Browse Source

!4397 leaky relu parser to activation

Merge pull request !4397 from sunsuodong/leaky_relu
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
a8ddcbdb04
2 changed files with 13 additions and 41 deletions
  1. +8
    -30
      mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.cc
  2. +5
    -11
      mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.h

+ 8
- 30
mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.cc View File

@@ -54,6 +54,14 @@ STATUS TfliteActivationParser::Parse(const std::unique_ptr<tflite::OperatorT> &t
} else if (std::strcmp(node_name, "Logistic") == 0) {
MS_LOG(DEBUG) << "parse TfliteLogisticParser";
attr->type = schema::ActivationType_SIGMOID;
} else if (std::strcmp(node_name, "LeakyRelu") == 0) {
const auto &option = tfliteOp->builtin_options.AsLeakyReluOptions();
if (option == nullptr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR;
}
attr->type = schema::ActivationType_LEAKY_RELU;
attr->alpha = option->alpha;
} else {
MS_LOG(ERROR) << "wrong activation type";
return RET_ERROR;
@@ -92,36 +100,6 @@ STATUS TflitePreluParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite
return RET_OK;
}

STATUS TfliteLeakyReluParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet,
schema::CNodeT *op, TensorCache *tensor_cache, bool quantizedModel) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}

MS_LOG(DEBUG) << "parse TfliteLeakyReluParser";
std::unique_ptr<schema::LeakyReLUT> attr(new schema::LeakyReLUT());

const auto &tflite_attr = tfliteOp->builtin_options.AsLeakyReluOptions();
if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR;
}
attr->negativeSlope = tflite_attr->alpha;

op->primitive->value.type = schema::PrimitiveType_LeakyReLU;
op->primitive->value.value = attr.release();
return RET_OK;
}

TfliteNodeRegister g_TfliteReluParser("Relu", new TfliteReluParser());
TfliteNodeRegister g_TfliteRelu6Parser("Relu6", new TfliteRelu6Parser());
TfliteNodeRegister g_TfliteTanhParser("Tanh", new TfliteTanhParser());


+ 5
- 11
mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.h View File

@@ -56,6 +56,11 @@ class TfliteLogisticParser : public TfliteActivationParser {
TfliteLogisticParser() : TfliteActivationParser() {}
};

class TfliteLeakyReluParser : public TfliteActivationParser {
public:
TfliteLeakyReluParser() : TfliteActivationParser() {}
};

class TflitePreluParser : public TfliteNodeParser {
public:
TflitePreluParser() : TfliteNodeParser("Prelu") {}
@@ -67,17 +72,6 @@ class TflitePreluParser : public TfliteNodeParser {
TensorCache *tensor_cache, bool quantized_model) override;
};

class TfliteLeakyReluParser : public TfliteNodeParser {
public:
TfliteLeakyReluParser() : TfliteNodeParser("LeakyRelu") {}

STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tfliteOp,
const std::vector<std::unique_ptr<tflite::TensorT>> &tfliteTensors,
const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer,
const std::vector<std::unique_ptr<tflite::OperatorCodeT>> &tfliteOpSet, schema::CNodeT *op,
TensorCache *tensor_cache, bool quantizedModel) override;
};

} // namespace lite
} // namespace mindspore



Loading…
Cancel
Save