from __future__ import annotations

"""Utility locali per `network_core`.

Contiene:
- Resampling quadratico per portare Y a 128
- Blur gaussiano GPU/CPU con fallback automatico
- Filtri vesselness opzionali
- Costanti di padding per adattare X=19 a X=32 nelle UNet 32x32
"""

import numpy as np
import logging


log = logging.getLogger(__name__)


# Padding legacy per compatibilità (19 canali -> 32)
_PAD_L, _PAD_R = 6, 7


def compute_dynamic_padding(current_channels: int, target_channels: int = 32) -> tuple[int, int]:
    """
    Calcola il padding dinamico per adattare il numero di canali al modello.
    
    Args:
        current_channels: Numero di canali corrente (es. 19, 23, 11)
        target_channels: Numero di canali richiesto dal modello (default 32)
    
    Returns:
        tuple (pad_left, pad_right) per portare current_channels a target_channels
    
    Examples:
        >>> compute_dynamic_padding(19, 32)  # VV originale
        (6, 7)  # 6 + 19 + 7 = 32
        >>> compute_dynamic_padding(23, 32)  # VV con più canali
        (4, 5)  # 4 + 23 + 5 = 32
        >>> compute_dynamic_padding(11, 32)  # HH
        (10, 11)  # 10 + 11 + 11 = 32
    """
    if current_channels >= target_channels:
        # Se hai più o uguale canali del necessario, nessun padding
        log.warning(f"Canali correnti ({current_channels}) >= target ({target_channels}), nessun padding applicato")
        return 0, 0
    
    pad_total = target_channels - current_channels
    pad_left = pad_total // 2
    pad_right = pad_total - pad_left
    
    log.info(f"Padding dinamico: {current_channels} canali -> {target_channels} (left={pad_left}, right={pad_right})")
    
    return pad_left, pad_right


def quadratic_resample_rows(arr: np.ndarray, num_rows: int, *, a: float = 0.0, b: float = 1.0, c: float = 0.0) -> np.ndarray:
    """Ricampionamento quadratico lungo l'asse delle righe.

    Parameters
    ----------
    arr : np.ndarray
        Matrice 2D (num_rows_originali, num_colonne)
    num_rows : int
        Numero di righe target
    a, b, c : float
        Coefficienti della trasformazione quadratica normalizzata

    Returns
    -------
    np.ndarray
        Matrice (num_rows, num_colonne) ricampionata
    """
    N, M = arr.shape
    t = np.linspace(0, 1, num_rows)
    quad_t = a * t**2 + b * t + c
    quad_t = (quad_t - quad_t.min()) / (quad_t.max() - quad_t.min())
    indices = quad_t * (N - 1)
    out = np.empty((num_rows, M), dtype=arr.dtype)
    for i, idx in enumerate(indices):
        lo = int(np.floor(idx))
        hi = min(int(np.ceil(idx)), N - 1)
        if hi == lo:
            out[i] = arr[lo]
        else:
            w = idx - lo
            out[i] = (1 - w) * arr[lo] + w * arr[hi]
    return out


def gaussian_blur_gpu(volume_np: np.ndarray,
                      sigma=(1.0, 1.0, 1.0),
                      mask_zeros: bool = False,
                      verbose: bool = False,
                      auto_threshold: bool = True) -> np.ndarray:
    """Blur gaussiano 3D con selezione automatica CPU/GPU.

    - Usa CuPy se disponibile e il volume è grande; altrimenti fallback a SciPy.
    - Impostare `mask_zeros=True` per trattare i voxel 0 come masked (NaN-safe).
    """
    volume_size = np.prod(volume_np.shape)
    size_threshold = 5_000_000

    if auto_threshold and volume_size < size_threshold:
        if verbose:
            print(f"Volume piccolo ({volume_size/1e6:.1f}M): uso CPU")
        from scipy.ndimage import gaussian_filter
        return gaussian_filter(volume_np, sigma=sigma, mode='nearest', truncate=3.0)

    try:
        import cupy as cp
        from cupyx.scipy.ndimage import gaussian_filter as gpu_gaussian_filter

        if verbose:
            device_info = cp.cuda.runtime.getDeviceProperties(cp.cuda.runtime.getDevice())
            total_mem = cp.cuda.runtime.memGetInfo()[1] / 1024**3
            free_mem = cp.cuda.runtime.memGetInfo()[0] / 1024**3
            print(f"GPU: {device_info['name'].decode()} — mem {free_mem:.1f}/{total_mem:.1f}GB")

        vol_gpu = cp.asarray(volume_np)
        if mask_zeros:
            vol_gpu = vol_gpu.astype(cp.float32)
            vol_gpu = cp.where(vol_gpu == 0, cp.nan, vol_gpu)
        vol_blur_gpu = gpu_gaussian_filter(vol_gpu, sigma=sigma, mode='nearest', truncate=3.0)
        if mask_zeros:
            vol_blur_gpu = cp.nan_to_num(vol_blur_gpu)
        return cp.asnumpy(vol_blur_gpu)
    except Exception:
        from scipy.ndimage import gaussian_filter
        return gaussian_filter(volume_np, sigma=sigma, mode='nearest', truncate=3.0)


