Browse Source

!10432 keep nop node in execution order if it's graph's output

From: @liubuyu
Reviewed-by: @zhoufeng54,@kisnwang
Signed-off-by: @kisnwang
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
7e0c727ace
1 changed files with 22 additions and 2 deletions
  1. +22
    -2
      mindspore/ccsrc/backend/optimizer/common/helper.cc

+ 22
- 2
mindspore/ccsrc/backend/optimizer/common/helper.cc View File

@@ -317,17 +317,35 @@ bool IsAllNopNode(const session::KernelGraph *const graph) {
return true; return true;
} }


bool CheckNopNodeIsOutputNode(const std::vector<AnfNodePtr> &outputs, const AnfNodePtr &node, bool is_dynamic_graph) {
MS_EXCEPTION_IF_NULL(node);
// if node is not a nop node, keep it in execution order
if (!IsNopNode(node)) {
return true;
}
// if node is nop node and the graph is dynamic graph, check if the nop node is graph's output.
if (is_dynamic_graph) {
auto iter = find(outputs.begin(), outputs.end(), node);
if (iter != outputs.end()) {
return true;
}
}
return false;
}

void HideNopNode(session::KernelGraph *const graph) { void HideNopNode(session::KernelGraph *const graph) {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
if (IsAllNopNode(graph) == true) { if (IsAllNopNode(graph) == true) {
return; return;
} }
auto execution_order = graph->execution_order(); auto execution_order = graph->execution_order();
auto outputs = graph->outputs();
bool is_dynamic_graph = graph->is_dynamic_shape();
MS_LOG(INFO) << "nop node info (Before Remove) size: " << execution_order.size(); MS_LOG(INFO) << "nop node info (Before Remove) size: " << execution_order.size();
std::vector<CNodePtr> new_nodes; std::vector<CNodePtr> new_nodes;
for (auto &cnode : execution_order) { for (auto &cnode : execution_order) {
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
if (!IsNopNode(cnode)) {
if (CheckNopNodeIsOutputNode(outputs, cnode, is_dynamic_graph)) {
new_nodes.push_back(cnode); new_nodes.push_back(cnode);
} }
} }
@@ -344,10 +362,12 @@ void RemoveNopNode(session::KernelGraph *const graph) {
while (changed) { while (changed) {
changed = false; changed = false;
std::vector<CNodePtr> new_nodes; std::vector<CNodePtr> new_nodes;
auto outputs = graph->outputs();
bool is_dynamic_graph = graph->is_dynamic_shape();
for (auto &cnode : graph->execution_order()) { for (auto &cnode : graph->execution_order()) {
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
// ignore nop node itself // ignore nop node itself
if (IsNopNode(cnode)) {
if (!CheckNopNodeIsOutputNode(outputs, cnode, is_dynamic_graph)) {
continue; continue;
} }
// Replace the input which is nop node // Replace the input which is nop node


Loading…
Cancel
Save