GitOrigin-RevId: a049c33f2b
tags/v1.0.0-rc1
| @@ -23,6 +23,9 @@ | |||
| #include "src/common/utils.h" | |||
| #include "src/fallback/matrix_mul/gemm_impl.h" | |||
| #if MGB_ENABLE_CPUINFO | |||
| #include "cpuinfo.h" | |||
| #endif | |||
| #include "midout.h" | |||
| MIDOUT_DECL(megdnn_aarch64_matmul_kern) | |||
| @@ -80,6 +83,7 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32K8x12x1::get_kern( | |||
| } | |||
| MIDOUT_END(); | |||
| }; | |||
| return f32_kern_8x12; | |||
| } | |||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32K8x12x1, megdnn_aarch64_matmul_kern, | |||
| @@ -837,6 +841,159 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K4x4x16, | |||
| aarch64::matmul::gemm_s8x8x16_4x4, int8_t, | |||
| int16_t); | |||
| /* ===================== Int8x8x16 K16x12x4 algo ===================== */ | |||
| namespace { | |||
| void int8x8x16_mk4_16x12x4_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
| MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, | |||
| midout_iv("int8x8x16_mk4_16x12x4_kern"_hash)) { | |||
| auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||
| auto trA = kern_param.trA, trB = kern_param.trB; | |||
| auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||
| auto A_type = kern_param.A_type, B_type = kern_param.B_type, | |||
| C_type = kern_param.C_type; | |||
| const auto Aptr = kern_param.A<dt_int8>(), | |||
| Bptr = kern_param.B<dt_int8>(); | |||
| auto Cptr = kern_param.C<dt_int16>(); | |||
| aarch64::matmul::gemm_s8x8x16_mk4_16x12_a53 strategy(M, N, K, A_type, | |||
| B_type, C_type); | |||
| megdnn::matmul::GemmInterleaved< | |||
| aarch64::matmul::gemm_s8x8x16_mk4_16x12_a53>(M, N, K, trA, trB, | |||
| strategy) | |||
| .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, | |||
| kern_param.workspace_ptr); | |||
| } | |||
| MIDOUT_END(); | |||
| } | |||
| } // anonymous namespace | |||
| bool MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4::usable( | |||
| const KernSizeParam& kern_size_param) const { | |||
| return can_be_treated_as_int8x8x16(kern_size_param) && | |||
| kern_size_param.format == param::MatrixMul::Format::MK4 && | |||
| kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && | |||
| !kern_size_param.trA && !kern_size_param.trB && | |||
| kern_size_param.M % 4 == 0 && kern_size_param.K % 4 == 0; | |||
| } | |||
| bool MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4::preferred( | |||
| const KernSizeParam&) const { | |||
| #if !MGB_ENABLE_CPUINFO | |||
| return false; | |||
| #else | |||
| auto arch = cpuinfo_get_current_core()->uarch; | |||
| bool little_core = arch == cpuinfo_uarch_cortex_a53 || | |||
| arch == cpuinfo_uarch_cortex_a55; | |||
| return little_core; | |||
| #endif | |||
| } | |||
| size_t MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4::get_workspace( | |||
| const KernSizeParam& kern_size_param) const { | |||
| MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, | |||
| midout_iv("AlgoInt8x8x16MK4_16x12x4::get_workspace"_hash)) { | |||
| auto M = kern_size_param.M, N = kern_size_param.N, | |||
| K = kern_size_param.K; | |||
| auto trA = kern_size_param.trA, trB = kern_size_param.trB; | |||
| auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, | |||
| C_type = kern_size_param.C_type; | |||
| aarch64::matmul::gemm_s8x8x16_mk4_16x12_a53 strategy(M, N, K, A_type, | |||
| B_type, C_type); | |||
| return megdnn::matmul::GemmInterleaved< | |||
| matmul::gemm_s8x8x16_mk4_16x12_a53>(M, N, K, trA, trB, | |||
| strategy) | |||
| .get_workspace_size(); | |||
| } | |||
| MIDOUT_END(); | |||
| } | |||
| MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4::get_kern( | |||
| const KernSizeParam&) const { | |||
| return int8x8x16_mk4_16x12x4_kern; | |||
| } | |||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( | |||
| AlgoInt8x8x16MK4_16x12x4, megdnn_aarch64_matmul_kern, | |||
| "AlgoInt8x8x16MK4_16x12x4Impl"_hash, | |||
| aarch64::matmul::gemm_s8x8x16_mk4_16x12_a53, int8_t, int16_t, int16_t); | |||
| /* ===================== Int8x8x16 MK4 4x4x8 algo ===================== */ | |||
| namespace { | |||
| void int8x8x16_mk4_4x4x8_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
| MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, | |||
| midout_iv("int8x8x16_mk4_4x4x8_kern"_hash)) { | |||
| auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||
| auto trA = kern_param.trA, trB = kern_param.trB; | |||
| auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||
| auto A_type = kern_param.A_type, B_type = kern_param.B_type, | |||
| C_type = kern_param.C_type; | |||
| const auto Aptr = kern_param.A<dt_int8>(), | |||
| Bptr = kern_param.B<dt_int8>(); | |||
| auto Cptr = kern_param.C<dt_int16>(); | |||
| aarch64::matmul::gemm_s8x8x16_mk4_4x4_a72 strategy(M, N, K, A_type, | |||
| B_type, C_type); | |||
| megdnn::matmul::GemmInterleaved< | |||
| aarch64::matmul::gemm_s8x8x16_mk4_4x4_a72>(M, N, K, trA, trB, | |||
| strategy) | |||
| .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, | |||
| kern_param.workspace_ptr); | |||
| } | |||
| MIDOUT_END(); | |||
| } | |||
| } // anonymous namespace | |||
| bool MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8::usable( | |||
| const KernSizeParam& kern_size_param) const { | |||
| return can_be_treated_as_int8x8x16(kern_size_param) && | |||
| kern_size_param.format == param::MatrixMul::Format::MK4 && | |||
| kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && | |||
| !kern_size_param.trA && !kern_size_param.trB && | |||
| kern_size_param.M % 4 == 0 && kern_size_param.K % 4 == 0; | |||
| } | |||
| bool MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8::preferred( | |||
| const KernSizeParam&) const { | |||
| #if !MGB_ENABLE_CPUINFO | |||
| return false; | |||
| #else | |||
| auto arch = cpuinfo_get_current_core()->uarch; | |||
| bool little_core = arch == cpuinfo_uarch_cortex_a53 || | |||
| arch == cpuinfo_uarch_cortex_a55; | |||
| return !little_core; | |||
| #endif | |||
| } | |||
| size_t MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8::get_workspace( | |||
| const KernSizeParam& kern_size_param) const { | |||
| MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, | |||
| midout_iv("AlgoInt8x8x16MK4_4x4x8::get_workspace"_hash)) { | |||
| auto M = kern_size_param.M, N = kern_size_param.N, | |||
| K = kern_size_param.K; | |||
| auto trA = kern_size_param.trA, trB = kern_size_param.trB; | |||
| auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, | |||
| C_type = kern_size_param.C_type; | |||
| aarch64::matmul::gemm_s8x8x16_mk4_4x4_a72 strategy(M, N, K, A_type, | |||
| B_type, C_type); | |||
| return megdnn::matmul::GemmInterleaved< | |||
| matmul::gemm_s8x8x16_mk4_4x4_a72>(M, N, K, trA, trB, | |||
| strategy) | |||
| .get_workspace_size(); | |||
| } | |||
| MIDOUT_END(); | |||
| } | |||
| MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8::get_kern( | |||
| const KernSizeParam&) const { | |||
| return int8x8x16_mk4_4x4x8_kern; | |||
| } | |||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16MK4_4x4x8, | |||
| megdnn_aarch64_matmul_kern, | |||
| "AlgoInt8x8x16MK4_4x4x8_Impl"_hash, | |||
| aarch64::matmul::gemm_s8x8x16_mk4_4x4_a72, | |||
| int8_t, int16_t); | |||
| /* ===================== Int16x16x32 K12x8x1 algo ===================== */ | |||
| namespace { | |||
| void int16x16x32_k12x8x1_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
| @@ -6,7 +6,8 @@ | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| @@ -121,12 +122,9 @@ public: | |||
| #else | |||
| class MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16 final : public AlgoBase { | |||
| public: | |||
| bool is_reproducible() const override { return true; } | |||
| const char* name() const override { | |||
| return "AARCH64_INT8X8X32_MK4_4X4X16"; | |||
| } | |||
| const char* name() const override { return "AARCH64_INT8X8X32_MK4_4X4X16"; } | |||
| bool usable(const KernSizeParam&) const override; | |||
| bool preferred(const KernSizeParam&) const override; | |||
| size_t get_workspace(const KernSizeParam&) const override; | |||
| @@ -188,6 +186,36 @@ public: | |||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
| }; | |||
| class MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4 final : public AlgoBase { | |||
| public: | |||
| bool is_reproducible() const override { return true; } | |||
| const char* name() const override { | |||
| return "AARCH64_INT8X8X16_MK4_16X12X4"; | |||
| } | |||
| bool usable(const KernSizeParam&) const override; | |||
| bool preferred(const KernSizeParam&) const override; | |||
| size_t get_workspace(const KernSizeParam&) const override; | |||
| kern_t get_kern(const KernSizeParam&) const override; | |||
| void* type() const override { return sm_arm_common_algo_type; } | |||
| PackMode packmode() const override { return PackMode::DEFAULT; } | |||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
| }; | |||
| class MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8 final : public AlgoBase { | |||
| public: | |||
| bool is_reproducible() const override { return true; } | |||
| const char* name() const override { return "AARCH64_INT8X8X16_MK4_4X4X8"; } | |||
| bool usable(const KernSizeParam&) const override; | |||
| bool preferred(const KernSizeParam&) const override; | |||
| size_t get_workspace(const KernSizeParam&) const override; | |||
| kern_t get_kern(const KernSizeParam&) const override; | |||
| void* type() const override { return sm_arm_common_algo_type; } | |||
| PackMode packmode() const override { return PackMode::DEFAULT; } | |||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
| }; | |||
| class MatrixMulImpl::AlgoInt16x16x32K12x8x1 final : public AlgoBase { | |||
| public: | |||
| bool is_reproducible() const override { return true; } | |||
| @@ -6,7 +6,8 @@ | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include <cmath> | |||
| @@ -993,8 +994,8 @@ static inline void interleave_4x1_4_s(const int32_t*& inptr0, | |||
| template <typename T> | |||
| static inline void interleave_4x8_1_s(const T*& inptr0, const T*& inptr1, | |||
| const T*& inptr2, const T*& inptr3, | |||
| T*& outptr) { | |||
| const T*& inptr2, const T*& inptr3, | |||
| T*& outptr) { | |||
| static_assert(sizeof(T) == 4, "only support size == 4"); | |||
| asm volatile( | |||
| "ld1 {v0.4s, v1.4s}, [%[inptr0]], #32\n" | |||
| @@ -1140,8 +1141,8 @@ static inline void interleave_2x4_4_s(const T*& inptr0, const T*& inptr1, | |||
| "stp q2, q6, [%[outptr], #64]\n" | |||
| "stp q3, q7, [%[outptr], #96]\n" | |||
| : [ inptr0 ] "+r"(inptr0), [ inptr1 ] "+r"(inptr1), | |||
| [ outptr ] "+r"(outptr) | |||
| : | |||
| [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [outptr] "+r"(outptr) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "memory"); | |||
| } | |||
| @@ -1153,7 +1154,7 @@ static inline void interleave_1x4_4_s(const T*& inptr0, T* outptr) { | |||
| "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[inptr0]], #64\n" | |||
| "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[outptr]]\n" | |||
| : [ inptr0 ] "+r"(inptr0), [ outptr ] "+r"(outptr) | |||
| : [inptr0] "+r"(inptr0), [outptr] "+r"(outptr) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "memory"); | |||
| } | |||
| @@ -1550,7 +1551,7 @@ static inline void transpose_1x12_4_s(const T*& inptr0, T* outptr) { | |||
| "stp q2, q6, [%[outptr], #96] \n" | |||
| "stp q10, q3, [%[outptr], #128] \n" | |||
| "stp q7, q11, [%[outptr], #160] \n" | |||
| : [ inptr0 ] "+r"(inptr0), [ outptr ] "+r"(outptr) | |||
| : [inptr0] "+r"(inptr0), [outptr] "+r"(outptr) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
| "v11", "memory"); | |||
| @@ -1564,7 +1565,7 @@ static inline void transpose_1x4_4_s(const T*& inptr0, T* outptr) { | |||
| asm volatile( | |||
| "ld4 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[inptr0]], #64\n" | |||
| "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[outptr]]\n" | |||
| : [ inptr0 ] "+r"(inptr0), [ outptr ] "+r"(outptr) | |||
| : [inptr0] "+r"(inptr0), [outptr] "+r"(outptr) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "memory"); | |||
| } | |||
| @@ -1681,13 +1682,12 @@ static inline void transpose_12x4_1_s(const T*& inptr0, const T*& inptr1, | |||
| "st1 {v3.4s,v4.4s,v5.4s}, [%[outptr]], #48\n" | |||
| "st1 {v6.4s,v7.4s,v8.4s}, [%[outptr]], #48\n" | |||
| "st1 {v24.4s,v25.4s,v26.4s}, [%[outptr]], #48\n" | |||
| : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), | |||
| [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), | |||
| [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), | |||
| [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), | |||
| [inptr8] "+r"(inptr8), [inptr9] "+r"(inptr9), | |||
| [inptr10] "+r"(inptr10), [inptr11] "+r"(inptr11), | |||
| [outptr] "+r"(outptr) | |||
| : | |||
| [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), | |||
| [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), | |||
| [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [inptr8] "+r"(inptr8), | |||
| [inptr9] "+r"(inptr9), [inptr10] "+r"(inptr10), | |||
| [inptr11] "+r"(inptr11), [outptr] "+r"(outptr) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||
| @@ -1972,6 +1972,135 @@ static inline void transpose_interleave_1x4_4_b(const T*& inptr0, T* outptr, | |||
| : "v0", "v1", "v2", "v3", "v4", "memory"); | |||
| } | |||
| static inline void interleave_4x4_16x4_s8_s16(const int8_t* inptr0, | |||
| const int8_t* inptr1, | |||
| const int8_t* inptr2, | |||
| const int8_t* inptr3, | |||
| int16_t* outptr) { | |||
| int8x16_t row0 = vld1q_s8(inptr0); | |||
| int16x8_t row0_01 = vmovl_low_s8(row0); | |||
| int16x8_t row0_23 = vmovl_high_s8(row0); | |||
| int16x4_t row0_0 = vget_low_s16(row0_01); | |||
| int16x4_t row0_1 = vget_high_s16(row0_01); | |||
| int16x4_t row0_2 = vget_low_s16(row0_23); | |||
| int16x4_t row0_3 = vget_high_s16(row0_23); | |||
| int8x16_t row1 = vld1q_s8(inptr1); | |||
| int16x8_t row1_01 = vmovl_low_s8(row1); | |||
| int16x8_t row1_23 = vmovl_high_s8(row1); | |||
| int16x4_t row1_0 = vget_low_s16(row1_01); | |||
| int16x4_t row1_1 = vget_high_s16(row1_01); | |||
| int16x4_t row1_2 = vget_low_s16(row1_23); | |||
| int16x4_t row1_3 = vget_high_s16(row1_23); | |||
| int8x16_t row2 = vld1q_s8(inptr2); | |||
| int16x8_t row2_01 = vmovl_low_s8(row2); | |||
| int16x8_t row2_23 = vmovl_high_s8(row2); | |||
| int16x4_t row2_0 = vget_low_s16(row2_01); | |||
| int16x4_t row2_1 = vget_high_s16(row2_01); | |||
| int16x4_t row2_2 = vget_low_s16(row2_23); | |||
| int16x4_t row2_3 = vget_high_s16(row2_23); | |||
| int8x16_t row3 = vld1q_s8(inptr3); | |||
| int16x8_t row3_01 = vmovl_low_s8(row3); | |||
| int16x8_t row3_23 = vmovl_high_s8(row3); | |||
| int16x4_t row3_0 = vget_low_s16(row3_01); | |||
| int16x4_t row3_1 = vget_high_s16(row3_01); | |||
| int16x4_t row3_2 = vget_low_s16(row3_23); | |||
| int16x4_t row3_3 = vget_high_s16(row3_23); | |||
| vst1_s16(outptr, row0_0); | |||
| vst1_s16(outptr + 1 * 4, row1_0); | |||
| vst1_s16(outptr + 2 * 4, row2_0); | |||
| vst1_s16(outptr + 3 * 4, row3_0); | |||
| vst1_s16(outptr + 4 * 4, row0_1); | |||
| vst1_s16(outptr + 5 * 4, row1_1); | |||
| vst1_s16(outptr + 6 * 4, row2_1); | |||
| vst1_s16(outptr + 7 * 4, row3_1); | |||
| vst1_s16(outptr + 8 * 4, row0_2); | |||
| vst1_s16(outptr + 9 * 4, row1_2); | |||
| vst1_s16(outptr + 10 * 4, row2_2); | |||
| vst1_s16(outptr + 11 * 4, row3_2); | |||
| vst1_s16(outptr + 12 * 4, row0_3); | |||
| vst1_s16(outptr + 13 * 4, row1_3); | |||
| vst1_s16(outptr + 14 * 4, row2_3); | |||
| vst1_s16(outptr + 15 * 4, row3_3); | |||
| }; | |||
| static inline void interleave_4x4_8x4_s8_s16(const int8_t* inptr0, | |||
| const int8_t* inptr1, | |||
| int16_t* outptr) { | |||
| int8x16_t row0 = vld1q_s8(inptr0); | |||
| int16x8_t row0_01 = vmovl_low_s8(row0); | |||
| int16x8_t row0_23 = vmovl_high_s8(row0); | |||
| int16x4_t row0_0 = vget_low_s16(row0_01); | |||
| int16x4_t row0_1 = vget_high_s16(row0_01); | |||
| int16x4_t row0_2 = vget_low_s16(row0_23); | |||
| int16x4_t row0_3 = vget_high_s16(row0_23); | |||
| int8x16_t row1 = vld1q_s8(inptr1); | |||
| int16x8_t row1_01 = vmovl_low_s8(row1); | |||
| int16x8_t row1_23 = vmovl_high_s8(row1); | |||
| int16x4_t row1_0 = vget_low_s16(row1_01); | |||
| int16x4_t row1_1 = vget_high_s16(row1_01); | |||
| int16x4_t row1_2 = vget_low_s16(row1_23); | |||
| int16x4_t row1_3 = vget_high_s16(row1_23); | |||
| vst1_s16(outptr, row0_0); | |||
| vst1_s16(outptr + 1 * 4, row1_0); | |||
| vst1_s16(outptr + 2 * 4, row0_1); | |||
| vst1_s16(outptr + 3 * 4, row1_1); | |||
| vst1_s16(outptr + 4 * 4, row0_2); | |||
| vst1_s16(outptr + 5 * 4, row1_2); | |||
| vst1_s16(outptr + 6 * 4, row0_3); | |||
| vst1_s16(outptr + 7 * 4, row1_3); | |||
| }; | |||
| static inline void memcpy_s8_s16(const int8_t* inptr, int16_t* outptr, | |||
| int count) { | |||
| for (; count >= 32; count -= 32) { | |||
| int8x8_t in0 = vld1_s8(inptr); | |||
| int8x8_t in1 = vld1_s8(inptr + 1 * 8); | |||
| int8x8_t in2 = vld1_s8(inptr + 2 * 8); | |||
| int8x8_t in3 = vld1_s8(inptr + 3 * 8); | |||
| vst1q_s16(outptr, vmovl_s8(in0)); | |||
| vst1q_s16(outptr + 1 * 8, vmovl_s8(in1)); | |||
| vst1q_s16(outptr + 2 * 8, vmovl_s8(in2)); | |||
| vst1q_s16(outptr + 3 * 8, vmovl_s8(in3)); | |||
| inptr += 32; | |||
| outptr += 32; | |||
| } | |||
| for (; count >= 8; count -= 8) { | |||
| int8x8_t in0 = vld1_s8(inptr); | |||
| vst1q_s16(outptr, vmovl_s8(in0)); | |||
| inptr += 8; | |||
| outptr += 8; | |||
| } | |||
| for (; count > 0; --count) { | |||
| *outptr++ = (int16_t)(*inptr++); | |||
| } | |||
| } | |||
| static inline void transpos_12x4_s8(const int8_t* inptr0, int8_t* outptr) { | |||
| static const uint8_t src_idx_buffer[16] = {0, 4, 8, 12, 1, 5, 9, 13, | |||
| 2, 6, 10, 14, 3, 7, 11, 15}; | |||
| static const uint8x16_t vtbl = vld1q_u8(&src_idx_buffer[0]); | |||
| int8x8x4_t input = vld4_s8(inptr0); | |||
| int8x16_t input2 = vqtbl1q_s8(vld1q_s8(inptr0 + 4 * 8), vtbl); | |||
| vst1_s8(outptr, input.val[0]); | |||
| vst1q_lane_s32(reinterpret_cast<int32_t*>(outptr + 8), | |||
| vreinterpretq_s32_s8(input2), 0); | |||
| vst1_s8(outptr + 1 * 12, input.val[1]); | |||
| vst1q_lane_s32(reinterpret_cast<int32_t*>(outptr + 1 * 12 + 8), | |||
| vreinterpretq_s32_s8(input2), 1); | |||
| vst1_s8(outptr + 2 * 12, input.val[2]); | |||
| vst1q_lane_s32(reinterpret_cast<int32_t*>(outptr + 2 * 12 + 8), | |||
| vreinterpretq_s32_s8(input2), 2); | |||
| vst1_s8(outptr + 3 * 12, input.val[3]); | |||
| vst1q_lane_s32(reinterpret_cast<int32_t*>(outptr + 3 * 12 + 8), | |||
| vreinterpretq_s32_s8(input2), 3); | |||
| } | |||
| } // namespace aarch64 | |||
| } // namespace megdnn | |||
| @@ -6,42 +6,55 @@ | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/aarch64/matrix_mul/fp32/strategy.h" | |||
| #include "src/aarch64/matrix_mul/fp32/kernel_general_4x16.h" | |||
| #include "src/aarch64/matrix_mul/fp32/kernel_general_8x12.h" | |||
| #include "src/aarch64/matrix_mul/fp32/kernel_general_8x12_a53.h" | |||
| #include "src/aarch64/matrix_mul/fp32/kernel_general_8x12_a55.h" | |||
| #include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h" | |||
| #include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a53.h" | |||
| #include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a55.h" | |||
| #include "src/aarch64/matrix_mul/fp32/strategy.h" | |||
| #include "src/common/utils.h" | |||
| #if MGB_ENABLE_CPUINFO | |||
| #include "cpuinfo.h" | |||
| #endif | |||
| using namespace megdnn; | |||
| using namespace aarch64; | |||
| using namespace aarch64::matmul; | |||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_4x16); | |||
| void sgemm_4x16::pack_A(float* out, const float* in, int ldin, int y0, | |||
| int ymax, int k0, int kmax, bool transpose_A) const { | |||
| void sgemm_4x16::pack_A(float* out, const float* in, int ldin, int y0, int ymax, | |||
| int k0, int kmax, bool transpose_A) const { | |||
| if (transpose_A) { | |||
| matmul_general_4x16::sgemm_4x16_pack_A_t(out, in, ldin, y0, ymax, k0, kmax); | |||
| matmul_general_4x16::sgemm_4x16_pack_A_t(out, in, ldin, y0, ymax, k0, | |||
| kmax); | |||
| } else { | |||
| matmul_general_4x16::sgemm_4x16_pack_A_n(out, in, ldin, y0, ymax, k0, kmax); | |||
| matmul_general_4x16::sgemm_4x16_pack_A_n(out, in, ldin, y0, ymax, k0, | |||
| kmax); | |||
| } | |||
| } | |||
| void sgemm_4x16::pack_B(float* out, const float* in, int ldin, int x0, int xmax, | |||
| int k0, int kmax, bool transpose_B) const { | |||
| if (transpose_B) { | |||
| matmul_general_4x16::sgemm_4x16_pack_B_t(out, in, ldin, x0, xmax, k0, kmax); | |||
| matmul_general_4x16::sgemm_4x16_pack_B_t(out, in, ldin, x0, xmax, k0, | |||
| kmax); | |||
| } else { | |||
| matmul_general_4x16::sgemm_4x16_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); | |||
| matmul_general_4x16::sgemm_4x16_pack_B_n(out, in, ldin, x0, xmax, k0, | |||
| kmax); | |||
| } | |||
| } | |||
| void sgemm_4x16::kern(const float* packA, const float* packB, | |||
| size_t M, size_t N, size_t K, float* C, size_t LDC, | |||
| bool is_first_k, const float*, float*) const { | |||
| void sgemm_4x16::kern(const float* packA, const float* packB, size_t M, | |||
| size_t N, size_t K, float* C, size_t LDC, bool is_first_k, | |||
| const float*, float*) const { | |||
| megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
| A_dtype.enumv() == C_dtype.enumv() && | |||
| A_dtype.enumv() == DTypeEnum::Float32); | |||
| @@ -61,15 +74,17 @@ void sgemm_4x16::kern(const float* packA, const float* packB, | |||
| size_t n = 0; | |||
| const float* cur_packB = packB; | |||
| for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
| matmul_general_4x16::kern_4x16(packA, cur_packB, K, output, LDC, is_first_k, | |||
| std::min<size_t>(M - m, 4)); | |||
| matmul_general_4x16::kern_4x16(packA, cur_packB, K, output, LDC, | |||
| is_first_k, | |||
| std::min<size_t>(M - m, 4)); | |||
| output += B_INTERLEAVE; | |||
| cur_packB += K16; | |||
| } | |||
| for (; n < N; n += 4) { | |||
| matmul_general_4x16::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, | |||
| std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||
| matmul_general_4x16::kern_4x4( | |||
| packA, cur_packB, K, output, LDC, is_first_k, | |||
| std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||
| output += 4; | |||
| cur_packB += K4; | |||
| } | |||
| @@ -80,8 +95,8 @@ void sgemm_4x16::kern(const float* packA, const float* packB, | |||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_8x12); | |||
| void sgemm_8x12::pack_A(float* out, const float* in, int ldin, int y0, | |||
| int ymax, int k0, int kmax, bool transpose_A) const { | |||
| void sgemm_8x12::pack_A(float* out, const float* in, int ldin, int y0, int ymax, | |||
| int k0, int kmax, bool transpose_A) const { | |||
| if (transpose_A) { | |||
| matmul_general_8x12::sgemm_8x12_pack_A_t(out, in, ldin, y0, ymax, k0, | |||
| kmax); | |||
| @@ -102,16 +117,10 @@ void sgemm_8x12::pack_B(float* out, const float* in, int ldin, int x0, int xmax, | |||
| } | |||
| } | |||
| void sgemm_8x12::kern(const float* packA, const float* packB, | |||
| size_t M, size_t N, size_t K, float* C, size_t LDC, | |||
| bool is_first_k, const float*, float*) const { | |||
| megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
| A_dtype.enumv() == C_dtype.enumv() && | |||
| A_dtype.enumv() == DTypeEnum::Float32); | |||
| MEGDNN_MARK_USED_VAR(A_dtype); | |||
| MEGDNN_MARK_USED_VAR(B_dtype); | |||
| MEGDNN_MARK_USED_VAR(C_dtype); | |||
| template <typename gemm_class> | |||
| static inline void sgemm_8x12_helper(const float* packA, const float* packB, | |||
| size_t M, size_t N, size_t K, float* C, | |||
| size_t LDC, bool is_first_k) { | |||
| constexpr size_t A_INTERLEAVE = 8; | |||
| constexpr size_t A_INTERLEAVE4 = 4; | |||
| constexpr size_t B_INTERLEAVE = 12; | |||
| @@ -126,16 +135,14 @@ void sgemm_8x12::kern(const float* packA, const float* packB, | |||
| size_t n = 0; | |||
| const float* cur_packB = packB; | |||
| for (; n + B_INTERLEAVE <= N; n += B_INTERLEAVE) { | |||
| matmul_general_8x12::kern_8x12(packA, cur_packB, K, output, LDC, | |||
| is_first_k); | |||
| gemm_class::kern_8x12(packA, cur_packB, K, output, LDC, is_first_k); | |||
| output += B_INTERLEAVE; | |||
| cur_packB += K12; | |||
| } | |||
| for (; n < N; n += 4) { | |||
| matmul_general_8x12::kern_8x4(packA, cur_packB, K, output, LDC, | |||
| is_first_k, | |||
| std::min<size_t>(N - n, 4)); | |||
| gemm_class::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k, | |||
| std::min<size_t>(N - n, 4)); | |||
| output += 4; | |||
| cur_packB += K4; | |||
| } | |||
| @@ -146,17 +153,16 @@ void sgemm_8x12::kern(const float* packA, const float* packB, | |||
| size_t n = 0; | |||
| const float* cur_packB = packB; | |||
| for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
| matmul_general_8x12::kern_4x12(packA, cur_packB, K, output, LDC, | |||
| is_first_k, | |||
| std::min<size_t>(M - m, 4)); | |||
| gemm_class::kern_4x12(packA, cur_packB, K, output, LDC, is_first_k, | |||
| std::min<size_t>(M - m, 4)); | |||
| output += B_INTERLEAVE; | |||
| cur_packB += K12; | |||
| } | |||
| for (; n < N; n += 4) { | |||
| matmul_general_8x12::kern_4x4( | |||
| packA, cur_packB, K, output, LDC, is_first_k, | |||
| std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||
| gemm_class::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, | |||
| std::min<size_t>(M - m, 4), | |||
| std::min<size_t>(N - n, 4)); | |||
| output += 4; | |||
| cur_packB += K4; | |||
| } | |||
| @@ -164,6 +170,33 @@ void sgemm_8x12::kern(const float* packA, const float* packB, | |||
| } | |||
| } | |||
| void sgemm_8x12::kern(const float* packA, const float* packB, size_t M, | |||
| size_t N, size_t K, float* C, size_t LDC, bool is_first_k, | |||
| const float*, float*) const { | |||
| megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
| A_dtype.enumv() == C_dtype.enumv() && | |||
| A_dtype.enumv() == DTypeEnum::Float32); | |||
| MEGDNN_MARK_USED_VAR(A_dtype); | |||
| MEGDNN_MARK_USED_VAR(B_dtype); | |||
| MEGDNN_MARK_USED_VAR(C_dtype); | |||
| #if !MGB_ENABLE_CPUINFO | |||
| sgemm_8x12_helper<matmul_general_8x12>(packA, packB, M, N, K, C, LDC, | |||
| is_first_k); | |||
| #else | |||
| auto arch = cpuinfo_get_current_core()->uarch; | |||
| if (arch == cpuinfo_uarch_cortex_a53) { | |||
| sgemm_8x12_helper<matmul_general_8x12_a53>(packA, packB, M, N, K, C, | |||
| LDC, is_first_k); | |||
| } else if (arch == cpuinfo_uarch_cortex_a55) { | |||
| sgemm_8x12_helper<matmul_general_8x12_a55>(packA, packB, M, N, K, C, | |||
| LDC, is_first_k); | |||
| } else { | |||
| sgemm_8x12_helper<matmul_general_8x12>(packA, packB, M, N, K, C, LDC, | |||
| is_first_k); | |||
| } | |||
| #endif | |||
| } | |||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_mk4_8x12); | |||
| void sgemm_mk4_8x12::pack_A(float* out, const float* in, int ldin, int y0, | |||
| @@ -180,25 +213,17 @@ void sgemm_mk4_8x12::pack_B(float* out, const float* in, int ldin, int x0, | |||
| matmul_mk4_8x12::sgemm_8x12_pack_B(out, in, ldin, x0, xmax, k0, kmax); | |||
| } | |||
| void sgemm_mk4_8x12::kern(const float* packA, const float* packB, | |||
| size_t M, size_t N, size_t K, float* C, size_t LDC, | |||
| bool is_first_k, const float*, float*) const { | |||
| megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
| A_dtype.enumv() == C_dtype.enumv() && | |||
| A_dtype.enumv() == DTypeEnum::Float32); | |||
| MEGDNN_MARK_USED_VAR(A_dtype); | |||
| MEGDNN_MARK_USED_VAR(B_dtype); | |||
| MEGDNN_MARK_USED_VAR(C_dtype); | |||
| megdnn_assert(M % 4 == 0 && K % 4 == 0, "M and K must be time of 4"); | |||
| template <typename gemm_name> | |||
| static inline void sgemm_mk4_8x12_helper(const float* packA, const float* packB, | |||
| size_t M, size_t N, size_t K, float* C, | |||
| size_t LDC, bool is_first_k) { | |||
| const int K12 = K * 12; | |||
| const int K8 = K * 8; | |||
| const int K4 = K * 4; | |||
| constexpr size_t PACK_C_SIZE = 4; | |||
| constexpr size_t A_INTERLEAVE = 8; | |||
| constexpr size_t A_INTERLEAVE4 = 4; | |||
| constexpr size_t B_INTERLEAVE = 12; | |||
| const int K12 = K * 12; | |||
| const int K8 = K * 8; | |||
| const int K4 = K * 4; | |||
| size_t m = 0; | |||
| for (; m + A_INTERLEAVE <= M; m += A_INTERLEAVE) { | |||
| float* output = C + (m / PACK_C_SIZE * LDC); | |||
| @@ -206,15 +231,14 @@ void sgemm_mk4_8x12::kern(const float* packA, const float* packB, | |||
| size_t n = 0; | |||
| const float* cur_packB = packB; | |||
| for (; n + B_INTERLEAVE <= N; n += B_INTERLEAVE) { | |||
| matmul_mk4_8x12::kern_8x12(packA, cur_packB, K, output, LDC, | |||
| is_first_k); | |||
| gemm_name::kern_8x12(packA, cur_packB, K, output, LDC, is_first_k); | |||
| output += B_INTERLEAVE * PACK_C_SIZE; | |||
| cur_packB += K12; | |||
| } | |||
| for (; n < N; n += 4) { | |||
| matmul_mk4_8x12::kern_8x4(packA, cur_packB, K, output, LDC, | |||
| is_first_k, std::min<size_t>(N - n, 4)); | |||
| for (; n < N; n += 4) { | |||
| gemm_name::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k, | |||
| std::min<size_t>(N - n, 4)); | |||
| output += 4 * PACK_C_SIZE; | |||
| cur_packB += K4; | |||
| } | |||
| @@ -225,19 +249,45 @@ void sgemm_mk4_8x12::kern(const float* packA, const float* packB, | |||
| size_t n = 0; | |||
| const float* cur_packB = packB; | |||
| for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
| matmul_mk4_8x12::kern_4x12(packA, cur_packB, K, output, LDC, | |||
| is_first_k); | |||
| gemm_name::kern_4x12(packA, cur_packB, K, output, LDC, is_first_k); | |||
| output += B_INTERLEAVE * PACK_C_SIZE; | |||
| cur_packB += K12; | |||
| } | |||
| for (; n < N; n += 4) { | |||
| matmul_mk4_8x12::kern_4x4(packA, cur_packB, K, output, LDC, | |||
| is_first_k, std::min<size_t>(N - n, 4)); | |||
| gemm_name::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, | |||
| std::min<size_t>(N - n, 4)); | |||
| output += 4 * PACK_C_SIZE; | |||
| cur_packB += K4; | |||
| } | |||
| packA += K4; | |||
| } | |||
| } | |||
| void sgemm_mk4_8x12::kern(const float* packA, const float* packB, size_t M, | |||
| size_t N, size_t K, float* C, size_t LDC, | |||
| bool is_first_k, const float*, float*) const { | |||
| megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
| A_dtype.enumv() == C_dtype.enumv() && | |||
| A_dtype.enumv() == DTypeEnum::Float32); | |||
| MEGDNN_MARK_USED_VAR(A_dtype); | |||
| MEGDNN_MARK_USED_VAR(B_dtype); | |||
| MEGDNN_MARK_USED_VAR(C_dtype); | |||
| megdnn_assert(M % 4 == 0 && K % 4 == 0, "M and K must be time of 4"); | |||
| #if !MGB_ENABLE_CPUINFO | |||
| sgemm_mk4_8x12_helper<matmul_mk4_8x12>(packA, packB, M, N, K, C, LDC, | |||
| is_first_k); | |||
| #else | |||
| auto arch = cpuinfo_get_current_core()->uarch; | |||
| if (arch == cpuinfo_uarch_cortex_a53) { | |||
| sgemm_mk4_8x12_helper<matmul_mk4_8x12_a53>(packA, packB, M, N, K, C, | |||
| LDC, is_first_k); | |||
| } else if (arch == cpuinfo_uarch_cortex_a55) { | |||
| sgemm_mk4_8x12_helper<matmul_mk4_8x12_a55>(packA, packB, M, N, K, C, | |||
| LDC, is_first_k); | |||
| } else { | |||
| sgemm_mk4_8x12_helper<matmul_mk4_8x12>(packA, packB, M, N, K, C, LDC, | |||
| is_first_k); | |||
| } | |||
| #endif | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -6,7 +6,8 @@ | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "src/fallback/matrix_mul/gemm_common.h" | |||
| @@ -0,0 +1,387 @@ | |||
| /** | |||
| * \file dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_4x4x8_a72.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include <inttypes.h> | |||
| #include "src/aarch64/matrix_mul/asm/common.h" | |||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||
| namespace megdnn { | |||
| namespace aarch64 { | |||
| namespace matmul_mk4_4x4x8_a72 { | |||
| //! optimize for A72 | |||
| // clang-format off | |||
| /** | |||
| * Overview of register layout: | |||
| * | |||
| * A 4x4x8 cell of Lhs is stored in 8bit in q0-q3, q4-q7 | |||
| * A 4x4x8 cell of Rhs is stored in 8bit in q8-q11, q12-q15 | |||
| * A 4x4 block of accumulators is stored in 16bit in q16-q31 | |||
| * | |||
| * +------------------------+ | |||
| * | q8 | q9 | q10 | q11 | | |||
| * Rhs +------------------------+ | |||
| * Lhs | | | | | | |||
| * +--------+ - - - - +------------------------+ | |||
| * | q0 | | q16 | q20 | q24 | q28 | | |||
| * | q1 | | q17 | q21 | q25 | q29 | | |||
| * | q2 | | q18 | q22 | q26 | q30 | | |||
| * | q3 | | q19 | q23 | q27 | q31 | | |||
| * +--------+ - - - - +------------------------+ | |||
| * | |||
| * Accumulator | |||
| */ | |||
| // clang-format on | |||
| static inline void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||
| int16_t* output, int LDC, bool, int remain_n) { | |||
| K = div_ceil(K, 8); | |||
| int oddk = (K & 1); | |||
| K = ((K + 1) / 2) - 1; | |||
| const int8_t* a_ptr = packA; | |||
| const int8_t* b_ptr = packB; | |||
| LDC = LDC * sizeof(int8_t); | |||
| // clang-format off | |||
| #define STORE_LINE(reg0) \ | |||
| "cmp w10, #0 \n" \ | |||
| "beq 101f\n" \ | |||
| "st1 {v" reg0 ".4h}, [x0], #8\n" \ | |||
| "subs w10, w10, #1\n" | |||
| #define STORE_C \ | |||
| "mov w10, %w[remain_n]\n" \ | |||
| STORE_LINE("16") \ | |||
| STORE_LINE("20") \ | |||
| STORE_LINE("24") \ | |||
| STORE_LINE("28") | |||
| // clang-format on | |||
| register int16_t* outptr asm("x0") = output; | |||
| asm volatile( | |||
| // load accumulator C | |||
| "1:\n" | |||
| "eor v16.16b, v16.16b, v16.16b\n" | |||
| "eor v17.16b, v17.16b, v17.16b\n" | |||
| "eor v18.16b, v18.16b, v18.16b\n" | |||
| "eor v19.16b, v19.16b, v19.16b\n" | |||
| "eor v20.16b, v20.16b, v20.16b\n" | |||
| "eor v21.16b, v21.16b, v21.16b\n" | |||
| "eor v22.16b, v22.16b, v22.16b\n" | |||
| "eor v23.16b, v23.16b, v23.16b\n" | |||
| "eor v24.16b, v24.16b, v24.16b\n" | |||
| "eor v25.16b, v25.16b, v25.16b\n" | |||
| "eor v26.16b, v26.16b, v26.16b\n" | |||
| "eor v27.16b, v27.16b, v27.16b\n" | |||
| "eor v28.16b, v28.16b, v28.16b\n" | |||
| "eor v29.16b, v29.16b, v29.16b\n" | |||
| "eor v30.16b, v30.16b, v30.16b\n" | |||
| "eor v31.16b, v31.16b, v31.16b\n" | |||
| "2: \n" | |||
| "ld1 {v0.8b, v1.8b}, [%[a_ptr]], #16\n" | |||
| "ld1 {v2.8b, v3.8b}, [%[a_ptr]], #16\n" | |||
| "ld1 {v8.8b, v9.8b}, [%[b_ptr]], #16\n" | |||
| "ld1 {v10.8b, v11.8b}, [%[b_ptr]], #16\n" | |||
| "cmp %w[K], #0\n" | |||
| "beq 4f\n" | |||
| "3: \n" | |||
| //! k = 0 | |||
| "smlal v16.8h, v0.8b, v8.8b\n" | |||
| "ld1 {v4.8b}, [%[a_ptr]], #8\n" | |||
| "smlal v17.8h, v1.8b, v8.8b\n" | |||
| "smlal v18.8h, v2.8b, v8.8b\n" | |||
| "ld1 {v5.8b}, [%[a_ptr]], #8\n" | |||
| "smlal v19.8h, v3.8b, v8.8b\n" | |||
| "smlal v20.8h, v0.8b, v9.8b\n" | |||
| "ld1 {v6.8b}, [%[a_ptr]], #8\n" | |||
| "smlal v21.8h, v1.8b, v9.8b\n" | |||
| "smlal v22.8h, v2.8b, v9.8b\n" | |||
| "ld1 {v7.8b}, [%[a_ptr]], #8\n" | |||
| "smlal v23.8h, v3.8b, v9.8b\n" | |||
| "smlal v24.8h, v0.8b, v10.8b\n" | |||
| "ld1 {v12.8b}, [%[b_ptr]], #8\n" | |||
| "smlal v25.8h, v1.8b, v10.8b\n" | |||
| "smlal v26.8h, v2.8b, v10.8b\n" | |||
| "ld1 {v13.8b}, [%[b_ptr]], #8\n" | |||
| "smlal v27.8h, v3.8b, v10.8b\n" | |||
| "smlal v28.8h, v0.8b, v11.8b\n" | |||
| "ld1 {v14.8b}, [%[b_ptr]], #8\n" | |||
| "smlal v29.8h, v1.8b, v11.8b\n" | |||
| "smlal v30.8h, v2.8b, v11.8b\n" | |||
| "ld1 {v15.8b}, [%[b_ptr]], #8\n" | |||
| "smlal v31.8h, v3.8b, v11.8b\n" | |||
| //! k = 8 | |||
| "smlal v16.8h, v4.8b, v12.8b\n" | |||
| "ld1 {v0.8b}, [%[a_ptr]], #8\n" | |||
| "smlal v17.8h, v5.8b, v12.8b\n" | |||
| "smlal v18.8h, v6.8b, v12.8b\n" | |||
| "ld1 {v1.8b}, [%[a_ptr]], #8\n" | |||
| "smlal v19.8h, v7.8b, v12.8b\n" | |||
| "smlal v20.8h, v4.8b, v13.8b\n" | |||
| "ld1 {v2.8b}, [%[a_ptr]], #8\n" | |||
| "smlal v21.8h, v5.8b, v13.8b\n" | |||
| "smlal v22.8h, v6.8b, v13.8b\n" | |||
| "ld1 {v3.8b}, [%[a_ptr]], #8\n" | |||
| "smlal v23.8h, v7.8b, v13.8b\n" | |||
| "smlal v24.8h, v4.8b, v14.8b\n" | |||
| "ld1 {v8.8b}, [%[b_ptr]], #8\n" | |||
| "smlal v25.8h, v5.8b, v14.8b\n" | |||
| "smlal v26.8h, v6.8b, v14.8b\n" | |||
| "ld1 {v9.8b}, [%[b_ptr]], #8\n" | |||
| "smlal v27.8h, v7.8b, v14.8b\n" | |||
| "smlal v28.8h, v4.8b, v15.8b\n" | |||
| "ld1 {v10.8b}, [%[b_ptr]], #8\n" | |||
| "smlal v29.8h, v5.8b, v15.8b\n" | |||
| "smlal v30.8h, v6.8b, v15.8b\n" | |||
| "ld1 {v11.8b}, [%[b_ptr]], #8\n" | |||
| "smlal v31.8h, v7.8b, v15.8b\n" | |||
| "subs %w[K], %w[K], #1\n" | |||
| "bne 3b\n" | |||
| "4:\n" | |||
| "cmp %w[oddk], #1\n" | |||
| "beq 5f\n" | |||
| //! even tail | |||
| //! k = 0 | |||
| "smlal v16.8h, v0.8b, v8.8b\n" | |||
| "ld1 {v4.8b}, [%[a_ptr]], #8\n" | |||
| "smlal v17.8h, v1.8b, v8.8b\n" | |||
| "smlal v18.8h, v2.8b, v8.8b\n" | |||
| "ld1 {v5.8b}, [%[a_ptr]], #8\n" | |||
| "smlal v19.8h, v3.8b, v8.8b\n" | |||
| "smlal v20.8h, v0.8b, v9.8b\n" | |||
| "ld1 {v6.8b}, [%[a_ptr]], #8\n" | |||
| "smlal v21.8h, v1.8b, v9.8b\n" | |||
| "smlal v22.8h, v2.8b, v9.8b\n" | |||
| "ld1 {v7.8b}, [%[a_ptr]], #8\n" | |||
| "smlal v23.8h, v3.8b, v9.8b\n" | |||
| "smlal v24.8h, v0.8b, v10.8b\n" | |||
| "ld1 {v12.8b}, [%[b_ptr]], #8\n" | |||
| "smlal v25.8h, v1.8b, v10.8b\n" | |||
| "smlal v26.8h, v2.8b, v10.8b\n" | |||
| "ld1 {v13.8b}, [%[b_ptr]], #8\n" | |||
| "smlal v27.8h, v3.8b, v10.8b\n" | |||
| "smlal v28.8h, v0.8b, v11.8b\n" | |||
| "ld1 {v14.8b}, [%[b_ptr]], #8\n" | |||
| "smlal v29.8h, v1.8b, v11.8b\n" | |||
| "smlal v30.8h, v2.8b, v11.8b\n" | |||
| "ld1 {v15.8b}, [%[b_ptr]], #8\n" | |||
| "smlal v31.8h, v3.8b, v11.8b\n" | |||
| //! k = 8 | |||
| "smlal v16.8h, v4.8b, v12.8b\n" | |||
| "smlal v17.8h, v5.8b, v12.8b\n" | |||
| "smlal v18.8h, v6.8b, v12.8b\n" | |||
| "smlal v19.8h, v7.8b, v12.8b\n" | |||
| "smlal v20.8h, v4.8b, v13.8b\n" | |||
| "smlal v21.8h, v5.8b, v13.8b\n" | |||
| "smlal v22.8h, v6.8b, v13.8b\n" | |||
| "smlal v23.8h, v7.8b, v13.8b\n" | |||
| "smlal v24.8h, v4.8b, v14.8b\n" | |||
| "smlal v25.8h, v5.8b, v14.8b\n" | |||
| "smlal v26.8h, v6.8b, v14.8b\n" | |||
| "smlal v27.8h, v7.8b, v14.8b\n" | |||
| "smlal v28.8h, v4.8b, v15.8b\n" | |||
| "smlal v29.8h, v5.8b, v15.8b\n" | |||
| "smlal v30.8h, v6.8b, v15.8b\n" | |||
| "smlal v31.8h, v7.8b, v15.8b\n" | |||
| "b 6f\n" | |||
| "5:\n" | |||
| //! odd tail | |||
| "smlal v16.8h, v0.8b, v8.8b\n" | |||
| "smlal v17.8h, v1.8b, v8.8b\n" | |||
| "smlal v18.8h, v2.8b, v8.8b\n" | |||
| "smlal v19.8h, v3.8b, v8.8b\n" | |||
| "smlal v20.8h, v0.8b, v9.8b\n" | |||
| "smlal v21.8h, v1.8b, v9.8b\n" | |||
| "smlal v22.8h, v2.8b, v9.8b\n" | |||
| "smlal v23.8h, v3.8b, v9.8b\n" | |||
| "smlal v24.8h, v0.8b, v10.8b\n" | |||
| "smlal v25.8h, v1.8b, v10.8b\n" | |||
| "smlal v26.8h, v2.8b, v10.8b\n" | |||
| "smlal v27.8h, v3.8b, v10.8b\n" | |||
| "smlal v28.8h, v0.8b, v11.8b\n" | |||
| "smlal v29.8h, v1.8b, v11.8b\n" | |||
| "smlal v30.8h, v2.8b, v11.8b\n" | |||
| "smlal v31.8h, v3.8b, v11.8b\n" | |||
| "6:\n" | |||
| //! reduece | |||
| "addp v16.8h, v16.8h, v17.8h\n" | |||
| "addp v18.8h, v18.8h, v19.8h\n" | |||
| "addp v20.8h, v20.8h, v21.8h\n" | |||
| "addp v22.8h, v22.8h, v23.8h\n" | |||
| "addp v24.8h, v24.8h, v25.8h\n" | |||
| "addp v26.8h, v26.8h, v27.8h\n" | |||
| "addp v16.8h, v16.8h, v18.8h\n" | |||
| "addp v28.8h, v28.8h, v29.8h\n" | |||
| "addp v30.8h, v30.8h, v31.8h\n" | |||
| "addp v20.8h, v20.8h, v22.8h\n" | |||
| "addp v16.8h, v16.8h, v16.8h\n" | |||
| "addp v20.8h, v20.8h, v20.8h\n" | |||
| "addp v24.8h, v24.8h, v26.8h\n" | |||
| "addp v24.8h, v24.8h, v24.8h\n" | |||
| "addp v28.8h, v28.8h, v30.8h\n" | |||
| "addp v28.8h, v28.8h, v28.8h\n" | |||
| "cmp %w[remain_n], #4\n" | |||
| "bne 7f\n" | |||
| "st1 {v16.4h}, [x0], #8\n" | |||
| "st1 {v20.4h}, [x0], #8\n" | |||
| "st1 {v24.4h}, [x0], #8\n" | |||
| "st1 {v28.4h}, [x0], #8\n" | |||
| "b 101f\n" | |||
| "7:\n" STORE_C | |||
| "101:\n" | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
| [oddk] "+r"(oddk), [LDC] "+r"(LDC), [outptr] "+r"(outptr), | |||
| [remain_n] "+r"(remain_n) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||
| "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||
| "v29", "v30", "v31", "x1", "x2", "x3", "x4", "x5", "x6", "x7", | |||
| "x8", "x9", "x10", "cc", "memory"); | |||
| #undef STORE_C | |||
| #undef STORE_LINE | |||
| } | |||
| static inline void transpose_8x4_b(const dt_int8* inptr, dt_int8* outptr) { | |||
| int8x8x4_t in0 = vld4_s8(inptr); | |||
| vst1_s8(outptr + 0 * 8, in0.val[0]); | |||
| vst1_s8(outptr + 1 * 8, in0.val[1]); | |||
| vst1_s8(outptr + 2 * 8, in0.val[2]); | |||
| vst1_s8(outptr + 3 * 8, in0.val[3]); | |||
| } | |||
| static inline void interleve_8x4_b(const dt_int8* inptr, const dt_int8* inptr2, | |||
| dt_int8* outptr) { | |||
| int8x16_t in0 = vld1q_s8(inptr); | |||
| int8x16_t in1 = vld1q_s8(inptr2); | |||
| int32x4x2_t in_x2 = { | |||
| {vreinterpretq_s32_s8(in0), vreinterpretq_s32_s8(in1)}}; | |||
| vst2q_s32(reinterpret_cast<int32_t*>(outptr), in_x2); | |||
| } | |||
| static inline void interleve_8x4_b_pad(const dt_int8* inptr, dt_int8* outptr) { | |||
| int8x16_t in0 = vld1q_s8(inptr); | |||
| int8x16_t in1 = vdupq_n_s8(0); | |||
| int32x4x2_t in_x2 = { | |||
| {vreinterpretq_s32_s8(in0), vreinterpretq_s32_s8(in1)}}; | |||
| vst2q_s32(reinterpret_cast<int32_t*>(outptr), in_x2); | |||
| } | |||
| static void gemm_s8x8x16_mk4_4x4x8_pack_A(dt_int8* out, const dt_int8* in, | |||
| int ldin, int m0, int mmax, int k0, | |||
| int kmax) { | |||
| megdnn_assert(m0 % 4 == 0 && mmax % 4 == 0, "M must be time of 4"); | |||
| megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | |||
| constexpr int pack_m = 4; | |||
| constexpr int pack_k = 8; | |||
| constexpr int pack_size = 4; | |||
| const int ksize = kmax - k0; | |||
| const int remain_k = ksize % pack_k; | |||
| const int kend = kmax - remain_k; | |||
| int8_t tmpbuff[pack_m * pack_k]{0}; | |||
| for (int m_idx = m0; m_idx < mmax; m_idx += pack_m) { | |||
| const int8_t* inptr0 = in + m_idx / pack_size * ldin + k0; | |||
| for (int k_idx = k0; k_idx < kend; k_idx += pack_k) { | |||
| transpose_8x4_b(inptr0, out); | |||
| inptr0 += pack_m * pack_k; | |||
| out += pack_m * pack_k; | |||
| } | |||
| if (remain_k > 0) { | |||
| int8x16_t tmp = vld1q_s8(inptr0); | |||
| vst1q_s8(&tmpbuff[0], tmp); | |||
| transpose_8x4_b(&tmpbuff[0], out); | |||
| inptr0 += pack_m * pack_size; | |||
| out += pack_m * pack_k; | |||
| } | |||
| } | |||
| } | |||
| static void gemm_s8x8x16_mk4_4x4x8_pack_B(dt_int8* out, const dt_int8* in, | |||
| int ldin, int n0, int nmax, int k0, | |||
| int kmax) { | |||
| megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | |||
| constexpr int pack_n = 4; | |||
| constexpr int pack_k = 8; | |||
| constexpr int pack_size = 4; | |||
| const int ksize = kmax - k0; | |||
| const int packed_ksize = round_up(ksize, pack_k); | |||
| const int remain_k = ksize % pack_k; | |||
| const int kend = kmax - remain_k; | |||
| const int nsize = nmax - n0; | |||
| const int remain_n = nsize % pack_n; | |||
| const int nend = nmax - remain_n; | |||
| const int stride_input = pack_size * nsize; | |||
| int8_t tmpbuff[pack_n * pack_k]{0}; | |||
| int8_t tmpbuff2[pack_n * pack_k]{0}; | |||
| for (int k_idx = k0; k_idx < kend; k_idx += pack_k) { | |||
| const int8_t* inptr = in + k_idx / pack_size * ldin + n0 * pack_size; | |||
| const int8_t* inptr2 = inptr + stride_input; | |||
| int8_t* outptr = out + k_idx * pack_n; | |||
| for (int n_idx = n0; n_idx < nend; n_idx += pack_n) { | |||
| interleve_8x4_b(inptr, inptr2, outptr); | |||
| inptr += pack_n * pack_size; | |||
| inptr2 += pack_n * pack_size; | |||
| outptr += pack_n * packed_ksize; | |||
| } | |||
| if (remain_n > 0) { | |||
| memcpy(&tmpbuff[0], inptr, remain_n * pack_size * sizeof(int8_t)); | |||
| memcpy(&tmpbuff2[0], inptr2, remain_n * pack_size * sizeof(int8_t)); | |||
| interleve_8x4_b(&tmpbuff[0], &tmpbuff2[0], outptr); | |||
| outptr += pack_n * packed_ksize; | |||
| } | |||
| } | |||
| if (remain_k > 0) { | |||
| const int8_t* inptr = in + kend / pack_size * ldin + n0 * pack_size; | |||
| int8_t* outptr = out + kend * pack_n; | |||
| for (int n_idx = n0; n_idx < nend; n_idx += pack_n) { | |||
| interleve_8x4_b_pad(inptr, outptr); | |||
| inptr += pack_n * pack_size; | |||
| outptr += pack_n * packed_ksize; | |||
| } | |||
| if (remain_n > 0) { | |||
| memcpy(&tmpbuff[0], inptr, remain_n * pack_size * sizeof(int8_t)); | |||
| interleve_8x4_b_pad(&tmpbuff[0], outptr); | |||
| outptr += pack_n * packed_ksize; | |||
| } | |||
| } | |||
| } | |||
| } // namespace matmul_mk4_4x4x8_a72 | |||
| } // namespace aarch64 | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -6,12 +6,15 @@ | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/aarch64/matrix_mul/asm/common.h" | |||
| #include "src/aarch64/matrix_mul/int8x8x16/kernel_4x4x16.h" | |||
| #include "src/aarch64/matrix_mul/int8x8x16/kernel_8x8x8.h" | |||
| #include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_16x12x4_a53.h" | |||
| #include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_4x4x8_a72.h" | |||
| #include "src/aarch64/matrix_mul/int8x8x16/strategy.h" | |||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||
| #include "src/common/utils.h" | |||
| @@ -197,4 +200,161 @@ void gemm_s8x8x16_4x4::kern(const dt_int8* packA, const dt_int8* packB, | |||
| packA += K4; | |||
| } | |||
| } | |||
| // ===========================gemm_s8x8x16_mk4_16x12================================== | |||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_16x12_a53); | |||
| void gemm_s8x8x16_mk4_16x12_a53::pack_A(dt_int16* out, const dt_int8* in, | |||
| int ldin, int y0, int ymax, int k0, | |||
| int kmax, bool) const { | |||
| matmul_mk4_16x12x4_a53::gemm_s8x8x16_mk4_16x12_pack_A(out, in, ldin, y0, | |||
| ymax, k0, kmax); | |||
| } | |||
| void gemm_s8x8x16_mk4_16x12_a53::pack_B(dt_int8* out, const dt_int8* in, | |||
| int ldin, int x0, int xmax, int k0, | |||
| int kmax, bool) const { | |||
| matmul_mk4_16x12x4_a53::gemm_s8x8x16_mk4_16x12_pack_B(out, in, ldin, x0, | |||
| xmax, k0, kmax); | |||
| } | |||
| void gemm_s8x8x16_mk4_16x12_a53::kern(const dt_int16* packA, | |||
| const dt_int8* packB, size_t M, size_t N, | |||
| size_t K, dt_int16* C, size_t LDC, | |||
| bool is_first_k, const dt_int16*, | |||
| dt_int16*) const { | |||
| megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
| C_dtype.enumv() == DTypeEnum::Int16 && | |||
| A_dtype.enumv() == DTypeEnum::Int8); | |||
| megdnn_assert(is_first_k == true, "only impl is_first_k"); | |||
| MEGDNN_MARK_USED_VAR(A_dtype); | |||
| MEGDNN_MARK_USED_VAR(B_dtype); | |||
| MEGDNN_MARK_USED_VAR(C_dtype); | |||
| megdnn_assert(M % 4 == 0 && K % 4 == 0, "M and K must be time of 4"); | |||
| constexpr size_t pack_size = 4; | |||
| constexpr size_t pack_m = 16; | |||
| constexpr size_t pack_n = 12; | |||
| const size_t remain_n = N % pack_n; | |||
| size_t remain_m = M % pack_m; | |||
| size_t m_idx = 0; | |||
| for (; m_idx + pack_m <= M; m_idx += pack_m) { | |||
| int16_t* output = C + (m_idx / pack_size * LDC); | |||
| size_t n_idx = 0; | |||
| const int8_t* cur_packB = packB; | |||
| for (; n_idx + pack_n <= N; n_idx += pack_n) { | |||
| matmul_mk4_16x12x4_a53::kern_16x12(packA, cur_packB, K, output, LDC, | |||
| is_first_k, pack_n); | |||
| output += pack_n * pack_size; | |||
| cur_packB += pack_n * K; | |||
| } | |||
| if (remain_n > 0) { | |||
| matmul_mk4_16x12x4_a53::kern_16x12(packA, cur_packB, K, output, LDC, | |||
| is_first_k, remain_n); | |||
| output += remain_n * pack_size; | |||
| cur_packB += pack_n * K; | |||
| } | |||
| packA += pack_m * K; | |||
| } | |||
| if (remain_m >= 8) { | |||
| int16_t* output = C + (m_idx / pack_size * LDC); | |||
| size_t n_idx = 0; | |||
| const int8_t* cur_packB = packB; | |||
| for (; n_idx + pack_n <= N; n_idx += pack_n) { | |||
| matmul_mk4_16x12x4_a53::kern_8x12(packA, cur_packB, K, output, LDC, | |||
| is_first_k, pack_n); | |||
| output += pack_n * pack_size; | |||
| cur_packB += pack_n * K; | |||
| } | |||
| if (remain_n > 0) { | |||
| matmul_mk4_16x12x4_a53::kern_8x12(packA, cur_packB, K, output, LDC, | |||
| is_first_k, remain_n); | |||
| output += remain_n * pack_size; | |||
| cur_packB += pack_n * K; | |||
| } | |||
| packA += 8 * K; | |||
| m_idx += 8; | |||
| remain_m -= 8; | |||
| } | |||
| if (remain_m == 4) { | |||
| int16_t* output = C + (m_idx / pack_size * LDC); | |||
| size_t n_idx = 0; | |||
| const int8_t* cur_packB = packB; | |||
| for (; n_idx + pack_n <= N; n_idx += pack_n) { | |||
| matmul_mk4_16x12x4_a53::kern_4x12(packA, cur_packB, K, output, LDC, | |||
| is_first_k, pack_n); | |||
| output += pack_n * pack_size; | |||
| cur_packB += pack_n * K; | |||
| } | |||
| if (remain_n > 0) { | |||
| matmul_mk4_16x12x4_a53::kern_4x12(packA, cur_packB, K, output, LDC, | |||
| is_first_k, remain_n); | |||
| output += remain_n * pack_size; | |||
| cur_packB += pack_n * K; | |||
| } | |||
| } | |||
| } | |||
| // ===========================gemm_s8x8x16_mk4_4x4_a72================================== | |||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_4x4_a72); | |||
| void gemm_s8x8x16_mk4_4x4_a72::pack_A(dt_int8* out, const dt_int8* in, int ldin, | |||
| int y0, int ymax, int k0, int kmax, | |||
| bool) const { | |||
| matmul_mk4_4x4x8_a72::gemm_s8x8x16_mk4_4x4x8_pack_A(out, in, ldin, y0, ymax, | |||
| k0, kmax); | |||
| } | |||
| void gemm_s8x8x16_mk4_4x4_a72::pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||
| int x0, int xmax, int k0, int kmax, | |||
| bool) const { | |||
| matmul_mk4_4x4x8_a72::gemm_s8x8x16_mk4_4x4x8_pack_B(out, in, ldin, x0, xmax, | |||
| k0, kmax); | |||
| } | |||
| void gemm_s8x8x16_mk4_4x4_a72::kern(const dt_int8* packA, const dt_int8* packB, | |||
| size_t M, size_t N, size_t K, dt_int16* C, | |||
| size_t LDC, bool is_first_k, | |||
| const dt_int16*, dt_int16*) const { | |||
| megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
| C_dtype.enumv() == DTypeEnum::Int16 && | |||
| A_dtype.enumv() == DTypeEnum::Int8); | |||
| megdnn_assert(is_first_k == true, "only impl is_first_k"); | |||
| MEGDNN_MARK_USED_VAR(A_dtype); | |||
| MEGDNN_MARK_USED_VAR(B_dtype); | |||
| MEGDNN_MARK_USED_VAR(C_dtype); | |||
| megdnn_assert(M % 4 == 0 && K % 4 == 0, "M and K must be time of 4"); | |||
| constexpr size_t pack_size = 4; | |||
| constexpr size_t pack_m = 4; | |||
| constexpr size_t pack_n = 4; | |||
| constexpr size_t pack_k = 8; | |||
| const size_t remain_n = N % pack_n; | |||
| const size_t nend = N - remain_n; | |||
| const size_t packed_k = round_up(K, pack_k); | |||
| for (size_t m_idx = 0; m_idx < M; m_idx += pack_m) { | |||
| int16_t* output = C + (m_idx / pack_size * LDC); | |||
| const int8_t* cur_packB = packB; | |||
| for (size_t n_idx = 0; n_idx < nend; n_idx += pack_n) { | |||
| matmul_mk4_4x4x8_a72::kern_4x4(packA, cur_packB, K, output, LDC, | |||
| is_first_k, pack_n); | |||
| output += pack_n * pack_size; | |||
| cur_packB += pack_n * packed_k; | |||
| } | |||
| if (remain_n > 0) { | |||
| matmul_mk4_4x4x8_a72::kern_4x4(packA, cur_packB, K, output, LDC, | |||
| is_first_k, remain_n); | |||
| output += remain_n * pack_size; | |||
| cur_packB += pack_n * packed_k; | |||
| } | |||
| packA += pack_m * packed_k; | |||
| } | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -6,7 +6,8 @@ | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| @@ -20,6 +21,11 @@ MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int16, dt_int16, 8, 8, 8, false, true, | |||
| gemm_s8x8x16_8x8); | |||
| MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int16, dt_int16, 4, 4, 16, false, true, | |||
| gemm_s8x8x16_4x4); | |||
| MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int16, dt_int16, 4, 4, 8, false, false, | |||
| gemm_s8x8x16_mk4_4x4_a72); | |||
| MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE(dt_int8, dt_int16, dt_int16, dt_int16, | |||
| 16, 12, 4, false, false, | |||
| gemm_s8x8x16_mk4_16x12_a53); | |||
| } // namespace matmul | |||
| } // namespace aarch64 | |||
| @@ -6,10 +6,11 @@ | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/aarch64/matrix_mul/opr_impl.h" | |||
| #include "src/aarch64/matrix_mul/algos.h" | |||
| #include "src/aarch64/matrix_mul/opr_impl.h" | |||
| #include "src/common/metahelper.h" | |||
| #include "src/common/utils.h" | |||
| @@ -36,6 +37,8 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||
| #endif | |||
| AlgoInt8x8x16K8x8x8 int8x8x16_k8x8x8; | |||
| AlgoInt8x8x16K4x4x16 int8x8x16_k4x4x16; | |||
| AlgoInt8x8x16MK4_16x12x4 int8x8x16_mk4_16x12x4; | |||
| AlgoInt8x8x16MK4_4x4x8 int8x8x16_mk4_4x4x8; | |||
| AlgoInt16x16x32K12x8x1 int16x16x32_k12x8x1; | |||
| AlgoInt16x16x32MK8_8x8 int16x16x32_mk8_8x8; | |||
| @@ -70,6 +73,8 @@ public: | |||
| #endif | |||
| all_algos.emplace_back(&int8x8x16_k4x4x16); | |||
| all_algos.emplace_back(&int8x8x16_k8x8x8); | |||
| all_algos.emplace_back(&int8x8x16_mk4_4x4x8); | |||
| all_algos.emplace_back(&int8x8x16_mk4_16x12x4); | |||
| all_algos.emplace_back(&int16x16x32_k12x8x1); | |||
| all_algos.emplace_back(&int16x16x32_mk8_8x8); | |||
| @@ -6,7 +6,8 @@ | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "src/arm_common/matrix_mul/opr_impl.h" | |||
| @@ -21,28 +22,30 @@ public: | |||
| SmallVector<AlgoBase*> algo_pack() override; | |||
| private: | |||
| class AlgoF32K8x12x1; // Aarch64 F32 Kernel 8X12X1 | |||
| class AlgoF32MK4_8x12x1; // Aarch64 F32 Kernel MK4 8x12x1 | |||
| class AlgoF32K4x16x1; // Aarch64 F32 Kernel 4x16x1 | |||
| class AlgoF32MK4_4x16; // Aarch64 F32 Format MK4 block 16x4 | |||
| class AlgoF32Gemv; // Aarch64 F32 Gemv | |||
| class AlgoF32K8x12x1; // Aarch64 F32 Kernel 8X12X1 | |||
| class AlgoF32MK4_8x12x1; // Aarch64 F32 Kernel MK4 8x12x1 | |||
| class AlgoF32K4x16x1; // Aarch64 F32 Kernel 4x16x1 | |||
| class AlgoF32MK4_4x16; // Aarch64 F32 Format MK4 block 16x4 | |||
| class AlgoF32Gemv; // Aarch64 F32 Gemv | |||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
| class AlgoF16K8x24x1; // Aarch64 F16 Kernel 8x24x1 | |||
| class AlgoF16MK8_8x8; // Aarch64 F16 Format MK8 block 16x8 | |||
| #endif | |||
| #if __ARM_FEATURE_DOTPROD | |||
| class AlgoInt8x8x32K8x12x4DotProd; // Aarch64 Int8x8x32 Kernel | |||
| // 8x12x4 DotProduct | |||
| class AlgoInt8x8x32MK4_8x12x4DotProd; // Aarch64 nchw44 Int8x8x32 Kernel | |||
| // 8x12x4 DotProduct | |||
| class AlgoInt8x8x32K8x12x4DotProd; // Aarch64 Int8x8x32 Kernel | |||
| // 8x12x4 DotProduct | |||
| class AlgoInt8x8x32MK4_8x12x4DotProd; // Aarch64 nchw44 Int8x8x32 Kernel | |||
| // 8x12x4 DotProduct | |||
| #else | |||
| class AlgoInt8x8x32MK4_4x4x16; // Aarch64 nchw44 Int8x8x32 Kernel 4x4x16 | |||
| class AlgoInt8x8x32K4x4x16; // Aarch64 Int8x8x32 Kernel 4x4x16 | |||
| class AlgoInt8x8x32K8x8x8; // Aarch64 Int8x8x32 Kernel 8x8x8 | |||
| class AlgoInt8x8x32K4x4x16; // Aarch64 Int8x8x32 Kernel 4x4x16 | |||
| class AlgoInt8x8x32K8x8x8; // Aarch64 Int8x8x32 Kernel 8x8x8 | |||
| #endif | |||
| class AlgoInt8x8x16K8x8x8; // Aarch64 Int8x8x16 Kernel 8x8x8 | |||
| class AlgoInt8x8x16K4x4x16; // Aarch64 Int8x8x16 Kernel 4x4x16 | |||
| class AlgoInt8x8x16K8x8x8; // Aarch64 Int8x8x16 Kernel 8x8x8 | |||
| class AlgoInt8x8x16K4x4x16; // Aarch64 Int8x8x16 Kernel 4x4x16 | |||
| class AlgoInt8x8x16MK4_16x12x4; // Aarch64 Int8x8x16 Kernel 16x12x16 | |||
| class AlgoInt8x8x16MK4_4x4x8; // Aarch64 Int8x8x16 Kernel 4x4x8 | |||
| class AlgoInt16x16x32K12x8x1; // Aarch64 Int16x16x32 Kernel 12x8x1 | |||
| class AlgoInt16x16x32MK8_8x8; // Aarch64 Int16x16x32 Format MK8 block 8x8 | |||
| @@ -52,7 +55,7 @@ private: | |||
| // 8x8x4 DotProduct | |||
| class AlgoQuint8GemvDotProd; // Aarch64 Quint8 Gemv DotProduct | |||
| #else | |||
| class AlgoQuint8K8x8x8; // Aarch64 Quint8 Kernel 8x8x8 | |||
| class AlgoQuint8K8x8x8; // Aarch64 Quint8 Kernel 8x8x8 | |||
| #endif | |||
| class AlgoPack; | |||
| @@ -214,7 +214,6 @@ void* const ConvBiasImpl::sm_arm_common_algo_type = | |||
| bool ConvBiasImpl::is_matmul_quantized_prefer( | |||
| const ConvBiasImpl::NCBKernSizeParam& param) const { | |||
| // fallback::ConvBiasImpl::NCBKernParam conv_ncb_param; | |||
| fallback::ConvBiasImpl::NCBKernSizeParam conv_ncb_param( | |||
| param, 0, param::MatrixMul::Format::DEFAULT, {}, 0, | |||
| BiasMode::NO_BIAS, param::ConvBias::NonlineMode::IDENTITY); | |||
| @@ -9,8 +9,8 @@ | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #ifdef MGB_ENABLE_CPUINFO_CHECK | |||
| #include "src/common/utils.h" | |||
| #if defined(MGB_ENABLE_CPUINFO_CHECK) && MGB_ENABLE_CPUINFO | |||
| #include "cpuinfo_arch_vendor.h" | |||
| @@ -11,8 +11,8 @@ | |||
| */ | |||
| #pragma once | |||
| #ifdef MGB_ENABLE_CPUINFO_CHECK | |||
| #include "src/common/utils.h" | |||
| #if defined(MGB_ENABLE_CPUINFO_CHECK) && MGB_ENABLE_CPUINFO | |||
| #include <cpuinfo.h> | |||
| @@ -6,7 +6,8 @@ | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "test/aarch64/fixture.h" | |||
| @@ -16,6 +17,7 @@ | |||
| #include "test/common/matrix_mul.h" | |||
| #include "test/common/rng.h" | |||
| #include "test/arm_common/cpuinfo_help.h" | |||
| using namespace megdnn; | |||
| using namespace test; | |||
| @@ -24,6 +26,20 @@ TEST_F(AARCH64, MATRIX_MUL_FP32K8X12) { | |||
| dtype::Float32{}, handle(), | |||
| "AARCH64_F32K8X12X1"); | |||
| } | |||
| #if MGB_ENABLE_CPUINFO | |||
| TEST_F(AARCH64, MATRIX_MUL_FP32K8X12_A53) { | |||
| CpuInfoTmpReplace cpu_replace_guard(cpuinfo_uarch_cortex_a53); | |||
| matrix_mul::check_matrix_mul(dtype::Float32{}, dtype::Float32{}, | |||
| dtype::Float32{}, handle(), | |||
| "AARCH64_F32K8X12X1"); | |||
| } | |||
| TEST_F(AARCH64, MATRIX_MUL_FP32K8X12_A55) { | |||
| CpuInfoTmpReplace cpu_replace_guard(cpuinfo_uarch_cortex_a55); | |||
| matrix_mul::check_matrix_mul(dtype::Float32{}, dtype::Float32{}, | |||
| dtype::Float32{}, handle(), | |||
| "AARCH64_F32K8X12X1"); | |||
| } | |||
| #endif | |||
| TEST_F(AARCH64, MATRIX_MUL_FP32K4X16) { | |||
| matrix_mul::check_matrix_mul(dtype::Float32{}, dtype::Float32{}, | |||
| @@ -36,6 +52,20 @@ TEST_F(AARCH64, MATRIX_MUL_FP32_PACK_MK4) { | |||
| dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), | |||
| "AARCH64_F32_MK4_K8X12X1", param::MatrixMul::Format::MK4, 1); | |||
| } | |||
| #if MGB_ENABLE_CPUINFO | |||
| TEST_F(AARCH64, MATRIX_MUL_FP32_PACK_MK4_A53) { | |||
| CpuInfoTmpReplace cpu_replace_guard(cpuinfo_uarch_cortex_a53); | |||
| matrix_mul::check_matrix_mul( | |||
| dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), | |||
| "AARCH64_F32_MK4_K8X12X1", param::MatrixMul::Format::MK4, 1); | |||
| } | |||
| TEST_F(AARCH64, MATRIX_MUL_FP32_PACK_MK4_A55) { | |||
| CpuInfoTmpReplace cpu_replace_guard(cpuinfo_uarch_cortex_a55); | |||
| matrix_mul::check_matrix_mul( | |||
| dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), | |||
| "AARCH64_F32_MK4_K8X12X1", param::MatrixMul::Format::MK4, 1); | |||
| } | |||
| #endif | |||
| TEST_F(AARCH64, MATRIX_MUL_FP32_MK4) { | |||
| matrix_mul::check_matrix_mul( | |||
| @@ -92,6 +122,18 @@ TEST_F(AARCH64, MATRIX_MUL_INT8_MK4) { | |||
| std::move(args)); | |||
| } | |||
| TEST_F(AARCH64, MATRIX_MUL_MK4_8x8x16_4x4) { | |||
| matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, | |||
| handle(), "AARCH64_INT8X8X16_MK4_4X4X8", | |||
| param::MatrixMul::Format::MK4, 1); | |||
| } | |||
| TEST_F(AARCH64, MATRIX_MUL_MK4_8x8x16) { | |||
| matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, | |||
| handle(), "AARCH64_INT8X8X16_MK4_16X12X4", | |||
| param::MatrixMul::Format::MK4, 1); | |||
| } | |||
| TEST_F(AARCH64, MATRIX_MUL_INT8x8x32_K8x8x8) { | |||
| matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, | |||
| handle(), "AARCH64_INT8X8X32_K8X8X8"); | |||
| @@ -172,6 +214,7 @@ TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_FP32_K4X16) { | |||
| }; | |||
| run(256, 256, 128); | |||
| run(384, 384, 384); | |||
| for (size_t k = 4; k <= 256; k *= 8) { | |||
| for (size_t m = 4; m <= 256; m *= 4) { | |||
| @@ -235,7 +278,7 @@ TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_INT16_8X8X8) { | |||
| int32_used / int_used); | |||
| }; | |||
| run(256, 256, 128); | |||
| run(256, 256, 256); | |||
| for (size_t k = 4; k <= 256; k *= 8) { | |||
| for (size_t m = 4; m <= 256; m *= 4) { | |||
| @@ -297,6 +340,62 @@ TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_INT32_MK_4X4X16) { | |||
| } | |||
| } | |||
| TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_MK4_8x8x16) { | |||
| constexpr size_t RUNS = 50; | |||
| param::MatrixMul param; | |||
| param.transposeA = false; | |||
| param.transposeB = false; | |||
| Benchmarker<MatrixMul> benchmarker(handle()); | |||
| Benchmarker<MatrixMul> benchmarker_mk4(handle()); | |||
| Benchmarker<MatrixMul> benchmarker_mk4_16x12(handle()); | |||
| benchmarker.set_times(RUNS) | |||
| .set_dtype(0, dtype::Int8{}) | |||
| .set_dtype(1, dtype::Int8{}) | |||
| .set_dtype(2, dtype::Int16{}) | |||
| .set_param(param) | |||
| .set_display(false); | |||
| benchmarker.set_before_exec_callback( | |||
| AlgoChecker<MatrixMul>("AARCH64_INT8X8X16_K4X4X16")); | |||
| param.format = MatrixMul::Param::Format::MK4; | |||
| benchmarker_mk4.set_before_exec_callback( | |||
| AlgoChecker<MatrixMul>("AARCH64_INT8X8X16_MK4_4X4X8")); | |||
| benchmarker_mk4.set_times(RUNS) | |||
| .set_dtype(0, dtype::Int8{}) | |||
| .set_dtype(1, dtype::Int8{}) | |||
| .set_dtype(2, dtype::Int16{}) | |||
| .set_param(param) | |||
| .set_display(false); | |||
| benchmarker_mk4_16x12.set_before_exec_callback( | |||
| AlgoChecker<MatrixMul>("AARCH64_INT8X8X16_MK4_16X12X4")); | |||
| benchmarker_mk4_16x12.set_times(RUNS) | |||
| .set_dtype(0, dtype::Int8{}) | |||
| .set_dtype(1, dtype::Int8{}) | |||
| .set_dtype(2, dtype::Int16{}) | |||
| .set_param(param) | |||
| .set_display(false); | |||
| auto run = [&](size_t M, size_t N, size_t K) { | |||
| auto default_used = benchmarker.exec({{M, K}, {K, N}, {}}) / RUNS; | |||
| auto mk_used = benchmarker_mk4.exec( | |||
| {{M / 4, K / 4, 4, 4}, {K / 4, N, 4}, {}}) / | |||
| RUNS; | |||
| auto mk4_16x12_used = | |||
| benchmarker_mk4_16x12.exec( | |||
| {{M / 4, K / 4, 4, 4}, {K / 4, N, 4}, {}}) / | |||
| RUNS; | |||
| float computations = 2.f * M * K * N * 1e-6; | |||
| printf("run: {%zu{M} %zu{K} %zu{N}} normal: %f ms %f Gflops mk4: %f ms " | |||
| "%f Gflops speedup: %f, mk4_16x12 %f Gflops speedup: %f\n", | |||
| M, K, N, default_used, computations / default_used, mk_used, | |||
| computations / mk_used, default_used / mk_used, | |||
| computations / mk4_16x12_used, default_used / mk4_16x12_used); | |||
| }; | |||
| run(384, 384, 384); | |||
| } | |||
| TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_INT16_4X4X16) { | |||
| constexpr size_t RUNS = 50; | |||
| param::MatrixMul param; | |||
| @@ -350,9 +449,11 @@ TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_INT16_4X4X16) { | |||
| run(256, 256, 128); | |||
| for (size_t k = 4; k <= 16; k *= 2) { | |||
| for (size_t m = 4; m <= 64; m *= 2) { | |||
| for (size_t n = 4; n <= 64; n *= 2) { | |||
| run(256, 256, 256); | |||
| for (size_t k = 4; k <= 256; k *= 4) { | |||
| for (size_t m = 4; m <= 256; m *= 4) { | |||
| for (size_t n = 4; n <= 256; n *= 4) { | |||
| run(m, n, k); | |||
| } | |||
| } | |||
| @@ -736,15 +736,21 @@ TEST_F(ARM_COMMON, BENCHMARK_NCHW_VS_NCHW44_INT8x8x32) { | |||
| } | |||
| #endif | |||
| #if MEGDNN_ARMV7 | |||
| TEST_F(ARM_COMMON, BENCHMARK_NCHW_VS_NCHW44_INT8x8x16) { | |||
| #if MEGDNN_ARMV7 | |||
| const char* default_algo = "IM2COLMATMUL:ARMV7_INT8X8X16_K4X8X8"; | |||
| const char* mk4_algo = "IM2COLMATMUL:ARMV7_INT8X8X16_MK4_K8X8X4"; | |||
| printf("compare %s vs %s \n", default_algo, mk4_algo); | |||
| BENCHMARK_IM2COL_NCHW44_VS_NCHW(default_algo, mk4_algo, handle(), 3, | |||
| dtype::Int8(), dtype::Int16()); | |||
| } | |||
| #else | |||
| const char* default_algo = "IM2COLMATMUL:AARCH64_INT8X8X16_K4X4X16"; | |||
| const char* mk4_algo = "IM2COLMATMUL:AARCH64_INT8X8X16_MK4_4X4X8"; | |||
| printf("compare %s vs %s \n", default_algo, mk4_algo); | |||
| BENCHMARK_IM2COL_NCHW44_VS_NCHW(default_algo, mk4_algo, handle(), 3, | |||
| dtype::Int8(), dtype::Int16()); | |||
| #endif | |||
| } | |||
| TEST_F(ARM_COMMON, BENCHMARK_GROUP_CONV_NCHW44_INT8x8x32_VS_INT8x8x16_STRIDE1) { | |||
| BENCHMARK_GROUPCONV_NCHW44_int8x8x16VS_int8x8x32("S8_CHAN_WISE_STRD1_NCHW44", | |||
| @@ -14,6 +14,8 @@ | |||
| #include "test/common/benchmarker.h" | |||
| #include "test/common/conv_bias.h" | |||
| #include "test/arm_common/cpuinfo_help.h" | |||
| using namespace megdnn; | |||
| using namespace test; | |||
| using namespace conv_bias; | |||
| @@ -487,11 +489,10 @@ TEST_F(ARM_COMMON_MULTI_THREADS, | |||
| handle(), "S8_CHAN_WISE_STRD2_NCHW44"); | |||
| } | |||
| TEST_F(ARM_COMMON, | |||
| CONV_BIAS_INT8_INT8_INT16_CHANNEL_WISE_DIRECT1_NCHW44) { | |||
| TEST_F(ARM_COMMON, CONV_BIAS_INT8_INT8_INT16_CHANNEL_WISE_DIRECT1_NCHW44) { | |||
| Checker<ConvBias> checker(handle()); | |||
| checker.set_before_exec_callback( | |||
| conv_bias::ConvBiasAlgoChecker<ConvBias>("S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44")); | |||
| checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>( | |||
| "S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44")); | |||
| checker.set_dtype(0, dtype::Int8()); | |||
| checker.set_dtype(1, dtype::Int8()); | |||
| checker.set_dtype(2, dtype::Int16()); | |||
| @@ -505,8 +506,8 @@ TEST_F(ARM_COMMON, | |||
| TEST_F(ARM_COMMON_MULTI_THREADS, | |||
| CONV_BIAS_INT8_INT8_INT16_CHANNEL_WISE_DIRECT2_NCHW44) { | |||
| Checker<ConvBias> checker(handle()); | |||
| checker.set_before_exec_callback( | |||
| conv_bias::ConvBiasAlgoChecker<ConvBias>("S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44")); | |||
| checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>( | |||
| "S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44")); | |||
| checker.set_dtype(0, dtype::Int8()); | |||
| checker.set_dtype(1, dtype::Int8()); | |||
| checker.set_dtype(2, dtype::Int16()); | |||
| @@ -1803,8 +1804,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COL_FP32_STRIDE2_PREPROCESS) { | |||
| handle(), nullptr, 0.001, dtype::Float32(), dtype::Float32(), \ | |||
| dtype::Float32(), dtype::Float32(), name); | |||
| #if MEGDNN_AARCH64 | |||
| cb("IM2COLMATMUL:AARCH64_F32K8X12X1") | |||
| cb("IM2COLMATMUL:AARCH64_F32K4X16X1") | |||
| cb("IM2COLMATMUL:AARCH64_F32K8X12X1") cb("IM2COLMATMUL:AARCH64_F32K4X16X1") | |||
| #elif MEGDNN_ARMV7 | |||
| cb("IM2COLMATMUL:ARMV7_F32") | |||
| #endif | |||
| @@ -1858,6 +1858,94 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COL_FP32_STRIDE1) { | |||
| #undef cb | |||
| } | |||
| //! CPUINFO ralated test | |||
| #if MEGDNN_AARCH64 | |||
| #if MGB_ENABLE_CPUINFO | |||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COL_FP32_A55) { | |||
| CpuInfoTmpReplace cpu_replace_guard(cpuinfo_uarch_cortex_a55); | |||
| #define cb(name,stride) \ | |||
| check_conv_bias( \ | |||
| get_conv_bias_args({2, 3, 4, 5, 6, 7}, stride, false, false, false), \ | |||
| handle(), name); | |||
| cb("IM2COLMATMUL:AARCH64_F32K8X12X1", 1) | |||
| cb("IM2COLMATMUL:AARCH64_F32K8X12X1", 2) | |||
| #undef cb | |||
| } | |||
| #endif | |||
| #endif | |||
| #if MEGDNN_AARCH64 | |||
| #if MGB_ENABLE_CPUINFO | |||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COL_FP32_A53) { | |||
| CpuInfoTmpReplace cpu_replace_guard(cpuinfo_uarch_cortex_a53); | |||
| #define cb(name,stride) \ | |||
| check_conv_bias( \ | |||
| get_conv_bias_args({2, 3, 4, 5, 6, 7}, stride, false, false, false), \ | |||
| handle(), name); | |||
| cb("IM2COLMATMUL:AARCH64_F32K8X12X1", 1) | |||
| cb("IM2COLMATMUL:AARCH64_F32K8X12X1", 2) | |||
| #undef cb | |||
| } | |||
| #endif | |||
| #endif | |||
| #if MEGDNN_AARCH64 | |||
| #if MGB_ENABLE_CPUINFO | |||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_MK4_PACK_F32_A55) { | |||
| CpuInfoTmpReplace cpu_replace_guard(cpuinfo_uarch_cortex_a55); | |||
| using namespace conv_bias; | |||
| std::vector<conv_bias::TestArg> args = get_nchw44_conv_bias_args( | |||
| {2, 3, 7}, 1, false, false, false, false, false, true, true); | |||
| check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1"); | |||
| args = get_nchw44_conv_bias_args( | |||
| {2, 3, 7}, 2, false, false, false, false, false, true, true); | |||
| check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1"); | |||
| } | |||
| #endif | |||
| #endif | |||
| #if MEGDNN_AARCH64 | |||
| #if MGB_ENABLE_CPUINFO | |||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_MK4_PACK_F32_A53) { | |||
| CpuInfoTmpReplace cpu_replace_guard(cpuinfo_uarch_cortex_a53); | |||
| using namespace conv_bias; | |||
| std::vector<conv_bias::TestArg> args = get_nchw44_conv_bias_args( | |||
| {2, 3, 7}, 1, false, false, false, false, false, true, true); | |||
| check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1"); | |||
| args = get_nchw44_conv_bias_args( | |||
| {2, 3, 7}, 2, false, false, false, false, false, true, true); | |||
| check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1"); | |||
| } | |||
| #endif | |||
| #endif | |||
| #if MEGDNN_AARCH64 | |||
| #if MGB_ENABLE_CPUINFO | |||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_MK4_PACK_F32_A55) { | |||
| CpuInfoTmpReplace cpu_replace_guard(cpuinfo_uarch_cortex_a55); | |||
| using namespace conv_bias; | |||
| std::vector<conv_bias::TestArg> args = | |||
| get_nchw44_conv_bias_args({1}, 1, true, false, false); | |||
| check_conv_bias(args, handle(), "CONV1x1:AARCH64_F32_MK4_K8X12X1:24"); | |||
| } | |||
| #endif | |||
| #endif | |||
| #if MEGDNN_AARCH64 | |||
| #if MGB_ENABLE_CPUINFO | |||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_MK4_PACK_F32_A53) { | |||
| CpuInfoTmpReplace cpu_replace_guard(cpuinfo_uarch_cortex_a53); | |||
| using namespace conv_bias; | |||
| std::vector<conv_bias::TestArg> args = | |||
| get_nchw44_conv_bias_args({1}, 1, true, false, false); | |||
| check_conv_bias(args, handle(), "CONV1x1:AARCH64_F32_MK4_K8X12X1:24"); | |||
| } | |||
| #endif | |||
| #endif | |||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM) { | |||
| UniformIntRNG rng{-50, 50}; | |||
| @@ -2216,7 +2304,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUINT8x8x32) { | |||
| #undef cb | |||
| } | |||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUINT8x8x32_FILTERPREPROCESS) { | |||
| TEST_F(ARM_COMMON_MULTI_THREADS, | |||
| CONV_BIAS_IM2COLMATMUL_QUINT8x8x32_FILTERPREPROCESS) { | |||
| UniformIntRNG rng{-50, 50}; | |||
| float epsilon = 0.001; | |||
| #define cb(name) \ | |||
| @@ -2247,7 +2336,6 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUINT8x8x32_FILTERPREPRO | |||
| #undef cb | |||
| } | |||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16) { | |||
| UniformIntRNG rng{-50, 50}; | |||
| float epsilon = 0.001; | |||
| @@ -2276,19 +2364,21 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16) { | |||
| #if MEGDNN_AARCH64 | |||
| cb("IM2COLMATMUL:AARCH64_INT8X8X16_K8X8X8"); | |||
| cb("IM2COLMATMUL:AARCH64_INT8X8X16_K4X4X16"); | |||
| cb("IM2COLMATMUL:ARM_COMMON_INT8X8X16"); | |||
| cb_nchw44("IM2COLMATMUL:AARCH64_INT8X8X16_MK4_4X4X8"); | |||
| cb_nchw44("IM2COLMATMUL:AARCH64_INT8X8X16_MK4_16X12X4"); | |||
| #elif MEGDNN_ARMV7 | |||
| cb("IM2COLMATMUL:ARM_COMMON_INT8X8X16"); | |||
| cb("IM2COLMATMUL:ARMV7_INT8X8X16_K4X8X8"); | |||
| cb("IM2COLMATMUL:ARMV7_INT8X8X16_K4X2X16"); | |||
| cb_nchw44("IM2COLMATMUL:ARMV7_INT8X8X16_MK4_K8X8X4"); | |||
| #endif | |||
| cb("IM2COLMATMUL:ARM_COMMON_INT8X8X16"); | |||
| #undef cb | |||
| #undef cb_nchw44 | |||
| } | |||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16_FILTERPREPROCESS) { | |||
| TEST_F(ARM_COMMON_MULTI_THREADS, | |||
| CONVBIAS_IM2COLMATMUL_INT8x8x16_FILTERPREPROCESS) { | |||
| UniformIntRNG rng{-50, 50}; | |||
| float epsilon = 0.001; | |||
| #define cb(name) \ | |||
| @@ -2311,7 +2401,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16_FILTERPREPROCES | |||
| #undef cb | |||
| } | |||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16_NOPACK_FILTERPREPROCESS) { | |||
| TEST_F(ARM_COMMON_MULTI_THREADS, | |||
| CONVBIAS_IM2COLMATMUL_INT8x8x16_NOPACK_FILTERPREPROCESS) { | |||
| UniformIntRNG rng{-50, 50}; | |||
| float epsilon = 0.001; | |||
| #define cb(name) \ | |||
| @@ -2415,8 +2506,9 @@ void checker_conv_bias_mul_int8x8x32(std::vector<conv_bias::TestArg> args, | |||
| } | |||
| } | |||
| void checker_conv_bias_int8x8x32_preprocess(std::vector<conv_bias::TestArg> args, | |||
| Handle* handle, const char* algo_name) { | |||
| void checker_conv_bias_int8x8x32_preprocess( | |||
| std::vector<conv_bias::TestArg> args, Handle* handle, | |||
| const char* algo_name) { | |||
| using namespace conv_bias; | |||
| Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker( | |||
| @@ -2461,7 +2553,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2) { | |||
| #undef cb | |||
| } | |||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2_PREPROCESS) { | |||
| TEST_F(ARM_COMMON_MULTI_THREADS, | |||
| CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2_PREPROCESS) { | |||
| using namespace conv_bias; | |||
| std::vector<conv_bias::TestArg> args = | |||
| get_nchw44_conv_bias_args({2, 5, 7}, 2, false, false, true); | |||
| @@ -2490,7 +2583,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S1) { | |||
| #undef cb | |||
| } | |||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S1_PREPROCESS) { | |||
| TEST_F(ARM_COMMON_MULTI_THREADS, | |||
| CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S1_PREPROCESS) { | |||
| using namespace conv_bias; | |||
| std::vector<conv_bias::TestArg> args = | |||
| get_nchw44_conv_bias_args({3, 4, 6}, 1, false, true, true); | |||
| @@ -2541,7 +2635,6 @@ TEST_F(ARM_COMMON_MULTI_THREADS, | |||
| #undef cb | |||
| } | |||
| TEST_F(ARM_COMMON_MULTI_THREADS, | |||
| CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_NCHW44_S1) { | |||
| UniformIntRNG rng{-50, 50}; | |||
| @@ -2678,7 +2771,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32) { | |||
| #undef cb | |||
| } | |||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8X8X32_FILTER_PREPROCESS) { | |||
| TEST_F(ARM_COMMON_MULTI_THREADS, | |||
| CONV_BIAS_IM2COLMATMUL_INT8X8X32_FILTER_PREPROCESS) { | |||
| using namespace conv_bias; | |||
| std::vector<conv_bias::TestArg> args = | |||
| get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true); | |||
| @@ -2722,7 +2816,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S1_MK4_PACK_F32) { | |||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S1_MK4_PACK_F32_PREPROCESS) { | |||
| using namespace conv_bias; | |||
| std::vector<conv_bias::TestArg> args = get_nchw44_conv_bias_args( | |||
| {2, 4, 7}, 1, false, false, false, false, false, true,true); | |||
| {2, 4, 7}, 1, false, false, false, false, false, true, true); | |||
| #define cb(name) \ | |||
| check_conv_bias_preprocess(args, handle(), nullptr, 0.001, \ | |||
| dtype::Float32(), dtype::Float32(), \ | |||
| @@ -2748,7 +2842,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S2_MK4_PACK_F32) { | |||
| #undef cb | |||
| } | |||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S2_MK4_PACK_F32_FUSE_PREPROCESS) { | |||
| TEST_F(ARM_COMMON_MULTI_THREADS, | |||
| CONV_BIAS_IM2COL_S2_MK4_PACK_F32_FUSE_PREPROCESS) { | |||
| using namespace conv_bias; | |||
| std::vector<conv_bias::TestArg> args = get_nchw44_conv_bias_args( | |||
| {3}, 2, false, false, false, false, false, true, true, false); | |||
| @@ -2884,12 +2979,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F16_PREPROCESS) { | |||
| NormalRNG rng(1); | |||
| #if MEGDNN_AARCH64 | |||
| check_conv_bias_preprocess(args, handle(), &rng, 0.03, dtype::Float16{}, | |||
| dtype::Float16{}, dtype::Float16{}, dtype::Float16{}, | |||
| "CONV1x1:AARCH64_F16_K8X24X1:48"); | |||
| dtype::Float16{}, dtype::Float16{}, | |||
| dtype::Float16{}, | |||
| "CONV1x1:AARCH64_F16_K8X24X1:48"); | |||
| #elif MEGDNN_ARMV7 | |||
| check_conv_bias_preprocess(args, handle(), &rng, 0.03, dtype::Float16{}, | |||
| dtype::Float16{}, dtype::Float16{}, dtype::Float16{}, | |||
| "CONV1x1:AARCH32_F16_K4X16X1:24"); | |||
| dtype::Float16{}, dtype::Float16{}, | |||
| dtype::Float16{}, | |||
| "CONV1x1:AARCH32_F16_K4X16X1:24"); | |||
| #endif | |||
| } | |||
| @@ -2951,7 +3048,6 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDSYM_PREPROCESS) { | |||
| #undef cb | |||
| } | |||
| #if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDASYM) { | |||
| UniformIntRNG rng{-50, 50}; | |||
| @@ -3074,7 +3170,6 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUINT8x8x32_PREPROCESS) { | |||
| cb("CONV1x1:ARMV7_QUINT8_K4X8X8:24"); | |||
| #endif | |||
| #undef cb | |||
| } | |||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_1X1_S1_INT8x8x16) { | |||
| @@ -3095,6 +3190,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_1X1_S1_INT8x8x16) { | |||
| #if MEGDNN_AARCH64 | |||
| cb("CONV1x1:AARCH64_INT8X8X16_K8X8X8:24"); | |||
| cb("CONV1x1:AARCH64_INT8X8X16_K4X4X16:24"); | |||
| cb_nchw44("CONV1x1:AARCH64_INT8X8X16_MK4_4X4X8:48"); | |||
| cb_nchw44("CONV1x1:AARCH64_INT8X8X16_MK4_16X12X4:48"); | |||
| #elif MEGDNN_ARMV7 | |||
| cb("CONV1x1:ARMV7_INT8X8X16_K4X8X8:24"); | |||
| cb("CONV1x1:ARMV7_INT8X8X16_K4X2X16:48"); | |||
| @@ -3128,11 +3225,11 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_1X1_S1_INT8x8x16_PREPROCESS) { | |||
| #if MEGDNN_AARCH64 | |||
| cb("CONV1x1:AARCH64_INT8X8X16_K8X8X8:24"); | |||
| cb("CONV1x1:AARCH64_INT8X8X16_K4X4X16:24"); | |||
| cb("CONV1x1:ARM_COMMON_INT8X8X16:24");//!add nopack test | |||
| cb("CONV1x1:ARM_COMMON_INT8X8X16:24"); //! add nopack test | |||
| #elif MEGDNN_ARMV7 | |||
| cb("CONV1x1:ARMV7_INT8X8X16_K4X8X8:24"); | |||
| cb("CONV1x1:ARMV7_INT8X8X16_K4X2X16:48"); | |||
| cb("CONV1x1:ARM_COMMON_INT8X8X16:24");//!add nopack test | |||
| cb("CONV1x1:ARM_COMMON_INT8X8X16:24"); //! add nopack test | |||
| #endif | |||
| #undef cb | |||
| } | |||
| @@ -3245,11 +3342,11 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_MK4_PREPROCESS) { | |||
| UniformIntRNG rng{-50, 50}; | |||
| float epsilon = 0.001; | |||
| #define cb(name) \ | |||
| check_conv_bias_preprocess(get_nchw44_conv_bias_args({1}, 1, true, false, false), \ | |||
| handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ | |||
| dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \ | |||
| dtype::QuantizedS8(60.25f), name); | |||
| #define cb(name) \ | |||
| check_conv_bias_preprocess( \ | |||
| get_nchw44_conv_bias_args({1}, 1, true, false, false), handle(), \ | |||
| &rng, epsilon, dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f), \ | |||
| dtype::QuantizedS32(6.25f), dtype::QuantizedS8(60.25f), name); | |||
| #if MEGDNN_AARCH64 | |||
| cb("CONV1x1:AARCH64_INT8X8X32_MK4_4X4X16:24"); | |||
| #elif MEGDNN_ARMV7 | |||
| @@ -9,7 +9,8 @@ | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #ifdef MGB_ENABLE_CPUINFO_CHECK | |||
| #include "src/common/utils.h" | |||
| #if defined(MGB_ENABLE_CPUINFO_CHECK) && MGB_ENABLE_CPUINFO | |||
| #include <cpuinfo.h> | |||
| #include <inttypes.h> | |||
| #include "gtest/gtest.h" | |||
| @@ -18,7 +19,6 @@ namespace megdnn { | |||
| namespace test { | |||
| TEST(ARM_RUNTIME, CPUINFO_KIRIN980) { | |||
| ASSERT_TRUE(cpuinfo_initialize()); | |||
| int right_soc = strcmp(cpuinfo_get_package(0)->name, "HiSilicon Kirin 980"); | |||
| @@ -68,7 +68,6 @@ TEST(ARM_RUNTIME, CPUINFO_KIRIN980) { | |||
| } | |||
| TEST(ARM_RUNTIME, CPUINFO_SDM8150) { | |||
| ASSERT_TRUE(cpuinfo_initialize()); | |||
| int right_soc = | |||
| @@ -119,7 +118,6 @@ TEST(ARM_RUNTIME, CPUINFO_SDM8150) { | |||
| } | |||
| TEST(ARM_RUNTIME, CPUINFO_SDM660) { | |||
| ASSERT_TRUE(cpuinfo_initialize()); | |||
| int right_soc = | |||
| @@ -173,4 +171,3 @@ TEST(ARM_RUNTIME, CPUINFO_SDM660) { | |||
| } // namespace megdnn | |||
| #endif | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,17 @@ | |||
| /** | |||
| * \file dnn/test/arm_common/cpuinfo_help.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/common/utils.h" | |||
| #include "test/arm_common/cpuinfo_help.h" | |||
| #if MGB_ENABLE_CPUINFO | |||
| std::mutex CpuInfoTmpReplace::m_cpuinfo_lock; | |||
| #endif | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,47 @@ | |||
| /** | |||
| * \file dnn/test/arm_common/cpuinfo_help.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include <mutex> | |||
| #include <vector> | |||
| #include "src/common/utils.h" | |||
| #if MGB_ENABLE_CPUINFO | |||
| #include "cpuinfo.h" | |||
| extern const struct cpuinfo_core** cpuinfo_linux_cpu_to_core_map; | |||
| class CpuInfoTmpReplace { | |||
| public: | |||
| CpuInfoTmpReplace(enum cpuinfo_uarch arch) { | |||
| m_cpuinfo_lock.lock(); | |||
| for (uint32_t i = 0; i < cpuinfo_get_cores_count(); ++i) { | |||
| m_arch_bak_vec.push_back(cpuinfo_linux_cpu_to_core_map[i]->uarch); | |||
| ((struct cpuinfo_core**)cpuinfo_linux_cpu_to_core_map)[i]->uarch = | |||
| arch; | |||
| } | |||
| } | |||
| ~CpuInfoTmpReplace() { | |||
| if (m_arch_bak_vec.size() > 0) { | |||
| for (uint32_t i = 0; i < cpuinfo_get_cores_count(); ++i) { | |||
| ((struct cpuinfo_core**)cpuinfo_linux_cpu_to_core_map)[i] | |||
| ->uarch = m_arch_bak_vec[i]; | |||
| } | |||
| } | |||
| m_cpuinfo_lock.unlock(); | |||
| } | |||
| private: | |||
| static std::mutex m_cpuinfo_lock; | |||
| std::vector<cpuinfo_uarch> m_arch_bak_vec; | |||
| }; | |||
| #endif | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -9,7 +9,8 @@ | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #ifdef MGB_ENABLE_CPUINFO_CHECK | |||
| #include "src/common/utils.h" | |||
| #if defined(MGB_ENABLE_CPUINFO_CHECK) && MGB_ENABLE_CPUINFO | |||
| #include <cpuinfo.h> | |||
| #include <inttypes.h> | |||
| #include "gtest/gtest.h" | |||
| @@ -18,14 +19,12 @@ namespace megdnn { | |||
| namespace test { | |||
| TEST(X86_RUNTIME, CPUINFO_XEON6130) { | |||
| ASSERT_TRUE(cpuinfo_initialize()); | |||
| int right_cpu = | |||
| strcmp(cpuinfo_get_package(0)->name, "Intel Xeon Gold 6130"); | |||
| if (!right_cpu) { | |||
| ASSERT_TRUE(cpuinfo_get_processors()); | |||
| ASSERT_TRUE(cpuinfo_has_x86_avx2()); | |||
| @@ -44,4 +43,3 @@ TEST(X86_RUNTIME, CPUINFO_XEON6130) { | |||
| } // namespace megdnn | |||
| #endif | |||
| // vim: syntax=cpp.doxygen | |||