Bridge++  Version 1.5.4
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
fft_3d_parallel1d.cpp
Go to the documentation of this file.
1 
14 #ifdef USE_FFTWLIB
15 #ifdef USE_MPI
16 
17 #include "fft_3d_parallel1d.h"
18 #include "Field/index_lex.h"
21 
22 #ifdef USE_OPENMP
24 #endif
25 
26 
27 const std::string FFT_3d_parallel1d::class_name = "FFT_3d_parallel1d";
28 
29 #ifdef USE_FACTORY_AUTOREGISTER
30 namespace {
31  bool init = FFT_3d_parallel1d::register_factory();
32 }
33 #endif
34 
35 //====================================================================
36 FFT_3d_parallel1d::FFT_3d_parallel1d()
37  : m_ndim(0)
38  , m_vol(0)
39  , m_nv(0)
40  , m_buf_in(NULL)
41  , m_buf_out(NULL)
42  , m_plan_fw(NULL)
43  , m_plan_bw(NULL)
44  , m_direction(UNDEF)
45 {
46  if (check_ok()) {
47  initialize();
48  }
49 }
50 
51 
52 //====================================================================
53 FFT_3d_parallel1d::~FFT_3d_parallel1d()
54 {
55  finalize();
56 }
57 
58 
59 //====================================================================
60 void FFT_3d_parallel1d::set_parameters(const Parameters& params)
61 {
62  // pass parameters to base class.
63  this->FFT::set_parameters(params);
64 
65  std::string direction;
66 
67  if (params.fetch_string("FFT_direction", direction) == 0) {
68  set_parameters(direction);
69  }
70 }
71 
72 
73 //====================================================================
74 void FFT_3d_parallel1d::set_parameters(const std::string& direction)
75 {
76  if (direction == "Forward") {
77  m_direction = FORWARD;
78  } else if (direction == "Backward") {
79  m_direction = BACKWARD;
80  } else {
81  m_direction = UNDEF;
82 
83  vout.crucial(m_vl, "Error at %s: unsupported FFT direction \"%s\"\n", class_name.c_str(), direction.c_str());
84  exit(EXIT_FAILURE);
85  }
86 }
87 
88 
89 //====================================================================
90 bool FFT_3d_parallel1d::check_ok()
91 {
92  int npe_xy = Communicator::npe(0) * Communicator::npe(1);
93 
94  if (npe_xy > 1) {
95  vout.crucial(m_vl, "%s: incompatible with xy parallelization.\n", class_name.c_str());
96  exit(EXIT_FAILURE);
97  }
98 
99  return true;
100 }
101 
102 
103 //====================================================================
104 void FFT_3d_parallel1d::initialize()
105 {
106 #ifdef USE_OPENMP
107  int thread_ok = fftw_init_threads();
108 
109  if (thread_ok) {
110  fftw_plan_with_nthreads(ThreadManager_OpenMP::get_num_threads());
111  }
112 #endif
113 
114  fftw_mpi_init();
115 
116  // split communicator along t-axis to form 3dim subspaces
117  int ipe_x = Communicator::ipe(0);
118  int ipe_y = Communicator::ipe(1);
119  int ipe_z = Communicator::ipe(2);
120  int ipe_t = Communicator::ipe(3);
121 
122  int npe_x = Communicator::npe(0);
123  int npe_y = Communicator::npe(1);
124  int npe_z = Communicator::npe(2);
125  int npe_t = Communicator::npe(3);
126 
127  int local_rank = ipe_x + npe_x * (ipe_y + npe_y * ipe_z);
128 
129  MPI_Comm_split(Communicator_impl::world(), ipe_t, local_rank, &m_comm);
130 }
131 
132 
133 //====================================================================
134 void FFT_3d_parallel1d::initialize_plan(const Field& src)
135 {
136  int local_vol =
138 
139  if ((m_nv == src.nin() / 2) && (m_vol == local_vol)) {
140  vout.general(m_vl, "%s: plan recycled.\n", class_name.c_str());
141  return;
142  } else {
143  vout.general(m_vl, "%s: create plan.\n", class_name.c_str());
144  }
145 
146  // first, clear pre-existing plan, if any.
147  clear_plan();
148 
149  // assume 3dimensional
150  m_ndim = 3;
151  // row-major fortran order
152  m_nsize[0] = CommonParameters::Lz();
153  m_nsize[1] = CommonParameters::Ly();
154  m_nsize[2] = CommonParameters::Lx();
155 
156  // local volume size
157  m_vol = local_vol;
158 
159  // number of complex elements. run nin at a time.
160  m_nv = src.nin() / 2;
161 
162  // local size
163  ptrdiff_t local_n0, local_0_start;
164 
165  ptrdiff_t psize = fftw_mpi_local_size_many(m_ndim, m_nsize, m_nv,
166  FFTW_MPI_DEFAULT_BLOCK, m_comm,
167  &local_n0, &local_0_start);
168 
169  int nz = CommonParameters::Nz();
170  int ipe_z = Communicator::ipe(2);
171 
172  if ((local_n0 != nz) || (local_0_start != nz * ipe_z)) {
173  vout.crucial(m_vl, "%s: data distribution plan not matched.\n", class_name.c_str());
174  exit(EXIT_FAILURE);
175  }
176 
177  m_buf_in = fftw_alloc_complex(psize);
178  m_buf_out = fftw_alloc_complex(psize);
179 
180  if ((!m_buf_in) || (!m_buf_out)) {
181  vout.crucial(m_vl, "%s: memory allocation failed.\n", class_name.c_str());
182  exit(EXIT_FAILURE);
183  }
184 
185  m_plan_fw = fftw_mpi_plan_many_dft(m_ndim, m_nsize, m_nv,
186  FFTW_MPI_DEFAULT_BLOCK, FFTW_MPI_DEFAULT_BLOCK,
187  m_buf_in, m_buf_out,
188  m_comm,
189  FFTW_FORWARD,
190  FFTW_ESTIMATE);
191 
192  m_plan_bw = fftw_mpi_plan_many_dft(m_ndim, m_nsize, m_nv,
193  FFTW_MPI_DEFAULT_BLOCK, FFTW_MPI_DEFAULT_BLOCK,
194  m_buf_in, m_buf_out,
195  m_comm,
196  FFTW_BACKWARD,
197  FFTW_ESTIMATE);
198 
199  if ((!m_plan_fw) || (!m_plan_bw)) {
200  vout.crucial(m_vl, "%s: create plan failed.\n", class_name.c_str());
201  exit(EXIT_FAILURE);
202  }
203 }
204 
205 
206 //====================================================================
207 void FFT_3d_parallel1d::clear_plan()
208 {
209  if (m_buf_in) fftw_free(m_buf_in);
210  if (m_buf_out) fftw_free(m_buf_out);
211 
212  if (m_plan_fw) fftw_destroy_plan(m_plan_fw);
213  if (m_plan_bw) fftw_destroy_plan(m_plan_bw);
214 }
215 
216 
217 //====================================================================
218 void FFT_3d_parallel1d::finalize()
219 {
220  clear_plan();
221 
222  MPI_Comm_free(&m_comm);
223 }
224 
225 
226 //====================================================================
227 void FFT_3d_parallel1d::fft(Field& dst, const Field& src, enum Direction dir)
228 {
229  if (not ((dir == FORWARD) || (dir == BACKWARD))) {
230  vout.crucial(m_vl, "%s: unsupported direction. %d\n", class_name.c_str(), dir);
231  exit(EXIT_FAILURE);
232  }
233 
234  initialize_plan(src);
235 
236  int nex = src.nex();
237  int nt = CommonParameters::Nt();
238 
239  Index_lex index;
240 
241  size_t count = m_nv * m_vol; // count in complex numbers
242 
243  for (int iex = 0; iex < nex; ++iex) {
244  for (int it = 0; it < nt; ++it) {
245  memcpy(m_buf_in, src.ptr(0, index.site(0, 0, 0, it), iex), sizeof(double) * count * 2);
246 
247  if (dir == FORWARD) {
248  fftw_execute(m_plan_fw);
249  } else if (dir == BACKWARD) {
250  fftw_execute(m_plan_bw);
251  } else {
252  vout.crucial(m_vl, "%s: unsupported direction. %d\n", class_name.c_str(), dir);
253  exit(EXIT_FAILURE);
254  }
255 
256  memcpy(dst.ptr(0, index.site(0, 0, 0, it), iex), m_buf_out, sizeof(double) * count * 2);
257  }
258  }
259 
260  if (dir == BACKWARD) {
261  size_t global_vol = CommonParameters::Lvol() / CommonParameters::Lt();
262  scal(dst, 1.0 / global_vol);
263  }
264 }
265 
266 
267 //====================================================================
268 void FFT_3d_parallel1d::fft(Field& dst, const Field& src)
269 {
270  return fft(dst, src, m_direction);
271 }
272 
273 
274 //====================================================================
275 void FFT_3d_parallel1d::fft(Field& field)
276 {
277  vout.crucial(m_vl, "Error at %s: fft on-the-fly unsupported.\n", class_name.c_str());
278  exit(EXIT_FAILURE);
279 }
280 
281 
282 //====================================================================
283 #endif /* USE_MPI */
284 #endif /* USE_FFTWLIB */
void scal(Field &x, const double a)
scal(x, a): x = a * x
Definition: field.cpp:433
BridgeIO vout
Definition: bridgeIO.cpp:503
static int get_num_threads()
returns available number of threads.
static int npe(const int dir)
logical grid extent
const double * ptr(const int jin, const int site, const int jex) const
Definition: field.h:153
int site(const int &x, const int &y, const int &z, const int &t) const
Definition: index_lex.h:53
void general(const char *format,...)
Definition: bridgeIO.cpp:197
Container of Field-type object.
Definition: field.h:45
Class for parameters.
Definition: parameters.h:46
static int ipe(const int dir)
logical coordinate of current proc.
int fetch_string(const string &key, string &value) const
Definition: parameters.cpp:378
int nin() const
Definition: field.h:126
static MPI_Comm & world()
retrieves current communicator.
int nex() const
Definition: field.h:128
Lexical site index.
Definition: index_lex.h:34
void crucial(const char *format,...)
Definition: bridgeIO.cpp:178
Direction
Definition: bridge_defs.h:24
static long_t Lvol()