|
|
@@ -617,9 +617,9 @@ Status HybridModelBuilder::MergeNetOutputNode(ComputeGraph &graph) { |
|
|
return SUCCESS; |
|
|
return SUCCESS; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGraphPtr &merged_graph) { |
|
|
|
|
|
|
|
|
Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraphPtr &root_graph, ComputeGraphPtr &merged_graph) { |
|
|
merged_graph = MakeShared<ComputeGraph>("MergedGraph"); |
|
|
merged_graph = MakeShared<ComputeGraph>("MergedGraph"); |
|
|
for (const auto &node : root_graph.GetDirectNode()) { |
|
|
|
|
|
|
|
|
for (const auto &node : root_graph->GetDirectNode()) { |
|
|
GE_CHECK_NOTNULL(node); |
|
|
GE_CHECK_NOTNULL(node); |
|
|
auto op_desc = node->GetOpDesc(); |
|
|
auto op_desc = node->GetOpDesc(); |
|
|
GE_CHECK_NOTNULL(op_desc); |
|
|
GE_CHECK_NOTNULL(op_desc); |
|
|
@@ -649,7 +649,7 @@ Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGrap |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraph(root_graph, *merged_graph, *subgraph), |
|
|
|
|
|
|
|
|
GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraph(root_graph, merged_graph, *subgraph), |
|
|
"[%s] Failed to merge subgraph.", |
|
|
"[%s] Failed to merge subgraph.", |
|
|
subgraph->GetName().c_str()); |
|
|
subgraph->GetName().c_str()); |
|
|
} |
|
|
} |
|
|
@@ -665,7 +665,7 @@ Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGrap |
|
|
return a_level < b_level; |
|
|
return a_level < b_level; |
|
|
}); |
|
|
}); |
|
|
|
|
|
|
|
|
for (auto &remained_subgraph : root_graph.GetAllSubgraphs()) { |
|
|
|
|
|
|
|
|
for (auto &remained_subgraph : root_graph->GetAllSubgraphs()) { |
|
|
GELOGD("Adding subgraph [%s] to merged-graph.", remained_subgraph->GetName().c_str()); |
|
|
GELOGD("Adding subgraph [%s] to merged-graph.", remained_subgraph->GetName().c_str()); |
|
|
GE_CHK_GRAPH_STATUS_RET(merged_graph->AddSubgraph(remained_subgraph), |
|
|
GE_CHK_GRAPH_STATUS_RET(merged_graph->AddSubgraph(remained_subgraph), |
|
|
"Failed to add subgraph [%s]", |
|
|
"Failed to add subgraph [%s]", |
|
|
@@ -675,8 +675,8 @@ Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGrap |
|
|
return SUCCESS; |
|
|
return SUCCESS; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
Status HybridModelBuilder::UnfoldSubgraph(ComputeGraph &root_graph, |
|
|
|
|
|
ComputeGraph &parent_graph, |
|
|
|
|
|
|
|
|
Status HybridModelBuilder::UnfoldSubgraph(ComputeGraphPtr &root_graph, |
|
|
|
|
|
ComputeGraphPtr &parent_graph, |
|
|
ComputeGraph &sub_graph) { |
|
|
ComputeGraph &sub_graph) { |
|
|
auto parent_node = sub_graph.GetParentNode(); |
|
|
auto parent_node = sub_graph.GetParentNode(); |
|
|
GE_CHECK_NOTNULL(parent_node); |
|
|
GE_CHECK_NOTNULL(parent_node); |
|
|
@@ -705,15 +705,24 @@ Status HybridModelBuilder::UnfoldSubgraph(ComputeGraph &root_graph, |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
parent_graph.AddNode(sub_node); |
|
|
|
|
|
|
|
|
if (!sub_node->GetOpDesc()->GetSubgraphInstanceNames().empty()) { |
|
|
|
|
|
for (size_t i = 0; i < sub_node->GetOpDesc()->GetSubgraphInstanceNames().size(); ++i) { |
|
|
|
|
|
auto sub_sub_graph = NodeUtils::GetSubgraph(*sub_node, i); |
|
|
|
|
|
GE_CHECK_NOTNULL(sub_sub_graph); |
|
|
|
|
|
sub_sub_graph->SetParentGraph(root_graph); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
parent_graph->AddNode(sub_node); |
|
|
GELOGD("[%s::%s] added to parent graph: [%s].", |
|
|
GELOGD("[%s::%s] added to parent graph: [%s].", |
|
|
sub_graph.GetName().c_str(), |
|
|
sub_graph.GetName().c_str(), |
|
|
sub_node->GetName().c_str(), |
|
|
sub_node->GetName().c_str(), |
|
|
parent_graph.GetName().c_str()); |
|
|
|
|
|
|
|
|
parent_graph->GetName().c_str()); |
|
|
|
|
|
sub_node->SetOwnerComputeGraph(root_graph); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
GELOGD("[%s] Done merging subgraph. remove it from root graph.", sub_graph.GetName().c_str()); |
|
|
GELOGD("[%s] Done merging subgraph. remove it from root graph.", sub_graph.GetName().c_str()); |
|
|
root_graph.RemoveSubgraph(sub_graph.GetName()); |
|
|
|
|
|
|
|
|
root_graph->RemoveSubgraph(sub_graph.GetName()); |
|
|
return SUCCESS; |
|
|
return SUCCESS; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
@@ -765,7 +774,7 @@ Status HybridModelBuilder::LoadGraph() { |
|
|
GELOGI("Before merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", |
|
|
GELOGI("Before merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", |
|
|
root_graph->GetDirectNodesSize(), |
|
|
root_graph->GetDirectNodesSize(), |
|
|
root_graph->GetAllNodesSize()); |
|
|
root_graph->GetAllNodesSize()); |
|
|
GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraphs(*root_graph, merged_graph), "Failed to unfold subgraphs."); |
|
|
|
|
|
|
|
|
GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraphs(root_graph, merged_graph), "Failed to unfold subgraphs."); |
|
|
root_graph = std::move(merged_graph); |
|
|
root_graph = std::move(merged_graph); |
|
|
GELOGI("After merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", |
|
|
GELOGI("After merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", |
|
|
root_graph->GetDirectNodesSize(), |
|
|
root_graph->GetDirectNodesSize(), |
|
|
@@ -1035,6 +1044,14 @@ Status HybridModelBuilder::InitWeights() { |
|
|
sub_weight_buffer->GetSize()); |
|
|
sub_weight_buffer->GetSize()); |
|
|
auto root_graph = GraphUtils::GetComputeGraph(subgraph_model.second->GetGraph()); |
|
|
auto root_graph = GraphUtils::GetComputeGraph(subgraph_model.second->GetGraph()); |
|
|
hybrid_model_.weight_buffer_map_.emplace(root_graph->GetName(),std::move(sub_weight_buffer)); |
|
|
hybrid_model_.weight_buffer_map_.emplace(root_graph->GetName(),std::move(sub_weight_buffer)); |
|
|
|
|
|
|
|
|
|
|
|
std::map<std::string, NodePtr> name_to_node; |
|
|
|
|
|
for (const auto &subgraph : ge_root_model_->GetRootGraph()->GetAllSubgraphs()) { |
|
|
|
|
|
for (const auto &node : subgraph->GetAllNodes()) { |
|
|
|
|
|
name_to_node.insert(make_pair(node->GetName(), node)); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
for (auto &node : root_graph->GetDirectNode()) { |
|
|
for (auto &node : root_graph->GetDirectNode()) { |
|
|
if (node->GetType() != CONSTANT) { |
|
|
if (node->GetType() != CONSTANT) { |
|
|
continue; |
|
|
continue; |
|
|
@@ -1065,10 +1082,25 @@ Status HybridModelBuilder::InitWeights() { |
|
|
|
|
|
|
|
|
auto tensor_buffer = TensorBuffer::Create(weight_base + data_offset, tensor_size); |
|
|
auto tensor_buffer = TensorBuffer::Create(weight_base + data_offset, tensor_size); |
|
|
GE_CHECK_NOTNULL(tensor_buffer); |
|
|
GE_CHECK_NOTNULL(tensor_buffer); |
|
|
|
|
|
|
|
|
|
|
|
if (tensor_size > 0) { |
|
|
|
|
|
auto tensor = std::shared_ptr<GeTensor>( |
|
|
|
|
|
new (std::nothrow)GeTensor(tensor_desc, weight_buffer.GetData() + data_offset, tensor_size)); |
|
|
|
|
|
OpDescPtr op_desc = nullptr; |
|
|
|
|
|
auto iter = name_to_node.find(node->GetName()); |
|
|
|
|
|
if (iter != name_to_node.end()) { |
|
|
|
|
|
op_desc = iter->second->GetOpDesc(); |
|
|
|
|
|
if (!AttrUtils::SetTensor(op_desc, ATTR_NAME_WEIGHTS, std::move(tensor))) { |
|
|
|
|
|
GELOGE(FAILED, "Set attr ATTR_NAME_WEIGHTS failed."); |
|
|
|
|
|
return FAILED; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
std::unique_ptr<TensorValue> constant_tensor(new (std::nothrow)TensorValue(std::move(tensor_buffer))); |
|
|
std::unique_ptr<TensorValue> constant_tensor(new (std::nothrow)TensorValue(std::move(tensor_buffer))); |
|
|
GE_CHECK_NOTNULL(constant_tensor); |
|
|
GE_CHECK_NOTNULL(constant_tensor); |
|
|
constant_tensor->SetName("Constant_" + op_desc->GetName()); |
|
|
constant_tensor->SetName("Constant_" + op_desc->GetName()); |
|
|
hybrid_model_.constant_tensors_.emplace(node, std::move(constant_tensor)); |
|
|
|
|
|
|
|
|
hybrid_model_.constant_tensors_.emplace(node->GetName(), std::move(constant_tensor)); |
|
|
GELOGD("[%s] Constant node [%s] added, size = %ld", GetGraphName(), node->GetName().c_str(), tensor_size); |
|
|
GELOGD("[%s] Constant node [%s] added, size = %ld", GetGraphName(), node->GetName().c_str(), tensor_size); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|