# FUNZIONI PER GESTIONE FILE OGPR
# - ESTRAZIONE DATI HEADER E SALVATAGGIO IN UN JSON
# - ESTRAZIONE DEI DATI DEI RADARGRAMMI
# - ESTRAZIONE DEI DATI GEOREFERENZIATI

import re
import json
import numpy as np
import struct
from shapely.geometry import Polygon
import geopandas as gpd
import os

from GLOBAL import trash_dir

# header file
def extract_header_ogpr(ogpr_file):
    header_data = {
        "byte_size_data": None,
        "byte_offset_data": None,
        "byte_size_georef": None,
        "byte_offset_georef": None,
        "sample_count": None,
        "channels_count": None,
        "slices_count": None,
        "sampling_step_m": None,
        "sampling_time_ns": None,
        "propagation_velocity": None,
        "frequency_MHz": None,
        "polarization": None
    }

    with open(ogpr_file, 'rb') as file:
        content = file.read()
        text_content = content.decode('utf-8', errors='ignore')

        # Estrazioni con regex
        byte_size_matches = re.finditer(r'"byteSize":\s*(\d+)', text_content)
        byte_offset_matches = re.finditer(r'"byteOffset":\s*(\d+)', text_content)
        sample_count_match = re.search(r'"samplesCount":\s*(\d+)', text_content)
        channels_count_match = re.search(r'"channelsCount":\s*(\d+)', text_content)
        slices_count_match = re.search(r'"slicesCount":\s*(\d+)', text_content)
        sampling_step_match = re.search(r'"samplingStep_m":\s*([\d\.]+)', text_content)
        sampling_time_match = re.search(r'"samplingTime_ns":\s*([\d\.]+)', text_content)
        propagation_velocity_match = re.search(r'"propagationVelocity_mPerSec":\s*([\d\.]+)', text_content)
        frequency_match = re.search(r'"fequency_MHz":\s*([\d\.]+)', text_content)
        polarization_match = re.search(r'"polarization":\s*"(\w+)"', text_content)

        # byteSize e byteOffset (2 valori ciascuno)
        for i, match in enumerate(byte_size_matches):
            value = int(match.group(1))
            if i == 0:
                header_data["byte_size_data"] = value
            elif i == 1:
                header_data["byte_size_georef"] = value

        for i, match in enumerate(byte_offset_matches):
            value = int(match.group(1))
            if i == 0:
                header_data["byte_offset_data"] = value
            elif i == 1:
                header_data["byte_offset_georef"] = value

        if sample_count_match:
            header_data["sample_count"] = int(sample_count_match.group(1))
        if channels_count_match:
            header_data["channels_count"] = int(channels_count_match.group(1))
        if slices_count_match:
            header_data["slices_count"] = int(slices_count_match.group(1))

        if sampling_step_match:
            header_data["sampling_step_m"] = float(sampling_step_match.group(1))
        if sampling_time_match:
            header_data["sampling_time_ns"] = float(sampling_time_match.group(1))
        if propagation_velocity_match:
            header_data["propagation_velocity"] = float(propagation_velocity_match.group(1))
        if frequency_match:
            header_data["frequency_MHz"] = float(frequency_match.group(1))
        if polarization_match:
            header_data["polarization"] = polarization_match.group(1)

    return header_data

# json subswath
def create_json_subswath(blocks):
    subs_index = []

    for i, (start_index, end_index) in enumerate(blocks):
        size = end_index - start_index + 1
        subswath_info = {
            "id_subswath": i,
            "start_index": start_index,
            "end_index": end_index,
            "size": size,
        }

        subs_index.append(subswath_info)

    subswath_json_data = {
        "num_subswath": len(blocks),
        "total_slices_count": blocks[-1][1] + 1 if blocks else 0,
        "subswath_details": subs_index
    }

    return subswath_json_data
def create_json_subswath_with_overlap(blocks, overlap=30):
    subs_index = []

    for i, (start_with_overlap, end_with_overlap) in enumerate(blocks):
        if i == 0:
            start_no_overlap = start_with_overlap
            overlap_size = 0
        else:
            start_no_overlap = start_with_overlap + overlap
            overlap_size = overlap

        size_no_overlap = end_with_overlap - start_no_overlap + 1
        size_with_overlap = end_with_overlap - start_with_overlap + 1

        subswath_info = {
            "id_subswath": i,
            "start_index_with_overlap": start_with_overlap,
            "end_index_with_overlap": end_with_overlap,
            "size_with_overlap": size_with_overlap,
            "start_index_no_overlap": start_no_overlap,
            "end_index_no_overlap": end_with_overlap,
            "size_no_overlap": size_no_overlap,
            "has_overlap": i > 0,
            "overlap_size": overlap_size,
        }

        subs_index.append(subswath_info)

    subswath_json_data = {
        "num_subswath": len(blocks),
        "overlap_used": overlap,
        "total_slices_count": blocks[-1][1] + 1 if blocks else 0,
        "subswath_details": subs_index
    }

    return subswath_json_data

