Bridge++  Version 1.6.1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
fft_3d_parallel1d.cpp
Go to the documentation of this file.
1 
13 #ifdef USE_FFTWLIB
14 #ifdef USE_MPI
15 
16 #include "fft_3d_parallel1d.h"
17 #include "Field/index_lex.h"
20 
21 #ifdef USE_OPENMP
23 #endif
24 
25 
26 const std::string FFT_3d_parallel1d::class_name = "FFT_3d_parallel1d";
27 
28 #ifdef USE_FACTORY_AUTOREGISTER
29 namespace {
30  bool init = FFT_3d_parallel1d::register_factory();
31 }
32 #endif
33 
34 //====================================================================
35 FFT_3d_parallel1d::FFT_3d_parallel1d()
36  : m_ndim(0)
37  , m_vol(0)
38  , m_nv(0)
39  , m_buf_in(NULL)
40  , m_buf_out(NULL)
41  , m_plan_fw(NULL)
42  , m_plan_bw(NULL)
43  , m_direction(UNDEF)
44 {
45  if (check_ok()) {
46  initialize();
47  }
48 }
49 
50 
51 //====================================================================
52 FFT_3d_parallel1d::~FFT_3d_parallel1d()
53 {
54  finalize();
55 }
56 
57 
58 //====================================================================
59 void FFT_3d_parallel1d::set_parameters(const Parameters& params)
60 {
61  // pass parameters to base class.
62  this->FFT::set_parameters(params);
63 
64  std::string direction;
65 
66  if (params.fetch_string("FFT_direction", direction) == 0) {
67  set_parameters(direction);
68  }
69 }
70 
71 
72 //====================================================================
73 void FFT_3d_parallel1d::set_parameters(const std::string& direction)
74 {
75  if (direction == "Forward") {
76  m_direction = FORWARD;
77  } else if (direction == "Backward") {
78  m_direction = BACKWARD;
79  } else {
80  m_direction = UNDEF;
81 
82  vout.crucial(m_vl, "Error at %s: unsupported FFT direction \"%s\"\n", class_name.c_str(), direction.c_str());
83  exit(EXIT_FAILURE);
84  }
85 }
86 
87 
88 //====================================================================
89 bool FFT_3d_parallel1d::check_ok()
90 {
91  int npe_xy = Communicator::npe(0) * Communicator::npe(1);
92 
93  if (npe_xy > 1) {
94  vout.crucial(m_vl, "%s: incompatible with xy parallelization.\n", class_name.c_str());
95  exit(EXIT_FAILURE);
96  }
97 
98  return true;
99 }
100 
101 
102 //====================================================================
103 void FFT_3d_parallel1d::initialize()
104 {
105 #ifdef USE_OPENMP
106  int thread_ok = fftw_init_threads();
107 
108  if (thread_ok) {
109  fftw_plan_with_nthreads(ThreadManager_OpenMP::get_num_threads());
110  }
111 #endif
112 
113  fftw_mpi_init();
114 
115  // split communicator along t-axis to form 3dim subspaces
116  int ipe_x = Communicator::ipe(0);
117  int ipe_y = Communicator::ipe(1);
118  int ipe_z = Communicator::ipe(2);
119  int ipe_t = Communicator::ipe(3);
120 
121  int npe_x = Communicator::npe(0);
122  int npe_y = Communicator::npe(1);
123  int npe_z = Communicator::npe(2);
124  int npe_t = Communicator::npe(3);
125 
126  int local_rank = ipe_x + npe_x * (ipe_y + npe_y * ipe_z);
127 
128  MPI_Comm_split(Communicator_impl::world(), ipe_t, local_rank, &m_comm);
129 }
130 
131 
132 //====================================================================
133 void FFT_3d_parallel1d::initialize_plan(const Field& src)
134 {
135  int local_vol =
137 
138  if ((m_nv == src.nin() / 2) && (m_vol == local_vol)) {
139  vout.general(m_vl, "%s: plan recycled.\n", class_name.c_str());
140  return;
141  } else {
142  vout.general(m_vl, "%s: create plan.\n", class_name.c_str());
143  }
144 
145  // first, clear pre-existing plan, if any.
146  clear_plan();
147 
148  // assume 3dimensional
149  m_ndim = 3;
150  // row-major fortran order
151  m_nsize[0] = CommonParameters::Lz();
152  m_nsize[1] = CommonParameters::Ly();
153  m_nsize[2] = CommonParameters::Lx();
154 
155  // local volume size
156  m_vol = local_vol;
157 
158  // number of complex elements. run nin at a time.
159  m_nv = src.nin() / 2;
160 
161  // local size
162  ptrdiff_t local_n0, local_0_start;
163 
164  ptrdiff_t psize = fftw_mpi_local_size_many(m_ndim, m_nsize, m_nv,
165  FFTW_MPI_DEFAULT_BLOCK, m_comm,
166  &local_n0, &local_0_start);
167 
168  int nz = CommonParameters::Nz();
169  int ipe_z = Communicator::ipe(2);
170 
171  if ((local_n0 != nz) || (local_0_start != nz * ipe_z)) {
172  vout.crucial(m_vl, "%s: data distribution plan not matched.\n", class_name.c_str());
173  exit(EXIT_FAILURE);
174  }
175 
176  m_buf_in = fftw_alloc_complex(psize);
177  m_buf_out = fftw_alloc_complex(psize);
178 
179  if ((!m_buf_in) || (!m_buf_out)) {
180  vout.crucial(m_vl, "%s: memory allocation failed.\n", class_name.c_str());
181  exit(EXIT_FAILURE);
182  }
183 
184  m_plan_fw = fftw_mpi_plan_many_dft(m_ndim, m_nsize, m_nv,
185  FFTW_MPI_DEFAULT_BLOCK, FFTW_MPI_DEFAULT_BLOCK,
186  m_buf_in, m_buf_out,
187  m_comm,
188  FFTW_FORWARD,
189  FFTW_ESTIMATE);
190 
191  m_plan_bw = fftw_mpi_plan_many_dft(m_ndim, m_nsize, m_nv,
192  FFTW_MPI_DEFAULT_BLOCK, FFTW_MPI_DEFAULT_BLOCK,
193  m_buf_in, m_buf_out,
194  m_comm,
195  FFTW_BACKWARD,
196  FFTW_ESTIMATE);
197 
198  if ((!m_plan_fw) || (!m_plan_bw)) {
199  vout.crucial(m_vl, "%s: create plan failed.\n", class_name.c_str());
200  exit(EXIT_FAILURE);
201  }
202 }
203 
204 
205 //====================================================================
206 void FFT_3d_parallel1d::clear_plan()
207 {
208  if (m_buf_in) fftw_free(m_buf_in);
209  if (m_buf_out) fftw_free(m_buf_out);
210 
211  if (m_plan_fw) fftw_destroy_plan(m_plan_fw);
212  if (m_plan_bw) fftw_destroy_plan(m_plan_bw);
213 }
214 
215 
216 //====================================================================
217 void FFT_3d_parallel1d::finalize()
218 {
219  clear_plan();
220 
221  MPI_Comm_free(&m_comm);
222 }
223 
224 
225 //====================================================================
226 void FFT_3d_parallel1d::fft(Field& dst, const Field& src, enum Direction dir)
227 {
228  if (not ((dir == FORWARD) || (dir == BACKWARD))) {
229  vout.crucial(m_vl, "%s: unsupported direction. %d\n", class_name.c_str(), dir);
230  exit(EXIT_FAILURE);
231  }
232 
233  initialize_plan(src);
234 
235  int nex = src.nex();
236  int nt = CommonParameters::Nt();
237 
238  Index_lex index;
239 
240  size_t count = m_nv * m_vol; // count in complex numbers
241 
242  for (int iex = 0; iex < nex; ++iex) {
243  for (int it = 0; it < nt; ++it) {
244  memcpy(m_buf_in, src.ptr(0, index.site(0, 0, 0, it), iex), sizeof(double) * count * 2);
245 
246  if (dir == FORWARD) {
247  fftw_execute(m_plan_fw);
248  } else if (dir == BACKWARD) {
249  fftw_execute(m_plan_bw);
250  } else {
251  vout.crucial(m_vl, "%s: unsupported direction. %d\n", class_name.c_str(), dir);
252  exit(EXIT_FAILURE);
253  }
254 
255  memcpy(dst.ptr(0, index.site(0, 0, 0, it), iex), m_buf_out, sizeof(double) * count * 2);
256  }
257  }
258 
259  if (dir == BACKWARD) {
260  size_t global_vol = CommonParameters::Lvol() / CommonParameters::Lt();
261  scal(dst, 1.0 / global_vol);
262  }
263 }
264 
265 
266 //====================================================================
267 void FFT_3d_parallel1d::fft(Field& dst, const Field& src)
268 {
269  return fft(dst, src, m_direction);
270 }
271 
272 
273 //====================================================================
274 void FFT_3d_parallel1d::fft(Field& field)
275 {
276  vout.crucial(m_vl, "Error at %s: fft on-the-fly unsupported.\n", class_name.c_str());
277  exit(EXIT_FAILURE);
278 }
279 
280 
281 //====================================================================
282 #endif /* USE_MPI */
283 #endif /* USE_FFTWLIB */
void scal(Field &x, const double a)
scal(x, a): x = a * x
Definition: field.cpp:433
BridgeIO vout
Definition: bridgeIO.cpp:503
static int get_num_threads()
returns available number of threads.
static int npe(const int dir)
logical grid extent
const double * ptr(const int jin, const int site, const int jex) const
Definition: field.h:154
int site(const int &x, const int &y, const int &z, const int &t) const
Definition: index_lex.h:53
void general(const char *format,...)
Definition: bridgeIO.cpp:197
Container of Field-type object.
Definition: field.h:46
Class for parameters.
Definition: parameters.h:46
static int ipe(const int dir)
logical coordinate of current proc.
int fetch_string(const string &key, string &value) const
Definition: parameters.cpp:378
int nin() const
Definition: field.h:127
static MPI_Comm & world()
retrieves current communicator.
int nex() const
Definition: field.h:129
Lexical site index.
Definition: index_lex.h:34
void crucial(const char *format,...)
Definition: bridgeIO.cpp:178
Direction
Definition: bridge_defs.h:24
static long_t Lvol()