Bridge++  Version 1.5.4
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
solver_BiCGStab_Cmplx.cpp
Go to the documentation of this file.
1 
14 #include "solver_BiCGStab_Cmplx.h"
15 
16 #ifdef USE_FACTORY_AUTOREGISTER
17 namespace {
18  bool init = Solver_BiCGStab_Cmplx::register_factory();
19 }
20 #endif
21 
22 const std::string Solver_BiCGStab_Cmplx::class_name = "Solver_BiCGStab_Cmplx";
23 
24 //====================================================================
26 {
27  const string str_vlevel = params.get_string("verbose_level");
28 
29  m_vl = vout.set_verbose_level(str_vlevel);
30 
31  //- fetch and check input parameters
32  int Niter, Nrestart;
33  double Stop_cond;
34  bool use_init_guess;
35  double Omega_tolerance;
36 
37  int err = 0;
38  err += params.fetch_int("maximum_number_of_iteration", Niter);
39  err += params.fetch_int("maximum_number_of_restart", Nrestart);
40  err += params.fetch_double("convergence_criterion_squared", Stop_cond);
41  err += params.fetch_bool("use_initial_guess", use_init_guess);
42  err += params.fetch_double("Omega_tolerance", Omega_tolerance);
43 
44  if (err) {
45  vout.crucial(m_vl, "Error at %s: input parameter not found.\n",
46  class_name.c_str());
47  exit(EXIT_FAILURE);
48  }
49 
50  set_parameters(Niter, Nrestart, Stop_cond, use_init_guess);
51  set_parameters_BiCGStab_series(Omega_tolerance);
52 }
53 
54 
55 //====================================================================
56 void Solver_BiCGStab_Cmplx::set_parameters(const int Niter, const int Nrestart,
57  const double Stop_cond)
58 {
60 
61  //- print input parameters
62  vout.general(m_vl, "%s: input parameters\n", class_name.c_str());
63  vout.general(m_vl, " Niter = %d\n", Niter);
64  vout.general(m_vl, " Nrestart = %d\n", Nrestart);
65  vout.general(m_vl, " Stop_cond = %8.2e\n", Stop_cond);
66 
67  //- range check
68  int err = 0;
69  err += ParameterCheck::non_negative(Niter);
70  err += ParameterCheck::non_negative(Nrestart);
71  err += ParameterCheck::square_non_zero(Stop_cond);
72 
73  if (err) {
74  vout.crucial(m_vl, "Error at %s: parameter range check failed.\n", class_name.c_str());
75  exit(EXIT_FAILURE);
76  }
77 
78  //- store values
79  m_Niter = Niter;
80  m_Nrestart = Nrestart;
81  m_Stop_cond = Stop_cond;
82 }
83 
84 
85 //====================================================================
86 void Solver_BiCGStab_Cmplx::set_parameters(const int Niter, const int Nrestart,
87  const double Stop_cond, const bool use_init_guess)
88 {
90 
91  //- print input parameters
92  vout.general(m_vl, "%s: input parameters\n", class_name.c_str());
93  vout.general(m_vl, " Niter = %d\n", Niter);
94  vout.general(m_vl, " Nrestart = %d\n", Nrestart);
95  vout.general(m_vl, " Stop_cond = %8.2e\n", Stop_cond);
96  vout.general(m_vl, " use_init_guess = %s\n", use_init_guess ? "true" : "false");
97 
98  //- range check
99  int err = 0;
100  err += ParameterCheck::non_negative(Niter);
101  err += ParameterCheck::non_negative(Nrestart);
102  err += ParameterCheck::square_non_zero(Stop_cond);
103 
104  if (err) {
105  vout.crucial(m_vl, "Error at %s: parameter range check failed.\n", class_name.c_str());
106  exit(EXIT_FAILURE);
107  }
108 
109  //- store values
110  m_Niter = Niter;
111  m_Nrestart = Nrestart;
112  m_Stop_cond = Stop_cond;
113  m_use_init_guess = use_init_guess;
114 }
115 
116 
117 //====================================================================
119 {
121 
122  //- print input parameters
123  vout.general(m_vl, " Omega_tolerance = %8.2e\n", Omega_tolerance);
124 
125  //- range check
126  // NB. Omega_tolerance == 0.0 is allowed.
127 
128  //- store values
129  m_Omega_tolerance = Omega_tolerance;
130 }
131 
132 
133 //====================================================================
135  int& Nconv, double& diff)
136 {
137  const double bnorm2 = b.norm2();
138  const int bsize = b.size();
139 
140  vout.paranoiac(m_vl, "%s: solver starts\n", class_name.c_str());
141  vout.paranoiac(m_vl, " norm of b = %16.8e\n", bnorm2);
142  vout.paranoiac(m_vl, " size of b = %d\n", bsize);
143 
144  bool is_converged = false;
145  int Nconv2 = 0;
146  double diff2 = 1.0; // superficial initialization
147  double rr;
148 
149  int Nconv_unit = 1;
150  // if (m_fopr->get_mode() == "DdagD" || m_fopr->get_mode() == "DDdag") {
151  // Nconv_unit = 2;
152  // }
153 
154  reset_field(b);
155 
156  if (m_use_init_guess) {
157  copy(m_s, xq); // s = xq;
158  } else {
159  copy(m_s, b); // s = b;
160  }
161  solve_init(b, rr);
162  Nconv2 += Nconv_unit;
163 
164  vout.detailed(m_vl, " iter: %8d %22.15e\n", Nconv2, rr / bnorm2);
165 
166 
167  for (int i_restart = 0; i_restart < m_Nrestart; i_restart++) {
168  for (int iter = 0; iter < m_Niter; iter++) {
169  if (rr / bnorm2 < m_Stop_cond) break;
170 
171  solve_step(rr);
172  Nconv2 += 2 * Nconv_unit;
173 
174  vout.detailed(m_vl, " iter: %8d %22.15e\n", Nconv2, rr / bnorm2);
175  }
176 
177  //- calculate true residual
178  m_fopr->mult(m_s, m_x); // s = m_fopr->mult(x);
179  axpy(m_s, -1.0, b); // s -= b;
180  diff2 = m_s.norm2();
181 
182  if (diff2 / bnorm2 < m_Stop_cond) {
183  vout.detailed(m_vl, "%s: converged.\n", class_name.c_str());
184  vout.detailed(m_vl, " iter(final): %8d %22.15e\n", Nconv2, diff2 / bnorm2);
185 
186  is_converged = true;
187 
188  m_Nrestart_count = i_restart;
189  m_Nconv_count = Nconv2;
190 
191  break;
192  } else {
193  //- restart with new approximate solution
194  copy(m_s, m_x); // s = x;
195  solve_init(b, rr);
196 
197  vout.detailed(m_vl, "%s: restarted.\n", class_name.c_str());
198  }
199  }
200 
201 
202  if (!is_converged) {
203  vout.crucial(m_vl, "Error at %s: not converged.\n", class_name.c_str());
204  vout.crucial(m_vl, " iter(final): %8d %22.15e\n", Nconv2, diff2 / bnorm2);
205  exit(EXIT_FAILURE);
206  }
207 
208 
209  copy(xq, m_x); // xq = x;
210 
211 #pragma omp barrier
212 #pragma omp master
213  {
214  diff = sqrt(diff2 / bnorm2);
215  Nconv = Nconv2;
216  }
217 #pragma omp barrier
218 }
219 
220 
221 //====================================================================
223 {
224 #pragma omp barrier
225 #pragma omp master
226  {
227  const int Nin = b.nin();
228  const int Nvol = b.nvol();
229  const int Nex = b.nex();
230 
231  if ((m_s.nin() != Nin) || (m_s.nvol() != Nvol) || (m_s.nex() != Nex)) {
232  m_s.reset(Nin, Nvol, Nex);
233  m_r.reset(Nin, Nvol, Nex);
234  m_x.reset(Nin, Nvol, Nex);
235  m_p.reset(Nin, Nvol, Nex);
236  m_v.reset(Nin, Nvol, Nex);
237  m_t.reset(Nin, Nvol, Nex);
238  m_rh.reset(Nin, Nvol, Nex);
239  }
240  }
241 #pragma omp barrier
242 
243  vout.detailed(m_vl, " %s: field size reset.\n", class_name.c_str());
244 }
245 
246 
247 //====================================================================
248 void Solver_BiCGStab_Cmplx::solve_init(const Field& b, double& rr)
249 {
250  copy(m_x, m_s); // x = s;
251 
252  // r = b - A x_0
253  m_fopr->mult(m_v, m_s); // v = m_fopr->mult(s);
254  copy(m_r, b); // r = b;
255  axpy(m_r, -1.0, m_v); // r -= v;
256  copy(m_rh, m_r); // rh = r;
257 
258  rr = m_r.norm2(); // rr = r * r;
259 
260  m_p.set(0.0); // p = 0.0
261  m_v.set(0.0); // v = 0.0
262 
263 #pragma omp barrier
264 #pragma omp master
265  {
266  m_rho_prev = cmplx(1.0, 0.0);
267  m_alpha_prev = cmplx(1.0, 0.0);
268  m_omega_prev = cmplx(1.0, 0.0);
269  }
270 #pragma omp barrier
271 }
272 
273 
274 //====================================================================
276 {
277  const dcomplex rho = dotc(m_rh, m_r); // rho = rh * r;
278  const dcomplex beta = rho * m_alpha_prev / (m_rho_prev * m_omega_prev);
279 
280  // p = r + beta * (p - m_omega_prev * v);
281  axpy(m_p, -m_omega_prev, m_v); // p += - m_omega_prev * v;
282  aypx(beta, m_p, m_r); // p = beta * p + r;
283 
284  m_fopr->mult(m_v, m_p); // v = m_fopr->mult(p);
285 
286  const dcomplex aden = dotc(m_rh, m_v); // aden = rh * v;
287  const dcomplex alpha = rho / aden;
288 
289  copy(m_s, m_r); // s = r
290  axpy(m_s, -alpha, m_v); // s += - alpha * v;
291 
292  m_fopr->mult(m_t, m_s); // t = m_fopr->mult(s);
293 
294  const dcomplex omega_numer = dotc(m_t, m_s); // omega_numer = t * s;
295  const double omega_denom = dot(m_t, m_t); // omega_denom = t * t;
296 
297  dcomplex omega = omega_numer / omega_denom;
298 
299  const double s_norm2 = m_s.norm2();
300 
301  //- a prescription to improve stability of BiCGStab
302  const double abs_rho = abs(omega_numer) / sqrt(omega_denom * s_norm2);
303  if (abs_rho < m_Omega_tolerance) {
304  omega *= m_Omega_tolerance / abs_rho;
305  }
306 
307  axpy(m_x, omega, m_s); // x += omega * s;
308  axpy(m_x, alpha, m_p); // x += alpha * p;
309 
310  copy(m_r, m_s); // r = s
311  axpy(m_r, -omega, m_t); // r += - omega * t;
312 
313  rr = m_r.norm2(); // rr = r * r;
314 
315 #pragma omp barrier
316 #pragma omp master
317  {
318  m_rho_prev = rho;
319  m_alpha_prev = alpha;
320  m_omega_prev = omega;
321  }
322 #pragma omp barrier
323 }
324 
325 
326 //====================================================================
328 {
329  const int NPE = CommonParameters::NPE();
330 
331  //- NB1 Nin = 2 * Nc * Nd, Nex = 1 for field_F
332  //- NB2 Nvol = CommonParameters::Nvol()/2 for eo
333  const int Nin = m_x.nin();
334  const int Nvol = m_x.nvol();
335  const int Nex = m_x.nex();
336 
337  const double gflop_fopr = m_fopr->flop_count();
338 
339  if (gflop_fopr < CommonParameters::epsilon_criterion()) {
340  vout.crucial(m_vl, "Warning at %s: no fopr->flop_count() is available, setting flop = 0\n", class_name.c_str());
341  return 0.0;
342  }
343 
344  const double gflop_axpy = (Nin * Nex * 2) * ((Nvol * NPE) / 1.0e+9);
345  const double gflop_dotc = (Nin * Nex * 4) * ((Nvol * NPE) / 1.0e+9);
346  const double gflop_norm = (Nin * Nex * 2) * ((Nvol * NPE) / 1.0e+9);
347 
348  const double gflop_init = gflop_fopr + gflop_axpy + gflop_norm;
349  const double gflop_step = 2 * gflop_fopr + 3 * gflop_dotc + 6 * gflop_axpy + 2 * gflop_norm;
350  const double gflop_true_residual = gflop_fopr + gflop_axpy + gflop_norm;
351 
352  const int N_iter = (m_Nconv_count - 1) / 2;
353  const double gflop = gflop_norm + gflop_init + gflop_step * N_iter + gflop_true_residual * (m_Nrestart_count + 1)
354  + gflop_init * m_Nrestart_count;
355 
356 
357  return gflop;
358 }
359 
360 
361 //====================================================================
362 //============================================================END=====
void reset_field(const Field &)
BridgeIO vout
Definition: bridgeIO.cpp:503
int fetch_bool(const string &key, bool &value) const
Definition: parameters.cpp:391
void detailed(const char *format,...)
Definition: bridgeIO.cpp:216
static double epsilon_criterion()
double norm2() const
Definition: field.cpp:592
static const std::string class_name
void set(const int jin, const int site, const int jex, double v)
Definition: field.h:175
double dot(const Field &y, const Field &x)
Definition: field.cpp:46
void general(const char *format,...)
Definition: bridgeIO.cpp:197
void solve(Field &solution, const Field &source, int &Nconv, double &diff)
Container of Field-type object.
Definition: field.h:45
int fetch_double(const string &key, double &value) const
Definition: parameters.cpp:327
int nvol() const
Definition: field.h:127
void set_parameters(const Parameters &params)
Class for parameters.
Definition: parameters.h:46
void copy(Field &y, const Field &x)
copy(y, x): y = x
Definition: field.cpp:532
int square_non_zero(const double v)
int nin() const
Definition: field.h:126
dcomplex dotc(const Field &y, const Field &x)
Definition: field.cpp:155
int fetch_int(const string &key, int &value) const
Definition: parameters.cpp:346
virtual double flop_count()
returns the flop in giga unit
Definition: fopr.h:120
void aypx(const double a, Field &y, const Field &x)
aypx(y, a, x): y := a * y + x
Definition: field.cpp:612
int nex() const
Definition: field.h:128
void solve_init(const Field &, double &)
void paranoiac(const char *format,...)
Definition: bridgeIO.cpp:235
void reset(const int Nin, const int Nvol, const int Nex, const element_type cmpl=Element_type::COMPLEX)
Definition: field.h:95
void axpy(Field &y, const double a, const Field &x)
axpy(y, a, x): y := a * x + y
Definition: field.cpp:319
void crucial(const char *format,...)
Definition: bridgeIO.cpp:178
virtual void mult(Field &, const Field &)=0
multiplies fermion operator to a given field (2nd argument)
void set_parameters_BiCGStab_series(const double Omega_tolerance)
int non_negative(const int v)
static void assert_single_thread(const std::string &class_name)
assert currently running on single thread.
string get_string(const string &key) const
Definition: parameters.cpp:221
Bridge::VerboseLevel m_vl
Definition: solver.h:63
static VerboseLevel set_verbose_level(const std::string &str)
Definition: bridgeIO.cpp:131
int size() const
Definition: field.h:132