diff --git a/parser/tensorflow/graph_optimizer.cc b/parser/tensorflow/graph_optimizer.cc index 69f3420..829b576 100644 --- a/parser/tensorflow/graph_optimizer.cc +++ b/parser/tensorflow/graph_optimizer.cc @@ -35,7 +35,8 @@ REGISTER_OPTYPE_DEFINE(TF_BATCH_MATMUL, "BatchMatmul"); namespace ge { namespace { const char RRTVAL_NODE_NAME_SUFFIX[] = "_RetVal"; -const char *const kShapeNodeName = "Shape"; +const char *const kShapeNodeType = "Shape"; +const char *const kShapeNodeNamePrefix = "getnext_shape_"; } // namespace Status ParserGraphOptimizer::FusionFmkop() { @@ -62,19 +63,25 @@ Status ParserGraphOptimizer::FusionFmkop() { return SUCCESS; } -Status ParserGraphOptimizer::MarkForFusion(unordered_map> &node_cluser_Map) { +Status ParserGraphOptimizer::MarkForFusion(unordered_map> &node_cluster_map) { GE_CHECK_NOTNULL(graph_); - bool hasGetNext = false; + bool has_get_next = false; for (auto node : graph_->GetDirectNode()) { GE_CHECK_NOTNULL(node); GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != ge::parser::FRAMEWORK_OP_TYPE, continue); string type = ""; GE_CHK_STATUS_RET(ge::parser::GetOriginalType(node, type)); if (type == "IteratorGetNext") { - hasGetNext = true; + has_get_next = true; break; } } + return GetFusionCluster(has_get_next, node_cluster_map); +} + +Status ParserGraphOptimizer::GetFusionCluster(const bool has_get_next, + unordered_map> &node_cluster_map) { + GE_CHECK_NOTNULL(graph_); for (auto node : graph_->GetDirectNode()) { GE_CHECK_NOTNULL(node); GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != ge::parser::FRAMEWORK_OP_TYPE, continue) @@ -97,7 +104,8 @@ Status ParserGraphOptimizer::MarkForFusion(unordered_map NodePtr dst_node = in_anchor->GetOwnerNode(); GE_CHECK_NOTNULL(dst_node); GE_CHECK_NOTNULL(dst_node->GetOpDesc()); - if (dst_node->GetOpDesc()->GetType() == kShapeNodeName) { + if ((dst_node->GetName().find(kShapeNodeNamePrefix) != std::string::npos) && + (dst_node->GetOpDesc()->GetType() == kShapeNodeType)) { temp_node_cluser.emplace_back(dst_node); } } @@ -105,14 +113,14 @@ Status ParserGraphOptimizer::MarkForFusion(unordered_map if (temp_node_cluser.size() > 1) { vector node_cluser; node_cluser.assign(temp_node_cluser.begin(), temp_node_cluser.end()); - node_cluser_Map[temp_node_cluser[0]->GetName()] = node_cluser; + node_cluster_map[temp_node_cluser[0]->GetName()] = node_cluser; } temp_node_cluser.clear(); GELOGI("MarkForFusion, IteratorGetNext graph mark success."); } - if (!hasGetNext && (type == "Iterator" || type == "IteratorV2")) { - GE_CHK_STATUS_RET(FindFmkNodeCluser(node_cluser_Map), "find framework node to be fused fail."); + if (!has_get_next && (type == "Iterator" || type == "IteratorV2")) { + GE_CHK_STATUS_RET(FindFmkNodeCluser(node_cluster_map), "find framework node to be fused fail."); GELOGI("MarkForFusion, Iterator init graph mark success."); } } diff --git a/parser/tensorflow/graph_optimizer.h b/parser/tensorflow/graph_optimizer.h index d004595..420c2b5 100644 --- a/parser/tensorflow/graph_optimizer.h +++ b/parser/tensorflow/graph_optimizer.h @@ -41,7 +41,9 @@ class ParserGraphOptimizer { domi::Status FindFmkNodeCluser(std::unordered_map> &node_cluser_Map) const; - domi::Status MarkForFusion(std::unordered_map> &node_cluser_Map); + domi::Status MarkForFusion(std::unordered_map> &node_cluster_map); + + domi::Status GetFusionCluster(const bool has_get_next, unordered_map> &node_cluster_map); domi::Status UpdateGraph(std::vector &nodes); diff --git a/tests/depends/ops_stub/ops_stub.h b/tests/depends/ops_stub/ops_stub.h index c3341da..d876ed4 100644 --- a/tests/depends/ops_stub/ops_stub.h +++ b/tests/depends/ops_stub/ops_stub.h @@ -316,6 +316,13 @@ REG_OP(Softmax) .ATTR(beta, Float, 0) .OP_END_FACTORY_REG(Softmax) +REG_OP(Shape) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, + DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_INT32, DT_INT64})) + .ATTR(dtype, Int, DT_INT32) + .OP_END_FACTORY_REG(Shape) + // for plugin static Status ParseParamsStub(const google::protobuf::Message* op_src, ge::Operator& op_dest) { return SUCCESS; diff --git a/tests/st/testcase/origin_models/test_getnext_dynamic_fusion.pbtxt b/tests/st/testcase/origin_models/test_getnext_dynamic_fusion.pbtxt new file mode 100644 index 0000000..80e0df9 --- /dev/null +++ b/tests/st/testcase/origin_models/test_getnext_dynamic_fusion.pbtxt @@ -0,0 +1,174 @@ +node { + name: "IteratorV2" + op: "IteratorV2" + attr { + key: "op_def" + value { + s: "\n\007GetNext\032\032\n\ncomponents2\014output_types\"\036\n\014output_types\022\nlist(type)(\0010\001\" \n\routput_shapes\022\013list(shape)(\0010\001\"\026\n\014channel_name\022\006string\210\001\001" + } + } + attr { + key: "output_types" + value { + list { + type: DT_INT64 + } + } + } + attr { + key: "output_tensor_desc" + value { + list { + func { + name: "0" + attr { + key: "serialize_datatype" + value: { + i: 9 + } + } + attr { + key: "serialize_format" + value: { + i: 1 + } + } + attr { + key: "serialize_shape" + value { + type: DT_INT32 + } + } + } + } + } + } +} +node { + name: "IteratorGetNext" + op: "IteratorGetNext" + input: "IteratorV2" + attr { + key: "output_types" + value { + list { + type: DT_INT64 + } + } + } + attr { + key: "op_def" + value { + s: "\n\007GetNext\032\032\n\ncomponents2\014output_types\"\036\n\014output_types\022\nlist(type)(\0010\001\" \n\routput_shapes\022\013list(shape)(\0010\001\"\026\n\014channel_name\022\006string\210\001\001" + } + } + attr { + key: "input_tensor_desc" + value { + list { + func { + name: "0" + attr { + key: "serialize_datatype" + value: { + i: 9 + } + } + attr { + key: "serialize_format" + value: { + i: 1 + } + } + attr { + key: "serialize_shape" + value { + type: DT_INT32 + } + } + } + } + } + } + attr { + key: "output_tensor_desc" + value { + list { + func { + name: "0" + attr { + key: "serialize_datatype" + value: { + i: 9 + } + } + attr { + key: "serialize_format" + value: { + i: 1 + } + } + attr { + key: "serialize_shape" + value { + list { + i: -1 + i: -1 + } + } + } + } + } + } + } +} +node { + name: "getnext_shape_0" + op: "Shape" + input: "IteratorGetNext" + attr { + key: "op_def" + value { + s: "\n\005Shape\022\n\n\005input\"\001T\032\022\n\006output\"\010out_type\"\t\n\001T\022\004type\"\034\n\010out_type\022\004type\032\0020\003:\006\n\0042\002\003\t" + } + } +} +node { + name: "retval_GetNext_0_0" + op: "_Retval" + input: "IteratorGetNext" + attr { + key: "index" + value { + i: 0 + } + } + attr { + key: "op_def" + value { + s: "" + } + } +} +node { + name: "retval_GetNext_0_1" + op: "_Retval" + input: "getnext_shape_0" + attr { + key: "index" + value { + i: 1 + } + } + attr { + key: "op_def" + value { + s: "" + } + } +} +library { +} +versions { + producer: 134 +} diff --git a/tests/st/testcase/test_tensorflow_parser.cc b/tests/st/testcase/test_tensorflow_parser.cc index 8a2ccdd..07346f9 100644 --- a/tests/st/testcase/test_tensorflow_parser.cc +++ b/tests/st/testcase/test_tensorflow_parser.cc @@ -4225,4 +4225,21 @@ TEST_F(STestTensorflowParser, parser_UppdateInputMap_test) delete graph; } +TEST_F(STestTensorflowParser, tensorflow_optimizer_fmk_fusion_op) { + std::string caseDir = __FILE__; + std::size_t idx = caseDir.find_last_of("/"); + caseDir = caseDir.substr(0, idx); + const std::string root_proto = caseDir + "/origin_models/test_getnext_dynamic_fusion.pbtxt"; + domi::tensorflow::GraphDef graphDef; + + bool protoRet = parser::ReadProtoFromText(root_proto.c_str(), &graphDef); + ASSERT_EQ(protoRet, true); + + TensorFlowModelParser tensorflow_parser; + ge::ComputeGraphPtr root_graph = ge::parser::MakeShared("tmp_graph"); + Status ret = tensorflow_parser.ParseProto(reinterpret_cast(&graphDef), root_graph); + EXPECT_EQ(ret, SUCCESS); + EXPECT_EQ(root_graph->GetDirectNode().size(), 3); +} + } // namespace ge