import os
from glob import glob
import cv2
import numpy as np
import matplotlib.pyplot as plt
from skimage.morphology import skeletonize
from skimage.measure import label, regionprops
from scipy import ndimage

# ============================
# PARAMETRI (tutti commentati)
# ============================

IMG_DIR = r"C:\Users\Valer\Desktop\lavoro\SP_V2\temp2"      # Cartella che contiene i TIFF/NPY (tutte le immagini devono avere stessa dimensione)
IMG_GLOB = "*.tiff"                  # Pattern dei file da leggere nella cartella (sovrascritto se USE_NPY_IMAGES=True)

# --- Input format ---
# Se True, legge immagini da file .npy (array 2D) con naming depth_XXX.npy
# Se False, legge immagini TIFF con naming depth_XXX_cm.tiff (comportamento attuale)
USE_NPY_IMAGES = True

# --- Preprocessing ---
GAUSS_KSIZE = (3, 3)                # Finestra del blur gaussiano (deve essere dispari); più grande = più smoothing
GAUSS_SIGMA = 1.5                   # Deviazione standard del gaussiano; più alto = più smoothing
CANNY_LOW = 50                      # Soglia bassa per Canny; più bassa = più edge (anche rumore)
CANNY_HIGH = 150                    # Soglia alta per Canny; più alta = meno edge (più selettivo)

# --- Probabilistic Hough Transform (PHT) ---
H_RHO = 1.25                       # Risoluzione del parametro rho (in pixel) nello spazio di Hough
H_THETA = np.pi / 180              # Risoluzione del parametro theta (in radianti): pi/180 = 1 grado
H_THRESHOLD = 35                   # Voti minimi nell’accumulatore di Hough per accettare un segmento
H_MIN_LINE_LEN = 35                 # Lunghezza minima (in pixel) del segmento rilevato
H_MAX_LINE_GAP = 10                 # Gap massimo (in pixel) per unire tratti collineari interrotti

LINE_THICKNESS = 2                  # Spessore (in pixel) con cui rasterizzare i segmenti nella maschera
DILATE_RADIUS = 8                 # Raggio (in pixel) della dilatazione per tollerare piccoli disallineamenti (0 = disattiva)

# --- Hotspot / visualizzazione ---
HOTSPOT_MIN_VOTES = 3              # Soglia (conteggio immagini) per segnalare aree "calde"
SHOW_HOTSPOTS = True                # Se True, plottiamo anche la mappa binaria sopra soglia

# --- Batch processing ---
USE_BATCH_PROCESSING = False            # Se True, processa a batch; se False, processa tutte le immagini insieme
BATCH_SIZE = 40                     # Numero di immagini per batch (solo se USE_BATCH_PROCESSING=True)
BATCH_OVERLAP = 5                   # Overlap tra batch consecutivi (solo se USE_BATCH_PROCESSING=True)

# --- Debug input (confronto TIFF vs NPY) ---
DEBUG_INPUT_STATS = False           # Stampa statistiche su poche immagini per capire range/edge
DEBUG_SAMPLES_TO_PRINT = 3

# ============================
# FUNZIONI
# ============================

def read_gray(path):
    """Legge un'immagine in scala di grigi da TIFF/PNG/JPG o da NPY (2D)."""
    ext = os.path.splitext(path)[1].lower()
    if ext == ".npy":
        arr = np.load(path)
        if arr.ndim == 3:
            # Se multi-canale, converte in luminanza semplice
            arr = arr[..., 0]
        # Porta a uint8 [0,255] per Canny/Hough
        # Normalizza in modo robusto ignorando NaN/Inf e outlier
        a = arr.astype(np.float32)
        a = np.nan_to_num(a, nan=0.0, posinf=0.0, neginf=0.0)
        p1 = np.percentile(a, 1)
        p99 = np.percentile(a, 99)
        if p99 > p1:
            a = np.clip((a - p1) / (p99 - p1), 0.0, 1.0)
        else:
            # Fallback su min/max
            a_min = float(a.min())
            a_max = float(a.max())
            if a_max > a_min:
                a = (a - a_min) / (a_max - a_min)
            else:
                a = np.zeros_like(a, dtype=np.float32)
        img = (a * 255.0).astype(np.uint8)
        return img
    else:
        img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
        if img is None:
            raise FileNotFoundError(f"Impossibile leggere: {path}")
        return img

def detect_lines_mask(gray):
    """
    Esegue: blur -> Canny -> PHT.
    Restituisce una maschera binaria (0/1) con i pixel attraversati da almeno un segmento.
    """
    # Preprocessing
    blur = cv2.GaussianBlur(gray, GAUSS_KSIZE, GAUSS_SIGMA)
    edges = cv2.Canny(blur, CANNY_LOW, CANNY_HIGH)

    if DEBUG_INPUT_STATS:
        # Piccolo dump per capire se ci sono edge
        nonzero_edges = int((edges > 0).sum())
        print(f"   [DEBUG] edges>0: {nonzero_edges}")

    # PHT
    lines = cv2.HoughLinesP(
        edges,
        rho=H_RHO,
        theta=H_THETA,
        threshold=H_THRESHOLD,
        minLineLength=H_MIN_LINE_LEN,
        maxLineGap=H_MAX_LINE_GAP
    )

    # Maschera
    mask = np.zeros_like(gray, dtype=np.uint8)
    if lines is not None:
        for (x1, y1, x2, y2) in lines[:, 0, :]:
            cv2.line(mask, (x1, y1), (x2, y2), 255, LINE_THICKNESS)

    if DILATE_RADIUS > 0 and mask.any():
        k = cv2.getStructuringElement(
            cv2.MORPH_ELLIPSE,
            (2 * DILATE_RADIUS + 1, 2 * DILATE_RADIUS + 1)
        )
        mask = cv2.dilate(mask, k)

    return (mask > 0).astype(np.uint8), lines, edges

