Bridge++  Version 1.4.4
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
solver_BiCGStab_IDS_L_Cmplx.cpp
Go to the documentation of this file.
1 
15 
16 
17 #ifdef USE_FACTORY
18 namespace {
19  Solver *create_object(Fopr *fopr)
20  {
21  return new Solver_BiCGStab_IDS_L_Cmplx(fopr);
22  }
23 
24 
25  bool init = Solver::Factory::Register("BiCGStab_IDS_L_Cmplx", create_object);
26 }
27 #endif
28 
29 
30 const std::string Solver_BiCGStab_IDS_L_Cmplx::class_name = "Solver_BiCGStab_IDS_L_Cmplx";
31 
32 //====================================================================
34 {
35  const string str_vlevel = params.get_string("verbose_level");
36 
37  m_vl = vout.set_verbose_level(str_vlevel);
38 
39  //- fetch and check input parameters
40  int Niter, Nrestart;
41  double Stop_cond;
42  int N_L;
43  double Tol_L;
44 
45  int err = 0;
46  err += params.fetch_int("maximum_number_of_iteration", Niter);
47  err += params.fetch_int("maximum_number_of_restart", Nrestart);
48  err += params.fetch_double("convergence_criterion_squared", Stop_cond);
49  err += params.fetch_int("number_of_orthonormal_vectors", N_L);
50  err += params.fetch_double("tolerance_for_DynamicSelection_of_L", Tol_L);
51 
52  if (err) {
53  vout.crucial(m_vl, "Error at %s: input parameter not found.\n", class_name.c_str());
54  exit(EXIT_FAILURE);
55  }
56 
57 
58  set_parameters(Niter, Nrestart, Stop_cond);
59  set_parameters_DS_L(N_L, Tol_L);
60 }
61 
62 
63 //====================================================================
64 void Solver_BiCGStab_IDS_L_Cmplx::set_parameters(const int Niter, const int Nrestart, const double Stop_cond)
65 {
67 
68  //- print input parameters
69  vout.general(m_vl, "%s: input parameters\n", class_name.c_str());
70  vout.general(m_vl, " Niter = %d\n", Niter);
71  vout.general(m_vl, " Nrestart = %d\n", Nrestart);
72  vout.general(m_vl, " Stop_cond = %8.2e\n", Stop_cond);
73 
74  //- range check
75  int err = 0;
76  err += ParameterCheck::non_negative(Niter);
77  err += ParameterCheck::non_negative(Nrestart);
78  err += ParameterCheck::square_non_zero(Stop_cond);
79 
80  if (err) {
81  vout.crucial(m_vl, "Error at %s: parameter range check failed.\n", class_name.c_str());
82  exit(EXIT_FAILURE);
83  }
84 
85  //- store values
86  m_Niter = Niter;
87  m_Nrestart = Nrestart;
88  m_Stop_cond = Stop_cond;
89 }
90 
91 
92 //====================================================================
93 void Solver_BiCGStab_IDS_L_Cmplx::set_parameters_DS_L(const int N_L, const double Tol_L)
94 {
95  //- print input parameters
96  vout.general(m_vl, " N_L = %d\n", N_L);
97  vout.general(m_vl, " Tol_L = %16.8e\n", Tol_L);
98 
99  //- range check
100  int err = 0;
101  err += ParameterCheck::non_negative(N_L);
102  err += ParameterCheck::square_non_zero(Tol_L);
103 
104  if (err) {
105  vout.crucial(m_vl, "Error at %s: parameter range check failed.\n", class_name.c_str());
106  exit(EXIT_FAILURE);
107  }
108 
109  //- store values
110  m_N_L = N_L;
111  m_Tol_L = Tol_L;
112 }
113 
114 
115 //====================================================================
117  int& Nconv, double& diff)
118 {
119  double bnorm2 = b.norm2();
120  int bsize = b.size();
121 
122  vout.paranoiac(m_vl, "%s: starts\n", class_name.c_str());
123  vout.paranoiac(m_vl, " norm of b = %16.8e\n", bnorm2);
124  vout.paranoiac(m_vl, " size of b = %d\n", bsize);
125 
126  bool is_converged = false;
127  int Nconv2 = 0;
128  double diff2 = 1.0;
129  double rr;
130 
131  reset_field(b);
132  copy(m_s, b); // s = b;
133  solve_init(b, rr);
134  Nconv2 += 1;
135 
136  vout.detailed(m_vl, " iter: %8d %22.15e\n", Nconv2, rr / bnorm2);
137 
138 
139  for (int i_restart = 0; i_restart < m_Nrestart; i_restart++) {
140  for (int iter = 0; iter < m_Niter; iter++) {
141  if (rr / bnorm2 < m_Stop_cond) break;
142 
143  solve_step(rr);
144  Nconv2 += 2 * m_N_L_prev;
145 
146  vout.paranoiac(m_vl, " iter,N_L: %8d %8d\n", Nconv2, m_N_L_prev);
147  vout.detailed(m_vl, " iter: %8d %22.15e\n", Nconv2, rr / bnorm2);
148  }
149 
150  //- calculate true residual
151  m_fopr->mult(m_s, m_x); // s = m_fopr->mult(x);
152  axpy(m_s, -1.0, b); // s -= b;
153  diff2 = m_s.norm2();
154 
155  if (diff2 / bnorm2 < m_Stop_cond) {
156  vout.detailed(m_vl, "%s: converged.\n", class_name.c_str());
157  vout.detailed(m_vl, " iter(final): %8d %22.15e\n", Nconv2, diff2 / bnorm2);
158 
159  is_converged = true;
160  break;
161  } else {
162  //- restart with new approximate solution
163  copy(m_s, m_x); // s = x;
164  solve_init(b, rr);
165 
166  vout.detailed(m_vl, "%s: restarted.\n", class_name.c_str());
167  }
168 
169  if (is_converged) break;
170  }
171 
172 
173  if (diff2 / bnorm2 > m_Stop_cond) {
174  vout.crucial(m_vl, "Error at %s: not converged.\n", class_name.c_str());
175  vout.crucial(m_vl, " iter(final): %8d %22.15e\n", Nconv2, diff2 / bnorm2);
176  exit(EXIT_FAILURE);
177  }
178 
179 
180  copy(xq, m_x); // xq = x;
181 
182 #pragma omp barrier
183 #pragma omp master
184  {
185  diff = sqrt(diff2 / bnorm2);
186  Nconv = Nconv2;
187  }
188 #pragma omp barrier
189 }
190 
191 
192 //====================================================================
194 {
195 #pragma omp barrier
196 #pragma omp master
197  {
198  int Nin = b.nin();
199  int Nvol = b.nvol();
200  int Nex = b.nex();
201 
202  if ((m_s.nin() != Nin) || (m_s.nvol() != Nvol) || (m_s.nex() != Nex)) {
203  m_s.reset(Nin, Nvol, Nex);
204  m_x.reset(Nin, Nvol, Nex);
205  m_r_init.reset(Nin, Nvol, Nex);
206  m_v.reset(Nin, Nvol, Nex);
207  }
208 
209  m_u.resize(m_N_L + 1);
210  m_r.resize(m_N_L + 1);
211 
212  for (int i = 0; i < m_N_L + 1; ++i) {
213  m_u[i].reset(Nin, Nvol, Nex);
214  m_r[i].reset(Nin, Nvol, Nex);
215  }
216  }
217 #pragma omp barrier
218 
219  vout.paranoiac(m_vl, " %s: field size reset.\n", class_name.c_str());
220 }
221 
222 
223 //====================================================================
225 {
226  copy(m_x, m_s); // x = s;
227 
228  for (int i = 0; i < m_N_L + 1; ++i) {
229  m_r[i].set(0.0); // r[i] = 0.0;
230  m_u[i].set(0.0); // u[i] = 0.0;
231  }
232 
233  // r[0] = b - A x_0
234  m_fopr->mult(m_v, m_s); // m_v = m_fopr->mult(s);
235  copy(m_r[0], b); // r[0] = b;
236  axpy(m_r[0], -1.0, m_v); // r[0] -= m_v;
237 
238  copy(m_r_init, m_r[0]); // r_init = r[0];
239  rr = m_r[0].norm2(); // rr = r[0] * r[0];
240 
241 #pragma omp barrier
242 #pragma omp master
243  {
244  m_rho_prev = cmplx(-1.0, 0.0);
245 
246  // NB. m_alpha_prev = 0.0 \neq 1.0
247  m_alpha_prev = cmplx(0.0, 0.0);
248 
249  m_N_L_prev = m_N_L;
250  }
251 #pragma omp barrier
252 }
253 
254 
255 //====================================================================
257 {
258  dcomplex rho_prev2 = m_rho_prev;
259  dcomplex alpha_prev2 = m_alpha_prev;
260 
261  int N_L_tmp = 0; // superficial initialization
262  dcomplex c_Rayleigh_prev = cmplx(0.0, 0.0);
263 
264  bool is_converged_L = false;
265 
266 
267  for (int j = 0; j < m_N_L_prev; ++j) {
268  if (!is_converged_L) {
269  dcomplex rho = dotc(m_r[j], m_r_init); // rho = r[j] * r_init;
270  rho = conj(rho);
271 
272  dcomplex beta = alpha_prev2 * (rho / rho_prev2);
273 
274  rho_prev2 = rho;
275 
276  for (int i = 0; i < j + 1; ++i) {
277  aypx(-beta, m_u[i], m_r[i]); // u[i] = - beta * u[i] + r[i];
278  }
279 
280  m_fopr->mult(m_u[j + 1], m_u[j]); // u[j+1] = m_fopr->mult(u[j]);
281 
282  dcomplex gamma = dotc(m_u[j + 1], m_r_init);
283 
284  alpha_prev2 = rho_prev2 / conj(gamma);
285 
286  for (int i = 0; i < j + 1; ++i) {
287  axpy(m_r[i], -alpha_prev2, m_u[i + 1]); // r[i] -= alpha_prev * u[i+1];
288  }
289 
290  m_fopr->mult(m_r[j + 1], m_r[j]); // r[j+1] = m_fopr->mult(r[j]);
291 
292  axpy(m_x, alpha_prev2, m_u[0]); // x += alpha_prev * u[0];
293 
294 
295  //- calculate the Rayleigh quotient for N_L_tmp
296 
297  // dcomplex c_Rayleigh = (r[j] * r[j+1]) / (r[j] * r[j]);
298  double r_tmp = m_r[j].norm2(); // r_tmp = r[j] * r[j];
299  dcomplex c_Rayleigh = dotc(m_r[j], m_r[j + 1]) / r_tmp; // c_Rayleigh = r[j] * r[j+1] / r_tmp;
300 
301  dcomplex c_E = (c_Rayleigh - c_Rayleigh_prev) / c_Rayleigh;
302 
303  // #pragma omp master
304  c_Rayleigh_prev = c_Rayleigh;
305  // #pragma omp barrier
306 
307  N_L_tmp = j + 1;
308 
309  // vout.paranoiac(m_vl, "N_L_tmp,abs(c_E),m_Tol_L = %d %f %f\n",N_L_tmp,abs(c_E),m_Tol_L);
310 
311  if (abs(c_E) < m_Tol_L) {
312  // vout.paranoiac(m_vl, "N_L_tmp = %d\n",N_L_tmp);
313  // break;
314  is_converged_L = true;
315  }
316  }
317  }
318 
319 
320  std::vector<double> sigma(m_N_L + 1);
321  std::vector<dcomplex> gamma_prime(m_N_L + 1);
322 
323  // NB. tau(m_N_L,m_N_L+1), not (m_N_L+1,m_N_L+1)
324  std::vector<dcomplex> tau(m_N_L * (m_N_L + 1));
325  int ij, ji;
326 
327  for (int j = 1; j < N_L_tmp + 1; ++j) {
328  for (int i = 1; i < j; ++i) {
329  ij = index_ij(i, j);
330 
331  dcomplex r_ji = dotc(m_r[j], m_r[i]);
332  tau[ij] = conj(r_ji) / sigma[i]; // tau[ij] = (r[j] * r[i]) / sigma[i];
333  axpy(m_r[j], -tau[ij], m_r[i]); // r[j] -= tau[ij] * r[i];
334  }
335 
336  sigma[j] = m_r[j].norm2(); // sigma[j] = r[j] * r[j];
337 
338  dcomplex r_0j = dotc(m_r[0], m_r[j]);
339  gamma_prime[j] = conj(r_0j) / sigma[j]; // gamma_prime[j] = (r[0] * r[j]) / sigma[j];
340  }
341 
342 
343  std::vector<dcomplex> gamma(m_N_L + 1);
344  dcomplex c_tmp;
345 
346  gamma[N_L_tmp] = gamma_prime[N_L_tmp];
347 
348  for (int j = N_L_tmp - 1; j > 0; --j) {
349  c_tmp = cmplx(0.0, 0.0);
350 
351  for (int i = j + 1; i < N_L_tmp + 1; ++i) {
352  ji = index_ij(j, i);
353  c_tmp += tau[ji] * gamma[i];
354  }
355 
356  gamma[j] = gamma_prime[j] - c_tmp;
357  }
358 
359 
360  // NB. gamma_double_prime(m_N_L), not (m_N_L+1)
361  std::vector<dcomplex> gamma_double_prime(m_N_L);
362 
363  for (int j = 1; j < N_L_tmp; ++j) {
364  c_tmp = cmplx(0.0, 0.0);
365 
366  for (int i = j + 1; i < N_L_tmp; ++i) {
367  ji = index_ij(j, i);
368  c_tmp += tau[ji] * gamma[i + 1];
369  }
370 
371  gamma_double_prime[j] = gamma[j + 1] + c_tmp;
372  }
373 
374 
375  axpy(m_x, gamma[1], m_r[0]); // x += gamma[ 1] * r[ 0];
376  axpy(m_r[0], -gamma_prime[N_L_tmp], m_r[N_L_tmp]); // r[0] -= gamma_prime[N_L_tmp] * r[N_L_tmp];
377  axpy(m_u[0], -gamma[N_L_tmp], m_u[N_L_tmp]); // u[0] -= gamma[ N_L_tmp] * u[N_L_tmp];
378 
379  for (int j = 1; j < N_L_tmp; ++j) {
380  axpy(m_x, gamma_double_prime[j], m_r[j]); // x += gamma_double_prime[j] * r[j];
381  axpy(m_r[0], -gamma_prime[j], m_r[j]); // r[0] -= gamma_prime[ j] * r[j];
382  axpy(m_u[0], -gamma[j], m_u[j]); // u[0] -= gamma[ j] * u[j];
383  }
384 
385  rr = m_r[0].norm2(); // rr = r[0] * r[0];
386 
387 
388  //- calculate the residual difference for N_L_prev
389  double c_phi = abs(rr - m_rr_prev) / m_rr_prev;
390 
391  // NB. another criterion for c_phi is possible, m_Tol_phi /= m_Tol_L
392  if (c_phi < m_Tol_L) {
393  N_L_tmp = m_N_L;
394  }
395 
396 
397 #pragma omp barrier
398 #pragma omp master
399  {
400  m_rho_prev = rho_prev2;
401  m_alpha_prev = alpha_prev2;
402  m_N_L_prev = N_L_tmp;
403  m_rr_prev = rr;
404 
405  m_rho_prev *= -gamma_prime[N_L_tmp];
406  }
407 #pragma omp barrier
408 }
409 
410 
411 //====================================================================
413 {
414  int NPE = CommonParameters::NPE();
416 
417  //- NB1 Nin = 2 * Nc * Nd, Nex = 1 for field_F
418  //- NB2 Nvol = CommonParameters::Nvol()/2 for eo
419  int Nin = m_x.nin();
420  int Nvol = m_x.nvol();
421  int Nex = m_x.nex();
422 
423  double flop_fopr = m_fopr->flop_count();
424 
425  if (flop_fopr < eps) {
426  vout.crucial(m_vl, "Warning at %s: no fopr->flop_count() is available, setting flop = 0.0.\n", class_name.c_str());
427  return 0.0;
428  }
429 
430  double flop_axpy = static_cast<double>(Nin * Nex * 2) * (Nvol * NPE);
431  double flop_dotc = static_cast<double>(Nin * Nex * 4) * (Nvol * NPE);
432  double flop_norm = static_cast<double>(Nin * Nex * 2) * (Nvol * NPE);
433 
434  int N_L_prev_total = (m_Nconv_count - 1) / 2;
435 
436  double flop_init = flop_fopr + flop_axpy + flop_norm;
437  double flop_step_BiCG_part = 2 * N_L_prev_total * flop_fopr
438  + 3 * N_L_prev_total * flop_dotc
439  + (N_L_prev_total + 2 * m_N_L_part_count) * flop_axpy
440  + N_L_prev_total * flop_norm;
441  double flop_step_L_part = (m_N_L_part_count + N_L_prev_total) * flop_dotc
442  + (m_N_L_part_count + 3 * N_L_prev_total) * flop_axpy
443  + (N_L_prev_total + m_Niter_count) * flop_norm;
444  double flop_step = flop_step_BiCG_part + flop_step_L_part;
445  double flop_true_residual = flop_fopr + flop_axpy + flop_norm;
446 
447  double flop = flop_norm + flop_init + flop_step + flop_true_residual
448  + flop_init * m_Nrestart_count;
449 
450 
451  return flop;
452 }
453 
454 
455 //====================================================================
456 //============================================================END=====
BridgeIO vout
Definition: bridgeIO.cpp:495
void detailed(const char *format,...)
Definition: bridgeIO.cpp:212
static double epsilon_criterion()
double norm2() const
Definition: field.cpp:441
void set_parameters_DS_L(const int N_L, const double Tol_L)
void general(const char *format,...)
Definition: bridgeIO.cpp:195
Container of Field-type object.
Definition: field.h:39
int fetch_double(const string &key, double &value) const
Definition: parameters.cpp:211
int nvol() const
Definition: field.h:116
Class for parameters.
Definition: parameters.h:46
void copy(Field &y, const Field &x)
copy(y, x): y = x
Definition: field.cpp:381
void solve_init(const Field &, double &)
int square_non_zero(const double v)
Definition: checker.cpp:41
int nin() const
Definition: field.h:115
dcomplex dotc(const Field &y, const Field &x)
Definition: field.cpp:92
int fetch_int(const string &key, int &value) const
Definition: parameters.cpp:230
virtual double flop_count()
returns the flops per site.
Definition: fopr.h:121
void reset(const int Nin, const int Nvol, const int Nex, const element_type cmpl=COMPLEX)
Definition: field.h:84
void aypx(const double a, Field &y, const Field &x)
aypx(y, a, x): y := a * y + x
Definition: field.cpp:461
int nex() const
Definition: field.h:117
void paranoiac(const char *format,...)
Definition: bridgeIO.cpp:229
void axpy(Field &y, const double a, const Field &x)
axpy(y, a, x): y := a * x + y
Definition: field.cpp:168
void crucial(const char *format,...)
Definition: bridgeIO.cpp:178
Base class for linear solver class family.
Definition: solver.h:37
virtual void mult(Field &, const Field &)=0
multiplies fermion operator to a given field (2nd argument)
int non_negative(const int v)
Definition: checker.cpp:21
void set_parameters(const Parameters &params)
static void assert_single_thread(const std::string &class_name)
assert currently running on single thread.
Base class of fermion operator family.
Definition: fopr.h:47
string get_string(const string &key) const
Definition: parameters.cpp:116
void solve(Field &solution, const Field &source, int &Nconv, double &diff)
Bridge::VerboseLevel m_vl
Definition: solver.h:63
static VerboseLevel set_verbose_level(const std::string &str)
Definition: bridgeIO.cpp:131
static int NPE()
int size() const
Definition: field.h:121