import numpy as np
from pyproj import Transformer
from scipy.interpolate import griddata
from shapely.geometry import Polygon, mapping
from rasterio.transform import from_bounds
from rasterio.features import geometry_mask
import rasterio
import os
import json
from concurrent.futures import ProcessPoolExecutor
from itertools import repeat
from GLOBAL import trash_dir  # Definisci questo percorso globale

# === FUNZIONI DI SUPPORTO ==================================================

def build_mask(polygon_utm, transformer, transform, shape):
    ring = [transformer.transform(x, y) for x, y in polygon_utm]  # lon, lat
    geom = mapping(Polygon(ring))
    return geometry_mask(
        [(geom, 1)],
        out_shape=shape,
        transform=transform,
        invert=True
    )

def write_geotiff(out_path, data, transform, shape):
    with rasterio.open(
        out_path, "w",
        driver="GTiff",
        height=shape[0],
        width=shape[1],
        count=1,
        dtype="uint8",
        crs="EPSG:4326",
        transform=transform,
        nodata=0,
        tiled=True,
        blockxsize=256,
        blockysize=256,
        compress="lzw",
        BIGTIFF="YES"
    ) as dst:
        dst.write(data, 1)

def process_single_depth(path, lon, lat, grid_lon_m, grid_lat_m,
                         mask_inside, transform, shape, out_dir_geotif):
    try:
        depth = np.load(path, mmap_mode="r")
        mask = ~np.isnan(depth)
        if not np.any(mask):
            return None

        val_valid = depth[mask].astype(np.float32)
        rng = val_valid.max() - val_valid.min() or 1
        val_norm = (val_valid - val_valid.min()) / rng

        lon_valid = lon[mask]
        lat_valid = lat[mask]

        grid_vals = griddata(
            (lon_valid, lat_valid),
            val_norm,
            (grid_lon_m, grid_lat_m),
            method="nearest",
            fill_value=0
        )
        grid_uint8 = (grid_vals * 255).astype(np.uint8)[::-1, :]
        data_to_write = np.where(mask_inside, grid_uint8, 0)

        name = os.path.splitext(os.path.basename(path))[0]
        out_path = os.path.join(out_dir_geotif, f"gpr_map_{name}.tif")
        write_geotiff(out_path, data_to_write, transform, shape)
        print(f"Creato: {out_path}")
        return out_path

    except Exception as e:
        print(f"Errore nel file {path}: {e}")
        return None

# === FUNZIONE PRINCIPALE ==================================================

def create_geotiffs_parallel(npy_lat_path, npy_lon_path, list_npy_depth_path, georef_json_path, out_dir_geotif):
    res_x, res_y = 500, 500
    shape = (res_y, res_x)

    transformer = Transformer.from_crs("EPSG:32632", "EPSG:4326", always_xy=True)

    x_utm = np.load(npy_lon_path, mmap_mode="r")
    y_utm = np.load(npy_lat_path, mmap_mode="r")

    lon, lat = transformer.transform(x_utm.ravel(), y_utm.ravel())
    lon = lon.reshape(x_utm.shape)
    lat = lat.reshape(y_utm.shape)

    grid_lon = np.linspace(lon.min(), lon.max(), res_x)
    grid_lat = np.linspace(lat.min(), lat.max(), res_y)
    grid_lon_m, grid_lat_m = np.meshgrid(grid_lon, grid_lat)

    transform = from_bounds(grid_lon.min(), grid_lat.min(), grid_lon.max(), grid_lat.max(), res_x, res_y)

    with open(georef_json_path, "r") as f:
        polygon_utm = json.load(f)["features"][0]["geometry"]["coordinates"][0]

    mask_inside = build_mask(polygon_utm, transformer, transform, shape)

    with ProcessPoolExecutor() as executor:
        results = executor.map(
            process_single_depth,
            list_npy_depth_path,
            repeat(lon),
            repeat(lat),
            repeat(grid_lon_m),
            repeat(grid_lat_m),
            repeat(mask_inside),
            repeat(transform),
            repeat(shape),
            repeat(out_dir_geotif)
        )

    return [r for r in results if r is not None]
