Bridge++  Ver. 1.3.x
solver_BiCGStab.cpp
Go to the documentation of this file.
1 
14 #include "solver_BiCGStab.h"
15 
16 
17 #ifdef USE_FACTORY
18 namespace {
19  Solver *create_object(Fopr *fopr)
20  {
21  return new Solver_BiCGStab(fopr);
22  }
23 
24 
25  bool init = Solver::Factory::Register("BiCGStab", create_object);
26 }
27 #endif
28 
29 //- parameter entries
30 namespace {
31  void append_entry(Parameters& param)
32  {
33  param.Register_int("maximum_number_of_iteration", 0);
34  param.Register_double("convergence_criterion_squared", 0.0);
35 
36  param.Register_string("verbose_level", "NULL");
37  }
38 
39 
40 #ifdef USE_PARAMETERS_FACTORY
41  bool init_param = ParametersFactory::Register("Solver.BiCGStab", append_entry);
42 #endif
43 }
44 //- end
45 
46 //- parameters class
48 //- end
49 
50 const std::string Solver_BiCGStab::class_name = "Solver_BiCGStab";
51 
52 //====================================================================
54 {
55  const string str_vlevel = params.get_string("verbose_level");
56 
57  m_vl = vout.set_verbose_level(str_vlevel);
58 
59  //- fetch and check input parameters
60  int Niter;
61  double Stop_cond;
62 
63  int err = 0;
64  err += params.fetch_int("maximum_number_of_iteration", Niter);
65  err += params.fetch_double("convergence_criterion_squared", Stop_cond);
66 
67  if (err) {
68  vout.crucial(m_vl, "%s: fetch error, input parameter not found.\n", class_name.c_str());
69  exit(EXIT_FAILURE);
70  }
71 
72 
73  set_parameters(Niter, Stop_cond);
74 }
75 
76 
77 //====================================================================
78 void Solver_BiCGStab::set_parameters(const int Niter, const double Stop_cond)
79 {
81 
82  //- print input parameters
83  vout.general(m_vl, "%s: input parameters\n", class_name.c_str());
84  vout.general(m_vl, " Niter = %d\n", Niter);
85  vout.general(m_vl, " Stop_cond = %16.8e\n", Stop_cond);
86 
87  //- range check
88  int err = 0;
89  err += ParameterCheck::non_negative(Niter);
90  err += ParameterCheck::square_non_zero(Stop_cond);
91 
92  if (err) {
93  vout.crucial(m_vl, "%s: parameter range check failed.\n", class_name.c_str());
94  exit(EXIT_FAILURE);
95  }
96 
97  //- store values
98  m_Niter = Niter;
99  m_Stop_cond = Stop_cond;
100 }
101 
102 
103 //====================================================================
104 void Solver_BiCGStab::solve(Field& xq, const Field& b,
105  int& Nconv, double& diff)
106 {
107  double bnorm2 = b.norm2();
108  double snorm = 1.0 / bnorm2;
109  int bsize = b.size();
110 
111  vout.paranoiac(m_vl, "%s: solver starts\n", class_name.c_str());
112  vout.paranoiac(m_vl, " norm of b = %16.8e\n", bnorm2);
113  vout.paranoiac(m_vl, " size of b = %d\n", bsize);
114 
115  reset_field(b);
116 
117  copy(s, b); // s = b;
118 
119  double rr;
120  int Nconv2 = -1;
121 
122  solve_init(b, rr);
123 
124  bool is_converged = false;
125 
126  vout.detailed(m_vl, " iter: %8d %22.15e\n", 0, rr * snorm);
127 
128 
129  for (int iter = 0; iter < m_Niter; iter++) {
130  if (is_converged) break;
131 
132  solve_step(rr);
133 
134  vout.detailed(m_vl, " iter: %8d %22.15e\n", 2 * (iter + 1), rr * snorm);
135 
136  if (rr * snorm < m_Stop_cond) {
137  m_fopr->mult(s, x); // s = m_fopr->mult(x);
138  axpy(s, -1.0, b); // s -= b;
139 
140  double diff2 = s.norm2();
141 
142  if (diff2 * snorm < m_Stop_cond) {
143  Nconv2 = 2 * (iter + 1);
144  is_converged = true;
145  } else {
146  copy(s, x); // s = x;
147  solve_init(b, rr);
148  }
149  }
150  }
151 
152 
153  m_fopr->mult(p, x); // p = m_fopr->mult(x);
154  axpy(p, -1.0, b); // p -= b;
155 
156  copy(xq, x); // xq = x;
157 
158  double diff2 = p.norm2();
159 
160  if (diff2 * snorm > m_Stop_cond) {
161  vout.crucial(m_vl, "%s: not converged.\n", class_name.c_str());
162  exit(EXIT_FAILURE);
163  }
164 
165 
166 #pragma omp barrier
167 #pragma omp master
168  {
169  diff = sqrt(diff2);
170  Nconv = Nconv2;
171  }
172 #pragma omp barrier
173 }
174 
175 
176 //====================================================================
178 {
179 #pragma omp barrier
180 #pragma omp master
181  {
182  int Nin = b.nin();
183  int Nvol = b.nvol();
184  int Nex = b.nex();
185 
186  if ((s.nin() != Nin) || (s.nvol() != Nvol) || (s.nex() != Nex)) {
187  s.reset(Nin, Nvol, Nex);
188  r.reset(Nin, Nvol, Nex);
189  x.reset(Nin, Nvol, Nex);
190  p.reset(Nin, Nvol, Nex);
191  v.reset(Nin, Nvol, Nex);
192  t.reset(Nin, Nvol, Nex);
193  rh.reset(Nin, Nvol, Nex);
194  }
195  }
196 #pragma omp barrier
197 
198  vout.detailed(m_vl, " %s: field size reset.\n", class_name.c_str());
199 }
200 
201 
202 //====================================================================
203 void Solver_BiCGStab::solve_init(const Field& b, double& rr)
204 {
205  copy(x, s); // x = s;
206 
207  //- r = b - A x_0
208  m_fopr->mult(v, s); // v = m_fopr->mult(s);
209  copy(r, b); // r = b;
210  axpy(r, -1.0, v); // r -= v;
211  copy(rh, r); // rh = r;
212 
213  rr = r.norm2(); // rr = r * r;
214 
215  p.set(0.0); // p = 0.0
216  v.set(0.0); // v = 0.0
217 
218 #pragma omp barrier
219 #pragma omp master
220  {
221  rho_prev = 1.0;
222  alpha_prev = 1.0;
223  omega_prev = 1.0;
224  }
225 #pragma omp barrier
226 }
227 
228 
229 //====================================================================
231 {
232  double rho = dot(rh, r); // double rho = rh * r;
233  double bet = rho * alpha_prev / (rho_prev * omega_prev);
234 
235  // p = r + bet * (p - omega_prev * v);
236  axpy(p, -omega_prev, v); // p += - omega_prev * v;
237  aypx(bet, p, r); // p = bet * p + r;
238 
239  m_fopr->mult(v, p); // v = m_fopr->mult(p);
240 
241  double aden = dot(rh, v); // dcomplex aden = rh * v;
242  double alpha = rho / aden;
243 
244  copy(s, r); // s = r
245  axpy(s, -alpha, v); // s += - alpha * v;
246 
247  m_fopr->mult(t, s); // t = m_fopr->mult(s);
248 
249  double omega_d = dot(t, t); // omega_d = t * t;
250  double omega_n = dot(t, s); // omega_n = t * s;
251  double omega = omega_n / omega_d;
252 
253  axpy(x, omega, s); // x += omega * s;
254  axpy(x, alpha, p); // x += alpha * p;
255 
256  copy(r, s); // r = s
257  axpy(r, -omega, t); // r += - omega * t;
258 
259  rr = r.norm2(); // rr = r * r;
260 
261 #pragma omp barrier
262 #pragma omp master
263  {
264  rho_prev = rho;
265  alpha_prev = alpha;
266  omega_prev = omega;
267  }
268 #pragma omp barrier
269 }
270 
271 
272 //====================================================================
273 //============================================================END=====
BridgeIO vout
Definition: bridgeIO.cpp:278
void detailed(const char *format,...)
Definition: bridgeIO.cpp:82
void Register_string(const string &, const string &)
Definition: parameters.cpp:351
double norm2() const
Definition: field.cpp:441
void solve(Field &solution, const Field &source, int &Nconv, double &diff)
void set(const int jin, const int site, const int jex, double v)
Definition: field.h:155
double dot(const Field &y, const Field &x)
Definition: field.cpp:46
void general(const char *format,...)
Definition: bridgeIO.cpp:65
void Register_int(const string &, const int)
Definition: parameters.cpp:330
Container of Field-type object.
Definition: field.h:39
int nvol() const
Definition: field.h:116
void set_parameters(const Parameters &params)
Class for parameters.
Definition: parameters.h:38
void copy(Field &y, const Field &x)
copy(y, x): y = x
Definition: field.cpp:381
void solve_init(const Field &, double &)
int square_non_zero(const double v)
Definition: checker.cpp:41
int nin() const
Definition: field.h:115
void reset(const int Nin, const int Nvol, const int Nex, const element_type cmpl=COMPLEX)
Definition: field.h:84
void aypx(const double a, Field &y, const Field &x)
aypx(y, a, x): y := a * y + x
Definition: field.cpp:461
int nex() const
Definition: field.h:117
void paranoiac(const char *format,...)
Definition: bridgeIO.cpp:99
void axpy(Field &y, const double a, const Field &x)
axpy(y, a, x): y := a * x + y
Definition: field.cpp:168
void crucial(const char *format,...)
Definition: bridgeIO.cpp:48
static const std::string class_name
Base class for linear solver class family.
Definition: solver.h:38
static bool Register(const std::string &realm, const creator_callback &cb)
virtual void mult(Field &, const Field &)=0
multiplies fermion operator to a given field (2nd argument)
BiCGStab algorithm.
void solve_step(double &)
int non_negative(const int v)
Definition: checker.cpp:21
void Register_double(const string &, const double)
Definition: parameters.cpp:323
void reset_field(const Field &)
Base class of fermion operator family.
Definition: fopr.h:49
int fetch_double(const string &key, double &val) const
Definition: parameters.cpp:124
string get_string(const string &key) const
Definition: parameters.cpp:87
int fetch_int(const string &key, int &val) const
Definition: parameters.cpp:141
Bridge::VerboseLevel m_vl
Definition: solver.h:62
static VerboseLevel set_verbose_level(const std::string &str)
Definition: bridgeIO.cpp:28
static void assert_single_thread(const std::string &classname)
assert currently running on single thread.
int size() const
Definition: field.h:121