Bridge++  Ver. 2.0.2
solver_BiCGStab_Cmplx.cpp
Go to the documentation of this file.
1 
14 #include "solver_BiCGStab_Cmplx.h"
15 
16 #ifdef USE_FACTORY_AUTOREGISTER
17 namespace {
18  bool init = Solver_BiCGStab_Cmplx::register_factory();
19 }
20 #endif
21 
22 const std::string Solver_BiCGStab_Cmplx::class_name = "Solver_BiCGStab_Cmplx";
23 
24 //====================================================================
26 {
27  std::string vlevel;
28  if (!params.fetch_string("verbose_level", vlevel)) {
29  m_vl = vout.set_verbose_level(vlevel);
30  }
31 
32  //- fetch and check input parameters
33  int Niter, Nrestart;
34  double Stop_cond;
35  bool use_init_guess;
36  double Omega_tolerance;
37 
38  int err = 0;
39  err += params.fetch_int("maximum_number_of_iteration", Niter);
40  err += params.fetch_int("maximum_number_of_restart", Nrestart);
41  err += params.fetch_double("convergence_criterion_squared", Stop_cond);
42  err += params.fetch_bool("use_initial_guess", use_init_guess);
43  err += params.fetch_double("Omega_tolerance", Omega_tolerance);
44 
45  if (err) {
46  vout.crucial(m_vl, "Error at %s: input parameter not found.\n",
47  class_name.c_str());
48  exit(EXIT_FAILURE);
49  }
50 
51  // set_parameters(Niter, Nrestart, Stop_cond, use_init_guess);
52  // set_parameters_BiCGStab_series(Omega_tolerance);
53  set_parameters(Niter, Nrestart, Stop_cond, use_init_guess, Omega_tolerance);
54 }
55 
56 
57 //====================================================================
59 {
60  params.set_int("maximum_number_of_iteration", m_Niter);
61  params.set_int("maximum_number_of_restart", m_Nrestart);
62  params.set_double("convergence_criterion_squared", m_Stop_cond);
63  params.set_bool("use_initial_guess", m_use_init_guess);
64  params.set_double("Omega_tolerance", m_Omega_tolerance);
65 
66  params.set_string("verbose_level", vout.get_verbose_level(m_vl));
67 }
68 
69 
70 //====================================================================
71 void Solver_BiCGStab_Cmplx::set_parameters(const int Niter, const int Nrestart,
72  const double Stop_cond)
73 {
75 
76  //- print input parameters
77  vout.general(m_vl, "%s: input parameters\n", class_name.c_str());
78  vout.general(m_vl, " Niter = %d\n", Niter);
79  vout.general(m_vl, " Nrestart = %d\n", Nrestart);
80  vout.general(m_vl, " Stop_cond = %8.2e\n", Stop_cond);
81 
82  //- range check
83  int err = 0;
84  err += ParameterCheck::non_negative(Niter);
85  err += ParameterCheck::non_negative(Nrestart);
86  err += ParameterCheck::square_non_zero(Stop_cond);
87 
88  if (err) {
89  vout.crucial(m_vl, "Error at %s: parameter range check failed.\n", class_name.c_str());
90  exit(EXIT_FAILURE);
91  }
92 
93  //- store values
94  m_Niter = Niter;
95  m_Nrestart = Nrestart;
96  m_Stop_cond = Stop_cond;
97 }
98 
99 
100 //====================================================================
101 void Solver_BiCGStab_Cmplx::set_parameters(const int Niter, const int Nrestart,
102  const double Stop_cond, const bool use_init_guess)
103 {
105 
106  //- print input parameters
107  vout.general(m_vl, "%s: input parameters\n", class_name.c_str());
108  vout.general(m_vl, " Niter = %d\n", Niter);
109  vout.general(m_vl, " Nrestart = %d\n", Nrestart);
110  vout.general(m_vl, " Stop_cond = %8.2e\n", Stop_cond);
111  vout.general(m_vl, " use_init_guess = %s\n", use_init_guess ? "true" : "false");
112 
113  //- range check
114  int err = 0;
115  err += ParameterCheck::non_negative(Niter);
116  err += ParameterCheck::non_negative(Nrestart);
117  err += ParameterCheck::square_non_zero(Stop_cond);
118 
119  if (err) {
120  vout.crucial(m_vl, "Error at %s: parameter range check failed.\n", class_name.c_str());
121  exit(EXIT_FAILURE);
122  }
123 
124  //- store values
125  m_Niter = Niter;
126  m_Nrestart = Nrestart;
127  m_Stop_cond = Stop_cond;
128  m_use_init_guess = use_init_guess;
129 }
130 
131 
132 //====================================================================
134 {
136 
137  //- print input parameters
138  vout.general(m_vl, " Omega_tolerance = %8.2e\n", Omega_tolerance);
139 
140  //- range check
141  // NB. Omega_tolerance == 0.0 is allowed.
142 
143  //- store values
144  m_Omega_tolerance = Omega_tolerance;
145 }
146 
147 
148 //====================================================================
150  const int Nrestart,
151  const double Stop_cond,
152  const bool use_init_guess,
153  const double Omega_tolerance)
154 {
156 
157  //- print input parameters
158  vout.general(m_vl, "%s: input parameters\n", class_name.c_str());
159  vout.general(m_vl, " Niter = %d\n", Niter);
160  vout.general(m_vl, " Nrestart = %d\n", Nrestart);
161  vout.general(m_vl, " Stop_cond = %8.2e\n", Stop_cond);
162  vout.general(m_vl, " use_init_guess = %s\n", use_init_guess ? "true" : "false");
163  vout.general(m_vl, " Omega_tolerance = %8.2e\n", Omega_tolerance);
164 
165  //- range check
166  int err = 0;
167  err += ParameterCheck::non_negative(Niter);
168  err += ParameterCheck::non_negative(Nrestart);
169  err += ParameterCheck::square_non_zero(Stop_cond);
170 
171  // NB. Omega_tolerance == 0.0 is allowed.
172 
173  if (err) {
174  vout.crucial(m_vl, "Error at %s: parameter range check failed.\n", class_name.c_str());
175  exit(EXIT_FAILURE);
176  }
177 
178  //- store values
179  m_Niter = Niter;
180  m_Nrestart = Nrestart;
181  m_Stop_cond = Stop_cond;
182  m_use_init_guess = use_init_guess;
183 
184  m_Omega_tolerance = Omega_tolerance;
185 }
186 
187 
188 //====================================================================
190  int& Nconv, double& diff)
191 {
192  const double bnorm2 = b.norm2();
193  const int bsize = b.size();
194 
195  vout.paranoiac(m_vl, "%s: solver starts\n", class_name.c_str());
196  vout.paranoiac(m_vl, " norm of b = %16.8e\n", bnorm2);
197  vout.paranoiac(m_vl, " size of b = %d\n", bsize);
198 
199  bool is_converged = false;
200  int Nconv2 = 0;
201  double diff2 = 1.0; // superficial initialization
202  double rr;
203 
204  int Nconv_unit = 1;
205  // if (m_fopr->get_mode() == "DdagD" || m_fopr->get_mode() == "DDdag") {
206  // Nconv_unit = 2;
207  // }
208 
209  reset_field(b);
210 
211  if (m_use_init_guess) {
212  copy(m_s, xq); // s = xq;
213  } else {
214  copy(m_s, b); // s = b;
215  }
216  solve_init(b, rr);
217  Nconv2 += Nconv_unit;
218 
219  vout.detailed(m_vl, " iter: %8d %22.15e\n", Nconv2, rr / bnorm2);
220 
221 
222  for (int i_restart = 0; i_restart < m_Nrestart; i_restart++) {
223  for (int iter = 0; iter < m_Niter; iter++) {
224  if (rr / bnorm2 < m_Stop_cond) break;
225 
226  solve_step(rr);
227  Nconv2 += 2 * Nconv_unit;
228 
229  vout.detailed(m_vl, " iter: %8d %22.15e\n", Nconv2, rr / bnorm2);
230  }
231 
232  //- calculate true residual
233  m_fopr->mult(m_s, m_x); // s = m_fopr->mult(x);
234  axpy(m_s, -1.0, b); // s -= b;
235  diff2 = m_s.norm2();
236 
237  if (diff2 / bnorm2 < m_Stop_cond) {
238  vout.detailed(m_vl, "%s: converged.\n", class_name.c_str());
239  vout.detailed(m_vl, " iter(final): %8d %22.15e\n", Nconv2, diff2 / bnorm2);
240 
241  is_converged = true;
242 
243  m_Nrestart_count = i_restart;
244  m_Nconv_count = Nconv2;
245 
246  break;
247  } else {
248  //- restart with new approximate solution
249  copy(m_s, m_x); // s = x;
250  solve_init(b, rr);
251 
252  vout.detailed(m_vl, "%s: restarted.\n", class_name.c_str());
253  }
254  }
255 
256 
257  if (!is_converged) {
258  vout.crucial(m_vl, "Error at %s: not converged.\n", class_name.c_str());
259  vout.crucial(m_vl, " iter(final): %8d %22.15e\n", Nconv2, diff2 / bnorm2);
260  exit(EXIT_FAILURE);
261  }
262 
263 
264  copy(xq, m_x); // xq = x;
265 
266 #pragma omp barrier
267 #pragma omp master
268  {
269  diff = sqrt(diff2 / bnorm2);
270  Nconv = Nconv2;
271  }
272 #pragma omp barrier
273 }
274 
275 
276 //====================================================================
278 {
279 #pragma omp barrier
280 #pragma omp master
281  {
282  const int Nin = b.nin();
283  const int Nvol = b.nvol();
284  const int Nex = b.nex();
285 
286  if ((m_s.nin() != Nin) || (m_s.nvol() != Nvol) || (m_s.nex() != Nex)) {
287  m_s.reset(Nin, Nvol, Nex);
288  m_r.reset(Nin, Nvol, Nex);
289  m_x.reset(Nin, Nvol, Nex);
290  m_p.reset(Nin, Nvol, Nex);
291  m_v.reset(Nin, Nvol, Nex);
292  m_t.reset(Nin, Nvol, Nex);
293  m_rh.reset(Nin, Nvol, Nex);
294  }
295  }
296 #pragma omp barrier
297 
298  vout.paranoiac(m_vl, " %s: field size reset.\n", class_name.c_str());
299 }
300 
301 
302 //====================================================================
303 void Solver_BiCGStab_Cmplx::solve_init(const Field& b, double& rr)
304 {
305  copy(m_x, m_s); // x = s;
306 
307  // r = b - A x_0
308  m_fopr->mult(m_v, m_s); // v = m_fopr->mult(s);
309  copy(m_r, b); // r = b;
310  axpy(m_r, -1.0, m_v); // r -= v;
311  copy(m_rh, m_r); // rh = r;
312 
313  rr = m_r.norm2(); // rr = r * r;
314 
315  m_p.set(0.0); // p = 0.0
316  m_v.set(0.0); // v = 0.0
317 
318 #pragma omp barrier
319 #pragma omp master
320  {
321  m_rho_prev = cmplx(1.0, 0.0);
322  m_alpha_prev = cmplx(1.0, 0.0);
323  m_omega_prev = cmplx(1.0, 0.0);
324  }
325 #pragma omp barrier
326 }
327 
328 
329 //====================================================================
331 {
332  const dcomplex rho = dotc(m_rh, m_r); // rho = rh * r;
333  const dcomplex beta = rho * m_alpha_prev / (m_rho_prev * m_omega_prev);
334 
335  // p = r + beta * (p - m_omega_prev * v);
336  axpy(m_p, -m_omega_prev, m_v); // p += - m_omega_prev * v;
337  aypx(beta, m_p, m_r); // p = beta * p + r;
338 
339  m_fopr->mult(m_v, m_p); // v = m_fopr->mult(p);
340 
341  const dcomplex aden = dotc(m_rh, m_v); // aden = rh * v;
342  const dcomplex alpha = rho / aden;
343 
344  copy(m_s, m_r); // s = r
345  axpy(m_s, -alpha, m_v); // s += - alpha * v;
346 
347  m_fopr->mult(m_t, m_s); // t = m_fopr->mult(s);
348 
349  const dcomplex omega_numer = dotc(m_t, m_s); // omega_numer = t * s;
350  const double omega_denom = dot(m_t, m_t); // omega_denom = t * t;
351 
352  dcomplex omega = omega_numer / omega_denom;
353 
354  const double s_norm2 = m_s.norm2();
355 
356  //- a prescription to improve stability of BiCGStab
357  const double abs_rho = abs(omega_numer) / sqrt(omega_denom * s_norm2);
358  if (abs_rho < m_Omega_tolerance) {
359  omega *= m_Omega_tolerance / abs_rho;
360  }
361 
362  axpy(m_x, omega, m_s); // x += omega * s;
363  axpy(m_x, alpha, m_p); // x += alpha * p;
364 
365  copy(m_r, m_s); // r = s
366  axpy(m_r, -omega, m_t); // r += - omega * t;
367 
368  rr = m_r.norm2(); // rr = r * r;
369 
370 #pragma omp barrier
371 #pragma omp master
372  {
373  m_rho_prev = rho;
374  m_alpha_prev = alpha;
375  m_omega_prev = omega;
376  }
377 #pragma omp barrier
378 }
379 
380 
381 //====================================================================
383 {
384  const int NPE = CommonParameters::NPE();
385 
386  //- NB1 Nin = 2 * Nc * Nd, Nex = 1 for field_F
387  //- NB2 Nvol = CommonParameters::Nvol()/2 for eo
388  const int Nin = m_x.nin();
389  const int Nvol = m_x.nvol();
390  const int Nex = m_x.nex();
391 
392  const double gflop_fopr = m_fopr->flop_count();
393 
394  if (gflop_fopr < CommonParameters::epsilon_criterion()) {
395  vout.crucial(m_vl, "Warning at %s: no fopr->flop_count() is available, setting flop = 0\n", class_name.c_str());
396  return 0.0;
397  }
398 
399  const double gflop_axpy = (Nin * Nex * 2) * ((Nvol * NPE) / 1.0e+9);
400  const double gflop_dotc = (Nin * Nex * 4) * ((Nvol * NPE) / 1.0e+9);
401  const double gflop_norm = (Nin * Nex * 2) * ((Nvol * NPE) / 1.0e+9);
402 
403  const double gflop_init = gflop_fopr + gflop_axpy + gflop_norm;
404  const double gflop_step = 2 * gflop_fopr + 3 * gflop_dotc + 6 * gflop_axpy + 2 * gflop_norm;
405  const double gflop_true_residual = gflop_fopr + gflop_axpy + gflop_norm;
406 
407  const int N_iter = (m_Nconv_count - 1) / 2;
408  const double gflop = gflop_norm + gflop_init + gflop_step * N_iter + gflop_true_residual * (m_Nrestart_count + 1)
409  + gflop_init * m_Nrestart_count;
410 
411 
412  return gflop;
413 }
414 
415 
416 //====================================================================
417 //============================================================END=====
Parameters::set_bool
void set_bool(const string &key, const bool value)
Definition: parameters.cpp:30
Solver_BiCGStab_Cmplx::m_rho_prev
dcomplex m_rho_prev
Definition: solver_BiCGStab_Cmplx.h:59
Solver_BiCGStab_Cmplx::class_name
static const std::string class_name
Definition: solver_BiCGStab_Cmplx.h:45
Solver_BiCGStab_Cmplx::reset_field
void reset_field(const Field &)
Definition: solver_BiCGStab_Cmplx.cpp:277
Parameters::set_string
void set_string(const string &key, const string &value)
Definition: parameters.cpp:39
Solver_BiCGStab_Cmplx::m_Stop_cond
double m_Stop_cond
Definition: solver_BiCGStab_Cmplx.h:53
Solver_BiCGStab_Cmplx::m_s
Field m_s
Definition: solver_BiCGStab_Cmplx.h:60
Solver_BiCGStab_Cmplx::m_r
Field m_r
Definition: solver_BiCGStab_Cmplx.h:60
Solver_BiCGStab_Cmplx::m_t
Field m_t
Definition: solver_BiCGStab_Cmplx.h:60
Field::set
void set(const int jin, const int site, const int jex, double v)
Definition: field.h:175
AFopr::mult
virtual void mult(AFIELD &, const AFIELD &)
multiplies fermion operator to a given field.
Definition: afopr.h:95
Parameters
Class for parameters.
Definition: parameters.h:46
Solver_BiCGStab_Cmplx::m_use_init_guess
bool m_use_init_guess
Definition: solver_BiCGStab_Cmplx.h:54
Solver_BiCGStab_Cmplx::m_alpha_prev
dcomplex m_alpha_prev
Definition: solver_BiCGStab_Cmplx.h:59
Parameters::set_double
void set_double(const string &key, const double value)
Definition: parameters.cpp:33
Solver_BiCGStab_Cmplx::m_rh
Field m_rh
Definition: solver_BiCGStab_Cmplx.h:60
Bridge::BridgeIO::detailed
void detailed(const char *format,...)
Definition: bridgeIO.cpp:219
Field::nex
int nex() const
Definition: field.h:128
Solver_BiCGStab_Cmplx::get_parameters
void get_parameters(Parameters &params) const
Definition: solver_BiCGStab_Cmplx.cpp:58
aypx
void aypx(const double a, Field &y, const Field &x)
aypx(y, a, x): y := a * y + x
Definition: field.cpp:509
Solver_BiCGStab_Cmplx::set_parameters
void set_parameters(const Parameters &params)
Definition: solver_BiCGStab_Cmplx.cpp:25
axpy
void axpy(Field &y, const double a, const Field &x)
axpy(y, a, x): y := a * x + y
Definition: field.cpp:380
dot
double dot(const Field &y, const Field &x)
Definition: field.cpp:576
Solver_BiCGStab_Cmplx::solve
void solve(Field &solution, const Field &source, int &Nconv, double &diff)
Definition: solver_BiCGStab_Cmplx.cpp:189
ParameterCheck::non_negative
int non_negative(const int v)
Definition: parameterCheck.cpp:21
Field::nin
int nin() const
Definition: field.h:126
Parameters::fetch_bool
int fetch_bool(const string &key, bool &value) const
Definition: parameters.cpp:391
copy
void copy(Field &y, const Field &x)
copy(y, x): y = x
Definition: field.cpp:212
Bridge::BridgeIO::paranoiac
void paranoiac(const char *format,...)
Definition: bridgeIO.cpp:238
Solver_BiCGStab_Cmplx::m_Nconv_count
int m_Nconv_count
Definition: solver_BiCGStab_Cmplx.h:63
Field::norm2
double norm2() const
Definition: field.cpp:113
Solver_BiCGStab_Cmplx::m_Nrestart
int m_Nrestart
Definition: solver_BiCGStab_Cmplx.h:52
Solver_BiCGStab_Cmplx::solve_step
void solve_step(double &)
Definition: solver_BiCGStab_Cmplx.cpp:330
Solver_BiCGStab_Cmplx::m_x
Field m_x
Definition: solver_BiCGStab_Cmplx.h:60
Field::size
int size() const
Definition: field.h:132
Solver_BiCGStab_Cmplx::solve_init
void solve_init(const Field &, double &)
Definition: solver_BiCGStab_Cmplx.cpp:303
AFopr::flop_count
virtual double flop_count()
returns the number of floating point operations.
Definition: afopr.h:160
Solver_BiCGStab_Cmplx::m_p
Field m_p
Definition: solver_BiCGStab_Cmplx.h:60
ParameterCheck::square_non_zero
int square_non_zero(const double v)
Definition: parameterCheck.cpp:43
Solver_BiCGStab_Cmplx::set_parameters_BiCGStab_series
DEPRECATED void set_parameters_BiCGStab_series(const double Omega_tolerance)
Definition: solver_BiCGStab_Cmplx.cpp:133
Solver_BiCGStab_Cmplx::m_omega_prev
dcomplex m_omega_prev
Definition: solver_BiCGStab_Cmplx.h:59
Field::nvol
int nvol() const
Definition: field.h:127
dotc
dcomplex dotc(const Field &y, const Field &x)
Definition: field.cpp:712
CommonParameters::NPE
static int NPE()
Definition: commonParameters.h:101
Field::reset
void reset(const int Nin, const int Nvol, const int Nex, const element_type cmpl=Element_type::COMPLEX)
Definition: field.h:95
Solver_BiCGStab_Cmplx::flop_count
double flop_count()
Definition: solver_BiCGStab_Cmplx.cpp:382
Bridge::BridgeIO::set_verbose_level
static VerboseLevel set_verbose_level(const std::string &str)
Definition: bridgeIO.cpp:133
solver_BiCGStab_Cmplx.h
Solver_BiCGStab_Cmplx::m_Niter
int m_Niter
Definition: solver_BiCGStab_Cmplx.h:52
Solver_BiCGStab_Cmplx::m_Omega_tolerance
double m_Omega_tolerance
Definition: solver_BiCGStab_Cmplx.h:56
Parameters::set_int
void set_int(const string &key, const int value)
Definition: parameters.cpp:36
Parameters::fetch_string
int fetch_string(const string &key, string &value) const
Definition: parameters.cpp:378
Parameters::fetch_double
int fetch_double(const string &key, double &value) const
Definition: parameters.cpp:327
Bridge::BridgeIO::crucial
void crucial(const char *format,...)
Definition: bridgeIO.cpp:180
Solver_BiCGStab_Cmplx::m_Nrestart_count
int m_Nrestart_count
Definition: solver_BiCGStab_Cmplx.h:62
Field
Container of Field-type object.
Definition: field.h:46
Solver_BiCGStab_Cmplx::m_vl
Bridge::VerboseLevel m_vl
Definition: solver_BiCGStab_Cmplx.h:48
Solver_BiCGStab_Cmplx::m_fopr
Fopr * m_fopr
Definition: solver_BiCGStab_Cmplx.h:50
Parameters::fetch_int
int fetch_int(const string &key, int &value) const
Definition: parameters.cpp:346
Bridge::BridgeIO::general
void general(const char *format,...)
Definition: bridgeIO.cpp:200
ThreadManager::assert_single_thread
static void assert_single_thread(const std::string &class_name)
assert currently running on single thread.
Definition: threadManager.cpp:372
Solver_BiCGStab_Cmplx::m_v
Field m_v
Definition: solver_BiCGStab_Cmplx.h:60
Bridge::vout
BridgeIO vout
Definition: bridgeIO.cpp:512
Bridge::BridgeIO::get_verbose_level
static std::string get_verbose_level(const VerboseLevel vl)
Definition: bridgeIO.cpp:154
CommonParameters::epsilon_criterion
static double epsilon_criterion()
Definition: commonParameters.h:119