Browse Source

more stricter armv7 fp16 and armv84 bf16 compiler check, fix #4147 fix #4222 (#4247)

tags/20221128
nihui GitHub 3 years ago
parent
commit
3e2b3fa04d
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 153 additions and 86 deletions
  1. +1
    -1
      CMakeLists.txt
  2. +1
    -1
      src/layer/arm/cast_bf16.h
  3. +119
    -52
      src/layer/arm/cast_fp16.h
  4. +14
    -14
      src/layer/arm/innerproduct_fp16s.h
  5. +10
    -10
      src/layer/arm/innerproduct_gemm_fp16s.h
  6. +8
    -8
      src/layer/arm/neon_mathfun_fp16s.h

+ 1
- 1
CMakeLists.txt View File

@@ -171,7 +171,7 @@ if((IOS AND CMAKE_OSX_ARCHITECTURES MATCHES "arm")
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float32x4_t _s; float16x8_t _a, _b; _s = vfmlalq_low_f16(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM82_FP16FML)

set(CMAKE_REQUIRED_FLAGS "-march=armv8.4-a+bf16")
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float32x4_t _s; bfloat16x8_t _a, _b; _s = vbfmmlaq_f32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM84_BF16)
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float32x4_t _s; bfloat16x8_t _a, _b; _s = vcvt_f32_bf16(vcvt_bf16_f32(vbfmmlaq_f32(_s, _a, _b))); return 0; }" NCNN_COMPILER_SUPPORT_ARM84_BF16)

set(CMAKE_REQUIRED_FLAGS "-march=armv8.4-a+i8mm")
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { int32x4_t _s; int8x16_t _a, _b; _s = vmmlaq_s32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM84_I8MM)


+ 1
- 1
src/layer/arm/cast_bf16.h View File

