Bridge++  Ver. 1.3.x
fft_xyz_3dim.cpp
Go to the documentation of this file.
1 
14 #ifdef USE_FFTWLIB
15 
16 #include "fft_xyz_3dim.h"
17 
18 const std::string FFT_xyz_3dim::class_name = "FFT_xyz_3dim";
19 
20 //====================================================================
21 void FFT_xyz_3dim::init()
22 {
23  //- global lattice size
24  const int Lx = CommonParameters::Lx();
25  const int Ly = CommonParameters::Ly();
26  const int Lz = CommonParameters::Lz();
27 
28 #ifdef USE_OPENMP
29  int threads_ok = fftw_init_threads();
30 #endif
31 
32 #ifdef USE_MPI
33  const int NPE_x = CommonParameters::NPEx();
34  const int NPE_y = CommonParameters::NPEy();
35  // const int NPE_z = CommonParameters::NPEz();
36  const int NPE_t = CommonParameters::NPEt();
37 
38  if ((NPE_x * NPE_y * NPE_t) != 1) {
39  vout.crucial(m_vl, "%s: FFTW supports parallelization only in z-direction.\n",
40  class_name.c_str());
41  exit(EXIT_FAILURE);
42  }
43 
44 
45  fftw_mpi_init();
46 
47 
48  //- allocate m_in,out = m_in,out[Nz][Ly][Lx]
49  const ptrdiff_t Lx_p = CommonParameters::Lx();
50  const ptrdiff_t Ly_p = CommonParameters::Ly();
51  const ptrdiff_t Lz_p = CommonParameters::Lz();
52 
53  ptrdiff_t fftw_size_p = fftw_mpi_local_size_3d(Lz_p, Ly_p, Lx_p,
55  &m_Nz_p, &m_z_start_p);
56 
57  m_in = fftw_alloc_complex(fftw_size_p);
58  m_out = fftw_alloc_complex(fftw_size_p);
59 
60  if (!m_in || !m_out) {
61  vout.crucial(m_vl, "%s: failed to allocate memory %d [Byte].\n",
62  class_name.c_str(), (int)fftw_size_p);
63  exit(EXIT_FAILURE);
64  }
65 #else
66  //- allocate m_in,out = m_in,out[Nz][Ly][Lx]
67  const size_t fftw_size = sizeof(fftw_complex) * Lx * Ly * Lz;
68  m_in = (fftw_complex *)fftw_malloc(fftw_size);
69  m_out = (fftw_complex *)fftw_malloc(fftw_size);
70 
71  if (!m_in || !m_out) {
72  vout.crucial(m_vl, "%s: failed to allocate memory %d [Byte].\n",
73  class_name.c_str(), (int)fftw_size);
74  exit(EXIT_FAILURE);
75  }
76 #endif
77 }
78 
79 
80 //====================================================================
81 void FFT_xyz_3dim::tidy_up()
82 {
83  if (m_in) fftw_free(m_in);
84  if (m_out) fftw_free(m_out);
85  if (m_plan) fftw_destroy_plan(m_plan);
86 }
87 
88 
89 //====================================================================
90 void FFT_xyz_3dim::FFT(Field& field, const bool is_forward)
91 {
92  //- global lattice size
93  const int Lx = CommonParameters::Lx();
94  const int Ly = CommonParameters::Ly();
95  const int Lz = CommonParameters::Lz();
96  const int Lt = CommonParameters::Lt();
97 
98  //- local size
99  const int Nz = CommonParameters::Nz();
100 
101  const int Nin = field.nin();
102  const int Nex = field.nex();
103 
104 
105  //- setup FFTW plan
106 #ifdef USE_OPENMP
107  const int Nthread = ThreadManager_OpenMP::get_num_threads();
108  fftw_plan_with_nthreads(Nthread);
109 #endif
110 #ifdef USE_MPI
111  const ptrdiff_t Lx_p = CommonParameters::Lx();
112  const ptrdiff_t Ly_p = CommonParameters::Ly();
113  const ptrdiff_t Lz_p = CommonParameters::Lz();
114 
115  if (is_forward) {
116  m_plan = fftw_mpi_plan_dft_3d(Lz_p, Ly_p, Lx_p, m_in, m_out,
118  FFTW_FORWARD, FFTW_ESTIMATE);
119  } else {
120  m_plan = fftw_mpi_plan_dft_3d(Lz_p, Ly_p, Lx_p, m_in, m_out,
122  FFTW_BACKWARD, FFTW_ESTIMATE);
123  }
124 #else
125  if (is_forward) {
126  m_plan = fftw_plan_dft_3d(Lz, Ly, Lx, m_in, m_out,
127  FFTW_FORWARD, FFTW_ESTIMATE);
128  } else {
129  m_plan = fftw_plan_dft_3d(Lz, Ly, Lx, m_in, m_out,
130  FFTW_BACKWARD, FFTW_ESTIMATE);
131  }
132 #endif
133 
134 
135  // #### Execution main part ####
136  //- Nin is devided by 2, because of complex(i.e. real and imag)
137  for (int in2 = 0; in2 < Nin / 2; ++in2) {
138  for (int t_global = 0; t_global < Lt; t_global++) {
139  for (int ex = 0; ex < Nex; ++ex) {
140  //- input data
141  for (int z = 0; z < Nz; z++) {
142  for (int y_global = 0; y_global < Ly; y_global++) {
143  for (int x_global = 0; x_global < Lx; x_global++) {
144  int isite_xyz_local = x_global + Lx * (y_global + Ly * z);
145 
146  int isite = m_index.site(x_global, y_global, z, t_global);
147  int i_real = 2 * in2;
148  int i_imag = 2 * in2 + 1;
149 
150  m_in[isite_xyz_local][0] = field.cmp(i_real, isite, ex);
151  m_in[isite_xyz_local][1] = field.cmp(i_imag, isite, ex);
152  }
153  }
154  }
155 
156 
157  fftw_execute(m_plan);
158 
159 
160  //- output data
161  for (int z = 0; z < Nz; z++) {
162  for (int y_global = 0; y_global < Ly; y_global++) {
163  for (int x_global = 0; x_global < Lx; x_global++) {
164  int isite_xyz_local = x_global + Lx * (y_global + Ly * z);
165 
166  int isite = m_index.site(x_global, y_global, z, t_global);
167  int i_real = 2 * in2;
168  int i_imag = 2 * in2 + 1;
169 
170  field.set(i_real, isite, ex, m_out[isite_xyz_local][0]);
171  field.set(i_imag, isite, ex, m_out[isite_xyz_local][1]);
172  }
173  }
174  }
175  }
176  }
177  }
178  //- end of global loops
179 }
180 
181 
182 //====================================================================
183 void FFT_xyz_3dim::FFT(Field& field_out, const Field& field_in, const bool is_forward)
184 {
185  //- global lattice size
186  const int Lx = CommonParameters::Lx();
187  const int Ly = CommonParameters::Ly();
188  const int Lz = CommonParameters::Lz();
189  const int Lt = CommonParameters::Lt();
190 
191  //- local size
192  const int Nz = CommonParameters::Nz();
193 
194  const int Nin = field_in.nin();
195  const int Nex = field_in.nex();
196 
197 
198  //- setup FFTW plan
199 #ifdef USE_OPENMP
200  const int Nthread = ThreadManager_OpenMP::get_num_threads();
201  fftw_plan_with_nthreads(Nthread);
202 #endif
203 #ifdef USE_MPI
204  const ptrdiff_t Lx_p = CommonParameters::Lx();
205  const ptrdiff_t Ly_p = CommonParameters::Ly();
206  const ptrdiff_t Lz_p = CommonParameters::Lz();
207 
208  if (is_forward) {
209  m_plan = fftw_mpi_plan_dft_3d(Lz_p, Ly_p, Lx_p, m_in, m_out,
211  FFTW_FORWARD, FFTW_ESTIMATE);
212  } else {
213  m_plan = fftw_mpi_plan_dft_3d(Lz_p, Ly_p, Lx_p, m_in, m_out,
215  FFTW_BACKWARD, FFTW_ESTIMATE);
216  }
217 #else
218  if (is_forward) {
219  m_plan = fftw_plan_dft_3d(Lz, Ly, Lx, m_in, m_out,
220  FFTW_FORWARD, FFTW_ESTIMATE);
221  } else {
222  m_plan = fftw_plan_dft_3d(Lz, Ly, Lx, m_in, m_out,
223  FFTW_BACKWARD, FFTW_ESTIMATE);
224  }
225 #endif
226 
227 
228  // #### Execution main part ####
229  //- Nin is devided by 2, because of complex(i.e. real and imag)
230  for (int in2 = 0; in2 < Nin / 2; ++in2) {
231  for (int t_global = 0; t_global < Lt; t_global++) {
232  for (int ex = 0; ex < Nex; ++ex) {
233  //- input data
234  for (int z = 0; z < Nz; z++) {
235  for (int y_global = 0; y_global < Ly; y_global++) {
236  for (int x_global = 0; x_global < Lx; x_global++) {
237  int isite_xyz_local = x_global + Lx * (y_global + Ly * z);
238 
239  int isite = m_index.site(x_global, y_global, z, t_global);
240  int i_real = 2 * in2;
241  int i_imag = 2 * in2 + 1;
242 
243  m_in[isite_xyz_local][0] = field_in.cmp(i_real, isite, ex);
244  m_in[isite_xyz_local][1] = field_in.cmp(i_imag, isite, ex);
245  }
246  }
247  }
248 
249 
250  fftw_execute(m_plan);
251 
252 
253  //- output data
254  for (int z = 0; z < Nz; z++) {
255  for (int y_global = 0; y_global < Ly; y_global++) {
256  for (int x_global = 0; x_global < Lx; x_global++) {
257  int isite_xyz_local = x_global + Lx * (y_global + Ly * z);
258 
259  int isite = m_index.site(x_global, y_global, z, t_global);
260  int i_real = 2 * in2;
261  int i_imag = 2 * in2 + 1;
262 
263  field_out.set(i_real, isite, ex, m_out[isite_xyz_local][0]);
264  field_out.set(i_imag, isite, ex, m_out[isite_xyz_local][1]);
265  }
266  }
267  }
268  }
269  }
270  }
271  //- end of global loops
272 }
273 
274 
275 //==========================================================
276 //==================================================END=====
277 #endif
BridgeIO vout
Definition: bridgeIO.cpp:278
static int get_num_threads()
returns available number of threads.
static int NPEy()
static int NPEt()
void set(const int jin, const int site, const int jex, double v)
Definition: field.h:155
Container of Field-type object.
Definition: field.h:39
double cmp(const int jin, const int site, const int jex) const
Definition: field.h:123
int nin() const
Definition: field.h:115
int nex() const
Definition: field.h:117
static int NPEx()
void crucial(const char *format,...)
Definition: bridgeIO.cpp:48