def extract_lines_from_heatmap(heatmap, min_votes=3, min_line_length=20, trace_skeleton=True, debug=False):
    """
    Estrae linee dalla heatmap usando skeletonization + line tracing.
    
    Args:
        heatmap: array 2D con i voti accumulati
        min_votes: soglia minima di voti per considerare un pixel
        min_line_length: lunghezza minima delle linee estratte
        trace_skeleton: se True, traccia lo scheletro invece di linee rette
        debug: se True, stampa informazioni di debug
    
    Returns:
        lines: lista di tuple ((x1, y1), (x2, y2), strength) o path se trace_skeleton=True
    """
    if debug:
        print(f"\n=== DEBUG ESTRAZIONE LINEE ===")
        print(f"Heatmap shape: {heatmap.shape}")
        print(f"Heatmap range: {heatmap.min()} - {heatmap.max()}")
        print(f"Soglia min_votes: {min_votes}")
        print(f"Lunghezza minima: {min_line_length}")
    
    # 1) Thresholding: crea maschera binaria dalle zone "calde"
    binary_mask = (heatmap >= min_votes).astype(np.uint8)
    pixels_above_threshold = binary_mask.sum()
    
    if debug:
        print(f"Pixel sopra soglia: {pixels_above_threshold} / {heatmap.size} ({100*pixels_above_threshold/heatmap.size:.2f}%)")
    
    if not binary_mask.any():
        if debug:
            print("❌ Nessun pixel sopra soglia!")
        return []
    
    # 2) Pulizia morfologica: rimuovi rumore piccolo
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
    cleaned = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, kernel)
    pixels_after_cleaning = cleaned.sum()
    
    if debug:
        print(f"Pixel dopo pulizia morfologica: {pixels_after_cleaning} (rimossi: {pixels_above_threshold - pixels_after_cleaning})")
    
    # 3) Skeletonization: riduce a linee di 1 pixel di spessore
    skeleton = skeletonize(cleaned.astype(bool)).astype(np.uint8)
    skeleton_pixels = skeleton.sum()
    
    if debug:
        print(f"Pixel dello scheletro: {skeleton_pixels}")
    
    if not skeleton.any():
        if debug:
            print("❌ Scheletro vuoto dopo skeletonization!")
        return []
    
    if trace_skeleton:
        return trace_skeleton_paths(skeleton, heatmap, min_line_length, debug)
    else:
        return extract_straight_lines(skeleton, heatmap, min_line_length, debug)

def extract_straight_lines(skeleton, heatmap, min_line_length, debug=False):
    """Estrazione classica: linee rette tra estremi"""
    labeled_skeleton = label(skeleton)
    regions = regionprops(labeled_skeleton)
    lines = []
    
    if debug:
        print(f"Componenti connesse trovate: {len(regions)}")
    
    lines_rejected_short = 0
    
    for i, region in enumerate(regions):
        coords = np.array(region.coords)  # [(y, x), ...]
        
        if len(coords) < min_line_length:
            lines_rejected_short += 1
            if debug:
                print(f"Componente {i+1}: {len(coords)} pixel - TROPPO CORTA")
            continue
        
        if debug:
            print(f"Componente {i+1}: {len(coords)} pixel - analizzando...")
        
        # Trova endpoints reali usando connettività
        endpoints = find_skeleton_endpoints(skeleton, coords)
        
        if len(endpoints) >= 2:
            # Usa i primi due endpoints trovati
            (y1, x1), (y2, x2) = endpoints[:2]
            if debug:
                print(f"  Endpoints trovati: ({x1},{y1}) -> ({x2},{y2})")
        else:
            # Fallback: punti più lontani
            distances = np.linalg.norm(coords[:, None] - coords[None, :], axis=2)
            max_idx = np.unravel_index(distances.argmax(), distances.shape)
            y1, x1 = coords[max_idx[0]]
            y2, x2 = coords[max_idx[1]]
            if debug:
                print(f"  Fallback punti lontani: ({x1},{y1}) -> ({x2},{y2})")
        
        # Calcola forza lungo il percorso
        line_coords = get_line_coordinates(x1, y1, x2, y2)
        valid_coords = [(x, y) for x, y in line_coords 
                       if 0 <= x < heatmap.shape[1] and 0 <= y < heatmap.shape[0]]
        
        if valid_coords:
            strengths = [heatmap[y, x] for x, y in valid_coords]
            avg_strength = np.mean(strengths)
            lines.append(((x1, y1), (x2, y2), avg_strength))
            if debug:
                print(f"  ✅ Accettata (forza: {avg_strength:.2f})")
        else:
            if debug:
                print(f"  ❌ Nessuna coordinata valida")
    
    if debug:
        print(f"Linee rifiutate (troppo corte): {lines_rejected_short}")
        print(f"Linee finali: {len(lines)}")
    
    return lines

