25 const std::string FFT_3d_parallel3d::class_name =
"FFT_3d_parallel3d";
27 #ifdef USE_FACTORY_AUTOREGISTER
29 bool init = FFT_3d_parallel3d::register_factory();
34 void FFT_3d_parallel3d::set_parameters(
const Parameters& params)
41 std::string direction;
43 set_parameters(direction);
49 void FFT_3d_parallel3d::get_parameters(
Parameters& params)
const
51 if (m_direction == FORWARD) {
53 }
else if (m_direction == BACKWARD) {
54 params.
set_string(
"FFT_direction",
"Backward");
64 void FFT_3d_parallel3d::set_parameters(
const std::string& direction)
66 if (direction ==
"Forward") {
67 m_direction = FORWARD;
68 }
else if (direction ==
"Backward") {
69 m_direction = BACKWARD;
73 vout.
crucial(m_vl,
"Error at %s: unsupported FFT direction \"%s\"\n", class_name.c_str(), direction.c_str());
80 FFT_3d_parallel3d::FFT_3d_parallel3d()
82 , m_initialized(false)
92 FFT_3d_parallel3d::FFT_3d_parallel3d(
const Parameters& params)
94 , m_initialized(false)
100 set_parameters(params);
105 FFT_3d_parallel3d::~FFT_3d_parallel3d()
112 bool FFT_3d_parallel3d::check_ok()
119 void FFT_3d_parallel3d::initialize()
122 int thread_ok = fftw_init_threads();
142 int ipe_xyz = ipe_x + npe_x * (ipe_y + npe_y * (ipe_z));
144 MPI_Comm_split(Communicator_impl::world(), ipe_t, ipe_xyz, &m_comm);
148 MPI_Comm_rank(m_comm, &local_rank);
150 int local_ipe_x = local_rank % npe_x;
151 int local_ipe_y = (local_rank / npe_x) % npe_y;
152 int local_ipe_z = (local_rank / npe_x / npe_y) % npe_z;
155 if ((local_ipe_x != ipe_x) || (local_ipe_y != ipe_y) || (local_ipe_z != ipe_z)) {
156 vout.
crucial(m_vl,
"%s: split commnicator failed.\n", class_name.c_str());
160 m_local_rank = local_rank;
162 m_local_ipe_x = local_ipe_x;
163 m_local_ipe_y = local_ipe_y;
164 m_local_ipe_z = local_ipe_z;
169 m_grid_size.resize(m_ndims);
175 for (
int i = 0; i < m_ndims; ++i) {
176 m_grid_vol *= m_grid_size[i];
179 m_lattice_size.resize(m_ndims);
185 for (
int i = 0; i < m_ndims; ++i) {
186 m_lattice_vol *= m_lattice_size[i];
189 m_local_size.resize(m_ndims);
195 for (
int i = 0; i < m_ndims; ++i) {
196 m_local_vol *= m_local_size[i];
202 void FFT_3d_parallel3d::finalize()
211 void FFT_3d_parallel3d::create_mpi_datatype(
int site_dof)
217 MPI_Type_contiguous(2 * site_dof,
219 &m_site_vector_type);
221 MPI_Type_commit(&m_site_vector_type);
223 MPI_Type_contiguous(m_local_vol,
225 &m_local_patch_type);
227 MPI_Type_commit(&m_local_patch_type);
230 MPI_Type_size(m_site_vector_type, &size_);
233 std::vector<int> local_origin(m_ndims, 0);
235 MPI_Type_create_subarray(m_ndims,
243 MPI_Type_create_resized(type_, 0, size_, &m_subarray_type);
245 MPI_Type_commit(&m_subarray_type);
248 m_sendcounts.resize(m_grid_vol);
250 for (
int r = 0; r < m_grid_vol; ++r) {
254 m_subarray_displs.resize(m_grid_vol);
256 for (
int r = 0; r < m_grid_vol; ++r) {
257 std::vector<int> coord = grid_rank_to_coord(r);
260 for (
int j = 0; j < m_ndims; ++j) {
261 coord[j] *= m_local_size[j];
264 int idx = find_global_index(coord);
266 m_subarray_displs[r] =
idx;
269 m_local_patch_displs.resize(m_grid_vol);
271 for (
int r = 0; r < m_grid_vol; ++r) {
272 m_local_patch_displs[r] = r;
278 void FFT_3d_parallel3d::release_mpi_datatype()
280 int is_finalized = 0;
282 MPI_Finalized(&is_finalized);
285 vout.
crucial(m_vl,
"%s: MPI has already gone...\n", class_name.c_str());
289 MPI_Type_free(&m_site_vector_type);
290 MPI_Type_free(&m_subarray_type);
291 MPI_Type_free(&m_local_patch_type);
296 void FFT_3d_parallel3d::create_fft_plan(
int site_dof)
299 m_buf = fftw_alloc_complex(site_dof * m_lattice_vol);
301 vout.
crucial(m_vl,
"%s: buffer allocation failed.\n", class_name.c_str());
307 m_plan_fw = fftw_plan_many_dft(m_ndims, &m_lattice_size[0], site_dof,
308 m_buf, NULL, site_dof, 1,
309 m_buf, NULL, site_dof, 1,
310 FFTW_FORWARD, FFTW_ESTIMATE);
312 m_plan_bw = fftw_plan_many_dft(m_ndims, &m_lattice_size[0], site_dof,
313 m_buf, NULL, site_dof, 1,
314 m_buf, NULL, site_dof, 1,
315 FFTW_BACKWARD, FFTW_ESTIMATE);
317 if (!m_plan_fw || !m_plan_bw) {
318 vout.
crucial(m_vl,
"%s: create plan failed.\n", class_name.c_str());
325 void FFT_3d_parallel3d::release_fft_plan()
327 if (m_buf) fftw_free(m_buf);
329 if (m_plan_fw) fftw_destroy_plan(m_plan_fw);
331 if (m_plan_bw) fftw_destroy_plan(m_plan_bw);
337 void FFT_3d_parallel3d::create_plan(
int site_dof)
339 create_mpi_datatype(site_dof);
340 create_fft_plan(site_dof);
342 m_site_dof = site_dof;
344 m_initialized =
true;
349 void FFT_3d_parallel3d::release_plan()
352 release_mpi_datatype();
354 m_initialized =
false;
359 bool FFT_3d_parallel3d::need_create_plan(
const Field& field)
361 if (field.
nin() / 2 == m_site_dof)
return false;
370 if (not ((dir == FORWARD) || (dir == BACKWARD))) {
371 vout.
crucial(m_vl,
"%s: unsupported direction. %d\n", class_name.c_str(), dir);
376 if (m_initialized ==
false) {
377 vout.
general(m_vl,
"%s: create plan.\n", class_name.c_str());
378 create_plan(src.
nin() / 2);
380 if (need_create_plan(src)) {
381 vout.
general(m_vl,
"%s: discard plan and create new.\n", class_name.c_str());
383 create_plan(src.
nin() / 2);
385 vout.
general(m_vl,
"%s: plan recycled.\n", class_name.c_str());
392 int ndata = nt * nex;
394 std::vector<dcomplex *> src_array(ndata,
nullptr);
395 std::vector<dcomplex *> dst_array(ndata,
nullptr);
397 int local_vol = m_local_vol;
400 for (
int iex = 0; iex < nex; ++iex) {
401 for (
int it = 0; it < nt; ++it) {
402 src_array[k] = (dcomplex *)(src.
ptr(0, local_vol * it, iex));
403 dst_array[k] = (dcomplex *)(dst.
ptr(0, local_vol * it, iex));
408 int nblock = m_grid_vol;
410 for (
int k = 0; k < ndata; k += nblock) {
411 bool do_full = (k + nblock <= ndata);
412 int nwork = do_full ? nblock : (ndata % nblock);
415 MPI_Alltoallv(src_array[k], &m_sendcounts[0], &m_local_patch_displs[0], m_local_patch_type,
416 m_buf, &m_sendcounts[0], &m_subarray_displs[0], m_subarray_type,
419 for (
int j = 0; j < nwork; ++j) {
420 MPI_Gatherv(src_array[k + j], 1, m_local_patch_type,
421 m_buf, &m_sendcounts[0], &m_subarray_displs[0], m_subarray_type,
426 if (m_local_rank < nwork) {
427 fftw_execute(dir == FORWARD ? m_plan_fw : m_plan_bw);
431 MPI_Alltoallv(m_buf, &m_sendcounts[0], &m_subarray_displs[0], m_subarray_type,
432 dst_array[k], &m_sendcounts[0], &m_local_patch_displs[0], m_local_patch_type,
435 for (
int j = 0; j < nwork; ++j) {
436 MPI_Scatterv(m_buf, &m_sendcounts[0], &m_subarray_displs[0], m_subarray_type,
437 dst_array[k + j], 1, m_local_patch_type,
443 if (dir == BACKWARD) {
444 scal(dst, 1.0 / m_lattice_vol);
450 void FFT_3d_parallel3d::fft(
Field& dst,
const Field& src)
452 return fft(dst, src, m_direction);
457 void FFT_3d_parallel3d::fft(
Field& field)
460 vout.
crucial(m_vl,
"Error at %s: fft on-the-fly unsupported.\n", class_name.c_str());
466 std::vector<int> FFT_3d_parallel3d::grid_rank_to_coord(
int r)
468 std::vector<int> coord(m_ndims);
470 for (
int i = 0; i < m_ndims; ++i) {
471 coord[i] = r % m_grid_size[i];
480 int FFT_3d_parallel3d::find_global_index(
const std::vector<int>& coord)
482 assert(coord.size() == m_ndims);
484 int idx = coord[m_ndims - 1];
485 for (
int i = m_ndims - 2; i >= 0; --i) {
486 idx *= m_lattice_size[i];