Browse Source

Bugfix: support unknown control op subgraph

pull/1322/head
lichun 4 years ago
parent
commit
889b82354d
4 changed files with 47 additions and 15 deletions
  1. +1
    -1
      ge/hybrid/model/hybrid_model.cc
  2. +1
    -1
      ge/hybrid/model/hybrid_model.h
  3. +43
    -11
      ge/hybrid/model/hybrid_model_builder.cc
  4. +2
    -2
      ge/hybrid/model/hybrid_model_builder.h

+ 1
- 1
ge/hybrid/model/hybrid_model.cc View File

@@ -333,7 +333,7 @@ TensorValue *HybridModel::GetConstant(const NodePtr &node) const {
return nullptr; return nullptr;
} }


auto it = constant_tensors_.find(node);
auto it = constant_tensors_.find(node->GetName());
if (it == constant_tensors_.end()) { if (it == constant_tensors_.end()) {
GELOGD("constant not found, node name = [%s]", node->GetName().c_str()); GELOGD("constant not found, node name = [%s]", node->GetName().c_str());
return nullptr; return nullptr;


+ 1
- 1
ge/hybrid/model/hybrid_model.h View File

@@ -138,7 +138,7 @@ class HybridModel {
std::map<std::string, NodePtr> device_variable_nodes_; //lint !e148 std::map<std::string, NodePtr> device_variable_nodes_; //lint !e148
std::map<std::string, NodePtr> host_variable_nodes_; //lint !e148 std::map<std::string, NodePtr> host_variable_nodes_; //lint !e148
std::map<std::string, std::unique_ptr<TensorValue>> variable_tensors_; std::map<std::string, std::unique_ptr<TensorValue>> variable_tensors_;
std::map<NodePtr, std::unique_ptr<TensorValue>> constant_tensors_;
std::map<std::string, std::unique_ptr<TensorValue>> constant_tensors_;
std::map<NodePtr, std::vector<domi::TaskDef>> task_defs_; std::map<NodePtr, std::vector<domi::TaskDef>> task_defs_;
std::map<NodePtr, GeModelPtr> known_shape_sub_models_; std::map<NodePtr, GeModelPtr> known_shape_sub_models_;




+ 43
- 11
ge/hybrid/model/hybrid_model_builder.cc View File

@@ -617,9 +617,9 @@ Status HybridModelBuilder::MergeNetOutputNode(ComputeGraph &graph) {
return SUCCESS; return SUCCESS;
} }


Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGraphPtr &merged_graph) {
Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraphPtr &root_graph, ComputeGraphPtr &merged_graph) {
merged_graph = MakeShared<ComputeGraph>("MergedGraph"); merged_graph = MakeShared<ComputeGraph>("MergedGraph");
for (const auto &node : root_graph.GetDirectNode()) {
for (const auto &node : root_graph->GetDirectNode()) {
GE_CHECK_NOTNULL(node); GE_CHECK_NOTNULL(node);
auto op_desc = node->GetOpDesc(); auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc); GE_CHECK_NOTNULL(op_desc);
@@ -649,7 +649,7 @@ Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGrap
} }
} }
} }
GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraph(root_graph, *merged_graph, *subgraph),
GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraph(root_graph, merged_graph, *subgraph),
"[%s] Failed to merge subgraph.", "[%s] Failed to merge subgraph.",
subgraph->GetName().c_str()); subgraph->GetName().c_str());
} }
@@ -665,7 +665,7 @@ Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGrap
return a_level < b_level; return a_level < b_level;
}); });


for (auto &remained_subgraph : root_graph.GetAllSubgraphs()) {
for (auto &remained_subgraph : root_graph->GetAllSubgraphs()) {
GELOGD("Adding subgraph [%s] to merged-graph.", remained_subgraph->GetName().c_str()); GELOGD("Adding subgraph [%s] to merged-graph.", remained_subgraph->GetName().c_str());
GE_CHK_GRAPH_STATUS_RET(merged_graph->AddSubgraph(remained_subgraph), GE_CHK_GRAPH_STATUS_RET(merged_graph->AddSubgraph(remained_subgraph),
"Failed to add subgraph [%s]", "Failed to add subgraph [%s]",
@@ -675,8 +675,8 @@ Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGrap
return SUCCESS; return SUCCESS;
} }


Status HybridModelBuilder::UnfoldSubgraph(ComputeGraph &root_graph,
ComputeGraph &parent_graph,
Status HybridModelBuilder::UnfoldSubgraph(ComputeGraphPtr &root_graph,
ComputeGraphPtr &parent_graph,
ComputeGraph &sub_graph) { ComputeGraph &sub_graph) {
auto parent_node = sub_graph.GetParentNode(); auto parent_node = sub_graph.GetParentNode();
GE_CHECK_NOTNULL(parent_node); GE_CHECK_NOTNULL(parent_node);
@@ -705,15 +705,24 @@ Status HybridModelBuilder::UnfoldSubgraph(ComputeGraph &root_graph,
} }
} }


parent_graph.AddNode(sub_node);
if (!sub_node->GetOpDesc()->GetSubgraphInstanceNames().empty()) {
for (size_t i = 0; i < sub_node->GetOpDesc()->GetSubgraphInstanceNames().size(); ++i) {
auto sub_sub_graph = NodeUtils::GetSubgraph(*sub_node, i);
GE_CHECK_NOTNULL(sub_sub_graph);
sub_sub_graph->SetParentGraph(root_graph);
}
}

parent_graph->AddNode(sub_node);
GELOGD("[%s::%s] added to parent graph: [%s].", GELOGD("[%s::%s] added to parent graph: [%s].",
sub_graph.GetName().c_str(), sub_graph.GetName().c_str(),
sub_node->GetName().c_str(), sub_node->GetName().c_str(),
parent_graph.GetName().c_str());
parent_graph->GetName().c_str());
sub_node->SetOwnerComputeGraph(root_graph);
} }


