Bridge++  Ver. 1.2.x
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
solver_GMRES_m_Cmplx.cpp
Go to the documentation of this file.
1 
14 #include "solver_GMRES_m_Cmplx.h"
15 
16 
17 #ifdef USE_FACTORY
18 namespace {
19  Solver *create_object(Fopr *fopr)
20  {
21  return new Solver_GMRES_m_Cmplx(fopr);
22  }
23 
24 
25  bool init = Solver::Factory::Register("GMRES_m_Cmplx", create_object);
26 }
27 #endif
28 
29 //- parameter entries
30 namespace {
31  void append_entry(Parameters& param)
32  {
33  param.Register_int("maximum_number_of_iteration", 0);
34  param.Register_double("convergence_criterion_squared", 0.0);
35 
36  param.Register_int("number_of_orthonormal_vectors", 0);
37 
38  param.Register_string("verbose_level", "NULL");
39  }
40 
41 
42 #ifdef USE_PARAMETERS_FACTORY
43  bool init_param = ParametersFactory::Register("Solver.GMRES_m_Cmplx", append_entry);
44 #endif
45 }
46 //- end
47 
48 //- parameters class
50 //- end
51 
52 const std::string Solver_GMRES_m_Cmplx::class_name = "Solver_GMRES_m_Cmplx";
53 
54 //====================================================================
56 {
57  const string str_vlevel = params.get_string("verbose_level");
58 
59  m_vl = vout.set_verbose_level(str_vlevel);
60 
61  //- fetch and check input parameters
62  int Niter;
63  double Stop_cond;
64  int N_M;
65 
66  int err = 0;
67  err += params.fetch_int("maximum_number_of_iteration", Niter);
68  err += params.fetch_double("convergence_criterion_squared", Stop_cond);
69  err += params.fetch_int("number_of_orthonormal_vectors", N_M);
70 
71  if (err) {
72  vout.crucial(m_vl, "%s: fetch error, input parameter not found.\n", class_name.c_str());
73  abort();
74  }
75 
76 
77  set_parameters(Niter, Stop_cond);
79 }
80 
81 
82 //====================================================================
83 void Solver_GMRES_m_Cmplx::set_parameters(const int Niter, const double Stop_cond)
84 {
86 
87  //- print input parameters
88  vout.general(m_vl, "%s: input parameters\n", class_name.c_str());
89  vout.general(m_vl, " Niter = %d\n", Niter);
90  vout.general(m_vl, " Stop_cond = %16.8e\n", Stop_cond);
91 
92  //- range check
93  int err = 0;
94  err += ParameterCheck::non_negative(Niter);
95  err += ParameterCheck::square_non_zero(Stop_cond);
96 
97  if (err) {
98  vout.crucial(m_vl, "%s: parameter range check failed.\n", class_name.c_str());
99  abort();
100  }
101 
102  //- store values
103  m_Niter = Niter;
104  m_Stop_cond = Stop_cond;
105 }
106 
107 
108 //====================================================================
110 {
111  //- print input parameters
112  vout.general(m_vl, " N_M = %d\n", N_M);
113 
114  //- range check
115  int err = 0;
116  err += ParameterCheck::non_negative(N_M);
117 
118  if (err) {
119  vout.crucial(m_vl, "%s: parameter range check failed.\n", class_name.c_str());
120  abort();
121  }
122 
123  //- store values
124  m_N_M = N_M;
125 }
126 
127 
128 //====================================================================
130  int& Nconv, double& diff)
131 {
132  //#pragma omp parallel
133  {
134  double bnorm2 = b.norm2();
135  double snorm = 1.0 / bnorm2;
136  int bsize = b.size();
137  double rr;
138 
141 
142 #pragma omp master
143  {
144  vout.paranoiac(m_vl, "%s: starts\n", class_name.c_str());
145  vout.paranoiac(m_vl, " norm of b = %16.8e\n", bnorm2);
146  vout.paranoiac(m_vl, " size of b = %d\n", bsize);
147  }
148 #pragma omp barrier
149 
150 
151  reset_field(b);
152 
153 
154  // Nconv = -1;
155  int Nconv2 = 0;
156  copy(s, b); // s = b;
157 
158  solve_init(b, rr);
159 
160  bool is_converged = false;
161 
162 #pragma omp master
163  vout.detailed(m_vl, " iter: %8d %22.15e\n", 0, rr * snorm);
164 #pragma omp barrier
165 
166 
167  for (int iter = 0; iter < m_Niter; iter++) {
168  if (!is_converged) {
169  solve_step(b, rr);
170 
171 #pragma omp master
172  vout.detailed(m_vl, " iter: %8d %22.15e\n", m_N_M * (iter + 1), rr * snorm);
173 #pragma omp barrier
174 
175  if (rr * snorm < m_Stop_cond) {
176  m_fopr->mult(s, x); // s = m_fopr->mult(x);
177  axpy(s, -1.0, b); // s -= b;
178 
179  double diff2 = s.norm2();
181 
183  if (ith == 0) vout.detailed(m_vl, " iter0: %8d %22.15e\n", nth, diff2 * snorm);
184 
185  if (diff2 * snorm < m_Stop_cond) {
186  Nconv2 = m_N_M * (iter + 1);
187 
188  // break;
189  is_converged = true;
190  }
191 
192  copy(s, x); // s = x;
193  solve_init(b, rr);
194 
196  if (ith == 0) vout.detailed(m_vl, " iter1: %8d %22.15e\n", nth, rr * snorm);
197  }
198  }
199  }
200 
201  if (Nconv2 == -1) {
202 #pragma omp master
203  vout.crucial(m_vl, "%s: not converged.\n", class_name.c_str());
204 #pragma omp barrier
205  abort();
206  }
207 
208  m_fopr->mult(s, x); // p = m_fopr->mult(x);
209  axpy(s, -1.0, b); // p -= b;
210 
211  copy(xq, x); // xq = x;
212 
213  double diff2 = s.norm2();
214 
215 #pragma omp master
216  {
217  diff = sqrt(diff2);
218  Nconv = Nconv2;
219  }
220 #pragma omp barrier
221  } // end of parallel region
222 }
223 
224 
225 //====================================================================
227 {
228 #pragma omp master
229  {
230  int Nin = b.nin();
231  int Nvol = b.nvol();
232  int Nex = b.nex();
233 
234  if ((s.nin() != Nin) || (s.nvol() != Nvol) || (s.nex() != Nex)) {
235  s.reset(Nin, Nvol, Nex);
236  r.reset(Nin, Nvol, Nex);
237  x.reset(Nin, Nvol, Nex);
238 
239  v_tmp.reset(Nin, Nvol, Nex);
240 
241  v.resize(m_N_M + 1);
242 
243  for (int i = 0; i < m_N_M + 1; ++i) {
244  v[i].reset(Nin, Nvol, Nex);
245  }
246 
247  vout.paranoiac(m_vl, " %s: field size reset.\n", class_name.c_str());
248  }
249  }
250 #pragma omp barrier
251 }
252 
253 
254 //====================================================================
255 void Solver_GMRES_m_Cmplx::solve_init(const Field& b, double& rr)
256 {
257  copy(x, s); // x = s;
258 
259  //- r = b - A x_0
260  m_fopr->mult(v_tmp, s); // v_tmp = m_fopr->mult(s);
261  copy(r, b); // r = b;
262  axpy(r, -1.0, v_tmp); // r -= v_tmp;
263 
264  rr = r.norm2(); // rr = r * r;
265 
266 #pragma omp master
267  beta_prev = sqrt(rr);
268 #pragma omp barrier
269 
270  //- v[0] = (1.0 / beta_prev) * r;
271  copy(v[0], r); // v[0] = r;
272  scal(v[0], (1.0 / beta_prev)); // v[0] = (1.0 / beta_p) * v[0];
273 }
274 
275 
276 //====================================================================
277 void Solver_GMRES_m_Cmplx::solve_step(const Field& b, double& rr)
278 {
279  std::valarray<dcomplex> h((m_N_M + 1) * m_N_M), y(m_N_M);
280 
281  h = cmplx(0.0, 0.0);
282  y = cmplx(0.0, 0.0);
283 
284 
285  for (int j = 0; j < m_N_M; ++j) {
286  m_fopr->mult(v_tmp, v[j]); // v_tmp = m_fopr->mult(v[j]);
287 
288  for (int i = 0; i < j + 1; ++i) {
289  int ij = index_ij(i, j);
290  h[ij] = dotc(v_tmp, v[i]); // h[ij] = (A v[j], v[i]);
291  }
292 
293  //- v[j+1] = A v[j] - \Sum_{i=0}^{j-1} h[i,j] * v[i]
294  v[j + 1] = v_tmp;
295 
296  for (int i = 0; i < j + 1; ++i) {
297  int ij = index_ij(i, j);
298  axpy(v[j + 1], -h[ij], v[i]); // v[j+1] -= h[ij] * v[i];
299  }
300 
301  double v_norm2 = v[j + 1].norm2();
302 
303  int j1j = index_ij(j + 1, j);
304  h[j1j] = sqrt(v_norm2);
305 
306  scal(v[j + 1], 1.0 / sqrt(v_norm2)); // v[j+1] /= sqrt(v_norm2);
307  }
308 
309 
310  // Compute y, which minimizes J := |r_new| = |beta_p - h * y|
311  min_J(y, h);
312 
313 
314  // x += Sum_{i=0}^{N_M-1} y[i] * v[i];
315  for (int i = 0; i < m_N_M; ++i) {
316  axpy(x, y[i], v[i]); // x += y[i] * v[i];
317  }
318 
319 
320  // r = b - m_fopr->mult(x);
321  copy(s, x); // s = x;
322  solve_init(b, rr);
323 }
324 
325 
326 //====================================================================
327 void Solver_GMRES_m_Cmplx::min_J(std::valarray<dcomplex>& y,
328  std::valarray<dcomplex>& h)
329 {
330  // Compute y, which minimizes J := |r_new| = |beta_p - h * y|
331 
332  std::valarray<dcomplex> g(m_N_M + 1);
333 
334  g = dcomplex(0.0);
335  g[0] = beta_prev;
336 
337  for (int i = 0; i < m_N_M; ++i) {
338  int ii = index_ij(i, i);
339  double h_1_r = abs(h[ii]);
340 
341  int i1i = index_ij(i + 1, i);
342  double h_2_r = abs(h[i1i]);
343 
344  double denomi = sqrt(h_1_r * h_1_r + h_2_r * h_2_r);
345 
346  dcomplex cs = h[ii] / denomi;
347  dcomplex sn = h[i1i] / denomi;
348 
349  for (int j = i; j < m_N_M; ++j) {
350  int ij = index_ij(i, j);
351  int i1j = index_ij(i + 1, j);
352 
353  dcomplex const_1_c = conj(cs) * h[ij] + sn * h[i1j];
354  dcomplex const_2_c = -sn * h[ij] + cs * h[i1j];
355 
356  h[ij] = const_1_c;
357  h[i1j] = const_2_c;
358  }
359 
360  dcomplex const_1_c = conj(cs) * g[i] + sn * g[i + 1];
361  dcomplex const_2_c = -sn * g[i] + cs * g[i + 1];
362 
363  g[i] = const_1_c;
364  g[i + 1] = const_2_c;
365  }
366 
367 
368  for (int i = m_N_M - 1; i > -1; --i) {
369  for (int j = i + 1; j < m_N_M; ++j) {
370  int ij = index_ij(i, j);
371  g[i] -= h[ij] * y[j];
372  }
373 
374  int ii = index_ij(i, i);
375  y[i] = g[i] / h[ii];
376  }
377 }
378 
379 
380 //====================================================================
381 //============================================================END=====
void scal(Field &x, const double a)
scal(x, a): x = a * x
Definition: field.cpp:310
BridgeIO vout
Definition: bridgeIO.cpp:207
void detailed(const char *format,...)
Definition: bridgeIO.cpp:50
static int get_num_threads()
returns available number of threads.
void Register_string(const string &, const string &)
Definition: parameters.cpp:352
double norm2() const
Definition: field.cpp:469
virtual const Field mult(const Field &)=0
multiplies fermion operator to a given field and returns the resultant field.
void general(const char *format,...)
Definition: bridgeIO.cpp:38
void Register_int(const string &, const int)
Definition: parameters.cpp:331
static const std::string class_name
Container of Field-type object.
Definition: field.h:37
int nvol() const
Definition: field.h:101
void solve_init(const Field &, double &)
Class for parameters.
Definition: parameters.h:40
void copy(Field &y, const Field &x)
copy(y, x): y = x
Definition: field.cpp:409
static int get_thread_id()
returns thread id.
std::valarray< Field > v
int square_non_zero(const double v)
Definition: checker.cpp:41
int nin() const
Definition: field.h:100
GMRES(m) algorithm with complex variables.
dcomplex dotc(const Field &y, const Field &x)
Definition: field.cpp:98
void reset(const int Nin, const int Nvol, const int Nex, const element_type cmpl=COMPLEX)
Definition: field.h:82
static void sync_barrier_all()
barrier among all the threads and nodes.
int nex() const
Definition: field.h:102
int index_ij(int i, int j)
void paranoiac(const char *format,...)
Definition: bridgeIO.cpp:62
void axpy(Field &y, const double a, const Field &x)
axpy(y, a, x): y := a * x + y
Definition: field.cpp:193
void crucial(const char *format,...)
Definition: bridgeIO.cpp:26
Base class for linear solver class family.
Definition: solver.h:37
static bool Register(const std::string &realm, const creator_callback &cb)
void reset_field(const Field &)
void solve(Field &solution, const Field &source, int &Nconv, double &diff)
int non_negative(const int v)
Definition: checker.cpp:21
void Register_double(const string &, const double)
Definition: parameters.cpp:324
void min_J(std::valarray< dcomplex > &y, std::valarray< dcomplex > &h)
void set_parameters_GMRES_m(const int N_M)
Base class of fermion operator family.
Definition: fopr.h:39
int fetch_double(const string &key, double &val) const
Definition: parameters.cpp:124
string get_string(const string &key) const
Definition: parameters.cpp:85
int fetch_int(const string &key, int &val) const
Definition: parameters.cpp:141
Bridge::VerboseLevel m_vl
Definition: solver.h:56
static VerboseLevel set_verbose_level(const std::string &str)
Definition: bridgeIO.cpp:191
static void assert_single_thread(const std::string &classname)
assert currently running on single thread.
void solve_step(const Field &, double &)
void set_parameters(const Parameters &params)
int size() const
Definition: field.h:106