"""
MiniUNet-128 (general purpose)
------------------------------
Una U-Net 2D compatta compatibile con patch 128x128 (e più in generale con input
multipli di 2^depth). L'architettura eredita la MiniUNet esistente ma non è
vincolata a 32x32: con depth=3 funziona perfettamente su 128x128
(128→64→32→16 → up → 128).

Compatibilità: PyTorch >= 1.8.

Esempio:
    model = MiniUNet128(in_channels=1, out_channels=1, base_ch=32, depth=3, dropout_rate=0.5)
    y = model(torch.randn(4, 1, 128, 128))  # -> (4,1,128,128), valori in [0,1]
"""
from __future__ import annotations

from typing import Tuple

import torch
import torch.nn as nn


class _DoubleConv(nn.Module):
    def __init__(self, in_ch: int, out_ch: int, dropout: float) -> None:
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Dropout2d(dropout),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:  # noqa: D401
        """Forward standard su blocco conv-conv."""
        return self.block(x)


class MiniUNet128(nn.Module):
    """Mini U-Net 2D per segmentazione con input 128×128 (profondità variabile).

    Args:
        in_channels (int): Numero di canali in input (es. 1 per scala di grigi).
        out_channels (int): Numero di canali in output (es. 1 per mappa di probabilità).
        base_ch (int): Canali del primo livello encoder.
        depth (int): Livelli di down/upsampling. Con ``depth=3`` l'input 128×128
            diventa 64→32→16 nel bottleneck e torna a 128 in decodifica.
        dropout_rate (float): Tasso di dropout nei blocchi; nel bottleneck viene
            incrementato leggermente.

    Shape:
        - Input: ``(N, in_channels, H, W)`` con ``H, W`` multipli di ``2**depth``
        - Output: ``(N, out_channels, H, W)`` con valori in ``[0, 1]`` (sigmoid)

    Note:
        - Skip-connections simmetriche e upsampling con ``ConvTranspose2d``.
        - Inizializzazione Xavier per pesi delle convoluzioni.
    """

    def __init__(
        self,
        in_channels: int = 1,
        out_channels: int = 1,
        base_ch: int = 32,
        depth: int = 3,
        dropout_rate: float = 0.5,
    ) -> None:
        super().__init__()
        assert depth >= 1, "depth deve essere >= 1"

        # Encoder
        enc_layers = []
        ch_in = in_channels
        ch_out = base_ch
        for _ in range(depth):
            enc_layers.append(_DoubleConv(ch_in, ch_out, dropout_rate))
            ch_in, ch_out = ch_out, ch_out * 2
        self.encoder = nn.ModuleList(enc_layers)
        self.pools = nn.ModuleList([nn.MaxPool2d(2) for _ in range(depth)])

        # Bottleneck
        self.bottleneck = _DoubleConv(ch_in, ch_out, min(0.8, dropout_rate * 1.5))

        # Decoder
        dec_layers = []
        upconvs = []
        skip_ch = ch_in  # ultimo canale encoder prima del bottleneck
        cur_ch = ch_out
        for _ in range(depth):
            upconvs.append(nn.ConvTranspose2d(cur_ch, cur_ch // 2, kernel_size=2, stride=2))
            dec_layers.append(_DoubleConv(cur_ch // 2 + skip_ch, skip_ch, dropout_rate))
            cur_ch = skip_ch
            skip_ch = skip_ch // 2
        self.upconvs = nn.ModuleList(upconvs)
        self.decoder = nn.ModuleList(dec_layers)

        # Head
        self.head = nn.Conv2d(cur_ch, out_channels, kernel_size=1)
        self.activation = nn.Sigmoid()

        self._init_weights()

    def _init_weights(self) -> None:
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
                nn.init.xavier_uniform_(m.weight)
                if getattr(m, "bias", None) is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:  # noqa: D401
        """Esegue il forward pass con skip-connections simmetriche.

        Args:
            x (torch.Tensor): Tensore di input di forma ``(N, in_channels, H, W)``.

        Returns:
            torch.Tensor: Mappa di probabilità ``(N, out_channels, H, W)`` in ``[0, 1]``.

        Shape:
            - Input: ``(N, in_channels, H, W)``
            - Output: ``(N, out_channels, H, W)``
        """
        skips = []
        out = x
        for enc, pool in zip(self.encoder, self.pools):
            out = enc(out)
            skips.append(out)
            out = pool(out)

        out = self.bottleneck(out)

        for up, dec, skip in zip(self.upconvs, self.decoder, reversed(skips)):
            out = up(out)
            # Gestione eventuali disallineamenti di 1px
            if out.size(-1) != skip.size(-1) or out.size(-2) != skip.size(-2):
                dh = skip.size(-2) - out.size(-2)
                dw = skip.size(-1) - out.size(-1)
                out = nn.functional.pad(out, (0, dw, 0, dh))
            out = torch.cat([skip, out], dim=1)
            out = dec(out)

        out = self.head(out)
        return self.activation(out)


__all__ = ["MiniUNet128"]