Bridge++  Version 1.4.4
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
solver_CG.cpp
Go to the documentation of this file.
1 
14 #include "solver_CG.h"
15 
16 
17 #ifdef USE_FACTORY
18 namespace {
19  Solver *create_object(Fopr *fopr)
20  {
21  return new Solver_CG(fopr);
22  }
23 
24 
25  bool init = Solver::Factory::Register("CG", create_object);
26 }
27 #endif
28 
29 
30 const std::string Solver_CG::class_name = "Solver_CG";
31 
32 //====================================================================
34 {
35  const string str_vlevel = params.get_string("verbose_level");
36 
37  m_vl = vout.set_verbose_level(str_vlevel);
38 
39  //- fetch and check input parameters
40  int Niter, Nrestart;
41  double Stop_cond;
42 
43  int err = 0;
44  err += params.fetch_int("maximum_number_of_iteration", Niter);
45  err += params.fetch_int("maximum_number_of_restart", Nrestart);
46  err += params.fetch_double("convergence_criterion_squared", Stop_cond);
47 
48  if (err) {
49  vout.crucial(m_vl, "Error at %s: input parameter not found.\n", class_name.c_str());
50  exit(EXIT_FAILURE);
51  }
52 
53  set_parameters(Niter, Nrestart, Stop_cond);
54 }
55 
56 
57 //====================================================================
58 void Solver_CG::set_parameters(const int Niter, const int Nrestart, const double Stop_cond)
59 {
61 
62  //- print input parameters
63  vout.general(m_vl, "%s:\n", class_name.c_str());
64  vout.general(m_vl, " Niter = %d\n", Niter);
65  vout.general(m_vl, " Nrestart = %d\n", Nrestart);
66  vout.general(m_vl, " Stop_cond = %8.2e\n", Stop_cond);
67 
68  //- range check
69  int err = 0;
70  err += ParameterCheck::non_negative(Niter);
71  err += ParameterCheck::non_negative(Nrestart);
72  err += ParameterCheck::square_non_zero(Stop_cond);
73 
74  if (err) {
75  vout.crucial(m_vl, "Error at %s: parameter range check failed.\n", class_name.c_str());
76  exit(EXIT_FAILURE);
77  }
78 
79  //- store values
80  m_Niter = Niter;
81  m_Nrestart = Nrestart;
82  m_Stop_cond = Stop_cond;
83 }
84 
85 
86 //====================================================================
87 void Solver_CG::solve(Field& xq, const Field& b,
88  int& Nconv, double& diff)
89 {
90  double bnorm2 = b.norm2();
91  int bsize = b.size();
92 
93  vout.paranoiac(m_vl, "%s: solver starts\n", class_name.c_str());
94  vout.paranoiac(m_vl, " norm of b = %16.8e\n", bnorm2);
95  vout.paranoiac(m_vl, " size of b = %d\n", bsize);
96 
97  bool is_converged = false;
98  int Nconv2 = 0;
99  double diff2 = 1.0;
100  double rr;
101 
102  reset_field(b);
103  copy(m_s, b); // s = b;
104  solve_init(b, rr);
105  Nconv2 += 1;
106 
107  vout.detailed(m_vl, " iter: %8d %22.15e\n", Nconv2, rr / bnorm2);
108 
109 
110  for (int i_restart = 0; i_restart < m_Nrestart; i_restart++) {
111  for (int iter = 0; iter < m_Niter; iter++) {
112  if (rr / bnorm2 < m_Stop_cond) break;
113 
114  solve_step(rr);
115  Nconv2 += 1;
116 
117  vout.detailed(m_vl, " iter: %8d %22.15e\n", Nconv2, rr / bnorm2);
118  }
119 
120  //- calculate true residual
121  m_fopr->mult(m_s, m_x); // s = m_fopr->mult(x);
122  axpy(m_s, -1.0, b); // s -= b;
123  diff2 = m_s.norm2();
124 
125  if (diff2 / bnorm2 < m_Stop_cond) {
126  vout.detailed(m_vl, "%s: converged.\n", class_name.c_str());
127  vout.detailed(m_vl, " iter(final): %8d %22.15e\n", Nconv2, diff2 / bnorm2);
128 
129  is_converged = true;
130  break;
131  } else {
132  //- restart with new approximate solution
133  copy(m_s, m_x); // s = x;
134  solve_init(b, rr);
135 
136  vout.detailed(m_vl, "%s: restarted.\n", class_name.c_str());
137  }
138  }
139 
140 
141  if (!is_converged) {
142  vout.crucial(m_vl, "Error at %s: not converged.\n", class_name.c_str());
143  vout.crucial(m_vl, " iter(final): %8d %22.15e\n", Nconv2, diff2 / bnorm2);
144  exit(EXIT_FAILURE);
145  }
146 
147 
148  copy(xq, m_x); // xq = x;
149 
150 #pragma omp barrier
151 #pragma omp master
152  {
153  diff = sqrt(diff2 / bnorm2);
154  Nconv = Nconv2;
155  }
156 #pragma omp barrier
157 }
158 
159 
160 //====================================================================
162 {
163 #pragma omp barrier
164 #pragma omp master
165  {
166  int Nin = b.nin();
167  int Nvol = b.nvol();
168  int Nex = b.nex();
169 
170  if ((m_s.nin() != Nin) || (m_s.nvol() != Nvol) || (m_s.nex() != Nex)) {
171  m_s.reset(Nin, Nvol, Nex);
172  m_r.reset(Nin, Nvol, Nex);
173  m_x.reset(Nin, Nvol, Nex);
174  m_p.reset(Nin, Nvol, Nex);
175  }
176  }
177 #pragma omp barrier
178 
179  vout.detailed(m_vl, " %s: field size reset.\n", class_name.c_str());
180 }
181 
182 
183 //====================================================================
184 void Solver_CG::solve_init(const Field& b, double& rr)
185 {
186  copy(m_x, m_s); // x = s;
187 
188  // r = b - A x
189  copy(m_r, b); // r = b;
190  m_fopr->mult(m_s, m_x); // s = m_fopr->mult(x);
191  axpy(m_r, -1.0, m_s); // r -= s;
192 
193  copy(m_p, m_r); // p = r;
194  rr = m_r.norm2(); // rr = r * r;
195 }
196 
197 
198 //====================================================================
199 void Solver_CG::solve_step(double& rr)
200 {
201  double rr_prev = rr;
202 
203  m_fopr->mult(m_s, m_p); // s = m_fopr->mult(p);
204 
205  double pap = dot(m_p, m_s); // pap = p * s;
206  double cr = rr_prev / pap;
207 
208  axpy(m_x, cr, m_p); // x += cr * p;
209  axpy(m_r, -cr, m_s); // r -= cr * s;
210 
211  rr = m_r.norm2(); // rr = r * r;
212 
213  double rr_ratio = rr / rr_prev;
214  aypx(rr_ratio, m_p, m_r); // p = (rr / rr_prev) * p + r
215 }
216 
217 
218 //====================================================================
220 {
221  int NPE = CommonParameters::NPE();
223 
224  //- NB1 Nin = 2 * Nc * Nd, Nex = 1 for field_F
225  //- NB2 Nvol = CommonParameters::Nvol()/2 for eo
226  int Nin = m_x.nin();
227  int Nvol = m_x.nvol();
228  int Nex = m_x.nex();
229 
230  double flop_fopr = m_fopr->flop_count() / (Nvol * NPE);
231 
232  if (flop_fopr < eps) {
233  vout.crucial(m_vl, "Warning at %s: no fopr->flop_count() is available, setting flop = 0.0.\n", class_name.c_str());
234  return 0.0;
235  }
236 
237  double flop_axpy = static_cast<double>(Nin * Nex * 2);
238  double flop_dot = static_cast<double>(Nin * Nex * 2); // (Nin * Nex * 4) for Cmplx
239  double flop_norm = static_cast<double>(Nin * Nex * 2);
240 
241  double flop_init = flop_fopr + flop_axpy + flop_norm;
242  double flop_step = flop_fopr + flop_dot + 3 * flop_axpy + flop_norm;
243  double flop_true_residual = flop_fopr + flop_axpy + flop_norm;
244 
245  double flop = (flop_norm + flop_init + flop_step * m_Nconv_count + flop_true_residual
246  + flop_init * m_Nrestart_count) * (Nvol * NPE);
247 
248 
249  return flop;
250 }
251 
252 
253 //====================================================================
254 //============================================================END=====
BridgeIO vout
Definition: bridgeIO.cpp:495
void detailed(const char *format,...)
Definition: bridgeIO.cpp:212
static double epsilon_criterion()
double norm2() const
Definition: field.cpp:441
void solve_init(const Field &, double &)
Definition: solver_CG.cpp:184
double dot(const Field &y, const Field &x)
Definition: field.cpp:46
int m_Nrestart_count
Definition: solver_CG.h:53
void general(const char *format,...)
Definition: bridgeIO.cpp:195
double m_Stop_cond
Definition: solver_CG.h:48
void reset_field(const Field &)
Definition: solver_CG.cpp:161
void solve(Field &solution, const Field &source, int &Nconv, double &diff)
Definition: solver_CG.cpp:87
Container of Field-type object.
Definition: field.h:39
static const std::string class_name
Definition: solver_CG.h:41
int fetch_double(const string &key, double &value) const
Definition: parameters.cpp:211
void solve_step(double &)
Definition: solver_CG.cpp:199
int nvol() const
Definition: field.h:116
Class for parameters.
Definition: parameters.h:46
void copy(Field &y, const Field &x)
copy(y, x): y = x
Definition: field.cpp:381
Standard Conjugate Gradient solver algorithm.
Definition: solver_CG.h:38
int m_Nrestart
Definition: solver_CG.h:47
int square_non_zero(const double v)
Definition: checker.cpp:41
int nin() const
Definition: field.h:115
void set_parameters(const Parameters &params)
Definition: solver_CG.cpp:33
Field m_s
Definition: solver_CG.h:51
int fetch_int(const string &key, int &value) const
Definition: parameters.cpp:230
virtual double flop_count()
returns the flops per site.
Definition: fopr.h:121
void reset(const int Nin, const int Nvol, const int Nex, const element_type cmpl=COMPLEX)
Definition: field.h:84
void aypx(const double a, Field &y, const Field &x)
aypx(y, a, x): y := a * y + x
Definition: field.cpp:461
Field m_x
Definition: solver_CG.h:51
int nex() const
Definition: field.h:117
void paranoiac(const char *format,...)
Definition: bridgeIO.cpp:229
void axpy(Field &y, const double a, const Field &x)
axpy(y, a, x): y := a * x + y
Definition: field.cpp:168
Fopr * m_fopr
Definition: solver_CG.h:44
void crucial(const char *format,...)
Definition: bridgeIO.cpp:178
Base class for linear solver class family.
Definition: solver.h:37
virtual void mult(Field &, const Field &)=0
multiplies fermion operator to a given field (2nd argument)
Field m_p
Definition: solver_CG.h:51
int non_negative(const int v)
Definition: checker.cpp:21
Field m_r
Definition: solver_CG.h:51
static void assert_single_thread(const std::string &class_name)
assert currently running on single thread.
Base class of fermion operator family.
Definition: fopr.h:47
double flop_count()
Definition: solver_CG.cpp:219
string get_string(const string &key) const
Definition: parameters.cpp:116
int m_Niter
Definition: solver_CG.h:46
Bridge::VerboseLevel m_vl
Definition: solver.h:63
static VerboseLevel set_verbose_level(const std::string &str)
Definition: bridgeIO.cpp:131
int m_Nconv_count
Definition: solver_CG.h:54
static int NPE()
int size() const
Definition: field.h:121