def trace_skeleton_paths(skeleton, heatmap, min_line_length, debug=False):
    """Traccia OGNI SINGOLO SEGMENTO dello scheletro come linea separata"""
    # Trova endpoints e junction points
    endpoints, junctions = trace_skeleton_endpoints(skeleton)
    
    # Trova coordinate degli endpoints e junctions
    endpoint_coords = np.argwhere(endpoints)  # [(y, x), ...]
    junction_coords = np.argwhere(junctions)  # [(y, x), ...]
    
    if debug:
        print(f"Endpoints trovati: {len(endpoint_coords)}")
        print(f"Junction points trovati: {len(junction_coords)}")
        print(f"Pixel totali scheletro: {skeleton.sum()}")
    
    # Identifica tutti i SEGMENTI dello scheletro
    segments = extract_skeleton_segments(skeleton, endpoints, junctions, debug)
    
    paths = []
    paths_attempted = 0
    paths_too_short = 0
    
    for i, segment in enumerate(segments):
        if len(segment) < min_line_length:
            paths_too_short += 1
            if debug:
                print(f"Segmento {i+1}: {len(segment)} pixel - TROPPO CORTO")
            continue
            
        paths_attempted += 1
        
        # Calcola forza media lungo il segmento
        strengths = [heatmap[y, x] for y, x in segment 
                    if 0 <= x < heatmap.shape[1] and 0 <= y < heatmap.shape[0]]
        
        if strengths:
            avg_strength = np.mean(strengths)
            # Primo e ultimo punto del segmento
            y1, x1 = segment[0]
            y2, x2 = segment[-1]
            paths.append(((x1, y1), (x2, y2), avg_strength, segment))
            if debug:
                print(f"Segmento {i+1}: {len(segment)} pixel da ({x1},{y1}) a ({x2},{y2}) - ✅ Accettato (forza: {avg_strength:.2f})")
        else:
            if debug:
                print(f"Segmento {i+1}: {len(segment)} pixel - ❌ Nessuna forza")
    
    if debug:
        print(f"\nRIASSUNTO:")
        print(f"Segmenti totali identificati: {len(segments)}")
        print(f"Segmenti troppo corti: {paths_too_short}")
        print(f"Segmenti finali: {len(paths)}")
        total_pixels_in_segments = sum(len(seg) for seg in segments)
        print(f"Pixel coperti dai segmenti: {total_pixels_in_segments} / {skeleton.sum()}")
        print(f"Copertura scheletro: {100*total_pixels_in_segments/skeleton.sum():.1f}%")
    
    return paths

def extract_skeleton_segments(skeleton, endpoints, junctions, debug=False):
    """Estrae tutti i segmenti individuali dello scheletro"""
    segments = []
    visited = np.zeros_like(skeleton, dtype=bool)
    
    # Trova tutte le coordinate di endpoints e junctions
    endpoint_coords = set(map(tuple, np.argwhere(endpoints)))
    junction_coords = set(map(tuple, np.argwhere(junctions)))
    special_points = endpoint_coords | junction_coords
    
    if debug:
        print(f"Punti speciali (endpoints + junctions): {len(special_points)}")
    
    # Per ogni punto speciale, traccia tutti i segmenti che partono da lì
    for y_start, x_start in special_points:
        if visited[y_start, x_start]:
            continue
            
        # Trova tutti i vicini non visitati
        neighbors = get_skeleton_neighbors(skeleton, y_start, x_start, visited)
        
        for y_next, x_next in neighbors:
            if visited[y_next, x_next]:
                continue
                
            # Traccia il segmento da questo punto speciale fino al prossimo punto speciale
            segment = trace_segment_between_special_points(
                skeleton, (y_start, x_start), (y_next, x_next), special_points, visited
            )
            
            if segment and len(segment) > 1:
                segments.append(segment)
                if debug:
                    print(f"Segmento trovato: {len(segment)} pixel da ({x_start},{y_start})")
    
    # Traccia eventuali cicli chiusi (senza endpoints o junctions)
    remaining_coords = np.argwhere(skeleton & ~visited)
    for y_start, x_start in remaining_coords:
        if visited[y_start, x_start]:
            continue
            
        # Traccia il ciclo completo
        cycle = trace_cycle(skeleton, (y_start, x_start), visited)
        if cycle and len(cycle) > 1:
            segments.append(cycle)
            if debug:
                print(f"Ciclo trovato: {len(cycle)} pixel da ({x_start},{y_start})")
    
    return segments

def get_skeleton_neighbors(skeleton, y, x, visited):
    """Trova tutti i vicini validi di un punto nello scheletro"""
    neighbors = []
    for dy in [-1, 0, 1]:
        for dx in [-1, 0, 1]:
            if dy == 0 and dx == 0:
                continue
            ny, nx = y + dy, x + dx
            if (0 <= ny < skeleton.shape[0] and 
                0 <= nx < skeleton.shape[1] and 
                skeleton[ny, nx] and 
                not visited[ny, nx]):
                neighbors.append((ny, nx))
    return neighbors

def trace_segment_between_special_points(skeleton, start_point, first_neighbor, special_points, visited):
    """Traccia un segmento tra due punti speciali (endpoint o junction)"""
    segment = [start_point]
    visited[start_point[0], start_point[1]] = True
    
    current = first_neighbor
    
    while current is not None:
        y, x = current
        segment.append(current)
        visited[y, x] = True
        
        # Se raggiungiamo un altro punto speciale, fermiamoci
        if current in special_points and len(segment) > 1:
            break
            
        # Trova il prossimo punto non visitato
        neighbors = get_skeleton_neighbors(skeleton, y, x, visited)
        
        # Se c'è un solo vicino, continua
        if len(neighbors) == 1:
            current = neighbors[0]
        # Se ci sono più vicini, siamo arrivati a un junction
        elif len(neighbors) > 1:
            # Scegliamo il primo e ci fermiamo (gli altri saranno tracciati separatamente)
            break
        else:
            # Nessun vicino, fine del segmento
            break
    
    return segment

def trace_cycle(skeleton, start_point, visited):
    """Traccia un ciclo chiuso nello scheletro"""
    cycle = []
    current = start_point
    
    while current is not None:
        y, x = current
        if visited[y, x]:
            break
            
        cycle.append(current)
        visited[y, x] = True
        
        # Trova il prossimo punto non visitato
        neighbors = get_skeleton_neighbors(skeleton, y, x, visited)
        
        if neighbors:
            current = neighbors[0]  # Prendi il primo vicino disponibile
        else:
            break
    
    return cycle

