algebramix_doc 0.3
|
00001 00002 /****************************************************************************** 00003 * MODULE : polynomial_schonhage_triadic.hpp 00004 * DESCRIPTION: Schonhage's triadic cyclic wrapped product 00005 * COPYRIGHT : (C) 2008 Gregoire Lecerf 00006 ******************************************************************************* 00007 * This software falls under the GNU general public license and comes WITHOUT 00008 * ANY WARRANTY WHATSOEVER. See the file $TEXMACS_PATH/LICENSE for more details. 00009 * If you don't have this file, write to the Free Software Foundation, Inc., 00010 * 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. 00011 ******************************************************************************/ 00012 00013 // Ref. "Modern Computer Algebra", Chapter 8, Exercise 8.30. 00014 00015 #ifndef __MMX__POLYNOMIAL_SCHONHAGE_TRIADIC__HPP 00016 #define __MMX__POLYNOMIAL_SCHONHAGE_TRIADIC__HPP 00017 #include <algebramix/fft_triadic_naive.hpp> 00018 #include <algebramix/polynomial_dicho.hpp> 00019 #include <algebramix/polynomial_balanced.hpp> 00020 00021 namespace mmx { 00022 #define TMPL template<typename C> 00023 00024 /****************************************************************************** 00025 * Variant for triadic Schonhage-Strassen multiplication 00026 ******************************************************************************/ 00027 00028 struct schonhage_triadic_threshold {}; 00029 00030 template<typename V, typename Th= schonhage_triadic_threshold> 00031 struct polynomial_schonhage_triadic_inc: public V { 00032 typedef typename V::Vec Vec; 00033 typedef typename V::Naive Naive; 00034 typedef typename V::Positive Positive; 00035 typedef polynomial_schonhage_triadic_inc<typename V::No_simd,Th> No_simd; 00036 typedef polynomial_schonhage_triadic_inc<typename V::No_thread,Th> No_thread; 00037 typedef polynomial_schonhage_triadic_inc<typename V::No_scaled,Th> No_scaled; 00038 }; 00039 00040 template<typename F, typename V, typename W, typename Th> 00041 struct implementation<F,V,polynomial_schonhage_triadic_inc<W,Th> >: 00042 public implementation<F,V,W> {}; 00043 00044 DEFINE_VARIANT_1 (typename V, V, 00045 polynomial_schonhage_triadic, 00046 polynomial_balanced< 00047 polynomial_schonhage_triadic_inc< 00048 polynomial_karatsuba<V> > >) 00049 00050 /****************************************************************************** 00051 * Triadic Schonhage-Strassen multiplication 00052 ******************************************************************************/ 00053 00054 template<typename V, typename W, typename Th> 00055 struct implementation<polynomial_multiply,V, 00056 polynomial_schonhage_triadic_inc<W,Th> >: 00057 public implementation<polynomial_linear,V> 00058 { 00059 typedef implementation<vector_linear,V> Vec; 00060 typedef implementation<polynomial_linear,V> Pol; 00061 typedef implementation<polynomial_multiply,W> Inner; 00062 00063 private: 00064 TMPL static inline C** 00065 bivariate_new (nat m, nat t) { 00066 nat l = aligned_size<C,V> (m); 00067 C** dest= mmx_new<C*> (t); 00068 for (nat j = 0; j < t; j++) 00069 dest[j]= mmx_new<C> (l); 00070 return dest; } 00071 00072 TMPL static inline void 00073 bivariate_delete (C** dest, nat m, nat t) { 00074 nat l = aligned_size<C,V> (m); 00075 for (nat j = 0; j < t; j++) 00076 mmx_delete<C> (dest[j], l); 00077 mmx_delete<C*> (dest, t); } 00078 00079 TMPL static inline void 00080 bivariate_encode (C** dest, const C*s, nat n, nat m, nat t) { 00081 // dest (x, x^m) = s (x) 00082 // s has size 2n, and dest has size 2t in x^m and 2m in x. 00083 // We assume that n <= m * t 00084 nat i, j, n2= n << 1, m2= m << 1, t2= t << 1; 00085 for (i = 0, j = 0; i < n2; i += m, j++) 00086 if (n2-i >= m) { 00087 Vec::copy (dest[j], s + i, m); 00088 Vec::clear (dest[j] + m, m); 00089 } 00090 else { 00091 Vec::copy (dest[j], s + i, n2 - i); 00092 Vec::clear (dest[j] + n2 - i, m2 - (n2 - i)); 00093 } 00094 for (; j < t2; j++) 00095 Vec::clear (dest[j], m2); } 00096 00097 TMPL static inline void 00098 bivariate_decode (C* dest, const C** s, nat n, nat m, nat t) { 00099 // dest (x) = s (x, x^m) mod (x^(2n) + x^n + 1) 00100 // dest has size 2n, s has size 2t. 00101 // We assume that n = m * t 00102 nat i, n2= n << 1, m2= m << 1, t2= t << 1; 00103 Vec::clear (dest, n2); 00104 for (i = 0; i+1 < t2; i++) 00105 Pol::add (dest + i * m, s[i], m2); 00106 Pol::add (dest + i * m, s[i] , m); 00107 Pol::sub (dest , s[i] + m, m); 00108 Pol::sub (dest + n , s[i] + m, m); } 00109 00110 TMPL static inline void 00111 triadic_shift (C* dest, const C* src, nat i, nat m) { 00112 // multiply by x^i modulo x^(2*m) + x^m + 1 00113 nat m2= m << 1, m3= m2 + m; 00114 i= i % m3; 00115 if (i > m2) { 00116 Vec::copy (dest , src + m3 - i, i - m); 00117 Vec::neg (dest + i - m , src , m3 - i); 00118 Vec::sub (dest + i - m2, src , m3 - i); 00119 } 00120 else if (i > m) { 00121 Vec::neg (dest , src + m2 - i, m); 00122 Vec::neg (dest + m, src + m2 - i, m); 00123 Vec::add (dest , src + m3 - i, i - m); 00124 Vec::add (dest + i, src , m2 - i); 00125 } 00126 else { 00127 Vec::neg (dest , src + m2 - i, i); 00128 Vec::copy (dest + i, src , m2 - i); 00129 Vec::sub (dest + m, src + m2 - i, i); } } 00130 00131 template<typename C> 00132 struct unptr_helper {}; 00133 00134 template<typename C> 00135 struct unptr_helper<C*> { 00136 typedef C type; }; 00137 00138 template<typename Cp> 00139 struct triadic_roots_helper { 00140 // modulo x^(2*m) + x^m + 1 00141 typedef Cp C; 00142 typedef nat U; 00143 typedef typename unptr_helper<Cp>::type CC; 00144 typedef CC S; 00145 00146 static inline nat& 00147 dyn_m () { 00148 static nat m= 0; 00149 return m; } 00150 00151 static inline C& 00152 dyn_temp (nat i) { 00153 static C temp0= NULL; 00154 static C temp1= NULL; 00155 static C temp2= NULL; 00156 static C temp3= NULL; 00157 switch (i) { 00158 case 0: return temp0; 00159 case 1: return temp1; 00160 case 2: return temp2; 00161 case 3: return temp3; 00162 default: ERROR ("index out of range"); return temp0; } } 00163 00164 static inline nat 00165 primitive_root (nat t, nat i) { 00166 if (t == 0) return 1; 00167 i = i % t; 00168 nat m3 = 3 * dyn_m (); 00169 ASSERT (m3 % t == 0, "primitive root out of range"); 00170 return i * (m3 / t); } 00171 00172 static U* 00173 create_roots (nat t) { 00174 nat m2= dyn_m () << 1; 00175 nat l= aligned_size<CC,V> (m2); 00176 for (nat i = 0; i <= 3; i++) 00177 dyn_temp (i)= mmx_new<CC> (l); 00178 U* roots= mmx_new<U> (t); 00179 for (nat i = 0; i < t; i++) 00180 roots[i] = primitive_root (t, digit_mirror_triadic (i, t)); 00181 return roots; } 00182 00183 static U* 00184 create_stoor (nat t) { 00185 U* stoor= mmx_new<U> (t); 00186 for (nat i = 0; i < t; i++) 00187 stoor[i]= primitive_root (t, i == 0 ? 00188 0 : t - digit_mirror_triadic (i, t)); 00189 return stoor; } 00190 00191 static void 00192 destroy_roots (U* u, nat t) { 00193 nat m2= dyn_m () << 1; 00194 nat l= aligned_size<CC,V> (m2); 00195 for (nat i = 0; i <= 3; i++) { 00196 C& temp= dyn_temp (i); 00197 if (temp != NULL) { 00198 mmx_delete<CC> (temp, l); 00199 temp= NULL; 00200 } 00201 } 00202 mmx_delete<U> (u, t); } 00203 00204 static inline void 00205 dfft_cross (C* c1, C* c2, C* c3, const U* u1, const U* u2, const U* u3) { 00206 C temp0= dyn_temp (0); 00207 C temp1= dyn_temp (1); 00208 C temp2= dyn_temp (2); 00209 C temp3= dyn_temp (3); 00210 nat m= dyn_m (), m2= m << 1; 00211 00212 triadic_shift (temp0, *c3, *u3, m); 00213 Vec::add (temp0, *c2, m2); 00214 triadic_shift (temp3, temp0, *u3, m); 00215 00216 triadic_shift (temp0, *c3, *u2, m); 00217 Vec::add (temp0, *c2, m2); 00218 triadic_shift (temp2, temp0, *u2, m); 00219 00220 triadic_shift (temp0, *c3, *u1, m); 00221 Vec::add (temp0, *c2, m2); 00222 triadic_shift (temp1, temp0, *u1, m); 00223 00224 Vec::add (*c3, *c1, temp3, m2); 00225 Vec::add (*c2, *c1, temp2, m2); 00226 Vec::add (*c1, temp1, m2); } 00227 00228 static inline void 00229 ifft_cross (C* c1, C* c2, C* c3, const U* u1, const U* u2, const U* u3) { 00230 C temp0= dyn_temp (0); 00231 C temp1= dyn_temp (1); 00232 C temp2= dyn_temp (2); 00233 C temp3= dyn_temp (3); 00234 nat m= dyn_m (), m2= m << 1; 00235 00236 triadic_shift (temp1, *c1, *u1, m); 00237 Vec::add (*c1, *c2, m2); 00238 Vec::add (*c1, *c3, m2); 00239 00240 triadic_shift (temp2, *c2, *u2, m); 00241 triadic_shift (temp3, *c3, *u3, m); 00242 triadic_shift (temp0, *c2, *u3, m); 00243 Vec::add (*c2, temp1, temp2, m2); 00244 Vec::add (*c2, temp3, m2); 00245 00246 triadic_shift (temp3, *c3, *u2, m); 00247 Vec::add (temp0, temp3, m2); 00248 Vec::add (temp0, temp1, m2); 00249 triadic_shift (*c3, temp0, *u1, m); } 00250 00251 static inline void 00252 fft_shift (C* dest, S v, nat t) { 00253 nat m2= dyn_m () << 1; 00254 for (nat i = 0; i < t; i++) 00255 Vec::mul (dest[i], v, m2); } 00256 }; 00257 00258 struct triadic_roots { 00259 template<typename C> 00260 struct helper { 00261 typedef triadic_roots_helper<C> roots_type; }; 00262 }; 00263 00264 TMPL static void 00265 direct_transform (C** b, nat m, nat t) { 00266 triadic_roots_helper<C*>::dyn_m ()= m; 00267 fft_triadic_naive_transformer<C*, triadic_roots> ffter (t); 00268 ffter.direct_transform_triadic (b); } 00269 00270 TMPL static void 00271 inverse_transform (C** b, nat m, nat t, bool shift=true) { 00272 triadic_roots_helper<C*>::dyn_m ()= m; 00273 fft_triadic_naive_transformer<C*, triadic_roots> ffter (t); 00274 ffter.inverse_transform_triadic (b,shift); } 00275 00276 TMPL static void 00277 variable_dilate (C** b, nat eta, nat m, nat t) { 00278 // substitute eta * y for y in b of size t 00279 nat m2= m << 1, m3= m2 + m; 00280 nat l= aligned_size<C,V> (m2); 00281 C* temp= mmx_new<C> (l); 00282 nat w= 0; 00283 for (nat i = 0; i < t; i++) { 00284 Vec::copy (temp, b[i], m2); 00285 triadic_shift (b[i], temp, w, m); 00286 w= (w + eta) % m3; 00287 } 00288 mmx_delete<C> (temp, l); } 00289 00290 TMPL static void 00291 bivariate_mod (C** dest, const C** h, nat w, nat m, nat t) { 00292 // dest = h mod (y^t - w) 00293 nat m2= m << 1, m3= m2 + m; 00294 w= w % m3; 00295 for (nat i = 0; i < t; i++) { 00296 triadic_shift (dest[i], h[i+t], w, m); 00297 Vec::add (dest[i], h[i], m2); } } 00298 00299 TMPL static void 00300 bivariate_crt (C** h, const C** h1, const C** h2, nat w, 00301 nat m, nat t, bool shift=true) { 00302 // h1 = h mod (y^t - w), h2 = h mod (y^t - w^2). 00303 nat m2= m << 1, m3= m2 + m, t2= t << 1; 00304 w= w % m3; 00305 nat w2= (w << 1) % m3; 00306 nat l= aligned_size<C,V> (m2); 00307 C* temp= mmx_new<C> (l); 00308 for (nat i = 0; i < t; i++) 00309 Vec::sub (h[i+t], h2[i], h1[i], m2); 00310 for (nat i = 0; i < t; i++) { 00311 triadic_shift (h[i], h1[i], w2, m); 00312 triadic_shift (temp, h2[i], w , m); 00313 Vec::sub (h[i], temp, m2); 00314 } 00315 for (nat i = 0; i < t2; i++) { 00316 triadic_shift (temp, h[i], w, m); 00317 Vec::add (h[i], temp, m2); 00318 Vec::add (h[i], temp, m2); 00319 if (shift) 00320 Vec::mul (h[i], invert (C(3)), m2); 00321 } 00322 mmx_delete<C> (temp, l); } 00323 00324 public: 00325 TMPL static inline nat 00326 mul_triadic (C* dest, const C* src, nat n, bool shift=true) { 00327 // dest *= src mod (x^(2*n) + x^n + 1) 00328 nat n2= n << 1; 00329 nat u = next_power_of_three (n); 00330 ASSERT (u != 0, "maximum size exceeded"); 00331 ASSERT (n == u, "power of three expected"); 00332 if (n <= max ((nat) 9, (nat) Threshold(C,Th))) { 00333 nat l= aligned_size<C,V> (2*n2-1); 00334 C* temp= mmx_new<C> (l); 00335 Inner::mul (temp, dest, src, n2, n2); 00336 Vec::copy (dest , temp , n2); 00337 Vec::add (dest , temp + n2 + n, n - 1); 00338 Vec::sub (dest , temp + n2 , n); 00339 Vec::sub (dest + n, temp + n2 , n); 00340 mmx_delete<C> (temp, l); 00341 return 0; 00342 } 00343 else { 00344 nat r = 0; 00345 nat k = log_3 (n); 00346 nat m = binpow (3, (k+1) >> 1), m2= m << 1, m3= m2 + m; 00347 nat t = n / m, t2= t << 1; 00348 nat eta= (t == m) ? 1 : 3; 00349 // mmout << "n = " << n << ", m = " << m << ", t = " << t << "\n"; 00350 C** hh= bivariate_new<C> (m2, t2); 00351 C** gg= bivariate_new<C> (m2, t2); 00352 C** h1= bivariate_new<C> (m2, t); 00353 C** h2= bivariate_new<C> (m2, t); 00354 C** g = bivariate_new<C> (m2, t); 00355 bivariate_encode (hh, src , n, m, t); 00356 bivariate_encode (gg, dest, n, m, t); 00357 for (nat j = 1; j <= 2; j++) { 00358 C** h= j == 1 ? h1 : h2; 00359 bivariate_mod (h, (const C**) hh, (j*eta*t) % m3, m, t); 00360 bivariate_mod (g, (const C**) gg, (j*eta*t) % m3, m, t); 00361 variable_dilate (h, (j*eta) % m3, m, t); 00362 variable_dilate (g, (j*eta) % m3, m, t); 00363 direct_transform (h, m, t); 00364 direct_transform (g, m, t); 00365 for (nat i = 0; i < t; i++) 00366 r = mul_triadic (h[i], g[i], m, shift); 00367 inverse_transform (h, m, t, shift); 00368 variable_dilate (h, m3 - ((j*eta) % m3), m, t); 00369 } 00370 bivariate_delete<C> (g , m2 , t); 00371 bivariate_delete<C> (gg, m2, t2); 00372 bivariate_crt (hh, (const C**) h1, (const C**) h2, (eta*t) % m3, m, t, shift); 00373 bivariate_delete<C> (h1, m2, t); 00374 bivariate_delete<C> (h2, m2, t); 00375 bivariate_decode (dest, (const C**) hh, n, m, t); 00376 bivariate_delete<C> (hh, m2, t2); 00377 return (shift) ? 0 : r + 1 + (k >> 1); }} 00378 00379 TMPL static inline nat 00380 mul (C* dest, const C* s1, const C* s2, nat n1, nat n2, bool shift=true) { 00381 nat len = n1 + n2 - 1; 00382 nat n = next_power_of_three ((len + 1) >> 1); 00383 nat nn= n << 1; 00384 nat l = aligned_size<C,V> (nn); 00385 C* t1 = mmx_new<C> (l); 00386 C* t2 = mmx_new<C> (l); 00387 Vec::copy (t1, s1, n1); Vec::clear (t1 + n1, nn - n1); 00388 Vec::copy (t2, s2, n2); Vec::clear (t2 + n2, nn - n2); 00389 nat k = mul_triadic (t1, t2, n, shift); 00390 Vec::copy (dest, t1, len); 00391 mmx_delete<C> (t1, l); 00392 mmx_delete<C> (t2, l); 00393 return k; } 00394 00395 }; // implementation<polynomial_multiply,V,polynomial_schonhage_triadic_inc<W,Th> > 00396 00397 #undef TMPL 00398 } // namespace mmx 00399 #endif //__MMX__POLYNOMIAL_SCHONHAGE_TRIADIC__HPP