From f595f5b133ee7e1416f27ec2b3e0ad714c29b43a Mon Sep 17 00:00:00 2001 From: sunsuodong Date: Tue, 26 Jan 2021 21:44:21 +0800 Subject: [PATCH] optimize sigmoid and tanh --- mindspore/lite/nnacl/fp16/activation_fp16.c | 48 ++++++++++++++++++--- 1 file changed, 42 insertions(+), 6 deletions(-) diff --git a/mindspore/lite/nnacl/fp16/activation_fp16.c b/mindspore/lite/nnacl/fp16/activation_fp16.c index 6f9b878fc3..cb09c4c742 100644 --- a/mindspore/lite/nnacl/fp16/activation_fp16.c +++ b/mindspore/lite/nnacl/fp16/activation_fp16.c @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #include "nnacl/fp16/activation_fp16.h" +#include "nnacl/fp32/exp_fp32.h" #include "nnacl/errorcode.h" int ReluFp16(const float16_t *src, float16_t *dst, int ele_num) { @@ -60,8 +60,19 @@ int LReluFp16(const float16_t *src, float16_t *dst, int ele_num, float16_t alpha } int SigmoidFp16(const float16_t *src, float16_t *dst, int ele_num) { - for (int i = 0; i < ele_num; ++i) { - dst[i] = (float16_t)1.0f / (float16_t)(1.0f + exp(-src[i])); + int i = 0; +#ifdef ENABLE_ARM64 + int count = (ele_num / C4NUM) * C4NUM; + for (; i < count; i += C4NUM) { + float32x4_t tmp; + simd_exp(vnegq_f32(vcvt_f32_f16(vld1_f16(src + i))), (float *)&tmp); + vst1_f16(dst + i, vcvt_f16_f32(vdivq_f32(vdupq_n_f32(1.0f), vaddq_f32(vdupq_n_f32(1.0f), tmp)))); + } +#endif + for (; i < ele_num; ++i) { + float temp; + single_exp(-src[i], &temp); + dst[i] = (float16_t)1.0f / ((float16_t)1.0f + temp); } return NNACL_OK; } @@ -80,8 +91,33 @@ float16_t TanhOptFp16(float16_t src) { } int TanhFp16(const float16_t *src, float16_t *dst, int ele_num) { - for (int i = 0; i < ele_num; ++i) { - dst[i] = TanhOptFp16(src[i]); + int i = 0; +#ifdef ENABLE_ARM64 + static float32x4_t paramv[] = {{378.0f, 378.0f, 378.0f, 378.0f}, + {17325.0f, 17325.0f, 17325.0f, 17325.0f}, + {135135.0f, 135135.0f, 135135.0f, 135135.0f}, + {28.0f, 28.0f, 28.0f, 28.0f}, + {3150.0f, 3150.0f, 3150.0f, 3150.0f}, + {62370.0f, 62370.0f, 62370.0f, 62370.0f}}; + int count = (ele_num / C4NUM) * C4NUM; + for (; i < count; i += C4NUM) { + float32x4_t input = vcvt_f32_f16(vld1_f16(src + i)); + float32x4_t square = vmulq_f32(input, input); + float32x4_t a = vmulq_f32( + vaddq_f32(vmulq_f32(vaddq_f32(vmulq_f32(vaddq_f32(square, paramv[0]), square), paramv[1]), square), paramv[2]), + input); + float32x4_t b = vaddq_f32( + vmulq_f32(vaddq_f32(vmulq_f32(vaddq_f32(vmulq_f32(paramv[3], square), paramv[4]), square), paramv[5]), square), + paramv[2]); + vst1_f16(dst + i, vcvt_f16_f32(vdivq_f32(a, b))); + } +#endif + for (; i < ele_num; ++i) { + float input = src[i]; + float square = input * input; + float a = (((square + 378.0f) * square + 17325.0f) * square + 135135.0f) * input; + float b = ((28.0f * square + 3150.0f) * square + 62370.0f) * square + 135135.0f; + dst[i] = a / b; } return NNACL_OK; }