import torch
import torch.nn as nn

"""MiniUNet3D per patch 32x32x32 (input/output 1 canale).

Encoder/Decoder 3D con skip connections. Output con sigmoid in [0,1].
"""


class MiniUNet3D(nn.Module):
    """Mini U-Net 3D per segmentazione volumetrica 32×32×32.

    Args:
        dropout_rate (float): Tasso di dropout per i blocchi 3D; nel bottleneck
            viene aumentato leggermente.

    Shape:
        - Input: ``(N, 1, D, H, W)`` con ``D, H, W`` multipli di 8 (ottimizzato per 32)
        - Output: ``(N, 1, D, H, W)`` con valori in ``[0, 1]`` (sigmoid)

    Note:
        - Skip-connections simmetriche e upsampling via ``ConvTranspose3d``.
        - Inizializzazione Xavier per le convoluzioni 3D.
    """
    def __init__(self, dropout_rate=0.5):
        super().__init__()

        self.enc1 = self._block(1, 32, dropout_rate)
        self.pool1 = nn.MaxPool3d(2)

        self.enc2 = self._block(32, 64, dropout_rate)
        self.pool2 = nn.MaxPool3d(2)

        self.enc3 = self._block(64, 128, dropout_rate)
        self.pool3 = nn.MaxPool3d(2)

        self.bottleneck = self._block(128, 256, min(0.8, dropout_rate * 1.5))

        self.up3 = nn.ConvTranspose3d(256, 128, 2, 2)
        self.dec3 = self._block(256, 128, dropout_rate)

        self.up2 = nn.ConvTranspose3d(128, 64, 2, 2)
        self.dec2 = self._block(128, 64, dropout_rate)

        self.up1 = nn.ConvTranspose3d(64, 32, 2, 2)
        self.dec1 = self._block(64, 32, dropout_rate)

        self.out_conv = nn.Conv3d(32, 1, kernel_size=1)

        self._init_weights()

    def _block(self, in_ch, out_ch, dr):
        return nn.Sequential(
            nn.Conv3d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm3d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm3d(out_ch),
            nn.ReLU(inplace=True),
            nn.Dropout3d(dr),
        )

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv3d, nn.ConvTranspose3d)):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm3d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, x):
        """Esegue il forward pass 3D.

        Args:
            x (torch.Tensor): Volume di input ``(N, 1, D, H, W)``.

        Returns:
            torch.Tensor: Volume di probabilità ``(N, 1, D, H, W)`` in ``[0, 1]``.

        Shape:
            - Input: ``(N, 1, D, H, W)``
            - Output: ``(N, 1, D, H, W)``
        """
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        e3 = self.enc3(self.pool2(e2))
        b = self.bottleneck(self.pool3(e3))

        d3 = self.dec3(torch.cat([self.up3(b), e3], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))

        return torch.sigmoid(self.out_conv(d1))


