/*!
        @file    $Id:: solver_BiCGStab.cpp #$

        @brief

        @author  <Hideo Matsufuru> hideo.matsufuru@kek.jp(matsufuru)
                 $LastChangedBy: sueda $

        @date    $LastChangedDate:: 2013-07-19 14:15:23 #$

        @version $LastChangedRevision: 936 $
*/

#include "solver_BiCGStab.h"

#ifdef USE_PARAMETERS_FACTORY
#include "parameters_factory.h"
#endif

#ifdef USE_FACTORY
namespace {
  Solver *create_object(Fopr *fopr)
  {
    return new Solver_BiCGStab(fopr);
  }


  bool init = Solver::Factory::Register("BiCGStab", create_object);
}
#endif

//- parameter entries
namespace {
  void append_entry(Parameters& param)
  {
    param.Register_int("maximum_number_of_iteration", 0);
    param.Register_double("convergence_criterion_squared", 0.0);

    param.Register_string("verbose_level", "NULL");
  }


#ifdef USE_PARAMETERS_FACTORY
  bool init_param = ParametersFactory::Register("Solver.BiCGStab", append_entry);
#endif
}
//- end

//- parameters class
Parameters_Solver_BiCGStab::Parameters_Solver_BiCGStab() { append_entry(*this); }
//- end

//====================================================================
void Solver_BiCGStab::set_parameters(const Parameters& params)
{
  const string str_vlevel = params.get_string("verbose_level");

  m_vl = vout.set_verbose_level(str_vlevel);

  //- fetch and check input parameters
  int    Niter;
  double Stop_cond;

  int err = 0;
  err += params.fetch_int("maximum_number_of_iteration", Niter);
  err += params.fetch_double("convergence_criterion_squared", Stop_cond);

  if (err) {
    vout.crucial(m_vl, "Solver_BiCGStab: fetch error, input parameter not found.\n");
    abort();
  }


  set_parameters(Niter, Stop_cond);
}


//====================================================================
void Solver_BiCGStab::set_parameters(const int Niter, const double Stop_cond)
{
  //- print input parameters
  vout.general(m_vl, "Parameters of Solver_BiCGStab:\n");
  vout.general(m_vl, "  Niter     = %d\n", Niter);
  vout.general(m_vl, "  Stop_cond = %16.8e\n", Stop_cond);

  //- range check
  int err = 0;
  err += ParameterCheck::non_negative(Niter);
  err += ParameterCheck::square_non_zero(Stop_cond);

  if (err) {
    vout.crucial(m_vl, "Solver_BiCGStab: parameter range check failed.\n");
    abort();
  }

  //- store values
  m_Niter     = Niter;
  m_Stop_cond = Stop_cond;
}


//====================================================================
void Solver_BiCGStab::solve(Field& xq, const Field& b,
                            int& Nconv, double& diff)
{
  vout.paranoiac(m_vl, "  BiCGStab solver starts\n");

  reset_field(b);

  vout.paranoiac(m_vl, "    norm of b = %16.8e\n", b.norm2());
  vout.paranoiac(m_vl, "    size of b = %d\n", b.size());

  double snorm = 1.0 / b.norm2();
  double rr;

  Nconv = -1;
  s     = b;

  solve_init(b, rr);

  vout.detailed(m_vl, "    iter: %8d  %22.15e\n", 0, rr * snorm);

  for (int iter = 0; iter < m_Niter; iter++) {
    solve_step(rr);

    vout.detailed(m_vl, "    iter: %8d  %22.15e\n", 2 * (iter + 1), rr * snorm);

    if (rr * snorm < m_Stop_cond) {
      s    = m_fopr->mult(x);
      s   -= b;
      diff = s.norm();

      if (diff * diff * snorm < m_Stop_cond) {
        Nconv = 2 * (iter + 1);
        break;
      }

      s = x;
      solve_init(b, rr);
    }
  }
  if (Nconv == -1) {
    vout.crucial(m_vl, "BiCGStab solver not converged.\n");
    abort();
  }

  p    = m_fopr->mult(x);
  p   -= b;
  diff = p.norm();

  xq = x;
}


//====================================================================
void Solver_BiCGStab::reset_field(const Field& b)
{
  int Nin  = b.nin();
  int Nvol = b.nvol();
  int Nex  = b.nex();

  if ((s.nin() != Nin) || (s.nvol() != Nvol) || (s.nex() != Nex)) {
    s.reset(Nin, Nvol, Nex);
    r.reset(Nin, Nvol, Nex);
    x.reset(Nin, Nvol, Nex);
    p.reset(Nin, Nvol, Nex);
    v.reset(Nin, Nvol, Nex);
    rh.reset(Nin, Nvol, Nex);

    vout.paranoiac(m_vl, "    Solver_BiCGStab: field size reset.\n");
  }
}


//====================================================================
void Solver_BiCGStab::solve_init(const Field& b, double& rr)
{
  v  = m_fopr->mult(s);
  r  = b;
  x  = s;
  r -= v;
  rh = r;
  rr = r * r;

  rho_p   = 1.0;
  alpha_p = 1.0;
  omega_p = 1.0;

  p = 0.0;
  v = 0.0;
}


//====================================================================
void Solver_BiCGStab::solve_step(double& rr)
{
  double rho = rh * r;
  double bet = rho * alpha_p / (rho_p * omega_p);

  p = r + bet * (p - omega_p * v);

  v = m_fopr->mult(p);
  double aden = rh * v;
  double alpha  = rho / aden;

  s = r - (alpha * v);

  Field t(s);
  t = m_fopr->mult(s);

  double omega_n = t * s;
  double omega_d = t * t;
  double omega   = omega_n / omega_d;

  x += omega * s + alpha * p;
  r  = s - omega * t;

  rho_p   = rho;
  alpha_p = alpha;
  omega_p = omega;

  rr = r * r;
}


//====================================================================
//============================================================END=====
