#ifndef OBS_DIAGCORR_HPP
#define OBS_DIAGCORR_HPP

#include "latlib/o815/o815.h"

#include "latlib/writeout.h"

#include "latlib/obstat.hpp"

#include <iostream>
#include <complex>

#include <cmath>

#include <gsl/gsl_complex.h>
#include <gsl/gsl_complex_math.h>
#include <gsl/gsl_math.h>
#include <gsl/gsl_eigen.h>
#include <gsl/gsl_linalg.h>
#include <gsl/gsl_cblas.h>
#include <gsl/gsl_blas.h>

using namespace std;

class obs_diagcorr : public o815::obs {

public:
  obs_diagcorr(o815 *_O815);
  
private:
  void _start();
  void _meas(bool loadedobs, const int& nthmeas);
  void _finish();

  void corrCompute();
  static complex<double> effMass(vector < complex<double> > *preCalculated, vector< complex<double> > *excludedMeas, int nmeas, void *para);
  static void preEffMass(vector< vector < complex<double> > > *allVals, vector < complex<double> > *preCalculated, void *para);

  sim *Sim;

  int spatialV;

  complex<double> *OM[16+1];

  gsl_matrix_complex ***measurements;

  gsl_vector_complex **measurements_disco;
  
  static void cdiag (gsl_vector *v, gsl_matrix_complex *evec, gsl_matrix_complex *m);
};

obs_diagcorr::obs_diagcorr(o815 *_O815) : o815::obs("diagcorr", 
						    _O815->paraQ->getParaNames() + 
						    "tsep"
						    "...", 
						    _O815, sizeof(complex<double>) * ( 16*_O815->comargs.lsize[1]/2 + 1*4 ) ) {
  for (int ivar = 0; ivar<16+1; ivar++)
    OM[ivar]  = (complex<double>*)( obsMem + ivar*sizeof(complex<double>)*( _O815->comargs.lsize[1]/2 ) );
  
  Sim = (sim*)O815->Sim;
  spatialV = O815->comargs.lsize[0] * O815->comargs.lsize[0] * O815->comargs.lsize[0];

  measurements = new gsl_matrix_complex**[O815->comargs.nmeas];
  for (int imeas = 0; imeas<O815->comargs.nmeas; imeas++) {
    measurements[imeas] = new gsl_matrix_complex*[_O815->comargs.lsize[1]/2];
    for (int itsep=0; itsep<_O815->comargs.lsize[1]/2; itsep++)
      measurements[imeas][itsep] = gsl_matrix_complex_alloc(4,4);
  }

  measurements_disco = new gsl_vector_complex*[O815->comargs.nmeas];
  for (int imeas = 0; imeas<O815->comargs.nmeas; imeas++)
    measurements_disco[imeas] = gsl_vector_complex_alloc(4);
}

void obs_diagcorr::_start() {
  //  *out << O815->comargs.nmeas << endl;
  
  //*out << "OBS_test: start" << endl;
};

void obs_diagcorr::_meas(bool loadedobs, const int& nthmeas) {
  if (!loadedobs)
    corrCompute();

  for (int icorr=0; icorr<4; icorr++)
    for (int jcorr=0; jcorr<4; jcorr++)
      for (int itsep=0; itsep<O815->comargs.lsize[1]/2; itsep++) {
	gsl_matrix_complex_set(measurements[nthmeas][itsep], icorr, jcorr,
			       gsl_complex_rect( OM[icorr*4+jcorr][itsep].real(), OM[icorr*4+jcorr][itsep].imag() ));
      }
  
  for (int icorr=0; icorr<4; icorr++) {
    gsl_vector_complex_set(measurements_disco[nthmeas], icorr,
			   gsl_complex_rect( OM[16][icorr].real(), OM[16][icorr].imag() ) );
  }
};

void obs_diagcorr::cdiag (gsl_vector *v, gsl_matrix_complex *evec, gsl_matrix_complex *m)
{
  gsl_eigen_hermv_workspace *wspace = gsl_eigen_hermv_alloc(100);

  gsl_eigen_hermv(m, v, evec, wspace);
  gsl_eigen_hermv_sort(v, evec, GSL_EIGEN_SORT_VAL_ASC);
  
  gsl_eigen_hermv_free(wspace);
}