def find_skeleton_endpoints(skeleton, coords):
    """Trova i veri endpoints dello scheletro usando connettività"""
    endpoints = []
    
    for y, x in coords:
        # Conta i vicini nello scheletro
        neighbors = 0
        for dy in [-1, 0, 1]:
            for dx in [-1, 0, 1]:
                if dy == 0 and dx == 0:
                    continue
                ny, nx = y + dy, x + dx
                if (0 <= ny < skeleton.shape[0] and 
                    0 <= nx < skeleton.shape[1] and 
                    skeleton[ny, nx]):
                    neighbors += 1
        
        # Endpoint ha 1 vicino, junction ha 3+ vicini
        if neighbors == 1:
            endpoints.append((y, x))
    
    return endpoints

def trace_path_from_point(skeleton, start_point, visited):
    """Traccia un percorso nello scheletro partendo da un punto"""
    path = []
    current = start_point
    
    while current is not None:
        y, x = current
        if visited[y, x]:
            break
            
        visited[y, x] = True
        path.append((y, x))
        
        # Trova il prossimo punto non visitato
        next_point = None
        for dy in [-1, 0, 1]:
            for dx in [-1, 0, 1]:
                if dy == 0 and dx == 0:
                    continue
                ny, nx = y + dy, x + dx
                if (0 <= ny < skeleton.shape[0] and 
                    0 <= nx < skeleton.shape[1] and 
                    skeleton[ny, nx] and 
                    not visited[ny, nx]):
                    next_point = (ny, nx)
                    break
            if next_point:
                break
        
        current = next_point
    
    return path

def get_line_coordinates(x1, y1, x2, y2):
    """
    Genera coordinate di tutti i pixel lungo una linea (algoritmo di Bresenham).
    """
    coords = []
    dx = abs(x2 - x1)
    dy = abs(y2 - y1)
    sx = 1 if x1 < x2 else -1
    sy = 1 if y1 < y2 else -1
    err = dx - dy
    
    x, y = x1, y1
    
    while True:
        coords.append((x, y))
        
        if x == x2 and y == y2:
            break
            
        e2 = 2 * err
        if e2 > -dy:
            err -= dy
            x += sx
        if e2 < dx:
            err += dx
            y += sy
    
    return coords

def trace_skeleton_endpoints(skeleton):
    """
    Trova endpoint e junction points nello scheletro per un tracing più preciso.
    """
    # Kernel per contare vicini
    kernel = np.array([[1, 1, 1],
                       [1, 0, 1], 
                       [1, 1, 1]], dtype=np.uint8)
    
    # Conta vicini per ogni pixel dello scheletro
    neighbor_count = cv2.filter2D(skeleton.astype(np.uint8), -1, kernel)
    neighbor_count = neighbor_count * skeleton  # solo sui pixel dello scheletro
    
    # Endpoint: 1 vicino, Junction: 3+ vicini
    endpoints = (neighbor_count == 1)
    junctions = (neighbor_count >= 3)
    
    return endpoints, junctions

def extract_depth_from_filename(filename):
    """
    Estrae la profondità dal nome del file.
    Supporta:
      - TIFF: depth_045_cm.tiff -> 45.0
      - NPY:  depth_045.npy     -> 45.0
    Ritorna None se non trova il pattern.
    """
    import re
    base = os.path.basename(filename)
    # Catch-all: qualunque file che contenga depth_XXX
    m_any = re.search(r'depth_(\d+)', base, re.IGNORECASE)
    if m_any:
        return float(m_any.group(1))
    return None

def calculate_line_depth(line_data, depth_map):
    """
    Calcola la profondità di una linea basandosi sul pixel più in alto (y minima).
    Ritorna la profondità del pixel più superficiale della linea.
    """
    if len(line_data) == 4:  # traced path
        (x1, y1), (x2, y2), strength, path = line_data
        # Trova il pixel con y minima nel path
        min_y = min(y for y, x in path)
        min_y_pixels = [(y, x) for y, x in path if y == min_y]
        # Prendi la profondità del primo pixel con y minima
        y, x = min_y_pixels[0]
        depth = depth_map[y, x] if depth_map[y, x] != np.inf else None
    else:  # straight line
        (x1, y1), (x2, y2), strength = line_data
        # Prendi il punto con y minore
        if y1 < y2:
            depth = depth_map[int(y1), int(x1)] if depth_map[int(y1), int(x1)] != np.inf else None
        else:
            depth = depth_map[int(y2), int(x2)] if depth_map[int(y2), int(x2)] != np.inf else None
    
    return depth

def accumulate_from_folder(folder, pattern):
    """
    Legge tutte le immagini della cartella, verifica stessa dimensione,
    crea l'accumulatore (uint32) sommando le maschere 0/1.
    Ritorna:
      - accumulator (uint32)
      - paths (lista di file)
      - sample (immagine centrale in grigio)
      - sample_lines (segmenti Hough della sample)
      - sample_edges (bordi Canny della sample)
      - depth_map (dict {y_coord: depth_value}) - mappa profondità per coordinata y
    """
    paths = sorted(glob(os.path.join(folder, pattern)))
    if not paths:
        raise FileNotFoundError("Nessun file trovato con il pattern specificato.")

    # Leggi la prima per dimensione
    first = read_gray(paths[0])
    h, w = first.shape
    acc = np.zeros((h, w), dtype=np.uint32)
    
    # Mappa di profondità: per ogni pixel, memorizziamo la profondità della prima immagine che lo ha attivato
    depth_map = np.full((h, w), fill_value=np.inf, dtype=np.float32)

    # Per visualizzazione "di esempio", useremo l'immagine centrale
    mid_idx = len(paths) // 2
    sample = None
    sample_lines = None
    sample_edges = None

    # Loop
    for i, p in enumerate(paths):
        g = read_gray(p)
        if DEBUG_INPUT_STATS and i < DEBUG_SAMPLES_TO_PRINT:
            g_float = g.astype(np.float32)
            print(f"[DEBUG] {os.path.basename(p)}: min={g_float.min():.3f} max={g_float.max():.3f} mean={g_float.mean():.3f} std={g_float.std():.3f}")
        if g.shape != (h, w):
            raise ValueError(f"Dimensione diversa per {p}: {g.shape}, atteso {(h, w)}")

        mask01, lines, edges = detect_lines_mask(g)
        acc += mask01.astype(np.uint32)
        
        # Estrai profondità dal nome file
        depth = extract_depth_from_filename(p)
        if depth is not None:
            # Per ogni pixel attivo in questa maschera, se non ha già una profondità, assegna questa
            active_pixels = mask01 > 0
            depth_map[active_pixels & (depth_map == np.inf)] = depth

        if i == mid_idx:
            sample = g
            sample_lines = lines
            sample_edges = edges

    return acc, paths, sample, sample_lines, sample_edges, depth_map

