Bridge++  Version 1.4.4
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
solver_BiCGStab_L_Cmplx.cpp
Go to the documentation of this file.
1 
15 
16 
17 #ifdef USE_FACTORY
18 namespace {
19  Solver *create_object(Fopr *fopr)
20  {
21  return new Solver_BiCGStab_L_Cmplx(fopr);
22  }
23 
24 
25  bool init = Solver::Factory::Register("BiCGStab_L_Cmplx", create_object);
26 }
27 #endif
28 
29 
30 const std::string Solver_BiCGStab_L_Cmplx::class_name = "Solver_BiCGStab_L_Cmplx";
31 
32 //====================================================================
34 {
35  const string str_vlevel = params.get_string("verbose_level");
36 
37  m_vl = vout.set_verbose_level(str_vlevel);
38 
39  //- fetch and check input parameters
40  int Niter, Nrestart;
41  double Stop_cond;
42  int N_L;
43 
44  int err = 0;
45  err += params.fetch_int("maximum_number_of_iteration", Niter);
46  err += params.fetch_int("maximum_number_of_restart", Nrestart);
47  err += params.fetch_double("convergence_criterion_squared", Stop_cond);
48  err += params.fetch_int("number_of_orthonormal_vectors", N_L);
49 
50  if (err) {
51  vout.crucial(m_vl, "Error at %s: input parameter not found.\n", class_name.c_str());
52  exit(EXIT_FAILURE);
53  }
54 
55  set_parameters(Niter, Nrestart, Stop_cond);
56  set_parameters_L(N_L);
57 }
58 
59 
60 //====================================================================
61 void Solver_BiCGStab_L_Cmplx::set_parameters(const int Niter, const int Nrestart, const double Stop_cond)
62 {
64 
65  //- print input parameters
66  vout.general(m_vl, "%s: input parameters\n", class_name.c_str());
67  vout.general(m_vl, " Niter = %d\n", Niter);
68  vout.general(m_vl, " Nrestart = %d\n", Nrestart);
69  vout.general(m_vl, " Stop_cond = %8.2e\n", Stop_cond);
70 
71  //- range check
72  int err = 0;
73  err += ParameterCheck::non_negative(Niter);
74  err += ParameterCheck::non_negative(Nrestart);
75  err += ParameterCheck::square_non_zero(Stop_cond);
76 
77  if (err) {
78  vout.crucial(m_vl, "Error at %s: parameter range check failed.\n", class_name.c_str());
79  exit(EXIT_FAILURE);
80  }
81 
82  //- store values
83  m_Niter = Niter;
84  m_Nrestart = Nrestart;
85  m_Stop_cond = Stop_cond;
86 }
87 
88 
89 //====================================================================
91 {
92  //- print input parameters
93  vout.general(m_vl, " N_L = %d\n", N_L);
94 
95  //- range check
96  int err = 0;
97  err += ParameterCheck::non_negative(N_L);
98 
99  if (err) {
100  vout.crucial(m_vl, "Error at %s: parameter range check failed.\n", class_name.c_str());
101  exit(EXIT_FAILURE);
102  }
103 
104  //- store values
105  m_N_L = N_L;
106 }
107 
108 
109 //====================================================================
111  int& Nconv, double& diff)
112 {
113  double bnorm2 = b.norm2();
114  int bsize = b.size();
115 
116  vout.paranoiac(m_vl, "%s: starts\n", class_name.c_str());
117  vout.paranoiac(m_vl, " norm of b = %16.8e\n", bnorm2);
118  vout.paranoiac(m_vl, " size of b = %d\n", bsize);
119 
120  bool is_converged = false;
121  int Nconv2 = 0;
122  double diff2 = 1.0;
123  double rr;
124 
125  reset_field(b);
126  copy(m_s, b); // s = b;
127  solve_init(b, rr);
128  Nconv2 += 1;
129 
130  vout.detailed(m_vl, " iter: %8d %22.15e\n", Nconv2, rr / bnorm2);
131 
132 
133  for (int i_restart = 0; i_restart < m_Nrestart; i_restart++) {
134  for (int iter = 0; iter < m_Niter; iter++) {
135  if (rr / bnorm2 < m_Stop_cond) break;
136 
137  solve_step(rr);
138  Nconv2 += 2 * m_N_L;
139 
140  vout.detailed(m_vl, " iter: %8d %22.15e\n", Nconv2, rr / bnorm2);
141  }
142 
143  //- calculate true residual
144  m_fopr->mult(m_s, m_x); // s = m_fopr->mult(x);
145  axpy(m_s, -1.0, b); // s -= b;
146  diff2 = m_s.norm2();
147 
148  if (diff2 / bnorm2 < m_Stop_cond) {
149  vout.detailed(m_vl, "%s: converged.\n", class_name.c_str());
150  vout.detailed(m_vl, " iter(final): %8d %22.15e\n", Nconv2, diff2 / bnorm2);
151 
152  is_converged = true;
153  break;
154  } else {
155  //- restart with new approximate solution
156  copy(m_s, m_x); // s = x;
157  solve_init(b, rr);
158 
159  vout.detailed(m_vl, "%s: restarted.\n", class_name.c_str());
160  }
161  }
162 
163 
164  if (!is_converged) {
165  vout.crucial(m_vl, "Error at %s: not converged.\n", class_name.c_str());
166  vout.crucial(m_vl, " iter(final): %8d %22.15e\n", Nconv2, diff2 / bnorm2);
167  exit(EXIT_FAILURE);
168  }
169 
170 
171  copy(xq, m_x); // xq = x;
172 
173 #pragma omp barrier
174 #pragma omp master
175  {
176  diff = sqrt(diff2 / bnorm2);
177  Nconv = Nconv2;
178  }
179 #pragma omp barrier
180 }
181 
182 
183 //====================================================================
185 {
186 #pragma omp barrier
187 #pragma omp master
188  {
189  int Nin = b.nin();
190  int Nvol = b.nvol();
191  int Nex = b.nex();
192 
193  if ((m_s.nin() != Nin) || (m_s.nvol() != Nvol) || (m_s.nex() != Nex)) {
194  m_s.reset(Nin, Nvol, Nex);
195  m_x.reset(Nin, Nvol, Nex);
196  m_r_init.reset(Nin, Nvol, Nex);
197  m_v.reset(Nin, Nvol, Nex);
198  }
199 
200  m_u.resize(m_N_L + 1);
201  m_r.resize(m_N_L + 1);
202 
203  for (int i = 0; i < m_N_L + 1; ++i) {
204  m_u[i].reset(Nin, Nvol, Nex);
205  m_r[i].reset(Nin, Nvol, Nex);
206  }
207  }
208 #pragma omp barrier
209 
210  vout.paranoiac(m_vl, " %s: field size reset.\n", class_name.c_str());
211 }
212 
213 
214 //====================================================================
215 void Solver_BiCGStab_L_Cmplx::solve_init(const Field& b, double& rr)
216 {
217  copy(m_x, m_s); // x = s;
218 
219  for (int i = 0; i < m_N_L + 1; ++i) {
220  m_r[i].set(0.0); // r[i] = 0.0;
221  m_u[i].set(0.0); // u[i] = 0.0;
222  }
223 
224  // r[0] = b - A x_0
225  m_fopr->mult(m_v, m_s); // m_v = m_fopr->mult(s);
226  copy(m_r[0], b); // r[0] = b;
227  axpy(m_r[0], -1.0, m_v); // r[0] -= m_v;
228 
229  copy(m_r_init, m_r[0]); // r_init = r[0];
230  rr = m_r[0].norm2(); // rr = r[0] * r[0];
231 
232 #pragma omp barrier
233 #pragma omp master
234  {
235  m_rho_prev = cmplx(-1.0, 0.0);
236 
237  // NB. m_alpha_prev = 0.0 \neq 1.0
238  m_alpha_prev = cmplx(0.0, 0.0);
239  }
240 #pragma omp barrier
241 }
242 
243 
244 //====================================================================
246 {
247  dcomplex rho_prev2 = m_rho_prev;
248  dcomplex alpha_prev2 = m_alpha_prev;
249 
250  for (int j = 0; j < m_N_L; ++j) {
251  dcomplex rho = dotc(m_r[j], m_r_init); // dcomplex rho = r[j] * r_init;
252  rho = conj(rho);
253 
254  dcomplex beta = alpha_prev2 * (rho / rho_prev2);
255 
256  rho_prev2 = rho;
257 
258  for (int i = 0; i < j + 1; ++i) {
259  aypx(-beta, m_u[i], m_r[i]); // u[i] = - beta * u[i] + r[i];
260  }
261 
262  m_fopr->mult(m_u[j + 1], m_u[j]); // u[j+1] = m_fopr->mult(u[j]);
263 
264  dcomplex gamma = dotc(m_u[j + 1], m_r_init);
265  alpha_prev2 = rho_prev2 / conj(gamma);
266 
267  for (int i = 0; i < j + 1; ++i) {
268  axpy(m_r[i], -alpha_prev2, m_u[i + 1]); // r[i] -= alpha_prev * u[i+1];
269  }
270 
271  m_fopr->mult(m_r[j + 1], m_r[j]); // r[j+1] = m_fopr->mult(r[j]);
272 
273  axpy(m_x, alpha_prev2, m_u[0]); // x += alpha_prev * u[0];
274  }
275 
276 
277  std::vector<double> sigma(m_N_L + 1);
278  std::vector<dcomplex> gamma_prime(m_N_L + 1);
279 
280  // NB. tau(m_N_L,m_N_L+1), not (m_N_L+1,m_N_L+1)
281  std::vector<dcomplex> tau(m_N_L * (m_N_L + 1));
282  int ij, ji;
283 
284  for (int j = 1; j < m_N_L + 1; ++j) {
285  for (int i = 1; i < j; ++i) {
286  ij = index_ij(i, j);
287 
288  dcomplex r_ji = dotc(m_r[j], m_r[i]);
289  tau[ij] = conj(r_ji) / sigma[i]; // tau[ij] = (r[j] * r[i]) / sigma[i];
290  axpy(m_r[j], -tau[ij], m_r[i]); // r[j] -= tau[ij] * r[i];
291  }
292 
293  sigma[j] = m_r[j].norm2(); // sigma[j] = r[j] * r[j];
294 
295  dcomplex r_0j = dotc(m_r[0], m_r[j]);
296  gamma_prime[j] = conj(r_0j) / sigma[j]; // gamma_prime[j] = (r[0] * r[j]) / sigma[j];
297  }
298 
299 
300  std::vector<dcomplex> gamma(m_N_L + 1);
301  dcomplex c_tmp;
302 
303  gamma[m_N_L] = gamma_prime[m_N_L];
304 
305  for (int j = m_N_L - 1; j > 0; --j) {
306  c_tmp = cmplx(0.0, 0.0);
307 
308  for (int i = j + 1; i < m_N_L + 1; ++i) {
309  ji = index_ij(j, i);
310  c_tmp += tau[ji] * gamma[i];
311  }
312 
313  gamma[j] = gamma_prime[j] - c_tmp;
314  }
315 
316 
317  // NB. gamma_double_prime(m_N_L), not (m_N_L+1)
318  std::vector<dcomplex> gamma_double_prime(m_N_L);
319 
320  for (int j = 1; j < m_N_L; ++j) {
321  c_tmp = cmplx(0.0, 0.0);
322 
323  for (int i = j + 1; i < m_N_L; ++i) {
324  ji = index_ij(j, i);
325  c_tmp += tau[ji] * gamma[i + 1];
326  }
327 
328  gamma_double_prime[j] = gamma[j + 1] + c_tmp;
329  }
330 
331 
332  axpy(m_x, gamma[1], m_r[0]); // x += gamma[ 1] * r[ 0];
333  axpy(m_r[0], -gamma_prime[m_N_L], m_r[m_N_L]); // r[0] -= gamma_prime[m_N_L] * r[m_N_L];
334  axpy(m_u[0], -gamma[m_N_L], m_u[m_N_L]); // u[0] -= gamma[ m_N_L] * u[m_N_L];
335 
336  for (int j = 1; j < m_N_L; ++j) {
337  axpy(m_x, gamma_double_prime[j], m_r[j]); // x += gamma_double_prime[j] * r[j];
338  axpy(m_r[0], -gamma_prime[j], m_r[j]); // r[0] -= gamma_prime[ j] * r[j];
339  axpy(m_u[0], -gamma[j], m_u[j]); // u[0] -= gamma[ j] * u[j];
340  }
341 
342  rr = m_r[0].norm2(); // rr = r[0] * r[0];
343 
344 #pragma omp barrier
345 #pragma omp master
346  {
347  m_rho_prev = rho_prev2;
348  m_alpha_prev = alpha_prev2;
349 
350  m_rho_prev *= -gamma_prime[m_N_L];
351  }
352 #pragma omp barrier
353 }
354 
355 
356 //====================================================================
358 {
359  int NPE = CommonParameters::NPE();
361 
362  //- NB1 Nin = 2 * Nc * Nd, Nex = 1 for field_F
363  //- NB2 Nvol = CommonParameters::Nvol()/2 for eo
364  int Nin = m_x.nin();
365  int Nvol = m_x.nvol();
366  int Nex = m_x.nex();
367 
368  double flop_fopr = m_fopr->flop_count();
369 
370  if (flop_fopr < eps) {
371  vout.crucial(m_vl, "Warning at %s: no fopr->flop_count() is available, setting flop = 0.0.\n", class_name.c_str());
372  return 0.0;
373  }
374 
375  double flop_axpy = static_cast<double>(Nin * Nex * 2) * (Nvol * NPE);
376  double flop_dotc = static_cast<double>(Nin * Nex * 4) * (Nvol * NPE);
377  double flop_norm = static_cast<double>(Nin * Nex * 2) * (Nvol * NPE);
378 
379  int N_iter = (m_Nconv_count - 1) / (2 * m_N_L);
380 
381  int N_L_part = 0;
382  for (int j = 0; j < m_N_L; ++j) {
383  for (int i = 0; i < j + 1; ++i) {
384  N_L_part += 1;
385  }
386  }
387 
388  double flop_init = flop_fopr + flop_axpy + flop_norm;
389  double flop_step_BiCG_part = 2 * m_N_L * flop_fopr
390  + 2 * m_N_L * flop_dotc
391  + (m_N_L + 2 * N_L_part) * flop_axpy;
392  double flop_step_L_part = (N_L_part + m_N_L) * flop_dotc
393  + (N_L_part + 3 * m_N_L) * flop_axpy
394  + (m_N_L + 1) * flop_norm;
395  double flop_step = flop_step_BiCG_part + flop_step_L_part;
396  double flop_true_residual = flop_fopr + flop_axpy + flop_norm;
397 
398  double flop = flop_norm + flop_init + flop_step * N_iter + flop_true_residual
399  + flop_init * m_Nrestart_count;
400 
401 
402  return flop;
403 }
404 
405 
406 //====================================================================
407 //============================================================END=====
BridgeIO vout
Definition: bridgeIO.cpp:495
static const std::string class_name
void detailed(const char *format,...)
Definition: bridgeIO.cpp:212
static double epsilon_criterion()
double norm2() const
Definition: field.cpp:441
void general(const char *format,...)
Definition: bridgeIO.cpp:195
void set_parameters_L(const int N_L)
BiCGStab(L) algorithm.
Container of Field-type object.
Definition: field.h:39
int fetch_double(const string &key, double &value) const
Definition: parameters.cpp:211
int nvol() const
Definition: field.h:116
void solve_init(const Field &, double &)
Class for parameters.
Definition: parameters.h:46
void copy(Field &y, const Field &x)
copy(y, x): y = x
Definition: field.cpp:381
int square_non_zero(const double v)
Definition: checker.cpp:41
int nin() const
Definition: field.h:115
void set_parameters(const Parameters &params)
dcomplex dotc(const Field &y, const Field &x)
Definition: field.cpp:92
int fetch_int(const string &key, int &value) const
Definition: parameters.cpp:230
virtual double flop_count()
returns the flops per site.
Definition: fopr.h:121
void reset(const int Nin, const int Nvol, const int Nex, const element_type cmpl=COMPLEX)
Definition: field.h:84
void aypx(const double a, Field &y, const Field &x)
aypx(y, a, x): y := a * y + x
Definition: field.cpp:461
int nex() const
Definition: field.h:117
void paranoiac(const char *format,...)
Definition: bridgeIO.cpp:229
void axpy(Field &y, const double a, const Field &x)
axpy(y, a, x): y := a * x + y
Definition: field.cpp:168
void crucial(const char *format,...)
Definition: bridgeIO.cpp:178
Base class for linear solver class family.
Definition: solver.h:37
virtual void mult(Field &, const Field &)=0
multiplies fermion operator to a given field (2nd argument)
int non_negative(const int v)
Definition: checker.cpp:21
static void assert_single_thread(const std::string &class_name)
assert currently running on single thread.
Base class of fermion operator family.
Definition: fopr.h:47
string get_string(const string &key) const
Definition: parameters.cpp:116
Bridge::VerboseLevel m_vl
Definition: solver.h:63
static VerboseLevel set_verbose_level(const std::string &str)
Definition: bridgeIO.cpp:131
void solve(Field &solution, const Field &source, int &Nconv, double &diff)
static int NPE()
int size() const
Definition: field.h:121