Bridge++  Version 1.5.4
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
solver_CG.cpp
Go to the documentation of this file.
1 
14 #include "solver_CG.h"
15 
16 #ifdef USE_FACTORY_AUTOREGISTER
17 namespace {
18  bool init = Solver_CG::register_factory();
19 }
20 #endif
21 
22 const std::string Solver_CG::class_name = "Solver_CG";
23 
24 //====================================================================
26 {
27  const string str_vlevel = params.get_string("verbose_level");
28 
29  m_vl = vout.set_verbose_level(str_vlevel);
30 
31  //- fetch and check input parameters
32  int Niter, Nrestart;
33  double Stop_cond;
34  bool use_init_guess;
35 
36  int err = 0;
37  err += params.fetch_int("maximum_number_of_iteration", Niter);
38  err += params.fetch_int("maximum_number_of_restart", Nrestart);
39  err += params.fetch_double("convergence_criterion_squared", Stop_cond);
40  err += params.fetch_bool("use_initial_guess", use_init_guess);
41 
42  if (err) {
43  vout.crucial(m_vl, "Error at %s: input parameter not found.\n", class_name.c_str());
44  exit(EXIT_FAILURE);
45  }
46 
47  set_parameters(Niter, Nrestart, Stop_cond, use_init_guess);
48 }
49 
50 
51 //====================================================================
52 void Solver_CG::set_parameters(const int Niter, const int Nrestart, const double Stop_cond)
53 {
55 
56  //- print input parameters
57  vout.general(m_vl, "%s:\n", class_name.c_str());
58  vout.general(m_vl, " Niter = %d\n", Niter);
59  vout.general(m_vl, " Nrestart = %d\n", Nrestart);
60  vout.general(m_vl, " Stop_cond = %8.2e\n", Stop_cond);
61 
62  //- range check
63  int err = 0;
64  err += ParameterCheck::non_negative(Niter);
65  err += ParameterCheck::non_negative(Nrestart);
66  err += ParameterCheck::square_non_zero(Stop_cond);
67 
68  if (err) {
69  vout.crucial(m_vl, "Error at %s: parameter range check failed.\n", class_name.c_str());
70  exit(EXIT_FAILURE);
71  }
72 
73  //- store values
74  m_Niter = Niter;
75  m_Nrestart = Nrestart;
76  m_Stop_cond = Stop_cond;
77 }
78 
79 
80 //====================================================================
81 void Solver_CG::set_parameters(const int Niter, const int Nrestart, const double Stop_cond, const bool use_init_guess)
82 {
84 
85  //- print input parameters
86  vout.general(m_vl, "%s:\n", class_name.c_str());
87  vout.general(m_vl, " Niter = %d\n", Niter);
88  vout.general(m_vl, " Nrestart = %d\n", Nrestart);
89  vout.general(m_vl, " Stop_cond = %8.2e\n", Stop_cond);
90  vout.general(m_vl, " use_init_guess = %s\n", use_init_guess ? "true" : "false");
91 
92  //- range check
93  int err = 0;
94  err += ParameterCheck::non_negative(Niter);
95  err += ParameterCheck::non_negative(Nrestart);
96  err += ParameterCheck::square_non_zero(Stop_cond);
97 
98  if (err) {
99  vout.crucial(m_vl, "Error at %s: parameter range check failed.\n", class_name.c_str());
100  exit(EXIT_FAILURE);
101  }
102 
103  //- store values
104  m_Niter = Niter;
105  m_Nrestart = Nrestart;
106  m_Stop_cond = Stop_cond;
107  m_use_init_guess = use_init_guess;
108 }
109 
110 
111 //====================================================================
112 void Solver_CG::solve(Field& xq, const Field& b,
113  int& Nconv, double& diff)
114 {
115  const double bnorm2 = b.norm2();
116  const int bsize = b.size();
117 
118  vout.paranoiac(m_vl, "%s: solver starts\n", class_name.c_str());
119  vout.paranoiac(m_vl, " norm of b = %16.8e\n", bnorm2);
120  vout.paranoiac(m_vl, " size of b = %d\n", bsize);
121 
122  bool is_converged = false;
123  int Nconv2 = 0;
124  double diff2 = 1.0; // superficial initialization
125  double rr;
126 
127  int Nconv_unit = 1;
128  // if (m_fopr->get_mode() == "DdagD" || m_fopr->get_mode() == "DDdag") {
129  // Nconv_unit = 2;
130  // }
131 
132  reset_field(b);
133 
134  if (m_use_init_guess) {
135  copy(m_s, xq); // s = xq;
136  } else {
137  copy(m_s, b); // s = b;
138  }
139  solve_init(b, rr);
140  Nconv2 += Nconv_unit;
141 
142  vout.detailed(m_vl, " iter: %8d %22.15e\n", Nconv2, rr / bnorm2);
143 
144 
145  for (int i_restart = 0; i_restart < m_Nrestart; i_restart++) {
146  for (int iter = 0; iter < m_Niter; iter++) {
147  if (rr / bnorm2 < m_Stop_cond) break;
148 
149  solve_step(rr);
150  Nconv2 += Nconv_unit;
151 
152  vout.detailed(m_vl, " iter: %8d %22.15e\n", Nconv2, rr / bnorm2);
153  }
154 
155  //- calculate true residual
156  m_fopr->mult(m_s, m_x); // s = m_fopr->mult(x);
157  axpy(m_s, -1.0, b); // s -= b;
158  diff2 = m_s.norm2();
159 
160  if (diff2 / bnorm2 < m_Stop_cond) {
161  vout.detailed(m_vl, "%s: converged.\n", class_name.c_str());
162  vout.detailed(m_vl, " iter(final): %8d %22.15e\n", Nconv2, diff2 / bnorm2);
163 
164  is_converged = true;
165 
166  m_Nrestart_count = i_restart;
167  m_Nconv_count = Nconv2;
168 
169  break;
170  } else {
171  //- restart with new approximate solution
172  copy(m_s, m_x); // s = x;
173  solve_init(b, rr);
174 
175  vout.detailed(m_vl, "%s: restarted.\n", class_name.c_str());
176  }
177  }
178 
179 
180  if (!is_converged) {
181  vout.crucial(m_vl, "Error at %s: not converged.\n", class_name.c_str());
182  vout.crucial(m_vl, " iter(final): %8d %22.15e\n", Nconv2, diff2 / bnorm2);
183  exit(EXIT_FAILURE);
184  }
185 
186 
187  copy(xq, m_x); // xq = x;
188 
189 #pragma omp barrier
190 #pragma omp master
191  {
192  diff = sqrt(diff2 / bnorm2);
193  Nconv = Nconv2;
194  }
195 #pragma omp barrier
196 }
197 
198 
199 //====================================================================
201 {
202 #pragma omp barrier
203 #pragma omp master
204  {
205  const int Nin = b.nin();
206  const int Nvol = b.nvol();
207  const int Nex = b.nex();
208 
209  if ((m_s.nin() != Nin) || (m_s.nvol() != Nvol) || (m_s.nex() != Nex)) {
210  m_s.reset(Nin, Nvol, Nex);
211  m_r.reset(Nin, Nvol, Nex);
212  m_x.reset(Nin, Nvol, Nex);
213  m_p.reset(Nin, Nvol, Nex);
214  }
215  }
216 #pragma omp barrier
217 
218  vout.detailed(m_vl, " %s: field size reset.\n", class_name.c_str());
219 }
220 
221 
222 //====================================================================
223 void Solver_CG::solve_init(const Field& b, double& rr)
224 {
225  copy(m_x, m_s); // x = s;
226 
227  // r = b - A x
228  copy(m_r, b); // r = b;
229  m_fopr->mult(m_s, m_x); // s = m_fopr->mult(x);
230  axpy(m_r, -1.0, m_s); // r -= s;
231 
232  copy(m_p, m_r); // p = r;
233  rr = m_r.norm2(); // rr = r * r;
234 }
235 
236 
237 //====================================================================
238 void Solver_CG::solve_step(double& rr)
239 {
240  const double rr_prev = rr;
241 
242  m_fopr->mult(m_s, m_p); // s = m_fopr->mult(p);
243 
244  const double pap = dot(m_p, m_s); // pap = p * s;
245  const double cr = rr_prev / pap;
246 
247  axpy(m_x, cr, m_p); // x += cr * p;
248  axpy(m_r, -cr, m_s); // r -= cr * s;
249 
250  rr = m_r.norm2(); // rr = r * r;
251 
252  const double rr_ratio = rr / rr_prev;
253  aypx(rr_ratio, m_p, m_r); // p = (rr / rr_prev) * p + r
254 }
255 
256 
257 //====================================================================
259 {
260  const int NPE = CommonParameters::NPE();
261 
262  //- NB1 Nin = 2 * Nc * Nd, Nex = 1 for field_F
263  //- NB2 Nvol = CommonParameters::Nvol()/2 for eo
264  const int Nin = m_x.nin();
265  const int Nvol = m_x.nvol();
266  const int Nex = m_x.nex();
267  const int e_type = m_x.field_element_type();
268 
269 
270  const double gflop_fopr = m_fopr->flop_count();
271 
272  if (gflop_fopr < CommonParameters::epsilon_criterion()) {
273  vout.crucial(m_vl, "Warning at %s: no fopr->flop_count() is available, setting flop = 0\n", class_name.c_str());
274  return 0.0;
275  }
276 
277  const double gflop_axpy = Nin * Nex * 2 * ((Nvol * NPE) / 1.0e+9);
278  const double gflop_norm = Nin * Nex * 2 * ((Nvol * NPE) / 1.0e+9);
279 
280  double gflop_dot = 0.0; // superficial initialization
281  if (e_type == Element_type::REAL) { // element_type REAL = 1
282  gflop_dot = Nin * Nex * 2 * ((Nvol * NPE) / 1.0e+9);
283  } else if (e_type == Element_type::COMPLEX) { // element_type COMPLEX = 2
284  gflop_dot = Nin * Nex * 4 * ((Nvol * NPE) / 1.0e+9);
285  } // NB. The other e_type is not allowed by construction in field.h
286 
287  const double gflop_init = gflop_fopr + gflop_axpy + gflop_norm;
288  const double gflop_step = gflop_fopr + gflop_dot + 3 * gflop_axpy + gflop_norm;
289  const double gflop_true_residual = gflop_fopr + gflop_axpy + gflop_norm;
290 
291  const double gflop = gflop_norm + gflop_init
292  + gflop_step * (m_Nconv_count - 1)
293  + gflop_true_residual * (m_Nrestart_count + 1)
294  + gflop_init * m_Nrestart_count;
295 
296  return gflop;
297 }
298 
299 
300 //====================================================================
301 //============================================================END=====
BridgeIO vout
Definition: bridgeIO.cpp:503
int fetch_bool(const string &key, bool &value) const
Definition: parameters.cpp:391
void detailed(const char *format,...)
Definition: bridgeIO.cpp:216
static double epsilon_criterion()
double norm2() const
Definition: field.cpp:592
void solve_init(const Field &, double &)
Definition: solver_CG.cpp:223
double dot(const Field &y, const Field &x)
Definition: field.cpp:46
int m_Nrestart_count
Definition: solver_CG.h:54
void general(const char *format,...)
Definition: bridgeIO.cpp:197
double m_Stop_cond
Definition: solver_CG.h:48
void reset_field(const Field &)
Definition: solver_CG.cpp:200
void solve(Field &solution, const Field &source, int &Nconv, double &diff)
Definition: solver_CG.cpp:112
Container of Field-type object.
Definition: field.h:45
static const std::string class_name
Definition: solver_CG.h:41
int fetch_double(const string &key, double &value) const
Definition: parameters.cpp:327
void solve_step(double &)
Definition: solver_CG.cpp:238
int nvol() const
Definition: field.h:127
bool m_use_init_guess
Definition: solver_CG.h:49
Class for parameters.
Definition: parameters.h:46
void copy(Field &y, const Field &x)
copy(y, x): y = x
Definition: field.cpp:532
int m_Nrestart
Definition: solver_CG.h:47
int square_non_zero(const double v)
int nin() const
Definition: field.h:126
void set_parameters(const Parameters &params)
Definition: solver_CG.cpp:25
Field m_s
Definition: solver_CG.h:52
int fetch_int(const string &key, int &value) const
Definition: parameters.cpp:346
virtual double flop_count()
returns the flop in giga unit
Definition: fopr.h:120
void aypx(const double a, Field &y, const Field &x)
aypx(y, a, x): y := a * y + x
Definition: field.cpp:612
Field m_x
Definition: solver_CG.h:52
int nex() const
Definition: field.h:128
void paranoiac(const char *format,...)
Definition: bridgeIO.cpp:235
void reset(const int Nin, const int Nvol, const int Nex, const element_type cmpl=Element_type::COMPLEX)
Definition: field.h:95
void axpy(Field &y, const double a, const Field &x)
axpy(y, a, x): y := a * x + y
Definition: field.cpp:319
Fopr * m_fopr
Definition: solver_CG.h:44
void crucial(const char *format,...)
Definition: bridgeIO.cpp:178
virtual void mult(Field &, const Field &)=0
multiplies fermion operator to a given field (2nd argument)
element_type field_element_type() const
Definition: field.h:129
Field m_p
Definition: solver_CG.h:52
int non_negative(const int v)
Field m_r
Definition: solver_CG.h:52
static void assert_single_thread(const std::string &class_name)
assert currently running on single thread.
double flop_count()
Definition: solver_CG.cpp:258
string get_string(const string &key) const
Definition: parameters.cpp:221
int m_Niter
Definition: solver_CG.h:46
Bridge::VerboseLevel m_vl
Definition: solver.h:63
static VerboseLevel set_verbose_level(const std::string &str)
Definition: bridgeIO.cpp:131
int m_Nconv_count
Definition: solver_CG.h:55
int size() const
Definition: field.h:132