Bridge++  Ver. 1.2.x
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
solver_BiCGStab_Cmplx.cpp
Go to the documentation of this file.
1 
14 #include "solver_BiCGStab_Cmplx.h"
15 
16 
17 #ifdef USE_FACTORY
18 namespace {
19  Solver *create_object(Fopr *fopr)
20  {
21  return new Solver_BiCGStab_Cmplx(fopr);
22  }
23 
24 
25  bool init = Solver::Factory::Register("BiCGStab_Cmplx", create_object);
26 }
27 #endif
28 
29 //- parameter entries
30 namespace {
31  void append_entry(Parameters& param)
32  {
33  param.Register_int("maximum_number_of_iteration", 0);
34  param.Register_double("convergence_criterion_squared", 0.0);
35 
36  param.Register_string("verbose_level", "NULL");
37  }
38 
39 
40 #ifdef USE_PARAMETERS_FACTORY
41  bool init_param = ParametersFactory::Register("Solver.BiCGStab_Cmplx",
42  append_entry);
43 #endif
44 }
45 //- end
46 
47 //- parameters class
49 { append_entry(*this); }
50 //- end
51 
52 const std::string Solver_BiCGStab_Cmplx::class_name = "Solver_BiCGStab_Cmplx";
53 
54 //====================================================================
56 {
57  const string str_vlevel = params.get_string("verbose_level");
58 
59  m_vl = vout.set_verbose_level(str_vlevel);
60 
61  //- fetch and check input parameters
62  int Niter;
63  double Stop_cond;
64 
65  int err = 0;
66  err += params.fetch_int("maximum_number_of_iteration", Niter);
67  err += params.fetch_double("convergence_criterion_squared", Stop_cond);
68 
69  if (err) {
70  vout.crucial(m_vl, "%s: fetch error, input parameter not found.\n",
71  class_name.c_str());
72  abort();
73  }
74 
75  set_parameters(Niter, Stop_cond);
76 }
77 
78 
79 //====================================================================
81  const double Stop_cond)
82 {
84 
85  //- print input parameters
86  vout.general(m_vl, "%s: input parameters\n", class_name.c_str());
87  vout.general(m_vl, " Niter = %d\n", Niter);
88  vout.general(m_vl, " Stop_cond = %16.8e\n", Stop_cond);
89 
90  //- range check
91  int err = 0;
92  err += ParameterCheck::non_negative(Niter);
93  err += ParameterCheck::square_non_zero(Stop_cond);
94 
95  if (err) {
96  vout.crucial(m_vl, "%s: parameter range check failed.\n", class_name.c_str());
97  abort();
98  }
99 
100  //- store values
101  m_Niter = Niter;
102  m_Stop_cond = Stop_cond;
103 }
104 
105 
106 //====================================================================
108  int& Nconv, double& diff)
109 {
110  //#pragma omp parallel
111  {
112  double bnorm2 = b.norm2();
113  double snorm = 1.0 / bnorm2;
114  int bsize = b.size();
115 
118 
119 #pragma omp master
120  {
121  vout.paranoiac(m_vl, "%s: solver starts\n", class_name.c_str());
122  vout.paranoiac(m_vl, " norm of b = %16.8e\n", bnorm2);
123  vout.paranoiac(m_vl, " size of b = %d\n", bsize);
124  vout.paranoiac(m_vl, " number of threads = %d\n", nth);
125  }
126 #pragma omp barrier
127 
128 
129  reset_field(b);
130 
131  copy(s, b); // s = b;
132 
133  double rr;
134  int Nconv2 = -1;
135 
136  solve_init(b, rr);
137 
138  bool is_converged = false;
139 
140 #pragma omp master
141  vout.detailed(m_vl, " iter: %8d %22.15e\n", 0, rr * snorm);
142 #pragma omp barrier
143 
144 
145  for (int iter = 0; iter < m_Niter; iter++) {
146  if (!is_converged) {
147  solve_step(rr);
148 
149 #pragma omp master
150  vout.detailed(m_vl, " iter: %8d %22.15e\n", 2 * (iter + 1), rr * snorm);
151 #pragma omp barrier
152 
153  if (rr * snorm < m_Stop_cond) {
154  m_fopr->mult(s, x); // s = m_fopr->mult(x);
155  axpy(s, -1.0, b); // s -= b;
156 
157  double diff2 = s.norm2();
159 
161  if (ith == 0) vout.detailed(m_vl, " iter0: %8d %22.15e\n", nth, diff2 * snorm);
162 
163  if (diff2 * snorm < m_Stop_cond) {
164  Nconv2 = 2 * (iter + 1);
165 
166  // break;
167  is_converged = true;
168  }
169 
170  copy(s, x); // s = x;
171  solve_init(b, rr);
172 
174  if (ith == 0) vout.detailed(m_vl, " iter1: %8d %22.15e\n", nth, rr * snorm);
175  }
176  }
177  }
178  if (Nconv2 == -1) {
179 #pragma omp master
180  vout.crucial(m_vl, "%s: not converged.\n", class_name.c_str());
181 #pragma omp barrier
182  abort();
183  }
184 
185  m_fopr->mult(p, x); // p = m_fopr->mult(x);
186  axpy(p, -1.0, b); // p -= b;
187 
188  copy(xq, x); // xq = x;
189 
190  double diff2 = p.norm2();
191 
192 #pragma omp master
193  {
194  diff = sqrt(diff2);
195  Nconv = Nconv2;
196  }
197 #pragma omp barrier
198  } // end of parallel region
199 }
200 
201 
202 //====================================================================
204 {
205 #pragma omp master
206  {
207  int Nin = b.nin();
208  int Nvol = b.nvol();
209  int Nex = b.nex();
210 
211  if ((s.nin() != Nin) || (s.nvol() != Nvol) || (s.nex() != Nex)) {
212  s.reset(Nin, Nvol, Nex);
213  r.reset(Nin, Nvol, Nex);
214  x.reset(Nin, Nvol, Nex);
215  p.reset(Nin, Nvol, Nex);
216  v.reset(Nin, Nvol, Nex);
217  t.reset(Nin, Nvol, Nex);
218  rh.reset(Nin, Nvol, Nex);
219 
220  vout.detailed(m_vl, " %s: field size reset.\n", class_name.c_str());
221  }
222  }
223 #pragma omp barrier
224 }
225 
226 
227 //====================================================================
228 void Solver_BiCGStab_Cmplx::solve_init(const Field& b, double& rr)
229 {
230  copy(x, s); // x = s;
231 
232  //- r = b - A x_0
233 
234  m_fopr->mult(v, s); // v = m_fopr->mult(s);
235  copy(r, b); // r = b;
236  axpy(r, -1.0, v); // r -= v;
237  copy(rh, r); // rh = r;
238 
239  rr = r.norm2(); // rr = r * r;
240 
241  p.set(0.0); // p = 0.0
242  v.set(0.0); // v = 0.0
243 
244 #pragma omp master
245  {
246  rho_prev = cmplx(1.0, 0.0);
247  alpha_prev = cmplx(1.0, 0.0);
248  omega_prev = cmplx(1.0, 0.0);
249  }
250 #pragma omp barrier
251 }
252 
253 
254 //====================================================================
256 {
257  dcomplex rho = dotc(rh, r); // rho = rh * r;
258  dcomplex bet = rho * alpha_prev / (rho_prev * omega_prev);
259 
260  // p = r + bet * (p - omega_prev * v);
261  axpy(p, -omega_prev, v); // p += - omega_prev * v;
262  aypx(bet, p, r); // p = bet * p + r;
263 
264  m_fopr->mult(v, p); // v = m_fopr->mult(p);
265 
266  dcomplex aden = dotc(rh, v); // aden = rh * v;
267  dcomplex alpha = rho / aden;
268 
269  copy(s, r); // s = r
270  axpy(s, -alpha, v); // s += - alpha * v;
271 
272  m_fopr->mult(t, s); // t = m_fopr->mult(s);
273 
274  double omega_d = dot(t, t); // omega_d = t * t;
275  dcomplex omega_n = dotc(t, s); // omega_n = t * s;
276  dcomplex omega = omega_n / omega_d;
277 
278  axpy(x, omega, s); // x += omega * s;
279  axpy(x, alpha, p); // x += alpha * p;
280 
281  copy(r, s); // r = s
282  axpy(r, -omega, t); // r += - omega * t;
283 
284  rr = r.norm2(); // rr = r * r;
285 
286 #pragma omp master
287  {
288  rho_prev = rho;
289  alpha_prev = alpha;
290  omega_prev = omega;
291  }
292 #pragma omp barrier
293 }
294 
295 
296 //====================================================================
297 //============================================================END=====
void reset_field(const Field &)
BridgeIO vout
Definition: bridgeIO.cpp:207
void detailed(const char *format,...)
Definition: bridgeIO.cpp:50
static int get_num_threads()
returns available number of threads.
void Register_string(const string &, const string &)
Definition: parameters.cpp:352
double norm2() const
Definition: field.cpp:469
static const std::string class_name
void set(const int jin, const int site, const int jex, double v)
Definition: field.h:128
double dot(const Field &y, const Field &x)
Definition: field.cpp:46
virtual const Field mult(const Field &)=0
multiplies fermion operator to a given field and returns the resultant field.
void general(const char *format,...)
Definition: bridgeIO.cpp:38
void Register_int(const string &, const int)
Definition: parameters.cpp:331
void solve(Field &solution, const Field &source, int &Nconv, double &diff)
Container of Field-type object.
Definition: field.h:37
int nvol() const
Definition: field.h:101
void set_parameters(const Parameters &params)
Class for parameters.
Definition: parameters.h:40
void copy(Field &y, const Field &x)
copy(y, x): y = x
Definition: field.cpp:409
static int get_thread_id()
returns thread id.
int square_non_zero(const double v)
Definition: checker.cpp:41
int nin() const
Definition: field.h:100
dcomplex dotc(const Field &y, const Field &x)
Definition: field.cpp:98
void reset(const int Nin, const int Nvol, const int Nex, const element_type cmpl=COMPLEX)
Definition: field.h:82
static void sync_barrier_all()
barrier among all the threads and nodes.
void aypx(const double a, Field &y, const Field &x)
aypx(y, a, x): y := a * y + x
Definition: field.cpp:489
int nex() const
Definition: field.h:102
void solve_init(const Field &, double &)
void paranoiac(const char *format,...)
Definition: bridgeIO.cpp:62
void axpy(Field &y, const double a, const Field &x)
axpy(y, a, x): y := a * x + y
Definition: field.cpp:193
BiCGStab algorithm with complex variables.
void crucial(const char *format,...)
Definition: bridgeIO.cpp:26
Base class for linear solver class family.
Definition: solver.h:37
static bool Register(const std::string &realm, const creator_callback &cb)
int non_negative(const int v)
Definition: checker.cpp:21
void Register_double(const string &, const double)
Definition: parameters.cpp:324
Base class of fermion operator family.
Definition: fopr.h:39
int fetch_double(const string &key, double &val) const
Definition: parameters.cpp:124
string get_string(const string &key) const
Definition: parameters.cpp:85
int fetch_int(const string &key, int &val) const
Definition: parameters.cpp:141
Bridge::VerboseLevel m_vl
Definition: solver.h:56
static VerboseLevel set_verbose_level(const std::string &str)
Definition: bridgeIO.cpp:191
static void assert_single_thread(const std::string &classname)
assert currently running on single thread.
int size() const
Definition: field.h:106