From 6090bf0849ec0f3086a634aa8d160a43eb3ff2e2 Mon Sep 17 00:00:00 2001 From: lzk Date: Sat, 20 Mar 2021 20:10:58 -0700 Subject: [PATCH] gelu optimize --- mindspore/lite/nnacl/fp16/activation_fp16.c | 37 +++++++++ mindspore/lite/nnacl/fp16/activation_fp16.h | 1 + mindspore/lite/nnacl/fp32/activation_fp32.c | 77 ++++++++++--------- mindspore/lite/nnacl/fp32/activation_fp32.h | 2 +- .../lite/nnacl/fp32/conv_depthwise_fp32.c | 5 -- mindspore/lite/nnacl/fp32/gelu_fp32.c | 39 ---------- mindspore/lite/nnacl/fp32/gelu_fp32.h | 31 -------- .../nnacl/intrinsics/ms_simd_instructions.h | 53 +++++++++++++ mindspore/lite/nnacl/op_base.h | 2 +- .../kernel/arm/fp16/activation_fp16.cc | 3 + .../kernel/arm/fp32/activation_fp32.cc | 2 +- 11 files changed, 138 insertions(+), 114 deletions(-) delete mode 100644 mindspore/lite/nnacl/fp32/gelu_fp32.c delete mode 100644 mindspore/lite/nnacl/fp32/gelu_fp32.h diff --git a/mindspore/lite/nnacl/fp16/activation_fp16.c b/mindspore/lite/nnacl/fp16/activation_fp16.c index d281d59d64..5528fc087a 100644 --- a/mindspore/lite/nnacl/fp16/activation_fp16.c +++ b/mindspore/lite/nnacl/fp16/activation_fp16.c @@ -168,3 +168,40 @@ int HardTanhFp16(const float16_t *src, int length, float16_t *dst, float min_val } return NNACL_OK; } + +int GeluFp16(const float16_t *src, int length, float16_t *dst, bool approximate) { + if (src == NULL || dst == NULL) { + return NNACL_ERR; + } + int i = 0; + if (approximate) { + // dst = 0.5 * x * (1 + tanh((2 / pi) ^ 0.5 * (x + 0.044715x^3))) +#ifdef ENABLE_NEON + int C8 = UP_ROUND(length, C8NUM); + for (; i < C8; i += C8NUM) { + float16x8_t in = vld1q_f16(src + i); + float16x8_t res = + 0.5 * in * (1.0 + MS_TANHX8_F16(((float16_t)0.79788456080287 + (float16_t)0.035677408136 * in * in) * in)); + vst1q_f16(dst + i, res); + } +#endif + for (; i < length; i++) { + dst[i] = + 0.5 * src[i] * + (1.0 + TanhOptFp16(((float16_t)0.79788456080287 + (float16_t)0.035677408136 * src[i] * src[i]) * src[i])); + } + } else { +#ifdef ENABLE_NEON + int C8 = UP_ROUND(length, C8NUM); + for (; i < C8; i += C8NUM) { + float16x8_t in = vld1q_f16(src + i); + float16x8_t res = 0.5 * in * (1.0 + MS_ERFX8_F16(in / (float16_t)1.4142135623730951f)); + vst1q_f16(dst + i, res); + } +#endif + for (; i < length; i++) { + dst[i] = 0.5 * src[i] * (1.0 + erf(src[i] / 1.4142135623730951f)); + } + } + return NNACL_OK; +} diff --git a/mindspore/lite/nnacl/fp16/activation_fp16.h b/mindspore/lite/nnacl/fp16/activation_fp16.h index 6fdb3220a1..6463490490 100644 --- a/mindspore/lite/nnacl/fp16/activation_fp16.h +++ b/mindspore/lite/nnacl/fp16/activation_fp16.h @@ -34,6 +34,7 @@ int TanhFp16(const float16_t *src, float16_t *dst, int ele_num); int HSwishFp16(const float16_t *src, float16_t *dst, int ele_num); int SwishFp16(const float16_t *src, float16_t *dst, int ele_num); int HardTanhFp16(const float16_t *src, int length, float16_t *dst, float min_val, float max_val); +int GeluFp16(const float16_t *src, int length, float16_t *dst, bool approximate); #ifdef __cplusplus } #endif diff --git a/mindspore/lite/nnacl/fp32/activation_fp32.c b/mindspore/lite/nnacl/fp32/activation_fp32.c index 7fc4a6012a..6df6935e22 100644 --- a/mindspore/lite/nnacl/fp32/activation_fp32.c +++ b/mindspore/lite/nnacl/fp32/activation_fp32.c @@ -134,50 +134,21 @@ float TanhOpt(float src) { int Tanh(const float *src, int length, float *dst) { int i = 0; -#if defined(ENABLE_ARM) || defined(ENABLE_SSE) || defined(ENABLE_AVX) - const int cnt = 6; - float data[] = {378.0f, 17325.0f, 135135.0f, 28.0f, 3150.0f, 62370.0f}; -#endif - #if defined(ENABLE_AVX) - MS_FLOAT32X8 neg_one_8 = {-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f}; - MS_FLOAT32X8 pos_one_8 = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; - MS_FLOAT32X8 param256[6]; - for (int j = 0; j < cnt; ++j) { - param256[j] = MS_MOV256_F32(data[j]); - } for (; i < length - 8; i += 8) { MS_FLOAT32X8 input = MS_LD256_F32(src + i); - MS_FLOAT32X8 square = input * input; - MS_FLOAT32X8 a = (((square + param256[0]) * square + param256[1]) * square + param256[2]) * input; - MS_FLOAT32X8 b = ((param256[3] * square + param256[4]) * square + param256[5]) * square + param256[2]; - MS_ST256_F32(dst + i, MS_MIN256_F32(MS_MAX256_F32(a / b, neg_one_8), pos_one_8)); + MS_ST256_F32(dst + i, MS_TANHX8_F32(input)); } #endif #if defined(ENABLE_ARM) || defined(ENABLE_SSE) - MS_FLOAT32X4 param[6]; - MS_FLOAT32X4 neg_one = {-1.0f, -1.0f, -1.0f, -1.0f}; - MS_FLOAT32X4 pos_one = {1.0f, 1.0f, 1.0f, 1.0f}; - for (int j = 0; j < cnt; ++j) { - param[j] = MS_MOVQ_F32(data[j]); - } for (; i < length - 4; i += 4) { MS_FLOAT32X4 input = MS_LDQ_F32(src + i); - MS_FLOAT32X4 square = input * input; - MS_FLOAT32X4 a = (((square + param[0]) * square + param[1]) * square + param[2]) * input; - MS_FLOAT32X4 b = ((param[3] * square + param[4]) * square + param[5]) * square + param[2]; - MS_STQ_F32(dst + i, MS_MINQ_F32(MS_MAXQ_F32(a / b, neg_one), pos_one)); + MS_STQ_F32(dst + i, MS_TANHX4_F32(input)); } #endif for (; i < length; ++i) { - float input = src[i]; - float square = input * input; - float a = (((square + 378.0f) * square + 17325.0f) * square + 135135.0f) * input; - float b = ((28.0f * square + 3150.0f) * square + 62370.0f) * square + 135135.0f; - dst[i] = a / b; - dst[i] = MSMAX(dst[i], -1); - dst[i] = MSMIN(dst[i], 1); + dst[i] = TanhOpt(src[i]); } return NNACL_OK; } @@ -249,10 +220,44 @@ int HardTanh(const float *src, int length, float *dst, float min_val, float max_ return NNACL_OK; } -int Gelu(const float *src, int length, float *dst) { - for (int i = 0; i < length; ++i) { - float tanh_res = TanhOpt(sqrt(2 / M_PI) * (src[i] + 0.044715 * pow(src[i], 3))); - dst[i] = 0.5f * src[i] * (1 + tanh_res); +int Gelu(const float *src, int length, float *dst, bool approximate) { + if (src == NULL || dst == NULL) { + return NNACL_ERR; + } + int i = 0; + if (approximate) { + // dst = 0.5 * x * (1 + tanh((2 / pi) ^ 0.5 * (x + 0.044715x^3))) +#if defined(ENABLE_AVX) + int C8 = UP_ROUND(length, C8NUM); + for (; i < C8; i += C8NUM) { + MS_FLOAT32X8 in = MS_LD256_F32(src + i); + MS_FLOAT32X8 res = 0.5 * in * (1.0 + MS_TANHX8_F32((0.79788456080287f + 0.035677408136f * in * in) * in)); + MS_ST256_F32(dst + i, res); + } +#endif +#if defined(ENABLE_SSE) || defined(ENABLE_ARM) + int C4 = UP_ROUND(length, C4NUM); + for (; i < C4; i += C4NUM) { + MS_FLOAT32X4 in = MS_LDQ_F32(src + i); + MS_FLOAT32X4 res = 0.5 * in * (1.0 + MS_TANHX4_F32((0.79788456080287f + 0.035677408136f * in * in) * in)); + MS_STQ_F32(dst + i, res); + } +#endif + for (; i < length; i++) { + dst[i] = 0.5 * src[i] * (1.0 + TanhOpt((0.79788456080287f + 0.035677408136f * src[i] * src[i]) * src[i])); + } + } else { +#if defined(ENABLE_AVX) || defined(ENABLE_SSE) || defined(ENABLE_ARM) + int C4 = UP_ROUND(length, C4NUM); + for (; i < C4; i += C4NUM) { + MS_FLOAT32X4 in = MS_LDQ_F32(src + i); + MS_FLOAT32X4 res = 0.5 * in * (1.0 + MS_ERFX4_F32(in / 1.4142135623730951f)); + MS_STQ_F32(dst + i, res); + } +#endif + for (; i < length; i++) { + dst[i] = 0.5 * src[i] * (1.0 + erf(src[i] / 1.4142135623730951f)); + } } return NNACL_OK; } diff --git a/mindspore/lite/nnacl/fp32/activation_fp32.h b/mindspore/lite/nnacl/fp32/activation_fp32.h index 8ce27b7df1..1f741c50b5 100644 --- a/mindspore/lite/nnacl/fp32/activation_fp32.h +++ b/mindspore/lite/nnacl/fp32/activation_fp32.h @@ -40,7 +40,7 @@ int HSigmoid(const float *src, int length, float *dst); int Swish(const float *src, int length, float *dst); int HSwish(const float *src, int length, float *dst); int HardTanh(const float *src, int length, float *dst, float min_val, float max_val); -int Gelu(const float *src, int length, float *dst); +int Gelu(const float *src, int length, float *dst, bool approximate); float TanhOpt(float src); #ifdef __cplusplus diff --git a/mindspore/lite/nnacl/fp32/conv_depthwise_fp32.c b/mindspore/lite/nnacl/fp32/conv_depthwise_fp32.c index 7c2ff0ca63..9614274b14 100644 --- a/mindspore/lite/nnacl/fp32/conv_depthwise_fp32.c +++ b/mindspore/lite/nnacl/fp32/conv_depthwise_fp32.c @@ -17,11 +17,6 @@ #include "nnacl/fp32/conv_depthwise_fp32.h" #include "nnacl/common_func.h" #include "nnacl/fp32/common_func_fp32.h" -#include "nnacl/fp32/winograd_transform.h" -#include "nnacl/intrinsics/ms_simd_instructions.h" -#ifdef ENABLE_ARM64 -#include -#endif #if !defined(ENABLE_ARM) && !defined(ENABLE_SSE) void ConvDwFp32Row(float *output_ptr, const float *input_ptr, const float *weight_ptr, int num_pixels, diff --git a/mindspore/lite/nnacl/fp32/gelu_fp32.c b/mindspore/lite/nnacl/fp32/gelu_fp32.c deleted file mode 100644 index 3340208dfd..0000000000 --- a/mindspore/lite/nnacl/fp32/gelu_fp32.c +++ /dev/null @@ -1,39 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "nnacl/fp32/gelu_fp32.h" -#include "nnacl/gelu_parameter.h" -#include -#include -#include "nnacl/errorcode.h" - -int DoGeLU(const float *src, float *out, int64_t real_dst_count, const GeLUParameter *param) { - if (src == NULL || out == NULL) { - return NNACL_ERR; - } - - if (param->approximate_) { - for (int i = 0; i < real_dst_count; i++) { - out[i] = 0.5 * src[i] * (1.0 + tanh(0.7978845608028654 * (src[i] + 0.044715 * pow(src[i], 3)))); - } - } else { - for (int i = 0; i < real_dst_count; i++) { - out[i] = 0.5 * src[i] * (1.0 + erf(src[i] / 1.4142135623730951)); - } - } - - return NNACL_OK; -} diff --git a/mindspore/lite/nnacl/fp32/gelu_fp32.h b/mindspore/lite/nnacl/fp32/gelu_fp32.h deleted file mode 100644 index e37f059382..0000000000 --- a/mindspore/lite/nnacl/fp32/gelu_fp32.h +++ /dev/null @@ -1,31 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_NNACL_FP32_GELU_H_ -#define MINDSPORE_LITE_NNACL_FP32_GELU_H_ - -#include "nnacl/op_base.h" -#include "nnacl/gelu_parameter.h" - -#ifdef __cplusplus -extern "C" { -#endif - -int DoGeLU(const float *src, float *out, int64_t real_dst_count, const GeLUParameter *param); -#ifdef __cplusplus -} -#endif - -#endif // MINDSPORE_LITE_NNACL_FP32_GELU_H_ diff --git a/mindspore/lite/nnacl/intrinsics/ms_simd_instructions.h b/mindspore/lite/nnacl/intrinsics/ms_simd_instructions.h index a101a62f5c..6c1153d6d3 100644 --- a/mindspore/lite/nnacl/intrinsics/ms_simd_instructions.h +++ b/mindspore/lite/nnacl/intrinsics/ms_simd_instructions.h @@ -16,6 +16,7 @@ #ifndef MINDSPORE_LITE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ #define MINDSPORE_LITE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ +#include #ifdef ENABLE_ARM #include #endif @@ -170,4 +171,56 @@ inline static float32x4_t vrecp(float32x4_t v) { MS_STQ_F32(output_ptr + 6 * num, dst##7); \ MS_STQ_F32(output_ptr + 7 * num, dst##8); +static inline MS_FLOAT32X4 MS_TANHX4_F32(MS_FLOAT32X4 src) { + static const float data[] = {378.0f, 17325.0f, 135135.0f, 28.0f, 3150.0f, 62370.0f}; + static const MS_FLOAT32X4 neg = {-1.0f, -1.0f, -1.0f, -1.0f}; + static const MS_FLOAT32X4 pos = {1.0f, 1.0f, 1.0f, 1.0f}; + MS_FLOAT32X4 square = src * src; + MS_FLOAT32X4 a = (((square + data[0]) * square + data[1]) * square + data[2]) * src; + MS_FLOAT32X4 b = ((data[3] * square + data[4]) * square + data[5]) * square + data[2]; + return MS_MINQ_F32(MS_MAXQ_F32(a / b, neg), pos); +} + +#ifdef ENABLE_AVX +static inline MS_FLOAT32X8 MS_TANHX8_F32(MS_FLOAT32X8 src) { + static const float data[] = {378.0f, 17325.0f, 135135.0f, 28.0f, 3150.0f, 62370.0f}; + static const MS_FLOAT32X8 neg = {-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f}; + static const MS_FLOAT32X8 pos = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; + MS_FLOAT32X8 square = src * src; + MS_FLOAT32X8 a = (((square + data[0]) * square + data[1]) * square + data[2]) * src; + MS_FLOAT32X8 b = ((data[3] * square + data[4]) * square + data[5]) * square + data[2]; + return MS_MIN256_F32(MS_MAX256_F32(a / b, neg), pos); +} +#endif + +static inline MS_FLOAT32X4 MS_ERFX4_F32(MS_FLOAT32X4 src) { + MS_FLOAT32X4 dst; + dst[0] = erff(src[0]); + dst[1] = erff(src[1]); + dst[2] = erff(src[2]); + dst[3] = erff(src[3]); + return dst; +} + +#ifdef ENABLE_ARM64 +static inline float16x8_t MS_TANHX8_F16(float16x8_t src) { + float32x4_t src_low = vcvt_f32_f16(vget_low_f16(src)); + float32x4_t src_high = vcvt_f32_f16(vget_high_f16(src)); + return vcombine_f16(vcvt_f16_f32(MS_TANHX4_F32(src_low)), vcvt_f16_f32(MS_TANHX4_F32(src_high))); +} + +static inline float16x8_t MS_ERFX8_F16(float16x8_t src) { + float16x8_t dst; + dst[0] = erff(src[0]); + dst[1] = erff(src[1]); + dst[2] = erff(src[2]); + dst[3] = erff(src[3]); + dst[4] = erff(src[4]); + dst[5] = erff(src[5]); + dst[6] = erff(src[6]); + dst[7] = erff(src[7]); + return dst; +} +#endif + #endif // MINDSPORE_LITE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ diff --git a/mindspore/lite/nnacl/op_base.h b/mindspore/lite/nnacl/op_base.h index d7551c3f30..b11f0a74d2 100644 --- a/mindspore/lite/nnacl/op_base.h +++ b/mindspore/lite/nnacl/op_base.h @@ -21,7 +21,7 @@ #include #include #include -#if defined(ENBALE_AVX) || defined(ENABLE_SSE) || defined(ENABLE_ARM) +#if defined(ENABLE_AVX) || defined(ENABLE_SSE) || defined(ENABLE_ARM) #include "nnacl/intrinsics/ms_simd_instructions.h" #endif diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/activation_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/activation_fp16.cc index 09c49fc64a..ce2f2bcce5 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/activation_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/activation_fp16.cc @@ -25,6 +25,7 @@ using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; +using mindspore::schema::ActivationType_GELU; using mindspore::schema::ActivationType_HSWISH; using mindspore::schema::ActivationType_LEAKY_RELU; using mindspore::schema::ActivationType_RELU; @@ -73,6 +74,8 @@ int ActivationFp16CPUKernel::DoActivation(int task_id) { } else if (type_ == schema::ActivationType_HARD_TANH) { error_code = HardTanhFp16(fp16_input_ + stride * task_id, count, fp16_output_ + stride * task_id, min_val_, max_val_); + } else if (type_ == schema::ActivationType_GELU) { + error_code = GeluFp16(fp16_input_ + stride * task_id, count, fp16_output_ + stride * task_id, true); } else { MS_LOG(ERROR) << "Activation fp16 not support type: " << type_; return RET_ERROR; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/activation_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/activation_fp32.cc index 5193153bb3..677f679c52 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/activation_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/activation_fp32.cc @@ -79,7 +79,7 @@ int ActivationCPUKernel::DoActivation(int task_id) { } else if (type_ == schema::ActivationType_HARD_TANH) { ret = HardTanh(input_addr + stride * task_id, count, output_addr + stride * task_id, min_val_, max_val_); } else if (type_ == schema::ActivationType_GELU) { - ret = Gelu(input_addr + stride * task_id, count, output_addr + stride * task_id); + ret = Gelu(input_addr + stride * task_id, count, output_addr + stride * task_id, true); } else { MS_LOG(ERROR) << "Activation type error"; return RET_ERROR;