/*!
        @file    $Id:: shiftsolver_CG.cpp #$

        @brief

        @author  <Hideo Matsufuru> hideo.matsufuru@kek.jp(matsufuru)
                 $LastChangedBy: sueda $

        @date    $LastChangedDate:: 2013-07-19 14:15:23 #$

        @version $LastChangedRevision: 936 $
*/

#include "shiftsolver_CG.h"

//- parameter entries
namespace {
  void append_entry(Parameters& param)
  {
    param.Register_int("maximum_number_of_iteration", 0);
    param.Register_double("convergence_criterion_squared", 0.0);

    param.Register_string("verbose_level", "NULL");
  }
}
//- end

//- parameters class
Parameters_Shiftsolver_CG::Parameters_Shiftsolver_CG() { append_entry(*this); }
//- end

//====================================================================
void Shiftsolver_CG::set_parameters(const Parameters& params)
{
  const string str_vlevel = params.get_string("verbose_level");

  m_vl = vout.set_verbose_level(str_vlevel);

  //- fetch and check input parameters
  int    Niter;
  double Stop_cond;

  int err = 0;
  err += params.fetch_int("maximum_number_of_iteration", Niter);
  err += params.fetch_double("convergence_criterion_squared", Stop_cond);

  if (err) {
    vout.crucial(m_vl, "Shiftsolver_CG: fetch error, input parameter not found.\n");
    abort();
  }


  set_parameters(Niter, Stop_cond);
}


//====================================================================
void Shiftsolver_CG::set_parameters(const int Niter, const double Stop_cond)
{
  //- print input parameters
  vout.general(m_vl, "Parameters of Shiftsolver_CG:\n");
  vout.general(m_vl, "  Niter     = %d\n", Niter);
  vout.general(m_vl, "  Stop_cond = %16.8e\n", Stop_cond);

  //- range check
  int err = 0;
  err += ParameterCheck::non_negative(Niter);
  err += ParameterCheck::square_non_zero(Stop_cond);

  if (err) {
    vout.crucial(m_vl, "Shiftsolver_CG: parameter range check failed.\n");
    abort();
  }

  //- store values
  m_Niter     = Niter;
  m_Stop_cond = Stop_cond;
}


//====================================================================
void Shiftsolver_CG::solve(std::valarray<Field>& xq,
                           std::valarray<double> sigma,
                           const Field& b,
                           int& Nconv, double& diff)
{
  vout.paranoiac(m_vl, "  Shift CG solver start.\n");

  int Nshift = sigma.size();

  vout.paranoiac(m_vl, "    number of shift = %d\n", Nshift);
  vout.paranoiac(m_vl, "    values of shift:\n");
  for(int i = 0; i<Nshift; ++i){
    vout.paranoiac(m_vl, "    %d  %12.8f\n", i, sigma[i]);
  }

  snorm = 1.0 / b.norm2();

  Nconv = -1;

  int Nin  = b.nin();
  int Nvol = b.nvol();
  int Nex  = b.nex();

  p.resize(Nshift);
  x.resize(Nshift);
  zeta1.resize(Nshift);
  zeta2.resize(Nshift);
  csh2.resize(Nshift);
  pp.resize(Nshift);

  for (int i = 0; i < Nshift; ++i) {
    p[i].reset(Nin, Nvol, Nex);
    x[i].reset(Nin, Nvol, Nex);
    zeta1[i] = 1.0;
    zeta2[i] = 1.0;
    csh2[i]  = sigma[i] - sigma[0];
  }
  s.reset(Nin, Nvol, Nex);
  r.reset(Nin, Nvol, Nex);
  s = b;
  r = b;

  double rr;
  Nshift2 = Nshift;

  solve_init(rr);

  vout.detailed(m_vl, "    iter: %8d  %22.15e\n", 0, rr * snorm);

  for (int iter = 0; iter < m_Niter; iter++) {
    solve_step(rr, sigma);

    vout.detailed(m_vl, "    iter: %8d  %22.15e  %4d\n", (iter + 1), rr * snorm, Nshift2);

    if (rr * snorm < m_Stop_cond) {
      Nconv = iter;
      break;
    }
  }
  if (Nconv == -1) {
    vout.crucial(m_vl, "Shiftsolver_CG not converged.\n");
    abort();
  }


  diff = -1.0;
  for (int i = 0; i < Nshift; ++i) {
    s  = m_fopr->mult(x[i]);
    s += sigma[i] * x[i];
    s -= b;
    double diff1 = s * s;
    diff1 = sqrt(diff1 * snorm);

    vout.paranoiac(m_vl, "    %4d  %22.15e\n", i, diff1);

    if (diff1 > diff) diff = diff1;
  }

  vout.paranoiac(m_vl, "   diff(max) = %22.15e  \n", diff);

  for (int i = 0; i < Nshift; ++i) {
    xq[i] = x[i];
  }
}


//====================================================================
void Shiftsolver_CG::solve_init(double& rr)
{
  int Nshift = p.size();

  vout.paranoiac(m_vl, "number of shift = %d\n", Nshift);

  for (int i = 0; i < Nshift; ++i) {
    p[i] = s;
    x[i] = 0.0;
  }

  r       = s;
  rr      = r * r;
  alpha_p = 0.0;
  beta_p  = 1.0;
}


//====================================================================
void Shiftsolver_CG::solve_step(double& rr,
                                const std::valarray<double>& sigma)
{
  s  = m_fopr->mult(p[0]);
  s += sigma[0] * p[0];

  double rr_p = rr;
  double pa_p = s * p[0];
  double beta = -rr_p / pa_p;

  x[0] -= beta * p[0];
  r    += beta * s;
  rr    = r * r;

  double alpha = rr / rr_p;

  p[0] *= alpha;
  p[0] += r;

  pp[0] = rr;

  double alpha_h = 1.0 + alpha_p * beta / beta_p;
  for (int ish = 1; ish < Nshift2; ++ish) {
    double zeta = (alpha_h - csh2[ish] * beta) / zeta1[ish]
                  + (1.0 - alpha_h) / zeta2[ish];
    zeta = 1.0 / zeta;
    double zr      = zeta / zeta1[ish];
    double beta_s  = beta * zr;
    double alpha_s = alpha * zr * zr;

    x[ish] -= beta_s * p[ish];
    p[ish] *= alpha_s;
    p[ish] += zeta * r;

    pp[ish]  = p[ish] * p[ish];
    pp[ish] *= snorm;

    zeta2[ish] = zeta1[ish];
    zeta1[ish] = zeta;
  }

  for (int ish = Nshift2 - 1; ish >= 0; --ish) {

    vout.paranoiac(m_vl, "%4d %16.8e\n", ish, pp[ish]);

    if (pp[ish] > m_Stop_cond) {
      Nshift2 = ish + 1;
      break;
    }
  }

  alpha_p = alpha;
  beta_p  = beta;
}


//====================================================================
//============================================================END=====
