algebramix_doc 0.3
|
00001 00002 /****************************************************************************** 00003 * MODULE : polynomial_schonhage_strassen.hpp 00004 * DESCRIPTION: Schonhage-Strassen fast product 00005 * COPYRIGHT : (C) 2008 Gregoire Lecerf 00006 ******************************************************************************* 00007 * This software falls under the GNU general public license and comes WITHOUT 00008 * ANY WARRANTY WHATSOEVER. See the file $TEXMACS_PATH/LICENSE for more details. 00009 * If you don't have this file, write to the Free Software Foundation, Inc., 00010 * 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. 00011 ******************************************************************************/ 00012 00013 // Ref. "Modern Computer Algebra", Chapter 8, Algorithm 8.20. 00014 00015 #ifndef __MMX__POLYNOMIAL_SCHONHAGE_STRASSEN__HPP 00016 #define __MMX__POLYNOMIAL_SCHONHAGE_STRASSEN__HPP 00017 #include <algebramix/polynomial.hpp> 00018 #include <algebramix/polynomial_dicho.hpp> 00019 #include <algebramix/polynomial_tft.hpp> 00020 00021 namespace mmx { 00022 #define TMPL template<typename C> 00023 00024 /****************************************************************************** 00025 * Variant for Schonhage-Strassen multiplication 00026 ******************************************************************************/ 00027 00028 struct schonhage_strassen_threshold {}; 00029 00030 template<typename V, typename T= schonhage_strassen_threshold> 00031 struct polynomial_schonhage_strassen_inc: public V { 00032 typedef typename V::Vec Vec; 00033 typedef typename V::Naive Naive; 00034 typedef typename V::Positive Positive; 00035 typedef polynomial_schonhage_strassen_inc<typename V::No_simd,T> No_simd; 00036 typedef polynomial_schonhage_strassen_inc<typename V::No_thread,T> No_thread; 00037 typedef polynomial_schonhage_strassen_inc<typename V::No_scaled,T> No_scaled; 00038 }; 00039 00040 template<typename F, typename V, typename W, typename Th> 00041 struct implementation<F,V,polynomial_schonhage_strassen_inc<W,Th> >: 00042 public implementation<F,V,W> {}; 00043 00044 DEFINE_VARIANT_1 (typename V, V, 00045 polynomial_schonhage_strassen, 00046 polynomial_balanced_tft< 00047 polynomial_schonhage_strassen_inc< 00048 polynomial_karatsuba<V> > >) 00049 00050 /****************************************************************************** 00051 * Schonhage-Strassen multiplication 00052 ******************************************************************************/ 00053 00054 template<typename V, typename W, typename Th> 00055 struct implementation<polynomial_multiply,V, 00056 polynomial_schonhage_strassen_inc<W,Th> >: 00057 public implementation<polynomial_linear,V> 00058 { 00059 typedef implementation<vector_linear,V> Vec; 00060 typedef implementation<polynomial_linear,V> Pol; 00061 typedef implementation<polynomial_multiply,W> Inner; 00062 00063 private: 00064 TMPL static inline C** 00065 bivariate_new (nat m2, nat t) { 00066 nat l = aligned_size<C ,V> (m2); 00067 nat ly= aligned_size<C*,V> (t); 00068 C** dest= mmx_new<C*> (ly); 00069 for (nat j = 0; j < t; j++) 00070 dest[j]= mmx_new<C> (l); 00071 return dest; } 00072 00073 TMPL static inline void 00074 bivariate_delete (C** dest, nat m2, nat t) { 00075 nat l = aligned_size<C ,V> (m2); 00076 nat ly= aligned_size<C*,V> (t); 00077 for (nat j = 0; j < t; j++) 00078 mmx_delete<C> (dest[j], l); 00079 mmx_delete<C*> (dest, ly); } 00080 00081 TMPL static inline void 00082 bivariate_encode (C** dest, const C*s, nat n, nat m, nat t) { 00083 // dest (x, x^m) = s (x) 00084 // s has size n, and dest has size t in x^m and 2m in x. 00085 // We assume that n <= m * t 00086 nat i, j, m2= m << 1; 00087 for (i = 0, j = 0; i < n; i += m, j++) 00088 if (n-i >= m) { 00089 Vec::copy (dest[j], s + i, m); 00090 Vec::clear (dest[j] + m, m); 00091 } 00092 else { 00093 Vec::copy (dest[j], s + i, n - i); 00094 Vec::clear (dest[j] + n - i, m2 - (n - i)); 00095 } 00096 for (; j < t; j++) 00097 Vec::clear (dest[j], m2); } 00098 00099 TMPL static inline void 00100 bivariate_decode (C* dest, const C** s, nat n, nat m, nat t) { 00101 // dest (x) = s (x, x^m), dest has size n 00102 // We assume that n = m * t 00103 nat i, m2= m << 1; 00104 Vec::clear (dest, n); 00105 for (i = 0; i+1 < t; i++) 00106 Pol::add (dest + i * m, s[i], m2); 00107 Pol::add (dest + i * m, s[i] , m); 00108 Pol::sub (dest , s[i] + m, m); } 00109 00110 TMPL static inline void 00111 negative_cyclic_shift (C* dest, const C* src, nat m2, nat i) { 00112 // multiply by x^i modulo x^m2 + 1 00113 bool negate= ((i / m2) & 1); 00114 i= i % m2; 00115 if (negate) { 00116 Vec::neg (dest + i, src , m2 - i); 00117 Vec::copy (dest , src + m2 - i, i); 00118 } 00119 else { 00120 Vec::copy (dest + i, src , m2 - i); 00121 Vec::neg (dest , src + m2 - i, i); } } 00122 00123 TMPL static inline void 00124 negative_cyclic_shift (C* dest, nat m2, nat i) { 00125 nat l= aligned_size<C,V> (m2); 00126 C* temp= mmx_new<C> (l); 00127 Vec::copy (temp, dest, m2); 00128 negative_cyclic_shift (dest, temp, m2, i); 00129 mmx_delete<C> (temp, l); } 00130 00131 template<typename C> 00132 struct unptr_helper {}; 00133 00134 template<typename C> 00135 struct unptr_helper<C*> { 00136 typedef C type; }; 00137 00138 template<typename Cp> 00139 struct negative_cyclic_roots_helper { 00140 // modulo x^m2 + 1 00141 typedef Cp C; 00142 typedef nat U; 00143 typedef typename unptr_helper<Cp>::type CC; 00144 typedef CC S; 00145 00146 static inline nat& 00147 dyn_modulus () { 00148 static nat m2= 0; 00149 return m2; } 00150 00151 static inline C& 00152 get_temp () { 00153 static C temp= NULL; 00154 return temp; } 00155 00156 static inline nat 00157 primitive_root (nat t, nat i) { 00158 ASSERT (t != 0, "unexpected zero root order"); 00159 i = i % t; 00160 nat m2 = dyn_modulus (); 00161 return i * (m2 << 1) / t; } 00162 00163 static U* 00164 create_roots (nat t) { 00165 nat m2= dyn_modulus (); 00166 nat l= aligned_size<CC,V> (m2); 00167 get_temp ()= mmx_new<CC> (l); 00168 U* roots= mmx_new<U> (t); 00169 for (nat i = 0; i < t; i += 2) { 00170 roots[i] = primitive_root (t, bit_mirror (i, t)); 00171 roots[i+1]= primitive_root (t, i == 0 ? 0 : t - bit_mirror (i, t)); 00172 } 00173 return roots; } 00174 00175 static void 00176 destroy_roots (U* u, nat t) { 00177 C& temp= get_temp (); 00178 if (temp != NULL) { 00179 nat m2= dyn_modulus (); 00180 nat l= aligned_size<CC,V> (m2); 00181 mmx_delete<CC> (temp, l); 00182 temp= NULL; 00183 } 00184 mmx_delete<U> (u, t); } 00185 00186 static inline void 00187 fft_cross (C* c1, C* c2) { 00188 C temp= get_temp (); 00189 nat m2= dyn_modulus (); 00190 Vec::copy (temp, *c2, m2); 00191 Vec::sub (*c2, *c1, temp, m2); 00192 Vec::add (*c1, temp, m2); } 00193 00194 static inline void 00195 dfft_cross (C* c1, C* c2, const U* u) { 00196 C temp= get_temp (); 00197 nat m2= dyn_modulus (); 00198 negative_cyclic_shift (temp, *c2, m2, *u); 00199 Vec::sub (*c2, *c1, temp, m2); 00200 Vec::add (*c1, temp, m2); } 00201 00202 static inline void 00203 ifft_cross (C* c1, C* c2, const U* u) { 00204 C temp= get_temp (); 00205 nat m2= dyn_modulus (); 00206 Vec::copy (temp, *c1, m2); 00207 Vec::add (*c1, *c2, m2); 00208 Vec::sub (*c2, temp, m2); 00209 negative_cyclic_shift (*c2, m2, (*u) + m2); } 00210 00211 static inline void 00212 dtft_cross (C* c1, C* c2) { 00213 static CC h= invert (CC(2)); 00214 nat m2= dyn_modulus (); 00215 fft_cross (c1, c2); 00216 Vec::mul (*c1, h, m2); 00217 Vec::mul (*c2, h, m2); } 00218 00219 static inline void 00220 dtft_cross (C* c1, C* c2, const U* u) { 00221 static CC h= invert (CC(2)); 00222 nat m2= dyn_modulus (); 00223 dfft_cross (c1, c2, u); 00224 Vec::mul (*c1, h, m2); 00225 Vec::mul (*c2, h, m2); } 00226 00227 static inline void 00228 itft_flip (C* c1, C* c2, const U* u) { 00229 static CC h= invert (CC(2)); 00230 C temp= get_temp (); 00231 nat m2= dyn_modulus (); 00232 negative_cyclic_shift (temp, *c2, m2, *u); 00233 Vec::add (*c1, *c1, m2); 00234 Vec::sub (*c1, temp, m2); 00235 Vec::sub (*c2, *c1, temp, m2); 00236 Vec::mul (*c2, h, m2); } 00237 00238 static inline void 00239 itft_flip (C* c1, C* c2) { 00240 static CC h= invert (CC(2)); 00241 nat m2= dyn_modulus (); 00242 Vec::add (*c1, *c1, m2); 00243 Vec::sub (*c1, *c2, m2); 00244 Vec::sub (*c2, *c1, *c2, m2); 00245 Vec::mul (*c2, h, m2); } 00246 00247 struct fft_mul_sc_op : mul_op { 00248 static inline void 00249 set_op (C& x, const S& y) { 00250 Vec::mul (x, y, dyn_modulus ()); } 00251 }; 00252 }; 00253 00254 TMPL 00255 struct negative_cyclic_roots { 00256 typedef negative_cyclic_roots_helper<C> roots_type; 00257 }; 00258 00259 TMPL static void 00260 direct_transform (C** b, nat m2, nat t) { 00261 negative_cyclic_roots_helper<C*>::dyn_modulus ()= m2; 00262 fft_naive_transformer<C*, negative_cyclic_roots<C*> > ffter (t); 00263 ffter.direct_transform (b); } 00264 00265 TMPL static void 00266 inverse_transform (C** b, nat m2, nat t, bool divide=true) { 00267 negative_cyclic_roots_helper<C*>::dyn_modulus ()= m2; 00268 fft_naive_transformer<C*, negative_cyclic_roots<C*> > ffter (t); 00269 ffter.inverse_transform (b, divide); } 00270 00271 TMPL static void 00272 direct_transform_truncated (C** b, nat m2, nat len) { 00273 negative_cyclic_roots_helper<C*>::dyn_modulus ()= m2; 00274 fft_truncated_transformer<C*, 00275 fft_naive_transformer<C*, negative_cyclic_roots<C*> > > ffter (len); 00276 ffter.dtft (b, 1, len, 0); } 00277 00278 TMPL static void 00279 inverse_transform_truncated (C** b, nat m2, nat len) { 00280 nat i; 00281 for (i = len; i < next_power_of_two (len); i++) 00282 Vec::clear (b[i], m2); 00283 negative_cyclic_roots_helper<C*>::dyn_modulus ()= m2; 00284 fft_truncated_transformer<C*, 00285 fft_naive_transformer<C*, negative_cyclic_roots<C*> > > ffter (len); 00286 ffter.itft (b, 1, len, 0); 00287 C h= invert (C(next_power_of_two (len))); 00288 for (i = 0; i < len; i++) 00289 Vec::mul (b[i], h, m2); 00290 for (; i < next_power_of_two (len); i++) 00291 Vec::clear (b[i], m2); } 00292 00293 TMPL static void 00294 variable_dilate (C** b, nat eta, nat m2, nat t) { 00295 // substitute eta * y for y in b of size t 00296 nat m4= m2 << 1; 00297 nat w= 0; 00298 for (nat i = 0; i < t; i++) { 00299 negative_cyclic_shift (b[i], m2, w); 00300 w= (w + eta) % m4; } } 00301 00302 public: 00303 TMPL static inline nat 00304 mul_negative_cyclic (C* dest, const C* src, nat n, bool shift=true) { 00305 // dest *= src mod (x^n + 1) 00306 nat u= next_power_of_two (n); 00307 ASSERT (u != 0, "maximum size exceeded"); 00308 ASSERT (n == u, "power of two expected"); 00309 if (n <= max ((nat) 4, (nat) Threshold(C,Th))) { 00310 nat l= aligned_size<C,V> (2*n-1); 00311 C* temp= mmx_new<C> (l); 00312 Inner::mul (temp, dest, src, n, n); 00313 Vec::copy (dest, temp, n); 00314 Vec::sub (dest, temp + n, n - 1); 00315 mmx_delete<C> (temp, l); 00316 return 0; 00317 } 00318 else { 00319 nat r = 0; 00320 nat k = log_2 (n); 00321 nat m = (nat) 1 << (k >> 1); 00322 nat t = u / m; 00323 nat m2= m << 1; 00324 C** b1= bivariate_new<C> (m2, t); 00325 C** b2= bivariate_new<C> (m2, t); 00326 bivariate_encode (b1, src , n, m, t); 00327 bivariate_encode (b2, dest, n, m, t); 00328 nat eta= (t == m2) ? 1 : 2; 00329 variable_dilate (b1, eta, m2, t); 00330 variable_dilate (b2, eta, m2, t); 00331 direct_transform (b1, m2, t); 00332 direct_transform (b2, m2, t); 00333 for (nat i = 0; i < t; i++) 00334 r= mul_negative_cyclic (b1[i], b2[i], m2, shift); 00335 bivariate_delete<C> (b2, m2, t); 00336 inverse_transform (b1, m2, t, shift); 00337 variable_dilate (b1, (m2 << 1) - eta, m2, t); 00338 bivariate_decode (dest, (const C**) b1, n, m, t); 00339 bivariate_delete<C> (b1, m2, t); 00340 return r + ((k+1) >> 1); } } 00341 00342 TMPL static inline void 00343 mul_negative_cyclic_truncated (C* dest, const C* src, nat len) { 00344 // dest *= src mod (x^n + 1) 00345 // TFT is used in first round 00346 nat n= next_power_of_two (len); 00347 ASSERT (n != 0, "maximum size exceeded"); 00348 if (n <= max ((nat) 4, (nat) Threshold(C,Th))) { 00349 nat l= aligned_size<C,V> (2*n-1); 00350 C* temp= mmx_new<C> (l); 00351 Inner::mul (temp, dest, src, n, n); 00352 Vec::copy (dest, temp, n); 00353 Vec::sub (dest, temp + n, n - 1); 00354 mmx_delete<C> (temp, l); 00355 return; 00356 } 00357 else { 00358 nat k = log_2 (n); 00359 nat m = (nat) 1 << (k >> 1); 00360 nat t = n / m; 00361 nat r = (len + m - 1) / m; 00362 nat m2= m << 1; 00363 C** b1= bivariate_new<C> (m2, t); 00364 C** b2= bivariate_new<C> (m2, t); 00365 bivariate_encode (b1, src , n, m, t); 00366 bivariate_encode (b2, dest, n, m, t); 00367 nat eta= (t == m2) ? 1 : 2; 00368 variable_dilate (b1, eta, m2, r); 00369 variable_dilate (b2, eta, m2, r); 00370 direct_transform_truncated (b1, m2, r); 00371 direct_transform_truncated (b2, m2, r); 00372 for (nat i = 0; i < r; i++) 00373 mul_negative_cyclic (b1[i], b2[i], m2); 00374 bivariate_delete<C> (b2, m2, t); 00375 inverse_transform_truncated (b1, m2, r); 00376 variable_dilate (b1, (m2 << 1) - eta, m2, r); 00377 bivariate_decode (dest, (const C**) b1, n, m, t); 00378 bivariate_delete<C> (b1, m2, t); } } 00379 00380 TMPL static inline nat 00381 mul (C* dest, const C* s1, const C* s2, nat n1, nat n2, bool shift=true) { 00382 nat ret = 0; 00383 nat len = n1 + n2 - 1; 00384 nat n = next_power_of_two (len); 00385 nat l = aligned_size<C,V> (n); 00386 C* t1 = mmx_new<C> (l); 00387 C* t2 = mmx_new<C> (l); 00388 Vec::copy (t1, s1, n1); Vec::clear (t1 + n1, n - n1); 00389 Vec::copy (t2, s2, n2); Vec::clear (t2 + n2, n - n2); 00390 if (shift) 00391 mul_negative_cyclic_truncated (t1, t2, len); 00392 else 00393 ret= mul_negative_cyclic (t1, t2, n, shift); 00394 Vec::copy (dest, t1, len); 00395 mmx_delete<C> (t1, l); 00396 mmx_delete<C> (t2, l); 00397 return ret; } 00398 00399 }; // implementation<polynomial_multiply,V,polynomial_schonhage_strassen_inc<W,Th> > 00400 00401 #undef TMPL 00402 } // namespace mmx 00403 #endif //__MMX__POLYNOMIAL_SCHONHAGE_STRASSEN__HPP