Bridge++  Version 1.5.4
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
solver_GMRES_m_Cmplx.cpp
Go to the documentation of this file.
1 
14 #include "solver_GMRES_m_Cmplx.h"
15 
16 #ifdef USE_FACTORY_AUTOREGISTER
17 namespace {
18  bool init = Solver_GMRES_m_Cmplx::register_factory();
19 }
20 #endif
21 
22 const std::string Solver_GMRES_m_Cmplx::class_name = "Solver_GMRES_m_Cmplx";
23 
24 //====================================================================
26 {
27  const std::string str_vlevel = params.get_string("verbose_level");
28 
29  m_vl = vout.set_verbose_level(str_vlevel);
30 
31  //- fetch and check input parameters
32  int Niter, Nrestart;
33  double Stop_cond;
34  bool use_init_guess;
35  int N_M;
36 
37  int err = 0;
38  err += params.fetch_int("maximum_number_of_iteration", Niter);
39  err += params.fetch_int("maximum_number_of_restart", Nrestart);
40  err += params.fetch_double("convergence_criterion_squared", Stop_cond);
41  err += params.fetch_bool("use_initial_guess", use_init_guess);
42  err += params.fetch_int("number_of_orthonormal_vectors", N_M);
43 
44  if (err) {
45  vout.crucial(m_vl, "Error at %s: input parameter not found.\n", class_name.c_str());
46  exit(EXIT_FAILURE);
47  }
48 
49  set_parameters(Niter, Nrestart, Stop_cond, use_init_guess);
51 }
52 
53 
54 //====================================================================
55 void Solver_GMRES_m_Cmplx::set_parameters(const int Niter, const int Nrestart, const double Stop_cond)
56 {
58 
59  //- print input parameters
60  vout.general(m_vl, "%s: input parameters\n", class_name.c_str());
61  vout.general(m_vl, " Niter = %d\n", Niter);
62  vout.general(m_vl, " Nrestart = %d\n", Nrestart);
63  vout.general(m_vl, " Stop_cond = %8.2e\n", Stop_cond);
64 
65  //- range check
66  int err = 0;
67  err += ParameterCheck::non_negative(Niter);
68  err += ParameterCheck::non_negative(Nrestart);
69  err += ParameterCheck::square_non_zero(Stop_cond);
70 
71  if (err) {
72  vout.crucial(m_vl, "Error at %s: parameter range check failed.\n", class_name.c_str());
73  exit(EXIT_FAILURE);
74  }
75 
76  //- store values
77  m_Niter = Niter;
78  m_Nrestart = Nrestart;
79  m_Stop_cond = Stop_cond;
80 }
81 
82 
83 //====================================================================
84 void Solver_GMRES_m_Cmplx::set_parameters(const int Niter, const int Nrestart, const double Stop_cond, const bool use_init_guess)
85 {
87 
88  //- print input parameters
89  vout.general(m_vl, "%s: input parameters\n", class_name.c_str());
90  vout.general(m_vl, " Niter = %d\n", Niter);
91  vout.general(m_vl, " Nrestart = %d\n", Nrestart);
92  vout.general(m_vl, " Stop_cond = %8.2e\n", Stop_cond);
93  vout.general(m_vl, " use_init_guess = %s\n", use_init_guess ? "true" : "false");
94 
95  //- range check
96  int err = 0;
97  err += ParameterCheck::non_negative(Niter);
98  err += ParameterCheck::non_negative(Nrestart);
99  err += ParameterCheck::square_non_zero(Stop_cond);
100 
101  if (err) {
102  vout.crucial(m_vl, "Error at %s: parameter range check failed.\n", class_name.c_str());
103  exit(EXIT_FAILURE);
104  }
105 
106  //- store values
107  m_Niter = Niter;
108  m_Nrestart = Nrestart;
109  m_Stop_cond = Stop_cond;
110  m_use_init_guess = use_init_guess;
111 }
112 
113 
114 //====================================================================
116 {
117  //- print input parameters
118  vout.general(m_vl, " N_M = %d\n", N_M);
119 
120  //- range check
121  int err = 0;
122  err += ParameterCheck::non_negative(N_M);
123 
124  if (err) {
125  vout.crucial(m_vl, "Error at %s: parameter range check failed.\n", class_name.c_str());
126  exit(EXIT_FAILURE);
127  }
128 
129  //- store values
130  m_N_M = N_M;
131 }
132 
133 
134 //====================================================================
136  int& Nconv, double& diff)
137 {
138  const double bnorm2 = b.norm2();
139  const int bsize = b.size();
140 
141  vout.paranoiac(m_vl, "%s: starts\n", class_name.c_str());
142  vout.paranoiac(m_vl, " norm of b = %16.8e\n", bnorm2);
143  vout.paranoiac(m_vl, " size of b = %d\n", bsize);
144 
145  bool is_converged = false;
146  int Nconv2 = 0;
147  double diff2 = 1.0; // superficial initialization
148  double rr;
149 
150  int Nconv_unit = 1;
151  // if (m_fopr->get_mode() == "DdagD" || m_fopr->get_mode() == "DDdag") {
152  // Nconv_unit = 2;
153  // }
154 
155  reset_field(b);
156 
157  if (m_use_init_guess) {
158  copy(m_s, xq); // s = xq;
159  } else {
160  copy(m_s, b); // s = b;
161  }
162  solve_init(b, rr);
163  Nconv2 += Nconv_unit;
164 
165  vout.detailed(m_vl, " iter: %8d %22.15e\n", Nconv2, rr / bnorm2);
166 
167 
168  for (int i_restart = 0; i_restart < m_Nrestart; i_restart++) {
169  for (int iter = 0; iter < m_Niter; iter++) {
170  if (rr / bnorm2 < m_Stop_cond) break;
171 
172  solve_step(b, rr);
173  Nconv2 += Nconv_unit * m_N_M;
174 
175  vout.detailed(m_vl, " iter: %8d %22.15e\n", Nconv2, rr / bnorm2);
176  }
177 
178  //- calculate true residual
179  m_fopr->mult(m_s, m_x); // s = m_fopr->mult(x);
180  axpy(m_s, -1.0, b); // s -= b;
181  diff2 = m_s.norm2();
182 
183  if (diff2 / bnorm2 < m_Stop_cond) {
184  vout.detailed(m_vl, "%s: converged.\n", class_name.c_str());
185  vout.detailed(m_vl, " iter(final): %8d %22.15e\n", Nconv2, diff2 / bnorm2);
186 
187  is_converged = true;
188 
189  m_Nrestart_count = i_restart;
190  m_Nconv_count = Nconv2;
191 
192  break;
193  } else {
194  //- restart with new approximate solution
195  copy(m_s, m_x); // s = x;
196  solve_init(b, rr);
197 
198  vout.detailed(m_vl, "%s: restarted.\n", class_name.c_str());
199  }
200  }
201 
202 
203  if (!is_converged) {
204  vout.crucial(m_vl, "Error at %s: not converged.\n", class_name.c_str());
205  vout.crucial(m_vl, " iter(final): %8d %22.15e\n", Nconv2, diff2 / bnorm2);
206  exit(EXIT_FAILURE);
207  }
208 
209 
210  copy(xq, m_x); // xq = x;
211 
212 #pragma omp barrier
213 #pragma omp master
214  {
215  diff = sqrt(diff2 / bnorm2);
216  Nconv = Nconv2;
217  }
218 #pragma omp barrier
219 }
220 
221 
222 //====================================================================
224 {
225 #pragma omp barrier
226 #pragma omp master
227  {
228  const int Nin = b.nin();
229  const int Nvol = b.nvol();
230  const int Nex = b.nex();
231 
232  if ((m_s.nin() != Nin) || (m_s.nvol() != Nvol) || (m_s.nex() != Nex)) {
233  m_s.reset(Nin, Nvol, Nex);
234  m_r.reset(Nin, Nvol, Nex);
235  m_x.reset(Nin, Nvol, Nex);
236 
237  m_v_tmp.reset(Nin, Nvol, Nex);
238 
239  m_v.resize(m_N_M + 1);
240 
241  for (int i = 0; i < m_N_M + 1; ++i) {
242  m_v[i].reset(Nin, Nvol, Nex);
243  }
244  }
245  }
246 #pragma omp barrier
247 
248  vout.paranoiac(m_vl, " %s: field size reset.\n", class_name.c_str());
249 }
250 
251 
252 //====================================================================
253 void Solver_GMRES_m_Cmplx::solve_init(const Field& b, double& rr)
254 {
255  copy(m_x, m_s); // x = s;
256 
257  for (int i = 0; i < m_N_M + 1; ++i) {
258  m_v[i].set(0.0); // m_v[i] = 0.0;
259  }
260 
261  // r = b - A x_0
262  m_fopr->mult(m_v_tmp, m_s); // v_tmp = m_fopr->mult(s);
263  copy(m_r, b); // r = b;
264  axpy(m_r, -1.0, m_v_tmp); // r -= v_tmp;
265 
266  rr = m_r.norm2(); // rr = r * r;
267 
268 #pragma omp barrier
269 #pragma omp master
270  m_beta_prev = sqrt(rr);
271 #pragma omp barrier
272 
273  //- v[0] = (1.0 / m_beta_prev) * r;
274  copy(m_v[0], m_r); // v[0] = r;
275  scal(m_v[0], (1.0 / m_beta_prev)); // v[0] = (1.0 / beta_p) * v[0];
276 }
277 
278 
279 //====================================================================
280 void Solver_GMRES_m_Cmplx::solve_step(const Field& b, double& rr)
281 {
282  std::valarray<dcomplex> h((m_N_M + 1) * m_N_M), y(m_N_M);
283 
284  h = cmplx(0.0, 0.0);
285  y = cmplx(0.0, 0.0);
286 
287 
288  for (int j = 0; j < m_N_M; ++j) {
289  m_fopr->mult(m_v_tmp, m_v[j]); // v_tmp = m_fopr->mult(v[j]);
290 
291  for (int i = 0; i < j + 1; ++i) {
292  int ij = index_ij(i, j);
293  h[ij] = dotc(m_v[i], m_v_tmp); // h[ij] = (v[i], A v[j]);
294  }
295 
296  //- v[j+1] = A v[j] - \Sum_{i=0}^{j-1} h[i,j] * v[i]
297  m_v[j + 1] = m_v_tmp;
298 
299  for (int i = 0; i < j + 1; ++i) {
300  int ij = index_ij(i, j);
301  axpy(m_v[j + 1], -h[ij], m_v[i]); // v[j+1] -= h[ij] * v[i];
302  }
303 
304  double v_norm2 = m_v[j + 1].norm2();
305 
306  int j1j = index_ij(j + 1, j);
307  h[j1j] = cmplx(sqrt(v_norm2), 0.0);
308 
309  scal(m_v[j + 1], 1.0 / sqrt(v_norm2)); // v[j+1] /= sqrt(v_norm2);
310  }
311 
312 
313  // Compute y, which minimizes J := |r_new| = |beta_p - h * y|
314  min_J(y, h);
315 
316 
317  // x += Sum_{i=0}^{N_M-1} y[i] * v[i];
318  for (int i = 0; i < m_N_M; ++i) {
319  axpy(m_x, y[i], m_v[i]); // x += y[i] * v[i];
320  }
321 
322 
323  // r = b - m_fopr->mult(x);
324  copy(m_s, m_x); // s = x;
325  solve_init(b, rr);
326 }
327 
328 
329 //====================================================================
330 void Solver_GMRES_m_Cmplx::min_J(std::valarray<dcomplex>& y,
331  std::valarray<dcomplex>& h)
332 {
333  // Compute y, which minimizes J := |r_new| = |beta_p - h * y|
334 
335  std::valarray<dcomplex> g(m_N_M + 1);
336 
337  g = cmplx(0.0, 0.0);
338  g[0] = cmplx(m_beta_prev, 0.0);
339 
340  for (int i = 0; i < m_N_M; ++i) {
341  int ii = index_ij(i, i);
342  double h_1_r = abs(h[ii]);
343 
344  int i1i = index_ij(i + 1, i);
345  double h_2_r = abs(h[i1i]);
346 
347  double denomi = sqrt(h_1_r * h_1_r + h_2_r * h_2_r);
348 
349  dcomplex cs = h[ii] / denomi;
350  dcomplex sn = h[i1i] / denomi;
351 
352  for (int j = i; j < m_N_M; ++j) {
353  int ij = index_ij(i, j);
354  int i1j = index_ij(i + 1, j);
355 
356  dcomplex const_1_c = conj(cs) * h[ij] + sn * h[i1j];
357  dcomplex const_2_c = -sn * h[ij] + cs * h[i1j];
358 
359  h[ij] = const_1_c;
360  h[i1j] = const_2_c;
361  }
362 
363  dcomplex const_1_c = conj(cs) * g[i] + sn * g[i + 1];
364  dcomplex const_2_c = -sn * g[i] + cs * g[i + 1];
365 
366  g[i] = const_1_c;
367  g[i + 1] = const_2_c;
368  }
369 
370 
371  for (int i = m_N_M - 1; i > -1; --i) {
372  for (int j = i + 1; j < m_N_M; ++j) {
373  int ij = index_ij(i, j);
374  g[i] -= h[ij] * y[j];
375  }
376 
377  int ii = index_ij(i, i);
378  y[i] = g[i] / h[ii];
379  }
380 }
381 
382 
383 //====================================================================
385 {
386  const int NPE = CommonParameters::NPE();
387 
388  //- NB1 Nin = 2 * Nc * Nd, Nex = 1 for field_F
389  //- NB2 Nvol = CommonParameters::Nvol()/2 for eo
390  const int Nin = m_x.nin();
391  const int Nvol = m_x.nvol();
392  const int Nex = m_x.nex();
393 
394  const double gflop_fopr = m_fopr->flop_count();
395 
396  if (gflop_fopr < CommonParameters::epsilon_criterion()) {
397  vout.crucial(m_vl, "Warning at %s: no fopr->flop_count() is available, setting flop = 0\n", class_name.c_str());
398  return 0.0;
399  }
400 
401  const double gflop_axpy = (Nin * Nex * 2) * ((Nvol * NPE) / 1.0e+9);
402  const double gflop_dotc = (Nin * Nex * 4) * ((Nvol * NPE) / 1.0e+9);
403  const double gflop_norm = (Nin * Nex * 2) * ((Nvol * NPE) / 1.0e+9);
404  const double gflop_scal = (Nin * Nex * 2) * ((Nvol * NPE) / 1.0e+9);
405 
406  int N_M_part = 0;
407  for (int j = 0; j < m_N_M; ++j) {
408  for (int i = 0; i < j + 1; ++i) {
409  N_M_part += 1;
410  }
411  }
412 
413  const double gflop_init = gflop_fopr + gflop_axpy + gflop_norm;
414  const double gflop_step = m_N_M * gflop_fopr + N_M_part * gflop_dotc
415  + (N_M_part + m_N_M) * gflop_axpy
416  + gflop_init;
417  const double gflop_true_residual = gflop_fopr + gflop_axpy + gflop_norm;
418 
419  const int N_iter = (m_Nconv_count - 1) / m_N_M;
420  const double gflop = gflop_norm + gflop_init
421  + gflop_step * N_iter + gflop_true_residual * (m_Nrestart_count + 1)
422  + gflop_init * m_Nrestart_count;
423 
424  return gflop;
425 }
426 
427 
428 //====================================================================
429 //============================================================END=====
void scal(Field &x, const double a)
scal(x, a): x = a * x
Definition: field.cpp:433
BridgeIO vout
Definition: bridgeIO.cpp:503
int fetch_bool(const string &key, bool &value) const
Definition: parameters.cpp:391
std::vector< Field > m_v
void detailed(const char *format,...)
Definition: bridgeIO.cpp:216
static double epsilon_criterion()
double norm2() const
Definition: field.cpp:592
void general(const char *format,...)
Definition: bridgeIO.cpp:197
static const std::string class_name
Container of Field-type object.
Definition: field.h:45
int fetch_double(const string &key, double &value) const
Definition: parameters.cpp:327
int nvol() const
Definition: field.h:127
void solve_init(const Field &, double &)
Class for parameters.
Definition: parameters.h:46
void copy(Field &y, const Field &x)
copy(y, x): y = x
Definition: field.cpp:532
int square_non_zero(const double v)
int nin() const
Definition: field.h:126
dcomplex dotc(const Field &y, const Field &x)
Definition: field.cpp:155
int fetch_int(const string &key, int &value) const
Definition: parameters.cpp:346
virtual double flop_count()
returns the flop in giga unit
Definition: fopr.h:120
int index_ij(const int i, const int j)
int nex() const
Definition: field.h:128
void paranoiac(const char *format,...)
Definition: bridgeIO.cpp:235
void reset(const int Nin, const int Nvol, const int Nex, const element_type cmpl=Element_type::COMPLEX)
Definition: field.h:95
void axpy(Field &y, const double a, const Field &x)
axpy(y, a, x): y := a * x + y
Definition: field.cpp:319
void crucial(const char *format,...)
Definition: bridgeIO.cpp:178
void reset_field(const Field &)
virtual void mult(Field &, const Field &)=0
multiplies fermion operator to a given field (2nd argument)
void solve(Field &solution, const Field &source, int &Nconv, double &diff)
int non_negative(const int v)
void min_J(std::valarray< dcomplex > &y, std::valarray< dcomplex > &h)
static void assert_single_thread(const std::string &class_name)
assert currently running on single thread.
void set_parameters_GMRES_m(const int N_M)
string get_string(const string &key) const
Definition: parameters.cpp:221
Bridge::VerboseLevel m_vl
Definition: solver.h:63
static VerboseLevel set_verbose_level(const std::string &str)
Definition: bridgeIO.cpp:131
void solve_step(const Field &, double &)
void set_parameters(const Parameters &params)
int size() const
Definition: field.h:132