algebramix_doc 0.3
|
00001 00002 /****************************************************************************** 00003 * MODULE : polynomial_dicho.hpp 00004 * DESCRIPTION: dichotomic algorithms including Karatsuba multiplication 00005 * COPYRIGHT : (C) 2003 Joris van der Hoeven and Gregoire Lecerf 00006 ******************************************************************************* 00007 * This software falls under the GNU general public license and comes WITHOUT 00008 * ANY WARRANTY WHATSOEVER. See the file $TEXMACS_PATH/LICENSE for more details. 00009 * If you don't have this file, write to the Free Software Foundation, Inc., 00010 * 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. 00011 ******************************************************************************/ 00012 00013 #ifndef __MMX__POLYNOMIAL_DICHO__HPP 00014 #define __MMX__POLYNOMIAL_DICHO__HPP 00015 #include <basix/vector_sort.hpp> 00016 #include <algebramix/polynomial_naive.hpp> 00017 #include <algebramix/crt_polynomial.hpp> 00018 00019 namespace mmx { 00020 #define TMPL template<typename C> 00021 #define TMPLP template <typename Polynomial> 00022 #define Vector vector<C> 00023 00024 /****************************************************************************** 00025 * Variants for dichotomic algorithms on polynomials 00026 ******************************************************************************/ 00027 00028 template<typename V> 00029 struct polynomial_karatsuba: public V { 00030 typedef typename V::Vec Vec; 00031 typedef typename V::Naive Naive; 00032 typedef polynomial_karatsuba<typename V::Positive> Positive; 00033 typedef polynomial_karatsuba<typename V::No_simd> No_simd; 00034 typedef polynomial_karatsuba<typename V::No_thread> No_thread; 00035 typedef polynomial_karatsuba<typename V::No_scaled> No_scaled; 00036 }; 00037 00038 template<typename F, typename V, typename W> 00039 struct implementation<F,V,polynomial_karatsuba<W> >: 00040 public implementation<F,V,W> {}; 00041 00042 template<typename V> 00043 struct polynomial_dicho: public V { 00044 typedef typename V::Vec Vec; 00045 typedef typename V::Naive Naive; 00046 typedef polynomial_dicho<typename V::Positive> Positive; 00047 typedef polynomial_dicho<typename V::No_simd> No_simd; 00048 typedef polynomial_dicho<typename V::No_thread> No_thread; 00049 typedef polynomial_dicho<typename V::No_scaled> No_scaled; 00050 }; 00051 00052 template<typename F, typename V, typename W> 00053 struct implementation<F,V,polynomial_dicho<W> >: 00054 public implementation<F,V,W> {}; 00055 00056 /****************************************************************************** 00057 * Multiplication 00058 ******************************************************************************/ 00059 00060 template<typename V> 00061 struct polynomial_multiply_threshold {}; 00062 00063 template<typename V, typename W> 00064 struct implementation<polynomial_multiply,V,polynomial_karatsuba<W> >: 00065 public implementation<polynomial_linear,V> 00066 { 00067 typedef polynomial_multiply_threshold<polynomial_karatsuba<W> > Th; 00068 typedef implementation<vector_linear,V> Vec; 00069 typedef implementation<polynomial_linear,W> Pol; 00070 typedef implementation<polynomial_multiply,W> Fallback; 00071 00072 TMPL static void 00073 multiply (C* dest, const C* s1, const C* s2, nat n1, nat n2) { 00074 if (n1 < Threshold(C,Th) || n2 < Threshold(C,Th)) 00075 Fallback::mul (dest, s1, s2, n1, n2); 00076 else { 00077 nat p1= n1 >> 1, p2= n2 >> 1, P= p1+p2-1; 00078 nat spc= aligned_size<C,V> (3*(p1+p2)); 00079 C* buf= mmx_new<C> (spc); 00080 C* low1= dest; 00081 C* low2= low1 + p1; 00082 C* mid1= buf; 00083 C* mid2= buf + p1; 00084 C* hi1 = mid2 + p2; 00085 C* hi2 = hi1 + p1; 00086 C* Low = mid1; 00087 C* Mid = hi1; 00088 C* Hi = Mid + p1 + p2; 00089 Vec::half_copy (low1 , s1 , p1); 00090 Vec::half_copy (hi1 , s1+1, p1); 00091 Pol::add (mid1 , low1, hi1 , p1); 00092 Vec::half_copy (low2 , s2 , p2); 00093 Vec::half_copy (hi2 , s2+1, p2); 00094 Pol::add (mid2 , low2, hi2 , p2); 00095 multiply (Hi , hi1 , hi2 , p1, p2); 00096 multiply (Mid , mid1, mid2, p1, p2); 00097 multiply (Low , low1, low2, p1, p2); 00098 Pol::sub (Mid , Low , P); 00099 Pol::sub (Mid , Hi , P); 00100 Pol::add (Low+1, Hi , P-1); 00101 Low[P]= Hi[P-1]; 00102 Vec::double_copy (dest , Low, P+1); 00103 Vec::double_copy (dest+1, Mid, P); 00104 mmx_delete<C> (buf, spc); 00105 00106 if ((n1 & 1) != 0) { 00107 dest[(P<<1)+1]= C(0); 00108 Pol::mul_add (dest + (p1<<1), s2, s1[p1<<1], p2<<1); 00109 } 00110 if ((n2 & 1) != 0) { 00111 dest[n1+n2-2]= C(0); 00112 Pol::mul_add (dest + (p2<<1), s1, s2[p2<<1], n1); 00113 } 00114 } 00115 } 00116 00117 TMPL static inline void 00118 mul (C* dest, const C* s1, const C* s2, nat n1, nat n2) { 00119 if (n1 == 0 && n2 == 0) return; 00120 if (n1 == 0 || n2 == 0) 00121 Pol::clear (dest, n1 + n2 - 1); 00122 else 00123 multiply (dest, s1, s2, n1, n2); 00124 } 00125 00126 TMPL static void 00127 square (C* dest, const C* s, nat n) { 00128 if (n == 0) return; 00129 if (n < Threshold(C,Th)) 00130 Fallback::square (dest, s, n); 00131 else { 00132 nat p= n >> 1, P= 2*p-1, spc= aligned_size<C,V> (6*p); 00133 C* buf= mmx_new<C> (spc); 00134 C* low= dest; 00135 C* mid= buf; 00136 C* hi = buf + 2*p; 00137 C* Low= buf; 00138 C* Mid= hi; 00139 C* Hi = Mid + 2*p; 00140 Vec::half_copy (low , s , p); 00141 Vec::half_copy (hi , s+1, p); 00142 Pol::add (mid , low, hi, p); 00143 square (Hi , hi , p); 00144 square (Mid , mid, p); 00145 square (Low , low, p); 00146 Pol::sub (Mid , Low, P); 00147 Pol::sub (Mid , Hi , P); 00148 Pol::add (Low+1, Hi , P-1); 00149 Low[P]= Hi[P-1]; 00150 Vec::double_copy (dest , Low, P+1); 00151 Vec::double_copy (dest+1, Mid, P); 00152 mmx_delete<C> (buf, spc); 00153 00154 if ((n & 1) != 0) { 00155 dest[(P<<1)+1]= dest[2*n-2]= C(0); 00156 Pol::mul_add (dest + (p<<1), s, s[p<<1], p<<1); 00157 Pol::mul_add (dest + (p<<1), s, s[p<<1], n); 00158 } 00159 } 00160 } 00161 00162 TMPL static void 00163 tmultiply (C* dest, const C* s1, const C* s2, nat n1, nat n2) { 00164 if (n1 < Threshold(C,Th) || n2 < Threshold(C,Th)) 00165 Fallback::tmul (dest, s1, s2, n1, n2); 00166 else { 00167 nat p1= n1 >> 1, p2= n2 >> 1, P= p1+p2-1; 00168 nat spc= aligned_size<C,V> (3*(p1+P)+1); 00169 C* buf= mmx_new<C> (spc); 00170 C* low1= buf; 00171 C* low2= low1 + p1; 00172 C* mid1= low2 + P+1; 00173 C* mid2= mid1 + p1; 00174 C* hi1 = mid2 + P; 00175 C* hi2 = hi1 + p1; 00176 C* Low = mid1; 00177 C* Mid = hi1; 00178 C* Hi = dest; 00179 Vec::half_copy (low1, s1 , p1); 00180 Vec::half_copy (hi1 , s1+1, p1); 00181 Pol::add (mid1, low1, hi1, p1); 00182 Vec::half_copy (low2, s2 , P+1); 00183 Vec::half_copy (hi2 , s2+1, P); 00184 Pol::add (mid2, low2+1, hi2, P); 00185 tmultiply (Hi , hi1 , hi2 , p1, p2); 00186 tmultiply (Mid, mid1, mid2, p1, p2); 00187 tmultiply (Low, low1, low2, p1, p2+1); 00188 Pol::sub (Mid, Hi , p2); 00189 Pol::sub (Mid, Low+1, p2); 00190 Pol::add (Low, Hi , p2); 00191 Vec::double_copy (dest , Low, p2); 00192 Vec::double_copy (dest+1, Mid, p2); 00193 mmx_delete<C> (buf, spc); 00194 00195 if ((n1 & 1) != 0) 00196 Pol::mul_add (dest, s2+n1-1, s1[n1-1], n2); 00197 00198 if ((n2 & 1) != 0) 00199 dest[n2-1]= Vec::inn_prod (s1, s2+n2-1, n1); 00200 } 00201 } 00202 00203 TMPL static inline void 00204 tmul (C* dest, const C* s1, const C* s2, nat n1, nat n2) { 00205 // transposed multiplication 00206 // s1 has length n1, dest has length n2 and s2 has length n1+n2-1 00207 Pol::clear (dest, n2); 00208 if (n1 == 0 || n2 == 0) return; 00209 tmultiply (dest, s1, s2, n1, n2); 00210 } 00211 00212 }; // implementation<polynomial_multiply,V,polynomial_karatsuba<W> > 00213 00214 /****************************************************************************** 00215 * Division 00216 ******************************************************************************/ 00217 00218 template<typename V> 00219 struct polynomial_divide_threshold {}; 00220 00221 template<typename V, typename BV> 00222 struct implementation<polynomial_divide,V,polynomial_dicho<BV> >: 00223 public implementation<polynomial_multiply,V> 00224 { 00225 typedef polynomial_divide_threshold<polynomial_dicho<BV> > Th; 00226 typedef implementation<polynomial_multiply,V> Pol; 00227 typedef implementation<polynomial_divide,V,BV> Fallback; 00228 00229 TMPL static void 00230 invert_lo (C* dest, const C* src, nat n) { 00231 if (n == 0) return; 00232 if (n == 1) *dest= C(1) / *src; 00233 else { 00234 nat h= (n+1) >> 1; 00235 nat l= n - h; 00236 invert_lo (dest, src, h); 00237 nat buf_size= aligned_size<C,V> (n << 1); 00238 C* buf= mmx_new<C> (buf_size); 00239 C* aux= buf + n; 00240 Pol::mul (buf, src, dest, n, h); 00241 // FIXME: use middle product 00242 Pol::mul (aux, dest, buf + h, l, l); 00243 // FIXME: use truncated product 00244 Pol::neg (dest + h, aux, l); 00245 mmx_delete<C> (buf, buf_size); 00246 } 00247 } 00248 00249 TMPL static void 00250 invert_hi (C* dest, const C* src, nat n) { 00251 if (n == 1) *dest= C(1) / *src; 00252 else { 00253 nat h= (n+1) >> 1; 00254 nat l= n - h; 00255 invert_hi (dest + l, src + l, h); 00256 nat buf_size= aligned_size<C,V> (n << 1); 00257 C* buf= mmx_new<C> (buf_size); 00258 C* aux= buf + l; 00259 Pol::mul (aux, src, dest + l, n, h); 00260 // FIXME: use middle product 00261 Pol::mul (buf, dest + h, aux + h - 1, l, l); 00262 // FIXME: use truncated product 00263 Pol::neg (dest, buf + l - 1, l); 00264 mmx_delete<C> (buf, buf_size); 00265 } 00266 } 00267 00268 TMPL static void 00269 quo_rem (C* dest, C* s1, const C* s2, nat n1, nat n2) { 00270 // NOTE: requires C to be a field 00271 if (n1 < n2); 00272 else if (n1 < Threshold(C,Th)) 00273 Fallback::quo_rem (dest, s1, s2, n1, n2); 00274 else { 00275 nat tot= aligned_size<C,V> ((n2 << 1) + n2); 00276 C* buf= mmx_new<C> (tot); 00277 C* inv= buf + (n2 << 1); 00278 nat nq = n1 + 1 - n2; 00279 nat l = min (n2, nq); 00280 invert_hi (inv + n2 - l, s2 + n2 - l, l); 00281 while (n1 >= n2) { 00282 nat nq= n1 + 1 - n2; 00283 nat l = min (n2, nq); 00284 C* dest_hi= dest + nq - l; 00285 Pol::mul (buf, s1 + n1 - l, inv + n2 - l, l, l); 00286 Pol::copy (dest_hi, buf + l - 1, l); 00287 Pol::mul (buf, dest_hi, s2, l, n2); 00288 Pol::sub (s1 + n1 - (n2 + l - 1), buf, n2 + l - 1); 00289 n1 -= l; 00290 } 00291 mmx_delete<C> (buf, tot); 00292 } 00293 } 00294 00295 TMPL static void 00296 tquo_rem (C* dest, const C* s1, const C* s2, nat n1, nat n2) { 00297 // (s1, n2-1) contains the part corresponding to the remainder, while 00298 // (s2+n2-1, n1-n2+1) occupies the one of the quotient. 00299 // We assume n2>0 and s2[n2-1] != 0. 00300 // (dest, n1) contains the output. 00301 // NOTE: requires C to be a field 00302 if (n1 < n2) 00303 Pol::copy (dest, s1, n1); 00304 else if (n1 < Threshold(C,Th)) 00305 Fallback::tquo_rem (dest, s1, s2, n1, n2); 00306 else { 00307 nat tot= aligned_size<C,V> (3 * n2); 00308 C* buf= mmx_new<C> (tot); 00309 C* inv= buf + 2*n2; 00310 nat nq = n1 + 1 - n2; 00311 nat l = min (n2, nq); 00312 invert_hi (inv + n2 - l, s2 + n2 - l, l); 00313 Pol::neg (inv + n2 - l, inv + n2 - l, l); 00314 Pol::copy (dest, s1, n2-1); 00315 Pol::clear (dest+n2-1, n1-n2+1); 00316 nat m = n2-1; 00317 while (m < n1) { 00318 nat nq= n1 - m; 00319 nat l = min (n2, nq); 00320 Pol::clear (buf , l-1); 00321 Pol::tmul (buf+l-1 , s2 , dest + m - (n2 - 1), n2, l); 00322 Pol::sub (buf+l-1 , s1 + m , l); 00323 Pol::tmul (dest + m, inv + n2 - l, buf, l , l); 00324 m += l; 00325 } 00326 mmx_delete<C> (buf, tot); 00327 } 00328 } 00329 00330 // Pseudo division. C can be any ring. 00331 TMPL static void 00332 pinvert_hi (C* dest, const C* src, nat n) { 00333 if (n == 1) *dest= C(1); 00334 else { 00335 nat h= (n+1) >> 1; 00336 nat l= n - h; 00337 nat tmp_size= aligned_size<C,V> (l); 00338 C* tmp= mmx_new<C> (tmp_size); 00339 pinvert_hi (dest + l, src + l, h); 00340 pinvert_hi (tmp, src + h, l); // FIXME: increase from l to h 00341 nat buf_size= aligned_size<C,V> (n << 1); 00342 C* buf= mmx_new<C> (buf_size); 00343 C* aux= buf + l; 00344 Pol::mul (aux, src, dest + l, n, h); 00345 // FIXME: use middle product 00346 Pol::mul_sc (dest + l, binpow (src[n-1], l), h); 00347 Pol::mul (buf, tmp, aux + h - 1, l, l); 00348 // FIXME: use truncated product 00349 Pol::neg (dest, buf + l - 1, l); 00350 mmx_delete<C> (buf, buf_size); 00351 mmx_delete<C> (tmp, tmp_size); 00352 } 00353 } 00354 00355 TMPL static void 00356 pquo_rem (C* dest, C* s1, const C* s2, nat n1, nat n2) { 00357 if (n1 < n2); 00358 else if (n1 < Threshold(C,Th)) 00359 Fallback::pquo_rem (dest, s1, s2, n1, n2); 00360 else { 00361 nat tot= aligned_size<C,V> ((n2 << 1) + n2); 00362 C* buf= mmx_new<C> (tot); 00363 C* inv= buf + (n2 << 1); 00364 nat tmp_size= aligned_size<C,V> (n2); 00365 C* tmp= mmx_new<C> (tmp_size); 00366 nat nq= n1 + 1 - n2; 00367 nat l = min (n2, nq); 00368 nat l_end= nq % l; 00369 pinvert_hi (inv + n2 - l, s2 + n2 - l, l); 00370 if (l_end != 0) 00371 pinvert_hi (tmp + n2 - l_end, s2 + n2 - l_end, l_end); 00372 while (n1 >= n2) { 00373 nat nq= n1 + 1 - n2; 00374 nat l = min (n2, nq); 00375 C* dest_hi= dest + nq - l; 00376 if (l == l_end) 00377 Pol::mul (buf, s1 + n1 - l, tmp + n2 - l, l, l); 00378 else 00379 Pol::mul (buf, s1 + n1 - l, inv + n2 - l, l, l); 00380 Pol::copy (dest_hi, buf + l - 1, l); 00381 Pol::mul (buf, dest_hi, s2, l, n2); 00382 Pol::mul_sc (s1, binpow (s2[n2-1], l), n1); 00383 Pol::sub (s1 + n1 - (n2 + l - 1), buf, n2 + l - 1); 00384 Pol::mul_sc (dest_hi, binpow (s2[n2-1], nq - l), l); 00385 n1 -= l; 00386 } 00387 mmx_delete<C> (buf, tot); 00388 mmx_delete<C> (tmp, tmp_size); 00389 } 00390 } 00391 00392 }; // implementation<polynomial_divide,V,polynomial_dicho<BV> > 00393 00394 /****************************************************************************** 00395 * Dichotomic Gcd computations 00396 ******************************************************************************/ 00397 00398 // "Modern Computer Algebra", Algorithm 11.4 00399 00400 template<typename V> 00401 struct polynomial_euclidean_threshold {}; 00402 00403 template<typename V, typename BV> 00404 struct implementation<polynomial_euclidean,V,polynomial_dicho<BV> >: 00405 public implementation<polynomial_divide,V> 00406 { 00407 typedef polynomial_euclidean_threshold<polynomial_dicho<BV> > Th; 00408 typedef implementation<vector_linear,V> Vec; 00409 typedef implementation<polynomial_divide,V> Pol; 00410 typedef implementation<polynomial_euclidean,V,BV> Fallback; 00411 00412 private: 00413 00414 TMPL static void 00415 dot_product (C* d, nat& nd, 00416 const C* r0, const C* r1, nat nr0, nat nr1, 00417 const C* s0, const C* s1, nat ns0, nat ns1, 00418 C* t) { 00419 // d = r0 s0 + r1 s1 00420 if ((nr0 == 0 || ns0 == 0) && (nr1 == 0 || ns1 == 0)) { 00421 nd = 0; return; 00422 } 00423 if (nr0 == 0 || ns0 == 0) { 00424 nd = nr1 + ns1 - 1; 00425 Pol::mul (d, r1, s1, nr1, ns1); 00426 Pol::trim (d, nd); 00427 return; 00428 } 00429 if (nr1 == 0 || ns1 == 0) { 00430 nd = nr0 + ns0 - 1; 00431 Pol::mul (d, r0, s0, nr0, ns0); 00432 Pol::trim (d, nd); 00433 } 00434 if (nr0 + ns0 - 1 < nr1 + ns1 - 1) { 00435 nd = nr1 + ns1 - 1; 00436 Pol::mul (d, r1, s1, nr1, ns1); 00437 Pol::mul (t, r0, s0, nr0, ns0); 00438 Pol::add (d, t , nr0 + ns0 - 1); 00439 } 00440 else { 00441 nd = nr0 + ns0 - 1; 00442 Pol::mul (d, r0, s0, nr0, ns0); 00443 Pol::mul (t, r1, s1, nr1, ns1); 00444 Pol::add (d, t , nr1 + ns1 - 1); 00445 } 00446 Pol::trim (d, nd); 00447 } 00448 00449 TMPL static void 00450 matrix_vector_product (C* r0, C* r1, nat& nr0, nat& nr1, 00451 const C* R00, const C* R01, const C* R10, const C* R11, 00452 nat nR00, nat nR01, nat nR10, nat nR11, 00453 const C* s0, const C* s1, nat n0, nat n1, 00454 C* tp) { 00455 // (r0,r1) = R . (s0,s1) 00456 dot_product (r0, nr0, R00, R01, nR00, nR01, s0, s1, n0, n1, tp); 00457 dot_product (r1, nr1, R10, R11, nR10, nR11, s0, s1, n0, n1, tp); 00458 } 00459 00460 TMPL static void 00461 matrix_product (C* Q00, C* Q01, C* Q10, C* Q11, 00462 nat& nQ00, nat& nQ01, nat& nQ10, nat& nQ11, 00463 const C* S00, const C* S01, const C* S10, const C* S11, 00464 nat nS00, nat nS01, nat nS10, nat nS11, 00465 const C* R00, const C* R01, const C* R10, const C* R11, 00466 nat nR00, nat nR01, nat nR10, nat nR11, C* tp) { 00467 // Q = S . R 00468 matrix_vector_product (Q00, Q10, nQ00, nQ10, 00469 S00, S01, S10, S11, nS00, nS01, nS10, nS11, 00470 R00, R10, nR00, nR10, tp); 00471 matrix_vector_product (Q01, Q11, nQ01, nQ11, 00472 S00, S01, S10, S11, nS00, nS01, nS10, nS11, 00473 R01, R11, nR01, nR11, tp); 00474 } 00475 00476 TMPL static void 00477 new_matrix (C*& M00, C*& M01, C*& M10, C*& M11, nat l) { 00478 M00= mmx_new<C> (l); M01= mmx_new<C> (l); 00479 M10= mmx_new<C> (l); M11= mmx_new<C> (l); 00480 } 00481 00482 TMPL static void 00483 delete_matrix (C* M00, C* M01, C* M10, C* M11, nat l) { 00484 mmx_delete<C> (M00, l); mmx_delete<C> (M01, l); 00485 mmx_delete<C> (M10, l); mmx_delete<C> (M11, l); 00486 } 00487 00488 TMPL static void 00489 half_gcd (C* Q00, C* Q01, C* Q10, C* Q11, 00490 nat& nQ00, nat& nQ01, nat& nQ10, nat& nQ11, 00491 const C* r0, const C* r1, nat n0, nat n1, nat k, 00492 C* rho, C* tp) { 00493 // Performs the maximum numbers of reductions less than k. 00494 // tp must have size n0 + n1. 00495 // rho is the vector of the leading coefficients of to 00496 // degree min (n0, n1) - 1. 00497 // s1 and s2 are supposed to be monic. 00498 VERIFY (n0 >= n1, "bad input sizes"); 00499 VERIFY (k <= n0, "index k out of range"); 00500 if (n1 == 0 || k < n0 - n1 + 1) { 00501 Q00[0] = 1; nQ00 = 1; nQ01 = 0; 00502 Q11[0] = 1; nQ10 = 0; nQ11 = 1; 00503 return; 00504 } 00505 if (k == 1) { 00506 Q00[0] = 1; nQ00 = 1; nQ01 = 0; 00507 Q10[0] = 1; Q11[0] = - r0[n0-1] / r1[n1-1]; nQ10 = 1; nQ11 = 1; 00508 return; 00509 } 00510 nat h = k >> 1, h2 = (h << 1) - 1; 00511 nat len_R= aligned_size<C,V> (h2); 00512 nat nR00, nR01, nR10, nR11; 00513 C* R00, * R01, * R10, * R11; 00514 new_matrix (R00, R01, R10, R11, len_R); 00515 if (h2 < n0 - n1) 00516 half_gcd (R00, R01, R10, R11, nR00, nR01, nR10, nR11, 00517 r0, r1, n0, n1, h, rho, tp); 00518 else 00519 half_gcd (R00, R01, R10, R11, nR00, nR01, nR10, nR11, 00520 r0 + n0 - h2, r1 + n0 - h2, 00521 h2, h2 - (n0 - n1), h, 00522 rho == NULL ? rho : (rho + (n0 - h2)), tp); 00523 nat len_r= aligned_size<C,V> (n0 + h); 00524 C* rjm1= mmx_new<C> (len_r), * rj = mmx_new<C> (len_r); 00525 nat nj, njm1; 00526 // TODO << use truncated product 00527 matrix_vector_product (rjm1, rj, njm1, nj, 00528 R00, R01, R10, R11, nR00, nR01, nR10, nR11, 00529 r0, r1, n0, n1, tp); 00530 if (nj == 0 || k < n0 - nj + 1) { 00531 Pol::copy (Q00, R00, nR00); nQ00= nR00; 00532 Pol::copy (Q01, R01, nR01); nQ01= nR01; 00533 Pol::copy (Q10, R10, nR10); nQ10= nR10; 00534 Pol::copy (Q11, R11, nR11); nQ11= nR11; 00535 return; 00536 } 00537 nat len_S= aligned_size<C,V> (max (k - (n0 - nj), njm1 - nj + 1)); 00538 nat nS00, nS01, nS10, nS11; 00539 C* S00, * S01, * S10, * S11; 00540 new_matrix (S00, S01, S10, S11, len_S); 00541 00542 nat len_T= aligned_size<C,V> (njm1 - nj + h); 00543 nat nT00, nT01, nT10, nT11; 00544 C* T00, * T01, * T10, * T11; 00545 new_matrix (T00, T01, T10, T11, len_T); 00546 00547 S01[0]= 1; S10[0]= 1; 00548 nS00= 0; nS01= 1; nS10= 1; nS11= njm1 - nj + 1; 00549 quo_rem (S11, rjm1, rj, njm1, nj); 00550 Vec::neg (S11, nS11); 00551 Pol::trim (rj, nj); Pol::trim (rjm1, njm1); 00552 00553 if (njm1 != 0) { 00554 C c= 1 / rjm1[njm1-1]; 00555 if (rho != NULL) rho[njm1-1] = rjm1[njm1-1]; 00556 Pol::mul_sc (rjm1, c, njm1); 00557 Pol::mul_sc (S11, c, nS11); 00558 Pol::mul_sc (S10, c, nS10); 00559 } 00560 matrix_product (T00, T01, T10, T11, nT00, nT01, nT10, nT11, 00561 S00, S01, S10, S11, nS00, nS01, nS10, nS11, 00562 R00, R01, R10, R11, nR00, nR01, nR10, nR11, tp); 00563 00564 if (njm1 == 0 || k < n0 - njm1 + 1) { 00565 Pol::copy (Q00, T00, nT00); nQ00= nT00; 00566 Pol::copy (Q01, T01, nT01); nQ01= nT01; 00567 Pol::copy (Q10, T10, nT10); nQ10= nT10; 00568 Pol::copy (Q11, T11, nT11); nQ11= nT11; 00569 return; 00570 } 00571 h = k - (n0 - nj); h2 = (h << 1) - 1; 00572 if (h2 < nj - njm1 || h2 > nj) 00573 half_gcd (S00, S01, S10, S11, nS00, nS01, nS10, nS11, 00574 rj, rjm1, nj, njm1, h, rho, tp); 00575 else 00576 half_gcd (S00, S01, S10, S11, nS00, nS01, nS10, nS11, 00577 rj + nj - h2, rjm1 + nj - h2, h2, h2 - (nj - njm1), h, 00578 rho == NULL ? rho : (rho + (nj - h2)), tp); 00579 mmx_delete<C> (rjm1, len_r); mmx_delete<C> (rj, len_r); 00580 matrix_product (Q00, Q01, Q10, Q11, nQ00, nQ01, nQ10, nQ11, 00581 S00, S01, S10, S11, nS00, nS01, nS10, nS11, 00582 T00, T01, T10, T11, nT00, nT01, nT10, nT11, tp); 00583 delete_matrix (R00, R01, R10, R11, len_R); 00584 delete_matrix (S00, S01, S10, S11, len_S); 00585 delete_matrix (T00, T01, T10, T11, len_T); 00586 } 00587 00588 public: 00589 00590 TMPL static inline void 00591 euclidean_sequence (const C* s1, const C* s2, nat n1, nat n2, 00592 C* d1, C* d2, nat& m1 , nat& m2, 00593 C* u1, C* u2, nat& nu1, nat& nu2, 00594 C* v1, C* v2, nat& nv1, nat& nv2, 00595 nat* n, C* rho, C* q, C** r, C** co1, C** co2, nat k= 0) { 00596 Fallback::euclidean_sequence (s1, s2, n1, n2, d1, d2, m1, m2, 00597 u1, u2, nu1, nu2, v1, v2, nv1, nv2, 00598 n, rho, q, r, co1, co2, k); 00599 } 00600 00601 TMPL static void 00602 gcd (C* g, nat& n, const C* s1, const C* s2, nat n1, nat n2, 00603 C* uu1, C* uu2, nat& nuu1, nat& nuu2) { 00604 // g must have allocated length min (n1, n2) 00605 if (n1 < Threshold(C,Th) || n2 < Threshold(C,Th)) { 00606 Fallback::gcd (g, n, s1, s2, n1, n2, uu1, uu2, nuu1, nuu2); return; } 00607 VERIFY (n1>0 && n2>0 && s1[n1-1] != 0 && s2[n2-1] != 0, 00608 "invalid hypothesis for gcd computation"); 00609 nat nu1, nu2, nv1, nv2; 00610 C c1= 1 / s1[n1-1], c2= 1 / s2[n2-1]; 00611 nat l1= aligned_size<C,V> (n1), l2= aligned_size<C,V> (n2); 00612 C* z1= mmx_new<C> (l1), * z2= mmx_new<C> (l2); 00613 Pol::mul_sc (z1, s1, c1, n1); Pol::mul_sc (z2, s2, c2, n2); 00614 C* u1= mmx_new<C> (l2), * u2= mmx_new<C> (l1); 00615 C* v1= mmx_new<C> (l2), * v2= mmx_new<C> (l1); 00616 nat len_tp= aligned_size<C,V> (n1 + n2); 00617 C* tp1= mmx_new<C> (len_tp), * tp2= mmx_new<C> (len_tp); 00618 C* rho= NULL; 00619 if (n1 >= n2) 00620 half_gcd (u1, u2, v1, v2, nu1, nu2, nv1, nv2, 00621 z1, z2, n1, n2, n1, rho, tp1); 00622 else 00623 half_gcd (u2, u1, v2, v1, nu2, nu1, nv2, nv1, 00624 z2, z1, n2, n1, n2, rho, tp1); 00625 VERIFY (nu1 <= n2 && nv1 <= n2 && nu2 <= n1 && nv2 <= n1, "bug"); 00626 dot_product (tp2, n, u1, u2, nu1, nu2, z1, z2, n1, n2, tp1); 00627 VERIFY (n <= min (n1, n2), "bug"); 00628 C c= n == 0 ? 0 : (1 / tp2[n-1]); 00629 Pol::mul_sc (g, tp2, c, n); Pol::clear (g + n, min (n1, n2) - n); 00630 if (uu1 != NULL) { 00631 Pol::mul_sc (uu1, u1, c1 * c, nu1); 00632 Pol::clear (uu1 + nu1, n2 - nu1); nuu1= nu1; } 00633 if (uu2 != NULL) { 00634 Pol::mul_sc (uu2, u2, c2 * c, nu2); 00635 Pol::clear (uu2 + nu2, n1 - nu2); nuu2= nu2; } 00636 mmx_delete<C> (tp1, len_tp); mmx_delete<C> (tp2, len_tp); 00637 mmx_delete<C> (u1, l2); mmx_delete<C> (u2, l1); 00638 mmx_delete<C> (v1, l2); mmx_delete<C> (v2, l1); 00639 } 00640 00641 TMPL static void 00642 gcd (C* g, nat& n, const C* s1, const C* s2, nat n1, nat n2) { 00643 nat nuu1, nuu2; 00644 C* uu1= NULL, * uu2= NULL; 00645 gcd (g, n, s1, s2, n1, n2, uu1, uu2, nuu1, nuu2); 00646 } 00647 00648 TMPL static void 00649 gcd (C* g, nat& n, const C* s1, const C* s2, nat n1, nat n2, 00650 C* uu1, nat& nuu1) { 00651 C* uu2= NULL; nat nuu2; 00652 gcd (g, n, s1, s2, n1, n2, uu1, uu2, nuu1, nuu2); 00653 } 00654 00655 TMPL static void 00656 pade (C* r, C* t, const C* s, nat m, nat n, nat k) { 00657 // r (resp. t) must have allocated length at least k (resp. n-k+1) 00658 // s = r / t + O(x^n), deg r < k, deg t <= n-k 00659 // "Modern Computer Algebra", Corollary 5.21 00660 if (n < Threshold(C,Th) || k == n || k == 0) { 00661 Fallback::pade (r, t, s, m, n, k); return; } 00662 VERIFY (n > k && n > 0 && m > 0 && m <= n && s[m-1] != 0, 00663 "invalid hypothesis for gcd computation"); 00664 nat nu1, nu2, nv1, nv2; 00665 nat n1= n+1, n2= m; 00666 nat l1= aligned_size<C,V> (n1), l2= aligned_size<C,V> (n2); 00667 C* u1= mmx_new<C> (l2), * u2= mmx_new<C> (l1); 00668 C* v1= mmx_new<C> (l2), * v2= mmx_new<C> (l1); 00669 C* s1= mmx_new<C> (l1); Pol::clear (s1, n); s1[n]= 1; const C* s2= s; 00670 nat len_tp= aligned_size<C,V> (n1 + n2); 00671 C* tp= mmx_new<C> (len_tp); 00672 C* rho= NULL; 00673 half_gcd (u1, u2, v1, v2, nu1, nu2, nv1, nv2, 00674 s1, s2, n1, n2, n-k, rho, tp); 00675 00676 nat len_r= aligned_size<C,V> (2*n1-k); 00677 C* rjm1= mmx_new<C> (len_r), * rj = mmx_new<C> (len_r); 00678 nat nj, njm1; 00679 matrix_vector_product (rjm1, rj, njm1, nj, 00680 u1, u2, v1, v2, nu1, nu2, nv1, nv2, 00681 s1, s2, n1, n2, tp); 00682 if (nj <= k) { 00683 VERIFY (nv2 <= n-k+1, "bug"); 00684 Pol::copy (r, rj, nj); Pol::clear (r + nj, k - nj); 00685 Pol::copy (t, v2, nv2); Pol::clear (t + nv2, n-k+1 - nv2); 00686 } 00687 else { 00688 VERIFY (nj > k, "bug"); 00689 nat nq= njm1 - nj + 1, len_q= aligned_size<C,V> (nq); 00690 C* q= mmx_new<C> (len_q); 00691 quo_rem (q, rjm1, rj, njm1, nj); 00692 Pol::trim (rjm1, njm1); 00693 Pol::mul (tp, q, v2, nq, nv2); 00694 Pol::clear (u2 + nu2, n1 - nu2); nu2= n1; 00695 Pol::sub (u2, tp, nq + nv2 - 1); 00696 Pol::trim (u2, nu2); 00697 VERIFY (njm1 <= k, "bug"); 00698 VERIFY (nu2 <= n-k+1, "bug"); 00699 Pol::copy (r, rjm1, njm1); Pol::clear (r + njm1, k - njm1); 00700 Pol::copy (t, u2, nu2); Pol::clear (t + nu2, n-k+1 - nu2); 00701 mmx_delete<C> (q, len_q); 00702 } 00703 mmx_delete<C> (rjm1, len_r); mmx_delete<C> (rj, len_r); 00704 mmx_delete<C> (s1, l1); mmx_delete<C> (tp, len_tp); 00705 mmx_delete<C> (u1, l2); mmx_delete<C> (u2, l1); 00706 mmx_delete<C> (v1, l2); mmx_delete<C> (v2, l1); 00707 } 00708 00709 }; // implementation<polynomial_euclidean,V,polynomial_dicho<BV> > 00710 00711 /****************************************************************************** 00712 * Efficient evaluation of polynomials 00713 ******************************************************************************/ 00714 00715 template<typename V> 00716 struct polynomial_evaluate_threshold {}; 00717 00718 template<typename V, typename BV> 00719 struct implementation<polynomial_evaluate,V,polynomial_dicho<BV> >: 00720 public implementation<polynomial_divide,V> 00721 { 00722 typedef polynomial_evaluate_threshold<polynomial_dicho<BV> > Th; 00723 typedef implementation<vector_linear,V> Vec; 00724 typedef implementation<polynomial_divide,V> Pol; 00725 typedef implementation<polynomial_evaluate,V,BV> Fallback; 00726 00727 TMPL static inline void 00728 factorials (C* dest, nat n) { 00729 if (n > 0) dest[0]= C(1); 00730 for (nat i=1; i<n; i++) 00731 dest[i]= C(i) * dest[i-1]; 00732 } 00733 00734 TMPL static void 00735 shift (C* dest, const C* s, const C& sh, nat n) { 00736 if (n <= 1 || sh == 0) Pol::copy (dest, s, n); 00737 else { 00738 nat l= aligned_size<C,V> (5 * n); 00739 C* u = mmx_new<C> (l); 00740 C* v = u + n; 00741 C* w = v + n; 00742 C* facts= w + n + n; 00743 factorials (facts, n); 00744 Vec::mul (u, s, facts, n); 00745 Vec::vec_reverse (u, n); 00746 Vec::set (v, 1, n); 00747 if (sh != 0) Pol::q_difference (v, v, sh, n); 00748 Vec::div (v, facts, n); 00749 Pol::mul (w, u, v, n, n); // FIXME: rather use truncated multiplication 00750 Vec::vec_reverse (w, n); 00751 Vec::div (dest, w, facts, n); 00752 mmx_delete<C> (u, l); 00753 } 00754 } 00755 00756 TMPL static inline C 00757 evaluate (const C* p, const C& x, nat l) { 00758 return Fallback::evaluate (p, x, l); 00759 } 00760 00761 TMPL static void 00762 q_binomial (C* dest, const C& q, nat mu) { 00763 dest[mu]= 1; 00764 for (nat i=1; i<=mu; i++) 00765 dest[mu-i]= (C (mu+1-i) * dest[mu+1-i] * q) / C(i); 00766 } 00767 00768 TMPL static void 00769 expand (C** v, const C* p, const C* x, const nat* nu, nat n, nat k) { 00770 nat tot= 0; 00771 for (nat i=0; i<k; i++) 00772 tot += nu[i] + 1; 00773 nat* d= mmx_new<nat> (k); 00774 nat l= aligned_size<C,V> (tot << 1); 00775 C* q= mmx_new<C> (l); 00776 C* r= q + tot; 00777 nat off= 0; 00778 for (nat i=0; i<k; i++) { 00779 q_binomial (q + off, -x[i], nu[i]); 00780 d[i]= nu[i] + 1; 00781 off += d[i]; 00782 } 00783 multi_mod (r, p, q, d, n, k); 00784 off= 0; 00785 for (nat i=0; i<k; i++) { 00786 shift (v[i], r + off, x[i], nu[i]); 00787 off += d[i]; 00788 } 00789 mmx_delete<C> (q, l); 00790 mmx_delete<nat> (d, k); 00791 } 00792 00793 #define C typename scalar_type_helper<Polynomial >::val 00794 00795 struct _vector_sort_by_increasing_degree_op { 00796 TMPLP static bool 00797 op (const Polynomial& p, const Polynomial& q) { 00798 return deg (p) < deg (q); } 00799 TMPLP static bool 00800 not_op (const Polynomial& p, const Polynomial& q) { 00801 return deg (p) >= deg (q); } 00802 }; 00803 00804 template<typename Op, typename Polynomial> static inline vector<Polynomial> 00805 _multi_rem (const Polynomial& p, const vector<Polynomial>& q) { 00806 if (p == 0) return vector<Polynomial> (Polynomial(C(0)), N(q)); 00807 nat n= degree (p); 00808 vector<Polynomial> sorted_q (q), r (Polynomial(C(0)), N(q)); 00809 vector<nat> sigma; // permutation 00810 sort_leq<_vector_sort_by_increasing_degree_op> (sorted_q, sigma); 00811 nat start= 0, sum= 0; 00812 for (nat i= 0; i < N(q); i++) { 00813 sum += degree (sorted_q[i]); 00814 if (sum > n / 2 || i+1 == N(q)) { 00815 Crt_polynomial_transformer(Polynomial) 00816 crter (range (sorted_q, start, i+1)); 00817 vector<Polynomial> tmp; direct_crt (tmp, p, crter); 00818 for (nat j= start; j < i+1; j++) r[sigma[j]]= tmp[j-start]; 00819 start= i+1; sum= 0; 00820 } 00821 } 00822 return r; } 00823 00824 TMPLP static inline vector<Polynomial> 00825 multi_rem (const Polynomial& p, const vector<Polynomial>& q) { 00826 return _multi_rem<rem_op> (p, q); } 00827 00828 TMPLP static inline vector<Polynomial> 00829 multi_prem (const Polynomial& p, const vector<Polynomial>& q) { 00830 return _multi_rem<prem_op> (p, q); } 00831 00832 TMPLP static vector<Polynomial> 00833 multi_gcd (const Polynomial& P, const vector<Polynomial>& Q) { 00834 return binary_map<gcd_op> (rem (P, Q), Q); } 00835 00836 TMPLP static Polynomial 00837 annulator (const Vector& x) { 00838 ASSERT (is_non_scalar (x), "non-scalar xector expected"); 00839 if (N(x) == 0) return 1; 00840 vector<Polynomial> q (Polynomial (), N(x)); 00841 Polynomial z (C(1), 1); 00842 for (nat i= 0; i < N(x); i++) q[i]= z - x[i]; 00843 Crt_polynomial_transformer(Polynomial) crter (q); 00844 return * moduli_product (crter); } 00845 00846 TMPLP static inline Vector 00847 evaluate (const Polynomial& p, const Vector& x) { 00848 ASSERT (is_non_scalar (x), "non-scalar vector expected"); 00849 if (N(x) == 0) return Vector (C(0), 0); 00850 vector<Polynomial> q (Polynomial (), N(x)); 00851 Polynomial z (C(1), 1); 00852 for (nat i= 0; i < N(x); i++) q[i]= z - x[i]; 00853 vector<Polynomial> tmp= multi_rem (p, q); 00854 Vector r (C(0), N(x)); 00855 for (nat i= 0; i < N(x); i++) r[i]= tmp[i][0]; 00856 return r; 00857 } 00858 00859 TMPLP static inline Polynomial 00860 tevaluate (const Vector& v, const Vector& x, nat l) { 00861 ASSERT (is_non_scalar (x), "non-scalar vector expected"); 00862 if (l == 0) return Polynomial (0); 00863 nat n= N(x), ll= aligned_size<C,V> (l); 00864 vector<Polynomial> q (Polynomial (), n); 00865 Polynomial z (C(1), 1); 00866 for (nat i= 0; i < N(v); i++) q[i]= Polynomial (1) - x[i] * z; 00867 Crt_polynomial_transformer(Polynomial) crter (q); 00868 Polynomial num (combine_crt (v, crter)), den (* moduli_product (crter)); 00869 // retrieve the coefficients of the series num / den 00870 C* tmp= mmx_new<C> (ll), * inv= mmx_new<C> (ll); 00871 Pol::clear (tmp, l); Pol::copy (tmp, seg(den), min (N(den), l)); 00872 Pol::invert_lo (inv, tmp, l); 00873 Polynomial b (inv, l, ll), c (num * b); 00874 for (nat i= 0; i < l; i++) tmp[i]= c[i]; 00875 return Polynomial (tmp, l, ll); 00876 } 00877 00878 TMPLP static Polynomial 00879 interpolate (const Vector& v, const Vector& x) { 00880 ASSERT (is_non_scalar (x), "non-scalar vector expected"); 00881 ASSERT (N(v) == N(x), "dimensions don't match"); 00882 if (N(x) == 0) return Polynomial (0); 00883 vector<Polynomial> q (Polynomial (), N(x)); 00884 Polynomial z (C(1), 1); 00885 for (nat i= 0; i < N(x); i++) q[i]= z - x[i]; 00886 Crt_polynomial_transformer(Polynomial) crter (q); 00887 Polynomial ans; inverse_crt (ans, as<vector<Polynomial> > (v), crter); 00888 return ans; 00889 } 00890 00891 TMPLP static Vector 00892 tinterpolate (const Polynomial& p, const Vector& x) { 00893 ASSERT (is_non_scalar (x), "non-scalar vector expected"); 00894 ASSERT (N(p) <= N(x), "dimensions don't match"); 00895 nat n= N(x); 00896 if (n == 0) return Vector (C(0), 0); 00897 vector<Polynomial> q (Polynomial (), n); 00898 Polynomial z (C(1), 1); 00899 for (nat i= 0; i < n; i++) q[i]= z - x[i]; 00900 Crt_polynomial_transformer(Polynomial) crter (q); 00901 Polynomial rp (lshiftz (reverse (p), (int) (n - N(p)))); 00902 Polynomial den (* moduli_product (crter)); 00903 vector<Polynomial> a; direct_crt (a, range (den * rp, n, 2*n), crter); 00904 vector<Polynomial> b; direct_crt (b, derive (den), crter); 00905 Vector v (C(0), n); for (nat i= 0; i < n; i++) v[i]= a[i][0] / b[i][0]; 00906 return v; 00907 } 00908 00909 #undef C 00910 }; // implementation<polynomial_evaluate,V,polynomial_dicho<BV> > 00911 00912 #undef TMPL 00913 #undef TMPLP 00914 #undef Vector 00915 } // namespace mmx 00916 #endif //__MMX__POLYNOMIAL_DICHO__HPP