| @@ -30,6 +30,7 @@ void run_matmul_mk_format(Handle* handle, param::MatrixMul::Format format, | |||
| auto extra_impl = [](const TensorNDArray& tensors, param::MatrixMul param, | |||
| Handle* handle, size_t pack_size) { | |||
| megdnn_assert((param.format == param::MatrixMul::Format::MK4 || | |||
| param.format == param::MatrixMul::Format::MK4_DOT || | |||
| param.format == param::MatrixMul::Format::MK8) && | |||
| tensors.size() == 3); | |||
| param::MatrixMul new_param = param; | |||
| @@ -41,18 +42,34 @@ void run_matmul_mk_format(Handle* handle, param::MatrixMul::Format format, | |||
| TensorLayoutArray default_layouts, mk4_layouts; | |||
| if (param.transposeA) { | |||
| default_layouts.emplace_back(tensors[0].layout.reshape({K, M})); | |||
| mk4_layouts.emplace_back( | |||
| default_layouts.back() | |||
| .reshape({K / pack_size, M / pack_size, pack_size, | |||
| pack_size}) | |||
| .dimshuffle({0, 2, 1, 3})); | |||
| if (param.format == param::MatrixMul::Format::MK4_DOT) { | |||
| mk4_layouts.emplace_back( | |||
| default_layouts.back() | |||
| .reshape({K / pack_size, M / pack_size, | |||
| pack_size, pack_size}) | |||
| .dimshuffle({0, 3, 1, 2})); | |||
| } else { | |||
| mk4_layouts.emplace_back( | |||
| default_layouts.back() | |||
| .reshape({K / pack_size, M / pack_size, | |||
| pack_size, pack_size}) | |||
| .dimshuffle({0, 2, 1, 3})); | |||
| } | |||
| } else { | |||
| default_layouts.emplace_back(tensors[0].layout.reshape({M, K})); | |||
| mk4_layouts.emplace_back( | |||
| default_layouts.back() | |||
| .reshape({M / pack_size, K / pack_size, pack_size, | |||
| pack_size}) | |||
| .dimshuffle({0, 3, 1, 2})); | |||
| if (param.format == param::MatrixMul::Format::MK4_DOT) { | |||
| mk4_layouts.emplace_back( | |||
| default_layouts.back() | |||
| .reshape({M / pack_size, K / pack_size, | |||
| pack_size, pack_size}) | |||
| .dimshuffle({0, 2, 1, 3})); | |||
| } else { | |||
| mk4_layouts.emplace_back( | |||
| default_layouts.back() | |||
| .reshape({M / pack_size, K / pack_size, | |||
| pack_size, pack_size}) | |||
| .dimshuffle({0, 3, 1, 2})); | |||
| } | |||
| } | |||
| if (param.transposeB) { | |||
| default_layouts.emplace_back(tensors[1].layout.reshape({N, K})); | |||
| @@ -238,6 +255,11 @@ TEST_F(NAIVE, MATRIX_MUL_MK8) { | |||
| dtype::Int16(), dtype::Int16(), dtype::Int32()); | |||
| } | |||
| TEST_F(NAIVE, MATRIX_MUL_MK4_DOT) { | |||
| run_matmul_mk_format(handle(), param::MatrixMul::Format::MK4_DOT, | |||
| dtype::Int8(), dtype::Int8(), dtype::Int32()); | |||
| } | |||
| TEST_F(NAIVE, MATRIX_MUL_BFLOAT16) { | |||
| Checker<MatrixMul> checker(handle(), /* check_dispatch */ false); | |||
| MatrixMul::Param param, fp32_param; | |||