import numpy as np
import torch

"""Test smoke per network_core.

Esegue la pipeline con modelli dummy sia in 2D che in 3D e verifica che:
- la shape dell'output coincida con quella dell'input
- non ci siano NaN/Inf

Esecuzione consigliata (package):
    python -m network_core.test_network_core

Esecuzione diretta (fallback):
    python network_core/test_network_core.py
"""

# Import robusto: funziona sia come modulo che come script
try:
    from . import process_swath_array  # quando eseguito come modulo
except Exception:
    import sys, pathlib
    sys.path.append(str(pathlib.Path(__file__).resolve().parents[1]))
    from network_core import process_swath_array  # quando eseguito come script


class DummyUNet2D(torch.nn.Module):
    def forward(self, x):
        # x: (B,1,32,32) → output (B,1,32,32) in [0,1]
        return torch.sigmoid(x)


class DummyUNet3D(torch.nn.Module):
    def forward(self, x):
        # x: (B,1,32,32,32) → output (B,1,32,32,32) in [0,1]
        return torch.sigmoid(x)


def _make_volume(x=19, y=96, z=64, seed=0):
    rng = np.random.default_rng(seed)
    vol = rng.normal(0.0, 0.1, size=(x, y, z)).astype(np.float32)
    return vol


def test_process_2d_cpu():
    vol = _make_volume()
    model = DummyUNet2D()
    mask = process_swath_array(
        vol,
        mode="2d",
        model_trasversale=model,
        model_longitudinale=model,
        device="cpu",
        stride=16,
        batch_size=128,
        apply_gaussian_blur=False,
    )
    assert mask.shape == vol.shape, f"Shape errata: {mask.shape} vs {vol.shape}"
    assert np.isfinite(mask).all(), "NaN/Inf nella maschera 2D"
    print("✓ Test 2D OK — shape", mask.shape)


def test_process_3d_cpu():
    vol = _make_volume()
    model3d = DummyUNet3D()
    mask = process_swath_array(
        vol,
        mode="3d",
        model_3d=model3d,
        device="cpu",
        stride_y_3d=16,
        stride_z_3d=16,
        batch_size_3d=8,
        apply_gaussian_blur=False,
    )
    assert mask.shape == vol.shape, f"Shape errata: {mask.shape} vs {vol.shape}"
    assert np.isfinite(mask).all(), "NaN/Inf nella maschera 3D"
    print("✓ Test 3D OK — shape", mask.shape)


if __name__ == "__main__":
    test_process_2d_cpu()
    test_process_3d_cpu()
    print("Tutti i test smoke sono passati.")


