| @@ -17,13 +17,21 @@ | |||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/x86/matrix_mul/common/common.h" | #include "src/x86/matrix_mul/common/common.h" | ||||
| #define DNN_AVX2_TARGET | |||||
| #if !defined(__clang__) | |||||
| //! bypass gcc bug https://bugs.launchpad.net/ubuntu/+source/gcc-5/+bug/1642109 | |||||
| #pragma GCC target("avx2") | |||||
| #else | |||||
| #undef DNN_AVX2_TARGET | |||||
| #define DNN_AVX2_TARGET MEGDNN_ATTRIBUTE_TARGET("avx2") | |||||
| #endif | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace x86 { | namespace x86 { | ||||
| namespace matmul_avx2_4x16x2 { | namespace matmul_avx2_4x16x2 { | ||||
| template <typename CType> | template <typename CType> | ||||
| MEGDNN_ATTRIBUTE_TARGET("avx2") | |||||
| void store_overflow(void* ptr, __m256i a); | |||||
| DNN_AVX2_TARGET void store_overflow(void* ptr, __m256i a); | |||||
| template <> | template <> | ||||
| void store_overflow<int16_t>(void* ptr, __m256i a) { | void store_overflow<int16_t>(void* ptr, __m256i a) { | ||||
| @@ -33,13 +41,14 @@ void store_overflow<int16_t>(void* ptr, __m256i a) { | |||||
| a = _mm256_permutevar8x32_epi32(a, idx); | a = _mm256_permutevar8x32_epi32(a, idx); | ||||
| _mm_storeu_si128((__m128i*)ptr, _mm256_extractf128_si256(a, 0)); | _mm_storeu_si128((__m128i*)ptr, _mm256_extractf128_si256(a, 0)); | ||||
| } | } | ||||
| template <> | template <> | ||||
| void store_overflow<int32_t>(void* ptr, __m256i a) { | void store_overflow<int32_t>(void* ptr, __m256i a) { | ||||
| _mm256_storeu_si256((__m256i*)(ptr), a); | _mm256_storeu_si256((__m256i*)(ptr), a); | ||||
| } | } | ||||
| template <typename CType> | template <typename CType> | ||||
| MEGDNN_ATTRIBUTE_TARGET("avx2") | |||||
| void store_overflow(void* ptr, __m256i a, int remain); | |||||
| DNN_AVX2_TARGET void store_overflow(void* ptr, __m256i a, int remain); | |||||
| template <> | template <> | ||||
| void store_overflow<int16_t>(void* ptr, __m256i a, int remain) { | void store_overflow<int16_t>(void* ptr, __m256i a, int remain) { | ||||
| @@ -51,6 +60,7 @@ void store_overflow<int16_t>(void* ptr, __m256i a, int remain) { | |||||
| _mm_maskmoveu_si128(_mm256_extractf128_si256(a, 0), mask, | _mm_maskmoveu_si128(_mm256_extractf128_si256(a, 0), mask, | ||||
| reinterpret_cast<char*>(ptr)); | reinterpret_cast<char*>(ptr)); | ||||
| } | } | ||||
| template <> | template <> | ||||
| void store_overflow<int32_t>(void* ptr, __m256i a, int remain) { | void store_overflow<int32_t>(void* ptr, __m256i a, int remain) { | ||||
| __m256i mask = _m256_continue_mask(remain); | __m256i mask = _m256_continue_mask(remain); | ||||
| @@ -870,4 +880,9 @@ static inline void gemm_s8s8s32_avx2_4x16x2_pack_at(dt_int16* out, | |||||
| } // namespace x86 | } // namespace x86 | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| #if !defined(__clang__) | |||||
| #pragma GCC reset_options | |||||
| #endif | |||||
| #undef DNN_AVX2_TARGET | |||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||