import torch
import torch.nn as nn

"""MiniUNet 2D per patch 32x32 (input/output 1 canale).

Architettura compatta con 3 livelli di encoder/decoder e skip connections.
Output con sigmoid in [0,1].
"""


class MiniUNet(nn.Module):
    """Mini U-Net 2D per segmentazione su patch 32x32.

    Args:
        dropout_rate (float): Tasso di dropout applicato nei blocchi conv;
            nel bottleneck viene aumentato leggermente.

    Shape:
        - Input: ``(N, 1, H, W)`` con ``H, W`` multipli di 8 (ottimizzato per 32x32)
        - Output: ``(N, 1, H, W)`` con valori in ``[0, 1]`` (sigmoid)

    Note:
        - Encoder/decoder con skip-connections e upsampling via ``ConvTranspose2d``.
        - Inizializzazione Xavier per i pesi delle convoluzioni.
    """
    def __init__(self, dropout_rate=0.5):
        super(MiniUNet, self).__init__()

        self.enc_conv1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Dropout2d(dropout_rate),
        )
        self.pool1 = nn.MaxPool2d(2)

        self.enc_conv2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Dropout2d(dropout_rate),
        )
        self.pool2 = nn.MaxPool2d(2)

        self.enc_conv3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Dropout2d(dropout_rate),
        )
        self.pool3 = nn.MaxPool2d(2)

        self.bottleneck = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Dropout2d(min(0.8, dropout_rate * 1.5)),
        )

        self.upconv3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec_conv3 = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Dropout2d(dropout_rate),
        )

        self.upconv2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec_conv2 = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Dropout2d(dropout_rate),
        )

        self.upconv1 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.dec_conv1 = nn.Sequential(
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Dropout2d(dropout_rate),
        )

        self.final_conv = nn.Conv2d(32, 1, kernel_size=1)

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        """Esegue il forward pass dell'architettura.

        Args:
            x (torch.Tensor): Tensore di input di forma ``(N, 1, H, W)``.

        Returns:
            torch.Tensor: Mappa di probabilità ``(N, 1, H, W)`` in ``[0, 1]``.

        Shape:
            - Input: ``(N, 1, H, W)``
            - Output: ``(N, 1, H, W)``
        """
        x1 = self.enc_conv1(x)
        x2 = self.pool1(x1)

        x3 = self.enc_conv2(x2)
        x4 = self.pool2(x3)

        x5 = self.enc_conv3(x4)
        x6 = self.pool3(x5)

        x7 = self.bottleneck(x6)

        x8 = self.upconv3(x7)
        x8 = torch.cat([x8, x5], dim=1)
        x8 = self.dec_conv3(x8)

        x9 = self.upconv2(x8)
        x9 = torch.cat([x9, x3], dim=1)
        x9 = self.dec_conv2(x9)

        x10 = self.upconv1(x9)
        x10 = torch.cat([x10, x1], dim=1)
        x10 = self.dec_conv1(x10)

        out = self.final_conv(x10)
        return torch.sigmoid(out)


