| @@ -184,7 +184,8 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32Vnni::get_kern( | |||
| } | |||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(AlgoInt8x8x32Vnni, | |||
| megdnn_x86_matmul_kern, 5, | |||
| megdnn_x86_matmul_kern, | |||
| "AlgoInt8x8x32Vnni"_hash, | |||
| x86::matmul::gemm_int8_vnni_12x32x4, | |||
| dt_int8, dt_int32, dt_uint8); | |||
| #endif | |||
| @@ -318,6 +319,8 @@ void gemm_s8s8s32_sse_4x8x2(const MatrixMulImpl::KernParam& kern_param) { | |||
| } | |||
| } // namespace | |||
| /*************************AlgoInt8x8x16AVX2********************/ | |||
| void MatrixMulImpl::AlgoInt8x8x16AVX2::gemm_s8s8s16_avx2_4x16x2( | |||
| const MatrixMulImpl::KernParam& kern_param) { | |||
| MEGDNN_MARK_USED_VAR(kern_param); | |||
| @@ -389,9 +392,86 @@ size_t MatrixMulImpl::AlgoInt8x8x16AVX2::get_workspace( | |||
| .get_workspace_size(); | |||
| } | |||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( | |||
| AlgoInt8x8x16AVX2, megdnn_x86_matmul_kern, 8, | |||
| AlgoInt8x8x16AVX2, megdnn_x86_matmul_kern, "AlgoInt8x8x16AVX2"_hash, | |||
| x86::matmul::gemm_avx2_s8s8s16_4x16x2, dt_int8, dt_int16, dt_int16); | |||
| /*************************AlgoInt8x8x16SSE********************/ | |||
| void MatrixMulImpl::AlgoInt8x8x16SSE::gemm_s8s8s16_sse_4x8x2( | |||
| const MatrixMulImpl::KernParam& kern_param) { | |||
| MEGDNN_MARK_USED_VAR(kern_param); | |||
| MIDOUT_BEGIN(megdnn_x86_matmul_kern_sse_4x8x2, midout_iv(2)) { | |||
| constexpr int cacheline = 64; | |||
| const size_t m = kern_param.M; | |||
| const size_t n = kern_param.N; | |||
| const size_t k = kern_param.K; | |||
| const bool trans_a = kern_param.trA; | |||
| const bool trans_b = kern_param.trB; | |||
| const size_t lda = kern_param.LDA; | |||
| const size_t ldb = kern_param.LDB; | |||
| const size_t ldc = kern_param.LDC; | |||
| auto a_type = kern_param.A_type; | |||
| auto b_type = kern_param.B_type; | |||
| auto c_type = kern_param.C_type; | |||
| const auto a_ptr = kern_param.A<dt_int8>(); | |||
| const auto b_ptr = kern_param.B<dt_int8>(); | |||
| auto c_ptr = kern_param.C<dt_int16>(); | |||
| x86::matmul::gemm_sse_s8s8s16_4x8x2 strategy(m, n, k, a_type, b_type, | |||
| c_type); | |||
| megdnn::matmul::GemmInterleaved<x86::matmul::gemm_sse_s8s8s16_4x8x2>( | |||
| m, n, k, trans_a, trans_b, strategy, cacheline) | |||
| .execute(a_ptr, lda, b_ptr, ldb, c_ptr, ldc, | |||
| kern_param.workspace_ptr); | |||
| } | |||
| MIDOUT_END(); | |||
| } | |||
| MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16SSE::get_kern( | |||
| const KernSizeParam&) const { | |||
| return gemm_s8s8s16_sse_4x8x2; | |||
| } | |||
| bool MatrixMulImpl::AlgoInt8x8x16SSE::usable( | |||
| const KernSizeParam& kern_size_param) const { | |||
| bool is_ab_same = | |||
| kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv(); | |||
| bool is_type_ok = | |||
| ((kern_size_param.A_type.enumv() == DTypeEnum::Int8 && | |||
| kern_size_param.C_type.enumv() == DTypeEnum::Int16) || | |||
| (kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8 && | |||
| kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS16)); | |||
| bool is_mode_ok = | |||
| kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && | |||
| is_supported(SIMDType::SSE4_1); | |||
| bool is_param_ok = is_ab_same && is_type_ok && is_mode_ok; | |||
| return is_param_ok; | |||
| } | |||
| bool MatrixMulImpl::AlgoInt8x8x16SSE::preferred(const KernSizeParam&) const { | |||
| return true; | |||
| } | |||
| size_t MatrixMulImpl::AlgoInt8x8x16SSE::get_workspace( | |||
| const KernSizeParam& kern_param) const { | |||
| constexpr int cacheline = 64; | |||
| const size_t m = kern_param.M; | |||
| const size_t n = kern_param.N; | |||
| const size_t k = kern_param.K; | |||
| const bool trans_a = kern_param.trA; | |||
| const bool trans_b = kern_param.trB; | |||
| auto a_type = kern_param.A_type; | |||
| auto b_type = kern_param.B_type; | |||
| auto c_type = kern_param.C_type; | |||
| x86::matmul::gemm_sse_s8s8s16_4x8x2 strategy(m, n, k, a_type, b_type, | |||
| c_type); | |||
| return megdnn::matmul::GemmInterleaved<x86::matmul::gemm_sse_s8s8s16_4x8x2>( | |||
| m, n, k, trans_a, trans_b, strategy, cacheline) | |||
| .get_workspace_size(); | |||
| } | |||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(AlgoInt8x8x16SSE, | |||
| megdnn_x86_matmul_kern, | |||
| "AlgoInt8x8x16SSE"_hash, | |||
| x86::matmul::gemm_sse_s8s8s16_4x8x2, | |||
| dt_int8, dt_int16, dt_int16); | |||
| /*************************AlgoInt8x8x32AVX2M4N16K2********************/ | |||
| MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2::get_kern( | |||
| const KernSizeParam&) const { | |||
| return gemm_s8s8s32_avx2_4x16x2; | |||
| @@ -426,8 +506,9 @@ size_t MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2::get_workspace( | |||
| .get_workspace_size(); | |||
| } | |||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( | |||
| AlgoInt8x8x32AVX2M4N16K2, megdnn_x86_matmul_kern, 8, | |||
| x86::matmul::gemm_avx2_s8s8s32_4x16x2, dt_int8, dt_int32, dt_int16); | |||
| AlgoInt8x8x32AVX2M4N16K2, megdnn_x86_matmul_kern, | |||
| "AlgoInt8x8x32AVX2M4N16K2"_hash, x86::matmul::gemm_avx2_s8s8s32_4x16x2, | |||
| dt_int8, dt_int32, dt_int16); | |||
| MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32AVX2M2N4K16::get_kern( | |||
| const KernSizeParam&) const { | |||
| @@ -463,7 +544,8 @@ size_t MatrixMulImpl::AlgoInt8x8x32AVX2M2N4K16::get_workspace( | |||
| .get_workspace_size(); | |||
| } | |||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32AVX2M2N4K16, | |||
| megdnn_x86_matmul_kern, 8, | |||
| megdnn_x86_matmul_kern, | |||
| "AlgoInt8x8x32AVX2M2N4K16"_hash, | |||
| x86::matmul::gemm_avx2_s8s8s32_2x4x16, | |||
| dt_int8, dt_int32); | |||
| @@ -501,7 +583,8 @@ size_t MatrixMulImpl::AlgoInt8x8x32SSEM4N8K2::get_workspace( | |||
| .get_workspace_size(); | |||
| } | |||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(AlgoInt8x8x32SSEM4N8K2, | |||
| megdnn_x86_matmul_kern, 9, | |||
| megdnn_x86_matmul_kern, | |||
| "AlgoInt8x8x32SSEM4N8K2"_hash, | |||
| x86::matmul::gemm_sse_s8s8s32_4x8x2, | |||
| dt_int8, dt_int32, dt_int16); | |||
| @@ -76,7 +76,6 @@ class MatrixMulImpl::AlgoInt8x8x16AVX2 : public AlgoBase { | |||
| private: | |||
| static void gemm_s8s8s16_avx2_4x16x2( | |||
| const MatrixMulImpl::KernParam& kern_param); | |||
| static MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2 m_algo; | |||
| public: | |||
| bool is_reproducible() const override { return true; } | |||
| @@ -89,6 +88,22 @@ public: | |||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
| }; | |||
| class MatrixMulImpl::AlgoInt8x8x16SSE : public AlgoBase { | |||
| private: | |||
| static void gemm_s8s8s16_sse_4x8x2( | |||
| const MatrixMulImpl::KernParam& kern_param); | |||
| public: | |||
| bool is_reproducible() const override { return true; } | |||
| const char* name() const override { return "X86_INT8X8X16_SSE"; } | |||
| bool usable(const KernSizeParam&) const override; | |||
| size_t get_workspace(const KernSizeParam&) const override; | |||
| kern_t get_kern(const KernSizeParam&) const override; | |||
| void* type() const override { return sm_x86_algo_type; } | |||
| bool preferred(const KernSizeParam&) const override; | |||
| MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
| }; | |||
| class MatrixMulImpl::AlgoInt8x8x32SSEM4N8K2 : public AlgoBase { | |||
| public: | |||
| bool is_reproducible() const override { return true; } | |||
| @@ -6,10 +6,17 @@ | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include <immintrin.h> | |||
| #ifdef WIN32 | |||
| #include <avx2intrin.h> | |||
| #include <avxintrin.h> | |||
| #include <fmaintrin.h> | |||
| #include <smmintrin.h> | |||
| #endif | |||
| #include <cmath> | |||
| #include <cstdint> | |||
| #include <type_traits> | |||
| @@ -21,10 +28,44 @@ namespace x86 { | |||
| namespace matmul_sse_4x8x2 { | |||
| template <typename CType> | |||
| MEGDNN_ATTRIBUTE_TARGET("sse4.1") | |||
| void store_overflow(void* ptr, __m128i a); | |||
| template <> | |||
| void store_overflow<int16_t>(void* ptr, __m128i a) { | |||
| a = _mm_shufflelo_epi16(a, 0x08); | |||
| a = _mm_shufflehi_epi16(a, 0x08); | |||
| a = _mm_shuffle_epi32(a, 0x08); | |||
| _mm_storel_epi64((__m128i*)ptr, a); | |||
| } | |||
| template <> | |||
| void store_overflow<int32_t>(void* ptr, __m128i a) { | |||
| _mm_storeu_si128((__m128i*)(ptr), a); | |||
| } | |||
| template <typename CType> | |||
| MEGDNN_ATTRIBUTE_TARGET("sse4.1") | |||
| void store_overflow(void* ptr, __m128i a, int remain); | |||
| template <> | |||
| void store_overflow<int16_t>(void* ptr, __m128i a, int remain) { | |||
| __m128i mask = _mm_continue_mask(remain * sizeof(int16_t)); | |||
| a = _mm_shufflelo_epi16(a, 0x08); | |||
| a = _mm_shufflehi_epi16(a, 0x08); | |||
| a = _mm_shuffle_epi32(a, 0x08); | |||
| _mm_maskmoveu_si128(a, mask, reinterpret_cast<char*>(ptr)); | |||
| } | |||
| template <> | |||
| void store_overflow<int32_t>(void* ptr, __m128i a, int remain) { | |||
| __m128i mask = _mm_continue_mask(remain * sizeof(int32_t)); | |||
| _mm_maskmoveu_si128(a, mask, reinterpret_cast<char*>(ptr)); | |||
| } | |||
| template <typename CType> | |||
| MEGDNN_ATTRIBUTE_TARGET("sse4.1") | |||
| static inline void kern_gemm_s8s8s32_sse_4x8x2(const int16_t* pack_a_ptr, | |||
| const int8_t* pack_b_ptr, | |||
| int32_t* c_ptr, const int ldc, | |||
| CType* c_ptr, const int ldc, | |||
| const int k) { | |||
| constexpr int k_step = 2; | |||
| @@ -102,20 +143,20 @@ static inline void kern_gemm_s8s8s32_sse_4x8x2(const int16_t* pack_a_ptr, | |||
| pack_a_ptr += 8; | |||
| pack_b_ptr += 16; | |||
| } | |||
| _mm_storeu_si128((__m128i*)(c_ptr), c_vec[0]); | |||
| _mm_storeu_si128((__m128i*)(c_ptr + 4), c_vec[1]); | |||
| _mm_storeu_si128((__m128i*)(c_ptr + ldc), c_vec[2]); | |||
| _mm_storeu_si128((__m128i*)(c_ptr + ldc + 4), c_vec[3]); | |||
| _mm_storeu_si128((__m128i*)(c_ptr + 2 * ldc), c_vec[4]); | |||
| _mm_storeu_si128((__m128i*)(c_ptr + 2 * ldc + 4), c_vec[5]); | |||
| _mm_storeu_si128((__m128i*)(c_ptr + 3 * ldc), c_vec[6]); | |||
| _mm_storeu_si128((__m128i*)(c_ptr + 3 * ldc + 4), c_vec[7]); | |||
| store_overflow<CType>(c_ptr, c_vec[0]); | |||
| store_overflow<CType>(c_ptr + 4, c_vec[1]); | |||
| store_overflow<CType>(c_ptr + ldc, c_vec[2]); | |||
| store_overflow<CType>(c_ptr + ldc + 4, c_vec[3]); | |||
| store_overflow<CType>(c_ptr + 2 * ldc, c_vec[4]); | |||
| store_overflow<CType>(c_ptr + 2 * ldc + 4, c_vec[5]); | |||
| store_overflow<CType>(c_ptr + 3 * ldc, c_vec[6]); | |||
| store_overflow<CType>(c_ptr + 3 * ldc + 4, c_vec[7]); | |||
| } | |||
| template <typename CType> | |||
| MEGDNN_ATTRIBUTE_TARGET("sse4.1") | |||
| static inline void kern_gemm_s8s8s32_sse_4x8x2_remain_m( | |||
| const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, int32_t* c_ptr, | |||
| const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, CType* c_ptr, | |||
| const int ldc, const int k, const int remain_m) { | |||
| constexpr int k_step = 2; | |||
| @@ -194,34 +235,35 @@ static inline void kern_gemm_s8s8s32_sse_4x8x2_remain_m( | |||
| pack_b_ptr += 16; | |||
| } | |||
| _mm_storeu_si128((__m128i*)(c_ptr), c_vec[0]); | |||
| _mm_storeu_si128((__m128i*)(c_ptr + 4), c_vec[1]); | |||
| store_overflow<CType>(c_ptr, c_vec[0]); | |||
| store_overflow<CType>(c_ptr + 4, c_vec[1]); | |||
| switch (remain_m) { | |||
| case 2: | |||
| _mm_storeu_si128((__m128i*)(c_ptr + ldc), c_vec[2]); | |||
| _mm_storeu_si128((__m128i*)(c_ptr + ldc + 4), c_vec[3]); | |||
| store_overflow<CType>(c_ptr + ldc, c_vec[2]); | |||
| store_overflow<CType>(c_ptr + ldc + 4, c_vec[3]); | |||
| break; | |||
| case 3: | |||
| _mm_storeu_si128((__m128i*)(c_ptr + ldc), c_vec[2]); | |||
| _mm_storeu_si128((__m128i*)(c_ptr + ldc + 4), c_vec[3]); | |||
| _mm_storeu_si128((__m128i*)(c_ptr + 2 * ldc), c_vec[4]); | |||
| _mm_storeu_si128((__m128i*)(c_ptr + 2 * ldc + 4), c_vec[5]); | |||
| store_overflow<CType>(c_ptr + ldc, c_vec[2]); | |||
| store_overflow<CType>(c_ptr + ldc + 4, c_vec[3]); | |||
| store_overflow<CType>(c_ptr + 2 * ldc, c_vec[4]); | |||
| store_overflow<CType>(c_ptr + 2 * ldc + 4, c_vec[5]); | |||
| break; | |||
| case 4: | |||
| _mm_storeu_si128((__m128i*)(c_ptr + ldc), c_vec[2]); | |||
| _mm_storeu_si128((__m128i*)(c_ptr + ldc + 4), c_vec[3]); | |||
| _mm_storeu_si128((__m128i*)(c_ptr + 2 * ldc), c_vec[4]); | |||
| _mm_storeu_si128((__m128i*)(c_ptr + 2 * ldc + 4), c_vec[5]); | |||
| _mm_storeu_si128((__m128i*)(c_ptr + 3 * ldc), c_vec[6]); | |||
| _mm_storeu_si128((__m128i*)(c_ptr + 3 * ldc + 4), c_vec[7]); | |||
| store_overflow<CType>(c_ptr + ldc, c_vec[2]); | |||
| store_overflow<CType>(c_ptr + ldc + 4, c_vec[3]); | |||
| store_overflow<CType>(c_ptr + 2 * ldc, c_vec[4]); | |||
| store_overflow<CType>(c_ptr + 2 * ldc + 4, c_vec[5]); | |||
| store_overflow<CType>(c_ptr + 3 * ldc, c_vec[6]); | |||
| store_overflow<CType>(c_ptr + 3 * ldc + 4, c_vec[7]); | |||
| default: | |||
| break; | |||
| } | |||
| } | |||
| template <typename CType> | |||
| MEGDNN_ATTRIBUTE_TARGET("sse4.1") | |||
| static inline void kern_gemm_s8s8s32_sse_4x8x2_remain_n( | |||
| const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, int32_t* c_ptr, | |||
| const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, CType* c_ptr, | |||
| const int ldc, const int k, int remain_n) { | |||
| constexpr int k_step = 2; | |||
| @@ -301,10 +343,10 @@ static inline void kern_gemm_s8s8s32_sse_4x8x2_remain_n( | |||
| } | |||
| if (remain_n >= 4) { | |||
| _mm_storeu_si128((__m128i*)(c_ptr), c_vec[0]); | |||
| _mm_storeu_si128((__m128i*)(c_ptr + ldc), c_vec[2]); | |||
| _mm_storeu_si128((__m128i*)(c_ptr + 2 * ldc), c_vec[4]); | |||
| _mm_storeu_si128((__m128i*)(c_ptr + 3 * ldc), c_vec[6]); | |||
| store_overflow<CType>(c_ptr, c_vec[0]); | |||
| store_overflow<CType>(c_ptr + ldc, c_vec[2]); | |||
| store_overflow<CType>(c_ptr + 2 * ldc, c_vec[4]); | |||
| store_overflow<CType>(c_ptr + 3 * ldc, c_vec[6]); | |||
| c_ptr += 4; | |||
| remain_n -= 4; | |||
| c_vec[0] = c_vec[1]; | |||
| @@ -312,35 +354,16 @@ static inline void kern_gemm_s8s8s32_sse_4x8x2_remain_n( | |||
| c_vec[4] = c_vec[5]; | |||
| c_vec[6] = c_vec[7]; | |||
| } | |||
| switch (remain_n) { | |||
| case 0: | |||
| break; | |||
| case 1: | |||
| *(c_ptr) = _mm_extract_epi32(c_vec[0], 0); | |||
| *(c_ptr + ldc) = _mm_extract_epi32(c_vec[2], 0); | |||
| *(c_ptr + 2 * ldc) = _mm_extract_epi32(c_vec[4], 0); | |||
| *(c_ptr + 3 * ldc) = _mm_extract_epi32(c_vec[6], 0); | |||
| break; | |||
| case 2: | |||
| case 3: | |||
| _mm_storel_epi64((__m128i*)(c_ptr), c_vec[0]); | |||
| _mm_storel_epi64((__m128i*)(c_ptr + ldc), c_vec[2]); | |||
| _mm_storel_epi64((__m128i*)(c_ptr + 2 * ldc), c_vec[4]); | |||
| _mm_storel_epi64((__m128i*)(c_ptr + 3 * ldc), c_vec[6]); | |||
| break; | |||
| } | |||
| if (remain_n == 3) { | |||
| *(c_ptr + 2) = _mm_extract_epi32(c_vec[0], 2); | |||
| *(c_ptr + ldc + 2) = _mm_extract_epi32(c_vec[2], 2); | |||
| *(c_ptr + 2 * ldc + 2) = _mm_extract_epi32(c_vec[4], 2); | |||
| *(c_ptr + 3 * ldc + 2) = _mm_extract_epi32(c_vec[6], 2); | |||
| } | |||
| store_overflow<CType>(c_ptr, c_vec[0], remain_n); | |||
| store_overflow<CType>(c_ptr + ldc, c_vec[2], remain_n); | |||
| store_overflow<CType>(c_ptr + 2 * ldc, c_vec[4], remain_n); | |||
| store_overflow<CType>(c_ptr + 3 * ldc, c_vec[6], remain_n); | |||
| } | |||
| template <typename CType> | |||
| MEGDNN_ATTRIBUTE_TARGET("sse4.1") | |||
| static inline void kern_gemm_s8s8s32_sse_4x8x2_remain_m_n( | |||
| const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, int32_t* c_ptr, | |||
| const int16_t* pack_a_ptr, const int8_t* pack_b_ptr, CType* c_ptr, | |||
| const int ldc, const int k, int remain_m, int remain_n) { | |||
| constexpr int k_step = 2; | |||
| @@ -421,8 +444,7 @@ static inline void kern_gemm_s8s8s32_sse_4x8x2_remain_m_n( | |||
| int index_array[4]{0, 2, 4, 6}; | |||
| if (remain_n >= 4) { | |||
| for (int m = 0; m < remain_m; ++m) { | |||
| _mm_storeu_si128((__m128i*)(c_ptr + m * ldc), | |||
| c_vec[index_array[m]]); | |||
| store_overflow<CType>(c_ptr + m * ldc, c_vec[index_array[m]]); | |||
| } | |||
| c_ptr += 4; | |||
| remain_n -= 4; | |||
| @@ -431,29 +453,8 @@ static inline void kern_gemm_s8s8s32_sse_4x8x2_remain_m_n( | |||
| c_vec[4] = c_vec[5]; | |||
| c_vec[6] = c_vec[7]; | |||
| } | |||
| switch (remain_n) { | |||
| case 0: | |||
| break; | |||
| case 1: | |||
| for (int m = 0; m < remain_m; ++m) { | |||
| *(c_ptr + m * ldc) = | |||
| _mm_extract_epi32(c_vec[index_array[m]], 0); | |||
| } | |||
| break; | |||
| case 2: | |||
| case 3: | |||
| for (int m = 0; m < remain_m; ++m) { | |||
| _mm_storel_epi64((__m128i*)(c_ptr + m * ldc), | |||
| c_vec[index_array[m]]); | |||
| } | |||
| break; | |||
| } | |||
| if (remain_n == 3) { | |||
| for (int m = 0; m < remain_m; ++m) { | |||
| *(c_ptr + m * ldc + 2) = | |||
| _mm_extract_epi32(c_vec[index_array[m]], 2); | |||
| } | |||
| for (int m = 0; m < remain_m; ++m) { | |||
| store_overflow<CType>(c_ptr + m * ldc, c_vec[index_array[m]], remain_n); | |||
| } | |||
| } | |||
| @@ -6,7 +6,8 @@ | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/common/utils.h" | |||
| @@ -18,11 +19,9 @@ using namespace megdnn; | |||
| using namespace x86; | |||
| using namespace x86::matmul; | |||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_sse_s8s8s32_4x8x2); | |||
| void gemm_sse_s8s8s32_4x8x2::pack_A(dt_int16* out, const dt_int8* in, int ldin, | |||
| int y0, int ymax, int k0, int kmax, | |||
| bool transpose) const { | |||
| static inline void gemm_packa(dt_int16* out, const dt_int8* in, int ldin, | |||
| int y0, int ymax, int k0, int kmax, | |||
| bool transpose) { | |||
| if (transpose) { | |||
| matmul_sse_4x8x2::gemm_s8s8s32_sse_4x8x2_pack_at(out, in, ldin, y0, | |||
| ymax, k0, kmax); | |||
| @@ -31,10 +30,8 @@ void gemm_sse_s8s8s32_4x8x2::pack_A(dt_int16* out, const dt_int8* in, int ldin, | |||
| ymax, k0, kmax); | |||
| } | |||
| } | |||
| void gemm_sse_s8s8s32_4x8x2::pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||
| int x0, int xmax, int k0, int kmax, | |||
| bool transpose) const { | |||
| static inline void gemm_packb(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||
| int xmax, int k0, int kmax, bool transpose) { | |||
| if (transpose) { | |||
| matmul_sse_4x8x2::gemm_s8s8s32_sse_4x8x2_pack_bt(out, in, ldin, x0, | |||
| xmax, k0, kmax); | |||
| @@ -43,20 +40,11 @@ void gemm_sse_s8s8s32_4x8x2::pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||
| xmax, k0, kmax); | |||
| } | |||
| } | |||
| void gemm_sse_s8s8s32_4x8x2::kern(const dt_int16* pack_a_ptr, | |||
| const dt_int8* pack_b_ptr, size_t m, size_t n, | |||
| size_t k, dt_int32* c_ptr, size_t ldc, | |||
| bool is_first_k, const dt_int32*, | |||
| dt_int32*) const { | |||
| megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
| ((A_dtype.enumv() == DTypeEnum::Int8 && | |||
| C_dtype.enumv() == DTypeEnum::Int32) || | |||
| (A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||
| C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||
| "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||
| C_dtype.name()); | |||
| megdnn_assert(is_first_k == true); | |||
| template <typename CType> | |||
| static inline void gemm_kern(const dt_int16* pack_a_ptr, | |||
| const dt_int8* pack_b_ptr, size_t m, size_t n, | |||
| size_t k, CType* c_ptr, size_t ldc, | |||
| bool is_first_k) { | |||
| constexpr int m_tile = 4; | |||
| constexpr int n_tile = 8; | |||
| constexpr int k_tile = 2; | |||
| @@ -99,4 +87,62 @@ void gemm_sse_s8s8s32_4x8x2::kern(const dt_int16* pack_a_ptr, | |||
| } | |||
| } | |||
| } | |||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_sse_s8s8s32_4x8x2); | |||
| void gemm_sse_s8s8s32_4x8x2::pack_A(dt_int16* out, const dt_int8* in, int ldin, | |||
| int y0, int ymax, int k0, int kmax, | |||
| bool transpose) const { | |||
| gemm_packa(out, in, ldin, y0, ymax, k0, kmax, transpose); | |||
| } | |||
| void gemm_sse_s8s8s32_4x8x2::pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||
| int x0, int xmax, int k0, int kmax, | |||
| bool transpose) const { | |||
| gemm_packb(out, in, ldin, x0, xmax, k0, kmax, transpose); | |||
| } | |||
| void gemm_sse_s8s8s32_4x8x2::kern(const dt_int16* pack_a_ptr, | |||
| const dt_int8* pack_b_ptr, size_t m, size_t n, | |||
| size_t k, dt_int32* c_ptr, size_t ldc, | |||
| bool is_first_k, const dt_int32*, | |||
| dt_int32*) const { | |||
| megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
| ((A_dtype.enumv() == DTypeEnum::Int8 && | |||
| C_dtype.enumv() == DTypeEnum::Int32) || | |||
| (A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||
| C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||
| "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||
| C_dtype.name()); | |||
| megdnn_assert(is_first_k == true); | |||
| gemm_kern(pack_a_ptr, pack_b_ptr, m, n, k, c_ptr, ldc, is_first_k); | |||
| } | |||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_sse_s8s8s16_4x8x2); | |||
| void gemm_sse_s8s8s16_4x8x2::pack_A(dt_int16* out, const dt_int8* in, int ldin, | |||
| int y0, int ymax, int k0, int kmax, | |||
| bool transpose) const { | |||
| gemm_packa(out, in, ldin, y0, ymax, k0, kmax, transpose); | |||
| } | |||
| void gemm_sse_s8s8s16_4x8x2::pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||
| int x0, int xmax, int k0, int kmax, | |||
| bool transpose) const { | |||
| gemm_packb(out, in, ldin, x0, xmax, k0, kmax, transpose); | |||
| } | |||
| void gemm_sse_s8s8s16_4x8x2::kern(const dt_int16* pack_a_ptr, | |||
| const dt_int8* pack_b_ptr, size_t m, size_t n, | |||
| size_t k, dt_int16* c_ptr, size_t ldc, | |||
| bool is_first_k, const dt_int32*, | |||
| dt_int32*) const { | |||
| megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
| ((A_dtype.enumv() == DTypeEnum::Int8 && | |||
| C_dtype.enumv() == DTypeEnum::Int16) || | |||
| (A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||
| C_dtype.enumv() == DTypeEnum::QuantizedS16)), | |||
| "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||
| C_dtype.name()); | |||
| megdnn_assert(is_first_k == true); | |||
| gemm_kern(pack_a_ptr, pack_b_ptr, m, n, k, c_ptr, ldc, is_first_k); | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -38,6 +38,10 @@ MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE(dt_int8, dt_int16, dt_int32, dt_int32, | |||
| 4, 8, 2, false, false, | |||
| gemm_sse_s8s8s32_4x8x2); | |||
| MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE(dt_int8, dt_int16, dt_int16, dt_int32, | |||
| 4, 8, 2, false, false, | |||
| gemm_sse_s8s8s16_4x8x2); | |||
| } // namespace matmul | |||
| } // namespace x86 | |||
| } // namespace megdnn | |||
| @@ -38,6 +38,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||
| AlgoInt8x8x32AVX2M2N4K16 algoint8x8x32avx2_m2n4k16; | |||
| AlgoInt8x8x32SSEM4N8K2 algoint8x8x32sse_m4n8k2; | |||
| AlgoInt8x8x16AVX2 algoint8x8x16avx2_m4n16k2; | |||
| AlgoInt8x8x16SSE algoint8x8x16sse_m4n8k2; | |||
| AlgoF32MK8_8x8 algof32mk8_8x8; | |||
| public: | |||
| @@ -51,6 +52,7 @@ public: | |||
| all_algos.emplace_back(&algoint8x8x16avx2_m4n16k2); | |||
| all_algos.emplace_back(&algoint8x8x32avx2_m2n4k16); | |||
| all_algos.emplace_back(&algoint8x8x32sse_m4n8k2); | |||
| all_algos.emplace_back(&algoint8x8x16sse_m4n8k2); | |||
| all_algos.emplace_back(&algof32mk8_8x8); | |||
| #if MEGDNN_X86_WITH_MKL_DNN | |||
| all_algos.emplace_back(&algoint8x8x32mkldnn); | |||
| @@ -56,6 +56,7 @@ protected: | |||
| class AlgoInt8x8x32AVX2M4N16K2; | |||
| class AlgoInt8x8x32SSEM4N8K2; | |||
| class AlgoInt8x8x16AVX2; | |||
| class AlgoInt8x8x16SSE; | |||
| class AlgoPack; | |||
| class AlgoF32MK8_8x8; | |||
| }; | |||
| @@ -835,6 +835,7 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8X8X) { | |||
| } | |||
| if (::megdnn::x86::is_supported(::megdnn::x86::SIMDType::SSE4_2)) { | |||
| cb("IM2COLMATMUL:X86_INT8X8X32_SSE_4X8X2"); | |||
| cb2("IM2COLMATMUL:X86_INT8X8X16_SSE"); | |||
| } | |||
| #undef cb | |||
| @@ -1002,7 +1003,7 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_FP32_BLAS) { | |||
| } | |||
| #endif | |||
| TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_INT8X8X32) { | |||
| TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_INT8X8X) { | |||
| using namespace conv_bias; | |||
| UniformIntRNG rng{-50, 50}; | |||
| float epsilon = 0.001; | |||
| @@ -1028,10 +1029,16 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_INT8X8X32) { | |||
| checker_conv_bias(args, handle(), &rng, epsilon, dtype::Int8{}, | |||
| dtype::Int8{}, dtype::Int32{}, dtype::Int32{}, | |||
| "CONV1x1:X86_INT8X8X32_AVX2_2X4X16:24"); | |||
| checker_conv_bias(args, handle(), &rng, epsilon, dtype::Int8{}, | |||
| dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, | |||
| "CONV1x1:X86_INT8X8X16_AVX2"); | |||
| } | |||
| checker_conv_bias(args, handle(), &rng, epsilon, dtype::Int8{}, | |||
| dtype::Int8{}, dtype::Int32{}, dtype::Int32{}, | |||
| "CONV1x1:X86_INT8X8X32_SSE_4X8X2:48"); | |||
| checker_conv_bias(args, handle(), &rng, epsilon, dtype::Int8{}, | |||
| dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, | |||
| "CONV1x1:X86_INT8X8X16_SSE"); | |||
| } | |||
| /************************* End Conv1x1 PackA ************************/ | |||
| @@ -403,6 +403,7 @@ TEST_F(X86, BENCHMARK_CONVOLUTION_I8x8x16) { | |||
| benchmark.set_dtype(0, dtype::Int8()) | |||
| .set_dtype(1, dtype::Int8()) | |||
| .set_dtype(2, dtype::Int16()); | |||
| benchmark.set_before_exec_callback(AlgoChecker<Convolution>(".*")); | |||
| benchmark.set_display(false); | |||
| benchmark.set_times(RUN); | |||
| @@ -52,6 +52,10 @@ TEST_F(X86, MATRIX_MUL_AVX2_8X8X16) { | |||
| matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, | |||
| handle(), "X86_INT8X8X16_AVX2"); | |||
| } | |||
| TEST_F(X86, MATRIX_MUL_SSE_8X8X16) { | |||
| matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, | |||
| handle(), "X86_INT8X8X16_SSE"); | |||
| } | |||
| TEST_F(X86, MATRIX_MUL_SSE_8X8X32) { | |||
| matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, | |||
| handle(), "X86_INT8X8X32_SSE_4X8X2"); | |||
| @@ -132,6 +136,17 @@ TEST_F(X86, BENCHMARK_MATRIX_MUL_8X8X32) { | |||
| benchmarker_avx2_4x16x2_8816.set_before_exec_callback( | |||
| AlgoChecker<MatrixMul>("X86_INT8X8X16_AVX2")); | |||
| Benchmarker<MatrixMul> benchmarker_sse_4x8x2_8816(handle()); | |||
| benchmarker_sse_4x8x2_8816.set_display(false) | |||
| .set_times(RUNS) | |||
| .set_dtype(0, dtype::Int8{}) | |||
| .set_dtype(1, dtype::Int8{}) | |||
| .set_dtype(2, dtype::Int16{}) | |||
| .set_rng(0, rng.get()) | |||
| .set_rng(1, rng.get()); | |||
| benchmarker_sse_4x8x2_8816.set_before_exec_callback( | |||
| AlgoChecker<MatrixMul>("X86_INT8X8X16_SSE")); | |||
| Benchmarker<MatrixMul> benchmarker_avx2_2x4x16(handle()); | |||
| benchmarker_avx2_2x4x16.set_display(false) | |||
| .set_times(RUNS) | |||
| @@ -212,9 +227,15 @@ TEST_F(X86, BENCHMARK_MATRIX_MUL_8X8X32) { | |||
| std::cout << "sse: " << sse_used << " ms, " | |||
| << computations / sse_used << " Gflops, " | |||
| << "speed_up " << float_used / sse_used << ", "; | |||
| auto sse_used_8816 = | |||
| benchmarker_sse_4x8x2_8816.exec({{M, K}, {K, N}, {}}) / | |||
| RUNS; | |||
| std::cout << "sse_8816: " << sse_used_8816 << " ms, " | |||
| << computations / sse_used_8816 << " Gflops, "; | |||
| } | |||
| std::cout << std::endl; | |||
| }; | |||
| run(256, 256, 256); | |||
| for (size_t M : {8, 64, 112, 256, 512}) { | |||
| for (size_t K : {8, 16, 32, 64, 112, 256, 512}) { | |||