diff --git a/cmake/ncnn_add_layer.cmake b/cmake/ncnn_add_layer.cmake index a5e25e963..b2b026401 100644 --- a/cmake/ncnn_add_layer.cmake +++ b/cmake/ncnn_add_layer.cmake @@ -160,28 +160,28 @@ macro(ncnn_add_layer class) if(NCNN_TARGET_ARCH STREQUAL "x86") if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC" OR (CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_SIMULATE_ID MATCHES "MSVC" AND CMAKE_CXX_COMPILER_FRONTEND_VARIANT MATCHES "MSVC")) if(NCNN_RUNTIME_CPU AND NCNN_AVX512) - ncnn_add_arch_opt_layer(${class} avx512 "/arch:AVX512 /D__FMA__ /D__F16C__") + ncnn_add_arch_opt_layer(${class} avx512 "/arch:AVX512 /D__SSE4_1__ /D__FMA__ /D__F16C__") endif() if(NCNN_RUNTIME_CPU AND NCNN_FMA) - ncnn_add_arch_opt_layer(${class} fma "/arch:AVX /D__FMA__ /D__F16C__") + ncnn_add_arch_opt_layer(${class} fma "/arch:AVX /D__SSE4_1__ /D__FMA__ /D__F16C__") endif() if(NCNN_RUNTIME_CPU AND NCNN_AVX) - ncnn_add_arch_opt_layer(${class} avx "/arch:AVX") + ncnn_add_arch_opt_layer(${class} avx "/arch:AVX /D__SSE4_1__") endif() if(NCNN_AVX512VNNI) - ncnn_add_arch_opt_source(${class} avx512vnni "/arch:AVX512 /D__FMA__ /D__F16C__ /D__AVX512VNNI__") + ncnn_add_arch_opt_source(${class} avx512vnni "/arch:AVX512 /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVX512VNNI__") endif() if(NCNN_AVXVNNI) - ncnn_add_arch_opt_source(${class} avxvnni "/arch:AVX2 /D__FMA__ /D__F16C__ /D__AVXVNNI__") + ncnn_add_arch_opt_source(${class} avxvnni "/arch:AVX2 /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVXVNNI__") endif() if(NCNN_AVX2) - ncnn_add_arch_opt_source(${class} avx2 "/arch:AVX2 /D__FMA__ /D__F16C__") + ncnn_add_arch_opt_source(${class} avx2 "/arch:AVX2 /D__SSE4_1__ /D__FMA__ /D__F16C__") endif() if(NCNN_XOP) - ncnn_add_arch_opt_source(${class} xop "/arch:AVX /D__XOP__") + ncnn_add_arch_opt_source(${class} xop "/arch:AVX /D__SSE4_1__ /D__XOP__") endif() if(NCNN_F16C) - ncnn_add_arch_opt_source(${class} f16c "/arch:AVX /D__F16C__") + ncnn_add_arch_opt_source(${class} f16c "/arch:AVX /D__SSE4_1__ /D__F16C__") endif() else() if(NCNN_RUNTIME_CPU AND NCNN_AVX512) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 09eaacf22..61b495b81 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -342,7 +342,7 @@ if(NCNN_TARGET_ARCH STREQUAL "x86") if(NOT NCNN_RUNTIME_CPU AND NCNN_AVX512) if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC" OR (CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_SIMULATE_ID MATCHES "MSVC" AND CMAKE_CXX_COMPILER_FRONTEND_VARIANT MATCHES "MSVC")) - target_compile_options(ncnn PRIVATE /arch:AVX512 /D__FMA__ /D__F16C__) + target_compile_options(ncnn PRIVATE /arch:AVX512 /D__SSE4_1__ /D__FMA__ /D__F16C__) if(NCNN_AVX512VNNI) target_compile_options(ncnn PRIVATE /D__AVX512VNNI__) endif() @@ -361,9 +361,9 @@ if(NCNN_TARGET_ARCH STREQUAL "x86") elseif(NOT NCNN_RUNTIME_CPU AND NCNN_FMA) if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC" OR (CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_SIMULATE_ID MATCHES "MSVC" AND CMAKE_CXX_COMPILER_FRONTEND_VARIANT MATCHES "MSVC")) if(NCNN_AVX2) - target_compile_options(ncnn PRIVATE /arch:AVX2 /D__FMA__) + target_compile_options(ncnn PRIVATE /arch:AVX2 /D__SSE4_1__ /D__FMA__) else() - target_compile_options(ncnn PRIVATE /arch:AVX /D__FMA__) + target_compile_options(ncnn PRIVATE /arch:AVX /D__SSE4_1__ /D__FMA__) endif() if(NCNN_AVXVNNI) target_compile_options(ncnn PRIVATE /D__AVXVNNI__) @@ -390,7 +390,7 @@ if(NCNN_TARGET_ARCH STREQUAL "x86") endif() elseif(NOT NCNN_RUNTIME_CPU AND NCNN_AVX) if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC" OR (CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_SIMULATE_ID MATCHES "MSVC" AND CMAKE_CXX_COMPILER_FRONTEND_VARIANT MATCHES "MSVC")) - target_compile_options(ncnn PRIVATE /arch:AVX) + target_compile_options(ncnn PRIVATE /arch:AVX /D__SSE4_1__) if(NCNN_XOP) target_compile_options(ncnn PRIVATE /D__XOP__) endif() diff --git a/src/layer/x86/cast_bf16.h b/src/layer/x86/cast_bf16.h index ea533054b..15939b926 100644 --- a/src/layer/x86/cast_bf16.h +++ b/src/layer/x86/cast_bf16.h @@ -17,36 +17,10 @@ void cast_fp32_to_bf16_sse_avx512bf16(const Mat& bottom_blob, Mat& top_blob, con void cast_bf16_to_fp32_sse_avx512bf16(const Mat& bottom_blob, Mat& top_blob, const Option& opt); #endif -#if __AVX__ -static inline __m256 bfloat2float_avx(__m128i v0) -{ - __m128i zero = _mm_set1_epi32(0); - __m128i a = _mm_slli_epi32(_mm_unpacklo_epi16(v0, zero), 16); - __m128i b = _mm_slli_epi32(_mm_unpackhi_epi16(v0, zero), 16); - __m256i ab = _mm256_set1_epi32(0); - ab = _mm256_insertf128_si256(ab, a, 0); // insert in low 128-bit lane - ab = _mm256_insertf128_si256(ab, b, 1); // insert in high 128-bit lane - return _mm256_castsi256_ps(ab); -} -#if __AVX2__ -static inline __m256i float2bfloat_avx(__m256 v0, __m256 v1) -{ - __m256i a = _mm256_castps_si256(v0); - a = _mm256_srli_epi32(a, 16); - __m256i b = _mm256_castps_si256(v1); - b = _mm256_srli_epi32(b, 16); - __m256i abab = _mm256_packus_epi32(a, b); - return _mm256_permutevar8x32_epi32(abab, _mm256_setr_epi32(0, 1, 4, 5, 2, 3, 6, 7)); -} -static inline __m128i float2bfloat_avx(__m256 v0) -{ - __m256i a = _mm256_castps_si256(v0); - a = _mm256_srli_epi32(a, 16); - __m256i aaaa = _mm256_packus_epi32(a, a); - return _mm256_castsi256_si128(_mm256_permutevar8x32_epi32(aaaa, _mm256_setr_epi32(0, 1, 4, 5, 2, 3, 6, 7))); -} +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ +void cast_fp32_to_bf16_sse_avx2(const Mat& bottom_blob, Mat& top_blob, const Option& opt); +void cast_bf16_to_fp32_sse_avx2(const Mat& bottom_blob, Mat& top_blob, const Option& opt); #endif -#endif // __AVX__ static void cast_fp32_to_bf16_sse(const Mat& bottom_blob, Mat& top_blob, const Option& opt) { @@ -58,6 +32,14 @@ static void cast_fp32_to_bf16_sse(const Mat& bottom_blob, Mat& top_blob, const O } #endif +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ + if (ncnn::cpu_support_x86_avx2()) + { + cast_fp32_to_bf16_sse_avx2(bottom_blob, top_blob, opt); + return; + } +#endif + const int w = bottom_blob.w; const int h = bottom_blob.h; const int d = bottom_blob.d; @@ -73,39 +55,38 @@ static void cast_fp32_to_bf16_sse(const Mat& bottom_blob, Mat& top_blob, const O unsigned short* outptr = top_blob.channel(q); int i = 0; -#if __AVX512BF16__ - for (; i + 15 < size; i += 16) - { - __m512 _v_fp32 = _mm512_loadu_ps(ptr); - __m256bh _v_bf16 = _mm512_cvtneps_pbh(_v_fp32); - _mm256_storeu_si256((__m256i*)outptr, (__m256i)_v_bf16); - - ptr += 16; - outptr += 16; - } - for (; i + 7 < size; i += 8) +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; i + 31 < size; i += 32) { - __m256 _v_fp32 = _mm256_loadu_ps(ptr); - __m128bh _v_bf16 = _mm256_cvtneps_pbh(_v_fp32); - _mm_storeu_si128((__m128i*)outptr, (__m128i)_v_bf16); - - ptr += 8; - outptr += 8; + _mm512_storeu_si512((__m512i*)outptr, float2bfloat_avx512(_mm512_loadu_ps(ptr), _mm512_loadu_ps(ptr + 16))); + ptr += 32; + outptr += 32; } -#elif __AVX2__ +#endif // __AVX512F__ for (; i + 15 < size; i += 16) { +#if __AVX512F__ + _mm256_storeu_si256((__m256i*)outptr, float2bfloat_avx512(_mm512_loadu_ps(ptr))); +#else _mm256_storeu_si256((__m256i*)outptr, float2bfloat_avx(_mm256_loadu_ps(ptr), _mm256_loadu_ps(ptr + 8))); +#endif ptr += 16; outptr += 16; } +#endif // __AVX__ for (; i + 7 < size; i += 8) { +#if __AVX__ _mm_store_si128((__m128i*)outptr, float2bfloat_avx(_mm256_loadu_ps(ptr))); +#else + _mm_store_si128((__m128i*)outptr, float2bfloat_sse(_mm_loadu_ps(ptr), _mm_loadu_ps(ptr + 4))); +#endif ptr += 8; outptr += 8; } -#endif +#endif // __SSE2__ for (; i < size; i++) { *outptr++ = float32_to_bfloat16(*ptr++); @@ -123,6 +104,14 @@ static void cast_bf16_to_fp32_sse(const Mat& bottom_blob, Mat& top_blob, const O } #endif +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ + if (ncnn::cpu_support_x86_avx2()) + { + cast_bf16_to_fp32_sse_avx2(bottom_blob, top_blob, opt); + return; + } +#endif + const int w = bottom_blob.w; const int h = bottom_blob.h; const int d = bottom_blob.d; @@ -138,33 +127,30 @@ static void cast_bf16_to_fp32_sse(const Mat& bottom_blob, Mat& top_blob, const O float* outptr = top_blob.channel(q); int i = 0; -#if __AVX512BF16__ +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ for (; i + 15 < size; i += 16) { - __m256bh _v_bf16 = (__m256bh)_mm256_loadu_si256((const __m256i*)ptr); - __m512 _v_fp32 = _mm512_cvtpbh_ps(_v_bf16); - _mm512_storeu_ps(outptr, _v_fp32); - + _mm512_storeu_ps(outptr, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr))); ptr += 16; outptr += 16; } +#endif // __AVX512F__ for (; i + 7 < size; i += 8) { - __m128bh _v_bf16 = (__m128bh)_mm_loadu_si128((const __m128i*)ptr); - __m256 _v_fp32 = _mm256_cvtpbh_ps(_v_bf16); - _mm256_storeu_ps(outptr, _v_fp32); - + _mm256_storeu_ps(outptr, bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr))); ptr += 8; outptr += 8; } -#elif __AVX__ - for (; i + 7 < size; i += 8) +#endif // __AVX__ + for (; i + 3 < size; i += 4) { - _mm256_storeu_ps(outptr, bfloat2float_avx(_mm_lddqu_si128((__m128i const*)ptr))); - ptr += 8; - outptr += 8; + _mm_storeu_ps(outptr, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr))); + ptr += 4; + outptr += 4; } -#endif +#endif // __SSE2__ for (; i < size; i++) { *outptr++ = bfloat16_to_float32(*ptr++); diff --git a/src/layer/x86/cast_fp16.h b/src/layer/x86/cast_fp16.h index 8fba9748c..fc1902654 100644 --- a/src/layer/x86/cast_fp16.h +++ b/src/layer/x86/cast_fp16.h @@ -12,11 +12,6 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#if NCNN_RUNTIME_CPU && NCNN_AVX512FP16 && __AVX512F__ && !__AVX512FP16__ -void cast_fp32_to_fp16_sse_avx512fp16(const Mat& bottom_blob, Mat& top_blob, const Option& opt); -void cast_fp16_to_fp32_sse_avx512fp16(const Mat& bottom_blob, Mat& top_blob, const Option& opt); -#endif - #if NCNN_RUNTIME_CPU && NCNN_F16C && __AVX__ && !__F16C__ void cast_fp32_to_fp16_sse_f16c(const Mat& bottom_blob, Mat& top_blob, const Option& opt); void cast_fp16_to_fp32_sse_f16c(const Mat& bottom_blob, Mat& top_blob, const Option& opt); @@ -24,14 +19,6 @@ void cast_fp16_to_fp32_sse_f16c(const Mat& bottom_blob, Mat& top_blob, const Opt static void cast_fp32_to_fp16_sse(const Mat& bottom_blob, Mat& top_blob, const Option& opt) { -#if NCNN_RUNTIME_CPU && NCNN_AVX512FP16 && __AVX512F__ && !__AVX512FP16__ - if (ncnn::cpu_support_x86_avx512_fp16()) - { - cast_fp32_to_fp16_sse_avx512fp16(bottom_blob, top_blob, opt); - return; - } -#endif - #if NCNN_RUNTIME_CPU && NCNN_F16C && __AVX__ && !__F16C__ if (ncnn::cpu_support_x86_f16c()) { @@ -55,54 +42,34 @@ static void cast_fp32_to_fp16_sse(const Mat& bottom_blob, Mat& top_blob, const O unsigned short* outptr = top_blob.channel(q); int i = 0; -#if __AVX512FP16__ +#if __F16C__ +#if __AVX512F__ for (; i + 15 < size; i += 16) { __m512 _v_fp32 = _mm512_loadu_ps(ptr); - __m256h _v_fp16 = _mm512_cvtxps_ph(_v_fp32); - _mm256_storeu_si256((__m256i*)outptr, (__m256i)_v_fp16); - + __m256i _v_fp16 = _mm512_cvtps_ph(_v_fp32, _MM_ROUND_NEAREST | _MM_FROUND_NO_EXC); + _mm256_storeu_si256((__m256i*)outptr, _v_fp16); ptr += 16; outptr += 16; } +#endif // __AVX512F__ for (; i + 7 < size; i += 8) { __m256 _v_fp32 = _mm256_loadu_ps(ptr); - __m128h _v_fp16 = _mm256_cvtxps_ph(_v_fp32); - _mm_storeu_si128((__m128i*)outptr, (__m128i)_v_fp16); - - ptr += 8; - outptr += 8; - } - for (; i + 3 < size; i += 4) - { - __m128 _v_fp32 = _mm_loadu_ps(ptr); - __m128h _v_fp16 = _mm_cvtxps_ph(_v_fp32); - _mm_storel_epi64((__m128i*)outptr, (__m128i)_v_fp16); - - ptr += 4; - outptr += 4; - } -#elif __F16C__ - for (; i + 7 < size; i += 8) - { - __m256 _v_fp32 = _mm256_loadu_ps(ptr); - __m128i _v_fp16 = _mm256_cvtps_ph(_v_fp32, _MM_FROUND_TRUNC); + __m128i _v_fp16 = _mm256_cvtps_ph(_v_fp32, _MM_ROUND_NEAREST | _MM_FROUND_NO_EXC); _mm_storeu_si128((__m128i*)outptr, _v_fp16); - ptr += 8; outptr += 8; } for (; i + 3 < size; i += 4) { __m128 _v_fp32 = _mm_loadu_ps(ptr); - __m128i _v_fp16 = _mm_cvtps_ph(_v_fp32, _MM_FROUND_TRUNC); + __m128i _v_fp16 = _mm_cvtps_ph(_v_fp32, _MM_ROUND_NEAREST | _MM_FROUND_NO_EXC); _mm_storel_epi64((__m128i*)outptr, _v_fp16); - ptr += 4; outptr += 4; } -#endif +#endif // __F16C__ for (; i < size; i++) { *outptr++ = float32_to_float16(*ptr++); @@ -112,14 +79,6 @@ static void cast_fp32_to_fp16_sse(const Mat& bottom_blob, Mat& top_blob, const O static void cast_fp16_to_fp32_sse(const Mat& bottom_blob, Mat& top_blob, const Option& opt) { -#if NCNN_AVX512FP16 && __AVX512F__ && !__AVX512FP16__ - if (ncnn::cpu_support_x86_avx512_fp16()) - { - cast_fp16_to_fp32_sse_avx512fp16(bottom_blob, top_blob, opt); - return; - } -#endif - #if NCNN_F16C && __AVX__ && !__F16C__ if (ncnn::cpu_support_x86_f16c()) { @@ -143,41 +102,22 @@ static void cast_fp16_to_fp32_sse(const Mat& bottom_blob, Mat& top_blob, const O float* outptr = top_blob.channel(q); int i = 0; -#if __AVX512FP16__ +#if __F16C__ +#if __AVX512F__ for (; i + 15 < size; i += 16) { - __m256h _v_fp16 = (__m256h)_mm256_loadu_si256((const __m256i*)ptr); - __m512 _v_fp32 = _mm512_cvtxph_ps(_v_fp16); + __m256i _v_fp16 = _mm256_loadu_si256((const __m256i*)ptr); + __m512 _v_fp32 = _mm512_cvtph_ps(_v_fp16); _mm512_storeu_ps(outptr, _v_fp32); - ptr += 16; outptr += 16; } - for (; i + 7 < size; i += 8) - { - __m128h _v_fp16 = (__m128h)_mm_loadu_si128((const __m128i*)ptr); - __m256 _v_fp32 = _mm256_cvtxph_ps(_v_fp16); - _mm256_storeu_ps(outptr, _v_fp32); - - ptr += 8; - outptr += 8; - } - for (; i + 3 < size; i += 4) - { - __m128h _v_fp16 = (__m128h)_mm_loadl_epi64((const __m128i*)ptr); - __m128 _v_fp32 = _mm_cvtxph_ps(_v_fp16); - _mm_storeu_ps(outptr, _v_fp32); - - ptr += 4; - outptr += 4; - } -#elif __F16C__ +#endif // __AVX512F__ for (; i + 7 < size; i += 8) { __m128i _v_fp16 = _mm_loadu_si128((const __m128i*)ptr); __m256 _v_fp32 = _mm256_cvtph_ps(_v_fp16); _mm256_storeu_ps(outptr, _v_fp32); - ptr += 8; outptr += 8; } @@ -186,11 +126,10 @@ static void cast_fp16_to_fp32_sse(const Mat& bottom_blob, Mat& top_blob, const O __m128i _v_fp16 = _mm_loadl_epi64((const __m128i*)ptr); __m128 _v_fp32 = _mm_cvtph_ps(_v_fp16); _mm_storeu_ps(outptr, _v_fp32); - ptr += 4; outptr += 4; } -#endif +#endif // __F16C__ for (; i < size; i++) { *outptr++ = float16_to_float32(*ptr++); diff --git a/src/layer/x86/cast_x86.cpp b/src/layer/x86/cast_x86.cpp index 845d204a8..032d4307b 100644 --- a/src/layer/x86/cast_x86.cpp +++ b/src/layer/x86/cast_x86.cpp @@ -20,6 +20,7 @@ #include #endif // __AVX__ #endif // __SSE2__ +#include "x86_usability.h" #include "cpu.h" diff --git a/src/layer/x86/cast_x86_avx512fp16.cpp b/src/layer/x86/cast_x86_avx2.cpp similarity index 64% rename from src/layer/x86/cast_x86_avx512fp16.cpp rename to src/layer/x86/cast_x86_avx2.cpp index 9ae31288e..a4cb82050 100644 --- a/src/layer/x86/cast_x86_avx512fp16.cpp +++ b/src/layer/x86/cast_x86_avx2.cpp @@ -1,6 +1,6 @@ // Tencent is pleased to support the open source community by making ncnn available. // -// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. // // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except // in compliance with the License. You may obtain a copy of the License at @@ -14,19 +14,20 @@ #include "cpu.h" #include "mat.h" +#include "x86_usability.h" namespace ncnn { -#include "cast_fp16.h" +#include "cast_bf16.h" -void cast_fp32_to_fp16_sse_avx512fp16(const Mat& bottom_blob, Mat& top_blob, const Option& opt) +void cast_fp32_to_bf16_sse_avx2(const Mat& bottom_blob, Mat& top_blob, const Option& opt) { - cast_fp32_to_fp16_sse(bottom_blob, top_blob, opt); + cast_fp32_to_bf16_sse(bottom_blob, top_blob, opt); } -void cast_fp16_to_fp32_sse_avx512fp16(const Mat& bottom_blob, Mat& top_blob, const Option& opt) +void cast_bf16_to_fp32_sse_avx2(const Mat& bottom_blob, Mat& top_blob, const Option& opt) { - cast_fp16_to_fp32_sse(bottom_blob, top_blob, opt); + cast_bf16_to_fp32_sse(bottom_blob, top_blob, opt); } } // namespace ncnn diff --git a/src/layer/x86/cast_x86_avx512bf16.cpp b/src/layer/x86/cast_x86_avx512bf16.cpp index f3bb65506..d5a9ff35c 100644 --- a/src/layer/x86/cast_x86_avx512bf16.cpp +++ b/src/layer/x86/cast_x86_avx512bf16.cpp @@ -14,6 +14,7 @@ #include "cpu.h" #include "mat.h" +#include "x86_usability.h" namespace ncnn { diff --git a/src/layer/x86/cast_x86_f16c.cpp b/src/layer/x86/cast_x86_f16c.cpp index 160855da2..41e923105 100644 --- a/src/layer/x86/cast_x86_f16c.cpp +++ b/src/layer/x86/cast_x86_f16c.cpp @@ -14,6 +14,7 @@ #include "cpu.h" #include "mat.h" +#include "x86_usability.h" namespace ncnn { diff --git a/src/layer/x86/x86_usability.h b/src/layer/x86/x86_usability.h index 669cec0a7..860b063e6 100644 --- a/src/layer/x86/x86_usability.h +++ b/src/layer/x86/x86_usability.h @@ -18,6 +18,8 @@ #include #if __SSE2__ #include +#if __SSE4_1__ +#include #if __AVX__ #include #if __XOP__ @@ -28,6 +30,7 @@ #endif #endif #endif +#endif #endif // __SSE2__ static NCNN_FORCEINLINE signed char float2int8(float v) @@ -154,6 +157,36 @@ static NCNN_FORCEINLINE __m128i float2int8_sse(const __m128& _v0, const __m128& return _v8; } +static NCNN_FORCEINLINE __m128 bfloat2float_sse(const __m128i& v0) +{ + __m128i _zero = _mm_setzero_si128(); + __m128i _a = _mm_unpacklo_epi16(_zero, v0); + __m128 _v = _mm_castsi128_ps(_a); + return _v; +} + +static NCNN_FORCEINLINE __m128i float2bfloat_sse(const __m128& v0, const __m128& v1) +{ +#if __AVX512BF16__ + __m128i _v = (__m128i)_mm256_cvtneps_pbh(_mm256_insertf128_ps(_mm256_castps128_ps256(v0), v1, 1)); +#else + __m128i _a = _mm_castps_si128(v0); + __m128i _b = _mm_castps_si128(v1); +#if __SSE4_1__ + _a = _mm_srli_epi32(_a, 16); + _b = _mm_srli_epi32(_b, 16); + __m128i _v = _mm_packus_epi32(_a, _b); +#else + _a = _mm_shufflelo_epi16(_a, _MM_SHUFFLE(2, 0, 3, 1)); + _b = _mm_shufflelo_epi16(_b, _MM_SHUFFLE(2, 0, 3, 1)); + _a = _mm_shufflehi_epi16(_a, _MM_SHUFFLE(2, 0, 3, 1)); + _b = _mm_shufflehi_epi16(_b, _MM_SHUFFLE(2, 0, 3, 1)); + __m128i _v = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(_a), _mm_castsi128_ps(_b), _MM_SHUFFLE(2, 0, 2, 0))); +#endif +#endif + return _v; +} + #ifndef __FMA__ static NCNN_FORCEINLINE __m128 _mm_comp_fmadd_ps(const __m128& _a, const __m128& _b, const __m128& _c) { @@ -500,6 +533,71 @@ static NCNN_FORCEINLINE void _mm256_comp_fmadd_ps8(__m256& _sum, _mm256_comp_fmadd_ps4(_sum, _w4, _w5, _w6, _w7, _v4, _v5, _v6, _v7); } +static NCNN_FORCEINLINE __m256 bfloat2float_avx(const __m128i& v0) +{ +#if __AVX512BF16__ + __m256 _v = _mm256_cvtpbh_ps((__m128bh)v0); +#else + __m128i _zero = _mm_setzero_si128(); + __m128i _a = _mm_unpacklo_epi16(_zero, v0); + __m128i _b = _mm_unpackhi_epi16(_zero, v0); + __m256 _v = _mm256_castsi256_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(_a), _b, 1)); +#endif + return _v; +} + +static NCNN_FORCEINLINE __m128i float2bfloat_avx(const __m256& v0) +{ +#if __AVX512BF16__ + __m128i _v = (__m128i)_mm256_cvtneps_pbh(v0); +#else + __m256i _ab = _mm256_castps_si256(v0); +#if __AVX2__ + _ab = _mm256_srli_epi32(_ab, 16); + __m128i _a = _mm256_extractf128_si256(_ab, 0); + __m128i _b = _mm256_extractf128_si256(_ab, 1); +#else + __m128i _a = _mm256_extractf128_si256(_ab, 0); + __m128i _b = _mm256_extractf128_si256(_ab, 1); + _a = _mm_srli_epi32(_a, 16); + _b = _mm_srli_epi32(_b, 16); +#endif + __m128i _v = _mm_packus_epi32(_a, _b); +#endif + return _v; +} + +static NCNN_FORCEINLINE __m256i float2bfloat_avx(const __m256& v0, const __m256& v1) +{ +#if __AVX512BF16__ + __m128i _v0 = (__m128i)_mm256_cvtneps_pbh(v0); + __m128i _v1 = (__m128i)_mm256_cvtneps_pbh(v1); + __m256i _v = _mm256_insertf128_si256(_mm256_castsi128_si256(_v0), _v1, 1); +#else + __m256i _a = _mm256_castps_si256(v0); + __m256i _b = _mm256_castps_si256(v1); +#if __AVX2__ + _a = _mm256_srli_epi32(_a, 16); + _b = _mm256_srli_epi32(_b, 16); + __m256i _v = _mm256_packus_epi32(_a, _b); + _v = _mm256_permute4x64_epi64(_v, _MM_SHUFFLE(3, 1, 2, 0)); +#else + __m128i _a0 = _mm256_extractf128_si256(_a, 0); + __m128i _a1 = _mm256_extractf128_si256(_a, 1); + __m128i _b0 = _mm256_extractf128_si256(_b, 0); + __m128i _b1 = _mm256_extractf128_si256(_b, 1); + _a0 = _mm_srli_epi32(_a0, 16); + _a1 = _mm_srli_epi32(_a1, 16); + _b0 = _mm_srli_epi32(_b0, 16); + _b1 = _mm_srli_epi32(_b1, 16); + __m128i _v0 = _mm_packus_epi32(_a0, _a1); + __m128i _v1 = _mm_packus_epi32(_b0, _b1); + __m256i _v = _mm256_insertf128_si256(_mm256_castsi128_si256(_v0), _v1, 1); +#endif +#endif + return _v; +} + #if __AVX512F__ static NCNN_FORCEINLINE void transpose16x16_ps(__m512& _r0, __m512& _r1, __m512& _r2, __m512& _r3, __m512& _r4, __m512& _r5, __m512& _r6, __m512& _r7, __m512& _r8, __m512& _r9, __m512& _ra, __m512& _rb, __m512& _rc, __m512& _rd, __m512& _re, __m512& _rf) @@ -886,6 +984,55 @@ static NCNN_FORCEINLINE float _mm512_comp_reduce_max_ps(__m512 x) const __m128 x32 = _mm_max_ss(x64, _mm_shuffle_ps(x64, x64, 0x55)); return _mm_cvtss_f32(x32); } + +static NCNN_FORCEINLINE __m512 bfloat2float_avx512(const __m256i& v0) +{ +#if __AVX512BF16__ + __m512 _v = _mm512_cvtpbh_ps((__m256bh)v0); +#else + __m256i _zero = _mm256_setzero_si256(); + __m256i _a = _mm256_unpacklo_epi16(_zero, v0); + __m256i _b = _mm256_unpackhi_epi16(_zero, v0); + __m256i _c = _mm256_permute2x128_si256(_a, _b, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i _d = _mm256_permute2x128_si256(_a, _b, _MM_SHUFFLE(0, 3, 0, 1)); + __m512 _v = _mm512_castsi512_ps(_mm512_inserti32x8(_mm512_castsi256_si512(_c), _d, 1)); +#endif + return _v; +} + +static NCNN_FORCEINLINE __m256i float2bfloat_avx512(const __m512& v0) +{ +#if __AVX512BF16__ + __m256i _v = (__m256i)_mm512_cvtneps_pbh(v0); +#else + __m512i _ab = _mm512_castps_si512(v0); + _ab = _mm512_srli_epi32(_ab, 16); + __m256i _a = _mm512_extracti32x8_epi32(_ab, 0); + __m256i _b = _mm512_extracti32x8_epi32(_ab, 1); + __m256i _v = _mm256_packus_epi32(_a, _b); + _v = _mm256_permute4x64_epi64(_v, _MM_SHUFFLE(3, 1, 2, 0)); +#endif + return _v; +} + +static NCNN_FORCEINLINE __m512i float2bfloat_avx512(const __m512& v0, const __m512& v1) +{ +#if __AVX512BF16__ + __m256bh _v0 = _mm512_cvtneps_pbh(v0); + __m256bh _v1 = _mm512_cvtneps_pbh(v1); + __m512i _v = _mm512_inserti32x8(_mm512_castsi256_si512((__m256i)_v0), (__m256i)_v1, 1); +#else + __m512i _a = _mm512_castps_si512(v0); + __m512i _b = _mm512_castps_si512(v1); + _a = _mm512_srli_epi32(_a, 16); + _b = _mm512_srli_epi32(_b, 16); + __m512i _v = _mm512_packus_epi32(_a, _b); + _v = _mm512_permutex_epi64(_v, _MM_SHUFFLE(3, 1, 2, 0)); + _v = _mm512_shuffle_i32x4(_v, _v, _MM_SHUFFLE(3, 1, 2, 0)); +#endif + return _v; +} + #endif // __AVX512F__ #endif // __AVX__ #endif // __SSE2__