algebramix_doc 0.3
/Users/mourrain/Devel/mmx/algebramix/include/algebramix/fft_threads.hpp
Go to the documentation of this file.
00001 
00002 /******************************************************************************
00003 * MODULE     : fft_threads.hpp
00004 * DESCRIPTION: multi-threaded FFT
00005 * COPYRIGHT  : (C) 2007  Joris van der Hoeven
00006 *******************************************************************************
00007 * This software falls under the GNU general public license and comes WITHOUT
00008 * ANY WARRANTY WHATSOEVER. See the file $TEXMACS_PATH/LICENSE for more details.
00009 * If you don't have this file, write to the Free Software Foundation, Inc.,
00010 * 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
00011 ******************************************************************************/
00012 
00013 #ifndef __MMX__FFT_THREADS__HPP
00014 #define __MMX__FFT_THREADS__HPP
00015 #include <basix/threads.hpp>
00016 #include <algebramix/vector_threads.hpp>
00017 #include <algebramix/fft_blocks.hpp>
00018 
00019 namespace mmx {
00020 
00021 /******************************************************************************
00022 * Multi-threaded and dichotomic FFT transformation
00023 ******************************************************************************/
00024 
00025 template<typename C,
00026          typename FFTER= fft_blocks_transformer<C,
00027                                                 fft_naive_transformer<C>,
00028                                                 8, 6, 10>,
00029          nat thr= 14>
00030 class fft_threads_transformer {
00031 public:
00032   typedef implementation<vector_linear,vector_threads<vector_naive> > Vec;
00033   typedef typename FFTER::R R;
00034   typedef typename R::U U;
00035   typedef typename R::S S;
00036   static const nat min_reps= 16;
00037 
00038   FFTER* ffter;
00039   nat    depth;
00040   nat    len;
00041   U*     roots;
00042 
00043 public:
00044   inline fft_threads_transformer (nat n):
00045     ffter (new FFTER (n)),
00046     depth (ffter->depth), len (ffter->len), roots (ffter->roots) {}
00047   inline ~fft_threads_transformer () { delete ffter; }
00048 
00049   static inline void
00050   copy (C* d, nat drs, nat dcs, C* s, nat srs, nat scs, nat r, nat c) {
00051     for (nat j=0; j<c; j++, d+=dcs, s+=scs) {
00052       C* dd= d; C* ss= s;
00053       for (nat i=0; i<r; i++, dd+=drs, ss+=srs)
00054         *dd= *ss;
00055     }
00056   }
00057 
00058   struct outer_fft_task_rep: public task_rep {
00059     FFTER* ffter;
00060     bool direct;
00061     C *buf;    
00062     nat thr_reps, tot_reps;
00063     C *c;
00064     nat stride, shift, steps;
00065   public:
00066     inline outer_fft_task_rep
00067      (FFTER* ffter2, bool direct2,
00068       C* buf2, nat thr_reps2, nat tot_reps2,
00069       C* c2, nat stride2, nat shift2, nat steps2):
00070         ffter (ffter2), direct (direct2),
00071         buf (buf2), thr_reps (thr_reps2), tot_reps (tot_reps2),
00072         c (c2), stride (stride2), shift (shift2), steps (steps2) {}
00073     inline ~outer_fft_task_rep () {}
00074     void execute () {
00075       for (nat i=0; i*min_reps < thr_reps; i++, c += stride * min_reps) {
00076         C* aux= buf;
00077         copy (buf, 1, min_reps, c, stride, stride * tot_reps,
00078               min_reps, 1 << steps);
00079         if (direct)
00080           for (nat j=0; j<min_reps; j++, aux++)
00081             ffter->dfft (aux, min_reps, shift, steps);
00082         else
00083           for (nat j=0; j<min_reps; j++, aux++)
00084             ffter->ifft (aux, min_reps, shift, steps);
00085         copy (c, stride, stride * tot_reps, buf, 1, min_reps,
00086               min_reps, 1 << steps);
00087       }
00088     }
00089   };
00090 
00091   struct inner_fft_task_rep: public task_rep {
00092     FFTER* ffter;
00093     bool direct;
00094     nat start, inc, total;
00095     C *c;
00096     nat stride, shift, steps;
00097   public:
00098     inline inner_fft_task_rep
00099      (FFTER* ffter2, bool direct2, nat start2, nat inc2, nat total2,
00100       C* c2, nat stride2, nat shift2, nat steps2):
00101         ffter (ffter2), direct (direct2),
00102         start (start2), inc (inc2), total (total2),
00103         c (c2), stride (stride2), shift (shift2), steps (steps2) {}
00104     inline ~inner_fft_task_rep () {}
00105     void execute () {
00106       if (direct)
00107         for (nat i=start; i<total; i+=inc)
00108           ffter->dfft (c + (i<<steps) * stride, stride,
00109                        shift + ((i << steps) >> 1), steps);
00110       else
00111         for (nat i=start; i<total; i+=inc)
00112           ffter->ifft (c + (i<<steps) * stride, stride,
00113                        shift + ((i << steps) >> 1), steps);     
00114     }
00115   };
00116 
00117   void
00118   fft (bool direct, C* c, nat stride, nat shift, nat steps) {
00119     if (steps <= thr) {
00120       if (direct) ffter->dfft (c, stride, shift, steps);
00121       else ffter->ifft (c, stride, shift, steps);
00122     }
00123     else {
00124       nat half1= min(log_2 (threads_number), steps);
00125       nat half2= steps - half1;
00126 
00127       for (nat stage=0; stage<2; stage++) {
00128         if ((stage == 0) ^ (!direct)) {
00129           //C*  cc1  = c;
00130           //for (nat i=0; i<((nat) 1<<half2); i++, cc1 += stride)
00131           //  ffter->fft (direct, cc1, stride << half2,
00132           //              shift >> half2, half1);
00133 
00134           nat nt = threads_number;
00135           nat bsz= (min_reps << half1);
00136           C* buf= mmx_new<C> (nt * bsz);
00137           task tasks[nt];
00138           for (nat i=0; i<nt; i++) {
00139             nat tot= 1 << half2;
00140             nat sta= min_reps * ((  i   * tot) / (min_reps * nt));
00141             nat end= min_reps * (((i+1) * tot) / (min_reps * nt));
00142             tasks[i]= new outer_fft_task_rep
00143               (ffter, direct, buf + i*bsz, end-sta, tot,
00144                c + sta*stride, stride, shift >> half2, half1);
00145           }
00146           threads_execute (tasks, nt);
00147           mmx_delete<C> (buf, nt * bsz);
00148 
00149         }
00150         else {
00151           //C* cc2= c;
00152           //for (nat i=0; i<((nat) 1<<half1); i++, cc2 += (stride << half2))
00153           //  ffter->fft (direct, cc2, stride,
00154           //              shift + ((i << half2) >> 1), half2);
00155           nat nt= threads_number;
00156           task tasks[nt];
00157           for (nat i=0; i<nt; i++)
00158             tasks[i]= new inner_fft_task_rep
00159               (ffter, direct, i, nt, 1 << half1, c, stride, shift, half2);
00160           threads_execute (tasks, nt);
00161         }
00162       }
00163     }
00164   }
00165 
00166   inline void
00167   dfft (C* c, nat stride, nat shift, nat steps) {
00168     if (steps <= thr) ffter->dfft (c, stride, shift, steps);
00169     else fft (true, c, stride, shift, steps);
00170   }
00171 
00172   inline void
00173   ifft (C* c, nat stride, nat shift, nat steps) {
00174     if (steps <= thr) ffter->ifft (c, stride, shift, steps);
00175     else fft (false, c, stride, shift, steps);
00176   }
00177 
00178   inline void
00179   dfft (C* c, nat stride, nat shift, nat steps, nat step1, nat step2) {
00180     if (step1 == 0 && step2 == steps && steps > thr)
00181       fft (true, c, stride, shift, steps);
00182     else ffter->dfft (c, stride, shift, steps, step1, step2);
00183   }
00184 
00185   inline void
00186   ifft (C* c, nat stride, nat shift, nat steps, nat step1, nat step2) {
00187     if (step1 == 0 && step2 == steps && steps > thr)
00188       fft (false, c, stride, shift, steps);
00189     else ffter->ifft (c, stride, shift, steps, step1, step2);
00190   }
00191 
00192   inline void
00193   direct_transform (C* c) {
00194     dfft (c, 1, 0, depth); }
00195 
00196   inline void
00197   inverse_transform (C* c, bool divide=true) {
00198     ifft (c, 1, 0, depth);
00199     if (divide) {
00200       S x= binpow (S (2), depth);
00201       x= invert (x);
00202       Vec::template vec_unary_scalar<typename R::fft_mul_sc_op> (c, x, len);
00203     }
00204   }
00205 };
00206 
00207 } // namespace mmx
00208 #endif //__MMX__FFT_THREADS__HPP
 All Classes Namespaces Files Functions Variables Typedefs Friends Defines