| @@ -168,3 +168,40 @@ int HardTanhFp16(const float16_t *src, int length, float16_t *dst, float min_val | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| int GeluFp16(const float16_t *src, int length, float16_t *dst, bool approximate) { | |||
| if (src == NULL || dst == NULL) { | |||
| return NNACL_ERR; | |||
| } | |||
| int i = 0; | |||
| if (approximate) { | |||
| // dst = 0.5 * x * (1 + tanh((2 / pi) ^ 0.5 * (x + 0.044715x^3))) | |||
| #ifdef ENABLE_NEON | |||
| int C8 = UP_ROUND(length, C8NUM); | |||
| for (; i < C8; i += C8NUM) { | |||
| float16x8_t in = vld1q_f16(src + i); | |||
| float16x8_t res = | |||
| 0.5 * in * (1.0 + MS_TANHX8_F16(((float16_t)0.79788456080287 + (float16_t)0.035677408136 * in * in) * in)); | |||
| vst1q_f16(dst + i, res); | |||
| } | |||
| #endif | |||
| for (; i < length; i++) { | |||
| dst[i] = | |||
| 0.5 * src[i] * | |||
| (1.0 + TanhOptFp16(((float16_t)0.79788456080287 + (float16_t)0.035677408136 * src[i] * src[i]) * src[i])); | |||
| } | |||
| } else { | |||
| #ifdef ENABLE_NEON | |||
| int C8 = UP_ROUND(length, C8NUM); | |||
| for (; i < C8; i += C8NUM) { | |||
| float16x8_t in = vld1q_f16(src + i); | |||
| float16x8_t res = 0.5 * in * (1.0 + MS_ERFX8_F16(in / (float16_t)1.4142135623730951f)); | |||
| vst1q_f16(dst + i, res); | |||
| } | |||
| #endif | |||
| for (; i < length; i++) { | |||
| dst[i] = 0.5 * src[i] * (1.0 + erf(src[i] / 1.4142135623730951f)); | |||
| } | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| @@ -34,6 +34,7 @@ int TanhFp16(const float16_t *src, float16_t *dst, int ele_num); | |||
| int HSwishFp16(const float16_t *src, float16_t *dst, int ele_num); | |||
| int SwishFp16(const float16_t *src, float16_t *dst, int ele_num); | |||
| int HardTanhFp16(const float16_t *src, int length, float16_t *dst, float min_val, float max_val); | |||
| int GeluFp16(const float16_t *src, int length, float16_t *dst, bool approximate); | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| @@ -134,50 +134,21 @@ float TanhOpt(float src) { | |||
| int Tanh(const float *src, int length, float *dst) { | |||
| int i = 0; | |||
| #if defined(ENABLE_ARM) || defined(ENABLE_SSE) || defined(ENABLE_AVX) | |||
| const int cnt = 6; | |||
| float data[] = {378.0f, 17325.0f, 135135.0f, 28.0f, 3150.0f, 62370.0f}; | |||
| #endif | |||
| #if defined(ENABLE_AVX) | |||
| MS_FLOAT32X8 neg_one_8 = {-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f}; | |||
| MS_FLOAT32X8 pos_one_8 = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; | |||
| MS_FLOAT32X8 param256[6]; | |||
| for (int j = 0; j < cnt; ++j) { | |||
| param256[j] = MS_MOV256_F32(data[j]); | |||
| } | |||
| for (; i < length - 8; i += 8) { | |||
| MS_FLOAT32X8 input = MS_LD256_F32(src + i); | |||
| MS_FLOAT32X8 square = input * input; | |||
| MS_FLOAT32X8 a = (((square + param256[0]) * square + param256[1]) * square + param256[2]) * input; | |||
| MS_FLOAT32X8 b = ((param256[3] * square + param256[4]) * square + param256[5]) * square + param256[2]; | |||
| MS_ST256_F32(dst + i, MS_MIN256_F32(MS_MAX256_F32(a / b, neg_one_8), pos_one_8)); | |||
| MS_ST256_F32(dst + i, MS_TANHX8_F32(input)); | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_ARM) || defined(ENABLE_SSE) | |||
| MS_FLOAT32X4 param[6]; | |||
| MS_FLOAT32X4 neg_one = {-1.0f, -1.0f, -1.0f, -1.0f}; | |||
| MS_FLOAT32X4 pos_one = {1.0f, 1.0f, 1.0f, 1.0f}; | |||
| for (int j = 0; j < cnt; ++j) { | |||
| param[j] = MS_MOVQ_F32(data[j]); | |||
| } | |||
| for (; i < length - 4; i += 4) { | |||
| MS_FLOAT32X4 input = MS_LDQ_F32(src + i); | |||
| MS_FLOAT32X4 square = input * input; | |||
| MS_FLOAT32X4 a = (((square + param[0]) * square + param[1]) * square + param[2]) * input; | |||
| MS_FLOAT32X4 b = ((param[3] * square + param[4]) * square + param[5]) * square + param[2]; | |||
| MS_STQ_F32(dst + i, MS_MINQ_F32(MS_MAXQ_F32(a / b, neg_one), pos_one)); | |||
| MS_STQ_F32(dst + i, MS_TANHX4_F32(input)); | |||
| } | |||
| #endif | |||
| for (; i < length; ++i) { | |||
| float input = src[i]; | |||
| float square = input * input; | |||
| float a = (((square + 378.0f) * square + 17325.0f) * square + 135135.0f) * input; | |||
| float b = ((28.0f * square + 3150.0f) * square + 62370.0f) * square + 135135.0f; | |||
| dst[i] = a / b; | |||
| dst[i] = MSMAX(dst[i], -1); | |||
| dst[i] = MSMIN(dst[i], 1); | |||
| dst[i] = TanhOpt(src[i]); | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| @@ -249,10 +220,44 @@ int HardTanh(const float *src, int length, float *dst, float min_val, float max_ | |||
| return NNACL_OK; | |||
| } | |||
| int Gelu(const float *src, int length, float *dst) { | |||
| for (int i = 0; i < length; ++i) { | |||
| float tanh_res = TanhOpt(sqrt(2 / M_PI) * (src[i] + 0.044715 * pow(src[i], 3))); | |||
| dst[i] = 0.5f * src[i] * (1 + tanh_res); | |||
| int Gelu(const float *src, int length, float *dst, bool approximate) { | |||
| if (src == NULL || dst == NULL) { | |||
| return NNACL_ERR; | |||
| } | |||
| int i = 0; | |||
| if (approximate) { | |||
| // dst = 0.5 * x * (1 + tanh((2 / pi) ^ 0.5 * (x + 0.044715x^3))) | |||
| #if defined(ENABLE_AVX) | |||
| int C8 = UP_ROUND(length, C8NUM); | |||
| for (; i < C8; i += C8NUM) { | |||
| MS_FLOAT32X8 in = MS_LD256_F32(src + i); | |||
| MS_FLOAT32X8 res = 0.5 * in * (1.0 + MS_TANHX8_F32((0.79788456080287f + 0.035677408136f * in * in) * in)); | |||
| MS_ST256_F32(dst + i, res); | |||
| } | |||
| #endif | |||
| #if defined(ENABLE_SSE) || defined(ENABLE_ARM) | |||
| int C4 = UP_ROUND(length, C4NUM); | |||
| for (; i < C4; i += C4NUM) { | |||
| MS_FLOAT32X4 in = MS_LDQ_F32(src + i); | |||
| MS_FLOAT32X4 res = 0.5 * in * (1.0 + MS_TANHX4_F32((0.79788456080287f + 0.035677408136f * in * in) * in)); | |||
| MS_STQ_F32(dst + i, res); | |||
| } | |||
| #endif | |||
| for (; i < length; i++) { | |||
| dst[i] = 0.5 * src[i] * (1.0 + TanhOpt((0.79788456080287f + 0.035677408136f * src[i] * src[i]) * src[i])); | |||
| } | |||
| } else { | |||
| #if defined(ENABLE_AVX) || defined(ENABLE_SSE) || defined(ENABLE_ARM) | |||
| int C4 = UP_ROUND(length, C4NUM); | |||
| for (; i < C4; i += C4NUM) { | |||
| MS_FLOAT32X4 in = MS_LDQ_F32(src + i); | |||
| MS_FLOAT32X4 res = 0.5 * in * (1.0 + MS_ERFX4_F32(in / 1.4142135623730951f)); | |||
| MS_STQ_F32(dst + i, res); | |||
| } | |||
| #endif | |||
| for (; i < length; i++) { | |||
| dst[i] = 0.5 * src[i] * (1.0 + erf(src[i] / 1.4142135623730951f)); | |||
| } | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| @@ -40,7 +40,7 @@ int HSigmoid(const float *src, int length, float *dst); | |||
| int Swish(const float *src, int length, float *dst); | |||
| int HSwish(const float *src, int length, float *dst); | |||
| int HardTanh(const float *src, int length, float *dst, float min_val, float max_val); | |||
| int Gelu(const float *src, int length, float *dst); | |||
| int Gelu(const float *src, int length, float *dst, bool approximate); | |||
| float TanhOpt(float src); | |||
| #ifdef __cplusplus | |||
| @@ -17,11 +17,6 @@ | |||
| #include "nnacl/fp32/conv_depthwise_fp32.h" | |||
| #include "nnacl/common_func.h" | |||
| #include "nnacl/fp32/common_func_fp32.h" | |||
| #include "nnacl/fp32/winograd_transform.h" | |||
| #include "nnacl/intrinsics/ms_simd_instructions.h" | |||
| #ifdef ENABLE_ARM64 | |||
| #include <arm_neon.h> | |||
| #endif | |||
| #if !defined(ENABLE_ARM) && !defined(ENABLE_SSE) | |||
| void ConvDwFp32Row(float *output_ptr, const float *input_ptr, const float *weight_ptr, int num_pixels, | |||
| @@ -1,39 +0,0 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "nnacl/fp32/gelu_fp32.h" | |||
| #include "nnacl/gelu_parameter.h" | |||
| #include <string.h> | |||
| #include <math.h> | |||
| #include "nnacl/errorcode.h" | |||
| int DoGeLU(const float *src, float *out, int64_t real_dst_count, const GeLUParameter *param) { | |||
| if (src == NULL || out == NULL) { | |||
| return NNACL_ERR; | |||
| } | |||
| if (param->approximate_) { | |||
| for (int i = 0; i < real_dst_count; i++) { | |||
| out[i] = 0.5 * src[i] * (1.0 + tanh(0.7978845608028654 * (src[i] + 0.044715 * pow(src[i], 3)))); | |||
| } | |||
| } else { | |||
| for (int i = 0; i < real_dst_count; i++) { | |||
| out[i] = 0.5 * src[i] * (1.0 + erf(src[i] / 1.4142135623730951)); | |||
| } | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| @@ -1,31 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_NNACL_FP32_GELU_H_ | |||
| #define MINDSPORE_LITE_NNACL_FP32_GELU_H_ | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/gelu_parameter.h" | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| int DoGeLU(const float *src, float *out, int64_t real_dst_count, const GeLUParameter *param); | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| #endif // MINDSPORE_LITE_NNACL_FP32_GELU_H_ | |||
| @@ -16,6 +16,7 @@ | |||
| #ifndef MINDSPORE_LITE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ | |||
| #define MINDSPORE_LITE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ | |||
| #include <math.h> | |||
| #ifdef ENABLE_ARM | |||
| #include <arm_neon.h> | |||
| #endif | |||
| @@ -170,4 +171,56 @@ inline static float32x4_t vrecp(float32x4_t v) { | |||
| MS_STQ_F32(output_ptr + 6 * num, dst##7); \ | |||
| MS_STQ_F32(output_ptr + 7 * num, dst##8); | |||
| static inline MS_FLOAT32X4 MS_TANHX4_F32(MS_FLOAT32X4 src) { | |||
| static const float data[] = {378.0f, 17325.0f, 135135.0f, 28.0f, 3150.0f, 62370.0f}; | |||
| static const MS_FLOAT32X4 neg = {-1.0f, -1.0f, -1.0f, -1.0f}; | |||
| static const MS_FLOAT32X4 pos = {1.0f, 1.0f, 1.0f, 1.0f}; | |||
| MS_FLOAT32X4 square = src * src; | |||
| MS_FLOAT32X4 a = (((square + data[0]) * square + data[1]) * square + data[2]) * src; | |||
| MS_FLOAT32X4 b = ((data[3] * square + data[4]) * square + data[5]) * square + data[2]; | |||
| return MS_MINQ_F32(MS_MAXQ_F32(a / b, neg), pos); | |||
| } | |||
| #ifdef ENABLE_AVX | |||
| static inline MS_FLOAT32X8 MS_TANHX8_F32(MS_FLOAT32X8 src) { | |||
| static const float data[] = {378.0f, 17325.0f, 135135.0f, 28.0f, 3150.0f, 62370.0f}; | |||
| static const MS_FLOAT32X8 neg = {-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f}; | |||
| static const MS_FLOAT32X8 pos = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; | |||
| MS_FLOAT32X8 square = src * src; | |||
| MS_FLOAT32X8 a = (((square + data[0]) * square + data[1]) * square + data[2]) * src; | |||
| MS_FLOAT32X8 b = ((data[3] * square + data[4]) * square + data[5]) * square + data[2]; | |||
| return MS_MIN256_F32(MS_MAX256_F32(a / b, neg), pos); | |||
| } | |||
| #endif | |||
| static inline MS_FLOAT32X4 MS_ERFX4_F32(MS_FLOAT32X4 src) { | |||
| MS_FLOAT32X4 dst; | |||
| dst[0] = erff(src[0]); | |||
| dst[1] = erff(src[1]); | |||
| dst[2] = erff(src[2]); | |||
| dst[3] = erff(src[3]); | |||
| return dst; | |||
| } | |||
| #ifdef ENABLE_ARM64 | |||
| static inline float16x8_t MS_TANHX8_F16(float16x8_t src) { | |||
| float32x4_t src_low = vcvt_f32_f16(vget_low_f16(src)); | |||
| float32x4_t src_high = vcvt_f32_f16(vget_high_f16(src)); | |||
| return vcombine_f16(vcvt_f16_f32(MS_TANHX4_F32(src_low)), vcvt_f16_f32(MS_TANHX4_F32(src_high))); | |||
| } | |||
| static inline float16x8_t MS_ERFX8_F16(float16x8_t src) { | |||
| float16x8_t dst; | |||
| dst[0] = erff(src[0]); | |||
| dst[1] = erff(src[1]); | |||
| dst[2] = erff(src[2]); | |||
| dst[3] = erff(src[3]); | |||
| dst[4] = erff(src[4]); | |||
| dst[5] = erff(src[5]); | |||
| dst[6] = erff(src[6]); | |||
| dst[7] = erff(src[7]); | |||
| return dst; | |||
| } | |||
| #endif | |||
| #endif // MINDSPORE_LITE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ | |||
| @@ -21,7 +21,7 @@ | |||
| #include <stdlib.h> | |||
| #include <stdbool.h> | |||
| #include <string.h> | |||
| #if defined(ENBALE_AVX) || defined(ENABLE_SSE) || defined(ENABLE_ARM) | |||
| #if defined(ENABLE_AVX) || defined(ENABLE_SSE) || defined(ENABLE_ARM) | |||
| #include "nnacl/intrinsics/ms_simd_instructions.h" | |||
| #endif | |||
| @@ -25,6 +25,7 @@ using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| using mindspore::schema::ActivationType_GELU; | |||
| using mindspore::schema::ActivationType_HSWISH; | |||
| using mindspore::schema::ActivationType_LEAKY_RELU; | |||
| using mindspore::schema::ActivationType_RELU; | |||
| @@ -73,6 +74,8 @@ int ActivationFp16CPUKernel::DoActivation(int task_id) { | |||
| } else if (type_ == schema::ActivationType_HARD_TANH) { | |||
| error_code = | |||
| HardTanhFp16(fp16_input_ + stride * task_id, count, fp16_output_ + stride * task_id, min_val_, max_val_); | |||
| } else if (type_ == schema::ActivationType_GELU) { | |||
| error_code = GeluFp16(fp16_input_ + stride * task_id, count, fp16_output_ + stride * task_id, true); | |||
| } else { | |||
| MS_LOG(ERROR) << "Activation fp16 not support type: " << type_; | |||
| return RET_ERROR; | |||
| @@ -79,7 +79,7 @@ int ActivationCPUKernel::DoActivation(int task_id) { | |||
| } else if (type_ == schema::ActivationType_HARD_TANH) { | |||
| ret = HardTanh(input_addr + stride * task_id, count, output_addr + stride * task_id, min_val_, max_val_); | |||
| } else if (type_ == schema::ActivationType_GELU) { | |||
| ret = Gelu(input_addr + stride * task_id, count, output_addr + stride * task_id); | |||
| ret = Gelu(input_addr + stride * task_id, count, output_addr + stride * task_id, true); | |||
| } else { | |||
| MS_LOG(ERROR) << "Activation type error"; | |||
| return RET_ERROR; | |||