/*!
        @file    $Id:: shiftsolver_CG_alt.cpp #$

        @brief

        @author  <Hideo Matsufuru> hideo.matsufuru@kek.jp(matsufuru)
                 $LastChangedBy: sueda $

        @date    $LastChangedDate:: 2013-07-12 16:56:41 #$

        @version $LastChangedRevision: 930 $
*/


#include "shiftsolver_CG.h"

#include <stdio.h>
#include <iostream>

#include "commonParameters.h"
#include "communicator.h"
#include "bridgeIO.h"
using Bridge::vout;

#include "fopr.h"
#include "field.h"



//********************************************************************
void Shiftsolver_CG::solve(valarray<Field>& xq,
                           const valarray<double> sigma,
                           const Field& b, int& Nconv, double& diff) const
{
  //  vout.general(m_vl, "solver start\n");

  int Nshift = sigma.size();

  /*
  vout.general(m_vl, "number of shift = %d\n", Nshift);
  vout.general(m_vl, " values of shift:\n");
  for(int i = 0; i<Nshift; ++i){
    vout.general(m_vl, "   %d  %12.8f\n", i, sigma[i]);
  }
  */

  double snorm = 1.0 / b.norm2();

  Nconv = -1;
  Field           s(b);
  Field           r(s);
  valarray<Field> p(Nshift);
  valarray<Field> x(Nshift);

  valarray<double> zeta1(Nshift);
  valarray<double> zeta2(Nshift);
  valarray<double> csh2(Nshift);
  valarray<double> pp(Nshift);

  int Nin  = b.nin();
  int Nvol = b.nvol();
  int Nex  = b.nex();
  //  vout.general(m_vl, "Nin, Nvol, Nex = %d, %d, %d\n",Nin,Nvol,Nex);
  for (int i = 0; i < Nshift; ++i) {
    p[i].reset(Nin, Nvol, Nex);
    x[i].reset(Nin, Nvol, Nex);
    zeta1[i] = 1.0;
    zeta2[i] = 1.0;
    csh2[i]  = sigma[i] - sigma[0];
  }

  double rr;
  double alphap, betap, rrp;
  int    Nshift2 = Nshift;

  solve_init(x, p, r, s, rr, zeta1, zeta2, csh2, alphap, betap);
  //vout.general(m_vl, "  init: %22.15e\n",rr*snorm);

  for (int it = 0; it < Niter; it++) {
    solve_step(x, p, r, s, rr, zeta1, zeta2, sigma, csh2, alphap, betap,
               Nshift2, snorm, pp);
    // vout.general(m_vl, "%6d  %22.15e  %4d\n",it,rr*snorm,Nshift2);

    if (rr * snorm < enorm) {
      Nconv = it;
      break;
    }
  }
  if (Nconv == -1) {
    cout << "Not converged." << __FILE__ << "(" << __LINE__ << ")" << endl;
    abort();

    // vout.general(m_vl, "  residues of solutions:\n");
    diff = -1.0;
    for (int i = 0; i < Nshift; ++i) {
      //    s = opr->mult(x[i]);
      opr->mult(s, x[i]);
      s += sigma[i] * x[i];
      s -= b;
      double diff1 = s * s;
      diff1 *= snorm;
      //    vout.general(m_vl, "%6d  %22.15e\n",i,diff1);
      if (diff1 > diff) diff = diff1;
    }
    //vout.general(m_vl, "  diff(max) = %22.15e  \n",diff);

    for (int i = 0; i < Nshift; ++i) {
      xq[i] = x[i];
    }
  }

//********************************************************************
  void Shiftsolver_CG::
  solve_init(valarray<Field>& x, valarray<Field>& p,
             Field& r, Field& s, double& rr,
             valarray<double>& zeta1, valarray<double>& zeta2,
             valarray<double>& csh2,
             double& alphap, double& betap) const
  {
    int Nshift = p.size();

    //vout.general(m_vl, "number of shift = %d\n", Nshift);

    for (int i = 0; i < Nshift; ++i) {
      p[i] = s;
      x[i] = 0.0;
    }

    r      = s;
    rr     = r * r;
    alphap = 0.0;
    betap  = 1.0;
  }


//********************************************************************
  void Shiftsolver_CG::
  solve_step(valarray<Field>& x, valarray<Field>& p,
             Field& r, Field& s, double& rr,
             valarray<double>& zeta1, valarray<double>& zeta2,
             const valarray<double>& sigma, valarray<double>& csh2,
             double& alphap, double& betap,
             int& Nshift2, double& snorm, valarray<double>& pp) const
  {
    //s = opr->mult(p[0]);
    opr->mult(s, p[0]);
    s += sigma[0] * p[0];

    double rrp  = rr;
    double pap  = s * p[0];
    double beta = -rrp / pap;


    x[0] -= beta * p[0];
    r    += beta * s;
    rr    = r * r;

    double alpha = rr / rrp;

    p[0] *= alpha;
    p[0] += r;

    pp[0] = rr;

    double alphah = 1.0 + alphap * beta / betap;
    for (int ish = 1; ish < Nshift2; ++ish) {
      double zeta = (alphah - csh2[ish] * beta) / zeta1[ish]
                    + (1.0 - alphah) / zeta2[ish];
      zeta = 1.0 / zeta;
      double zr     = zeta / zeta1[ish];
      double betas  = beta * zr;
      double alphas = alpha * zr * zr;

      x[ish] -= betas * p[ish];
      p[ish] *= alphas;
      p[ish] += zeta * r;

      pp[ish]  = p[ish] * p[ish];
      pp[ish] *= snorm;

      zeta2[ish] = zeta1[ish];
      zeta1[ish] = zeta;
    }

    for (int ish = Nshift2 - 1; ish >= 0; --ish) {
      //    vout.general(m_vl, "%4d %16.8e\n",ish,pp[ish]);
      if (pp[ish] > enorm) {
        Nshift2 = ish + 1;
        break;
      }
    }

    alphap = alpha;
    betap  = beta;
  }


//********************************************************************
//************************************************************END*****
