diff --git a/ge/hybrid/model/hybrid_model.cc b/ge/hybrid/model/hybrid_model.cc index a0217d52..86acc260 100644 --- a/ge/hybrid/model/hybrid_model.cc +++ b/ge/hybrid/model/hybrid_model.cc @@ -333,7 +333,7 @@ TensorValue *HybridModel::GetConstant(const NodePtr &node) const { return nullptr; } - auto it = constant_tensors_.find(node); + auto it = constant_tensors_.find(node->GetName()); if (it == constant_tensors_.end()) { GELOGD("constant not found, node name = [%s]", node->GetName().c_str()); return nullptr; diff --git a/ge/hybrid/model/hybrid_model.h b/ge/hybrid/model/hybrid_model.h index fae53679..e8c005f6 100644 --- a/ge/hybrid/model/hybrid_model.h +++ b/ge/hybrid/model/hybrid_model.h @@ -138,7 +138,7 @@ class HybridModel { std::map device_variable_nodes_; //lint !e148 std::map host_variable_nodes_; //lint !e148 std::map> variable_tensors_; - std::map> constant_tensors_; + std::map> constant_tensors_; std::map> task_defs_; std::map known_shape_sub_models_; diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index a3b1da20..def32766 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -617,9 +617,9 @@ Status HybridModelBuilder::MergeNetOutputNode(ComputeGraph &graph) { return SUCCESS; } -Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGraphPtr &merged_graph) { +Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraphPtr &root_graph, ComputeGraphPtr &merged_graph) { merged_graph = MakeShared("MergedGraph"); - for (const auto &node : root_graph.GetDirectNode()) { + for (const auto &node : root_graph->GetDirectNode()) { GE_CHECK_NOTNULL(node); auto op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); @@ -649,7 +649,7 @@ Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGrap } } } - GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraph(root_graph, *merged_graph, *subgraph), + GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraph(root_graph, merged_graph, *subgraph), "[%s] Failed to merge subgraph.", subgraph->GetName().c_str()); } @@ -665,7 +665,7 @@ Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGrap return a_level < b_level; }); - for (auto &remained_subgraph : root_graph.GetAllSubgraphs()) { + for (auto &remained_subgraph : root_graph->GetAllSubgraphs()) { GELOGD("Adding subgraph [%s] to merged-graph.", remained_subgraph->GetName().c_str()); GE_CHK_GRAPH_STATUS_RET(merged_graph->AddSubgraph(remained_subgraph), "Failed to add subgraph [%s]", @@ -675,8 +675,8 @@ Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGrap return SUCCESS; } -Status HybridModelBuilder::UnfoldSubgraph(ComputeGraph &root_graph, - ComputeGraph &parent_graph, +Status HybridModelBuilder::UnfoldSubgraph(ComputeGraphPtr &root_graph, + ComputeGraphPtr &parent_graph, ComputeGraph &sub_graph) { auto parent_node = sub_graph.GetParentNode(); GE_CHECK_NOTNULL(parent_node); @@ -705,15 +705,24 @@ Status HybridModelBuilder::UnfoldSubgraph(ComputeGraph &root_graph, } } - parent_graph.AddNode(sub_node); + if (!sub_node->GetOpDesc()->GetSubgraphInstanceNames().empty()) { + for (size_t i = 0; i < sub_node->GetOpDesc()->GetSubgraphInstanceNames().size(); ++i) { + auto sub_sub_graph = NodeUtils::GetSubgraph(*sub_node, i); + GE_CHECK_NOTNULL(sub_sub_graph); + sub_sub_graph->SetParentGraph(root_graph); + } + } + + parent_graph->AddNode(sub_node); GELOGD("[%s::%s] added to parent graph: [%s].", sub_graph.GetName().c_str(), sub_node->GetName().c_str(), - parent_graph.GetName().c_str()); + parent_graph->GetName().c_str()); + sub_node->SetOwnerComputeGraph(root_graph); } GELOGD("[%s] Done merging subgraph. remove it from root graph.", sub_graph.GetName().c_str()); - root_graph.RemoveSubgraph(sub_graph.GetName()); + root_graph->RemoveSubgraph(sub_graph.GetName()); return SUCCESS; } @@ -765,7 +774,7 @@ Status HybridModelBuilder::LoadGraph() { GELOGI("Before merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", root_graph->GetDirectNodesSize(), root_graph->GetAllNodesSize()); - GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraphs(*root_graph, merged_graph), "Failed to unfold subgraphs."); + GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraphs(root_graph, merged_graph), "Failed to unfold subgraphs."); root_graph = std::move(merged_graph); GELOGI("After merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", root_graph->GetDirectNodesSize(), @@ -1035,6 +1044,14 @@ Status HybridModelBuilder::InitWeights() { sub_weight_buffer->GetSize()); auto root_graph = GraphUtils::GetComputeGraph(subgraph_model.second->GetGraph()); hybrid_model_.weight_buffer_map_.emplace(root_graph->GetName(),std::move(sub_weight_buffer)); + + std::map name_to_node; + for (const auto &subgraph : ge_root_model_->GetRootGraph()->GetAllSubgraphs()) { + for (const auto &node : subgraph->GetAllNodes()) { + name_to_node.insert(make_pair(node->GetName(), node)); + } + } + for (auto &node : root_graph->GetDirectNode()) { if (node->GetType() != CONSTANT) { continue; @@ -1065,10 +1082,25 @@ Status HybridModelBuilder::InitWeights() { auto tensor_buffer = TensorBuffer::Create(weight_base + data_offset, tensor_size); GE_CHECK_NOTNULL(tensor_buffer); + + if (tensor_size > 0) { + auto tensor = std::shared_ptr( + new (std::nothrow)GeTensor(tensor_desc, weight_buffer.GetData() + data_offset, tensor_size)); + OpDescPtr op_desc = nullptr; + auto iter = name_to_node.find(node->GetName()); + if (iter != name_to_node.end()) { + op_desc = iter->second->GetOpDesc(); + if (!AttrUtils::SetTensor(op_desc, ATTR_NAME_WEIGHTS, std::move(tensor))) { + GELOGE(FAILED, "Set attr ATTR_NAME_WEIGHTS failed."); + return FAILED; + } + } + } + std::unique_ptr constant_tensor(new (std::nothrow)TensorValue(std::move(tensor_buffer))); GE_CHECK_NOTNULL(constant_tensor); constant_tensor->SetName("Constant_" + op_desc->GetName()); - hybrid_model_.constant_tensors_.emplace(node, std::move(constant_tensor)); + hybrid_model_.constant_tensors_.emplace(node->GetName(), std::move(constant_tensor)); GELOGD("[%s] Constant node [%s] added, size = %ld", GetGraphName(), node->GetName().c_str(), tensor_size); } } diff --git a/ge/hybrid/model/hybrid_model_builder.h b/ge/hybrid/model/hybrid_model_builder.h index 313d5ca6..a6b2b25a 100644 --- a/ge/hybrid/model/hybrid_model_builder.h +++ b/ge/hybrid/model/hybrid_model_builder.h @@ -47,8 +47,8 @@ class HybridModelBuilder { static Status HandleDtString(const GeTensor &tensor, void *var_addr); static Status MergeInputNodes(ComputeGraph &compute_graph); static Status MergeNetOutputNode(ComputeGraph &compute_graph); - static Status UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGraphPtr &merged_graph); - static Status UnfoldSubgraph(ComputeGraph &root_graph, ComputeGraph &parent_graph, ComputeGraph &sub_graph); + static Status UnfoldSubgraphs(ComputeGraphPtr &root_graph, ComputeGraphPtr &merged_graph); + static Status UnfoldSubgraph(ComputeGraphPtr &root_graph, ComputeGraphPtr &parent_graph, ComputeGraph &sub_graph); static Status BuildInputMapping(GraphItem &graph_item, std::vector &data_nodes, bool is_root_graph);