/*!
        @file    $Id:: field.h #$

        @brief

        @author  <Hideo Matsufuru> hideo.matsufuru@kek.jp(matsufuru)
                 $LastChangedBy: sueda $

        @date    $LastChangedDate:: 2013-07-12 16:56:41 #$

        @version $LastChangedRevision: 930 $
*/


#ifndef FIELD_INCLUDED
#define FIELD_INCLUDED

#include <valarray>
#include <string>
#include <assert.h>

#include "commonParameters.h"
#include "communicator.h"
#include "bridge_complex.h"

//#define USE_EXPR_TEMPL

//! Container of Field-type object.

/*!
   This class defines field-type quantity which has three
   size parameters, Nin: on-site degree of freedom,
   Nvol: site d.o.f, Nex: extra d.o.f.
   The detailed structure of these degrees of freedom is not
   defined in this class but in subclasses.
   Expression template was implemented by J.Noaki.
                                   [28 Dec 2011 H.Matsufuru]
 */
class Field {
  // friend dcomplex ddotc_complex(Field&, Field&);

 public:
  enum complexness { COMPLEX, REAL };

 protected:
  std::valarray<double> field;
  int         m_Nvol; // lattice volume
  int         m_Nin;  // internal d.o.f.
  int         m_Nex;  // external d.o.f.
  complexness m_complexness;
  // total degree of freedom is Nin * Nsite * Nex.
  int m_Ntot;  // total d.o.f.

  int myindex(const int jin, const int site, const int jex)
  const { return jin + m_Nin * (site + m_Nvol * jex); }

  Bridge::VerboseLevel m_vl;

 public:

  Field() :
    m_Nvol(0), m_Nin(0), m_Nex(0), m_complexness(COMPLEX),
    m_vl(CommonParameters::Vlevel()) {  }

  Field(const int Nin, const int Nvol, const int Nex,
        const complexness cmpl = COMPLEX) :
    m_Nvol(Nvol), m_Nin(Nin), m_Nex(Nex), m_complexness(cmpl),
    m_vl(CommonParameters::Vlevel())
  {
    m_Ntot = m_Nin * m_Nvol * m_Nex;
    field.resize(m_Ntot);
  }

  void reset(const int Nin, const int Nvol, const int Nex,
             const complexness cmpl = COMPLEX)
  {
    m_Nin         = Nin;
    m_Nvol        = Nvol;
    m_Nex         = Nex;
    m_Ntot        = m_Nin * m_Nvol * m_Nex;
    m_complexness = cmpl;
    field.resize(m_Ntot);
  }

  int nin() const { return m_Nin; }
  int nex() const { return m_Nex; }
  int nvol() const { return m_Nvol; }
  int ntot() const { return m_Ntot; }
  int size() const { return m_Nin * m_Nvol * m_Nex; }
  complexness field_complexness() const { return m_complexness; }

  double cmp(const int jin, const int site, const int jex) const
  {
    return field[myindex(jin, site, jex)];
  }

  double cmp(const int i) const { return field[i]; }

  double *ptr(const int i) { return &field[i]; }

  void set(const int jin, const int site, const int jex, double v)
  {
    field[myindex(jin, site, jex)] = v;
  }

  void set(const int i, double v) { field[i] = v; }

  void add(const int jin, const int site, const int jex, double v)
  {
    field[myindex(jin, site, jex)] += v;
  }

  void add(const int i, double v) { field[i] += v; }

  void setpart_ex(int ex, const Field& w, int exw)
  {
    assert(ex < m_Nex);
    assert(exw < w.nex());
    for (int site = 0; site < m_Nvol; ++site) {
      for (int jin = 0; jin < m_Nin; ++jin) {
        field[myindex(jin, site, ex)] = w.field[myindex(jin, site, exw)];
      }
    }
  }

