# controllers/MapController.py
# IMPORT PYSIDE6
from PySide6.QtCore import QObject
from PySide6.QtCore import QTimer, QUrl

# IMPORT MAPPA
from PySide6.QtWebEngineCore import QWebEngineSettings, QWebEngineScript
from PySide6.QtWebChannel import QWebChannel
from pyqtlet2 import L, MapWidget
# IMPORT JSBRIDGE
from Controller.Main.JSBridge import JSBridge  

# IMPORT CHANGED
from Controller.Main.Changed import ChangedController

# MODELLI
from Model.SurveyManager import SurveyManager
from Model.TileManager import TileManager

# THREAD
from Thread.ReadFile import ReadFile
from Thread.Rendering import Rendering
from Thread.RunAI import RunAI

# GLOBAL
from GLOBAL import root_software_path

import logging
import os
import json
from pyproj import Transformer
import numpy as np
from scipy.interpolate import griddata
import rasterio
from rasterio.transform import from_bounds
from rasterio.features import geometry_mask
from shapely.geometry import Polygon, mapping

class MapController(QObject):

    def __init__(self, view, parent=None):
        super().__init__(parent)

        self.view = view
        self.survey_manager = SurveyManager()  # recupero il gestore dei survey
        self.tile_manager = TileManager()  # recupero il gestore dei tile
        self.map_logger = logging.getLogger('map')

        # NOTIFY CHANGED
        self.changed_controller = ChangedController()
        self.changed_controller.swath_changed.connect(self.change_select_swath)

        # THREAD
        self.th_readFile = ReadFile()

        self.th_runAI = RunAI()
        self.th_runAI.ai_completed.connect(self.switch_tomography)
        self.th_runAI.ai_switch_tomo.connect(self.switch_tomography)
        self.inferenza_ai = False

        self.th_rendering = Rendering()
        self.th_rendering.update_map.connect(self.update_map) # --> da fare
        #self.th_rendering.start() # --> thread sempre in esecuzione

        # CREAZIONE DELLA MAPPA -------------------------------------------------------------------------------------
        self.L = L
        self.mapWidget = MapWidget()
        self.mapWidget.settings().setAttribute(QWebEngineSettings.WebAttribute.LocalContentCanAccessRemoteUrls, True)

        self.map = self.L.map(self.mapWidget, {'attributionControl': False})

        self.map.setView([44.80772488328479, 10.322647213925565], 10)
        self.online = self.L.tileLayer('https://mt1.google.com/vt/lyrs=s&x={x}&y={y}&z={z}',
                                       options={"maxNativeZoom": 22, "maxZoom": 24})
        self.online.addTo(self.map)
        
        self.view.setMap(self.mapWidget)
        # -----------------------------------------------------------------------------------------------------------

        # CREAZIONE DEI LAYER DALLE MAPPA
        # survey
        self.survey_layer = self.L.layerGroup()
        self.survey_layer.addTo(self.map)

        # layer per selezione/visualizzazione swath (usato per disegnare poligoni)
        self.select_swath_layer = self.L.layerGroup()
        self.select_swath_layer.addTo(self.map)

        # REGISTRA I LAYER IN JAVASCRIPT PER IL CONTROLLO AVANZATO
        self.map.runJavaScript(f"""
            window.surveyLayer = {self.survey_layer.jsName};
            window.select_swath_layer = {self.select_swath_layer.jsName};
        """, self.map.mapWidgetIndex)

        # BRIDGE JS
        self.js_bridge = JSBridge()
        self.js_bridge.depthChanged.connect(self.depth_change) # --> cattura la modifica dello slider di profondità
        self.js_bridge.zoomBoundsChanged.connect(self.on_zoom_bounds) # --> restituisce i bounds della mappa (visualizzata in quel momento)
        # Registra l'oggetto bridge nel canale
        self.mapWidget.channel.registerObject("bridge", self.js_bridge)

        js_file_path = os.path.join(root_software_path, "JS", "map_custom.js")
        # Carica JS dopo il caricamento della mappa
        with open(js_file_path, "r", encoding="utf-8") as f:
            js_code = f.read()
        # Esegui dopo delay
        QTimer.singleShot(500, lambda: self.map.runJavaScript(js_code, self.map.mapWidgetIndex))
        QTimer.singleShot(510, lambda: self.map.runJavaScript("initCustomFeatures(map);", self.map.mapWidgetIndex))
        # STORICO SUBSWATH IN VISUALIZZAZIONE
        self.old_subswath_in_bounds = [] # tengo le subswath che sono in visualizzazione cosi da confrontarle con quelle nuove
        self.old_depth = None # mi tengo l'ultima profondità selezionata per confrontarla con quella nuova
        self.prima_volta = True

        # Avvia la generazione dei tile (thread) all'apertura GUI SOLO se Open Project
        QTimer.singleShot(1000, self._start_tile_thread_if_ready)
    
    # EVENTO RICHIAMATO AL COMPLETAMENTO DELL'ELABORAZIONE AI
    def switch_tomography(self):
        print("[MapController] Elaborazione AI completata, cambio tomografia")
        self.inferenza_ai = not self.inferenza_ai
        self.th_rendering.switch_tomography(self.inferenza_ai)
    
    # EVENTO RICHIAMATO AL CAMBIO DEL VALORE DALLO SLIDER JS
    def depth_change(self, val):
        #return
        print(f"[MapController] Slider JS ha selezionato profondità: {val} cm")
        self.changed_controller.emit_depth_changed(val) # --> Aggiorno la profondità del mio survey selezionato
        self.th_rendering.update_depth_from_map()

    def on_zoom_bounds(self, minLat, minLng, maxLat, maxLng, zoomLevel):
        minLat, minLng = self.from_4326_to_32632(minLng, minLat)
        maxLat, maxLng = self.from_4326_to_32632(maxLng, maxLat)
        print(f"[MapController] Bounds: min_coord[{minLng}, {minLat}]  max_coord[{maxLng}, {maxLat}] - Zoom: {zoomLevel}")
        # Qui puoi aggiungere la logica che vuoi quando cambia lo zoom
        # Ad esempio, aggiornare la risoluzione dei tile o altri comportamenti
        #self.th_rendering.rendering_tile_in_bounds(minLat, minLng, maxLat, maxLng)
        self.th_rendering.run(minLat, minLng, maxLat, maxLng)

    # continuare a provare a disegnare i tile (array) senza salvare raster sulla mappa 
    def debug_tile(self):
        import numpy as np
        
        def plot_array_canvas(arr, bounds_sw_ne, zindex=4000):
            """
            arr: np.ndarray (H,W), (H,W,3) o (H,W,4). Verrà convertito a RGBA uint8 in Python.
            bounds_sw_ne: [[S,W],[N,E]] float in WGS84.
            zindex: intero per portare il layer sopra le tiles.
            """

            # --- to RGBA uint8 (robusto) ---
            a = np.ascontiguousarray(arr)
            if a.ndim == 2:
                # grayscale -> RGBA
                a = a.astype(np.float32)
                a -= np.nanmin(a)
                denom = float(np.nanmax(a))
                a = (255*(a/denom) if denom > 0 else a).clip(0, 255).astype(np.uint8)
                a = np.stack([a, a, a, np.full_like(a, 255)], axis=-1)
            elif a.ndim == 3 and a.shape[2] in (3, 4):
                if a.dtype != np.uint8:
                    a = np.clip(a, 0, 255).astype(np.uint8)
                if a.shape[2] == 3:
                    alpha = np.full(a.shape[:2] + (1,), 255, dtype=np.uint8)
                    a = np.concatenate([a, alpha], axis=2)
            else:
                raise ValueError("Array deve essere (H,W), (H,W,3) o (H,W,4).")
            H, W = int(a.shape[0]), int(a.shape[1])

            # Per trasmettere i pixel in JS in modo compatto/robusto: HEX (niente base64)
            rgba_hex = a.tobytes().hex()

            S, Wlon = float(bounds_sw_ne[0][0]), float(bounds_sw_ne[0][1])
            N, Elon = float(bounds_sw_ne[1][0]), float(bounds_sw_ne[1][1])
            JS = f"""
                    (function() {{
                    function hexToU8(hex) {{
                        const len = hex.length/2, out = new Uint8Array(len);
                        for (let i=0;i<len;i++) out[i] = parseInt(hex.substr(i*2,2),16);
                        return out;
                    }}
                    const map = (window._leaflet_map||window.map);
                    if (!map) {{ console.error('Leaflet map not ready'); return; }}

                    // Definisci la classe CanvasOverlay una volta
                    if (!window.CanvasOverlay) {{
                        window.CanvasOverlay = L.Layer.extend({{
                        initialize: function(bounds, opts) {{
                            this._bounds = L.latLngBounds(bounds);
                            L.setOptions(this, opts);
                        }},
                        onAdd: function(map) {{
                            this._map = map;
                            this._canvas = L.DomUtil.create('canvas', 'leaflet-canvas-layer');
                            this._ctx = this._canvas.getContext('2d');
                            if (this.options.zIndex != null) this._canvas.style.zIndex = String(this.options.zIndex);
                            map.getPanes().overlayPane.appendChild(this._canvas);
                            map.on('zoomend viewreset moveend resize', this._reset, this);
                            this._reset();
                        }},
                        onRemove: function(map) {{
                            L.DomUtil.remove(this._canvas);
                            map.off('zoomend viewreset moveend resize', this._reset, this);
                        }},
                        _reset: function() {{
                            const nw = this._map.latLngToLayerPoint(this._bounds.getNorthWest());
                            const se = this._map.latLngToLayerPoint(this._bounds.getSouthEast());
                            const size = se.subtract(nw);
                            L.DomUtil.setPosition(this._canvas, nw);
                            this._canvas.width  = Math.max(1, Math.round(Math.abs(size.x)));
                            this._canvas.height = Math.max(1, Math.round(Math.abs(size.y)));
                            this._redraw();
                        }},
                        setImageData: function(u8rgba, w, h) {{
                            this._imgW = w; this._imgH = h; this._u8 = u8rgba;
                            this._redraw();
                        }},
                        _redraw: function() {{
                            if (!this._u8 || !this._ctx || !this._canvas) return;
                            const ctx = this._ctx;
                            const Wc = this._canvas.width, Hc = this._canvas.height;
                            const imgData = new ImageData(new Uint8ClampedArray(this._u8), this._imgW, this._imgH);
                            const off = document.createElement('canvas');
                            off.width = this._imgW; off.height = this._imgH;
                            off.getContext('2d').putImageData(imgData, 0, 0);
                            ctx.clearRect(0,0,Wc,Hc);
                            ctx.imageSmoothingEnabled = true;
                            ctx.drawImage(off, 0, 0, Wc, Hc);
                        }}
                        }});
                    }}

                    const bounds = L.latLngBounds([{S}, {Wlon}], [{N}, {Elon}]);
                    const layer = new CanvasOverlay(bounds, {{ zIndex: {int(zindex)} }}).addTo(map);
                    const u8 = hexToU8("{rgba_hex}");
                    layer.setImageData(u8, {W}, {H});
                    }})();
                    """
            #self.map.runJavaScript(JS, self.map.mapWidgetIndex)
            rgba_js = json.dumps(rgba_hex)
            self.map.runJavaScript(f"drawTile(map, {bounds}, {rgba_js}, {W}, {H}, {zindex});", self.map.mapWidgetIndex)
        for index_tile in range(self.tile_manager.total_tiles):
            bounds = self.tile_manager.dic_tile[index_tile].bounds_zone

            H, W = 20, 20
            arr = np.random.rand(H, W, 4) * 1000  # matrice qualsiasi
            arr[:, :, 0] = self.tile_manager.dic_tile[index_tile].rgb[0]
            arr[:, :, 1] = self.tile_manager.dic_tile[index_tile].rgb[1]
            arr[:, :, 2] = self.tile_manager.dic_tile[index_tile].rgb[2]
            arr[:, :, 3] = 255

            min_bounds = bounds[3]
            max_bounds = bounds[1]

            b_min_x, b_min_y = self.from_32632_to_4326(min_bounds[1], min_bounds[0])
            b_max_x, b_max_y = self.from_32632_to_4326(max_bounds[1], max_bounds[0])

            bounds = [[b_min_y, b_min_x], [b_max_y, b_max_x]]
            print(f"bounds: {bounds}")

            plot_array_canvas(arr, bounds, zindex=3000)
        return   

    # AGGIORNO LA MAPPA
    def update_map(self):
        print("Cambio la visualizzazione")
        #self.debug_tile()
        # Cosa devo fare?
        # 0. Confronto tra le subswath in visualizzazione e quelle nuove
        # 1. Se ce ne sono di nuove, le aggiungo alla mappa
        # 2. Se alcune non ci sono più, le rimuovo dalla mappa(o le rendo non visibili)

        list_subswath_in_bounds = self.th_rendering.subswath_in_bounds # --> lista delle subswath che sono in visualizzazione
        list_index_tile = self.th_rendering.list_index_tile # --> lista degli indici dei tile che sono in visualizzazione

        
        if self.old_depth != self.survey_manager.get_depth(): # oltre ai nuovi aggiorno i raster della profondità corrente
            import time
            start = time.time()
            # 1) Aggiungo prima i nuovi raster (evita il "vuoto" visivo)
            for index_subswath in list_subswath_in_bounds:
                key = f"{index_subswath[0]}_{index_subswath[1]}_{index_subswath[2]}_{self.survey_manager.get_depth()}"
                raster = self.th_rendering.cache.cache_dict[key]["raster"]
                bounds_utm = self.th_rendering.cache.cache_dict[key]["bounds"]  # [[S,W],[N,E]] in EPSG:32632
                # Converto bounds in WGS84 per Leaflet
                S_utm, W_utm = float(bounds_utm[0][0]), float(bounds_utm[0][1])
                N_utm, E_utm = float(bounds_utm[1][0]), float(bounds_utm[1][1])
                lng_SW, lat_SW = self.from_32632_to_4326(W_utm, S_utm)
                lng_NE, lat_NE = self.from_32632_to_4326(E_utm, N_utm)
                bounds_js = [[lat_SW, lng_SW], [lat_NE, lng_NE]]
                # Serializzo per JavaScript (NaN consentiti dal motore JS)
                matrix_js = json.dumps(np.asarray(raster).tolist(), allow_nan=True)
                bounds_js_str = json.dumps(bounds_js)
                key_js = json.dumps(key)
                zindex = 3000
                self.map.runJavaScript(f"drawSubswath(map, {bounds_js_str}, {matrix_js}, {key_js}, {zindex});", self.map.mapWidgetIndex)
            # 2) Rimuovo i raster della vecchia profondità
            for index_subswath in self.old_subswath_in_bounds:
                key = f"{index_subswath[0]}_{index_subswath[1]}_{index_subswath[2]}_{self.old_depth}"
                key_js = json.dumps(key)
                self.map.runJavaScript(f"remove_from_map_dict(map, {key_js});", self.map.mapWidgetIndex)
            end = time.time()
            print(f"Tempo di aggiornamento della mappa: {end - start} secondi")
            # aggiorno la lista delle subswath in visualizzazione
            self.old_subswath_in_bounds = list_subswath_in_bounds.copy()
            self.old_depth = self.survey_manager.get_depth()
        else: # Ho la stessa profondità, quindi modifico solo i raster che sono apparsi in visualizzazione
            # 1. Confronto tra le subswath in visualizzazione e quelle nuove
            for index_subswath in list_subswath_in_bounds:
                if index_subswath not in self.old_subswath_in_bounds:
                    key = f"{index_subswath[0]}_{index_subswath[1]}_{index_subswath[2]}_{self.survey_manager.get_depth()}"
                    raster = self.th_rendering.cache.cache_dict[key]["raster"]
                    bounds_utm = self.th_rendering.cache.cache_dict[key]["bounds"]  # [[S,W],[N,E]] in EPSG:32632
                    # Converto bounds in WGS84 per Leaflet
                    S_utm, W_utm = float(bounds_utm[0][0]), float(bounds_utm[0][1])
                    N_utm, E_utm = float(bounds_utm[1][0]), float(bounds_utm[1][1])
                    lng_SW, lat_SW = self.from_32632_to_4326(W_utm, S_utm)
                    lng_NE, lat_NE = self.from_32632_to_4326(E_utm, N_utm)
                    bounds_js = [[lat_SW, lng_SW], [lat_NE, lng_NE]]
                    # Serializzo per JavaScript (NaN consentiti dal motore JS)
                    matrix_js = json.dumps(np.asarray(raster).tolist(), allow_nan=True)
                    bounds_js_str = json.dumps(bounds_js)
                    key_js = json.dumps(key)
                    zindex = 3000
                    self.map.runJavaScript(f"drawSubswath(map, {bounds_js_str}, {matrix_js}, {key_js}, {zindex});", self.map.mapWidgetIndex)
            # 2. Rimuovo i raster delle subswath che non sono più in visualizzazione
            for index_subswath in self.old_subswath_in_bounds:
                if index_subswath not in list_subswath_in_bounds:
                    key = f"{index_subswath[0]}_{index_subswath[1]}_{index_subswath[2]}_{self.survey_manager.get_depth()}"
                    key_js = json.dumps(key)
                    self.map.runJavaScript(f"remove_from_map_dict(map, {key_js});", self.map.mapWidgetIndex)

            # aggiorno la lista delle subswath in visualizzazione
            self.old_subswath_in_bounds = list_subswath_in_bounds.copy()
        if self.prima_volta:
            #self.debug_tile()
            self.change_select_swath(0)
            # zoom vicino alla tomografia
            # prendo una coordinata della tomografie e faccio un set view vicino a quella coordinata
            coord = self.th_rendering.tile_manager.dic_tile[0].bounds_zone[0]
            x, y = coord[1], coord[0]  # bounds_zone memorizza (y, x)
            lng, lat = self.from_32632_to_4326(x, y)  # ritorna (lng, lat)
            self.map.setView([lat, lng], 18)
            self.prima_volta = False
            return

    def change_select_swath(self, id_swath):
        # 1. prendo il geojson della subswath selezionata
        geojson_path = self.survey_manager[self.survey_manager.get_id_selected_survey()][id_swath].path_polygon_total_area
        with open(geojson_path, 'r') as f:
            geojson = json.load(f)
        polygon = geojson['features'][0]['geometry']['coordinates'][0]  # EPSG:32632 (lista di [x,y])
        # Trasforma le coordinate UTM (EPSG:32632) in WGS84 (EPSG:4326)
        transformer = Transformer.from_crs("EPSG:32632", "EPSG:4326", always_xy=True)
        polygon_4326 = []
        for x, y in polygon:
            lng, lat = transformer.transform(x, y)  # sempre (x,y)->(lng,lat)
            polygon_4326.append([lat, lng])        # Leaflet vuole [lat, lng]
        # chiudo il poligono se non chiuso
        if polygon_4326[0] != polygon_4326[-1]:
            polygon_4326.append(polygon_4326[0])
        js_coords = json.dumps(polygon_4326)
        self.map.runJavaScript(f"drawContourSwath(map, {js_coords});", self.map.mapWidgetIndex)
        return

    # funzione per passare da EPSG:32632 a EPSG:4326
    def from_32632_to_4326(self, x, y):
        transformer = Transformer.from_crs("EPSG:32632", "EPSG:4326", always_xy=True)
        return transformer.transform(x, y)

    # funzione per passare da EPSG:4326 a EPSG:32632
    def from_4326_to_32632(self, x, y):
        transformer = Transformer.from_crs("EPSG:4326", "EPSG:32632", always_xy=True)
        return transformer.transform(x, y)

    def _start_tile_thread_if_ready(self):
        try:
            import GLOBAL as G
            if not getattr(G, 'open_project_mode', False):
                return
            if hasattr(self, 'th_rendering') and hasattr(self.th_rendering, 'th_raster_tile'):
                if not self.th_rendering.th_raster_tile.isRunning():
                    self.th_rendering.th_raster_tile.start()
                    # resetto il flag per evitare riavvii inattesi
                    G.open_project_mode = False
        except Exception:
            pass
