Bridge++  Ver. 1.1.x
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
fopr_Sign.cpp
Go to the documentation of this file.
1 
14 #include "fopr_Sign.h"
15 
16 #ifdef USE_PARAMETERS_FACTORY
17 #include "parameters_factory.h"
18 #endif
19 
20 using std::valarray;
21 
22 //- parameter entries
23 namespace {
24  void append_entry(Parameters& param)
25  {
26  param.Register_int("number_of_poles", 0);
27  param.Register_double("lower_bound", 0.0);
28  param.Register_double("upper_bound", 0.0);
29  param.Register_int("maximum_number_of_iteration", 0);
30  param.Register_double("convergence_criterion_squared", 0.0);
31 
32  param.Register_string("verbose_level", "NULL");
33  }
34 
35 
36 #ifdef USE_PARAMETERS_FACTORY
37  bool init_param = ParametersFactory::Register("Fopr.Sign", append_entry);
38 #endif
39 }
40 //- end
41 
42 //- parameters class
44 //- end
45 
46 //====================================================================
48 {
49  const string str_vlevel = params.get_string("verbose_level");
50 
51  m_vl = vout.set_verbose_level(str_vlevel);
52 
53  //- fetch and check input parameters
54  int Np;
55  double x_min, x_max;
56  int Niter;
57  double Stop_cond;
58 
59  int err = 0;
60  err += params.fetch_int("number_of_poles", Np);
61  err += params.fetch_double("lower_bound", x_min);
62  err += params.fetch_double("upper_bound", x_max);
63  err += params.fetch_int("maximum_number_of_iteration", Niter);
64  err += params.fetch_double("convergence_criterion_squared", Stop_cond);
65 
66  if (err) {
67  vout.crucial(m_vl, "Fopr_Sign: fetch error, input parameter not found.\n");
68  abort();
69  }
70 
71  set_parameters(Np, x_min, x_max, Niter, Stop_cond);
72 }
73 
74 
75 //====================================================================
76 void Fopr_Sign::set_parameters(int Np, double x_min, double x_max, int Niter,
77  double Stop_cond)
78 {
79  //- print input parameters
80  vout.general(m_vl, "Fopr_Sign parameters:\n");
81  vout.general(m_vl, " Np = %4d\n", Np);
82  vout.general(m_vl, " x_min = %12.8f\n", x_min);
83  vout.general(m_vl, " x_max = %12.6f\n", x_max);
84  vout.general(m_vl, " Niter = %6d\n", Niter);
85  vout.general(m_vl, " Stop_cond = %12.4e\n", Stop_cond);
86 
87  //- range check
88  int err = 0;
89  err += ParameterCheck::non_zero(Np);
90  // NB. x_min,x_max == 0 is allowed.
91  err += ParameterCheck::non_zero(Niter);
92  err += ParameterCheck::square_non_zero(Stop_cond);
93 
94  if (err) {
95  vout.crucial(m_vl, "Fopr_Sign: parameter range check failed.\n");
96  abort();
97  }
98 
99  //- store values
100  m_Np = Np;
101  m_x_min = x_min;
102  m_x_max = x_max;
103  m_Niter = Niter;
104  m_Stop_cond = Stop_cond;
105 
106  //- post-process
107  m_sigma.resize(m_Np);
108  m_cl.resize(2 * m_Np);
109  m_bl.resize(m_Np);
110 
111  init_parameters();
112 }
113 
114 
115 //====================================================================
116 void Fopr_Sign::set_lowmodes(int Nsbt, valarray<double> *ev,
117  valarray<Field> *vk)
118 {
119  m_Nsbt = Nsbt;
120  m_ev = ev;
121  m_vk = vk;
122 }
123 
124 
125 //====================================================================
127 {
128  delete m_solver;
129 }
130 
131 
132 //====================================================================
134 {
135  m_sigma.resize(m_Np);
136  m_cl.resize(2 * m_Np);
137  m_bl.resize(m_Np);
138 
139  // Zolotarev coefficient defined
140  double bmax = m_x_max / m_x_min;
141  Math_Sign_Zolotarev sign_func(m_Np, bmax);
142  sign_func.get_sign_parameters(m_cl, m_bl);
143 
144  for (int i = 0; i < m_Np; i++) {
145  m_sigma[i] = m_cl[2 * i] * m_x_min * m_x_min;
146  }
147 
148  for (int i = 0; i < m_Np; i++) {
149  vout.general(m_vl, " %3d %12.4e %12.4e %12.4e\n",
150  i, m_cl[i], m_cl[i + m_Np], m_bl[i]);
151  }
152 
153  int Nin = m_fopr->field_nin();
154  int Nvol = m_fopr->field_nvol();
155  int Nex = m_fopr->field_nex();
156  m_xq.resize(m_Np);
157  for (int i = 0; i < m_Np; ++i) {
158  m_xq[i].reset(Nin, Nvol, Nex);
159  }
160 
162 }
163 
164 
165 //====================================================================
166 void Fopr_Sign::mult(Field& v, const Field& b)
167 {
168  assert(b.nin() == m_fopr->field_nin());
169  assert(b.nvol() == m_fopr->field_nvol());
170  assert(b.nex() == m_fopr->field_nex());
171 
172  assert(v.nin() == m_fopr->field_nin());
173  assert(v.nvol() == m_fopr->field_nvol());
174  assert(v.nex() == m_fopr->field_nex());
175 
176  // vout.general(m_vl, " Sign function: Nsbt = %d\n",m_Nsbt);
177 
178  // Low-mode subtraction
179  Field b2(b);
180  if (m_Nsbt > 0) subtract_lowmodes(b2);
181 
182  // Shiftsolver
183 
184  int Nshift = m_Np;
185  int Nconv;
186  double diff;
187 
188  //vout.general(m_vl, " Shiftsolver in sign function\n");
189  //vout.general(m_vl, " Number of shift values = %d\n",m_sigma.size());
190 
191  m_fopr->set_mode("DdagD");
192  m_solver->solve(m_xq, m_sigma, b2, Nconv, diff);
193 
194  // Field v(b);
195  Field w(b);
196 
197  v = m_bl[0] * m_xq[0];
198  for (int i = 1; i < m_Np; i++) {
199  v += m_bl[i] * m_xq[i];
200  }
201 
202  w = m_fopr->mult(v);
203  double coeff = m_cl[2 * m_Np - 1] * m_x_min * m_x_min;
204  w += coeff * v;
205 
206  m_fopr->set_mode("H");
207  v = m_fopr->mult(w);
208  v /= m_x_min;
209 
210  if (m_Nsbt > 0) evaluate_lowmodes(v, b);
211 
212  // return v;
213 }
214 
215 
216 //====================================================================
218 {
219  if ((w.nin() % 2) != 0) abort();
220 
221  int Nin = w.nin();
222  int Nvol = w.nvol();
223  int Nex = w.nex();
224  double prd_r, prd_i;
225  double v_r, v_i;
226 
227  for (int k = 0; k < m_Nsbt; ++k) {
228  prd_r = 0.0;
229  prd_i = 0.0;
230  for (int ex = 0; ex < Nex; ++ex) {
231  for (int iv = 0; iv < Nvol; ++iv) {
232  for (int in = 0; in < Nin; in += 2) {
233  prd_r += (*m_vk)[k].cmp(in, iv, ex) * w.cmp(in, iv, ex)
234  + (*m_vk)[k].cmp(in + 1, iv, ex) * w.cmp(in + 1, iv, ex);
235  prd_i += (*m_vk)[k].cmp(in, iv, ex) * w.cmp(in + 1, iv, ex)
236  - (*m_vk)[k].cmp(in + 1, iv, ex) * w.cmp(in, iv, ex);
237  }
238  }
239  }
240 
241  prd_r = Communicator::reduce_sum(prd_r);
242  prd_i = Communicator::reduce_sum(prd_i);
243 
244  for (int ex = 0; ex < Nex; ++ex) {
245  for (int iv = 0; iv < Nvol; ++iv) {
246  for (int in = 0; in < Nin; in += 2) {
247  v_r = w.cmp(in, iv, ex) - prd_r * (*m_vk)[k].cmp(in, iv, ex)
248  + prd_i * (*m_vk)[k].cmp(in + 1, iv, ex);
249  v_i = w.cmp(in + 1, iv, ex) - prd_r * (*m_vk)[k].cmp(in + 1, iv, ex)
250  - prd_i * (*m_vk)[k].cmp(in, iv, ex);
251  w.set(in, iv, ex, v_r);
252  w.set(in + 1, iv, ex, v_i);
253  }
254  }
255  }
256  }
257 }
258 
259 
260 //====================================================================
262 {
263  if ((w.nin() % 2) != 0) abort();
264 
265  int Nin = w.nin();
266  int Nvol = w.nvol();
267  int Nex = w.nex();
268  double prd_r, prd_i;
269  double v_r, v_i;
270 
271  for (int k = 0; k < m_Nsbt; ++k) {
272  prd_r = 0.0;
273  prd_i = 0.0;
274  for (int ex = 0; ex < Nex; ++ex) {
275  for (int iv = 0; iv < Nvol; ++iv) {
276  for (int in = 0; in < Nin; in += 2) {
277  prd_r += (*m_vk)[k].cmp(in, iv, ex) * w.cmp(in, iv, ex)
278  + (*m_vk)[k].cmp(in + 1, iv, ex) * w.cmp(in + 1, iv, ex);
279  prd_i += (*m_vk)[k].cmp(in, iv, ex) * w.cmp(in + 1, iv, ex)
280  - (*m_vk)[k].cmp(in + 1, iv, ex) * w.cmp(in, iv, ex);
281  }
282  }
283  }
284 
285  prd_r = Communicator::reduce_sum(prd_r);
286  prd_i = Communicator::reduce_sum(prd_i);
287 
288  double sgn = (*m_ev)[k] / fabs((*m_ev)[k]);
289  prd_r *= sgn;
290  prd_i *= sgn;
291 
292  for (int ex = 0; ex < Nex; ++ex) {
293  for (int iv = 0; iv < Nvol; ++iv) {
294  for (int in = 0; in < Nin; in += 2) {
295  v_r = x.cmp(in, iv, ex) + prd_r * (*m_vk)[k].cmp(in, iv, ex)
296  - prd_i * (*m_vk)[k].cmp(in + 1, iv, ex);
297  v_i = x.cmp(in + 1, iv, ex) + prd_r * (*m_vk)[k].cmp(in + 1, iv, ex)
298  + prd_i * (*m_vk)[k].cmp(in, iv, ex);
299  x.set(in, iv, ex, v_r);
300  x.set(in + 1, iv, ex, v_i);
301  }
302  }
303  }
304  }
305 }
306 
307 
308 //====================================================================
310 {
311 // cl[2*Np], bl[Np]: coefficients of rational approx.
312 
313  double x2R = 0.0;
314 
315  for (int l = 0; l < m_Np; l++) {
316  x2R += m_bl[l] / (x * x + m_cl[2 * l]);
317  }
318  x2R = x2R * (x * x + m_cl[2 * m_Np - 1]);
319 
320  return x * x2R;
321 }
322 
323 
324 //====================================================================
325 //============================================================END=====