from __future__ import annotations

"""
Processor: API principale per eseguire la pipeline su una singola swath già filtrata.

Flusso di alto livello:
1) Input: volume numpy (x, y, z) con filtri già applicati
2) Resample dell'asse Y a 128 livelli con `quadratic_resample_rows`
3) Inferenza:
   - mode="2d": esegue trasversale + longitudinale e fonde con media
   - mode="3d": esegue inferenza 3D con sliding window su (y,z)
4) Post-processing opzionale: blur gaussiano GPU/CPU, vesselness
5) Output: maschera riportata alla Y originale con interpolazione vettorizzata

Esempio d'uso minimo:
    >>> import numpy as np
    >>> from network_core import process_swath_array
    >>> vol = np.random.normal(0, 0.1, (19, 96, 64)).astype(np.float32)
    >>> mask = process_swath_array(
    ...     vol,
    ...     mode="2d",
    ...     model_trasversale_path=r"C:\\modelli\\best_tr.pth",
    ...     model_longitudinale_path=r"C:\\modelli\\best_lon.pth",
    ...     stride=4,
    ...     batch_size=1024,
    ... )
    >>> mask.shape
    (19, 96, 64)

Per la modalità 3D:
    >>> mask3d = process_swath_array(
    ...     vol,
    ...     mode="3d",
    ...     model_3d_path=r"C:\\modelli\\best_3d.pth",
    ...     stride_y_3d=16,
    ...     stride_z_3d=16,
    ...     batch_size_3d=32,
    ... )

Dipendenze runtime minime: numpy, torch; opzionali: scipy (interpolazione/blur CPU), cupy (blur GPU), scikit-image (vesselness)
"""

import logging
from typing import Iterable

import numpy as np
from scipy.ndimage import maximum_filter
# Importiamo torch e scipy solo quando servono per evitare inizializzazioni precoci di OpenMP
import importlib

# Import locali 
from .utils import quadratic_resample_rows, gaussian_blur_gpu, apply_vessel_filters, show_slices_up, show_slices_front
from .infer import infer_transverse_fast, infer_longitudinal_fast, infer_3d_fast, infer_manhole_longitudinal_128
from .models import load_model_2d, load_model_3d, load_model_manhole_2d
from .interpolation import interpolate_to_target_frames


log = logging.getLogger(__name__)


def _auto_device(device: str | None) -> str:
    """
    Rileva automaticamente il dispositivo CUDA disponibile se non specificato.
    
    Parameters
    ----------
    device : str | None
        "cuda" | "cpu"; se None auto-detect.

    Returns
    -------
    str
        "cuda" | "cpu".
    """
    if device:
        return device
    # Lazy import torch qui
    torch = importlib.import_module('torch')
    return "cuda" if torch.cuda.is_available() else "cpu"


def _resample_y_to_128(volume: np.ndarray) -> np.ndarray:
    """Resample lungo Y a 128 con interpolazione quadratica per ogni slice Z.

    Parameters
    ----------
    volume : np.ndarray
        Volume input (x, y, z).

    Returns
    -------
    np.ndarray
        Volume output (x, 128, z).
    """
    x, y, z = volume.shape
    if y == 128:
        return volume.astype(np.float32, copy=False)
    out = np.empty((x, 128, z), dtype=np.float32)
    for k in range(z):
        # quadratic_resample_rows lavora sulle righe; trasponiamo (y, x)
        yz = volume[:, :, k].T  # (y, x)
        yz128 = quadratic_resample_rows(yz, 128)  # (128, x)
        out[:, :, k] = yz128.T  # (x, 128)
    return out


def _maybe_vessel(mask: np.ndarray, vessel_filters: Iterable[str] | None) -> np.ndarray:
    """Applica filtri vesselness se specificati.
    
    Parameters
    ----------
    mask : np.ndarray
        Maschera input (x, y, z).
    vessel_filters : Iterable[str] | None
        Filtri vesselness opzionali ["frangi", "sato", "meijering"].
        
    Returns
    -------
    np.ndarray
        Maschera output (x, y, z).
    """
    if vessel_filters:
        return apply_vessel_filters(mask, list(vessel_filters))
    return mask


