Bridge++  Ver. 2.0.2
ashiftsolver_CG-tmpl.h
Go to the documentation of this file.
1 
14 #include "ashiftsolver_CG.h"
15 
16 template<typename FIELD, typename FOPR>
18  = "AShiftsolver_CG";
19 
20 //====================================================================
21 template<typename FIELD, typename FOPR>
23  const Parameters& params)
24 {
25  std::string vlevel;
26  if (!params.fetch_string("verbose_level", vlevel)) {
27  m_vl = vout.set_verbose_level(vlevel);
28  }
29 
30  //- fetch and check input parameters
31  int Niter;
32  double Stop_cond;
33 
34  int err = 0;
35  err += params.fetch_int("maximum_number_of_iteration", Niter);
36  err += params.fetch_double("convergence_criterion_squared", Stop_cond);
37 
38  if (err) {
39  vout.crucial(m_vl, "Error at %s: input parameter not found.\n", class_name.c_str());
40  exit(EXIT_FAILURE);
41  }
42 
43  set_parameters(Niter, Stop_cond);
44 }
45 
46 
47 //====================================================================
48 template<typename FIELD, typename FOPR>
50 {
51  params.set_int("maximum_number_of_iteration", m_Niter);
52  params.set_double("convergence_criterion_squared", m_Stop_cond);
53 
54  params.set_string("verbose_level", vout.get_verbose_level(m_vl));
55 }
56 
57 
58 //====================================================================
59 template<typename FIELD, typename FOPR>
61  const int Niter,
62  const double Stop_cond)
63 {
64  //- print input parameters
65  vout.general(m_vl, "%s:\n", class_name.c_str());
66  vout.general(m_vl, " Niter = %d\n", Niter);
67  vout.general(m_vl, " Stop_cond = %8.2e\n", Stop_cond);
68 
69  //- range check
70  int err = 0;
71  err += ParameterCheck::non_negative(Niter);
72  err += ParameterCheck::square_non_zero(Stop_cond);
73 
74  if (err) {
75  vout.crucial(m_vl, "Error at %s: parameter range check failed.\n",
76  class_name.c_str());
77  exit(EXIT_FAILURE);
78  }
79 
80  //- store values
81  m_Niter = Niter;
82  m_Stop_cond = Stop_cond;
83 }
84 
85 
86 //====================================================================
87 template<typename FIELD, typename FOPR>
89  std::vector<FIELD>& xq,
90  const std::vector<double>& sigma,
91  const FIELD& b,
92  int& Nconv,
93  double& diff)
94 {
95  int Nshift = sigma.size();
96 
97  vout.paranoiac(m_vl, " Shift CG solver start.\n");
98  vout.paranoiac(m_vl, " number of shift = %d\n", Nshift);
99  vout.paranoiac(m_vl, " values of shift:\n");
100  for (int i = 0; i < Nshift; ++i) {
101  vout.paranoiac(m_vl, " %d %12.8f\n", i, sigma[i]);
102  }
103 
104  m_snorm = 1.0 / b.norm2();
105 
106  int Nconv2 = -1;
107 
108  reset_field(b, sigma, Nshift);
109 
110  copy(m_s, b);
111  copy(m_r, b);
112 
113  double rr = 0.0;
114 
115  solve_init(rr);
116 
117  vout.detailed(m_vl, " iter: %8d %22.15e\n", 0, rr * m_snorm);
118 
119  bool is_converged = false;
120 
121  for (int iter = 0; iter < m_Niter; iter++) {
122  solve_step(rr);
123 
124  Nconv2 += 1;
125 
126  vout.detailed(m_vl, " iter: %8d %22.15e %4d\n",
127  (iter + 1), rr * m_snorm, m_Nshift2);
128 
129  if (rr * m_snorm < m_Stop_cond) {
130  is_converged = true;
131  break;
132  }
133  }
134 
135  if (!is_converged) {
136  vout.crucial(m_vl, "Error at %s: not converged.\n",
137  class_name.c_str());
138  exit(EXIT_FAILURE);
139  }
140 
141 
142  std::vector<double> diffs(Nshift);
143  for (int i = 0; i < Nshift; ++i) {
144  diffs[i] = 0.0;
145  }
146 
147  for (int i = 0; i < Nshift; ++i) {
148  m_fopr->mult(m_s, m_x[i]);
149  axpy(m_s, sigma[i], m_x[i]);
150  axpy(m_s, -1.0, b);
151 
152  double diff1 = sqrt(m_s.norm2() * m_snorm);
153 
154  vout.paranoiac(m_vl, " %4d %22.15e\n", i, diff1);
155 
156  // if (diff1 > diff2) diff2 = diff1;
157  diffs[i] = diff1;
158  }
159 
160 #pragma omp barrier
161 #pragma omp master
162  {
163  double diff2 = -1.0;
164 
165  for (int i = 0; i < Nshift; ++i) {
166  if (diffs[i] > diff2) diff2 = diffs[i];
167  }
168 
169  diff = diff2;
170 
171  Nconv = Nconv2;
172  }
173 #pragma omp barrier
174 
175  for (int i = 0; i < Nshift; ++i) {
176  copy(xq[i], m_x[i]);
177  }
178 
179  vout.paranoiac(m_vl, " diff(max) = %22.15e \n", diff);
180 }
181 
182 
183 //====================================================================
184 template<typename FIELD, typename FOPR>
186 {
187  int Nshift = m_p.size();
188 
189  vout.paranoiac(m_vl, "number of shift = %d\n", Nshift);
190 
191  for (int i = 0; i < Nshift; ++i) {
192  copy(m_p[i], m_s);
193  scal(m_x[i], 0.0);
194  }
195 
196  copy(m_r, m_s);
197  rr = m_r.norm2();
198 
199 #pragma omp barrier
200 #pragma omp master
201  {
202  m_alpha_p = 0.0;
203  m_beta_p = 1.0;
204  }
205 #pragma omp barrier
206 }
207 
208 
209 //====================================================================
210 template<typename FIELD, typename FOPR>
212 {
213  m_fopr->mult(m_s, m_p[0]);
214  axpy(m_s, m_sigma0, m_p[0]);
215 
216  double rr_p = rr;
217  double pa_p = dot(m_s, m_p[0]);
218  double beta = -rr_p / pa_p;
219 
220  axpy(m_x[0], -beta, m_p[0]);
221  axpy(m_r, beta, m_s);
222  rr = m_r.norm2();
223 
224  double alpha = rr / rr_p;
225 
226  aypx(alpha, m_p[0], m_r);
227 
228 #pragma omp barrier
229 #pragma omp master
230  {
231  m_pp[0] = rr;
232  }
233 #pragma omp barrier
234 
235  double alpha_h = 1.0 + m_alpha_p * beta / m_beta_p;
236 
237  for (int ish = 1; ish < m_Nshift2; ++ish) {
238  double zeta = (alpha_h - m_csh2[ish] * beta) / m_zeta1[ish]
239  + (1.0 - alpha_h) / m_zeta2[ish];
240  zeta = 1.0 / zeta;
241  double zr = zeta / m_zeta1[ish];
242  double beta_s = beta * zr;
243  double alpha_s = alpha * zr * zr;
244 
245  axpy(m_x[ish], -beta_s, m_p[ish]);
246  scal(m_p[ish], alpha_s);
247  axpy(m_p[ish], zeta, m_r);
248 
249  double ppr = m_p[ish].norm2();
250 
251 #pragma omp barrier
252 #pragma omp master
253  {
254  m_pp[ish] = ppr * m_snorm;
255 
256  m_zeta2[ish] = m_zeta1[ish];
257  m_zeta1[ish] = zeta;
258  }
259 #pragma omp barrier
260  }
261 
262  int ish1 = m_Nshift2;
263 
264  for (int ish = m_Nshift2 - 1; ish >= 0; --ish) {
265  vout.paranoiac(m_vl, "%4d %16.8e\n", ish, m_pp[ish]);
266  if (m_pp[ish] > m_Stop_cond) {
267  ish1 = ish + 1;
268  break;
269  }
270  }
271 
272 #pragma omp barrier
273 #pragma omp master
274  {
275  m_Nshift2 = ish1;
276 
277  m_alpha_p = alpha;
278  m_beta_p = beta;
279  }
280 #pragma omp barrier
281 }
282 
283 
284 //====================================================================
285 template<typename FIELD, typename FOPR>
287  const std::vector<double>& sigma,
288  const int Nshift)
289 {
290 #pragma omp barrier
291 #pragma omp master
292  {
293  int Nin = b.nin();
294  int Nvol = b.nvol();
295  int Nex = b.nex();
296 
297  m_p.resize(Nshift);
298  m_x.resize(Nshift);
299  m_zeta1.resize(Nshift);
300  m_zeta2.resize(Nshift);
301  m_csh2.resize(Nshift);
302  m_pp.resize(Nshift);
303 
304  for (int i = 0; i < Nshift; ++i) {
305  m_p[i].reset(Nin, Nvol, Nex);
306  m_x[i].reset(Nin, Nvol, Nex);
307  m_zeta1[i] = 1.0;
308  m_zeta2[i] = 1.0;
309  m_csh2[i] = sigma[i] - sigma[0];
310 
311  m_pp[i] = 0.0;
312  }
313 
314  m_s.reset(Nin, Nvol, Nex);
315  m_r.reset(Nin, Nvol, Nex);
316 
317  m_sigma0 = sigma[0];
318 
319  m_Nshift2 = Nshift;
320  }
321 #pragma omp barrier
322 }
323 
324 
325 //====================================================================
326 template<typename FIELD, typename FOPR>
328 {
329  vout.general(m_vl, "Warning at %s: flop_count() not yet implemented.\n",
330  class_name.c_str());
331  return 0.0;
332 }
333 
334 
335 //============================================================END=====
AShiftsolver_CG::get_parameters
void get_parameters(Parameters &params) const
Definition: ashiftsolver_CG-tmpl.h:49
Parameters::set_string
void set_string(const string &key, const string &value)
Definition: parameters.cpp:39
AShiftsolver_CG::set_parameters
void set_parameters(const Parameters &params)
Definition: ashiftsolver_CG-tmpl.h:22
Parameters
Class for parameters.
Definition: parameters.h:46
AShiftsolver_CG::solve_init
void solve_init(double &)
Definition: ashiftsolver_CG-tmpl.h:185
Parameters::set_double
void set_double(const string &key, const double value)
Definition: parameters.cpp:33
Bridge::BridgeIO::detailed
void detailed(const char *format,...)
Definition: bridgeIO.cpp:219
aypx
void aypx(const double a, Field &y, const Field &x)
aypx(y, a, x): y := a * y + x
Definition: field.cpp:509
AShiftsolver_CG::solve_step
void solve_step(double &)
Definition: ashiftsolver_CG-tmpl.h:211
AShiftsolver_CG::reset_field
void reset_field(const FIELD &b, const std::vector< double > &sigma, const int Nshift)
Definition: ashiftsolver_CG-tmpl.h:286
axpy
void axpy(Field &y, const double a, const Field &x)
axpy(y, a, x): y := a * x + y
Definition: field.cpp:380
dot
double dot(const Field &y, const Field &x)
Definition: field.cpp:576
ParameterCheck::non_negative
int non_negative(const int v)
Definition: parameterCheck.cpp:21
copy
void copy(Field &y, const Field &x)
copy(y, x): y = x
Definition: field.cpp:212
AShiftsolver_CG::flop_count
double flop_count()
Definition: ashiftsolver_CG-tmpl.h:327
Bridge::BridgeIO::paranoiac
void paranoiac(const char *format,...)
Definition: bridgeIO.cpp:238
AShiftsolver_CG::solve
void solve(std::vector< FIELD > &solution, const std::vector< double > &shift, const FIELD &source, int &Nconv, double &diff)
Definition: ashiftsolver_CG-tmpl.h:88
ParameterCheck::square_non_zero
int square_non_zero(const double v)
Definition: parameterCheck.cpp:43
ashiftsolver_CG.h
Bridge::BridgeIO::set_verbose_level
static VerboseLevel set_verbose_level(const std::string &str)
Definition: bridgeIO.cpp:133
Parameters::set_int
void set_int(const string &key, const int value)
Definition: parameters.cpp:36
scal
void scal(Field &x, const double a)
scal(x, a): x = a * x
Definition: field.cpp:261
Parameters::fetch_string
int fetch_string(const string &key, string &value) const
Definition: parameters.cpp:378
Parameters::fetch_double
int fetch_double(const string &key, double &value) const
Definition: parameters.cpp:327
Bridge::BridgeIO::crucial
void crucial(const char *format,...)
Definition: bridgeIO.cpp:180
AShiftsolver_CG
Multishift Conjugate Gradient solver.
Definition: ashiftsolver_CG.h:32
Parameters::fetch_int
int fetch_int(const string &key, int &value) const
Definition: parameters.cpp:346
Bridge::BridgeIO::general
void general(const char *format,...)
Definition: bridgeIO.cpp:200
Bridge::vout
BridgeIO vout
Definition: bridgeIO.cpp:512
Bridge::BridgeIO::get_verbose_level
static std::string get_verbose_level(const VerboseLevel vl)
Definition: bridgeIO.cpp:154