from PySide6.QtCore import QThread, Signal
import os
import json
import numpy as np
import cv2
import matplotlib.pyplot as plt
import shutil
from skimage.morphology import skeletonize

# Parametri visualizzazione/scheletro per allineare multi_PHT
DILATE_RADIUS = 8  # raggio dilatazione prima dello skeletonize (ellittico)


class RunAIDetection(QThread):
    """
    Thread to run multiPHT algorithm on completed tiles.
    Processes tiles as they are completed by Raster_tile.
    """

    ai_started = Signal(str)
    ai_completed = Signal(str)
    ai_error = Signal(str)
    ai_progress = Signal(str)

    def __init__(self, parent=None) -> None:
        super().__init__(parent)

    def run(self) -> None:
        print("[RunAIDetection] Avvio thread AI Detection...")
        print("[RunAIDetection] Thread in ascolto per signal tile_completed...")
        self.ai_started.emit("AI detection avviata (multiPHT)")
        
        # Il thread rimane attivo e in ascolto dei signal tile_completed
        # Non si completa mai, rimane in attesa
        print("[RunAIDetection] Thread attivo e in ascolto...")
        
        # Mantieni il thread vivo
        self.exec()

    # API chiamabile da altri thread (es. Raster_tile.tile_completed)
    def on_tile_completed(self, tile_id: int, out_dir: str, expected_depth: int):
        print(f"[RunAIDetection] Ricevuto tile_completed: tile {tile_id}, dir {out_dir}, depth {expected_depth}")
        try:
            self.ai_progress.emit(f"Tile completato: {tile_id} ({expected_depth} depth) -> {out_dir}")
            # Processa immediatamente il tile completato
            self._process_single_tile(tile_id, out_dir, expected_depth)
        except Exception as exc:
            print(f"[RunAIDetection] Errore processamento tile {tile_id}: {exc}")
            self.ai_progress.emit(f"Errore processamento tile {tile_id}: {exc}")

    def _preprocess_gray(self, gray: np.ndarray) -> np.ndarray:
        """Replica il pre-processing usato in multi_PHT per i TIFF:
        - normalizzazione robusta per percentili (2-98)
        - scaling a 0-255 uint8
        - Gaussian blur (3x3, sigma=1.5)
        """
        # Safeguard per valori NaN/Inf
        gray = np.nan_to_num(gray, nan=0.0, posinf=0.0, neginf=0.0)
        # Percentile stretch
        p2 = np.percentile(gray, 2)
        p98 = np.percentile(gray, 98)
        if p98 <= p2:
            p2, p98 = float(gray.min()), float(gray.max())
        if p98 <= p2:
            # tutto costante
            scaled = np.zeros_like(gray, dtype=np.uint8)
        else:
            scaled = np.clip((gray - p2) / (p98 - p2), 0.0, 1.0)
            scaled = (scaled * 255.0).astype(np.uint8)
        # Gaussian blur
        blurred = cv2.GaussianBlur(scaled, (3, 3), 1.5)
        return blurred

    def _process_single_tile(self, tile_id: int, out_dir: str, expected_depth: int):
        print(f"[RunAIDetection] Processo tile {tile_id} da {out_dir}")
        try:
            marker = os.path.join(out_dir, 'tile_complete.json')
            if not os.path.isfile(marker):
                print(f"[RunAIDetection] Marker non trovato per tile {tile_id}: {marker}")
                return
            print(f"[RunAIDetection] Marker trovato per tile {tile_id}, leggo metadati...")
            with open(marker, 'r', encoding='utf-8') as f:
                meta = json.load(f)
            print(f"[RunAIDetection] Metadati tile {tile_id}: {meta}")
            # avvia multiPHT in Python sullo stack NPY del tile
            print(f"[RunAIDetection] Avvio multiPHT per tile {tile_id}...")
            self._run_multipht_python(out_dir, expected_depth)
            print(f"[RunAIDetection] multiPHT completato per tile {tile_id}")
            self.ai_progress.emit(f"multiPHT completato per tile {tile_id}")
        except Exception as exc:
            print(f"[RunAIDetection] Errore multiPHT tile {tile_id}: {exc}")
            self.ai_progress.emit(f"Errore multiPHT tile {tile_id}: {exc}")

    def _run_multipht_python(self, tile_dir: str, expected_depth: int):
        # Esegue il multi_PHT sulla serie depth_XXX.npy del tile
        print(f"[RunAIDetection] _run_multipht_python: tile_dir={tile_dir}, expected_depth={expected_depth}")
        try:
            # Carica tutte le immagini depth_*.npy fino a expected_depth
            frames = []
            print(f"[RunAIDetection] Carico {expected_depth} immagini depth_*.npy...")
            for d in range(expected_depth):
                npy_path = os.path.join(tile_dir, f"depth_{d:03d}.npy")
                if not os.path.isfile(npy_path):
                    print(f"[RunAIDetection] File non trovato: {npy_path}")
                    break
                if d == 0 or d == expected_depth - 1:
                    print(f"[RunAIDetection] Carico {npy_path}")
                gray = np.load(npy_path)
                # Applica lo stesso pre-processing dei TIFF
                gray_p = self._preprocess_gray(gray)
                frames.append(gray_p)
                if d == 0 or d == expected_depth - 1:
                    print(f"[RunAIDetection] Immagine {d} caricata: shape {gray.shape}")
            
            if not frames:
                print("[RunAIDetection] Nessuna immagine caricata, esco")
                return
            
            print(f"[RunAIDetection] Caricate {len(frames)} immagini, avvio accumulo PHT...")
            
            # Accumula maschere di linee
            h, w = frames[0].shape
            acc = np.zeros((h, w), dtype=np.uint16)
            lines_detected = []
            
            # Applica PHT su ogni immagine
            for i, img in enumerate(frames):
                if i % 10 == 0 or i == len(frames) - 1:
                    print(f"[RunAIDetection] Processo immagine {i}/{len(frames)}...")
                try:
                    # Applica Canny edge detection (su immagine già pre-processata come TIFF)
                    edges = cv2.Canny(img, 50, 150)
                    
                    # Applica Probabilistic Hough Transform
                    lines = cv2.HoughLinesP(
                        edges,
                        rho=1.25,
                        theta=np.pi/180,
                        threshold=35,
                        minLineLength=35,
                        maxLineGap=10
                    )
                    
                    # Crea maschera delle linee
                    mask = np.zeros_like(img, dtype=np.uint8)
                    if lines is not None:
                        for line in lines:
                            x1, y1, x2, y2 = line[0]
                            cv2.line(mask, (x1, y1), (x2, y2), 255, 2)
                            # Salva informazioni sulla linea
                            length = np.sqrt((x2-x1)**2 + (y2-y1)**2)
                            lines_detected.append({
                                'depth': i,
                                'start': (x1, y1),
                                'end': (x2, y2),
                                'length': length
                            })
                    
                    # Accumula nella heatmap
                    acc += (mask > 0).astype(np.uint16)
                    
                    if i % 10 == 0 or i == len(frames) - 1:
                        print(f"[RunAIDetection] Immagine {i} processata, pixel con linee: {(mask > 0).sum()}")
                    
                except Exception as e:
                    print(f"[RunAIDetection] Errore processamento immagine {i}: {e}")
                    continue
            
            # Statistiche finali
            print(f"\n{'='*80}")
            print(f"RISULTATI MULTIPHT PER TILE")
            print(f"{'='*80}")
            print(f"Immagini processate: {len(frames)}")
            print(f"Valore massimo nell'accumulatore: {acc.max()}")
            print(f"Pixel con almeno 1 voto: {(acc > 0).sum()}")
            print(f"Pixel con ≥3 voti: {(acc >= 3).sum()}")
            print(f"Linee totali rilevate: {len(lines_detected)}")
            
            # Analizza linee per profondità
            depth_stats = {}
            for line in lines_detected:
                depth = line['depth']
                if depth not in depth_stats:
                    depth_stats[depth] = 0
                depth_stats[depth] += 1
            
            # if depth_stats:
            #     print(f"\nLinee per profondità:")
            #     for depth in sorted(depth_stats.keys()):
            #         print(f"  Profondità {depth}: {depth_stats[depth]} linee")
            
            # Salva risultati
            out_acc = os.path.join(tile_dir, 'multipht_acc.npy')
            print(f"\n[RunAIDetection] Salvo risultato in {out_acc}")
            np.save(out_acc, acc)
            print(f"[RunAIDetection] Risultato salvato: shape {acc.shape}, max {acc.max()}")
            
            # Salva statistiche JSON
            stats = {
                'tile_dir': tile_dir,
                'images_processed': int(len(frames)),
                'max_accumulator_value': int(acc.max()),
                'pixels_with_votes': int((acc > 0).sum()),
                'pixels_with_3plus_votes': int((acc >= 3).sum()),
                'total_lines_detected': int(len(lines_detected)),
                'depth_stats': {int(k): int(v) for k, v in depth_stats.items()},
                'lines_details': [
                    {
                        'depth': int(line['depth']),
                        'start': [int(line['start'][0]), int(line['start'][1])],
                        'end': [int(line['end'][0]), int(line['end'][1])],
                        'length': float(line['length'])
                    }
                    for line in lines_detected
                ]
            }
            
            stats_path = os.path.join(tile_dir, 'multipht_stats.json')
            with open(stats_path, 'w', encoding='utf-8') as f:
                import json
                json.dump(stats, f, indent=2, ensure_ascii=False)
            print(f"[RunAIDetection] Statistiche salvate in {stats_path}")
            
            # Salva JSON nel formato richiesto da geometrie.py
            self._save_geometrie_json(tile_dir, lines_detected, acc)
            
            # Crea plot come in multiPHT
            self._create_plots(tile_dir, acc, frames, lines_detected)
            
            # Organizza i file nelle sottocartelle dopo multiPHT
            self._organize_tile_outputs(tile_dir)
            
            print(f"{'='*80}")
            print(f"MULTIPHT COMPLETATO PER TILE")
            print(f"{'='*80}\n")

        except Exception as exc:
            print(f"[RunAIDetection] Errore _run_multipht_python: {exc}")
            self.ai_progress.emit(f"Errore _run_multipht_python: {exc}")

    def _create_plots(self, tile_dir: str, acc: np.ndarray, frames: list, lines_detected: list):
        """Crea plot come in multiPHT"""
        try:
            print(f"[RunAIDetection] Creo plot per tile...")
            
            # Plot 1: Accumulatore heatmap
            plt.figure(figsize=(12, 8))
            plt.imshow(acc, cmap='hot', interpolation='nearest')
            plt.colorbar(label='Voti accumulati')
            plt.title(f'Heatmap Accumulatore PHT - Tile {os.path.basename(tile_dir)}')
            plt.xlabel('X (pixel)')
            plt.ylabel('Y (pixel)')
            
            plot_path = os.path.join(tile_dir, 'multipht_heatmap.png')
            plt.savefig(plot_path, dpi=150, bbox_inches='tight')
            plt.close()
            print(f"[RunAIDetection] Heatmap salvata: {plot_path}")
            
            # Plot 2: Hotspot (pixel con ≥3 voti)
            if acc.max() >= 3:
                hotspot_mask = acc >= 3
                plt.figure(figsize=(12, 8))
                plt.imshow(hotspot_mask, cmap='Reds', interpolation='nearest')
                plt.colorbar(label='Hotspot (≥3 voti)')
                plt.title(f'Hotspot PHT (≥3 voti) - Tile {os.path.basename(tile_dir)}')
                plt.xlabel('X (pixel)')
                plt.ylabel('Y (pixel)')
                
                hotspot_path = os.path.join(tile_dir, 'multipht_hotspots.png')
                plt.savefig(hotspot_path, dpi=150, bbox_inches='tight')
                plt.close()
                print(f"[RunAIDetection] Hotspot salvato: {hotspot_path}")
            
            # Plot 3: Distribuzione linee per profondità
            if lines_detected:
                depth_counts = {}
                for line in lines_detected:
                    depth = line['depth']
                    depth_counts[depth] = depth_counts.get(depth, 0) + 1
                
                depths = sorted(depth_counts.keys())
                counts = [depth_counts[d] for d in depths]
                
                plt.figure(figsize=(10, 6))
                plt.bar(depths, counts, alpha=0.7, color='blue')
                plt.xlabel('Profondità (indice immagine)')
                plt.ylabel('Numero di linee rilevate')
                plt.title(f'Distribuzione Linee per Profondità - Tile {os.path.basename(tile_dir)}')
                plt.grid(True, alpha=0.3)
                
                depth_dist_path = os.path.join(tile_dir, 'multipht_depth_distribution.png')
                plt.savefig(depth_dist_path, dpi=150, bbox_inches='tight')
                plt.close()
                print(f"[RunAIDetection] Distribuzione profondità salvata: {depth_dist_path}")
            
            # Plot 4: Esempio immagine originale + edge detection
            if frames:
                sample_idx = len(frames) // 2  # Immagine centrale
                sample_img = frames[sample_idx]
                edges = cv2.Canny(sample_img, 50, 150)
                
                fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
                
                ax1.imshow(sample_img, cmap='gray')
                ax1.set_title(f'Immagine Preprocessata (profondità {sample_idx})')
                ax1.set_xlabel('X (pixel)')
                ax1.set_ylabel('Y (pixel)')
                
                ax2.imshow(edges, cmap='gray')
                ax2.set_title(f'Edge Detection Canny (profondità {sample_idx})')
                ax2.set_xlabel('X (pixel)')
                ax2.set_ylabel('Y (pixel)')
                
                plt.tight_layout()
                sample_path = os.path.join(tile_dir, 'multipht_sample_edges.png')
                plt.savefig(sample_path, dpi=150, bbox_inches='tight')
                plt.close()
                print(f"[RunAIDetection] Esempio edge detection salvato: {sample_path}")
            
            # Plot 5: Heatmap + tutte le linee estratte (come in multiPHT)
            if lines_detected:
                plt.figure(figsize=(14, 10))
                
                # Subplot 1: Heatmap + tutte le linee
                plt.subplot(1, 2, 1)
                plt.imshow(acc, interpolation="nearest", cmap='hot')
                plt.title(f"Heatmap + Tutte le Linee ({len(lines_detected)})")
                
                # Disegna tutte le linee estratte
                for i, line in enumerate(lines_detected):
                    x1, y1 = line['start']
                    x2, y2 = line['end']
                    color = plt.cm.cool(i / max(1, len(lines_detected)))
                    plt.plot([x1, x2], [y1, y2], color=color, linewidth=2, alpha=0.7)
                    
                    # Aggiungi etichetta con profondità (solo per le prime 20)
                    if i < 20:
                        mid_x, mid_y = (x1 + x2) / 2, (y1 + y2) / 2
                        plt.text(mid_x, mid_y, f"{line['depth']}", color='white', fontsize=7, 
                                ha='center', va='center', weight='bold')
                
                plt.xlabel("X")
                plt.ylabel("Y")
                plt.colorbar(label="Voti accumulati")
                
                # Subplot 2: Scheletro globale
                binary_mask = (acc >= 3).astype(np.uint8)  # HOTSPOT_MIN_VOTES = 3
                kernel_open = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
                cleaned = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, kernel_open)
                # Dilatazione ellittica con r=DILATE_RADIUS per tollerare piccoli disallineamenti
                if DILATE_RADIUS > 0:
                    ksz = 2 * DILATE_RADIUS + 1
                    kernel_dilate = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (ksz, ksz))
                    cleaned = cv2.dilate(cleaned, kernel_dilate)
                skeleton = skeletonize(cleaned.astype(bool)).astype(np.uint8)
                
                plt.subplot(1, 2, 2)
                plt.imshow(skeleton, cmap='gray', interpolation="nearest")
                plt.title("Scheletro Globale")
                plt.xlabel("X")
                plt.ylabel("Y")
                
                plt.tight_layout()
                lines_skeleton_path = os.path.join(tile_dir, 'multipht_lines_skeleton.png')
                plt.savefig(lines_skeleton_path, dpi=150, bbox_inches='tight')
                plt.close()
                print(f"[RunAIDetection] Linee + scheletro salvato: {lines_skeleton_path}")
            
            print(f"[RunAIDetection] Tutti i plot creati per tile {os.path.basename(tile_dir)}")
            
        except Exception as e:
            print(f"[RunAIDetection] Errore creazione plot: {e}")

    def _save_geometrie_json(self, tile_dir: str, lines_detected: list, acc: np.ndarray):
        """Salva JSON nel formato richiesto da geometrie.py"""
        try:
            print(f"[RunAIDetection] Creo JSON per geometrie.py...")
            
            # Crea la struttura JSON richiesta da Geometrie.py
            pipes_json = {}
            
            for idx, line in enumerate(lines_detected):
                x1, y1 = line['start']
                x2, y2 = line['end']
                depth = line['depth']
                length = line['length']
                
                # Calcola forza normalizzata (0-1) basata sui voti nell'accumulatore
                # Per semplicità, usiamo la lunghezza come proxy della forza
                max_length = max([l['length'] for l in lines_detected]) if lines_detected else 1.0
                normalized_strength = length / max_length if max_length > 0 else 0.0
                
                # Crea la pipe in formato richiesto da read_json di Geometrie.py
                pipe_key = f"pipe_{idx:04d}"
                pipes_json[pipe_key] = {
                    "regPoints": [
                        [float(x1), float(y1), float(depth)],  # start point
                        [float(x2), float(y2), float(depth)]   # end point
                    ],
                    "best_score": float(normalized_strength),
                    "diametro": 0.0,
                    "diametroOutOfFlag": False,
                    "diametroRicercaRicorsiva": False,
                    "Distanza Percorso": float(length),
                    "edgeIntensities": [],
                    "Epoca": 0,
                    "MatrixValues": [],
                    "mediaValoriBaricentri": 0.0,
                    "medianaValueBaricentri": 0.0,
                    "points": [],
                    "valuesBaricentri": [],
                    "dpPoints": [],
                    "pcaAngles": [],
                    "pcaStartAngles": [],
                    "depth_cm": float(depth),
                    "source": "multiPHT_tile"
                }
            
            # Crea la struttura completa con Offset
            output_json = {
                "Offset": {
                    "Left": 0,
                    "Top": 0,
                    "Z": 0
                }
            }
            # Aggiungi le pipes al dizionario
            output_json.update(pipes_json)
            
            # Salva il JSON nel formato geometrie.py
            geometrie_path = os.path.join(tile_dir, 'final_output_with_analysis.json')
            with open(geometrie_path, 'w', encoding='utf-8') as f:
                json.dump(output_json, f, indent=4, ensure_ascii=False)
            
            print(f"[RunAIDetection] JSON geometrie.py salvato: {geometrie_path}")
            print(f"[RunAIDetection] Numero di pipes: {len(pipes_json)}")
            
        except Exception as e:
            print(f"[RunAIDetection] Errore salvataggio JSON geometrie.py: {e}")

    def _organize_tile_outputs(self, tile_dir: str):
        """Organizza i file del tile nelle sottocartelle dopo multiPHT"""
        try:
            print(f"[RunAIDetection] Organizzo struttura cartelle per tile {os.path.basename(tile_dir)}...")
            
            # Crea le sottocartelle
            subdirs = ['NPY', 'TIFF', 'PNG', 'Output_JSON', 'JSON']
            for subdir in subdirs:
                subdir_path = os.path.join(tile_dir, subdir)
                os.makedirs(subdir_path, exist_ok=True)
                print(f"[RunAIDetection] Creata cartella: {subdir}")
            
            # Sposta i file depth_XXX.npy in NPY
            npy_files = [f for f in os.listdir(tile_dir) if f.startswith('depth_') and f.endswith('.npy')]
            for npy_file in npy_files:
                src = os.path.join(tile_dir, npy_file)
                dst = os.path.join(tile_dir, 'NPY', npy_file)
                shutil.move(src, dst)
                # print(f"[RunAIDetection] Spostato {npy_file} -> NPY/")
            
            # Sposta i file TIFF in TIFF (se presenti)
            tiff_files = [f for f in os.listdir(tile_dir) if f.endswith('.tiff')]
            for tiff_file in tiff_files:
                src = os.path.join(tile_dir, tiff_file)
                dst = os.path.join(tile_dir, 'TIFF', tiff_file)
                shutil.move(src, dst)
                # print(f"[RunAIDetection] Spostato {tiff_file} -> TIFF/")
            
            # Sposta i file PNG in PNG
            png_files = [f for f in os.listdir(tile_dir) if f.endswith('.png')]
            for png_file in png_files:
                src = os.path.join(tile_dir, png_file)
                dst = os.path.join(tile_dir, 'PNG', png_file)
                shutil.move(src, dst)
                print(f"[RunAIDetection] Spostato {png_file} -> PNG/")
            
            # Sposta multipht_stats.json in Output_JSON
            stats_file = 'multipht_stats.json'
            if os.path.exists(os.path.join(tile_dir, stats_file)):
                src = os.path.join(tile_dir, stats_file)
                dst = os.path.join(tile_dir, 'Output_JSON', stats_file)
                shutil.move(src, dst)
                print(f"[RunAIDetection] Spostato {stats_file} -> Output_JSON/")
            
            # Sposta final_output_with_analysis.json in Output_JSON
            geometrie_file = 'final_output_with_analysis.json'
            if os.path.exists(os.path.join(tile_dir, geometrie_file)):
                src = os.path.join(tile_dir, geometrie_file)
                dst = os.path.join(tile_dir, 'Output_JSON', geometrie_file)
                shutil.move(src, dst)
                print(f"[RunAIDetection] Spostato {geometrie_file} -> Output_JSON/")
            
            # Sposta tutti gli altri JSON in JSON
            json_files = [f for f in os.listdir(tile_dir) if f.endswith('.json')]
            for json_file in json_files:
                src = os.path.join(tile_dir, json_file)
                dst = os.path.join(tile_dir, 'JSON', json_file)
                shutil.move(src, dst)
                print(f"[RunAIDetection] Spostato {json_file} -> JSON/")
            
            print(f"[RunAIDetection] Struttura cartelle completata per tile {os.path.basename(tile_dir)}")
            print(f"[RunAIDetection] multipht_acc.npy rimane nel root del tile")
            
        except Exception as e:
            print(f"[RunAIDetection] Errore organizzazione cartelle: {e}")


