import torch

"""Loader modelli locali per rendere il package portabile.

Le funzioni caricano checkpoint salvati con chiave 'model_state_dict'.
"""

from .mini_unet import MiniUNet
from .miniunet3d import MiniUNet3D
from .mini_unet_128 import MiniUNet128


def load_model_2d(model_path: str, device: str = 'cuda'):
    """
    Carica modello 2D (MiniUNet) dal checkpoint.

    Parameters
    ----------
    model_path : str
        Percorso al checkpoint del modello.
    device : str
        "cuda" | "cpu"; se None auto-detect.
    Returns
    -------
    torch.nn.Module
        Modello MiniUNet caricato e pronto per inferenza.
    """
    print(f"[load_model_2d] Loading model from: {model_path}")
    print(f"[load_model_2d] CUDA available: {torch.cuda.is_available()}")
    device = torch.device(device if torch.cuda.is_available() else 'cpu')
    print(f"[load_model_2d] Using device: {device}")
    print(f"[load_model_2d] Creating model...")
    model = MiniUNet().to(device)
    print(f"[load_model_2d] Loading checkpoint...")
    checkpoint = torch.load(model_path, map_location=device)
    print(f"[load_model_2d] Loading state dict...")
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"[load_model_2d] Setting eval mode...")
    model.eval()
    print(f"[load_model_2d] ✓ Model loaded successfully")
    return model


def load_model_3d(model_path: str, device: str = 'cuda', dropout_rate: float = 0.5):
    """ 
    Carica modello 3D (MiniUNet3D) dal checkpoint.

    Parameters
    ----------
    model_path : str
        Percorso al checkpoint del modello.
    device : str
        "cuda" | "cpu"; se None auto-detect.
    dropout_rate : float
        Tasso di dropout

    Returns
    -------
    torch.nn.Module
        Modello MiniUNet3D caricato e pronto per inferenza.
    """
    device = torch.device(device if torch.cuda.is_available() else 'cpu')
    model = MiniUNet3D(dropout_rate=dropout_rate).to(device)
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    return model


def load_model_manhole_2d(
    model_path: str,
    device: str = 'cuda',
    *,
    base_ch: int = 32,
    depth: int = 3,
    dropout_rate: float = 0.5,
):
    """Carica modello 2D per tombini (MiniUNet128) dal checkpoint.

    Legge iperparametri dal checkpoint (ckpt['args']) se presenti, in modo da
    combaciare con la configurazione di training (come in swath_processing_turbo).

    Parameters
    ----------
    model_path : str
        Percorso al checkpoint del modello.
    device : str
        "cuda" | "cpu"; se None auto-detect.
    base_ch : int
        Numero di canali base.
    depth : int
        Profondità del modello.
    dropout_rate : float
        Tasso di dropout.

    Returns
    -------
    torch.nn.Module
        Modello MiniUNet128 caricato e pronto per inferenza.
    """
    dev = torch.device(device if torch.cuda.is_available() else 'cpu')
    checkpoint = torch.load(model_path, map_location=dev)
    args = checkpoint.get('args', {}) or {}
    base_ch = int(args.get('base_ch', base_ch))
    depth = int(args.get('depth', depth))
    dropout_rate = float(args.get('dropout', dropout_rate))

    model = MiniUNet128(in_channels=1, out_channels=1, base_ch=base_ch, depth=depth, dropout_rate=dropout_rate).to(dev)
    state = checkpoint.get('model_state_dict') or checkpoint.get('model_state') or checkpoint
    model.load_state_dict(state)
    model.eval()
    return model


