* simplify cast fp16 avx512 dispatch * define sse4.1 macro on msvc avx+tags/20230223
| @@ -160,28 +160,28 @@ macro(ncnn_add_layer class) | |||
| if(NCNN_TARGET_ARCH STREQUAL "x86") | |||
| if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC" OR (CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_SIMULATE_ID MATCHES "MSVC" AND CMAKE_CXX_COMPILER_FRONTEND_VARIANT MATCHES "MSVC")) | |||
| if(NCNN_RUNTIME_CPU AND NCNN_AVX512) | |||
| ncnn_add_arch_opt_layer(${class} avx512 "/arch:AVX512 /D__FMA__ /D__F16C__") | |||
| ncnn_add_arch_opt_layer(${class} avx512 "/arch:AVX512 /D__SSE4_1__ /D__FMA__ /D__F16C__") | |||
| endif() | |||
| if(NCNN_RUNTIME_CPU AND NCNN_FMA) | |||
| ncnn_add_arch_opt_layer(${class} fma "/arch:AVX /D__FMA__ /D__F16C__") | |||
| ncnn_add_arch_opt_layer(${class} fma "/arch:AVX /D__SSE4_1__ /D__FMA__ /D__F16C__") | |||
| endif() | |||
| if(NCNN_RUNTIME_CPU AND NCNN_AVX) | |||
| ncnn_add_arch_opt_layer(${class} avx "/arch:AVX") | |||
| ncnn_add_arch_opt_layer(${class} avx "/arch:AVX /D__SSE4_1__") | |||
| endif() | |||
| if(NCNN_AVX512VNNI) | |||
| ncnn_add_arch_opt_source(${class} avx512vnni "/arch:AVX512 /D__FMA__ /D__F16C__ /D__AVX512VNNI__") | |||
| ncnn_add_arch_opt_source(${class} avx512vnni "/arch:AVX512 /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVX512VNNI__") | |||
| endif() | |||
| if(NCNN_AVXVNNI) | |||
| ncnn_add_arch_opt_source(${class} avxvnni "/arch:AVX2 /D__FMA__ /D__F16C__ /D__AVXVNNI__") | |||
| ncnn_add_arch_opt_source(${class} avxvnni "/arch:AVX2 /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVXVNNI__") | |||
| endif() | |||
| if(NCNN_AVX2) | |||
| ncnn_add_arch_opt_source(${class} avx2 "/arch:AVX2 /D__FMA__ /D__F16C__") | |||
| ncnn_add_arch_opt_source(${class} avx2 "/arch:AVX2 /D__SSE4_1__ /D__FMA__ /D__F16C__") | |||
| endif() | |||
| if(NCNN_XOP) | |||
| ncnn_add_arch_opt_source(${class} xop "/arch:AVX /D__XOP__") | |||
| ncnn_add_arch_opt_source(${class} xop "/arch:AVX /D__SSE4_1__ /D__XOP__") | |||
| endif() | |||
| if(NCNN_F16C) | |||
| ncnn_add_arch_opt_source(${class} f16c "/arch:AVX /D__F16C__") | |||
| ncnn_add_arch_opt_source(${class} f16c "/arch:AVX /D__SSE4_1__ /D__F16C__") | |||
| endif() | |||
| else() | |||
| if(NCNN_RUNTIME_CPU AND NCNN_AVX512) | |||
| @@ -342,7 +342,7 @@ if(NCNN_TARGET_ARCH STREQUAL "x86") | |||
| if(NOT NCNN_RUNTIME_CPU AND NCNN_AVX512) | |||
| if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC" OR (CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_SIMULATE_ID MATCHES "MSVC" AND CMAKE_CXX_COMPILER_FRONTEND_VARIANT MATCHES "MSVC")) | |||
| target_compile_options(ncnn PRIVATE /arch:AVX512 /D__FMA__ /D__F16C__) | |||
| target_compile_options(ncnn PRIVATE /arch:AVX512 /D__SSE4_1__ /D__FMA__ /D__F16C__) | |||
| if(NCNN_AVX512VNNI) | |||
| target_compile_options(ncnn PRIVATE /D__AVX512VNNI__) | |||
| endif() | |||
| @@ -361,9 +361,9 @@ if(NCNN_TARGET_ARCH STREQUAL "x86") | |||
| elseif(NOT NCNN_RUNTIME_CPU AND NCNN_FMA) | |||
| if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC" OR (CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_SIMULATE_ID MATCHES "MSVC" AND CMAKE_CXX_COMPILER_FRONTEND_VARIANT MATCHES "MSVC")) | |||
| if(NCNN_AVX2) | |||
| target_compile_options(ncnn PRIVATE /arch:AVX2 /D__FMA__) | |||
| target_compile_options(ncnn PRIVATE /arch:AVX2 /D__SSE4_1__ /D__FMA__) | |||
| else() | |||
| target_compile_options(ncnn PRIVATE /arch:AVX /D__FMA__) | |||
| target_compile_options(ncnn PRIVATE /arch:AVX /D__SSE4_1__ /D__FMA__) | |||
| endif() | |||
| if(NCNN_AVXVNNI) | |||
| target_compile_options(ncnn PRIVATE /D__AVXVNNI__) | |||
| @@ -390,7 +390,7 @@ if(NCNN_TARGET_ARCH STREQUAL "x86") | |||
| endif() | |||
| elseif(NOT NCNN_RUNTIME_CPU AND NCNN_AVX) | |||
| if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC" OR (CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_SIMULATE_ID MATCHES "MSVC" AND CMAKE_CXX_COMPILER_FRONTEND_VARIANT MATCHES "MSVC")) | |||
| target_compile_options(ncnn PRIVATE /arch:AVX) | |||
| target_compile_options(ncnn PRIVATE /arch:AVX /D__SSE4_1__) | |||
| if(NCNN_XOP) | |||
| target_compile_options(ncnn PRIVATE /D__XOP__) | |||
| endif() | |||
| @@ -17,36 +17,10 @@ void cast_fp32_to_bf16_sse_avx512bf16(const Mat& bottom_blob, Mat& top_blob, con | |||
| void cast_bf16_to_fp32_sse_avx512bf16(const Mat& bottom_blob, Mat& top_blob, const Option& opt); | |||
| #endif | |||
| #if __AVX__ | |||
| static inline __m256 bfloat2float_avx(__m128i v0) | |||
| { | |||
| __m128i zero = _mm_set1_epi32(0); | |||
| __m128i a = _mm_slli_epi32(_mm_unpacklo_epi16(v0, zero), 16); | |||
| __m128i b = _mm_slli_epi32(_mm_unpackhi_epi16(v0, zero), 16); | |||
| __m256i ab = _mm256_set1_epi32(0); | |||
| ab = _mm256_insertf128_si256(ab, a, 0); // insert in low 128-bit lane | |||
| ab = _mm256_insertf128_si256(ab, b, 1); // insert in high 128-bit lane | |||
| return _mm256_castsi256_ps(ab); | |||
| } | |||
| #if __AVX2__ | |||
| static inline __m256i float2bfloat_avx(__m256 v0, __m256 v1) | |||
| { | |||
| __m256i a = _mm256_castps_si256(v0); | |||
| a = _mm256_srli_epi32(a, 16); | |||
| __m256i b = _mm256_castps_si256(v1); | |||
| b = _mm256_srli_epi32(b, 16); | |||
| __m256i abab = _mm256_packus_epi32(a, b); | |||
| return _mm256_permutevar8x32_epi32(abab, _mm256_setr_epi32(0, 1, 4, 5, 2, 3, 6, 7)); | |||
| } | |||
| static inline __m128i float2bfloat_avx(__m256 v0) | |||
| { | |||
| __m256i a = _mm256_castps_si256(v0); | |||
| a = _mm256_srli_epi32(a, 16); | |||
| __m256i aaaa = _mm256_packus_epi32(a, a); | |||
| return _mm256_castsi256_si128(_mm256_permutevar8x32_epi32(aaaa, _mm256_setr_epi32(0, 1, 4, 5, 2, 3, 6, 7))); | |||
| } | |||
| #if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ | |||
| void cast_fp32_to_bf16_sse_avx2(const Mat& bottom_blob, Mat& top_blob, const Option& opt); | |||
| void cast_bf16_to_fp32_sse_avx2(const Mat& bottom_blob, Mat& top_blob, const Option& opt); | |||
| #endif | |||
| #endif // __AVX__ | |||
| static void cast_fp32_to_bf16_sse(const Mat& bottom_blob, Mat& top_blob, const Option& opt) | |||
| { | |||
| @@ -58,6 +32,14 @@ static void cast_fp32_to_bf16_sse(const Mat& bottom_blob, Mat& top_blob, const O | |||
| } | |||
| #endif | |||
| #if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ | |||
| if (ncnn::cpu_support_x86_avx2()) | |||
| { | |||
| cast_fp32_to_bf16_sse_avx2(bottom_blob, top_blob, opt); | |||
| return; | |||
| } | |||
| #endif | |||
| const int w = bottom_blob.w; | |||
| const int h = bottom_blob.h; | |||
| const int d = bottom_blob.d; | |||
| @@ -73,39 +55,38 @@ static void cast_fp32_to_bf16_sse(const Mat& bottom_blob, Mat& top_blob, const O | |||
| unsigned short* outptr = top_blob.channel(q); | |||
| int i = 0; | |||
| #if __AVX512BF16__ | |||
| for (; i + 15 < size; i += 16) | |||
| { | |||
| __m512 _v_fp32 = _mm512_loadu_ps(ptr); | |||
| __m256bh _v_bf16 = _mm512_cvtneps_pbh(_v_fp32); | |||
| _mm256_storeu_si256((__m256i*)outptr, (__m256i)_v_bf16); | |||
| ptr += 16; | |||
| outptr += 16; | |||
| } | |||
| for (; i + 7 < size; i += 8) | |||
| #if __SSE2__ | |||
| #if __AVX__ | |||
| #if __AVX512F__ | |||
| for (; i + 31 < size; i += 32) | |||
| { | |||
| __m256 _v_fp32 = _mm256_loadu_ps(ptr); | |||
| __m128bh _v_bf16 = _mm256_cvtneps_pbh(_v_fp32); | |||
| _mm_storeu_si128((__m128i*)outptr, (__m128i)_v_bf16); | |||
| ptr += 8; | |||
| outptr += 8; | |||
| _mm512_storeu_si512((__m512i*)outptr, float2bfloat_avx512(_mm512_loadu_ps(ptr), _mm512_loadu_ps(ptr + 16))); | |||
| ptr += 32; | |||
| outptr += 32; | |||
| } | |||
| #elif __AVX2__ | |||
| #endif // __AVX512F__ | |||
| for (; i + 15 < size; i += 16) | |||
| { | |||
| #if __AVX512F__ | |||
| _mm256_storeu_si256((__m256i*)outptr, float2bfloat_avx512(_mm512_loadu_ps(ptr))); | |||
| #else | |||
| _mm256_storeu_si256((__m256i*)outptr, float2bfloat_avx(_mm256_loadu_ps(ptr), _mm256_loadu_ps(ptr + 8))); | |||
| #endif | |||
| ptr += 16; | |||
| outptr += 16; | |||
| } | |||
| #endif // __AVX__ | |||
| for (; i + 7 < size; i += 8) | |||
| { | |||
| #if __AVX__ | |||
| _mm_store_si128((__m128i*)outptr, float2bfloat_avx(_mm256_loadu_ps(ptr))); | |||
| #else | |||
| _mm_store_si128((__m128i*)outptr, float2bfloat_sse(_mm_loadu_ps(ptr), _mm_loadu_ps(ptr + 4))); | |||
| #endif | |||
| ptr += 8; | |||
| outptr += 8; | |||
| } | |||
| #endif | |||
| #endif // __SSE2__ | |||
| for (; i < size; i++) | |||
| { | |||
| *outptr++ = float32_to_bfloat16(*ptr++); | |||
| @@ -123,6 +104,14 @@ static void cast_bf16_to_fp32_sse(const Mat& bottom_blob, Mat& top_blob, const O | |||
| } | |||
| #endif | |||
| #if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ | |||
| if (ncnn::cpu_support_x86_avx2()) | |||
| { | |||
| cast_bf16_to_fp32_sse_avx2(bottom_blob, top_blob, opt); | |||
| return; | |||
| } | |||
| #endif | |||
| const int w = bottom_blob.w; | |||
| const int h = bottom_blob.h; | |||
| const int d = bottom_blob.d; | |||
| @@ -138,33 +127,30 @@ static void cast_bf16_to_fp32_sse(const Mat& bottom_blob, Mat& top_blob, const O | |||
| float* outptr = top_blob.channel(q); | |||
| int i = 0; | |||
| #if __AVX512BF16__ | |||
| #if __SSE2__ | |||
| #if __AVX__ | |||
| #if __AVX512F__ | |||
| for (; i + 15 < size; i += 16) | |||
| { | |||
| __m256bh _v_bf16 = (__m256bh)_mm256_loadu_si256((const __m256i*)ptr); | |||
| __m512 _v_fp32 = _mm512_cvtpbh_ps(_v_bf16); | |||
| _mm512_storeu_ps(outptr, _v_fp32); | |||
| _mm512_storeu_ps(outptr, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr))); | |||
| ptr += 16; | |||
| outptr += 16; | |||
| } | |||
| #endif // __AVX512F__ | |||
| for (; i + 7 < size; i += 8) | |||
| { | |||
| __m128bh _v_bf16 = (__m128bh)_mm_loadu_si128((const __m128i*)ptr); | |||
| __m256 _v_fp32 = _mm256_cvtpbh_ps(_v_bf16); | |||
| _mm256_storeu_ps(outptr, _v_fp32); | |||
| _mm256_storeu_ps(outptr, bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr))); | |||
| ptr += 8; | |||
| outptr += 8; | |||
| } | |||
| #elif __AVX__ | |||
| for (; i + 7 < size; i += 8) | |||
| #endif // __AVX__ | |||
| for (; i + 3 < size; i += 4) | |||
| { | |||
| _mm256_storeu_ps(outptr, bfloat2float_avx(_mm_lddqu_si128((__m128i const*)ptr))); | |||
| ptr += 8; | |||
| outptr += 8; | |||
| _mm_storeu_ps(outptr, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr))); | |||
| ptr += 4; | |||
| outptr += 4; | |||
| } | |||
| #endif | |||
| #endif // __SSE2__ | |||
| for (; i < size; i++) | |||
| { | |||
| *outptr++ = bfloat16_to_float32(*ptr++); | |||
| @@ -12,11 +12,6 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #if NCNN_RUNTIME_CPU && NCNN_AVX512FP16 && __AVX512F__ && !__AVX512FP16__ | |||
| void cast_fp32_to_fp16_sse_avx512fp16(const Mat& bottom_blob, Mat& top_blob, const Option& opt); | |||
| void cast_fp16_to_fp32_sse_avx512fp16(const Mat& bottom_blob, Mat& top_blob, const Option& opt); | |||
| #endif | |||
| #if NCNN_RUNTIME_CPU && NCNN_F16C && __AVX__ && !__F16C__ | |||
| void cast_fp32_to_fp16_sse_f16c(const Mat& bottom_blob, Mat& top_blob, const Option& opt); | |||
| void cast_fp16_to_fp32_sse_f16c(const Mat& bottom_blob, Mat& top_blob, const Option& opt); | |||
| @@ -24,14 +19,6 @@ void cast_fp16_to_fp32_sse_f16c(const Mat& bottom_blob, Mat& top_blob, const Opt | |||
| static void cast_fp32_to_fp16_sse(const Mat& bottom_blob, Mat& top_blob, const Option& opt) | |||
| { | |||
| #if NCNN_RUNTIME_CPU && NCNN_AVX512FP16 && __AVX512F__ && !__AVX512FP16__ | |||
| if (ncnn::cpu_support_x86_avx512_fp16()) | |||
| { | |||
| cast_fp32_to_fp16_sse_avx512fp16(bottom_blob, top_blob, opt); | |||
| return; | |||
| } | |||
| #endif | |||
| #if NCNN_RUNTIME_CPU && NCNN_F16C && __AVX__ && !__F16C__ | |||
| if (ncnn::cpu_support_x86_f16c()) | |||
| { | |||
| @@ -55,54 +42,34 @@ static void cast_fp32_to_fp16_sse(const Mat& bottom_blob, Mat& top_blob, const O | |||
| unsigned short* outptr = top_blob.channel(q); | |||
| int i = 0; | |||
| #if __AVX512FP16__ | |||
| #if __F16C__ | |||
| #if __AVX512F__ | |||
| for (; i + 15 < size; i += 16) | |||
| { | |||
| __m512 _v_fp32 = _mm512_loadu_ps(ptr); | |||
| __m256h _v_fp16 = _mm512_cvtxps_ph(_v_fp32); | |||
| _mm256_storeu_si256((__m256i*)outptr, (__m256i)_v_fp16); | |||
| __m256i _v_fp16 = _mm512_cvtps_ph(_v_fp32, _MM_ROUND_NEAREST | _MM_FROUND_NO_EXC); | |||
| _mm256_storeu_si256((__m256i*)outptr, _v_fp16); | |||
| ptr += 16; | |||
| outptr += 16; | |||
| } | |||
| #endif // __AVX512F__ | |||
| for (; i + 7 < size; i += 8) | |||
| { | |||
| __m256 _v_fp32 = _mm256_loadu_ps(ptr); | |||
| __m128h _v_fp16 = _mm256_cvtxps_ph(_v_fp32); | |||
| _mm_storeu_si128((__m128i*)outptr, (__m128i)_v_fp16); | |||
| ptr += 8; | |||
| outptr += 8; | |||
| } | |||
| for (; i + 3 < size; i += 4) | |||
| { | |||
| __m128 _v_fp32 = _mm_loadu_ps(ptr); | |||
| __m128h _v_fp16 = _mm_cvtxps_ph(_v_fp32); | |||
| _mm_storel_epi64((__m128i*)outptr, (__m128i)_v_fp16); | |||
| ptr += 4; | |||
| outptr += 4; | |||
| } | |||
| #elif __F16C__ | |||
| for (; i + 7 < size; i += 8) | |||
| { | |||
| __m256 _v_fp32 = _mm256_loadu_ps(ptr); | |||
| __m128i _v_fp16 = _mm256_cvtps_ph(_v_fp32, _MM_FROUND_TRUNC); | |||
| __m128i _v_fp16 = _mm256_cvtps_ph(_v_fp32, _MM_ROUND_NEAREST | _MM_FROUND_NO_EXC); | |||
| _mm_storeu_si128((__m128i*)outptr, _v_fp16); | |||
| ptr += 8; | |||
| outptr += 8; | |||
| } | |||
| for (; i + 3 < size; i += 4) | |||
| { | |||
| __m128 _v_fp32 = _mm_loadu_ps(ptr); | |||
| __m128i _v_fp16 = _mm_cvtps_ph(_v_fp32, _MM_FROUND_TRUNC); | |||
| __m128i _v_fp16 = _mm_cvtps_ph(_v_fp32, _MM_ROUND_NEAREST | _MM_FROUND_NO_EXC); | |||
| _mm_storel_epi64((__m128i*)outptr, _v_fp16); | |||
| ptr += 4; | |||
| outptr += 4; | |||
| } | |||
| #endif | |||
| #endif // __F16C__ | |||
| for (; i < size; i++) | |||
| { | |||
| *outptr++ = float32_to_float16(*ptr++); | |||
| @@ -112,14 +79,6 @@ static void cast_fp32_to_fp16_sse(const Mat& bottom_blob, Mat& top_blob, const O | |||
| static void cast_fp16_to_fp32_sse(const Mat& bottom_blob, Mat& top_blob, const Option& opt) | |||
| { | |||
| #if NCNN_AVX512FP16 && __AVX512F__ && !__AVX512FP16__ | |||
| if (ncnn::cpu_support_x86_avx512_fp16()) | |||
| { | |||
| cast_fp16_to_fp32_sse_avx512fp16(bottom_blob, top_blob, opt); | |||
| return; | |||
| } | |||
| #endif | |||
| #if NCNN_F16C && __AVX__ && !__F16C__ | |||
| if (ncnn::cpu_support_x86_f16c()) | |||
| { | |||
| @@ -143,41 +102,22 @@ static void cast_fp16_to_fp32_sse(const Mat& bottom_blob, Mat& top_blob, const O | |||
| float* outptr = top_blob.channel(q); | |||
| int i = 0; | |||
| #if __AVX512FP16__ | |||
| #if __F16C__ | |||
| #if __AVX512F__ | |||
| for (; i + 15 < size; i += 16) | |||
| { | |||
| __m256h _v_fp16 = (__m256h)_mm256_loadu_si256((const __m256i*)ptr); | |||
| __m512 _v_fp32 = _mm512_cvtxph_ps(_v_fp16); | |||
| __m256i _v_fp16 = _mm256_loadu_si256((const __m256i*)ptr); | |||
| __m512 _v_fp32 = _mm512_cvtph_ps(_v_fp16); | |||
| _mm512_storeu_ps(outptr, _v_fp32); | |||
| ptr += 16; | |||
| outptr += 16; | |||
| } | |||
| for (; i + 7 < size; i += 8) | |||
| { | |||
| __m128h _v_fp16 = (__m128h)_mm_loadu_si128((const __m128i*)ptr); | |||
| __m256 _v_fp32 = _mm256_cvtxph_ps(_v_fp16); | |||
| _mm256_storeu_ps(outptr, _v_fp32); | |||
| ptr += 8; | |||
| outptr += 8; | |||
| } | |||
| for (; i + 3 < size; i += 4) | |||
| { | |||
| __m128h _v_fp16 = (__m128h)_mm_loadl_epi64((const __m128i*)ptr); | |||
| __m128 _v_fp32 = _mm_cvtxph_ps(_v_fp16); | |||
| _mm_storeu_ps(outptr, _v_fp32); | |||
| ptr += 4; | |||
| outptr += 4; | |||
| } | |||
| #elif __F16C__ | |||
| #endif // __AVX512F__ | |||
| for (; i + 7 < size; i += 8) | |||
| { | |||
| __m128i _v_fp16 = _mm_loadu_si128((const __m128i*)ptr); | |||
| __m256 _v_fp32 = _mm256_cvtph_ps(_v_fp16); | |||
| _mm256_storeu_ps(outptr, _v_fp32); | |||
| ptr += 8; | |||
| outptr += 8; | |||
| } | |||
| @@ -186,11 +126,10 @@ static void cast_fp16_to_fp32_sse(const Mat& bottom_blob, Mat& top_blob, const O | |||
| __m128i _v_fp16 = _mm_loadl_epi64((const __m128i*)ptr); | |||
| __m128 _v_fp32 = _mm_cvtph_ps(_v_fp16); | |||
| _mm_storeu_ps(outptr, _v_fp32); | |||
| ptr += 4; | |||
| outptr += 4; | |||
| } | |||
| #endif | |||
| #endif // __F16C__ | |||
| for (; i < size; i++) | |||
| { | |||
| *outptr++ = float16_to_float32(*ptr++); | |||
| @@ -20,6 +20,7 @@ | |||
| #include <immintrin.h> | |||
| #endif // __AVX__ | |||
| #endif // __SSE2__ | |||
| #include "x86_usability.h" | |||
| #include "cpu.h" | |||
| @@ -1,6 +1,6 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| @@ -14,19 +14,20 @@ | |||
| #include "cpu.h" | |||
| #include "mat.h" | |||
| #include "x86_usability.h" | |||
| namespace ncnn { | |||
| #include "cast_fp16.h" | |||
| #include "cast_bf16.h" | |||
| void cast_fp32_to_fp16_sse_avx512fp16(const Mat& bottom_blob, Mat& top_blob, const Option& opt) | |||
| void cast_fp32_to_bf16_sse_avx2(const Mat& bottom_blob, Mat& top_blob, const Option& opt) | |||
| { | |||
| cast_fp32_to_fp16_sse(bottom_blob, top_blob, opt); | |||
| cast_fp32_to_bf16_sse(bottom_blob, top_blob, opt); | |||
| } | |||
| void cast_fp16_to_fp32_sse_avx512fp16(const Mat& bottom_blob, Mat& top_blob, const Option& opt) | |||
| void cast_bf16_to_fp32_sse_avx2(const Mat& bottom_blob, Mat& top_blob, const Option& opt) | |||
| { | |||
| cast_fp16_to_fp32_sse(bottom_blob, top_blob, opt); | |||
| cast_bf16_to_fp32_sse(bottom_blob, top_blob, opt); | |||
| } | |||
| } // namespace ncnn | |||
| @@ -14,6 +14,7 @@ | |||
| #include "cpu.h" | |||
| #include "mat.h" | |||
| #include "x86_usability.h" | |||
| namespace ncnn { | |||
| @@ -14,6 +14,7 @@ | |||
| #include "cpu.h" | |||
| #include "mat.h" | |||
| #include "x86_usability.h" | |||
| namespace ncnn { | |||
| @@ -18,6 +18,8 @@ | |||
| #include <math.h> | |||
| #if __SSE2__ | |||
| #include <emmintrin.h> | |||
| #if __SSE4_1__ | |||
| #include <smmintrin.h> | |||
| #if __AVX__ | |||
| #include <immintrin.h> | |||
| #if __XOP__ | |||
| @@ -28,6 +30,7 @@ | |||
| #endif | |||
| #endif | |||
| #endif | |||
| #endif | |||
| #endif // __SSE2__ | |||
| static NCNN_FORCEINLINE signed char float2int8(float v) | |||
| @@ -154,6 +157,36 @@ static NCNN_FORCEINLINE __m128i float2int8_sse(const __m128& _v0, const __m128& | |||
| return _v8; | |||
| } | |||
| static NCNN_FORCEINLINE __m128 bfloat2float_sse(const __m128i& v0) | |||
| { | |||
| __m128i _zero = _mm_setzero_si128(); | |||
| __m128i _a = _mm_unpacklo_epi16(_zero, v0); | |||
| __m128 _v = _mm_castsi128_ps(_a); | |||
| return _v; | |||
| } | |||
| static NCNN_FORCEINLINE __m128i float2bfloat_sse(const __m128& v0, const __m128& v1) | |||
| { | |||
| #if __AVX512BF16__ | |||
| __m128i _v = (__m128i)_mm256_cvtneps_pbh(_mm256_insertf128_ps(_mm256_castps128_ps256(v0), v1, 1)); | |||
| #else | |||
| __m128i _a = _mm_castps_si128(v0); | |||
| __m128i _b = _mm_castps_si128(v1); | |||
| #if __SSE4_1__ | |||
| _a = _mm_srli_epi32(_a, 16); | |||
| _b = _mm_srli_epi32(_b, 16); | |||
| __m128i _v = _mm_packus_epi32(_a, _b); | |||
| #else | |||
| _a = _mm_shufflelo_epi16(_a, _MM_SHUFFLE(2, 0, 3, 1)); | |||
| _b = _mm_shufflelo_epi16(_b, _MM_SHUFFLE(2, 0, 3, 1)); | |||
| _a = _mm_shufflehi_epi16(_a, _MM_SHUFFLE(2, 0, 3, 1)); | |||
| _b = _mm_shufflehi_epi16(_b, _MM_SHUFFLE(2, 0, 3, 1)); | |||
| __m128i _v = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(_a), _mm_castsi128_ps(_b), _MM_SHUFFLE(2, 0, 2, 0))); | |||
| #endif | |||
| #endif | |||
| return _v; | |||
| } | |||
| #ifndef __FMA__ | |||
| static NCNN_FORCEINLINE __m128 _mm_comp_fmadd_ps(const __m128& _a, const __m128& _b, const __m128& _c) | |||
| { | |||
| @@ -500,6 +533,71 @@ static NCNN_FORCEINLINE void _mm256_comp_fmadd_ps8(__m256& _sum, | |||
| _mm256_comp_fmadd_ps4(_sum, _w4, _w5, _w6, _w7, _v4, _v5, _v6, _v7); | |||
| } | |||
| static NCNN_FORCEINLINE __m256 bfloat2float_avx(const __m128i& v0) | |||
| { | |||
| #if __AVX512BF16__ | |||
| __m256 _v = _mm256_cvtpbh_ps((__m128bh)v0); | |||
| #else | |||
| __m128i _zero = _mm_setzero_si128(); | |||
| __m128i _a = _mm_unpacklo_epi16(_zero, v0); | |||
| __m128i _b = _mm_unpackhi_epi16(_zero, v0); | |||
| __m256 _v = _mm256_castsi256_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(_a), _b, 1)); | |||
| #endif | |||
| return _v; | |||
| } | |||
| static NCNN_FORCEINLINE __m128i float2bfloat_avx(const __m256& v0) | |||
| { | |||
| #if __AVX512BF16__ | |||
| __m128i _v = (__m128i)_mm256_cvtneps_pbh(v0); | |||
| #else | |||
| __m256i _ab = _mm256_castps_si256(v0); | |||
| #if __AVX2__ | |||
| _ab = _mm256_srli_epi32(_ab, 16); | |||
| __m128i _a = _mm256_extractf128_si256(_ab, 0); | |||
| __m128i _b = _mm256_extractf128_si256(_ab, 1); | |||
| #else | |||
| __m128i _a = _mm256_extractf128_si256(_ab, 0); | |||
| __m128i _b = _mm256_extractf128_si256(_ab, 1); | |||
| _a = _mm_srli_epi32(_a, 16); | |||
| _b = _mm_srli_epi32(_b, 16); | |||
| #endif | |||
| __m128i _v = _mm_packus_epi32(_a, _b); | |||
| #endif | |||
| return _v; | |||
| } | |||
| static NCNN_FORCEINLINE __m256i float2bfloat_avx(const __m256& v0, const __m256& v1) | |||
| { | |||
| #if __AVX512BF16__ | |||
| __m128i _v0 = (__m128i)_mm256_cvtneps_pbh(v0); | |||
| __m128i _v1 = (__m128i)_mm256_cvtneps_pbh(v1); | |||
| __m256i _v = _mm256_insertf128_si256(_mm256_castsi128_si256(_v0), _v1, 1); | |||
| #else | |||
| __m256i _a = _mm256_castps_si256(v0); | |||
| __m256i _b = _mm256_castps_si256(v1); | |||
| #if __AVX2__ | |||
| _a = _mm256_srli_epi32(_a, 16); | |||
| _b = _mm256_srli_epi32(_b, 16); | |||
| __m256i _v = _mm256_packus_epi32(_a, _b); | |||
| _v = _mm256_permute4x64_epi64(_v, _MM_SHUFFLE(3, 1, 2, 0)); | |||
| #else | |||
| __m128i _a0 = _mm256_extractf128_si256(_a, 0); | |||
| __m128i _a1 = _mm256_extractf128_si256(_a, 1); | |||
| __m128i _b0 = _mm256_extractf128_si256(_b, 0); | |||
| __m128i _b1 = _mm256_extractf128_si256(_b, 1); | |||
| _a0 = _mm_srli_epi32(_a0, 16); | |||
| _a1 = _mm_srli_epi32(_a1, 16); | |||
| _b0 = _mm_srli_epi32(_b0, 16); | |||
| _b1 = _mm_srli_epi32(_b1, 16); | |||
| __m128i _v0 = _mm_packus_epi32(_a0, _a1); | |||
| __m128i _v1 = _mm_packus_epi32(_b0, _b1); | |||
| __m256i _v = _mm256_insertf128_si256(_mm256_castsi128_si256(_v0), _v1, 1); | |||
| #endif | |||
| #endif | |||
| return _v; | |||
| } | |||
| #if __AVX512F__ | |||
| static NCNN_FORCEINLINE void transpose16x16_ps(__m512& _r0, __m512& _r1, __m512& _r2, __m512& _r3, __m512& _r4, __m512& _r5, __m512& _r6, __m512& _r7, | |||
| __m512& _r8, __m512& _r9, __m512& _ra, __m512& _rb, __m512& _rc, __m512& _rd, __m512& _re, __m512& _rf) | |||
| @@ -886,6 +984,55 @@ static NCNN_FORCEINLINE float _mm512_comp_reduce_max_ps(__m512 x) | |||
| const __m128 x32 = _mm_max_ss(x64, _mm_shuffle_ps(x64, x64, 0x55)); | |||
| return _mm_cvtss_f32(x32); | |||
| } | |||
| static NCNN_FORCEINLINE __m512 bfloat2float_avx512(const __m256i& v0) | |||
| { | |||
| #if __AVX512BF16__ | |||
| __m512 _v = _mm512_cvtpbh_ps((__m256bh)v0); | |||
| #else | |||
| __m256i _zero = _mm256_setzero_si256(); | |||
| __m256i _a = _mm256_unpacklo_epi16(_zero, v0); | |||
| __m256i _b = _mm256_unpackhi_epi16(_zero, v0); | |||
| __m256i _c = _mm256_permute2x128_si256(_a, _b, _MM_SHUFFLE(0, 2, 0, 0)); | |||
| __m256i _d = _mm256_permute2x128_si256(_a, _b, _MM_SHUFFLE(0, 3, 0, 1)); | |||
| __m512 _v = _mm512_castsi512_ps(_mm512_inserti32x8(_mm512_castsi256_si512(_c), _d, 1)); | |||
| #endif | |||
| return _v; | |||
| } | |||
| static NCNN_FORCEINLINE __m256i float2bfloat_avx512(const __m512& v0) | |||
| { | |||
| #if __AVX512BF16__ | |||
| __m256i _v = (__m256i)_mm512_cvtneps_pbh(v0); | |||
| #else | |||
| __m512i _ab = _mm512_castps_si512(v0); | |||
| _ab = _mm512_srli_epi32(_ab, 16); | |||
| __m256i _a = _mm512_extracti32x8_epi32(_ab, 0); | |||
| __m256i _b = _mm512_extracti32x8_epi32(_ab, 1); | |||
| __m256i _v = _mm256_packus_epi32(_a, _b); | |||
| _v = _mm256_permute4x64_epi64(_v, _MM_SHUFFLE(3, 1, 2, 0)); | |||
| #endif | |||
| return _v; | |||
| } | |||
| static NCNN_FORCEINLINE __m512i float2bfloat_avx512(const __m512& v0, const __m512& v1) | |||
| { | |||
| #if __AVX512BF16__ | |||
| __m256bh _v0 = _mm512_cvtneps_pbh(v0); | |||
| __m256bh _v1 = _mm512_cvtneps_pbh(v1); | |||
| __m512i _v = _mm512_inserti32x8(_mm512_castsi256_si512((__m256i)_v0), (__m256i)_v1, 1); | |||
| #else | |||
| __m512i _a = _mm512_castps_si512(v0); | |||
| __m512i _b = _mm512_castps_si512(v1); | |||
| _a = _mm512_srli_epi32(_a, 16); | |||
| _b = _mm512_srli_epi32(_b, 16); | |||
| __m512i _v = _mm512_packus_epi32(_a, _b); | |||
| _v = _mm512_permutex_epi64(_v, _MM_SHUFFLE(3, 1, 2, 0)); | |||
| _v = _mm512_shuffle_i32x4(_v, _v, _MM_SHUFFLE(3, 1, 2, 0)); | |||
| #endif | |||
| return _v; | |||
| } | |||
| #endif // __AVX512F__ | |||
| #endif // __AVX__ | |||
| #endif // __SSE2__ | |||