GitOrigin-RevId: 7c6fbdfa97
tags/v0.5.0
| @@ -433,7 +433,10 @@ pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0) | |||||
| 'layout is (K/4, M/4, 4(k), 4(m)) x (K/4, N, 4(k))'), | 'layout is (K/4, M/4, 4(k), 4(m)) x (K/4, N, 4(k))'), | ||||
| Doc('MK8', 'Split 8 from M and K, better for neon compute:' | Doc('MK8', 'Split 8 from M and K, better for neon compute:' | ||||
| '(M/8, K/8, 8(k), 8(m)) x (K/8, N, 8(k)). if transposeA the ' | '(M/8, K/8, 8(k), 8(m)) x (K/8, N, 8(k)). if transposeA the ' | ||||
| 'layout is (K/8, M/8, 8(k), 8(m)) x (K/8, N, 8(k))')) | |||||
| 'layout is (K/8, M/8, 8(k), 8(m)) x (K/8, N, 8(k))'), | |||||
| Doc('MK4_DOT', 'Split 4 from M and K, better for neon dotprod:' | |||||
| 'M/4, K/4, 4(m), 4(k)) x (K/4, N, 4(k)). if transposeA the ' | |||||
| 'layout is (K/4, M/4, 4(m), 4(k)) x (K/4, N, 4(k))')) | |||||
| ) | ) | ||||
| (pdef('Winograd', 'winograd param used in convbias'). | (pdef('Winograd', 'winograd param used in convbias'). | ||||
| @@ -186,6 +186,8 @@ size_t MatrixMulForward::pack_size(const Param::Format format) { | |||||
| return 1; | return 1; | ||||
| case Param::Format::MK4: | case Param::Format::MK4: | ||||
| return 4; | return 4; | ||||
| case Param::Format::MK4_DOT: | |||||
| return 4; | |||||
| case Param::Format::MK8: | case Param::Format::MK8: | ||||
| return 8; | return 8; | ||||
| default: | default: | ||||
| @@ -82,6 +82,35 @@ void run_matrix_mul_mk4_tpl(const itype* A, const itype* B, otype* C, size_t M, | |||||
| } | } | ||||
| } | } | ||||
| template <typename itype, typename otype, bool transA, bool transB, | |||||
| typename comp_type = otype> | |||||
| void run_matrix_mul_mk4_dot_tpl(const itype* A, const itype* B, otype* C, | |||||
| size_t M, size_t N, size_t K, size_t LDA, | |||||
| size_t LDB, size_t LDC, const DType& A_type, | |||||
| const DType& B_type) { | |||||
| Getter<itype, comp_type> getterA(A_type), getterB(B_type); | |||||
| for (size_t m = 0; m < M; ++m) { | |||||
| for (size_t n = 0; n < N; ++n) { | |||||
| comp_type res[4] = {comp_type(0)}; | |||||
| for (size_t k = 0; k < K; ++k) { | |||||
| for (size_t i = 0; i < 4; i++) { | |||||
| comp_type av, bv; | |||||
| for (size_t j = 0; j < 4; j++) { | |||||
| av = transA ? getterA(A[k * LDA + m * 16 + 4 * i + j]) | |||||
| : getterA(A[m * LDA + k * 16 + 4 * i + j]), | |||||
| bv = transB ? getterB(B[n * LDB + k * 4 + j]) | |||||
| : getterB(B[k * LDB + n * 4 + j]); | |||||
| res[i] += av * bv; | |||||
| } | |||||
| } | |||||
| } | |||||
| for (size_t i = 0; i < 4; i++) { | |||||
| C[m * LDC + n * 4 + i] = res[i]; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| template <typename itype, typename otype, bool transA, bool transB, | template <typename itype, typename otype, bool transA, bool transB, | ||||
| typename comp_type = otype> | typename comp_type = otype> | ||||
| void run_matrix_mul_mk8_tpl(const itype* A, const itype* B, otype* C, size_t M, | void run_matrix_mul_mk8_tpl(const itype* A, const itype* B, otype* C, size_t M, | ||||
| @@ -38,22 +38,27 @@ void dispatch_ta_tb(_megdnn_tensor_in A, _megdnn_tensor_in B, | |||||
| auto LDA = A.layout.stride[0], LDB = B.layout.stride[0], | auto LDA = A.layout.stride[0], LDB = B.layout.stride[0], | ||||
| LDC = C.layout.stride[0]; | LDC = C.layout.stride[0]; | ||||
| #define cb(_itype, _otype, _comp_type) \ | |||||
| if (param.format == param::MatrixMul::Format::DEFAULT) { \ | |||||
| return run_matrix_mul_tpl<_itype, _otype, TA, TB, _comp_type>( \ | |||||
| A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \ | |||||
| C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \ | |||||
| A.layout.dtype, B.layout.dtype); \ | |||||
| } else if (param.format == param::MatrixMul::Format::MK4) { \ | |||||
| return run_matrix_mul_mk4_tpl<_itype, _otype, TA, TB, _comp_type>( \ | |||||
| A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \ | |||||
| C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \ | |||||
| A.layout.dtype, B.layout.dtype); \ | |||||
| } else if (param.format == param::MatrixMul::Format::MK8) { \ | |||||
| return run_matrix_mul_mk8_tpl<_itype, _otype, TA, TB, _comp_type>( \ | |||||
| A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \ | |||||
| C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \ | |||||
| A.layout.dtype, B.layout.dtype); \ | |||||
| #define cb(_itype, _otype, _comp_type) \ | |||||
| if (param.format == param::MatrixMul::Format::DEFAULT) { \ | |||||
| return run_matrix_mul_tpl<_itype, _otype, TA, TB, _comp_type>( \ | |||||
| A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \ | |||||
| C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \ | |||||
| A.layout.dtype, B.layout.dtype); \ | |||||
| } else if (param.format == param::MatrixMul::Format::MK4) { \ | |||||
| return run_matrix_mul_mk4_tpl<_itype, _otype, TA, TB, _comp_type>( \ | |||||
| A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \ | |||||
| C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \ | |||||
| A.layout.dtype, B.layout.dtype); \ | |||||
| } else if (param.format == param::MatrixMul::Format::MK4_DOT) { \ | |||||
| return run_matrix_mul_mk4_dot_tpl<_itype, _otype, TA, TB, _comp_type>( \ | |||||
| A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \ | |||||
| C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \ | |||||
| A.layout.dtype, B.layout.dtype); \ | |||||
| } else if (param.format == param::MatrixMul::Format::MK8) { \ | |||||
| return run_matrix_mul_mk8_tpl<_itype, _otype, TA, TB, _comp_type>( \ | |||||
| A.compatible_ptr<_itype>(), B.compatible_ptr<_itype>(), \ | |||||
| C.compatible_ptr<_otype>(), M, N, K, LDA, LDB, LDC, \ | |||||
| A.layout.dtype, B.layout.dtype); \ | |||||
| } | } | ||||
| if (A.layout.dtype == dtype::Float32()) { | if (A.layout.dtype == dtype::Float32()) { | ||||