def apply_vessel_filters(mask: np.ndarray, filters: list[str] | None, *, black_ridges: bool = False, thr: float = 0.0015) -> np.ndarray:
    """Applica filtri vesselness e interseca con la maschera UNet.

    filters: lista tra ["frangi", "sato", "meijering"].
    """
    if not filters:
        return mask
    from skimage.filters import frangi, sato, meijering
    fmap = {"frangi": frangi, "sato": sato, "meijering": meijering}
    enh = np.zeros_like(mask, dtype=np.float32)
    for f in filters:
        fl = f.lower()
        if fl not in fmap:
            log.warning("Filtro %s non riconosciuto", f)
            continue
        enh = np.maximum(enh, fmap[fl](mask, black_ridges=black_ridges))
    return np.where(enh > thr, mask, 0).astype(mask.dtype)


def show_slices_front(volume, mask):
    """
    Visualizza sezioni frontali del volume con slider interattivo.
    
    Args:
        volume (numpy.ndarray): Volume 3D da visualizzare
        mask (numpy.ndarray): Maschera 3D da visualizzare
        
    Note:
        Crea una visualizzazione interattiva con uno slider per navigare
        attraverso le sezioni frontali (piano ZY) del volume.
    """
    # Verifichiamo la forma dei dati
    print("Volume shape:", volume.shape)
    print("Mask shape:", mask.shape)
    
    # Assumiamo che il volume sia in formato (x, y, z)
    # x è la dimensione che vogliamo controllare con lo slider
    max_x = volume.shape[0] - 1
    
    fig, axs = plt.subplots(1, 2, figsize=(28, 10))
    plt.subplots_adjust(bottom=0.2)
    axcolor = 'lightgoldenrodyellow'
    
    # Slider per controllare la posizione x
    ax_x = plt.axes([0.15, 0.1, 0.65, 0.03], facecolor=axcolor)
    slider_x = Slider(ax_x, 'X', 0, max_x, valinit=max_x//2, valstep=1)
    
    def update(val):
        x_idx = int(slider_x.val)
        for ax in axs.flat:
            ax.clear()
        
        # Visualizziamo la sezione frontale (z,y) alla posizione x
        axs[0].imshow(volume[x_idx, :, :], cmap='gray')
        axs[0].set_title(f'Volume: Sezione X={x_idx} (Piano ZY)')
        axs[0].set_xlabel('Y')
        axs[0].set_ylabel('Z')
        
        axs[1].imshow(mask[x_idx, :, :], cmap='gray')
        axs[1].set_title(f'Mask: Sezione X={x_idx} (Piano ZY)')
        axs[1].set_xlabel('Y')
        axs[1].set_ylabel('Z')
        
        fig.canvas.draw_idle()
    
    slider_x.on_changed(update)
    update(0)
    plt.show()
def show_slices_front_both(volume, mask, mask1):
    """
    Visualizza sezioni frontali del volume con due maschere e slider interattivo.
    Piano ZY alla posizione X selezionata.
    """
    # Verifichiamo la forma dei dati
    print("Volume shape:", volume.shape)
    print("Mask A shape:", mask.shape)
    print("Mask B shape:", mask1.shape)

    max_x = volume.shape[0] - 1

    fig, axs = plt.subplots(3, 1, figsize=(28, 10))
    plt.subplots_adjust(bottom=0.2)
    axcolor = 'lightgoldenrodyellow'

    # Slider per controllare la posizione x
    ax_x = plt.axes([0.15, 0.1, 0.65, 0.03], facecolor=axcolor)
    slider_x = Slider(ax_x, 'X', 0, max_x, valinit=max_x//2, valstep=1)

    def update(val):
        x_idx = int(slider_x.val)
        for ax in axs.flat:
            ax.clear()

        axs[0].imshow(volume[x_idx, :, :], cmap='gray')
        axs[0].set_title(f'Volume: Sezione X={x_idx} (Piano ZY)')
        axs[0].set_xlabel('Y')
        axs[0].set_ylabel('Z')

        axs[1].imshow(mask[x_idx, :, :], cmap='jet', vmax=1, vmin=0)
        axs[1].set_title(f'Mask A: Sezione X={x_idx} (Piano ZY)')
        axs[1].set_xlabel('Y')
        axs[1].set_ylabel('Z')

        axs[2].imshow(mask1[x_idx, :, :], cmap='jet', vmax=1, vmin=0)
        axs[2].set_title(f'Mask B: Sezione X={x_idx} (Piano ZY)')
        axs[2].set_xlabel('Y')
        axs[2].set_ylabel('Z')

        fig.canvas.draw_idle()

    slider_x.on_changed(update)
    update(0)
    plt.show()
 
def show_slices_up(volume, mask):
    """
    Visualizza sezioni orizzontali del volume con slider interattivo.
    
    Args:
        volume (numpy.ndarray): Volume 3D da visualizzare
        mask (numpy.ndarray): Maschera 3D da visualizzare
        
    Note:
        Crea una visualizzazione interattiva con uno slider per navigare
        attraverso le sezioni orizzontali (piano XZ) del volume.
    """
    # Verifichiamo la forma dei dati
    print("Volume shape:", volume.shape)
    print("Mask shape:", mask.shape)
    
    # Assumiamo che il volume sia in formato (x, y, z)
    # y è la dimensione che vogliamo controllare con lo slider
    max_y = volume.shape[1] - 1
    
    fig, axs = plt.subplots(1, 2, figsize=(28, 10))
    plt.subplots_adjust(bottom=0.2)
    axcolor = 'lightgoldenrodyellow'
    
    # Slider per controllare la posizione y
    ax_y = plt.axes([0.15, 0.1, 0.65, 0.03], facecolor=axcolor)
    slider_y = Slider(ax_y, 'Y', 0, max_y, valinit=max_y//2, valstep=1)
    
    def update(val):
        y_idx = int(slider_y.val)
        for ax in axs.flat:
            ax.clear()
        
        # Visualizziamo la sezione trasversale (x,z) alla posizione y
        # volume[:, y_idx, :] seleziona tutti i punti x e z alla posizione y_idx
        axs[0].imshow(volume[:, y_idx, :], cmap='gray', vmax=2, vmin=-2)
        axs[0].set_title(f'Volume: Sezione Y={y_idx} (Piano XZ)')
        axs[0].set_xlabel('Z')
        axs[0].set_ylabel('X')
        
        # Rimappiamo l'indice Y del volume all'indice Y della maschera senza interpolazione
        vol_y = volume.shape[1]
        mask_y = mask.shape[1]
        if vol_y == mask_y:
            y_idx_mask = y_idx
        else:
            if vol_y <= 1:
                y_idx_mask = 0
            else:
                y_idx_mask = int(round(y_idx * (mask_y - 1) / (vol_y - 1)))
            # clamp per sicurezza
            y_idx_mask = max(0, min(mask_y - 1, y_idx_mask))

        axs[1].imshow(mask[:, y_idx_mask, :], cmap='jet', vmax=1, vmin=0)
        axs[1].set_title(f'Mask: Sezione Y(vol)={y_idx} → Y(mask)={y_idx_mask} (Piano XZ)')
        axs[1].set_xlabel('Z')
        axs[1].set_ylabel('X')

        
        fig.canvas.draw_idle()
    
    slider_y.on_changed(update)
    update(0)
    plt.show()