Bridge++  Ver. 1.1.x
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
shiftsolver_CG.cpp
Go to the documentation of this file.
1 
14 #include "shiftsolver_CG.h"
15 
16 //- parameter entries
17 namespace {
18  void append_entry(Parameters& param)
19  {
20  param.Register_int("maximum_number_of_iteration", 0);
21  param.Register_double("convergence_criterion_squared", 0.0);
22 
23  param.Register_string("verbose_level", "NULL");
24  }
25 }
26 //- end
27 
28 //- parameters class
30 //- end
31 
32 //====================================================================
34 {
35  const string str_vlevel = params.get_string("verbose_level");
36 
37  m_vl = vout.set_verbose_level(str_vlevel);
38 
39  //- fetch and check input parameters
40  int Niter;
41  double Stop_cond;
42 
43  int err = 0;
44  err += params.fetch_int("maximum_number_of_iteration", Niter);
45  err += params.fetch_double("convergence_criterion_squared", Stop_cond);
46 
47  if (err) {
48  vout.crucial(m_vl, "Shiftsolver_CG: fetch error, input parameter not found.\n");
49  abort();
50  }
51 
52 
53  set_parameters(Niter, Stop_cond);
54 }
55 
56 
57 //====================================================================
58 void Shiftsolver_CG::set_parameters(const int Niter, const double Stop_cond)
59 {
60  //- print input parameters
61  vout.general(m_vl, "Parameters of Shiftsolver_CG:\n");
62  vout.general(m_vl, " Niter = %d\n", Niter);
63  vout.general(m_vl, " Stop_cond = %16.8e\n", Stop_cond);
64 
65  //- range check
66  int err = 0;
67  err += ParameterCheck::non_negative(Niter);
68  err += ParameterCheck::square_non_zero(Stop_cond);
69 
70  if (err) {
71  vout.crucial(m_vl, "Shiftsolver_CG: parameter range check failed.\n");
72  abort();
73  }
74 
75  //- store values
76  m_Niter = Niter;
77  m_Stop_cond = Stop_cond;
78 }
79 
80 
81 //====================================================================
82 void Shiftsolver_CG::solve(std::valarray<Field>& xq,
83  std::valarray<double> sigma,
84  const Field& b,
85  int& Nconv, double& diff)
86 {
87  vout.paranoiac(m_vl, " Shift CG solver start.\n");
88 
89  int Nshift = sigma.size();
90 
91  vout.paranoiac(m_vl, " number of shift = %d\n", Nshift);
92  vout.paranoiac(m_vl, " values of shift:\n");
93  for(int i = 0; i<Nshift; ++i){
94  vout.paranoiac(m_vl, " %d %12.8f\n", i, sigma[i]);
95  }
96 
97  snorm = 1.0 / b.norm2();
98 
99  Nconv = -1;
100 
101  int Nin = b.nin();
102  int Nvol = b.nvol();
103  int Nex = b.nex();
104 
105  p.resize(Nshift);
106  x.resize(Nshift);
107  zeta1.resize(Nshift);
108  zeta2.resize(Nshift);
109  csh2.resize(Nshift);
110  pp.resize(Nshift);
111 
112  for (int i = 0; i < Nshift; ++i) {
113  p[i].reset(Nin, Nvol, Nex);
114  x[i].reset(Nin, Nvol, Nex);
115  zeta1[i] = 1.0;
116  zeta2[i] = 1.0;
117  csh2[i] = sigma[i] - sigma[0];
118  }
119  s.reset(Nin, Nvol, Nex);
120  r.reset(Nin, Nvol, Nex);
121  s = b;
122  r = b;
123 
124  double rr;
125  Nshift2 = Nshift;
126 
127  solve_init(rr);
128 
129  vout.detailed(m_vl, " iter: %8d %22.15e\n", 0, rr * snorm);
130 
131  for (int iter = 0; iter < m_Niter; iter++) {
132  solve_step(rr, sigma);
133 
134  vout.detailed(m_vl, " iter: %8d %22.15e %4d\n", (iter + 1), rr * snorm, Nshift2);
135 
136  if (rr * snorm < m_Stop_cond) {
137  Nconv = iter;
138  break;
139  }
140  }
141  if (Nconv == -1) {
142  vout.crucial(m_vl, "Shiftsolver_CG not converged.\n");
143  abort();
144  }
145 
146 
147  diff = -1.0;
148  for (int i = 0; i < Nshift; ++i) {
149  s = m_fopr->mult(x[i]);
150  s += sigma[i] * x[i];
151  s -= b;
152  double diff1 = s * s;
153  diff1 = sqrt(diff1 * snorm);
154 
155  vout.paranoiac(m_vl, " %4d %22.15e\n", i, diff1);
156 
157  if (diff1 > diff) diff = diff1;
158  }
159 
160  vout.paranoiac(m_vl, " diff(max) = %22.15e \n", diff);
161 
162  for (int i = 0; i < Nshift; ++i) {
163  xq[i] = x[i];
164  }
165 }
166 
167 
168 //====================================================================
170 {
171  int Nshift = p.size();
172 
173  vout.paranoiac(m_vl, "number of shift = %d\n", Nshift);
174 
175  for (int i = 0; i < Nshift; ++i) {
176  p[i] = s;
177  x[i] = 0.0;
178  }
179 
180  r = s;
181  rr = r * r;
182  alpha_p = 0.0;
183  beta_p = 1.0;
184 }
185 
186 
187 //====================================================================
189  const std::valarray<double>& sigma)
190 {
191  s = m_fopr->mult(p[0]);
192  s += sigma[0] * p[0];
193 
194  double rr_p = rr;
195  double pa_p = s * p[0];
196  double beta = -rr_p / pa_p;
197 
198  x[0] -= beta * p[0];
199  r += beta * s;
200  rr = r * r;
201 
202  double alpha = rr / rr_p;
203 
204  p[0] *= alpha;
205  p[0] += r;
206 
207  pp[0] = rr;
208 
209  double alpha_h = 1.0 + alpha_p * beta / beta_p;
210  for (int ish = 1; ish < Nshift2; ++ish) {
211  double zeta = (alpha_h - csh2[ish] * beta) / zeta1[ish]
212  + (1.0 - alpha_h) / zeta2[ish];
213  zeta = 1.0 / zeta;
214  double zr = zeta / zeta1[ish];
215  double beta_s = beta * zr;
216  double alpha_s = alpha * zr * zr;
217 
218  x[ish] -= beta_s * p[ish];
219  p[ish] *= alpha_s;
220  p[ish] += zeta * r;
221 
222  pp[ish] = p[ish] * p[ish];
223  pp[ish] *= snorm;
224 
225  zeta2[ish] = zeta1[ish];
226  zeta1[ish] = zeta;
227  }
228 
229  for (int ish = Nshift2 - 1; ish >= 0; --ish) {
230 
231  vout.paranoiac(m_vl, "%4d %16.8e\n", ish, pp[ish]);
232 
233  if (pp[ish] > m_Stop_cond) {
234  Nshift2 = ish + 1;
235  break;
236  }
237  }
238 
239  alpha_p = alpha;
240  beta_p = beta;
241 }
242 
243 
244 //====================================================================
245 //============================================================END=====