Bridge++  Version 1.5.4
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
fft_3d_local.cpp
Go to the documentation of this file.
1 
14 #ifdef USE_FFTWLIB
15 
16 #include "fft_3d_local.h"
17 #include "Field/index_lex.h"
18 #include <cstring>
19 
20 #ifdef USE_OPENMP
22 #endif
23 
24 const std::string FFT_3d_local::class_name = "FFT_3d_local";
25 
26 #ifdef USE_FACTORY_AUTOREGISTER
27 namespace {
28  bool init = FFT_3d_local::register_factory();
29 }
30 #endif
31 
32 //====================================================================
33 FFT_3d_local::FFT_3d_local()
34  : m_ndim(0)
35  , m_vol(0)
36  , m_nv(0)
37  , m_buf_in(NULL)
38  , m_buf_out(NULL)
39  , m_plan_fw(NULL)
40  , m_plan_bw(NULL)
41  , m_direction(UNDEF)
42 {
43  if (check_ok()) {
44  initialize();
45  }
46 }
47 
48 
49 //====================================================================
50 FFT_3d_local::~FFT_3d_local()
51 {
52  finalize();
53 }
54 
55 
56 //====================================================================
57 void FFT_3d_local::set_parameters(const Parameters& params)
58 {
59  // pass parameters to base class.
60  this->FFT::set_parameters(params);
61 
62  std::string direction;
63 
64  if (params.fetch_string("FFT_direction", direction) == 0) {
65  set_parameters(direction);
66  }
67 }
68 
69 
70 //====================================================================
71 void FFT_3d_local::set_parameters(const std::string& direction)
72 {
73  if (direction == "Forward") {
74  m_direction = FORWARD;
75  } else if (direction == "Backward") {
76  m_direction = BACKWARD;
77  } else {
78  m_direction = UNDEF;
79 
80  vout.crucial(m_vl, "Error at %s: unsupported FFT direction \"%s\"\n", class_name.c_str(), direction.c_str());
81  exit(EXIT_FAILURE);
82  }
83 }
84 
85 
86 //====================================================================
87 bool FFT_3d_local::check_ok()
88 {
89  int npe_xyz = Communicator::npe(0) * Communicator::npe(1) * Communicator::npe(2);
90 
91  if (npe_xyz > 1) {
92  vout.crucial(m_vl, "%s: incompatible with xyz parallelization.\n", class_name.c_str());
93  exit(EXIT_FAILURE);
94  }
95 
96  return true;
97 }
98 
99 
100 //====================================================================
101 void FFT_3d_local::initialize()
102 {
103 #ifdef USE_OPENMP
104  int thread_ok = fftw_init_threads();
105 
106  if (thread_ok) {
107  fftw_plan_with_nthreads(ThreadManager_OpenMP::get_num_threads());
108  }
109 #endif
110 }
111 
112 
113 //====================================================================
114 void FFT_3d_local::initialize_plan(const Field& src)
115 {
116  if ((m_nv == src.nin() / 2) && (m_vol == CommonParameters::Lvol() / CommonParameters::Lt())) {
117  vout.detailed(m_vl, "%s: plan recycled.\n", class_name.c_str());
118  return;
119  } else {
120  vout.detailed(m_vl, "%s: create plan.\n", class_name.c_str());
121  }
122 
123  // first, clear pre-existing plan, if any.
124  clear_plan();
125 
126  // assume 3dimensional
127  m_ndim = 3;
128  // row-major fortran order
129  m_nsize[0] = CommonParameters::Lz();
130  m_nsize[1] = CommonParameters::Ly();
131  m_nsize[2] = CommonParameters::Lx();
132 
133  // local volume
134  m_vol = 1;
135  for (int i = 0; i < m_ndim; ++i) {
136  m_vol *= m_nsize[i];
137  }
138 
139  // number of complex elements. run nin at a time.
140  m_nv = src.nin() / 2;
141 
142  m_buf_in = fftw_alloc_complex(m_nv * m_vol);
143  m_buf_out = fftw_alloc_complex(m_nv * m_vol);
144 
145  if ((!m_buf_in) || (!m_buf_out)) {
146  vout.crucial(m_vl, "%s: memory allocation failed.\n", class_name.c_str());
147  exit(EXIT_FAILURE);
148  }
149 
150  m_plan_fw = fftw_plan_many_dft(m_ndim, m_nsize, m_nv,
151  m_buf_in, m_nsize, m_nv, 1,
152  m_buf_out, m_nsize, m_nv, 1,
153  FFTW_FORWARD, FFTW_ESTIMATE);
154 
155  m_plan_bw = fftw_plan_many_dft(m_ndim, m_nsize, m_nv,
156  m_buf_in, m_nsize, m_nv, 1,
157  m_buf_out, m_nsize, m_nv, 1,
158  FFTW_BACKWARD, FFTW_ESTIMATE);
159 
160  if ((!m_plan_fw) || (!m_plan_bw)) {
161  vout.crucial(m_vl, "%s: create plan failed.\n", class_name.c_str());
162  exit(EXIT_FAILURE);
163  }
164 }
165 
166 
167 //====================================================================
168 void FFT_3d_local::clear_plan()
169 {
170  if (m_buf_in) fftw_free(m_buf_in);
171  if (m_buf_out) fftw_free(m_buf_out);
172 
173  if (m_plan_fw) fftw_destroy_plan(m_plan_fw);
174  if (m_plan_bw) fftw_destroy_plan(m_plan_bw);
175 }
176 
177 
178 //====================================================================
179 void FFT_3d_local::finalize()
180 {
181  clear_plan();
182 }
183 
184 
185 //====================================================================
186 void FFT_3d_local::fft(Field& dst, const Field& src, enum Direction dir)
187 {
188  if (not ((dir == FORWARD) || (dir == BACKWARD))) {
189  vout.crucial(m_vl, "%s: unsupported direction. %d\n", class_name.c_str(), dir);
190  exit(EXIT_FAILURE);
191  }
192 
193  initialize_plan(src);
194 
195  int nex = src.nex();
196  int nt = CommonParameters::Nt();
197 
198  Index_lex index;
199 
200  size_t count = m_nv * m_vol; // count in complex numbers
201 
202  for (int iex = 0; iex < nex; ++iex) {
203  for (int it = 0; it < nt; ++it) {
204  memcpy(m_buf_in, src.ptr(0, index.site(0, 0, 0, it), iex), sizeof(double) * count * 2);
205 
206  if (dir == FORWARD) {
207  fftw_execute(m_plan_fw);
208  } else if (dir == BACKWARD) {
209  fftw_execute(m_plan_bw);
210  } else {
211  vout.crucial(m_vl, "%s: unsupported direction. %d\n", class_name.c_str(), dir);
212  exit(EXIT_FAILURE);
213  }
214 
215  memcpy(dst.ptr(0, index.site(0, 0, 0, it), iex), m_buf_out, sizeof(double) * count * 2);
216  }
217  }
218 
219  if (dir == BACKWARD) {
220  scal(dst, 1.0 / m_vol);
221  }
222 }
223 
224 
225 //====================================================================
226 void FFT_3d_local::fft(Field& dst, const Field& src)
227 {
228  return fft(dst, src, m_direction);
229 }
230 
231 
232 //====================================================================
233 void FFT_3d_local::fft(Field& field)
234 {
235  // return fft(field, field, m_direction);
236  vout.crucial(m_vl, "Error at %s: fft on-the-fly unsupported.\n", class_name.c_str());
237  exit(EXIT_FAILURE);
238 }
239 
240 
241 //====================================================================
242 #endif /* USE_FFTWLIB */
243 
244 //====================================================================
245 //============================================================END=====
void scal(Field &x, const double a)
scal(x, a): x = a * x
Definition: field.cpp:433
BridgeIO vout
Definition: bridgeIO.cpp:503
void detailed(const char *format,...)
Definition: bridgeIO.cpp:216
static int get_num_threads()
returns available number of threads.
static int npe(const int dir)
logical grid extent
const double * ptr(const int jin, const int site, const int jex) const
Definition: field.h:153
int site(const int &x, const int &y, const int &z, const int &t) const
Definition: index_lex.h:53
Container of Field-type object.
Definition: field.h:45
Class for parameters.
Definition: parameters.h:46
int fetch_string(const string &key, string &value) const
Definition: parameters.cpp:378
int nin() const
Definition: field.h:126
int nex() const
Definition: field.h:128
Lexical site index.
Definition: index_lex.h:34
void crucial(const char *format,...)
Definition: bridgeIO.cpp:178
Direction
Definition: bridge_defs.h:24
static long_t Lvol()