algebramix_doc 0.3
|
00001 00002 /****************************************************************************** 00003 * MODULE : matrix_strassen.hpp 00004 * DESCRIPTION: Matrix multiplication using Strassen's algorithm 00005 * COPYRIGHT : (C) 2007 Joris van der Hoeven 00006 ******************************************************************************* 00007 * This software falls under the GNU general public license and comes WITHOUT 00008 * ANY WARRANTY WHATSOEVER. See the file $TEXMACS_PATH/LICENSE for more details. 00009 * If you don't have this file, write to the Free Software Foundation, Inc., 00010 * 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. 00011 ******************************************************************************/ 00012 00013 #ifndef __MMX__MATRIX_STRASSEN__HPP 00014 #define __MMX__MATRIX_STRASSEN__HPP 00015 #include <algebramix/matrix_unrolled.hpp> 00016 00017 namespace mmx { 00018 00019 /****************************************************************************** 00020 * Variant for Strassen multiplication 00021 ******************************************************************************/ 00022 00023 template<typename V> 00024 struct matrix_strassen: public V { 00025 typedef typename V::Vec Vec; 00026 typedef typename V::Naive Naive; 00027 typedef typename V::Positive Positive; 00028 typedef matrix_strassen<typename V::No_simd> No_simd; 00029 typedef matrix_strassen<typename V::No_thread> No_thread; 00030 typedef matrix_strassen<typename V::No_scaled> No_scaled; 00031 }; 00032 00033 template<typename F, typename V, typename W> 00034 struct implementation<F,V,matrix_strassen<W> >: 00035 public implementation<F,V,W> {}; 00036 00037 /****************************************************************************** 00038 * Strassen multiplication (even-odd decomposition for numerical stability) 00039 ******************************************************************************/ 00040 00041 template<typename V, typename W> 00042 struct implementation<matrix_multiply_base,V,matrix_strassen<W> >: 00043 public implementation<matrix_linear,V> 00044 { 00045 static const nat thr= 128; // use 512 for double >> TODO 00046 typedef implementation<vector_linear,W> Vec; 00047 typedef implementation<matrix_multiply,W> Mat; 00048 00049 template<typename C> static inline void 00050 mat_load (C* d, nat r, nat c, const C* s, nat rr, nat cc) { 00051 nat j= c; 00052 for (; j!=0; j--, d += Mat::index(0,1,r,c), s += Mat::index(0,2,rr,cc)) { 00053 nat i = r; 00054 C* dd= d; 00055 const C* ss= s; 00056 for (; i!=0; i--, dd += Mat::index(1,0,r,c), ss += Mat::index(2,0,rr,cc)) 00057 *dd= *ss; 00058 } 00059 } 00060 00061 template<typename Op, typename C> static inline void 00062 mat_save (C* d, nat rr, nat cc, const C* s, nat r, nat c) { 00063 typedef typename Op::nomul_op Set; 00064 nat j= c; 00065 for (; j!=0; j--, d += Mat::index(0,2,rr,cc), s += Mat::index(0,1,r,c)) { 00066 nat i = r; 00067 C* dd= d; 00068 const C* ss= s; 00069 for (; i!=0; i--, dd += Mat::index(2,0,rr,cc), ss += Mat::index(1,0,r,c)) 00070 Set::set_op (*dd, *ss); 00071 } 00072 } 00073 00074 template<typename Op, typename D, typename S1, typename S2> static void 00075 mul (D* d, const S1* a, const S2* b, 00076 nat r, nat rr, nat l, nat ll, nat c, nat cc) 00077 { 00078 if (r < thr || l < thr || c < thr) { 00079 Mat::template mul<Op> (d, a, b, r, rr, l, ll, c, cc); 00080 return; 00081 } 00082 00083 nat hr= r>>1, hl= l>>1, hc= c>>1, fr= hr<<1, fl= hl<<1, fc= hc<<1; 00084 00085 nat sza= aligned_size<S1,W> (hr * hl); 00086 S1* a11= mmx_new<S1> (5 * sza); 00087 S1* a12= a11 + sza; 00088 S1* a21= a12 + sza; 00089 S1* a22= a21 + sza; 00090 S1* aaa= a22 + sza; 00091 00092 nat szb= aligned_size<S2,W> (hl * hc); 00093 S2* b11= mmx_new<S2> (5 * szb); 00094 S2* b12= b11 + szb; 00095 S2* b21= b12 + szb; 00096 S2* b22= b21 + szb; 00097 S2* bbb= b22 + szb; 00098 00099 nat szd= aligned_size<D,W> (hr * hc); 00100 D* m1 = mmx_new<D> (11 * szd); 00101 D* m2 = m1 + szd; 00102 D* m3 = m2 + szd; 00103 D* m4 = m3 + szd; 00104 D* m5 = m4 + szd; 00105 D* m6 = m5 + szd; 00106 D* m7 = m6 + szd; 00107 D* d11= m7 + szd; 00108 D* d12= d11 + szd; 00109 D* d21= d12 + szd; 00110 D* d22= d21 + szd; 00111 00112 mat_load (a11, hr, hl, a , rr, ll); 00113 mat_load (a12, hr, hl, a + Mat::index (0, 1, rr, ll), rr, ll); 00114 mat_load (a21, hr, hl, a + Mat::index (1, 0, rr, ll), rr, ll); 00115 mat_load (a22, hr, hl, a + Mat::index (1, 1, rr, ll), rr, ll); 00116 mat_load (b11, hr, hl, b , rr, ll); 00117 mat_load (b12, hr, hl, b + Mat::index (0, 1, rr, ll), rr, ll); 00118 mat_load (b21, hr, hl, b + Mat::index (1, 0, rr, ll), rr, ll); 00119 mat_load (b22, hr, hl, b + Mat::index (1, 1, rr, ll), rr, ll); 00120 00121 Vec::add (aaa, a11, a22, hr * hl); 00122 Vec::add (bbb, b11, b22, hl * hc); 00123 mul<mul_op> (m1, aaa, bbb, hr, hr, hl, hl, hc, hc); 00124 Vec::add (aaa, a21, a22, hr * hl); 00125 mul<mul_op> (m2, aaa, b11, hr, hr, hl, hl, hc, hc); 00126 Vec::sub (bbb, b12, b22, hl * hc); 00127 mul<mul_op> (m3, a11, bbb, hr, hr, hl, hl, hc, hc); 00128 Vec::sub (bbb, b21, b11, hl * hc); 00129 mul<mul_op> (m4, a22, bbb, hr, hr, hl, hl, hc, hc); 00130 Vec::add (aaa, a11, a12, hr * hl); 00131 mul<mul_op> (m5, aaa, b22, hr, hr, hl, hl, hc, hc); 00132 Vec::sub (aaa, a21, a11, hr * hl); 00133 Vec::add (bbb, b11, b12, hl * hc); 00134 mul<mul_op> (m6, aaa, bbb, hr, hr, hl, hl, hc, hc); 00135 Vec::sub (aaa, a12, a22, hr * hl); 00136 Vec::add (bbb, b21, b22, hl * hc); 00137 mul<mul_op> (m7, aaa, bbb, hr, hr, hl, hl, hc, hc); 00138 00139 Vec::add (d11, m1, m4, hr * hc); 00140 Vec::sub (d11, m5, hr * hc); 00141 Vec::add (d11, m7, hr * hc); 00142 Vec::add (d12, m3, m5, hr * hc); 00143 Vec::add (d21, m2, m4, hr * hc); 00144 Vec::sub (d22, m1, m2, hr * hc); 00145 Vec::add (d22, m3, hr * hc); 00146 Vec::add (d22, m6, hr * hc); 00147 00148 mat_save<Op> (d , rr, cc, d11, hr, hc); 00149 mat_save<Op> (d + Mat::index (0, 1, rr, cc), rr, ll, d12, hr, hc); 00150 mat_save<Op> (d + Mat::index (1, 0, rr, cc), rr, ll, d21, hr, hc); 00151 mat_save<Op> (d + Mat::index (1, 1, rr, cc), rr, ll, d22, hr, hc); 00152 00153 mmx_delete<S1> (a11, 5 * sza); 00154 mmx_delete<S2> (b11, 5 * szb); 00155 mmx_delete<D > (m1 , 11 * szd); 00156 00157 mul_complete<Op,W> (d, a, b, r, rr, l, ll, c, cc, fr, fl, fc); 00158 } 00159 }; // implementation<matrix_multiply_base,V,matrix_strassen<W> > 00160 00161 } // namespace mmx 00162 #endif //__MMX__MATRIX_STRASSEN__HPP