algebramix_doc 0.3
/Users/mourrain/Devel/mmx/algebramix/include/algebramix/matrix_unrolled.hpp
Go to the documentation of this file.
00001 
00002 /******************************************************************************
00003 * MODULE     : matrix_unrolled.hpp
00004 * DESCRIPTION: unrolling and subdividing matrix multiplication
00005 * COPYRIGHT  : (C) 2007  Joris van der Hoeven
00006 *******************************************************************************
00007 * This software falls under the GNU general public license and comes WITHOUT
00008 * ANY WARRANTY WHATSOEVER. See the file $TEXMACS_PATH/LICENSE for more details.
00009 * If you don't have this file, write to the Free Software Foundation, Inc.,
00010 * 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
00011 ******************************************************************************/
00012 
00013 #ifndef __MMX__MATRIX_UNROLLED__HPP
00014 #define __MMX__MATRIX_UNROLLED__HPP
00015 #include <algebramix/vector_unrolled.hpp>
00016 #include <algebramix/matrix_fixed.hpp>
00017 
00018 namespace mmx {
00019 
00020 /******************************************************************************
00021 * Unrolled variant
00022 ******************************************************************************/
00023 
00024 template<nat sz, typename V=matrix_naive>
00025 struct matrix_unrolled: public V {
00026   typedef vector_unrolled<sz,typename V::Vec> Vec;
00027   typedef typename V::Naive Naive;
00028   typedef matrix_unrolled<sz,typename V::Positive> Positive;
00029   typedef matrix_unrolled<sz,typename V::No_aligned> No_aligned;
00030   typedef matrix_unrolled<sz,typename V::No_simd> No_simd;
00031   typedef matrix_unrolled<sz,typename V::No_thread> No_thread;
00032   typedef matrix_unrolled<sz,typename V::No_scaled> No_scaled;
00033 };
00034 
00035 template<nat sz, typename F, typename V, typename W>
00036 struct implementation<F,V,matrix_unrolled<sz,W> >:
00037   public implementation<F,V,W> {};
00038 
00039 /******************************************************************************
00040 * Complete a few missing rows, columns, inner products using another algorithm
00041 ******************************************************************************/
00042 
00043 template<typename Op, typename V, typename D, typename S1, typename S2> void
00044 mul_complete (D* dest, const S1* src1, const S2* src2,
00045               nat r, nat rr, nat l, nat ll, nat c, nat cc,
00046               nat hr, nat hl, nat hc)
00047 {
00048   typedef implementation<matrix_multiply,V> Mat;
00049   typedef typename Op::acc_op Acc;
00050   if (hr < r && hl != 0 && hc != 0)
00051     Mat::template mul<Op > (dest + Mat::index (hr, 0, rr, cc),
00052                             src1 + Mat::index (hr, 0, rr, ll),
00053                             src2,
00054                             r-hr, rr, hl, ll, hc, cc);
00055   if (hc < c && hl != 0)
00056     Mat::template mul<Op > (dest + Mat::index (0, hc, rr, cc),
00057                             src1,
00058                             src2 + Mat::index (0, hc, ll, cc),
00059                             r , rr, hl, ll, c-hc, cc);
00060   if (hl < l)
00061     Mat::template mul<Acc> (dest,
00062                             src1 + Mat::index (0, hl, rr, ll),
00063                             src2 + Mat::index (hl, 0, ll, cc),
00064                             r , rr, l-hl, ll, c , cc);
00065 }
00066 
00067 /******************************************************************************
00068 * Using a triple loop using big unrolled blocks
00069 ******************************************************************************/
00070 
00071 template<typename Op, nat ur, nat ul, nat uc, typename V,
00072          typename D, typename S1, typename S2> void
00073 mul_unrolled (D* dest, const S1* src1, const S2* src2,
00074               nat r, nat rr, nat l, nat ll, nat c, nat cc)
00075 {
00076   typedef implementation<matrix_multiply,V> Mat;
00077   typedef implementation<matrix_multiply_base,matrix_naive> NMat;
00078   typedef typename Op::acc_op Acc;
00079   nat nr= r/ur, nl= l/ul, nc= c/uc;
00080   if (nl == 0)
00081     NMat::template clr<Op> (dest, r, rr, c, cc);
00082   else
00083     for (nat ir=0; ir<nr; ir++)
00084       for (nat ic=0; ic<nc; ic++) {
00085         nat il=0;
00086         for (; il<1; il++)
00087           matrix_multiply_helper<Op,D,S1,S2,ur,ul,uc>::
00088             mul_stride (dest + Mat::index (ir*ur, ic*uc, rr, cc),
00089                         src1 + Mat::index (ir*ur, il*ul, rr, ll),
00090                         src2 + Mat::index (il*ul, ic*uc, ll, cc),
00091                         rr, ll);
00092         for (; il<nl; il++)
00093           matrix_multiply_helper<Acc,D,S1,S2,ur,ul,uc>::
00094             mul_stride (dest + Mat::index (ir*ur, ic*uc, rr, cc),
00095                         src1 + Mat::index (ir*ur, il*ul, rr, ll),
00096                         src2 + Mat::index (il*ul, ic*uc, ll, cc),
00097                         rr, ll);
00098       }
00099   mul_complete<Op,V> (dest, src1, src2, r, rr, l, ll, c, cc,
00100                       ur*nr, ul*nl, uc*nc);
00101 }
00102 
00103 /******************************************************************************
00104 * Interface for unrolled multiplication
00105 ******************************************************************************/
00106 
00107 template<nat sz, typename V, typename W>
00108 struct implementation<matrix_multiply_base,V,matrix_unrolled<sz,W> >:
00109   public implementation<matrix_linear,V>
00110 {
00111   const static nat ur= sz;
00112   const static nat ul= sz;
00113   const static nat uc= sz;
00114 
00115   template<typename Op, typename D, typename S1, typename S2>
00116   static inline void
00117   mul (D* dest, const S1* src1, const S2* src2,
00118        nat r, nat rr, nat l, nat ll, nat c, nat cc)
00119   {
00120     mul_unrolled<Op,ur,ul,uc,W> (dest, src1, src2, r, rr, l, ll, c, cc);
00121   }
00122 }; // implementation<matrix_multiply,V,matrix_unrolled<sz,BV> >
00123 
00124 } // namespace mmx
00125 #endif //__MMX__MATRIX_UNROLLED__HPP
 All Classes Namespaces Files Functions Variables Typedefs Friends Defines