#include "contract_4spinor.h"

#include <cassert>

// This implementation only applies to SU(3) group and Nd=4 case.
#define NC      3
#define NC2     6
#define NDF     18
#define ND      4
#define NCD     12
#define NCD2    24
#define C1      0
#define C2      2
#define C3      4

//====================================================================
void contract_at_t(dcomplex& corr, const GammaMatrix& gm,
                   const Field_F& v1, const Field_F& v2, int time)
{
  int Nvol   = v1.nvol();
  int Nvol_s = Nvol / CommonParameters::Nt();

  assert(Nvol == v2.nvol());
  assert(v1.nex() == 1);
  assert(v2.nex() == 1);

  double *w1 = const_cast<Field_F *>(&v1)->ptr(0);
  double *w2 = const_cast<Field_F *>(&v2)->ptr(0);

  int id11 = 0;
  int id12 = NC2;
  int id13 = 2 * NC2;
  int id14 = 3 * NC2;
  int id21 = gm.index(0) * NC2;
  int id22 = gm.index(1) * NC2;
  int id23 = gm.index(2) * NC2;
  int id24 = gm.index(3) * NC2;

  double cr[ND];
  double ci[ND];
  for (int id = 0; id < ND; ++id) {
    cr[id] = 0.0;
    ci[id] = 0.0;
  }

  for (int ss = 0; ss < Nvol_s; ++ss) {
    int site = NCD2 * (ss + time * Nvol_s);

    for (int cc = 0; cc < NC; ++cc) {
      cr[0] += w1[2 * cc + id21 + site] * w2[2 * cc + id11 + site]
               + w1[2 * cc + 1 + id21 + site] * w2[2 * cc + 1 + id11 + site];
      cr[1] += w1[2 * cc + id22 + site] * w2[2 * cc + id12 + site]
               + w1[2 * cc + 1 + id22 + site] * w2[2 * cc + 1 + id12 + site];
      cr[2] += w1[2 * cc + id23 + site] * w2[2 * cc + id13 + site]
               + w1[2 * cc + 1 + id23 + site] * w2[2 * cc + 1 + id13 + site];
      cr[3] += w1[2 * cc + id24 + site] * w2[2 * cc + id14 + site]
               + w1[2 * cc + 1 + id24 + site] * w2[2 * cc + 1 + id14 + site];

      ci[0] += w1[2 * cc + id21 + site] * w2[2 * cc + 1 + id11 + site]
               - w1[2 * cc + 1 + id21 + site] * w2[2 * cc + id11 + site];
      ci[1] += w1[2 * cc + id22 + site] * w2[2 * cc + 1 + id12 + site]
               - w1[2 * cc + 1 + id22 + site] * w2[2 * cc + id12 + site];
      ci[2] += w1[2 * cc + id23 + site] * w2[2 * cc + 1 + id13 + site]
               - w1[2 * cc + 1 + id23 + site] * w2[2 * cc + id13 + site];
      ci[3] += w1[2 * cc + id24 + site] * w2[2 * cc + 1 + id14 + site]
               - w1[2 * cc + 1 + id24 + site] * w2[2 * cc + id14 + site];
    }
  }

  corr = gm.value(0) * cmplx(cr[0], ci[0])
         + gm.value(1) * cmplx(cr[1], ci[1])
         + gm.value(2) * cmplx(cr[2], ci[2])
         + gm.value(3) * cmplx(cr[3], ci[3]);
}


