GitOrigin-RevId: 2b98867e45
tags/v0.5.0
| @@ -14,7 +14,6 @@ | |||
| #include "src/aarch64/matrix_mul/fp32/strategy.h" | |||
| #include "src/aarch64/matrix_mul/int16/strategy.h" | |||
| #include "src/aarch64/matrix_mul/int8/strategy.h" | |||
| #include "src/aarch64/matrix_mul/int8_dot/gemv.h" | |||
| #include "src/aarch64/matrix_mul/int8_dot/strategy.h" | |||
| #include "src/aarch64/matrix_mul/int8x8x16/strategy.h" | |||
| #include "src/aarch64/matrix_mul/quint8/strategy.h" | |||
| @@ -441,39 +440,6 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K8x12x4DotProd, | |||
| "AlgoInt8x8x32K8x12x4DotProdImpl"_hash, | |||
| aarch64::matmul::gemm_s8_8x12, int8_t, | |||
| int32_t); | |||
| /* ===================== Int8x8x32 Gemv DotProd algo ===================== */ | |||
| namespace { | |||
| void int8x8x32_gemv_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
| auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||
| auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||
| const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>(); | |||
| auto Cptr = kern_param.C<dt_int32>(); | |||
| aarch64::matmul::gemv_like_int8(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); | |||
| } | |||
| } // anonymous namespace | |||
| bool MatrixMulImpl::AlgoInt8x8x32GemvDotProd::usable( | |||
| const KernSizeParam& kern_size_param) const { | |||
| return can_be_treated_as_int8x8x32(kern_size_param) && | |||
| !kern_size_param.trA && !kern_size_param.trB && | |||
| kern_size_param.N == 1 && kern_size_param.LDB == 1; | |||
| } | |||
| bool MatrixMulImpl::AlgoInt8x8x32GemvDotProd::preferred( | |||
| const KernSizeParam& kern_size_param) const { | |||
| auto N = kern_size_param.N, LDB = kern_size_param.LDB; | |||
| return (N == 1 && LDB == 1); | |||
| } | |||
| MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32GemvDotProd::get_kern( | |||
| const KernSizeParam&) const { | |||
| MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, | |||
| midout_iv("AlgoInt8x8x32GemvDotProd::get_kern"_hash)) { | |||
| return int8x8x32_gemv_dotprod_kern; | |||
| } | |||
| MIDOUT_END(); | |||
| return nullptr; | |||
| } | |||
| /* =================== Int8x8x32 MK4 8X12X4 Dotprod algo =================== */ | |||
| namespace { | |||
| @@ -104,21 +104,6 @@ public: | |||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
| }; | |||
| class MatrixMulImpl::AlgoInt8x8x32GemvDotProd final : public AlgoBase { | |||
| public: | |||
| bool is_reproducible() const override { return true; } | |||
| const char* name() const override { | |||
| return "AARCH64_INT8X8X32_GEMV_DOTPROD"; | |||
| } | |||
| bool usable(const KernSizeParam&) const override; | |||
| bool preferred(const KernSizeParam&) const override; | |||
| size_t get_workspace(const KernSizeParam&) const override { return 0; } | |||
| kern_t get_kern(const KernSizeParam&) const override; | |||
| void* type() const override { return sm_arm_common_algo_type; } | |||
| AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||
| PackMode packmode() const override { return PackMode::NO_PACK; } | |||
| }; | |||
| class MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd final : public AlgoBase { | |||
| public: | |||
| bool is_reproducible() const override { return true; } | |||
| @@ -174,10 +159,6 @@ public: | |||
| void* type() const override { return sm_arm_common_algo_type; } | |||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
| }; | |||
| class MatrixMulImpl::AlgoInt8x8x32Gemv final | |||
| : public arm_common::MatrixMulImpl::AlgoInt8x8x32Gemv {}; | |||
| #endif | |||
| class MatrixMulImpl::AlgoInt8x8x16K8x8x8 final : public AlgoBase { | |||
| @@ -1,116 +0,0 @@ | |||
| /** | |||
| * \file dnn/src/aarch64/matrix_mul/int8_dot/gemv.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "src/aarch64/matrix_mul/int8_dot/gemv.h" | |||
| #include <cstddef> | |||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/common/unroll_macro.h" | |||
| #if __ARM_FEATURE_DOTPROD | |||
| namespace { | |||
| void gemv_naive_n(const int8_t* __restrict A, const int8_t* __restrict B, | |||
| int32_t* __restrict C, size_t M, size_t N, size_t K, | |||
| size_t Astride, size_t Bstride, size_t Cstride) { | |||
| megdnn_assert(N == 1 && Bstride == 1); | |||
| size_t m = 0; | |||
| for (; m + 2 <= M; m += 2) { | |||
| int32_t acc[4]; | |||
| int32x4_t acc_neon = vdupq_n_s32(0); | |||
| size_t k = 0; | |||
| for (; k + 16 <= K; k += 16) { | |||
| int64x2_t a0 = vreinterpretq_s64_s8(vld1q_s8(A + m * Astride + k)); | |||
| int64x2_t a1 = | |||
| vreinterpretq_s64_s8(vld1q_s8(A + (m + 1) * Astride + k)); | |||
| //! the first 8 elements is m, the last 8 elements is m + 1 | |||
| int8x16_t a2 = vreinterpretq_s8_s64(vzip1q_s64(a0, a1)); | |||
| int8x16_t a3 = vreinterpretq_s8_s64(vzip2q_s64(a0, a1)); | |||
| int64x2_t b0 = vreinterpretq_s64_s8(vld1q_s8(B + k)); | |||
| int8x16_t b2 = vreinterpretq_s8_s64(vzip1q_s64(b0, b0)); | |||
| int8x16_t b3 = vreinterpretq_s8_s64(vzip2q_s64(b0, b0)); | |||
| acc_neon = vdotq_s32(acc_neon, a2, b2); | |||
| acc_neon = vdotq_s32(acc_neon, a3, b3); | |||
| } | |||
| vst1q_s32(acc, acc_neon); | |||
| for (; k + 8 <= K; k += 8) { | |||
| int8x8_t a0 = vld1_s8(A + m * Astride + k); | |||
| int8x8_t a1 = vld1_s8(A + (m + 1) * Astride + k); | |||
| int8x8_t b0 = vld1_s8(B + k); | |||
| uint32x2_t zero = vdup_n_s32(0); | |||
| acc[0] += vaddv_s32(vdot_s32(zero, a0, b0)); | |||
| zero = vdup_n_s32(0); | |||
| acc[3] += vaddv_s32(vdot_s32(zero, a1, b0)); | |||
| } | |||
| for (; k < K; ++k) { | |||
| acc[0] += static_cast<int32_t>(A[m * Astride + k]) * B[k]; | |||
| acc[3] += static_cast<int32_t>(A[(m + 1) * Astride + k]) * B[k]; | |||
| } | |||
| C[m * Cstride] = acc[0] + acc[1]; | |||
| C[(m + 1) * Cstride] = acc[2] + acc[3]; | |||
| } | |||
| for (; m < M; ++m) { | |||
| int32_t acc[4]; | |||
| int32x4_t acc_neon = vdupq_n_s32(0); | |||
| size_t k = 0; | |||
| for (; k + 16 <= K; k += 16) { | |||
| int8x16_t a0 = vld1q_s8(A + m * Astride + k); | |||
| int8x16_t b0 = vld1q_s8(B + k); | |||
| acc_neon = vdotq_s32(acc_neon, a0, b0); | |||
| } | |||
| vst1q_s32(acc, acc_neon); | |||
| for (; k + 8 <= K; k += 8) { | |||
| int8x8_t a0 = vld1_s8(A + m * Astride + k); | |||
| int8x8_t b0 = vld1_s8(B + k); | |||
| uint32x2_t zero = vdup_n_s32(0); | |||
| acc[0] += vaddv_s32(vdot_s32(zero, a0, b0)); | |||
| } | |||
| for (; k < K; ++k) { | |||
| acc[0] += static_cast<int32_t>(A[m * Astride + k]) * B[k]; | |||
| } | |||
| C[m * Cstride] = acc[0] + acc[1] + acc[2] + acc[3]; | |||
| } | |||
| } | |||
| } // namespace | |||
| bool megdnn::aarch64::matmul::is_gemv_like_preferred_int8( | |||
| bool transposeA, bool transposeB, size_t M, size_t N, size_t K, | |||
| size_t /* LDA */, size_t LDB, size_t /* LDC */) { | |||
| if (transposeA) | |||
| return false; | |||
| if (transposeB) | |||
| return false; | |||
| MEGDNN_MARK_USED_VAR(K); | |||
| MEGDNN_MARK_USED_VAR(M); | |||
| return (N == 1 && LDB == 1); | |||
| } | |||
| void megdnn::aarch64::matmul::gemv_like_int8(const int8_t* __restrict A, | |||
| const int8_t* __restrict B, | |||
| int32_t* __restrict C, size_t M, | |||
| size_t N, size_t K, size_t Astride, | |||
| size_t Bstride, size_t Cstride) { | |||
| megdnn_assert(N == 1); | |||
| return gemv_naive_n(A, B, C, M, N, K, Astride, Bstride, Cstride); | |||
| } | |||
| #endif | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -1,34 +0,0 @@ | |||
| /** | |||
| * \file dnn/src/aarch64/matrix_mul/int8_dot/gemv.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include <cstddef> | |||
| #include <cstdint> | |||
| #if __ARM_FEATURE_DOTPROD | |||
| namespace megdnn { | |||
| namespace aarch64 { | |||
| namespace matmul { | |||
| bool is_gemv_like_preferred_int8(bool transposeA, bool transposeB, size_t M, | |||
| size_t N, size_t K, size_t LDA, size_t LDB, | |||
| size_t LDC); | |||
| void gemv_like_int8(const int8_t* __restrict A, const int8_t* __restrict B, | |||
| int32_t* __restrict C, size_t M, size_t N, size_t K, | |||
| size_t Astride, size_t Bstride, size_t Cstride); | |||
| } // namespace matmul | |||
| } // namespace aarch64 | |||
| } // namespace megdnn | |||
| #endif | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -28,13 +28,11 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||
| #endif | |||
| #if __ARM_FEATURE_DOTPROD | |||
| AlgoInt8x8x32K8x12x4DotProd int8x8x32_k8x12x4_dotprod; | |||
| AlgoInt8x8x32GemvDotProd int8x8x32_gemv_dotprod; | |||
| AlgoInt8x8x32MK4_8x12x4DotProd int8x8x32_mk4_8x12x4_dotprod; | |||
| #else | |||
| AlgoInt8x8x32MK4_4x4x16 int8x8x32_mk4_4x4x16; | |||
| AlgoInt8x8x32K4x4x16 int8x8x32_k4x4x16; | |||
| AlgoInt8x8x32K8x8x8 int8x8x32_k8x8x8; | |||
| AlgoInt8x8x32Gemv int8x8x32_gemv; | |||
| #endif | |||
| AlgoInt8x8x16K8x8x8 int8x8x16_k8x8x8; | |||
| AlgoInt8x8x16K4x4x16 int8x8x16_k4x4x16; | |||
| @@ -63,11 +61,9 @@ public: | |||
| all_algos.emplace_back(&f16_mk8_8x8); | |||
| #endif | |||
| #if __ARM_FEATURE_DOTPROD | |||
| all_algos.emplace_back(&int8x8x32_gemv_dotprod); | |||
| all_algos.emplace_back(&int8x8x32_k8x12x4_dotprod); | |||
| all_algos.emplace_back(&int8x8x32_mk4_8x12x4_dotprod); | |||
| #else | |||
| all_algos.emplace_back(&int8x8x32_gemv); | |||
| all_algos.emplace_back(&int8x8x32_k4x4x16); | |||
| all_algos.emplace_back(&int8x8x32_k8x8x8); | |||
| all_algos.emplace_back(&int8x8x32_mk4_4x4x16); | |||
| @@ -34,14 +34,12 @@ private: | |||
| #if __ARM_FEATURE_DOTPROD | |||
| class AlgoInt8x8x32K8x12x4DotProd; // Aarch64 Int8x8x32 Kernel | |||
| // 8x12x4 DotProduct | |||
| class AlgoInt8x8x32GemvDotProd; // Aarch64 Int8x8x32 Gemv DotProduct | |||
| class AlgoInt8x8x32MK4_8x12x4DotProd; // Aarch64 nchw44 Int8x8x32 Kernel | |||
| // 8x12x4 DotProduct | |||
| #else | |||
| class AlgoInt8x8x32MK4_4x4x16; // Aarch64 nchw44 Int8x8x32 Kernel 4x4x16 | |||
| class AlgoInt8x8x32K4x4x16; // Aarch64 Int8x8x32 Kernel 4x4x16 | |||
| class AlgoInt8x8x32K8x8x8; // Aarch64 Int8x8x32 Kernel 8x8x8 | |||
| class AlgoInt8x8x32Gemv; // Aarch64 Int8x8x32 Gemv | |||
| #endif | |||
| class AlgoInt8x8x16K8x8x8; // Aarch64 Int8x8x16 Kernel 8x8x8 | |||
| class AlgoInt8x8x16K4x4x16; // Aarch64 Int8x8x16 Kernel 4x4x16 | |||
| @@ -72,7 +72,6 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16::get_kern( | |||
| return exec_int_8x8x16; | |||
| } | |||
| #if !__ARM_FEATURE_DOTPROD | |||
| /* ===================== Int8x8x32 Gemv algo ===================== */ | |||
| namespace { | |||
| void int8x8x32_gemv_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
| @@ -102,7 +101,6 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32Gemv::get_kern( | |||
| const KernSizeParam&) const { | |||
| return int8x8x32_gemv_kern; | |||
| } | |||
| #endif | |||
| /* ===================== F32 Gemv algo ===================== */ | |||
| namespace { | |||
| @@ -112,7 +110,6 @@ void f32_gemv_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
| const auto Aptr = kern_param.A<dt_float32>(), | |||
| Bptr = kern_param.B<dt_float32>(); | |||
| auto Cptr = kern_param.C<dt_float32>(); | |||
| arm_common::sgemm_sgemv_like(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); | |||
| } | |||
| } // anonymous namespace | |||
| @@ -27,11 +27,7 @@ public: | |||
| PackMode packmode() const override { return PackMode::NO_PACK; } | |||
| }; | |||
| #if !__ARM_FEATURE_DOTPROD | |||
| class MatrixMulImpl::AlgoInt8x8x32Gemv : public AlgoBase { | |||
| protected: | |||
| ~AlgoInt8x8x32Gemv() = default; | |||
| public: | |||
| bool is_reproducible() const override { return true; } | |||
| const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV"; } | |||
| @@ -43,7 +39,6 @@ public: | |||
| AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||
| PackMode packmode() const override { return PackMode::NO_PACK; } | |||
| }; | |||
| #endif | |||
| class MatrixMulImpl::AlgoF32Gemv : public AlgoBase { | |||
| protected: | |||
| @@ -9,8 +9,6 @@ | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #if !__ARM_FEATURE_DOTPROD | |||
| #include <cstddef> | |||
| #include "src/arm_common/matrix_mul/int8/gemv.h" | |||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||
| @@ -23,6 +21,8 @@ MIDOUT_DECL(megdnn_arm_common_int8_gemv) | |||
| using namespace megdnn; | |||
| using namespace arm_common; | |||
| #if !__ARM_FEATURE_DOTPROD | |||
| namespace { | |||
| void gemv_naive_n(const int8_t* __restrict A, const int8_t* __restrict B, | |||
| @@ -95,8 +95,82 @@ void gemv_naive_n(const int8_t* __restrict A, const int8_t* __restrict B, | |||
| C[m * Cstride] = acc0; | |||
| } | |||
| } | |||
| } // namespace | |||
| #endif | |||
| #if __ARM_FEATURE_DOTPROD | |||
| namespace { | |||
| void gemv_naive_n(const int8_t* __restrict A, const int8_t* __restrict B, | |||
| int32_t* __restrict C, size_t M, size_t N, size_t K, | |||
| size_t Astride, size_t Bstride, size_t Cstride) { | |||
| megdnn_assert(N == 1 && Bstride == 1); | |||
| size_t m = 0; | |||
| for (; m + 2 <= M; m += 2) { | |||
| int32_t acc[4]; | |||
| int32x4_t acc_neon = vdupq_n_s32(0); | |||
| size_t k = 0; | |||
| for (; k + 16 <= K; k += 16) { | |||
| int64x2_t a0 = vreinterpretq_s64_s8(vld1q_s8(A + m * Astride + k)); | |||
| int64x2_t a1 = | |||
| vreinterpretq_s64_s8(vld1q_s8(A + (m + 1) * Astride + k)); | |||
| //! the first 8 elements is m, the last 8 elements is m + 1 | |||
| int8x16_t a2 = vreinterpretq_s8_s64(vzip1q_s64(a0, a1)); | |||
| int8x16_t a3 = vreinterpretq_s8_s64(vzip2q_s64(a0, a1)); | |||
| int64x2_t b0 = vreinterpretq_s64_s8(vld1q_s8(B + k)); | |||
| int8x16_t b2 = vreinterpretq_s8_s64(vzip1q_s64(b0, b0)); | |||
| int8x16_t b3 = vreinterpretq_s8_s64(vzip2q_s64(b0, b0)); | |||
| acc_neon = vdotq_s32(acc_neon, a2, b2); | |||
| acc_neon = vdotq_s32(acc_neon, a3, b3); | |||
| } | |||
| vst1q_s32(acc, acc_neon); | |||
| for (; k + 8 <= K; k += 8) { | |||
| int8x8_t a0 = vld1_s8(A + m * Astride + k); | |||
| int8x8_t a1 = vld1_s8(A + (m + 1) * Astride + k); | |||
| int8x8_t b0 = vld1_s8(B + k); | |||
| uint32x2_t zero = vdup_n_s32(0); | |||
| acc[0] += vaddv_s32(vdot_s32(zero, a0, b0)); | |||
| zero = vdup_n_s32(0); | |||
| acc[3] += vaddv_s32(vdot_s32(zero, a1, b0)); | |||
| } | |||
| for (; k < K; ++k) { | |||
| acc[0] += static_cast<int32_t>(A[m * Astride + k]) * B[k]; | |||
| acc[3] += static_cast<int32_t>(A[(m + 1) * Astride + k]) * B[k]; | |||
| } | |||
| C[m * Cstride] = acc[0] + acc[1]; | |||
| C[(m + 1) * Cstride] = acc[2] + acc[3]; | |||
| } | |||
| for (; m < M; ++m) { | |||
| int32_t acc[4]; | |||
| int32x4_t acc_neon = vdupq_n_s32(0); | |||
| size_t k = 0; | |||
| for (; k + 16 <= K; k += 16) { | |||
| int8x16_t a0 = vld1q_s8(A + m * Astride + k); | |||
| int8x16_t b0 = vld1q_s8(B + k); | |||
| acc_neon = vdotq_s32(acc_neon, a0, b0); | |||
| } | |||
| vst1q_s32(acc, acc_neon); | |||
| for (; k + 8 <= K; k += 8) { | |||
| int8x8_t a0 = vld1_s8(A + m * Astride + k); | |||
| int8x8_t b0 = vld1_s8(B + k); | |||
| uint32x2_t zero = vdup_n_s32(0); | |||
| acc[0] += vaddv_s32(vdot_s32(zero, a0, b0)); | |||
| } | |||
| for (; k < K; ++k) { | |||
| acc[0] += static_cast<int32_t>(A[m * Astride + k]) * B[k]; | |||
| } | |||
| C[m * Cstride] = acc[0] + acc[1] + acc[2] + acc[3]; | |||
| } | |||
| } | |||
| } // namespace | |||
| #endif | |||
| bool matmul::is_gemv_like_preferred_int8(bool transposeA, bool transposeB, | |||
| size_t M, size_t N, size_t K, | |||
| @@ -124,6 +198,5 @@ void matmul::gemv_like_int8(const int8_t* __restrict A, | |||
| } MIDOUT_END(); | |||
| } | |||
| #endif | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -13,7 +13,6 @@ | |||
| #include <cstddef> | |||
| #include <cstdint> | |||
| #if !__ARM_FEATURE_DOTPROD | |||
| namespace megdnn { | |||
| namespace arm_common { | |||
| namespace matmul { | |||
| @@ -28,6 +27,6 @@ void gemv_like_int8(const int8_t* __restrict A, const int8_t* __restrict B, | |||
| } // namespace matmul | |||
| } // namespace arm_common | |||
| } // namespace megdnn | |||
| #endif | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -27,13 +27,14 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
| AlgoF16Gemv f16gemv; | |||
| #endif | |||
| AlgoInt8x8x32Gemv int8x8x32_gemv; | |||
| public: | |||
| AlgoPack() { | |||
| all_algos.emplace_back(&int8x8x16); | |||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
| all_algos.emplace_back(&f16gemv); | |||
| #endif | |||
| all_algos.emplace_back(&int8x8x32_gemv); | |||
| } | |||
| SmallVector<AlgoBase*> all_algos; | |||
| }; | |||
| @@ -25,9 +25,7 @@ public: | |||
| protected: | |||
| static void* const sm_arm_common_algo_type; | |||
| #if !__ARM_FEATURE_DOTPROD | |||
| class AlgoInt8x8x32Gemv; // Arm_common Int 8x8x32 Gemv | |||
| #endif | |||
| class AlgoF32Gemv; // Arm_common F32 Gemv | |||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
| class AlgoF16Gemv; | |||
| @@ -388,6 +388,19 @@ __ai int64x2_t vmovl_high_s32(int32x4_t __p0) { | |||
| __ai uint64x2_t vmovl_high_u32(uint32x4_t __p0) { | |||
| return vmovl_u32(vget_high_u32(__p0)); | |||
| } | |||
| __ai int64x2_t vzip1q_s64(int64x2_t& a, int64x2_t& b) { | |||
| return vcombine_s64(vget_low_s64(a), vget_low_s64(b)); | |||
| } | |||
| __ai int64x2_t vzip2q_s64(int64x2_t& a, int64x2_t& b) { | |||
| return vcombine_s64(vget_high_s64(a), vget_high_s64(b)); | |||
| } | |||
| __ai int32_t vaddv_s32(int32x2_t a) { | |||
| return vget_lane_s32(a, 0) + vget_lane_s32(a, 1); | |||
| } | |||
| #endif // MEGDNN_ARMV7 | |||
| //! pack vmovl_low_xx() on armv7 and armv8 | |||
| @@ -134,11 +134,6 @@ public: | |||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
| }; | |||
| #if !__ARM_FEATURE_DOTPROD | |||
| class MatrixMulImpl::AlgoInt8x8x32Gemv final | |||
| : public arm_common::MatrixMulImpl::AlgoInt8x8x32Gemv {}; | |||
| #endif | |||
| class MatrixMulImpl::AlgoQuint8K4x8x8 final : public AlgoBase { | |||
| public: | |||
| bool is_reproducible() const override { return true; } | |||
| @@ -35,9 +35,6 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||
| AlgoInt8x8x32MK4_4x2x16 int8x8x32_mk4_4x2x16; | |||
| AlgoInt8x8x32K4x2x16 int8x8x32_k4x2x16; | |||
| AlgoInt8x8x32K4x8x8 int8x8x32_k4x8x8; | |||
| #if !__ARM_FEATURE_DOTPROD | |||
| AlgoInt8x8x32Gemv int8x8x32_gemv; | |||
| #endif | |||
| AlgoQuint8K4x8x8 quint8_k4x8x8; | |||
| AlgoInt8x8x16K4x2x16 int8x8x16_k4x2x16; | |||
| AlgoInt8x8x16K4x8x8 int8x8x16_k4x8x8; | |||
| @@ -60,9 +57,6 @@ public: | |||
| all_algos.emplace_back(&int8x8x32_mk4_8x4x4_dotprod); | |||
| all_algos.emplace_back(&int8_k6x8x4); | |||
| all_algos.emplace_back(&quint8_k4x8x4); | |||
| #endif | |||
| #if !__ARM_FEATURE_DOTPROD | |||
| all_algos.emplace_back(&int8x8x32_gemv); | |||
| #endif | |||
| all_algos.emplace_back(&int8x8x32_mk4_4x2x16); | |||
| all_algos.emplace_back(&int8x8x32_k4x2x16); | |||
| @@ -27,9 +27,6 @@ private: | |||
| class AlgoInt8x8x32K4x8x8; // Armv7 Int8x8x32 Kernel 4x8x8 | |||
| class AlgoInt8x8x32K4x2x16; // Armv7 Int8x8x32 Kernel 4x2x16 | |||
| class AlgoInt8x8x32MK4_4x2x16; // Armv7 Int8x8x32 Kernel MK4 4x2x16 | |||
| #if !__ARM_FEATURE_DOTPROD | |||
| class AlgoInt8x8x32Gemv; // Armv7 Int8x8x32 Gemv | |||
| #endif | |||
| class AlgoQuint8K4x8x8; // Armv7 Quint8 Kernel 4x8x8 | |||
| class AlgoInt8x8x16K4x2x16; // Armv7 Int8x8x16 Kernel 4x2x16 | |||
| class AlgoInt8x8x16K4x8x8; // Armv7 Int8x8x16 Kernel 4x8x8 | |||
| @@ -133,6 +133,36 @@ TEST_F(ARM_COMMON, MATRIX_MUL_FP16_TEST) { | |||
| } | |||
| #endif | |||
| TEST_F(ARM_COMMON, QINT8x8x32_GEMV) { | |||
| Checker<MatrixMul> checker(handle()); | |||
| using Param = MatrixMul::Param; | |||
| checker.set_before_exec_callback( | |||
| AlgoChecker<MatrixMul>("ARM_COMMON_INT8X8X32_GEMV")); | |||
| std::unique_ptr<RNG> rng = std::make_unique<UniformIntRNG>(-127, 127); | |||
| checker.set_rng(0, rng.get()).set_rng(1, rng.get()); | |||
| auto run = [&](size_t M, size_t K, size_t N) { | |||
| Param param; | |||
| param.transposeA = false; | |||
| param.transposeB = false; | |||
| TensorShape A, B; | |||
| A = TensorShape{M, K}; | |||
| B = TensorShape{K, N}; | |||
| checker.set_param(param) | |||
| .set_dtype(0, dtype::QuantizedS8(2.5f)) | |||
| .set_dtype(1, dtype::QuantizedS8(2.5f)) | |||
| .set_dtype(2, dtype::QuantizedS32(6.25f)) | |||
| .execs({A, B, {}}); | |||
| }; | |||
| // N = 1 | |||
| for (size_t M : {1, 10, 16, 33, 64}) | |||
| for (size_t K : {7, 512, 1024}) | |||
| for (size_t N : {1}) | |||
| run(M, K, N); | |||
| } | |||
| #if MEGDNN_WITH_BENCHMARK | |||