| @@ -138,6 +138,63 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32Gemv::get_kern( | |||
| return f32_gemv_kern; | |||
| } | |||
| /* ===================== F32 Gevm algo ===================== */ | |||
| namespace { | |||
| void gevm_fp32_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
| auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||
| auto LDB = kern_param.LDB; | |||
| const auto Aptr = kern_param.A<dt_float32>(), | |||
| Bptr = kern_param.B<dt_float32>(); | |||
| auto Cptr = kern_param.C<dt_float32>(); | |||
| arm_common::sgemm_sgemv_like(Bptr, Aptr, Cptr, N, M, K, LDB, 1, 1); | |||
| } | |||
| void gevm_int8_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
| auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||
| auto LDB = kern_param.LDB; | |||
| const auto Aptr = kern_param.A<dt_int8>(), | |||
| Bptr = kern_param.B<dt_int8>(); | |||
| auto Cptr = kern_param.C<dt_int32>(); | |||
| arm_common::matmul::gemv_like_int8(Bptr, Aptr, Cptr, N, M, K, LDB, 1, 1); | |||
| } | |||
| } // anonymous namespace | |||
| bool MatrixMulImpl::AlgoGevm::usable( | |||
| const KernSizeParam& kern_size_param) const { | |||
| // enumerate the M, N, K, only usable when preferred | |||
| bool fp32_ok = | |||
| kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && | |||
| kern_size_param.format == param::MatrixMul::Format::DEFAULT && | |||
| kern_size_param.B_type == kern_size_param.A_type && | |||
| kern_size_param.C_type == kern_size_param.A_type && | |||
| kern_size_param.A_type == dtype::Float32(); | |||
| return (fp32_ok || can_be_treated_as_int8x8x32(kern_size_param)) && | |||
| preferred(kern_size_param); | |||
| } | |||
| bool MatrixMulImpl::AlgoGevm::preferred( | |||
| const KernSizeParam& kern_size_param) const { | |||
| auto M = kern_size_param.M; | |||
| return kern_size_param.trB && M == 1; | |||
| } | |||
| MatrixMulImpl::kern_t MatrixMulImpl::AlgoGevm::get_kern( | |||
| const KernSizeParam& kern_size_param) const { | |||
| if (kern_size_param.A_type == dtype::Float32()) { | |||
| return gevm_fp32_kern; | |||
| } else if (kern_size_param.A_type.enumv() == DTypeEnum::Int8 || | |||
| kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) { | |||
| return gevm_int8_kern; | |||
| } else { | |||
| megdnn_assert( | |||
| false, "no avaliable kern got A_type: %s B_type: %s C_type: %s", | |||
| kern_size_param.A_type.name(), kern_size_param.B_type.name(), | |||
| kern_size_param.C_type.name()); | |||
| } | |||
| } | |||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
| /* ===================== F16 Gemv algo ===================== */ | |||
| namespace { | |||
| @@ -70,6 +70,21 @@ public: | |||
| PackMode packmode() const override { return PackMode::NO_PACK; } | |||
| }; | |||
| #endif | |||
| class MatrixMulImpl::AlgoGevm : public AlgoBase { | |||
| public: | |||
| bool is_reproducible() const override { return true; } | |||
| const char* name() const override { return "ARM_COMMON_GEVM"; } | |||
| bool usable(const KernSizeParam&) const override; | |||
| bool preferred(const KernSizeParam&) const override; | |||
| size_t get_workspace(const KernSizeParam&) const override { return 0; } | |||
| kern_t get_kern(const KernSizeParam&) const override; | |||
| void* type() const override { return sm_arm_common_algo_type; } | |||
| AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||
| PackMode packmode() const override { return PackMode::NO_PACK; } | |||
| }; | |||
| } // namespace arm_common | |||
| } // namespace megdnn | |||
| @@ -27,7 +27,8 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
| AlgoF16Gemv f16gemv; | |||
| #endif | |||
| AlgoInt8x8x32Gemv int8x8x32_gemv; | |||
| AlgoInt8x8x32Gemv int8x8x32_gemv; | |||
| AlgoGevm gevm; | |||
| public: | |||
| AlgoPack() { | |||
| all_algos.emplace_back(&int8x8x16); | |||
| @@ -35,6 +36,7 @@ public: | |||
| all_algos.emplace_back(&f16gemv); | |||
| #endif | |||
| all_algos.emplace_back(&int8x8x32_gemv); | |||
| all_algos.emplace_back(&gevm); | |||
| } | |||
| SmallVector<AlgoBase*> all_algos; | |||
| }; | |||
| @@ -27,6 +27,7 @@ protected: | |||
| static void* const sm_arm_common_algo_type; | |||
| class AlgoInt8x8x32Gemv; // Arm_common Int 8x8x32 Gemv | |||
| class AlgoF32Gemv; // Arm_common F32 Gemv | |||
| class AlgoGevm; // Arm_common Gemv(support int8 and fp32) | |||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
| class AlgoF16Gemv; | |||
| #endif | |||
| @@ -164,6 +164,62 @@ TEST_F(ARM_COMMON, QINT8x8x32_GEMV) { | |||
| run(M, K, N); | |||
| } | |||
| TEST_F(ARM_COMMON, QINT8x8x32_GEVM) { | |||
| Checker<MatrixMul> checker(handle()); | |||
| using Param = MatrixMul::Param; | |||
| checker.set_before_exec_callback( | |||
| AlgoChecker<MatrixMul>("ARM_COMMON_GEVM")); | |||
| std::unique_ptr<RNG> rng = std::make_unique<UniformIntRNG>(-127, 127); | |||
| checker.set_rng(0, rng.get()).set_rng(1, rng.get()); | |||
| auto run = [&](size_t M, size_t K, size_t N) { | |||
| Param param; | |||
| param.transposeA = false; | |||
| param.transposeB = true; | |||
| TensorShape A, B; | |||
| A = TensorShape{M, K}; | |||
| B = TensorShape{N, K}; | |||
| checker.set_param(param) | |||
| .set_dtype(0, dtype::QuantizedS8(2.5f)) | |||
| .set_dtype(1, dtype::QuantizedS8(2.5f)) | |||
| .set_dtype(2, dtype::QuantizedS32(6.25f)) | |||
| .execs({A, B, {}}); | |||
| }; | |||
| // M = 1 | |||
| for (size_t N : {1, 10, 16, 33, 64}) | |||
| for (size_t K : {7, 512, 1024}) | |||
| for (size_t M : {1}) | |||
| run(M, K, N); | |||
| } | |||
| TEST_F(ARM_COMMON, FP32_GEVM) { | |||
| Checker<MatrixMul> checker(handle()); | |||
| using Param = MatrixMul::Param; | |||
| checker.set_before_exec_callback( | |||
| AlgoChecker<MatrixMul>("ARM_COMMON_GEVM")); | |||
| checker.set_epsilon(1e-2); | |||
| auto run = [&](size_t M, size_t K, size_t N) { | |||
| Param param; | |||
| param.transposeA = false; | |||
| param.transposeB = true; | |||
| TensorShape A, B; | |||
| A = TensorShape{M, K}; | |||
| B = TensorShape{N, K}; | |||
| checker.set_param(param).execs({A, B, {}}); | |||
| }; | |||
| // M = 1 | |||
| for (size_t M : {1}) | |||
| for (size_t K : {1000, 4096, 25088}) | |||
| for (size_t N : {1000, 4096}) | |||
| run(M, K, N); | |||
| } | |||
| #if MEGDNN_WITH_BENCHMARK | |||
| TEST_F(ARM_COMMON, BENCHMARK_SGEMV) { | |||