Bridge++  Ver. 1.1.x
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
solver_GMRES_m_Cmplx.cpp
Go to the documentation of this file.
1 
14 #include "solver_GMRES_m_Cmplx.h"
15 
16 #ifdef USE_PARAMETERS_FACTORY
17 #include "parameters_factory.h"
18 #endif
19 
20 #ifdef USE_FACTORY
21 namespace {
22  Solver *create_object(Fopr *fopr)
23  {
24  return new Solver_GMRES_m_Cmplx(fopr);
25  }
26 
27 
28  bool init = Solver::Factory::Register("GMRES_m_Cmplx", create_object);
29 }
30 #endif
31 
32 //- parameter entries
33 namespace {
34  void append_entry(Parameters& param)
35  {
36  param.Register_int("maximum_number_of_iteration", 0);
37  param.Register_double("convergence_criterion_squared", 0.0);
38 
39  param.Register_int("number_of_orthonormal_vectors", 0);
40 
41  param.Register_string("verbose_level", "NULL");
42  }
43 
44 
45 #ifdef USE_PARAMETERS_FACTORY
46  bool init_param = ParametersFactory::Register("Solver.GMRES_m_Cmplx", append_entry);
47 #endif
48 }
49 //- end
50 
51 //- parameters class
53 //- end
54 
55 //====================================================================
57 {
58  const string str_vlevel = params.get_string("verbose_level");
59 
60  m_vl = vout.set_verbose_level(str_vlevel);
61 
62  //- fetch and check input parameters
63  int Niter;
64  double Stop_cond;
65  int N_M;
66 
67  int err = 0;
68  err += params.fetch_int("maximum_number_of_iteration", Niter);
69  err += params.fetch_double("convergence_criterion_squared", Stop_cond);
70  err += params.fetch_int("number_of_orthonormal_vectors", N_M);
71 
72  if (err) {
73  vout.crucial(m_vl, "Solver_GMRES_m_Cmplx: fetch error, input parameter not found.\n");
74  abort();
75  }
76 
77 
78  set_parameters(Niter, Stop_cond);
80 }
81 
82 
83 //====================================================================
84 void Solver_GMRES_m_Cmplx::set_parameters(const int Niter, const double Stop_cond)
85 {
86  //- print input parameters
87  vout.general(m_vl, "Parameters of Solver_GMRES_m_Cmplx:\n");
88  vout.general(m_vl, " Niter = %d\n", Niter);
89  vout.general(m_vl, " Stop_cond = %16.8e\n", Stop_cond);
90 
91  //- range check
92  int err = 0;
93  err += ParameterCheck::non_negative(Niter);
94  err += ParameterCheck::square_non_zero(Stop_cond);
95 
96  if (err) {
97  vout.crucial(m_vl, "Solver_GMRES_m_Cmplx: parameter range check failed.\n");
98  abort();
99  }
100 
101  //- store values
102  m_Niter = Niter;
103  m_Stop_cond = Stop_cond;
104 }
105 
106 
107 //====================================================================
109 {
110  //- print input parameters
111  vout.general(m_vl, " N_M = %d\n", N_M);
112 
113  //- range check
114  int err = 0;
115  err += ParameterCheck::non_negative(N_M);
116 
117  if (err) {
118  vout.crucial(m_vl, "Solver_GMRES_m_Cmplx: parameter range check failed.\n");
119  abort();
120  }
121 
122  //- store values
123  m_N_M = N_M;
124 }
125 
126 
127 //====================================================================
129  int& Nconv, double& diff)
130 {
131  vout.detailed(m_vl, " GMRES_m_Cmplx solver starts\n");
132 
133  reset_field(b);
134 
135  vout.paranoiac(m_vl, " norm of b = %16.8e\n", b.norm2());
136  vout.paranoiac(m_vl, " size of b = %d\n", b.size());
137 
138  double snorm = 1.0 / b.norm2();
139  double rr;
140 
141  Nconv = -1;
142  s = b;
143 
144  solve_init(b, rr);
145 
146  vout.detailed(m_vl, " iter: %8d %22.15e\n", 0, rr * snorm);
147 
148  for (int iter = 0; iter < m_Niter; iter++) {
149  solve_step(b, rr);
150 
151  vout.detailed(m_vl, " iter: %8d %22.15e\n", m_N_M * (iter + 1), rr * snorm);
152 
153  if (rr * snorm < m_Stop_cond) {
154  s = m_fopr->mult(x);
155  s -= b;
156  diff = s.norm();
157 
158  if (diff * diff * snorm < m_Stop_cond) {
159  Nconv = m_N_M * (iter + 1);
160  break;
161  }
162 
163  s = x;
164  solve_init(b, rr);
165  }
166  }
167  if (Nconv == -1) {
168  vout.crucial(m_vl, "GMRES_m_Cmplx solver not converged.\n");
169  abort();
170  }
171 
172  s = m_fopr->mult(x);
173  s -= b;
174  diff = s.norm();
175 
176  xq = x;
177 }
178 
179 
180 //====================================================================
182 {
183  int Nin = b.nin();
184  int Nvol = b.nvol();
185  int Nex = b.nex();
186 
187  if ((s.nin() != Nin) || (s.nvol() != Nvol) || (s.nex() != Nex)) {
188  s.reset(Nin, Nvol, Nex);
189  r.reset(Nin, Nvol, Nex);
190  x.reset(Nin, Nvol, Nex);
191 
192  v_tmp.reset(Nin, Nvol, Nex);
193 
194  v.resize(m_N_M + 1);
195 
196  for (int i = 0; i < m_N_M + 1; ++i) {
197  v[i].reset(Nin, Nvol, Nex);
198  }
199 
200  vout.paranoiac(m_vl, " Solver_GMRES_m_Cmplx: field size reset.\n");
201  }
202 }
203 
204 
205 //====================================================================
206 void Solver_GMRES_m_Cmplx::solve_init(const Field& b, double& rr)
207 {
208  x = s;
209 
210  v_tmp = m_fopr->mult(x);
211  r = b;
212  r -= v_tmp;
213 
214  rr = r * r;
215  beta_p = sqrt(rr);
216 
217  v[0] = (1.0 / beta_p) * r;
218 }
219 
220 
221 //====================================================================
222 void Solver_GMRES_m_Cmplx::solve_step(const Field& b, double& rr)
223 {
224  int ij;
225  double const_r, const_i;
226  std::valarray<dcomplex> h((m_N_M + 1) * m_N_M), y(m_N_M);
227 
228  h = cmplx(0.0, 0.0);
229  y = cmplx(0.0, 0.0);
230 
231 
232  for (int j = 0; j < m_N_M; ++j) {
233  v_tmp = m_fopr->mult(v[j]);
234 
235  for (int i = 0; i < (j + 1); ++i) {
236  // h[i,j] = (A v[j], v[i]);
237  ij = index_ij(i, j);
238  innerprod_c(const_r, const_i, v_tmp, v[i]);
239  h[ij] = cmplx(const_r, const_i);
240  }
241 
242  // v[j+1] = A v[j] - \Sum_{i=0}^{j-1} h[i,j] * v[i]
243  v[j + 1] = v_tmp;
244 
245  for (int i = 0; i < (j + 1); ++i) {
246  // v[j+1] -= h[i,j] * v[i];
247  ij = index_ij(i, j);
248  mult_c(v_tmp, v[i], real(h[ij]), imag(h[ij]));
249  v[j + 1] -= v_tmp;
250  }
251 
252  const_r = v[j + 1] * v[j + 1];
253 
254  ij = index_ij(j + 1, j);
255  h[ij] = sqrt(const_r);
256  v[j + 1] /= sqrt(const_r);
257  }
258 
259 
260  // Compute y, which minimizes J := |r_new| = |beta_p - h * y|
261  min_J(y, h);
262 
263 
264  // x += Sum_{i=0}^{N_M-1} y[i] * v[i];
265  for (int i = 0; i < m_N_M; ++i) {
266  mult_c(v_tmp, v[i], real(y[i]), imag(y[i]));
267  x += v_tmp;
268  }
269 
270 
271  // r = b - m_fopr->mult(x);
272  s = x;
273  solve_init(b, rr);
274 }
275 
276 
277 //====================================================================
278 void Solver_GMRES_m_Cmplx::innerprod_c(double& prod_r, double& prod_i,
279  const Field& v, const Field& w)
280 {
281  // prod = (v,w);
282 
283  int size = w.size();
284 
285  assert(v.size() == size);
286 
287  prod_r = 0.0;
288  prod_i = 0.0;
289 
290  for (int i = 0; i < size; i += 2) {
291  prod_r += v.cmp(i) * w.cmp(i) + v.cmp(i + 1) * w.cmp(i + 1);
292  prod_i += v.cmp(i) * w.cmp(i + 1) - v.cmp(i + 1) * w.cmp(i);
293  }
294 
295  prod_r = Communicator::reduce_sum(prod_r);
296  prod_i = Communicator::reduce_sum(prod_i);
297 }
298 
299 
300 //====================================================================
301 void Solver_GMRES_m_Cmplx::min_J(std::valarray<dcomplex>& y,
302  std::valarray<dcomplex>& h)
303 {
304  // Compute y, which minimizes J := |r_new| = |beta_p - h * y|
305 
306  int ii, i1i, ij, i1j;
307  double const_r, const_1_r, const_2_r;
308  dcomplex cs, sn, const_1_c, const_2_c;
309 
310  std::valarray<dcomplex> g(m_N_M + 1);
311 
312  g = dcomplex(0.0);
313  g[0] = beta_p;
314 
315 
316  for (int i = 0; i < m_N_M; ++i) {
317  ii = index_ij(i, i);
318  const_1_r = abs(h[ii]);
319 
320  i1i = index_ij(i + 1, i);
321  const_2_r = abs(h[i1i]);
322 
323  const_r = sqrt(const_1_r * const_1_r
324  + const_2_r * const_2_r);
325 
326  cs = h[ii] / const_r;
327  sn = h[i1i] / const_r;
328 
329  for (int j = i; j < m_N_M; ++j) {
330  ij = index_ij(i, j);
331  i1j = index_ij(i + 1, j);
332 
333  const_1_c = conj(cs) * h[ij] + sn * h[i1j];
334  const_2_c = -sn * h[ij] + cs * h[i1j];
335 
336  h[ij] = const_1_c;
337  h[i1j] = const_2_c;
338  }
339 
340  const_1_c = conj(cs) * g[i] + sn * g[i + 1];
341  const_2_c = -sn * g[i] + cs * g[i + 1];
342 
343  g[i] = const_1_c;
344  g[i + 1] = const_2_c;
345  }
346 
347 
348  for (int i = m_N_M - 1; i > -1; --i) {
349  for (int j = i + 1; j < m_N_M; ++j) {
350  ij = index_ij(i, j);
351  g[i] -= h[ij] * y[j];
352  }
353 
354  ii = index_ij(i, i);
355  y[i] = g[i] / h[ii];
356  }
357 }
358 
359 
360 //====================================================================
362  const Field& w,
363  const double& prod_r, const double& prod_i)
364 {
365  // v = dcomplex(prod_r,prod_i) * w;
366 
367  int size = w.size();
368 
369  assert(v.size() == size);
370 
371  double vr, vi;
372  for (int i = 0; i < size; i += 2) {
373  vr = prod_r * w.cmp(i) - prod_i * w.cmp(i + 1);
374  vi = prod_r * w.cmp(i + 1) + prod_i * w.cmp(i);
375 
376  v.set(i, vr);
377  v.set(i + 1, vi);
378  }
379 }
380 
381 
382 //====================================================================
383 //============================================================END=====