algebramix_doc 0.3
/Users/mourrain/Devel/mmx/algebramix/include/algebramix/fft_triadic_threads.hpp
Go to the documentation of this file.
00001 
00002 /******************************************************************************
00003 * MODULE     : fft_triadic_threads.hpp
00004 * DESCRIPTION: multi-threaded triadic FFT multiplication
00005 * COPYRIGHT  : (C) 2008  Gregoire Lecerf
00006 *******************************************************************************
00007 * This software falls under the GNU general public license and comes WITHOUT
00008 * ANY WARRANTY WHATSOEVER. See the file $TEXMACS_PATH/LICENSE for more details.
00009 * If you don't have this file, write to the Free Software Foundation, Inc.,
00010 * 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
00011 ******************************************************************************/
00012 
00013 #ifndef __MMX__FFT_TRIADIC_THREADS__HPP
00014 #define __MMX__FFT_TRIADIC_THREADS__HPP
00015 #include <algebramix/fft_triadic_naive.hpp>
00016 #include <basix/threads.hpp>
00017 
00018 namespace mmx {
00019 
00020 /******************************************************************************
00021 * Multi-threaded and triadic FFT transformation
00022 ******************************************************************************/
00023 
00024 template<typename C, typename FFTER= fft_triadic_naive_transformer<C>,
00025          nat thr= 9>
00026 class fft_triadic_threads_transformer {
00027 public:
00028   typedef implementation<vector_linear,vector_naive> NVec;
00029   typedef typename FFTER::V V;
00030   typedef typename V::template helper<C>::roots_type R;
00031   typedef typename R::U U;
00032   static const nat min_reps= 27; // Must be a power of 3
00033 
00034   FFTER* ffter;
00035   nat    depth;
00036   nat    len;
00037   U*     roots;
00038   U*     stoor;
00039 
00040 public:
00041   inline fft_triadic_threads_transformer (nat n):
00042     ffter (new FFTER (n)),
00043     depth (ffter->depth), len (ffter->len),
00044     roots (ffter->roots), stoor (ffter->stoor) {}
00045   inline ~fft_triadic_threads_transformer () { delete ffter; }
00046 
00047   template<typename CC> static inline void
00048   copy (CC* d, nat drs, nat dcs, CC* s, nat srs, nat scs, nat r, nat c) {
00049     for (nat j=0; j<c; j++, d+=dcs, s+=scs) {
00050       CC* dd= d; CC* ss= s;
00051       for (nat i=0; i<r; i++, dd+=drs, ss+=srs)
00052         *dd= *ss;
00053     }
00054   }
00055 
00056   template<typename CC>
00057   struct outer_fft_triadic_task_rep: public task_rep {
00058     FFTER* ffter;
00059     bool direct;
00060     CC *buf;    
00061     nat thr_reps, tot_reps;
00062     CC *c;
00063     nat stride, shift, steps;
00064     nat l;
00065   public:
00066     inline outer_fft_triadic_task_rep
00067      (FFTER* ffter2, bool direct2,
00068       CC* buf2, nat thr_reps2, nat tot_reps2,
00069       CC* c2, nat stride2, nat shift2, nat steps2):
00070         ffter (ffter2), direct (direct2),
00071         buf (buf2), thr_reps (thr_reps2), tot_reps (tot_reps2),
00072         c (c2), stride (stride2), shift (shift2), steps (steps2) {
00073       l= binpow ((nat) 3, steps); }
00074     inline ~outer_fft_triadic_task_rep () {}
00075     void execute () {
00076       for (nat i=0; i*min_reps < thr_reps; i++, c += stride * min_reps) {
00077         CC* aux= buf;
00078         copy (buf, 1, min_reps, c, stride, stride * tot_reps,
00079               min_reps, l);
00080         if (direct)
00081           for (nat j=0; j<min_reps; j++, aux++)
00082             ffter->dfft_triadic (aux, min_reps, shift, steps);
00083         else
00084           for (nat j=0; j<min_reps; j++, aux++)
00085             ffter->ifft_triadic (aux, min_reps, shift, steps);
00086         copy (c, stride, stride * tot_reps, buf, 1, min_reps,
00087               min_reps, l);
00088       }
00089     }
00090   };
00091 
00092   template<typename CC>
00093   struct inner_fft_triadic_task_rep: public task_rep {
00094     FFTER* ffter;
00095     bool direct;
00096     nat start, inc, total;
00097     CC *c;
00098     nat stride, shift, steps;
00099     nat l;
00100   public:
00101     inline inner_fft_triadic_task_rep
00102      (FFTER* ffter2, bool direct2, nat start2, nat inc2, nat total2,
00103       CC* c2, nat stride2, nat shift2, nat steps2):
00104         ffter (ffter2), direct (direct2),
00105         start (start2), inc (inc2), total (total2),
00106         c (c2), stride (stride2), shift (shift2), steps (steps2) {
00107       l= binpow ((nat) 3, steps); }
00108     inline ~inner_fft_triadic_task_rep () {}
00109     void execute () {
00110       if (direct)
00111         for (nat i=start; i<total; i+=inc)
00112           ffter->dfft_triadic (c + i * l * stride, stride,
00113                                shift + i * l, steps);
00114       else
00115         for (nat i=start; i<total; i+=inc)
00116           ffter->ifft_triadic (c + i * l * stride, stride,
00117                                shift + i * l, steps);   
00118     }
00119   };
00120 
00121   template<typename CC> void
00122   fft_triadic (bool direct, CC* c, nat stride, nat shift, nat steps) {
00123     nat nt= threads_number;
00124     nat half1= steps >> 1;
00125     nat half2= steps - half1;
00126     nat len1= binpow ((nat) 3, half1);
00127     nat len2= binpow ((nat) 3, half2);
00128     if (steps <= thr || len2 <= min_reps || nt == 1) {
00129       if (direct) ffter->dfft_triadic (c, stride, shift, steps);
00130       else ffter->ifft_triadic (c, stride, shift, steps);
00131     }
00132     else {
00133       for (nat stage=0; stage<2; stage++) {
00134         if ((stage == 0) ^ (!direct)) {
00135           // Below is multi-threaded the following loop:
00136           // C*  cc1  = c;
00137           // for (nat i=0; i< len2; i++, cc1 += stride)
00138           //   if (direct)
00139           //     ffter->dfft_triadic (cc1, stride * len2, shift, half1);
00140           //   else
00141           //     ffter->ifft_triadic (cc1, stride * len2, shift, half1);
00142           nat bsz= min_reps * len1; // Recall min_reps < len2
00143           CC* buf= mmx_new<CC> (nt * bsz);
00144           task tasks[nt];
00145           for (nat i=0; i<nt; i++) {
00146             nat sta= min_reps * ((  i   * len2) / (min_reps * nt));
00147             nat end= min_reps * (((i+1) * len2) / (min_reps * nt));
00148             tasks[i]= new outer_fft_triadic_task_rep<CC>
00149               (ffter, direct, buf + i*bsz, end-sta, len2,
00150                c + sta*stride, stride, shift / len2, half1);
00151           }
00152           threads_execute (tasks, nt);
00153           mmx_delete<CC> (buf, nt * bsz);
00154         }
00155         else {
00156           // Below is multi-threaded the following loop:
00157           // CC* cc2= c;
00158           // for (nat i=0; i<len1; i++, cc2 += stride*len2)
00159           //   if (direct)
00160           //     ffter->dfft_triadic (cc2, stride, shift + i*len2, half2);
00161           //   else
00162           //     ffter->ifft_triadic (cc2, stride, shift + i*len2, half2);
00163           task tasks[nt];
00164           for (nat i=0; i<nt; i++)
00165             tasks[i]= new inner_fft_triadic_task_rep<CC>
00166               (ffter, direct, i, nt, len1, c, stride, shift, half2);
00167           threads_execute (tasks, nt);
00168         }
00169       }
00170     }
00171   }
00172 
00173   template<typename CC> inline void
00174   dfft_triadic (CC* c, nat stride, nat shift, nat steps) {
00175     if (steps <= thr) ffter->dfft_triadic (c, stride, shift, steps);
00176     else fft_triadic (true, c, stride, shift, steps);
00177   }
00178 
00179   template<typename CC> inline void
00180   ifft_triadic (CC* c, nat stride, nat shift, nat steps) {
00181     if (steps <= thr) ffter->ifft_triadic (c, stride, shift, steps);
00182     else fft_triadic (false, c, stride, shift, steps);
00183   }
00184 
00185   template<typename CC> inline void
00186   dfft_triadic (CC* c, nat stride, nat shift, nat steps,
00187                 nat step1, nat step2) {
00188     if (step1 == 0 && step2 == steps && steps > thr)
00189       fft_triadic (true, c, stride, shift, steps);
00190     else ffter->dfft_triadic (c, stride, shift, steps, step1, step2);
00191   }
00192 
00193   template<typename CC> inline void
00194   ifft_triadic (CC* c, nat stride, nat shift, nat steps,
00195                 nat step1, nat step2) {
00196     if (step1 == 0 && step2 == steps && steps > thr)
00197       fft_triadic (false, c, stride, shift, steps);
00198     else ffter->ifft_triadic (c, stride, shift, steps, step1, step2);
00199   }
00200 
00201   inline void
00202   direct_transform_triadic (C* c) {
00203     dfft_triadic (c, 1, 0, depth);
00204   }
00205 
00206   inline void
00207   inverse_transform_triadic (C* c, bool shift=true) {
00208     ifft_triadic (c, 1, 0, depth);
00209     if (shift) NVec::mul (c, invert (C (len)), len);
00210   }
00211 };
00212 
00213 } // namespace mmx
00214 #endif //__MMX__FFT_TRIADIC_THREADS__HPP
 All Classes Namespaces Files Functions Variables Typedefs Friends Defines