algebramix_doc 0.3
|
00001 00002 /****************************************************************************** 00003 * MODULE : matrix_crt.hpp 00004 * DESCRIPTION: Multi-modular multiplication of matrices 00005 * COPYRIGHT : (C) 2009 Joris van der Hoeven and Gregoire Lecerf 00006 ******************************************************************************* 00007 * This software falls under the GNU general public license and comes WITHOUT 00008 * ANY WARRANTY WHATSOEVER. See the file $TEXMACS_PATH/LICENSE for more details. 00009 * If you don't have this file, write to the Free Software Foundation, Inc., 00010 * 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. 00011 ******************************************************************************/ 00012 00013 #ifndef __MMX_MATRIX_CRT_HPP 00014 #define __MMX_MATRIX_CRT_HPP 00015 #include <numerix/modular.hpp> 00016 #include <algebramix/matrix.hpp> 00017 #include <algebramix/crt_blocks.hpp> 00018 namespace mmx { 00019 00020 /****************************************************************************** 00021 * Special variant to avoid colisions with other global moduli 00022 ******************************************************************************/ 00023 00024 template<typename C> 00025 struct modular_matrix_crt { 00026 template<typename M> 00027 class modulus_storage { 00028 static inline M& dyn_modulus () { 00029 static M modulus = M (); 00030 return modulus; } 00031 public: 00032 static inline void set_modulus (const M& p) { dyn_modulus () = p; } 00033 static inline M get_modulus () { return dyn_modulus (); } 00034 }; 00035 }; 00036 00037 /****************************************************************************** 00038 * Variant for matrices with CRT based operations 00039 ******************************************************************************/ 00040 00041 template<typename V> 00042 struct matrix_crt: public V { 00043 typedef typename V::Vec Vec; 00044 typedef typename V::Naive Naive; 00045 typedef typename V::Positive Positive; 00046 typedef matrix_crt<typename V::No_simd> No_simd; 00047 typedef matrix_crt<typename V::No_thread> No_thread; 00048 typedef matrix_crt<typename V::No_scaled> No_scaled; 00049 }; 00050 00051 template<typename F, typename V, typename W> 00052 struct implementation<F,V,matrix_crt<W> >: 00053 public implementation<F,V,W> {}; 00054 00055 /****************************************************************************** 00056 * Default non-archimedian size helper 00057 ******************************************************************************/ 00058 00059 template<typename C> 00060 struct matrix_crt_multiply_helper { 00061 static const nat dimension_threshold= 7; 00062 static const nat ratio_threshold= 100; // this is a percentage 00063 00064 typedef crt_naive_transformer<C> crt_transformer; 00065 typedef moduli_helper<C, 00066 modulus<typename crt_transformer::modulus_base, 00067 typename crt_transformer::modulus_base_variant> > moduli_sequence; 00068 00069 static nat size (const C* s1, nat s1_rs, nat s1_cs, 00070 const C* s2, nat s2_rs, nat s2_cs, 00071 nat r, nat l, nat c) { 00072 nat sz= 0; 00073 for (nat k= 0; k < l; k++) { 00074 nat sz1= 0, sz2= 0; 00075 const C* ss1= s1 + k * s1_cs; 00076 const C* ss2= s2 + k * s2_rs; 00077 for (nat i= 0; i < r; i++, ss1 += s1_rs) sz1= max (sz1, N (*ss1)); 00078 for (nat j= 0; j < c; j++, ss2 += s2_cs) sz2= max (sz2, N (*ss2)); 00079 sz= max (sz, sz1 + sz2); 00080 } 00081 return sz; } 00082 }; 00083 00084 /****************************************************************************** 00085 * Multi-modular multiplication of matrices 00086 ******************************************************************************/ 00087 00088 template<typename V, typename W> 00089 struct implementation<matrix_multiply,V,matrix_crt<W> >: 00090 public implementation<matrix_multiply_base,V> 00091 { 00092 typedef implementation<matrix_multiply,W> Mat; 00093 00094 template<typename Op, typename D, typename S1, typename S2> 00095 static inline void 00096 mul (D* d, const S1* s1, const S2* s2, 00097 nat r, nat rr, nat l, nat ll, nat c, nat cc) { 00098 Mat::template mul<Op> (d, s1, s2, r, rr, l, ll, c, cc); } 00099 00100 template<typename D, typename S1, typename S2> 00101 static inline void 00102 mul (D* d, const S1* s1, const S2* s2, 00103 nat r, nat l, nat c) { 00104 Mat::template mul<mul_op> (d, s1, s2, r, r, l, l, c, c); } 00105 00106 template<typename C, typename I, typename MV, typename Crter> 00107 static void 00108 mat_direct_crt (matrix<I,MV>* dest, const C* s, 00109 nat s_rs, nat s_cs, nat r, nat c, Crter& crter) { 00110 nat n= N(crter); 00111 for (nat k= 0; k < n; k++) 00112 dest[k]= matrix<I,MV> (I (), r, c); 00113 I* aux= mmx_new<I> (n); 00114 for (nat i= 0; i < r; i++) 00115 for (nat j= 0; j < c; j++) { 00116 direct_crt (aux, s[i * s_rs + j * s_cs], crter); 00117 for (nat k= 0; k < n ; k++) dest[k](i,j)= aux[k]; 00118 } 00119 mmx_delete<I> (aux, n); } 00120 00121 template<typename C, typename Modulus, typename MW, typename MV, typename Crter> 00122 static void 00123 mat_direct_crt (matrix<modular<Modulus,MW>,MV>* dest, const C* s, 00124 nat s_rs, nat s_cs, nat r, nat c, Crter& crter) { 00125 typedef modular<Modulus,MW> Modular; 00126 typedef typename Modular::modulus::base I; 00127 nat n= N(crter); 00128 for (nat k= 0; k < n; k++) 00129 dest[k]= matrix<Modular,MV> (Modular (), r, c); 00130 I* aux= mmx_new<I> (n); 00131 for (nat i= 0; i < r; i++) 00132 for (nat j= 0; j < c; j++) { 00133 direct_crt (aux, s[i * s_rs + j * s_cs], crter); 00134 for (nat k= 0; k < n ; k++) dest[k](i,j)= Modular (aux[k], true); 00135 } 00136 mmx_delete<I> (aux, n); } 00137 00138 template<typename C, typename I, typename MV, typename Crter> 00139 static void 00140 mat_inverse_crt (C* d, nat d_rs, nat d_cs, nat r, nat c, 00141 const matrix<I,MV>* s, Crter& crter) { 00142 nat n= N(crter); 00143 I* aux= mmx_new<I> (n); 00144 for (nat i= 0; i < r; i++) 00145 for (nat j= 0; j < c; j++) { 00146 for (nat k= 0; k < n ; k++) aux[k]= s[k](i,j); 00147 inverse_crt (d[i * d_rs + j * d_cs], aux, crter); 00148 } 00149 mmx_delete<I> (aux, n); } 00150 00151 template<typename C, typename Modulus, typename MW, typename MV, typename Crter> 00152 static void 00153 mat_inverse_crt (C* d, nat d_rs, nat d_cs, nat r, nat c, 00154 const matrix<modular<Modulus,MW>,MV>* s, Crter& crter) { 00155 typedef modular<Modulus,MW> Modular; 00156 typedef typename Modular::modulus::base I; 00157 nat n= N(crter); 00158 I* aux= mmx_new<I> (n); 00159 for (nat i= 0; i < r; i++) 00160 for (nat j= 0; j < c; j++) { 00161 for (nat k= 0; k < n ; k++) aux[k]= * s[k](i,j); 00162 inverse_crt (d[i * d_rs + j * d_cs], aux, crter); 00163 } 00164 mmx_delete<I> (aux, n); } 00165 00166 template<typename C, typename Crter> static void 00167 mul (C* d, const C* s1, const C* s2, nat r, nat l, nat c, 00168 Crter& crter) { 00169 typedef typename Crter::modulus_base I; 00170 typedef modulus<I,typename Crter::modulus_base_variant> Modulus; 00171 typedef modular<Modulus,modular_matrix_crt<C> > Modular; 00172 typedef matrix<Modular> Matrix_modular; 00173 nat n= N(crter); 00174 Matrix_modular* mm1= mmx_new<Matrix_modular> (n); 00175 Matrix_modular* mm2= mmx_new<Matrix_modular> (n); 00176 Matrix_modular* mmd= mmx_new<Matrix_modular> (n); 00177 mat_direct_crt (mm1, s1, Mat::index (1, 0, r, l), Mat::index (0, 1, r, l), 00178 r, l, crter); 00179 mat_direct_crt (mm2, s2, Mat::index (1, 0, l, c), Mat::index (0, 1, l, c), 00180 l, c, crter); 00181 for (nat k= 0; k < n; k++) { 00182 Modular::set_modulus (crter[k]); 00183 mmd[k]= mm1[k] * mm2[k]; 00184 } 00185 mat_inverse_crt (d, Mat::index (1, 0, r, c), Mat::index (0, 1, r, c), 00186 r, c, mmd, crter); 00187 mmx_delete<Matrix_modular> (mm1, n); 00188 mmx_delete<Matrix_modular> (mm2, n); 00189 mmx_delete<Matrix_modular> (mmd, n); } 00190 00191 template<typename C, typename S, typename CV> static void 00192 mul (C* d, const C* s1, const C* s2, nat r, nat l, nat c, 00193 crt_naive_transformer<C,S,CV>& crter) { 00194 typedef implementation<crt_project,CV> Crt; 00195 typedef crt_naive_transformer<C,S,CV> Crter; 00196 typedef typename Crter::modulus_base I; 00197 typedef modulus<I,typename Crter::modulus_base_variant> Modulus; 00198 typedef modular<Modulus,modular_matrix_crt<C> > Modular; 00199 typedef typename Matrix_variant(Modular) MV; 00200 typedef implementation<matrix_multiply,MV> Mat_mod; 00201 00202 nat spc1= aligned_size<Modular,V> (r * l); 00203 nat spc2= aligned_size<Modular,V> (l * c); 00204 nat spcd= aligned_size<Modular,V> (r * c); 00205 nat spc= spc1 + spc2 + spcd; 00206 Modular* x1= mmx_new<Modular> (spc); 00207 Modular* x2= x1 + spc1; 00208 Modular* xd= x2 + spc2; 00209 00210 if (N(crter) == 1) { 00211 Modulus p (crter[0]); Modular::set_modulus (p); 00212 for (nat i= 0; i < r * l; i++) 00213 x1[i]= Modular (Crt::encode (s1[i], p), true); 00214 for (nat i= 0; i < l * c; i++) 00215 x2[i]= Modular (Crt::encode (s2[i], p), true); 00216 Mat_mod::mul (xd, x1, x2, r, l, c); 00217 for (nat i= 0; i < r * c; i++) 00218 d[i]= Crt::decode (C(* xd[i]), crter.P, crter.H); 00219 } 00220 else { 00221 for (nat k= 0; k < N(crter); k++) { 00222 Modulus p (crter[k]); Modular::set_modulus (p); 00223 for (nat i= 0; i < r * l; i++) 00224 x1[i]= Modular (Crt::mod (Crt::encode (s1[i], crter.P), p)); 00225 for (nat i= 0; i < l * c; i++) 00226 x2[i]= Modular (Crt::mod (Crt::encode (s2[i], crter.P), p)); 00227 Mat_mod::mul (xd, x1, x2, r, l, c); 00228 I m (crter.m[k]), t; C q (crter.q[k]); 00229 for (nat i= 0; i < r * c; i++) { 00230 mul_mod (t, m, * xd[i], p); 00231 if (k == 0) d[i]= t * q; else mul_add (d[i], t, q); 00232 } 00233 } 00234 for (nat i= 0; i < r * c; i++) 00235 d[i]= Crt::decode (Crt::mod (d[i], crter.P), crter.P, crter.H); 00236 } 00237 mmx_delete<Modular> (x1, spc); } 00238 00239 template<typename C, typename Low, typename High, nat s, typename CV> 00240 static void 00241 mul (C* d, const C* s1, const C* s2, nat r, nat l, nat c, 00242 crt_blocks_transformer<Low,High,s,CV>& crter) { 00243 typedef matrix<C,matrix_crt<W> > Matrix; 00244 nat n= crter.high -> size (); 00245 if (n == 1) { 00246 mul (d, s1, s2, r, l, c, * crter.low[0]); 00247 return; 00248 } 00249 Matrix* mm1= mmx_new<Matrix> (n); 00250 Matrix* mm2= mmx_new<Matrix> (n); 00251 Matrix* mmd= mmx_new<Matrix> (n); 00252 mat_direct_crt (mm1, s1, Mat::index (1, 0, r, l), Mat::index (0, 1, r, l), 00253 r, l, * crter.high); 00254 mat_direct_crt (mm2, s2, Mat::index (1, 0, l, c), Mat::index (0, 1, l, c), 00255 l, c, * crter.high); 00256 for (nat k= 0; k < n; k++) 00257 mul (tab(mmd[k]), tab(mm1[k]), tab(mm2[k]), r, l, c, * crter.low[k]); 00258 mat_inverse_crt (d, Mat::index (1, 0, r, c), Mat::index (0, 1, r, c), 00259 r, c, mmd, * crter.high); 00260 mmx_delete<Matrix> (mm1, n); 00261 mmx_delete<Matrix> (mm2, n); 00262 mmx_delete<Matrix> (mmd, n); } 00263 00264 template<typename C> static void 00265 mul (C* d, const C* s1, const C* s2, nat r, nat l, nat c) { 00266 typedef matrix_crt_multiply_helper<C> Matrix_crt; 00267 typedef typename Matrix_crt::crt_transformer Crter; 00268 typedef typename Matrix_crt::moduli_sequence Sequence; 00269 typedef typename Crter::modulus_base I; 00270 typedef modulus<I,typename Crter::modulus_base_variant> Modulus; 00271 static const nat dim_thr= Matrix_crt::dimension_threshold; 00272 static const nat ratio_thr= Matrix_crt::ratio_threshold;; 00273 00274 if (r <= dim_thr || l <= dim_thr || c <= dim_thr) { 00275 Mat::mul (d, s1, s2, r, l, c); 00276 return; 00277 } 00278 nat sz= matrix_crt_multiply_helper<C> 00279 ::size (s1, Mat::index (1, 0, r, l), Mat::index (0, 1, r, l), 00280 s2, Mat::index (1, 0, l, c), Mat::index (0, 1, l, c), 00281 r, l, c); 00282 nat wd= (ratio_thr * sz) / 100; 00283 00284 if (r < wd || l < wd || c < wd) { 00285 Mat::mul (d, s1, s2, r, l, c); 00286 return; 00287 } 00288 vector<Modulus> mods; 00289 if (! Sequence::covering (mods, sz)) 00290 // Not enough moduli can be generated 00291 Mat::mul (d, s1, s2, r, l, c); 00292 else { 00293 Crter crter (mods, false); 00294 mul (d, s1, s2, r, l, c, crter); } } 00295 00296 }; // implementation<matrix_multiply,V,matrix_crt<W> > 00297 00298 } // namespace mmx 00299 #endif // __MMX_MATRIX_CRT_HPP