def process_batches(folder, pattern, batch_size=40, overlap=5):
    """
    Processa le immagini in batch con overlap.
    Ritorna una lista di risultati, uno per batch.
    """
    paths = sorted(glob(os.path.join(folder, pattern)))
    if not paths:
        raise FileNotFoundError("Nessun file trovato con il pattern specificato.")
    
    total_images = len(paths)
    step = batch_size - overlap
    
    # Calcola quanti batch servono
    batch_results = []
    batch_idx = 0
    
    print(f"Totale immagini: {total_images}")
    print(f"Batch size: {batch_size}, Overlap: {overlap}, Step: {step}")
    
    start_idx = 0
    while start_idx < total_images:
        end_idx = min(start_idx + batch_size, total_images)
        batch_paths = paths[start_idx:end_idx]
        
        print(f"\n{'='*60}")
        print(f"BATCH {batch_idx + 1}: immagini {start_idx} - {end_idx-1} ({len(batch_paths)} immagini)")
        print(f"{'='*60}")
        
        # Processa questo batch
        batch_result = process_single_batch(batch_paths, batch_idx, start_idx)
        batch_results.append(batch_result)
        
        batch_idx += 1
        start_idx += step
        
        # Se il prossimo batch sarebbe troppo piccolo, fermati
        if start_idx + batch_size > total_images and end_idx == total_images:
            break
    
    print(f"\n{'='*60}")
    print(f"Totale batch processati: {len(batch_results)}")
    print(f"{'='*60}\n")
    
    return batch_results

def process_single_batch(batch_paths, batch_idx, start_idx):
    """
    Processa un singolo batch di immagini.
    """
    # Leggi la prima per dimensione
    first = read_gray(batch_paths[0])
    h, w = first.shape
    acc = np.zeros((h, w), dtype=np.uint32)
    
    # Mappa di profondità
    depth_map = np.full((h, w), fill_value=np.inf, dtype=np.float32)
    
    # Immagine centrale del batch
    mid_idx = len(batch_paths) // 2
    sample = None
    sample_lines = None
    sample_edges = None
    
    # Loop sulle immagini del batch
    for i, p in enumerate(batch_paths):
        g = read_gray(p)
        if g.shape != (h, w):
            raise ValueError(f"Dimensione diversa per {p}: {g.shape}, atteso {(h, w)}")
        
        mask01, lines, edges = detect_lines_mask(g)
        acc += mask01.astype(np.uint32)
        
        # Estrai profondità
        depth = extract_depth_from_filename(p)
        if depth is not None:
            active_pixels = mask01 > 0
            depth_map[active_pixels & (depth_map == np.inf)] = depth
        
        if i == mid_idx:
            sample = g
            sample_lines = lines
            sample_edges = edges
    
    return {
        'batch_idx': batch_idx,
        'start_idx': start_idx,
        'end_idx': start_idx + len(batch_paths) - 1,
        'num_images': len(batch_paths),
        'accumulator': acc,
        'depth_map': depth_map,
        'paths': batch_paths,
        'sample': sample,
        'sample_lines': sample_lines,
        'sample_edges': sample_edges
    }

# ============================
# ESECUZIONE
# ============================

