Bridge++  Version 1.4.4
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
solver_BiCGStab_Cmplx.cpp
Go to the documentation of this file.
1 
14 #include "solver_BiCGStab_Cmplx.h"
15 
16 
17 #ifdef USE_FACTORY
18 namespace {
19  Solver *create_object(Fopr *fopr)
20  {
21  return new Solver_BiCGStab_Cmplx(fopr);
22  }
23 
24 
25  bool init = Solver::Factory::Register("BiCGStab_Cmplx", create_object);
26 }
27 #endif
28 
29 
30 const std::string Solver_BiCGStab_Cmplx::class_name = "Solver_BiCGStab_Cmplx";
31 
32 //====================================================================
34 {
35  const string str_vlevel = params.get_string("verbose_level");
36 
37  m_vl = vout.set_verbose_level(str_vlevel);
38 
39  //- fetch and check input parameters
40  int Niter, Nrestart;
41  double Stop_cond;
42 
43  int err = 0;
44  err += params.fetch_int("maximum_number_of_iteration", Niter);
45  err += params.fetch_int("maximum_number_of_restart", Nrestart);
46  err += params.fetch_double("convergence_criterion_squared", Stop_cond);
47 
48  if (err) {
49  vout.crucial(m_vl, "Error at %s: input parameter not found.\n",
50  class_name.c_str());
51  exit(EXIT_FAILURE);
52  }
53 
54  set_parameters(Niter, Nrestart, Stop_cond);
55 }
56 
57 
58 //====================================================================
59 void Solver_BiCGStab_Cmplx::set_parameters(const int Niter, const int Nrestart,
60  const double Stop_cond)
61 {
63 
64  //- print input parameters
65  vout.general(m_vl, "%s: input parameters\n", class_name.c_str());
66  vout.general(m_vl, " Niter = %d\n", Niter);
67  vout.general(m_vl, " Nrestart = %d\n", Nrestart);
68  vout.general(m_vl, " Stop_cond = %8.2e\n", Stop_cond);
69 
70  //- range check
71  int err = 0;
72  err += ParameterCheck::non_negative(Niter);
73  err += ParameterCheck::non_negative(Nrestart);
74  err += ParameterCheck::square_non_zero(Stop_cond);
75 
76  if (err) {
77  vout.crucial(m_vl, "Error at %s: parameter range check failed.\n", class_name.c_str());
78  exit(EXIT_FAILURE);
79  }
80 
81  //- store values
82  m_Niter = Niter;
83  m_Nrestart = Nrestart;
84  m_Stop_cond = Stop_cond;
85 }
86 
87 
88 //====================================================================
90  int& Nconv, double& diff)
91 {
92  double bnorm2 = b.norm2();
93  int bsize = b.size();
94 
95  vout.paranoiac(m_vl, "%s: solver starts\n", class_name.c_str());
96  vout.paranoiac(m_vl, " norm of b = %16.8e\n", bnorm2);
97  vout.paranoiac(m_vl, " size of b = %d\n", bsize);
98 
99  bool is_converged = false;
100  int Nconv2 = 0;
101  double diff2 = 1.0;
102  double rr;
103 
104  reset_field(b);
105  copy(m_s, b); // s = b;
106  solve_init(b, rr);
107  Nconv2 += 1;
108 
109  vout.detailed(m_vl, " iter: %8d %22.15e\n", Nconv2, rr / bnorm2);
110 
111 
112  for (int i_restart = 0; i_restart < m_Nrestart; i_restart++) {
113  for (int iter = 0; iter < m_Niter; iter++) {
114  if (rr / bnorm2 < m_Stop_cond) break;
115 
116  solve_step(rr);
117  Nconv2 += 2;
118 
119  vout.detailed(m_vl, " iter: %8d %22.15e\n", Nconv2, rr / bnorm2);
120  }
121 
122  //- calculate true residual
123  m_fopr->mult(m_s, m_x); // s = m_fopr->mult(x);
124  axpy(m_s, -1.0, b); // s -= b;
125  diff2 = m_s.norm2();
126 
127  if (diff2 / bnorm2 < m_Stop_cond) {
128  vout.detailed(m_vl, "%s: converged.\n", class_name.c_str());
129  vout.detailed(m_vl, " iter(final): %8d %22.15e\n", Nconv2, diff2 / bnorm2);
130 
131  is_converged = true;
132  break;
133  } else {
134  //- restart with new approximate solution
135  copy(m_s, m_x); // s = x;
136  solve_init(b, rr);
137 
138  vout.detailed(m_vl, "%s: restarted.\n", class_name.c_str());
139  }
140  }
141 
142 
143  if (!is_converged) {
144  vout.crucial(m_vl, "Error at %s: not converged.\n", class_name.c_str());
145  vout.crucial(m_vl, " iter(final): %8d %22.15e\n", Nconv2, diff2 / bnorm2);
146  exit(EXIT_FAILURE);
147  }
148 
149 
150  copy(xq, m_x); // xq = x;
151 
152 #pragma omp barrier
153 #pragma omp master
154  {
155  diff = sqrt(diff2 / bnorm2);
156  Nconv = Nconv2;
157  }
158 #pragma omp barrier
159 }
160 
161 
162 //====================================================================
164 {
165 #pragma omp barrier
166 #pragma omp master
167  {
168  int Nin = b.nin();
169  int Nvol = b.nvol();
170  int Nex = b.nex();
171 
172  if ((m_s.nin() != Nin) || (m_s.nvol() != Nvol) || (m_s.nex() != Nex)) {
173  m_s.reset(Nin, Nvol, Nex);
174  m_r.reset(Nin, Nvol, Nex);
175  m_x.reset(Nin, Nvol, Nex);
176  m_p.reset(Nin, Nvol, Nex);
177  m_v.reset(Nin, Nvol, Nex);
178  m_t.reset(Nin, Nvol, Nex);
179  m_rh.reset(Nin, Nvol, Nex);
180  }
181  }
182 #pragma omp barrier
183 
184  vout.detailed(m_vl, " %s: field size reset.\n", class_name.c_str());
185 }
186 
187 
188 //====================================================================
189 void Solver_BiCGStab_Cmplx::solve_init(const Field& b, double& rr)
190 {
191  copy(m_x, m_s); // x = s;
192 
193  // r = b - A x_0
194  m_fopr->mult(m_v, m_s); // v = m_fopr->mult(s);
195  copy(m_r, b); // r = b;
196  axpy(m_r, -1.0, m_v); // r -= v;
197  copy(m_rh, m_r); // rh = r;
198 
199  rr = m_r.norm2(); // rr = r * r;
200 
201  m_p.set(0.0); // p = 0.0
202  m_v.set(0.0); // v = 0.0
203 
204 #pragma omp barrier
205 #pragma omp master
206  {
207  m_rho_prev = cmplx(1.0, 0.0);
208  m_alpha_prev = cmplx(1.0, 0.0);
209  m_omega_prev = cmplx(1.0, 0.0);
210  }
211 #pragma omp barrier
212 }
213 
214 
215 //====================================================================
217 {
218  dcomplex rho = dotc(m_rh, m_r); // rho = rh * r;
219  dcomplex beta = rho * m_alpha_prev / (m_rho_prev * m_omega_prev);
220 
221  // p = r + beta * (p - m_omega_prev * v);
222  axpy(m_p, -m_omega_prev, m_v); // p += - m_omega_prev * v;
223  aypx(beta, m_p, m_r); // p = beta * p + r;
224 
225  m_fopr->mult(m_v, m_p); // v = m_fopr->mult(p);
226 
227  dcomplex aden = dotc(m_rh, m_v); // aden = rh * v;
228  dcomplex alpha = rho / aden;
229 
230  copy(m_s, m_r); // s = r
231  axpy(m_s, -alpha, m_v); // s += - alpha * v;
232 
233  m_fopr->mult(m_t, m_s); // t = m_fopr->mult(s);
234 
235  double omega_d = dot(m_t, m_t); // omega_d = t * t;
236  dcomplex omega_n = dotc(m_t, m_s); // omega_n = t * s;
237  dcomplex omega = omega_n / omega_d;
238 
239  axpy(m_x, omega, m_s); // x += omega * s;
240  axpy(m_x, alpha, m_p); // x += alpha * p;
241 
242  copy(m_r, m_s); // r = s
243  axpy(m_r, -omega, m_t); // r += - omega * t;
244 
245  rr = m_r.norm2(); // rr = r * r;
246 
247 #pragma omp barrier
248 #pragma omp master
249  {
250  m_rho_prev = rho;
251  m_alpha_prev = alpha;
252  m_omega_prev = omega;
253  }
254 #pragma omp barrier
255 }
256 
257 
258 //====================================================================
260 {
261  int NPE = CommonParameters::NPE();
263 
264  //- NB1 Nin = 2 * Nc * Nd, Nex = 1 for field_F
265  //- NB2 Nvol = CommonParameters::Nvol()/2 for eo
266  int Nin = m_x.nin();
267  int Nvol = m_x.nvol();
268  int Nex = m_x.nex();
269 
270  double flop_fopr = m_fopr->flop_count();
271 
272  if (flop_fopr < eps) {
273  vout.crucial(m_vl, "Warning at %s: no fopr->flop_count() is available, setting flop = 0.0.\n", class_name.c_str());
274  return 0.0;
275  }
276 
277  double flop_axpy = static_cast<double>(Nin * Nex * 2) * (Nvol * NPE);
278  double flop_dotc = static_cast<double>(Nin * Nex * 4) * (Nvol * NPE);
279  double flop_norm = static_cast<double>(Nin * Nex * 2) * (Nvol * NPE);
280 
281  int N_iter = (m_Nconv_count - 1) / 2;
282 
283  double flop_init = flop_fopr + flop_axpy + flop_norm;
284  double flop_step = 2 * flop_fopr + 3 * flop_dotc + 6 * flop_axpy + 2 * flop_norm;
285  double flop_true_residual = flop_fopr + flop_axpy + flop_norm;
286 
287  double flop = flop_norm + flop_init + flop_step * N_iter + flop_true_residual
288  + flop_init * m_Nrestart_count;
289 
290 
291  return flop;
292 }
293 
294 
295 //====================================================================
296 //============================================================END=====
void reset_field(const Field &)
BridgeIO vout
Definition: bridgeIO.cpp:495
void detailed(const char *format,...)
Definition: bridgeIO.cpp:212
static double epsilon_criterion()
double norm2() const
Definition: field.cpp:441
static const std::string class_name
void set(const int jin, const int site, const int jex, double v)
Definition: field.h:164
double dot(const Field &y, const Field &x)
Definition: field.cpp:46
void general(const char *format,...)
Definition: bridgeIO.cpp:195
void solve(Field &solution, const Field &source, int &Nconv, double &diff)
Container of Field-type object.
Definition: field.h:39
int fetch_double(const string &key, double &value) const
Definition: parameters.cpp:211
int nvol() const
Definition: field.h:116
void set_parameters(const Parameters &params)
Class for parameters.
Definition: parameters.h:46
void copy(Field &y, const Field &x)
copy(y, x): y = x
Definition: field.cpp:381
int square_non_zero(const double v)
Definition: checker.cpp:41
int nin() const
Definition: field.h:115
dcomplex dotc(const Field &y, const Field &x)
Definition: field.cpp:92
int fetch_int(const string &key, int &value) const
Definition: parameters.cpp:230
virtual double flop_count()
returns the flops per site.
Definition: fopr.h:121
void reset(const int Nin, const int Nvol, const int Nex, const element_type cmpl=COMPLEX)
Definition: field.h:84
void aypx(const double a, Field &y, const Field &x)
aypx(y, a, x): y := a * y + x
Definition: field.cpp:461
int nex() const
Definition: field.h:117
void solve_init(const Field &, double &)
void paranoiac(const char *format,...)
Definition: bridgeIO.cpp:229
void axpy(Field &y, const double a, const Field &x)
axpy(y, a, x): y := a * x + y
Definition: field.cpp:168
BiCGStab algorithm with complex variables.
void crucial(const char *format,...)
Definition: bridgeIO.cpp:178
Base class for linear solver class family.
Definition: solver.h:37
virtual void mult(Field &, const Field &)=0
multiplies fermion operator to a given field (2nd argument)
int non_negative(const int v)
Definition: checker.cpp:21
static void assert_single_thread(const std::string &class_name)
assert currently running on single thread.
Base class of fermion operator family.
Definition: fopr.h:47
string get_string(const string &key) const
Definition: parameters.cpp:116
Bridge::VerboseLevel m_vl
Definition: solver.h:63
static VerboseLevel set_verbose_level(const std::string &str)
Definition: bridgeIO.cpp:131
static int NPE()
int size() const
Definition: field.h:121