  void addpart_ex(int ex, const Field& w, int exw)
  {
    assert(ex < m_Nex);
    assert(exw < w.nex());
    for (int site = 0; site < m_Nvol; ++site) {
      for (int jin = 0; jin < m_Nin; ++jin) {
        field[myindex(jin, site, ex)] += w.field[myindex(jin, site, exw)];
      }
    }
  }

  void addpart_ex(int ex, const Field& w, int exw, double prf)
  {
    assert(ex < m_Nex);
    assert(exw < w.nex());
    for (int site = 0; site < m_Nvol; ++site) {
      for (int jin = 0; jin < m_Nin; ++jin) {
        field[myindex(jin, site, ex)]
          += prf * w.field[myindex(jin, site, exw)];
      }
    }
  }

  double norm() const
  {
    double a = (field * field).sum();
    double b = Communicator::reduce_sum(a);

    return sqrt(b);
  }

  //  double norm2() const {
  //    double a = (field*field).sum();
  //    double b = Communicator::reduce_sum(a);
  //    return b;
  //  }

  // following several function was added to test the
  // performance of the functions replacing expression
  // templates.  [8 Apr 2012 H.Matsufuru]

  double ddotc(const Field& x) const
  {
    double *yp = const_cast<Field *>(this)->ptr(0);
    double *xp = const_cast<Field *>(&x)->ptr(0);
    //#pragma disjoint(*yp,*xp)  // for BG/Q
    //    __alignx(32,yp);
    //    __alignx(32,xp);
    double a = 0.0;

    for (int i = 0; i < m_Ntot; ++i) {
      a += yp[i] * xp[i];
    }
    double b = Communicator::reduce_sum(a);
    return b;
  }

  double norm2() const
  {
    //    __alignx(32,&field[0]);
    double a = 0.0;

    for (int i = 0; i < m_Ntot; ++i) {
      a += field[i] * field[i];
    }
    double b = Communicator::reduce_sum(a);
    return b;
  }

  void daxpy(double a, const Field& x)
  {
    double *yp = ptr(0);
    double *xp = const_cast<Field *>(&x)->ptr(0);

    //#pragma disjoint(*yp,*xp)  // for BG/Q
    //    __alignx(32,yp);
    //    __alignx(32,xp);
    for (int i = 0; i < m_Ntot; ++i) {
      yp[i] += a * xp[i];
    }
  }

  void daxpy(dcomplex a, const Field& x)
  {
    double *yp = ptr(0);
    double *xp = const_cast<Field *>(&x)->ptr(0);

    assert(x.ntot() == m_Ntot);
    if (m_complexness == COMPLEX) {
      double ar = real(a);
      double ai = imag(a);
      for (int k = 0; k < m_Ntot / 2; ++k) {
        yp[2 * k]     += ar * xp[2 * k] - ai * xp[2 * k + 1];
        yp[2 * k + 1] += ai * xp[2 * k] + ar * xp[2 * k + 1];
      }
    } else if (m_complexness == REAL) {
      double ar = real(a);
      for (int k = 0; k < m_Ntot; ++k) {
        yp[k] += ar * xp[k];
      }
    } else {
      abort();
    }
  }

  void dscal(double a)
  {
    double *yp = ptr(0);

    //    __alignx(32,yp);
    for (int i = 0; i < m_Ntot; ++i) {
      yp[i] *= a;
    }
  }

  void dcopy(const Field& x)
  {
    for (int i = 0; i < m_Ntot; ++i) {
      field[i] = x.field[i];
    }
  }

  void dcopy(double a, const Field& x)
  {
    double *yp = ptr(0);
    double *xp = const_cast<Field *>(&x)->ptr(0);

    //#pragma disjoint(*yp,*xp)
    //    __alignx(32,yp);
    //    __alignx(32,xp);
    for (int i = 0; i < m_Ntot; ++i) {
      yp[i] = a * xp[i];
    }
  }

  void clear()
  {
    for (int i = 0; i < m_Ntot; ++i) {
      field[i] = 0.0;
    }
  }

