algebramix_doc 0.3
|
00001 00002 /****************************************************************************** 00003 * MODULE : matrix_fixed.hpp 00004 * DESCRIPTION: inlined low-level matrix operations of a fixed size 00005 * COPYRIGHT : (C) 2007 Joris van der Hoeven and Gregoire Lecerf 00006 ******************************************************************************* 00007 * This software falls under the GNU general public license and comes WITHOUT 00008 * ANY WARRANTY WHATSOEVER. See the file $TEXMACS_PATH/LICENSE for more details. 00009 * If you don't have this file, write to the Free Software Foundation, Inc., 00010 * 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. 00011 ******************************************************************************/ 00012 00013 #ifndef __MMX__MATRIX_FIXED__HPP 00014 #define __MMX__MATRIX_FIXED__HPP 00015 #include <algebramix/matrix_naive.hpp> 00016 00017 namespace mmx { 00018 00019 /****************************************************************************** 00020 * Inlined matrix multiplication for fixed sizes 00021 ******************************************************************************/ 00022 00023 template<typename Op, nat tp> struct div_type {}; 00024 00025 template<typename Op, typename D, typename S1, typename S2, 00026 nat r, nat l, nat c> 00027 struct matrix_multiply_helper { 00028 static const nat tp= (r*c==1 || (r<l && c<l && l>8))? 3: (r<=c? 2: 1); 00029 00030 template<typename rr, typename cc> static inline void 00031 mul (D* d, const S1* m1, const S2* m2) { 00032 matrix_multiply_helper <div_type<Op,tp>, D, S1, S2, r, l, c>:: 00033 template mul<rr,cc> (d, m1, m2); 00034 } 00035 00036 static inline void 00037 mul_stride (D* d, const S1* m1, const S2* m2, nat rr, nat ll) { 00038 matrix_multiply_helper <div_type<Op,tp>, D, S1, S2, r, l, c>:: 00039 mul_stride (d, m1, m2, rr, ll); 00040 } 00041 }; 00042 00043 template<typename Op, typename D, typename S1, typename S2, 00044 nat r, nat l, nat c> 00045 struct matrix_multiply_helper<div_type<Op,1>, D, S1, S2, r, l, c> { 00046 // Recurse on a smaller number of rows r 00047 static const nat r1= (r>>1), r2= r-r1; 00048 00049 template<nat rr, nat ll> static inline void 00050 mul (D* d, const S1* m1, const S2* m2) { 00051 { 00052 matrix_multiply_helper <Op, D, S1, S2, r1, l, c>:: 00053 template mul<rr,ll> (d, m1, m2); 00054 matrix_multiply_helper <Op, D, S1, S2, r2, l, c>:: 00055 template mul<rr,ll> (d + r1, m1 + r1, m2); 00056 } 00057 } 00058 00059 static inline void 00060 mul_stride (D* d, const S1* m1, const S2* m2, nat rr, nat ll) { 00061 { 00062 matrix_multiply_helper <Op, D, S1, S2, r1, l, c>:: 00063 mul_stride (d, m1, m2, rr, ll); 00064 matrix_multiply_helper <Op, D, S1, S2, r2, l, c>:: 00065 mul_stride (d + r1, m1 + r1, m2, rr, ll); 00066 } 00067 } 00068 }; 00069 00070 template<typename Op, typename D, typename S1, typename S2, 00071 nat r, nat l, nat c> 00072 struct matrix_multiply_helper<div_type<Op,2>, D, S1, S2, r, l, c> { 00073 // Recurse on a smaller number of columns c 00074 static const nat c1= (c>>1), c2= c-c1; 00075 00076 template<nat rr, nat ll> static inline void 00077 mul (D* d, const S1* m1, const S2* m2) { 00078 { 00079 matrix_multiply_helper <Op, D, S1, S2, r, l, c1>:: 00080 template mul<rr,ll> (d, m1, m2); 00081 matrix_multiply_helper <Op, D, S1, S2, r, l, c2>:: 00082 template mul<rr,ll> (d + c1 * rr, m1, m2 + c1 * ll); 00083 } 00084 } 00085 00086 static inline void 00087 mul_stride (D* d, const S1* m1, const S2* m2, nat rr, nat ll) { 00088 { 00089 matrix_multiply_helper <Op, D, S1, S2, r, l, c1>:: 00090 mul_stride (d, m1, m2, rr, ll); 00091 matrix_multiply_helper <Op, D, S1, S2, r, l, c2>:: 00092 mul_stride (d + c1 * rr, m1, m2 + c1 * ll, rr, ll); 00093 } 00094 } 00095 }; 00096 00097 template<typename Op, typename D, typename S1, typename S2, 00098 nat r, nat l, nat c> 00099 struct matrix_multiply_helper<div_type<Op,3>, D, S1, S2, r, l, c> { 00100 // Recurse on a smaller length l of the inner loop 00101 static const nat l1= (l>>1), l2= l-l1; 00102 typedef typename Op::acc_op Acc; 00103 00104 template<nat rr, nat ll> static inline void 00105 mul (D* d, const S1* m1, const S2* m2) { 00106 { 00107 matrix_multiply_helper <Op , D, S1, S2, r, l1, c>:: 00108 template mul<rr,ll> (d, m1, m2); 00109 matrix_multiply_helper <Acc, D, S1, S2, r, l2, c>:: 00110 template mul<rr,ll> (d, m1 + l1 * rr, m2 + l1); 00111 } 00112 } 00113 00114 static inline void 00115 mul_stride (D* d, const S1* m1, const S2* m2, nat rr, nat ll) { 00116 { 00117 matrix_multiply_helper <Op , D, S1, S2, r, l1, c>:: 00118 mul_stride (d, m1, m2, rr, ll); 00119 matrix_multiply_helper <Acc, D, S1, S2, r, l2, c>:: 00120 mul_stride (d, m1 + l1 * rr, m2 + l1, rr, ll); 00121 } 00122 } 00123 }; 00124 00125 template<typename Op, typename D, typename S1, typename S2> 00126 struct matrix_multiply_helper<Op, D, S1, S2, 1, 1, 1> { 00127 template<nat rr, nat ll> static inline void 00128 mul (D* d, const S1* m1, const S2* m2) { 00129 Op::set_op (*d, *m1, *m2); 00130 } 00131 00132 static inline void 00133 mul_stride (D* d, const S1* m1, const S2* m2, nat rr, nat ll) { 00134 (void) rr; (void) ll; 00135 Op::set_op (*d, *m1, *m2); 00136 } 00137 }; 00138 00139 } // namespace mmx 00140 #endif //__MMX__MATRIX_FIXED__HPP