algebramix_doc 0.3
/Users/mourrain/Devel/mmx/algebramix/include/algebramix/polynomial_schonhage_triadic.hpp
Go to the documentation of this file.
00001 
00002 /******************************************************************************
00003 * MODULE     : polynomial_schonhage_triadic.hpp
00004 * DESCRIPTION: Schonhage's triadic cyclic wrapped product
00005 * COPYRIGHT  : (C) 2008  Gregoire Lecerf
00006 *******************************************************************************
00007 * This software falls under the GNU general public license and comes WITHOUT
00008 * ANY WARRANTY WHATSOEVER. See the file $TEXMACS_PATH/LICENSE for more details.
00009 * If you don't have this file, write to the Free Software Foundation, Inc.,
00010 * 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
00011 ******************************************************************************/
00012 
00013 // Ref. "Modern Computer Algebra", Chapter 8, Exercise 8.30.
00014 
00015 #ifndef __MMX__POLYNOMIAL_SCHONHAGE_TRIADIC__HPP
00016 #define __MMX__POLYNOMIAL_SCHONHAGE_TRIADIC__HPP
00017 #include <algebramix/fft_triadic_naive.hpp>
00018 #include <algebramix/polynomial_dicho.hpp>
00019 #include <algebramix/polynomial_balanced.hpp>
00020 
00021 namespace mmx {
00022 #define TMPL template<typename C>
00023 
00024 /******************************************************************************
00025 * Variant for triadic Schonhage-Strassen multiplication
00026 ******************************************************************************/
00027 
00028 struct schonhage_triadic_threshold {};
00029 
00030 template<typename V, typename Th= schonhage_triadic_threshold>
00031 struct polynomial_schonhage_triadic_inc: public V {
00032   typedef typename V::Vec Vec;
00033   typedef typename V::Naive Naive;
00034   typedef typename V::Positive Positive;
00035   typedef polynomial_schonhage_triadic_inc<typename V::No_simd,Th> No_simd;
00036   typedef polynomial_schonhage_triadic_inc<typename V::No_thread,Th> No_thread;
00037   typedef polynomial_schonhage_triadic_inc<typename V::No_scaled,Th> No_scaled;
00038 };
00039 
00040 template<typename F, typename V, typename W, typename Th>
00041 struct implementation<F,V,polynomial_schonhage_triadic_inc<W,Th> >:
00042   public implementation<F,V,W> {};
00043 
00044 DEFINE_VARIANT_1 (typename V, V,
00045                   polynomial_schonhage_triadic,
00046                   polynomial_balanced<
00047                     polynomial_schonhage_triadic_inc<
00048                       polynomial_karatsuba<V> > >)
00049 
00050 /******************************************************************************
00051 * Triadic Schonhage-Strassen multiplication
00052 ******************************************************************************/
00053 
00054 template<typename V, typename W, typename Th>
00055 struct implementation<polynomial_multiply,V,
00056                       polynomial_schonhage_triadic_inc<W,Th> >:
00057   public implementation<polynomial_linear,V>
00058 {
00059   typedef implementation<vector_linear,V> Vec;
00060   typedef implementation<polynomial_linear,V> Pol;
00061   typedef implementation<polynomial_multiply,W> Inner;
00062 
00063 private:  
00064   TMPL static inline C**
00065   bivariate_new (nat m, nat t) {
00066     nat l = aligned_size<C,V> (m);
00067     C** dest= mmx_new<C*> (t);
00068     for (nat j = 0; j < t; j++)
00069       dest[j]= mmx_new<C> (l);
00070     return dest; }
00071 
00072   TMPL static inline void
00073   bivariate_delete (C** dest, nat m, nat t) {
00074     nat l = aligned_size<C,V> (m);
00075     for (nat j = 0; j < t; j++)
00076       mmx_delete<C> (dest[j], l);
00077     mmx_delete<C*> (dest, t); }
00078 
00079   TMPL static inline void
00080   bivariate_encode (C** dest, const C*s, nat n, nat m, nat t) {
00081     // dest (x, x^m) = s (x)
00082     // s has size 2n, and dest has size 2t in x^m and 2m in x.
00083     // We assume that n <= m * t 
00084     nat i, j, n2= n << 1, m2= m << 1, t2= t << 1;
00085     for (i = 0, j = 0; i < n2; i += m, j++)
00086       if (n2-i >= m) {
00087         Vec::copy  (dest[j], s + i, m);
00088         Vec::clear (dest[j] + m, m);
00089       }
00090       else {
00091         Vec::copy  (dest[j], s + i, n2 - i);
00092         Vec::clear (dest[j] + n2 - i, m2 - (n2 - i));
00093       }
00094     for (; j < t2; j++)
00095       Vec::clear (dest[j], m2); }
00096   
00097   TMPL static inline void
00098   bivariate_decode (C* dest, const C** s, nat n, nat m, nat t) {
00099     // dest (x) = s (x, x^m) mod (x^(2n) + x^n + 1)
00100     // dest has size 2n, s has size 2t.
00101     // We assume that n = m * t
00102     nat i, n2= n << 1, m2= m << 1, t2= t << 1;
00103     Vec::clear (dest, n2);
00104     for (i = 0; i+1 < t2; i++)
00105       Pol::add (dest + i * m, s[i], m2);
00106     Pol::add (dest + i * m, s[i]    , m);
00107     Pol::sub (dest        , s[i] + m, m);
00108     Pol::sub (dest + n    , s[i] + m, m); }
00109 
00110   TMPL static inline void
00111   triadic_shift (C* dest, const C* src, nat i, nat m) {
00112     // multiply by x^i modulo x^(2*m) + x^m + 1
00113     nat m2= m << 1, m3= m2 + m;
00114     i= i % m3;
00115     if (i > m2) {
00116       Vec::copy (dest         , src + m3 - i, i - m);
00117       Vec::neg  (dest + i - m , src         , m3 - i);
00118       Vec::sub  (dest + i - m2, src         , m3 - i);
00119     }
00120     else if (i > m) {
00121       Vec::neg  (dest    , src + m2 - i, m);
00122       Vec::neg  (dest + m, src + m2 - i, m);
00123       Vec::add  (dest    , src + m3 - i, i - m);
00124       Vec::add  (dest + i, src         , m2 - i);
00125     }
00126     else {
00127       Vec::neg  (dest    , src + m2 - i, i);
00128       Vec::copy (dest + i, src         , m2 - i);
00129       Vec::sub  (dest + m, src + m2 - i, i); } }
00130   
00131   template<typename C>
00132   struct unptr_helper {};
00133 
00134   template<typename C>
00135   struct unptr_helper<C*> {
00136     typedef C type; };
00137 
00138   template<typename Cp>
00139   struct triadic_roots_helper {
00140     // modulo x^(2*m) + x^m + 1
00141     typedef Cp  C;
00142     typedef nat U;
00143     typedef typename unptr_helper<Cp>::type CC;
00144     typedef CC  S;
00145 
00146     static inline nat&
00147     dyn_m () {
00148       static nat m= 0;
00149       return m; }
00150 
00151     static inline C&
00152     dyn_temp (nat i) {
00153       static C temp0= NULL;
00154       static C temp1= NULL;
00155       static C temp2= NULL;
00156       static C temp3= NULL;
00157       switch (i) {
00158       case 0: return temp0;
00159       case 1: return temp1;
00160       case 2: return temp2;
00161       case 3: return temp3;
00162       default: ERROR ("index out of range"); return temp0; } }
00163 
00164     static inline nat
00165     primitive_root (nat t, nat i) {
00166       if (t == 0) return 1;
00167       i = i % t;
00168       nat m3 = 3 * dyn_m ();
00169       ASSERT (m3 % t == 0, "primitive root out of range");
00170       return i * (m3 / t); }
00171 
00172     static U*
00173     create_roots (nat t) {
00174       nat m2= dyn_m () << 1;
00175       nat l= aligned_size<CC,V> (m2);
00176       for (nat i = 0; i <= 3; i++)
00177         dyn_temp (i)= mmx_new<CC> (l);
00178       U* roots= mmx_new<U> (t);
00179       for (nat i = 0; i < t; i++)
00180         roots[i]  = primitive_root (t, digit_mirror_triadic (i, t));
00181       return roots; }
00182 
00183     static U*
00184     create_stoor (nat t) {
00185       U* stoor= mmx_new<U> (t);
00186       for (nat i = 0; i < t; i++)
00187         stoor[i]= primitive_root (t, i == 0 ?
00188                                   0 : t - digit_mirror_triadic (i, t));
00189       return stoor; }
00190 
00191     static void
00192     destroy_roots (U* u, nat t) {
00193       nat m2= dyn_m () << 1;
00194       nat l= aligned_size<CC,V> (m2);
00195       for (nat i = 0; i <= 3; i++) {
00196         C& temp= dyn_temp (i);
00197         if (temp != NULL) {
00198           mmx_delete<CC> (temp, l);
00199           temp= NULL;
00200         }
00201       }
00202       mmx_delete<U> (u, t); }
00203 
00204     static inline void
00205     dfft_cross (C* c1, C* c2, C* c3, const U* u1, const U* u2, const U* u3) {
00206       C temp0= dyn_temp (0);
00207       C temp1= dyn_temp (1);
00208       C temp2= dyn_temp (2);
00209       C temp3= dyn_temp (3);
00210       nat m= dyn_m (), m2= m << 1;
00211 
00212       triadic_shift (temp0, *c3, *u3, m);
00213       Vec::add (temp0, *c2, m2);
00214       triadic_shift (temp3, temp0, *u3, m);
00215 
00216       triadic_shift (temp0, *c3, *u2, m);
00217       Vec::add (temp0, *c2, m2);
00218       triadic_shift (temp2, temp0, *u2, m);
00219 
00220       triadic_shift (temp0, *c3, *u1, m);
00221       Vec::add (temp0, *c2, m2);
00222       triadic_shift (temp1, temp0, *u1, m);
00223 
00224       Vec::add (*c3, *c1, temp3, m2);
00225       Vec::add (*c2, *c1, temp2, m2);
00226       Vec::add (*c1, temp1, m2); }
00227 
00228     static inline void
00229     ifft_cross (C* c1, C* c2, C* c3, const U* u1, const U* u2, const U* u3) {
00230       C temp0= dyn_temp (0);
00231       C temp1= dyn_temp (1);
00232       C temp2= dyn_temp (2);
00233       C temp3= dyn_temp (3);
00234       nat m= dyn_m (), m2= m << 1;
00235 
00236       triadic_shift (temp1, *c1, *u1, m);
00237       Vec::add (*c1, *c2, m2);
00238       Vec::add (*c1, *c3, m2);
00239 
00240       triadic_shift (temp2, *c2, *u2, m);
00241       triadic_shift (temp3, *c3, *u3, m);
00242       triadic_shift (temp0, *c2, *u3, m);
00243       Vec::add (*c2, temp1, temp2, m2);
00244       Vec::add (*c2, temp3, m2);
00245 
00246       triadic_shift (temp3, *c3, *u2, m);
00247       Vec::add (temp0, temp3, m2);
00248       Vec::add (temp0, temp1, m2);
00249       triadic_shift (*c3, temp0, *u1, m); }
00250 
00251     static inline void
00252     fft_shift (C* dest, S v, nat t) {
00253       nat m2= dyn_m () << 1;
00254       for (nat i = 0; i < t; i++)
00255         Vec::mul (dest[i], v, m2); }
00256   };
00257 
00258   struct triadic_roots {
00259     template<typename C>
00260     struct helper {
00261       typedef triadic_roots_helper<C> roots_type; };
00262   };
00263 
00264   TMPL static void
00265   direct_transform (C** b, nat m, nat t) {
00266     triadic_roots_helper<C*>::dyn_m ()= m;
00267     fft_triadic_naive_transformer<C*, triadic_roots> ffter (t);
00268     ffter.direct_transform_triadic (b); } 
00269 
00270   TMPL static void
00271   inverse_transform (C** b, nat m, nat t, bool shift=true) {
00272     triadic_roots_helper<C*>::dyn_m ()= m;
00273     fft_triadic_naive_transformer<C*, triadic_roots> ffter (t);
00274     ffter.inverse_transform_triadic (b,shift); } 
00275 
00276   TMPL static void
00277   variable_dilate (C** b, nat eta, nat m, nat t) {
00278     // substitute eta * y for y in b of size t
00279     nat m2= m << 1, m3= m2 + m;
00280     nat l= aligned_size<C,V> (m2);
00281     C* temp= mmx_new<C> (l);
00282     nat w= 0;
00283     for (nat i = 0; i < t; i++) {
00284       Vec::copy (temp, b[i], m2);
00285       triadic_shift (b[i], temp, w, m);
00286       w= (w + eta) % m3;
00287     }
00288     mmx_delete<C> (temp, l); }
00289 
00290   TMPL static void
00291   bivariate_mod (C** dest, const C** h, nat w, nat m, nat t) {
00292     // dest = h mod (y^t - w)
00293     nat m2= m << 1, m3= m2 + m;
00294     w= w % m3;
00295     for (nat i = 0; i < t; i++) {
00296       triadic_shift (dest[i], h[i+t], w, m);
00297       Vec::add (dest[i], h[i], m2); } }
00298 
00299   TMPL static void
00300   bivariate_crt (C** h, const C** h1, const C** h2, nat w,
00301                  nat m, nat t, bool shift=true) {
00302     // h1 = h mod (y^t - w), h2 = h mod (y^t - w^2). 
00303     nat m2= m << 1, m3= m2 + m, t2= t << 1;
00304     w= w % m3;
00305     nat w2= (w << 1) % m3;
00306     nat l= aligned_size<C,V> (m2);
00307     C* temp= mmx_new<C> (l);
00308     for (nat i = 0; i < t; i++) 
00309       Vec::sub (h[i+t], h2[i], h1[i], m2);
00310     for (nat i = 0; i < t; i++) {
00311       triadic_shift (h[i], h1[i], w2, m);
00312       triadic_shift (temp, h2[i], w , m);
00313       Vec::sub (h[i], temp, m2);
00314     }
00315     for (nat i = 0; i < t2; i++) {
00316       triadic_shift (temp, h[i], w, m);
00317       Vec::add (h[i], temp, m2);
00318       Vec::add (h[i], temp, m2);
00319       if (shift)
00320         Vec::mul (h[i], invert (C(3)), m2);
00321     }
00322     mmx_delete<C> (temp, l); }
00323   
00324 public:
00325   TMPL static inline nat
00326   mul_triadic (C* dest, const C* src, nat n, bool shift=true) {
00327     // dest *= src mod (x^(2*n) + x^n + 1)
00328     nat n2= n << 1;
00329     nat u = next_power_of_three (n);
00330     ASSERT (u != 0, "maximum size exceeded");
00331     ASSERT (n == u, "power of three expected");
00332     if (n <= max ((nat) 9, (nat) Threshold(C,Th))) {
00333       nat l= aligned_size<C,V> (2*n2-1);
00334       C* temp= mmx_new<C> (l);
00335       Inner::mul (temp, dest, src, n2, n2);
00336       Vec::copy (dest    , temp         , n2);
00337       Vec::add  (dest    , temp + n2 + n, n - 1);
00338       Vec::sub  (dest    , temp + n2    , n);
00339       Vec::sub  (dest + n, temp + n2    , n);
00340       mmx_delete<C> (temp, l);
00341       return 0;
00342     }
00343     else {
00344       nat r = 0;
00345       nat k = log_3 (n);
00346       nat m = binpow (3, (k+1) >> 1), m2= m << 1, m3= m2 + m;
00347       nat t = n / m, t2= t << 1;
00348       nat eta= (t == m) ? 1 : 3;
00349       // mmout << "n = " << n << ", m = " << m << ", t = " << t << "\n";
00350       C** hh= bivariate_new<C> (m2, t2);
00351       C** gg= bivariate_new<C> (m2, t2);
00352       C** h1= bivariate_new<C> (m2, t);
00353       C** h2= bivariate_new<C> (m2, t);
00354       C** g = bivariate_new<C> (m2, t);
00355       bivariate_encode (hh, src , n, m, t);
00356       bivariate_encode (gg, dest, n, m, t);
00357       for (nat j = 1; j <= 2; j++) {
00358         C** h= j == 1 ? h1 : h2;
00359         bivariate_mod (h, (const C**) hh, (j*eta*t) % m3, m, t);
00360         bivariate_mod (g, (const C**) gg, (j*eta*t) % m3, m, t);
00361         variable_dilate (h, (j*eta) % m3, m, t);
00362         variable_dilate (g, (j*eta) % m3, m, t);
00363         direct_transform (h, m, t);
00364         direct_transform (g, m, t);
00365         for (nat i = 0; i < t; i++)
00366           r = mul_triadic (h[i], g[i], m, shift);
00367         inverse_transform (h, m, t, shift);
00368         variable_dilate (h, m3 - ((j*eta) % m3), m, t);
00369       }
00370       bivariate_delete<C> (g , m2 , t);
00371       bivariate_delete<C> (gg, m2, t2);
00372       bivariate_crt (hh, (const C**) h1, (const C**) h2, (eta*t) % m3, m, t, shift);
00373       bivariate_delete<C> (h1, m2, t);
00374       bivariate_delete<C> (h2, m2, t); 
00375       bivariate_decode (dest, (const C**) hh, n, m, t);
00376       bivariate_delete<C> (hh, m2, t2);
00377       return (shift) ? 0 : r + 1 + (k >> 1); }}
00378 
00379   TMPL static inline nat
00380   mul (C* dest, const C* s1, const C* s2, nat n1, nat n2, bool shift=true) {
00381     nat len = n1 + n2 - 1;
00382     nat n = next_power_of_three ((len + 1) >> 1);
00383     nat nn= n << 1;
00384     nat l = aligned_size<C,V> (nn);
00385     C* t1 = mmx_new<C> (l);
00386     C* t2 = mmx_new<C> (l);
00387     Vec::copy (t1, s1, n1); Vec::clear (t1 + n1, nn - n1);
00388     Vec::copy (t2, s2, n2); Vec::clear (t2 + n2, nn - n2);
00389     nat k = mul_triadic (t1, t2, n, shift);
00390     Vec::copy (dest, t1, len);
00391     mmx_delete<C> (t1, l);
00392     mmx_delete<C> (t2, l);
00393     return k; }
00394 
00395 }; // implementation<polynomial_multiply,V,polynomial_schonhage_triadic_inc<W,Th> >
00396 
00397 #undef TMPL
00398 } // namespace mmx
00399 #endif //__MMX__POLYNOMIAL_SCHONHAGE_TRIADIC__HPP
 All Classes Namespaces Files Functions Variables Typedefs Friends Defines