algebramix_doc 0.3
|
00001 00002 /****************************************************************************** 00003 * MODULE : crt_integer.cpp 00004 * DESCRIPTION: Subroutines for efficient Chinese remaindering with GMP 00005 * COPYRIGHT : (C) 2008 Joris van der Hoeven 00006 ******************************************************************************* 00007 * This software falls under the GNU general public license and comes WITHOUT 00008 * ANY WARRANTY WHATSOEVER. See the file $TEXMACS_PATH/LICENSE for more details. 00009 * If you don't have this file, write to the Free Software Foundation, Inc., 00010 * 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. 00011 ******************************************************************************/ 00012 00013 #include <basix/port.hpp> 00014 #include <algebramix/crt_integer.hpp> 00015 namespace mmx { 00016 #if defined (__GNU_MP__) 00017 typedef mp_limb_t C; 00018 00019 /****************************************************************************** 00020 * Extra routine for debugging 00021 ******************************************************************************/ 00022 00023 /* 00024 static void 00025 mpn_print (const port& out, const C* src, nat n) { 00026 mpz_t i; 00027 i->_mp_alloc= n; 00028 i->_mp_size = n; 00029 i->_mp_d = const_cast<C*> (src); 00030 char* s= mpz_get_str (NULL, 10, i); 00031 out << s; 00032 mmx_free (s, strlen (s) + 1); 00033 } 00034 */ 00035 00036 /****************************************************************************** 00037 * Building the auxiliary information for encoding and decoding 00038 ******************************************************************************/ 00039 00040 static inline nat 00041 mpn_size (const C* src, nat n) { 00042 while (n > 0 && src[n-1] == 0) n--; 00043 return n; 00044 } 00045 00046 static inline void 00047 mpn_clear (C* dest, nat n) { 00048 for (; n != 0; n--) { 00049 *dest= 0; 00050 dest++; 00051 } 00052 } 00053 00054 static inline void 00055 mpn_copy (C* dest, const C* src, nat n) { 00056 for (; n != 0; n--) { 00057 *dest= *src; 00058 dest++; src++; 00059 } 00060 } 00061 00062 static inline void 00063 mpn_copy (C* dest, const C* src, nat n1, nat n2) { 00064 mpn_copy (dest, src, n2); 00065 mpn_clear (dest + n2, n1 - n2); 00066 } 00067 00068 static void 00069 mpn_cofactor (C* d, const C* s1, const C* s2, nat n1, nat n2) { 00070 // Compute the first cofactor [d, n2] of [s1, n1] and [s2, n2]. 00071 // We have 0 <= d < s2 and d s1 + e s2 = 1 for the second cofactor e. 00072 mpz_t src1; 00073 mpz_t src2; 00074 mpz_t g; 00075 mpz_t c1; 00076 mpz_t c2; 00077 src1->_mp_alloc= n1; 00078 src1->_mp_size = mpn_size (s1, n1); 00079 src1->_mp_d = const_cast<C*> (s1); 00080 src2->_mp_alloc= n2; 00081 src2->_mp_size = mpn_size (s2, n2); 00082 src2->_mp_d = const_cast<C*> (s2); 00083 mpz_init (g); 00084 mpz_init (c1); 00085 mpz_init (c2); 00086 mpz_gcdext (g, c1, c2, src1, src2); 00087 if (mpz_sgn (c1) < 0) mpz_add (c1, c1, src2); 00088 mpn_copy (d, c1->_mp_d, n2, c1->_mp_size); 00089 mpz_clear (g); 00090 mpz_clear (c1); 00091 mpz_clear (c2); 00092 } 00093 00094 static void 00095 build_converters (C* enc, const C* src, nat n, nat inc) { 00096 // At the end, [enc, n] contains the product of all moduli. 00097 // If n = n1 + n2 >= 2, then 'enc + inc' and 'enc + inc + n1' 00098 // recursively contains the converter information for 00099 // the first n1 and the second n2 moduli. 00100 // Furthermore, [enc + inc/2, n2] contains the first cofactor 00101 // for the corresponding products 00102 //mmout << "Build " << n << ", " << inc << "\n"; 00103 if (n == 1) *enc= (C) *src; 00104 else { 00105 nat n2= n >> 1; 00106 nat n1= n - n2; 00107 C* dec= enc + (inc >> 1); 00108 build_converters (enc + inc, src, n1, inc); 00109 build_converters (enc + inc + n1, src + n1, n2, inc); 00110 mpn_mul (enc, enc + inc, n1, enc + inc + n1, n2); 00111 mpn_cofactor (dec, enc + inc, enc + inc + n1, n1, n2); 00112 } 00113 } 00114 00115 vector<C> 00116 mpz_setup_crt (const vector<C>& mods) { 00117 nat n= N(mods); 00118 if (n == 0) return vector<C> (); 00119 nat inc= n << 1; 00120 nat steps= 1; 00121 for (nat i= n-1; i != 0; steps++, i >>= 1); 00122 vector<C> cv= fill<C> (0, steps * inc); 00123 C* enc= seg (cv); 00124 build_converters (enc, seg (mods), n, inc); 00125 return cv; 00126 } 00127 00128 integer 00129 mpz_moduli_product (const vector<C>& mods, const vector<mp_limb_t>& cv) { 00130 nat n= N(mods); 00131 if (n == 0) return 1; 00132 integer i= raw_integer (n); 00133 mpn_copy ((*i)->_mp_d, seg (cv), n); 00134 (*i)->_mp_size= mpn_size (seg (cv), n); 00135 return i; 00136 } 00137 00138 /****************************************************************************** 00139 * Encoding and decoding integers using Chinese remaindering 00140 ******************************************************************************/ 00141 00142 inline const C* seg (const integer& i) { return (*i)->_mp_d; } 00143 inline C* seg (integer& i) { return (*i)->_mp_d; } 00144 00145 static void 00146 mpn_mod (C* dest, C* tmp, 00147 const C* s1, const C* s2, nat n1, nat n2) { 00148 // Compute dest := s1 mod s2. The argument tmp contains auxiliary storage 00149 nat eff_n1= n1, eff_n2= n2; 00150 while (eff_n1 > 0 && s1[eff_n1-1] == 0) eff_n1--; 00151 while (eff_n2 > 0 && s2[eff_n2-1] == 0) eff_n2--; 00152 if (eff_n1 < eff_n2) mpn_copy (dest, s1, n2, eff_n1); 00153 else { 00154 mpn_tdiv_qr (tmp, dest, 0, s1, eff_n1, s2, eff_n2); 00155 if (n2 > eff_n2) mpn_clear (dest + eff_n2, n2 - eff_n2); 00156 } 00157 } 00158 00159 static void 00160 encode (C* dest, C* tmp, const C* enc, nat n, nat inc) { 00161 if (n == 1) return; 00162 nat n2= n >> 1; 00163 nat n1= n - n2; 00164 mpn_copy (tmp, dest, n); 00165 mpn_mod (dest, tmp + n, tmp, enc, n, n1); 00166 mpn_mod (dest + n1, tmp + n, tmp, enc + n1, n, n2); 00167 encode (dest, tmp, enc + inc, n1, inc); 00168 encode (dest + n1, tmp, enc + n1 + inc, n2, inc); 00169 } 00170 00171 void 00172 mpz_encode_crt (C* dest, const integer& src, 00173 const vector<C>& mods, const vector<C>& cv) 00174 { 00175 nat n= N(mods); 00176 if (n == 0) return; 00177 const C* enc= seg (cv); 00178 C* tmp = mmx_new<C> (n << 1); 00179 mpn_copy (dest, seg (src), n, limb_size (src)); 00180 encode (dest, tmp, enc + (n << 1), n, n << 1); 00181 if (sign (src) < 0) 00182 for (nat i=0; i<n; i++) 00183 if (dest[i] != 0) 00184 dest[i]= mods[i] - dest[i]; 00185 mmx_delete<C> (tmp, n << 1); 00186 00187 /* 00188 // optional check for correctness 00189 integer check= chinese_decode (dest, mods, cv); 00190 if (check != src) 00191 mmout << "Check failed " << src << " -> " << check << " " << mods << "\n"; 00192 */ 00193 } 00194 00195 static bool 00196 mpn_is_zero (const C* src, nat n) { 00197 for (nat i= 0; i<n; i++) 00198 if (src[i] != 0) return false; 00199 return true; 00200 } 00201 00202 static void 00203 mpn_reconstruct (C* dest, C* temp, 00204 const C* x1, const C* x2, 00205 const C* m1, const C* m2, const C* c1, 00206 nat n1, nat n2) 00207 { 00208 mpn_copy (temp, x2, n1, n2); 00209 C borrow= mpn_sub_n (dest, temp, x1, n1); 00210 if (borrow) mpn_sub_n (dest, x1, temp, n1); 00211 mpn_mul (temp, dest, n1, c1, n2); 00212 mpn_mod (dest, temp + n1 + n2, temp, m2, n1 + n2, n2); 00213 if (borrow && !mpn_is_zero (dest, n2)) { 00214 mpn_copy (temp, dest, n2); 00215 mpn_sub_n (dest, m2, temp, n2); 00216 } 00217 mpn_mul (temp, m1, n1, dest, n2); 00218 mpn_add (dest, temp, n1 + n2, x1, n1); 00219 } 00220 00221 static void 00222 decode (C* dest, C* tmp, const C* src, 00223 const C* enc, nat n, nat inc) 00224 { 00225 if (n == 1) *dest= (C) *src; 00226 else { 00227 //typedef implementation<vector_linear,vector_naive> NVec; 00228 //mmout << "Decode "; NVec::print (mmout, src, n); mmout << "\n"; 00229 nat n2= n >> 1; 00230 nat n1= n - n2; 00231 const C* dec= enc + (inc >> 1); 00232 decode (dest, tmp, src, enc + inc, n1, inc); 00233 decode (dest + n1, tmp, src + n1, enc + n1 + inc, n2, inc); 00234 mpn_copy (tmp, dest, n); 00235 mpn_reconstruct (dest, tmp + n, tmp, tmp + n1, 00236 enc + inc, enc + inc + n1, dec, n1, n2); 00237 //mmout << "Decoded "; mpn_print (mmout, dest, n); mmout << "\n"; 00238 } 00239 } 00240 00241 integer 00242 mpz_decode_crt (const C* src, 00243 const vector<C>& mods, const vector<C>& cv) { 00244 nat n= N(mods); 00245 if (n == 0) return 0; 00246 integer i= raw_integer (n); 00247 const C* enc= seg (cv); 00248 C* tmp = mmx_new<C> (n << 2); 00249 decode (seg (i), tmp, src, enc, n, n << 1); 00250 mpn_sub_n (tmp, enc, seg (i), n); 00251 if (mpn_cmp (seg (i), tmp, n) <= 0) 00252 (*i)->_mp_size= mpn_size (seg (i), n); 00253 else { 00254 mpn_copy (seg (i), tmp, n); 00255 (*i)->_mp_size= -mpn_size (seg (i), n); 00256 } 00257 mmx_delete<C> (tmp, n << 2); 00258 00259 /* 00260 // optional check for correctness 00261 C* check= mmx_new<C> (n); 00262 chinese_encode (check, i, mods, cv); 00263 for (nat k=0; k<n; k++) 00264 if (check[k] != src[k]) 00265 mmout << "Check failed " << src[k] << " -> " << check[k] << "\n"; 00266 mmx_delete<C> (check, n); 00267 */ 00268 00269 return i; 00270 } 00271 00272 #endif // __GNU_MP__ 00273 } // namespace mmx