| @@ -35,7 +35,8 @@ REGISTER_OPTYPE_DEFINE(TF_BATCH_MATMUL, "BatchMatmul"); | |||
| namespace ge { | |||
| namespace { | |||
| const char RRTVAL_NODE_NAME_SUFFIX[] = "_RetVal"; | |||
| const char *const kShapeNodeName = "Shape"; | |||
| const char *const kShapeNodeType = "Shape"; | |||
| const char *const kShapeNodeNamePrefix = "getnext_shape_"; | |||
| } // namespace | |||
| Status ParserGraphOptimizer::FusionFmkop() { | |||
| @@ -62,19 +63,25 @@ Status ParserGraphOptimizer::FusionFmkop() { | |||
| return SUCCESS; | |||
| } | |||
| Status ParserGraphOptimizer::MarkForFusion(unordered_map<string, vector<NodePtr>> &node_cluser_Map) { | |||
| Status ParserGraphOptimizer::MarkForFusion(unordered_map<string, vector<NodePtr>> &node_cluster_map) { | |||
| GE_CHECK_NOTNULL(graph_); | |||
| bool hasGetNext = false; | |||
| bool has_get_next = false; | |||
| for (auto node : graph_->GetDirectNode()) { | |||
| GE_CHECK_NOTNULL(node); | |||
| GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != ge::parser::FRAMEWORK_OP_TYPE, continue); | |||
| string type = ""; | |||
| GE_CHK_STATUS_RET(ge::parser::GetOriginalType(node, type)); | |||
| if (type == "IteratorGetNext") { | |||
| hasGetNext = true; | |||
| has_get_next = true; | |||
| break; | |||
| } | |||
| } | |||
| return GetFusionCluster(has_get_next, node_cluster_map); | |||
| } | |||
| Status ParserGraphOptimizer::GetFusionCluster(const bool has_get_next, | |||
| unordered_map<string, vector<NodePtr>> &node_cluster_map) { | |||
| GE_CHECK_NOTNULL(graph_); | |||
| for (auto node : graph_->GetDirectNode()) { | |||
| GE_CHECK_NOTNULL(node); | |||
| GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != ge::parser::FRAMEWORK_OP_TYPE, continue) | |||
| @@ -97,7 +104,8 @@ Status ParserGraphOptimizer::MarkForFusion(unordered_map<string, vector<NodePtr> | |||
| NodePtr dst_node = in_anchor->GetOwnerNode(); | |||
| GE_CHECK_NOTNULL(dst_node); | |||
| GE_CHECK_NOTNULL(dst_node->GetOpDesc()); | |||
| if (dst_node->GetOpDesc()->GetType() == kShapeNodeName) { | |||
| if ((dst_node->GetName().find(kShapeNodeNamePrefix) != std::string::npos) && | |||
| (dst_node->GetOpDesc()->GetType() == kShapeNodeType)) { | |||
| temp_node_cluser.emplace_back(dst_node); | |||
| } | |||
| } | |||
| @@ -105,14 +113,14 @@ Status ParserGraphOptimizer::MarkForFusion(unordered_map<string, vector<NodePtr> | |||
| if (temp_node_cluser.size() > 1) { | |||
| vector<NodePtr> node_cluser; | |||
| node_cluser.assign(temp_node_cluser.begin(), temp_node_cluser.end()); | |||
| node_cluser_Map[temp_node_cluser[0]->GetName()] = node_cluser; | |||
| node_cluster_map[temp_node_cluser[0]->GetName()] = node_cluser; | |||
| } | |||
| temp_node_cluser.clear(); | |||
| GELOGI("MarkForFusion, IteratorGetNext graph mark success."); | |||
| } | |||
| if (!hasGetNext && (type == "Iterator" || type == "IteratorV2")) { | |||
| GE_CHK_STATUS_RET(FindFmkNodeCluser(node_cluser_Map), "find framework node to be fused fail."); | |||
| if (!has_get_next && (type == "Iterator" || type == "IteratorV2")) { | |||
| GE_CHK_STATUS_RET(FindFmkNodeCluser(node_cluster_map), "find framework node to be fused fail."); | |||
| GELOGI("MarkForFusion, Iterator init graph mark success."); | |||
| } | |||
| } | |||
| @@ -41,7 +41,9 @@ class ParserGraphOptimizer { | |||
| domi::Status FindFmkNodeCluser(std::unordered_map<std::string, std::vector<ge::NodePtr>> &node_cluser_Map) const; | |||
| domi::Status MarkForFusion(std::unordered_map<std::string, std::vector<ge::NodePtr>> &node_cluser_Map); | |||
| domi::Status MarkForFusion(std::unordered_map<std::string, std::vector<ge::NodePtr>> &node_cluster_map); | |||
| domi::Status GetFusionCluster(const bool has_get_next, unordered_map<string, vector<NodePtr>> &node_cluster_map); | |||
| domi::Status UpdateGraph(std::vector<ge::NodePtr> &nodes); | |||
| @@ -316,6 +316,13 @@ REG_OP(Softmax) | |||
| .ATTR(beta, Float, 0) | |||
| .OP_END_FACTORY_REG(Softmax) | |||
| REG_OP(Shape) | |||
| .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||
| DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||
| .OUTPUT(y, TensorType({DT_INT32, DT_INT64})) | |||
| .ATTR(dtype, Int, DT_INT32) | |||
| .OP_END_FACTORY_REG(Shape) | |||
| // for plugin | |||
| static Status ParseParamsStub(const google::protobuf::Message* op_src, ge::Operator& op_dest) { | |||
| return SUCCESS; | |||
| @@ -0,0 +1,174 @@ | |||
| node { | |||
| name: "IteratorV2" | |||
| op: "IteratorV2" | |||
| attr { | |||
| key: "op_def" | |||
| value { | |||
| s: "\n\007GetNext\032\032\n\ncomponents2\014output_types\"\036\n\014output_types\022\nlist(type)(\0010\001\" \n\routput_shapes\022\013list(shape)(\0010\001\"\026\n\014channel_name\022\006string\210\001\001" | |||
| } | |||
| } | |||
| attr { | |||
| key: "output_types" | |||
| value { | |||
| list { | |||
| type: DT_INT64 | |||
| } | |||
| } | |||
| } | |||
| attr { | |||
| key: "output_tensor_desc" | |||
| value { | |||
| list { | |||
| func { | |||
| name: "0" | |||
| attr { | |||
| key: "serialize_datatype" | |||
| value: { | |||
| i: 9 | |||
| } | |||
| } | |||
| attr { | |||
| key: "serialize_format" | |||
| value: { | |||
| i: 1 | |||
| } | |||
| } | |||
| attr { | |||
| key: "serialize_shape" | |||
| value { | |||
| type: DT_INT32 | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| node { | |||
| name: "IteratorGetNext" | |||
| op: "IteratorGetNext" | |||
| input: "IteratorV2" | |||
| attr { | |||
| key: "output_types" | |||
| value { | |||
| list { | |||
| type: DT_INT64 | |||
| } | |||
| } | |||
| } | |||
| attr { | |||
| key: "op_def" | |||
| value { | |||
| s: "\n\007GetNext\032\032\n\ncomponents2\014output_types\"\036\n\014output_types\022\nlist(type)(\0010\001\" \n\routput_shapes\022\013list(shape)(\0010\001\"\026\n\014channel_name\022\006string\210\001\001" | |||
| } | |||
| } | |||
| attr { | |||
| key: "input_tensor_desc" | |||
| value { | |||
| list { | |||
| func { | |||
| name: "0" | |||
| attr { | |||
| key: "serialize_datatype" | |||
| value: { | |||
| i: 9 | |||
| } | |||
| } | |||
| attr { | |||
| key: "serialize_format" | |||
| value: { | |||
| i: 1 | |||
| } | |||
| } | |||
| attr { | |||
| key: "serialize_shape" | |||
| value { | |||
| type: DT_INT32 | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| attr { | |||
| key: "output_tensor_desc" | |||
| value { | |||
| list { | |||
| func { | |||
| name: "0" | |||
| attr { | |||
| key: "serialize_datatype" | |||
| value: { | |||
| i: 9 | |||
| } | |||
| } | |||
| attr { | |||
| key: "serialize_format" | |||
| value: { | |||
| i: 1 | |||
| } | |||
| } | |||
| attr { | |||
| key: "serialize_shape" | |||
| value { | |||
| list { | |||
| i: -1 | |||
| i: -1 | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| node { | |||
| name: "getnext_shape_0" | |||
| op: "Shape" | |||
| input: "IteratorGetNext" | |||
| attr { | |||
| key: "op_def" | |||
| value { | |||
| s: "\n\005Shape\022\n\n\005input\"\001T\032\022\n\006output\"\010out_type\"\t\n\001T\022\004type\"\034\n\010out_type\022\004type\032\0020\003:\006\n\0042\002\003\t" | |||
| } | |||
| } | |||
| } | |||
| node { | |||
| name: "retval_GetNext_0_0" | |||
| op: "_Retval" | |||
| input: "IteratorGetNext" | |||
| attr { | |||
| key: "index" | |||
| value { | |||
| i: 0 | |||
| } | |||
| } | |||
| attr { | |||
| key: "op_def" | |||
| value { | |||
| s: "" | |||
| } | |||
| } | |||
| } | |||
| node { | |||
| name: "retval_GetNext_0_1" | |||
| op: "_Retval" | |||
| input: "getnext_shape_0" | |||
| attr { | |||
| key: "index" | |||
| value { | |||
| i: 1 | |||
| } | |||
| } | |||
| attr { | |||
| key: "op_def" | |||
| value { | |||
| s: "" | |||
| } | |||
| } | |||
| } | |||
| library { | |||
| } | |||
| versions { | |||
| producer: 134 | |||
| } | |||
| @@ -4225,4 +4225,21 @@ TEST_F(STestTensorflowParser, parser_UppdateInputMap_test) | |||
| delete graph; | |||
| } | |||
| TEST_F(STestTensorflowParser, tensorflow_optimizer_fmk_fusion_op) { | |||
| std::string caseDir = __FILE__; | |||
| std::size_t idx = caseDir.find_last_of("/"); | |||
| caseDir = caseDir.substr(0, idx); | |||
| const std::string root_proto = caseDir + "/origin_models/test_getnext_dynamic_fusion.pbtxt"; | |||
| domi::tensorflow::GraphDef graphDef; | |||
| bool protoRet = parser::ReadProtoFromText(root_proto.c_str(), &graphDef); | |||
| ASSERT_EQ(protoRet, true); | |||
| TensorFlowModelParser tensorflow_parser; | |||
| ge::ComputeGraphPtr root_graph = ge::parser::MakeShared<ge::ComputeGraph>("tmp_graph"); | |||
| Status ret = tensorflow_parser.ParseProto(reinterpret_cast<google::protobuf::Message *>(&graphDef), root_graph); | |||
| EXPECT_EQ(ret, SUCCESS); | |||
| EXPECT_EQ(root_graph->GetDirectNode().size(), 3); | |||
| } | |||
| } // namespace ge | |||