# ESTRAZIONE DATI SWATH ------------------------------------------------------------------------------------------------
# totale
def extract_ogpr_data_value(ogpr_file, byte_offset, byte_size):
    with open(ogpr_file, 'rb') as file:
        file.seek(byte_offset)
        bytes_read = file.read(byte_size)  # Legge tutto in una sola operazione

    # Decodifica tutti i valori in un unico passaggio usando numpy
    data_array = np.frombuffer(bytes_read, dtype=np.int16)  # Little-endian int16

    return data_array.tolist()
# subswath
def extract_subswath(ogpr_file_path, start_index, end_index, byte_offset_data, sample_count, channels_count):
    bytes_per_sample = 2  # int16
    bytes_per_slice = sample_count * channels_count * bytes_per_sample
    num_slices = end_index - start_index + 1

    byte_offset = byte_offset_data + start_index * bytes_per_slice
    byte_size = num_slices * bytes_per_slice

    data_flat = extract_ogpr_data_value(ogpr_file_path, byte_offset, byte_size)
    mat3d = np.reshape(data_flat, (num_slices, channels_count, sample_count))

    return mat3d
# ------------------------------------------------------------------------------------------------

# GEOREF ------------------------------------------------------------------------------------------------
def save_swath_polygon(coordinates, output_path, swath_name):
    # Crea il poligono usando shapely
    polygon = Polygon(coordinates)
    # Crea un GeoDataFrame
    gdf = gpd.GeoDataFrame({
        'name': [swath_name],
        'geometry': [polygon]
    })
    # Imposta il sistema di riferimento (WGS84)
    gdf.set_crs(epsg=32632, inplace=True)
    # Salva in formato GeoJSON
    gdf.to_file(output_path, driver='GeoJSON')

def extract_swath_gps_ogpr(ogpr_file, header_json_data):
    num_channels = header_json_data.get("channels_count", 0)
    slices_count = header_json_data.get("slices_count", 0)  # 436 x
    byte_size_georef = header_json_data.get("byte_size_georef", 0)
    byte_offset_georef = header_json_data.get("byte_offset_georef", 0)

    end_position = byte_offset_georef + byte_size_georef

    with open(ogpr_file, 'rb') as f:
        # Posizionati all'inizio del blocco
        f.seek(byte_offset_georef)

        # lista per salvare i limite dell'area della mia strisciata e di conseguenza trovare l'area
        start_list = []
        end_list = []
        depth = None

        # iniziallizza qua due matrici npy vuota per salvare la mia georef
        print("-- GEOREF --")
        print(f"num_channels: {num_channels}")
        print(f"slices_count: {slices_count}")
        georef_matrix_lat = np.empty((num_channels, slices_count))
        georef_matrix_lon = np.empty((num_channels, slices_count))

        while f.tell() + 8 <= end_position:
            # 1) Leggo 8 byte -> sliceID (64-bit integer).
            slice_id_bytes = f.read(8)
            if len(slice_id_bytes) < 8:
                break  # Nessun ID => fine blocco

            # Se è little-endian firmato: '<q', se non firmato: '<Q'
            # (OGPR di solito usa LE, ma verifica col manuale)
            slice_id = struct.unpack('<Q', slice_id_bytes)[0]
            #print("slice_id: ", slice_id)

            # 2) Per questa slice, leggi un numero fisso di sweep
            for ch in range(num_channels):
                # Controllo se abbiamo ancora spazio per 64 byte:
                if f.tell() + 64 > end_position:
                    # Non c'è più spazio a sufficienza: usciamo
                    break
                sweep_block = f.read(64)
                if len(sweep_block) < 64:
                    break
                # 3) Divido i 64 byte in due CoordBlock (32 byte + 32)
                first_block = sweep_block[:32]  # min_depth
                second_block = sweep_block[32:]  # max_depth
                # 4) Ognuno è 4 double in LE: (x, y, depth, elevation)
                min_coord = struct.unpack('<dddd', first_block) # depth verso l'alto
                max_coord = struct.unpack('<dddd', second_block) # depth verso il basso (quanto va in profondità)
                # depth
                depth = max_coord[2]
                # LAT
                georef_matrix_lat[ch, slice_id] = max_coord[1]
                # LON
                georef_matrix_lon[ch, slice_id] = max_coord[0]

                if ch == 0:
                    start_list.append([max_coord[0], max_coord[1]])
                elif ch == num_channels - 1:
                    end_list.append([max_coord[0], max_coord[1]])

    # Ribaldo la end_list così che il primo elemento diventi l'ultimo e viceversa
    end_list = end_list[::-1]
    end_list.append(start_list[0]) # --> chiudo il poligono mettendo la prima coordinata
    start_list.extend(end_list)  # --> concatenate the lists instead of append()

    # Crea il poligono usando shapely
    #polygon = Polygon(start_list)
    # Crea un GeoDataFrame
    #gdf = gpd.GeoDataFrame({
    #    'geometry': [polygon]
    #})
    # Imposta il sistema di riferimento (WGS84)
    #gdf.set_crs(epsg=32632, inplace=True)
    # Salva in formato GeoJSON
    #gdf.to_file(output_path, driver='GeoJSON')
    
    # Salvataggio coordinate come singoli punti per debug nella cartella trash
    if False:  # Cambia in False per disabilitare
        points_data = []
        
        for idx, coord in enumerate(start_list):
            lon, lat = coord[0], coord[1]
            
            points_data.append({
                'type': 'Feature',
                'geometry': {
                    'type': 'Point', 
                    'coordinates': [lon, lat]
                },
                'properties': {
                    'id': idx
                }
            })
        # Salva come GeoJSON
        geojson_data = {
            'type': 'FeatureCollection',
            'features': points_data
        }
        
        import json
        points_output_path = os.path.join(trash_dir, "swath_points.geojson")
        with open(points_output_path, 'w') as f:
            json.dump(geojson_data, f, indent=2)
    
    print("FINE GPS ------------------------------------------------------------")
    return depth, georef_matrix_lat, georef_matrix_lon, start_list

