Browse Source

x86 bfloat16 cast functions (#4491)

* simplify cast fp16 avx512 dispatch

* define sse4.1 macro on msvc avx+
tags/20230223
nihui GitHub 3 years ago
parent
commit
d2d012dce5
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 232 additions and 156 deletions
  1. +8
    -8
      cmake/ncnn_add_layer.cmake
  2. +4
    -4
      src/CMakeLists.txt
  3. +49
    -63
      src/layer/x86/cast_bf16.h
  4. +14
    -75
      src/layer/x86/cast_fp16.h
  5. +1
    -0
      src/layer/x86/cast_x86.cpp
  6. +7
    -6
      src/layer/x86/cast_x86_avx2.cpp
  7. +1
    -0
      src/layer/x86/cast_x86_avx512bf16.cpp
  8. +1
    -0
      src/layer/x86/cast_x86_f16c.cpp
  9. +147
    -0
      src/layer/x86/x86_usability.h

+ 8
- 8
cmake/ncnn_add_layer.cmake View File

@@ -160,28 +160,28 @@ macro(ncnn_add_layer class)
if(NCNN_TARGET_ARCH STREQUAL "x86")
if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC" OR (CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_SIMULATE_ID MATCHES "MSVC" AND CMAKE_CXX_COMPILER_FRONTEND_VARIANT MATCHES "MSVC"))
if(NCNN_RUNTIME_CPU AND NCNN_AVX512)
ncnn_add_arch_opt_layer(${class} avx512 "/arch:AVX512 /D__FMA__ /D__F16C__")
ncnn_add_arch_opt_layer(${class} avx512 "/arch:AVX512 /D__SSE4_1__ /D__FMA__ /D__F16C__")
endif()
if(NCNN_RUNTIME_CPU AND NCNN_FMA)
ncnn_add_arch_opt_layer(${class} fma "/arch:AVX /D__FMA__ /D__F16C__")
ncnn_add_arch_opt_layer(${class} fma "/arch:AVX /D__SSE4_1__ /D__FMA__ /D__F16C__")
endif()
if(NCNN_RUNTIME_CPU AND NCNN_AVX)
ncnn_add_arch_opt_layer(${class} avx "/arch:AVX")
ncnn_add_arch_opt_layer(${class} avx "/arch:AVX /D__SSE4_1__")
endif()
if(NCNN_AVX512VNNI)
ncnn_add_arch_opt_source(${class} avx512vnni "/arch:AVX512 /D__FMA__ /D__F16C__ /D__AVX512VNNI__")
ncnn_add_arch_opt_source(${class} avx512vnni "/arch:AVX512 /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVX512VNNI__")
endif()
if(NCNN_AVXVNNI)
ncnn_add_arch_opt_source(${class} avxvnni "/arch:AVX2 /D__FMA__ /D__F16C__ /D__AVXVNNI__")
ncnn_add_arch_opt_source(${class} avxvnni "/arch:AVX2 /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVXVNNI__")
endif()
if(NCNN_AVX2)
ncnn_add_arch_opt_source(${class} avx2 "/arch:AVX2 /D__FMA__ /D__F16C__")
ncnn_add_arch_opt_source(${class} avx2 "/arch:AVX2 /D__SSE4_1__ /D__FMA__ /D__F16C__")
endif()
if(NCNN_XOP)
ncnn_add_arch_opt_source(${class} xop "/arch:AVX /D__XOP__")
ncnn_add_arch_opt_source(${class} xop "/arch:AVX /D__SSE4_1__ /D__XOP__")
endif()
if(NCNN_F16C)
ncnn_add_arch_opt_source(${class} f16c "/arch:AVX /D__F16C__")
ncnn_add_arch_opt_source(${class} f16c "/arch:AVX /D__SSE4_1__ /D__F16C__")
endif()
else()
if(NCNN_RUNTIME_CPU AND NCNN_AVX512)


+ 4
- 4
src/CMakeLists.txt View File

@@ -342,7 +342,7 @@ if(NCNN_TARGET_ARCH STREQUAL "x86")

if(NOT NCNN_RUNTIME_CPU AND NCNN_AVX512)
if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC" OR (CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_SIMULATE_ID MATCHES "MSVC" AND CMAKE_CXX_COMPILER_FRONTEND_VARIANT MATCHES "MSVC"))
target_compile_options(ncnn PRIVATE /arch:AVX512 /D__FMA__ /D__F16C__)
target_compile_options(ncnn PRIVATE /arch:AVX512 /D__SSE4_1__ /D__FMA__ /D__F16C__)
if(NCNN_AVX512VNNI)
target_compile_options(ncnn PRIVATE /D__AVX512VNNI__)
endif()
@@ -361,9 +361,9 @@ if(NCNN_TARGET_ARCH STREQUAL "x86")
elseif(NOT NCNN_RUNTIME_CPU AND NCNN_FMA)
if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC" OR (CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_SIMULATE_ID MATCHES "MSVC" AND CMAKE_CXX_COMPILER_FRONTEND_VARIANT MATCHES "MSVC"))
if(NCNN_AVX2)
target_compile_options(ncnn PRIVATE /arch:AVX2 /D__FMA__)
target_compile_options(ncnn PRIVATE /arch:AVX2 /D__SSE4_1__ /D__FMA__)
else()
target_compile_options(ncnn PRIVATE /arch:AVX /D__FMA__)
target_compile_options(ncnn PRIVATE /arch:AVX /D__SSE4_1__ /D__FMA__)
endif()
if(NCNN_AVXVNNI)
target_compile_options(ncnn PRIVATE /D__AVXVNNI__)
@@ -390,7 +390,7 @@ if(NCNN_TARGET_ARCH STREQUAL "x86")
endif()
elseif(NOT NCNN_RUNTIME_CPU AND NCNN_AVX)
if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC" OR (CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_SIMULATE_ID MATCHES "MSVC" AND CMAKE_CXX_COMPILER_FRONTEND_VARIANT MATCHES "MSVC"))
target_compile_options(ncnn PRIVATE /arch:AVX)
target_compile_options(ncnn PRIVATE /arch:AVX /D__SSE4_1__)
if(NCNN_XOP)
target_compile_options(ncnn PRIVATE /D__XOP__)
endif()


+ 49
- 63
src/layer/x86/cast_bf16.h View File

@@ -17,36 +17,10 @@ void cast_fp32_to_bf16_sse_avx512bf16(const Mat& bottom_blob, Mat& top_blob, con
void cast_bf16_to_fp32_sse_avx512bf16(const Mat& bottom_blob, Mat& top_blob, const Option& opt);
#endif

#if __AVX__
static inline __m256 bfloat2float_avx(__m128i v0)
{
__m128i zero = _mm_set1_epi32(0);
__m128i a = _mm_slli_epi32(_mm_unpacklo_epi16(v0, zero), 16);
__m128i b = _mm_slli_epi32(_mm_unpackhi_epi16(v0, zero), 16);
__m256i ab = _mm256_set1_epi32(0);
ab = _mm256_insertf128_si256(ab, a, 0); // insert in low 128-bit lane
ab = _mm256_insertf128_si256(ab, b, 1); // insert in high 128-bit lane
return _mm256_castsi256_ps(ab);
}
#if __AVX2__
static inline __m256i float2bfloat_avx(__m256 v0, __m256 v1)
{
__m256i a = _mm256_castps_si256(v0);
a = _mm256_srli_epi32(a, 16);
__m256i b = _mm256_castps_si256(v1);
b = _mm256_srli_epi32(b, 16);
__m256i abab = _mm256_packus_epi32(a, b);
return _mm256_permutevar8x32_epi32(abab, _mm256_setr_epi32(0, 1, 4, 5, 2, 3, 6, 7));
}
static inline __m128i float2bfloat_avx(__m256 v0)
{
__m256i a = _mm256_castps_si256(v0);
a = _mm256_srli_epi32(a, 16);
__m256i aaaa = _mm256_packus_epi32(a, a);
return _mm256_castsi256_si128(_mm256_permutevar8x32_epi32(aaaa, _mm256_setr_epi32(0, 1, 4, 5, 2, 3, 6, 7)));
}
#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__
void cast_fp32_to_bf16_sse_avx2(const Mat& bottom_blob, Mat& top_blob, const Option& opt);
void cast_bf16_to_fp32_sse_avx2(const Mat& bottom_blob, Mat& top_blob, const Option& opt);
#endif
#endif // __AVX__

static void cast_fp32_to_bf16_sse(const Mat& bottom_blob, Mat& top_blob, const Option& opt)
{
@@ -58,6 +32,14 @@ static void cast_fp32_to_bf16_sse(const Mat& bottom_blob, Mat& top_blob, const O
}
#endif

#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__
if (ncnn::cpu_support_x86_avx2())
{
cast_fp32_to_bf16_sse_avx2(bottom_blob, top_blob, opt);
return;
}
#endif

const int w = bottom_blob.w;
const int h = bottom_blob.h;
const int d = bottom_blob.d;
@@ -73,39 +55,38 @@ static void cast_fp32_to_bf16_sse(const Mat& bottom_blob, Mat& top_blob, const O
unsigned short* outptr = top_blob.channel(q);

int i = 0;
#if __AVX512BF16__
for (; i + 15 < size; i += 16)
{
__m512 _v_fp32 = _mm512_loadu_ps(ptr);
__m256bh _v_bf16 = _mm512_cvtneps_pbh(_v_fp32);
_mm256_storeu_si256((__m256i*)outptr, (__m256i)_v_bf16);

ptr += 16;
outptr += 16;
}
for (; i + 7 < size; i += 8)
#if __SSE2__
#if __AVX__
#if __AVX512F__
for (; i + 31 < size; i += 32)
{
__m256 _v_fp32 = _mm256_loadu_ps(ptr);
__m128bh _v_bf16 = _mm256_cvtneps_pbh(_v_fp32);
_mm_storeu_si128((__m128i*)outptr, (__m128i)_v_bf16);

ptr += 8;
outptr += 8;
_mm512_storeu_si512((__m512i*)outptr, float2bfloat_avx512(_mm512_loadu_ps(ptr), _mm512_loadu_ps(ptr + 16)));
ptr += 32;
outptr += 32;
}
#elif __AVX2__
#endif // __AVX512F__
for (; i + 15 < size; i += 16)
{
#if __AVX512F__
_mm256_storeu_si256((__m256i*)outptr, float2bfloat_avx512(_mm512_loadu_ps(ptr)));
#else
_mm256_storeu_si256((__m256i*)outptr, float2bfloat_avx(_mm256_loadu_ps(ptr), _mm256_loadu_ps(ptr + 8)));
#endif
ptr += 16;
outptr += 16;
}
#endif // __AVX__
for (; i + 7 < size; i += 8)
{
#if __AVX__
_mm_store_si128((__m128i*)outptr, float2bfloat_avx(_mm256_loadu_ps(ptr)));
#else
_mm_store_si128((__m128i*)outptr, float2bfloat_sse(_mm_loadu_ps(ptr), _mm_loadu_ps(ptr + 4)));
#endif
ptr += 8;
outptr += 8;
}
#endif
#endif // __SSE2__
for (; i < size; i++)
{
*outptr++ = float32_to_bfloat16(*ptr++);
@@ -123,6 +104,14 @@ static void cast_bf16_to_fp32_sse(const Mat& bottom_blob, Mat& top_blob, const O
}
#endif

#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__
if (ncnn::cpu_support_x86_avx2())
{
cast_bf16_to_fp32_sse_avx2(bottom_blob, top_blob, opt);
return;
}
#endif

const int w = bottom_blob.w;
const int h = bottom_blob.h;
const int d = bottom_blob.d;
@@ -138,33 +127,30 @@ static void cast_bf16_to_fp32_sse(const Mat& bottom_blob, Mat& top_blob, const O
float* outptr = top_blob.channel(q);

int i = 0;
#if __AVX512BF16__
#if __SSE2__
#if __AVX__
#if __AVX512F__
for (; i + 15 < size; i += 16)
{
__m256bh _v_bf16 = (__m256bh)_mm256_loadu_si256((const __m256i*)ptr);
__m512 _v_fp32 = _mm512_cvtpbh_ps(_v_bf16);
_mm512_storeu_ps(outptr, _v_fp32);

_mm512_storeu_ps(outptr, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)ptr)));
ptr += 16;
outptr += 16;
}
#endif // __AVX512F__
for (; i + 7 < size; i += 8)
{
__m128bh _v_bf16 = (__m128bh)_mm_loadu_si128((const __m128i*)ptr);
__m256 _v_fp32 = _mm256_cvtpbh_ps(_v_bf16);
_mm256_storeu_ps(outptr, _v_fp32);

_mm256_storeu_ps(outptr, bfloat2float_avx(_mm_loadu_si128((const __m128i*)ptr)));
ptr += 8;
outptr += 8;
}
#elif __AVX__
for (; i + 7 < size; i += 8)
#endif // __AVX__
for (; i + 3 < size; i += 4)
{
_mm256_storeu_ps(outptr, bfloat2float_avx(_mm_lddqu_si128((__m128i const*)ptr)));
ptr += 8;
outptr += 8;
_mm_storeu_ps(outptr, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)ptr)));
ptr += 4;
outptr += 4;
}
#endif
#endif // __SSE2__
for (; i < size; i++)
{
*outptr++ = bfloat16_to_float32(*ptr++);


+ 14
- 75
src/layer/x86/cast_fp16.h View File

@@ -12,11 +12,6 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#if NCNN_RUNTIME_CPU && NCNN_AVX512FP16 && __AVX512F__ && !__AVX512FP16__
void cast_fp32_to_fp16_sse_avx512fp16(const Mat& bottom_blob, Mat& top_blob, const Option& opt);
void cast_fp16_to_fp32_sse_avx512fp16(const Mat& bottom_blob, Mat& top_blob, const Option& opt);
#endif

#if NCNN_RUNTIME_CPU && NCNN_F16C && __AVX__ && !__F16C__
void cast_fp32_to_fp16_sse_f16c(const Mat& bottom_blob, Mat& top_blob, const Option& opt);
void cast_fp16_to_fp32_sse_f16c(const Mat& bottom_blob, Mat& top_blob, const Option& opt);
@@ -24,14 +19,6 @@ void cast_fp16_to_fp32_sse_f16c(const Mat& bottom_blob, Mat& top_blob, const Opt

static void cast_fp32_to_fp16_sse(const Mat& bottom_blob, Mat& top_blob, const Option& opt)
{
#if NCNN_RUNTIME_CPU && NCNN_AVX512FP16 && __AVX512F__ && !__AVX512FP16__
if (ncnn::cpu_support_x86_avx512_fp16())
{
cast_fp32_to_fp16_sse_avx512fp16(bottom_blob, top_blob, opt);
return;
}
#endif

#if NCNN_RUNTIME_CPU && NCNN_F16C && __AVX__ && !__F16C__
if (ncnn::cpu_support_x86_f16c())
{
@@ -55,54 +42,34 @@ static void cast_fp32_to_fp16_sse(const Mat& bottom_blob, Mat& top_blob, const O
unsigned short* outptr = top_blob.channel(q);

int i = 0;
#if __AVX512FP16__
#if __F16C__
#if __AVX512F__
for (; i + 15 < size; i += 16)
{
__m512 _v_fp32 = _mm512_loadu_ps(ptr);
__m256h _v_fp16 = _mm512_cvtxps_ph(_v_fp32);
_mm256_storeu_si256((__m256i*)outptr, (__m256i)_v_fp16);

__m256i _v_fp16 = _mm512_cvtps_ph(_v_fp32, _MM_ROUND_NEAREST | _MM_FROUND_NO_EXC);
_mm256_storeu_si256((__m256i*)outptr, _v_fp16);
ptr += 16;
outptr += 16;
}
#endif // __AVX512F__
for (; i + 7 < size; i += 8)
{
__m256 _v_fp32 = _mm256_loadu_ps(ptr);
__m128h _v_fp16 = _mm256_cvtxps_ph(_v_fp32);
_mm_storeu_si128((__m128i*)outptr, (__m128i)_v_fp16);

ptr += 8;
outptr += 8;
}
for (; i + 3 < size; i += 4)
{
__m128 _v_fp32 = _mm_loadu_ps(ptr);
__m128h _v_fp16 = _mm_cvtxps_ph(_v_fp32);
_mm_storel_epi64((__m128i*)outptr, (__m128i)_v_fp16);

ptr += 4;
outptr += 4;
}
#elif __F16C__
for (; i + 7 < size; i += 8)
{
__m256 _v_fp32 = _mm256_loadu_ps(ptr);
__m128i _v_fp16 = _mm256_cvtps_ph(_v_fp32, _MM_FROUND_TRUNC);
__m128i _v_fp16 = _mm256_cvtps_ph(_v_fp32, _MM_ROUND_NEAREST | _MM_FROUND_NO_EXC);
_mm_storeu_si128((__m128i*)outptr, _v_fp16);

ptr += 8;
outptr += 8;
}
for (; i + 3 < size; i += 4)
{
__m128 _v_fp32 = _mm_loadu_ps(ptr);
__m128i _v_fp16 = _mm_cvtps_ph(_v_fp32, _MM_FROUND_TRUNC);
__m128i _v_fp16 = _mm_cvtps_ph(_v_fp32, _MM_ROUND_NEAREST | _MM_FROUND_NO_EXC);
_mm_storel_epi64((__m128i*)outptr, _v_fp16);

ptr += 4;
outptr += 4;
}
#endif
#endif // __F16C__
for (; i < size; i++)
{
*outptr++ = float32_to_float16(*ptr++);
@@ -112,14 +79,6 @@ static void cast_fp32_to_fp16_sse(const Mat& bottom_blob, Mat& top_blob, const O

static void cast_fp16_to_fp32_sse(const Mat& bottom_blob, Mat& top_blob, const Option& opt)
{
#if NCNN_AVX512FP16 && __AVX512F__ && !__AVX512FP16__
if (ncnn::cpu_support_x86_avx512_fp16())
{
cast_fp16_to_fp32_sse_avx512fp16(bottom_blob, top_blob, opt);
return;
}
#endif

#if NCNN_F16C && __AVX__ && !__F16C__
if (ncnn::cpu_support_x86_f16c())
{
@@ -143,41 +102,22 @@ static void cast_fp16_to_fp32_sse(const Mat& bottom_blob, Mat& top_blob, const O
float* outptr = top_blob.channel(q);

int i = 0;
#if __AVX512FP16__
#if __F16C__
#if __AVX512F__
for (; i + 15 < size; i += 16)
{
__m256h _v_fp16 = (__m256h)_mm256_loadu_si256((const __m256i*)ptr);
__m512 _v_fp32 = _mm512_cvtxph_ps(_v_fp16);
__m256i _v_fp16 = _mm256_loadu_si256((const __m256i*)ptr);
__m512 _v_fp32 = _mm512_cvtph_ps(_v_fp16);
_mm512_storeu_ps(outptr, _v_fp32);

ptr += 16;
outptr += 16;
}
for (; i + 7 < size; i += 8)
{
__m128h _v_fp16 = (__m128h)_mm_loadu_si128((const __m128i*)ptr);
__m256 _v_fp32 = _mm256_cvtxph_ps(_v_fp16);
_mm256_storeu_ps(outptr, _v_fp32);

ptr += 8;
outptr += 8;
}
for (; i + 3 < size; i += 4)
{
__m128h _v_fp16 = (__m128h)_mm_loadl_epi64((const __m128i*)ptr);
__m128 _v_fp32 = _mm_cvtxph_ps(_v_fp16);
_mm_storeu_ps(outptr, _v_fp32);

ptr += 4;
outptr += 4;
}
#elif __F16C__
#endif // __AVX512F__
for (; i + 7 < size; i += 8)
{
__m128i _v_fp16 = _mm_loadu_si128((const __m128i*)ptr);
__m256 _v_fp32 = _mm256_cvtph_ps(_v_fp16);
_mm256_storeu_ps(outptr, _v_fp32);

ptr += 8;
outptr += 8;
}
@@ -186,11 +126,10 @@ static void cast_fp16_to_fp32_sse(const Mat& bottom_blob, Mat& top_blob, const O
__m128i _v_fp16 = _mm_loadl_epi64((const __m128i*)ptr);
__m128 _v_fp32 = _mm_cvtph_ps(_v_fp16);
_mm_storeu_ps(outptr, _v_fp32);

ptr += 4;
outptr += 4;
}
#endif
#endif // __F16C__
for (; i < size; i++)
{
*outptr++ = float16_to_float32(*ptr++);


+ 1
- 0
src/layer/x86/cast_x86.cpp View File

@@ -20,6 +20,7 @@
#include <immintrin.h>
#endif // __AVX__
#endif // __SSE2__
#include "x86_usability.h"

#include "cpu.h"



src/layer/x86/cast_x86_avx512fp16.cpp → src/layer/x86/cast_x86_avx2.cpp View File

@@ -1,6 +1,6 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
@@ -14,19 +14,20 @@

#include "cpu.h"
#include "mat.h"
#include "x86_usability.h"

namespace ncnn {

#include "cast_fp16.h"
#include "cast_bf16.h"

void cast_fp32_to_fp16_sse_avx512fp16(const Mat& bottom_blob, Mat& top_blob, const Option& opt)
void cast_fp32_to_bf16_sse_avx2(const Mat& bottom_blob, Mat& top_blob, const Option& opt)
{
cast_fp32_to_fp16_sse(bottom_blob, top_blob, opt);
cast_fp32_to_bf16_sse(bottom_blob, top_blob, opt);
}

void cast_fp16_to_fp32_sse_avx512fp16(const Mat& bottom_blob, Mat& top_blob, const Option& opt)
void cast_bf16_to_fp32_sse_avx2(const Mat& bottom_blob, Mat& top_blob, const Option& opt)
{
cast_fp16_to_fp32_sse(bottom_blob, top_blob, opt);
cast_bf16_to_fp32_sse(bottom_blob, top_blob, opt);
}

} // namespace ncnn

+ 1
- 0
src/layer/x86/cast_x86_avx512bf16.cpp View File

@@ -14,6 +14,7 @@

#include "cpu.h"
#include "mat.h"
#include "x86_usability.h"

namespace ncnn {



+ 1
- 0
src/layer/x86/cast_x86_f16c.cpp View File

@@ -14,6 +14,7 @@

#include "cpu.h"
#include "mat.h"
#include "x86_usability.h"

namespace ncnn {



+ 147
- 0
src/layer/x86/x86_usability.h View File

@@ -18,6 +18,8 @@
#include <math.h>
#if __SSE2__
#include <emmintrin.h>
#if __SSE4_1__
#include <smmintrin.h>
#if __AVX__
#include <immintrin.h>
#if __XOP__
@@ -28,6 +30,7 @@
#endif
#endif
#endif
#endif
#endif // __SSE2__

static NCNN_FORCEINLINE signed char float2int8(float v)
@@ -154,6 +157,36 @@ static NCNN_FORCEINLINE __m128i float2int8_sse(const __m128& _v0, const __m128&
return _v8;
}

static NCNN_FORCEINLINE __m128 bfloat2float_sse(const __m128i& v0)
{
__m128i _zero = _mm_setzero_si128();
__m128i _a = _mm_unpacklo_epi16(_zero, v0);
__m128 _v = _mm_castsi128_ps(_a);
return _v;
}

static NCNN_FORCEINLINE __m128i float2bfloat_sse(const __m128& v0, const __m128& v1)
{
#if __AVX512BF16__
__m128i _v = (__m128i)_mm256_cvtneps_pbh(_mm256_insertf128_ps(_mm256_castps128_ps256(v0), v1, 1));
#else
__m128i _a = _mm_castps_si128(v0);
__m128i _b = _mm_castps_si128(v1);
#if __SSE4_1__
_a = _mm_srli_epi32(_a, 16);
_b = _mm_srli_epi32(_b, 16);
__m128i _v = _mm_packus_epi32(_a, _b);
#else
_a = _mm_shufflelo_epi16(_a, _MM_SHUFFLE(2, 0, 3, 1));
_b = _mm_shufflelo_epi16(_b, _MM_SHUFFLE(2, 0, 3, 1));
_a = _mm_shufflehi_epi16(_a, _MM_SHUFFLE(2, 0, 3, 1));
_b = _mm_shufflehi_epi16(_b, _MM_SHUFFLE(2, 0, 3, 1));
__m128i _v = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(_a), _mm_castsi128_ps(_b), _MM_SHUFFLE(2, 0, 2, 0)));
#endif
#endif
return _v;
}

#ifndef __FMA__
static NCNN_FORCEINLINE __m128 _mm_comp_fmadd_ps(const __m128& _a, const __m128& _b, const __m128& _c)
{
@@ -500,6 +533,71 @@ static NCNN_FORCEINLINE void _mm256_comp_fmadd_ps8(__m256& _sum,
_mm256_comp_fmadd_ps4(_sum, _w4, _w5, _w6, _w7, _v4, _v5, _v6, _v7);
}

static NCNN_FORCEINLINE __m256 bfloat2float_avx(const __m128i& v0)
{
#if __AVX512BF16__
__m256 _v = _mm256_cvtpbh_ps((__m128bh)v0);
#else
__m128i _zero = _mm_setzero_si128();
__m128i _a = _mm_unpacklo_epi16(_zero, v0);
__m128i _b = _mm_unpackhi_epi16(_zero, v0);
__m256 _v = _mm256_castsi256_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(_a), _b, 1));
#endif
return _v;
}

static NCNN_FORCEINLINE __m128i float2bfloat_avx(const __m256& v0)
{
#if __AVX512BF16__
__m128i _v = (__m128i)_mm256_cvtneps_pbh(v0);
#else
__m256i _ab = _mm256_castps_si256(v0);
#if __AVX2__
_ab = _mm256_srli_epi32(_ab, 16);
__m128i _a = _mm256_extractf128_si256(_ab, 0);
__m128i _b = _mm256_extractf128_si256(_ab, 1);
#else
__m128i _a = _mm256_extractf128_si256(_ab, 0);
__m128i _b = _mm256_extractf128_si256(_ab, 1);
_a = _mm_srli_epi32(_a, 16);
_b = _mm_srli_epi32(_b, 16);
#endif
__m128i _v = _mm_packus_epi32(_a, _b);
#endif
return _v;
}

static NCNN_FORCEINLINE __m256i float2bfloat_avx(const __m256& v0, const __m256& v1)
{
#if __AVX512BF16__
__m128i _v0 = (__m128i)_mm256_cvtneps_pbh(v0);
__m128i _v1 = (__m128i)_mm256_cvtneps_pbh(v1);
__m256i _v = _mm256_insertf128_si256(_mm256_castsi128_si256(_v0), _v1, 1);
#else
__m256i _a = _mm256_castps_si256(v0);
__m256i _b = _mm256_castps_si256(v1);
#if __AVX2__
_a = _mm256_srli_epi32(_a, 16);
_b = _mm256_srli_epi32(_b, 16);
__m256i _v = _mm256_packus_epi32(_a, _b);
_v = _mm256_permute4x64_epi64(_v, _MM_SHUFFLE(3, 1, 2, 0));
#else
__m128i _a0 = _mm256_extractf128_si256(_a, 0);
__m128i _a1 = _mm256_extractf128_si256(_a, 1);
__m128i _b0 = _mm256_extractf128_si256(_b, 0);
__m128i _b1 = _mm256_extractf128_si256(_b, 1);
_a0 = _mm_srli_epi32(_a0, 16);
_a1 = _mm_srli_epi32(_a1, 16);
_b0 = _mm_srli_epi32(_b0, 16);
_b1 = _mm_srli_epi32(_b1, 16);
__m128i _v0 = _mm_packus_epi32(_a0, _a1);
__m128i _v1 = _mm_packus_epi32(_b0, _b1);
__m256i _v = _mm256_insertf128_si256(_mm256_castsi128_si256(_v0), _v1, 1);
#endif
#endif
return _v;
}

#if __AVX512F__
static NCNN_FORCEINLINE void transpose16x16_ps(__m512& _r0, __m512& _r1, __m512& _r2, __m512& _r3, __m512& _r4, __m512& _r5, __m512& _r6, __m512& _r7,
__m512& _r8, __m512& _r9, __m512& _ra, __m512& _rb, __m512& _rc, __m512& _rd, __m512& _re, __m512& _rf)
@@ -886,6 +984,55 @@ static NCNN_FORCEINLINE float _mm512_comp_reduce_max_ps(__m512 x)
const __m128 x32 = _mm_max_ss(x64, _mm_shuffle_ps(x64, x64, 0x55));
return _mm_cvtss_f32(x32);
}

static NCNN_FORCEINLINE __m512 bfloat2float_avx512(const __m256i& v0)
{
#if __AVX512BF16__
__m512 _v = _mm512_cvtpbh_ps((__m256bh)v0);
#else
__m256i _zero = _mm256_setzero_si256();
__m256i _a = _mm256_unpacklo_epi16(_zero, v0);
__m256i _b = _mm256_unpackhi_epi16(_zero, v0);
__m256i _c = _mm256_permute2x128_si256(_a, _b, _MM_SHUFFLE(0, 2, 0, 0));
__m256i _d = _mm256_permute2x128_si256(_a, _b, _MM_SHUFFLE(0, 3, 0, 1));
__m512 _v = _mm512_castsi512_ps(_mm512_inserti32x8(_mm512_castsi256_si512(_c), _d, 1));
#endif
return _v;
}

static NCNN_FORCEINLINE __m256i float2bfloat_avx512(const __m512& v0)
{
#if __AVX512BF16__
__m256i _v = (__m256i)_mm512_cvtneps_pbh(v0);
#else
__m512i _ab = _mm512_castps_si512(v0);
_ab = _mm512_srli_epi32(_ab, 16);
__m256i _a = _mm512_extracti32x8_epi32(_ab, 0);
__m256i _b = _mm512_extracti32x8_epi32(_ab, 1);
__m256i _v = _mm256_packus_epi32(_a, _b);
_v = _mm256_permute4x64_epi64(_v, _MM_SHUFFLE(3, 1, 2, 0));
#endif
return _v;
}

static NCNN_FORCEINLINE __m512i float2bfloat_avx512(const __m512& v0, const __m512& v1)
{
#if __AVX512BF16__
__m256bh _v0 = _mm512_cvtneps_pbh(v0);
__m256bh _v1 = _mm512_cvtneps_pbh(v1);
__m512i _v = _mm512_inserti32x8(_mm512_castsi256_si512((__m256i)_v0), (__m256i)_v1, 1);
#else
__m512i _a = _mm512_castps_si512(v0);
__m512i _b = _mm512_castps_si512(v1);
_a = _mm512_srli_epi32(_a, 16);
_b = _mm512_srli_epi32(_b, 16);
__m512i _v = _mm512_packus_epi32(_a, _b);
_v = _mm512_permutex_epi64(_v, _MM_SHUFFLE(3, 1, 2, 0));
_v = _mm512_shuffle_i32x4(_v, _v, _MM_SHUFFLE(3, 1, 2, 0));
#endif
return _v;
}

#endif // __AVX512F__
#endif // __AVX__
#endif // __SSE2__


Loading…
Cancel
Save