if USE_BATCH_PROCESSING:
    print(f"\n{'='*80}")
    print(f"MODALITÀ: BATCH PROCESSING (batch size={BATCH_SIZE}, overlap={BATCH_OVERLAP})")
    print(f"{'='*80}\n")
    
    # Processa in batch
    # Scegli pattern in base al formato
    pattern = "*.npy" if USE_NPY_IMAGES else IMG_GLOB
    batch_results = process_batches(IMG_DIR, pattern, batch_size=BATCH_SIZE, overlap=BATCH_OVERLAP)
    
    # Accumulatori globali per tutti i batch
    all_extracted_lines = []
    global_accumulator = None
    global_depth_map = None
    
    # Processa ogni batch separatamente
    for batch_result in batch_results:
        batch_idx = batch_result['batch_idx']
        accumulator = batch_result['accumulator']
        depth_map = batch_result['depth_map']
        paths = batch_result['paths']
        sample_gray = batch_result['sample']
        sample_lines = batch_result['sample_lines']
        sample_edges = batch_result['sample_edges']
        
        print(f"\n{'#'*70}")
        print(f"### PROCESSING BATCH {batch_idx + 1}: immagini {batch_result['start_idx']}-{batch_result['end_idx']} ###")
        print(f"{'#'*70}\n")
        
        print(f"Immagini nel batch: {batch_result['num_images']}")
        print(f"Valore massimo nell'accumulatore (voti pixel): {int(accumulator.max())}")
        
        # Info depth map
        valid_depths = depth_map[depth_map != np.inf]
        if len(valid_depths) > 0:
            print(f"Depth map: {len(valid_depths)} pixel con profondità (range: {valid_depths.min():.1f}-{valid_depths.max():.1f} cm)")
        else:
            print("⚠️ Nessuna profondità estratta dai nomi file!")
        
        print(f"\n=== PARAMETRI CORRENTI ===")
        print(f"HOTSPOT_MIN_VOTES: {HOTSPOT_MIN_VOTES}")
        print(f"H_MIN_LINE_LEN: {H_MIN_LINE_LEN}")
        print(f"Kernel pulizia morfologica: 3x3 ellisse")
        print(f"Pixel con almeno 1 voto: {(accumulator > 0).sum()}")
        print(f"Pixel con ≥{HOTSPOT_MIN_VOTES} voti: {(accumulator >= HOTSPOT_MIN_VOTES).sum()}")
        
        # ============================
        # ESTRAZIONE LINEE DA HEATMAP
        # ============================
        
        # Estrai linee dalla heatmap usando skeletonization
        # Prova entrambi i metodi per confronto
        extracted_lines_straight = extract_lines_from_heatmap(
            accumulator, 
            min_votes=HOTSPOT_MIN_VOTES,
            min_line_length=H_MIN_LINE_LEN,
            trace_skeleton=False,  # linee rette
            debug=False  # disabilita debug per batch
        )
        
        extracted_lines_traced = extract_lines_from_heatmap(
            accumulator, 
            min_votes=HOTSPOT_MIN_VOTES,
            min_line_length=H_MIN_LINE_LEN, 
            trace_skeleton=True,  # segue lo scheletro
            debug=False  # disabilita debug per batch
        )
        
        # Usa le linee tracciate come default
        extracted_lines = extracted_lines_traced
        
        print(f"Linee rette estratte: {len(extracted_lines_straight)}")
        print(f"Linee tracciate estratte: {len(extracted_lines_traced)}")
        
        if extracted_lines:
            strengths = [line_data[2] for line_data in extracted_lines]  # strength è sempre al 3° posto
            max_votes_local = int(accumulator.max())
            if max_votes_local > 0:
                strengths_norm = [s / max_votes_local for s in strengths]
                print(f"Forza media (normalizzata 0-1): {np.mean(strengths_norm):.2f} (min: {min(strengths_norm):.2f}, max: {max(strengths_norm):.2f})")
            else:
                print("Forza media (normalizzata 0-1): n/d (max_votes=0)")
        
        # Ordina per forza decrescente
        extracted_lines.sort(key=lambda x: x[2], reverse=True)
        
        # Accumula i risultati globali
        all_extracted_lines.extend(extracted_lines)
        
        # Accumula gli accumulator e depth_map globali
        if global_accumulator is None:
            global_accumulator = accumulator.copy()
            global_depth_map = depth_map.copy()
        else:
            # Somma gli accumulator (gestendo overlap)
            global_accumulator = np.maximum(global_accumulator, accumulator)
            # Per depth map, mantieni il valore minimo (più superficiale)
            global_depth_map = np.minimum(global_depth_map, depth_map)
        
        # ============================
        # PLOT INDIVIDUALI DEL BATCH (opzionali, commentati per default)
        # ============================
        # Decommenta se vuoi vedere i plot di ogni singolo batch
        """
        # 1) Heatmap dell'accumulatore normalizzata (valori 0-1)
        min_votes = accumulator.min()
        max_votes = accumulator.max()
        if max_votes > min_votes:
            accumulator_normalized = (accumulator.astype(np.float32) - min_votes) / (max_votes - min_votes)
        else:
            accumulator_normalized = np.zeros_like(accumulator, dtype=np.float32)
        
        plt.figure(figsize=(8, 6))
        plt.imshow(accumulator_normalized, interpolation="nearest", vmin=0, vmax=1)
        plt.title(f"Batch {batch_idx+1}: Heatmap normalizzata")
        plt.xlabel("X")
        plt.ylabel("Y")
        plt.colorbar(label="Probabilità voto (0-1)")
        plt.tight_layout()
        """

