from __future__ import annotations

"""Inferenza locale per rendere il package auto-contenuto.

Contiene tre funzioni principali utilizzate da `processor.process_swath_array`:
- `infer_transverse_fast`: patch 32x32 con padding laterale su X (19→32), sliding su Y
- `infer_longitudinal_fast`: patch 32x32 su piani (y,z) per ogni X
- `infer_3d_fast`: patch 32x32x32 con sliding su (y,z), padding su X (19→32) interno

Dipendenze: numpy, torch
"""

import logging
from typing import Sequence

import numpy as np
import torch
import torch.nn.functional as F

from .utils import _PAD_L, _PAD_R, compute_dynamic_padding


log = logging.getLogger(__name__)


@torch.no_grad()
def infer_transverse_fast(
    volume: np.ndarray,
    model: torch.nn.Module,
    *,
    device: str = "cpu",
    stride: int = 4,
    batch_size: int = 1024,
    target_channels: int = 32,
    tag: str = "",
    normalize_input: bool = False,
) -> np.ndarray:
    """Esegue inferenza 2D trasversale su un volume 3D.

    Estrae patch 32×32 dai piani XY per ogni Z (con padding su X: 19→32),
    valuta il modello a sliding-window lungo Y e ricompone la maschera
    mediando le sovrapposizioni.

    Se ``normalize_input`` è True, ogni patch del batch viene normalizzata
    come nel preprocessing del training da PNG: z-score per patch, clipping
    in [-2, 2], rimappatura a [0,1] tramite (z+2)/4 e quantizzazione a 8 bit
    (``round(*255)/255``).

    Parameters
    ----------
    volume : np.ndarray
        Volume di input in float32 con shape (x, y, z).
    model : torch.nn.Module
        Modello UNet 2D che produce mappe (1, 32, 32).
    device : str, optional
        Dispositivo ("cpu"|"cuda"). Default: "cpu".
    stride : int, optional
        Passo di scorrimento lungo Y. Default: 4.
    batch_size : int, optional
        Dimensione batch per l'inferenza. Default: 1024.
    target_channels : int, optional
        Numero di canali target per il padding. Default: 32.
    tag : str, optional
        Etichetta opzionale per logging.
    normalize_input : bool, optional
        Se True, applica la normalizzazione stile training PNG.

    Returns
    -------
    np.ndarray
        Maschera predetta con shape (x, y, z) in float32.
    """
    x, y, z = volume.shape
    mask = np.zeros_like(volume, dtype=np.float32)
    cnt = np.zeros_like(volume, dtype=np.float32)

    # Calcola padding dinamicamente basandosi sul numero di canali (x)
    if x != 19:
        pad_l, pad_r = compute_dynamic_padding(x, target_channels=target_channels)
    else:
        pad_l, pad_r = _PAD_L, _PAD_R

    for zz in range(z):
        section = volume[:, :, zz].T  # (y, x)
        section_p = np.pad(section, ((0, 0), (pad_l, pad_r)), mode="constant")
        sec_t = (
            torch.from_numpy(section_p).float().unsqueeze(0).unsqueeze(0).to(device)
        )  # (1,1,y,32)

        patches = F.unfold(sec_t, kernel_size=(32, 32), stride=(stride, 32))
        patches = patches.permute(0, 2, 1).reshape(-1, 1, 32, 32)
        n_patches = patches.size(0)

        for start in range(0, n_patches, batch_size):
            batch = patches[start : start + batch_size]
            if normalize_input:
                # z-score per patch → clip [-2,2] → mappa a [0,1] come in train_volume
                mean = batch.mean(dim=[2,3], keepdim=True)
                std = batch.std(dim=[2,3], keepdim=True).clamp_min(1e-8)
                z = (batch - mean) / std
                z = torch.clamp(z, -2.0, 2.0)
                batch = (z + 2.0) / 4.0
                # quantizzazione 8-bit equivalente a PNG: round(*255)/255
                batch = torch.round(batch * 255.0) / 255.0
            preds = model(batch)[:, 0].cpu().numpy()
            idxs = np.arange(start, min(start + batch_size, n_patches))
            ys = 16 + idxs * stride

            for k, yc in enumerate(ys):
                pred_crop = preds[k][:, pad_l : 32 - pad_r].T  # (x=19, 32)
                mask[:, yc - 16 : yc + 16, zz] += pred_crop
                cnt[:, yc - 16 : yc + 16, zz] += 1

    return np.divide(mask, cnt, out=np.zeros_like(mask), where=cnt > 0)


