Browse Source

add swish kernel

tags/v1.1.0
sunsuodong 5 years ago
parent
commit
ce47d7ee49
8 changed files with 60 additions and 13 deletions
  1. +20
    -0
      mindspore/lite/nnacl/fp32/activation.c
  2. +1
    -0
      mindspore/lite/nnacl/fp32/activation.h
  3. +2
    -1
      mindspore/lite/schema/ops.fbs
  4. +2
    -0
      mindspore/lite/src/ops/activation.cc
  5. +15
    -12
      mindspore/lite/src/runtime/kernel/arm/fp32/activation.cc
  6. +11
    -0
      mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/activation_fp32_test.cc
  7. +4
    -0
      mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.cc
  8. +5
    -0
      mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.h

+ 20
- 0
mindspore/lite/nnacl/fp32/activation.c View File

@@ -108,6 +108,26 @@ int Tanh(const float *src, int length, float *dst) {
return NNACL_OK; return NNACL_OK;
} }


int Swish(const float *src, int length, float *dst) {
int ret = Sigmoid(src, length, dst);
if (ret != NNACL_OK) {
return NNACL_ERR;
}
int index = 0;
#ifdef ENABLE_NEON
for (; index <= length - C4NUM; index += C4NUM) {
float32x4_t src_value = vld1q_f32(src + index);
float32x4_t sigmoid_value = vld1q_f32(dst + index);
float32x4_t result = vmulq_f32(src_value, sigmoid_value);
vst1q_f32(dst + index, result);
}
#endif
for (; index < length; index++) {
dst[index] = src[index] * dst[index];
}
return NNACL_OK;
}

