algebramix_doc 0.3
/Users/mourrain/Devel/mmx/algebramix/src/crt_integer.cpp
Go to the documentation of this file.
00001 
00002 /******************************************************************************
00003 * MODULE     : crt_integer.cpp
00004 * DESCRIPTION: Subroutines for efficient Chinese remaindering with GMP
00005 * COPYRIGHT  : (C) 2008  Joris van der Hoeven
00006 *******************************************************************************
00007 * This software falls under the GNU general public license and comes WITHOUT
00008 * ANY WARRANTY WHATSOEVER. See the file $TEXMACS_PATH/LICENSE for more details.
00009 * If you don't have this file, write to the Free Software Foundation, Inc.,
00010 * 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
00011 ******************************************************************************/
00012 
00013 #include <basix/port.hpp>
00014 #include <algebramix/crt_integer.hpp>
00015 namespace mmx {
00016 #if defined (__GNU_MP__)
00017 typedef mp_limb_t C;
00018 
00019 /******************************************************************************
00020 * Extra routine for debugging
00021 ******************************************************************************/
00022 
00023 /*
00024 static void
00025 mpn_print (const port& out, const C* src, nat n) {
00026   mpz_t i;
00027   i->_mp_alloc= n;
00028   i->_mp_size = n;
00029   i->_mp_d    = const_cast<C*> (src);
00030   char* s= mpz_get_str (NULL, 10, i);
00031   out << s;
00032   mmx_free (s, strlen (s) + 1);  
00033 }
00034 */
00035 
00036 /******************************************************************************
00037 * Building the auxiliary information for encoding and decoding
00038 ******************************************************************************/
00039 
00040 static inline nat
00041 mpn_size (const C* src, nat n) {
00042   while (n > 0 && src[n-1] == 0) n--;
00043   return n;
00044 }
00045 
00046 static inline void
00047 mpn_clear (C* dest, nat n) {
00048   for (; n != 0; n--) {
00049     *dest= 0;
00050     dest++;
00051   }
00052 }
00053 
00054 static inline void
00055 mpn_copy (C* dest, const C* src, nat n) {
00056   for (; n != 0; n--) {
00057     *dest= *src;
00058     dest++; src++;
00059   }
00060 }
00061 
00062 static inline void
00063 mpn_copy (C* dest, const C* src, nat n1, nat n2) {
00064   mpn_copy (dest, src, n2);
00065   mpn_clear (dest + n2, n1 - n2);
00066 }
00067 
00068 static void
00069 mpn_cofactor (C* d, const C* s1, const C* s2, nat n1, nat n2) {
00070   // Compute the first cofactor [d, n2] of [s1, n1] and [s2, n2].
00071   // We have 0 <= d < s2 and d s1 + e s2 = 1 for the second cofactor e.
00072   mpz_t src1;
00073   mpz_t src2;
00074   mpz_t g;
00075   mpz_t c1;
00076   mpz_t c2;
00077   src1->_mp_alloc= n1;
00078   src1->_mp_size = mpn_size (s1, n1);
00079   src1->_mp_d    = const_cast<C*> (s1);
00080   src2->_mp_alloc= n2;
00081   src2->_mp_size = mpn_size (s2, n2);
00082   src2->_mp_d    = const_cast<C*> (s2);
00083   mpz_init (g);
00084   mpz_init (c1);
00085   mpz_init (c2);
00086   mpz_gcdext (g, c1, c2, src1, src2);
00087   if (mpz_sgn (c1) < 0) mpz_add (c1, c1, src2);
00088   mpn_copy (d, c1->_mp_d, n2, c1->_mp_size);
00089   mpz_clear (g);
00090   mpz_clear (c1);
00091   mpz_clear (c2);
00092 }
00093 
00094 static void
00095 build_converters (C* enc, const C* src, nat n, nat inc) {
00096   // At the end, [enc, n] contains the product of all moduli.
00097   // If n = n1 + n2 >= 2, then 'enc + inc' and 'enc + inc + n1'
00098   // recursively contains the converter information for
00099   // the first n1 and the second n2 moduli.
00100   // Furthermore, [enc + inc/2, n2] contains the first cofactor
00101   // for the corresponding products
00102   //mmout << "Build " << n << ", " << inc << "\n";
00103   if (n == 1) *enc= (C) *src;
00104   else {
00105     nat n2= n >> 1;
00106     nat n1= n - n2;
00107     C* dec= enc + (inc >> 1);
00108     build_converters (enc + inc, src, n1, inc);
00109     build_converters (enc + inc + n1, src + n1, n2, inc);
00110     mpn_mul (enc, enc + inc, n1, enc + inc + n1, n2);
00111     mpn_cofactor (dec, enc + inc, enc + inc + n1, n1, n2);
00112   }
00113 }
00114 
00115 vector<C>
00116 mpz_setup_crt (const vector<C>& mods) {
00117   nat n= N(mods);
00118   if (n == 0) return vector<C> ();
00119   nat inc= n << 1;
00120   nat steps= 1;
00121   for (nat i= n-1; i != 0; steps++, i >>= 1);
00122   vector<C> cv= fill<C> (0, steps * inc);
00123   C* enc= seg (cv);
00124   build_converters (enc, seg (mods), n, inc);
00125   return cv;
00126 }
00127 
00128 integer
00129 mpz_moduli_product (const vector<C>& mods, const vector<mp_limb_t>& cv) {
00130   nat n= N(mods);
00131   if (n == 0) return 1;
00132   integer i= raw_integer (n);
00133   mpn_copy ((*i)->_mp_d, seg (cv), n);
00134   (*i)->_mp_size= mpn_size (seg (cv), n);
00135   return i;
00136 }
00137   
00138 /******************************************************************************
00139 * Encoding and decoding integers using Chinese remaindering
00140 ******************************************************************************/
00141 
00142 inline const C* seg (const integer& i) { return (*i)->_mp_d; }
00143 inline C* seg (integer& i) { return (*i)->_mp_d; }
00144 
00145 static void
00146 mpn_mod (C* dest, C* tmp,
00147          const C* s1, const C* s2, nat n1, nat n2) {
00148   // Compute dest := s1 mod s2. The argument tmp contains auxiliary storage
00149   nat eff_n1= n1, eff_n2= n2;
00150   while (eff_n1 > 0 && s1[eff_n1-1] == 0) eff_n1--;
00151   while (eff_n2 > 0 && s2[eff_n2-1] == 0) eff_n2--;
00152   if (eff_n1 < eff_n2) mpn_copy (dest, s1, n2, eff_n1);
00153   else {
00154     mpn_tdiv_qr (tmp, dest, 0, s1, eff_n1, s2, eff_n2);
00155     if (n2 > eff_n2) mpn_clear (dest + eff_n2, n2 - eff_n2);
00156   }
00157 }
00158 
00159 static void
00160 encode (C* dest, C* tmp, const C* enc, nat n, nat inc) {
00161   if (n == 1) return;
00162   nat n2= n >> 1;
00163   nat n1= n - n2;
00164   mpn_copy (tmp, dest, n);
00165   mpn_mod (dest, tmp + n, tmp, enc, n, n1);
00166   mpn_mod (dest + n1, tmp + n, tmp, enc + n1, n, n2);
00167   encode (dest, tmp, enc + inc, n1, inc);
00168   encode (dest + n1, tmp, enc + n1 + inc, n2, inc);
00169 }
00170 
00171 void
00172 mpz_encode_crt (C* dest, const integer& src,
00173                 const vector<C>& mods, const vector<C>& cv)
00174 {
00175   nat n= N(mods);
00176   if (n == 0) return;
00177   const C* enc= seg (cv);
00178   C* tmp = mmx_new<C> (n << 1);
00179   mpn_copy (dest, seg (src), n, limb_size (src));
00180   encode (dest, tmp, enc + (n << 1), n, n << 1);
00181   if (sign (src) < 0)
00182     for (nat i=0; i<n; i++)
00183       if (dest[i] != 0)
00184         dest[i]= mods[i] - dest[i];
00185   mmx_delete<C> (tmp, n << 1);
00186 
00187   /*
00188   // optional check for correctness
00189   integer check= chinese_decode (dest, mods, cv);
00190   if (check != src)
00191     mmout << "Check failed " << src << " -> " << check << " " << mods << "\n";
00192   */
00193 }
00194 
00195 static bool
00196 mpn_is_zero (const C* src, nat n) {
00197   for (nat i= 0; i<n; i++)
00198     if (src[i] != 0) return false;
00199   return true;
00200 }
00201 
00202 static void
00203 mpn_reconstruct (C* dest, C* temp,
00204                  const C* x1, const C* x2,
00205                  const C* m1, const C* m2, const C* c1,
00206                  nat n1, nat n2)
00207 {
00208   mpn_copy (temp, x2, n1, n2);
00209   C borrow= mpn_sub_n (dest, temp, x1, n1);
00210   if (borrow) mpn_sub_n (dest, x1, temp, n1);
00211   mpn_mul (temp, dest, n1, c1, n2);
00212   mpn_mod (dest, temp + n1 + n2, temp, m2, n1 + n2, n2);
00213   if (borrow && !mpn_is_zero (dest, n2)) {
00214     mpn_copy (temp, dest, n2);
00215     mpn_sub_n (dest, m2, temp, n2);
00216   }
00217   mpn_mul (temp, m1, n1, dest, n2);
00218   mpn_add (dest, temp, n1 + n2, x1, n1);
00219 }
00220 
00221 static void
00222 decode (C* dest, C* tmp, const C* src,
00223         const C* enc, nat n, nat inc)
00224 {
00225   if (n == 1) *dest= (C) *src;
00226   else {
00227     //typedef implementation<vector_linear,vector_naive> NVec;
00228     //mmout << "Decode "; NVec::print (mmout, src, n); mmout << "\n";
00229     nat n2= n >> 1;
00230     nat n1= n - n2;
00231     const C* dec= enc + (inc >> 1);
00232     decode (dest, tmp, src, enc + inc, n1, inc);
00233     decode (dest + n1, tmp, src + n1, enc + n1 + inc, n2, inc);
00234     mpn_copy (tmp, dest, n);
00235     mpn_reconstruct (dest, tmp + n, tmp, tmp + n1,
00236                      enc + inc, enc + inc + n1, dec, n1, n2);
00237     //mmout << "Decoded "; mpn_print (mmout, dest, n); mmout << "\n";
00238   }
00239 }
00240 
00241 integer
00242 mpz_decode_crt (const C* src,
00243                 const vector<C>& mods, const vector<C>& cv) {
00244   nat n= N(mods);
00245   if (n == 0) return 0;
00246   integer i= raw_integer (n);
00247   const C* enc= seg (cv);
00248   C* tmp = mmx_new<C> (n << 2);
00249   decode (seg (i), tmp, src, enc, n, n << 1);
00250   mpn_sub_n (tmp, enc, seg (i), n);
00251   if (mpn_cmp (seg (i), tmp, n) <= 0)
00252     (*i)->_mp_size= mpn_size (seg (i), n);
00253   else {
00254     mpn_copy (seg (i), tmp, n);
00255     (*i)->_mp_size= -mpn_size (seg (i), n);
00256   }
00257   mmx_delete<C> (tmp, n << 2);
00258 
00259   /*
00260   // optional check for correctness
00261   C* check= mmx_new<C> (n);
00262   chinese_encode (check, i, mods, cv);
00263   for (nat k=0; k<n; k++)
00264     if (check[k] != src[k])
00265       mmout << "Check failed " << src[k] << " -> " << check[k] << "\n";
00266   mmx_delete<C> (check, n);
00267   */
00268 
00269   return i;
00270 }
00271 
00272 #endif // __GNU_MP__
00273 } // namespace mmx
 All Classes Namespaces Files Functions Variables Typedefs Friends Defines