|
|
@@ -317,17 +317,35 @@ bool IsAllNopNode(const session::KernelGraph *const graph) { |
|
|
return true; |
|
|
return true; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
bool CheckNopNodeIsOutputNode(const std::vector<AnfNodePtr> &outputs, const AnfNodePtr &node, bool is_dynamic_graph) { |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
|
|
// if node is not a nop node, keep it in execution order |
|
|
|
|
|
if (!IsNopNode(node)) { |
|
|
|
|
|
return true; |
|
|
|
|
|
} |
|
|
|
|
|
// if node is nop node and the graph is dynamic graph, check if the nop node is graph's output. |
|
|
|
|
|
if (is_dynamic_graph) { |
|
|
|
|
|
auto iter = find(outputs.begin(), outputs.end(), node); |
|
|
|
|
|
if (iter != outputs.end()) { |
|
|
|
|
|
return true; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
void HideNopNode(session::KernelGraph *const graph) { |
|
|
void HideNopNode(session::KernelGraph *const graph) { |
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
if (IsAllNopNode(graph) == true) { |
|
|
if (IsAllNopNode(graph) == true) { |
|
|
return; |
|
|
return; |
|
|
} |
|
|
} |
|
|
auto execution_order = graph->execution_order(); |
|
|
auto execution_order = graph->execution_order(); |
|
|
|
|
|
auto outputs = graph->outputs(); |
|
|
|
|
|
bool is_dynamic_graph = graph->is_dynamic_shape(); |
|
|
MS_LOG(INFO) << "nop node info (Before Remove) size: " << execution_order.size(); |
|
|
MS_LOG(INFO) << "nop node info (Before Remove) size: " << execution_order.size(); |
|
|
std::vector<CNodePtr> new_nodes; |
|
|
std::vector<CNodePtr> new_nodes; |
|
|
for (auto &cnode : execution_order) { |
|
|
for (auto &cnode : execution_order) { |
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
if (!IsNopNode(cnode)) { |
|
|
|
|
|
|
|
|
if (CheckNopNodeIsOutputNode(outputs, cnode, is_dynamic_graph)) { |
|
|
new_nodes.push_back(cnode); |
|
|
new_nodes.push_back(cnode); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
@@ -344,10 +362,12 @@ void RemoveNopNode(session::KernelGraph *const graph) { |
|
|
while (changed) { |
|
|
while (changed) { |
|
|
changed = false; |
|
|
changed = false; |
|
|
std::vector<CNodePtr> new_nodes; |
|
|
std::vector<CNodePtr> new_nodes; |
|
|
|
|
|
auto outputs = graph->outputs(); |
|
|
|
|
|
bool is_dynamic_graph = graph->is_dynamic_shape(); |
|
|
for (auto &cnode : graph->execution_order()) { |
|
|
for (auto &cnode : graph->execution_order()) { |
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
// ignore nop node itself |
|
|
// ignore nop node itself |
|
|
if (IsNopNode(cnode)) { |
|
|
|
|
|
|
|
|
if (!CheckNopNodeIsOutputNode(outputs, cnode, is_dynamic_graph)) { |
|
|
continue; |
|
|
continue; |
|
|
} |
|
|
} |
|
|
// Replace the input which is nop node |
|
|
// Replace the input which is nop node |
|
|
|