| @@ -108,6 +108,26 @@ int Tanh(const float *src, int length, float *dst) { | |||||
| return NNACL_OK; | return NNACL_OK; | ||||
| } | } | ||||
| int Swish(const float *src, int length, float *dst) { | |||||
| int ret = Sigmoid(src, length, dst); | |||||
| if (ret != NNACL_OK) { | |||||
| return NNACL_ERR; | |||||
| } | |||||
| int index = 0; | |||||
| #ifdef ENABLE_NEON | |||||
| for (; index <= length - C4NUM; index += C4NUM) { | |||||
| float32x4_t src_value = vld1q_f32(src + index); | |||||
| float32x4_t sigmoid_value = vld1q_f32(dst + index); | |||||
| float32x4_t result = vmulq_f32(src_value, sigmoid_value); | |||||
| vst1q_f32(dst + index, result); | |||||
| } | |||||
| #endif | |||||
| for (; index < length; index++) { | |||||
| dst[index] = src[index] * dst[index]; | |||||
| } | |||||
| return NNACL_OK; | |||||
| } | |||||
| int HSwish(const float *src, int length, float *dst) { | int HSwish(const float *src, int length, float *dst) { | ||||
| for (int i = 0; i < length; ++i) { | for (int i = 0; i < length; ++i) { | ||||
| float in = src[i]; | float in = src[i]; | ||||
| @@ -37,6 +37,7 @@ int LRelu(const float *src, int length, float *dst, float alpha); | |||||
| int Sigmoid(const float *src, int length, float *dst); | int Sigmoid(const float *src, int length, float *dst); | ||||
| int Tanh(const float *src, int length, float *dst); | int Tanh(const float *src, int length, float *dst); | ||||
| int HSigmoid(const float *src, int length, float *dst); | int HSigmoid(const float *src, int length, float *dst); | ||||
| int Swish(const float *src, int length, float *dst); | |||||
| int HSwish(const float *src, int length, float *dst); | int HSwish(const float *src, int length, float *dst); | ||||
| int HardTanh(const float *src, int length, float *dst, float min_val, float max_val); | int HardTanh(const float *src, int length, float *dst, float min_val, float max_val); | ||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| @@ -79,7 +79,8 @@ enum ActivationType : byte { | |||||
| LINEAR = 15, | LINEAR = 15, | ||||
| HARD_TANH = 16, | HARD_TANH = 16, | ||||
| SIGN = 17, | SIGN = 17, | ||||
| UNKNOW = 18 | |||||
| SWISH = 18, | |||||
| UNKNOW = 19 | |||||
| } | } | ||||
| enum ActivationGradType : byte { | enum ActivationGradType : byte { | ||||
| NO_ACTIVATION = 0, | NO_ACTIVATION = 0, | ||||
| @@ -53,6 +53,8 @@ int Activation::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> | |||||
| attr->type = schema::ActivationType_SIGMOID; | attr->type = schema::ActivationType_SIGMOID; | ||||
| } else if (prim.name() == "ReLU6") { | } else if (prim.name() == "ReLU6") { | ||||
| attr->type = schema::ActivationType_RELU6; | attr->type = schema::ActivationType_RELU6; | ||||
| } else if (prim.name() == "Swish") { | |||||
| attr->type = schema::ActivationType_SWISH; | |||||
| } else if (prim.name() == "HSwish") { | } else if (prim.name() == "HSwish") { | ||||
| attr->type = schema::ActivationType_HSWISH; | attr->type = schema::ActivationType_HSWISH; | ||||
| } else if (prim.name() == "HSigmoid") { | } else if (prim.name() == "HSigmoid") { | ||||
| @@ -29,6 +29,7 @@ using mindspore::schema::ActivationType_HSWISH; | |||||
| using mindspore::schema::ActivationType_LEAKY_RELU; | using mindspore::schema::ActivationType_LEAKY_RELU; | ||||
| using mindspore::schema::ActivationType_RELU; | using mindspore::schema::ActivationType_RELU; | ||||
| using mindspore::schema::ActivationType_RELU6; | using mindspore::schema::ActivationType_RELU6; | ||||
| using mindspore::schema::ActivationType_SWISH; | |||||
| using mindspore::schema::PrimitiveType_Activation; | using mindspore::schema::PrimitiveType_Activation; | ||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| @@ -44,32 +45,34 @@ int ActivationCPUKernel::DoActivation(int task_id) { | |||||
| int stride = UP_DIV(length, thread_count_); | int stride = UP_DIV(length, thread_count_); | ||||
| int count = MSMIN(stride, length - stride * task_id); | int count = MSMIN(stride, length - stride * task_id); | ||||
| auto error_code = RET_OK; | |||||
| auto ret = RET_OK; | |||||
| if (type_ == schema::ActivationType_RELU) { | if (type_ == schema::ActivationType_RELU) { | ||||
| error_code = Fp32Relu(input_addr + stride * task_id, count, output_addr + stride * task_id); | |||||
| ret = Fp32Relu(input_addr + stride * task_id, count, output_addr + stride * task_id); | |||||
| } else if (type_ == schema::ActivationType_RELU6) { | } else if (type_ == schema::ActivationType_RELU6) { | ||||
| error_code = Fp32Relu6(input_addr + stride * task_id, count, output_addr + stride * task_id); | |||||
| ret = Fp32Relu6(input_addr + stride * task_id, count, output_addr + stride * task_id); | |||||
| } else if (type_ == schema::ActivationType_LEAKY_RELU) { | } else if (type_ == schema::ActivationType_LEAKY_RELU) { | ||||
| error_code = LRelu(input_addr + stride * task_id, count, output_addr + stride * task_id, alpha_); | |||||
| ret = LRelu(input_addr + stride * task_id, count, output_addr + stride * task_id, alpha_); | |||||
| } else if (type_ == schema::ActivationType_SIGMOID) { | } else if (type_ == schema::ActivationType_SIGMOID) { | ||||
| error_code = Sigmoid(input_addr + stride * task_id, count, output_addr + stride * task_id); | |||||
| ret = Sigmoid(input_addr + stride * task_id, count, output_addr + stride * task_id); | |||||
| } else if (type_ == schema::ActivationType_TANH) { | } else if (type_ == schema::ActivationType_TANH) { | ||||
| error_code = Tanh(input_addr + stride * task_id, count, output_addr + stride * task_id); | |||||
| ret = Tanh(input_addr + stride * task_id, count, output_addr + stride * task_id); | |||||
| } else if (type_ == schema::ActivationType_SWISH) { | |||||
| ret = Swish(input_addr + stride * task_id, count, output_addr + stride * task_id); | |||||
| } else if (type_ == schema::ActivationType_HSWISH) { | } else if (type_ == schema::ActivationType_HSWISH) { | ||||
| error_code = HSwish(input_addr + stride * task_id, count, output_addr + stride * task_id); | |||||
| ret = HSwish(input_addr + stride * task_id, count, output_addr + stride * task_id); | |||||
| } else if (type_ == schema::ActivationType_HSIGMOID) { | } else if (type_ == schema::ActivationType_HSIGMOID) { | ||||
| error_code = HSigmoid(input_addr + stride * task_id, count, output_addr + stride * task_id); | |||||
| ret = HSigmoid(input_addr + stride * task_id, count, output_addr + stride * task_id); | |||||
| } else if (type_ == schema::ActivationType_HARD_TANH) { | } else if (type_ == schema::ActivationType_HARD_TANH) { | ||||
| error_code = HardTanh(input_addr + stride * task_id, count, output_addr + stride * task_id, min_val_, max_val_); | |||||
| ret = HardTanh(input_addr + stride * task_id, count, output_addr + stride * task_id, min_val_, max_val_); | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Activation type error"; | MS_LOG(ERROR) << "Activation type error"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (error_code != RET_OK) { | |||||
| return RET_ERROR; | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Activation error, ret: " << ret; | |||||
| } | } | ||||
| return RET_OK; | |||||
| return ret; | |||||
| } | } | ||||
| int ActivationRun(void *cdata, int task_id) { | int ActivationRun(void *cdata, int task_id) { | ||||
| @@ -73,6 +73,17 @@ TEST_F(TestActivationFp32, SigmoidFp32) { | |||||
| MS_LOG(INFO) << "TestSigmoidFp32 passed"; | MS_LOG(INFO) << "TestSigmoidFp32 passed"; | ||||
| } | } | ||||
| TEST_F(TestActivationFp32, SwishFp32) { | |||||
| float input[8] = {0, 1, 2, 3, 4, 5, 6, 7}; | |||||
| float output[8] = {0}; | |||||
| Swish(input, 8, output); | |||||
| float expect[8] = {0, 0.731059, 1.761594, 2.857722, 3.928056, 4.966535, 5.985162, 6.993623}; | |||||
| for (int i = 0; i < 8; ++i) { | |||||
| EXPECT_NEAR(output[i], expect[i], 0.00001); | |||||
| } | |||||
| } | |||||
| TEST_F(TestActivationFp32, TanhFp32) { | TEST_F(TestActivationFp32, TanhFp32) { | ||||
| float input[7] = {-3, -2, -1, 0, 1, 2, 3}; | float input[7] = {-3, -2, -1, 0, 1, 2, 3}; | ||||
| float output[7] = {0}; | float output[7] = {0}; | ||||
| @@ -56,6 +56,9 @@ STATUS TfliteActivationParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| } else if (std::strcmp(node_name, "Logistic") == 0) { | } else if (std::strcmp(node_name, "Logistic") == 0) { | ||||
| MS_LOG(DEBUG) << "parse TfliteLogisticParser"; | MS_LOG(DEBUG) << "parse TfliteLogisticParser"; | ||||
| attr->type = schema::ActivationType_SIGMOID; | attr->type = schema::ActivationType_SIGMOID; | ||||
| } else if (std::strcmp(node_name, "Swish") == 0) { | |||||
| MS_LOG(DEBUG) << "parse TfliteSwishParser"; | |||||
| attr->type = schema::ActivationType_SWISH; | |||||
| } else if (std::strcmp(node_name, "HardSwish") == 0) { | } else if (std::strcmp(node_name, "HardSwish") == 0) { | ||||
| MS_LOG(DEBUG) << "parse TfliteHardSwishParser"; | MS_LOG(DEBUG) << "parse TfliteHardSwishParser"; | ||||
| attr->type = schema::ActivationType_HSWISH; | attr->type = schema::ActivationType_HSWISH; | ||||
| @@ -82,6 +85,7 @@ STATUS TfliteActivationParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| TfliteNodeRegister g_TfliteReluParser("Relu", new TfliteReluParser()); | TfliteNodeRegister g_TfliteReluParser("Relu", new TfliteReluParser()); | ||||
| TfliteNodeRegister g_TfliteRelu6Parser("Relu6", new TfliteRelu6Parser()); | TfliteNodeRegister g_TfliteRelu6Parser("Relu6", new TfliteRelu6Parser()); | ||||
| TfliteNodeRegister g_TfliteTanhParser("Tanh", new TfliteTanhParser()); | TfliteNodeRegister g_TfliteTanhParser("Tanh", new TfliteTanhParser()); | ||||
| TfliteNodeRegister g_TfliteSwishParser("Swish", new TfliteSwishParser()); | |||||
| TfliteNodeRegister g_TfliteHardSwishParser("HardSwish", new TfliteHardSwishParser()); | TfliteNodeRegister g_TfliteHardSwishParser("HardSwish", new TfliteHardSwishParser()); | ||||
| TfliteNodeRegister g_tfliteLogisticParser("Logistic", new TfliteLogisticParser()); | TfliteNodeRegister g_tfliteLogisticParser("Logistic", new TfliteLogisticParser()); | ||||
| TfliteNodeRegister g_TfliteLeakyReluParser("LeakyRelu", new TfliteLeakyReluParser()); | TfliteNodeRegister g_TfliteLeakyReluParser("LeakyRelu", new TfliteLeakyReluParser()); | ||||
| @@ -53,6 +53,11 @@ class TfliteLogisticParser : public TfliteActivationParser { | |||||
| TfliteLogisticParser() : TfliteActivationParser() {} | TfliteLogisticParser() : TfliteActivationParser() {} | ||||
| }; | }; | ||||
| class TfliteSwishParser : public TfliteActivationParser { | |||||
| public: | |||||
| TfliteSwishParser() : TfliteActivationParser() {} | |||||
| }; | |||||
| class TfliteHardSwishParser : public TfliteActivationParser { | class TfliteHardSwishParser : public TfliteActivationParser { | ||||
| public: | public: | ||||
| TfliteHardSwishParser() : TfliteActivationParser() {} | TfliteHardSwishParser() : TfliteActivationParser() {} | ||||