Bridge++  Version 1.5.4
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
fft_3d_parallel3d.cpp
Go to the documentation of this file.
1 
14 #ifdef USE_FFTWLIB
15 #ifdef USE_MPI
16 
17 #include "fft_3d_parallel3d.h"
20 
21 #ifdef USE_OPENMP
23 #endif
24 
25 
26 const std::string FFT_3d_parallel3d::class_name = "FFT_3d_parallel3d";
27 
28 #ifdef USE_FACTORY_AUTOREGISTER
29 namespace {
30  bool init = FFT_3d_parallel3d::register_factory();
31 }
32 #endif
33 
34 //====================================================================
35 void FFT_3d_parallel3d::set_parameters(const Parameters& params)
36 {
37  // pass parameters to base class.
38  this->FFT::set_parameters(params);
39 
40  std::string direction;
41 
42  if (params.fetch_string("FFT_direction", direction) == 0) {
43  set_parameters(direction);
44  }
45 }
46 
47 
48 //====================================================================
49 void FFT_3d_parallel3d::set_parameters(const std::string& direction)
50 {
51  if (direction == "Forward") {
52  m_direction = FORWARD;
53  } else if (direction == "Backward") {
54  m_direction = BACKWARD;
55  } else {
56  m_direction = UNDEF;
57 
58  vout.crucial(m_vl, "Error at %s: unsupported FFT direction \"%s\"\n", class_name.c_str(), direction.c_str());
59  exit(EXIT_FAILURE);
60  }
61 }
62 
63 
64 //====================================================================
65 FFT_3d_parallel3d::FFT_3d_parallel3d()
66  : m_initialized(false)
67  , m_direction(UNDEF)
68 {
69  if (check_ok()) {
70  initialize();
71  }
72 }
73 
74 
75 //====================================================================
76 FFT_3d_parallel3d::~FFT_3d_parallel3d()
77 {
78  finalize();
79 }
80 
81 
82 //====================================================================
83 bool FFT_3d_parallel3d::check_ok()
84 {
85  return true;
86 }
87 
88 
89 //====================================================================
90 void FFT_3d_parallel3d::initialize()
91 {
92 #ifdef USE_OPENMP
93  int thread_ok = fftw_init_threads();
94 
95  if (thread_ok) {
96  fftw_plan_with_nthreads(ThreadManager_OpenMP::get_num_threads());
97  }
98 #endif
99 
100  // fft itself is not mpi-parallelized.
101 
102  int ipe_x = Communicator::ipe(0);
103  int ipe_y = Communicator::ipe(1);
104  int ipe_z = Communicator::ipe(2);
105  int ipe_t = Communicator::ipe(3);
106 
107  int npe_x = Communicator::npe(0);
108  int npe_y = Communicator::npe(1);
109  int npe_z = Communicator::npe(2);
110  int npe_t = Communicator::npe(3);
111 
112  // split communicator along t directions
113  int ipe_xyz = ipe_x + npe_x * (ipe_y + npe_y * (ipe_z));
114 
115  MPI_Comm_split(Communicator_impl::world(), ipe_t, ipe_xyz, &m_comm);
116 
117  // find rank and coordinate in subcommunicator
118  int local_rank;
119  MPI_Comm_rank(m_comm, &local_rank);
120 
121  int local_ipe_x = local_rank % npe_x;
122  int local_ipe_y = (local_rank / npe_x) % npe_y;
123  int local_ipe_z = (local_rank / npe_x / npe_y) % npe_z;
124 
125  // just for check
126  if ((local_ipe_x != ipe_x) || (local_ipe_y != ipe_y) || (local_ipe_z != ipe_z)) {
127  vout.crucial(m_vl, "%s: split commnicator failed.\n", class_name.c_str());
128  exit(EXIT_FAILURE);
129  }
130 
131  m_local_rank = local_rank;
132 
133  m_local_ipe_x = local_ipe_x;
134  m_local_ipe_y = local_ipe_y;
135  m_local_ipe_z = local_ipe_z;
136 
137  // lattice and grid information
138  m_ndims = 3;
139 
140  m_grid_size.resize(m_ndims);
141  m_grid_size[0] = CommonParameters::NPEx();
142  m_grid_size[1] = CommonParameters::NPEy();
143  m_grid_size[2] = CommonParameters::NPEz();
144 
145  m_grid_vol = 1;
146  for (int i = 0; i < m_ndims; ++i) {
147  m_grid_vol *= m_grid_size[i];
148  }
149 
150  m_lattice_size.resize(m_ndims);
151  m_lattice_size[0] = CommonParameters::Lx();
152  m_lattice_size[1] = CommonParameters::Ly();
153  m_lattice_size[2] = CommonParameters::Lz();
154 
155  m_lattice_vol = 1;
156  for (int i = 0; i < m_ndims; ++i) {
157  m_lattice_vol *= m_lattice_size[i];
158  }
159 
160  m_local_size.resize(m_ndims);
161  m_local_size[0] = CommonParameters::Nx();
162  m_local_size[1] = CommonParameters::Ny();
163  m_local_size[2] = CommonParameters::Nz();
164 
165  m_local_vol = 1;
166  for (int i = 0; i < m_ndims; ++i) {
167  m_local_vol *= m_local_size[i];
168  }
169 }
170 
171 
172 //====================================================================
173 void FFT_3d_parallel3d::finalize()
174 {
175  if (m_initialized) {
176  release_plan();
177  }
178 }
179 
180 
181 //====================================================================
182 void FFT_3d_parallel3d::create_mpi_datatype(int site_dof)
183 {
184  // MPI datatypes and gather/scatter parameters
185 
186  // assume site_dof in complex
187 
188  MPI_Type_contiguous(2 * site_dof,
189  MPI_DOUBLE,
190  &m_site_vector_type);
191 
192  MPI_Type_commit(&m_site_vector_type);
193 
194  MPI_Type_contiguous(m_local_vol,
195  m_site_vector_type,
196  &m_local_patch_type);
197 
198  MPI_Type_commit(&m_local_patch_type);
199 
200  int size_;
201  MPI_Type_size(m_site_vector_type, &size_);
202 
203  MPI_Datatype type_;
204  std::vector<int> local_origin(m_ndims, 0);
205 
206  MPI_Type_create_subarray(m_ndims,
207  &m_lattice_size[0],
208  &m_local_size[0],
209  &local_origin[0],
210  MPI_ORDER_FORTRAN,
211  m_site_vector_type,
212  &type_);
213 
214  MPI_Type_create_resized(type_, 0, size_, &m_subarray_type);
215 
216  MPI_Type_commit(&m_subarray_type);
217 
218 
219  m_sendcounts.resize(m_grid_vol);
220 
221  for (int r = 0; r < m_grid_vol; ++r) {
222  m_sendcounts[r] = 1;
223  }
224 
225  m_subarray_displs.resize(m_grid_vol);
226 
227  for (int r = 0; r < m_grid_vol; ++r) {
228  std::vector<int> coord = grid_rank_to_coord(r);
229 
230  // find global coordinate of origin of each local patch
231  for (int j = 0; j < m_ndims; ++j) {
232  coord[j] *= m_local_size[j];
233  }
234 
235  int idx = find_global_index(coord);
236 
237  m_subarray_displs[r] = idx;
238  }
239 
240  m_local_patch_displs.resize(m_grid_vol);
241 
242  for (int r = 0; r < m_grid_vol; ++r) {
243  m_local_patch_displs[r] = r;
244  }
245 }
246 
247 
248 //====================================================================
249 void FFT_3d_parallel3d::release_mpi_datatype()
250 {
251  int is_finalized = 0;
252 
253  MPI_Finalized(&is_finalized);
254 
255  if (is_finalized) {
256  vout.crucial(m_vl, "%s: MPI has already gone...\n", class_name.c_str());
257  return;
258  }
259 
260  MPI_Type_free(&m_site_vector_type);
261  MPI_Type_free(&m_subarray_type);
262  MPI_Type_free(&m_local_patch_type);
263 }
264 
265 
266 //====================================================================
267 void FFT_3d_parallel3d::create_fft_plan(int site_dof)
268 {
269  // allocate buffer (run on-the-fly)
270  m_buf = fftw_alloc_complex(site_dof * m_lattice_vol);
271  if (!m_buf) {
272  vout.crucial(m_vl, "%s: buffer allocation failed.\n", class_name.c_str());
273  exit(EXIT_FAILURE);
274  }
275 
276  // create plan
277 
278  m_plan_fw = fftw_plan_many_dft(m_ndims, &m_lattice_size[0], site_dof,
279  m_buf, NULL, site_dof, 1,
280  m_buf, NULL, site_dof, 1,
281  FFTW_FORWARD, FFTW_ESTIMATE);
282 
283  m_plan_bw = fftw_plan_many_dft(m_ndims, &m_lattice_size[0], site_dof,
284  m_buf, NULL, site_dof, 1,
285  m_buf, NULL, site_dof, 1,
286  FFTW_BACKWARD, FFTW_ESTIMATE);
287 
288  if (!m_plan_fw || !m_plan_bw) {
289  vout.crucial(m_vl, "%s: create plan failed.\n", class_name.c_str());
290  exit(EXIT_FAILURE);
291  }
292 }
293 
294 
295 //====================================================================
296 void FFT_3d_parallel3d::release_fft_plan()
297 {
298  if (m_buf) fftw_free(m_buf);
299  m_buf = NULL;
300  if (m_plan_fw) fftw_destroy_plan(m_plan_fw);
301  m_plan_fw = NULL;
302  if (m_plan_bw) fftw_destroy_plan(m_plan_bw);
303  m_plan_bw = NULL;
304 }
305 
306 
307 //====================================================================
308 void FFT_3d_parallel3d::create_plan(int site_dof)
309 {
310  create_mpi_datatype(site_dof);
311  create_fft_plan(site_dof);
312 
313  m_site_dof = site_dof;
314 
315  m_initialized = true;
316 }
317 
318 
319 //====================================================================
320 void FFT_3d_parallel3d::release_plan()
321 {
322  release_fft_plan();
323  release_mpi_datatype();
324 
325  m_initialized = false;
326 }
327 
328 
329 //====================================================================
330 bool FFT_3d_parallel3d::need_create_plan(const Field& field)
331 {
332  if (field.nin() / 2 == m_site_dof) return false;
333 
334  return true;
335 }
336 
337 
338 //====================================================================
339 void FFT_3d_parallel3d::fft(Field& dst, const Field& src, enum Direction dir)
340 {
341  if (not ((dir == FORWARD) || (dir == BACKWARD))) {
342  vout.crucial(m_vl, "%s: unsupported direction. %d\n", class_name.c_str(), dir);
343  exit(EXIT_FAILURE);
344  }
345 
346  // check if mpi types and fft plans are recyclable.
347  if (m_initialized == false) {
348  vout.general(m_vl, "%s: create plan.\n", class_name.c_str());
349  create_plan(src.nin() / 2);
350  } else {
351  if (need_create_plan(src)) {
352  vout.general(m_vl, "%s: discard plan and create new.\n", class_name.c_str());
353  release_plan();
354  create_plan(src.nin() / 2);
355  } else {
356  vout.general(m_vl, "%s: plan recycled.\n", class_name.c_str());
357  }
358  }
359 
360  int nex = src.nex();
361  int nt = CommonParameters::Nt();
362 
363  int ndata = nt * nex;
364 
365 #ifdef LIB_CPP11
366  std::vector<dcomplex *> src_array(ndata, nullptr);
367  std::vector<dcomplex *> dst_array(ndata, nullptr);
368 #else
369  std::vector<dcomplex *> src_array(ndata, NULL);
370  std::vector<dcomplex *> dst_array(ndata, NULL);
371 #endif
372 
373  int local_vol = m_local_vol;
374 
375  int k = 0;
376  for (int iex = 0; iex < nex; ++iex) {
377  for (int it = 0; it < nt; ++it) {
378  src_array[k] = (dcomplex *)(src.ptr(0, local_vol * it, iex));
379  dst_array[k] = (dcomplex *)(dst.ptr(0, local_vol * it, iex));
380  ++k;
381  }
382  }
383 
384  int nblock = m_grid_vol;
385 
386  for (int k = 0; k < ndata; k += nblock) {
387  bool do_full = (k + nblock <= ndata);
388  int nwork = do_full ? nblock : (ndata % nblock);
389 
390  if (do_full) {
391  MPI_Alltoallv(src_array[k], &m_sendcounts[0], &m_local_patch_displs[0], m_local_patch_type,
392  m_buf, &m_sendcounts[0], &m_subarray_displs[0], m_subarray_type,
393  m_comm);
394  } else {
395  for (int j = 0; j < nwork; ++j) {
396  MPI_Gatherv(src_array[k + j], 1, m_local_patch_type,
397  m_buf, &m_sendcounts[0], &m_subarray_displs[0], m_subarray_type,
398  j, m_comm);
399  }
400  }
401 
402  if (m_local_rank < nwork) {
403  fftw_execute(dir == FORWARD ? m_plan_fw : m_plan_bw);
404  }
405 
406  if (do_full) {
407  MPI_Alltoallv(m_buf, &m_sendcounts[0], &m_subarray_displs[0], m_subarray_type,
408  dst_array[k], &m_sendcounts[0], &m_local_patch_displs[0], m_local_patch_type,
409  m_comm);
410  } else {
411  for (int j = 0; j < nwork; ++j) {
412  MPI_Scatterv(m_buf, &m_sendcounts[0], &m_subarray_displs[0], m_subarray_type,
413  dst_array[k + j], 1, m_local_patch_type,
414  j, m_comm);
415  }
416  }
417  }
418 
419  if (dir == BACKWARD) {
420  scal(dst, 1.0 / m_lattice_vol);
421  }
422 }
423 
424 
425 //====================================================================
426 void FFT_3d_parallel3d::fft(Field& dst, const Field& src)
427 {
428  return fft(dst, src, m_direction);
429 }
430 
431 
432 //====================================================================
433 void FFT_3d_parallel3d::fft(Field& field)
434 {
435  // return fft(field, field, m_direction);
436  vout.crucial(m_vl, "Error at %s: fft on-the-fly unsupported.\n", class_name.c_str());
437  exit(EXIT_FAILURE);
438 }
439 
440 
441 //====================================================================
442 std::vector<int> FFT_3d_parallel3d::grid_rank_to_coord(int r)
443 {
444  std::vector<int> coord(m_ndims);
445 
446  for (int i = 0; i < m_ndims; ++i) {
447  coord[i] = r % m_grid_size[i];
448  r /= m_grid_size[i];
449  }
450 
451  return coord;
452 }
453 
454 
455 //====================================================================
456 int FFT_3d_parallel3d::find_global_index(const std::vector<int>& coord)
457 {
458  assert(coord.size() == m_ndims);
459 
460  int idx = coord[m_ndims - 1];
461  for (int i = m_ndims - 2; i >= 0; --i) {
462  idx *= m_lattice_size[i];
463  idx += coord[i];
464  }
465 
466  return idx;
467 }
468 
469 
470 //====================================================================
471 #endif /* USE_MPI */
472 #endif /* USE_FFTWLIB */
473 
474 //====================================================================
475 //============================================================END=====
void scal(Field &x, const double a)
scal(x, a): x = a * x
Definition: field.cpp:433
BridgeIO vout
Definition: bridgeIO.cpp:503
static int get_num_threads()
returns available number of threads.
static int npe(const int dir)
logical grid extent
static int NPEy()
const double * ptr(const int jin, const int site, const int jex) const
Definition: field.h:153
void general(const char *format,...)
Definition: bridgeIO.cpp:197
Container of Field-type object.
Definition: field.h:45
Class for parameters.
Definition: parameters.h:46
static int ipe(const int dir)
logical coordinate of current proc.
int fetch_string(const string &key, string &value) const
Definition: parameters.cpp:378
int nin() const
Definition: field.h:126
static MPI_Comm & world()
retrieves current communicator.
int nex() const
Definition: field.h:128
static int NPEx()
void crucial(const char *format,...)
Definition: bridgeIO.cpp:178
static int NPEz()
Direction
Definition: bridge_defs.h:24