Bridge++  Ver. 1.2.x
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
solver_CG.cpp
Go to the documentation of this file.
1 
14 #include "solver_CG.h"
15 
16 
17 #ifdef USE_FACTORY
18 namespace {
19  Solver *create_object(Fopr *fopr)
20  {
21  return new Solver_CG(fopr);
22  }
23 
24 
25  bool init = Solver::Factory::Register("CG", create_object);
26 }
27 #endif
28 
29 //- parameter entries
30 namespace {
31  void append_entry(Parameters& param)
32  {
33  param.Register_int("maximum_number_of_iteration", 0);
34  param.Register_double("convergence_criterion_squared", 0.0);
35 
36  param.Register_string("verbose_level", "NULL");
37  }
38 
39 
40 #ifdef USE_PARAMETERS_FACTORY
41  bool init_param = ParametersFactory::Register("Solver.CG", append_entry);
42 #endif
43 }
44 //- end
45 
46 //- parameters class
48 //- end
49 
50 const std::string Solver_CG::class_name = "Solver_CG";
51 
52 //====================================================================
54 {
55  const string str_vlevel = params.get_string("verbose_level");
56 
57  m_vl = vout.set_verbose_level(str_vlevel);
58 
59  //- fetch and check input parameters
60  int Niter;
61  double Stop_cond;
62 
63  int err = 0;
64  err += params.fetch_int("maximum_number_of_iteration", Niter);
65  err += params.fetch_double("convergence_criterion_squared", Stop_cond);
66 
67  if (err) {
68  vout.crucial(m_vl, "%s: fetch error, input parameter not found.\n", class_name.c_str());
69  abort();
70  }
71 
72 
73  set_parameters(Niter, Stop_cond);
74 }
75 
76 
77 //====================================================================
78 void Solver_CG::set_parameters(const int Niter, const double Stop_cond)
79 {
81 
82  //- print input parameters
83  vout.general(m_vl, "Parameters of %s:\n", class_name.c_str());
84  vout.general(m_vl, " Niter = %d\n", Niter);
85  vout.general(m_vl, " Stop_cond = %16.8e\n", Stop_cond);
86 
87  //- range check
88  int err = 0;
89  err += ParameterCheck::non_negative(Niter);
90  err += ParameterCheck::square_non_zero(Stop_cond);
91 
92  if (err) {
93  vout.crucial(m_vl, "%s: parameter range check failed.\n", class_name.c_str());
94  abort();
95  }
96 
97  //- store values
98  m_Niter = Niter;
99  m_Stop_cond = Stop_cond;
100 }
101 
102 
103 //====================================================================
104 void Solver_CG::solve(Field& xq, const Field& b,
105  int& Nconv, double& diff)
106 {
107  //#pragma omp parallel
108  {
109  double bnorm2 = b.norm2();
110  double snorm = 1.0 / bnorm2;
111  int bsize = b.size();
112 
115 
116 #pragma omp master
117  {
118  vout.paranoiac(m_vl, "%s: solver starts\n", class_name.c_str());
119  vout.paranoiac(m_vl, " norm of b = %16.8e\n", bnorm2);
120  vout.paranoiac(m_vl, " size of b = %d\n", bsize);
121  }
122 #pragma omp barrier
123 
124 
125  reset_field(b);
126 
127  copy(s, b); // s = b;
128 
129  double rr;
130  int Nconv2 = -1;
131 
132  solve_init(b, rr);
133 
134  bool is_converged = false;
135 
136 #pragma omp master
137  vout.detailed(m_vl, " iter: %8d %22.15e\n", 0, rr * snorm);
138 #pragma omp barrier
139 
140 
141  for (int iter = 0; iter < m_Niter; iter++) {
142  if (!is_converged) {
143  solve_step(rr);
144 
145 #pragma omp master
146  vout.detailed(m_vl, " iter: %8d %22.15e\n", iter + 1, rr * snorm);
147 #pragma omp barrier
148 
149  if (rr * snorm < m_Stop_cond) {
150  m_fopr->mult(s, x); // s = m_fopr->mult(x);
151  axpy(s, -1.0, b); // s -= b;
152 
153  double diff2 = s.norm2();
155 
156  if (diff2 * snorm < m_Stop_cond) {
157  Nconv2 = iter + 1;
158 
159  // break;
160  is_converged = true;
161  }
162 
163  copy(s, x); // s = x;
164  solve_init(b, rr);
165  }
166  }
167  }
168  if (Nconv2 == -1) {
169 #pragma omp master
170  vout.crucial(m_vl, "%s: not converged.\n", class_name.c_str());
171 #pragma omp barrier
172  abort();
173  }
174 
175  copy(xq, x); // xq = x;
176 
177  m_fopr->mult(p, x); // p = m_fopr->mult(x);
178  axpy(p, -1.0, b); // p -= b;
179 
180  double diff2 = p.norm2();
181 
182 #pragma omp master
183  {
184  diff = sqrt(diff2);
185  Nconv = Nconv2;
186  }
187 #pragma omp barrier
188  } // end of parallel region
189 }
190 
191 
192 //====================================================================
193 void Solver_CG::reset_field(const Field& b)
194 {
195 #pragma omp master
196  {
197  int Nin = b.nin();
198  int Nvol = b.nvol();
199  int Nex = b.nex();
200 
201  if ((s.nin() != Nin) || (s.nvol() != Nvol) || (s.nex() != Nex)) {
202  s.reset(Nin, Nvol, Nex);
203  r.reset(Nin, Nvol, Nex);
204  x.reset(Nin, Nvol, Nex);
205  p.reset(Nin, Nvol, Nex);
206 
207  vout.detailed(m_vl, " %s: field size reset.\n", class_name.c_str());
208  }
209  }
210 #pragma omp barrier
211 }
212 
213 
214 //====================================================================
215 void Solver_CG::solve_init(const Field& b, double& rr)
216 {
217  copy(x, s); // x = s;
218 
219  //- r = b - A x
220  copy(r, b); // r = b;
221  m_fopr->mult(s, x); // s = m_fopr->mult(x);
222  axpy(r, -1.0, s); // r -= s;
223 
224  copy(p, r); // p = r;
225  rr = r.norm2(); // rr = r * r;
226 
227  // #pragma omp master
228  // rr_prev = rr;
229  // #pragma omp barrier
230 }
231 
232 
233 //====================================================================
234 void Solver_CG::solve_step(double& rr)
235 {
236  double rr_prev = rr;
237 
238  m_fopr->mult(s, p); // s = m_fopr->mult(p);
239 
240  double pap = dot(p, s); // pap = p * s;
241  double cr = rr_prev / pap;
242 
243  axpy(x, cr, p); // x += cr * p;
244  axpy(r, -cr, s); // r -= cr * s;
245 
246  rr = r.norm2(); // rr = r * r;
247 
248  double rr_ratio = rr / rr_prev;
249  aypx(rr_ratio, p, r); // p = (rr / rr_prev) * p + r
250 
251  // #pragma omp master
252  // rr_prev = rr;
253  // #pragma omp barrier
254 }
255 
256 
257 //====================================================================
258 //============================================================END=====
BridgeIO vout
Definition: bridgeIO.cpp:207
void detailed(const char *format,...)
Definition: bridgeIO.cpp:50
static int get_num_threads()
returns available number of threads.
void Register_string(const string &, const string &)
Definition: parameters.cpp:352
double norm2() const
Definition: field.cpp:469
void solve_init(const Field &, double &)
double dot(const Field &y, const Field &x)
Definition: field.cpp:46
virtual const Field mult(const Field &)=0
multiplies fermion operator to a given field and returns the resultant field.
void general(const char *format,...)
Definition: bridgeIO.cpp:38
void Register_int(const string &, const int)
Definition: parameters.cpp:331
double m_Stop_cond
Definition: solver_CG.h:49
void reset_field(const Field &)
void solve(Field &solution, const Field &source, int &Nconv, double &diff)
Container of Field-type object.
Definition: field.h:37
static const std::string class_name
Definition: solver_CG.h:44
void solve_step(double &)
int nvol() const
Definition: field.h:101
Class for parameters.
Definition: parameters.h:40
void copy(Field &y, const Field &x)
copy(y, x): y = x
Definition: field.cpp:409
Standard Conjugate Gradient solver algorithm.
Definition: solver_CG.h:41
Field x
Definition: solver_CG.h:51
static int get_thread_id()
returns thread id.
int square_non_zero(const double v)
Definition: checker.cpp:41
int nin() const
Definition: field.h:100
void set_parameters(const Parameters &params)
Definition: solver_CG.cpp:53
void reset(const int Nin, const int Nvol, const int Nex, const element_type cmpl=COMPLEX)
Definition: field.h:82
static void sync_barrier_all()
barrier among all the threads and nodes.
Field r
Definition: solver_CG.h:51
void aypx(const double a, Field &y, const Field &x)
aypx(y, a, x): y := a * y + x
Definition: field.cpp:489
int nex() const
Definition: field.h:102
void paranoiac(const char *format,...)
Definition: bridgeIO.cpp:62
void axpy(Field &y, const double a, const Field &x)
axpy(y, a, x): y := a * x + y
Definition: field.cpp:193
Field p
Definition: solver_CG.h:51
Fopr * m_fopr
Definition: solver_CG.h:47
void crucial(const char *format,...)
Definition: bridgeIO.cpp:26
Base class for linear solver class family.
Definition: solver.h:37
static bool Register(const std::string &realm, const creator_callback &cb)
int non_negative(const int v)
Definition: checker.cpp:21
Field s
Definition: solver_CG.h:51
void Register_double(const string &, const double)
Definition: parameters.cpp:324
Base class of fermion operator family.
Definition: fopr.h:39
int fetch_double(const string &key, double &val) const
Definition: parameters.cpp:124
string get_string(const string &key) const
Definition: parameters.cpp:85
int m_Niter
Definition: solver_CG.h:48
int fetch_int(const string &key, int &val) const
Definition: parameters.cpp:141
Bridge::VerboseLevel m_vl
Definition: solver.h:56
static VerboseLevel set_verbose_level(const std::string &str)
Definition: bridgeIO.cpp:191
static void assert_single_thread(const std::string &classname)
assert currently running on single thread.
int size() const
Definition: field.h:106