@torch.no_grad()
def infer_longitudinal_fast(
    volume: np.ndarray,
    model: torch.nn.Module,
    *,
    device: str = "cpu",
    stride: int = 4,
    batch_size: int = 512,
    normalize_input: bool = False,
) -> np.ndarray:
    """Esegue inferenza 2D longitudinale su un volume 3D.

    Estrae patch 32×32 dai piani ZY per ogni X, scorrendo con sliding-window
    su (y, z) e ricompone la maschera mediando le sovrapposizioni.

    Se ``normalize_input`` è True, ogni patch del batch viene normalizzata
    come nel preprocessing del training da PNG: z-score per patch, clipping
    in [-2, 2], rimappatura a [0,1] tramite (z+2)/4 e quantizzazione a 8 bit
    (``round(*255)/255``).

    Parameters
    ----------
    volume : np.ndarray
        Volume di input in float32 con shape (x, y, z).
    model : torch.nn.Module
        Modello UNet 2D che produce mappe (1, 32, 32).
    device : str, optional
        Dispositivo ("cpu"|"cuda"). Default: "cpu".
    stride : int, optional
        Passo di scorrimento per Y e Z. Default: 4.
    batch_size : int, optional
        Dimensione batch per l'inferenza. Default: 512.
    normalize_input : bool, optional
        Se True, applica la normalizzazione stile training PNG.

    Returns
    -------
    np.ndarray
        Maschera predetta con shape (x, y, z) in float32.
    """
    x, y, z = volume.shape
    mask = np.zeros_like(volume, dtype=np.float32)
    cnt = np.zeros_like(volume, dtype=np.float32)

    if y < 32 or z < 32:
        log.warning("Inferenza longitudinale saltata: dimensioni (y=%d, z=%d) < 32", y, z)
        return mask

    patches_per_y = (y - 32) // stride + 1
    patches_per_z = (z - 32) // stride + 1

    for xx in range(x):
        section = volume[xx]  # (y,z)
        sec_t = torch.from_numpy(section).float().to(device).unsqueeze(0).unsqueeze(0)
        patches = F.unfold(sec_t, kernel_size=32, stride=stride)
        patches = patches.permute(0, 2, 1).reshape(-1, 1, 32, 32)

        n_patches = patches.size(0)
        expected_patches = patches_per_y * patches_per_z
        if n_patches != expected_patches:
            log.warning("Numero patch inaspettato: %d vs %d", n_patches, expected_patches)

        for start in range(0, n_patches, batch_size):
            end = min(start + batch_size, n_patches)
            batch_patches = patches[start:end]
            if normalize_input:
                mean = batch_patches.mean(dim=[2,3], keepdim=True)
                std = batch_patches.std(dim=[2,3], keepdim=True).clamp_min(1e-8)
                z = (batch_patches - mean) / std
                z = torch.clamp(z, -2.0, 2.0)
                batch_patches = (z + 2.0) / 4.0
                batch_patches = torch.round(batch_patches * 255.0) / 255.0
            preds = model(batch_patches)[:, 0].cpu().numpy()

            for k in range(len(batch_patches)):
                patch_idx = start + k
                patch_y_idx = patch_idx // patches_per_z
                patch_z_idx = patch_idx % patches_per_z
                yy = 16 + patch_y_idx * stride
                zz = 16 + patch_z_idx * stride
                mask[xx, yy - 16 : yy + 16, zz - 16 : zz + 16] += preds[k]
                cnt[xx, yy - 16 : yy + 16, zz - 16 : zz + 16] += 1

    return np.divide(mask, cnt, out=np.zeros_like(mask), where=cnt > 0)


@torch.no_grad()
def infer_manhole_longitudinal_128(
    volume: np.ndarray,
    model: torch.nn.Module,
    *,
    device: str = "cpu",
    stride_z: int = 8,
    batch_size: int = 128,
    normalize_input: bool = True,
) -> np.ndarray:
    """Inferenza 2D longitudinale specifica per modello 128x128 (schema turbo).

    - Un'unica riga di patch lungo Y (0:128)
    - Scorrimento lungo Z con stride `stride_z` usando torch.nn.functional.unfold
    - Ricomposizione tramite media delle sovrapposizioni
    """
    import torch.nn.functional as F

    x, y, z = volume.shape
    mask = np.zeros_like(volume, dtype=np.float32)
    cnt = np.zeros_like(volume, dtype=np.float32)

    if y < 128 or z < 128:
        log.warning("Inferenza manhole 128 saltata: dimensioni (y=%d, z=%d) < 128", y, z)
        return mask

    patches_per_z = (z - 128) // stride_z + 1

    dev = torch.device(device if torch.cuda.is_available() else "cpu")
    model = model.to(dev).eval()

    for xx in range(x):
        section = volume[xx]  # (y,z)
        sec_t = torch.from_numpy(section).float().to(dev).unsqueeze(0).unsqueeze(0)  # (1,1,y,z)
        patches = F.unfold(sec_t, kernel_size=(128, 128), stride=(128, stride_z))  # (1, 16384, N)
        patches = patches.permute(0, 2, 1).reshape(-1, 1, 128, 128)  # (N,1,128,128)

        n_patches = patches.size(0)
        expected = patches_per_z
        if n_patches != expected:
            log.warning("Numero patch 128x128 inaspettato: %d vs %d", n_patches, expected)

        for start in range(0, n_patches, batch_size):
            end = min(start + batch_size, n_patches)
            batch_patches = patches[start:end]
            if normalize_input:
                mean = batch_patches.mean(dim=[2, 3], keepdim=True)
                std = batch_patches.std(dim=[2, 3], keepdim=True).clamp_min(1e-8)
                zt = (batch_patches - mean) / std
                zt = torch.clamp(zt, -2.0, 2.0)
                batch_patches = (zt + 2.0) / 4.0
                batch_patches = torch.round(batch_patches * 255.0) / 255.0

            preds = model(batch_patches)[:, 0].detach().cpu().numpy()  # (B,128,128)

            for k in range(batch_patches.size(0)):
                patch_idx = start + k
                zz_center = 64 + patch_idx * stride_z
                yy0, yy1 = 0, 128
                zz0, zz1 = zz_center - 64, zz_center + 64
                mask[xx, yy0:yy1, zz0:zz1] += preds[k]
                cnt[xx, yy0:yy1, zz0:zz1] += 1

    return np.divide(mask, cnt, out=np.zeros_like(mask), where=cnt > 0)


