# Thread/ApplyFilter.py

from PySide6.QtCore import QThread, Signal
import numpy as np
import os

# MODELLI
from Model.SurveyManager import SurveyManager

# FILTRI
from Algorithm.Filters.OGPR_filters import (
    bandpass_fir,
    background_removal,
    gain,
    start_time_shifterOGPR
)

class ApplyFilter(QThread):
    # Singleton
    _instance = None

    def __new__(cls):
        if cls._instance is None:
            cls._instance = super(ApplyFilter, cls).__new__(cls)
            cls._instance._initialized = False
        return cls._instance
    """
    Thread completamente autonomo per l'applicazione dei filtri sui radargrammi.
    Gestisce tutto internamente senza dipendere da comunicazioni tra controller.
    """
    
    # Definizione dei segnali come attributi di classe
    filter_progress = Signal(str)  # Emesso per aggiornare la progress bar
    filter_completed = Signal()    # Emesso quando il filtro è completato
    filter_error = Signal(str)     # Emesso in caso di errore

    def __init__(self):
        # Evita la doppia inizializzazione del QThread
        if hasattr(self, '_initialized') and self._initialized:
            return
        super().__init__()
        self._initialized = True

        # Parametri del filtro
        self.filter_name = None
        self.filter_params = None
        self.default_sequence = False
        self.filters_sequence = None  # lista di tuple (name, params)
        self.reset_before_sequence = True
        
        # Survey manager
        self.survey_manager = SurveyManager()
        
        # Mapping nomi filtri -> funzioni
        self.filter_functions = {
            'bandpass_fir': bandpass_fir,
            'background_removal': background_removal,
            'gain': gain,
            'start_time_shifterOGPR': start_time_shifterOGPR
        }

    def set_filter(self, filter_name, params=None):
        """Imposta il filtro da applicare"""
        print(f"ApplyFilter: Imposto filtro {filter_name} con parametri {params}")
        self.filter_name = filter_name
        self.filter_params = params
        self.default_sequence = False  # Modalità manuale
        self.filters_sequence = None

    def set_filter_sequence(self, filters_sequence, reset_before=True):
        """Imposta una sequenza arbitraria di filtri da applicare in ordine.
        filters_sequence: list[ (filter_name, params_dict) ]
        reset_before: se True, svuota Data_current_filter prima della sequenza
        """
        # Normalizzo: ogni params None -> {}
        self.filters_sequence = [(name, (params or {})) for (name, params) in (filters_sequence or [])]
        self.reset_before_sequence = reset_before
        self.default_sequence = False
        self.filter_name = None
        self.filter_params = None

    def run(self):
        try:
            if self.filters_sequence is not None:
                # Sequenza personalizzata
                if self.reset_before_sequence:
                    self._reset_current_filters_for_selected()
                for filter_name, params in self.filters_sequence:
                    print(f"ApplyFilter: Applico filtro {filter_name} (sequenza custom)")
                    self.filter_progress.emit(f"Applicazione filtro {filter_name}...")
                    self._apply_single_filter(filter_name, params, prefer_filtered_input=True)
                    self.filter_progress.emit(f"Filtro {filter_name} completato")
                self.filter_completed.emit()
            elif self.default_sequence:
                print("ApplyFilter: Inizio applicazione filtri predefiniti (modalità sequenza)")
                filters_sequence = [
                    ('start_time_shifterOGPR', {}),
                    ('bandpass_fir', {
                        'dt_ns': 0.125,
                        'fmin_mhz': 150,
                        'fmax_mhz': 850,
                        'order': 71
                    }),
                    ('background_removal', {'window_traces': 0}),
                    ('gain', {})
                ]
                for filter_name, params in filters_sequence:
                    print(f"ApplyFilter: Applico filtro {filter_name}")
                    self.filter_progress.emit(f"Applicazione filtro {filter_name}...")
                    self._apply_single_filter(filter_name, params, prefer_filtered_input=True)
                    print(f"ApplyFilter: Completato filtro {filter_name}")
                    self.filter_progress.emit(f"Filtro {filter_name} completato")
                print("ApplyFilter: Completata applicazione di tutti i filtri")
                self.filter_completed.emit()
            else:
                print(f"ApplyFilter: Inizio applicazione filtro singolo: {self.filter_name}")
                self._apply_single_filter(self.filter_name, self.filter_params, prefer_filtered_input=True)
                self.filter_completed.emit()
        except Exception as e:
            error_msg = f"Errore durante l'applicazione del filtro: {str(e)}"
            print(f"ApplyFilter: {error_msg}")
            self.filter_error.emit(error_msg)

    def _apply_single_filter(self, filter_name, params, prefer_filtered_input=True):
        print(f"ApplyFilter: Inizio esecuzione filtro {filter_name}")
        if filter_name not in self.filter_functions:
            raise ValueError(f"Filtro {filter_name} non riconosciuto")
        filter_func = self.filter_functions[filter_name]
        
        # Ottieni il survey selezionato tramite l'API corretta
        selected_survey_id = self.survey_manager.get_id_selected_survey()
        selected_swath_id = self.survey_manager.get_id_selected_swath()
        print(f"ApplyFilter: Survey selezionato: {selected_survey_id}, Swath: {selected_swath_id}")
        
        # Elabora TUTTE le swath VV del survey selezionato per prime (VV ha 19 canali)
        self.survey_manager.set_current_polarization(0)  # VV
        survey_vv = self.survey_manager[selected_survey_id]
        try:
            num_swath_vv = survey_vv.get_num_tot_swath()
        except Exception:
            num_swath_vv = 0
        for swath_idx in range(num_swath_vv):
            swath_vv = survey_vv[swath_idx]
            num_subswaths = swath_vv.get_num_tot_subswath()
            print(f"ApplyFilter: [VV] Swath {swath_idx} - SubSwath: {num_subswaths}")
            for subswath_idx in range(num_subswaths):
                subswath = swath_vv[subswath_idx]
                num_channels_subswath = len(subswath.file_subswath_channel_list)
                print(f"ApplyFilter: [VV] SubSwath {subswath_idx} - Canali: {num_channels_subswath}")
                for channel_idx in range(num_channels_subswath):
                    try:
                        prev_filtered_path = os.path.join(subswath.data_current_filter_folder, f"channel_{channel_idx:02d}.npy")
                        if prefer_filtered_input and os.path.exists(prev_filtered_path):
                            input_data = np.load(prev_filtered_path)
                        else:
                            input_data = subswath.channel(channel_idx)
                        # Ensure Y height <= 512 consistently
                        if input_data.shape[0] > 512:
                            input_data = input_data[:512, :]
                        # Sanitize and cast to float to avoid integer issues in signal processing
                        input_data = np.nan_to_num(input_data, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32, copy=False)
                        filtered_data = filter_func(input_data, **(params or {}))
                        # Post-filter sanitize to avoid NaNs propagating to subsequent filters (e.g., shifter)
                        filtered_data = np.nan_to_num(filtered_data, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32, copy=False)
                        filtered_path = os.path.join(subswath.data_current_filter_folder, f"channel_{channel_idx:02d}.npy")
                        try:
                            np.save(filtered_path, filtered_data)
                        except Exception:
                            # Ensure directory exists, then save again
                            os.makedirs(subswath.data_current_filter_folder, exist_ok=True)
                            np.save(filtered_path, filtered_data)
                    except Exception as e:
                        print(f"Errore [HH] Swath {swath_idx} SubSwath {subswath_idx} Canale {channel_idx}: {str(e)}")
                        continue

        # Elabora TUTTE le swath HH del survey selezionato
        self.survey_manager.set_current_polarization(1)  # HH
        survey_hh = self.survey_manager[selected_survey_id]
        try:
            num_swath_hh = survey_hh.get_num_tot_swath()
        except Exception:
            num_swath_hh = 0
        for swath_idx in range(num_swath_hh):
            swath_hh = survey_hh[swath_idx]
            num_subswaths = swath_hh.get_num_tot_subswath()
            print(f"ApplyFilter: [HH] Swath {swath_idx} - SubSwath: {num_subswaths}")
            for subswath_idx in range(num_subswaths):
                subswath = swath_hh[subswath_idx]
                num_channels_subswath = len(subswath.file_subswath_channel_list)
                print(f"ApplyFilter: [HH] SubSwath {subswath_idx} - Canali: {num_channels_subswath}")
                for channel_idx in range(num_channels_subswath):
                    try:
                        prev_filtered_path = os.path.join(subswath.data_current_filter_folder, f"channel_{channel_idx:02d}.npy")
                        if prefer_filtered_input and os.path.exists(prev_filtered_path):
                            input_data = np.load(prev_filtered_path)
                        else:
                            input_data = subswath.channel(channel_idx)
                        # Ensure Y height <= 512 consistently
                        if input_data.shape[0] > 512:
                            input_data = input_data[:512, :]
                        input_data = np.nan_to_num(input_data, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32, copy=False)
                        filtered_data = filter_func(input_data, **(params or {}))
                        filtered_data = np.nan_to_num(filtered_data, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32, copy=False)
                        filtered_path = os.path.join(subswath.data_current_filter_folder, f"channel_{channel_idx:02d}.npy")
                        try:
                            np.save(filtered_path, filtered_data)
                        except Exception:
                            os.makedirs(subswath.data_current_filter_folder, exist_ok=True)
                            np.save(filtered_path, filtered_data)
                    except Exception as e:
                        print(f"Errore [VV] Swath {swath_idx} SubSwath {subswath_idx} Canale {channel_idx}: {str(e)}")
                        continue

        print(f"ApplyFilter: Completata elaborazione filtro {filter_name}")

    def _reset_current_filters_for_selected(self):
        """Svuota le cartelle Data_current_filter della swath selezionata (HH e VV)."""
        selected_survey_id = self.survey_manager.get_id_selected_survey()
        selected_swath_id = self.survey_manager.get_id_selected_swath()
        # HH
        self.survey_manager.set_current_polarization(1)
        swath = self.survey_manager[selected_survey_id][selected_swath_id]
        for subswath_idx in range(swath.get_num_tot_subswath()):
            subswath = swath[subswath_idx]
            try:
                for fn in os.listdir(subswath.data_current_filter_folder):
                    if fn.endswith('.npy'):
                        os.remove(os.path.join(subswath.data_current_filter_folder, fn))
            except Exception:
                pass
        # VV
        self.survey_manager.set_current_polarization(0)
        swath = self.survey_manager[selected_survey_id][selected_swath_id]
        for subswath_idx in range(swath.get_num_tot_subswath()):
            subswath = swath[subswath_idx]
            try:
                for fn in os.listdir(subswath.data_current_filter_folder):
                    if fn.endswith('.npy'):
                        os.remove(os.path.join(subswath.data_current_filter_folder, fn))
            except Exception:
                pass