Bridge++  Version 1.5.4
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
solver_BiCGStab_L_Cmplx.cpp
Go to the documentation of this file.
1 
15 
16 #ifdef USE_FACTORY_AUTOREGISTER
17 namespace {
18  bool init = Solver_BiCGStab_L_Cmplx::register_factory();
19 }
20 #endif
21 
22 const std::string Solver_BiCGStab_L_Cmplx::class_name = "Solver_BiCGStab_L_Cmplx";
23 
24 //====================================================================
26 {
27  const string str_vlevel = params.get_string("verbose_level");
28 
29  m_vl = vout.set_verbose_level(str_vlevel);
30 
31  //- fetch and check input parameters
32  int Niter, Nrestart;
33  double Stop_cond;
34  bool use_init_guess;
35  double Omega_tolerance;
36  int N_L;
37 
38  int err = 0;
39  err += params.fetch_int("maximum_number_of_iteration", Niter);
40  err += params.fetch_int("maximum_number_of_restart", Nrestart);
41  err += params.fetch_double("convergence_criterion_squared", Stop_cond);
42  err += params.fetch_bool("use_initial_guess", use_init_guess);
43  err += params.fetch_double("Omega_tolerance", Omega_tolerance);
44  err += params.fetch_int("number_of_orthonormal_vectors", N_L);
45 
46  if (err) {
47  vout.crucial(m_vl, "Error at %s: input parameter not found.\n", class_name.c_str());
48  exit(EXIT_FAILURE);
49  }
50 
51  set_parameters(Niter, Nrestart, Stop_cond, use_init_guess);
52  set_parameters_BiCGStab_series(Omega_tolerance);
53  set_parameters_L(N_L);
54 }
55 
56 
57 //====================================================================
58 void Solver_BiCGStab_L_Cmplx::set_parameters(const int Niter, const int Nrestart, const double Stop_cond)
59 {
61 
62  //- print input parameters
63  vout.general(m_vl, "%s: input parameters\n", class_name.c_str());
64  vout.general(m_vl, " Niter = %d\n", Niter);
65  vout.general(m_vl, " Nrestart = %d\n", Nrestart);
66  vout.general(m_vl, " Stop_cond = %8.2e\n", Stop_cond);
67 
68  //- range check
69  int err = 0;
70  err += ParameterCheck::non_negative(Niter);
71  err += ParameterCheck::non_negative(Nrestart);
72  err += ParameterCheck::square_non_zero(Stop_cond);
73 
74  if (err) {
75  vout.crucial(m_vl, "Error at %s: parameter range check failed.\n", class_name.c_str());
76  exit(EXIT_FAILURE);
77  }
78 
79  //- store values
80  m_Niter = Niter;
81  m_Nrestart = Nrestart;
82  m_Stop_cond = Stop_cond;
83 }
84 
85 
86 //====================================================================
87 void Solver_BiCGStab_L_Cmplx::set_parameters(const int Niter, const int Nrestart, const double Stop_cond, const bool use_init_guess)
88 {
90 
91  //- print input parameters
92  vout.general(m_vl, "%s: input parameters\n", class_name.c_str());
93  vout.general(m_vl, " Niter = %d\n", Niter);
94  vout.general(m_vl, " Nrestart = %d\n", Nrestart);
95  vout.general(m_vl, " Stop_cond = %8.2e\n", Stop_cond);
96  vout.general(m_vl, " use_init_guess = %s\n", use_init_guess ? "true" : "false");
97 
98  //- range check
99  int err = 0;
100  err += ParameterCheck::non_negative(Niter);
101  err += ParameterCheck::non_negative(Nrestart);
102  err += ParameterCheck::square_non_zero(Stop_cond);
103 
104  if (err) {
105  vout.crucial(m_vl, "Error at %s: parameter range check failed.\n", class_name.c_str());
106  exit(EXIT_FAILURE);
107  }
108 
109  //- store values
110  m_Niter = Niter;
111  m_Nrestart = Nrestart;
112  m_Stop_cond = Stop_cond;
113  m_use_init_guess = use_init_guess;
114 }
115 
116 
117 //====================================================================
119 {
121 
122  //- print input parameters
123  vout.general(m_vl, " Omega_tolerance = %8.2e\n", Omega_tolerance);
124 
125  //- range check
126  // NB. Omega_tolerance == 0.0 is allowed.
127 
128  //- store values
129  m_Omega_tolerance = Omega_tolerance;
130 }
131 
132 
133 //====================================================================
135 {
136  //- print input parameters
137  vout.general(m_vl, " N_L = %d\n", N_L);
138 
139  //- range check
140  int err = 0;
141  err += ParameterCheck::non_negative(N_L);
142 
143  if (err) {
144  vout.crucial(m_vl, "Error at %s: parameter range check failed.\n", class_name.c_str());
145  exit(EXIT_FAILURE);
146  }
147 
148  //- store values
149  m_N_L = N_L;
150 }
151 
152 
153 //====================================================================
155  int& Nconv, double& diff)
156 {
157  const double bnorm2 = b.norm2();
158  const int bsize = b.size();
159 
160  vout.paranoiac(m_vl, "%s: starts\n", class_name.c_str());
161  vout.paranoiac(m_vl, " norm of b = %16.8e\n", bnorm2);
162  vout.paranoiac(m_vl, " size of b = %d\n", bsize);
163 
164  bool is_converged = false;
165  int Nconv2 = 0;
166  double diff2 = 1.0; // superficial initialization
167  double rr;
168 
169  int Nconv_unit = 1;
170  // if (m_fopr->get_mode() == "DdagD" || m_fopr->get_mode() == "DDdag") {
171  // Nconv_unit = 2;
172  // }
173 
174  reset_field(b);
175 
176  if (m_use_init_guess) {
177  copy(m_s, xq); // s = xq;
178  } else {
179  copy(m_s, b); // s = b;
180  }
181  solve_init(b, rr);
182  Nconv2 += Nconv_unit;
183 
184  vout.detailed(m_vl, " iter: %8d %22.15e\n", Nconv2, rr / bnorm2);
185 
186 
187  for (int i_restart = 0; i_restart < m_Nrestart; i_restart++) {
188  for (int iter = 0; iter < m_Niter; iter++) {
189  if (rr / bnorm2 < m_Stop_cond) break;
190 
191  solve_step(rr);
192  Nconv2 += 2 * Nconv_unit * m_N_L;
193 
194  vout.detailed(m_vl, " iter: %8d %22.15e\n", Nconv2, rr / bnorm2);
195  }
196 
197  //- calculate true residual
198  m_fopr->mult(m_s, m_x); // s = m_fopr->mult(x);
199  axpy(m_s, -1.0, b); // s -= b;
200  diff2 = m_s.norm2();
201 
202  if (diff2 / bnorm2 < m_Stop_cond) {
203  vout.detailed(m_vl, "%s: converged.\n", class_name.c_str());
204  vout.detailed(m_vl, " iter(final): %8d %22.15e\n", Nconv2, diff2 / bnorm2);
205 
206  is_converged = true;
207 
208  m_Nrestart_count = i_restart;
209  m_Nconv_count = Nconv2;
210 
211  break;
212  } else {
213  //- restart with new approximate solution
214  copy(m_s, m_x); // s = x;
215  solve_init(b, rr);
216 
217  vout.detailed(m_vl, "%s: restarted.\n", class_name.c_str());
218  }
219  }
220 
221 
222  if (!is_converged) {
223  vout.crucial(m_vl, "Error at %s: not converged.\n", class_name.c_str());
224  vout.crucial(m_vl, " iter(final): %8d %22.15e\n", Nconv2, diff2 / bnorm2);
225  exit(EXIT_FAILURE);
226  }
227 
228 
229  copy(xq, m_x); // xq = x;
230 
231 #pragma omp barrier
232 #pragma omp master
233  {
234  diff = sqrt(diff2 / bnorm2);
235  Nconv = Nconv2;
236  }
237 #pragma omp barrier
238 }
239 
240 
241 //====================================================================
243 {
244 #pragma omp barrier
245 #pragma omp master
246  {
247  const int Nin = b.nin();
248  const int Nvol = b.nvol();
249  const int Nex = b.nex();
250 
251  if ((m_s.nin() != Nin) || (m_s.nvol() != Nvol) || (m_s.nex() != Nex)) {
252  m_s.reset(Nin, Nvol, Nex);
253  m_x.reset(Nin, Nvol, Nex);
254  m_r_init.reset(Nin, Nvol, Nex);
255  m_v.reset(Nin, Nvol, Nex);
256  }
257 
258  m_u.resize(m_N_L + 1);
259  m_r.resize(m_N_L + 1);
260 
261  for (int i = 0; i < m_N_L + 1; ++i) {
262  m_u[i].reset(Nin, Nvol, Nex);
263  m_r[i].reset(Nin, Nvol, Nex);
264  }
265  }
266 #pragma omp barrier
267 
268  vout.paranoiac(m_vl, " %s: field size reset.\n", class_name.c_str());
269 }
270 
271 
272 //====================================================================
273 void Solver_BiCGStab_L_Cmplx::solve_init(const Field& b, double& rr)
274 {
275  copy(m_x, m_s); // x = s;
276 
277  for (int i = 0; i < m_N_L + 1; ++i) {
278  m_r[i].set(0.0); // r[i] = 0.0;
279  m_u[i].set(0.0); // u[i] = 0.0;
280  }
281 
282  // r[0] = b - A x_0
283  m_fopr->mult(m_v, m_s); // m_v = m_fopr->mult(s);
284  copy(m_r[0], b); // r[0] = b;
285  axpy(m_r[0], -1.0, m_v); // r[0] -= m_v;
286 
287  copy(m_r_init, m_r[0]); // r_init = r[0];
288  rr = m_r[0].norm2(); // rr = r[0] * r[0];
289 
290 #pragma omp barrier
291 #pragma omp master
292  {
293  m_rho_prev = cmplx(-1.0, 0.0);
294 
295  // NB. m_alpha_prev = 0.0 \neq 1.0
296  m_alpha_prev = cmplx(0.0, 0.0);
297  }
298 #pragma omp barrier
299 }
300 
301 
302 //====================================================================
304 {
305  dcomplex rho_prev2 = m_rho_prev;
306  dcomplex alpha_prev2 = m_alpha_prev;
307 
308  for (int j = 0; j < m_N_L; ++j) {
309  dcomplex rho = dotc(m_r[j], m_r_init); // dcomplex rho = r[j] * r_init;
310  rho = conj(rho);
311 
312  dcomplex beta = alpha_prev2 * (rho / rho_prev2);
313 
314  rho_prev2 = rho;
315 
316  for (int i = 0; i < j + 1; ++i) {
317  aypx(-beta, m_u[i], m_r[i]); // u[i] = - beta * u[i] + r[i];
318  }
319 
320  m_fopr->mult(m_u[j + 1], m_u[j]); // u[j+1] = m_fopr->mult(u[j]);
321 
322  dcomplex gamma = dotc(m_u[j + 1], m_r_init);
323  alpha_prev2 = rho_prev2 / conj(gamma);
324 
325  for (int i = 0; i < j + 1; ++i) {
326  axpy(m_r[i], -alpha_prev2, m_u[i + 1]); // r[i] -= alpha_prev * u[i+1];
327  }
328 
329  m_fopr->mult(m_r[j + 1], m_r[j]); // r[j+1] = m_fopr->mult(r[j]);
330 
331  axpy(m_x, alpha_prev2, m_u[0]); // x += alpha_prev * u[0];
332  }
333 
334 
335  std::vector<double> sigma(m_N_L + 1);
336  std::vector<dcomplex> gamma_prime(m_N_L + 1);
337 
338  // NB. tau(m_N_L,m_N_L+1), not (m_N_L+1,m_N_L+1)
339  std::vector<dcomplex> tau(m_N_L * (m_N_L + 1));
340 
341  const double sigma_0 = m_r[0].norm2();
342 
343  for (int j = 1; j < m_N_L + 1; ++j) {
344  for (int i = 1; i < j; ++i) {
345  int ij = index_ij(i, j);
346 
347  dcomplex r_ji = dotc(m_r[j], m_r[i]);
348  tau[ij] = conj(r_ji) / sigma[i]; // tau[ij] = (r[j] * r[i]) / sigma[i];
349  axpy(m_r[j], -tau[ij], m_r[i]); // r[j] -= tau[ij] * r[i];
350  }
351 
352  sigma[j] = m_r[j].norm2(); // sigma[j] = r[j] * r[j];
353 
354  dcomplex r_0j = dotc(m_r[0], m_r[j]);
355  gamma_prime[j] = conj(r_0j) / sigma[j]; // gamma_prime[j] = (r[0] * r[j]) / sigma[j];
356 
357  //- a prescription to improve stability of BiCGStab(L)
358  double abs_rho = abs(r_0j) / sqrt(sigma[j] * sigma_0);
359  if (abs_rho < m_Omega_tolerance) {
360  gamma_prime[j] *= m_Omega_tolerance / abs_rho;
361  }
362  }
363 
364 
365  std::vector<dcomplex> gamma(m_N_L + 1);
366  dcomplex c_tmp;
367 
368  gamma[m_N_L] = gamma_prime[m_N_L];
369 
370  for (int j = m_N_L - 1; j > 0; --j) {
371  dcomplex c_tmp = cmplx(0.0, 0.0);
372 
373  for (int i = j + 1; i < m_N_L + 1; ++i) {
374  int ji = index_ij(j, i);
375  c_tmp += tau[ji] * gamma[i];
376  }
377 
378  gamma[j] = gamma_prime[j] - c_tmp;
379  }
380 
381 
382  // NB. gamma_double_prime(m_N_L), not (m_N_L+1)
383  std::vector<dcomplex> gamma_double_prime(m_N_L);
384 
385  for (int j = 1; j < m_N_L; ++j) {
386  dcomplex c_tmp = cmplx(0.0, 0.0);
387 
388  for (int i = j + 1; i < m_N_L; ++i) {
389  int ji = index_ij(j, i);
390  c_tmp += tau[ji] * gamma[i + 1];
391  }
392 
393  gamma_double_prime[j] = gamma[j + 1] + c_tmp;
394  }
395 
396 
397  axpy(m_x, gamma[1], m_r[0]); // x += gamma[ 1] * r[ 0];
398  axpy(m_r[0], -gamma_prime[m_N_L], m_r[m_N_L]); // r[0] -= gamma_prime[m_N_L] * r[m_N_L];
399  axpy(m_u[0], -gamma[m_N_L], m_u[m_N_L]); // u[0] -= gamma[ m_N_L] * u[m_N_L];
400 
401  for (int j = 1; j < m_N_L; ++j) {
402  axpy(m_x, gamma_double_prime[j], m_r[j]); // x += gamma_double_prime[j] * r[j];
403  axpy(m_r[0], -gamma_prime[j], m_r[j]); // r[0] -= gamma_prime[ j] * r[j];
404  axpy(m_u[0], -gamma[j], m_u[j]); // u[0] -= gamma[ j] * u[j];
405  }
406 
407  rr = m_r[0].norm2(); // rr = r[0] * r[0];
408 
409 #pragma omp barrier
410 #pragma omp master
411  {
412  m_rho_prev = rho_prev2;
413  m_alpha_prev = alpha_prev2;
414 
415  m_rho_prev *= -gamma_prime[m_N_L];
416  }
417 #pragma omp barrier
418 }
419 
420 
421 //====================================================================
423 {
424  const int NPE = CommonParameters::NPE();
425 
426  //- NB1 Nin = 2 * Nc * Nd, Nex = 1 for field_F
427  //- NB2 Nvol = CommonParameters::Nvol()/2 for eo
428  const int Nin = m_x.nin();
429  const int Nvol = m_x.nvol();
430  const int Nex = m_x.nex();
431 
432  const double gflop_fopr = m_fopr->flop_count();
433 
434  if (gflop_fopr < CommonParameters::epsilon_criterion()) {
435  vout.crucial(m_vl, "Warning at %s: no fopr->flop_count() is available, setting flop = 0\n", class_name.c_str());
436  return 0.0;
437  }
438 
439  const double gflop_axpy = (Nin * Nex * 2) * ((Nvol * NPE) / 1.0e+9);
440  const double gflop_dotc = (Nin * Nex * 4) * ((Nvol * NPE) / 1.0e+9);
441  const double gflop_norm = (Nin * Nex * 2) * ((Nvol * NPE) / 1.0e+9);
442 
443  int N_L_part = 0;
444  for (int j = 0; j < m_N_L; ++j) {
445  for (int i = 0; i < j + 1; ++i) {
446  N_L_part += 1;
447  }
448  }
449 
450  const double gflop_init = gflop_fopr + gflop_axpy + gflop_norm;
451  const double gflop_step_BiCG_part = 2 * m_N_L * gflop_fopr
452  + 2 * m_N_L * gflop_dotc
453  + (m_N_L + 2 * N_L_part) * gflop_axpy;
454  const double gflop_step_L_part = (N_L_part + m_N_L) * gflop_dotc
455  + (N_L_part + 3 * m_N_L) * gflop_axpy
456  + (m_N_L + 1) * gflop_norm;
457  const double gflop_step = gflop_step_BiCG_part + gflop_step_L_part;
458  const double gflop_true_residual = gflop_fopr + gflop_axpy + gflop_norm;
459 
460  const int N_iter = (m_Nconv_count - 1) / (2 * m_N_L);
461  const double gflop = gflop_norm + gflop_init + gflop_step * N_iter + gflop_true_residual * (m_Nrestart_count + 1)
462  + gflop_init * m_Nrestart_count;
463 
464 
465  return gflop;
466 }
467 
468 
469 //====================================================================
470 //============================================================END=====
BridgeIO vout
Definition: bridgeIO.cpp:503
int fetch_bool(const string &key, bool &value) const
Definition: parameters.cpp:391
static const std::string class_name
void detailed(const char *format,...)
Definition: bridgeIO.cpp:216
static double epsilon_criterion()
double norm2() const
Definition: field.cpp:592
void general(const char *format,...)
Definition: bridgeIO.cpp:197
void set_parameters_L(const int N_L)
Container of Field-type object.
Definition: field.h:45
int fetch_double(const string &key, double &value) const
Definition: parameters.cpp:327
int nvol() const
Definition: field.h:127
void solve_init(const Field &, double &)
Class for parameters.
Definition: parameters.h:46
void copy(Field &y, const Field &x)
copy(y, x): y = x
Definition: field.cpp:532
int square_non_zero(const double v)
int nin() const
Definition: field.h:126
void set_parameters(const Parameters &params)
dcomplex dotc(const Field &y, const Field &x)
Definition: field.cpp:155
int fetch_int(const string &key, int &value) const
Definition: parameters.cpp:346
virtual double flop_count()
returns the flop in giga unit
Definition: fopr.h:120
void aypx(const double a, Field &y, const Field &x)
aypx(y, a, x): y := a * y + x
Definition: field.cpp:612
void set_parameters_BiCGStab_series(const double Omega_tolerance)
int nex() const
Definition: field.h:128
void paranoiac(const char *format,...)
Definition: bridgeIO.cpp:235
void reset(const int Nin, const int Nvol, const int Nex, const element_type cmpl=Element_type::COMPLEX)
Definition: field.h:95
int index_ij(const int i, const int j)
void axpy(Field &y, const double a, const Field &x)
axpy(y, a, x): y := a * x + y
Definition: field.cpp:319
void crucial(const char *format,...)
Definition: bridgeIO.cpp:178
virtual void mult(Field &, const Field &)=0
multiplies fermion operator to a given field (2nd argument)
int non_negative(const int v)
static void assert_single_thread(const std::string &class_name)
assert currently running on single thread.
string get_string(const string &key) const
Definition: parameters.cpp:221
Bridge::VerboseLevel m_vl
Definition: solver.h:63
static VerboseLevel set_verbose_level(const std::string &str)
Definition: bridgeIO.cpp:131
void solve(Field &solution, const Field &source, int &Nconv, double &diff)
int size() const
Definition: field.h:132