Bridge++  Ver. 2.0.2
fft_3d_parallel3d.cpp
Go to the documentation of this file.
1 
13 #ifdef USE_FFTWLIB
14 #ifdef USE_MPI
15 
16 #include "fft_3d_parallel3d.h"
19 
20 #ifdef USE_OPENMP
22 #endif
23 
24 
25 const std::string FFT_3d_parallel3d::class_name = "FFT_3d_parallel3d";
26 
27 #ifdef USE_FACTORY_AUTOREGISTER
28 namespace {
29  bool init = FFT_3d_parallel3d::register_factory();
30 }
31 #endif
32 
33 //====================================================================
34 void FFT_3d_parallel3d::set_parameters(const Parameters& params)
35 {
36  std::string vlevel;
37  if (!params.fetch_string("verbose_level", vlevel)) {
38  m_vl = vout.set_verbose_level(vlevel);
39  }
40 
41  std::string direction;
42  if (!params.fetch_string("FFT_direction", direction)) {
43  set_parameters(direction);
44  }
45 }
46 
47 
48 //====================================================================
49 void FFT_3d_parallel3d::get_parameters(Parameters& params) const
50 {
51  if (m_direction == FORWARD) {
52  params.set_string("FFT_direction", "Forward");
53  } else if (m_direction == BACKWARD) {
54  params.set_string("FFT_direction", "Backward");
55  } else {
56  params.set_string("FFT_direction", "None");
57  }
58 
59  params.set_string("verbose_level", vout.get_verbose_level(m_vl));
60 }
61 
62 
63 //====================================================================
64 void FFT_3d_parallel3d::set_parameters(const std::string& direction)
65 {
66  if (direction == "Forward") {
67  m_direction = FORWARD;
68  } else if (direction == "Backward") {
69  m_direction = BACKWARD;
70  } else {
71  m_direction = UNDEF;
72 
73  vout.crucial(m_vl, "Error at %s: unsupported FFT direction \"%s\"\n", class_name.c_str(), direction.c_str());
74  exit(EXIT_FAILURE);
75  }
76 }
77 
78 
79 //====================================================================
80 FFT_3d_parallel3d::FFT_3d_parallel3d()
81  : m_vl(CommonParameters::Vlevel())
82  , m_initialized(false)
83  , m_direction(UNDEF)
84 {
85  if (check_ok()) {
86  initialize();
87  }
88 }
89 
90 
91 //====================================================================
92 FFT_3d_parallel3d::FFT_3d_parallel3d(const Parameters& params)
93  : m_vl(CommonParameters::Vlevel())
94  , m_initialized(false)
95  , m_direction(UNDEF)
96 {
97  if (check_ok()) {
98  initialize();
99  }
100  set_parameters(params);
101 }
102 
103 
104 //====================================================================
105 FFT_3d_parallel3d::~FFT_3d_parallel3d()
106 {
107  finalize();
108 }
109 
110 
111 //====================================================================
112 bool FFT_3d_parallel3d::check_ok()
113 {
114  return true;
115 }
116 
117 
118 //====================================================================
119 void FFT_3d_parallel3d::initialize()
120 {
121 #ifdef USE_OPENMP
122  int thread_ok = fftw_init_threads();
123 
124  if (thread_ok) {
125  fftw_plan_with_nthreads(ThreadManager::get_num_threads());
126  }
127 #endif
128 
129  // fft itself is not mpi-parallelized.
130 
131  int ipe_x = Communicator::ipe(0);
132  int ipe_y = Communicator::ipe(1);
133  int ipe_z = Communicator::ipe(2);
134  int ipe_t = Communicator::ipe(3);
135 
136  int npe_x = Communicator::npe(0);
137  int npe_y = Communicator::npe(1);
138  int npe_z = Communicator::npe(2);
139  int npe_t = Communicator::npe(3);
140 
141  // split communicator along t directions
142  int ipe_xyz = ipe_x + npe_x * (ipe_y + npe_y * (ipe_z));
143 
144  MPI_Comm_split(Communicator_impl::world(), ipe_t, ipe_xyz, &m_comm);
145 
146  // find rank and coordinate in subcommunicator
147  int local_rank;
148  MPI_Comm_rank(m_comm, &local_rank);
149 
150  int local_ipe_x = local_rank % npe_x;
151  int local_ipe_y = (local_rank / npe_x) % npe_y;
152  int local_ipe_z = (local_rank / npe_x / npe_y) % npe_z;
153 
154  // just for check
155  if ((local_ipe_x != ipe_x) || (local_ipe_y != ipe_y) || (local_ipe_z != ipe_z)) {
156  vout.crucial(m_vl, "%s: split commnicator failed.\n", class_name.c_str());
157  exit(EXIT_FAILURE);
158  }
159 
160  m_local_rank = local_rank;
161 
162  m_local_ipe_x = local_ipe_x;
163  m_local_ipe_y = local_ipe_y;
164  m_local_ipe_z = local_ipe_z;
165 
166  // lattice and grid information
167  m_ndims = 3;
168 
169  m_grid_size.resize(m_ndims);
170  m_grid_size[0] = CommonParameters::NPEx();
171  m_grid_size[1] = CommonParameters::NPEy();
172  m_grid_size[2] = CommonParameters::NPEz();
173 
174  m_grid_vol = 1;
175  for (int i = 0; i < m_ndims; ++i) {
176  m_grid_vol *= m_grid_size[i];
177  }
178 
179  m_lattice_size.resize(m_ndims);
180  m_lattice_size[0] = CommonParameters::Lx();
181  m_lattice_size[1] = CommonParameters::Ly();
182  m_lattice_size[2] = CommonParameters::Lz();
183 
184  m_lattice_vol = 1;
185  for (int i = 0; i < m_ndims; ++i) {
186  m_lattice_vol *= m_lattice_size[i];
187  }
188 
189  m_local_size.resize(m_ndims);
190  m_local_size[0] = CommonParameters::Nx();
191  m_local_size[1] = CommonParameters::Ny();
192  m_local_size[2] = CommonParameters::Nz();
193 
194  m_local_vol = 1;
195  for (int i = 0; i < m_ndims; ++i) {
196  m_local_vol *= m_local_size[i];
197  }
198 }
199 
200 
201 //====================================================================
202 void FFT_3d_parallel3d::finalize()
203 {
204  if (m_initialized) {
205  release_plan();
206  }
207 }
208 
209 
210 //====================================================================
211 void FFT_3d_parallel3d::create_mpi_datatype(int site_dof)
212 {
213  // MPI datatypes and gather/scatter parameters
214 
215  // assume site_dof in complex
216 
217  MPI_Type_contiguous(2 * site_dof,
218  MPI_DOUBLE,
219  &m_site_vector_type);
220 
221  MPI_Type_commit(&m_site_vector_type);
222 
223  MPI_Type_contiguous(m_local_vol,
224  m_site_vector_type,
225  &m_local_patch_type);
226 
227  MPI_Type_commit(&m_local_patch_type);
228 
229  int size_;
230  MPI_Type_size(m_site_vector_type, &size_);
231 
232  MPI_Datatype type_;
233  std::vector<int> local_origin(m_ndims, 0);
234 
235  MPI_Type_create_subarray(m_ndims,
236  &m_lattice_size[0],
237  &m_local_size[0],
238  &local_origin[0],
239  MPI_ORDER_FORTRAN,
240  m_site_vector_type,
241  &type_);
242 
243  MPI_Type_create_resized(type_, 0, size_, &m_subarray_type);
244 
245  MPI_Type_commit(&m_subarray_type);
246 
247 
248  m_sendcounts.resize(m_grid_vol);
249 
250  for (int r = 0; r < m_grid_vol; ++r) {
251  m_sendcounts[r] = 1;
252  }
253 
254  m_subarray_displs.resize(m_grid_vol);
255 
256  for (int r = 0; r < m_grid_vol; ++r) {
257  std::vector<int> coord = grid_rank_to_coord(r);
258 
259  // find global coordinate of origin of each local patch
260  for (int j = 0; j < m_ndims; ++j) {
261  coord[j] *= m_local_size[j];
262  }
263 
264  int idx = find_global_index(coord);
265 
266  m_subarray_displs[r] = idx;
267  }
268 
269  m_local_patch_displs.resize(m_grid_vol);
270 
271  for (int r = 0; r < m_grid_vol; ++r) {
272  m_local_patch_displs[r] = r;
273  }
274 }
275 
276 
277 //====================================================================
278 void FFT_3d_parallel3d::release_mpi_datatype()
279 {
280  int is_finalized = 0;
281 
282  MPI_Finalized(&is_finalized);
283 
284  if (is_finalized) {
285  vout.crucial(m_vl, "%s: MPI has already gone...\n", class_name.c_str());
286  return;
287  }
288 
289  MPI_Type_free(&m_site_vector_type);
290  MPI_Type_free(&m_subarray_type);
291  MPI_Type_free(&m_local_patch_type);
292 }
293 
294 
295 //====================================================================
296 void FFT_3d_parallel3d::create_fft_plan(int site_dof)
297 {
298  // allocate buffer (run on-the-fly)
299  m_buf = fftw_alloc_complex(site_dof * m_lattice_vol);
300  if (!m_buf) {
301  vout.crucial(m_vl, "%s: buffer allocation failed.\n", class_name.c_str());
302  exit(EXIT_FAILURE);
303  }
304 
305  // create plan
306 
307  m_plan_fw = fftw_plan_many_dft(m_ndims, &m_lattice_size[0], site_dof,
308  m_buf, NULL, site_dof, 1,
309  m_buf, NULL, site_dof, 1,
310  FFTW_FORWARD, FFTW_ESTIMATE);
311 
312  m_plan_bw = fftw_plan_many_dft(m_ndims, &m_lattice_size[0], site_dof,
313  m_buf, NULL, site_dof, 1,
314  m_buf, NULL, site_dof, 1,
315  FFTW_BACKWARD, FFTW_ESTIMATE);
316 
317  if (!m_plan_fw || !m_plan_bw) {
318  vout.crucial(m_vl, "%s: create plan failed.\n", class_name.c_str());
319  exit(EXIT_FAILURE);
320  }
321 }
322 
323 
324 //====================================================================
325 void FFT_3d_parallel3d::release_fft_plan()
326 {
327  if (m_buf) fftw_free(m_buf);
328  m_buf = NULL;
329  if (m_plan_fw) fftw_destroy_plan(m_plan_fw);
330  m_plan_fw = NULL;
331  if (m_plan_bw) fftw_destroy_plan(m_plan_bw);
332  m_plan_bw = NULL;
333 }
334 
335 
336 //====================================================================
337 void FFT_3d_parallel3d::create_plan(int site_dof)
338 {
339  create_mpi_datatype(site_dof);
340  create_fft_plan(site_dof);
341 
342  m_site_dof = site_dof;
343 
344  m_initialized = true;
345 }
346 
347 
348 //====================================================================
349 void FFT_3d_parallel3d::release_plan()
350 {
351  release_fft_plan();
352  release_mpi_datatype();
353 
354  m_initialized = false;
355 }
356 
357 
358 //====================================================================
359 bool FFT_3d_parallel3d::need_create_plan(const Field& field)
360 {
361  if (field.nin() / 2 == m_site_dof) return false;
362 
363  return true;
364 }
365 
366 
367 //====================================================================
368 void FFT_3d_parallel3d::fft(Field& dst, const Field& src, enum Direction dir)
369 {
370  if (not ((dir == FORWARD) || (dir == BACKWARD))) {
371  vout.crucial(m_vl, "%s: unsupported direction. %d\n", class_name.c_str(), dir);
372  exit(EXIT_FAILURE);
373  }
374 
375  // check if mpi types and fft plans are recyclable.
376  if (m_initialized == false) {
377  vout.general(m_vl, "%s: create plan.\n", class_name.c_str());
378  create_plan(src.nin() / 2);
379  } else {
380  if (need_create_plan(src)) {
381  vout.general(m_vl, "%s: discard plan and create new.\n", class_name.c_str());
382  release_plan();
383  create_plan(src.nin() / 2);
384  } else {
385  vout.general(m_vl, "%s: plan recycled.\n", class_name.c_str());
386  }
387  }
388 
389  int nex = src.nex();
390  int nt = CommonParameters::Nt();
391 
392  int ndata = nt * nex;
393 
394  std::vector<dcomplex *> src_array(ndata, nullptr);
395  std::vector<dcomplex *> dst_array(ndata, nullptr);
396 
397  int local_vol = m_local_vol;
398 
399  int k = 0;
400  for (int iex = 0; iex < nex; ++iex) {
401  for (int it = 0; it < nt; ++it) {
402  src_array[k] = (dcomplex *)(src.ptr(0, local_vol * it, iex));
403  dst_array[k] = (dcomplex *)(dst.ptr(0, local_vol * it, iex));
404  ++k;
405  }
406  }
407 
408  int nblock = m_grid_vol;
409 
410  for (int k = 0; k < ndata; k += nblock) {
411  bool do_full = (k + nblock <= ndata);
412  int nwork = do_full ? nblock : (ndata % nblock);
413 
414  if (do_full) {
415  MPI_Alltoallv(src_array[k], &m_sendcounts[0], &m_local_patch_displs[0], m_local_patch_type,
416  m_buf, &m_sendcounts[0], &m_subarray_displs[0], m_subarray_type,
417  m_comm);
418  } else {
419  for (int j = 0; j < nwork; ++j) {
420  MPI_Gatherv(src_array[k + j], 1, m_local_patch_type,
421  m_buf, &m_sendcounts[0], &m_subarray_displs[0], m_subarray_type,
422  j, m_comm);
423  }
424  }
425 
426  if (m_local_rank < nwork) {
427  fftw_execute(dir == FORWARD ? m_plan_fw : m_plan_bw);
428  }
429 
430  if (do_full) {
431  MPI_Alltoallv(m_buf, &m_sendcounts[0], &m_subarray_displs[0], m_subarray_type,
432  dst_array[k], &m_sendcounts[0], &m_local_patch_displs[0], m_local_patch_type,
433  m_comm);
434  } else {
435  for (int j = 0; j < nwork; ++j) {
436  MPI_Scatterv(m_buf, &m_sendcounts[0], &m_subarray_displs[0], m_subarray_type,
437  dst_array[k + j], 1, m_local_patch_type,
438  j, m_comm);
439  }
440  }
441  }
442 
443  if (dir == BACKWARD) {
444  scal(dst, 1.0 / m_lattice_vol);
445  }
446 }
447 
448 
449 //====================================================================
450 void FFT_3d_parallel3d::fft(Field& dst, const Field& src)
451 {
452  return fft(dst, src, m_direction);
453 }
454 
455 
456 //====================================================================
457 void FFT_3d_parallel3d::fft(Field& field)
458 {
459  // return fft(field, field, m_direction);
460  vout.crucial(m_vl, "Error at %s: fft on-the-fly unsupported.\n", class_name.c_str());
461  exit(EXIT_FAILURE);
462 }
463 
464 
465 //====================================================================
466 std::vector<int> FFT_3d_parallel3d::grid_rank_to_coord(int r)
467 {
468  std::vector<int> coord(m_ndims);
469 
470  for (int i = 0; i < m_ndims; ++i) {
471  coord[i] = r % m_grid_size[i];
472  r /= m_grid_size[i];
473  }
474 
475  return coord;
476 }
477 
478 
479 //====================================================================
480 int FFT_3d_parallel3d::find_global_index(const std::vector<int>& coord)
481 {
482  assert(coord.size() == m_ndims);
483 
484  int idx = coord[m_ndims - 1];
485  for (int i = m_ndims - 2; i >= 0; --i) {
486  idx *= m_lattice_size[i];
487  idx += coord[i];
488  }
489 
490  return idx;
491 }
492 
493 
494 //====================================================================
495 #endif /* USE_MPI */
496 #endif /* USE_FFTWLIB */
497 
498 //====================================================================
499 //============================================================END=====
CommonParameters::Ny
static int Ny()
Definition: commonParameters.h:106
CommonParameters::Nz
static int Nz()
Definition: commonParameters.h:107
Parameters::set_string
void set_string(const string &key, const string &value)
Definition: parameters.cpp:39
CommonParameters
Common parameter class: provides parameters as singleton.
Definition: commonParameters.h:42
ThreadManager::get_num_threads
static int get_num_threads()
returns available number of threads.
Definition: threadManager.cpp:246
Parameters
Class for parameters.
Definition: parameters.h:46
fft_3d_parallel3d.h
Field::nex
int nex() const
Definition: field.h:128
CommonParameters::Ly
static int Ly()
Definition: commonParameters.h:92
Direction
Direction
Definition: bridge_defs.h:24
Field::nin
int nin() const
Definition: field.h:126
CommonParameters::Nx
static int Nx()
Definition: commonParameters.h:105
CommonParameters::Lx
static int Lx()
Definition: commonParameters.h:91
communicator_mpi.h
CommonParameters::Lz
static int Lz()
Definition: commonParameters.h:93
CommonParameters::Nt
static int Nt()
Definition: commonParameters.h:108
Communicator::npe
static int npe(const int dir)
logical grid extent
Definition: communicator.cpp:112
CommonParameters::NPEz
static int NPEz()
Definition: commonParameters.h:99
AIndex_eo_qxs::idx
int idx(const int in, const int Nin, const int ist, const int Nx2, const int Ny, const int leo, const int Nvol2, const int ex)
Definition: aindex_eo.h:27
threadManager.h
CommonParameters::NPEy
static int NPEy()
Definition: commonParameters.h:98
Field::ptr
const double * ptr(const int jin, const int site, const int jex) const
Definition: field.h:153
Bridge::BridgeIO::set_verbose_level
static VerboseLevel set_verbose_level(const std::string &str)
Definition: bridgeIO.cpp:133
CommonParameters::NPEx
static int NPEx()
Definition: commonParameters.h:97
scal
void scal(Field &x, const double a)
scal(x, a): x = a * x
Definition: field.cpp:261
Communicator::ipe
static int ipe(const int dir)
logical coordinate of current proc.
Definition: communicator.cpp:105
Parameters::fetch_string
int fetch_string(const string &key, string &value) const
Definition: parameters.cpp:378
Bridge::BridgeIO::crucial
void crucial(const char *format,...)
Definition: bridgeIO.cpp:180
Field
Container of Field-type object.
Definition: field.h:46
communicator.h
Bridge::BridgeIO::general
void general(const char *format,...)
Definition: bridgeIO.cpp:200
Bridge::vout
BridgeIO vout
Definition: bridgeIO.cpp:512
Bridge::BridgeIO::get_verbose_level
static std::string get_verbose_level(const VerboseLevel vl)
Definition: bridgeIO.cpp:154