Bridge++  Ver. 1.3.x
solver_GMRES_m_Cmplx.cpp
Go to the documentation of this file.
1 
14 #include "solver_GMRES_m_Cmplx.h"
15 
16 #include <valarray>
17 using std::valarray;
18 
19 #ifdef USE_FACTORY
20 namespace {
21  Solver *create_object(Fopr *fopr)
22  {
23  return new Solver_GMRES_m_Cmplx(fopr);
24  }
25 
26 
27  bool init = Solver::Factory::Register("GMRES_m_Cmplx", create_object);
28 }
29 #endif
30 
31 //- parameter entries
32 namespace {
33  void append_entry(Parameters& param)
34  {
35  param.Register_int("maximum_number_of_iteration", 0);
36  param.Register_double("convergence_criterion_squared", 0.0);
37 
38  param.Register_int("number_of_orthonormal_vectors", 0);
39 
40  param.Register_string("verbose_level", "NULL");
41  }
42 
43 
44 #ifdef USE_PARAMETERS_FACTORY
45  bool init_param = ParametersFactory::Register("Solver.GMRES_m_Cmplx", append_entry);
46 #endif
47 }
48 //- end
49 
50 //- parameters class
52 //- end
53 
54 const std::string Solver_GMRES_m_Cmplx::class_name = "Solver_GMRES_m_Cmplx";
55 
56 //====================================================================
58 {
59  const string str_vlevel = params.get_string("verbose_level");
60 
61  m_vl = vout.set_verbose_level(str_vlevel);
62 
63  //- fetch and check input parameters
64  int Niter;
65  double Stop_cond;
66  int N_M;
67 
68  int err = 0;
69  err += params.fetch_int("maximum_number_of_iteration", Niter);
70  err += params.fetch_double("convergence_criterion_squared", Stop_cond);
71  err += params.fetch_int("number_of_orthonormal_vectors", N_M);
72 
73  if (err) {
74  vout.crucial(m_vl, "%s: fetch error, input parameter not found.\n", class_name.c_str());
75  exit(EXIT_FAILURE);
76  }
77 
78 
79  set_parameters(Niter, Stop_cond);
81 }
82 
83 
84 //====================================================================
85 void Solver_GMRES_m_Cmplx::set_parameters(const int Niter, const double Stop_cond)
86 {
88 
89  //- print input parameters
90  vout.general(m_vl, "%s: input parameters\n", class_name.c_str());
91  vout.general(m_vl, " Niter = %d\n", Niter);
92  vout.general(m_vl, " Stop_cond = %16.8e\n", Stop_cond);
93 
94  //- range check
95  int err = 0;
96  err += ParameterCheck::non_negative(Niter);
97  err += ParameterCheck::square_non_zero(Stop_cond);
98 
99  if (err) {
100  vout.crucial(m_vl, "%s: parameter range check failed.\n", class_name.c_str());
101  exit(EXIT_FAILURE);
102  }
103 
104  //- store values
105  m_Niter = Niter;
106  m_Stop_cond = Stop_cond;
107 }
108 
109 
110 //====================================================================
112 {
113  //- print input parameters
114  vout.general(m_vl, " N_M = %d\n", N_M);
115 
116  //- range check
117  int err = 0;
118  err += ParameterCheck::non_negative(N_M);
119 
120  if (err) {
121  vout.crucial(m_vl, "%s: parameter range check failed.\n", class_name.c_str());
122  exit(EXIT_FAILURE);
123  }
124 
125  //- store values
126  m_N_M = N_M;
127 }
128 
129 
130 //====================================================================
132  int& Nconv, double& diff)
133 {
134  double bnorm2 = b.norm2();
135  double snorm = 1.0 / bnorm2;
136  int bsize = b.size();
137  double rr;
138 
139  vout.paranoiac(m_vl, "%s: starts\n", class_name.c_str());
140  vout.paranoiac(m_vl, " norm of b = %16.8e\n", bnorm2);
141  vout.paranoiac(m_vl, " size of b = %d\n", bsize);
142 
143  reset_field(b);
144 
145 
146  // Nconv = -1;
147  int Nconv2 = 0;
148  copy(s, b); // s = b;
149 
150  solve_init(b, rr);
151 
152  bool is_converged = false;
153 
154  vout.detailed(m_vl, " iter: %8d %22.15e\n", 0, rr * snorm);
155 
156 
157  for (int iter = 0; iter < m_Niter; iter++) {
158  if (is_converged) break;
159 
160  solve_step(b, rr);
161 
162  vout.detailed(m_vl, " iter: %8d %22.15e\n", m_N_M * (iter + 1), rr * snorm);
163 
164  if (rr * snorm < m_Stop_cond) {
165  m_fopr->mult(s, x); // s = m_fopr->mult(x);
166  axpy(s, -1.0, b); // s -= b;
167 
168  double diff2 = s.norm2();
169 
170  if (diff2 * snorm < m_Stop_cond) {
171  Nconv2 = m_N_M * (iter + 1);
172  is_converged = true;
173  } else {
174  copy(s, x); // s = x;
175  solve_init(b, rr);
176  }
177  }
178  }
179 
180 
181  m_fopr->mult(s, x); // p = m_fopr->mult(x);
182  axpy(s, -1.0, b); // p -= b;
183 
184  copy(xq, x); // xq = x;
185 
186  double diff2 = s.norm2();
187 
188  if (diff2 * snorm > m_Stop_cond) {
189  vout.crucial(m_vl, "%s: not converged.\n", class_name.c_str());
190  exit(EXIT_FAILURE);
191  }
192 
193 
194 #pragma omp barrier
195 #pragma omp master
196  {
197  diff = sqrt(diff2);
198  Nconv = Nconv2;
199  }
200 #pragma omp barrier
201 }
202 
203 
204 //====================================================================
206 {
207 #pragma omp barrier
208 #pragma omp master
209  {
210  int Nin = b.nin();
211  int Nvol = b.nvol();
212  int Nex = b.nex();
213 
214  if ((s.nin() != Nin) || (s.nvol() != Nvol) || (s.nex() != Nex)) {
215  s.reset(Nin, Nvol, Nex);
216  r.reset(Nin, Nvol, Nex);
217  x.reset(Nin, Nvol, Nex);
218 
219  v_tmp.reset(Nin, Nvol, Nex);
220 
221  v.resize(m_N_M + 1);
222 
223  for (int i = 0; i < m_N_M + 1; ++i) {
224  v[i].reset(Nin, Nvol, Nex);
225  }
226  }
227  }
228 #pragma omp barrier
229 
230  vout.paranoiac(m_vl, " %s: field size reset.\n", class_name.c_str());
231 }
232 
233 
234 //====================================================================
235 void Solver_GMRES_m_Cmplx::solve_init(const Field& b, double& rr)
236 {
237  copy(x, s); // x = s;
238 
239  //- r = b - A x_0
240  m_fopr->mult(v_tmp, s); // v_tmp = m_fopr->mult(s);
241  copy(r, b); // r = b;
242  axpy(r, -1.0, v_tmp); // r -= v_tmp;
243 
244  rr = r.norm2(); // rr = r * r;
245 
246 #pragma omp barrier
247 #pragma omp master
248  beta_prev = sqrt(rr);
249 #pragma omp barrier
250 
251  //- v[0] = (1.0 / beta_prev) * r;
252  copy(v[0], r); // v[0] = r;
253  scal(v[0], (1.0 / beta_prev)); // v[0] = (1.0 / beta_p) * v[0];
254 }
255 
256 
257 //====================================================================
258 void Solver_GMRES_m_Cmplx::solve_step(const Field& b, double& rr)
259 {
260  valarray<dcomplex> h((m_N_M + 1) * m_N_M), y(m_N_M);
261 
262  h = cmplx(0.0, 0.0);
263  y = cmplx(0.0, 0.0);
264 
265 
266  for (int j = 0; j < m_N_M; ++j) {
267  m_fopr->mult(v_tmp, v[j]); // v_tmp = m_fopr->mult(v[j]);
268 
269  for (int i = 0; i < j + 1; ++i) {
270  int ij = index_ij(i, j);
271  h[ij] = dotc(v_tmp, v[i]); // h[ij] = (A v[j], v[i]);
272  }
273 
274  //- v[j+1] = A v[j] - \Sum_{i=0}^{j-1} h[i,j] * v[i]
275  v[j + 1] = v_tmp;
276 
277  for (int i = 0; i < j + 1; ++i) {
278  int ij = index_ij(i, j);
279  axpy(v[j + 1], -h[ij], v[i]); // v[j+1] -= h[ij] * v[i];
280  }
281 
282  double v_norm2 = v[j + 1].norm2();
283 
284  int j1j = index_ij(j + 1, j);
285  h[j1j] = sqrt(v_norm2);
286 
287  scal(v[j + 1], 1.0 / sqrt(v_norm2)); // v[j+1] /= sqrt(v_norm2);
288  }
289 
290 
291  // Compute y, which minimizes J := |r_new| = |beta_p - h * y|
292  min_J(y, h);
293 
294 
295  // x += Sum_{i=0}^{N_M-1} y[i] * v[i];
296  for (int i = 0; i < m_N_M; ++i) {
297  axpy(x, y[i], v[i]); // x += y[i] * v[i];
298  }
299 
300 
301  // r = b - m_fopr->mult(x);
302  copy(s, x); // s = x;
303  solve_init(b, rr);
304 }
305 
306 
307 //====================================================================
308 void Solver_GMRES_m_Cmplx::min_J(valarray<dcomplex>& y,
309  valarray<dcomplex>& h)
310 {
311  // Compute y, which minimizes J := |r_new| = |beta_p - h * y|
312 
313  valarray<dcomplex> g(m_N_M + 1);
314 
315  g = dcomplex(0.0);
316  g[0] = beta_prev;
317 
318  for (int i = 0; i < m_N_M; ++i) {
319  int ii = index_ij(i, i);
320  double h_1_r = abs(h[ii]);
321 
322  int i1i = index_ij(i + 1, i);
323  double h_2_r = abs(h[i1i]);
324 
325  double denomi = sqrt(h_1_r * h_1_r + h_2_r * h_2_r);
326 
327  dcomplex cs = h[ii] / denomi;
328  dcomplex sn = h[i1i] / denomi;
329 
330  for (int j = i; j < m_N_M; ++j) {
331  int ij = index_ij(i, j);
332  int i1j = index_ij(i + 1, j);
333 
334  dcomplex const_1_c = conj(cs) * h[ij] + sn * h[i1j];
335  dcomplex const_2_c = -sn * h[ij] + cs * h[i1j];
336 
337  h[ij] = const_1_c;
338  h[i1j] = const_2_c;
339  }
340 
341  dcomplex const_1_c = conj(cs) * g[i] + sn * g[i + 1];
342  dcomplex const_2_c = -sn * g[i] + cs * g[i + 1];
343 
344  g[i] = const_1_c;
345  g[i + 1] = const_2_c;
346  }
347 
348 
349  for (int i = m_N_M - 1; i > -1; --i) {
350  for (int j = i + 1; j < m_N_M; ++j) {
351  int ij = index_ij(i, j);
352  g[i] -= h[ij] * y[j];
353  }
354 
355  int ii = index_ij(i, i);
356  y[i] = g[i] / h[ii];
357  }
358 }
359 
360 
361 //====================================================================
362 //============================================================END=====
void scal(Field &x, const double a)
scal(x, a): x = a * x
Definition: field.cpp:282
BridgeIO vout
Definition: bridgeIO.cpp:278
void detailed(const char *format,...)
Definition: bridgeIO.cpp:82
void Register_string(const string &, const string &)
Definition: parameters.cpp:351
double norm2() const
Definition: field.cpp:441
void general(const char *format,...)
Definition: bridgeIO.cpp:65
void Register_int(const string &, const int)
Definition: parameters.cpp:330
static const std::string class_name
Container of Field-type object.
Definition: field.h:39
int nvol() const
Definition: field.h:116
void solve_init(const Field &, double &)
Class for parameters.
Definition: parameters.h:38
void copy(Field &y, const Field &x)
copy(y, x): y = x
Definition: field.cpp:381
int square_non_zero(const double v)
Definition: checker.cpp:41
int nin() const
Definition: field.h:115
GMRES(m) algorithm with complex variables.
dcomplex dotc(const Field &y, const Field &x)
Definition: field.cpp:92
void reset(const int Nin, const int Nvol, const int Nex, const element_type cmpl=COMPLEX)
Definition: field.h:84
int nex() const
Definition: field.h:117
int index_ij(int i, int j)
void paranoiac(const char *format,...)
Definition: bridgeIO.cpp:99
void axpy(Field &y, const double a, const Field &x)
axpy(y, a, x): y := a * x + y
Definition: field.cpp:168
void crucial(const char *format,...)
Definition: bridgeIO.cpp:48
Base class for linear solver class family.
Definition: solver.h:38
static bool Register(const std::string &realm, const creator_callback &cb)
void reset_field(const Field &)
virtual void mult(Field &, const Field &)=0
multiplies fermion operator to a given field (2nd argument)
void solve(Field &solution, const Field &source, int &Nconv, double &diff)
int non_negative(const int v)
Definition: checker.cpp:21
std::vector< Field > v
void Register_double(const string &, const double)
Definition: parameters.cpp:323
void min_J(std::valarray< dcomplex > &y, std::valarray< dcomplex > &h)
void set_parameters_GMRES_m(const int N_M)
Base class of fermion operator family.
Definition: fopr.h:49
int fetch_double(const string &key, double &val) const
Definition: parameters.cpp:124
string get_string(const string &key) const
Definition: parameters.cpp:87
int fetch_int(const string &key, int &val) const
Definition: parameters.cpp:141
Bridge::VerboseLevel m_vl
Definition: solver.h:62
static VerboseLevel set_verbose_level(const std::string &str)
Definition: bridgeIO.cpp:28
static void assert_single_thread(const std::string &classname)
assert currently running on single thread.
void solve_step(const Field &, double &)
void set_parameters(const Parameters &params)
int size() const
Definition: field.h:121