/*!
        @file    $Id:: force_F_Rational.cpp #$

        @brief

        @author  <Hideo Matsufuru> hideo.matsufuru@kek.jp(matsufuru)
                 $LastChangedBy: sueda $

        @date    $LastChangedDate:: 2013-07-12 16:56:41 #$

        @version $LastChangedRevision: 930 $
*/

#include "force_F_Rational.h"

#ifdef USE_PARAMETERS_FACTORY
#include "parameters_factory.h"
#endif

using std::valarray;
using Bridge::vout;

//- parameter entries
namespace {
  void append_entry(Parameters& param)
  {
    param.Register_int("number_of_poles", 0);
    param.Register_int("exponent_numerator", 0);
    param.Register_int("exponent_denominator", 0);
    param.Register_double("lower_bound", 0.0);
    param.Register_double("upper_bound", 0.0);
    param.Register_int("maximum_number_of_iteration", 0);
    param.Register_double("convergence_criterion_squared", 0.0);

    param.Register_string("verbose_level", "NULL");
  }


#ifdef USE_PARAMETERS_FACTORY
  bool init_param = ParametersFactory::Register("Force.F_Rational", append_entry);
#endif
}
//- end

//- parameters class
Parameters_Force_F_Rational::Parameters_Force_F_Rational() { append_entry(*this); }
//- end

//====================================================================
void Force_F_Rational::set_parameters(const Parameters& params)
{
  const string str_vlevel = params.get_string("verbose_level");

  m_vl = vout.set_verbose_level(str_vlevel);

  //- fetch and check input parameters
  int    Np, n_exp, d_exp;
  double x_min, x_max;
  int    Niter;
  double Stop_cond;

  int err = 0;
  err += params.fetch_int("number_of_poles", Np);
  err += params.fetch_int("exponent_numerator", n_exp);
  err += params.fetch_int("exponent_denominator", d_exp);
  err += params.fetch_double("lower_bound", x_min);
  err += params.fetch_double("upper_bound", x_max);
  err += params.fetch_int("maximum_number_of_iteration", Niter);
  err += params.fetch_double("convergence_criterion_squared", Stop_cond);

  if (err) {
    vout.crucial(m_vl, "Force_Rational: fetch error, input parameter not found.\n");
    abort();
  }


  set_parameters(Np, n_exp, d_exp, x_min, x_max, Niter, Stop_cond);
}


//====================================================================
void Force_F_Rational::set_parameters(int Np, int n_exp, int d_exp,
                                      double x_min, double x_max,
                                      int Niter, double Stop_cond)
{
  //- print input parameters
  vout.general(m_vl, "Parameters of Force_F_Rational:\n");
  vout.general(m_vl, "  Np        = %d\n", Np);
  vout.general(m_vl, "  n_exp     = %d\n", n_exp);
  vout.general(m_vl, "  d_exp     = %d\n", d_exp);
  vout.general(m_vl, "  x_min     = %10.6f\n", x_min);
  vout.general(m_vl, "  x_max     = %10.6f\n", x_max);
  vout.general(m_vl, "  Niter     = %d\n", Niter);
  vout.general(m_vl, "  Stop_cond = %10.6f\n", Stop_cond);

  //- range check
  int err = 0;
  err += ParameterCheck::non_zero(Np);
  err += ParameterCheck::non_zero(n_exp);
  err += ParameterCheck::non_zero(d_exp);
  // NB. x_min,x_max=0 is allowed.
  err += ParameterCheck::non_zero(Niter);
  err += ParameterCheck::square_non_zero(Stop_cond);

  if (err) {
    vout.crucial(m_vl, "Force_Rational: parameter range check failed.\n");
    abort();
  }

  //- store values
  m_Np        = Np;
  m_n_exp     = n_exp;
  m_d_exp     = d_exp;
  m_x_min     = x_min;
  m_x_max     = x_max;
  m_Niter     = Niter;
  m_Stop_cond = Stop_cond;

  // post-process
  init_parameters();
}


//====================================================================
void Force_F_Rational::init_parameters()
{
  m_cl.resize(m_Np);
  m_bl.resize(m_Np);

  // Rational approximation
  double x_min2 = m_x_min * m_x_min;
  double x_max2 = m_x_max * m_x_max;

  Math_Rational rational;
  rational.set_parameters(m_Np, m_n_exp, m_d_exp, x_min2, x_max2);
  rational.get_parameters(m_a0, m_bl, m_cl);

  vout.general(m_vl, " a0 = %18.14f\n", m_a0);
  for (int i = 0; i < m_Np; i++) {
    vout.general(m_vl, " bl[%d] = %18.14f  cl[%d] = %18.14f\n",
                 i, m_bl[i], i, m_cl[i]);
  }
}


//====================================================================
Field Force_F_Rational::force_core(const Field& eta)
{
  int Nc   = CommonParameters::Nc();
  int Nvol = CommonParameters::Nvol();
  int Ndim = CommonParameters::Ndim();
  int NinG = 2 * Nc * Nc;

  Field_G  force(Nvol, Ndim), force1(Nvol, Ndim);
  Mat_SU_N ut(Nc);

  force1 = force_udiv(eta);

  for (int mu = 0; mu < Ndim; ++mu) {
    force.mult_Field_Gnn(mu, *m_U, mu, force1, mu);
    force.at_Field_G(mu);
  }
  force *= -2.0;

  /*
  for(int mu = 0; mu < Ndim; ++mu){
   for(int site = 0; site < Nvol; ++site){
     ut = m_U->mat(site,mu) * force1.mat(site,mu);
     ut.at();
     ut *= -2.0;
     force.set_mat(site,mu,ut);
   }
  }
  */

  return (Field)force;
}


//====================================================================
Field Force_F_Rational::force_udiv(const Field& eta)
{
  int Nc   = CommonParameters::Nc();
  int Nd   = CommonParameters::Nd();
  int Nvol = CommonParameters::Nvol();
  int Ndim = CommonParameters::Ndim();
  int NinG = 2 * Nc * Nc;

  int NinF  = eta.nin();
  int NvolF = eta.nvol();
  int NexF  = eta.nex();

  // Shiftsolver
  int Nshift = m_Np;

  valarray<Field> psi(Nshift);
  for (int i = 0; i < Nshift; ++i) {
    psi[i].reset(NinF, NvolF, NexF);
  }

  int    Nconv;
  double diff;

  vout.general(m_vl, "    Shift solver in force calculation\n");
  vout.general(m_vl, "      Number of shift values = %d\n", m_cl.size());
  m_fopr->set_mode("DdagD");

  Shiftsolver_CG *solver = new Shiftsolver_CG(m_fopr, m_Niter, m_Stop_cond);

  solver->solve(psi, m_cl, eta, Nconv, diff);
  vout.general(m_vl, "      diff(max) = %22.15e  \n", diff);

  delete solver;

  Field force(NinG, Nvol, Ndim);
  Field force1(NinG, Nvol, Ndim);
  force = 0.0;

  for (int i = 0; i < Nshift; ++i) {
    force1 = m_force->force_udiv(psi[i]);
    force += m_bl[i] * force1;
  }

  return force;
}


//====================================================================
//===========================================================END======