def process_swath_array(
    volume_np: np.ndarray,
    *,
    mode: str = "2d",  # "2d" | "3d"
    # Modelli pre-caricati (opzionali)
    model_trasversale: torch.nn.Module | None = None,
    model_longitudinale: torch.nn.Module | None = None,
    model_3d: torch.nn.Module | None = None,
    model_manhole: torch.nn.Module | None = None,
    # Oppure percorsi checkpoint
    model_trasversale_path: str | None = None,
    model_longitudinale_path: str | None = None,
    model_3d_path: str | None = None,
    model_manhole_path: str | None = None,
    # Parametri inferenza
    device: str | None = None,
    stride: int = 4,
    batch_size: int = 1024,
    stride_manhole_z: int = 8,
    batch_size_manhole: int = 128,
    stride_y_3d: int = 16,
    stride_z_3d: int = 16,
    batch_size_3d: int = 32,
    dropout_rate_3d: float = 0.5,
    manhole_threshold: float = 0.95,
    manhole_connect_iterations: int = 1,
    manhole_min_height_y: int = 90,
    manhole_structure: tuple[int, int, int] = (3, 3, 3),
    # Post-processing
    apply_gaussian_blur: bool = False,
    apply_max_filter: bool = False,
    vessel_filters: Iterable[str] | None = None,
    # Normalizzazione input per-patch
    normalize_input: bool = False,
    # Interpolazione
    interpolate: bool = True,
    # Debug
    debug_manhole: bool = False,
    ) -> np.ndarray | tuple[np.ndarray, list[dict[str, int]]]:
    """Esegue la pipeline di inferenza su una singola swath.

    Flusso:
    1) Ricampionamento lungo Y a 128 livelli (vol → vol128)
    2) Inferenza tubi (2D trasversale+longitudinale oppure 3D) → mask128
    3) Opzionale: inferenza tombini 128×128 su piani (y,z) prima dell'interpolazione
       - Threshold > manhole_threshold, chiusura 3D conservativa (manhole_connect_iterations iterazioni), labeling
       - Filtro blob con altezza in Y ≥ manhole_min_height_y
       - Calcolo centroidi e costruzione lista candidati nel formato
         [{"x": cx, "y": cy, "z": cz}, ...]
    4) Interpolazione della maschera tubi alla Y originale (se richiesto)

    Parameters
    ----------
    volume_np : np.ndarray
        Volume input (x, y, z) già filtrato.
    mode : str
        "2d" | "3d" per scegliere il tipo di inferenza tubi.
    model_trasversale, model_longitudinale, model_3d : torch.nn.Module | None
        Modelli già caricati; in alternativa fornire i percorsi .pth.
    model_*_path : str | None
        Percorsi ai checkpoint, usati se non si passano i modelli già caricati.
    device : str | None
        "cuda" | "cpu"; se None viene auto-determinato.
    stride, batch_size : int
        Parametri per inferenza 2D.
    stride_y_3d, stride_z_3d, batch_size_3d : int
        Parametri per inferenza 3D.
    stride_manhole_z, batch_size_manhole : int
        Parametri per inferenza tombini.
    manhole_threshold, manhole_connect_iterations, manhole_min_height_y, manhole_structure : float, int, int, tuple[int, int, int]
        Parametri per inferenza tombini.
    dropout_rate_3d : float
        Dropout del modello 3D (deve combaciare col training).
    apply_gaussian_blur : bool
        Se True applica blur gaussiano (GPU/CPU) alla maschera tubi.
    vessel_filters : Iterable[str] | None
        Filtri vesselness opzionali (frangi, sato, meijering) sulla maschera tubi.
    normalize_input : bool
        Se True applica normalizzazione stile training PNG alle patch.
    interpolate : bool
        Se True riporta la maschera tubi alla Y originale.
    debug_manhole : bool
        Se True visualizza volumi e maschera tombini prima del post-processing.

    Returns
    -------
    Union[np.ndarray, Tuple[np.ndarray, List[Dict[str, int]]]]
        - Solo maschera tubi (x, y_orig, z) se non è presente il modello tombini.
        - Coppia (mask_tubi, candidati) se il modello tombini è disponibile. I candidati sono una lista di dict con
          chiavi: "x", "y", "z" (coordinate del centroide nel sistema di vol128).
    """
    if not isinstance(volume_np, np.ndarray) or volume_np.ndim != 3:
        raise ValueError("'volume_np' deve essere un np.ndarray 3D (x,y,z)")

    volume_np = volume_np.astype(np.float32, copy=False)
    x, y_orig, z = volume_np.shape
    device = _auto_device(device)
    print(f"device: {device}")
    log.info("Process swath array: shape=%s, mode=%s, device=%s", volume_np.shape, mode, device)

    # 1) Resample Y→128
    vol128 = _resample_y_to_128(volume_np)

    # 2) Inferenza
    if mode.lower() == "2d":
        if model_trasversale is None:
            if not model_trasversale_path:
                raise ValueError("Serve 'model_trasversale' oppure 'model_trasversale_path'")
            model_trasversale = load_model_2d(model_trasversale_path).to(device).eval()
        if model_longitudinale is None:
            if not model_longitudinale_path:
                raise ValueError("Serve 'model_longitudinale' oppure 'model_longitudinale_path'")
            model_longitudinale = load_model_2d(model_longitudinale_path).to(device).eval()

        mask_tr = infer_transverse_fast(
            vol128,
            model_trasversale,
            device=device,
            stride=stride,
            batch_size=batch_size,
            normalize_input=normalize_input,
        )
        mask_lon = infer_longitudinal_fast(
            vol128,
            model_longitudinale,
            device=device,
            stride=stride,
            batch_size=max(1, batch_size // 2),
            normalize_input=normalize_input,
        )

        if apply_gaussian_blur:
            mask_tr = gaussian_blur_gpu(mask_tr, sigma=(1.0, 1.0, 1.0), mask_zeros=False)
            mask_lon = gaussian_blur_gpu(mask_lon, sigma=(1.0, 1.0, 1.0), mask_zeros=False)

        mask128 = np.maximum(mask_tr, mask_lon)
        
        if apply_max_filter:
            # Calcola il massimo locale sui valori assoluti del volume con una finestra 3x3x3
            max_volume = maximum_filter(np.abs(vol128), size=(4, 4, 4))
            max_volume= gaussian_blur_gpu(max_volume, sigma=(1.0, 1.0, 1.0), mask_zeros=False)
            max_volume=np.clip(max_volume, 0, 2) / 2 #clip e normalizzazione
            mask128=mask128*max_volume*max_volume 
        
        mask128 = _maybe_vessel(mask128, vessel_filters)

    elif mode.lower() == "3d":
        if model_3d is None:
            if not model_3d_path:
                raise ValueError("Serve 'model_3d' oppure 'model_3d_path'")
            model_3d = load_model_3d(model_3d_path, device=device, dropout_rate=dropout_rate_3d)

        mask128 = infer_3d_fast(
            vol128, model_3d, device=device, stride_y=stride_y_3d, stride_z=stride_z_3d, batch_size=batch_size_3d
        )
        if apply_gaussian_blur:
            mask128 = gaussian_blur_gpu(mask128, sigma=(1.0, 1.0, 1.0), mask_zeros=False)
        mask128 = _maybe_vessel(mask128, vessel_filters)

    else:
        raise ValueError("'mode' deve essere '2d' oppure '3d'")

    # ---- Inferenza TOMBINI (indipendente) su vol128 prima di qualunque interpolazione ----
    manhole_mask128: np.ndarray | None = None
    manhole_candidates_list: list[dict[str, int]] | None = None
    try:
        # Carica modello se fornito
        if model_manhole is None and model_manhole_path:
            model_manhole = load_model_manhole_2d(model_manhole_path, device=device)
        if model_manhole is not None:
            manhole_mask128 = infer_manhole_longitudinal_128(
                vol128,
                model_manhole,
                device=device or "cpu",
                stride_z=stride_manhole_z,
                batch_size=batch_size_manhole,
                normalize_input=normalize_input,
            ).astype(np.float32, copy=False)

            # Debug: show volume di riferimento e inferenza tombini
            if debug_manhole:
                try:
                    log.info("Visualizzazione debug manhole")
                    show_slices_up(vol128, manhole_mask128)
                    show_slices_front(vol128, manhole_mask128)
                except Exception:
                    log.debug("Visualizzazione debug manhole non disponibile in questo contesto")

            # Post: threshold e componenti connessi con tolleranza salti ≤ 3 voxel
            thr = manhole_threshold
            bin_mask = (manhole_mask128 > thr).astype(np.uint8)

            # Connessione con distanza di collegamento max 3: chiusura conservativa
            # Implementazione CPU con struttura 3D e 1 iterazione di dilatazione+erosione
            try:
                import scipy.ndimage as ndi
                # chiusura per unire piccoli buchi
                structure = np.ones(manhole_structure, dtype=np.uint8)
                # 3 iterazioni di dilatazione per connettere gap ≤3, poi erosione simmetrica
                dilated = bin_mask.copy()
                dilated = ndi.binary_dilation(dilated, structure=structure, iterations=manhole_connect_iterations)
                closed = ndi.binary_erosion(dilated, structure=structure, iterations=manhole_connect_iterations)

                labeled, num = ndi.label(closed, structure=structure)

                # Filtra per altezza lungo Y ≥ 90 e calcola centroidi
                manhole_candidates_list = []
                for label_id in range(1, num + 1):
                    coords = np.argwhere(labeled == label_id)
                    if coords.size == 0:
                        continue
                    # coords: (N, 3) con ordini (x, y, z)
                    y_min = int(coords[:, 1].min())
                    y_max = int(coords[:, 1].max())
                    height_y = y_max - y_min + 1
                    if height_y >= manhole_min_height_y:
                        # centroide (coordinate intere su vol128)
                        cx = int(np.round(coords[:, 0].mean()))
                        cy = int(np.round(coords[:, 1].mean()))
                        cz = int(np.round(coords[:, 2].mean()))
                        manhole_candidates_list.append({
                            "x": cx,
                            "y": cy,
                            "z": cz,
                        })
                        log.info(f"Tombino trovato: {cx}, {cy}, {cz}")
            except Exception as e:
                log.warning("Component labeling manhole fallita: %s", e)
    except Exception as e:
        log.warning("Inferenza manhole non eseguita: %s", e)

    # 3) Riporta maschera tubi a Y originale
    if y_orig != 128 and interpolate:
        mask_out = interpolate_to_target_frames(mask128, y_orig)
    else:
        mask_out = mask128.astype(np.float32, copy=False)

    # Ritorno: solo candidati opzionali oltre alla maschera tubi
    if  manhole_candidates_list is not None:
        return mask_out, manhole_candidates_list
    return mask_out


