/*!
        @file    $Id:: solver_BiCGStab_L_Cmplx.cpp #$

        @brief

        @author  <Yusuke Namekawa> namekawa@ccs.tsukuba.ac.jp(namekawa)
                 $LastChangedBy: sueda $

        @date    $LastChangedDate:: 2013-07-19 14:15:23 #$

        @version $LastChangedRevision: 936 $
*/

#include "solver_BiCGStab_L_Cmplx.h"

using std::valarray;

#ifdef USE_PARAMETERS_FACTORY
#include "parameters_factory.h"
#endif

#ifdef USE_FACTORY
namespace {
  Solver *create_object(Fopr *fopr)
  {
    return new Solver_BiCGStab_L_Cmplx(fopr);
  }


  bool init = Solver::Factory::Register("BiCGStab_L_Cmplx", create_object);
}
#endif

//- parameter entries
namespace {
  void append_entry(Parameters& param)
  {
    param.Register_int("maximum_number_of_iteration", 0);
    param.Register_double("convergence_criterion_squared", 0.0);

    param.Register_int("number_of_orthonormal_vectors", 0);

    param.Register_string("verbose_level", "NULL");
  }


#ifdef USE_PARAMETERS_FACTORY
  bool init_param = ParametersFactory::Register("Solver.BiCGStab_L_Cmplx", append_entry);
#endif
}
//- end

//- parameters class
Parameters_Solver_BiCGStab_L_Cmplx::Parameters_Solver_BiCGStab_L_Cmplx() { append_entry(*this); }
//- end

//====================================================================
void Solver_BiCGStab_L_Cmplx::set_parameters(const Parameters& params)
{
  const string str_vlevel = params.get_string("verbose_level");

  m_vl = vout.set_verbose_level(str_vlevel);

  //- fetch and check input parameters
  int    Niter;
  double Stop_cond;
  int    N_L;

  int err = 0;
  err += params.fetch_int("maximum_number_of_iteration", Niter);
  err += params.fetch_double("convergence_criterion_squared", Stop_cond);
  err += params.fetch_int("number_of_orthonormal_vectors", N_L);

  if (err) {
    vout.crucial(m_vl, "Solver_BiCGStab_L_Cmplx: fetch error, input parameter not found.\n");
    abort();
  }


  set_parameters(Niter, Stop_cond);
  set_parameters_L(N_L);
}


//====================================================================
void Solver_BiCGStab_L_Cmplx::set_parameters(const int Niter, const double Stop_cond)
{
  //- print input parameters
  vout.general(m_vl, "Parameters of Solver_BiCGStab_L_Cmplx:\n");
  vout.general(m_vl, "  Niter     = %d\n", Niter);
  vout.general(m_vl, "  Stop_cond = %16.8e\n", Stop_cond);

  //- range check
  int err = 0;
  err += ParameterCheck::non_negative(Niter);
  err += ParameterCheck::square_non_zero(Stop_cond);

  if (err) {
    vout.crucial(m_vl, "Solver_BiCGStab_L_Cmplx: parameter range check failed.\n");
    abort();
  }

  //- store values
  m_Niter     = Niter;
  m_Stop_cond = Stop_cond;
}


//====================================================================
void Solver_BiCGStab_L_Cmplx::set_parameters_L(const int N_L)
{
  //- print input parameters
  vout.general(m_vl, "  N_L   = %d\n", N_L);

  //- range check
  int err = 0;
  err += ParameterCheck::non_negative(N_L);

  if (err) {
    vout.crucial(m_vl, "Solver_BiCGStab_L_Cmplx: parameter range check failed.\n");
    abort();
  }

  //- store values
  m_N_L = N_L;
}


