diff --git a/mindspore/lite/nnacl/fp32/activation.c b/mindspore/lite/nnacl/fp32/activation.c index 1124fd67cc..08839f439c 100644 --- a/mindspore/lite/nnacl/fp32/activation.c +++ b/mindspore/lite/nnacl/fp32/activation.c @@ -108,6 +108,26 @@ int Tanh(const float *src, int length, float *dst) { return NNACL_OK; } +int Swish(const float *src, int length, float *dst) { + int ret = Sigmoid(src, length, dst); + if (ret != NNACL_OK) { + return NNACL_ERR; + } + int index = 0; +#ifdef ENABLE_NEON + for (; index <= length - C4NUM; index += C4NUM) { + float32x4_t src_value = vld1q_f32(src + index); + float32x4_t sigmoid_value = vld1q_f32(dst + index); + float32x4_t result = vmulq_f32(src_value, sigmoid_value); + vst1q_f32(dst + index, result); + } +#endif + for (; index < length; index++) { + dst[index] = src[index] * dst[index]; + } + return NNACL_OK; +} + int HSwish(const float *src, int length, float *dst) { for (int i = 0; i < length; ++i) { float in = src[i]; diff --git a/mindspore/lite/nnacl/fp32/activation.h b/mindspore/lite/nnacl/fp32/activation.h index bd85832342..24f32a54c2 100644 --- a/mindspore/lite/nnacl/fp32/activation.h +++ b/mindspore/lite/nnacl/fp32/activation.h @@ -37,6 +37,7 @@ int LRelu(const float *src, int length, float *dst, float alpha); int Sigmoid(const float *src, int length, float *dst); int Tanh(const float *src, int length, float *dst); int HSigmoid(const float *src, int length, float *dst); +int Swish(const float *src, int length, float *dst); int HSwish(const float *src, int length, float *dst); int HardTanh(const float *src, int length, float *dst, float min_val, float max_val); #ifdef __cplusplus diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index 27656e5b22..4ad406898e 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -79,7 +79,8 @@ enum ActivationType : byte { LINEAR = 15, HARD_TANH = 16, SIGN = 17, - UNKNOW = 18 + SWISH = 18, + UNKNOW = 19 } enum ActivationGradType : byte { NO_ACTIVATION = 0, diff --git a/mindspore/lite/src/ops/activation.cc b/mindspore/lite/src/ops/activation.cc index 9021392dfa..24b7d51a95 100644 --- a/mindspore/lite/src/ops/activation.cc +++ b/mindspore/lite/src/ops/activation.cc @@ -53,6 +53,8 @@ int Activation::UnPackAttr(const Primitive &prim, const std::vector attr->type = schema::ActivationType_SIGMOID; } else if (prim.name() == "ReLU6") { attr->type = schema::ActivationType_RELU6; + } else if (prim.name() == "Swish") { + attr->type = schema::ActivationType_SWISH; } else if (prim.name() == "HSwish") { attr->type = schema::ActivationType_HSWISH; } else if (prim.name() == "HSigmoid") { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/activation.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/activation.cc index 093374d585..b28091433a 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/activation.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/activation.cc @@ -29,6 +29,7 @@ using mindspore::schema::ActivationType_HSWISH; using mindspore::schema::ActivationType_LEAKY_RELU; using mindspore::schema::ActivationType_RELU; using mindspore::schema::ActivationType_RELU6; +using mindspore::schema::ActivationType_SWISH; using mindspore::schema::PrimitiveType_Activation; namespace mindspore::kernel { @@ -44,32 +45,34 @@ int ActivationCPUKernel::DoActivation(int task_id) { int stride = UP_DIV(length, thread_count_); int count = MSMIN(stride, length - stride * task_id); - auto error_code = RET_OK; + auto ret = RET_OK; if (type_ == schema::ActivationType_RELU) { - error_code = Fp32Relu(input_addr + stride * task_id, count, output_addr + stride * task_id); + ret = Fp32Relu(input_addr + stride * task_id, count, output_addr + stride * task_id); } else if (type_ == schema::ActivationType_RELU6) { - error_code = Fp32Relu6(input_addr + stride * task_id, count, output_addr + stride * task_id); + ret = Fp32Relu6(input_addr + stride * task_id, count, output_addr + stride * task_id); } else if (type_ == schema::ActivationType_LEAKY_RELU) { - error_code = LRelu(input_addr + stride * task_id, count, output_addr + stride * task_id, alpha_); + ret = LRelu(input_addr + stride * task_id, count, output_addr + stride * task_id, alpha_); } else if (type_ == schema::ActivationType_SIGMOID) { - error_code = Sigmoid(input_addr + stride * task_id, count, output_addr + stride * task_id); + ret = Sigmoid(input_addr + stride * task_id, count, output_addr + stride * task_id); } else if (type_ == schema::ActivationType_TANH) { - error_code = Tanh(input_addr + stride * task_id, count, output_addr + stride * task_id); + ret = Tanh(input_addr + stride * task_id, count, output_addr + stride * task_id); + } else if (type_ == schema::ActivationType_SWISH) { + ret = Swish(input_addr + stride * task_id, count, output_addr + stride * task_id); } else if (type_ == schema::ActivationType_HSWISH) { - error_code = HSwish(input_addr + stride * task_id, count, output_addr + stride * task_id); + ret = HSwish(input_addr + stride * task_id, count, output_addr + stride * task_id); } else if (type_ == schema::ActivationType_HSIGMOID) { - error_code = HSigmoid(input_addr + stride * task_id, count, output_addr + stride * task_id); + ret = HSigmoid(input_addr + stride * task_id, count, output_addr + stride * task_id); } else if (type_ == schema::ActivationType_HARD_TANH) { - error_code = HardTanh(input_addr + stride * task_id, count, output_addr + stride * task_id, min_val_, max_val_); + ret = HardTanh(input_addr + stride * task_id, count, output_addr + stride * task_id, min_val_, max_val_); } else { MS_LOG(ERROR) << "Activation type error"; return RET_ERROR; } - if (error_code != RET_OK) { - return RET_ERROR; + if (ret != RET_OK) { + MS_LOG(ERROR) << "Activation error, ret: " << ret; } - return RET_OK; + return ret; } int ActivationRun(void *cdata, int task_id) { diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/activation_fp32_test.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/activation_fp32_test.cc index 2461880f17..b5f8f87b3a 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/activation_fp32_test.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/activation_fp32_test.cc @@ -73,6 +73,17 @@ TEST_F(TestActivationFp32, SigmoidFp32) { MS_LOG(INFO) << "TestSigmoidFp32 passed"; } +TEST_F(TestActivationFp32, SwishFp32) { + float input[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + float output[8] = {0}; + Swish(input, 8, output); + + float expect[8] = {0, 0.731059, 1.761594, 2.857722, 3.928056, 4.966535, 5.985162, 6.993623}; + for (int i = 0; i < 8; ++i) { + EXPECT_NEAR(output[i], expect[i], 0.00001); + } +} + TEST_F(TestActivationFp32, TanhFp32) { float input[7] = {-3, -2, -1, 0, 1, 2, 3}; float output[7] = {0}; diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.cc index b4c43fc635..456e609346 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.cc @@ -56,6 +56,9 @@ STATUS TfliteActivationParser::Parse(TfliteTensorsInfo *tensors_info, } else if (std::strcmp(node_name, "Logistic") == 0) { MS_LOG(DEBUG) << "parse TfliteLogisticParser"; attr->type = schema::ActivationType_SIGMOID; + } else if (std::strcmp(node_name, "Swish") == 0) { + MS_LOG(DEBUG) << "parse TfliteSwishParser"; + attr->type = schema::ActivationType_SWISH; } else if (std::strcmp(node_name, "HardSwish") == 0) { MS_LOG(DEBUG) << "parse TfliteHardSwishParser"; attr->type = schema::ActivationType_HSWISH; @@ -82,6 +85,7 @@ STATUS TfliteActivationParser::Parse(TfliteTensorsInfo *tensors_info, TfliteNodeRegister g_TfliteReluParser("Relu", new TfliteReluParser()); TfliteNodeRegister g_TfliteRelu6Parser("Relu6", new TfliteRelu6Parser()); TfliteNodeRegister g_TfliteTanhParser("Tanh", new TfliteTanhParser()); +TfliteNodeRegister g_TfliteSwishParser("Swish", new TfliteSwishParser()); TfliteNodeRegister g_TfliteHardSwishParser("HardSwish", new TfliteHardSwishParser()); TfliteNodeRegister g_tfliteLogisticParser("Logistic", new TfliteLogisticParser()); TfliteNodeRegister g_TfliteLeakyReluParser("LeakyRelu", new TfliteLeakyReluParser()); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.h index d6d8a13fb6..cc849ed330 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.h @@ -53,6 +53,11 @@ class TfliteLogisticParser : public TfliteActivationParser { TfliteLogisticParser() : TfliteActivationParser() {} }; +class TfliteSwishParser : public TfliteActivationParser { + public: + TfliteSwishParser() : TfliteActivationParser() {} +}; + class TfliteHardSwishParser : public TfliteActivationParser { public: TfliteHardSwishParser() : TfliteActivationParser() {}