void obs_diagcorr::_finish() {
  gsl_matrix_complex *totalval = gsl_matrix_complex_alloc(4,4);
  gsl_matrix_complex *tmpmatrix = gsl_matrix_complex_alloc(4,4);
  gsl_matrix_complex *tmpmatrix2 = gsl_matrix_complex_alloc(4,4);
  gsl_vector *tmpvec = gsl_vector_alloc(4);
  gsl_vector_complex *tmpvecc = gsl_vector_complex_alloc(4);
  gsl_vector *tmpvec2 = gsl_vector_alloc(4);
  gsl_vector *jackres = gsl_vector_alloc(4);
  gsl_vector *jackerrornorm = gsl_vector_alloc(4);
  gsl_vector_complex *totaldisco = gsl_vector_complex_alloc(4);
  gsl_matrix_complex *evecres = gsl_matrix_complex_alloc(4,4);
  gsl_matrix_complex *evecerrornorm = gsl_matrix_complex_alloc(4,4);
  
  gsl_vector_complex_set_zero(totaldisco);
  for (int imeas=0; imeas<O815->comargs.nmeas; imeas++)
    gsl_vector_complex_add( totaldisco, measurements_disco[imeas] );
  
  for (int itsep = 0; itsep < O815->comargs.lsize[1]/2; itsep++) {
    gsl_matrix_complex_set_zero(totalval);
    gsl_vector_set_zero(jackerrornorm);
    gsl_matrix_complex_set_zero(evecerrornorm);

    for (int imeas=0; imeas<O815->comargs.nmeas; imeas++)
      gsl_matrix_complex_add( totalval, measurements[imeas][itsep] );
    
    gsl_matrix_complex_memcpy (tmpmatrix, totalval);
    gsl_matrix_complex_scale (tmpmatrix, gsl_complex_rect(1.0/O815->comargs.nmeas, 0.0) );
    for (int icorr=0; icorr<4; icorr++)
      for (int jcorr=0; jcorr<4; jcorr++) {
	gsl_complex discopart = gsl_complex_mul( gsl_vector_complex_get(totaldisco, icorr),
						 gsl_complex_conjugate( gsl_vector_complex_get(totaldisco, jcorr) )
						 );
	discopart = gsl_complex_mul_real(discopart, pow(1.0/O815->comargs.nmeas,2));
	gsl_matrix_complex_set( tmpmatrix, icorr, jcorr, 
				gsl_complex_sub( gsl_matrix_complex_get( tmpmatrix, icorr, jcorr ),
						 discopart
						 )
				);
      }
    cdiag(jackres, evecres, tmpmatrix);
    
    for (int imeas=0; imeas<O815->comargs.nmeas; imeas++) {
      gsl_matrix_complex_memcpy (tmpmatrix, totalval);
      gsl_matrix_complex_sub (tmpmatrix, measurements[imeas][itsep]);
      gsl_matrix_complex_scale ( tmpmatrix, gsl_complex_rect(1.0/(O815->comargs.nmeas-1), 0.0) );
      gsl_vector_complex_memcpy( tmpvecc, totaldisco );
      gsl_vector_complex_sub( tmpvecc, measurements_disco[imeas] );
      for (int icorr=0; icorr<4; icorr++)
	for (int jcorr=0; jcorr<4; jcorr++) {
	  gsl_complex discopart = gsl_complex_mul( gsl_vector_complex_get(tmpvecc, icorr),
						   gsl_complex_conjugate( gsl_vector_complex_get(tmpvecc, jcorr) )
						   );
	  discopart = gsl_complex_mul_real(discopart, pow(1.0/(O815->comargs.nmeas-1),2));
	  gsl_matrix_complex_set( tmpmatrix, icorr, jcorr, 
				  gsl_complex_sub( gsl_matrix_complex_get( tmpmatrix, icorr, jcorr ),
						   discopart
						   )
				  );
	}
      cdiag(tmpvec, tmpmatrix2, tmpmatrix);
      
      gsl_vector_sub(tmpvec, jackres);
      gsl_vector_memcpy(tmpvec2, tmpvec);
      gsl_vector_mul(tmpvec, tmpvec2);
      gsl_vector_add(jackerrornorm, tmpvec);

      gsl_matrix_complex_sub(tmpmatrix2, evecres);
      for (int i=0; i<4; i++)
	for (int j=0; j<4; j++) {
	  gsl_complex tmperr =
	    gsl_complex_rect( pow(GSL_REAL( gsl_matrix_complex_get( tmpmatrix2, i, j ) ),2),
			      pow(GSL_IMAG( gsl_matrix_complex_get( tmpmatrix2, i, j ) ),2));
	  gsl_matrix_complex_set(tmpmatrix, i, j, tmperr);
	}
      gsl_matrix_complex_add( evecerrornorm, tmpmatrix );
    } 
    gsl_vector_scale( jackerrornorm,
		      (double)(O815->comargs.nmeas-1)/O815->comargs.nmeas );
    gsl_matrix_complex_scale( evecerrornorm, gsl_complex_rect( (double)(O815->comargs.nmeas-1)/O815->comargs.nmeas, 0.0 ) );
    
    *out << O815->paraQ->getParaVals();
    *out << "\t" << itsep;
    for (int iev = 0; iev < 4; iev++) {
      *out << "\t" << gsl_vector_get(jackres,iev)
	   << "\t" << sqrt(gsl_vector_get(jackerrornorm,iev));
      for (int iinter=0; iinter<4; iinter++)
	*out << "\t" << GSL_REAL(gsl_matrix_complex_get(evecres, iinter, iev))
	     << "\t" << sqrt( GSL_REAL(gsl_matrix_complex_get(evecerrornorm, iinter, iev)) );
    }
    *out << endl;
  }

  gsl_matrix_complex_free(totalval);
  gsl_matrix_complex_free(tmpmatrix);
  gsl_matrix_complex_free(evecres);
  gsl_vector_free(tmpvec);
  gsl_vector_complex_free(tmpvecc);
  gsl_vector_free(tmpvec2);
  gsl_vector_free(jackres);
  gsl_vector_free(jackerrornorm);
  gsl_vector_complex_free(totaldisco);
};

