algebramix_doc 0.3
|
00001 00002 /****************************************************************************** 00003 * MODULE : fft_threads.hpp 00004 * DESCRIPTION: multi-threaded FFT 00005 * COPYRIGHT : (C) 2007 Joris van der Hoeven 00006 ******************************************************************************* 00007 * This software falls under the GNU general public license and comes WITHOUT 00008 * ANY WARRANTY WHATSOEVER. See the file $TEXMACS_PATH/LICENSE for more details. 00009 * If you don't have this file, write to the Free Software Foundation, Inc., 00010 * 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. 00011 ******************************************************************************/ 00012 00013 #ifndef __MMX__FFT_THREADS__HPP 00014 #define __MMX__FFT_THREADS__HPP 00015 #include <basix/threads.hpp> 00016 #include <algebramix/vector_threads.hpp> 00017 #include <algebramix/fft_blocks.hpp> 00018 00019 namespace mmx { 00020 00021 /****************************************************************************** 00022 * Multi-threaded and dichotomic FFT transformation 00023 ******************************************************************************/ 00024 00025 template<typename C, 00026 typename FFTER= fft_blocks_transformer<C, 00027 fft_naive_transformer<C>, 00028 8, 6, 10>, 00029 nat thr= 14> 00030 class fft_threads_transformer { 00031 public: 00032 typedef implementation<vector_linear,vector_threads<vector_naive> > Vec; 00033 typedef typename FFTER::R R; 00034 typedef typename R::U U; 00035 typedef typename R::S S; 00036 static const nat min_reps= 16; 00037 00038 FFTER* ffter; 00039 nat depth; 00040 nat len; 00041 U* roots; 00042 00043 public: 00044 inline fft_threads_transformer (nat n): 00045 ffter (new FFTER (n)), 00046 depth (ffter->depth), len (ffter->len), roots (ffter->roots) {} 00047 inline ~fft_threads_transformer () { delete ffter; } 00048 00049 static inline void 00050 copy (C* d, nat drs, nat dcs, C* s, nat srs, nat scs, nat r, nat c) { 00051 for (nat j=0; j<c; j++, d+=dcs, s+=scs) { 00052 C* dd= d; C* ss= s; 00053 for (nat i=0; i<r; i++, dd+=drs, ss+=srs) 00054 *dd= *ss; 00055 } 00056 } 00057 00058 struct outer_fft_task_rep: public task_rep { 00059 FFTER* ffter; 00060 bool direct; 00061 C *buf; 00062 nat thr_reps, tot_reps; 00063 C *c; 00064 nat stride, shift, steps; 00065 public: 00066 inline outer_fft_task_rep 00067 (FFTER* ffter2, bool direct2, 00068 C* buf2, nat thr_reps2, nat tot_reps2, 00069 C* c2, nat stride2, nat shift2, nat steps2): 00070 ffter (ffter2), direct (direct2), 00071 buf (buf2), thr_reps (thr_reps2), tot_reps (tot_reps2), 00072 c (c2), stride (stride2), shift (shift2), steps (steps2) {} 00073 inline ~outer_fft_task_rep () {} 00074 void execute () { 00075 for (nat i=0; i*min_reps < thr_reps; i++, c += stride * min_reps) { 00076 C* aux= buf; 00077 copy (buf, 1, min_reps, c, stride, stride * tot_reps, 00078 min_reps, 1 << steps); 00079 if (direct) 00080 for (nat j=0; j<min_reps; j++, aux++) 00081 ffter->dfft (aux, min_reps, shift, steps); 00082 else 00083 for (nat j=0; j<min_reps; j++, aux++) 00084 ffter->ifft (aux, min_reps, shift, steps); 00085 copy (c, stride, stride * tot_reps, buf, 1, min_reps, 00086 min_reps, 1 << steps); 00087 } 00088 } 00089 }; 00090 00091 struct inner_fft_task_rep: public task_rep { 00092 FFTER* ffter; 00093 bool direct; 00094 nat start, inc, total; 00095 C *c; 00096 nat stride, shift, steps; 00097 public: 00098 inline inner_fft_task_rep 00099 (FFTER* ffter2, bool direct2, nat start2, nat inc2, nat total2, 00100 C* c2, nat stride2, nat shift2, nat steps2): 00101 ffter (ffter2), direct (direct2), 00102 start (start2), inc (inc2), total (total2), 00103 c (c2), stride (stride2), shift (shift2), steps (steps2) {} 00104 inline ~inner_fft_task_rep () {} 00105 void execute () { 00106 if (direct) 00107 for (nat i=start; i<total; i+=inc) 00108 ffter->dfft (c + (i<<steps) * stride, stride, 00109 shift + ((i << steps) >> 1), steps); 00110 else 00111 for (nat i=start; i<total; i+=inc) 00112 ffter->ifft (c + (i<<steps) * stride, stride, 00113 shift + ((i << steps) >> 1), steps); 00114 } 00115 }; 00116 00117 void 00118 fft (bool direct, C* c, nat stride, nat shift, nat steps) { 00119 if (steps <= thr) { 00120 if (direct) ffter->dfft (c, stride, shift, steps); 00121 else ffter->ifft (c, stride, shift, steps); 00122 } 00123 else { 00124 nat half1= min(log_2 (threads_number), steps); 00125 nat half2= steps - half1; 00126 00127 for (nat stage=0; stage<2; stage++) { 00128 if ((stage == 0) ^ (!direct)) { 00129 //C* cc1 = c; 00130 //for (nat i=0; i<((nat) 1<<half2); i++, cc1 += stride) 00131 // ffter->fft (direct, cc1, stride << half2, 00132 // shift >> half2, half1); 00133 00134 nat nt = threads_number; 00135 nat bsz= (min_reps << half1); 00136 C* buf= mmx_new<C> (nt * bsz); 00137 task tasks[nt]; 00138 for (nat i=0; i<nt; i++) { 00139 nat tot= 1 << half2; 00140 nat sta= min_reps * (( i * tot) / (min_reps * nt)); 00141 nat end= min_reps * (((i+1) * tot) / (min_reps * nt)); 00142 tasks[i]= new outer_fft_task_rep 00143 (ffter, direct, buf + i*bsz, end-sta, tot, 00144 c + sta*stride, stride, shift >> half2, half1); 00145 } 00146 threads_execute (tasks, nt); 00147 mmx_delete<C> (buf, nt * bsz); 00148 00149 } 00150 else { 00151 //C* cc2= c; 00152 //for (nat i=0; i<((nat) 1<<half1); i++, cc2 += (stride << half2)) 00153 // ffter->fft (direct, cc2, stride, 00154 // shift + ((i << half2) >> 1), half2); 00155 nat nt= threads_number; 00156 task tasks[nt]; 00157 for (nat i=0; i<nt; i++) 00158 tasks[i]= new inner_fft_task_rep 00159 (ffter, direct, i, nt, 1 << half1, c, stride, shift, half2); 00160 threads_execute (tasks, nt); 00161 } 00162 } 00163 } 00164 } 00165 00166 inline void 00167 dfft (C* c, nat stride, nat shift, nat steps) { 00168 if (steps <= thr) ffter->dfft (c, stride, shift, steps); 00169 else fft (true, c, stride, shift, steps); 00170 } 00171 00172 inline void 00173 ifft (C* c, nat stride, nat shift, nat steps) { 00174 if (steps <= thr) ffter->ifft (c, stride, shift, steps); 00175 else fft (false, c, stride, shift, steps); 00176 } 00177 00178 inline void 00179 dfft (C* c, nat stride, nat shift, nat steps, nat step1, nat step2) { 00180 if (step1 == 0 && step2 == steps && steps > thr) 00181 fft (true, c, stride, shift, steps); 00182 else ffter->dfft (c, stride, shift, steps, step1, step2); 00183 } 00184 00185 inline void 00186 ifft (C* c, nat stride, nat shift, nat steps, nat step1, nat step2) { 00187 if (step1 == 0 && step2 == steps && steps > thr) 00188 fft (false, c, stride, shift, steps); 00189 else ffter->ifft (c, stride, shift, steps, step1, step2); 00190 } 00191 00192 inline void 00193 direct_transform (C* c) { 00194 dfft (c, 1, 0, depth); } 00195 00196 inline void 00197 inverse_transform (C* c, bool divide=true) { 00198 ifft (c, 1, 0, depth); 00199 if (divide) { 00200 S x= binpow (S (2), depth); 00201 x= invert (x); 00202 Vec::template vec_unary_scalar<typename R::fft_mul_sc_op> (c, x, len); 00203 } 00204 } 00205 }; 00206 00207 } // namespace mmx 00208 #endif //__MMX__FFT_THREADS__HPP