Bridge++  Version 1.6.1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
fft_3d_local.cpp
Go to the documentation of this file.
1 
13 #ifdef USE_FFTWLIB
14 
15 #include "fft_3d_local.h"
16 #include "Field/index_lex.h"
17 #include <cstring>
18 
19 #ifdef USE_OPENMP
21 #endif
22 
23 const std::string FFT_3d_local::class_name = "FFT_3d_local";
24 
25 #ifdef USE_FACTORY_AUTOREGISTER
26 namespace {
27  bool init = FFT_3d_local::register_factory();
28 }
29 #endif
30 
31 //====================================================================
32 FFT_3d_local::FFT_3d_local()
33  : m_ndim(0)
34  , m_vol(0)
35  , m_nv(0)
36  , m_buf_in(NULL)
37  , m_buf_out(NULL)
38  , m_plan_fw(NULL)
39  , m_plan_bw(NULL)
40  , m_direction(UNDEF)
41 {
42  if (check_ok()) {
43  initialize();
44  }
45 }
46 
47 
48 //====================================================================
49 FFT_3d_local::~FFT_3d_local()
50 {
51  finalize();
52 }
53 
54 
55 //====================================================================
56 void FFT_3d_local::set_parameters(const Parameters& params)
57 {
58  // pass parameters to base class.
59  this->FFT::set_parameters(params);
60 
61  std::string direction;
62 
63  if (params.fetch_string("FFT_direction", direction) == 0) {
64  set_parameters(direction);
65  }
66 }
67 
68 
69 //====================================================================
70 void FFT_3d_local::set_parameters(const std::string& direction)
71 {
72  if (direction == "Forward") {
73  m_direction = FORWARD;
74  } else if (direction == "Backward") {
75  m_direction = BACKWARD;
76  } else {
77  m_direction = UNDEF;
78 
79  vout.crucial(m_vl, "Error at %s: unsupported FFT direction \"%s\"\n", class_name.c_str(), direction.c_str());
80  exit(EXIT_FAILURE);
81  }
82 }
83 
84 
85 //====================================================================
86 bool FFT_3d_local::check_ok()
87 {
88  int npe_xyz = Communicator::npe(0) * Communicator::npe(1) * Communicator::npe(2);
89 
90  if (npe_xyz > 1) {
91  vout.crucial(m_vl, "%s: incompatible with xyz parallelization.\n", class_name.c_str());
92  exit(EXIT_FAILURE);
93  }
94 
95  return true;
96 }
97 
98 
99 //====================================================================
100 void FFT_3d_local::initialize()
101 {
102 #ifdef USE_OPENMP
103  int thread_ok = fftw_init_threads();
104 
105  if (thread_ok) {
106  fftw_plan_with_nthreads(ThreadManager_OpenMP::get_num_threads());
107  }
108 #endif
109 }
110 
111 
112 //====================================================================
113 void FFT_3d_local::initialize_plan(const Field& src)
114 {
115  if ((m_nv == src.nin() / 2) && (m_vol == CommonParameters::Lvol() / CommonParameters::Lt())) {
116  vout.detailed(m_vl, "%s: plan recycled.\n", class_name.c_str());
117  return;
118  } else {
119  vout.detailed(m_vl, "%s: create plan.\n", class_name.c_str());
120  }
121 
122  // first, clear pre-existing plan, if any.
123  clear_plan();
124 
125  // assume 3dimensional
126  m_ndim = 3;
127  // row-major fortran order
128  m_nsize[0] = CommonParameters::Lz();
129  m_nsize[1] = CommonParameters::Ly();
130  m_nsize[2] = CommonParameters::Lx();
131 
132  // local volume
133  m_vol = 1;
134  for (int i = 0; i < m_ndim; ++i) {
135  m_vol *= m_nsize[i];
136  }
137 
138  // number of complex elements. run nin at a time.
139  m_nv = src.nin() / 2;
140 
141  m_buf_in = fftw_alloc_complex(m_nv * m_vol);
142  m_buf_out = fftw_alloc_complex(m_nv * m_vol);
143 
144  if ((!m_buf_in) || (!m_buf_out)) {
145  vout.crucial(m_vl, "%s: memory allocation failed.\n", class_name.c_str());
146  exit(EXIT_FAILURE);
147  }
148 
149  m_plan_fw = fftw_plan_many_dft(m_ndim, m_nsize, m_nv,
150  m_buf_in, m_nsize, m_nv, 1,
151  m_buf_out, m_nsize, m_nv, 1,
152  FFTW_FORWARD, FFTW_ESTIMATE);
153 
154  m_plan_bw = fftw_plan_many_dft(m_ndim, m_nsize, m_nv,
155  m_buf_in, m_nsize, m_nv, 1,
156  m_buf_out, m_nsize, m_nv, 1,
157  FFTW_BACKWARD, FFTW_ESTIMATE);
158 
159  if ((!m_plan_fw) || (!m_plan_bw)) {
160  vout.crucial(m_vl, "%s: create plan failed.\n", class_name.c_str());
161  exit(EXIT_FAILURE);
162  }
163 }
164 
165 
166 //====================================================================
167 void FFT_3d_local::clear_plan()
168 {
169  if (m_buf_in) fftw_free(m_buf_in);
170  if (m_buf_out) fftw_free(m_buf_out);
171 
172  if (m_plan_fw) fftw_destroy_plan(m_plan_fw);
173  if (m_plan_bw) fftw_destroy_plan(m_plan_bw);
174 }
175 
176 
177 //====================================================================
178 void FFT_3d_local::finalize()
179 {
180  clear_plan();
181 }
182 
183 
184 //====================================================================
185 void FFT_3d_local::fft(Field& dst, const Field& src, enum Direction dir)
186 {
187  if (not ((dir == FORWARD) || (dir == BACKWARD))) {
188  vout.crucial(m_vl, "%s: unsupported direction. %d\n", class_name.c_str(), dir);
189  exit(EXIT_FAILURE);
190  }
191 
192  initialize_plan(src);
193 
194  int nex = src.nex();
195  int nt = CommonParameters::Nt();
196 
197  Index_lex index;
198 
199  size_t count = m_nv * m_vol; // count in complex numbers
200 
201  for (int iex = 0; iex < nex; ++iex) {
202  for (int it = 0; it < nt; ++it) {
203  memcpy(m_buf_in, src.ptr(0, index.site(0, 0, 0, it), iex), sizeof(double) * count * 2);
204 
205  if (dir == FORWARD) {
206  fftw_execute(m_plan_fw);
207  } else if (dir == BACKWARD) {
208  fftw_execute(m_plan_bw);
209  } else {
210  vout.crucial(m_vl, "%s: unsupported direction. %d\n", class_name.c_str(), dir);
211  exit(EXIT_FAILURE);
212  }
213 
214  memcpy(dst.ptr(0, index.site(0, 0, 0, it), iex), m_buf_out, sizeof(double) * count * 2);
215  }
216  }
217 
218  if (dir == BACKWARD) {
219  scal(dst, 1.0 / m_vol);
220  }
221 }
222 
223 
224 //====================================================================
225 void FFT_3d_local::fft(Field& dst, const Field& src)
226 {
227  return fft(dst, src, m_direction);
228 }
229 
230 
231 //====================================================================
232 void FFT_3d_local::fft(Field& field)
233 {
234  // return fft(field, field, m_direction);
235  vout.crucial(m_vl, "Error at %s: fft on-the-fly unsupported.\n", class_name.c_str());
236  exit(EXIT_FAILURE);
237 }
238 
239 
240 //====================================================================
241 #endif /* USE_FFTWLIB */
242 
243 //====================================================================
244 //============================================================END=====
void scal(Field &x, const double a)
scal(x, a): x = a * x
Definition: field.cpp:433
BridgeIO vout
Definition: bridgeIO.cpp:503
void detailed(const char *format,...)
Definition: bridgeIO.cpp:216
static int get_num_threads()
returns available number of threads.
static int npe(const int dir)
logical grid extent
const double * ptr(const int jin, const int site, const int jex) const
Definition: field.h:154
int site(const int &x, const int &y, const int &z, const int &t) const
Definition: index_lex.h:53
Container of Field-type object.
Definition: field.h:46
Class for parameters.
Definition: parameters.h:46
int fetch_string(const string &key, string &value) const
Definition: parameters.cpp:378
int nin() const
Definition: field.h:127
int nex() const
Definition: field.h:129
Lexical site index.
Definition: index_lex.h:34
void crucial(const char *format,...)
Definition: bridgeIO.cpp:178
Direction
Definition: bridge_defs.h:24
static long_t Lvol()