#include "SpLocal.hh"
#include <cmath>
#include <iostream>
#include <thread>
#include <chrono>
#include <mutex>

//Costruttore 
SpLocal::SpLocal(Matrix3D<float> &matrix, const std::vector<std::vector<Point3D>> &localMaxList, const std::vector<std::vector<Stats>> &statCoppie) 
    : matrix(matrix), localMaxList(localMaxList), statCoppie(statCoppie) {
        
    }
void SpLocal::shortestPathOnMatrix(int sizeX, int sizeY, int sizeZ)
{
    cout << "AAAAAAAAAAAAAAAAAAAAAAAAAA"<< endl;
    auto start = chrono::high_resolution_clock::now(); 
    int lato = 150; 
    int maxNeighbors = 30; 
    int numThreads = 8; 
    mutex mtxPoints, mtxEdges; // garantisce l'accesso sicuro a risorse condivise tra i thread
    pathPoints.resize(matrix.getX()); 

    // Preparazione parallelismo, divido la lista dei massimi locali in base al numero di thread 
    thread **threads = new thread *[numThreads]; 
    int size = localMaxList.size(); 
    int tempSize = (size + numThreads -1) / numThreads; 

    // Creazione dei thread 
    for (int i=0; i < numThreads; i++)
    {
        int start = i * tempSize; 
        int end = (i + 1) * tempSize; 
        if (end >= size) 
        {
            end = size; 
        }
        threads[i] = new thread([this, tempSize, lato, maxNeighbors, &mtxPoints, &mtxEdges, i, start, end, sizeX, sizeY, sizeZ]() {
            this->shortestPathOnThread(i, start, end, lato, maxNeighbors, mtxPoints, mtxEdges, sizeX, sizeY, sizeZ); 
        }); 
    } 

    // libero la memoria 
    for (int i=0; i < numThreads; i++)
    {
        threads[i]->join(); 
    }

    for (int i=0; i<numThreads; i++)
    {
        delete threads[i]; 
    }
    delete[] threads; 

    //calcolo dei tempi
    auto end = chrono::high_resolution_clock::now(); 
    auto duration = chrono::duration_cast<chrono::milliseconds>(end-start); 
    cout << "Time taken by shortestPathOnMatrix: " << duration.count() << "milliseconds" << endl; 
    cout << "AAAAAAAAAAAAAAAAAAAAAAAAAA"<< endl;
}

void SpLocal::shortestPathOnThread(int id, int from, int to, int lato, int maxNeighbors, mutex &mtxPoints, mutex &mtxEdges, int sizeX, int sizeY, int sizeZ)
{
    //cout << "Size PathPoints pre ciclo:" << pathPoints.size() <<endl;
    for (int xid = from; xid < to; xid++)
    {
        //cout << "xid: " << xid <<endl;
        for (const Point3D &max : localMaxList[xid])
        {
            MinHeap heap(lato * lato * lato); 
            float *cost = new float[lato * lato * lato]; // costo per raggiungere ogni punto 
            int *prev = new int[lato * lato * lato]; // indice nodo precedente 

            // inizializzo costi
            for (int i = 0; i < lato * lato * lato; i++)
            {
                cost[i] = heap.INF; 
                prev[i] = -1; 
            }

            // imposto nodo di partenza con costo
            long idS = getId(max.getX(), max.getY(), max.getY(), max.getX(), max.getY(), max.getY(), lato); 
            cost[idS] = 0; 
            heap.insert(idS, 0); 

            updatePaths(idS, heap, cost, prev, lato, sizeX, sizeY, sizeZ, mtxPoints, mtxEdges); 

            delete[] cost; 
            delete[] prev; 

        }
    }
    //cout << "Size PathPoints post ciclo:" << pathPoints.size() <<endl;
}