int HSwish(const float *src, int length, float *dst) { int HSwish(const float *src, int length, float *dst) {
for (int i = 0; i < length; ++i) { for (int i = 0; i < length; ++i) {
float in = src[i]; float in = src[i];


+ 1
- 0
mindspore/lite/nnacl/fp32/activation.h View File

@@ -37,6 +37,7 @@ int LRelu(const float *src, int length, float *dst, float alpha);
int Sigmoid(const float *src, int length, float *dst); int Sigmoid(const float *src, int length, float *dst);
int Tanh(const float *src, int length, float *dst); int Tanh(const float *src, int length, float *dst);
int HSigmoid(const float *src, int length, float *dst); int HSigmoid(const float *src, int length, float *dst);
int Swish(const float *src, int length, float *dst);
int HSwish(const float *src, int length, float *dst); int HSwish(const float *src, int length, float *dst);
int HardTanh(const float *src, int length, float *dst, float min_val, float max_val); int HardTanh(const float *src, int length, float *dst, float min_val, float max_val);
#ifdef __cplusplus #ifdef __cplusplus


+ 2
- 1
mindspore/lite/schema/ops.fbs View File

@@ -79,7 +79,8 @@ enum ActivationType : byte {
LINEAR = 15, LINEAR = 15,
HARD_TANH = 16, HARD_TANH = 16,
SIGN = 17, SIGN = 17,
UNKNOW = 18
SWISH = 18,
UNKNOW = 19
} }
enum ActivationGradType : byte { enum ActivationGradType : byte {
NO_ACTIVATION = 0, NO_ACTIVATION = 0,


+ 2
- 0
mindspore/lite/src/ops/activation.cc View File

@@ -53,6 +53,8 @@ int Activation::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr>
attr->type = schema::ActivationType_SIGMOID; attr->type = schema::ActivationType_SIGMOID;
} else if (prim.name() == "ReLU6") { } else if (prim.name() == "ReLU6") {
attr->type = schema::ActivationType_RELU6; attr->type = schema::ActivationType_RELU6;
} else if (prim.name() == "Swish") {
attr->type = schema::ActivationType_SWISH;
} else if (prim.name() == "HSwish") { } else if (prim.name() == "HSwish") {
attr->type = schema::ActivationType_HSWISH; attr->type = schema::ActivationType_HSWISH;
} else if (prim.name() == "HSigmoid") { } else if (prim.name() == "HSigmoid") {


+ 15
- 12
mindspore/lite/src/runtime/kernel/arm/fp32/activation.cc View File

@@ -29,6 +29,7 @@ using mindspore::schema::ActivationType_HSWISH;
using mindspore::schema::ActivationType_LEAKY_RELU; using mindspore::schema::ActivationType_LEAKY_RELU;
using mindspore::schema::ActivationType_RELU; using mindspore::schema::ActivationType_RELU;
using mindspore::schema::ActivationType_RELU6; using mindspore::schema::ActivationType_RELU6;
using mindspore::schema::ActivationType_SWISH;
using mindspore::schema::PrimitiveType_Activation; using mindspore::schema::PrimitiveType_Activation;


namespace mindspore::kernel { namespace mindspore::kernel {
@@ -44,32 +45,34 @@ int ActivationCPUKernel::DoActivation(int task_id) {
int stride = UP_DIV(length, thread_count_); int stride = UP_DIV(length, thread_count_);
int count = MSMIN(stride, length - stride * task_id); int count = MSMIN(stride, length - stride * task_id);


auto error_code = RET_OK;
auto ret = RET_OK;


if (type_ == schema::ActivationType_RELU) { if (type_ == schema::ActivationType_RELU) {
error_code = Fp32Relu(input_addr + stride * task_id, count, output_addr + stride * task_id);
ret = Fp32Relu(input_addr + stride * task_id, count, output_addr + stride * task_id);
} else if (type_ == schema::ActivationType_RELU6) { } else if (type_ == schema::ActivationType_RELU6) {
error_code = Fp32Relu6(input_addr + stride * task_id, count, output_addr + stride * task_id);
ret = Fp32Relu6(input_addr + stride * task_id, count, output_addr + stride * task_id);
} else if (type_ == schema::ActivationType_LEAKY_RELU) { } else if (type_ == schema::ActivationType_LEAKY_RELU) {
error_code = LRelu(input_addr + stride * task_id, count, output_addr + stride * task_id, alpha_);
ret = LRelu(input_addr + stride * task_id, count, output_addr + stride * task_id, alpha_);
} else if (type_ == schema::ActivationType_SIGMOID) { } else if (type_ == schema::ActivationType_SIGMOID) {
error_code = Sigmoid(input_addr + stride * task_id, count, output_addr + stride * task_id);
ret = Sigmoid(input_addr + stride * task_id, count, output_addr + stride * task_id);
} else if (type_ == schema::ActivationType_TANH) { } else if (type_ == schema::ActivationType_TANH) {
error_code = Tanh(input_addr + stride * task_id, count, output_addr + stride * task_id);
ret = Tanh(input_addr + stride * task_id, count, output_addr + stride * task_id);
} else if (type_ == schema::ActivationType_SWISH) {
ret = Swish(input_addr + stride * task_id, count, output_addr + stride * task_id);
} else if (type_ == schema::ActivationType_HSWISH) { } else if (type_ == schema::ActivationType_HSWISH) {
error_code = HSwish(input_addr + stride * task_id, count, output_addr + stride * task_id);
ret = HSwish(input_addr + stride * task_id, count, output_addr + stride * task_id);
} else if (type_ == schema::ActivationType_HSIGMOID) { } else if (type_ == schema::ActivationType_HSIGMOID) {
error_code = HSigmoid(input_addr + stride * task_id, count, output_addr + stride * task_id);
ret = HSigmoid(input_addr + stride * task_id, count, output_addr + stride * task_id);
} else if (type_ == schema::ActivationType_HARD_TANH) { } else if (type_ == schema::ActivationType_HARD_TANH) {
error_code = HardTanh(input_addr + stride * task_id, count, output_addr + stride * task_id, min_val_, max_val_);
ret = HardTanh(input_addr + stride * task_id, count, output_addr + stride * task_id, min_val_, max_val_);
} else { } else {
MS_LOG(ERROR) << "Activation type error"; MS_LOG(ERROR) << "Activation type error";
return RET_ERROR; return RET_ERROR;
} }
if (error_code != RET_OK) {
return RET_ERROR;
if (ret != RET_OK) {
MS_LOG(ERROR) << "Activation error, ret: " << ret;
} }
return RET_OK;
return ret;
} }


int ActivationRun(void *cdata, int task_id) { int ActivationRun(void *cdata, int task_id) {


+ 11
- 0
mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/activation_fp32_test.cc View File

@@ -73,6 +73,17 @@ TEST_F(TestActivationFp32, SigmoidFp32) {
MS_LOG(INFO) << "TestSigmoidFp32 passed"; MS_LOG(INFO) << "TestSigmoidFp32 passed";
} }


TEST_F(TestActivationFp32, SwishFp32) {
float input[8] = {0, 1, 2, 3, 4, 5, 6, 7};
float output[8] = {0};
Swish(input, 8, output);

float expect[8] = {0, 0.731059, 1.761594, 2.857722, 3.928056, 4.966535, 5.985162, 6.993623};
for (int i = 0; i < 8; ++i) {
EXPECT_NEAR(output[i], expect[i], 0.00001);
}
}

TEST_F(TestActivationFp32, TanhFp32) { TEST_F(TestActivationFp32, TanhFp32) {
float input[7] = {-3, -2, -1, 0, 1, 2, 3}; float input[7] = {-3, -2, -1, 0, 1, 2, 3};
float output[7] = {0}; float output[7] = {0};


+ 4
- 0
mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.cc View File

@@ -56,6 +56,9 @@ STATUS TfliteActivationParser::Parse(TfliteTensorsInfo *tensors_info,
} else if (std::strcmp(node_name, "Logistic") == 0) { } else if (std::strcmp(node_name, "Logistic") == 0) {
MS_LOG(DEBUG) << "parse TfliteLogisticParser"; MS_LOG(DEBUG) << "parse TfliteLogisticParser";
attr->type = schema::ActivationType_SIGMOID; attr->type = schema::ActivationType_SIGMOID;
} else if (std::strcmp(node_name, "Swish") == 0) {
MS_LOG(DEBUG) << "parse TfliteSwishParser";
attr->type = schema::ActivationType_SWISH;
} else if (std::strcmp(node_name, "HardSwish") == 0) { } else if (std::strcmp(node_name, "HardSwish") == 0) {
MS_LOG(DEBUG) << "parse TfliteHardSwishParser"; MS_LOG(DEBUG) << "parse TfliteHardSwishParser";
attr->type = schema::ActivationType_HSWISH; attr->type = schema::ActivationType_HSWISH;
@@ -82,6 +85,7 @@ STATUS TfliteActivationParser::Parse(TfliteTensorsInfo *tensors_info,
TfliteNodeRegister g_TfliteReluParser("Relu", new TfliteReluParser()); TfliteNodeRegister g_TfliteReluParser("Relu", new TfliteReluParser());
TfliteNodeRegister g_TfliteRelu6Parser("Relu6", new TfliteRelu6Parser()); TfliteNodeRegister g_TfliteRelu6Parser("Relu6", new TfliteRelu6Parser());
TfliteNodeRegister g_TfliteTanhParser("Tanh", new TfliteTanhParser()); TfliteNodeRegister g_TfliteTanhParser("Tanh", new TfliteTanhParser());
TfliteNodeRegister g_TfliteSwishParser("Swish", new TfliteSwishParser());
TfliteNodeRegister g_TfliteHardSwishParser("HardSwish", new TfliteHardSwishParser()); TfliteNodeRegister g_TfliteHardSwishParser("HardSwish", new TfliteHardSwishParser());
TfliteNodeRegister g_tfliteLogisticParser("Logistic", new TfliteLogisticParser()); TfliteNodeRegister g_tfliteLogisticParser("Logistic", new TfliteLogisticParser());
TfliteNodeRegister g_TfliteLeakyReluParser("LeakyRelu", new TfliteLeakyReluParser()); TfliteNodeRegister g_TfliteLeakyReluParser("LeakyRelu", new TfliteLeakyReluParser());


+ 5
- 0
mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.h View File

@@ -53,6 +53,11 @@ class TfliteLogisticParser : public TfliteActivationParser {
TfliteLogisticParser() : TfliteActivationParser() {} TfliteLogisticParser() : TfliteActivationParser() {}
}; };


class TfliteSwishParser : public TfliteActivationParser {
public:
TfliteSwishParser() : TfliteActivationParser() {}
};

class TfliteHardSwishParser : public TfliteActivationParser { class TfliteHardSwishParser : public TfliteActivationParser {
public: public:
TfliteHardSwishParser() : TfliteActivationParser() {} TfliteHardSwishParser() : TfliteActivationParser() {}


Loading…
Cancel
Save