algebramix_doc 0.3
/Users/mourrain/Devel/mmx/algebramix/include/algebramix/series_fast.hpp
Go to the documentation of this file.
00001 
00002 /******************************************************************************
00003 * MODULE     : series_fast.hpp
00004 * DESCRIPTION: Fast univariate power series
00005 * COPYRIGHT  : (C) 2006  Joris van der Hoeven
00006 *******************************************************************************
00007 * This software falls under the GNU general public license and comes WITHOUT
00008 * ANY WARRANTY WHATSOEVER. See the file $TEXMACS_PATH/LICENSE for more details.
00009 * If you don't have this file, write to the Free Software Foundation, Inc.,
00010 * 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
00011 ******************************************************************************/
00012 
00013 #ifndef __MMX__SERIES_FAST__HPP
00014 #define __MMX__SERIES_FAST__HPP
00015 #include <algebramix/series.hpp>
00016 #include <algebramix/series_vector.hpp>
00017 #include <algebramix/fkt_transform.hpp>
00018 #include <algebramix/fft_naive.hpp>
00019 
00020 namespace mmx {
00021 #define TMPL template<typename C,typename V>
00022 #define Series series<C,V>
00023 #define Series_rep series_rep<C,V>
00024 #define Vector vector<C>
00025 #define Series_vector series<Vector,V>
00026 #define Series_vector_rep series_rep<Vector,V>
00027 
00028 /******************************************************************************
00029 * Variant
00030 ******************************************************************************/
00031 
00032 struct series_fast {};
00033 
00034 template<typename F, typename V>
00035 struct implementation<F,V,series_fast>:
00036     public implementation<F,V,series_naive> {};
00037 
00038 /******************************************************************************
00039 * Multiplication
00040 ******************************************************************************/
00041 
00042 template<typename U>
00043 struct implementation<series_multiply,U,series_fast>:
00044   public implementation<series_abstractions,U>
00045 {
00046 
00047 // Setting up the different levels for fast multiplication
00048 #define TRANSFORM_NAIVE      0
00049 #define TRANSFORM_KARATSUBA  1
00050 #define TRANSFORM_FFT        2
00051 
00052 template <typename C>
00053 struct level_info {
00054 MMX_ALLOCATORS
00055   typedef C* C_ptr;
00056 
00057   nat n;          // size of blocks
00058   nat k;          // multiplier
00059   nat tsz;        // size of transformed blocks
00060   nat type;       // type of the transformation
00061   C_ptr head[2];  // buffer for transformed heads
00062   C_ptr tail[2];  // buffer for transformed tails
00063 
00064   level_info () {
00065     head[0]= head[1]= NULL;
00066     tail[0]= tail[1]= NULL;
00067   }
00068   ~level_info () {
00069     if (head[0] != NULL) mmx_classical_delete<C> (head[0]);
00070     if (head[1] != NULL) mmx_classical_delete<C> (head[1]);
00071     if (tail[0] != NULL) mmx_classical_delete<C> (tail[0]);
00072     if (tail[1] != NULL) mmx_classical_delete<C> (tail[1]);
00073   }
00074 };
00075 
00076 static nat
00077 get_inside_multiplier (nat n) {
00078   if (n <= (1 << 7)) return 16;
00079   if (n <= (1 << 9)) return 4;
00080   if (n <= (1 << 10)) return 8;
00081   if (n <= (1 << 15)) return 16;
00082   if (n <= (1 << 23)) return 32;
00083   return 64;
00084 }
00085 
00086 static nat
00087 get_border_multiplier (nat n) {
00088   if (n <= (1 << 4)) return 2;
00089   if (n <= (1 << 8)) return 4;
00090   if (n <= (1 << 13)) return 8;
00091   if (n <= (1 << 21)) return 16;
00092   return 32;
00093 }
00094 
00095   /* Better for semi-relaxed multiplication
00096 static nat
00097 get_border_multiplier (nat n) {
00098   if (n <= (1 << 4)) return 2;
00099   if (n <= (1 << 6)) return 4;
00100   if (n <= (1 << 9)) return 8;
00101   if (n <= (1 << 17)) return 16;
00102   return 32;
00103 }
00104   */
00105 
00106   /* Better for (semi-)relaxed mult. over F_p and k-bit compl. floats
00107 static nat
00108 get_border_multiplier (nat n) {
00109   if (n <= (1 << 3)) return 2;
00110   if (n <= (1 << 4)) return 4;
00111   if (n <= (1 << 8)) return 8;
00112   if (n <= (1 << 16)) return 16;
00113   return 32;
00114 }
00115   */
00116 
00117 static vector<nat>
00118 determine_sizes (nat n) {
00119   if (n <= 16) return vector<nat> (1, 1);
00120   vector<nat> v;
00121   nat k= get_inside_multiplier (n);
00122   // nat k= get_inside_multiplier (n*2); for F_p and k-bit compl. floats
00123   n /= k;
00124   while (n != 1) {
00125     k= get_border_multiplier (n);
00126     v << k;
00127     n /= k;
00128   }
00129   v << 1;
00130   for (nat i=0; i<N(v)/2; i++)
00131     swap (v[i], v[N(v) - 1 - i]);
00132   return v;
00133 }
00134 
00135 
00136 // Fast multiplication
00137 //static nat mul_count= 0;
00138 
00139 TMPL
00140 class nrelax_mul_series_rep: public Series_rep {
00141 public:
00142   typedef typename series_polynomial_helper<C,V >::PV PV;
00143   typedef implementation<polynomial_linear,PV> Pol;
00144   typedef fkt_package<polynomial_naive> Fkt;
00145   
00146 protected:
00147   Series f[2];           // the series being multiplied
00148   nat sh[2];             // sh[i]=1 if f[1-i] is fast and 0 otherwise
00149   nat capacity[2];       // capacity for inner transforms
00150   nat nr_levels;         // number of levels
00151   level_info<C>* info;   // information about each level
00152   nat xnr_levels;        // length of xinfo
00153   level_info<C>* xinfo;  // storage of info
00154 
00155 public:
00156   nrelax_mul_series_rep (const Series& f2, const Series& g2, nat n):
00157     Series_rep (CF(f2))
00158   {
00159     f[0]= f2; f[1]= g2;
00160     sh[0]= 1; sh[1]= 1;
00161     //sh[0]= 1; sh[1]= 0;
00162     const vector<nat> v= determine_sizes (n);
00163     nr_levels= N(v);
00164     info= mmx_new<level_info<C> > (nr_levels);
00165     nat sz= 1;
00166     for (nat level=0; level< nr_levels; level++) {
00167       bool last= (level == nr_levels - 1);
00168       nat n= sz * v[level];
00169       nat k= (last? 1: v[level+1]);
00170       info[level].n = n;
00171       info[level].k = k;
00172       nat tsz, type;
00173       if (v[level] == 1) {
00174         tsz = 1;
00175         type= TRANSFORM_NAIVE;
00176       }
00177       else if (v[level] == 2) {
00178         tsz = 3 * info[level-1].tsz;
00179         type= TRANSFORM_KARATSUBA;
00180       }
00181       else {
00182         tsz= 2 * n;
00183         type= TRANSFORM_FFT;
00184       }
00185       info[level].tsz = tsz;
00186       info[level].type= type;
00187       //mmout << "Level " << level << "\n";
00188       //mmout << "  " << "n   = " << n << "\n";
00189       //mmout << "  " << "k   = " << k << "\n";
00190       //mmout << "  " << "tsz = " << tsz << "\n";
00191       //mmout << "  " << "type= " << type << "\n";
00192       if (last) {
00193         info[level].tail[0]= mmx_classical_new<C> (2 * tsz);
00194         info[level].tail[1]= mmx_classical_new<C> (2 * tsz);
00195         capacity[0]= 2;
00196         capacity[1]= 2;
00197       }
00198       else {
00199         if (sh[0] != 0) {
00200           info[level].head[0]= mmx_classical_new<C> (k * tsz);
00201           info[level].tail[1]= mmx_classical_new<C> (k * tsz);
00202         }
00203         if (sh[1] != 0) {
00204           info[level].head[1]= mmx_classical_new<C> (k * tsz);
00205           info[level].tail[0]= mmx_classical_new<C> (k * tsz);
00206         }
00207       }
00208       sz *= v[level];
00209       //mmout << "  " << "head= " << info[level].head[0] << ", " << info[level].head[1] << "\n";
00210       //mmout << "  " << "tail= " << info[level].tail[0] << ", " << info[level].tail[1] << "\n";
00211     }
00212     xinfo= info;
00213     xnr_levels= nr_levels;
00214     if (sh[0] == 0 && sh[1] == 0) {
00215       info += nr_levels-1;
00216       nr_levels= 1;
00217     }
00218   }
00219 
00220   ~nrelax_mul_series_rep () {
00221     this->l= allocated (this->l); // not very nice, but necessary correction
00222     mmx_delete<level_info<C> > (xinfo, xnr_levels); }
00223 
00224   syntactic expression (const syntactic& z) const {
00225     return flatten (f[0], z) * flatten (f[1], z); }
00226 
00227   nat allocated (nat l) {
00228     if (l == 0) return 0;
00229     else return l + 2 * info[nr_levels-1].n; }
00230   void Set_order (nat l2) { // not very nice, but OK...
00231     if (l2 <= this->l) return;
00232     nat old_allocated= allocated (this->l);
00233     nat new_allocated= allocated (l2);
00234     C* b= mmx_new<C> (new_allocated);
00235     Pol::copy (b, this->a, old_allocated);
00236     Pol::clear (b + old_allocated, new_allocated - old_allocated);
00237     mmx_delete<C> (this->a, old_allocated);
00238     this->a= b;
00239     this->l= l2;
00240   }
00241 
00242   void Increase_order (nat l) {
00243     Series_rep::Increase_order (l);
00244     increase_order (f[0], l);
00245     increase_order (f[1], l); }
00246 
00247   void direct_transform (C* dest, nat ld, const C* src, nat ls, nat type) {
00248     //mmout << "    ld= " << ld << ", ls= " << ls << ", type= " << type << "\n";
00249     if (type == TRANSFORM_NAIVE) dest[0]= src[0];
00250     else if (type == TRANSFORM_KARATSUBA) {
00251       Pol::copy (dest, src, ls);
00252       Fkt::direct_fkt (dest, ls, ld);
00253     }
00254     else {
00255       Pol::copy (dest, src, ls);
00256       Pol::clear (dest + ls, ls);
00257       fft_naive_transformer<C> ffter (ld);
00258       ffter.direct_transform (dest);
00259     }
00260   }
00261 
00262   void inverse_transform (C* dest, nat n, nat tsz, nat type) {
00263     if (type == TRANSFORM_NAIVE);
00264     else if (type == TRANSFORM_KARATSUBA) Fkt::inverse_fkt (dest, tsz, n);
00265     else {
00266       fft_naive_transformer<C> ffter (tsz);
00267       ffter.inverse_transform (dest);
00268     }
00269   }
00270 
00271   void direct_transform (nat which) {
00272     for (nat level= 0; level < nr_levels; level++) {
00273       typedef C* C_ptr;
00274       nat    n   = info[level].n;
00275       nat    tsh = sh[0] + sh[1];
00276       nat    cur = this->n + tsh - n * sh[1-which];
00277       if (cur % n != 0) break;
00278       if (sh[1-which] != 0 && this->n + 1 < n) break;
00279       //mmout << "  Transform " << cur/n << " (" << which << ") at level " << level << "\n";
00280       bool   last= (level == nr_levels - 1);
00281       nat    k   = info[level].k;
00282       nat    tsz = info[level].tsz;
00283       nat    type= info[level].type;
00284       C_ptr& head= info[level].head[which];
00285       C_ptr& tail= info[level].tail[which];
00286       nat Cur= cur / n;
00287       nat Mod= Cur % k;
00288       //mmout << "    Mod= " << Mod << "\n";
00289       if (f[which]->n < this->n + 1 && which == 1) {
00290         mmerr << "\n>>> n= " << this->n
00291              << " and only " << f[which]->n << " terms\n";
00292         ASSERT (f[which]->n >= this->n + 1, "insufficient number of terms");
00293       }
00294       const C* seg= (sh[1-which] == 0?
00295                      f[which] (this->n, this->n + n):
00296                      f[which] (this->n + 1 - n, this->n + 1));
00297       //mmout << "    seg= "; Pol::print (mmout, seg, n); mmout << "\n";
00298       if (last) {
00299         nat Cap= capacity[which];
00300         //mmout << "    Cur= " << cur << ", Cap= " << Cap << "\n";
00301         if (Cur >= Cap) {
00302           nat old_cap= Cap * tsz;
00303           nat new_cap= old_cap << 1;
00304           C* a= mmx_classical_new<C> (new_cap);
00305           Pol::copy (a, tail, old_cap);
00306           Pol::clear (a + old_cap, new_cap - old_cap);
00307           mmx_classical_delete<C> (tail);
00308           tail= a;
00309           capacity[which]= Cap << 1;
00310           //mmout << "    New capacity= " << capacity[which] << "\n";
00311         }
00312         //mmout << "    -> tail " << Cur << "\n";
00313         direct_transform (tail + Cur * tsz, tsz, seg, n, type);
00314         //mmout << "    tr= "; Pol::print (mmout, tail + Cur * tsz, tsz); mmout << "\n";
00315       }
00316       else if (Cur < k * sh[which]) {
00317         //mmout << "    -> head " << Mod << "\n";
00318         direct_transform (head + Mod * tsz, tsz, seg, n, type);
00319         //mmout << "    tr= "; Pol::print (mmout, head + Mod * tsz, tsz); mmout << "\n";
00320       }
00321       else if (tail != NULL) {
00322         //mmout << "    -> tail " << Mod << "\n";
00323         direct_transform (tail + Mod * tsz, tsz, seg, n, type);
00324         //mmout << "    tr= "; Pol::print (mmout, tail + Mod * tsz, tsz); mmout << "\n";
00325       }
00326     }
00327   }
00328 
00329   inline void accumulate (C* dest, const C* s1, const C* s2, nat len) {
00330     //mul_count += len;
00331     //mmout << "      s1= "; Pol::print (mmout, s1, len); mmout << "\n";
00332     //mmout << "      s2= "; Pol::print (mmout, s2, len); mmout << "\n";
00333     Pol::mul_add (dest, s1, s2, len);
00334     //mmout << "      d = "; Pol::print (mmout, dest, len); mmout << "\n";
00335   }
00336 
00337   C next () {
00338     //mmout << "[" << this->n << "]" << flush_now;
00339     //mmout << "Coefficient: " << this->n << "\n";
00340     (void) f[0][this->n];
00341     (void) f[1][this->n];
00342     direct_transform (0);
00343     direct_transform (1);
00344     for (nat level= 0; level < nr_levels; level++) {
00345       nat   tsh = sh[0] + sh[1];
00346       nat   cur = this->n + tsh;
00347       nat   n   = info[level].n;
00348       if (cur % n != 0) break;
00349       nat   Cur = cur/n;
00350       if (Cur < tsh) break;
00351       //mmout << "  Product at level " << level << "\n";
00352       bool  last= (level == nr_levels - 1);
00353       nat   msh = min (sh[0], sh[1]);
00354       nat   k   = info[level].k;
00355       nat   tsz = info[level].tsz;
00356       nat   type= info[level].type;
00357       const C* h0 = info[level].head[0];
00358       const C* h1 = info[level].head[1];
00359       const C* t0 = info[level].tail[0];
00360       const C* t1 = info[level].tail[1];
00361       nat   Mod = Cur % k;
00362 
00363       //mmout << "    tsz= " << tsz << ", l= " << 2*n-1 << "\n";
00364       C* acc= mmx_new<C> (tsz);
00365       Pol::clear (acc, tsz);
00366       if (last)
00367         for (nat i=sh[0]; i<=Cur-sh[1]; i++)
00368           accumulate (acc, t0 + i*tsz, t1 + (Cur-i)*tsz, tsz);
00369       else {
00370         if (Cur < 2 * msh * k) {
00371           nat start= 1      ; if (Cur > k) start= Cur - k + 1;
00372           nat end  = Cur - 1; if (Cur > k) end  = k - 1;
00373           for (nat i=start; i<=end; i++)
00374             accumulate (acc, h0 + i*tsz, h1 + (Cur-i)*tsz, tsz);
00375         }
00376         if (Cur >= msh * k && Mod != 0) {
00377           //mmout << "    Normal\n";
00378           if (sh[0] != 0)
00379             for (nat i=1; i<=Mod; i++)
00380               accumulate (acc, h0 + i*tsz, t1 + (Mod-i)*tsz, tsz);
00381           if (sh[1] != 0)
00382             for (nat i=0; i<=Mod-1; i++)
00383               accumulate (acc, t0 + i*tsz, h1 + (Mod-i)*tsz, tsz);
00384         }
00385       }
00386       inverse_transform (acc, n, tsz, type);
00387       //mmout << "    acc= "; Pol::print (mmout, acc, 2*n-1); mmout << "\n";
00388       Pol::add (this->a + this->n, acc, 2*n - 1);
00389 
00390       if (Cur >= msh * k && Mod == 0 && !last) {
00391         //mmout << "    Extra " << k << "\n";
00392         for (nat Mod= 0; Mod<k-1; Mod++) {
00393           Pol::clear (acc, tsz);
00394           for (nat i=Mod+1; i<k; i++) {
00395             if (sh[0] != 0)
00396               accumulate (acc, h0 + i*tsz, t1 + (k+Mod-i)*tsz, tsz);
00397             if (sh[1] != 0)
00398               accumulate (acc, t0 + i*tsz, h1 + (k+Mod-i)*tsz, tsz);
00399           }
00400           inverse_transform (acc, n, tsz, type);
00401           Pol::add (this->a + this->n + Mod*n, acc, 2*n - 1);
00402         }
00403       }
00404       mmx_delete<C> (acc, tsz);
00405     }
00406     return this->a [this->n];
00407   }
00408 };
00409 
00410 TMPL static Series
00411 nrelax_mul (const Series& f, const Series& g, nat n) {
00412   return (Series_rep*) new nrelax_mul_series_rep<C,V> (f, g, n);
00413 }
00414 
00415 // Top level interface for fast multiplication
00416 
00417 TMPL
00418 class mul_series_rep: public 
00419   implementation<series_abstractions,V>
00420     ::template binary_series_rep<mul_op,C,V> {
00421 protected:
00422   Series prod;
00423   nat    N;
00424 public:
00425   inline mul_series_rep (const Series& f, const Series& g):
00426     implementation<series_abstractions,V>
00427       ::template binary_series_rep<mul_op,C,V > (f, g),
00428     prod (nrelax_mul (f, g, 1)), N (1) {}
00429   C next () { return prod [this->n]; }
00430   void Increase_order (nat l) {
00431     Series_rep::Increase_order (l);
00432     increase_order (prod, l);
00433     if (l < 2*N) return;
00434     while (l >= 2*N) N= 2*N;
00435     //while (l >= 2*N) N= 4*N;
00436     prod= nrelax_mul (this->f, this->g, N);
00437   }
00438 };
00439 
00440 TMPL static inline Series
00441 ser_mul (const Series& f, const Series& g) {
00442   typedef mul_series_rep<C,V> Mul_rep;
00443   if (is_exact_zero (f) || is_exact_zero (g))
00444     return Series (CF(f));
00445   return (Series_rep*) new Mul_rep (f, g); }
00446 
00447 TMPL static inline Series
00448 ser_truncate_mul (const Series& f, const Series& g, nat nf, nat ng) {
00449   typedef mul_series_rep<C,V> Mul_rep;
00450   if (is_exact_zero (f) || is_exact_zero (g) || nf == 0 || ng == 0)
00451     return Series (CF(f));
00452   return (Series_rep*)
00453     new Mul_rep (piecewise (f, Series (CF(f)), nf),
00454                  piecewise (g, Series (CF(g)), ng)); }
00455 
00456 }; // implementation<series_multiply,U,series_relaxed>
00457 
00458 #undef TMPL
00459 #undef Series
00460 #undef Series_rep
00461 #undef Vector
00462 #undef Series_vector
00463 #undef Series_vector_rep
00464 } // namespace mmx
00465 #endif // __MMX__SERIES_FAST__HPP
 All Classes Namespaces Files Functions Variables Typedefs Friends Defines