From 9a71b32dd6f24465cb0234de22104493ff3b418d Mon Sep 17 00:00:00 2001 From: yefeng Date: Tue, 30 Mar 2021 09:36:30 +0800 Subject: [PATCH] fix_leakyRelu --- .../parser/tf/tf_activation_parser.cc | 31 ++++++------------- .../parser/tf/tf_activation_parser.h | 10 ------ 2 files changed, 9 insertions(+), 32 deletions(-) diff --git a/mindspore/lite/tools/converter/parser/tf/tf_activation_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_activation_parser.cc index 3e000499a2..df1acd9f6f 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_activation_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_activation_parser.cc @@ -41,6 +41,14 @@ ops::PrimitiveC *TFActivationParser::Parse(const tensorflow::NodeDef &tf_op, prim->set_activation_type(mindspore::ActivationType::SELU); } else if (tf_op.op() == "Softplus") { prim->set_activation_type(mindspore::ActivationType::SOFTPLUS); + } else if (tf_op.op() == "LeakyRelu") { + prim->set_activation_type(mindspore::ActivationType::LEAKY_RELU); + tensorflow::AttrValue attr_value; + if (!TensorFlowUtils::FindAttrValue(tf_op, "alpha", &attr_value)) { + MS_LOG(ERROR) << "The attribute alpha should be specified."; + return nullptr; + } + prim->set_alpha(attr_value.f()); } else { MS_LOG(ERROR) << "unsupported activation type:" << tf_op.op(); return nullptr; @@ -55,33 +63,12 @@ ops::PrimitiveC *TFActivationParser::Parse(const tensorflow::NodeDef &tf_op, return prim.release(); } -ops::PrimitiveC *TFLeakyReluParser::Parse(const tensorflow::NodeDef &tf_op, - const std::map &tf_node_map, - std::vector *inputs, int *output_size) { - auto prim = std::make_unique(); - - tensorflow::AttrValue attr_value; - if (!TensorFlowUtils::FindAttrValue(tf_op, "alpha", &attr_value)) { - MS_LOG(ERROR) << "The attribute alpha should be specified."; - return nullptr; - } - prim->set_negative_slope(attr_value.f()); - - *output_size = 1; - if (AddOpInput(tf_op, 0, inputs) != RET_OK) { - MS_LOG(ERROR) << "add op input failed"; - return nullptr; - } - - return prim.release(); -} - TFNodeRegistrar g_tfReluParser("Relu", new TFActivationParser()); TFNodeRegistrar g_tfRelu6Parser("Relu6", new TFActivationParser()); TFNodeRegistrar g_tfSigmoidParser("Sigmoid", new TFActivationParser()); TFNodeRegistrar g_tfTanhParser("Tanh", new TFActivationParser()); TFNodeRegistrar g_tfSeLUParser("Selu", new TFActivationParser()); -TFNodeRegistrar g_tfLeakyReluParser("LeakyRelu", new TFLeakyReluParser()); +TFNodeRegistrar g_tfLeakyReluParser("LeakyRelu", new TFActivationParser()); TFNodeRegistrar g_tfSoftplusParser("Softplus", new TFActivationParser()); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_activation_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_activation_parser.h index 87dce5e6de..ece8aee5fd 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_activation_parser.h +++ b/mindspore/lite/tools/converter/parser/tf/tf_activation_parser.h @@ -33,16 +33,6 @@ class TFActivationParser : public TFNodeParser { const std::map &tf_node_map, std::vector *inputs, int *output_size) override; }; - -class TFLeakyReluParser : public TFNodeParser { - public: - TFLeakyReluParser() = default; - ~TFLeakyReluParser() override = default; - - ops::PrimitiveC *Parse(const tensorflow::NodeDef &tf_op, - const std::map &tf_node_map, - std::vector *inputs, int *output_size) override; -}; } // namespace lite } // namespace mindspore