Bridge++  Version 1.4.4
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
fft_xyz_3dim.cpp
Go to the documentation of this file.
1 
14 #ifdef USE_FFTWLIB
15 
16 #include "fft_xyz_3dim.h"
17 
18 #ifdef USE_FACTORY
19 namespace {
20  FFT *create_object()
21  {
22  return new FFT_xyz_3dim();
23  }
24 
25 
26  bool init = FFT::Factory::Register("FFT_xyz_3dim", create_object);
27 }
28 #endif
29 
30 const std::string FFT_xyz_3dim::class_name = "FFT_xyz_3dim";
31 
32 //====================================================================
33 void FFT_xyz_3dim::set_parameters(const Parameters& params)
34 {
35  const std::string str_vlevel = params.get_string("verbose_level");
36 
37  m_vl = vout.set_verbose_level(str_vlevel);
38 
39  //- fetch and check input parameters
40  string str_fft_direction;
41 
42  int err = 0;
43  err += params.fetch_string("FFT_direction", str_fft_direction);
44 
45  if (err) {
46  vout.crucial(m_vl, "Error at %s: input parameter not found.\n", class_name.c_str());
47  exit(EXIT_FAILURE);
48  }
49 
50  set_parameters(str_fft_direction);
51 }
52 
53 
54 //====================================================================
55 void FFT_xyz_3dim::set_parameters(const string str_fft_direction)
56 {
57  //- print input parameters
58  vout.general(m_vl, "%s:\n", class_name.c_str());
59  vout.general(m_vl, " FFT_direction = %s\n", str_fft_direction.c_str());
60 
61  //- range check
62 
63  //- store values
64  if (str_fft_direction == "Forward") {
65  m_is_forward = true;
66  } else if (str_fft_direction == "Backward") {
67  m_is_forward = false;
68  } else {
69  vout.crucial(m_vl, "Error at %s: unsupported FFT direction \"%s\"\n", class_name.c_str(), str_fft_direction.c_str());
70  exit(EXIT_FAILURE);
71  }
72 }
73 
74 
75 //====================================================================
76 void FFT_xyz_3dim::init()
77 {
78  //- global lattice size
79  const int Lx = CommonParameters::Lx();
80  const int Ly = CommonParameters::Ly();
81  const int Lz = CommonParameters::Lz();
82 
83 #ifdef USE_OPENMP
84  int threads_ok = fftw_init_threads();
85 #endif
86 
87 #ifdef USE_MPI
88  const int NPE_x = CommonParameters::NPEx();
89  const int NPE_y = CommonParameters::NPEy();
90  // const int NPE_z = CommonParameters::NPEz();
91  const int NPE_t = CommonParameters::NPEt();
92 
93  if ((NPE_x * NPE_y * NPE_t) != 1) {
94  vout.crucial(m_vl, "Error at %s: FFTW supports parallelization only in z-direction.\n",
95  class_name.c_str());
96  exit(EXIT_FAILURE);
97  }
98 
99 
100  fftw_mpi_init();
101 
102 
103  //- allocate m_in,out = m_in,out[Nz][Ly][Lx]
104  const ptrdiff_t Lx_p = CommonParameters::Lx();
105  const ptrdiff_t Ly_p = CommonParameters::Ly();
106  const ptrdiff_t Lz_p = CommonParameters::Lz();
107 
108  ptrdiff_t fftw_size_p = fftw_mpi_local_size_3d(Lz_p, Ly_p, Lx_p,
110  &m_Nz_p, &m_z_start_p);
111 
112  m_in = fftw_alloc_complex(fftw_size_p);
113  m_out = fftw_alloc_complex(fftw_size_p);
114 
115  if (!m_in || !m_out) {
116  vout.crucial(m_vl, "Error at %s: failed to allocate memory %d [Byte].\n",
117  class_name.c_str(), (int)fftw_size_p);
118  exit(EXIT_FAILURE);
119  }
120 #else
121  //- allocate m_in,out = m_in,out[Nz][Ly][Lx]
122  const size_t fftw_size = sizeof(fftw_complex) * Lx * Ly * Lz;
123  m_in = (fftw_complex *)fftw_malloc(fftw_size);
124  m_out = (fftw_complex *)fftw_malloc(fftw_size);
125 
126  if (!m_in || !m_out) {
127  vout.crucial(m_vl, "Error at %s: failed to allocate memory %d [Byte].\n",
128  class_name.c_str(), (int)fftw_size);
129  exit(EXIT_FAILURE);
130  }
131 #endif
132 }
133 
134 
135 //====================================================================
136 void FFT_xyz_3dim::tidy_up()
137 {
138  if (m_in) fftw_free(m_in);
139  if (m_out) fftw_free(m_out);
140  if (m_plan) fftw_destroy_plan(m_plan);
141 }
142 
143 
144 //====================================================================
145 void FFT_xyz_3dim::fft(Field& field)
146 {
147  //- global lattice size
148  const int Lx = CommonParameters::Lx();
149  const int Ly = CommonParameters::Ly();
150  const int Lz = CommonParameters::Lz();
151  const int Lt = CommonParameters::Lt();
152  const int Lxyz = Lx * Ly * Lz;
153 
154  //- local size
155  const int Nz = CommonParameters::Nz();
156 
157  const int Nin = field.nin();
158  const int Nex = field.nex();
159 
160 
161  //- setup FFTW plan
162 #ifdef USE_OPENMP
163  const int Nthread = ThreadManager_OpenMP::get_num_threads();
164  fftw_plan_with_nthreads(Nthread);
165 #endif
166 #ifdef USE_MPI
167  const ptrdiff_t Lx_p = CommonParameters::Lx();
168  const ptrdiff_t Ly_p = CommonParameters::Ly();
169  const ptrdiff_t Lz_p = CommonParameters::Lz();
170 
171  if (m_is_forward) {
172  m_plan = fftw_mpi_plan_dft_3d(Lz_p, Ly_p, Lx_p, m_in, m_out,
174  FFTW_FORWARD, FFTW_ESTIMATE);
175  } else {
176  m_plan = fftw_mpi_plan_dft_3d(Lz_p, Ly_p, Lx_p, m_in, m_out,
178  FFTW_BACKWARD, FFTW_ESTIMATE);
179  }
180 #else
181  if (m_is_forward) {
182  m_plan = fftw_plan_dft_3d(Lz, Ly, Lx, m_in, m_out,
183  FFTW_FORWARD, FFTW_ESTIMATE);
184  } else {
185  m_plan = fftw_plan_dft_3d(Lz, Ly, Lx, m_in, m_out,
186  FFTW_BACKWARD, FFTW_ESTIMATE);
187  }
188 #endif
189 
190 
191  // #### Execution main part ####
192  //- Nin is devided by 2, because of complex(i.e. real and imag)
193  for (int in2 = 0; in2 < Nin / 2; ++in2) {
194  for (int t_global = 0; t_global < Lt; t_global++) {
195  for (int ex = 0; ex < Nex; ++ex) {
196  //- input data
197  for (int z = 0; z < Nz; z++) {
198  for (int y_global = 0; y_global < Ly; y_global++) {
199  for (int x_global = 0; x_global < Lx; x_global++) {
200  int isite_xyz_local = x_global + Lx * (y_global + Ly * z);
201 
202  int isite = m_index.site(x_global, y_global, z, t_global);
203  int i_real = 2 * in2;
204  int i_imag = 2 * in2 + 1;
205 
206  m_in[isite_xyz_local][0] = field.cmp(i_real, isite, ex);
207  m_in[isite_xyz_local][1] = field.cmp(i_imag, isite, ex);
208  }
209  }
210  }
211 
212 
213  fftw_execute(m_plan);
214 
215 
216  //- output data
217  for (int z = 0; z < Nz; z++) {
218  for (int y_global = 0; y_global < Ly; y_global++) {
219  for (int x_global = 0; x_global < Lx; x_global++) {
220  int isite_xyz_local = x_global + Lx * (y_global + Ly * z);
221 
222  int isite = m_index.site(x_global, y_global, z, t_global);
223  int i_real = 2 * in2;
224  int i_imag = 2 * in2 + 1;
225 
226  field.set(i_real, isite, ex, m_out[isite_xyz_local][0]);
227  field.set(i_imag, isite, ex, m_out[isite_xyz_local][1]);
228  }
229  }
230  }
231  }
232  }
233  }
234  //- end of global loops
235 
236  //- normailzation for FFTW_BACKWARD
237  if (!m_is_forward) {
238  scal(field, 1.0 / Lxyz);
239  }
240 }
241 
242 
243 //====================================================================
244 void FFT_xyz_3dim::fft(Field& field_out, const Field& field_in)
245 {
246  //- global lattice size
247  const int Lx = CommonParameters::Lx();
248  const int Ly = CommonParameters::Ly();
249  const int Lz = CommonParameters::Lz();
250  const int Lt = CommonParameters::Lt();
251  const int Lxyz = Lx * Ly * Lz;
252 
253  //- local size
254  const int Nz = CommonParameters::Nz();
255 
256  const int Nin = field_in.nin();
257  const int Nex = field_in.nex();
258 
259 
260  //- setup FFTW plan
261 #ifdef USE_OPENMP
262  const int Nthread = ThreadManager_OpenMP::get_num_threads();
263  fftw_plan_with_nthreads(Nthread);
264 #endif
265 #ifdef USE_MPI
266  const ptrdiff_t Lx_p = CommonParameters::Lx();
267  const ptrdiff_t Ly_p = CommonParameters::Ly();
268  const ptrdiff_t Lz_p = CommonParameters::Lz();
269 
270  if (m_is_forward) {
271  m_plan = fftw_mpi_plan_dft_3d(Lz_p, Ly_p, Lx_p, m_in, m_out,
273  FFTW_FORWARD, FFTW_ESTIMATE);
274  } else {
275  m_plan = fftw_mpi_plan_dft_3d(Lz_p, Ly_p, Lx_p, m_in, m_out,
277  FFTW_BACKWARD, FFTW_ESTIMATE);
278  }
279 #else
280  if (m_is_forward) {
281  m_plan = fftw_plan_dft_3d(Lz, Ly, Lx, m_in, m_out,
282  FFTW_FORWARD, FFTW_ESTIMATE);
283  } else {
284  m_plan = fftw_plan_dft_3d(Lz, Ly, Lx, m_in, m_out,
285  FFTW_BACKWARD, FFTW_ESTIMATE);
286  }
287 #endif
288 
289 
290  // #### Execution main part ####
291  //- Nin is devided by 2, because of complex(i.e. real and imag)
292  for (int in2 = 0; in2 < Nin / 2; ++in2) {
293  for (int t_global = 0; t_global < Lt; t_global++) {
294  for (int ex = 0; ex < Nex; ++ex) {
295  //- input data
296  for (int z = 0; z < Nz; z++) {
297  for (int y_global = 0; y_global < Ly; y_global++) {
298  for (int x_global = 0; x_global < Lx; x_global++) {
299  int isite_xyz_local = x_global + Lx * (y_global + Ly * z);
300 
301  int isite = m_index.site(x_global, y_global, z, t_global);
302  int i_real = 2 * in2;
303  int i_imag = 2 * in2 + 1;
304 
305  m_in[isite_xyz_local][0] = field_in.cmp(i_real, isite, ex);
306  m_in[isite_xyz_local][1] = field_in.cmp(i_imag, isite, ex);
307  }
308  }
309  }
310 
311 
312  fftw_execute(m_plan);
313 
314 
315  //- output data
316  for (int z = 0; z < Nz; z++) {
317  for (int y_global = 0; y_global < Ly; y_global++) {
318  for (int x_global = 0; x_global < Lx; x_global++) {
319  int isite_xyz_local = x_global + Lx * (y_global + Ly * z);
320 
321  int isite = m_index.site(x_global, y_global, z, t_global);
322  int i_real = 2 * in2;
323  int i_imag = 2 * in2 + 1;
324 
325  field_out.set(i_real, isite, ex, m_out[isite_xyz_local][0]);
326  field_out.set(i_imag, isite, ex, m_out[isite_xyz_local][1]);
327  }
328  }
329  }
330  }
331  }
332  }
333  //- end of global loops
334 
335  //- normailzation for FFTW_BACKWARD
336  if (!m_is_forward) {
337  scal(field_out, 1.0 / Lxyz);
338  }
339 }
340 
341 
342 //==========================================================
343 //==================================================END=====
344 #endif
void scal(Field &x, const double a)
scal(x, a): x = a * x
Definition: field.cpp:282
BridgeIO vout
Definition: bridgeIO.cpp:495
static int get_num_threads()
returns available number of threads.
static int NPEy()
static int NPEt()
void set(const int jin, const int site, const int jex, double v)
Definition: field.h:164
void general(const char *format,...)
Definition: bridgeIO.cpp:195
Container of Field-type object.
Definition: field.h:39
double cmp(const int jin, const int site, const int jex) const
Definition: field.h:132
Class for parameters.
Definition: parameters.h:46
int fetch_string(const string &key, string &value) const
Definition: parameters.cpp:262
int nin() const
Definition: field.h:115
int nex() const
Definition: field.h:117
static int NPEx()
void crucial(const char *format,...)
Definition: bridgeIO.cpp:178
string get_string(const string &key) const
Definition: parameters.cpp:116
static VerboseLevel set_verbose_level(const std::string &str)
Definition: bridgeIO.cpp:131