GELOGD("[%s] Done merging subgraph. remove it from root graph.", sub_graph.GetName().c_str()); GELOGD("[%s] Done merging subgraph. remove it from root graph.", sub_graph.GetName().c_str());
root_graph.RemoveSubgraph(sub_graph.GetName());
root_graph->RemoveSubgraph(sub_graph.GetName());
return SUCCESS; return SUCCESS;
} }


@@ -765,7 +774,7 @@ Status HybridModelBuilder::LoadGraph() {
GELOGI("Before merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", GELOGI("Before merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu",
root_graph->GetDirectNodesSize(), root_graph->GetDirectNodesSize(),
root_graph->GetAllNodesSize()); root_graph->GetAllNodesSize());
GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraphs(*root_graph, merged_graph), "Failed to unfold subgraphs.");
GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraphs(root_graph, merged_graph), "Failed to unfold subgraphs.");
root_graph = std::move(merged_graph); root_graph = std::move(merged_graph);
GELOGI("After merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", GELOGI("After merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu",
root_graph->GetDirectNodesSize(), root_graph->GetDirectNodesSize(),
@@ -1035,6 +1044,14 @@ Status HybridModelBuilder::InitWeights() {
sub_weight_buffer->GetSize()); sub_weight_buffer->GetSize());
auto root_graph = GraphUtils::GetComputeGraph(subgraph_model.second->GetGraph()); auto root_graph = GraphUtils::GetComputeGraph(subgraph_model.second->GetGraph());
hybrid_model_.weight_buffer_map_.emplace(root_graph->GetName(),std::move(sub_weight_buffer)); hybrid_model_.weight_buffer_map_.emplace(root_graph->GetName(),std::move(sub_weight_buffer));

std::map<std::string, NodePtr> name_to_node;
for (const auto &subgraph : ge_root_model_->GetRootGraph()->GetAllSubgraphs()) {
for (const auto &node : subgraph->GetAllNodes()) {
name_to_node.insert(make_pair(node->GetName(), node));
}
}

for (auto &node : root_graph->GetDirectNode()) { for (auto &node : root_graph->GetDirectNode()) {
if (node->GetType() != CONSTANT) { if (node->GetType() != CONSTANT) {
continue; continue;
@@ -1065,10 +1082,25 @@ Status HybridModelBuilder::InitWeights() {


auto tensor_buffer = TensorBuffer::Create(weight_base + data_offset, tensor_size); auto tensor_buffer = TensorBuffer::Create(weight_base + data_offset, tensor_size);
GE_CHECK_NOTNULL(tensor_buffer); GE_CHECK_NOTNULL(tensor_buffer);

if (tensor_size > 0) {
auto tensor = std::shared_ptr<GeTensor>(
new (std::nothrow)GeTensor(tensor_desc, weight_buffer.GetData() + data_offset, tensor_size));
OpDescPtr op_desc = nullptr;
auto iter = name_to_node.find(node->GetName());
if (iter != name_to_node.end()) {
op_desc = iter->second->GetOpDesc();
if (!AttrUtils::SetTensor(op_desc, ATTR_NAME_WEIGHTS, std::move(tensor))) {
GELOGE(FAILED, "Set attr ATTR_NAME_WEIGHTS failed.");
return FAILED;
}
}
}

std::unique_ptr<TensorValue> constant_tensor(new (std::nothrow)TensorValue(std::move(tensor_buffer))); std::unique_ptr<TensorValue> constant_tensor(new (std::nothrow)TensorValue(std::move(tensor_buffer)));
GE_CHECK_NOTNULL(constant_tensor); GE_CHECK_NOTNULL(constant_tensor);
constant_tensor->SetName("Constant_" + op_desc->GetName()); constant_tensor->SetName("Constant_" + op_desc->GetName());
hybrid_model_.constant_tensors_.emplace(node, std::move(constant_tensor));
hybrid_model_.constant_tensors_.emplace(node->GetName(), std::move(constant_tensor));
GELOGD("[%s] Constant node [%s] added, size = %ld", GetGraphName(), node->GetName().c_str(), tensor_size); GELOGD("[%s] Constant node [%s] added, size = %ld", GetGraphName(), node->GetName().c_str(), tensor_size);
} }
} }


+ 2
- 2
ge/hybrid/model/hybrid_model_builder.h View File

@@ -47,8 +47,8 @@ class HybridModelBuilder {
static Status HandleDtString(const GeTensor &tensor, void *var_addr); static Status HandleDtString(const GeTensor &tensor, void *var_addr);
static Status MergeInputNodes(ComputeGraph &compute_graph); static Status MergeInputNodes(ComputeGraph &compute_graph);
static Status MergeNetOutputNode(ComputeGraph &compute_graph); static Status MergeNetOutputNode(ComputeGraph &compute_graph);
static Status UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGraphPtr &merged_graph);
static Status UnfoldSubgraph(ComputeGraph &root_graph, ComputeGraph &parent_graph, ComputeGraph &sub_graph);
static Status UnfoldSubgraphs(ComputeGraphPtr &root_graph, ComputeGraphPtr &merged_graph);
static Status UnfoldSubgraph(ComputeGraphPtr &root_graph, ComputeGraphPtr &parent_graph, ComputeGraph &sub_graph);
static Status BuildInputMapping(GraphItem &graph_item, static Status BuildInputMapping(GraphItem &graph_item,
std::vector<NodeItem *> &data_nodes, std::vector<NodeItem *> &data_nodes,
bool is_root_graph); bool is_root_graph);


Loading…
Cancel
Save