def extract_subswath_swath_gps_ogpr(ogpr_file, header_json_data, start_index, end_index):
    num_channels = header_json_data.get("channels_count", 0)
    slices_count = end_index - start_index + 1
    byte_size_georef = header_json_data.get("byte_size_georef", 0)
    byte_offset_georef = header_json_data.get("byte_offset_georef", 0)

    end_position = byte_offset_georef + byte_size_georef

    with open(ogpr_file, 'rb') as f:
        # Posizionati all'inizio del blocco
        f.seek(byte_offset_georef)

        # lista per salvare i limite dell'area della mia strisciata e di conseguenza trovare l'area
        start_list = []
        end_list = []

        # iniziallizza qua due matrici npy vuota per salvare la mia georef
        print("-- SUBSWATH GEOREF --")
        print(f"num_channels: {num_channels}")
        print(f"slices_count (subswath): {slices_count}")
        georef_matrix_lat = np.empty((num_channels, slices_count))
        georef_matrix_lon = np.empty((num_channels, slices_count))

        while f.tell() + 8 <= end_position:
            # 1) Leggo 8 byte -> sliceID (64-bit integer).
            slice_id_bytes = f.read(8)
            if len(slice_id_bytes) < 8:
                break  # Nessun ID => fine blocco

            # Se è little-endian firmato: '<q', se non firmato: '<Q'
            # (OGPR di solito usa LE, ma verifica col manuale)
            slice_id = struct.unpack('<Q', slice_id_bytes)[0]
            #print("slice_id: ", slice_id)

            # 2) Per questa slice, leggi un numero fisso di sweep
            for ch in range(num_channels):
                # Controllo se abbiamo ancora spazio per 64 byte:
                if f.tell() + 64 > end_position:
                    # Non c'è più spazio a sufficienza: usciamo
                    break
                sweep_block = f.read(64)
                if len(sweep_block) < 64:
                    break
                # 3) Divido i 64 byte in due CoordBlock (32 byte + 32)
                first_block = sweep_block[:32]  # min_depth
                second_block = sweep_block[32:]  # max_depth
                # 4) Ognuno è 4 double in LE: (x, y, depth, elevation)
                min_coord = struct.unpack('<dddd', first_block) # depth verso l'alto
                max_coord = struct.unpack('<dddd', second_block) # depth verso il basso (quanto va in profondità)
                # Mappa l'indice globale della slice all'indice locale della subswath
                if start_index <= slice_id <= end_index:
                    local_idx = slice_id - start_index
                    # LAT/LON solo nella finestra richiesta
                    georef_matrix_lat[ch, local_idx] = max_coord[1]
                    georef_matrix_lon[ch, local_idx] = max_coord[0]

                    # Perimetro: prima e ultima traccia (canali estremi)
                    if ch == 0:
                        start_list.append([max_coord[0], max_coord[1]])
                    elif ch == num_channels - 1:
                        end_list.append([max_coord[0], max_coord[1]])

    # Ribaldo la end_list così che il primo elemento diventi l'ultimo e viceversa
    end_list = end_list[::-1]
    end_list.append(start_list[0]) # --> chiudo il poligono mettendo la prima coordinata
    start_list.extend(end_list)  # --> concatenate the lists instead of append()
    
    # Salvataggio coordinate come singoli punti per debug nella cartella trash
    if False:  # Cambia in False per disabilitare
        points_data = []
        
        for idx, coord in enumerate(start_list):
            lon, lat = coord[0], coord[1]
            
            points_data.append({
                'type': 'Feature',
                'geometry': {
                    'type': 'Point', 
                    'coordinates': [lon, lat]
                },
                'properties': {
                    'id': idx
                }
            })
        # Salva come GeoJSON
        geojson_data = {
            'type': 'FeatureCollection',
            'features': points_data
        }
        
        import json
        points_output_path = os.path.join(trash_dir, "swath_points.geojson")
        with open(points_output_path, 'w') as f:
            json.dump(geojson_data, f, indent=2)
    
    print("FINE SUBSWATH GEOREF ------------------------------------------------------------")
    return georef_matrix_lat, georef_matrix_lon, start_list