Bridge++  Ver. 2.0.2
solver_CG.cpp
Go to the documentation of this file.
1 
14 #include "solver_CG.h"
15 
16 #ifdef USE_FACTORY_AUTOREGISTER
17 namespace {
18  bool init = Solver_CG::register_factory();
19 }
20 #endif
21 
22 const std::string Solver_CG::class_name = "Solver_CG";
23 
24 //====================================================================
26 {
27  std::string vlevel;
28  if (!params.fetch_string("verbose_level", vlevel)) {
29  m_vl = vout.set_verbose_level(vlevel);
30  }
31 
32  //- fetch and check input parameters
33  int Niter, Nrestart;
34  double Stop_cond;
35  bool use_init_guess;
36 
37  int err = 0;
38  err += params.fetch_int("maximum_number_of_iteration", Niter);
39  err += params.fetch_int("maximum_number_of_restart", Nrestart);
40  err += params.fetch_double("convergence_criterion_squared", Stop_cond);
41  err += params.fetch_bool("use_initial_guess", use_init_guess);
42 
43  if (err) {
44  vout.crucial(m_vl, "Error at %s: input parameter not found.\n", class_name.c_str());
45  exit(EXIT_FAILURE);
46  }
47 
48  set_parameters(Niter, Nrestart, Stop_cond, use_init_guess);
49 }
50 
51 
52 //====================================================================
54 {
55  params.set_int("maximum_number_of_iteration", m_Niter);
56  params.set_int("maximum_number_of_restart", m_Nrestart);
57  params.set_double("convergence_criterion_squared", m_Stop_cond);
58  params.set_bool("use_initial_guess", m_use_init_guess);
59 
60  params.set_string("verbose_level", vout.get_verbose_level(m_vl));
61 }
62 
63 
64 //====================================================================
65 void Solver_CG::set_parameters(const int Niter, const int Nrestart, const double Stop_cond)
66 {
68 
69  //- print input parameters
70  vout.general(m_vl, "%s:\n", class_name.c_str());
71  vout.general(m_vl, " Niter = %d\n", Niter);
72  vout.general(m_vl, " Nrestart = %d\n", Nrestart);
73  vout.general(m_vl, " Stop_cond = %8.2e\n", Stop_cond);
74 
75  //- range check
76  int err = 0;
77  err += ParameterCheck::non_negative(Niter);
78  err += ParameterCheck::non_negative(Nrestart);
79  err += ParameterCheck::square_non_zero(Stop_cond);
80 
81  if (err) {
82  vout.crucial(m_vl, "Error at %s: parameter range check failed.\n", class_name.c_str());
83  exit(EXIT_FAILURE);
84  }
85 
86  //- store values
87  m_Niter = Niter;
88  m_Nrestart = Nrestart;
89  m_Stop_cond = Stop_cond;
90 }
91 
92 
93 //====================================================================
94 void Solver_CG::set_parameters(const int Niter, const int Nrestart, const double Stop_cond, const bool use_init_guess)
95 {
97 
98  //- print input parameters
99  vout.general(m_vl, "%s:\n", class_name.c_str());
100  vout.general(m_vl, " Niter = %d\n", Niter);
101  vout.general(m_vl, " Nrestart = %d\n", Nrestart);
102  vout.general(m_vl, " Stop_cond = %8.2e\n", Stop_cond);
103  vout.general(m_vl, " use_init_guess = %s\n", use_init_guess ? "true" : "false");
104 
105  //- range check
106  int err = 0;
107  err += ParameterCheck::non_negative(Niter);
108  err += ParameterCheck::non_negative(Nrestart);
109  err += ParameterCheck::square_non_zero(Stop_cond);
110 
111  if (err) {
112  vout.crucial(m_vl, "Error at %s: parameter range check failed.\n", class_name.c_str());
113  exit(EXIT_FAILURE);
114  }
115 
116  //- store values
117  m_Niter = Niter;
118  m_Nrestart = Nrestart;
119  m_Stop_cond = Stop_cond;
120  m_use_init_guess = use_init_guess;
121 }
122 
123 
124 //====================================================================
125 void Solver_CG::solve(Field& xq, const Field& b,
126  int& Nconv, double& diff)
127 {
128  const double bnorm2 = b.norm2();
129  const int bsize = b.size();
130 
131  vout.paranoiac(m_vl, "%s: solver starts\n", class_name.c_str());
132  vout.paranoiac(m_vl, " norm of b = %16.8e\n", bnorm2);
133  vout.paranoiac(m_vl, " size of b = %d\n", bsize);
134 
135  bool is_converged = false;
136  int Nconv2 = 0;
137  double diff2 = 1.0; // superficial initialization
138  double rr;
139 
140  int Nconv_unit = 1;
141  // if (m_fopr->get_mode() == "DdagD" || m_fopr->get_mode() == "DDdag") {
142  // Nconv_unit = 2;
143  // }
144 
145  reset_field(b);
146 
147  if (m_use_init_guess) {
148  copy(m_s, xq); // s = xq;
149  } else {
150  copy(m_s, b); // s = b;
151  }
152  solve_init(b, rr);
153  Nconv2 += Nconv_unit;
154 
155  vout.detailed(m_vl, " iter: %8d %22.15e\n", Nconv2, rr / bnorm2);
156 
157 
158  for (int i_restart = 0; i_restart < m_Nrestart; i_restart++) {
159  for (int iter = 0; iter < m_Niter; iter++) {
160  if (rr / bnorm2 < m_Stop_cond) break;
161 
162  solve_step(rr);
163  Nconv2 += Nconv_unit;
164 
165  vout.detailed(m_vl, " iter: %8d %22.15e\n", Nconv2, rr / bnorm2);
166  }
167 
168  //- calculate true residual
169  m_fopr->mult(m_s, m_x); // s = m_fopr->mult(x);
170  axpy(m_s, -1.0, b); // s -= b;
171  diff2 = m_s.norm2();
172 
173  if (diff2 / bnorm2 < m_Stop_cond) {
174  vout.detailed(m_vl, "%s: converged.\n", class_name.c_str());
175  vout.detailed(m_vl, " iter(final): %8d %22.15e\n", Nconv2, diff2 / bnorm2);
176 
177  is_converged = true;
178 
179  m_Nrestart_count = i_restart;
180  m_Nconv_count = Nconv2;
181 
182  break;
183  } else {
184  //- restart with new approximate solution
185  copy(m_s, m_x); // s = x;
186  solve_init(b, rr);
187 
188  vout.detailed(m_vl, "%s: restarted.\n", class_name.c_str());
189  }
190  }
191 
192 
193  if (!is_converged) {
194  vout.crucial(m_vl, "Error at %s: not converged.\n", class_name.c_str());
195  vout.crucial(m_vl, " iter(final): %8d %22.15e\n", Nconv2, diff2 / bnorm2);
196  exit(EXIT_FAILURE);
197  }
198 
199 
200  copy(xq, m_x); // xq = x;
201 
202 #pragma omp barrier
203 #pragma omp master
204  {
205  diff = sqrt(diff2 / bnorm2);
206  Nconv = Nconv2;
207  }
208 #pragma omp barrier
209 }
210 
211 
212 //====================================================================
214 {
215 #pragma omp barrier
216 #pragma omp master
217  {
218  const int Nin = b.nin();
219  const int Nvol = b.nvol();
220  const int Nex = b.nex();
221 
222  if ((m_s.nin() != Nin) || (m_s.nvol() != Nvol) || (m_s.nex() != Nex)) {
223  m_s.reset(Nin, Nvol, Nex);
224  m_r.reset(Nin, Nvol, Nex);
225  m_x.reset(Nin, Nvol, Nex);
226  m_p.reset(Nin, Nvol, Nex);
227  }
228  }
229 #pragma omp barrier
230 
231  vout.paranoiac(m_vl, " %s: field size reset.\n", class_name.c_str());
232 }
233 
234 
235 //====================================================================
236 void Solver_CG::solve_init(const Field& b, double& rr)
237 {
238  copy(m_x, m_s); // x = s;
239 
240  // r = b - A x
241  copy(m_r, b); // r = b;
242  m_fopr->mult(m_s, m_x); // s = m_fopr->mult(x);
243  axpy(m_r, -1.0, m_s); // r -= s;
244 
245  copy(m_p, m_r); // p = r;
246  rr = m_r.norm2(); // rr = r * r;
247 }
248 
249 
250 //====================================================================
251 void Solver_CG::solve_step(double& rr)
252 {
253  const double rr_prev = rr;
254 
255  m_fopr->mult(m_s, m_p); // s = m_fopr->mult(p);
256 
257  const double pap = dot(m_p, m_s); // pap = p * s;
258  const double cr = rr_prev / pap;
259 
260  axpy(m_x, cr, m_p); // x += cr * p;
261  axpy(m_r, -cr, m_s); // r -= cr * s;
262 
263  rr = m_r.norm2(); // rr = r * r;
264 
265  const double rr_ratio = rr / rr_prev;
266  aypx(rr_ratio, m_p, m_r); // p = (rr / rr_prev) * p + r
267 }
268 
269 
270 //====================================================================
272 {
273  const int NPE = CommonParameters::NPE();
274 
275  //- NB1 Nin = 2 * Nc * Nd, Nex = 1 for field_F
276  //- NB2 Nvol = CommonParameters::Nvol()/2 for eo
277  const int Nin = m_x.nin();
278  const int Nvol = m_x.nvol();
279  const int Nex = m_x.nex();
280  const int e_type = m_x.field_element_type();
281 
282 
283  const double gflop_fopr = m_fopr->flop_count();
284 
285  if (gflop_fopr < CommonParameters::epsilon_criterion()) {
286  vout.crucial(m_vl, "Warning at %s: no fopr->flop_count() is available, setting flop = 0\n", class_name.c_str());
287  return 0.0;
288  }
289 
290  const double gflop_axpy = Nin * Nex * 2 * ((Nvol * NPE) / 1.0e+9);
291  const double gflop_norm = Nin * Nex * 2 * ((Nvol * NPE) / 1.0e+9);
292 
293  double gflop_dot = 0.0; // superficial initialization
294  if (e_type == Element_type::REAL) { // element_type REAL = 1
295  gflop_dot = Nin * Nex * 2 * ((Nvol * NPE) / 1.0e+9);
296  } else if (e_type == Element_type::COMPLEX) { // element_type COMPLEX = 2
297  gflop_dot = Nin * Nex * 4 * ((Nvol * NPE) / 1.0e+9);
298  } // NB. The other e_type is not allowed by construction in field.h
299 
300  const double gflop_init = gflop_fopr + gflop_axpy + gflop_norm;
301  const double gflop_step = gflop_fopr + gflop_dot + 3 * gflop_axpy + gflop_norm;
302  const double gflop_true_residual = gflop_fopr + gflop_axpy + gflop_norm;
303 
304  const double gflop = gflop_norm + gflop_init
305  + gflop_step * (m_Nconv_count - 1)
306  + gflop_true_residual * (m_Nrestart_count + 1)
307  + gflop_init * m_Nrestart_count;
308 
309  return gflop;
310 }
311 
312 
313 //====================================================================
314 //============================================================END=====
Parameters::set_bool
void set_bool(const string &key, const bool value)
Definition: parameters.cpp:30
Solver_CG::m_fopr
Fopr * m_fopr
Definition: solver_CG.h:47
Parameters::set_string
void set_string(const string &key, const string &value)
Definition: parameters.cpp:39
Solver_CG::m_Nrestart
int m_Nrestart
Definition: solver_CG.h:50
Solver_CG::m_vl
Bridge::VerboseLevel m_vl
Definition: solver_CG.h:44
AFopr::mult
virtual void mult(AFIELD &, const AFIELD &)
multiplies fermion operator to a given field.
Definition: afopr.h:95
Parameters
Class for parameters.
Definition: parameters.h:46
Parameters::set_double
void set_double(const string &key, const double value)
Definition: parameters.cpp:33
Solver_CG::reset_field
void reset_field(const Field &)
Definition: solver_CG.cpp:213
Bridge::BridgeIO::detailed
void detailed(const char *format,...)
Definition: bridgeIO.cpp:219
Field::nex
int nex() const
Definition: field.h:128
aypx
void aypx(const double a, Field &y, const Field &x)
aypx(y, a, x): y := a * y + x
Definition: field.cpp:509
Solver_CG::get_parameters
void get_parameters(Parameters &params) const
Definition: solver_CG.cpp:53
axpy
void axpy(Field &y, const double a, const Field &x)
axpy(y, a, x): y := a * x + y
Definition: field.cpp:380
dot
double dot(const Field &y, const Field &x)
Definition: field.cpp:576
Solver_CG::solve
void solve(Field &solution, const Field &source, int &Nconv, double &diff)
Definition: solver_CG.cpp:125
ParameterCheck::non_negative
int non_negative(const int v)
Definition: parameterCheck.cpp:21
Field::nin
int nin() const
Definition: field.h:126
Parameters::fetch_bool
int fetch_bool(const string &key, bool &value) const
Definition: parameters.cpp:391
copy
void copy(Field &y, const Field &x)
copy(y, x): y = x
Definition: field.cpp:212
Solver_CG::m_Nrestart_count
int m_Nrestart_count
Definition: solver_CG.h:57
Bridge::BridgeIO::paranoiac
void paranoiac(const char *format,...)
Definition: bridgeIO.cpp:238
Field::norm2
double norm2() const
Definition: field.cpp:113
Solver_CG::m_s
Field m_s
Definition: solver_CG.h:55
Field::size
int size() const
Definition: field.h:132
AFopr::flop_count
virtual double flop_count()
returns the number of floating point operations.
Definition: afopr.h:160
Solver_CG::class_name
static const std::string class_name
Definition: solver_CG.h:41
ParameterCheck::square_non_zero
int square_non_zero(const double v)
Definition: parameterCheck.cpp:43
Solver_CG::m_x
Field m_x
Definition: solver_CG.h:55
Element_type::REAL
@ REAL
Definition: bridge_defs.h:43
Field::nvol
int nvol() const
Definition: field.h:127
CommonParameters::NPE
static int NPE()
Definition: commonParameters.h:101
Field::reset
void reset(const int Nin, const int Nvol, const int Nex, const element_type cmpl=Element_type::COMPLEX)
Definition: field.h:95
Field::field_element_type
element_type field_element_type() const
Definition: field.h:129
Solver_CG::m_p
Field m_p
Definition: solver_CG.h:55
Element_type::COMPLEX
@ COMPLEX
Definition: bridge_defs.h:43
Bridge::BridgeIO::set_verbose_level
static VerboseLevel set_verbose_level(const std::string &str)
Definition: bridgeIO.cpp:133
Solver_CG::m_Niter
int m_Niter
Definition: solver_CG.h:49
Parameters::set_int
void set_int(const string &key, const int value)
Definition: parameters.cpp:36
Solver_CG::m_Stop_cond
double m_Stop_cond
Definition: solver_CG.h:51
Parameters::fetch_string
int fetch_string(const string &key, string &value) const
Definition: parameters.cpp:378
Parameters::fetch_double
int fetch_double(const string &key, double &value) const
Definition: parameters.cpp:327
Solver_CG::solve_init
void solve_init(const Field &, double &)
Definition: solver_CG.cpp:236
Solver_CG::set_parameters
void set_parameters(const Parameters &params)
Definition: solver_CG.cpp:25
Solver_CG::solve_step
void solve_step(double &)
Definition: solver_CG.cpp:251
solver_CG.h
Bridge::BridgeIO::crucial
void crucial(const char *format,...)
Definition: bridgeIO.cpp:180
Field
Container of Field-type object.
Definition: field.h:46
Solver_CG::flop_count
double flop_count()
Definition: solver_CG.cpp:271
Parameters::fetch_int
int fetch_int(const string &key, int &value) const
Definition: parameters.cpp:346
Solver_CG::m_Nconv_count
int m_Nconv_count
Definition: solver_CG.h:58
Bridge::BridgeIO::general
void general(const char *format,...)
Definition: bridgeIO.cpp:200
ThreadManager::assert_single_thread
static void assert_single_thread(const std::string &class_name)
assert currently running on single thread.
Definition: threadManager.cpp:372
Bridge::vout
BridgeIO vout
Definition: bridgeIO.cpp:512
Bridge::BridgeIO::get_verbose_level
static std::string get_verbose_level(const VerboseLevel vl)
Definition: bridgeIO.cpp:154
Solver_CG::m_use_init_guess
bool m_use_init_guess
Definition: solver_CG.h:52
Solver_CG::m_r
Field m_r
Definition: solver_CG.h:55
CommonParameters::epsilon_criterion
static double epsilon_criterion()
Definition: commonParameters.h:119