From 889b82354dcf74e88525db954c5a240dae0bd2bb Mon Sep 17 00:00:00 2001 From: lichun Date: Mon, 22 Mar 2021 21:09:54 +0800 Subject: [PATCH 01/13] Bugfix: support unknown control op subgraph --- ge/hybrid/model/hybrid_model.cc | 2 +- ge/hybrid/model/hybrid_model.h | 2 +- ge/hybrid/model/hybrid_model_builder.cc | 54 ++++++++++++++++++++----- ge/hybrid/model/hybrid_model_builder.h | 4 +- 4 files changed, 47 insertions(+), 15 deletions(-) diff --git a/ge/hybrid/model/hybrid_model.cc b/ge/hybrid/model/hybrid_model.cc index a0217d52..86acc260 100644 --- a/ge/hybrid/model/hybrid_model.cc +++ b/ge/hybrid/model/hybrid_model.cc @@ -333,7 +333,7 @@ TensorValue *HybridModel::GetConstant(const NodePtr &node) const { return nullptr; } - auto it = constant_tensors_.find(node); + auto it = constant_tensors_.find(node->GetName()); if (it == constant_tensors_.end()) { GELOGD("constant not found, node name = [%s]", node->GetName().c_str()); return nullptr; diff --git a/ge/hybrid/model/hybrid_model.h b/ge/hybrid/model/hybrid_model.h index fae53679..e8c005f6 100644 --- a/ge/hybrid/model/hybrid_model.h +++ b/ge/hybrid/model/hybrid_model.h @@ -138,7 +138,7 @@ class HybridModel { std::map device_variable_nodes_; //lint !e148 std::map host_variable_nodes_; //lint !e148 std::map> variable_tensors_; - std::map> constant_tensors_; + std::map> constant_tensors_; std::map> task_defs_; std::map known_shape_sub_models_; diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index a3b1da20..def32766 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -617,9 +617,9 @@ Status HybridModelBuilder::MergeNetOutputNode(ComputeGraph &graph) { return SUCCESS; } -Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGraphPtr &merged_graph) { +Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraphPtr &root_graph, ComputeGraphPtr &merged_graph) { merged_graph = MakeShared("MergedGraph"); - for (const auto &node : root_graph.GetDirectNode()) { + for (const auto &node : root_graph->GetDirectNode()) { GE_CHECK_NOTNULL(node); auto op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); @@ -649,7 +649,7 @@ Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGrap } } } - GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraph(root_graph, *merged_graph, *subgraph), + GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraph(root_graph, merged_graph, *subgraph), "[%s] Failed to merge subgraph.", subgraph->GetName().c_str()); } @@ -665,7 +665,7 @@ Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGrap return a_level < b_level; }); - for (auto &remained_subgraph : root_graph.GetAllSubgraphs()) { + for (auto &remained_subgraph : root_graph->GetAllSubgraphs()) { GELOGD("Adding subgraph [%s] to merged-graph.", remained_subgraph->GetName().c_str()); GE_CHK_GRAPH_STATUS_RET(merged_graph->AddSubgraph(remained_subgraph), "Failed to add subgraph [%s]", @@ -675,8 +675,8 @@ Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGrap return SUCCESS; } -Status HybridModelBuilder::UnfoldSubgraph(ComputeGraph &root_graph, - ComputeGraph &parent_graph, +Status HybridModelBuilder::UnfoldSubgraph(ComputeGraphPtr &root_graph, + ComputeGraphPtr &parent_graph, ComputeGraph &sub_graph) { auto parent_node = sub_graph.GetParentNode(); GE_CHECK_NOTNULL(parent_node); @@ -705,15 +705,24 @@ Status HybridModelBuilder::UnfoldSubgraph(ComputeGraph &root_graph, } } - parent_graph.AddNode(sub_node); + if (!sub_node->GetOpDesc()->GetSubgraphInstanceNames().empty()) { + for (size_t i = 0; i < sub_node->GetOpDesc()->GetSubgraphInstanceNames().size(); ++i) { + auto sub_sub_graph = NodeUtils::GetSubgraph(*sub_node, i); + GE_CHECK_NOTNULL(sub_sub_graph); + sub_sub_graph->SetParentGraph(root_graph); + } + } + + parent_graph->AddNode(sub_node); GELOGD("[%s::%s] added to parent graph: [%s].", sub_graph.GetName().c_str(), sub_node->GetName().c_str(), - parent_graph.GetName().c_str()); + parent_graph->GetName().c_str()); + sub_node->SetOwnerComputeGraph(root_graph); } GELOGD("[%s] Done merging subgraph. remove it from root graph.", sub_graph.GetName().c_str()); - root_graph.RemoveSubgraph(sub_graph.GetName()); + root_graph->RemoveSubgraph(sub_graph.GetName()); return SUCCESS; } @@ -765,7 +774,7 @@ Status HybridModelBuilder::LoadGraph() { GELOGI("Before merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", root_graph->GetDirectNodesSize(), root_graph->GetAllNodesSize()); - GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraphs(*root_graph, merged_graph), "Failed to unfold subgraphs."); + GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraphs(root_graph, merged_graph), "Failed to unfold subgraphs."); root_graph = std::move(merged_graph); GELOGI("After merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", root_graph->GetDirectNodesSize(), @@ -1035,6 +1044,14 @@ Status HybridModelBuilder::InitWeights() { sub_weight_buffer->GetSize()); auto root_graph = GraphUtils::GetComputeGraph(subgraph_model.second->GetGraph()); hybrid_model_.weight_buffer_map_.emplace(root_graph->GetName(),std::move(sub_weight_buffer)); + + std::map name_to_node; + for (const auto &subgraph : ge_root_model_->GetRootGraph()->GetAllSubgraphs()) { + for (const auto &node : subgraph->GetAllNodes()) { + name_to_node.insert(make_pair(node->GetName(), node)); + } + } + for (auto &node : root_graph->GetDirectNode()) { if (node->GetType() != CONSTANT) { continue; @@ -1065,10 +1082,25 @@ Status HybridModelBuilder::InitWeights() { auto tensor_buffer = TensorBuffer::Create(weight_base + data_offset, tensor_size); GE_CHECK_NOTNULL(tensor_buffer); + + if (tensor_size > 0) { + auto tensor = std::shared_ptr( + new (std::nothrow)GeTensor(tensor_desc, weight_buffer.GetData() + data_offset, tensor_size)); + OpDescPtr op_desc = nullptr; + auto iter = name_to_node.find(node->GetName()); + if (iter != name_to_node.end()) { + op_desc = iter->second->GetOpDesc(); + if (!AttrUtils::SetTensor(op_desc, ATTR_NAME_WEIGHTS, std::move(tensor))) { + GELOGE(FAILED, "Set attr ATTR_NAME_WEIGHTS failed."); + return FAILED; + } + } + } + std::unique_ptr constant_tensor(new (std::nothrow)TensorValue(std::move(tensor_buffer))); GE_CHECK_NOTNULL(constant_tensor); constant_tensor->SetName("Constant_" + op_desc->GetName()); - hybrid_model_.constant_tensors_.emplace(node, std::move(constant_tensor)); + hybrid_model_.constant_tensors_.emplace(node->GetName(), std::move(constant_tensor)); GELOGD("[%s] Constant node [%s] added, size = %ld", GetGraphName(), node->GetName().c_str(), tensor_size); } } diff --git a/ge/hybrid/model/hybrid_model_builder.h b/ge/hybrid/model/hybrid_model_builder.h index 313d5ca6..a6b2b25a 100644 --- a/ge/hybrid/model/hybrid_model_builder.h +++ b/ge/hybrid/model/hybrid_model_builder.h @@ -47,8 +47,8 @@ class HybridModelBuilder { static Status HandleDtString(const GeTensor &tensor, void *var_addr); static Status MergeInputNodes(ComputeGraph &compute_graph); static Status MergeNetOutputNode(ComputeGraph &compute_graph); - static Status UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGraphPtr &merged_graph); - static Status UnfoldSubgraph(ComputeGraph &root_graph, ComputeGraph &parent_graph, ComputeGraph &sub_graph); + static Status UnfoldSubgraphs(ComputeGraphPtr &root_graph, ComputeGraphPtr &merged_graph); + static Status UnfoldSubgraph(ComputeGraphPtr &root_graph, ComputeGraphPtr &parent_graph, ComputeGraph &sub_graph); static Status BuildInputMapping(GraphItem &graph_item, std::vector &data_nodes, bool is_root_graph); From 5d0af86a07533a6f5edfad47949bbd0266bac59e Mon Sep 17 00:00:00 2001 From: lichun Date: Tue, 23 Mar 2021 10:42:14 +0800 Subject: [PATCH 02/13] Bugfix: support unknown control op subgraph --- .../format_transfers/format_transfer_fractal_nz.cc | 2 +- .../format_transfers/format_transfer_fractal_zz.cc | 2 +- .../format_transfer_nhwc_nc1hwc0.cc | 3 ++- .../format_transfers/format_transfer_transpose.cc | 14 +++++++------- 4 files changed, 11 insertions(+), 10 deletions(-) diff --git a/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc b/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc index fccdb57b..01c7de95 100755 --- a/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc +++ b/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc @@ -60,7 +60,7 @@ bool CheckShape(Format format, const ShapeVector &shape) { default: std::string error = "Trans format between " + FmtToStr(TypeUtils::FormatToSerialString(format)) + " and FORMAT_FRACTAL_NZ is not supported."; - GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); + GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); return false; } } diff --git a/ge/common/formats/format_transfers/format_transfer_fractal_zz.cc b/ge/common/formats/format_transfers/format_transfer_fractal_zz.cc index c36bffb5..36bea872 100755 --- a/ge/common/formats/format_transfers/format_transfer_fractal_zz.cc +++ b/ge/common/formats/format_transfers/format_transfer_fractal_zz.cc @@ -59,7 +59,7 @@ bool CheckShape(Format format, const ShapeVector &shape) { default: std::string error = "Trans format between " + FmtToStr(TypeUtils::FormatToSerialString(format)) + " and FORMAT_FRACTAL_ZZ is not supported."; - GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); + GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); return false; } } diff --git a/ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc b/ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc index b09fd168..6817713a 100755 --- a/ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc +++ b/ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc @@ -92,7 +92,8 @@ Status CheckArgsForNhwcToNc1hwc0(const TransArgs &args) { Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const int size, const int64_t total_size) { std::shared_ptr dst(new (std::nothrow) uint8_t[total_size], std::default_delete()); if (dst == nullptr) { - GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s", + GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, + "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s", TypeUtils::FormatToSerialString(args.src_format).c_str(), TypeUtils::FormatToSerialString(args.dst_format).c_str(), total_size, ShapeToString(args.dst_shape).c_str()); return ACL_ERROR_GE_MEMORY_ALLOCATION; diff --git a/ge/common/formats/format_transfers/format_transfer_transpose.cc b/ge/common/formats/format_transfers/format_transfer_transpose.cc index 694777f3..49bb5cd6 100755 --- a/ge/common/formats/format_transfers/format_transfer_transpose.cc +++ b/ge/common/formats/format_transfers/format_transfer_transpose.cc @@ -50,21 +50,21 @@ std::map>> perm_args{ bool IsShapeArgValid(const std::vector &src_shape, const std::vector &perm_arg) { if (src_shape.empty()) { std::string error = "Failed to transpose, empty src shape"; - GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); - GELOGE(PARAM_INVALID, "Failed to transpose, empty src shape"); + GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_SHAPE_INVALID, error.c_str()); + GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to transpose, empty src shape"); return false; } for (auto dim : src_shape) { if (dim < 0) { std::string error = "Failed to transpose, negative dim in src shape " + FmtToStr(ShapeToString(src_shape)); - GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); + GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_SHAPE_INVALID, error.c_str()); return false; } } if (perm_arg.size() != src_shape.size()) { std::string error = "Failed to transpose, the size of src shape" + FmtToStr(src_shape.size()) + " and perm arg" + FmtToStr(perm_arg.size()) + " are different"; - GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); + GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_SHAPE_INVALID, error.c_str()); return false; } @@ -73,7 +73,7 @@ bool IsShapeArgValid(const std::vector &src_shape, const std::vector(perm) >= perm_arg.size() || ++exists[perm] > 1) { std::string error = "Failed to transpose, duplicated perm arg " + FmtToStr(perm) + ", perm arg " + FmtToStr(JoinToString(perm_arg)); - GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); + GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_PARAM_INVALID, error.c_str()); return false; } } @@ -82,11 +82,11 @@ bool IsShapeArgValid(const std::vector &src_shape, const std::vector &src_shape, DataType src_data_type, const std::vector &perm_arg) { if (src == nullptr) { - GELOGE(PARAM_INVALID, "Failed to transpose, the src is null"); + GELOGE(ACL_ERROR_GE_PARAM_INVALID, "Failed to transpose, the src is null"); return false; } if (GetSizeByDataType(src_data_type) < 0) { - GELOGE(UNSUPPORTED, "Failed to transpose, the data type %s is not support", + GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "Failed to transpose, the data type %s is not support", TypeUtils::DataTypeToSerialString(src_data_type).c_str()); return false; } From bb81990dd3dc5d4db6e6e168a1c9e4d22d60d9b0 Mon Sep 17 00:00:00 2001 From: lichun Date: Tue, 23 Mar 2021 14:38:30 +0800 Subject: [PATCH 03/13] Bugfix: support unknown control op subgraph --- tests/ut/ge/hybrid/ge_hybrid_unittest.cc | 39 +++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc index 3b5d19e6..c13a9003 100644 --- a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc +++ b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc @@ -245,7 +245,7 @@ TEST_F(UtestGeHybrid, init_weight_success) { ASSERT_EQ(ret,PARAM_INVALID); } - TEST_F(UtestGeHybrid, hybrid_model_executor) { +TEST_F(UtestGeHybrid, hybrid_model_executor) { ComputeGraphPtr compute_graph = MakeShared("abc"); GeRootModelPtr root_model = MakeShared(compute_graph); HybridModel model(root_model); @@ -256,3 +256,40 @@ TEST_F(UtestGeHybrid, init_weight_success) { HybridModelExecutor executor(model_ptr, device_id, stream); executor.Init(); } + +TEST_F(UtestGeHybrid, unfold_subgraphs_success) { + ComputeGraphPtr merged_graph = nullptr; + + ComputeGraphPtr root_graph = std::make_shared("root_graph"); + auto partitioned_call_op_desc = MakeShared("partitioned_call", PARTITIONEDCALL); + auto partitioned_call_node = root_graph->AddNode(partitioned_call_op_desc); + + ComputeGraphPtr sub_graph = std::make_shared("sub_graph"); + OpDescPtr sub_graph_while_op_desc = CreateOpDesc("while", WHILE); + NodePtr sub_graph_while_node = sub_graph->AddNode(op_desc); + + ComputeGraphPtr sub_sub_graph1 = std::make_shared("while_cond"); + OpDescPtr sub_sub_graph_while_cond_const_op_desc = CreateOpDesc("cond_const", CONST); + NodePtr sub_sub_graph_while_cond_const_node = sub_sub_graph1->AddNode(sub_sub_graph_while_cond_const_op_desc); + + ComputeGraphPtr sub_sub_graph2 = std::make_shared("while body"); + OpDescPtr sub_sub_graph_while_body_const_op_desc = CreateOpDesc("body_const", CONST); + NodePtr sub_sub_graph_while_body_const_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_const_op_desc); + OpDescPtr sub_sub_graph_while_body_data_op_desc = CreateOpDesc("body_data", DATA); + NodePtr sub_sub_graph_while_body_data_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_data_op_desc); + OpDescPtr sub_sub_graph_while_body_add_op_desc = CreateOpDesc("body_add", ADD); + NodePtr sub_sub_graph_while_body_add_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_add_node); + sub_sub_graph_while_body_add_node->AddLinkFrom(sub_sub_graph_while_body_data_node); + sub_sub_graph_while_body_add_node->AddLinkFrom(sub_sub_graph_while_body_const_node); + + sub_sub_graph1->AddSubGraph(sub_sub_graph1); + sub_sub_graph2->AddSubGraph(sub_sub_graph2); + + root_graph->AddSubGraph(sub_graph); + sub_graph->SetParentNode(partitioned_call_node); + + GeRootModelPtr root_model = MakeShared(root_graph); + HybridModel hybrid_model(root_model); + HybridModelBuilder hybrid_model_builder(hybrid_model); + EXPECT_EQ(hybrid_model_builder.UnfoldSubgraphs(root_graph, merged_graph), SUCCESS); +} From 1e9a360a8cfb673d0b9dac9f69850da5cdf2f8e1 Mon Sep 17 00:00:00 2001 From: lichun Date: Tue, 23 Mar 2021 14:41:07 +0800 Subject: [PATCH 04/13] Bugfix: support unknown control op subgraph --- tests/ut/ge/hybrid/ge_hybrid_unittest.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc index c13a9003..242ca91d 100644 --- a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc +++ b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc @@ -266,11 +266,11 @@ TEST_F(UtestGeHybrid, unfold_subgraphs_success) { ComputeGraphPtr sub_graph = std::make_shared("sub_graph"); OpDescPtr sub_graph_while_op_desc = CreateOpDesc("while", WHILE); - NodePtr sub_graph_while_node = sub_graph->AddNode(op_desc); + NodePtr sub_graph_while_node = sub_graph->AddNode(sub_graph_while_op_desc); ComputeGraphPtr sub_sub_graph1 = std::make_shared("while_cond"); - OpDescPtr sub_sub_graph_while_cond_const_op_desc = CreateOpDesc("cond_const", CONST); - NodePtr sub_sub_graph_while_cond_const_node = sub_sub_graph1->AddNode(sub_sub_graph_while_cond_const_op_desc); + OpDescPtr sub_sub_graph_while_cond_data_op_desc = CreateOpDesc("cond_data", DATA); + NodePtr sub_sub_graph_while_cond_data_node = sub_sub_graph1->AddNode(sub_sub_graph_while_cond_data_op_desc); ComputeGraphPtr sub_sub_graph2 = std::make_shared("while body"); OpDescPtr sub_sub_graph_while_body_const_op_desc = CreateOpDesc("body_const", CONST); From 8a83fbf1a672ed63ccfa2c0a0388d97dafa695a8 Mon Sep 17 00:00:00 2001 From: lichun Date: Tue, 23 Mar 2021 14:45:00 +0800 Subject: [PATCH 05/13] Bugfix: support unknown control op subgraph --- tests/ut/ge/hybrid/ge_hybrid_unittest.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc index 242ca91d..1b2223b3 100644 --- a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc +++ b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc @@ -273,7 +273,7 @@ TEST_F(UtestGeHybrid, unfold_subgraphs_success) { NodePtr sub_sub_graph_while_cond_data_node = sub_sub_graph1->AddNode(sub_sub_graph_while_cond_data_op_desc); ComputeGraphPtr sub_sub_graph2 = std::make_shared("while body"); - OpDescPtr sub_sub_graph_while_body_const_op_desc = CreateOpDesc("body_const", CONST); + OpDescPtr sub_sub_graph_while_body_const_op_desc = CreateOpDesc("body_const", CONSTANT); NodePtr sub_sub_graph_while_body_const_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_const_op_desc); OpDescPtr sub_sub_graph_while_body_data_op_desc = CreateOpDesc("body_data", DATA); NodePtr sub_sub_graph_while_body_data_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_data_op_desc); From 60cb35477423c9e2ef842029c4443b919aca1ae0 Mon Sep 17 00:00:00 2001 From: lichun Date: Tue, 23 Mar 2021 15:01:43 +0800 Subject: [PATCH 06/13] Bugfix: support unknown control op subgraph --- tests/ut/ge/hybrid/ge_hybrid_unittest.cc | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc index 1b2223b3..eac2da33 100644 --- a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc +++ b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc @@ -260,13 +260,10 @@ TEST_F(UtestGeHybrid, hybrid_model_executor) { TEST_F(UtestGeHybrid, unfold_subgraphs_success) { ComputeGraphPtr merged_graph = nullptr; - ComputeGraphPtr root_graph = std::make_shared("root_graph"); - auto partitioned_call_op_desc = MakeShared("partitioned_call", PARTITIONEDCALL); - auto partitioned_call_node = root_graph->AddNode(partitioned_call_op_desc); - ComputeGraphPtr sub_graph = std::make_shared("sub_graph"); OpDescPtr sub_graph_while_op_desc = CreateOpDesc("while", WHILE); NodePtr sub_graph_while_node = sub_graph->AddNode(sub_graph_while_op_desc); + sub_graph->SetGraphUnknownFlag(True); ComputeGraphPtr sub_sub_graph1 = std::make_shared("while_cond"); OpDescPtr sub_sub_graph_while_cond_data_op_desc = CreateOpDesc("cond_data", DATA); @@ -282,6 +279,11 @@ TEST_F(UtestGeHybrid, unfold_subgraphs_success) { sub_sub_graph_while_body_add_node->AddLinkFrom(sub_sub_graph_while_body_data_node); sub_sub_graph_while_body_add_node->AddLinkFrom(sub_sub_graph_while_body_const_node); + ComputeGraphPtr root_graph = std::make_shared("root_graph"); + auto partitioned_call_op_desc = MakeShared("partitioned_call", PARTITIONEDCALL); + auto partitioned_call_node = root_graph->AddNode(partitioned_call_op_desc); + partitioned_call_node->GetOpDesc()->AddSubgraphName("sub_graph"); + sub_sub_graph1->AddSubGraph(sub_sub_graph1); sub_sub_graph2->AddSubGraph(sub_sub_graph2); From bb64fd45a2b25d82706115fc5314e3b25a754860 Mon Sep 17 00:00:00 2001 From: lichun Date: Tue, 23 Mar 2021 15:02:24 +0800 Subject: [PATCH 07/13] Bugfix: support unknown control op subgraph --- tests/ut/ge/hybrid/ge_hybrid_unittest.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc index eac2da33..8d8f0a78 100644 --- a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc +++ b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc @@ -263,7 +263,7 @@ TEST_F(UtestGeHybrid, unfold_subgraphs_success) { ComputeGraphPtr sub_graph = std::make_shared("sub_graph"); OpDescPtr sub_graph_while_op_desc = CreateOpDesc("while", WHILE); NodePtr sub_graph_while_node = sub_graph->AddNode(sub_graph_while_op_desc); - sub_graph->SetGraphUnknownFlag(True); + sub_graph->SetGraphUnknownFlag(true); ComputeGraphPtr sub_sub_graph1 = std::make_shared("while_cond"); OpDescPtr sub_sub_graph_while_cond_data_op_desc = CreateOpDesc("cond_data", DATA); From 6426aab7d328b6c4937ecf4b5b478dc219fa7b3b Mon Sep 17 00:00:00 2001 From: lichun Date: Tue, 23 Mar 2021 15:33:00 +0800 Subject: [PATCH 08/13] Bugfix: support unknown control op subgraph --- ge/CMakeLists.txt | 104 +++++++++++------------ ge/common/CMakeLists.txt | 4 +- ge/executor/CMakeLists.txt | 4 +- ge/ge_local_engine/CMakeLists.txt | 6 +- ge/host_cpu_engine/CMakeLists.txt | 4 +- ge/offline/CMakeLists.txt | 2 +- tests/ut/ge/hybrid/ge_hybrid_unittest.cc | 1 + 7 files changed, 63 insertions(+), 62 deletions(-) diff --git a/ge/CMakeLists.txt b/ge/CMakeLists.txt index bd9edd86..9f2b21d7 100755 --- a/ge/CMakeLists.txt +++ b/ge/CMakeLists.txt @@ -28,59 +28,59 @@ set(PROTO_HEADER_LIST "${METADEF_DIR}/proto/op_mapping_info.proto" ) -protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) -protobuf_generate(ge PROTO_CLIENT_SRCS PROTO_CLIENT_HDRS ${PROTO_CLIENT_LIST}) -protobuf_generate(ge PROTO_HEADER_SRCS PROTO_HEADER_HDRS ${PROTO_HEADER_LIST}) -protobuf_generate(ge_client PROTO_CLIENT_HEADER_SRCS PROTO_CLIENT_HEADER_HDRS ${PROTO_HEADER_LIST}) +#protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) +#protobuf_generate(ge PROTO_CLIENT_SRCS PROTO_CLIENT_HDRS ${PROTO_CLIENT_LIST}) +#protobuf_generate(ge PROTO_HEADER_SRCS PROTO_HEADER_HDRS ${PROTO_HEADER_LIST}) +#protobuf_generate(ge_client PROTO_CLIENT_HEADER_SRCS PROTO_CLIENT_HEADER_HDRS ${PROTO_HEADER_LIST}) -if (NOT ENABLE_D AND NOT ENABLE_ACL AND NOT ENABLE_MS_TESTCASES) -############ libge_proto_common.a ############ -add_library(ge_proto_common STATIC - ${PROTO_HEADER_HDRS} - ${PROTO_SRCS} -) - -target_compile_definitions(ge_proto_common PRIVATE - PROTOBUF_INLINE_NOT_IN_HEADERS=0 - google=ascend_private -) - -target_compile_options(ge_proto_common PRIVATE - -O2 - -fno-common -) - -target_link_libraries(ge_proto_common PRIVATE - $ - ascend_protobuf -) - -############ libge_proto_client.a ############ -add_library(ge_proto_client STATIC - ${PROTO_CLIENT_HEADER_HDRS} - ${PROTO_CLIENT_SRCS} -) - -target_compile_definitions(ge_proto_client PRIVATE - PROTOBUF_INLINE_NOT_IN_HEADERS=0 - google=ascend_private -) - -target_include_directories(ge_proto_client PRIVATE - ${CMAKE_BINARY_DIR}/proto/ge_client - ${CMAKE_BINARY_DIR}/proto/ge_client/proto -) - -target_compile_options(ge_proto_client PRIVATE - -O2 - -fno-common -) - -target_link_libraries(ge_proto_client PRIVATE - $ - ascend_protobuf -) -endif () +#if (NOT ENABLE_D AND NOT ENABLE_ACL AND NOT ENABLE_MS_TESTCASES) +############# libge_proto_common.a ############ +#add_library(ge_proto_common STATIC +# ${PROTO_HEADER_HDRS} +# ${PROTO_SRCS} +#) +# +#target_compile_definitions(ge_proto_common PRIVATE +# PROTOBUF_INLINE_NOT_IN_HEADERS=0 +# google=ascend_private +#) +# +#target_compile_options(ge_proto_common PRIVATE +# -O2 +# -fno-common +#) +# +#target_link_libraries(ge_proto_common PRIVATE +# $ +# ascend_protobuf +#) +# +############# libge_proto_client.a ############ +#add_library(ge_proto_client STATIC +# ${PROTO_CLIENT_HEADER_HDRS} +# ${PROTO_CLIENT_SRCS} +#) +# +#target_compile_definitions(ge_proto_client PRIVATE +# PROTOBUF_INLINE_NOT_IN_HEADERS=0 +# google=ascend_private +#) +# +#target_include_directories(ge_proto_client PRIVATE +# ${CMAKE_BINARY_DIR}/proto/ge_client +# ${CMAKE_BINARY_DIR}/proto/ge_client/proto +#) +# +#target_compile_options(ge_proto_client PRIVATE +# -O2 +# -fno-common +#) +# +#target_link_libraries(ge_proto_client PRIVATE +# $ +# ascend_protobuf +#) +#endif () ################################################################## set(TRAIN_SRC_LIST diff --git a/ge/common/CMakeLists.txt b/ge/common/CMakeLists.txt index 75cb8ad1..585a42cb 100755 --- a/ge/common/CMakeLists.txt +++ b/ge/common/CMakeLists.txt @@ -15,8 +15,8 @@ set(PROTO_LIST "${METADEF_DIR}/proto/tensorflow/versions.proto" ) -protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) -protobuf_generate(ge_static PROTO_STATIC_SRCS PROTO_STATIC_HDRS ${PROTO_LIST}) +#protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) +#protobuf_generate(ge_static PROTO_STATIC_SRCS PROTO_STATIC_HDRS ${PROTO_LIST}) set(SRC_LIST "context/ctx.cc" diff --git a/ge/executor/CMakeLists.txt b/ge/executor/CMakeLists.txt index 363900d0..31f8dc4b 100644 --- a/ge/executor/CMakeLists.txt +++ b/ge/executor/CMakeLists.txt @@ -7,8 +7,8 @@ set(PROTO_LIST "${METADEF_DIR}/proto/dump_task.proto" ) -protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) -protobuf_generate(ge_static PROTO_STATIC_SRCS PROTO_STATIC_HDRS ${PROTO_LIST}) +#protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) +#protobuf_generate(ge_static PROTO_STATIC_SRCS PROTO_STATIC_HDRS ${PROTO_LIST}) set(SRC_LIST "ge_executor.cc" diff --git a/ge/ge_local_engine/CMakeLists.txt b/ge/ge_local_engine/CMakeLists.txt index ab767ccb..affd8f5a 100755 --- a/ge/ge_local_engine/CMakeLists.txt +++ b/ge/ge_local_engine/CMakeLists.txt @@ -19,9 +19,9 @@ set(OPS_KERNEL_SRC_LIST "ops_kernel_store/op/no_op.cc" ) -protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) -protobuf_generate(ge_ops_shared PROTO_OPS_SHARED_SRCS PROTO_OPS_SHARED_HDRS ${PROTO_LIST}) -protobuf_generate(ge_ops_static PROTO_OPS_STATIC_SRCS PROTO_OPS_STATIC_HDRS ${PROTO_LIST}) +#protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) +#protobuf_generate(ge_ops_shared PROTO_OPS_SHARED_SRCS PROTO_OPS_SHARED_HDRS ${PROTO_LIST}) +#protobuf_generate(ge_ops_static PROTO_OPS_STATIC_SRCS PROTO_OPS_STATIC_HDRS ${PROTO_LIST}) ############ libge_local_engine.so ############ add_library(ge_local_engine SHARED ${SRC_LIST} ${PROTO_HDRS}) diff --git a/ge/host_cpu_engine/CMakeLists.txt b/ge/host_cpu_engine/CMakeLists.txt index 8d84ee28..950a1e5c 100644 --- a/ge/host_cpu_engine/CMakeLists.txt +++ b/ge/host_cpu_engine/CMakeLists.txt @@ -2,8 +2,8 @@ set(PROTO_LIST "${METADEF_DIR}/proto/task.proto" ) -protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) -protobuf_generate(ge_atcstub PROTO_ATCSTUB_SRCS PROTO_ATCSTUB_HDRS ${PROTO_LIST}) +#protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) +#protobuf_generate(ge_atcstub PROTO_ATCSTUB_SRCS PROTO_ATCSTUB_HDRS ${PROTO_LIST}) set(SRC_LIST "engine/host_cpu_engine.cc" diff --git a/ge/offline/CMakeLists.txt b/ge/offline/CMakeLists.txt index 87589859..1e8a6cc5 100644 --- a/ge/offline/CMakeLists.txt +++ b/ge/offline/CMakeLists.txt @@ -5,7 +5,7 @@ set(PROTO_LIST "${METADEF_DIR}/proto/task.proto" ) -protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) +#protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) set(SRC_LIST "main.cc" diff --git a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc index 8d8f0a78..e244eb2f 100644 --- a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc +++ b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc @@ -283,6 +283,7 @@ TEST_F(UtestGeHybrid, unfold_subgraphs_success) { auto partitioned_call_op_desc = MakeShared("partitioned_call", PARTITIONEDCALL); auto partitioned_call_node = root_graph->AddNode(partitioned_call_op_desc); partitioned_call_node->GetOpDesc()->AddSubgraphName("sub_graph"); + partitioned_call_node->GetOpDesc()->SetSubgraphInstanceName(0, "sub_graph"); sub_sub_graph1->AddSubGraph(sub_sub_graph1); sub_sub_graph2->AddSubGraph(sub_sub_graph2); From a77ecde682a000d9ba8a5cfa25b161e2ca118bea Mon Sep 17 00:00:00 2001 From: lichun Date: Tue, 23 Mar 2021 16:09:38 +0800 Subject: [PATCH 09/13] Bugfix: support unknown control op subgraph --- tests/ut/ge/hybrid/ge_hybrid_unittest.cc | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc index e244eb2f..7e6da9e0 100644 --- a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc +++ b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc @@ -260,11 +260,6 @@ TEST_F(UtestGeHybrid, hybrid_model_executor) { TEST_F(UtestGeHybrid, unfold_subgraphs_success) { ComputeGraphPtr merged_graph = nullptr; - ComputeGraphPtr sub_graph = std::make_shared("sub_graph"); - OpDescPtr sub_graph_while_op_desc = CreateOpDesc("while", WHILE); - NodePtr sub_graph_while_node = sub_graph->AddNode(sub_graph_while_op_desc); - sub_graph->SetGraphUnknownFlag(true); - ComputeGraphPtr sub_sub_graph1 = std::make_shared("while_cond"); OpDescPtr sub_sub_graph_while_cond_data_op_desc = CreateOpDesc("cond_data", DATA); NodePtr sub_sub_graph_while_cond_data_node = sub_sub_graph1->AddNode(sub_sub_graph_while_cond_data_op_desc); @@ -279,6 +274,13 @@ TEST_F(UtestGeHybrid, unfold_subgraphs_success) { sub_sub_graph_while_body_add_node->AddLinkFrom(sub_sub_graph_while_body_data_node); sub_sub_graph_while_body_add_node->AddLinkFrom(sub_sub_graph_while_body_const_node); + ComputeGraphPtr sub_graph = std::make_shared("sub_graph"); + OpDescPtr sub_graph_while_op_desc = CreateOpDesc("while", WHILE); + NodePtr sub_graph_while_node = sub_graph->AddNode(sub_graph_while_op_desc); + sub_graph->SetGraphUnknownFlag(true); + sub_graph_while_node->GetOpDesc()->SetSubgraphInstanceName(0, "while_cond"); + sub_graph_while_node->GetOpDesc()->SetSubgraphInstanceName(1, "while_body"); + ComputeGraphPtr root_graph = std::make_shared("root_graph"); auto partitioned_call_op_desc = MakeShared("partitioned_call", PARTITIONEDCALL); auto partitioned_call_node = root_graph->AddNode(partitioned_call_op_desc); From 54edf24170fcdbe765af9d5d75af73a3c93b8c2f Mon Sep 17 00:00:00 2001 From: lichun Date: Tue, 23 Mar 2021 17:14:21 +0800 Subject: [PATCH 10/13] Bugfix: support unknown control op subgraph --- tests/ut/ge/hybrid/ge_hybrid_unittest.cc | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc index 7e6da9e0..26aec7fe 100644 --- a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc +++ b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc @@ -265,19 +265,21 @@ TEST_F(UtestGeHybrid, unfold_subgraphs_success) { NodePtr sub_sub_graph_while_cond_data_node = sub_sub_graph1->AddNode(sub_sub_graph_while_cond_data_op_desc); ComputeGraphPtr sub_sub_graph2 = std::make_shared("while body"); - OpDescPtr sub_sub_graph_while_body_const_op_desc = CreateOpDesc("body_const", CONSTANT); - NodePtr sub_sub_graph_while_body_const_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_const_op_desc); + /*OpDescPtr sub_sub_graph_while_body_const_op_desc = CreateOpDesc("body_const", CONSTANT); + NodePtr sub_sub_graph_while_body_const_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_const_op_desc);*/ OpDescPtr sub_sub_graph_while_body_data_op_desc = CreateOpDesc("body_data", DATA); NodePtr sub_sub_graph_while_body_data_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_data_op_desc); - OpDescPtr sub_sub_graph_while_body_add_op_desc = CreateOpDesc("body_add", ADD); + /*OpDescPtr sub_sub_graph_while_body_add_op_desc = CreateOpDesc("body_add", ADD); NodePtr sub_sub_graph_while_body_add_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_add_node); sub_sub_graph_while_body_add_node->AddLinkFrom(sub_sub_graph_while_body_data_node); - sub_sub_graph_while_body_add_node->AddLinkFrom(sub_sub_graph_while_body_const_node); + sub_sub_graph_while_body_add_node->AddLinkFrom(sub_sub_graph_while_body_const_node);*/ ComputeGraphPtr sub_graph = std::make_shared("sub_graph"); OpDescPtr sub_graph_while_op_desc = CreateOpDesc("while", WHILE); NodePtr sub_graph_while_node = sub_graph->AddNode(sub_graph_while_op_desc); sub_graph->SetGraphUnknownFlag(true); + sub_graph_while_node->GetOpDesc()->AddSubgraphName("while_cond"); + sub_graph_while_node->GetOpDesc()->AddSubgraphName("while_body"); sub_graph_while_node->GetOpDesc()->SetSubgraphInstanceName(0, "while_cond"); sub_graph_while_node->GetOpDesc()->SetSubgraphInstanceName(1, "while_body"); From 6ae1b36143ba502136fdd2c6903d666539b86c25 Mon Sep 17 00:00:00 2001 From: lichun Date: Tue, 23 Mar 2021 17:28:12 +0800 Subject: [PATCH 11/13] Bugfix: support unknown control op subgraph --- tests/ut/ge/hybrid/ge_hybrid_unittest.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc index 26aec7fe..88141d97 100644 --- a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc +++ b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc @@ -289,8 +289,10 @@ TEST_F(UtestGeHybrid, unfold_subgraphs_success) { partitioned_call_node->GetOpDesc()->AddSubgraphName("sub_graph"); partitioned_call_node->GetOpDesc()->SetSubgraphInstanceName(0, "sub_graph"); - sub_sub_graph1->AddSubGraph(sub_sub_graph1); - sub_sub_graph2->AddSubGraph(sub_sub_graph2); + root_graph->AddSubGraph(sub_sub_graph1); + root_graph->AddSubGraph(sub_sub_graph2); + sub_sub_graph1->SetParentGraph(root_graph); + sub_sub_graph2->SetParentGraph(root_graph); root_graph->AddSubGraph(sub_graph); sub_graph->SetParentNode(partitioned_call_node); From b7f4f5c8cb885fd010409dd3c6be2a425b09f207 Mon Sep 17 00:00:00 2001 From: lichun Date: Tue, 23 Mar 2021 19:35:38 +0800 Subject: [PATCH 12/13] Bugfix: support unknown control op subgraph --- tests/ut/ge/hybrid/ge_hybrid_unittest.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc index 88141d97..e8e8e196 100644 --- a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc +++ b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc @@ -269,6 +269,7 @@ TEST_F(UtestGeHybrid, unfold_subgraphs_success) { NodePtr sub_sub_graph_while_body_const_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_const_op_desc);*/ OpDescPtr sub_sub_graph_while_body_data_op_desc = CreateOpDesc("body_data", DATA); NodePtr sub_sub_graph_while_body_data_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_data_op_desc); + sub_sub_graph2->SetGraphUnknownFlag(true); /*OpDescPtr sub_sub_graph_while_body_add_op_desc = CreateOpDesc("body_add", ADD); NodePtr sub_sub_graph_while_body_add_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_add_node); sub_sub_graph_while_body_add_node->AddLinkFrom(sub_sub_graph_while_body_data_node); @@ -277,7 +278,7 @@ TEST_F(UtestGeHybrid, unfold_subgraphs_success) { ComputeGraphPtr sub_graph = std::make_shared("sub_graph"); OpDescPtr sub_graph_while_op_desc = CreateOpDesc("while", WHILE); NodePtr sub_graph_while_node = sub_graph->AddNode(sub_graph_while_op_desc); - sub_graph->SetGraphUnknownFlag(true); + sub_graph->SetGraphUnknownFlag(false); sub_graph_while_node->GetOpDesc()->AddSubgraphName("while_cond"); sub_graph_while_node->GetOpDesc()->AddSubgraphName("while_body"); sub_graph_while_node->GetOpDesc()->SetSubgraphInstanceName(0, "while_cond"); @@ -289,6 +290,7 @@ TEST_F(UtestGeHybrid, unfold_subgraphs_success) { partitioned_call_node->GetOpDesc()->AddSubgraphName("sub_graph"); partitioned_call_node->GetOpDesc()->SetSubgraphInstanceName(0, "sub_graph"); + root_graph->AddSubGraph(sub_sub_graph1); root_graph->AddSubGraph(sub_sub_graph2); sub_sub_graph1->SetParentGraph(root_graph); From 51b3a1acedd3d0294a4588b107d2e7ca65d05b7b Mon Sep 17 00:00:00 2001 From: lichun Date: Tue, 23 Mar 2021 22:14:32 +0800 Subject: [PATCH 13/13] Bugfix: support unknown control op subgraph --- tests/ut/ge/hybrid/ge_hybrid_unittest.cc | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc index e8e8e196..de84342c 100644 --- a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc +++ b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc @@ -264,7 +264,7 @@ TEST_F(UtestGeHybrid, unfold_subgraphs_success) { OpDescPtr sub_sub_graph_while_cond_data_op_desc = CreateOpDesc("cond_data", DATA); NodePtr sub_sub_graph_while_cond_data_node = sub_sub_graph1->AddNode(sub_sub_graph_while_cond_data_op_desc); - ComputeGraphPtr sub_sub_graph2 = std::make_shared("while body"); + ComputeGraphPtr sub_sub_graph2 = std::make_shared("while_body"); /*OpDescPtr sub_sub_graph_while_body_const_op_desc = CreateOpDesc("body_const", CONSTANT); NodePtr sub_sub_graph_while_body_const_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_const_op_desc);*/ OpDescPtr sub_sub_graph_while_body_data_op_desc = CreateOpDesc("body_data", DATA); @@ -278,7 +278,7 @@ TEST_F(UtestGeHybrid, unfold_subgraphs_success) { ComputeGraphPtr sub_graph = std::make_shared("sub_graph"); OpDescPtr sub_graph_while_op_desc = CreateOpDesc("while", WHILE); NodePtr sub_graph_while_node = sub_graph->AddNode(sub_graph_while_op_desc); - sub_graph->SetGraphUnknownFlag(false); + sub_graph->SetGraphUnknownFlag(true); sub_graph_while_node->GetOpDesc()->AddSubgraphName("while_cond"); sub_graph_while_node->GetOpDesc()->AddSubgraphName("while_body"); sub_graph_while_node->GetOpDesc()->SetSubgraphInstanceName(0, "while_cond"); @@ -290,14 +290,16 @@ TEST_F(UtestGeHybrid, unfold_subgraphs_success) { partitioned_call_node->GetOpDesc()->AddSubgraphName("sub_graph"); partitioned_call_node->GetOpDesc()->SetSubgraphInstanceName(0, "sub_graph"); - root_graph->AddSubGraph(sub_sub_graph1); root_graph->AddSubGraph(sub_sub_graph2); sub_sub_graph1->SetParentGraph(root_graph); sub_sub_graph2->SetParentGraph(root_graph); + sub_sub_graph1->SetParentNode(sub_graph_while_node); + sub_sub_graph2->SetParentNode(sub_graph_while_node); root_graph->AddSubGraph(sub_graph); sub_graph->SetParentNode(partitioned_call_node); + sub_graph->SetParentGraph(root_graph); GeRootModelPtr root_model = MakeShared(root_graph); HybridModel hybrid_model(root_model);