Browse Source

!13208 [GraphKernel] Through pass monad depend in parallel fusion.

From: @tronzhang
Reviewed-by: @gaoxiong1,@anyrenwei
Signed-off-by: @anyrenwei
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
281bf9aa3a
2 changed files with 12 additions and 2 deletions
  1. +8
    -1
      mindspore/ccsrc/backend/optimizer/graph_kernel/depend_formater.cc
  2. +4
    -1
      mindspore/ccsrc/backend/optimizer/graph_kernel/parallel_fusion.cc

+ 8
- 1
mindspore/ccsrc/backend/optimizer/graph_kernel/depend_formater.cc View File

@@ -135,7 +135,14 @@ bool DependFormater::Run(const FuncGraphPtr &func_graph) {
} }


old_depends.push_back(node); old_depends.push_back(node);
free_nodes.push_back(node->cast<CNodePtr>()->input(kDependAttachNodeIndex));
auto cnode = node->cast<CNodePtr>();
for (size_t id = kDependAttachNodeIndex; id < cnode->inputs().size(); ++id) {
auto attach_node = cnode->input(id);
if (!IsPrimitiveCNode(attach_node, prim::kPrimDepend)) {
continue;
}
free_nodes.push_back(attach_node);
}
} }


if (old_depends.empty()) { if (old_depends.empty()) {


+ 4
- 1
mindspore/ccsrc/backend/optimizer/graph_kernel/parallel_fusion.cc View File

@@ -96,6 +96,7 @@ void ProcessThroughPassCNode(std::function<bool(const AnfNodePtr &)> pass_fn,
} }


void ProcessDependCNode(OrderedMap<AnfNodePtr, NodeRelation> *node_rels) { void ProcessDependCNode(OrderedMap<AnfNodePtr, NodeRelation> *node_rels) {
OrderedSet<AnfNodePtr> to_be_through_pass;
for (auto &[node, node_rel] : (*node_rels)) { for (auto &[node, node_rel] : (*node_rels)) {
if (!IsPrimitiveCNode(node, prim::kPrimDepend) || if (!IsPrimitiveCNode(node, prim::kPrimDepend) ||
HasAbstractMonad(node->cast<CNodePtr>()->input(kDependAttachNodeIndex))) { HasAbstractMonad(node->cast<CNodePtr>()->input(kDependAttachNodeIndex))) {
@@ -113,10 +114,12 @@ void ProcessDependCNode(OrderedMap<AnfNodePtr, NodeRelation> *node_rels) {
cnode_pres.erase(attach_node); cnode_pres.erase(attach_node);
} }
} }
to_be_through_pass.insert(node);
} }


// Eliminate depend node of node relations. // Eliminate depend node of node relations.
ProcessThroughPassCNode([](const AnfNodePtr &node) { return IsOneOf(node, {prim::kPrimDepend}); }, node_rels);
ProcessThroughPassCNode([&to_be_through_pass](const AnfNodePtr &node) { return to_be_through_pass.count(node) > 0; },
node_rels);
} }


void ProcessTailMakeTupleCNode(OrderedMap<AnfNodePtr, NodeRelation> *node_rels) { void ProcessTailMakeTupleCNode(OrderedMap<AnfNodePtr, NodeRelation> *node_rels) {


Loading…
Cancel
Save