|
|
|
@@ -31,7 +31,28 @@ ops::PrimitiveC *TFPadParser::Parse(const tensorflow::NodeDef &tf_op, |
|
|
|
if (tf_op.op() == "Pad") { |
|
|
|
prim->set_padding_mode(mindspore::PaddingMode::CONSTANT); |
|
|
|
prim->set_constant_value(0.0f); |
|
|
|
|
|
|
|
} else if (tf_op.op() == "PadV2") { |
|
|
|
prim->set_padding_mode(mindspore::PaddingMode::CONSTANT); |
|
|
|
if (tf_op.input_size() < 3) { |
|
|
|
MS_LOG(ERROR) << "tf padv2 input size less than 3, which is " << tf_op.input_size(); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
auto &const_value_name = tf_op.input(2); |
|
|
|
if (tf_node_map.find(const_value_name) == tf_node_map.end()) { |
|
|
|
MS_LOG(ERROR) << "cannot find the input."; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
tensorflow::AttrValue attr_value; |
|
|
|
if (!TensorFlowUtils::FindAttrValue(*tf_node_map.at(const_value_name), "value", &attr_value)) { |
|
|
|
MS_LOG(ERROR) << "the input may be not const, which is not support now."; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
auto &tensor_proto = attr_value.tensor(); |
|
|
|
if (tensor_proto.dtype() != tensorflow::DT_FLOAT) { |
|
|
|
MS_LOG(ERROR) << "input data type only support float now."; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
prim->set_constant_value(tensor_proto.float_val(0)); |
|
|
|
} else if (tf_op.op() == "MirrorPad") { |
|
|
|
tensorflow::AttrValue attr_value; |
|
|
|
if (!TensorFlowUtils::FindAttrValue(tf_op, "mode", &attr_value)) { |
|
|
|
@@ -58,6 +79,7 @@ ops::PrimitiveC *TFPadParser::Parse(const tensorflow::NodeDef &tf_op, |
|
|
|
return prim.release(); |
|
|
|
} |
|
|
|
TFNodeRegistrar g_tfPadParser("Pad", new TFPadParser()); |
|
|
|
TFNodeRegistrar g_tfPadV2Parser("PadV2", new TFPadParser()); |
|
|
|
TFNodeRegistrar g_tfMirrorPadParser("MirrorPad", new TFPadParser()); |
|
|
|
} // namespace lite |
|
|
|
} // namespace mindspore |