26 const std::string FFT_3d_parallel3d::class_name =
"FFT_3d_parallel3d";
28 #ifdef USE_FACTORY_AUTOREGISTER
30 bool init = FFT_3d_parallel3d::register_factory();
35 void FFT_3d_parallel3d::set_parameters(
const Parameters& params)
38 this->FFT::set_parameters(params);
40 std::string direction;
42 if (params.
fetch_string(
"FFT_direction", direction) == 0) {
43 set_parameters(direction);
49 void FFT_3d_parallel3d::set_parameters(
const std::string& direction)
51 if (direction ==
"Forward") {
52 m_direction = FORWARD;
53 }
else if (direction ==
"Backward") {
54 m_direction = BACKWARD;
58 vout.
crucial(m_vl,
"Error at %s: unsupported FFT direction \"%s\"\n", class_name.c_str(), direction.c_str());
65 FFT_3d_parallel3d::FFT_3d_parallel3d()
66 : m_initialized(false)
76 FFT_3d_parallel3d::~FFT_3d_parallel3d()
83 bool FFT_3d_parallel3d::check_ok()
90 void FFT_3d_parallel3d::initialize()
93 int thread_ok = fftw_init_threads();
113 int ipe_xyz = ipe_x + npe_x * (ipe_y + npe_y * (ipe_z));
119 MPI_Comm_rank(m_comm, &local_rank);
121 int local_ipe_x = local_rank % npe_x;
122 int local_ipe_y = (local_rank / npe_x) % npe_y;
123 int local_ipe_z = (local_rank / npe_x / npe_y) % npe_z;
126 if ((local_ipe_x != ipe_x) || (local_ipe_y != ipe_y) || (local_ipe_z != ipe_z)) {
127 vout.
crucial(m_vl,
"%s: split commnicator failed.\n", class_name.c_str());
131 m_local_rank = local_rank;
133 m_local_ipe_x = local_ipe_x;
134 m_local_ipe_y = local_ipe_y;
135 m_local_ipe_z = local_ipe_z;
140 m_grid_size.resize(m_ndims);
146 for (
int i = 0; i < m_ndims; ++i) {
147 m_grid_vol *= m_grid_size[i];
150 m_lattice_size.resize(m_ndims);
156 for (
int i = 0; i < m_ndims; ++i) {
157 m_lattice_vol *= m_lattice_size[i];
160 m_local_size.resize(m_ndims);
166 for (
int i = 0; i < m_ndims; ++i) {
167 m_local_vol *= m_local_size[i];
173 void FFT_3d_parallel3d::finalize()
182 void FFT_3d_parallel3d::create_mpi_datatype(
int site_dof)
188 MPI_Type_contiguous(2 * site_dof,
190 &m_site_vector_type);
192 MPI_Type_commit(&m_site_vector_type);
194 MPI_Type_contiguous(m_local_vol,
196 &m_local_patch_type);
198 MPI_Type_commit(&m_local_patch_type);
201 MPI_Type_size(m_site_vector_type, &size_);
204 std::vector<int> local_origin(m_ndims, 0);
206 MPI_Type_create_subarray(m_ndims,
214 MPI_Type_create_resized(type_, 0, size_, &m_subarray_type);
216 MPI_Type_commit(&m_subarray_type);
219 m_sendcounts.resize(m_grid_vol);
221 for (
int r = 0; r < m_grid_vol; ++r) {
225 m_subarray_displs.resize(m_grid_vol);
227 for (
int r = 0; r < m_grid_vol; ++r) {
228 std::vector<int> coord = grid_rank_to_coord(r);
231 for (
int j = 0; j < m_ndims; ++j) {
232 coord[j] *= m_local_size[j];
235 int idx = find_global_index(coord);
237 m_subarray_displs[r] = idx;
240 m_local_patch_displs.resize(m_grid_vol);
242 for (
int r = 0; r < m_grid_vol; ++r) {
243 m_local_patch_displs[r] = r;
249 void FFT_3d_parallel3d::release_mpi_datatype()
251 int is_finalized = 0;
253 MPI_Finalized(&is_finalized);
256 vout.
crucial(m_vl,
"%s: MPI has already gone...\n", class_name.c_str());
260 MPI_Type_free(&m_site_vector_type);
261 MPI_Type_free(&m_subarray_type);
262 MPI_Type_free(&m_local_patch_type);
267 void FFT_3d_parallel3d::create_fft_plan(
int site_dof)
270 m_buf = fftw_alloc_complex(site_dof * m_lattice_vol);
272 vout.
crucial(m_vl,
"%s: buffer allocation failed.\n", class_name.c_str());
278 m_plan_fw = fftw_plan_many_dft(m_ndims, &m_lattice_size[0], site_dof,
279 m_buf, NULL, site_dof, 1,
280 m_buf, NULL, site_dof, 1,
281 FFTW_FORWARD, FFTW_ESTIMATE);
283 m_plan_bw = fftw_plan_many_dft(m_ndims, &m_lattice_size[0], site_dof,
284 m_buf, NULL, site_dof, 1,
285 m_buf, NULL, site_dof, 1,
286 FFTW_BACKWARD, FFTW_ESTIMATE);
288 if (!m_plan_fw || !m_plan_bw) {
289 vout.
crucial(m_vl,
"%s: create plan failed.\n", class_name.c_str());
296 void FFT_3d_parallel3d::release_fft_plan()
298 if (m_buf) fftw_free(m_buf);
300 if (m_plan_fw) fftw_destroy_plan(m_plan_fw);
302 if (m_plan_bw) fftw_destroy_plan(m_plan_bw);
308 void FFT_3d_parallel3d::create_plan(
int site_dof)
310 create_mpi_datatype(site_dof);
311 create_fft_plan(site_dof);
313 m_site_dof = site_dof;
315 m_initialized =
true;
320 void FFT_3d_parallel3d::release_plan()
323 release_mpi_datatype();
325 m_initialized =
false;
330 bool FFT_3d_parallel3d::need_create_plan(
const Field& field)
332 if (field.
nin() / 2 == m_site_dof)
return false;
341 if (not ((dir == FORWARD) || (dir == BACKWARD))) {
342 vout.
crucial(m_vl,
"%s: unsupported direction. %d\n", class_name.c_str(), dir);
347 if (m_initialized ==
false) {
348 vout.
general(m_vl,
"%s: create plan.\n", class_name.c_str());
349 create_plan(src.
nin() / 2);
351 if (need_create_plan(src)) {
352 vout.
general(m_vl,
"%s: discard plan and create new.\n", class_name.c_str());
354 create_plan(src.
nin() / 2);
356 vout.
general(m_vl,
"%s: plan recycled.\n", class_name.c_str());
363 int ndata = nt * nex;
366 std::vector<dcomplex *> src_array(ndata,
nullptr);
367 std::vector<dcomplex *> dst_array(ndata,
nullptr);
369 std::vector<dcomplex *> src_array(ndata, NULL);
370 std::vector<dcomplex *> dst_array(ndata, NULL);
373 int local_vol = m_local_vol;
376 for (
int iex = 0; iex < nex; ++iex) {
377 for (
int it = 0; it < nt; ++it) {
378 src_array[k] = (dcomplex *)(src.
ptr(0, local_vol * it, iex));
379 dst_array[k] = (dcomplex *)(dst.
ptr(0, local_vol * it, iex));
384 int nblock = m_grid_vol;
386 for (
int k = 0; k < ndata; k += nblock) {
387 bool do_full = (k + nblock <= ndata);
388 int nwork = do_full ? nblock : (ndata % nblock);
391 MPI_Alltoallv(src_array[k], &m_sendcounts[0], &m_local_patch_displs[0], m_local_patch_type,
392 m_buf, &m_sendcounts[0], &m_subarray_displs[0], m_subarray_type,
395 for (
int j = 0; j < nwork; ++j) {
396 MPI_Gatherv(src_array[k + j], 1, m_local_patch_type,
397 m_buf, &m_sendcounts[0], &m_subarray_displs[0], m_subarray_type,
402 if (m_local_rank < nwork) {
403 fftw_execute(dir == FORWARD ? m_plan_fw : m_plan_bw);
407 MPI_Alltoallv(m_buf, &m_sendcounts[0], &m_subarray_displs[0], m_subarray_type,
408 dst_array[k], &m_sendcounts[0], &m_local_patch_displs[0], m_local_patch_type,
411 for (
int j = 0; j < nwork; ++j) {
412 MPI_Scatterv(m_buf, &m_sendcounts[0], &m_subarray_displs[0], m_subarray_type,
413 dst_array[k + j], 1, m_local_patch_type,
419 if (dir == BACKWARD) {
420 scal(dst, 1.0 / m_lattice_vol);
426 void FFT_3d_parallel3d::fft(
Field& dst,
const Field& src)
428 return fft(dst, src, m_direction);
433 void FFT_3d_parallel3d::fft(
Field& field)
436 vout.
crucial(m_vl,
"Error at %s: fft on-the-fly unsupported.\n", class_name.c_str());
442 std::vector<int> FFT_3d_parallel3d::grid_rank_to_coord(
int r)
444 std::vector<int> coord(m_ndims);
446 for (
int i = 0; i < m_ndims; ++i) {
447 coord[i] = r % m_grid_size[i];
456 int FFT_3d_parallel3d::find_global_index(
const std::vector<int>& coord)
458 assert(coord.size() == m_ndims);
460 int idx = coord[m_ndims - 1];
461 for (
int i = m_ndims - 2; i >= 0; --i) {
462 idx *= m_lattice_size[i];
void scal(Field &x, const double a)
scal(x, a): x = a * x
static int get_num_threads()
returns available number of threads.
static int npe(const int dir)
logical grid extent
const double * ptr(const int jin, const int site, const int jex) const
void general(const char *format,...)
Container of Field-type object.
static int ipe(const int dir)
logical coordinate of current proc.
int fetch_string(const string &key, string &value) const
static MPI_Comm & world()
retrieves current communicator.
void crucial(const char *format,...)