diff --git a/mindspore/lite/tools/converter/parser/tf/tf_pad_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_pad_parser.cc index 489a8c3603..04bc128f5c 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_pad_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_pad_parser.cc @@ -31,7 +31,28 @@ ops::PrimitiveC *TFPadParser::Parse(const tensorflow::NodeDef &tf_op, if (tf_op.op() == "Pad") { prim->set_padding_mode(mindspore::PaddingMode::CONSTANT); prim->set_constant_value(0.0f); - + } else if (tf_op.op() == "PadV2") { + prim->set_padding_mode(mindspore::PaddingMode::CONSTANT); + if (tf_op.input_size() < 3) { + MS_LOG(ERROR) << "tf padv2 input size less than 3, which is " << tf_op.input_size(); + return nullptr; + } + auto &const_value_name = tf_op.input(2); + if (tf_node_map.find(const_value_name) == tf_node_map.end()) { + MS_LOG(ERROR) << "cannot find the input."; + return nullptr; + } + tensorflow::AttrValue attr_value; + if (!TensorFlowUtils::FindAttrValue(*tf_node_map.at(const_value_name), "value", &attr_value)) { + MS_LOG(ERROR) << "the input may be not const, which is not support now."; + return nullptr; + } + auto &tensor_proto = attr_value.tensor(); + if (tensor_proto.dtype() != tensorflow::DT_FLOAT) { + MS_LOG(ERROR) << "input data type only support float now."; + return nullptr; + } + prim->set_constant_value(tensor_proto.float_val(0)); } else if (tf_op.op() == "MirrorPad") { tensorflow::AttrValue attr_value; if (!TensorFlowUtils::FindAttrValue(tf_op, "mode", &attr_value)) { @@ -58,6 +79,7 @@ ops::PrimitiveC *TFPadParser::Parse(const tensorflow::NodeDef &tf_op, return prim.release(); } TFNodeRegistrar g_tfPadParser("Pad", new TFPadParser()); +TFNodeRegistrar g_tfPadV2Parser("PadV2", new TFPadParser()); TFNodeRegistrar g_tfMirrorPadParser("MirrorPad", new TFPadParser()); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/optimizer/graph/unify_format_pass.cc b/mindspore/lite/tools/optimizer/graph/unify_format_pass.cc index 5d1dfafe7c..617e2c620f 100644 --- a/mindspore/lite/tools/optimizer/graph/unify_format_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/unify_format_pass.cc @@ -807,7 +807,6 @@ bool UnifyFormatPass::BasicProcess(const FuncGraphPtr &func_graph, bool main_gra return false; } } - ResetSubGraphInput(); return true; } @@ -868,7 +867,6 @@ bool UnifyFormatPass::DecreaseTransposeForSingleOp(const FuncGraphPtr &func_grap return false; } } - ResetSubGraphInput(); return true; } @@ -1029,6 +1027,7 @@ bool UnifyFormatPass::RunOnlyForShape(const FuncGraphPtr &func_graph) { MS_LOG(ERROR) << "run framework transpose unify failed."; return false; } + ResetSubGraphInput(); // delete insert transpose op and update op output shape. if (!ResetFuncGraph(func_graph)) { MS_LOG(ERROR) << "reset func_graph failed."; @@ -1059,11 +1058,13 @@ bool UnifyFormatPass::Run(const FuncGraphPtr &func_graph) { MS_LOG(ERROR) << "run framework transpose unify failed."; return false; } + ResetSubGraphInput(); // if input format of a certain op can be NHWC, can try transform this op to decrease the number of transpose op. if (!DecreaseTransposeForSingleOp(func_graph)) { MS_LOG(ERROR) << "run local trans insert optimizer failed."; return false; } + ResetSubGraphInput(); // if input format of several ops surrounded only by transpose op all can be NHWC, // we can delete these transpose ops, and at the same time, transform these middle ops. if (!DecreaseTransposeForMultiOp(func_graph)) { diff --git a/mindspore/lite/tools/optimizer/graph/unused_node_remove_pass.cc b/mindspore/lite/tools/optimizer/graph/unused_node_remove_pass.cc index 8f0466bc03..3481597cc1 100644 --- a/mindspore/lite/tools/optimizer/graph/unused_node_remove_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/unused_node_remove_pass.cc @@ -41,8 +41,8 @@ STATUS UnusedNodeRemovePass::ProcessGraph(const FuncGraphPtr &func_graph) { } } } - if (utils::isa(node)) { - auto sub_graph = utils::cast(node); + if (utils::isa(node) && GetValueNode(node) != nullptr) { + auto sub_graph = GetValueNode(node); auto status = ProcessGraph(sub_graph); if (status != RET_OK) { MS_LOG(ERROR) << "process sub graph failed"; @@ -51,8 +51,10 @@ STATUS UnusedNodeRemovePass::ProcessGraph(const FuncGraphPtr &func_graph) { } } auto nodes = func_graph->nodes(); + auto graph_inputs = func_graph->get_inputs(); for (auto &node : nodes) { - if (vis.find(node) == vis.end()) { + if (vis.find(node) == vis.end() && + std::find(graph_inputs.begin(), graph_inputs.end(), node) == graph_inputs.end()) { func_graph->DropNode(node); } }