else:
    # ============================
    # MODALITÀ PROCESSING COMPLETO (TUTTE LE IMMAGINI INSIEME)
    # ============================
    print(f"\n{'='*80}")
    print(f"MODALITÀ: PROCESSING COMPLETO (tutte le immagini insieme)")
    print(f"{'='*80}\n")
    
    # Processa tutte le immagini insieme
    # Scegli pattern in base al formato
    pattern = "*.npy" if USE_NPY_IMAGES else IMG_GLOB
    accumulator, paths, sample_gray, sample_lines, sample_edges, depth_map = accumulate_from_folder(IMG_DIR, pattern)
    
    print(f"Immagini lette: {len(paths)}")
    print(f"Valore massimo nell'accumulatore (voti pixel): {int(accumulator.max())}")
    
    # Info depth map
    valid_depths = depth_map[depth_map != np.inf]
    if len(valid_depths) > 0:
        print(f"Depth map: {len(valid_depths)} pixel con profondità (range: {valid_depths.min():.1f}-{valid_depths.max():.1f} cm)")
    else:
        print("⚠️ Nessuna profondità estratta dai nomi file!")
    
    print(f"\n=== PARAMETRI CORRENTI ===")
    print(f"HOTSPOT_MIN_VOTES: {HOTSPOT_MIN_VOTES}")
    print(f"H_MIN_LINE_LEN: {H_MIN_LINE_LEN}")
    print(f"Kernel pulizia morfologica: 3x3 ellisse")
    print(f"Pixel con almeno 1 voto: {(accumulator > 0).sum()}")
    print(f"Pixel con ≥{HOTSPOT_MIN_VOTES} voti: {(accumulator >= HOTSPOT_MIN_VOTES).sum()}")
    
    # ============================
    # ESTRAZIONE LINEE DA HEATMAP
    # ============================
    
    # Estrai linee dalla heatmap usando skeletonization
    extracted_lines_straight = extract_lines_from_heatmap(
        accumulator, 
        min_votes=HOTSPOT_MIN_VOTES,
        min_line_length=H_MIN_LINE_LEN,
        trace_skeleton=False,  # linee rette
        debug=False
    )
    
    extracted_lines_traced = extract_lines_from_heatmap(
        accumulator, 
        min_votes=HOTSPOT_MIN_VOTES,
        min_line_length=H_MIN_LINE_LEN, 
        trace_skeleton=True,  # segue lo scheletro
        debug=False
    )
    
    # Usa le linee tracciate come default
    all_extracted_lines = extracted_lines_traced
    
    print(f"Linee rette estratte: {len(extracted_lines_straight)}")
    print(f"Linee tracciate estratte: {len(extracted_lines_traced)}")
    
    if all_extracted_lines:
        strengths = [line_data[2] for line_data in all_extracted_lines]
        max_votes_local = int(accumulator.max())
        if max_votes_local > 0:
            strengths_norm = [s / max_votes_local for s in strengths]
            print(f"Forza media (normalizzata 0-1): {np.mean(strengths_norm):.2f} (min: {min(strengths_norm):.2f}, max: {max(strengths_norm):.2f})")
        else:
            print("Forza media (normalizzata 0-1): n/d (max_votes=0)")
    
    # Ordina per forza decrescente
    all_extracted_lines.sort(key=lambda x: x[2], reverse=True)
    
    # Imposta i dati globali (per compatibilità con il codice di plotting finale)
    global_accumulator = accumulator
    global_depth_map = depth_map

# ============================
# PLOT FINALE COMBINATO (TUTTI I BATCH)
# ============================

print(f"\n{'#'*80}")
print(f"### RISULTATI FINALI COMBINATI DA TUTTI I BATCH ###")
print(f"{'#'*80}\n")

print(f"Totale linee estratte da tutti i batch: {len(all_extracted_lines)}")

