algebramix_doc 0.3
/Users/mourrain/Devel/mmx/algebramix/include/algebramix/matrix_threads.hpp
Go to the documentation of this file.
00001 
00002 /******************************************************************************
00003 * MODULE     : matrix_threads.hpp
00004 * DESCRIPTION: multi-threaded matrix-multiplication
00005 * COPYRIGHT  : (C) 2007  Joris van der Hoeven
00006 *******************************************************************************
00007 * This software falls under the GNU general public license and comes WITHOUT
00008 * ANY WARRANTY WHATSOEVER. See the file $TEXMACS_PATH/LICENSE for more details.
00009 * If you don't have this file, write to the Free Software Foundation, Inc.,
00010 * 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
00011 ******************************************************************************/
00012 
00013 #ifndef __MMX__MATRIX_THREADS__HPP
00014 #define __MMX__MATRIX_THREADS__HPP
00015 #include <algebramix/matrix_unrolled.hpp>
00016 #include <basix/threads.hpp>
00017 
00018 namespace mmx {
00019 
00020 /******************************************************************************
00021 * Multi-threaded matrix variant
00022 ******************************************************************************/
00023 
00024 template<typename V>
00025 struct matrix_threads: public V {
00026   typedef typename V::Vec Vec;
00027   typedef typename V::Naive Naive;
00028   typedef matrix_threads<typename V::Positive> Positive;
00029   typedef matrix_threads<typename V::No_simd> No_simd;
00030   typedef typename V::No_thread No_thread;
00031   typedef matrix_threads<typename V::No_scaled> No_scaled;
00032 };
00033 
00034 template<typename F, typename V, typename W>
00035 struct implementation<F,V,matrix_threads<W> >:
00036   public implementation<F,V,W> {};
00037 
00038 /******************************************************************************
00039 * Multi-threaded matrix-multiplication
00040 ******************************************************************************/
00041 
00042 #ifdef BASIX_ENABLE_THREADS
00043 
00044 template<typename V, typename W>
00045 struct implementation<matrix_multiply_base,V,matrix_threads<W> >:
00046   public implementation<matrix_linear,V>
00047 {
00048   static const nat thr= 1024;
00049   static const nat sz = 4;
00050   typedef implementation<matrix_multiply,V,W> Mat;
00051 
00052 template<typename Op, typename D, typename S1, typename S2>
00053 struct multiply_task_rep: public task_rep {
00054   D* d; const S1* s1; const S2* s2;
00055   nat r; nat rr; nat l; nat ll; nat c; nat cc;
00056   //D *xd; S1 *xs1; S2 *xs2;
00057 public:
00058   inline multiply_task_rep (D* d2, const S1* s1b, const S2* s2b,
00059                             nat r2, nat rr2, nat l2, nat ll2, nat c2, nat cc2):
00060     d (d2), s1 (s1b), s2 (s2b),
00061     r (r2), rr (rr2), l (l2), ll (ll2), c (c2), cc (cc2)
00062   {
00063     /*
00064     xd = mmx_new<D > (aligned_size<C,V> (r * c));
00065     xs1= mmx_new<S1> (aligned_size<C,V> (r * l));
00066     xs2= mmx_new<S2> (aligned_size<C,V> (l * c));
00067     Mat::template mat_unary_stride<id_op>
00068       (xs1, Mat::index (1, 0, r , l ), Mat::index (0, 1, r , l ),
00069        s1 , Mat::index (1, 0, rr, ll), Mat::index (0, 1, rr, ll), r, l);
00070     Mat::template mat_unary_stride<id_op>
00071       (xs2, Mat::index (1, 0, l , c ), Mat::index (0, 1, l , c ),
00072        s2 , Mat::index (1, 0, ll, cc), Mat::index (0, 1, ll, cc), l, c);
00073     */
00074   }
00075   inline ~multiply_task_rep () {
00076     /*
00077     Mat::template mat_unary_stride<typename Op::nomul_op>
00078       (d , Mat::index (1, 0, rr, cc), Mat::index (0, 1, rr, cc),
00079        xd, Mat::index (1, 0, r , c ), Mat::index (0, 1, r , c ), r, c);
00080     mmx_delete<D > (xd , aligned_size<C,V> (r * c));
00081     mmx_delete<S1> (xs1, aligned_size<C,V> (r * l));
00082     mmx_delete<S2> (xs2, aligned_size<C,V> (l * c));
00083     */
00084   }
00085   void execute () {
00086     //Mat::template mul<mul_op> (xd, xs1, xs2, r, r, l, l, c, c);
00087     Mat::template mul<mul_op> (d, s1, s2, r, rr, l, ll, c, cc);
00088   }
00089 };
00090 
00091 template<typename Op, typename D, typename S1, typename S2> static inline void
00092 mul (D* d, const S1* s1, const S2* s2,
00093      nat r, nat rr, nat l, nat ll, nat c, nat cc)
00094 {
00095   typedef typename Op::acc_op Acc;
00096   if (r * c < thr) Mat::template mul<Op> (d, s1, s2, r, rr, l, ll, c, cc);
00097   else {
00098     nat tr= r, nr= 1, tc= c, nc= 1, tt= threads_number;
00099     while (tt != 1) {
00100       if ((tt & 1) == 0) {
00101         if (tr > tc) { tr= (tr+1) >> 1; nr <<= 1; }
00102         else { tc= (tc+1) >> 1; nc <<= 1; }
00103         tt >>= 1;
00104       }
00105       else {
00106         if (tr > tc) { tr= (tr+tt-1) / tt; nr *= tt; }
00107         else { tc= (tc+tt-1) / tt; nc *= tt; }
00108         tt= 1;
00109       }
00110     }
00111     tr= sz * ((tr + sz - 1) / sz);
00112     tc= sz * ((tc + sz - 1) / sz);
00113 
00114     task tasks[nr*nc];
00115     for (nat ir=0; ir<nr; ir++)
00116       for (nat ic=0; ic<nc; ic++) {
00117         nat r1= ir * tr, r2= min (r1 + tr, r);
00118         nat c1= ic * tc, c2= min (c1 + tc, c);
00119         if (r1 < r && c1 < c) {
00120           D*        td = d  + Mat::index (r1, c1, rr, cc);
00121           const S1* ts1= s1 + Mat::index (r1, 0 , rr, ll);
00122           const S2* ts2= s2 + Mat::index (0 , c1, ll, cc);
00123           tasks[Mat::index (ir, ic, nr, nc)]=
00124             new multiply_task_rep<Op,D,S1,S2>
00125                   (td, ts1, ts2, r2-r1, rr, l, ll, c2-c1, cc);
00126         }
00127       }
00128     threads_execute (tasks, nr*nc);
00129   }
00130 }
00131 
00132 }; // implementation<matrix_multiply_base,V,matrix_threads<W> >
00133 
00134 #endif // BASIX_ENABLE_THREADS
00135 
00136 } // namespace mmx
00137 #endif //__MMX__MATRIX_THREADS__HPP
 All Classes Namespaces Files Functions Variables Typedefs Friends Defines