void SpLocal::updatePaths(int from, MinHeap &heap, float *cost, int *prev, int lato, int sizeX, int sizeY, int sizeZ,
mutex &mtxPoints, mutex &mtxEdges)
{
    //cout<<"Begin UpdatePaths"<<endl;
    while (!heap.isEmpty())
    {
        long idU = heap.extractMinID(); 
        Point3D u = getZYX(sizeX, sizeY, sizeZ, idU);


        for (int i = -1; i <= 1; i++)
        {
            for (int j = -1; j <= 1; j++)
            {
                for (int k = -1; k <= 1; k++)
                {
                    if (i != 0 || j != 0 || k != 0)
                    {
                        int x = u.getX() + k;
                        int y = u.getY() + j;
                        int z = u.getZ() + i;

                        if (x >= 0 && y >= 0 && z >= 0 && x < sizeX && y < sizeY && z < sizeZ) // controllo margini matrice 
                        {                            
                            float delta = statCoppie[u.getX()][x].valMax; // calcolo con statCoppie
                            long idV = getId(u.getZ(), u.getY(), u.getX(), z, y, x, lato); // l'ID del vicino
                            updateEdge(idU, idV, sizeX, sizeY, sizeZ,delta, delta, heap, cost, prev, mtxPoints, mtxEdges); 
                        }
                    }
                }
            }
        }
    }
    //  //cout << "Scorro path points" << pathPoints.size() <<endl;
    //  lines.resize(pathPoints.size()+1);
    //  // Scorri attraverso i path points per creare le linee
    //  for (int xid = 0; xid < pathPoints.size(); xid++)
    //  {
    //      //cout<<"xid: "<< xid << endl;
    //      const auto& points = pathPoints[xid];
    //      //cout<<"Points size: "<< points.size() << endl;
    //      //lines[xid].clear();
    //      // Se ci sono almeno due punti, crea una linea tra i successivi
    //      lock_guard<mutex> lock(mtxEdges);
    //      if (points.size()>=2)
    //      {
    //          for (int i = 1; i < points.size()-1; i++)
    //          {
                 
    //              //cout<<"ciclo: "<< xid <<"- "<<i<<endl;
    //              Point3D u = points[i - 1];  // Punto precedente
    //              Point3D v = points[i];      // Punto corrente
    //              //cout<<"u (x,y,z): "<<u.getX()<< " "<<u.getY()<< " "<<u.getZ()<<endl;
    //              //cout<<"v (x,y,z): "<<v.getX()<< " "<<v.getY()<< " "<<v.getZ()<<endl;
    //              // Usa il costo calcolato dallo heap per creare la linea
    //              // (Assumiamo che i costi siano memorizzati nell'array `cost`)
    //              //long idU = getId(u.getZ(), u.getY(), u.getX(), u.getZ(), u.getY(), u.getX(), lato);
    //              //long idV = getId(v.getZ(), v.getY(), v.getX(), v.getZ(), v.getY(), v.getX(), lato);
    //              float costUV = 0.5;  // Costo da u a v (o da v a u a seconda della direzione)
    //              //cout<<"CostUV: " <<costUV<<endl;
    //              // Creazione della linea con il costo
    //              Line line(u, v, costUV);
    //              lines[xid].push_back(line);
    //          }
    //      }
    //  }
}

void SpLocal::updateEdge(int idU, int idV, int sizeX, int sizeY, int sizeZ, float direction, float delta, MinHeap &heap, float *cost, int *prev,
    mutex &mtxPoints, mutex &mtxEdges)
{
    float newCost = cost[idU] + (delta + 0.001); 

    // aggiorna il costo se minore del costo attuale 
    if (newCost < cost[idV])
    {
        //cout<<"updating edge... "<<endl;
        cost[idV] = newCost; 
        prev[idV] = idU; // origine del nodo V
        if (heap.containsID(idV))
        {
            heap.decreaseKey(idV, newCost); 
        }
        else
        {
            heap.insert(idV, newCost); 
        }

        {
            lock_guard<mutex> lock(mtxPoints);
            Point3D v = getZYX(sizeX, sizeY, sizeZ, idV);  // Nodo V
            Point3D u = getZYX(sizeX, sizeY, sizeZ, idU);  // Nodo U

            if (idU >= pathPoints.size()) {
                pathPoints.resize(idU + 1);
            }
            
            if (pathPoints[idU].empty() || 
                pathPoints[idU].back().getX() != v.getX() || 
                pathPoints[idU].back().getY() != v.getY() || 
                pathPoints[idU].back().getZ() != v.getZ()) {
                pathPoints[idU].push_back(v);  // Aggiungi il punto V al percorso di U
            }


            // Inserisci l'arco (edge) in `lines` se il percorso ha almeno due punti
            {
                lock_guard<mutex> lock(mtxEdges);

                if (pathPoints[idU].size() > 1) {
                    Line t(u, v, newCost);  // Crea l'arco con il costo calcolato
                    if (idU >= lines.size()) {
                        lines.resize(idU + 1);
                    }
                    lines[idU].push_back(t);  // Inserisci l'arco nella lista
                    cout<<"Inserito arco tra (x,y,z): "
                    <<u.getX()<<" "<<u.getY()<<" "<<u.getZ()<<
                    " e (x,y,z) "
                    <<v.getX()<<" "<<v.getY()<<" "<<v.getZ()<<
                    endl;
                }
            }
        }

    }
}

long SpLocal::getId(int mz, int my, int mx, int z, int y, int x, int lato)
{
    return x - mx + lato / 2 + (y - my + lato / 2) * lato + (z - mz + lato / 2) * lato * lato;
}

Point3D SpLocal::getZYX(int sizeZ, int sizeY, int sizeX, long id)
{
    int sizeXY = sizeY * sizeX;
    int z = id / sizeXY;
    int sizeXYZ = z * sizeXY;
    int y = (id - sizeXYZ) / sizeX;
    int x = (id - (sizeXYZ + sizeX * y));

    return Point3D(z, y, x);
}

vector<vector<Point3D>> &SpLocal::getPathPoints()
{
    return pathPoints;
}

vector<vector<Line>> &SpLocal::getLines()
{
    return lines;
}