//====================================================================
void Solver_BiCGStab_L_Cmplx::solve(Field& xq, const Field& b,
                                    int& Nconv, double& diff)
{
  vout.detailed(m_vl, "  BiCGStab_L_Cmplx solver starts\n");

  reset_field(b);

  vout.paranoiac(m_vl, "    norm of b = %16.8e\n", b.norm2());
  vout.paranoiac(m_vl, "    size of b = %d\n", b.size());

  double snorm = 1.0 / b.norm2();
  double rr;

  Nconv = -1;
  s     = b;

  solve_init(b, rr);

  vout.detailed(m_vl, "    iter: %8d  %22.15e\n", 0, rr * snorm);

  for (int iter = 0; iter < m_Niter; iter++) {
    solve_step(rr);

    vout.detailed(m_vl, "    iter: %8d  %22.15e\n", 2 * m_N_L * (iter + 1), rr * snorm);

    if (rr * snorm < m_Stop_cond) {
      s    = m_fopr->mult(x);
      s   -= b;
      diff = s.norm();

      if (diff * diff * snorm < m_Stop_cond) {
        Nconv = 2 * m_N_L * (iter + 1);
        break;
      }

      s = x;
      solve_init(b, rr);
    }
  }
  if (Nconv == -1) {
    vout.crucial(m_vl, "BiCGStab_L_Cmplx solver not converged.\n");
    abort();
  }

  s    = m_fopr->mult(x);
  s   -= b;
  diff = s.norm();

  xq = x;
}


//====================================================================
void Solver_BiCGStab_L_Cmplx::reset_field(const Field& b)
{
  int Nin  = b.nin();
  int Nvol = b.nvol();
  int Nex  = b.nex();

  if ((s.nin() != Nin) || (s.nvol() != Nvol) || (s.nex() != Nex)) {
    s.reset(Nin, Nvol, Nex);
    x.reset(Nin, Nvol, Nex);

    r_init.reset(Nin, Nvol, Nex);

    v_tmp.reset(Nin, Nvol, Nex);

    vout.paranoiac(m_vl, "    Solver_BiCGStab_L_Cmplx: field size reset.\n");
  }

  u.resize(m_N_L + 1);
  r.resize(m_N_L + 1);

  for (int i = 0; i < m_N_L + 1; ++i) {
    u[i].reset(Nin, Nvol, Nex);
    r[i].reset(Nin, Nvol, Nex);
  }
}


//====================================================================
void Solver_BiCGStab_L_Cmplx::solve_init(const Field& b, double& rr)
{
  x = s;

  // r[0] = b - A x_0, x_0 = s
  v_tmp = m_fopr->mult(x);
  r[0]  = b;
  r[0] -= v_tmp;

  r_init = r[0];

  rr = r[0] * r[0];

  // NB. alpha_p = 0.0 \neq 1.0
  rho_p   = cmplx(1.0, 0.0);
  alpha_p = cmplx(0.0, 0.0);
  omega_p = cmplx(1.0, 0.0);

  u[0] = 0.0;
}