if all_extracted_lines and global_accumulator is not None:
    # Ordina tutte le linee per forza
    all_extracted_lines.sort(key=lambda x: x[2], reverse=True)
    
    # Normalizzazione dell'accumulator globale
    min_votes_global = global_accumulator.min()
    max_votes_global = global_accumulator.max()
    if max_votes_global > min_votes_global:
        global_accumulator_normalized = (global_accumulator.astype(np.float32) - min_votes_global) / (max_votes_global - min_votes_global)
    else:
        global_accumulator_normalized = np.zeros_like(global_accumulator, dtype=np.float32)
    
    print(f"Range accumulatore globale: {min_votes_global} - {max_votes_global} voti")
    
    # 1) Heatmap normalizzata globale
    plt.figure(figsize=(8, 6))
    plt.imshow(global_accumulator_normalized, interpolation="nearest", vmin=0, vmax=1)
    plt.title(f"Heatmap Globale Normalizzata (tutti i batch)")
    plt.xlabel("X")
    plt.ylabel("Y")
    plt.colorbar(label="Probabilità voto (0-1)")
    plt.tight_layout()
    
    # 2) Heatmap globale con conteggio assoluto
    plt.figure(figsize=(8, 6))
    plt.imshow(global_accumulator, interpolation="nearest")
    plt.title(f"Heatmap Globale (conteggio voti)")
    plt.xlabel("X")
    plt.ylabel("Y")
    plt.colorbar(label="Voti (n° immagini)")
    plt.tight_layout()
    
    # 3) Hotspot globale
    if SHOW_HOTSPOTS:
        hotspot_global = (global_accumulator >= HOTSPOT_MIN_VOTES).astype(np.uint8)
        plt.figure(figsize=(8, 6))
        plt.imshow(hotspot_global, interpolation="nearest")
        plt.title(f"Hotspot Globale (≥ {HOTSPOT_MIN_VOTES} voti)")
        plt.xlabel("X")
        plt.ylabel("Y")
        plt.colorbar(label="Maschera (0/1)")
        plt.tight_layout()
    
    # 4) Plot finale: Heatmap + TUTTE le linee estratte
    plt.figure(figsize=(14, 10))
    
    # Subplot 1: Heatmap + tutte le linee
    plt.subplot(1, 2, 1)
    plt.imshow(global_accumulator_normalized, interpolation="nearest", vmin=0, vmax=1, cmap='hot')
    plt.title(f"Heatmap Globale + Tutte le Linee ({len(all_extracted_lines)})")
    
    # Disegna TUTTE le linee estratte da tutti i batch
    for i, line_data in enumerate(all_extracted_lines): 
        color = plt.cm.cool(i / max(1, len(all_extracted_lines)))  # colori diversi
        
        if len(line_data) == 4:  # traced paths: ((x1,y1), (x2,y2), strength, path)
            (x1, y1), (x2, y2), strength, path = line_data
            # Disegna il percorso completo
            if len(path) > 1:
                path_x = [x for y, x in path]
                path_y = [y for y, x in path]
                plt.plot(path_x, path_y, color=color, linewidth=2, alpha=0.7)
            else:
                plt.plot([x1, x2], [y1, y2], color=color, linewidth=2, alpha=0.7)
        else:  # straight lines: ((x1,y1), (x2,y2), strength)
            (x1, y1), (x2, y2), strength = line_data
            plt.plot([x1, x2], [y1, y2], color=color, linewidth=2, alpha=0.7)
        
        # Aggiungi etichetta con profondità (solo per le prime 50 per non sovraffollare)
        if i < 50:
            depth = calculate_line_depth(line_data, global_depth_map)
            mid_x, mid_y = (x1 + x2) / 2, (y1 + y2) / 2
            if depth is not None:
                plt.text(mid_x, mid_y, f'{depth:.0f}', color='white', fontsize=7, 
                        ha='center', va='center', weight='bold')
            else:
                plt.text(mid_x, mid_y, 'N/A', color='white', fontsize=6, 
                        ha='center', va='center', weight='bold')
    
    plt.xlabel("X")
    plt.ylabel("Y")
    plt.colorbar(label="Probabilità voto (0-1)")
    
    # Subplot 2: Solo scheletro globale per debug
    binary_mask_global = (global_accumulator >= HOTSPOT_MIN_VOTES).astype(np.uint8)
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
    cleaned_global = cv2.morphologyEx(binary_mask_global, cv2.MORPH_OPEN, kernel)
    skeleton_global = skeletonize(cleaned_global.astype(bool)).astype(np.uint8)
    
    plt.subplot(1, 2, 2)
    plt.imshow(skeleton_global, cmap='gray', interpolation="nearest")
    plt.title("Scheletro Globale")
    plt.xlabel("X")
    plt.ylabel("Y")
    
    plt.tight_layout()
    
    # 5) Tabella riassuntiva FINALE
    print("\n" + "="*80)
    print("TUTTE LE LINEE ESTRATTE (ordinate per forza normalizzata)")
    print("="*80)
    print(f"{'#':<3} {'Punto1':<12} {'Punto2':<12} {'Lunghezza':<10} {'Forza(0-1)':<10} {'Profondità(cm)':<15}")
    print("-"*80)
    
    for i, line_data in enumerate(all_extracted_lines[:30]):  # prime 30
        depth = calculate_line_depth(line_data, global_depth_map)
        depth_str = f"{depth:.1f}" if depth is not None else "N/A"
        
        if len(line_data) == 4:  # traced paths
            (x1, y1), (x2, y2), strength, path = line_data
            path_length = len(path)
            euclidean_length = np.hypot(x2 - x1, y2 - y1)
            norm_strength = (strength - min_votes_global) / (max_votes_global - min_votes_global) if max_votes_global > min_votes_global else 0.0
            print(f"{i+1:<3} ({x1:3.0f},{y1:3.0f})<->({x2:3.0f},{y2:3.0f}) {euclidean_length:6.1f}px (path:{path_length:3d}) {norm_strength:6.2f}     {depth_str:<15}")
        else:  # straight lines
            (x1, y1), (x2, y2), strength = line_data
            length = np.hypot(x2 - x1, y2 - y1)
            norm_strength = (strength - min_votes_global) / (max_votes_global - min_votes_global) if max_votes_global > min_votes_global else 0.0
            print(f"{i+1:<3} ({x1:3.0f},{y1:3.0f})<->({x2:3.0f},{y2:3.0f}) {length:8.1f}px {norm_strength:6.2f}     {depth_str:<15}")
    
    # ============================
    # EXPORT JSON PER GEOMETRIE.PY
    # ============================
    print("\n" + "="*80)
    print("EXPORT JSON PER GEOMETRIE.PY")
    print("="*80)
    
    # Crea la struttura JSON richiesta da Geometrie.py
    # Formato: dizionario con chiavi "pipe_N" per ogni tubo
    pipes_json = {}
    
    for idx, line_data in enumerate(all_extracted_lines):
        if len(line_data) == 4:  # traced paths
            (x1, y1), (x2, y2), strength, path = line_data
        else:  # straight lines
            (x1, y1), (x2, y2), strength = line_data
        
        # Ottieni la profondità
        depth = calculate_line_depth(line_data, global_depth_map)
        z_value = depth if depth is not None else 0.0
        
        # Crea la pipe in formato richiesto da read_json di Geometrie.py
        pipe_key = f"pipe_{idx:04d}"
        pipes_json[pipe_key] = {
            "regPoints": [
                [float(x1), float(y1), float(z_value)],  # start point
                [float(x2), float(y2), float(z_value)]   # end point
            ],
            "best_score": float(strength),
            "diametro": 0.0,
            "diametroOutOfFlag": False,
            "diametroRicercaRicorsiva": False,
            "Distanza Percorso": float(np.hypot(x2 - x1, y2 - y1)),
            "edgeIntensities": [],
            "Epoca": 0,
            "MatrixValues": [],
            "mediaValoriBaricentri": 0.0,
            "medianaValueBaricentri": 0.0,
            "points": [],
            "valuesBaricentri": [],
            "dpPoints": [],
            "pcaAngles": [],
            "pcaStartAngles": [],
            "depth_cm": float(depth) if depth is not None else None,
            "source": "multi_PHT"
        }
    
    # Crea la struttura completa con Offset
    output_json = {
        "Offset": {
            "Left": 0,
            "Top": 0,
            "Z": 0
        }
    }
    # Aggiungi le pipes al dizionario
    output_json.update(pipes_json)
    
    # Salva il JSON nella cartella corrente
    import json
    output_file = "final_output_with_analysis.json"
    
    with open(output_file, 'w') as f:
        json.dump(output_json, f, indent=4)
    
    print(f"✅ JSON esportato con successo: {output_file}")
    print(f"   - Numero di pipes: {len(pipes_json)}")
    print(f"   - Percorso TIFF: {IMG_DIR}")
    print(f"   - Pronto per essere usato in Geometrie.py")
    
    print("\n" + "="*80)
    print("TUTTO PRONTO PER GEOMETRIE.PY!")
    print("="*80)
    print("Esegui: python Geometrie.py")
    print("="*80)

# Mostra tutti i plot alla fine
plt.show()