  //! \brief determines the statistics of the field.
  //! average, maximum value, and deviation is determined
  //! over global lattice. On-site degree of freedom is
  //! sumed over in quadrature, not averaged.
  //! This function works only on single node.
  void stat(double& Fave, double& Fmax, double& Fdev);

  //! write field values to a text file.
  void write_text(std::string);

  //! \brief read field values from text file.
  //! Assumes field size is already defined, and if it is
  //! inconsistent with the field in the file, aborted.
  //! This function works only on single node.
  void read_text(std::string);

#ifdef USE_EXPR_TEMPL
  template<typename T>
  Field& operator=(const T& rhs)
  {
    *this = rhs.eval();
    return *this;
  }

  template<typename T>
  Field& operator+=(const T& rhs)
  {
    *this += rhs.eval();
    return *this;
  }

  template<typename T>
  Field& operator-=(const T& rhs)
  {
    *this -= rhs.eval();
    return *this;
  }
#endif

  Field& operator-();
  Field& operator=(const double&);
  Field& operator+=(const Field&);
  Field& operator-=(const Field&);
  Field& operator*=(const double&);
  Field& operator/=(const double&);
  double operator*(const Field& rhs);
};

inline Field& Field::operator-()
{
  field = -field;
  return *this;
}


inline Field& Field::operator=(const double& r)
{
  field = r;
  return *this;
}


inline Field& Field::operator+=(const Field& rhs)
{
  field += rhs.field;
  return *this;
}


inline Field& Field::operator-=(const Field& rhs)
{
  field -= rhs.field;
  return *this;
}


inline Field& Field::operator*=(const double& rhs)
{
  field *= rhs;
  return *this;
}


inline Field& Field::operator/=(const double& rhs)
{
  field /= rhs;
  return *this;
}


inline double Field::operator*(const Field& rhs)
{
  double a = (field * rhs.field).sum();
  double b = Communicator::reduce_sum(a);

  return b;
}


inline dcomplex ddotc_complex(const Field& y, const Field& x)
{
  double *yp  = const_cast<Field *>(&y)->ptr(0);
  double *xp  = const_cast<Field *>(&x)->ptr(0);
  int    Ntot = y.ntot();

  assert(x.ntot() == Ntot);

  if (y.field_complexness() == Field::COMPLEX) {
    double prdr = 0.0;
    double prdi = 0.0;
    for (int k = 0; k < Ntot / 2; ++k) {
      prdr += yp[2 * k] * xp[2 * k]
              + yp[2 * k + 1] * xp[2 * k + 1];
      prdi += yp[2 * k] * xp[2 * k + 1]
              - yp[2 * k + 1] * xp[2 * k];
    }
    prdr = Communicator::reduce_sum(prdr);
    prdi = Communicator::reduce_sum(prdi);
    return cmplx(prdr, prdi);
  } else if (y.field_complexness() == Field::REAL) {
    double prdr = 0.0;
    for (int k = 0; k < Ntot; ++k) {
      prdr += yp[k] * xp[k];
    }
    prdr = Communicator::reduce_sum(prdr);
    return cmplx(prdr, 0.0);
  } else {
    abort();
  }
}


//----------------------------------------------------------------

inline int exchange(int count, Field *recv_buf, Field *send_buf, int idir, int ipm, int tag)
{
  return Communicator::exchange(count, recv_buf->ptr(0), send_buf->ptr(0), idir, ipm, tag);
}


inline int send_1to1(int count, Field *recv_buf, Field *send_buf, int p_to, int p_from, int tag)
{
  return Communicator::send_1to1(count, recv_buf->ptr(0), send_buf->ptr(0), p_to, p_from, tag);
}


//----------------------------------------------------------------

#ifdef USE_EXPR_TEMPL

struct Add
{
  static Field calc(const Field& lhs, const Field& rhs)
  {
    return Field(lhs) += rhs;
  }
};

struct Sub
{
  static Field calc(const Field& lhs, const Field& rhs)
  {
    return Field(lhs) -= rhs;
  }
};

