algebramix_doc 0.3
|
00001 00002 /****************************************************************************** 00003 * MODULE : fft_triadic_threads.hpp 00004 * DESCRIPTION: multi-threaded triadic FFT multiplication 00005 * COPYRIGHT : (C) 2008 Gregoire Lecerf 00006 ******************************************************************************* 00007 * This software falls under the GNU general public license and comes WITHOUT 00008 * ANY WARRANTY WHATSOEVER. See the file $TEXMACS_PATH/LICENSE for more details. 00009 * If you don't have this file, write to the Free Software Foundation, Inc., 00010 * 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. 00011 ******************************************************************************/ 00012 00013 #ifndef __MMX__FFT_TRIADIC_THREADS__HPP 00014 #define __MMX__FFT_TRIADIC_THREADS__HPP 00015 #include <algebramix/fft_triadic_naive.hpp> 00016 #include <basix/threads.hpp> 00017 00018 namespace mmx { 00019 00020 /****************************************************************************** 00021 * Multi-threaded and triadic FFT transformation 00022 ******************************************************************************/ 00023 00024 template<typename C, typename FFTER= fft_triadic_naive_transformer<C>, 00025 nat thr= 9> 00026 class fft_triadic_threads_transformer { 00027 public: 00028 typedef implementation<vector_linear,vector_naive> NVec; 00029 typedef typename FFTER::V V; 00030 typedef typename V::template helper<C>::roots_type R; 00031 typedef typename R::U U; 00032 static const nat min_reps= 27; // Must be a power of 3 00033 00034 FFTER* ffter; 00035 nat depth; 00036 nat len; 00037 U* roots; 00038 U* stoor; 00039 00040 public: 00041 inline fft_triadic_threads_transformer (nat n): 00042 ffter (new FFTER (n)), 00043 depth (ffter->depth), len (ffter->len), 00044 roots (ffter->roots), stoor (ffter->stoor) {} 00045 inline ~fft_triadic_threads_transformer () { delete ffter; } 00046 00047 template<typename CC> static inline void 00048 copy (CC* d, nat drs, nat dcs, CC* s, nat srs, nat scs, nat r, nat c) { 00049 for (nat j=0; j<c; j++, d+=dcs, s+=scs) { 00050 CC* dd= d; CC* ss= s; 00051 for (nat i=0; i<r; i++, dd+=drs, ss+=srs) 00052 *dd= *ss; 00053 } 00054 } 00055 00056 template<typename CC> 00057 struct outer_fft_triadic_task_rep: public task_rep { 00058 FFTER* ffter; 00059 bool direct; 00060 CC *buf; 00061 nat thr_reps, tot_reps; 00062 CC *c; 00063 nat stride, shift, steps; 00064 nat l; 00065 public: 00066 inline outer_fft_triadic_task_rep 00067 (FFTER* ffter2, bool direct2, 00068 CC* buf2, nat thr_reps2, nat tot_reps2, 00069 CC* c2, nat stride2, nat shift2, nat steps2): 00070 ffter (ffter2), direct (direct2), 00071 buf (buf2), thr_reps (thr_reps2), tot_reps (tot_reps2), 00072 c (c2), stride (stride2), shift (shift2), steps (steps2) { 00073 l= binpow ((nat) 3, steps); } 00074 inline ~outer_fft_triadic_task_rep () {} 00075 void execute () { 00076 for (nat i=0; i*min_reps < thr_reps; i++, c += stride * min_reps) { 00077 CC* aux= buf; 00078 copy (buf, 1, min_reps, c, stride, stride * tot_reps, 00079 min_reps, l); 00080 if (direct) 00081 for (nat j=0; j<min_reps; j++, aux++) 00082 ffter->dfft_triadic (aux, min_reps, shift, steps); 00083 else 00084 for (nat j=0; j<min_reps; j++, aux++) 00085 ffter->ifft_triadic (aux, min_reps, shift, steps); 00086 copy (c, stride, stride * tot_reps, buf, 1, min_reps, 00087 min_reps, l); 00088 } 00089 } 00090 }; 00091 00092 template<typename CC> 00093 struct inner_fft_triadic_task_rep: public task_rep { 00094 FFTER* ffter; 00095 bool direct; 00096 nat start, inc, total; 00097 CC *c; 00098 nat stride, shift, steps; 00099 nat l; 00100 public: 00101 inline inner_fft_triadic_task_rep 00102 (FFTER* ffter2, bool direct2, nat start2, nat inc2, nat total2, 00103 CC* c2, nat stride2, nat shift2, nat steps2): 00104 ffter (ffter2), direct (direct2), 00105 start (start2), inc (inc2), total (total2), 00106 c (c2), stride (stride2), shift (shift2), steps (steps2) { 00107 l= binpow ((nat) 3, steps); } 00108 inline ~inner_fft_triadic_task_rep () {} 00109 void execute () { 00110 if (direct) 00111 for (nat i=start; i<total; i+=inc) 00112 ffter->dfft_triadic (c + i * l * stride, stride, 00113 shift + i * l, steps); 00114 else 00115 for (nat i=start; i<total; i+=inc) 00116 ffter->ifft_triadic (c + i * l * stride, stride, 00117 shift + i * l, steps); 00118 } 00119 }; 00120 00121 template<typename CC> void 00122 fft_triadic (bool direct, CC* c, nat stride, nat shift, nat steps) { 00123 nat nt= threads_number; 00124 nat half1= steps >> 1; 00125 nat half2= steps - half1; 00126 nat len1= binpow ((nat) 3, half1); 00127 nat len2= binpow ((nat) 3, half2); 00128 if (steps <= thr || len2 <= min_reps || nt == 1) { 00129 if (direct) ffter->dfft_triadic (c, stride, shift, steps); 00130 else ffter->ifft_triadic (c, stride, shift, steps); 00131 } 00132 else { 00133 for (nat stage=0; stage<2; stage++) { 00134 if ((stage == 0) ^ (!direct)) { 00135 // Below is multi-threaded the following loop: 00136 // C* cc1 = c; 00137 // for (nat i=0; i< len2; i++, cc1 += stride) 00138 // if (direct) 00139 // ffter->dfft_triadic (cc1, stride * len2, shift, half1); 00140 // else 00141 // ffter->ifft_triadic (cc1, stride * len2, shift, half1); 00142 nat bsz= min_reps * len1; // Recall min_reps < len2 00143 CC* buf= mmx_new<CC> (nt * bsz); 00144 task tasks[nt]; 00145 for (nat i=0; i<nt; i++) { 00146 nat sta= min_reps * (( i * len2) / (min_reps * nt)); 00147 nat end= min_reps * (((i+1) * len2) / (min_reps * nt)); 00148 tasks[i]= new outer_fft_triadic_task_rep<CC> 00149 (ffter, direct, buf + i*bsz, end-sta, len2, 00150 c + sta*stride, stride, shift / len2, half1); 00151 } 00152 threads_execute (tasks, nt); 00153 mmx_delete<CC> (buf, nt * bsz); 00154 } 00155 else { 00156 // Below is multi-threaded the following loop: 00157 // CC* cc2= c; 00158 // for (nat i=0; i<len1; i++, cc2 += stride*len2) 00159 // if (direct) 00160 // ffter->dfft_triadic (cc2, stride, shift + i*len2, half2); 00161 // else 00162 // ffter->ifft_triadic (cc2, stride, shift + i*len2, half2); 00163 task tasks[nt]; 00164 for (nat i=0; i<nt; i++) 00165 tasks[i]= new inner_fft_triadic_task_rep<CC> 00166 (ffter, direct, i, nt, len1, c, stride, shift, half2); 00167 threads_execute (tasks, nt); 00168 } 00169 } 00170 } 00171 } 00172 00173 template<typename CC> inline void 00174 dfft_triadic (CC* c, nat stride, nat shift, nat steps) { 00175 if (steps <= thr) ffter->dfft_triadic (c, stride, shift, steps); 00176 else fft_triadic (true, c, stride, shift, steps); 00177 } 00178 00179 template<typename CC> inline void 00180 ifft_triadic (CC* c, nat stride, nat shift, nat steps) { 00181 if (steps <= thr) ffter->ifft_triadic (c, stride, shift, steps); 00182 else fft_triadic (false, c, stride, shift, steps); 00183 } 00184 00185 template<typename CC> inline void 00186 dfft_triadic (CC* c, nat stride, nat shift, nat steps, 00187 nat step1, nat step2) { 00188 if (step1 == 0 && step2 == steps && steps > thr) 00189 fft_triadic (true, c, stride, shift, steps); 00190 else ffter->dfft_triadic (c, stride, shift, steps, step1, step2); 00191 } 00192 00193 template<typename CC> inline void 00194 ifft_triadic (CC* c, nat stride, nat shift, nat steps, 00195 nat step1, nat step2) { 00196 if (step1 == 0 && step2 == steps && steps > thr) 00197 fft_triadic (false, c, stride, shift, steps); 00198 else ffter->ifft_triadic (c, stride, shift, steps, step1, step2); 00199 } 00200 00201 inline void 00202 direct_transform_triadic (C* c) { 00203 dfft_triadic (c, 1, 0, depth); 00204 } 00205 00206 inline void 00207 inverse_transform_triadic (C* c, bool shift=true) { 00208 ifft_triadic (c, 1, 0, depth); 00209 if (shift) NVec::mul (c, invert (C (len)), len); 00210 } 00211 }; 00212 00213 } // namespace mmx 00214 #endif //__MMX__FFT_TRIADIC_THREADS__HPP