/*!
        @file    $Id:: solver_CG_alt.cpp #$

        @brief

        @author  <Hideo Matsufuru> hideo.matsufuru@kek.jp(matsufuru)
                 $LastChangedBy: sueda $

        @date    $LastChangedDate:: 2013-07-12 16:56:41 #$

        @version $LastChangedRevision: 930 $
*/


#include "solver_CG.h"

#include "communicator.h"
#include "bridgeIO.h"
using Bridge::vout;

#include "fopr.h"


//====================================================================
void Solver_CG::solve(Field& xq, const Field& b,
                      int& Nconv, double& diff)
{
  //  vout.general(m_vl, "CG solver start\n");

  double snorm = 1.0 / b.norm2();

  reset_field(b);
  s = b;

  Nconv = -1;
  double rr;

  solve_init(rr);
  //vout.general(m_vl, "  init: %22.15e\n",rr*snorm);

  for (int it = 0; it < Niter; it++) {
    solve_step(rr);
    // vout.general(m_vl, "%6d  %22.15e\n",it,rr*snorm);

    if (rr * snorm < enorm) {
      Nconv = it;
      break;
    }
  }
  if (Nconv == -1) { cout << "Not converged." << __FILE__ << "(" << __LINE__ << ")" << endl; abort(); }

  //p = opr->mult(x);
  opr->mult(p, x);
  p   -= b;
  diff = p.norm();

  xq = x;
}


//====================================================================
void Solver_CG::reset_field(const Field& b)
{
  int Nin  = b.nin();
  int Nvol = b.nvol();
  int Nex  = b.nex();

  if ((s.nin() != Nin) || (s.nvol() != Nvol) || (s.nex() != Nex)) {
    s.reset(Nin, Nvol, Nex);
    r.reset(Nin, Nvol, Nex);
    x.reset(Nin, Nvol, Nex);
    p.reset(Nin, Nvol, Nex);
    v.reset(Nin, Nvol, Nex);

    vout.general(m_vl, "    Solver_CG: field size reset.\n");
  }
}


//====================================================================
void Solver_CG::solve_init(double& rr)
{
  r = s;
  x = s;
  //s = opr->mult(x);
  opr->mult(s, x);
  r -= s;
  p  = r;
  rr = r * r;
}


//====================================================================
void Solver_CG::solve_step(double& rr)
{
  //s = opr->mult(p);
  opr->mult(s, p);

  double pap = p * s;
  double rrp = rr;
  double cr  = rrp / pap;

  v  = p;
  v *= cr;
  x += v;
  //  x += cr*p;

  s *= cr;
  r -= s;

  rr = r * r;
  p *= rr / rrp;
  p += r;
}


//====================================================================
//============================================================END=====
