move AlgoF32GemvMK4 from arm_common to fallback
GitOrigin-RevId: 6c065abf99
tags/v1.10.0
| @@ -239,46 +239,6 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32Gemv::get_kern(const KernSizeParam&) | |||
| return f32_gemv_kern; | |||
| } | |||
| /* ================== F32 Gemv MK4 algo ================== */ | |||
| namespace { | |||
| void f32_gemv_mk4_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
| MIDOUT_BEGIN(megdnn_arm_exec_fp32, midout_iv("f32_gemv_mk4_kern"_hash)) { | |||
| auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||
| auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||
| const auto Aptr = kern_param.A<dt_float32>(), Bptr = kern_param.B<dt_float32>(); | |||
| auto Cptr = kern_param.C<dt_float32>(); | |||
| gemv_like_mk4(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); | |||
| } | |||
| MIDOUT_END(); | |||
| } | |||
| } // anonymous namespace | |||
| bool MatrixMulImpl::AlgoF32GemvMK4::usable(const KernSizeParam& kern_size_param) const { | |||
| // enumerate the M, N, K, only usable when preferred | |||
| auto M = kern_size_param.M; | |||
| auto N = kern_size_param.N; | |||
| auto K = kern_size_param.K; | |||
| auto LDB = kern_size_param.LDB; | |||
| return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && | |||
| kern_size_param.format == param::MatrixMul::Format::MK4 && | |||
| kern_size_param.B_type == kern_size_param.A_type && | |||
| kern_size_param.C_type == kern_size_param.A_type && | |||
| kern_size_param.A_type == dtype::Float32() && !kern_size_param.trA && | |||
| !kern_size_param.trB && M % 4 == 0 && K % 4 == 0 && N == 1 && LDB == 4; | |||
| } | |||
| bool MatrixMulImpl::AlgoF32GemvMK4::preferred( | |||
| const KernSizeParam& kern_size_param) const { | |||
| MEGDNN_MARK_USED_VAR(kern_size_param); | |||
| return true; | |||
| } | |||
| MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32GemvMK4::get_kern( | |||
| const KernSizeParam&) const { | |||
| return f32_gemv_mk4_kern; | |||
| } | |||
| /* ===================== F32 Gevm algo ===================== */ | |||
| namespace { | |||
| template <typename stype, typename dtype> | |||
| @@ -95,22 +95,6 @@ public: | |||
| MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::FLOAT32, DEFAULT) | |||
| }; | |||
| class MatrixMulImpl::AlgoF32GemvMK4 : public AlgoBase { | |||
| public: | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
| } | |||
| const char* name() const override { return "ARM_COMMON_F32_GEMV_MK4"; } | |||
| bool usable(const KernSizeParam&) const override; | |||
| bool preferred(const KernSizeParam&) const override; | |||
| size_t get_workspace(const KernSizeParam&) const override { return 0; } | |||
| kern_t get_kern(const KernSizeParam&) const override; | |||
| AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||
| PackMode packmode() const override { return PackMode::NO_PACK; } | |||
| MEGDNN_OVERRIDE_MATMUL_DESC(4, 1, 1, 4, AlgoDataType::FLOAT32, MK4) | |||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_F32_GEMV_MK4) | |||
| }; | |||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
| class MatrixMulImpl::AlgoF16Gemv : public AlgoBase { | |||
| public: | |||
| @@ -26,7 +26,6 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||
| AlgoInt8x8x32GemvMK4Dot int8x8x32_gemv_mk4_dot; | |||
| #endif | |||
| AlgoGevm gevm; | |||
| AlgoF32GemvMK4 f32_gemv_mk4; | |||
| SmallVector<fallback::MatrixMulImpl::AlgoBase*> m_all_algos; | |||
| fallback::MatrixMulImpl::AlgoBase::Mapper m_all_algos_map; | |||
| @@ -42,7 +41,6 @@ public: | |||
| #endif | |||
| m_all_algos.emplace_back(&int8x8x32_gemv); | |||
| m_all_algos.emplace_back(&int8x8x32_gemv_mk4); | |||
| m_all_algos.emplace_back(&f32_gemv_mk4); | |||
| m_all_algos.emplace_back(&gevm); | |||
| for (auto&& algo : m_all_algos) { | |||
| @@ -34,7 +34,6 @@ public: | |||
| protected: | |||
| class AlgoF32Gemv; // Arm_common F32 Gemv | |||
| class AlgoF32GemvMK4; // Arm_common F32 Gemv NCHW44 | |||
| class AlgoInt8x8x32Gemv; // Arm_common Int8x8x32 Gemv | |||
| class AlgoInt8x8x32GemvMK4; // Arm_common Int8x8x32 Gemv NCHW44 | |||
| class AlgoGevm; // Arm_common Gevm(support int8 and fp32) | |||
| @@ -17,11 +17,15 @@ | |||
| #include "src/naive/matrix_mul/matrix_mul_helper.h" | |||
| #include "src/fallback/matrix_mul/gi/fp32/exec_sgemv.h" | |||
| #include "midout.h" | |||
| MIDOUT_DECL(megdnn_fb_matmul_f32_kern) | |||
| MIDOUT_DECL(megdnn_fb_matmul_f32_gemm_gemv_like) | |||
| MIDOUT_DECL(megdnn_fb_matmul_naive) | |||
| MIDOUT_DECL(megdnn_fb_gi_exec_fp32) | |||
| MIDOUT_DECL(megdnn_fb_gi_matmul_kern) | |||
| using namespace megdnn; | |||
| using namespace fallback; | |||
| @@ -205,4 +209,99 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoNaive::get_kern(const KernSizeParam&) c | |||
| return kern_naive; | |||
| } | |||
| /* ================== F32 Gemv MK4 gi algo ================== */ | |||
| namespace { | |||
| void gi_f32_gemv_mk4_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
| MIDOUT_BEGIN(megdnn_fb_gi_exec_fp32, midout_iv("f32_gemv_mk4_gi_kern"_hash)) { | |||
| auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||
| auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||
| const auto Aptr = kern_param.A<dt_float32>(), Bptr = kern_param.B<dt_float32>(); | |||
| auto Cptr = kern_param.C<dt_float32>(); | |||
| gi_gemv_like_mk4(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); | |||
| } | |||
| MIDOUT_END(); | |||
| } | |||
| } // anonymous namespace | |||
| bool MatrixMulImpl::AlgoF32GiGemvMK4::usable( | |||
| const KernSizeParam& kern_size_param) const { | |||
| // enumerate the M, N, K, only usable when preferred | |||
| auto M = kern_size_param.M; | |||
| auto N = kern_size_param.N; | |||
| auto K = kern_size_param.K; | |||
| auto LDB = kern_size_param.LDB; | |||
| return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && | |||
| kern_size_param.format == param::MatrixMul::Format::MK4 && | |||
| kern_size_param.B_type == kern_size_param.A_type && | |||
| kern_size_param.C_type == kern_size_param.A_type && | |||
| kern_size_param.A_type == dtype::Float32() && !kern_size_param.trA && | |||
| !kern_size_param.trB && M % 4 == 0 && K % 4 == 0 && N == 1 && LDB == 4; | |||
| } | |||
| bool MatrixMulImpl::AlgoF32GiGemvMK4::preferred( | |||
| const KernSizeParam& kern_size_param) const { | |||
| MEGDNN_MARK_USED_VAR(kern_size_param); | |||
| return true; | |||
| } | |||
| MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32GiGemvMK4::get_kern( | |||
| const KernSizeParam&) const { | |||
| return gi_f32_gemv_mk4_kern; | |||
| } | |||
| /* ================== F32 Gemm MK4 gi algo ================== */ | |||
| namespace { | |||
| void gi_f32_mk4_4x8_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
| MIDOUT_BEGIN(megdnn_fb_gi_matmul_kern, midout_iv("gi_f32_mk4_4x8_kern"_hash)) { | |||
| auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||
| auto trA = kern_param.trA, trB = kern_param.trB; | |||
| auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||
| auto A_type = kern_param.A_type, B_type = kern_param.B_type, | |||
| C_type = kern_param.C_type; | |||
| const auto Aptr = kern_param.A<float>(), Bptr = kern_param.B<float>(); | |||
| auto Cptr = kern_param.C<float>(); | |||
| matmul::fallback::gi_sgemm_nopack_4x8 strategy(A_type, B_type, C_type); | |||
| megdnn::matmul::GemmInterleaved<matmul::fallback::gi_sgemm_nopack_4x8, false>( | |||
| M, N, K, trA, trB, strategy) | |||
| .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); | |||
| } | |||
| MIDOUT_END(); | |||
| } | |||
| } // anonymous namespace | |||
| bool MatrixMulImpl::AlgoF32GiMK4_4x8::usable( | |||
| const KernSizeParam& kern_size_param) const { | |||
| return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && | |||
| kern_size_param.format == param::MatrixMul::Format::MK4 && | |||
| kern_size_param.B_type == kern_size_param.A_type && | |||
| kern_size_param.C_type == kern_size_param.A_type && | |||
| kern_size_param.A_type == dtype::Float32() && !kern_size_param.trA && | |||
| !kern_size_param.trB; | |||
| } | |||
| size_t MatrixMulImpl::AlgoF32GiMK4_4x8::get_workspace( | |||
| const KernSizeParam& kern_size_param) const { | |||
| MIDOUT_BEGIN( | |||
| megdnn_fb_gi_matmul_kern, | |||
| midout_iv("AlgoF32GiMK4_4x8::get_workspace"_hash)) { | |||
| auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; | |||
| auto trA = kern_size_param.trA, trB = kern_size_param.trB; | |||
| auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, | |||
| C_type = kern_size_param.C_type; | |||
| matmul::fallback::gi_sgemm_nopack_4x8 strategy(A_type, B_type, C_type); | |||
| return megdnn::matmul::GemmInterleaved< | |||
| matmul::fallback::gi_sgemm_nopack_4x8, false>( | |||
| M, N, K, trA, trB, strategy) | |||
| .get_workspace_size(); | |||
| } | |||
| MIDOUT_END(); | |||
| return 0; | |||
| } | |||
| MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32GiMK4_4x8::get_kern( | |||
| const KernSizeParam&) const { | |||
| return gi_f32_mk4_4x8_kern; | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -80,6 +80,34 @@ public: | |||
| DEFAULT) | |||
| }; | |||
| class MatrixMulImpl::AlgoF32GiGemvMK4 : public AlgoBase { | |||
| public: | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
| } | |||
| const char* name() const override { return "FB_GI_F32_GEMV_MK4"; } | |||
| bool usable(const KernSizeParam&) const override; | |||
| bool preferred(const KernSizeParam&) const override; | |||
| size_t get_workspace(const KernSizeParam&) const override { return 0; } | |||
| kern_t get_kern(const KernSizeParam&) const override; | |||
| AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||
| PackMode packmode() const override { return PackMode::NO_PACK; } | |||
| MEGDNN_OVERRIDE_MATMUL_DESC(4, 1, 1, 4, AlgoDataType::FLOAT32, MK4) | |||
| MEGDNN_DECL_ALGO_TYPE(FB_GI_F32_GEMV_MK4) | |||
| }; | |||
| class MatrixMulImpl::AlgoF32GiMK4_4x8 final : public AlgoBase { | |||
| public: | |||
| AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
| const char* name() const override { return "FB_GI_F32_MK4_4x8"; } | |||
| bool usable(const KernSizeParam&) const override; | |||
| size_t get_workspace(const KernSizeParam&) const override; | |||
| kern_t get_kern(const KernSizeParam&) const override; | |||
| PackMode packmode() const override { return PackMode::NO_PACK; } | |||
| MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 4, 4, AlgoDataType::FLOAT32, MK4) | |||
| MEGDNN_DECL_ALGO_TYPE(FB_GI_F32_MK4_4x8) | |||
| }; | |||
| } // namespace fallback | |||
| } // namespace megdnn | |||
| @@ -16,6 +16,8 @@ namespace matmul { | |||
| namespace fallback { | |||
| MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, true, sgemm_8x12); | |||
| MEGDNN_REG_GEMM_STRATEGY_NOPACK( | |||
| float, float, float, 4, 8, 1, false, true, gi_sgemm_nopack_4x8); | |||
| } // namespace fallback | |||
| } // namespace matmul | |||
| @@ -0,0 +1,101 @@ | |||
| /** | |||
| * \file dnn/src/fallback/matrix_mul/gi/fp32/exec_sgemv.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2022 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/fallback/matrix_mul/gi/fp32/exec_sgemv.h" | |||
| #include "include/megdnn/oprs.h" | |||
| #include "src/common/unroll_macro.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/fallback/general_intrinsic/gi_float.h" | |||
| #include "midout.h" | |||
| MIDOUT_DECL(megdnn_fp32_gi_sgemv) | |||
| using namespace megdnn; | |||
| using namespace fallback; | |||
| namespace { | |||
| void sgemv_gi_naive_n_mk4( | |||
| const float* __restrict A, const float* __restrict B, float* __restrict C, | |||
| size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride) { | |||
| constexpr size_t PACK_SIZE = 4; | |||
| megdnn_assert( | |||
| N == 1 && Bstride == PACK_SIZE && M % PACK_SIZE == 0 && K % PACK_SIZE == 0); | |||
| auto Aptr = A; | |||
| auto Cptr = C; | |||
| size_t m = 0; | |||
| while (m < M) { | |||
| auto Aptr0 = Aptr; | |||
| auto Cptr0 = Cptr; | |||
| GI_FLOAT32_t c[4]; | |||
| #define INIT(step) c[step] = GiBroadcastFloat32(0.0f); | |||
| UNROLL_CALL_RAW(4, INIT) | |||
| #undef INIT | |||
| auto Bptr = B; | |||
| size_t k = 0; | |||
| while (k < K) { | |||
| GI_FLOAT32_t b = GiLoadFloat32(Bptr); | |||
| GI_FLOAT32_V2_t a[2]; | |||
| #if defined(GI_TEST_NAIVE) | |||
| #define LOAD_A(step) \ | |||
| a[step].val[0] = GiLoadFloat32(Aptr0 + step * 8); \ | |||
| a[step].val[1] = GiLoadFloat32(Aptr0 + step * 8 + 4); | |||
| #elif defined(__arm__) || defined(__aarch64__) | |||
| #define LOAD_A(step) a[step] = vld1q_f32_x2(Aptr0 + step * 8); | |||
| #else | |||
| #define LOAD_A(step) \ | |||
| a[step].val[0] = GiLoadFloat32(Aptr0 + step * 8); \ | |||
| a[step].val[1] = GiLoadFloat32(Aptr0 + step * 8 + 4); | |||
| #endif | |||
| UNROLL_CALL_RAW(2, LOAD_A) | |||
| #undef LOAD_A | |||
| #define COMPT(step) \ | |||
| c[step] = GiSimdFmaLane(c[step], a[step / 2].val[step % 2], b, step % 4); | |||
| UNROLL_CALL_RAW(4, COMPT) | |||
| #undef COMPT | |||
| Bptr += Bstride; | |||
| Aptr0 += PACK_SIZE * PACK_SIZE; | |||
| k += PACK_SIZE; | |||
| } | |||
| #define ADD_C(step, stride) c[step] = GiAddFloat32(c[step], c[step + stride]); | |||
| UNROLL_CALL_RAW(2, ADD_C, 2) | |||
| UNROLL_CALL_RAW(1, ADD_C, 1) | |||
| #undef ADD_C | |||
| GiStoreFloat32(Cptr0, c[0]); | |||
| Aptr += Astride; | |||
| Cptr += Cstride; | |||
| m += PACK_SIZE; | |||
| } | |||
| } | |||
| } // namespace | |||
| namespace megdnn { | |||
| namespace fallback { | |||
| void gi_gemv_like_mk4( | |||
| const float* __restrict A, const float* __restrict B, float* __restrict C, | |||
| size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride) { | |||
| megdnn_assert(N == 1 && Bstride == 4); | |||
| MIDOUT_BEGIN(megdnn_fp32_gi_sgemv, midout_iv("F32_GEMV_NCHW_GI_44_N"_hash)) { | |||
| return sgemv_gi_naive_n_mk4(A, B, C, M, N, K, Astride, Bstride, Cstride); | |||
| } | |||
| MIDOUT_END(); | |||
| } | |||
| } // namespace fallback | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,25 @@ | |||
| /** | |||
| * \file dnn/src/fallback/matrix_mul/gi/fp32/exec_sgemv.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2022 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include <cstddef> | |||
| namespace megdnn { | |||
| namespace fallback { | |||
| void gi_gemv_like_mk4( | |||
| const float* __restrict A, const float* __restrict B, float* __restrict C, | |||
| size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride); | |||
| } // namespace fallback | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,349 @@ | |||
| /** | |||
| * \file dnn/src/fallback/matrix_mul/gi/fp32/strategy_mk4_4x8.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2022 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "src/common/utils.h" | |||
| #include "src/fallback/general_intrinsic/gi_float.h" | |||
| #include "src/fallback/matrix_mul/generic_strategy.h" | |||
| using namespace megdnn; | |||
| using namespace matmul::fallback; | |||
| namespace { | |||
| void kern_4x1(const float* A, const float* B, size_t LDB, size_t K, float* C) { | |||
| LDB = LDB - 4; | |||
| K = K - 4; | |||
| GI_FLOAT32_t d8d9 = GiLoadFloat32(A); | |||
| A = A + 4; | |||
| GI_FLOAT32_t d10d11 = GiLoadFloat32(A); | |||
| A = A + 4; | |||
| GI_FLOAT32_t d12d13 = GiLoadFloat32(A); | |||
| A = A + 4; | |||
| GI_FLOAT32_t d14d15 = GiLoadFloat32(A); | |||
| A = A + 4; | |||
| GI_FLOAT32_t d16d17 = GiBroadcastFloat32(0.0f); | |||
| GI_FLOAT32_t d18d19 = GiBroadcastFloat32(0.0f); | |||
| GI_FLOAT32_t d20d21 = GiBroadcastFloat32(0.0f); | |||
| GI_FLOAT32_t d22d23 = GiBroadcastFloat32(0.0f); | |||
| GI_FLOAT32_t d0d1 = GiLoadFloat32(B); | |||
| B = B + 4; | |||
| d16d17 = GiSimdFmaLane(d16d17, d8d9, d0d1, 0); | |||
| d18d19 = GiSimdFmaLane(d18d19, d10d11, d0d1, 1); | |||
| for (; K > 0; K -= 4) { | |||
| d8d9 = GiLoadFloat32(A); | |||
| A = A + 4; | |||
| d10d11 = GiLoadFloat32(A); | |||
| A = A + 4; | |||
| d20d21 = GiSimdFmaLane(d20d21, d12d13, d0d1, 2); | |||
| d22d23 = GiSimdFmaLane(d22d23, d14d15, d0d1, 3); | |||
| B = B + LDB; | |||
| d0d1 = GiLoadFloat32(B); | |||
| B = B + 4; | |||
| d12d13 = GiLoadFloat32(A); | |||
| A = A + 4; | |||
| d14d15 = GiLoadFloat32(A); | |||
| A = A + 4; | |||
| d16d17 = GiSimdFmaLane(d16d17, d8d9, d0d1, 0); | |||
| d18d19 = GiSimdFmaLane(d18d19, d10d11, d0d1, 1); | |||
| } | |||
| d20d21 = GiSimdFmaLane(d20d21, d12d13, d0d1, 2); | |||
| d22d23 = GiSimdFmaLane(d22d23, d14d15, d0d1, 3); | |||
| d16d17 = GiAddFloat32(d16d17, d20d21); | |||
| d18d19 = GiAddFloat32(d18d19, d22d23); | |||
| d16d17 = GiAddFloat32(d16d17, d18d19); | |||
| GiStoreFloat32(C, d16d17); | |||
| C = C + 4; | |||
| } | |||
| void kern_4x4(const float* A, const float* B, size_t LDB, size_t K, float* C) { | |||
| LDB = (LDB - 16); | |||
| K = K - 4; | |||
| GI_FLOAT32_t d8d9 = GiLoadFloat32(A); | |||
| A = A + 4; | |||
| GI_FLOAT32_t d10d11 = GiLoadFloat32(A); | |||
| A = A + 4; | |||
| GI_FLOAT32_t d12d13 = GiLoadFloat32(A); | |||
| A = A + 4; | |||
| GI_FLOAT32_t d14d15 = GiLoadFloat32(A); | |||
| A = A + 4; | |||
| GI_FLOAT32_t d0d1 = GiLoadFloat32(B); | |||
| B = B + 4; | |||
| GI_FLOAT32_t d2d3 = GiLoadFloat32(B); | |||
| B = B + 4; | |||
| GI_FLOAT32_t d4d5 = GiLoadFloat32(B); | |||
| B = B + 4; | |||
| GI_FLOAT32_t d6d7 = GiLoadFloat32(B); | |||
| B = B + 4; | |||
| GI_FLOAT32_t d16d17 = GiSimdFmaLane(vfzero, d8d9, d0d1, 0); | |||
| GI_FLOAT32_t d18d19 = GiSimdFmaLane(vfzero, d8d9, d2d3, 0); | |||
| GI_FLOAT32_t d20d21 = GiSimdFmaLane(vfzero, d8d9, d4d5, 0); | |||
| GI_FLOAT32_t d22d23 = GiSimdFmaLane(vfzero, d8d9, d6d7, 0); | |||
| d16d17 = GiSimdFmaLane(d16d17, d10d11, d0d1, 1); | |||
| d18d19 = GiSimdFmaLane(d18d19, d10d11, d2d3, 1); | |||
| d20d21 = GiSimdFmaLane(d20d21, d10d11, d4d5, 1); | |||
| d22d23 = GiSimdFmaLane(d22d23, d10d11, d6d7, 1); | |||
| for (; K > 0; K -= 4) { | |||
| d8d9 = GiLoadFloat32(A); | |||
| A = A + 4; | |||
| d10d11 = GiLoadFloat32(A); | |||
| A = A + 4; | |||
| d16d17 = GiSimdFmaLane(d16d17, d12d13, d0d1, 2); | |||
| d18d19 = GiSimdFmaLane(d18d19, d12d13, d2d3, 2); | |||
| d20d21 = GiSimdFmaLane(d20d21, d12d13, d4d5, 2); | |||
| d22d23 = GiSimdFmaLane(d22d23, d12d13, d6d7, 2); | |||
| B = B + LDB; | |||
| d16d17 = GiSimdFmaLane(d16d17, d14d15, d0d1, 3); | |||
| d18d19 = GiSimdFmaLane(d18d19, d14d15, d2d3, 3); | |||
| d0d1 = GiLoadFloat32(B); | |||
| B = B + 4; | |||
| d20d21 = GiSimdFmaLane(d20d21, d14d15, d4d5, 3); | |||
| d2d3 = GiLoadFloat32(B); | |||
| B = B + 4; | |||
| d22d23 = GiSimdFmaLane(d22d23, d14d15, d6d7, 3); | |||
| d4d5 = GiLoadFloat32(B); | |||
| B = B + 4; | |||
| d16d17 = GiSimdFmaLane(d16d17, d8d9, d0d1, 0); | |||
| d6d7 = GiLoadFloat32(B); | |||
| B = B + 4; | |||
| d18d19 = GiSimdFmaLane(d18d19, d8d9, d2d3, 0); | |||
| d20d21 = GiSimdFmaLane(d20d21, d8d9, d4d5, 0); | |||
| d22d23 = GiSimdFmaLane(d22d23, d8d9, d6d7, 0); | |||
| d12d13 = GiLoadFloat32(A); | |||
| A = A + 4; | |||
| d14d15 = GiLoadFloat32(A); | |||
| A = A + 4; | |||
| d16d17 = GiSimdFmaLane(d16d17, d10d11, d0d1, 1); | |||
| d18d19 = GiSimdFmaLane(d18d19, d10d11, d2d3, 1); | |||
| d20d21 = GiSimdFmaLane(d20d21, d10d11, d4d5, 1); | |||
| d22d23 = GiSimdFmaLane(d22d23, d10d11, d6d7, 1); | |||
| } | |||
| d16d17 = GiSimdFmaLane(d16d17, d12d13, d0d1, 2); | |||
| d18d19 = GiSimdFmaLane(d18d19, d12d13, d2d3, 2); | |||
| d20d21 = GiSimdFmaLane(d20d21, d12d13, d4d5, 2); | |||
| d22d23 = GiSimdFmaLane(d22d23, d12d13, d6d7, 2); | |||
| d16d17 = GiSimdFmaLane(d16d17, d14d15, d0d1, 3); | |||
| d18d19 = GiSimdFmaLane(d18d19, d14d15, d2d3, 3); | |||
| d20d21 = GiSimdFmaLane(d20d21, d14d15, d4d5, 3); | |||
| d22d23 = GiSimdFmaLane(d22d23, d14d15, d6d7, 3); | |||
| GiStoreFloat32(C, d16d17); | |||
| C = C + 4; | |||
| GiStoreFloat32(C, d18d19); | |||
| C = C + 4; | |||
| GiStoreFloat32(C, d20d21); | |||
| C = C + 4; | |||
| GiStoreFloat32(C, d22d23); | |||
| C = C + 4; | |||
| } | |||
| void kern_4x8(const float* A, const float* B, size_t LDB, size_t K, float* C) { | |||
| LDB -= 32; | |||
| GI_FLOAT32_t d8d9 = GiLoadFloat32(A); | |||
| A = A + 4; | |||
| GI_FLOAT32_t d10d11 = GiLoadFloat32(A); | |||
| A = A + 4; | |||
| GI_FLOAT32_t d12d13 = GiLoadFloat32(A); | |||
| A = A + 4; | |||
| GI_FLOAT32_t d14d15 = GiLoadFloat32(A); | |||
| A = A + 4; | |||
| GI_FLOAT32_t d0d1 = GiLoadFloat32(B); | |||
| B = B + 4; | |||
| GI_FLOAT32_t d2d3 = GiLoadFloat32(B); | |||
| B = B + 4; | |||
| GI_FLOAT32_t d4d5 = GiLoadFloat32(B); | |||
| B = B + 4; | |||
| GI_FLOAT32_t d6d7 = GiLoadFloat32(B); | |||
| B = B + 4; | |||
| GI_FLOAT32_t d16d17 = GiSimdFmaLane(vfzero, d8d9, d0d1, 0); | |||
| d16d17 = GiSimdFmaLane(d16d17, d10d11, d0d1, 1); | |||
| GI_FLOAT32_t d18d19 = GiSimdFmaLane(vfzero, d8d9, d2d3, 0); | |||
| d16d17 = GiSimdFmaLane(d16d17, d12d13, d0d1, 2); | |||
| d18d19 = GiSimdFmaLane(d18d19, d10d11, d2d3, 1); | |||
| d16d17 = GiSimdFmaLane(d16d17, d14d15, d0d1, 3); | |||
| d18d19 = GiSimdFmaLane(d18d19, d12d13, d2d3, 2); | |||
| d18d19 = GiSimdFmaLane(d18d19, d14d15, d2d3, 3); | |||
| d0d1 = GiLoadFloat32(B); | |||
| B = B + 4; | |||
| d2d3 = GiLoadFloat32(B); | |||
| B = B + 4; | |||
| GI_FLOAT32_t d20d21 = GiSimdFmaLane(vfzero, d8d9, d4d5, 0); | |||
| d20d21 = GiSimdFmaLane(d20d21, d10d11, d4d5, 1); | |||
| GI_FLOAT32_t d22d23 = GiSimdFmaLane(vfzero, d8d9, d6d7, 0); | |||
| d20d21 = GiSimdFmaLane(d20d21, d12d13, d4d5, 2); | |||
| d22d23 = GiSimdFmaLane(d22d23, d10d11, d6d7, 1); | |||
| d20d21 = GiSimdFmaLane(d20d21, d14d15, d4d5, 3); | |||
| d22d23 = GiSimdFmaLane(d22d23, d12d13, d6d7, 2); | |||
| d22d23 = GiSimdFmaLane(d22d23, d14d15, d6d7, 3); | |||
| d4d5 = GiLoadFloat32(B); | |||
| B = B + 4; | |||
| d6d7 = GiLoadFloat32(B); | |||
| B = B + 4; | |||
| GI_FLOAT32_t d24d25 = GiSimdFmaLane(vfzero, d8d9, d0d1, 0); | |||
| d24d25 = GiSimdFmaLane(d24d25, d10d11, d0d1, 1); | |||
| GI_FLOAT32_t d26d27 = GiSimdFmaLane(vfzero, d8d9, d2d3, 0); | |||
| d24d25 = GiSimdFmaLane(d24d25, d12d13, d0d1, 2); | |||
| d26d27 = GiSimdFmaLane(d26d27, d10d11, d2d3, 1); | |||
| d24d25 = GiSimdFmaLane(d24d25, d14d15, d0d1, 3); | |||
| d26d27 = GiSimdFmaLane(d26d27, d12d13, d2d3, 2); | |||
| d26d27 = GiSimdFmaLane(d26d27, d14d15, d2d3, 3); | |||
| GI_FLOAT32_t d28d29 = GiSimdFmaLane(vfzero, d8d9, d4d5, 0); | |||
| d28d29 = GiSimdFmaLane(d28d29, d10d11, d4d5, 1); | |||
| GI_FLOAT32_t d30d31 = GiSimdFmaLane(vfzero, d8d9, d6d7, 0); | |||
| d28d29 = GiSimdFmaLane(d28d29, d12d13, d4d5, 2); | |||
| d30d31 = GiSimdFmaLane(d30d31, d10d11, d6d7, 1); | |||
| d28d29 = GiSimdFmaLane(d28d29, d14d15, d4d5, 3); | |||
| d30d31 = GiSimdFmaLane(d30d31, d12d13, d6d7, 2); | |||
| d30d31 = GiSimdFmaLane(d30d31, d14d15, d6d7, 3); | |||
| B = B + LDB; | |||
| K = K - 4; | |||
| for (; K > 0; K -= 4) { | |||
| d8d9 = GiLoadFloat32(A); | |||
| A = A + 4; | |||
| d10d11 = GiLoadFloat32(A); | |||
| A = A + 4; | |||
| d12d13 = GiLoadFloat32(A); | |||
| A = A + 4; | |||
| d14d15 = GiLoadFloat32(A); | |||
| A = A + 4; | |||
| d0d1 = GiLoadFloat32(B); | |||
| B = B + 4; | |||
| d2d3 = GiLoadFloat32(B); | |||
| B = B + 4; | |||
| d4d5 = GiLoadFloat32(B); | |||
| B = B + 4; | |||
| d6d7 = GiLoadFloat32(B); | |||
| B = B + 4; | |||
| d16d17 = GiSimdFmaLane(d16d17, d8d9, d0d1, 0); | |||
| d16d17 = GiSimdFmaLane(d16d17, d10d11, d0d1, 1); | |||
| d18d19 = GiSimdFmaLane(d18d19, d8d9, d2d3, 0); | |||
| d16d17 = GiSimdFmaLane(d16d17, d12d13, d0d1, 2); | |||
| d18d19 = GiSimdFmaLane(d18d19, d10d11, d2d3, 1); | |||
| d16d17 = GiSimdFmaLane(d16d17, d14d15, d0d1, 3); | |||
| d18d19 = GiSimdFmaLane(d18d19, d12d13, d2d3, 2); | |||
| d18d19 = GiSimdFmaLane(d18d19, d14d15, d2d3, 3); | |||
| d0d1 = GiLoadFloat32(B); | |||
| B = B + 4; | |||
| d2d3 = GiLoadFloat32(B); | |||
| B = B + 4; | |||
| d20d21 = GiSimdFmaLane(d20d21, d8d9, d4d5, 0); | |||
| d20d21 = GiSimdFmaLane(d20d21, d10d11, d4d5, 1); | |||
| d22d23 = GiSimdFmaLane(d22d23, d8d9, d6d7, 0); | |||
| d20d21 = GiSimdFmaLane(d20d21, d12d13, d4d5, 2); | |||
| d22d23 = GiSimdFmaLane(d22d23, d10d11, d6d7, 1); | |||
| d20d21 = GiSimdFmaLane(d20d21, d14d15, d4d5, 3); | |||
| d22d23 = GiSimdFmaLane(d22d23, d12d13, d6d7, 2); | |||
| d22d23 = GiSimdFmaLane(d22d23, d14d15, d6d7, 3); | |||
| d4d5 = GiLoadFloat32(B); | |||
| B = B + 4; | |||
| d6d7 = GiLoadFloat32(B); | |||
| B = B + 4; | |||
| d24d25 = GiSimdFmaLane(d24d25, d8d9, d0d1, 0); | |||
| d24d25 = GiSimdFmaLane(d24d25, d10d11, d0d1, 1); | |||
| d26d27 = GiSimdFmaLane(d26d27, d8d9, d2d3, 0); | |||
| d24d25 = GiSimdFmaLane(d24d25, d12d13, d0d1, 2); | |||
| d26d27 = GiSimdFmaLane(d26d27, d10d11, d2d3, 1); | |||
| d24d25 = GiSimdFmaLane(d24d25, d14d15, d0d1, 3); | |||
| d26d27 = GiSimdFmaLane(d26d27, d12d13, d2d3, 2); | |||
| d26d27 = GiSimdFmaLane(d26d27, d14d15, d2d3, 3); | |||
| d28d29 = GiSimdFmaLane(d28d29, d8d9, d4d5, 0); | |||
| d28d29 = GiSimdFmaLane(d28d29, d10d11, d4d5, 1); | |||
| d30d31 = GiSimdFmaLane(d30d31, d8d9, d6d7, 0); | |||
| d28d29 = GiSimdFmaLane(d28d29, d12d13, d4d5, 2); | |||
| d30d31 = GiSimdFmaLane(d30d31, d10d11, d6d7, 1); | |||
| d28d29 = GiSimdFmaLane(d28d29, d14d15, d4d5, 3); | |||
| d30d31 = GiSimdFmaLane(d30d31, d12d13, d6d7, 2); | |||
| d30d31 = GiSimdFmaLane(d30d31, d14d15, d6d7, 3); | |||
| B = B + LDB; | |||
| } | |||
| GiStoreFloat32(C, d16d17); | |||
| C = C + 4; | |||
| GiStoreFloat32(C, d18d19); | |||
| C = C + 4; | |||
| GiStoreFloat32(C, d20d21); | |||
| C = C + 4; | |||
| GiStoreFloat32(C, d22d23); | |||
| C = C + 4; | |||
| GiStoreFloat32(C, d24d25); | |||
| C = C + 4; | |||
| GiStoreFloat32(C, d26d27); | |||
| C = C + 4; | |||
| GiStoreFloat32(C, d28d29); | |||
| C = C + 4; | |||
| GiStoreFloat32(C, d30d31); | |||
| C = C + 4; | |||
| } | |||
| } // namespace | |||
| MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(gi_sgemm_nopack_4x8); | |||
| void gi_sgemm_nopack_4x8::kern( | |||
| const float* A, size_t LDA, const float* B, size_t LDB, float* C, size_t LDC, | |||
| size_t M, size_t K, size_t N, const float*, void*, bool trA, bool trB) const { | |||
| constexpr size_t MB = 4; | |||
| constexpr size_t KB = 4; | |||
| constexpr size_t NB = 8; | |||
| constexpr size_t NB_HALF = 4; | |||
| megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0); | |||
| for (size_t m = 0; m < M; m += MB) { | |||
| float* output = C + (m / MB) * LDC; | |||
| const float* cur_B = B; | |||
| size_t n = 0; | |||
| for (; n + NB - 1 < N; n += NB) { | |||
| kern_4x8(A, cur_B, LDB, K, output); | |||
| cur_B += KB * NB; | |||
| output += MB * NB; | |||
| } | |||
| if (N - n >= 4) { | |||
| kern_4x4(A, cur_B, LDB, K, output); | |||
| cur_B += KB * NB_HALF; | |||
| output += MB * NB_HALF; | |||
| n += 4; | |||
| } | |||
| while (n < N) { | |||
| kern_4x1(A, cur_B, LDB, K, output); | |||
| cur_B += KB; | |||
| output += MB; | |||
| n++; | |||
| } | |||
| A += LDA; | |||
| } | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -36,6 +36,8 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||
| AlgoF32K8x12x1 f32_k8x12x1; | |||
| AlgoGemv gemv; | |||
| AlgoNaive naive; | |||
| AlgoF32GiGemvMK4 f32_gemv_mk4; | |||
| AlgoF32GiMK4_4x8 f32_mk4_4x8; | |||
| SmallVector<AlgoBase*> m_all_algos; | |||
| AlgoBase::Mapper m_all_algos_map; | |||
| @@ -44,6 +46,8 @@ public: | |||
| m_all_algos.emplace_back(&gemv); | |||
| m_all_algos.emplace_back(&f32_k8x12x1); | |||
| m_all_algos.emplace_back(&naive); | |||
| m_all_algos.emplace_back(&f32_gemv_mk4); | |||
| m_all_algos.emplace_back(&f32_mk4_4x8); | |||
| for (auto&& algo : m_all_algos) { | |||
| m_all_algos_map.emplace(algo->info().desc, algo); | |||
| } | |||
| @@ -112,6 +112,8 @@ public: | |||
| FB_F32K8x12x1 = 1 << 0, | |||
| FB_GEMV, | |||
| FB_NAIVE, | |||
| FB_GI_F32_GEMV_MK4, | |||
| FB_GI_F32_MK4_4x8, | |||
| #if MEGDNN_X86 | |||
| //! x86 | |||
| @@ -131,7 +133,6 @@ public: | |||
| ARM_COMMON_INT8X8X32_GEMV, | |||
| ARM_COMMON_INT8X8X32_GEMV_MK4, | |||
| ARM_COMMON_INT8X8X32_GEMV_MK4_DOT, | |||
| ARM_COMMON_F32_GEMV_MK4, | |||
| ARM_COMMON_F16_GEMV, | |||
| ARM_COMMON_GEVM, | |||
| #if MEGDNN_AARCH64 | |||
| @@ -236,7 +237,9 @@ public: | |||
| }; | |||
| private: | |||
| class AlgoF32K8x12x1; // Fallback F32 Kernel 8x12x1 | |||
| class AlgoF32K8x12x1; // Fallback F32 Kernel 8x12x1 | |||
| class AlgoF32GiGemvMK4; // fallback F32 gi Gemv NCHW44 | |||
| class AlgoF32GiMK4_4x8; // fallback F32 gi Gemm NCHW44 | |||
| class AlgoGemv; | |||
| class AlgoNaive; | |||
| class AlgoPack; | |||
| @@ -45,6 +45,13 @@ TEST_F(FALLBACK, MATRIX_MUL) { | |||
| checker.execl({AL, BL, CL}); | |||
| } | |||
| } | |||
| TEST_F(FALLBACK, MATRIX_MUL_MK4_GI) { | |||
| matrix_mul::check_matrix_mul( | |||
| dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), | |||
| "FB_GI_F32_MK4_4x8", param::MatrixMul::Format::MK4, 1); | |||
| } | |||
| TEST_F(FALLBACK, MATRIX_MUL_RECORD) { | |||
| TaskRecordChecker<MatrixMul> checker(1); | |||
| using Param = MatrixMul::Param; | |||