|
|
|
@@ -41,6 +41,14 @@ ops::PrimitiveC *TFActivationParser::Parse(const tensorflow::NodeDef &tf_op, |
|
|
|
prim->set_activation_type(mindspore::ActivationType::SELU); |
|
|
|
} else if (tf_op.op() == "Softplus") { |
|
|
|
prim->set_activation_type(mindspore::ActivationType::SOFTPLUS); |
|
|
|
} else if (tf_op.op() == "LeakyRelu") { |
|
|
|
prim->set_activation_type(mindspore::ActivationType::LEAKY_RELU); |
|
|
|
tensorflow::AttrValue attr_value; |
|
|
|
if (!TensorFlowUtils::FindAttrValue(tf_op, "alpha", &attr_value)) { |
|
|
|
MS_LOG(ERROR) << "The attribute alpha should be specified."; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
prim->set_alpha(attr_value.f()); |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << "unsupported activation type:" << tf_op.op(); |
|
|
|
return nullptr; |
|
|
|
@@ -55,33 +63,12 @@ ops::PrimitiveC *TFActivationParser::Parse(const tensorflow::NodeDef &tf_op, |
|
|
|
return prim.release(); |
|
|
|
} |
|
|
|
|
|
|
|
ops::PrimitiveC *TFLeakyReluParser::Parse(const tensorflow::NodeDef &tf_op, |
|
|
|
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, |
|
|
|
std::vector<std::string> *inputs, int *output_size) { |
|
|
|
auto prim = std::make_unique<ops::LeakyRelu>(); |
|
|
|
|
|
|
|
tensorflow::AttrValue attr_value; |
|
|
|
if (!TensorFlowUtils::FindAttrValue(tf_op, "alpha", &attr_value)) { |
|
|
|
MS_LOG(ERROR) << "The attribute alpha should be specified."; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
prim->set_negative_slope(attr_value.f()); |
|
|
|
|
|
|
|
*output_size = 1; |
|
|
|
if (AddOpInput(tf_op, 0, inputs) != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "add op input failed"; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
return prim.release(); |
|
|
|
} |
|
|
|
|
|
|
|
TFNodeRegistrar g_tfReluParser("Relu", new TFActivationParser()); |
|
|
|
TFNodeRegistrar g_tfRelu6Parser("Relu6", new TFActivationParser()); |
|
|
|
TFNodeRegistrar g_tfSigmoidParser("Sigmoid", new TFActivationParser()); |
|
|
|
TFNodeRegistrar g_tfTanhParser("Tanh", new TFActivationParser()); |
|
|
|
TFNodeRegistrar g_tfSeLUParser("Selu", new TFActivationParser()); |
|
|
|
TFNodeRegistrar g_tfLeakyReluParser("LeakyRelu", new TFLeakyReluParser()); |
|
|
|
TFNodeRegistrar g_tfLeakyReluParser("LeakyRelu", new TFActivationParser()); |
|
|
|
TFNodeRegistrar g_tfSoftplusParser("Softplus", new TFActivationParser()); |
|
|
|
} // namespace lite |
|
|
|
} // namespace mindspore |