algebramix_doc 0.3
/Users/mourrain/Devel/mmx/algebramix/include/algebramix/fft_simd.hpp
Go to the documentation of this file.
00001 
00002 /******************************************************************************
00003 * MODULE     : fft_simd.hpp
00004 * DESCRIPTION: FFT using SIMD SSE operations
00005 * COPYRIGHT  : (C) 2007  Joris van der Hoeven
00006 *******************************************************************************
00007 * This software falls under the GNU general public license and comes WITHOUT
00008 * ANY WARRANTY WHATSOEVER. See the file $TEXMACS_PATH/LICENSE for more details.
00009 * If you don't have this file, write to the Free Software Foundation, Inc.,
00010 * 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
00011 ******************************************************************************/
00012 
00013 #ifndef __MMX__FFT_SIMD__HPP
00014 #define __MMX__FFT_SIMD__HPP
00015 #include <numerix/simd.hpp>
00016 #include <numerix/sse.hpp>
00017 #include <algebramix/fft_naive.hpp>
00018 
00019 namespace mmx {
00020 
00021 // Fallback for when no simd type is available
00022 template<typename C,
00023          typename FFTER= fft_naive_transformer<C>,
00024          typename FFTER_SIMD= fft_naive_transformer<typename Simd_type(C)>,
00025          nat thr= 2>
00026 class fft_simd_transformer {
00027 public:
00028   typedef typename FFTER::R R;
00029   typedef typename R::U U;
00030   typedef typename R::S S;
00031   
00032   FFTER* ffter;
00033   nat    depth;
00034   nat    len;
00035   U*     roots;
00036 
00037 public:
00038   inline fft_simd_transformer (nat n):
00039     ffter (new FFTER (n)),
00040     depth (ffter->depth), len (ffter->len), roots (ffter->roots)  {}
00041 
00042   inline ~fft_simd_transformer () { delete ffter; }
00043 
00044   template<typename CC> inline void
00045   dfft (CC* c, nat stride, nat shift, nat steps, nat step1, nat step2) {
00046     ffter->dfft (c, stride, shift, steps, step1, step2); }
00047 
00048   template<typename CC> inline void
00049   ifft (CC* c, nat stride, nat shift, nat steps, nat step1, nat step2) {
00050     ffter->ifft (c, stride, shift, steps, step1, step2); }
00051 
00052   template<typename CC> inline void
00053   dfft (CC* c, nat stride, nat shift, nat steps) {
00054     dfft (c, stride, shift, steps, 0, steps); }
00055 
00056   template<typename CC> inline void
00057   ifft (CC* c, nat stride, nat shift, nat steps) {
00058     ifft (c, stride, shift, steps, 0, steps); }
00059 
00060   inline void
00061   direct_transform (C* c) {
00062     dfft (c, 1, 0, depth); }
00063 
00064   inline void
00065   inverse_transform (C* c, bool divide=true) {
00066     typedef implementation<vector_linear,vector_naive> NVec;
00067     ifft (c, 1, 0, depth);
00068     if (divide) NVec::mul (c, invert (binpow (S (2), depth)), len); }
00069 };
00070   
00071 #ifdef NUMERIX_ENABLE_SIMD
00072 
00073 /******************************************************************************
00074 * Support for SSE instructions
00075 ******************************************************************************/
00076 
00077 #ifdef __SSE2__
00078 
00079 inline void
00080 simd_encode (complex<double>* c, nat n) {
00081   double  temp;
00082   double* r= (double*) ((void*) c);
00083   for (; n!=0; r+=8, n-=4) {
00084     temp= r[1];
00085     r[1]= r[2];
00086     r[2]= temp;
00087     temp= r[5];
00088     r[5]= r[6];
00089     r[6]= temp;
00090   }
00091 }
00092 
00093 inline void
00094 simd_decode (complex<double>* c, nat n) {
00095   double  temp;
00096   double* r= (double*) ((void*) c);
00097   for (; n!=0; r+=8, n-=4) {
00098     temp= r[1];
00099     r[1]= r[2];
00100     r[2]= temp;
00101     temp= r[5];
00102     r[5]= r[6];
00103     r[6]= temp;
00104   }
00105 }
00106 
00107 STMPL
00108 struct roots_helper<complex<sse_double> >:
00109   public roots_helper<complex<double> > {
00110 
00111   typedef complex<sse_double> C;
00112   typedef complex<double>     U;
00113   typedef double              S;
00114 
00115   static inline void
00116   fft_cross (C* c1, C* c2) {
00117     double* z1= (double*) ((void*) c1);
00118     double* z2= (double*) ((void*) c2);
00119     sse_double re_z1= simd_load_aligned (z1);
00120     sse_double im_z1= simd_load_aligned (z1+2);
00121     sse_double re_z2= simd_load_aligned (z2);
00122     sse_double im_z2= simd_load_aligned (z2+2);
00123     simd_save_aligned (z2  , re_z1 - re_z2);
00124     simd_save_aligned (z2+2, im_z1 - im_z2);
00125     simd_save_aligned (z1  , re_z1 + re_z2);
00126     simd_save_aligned (z1+2, im_z1 + im_z2);
00127   }
00128 
00129   static inline void
00130   dfft_cross (C* c1, C* c2, const U* u) {
00131     double* z1= (double*) ((void*) c1);
00132     double* z2= (double*) ((void*) c2);
00133     double* u1= (double*) ((void*) u );
00134     sse_double re_z1= simd_load_aligned (z1);
00135     sse_double im_z1= simd_load_aligned (z1+2);
00136     sse_double re_z2= simd_load_aligned (z2);
00137     sse_double im_z2= simd_load_aligned (z2+2);
00138     sse_double re_u1= simd_load_duplicate (u1);
00139     sse_double im_u1= simd_load_duplicate (u1+1);
00140     sse_double re_u2= re_u1 * re_z2 - im_u1 * im_z2;
00141     sse_double im_u2= re_u1 * im_z2 + im_u1 * re_z2;
00142     simd_save_aligned (z2  , re_z1 - re_u2);
00143     simd_save_aligned (z2+2, im_z1 - im_u2);
00144     simd_save_aligned (z1  , re_z1 + re_u2);
00145     simd_save_aligned (z1+2, im_z1 + im_u2);
00146   }
00147 
00148   static inline void
00149   ifft_cross (C* c1, C* c2, const U* u) {
00150     double* z1= (double*) ((void*) c1);
00151     double* z2= (double*) ((void*) c2);
00152     double* u1= (double*) ((void*) u );
00153     sse_double re_z1= simd_load_aligned (z1);
00154     sse_double im_z1= simd_load_aligned (z1+2);
00155     sse_double re_z2= simd_load_aligned (z2);
00156     sse_double im_z2= simd_load_aligned (z2+2);
00157     sse_double re_u1= simd_load_duplicate (u1);
00158     sse_double im_u1= simd_load_duplicate (u1+1);
00159     sse_double re_u2= re_z1 - re_z2;
00160     sse_double im_u2= im_z1 - im_z2;
00161     simd_save_aligned (z2  , re_u1 * re_u2 - im_u1 * im_u2);
00162     simd_save_aligned (z2+2, re_u1 * im_u2 + im_u1 * re_u2);
00163     simd_save_aligned (z1  , re_z1 + re_z2);
00164     simd_save_aligned (z1+2, im_z1 + im_z2);
00165   }
00166 };
00167 
00168 STMPL
00169 struct std_roots<complex<sse_double> > {
00170   typedef complex<sse_double> C;
00171   typedef cached_roots_helper<roots_helper<C> > roots_type;
00172 };
00173 
00174 /******************************************************************************
00175 * The FFT transformer class
00176 ******************************************************************************/
00177 
00178 template<typename FFTER, typename FFTER_SIMD, nat thr>
00179 class fft_simd_transformer<complex<double>, FFTER, FFTER_SIMD, thr> {
00180 public:
00181   typedef complex<double> C;
00182   typedef typename FFTER::R R;
00183   typedef typename R::U U;
00184   typedef typename R::S S;
00185   
00186   FFTER* ffter;
00187   nat    depth;
00188   nat    len;
00189   U*     roots;
00190 
00191   typedef complex<sse_double> C_SIMD;
00192   FFTER_SIMD* ffter_simd;
00193 
00194 public:
00195   inline fft_simd_transformer (nat n):
00196     ffter (new FFTER (n)),
00197     depth (ffter->depth), len (ffter->len), roots (ffter->roots),
00198     ffter_simd (new FFTER_SIMD (n)) {}
00199 
00200   inline ~fft_simd_transformer () { delete ffter; delete ffter_simd; }
00201 
00202   inline void
00203   dfft (C* c, nat stride, nat shift, nat steps, nat step1, nat step2) {
00204     if (steps <= thr || stride != 1)
00205       ffter->dfft (c, stride, shift, steps, step1, step2);
00206     else {
00207       if (steps == step2) {
00208         simd_encode (c, 1 << steps);
00209         ffter_simd->dfft ((C_SIMD*) ((void*) c), 1, shift>>1,
00210                           steps-1, step1, steps-1);
00211         simd_decode (c, 1 << steps);
00212         ffter->dfft (c, 1, shift, steps, steps-1, steps);
00213       }
00214       else {
00215         simd_encode (c, 1 << steps);
00216         ffter_simd->dfft ((C_SIMD*) ((void*) c), 1,
00217                           shift>>1, steps-1, step1, step2);
00218         simd_decode (c, 1 << steps);
00219       }
00220     }
00221   }
00222 
00223   inline void
00224   ifft (C* c, nat stride, nat shift, nat steps, nat step1, nat step2) {
00225     if (steps <= thr || stride != 1)
00226       ffter->ifft (c, stride, shift, steps, step1, step2);
00227     else {
00228       if (steps == step2) {
00229         ffter->ifft (c, 1, shift, steps, steps-1, steps);
00230         simd_encode (c, 1 << steps);
00231         ffter_simd->ifft ((C_SIMD*) ((void*) c), 1, shift>>1,
00232                           steps-1, step1, steps-1);
00233         simd_decode (c, 1 << steps);
00234       }
00235       else {
00236         simd_encode (c, 1 << steps);
00237         ffter_simd->ifft ((C_SIMD*) ((void*) c), 1,
00238                           shift>>1, steps-1, step1, step2);
00239         simd_decode (c, 1 << steps);
00240       }
00241     }
00242   }
00243 
00244   inline void
00245   dfft (C* c, nat stride, nat shift, nat steps) {
00246     dfft (c, stride, shift, steps, 0, steps); }
00247 
00248   inline void
00249   ifft (C* c, nat stride, nat shift, nat steps) {
00250     ifft (c, stride, shift, steps, 0, steps); }
00251 
00252   inline void
00253   direct_transform (C* c) {
00254     dfft (c, 1, 0, ffter->depth); }
00255 
00256   inline void
00257   inverse_transform (C* c, bool divide=true) {
00258     typedef implementation<vector_linear,vector_naive> NVec;
00259     ifft (c, 1, 0, depth);
00260     if (divide) NVec::mul (c, invert (binpow (S (2), depth)), len); }
00261 };
00262 
00263 #endif //__SSE2__
00264 #endif // NUMERIX_ENABLE_SIMD
00265 } // namespace mmx
00266 #endif //__MMX__FFT_SIMD__HPP
 All Classes Namespaces Files Functions Variables Typedefs Friends Defines