from scipy.signal import firwin, filtfilt
import numpy as np
from scipy.ndimage import gaussian_filter1d

def correct_radargram_orientation(data):
    """
    Corregge l'orientamento del radargramma ruotandolo per avere:
    - Asse X: Distanza (tracce)
    - Asse Y: Profondità (campioni)
    
    Parametri
    ---------
    data : ndarray
        Radargramma con orientamento da correggere.
        
    Ritorna
    -------
    data_corrected : ndarray (nsamples, ntraces)
        Radargramma con orientamento corretto.
    """
    # Se i dati hanno shape (ntraces, nsamples), li ruotiamo di 90° antiorario
    if data.shape[0] < data.shape[1]:
        print(f"Correzione orientamento: {data.shape} -> ruotazione 90° antiorario")
        # Ruota di 90° antiorario: np.rot90(data, k=1)
        # Questo equivale a trasporre e poi flipare verticalmente
        data_corrected = np.rot90(data, k=1)
        print(f"Shape dopo rotazione: {data_corrected.shape}")
        return data_corrected
    else:
        print(f"Orientamento già corretto: {data.shape}")
        return data

# def bandpass_fir(data, dt_ns=0.125, fmin_mhz=150, fmax_mhz=850, order=71):
def bandpass_fir(data, dt_ns=0.125, fmin_mhz=150, fmax_mhz=850, order=71):
    """
    Applica un filtro FIR passa-banda tra fmin_mhz e fmax_mhz (in MHz)
    a un radargramma 'data' di forma (nsamples, ntraces).

    Parametri
    ---------
    data : ndarray (nsamples, ntraces)
        Radargramma (una traccia per colonna).
    dt_ns : float
        Passo di campionamento in nanosecondi.
    fmin_mhz : float
        Frequenza di taglio inferiore in MHz.
    fmax_mhz : float
        Frequenza di taglio superiore in MHz.
    order : int
        Ordine del filtro FIR.

    Ritorna
    -------
    data_out : ndarray (nsamples, ntraces)
        Radargramma filtrato.
    """
    # Calcolo della frequenza di campionamento in Hz
    fs = 1e9 / dt_ns
    nyq = fs / 2.0

    # Normalizzo le frequenze di taglio in [0,1]
    low = (fmin_mhz * 1e6) / nyq
    high = (fmax_mhz * 1e6) / nyq

    # Calcolo order adattivo basato sulla lunghezza del radargramma
    min_length = 3 * order  # Lunghezza minima richiesta da filtfilt
    if data.shape[0] < min_length:
        # Calcola order adattivo più conservativo per evitare errori di padding
        # filtfilt usa internamente un padding di 3*(numtaps-1), quindi:
        # data.shape[0] > 3*(adaptive_order-1) 
        # adaptive_order < (data.shape[0]/3) + 1
        adaptive_order = max(5, (data.shape[0] - 1) // 3)
        
        # Verifica ulteriore per sicurezza - se ancora troppo grande, riduci drasticamente
        if 3 * (adaptive_order - 1) >= data.shape[0]:
            adaptive_order = max(3, data.shape[0] // 6)  # Molto conservativo
            
        print(f"Warning: Radargramma troppo corto ({data.shape[0]} campioni) per filtro order={order}")
        print(f"Usando order adattivo={adaptive_order} per mantenere filtfilt e qualità ottimale")
        
        try:
            # Ricalcola il filtro con order adattivo
            b_adaptive = firwin(numtaps=adaptive_order, cutoff=[low, high], pass_zero=False, window='blackman')
            data_out = filtfilt(b_adaptive, [1.0], data, axis=0)
        except Exception as e:
            # Fallback finale: usa lfilter se filtfilt continua a fallire
            print(f"Errore con filtfilt anche con order adattivo: {str(e)}")
            print(f"Fallback a lfilter per evitare crash")
            from scipy.signal import lfilter
            data_out = lfilter(b_adaptive, [1.0], data, axis=0)
    else:
        # Usa order originale per qualità ottimale
        b = firwin(numtaps=order, cutoff=[low, high], pass_zero=False, window='blackman')
        data_out = filtfilt(b, [1.0], data, axis=0)
    # data_out = lfilter(b, [1.0], data, axis=0)

    return data_out

def background_removal(data, window_traces=0):
    """
    Rimuove il background calcolato come media locale su una finestra di dimensione
    'window_traces' lungo l'asse delle tracce.
    
    Parametri
    ---------
    data : ndarray (nsamples, ntraces)
        Radargramma di input.
    window_traces : int
        Dimensione della finestra per il calcolo della media locale.
        
    Ritorna
    -------
    output : ndarray (nsamples, ntraces)
        Radargramma con background rimosso.
    """
    output = np.zeros_like(data)
    for i in range(0, data.shape[0]):
        w_min = max(i - window_traces, 0)
        w_max = min(i + window_traces, data.shape[0])
        a = np.average(data[w_min:w_max+1, :])
        output[i, :] = data[i, :] - a
    return output

def gain(data):
    """
    Applica un guadagno adattivo basato sui percentili per normalizzare l'ampiezza.
    
    Parametri
    ---------
    data : ndarray (nsamples, ntraces)
        Radargramma di input.
        
    Ritorna
    -------
    output : ndarray (nsamples, ntraces)
        Radargramma con guadagno applicato.
    """
    a = np.zeros(data.shape[0])
    b = np.zeros(data.shape[0])
    for i in range(0, data.shape[0]):
        a[i] = np.percentile(data[i, :], 97)
        b[i] = np.percentile(data[i, :], 3)
    a = gaussian_filter1d(a, 10)
    b = gaussian_filter1d(b, 10)
    output = np.zeros_like(data)
    for i in range(0, data.shape[0]):
        temp = data[i, :]
        temp[temp > 0] = temp[temp > 0] / a[i]
        temp[temp < 0] = -temp[temp < 0] / b[i]
        output[i, :] = temp
    return output

def start_time_shifterOGPR(radargram):
    """
    Mantiene solo le righe della matrice a partire dalla prima occorrenza del minimo globale
    successivo al massimo assoluto.
    
    Parametri
    ---------
    radargram : ndarray (nsamples, ntraces)
        Radargramma di input.
        
    Ritorna
    -------
    radargram : ndarray (nsamples_reduced, ntraces)
        Radargramma con tempo di inizio corretto.
    """
    n_tracce, n_campioni = radargram.shape
    max_val = np.max(radargram)
    min_val = np.min(radargram)
    indici_min = np.argmax(radargram == min_val, axis=0)
    mask_min = np.any(radargram == min_val, axis=0)
    indici_min[~mask_min] = -1
    indici_max = np.argmax(radargram == max_val, axis=0)
    mask_max = np.any(radargram == max_val, axis=0)
    indici_max[~mask_max] = -1
    
    # Controllo per evitare errore su array vuoti
    if np.any(mask_min):
        avg_min = int(np.round(np.average(indici_min[mask_min])))
    else:
        avg_min = 0  # Valore di default se nessun minimo trovato
        
    if np.any(mask_max):
        avg_max = int(np.round(np.average(indici_max[mask_max])))
    else:
        avg_max = 0  # Valore di default se nessun massimo trovato
    if (avg_min < avg_max):
        radargram = radargram[avg_max:, :]
        indici_min = np.argmax(radargram == min_val, axis=0)
        mask_min = np.any(radargram == min_val, axis=0)
        indici_min[~mask_min] = -1
        
        # Controllo per evitare errore su array vuoti
        if np.any(mask_min):
            avg_min = int(np.round(np.average(indici_min[mask_min])))
        else:
            avg_min = 0  # Valore di default se nessun minimo trovato
        radargram = radargram[avg_min:, :]
    else:
        radargram = radargram[avg_min:, :]
    return radargram

