GitOrigin-RevId: b6af21e8e3
tags/v1.1.0
| @@ -1310,4 +1310,74 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8K8x8x8, | |||
| int32_t); | |||
| #endif | |||
| /* ===================== Int8x8x16 K8x8x8 algo ===================== */ | |||
| namespace { | |||
| void int8x8x16_mk4_8x8x8_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
| MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, | |||
| midout_iv("int8x8x16_mk4_8x8x8_kern"_hash)) { | |||
| auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||
| auto trA = kern_param.trA, trB = kern_param.trB; | |||
| auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||
| auto A_type = kern_param.A_type, B_type = kern_param.B_type, | |||
| C_type = kern_param.C_type; | |||
| const auto Aptr = kern_param.A<dt_int8>(), | |||
| Bptr = kern_param.B<dt_int8>(); | |||
| auto Cptr = kern_param.C<dt_int16>(); | |||
| aarch64::matmul::gemm_s8x8x16_mk4_8x8x8 strategy(M, N, K, A_type, | |||
| B_type, C_type); | |||
| megdnn::matmul::GemmInterleaved< | |||
| aarch64::matmul::gemm_s8x8x16_mk4_8x8x8>(M, N, K, trA, trB, | |||
| strategy) | |||
| .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, | |||
| kern_param.workspace_ptr); | |||
| } | |||
| MIDOUT_END(); | |||
| } | |||
| } // anonymous namespace | |||
| bool MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8::usable( | |||
| const KernSizeParam& kern_size_param) const { | |||
| return can_be_treated_as_int8x8x16(kern_size_param) && | |||
| kern_size_param.format == param::MatrixMul::Format::MK4 && | |||
| kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && | |||
| !kern_size_param.trA && !kern_size_param.trB && | |||
| kern_size_param.M % 4 == 0 && kern_size_param.K % 4 == 0; | |||
| } | |||
| bool MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8::preferred( | |||
| const KernSizeParam&) const { | |||
| return true; | |||
| } | |||
| size_t MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8::get_workspace( | |||
| const KernSizeParam& kern_size_param) const { | |||
| MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, | |||
| midout_iv("AlgoInt8x8x16_MK4_8x8x8::get_workspace"_hash)) { | |||
| auto M = kern_size_param.M, N = kern_size_param.N, | |||
| K = kern_size_param.K; | |||
| auto trA = kern_size_param.trA, trB = kern_size_param.trB; | |||
| auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, | |||
| C_type = kern_size_param.C_type; | |||
| aarch64::matmul::gemm_s8x8x16_mk4_8x8x8 strategy(M, N, K, A_type, | |||
| B_type, C_type); | |||
| return megdnn::matmul::GemmInterleaved< | |||
| matmul::gemm_s8x8x16_mk4_8x8x8>(M, N, K, trA, trB, | |||
| strategy) | |||
| .get_workspace_size(); | |||
| } | |||
| MIDOUT_END(); | |||
| return 0; | |||
| } | |||
| MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8::get_kern( | |||
| const KernSizeParam&) const { | |||
| return int8x8x16_mk4_8x8x8_kern; | |||
| } | |||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16MK4_K8x8x8, | |||
| megdnn_aarch64_matmul_kern, | |||
| "AlgoInt8x8x16MK4_K8x8x8Impl"_hash, | |||
| aarch64::matmul::gemm_s8x8x16_mk4_8x8x8, int8_t, | |||
| int16_t); | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -202,6 +202,22 @@ public: | |||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
| }; | |||
| class MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8 final : public AlgoBase { | |||
| public: | |||
| bool is_reproducible() const override { return true; } | |||
| const char* name() const override { | |||
| return "AARCH64_INT8X8X16_MK4_K8X8X8"; | |||
| } | |||
| bool usable(const KernSizeParam&) const override; | |||
| bool preferred(const KernSizeParam&) const override; | |||
| size_t get_workspace(const KernSizeParam&) const override; | |||
| kern_t get_kern(const KernSizeParam&) const override; | |||
| void* type() const override { return sm_arm_common_algo_type; } | |||
| PackMode packmode() const override { return PackMode::DEFAULT; } | |||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
| }; | |||
| class MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8 final : public AlgoBase { | |||
| public: | |||
| bool is_reproducible() const override { return true; } | |||
| @@ -2101,6 +2101,62 @@ static inline void transpos_12x4_s8(const int8_t* inptr0, int8_t* outptr) { | |||
| vreinterpretq_s32_s8(input2), 3); | |||
| } | |||
| template <typename T> | |||
| static inline void interleave_8x8_mk4_b(const T*& inptr0, const T*& inptr1, | |||
| T*& outptr) { | |||
| static_assert( | |||
| std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value, | |||
| "transpose_8x4_1_b only support uint8_t and int8_t"); | |||
| asm volatile( | |||
| "ld1 {v0.4s}, [%[inptr0]], #16\n" | |||
| "ld1 {v1.4s}, [%[inptr1]], #16\n" | |||
| "ld1 {v2.4s}, [%[inptr0]], #16\n" | |||
| "ld1 {v3.4s}, [%[inptr1]], #16\n" | |||
| "zip1 v4.4s, v0.4s, v1.4s \n" | |||
| "zip2 v5.4s, v0.4s, v1.4s \n" | |||
| "zip1 v6.4s, v2.4s, v3.4s\n" | |||
| "zip2 v7.4s, v2.4s, v3.4s\n" | |||
| "st1 {v4.4s},[%[outptr]],#16\n" | |||
| "st1 {v5.4s},[%[outptr]],#16\n" | |||
| "st1 {v6.4s},[%[outptr]],#16\n" | |||
| "st1 {v7.4s},[%[outptr]],#16\n" | |||
| : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), | |||
| [outptr] "+r"(outptr) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "v4", "v5","v6","v7","memory"); | |||
| } | |||
| template <typename T> | |||
| static inline void transpose_8x8_mk4_b(const T*& inptr0, const T*& inptr1, | |||
| T* outptr) { | |||
| static_assert( | |||
| std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value, | |||
| "transpose_8x4_1_b only support uint8_t and int8_t"); | |||
| asm volatile( | |||
| "ld4 {v0.8b-v3.8b}, [%[inptr0]], #32\n" | |||
| "ld4 {v4.8b-v7.8b}, [%[inptr1]], #32\n" | |||
| "st1 {v0.2s},[%[outptr]],#8\n" | |||
| "st1 {v1.2s},[%[outptr]],#8\n" | |||
| "st1 {v2.2s},[%[outptr]],#8\n" | |||
| "st1 {v3.2s},[%[outptr]],#8\n" | |||
| "st1 {v4.2s},[%[outptr]],#8\n" | |||
| "st1 {v5.2s},[%[outptr]],#8\n" | |||
| "st1 {v6.2s},[%[outptr]],#8\n" | |||
| "st1 {v7.2s},[%[outptr]],#8\n" | |||
| : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), | |||
| [outptr] "+r"(outptr) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "v4", "v5","v6","v7","memory"); | |||
| } | |||
| } // namespace aarch64 | |||
| } // namespace megdnn | |||
| @@ -13,6 +13,7 @@ | |||
| #include "src/aarch64/matrix_mul/asm/common.h" | |||
| #include "src/aarch64/matrix_mul/int8x8x16/kernel_4x4x16.h" | |||
| #include "src/aarch64/matrix_mul/int8x8x16/kernel_8x8x8.h" | |||
| #include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_8x8x8.h" | |||
| #include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_16x12x4_a53.h" | |||
| #include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_4x4x8_a72.h" | |||
| #include "src/aarch64/matrix_mul/int8x8x16/strategy.h" | |||
| @@ -357,4 +358,81 @@ void gemm_s8x8x16_mk4_4x4_a72::kern(const dt_int8* packA, const dt_int8* packB, | |||
| } | |||
| } | |||
| // ===========================gemm_s8x8x16_mk4_8x8x8================================== | |||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_8x8x8); | |||
| void gemm_s8x8x16_mk4_8x8x8::pack_A(dt_int8* out, const dt_int8* in, | |||
| int ldin, int y0, int ymax, int k0, | |||
| int kmax, bool) const { | |||
| matmul_mk4_8x8x8::gemm_s8x8x16_mk4_8x8x8_pack_A(out, in, ldin, y0, | |||
| ymax, k0, kmax); | |||
| } | |||
| void gemm_s8x8x16_mk4_8x8x8::pack_B(dt_int8* out, const dt_int8* in, | |||
| int ldin, int x0, int xmax, int k0, | |||
| int kmax, bool) const { | |||
| matmul_mk4_8x8x8::gemm_s8x8x16_mk4_8x8x8_pack_B(out, in, ldin, x0, | |||
| xmax, k0, kmax); | |||
| } | |||
| void gemm_s8x8x16_mk4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB, | |||
| size_t M, size_t N, size_t K, dt_int16* C, | |||
| size_t LDC, bool is_first_k, const dt_int16*, | |||
| dt_int16*) const { | |||
| megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
| C_dtype.enumv() == DTypeEnum::Int16 && | |||
| A_dtype.enumv() == DTypeEnum::Int8); | |||
| megdnn_assert(is_first_k == true, "only impl is_first_k"); | |||
| MEGDNN_MARK_USED_VAR(A_dtype); | |||
| MEGDNN_MARK_USED_VAR(B_dtype); | |||
| MEGDNN_MARK_USED_VAR(C_dtype); | |||
| megdnn_assert(M % 4 == 0 && K % 4 == 0, "M and K must be time of 4"); | |||
| constexpr size_t pack_size = 4; | |||
| constexpr size_t pack_m = 8; | |||
| constexpr size_t pack_n = 8; | |||
| const size_t remain_n = N % pack_n; | |||
| size_t remain_m = M % pack_m; | |||
| K = round_up<size_t>(K, 8); | |||
| size_t KSIZE8 = K * pack_n; | |||
| size_t m_idx = 0; | |||
| for (; m_idx + pack_m <= M; m_idx += pack_m) { | |||
| int16_t* output = C + (m_idx / pack_size * LDC); | |||
| size_t n_idx = 0; | |||
| const int8_t* cur_packB = packB; | |||
| for (; n_idx + pack_n <= N; n_idx += pack_n) { | |||
| matmul_mk4_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC, | |||
| is_first_k, pack_m, pack_n); | |||
| output += pack_n * pack_size; | |||
| cur_packB += KSIZE8; | |||
| } | |||
| if (remain_n > 0) { | |||
| matmul_mk4_8x8x8::kern_8x8_remain(packA, cur_packB, K, output, LDC, | |||
| is_first_k, pack_m, remain_n); | |||
| output += remain_n * pack_size; | |||
| cur_packB += KSIZE8; | |||
| } | |||
| packA += KSIZE8; | |||
| } | |||
| if (remain_m == 4) { | |||
| int16_t* output = C + (m_idx / pack_size * LDC); | |||
| size_t n_idx = 0; | |||
| const int8_t* cur_packB = packB; | |||
| for (; n_idx + pack_n <= N; n_idx += pack_n) { | |||
| matmul_mk4_8x8x8::kern_4x8(packA, cur_packB, K, output, LDC, | |||
| is_first_k, 4, pack_n); | |||
| output += pack_n * pack_size; | |||
| cur_packB += pack_n * K; | |||
| } | |||
| if (remain_n > 0) { | |||
| matmul_mk4_8x8x8::kern_4x8_remain(packA, cur_packB, K, output, LDC, | |||
| is_first_k, 4, remain_n); | |||
| output += remain_n * pack_size; | |||
| cur_packB += pack_n * K; | |||
| } | |||
| } | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -26,6 +26,8 @@ MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int16, dt_int16, 4, 4, 8, false, false, | |||
| MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE(dt_int8, dt_int16, dt_int16, dt_int16, | |||
| 16, 12, 4, false, false, | |||
| gemm_s8x8x16_mk4_16x12_a53); | |||
| MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int16, dt_int16, 8, 8, 8, false, false, | |||
| gemm_s8x8x16_mk4_8x8x8); | |||
| } // namespace matmul | |||
| } // namespace aarch64 | |||
| @@ -39,6 +39,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||
| AlgoInt8x8x16K4x4x16 int8x8x16_k4x4x16; | |||
| AlgoInt8x8x16MK4_16x12x4 int8x8x16_mk4_16x12x4; | |||
| AlgoInt8x8x16MK4_4x4x8 int8x8x16_mk4_4x4x8; | |||
| AlgoInt8x8x16MK4_K8x8x8 int8x8x16_mk4_k8x8x8; | |||
| AlgoInt16x16x32K12x8x1 int16x16x32_k12x8x1; | |||
| AlgoInt16x16x32MK8_8x8 int16x16x32_mk8_8x8; | |||
| @@ -73,6 +74,7 @@ public: | |||
| #endif | |||
| all_algos.emplace_back(&int8x8x16_k4x4x16); | |||
| all_algos.emplace_back(&int8x8x16_k8x8x8); | |||
| all_algos.emplace_back(&int8x8x16_mk4_k8x8x8); | |||
| all_algos.emplace_back(&int8x8x16_mk4_4x4x8); | |||
| all_algos.emplace_back(&int8x8x16_mk4_16x12x4); | |||
| @@ -57,6 +57,7 @@ private: | |||
| #else | |||
| class AlgoQuint8K8x8x8; // Aarch64 Quint8 Kernel 8x8x8 | |||
| #endif | |||
| class AlgoInt8x8x16MK4_K8x8x8; // Aarch64 Int4x4x16 Kernel 4x4x16 | |||
| class AlgoPack; | |||
| }; | |||
| @@ -122,6 +122,20 @@ TEST_F(AARCH64, MATRIX_MUL_INT8_MK4) { | |||
| std::move(args)); | |||
| } | |||
| TEST_F(AARCH64, MATRIX_MUL_INT8x8x16_MK4) { | |||
| std::vector<matrix_mul::TestArg> args; | |||
| for (size_t m : {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}) | |||
| for (size_t n : | |||
| {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 24}) | |||
| for (size_t k : | |||
| {2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, | |||
| 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29}) | |||
| args.emplace_back(m, n, k, 0); | |||
| matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, | |||
| handle(), "AARCH64_INT8X8X16_MK4_K8X8X8", | |||
| param::MatrixMul::Format::MK4, 1, 1e-3, | |||
| std::move(args)); | |||
| } | |||
| TEST_F(AARCH64, MATRIX_MUL_MK4_8x8x16_4x4) { | |||
| matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, | |||
| handle(), "AARCH64_INT8X8X16_MK4_4X4X8", | |||
| @@ -396,6 +410,71 @@ TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_MK4_8x8x16) { | |||
| run(384, 384, 384); | |||
| } | |||
| TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_MK4_8x8x8_8x8x16_vs_4x4x16_8x8x16) { | |||
| constexpr size_t RUNS = 50; | |||
| param::MatrixMul param; | |||
| param.transposeA = false; | |||
| param.transposeB = false; | |||
| Benchmarker<MatrixMul> benchmarker(handle()); | |||
| Benchmarker<MatrixMul> benchmarker_mk4(handle()); | |||
| Benchmarker<MatrixMul> benchmarker_mk4_4x4x8(handle()); | |||
| benchmarker.set_times(RUNS) | |||
| .set_dtype(0, dtype::Int8{}) | |||
| .set_dtype(1, dtype::Int8{}) | |||
| .set_dtype(2, dtype::Int16{}) | |||
| .set_param(param) | |||
| .set_display(false); | |||
| benchmarker.set_before_exec_callback( | |||
| AlgoChecker<MatrixMul>("AARCH64_INT8X8X16_K4X4X16")); | |||
| param.format = MatrixMul::Param::Format::MK4; | |||
| benchmarker_mk4.set_before_exec_callback( | |||
| AlgoChecker<MatrixMul>( | |||
| "AARCH64_INT8X8X16_MK4_K8X8X8" | |||
| )); | |||
| benchmarker_mk4.set_times(RUNS) | |||
| .set_dtype(0, dtype::Int8{}) | |||
| .set_dtype(1, dtype::Int8{}) | |||
| .set_dtype(2, dtype::Int16{}) | |||
| .set_param(param) | |||
| .set_display(false); | |||
| benchmarker_mk4_4x4x8.set_before_exec_callback( | |||
| AlgoChecker<MatrixMul>("AARCH64_INT8X8X16_MK4_4X4X8")); | |||
| benchmarker_mk4_4x4x8.set_times(RUNS) | |||
| .set_dtype(0, dtype::Int8{}) | |||
| .set_dtype(1, dtype::Int8{}) | |||
| .set_dtype(2, dtype::Int16{}) | |||
| .set_param(param) | |||
| .set_display(false); | |||
| auto run = [&](size_t M, size_t N, size_t K) { | |||
| auto default_used = benchmarker.exec({{M, K}, {K, N}, {}}) / RUNS; | |||
| auto mk_used = benchmarker_mk4.exec( | |||
| {{M / 4, K / 4, 4, 4}, {K / 4, N, 4}, {}}) / | |||
| RUNS; | |||
| auto mk4_4x4x8_used = | |||
| benchmarker_mk4_4x4x8.exec( | |||
| {{M / 4, K / 4, 4, 4}, {K / 4, N, 4}, {}}) / | |||
| RUNS; | |||
| float computations = 2.f * M * K * N * 1e-6; | |||
| printf("run: {%zu{M} %zu{K} %zu{N}} normal: %f ms %f Gflops mk4: %f ms " | |||
| "%f Gflops speedup: %f, mk4_4x4x8 %f Gflops %f ms speedup: %f\n", | |||
| M, K, N, default_used, computations / default_used, mk_used, | |||
| computations / mk_used, default_used / mk_used, | |||
| computations / mk4_4x4x8_used, mk4_4x4x8_used , mk4_4x4x8_used/mk_used); | |||
| }; | |||
| run(384, 384, 384); | |||
| run(512, 512, 512); | |||
| run(1024, 1024, 384); | |||
| run(256, 256, 384); | |||
| for(int m = 32; m <= 512;m*=2) | |||
| for(int n = 32; n <= 512;n*=2) | |||
| for(int k = 32; k < 512;k*=2){ | |||
| run(m,n,k); | |||
| } | |||
| } | |||
| TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_INT16_4X4X16) { | |||
| constexpr size_t RUNS = 50; | |||
| param::MatrixMul param; | |||