//====================================================================
void Solver_BiCGStab_L_Cmplx::solve_step(double& rr)
{
  double const_r, const_i;

  rho_p *= -omega_p;

  for (int j = 0; j < m_N_L; ++j) {
    // dcomplex rho  = r[j] * r_init;
    innerprod_c(const_r, const_i, r[j], r_init);
    dcomplex rho = cmplx(const_r, -const_i);

    dcomplex beta = alpha_p * (rho / rho_p);
    rho_p = rho;

    for (int i = 0; i < j + 1; ++i) {
      // u[i] = r[i] - beta * u[i];
      mult_c(v_tmp, u[i], real(beta), imag(beta));
      u[i] = r[i] - v_tmp;
    }

    u[j + 1] = m_fopr->mult(u[j]);

    // dcomplex gamma = u[j+1] * r_init;
    innerprod_c(const_r, const_i, u[j + 1], r_init);
    dcomplex gamma = cmplx(const_r, -const_i);

    alpha_p = rho_p / gamma;

    for (int i = 0; i < j + 1; ++i) {
      // r[i] -= alpha_p * u[i+1];
      mult_c(v_tmp, u[i + 1], real(alpha_p), imag(alpha_p));
      r[i] -= v_tmp;
    }

    r[j + 1] = m_fopr->mult(r[j]);

    // x += alpha_p * u[0];
    mult_c(v_tmp, u[0], real(alpha_p), imag(alpha_p));
    x += v_tmp;
  }


  valarray<double>   sigma(m_N_L + 1);
  valarray<dcomplex> gamma_prime(m_N_L + 1);

  // NB. tau(m_N_L,m_N_L+1), not (m_N_L+1,m_N_L+1)
  valarray<dcomplex> tau(m_N_L * (m_N_L + 1));
  int                ij, ji;

  for (int j = 1; j < m_N_L + 1; ++j) {
    for (int i = 1; i < j; ++i) {
      ij = index_ij(i, j);

      // tau[i,j]  = (r[j] * r[i]) / sigma[i];
      innerprod_c(const_r, const_i, r[j], r[i]);
      tau[ij] = cmplx(const_r, -const_i) / sigma[i];

      // r[    j] -= tau[i,j] * r[i];
      mult_c(v_tmp, r[i], real(tau[ij]), imag(tau[ij]));
      r[j] -= v_tmp;
    }


    sigma[j] = r[j] * r[j];


    // gamma_prime[j] = (r[0] * r[j]) / sigma[j];
    innerprod_c(const_r, const_i, r[0], r[j]);
    gamma_prime[j] = cmplx(const_r, -const_i) / sigma[j];
  }


  valarray<dcomplex> gamma(m_N_L + 1);
  dcomplex           c_tmp;

  gamma[m_N_L] = gamma_prime[m_N_L];
  omega_p      = gamma[m_N_L];

  for (int j = m_N_L - 1; j > 0; --j) {
    c_tmp = cmplx(0.0, 0.0);

    for (int i = j + 1; i < m_N_L + 1; ++i) {
      ji     = index_ij(j, i);
      c_tmp += tau[ji] * gamma[i];
    }

    gamma[j] = gamma_prime[j] - c_tmp;
  }


  // NB. gamma_double_prime(m_N_L), not (m_N_L+1)
  valarray<dcomplex> gamma_double_prime(m_N_L);

  for (int j = 1; j < m_N_L; ++j) {
    c_tmp = cmplx(0.0, 0.0);

    for (int i = j + 1; i < m_N_L; ++i) {
      ji     = index_ij(j, i);
      c_tmp += tau[ji] * gamma[i + 1];
    }

    gamma_double_prime[j] = gamma[j + 1] + c_tmp;
  }

  // x    += gamma[          1] * r[    0];
  // r[0] -= gamma_prime[m_N_L] * r[m_N_L];
  // u[0] -= gamma[      m_N_L] * u[m_N_L];

  mult_c(v_tmp, r[0], real(gamma[1]), imag(gamma[1]));
  x += v_tmp;

  mult_c(v_tmp, r[m_N_L], real(gamma_prime[m_N_L]), imag(gamma_prime[m_N_L]));
  r[0] -= v_tmp;

  mult_c(v_tmp, u[m_N_L], real(gamma[m_N_L]), imag(gamma[m_N_L]));
  u[0] -= v_tmp;


  for (int j = 1; j < m_N_L; ++j) {
    // x    += gamma_double_prime[j] * r[j];
    // r[0] -= gamma_prime[       j] * r[j];
    // u[0] -= gamma[             j] * u[j];

    mult_c(v_tmp, r[j], real(gamma_double_prime[j]), imag(gamma_double_prime[j]));
    x += v_tmp;

    mult_c(v_tmp, r[j], real(gamma_prime[j]), imag(gamma_prime[j]));
    r[0] -= v_tmp;

    mult_c(v_tmp, u[j], real(gamma[j]), imag(gamma[j]));
    u[0] -= v_tmp;
  }

  rr = r[0] * r[0];
}


//====================================================================
void Solver_BiCGStab_L_Cmplx::innerprod_c(double& prod_r, double& prod_i,
                                          const Field& v, const Field& w)
{
  // prod = (v,w);

  int size = w.size();

  assert(v.size() == size);

  prod_r = 0.0;
  prod_i = 0.0;

  for (int i = 0; i < size; i += 2) {
    prod_r += v.cmp(i) * w.cmp(i) + v.cmp(i + 1) * w.cmp(i + 1);
    prod_i += v.cmp(i) * w.cmp(i + 1) - v.cmp(i + 1) * w.cmp(i);
  }

  prod_r = Communicator::reduce_sum(prod_r);
  prod_i = Communicator::reduce_sum(prod_i);
}


//====================================================================
void Solver_BiCGStab_L_Cmplx::mult_c(Field& v,
                                     const Field& w,
                                     const double& prod_r, const double& prod_i)
{
  // v = dcomplex(prod_r,prod_i) * w;

  int size = w.size();

  assert(v.size() == size);

  double vr, vi;
  for (int i = 0; i < size; i += 2) {
    vr = prod_r * w.cmp(i) - prod_i * w.cmp(i + 1);
    vi = prod_r * w.cmp(i + 1) + prod_i * w.cmp(i);

    v.set(i, vr);
    v.set(i + 1, vi);
  }
}


//====================================================================
//============================================================END=====
