algebramix_doc 0.3
|
00001 00002 /****************************************************************************** 00003 * MODULE : matrix_unrolled.hpp 00004 * DESCRIPTION: unrolling and subdividing matrix multiplication 00005 * COPYRIGHT : (C) 2007 Joris van der Hoeven 00006 ******************************************************************************* 00007 * This software falls under the GNU general public license and comes WITHOUT 00008 * ANY WARRANTY WHATSOEVER. See the file $TEXMACS_PATH/LICENSE for more details. 00009 * If you don't have this file, write to the Free Software Foundation, Inc., 00010 * 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. 00011 ******************************************************************************/ 00012 00013 #ifndef __MMX__MATRIX_UNROLLED__HPP 00014 #define __MMX__MATRIX_UNROLLED__HPP 00015 #include <algebramix/vector_unrolled.hpp> 00016 #include <algebramix/matrix_fixed.hpp> 00017 00018 namespace mmx { 00019 00020 /****************************************************************************** 00021 * Unrolled variant 00022 ******************************************************************************/ 00023 00024 template<nat sz, typename V=matrix_naive> 00025 struct matrix_unrolled: public V { 00026 typedef vector_unrolled<sz,typename V::Vec> Vec; 00027 typedef typename V::Naive Naive; 00028 typedef matrix_unrolled<sz,typename V::Positive> Positive; 00029 typedef matrix_unrolled<sz,typename V::No_aligned> No_aligned; 00030 typedef matrix_unrolled<sz,typename V::No_simd> No_simd; 00031 typedef matrix_unrolled<sz,typename V::No_thread> No_thread; 00032 typedef matrix_unrolled<sz,typename V::No_scaled> No_scaled; 00033 }; 00034 00035 template<nat sz, typename F, typename V, typename W> 00036 struct implementation<F,V,matrix_unrolled<sz,W> >: 00037 public implementation<F,V,W> {}; 00038 00039 /****************************************************************************** 00040 * Complete a few missing rows, columns, inner products using another algorithm 00041 ******************************************************************************/ 00042 00043 template<typename Op, typename V, typename D, typename S1, typename S2> void 00044 mul_complete (D* dest, const S1* src1, const S2* src2, 00045 nat r, nat rr, nat l, nat ll, nat c, nat cc, 00046 nat hr, nat hl, nat hc) 00047 { 00048 typedef implementation<matrix_multiply,V> Mat; 00049 typedef typename Op::acc_op Acc; 00050 if (hr < r && hl != 0 && hc != 0) 00051 Mat::template mul<Op > (dest + Mat::index (hr, 0, rr, cc), 00052 src1 + Mat::index (hr, 0, rr, ll), 00053 src2, 00054 r-hr, rr, hl, ll, hc, cc); 00055 if (hc < c && hl != 0) 00056 Mat::template mul<Op > (dest + Mat::index (0, hc, rr, cc), 00057 src1, 00058 src2 + Mat::index (0, hc, ll, cc), 00059 r , rr, hl, ll, c-hc, cc); 00060 if (hl < l) 00061 Mat::template mul<Acc> (dest, 00062 src1 + Mat::index (0, hl, rr, ll), 00063 src2 + Mat::index (hl, 0, ll, cc), 00064 r , rr, l-hl, ll, c , cc); 00065 } 00066 00067 /****************************************************************************** 00068 * Using a triple loop using big unrolled blocks 00069 ******************************************************************************/ 00070 00071 template<typename Op, nat ur, nat ul, nat uc, typename V, 00072 typename D, typename S1, typename S2> void 00073 mul_unrolled (D* dest, const S1* src1, const S2* src2, 00074 nat r, nat rr, nat l, nat ll, nat c, nat cc) 00075 { 00076 typedef implementation<matrix_multiply,V> Mat; 00077 typedef implementation<matrix_multiply_base,matrix_naive> NMat; 00078 typedef typename Op::acc_op Acc; 00079 nat nr= r/ur, nl= l/ul, nc= c/uc; 00080 if (nl == 0) 00081 NMat::template clr<Op> (dest, r, rr, c, cc); 00082 else 00083 for (nat ir=0; ir<nr; ir++) 00084 for (nat ic=0; ic<nc; ic++) { 00085 nat il=0; 00086 for (; il<1; il++) 00087 matrix_multiply_helper<Op,D,S1,S2,ur,ul,uc>:: 00088 mul_stride (dest + Mat::index (ir*ur, ic*uc, rr, cc), 00089 src1 + Mat::index (ir*ur, il*ul, rr, ll), 00090 src2 + Mat::index (il*ul, ic*uc, ll, cc), 00091 rr, ll); 00092 for (; il<nl; il++) 00093 matrix_multiply_helper<Acc,D,S1,S2,ur,ul,uc>:: 00094 mul_stride (dest + Mat::index (ir*ur, ic*uc, rr, cc), 00095 src1 + Mat::index (ir*ur, il*ul, rr, ll), 00096 src2 + Mat::index (il*ul, ic*uc, ll, cc), 00097 rr, ll); 00098 } 00099 mul_complete<Op,V> (dest, src1, src2, r, rr, l, ll, c, cc, 00100 ur*nr, ul*nl, uc*nc); 00101 } 00102 00103 /****************************************************************************** 00104 * Interface for unrolled multiplication 00105 ******************************************************************************/ 00106 00107 template<nat sz, typename V, typename W> 00108 struct implementation<matrix_multiply_base,V,matrix_unrolled<sz,W> >: 00109 public implementation<matrix_linear,V> 00110 { 00111 const static nat ur= sz; 00112 const static nat ul= sz; 00113 const static nat uc= sz; 00114 00115 template<typename Op, typename D, typename S1, typename S2> 00116 static inline void 00117 mul (D* dest, const S1* src1, const S2* src2, 00118 nat r, nat rr, nat l, nat ll, nat c, nat cc) 00119 { 00120 mul_unrolled<Op,ur,ul,uc,W> (dest, src1, src2, r, rr, l, ll, c, cc); 00121 } 00122 }; // implementation<matrix_multiply,V,matrix_unrolled<sz,BV> > 00123 00124 } // namespace mmx 00125 #endif //__MMX__MATRIX_UNROLLED__HPP