Bridge++  Version 1.6.1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
fft_3d_parallel3d.cpp
Go to the documentation of this file.
1 
13 #ifdef USE_FFTWLIB
14 #ifdef USE_MPI
15 
16 #include "fft_3d_parallel3d.h"
19 
20 #ifdef USE_OPENMP
22 #endif
23 
24 
25 const std::string FFT_3d_parallel3d::class_name = "FFT_3d_parallel3d";
26 
27 #ifdef USE_FACTORY_AUTOREGISTER
28 namespace {
29  bool init = FFT_3d_parallel3d::register_factory();
30 }
31 #endif
32 
33 //====================================================================
34 void FFT_3d_parallel3d::set_parameters(const Parameters& params)
35 {
36  // pass parameters to base class.
37  this->FFT::set_parameters(params);
38 
39  std::string direction;
40 
41  if (params.fetch_string("FFT_direction", direction) == 0) {
42  set_parameters(direction);
43  }
44 }
45 
46 
47 //====================================================================
48 void FFT_3d_parallel3d::set_parameters(const std::string& direction)
49 {
50  if (direction == "Forward") {
51  m_direction = FORWARD;
52  } else if (direction == "Backward") {
53  m_direction = BACKWARD;
54  } else {
55  m_direction = UNDEF;
56 
57  vout.crucial(m_vl, "Error at %s: unsupported FFT direction \"%s\"\n", class_name.c_str(), direction.c_str());
58  exit(EXIT_FAILURE);
59  }
60 }
61 
62 
63 //====================================================================
64 FFT_3d_parallel3d::FFT_3d_parallel3d()
65  : m_initialized(false)
66  , m_direction(UNDEF)
67 {
68  if (check_ok()) {
69  initialize();
70  }
71 }
72 
73 
74 //====================================================================
75 FFT_3d_parallel3d::~FFT_3d_parallel3d()
76 {
77  finalize();
78 }
79 
80 
81 //====================================================================
82 bool FFT_3d_parallel3d::check_ok()
83 {
84  return true;
85 }
86 
87 
88 //====================================================================
89 void FFT_3d_parallel3d::initialize()
90 {
91 #ifdef USE_OPENMP
92  int thread_ok = fftw_init_threads();
93 
94  if (thread_ok) {
95  fftw_plan_with_nthreads(ThreadManager_OpenMP::get_num_threads());
96  }
97 #endif
98 
99  // fft itself is not mpi-parallelized.
100 
101  int ipe_x = Communicator::ipe(0);
102  int ipe_y = Communicator::ipe(1);
103  int ipe_z = Communicator::ipe(2);
104  int ipe_t = Communicator::ipe(3);
105 
106  int npe_x = Communicator::npe(0);
107  int npe_y = Communicator::npe(1);
108  int npe_z = Communicator::npe(2);
109  int npe_t = Communicator::npe(3);
110 
111  // split communicator along t directions
112  int ipe_xyz = ipe_x + npe_x * (ipe_y + npe_y * (ipe_z));
113 
114  MPI_Comm_split(Communicator_impl::world(), ipe_t, ipe_xyz, &m_comm);
115 
116  // find rank and coordinate in subcommunicator
117  int local_rank;
118  MPI_Comm_rank(m_comm, &local_rank);
119 
120  int local_ipe_x = local_rank % npe_x;
121  int local_ipe_y = (local_rank / npe_x) % npe_y;
122  int local_ipe_z = (local_rank / npe_x / npe_y) % npe_z;
123 
124  // just for check
125  if ((local_ipe_x != ipe_x) || (local_ipe_y != ipe_y) || (local_ipe_z != ipe_z)) {
126  vout.crucial(m_vl, "%s: split commnicator failed.\n", class_name.c_str());
127  exit(EXIT_FAILURE);
128  }
129 
130  m_local_rank = local_rank;
131 
132  m_local_ipe_x = local_ipe_x;
133  m_local_ipe_y = local_ipe_y;
134  m_local_ipe_z = local_ipe_z;
135 
136  // lattice and grid information
137  m_ndims = 3;
138 
139  m_grid_size.resize(m_ndims);
140  m_grid_size[0] = CommonParameters::NPEx();
141  m_grid_size[1] = CommonParameters::NPEy();
142  m_grid_size[2] = CommonParameters::NPEz();
143 
144  m_grid_vol = 1;
145  for (int i = 0; i < m_ndims; ++i) {
146  m_grid_vol *= m_grid_size[i];
147  }
148 
149  m_lattice_size.resize(m_ndims);
150  m_lattice_size[0] = CommonParameters::Lx();
151  m_lattice_size[1] = CommonParameters::Ly();
152  m_lattice_size[2] = CommonParameters::Lz();
153 
154  m_lattice_vol = 1;
155  for (int i = 0; i < m_ndims; ++i) {
156  m_lattice_vol *= m_lattice_size[i];
157  }
158 
159  m_local_size.resize(m_ndims);
160  m_local_size[0] = CommonParameters::Nx();
161  m_local_size[1] = CommonParameters::Ny();
162  m_local_size[2] = CommonParameters::Nz();
163 
164  m_local_vol = 1;
165  for (int i = 0; i < m_ndims; ++i) {
166  m_local_vol *= m_local_size[i];
167  }
168 }
169 
170 
171 //====================================================================
172 void FFT_3d_parallel3d::finalize()
173 {
174  if (m_initialized) {
175  release_plan();
176  }
177 }
178 
179 
180 //====================================================================
181 void FFT_3d_parallel3d::create_mpi_datatype(int site_dof)
182 {
183  // MPI datatypes and gather/scatter parameters
184 
185  // assume site_dof in complex
186 
187  MPI_Type_contiguous(2 * site_dof,
188  MPI_DOUBLE,
189  &m_site_vector_type);
190 
191  MPI_Type_commit(&m_site_vector_type);
192 
193  MPI_Type_contiguous(m_local_vol,
194  m_site_vector_type,
195  &m_local_patch_type);
196 
197  MPI_Type_commit(&m_local_patch_type);
198 
199  int size_;
200  MPI_Type_size(m_site_vector_type, &size_);
201 
202  MPI_Datatype type_;
203  std::vector<int> local_origin(m_ndims, 0);
204 
205  MPI_Type_create_subarray(m_ndims,
206  &m_lattice_size[0],
207  &m_local_size[0],
208  &local_origin[0],
209  MPI_ORDER_FORTRAN,
210  m_site_vector_type,
211  &type_);
212 
213  MPI_Type_create_resized(type_, 0, size_, &m_subarray_type);
214 
215  MPI_Type_commit(&m_subarray_type);
216 
217 
218  m_sendcounts.resize(m_grid_vol);
219 
220  for (int r = 0; r < m_grid_vol; ++r) {
221  m_sendcounts[r] = 1;
222  }
223 
224  m_subarray_displs.resize(m_grid_vol);
225 
226  for (int r = 0; r < m_grid_vol; ++r) {
227  std::vector<int> coord = grid_rank_to_coord(r);
228 
229  // find global coordinate of origin of each local patch
230  for (int j = 0; j < m_ndims; ++j) {
231  coord[j] *= m_local_size[j];
232  }
233 
234  int idx = find_global_index(coord);
235 
236  m_subarray_displs[r] = idx;
237  }
238 
239  m_local_patch_displs.resize(m_grid_vol);
240 
241  for (int r = 0; r < m_grid_vol; ++r) {
242  m_local_patch_displs[r] = r;
243  }
244 }
245 
246 
247 //====================================================================
248 void FFT_3d_parallel3d::release_mpi_datatype()
249 {
250  int is_finalized = 0;
251 
252  MPI_Finalized(&is_finalized);
253 
254  if (is_finalized) {
255  vout.crucial(m_vl, "%s: MPI has already gone...\n", class_name.c_str());
256  return;
257  }
258 
259  MPI_Type_free(&m_site_vector_type);
260  MPI_Type_free(&m_subarray_type);
261  MPI_Type_free(&m_local_patch_type);
262 }
263 
264 
265 //====================================================================
266 void FFT_3d_parallel3d::create_fft_plan(int site_dof)
267 {
268  // allocate buffer (run on-the-fly)
269  m_buf = fftw_alloc_complex(site_dof * m_lattice_vol);
270  if (!m_buf) {
271  vout.crucial(m_vl, "%s: buffer allocation failed.\n", class_name.c_str());
272  exit(EXIT_FAILURE);
273  }
274 
275  // create plan
276 
277  m_plan_fw = fftw_plan_many_dft(m_ndims, &m_lattice_size[0], site_dof,
278  m_buf, NULL, site_dof, 1,
279  m_buf, NULL, site_dof, 1,
280  FFTW_FORWARD, FFTW_ESTIMATE);
281 
282  m_plan_bw = fftw_plan_many_dft(m_ndims, &m_lattice_size[0], site_dof,
283  m_buf, NULL, site_dof, 1,
284  m_buf, NULL, site_dof, 1,
285  FFTW_BACKWARD, FFTW_ESTIMATE);
286 
287  if (!m_plan_fw || !m_plan_bw) {
288  vout.crucial(m_vl, "%s: create plan failed.\n", class_name.c_str());
289  exit(EXIT_FAILURE);
290  }
291 }
292 
293 
294 //====================================================================
295 void FFT_3d_parallel3d::release_fft_plan()
296 {
297  if (m_buf) fftw_free(m_buf);
298  m_buf = NULL;
299  if (m_plan_fw) fftw_destroy_plan(m_plan_fw);
300  m_plan_fw = NULL;
301  if (m_plan_bw) fftw_destroy_plan(m_plan_bw);
302  m_plan_bw = NULL;
303 }
304 
305 
306 //====================================================================
307 void FFT_3d_parallel3d::create_plan(int site_dof)
308 {
309  create_mpi_datatype(site_dof);
310  create_fft_plan(site_dof);
311 
312  m_site_dof = site_dof;
313 
314  m_initialized = true;
315 }
316 
317 
318 //====================================================================
319 void FFT_3d_parallel3d::release_plan()
320 {
321  release_fft_plan();
322  release_mpi_datatype();
323 
324  m_initialized = false;
325 }
326 
327 
328 //====================================================================
329 bool FFT_3d_parallel3d::need_create_plan(const Field& field)
330 {
331  if (field.nin() / 2 == m_site_dof) return false;
332 
333  return true;
334 }
335 
336 
337 //====================================================================
338 void FFT_3d_parallel3d::fft(Field& dst, const Field& src, enum Direction dir)
339 {
340  if (not ((dir == FORWARD) || (dir == BACKWARD))) {
341  vout.crucial(m_vl, "%s: unsupported direction. %d\n", class_name.c_str(), dir);
342  exit(EXIT_FAILURE);
343  }
344 
345  // check if mpi types and fft plans are recyclable.
346  if (m_initialized == false) {
347  vout.general(m_vl, "%s: create plan.\n", class_name.c_str());
348  create_plan(src.nin() / 2);
349  } else {
350  if (need_create_plan(src)) {
351  vout.general(m_vl, "%s: discard plan and create new.\n", class_name.c_str());
352  release_plan();
353  create_plan(src.nin() / 2);
354  } else {
355  vout.general(m_vl, "%s: plan recycled.\n", class_name.c_str());
356  }
357  }
358 
359  int nex = src.nex();
360  int nt = CommonParameters::Nt();
361 
362  int ndata = nt * nex;
363 
364 #ifdef LIB_CPP11
365  std::vector<dcomplex *> src_array(ndata, nullptr);
366  std::vector<dcomplex *> dst_array(ndata, nullptr);
367 #else
368  std::vector<dcomplex *> src_array(ndata, NULL);
369  std::vector<dcomplex *> dst_array(ndata, NULL);
370 #endif
371 
372  int local_vol = m_local_vol;
373 
374  int k = 0;
375  for (int iex = 0; iex < nex; ++iex) {
376  for (int it = 0; it < nt; ++it) {
377  src_array[k] = (dcomplex *)(src.ptr(0, local_vol * it, iex));
378  dst_array[k] = (dcomplex *)(dst.ptr(0, local_vol * it, iex));
379  ++k;
380  }
381  }
382 
383  int nblock = m_grid_vol;
384 
385  for (int k = 0; k < ndata; k += nblock) {
386  bool do_full = (k + nblock <= ndata);
387  int nwork = do_full ? nblock : (ndata % nblock);
388 
389  if (do_full) {
390  MPI_Alltoallv(src_array[k], &m_sendcounts[0], &m_local_patch_displs[0], m_local_patch_type,
391  m_buf, &m_sendcounts[0], &m_subarray_displs[0], m_subarray_type,
392  m_comm);
393  } else {
394  for (int j = 0; j < nwork; ++j) {
395  MPI_Gatherv(src_array[k + j], 1, m_local_patch_type,
396  m_buf, &m_sendcounts[0], &m_subarray_displs[0], m_subarray_type,
397  j, m_comm);
398  }
399  }
400 
401  if (m_local_rank < nwork) {
402  fftw_execute(dir == FORWARD ? m_plan_fw : m_plan_bw);
403  }
404 
405  if (do_full) {
406  MPI_Alltoallv(m_buf, &m_sendcounts[0], &m_subarray_displs[0], m_subarray_type,
407  dst_array[k], &m_sendcounts[0], &m_local_patch_displs[0], m_local_patch_type,
408  m_comm);
409  } else {
410  for (int j = 0; j < nwork; ++j) {
411  MPI_Scatterv(m_buf, &m_sendcounts[0], &m_subarray_displs[0], m_subarray_type,
412  dst_array[k + j], 1, m_local_patch_type,
413  j, m_comm);
414  }
415  }
416  }
417 
418  if (dir == BACKWARD) {
419  scal(dst, 1.0 / m_lattice_vol);
420  }
421 }
422 
423 
424 //====================================================================
425 void FFT_3d_parallel3d::fft(Field& dst, const Field& src)
426 {
427  return fft(dst, src, m_direction);
428 }
429 
430 
431 //====================================================================
432 void FFT_3d_parallel3d::fft(Field& field)
433 {
434  // return fft(field, field, m_direction);
435  vout.crucial(m_vl, "Error at %s: fft on-the-fly unsupported.\n", class_name.c_str());
436  exit(EXIT_FAILURE);
437 }
438 
439 
440 //====================================================================
441 std::vector<int> FFT_3d_parallel3d::grid_rank_to_coord(int r)
442 {
443  std::vector<int> coord(m_ndims);
444 
445  for (int i = 0; i < m_ndims; ++i) {
446  coord[i] = r % m_grid_size[i];
447  r /= m_grid_size[i];
448  }
449 
450  return coord;
451 }
452 
453 
454 //====================================================================
455 int FFT_3d_parallel3d::find_global_index(const std::vector<int>& coord)
456 {
457  assert(coord.size() == m_ndims);
458 
459  int idx = coord[m_ndims - 1];
460  for (int i = m_ndims - 2; i >= 0; --i) {
461  idx *= m_lattice_size[i];
462  idx += coord[i];
463  }
464 
465  return idx;
466 }
467 
468 
469 //====================================================================
470 #endif /* USE_MPI */
471 #endif /* USE_FFTWLIB */
472 
473 //====================================================================
474 //============================================================END=====
void scal(Field &x, const double a)
scal(x, a): x = a * x
Definition: field.cpp:433
BridgeIO vout
Definition: bridgeIO.cpp:503
static int get_num_threads()
returns available number of threads.
static int npe(const int dir)
logical grid extent
static int NPEy()
const double * ptr(const int jin, const int site, const int jex) const
Definition: field.h:154
void general(const char *format,...)
Definition: bridgeIO.cpp:197
Container of Field-type object.
Definition: field.h:46
Class for parameters.
Definition: parameters.h:46
static int ipe(const int dir)
logical coordinate of current proc.
int fetch_string(const string &key, string &value) const
Definition: parameters.cpp:378
int nin() const
Definition: field.h:127
static MPI_Comm & world()
retrieves current communicator.
int nex() const
Definition: field.h:129
static int NPEx()
void crucial(const char *format,...)
Definition: bridgeIO.cpp:178
static int NPEz()
Direction
Definition: bridge_defs.h:24