algebramix_doc 0.3
|
00001 00002 /****************************************************************************** 00003 * MODULE : series_fast.hpp 00004 * DESCRIPTION: Fast univariate power series 00005 * COPYRIGHT : (C) 2006 Joris van der Hoeven 00006 ******************************************************************************* 00007 * This software falls under the GNU general public license and comes WITHOUT 00008 * ANY WARRANTY WHATSOEVER. See the file $TEXMACS_PATH/LICENSE for more details. 00009 * If you don't have this file, write to the Free Software Foundation, Inc., 00010 * 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. 00011 ******************************************************************************/ 00012 00013 #ifndef __MMX__SERIES_FAST__HPP 00014 #define __MMX__SERIES_FAST__HPP 00015 #include <algebramix/series.hpp> 00016 #include <algebramix/series_vector.hpp> 00017 #include <algebramix/fkt_transform.hpp> 00018 #include <algebramix/fft_naive.hpp> 00019 00020 namespace mmx { 00021 #define TMPL template<typename C,typename V> 00022 #define Series series<C,V> 00023 #define Series_rep series_rep<C,V> 00024 #define Vector vector<C> 00025 #define Series_vector series<Vector,V> 00026 #define Series_vector_rep series_rep<Vector,V> 00027 00028 /****************************************************************************** 00029 * Variant 00030 ******************************************************************************/ 00031 00032 struct series_fast {}; 00033 00034 template<typename F, typename V> 00035 struct implementation<F,V,series_fast>: 00036 public implementation<F,V,series_naive> {}; 00037 00038 /****************************************************************************** 00039 * Multiplication 00040 ******************************************************************************/ 00041 00042 template<typename U> 00043 struct implementation<series_multiply,U,series_fast>: 00044 public implementation<series_abstractions,U> 00045 { 00046 00047 // Setting up the different levels for fast multiplication 00048 #define TRANSFORM_NAIVE 0 00049 #define TRANSFORM_KARATSUBA 1 00050 #define TRANSFORM_FFT 2 00051 00052 template <typename C> 00053 struct level_info { 00054 MMX_ALLOCATORS 00055 typedef C* C_ptr; 00056 00057 nat n; // size of blocks 00058 nat k; // multiplier 00059 nat tsz; // size of transformed blocks 00060 nat type; // type of the transformation 00061 C_ptr head[2]; // buffer for transformed heads 00062 C_ptr tail[2]; // buffer for transformed tails 00063 00064 level_info () { 00065 head[0]= head[1]= NULL; 00066 tail[0]= tail[1]= NULL; 00067 } 00068 ~level_info () { 00069 if (head[0] != NULL) mmx_classical_delete<C> (head[0]); 00070 if (head[1] != NULL) mmx_classical_delete<C> (head[1]); 00071 if (tail[0] != NULL) mmx_classical_delete<C> (tail[0]); 00072 if (tail[1] != NULL) mmx_classical_delete<C> (tail[1]); 00073 } 00074 }; 00075 00076 static nat 00077 get_inside_multiplier (nat n) { 00078 if (n <= (1 << 7)) return 16; 00079 if (n <= (1 << 9)) return 4; 00080 if (n <= (1 << 10)) return 8; 00081 if (n <= (1 << 15)) return 16; 00082 if (n <= (1 << 23)) return 32; 00083 return 64; 00084 } 00085 00086 static nat 00087 get_border_multiplier (nat n) { 00088 if (n <= (1 << 4)) return 2; 00089 if (n <= (1 << 8)) return 4; 00090 if (n <= (1 << 13)) return 8; 00091 if (n <= (1 << 21)) return 16; 00092 return 32; 00093 } 00094 00095 /* Better for semi-relaxed multiplication 00096 static nat 00097 get_border_multiplier (nat n) { 00098 if (n <= (1 << 4)) return 2; 00099 if (n <= (1 << 6)) return 4; 00100 if (n <= (1 << 9)) return 8; 00101 if (n <= (1 << 17)) return 16; 00102 return 32; 00103 } 00104 */ 00105 00106 /* Better for (semi-)relaxed mult. over F_p and k-bit compl. floats 00107 static nat 00108 get_border_multiplier (nat n) { 00109 if (n <= (1 << 3)) return 2; 00110 if (n <= (1 << 4)) return 4; 00111 if (n <= (1 << 8)) return 8; 00112 if (n <= (1 << 16)) return 16; 00113 return 32; 00114 } 00115 */ 00116 00117 static vector<nat> 00118 determine_sizes (nat n) { 00119 if (n <= 16) return vector<nat> (1, 1); 00120 vector<nat> v; 00121 nat k= get_inside_multiplier (n); 00122 // nat k= get_inside_multiplier (n*2); for F_p and k-bit compl. floats 00123 n /= k; 00124 while (n != 1) { 00125 k= get_border_multiplier (n); 00126 v << k; 00127 n /= k; 00128 } 00129 v << 1; 00130 for (nat i=0; i<N(v)/2; i++) 00131 swap (v[i], v[N(v) - 1 - i]); 00132 return v; 00133 } 00134 00135 00136 // Fast multiplication 00137 //static nat mul_count= 0; 00138 00139 TMPL 00140 class nrelax_mul_series_rep: public Series_rep { 00141 public: 00142 typedef typename series_polynomial_helper<C,V >::PV PV; 00143 typedef implementation<polynomial_linear,PV> Pol; 00144 typedef fkt_package<polynomial_naive> Fkt; 00145 00146 protected: 00147 Series f[2]; // the series being multiplied 00148 nat sh[2]; // sh[i]=1 if f[1-i] is fast and 0 otherwise 00149 nat capacity[2]; // capacity for inner transforms 00150 nat nr_levels; // number of levels 00151 level_info<C>* info; // information about each level 00152 nat xnr_levels; // length of xinfo 00153 level_info<C>* xinfo; // storage of info 00154 00155 public: 00156 nrelax_mul_series_rep (const Series& f2, const Series& g2, nat n): 00157 Series_rep (CF(f2)) 00158 { 00159 f[0]= f2; f[1]= g2; 00160 sh[0]= 1; sh[1]= 1; 00161 //sh[0]= 1; sh[1]= 0; 00162 const vector<nat> v= determine_sizes (n); 00163 nr_levels= N(v); 00164 info= mmx_new<level_info<C> > (nr_levels); 00165 nat sz= 1; 00166 for (nat level=0; level< nr_levels; level++) { 00167 bool last= (level == nr_levels - 1); 00168 nat n= sz * v[level]; 00169 nat k= (last? 1: v[level+1]); 00170 info[level].n = n; 00171 info[level].k = k; 00172 nat tsz, type; 00173 if (v[level] == 1) { 00174 tsz = 1; 00175 type= TRANSFORM_NAIVE; 00176 } 00177 else if (v[level] == 2) { 00178 tsz = 3 * info[level-1].tsz; 00179 type= TRANSFORM_KARATSUBA; 00180 } 00181 else { 00182 tsz= 2 * n; 00183 type= TRANSFORM_FFT; 00184 } 00185 info[level].tsz = tsz; 00186 info[level].type= type; 00187 //mmout << "Level " << level << "\n"; 00188 //mmout << " " << "n = " << n << "\n"; 00189 //mmout << " " << "k = " << k << "\n"; 00190 //mmout << " " << "tsz = " << tsz << "\n"; 00191 //mmout << " " << "type= " << type << "\n"; 00192 if (last) { 00193 info[level].tail[0]= mmx_classical_new<C> (2 * tsz); 00194 info[level].tail[1]= mmx_classical_new<C> (2 * tsz); 00195 capacity[0]= 2; 00196 capacity[1]= 2; 00197 } 00198 else { 00199 if (sh[0] != 0) { 00200 info[level].head[0]= mmx_classical_new<C> (k * tsz); 00201 info[level].tail[1]= mmx_classical_new<C> (k * tsz); 00202 } 00203 if (sh[1] != 0) { 00204 info[level].head[1]= mmx_classical_new<C> (k * tsz); 00205 info[level].tail[0]= mmx_classical_new<C> (k * tsz); 00206 } 00207 } 00208 sz *= v[level]; 00209 //mmout << " " << "head= " << info[level].head[0] << ", " << info[level].head[1] << "\n"; 00210 //mmout << " " << "tail= " << info[level].tail[0] << ", " << info[level].tail[1] << "\n"; 00211 } 00212 xinfo= info; 00213 xnr_levels= nr_levels; 00214 if (sh[0] == 0 && sh[1] == 0) { 00215 info += nr_levels-1; 00216 nr_levels= 1; 00217 } 00218 } 00219 00220 ~nrelax_mul_series_rep () { 00221 this->l= allocated (this->l); // not very nice, but necessary correction 00222 mmx_delete<level_info<C> > (xinfo, xnr_levels); } 00223 00224 syntactic expression (const syntactic& z) const { 00225 return flatten (f[0], z) * flatten (f[1], z); } 00226 00227 nat allocated (nat l) { 00228 if (l == 0) return 0; 00229 else return l + 2 * info[nr_levels-1].n; } 00230 void Set_order (nat l2) { // not very nice, but OK... 00231 if (l2 <= this->l) return; 00232 nat old_allocated= allocated (this->l); 00233 nat new_allocated= allocated (l2); 00234 C* b= mmx_new<C> (new_allocated); 00235 Pol::copy (b, this->a, old_allocated); 00236 Pol::clear (b + old_allocated, new_allocated - old_allocated); 00237 mmx_delete<C> (this->a, old_allocated); 00238 this->a= b; 00239 this->l= l2; 00240 } 00241 00242 void Increase_order (nat l) { 00243 Series_rep::Increase_order (l); 00244 increase_order (f[0], l); 00245 increase_order (f[1], l); } 00246 00247 void direct_transform (C* dest, nat ld, const C* src, nat ls, nat type) { 00248 //mmout << " ld= " << ld << ", ls= " << ls << ", type= " << type << "\n"; 00249 if (type == TRANSFORM_NAIVE) dest[0]= src[0]; 00250 else if (type == TRANSFORM_KARATSUBA) { 00251 Pol::copy (dest, src, ls); 00252 Fkt::direct_fkt (dest, ls, ld); 00253 } 00254 else { 00255 Pol::copy (dest, src, ls); 00256 Pol::clear (dest + ls, ls); 00257 fft_naive_transformer<C> ffter (ld); 00258 ffter.direct_transform (dest); 00259 } 00260 } 00261 00262 void inverse_transform (C* dest, nat n, nat tsz, nat type) { 00263 if (type == TRANSFORM_NAIVE); 00264 else if (type == TRANSFORM_KARATSUBA) Fkt::inverse_fkt (dest, tsz, n); 00265 else { 00266 fft_naive_transformer<C> ffter (tsz); 00267 ffter.inverse_transform (dest); 00268 } 00269 } 00270 00271 void direct_transform (nat which) { 00272 for (nat level= 0; level < nr_levels; level++) { 00273 typedef C* C_ptr; 00274 nat n = info[level].n; 00275 nat tsh = sh[0] + sh[1]; 00276 nat cur = this->n + tsh - n * sh[1-which]; 00277 if (cur % n != 0) break; 00278 if (sh[1-which] != 0 && this->n + 1 < n) break; 00279 //mmout << " Transform " << cur/n << " (" << which << ") at level " << level << "\n"; 00280 bool last= (level == nr_levels - 1); 00281 nat k = info[level].k; 00282 nat tsz = info[level].tsz; 00283 nat type= info[level].type; 00284 C_ptr& head= info[level].head[which]; 00285 C_ptr& tail= info[level].tail[which]; 00286 nat Cur= cur / n; 00287 nat Mod= Cur % k; 00288 //mmout << " Mod= " << Mod << "\n"; 00289 if (f[which]->n < this->n + 1 && which == 1) { 00290 mmerr << "\n>>> n= " << this->n 00291 << " and only " << f[which]->n << " terms\n"; 00292 ASSERT (f[which]->n >= this->n + 1, "insufficient number of terms"); 00293 } 00294 const C* seg= (sh[1-which] == 0? 00295 f[which] (this->n, this->n + n): 00296 f[which] (this->n + 1 - n, this->n + 1)); 00297 //mmout << " seg= "; Pol::print (mmout, seg, n); mmout << "\n"; 00298 if (last) { 00299 nat Cap= capacity[which]; 00300 //mmout << " Cur= " << cur << ", Cap= " << Cap << "\n"; 00301 if (Cur >= Cap) { 00302 nat old_cap= Cap * tsz; 00303 nat new_cap= old_cap << 1; 00304 C* a= mmx_classical_new<C> (new_cap); 00305 Pol::copy (a, tail, old_cap); 00306 Pol::clear (a + old_cap, new_cap - old_cap); 00307 mmx_classical_delete<C> (tail); 00308 tail= a; 00309 capacity[which]= Cap << 1; 00310 //mmout << " New capacity= " << capacity[which] << "\n"; 00311 } 00312 //mmout << " -> tail " << Cur << "\n"; 00313 direct_transform (tail + Cur * tsz, tsz, seg, n, type); 00314 //mmout << " tr= "; Pol::print (mmout, tail + Cur * tsz, tsz); mmout << "\n"; 00315 } 00316 else if (Cur < k * sh[which]) { 00317 //mmout << " -> head " << Mod << "\n"; 00318 direct_transform (head + Mod * tsz, tsz, seg, n, type); 00319 //mmout << " tr= "; Pol::print (mmout, head + Mod * tsz, tsz); mmout << "\n"; 00320 } 00321 else if (tail != NULL) { 00322 //mmout << " -> tail " << Mod << "\n"; 00323 direct_transform (tail + Mod * tsz, tsz, seg, n, type); 00324 //mmout << " tr= "; Pol::print (mmout, tail + Mod * tsz, tsz); mmout << "\n"; 00325 } 00326 } 00327 } 00328 00329 inline void accumulate (C* dest, const C* s1, const C* s2, nat len) { 00330 //mul_count += len; 00331 //mmout << " s1= "; Pol::print (mmout, s1, len); mmout << "\n"; 00332 //mmout << " s2= "; Pol::print (mmout, s2, len); mmout << "\n"; 00333 Pol::mul_add (dest, s1, s2, len); 00334 //mmout << " d = "; Pol::print (mmout, dest, len); mmout << "\n"; 00335 } 00336 00337 C next () { 00338 //mmout << "[" << this->n << "]" << flush_now; 00339 //mmout << "Coefficient: " << this->n << "\n"; 00340 (void) f[0][this->n]; 00341 (void) f[1][this->n]; 00342 direct_transform (0); 00343 direct_transform (1); 00344 for (nat level= 0; level < nr_levels; level++) { 00345 nat tsh = sh[0] + sh[1]; 00346 nat cur = this->n + tsh; 00347 nat n = info[level].n; 00348 if (cur % n != 0) break; 00349 nat Cur = cur/n; 00350 if (Cur < tsh) break; 00351 //mmout << " Product at level " << level << "\n"; 00352 bool last= (level == nr_levels - 1); 00353 nat msh = min (sh[0], sh[1]); 00354 nat k = info[level].k; 00355 nat tsz = info[level].tsz; 00356 nat type= info[level].type; 00357 const C* h0 = info[level].head[0]; 00358 const C* h1 = info[level].head[1]; 00359 const C* t0 = info[level].tail[0]; 00360 const C* t1 = info[level].tail[1]; 00361 nat Mod = Cur % k; 00362 00363 //mmout << " tsz= " << tsz << ", l= " << 2*n-1 << "\n"; 00364 C* acc= mmx_new<C> (tsz); 00365 Pol::clear (acc, tsz); 00366 if (last) 00367 for (nat i=sh[0]; i<=Cur-sh[1]; i++) 00368 accumulate (acc, t0 + i*tsz, t1 + (Cur-i)*tsz, tsz); 00369 else { 00370 if (Cur < 2 * msh * k) { 00371 nat start= 1 ; if (Cur > k) start= Cur - k + 1; 00372 nat end = Cur - 1; if (Cur > k) end = k - 1; 00373 for (nat i=start; i<=end; i++) 00374 accumulate (acc, h0 + i*tsz, h1 + (Cur-i)*tsz, tsz); 00375 } 00376 if (Cur >= msh * k && Mod != 0) { 00377 //mmout << " Normal\n"; 00378 if (sh[0] != 0) 00379 for (nat i=1; i<=Mod; i++) 00380 accumulate (acc, h0 + i*tsz, t1 + (Mod-i)*tsz, tsz); 00381 if (sh[1] != 0) 00382 for (nat i=0; i<=Mod-1; i++) 00383 accumulate (acc, t0 + i*tsz, h1 + (Mod-i)*tsz, tsz); 00384 } 00385 } 00386 inverse_transform (acc, n, tsz, type); 00387 //mmout << " acc= "; Pol::print (mmout, acc, 2*n-1); mmout << "\n"; 00388 Pol::add (this->a + this->n, acc, 2*n - 1); 00389 00390 if (Cur >= msh * k && Mod == 0 && !last) { 00391 //mmout << " Extra " << k << "\n"; 00392 for (nat Mod= 0; Mod<k-1; Mod++) { 00393 Pol::clear (acc, tsz); 00394 for (nat i=Mod+1; i<k; i++) { 00395 if (sh[0] != 0) 00396 accumulate (acc, h0 + i*tsz, t1 + (k+Mod-i)*tsz, tsz); 00397 if (sh[1] != 0) 00398 accumulate (acc, t0 + i*tsz, h1 + (k+Mod-i)*tsz, tsz); 00399 } 00400 inverse_transform (acc, n, tsz, type); 00401 Pol::add (this->a + this->n + Mod*n, acc, 2*n - 1); 00402 } 00403 } 00404 mmx_delete<C> (acc, tsz); 00405 } 00406 return this->a [this->n]; 00407 } 00408 }; 00409 00410 TMPL static Series 00411 nrelax_mul (const Series& f, const Series& g, nat n) { 00412 return (Series_rep*) new nrelax_mul_series_rep<C,V> (f, g, n); 00413 } 00414 00415 // Top level interface for fast multiplication 00416 00417 TMPL 00418 class mul_series_rep: public 00419 implementation<series_abstractions,V> 00420 ::template binary_series_rep<mul_op,C,V> { 00421 protected: 00422 Series prod; 00423 nat N; 00424 public: 00425 inline mul_series_rep (const Series& f, const Series& g): 00426 implementation<series_abstractions,V> 00427 ::template binary_series_rep<mul_op,C,V > (f, g), 00428 prod (nrelax_mul (f, g, 1)), N (1) {} 00429 C next () { return prod [this->n]; } 00430 void Increase_order (nat l) { 00431 Series_rep::Increase_order (l); 00432 increase_order (prod, l); 00433 if (l < 2*N) return; 00434 while (l >= 2*N) N= 2*N; 00435 //while (l >= 2*N) N= 4*N; 00436 prod= nrelax_mul (this->f, this->g, N); 00437 } 00438 }; 00439 00440 TMPL static inline Series 00441 ser_mul (const Series& f, const Series& g) { 00442 typedef mul_series_rep<C,V> Mul_rep; 00443 if (is_exact_zero (f) || is_exact_zero (g)) 00444 return Series (CF(f)); 00445 return (Series_rep*) new Mul_rep (f, g); } 00446 00447 TMPL static inline Series 00448 ser_truncate_mul (const Series& f, const Series& g, nat nf, nat ng) { 00449 typedef mul_series_rep<C,V> Mul_rep; 00450 if (is_exact_zero (f) || is_exact_zero (g) || nf == 0 || ng == 0) 00451 return Series (CF(f)); 00452 return (Series_rep*) 00453 new Mul_rep (piecewise (f, Series (CF(f)), nf), 00454 piecewise (g, Series (CF(g)), ng)); } 00455 00456 }; // implementation<series_multiply,U,series_relaxed> 00457 00458 #undef TMPL 00459 #undef Series 00460 #undef Series_rep 00461 #undef Vector 00462 #undef Series_vector 00463 #undef Series_vector_rep 00464 } // namespace mmx 00465 #endif // __MMX__SERIES_FAST__HPP