| @@ -31,7 +31,28 @@ ops::PrimitiveC *TFPadParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| if (tf_op.op() == "Pad") { | |||
| prim->set_padding_mode(mindspore::PaddingMode::CONSTANT); | |||
| prim->set_constant_value(0.0f); | |||
| } else if (tf_op.op() == "PadV2") { | |||
| prim->set_padding_mode(mindspore::PaddingMode::CONSTANT); | |||
| if (tf_op.input_size() < 3) { | |||
| MS_LOG(ERROR) << "tf padv2 input size less than 3, which is " << tf_op.input_size(); | |||
| return nullptr; | |||
| } | |||
| auto &const_value_name = tf_op.input(2); | |||
| if (tf_node_map.find(const_value_name) == tf_node_map.end()) { | |||
| MS_LOG(ERROR) << "cannot find the input."; | |||
| return nullptr; | |||
| } | |||
| tensorflow::AttrValue attr_value; | |||
| if (!TensorFlowUtils::FindAttrValue(*tf_node_map.at(const_value_name), "value", &attr_value)) { | |||
| MS_LOG(ERROR) << "the input may be not const, which is not support now."; | |||
| return nullptr; | |||
| } | |||
| auto &tensor_proto = attr_value.tensor(); | |||
| if (tensor_proto.dtype() != tensorflow::DT_FLOAT) { | |||
| MS_LOG(ERROR) << "input data type only support float now."; | |||
| return nullptr; | |||
| } | |||
| prim->set_constant_value(tensor_proto.float_val(0)); | |||
| } else if (tf_op.op() == "MirrorPad") { | |||
| tensorflow::AttrValue attr_value; | |||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "mode", &attr_value)) { | |||
| @@ -58,6 +79,7 @@ ops::PrimitiveC *TFPadParser::Parse(const tensorflow::NodeDef &tf_op, | |||
| return prim.release(); | |||
| } | |||
| TFNodeRegistrar g_tfPadParser("Pad", new TFPadParser()); | |||
| TFNodeRegistrar g_tfPadV2Parser("PadV2", new TFPadParser()); | |||
| TFNodeRegistrar g_tfMirrorPadParser("MirrorPad", new TFPadParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -807,7 +807,6 @@ bool UnifyFormatPass::BasicProcess(const FuncGraphPtr &func_graph, bool main_gra | |||
| return false; | |||
| } | |||
| } | |||
| ResetSubGraphInput(); | |||
| return true; | |||
| } | |||
| @@ -868,7 +867,6 @@ bool UnifyFormatPass::DecreaseTransposeForSingleOp(const FuncGraphPtr &func_grap | |||
| return false; | |||
| } | |||
| } | |||
| ResetSubGraphInput(); | |||
| return true; | |||
| } | |||
| @@ -1029,6 +1027,7 @@ bool UnifyFormatPass::RunOnlyForShape(const FuncGraphPtr &func_graph) { | |||
| MS_LOG(ERROR) << "run framework transpose unify failed."; | |||
| return false; | |||
| } | |||
| ResetSubGraphInput(); | |||
| // delete insert transpose op and update op output shape. | |||
| if (!ResetFuncGraph(func_graph)) { | |||
| MS_LOG(ERROR) << "reset func_graph failed."; | |||
| @@ -1059,11 +1058,13 @@ bool UnifyFormatPass::Run(const FuncGraphPtr &func_graph) { | |||
| MS_LOG(ERROR) << "run framework transpose unify failed."; | |||
| return false; | |||
| } | |||
| ResetSubGraphInput(); | |||
| // if input format of a certain op can be NHWC, can try transform this op to decrease the number of transpose op. | |||
| if (!DecreaseTransposeForSingleOp(func_graph)) { | |||
| MS_LOG(ERROR) << "run local trans insert optimizer failed."; | |||
| return false; | |||
| } | |||
| ResetSubGraphInput(); | |||
| // if input format of several ops surrounded only by transpose op all can be NHWC, | |||
| // we can delete these transpose ops, and at the same time, transform these middle ops. | |||
| if (!DecreaseTransposeForMultiOp(func_graph)) { | |||
| @@ -41,8 +41,8 @@ STATUS UnusedNodeRemovePass::ProcessGraph(const FuncGraphPtr &func_graph) { | |||
| } | |||
| } | |||
| } | |||
| if (utils::isa<FuncGraphPtr>(node)) { | |||
| auto sub_graph = utils::cast<FuncGraphPtr>(node); | |||
| if (utils::isa<ValueNode>(node) && GetValueNode<FuncGraphPtr>(node) != nullptr) { | |||
| auto sub_graph = GetValueNode<FuncGraphPtr>(node); | |||
| auto status = ProcessGraph(sub_graph); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "process sub graph failed"; | |||
| @@ -51,8 +51,10 @@ STATUS UnusedNodeRemovePass::ProcessGraph(const FuncGraphPtr &func_graph) { | |||
| } | |||
| } | |||
| auto nodes = func_graph->nodes(); | |||
| auto graph_inputs = func_graph->get_inputs(); | |||
| for (auto &node : nodes) { | |||
| if (vis.find(node) == vis.end()) { | |||
| if (vis.find(node) == vis.end() && | |||
| std::find(graph_inputs.begin(), graph_inputs.end(), node) == graph_inputs.end()) { | |||
| func_graph->DropNode(node); | |||
| } | |||
| } | |||