# Funzioni per la creazione del raster da visualizzare sulla mappa
# INPUT: array numpy e i bound dove dovrò disegnare la mia subswath
# --> creo la griglia attraverso i miei bound con un gsd fisso 
# --> interplazione nearest e applicare la mask per prendere solo i valori della subswath
# OUTPUT: raster numpy 

import numpy as np
from shapely.geometry import Polygon, mapping
from rasterio.transform import from_bounds
from pyproj import Transformer
from scipy.interpolate import griddata
from rasterio.features import geometry_mask, rasterize
from scipy.ndimage import distance_transform_edt
from scipy.spatial import cKDTree
import os
import rasterio
import json


def create_raster(values, x_coords, y_coords, path_perimetro, splat: str = 'nearest', tomo_inferenza: bool = False):
    transformer = Transformer.from_crs("EPSG:32632", "EPSG:4326", always_xy=True)

    lon, lat = transformer.transform(x_coords.ravel(), y_coords.ravel())
    lon = lon.reshape(x_coords.shape)
    lat = lat.reshape(y_coords.shape)

    delta_lon = lon.max() - lon.min()
    delta_lat = lat.max() - lat.min()

    if tomo_inferenza:
        risoluzione = 800
    else:
        risoluzione = 500

    if delta_lon > delta_lat:
        res_x = risoluzione
        res_y = int((delta_lat / delta_lon)*risoluzione)
    else:
        res_x = int((delta_lon / delta_lat)*risoluzione)
        res_y = risoluzione

    shape = (res_y, res_x)
    #print(f"Shape: {shape}")

    grid_lon = np.linspace(lon.min(), lon.max(), res_x)
    grid_lat = np.linspace(lat.min(), lat.max(), res_y)
    grid_lon_m, grid_lat_m = np.meshgrid(grid_lon, grid_lat)

    transform = from_bounds(grid_lon.min(), grid_lat.min(), grid_lon.max(), grid_lat.max(), res_x, res_y)

    with open(path_perimetro, "r") as f:
        polygon_utm = json.load(f)["features"][0]["geometry"]["coordinates"][0]

    ring = [transformer.transform(x, y) for x, y in polygon_utm]  # lon, lat
    geom = mapping(Polygon(ring))
    mask_inside = geometry_mask(
        [(geom, 1)],
        out_shape=shape,
        transform=transform,
        invert=True
    )

    mask = ~np.isnan(values)
    if not np.any(mask):
        return None

    val_valid = values[mask].astype(np.float32)
    rng = val_valid.max() - val_valid.min() or 1
    val_norm = (val_valid - val_valid.min()) / rng

    lon_valid = lon[mask]
    lat_valid = lat[mask]

    # Nearest senza arrotondamenti: KD-Tree dai campioni ai centri pixel della griglia
    valid = ~np.isnan(values)
    if not np.any(valid):
        return None

    val_valid = values[valid].astype(np.float32)
    rng = val_valid.max() - val_valid.min() or 1.0
    val_norm = (val_valid - val_valid.min()) / rng

    lon_valid = lon[valid].astype(np.float64)
    lat_valid = lat[valid].astype(np.float64)

    # EDT-based nearest senza aliasing: splat dei punti per creare semi densi
    # coordinate pixel continue
    px = (lon_valid - lon.min()) / (lon.max() - lon.min() + 1e-12) * (res_x - 1)
    py = (lat_valid - lat.min()) / (lat.max() - lat.min() + 1e-12) * (res_y - 1)

    x0 = np.floor(px).astype(np.int32)
    y0 = np.floor(py).astype(np.int32)
    x1 = np.clip(x0 + 1, 0, res_x - 1)
    y1 = np.clip(y0 + 1, 0, res_y - 1)
    x0 = np.clip(x0, 0, res_x - 1)
    y0 = np.clip(y0, 0, res_y - 1)

    if splat == 'nearest':
        # Splat nearest: solo il pixel (x0,y0)
        w00 = np.ones_like(px, dtype=np.float32)
        #w10 = np.zeros_like(px, dtype=np.float32)
        #w01 = np.zeros_like(px, dtype=np.float32)
        #w11 = np.zeros_like(px, dtype=np.float32)
    #else:
        # Splat bilineare
        #wx = (px - x0).astype(np.float32)
        #wy = (py - y0).astype(np.float32)
        #w00 = (1.0 - wx) * (1.0 - wy)
        #w10 = wx * (1.0 - wy)
        #w01 = (1.0 - wx) * wy
        #w11 = wx * wy

    seed_sum = np.zeros(shape, dtype=np.float32)
    seed_w = np.zeros(shape, dtype=np.float32)

    # indicizzazione lineare per np.add.at
    lin00 = (y0 * res_x + x0).astype(np.int64)
    lin10 = (y0 * res_x + x1).astype(np.int64)
    lin01 = (y1 * res_x + x0).astype(np.int64)
    lin11 = (y1 * res_x + x1).astype(np.int64)

    seed_sum_flat = seed_sum.ravel()
    seed_w_flat = seed_w.ravel()
    np.add.at(seed_sum_flat, lin00, val_norm * w00)
    np.add.at(seed_w_flat, lin00, w00)
    #np.add.at(seed_sum_flat, lin10, val_norm * w10)
   # np.add.at(seed_w_flat, lin10, w10)
    #np.add.at(seed_sum_flat, lin01, val_norm * w01)
    #np.add.at(seed_w_flat, lin01, w01)
    #np.add.at(seed_sum_flat, lin11, val_norm * w11)
    #np.add.at(seed_w_flat, lin11, w11)

    seed_sum = seed_sum_flat.reshape(shape)
    seed_w = seed_w_flat.reshape(shape)
    seed_mask = seed_w > 0
    img = np.zeros(shape, dtype=np.float32)
    img[seed_mask] = seed_sum[seed_mask] / seed_w[seed_mask]

    if not np.all(seed_mask):
        # EDT: distanza al seed più vicino (seed = True) → usiamo ~seed_mask per trovare indice del seed
        _, (iy, ix) = distance_transform_edt(~seed_mask, return_distances=True, return_indices=True)
        miss = ~seed_mask
        img[miss] = img[iy[miss], ix[miss]]

    grid_uint8 = (img * 255.0).astype(np.uint8)[::-1, :]
    data_to_write = np.where(mask_inside, grid_uint8, np.nan)

    if False:
        # Salvataggio debug come GeoTIFF a banda singola (float32) con nodata
        try:
            from GLOBAL import trash_dir
            out_path = os.path.join(trash_dir, 'raster', 'raster.tif')

            with rasterio.open(
                out_path, "w",
                driver="GTiff",
                height=shape[0],
                width=shape[1],
                count=1,
                dtype="uint8",
                crs="EPSG:4326",
                transform=transform,
                nodata=0,
                tiled=True,
                blockxsize=256,
                blockysize=256,
                compress="lzw",
                BIGTIFF="YES"
            ) as dst:
                dst.write(data_to_write, 1)
        except Exception:
            print("Errore nel salvataggio del raster")
    if False:
        # salvataggio del raster come png
        from GLOBAL import trash_dir
        import matplotlib.pyplot as plt
        out_path = os.path.join(trash_dir, 'raster', 'raster.png')
        plt.imsave(out_path, data_to_write, cmap='gray')

    return data_to_write

    