Bridge++  Ver. 2.0.2
solver_GMRES_m_Cmplx.cpp
Go to the documentation of this file.
1 
14 #include "solver_GMRES_m_Cmplx.h"
15 
16 #ifdef USE_FACTORY_AUTOREGISTER
17 namespace {
18  bool init = Solver_GMRES_m_Cmplx::register_factory();
19 }
20 #endif
21 
22 const std::string Solver_GMRES_m_Cmplx::class_name = "Solver_GMRES_m_Cmplx";
23 
24 //====================================================================
26 {
27  std::string vlevel;
28  if (!params.fetch_string("verbose_level", vlevel)) {
29  m_vl = vout.set_verbose_level(vlevel);
30  }
31 
32  //- fetch and check input parameters
33  int Niter, Nrestart;
34  double Stop_cond;
35  bool use_init_guess;
36  int N_M;
37 
38  int err = 0;
39  err += params.fetch_int("maximum_number_of_iteration", Niter);
40  err += params.fetch_int("maximum_number_of_restart", Nrestart);
41  err += params.fetch_double("convergence_criterion_squared", Stop_cond);
42  err += params.fetch_bool("use_initial_guess", use_init_guess);
43  err += params.fetch_int("number_of_orthonormal_vectors", N_M);
44 
45  if (err) {
46  vout.crucial(m_vl, "Error at %s: input parameter not found.\n", class_name.c_str());
47  exit(EXIT_FAILURE);
48  }
49 
50  set_parameters(Niter, Nrestart, Stop_cond, use_init_guess);
52 }
53 
54 
55 //====================================================================
57 {
58  params.set_int("maximum_number_of_iteration", m_Niter);
59  params.set_int("maximum_number_of_restart", m_Nrestart);
60  params.set_double("convergence_criterion_squared", m_Stop_cond);
61  params.set_bool("use_initial_guess", m_use_init_guess);
62  params.set_int("number_of_orthonormal_vectors", m_N_M);
63 
64  params.set_string("verbose_level", vout.get_verbose_level(m_vl));
65 }
66 
67 
68 //====================================================================
69 void Solver_GMRES_m_Cmplx::set_parameters(const int Niter, const int Nrestart, const double Stop_cond)
70 {
72 
73  //- print input parameters
74  vout.general(m_vl, "%s: input parameters\n", class_name.c_str());
75  vout.general(m_vl, " Niter = %d\n", Niter);
76  vout.general(m_vl, " Nrestart = %d\n", Nrestart);
77  vout.general(m_vl, " Stop_cond = %8.2e\n", Stop_cond);
78 
79  //- range check
80  int err = 0;
81  err += ParameterCheck::non_negative(Niter);
82  err += ParameterCheck::non_negative(Nrestart);
83  err += ParameterCheck::square_non_zero(Stop_cond);
84 
85  if (err) {
86  vout.crucial(m_vl, "Error at %s: parameter range check failed.\n", class_name.c_str());
87  exit(EXIT_FAILURE);
88  }
89 
90  //- store values
91  m_Niter = Niter;
92  m_Nrestart = Nrestart;
93  m_Stop_cond = Stop_cond;
94 }
95 
96 
97 //====================================================================
98 void Solver_GMRES_m_Cmplx::set_parameters(const int Niter, const int Nrestart, const double Stop_cond, const bool use_init_guess)
99 {
101 
102  //- print input parameters
103  vout.general(m_vl, "%s: input parameters\n", class_name.c_str());
104  vout.general(m_vl, " Niter = %d\n", Niter);
105  vout.general(m_vl, " Nrestart = %d\n", Nrestart);
106  vout.general(m_vl, " Stop_cond = %8.2e\n", Stop_cond);
107  vout.general(m_vl, " use_init_guess = %s\n", use_init_guess ? "true" : "false");
108 
109  //- range check
110  int err = 0;
111  err += ParameterCheck::non_negative(Niter);
112  err += ParameterCheck::non_negative(Nrestart);
113  err += ParameterCheck::square_non_zero(Stop_cond);
114 
115  if (err) {
116  vout.crucial(m_vl, "Error at %s: parameter range check failed.\n", class_name.c_str());
117  exit(EXIT_FAILURE);
118  }
119 
120  //- store values
121  m_Niter = Niter;
122  m_Nrestart = Nrestart;
123  m_Stop_cond = Stop_cond;
124  m_use_init_guess = use_init_guess;
125 }
126 
127 
128 //====================================================================
130 {
131  //- print input parameters
132  vout.general(m_vl, " N_M = %d\n", N_M);
133 
134  //- range check
135  int err = 0;
136  err += ParameterCheck::non_negative(N_M);
137 
138  if (err) {
139  vout.crucial(m_vl, "Error at %s: parameter range check failed.\n", class_name.c_str());
140  exit(EXIT_FAILURE);
141  }
142 
143  //- store values
144  m_N_M = N_M;
145 }
146 
147 
148 //====================================================================
150  const int Nrestart,
151  const double Stop_cond,
152  const bool use_init_guess,
153  const int N_M)
154 {
156 
157  //- print input parameters
158  vout.general(m_vl, "%s: input parameters\n", class_name.c_str());
159  vout.general(m_vl, " Niter = %d\n", Niter);
160  vout.general(m_vl, " Nrestart = %d\n", Nrestart);
161  vout.general(m_vl, " Stop_cond = %8.2e\n", Stop_cond);
162  vout.general(m_vl, " use_init_guess = %s\n", use_init_guess ? "true" : "false");
163 
164  vout.general(m_vl, " N_M = %d\n", N_M);
165 
166  //- range check
167  int err = 0;
168  err += ParameterCheck::non_negative(Niter);
169  err += ParameterCheck::non_negative(Nrestart);
170  err += ParameterCheck::square_non_zero(Stop_cond);
171 
172  err += ParameterCheck::non_negative(N_M);
173 
174  if (err) {
175  vout.crucial(m_vl, "Error at %s: parameter range check failed.\n", class_name.c_str());
176  exit(EXIT_FAILURE);
177  }
178 
179  //- store values
180  m_Niter = Niter;
181  m_Nrestart = Nrestart;
182  m_Stop_cond = Stop_cond;
183  m_use_init_guess = use_init_guess;
184 
185  m_N_M = N_M;
186 }
187 
188 
189 //====================================================================
191  int& Nconv, double& diff)
192 {
193  const double bnorm2 = b.norm2();
194  const int bsize = b.size();
195 
196  vout.paranoiac(m_vl, "%s: starts\n", class_name.c_str());
197  vout.paranoiac(m_vl, " norm of b = %16.8e\n", bnorm2);
198  vout.paranoiac(m_vl, " size of b = %d\n", bsize);
199 
200  bool is_converged = false;
201  int Nconv2 = 0;
202  double diff2 = 1.0; // superficial initialization
203  double rr;
204 
205  int Nconv_unit = 1;
206  // if (m_fopr->get_mode() == "DdagD" || m_fopr->get_mode() == "DDdag") {
207  // Nconv_unit = 2;
208  // }
209 
210  reset_field(b);
211 
212  if (m_use_init_guess) {
213  copy(m_s, xq); // s = xq;
214  } else {
215  copy(m_s, b); // s = b;
216  }
217  solve_init(b, rr);
218  Nconv2 += Nconv_unit;
219 
220  vout.detailed(m_vl, " iter: %8d %22.15e\n", Nconv2, rr / bnorm2);
221 
222 
223  for (int i_restart = 0; i_restart < m_Nrestart; i_restart++) {
224  for (int iter = 0; iter < m_Niter; iter++) {
225  if (rr / bnorm2 < m_Stop_cond) break;
226 
227  solve_step(b, rr);
228  Nconv2 += Nconv_unit * m_N_M;
229 
230  vout.detailed(m_vl, " iter: %8d %22.15e\n", Nconv2, rr / bnorm2);
231  }
232 
233  //- calculate true residual
234  m_fopr->mult(m_s, m_x); // s = m_fopr->mult(x);
235  axpy(m_s, -1.0, b); // s -= b;
236  diff2 = m_s.norm2();
237 
238  if (diff2 / bnorm2 < m_Stop_cond) {
239  vout.detailed(m_vl, "%s: converged.\n", class_name.c_str());
240  vout.detailed(m_vl, " iter(final): %8d %22.15e\n", Nconv2, diff2 / bnorm2);
241 
242  is_converged = true;
243 
244  m_Nrestart_count = i_restart;
245  m_Nconv_count = Nconv2;
246 
247  break;
248  } else {
249  //- restart with new approximate solution
250  copy(m_s, m_x); // s = x;
251  solve_init(b, rr);
252 
253  vout.detailed(m_vl, "%s: restarted.\n", class_name.c_str());
254  }
255  }
256 
257 
258  if (!is_converged) {
259  vout.crucial(m_vl, "Error at %s: not converged.\n", class_name.c_str());
260  vout.crucial(m_vl, " iter(final): %8d %22.15e\n", Nconv2, diff2 / bnorm2);
261  exit(EXIT_FAILURE);
262  }
263 
264 
265  copy(xq, m_x); // xq = x;
266 
267 #pragma omp barrier
268 #pragma omp master
269  {
270  diff = sqrt(diff2 / bnorm2);
271  Nconv = Nconv2;
272  }
273 #pragma omp barrier
274 }
275 
276 
277 //====================================================================
279 {
280 #pragma omp barrier
281 #pragma omp master
282  {
283  const int Nin = b.nin();
284  const int Nvol = b.nvol();
285  const int Nex = b.nex();
286 
287  if ((m_s.nin() != Nin) || (m_s.nvol() != Nvol) || (m_s.nex() != Nex)) {
288  m_s.reset(Nin, Nvol, Nex);
289  m_r.reset(Nin, Nvol, Nex);
290  m_x.reset(Nin, Nvol, Nex);
291 
292  m_v_tmp.reset(Nin, Nvol, Nex);
293 
294  m_v.resize(m_N_M + 1);
295 
296  for (int i = 0; i < m_N_M + 1; ++i) {
297  m_v[i].reset(Nin, Nvol, Nex);
298  }
299  }
300  }
301 #pragma omp barrier
302 
303  vout.paranoiac(m_vl, " %s: field size reset.\n", class_name.c_str());
304 }
305 
306 
307 //====================================================================
308 void Solver_GMRES_m_Cmplx::solve_init(const Field& b, double& rr)
309 {
310  copy(m_x, m_s); // x = s;
311 
312  for (int i = 0; i < m_N_M + 1; ++i) {
313  m_v[i].set(0.0); // m_v[i] = 0.0;
314  }
315 
316  // r = b - A x_0
317  m_fopr->mult(m_v_tmp, m_s); // v_tmp = m_fopr->mult(s);
318  copy(m_r, b); // r = b;
319  axpy(m_r, -1.0, m_v_tmp); // r -= v_tmp;
320 
321  rr = m_r.norm2(); // rr = r * r;
322 
323 #pragma omp barrier
324 #pragma omp master
325  m_beta_prev = sqrt(rr);
326 #pragma omp barrier
327 
328  //- v[0] = (1.0 / m_beta_prev) * r;
329  copy(m_v[0], m_r); // v[0] = r;
330  scal(m_v[0], (1.0 / m_beta_prev)); // v[0] = (1.0 / beta_p) * v[0];
331 }
332 
333 
334 //====================================================================
335 void Solver_GMRES_m_Cmplx::solve_step(const Field& b, double& rr)
336 {
337  std::valarray<dcomplex> h((m_N_M + 1) * m_N_M), y(m_N_M);
338 
339  h = cmplx(0.0, 0.0);
340  y = cmplx(0.0, 0.0);
341 
342 
343  for (int j = 0; j < m_N_M; ++j) {
344  m_fopr->mult(m_v_tmp, m_v[j]); // v_tmp = m_fopr->mult(v[j]);
345 
346  for (int i = 0; i < j + 1; ++i) {
347  int ij = index_ij(i, j);
348  h[ij] = dotc(m_v[i], m_v_tmp); // h[ij] = (v[i], A v[j]);
349  }
350 
351  //- v[j+1] = A v[j] - \Sum_{i=0}^{j-1} h[i,j] * v[i]
352  m_v[j + 1] = m_v_tmp;
353 
354  for (int i = 0; i < j + 1; ++i) {
355  int ij = index_ij(i, j);
356  axpy(m_v[j + 1], -h[ij], m_v[i]); // v[j+1] -= h[ij] * v[i];
357  }
358 
359  double v_norm2 = m_v[j + 1].norm2();
360 
361  int j1j = index_ij(j + 1, j);
362  h[j1j] = cmplx(sqrt(v_norm2), 0.0);
363 
364  scal(m_v[j + 1], 1.0 / sqrt(v_norm2)); // v[j+1] /= sqrt(v_norm2);
365  }
366 
367 
368  // Compute y, which minimizes J := |r_new| = |beta_p - h * y|
369  min_J(y, h);
370 
371 
372  // x += Sum_{i=0}^{N_M-1} y[i] * v[i];
373  for (int i = 0; i < m_N_M; ++i) {
374  axpy(m_x, y[i], m_v[i]); // x += y[i] * v[i];
375  }
376 
377 
378  // r = b - m_fopr->mult(x);
379  copy(m_s, m_x); // s = x;
380  solve_init(b, rr);
381 }
382 
383 
384 //====================================================================
385 void Solver_GMRES_m_Cmplx::min_J(std::valarray<dcomplex>& y,
386  std::valarray<dcomplex>& h)
387 {
388  // Compute y, which minimizes J := |r_new| = |beta_p - h * y|
389 
390  std::valarray<dcomplex> g(m_N_M + 1);
391 
392  g = cmplx(0.0, 0.0);
393  g[0] = cmplx(m_beta_prev, 0.0);
394 
395  for (int i = 0; i < m_N_M; ++i) {
396  int ii = index_ij(i, i);
397  double h_1_r = abs(h[ii]);
398 
399  int i1i = index_ij(i + 1, i);
400  double h_2_r = abs(h[i1i]);
401 
402  double denomi = sqrt(h_1_r * h_1_r + h_2_r * h_2_r);
403 
404  dcomplex cs = h[ii] / denomi;
405  dcomplex sn = h[i1i] / denomi;
406 
407  for (int j = i; j < m_N_M; ++j) {
408  int ij = index_ij(i, j);
409  int i1j = index_ij(i + 1, j);
410 
411  dcomplex const_1_c = conj(cs) * h[ij] + sn * h[i1j];
412  dcomplex const_2_c = -sn * h[ij] + cs * h[i1j];
413 
414  h[ij] = const_1_c;
415  h[i1j] = const_2_c;
416  }
417 
418  dcomplex const_1_c = conj(cs) * g[i] + sn * g[i + 1];
419  dcomplex const_2_c = -sn * g[i] + cs * g[i + 1];
420 
421  g[i] = const_1_c;
422  g[i + 1] = const_2_c;
423  }
424 
425 
426  for (int i = m_N_M - 1; i > -1; --i) {
427  for (int j = i + 1; j < m_N_M; ++j) {
428  int ij = index_ij(i, j);
429  g[i] -= h[ij] * y[j];
430  }
431 
432  int ii = index_ij(i, i);
433  y[i] = g[i] / h[ii];
434  }
435 }
436 
437 
438 //====================================================================
440 {
441  const int NPE = CommonParameters::NPE();
442 
443  //- NB1 Nin = 2 * Nc * Nd, Nex = 1 for field_F
444  //- NB2 Nvol = CommonParameters::Nvol()/2 for eo
445  const int Nin = m_x.nin();
446  const int Nvol = m_x.nvol();
447  const int Nex = m_x.nex();
448 
449  const double gflop_fopr = m_fopr->flop_count();
450 
451  if (gflop_fopr < CommonParameters::epsilon_criterion()) {
452  vout.crucial(m_vl, "Warning at %s: no fopr->flop_count() is available, setting flop = 0\n", class_name.c_str());
453  return 0.0;
454  }
455 
456  const double gflop_axpy = (Nin * Nex * 2) * ((Nvol * NPE) / 1.0e+9);
457  const double gflop_dotc = (Nin * Nex * 4) * ((Nvol * NPE) / 1.0e+9);
458  const double gflop_norm = (Nin * Nex * 2) * ((Nvol * NPE) / 1.0e+9);
459  const double gflop_scal = (Nin * Nex * 2) * ((Nvol * NPE) / 1.0e+9);
460 
461  int N_M_part = 0;
462  for (int j = 0; j < m_N_M; ++j) {
463  for (int i = 0; i < j + 1; ++i) {
464  N_M_part += 1;
465  }
466  }
467 
468  const double gflop_init = gflop_fopr + gflop_axpy + gflop_norm + gflop_scal;
469  const double gflop_step = m_N_M * gflop_fopr + N_M_part * gflop_dotc
470  + (N_M_part + m_N_M) * gflop_axpy
471  + m_N_M * gflop_scal
472  + gflop_init;
473  const double gflop_true_residual = gflop_fopr + gflop_axpy + gflop_norm;
474 
475  const int N_iter = (m_Nconv_count - 1) / m_N_M;
476  const double gflop = gflop_norm + gflop_init
477  + gflop_step * N_iter
478  + gflop_true_residual * (m_Nrestart_count + 1)
479  + gflop_init * m_Nrestart_count;
480 
481  return gflop;
482 }
483 
484 
485 //====================================================================
486 //============================================================END=====
Parameters::set_bool
void set_bool(const string &key, const bool value)
Definition: parameters.cpp:30
Solver_GMRES_m_Cmplx::solve_step
void solve_step(const Field &, double &)
Definition: solver_GMRES_m_Cmplx.cpp:335
Solver_GMRES_m_Cmplx::class_name
static const std::string class_name
Definition: solver_GMRES_m_Cmplx.h:44
Solver_GMRES_m_Cmplx::m_v_tmp
Field m_v_tmp
Definition: solver_GMRES_m_Cmplx.h:62
Solver_GMRES_m_Cmplx::set_parameters_GMRES_m
DEPRECATED void set_parameters_GMRES_m(const int N_M)
Definition: solver_GMRES_m_Cmplx.cpp:129
Parameters::set_string
void set_string(const string &key, const string &value)
Definition: parameters.cpp:39
Solver_GMRES_m_Cmplx::m_v
std::vector< Field > m_v
Definition: solver_GMRES_m_Cmplx.h:61
AFopr::mult
virtual void mult(AFIELD &, const AFIELD &)
multiplies fermion operator to a given field.
Definition: afopr.h:95
Parameters
Class for parameters.
Definition: parameters.h:46
Parameters::set_double
void set_double(const string &key, const double value)
Definition: parameters.cpp:33
Bridge::BridgeIO::detailed
void detailed(const char *format,...)
Definition: bridgeIO.cpp:219
Field::nex
int nex() const
Definition: field.h:128
Solver_GMRES_m_Cmplx::m_s
Field m_s
Definition: solver_GMRES_m_Cmplx.h:62
Solver_GMRES_m_Cmplx::flop_count
double flop_count()
Definition: solver_GMRES_m_Cmplx.cpp:439
axpy
void axpy(Field &y, const double a, const Field &x)
axpy(y, a, x): y := a * x + y
Definition: field.cpp:380
Solver_GMRES_m_Cmplx::m_N_M
int m_N_M
Definition: solver_GMRES_m_Cmplx.h:56
Solver_GMRES_m_Cmplx::m_Stop_cond
double m_Stop_cond
Definition: solver_GMRES_m_Cmplx.h:53
ParameterCheck::non_negative
int non_negative(const int v)
Definition: parameterCheck.cpp:21
Solver_GMRES_m_Cmplx::m_beta_prev
double m_beta_prev
Definition: solver_GMRES_m_Cmplx.h:59
Field::nin
int nin() const
Definition: field.h:126
Parameters::fetch_bool
int fetch_bool(const string &key, bool &value) const
Definition: parameters.cpp:391
copy
void copy(Field &y, const Field &x)
copy(y, x): y = x
Definition: field.cpp:212
Solver_GMRES_m_Cmplx::solve_init
void solve_init(const Field &, double &)
Definition: solver_GMRES_m_Cmplx.cpp:308
Bridge::BridgeIO::paranoiac
void paranoiac(const char *format,...)
Definition: bridgeIO.cpp:238
Solver_GMRES_m_Cmplx::m_Nrestart_count
int m_Nrestart_count
Definition: solver_GMRES_m_Cmplx.h:64
Field::norm2
double norm2() const
Definition: field.cpp:113
solver_GMRES_m_Cmplx.h
Solver_GMRES_m_Cmplx::m_Nrestart
int m_Nrestart
Definition: solver_GMRES_m_Cmplx.h:52
Solver_GMRES_m_Cmplx::m_fopr
Fopr * m_fopr
Definition: solver_GMRES_m_Cmplx.h:49
Solver_GMRES_m_Cmplx::index_ij
int index_ij(const int i, const int j)
Definition: solver_GMRES_m_Cmplx.h:118
Solver_GMRES_m_Cmplx::get_parameters
void get_parameters(Parameters &params) const
Definition: solver_GMRES_m_Cmplx.cpp:56
Field::size
int size() const
Definition: field.h:132
Solver_GMRES_m_Cmplx::m_r
Field m_r
Definition: solver_GMRES_m_Cmplx.h:62
AFopr::flop_count
virtual double flop_count()
returns the number of floating point operations.
Definition: afopr.h:160
ParameterCheck::square_non_zero
int square_non_zero(const double v)
Definition: parameterCheck.cpp:43
Field::nvol
int nvol() const
Definition: field.h:127
dotc
dcomplex dotc(const Field &y, const Field &x)
Definition: field.cpp:712
Solver_GMRES_m_Cmplx::m_Niter
int m_Niter
Definition: solver_GMRES_m_Cmplx.h:51
CommonParameters::NPE
static int NPE()
Definition: commonParameters.h:101
Solver_GMRES_m_Cmplx::reset_field
void reset_field(const Field &)
Definition: solver_GMRES_m_Cmplx.cpp:278
Field::reset
void reset(const int Nin, const int Nvol, const int Nex, const element_type cmpl=Element_type::COMPLEX)
Definition: field.h:95
Solver_GMRES_m_Cmplx::min_J
void min_J(std::valarray< dcomplex > &y, std::valarray< dcomplex > &h)
Definition: solver_GMRES_m_Cmplx.cpp:385
Solver_GMRES_m_Cmplx::m_x
Field m_x
Definition: solver_GMRES_m_Cmplx.h:62
Solver_GMRES_m_Cmplx::m_Nconv_count
int m_Nconv_count
Definition: solver_GMRES_m_Cmplx.h:65
Solver_GMRES_m_Cmplx::solve
void solve(Field &solution, const Field &source, int &Nconv, double &diff)
Definition: solver_GMRES_m_Cmplx.cpp:190
Bridge::BridgeIO::set_verbose_level
static VerboseLevel set_verbose_level(const std::string &str)
Definition: bridgeIO.cpp:133
Solver_GMRES_m_Cmplx::m_vl
Bridge::VerboseLevel m_vl
Definition: solver_GMRES_m_Cmplx.h:47
Parameters::set_int
void set_int(const string &key, const int value)
Definition: parameters.cpp:36
scal
void scal(Field &x, const double a)
scal(x, a): x = a * x
Definition: field.cpp:261
Parameters::fetch_string
int fetch_string(const string &key, string &value) const
Definition: parameters.cpp:378
Parameters::fetch_double
int fetch_double(const string &key, double &value) const
Definition: parameters.cpp:327
Solver_GMRES_m_Cmplx::m_use_init_guess
bool m_use_init_guess
Definition: solver_GMRES_m_Cmplx.h:54
Bridge::BridgeIO::crucial
void crucial(const char *format,...)
Definition: bridgeIO.cpp:180
Field
Container of Field-type object.
Definition: field.h:46
Solver_GMRES_m_Cmplx::set_parameters
void set_parameters(const Parameters &params)
Definition: solver_GMRES_m_Cmplx.cpp:25
Parameters::fetch_int
int fetch_int(const string &key, int &value) const
Definition: parameters.cpp:346
Bridge::BridgeIO::general
void general(const char *format,...)
Definition: bridgeIO.cpp:200
ThreadManager::assert_single_thread
static void assert_single_thread(const std::string &class_name)
assert currently running on single thread.
Definition: threadManager.cpp:372
Bridge::vout
BridgeIO vout
Definition: bridgeIO.cpp:512
Bridge::BridgeIO::get_verbose_level
static std::string get_verbose_level(const VerboseLevel vl)
Definition: bridgeIO.cpp:154
CommonParameters::epsilon_criterion
static double epsilon_criterion()
Definition: commonParameters.h:119