From: @xu_anyue Reviewed-by: @hangangqiang,@jpc_chenjianping Signed-off-by: @hangangqiangpull/15814/MERGE
| @@ -31,7 +31,28 @@ ops::PrimitiveC *TFPadParser::Parse(const tensorflow::NodeDef &tf_op, | |||||
| if (tf_op.op() == "Pad") { | if (tf_op.op() == "Pad") { | ||||
| prim->set_padding_mode(mindspore::PaddingMode::CONSTANT); | prim->set_padding_mode(mindspore::PaddingMode::CONSTANT); | ||||
| prim->set_constant_value(0.0f); | prim->set_constant_value(0.0f); | ||||
| } else if (tf_op.op() == "PadV2") { | |||||
| prim->set_padding_mode(mindspore::PaddingMode::CONSTANT); | |||||
| if (tf_op.input_size() < 3) { | |||||
| MS_LOG(ERROR) << "tf padv2 input size less than 3, which is " << tf_op.input_size(); | |||||
| return nullptr; | |||||
| } | |||||
| auto &const_value_name = tf_op.input(2); | |||||
| if (tf_node_map.find(const_value_name) == tf_node_map.end()) { | |||||
| MS_LOG(ERROR) << "cannot find the input."; | |||||
| return nullptr; | |||||
| } | |||||
| tensorflow::AttrValue attr_value; | |||||
| if (!TensorFlowUtils::FindAttrValue(*tf_node_map.at(const_value_name), "value", &attr_value)) { | |||||
| MS_LOG(ERROR) << "the input may be not const, which is not support now."; | |||||
| return nullptr; | |||||
| } | |||||
| auto &tensor_proto = attr_value.tensor(); | |||||
| if (tensor_proto.dtype() != tensorflow::DT_FLOAT) { | |||||
| MS_LOG(ERROR) << "input data type only support float now."; | |||||
| return nullptr; | |||||
| } | |||||
| prim->set_constant_value(tensor_proto.float_val(0)); | |||||
| } else if (tf_op.op() == "MirrorPad") { | } else if (tf_op.op() == "MirrorPad") { | ||||
| tensorflow::AttrValue attr_value; | tensorflow::AttrValue attr_value; | ||||
| if (!TensorFlowUtils::FindAttrValue(tf_op, "mode", &attr_value)) { | if (!TensorFlowUtils::FindAttrValue(tf_op, "mode", &attr_value)) { | ||||
| @@ -58,6 +79,7 @@ ops::PrimitiveC *TFPadParser::Parse(const tensorflow::NodeDef &tf_op, | |||||
| return prim.release(); | return prim.release(); | ||||
| } | } | ||||
| TFNodeRegistrar g_tfPadParser("Pad", new TFPadParser()); | TFNodeRegistrar g_tfPadParser("Pad", new TFPadParser()); | ||||
| TFNodeRegistrar g_tfPadV2Parser("PadV2", new TFPadParser()); | |||||
| TFNodeRegistrar g_tfMirrorPadParser("MirrorPad", new TFPadParser()); | TFNodeRegistrar g_tfMirrorPadParser("MirrorPad", new TFPadParser()); | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -807,7 +807,6 @@ bool UnifyFormatPass::BasicProcess(const FuncGraphPtr &func_graph, bool main_gra | |||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| ResetSubGraphInput(); | |||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -868,7 +867,6 @@ bool UnifyFormatPass::DecreaseTransposeForSingleOp(const FuncGraphPtr &func_grap | |||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| ResetSubGraphInput(); | |||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -1029,6 +1027,7 @@ bool UnifyFormatPass::RunOnlyForShape(const FuncGraphPtr &func_graph) { | |||||
| MS_LOG(ERROR) << "run framework transpose unify failed."; | MS_LOG(ERROR) << "run framework transpose unify failed."; | ||||
| return false; | return false; | ||||
| } | } | ||||
| ResetSubGraphInput(); | |||||
| // delete insert transpose op and update op output shape. | // delete insert transpose op and update op output shape. | ||||
| if (!ResetFuncGraph(func_graph)) { | if (!ResetFuncGraph(func_graph)) { | ||||
| MS_LOG(ERROR) << "reset func_graph failed."; | MS_LOG(ERROR) << "reset func_graph failed."; | ||||
| @@ -1059,11 +1058,13 @@ bool UnifyFormatPass::Run(const FuncGraphPtr &func_graph) { | |||||
| MS_LOG(ERROR) << "run framework transpose unify failed."; | MS_LOG(ERROR) << "run framework transpose unify failed."; | ||||
| return false; | return false; | ||||
| } | } | ||||
| ResetSubGraphInput(); | |||||
| // if input format of a certain op can be NHWC, can try transform this op to decrease the number of transpose op. | // if input format of a certain op can be NHWC, can try transform this op to decrease the number of transpose op. | ||||
| if (!DecreaseTransposeForSingleOp(func_graph)) { | if (!DecreaseTransposeForSingleOp(func_graph)) { | ||||
| MS_LOG(ERROR) << "run local trans insert optimizer failed."; | MS_LOG(ERROR) << "run local trans insert optimizer failed."; | ||||
| return false; | return false; | ||||
| } | } | ||||
| ResetSubGraphInput(); | |||||
| // if input format of several ops surrounded only by transpose op all can be NHWC, | // if input format of several ops surrounded only by transpose op all can be NHWC, | ||||
| // we can delete these transpose ops, and at the same time, transform these middle ops. | // we can delete these transpose ops, and at the same time, transform these middle ops. | ||||
| if (!DecreaseTransposeForMultiOp(func_graph)) { | if (!DecreaseTransposeForMultiOp(func_graph)) { | ||||
| @@ -41,8 +41,8 @@ STATUS UnusedNodeRemovePass::ProcessGraph(const FuncGraphPtr &func_graph) { | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| if (utils::isa<FuncGraphPtr>(node)) { | |||||
| auto sub_graph = utils::cast<FuncGraphPtr>(node); | |||||
| if (utils::isa<ValueNode>(node) && GetValueNode<FuncGraphPtr>(node) != nullptr) { | |||||
| auto sub_graph = GetValueNode<FuncGraphPtr>(node); | |||||
| auto status = ProcessGraph(sub_graph); | auto status = ProcessGraph(sub_graph); | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "process sub graph failed"; | MS_LOG(ERROR) << "process sub graph failed"; | ||||
| @@ -51,8 +51,10 @@ STATUS UnusedNodeRemovePass::ProcessGraph(const FuncGraphPtr &func_graph) { | |||||
| } | } | ||||
| } | } | ||||
| auto nodes = func_graph->nodes(); | auto nodes = func_graph->nodes(); | ||||
| auto graph_inputs = func_graph->get_inputs(); | |||||
| for (auto &node : nodes) { | for (auto &node : nodes) { | ||||
| if (vis.find(node) == vis.end()) { | |||||
| if (vis.find(node) == vis.end() && | |||||
| std::find(graph_inputs.begin(), graph_inputs.end(), node) == graph_inputs.end()) { | |||||
| func_graph->DropNode(node); | func_graph->DropNode(node); | ||||
| } | } | ||||
| } | } | ||||