from PySide6.QtCore import QThread, Signal
import os
import numpy as np

from Model.SurveyManager import SurveyManager
# Nota: importiamo Lazy i moduli AI pesanti nel metodo run()


class RunAI(QThread):

    _instance = None  # Qui memorizziamo l'unica istanza

    def __new__(cls):
        if cls._instance is None:
            cls._instance = super(RunAI, cls).__new__(cls)
            cls._instance._initialized = False
        return cls._instance
    """
    Thread placeholder for running AI models/inference.
    Mirrors the filter thread structure with start/complete/error signals.
    """

    ai_started = Signal(str)
    ai_completed = Signal(str)
    ai_error = Signal(str)
    ai_progress = Signal(str)

    # dice a map controller di ritornare alla tomografia originale
    ai_switch_tomo = Signal()

    def __init__(self, parent=None):
        # Evita la doppia inizializzazione del QThread
        if hasattr(self, '_initialized') and self._initialized:
            return
            
        super().__init__(parent)
        self._initialized = True
        # future: configuration fields (model path, options, device, etc.)
        self.request = None
        self.survey_manager = SurveyManager()
        
        # Percorsi di default dei modelli (possono essere sovrascritti via set_request)
        self.default_model_paths = {
            'longitudinale': r"X:\R&D\DeepGeo\GPRDG\AI\modelli\2D_Longitudinali\output_training_valerio\checkpoints\best_model.pth",
            'trasversale': r"X:\R&D\DeepGeo\GPRDG\AI\modelli\2D_Trasversali\output_tra_akira\checkpoints\best_model.pth", 
            'manhole': r"X:\R&D\DeepGeo\GPRDG\AI\modelli\MANHOLES\exp_longitudinale_128\checkpoints\best_epoch002_valloss0.218273.pth"
        }

    def set_request(self, request: dict | None):
        self.request = request or {}
    
    def set_model_paths(self, model_paths: dict):
        """
        Imposta nuovi percorsi dei modelli.
        model_paths: dict con chiavi 'longitudinale', 'trasversale', 'manhole'
        """
        if not hasattr(self, 'request') or self.request is None:
            self.request = {}
        self.request['model_paths'] = model_paths
        print(f"[RunAI] Nuovi percorsi modelli impostati:")
        for name, path in model_paths.items():
            print(f"[RunAI]   {name}: {path}")

    def run(self):
        print("[RunAI] ========================================")
        print("[RunAI] Thread RunAI.run() AVVIATO")
        print("[RunAI] ========================================")
        try:
            # Import lazy per ridurre i rischi di inizializzazione multipla di OpenMP
            print("[RunAI] Import moduli AI...")
            from Algorithm.AI.processor import process_swath_array
            from Algorithm.AI.models import load_model_2d, load_model_manhole_2d
            import hashlib
            import json
            from datetime import datetime
            print("[RunAI] Import moduli completati")
            
            # Carico i modelli: uso custom paths se forniti, altrimenti default
            if self.request and 'model_paths' in self.request:
                model_paths = self.request['model_paths']
                print("[RunAI] Uso percorsi modelli CUSTOM forniti via set_model_paths()")
            else:
                model_paths = self.default_model_paths
                print("[RunAI] Uso percorsi modelli DEFAULT")
            
            print("\n[RunAI] === CARICAMENTO MODELLI ===")
            print(f"[RunAI] Longitudinale: {model_paths['longitudinale']}")
            print(f"[RunAI] Trasversale: {model_paths['trasversale']}")
            print(f"[RunAI] Manhole: {model_paths['manhole']}")
            
            print("[RunAI] Carico model longitudinale...")
            model_longitudinale = load_model_2d(model_paths['longitudinale'])
            print("[RunAI] ✓ Model longitudinale caricato")
            
            print("[RunAI] Carico model trasversale...")
            model_trasversale = load_model_2d(model_paths['trasversale'])
            print("[RunAI] ✓ Model trasversale caricato")
            
            print("[RunAI] Carico model manhole...")
            model_manhole = load_model_manhole_2d(model_paths['manhole'])
            print("[RunAI] ✓ Model manhole caricato")
            
            # Calcola hash dei modelli per tracciare versioni
            model_metadata = {
                'timestamp': datetime.now().isoformat(),
                'models': {}
            }
            print("[RunAI] Calcolo hash dei modelli...")
            for name, path in model_paths.items():
                try:
                    if os.path.exists(path):
                        with open(path, 'rb') as f:
                            file_hash = hashlib.md5(f.read()).hexdigest()
                        model_metadata['models'][name] = {
                            'path': path,
                            'hash': file_hash,
                            'modified': datetime.fromtimestamp(os.path.getmtime(path)).isoformat()
                        }
                        print(f"[RunAI] {name}: hash={file_hash[:8]}... modified={model_metadata['models'][name]['modified']}")
                except Exception as e:
                    model_metadata['models'][name] = {'path': path, 'error': str(e)}
                    print(f"[RunAI] ERRORE hash {name}: {e}")
            
            print("[RunAI] === MODELLI CARICATI CON SUCCESSO ===\n")
            
            self.ai_started.emit("AI elaboration started")
            # Process only VV polarization as requested (0=VV, 1=HH)
            selected_survey_id = self.survey_manager.get_id_selected_survey()
            self.ai_progress.emit(f"Selected survey: {selected_survey_id}")

            # Order: HH(1) then VV(0) to mirror filter flow
            print(f"[RunAI] Inizio processing, polarization=0")
            for pol in (0,):
                self.survey_manager.set_current_polarization(pol)
                survey = self.survey_manager[selected_survey_id]
                try:
                    num_swath = survey.get_num_tot_swath()
                    print(f"[RunAI] Numero swath da processare: {num_swath}")
                except Exception:
                    num_swath = 0
                    print(f"[RunAI] ERRORE ottenimento num swath")
                
                for sw_idx in range(num_swath):
                    print(f"[RunAI] Processing swath {sw_idx}/{num_swath}")
                    swath = survey[sw_idx]
                    try:
                        num_sub = swath.get_num_tot_subswath()
                        print(f"[RunAI] Swath {sw_idx}: num sub={num_sub}")
                    except Exception:
                        num_sub = 0
                        print(f"[RunAI] ERRORE ottenimento num sub")
                    for sub_idx in range(num_sub):
                        print(f"[RunAI] Processing swath {sw_idx}, sub {sub_idx}/{num_sub}")
                        subswath = swath[sub_idx]
                        print(f"[RunAI] Subswath {sub_idx} ottenuto, check folders...")
                        # INPUT: Data_current_filter (dati filtrati)
                        input_folder = getattr(subswath, 'data_current_filter_folder', None)
                        print(f"[RunAI] Data_current_filter: {input_folder}, exists={input_folder and os.path.isdir(input_folder)}")
                        if not input_folder or not os.path.isdir(input_folder):
                            print(f"[RunAI] ✗ Skip: no Data_current_filter for swath {sw_idx} sub {sub_idx}")
                            self.ai_progress.emit(f"Skip: no Data_current_filter for swath {sw_idx} sub {sub_idx}")
                            continue
                        # OUTPUT: Data_AI (dati con inferenza, separati dai filtrati)
                        output_folder = getattr(subswath, 'data_AI_folder', None)
                        print(f"[RunAI] Data_AI: {output_folder}, exists={output_folder is not None}")
                        if not output_folder:
                            print(f"[RunAI] ✗ Skip: no Data_AI folder for swath {sw_idx} sub {sub_idx}")
                            self.ai_progress.emit(f"Skip: no Data_AI folder for swath {sw_idx} sub {sub_idx}")
                            continue
                        print(f"[RunAI] ✓ Folders OK, procedo con elaborazione...")
                        
                        self.ai_progress.emit(f"Building volume (pol={pol}, swath={sw_idx}, sub={sub_idx}) from {input_folder}...")
                        print(f"[RunAI] INPUT: {input_folder}")
                        print(f"[RunAI] OUTPUT: {output_folder}")
                        try:
                            print(f"[RunAI] Building volume da {input_folder}...")
                            volume = self._volume_from_channel_fast(input_folder)
                            print(f"[RunAI] ✓ Volume built: shape={volume.shape}")
                        except Exception as e:
                            print(f"[RunAI] ✗ Volume build failed: {e}")
                            import traceback
                            traceback.print_exc()
                            self.ai_progress.emit(f"Volume build failed: {e}")
                            continue
                        self.ai_progress.emit(f"Running model (shape={volume.shape})...")
                        print(f"[RunAI] Running model...")
                        try:
                            # User semantics: z is subswath height in their radargrams.
                            # The model expects height on axis=1 (y). Map z->y via swapaxes before inference.
                            #volume_for_model = np.swapaxes(volume, 1, 2)
                            #mask = process_swath_array(volume_for_model)
                            # nuova aggiunta
                            # Use original volume orientation directly
                            print(f"[RunAI] Chiamata process_swath_array...")
                            result = process_swath_array(
                                volume,
                                mode="2d",
                                model_trasversale=model_trasversale,
                                model_longitudinale=model_longitudinale,
                                model_manhole=model_manhole,
                                normalize_input=True,
                            )
                            print(f"[RunAI] ✓ Model inference completata")
                            if isinstance(result, tuple):
                                mask, manhole_candidates = result
                                subswath.save_manholes(manhole_candidates)
                                print(f"Manholes saved to: {subswath.manholes_path}")
                                # DEVO SALVARE I MANHOLES NELLA CARTELLA DELLA SUBSWATH (lo salvo come json x,y,z) x: num_channel, y: depth(sempre a zero), z: lunghezza rilievo
                            else:
                                mask, manhole_candidates = result, None

                            

                        except Exception as e:
                            print(f"[RunAI] ✗ Model error: {e}")
                            import traceback
                            traceback.print_exc()
                            self.ai_error.emit(f"Model error: {e}")
                            continue
                        # Salva maschere AI in Data_AI (NON sovrascrive Data_current_filter)
                        try:
                            print(f"[RunAI] Salvo maschere in {output_folder}...")
                            # Save mask to Data_AI folder
                            self._write_mask_as_channels(mask, output_folder)
                            print(f"[RunAI] Maschere AI salvate in: {output_folder}")
                            
                            # Save model metadata for versioning
                            metadata_path = os.path.join(output_folder, 'ai_model_metadata.json')
                            with open(metadata_path, 'w', encoding='utf-8') as f:
                                json.dump(model_metadata, f, indent=2, ensure_ascii=False)
                            print(f"[RunAI] Metadata salvato in: {metadata_path}")
                            print(f"[RunAI] Timestamp: {model_metadata['timestamp']}")
                            
                            self.ai_progress.emit("Mask saved to Data_AI")
                        except Exception as e:
                            self.ai_error.emit(f"Save error: {e}")
                            continue

            print("[RunAI] ========================================")
            print("[RunAI] AI elaboration completata, emetto signal")
            print("[RunAI] ========================================")
            self.ai_completed.emit("AI elaboration completed")
            print("[RunAI] Signal ai_completed emesso")
        except Exception as e:
            print(f"[RunAI] ERRORE in run(): {e}")
            import traceback
            traceback.print_exc()
            self.ai_error.emit(str(e))

    # --------------- helpers ---------------
    def _volume_from_channel_fast(self, channel_folder: str) -> np.ndarray:
        """Builds a volume (x,y,z) stacking channel_XX.npy from Data_current_filter.
        x = num channels, y = subswath height, z = subswath length.

        Robust to mismatched shapes due to legacy files: will try to transpose to a common
        orientation, then zero-pad or crop to the target shape so stacking always succeeds.
        """
        files = [f for f in os.listdir(channel_folder) if f.startswith("channel_") and f.endswith(".npy")]
        if not files:
            raise FileNotFoundError(f"No channel_*.npy found in {channel_folder}")
        files = sorted(files, key=lambda x: int(x.split("_")[1].split(".")[0]))

        # Load all arrays
        arrays_raw: list[np.ndarray] = []
        for f in files:
            arr = np.load(os.path.join(channel_folder, f))
            if arr.ndim != 2:
                raise ValueError(f"Channel file {f} is not 2D: shape={arr.shape}")
            arrays_raw.append(np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32, copy=False))

        # Decide target orientation/shape by majority vote on shapes and their transposes
        shape_counts: dict[tuple[int, int], int] = {}
        for a in arrays_raw:
            s = a.shape
            shape_counts[s] = shape_counts.get(s, 0) + 1
            shape_counts[(s[1], s[0])] = shape_counts.get((s[1], s[0]), 0)  # ensure key exists for tally
        # pick the shape with max count
        target_shape = max(shape_counts.items(), key=lambda kv: kv[1])[0]

        # Harmonize arrays to target_shape
        arrays: list[np.ndarray] = []
        ty, tz = target_shape
        for idx, a in enumerate(arrays_raw):
            ay, az = a.shape
            # transpose if it matches target when transposed
            if (ay, az) != (ty, tz) and (az, ay) == (ty, tz):
                a = a.T
                ay, az = a.shape
            # pad or crop to target
            if (ay, az) != (ty, tz):
                padded = np.zeros((ty, tz), dtype=np.float32)
                py = min(ay, ty)
                pz = min(az, tz)
                padded[:py, :pz] = a[:py, :pz]
                a = padded
            # Crop Y to 512 to enforce requested subswath height cap
            if a.shape[0] > 512:
                a = a[:512, :]
            arrays.append(a)

        vol = np.stack(arrays, axis=0).astype(np.float32, copy=False)
        return vol

    def _write_mask_as_channels(self, mask: np.ndarray, channel_folder: str) -> None:
        """Writes the (x,y,z) mask as per-channel channel_XX.npy files, replacing current-filter radargrams."""
        x = mask.shape[0]
        for ch in range(x):
            out_path = os.path.join(channel_folder, f"channel_{ch:02d}.npy")
            np.save(out_path, mask[ch].astype(np.float32, copy=False))