//====================================================================
void contract_at_t(dcomplex& corr, const GammaMatrix& gm, int d3,
                   const Field_F& v1, const Field_F& v2, const Field_F& v3,
                   int time)
{
  int Nvol   = v1.nvol();
  int Nvol_s = Nvol / CommonParameters::Nt();

  assert(Nvol == v2.nvol());
  assert(Nvol == v3.nvol());
  assert(v1.nex() == 1);
  assert(v2.nex() == 1);
  assert(v3.nex() == 1);

  double *w1 = const_cast<Field_F *>(&v1)->ptr(0);
  double *w2 = const_cast<Field_F *>(&v2)->ptr(0);
  double *w3 = const_cast<Field_F *>(&v3)->ptr(0);

  int s3 = d3 * NC2;

  int    gmd[ND];
  double cr[ND];
  double ci[ND];

  for (int id = 0; id < ND; ++id) {
    gmd[id] = gm.index(id);
    cr[id]  = 0.0;
    ci[id]  = 0.0;
  }


  for (int ss = 0; ss < Nvol_s; ++ss) {
    int site = NCD2 * (ss + time * Nvol_s);

    for (int id = 0; id < ND; ++id) {
      int iw1 = id * NC2 + site;
      int iw2 = gmd[id] * NC2 + site;
      int iw3 = s3 + site;

      cr[id] += (w1[C1 + iw1] * w2[C2 + iw2]
                 - w1[C1 + 1 + iw1] * w2[C2 + 1 + iw2]) * w3[C3 + iw3]
                - (w1[C1 + iw1] * w2[C2 + 1 + iw2]
                   + w1[C1 + 1 + iw1] * w2[C2 + iw2]) * w3[C3 + 1 + iw3];
      ci[id] += (w1[C1 + iw1] * w2[C2 + iw2]
                 - w1[C1 + 1 + iw1] * w2[C2 + 1 + iw2]) * w3[C3 + 1 + iw3]
                + (w1[C1 + iw1] * w2[C2 + 1 + iw2]
                   + w1[C1 + 1 + iw1] * w2[C2 + iw2]) * w3[C3 + iw3];

      cr[id] += (w1[C2 + iw1] * w2[C3 + iw2]
                 - w1[C2 + 1 + iw1] * w2[C3 + 1 + iw2]) * w3[C1 + iw3]
                - (w1[C2 + iw1] * w2[C3 + 1 + iw2]
                   + w1[C2 + 1 + iw1] * w2[C3 + iw2]) * w3[C1 + 1 + iw3];
      ci[id] += (w1[C2 + iw1] * w2[C3 + iw2]
                 - w1[C2 + 1 + iw1] * w2[C3 + 1 + iw2]) * w3[C1 + 1 + iw3]
                + (w1[C2 + iw1] * w2[C3 + 1 + iw2]
                   + w1[C2 + 1 + iw1] * w2[C3 + iw2]) * w3[C1 + iw3];

      cr[id] += (w1[C3 + iw1] * w2[C1 + iw2]
                 - w1[C3 + 1 + iw1] * w2[C1 + 1 + iw2]) * w3[C2 + iw3]
                - (w1[C3 + iw1] * w2[C1 + 1 + iw2]
                   + w1[C3 + 1 + iw1] * w2[C1 + iw2]) * w3[C2 + 1 + iw3];
      ci[id] += (w1[C3 + iw1] * w2[C1 + iw2]
                 - w1[C3 + 1 + iw1] * w2[C1 + 1 + iw2]) * w3[C2 + 1 + iw3]
                + (w1[C3 + iw1] * w2[C1 + 1 + iw2]
                   + w1[C3 + 1 + iw1] * w2[C1 + iw2]) * w3[C2 + iw3];

      cr[id] -= (w1[C3 + iw1] * w2[C2 + iw2]
                 - w1[C3 + 1 + iw1] * w2[C2 + 1 + iw2]) * w3[C1 + iw3]
                - (w1[C3 + iw1] * w2[C2 + 1 + iw2]
                   + w1[C3 + 1 + iw1] * w2[C2 + iw2]) * w3[C1 + 1 + iw3];
      ci[id] -= (w1[C3 + iw1] * w2[C2 + iw2]
                 - w1[C3 + 1 + iw1] * w2[C2 + 1 + iw2]) * w3[C1 + 1 + iw3]
                + (w1[C3 + iw1] * w2[C2 + 1 + iw2]
                   + w1[C3 + 1 + iw1] * w2[C2 + iw2]) * w3[C1 + iw3];

      cr[id] -= (w1[C2 + iw1] * w2[C1 + iw2]
                 - w1[C2 + 1 + iw1] * w2[C1 + 1 + iw2]) * w3[C3 + iw3]
                - (w1[C2 + iw1] * w2[C1 + 1 + iw2]
                   + w1[C2 + 1 + iw1] * w2[C1 + iw2]) * w3[C3 + 1 + iw3];
      ci[id] -= (w1[C2 + iw1] * w2[C1 + iw2]
                 - w1[C2 + 1 + iw1] * w2[C1 + 1 + iw2]) * w3[C3 + 1 + iw3]
                + (w1[C2 + iw1] * w2[C1 + 1 + iw2]
                   + w1[C2 + 1 + iw1] * w2[C1 + iw2]) * w3[C3 + iw3];

      cr[id] -= (w1[C1 + iw1] * w2[C3 + iw2]
                 - w1[C1 + 1 + iw1] * w2[C3 + 1 + iw2]) * w3[C2 + iw3]
                - (w1[C1 + iw1] * w2[C3 + 1 + iw2]
                   + w1[C1 + 1 + iw1] * w2[C3 + iw2]) * w3[C2 + 1 + iw3];
      ci[id] -= (w1[C1 + iw1] * w2[C3 + iw2]
                 - w1[C1 + 1 + iw1] * w2[C3 + 1 + iw2]) * w3[C2 + 1 + iw3]
                + (w1[C1 + iw1] * w2[C3 + 1 + iw2]
                   + w1[C1 + 1 + iw1] * w2[C3 + iw2]) * w3[C2 + iw3];
    }
  }
  corr = cmplx(0.0, 0.0);
  for (int id = 0; id < ND; ++id) {
    corr += cmplx(cr[id], ci[id]) * gm.value(id);
  }
}
