algebramix_doc 0.3
/Users/mourrain/Devel/mmx/algebramix/include/algebramix/matrix_crt.hpp
Go to the documentation of this file.
00001 
00002 /******************************************************************************
00003 * MODULE     : matrix_crt.hpp
00004 * DESCRIPTION: Multi-modular multiplication of matrices
00005 * COPYRIGHT  : (C) 2009  Joris van der Hoeven and Gregoire Lecerf
00006 *******************************************************************************
00007 * This software falls under the GNU general public license and comes WITHOUT
00008 * ANY WARRANTY WHATSOEVER. See the file $TEXMACS_PATH/LICENSE for more details.
00009 * If you don't have this file, write to the Free Software Foundation, Inc.,
00010 * 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
00011 ******************************************************************************/
00012 
00013 #ifndef __MMX_MATRIX_CRT_HPP
00014 #define __MMX_MATRIX_CRT_HPP
00015 #include <numerix/modular.hpp>
00016 #include <algebramix/matrix.hpp>
00017 #include <algebramix/crt_blocks.hpp>
00018 namespace mmx {
00019 
00020 /******************************************************************************
00021 * Special variant to avoid colisions with other global moduli
00022 ******************************************************************************/
00023 
00024 template<typename C>
00025 struct modular_matrix_crt {
00026   template<typename M>
00027   class modulus_storage {
00028     static inline M& dyn_modulus () {
00029       static M modulus = M ();
00030       return modulus; }
00031   public:
00032     static inline void set_modulus (const M& p) { dyn_modulus () = p; }
00033     static inline M get_modulus () { return dyn_modulus (); }
00034   };
00035 };
00036 
00037 /******************************************************************************
00038 * Variant for matrices with CRT based operations
00039 ******************************************************************************/
00040 
00041 template<typename V>
00042 struct matrix_crt: public V {
00043   typedef typename V::Vec Vec;
00044   typedef typename V::Naive Naive;
00045   typedef typename V::Positive Positive;
00046   typedef matrix_crt<typename V::No_simd> No_simd;
00047   typedef matrix_crt<typename V::No_thread> No_thread;
00048   typedef matrix_crt<typename V::No_scaled> No_scaled;
00049 };
00050 
00051 template<typename F, typename V, typename W>
00052 struct implementation<F,V,matrix_crt<W> >:
00053   public implementation<F,V,W> {};
00054 
00055 /******************************************************************************
00056 * Default non-archimedian size helper
00057 ******************************************************************************/
00058 
00059 template<typename C>
00060 struct matrix_crt_multiply_helper {
00061   static const nat dimension_threshold= 7;
00062   static const nat ratio_threshold= 100; // this is a percentage
00063 
00064   typedef crt_naive_transformer<C> crt_transformer;
00065   typedef moduli_helper<C,
00066         modulus<typename crt_transformer::modulus_base,
00067           typename crt_transformer::modulus_base_variant> > moduli_sequence;
00068 
00069   static nat size (const C* s1, nat s1_rs, nat s1_cs,
00070             const C* s2, nat s2_rs, nat s2_cs,
00071             nat r, nat l, nat c) {
00072     nat sz= 0;
00073     for (nat k= 0; k < l; k++) {
00074       nat sz1= 0, sz2= 0;
00075       const C* ss1= s1 + k * s1_cs;
00076       const C* ss2= s2 + k * s2_rs;
00077       for (nat i= 0; i < r; i++, ss1 += s1_rs) sz1= max (sz1, N (*ss1));
00078       for (nat j= 0; j < c; j++, ss2 += s2_cs) sz2= max (sz2, N (*ss2));
00079       sz= max (sz, sz1 + sz2);
00080     }
00081     return sz; }
00082 };
00083 
00084 /******************************************************************************
00085 * Multi-modular multiplication of matrices
00086 ******************************************************************************/
00087 
00088 template<typename V, typename W>
00089 struct implementation<matrix_multiply,V,matrix_crt<W> >:
00090   public implementation<matrix_multiply_base,V>
00091 {
00092   typedef implementation<matrix_multiply,W> Mat;
00093 
00094 template<typename Op, typename D, typename S1, typename S2>
00095 static inline void
00096 mul (D* d, const S1* s1, const S2* s2,
00097      nat r, nat rr, nat l, nat ll, nat c, nat cc) {
00098   Mat::template mul<Op> (d, s1, s2, r, rr, l, ll, c, cc); }
00099 
00100 template<typename D, typename S1, typename S2>
00101 static inline void
00102 mul (D* d, const S1* s1, const S2* s2,
00103      nat r, nat l, nat c) {
00104   Mat::template mul<mul_op> (d, s1, s2, r, r, l, l, c, c); }
00105 
00106 template<typename C, typename I, typename MV, typename Crter>
00107 static void
00108 mat_direct_crt (matrix<I,MV>* dest, const C* s,
00109                 nat s_rs, nat s_cs, nat r, nat c, Crter& crter) {
00110   nat n= N(crter);
00111   for (nat k= 0; k < n; k++)
00112     dest[k]= matrix<I,MV> (I (), r, c);
00113   I* aux= mmx_new<I> (n);
00114   for (nat i= 0; i < r; i++)
00115     for (nat j= 0; j < c; j++) {
00116       direct_crt (aux, s[i * s_rs + j * s_cs], crter);
00117       for (nat k= 0; k < n ; k++) dest[k](i,j)= aux[k];
00118     }
00119   mmx_delete<I> (aux, n); }
00120 
00121 template<typename C, typename Modulus, typename MW, typename MV, typename Crter>
00122 static void
00123 mat_direct_crt (matrix<modular<Modulus,MW>,MV>* dest, const C* s,
00124                 nat s_rs, nat s_cs, nat r, nat c, Crter& crter) {
00125   typedef modular<Modulus,MW> Modular;
00126   typedef typename Modular::modulus::base I;
00127   nat n= N(crter);
00128   for (nat k= 0; k < n; k++)
00129     dest[k]= matrix<Modular,MV> (Modular (), r, c);
00130   I* aux= mmx_new<I> (n);
00131   for (nat i= 0; i < r; i++)
00132     for (nat j= 0; j < c; j++) {
00133       direct_crt (aux, s[i * s_rs + j * s_cs], crter);
00134       for (nat k= 0; k < n ; k++) dest[k](i,j)= Modular (aux[k], true);
00135     }
00136   mmx_delete<I> (aux, n); }
00137 
00138 template<typename C, typename I, typename MV, typename Crter>
00139 static void
00140 mat_inverse_crt (C* d, nat d_rs, nat d_cs, nat r, nat c,
00141                  const matrix<I,MV>* s, Crter& crter) {
00142   nat n= N(crter);
00143   I* aux= mmx_new<I> (n);
00144   for (nat i= 0; i < r; i++)
00145     for (nat j= 0; j < c; j++) {
00146       for (nat k= 0; k < n ; k++) aux[k]= s[k](i,j);
00147       inverse_crt (d[i * d_rs + j * d_cs], aux, crter);
00148     }
00149   mmx_delete<I> (aux, n); }
00150 
00151 template<typename C, typename Modulus, typename MW, typename MV, typename Crter>
00152 static void
00153 mat_inverse_crt (C* d, nat d_rs, nat d_cs, nat r, nat c,
00154                  const matrix<modular<Modulus,MW>,MV>* s, Crter& crter) {
00155   typedef modular<Modulus,MW> Modular;
00156   typedef typename Modular::modulus::base I;
00157   nat n= N(crter);
00158   I* aux= mmx_new<I> (n);
00159   for (nat i= 0; i < r; i++)
00160     for (nat j= 0; j < c; j++) {
00161       for (nat k= 0; k < n ; k++) aux[k]= * s[k](i,j);
00162       inverse_crt (d[i * d_rs + j * d_cs], aux, crter);
00163     }
00164   mmx_delete<I> (aux, n); }
00165 
00166 template<typename C, typename Crter> static void
00167 mul (C* d, const C* s1, const C* s2, nat r, nat l, nat c,
00168      Crter& crter) {
00169   typedef typename Crter::modulus_base I;
00170   typedef modulus<I,typename Crter::modulus_base_variant> Modulus;
00171   typedef modular<Modulus,modular_matrix_crt<C> > Modular;
00172   typedef matrix<Modular> Matrix_modular;
00173   nat n= N(crter);
00174   Matrix_modular* mm1= mmx_new<Matrix_modular> (n);
00175   Matrix_modular* mm2= mmx_new<Matrix_modular> (n);
00176   Matrix_modular* mmd= mmx_new<Matrix_modular> (n);
00177   mat_direct_crt (mm1, s1, Mat::index (1, 0, r, l), Mat::index (0, 1, r, l),
00178                   r, l, crter);
00179   mat_direct_crt (mm2, s2, Mat::index (1, 0, l, c), Mat::index (0, 1, l, c),
00180                   l, c, crter);
00181   for (nat k= 0; k < n; k++) {
00182     Modular::set_modulus (crter[k]);
00183     mmd[k]= mm1[k] * mm2[k];
00184   }
00185   mat_inverse_crt (d, Mat::index (1, 0, r, c), Mat::index (0, 1, r, c),
00186                    r, c, mmd, crter);
00187   mmx_delete<Matrix_modular> (mm1, n);
00188   mmx_delete<Matrix_modular> (mm2, n);
00189   mmx_delete<Matrix_modular> (mmd, n); }
00190 
00191 template<typename C, typename S, typename CV> static void
00192 mul (C* d, const C* s1, const C* s2, nat r, nat l, nat c,
00193      crt_naive_transformer<C,S,CV>& crter) {
00194   typedef implementation<crt_project,CV> Crt;
00195   typedef crt_naive_transformer<C,S,CV> Crter;
00196   typedef typename Crter::modulus_base I;
00197   typedef modulus<I,typename Crter::modulus_base_variant> Modulus;
00198   typedef modular<Modulus,modular_matrix_crt<C> > Modular;
00199   typedef typename Matrix_variant(Modular) MV;
00200   typedef implementation<matrix_multiply,MV> Mat_mod;
00201 
00202   nat spc1= aligned_size<Modular,V> (r * l);
00203   nat spc2= aligned_size<Modular,V> (l * c);
00204   nat spcd= aligned_size<Modular,V> (r * c);
00205   nat spc= spc1 + spc2 + spcd;
00206   Modular* x1= mmx_new<Modular> (spc);
00207   Modular* x2= x1 + spc1;
00208   Modular* xd= x2 + spc2;
00209 
00210   if (N(crter) == 1) {
00211     Modulus p (crter[0]); Modular::set_modulus (p);
00212     for (nat i= 0; i < r * l; i++)
00213       x1[i]= Modular (Crt::encode (s1[i], p), true);
00214     for (nat i= 0; i < l * c; i++)
00215       x2[i]= Modular (Crt::encode (s2[i], p), true);
00216     Mat_mod::mul (xd, x1, x2, r, l, c);
00217     for (nat i= 0; i < r * c; i++)
00218       d[i]= Crt::decode (C(* xd[i]), crter.P, crter.H);
00219   }
00220   else {
00221     for (nat k= 0; k < N(crter); k++) {
00222       Modulus p (crter[k]); Modular::set_modulus (p);
00223       for (nat i= 0; i < r * l; i++)
00224         x1[i]= Modular (Crt::mod (Crt::encode (s1[i], crter.P), p));
00225       for (nat i= 0; i < l * c; i++)
00226         x2[i]= Modular (Crt::mod (Crt::encode (s2[i], crter.P), p));
00227       Mat_mod::mul (xd, x1, x2, r, l, c);
00228       I m (crter.m[k]), t; C q (crter.q[k]);
00229       for (nat i= 0; i < r * c; i++) {
00230         mul_mod (t, m, * xd[i], p);
00231         if (k == 0) d[i]= t * q; else mul_add (d[i], t, q);
00232       }
00233     }
00234     for (nat i= 0; i < r * c; i++)
00235       d[i]= Crt::decode (Crt::mod (d[i], crter.P), crter.P, crter.H);
00236   }
00237   mmx_delete<Modular> (x1, spc); }
00238 
00239 template<typename C, typename Low, typename High, nat s, typename CV>
00240 static void
00241 mul (C* d, const C* s1, const C* s2, nat r, nat l, nat c,
00242      crt_blocks_transformer<Low,High,s,CV>& crter) {
00243   typedef matrix<C,matrix_crt<W> > Matrix;
00244   nat n= crter.high -> size ();
00245   if (n == 1) {
00246     mul (d, s1, s2, r, l, c, * crter.low[0]);
00247     return;
00248   }
00249   Matrix* mm1= mmx_new<Matrix> (n);
00250   Matrix* mm2= mmx_new<Matrix> (n);
00251   Matrix* mmd= mmx_new<Matrix> (n);
00252   mat_direct_crt (mm1, s1, Mat::index (1, 0, r, l), Mat::index (0, 1, r, l),
00253                   r, l, * crter.high);
00254   mat_direct_crt (mm2, s2, Mat::index (1, 0, l, c), Mat::index (0, 1, l, c),
00255                   l, c, * crter.high);
00256   for (nat k= 0; k < n; k++)
00257     mul (tab(mmd[k]), tab(mm1[k]), tab(mm2[k]), r, l, c, * crter.low[k]);
00258   mat_inverse_crt (d, Mat::index (1, 0, r, c), Mat::index (0, 1, r, c),
00259                    r, c, mmd, * crter.high);
00260   mmx_delete<Matrix> (mm1, n);
00261   mmx_delete<Matrix> (mm2, n);
00262   mmx_delete<Matrix> (mmd, n); }
00263 
00264 template<typename C> static void
00265 mul (C* d, const C* s1, const C* s2, nat r, nat l, nat c) {
00266   typedef matrix_crt_multiply_helper<C> Matrix_crt;
00267   typedef typename Matrix_crt::crt_transformer Crter;
00268   typedef typename Matrix_crt::moduli_sequence Sequence;
00269   typedef typename Crter::modulus_base I;
00270   typedef modulus<I,typename Crter::modulus_base_variant> Modulus;
00271   static const nat dim_thr= Matrix_crt::dimension_threshold;
00272   static const nat ratio_thr= Matrix_crt::ratio_threshold;;
00273   
00274   if (r <= dim_thr || l <= dim_thr || c <= dim_thr) {
00275     Mat::mul (d, s1, s2, r, l, c);
00276     return;
00277   }
00278   nat sz= matrix_crt_multiply_helper<C>
00279     ::size (s1, Mat::index (1, 0, r, l), Mat::index (0, 1, r, l),
00280             s2, Mat::index (1, 0, l, c), Mat::index (0, 1, l, c),
00281             r, l, c);
00282   nat wd= (ratio_thr * sz) / 100; 
00283 
00284   if (r < wd || l < wd || c < wd) {
00285     Mat::mul (d, s1, s2, r, l, c);
00286     return;
00287   }
00288   vector<Modulus> mods;
00289   if (! Sequence::covering (mods, sz))
00290     // Not enough moduli can be generated
00291     Mat::mul (d, s1, s2, r, l, c);
00292   else {
00293     Crter crter (mods, false);
00294     mul (d, s1, s2, r, l, c, crter); } }
00295 
00296 }; // implementation<matrix_multiply,V,matrix_crt<W> >
00297 
00298 } // namespace mmx
00299 #endif // __MMX_MATRIX_CRT_HPP
 All Classes Namespaces Files Functions Variables Typedefs Friends Defines