Bridge++  Ver. 1.2.x
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
solver_BiCGStab_L_Cmplx.cpp
Go to the documentation of this file.
1 
15 
16 using std::valarray;
17 
18 
19 #ifdef USE_FACTORY
20 namespace {
21  Solver *create_object(Fopr *fopr)
22  {
23  return new Solver_BiCGStab_L_Cmplx(fopr);
24  }
25 
26 
27  bool init = Solver::Factory::Register("BiCGStab_L_Cmplx", create_object);
28 }
29 #endif
30 
31 //- parameter entries
32 namespace {
33  void append_entry(Parameters& param)
34  {
35  param.Register_int("maximum_number_of_iteration", 0);
36  param.Register_double("convergence_criterion_squared", 0.0);
37 
38  param.Register_int("number_of_orthonormal_vectors", 0);
39 
40  param.Register_string("verbose_level", "NULL");
41  }
42 
43 
44 #ifdef USE_PARAMETERS_FACTORY
45  bool init_param = ParametersFactory::Register("Solver.BiCGStab_L_Cmplx", append_entry);
46 #endif
47 }
48 //- end
49 
50 //- parameters class
52 //- end
53 
54 const std::string Solver_BiCGStab_L_Cmplx::class_name = "Solver_BiCGStab_L_Cmplx";
55 
56 //====================================================================
58 {
59  const string str_vlevel = params.get_string("verbose_level");
60 
61  m_vl = vout.set_verbose_level(str_vlevel);
62 
63  //- fetch and check input parameters
64  int Niter;
65  double Stop_cond;
66  int N_L;
67 
68  int err = 0;
69  err += params.fetch_int("maximum_number_of_iteration", Niter);
70  err += params.fetch_double("convergence_criterion_squared", Stop_cond);
71  err += params.fetch_int("number_of_orthonormal_vectors", N_L);
72 
73  if (err) {
74  vout.crucial(m_vl, "%s: fetch error, input parameter not found.\n", class_name.c_str());
75  abort();
76  }
77 
78 
79  set_parameters(Niter, Stop_cond);
80  set_parameters_L(N_L);
81 }
82 
83 
84 //====================================================================
85 void Solver_BiCGStab_L_Cmplx::set_parameters(const int Niter, const double Stop_cond)
86 {
88 
89  //- print input parameters
90  vout.general(m_vl, "%s: input parameters\n", class_name.c_str());
91  vout.general(m_vl, " Niter = %d\n", Niter);
92  vout.general(m_vl, " Stop_cond = %16.8e\n", Stop_cond);
93 
94  //- range check
95  int err = 0;
96  err += ParameterCheck::non_negative(Niter);
97  err += ParameterCheck::square_non_zero(Stop_cond);
98 
99  if (err) {
100  vout.crucial(m_vl, "%s: parameter range check failed.\n", class_name.c_str());
101  abort();
102  }
103 
104  //- store values
105  m_Niter = Niter;
106  m_Stop_cond = Stop_cond;
107 }
108 
109 
110 //====================================================================
112 {
113  //- print input parameters
114  vout.general(m_vl, " N_L = %d\n", N_L);
115 
116  //- range check
117  int err = 0;
118  err += ParameterCheck::non_negative(N_L);
119 
120  if (err) {
121  vout.crucial(m_vl, "%s: parameter range check failed.\n", class_name.c_str());
122  abort();
123  }
124 
125  //- store values
126  m_N_L = N_L;
127 }
128 
129 
130 //====================================================================
132  int& Nconv, double& diff)
133 {
134  //#pragma omp parallel
135  {
136  double bnorm2 = b.norm2();
137  double snorm = 1.0 / bnorm2;
138  int bsize = b.size();
139  double rr;
140 
143 
144 #pragma omp master
145  {
146  vout.paranoiac(m_vl, "%s: starts\n", class_name.c_str());
147  vout.paranoiac(m_vl, " norm of b = %16.8e\n", bnorm2);
148  vout.paranoiac(m_vl, " size of b = %d\n", bsize);
149  }
150 #pragma omp barrier
151 
152 
153  reset_field(b);
154 
155 
156  // Nconv = -1;
157  int Nconv2 = 0;
158  copy(s, b); // s = b;
159 
160  solve_init(b, rr);
161 
162  bool is_converged = false;
163 
164 #pragma omp master
165  vout.detailed(m_vl, " iter: %8d %22.15e\n", 0, rr * snorm);
166 #pragma omp barrier
167 
168 
169  for (int iter = 0; iter < m_Niter; iter++) {
170  if (!is_converged) {
171  solve_step(rr);
172 
173  Nconv2 += 2 * m_N_L;
174 
175 #pragma omp master
176  vout.detailed(m_vl, " iter: %8d %22.15e\n", Nconv2, rr * snorm);
177 #pragma omp barrier
178 
179  if (rr * snorm < m_Stop_cond) {
180  m_fopr->mult(s, x); // s = m_fopr->mult(x);
181  axpy(s, -1.0, b); // s -= b;
182 
183  double diff2 = s.norm2();
185 
187  if (ith == 0) vout.detailed(m_vl, " iter0: %8d %22.15e\n", nth, diff2 * snorm);
188 
189  if (diff2 * snorm < m_Stop_cond) {
190  // NB. Nconv is calculated above.
191  // break;
192  is_converged = true;
193  }
194 
195  copy(s, x); // s = x;
196  solve_init(b, rr);
197 
199  if (ith == 0) vout.detailed(m_vl, " iter1: %8d %22.15e\n", nth, rr * snorm);
200  }
201  }
202  }
203 
204 
205  m_fopr->mult(s, x); // s = m_fopr->mult(x);
206  axpy(s, -1.0, b); // s -= b;
207 
208  copy(xq, x); // xq = x;
209 
210  double diff2 = s.norm2();
211 
212 #pragma omp master
213  {
214  diff = sqrt(diff2);
215  Nconv = Nconv2;
216  }
217 #pragma omp barrier
218 
219  if (diff2 * snorm > m_Stop_cond) {
220 #pragma omp master
221  vout.crucial(m_vl, "%s: not converged.\n", class_name.c_str());
222 #pragma omp barrier
223  abort();
224  }
225  } // end of parallel region
226 }
227 
228 
229 //====================================================================
231 {
232 #pragma omp master
233  {
234  int Nin = b.nin();
235  int Nvol = b.nvol();
236  int Nex = b.nex();
237 
238  if ((s.nin() != Nin) || (s.nvol() != Nvol) || (s.nex() != Nex)) {
239  s.reset(Nin, Nvol, Nex);
240  x.reset(Nin, Nvol, Nex);
241 
242  r_init.reset(Nin, Nvol, Nex);
243 
244  v_tmp.reset(Nin, Nvol, Nex);
245 
246  vout.paranoiac(m_vl, " %s: field size reset.\n", class_name.c_str());
247  }
248 
249  u.resize(m_N_L + 1);
250  r.resize(m_N_L + 1);
251 
252  for (int i = 0; i < m_N_L + 1; ++i) {
253  u[i].reset(Nin, Nvol, Nex);
254  r[i].reset(Nin, Nvol, Nex);
255  }
256  }
257 #pragma omp barrier
258 }
259 
260 
261 //====================================================================
262 void Solver_BiCGStab_L_Cmplx::solve_init(const Field& b, double& rr)
263 {
264  copy(x, s); // x = s;
265 
266  //- r[0] = b - A x_0
267  m_fopr->mult(v_tmp, s); // v_tmp = m_fopr->mult(s);
268  copy(r[0], b); // r[0] = b;
269  axpy(r[0], -1.0, v_tmp); // r[0] -= v_tmp;
270 
271  copy(r_init, r[0]); // r_init = r[0];
272  rr = r[0].norm2(); // rr = r[0] * r[0];
273 
274  u[0].set(0.0); // u[0] = 0.0;
275 
276  // NB. alpha_prev = 0.0 \neq 1.0
277 #pragma omp master
278  {
279  rho_prev = cmplx(-1.0, 0.0);
280  alpha_prev = cmplx(0.0, 0.0);
281  }
282 #pragma omp barrier
283 }
284 
285 
286 //====================================================================
288 {
289  dcomplex rho_prev2 = rho_prev;
290  dcomplex alpha_prev2 = alpha_prev;
291 
292  for (int j = 0; j < m_N_L; ++j) {
293  dcomplex rho = dotc(r[j], r_init); // dcomplex rho = r[j] * r_init;
294  rho = conj(rho);
295 
296  dcomplex beta = alpha_prev2 * (rho / rho_prev2);
297 
298  rho_prev2 = rho;
299 
300  for (int i = 0; i < j + 1; ++i) {
301  aypx(-beta, u[i], r[i]); // u[i] = - beta * u[i] + r[i];
302  }
303 
304  m_fopr->mult(u[j + 1], u[j]); // u[j+1] = m_fopr->mult(u[j]);
305 
306  dcomplex gamma = dotc(u[j + 1], r_init);
307  alpha_prev2 = rho_prev2 / conj(gamma);
308 
309  for (int i = 0; i < j + 1; ++i) {
310  axpy(r[i], -alpha_prev2, u[i + 1]); // r[i] -= alpha_prev * u[i+1];
311  }
312 
313  m_fopr->mult(r[j + 1], r[j]); // r[j+1] = m_fopr->mult(r[j]);
314 
315  axpy(x, alpha_prev2, u[0]); // x += alpha_prev * u[0];
316  }
317 
318 
319  valarray<double> sigma(m_N_L + 1);
320  valarray<dcomplex> gamma_prime(m_N_L + 1);
321 
322  // NB. tau(m_N_L,m_N_L+1), not (m_N_L+1,m_N_L+1)
323  valarray<dcomplex> tau(m_N_L * (m_N_L + 1));
324  int ij, ji;
325 
326  for (int j = 1; j < m_N_L + 1; ++j) {
327  for (int i = 1; i < j; ++i) {
328  ij = index_ij(i, j);
329 
330  dcomplex r_ji = dotc(r[j], r[i]);
331  tau[ij] = conj(r_ji) / sigma[i]; // tau[ij] = (r[j] * r[i]) / sigma[i];
332  axpy(r[j], -tau[ij], r[i]); // r[j] -= tau[ij] * r[i];
333  }
334 
335  sigma[j] = r[j].norm2(); // sigma[j] = r[j] * r[j];
336 
337  dcomplex r_0j = dotc(r[0], r[j]);
338  gamma_prime[j] = conj(r_0j) / sigma[j]; // gamma_prime[j] = (r[0] * r[j]) / sigma[j];
339  }
340 
341 
342  valarray<dcomplex> gamma(m_N_L + 1);
343  dcomplex c_tmp;
344 
345  gamma[m_N_L] = gamma_prime[m_N_L];
346 
347  for (int j = m_N_L - 1; j > 0; --j) {
348  c_tmp = cmplx(0.0, 0.0);
349 
350  for (int i = j + 1; i < m_N_L + 1; ++i) {
351  ji = index_ij(j, i);
352  c_tmp += tau[ji] * gamma[i];
353  }
354 
355  gamma[j] = gamma_prime[j] - c_tmp;
356  }
357 
358 
359  // NB. gamma_double_prime(m_N_L), not (m_N_L+1)
360  valarray<dcomplex> gamma_double_prime(m_N_L);
361 
362  for (int j = 1; j < m_N_L; ++j) {
363  c_tmp = cmplx(0.0, 0.0);
364 
365  for (int i = j + 1; i < m_N_L; ++i) {
366  ji = index_ij(j, i);
367  c_tmp += tau[ji] * gamma[i + 1];
368  }
369 
370  gamma_double_prime[j] = gamma[j + 1] + c_tmp;
371  }
372 
373 
374  axpy(x, gamma[1], r[0]); // x += gamma[ 1] * r[ 0];
375  axpy(r[0], -gamma_prime[m_N_L], r[m_N_L]); // r[0] -= gamma_prime[m_N_L] * r[m_N_L];
376  axpy(u[0], -gamma[m_N_L], u[m_N_L]); // u[0] -= gamma[ m_N_L] * u[m_N_L];
377 
378  for (int j = 1; j < m_N_L; ++j) {
379  axpy(x, gamma_double_prime[j], r[j]); // x += gamma_double_prime[j] * r[j];
380  axpy(r[0], -gamma_prime[j], r[j]); // r[0] -= gamma_prime[ j] * r[j];
381  axpy(u[0], -gamma[j], u[j]); // u[0] -= gamma[ j] * u[j];
382  }
383 
384  rr = r[0].norm2(); // rr = r[0] * r[0];
385 
386 #pragma omp master
387  {
388  rho_prev = rho_prev2;
389  alpha_prev = alpha_prev2;
390  rho_prev *= -gamma_prime[m_N_L];
391  }
392 #pragma omp barrier
393 }
394 
395 
396 //====================================================================
397 //============================================================END=====
BridgeIO vout
Definition: bridgeIO.cpp:207
static const std::string class_name
void detailed(const char *format,...)
Definition: bridgeIO.cpp:50
static int get_num_threads()
returns available number of threads.
void Register_string(const string &, const string &)
Definition: parameters.cpp:352
double norm2() const
Definition: field.cpp:469
virtual const Field mult(const Field &)=0
multiplies fermion operator to a given field and returns the resultant field.
void general(const char *format,...)
Definition: bridgeIO.cpp:38
void Register_int(const string &, const int)
Definition: parameters.cpp:331
void set_parameters_L(const int N_L)
BiCGStab(L) algorithm.
Container of Field-type object.
Definition: field.h:37
std::valarray< Field > u
int nvol() const
Definition: field.h:101
void solve_init(const Field &, double &)
Class for parameters.
Definition: parameters.h:40
void copy(Field &y, const Field &x)
copy(y, x): y = x
Definition: field.cpp:409
static int get_thread_id()
returns thread id.
int square_non_zero(const double v)
Definition: checker.cpp:41
int nin() const
Definition: field.h:100
void set_parameters(const Parameters &params)
dcomplex dotc(const Field &y, const Field &x)
Definition: field.cpp:98
void reset(const int Nin, const int Nvol, const int Nex, const element_type cmpl=COMPLEX)
Definition: field.h:82
static void sync_barrier_all()
barrier among all the threads and nodes.
void aypx(const double a, Field &y, const Field &x)
aypx(y, a, x): y := a * y + x
Definition: field.cpp:489
int nex() const
Definition: field.h:102
void paranoiac(const char *format,...)
Definition: bridgeIO.cpp:62
void axpy(Field &y, const double a, const Field &x)
axpy(y, a, x): y := a * x + y
Definition: field.cpp:193
void crucial(const char *format,...)
Definition: bridgeIO.cpp:26
Base class for linear solver class family.
Definition: solver.h:37
static bool Register(const std::string &realm, const creator_callback &cb)
int non_negative(const int v)
Definition: checker.cpp:21
void Register_double(const string &, const double)
Definition: parameters.cpp:324
std::valarray< Field > r
Base class of fermion operator family.
Definition: fopr.h:39
int fetch_double(const string &key, double &val) const
Definition: parameters.cpp:124
string get_string(const string &key) const
Definition: parameters.cpp:85
int fetch_int(const string &key, int &val) const
Definition: parameters.cpp:141
Bridge::VerboseLevel m_vl
Definition: solver.h:56
static VerboseLevel set_verbose_level(const std::string &str)
Definition: bridgeIO.cpp:191
void solve(Field &solution, const Field &source, int &Nconv, double &diff)
static void assert_single_thread(const std::string &classname)
assert currently running on single thread.
int size() const
Definition: field.h:106