void obs_diagcorr::corrCompute()
{
  complex<double> phislice[4][O815->comargs.lsize[1]];

  for (int icorr=0; icorr<4; icorr++)
    OM[16][icorr] = 0;
  
  for (int it = 0; it < O815->comargs.lsize[1]; it++) {
    for (int icorr=0; icorr<4; icorr++)
      phislice[icorr][it] = 0;
    
    for (int ix = 0; ix < spatialV; ix++) {
      phislice[0][it] += conj(Sim->phi[ 0*Sim->lsize4 + it*spatialV + ix ]) * Sim->phi[ 0*Sim->lsize4 + it*spatialV + ix ];
      phislice[1][it] += Sim->phi[ 0*Sim->lsize4 + it*spatialV + ix ] * Sim->phi[ 1*Sim->lsize4 + it*spatialV + ix ];
      phislice[2][it] += conj(Sim->phi[ 0*Sim->lsize4 + it*spatialV + ix ]) * conj(Sim->phi[ 1*Sim->lsize4 + it*spatialV + ix ]);
      phislice[3][it] += conj(Sim->phi[ 1*Sim->lsize4 + it*spatialV + ix ]) * Sim->phi[ 1*Sim->lsize4 + it*spatialV + ix ];
    }

    for (int icorr=0; icorr<4; icorr++) {
      phislice[icorr][it] /= spatialV;
      OM[16][icorr] += phislice[icorr][it];
    }
  }

  for (int icorr = 0; icorr < 4; icorr++) {
    for (int jcorr = 0; jcorr < 4; jcorr++) {
      for (int itsep = 0; itsep < O815->comargs.lsize[1]/2; itsep++) {
	OM[icorr*4+jcorr][itsep] = 0;
	
	for (int it = 0; it < O815->comargs.lsize[1]; it++)
	  OM[icorr*4+jcorr][itsep] += phislice[icorr][ (it+itsep)%O815->comargs.lsize[1] ] * conj( phislice[jcorr][it] );
      
	OM[icorr*4+jcorr][itsep] /= O815->comargs.lsize[1];
      }
    }
    
    OM[16][icorr] /= O815->comargs.lsize[1];
  }
}

#endif
