Bridge++  Version 1.4.4
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
solver_GMRES_m_Cmplx.cpp
Go to the documentation of this file.
1 
14 #include "solver_GMRES_m_Cmplx.h"
15 
16 #include <valarray>
17 using std::valarray;
18 
19 
20 #ifdef USE_FACTORY
21 namespace {
22  Solver *create_object(Fopr *fopr)
23  {
24  return new Solver_GMRES_m_Cmplx(fopr);
25  }
26 
27 
28  bool init = Solver::Factory::Register("GMRES_m_Cmplx", create_object);
29 }
30 #endif
31 
32 
33 const std::string Solver_GMRES_m_Cmplx::class_name = "Solver_GMRES_m_Cmplx";
34 
35 //====================================================================
37 {
38  const string str_vlevel = params.get_string("verbose_level");
39 
40  m_vl = vout.set_verbose_level(str_vlevel);
41 
42  //- fetch and check input parameters
43  int Niter, Nrestart;
44  double Stop_cond;
45  int N_M;
46 
47  int err = 0;
48  err += params.fetch_int("maximum_number_of_iteration", Niter);
49  err += params.fetch_int("maximum_number_of_restart", Nrestart);
50  err += params.fetch_double("convergence_criterion_squared", Stop_cond);
51  err += params.fetch_int("number_of_orthonormal_vectors", N_M);
52 
53  if (err) {
54  vout.crucial(m_vl, "Error at %s: input parameter not found.\n", class_name.c_str());
55  exit(EXIT_FAILURE);
56  }
57 
58  set_parameters(Niter, Nrestart, Stop_cond);
60 }
61 
62 
63 //====================================================================
64 void Solver_GMRES_m_Cmplx::set_parameters(const int Niter, const int Nrestart, const double Stop_cond)
65 {
67 
68  //- print input parameters
69  vout.general(m_vl, "%s: input parameters\n", class_name.c_str());
70  vout.general(m_vl, " Niter = %d\n", Niter);
71  vout.general(m_vl, " Nrestart = %d\n", Nrestart);
72  vout.general(m_vl, " Stop_cond = %8.2e\n", Stop_cond);
73 
74  //- range check
75  int err = 0;
76  err += ParameterCheck::non_negative(Niter);
77  err += ParameterCheck::non_negative(Nrestart);
78  err += ParameterCheck::square_non_zero(Stop_cond);
79 
80  if (err) {
81  vout.crucial(m_vl, "Error at %s: parameter range check failed.\n", class_name.c_str());
82  exit(EXIT_FAILURE);
83  }
84 
85  //- store values
86  m_Niter = Niter;
87  m_Nrestart = Nrestart;
88  m_Stop_cond = Stop_cond;
89 }
90 
91 
92 //====================================================================
94 {
95  //- print input parameters
96  vout.general(m_vl, " N_M = %d\n", N_M);
97 
98  //- range check
99  int err = 0;
100  err += ParameterCheck::non_negative(N_M);
101 
102  if (err) {
103  vout.crucial(m_vl, "Error at %s: parameter range check failed.\n", class_name.c_str());
104  exit(EXIT_FAILURE);
105  }
106 
107  //- store values
108  m_N_M = N_M;
109 }
110 
111 
112 //====================================================================
114  int& Nconv, double& diff)
115 {
116  double bnorm2 = b.norm2();
117  int bsize = b.size();
118 
119  vout.paranoiac(m_vl, "%s: starts\n", class_name.c_str());
120  vout.paranoiac(m_vl, " norm of b = %16.8e\n", bnorm2);
121  vout.paranoiac(m_vl, " size of b = %d\n", bsize);
122 
123  bool is_converged = false;
124  int Nconv2 = 0;
125  double diff2 = 1.0;
126  double rr;
127 
128  reset_field(b);
129  copy(m_s, b); // s = b;
130  solve_init(b, rr);
131  Nconv2 += 1;
132 
133  vout.detailed(m_vl, " iter: %8d %22.15e\n", Nconv2, rr / bnorm2);
134 
135 
136  for (int i_restart = 0; i_restart < m_Nrestart; i_restart++) {
137  for (int iter = 0; iter < m_Niter; iter++) {
138  if (rr / bnorm2 < m_Stop_cond) break;
139 
140  solve_step(b, rr);
141  Nconv2 += m_N_M;
142 
143  vout.detailed(m_vl, " iter: %8d %22.15e\n", Nconv2, rr / bnorm2);
144  }
145 
146  //- calculate true residual
147  m_fopr->mult(m_s, m_x); // s = m_fopr->mult(x);
148  axpy(m_s, -1.0, b); // s -= b;
149  diff2 = m_s.norm2();
150 
151  if (diff2 / bnorm2 < m_Stop_cond) {
152  vout.detailed(m_vl, "%s: converged.\n", class_name.c_str());
153  vout.detailed(m_vl, " iter(final): %8d %22.15e\n", Nconv2, diff2 / bnorm2);
154 
155  is_converged = true;
156  break;
157  } else {
158  //- restart with new approximate solution
159  copy(m_s, m_x); // s = x;
160  solve_init(b, rr);
161 
162  vout.detailed(m_vl, "%s: restarted.\n", class_name.c_str());
163  }
164  }
165 
166 
167  if (!is_converged) {
168  vout.crucial(m_vl, "Error at %s: not converged.\n", class_name.c_str());
169  vout.crucial(m_vl, " iter(final): %8d %22.15e\n", Nconv2, diff2 / bnorm2);
170  exit(EXIT_FAILURE);
171  }
172 
173 
174  copy(xq, m_x); // xq = x;
175 
176 #pragma omp barrier
177 #pragma omp master
178  {
179  diff = sqrt(diff2 / bnorm2);
180  Nconv = Nconv2;
181  }
182 #pragma omp barrier
183 }
184 
185 
186 //====================================================================
188 {
189 #pragma omp barrier
190 #pragma omp master
191  {
192  int Nin = b.nin();
193  int Nvol = b.nvol();
194  int Nex = b.nex();
195 
196  if ((m_s.nin() != Nin) || (m_s.nvol() != Nvol) || (m_s.nex() != Nex)) {
197  m_s.reset(Nin, Nvol, Nex);
198  m_r.reset(Nin, Nvol, Nex);
199  m_x.reset(Nin, Nvol, Nex);
200 
201  m_v_tmp.reset(Nin, Nvol, Nex);
202 
203  m_v.resize(m_N_M + 1);
204 
205  for (int i = 0; i < m_N_M + 1; ++i) {
206  m_v[i].reset(Nin, Nvol, Nex);
207  }
208  }
209  }
210 #pragma omp barrier
211 
212  vout.paranoiac(m_vl, " %s: field size reset.\n", class_name.c_str());
213 }
214 
215 
216 //====================================================================
217 void Solver_GMRES_m_Cmplx::solve_init(const Field& b, double& rr)
218 {
219  copy(m_x, m_s); // x = s;
220 
221  for (int i = 0; i < m_N_M + 1; ++i) {
222  m_v[i].set(0.0); // m_v[i] = 0.0;
223  }
224 
225  // r = b - A x_0
226  m_fopr->mult(m_v_tmp, m_s); // v_tmp = m_fopr->mult(s);
227  copy(m_r, b); // r = b;
228  axpy(m_r, -1.0, m_v_tmp); // r -= v_tmp;
229 
230  rr = m_r.norm2(); // rr = r * r;
231 
232 #pragma omp barrier
233 #pragma omp master
234  m_beta_prev = sqrt(rr);
235 #pragma omp barrier
236 
237  //- v[0] = (1.0 / m_beta_prev) * r;
238  copy(m_v[0], m_r); // v[0] = r;
239  scal(m_v[0], (1.0 / m_beta_prev)); // v[0] = (1.0 / beta_p) * v[0];
240 }
241 
242 
243 //====================================================================
244 void Solver_GMRES_m_Cmplx::solve_step(const Field& b, double& rr)
245 {
246  valarray<dcomplex> h((m_N_M + 1) * m_N_M), y(m_N_M);
247 
248  h = cmplx(0.0, 0.0);
249  y = cmplx(0.0, 0.0);
250 
251 
252  for (int j = 0; j < m_N_M; ++j) {
253  m_fopr->mult(m_v_tmp, m_v[j]); // v_tmp = m_fopr->mult(v[j]);
254 
255  for (int i = 0; i < j + 1; ++i) {
256  int ij = index_ij(i, j);
257  h[ij] = dotc(m_v_tmp, m_v[i]); // h[ij] = (A v[j], v[i]);
258  }
259 
260  //- v[j+1] = A v[j] - \Sum_{i=0}^{j-1} h[i,j] * v[i]
261  m_v[j + 1] = m_v_tmp;
262 
263  for (int i = 0; i < j + 1; ++i) {
264  int ij = index_ij(i, j);
265  axpy(m_v[j + 1], -h[ij], m_v[i]); // v[j+1] -= h[ij] * v[i];
266  }
267 
268  double v_norm2 = m_v[j + 1].norm2();
269 
270  int j1j = index_ij(j + 1, j);
271  h[j1j] = cmplx(sqrt(v_norm2), 0.0);
272 
273  scal(m_v[j + 1], 1.0 / sqrt(v_norm2)); // v[j+1] /= sqrt(v_norm2);
274  }
275 
276 
277  // Compute y, which minimizes J := |r_new| = |beta_p - h * y|
278  min_J(y, h);
279 
280 
281  // x += Sum_{i=0}^{N_M-1} y[i] * v[i];
282  for (int i = 0; i < m_N_M; ++i) {
283  axpy(m_x, y[i], m_v[i]); // x += y[i] * v[i];
284  }
285 
286 
287  // r = b - m_fopr->mult(x);
288  copy(m_s, m_x); // s = x;
289  solve_init(b, rr);
290 }
291 
292 
293 //====================================================================
294 void Solver_GMRES_m_Cmplx::min_J(valarray<dcomplex>& y,
295  valarray<dcomplex>& h)
296 {
297  // Compute y, which minimizes J := |r_new| = |beta_p - h * y|
298 
299  valarray<dcomplex> g(m_N_M + 1);
300 
301  g = cmplx(0.0, 0.0);
302  g[0] = cmplx(m_beta_prev, 0.0);
303 
304  for (int i = 0; i < m_N_M; ++i) {
305  int ii = index_ij(i, i);
306  double h_1_r = abs(h[ii]);
307 
308  int i1i = index_ij(i + 1, i);
309  double h_2_r = abs(h[i1i]);
310 
311  double denomi = sqrt(h_1_r * h_1_r + h_2_r * h_2_r);
312 
313  dcomplex cs = h[ii] / denomi;
314  dcomplex sn = h[i1i] / denomi;
315 
316  for (int j = i; j < m_N_M; ++j) {
317  int ij = index_ij(i, j);
318  int i1j = index_ij(i + 1, j);
319 
320  dcomplex const_1_c = conj(cs) * h[ij] + sn * h[i1j];
321  dcomplex const_2_c = -sn * h[ij] + cs * h[i1j];
322 
323  h[ij] = const_1_c;
324  h[i1j] = const_2_c;
325  }
326 
327  dcomplex const_1_c = conj(cs) * g[i] + sn * g[i + 1];
328  dcomplex const_2_c = -sn * g[i] + cs * g[i + 1];
329 
330  g[i] = const_1_c;
331  g[i + 1] = const_2_c;
332  }
333 
334 
335  for (int i = m_N_M - 1; i > -1; --i) {
336  for (int j = i + 1; j < m_N_M; ++j) {
337  int ij = index_ij(i, j);
338  g[i] -= h[ij] * y[j];
339  }
340 
341  int ii = index_ij(i, i);
342  y[i] = g[i] / h[ii];
343  }
344 }
345 
346 
347 //====================================================================
349 {
350  int NPE = CommonParameters::NPE();
352 
353  //- NB1 Nin = 2 * Nc * Nd, Nex = 1 for field_F
354  //- NB2 Nvol = CommonParameters::Nvol()/2 for eo
355  int Nin = m_x.nin();
356  int Nvol = m_x.nvol();
357  int Nex = m_x.nex();
358 
359  double flop_fopr = m_fopr->flop_count();
360 
361  if (flop_fopr < eps) {
362  vout.crucial(m_vl, "Warning at %s: no fopr->flop_count() is available, setting flop = 0.0.\n", class_name.c_str());
363  return 0.0;
364  }
365 
366  double flop_axpy = static_cast<double>(Nin * Nex * 2) * (Nvol * NPE);
367  double flop_dotc = static_cast<double>(Nin * Nex * 4) * (Nvol * NPE);
368  double flop_norm = static_cast<double>(Nin * Nex * 2) * (Nvol * NPE);
369  double flop_scal = static_cast<double>(Nin * Nex * 2) * (Nvol * NPE);
370 
371  int N_iter = (m_Nconv_count - 1) / m_N_M;
372 
373  int N_M_part = 0;
374  for (int j = 0; j < m_N_M; ++j) {
375  for (int i = 0; i < j + 1; ++i) {
376  N_M_part += 1;
377  }
378  }
379 
380  double flop_init = flop_fopr + flop_axpy + flop_norm;
381  double flop_step = m_N_M * flop_fopr
382  + N_M_part * flop_dotc
383  + (N_M_part + m_N_M) * flop_axpy
384  + flop_init;
385  double flop_true_residual = flop_fopr + flop_axpy + flop_norm;
386 
387  double flop = flop_norm + flop_init + flop_step * N_iter + flop_true_residual
388  + flop_init * m_Nrestart_count;
389 
390 
391  return flop;
392 }
393 
394 
395 //====================================================================
396 //============================================================END=====
void scal(Field &x, const double a)
scal(x, a): x = a * x
Definition: field.cpp:282
BridgeIO vout
Definition: bridgeIO.cpp:495
std::vector< Field > m_v
void detailed(const char *format,...)
Definition: bridgeIO.cpp:212
static double epsilon_criterion()
double norm2() const
Definition: field.cpp:441
void general(const char *format,...)
Definition: bridgeIO.cpp:195
static const std::string class_name
Container of Field-type object.
Definition: field.h:39
int fetch_double(const string &key, double &value) const
Definition: parameters.cpp:211
int nvol() const
Definition: field.h:116
void solve_init(const Field &, double &)
Class for parameters.
Definition: parameters.h:46
void copy(Field &y, const Field &x)
copy(y, x): y = x
Definition: field.cpp:381
int square_non_zero(const double v)
Definition: checker.cpp:41
int nin() const
Definition: field.h:115
GMRES(m) algorithm with complex variables.
dcomplex dotc(const Field &y, const Field &x)
Definition: field.cpp:92
int fetch_int(const string &key, int &value) const
Definition: parameters.cpp:230
virtual double flop_count()
returns the flops per site.
Definition: fopr.h:121
void reset(const int Nin, const int Nvol, const int Nex, const element_type cmpl=COMPLEX)
Definition: field.h:84
int nex() const
Definition: field.h:117
int index_ij(int i, int j)
void paranoiac(const char *format,...)
Definition: bridgeIO.cpp:229
void axpy(Field &y, const double a, const Field &x)
axpy(y, a, x): y := a * x + y
Definition: field.cpp:168
void crucial(const char *format,...)
Definition: bridgeIO.cpp:178
Base class for linear solver class family.
Definition: solver.h:37
void reset_field(const Field &)
virtual void mult(Field &, const Field &)=0
multiplies fermion operator to a given field (2nd argument)
void solve(Field &solution, const Field &source, int &Nconv, double &diff)
int non_negative(const int v)
Definition: checker.cpp:21
void min_J(std::valarray< dcomplex > &y, std::valarray< dcomplex > &h)
static void assert_single_thread(const std::string &class_name)
assert currently running on single thread.
void set_parameters_GMRES_m(const int N_M)
Base class of fermion operator family.
Definition: fopr.h:47
string get_string(const string &key) const
Definition: parameters.cpp:116
Bridge::VerboseLevel m_vl
Definition: solver.h:63
static VerboseLevel set_verbose_level(const std::string &str)
Definition: bridgeIO.cpp:131
static int NPE()
void solve_step(const Field &, double &)
void set_parameters(const Parameters &params)
int size() const
Definition: field.h:121