algebramix_doc 0.3
|
00001 00002 /****************************************************************************** 00003 * MODULE : matrix_threads.hpp 00004 * DESCRIPTION: multi-threaded matrix-multiplication 00005 * COPYRIGHT : (C) 2007 Joris van der Hoeven 00006 ******************************************************************************* 00007 * This software falls under the GNU general public license and comes WITHOUT 00008 * ANY WARRANTY WHATSOEVER. See the file $TEXMACS_PATH/LICENSE for more details. 00009 * If you don't have this file, write to the Free Software Foundation, Inc., 00010 * 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. 00011 ******************************************************************************/ 00012 00013 #ifndef __MMX__MATRIX_THREADS__HPP 00014 #define __MMX__MATRIX_THREADS__HPP 00015 #include <algebramix/matrix_unrolled.hpp> 00016 #include <basix/threads.hpp> 00017 00018 namespace mmx { 00019 00020 /****************************************************************************** 00021 * Multi-threaded matrix variant 00022 ******************************************************************************/ 00023 00024 template<typename V> 00025 struct matrix_threads: public V { 00026 typedef typename V::Vec Vec; 00027 typedef typename V::Naive Naive; 00028 typedef matrix_threads<typename V::Positive> Positive; 00029 typedef matrix_threads<typename V::No_simd> No_simd; 00030 typedef typename V::No_thread No_thread; 00031 typedef matrix_threads<typename V::No_scaled> No_scaled; 00032 }; 00033 00034 template<typename F, typename V, typename W> 00035 struct implementation<F,V,matrix_threads<W> >: 00036 public implementation<F,V,W> {}; 00037 00038 /****************************************************************************** 00039 * Multi-threaded matrix-multiplication 00040 ******************************************************************************/ 00041 00042 #ifdef BASIX_ENABLE_THREADS 00043 00044 template<typename V, typename W> 00045 struct implementation<matrix_multiply_base,V,matrix_threads<W> >: 00046 public implementation<matrix_linear,V> 00047 { 00048 static const nat thr= 1024; 00049 static const nat sz = 4; 00050 typedef implementation<matrix_multiply,V,W> Mat; 00051 00052 template<typename Op, typename D, typename S1, typename S2> 00053 struct multiply_task_rep: public task_rep { 00054 D* d; const S1* s1; const S2* s2; 00055 nat r; nat rr; nat l; nat ll; nat c; nat cc; 00056 //D *xd; S1 *xs1; S2 *xs2; 00057 public: 00058 inline multiply_task_rep (D* d2, const S1* s1b, const S2* s2b, 00059 nat r2, nat rr2, nat l2, nat ll2, nat c2, nat cc2): 00060 d (d2), s1 (s1b), s2 (s2b), 00061 r (r2), rr (rr2), l (l2), ll (ll2), c (c2), cc (cc2) 00062 { 00063 /* 00064 xd = mmx_new<D > (aligned_size<C,V> (r * c)); 00065 xs1= mmx_new<S1> (aligned_size<C,V> (r * l)); 00066 xs2= mmx_new<S2> (aligned_size<C,V> (l * c)); 00067 Mat::template mat_unary_stride<id_op> 00068 (xs1, Mat::index (1, 0, r , l ), Mat::index (0, 1, r , l ), 00069 s1 , Mat::index (1, 0, rr, ll), Mat::index (0, 1, rr, ll), r, l); 00070 Mat::template mat_unary_stride<id_op> 00071 (xs2, Mat::index (1, 0, l , c ), Mat::index (0, 1, l , c ), 00072 s2 , Mat::index (1, 0, ll, cc), Mat::index (0, 1, ll, cc), l, c); 00073 */ 00074 } 00075 inline ~multiply_task_rep () { 00076 /* 00077 Mat::template mat_unary_stride<typename Op::nomul_op> 00078 (d , Mat::index (1, 0, rr, cc), Mat::index (0, 1, rr, cc), 00079 xd, Mat::index (1, 0, r , c ), Mat::index (0, 1, r , c ), r, c); 00080 mmx_delete<D > (xd , aligned_size<C,V> (r * c)); 00081 mmx_delete<S1> (xs1, aligned_size<C,V> (r * l)); 00082 mmx_delete<S2> (xs2, aligned_size<C,V> (l * c)); 00083 */ 00084 } 00085 void execute () { 00086 //Mat::template mul<mul_op> (xd, xs1, xs2, r, r, l, l, c, c); 00087 Mat::template mul<mul_op> (d, s1, s2, r, rr, l, ll, c, cc); 00088 } 00089 }; 00090 00091 template<typename Op, typename D, typename S1, typename S2> static inline void 00092 mul (D* d, const S1* s1, const S2* s2, 00093 nat r, nat rr, nat l, nat ll, nat c, nat cc) 00094 { 00095 typedef typename Op::acc_op Acc; 00096 if (r * c < thr) Mat::template mul<Op> (d, s1, s2, r, rr, l, ll, c, cc); 00097 else { 00098 nat tr= r, nr= 1, tc= c, nc= 1, tt= threads_number; 00099 while (tt != 1) { 00100 if ((tt & 1) == 0) { 00101 if (tr > tc) { tr= (tr+1) >> 1; nr <<= 1; } 00102 else { tc= (tc+1) >> 1; nc <<= 1; } 00103 tt >>= 1; 00104 } 00105 else { 00106 if (tr > tc) { tr= (tr+tt-1) / tt; nr *= tt; } 00107 else { tc= (tc+tt-1) / tt; nc *= tt; } 00108 tt= 1; 00109 } 00110 } 00111 tr= sz * ((tr + sz - 1) / sz); 00112 tc= sz * ((tc + sz - 1) / sz); 00113 00114 task tasks[nr*nc]; 00115 for (nat ir=0; ir<nr; ir++) 00116 for (nat ic=0; ic<nc; ic++) { 00117 nat r1= ir * tr, r2= min (r1 + tr, r); 00118 nat c1= ic * tc, c2= min (c1 + tc, c); 00119 if (r1 < r && c1 < c) { 00120 D* td = d + Mat::index (r1, c1, rr, cc); 00121 const S1* ts1= s1 + Mat::index (r1, 0 , rr, ll); 00122 const S2* ts2= s2 + Mat::index (0 , c1, ll, cc); 00123 tasks[Mat::index (ir, ic, nr, nc)]= 00124 new multiply_task_rep<Op,D,S1,S2> 00125 (td, ts1, ts2, r2-r1, rr, l, ll, c2-c1, cc); 00126 } 00127 } 00128 threads_execute (tasks, nr*nc); 00129 } 00130 } 00131 00132 }; // implementation<matrix_multiply_base,V,matrix_threads<W> > 00133 00134 #endif // BASIX_ENABLE_THREADS 00135 00136 } // namespace mmx 00137 #endif //__MMX__MATRIX_THREADS__HPP