Bridge++  Ver. 1.3.x
solver_BiCGStab_L_Cmplx.cpp
Go to the documentation of this file.
1 
15 
16 
17 
18 #ifdef USE_FACTORY
19 namespace {
20  Solver *create_object(Fopr *fopr)
21  {
22  return new Solver_BiCGStab_L_Cmplx(fopr);
23  }
24 
25 
26  bool init = Solver::Factory::Register("BiCGStab_L_Cmplx", create_object);
27 }
28 #endif
29 
30 //- parameter entries
31 namespace {
32  void append_entry(Parameters& param)
33  {
34  param.Register_int("maximum_number_of_iteration", 0);
35  param.Register_double("convergence_criterion_squared", 0.0);
36 
37  param.Register_int("number_of_orthonormal_vectors", 0);
38 
39  param.Register_string("verbose_level", "NULL");
40  }
41 
42 
43 #ifdef USE_PARAMETERS_FACTORY
44  bool init_param = ParametersFactory::Register("Solver.BiCGStab_L_Cmplx", append_entry);
45 #endif
46 }
47 //- end
48 
49 //- parameters class
51 //- end
52 
53 const std::string Solver_BiCGStab_L_Cmplx::class_name = "Solver_BiCGStab_L_Cmplx";
54 
55 //====================================================================
57 {
58  const string str_vlevel = params.get_string("verbose_level");
59 
60  m_vl = vout.set_verbose_level(str_vlevel);
61 
62  //- fetch and check input parameters
63  int Niter;
64  double Stop_cond;
65  int N_L;
66 
67  int err = 0;
68  err += params.fetch_int("maximum_number_of_iteration", Niter);
69  err += params.fetch_double("convergence_criterion_squared", Stop_cond);
70  err += params.fetch_int("number_of_orthonormal_vectors", N_L);
71 
72  if (err) {
73  vout.crucial(m_vl, "%s: fetch error, input parameter not found.\n", class_name.c_str());
74  exit(EXIT_FAILURE);
75  }
76 
77 
78  set_parameters(Niter, Stop_cond);
79  set_parameters_L(N_L);
80 }
81 
82 
83 //====================================================================
84 void Solver_BiCGStab_L_Cmplx::set_parameters(const int Niter, const double Stop_cond)
85 {
87 
88  //- print input parameters
89  vout.general(m_vl, "%s: input parameters\n", class_name.c_str());
90  vout.general(m_vl, " Niter = %d\n", Niter);
91  vout.general(m_vl, " Stop_cond = %16.8e\n", Stop_cond);
92 
93  //- range check
94  int err = 0;
95  err += ParameterCheck::non_negative(Niter);
96  err += ParameterCheck::square_non_zero(Stop_cond);
97 
98  if (err) {
99  vout.crucial(m_vl, "%s: parameter range check failed.\n", class_name.c_str());
100  exit(EXIT_FAILURE);
101  }
102 
103  //- store values
104  m_Niter = Niter;
105  m_Stop_cond = Stop_cond;
106 }
107 
108 
109 //====================================================================
111 {
112  //- print input parameters
113  vout.general(m_vl, " N_L = %d\n", N_L);
114 
115  //- range check
116  int err = 0;
117  err += ParameterCheck::non_negative(N_L);
118 
119  if (err) {
120  vout.crucial(m_vl, "%s: parameter range check failed.\n", class_name.c_str());
121  exit(EXIT_FAILURE);
122  }
123 
124  //- store values
125  m_N_L = N_L;
126 }
127 
128 
129 //====================================================================
131  int& Nconv, double& diff)
132 {
133  double bnorm2 = b.norm2();
134  double snorm = 1.0 / bnorm2;
135  int bsize = b.size();
136  double rr;
137 
138  vout.paranoiac(m_vl, "%s: starts\n", class_name.c_str());
139  vout.paranoiac(m_vl, " norm of b = %16.8e\n", bnorm2);
140  vout.paranoiac(m_vl, " size of b = %d\n", bsize);
141 
142  reset_field(b);
143 
144 
145  // Nconv = -1;
146  int Nconv2 = 0;
147  copy(s, b); // s = b;
148 
149  solve_init(b, rr);
150 
151  bool is_converged = false;
152 
153  vout.detailed(m_vl, " iter: %8d %22.15e\n", 0, rr * snorm);
154 
155 
156  for (int iter = 0; iter < m_Niter; iter++) {
157  if (is_converged) break;
158 
159  solve_step(rr);
160 
161  Nconv2 += 2 * m_N_L;
162 
163  vout.detailed(m_vl, " iter: %8d %22.15e\n", Nconv2, rr * snorm);
164 
165  if (rr * snorm < m_Stop_cond) {
166  m_fopr->mult(s, x); // s = m_fopr->mult(x);
167  axpy(s, -1.0, b); // s -= b;
168 
169  double diff2 = s.norm2();
170 
171  if (diff2 * snorm < m_Stop_cond) {
172  // NB. Nconv is calculated above.
173  is_converged = true;
174  } else {
175  copy(s, x); // s = x;
176  solve_init(b, rr);
177  }
178  }
179  }
180 
181 
182  m_fopr->mult(s, x); // s = m_fopr->mult(x);
183  axpy(s, -1.0, b); // s -= b;
184 
185  copy(xq, x); // xq = x;
186 
187  double diff2 = s.norm2();
188 
189  if (diff2 * snorm > m_Stop_cond) {
190  vout.crucial(m_vl, "%s: not converged.\n", class_name.c_str());
191  exit(EXIT_FAILURE);
192  }
193 
194 
195 #pragma omp barrier
196 #pragma omp master
197  {
198  diff = sqrt(diff2);
199  Nconv = Nconv2;
200  }
201 #pragma omp barrier
202 }
203 
204 
205 //====================================================================
207 {
208 #pragma omp barrier
209 #pragma omp master
210  {
211  int Nin = b.nin();
212  int Nvol = b.nvol();
213  int Nex = b.nex();
214 
215  if ((s.nin() != Nin) || (s.nvol() != Nvol) || (s.nex() != Nex)) {
216  s.reset(Nin, Nvol, Nex);
217  x.reset(Nin, Nvol, Nex);
218  r_init.reset(Nin, Nvol, Nex);
219  v_tmp.reset(Nin, Nvol, Nex);
220  }
221 
222  u.resize(m_N_L + 1);
223  r.resize(m_N_L + 1);
224 
225  for (int i = 0; i < m_N_L + 1; ++i) {
226  u[i].reset(Nin, Nvol, Nex);
227  r[i].reset(Nin, Nvol, Nex);
228  }
229  }
230 #pragma omp barrier
231 
232  vout.paranoiac(m_vl, " %s: field size reset.\n", class_name.c_str());
233 }
234 
235 
236 //====================================================================
237 void Solver_BiCGStab_L_Cmplx::solve_init(const Field& b, double& rr)
238 {
239  copy(x, s); // x = s;
240 
241  //- r[0] = b - A x_0
242  m_fopr->mult(v_tmp, s); // v_tmp = m_fopr->mult(s);
243  copy(r[0], b); // r[0] = b;
244  axpy(r[0], -1.0, v_tmp); // r[0] -= v_tmp;
245 
246  copy(r_init, r[0]); // r_init = r[0];
247  rr = r[0].norm2(); // rr = r[0] * r[0];
248 
249  u[0].set(0.0); // u[0] = 0.0;
250 
251 #pragma omp barrier
252 #pragma omp master
253  {
254  rho_prev = cmplx(-1.0, 0.0);
255 
256  // NB. alpha_prev = 0.0 \neq 1.0
257  alpha_prev = cmplx(0.0, 0.0);
258  }
259 #pragma omp barrier
260 }
261 
262 
263 //====================================================================
265 {
266  dcomplex rho_prev2 = rho_prev;
267  dcomplex alpha_prev2 = alpha_prev;
268 
269  for (int j = 0; j < m_N_L; ++j) {
270  dcomplex rho = dotc(r[j], r_init); // dcomplex rho = r[j] * r_init;
271  rho = conj(rho);
272 
273  dcomplex beta = alpha_prev2 * (rho / rho_prev2);
274 
275  rho_prev2 = rho;
276 
277  for (int i = 0; i < j + 1; ++i) {
278  aypx(-beta, u[i], r[i]); // u[i] = - beta * u[i] + r[i];
279  }
280 
281  m_fopr->mult(u[j + 1], u[j]); // u[j+1] = m_fopr->mult(u[j]);
282 
283  dcomplex gamma = dotc(u[j + 1], r_init);
284  alpha_prev2 = rho_prev2 / conj(gamma);
285 
286  for (int i = 0; i < j + 1; ++i) {
287  axpy(r[i], -alpha_prev2, u[i + 1]); // r[i] -= alpha_prev * u[i+1];
288  }
289 
290  m_fopr->mult(r[j + 1], r[j]); // r[j+1] = m_fopr->mult(r[j]);
291 
292  axpy(x, alpha_prev2, u[0]); // x += alpha_prev * u[0];
293  }
294 
295 
296  std::vector<double> sigma(m_N_L + 1);
297  std::vector<dcomplex> gamma_prime(m_N_L + 1);
298 
299  // NB. tau(m_N_L,m_N_L+1), not (m_N_L+1,m_N_L+1)
300  std::vector<dcomplex> tau(m_N_L * (m_N_L + 1));
301  int ij, ji;
302 
303  for (int j = 1; j < m_N_L + 1; ++j) {
304  for (int i = 1; i < j; ++i) {
305  ij = index_ij(i, j);
306 
307  dcomplex r_ji = dotc(r[j], r[i]);
308  tau[ij] = conj(r_ji) / sigma[i]; // tau[ij] = (r[j] * r[i]) / sigma[i];
309  axpy(r[j], -tau[ij], r[i]); // r[j] -= tau[ij] * r[i];
310  }
311 
312  sigma[j] = r[j].norm2(); // sigma[j] = r[j] * r[j];
313 
314  dcomplex r_0j = dotc(r[0], r[j]);
315  gamma_prime[j] = conj(r_0j) / sigma[j]; // gamma_prime[j] = (r[0] * r[j]) / sigma[j];
316  }
317 
318 
319  std::vector<dcomplex> gamma(m_N_L + 1);
320  dcomplex c_tmp;
321 
322  gamma[m_N_L] = gamma_prime[m_N_L];
323 
324  for (int j = m_N_L - 1; j > 0; --j) {
325  c_tmp = cmplx(0.0, 0.0);
326 
327  for (int i = j + 1; i < m_N_L + 1; ++i) {
328  ji = index_ij(j, i);
329  c_tmp += tau[ji] * gamma[i];
330  }
331 
332  gamma[j] = gamma_prime[j] - c_tmp;
333  }
334 
335 
336  // NB. gamma_double_prime(m_N_L), not (m_N_L+1)
337  std::vector<dcomplex> gamma_double_prime(m_N_L);
338 
339  for (int j = 1; j < m_N_L; ++j) {
340  c_tmp = cmplx(0.0, 0.0);
341 
342  for (int i = j + 1; i < m_N_L; ++i) {
343  ji = index_ij(j, i);
344  c_tmp += tau[ji] * gamma[i + 1];
345  }
346 
347  gamma_double_prime[j] = gamma[j + 1] + c_tmp;
348  }
349 
350 
351  axpy(x, gamma[1], r[0]); // x += gamma[ 1] * r[ 0];
352  axpy(r[0], -gamma_prime[m_N_L], r[m_N_L]); // r[0] -= gamma_prime[m_N_L] * r[m_N_L];
353  axpy(u[0], -gamma[m_N_L], u[m_N_L]); // u[0] -= gamma[ m_N_L] * u[m_N_L];
354 
355  for (int j = 1; j < m_N_L; ++j) {
356  axpy(x, gamma_double_prime[j], r[j]); // x += gamma_double_prime[j] * r[j];
357  axpy(r[0], -gamma_prime[j], r[j]); // r[0] -= gamma_prime[ j] * r[j];
358  axpy(u[0], -gamma[j], u[j]); // u[0] -= gamma[ j] * u[j];
359  }
360 
361  rr = r[0].norm2(); // rr = r[0] * r[0];
362 
363 #pragma omp barrier
364 #pragma omp master
365  {
366  rho_prev = rho_prev2;
367  alpha_prev = alpha_prev2;
368  rho_prev *= -gamma_prime[m_N_L];
369  }
370 #pragma omp barrier
371 }
372 
373 
374 //====================================================================
375 //============================================================END=====
BridgeIO vout
Definition: bridgeIO.cpp:278
static const std::string class_name
void detailed(const char *format,...)
Definition: bridgeIO.cpp:82
void Register_string(const string &, const string &)
Definition: parameters.cpp:351
double norm2() const
Definition: field.cpp:441
void general(const char *format,...)
Definition: bridgeIO.cpp:65
void Register_int(const string &, const int)
Definition: parameters.cpp:330
void set_parameters_L(const int N_L)
BiCGStab(L) algorithm.
Container of Field-type object.
Definition: field.h:39
int nvol() const
Definition: field.h:116
void solve_init(const Field &, double &)
Class for parameters.
Definition: parameters.h:38
void copy(Field &y, const Field &x)
copy(y, x): y = x
Definition: field.cpp:381
int square_non_zero(const double v)
Definition: checker.cpp:41
int nin() const
Definition: field.h:115
void set_parameters(const Parameters &params)
dcomplex dotc(const Field &y, const Field &x)
Definition: field.cpp:92
void reset(const int Nin, const int Nvol, const int Nex, const element_type cmpl=COMPLEX)
Definition: field.h:84
void aypx(const double a, Field &y, const Field &x)
aypx(y, a, x): y := a * y + x
Definition: field.cpp:461
int nex() const
Definition: field.h:117
void paranoiac(const char *format,...)
Definition: bridgeIO.cpp:99
void axpy(Field &y, const double a, const Field &x)
axpy(y, a, x): y := a * x + y
Definition: field.cpp:168
void crucial(const char *format,...)
Definition: bridgeIO.cpp:48
Base class for linear solver class family.
Definition: solver.h:38
static bool Register(const std::string &realm, const creator_callback &cb)
virtual void mult(Field &, const Field &)=0
multiplies fermion operator to a given field (2nd argument)
int non_negative(const int v)
Definition: checker.cpp:21
void Register_double(const string &, const double)
Definition: parameters.cpp:323
Base class of fermion operator family.
Definition: fopr.h:49
int fetch_double(const string &key, double &val) const
Definition: parameters.cpp:124
string get_string(const string &key) const
Definition: parameters.cpp:87
int fetch_int(const string &key, int &val) const
Definition: parameters.cpp:141
Bridge::VerboseLevel m_vl
Definition: solver.h:62
static VerboseLevel set_verbose_level(const std::string &str)
Definition: bridgeIO.cpp:28
void solve(Field &solution, const Field &source, int &Nconv, double &diff)
static void assert_single_thread(const std::string &classname)
assert currently running on single thread.
int size() const
Definition: field.h:121