algebramix_doc 0.3
/Users/mourrain/Devel/mmx/algebramix/include/algebramix/polynomial_schonhage_strassen.hpp
Go to the documentation of this file.
00001 
00002 /******************************************************************************
00003 * MODULE     : polynomial_schonhage_strassen.hpp
00004 * DESCRIPTION: Schonhage-Strassen fast product
00005 * COPYRIGHT  : (C) 2008  Gregoire Lecerf
00006 *******************************************************************************
00007 * This software falls under the GNU general public license and comes WITHOUT
00008 * ANY WARRANTY WHATSOEVER. See the file $TEXMACS_PATH/LICENSE for more details.
00009 * If you don't have this file, write to the Free Software Foundation, Inc.,
00010 * 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
00011 ******************************************************************************/
00012 
00013 // Ref. "Modern Computer Algebra", Chapter 8, Algorithm 8.20.
00014 
00015 #ifndef __MMX__POLYNOMIAL_SCHONHAGE_STRASSEN__HPP
00016 #define __MMX__POLYNOMIAL_SCHONHAGE_STRASSEN__HPP
00017 #include <algebramix/polynomial.hpp>
00018 #include <algebramix/polynomial_dicho.hpp>
00019 #include <algebramix/polynomial_tft.hpp>
00020 
00021 namespace mmx {
00022 #define TMPL template<typename C>
00023 
00024 /******************************************************************************
00025 * Variant for Schonhage-Strassen multiplication
00026 ******************************************************************************/
00027 
00028 struct schonhage_strassen_threshold {};
00029 
00030 template<typename V, typename T= schonhage_strassen_threshold>
00031 struct polynomial_schonhage_strassen_inc: public V {
00032   typedef typename V::Vec Vec;
00033   typedef typename V::Naive Naive;
00034   typedef typename V::Positive Positive;
00035   typedef polynomial_schonhage_strassen_inc<typename V::No_simd,T> No_simd;
00036   typedef polynomial_schonhage_strassen_inc<typename V::No_thread,T> No_thread;
00037   typedef polynomial_schonhage_strassen_inc<typename V::No_scaled,T> No_scaled;
00038 };
00039 
00040 template<typename F, typename V, typename W, typename Th>
00041 struct implementation<F,V,polynomial_schonhage_strassen_inc<W,Th> >:
00042   public implementation<F,V,W> {};
00043 
00044 DEFINE_VARIANT_1 (typename V, V,
00045                   polynomial_schonhage_strassen,
00046                   polynomial_balanced_tft<
00047                     polynomial_schonhage_strassen_inc<
00048                       polynomial_karatsuba<V> > >)
00049 
00050 /******************************************************************************
00051 * Schonhage-Strassen multiplication
00052 ******************************************************************************/
00053 
00054 template<typename V, typename W, typename Th>
00055 struct implementation<polynomial_multiply,V,
00056                       polynomial_schonhage_strassen_inc<W,Th> >:
00057   public implementation<polynomial_linear,V>
00058 {
00059   typedef implementation<vector_linear,V> Vec;
00060   typedef implementation<polynomial_linear,V> Pol;
00061   typedef implementation<polynomial_multiply,W> Inner;
00062 
00063 private:  
00064   TMPL static inline C**
00065   bivariate_new (nat m2, nat t) {
00066     nat l = aligned_size<C ,V> (m2);
00067     nat ly= aligned_size<C*,V> (t);
00068     C** dest= mmx_new<C*> (ly);
00069     for (nat j = 0; j < t; j++)
00070       dest[j]= mmx_new<C> (l);
00071     return dest; }
00072 
00073   TMPL static inline void
00074   bivariate_delete (C** dest, nat m2, nat t) {
00075     nat l = aligned_size<C ,V> (m2);
00076     nat ly= aligned_size<C*,V> (t);
00077     for (nat j = 0; j < t; j++)
00078       mmx_delete<C> (dest[j], l);
00079     mmx_delete<C*> (dest, ly); }
00080 
00081   TMPL static inline void
00082   bivariate_encode (C** dest, const C*s, nat n, nat m, nat t) {
00083     // dest (x, x^m) = s (x)
00084     // s has size n, and dest has size t in x^m and 2m in x.
00085     // We assume that n <= m * t 
00086     nat i, j, m2= m << 1;
00087     for (i = 0, j = 0; i < n; i += m, j++)
00088       if (n-i >= m) {
00089         Vec::copy  (dest[j], s + i, m);
00090         Vec::clear (dest[j] + m, m);
00091       }
00092       else {
00093         Vec::copy  (dest[j], s + i, n - i);
00094         Vec::clear (dest[j] + n - i, m2 - (n - i));
00095       }
00096     for (; j < t; j++)
00097       Vec::clear (dest[j], m2); }
00098   
00099   TMPL static inline void
00100   bivariate_decode (C* dest, const C** s, nat n, nat m, nat t) {
00101     // dest (x) = s (x, x^m), dest has size n
00102     // We assume that n = m * t
00103     nat i, m2= m << 1;
00104     Vec::clear (dest, n);
00105     for (i = 0; i+1 < t; i++)
00106       Pol::add (dest + i * m, s[i], m2);
00107     Pol::add (dest + i * m, s[i]    , m);
00108     Pol::sub (dest        , s[i] + m, m); }
00109 
00110   TMPL static inline void
00111   negative_cyclic_shift (C* dest, const C* src, nat m2, nat i) {
00112     // multiply by x^i modulo x^m2 + 1
00113     bool negate= ((i / m2) & 1);
00114     i= i % m2;
00115     if (negate) {
00116       Vec::neg   (dest + i, src         , m2 - i);
00117       Vec::copy  (dest    , src + m2 - i, i);
00118     }
00119     else {
00120       Vec::copy  (dest + i, src         , m2 - i);
00121       Vec::neg   (dest    , src + m2 - i, i); } }
00122   
00123   TMPL static inline void
00124   negative_cyclic_shift (C* dest, nat m2, nat i) {
00125     nat l= aligned_size<C,V> (m2);
00126     C* temp= mmx_new<C> (l);
00127     Vec::copy (temp, dest, m2);
00128     negative_cyclic_shift (dest, temp, m2, i);
00129     mmx_delete<C> (temp, l); }
00130 
00131   template<typename C>
00132   struct unptr_helper {};
00133 
00134   template<typename C>
00135   struct unptr_helper<C*> {
00136     typedef C type; };
00137 
00138   template<typename Cp>
00139   struct negative_cyclic_roots_helper {
00140     // modulo x^m2 + 1
00141     typedef Cp  C;
00142     typedef nat U;
00143     typedef typename unptr_helper<Cp>::type CC;
00144     typedef CC  S;
00145 
00146     static inline nat&
00147     dyn_modulus () {
00148       static nat m2= 0;
00149       return m2; }
00150 
00151     static inline C&
00152     get_temp () {
00153       static C temp= NULL;
00154       return temp; }
00155 
00156     static inline nat
00157     primitive_root (nat t, nat i) {
00158       ASSERT (t != 0, "unexpected zero root order");
00159       i = i % t;
00160       nat m2 = dyn_modulus ();
00161       return i * (m2 << 1) / t; }
00162 
00163     static U*
00164     create_roots (nat t) {
00165       nat m2= dyn_modulus ();
00166       nat l= aligned_size<CC,V> (m2);
00167       get_temp ()= mmx_new<CC> (l);
00168       U* roots= mmx_new<U> (t);
00169       for (nat i = 0; i < t; i += 2) {
00170         roots[i]  = primitive_root (t, bit_mirror (i, t));
00171         roots[i+1]= primitive_root (t, i == 0 ? 0 : t - bit_mirror (i, t));
00172       }
00173       return roots; }
00174 
00175     static void
00176     destroy_roots (U* u, nat t) {
00177       C& temp= get_temp ();
00178       if (temp != NULL) {
00179         nat m2= dyn_modulus ();
00180         nat l= aligned_size<CC,V> (m2);
00181         mmx_delete<CC> (temp, l);
00182         temp= NULL;
00183       }
00184       mmx_delete<U> (u, t); }
00185 
00186     static inline void
00187     fft_cross (C* c1, C* c2) {
00188       C temp= get_temp ();
00189       nat m2= dyn_modulus ();
00190       Vec::copy (temp, *c2, m2);
00191       Vec::sub  (*c2, *c1, temp, m2);
00192       Vec::add  (*c1, temp, m2); }
00193     
00194     static inline void
00195     dfft_cross (C* c1, C* c2, const U* u) {
00196       C temp= get_temp ();
00197       nat m2= dyn_modulus ();
00198       negative_cyclic_shift (temp, *c2, m2, *u);
00199       Vec::sub  (*c2, *c1, temp, m2);
00200       Vec::add  (*c1, temp, m2); }
00201 
00202     static inline void
00203     ifft_cross (C* c1, C* c2, const U* u) {
00204       C temp= get_temp ();
00205       nat m2= dyn_modulus ();
00206       Vec::copy (temp, *c1, m2);
00207       Vec::add  (*c1, *c2, m2);
00208       Vec::sub  (*c2, temp, m2);
00209       negative_cyclic_shift (*c2, m2, (*u) + m2); }
00210 
00211     static inline void
00212     dtft_cross (C* c1, C* c2) {
00213       static CC h= invert (CC(2));
00214       nat m2= dyn_modulus ();
00215       fft_cross (c1, c2);
00216       Vec::mul (*c1, h, m2);
00217       Vec::mul (*c2, h, m2); }
00218     
00219     static inline void
00220     dtft_cross (C* c1, C* c2, const U* u) {
00221       static CC h= invert (CC(2));
00222       nat m2= dyn_modulus ();
00223       dfft_cross (c1, c2, u);
00224       Vec::mul (*c1, h, m2);
00225       Vec::mul (*c2, h, m2); }
00226     
00227     static inline void
00228     itft_flip (C* c1, C* c2, const U* u) {
00229       static CC h= invert (CC(2));
00230       C temp= get_temp ();
00231       nat m2= dyn_modulus ();
00232       negative_cyclic_shift (temp, *c2, m2, *u);
00233       Vec::add (*c1, *c1, m2);
00234       Vec::sub (*c1, temp, m2);
00235       Vec::sub (*c2, *c1, temp, m2);
00236       Vec::mul (*c2, h, m2); }
00237     
00238     static inline void
00239     itft_flip (C* c1, C* c2) {
00240       static CC h= invert (CC(2));
00241       nat m2= dyn_modulus ();
00242       Vec::add (*c1, *c1, m2);
00243       Vec::sub (*c1, *c2, m2);
00244       Vec::sub (*c2, *c1, *c2, m2);
00245       Vec::mul (*c2, h, m2); }
00246     
00247   struct fft_mul_sc_op : mul_op {
00248     static inline void
00249     set_op (C& x, const S& y) {
00250       Vec::mul (x, y, dyn_modulus ()); }
00251   };
00252   };
00253 
00254   TMPL
00255   struct negative_cyclic_roots {
00256     typedef negative_cyclic_roots_helper<C> roots_type;
00257   };
00258 
00259   TMPL static void
00260   direct_transform (C** b, nat m2, nat t) {
00261     negative_cyclic_roots_helper<C*>::dyn_modulus ()= m2;
00262     fft_naive_transformer<C*, negative_cyclic_roots<C*> > ffter (t);
00263     ffter.direct_transform (b); } 
00264 
00265   TMPL static void
00266   inverse_transform (C** b, nat m2, nat t, bool divide=true) {
00267     negative_cyclic_roots_helper<C*>::dyn_modulus ()= m2;
00268     fft_naive_transformer<C*, negative_cyclic_roots<C*> > ffter (t);
00269     ffter.inverse_transform (b, divide); } 
00270 
00271   TMPL static void
00272   direct_transform_truncated (C** b, nat m2, nat len) {
00273     negative_cyclic_roots_helper<C*>::dyn_modulus ()= m2;
00274     fft_truncated_transformer<C*,
00275       fft_naive_transformer<C*, negative_cyclic_roots<C*> > > ffter (len);
00276     ffter.dtft (b, 1, len, 0); }
00277 
00278   TMPL static void
00279   inverse_transform_truncated (C** b, nat m2, nat len) {
00280     nat i;
00281     for (i = len; i < next_power_of_two (len); i++)
00282       Vec::clear (b[i], m2);
00283     negative_cyclic_roots_helper<C*>::dyn_modulus ()= m2;
00284     fft_truncated_transformer<C*,
00285       fft_naive_transformer<C*, negative_cyclic_roots<C*> > > ffter (len);
00286     ffter.itft (b, 1, len, 0);
00287     C h= invert (C(next_power_of_two (len)));
00288     for (i = 0; i < len; i++)
00289       Vec::mul (b[i], h, m2);
00290     for (; i < next_power_of_two (len); i++)
00291       Vec::clear (b[i], m2); }
00292 
00293   TMPL static void
00294   variable_dilate (C** b, nat eta, nat m2, nat t) {
00295     // substitute eta * y for y in b of size t
00296     nat m4= m2 << 1;
00297     nat w= 0;
00298     for (nat i = 0; i < t; i++) {
00299       negative_cyclic_shift (b[i], m2, w);
00300       w= (w + eta) % m4; } }
00301 
00302 public:
00303   TMPL static inline nat
00304   mul_negative_cyclic (C* dest, const C* src, nat n, bool shift=true) {
00305     // dest *= src mod (x^n + 1)
00306     nat u= next_power_of_two (n);
00307     ASSERT (u != 0, "maximum size exceeded");
00308     ASSERT (n == u, "power of two expected");
00309     if (n <= max ((nat) 4, (nat) Threshold(C,Th))) {
00310       nat l= aligned_size<C,V> (2*n-1);
00311       C* temp= mmx_new<C> (l);
00312       Inner::mul (temp, dest, src, n, n);
00313       Vec::copy (dest, temp, n);
00314       Vec::sub  (dest, temp + n, n - 1);      
00315       mmx_delete<C> (temp, l);
00316       return 0;
00317     }
00318     else {
00319       nat r = 0;
00320       nat k = log_2 (n);
00321       nat m = (nat) 1 << (k >> 1);
00322       nat t = u / m;
00323       nat m2= m << 1;
00324       C** b1= bivariate_new<C> (m2, t);
00325       C** b2= bivariate_new<C> (m2, t);
00326       bivariate_encode (b1, src , n, m, t);
00327       bivariate_encode (b2, dest, n, m, t);
00328       nat eta= (t == m2) ? 1 : 2;
00329       variable_dilate (b1, eta, m2, t);
00330       variable_dilate (b2, eta, m2, t);
00331       direct_transform (b1, m2, t);
00332       direct_transform (b2, m2, t);
00333       for (nat i = 0; i < t; i++)
00334         r= mul_negative_cyclic (b1[i], b2[i], m2, shift);
00335       bivariate_delete<C> (b2, m2, t);
00336       inverse_transform (b1, m2, t, shift);
00337       variable_dilate (b1, (m2 << 1) - eta, m2, t);
00338       bivariate_decode (dest, (const C**) b1, n, m, t);
00339       bivariate_delete<C> (b1, m2, t);
00340       return r + ((k+1) >> 1); } }
00341 
00342   TMPL static inline void
00343   mul_negative_cyclic_truncated (C* dest, const C* src, nat len) {
00344     // dest *= src mod (x^n + 1)
00345     // TFT is used in first round
00346     nat n= next_power_of_two (len);
00347     ASSERT (n != 0, "maximum size exceeded");
00348     if (n <= max ((nat) 4, (nat) Threshold(C,Th))) {
00349       nat l= aligned_size<C,V> (2*n-1);
00350       C* temp= mmx_new<C> (l);
00351       Inner::mul (temp, dest, src, n, n);
00352       Vec::copy (dest, temp, n);
00353       Vec::sub  (dest, temp + n, n - 1);      
00354       mmx_delete<C> (temp, l);
00355       return;
00356     }
00357     else {
00358       nat k = log_2 (n);
00359       nat m = (nat) 1 << (k >> 1);
00360       nat t = n / m;
00361       nat r = (len + m - 1) / m;
00362       nat m2= m << 1;
00363       C** b1= bivariate_new<C> (m2, t);
00364       C** b2= bivariate_new<C> (m2, t);
00365       bivariate_encode (b1, src , n, m, t);
00366       bivariate_encode (b2, dest, n, m, t);
00367       nat eta= (t == m2) ? 1 : 2;
00368       variable_dilate (b1, eta, m2, r);
00369       variable_dilate (b2, eta, m2, r);
00370       direct_transform_truncated (b1, m2, r);
00371       direct_transform_truncated (b2, m2, r);
00372       for (nat i = 0; i < r; i++)
00373         mul_negative_cyclic (b1[i], b2[i], m2);
00374       bivariate_delete<C> (b2, m2, t);
00375       inverse_transform_truncated (b1, m2, r);
00376       variable_dilate (b1, (m2 << 1) - eta, m2, r);
00377       bivariate_decode (dest, (const C**) b1, n, m, t);
00378       bivariate_delete<C> (b1, m2, t); } }
00379 
00380   TMPL static inline nat
00381   mul (C* dest, const C* s1, const C* s2, nat n1, nat n2, bool shift=true) {
00382     nat ret = 0;
00383     nat len = n1 + n2 - 1;
00384     nat n = next_power_of_two (len);
00385     nat l = aligned_size<C,V> (n);
00386     C* t1 = mmx_new<C> (l);
00387     C* t2 = mmx_new<C> (l);
00388     Vec::copy (t1, s1, n1); Vec::clear (t1 + n1, n - n1);
00389     Vec::copy (t2, s2, n2); Vec::clear (t2 + n2, n - n2);
00390     if (shift)
00391       mul_negative_cyclic_truncated (t1, t2, len);
00392     else
00393       ret= mul_negative_cyclic (t1, t2, n, shift);
00394     Vec::copy (dest, t1, len);
00395     mmx_delete<C> (t1, l);
00396     mmx_delete<C> (t2, l);
00397     return ret; }
00398 
00399 }; // implementation<polynomial_multiply,V,polynomial_schonhage_strassen_inc<W,Th> >
00400 
00401 #undef TMPL
00402 } // namespace mmx
00403 #endif //__MMX__POLYNOMIAL_SCHONHAGE_STRASSEN__HPP
 All Classes Namespaces Files Functions Variables Typedefs Friends Defines