Browse Source

!14354 [MS_LITE] fix leakrelu

From: @YeFeng_24
Reviewed-by: @hangangqiang,@zhanghaibo5
Signed-off-by: @hangangqiang
pull/14354/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
9bc02b1e69
2 changed files with 9 additions and 32 deletions
  1. +9
    -22
      mindspore/lite/tools/converter/parser/tf/tf_activation_parser.cc
  2. +0
    -10
      mindspore/lite/tools/converter/parser/tf/tf_activation_parser.h

+ 9
- 22
mindspore/lite/tools/converter/parser/tf/tf_activation_parser.cc View File

@@ -41,6 +41,14 @@ ops::PrimitiveC *TFActivationParser::Parse(const tensorflow::NodeDef &tf_op,
prim->set_activation_type(mindspore::ActivationType::SELU); prim->set_activation_type(mindspore::ActivationType::SELU);
} else if (tf_op.op() == "Softplus") { } else if (tf_op.op() == "Softplus") {
prim->set_activation_type(mindspore::ActivationType::SOFTPLUS); prim->set_activation_type(mindspore::ActivationType::SOFTPLUS);
} else if (tf_op.op() == "LeakyRelu") {
prim->set_activation_type(mindspore::ActivationType::LEAKY_RELU);
tensorflow::AttrValue attr_value;
if (!TensorFlowUtils::FindAttrValue(tf_op, "alpha", &attr_value)) {
MS_LOG(ERROR) << "The attribute alpha should be specified.";
return nullptr;
}
prim->set_alpha(attr_value.f());
} else { } else {
MS_LOG(ERROR) << "unsupported activation type:" << tf_op.op(); MS_LOG(ERROR) << "unsupported activation type:" << tf_op.op();
return nullptr; return nullptr;
@@ -55,33 +63,12 @@ ops::PrimitiveC *TFActivationParser::Parse(const tensorflow::NodeDef &tf_op,
return prim.release(); return prim.release();
} }


ops::PrimitiveC *TFLeakyReluParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) {
auto prim = std::make_unique<ops::LeakyRelu>();

tensorflow::AttrValue attr_value;
if (!TensorFlowUtils::FindAttrValue(tf_op, "alpha", &attr_value)) {
MS_LOG(ERROR) << "The attribute alpha should be specified.";
return nullptr;
}
prim->set_negative_slope(attr_value.f());

*output_size = 1;
if (AddOpInput(tf_op, 0, inputs) != RET_OK) {
MS_LOG(ERROR) << "add op input failed";
return nullptr;
}

return prim.release();
}

TFNodeRegistrar g_tfReluParser("Relu", new TFActivationParser()); TFNodeRegistrar g_tfReluParser("Relu", new TFActivationParser());
TFNodeRegistrar g_tfRelu6Parser("Relu6", new TFActivationParser()); TFNodeRegistrar g_tfRelu6Parser("Relu6", new TFActivationParser());
TFNodeRegistrar g_tfSigmoidParser("Sigmoid", new TFActivationParser()); TFNodeRegistrar g_tfSigmoidParser("Sigmoid", new TFActivationParser());
TFNodeRegistrar g_tfTanhParser("Tanh", new TFActivationParser()); TFNodeRegistrar g_tfTanhParser("Tanh", new TFActivationParser());
TFNodeRegistrar g_tfSeLUParser("Selu", new TFActivationParser()); TFNodeRegistrar g_tfSeLUParser("Selu", new TFActivationParser());
TFNodeRegistrar g_tfLeakyReluParser("LeakyRelu", new TFLeakyReluParser());
TFNodeRegistrar g_tfLeakyReluParser("LeakyRelu", new TFActivationParser());
TFNodeRegistrar g_tfSoftplusParser("Softplus", new TFActivationParser()); TFNodeRegistrar g_tfSoftplusParser("Softplus", new TFActivationParser());
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

+ 0
- 10
mindspore/lite/tools/converter/parser/tf/tf_activation_parser.h View File

@@ -33,16 +33,6 @@ class TFActivationParser : public TFNodeParser {
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) override; std::vector<std::string> *inputs, int *output_size) override;
}; };

class TFLeakyReluParser : public TFNodeParser {
public:
TFLeakyReluParser() = default;
~TFLeakyReluParser() override = default;

ops::PrimitiveC *Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore




Loading…
Cancel
Save