| @@ -35,7 +35,8 @@ REGISTER_OPTYPE_DEFINE(TF_BATCH_MATMUL, "BatchMatmul"); | |||||
| namespace ge { | namespace ge { | ||||
| namespace { | namespace { | ||||
| const char RRTVAL_NODE_NAME_SUFFIX[] = "_RetVal"; | const char RRTVAL_NODE_NAME_SUFFIX[] = "_RetVal"; | ||||
| const char *const kShapeNodeName = "Shape"; | |||||
| const char *const kShapeNodeType = "Shape"; | |||||
| const char *const kShapeNodeNamePrefix = "getnext_shape_"; | |||||
| } // namespace | } // namespace | ||||
| Status ParserGraphOptimizer::FusionFmkop() { | Status ParserGraphOptimizer::FusionFmkop() { | ||||
| @@ -62,19 +63,25 @@ Status ParserGraphOptimizer::FusionFmkop() { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status ParserGraphOptimizer::MarkForFusion(unordered_map<string, vector<NodePtr>> &node_cluser_Map) { | |||||
| Status ParserGraphOptimizer::MarkForFusion(unordered_map<string, vector<NodePtr>> &node_cluster_map) { | |||||
| GE_CHECK_NOTNULL(graph_); | GE_CHECK_NOTNULL(graph_); | ||||
| bool hasGetNext = false; | |||||
| bool has_get_next = false; | |||||
| for (auto node : graph_->GetDirectNode()) { | for (auto node : graph_->GetDirectNode()) { | ||||
| GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
| GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != ge::parser::FRAMEWORK_OP_TYPE, continue); | GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != ge::parser::FRAMEWORK_OP_TYPE, continue); | ||||
| string type = ""; | string type = ""; | ||||
| GE_CHK_STATUS_RET(ge::parser::GetOriginalType(node, type)); | GE_CHK_STATUS_RET(ge::parser::GetOriginalType(node, type)); | ||||
| if (type == "IteratorGetNext") { | if (type == "IteratorGetNext") { | ||||
| hasGetNext = true; | |||||
| has_get_next = true; | |||||
| break; | break; | ||||
| } | } | ||||
| } | } | ||||
| return GetFusionCluster(has_get_next, node_cluster_map); | |||||
| } | |||||
| Status ParserGraphOptimizer::GetFusionCluster(const bool has_get_next, | |||||
| unordered_map<string, vector<NodePtr>> &node_cluster_map) { | |||||
| GE_CHECK_NOTNULL(graph_); | |||||
| for (auto node : graph_->GetDirectNode()) { | for (auto node : graph_->GetDirectNode()) { | ||||
| GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
| GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != ge::parser::FRAMEWORK_OP_TYPE, continue) | GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != ge::parser::FRAMEWORK_OP_TYPE, continue) | ||||
| @@ -97,7 +104,8 @@ Status ParserGraphOptimizer::MarkForFusion(unordered_map<string, vector<NodePtr> | |||||
| NodePtr dst_node = in_anchor->GetOwnerNode(); | NodePtr dst_node = in_anchor->GetOwnerNode(); | ||||
| GE_CHECK_NOTNULL(dst_node); | GE_CHECK_NOTNULL(dst_node); | ||||
| GE_CHECK_NOTNULL(dst_node->GetOpDesc()); | GE_CHECK_NOTNULL(dst_node->GetOpDesc()); | ||||
| if (dst_node->GetOpDesc()->GetType() == kShapeNodeName) { | |||||
| if ((dst_node->GetName().find(kShapeNodeNamePrefix) != std::string::npos) && | |||||
| (dst_node->GetOpDesc()->GetType() == kShapeNodeType)) { | |||||
| temp_node_cluser.emplace_back(dst_node); | temp_node_cluser.emplace_back(dst_node); | ||||
| } | } | ||||
| } | } | ||||
| @@ -105,14 +113,14 @@ Status ParserGraphOptimizer::MarkForFusion(unordered_map<string, vector<NodePtr> | |||||
| if (temp_node_cluser.size() > 1) { | if (temp_node_cluser.size() > 1) { | ||||
| vector<NodePtr> node_cluser; | vector<NodePtr> node_cluser; | ||||
| node_cluser.assign(temp_node_cluser.begin(), temp_node_cluser.end()); | node_cluser.assign(temp_node_cluser.begin(), temp_node_cluser.end()); | ||||
| node_cluser_Map[temp_node_cluser[0]->GetName()] = node_cluser; | |||||
| node_cluster_map[temp_node_cluser[0]->GetName()] = node_cluser; | |||||
| } | } | ||||
| temp_node_cluser.clear(); | temp_node_cluser.clear(); | ||||
| GELOGI("MarkForFusion, IteratorGetNext graph mark success."); | GELOGI("MarkForFusion, IteratorGetNext graph mark success."); | ||||
| } | } | ||||
| if (!hasGetNext && (type == "Iterator" || type == "IteratorV2")) { | |||||
| GE_CHK_STATUS_RET(FindFmkNodeCluser(node_cluser_Map), "find framework node to be fused fail."); | |||||
| if (!has_get_next && (type == "Iterator" || type == "IteratorV2")) { | |||||
| GE_CHK_STATUS_RET(FindFmkNodeCluser(node_cluster_map), "find framework node to be fused fail."); | |||||
| GELOGI("MarkForFusion, Iterator init graph mark success."); | GELOGI("MarkForFusion, Iterator init graph mark success."); | ||||
| } | } | ||||
| } | } | ||||
| @@ -41,7 +41,9 @@ class ParserGraphOptimizer { | |||||
| domi::Status FindFmkNodeCluser(std::unordered_map<std::string, std::vector<ge::NodePtr>> &node_cluser_Map) const; | domi::Status FindFmkNodeCluser(std::unordered_map<std::string, std::vector<ge::NodePtr>> &node_cluser_Map) const; | ||||
| domi::Status MarkForFusion(std::unordered_map<std::string, std::vector<ge::NodePtr>> &node_cluser_Map); | |||||
| domi::Status MarkForFusion(std::unordered_map<std::string, std::vector<ge::NodePtr>> &node_cluster_map); | |||||
| domi::Status GetFusionCluster(const bool has_get_next, unordered_map<string, vector<NodePtr>> &node_cluster_map); | |||||
| domi::Status UpdateGraph(std::vector<ge::NodePtr> &nodes); | domi::Status UpdateGraph(std::vector<ge::NodePtr> &nodes); | ||||
| @@ -316,6 +316,13 @@ REG_OP(Softmax) | |||||
| .ATTR(beta, Float, 0) | .ATTR(beta, Float, 0) | ||||
| .OP_END_FACTORY_REG(Softmax) | .OP_END_FACTORY_REG(Softmax) | ||||
| REG_OP(Shape) | |||||
| .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||||
| DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
| .OUTPUT(y, TensorType({DT_INT32, DT_INT64})) | |||||
| .ATTR(dtype, Int, DT_INT32) | |||||
| .OP_END_FACTORY_REG(Shape) | |||||
| // for plugin | // for plugin | ||||
| static Status ParseParamsStub(const google::protobuf::Message* op_src, ge::Operator& op_dest) { | static Status ParseParamsStub(const google::protobuf::Message* op_src, ge::Operator& op_dest) { | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -0,0 +1,174 @@ | |||||
| node { | |||||
| name: "IteratorV2" | |||||
| op: "IteratorV2" | |||||
| attr { | |||||
| key: "op_def" | |||||
| value { | |||||
| s: "\n\007GetNext\032\032\n\ncomponents2\014output_types\"\036\n\014output_types\022\nlist(type)(\0010\001\" \n\routput_shapes\022\013list(shape)(\0010\001\"\026\n\014channel_name\022\006string\210\001\001" | |||||
| } | |||||
| } | |||||
| attr { | |||||
| key: "output_types" | |||||
| value { | |||||
| list { | |||||
| type: DT_INT64 | |||||
| } | |||||
| } | |||||
| } | |||||
| attr { | |||||
| key: "output_tensor_desc" | |||||
| value { | |||||
| list { | |||||
| func { | |||||
| name: "0" | |||||
| attr { | |||||
| key: "serialize_datatype" | |||||
| value: { | |||||
| i: 9 | |||||
| } | |||||
| } | |||||
| attr { | |||||
| key: "serialize_format" | |||||
| value: { | |||||
| i: 1 | |||||
| } | |||||
| } | |||||
| attr { | |||||
| key: "serialize_shape" | |||||
| value { | |||||
| type: DT_INT32 | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| node { | |||||
| name: "IteratorGetNext" | |||||
| op: "IteratorGetNext" | |||||
| input: "IteratorV2" | |||||
| attr { | |||||
| key: "output_types" | |||||
| value { | |||||
| list { | |||||
| type: DT_INT64 | |||||
| } | |||||
| } | |||||
| } | |||||
| attr { | |||||
| key: "op_def" | |||||
| value { | |||||
| s: "\n\007GetNext\032\032\n\ncomponents2\014output_types\"\036\n\014output_types\022\nlist(type)(\0010\001\" \n\routput_shapes\022\013list(shape)(\0010\001\"\026\n\014channel_name\022\006string\210\001\001" | |||||
| } | |||||
| } | |||||
| attr { | |||||
| key: "input_tensor_desc" | |||||
| value { | |||||
| list { | |||||
| func { | |||||
| name: "0" | |||||
| attr { | |||||
| key: "serialize_datatype" | |||||
| value: { | |||||
| i: 9 | |||||
| } | |||||
| } | |||||
| attr { | |||||
| key: "serialize_format" | |||||
| value: { | |||||
| i: 1 | |||||
| } | |||||
| } | |||||
| attr { | |||||
| key: "serialize_shape" | |||||
| value { | |||||
| type: DT_INT32 | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| attr { | |||||
| key: "output_tensor_desc" | |||||
| value { | |||||
| list { | |||||
| func { | |||||
| name: "0" | |||||
| attr { | |||||
| key: "serialize_datatype" | |||||
| value: { | |||||
| i: 9 | |||||
| } | |||||
| } | |||||
| attr { | |||||
| key: "serialize_format" | |||||
| value: { | |||||
| i: 1 | |||||
| } | |||||
| } | |||||
| attr { | |||||
| key: "serialize_shape" | |||||
| value { | |||||
| list { | |||||
| i: -1 | |||||
| i: -1 | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| node { | |||||
| name: "getnext_shape_0" | |||||
| op: "Shape" | |||||
| input: "IteratorGetNext" | |||||
| attr { | |||||
| key: "op_def" | |||||
| value { | |||||
| s: "\n\005Shape\022\n\n\005input\"\001T\032\022\n\006output\"\010out_type\"\t\n\001T\022\004type\"\034\n\010out_type\022\004type\032\0020\003:\006\n\0042\002\003\t" | |||||
| } | |||||
| } | |||||
| } | |||||
| node { | |||||
| name: "retval_GetNext_0_0" | |||||
| op: "_Retval" | |||||
| input: "IteratorGetNext" | |||||
| attr { | |||||
| key: "index" | |||||
| value { | |||||
| i: 0 | |||||
| } | |||||
| } | |||||
| attr { | |||||
| key: "op_def" | |||||
| value { | |||||
| s: "" | |||||
| } | |||||
| } | |||||
| } | |||||
| node { | |||||
| name: "retval_GetNext_0_1" | |||||
| op: "_Retval" | |||||
| input: "getnext_shape_0" | |||||
| attr { | |||||
| key: "index" | |||||
| value { | |||||
| i: 1 | |||||
| } | |||||
| } | |||||
| attr { | |||||
| key: "op_def" | |||||
| value { | |||||
| s: "" | |||||
| } | |||||
| } | |||||
| } | |||||
| library { | |||||
| } | |||||
| versions { | |||||
| producer: 134 | |||||
| } | |||||
| @@ -4225,4 +4225,21 @@ TEST_F(STestTensorflowParser, parser_UppdateInputMap_test) | |||||
| delete graph; | delete graph; | ||||
| } | } | ||||
| TEST_F(STestTensorflowParser, tensorflow_optimizer_fmk_fusion_op) { | |||||
| std::string caseDir = __FILE__; | |||||
| std::size_t idx = caseDir.find_last_of("/"); | |||||
| caseDir = caseDir.substr(0, idx); | |||||
| const std::string root_proto = caseDir + "/origin_models/test_getnext_dynamic_fusion.pbtxt"; | |||||
| domi::tensorflow::GraphDef graphDef; | |||||
| bool protoRet = parser::ReadProtoFromText(root_proto.c_str(), &graphDef); | |||||
| ASSERT_EQ(protoRet, true); | |||||
| TensorFlowModelParser tensorflow_parser; | |||||
| ge::ComputeGraphPtr root_graph = ge::parser::MakeShared<ge::ComputeGraph>("tmp_graph"); | |||||
| Status ret = tensorflow_parser.ParseProto(reinterpret_cast<google::protobuf::Message *>(&graphDef), root_graph); | |||||
| EXPECT_EQ(ret, SUCCESS); | |||||
| EXPECT_EQ(root_graph->GetDirectNode().size(), 3); | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||