algebramix_doc 0.3
/Users/mourrain/Devel/mmx/algebramix/include/algebramix/polynomial_dicho.hpp
Go to the documentation of this file.
00001 
00002 /******************************************************************************
00003 * MODULE     : polynomial_dicho.hpp
00004 * DESCRIPTION: dichotomic algorithms including Karatsuba multiplication
00005 * COPYRIGHT  : (C) 2003  Joris van der Hoeven and Gregoire Lecerf
00006 *******************************************************************************
00007 * This software falls under the GNU general public license and comes WITHOUT
00008 * ANY WARRANTY WHATSOEVER. See the file $TEXMACS_PATH/LICENSE for more details.
00009 * If you don't have this file, write to the Free Software Foundation, Inc.,
00010 * 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
00011 ******************************************************************************/
00012 
00013 #ifndef __MMX__POLYNOMIAL_DICHO__HPP
00014 #define __MMX__POLYNOMIAL_DICHO__HPP
00015 #include <basix/vector_sort.hpp>
00016 #include <algebramix/polynomial_naive.hpp>
00017 #include <algebramix/crt_polynomial.hpp>
00018 
00019 namespace mmx {
00020 #define TMPL  template<typename C>
00021 #define TMPLP template <typename Polynomial>
00022 #define Vector vector<C>
00023 
00024 /******************************************************************************
00025 * Variants for dichotomic algorithms on polynomials
00026 ******************************************************************************/
00027 
00028 template<typename V>
00029 struct polynomial_karatsuba: public V {
00030   typedef typename V::Vec Vec;
00031   typedef typename V::Naive Naive;
00032   typedef polynomial_karatsuba<typename V::Positive> Positive;
00033   typedef polynomial_karatsuba<typename V::No_simd> No_simd;
00034   typedef polynomial_karatsuba<typename V::No_thread> No_thread;
00035   typedef polynomial_karatsuba<typename V::No_scaled> No_scaled;
00036 };
00037 
00038 template<typename F, typename V, typename W>
00039 struct implementation<F,V,polynomial_karatsuba<W> >:
00040     public implementation<F,V,W> {};
00041 
00042 template<typename V>
00043 struct polynomial_dicho: public V {
00044   typedef typename V::Vec Vec;
00045   typedef typename V::Naive Naive;
00046   typedef polynomial_dicho<typename V::Positive> Positive;
00047   typedef polynomial_dicho<typename V::No_simd> No_simd;
00048   typedef polynomial_dicho<typename V::No_thread> No_thread;
00049   typedef polynomial_dicho<typename V::No_scaled> No_scaled;
00050 };
00051 
00052 template<typename F, typename V, typename W>
00053 struct implementation<F,V,polynomial_dicho<W> >:
00054   public implementation<F,V,W> {};
00055 
00056 /******************************************************************************
00057 * Multiplication
00058 ******************************************************************************/
00059 
00060 template<typename V>
00061 struct polynomial_multiply_threshold {};
00062 
00063 template<typename V, typename W>
00064 struct implementation<polynomial_multiply,V,polynomial_karatsuba<W> >:
00065   public implementation<polynomial_linear,V>
00066 {
00067   typedef polynomial_multiply_threshold<polynomial_karatsuba<W> > Th;
00068   typedef implementation<vector_linear,V> Vec;
00069   typedef implementation<polynomial_linear,W> Pol;
00070   typedef implementation<polynomial_multiply,W> Fallback;
00071 
00072 TMPL static void
00073 multiply (C* dest, const C* s1, const C* s2, nat n1, nat n2) {
00074   if (n1 < Threshold(C,Th) || n2 < Threshold(C,Th))
00075     Fallback::mul (dest, s1, s2, n1, n2);
00076   else {
00077     nat p1= n1 >> 1, p2= n2 >> 1, P= p1+p2-1;
00078     nat spc= aligned_size<C,V> (3*(p1+p2));
00079     C* buf= mmx_new<C> (spc);
00080     C* low1= dest;
00081     C* low2= low1 + p1;
00082     C* mid1= buf;
00083     C* mid2= buf + p1;
00084     C* hi1 = mid2 + p2;
00085     C* hi2 = hi1 + p1;
00086     C* Low = mid1;
00087     C* Mid = hi1;
00088     C* Hi  = Mid + p1 + p2;
00089     Vec::half_copy (low1 , s1  , p1);
00090     Vec::half_copy (hi1  , s1+1, p1);
00091     Pol::add       (mid1 , low1, hi1 , p1);
00092     Vec::half_copy (low2 , s2  , p2);
00093     Vec::half_copy (hi2  , s2+1, p2);
00094     Pol::add       (mid2 , low2, hi2 , p2);
00095     multiply       (Hi   , hi1 , hi2 , p1, p2);
00096     multiply       (Mid  , mid1, mid2, p1, p2);
00097     multiply       (Low  , low1, low2, p1, p2);
00098     Pol::sub       (Mid  , Low , P);
00099     Pol::sub       (Mid  , Hi  , P);
00100     Pol::add       (Low+1, Hi  , P-1);
00101     Low[P]= Hi[P-1];
00102     Vec::double_copy (dest  , Low, P+1);
00103     Vec::double_copy (dest+1, Mid, P);
00104     mmx_delete<C> (buf, spc);
00105 
00106     if ((n1 & 1) != 0) {
00107       dest[(P<<1)+1]= C(0);
00108       Pol::mul_add (dest + (p1<<1), s2, s1[p1<<1], p2<<1);
00109     }
00110     if ((n2 & 1) != 0) {
00111       dest[n1+n2-2]= C(0);
00112       Pol::mul_add (dest + (p2<<1), s1, s2[p2<<1], n1);
00113     }
00114   }
00115 }
00116 
00117 TMPL static inline void
00118 mul (C* dest, const C* s1, const C* s2, nat n1, nat n2) {
00119   if (n1 == 0 && n2 == 0) return;
00120   if (n1 == 0 || n2 == 0)
00121     Pol::clear (dest, n1 + n2 - 1);
00122   else
00123     multiply (dest, s1, s2, n1, n2);
00124 }
00125 
00126 TMPL static void
00127 square (C* dest, const C* s, nat n) {
00128   if (n == 0) return;
00129   if (n < Threshold(C,Th))
00130     Fallback::square (dest, s, n);
00131   else {
00132     nat p= n >> 1, P= 2*p-1, spc= aligned_size<C,V> (6*p);
00133     C* buf= mmx_new<C> (spc);
00134     C* low= dest;
00135     C* mid= buf;
00136     C* hi = buf + 2*p;
00137     C* Low= buf;
00138     C* Mid= hi;
00139     C* Hi = Mid + 2*p;
00140     Vec::half_copy (low  , s  , p);
00141     Vec::half_copy (hi   , s+1, p);
00142     Pol::add       (mid  , low, hi, p);
00143     square         (Hi   , hi , p);
00144     square         (Mid  , mid, p);
00145     square         (Low  , low, p);
00146     Pol::sub       (Mid  , Low, P);
00147     Pol::sub       (Mid  , Hi , P);
00148     Pol::add       (Low+1, Hi , P-1);
00149     Low[P]= Hi[P-1];
00150     Vec::double_copy (dest  , Low, P+1);
00151     Vec::double_copy (dest+1, Mid, P);
00152     mmx_delete<C> (buf, spc);
00153 
00154     if ((n & 1) != 0) {
00155       dest[(P<<1)+1]= dest[2*n-2]= C(0);
00156       Pol::mul_add (dest + (p<<1), s, s[p<<1], p<<1);
00157       Pol::mul_add (dest + (p<<1), s, s[p<<1], n);
00158     }
00159   }
00160 }
00161 
00162 TMPL static void
00163 tmultiply (C* dest, const C* s1, const C* s2, nat n1, nat n2) {
00164   if (n1 < Threshold(C,Th) || n2 < Threshold(C,Th))
00165     Fallback::tmul (dest, s1, s2, n1, n2);
00166   else {
00167     nat p1= n1 >> 1, p2= n2 >> 1, P= p1+p2-1;
00168     nat spc= aligned_size<C,V> (3*(p1+P)+1);
00169     C* buf= mmx_new<C> (spc);
00170     C* low1= buf;
00171     C* low2= low1 + p1;
00172     C* mid1= low2 + P+1;
00173     C* mid2= mid1 + p1;
00174     C* hi1 = mid2 + P;
00175     C* hi2 = hi1  + p1;
00176     C* Low = mid1;
00177     C* Mid = hi1;
00178     C* Hi  = dest;
00179     Vec::half_copy (low1, s1  , p1);
00180     Vec::half_copy (hi1 , s1+1, p1);
00181     Pol::add       (mid1, low1, hi1, p1);
00182     Vec::half_copy (low2, s2  , P+1);
00183     Vec::half_copy (hi2 , s2+1, P);
00184     Pol::add       (mid2, low2+1, hi2, P);
00185     tmultiply (Hi , hi1 , hi2 , p1, p2);
00186     tmultiply (Mid, mid1, mid2, p1, p2);
00187     tmultiply (Low, low1, low2, p1, p2+1);
00188     Pol::sub (Mid, Hi   , p2);
00189     Pol::sub (Mid, Low+1, p2);
00190     Pol::add (Low, Hi   , p2);
00191     Vec::double_copy (dest  , Low, p2);
00192     Vec::double_copy (dest+1, Mid, p2);
00193     mmx_delete<C> (buf, spc);
00194 
00195     if ((n1 & 1) != 0)
00196       Pol::mul_add (dest, s2+n1-1, s1[n1-1], n2);
00197 
00198     if ((n2 & 1) != 0)
00199       dest[n2-1]= Vec::inn_prod (s1, s2+n2-1, n1);
00200   }
00201 }
00202 
00203 TMPL static inline void
00204 tmul (C* dest, const C* s1, const C* s2, nat n1, nat n2) {
00205   // transposed multiplication
00206   // s1 has length n1, dest has length n2 and s2 has length n1+n2-1
00207   Pol::clear (dest, n2);
00208   if (n1 == 0 || n2 == 0) return;
00209   tmultiply (dest, s1, s2, n1, n2);
00210 }
00211 
00212 }; // implementation<polynomial_multiply,V,polynomial_karatsuba<W> >
00213 
00214 /******************************************************************************
00215 * Division
00216 ******************************************************************************/
00217 
00218 template<typename V>
00219 struct polynomial_divide_threshold {};
00220   
00221 template<typename V, typename BV>
00222 struct implementation<polynomial_divide,V,polynomial_dicho<BV> >:
00223   public implementation<polynomial_multiply,V>
00224 {
00225   typedef polynomial_divide_threshold<polynomial_dicho<BV> > Th;
00226   typedef implementation<polynomial_multiply,V> Pol;
00227   typedef implementation<polynomial_divide,V,BV> Fallback;
00228 
00229 TMPL static void
00230 invert_lo (C* dest, const C* src, nat n) {
00231   if (n == 0) return;
00232   if (n == 1) *dest= C(1) / *src;
00233   else {
00234     nat h= (n+1) >> 1;
00235     nat l= n - h;
00236     invert_lo (dest, src, h);
00237     nat buf_size= aligned_size<C,V> (n << 1);
00238     C* buf= mmx_new<C> (buf_size);
00239     C* aux= buf + n;
00240     Pol::mul (buf, src, dest, n, h);
00241       // FIXME: use middle product
00242     Pol::mul (aux, dest, buf + h, l, l);
00243       // FIXME: use truncated product
00244     Pol::neg (dest + h, aux, l);
00245     mmx_delete<C> (buf, buf_size);
00246   }
00247 }
00248 
00249 TMPL static void
00250 invert_hi (C* dest, const C* src, nat n) {
00251   if (n == 1) *dest= C(1) / *src;
00252   else {
00253     nat h= (n+1) >> 1;
00254     nat l= n - h;
00255     invert_hi (dest + l, src + l, h);
00256     nat buf_size= aligned_size<C,V> (n << 1);
00257     C* buf= mmx_new<C> (buf_size);
00258     C* aux= buf + l;
00259     Pol::mul (aux, src, dest + l, n, h);
00260       // FIXME: use middle product
00261     Pol::mul (buf, dest + h, aux + h - 1, l, l);
00262       // FIXME: use truncated product
00263     Pol::neg (dest, buf + l - 1, l);
00264     mmx_delete<C> (buf, buf_size);
00265   }
00266 }
00267 
00268 TMPL static void
00269 quo_rem (C* dest, C* s1, const C* s2, nat n1, nat n2) {
00270   // NOTE: requires C to be a field
00271   if (n1 < n2);
00272   else if (n1 < Threshold(C,Th))
00273     Fallback::quo_rem (dest, s1, s2, n1, n2);
00274   else {
00275     nat tot= aligned_size<C,V> ((n2 << 1) + n2);
00276     C* buf= mmx_new<C> (tot);
00277     C* inv= buf + (n2 << 1);
00278     nat nq = n1 + 1 - n2;
00279     nat l  = min (n2, nq);
00280     invert_hi (inv + n2 - l, s2 + n2 - l, l);
00281     while (n1 >= n2) {
00282       nat nq= n1 + 1 - n2;
00283       nat l = min (n2, nq);
00284       C* dest_hi= dest + nq - l;
00285       Pol::mul (buf, s1 + n1 - l, inv + n2 - l, l, l);
00286       Pol::copy (dest_hi, buf + l - 1, l);
00287       Pol::mul (buf, dest_hi, s2, l, n2);
00288       Pol::sub (s1 + n1 - (n2 + l - 1), buf, n2 + l - 1);
00289       n1 -= l;
00290     }
00291     mmx_delete<C> (buf, tot);
00292   }
00293 }
00294 
00295 TMPL static void
00296 tquo_rem (C* dest, const C* s1, const C* s2, nat n1, nat n2) {
00297   // (s1, n2-1) contains the part corresponding to the remainder, while
00298   // (s2+n2-1, n1-n2+1) occupies the one of the quotient.
00299   // We assume n2>0 and s2[n2-1] != 0.
00300   // (dest, n1) contains the output.
00301   // NOTE: requires C to be a field
00302   if (n1 < n2)
00303     Pol::copy (dest, s1, n1);
00304   else if (n1 < Threshold(C,Th))
00305     Fallback::tquo_rem (dest, s1, s2, n1, n2);
00306   else {
00307     nat tot= aligned_size<C,V> (3 * n2);
00308     C* buf= mmx_new<C> (tot);
00309     C*  inv= buf + 2*n2;
00310     nat nq = n1 + 1 - n2;
00311     nat l  = min (n2, nq);
00312     invert_hi (inv + n2 - l, s2 + n2 - l, l);
00313     Pol::neg (inv + n2 - l, inv + n2 - l, l);
00314     Pol::copy (dest, s1, n2-1);
00315     Pol::clear (dest+n2-1, n1-n2+1);
00316     nat m = n2-1;
00317     while (m < n1) {
00318       nat nq= n1 - m;
00319       nat l = min (n2, nq);
00320       Pol::clear (buf     , l-1);
00321       Pol::tmul  (buf+l-1 , s2          , dest + m - (n2 - 1), n2, l);
00322       Pol::sub   (buf+l-1 , s1 + m      , l);      
00323       Pol::tmul  (dest + m, inv + n2 - l, buf, l , l);
00324       m += l;
00325     }
00326     mmx_delete<C> (buf, tot);
00327   }
00328 }
00329 
00330 // Pseudo division. C can be any ring.
00331 TMPL static void
00332 pinvert_hi (C* dest, const C* src, nat n) {
00333   if (n == 1) *dest= C(1);
00334   else {
00335     nat h= (n+1) >> 1;
00336     nat l= n - h;
00337     nat tmp_size= aligned_size<C,V> (l);
00338     C* tmp= mmx_new<C> (tmp_size);
00339     pinvert_hi (dest + l, src + l, h);
00340     pinvert_hi (tmp, src + h, l); // FIXME: increase from l to h
00341     nat buf_size= aligned_size<C,V> (n << 1);
00342     C* buf= mmx_new<C> (buf_size);
00343     C* aux= buf + l;
00344     Pol::mul (aux, src, dest + l, n, h);
00345       // FIXME: use middle product
00346     Pol::mul_sc (dest + l, binpow (src[n-1], l), h);
00347     Pol::mul (buf, tmp, aux + h - 1, l, l);
00348       // FIXME: use truncated product
00349     Pol::neg (dest, buf + l - 1, l);
00350     mmx_delete<C> (buf, buf_size);
00351     mmx_delete<C> (tmp, tmp_size);
00352   }
00353 }
00354 
00355 TMPL static void
00356 pquo_rem (C* dest, C* s1, const C* s2, nat n1, nat n2) {
00357   if (n1 < n2);
00358   else if (n1 < Threshold(C,Th))
00359     Fallback::pquo_rem (dest, s1, s2, n1, n2);
00360   else {
00361     nat tot= aligned_size<C,V> ((n2 << 1) + n2);
00362     C* buf= mmx_new<C> (tot);
00363     C* inv= buf + (n2 << 1);
00364     nat tmp_size= aligned_size<C,V> (n2);
00365     C* tmp= mmx_new<C> (tmp_size);
00366     nat nq= n1 + 1 - n2;
00367     nat l = min (n2, nq);
00368     nat l_end= nq % l;
00369     pinvert_hi (inv + n2 - l, s2 + n2 - l, l);
00370     if (l_end != 0)
00371       pinvert_hi (tmp + n2 - l_end, s2 + n2 - l_end, l_end);
00372     while (n1 >= n2) {
00373       nat nq= n1 + 1 - n2;
00374       nat l = min (n2, nq);
00375       C* dest_hi= dest + nq - l;
00376       if (l == l_end)
00377         Pol::mul (buf, s1 + n1 - l, tmp + n2 - l, l, l);
00378       else
00379         Pol::mul (buf, s1 + n1 - l, inv + n2 - l, l, l);
00380       Pol::copy (dest_hi, buf + l - 1, l);
00381       Pol::mul (buf, dest_hi, s2, l, n2);
00382       Pol::mul_sc (s1, binpow (s2[n2-1], l), n1);
00383       Pol::sub (s1 + n1 - (n2 + l - 1), buf, n2 + l - 1);
00384       Pol::mul_sc (dest_hi, binpow (s2[n2-1], nq - l), l);
00385       n1 -= l;
00386     }
00387     mmx_delete<C> (buf, tot);
00388     mmx_delete<C> (tmp, tmp_size);
00389   }
00390 }
00391 
00392 }; // implementation<polynomial_divide,V,polynomial_dicho<BV> >
00393 
00394 /******************************************************************************
00395 * Dichotomic Gcd computations
00396 ******************************************************************************/
00397 
00398 // "Modern Computer Algebra", Algorithm 11.4
00399 
00400 template<typename V>
00401 struct polynomial_euclidean_threshold {};
00402 
00403 template<typename V, typename BV>
00404 struct implementation<polynomial_euclidean,V,polynomial_dicho<BV> >:
00405   public implementation<polynomial_divide,V>
00406 {
00407   typedef polynomial_euclidean_threshold<polynomial_dicho<BV> > Th;
00408   typedef implementation<vector_linear,V> Vec;
00409   typedef implementation<polynomial_divide,V> Pol;
00410   typedef implementation<polynomial_euclidean,V,BV> Fallback;
00411 
00412 private:
00413 
00414 TMPL static void
00415 dot_product (C* d, nat& nd,
00416              const C* r0, const C* r1, nat nr0, nat nr1,
00417              const C* s0, const C* s1, nat ns0, nat ns1,
00418              C* t) {
00419   // d = r0 s0 + r1 s1
00420   if ((nr0 == 0 || ns0 == 0) && (nr1 == 0 || ns1 == 0)) {
00421     nd = 0; return;
00422   }
00423   if (nr0 == 0 || ns0 == 0) {
00424     nd = nr1 + ns1 - 1;
00425     Pol::mul (d, r1, s1, nr1, ns1);
00426     Pol::trim (d, nd);
00427     return;
00428   }
00429   if (nr1 == 0 || ns1 == 0) {
00430     nd = nr0 + ns0 - 1;
00431     Pol::mul (d, r0, s0, nr0, ns0);
00432     Pol::trim (d, nd);
00433   }
00434   if (nr0 + ns0 - 1 < nr1 + ns1 - 1) {
00435     nd = nr1 + ns1 - 1;
00436     Pol::mul (d, r1, s1, nr1, ns1);
00437     Pol::mul (t, r0, s0, nr0, ns0);
00438     Pol::add (d, t , nr0 + ns0 - 1);
00439   }
00440   else {
00441     nd = nr0 + ns0 - 1;
00442     Pol::mul (d, r0, s0, nr0, ns0);
00443     Pol::mul (t, r1, s1, nr1, ns1);
00444     Pol::add (d, t , nr1 + ns1 - 1);
00445   }
00446   Pol::trim (d, nd);
00447 }
00448   
00449 TMPL static void
00450 matrix_vector_product (C* r0, C* r1, nat& nr0, nat& nr1,
00451                        const C* R00, const C* R01, const C* R10, const C* R11,
00452                        nat nR00, nat nR01, nat nR10, nat nR11,
00453                        const C* s0, const C* s1, nat n0, nat n1,
00454                        C* tp) {
00455   // (r0,r1) = R . (s0,s1)
00456   dot_product (r0, nr0, R00, R01, nR00, nR01, s0, s1, n0, n1, tp);
00457   dot_product (r1, nr1, R10, R11, nR10, nR11, s0, s1, n0, n1, tp);
00458 }
00459 
00460 TMPL static void
00461 matrix_product (C* Q00, C* Q01, C* Q10, C* Q11,
00462                 nat& nQ00, nat& nQ01, nat& nQ10, nat& nQ11,
00463                 const C* S00, const C* S01, const C* S10, const C* S11,
00464                 nat nS00, nat nS01, nat nS10, nat nS11,
00465                 const C* R00, const C* R01, const C* R10, const C* R11,
00466                 nat nR00, nat nR01, nat nR10, nat nR11, C* tp) {
00467   // Q = S . R
00468   matrix_vector_product (Q00, Q10, nQ00, nQ10,
00469                          S00, S01, S10, S11, nS00, nS01, nS10, nS11,
00470                          R00, R10, nR00, nR10, tp);
00471   matrix_vector_product (Q01, Q11, nQ01, nQ11, 
00472                          S00, S01, S10, S11, nS00, nS01, nS10, nS11,
00473                          R01, R11, nR01, nR11, tp);
00474 }
00475 
00476 TMPL static void
00477 new_matrix (C*& M00, C*& M01, C*& M10, C*& M11, nat l) {
00478   M00= mmx_new<C> (l); M01= mmx_new<C> (l);
00479   M10= mmx_new<C> (l); M11= mmx_new<C> (l);
00480 }
00481 
00482 TMPL static void
00483 delete_matrix (C* M00, C* M01, C* M10, C* M11, nat l) {
00484   mmx_delete<C> (M00, l); mmx_delete<C> (M01, l);
00485   mmx_delete<C> (M10, l); mmx_delete<C> (M11, l);
00486 }
00487 
00488 TMPL static void
00489 half_gcd (C* Q00, C* Q01, C* Q10, C* Q11,
00490           nat& nQ00, nat& nQ01, nat& nQ10, nat& nQ11,
00491           const C* r0, const C* r1, nat n0, nat n1, nat k,
00492           C* rho, C* tp) {
00493   // Performs the maximum numbers of reductions less than k.
00494   // tp must have size n0 + n1.
00495   // rho is the vector of the leading coefficients of to
00496   // degree min (n0, n1) - 1.
00497   // s1 and s2 are supposed to be monic.
00498   VERIFY (n0 >= n1, "bad input sizes");
00499   VERIFY (k  <= n0, "index k out of range");
00500   if (n1 == 0 || k < n0 - n1 + 1) {
00501     Q00[0] = 1; nQ00 = 1; nQ01 = 0;
00502     Q11[0] = 1; nQ10 = 0; nQ11 = 1;
00503     return;
00504   }
00505   if (k == 1) {
00506     Q00[0] = 1; nQ00 = 1; nQ01 = 0;
00507     Q10[0] = 1; Q11[0] = - r0[n0-1] / r1[n1-1]; nQ10 = 1; nQ11 = 1;
00508     return;
00509   }
00510   nat h = k >> 1, h2 = (h << 1) - 1;
00511   nat len_R= aligned_size<C,V> (h2);
00512   nat nR00, nR01, nR10, nR11;
00513   C* R00, * R01, * R10, * R11;
00514   new_matrix (R00, R01, R10, R11, len_R);
00515   if (h2 < n0 - n1)
00516     half_gcd (R00, R01, R10, R11, nR00, nR01, nR10, nR11,
00517               r0, r1, n0, n1, h, rho, tp);
00518   else
00519     half_gcd (R00, R01, R10, R11, nR00, nR01, nR10, nR11,
00520               r0 + n0 - h2, r1 + n0 - h2,
00521               h2, h2 - (n0 - n1), h,
00522               rho == NULL ? rho : (rho + (n0 - h2)), tp);    
00523   nat len_r= aligned_size<C,V> (n0 + h);
00524   C* rjm1= mmx_new<C> (len_r), * rj  = mmx_new<C> (len_r);
00525   nat nj, njm1;
00526   // TODO << use truncated product
00527   matrix_vector_product (rjm1, rj, njm1, nj,
00528                          R00, R01, R10, R11, nR00, nR01, nR10, nR11,
00529                          r0, r1, n0, n1, tp);
00530   if (nj == 0 || k < n0 - nj + 1) {
00531     Pol::copy (Q00, R00, nR00); nQ00= nR00;
00532     Pol::copy (Q01, R01, nR01); nQ01= nR01;
00533     Pol::copy (Q10, R10, nR10); nQ10= nR10;
00534     Pol::copy (Q11, R11, nR11); nQ11= nR11;
00535     return;
00536   }
00537   nat len_S= aligned_size<C,V> (max (k - (n0 - nj), njm1 - nj + 1));
00538   nat nS00, nS01, nS10, nS11;
00539   C* S00, * S01, * S10, * S11;
00540   new_matrix (S00, S01, S10, S11, len_S);
00541 
00542   nat len_T= aligned_size<C,V> (njm1 - nj + h);
00543   nat nT00, nT01, nT10, nT11;
00544   C* T00, * T01, * T10, * T11;
00545   new_matrix (T00, T01, T10, T11, len_T);
00546 
00547   S01[0]= 1; S10[0]= 1;  
00548   nS00= 0; nS01= 1; nS10= 1; nS11= njm1 - nj + 1;
00549   quo_rem (S11, rjm1, rj, njm1, nj);
00550   Vec::neg (S11, nS11);
00551   Pol::trim (rj, nj); Pol::trim (rjm1, njm1);
00552 
00553   if (njm1 != 0) {
00554     C c= 1 / rjm1[njm1-1];
00555     if (rho != NULL) rho[njm1-1] = rjm1[njm1-1];
00556     Pol::mul_sc (rjm1, c, njm1);
00557     Pol::mul_sc (S11, c, nS11);
00558     Pol::mul_sc (S10, c, nS10);
00559   }
00560   matrix_product (T00, T01, T10, T11, nT00, nT01, nT10, nT11,
00561                   S00, S01, S10, S11, nS00, nS01, nS10, nS11,
00562                   R00, R01, R10, R11, nR00, nR01, nR10, nR11, tp);
00563 
00564   if (njm1 == 0 || k < n0 - njm1 + 1) {
00565     Pol::copy (Q00, T00, nT00); nQ00= nT00;
00566     Pol::copy (Q01, T01, nT01); nQ01= nT01;
00567     Pol::copy (Q10, T10, nT10); nQ10= nT10;
00568     Pol::copy (Q11, T11, nT11); nQ11= nT11;
00569     return;
00570   }
00571   h = k - (n0 - nj); h2 = (h << 1) - 1;
00572   if (h2 < nj - njm1 || h2 > nj)
00573     half_gcd (S00, S01, S10, S11, nS00, nS01, nS10, nS11,
00574               rj, rjm1, nj, njm1, h, rho, tp);
00575   else
00576     half_gcd (S00, S01, S10, S11, nS00, nS01, nS10, nS11,
00577               rj + nj - h2, rjm1 + nj - h2, h2, h2 - (nj - njm1), h,
00578               rho == NULL ? rho : (rho + (nj - h2)), tp);
00579   mmx_delete<C> (rjm1, len_r); mmx_delete<C> (rj, len_r);
00580   matrix_product (Q00, Q01, Q10, Q11, nQ00, nQ01, nQ10, nQ11,
00581                   S00, S01, S10, S11, nS00, nS01, nS10, nS11,
00582                   T00, T01, T10, T11, nT00, nT01, nT10, nT11, tp);
00583   delete_matrix (R00, R01, R10, R11, len_R);
00584   delete_matrix (S00, S01, S10, S11, len_S);
00585   delete_matrix (T00, T01, T10, T11, len_T);
00586 }
00587   
00588 public:
00589 
00590 TMPL static inline void
00591 euclidean_sequence (const C* s1, const C* s2, nat n1, nat n2,
00592                     C* d1, C* d2, nat& m1 , nat& m2,
00593                     C* u1, C* u2, nat& nu1, nat& nu2,
00594                     C* v1, C* v2, nat& nv1, nat& nv2,
00595                     nat* n, C* rho, C* q, C** r, C** co1, C** co2, nat k= 0) {
00596   Fallback::euclidean_sequence (s1, s2, n1, n2, d1, d2, m1, m2,
00597                                 u1, u2, nu1, nu2, v1, v2, nv1, nv2,
00598                                 n, rho, q, r, co1, co2, k);
00599 }
00600 
00601 TMPL static void
00602 gcd (C* g, nat& n, const C* s1, const C* s2, nat n1, nat n2,
00603      C* uu1, C* uu2, nat& nuu1, nat& nuu2) {
00604   // g must have allocated length min (n1, n2)
00605   if (n1 < Threshold(C,Th) || n2 < Threshold(C,Th)) {
00606     Fallback::gcd (g, n, s1, s2, n1, n2, uu1, uu2, nuu1, nuu2); return; }
00607   VERIFY (n1>0 && n2>0 && s1[n1-1] != 0 && s2[n2-1] != 0,
00608           "invalid hypothesis for gcd computation");
00609   nat nu1, nu2, nv1, nv2;
00610   C c1= 1 / s1[n1-1], c2= 1 / s2[n2-1];
00611   nat l1= aligned_size<C,V> (n1), l2= aligned_size<C,V> (n2);
00612   C* z1= mmx_new<C> (l1), * z2= mmx_new<C> (l2);
00613   Pol::mul_sc (z1, s1, c1, n1); Pol::mul_sc (z2, s2, c2, n2); 
00614   C* u1= mmx_new<C> (l2), * u2= mmx_new<C> (l1);
00615   C* v1= mmx_new<C> (l2), * v2= mmx_new<C> (l1);
00616   nat len_tp= aligned_size<C,V> (n1 + n2);
00617   C* tp1= mmx_new<C> (len_tp), * tp2= mmx_new<C> (len_tp);
00618   C* rho= NULL;
00619   if (n1 >= n2)
00620     half_gcd (u1, u2, v1, v2, nu1, nu2, nv1, nv2,
00621               z1, z2, n1, n2, n1, rho, tp1);
00622   else
00623     half_gcd (u2, u1, v2, v1, nu2, nu1, nv2, nv1,
00624               z2, z1, n2, n1, n2, rho, tp1);
00625   VERIFY (nu1 <= n2 && nv1 <= n2 && nu2 <= n1 && nv2 <= n1, "bug");
00626   dot_product (tp2, n, u1, u2, nu1, nu2, z1, z2, n1, n2, tp1);
00627   VERIFY (n <= min (n1, n2), "bug");
00628   C c= n == 0 ? 0 : (1 / tp2[n-1]);
00629   Pol::mul_sc (g, tp2, c, n); Pol::clear (g + n, min (n1, n2) - n);
00630   if (uu1 != NULL) {
00631     Pol::mul_sc (uu1, u1, c1 * c, nu1);
00632     Pol::clear (uu1 + nu1, n2 - nu1); nuu1= nu1; }
00633   if (uu2 != NULL) {
00634     Pol::mul_sc (uu2, u2, c2 * c, nu2);
00635     Pol::clear (uu2 + nu2, n1 - nu2); nuu2= nu2; }
00636   mmx_delete<C> (tp1, len_tp); mmx_delete<C> (tp2, len_tp);
00637   mmx_delete<C> (u1, l2); mmx_delete<C> (u2, l1);
00638   mmx_delete<C> (v1, l2); mmx_delete<C> (v2, l1);
00639 }  
00640 
00641 TMPL static void
00642 gcd (C* g, nat& n, const C* s1, const C* s2, nat n1, nat n2) {
00643   nat nuu1, nuu2;
00644   C* uu1= NULL, * uu2= NULL;
00645   gcd (g, n, s1, s2, n1, n2, uu1, uu2, nuu1, nuu2);
00646 }
00647 
00648 TMPL static void
00649 gcd (C* g, nat& n, const C* s1, const C* s2, nat n1, nat n2,
00650      C* uu1, nat& nuu1) {
00651   C* uu2= NULL; nat nuu2;
00652   gcd (g, n, s1, s2, n1, n2, uu1, uu2, nuu1, nuu2);
00653 }  
00654 
00655 TMPL static void
00656 pade (C* r, C* t, const C* s, nat m, nat n, nat k) {
00657   // r (resp. t) must have allocated length at least k (resp. n-k+1)
00658   // s = r / t + O(x^n), deg r < k, deg t <= n-k
00659   // "Modern Computer Algebra", Corollary 5.21
00660   if (n < Threshold(C,Th) || k == n || k == 0) {
00661     Fallback::pade (r, t, s, m, n, k); return; }
00662   VERIFY (n > k && n > 0 && m > 0 && m <= n && s[m-1] != 0,
00663           "invalid hypothesis for gcd computation");
00664   nat nu1, nu2, nv1, nv2;
00665   nat n1= n+1, n2= m;
00666   nat l1= aligned_size<C,V> (n1), l2= aligned_size<C,V> (n2);
00667   C* u1= mmx_new<C> (l2), * u2= mmx_new<C> (l1);
00668   C* v1= mmx_new<C> (l2), * v2= mmx_new<C> (l1);
00669   C* s1= mmx_new<C> (l1); Pol::clear (s1, n); s1[n]= 1; const C* s2= s;
00670   nat len_tp= aligned_size<C,V> (n1 + n2);
00671   C* tp= mmx_new<C> (len_tp);
00672   C* rho= NULL;
00673   half_gcd (u1, u2, v1, v2, nu1, nu2, nv1, nv2,
00674             s1, s2, n1, n2, n-k, rho, tp);
00675 
00676   nat len_r= aligned_size<C,V> (2*n1-k);
00677   C* rjm1= mmx_new<C> (len_r), * rj  = mmx_new<C> (len_r);
00678   nat nj, njm1;
00679   matrix_vector_product (rjm1, rj, njm1, nj,
00680                          u1, u2, v1, v2, nu1, nu2, nv1, nv2,
00681                          s1, s2, n1, n2, tp);
00682   if (nj <= k) {
00683     VERIFY (nv2 <= n-k+1, "bug");
00684     Pol::copy (r, rj, nj); Pol::clear (r + nj, k - nj);
00685     Pol::copy (t, v2, nv2); Pol::clear (t + nv2, n-k+1 - nv2);
00686   }
00687   else {
00688     VERIFY (nj > k, "bug");
00689     nat nq= njm1 - nj + 1, len_q= aligned_size<C,V> (nq);
00690     C* q= mmx_new<C> (len_q);
00691     quo_rem (q, rjm1, rj, njm1, nj);
00692     Pol::trim (rjm1, njm1);
00693     Pol::mul (tp, q, v2, nq, nv2);
00694     Pol::clear (u2 + nu2, n1 - nu2); nu2= n1;
00695     Pol::sub (u2, tp, nq + nv2 - 1);
00696     Pol::trim (u2, nu2);
00697     VERIFY (njm1 <= k, "bug");
00698     VERIFY (nu2 <= n-k+1, "bug");
00699     Pol::copy (r, rjm1, njm1); Pol::clear (r + njm1, k - njm1);
00700     Pol::copy (t, u2, nu2); Pol::clear (t + nu2, n-k+1 - nu2);
00701     mmx_delete<C> (q, len_q);
00702   }
00703   mmx_delete<C> (rjm1, len_r); mmx_delete<C> (rj, len_r);
00704   mmx_delete<C> (s1, l1); mmx_delete<C> (tp, len_tp);
00705   mmx_delete<C> (u1, l2); mmx_delete<C> (u2, l1);
00706   mmx_delete<C> (v1, l2); mmx_delete<C> (v2, l1);
00707 }
00708 
00709 }; // implementation<polynomial_euclidean,V,polynomial_dicho<BV> >
00710 
00711 /******************************************************************************
00712 * Efficient evaluation of polynomials
00713 ******************************************************************************/
00714 
00715 template<typename V>
00716 struct polynomial_evaluate_threshold {};
00717 
00718 template<typename V, typename BV>
00719 struct implementation<polynomial_evaluate,V,polynomial_dicho<BV> >:
00720   public implementation<polynomial_divide,V>
00721 {
00722   typedef polynomial_evaluate_threshold<polynomial_dicho<BV> > Th;
00723   typedef implementation<vector_linear,V> Vec;
00724   typedef implementation<polynomial_divide,V> Pol;
00725   typedef implementation<polynomial_evaluate,V,BV> Fallback;
00726 
00727 TMPL static inline void
00728 factorials (C* dest, nat n) {
00729   if (n > 0) dest[0]= C(1);
00730   for (nat i=1; i<n; i++)
00731     dest[i]= C(i) * dest[i-1];
00732 }
00733 
00734 TMPL static void
00735 shift (C* dest, const C* s, const C& sh, nat n) {
00736   if (n <= 1 || sh == 0) Pol::copy (dest, s, n);
00737   else {
00738     nat l= aligned_size<C,V> (5 * n);
00739     C* u    = mmx_new<C> (l);
00740     C* v    = u + n;
00741     C* w    = v + n;
00742     C* facts= w + n + n;
00743     factorials (facts, n);
00744     Vec::mul (u, s, facts, n);
00745     Vec::vec_reverse (u, n);
00746     Vec::set (v, 1, n);
00747     if (sh != 0) Pol::q_difference (v, v, sh, n);
00748     Vec::div (v, facts, n);
00749     Pol::mul (w, u, v, n, n); // FIXME: rather use truncated multiplication
00750     Vec::vec_reverse (w, n);
00751     Vec::div (dest, w, facts, n);
00752     mmx_delete<C> (u, l);
00753   }
00754 }
00755 
00756 TMPL static inline C
00757 evaluate (const C* p, const C& x, nat l) {
00758   return Fallback::evaluate (p, x, l);
00759 }
00760 
00761 TMPL static void
00762 q_binomial (C* dest, const C& q, nat mu) {
00763   dest[mu]= 1;
00764   for (nat i=1; i<=mu; i++)
00765     dest[mu-i]= (C (mu+1-i) * dest[mu+1-i] * q) / C(i);
00766 }
00767 
00768 TMPL static void
00769 expand (C** v, const C* p, const C* x, const nat* nu, nat n, nat k) {
00770   nat tot= 0;
00771   for (nat i=0; i<k; i++)
00772     tot += nu[i] + 1;
00773   nat* d= mmx_new<nat> (k);
00774   nat l= aligned_size<C,V> (tot << 1);
00775   C* q= mmx_new<C> (l);
00776   C* r= q + tot;
00777   nat off= 0;
00778   for (nat i=0; i<k; i++) {
00779     q_binomial (q + off, -x[i], nu[i]);
00780     d[i]= nu[i] + 1;
00781     off += d[i];
00782   }
00783   multi_mod (r, p, q, d, n, k);
00784   off= 0;
00785   for (nat i=0; i<k; i++) {
00786     shift (v[i], r + off, x[i], nu[i]);
00787     off += d[i];
00788   }
00789   mmx_delete<C> (q, l);
00790   mmx_delete<nat> (d, k);
00791 }
00792 
00793 #define C typename scalar_type_helper<Polynomial >::val
00794 
00795 struct _vector_sort_by_increasing_degree_op {
00796   TMPLP static bool
00797   op (const Polynomial& p, const Polynomial& q) {
00798     return deg (p) < deg (q); }
00799   TMPLP static bool
00800   not_op (const Polynomial& p, const Polynomial& q) {
00801     return deg (p) >= deg (q); }
00802 };
00803 
00804 template<typename Op, typename Polynomial> static inline vector<Polynomial>
00805 _multi_rem (const Polynomial& p, const vector<Polynomial>& q) {
00806   if (p == 0) return vector<Polynomial> (Polynomial(C(0)), N(q));
00807   nat n= degree (p);
00808   vector<Polynomial> sorted_q (q), r (Polynomial(C(0)), N(q));
00809   vector<nat> sigma; // permutation
00810   sort_leq<_vector_sort_by_increasing_degree_op> (sorted_q, sigma);
00811   nat start= 0, sum= 0;
00812   for (nat i= 0; i < N(q); i++) {
00813     sum += degree (sorted_q[i]);
00814     if (sum > n / 2 || i+1 == N(q)) {
00815       Crt_polynomial_transformer(Polynomial)
00816         crter (range (sorted_q, start, i+1));
00817       vector<Polynomial> tmp; direct_crt (tmp, p, crter);
00818       for (nat j= start; j < i+1; j++) r[sigma[j]]= tmp[j-start];
00819       start= i+1; sum= 0;
00820     }
00821   }
00822   return r; }
00823 
00824 TMPLP static inline vector<Polynomial>
00825 multi_rem (const Polynomial& p, const vector<Polynomial>& q) {
00826   return _multi_rem<rem_op> (p, q); }
00827 
00828 TMPLP static inline vector<Polynomial>
00829 multi_prem (const Polynomial& p, const vector<Polynomial>& q) {
00830   return _multi_rem<prem_op> (p, q); }
00831 
00832 TMPLP static vector<Polynomial>
00833 multi_gcd (const Polynomial& P, const vector<Polynomial>& Q) {
00834   return binary_map<gcd_op> (rem (P, Q), Q); }
00835   
00836 TMPLP static Polynomial
00837 annulator (const Vector& x) {
00838   ASSERT (is_non_scalar (x), "non-scalar xector expected");
00839   if (N(x) == 0) return 1;
00840   vector<Polynomial> q (Polynomial (), N(x));
00841   Polynomial z (C(1), 1);
00842   for (nat i= 0; i < N(x); i++) q[i]= z - x[i];
00843   Crt_polynomial_transformer(Polynomial) crter (q);
00844   return * moduli_product (crter); }
00845 
00846 TMPLP static inline Vector
00847 evaluate (const Polynomial& p, const Vector& x) {
00848   ASSERT (is_non_scalar (x), "non-scalar vector expected");
00849   if (N(x) == 0) return Vector (C(0), 0);
00850   vector<Polynomial> q (Polynomial (), N(x));
00851   Polynomial z (C(1), 1);
00852   for (nat i= 0; i < N(x); i++) q[i]= z - x[i];
00853   vector<Polynomial> tmp= multi_rem (p, q);
00854   Vector r (C(0), N(x));
00855   for (nat i= 0; i < N(x); i++) r[i]= tmp[i][0];
00856   return r;
00857 }
00858 
00859 TMPLP static inline Polynomial
00860 tevaluate (const Vector& v, const Vector& x, nat l) {
00861   ASSERT (is_non_scalar (x), "non-scalar vector expected");
00862   if (l == 0) return Polynomial (0);
00863   nat n= N(x), ll= aligned_size<C,V> (l);
00864   vector<Polynomial> q (Polynomial (), n);
00865   Polynomial z (C(1), 1);
00866   for (nat i= 0; i < N(v); i++) q[i]= Polynomial (1) - x[i] * z;
00867   Crt_polynomial_transformer(Polynomial) crter (q);
00868   Polynomial num (combine_crt (v, crter)), den (* moduli_product (crter));
00869   // retrieve the coefficients of the series num / den
00870   C* tmp= mmx_new<C> (ll), * inv= mmx_new<C> (ll);
00871   Pol::clear (tmp, l); Pol::copy (tmp, seg(den), min (N(den), l));
00872   Pol::invert_lo (inv, tmp, l);
00873   Polynomial b (inv, l, ll), c (num * b);
00874   for (nat i= 0; i < l; i++) tmp[i]= c[i];
00875   return Polynomial (tmp, l, ll);
00876 }
00877 
00878 TMPLP static Polynomial
00879 interpolate (const Vector& v, const Vector& x) {
00880   ASSERT (is_non_scalar (x), "non-scalar vector expected");
00881   ASSERT (N(v) == N(x), "dimensions don't match");
00882   if (N(x) == 0) return Polynomial (0);
00883   vector<Polynomial> q (Polynomial (), N(x));
00884   Polynomial z (C(1), 1);
00885   for (nat i= 0; i < N(x); i++) q[i]= z - x[i];
00886   Crt_polynomial_transformer(Polynomial) crter (q);
00887   Polynomial ans; inverse_crt (ans, as<vector<Polynomial> > (v), crter);
00888   return ans;
00889 }
00890 
00891 TMPLP static Vector
00892 tinterpolate (const Polynomial& p, const Vector& x) {
00893   ASSERT (is_non_scalar (x), "non-scalar vector expected");
00894   ASSERT (N(p) <= N(x), "dimensions don't match");
00895   nat n= N(x);
00896   if (n == 0) return Vector (C(0), 0);
00897   vector<Polynomial> q (Polynomial (), n);
00898   Polynomial z (C(1), 1);
00899   for (nat i= 0; i < n; i++) q[i]= z - x[i];
00900   Crt_polynomial_transformer(Polynomial) crter (q);
00901   Polynomial rp (lshiftz (reverse (p), (int) (n - N(p))));
00902   Polynomial den (* moduli_product (crter));
00903   vector<Polynomial> a; direct_crt (a, range (den * rp, n, 2*n), crter);
00904   vector<Polynomial> b; direct_crt (b, derive (den), crter);
00905   Vector v (C(0), n); for (nat i= 0; i < n; i++) v[i]= a[i][0] / b[i][0];
00906   return v;
00907 }
00908 
00909 #undef C
00910 }; // implementation<polynomial_evaluate,V,polynomial_dicho<BV> >
00911 
00912 #undef TMPL
00913 #undef TMPLP
00914 #undef Vector
00915 } // namespace mmx
00916 #endif //__MMX__POLYNOMIAL_DICHO__HPP
 All Classes Namespaces Files Functions Variables Typedefs Friends Defines