Bridge++  Version 1.5.4
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
fft_xyz_3dim.cpp
Go to the documentation of this file.
1 
14 #ifdef USE_FFTWLIB
15 
16 #include "fft_xyz_3dim.h"
17 
18 #ifdef USE_FACTORY_AUTOREGISTER
19 namespace {
20  bool init = FFT_xyz_3dim::register_factory();
21 }
22 #endif
23 
24 const std::string FFT_xyz_3dim::class_name = "FFT_xyz_3dim";
25 
26 //====================================================================
27 void FFT_xyz_3dim::set_parameters(const Parameters& params)
28 {
29  const std::string str_vlevel = params.get_string("verbose_level");
30 
31  m_vl = vout.set_verbose_level(str_vlevel);
32 
33  //- fetch and check input parameters
34  string str_fft_direction;
35 
36  int err = 0;
37  err += params.fetch_string("FFT_direction", str_fft_direction);
38 
39  if (err) {
40  vout.crucial(m_vl, "Error at %s: input parameter not found.\n", class_name.c_str());
41  exit(EXIT_FAILURE);
42  }
43 
44  set_parameters(str_fft_direction);
45 }
46 
47 
48 //====================================================================
49 void FFT_xyz_3dim::set_parameters(const string& str_fft_direction)
50 {
51  //- print input parameters
52  vout.general(m_vl, "%s:\n", class_name.c_str());
53  vout.general(m_vl, " FFT_direction = %s\n", str_fft_direction.c_str());
54 
55  //- range check
56 
57  //- store values
58  if (str_fft_direction == "Forward") {
59  m_is_forward = true;
60  } else if (str_fft_direction == "Backward") {
61  m_is_forward = false;
62  } else {
63  vout.crucial(m_vl, "Error at %s: unsupported FFT direction \"%s\"\n", class_name.c_str(), str_fft_direction.c_str());
64  exit(EXIT_FAILURE);
65  }
66 }
67 
68 
69 //====================================================================
70 void FFT_xyz_3dim::init()
71 {
72  //- global lattice size
73  const int Lx = CommonParameters::Lx();
74  const int Ly = CommonParameters::Ly();
75  const int Lz = CommonParameters::Lz();
76 
77 #ifdef USE_OPENMP
78  int threads_ok = fftw_init_threads();
79 #endif
80 
81 #ifdef USE_MPI
82  const int NPE_x = CommonParameters::NPEx();
83  const int NPE_y = CommonParameters::NPEy();
84  // const int NPE_z = CommonParameters::NPEz();
85  const int NPE_t = CommonParameters::NPEt();
86 
87  if ((NPE_x * NPE_y * NPE_t) != 1) {
88  vout.crucial(m_vl, "Error at %s: FFTW supports parallelization only in z-direction.\n",
89  class_name.c_str());
90  exit(EXIT_FAILURE);
91  }
92 
93 
94  fftw_mpi_init();
95 
96 
97  //- allocate m_in,out = m_in,out[Nz][Ly][Lx]
98  const ptrdiff_t Lx_p = CommonParameters::Lx();
99  const ptrdiff_t Ly_p = CommonParameters::Ly();
100  const ptrdiff_t Lz_p = CommonParameters::Lz();
101 
102  ptrdiff_t fftw_size_p = fftw_mpi_local_size_3d(Lz_p, Ly_p, Lx_p,
104  &m_Nz_p, &m_z_start_p);
105 
106  m_in = fftw_alloc_complex(fftw_size_p);
107  m_out = fftw_alloc_complex(fftw_size_p);
108 
109  if (!m_in || !m_out) {
110  vout.crucial(m_vl, "Error at %s: failed to allocate memory %d [Byte].\n",
111  class_name.c_str(), (int)fftw_size_p);
112  exit(EXIT_FAILURE);
113  }
114 #else
115  //- allocate m_in,out = m_in,out[Nz][Ly][Lx]
116  const size_t fftw_size = sizeof(fftw_complex) * Lx * Ly * Lz;
117  m_in = (fftw_complex *)fftw_malloc(fftw_size);
118  m_out = (fftw_complex *)fftw_malloc(fftw_size);
119 
120  if (!m_in || !m_out) {
121  vout.crucial(m_vl, "Error at %s: failed to allocate memory %d [Byte].\n",
122  class_name.c_str(), (int)fftw_size);
123  exit(EXIT_FAILURE);
124  }
125 #endif
126 }
127 
128 
129 //====================================================================
130 void FFT_xyz_3dim::tidy_up()
131 {
132  if (m_in) fftw_free(m_in);
133  if (m_out) fftw_free(m_out);
134  if (m_plan) fftw_destroy_plan(m_plan);
135 }
136 
137 
138 //====================================================================
139 void FFT_xyz_3dim::fft(Field& field)
140 {
141  //- global lattice size
142  const int Lx = CommonParameters::Lx();
143  const int Ly = CommonParameters::Ly();
144  const int Lz = CommonParameters::Lz();
145  const int Lt = CommonParameters::Lt();
146  const int Lxyz = Lx * Ly * Lz;
147 
148  //- local size
149  const int Nz = CommonParameters::Nz();
150 
151  const int Nin = field.nin();
152  const int Nex = field.nex();
153 
154 
155  //- setup FFTW plan
156 #ifdef USE_OPENMP
157  const int Nthread = ThreadManager_OpenMP::get_num_threads();
158  fftw_plan_with_nthreads(Nthread);
159 #endif
160 #ifdef USE_MPI
161  const ptrdiff_t Lx_p = CommonParameters::Lx();
162  const ptrdiff_t Ly_p = CommonParameters::Ly();
163  const ptrdiff_t Lz_p = CommonParameters::Lz();
164 
165  if (m_is_forward) {
166  m_plan = fftw_mpi_plan_dft_3d(Lz_p, Ly_p, Lx_p, m_in, m_out,
168  FFTW_FORWARD, FFTW_ESTIMATE);
169  } else {
170  m_plan = fftw_mpi_plan_dft_3d(Lz_p, Ly_p, Lx_p, m_in, m_out,
172  FFTW_BACKWARD, FFTW_ESTIMATE);
173  }
174 #else
175  if (m_is_forward) {
176  m_plan = fftw_plan_dft_3d(Lz, Ly, Lx, m_in, m_out,
177  FFTW_FORWARD, FFTW_ESTIMATE);
178  } else {
179  m_plan = fftw_plan_dft_3d(Lz, Ly, Lx, m_in, m_out,
180  FFTW_BACKWARD, FFTW_ESTIMATE);
181  }
182 #endif
183 
184 
185  // #### Execution main part ####
186  //- Nin is devided by 2, because of complex(i.e. real and imag)
187  for (int in2 = 0; in2 < Nin / 2; ++in2) {
188  for (int t_global = 0; t_global < Lt; t_global++) {
189  for (int ex = 0; ex < Nex; ++ex) {
190  //- input data
191  for (int z = 0; z < Nz; z++) {
192  for (int y_global = 0; y_global < Ly; y_global++) {
193  for (int x_global = 0; x_global < Lx; x_global++) {
194  int isite_xyz_local = x_global + Lx * (y_global + Ly * z);
195 
196  int isite = m_index.site(x_global, y_global, z, t_global);
197  int i_real = 2 * in2;
198  int i_imag = 2 * in2 + 1;
199 
200  m_in[isite_xyz_local][0] = field.cmp(i_real, isite, ex);
201  m_in[isite_xyz_local][1] = field.cmp(i_imag, isite, ex);
202  }
203  }
204  }
205 
206 
207  fftw_execute(m_plan);
208 
209 
210  //- output data
211  for (int z = 0; z < Nz; z++) {
212  for (int y_global = 0; y_global < Ly; y_global++) {
213  for (int x_global = 0; x_global < Lx; x_global++) {
214  int isite_xyz_local = x_global + Lx * (y_global + Ly * z);
215 
216  int isite = m_index.site(x_global, y_global, z, t_global);
217  int i_real = 2 * in2;
218  int i_imag = 2 * in2 + 1;
219 
220  field.set(i_real, isite, ex, m_out[isite_xyz_local][0]);
221  field.set(i_imag, isite, ex, m_out[isite_xyz_local][1]);
222  }
223  }
224  }
225  }
226  }
227  }
228  //- end of global loops
229 
230  //- normailzation for FFTW_BACKWARD
231  if (!m_is_forward) {
232  scal(field, 1.0 / Lxyz);
233  }
234 }
235 
236 
237 //====================================================================
238 void FFT_xyz_3dim::fft(Field& field_out, const Field& field_in)
239 {
240  //- global lattice size
241  const int Lx = CommonParameters::Lx();
242  const int Ly = CommonParameters::Ly();
243  const int Lz = CommonParameters::Lz();
244  const int Lt = CommonParameters::Lt();
245  const int Lxyz = Lx * Ly * Lz;
246 
247  //- local size
248  const int Nz = CommonParameters::Nz();
249 
250  const int Nin = field_in.nin();
251  const int Nex = field_in.nex();
252 
253 
254  //- setup FFTW plan
255 #ifdef USE_OPENMP
256  const int Nthread = ThreadManager_OpenMP::get_num_threads();
257  fftw_plan_with_nthreads(Nthread);
258 #endif
259 #ifdef USE_MPI
260  const ptrdiff_t Lx_p = CommonParameters::Lx();
261  const ptrdiff_t Ly_p = CommonParameters::Ly();
262  const ptrdiff_t Lz_p = CommonParameters::Lz();
263 
264  if (m_is_forward) {
265  m_plan = fftw_mpi_plan_dft_3d(Lz_p, Ly_p, Lx_p, m_in, m_out,
267  FFTW_FORWARD, FFTW_ESTIMATE);
268  } else {
269  m_plan = fftw_mpi_plan_dft_3d(Lz_p, Ly_p, Lx_p, m_in, m_out,
271  FFTW_BACKWARD, FFTW_ESTIMATE);
272  }
273 #else
274  if (m_is_forward) {
275  m_plan = fftw_plan_dft_3d(Lz, Ly, Lx, m_in, m_out,
276  FFTW_FORWARD, FFTW_ESTIMATE);
277  } else {
278  m_plan = fftw_plan_dft_3d(Lz, Ly, Lx, m_in, m_out,
279  FFTW_BACKWARD, FFTW_ESTIMATE);
280  }
281 #endif
282 
283 
284  // #### Execution main part ####
285  //- Nin is devided by 2, because of complex(i.e. real and imag)
286  for (int in2 = 0; in2 < Nin / 2; ++in2) {
287  for (int t_global = 0; t_global < Lt; t_global++) {
288  for (int ex = 0; ex < Nex; ++ex) {
289  //- input data
290  for (int z = 0; z < Nz; z++) {
291  for (int y_global = 0; y_global < Ly; y_global++) {
292  for (int x_global = 0; x_global < Lx; x_global++) {
293  int isite_xyz_local = x_global + Lx * (y_global + Ly * z);
294 
295  int isite = m_index.site(x_global, y_global, z, t_global);
296  int i_real = 2 * in2;
297  int i_imag = 2 * in2 + 1;
298 
299  m_in[isite_xyz_local][0] = field_in.cmp(i_real, isite, ex);
300  m_in[isite_xyz_local][1] = field_in.cmp(i_imag, isite, ex);
301  }
302  }
303  }
304 
305 
306  fftw_execute(m_plan);
307 
308 
309  //- output data
310  for (int z = 0; z < Nz; z++) {
311  for (int y_global = 0; y_global < Ly; y_global++) {
312  for (int x_global = 0; x_global < Lx; x_global++) {
313  int isite_xyz_local = x_global + Lx * (y_global + Ly * z);
314 
315  int isite = m_index.site(x_global, y_global, z, t_global);
316  int i_real = 2 * in2;
317  int i_imag = 2 * in2 + 1;
318 
319  field_out.set(i_real, isite, ex, m_out[isite_xyz_local][0]);
320  field_out.set(i_imag, isite, ex, m_out[isite_xyz_local][1]);
321  }
322  }
323  }
324  }
325  }
326  }
327  //- end of global loops
328 
329  //- normailzation for FFTW_BACKWARD
330  if (!m_is_forward) {
331  scal(field_out, 1.0 / Lxyz);
332  }
333 }
334 
335 
336 //====================================================================
337 void FFT_xyz_3dim::fft(Field& field_out, const Field& field_in, const Direction dir)
338 {
339  // save state
340  bool backup_fwbw = m_is_forward;
341 
342  // find direction and set
343  if (dir == FORWARD) {
344  m_is_forward = true;
345  } else if (dir == BACKWARD) {
346  m_is_forward = false;
347  } else {
348  vout.crucial(m_vl, "%s: unknown direction %d. failed.\n", class_name.c_str(), dir);
349  exit(EXIT_FAILURE);
350  }
351 
352  // delegate to another method
353  fft(field_out, field_in);
354 
355  // restore state
356  m_is_forward = backup_fwbw;
357 }
358 
359 
360 //==========================================================
361 //==================================================END=====
362 #endif
void scal(Field &x, const double a)
scal(x, a): x = a * x
Definition: field.cpp:433
BridgeIO vout
Definition: bridgeIO.cpp:503
static int get_num_threads()
returns available number of threads.
static int NPEy()
void set(const int jin, const int site, const int jex, double v)
Definition: field.h:175
void general(const char *format,...)
Definition: bridgeIO.cpp:197
Container of Field-type object.
Definition: field.h:45
double cmp(const int jin, const int site, const int jex) const
Definition: field.h:143
Class for parameters.
Definition: parameters.h:46
int fetch_string(const string &key, string &value) const
Definition: parameters.cpp:378
int nin() const
Definition: field.h:126
static MPI_Comm & world()
retrieves current communicator.
int nex() const
Definition: field.h:128
static int NPEx()
void crucial(const char *format,...)
Definition: bridgeIO.cpp:178
Direction
Definition: bridge_defs.h:24
string get_string(const string &key) const
Definition: parameters.cpp:221
static VerboseLevel set_verbose_level(const std::string &str)
Definition: bridgeIO.cpp:131