algebramix_doc 0.3
|
00001 00002 /****************************************************************************** 00003 * MODULE : fft_naive.hpp 00004 * DESCRIPTION: generic low-level FFT multiplication 00005 * COPYRIGHT : (C) 2005 Joris van der Hoeven 00006 ******************************************************************************* 00007 * This software falls under the GNU general public license and comes WITHOUT 00008 * ANY WARRANTY WHATSOEVER. See the file $TEXMACS_PATH/LICENSE for more details. 00009 * If you don't have this file, write to the Free Software Foundation, Inc., 00010 * 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. 00011 ******************************************************************************/ 00012 00013 #ifndef __MMX__FFT_NAIVE__HPP 00014 #define __MMX__FFT_NAIVE__HPP 00015 #include <algebramix/fft_roots.hpp> 00016 00017 namespace mmx { 00018 00019 /****************************************************************************** 00020 * Naive root arithmetic 00021 ******************************************************************************/ 00022 00023 template<typename CC, typename UU, typename SS> 00024 struct roots_helper { 00025 typedef CC C; 00026 typedef UU U; 00027 typedef SS S; 00028 00029 static U* 00030 create_roots (nat n) { 00031 nat k= primitive_root_max_order<C> (2); (void) k; 00032 VERIFY (k == 0 || n <= k, "maximum order exceeded"); 00033 VERIFY (n >= 2, "size must be at least two"); 00034 U* roots= mmx_new<U> (n); 00035 for (nat i=0; i<n; i+=2) { 00036 U temp = primitive_root<C> (n, bit_mirror (i, n)); 00037 roots[i] = temp; 00038 roots[i+1]= primitive_root<U> (n, i==0? 0: n - bit_mirror (i, n)); 00039 } 00040 return roots; } 00041 00042 static void 00043 destroy_roots (U* u, nat n) { 00044 mmx_delete<U> (u, n); } 00045 00046 static inline void 00047 fft_cross (C* c1, C* c2) { 00048 C temp= (*c2); 00049 *c2 = (*c1) - temp; 00050 *c1 = (*c1) + temp; } 00051 00052 static inline void 00053 dfft_cross (C* c1, C* c2, const U* u) { 00054 C temp= (*u) * (*c2); 00055 *c2 = (*c1) - temp; 00056 *c1 = (*c1) + temp; } 00057 00058 static inline void 00059 ifft_cross (C* c1, C* c2, const U* u) { 00060 C temp= *c2; 00061 *c2 = (*u ) * ((*c1) - temp); 00062 *c1 = (*c1) + temp; } 00063 00064 static inline void 00065 dtft_cross (C* c1, C* c2) { 00066 static S h= invert (S (2)); 00067 fft_cross (c1, c2); 00068 *c1 *= h; 00069 *c2 *= h; } 00070 00071 static inline void 00072 dtft_cross (C* c1, C* c2, const U* u) { 00073 static S h= invert (S (2)); 00074 dfft_cross (c1, c2, u); 00075 *c1 *= h; 00076 *c2 *= h; } 00077 00078 static inline void 00079 itft_flip (C* c1, C* c2, const U* u) { 00080 static S h= invert (S(2)); 00081 C temp= (*u) * (*c2); 00082 *c1 += (*c1) - temp; 00083 *c2 = h * ((*c1) - temp); } 00084 00085 static inline void 00086 itft_flip (C* c1, C* c2) { 00087 static S h= invert (S(2)); 00088 *c1 += (*c1) - (*c2); 00089 *c2 = h * ((*c1) - (*c2)); } 00090 00091 struct fft_mul_sc_op : mul_op {}; 00092 }; 00093 00094 /****************************************************************************** 00095 * The FFT transformer class 00096 ******************************************************************************/ 00097 00098 template<typename C, typename V= std_roots<C> > 00099 class fft_naive_transformer { 00100 public: 00101 typedef implementation<vector_linear,vector_naive> NVec; 00102 typedef typename V::roots_type R; 00103 typedef typename R::U U; 00104 typedef typename R::S S; 00105 00106 nat depth; 00107 nat len; 00108 U* roots; 00109 00110 public: 00111 inline fft_naive_transformer (nat n): 00112 depth (log_2 (n)), len (n), roots (R::create_roots (n)) { 00113 VERIFY (n == ((nat) 1 << depth), "power of two expected"); } 00114 00115 inline ~fft_naive_transformer () { 00116 R::destroy_roots (roots, len); } 00117 00118 inline void 00119 dfft (C* c, nat stride, nat shift, nat steps, nat step1, nat step2) { 00120 // In place direct fft of c[0], c[stride], ..., c[(2^steps-1) stride] 00121 // Only perform steps from step1 until step2-1 00122 // If shift != 0, then roots start at roots + (shift<<1) 00123 for (nat step= step1; step < step2; step++) { 00124 //mmout << "step " << step << ": " << flush_now; 00125 if (step == 0 && shift == 0) { 00126 nat todo= steps - 1; 00127 C* cc= c; 00128 for (nat k= 0; k < ((nat) 1<<todo); k++) { 00129 R::fft_cross (cc, cc + (stride<<todo)); 00130 cc += stride; 00131 } 00132 } 00133 else { 00134 nat todo= steps - 1 - step; 00135 C* cc= c; 00136 U * uu= roots + ((shift >> todo) << 1); 00137 for (nat j= 0; j < ((nat) 1<<step); j++) { 00138 for (nat k= 0; k < ((nat) 1<<todo); k++) { 00139 R::dfft_cross (cc, cc + (stride<<todo), uu); 00140 cc += stride; 00141 } 00142 cc += (stride<<todo); 00143 uu += 2; 00144 } 00145 } 00146 } 00147 } 00148 00149 inline void 00150 ifft (C* c, nat stride, nat shift, nat steps, nat step1, nat step2) { 00151 // In place inverse fft of c[0], c[stride], ..., c[(2^steps-1) stride] 00152 // Only perform steps from step2-1 until step1 00153 // If shift != 0, then roots start at roots + (shift<<1) 00154 for (int step= step2-1; (int) step >= ((int) step1); step--) { 00155 //mmout << "step " << step << ": " << flush_now; 00156 if (step == 0 && shift == 0) { 00157 nat todo= steps - 1; 00158 C* cc= c; 00159 for (nat k= 0; k < ((nat) 1<<todo); k++) { 00160 R::fft_cross (cc, cc + (stride<<todo)); 00161 cc += stride; 00162 } 00163 } 00164 else { 00165 nat todo= steps - 1 - step; 00166 C* cc= c; 00167 U * uu= roots + 1 + ((shift >> todo) << 1); 00168 for (nat j= 0; j < ((nat) 1<<step); j++) { 00169 for (nat k= 0; k < ((nat) 1<<todo); k++) { 00170 R::ifft_cross (cc, cc + (stride<<todo), uu); 00171 cc += stride; 00172 } 00173 cc += (stride<<todo); 00174 uu += 2; 00175 } 00176 } 00177 } 00178 } 00179 00180 inline void 00181 dfft (C* c, nat stride, nat shift, nat steps) { 00182 dfft (c, stride, shift, steps, 0, steps); } 00183 00184 inline void 00185 ifft (C* c, nat stride, nat shift, nat steps) { 00186 ifft (c, stride, shift, steps, 0, steps); } 00187 00188 inline void 00189 direct_transform (C* c) { 00190 dfft (c, 1, 0, depth); } 00191 00192 inline void 00193 inverse_transform (C* c, bool divide=true) { 00194 ifft (c, 1, 0, depth); 00195 if (divide) { 00196 S x= binpow (S (2), depth); 00197 x= invert (x); 00198 NVec::template vec_unary_scalar<typename R::fft_mul_sc_op> (c, x, len); 00199 } 00200 } 00201 }; 00202 00203 /****************************************************************************** 00204 * The FFT transformer class 00205 ******************************************************************************/ 00206 00207 template<typename C> inline void 00208 direct_fft (C* dest, nat n) { 00209 fft_naive_transformer<C> ffter (n); 00210 ffter.direct_transform (dest); 00211 } 00212 00213 template<typename C> inline void 00214 inverse_fft (C* dest, nat n) { 00215 fft_naive_transformer<C> ffter (n); 00216 ffter.inverse_transform (dest); 00217 } 00218 00219 } // namespace mmx 00220 #endif //__MMX__FFT_NAIVE__HPP