algebramix_doc 0.3
/Users/mourrain/Devel/mmx/algebramix/include/algebramix/matrix_strassen.hpp
Go to the documentation of this file.
00001 
00002 /******************************************************************************
00003 * MODULE     : matrix_strassen.hpp
00004 * DESCRIPTION: Matrix multiplication using Strassen's algorithm
00005 * COPYRIGHT  : (C) 2007  Joris van der Hoeven
00006 *******************************************************************************
00007 * This software falls under the GNU general public license and comes WITHOUT
00008 * ANY WARRANTY WHATSOEVER. See the file $TEXMACS_PATH/LICENSE for more details.
00009 * If you don't have this file, write to the Free Software Foundation, Inc.,
00010 * 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
00011 ******************************************************************************/
00012 
00013 #ifndef __MMX__MATRIX_STRASSEN__HPP
00014 #define __MMX__MATRIX_STRASSEN__HPP
00015 #include <algebramix/matrix_unrolled.hpp>
00016 
00017 namespace mmx {
00018 
00019 /******************************************************************************
00020 * Variant for Strassen multiplication
00021 ******************************************************************************/
00022 
00023 template<typename V>
00024 struct matrix_strassen: public V {
00025   typedef typename V::Vec Vec;
00026   typedef typename V::Naive Naive;
00027   typedef typename V::Positive Positive;
00028   typedef matrix_strassen<typename V::No_simd> No_simd;
00029   typedef matrix_strassen<typename V::No_thread> No_thread;
00030   typedef matrix_strassen<typename V::No_scaled> No_scaled;
00031 };
00032 
00033 template<typename F, typename V, typename W>
00034 struct implementation<F,V,matrix_strassen<W> >:
00035   public implementation<F,V,W> {};
00036 
00037 /******************************************************************************
00038 * Strassen multiplication (even-odd decomposition for numerical stability)
00039 ******************************************************************************/
00040 
00041 template<typename V, typename W>
00042 struct implementation<matrix_multiply_base,V,matrix_strassen<W> >:
00043   public implementation<matrix_linear,V>
00044 {
00045   static const nat thr= 128; // use 512 for double >> TODO
00046   typedef implementation<vector_linear,W> Vec;
00047   typedef implementation<matrix_multiply,W> Mat;
00048 
00049   template<typename C> static inline void
00050   mat_load (C* d, nat r, nat c, const C* s, nat rr, nat cc) {
00051     nat j= c;
00052     for (; j!=0; j--, d += Mat::index(0,1,r,c), s += Mat::index(0,2,rr,cc)) {
00053       nat      i = r;
00054       C*       dd= d;
00055       const C* ss= s;
00056       for (; i!=0; i--, dd += Mat::index(1,0,r,c), ss += Mat::index(2,0,rr,cc))
00057         *dd= *ss;
00058     } 
00059   }
00060 
00061   template<typename Op, typename C> static inline void
00062   mat_save (C* d, nat rr, nat cc, const C* s, nat r, nat c) {
00063     typedef typename Op::nomul_op Set;
00064     nat j= c;
00065     for (; j!=0; j--, d += Mat::index(0,2,rr,cc), s += Mat::index(0,1,r,c)) {
00066       nat      i = r;
00067       C*       dd= d;
00068       const C* ss= s;
00069       for (; i!=0; i--, dd += Mat::index(2,0,rr,cc), ss += Mat::index(1,0,r,c))
00070         Set::set_op (*dd, *ss);
00071     } 
00072   }
00073 
00074   template<typename Op, typename D, typename S1, typename S2> static void
00075   mul (D* d, const S1* a, const S2* b,
00076        nat r, nat rr, nat l, nat ll, nat c, nat cc)
00077   {
00078     if (r < thr || l < thr || c < thr) {
00079       Mat::template mul<Op> (d, a, b, r, rr, l, ll, c, cc);
00080       return;
00081     }
00082 
00083     nat hr= r>>1, hl= l>>1, hc= c>>1, fr= hr<<1, fl= hl<<1, fc= hc<<1;
00084     
00085     nat sza= aligned_size<S1,W> (hr * hl);
00086     S1* a11= mmx_new<S1> (5 * sza);
00087     S1* a12= a11 + sza;
00088     S1* a21= a12 + sza;
00089     S1* a22= a21 + sza;
00090     S1* aaa= a22 + sza;
00091 
00092     nat szb= aligned_size<S2,W> (hl * hc);
00093     S2* b11= mmx_new<S2> (5 * szb);
00094     S2* b12= b11 + szb;
00095     S2* b21= b12 + szb;
00096     S2* b22= b21 + szb;
00097     S2* bbb= b22 + szb;
00098 
00099     nat szd= aligned_size<D,W> (hr * hc);
00100     D*  m1 = mmx_new<D> (11 * szd);
00101     D*  m2 = m1  + szd;
00102     D*  m3 = m2  + szd;
00103     D*  m4 = m3  + szd;
00104     D*  m5 = m4  + szd;
00105     D*  m6 = m5  + szd;
00106     D*  m7 = m6  + szd;
00107     D*  d11= m7  + szd;
00108     D*  d12= d11 + szd;
00109     D*  d21= d12 + szd;
00110     D*  d22= d21 + szd;
00111 
00112     mat_load (a11, hr, hl, a                            , rr, ll);
00113     mat_load (a12, hr, hl, a + Mat::index (0, 1, rr, ll), rr, ll);
00114     mat_load (a21, hr, hl, a + Mat::index (1, 0, rr, ll), rr, ll);
00115     mat_load (a22, hr, hl, a + Mat::index (1, 1, rr, ll), rr, ll);
00116     mat_load (b11, hr, hl, b                            , rr, ll);
00117     mat_load (b12, hr, hl, b + Mat::index (0, 1, rr, ll), rr, ll);
00118     mat_load (b21, hr, hl, b + Mat::index (1, 0, rr, ll), rr, ll);
00119     mat_load (b22, hr, hl, b + Mat::index (1, 1, rr, ll), rr, ll);
00120 
00121     Vec::add (aaa, a11, a22, hr * hl);
00122     Vec::add (bbb, b11, b22, hl * hc);
00123     mul<mul_op> (m1, aaa, bbb, hr, hr, hl, hl, hc, hc);
00124     Vec::add (aaa, a21, a22, hr * hl);
00125     mul<mul_op> (m2, aaa, b11, hr, hr, hl, hl, hc, hc);
00126     Vec::sub (bbb, b12, b22, hl * hc);
00127     mul<mul_op> (m3, a11, bbb, hr, hr, hl, hl, hc, hc);
00128     Vec::sub (bbb, b21, b11, hl * hc);
00129     mul<mul_op> (m4, a22, bbb, hr, hr, hl, hl, hc, hc);
00130     Vec::add (aaa, a11, a12, hr * hl);
00131     mul<mul_op> (m5, aaa, b22, hr, hr, hl, hl, hc, hc);
00132     Vec::sub (aaa, a21, a11, hr * hl);
00133     Vec::add (bbb, b11, b12, hl * hc);
00134     mul<mul_op> (m6, aaa, bbb, hr, hr, hl, hl, hc, hc);
00135     Vec::sub (aaa, a12, a22, hr * hl);
00136     Vec::add (bbb, b21, b22, hl * hc);
00137     mul<mul_op> (m7, aaa, bbb, hr, hr, hl, hl, hc, hc);
00138 
00139     Vec::add (d11, m1, m4, hr * hc);
00140     Vec::sub (d11, m5, hr * hc);
00141     Vec::add (d11, m7, hr * hc);
00142     Vec::add (d12, m3, m5, hr * hc);
00143     Vec::add (d21, m2, m4, hr * hc);
00144     Vec::sub (d22, m1, m2, hr * hc);
00145     Vec::add (d22, m3, hr * hc);
00146     Vec::add (d22, m6, hr * hc);
00147 
00148     mat_save<Op> (d                            , rr, cc, d11, hr, hc);
00149     mat_save<Op> (d + Mat::index (0, 1, rr, cc), rr, ll, d12, hr, hc);
00150     mat_save<Op> (d + Mat::index (1, 0, rr, cc), rr, ll, d21, hr, hc);
00151     mat_save<Op> (d + Mat::index (1, 1, rr, cc), rr, ll, d22, hr, hc);
00152 
00153     mmx_delete<S1> (a11, 5  * sza);
00154     mmx_delete<S2> (b11, 5  * szb);
00155     mmx_delete<D > (m1 , 11 * szd);
00156 
00157     mul_complete<Op,W> (d, a, b, r, rr, l, ll, c, cc, fr, fl, fc);
00158   }
00159 }; // implementation<matrix_multiply_base,V,matrix_strassen<W> >
00160 
00161 } // namespace mmx
00162 #endif //__MMX__MATRIX_STRASSEN__HPP
 All Classes Namespaces Files Functions Variables Typedefs Friends Defines