@@ -150,7 +150,7 @@ static void cast_fp32_to_bf16_neon(const Mat& bottom_blob, Mat& top_blob, const

static void cast_bf16_to_fp32_neon(const Mat& bottom_blob, Mat& top_blob, const Option& opt)
{
#if NCNN_ARM84BF16 && __aarch64__ && !__ARM_FEATURE_BF16_VECTOR_ARITHMETIC
#if NCNN_RUNTIME_CPU && NCNN_ARM84BF16 && __aarch64__ && !__ARM_FEATURE_BF16_VECTOR_ARITHMETIC
if (ncnn::cpu_support_arm_bf16())
{
cast_bf16_to_fp32_neon_bf16(bottom_blob, top_blob, opt);


+ 119
- 52
src/layer/arm/cast_fp16.h View File

@@ -47,12 +47,12 @@ static void cast_fp32_to_fp16_neon(const Mat& bottom_blob, Mat& top_blob, const
{
#if __aarch64__
asm volatile(
"prfm pldl1keep, [%0, #512] \n"
"prfm pldl1keep, [%0, #512] \n"
"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%0], #64 \n"
"fcvtn v0.4h, v0.4s \n"
"fcvtn v1.4h, v1.4s \n"
"fcvtn v2.4h, v2.4s \n"
"fcvtn v3.4h, v3.4s \n"
"fcvtn v0.4h, v0.4s \n"
"fcvtn v1.4h, v1.4s \n"
"fcvtn v2.4h, v2.4s \n"
"fcvtn v3.4h, v3.4s \n"
"st1 {v0.4h, v1.4h, v2.4h, v3.4h}, [%1], #32 \n"
: "=r"(ptr), // %0
"=r"(outptr) // %1
@@ -61,12 +61,12 @@ static void cast_fp32_to_fp16_neon(const Mat& bottom_blob, Mat& top_blob, const
: "memory", "v0", "v1", "v2", "v3");
#else // __aarch64__
asm volatile(
"pld [%0, #512] \n"
"vldm %0!, {d0-d7} \n"
"vcvt.f16.f32 d0, q0 \n"
"vcvt.f16.f32 d1, q1 \n"
"vcvt.f16.f32 d2, q2 \n"
"vcvt.f16.f32 d3, q3 \n"
"pld [%0, #512] \n"
"vldm %0!, {d0-d7} \n"
"vcvt.f16.f32 d0, q0 \n"
"vcvt.f16.f32 d1, q1 \n"
"vcvt.f16.f32 d2, q2 \n"
"vcvt.f16.f32 d3, q3 \n"
"vst1.u16 {d0-d3}, [%1 :128]! \n"
: "=r"(ptr), // %0
"=r"(outptr) // %1
@@ -77,24 +77,61 @@ static void cast_fp32_to_fp16_neon(const Mat& bottom_blob, Mat& top_blob, const
}
for (; i + 7 < size; i += 8)
{
float32x4_t _p0_fp32 = vld1q_f32(ptr);
float32x4_t _p1_fp32 = vld1q_f32(ptr + 4);
float16x4_t _p0_fp16 = vcvt_f16_f32(_p0_fp32);
float16x4_t _p1_fp16 = vcvt_f16_f32(_p1_fp32);
uint16x8_t _p_fp16 = vcombine_u16(vreinterpret_u16_f16(_p0_fp16), vreinterpret_u16_f16(_p1_fp16));
vst1q_u16(outptr, _p_fp16);
ptr += 8;
outptr += 8;
// This is originally implemented with neon fp16 intrinsics.
// In the new version of gcc, __ARM_FP16_FORMAT_IEEE or __ARM_FP16_FORMAT_ALTERNATIVE needs to be defined to use the float16x4_t type.
// That leads to compiler error when compiled with -mfpu=neon-vfpv4 but without -mfp16-format=ieee flag.
// We could add more macro conditions to differentiate between old and new versions, but that's pretty ugly!
// Just use all inline assembly here ~
// --- nihui
#if __aarch64__
asm volatile(
"ld1 {v0.4s, v1.4s}, [%0], #32 \n"
"fcvtn v0.4h, v0.4s \n"
"fcvtn v1.4h, v1.4s \n"
"st1 {v0.4h, v1.4h}, [%1], #16 \n"
: "=r"(ptr), // %0
"=r"(outptr) // %1
: "0"(ptr),
"1"(outptr)
: "memory", "v0", "v1");
#else // __aarch64__
asm volatile(
"vld1.f32 {d0-d3}, [%0]! \n"
"vcvt.f16.f32 d0, q0 \n"
"vcvt.f16.f32 d1, q1 \n"
"vst1.u16 {d0-d1}, [%1]! \n"
: "=r"(ptr), // %0
"=r"(outptr) // %1
: "0"(ptr),
"1"(outptr)
: "memory", "q0", "q1");
#endif // __aarch64__
}
for (; i + 3 < size; i += 4)
{
float32x4_t _p_fp32 = vld1q_f32(ptr);
float16x4_t _p_fp16 = vcvt_f16_f32(_p_fp32);
vst1_u16(outptr, vreinterpret_u16_f16(_p_fp16));
ptr += 4;
outptr += 4;
#if __aarch64__
asm volatile(
"ld1 {v0.4s}, [%0], #16 \n"
"fcvtn v0.4h, v0.4s \n"
"st1 {v0.4h}, [%1], #8 \n"
: "=r"(ptr), // %0
"=r"(outptr) // %1
: "0"(ptr),
"1"(outptr)
: "memory", "v0");
#else // __aarch64__
asm volatile(
"vld1.f32 {d0-d1}, [%0]! \n"
"vcvt.f16.f32 d0, q0 \n"
"vst1.u16 {d0}, [%1]! \n"
: "=r"(ptr), // %0
"=r"(outptr) // %1
: "0"(ptr),
"1"(outptr)
: "memory", "q0");
#endif // __aarch64__
}
#endif
#endif // (__ARM_FP & 2)
for (; i < size; i++)
{
*outptr++ = float32_to_float16(*ptr++);
@@ -104,7 +141,7 @@ static void cast_fp32_to_fp16_neon(const Mat& bottom_blob, Mat& top_blob, const

static void cast_fp16_to_fp32_neon(const Mat& bottom_blob, Mat& top_blob, const Option& opt)
{
#if NCNN_VFPV4 && __ARM_NEON && !(__ARM_FP & 2)
#if NCNN_RUNTIME_CPU && NCNN_VFPV4 && __ARM_NEON && !(__ARM_FP & 2)
if (ncnn::cpu_support_arm_vfpv4())
{
cast_fp16_to_fp32_neon_vfpv4(bottom_blob, top_blob, opt);
@@ -132,12 +169,12 @@ static void cast_fp16_to_fp32_neon(const Mat& bottom_blob, Mat& top_blob, const
{
#if __aarch64__
asm volatile(
"prfm pldl1keep, [%0, #256] \n"
"prfm pldl1keep, [%0, #256] \n"
"ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [%0], #32 \n"
"fcvtl v0.4s, v0.4h \n"
"fcvtl v1.4s, v1.4h \n"
"fcvtl v2.4s, v2.4h \n"
"fcvtl v3.4s, v3.4h \n"
"fcvtl v0.4s, v0.4h \n"
"fcvtl v1.4s, v1.4h \n"
"fcvtl v2.4s, v2.4h \n"
"fcvtl v3.4s, v3.4h \n"
"st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%1], #64 \n"
: "=r"(ptr), // %0
"=r"(outptr) // %1
@@ -146,13 +183,13 @@ static void cast_fp16_to_fp32_neon(const Mat& bottom_blob, Mat& top_blob, const
: "memory", "v0", "v1", "v2", "v3");
#else // __aarch64__
asm volatile(
"pld [%0, #256] \n"
"pld [%0, #256] \n"
"vld1.u16 {d4-d7}, [%0 :128]! \n"
"vcvt.f32.f16 q0, d4 \n"
"vcvt.f32.f16 q1, d5 \n"
"vcvt.f32.f16 q2, d6 \n"
"vcvt.f32.f16 q3, d7 \n"
"vstm %1!, {d0-d7} \n"
"vcvt.f32.f16 q0, d4 \n"
"vcvt.f32.f16 q1, d5 \n"
"vcvt.f32.f16 q2, d6 \n"
"vcvt.f32.f16 q3, d7 \n"
"vstm %1!, {d0-d7} \n"
: "=r"(ptr), // %0
"=r"(outptr) // %1
: "0"(ptr),
@@ -162,25 +199,55 @@ static void cast_fp16_to_fp32_neon(const Mat& bottom_blob, Mat& top_blob, const
}
for (; i + 7 < size; i += 8)
{
uint16x8_t _p_fp16 = vld1q_u16(ptr);
float16x4_t _p0_fp16 = vreinterpret_f16_u16(vget_low_u16(_p_fp16));
float16x4_t _p1_fp16 = vreinterpret_f16_u16(vget_high_u16(_p_fp16));
float32x4_t _p0_fp32 = vcvt_f32_f16(_p0_fp16);
float32x4_t _p1_fp32 = vcvt_f32_f16(_p1_fp16);
vst1q_f32(outptr, _p0_fp32);
vst1q_f32(outptr + 4, _p1_fp32);
ptr += 8;
outptr += 8;
#if __aarch64__
asm volatile(
"ld1 {v0.4h, v1.4h}, [%0], #16 \n"
"fcvtl v0.4s, v0.4h \n"
"fcvtl v1.4s, v1.4h \n"
"st1 {v0.4s, v1.4s}, [%1], #32 \n"
: "=r"(ptr), // %0
"=r"(outptr) // %1
: "0"(ptr),
"1"(outptr)
: "memory", "v0", "v1");
#else // __aarch64__
asm volatile(
"vld1.u16 {d4-d5}, [%0]! \n"
"vcvt.f32.f16 q0, d4 \n"
"vcvt.f32.f16 q1, d5 \n"
"vst1.f32 {d0-d3}, [%1]! \n"
: "=r"(ptr), // %0
"=r"(outptr) // %1
: "0"(ptr),
"1"(outptr)
: "memory", "q0", "q1", "q2");
#endif // __aarch64__
}
for (; i + 3 < size; i += 4)
{
float16x4_t _p_fp16 = vreinterpret_f16_u16(vld1_u16(ptr));
float32x4_t _p_fp32 = vcvt_f32_f16(_p_fp16);
vst1q_f32(outptr, _p_fp32);
ptr += 4;
outptr += 4;
#if __aarch64__
asm volatile(
"ld1 {v0.4h}, [%0], #8 \n"
"fcvtl v0.4s, v0.4h \n"
"st1 {v0.4s}, [%1], #16 \n"
: "=r"(ptr), // %0
"=r"(outptr) // %1
: "0"(ptr),
"1"(outptr)
: "memory", "v0");
#else // __aarch64__
asm volatile(
"vld1.u16 {d2}, [%0]! \n"
"vcvt.f32.f16 q0, d2 \n"
"vst1.f32 {d0-d1}, [%1]! \n"
: "=r"(ptr), // %0
"=r"(outptr) // %1
: "0"(ptr),
"1"(outptr)
: "memory", "q0", "q1");
#endif // __aarch64__
}
#endif
#endif // (__ARM_FP & 2)
for (; i < size; i++)
{
*outptr++ = float16_to_float32(*ptr++);


+ 14
- 14
src/layer/arm/innerproduct_fp16s.h View File

@@ -253,10 +253,10 @@ static void innerproduct_pack4_fp16s_neon(const Mat& bottom_blob, Mat& top_blob,
float32x4_t _val = vld1q_f32(sptr);
uint16x8_t _w01 = vld1q_u16(kptr);
uint16x8_t _w23 = vld1q_u16(kptr + 8);
float32x4_t _w0 = vcvt_f32_f16(vreinterpret_f16_u16(vget_low_u16(_w01)));
float32x4_t _w1 = vcvt_f32_f16(vreinterpret_f16_u16(vget_high_u16(_w01)));
float32x4_t _w2 = vcvt_f32_f16(vreinterpret_f16_u16(vget_low_u16(_w23)));
float32x4_t _w3 = vcvt_f32_f16(vreinterpret_f16_u16(vget_high_u16(_w23)));
float32x4_t _w0 = vcvt_f32_f16((float16x4_t)(vget_low_u16(_w01)));
float32x4_t _w1 = vcvt_f32_f16((float16x4_t)(vget_high_u16(_w01)));
float32x4_t _w2 = vcvt_f32_f16((float16x4_t)(vget_low_u16(_w23)));
float32x4_t _w3 = vcvt_f32_f16((float16x4_t)(vget_high_u16(_w23)));
#endif

#if __aarch64__
@@ -281,7 +281,7 @@ static void innerproduct_pack4_fp16s_neon(const Mat& bottom_blob, Mat& top_blob,
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
float32x4_t _w = vcvt_f32_f16(vld1_f16(kptr));
#else
float32x4_t _w = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(kptr)));
float32x4_t _w = vcvt_f32_f16((float16x4_t)(vld1_u16(kptr)));
#endif
_sum0 = vfmaq_f32(_sum0, _val, _w);

@@ -410,10 +410,10 @@ static void innerproduct_fp16s_neon(const Mat& bottom_blob, Mat& top_blob, const
float32x4_t _w3 = vcvt_f32_f16(vld1_f16(kptr3));
#else
float32x4_t _val = vld1q_f32(sptr);
float32x4_t _w0 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(kptr0)));
float32x4_t _w1 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(kptr1)));
float32x4_t _w2 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(kptr2)));
float32x4_t _w3 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(kptr3)));
float32x4_t _w0 = vcvt_f32_f16((float16x4_t)(vld1_u16(kptr0)));
float32x4_t _w1 = vcvt_f32_f16((float16x4_t)(vld1_u16(kptr1)));
float32x4_t _w2 = vcvt_f32_f16((float16x4_t)(vld1_u16(kptr2)));
float32x4_t _w3 = vcvt_f32_f16((float16x4_t)(vld1_u16(kptr3)));
#endif

_sum0 = vfmaq_f32(_sum0, _val, _w0);
@@ -507,7 +507,7 @@ static void innerproduct_fp16s_neon(const Mat& bottom_blob, Mat& top_blob, const
float32x4_t _w = vcvt_f32_f16(vld1_f16(kptr));
#else
float32x4_t _val = vld1q_f32(sptr);
float32x4_t _w = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(kptr)));
float32x4_t _w = vcvt_f32_f16((float16x4_t)(vld1_u16(kptr)));
#endif
_sum = vfmaq_f32(_sum, _val, _w);

@@ -713,10 +713,10 @@ static void innerproduct_transform_kernel_fp16s_neon(const Mat& weight_data, Mat
{
// transpose 4x4
uint16x4x4_t _p;
_p.val[0] = vreinterpret_u16_f16(vcvt_f16_f32(vld1q_f32(k0)));
_p.val[1] = vreinterpret_u16_f16(vcvt_f16_f32(vld1q_f32(k1)));
_p.val[2] = vreinterpret_u16_f16(vcvt_f16_f32(vld1q_f32(k2)));
_p.val[3] = vreinterpret_u16_f16(vcvt_f16_f32(vld1q_f32(k3)));
_p.val[0] = (uint16x4_t)(vcvt_f16_f32(vld1q_f32(k0)));
_p.val[1] = (uint16x4_t)(vcvt_f16_f32(vld1q_f32(k1)));
_p.val[2] = (uint16x4_t)(vcvt_f16_f32(vld1q_f32(k2)));
_p.val[3] = (uint16x4_t)(vcvt_f16_f32(vld1q_f32(k3)));
vst4_u16(g0, _p);

k0 += 4;


+ 10
- 10
src/layer/arm/innerproduct_gemm_fp16s.h View File

@@ -120,7 +120,7 @@ static void innerproduct_gemm_fp16s_neon(const Mat& bottom_blob, Mat& top_blob,
float32x4_t _w = vcvt_f32_f16(vld1_f16(kptr));
#else
float32x4_t _val = vld1q_f32(m);
float32x4_t _w = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(kptr)));
float32x4_t _w = vcvt_f32_f16((float16x4_t)(vld1_u16(kptr)));
#endif

#if __aarch64__
@@ -214,10 +214,10 @@ static void innerproduct_gemm_fp16s_neon(const Mat& bottom_blob, Mat& top_blob,
float32x4_t _val = vld1q_f32(m);
uint16x8_t _w01 = vld1q_u16(kptr);
uint16x8_t _w23 = vld1q_u16(kptr + 8);
float32x4_t _w0 = vcvt_f32_f16(vreinterpret_f16_u16(vget_low_u16(_w01)));
float32x4_t _w1 = vcvt_f32_f16(vreinterpret_f16_u16(vget_high_u16(_w01)));
float32x4_t _w2 = vcvt_f32_f16(vreinterpret_f16_u16(vget_low_u16(_w23)));
float32x4_t _w3 = vcvt_f32_f16(vreinterpret_f16_u16(vget_high_u16(_w23)));
float32x4_t _w0 = vcvt_f32_f16((float16x4_t)(vget_low_u16(_w01)));
float32x4_t _w1 = vcvt_f32_f16((float16x4_t)(vget_high_u16(_w01)));
float32x4_t _w2 = vcvt_f32_f16((float16x4_t)(vget_low_u16(_w23)));
float32x4_t _w3 = vcvt_f32_f16((float16x4_t)(vget_high_u16(_w23)));
#endif

#if __aarch64__
@@ -242,7 +242,7 @@ static void innerproduct_gemm_fp16s_neon(const Mat& bottom_blob, Mat& top_blob,
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
float32x4_t _w = vcvt_f32_f16(vld1_f16(kptr));
#else
float32x4_t _w = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(kptr)));
float32x4_t _w = vcvt_f32_f16((float16x4_t)(vld1_u16(kptr)));
#endif
_sum0 = vfmaq_f32(_sum0, _val, _w);

@@ -317,7 +317,7 @@ static void innerproduct_gemm_fp16s_neon(const Mat& bottom_blob, Mat& top_blob,
float32x4_t _val1 = vld1q_f32(m + 4);
float32x4_t _val2 = vld1q_f32(m + 8);
float32x4_t _val3 = vld1q_f32(m + 12);
float32x4_t _w = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(kptr)));
float32x4_t _w = vcvt_f32_f16((float16x4_t)(vld1_u16(kptr)));
#endif

#if __aarch64__
@@ -414,8 +414,8 @@ static void innerproduct_gemm_fp16s_neon(const Mat& bottom_blob, Mat& top_blob,
float32x4_t _val0 = vld1q_f32(m);
float32x4_t _val1 = vld1q_f32(m + 4);
uint16x8_t _w01 = vld1q_u16(kptr);
float32x4_t _w0 = vcvt_f32_f16(vreinterpret_f16_u16(vget_low_u16(_w01)));
float32x4_t _w1 = vcvt_f32_f16(vreinterpret_f16_u16(vget_high_u16(_w01)));
float32x4_t _w0 = vcvt_f32_f16((float16x4_t)(vget_low_u16(_w01)));
float32x4_t _w1 = vcvt_f32_f16((float16x4_t)(vget_high_u16(_w01)));
#endif

_sum0 = vfmaq_f32(_sum0, _val0, _w0);
@@ -433,7 +433,7 @@ static void innerproduct_gemm_fp16s_neon(const Mat& bottom_blob, Mat& top_blob,
float32x4_t _w = vcvt_f32_f16(vld1_f16(kptr));
#else
float32x4_t _val = vld1q_f32(m);
float32x4_t _w = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(kptr)));
float32x4_t _w = vcvt_f32_f16((float16x4_t)(vld1_u16(kptr)));
#endif

_sum0 = vfmaq_f32(_sum0, _val, _w);


+ 8
- 8
src/layer/arm/neon_mathfun_fp16s.h View File

@@ -89,9 +89,9 @@ static inline float16x4_t log_ps(float16x4_t x)
* } else { x = x - 1.0; }
*/
uint16x4_t mask = vclt_f16(x, vdup_n_f16(c_cephes_SQRTHF));
float16x4_t tmp = vreinterpret_f16_u16(vand_u16(vreinterpret_u16_f16(x), mask));
float16x4_t tmp = (float16x4_t)(vand_u16((uint16x4_t)(x), mask));
x = vsub_f16(x, one);
e = vsub_f16(e, vreinterpret_f16_u16(vand_u16(vreinterpret_u16_f16(one), mask)));
e = vsub_f16(e, (float16x4_t)(vand_u16((uint16x4_t)(one), mask)));
x = vadd_f16(x, tmp);

float16x4_t z = vmul_f16(x, x);
@@ -115,7 +115,7 @@ static inline float16x4_t log_ps(float16x4_t x)

x = vadd_f16(x, y);
x = vfma_f16(x, e, vdup_n_f16(c_cephes_log_q2));
x = vreinterpret_f16_u16(vorr_u16(vreinterpret_u16_f16(x), invalid_mask)); // negative arg will be NAN
x = (float16x4_t)(vorr_u16((uint16x4_t)(x), invalid_mask)); // negative arg will be NAN
return x;
}

@@ -208,9 +208,9 @@ static inline float16x4_t exp_ps(float16x4_t x)

/* if greater, substract 1 */
uint16x4_t mask = vcgt_f16(tmp, fx);
mask = vand_u16(mask, vreinterpret_u16_f16(one));
mask = vand_u16(mask, (uint16x4_t)(one));

fx = vsub_f16(tmp, vreinterpret_f16_u16(mask));
fx = vsub_f16(tmp, (float16x4_t)(mask));

tmp = vmul_f16(fx, vdup_n_f16(c_cephes_exp_C1));
float16x4_t z = vmul_f16(fx, vdup_n_f16(c_cephes_exp_C2));
@@ -489,7 +489,7 @@ static inline float16x4_t tanh_ps(float16x4_t x)

// clamp the inputs to the range [-9, 9] since anything outside
// this range is -/+1.0f in single-precision.
x2 = vreinterpret_f16_u16(vbsl_u16(vcge_f16(vdup_n_f16(c_tanh_hi), x2), vreinterpret_u16_f16(x2), vreinterpret_u16_f16(vdup_n_f16(c_tanh_hi))));
x2 = (float16x4_t)(vbsl_u16(vcge_f16(vdup_n_f16(c_tanh_hi), x2), (uint16x4_t)(x2), (uint16x4_t)(vdup_n_f16(c_tanh_hi))));

// since the polynomials are odd/even, we need x**2.
float16x4_t z = vmul_f16(x2, x2);
@@ -514,10 +514,10 @@ static inline float16x4_t tanh_ps(float16x4_t x)
y = vdiv_f16(y, w);

// reinstate the sign.
y = vreinterpret_f16_u16(vbsl_u16(vdup_n_u16(1u << 15), vreinterpret_u16_f16(x), vreinterpret_u16_f16(y)));
y = (float16x4_t)(vbsl_u16(vdup_n_u16(1u << 15), (uint16x4_t)(x), (uint16x4_t)(y)));

// when the argument is very small in magnitude it's more accurate to just return it.
y = vreinterpret_f16_u16(vbsl_u16(tiny_mask, vreinterpret_u16_f16(y), vreinterpret_u16_f16(x)));
y = (float16x4_t)(vbsl_u16(tiny_mask, (uint16x4_t)(y), (uint16x4_t)(x)));

return y;
}


Loading…
Cancel
Save