struct Mul
{
  static Field calc(const Field& lhs, const double& rhs)
  {
    return Field(lhs) *= rhs;
  }

  static Field calc(const double& lhs, const Field& rhs)
  {
    return Field(rhs) *= lhs;
  }
};

template<typename L, typename Op, typename R>
class AdSbMc {
 private:
  const L& lhs;
  const R& rhs;
 public:
  AdSbMc(const L& Lhs, const R& Rhs) : lhs(Lhs), rhs(Rhs) {}
  Field eval() const { return Op::calc(lhs.eval(), rhs.eval()); }
};

template<typename L, typename Op>
class AdSbMc<L, Op, Field> {
 private:
  const L&     lhs;
  const Field& rhs;
 public:
  AdSbMc(const L& Lhs, const Field& Rhs) : lhs(Lhs), rhs(Rhs) {}
  Field eval() const { return Op::calc(lhs.eval(), rhs); }
};

template<typename Op, typename R>
class AdSbMc<Field, Op, R> {
 private:
  const Field& lhs;
  const R&     rhs;
 public:
  AdSbMc(const Field& Lhs, const R& Rhs) : lhs(Lhs), rhs(Rhs) {}
  Field eval() const { return Op::calc(lhs, rhs.eval()); }
};

template<typename L>
class AdSbMc<L, Mul, double> {
 private:
  const L&      lhs;
  const double& rhs;
 public:
  AdSbMc(const L& Lhs, const double& Rhs) : lhs(Lhs), rhs(Rhs) {}
  Field eval() const { return Mul::calc(lhs.eval(), rhs); }
};

template<>
class AdSbMc<Field, Mul, double> {
 private:
  const Field&  lhs;
  const double& rhs;
 public:
  AdSbMc(const Field& Lhs, const double& Rhs) : lhs(Lhs), rhs(Rhs) {}
  Field eval() const { return Mul::calc(lhs, rhs); }
};

template<typename R>
class AdSbMc<double, Mul, R> {
 private:
  const double& lhs;
  const R&      rhs;
 public:
  AdSbMc(const double& Lhs, const R& Rhs) : lhs(Lhs), rhs(Rhs) {}
  Field eval() const { return Mul::calc(lhs, rhs.eval()); }
};

template<>
class AdSbMc<double, Mul, Field> {
 private:
  const double& lhs;
  const Field&  rhs;
 public:
  AdSbMc(const double& Lhs, const Field& Rhs) : lhs(Lhs), rhs(Rhs) {}
  Field eval() const { return Mul::calc(lhs, rhs); }
};

template<typename Op>
class AdSbMc<Field, Op, Field> {
 private:
  const Field& lhs;
  const Field& rhs;
 public:
  AdSbMc(const Field& Lhs, const Field& Rhs) : lhs(Lhs), rhs(Rhs) {}
  Field eval() const { return Op::calc(lhs, rhs); }
};

template<typename L, typename R>
AdSbMc<L, Add, R> operator+(const L& lhs, const R& rhs)
{
  return AdSbMc<L, Add, R>(lhs, rhs);
}


template<typename L, typename R>
AdSbMc<L, Sub, R> operator-(const L& lhs, const R& rhs)
{
  return AdSbMc<L, Sub, R>(lhs, rhs);
}


template<typename R>
AdSbMc<double, Mul, R> operator*(const double& lhs, const R& rhs)
{
  return AdSbMc<double, Mul, R>(lhs, rhs);
}


#else  // USE_EXPR_TEMPL

inline Field operator*(const Field& v, const double s)
{
  Field w(v);

  w *= s;
  return w;
}


inline Field operator*(const double s, const Field& v)
{
  return operator*(v, s);
}


inline Field operator+(const Field& lhs, const Field& rhs)
{
  Field w(lhs);

  w += rhs;
  return w;
}


inline Field operator-(const Field& lhs, const Field& rhs)
{
  Field w(lhs);

  w -= rhs;
  return w;
}
#endif  // USE_EXPR_TEMPL
#endif