@torch.no_grad()
def infer_3d_fast(
    volume: np.ndarray,
    model: torch.nn.Module,
    *,
    device: str = "cpu",
    stride_y: int = 16,
    stride_z: int = 16,
    batch_size: int = 32,
    tag: str = "",
    normalize_input: bool = False,
) -> np.ndarray:
    """Esegue inferenza 3D con patch 32×32×32 su un volume 3D.

    Applica padding su X (19→32) internamente, esegue uno sliding-window su
    (y, z) e ricompone la maschera mediando le sovrapposizioni. Se
    ``normalize_input`` è True, normalizza per patch con z-score, clipping
    in [-2, 2] e rimappatura per-patch a [0,1] (come attualmente implementato
    per la 3D), lasciando invariata la logica rispetto al training 3D.

    Parameters
    ----------
    volume : np.ndarray
        Volume di input in float32 con shape (x, y, z).
    model : torch.nn.Module
        Modello UNet 3D che produce blocchi (1, 32, 32, 32).
    device : str, optional
        Dispositivo ("cpu"|"cuda"). Default: "cpu".
    stride_y : int, optional
        Passo di scorrimento lungo Y. Default: 16.
    stride_z : int, optional
        Passo di scorrimento lungo Z. Default: 16.
    batch_size : int, optional
        Dimensione batch per l'inferenza. Default: 32.
    tag : str, optional
        Etichetta opzionale per logging.
    normalize_input : bool, optional
        Se True, normalizza per-patch con z-score e min-max [0,1].

    Returns
    -------
    np.ndarray
        Maschera predetta con shape (x, y, z) in float32.
    """
    x, y, z = volume.shape
    volume_padded = np.pad(volume, ((6, 7), (0, 0), (0, 0)), mode="constant")

    mask = np.zeros_like(volume, dtype=np.float32)
    cnt = np.zeros_like(volume, dtype=np.float32)

    if y < 32 or z < 32:
        log.warning("Volume troppo piccolo per patch 32^3: %s", volume.shape)
        return mask

    patches_y = range(16, y - 16 + 1, stride_y)
    patches_z = range(16, z - 16 + 1, stride_z)

    patches_list = []
    coords_list = []
    for yc in patches_y:
        for zc in patches_z:
            patch = volume_padded[:, yc - 16 : yc + 16, zc - 16 : zc + 16]
            patch_tensor = torch.from_numpy(patch).float().unsqueeze(0)
            patches_list.append(patch_tensor)
            coords_list.append((yc, zc))

    if not patches_list:
        log.warning("Nessuna patch estratta per %s", tag)
        return mask

    total_patches = len(patches_list)
    for start in range(0, total_patches, batch_size):
        end = min(start + batch_size, total_patches)
        batch_patches = torch.cat(patches_list[start:end], dim=0).to(device)
        batch_coords = coords_list[start:end]

        batch_input = batch_patches.unsqueeze(1)
        if normalize_input:
            mean = batch_input.mean(dim=[2,3,4], keepdim=True)
            std = batch_input.std(dim=[2,3,4], keepdim=True).clamp_min(1e-8)
            z = (batch_input - mean) / std
            z = torch.clamp(z, -2.0, 2.0)
            mn = z.amin(dim=[2,3,4], keepdim=True)
            mx = z.amax(dim=[2,3,4], keepdim=True)
            batch_input = (z - mn) / (mx - mn + 1e-8)

        batch_preds = model(batch_input)  # (B,1,32,32,32)
        batch_preds = batch_preds.squeeze(1).cpu().numpy()

        for i, (yc, zc) in enumerate(batch_coords):
            pred = batch_preds[i]
            pred_crop = pred[:19, :, :]
            mask[:, yc - 16 : yc + 16, zc - 16 : zc + 16] += pred_crop
            cnt[:, yc - 16 : yc + 16, zc - 16 : zc + 16] += 1

    return np.divide(mask, cnt, out=np.zeros_like(mask), where=cnt > 0)


