Browse Source

!12621 Eliminate all redundant nodes related to UpdateStates.

From: @zh_qh
Reviewed-by: @ginfung,@hwhewei
Signed-off-by: @ginfung
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
b13cabeb10
6 changed files with 58 additions and 17 deletions
  1. +2
    -3
      mindspore/ccsrc/backend/session/session_basic.cc
  2. +2
    -2
      mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h
  3. +47
    -4
      mindspore/ccsrc/frontend/optimizer/irpass/updatestate_eliminate.cc
  4. +5
    -6
      mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc
  5. +1
    -1
      mindspore/ccsrc/runtime/device/ascend/tasksink/task_generator.cc
  6. +1
    -1
      tests/ut/cpp/ir/manager_test.cc

+ 2
- 3
mindspore/ccsrc/backend/session/session_basic.cc View File

@@ -1878,10 +1878,9 @@ bool CNodeFirstInputIsPrimitive(const AnfNodePtr &node) {

std::vector<AnfNodePtr> ExtendNodeUsers(const FuncGraphManagerPtr &front_func_graph_manager,
const AnfNodePtr &front_node) {
auto node_users = front_func_graph_manager->node_users();
auto users = node_users[front_node];
auto &users = front_func_graph_manager->node_users()[front_node];
std::vector<AnfNodePtr> result;
for (auto user : users) {
for (auto &user : users) {
if (IsPrimitiveCNode(user.first, prim::kPrimControlDepend)) {
continue;
}


+ 2
- 2
mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h View File

@@ -433,12 +433,12 @@ class IncorporateGetitemSwitch : public AnfVisitor {
MS_EXCEPTION_IF_NULL(switch_call_cnode);
auto manager = fg->manager();
MS_EXCEPTION_IF_NULL(manager);
auto node_users_map = manager->node_users();
auto &node_users_map = manager->node_users();
auto it = node_users_map.find(switch_call);
if (it == node_users_map.end()) {
return false;
}
auto node_users = it->second;
auto &node_users = it->second;
// If switch was used by more than 1 tuple_getitem nodes, this pass shouldn't be execute.s
auto tuple_getitem_num = std::count_if(node_users.begin(), node_users.end(), [](std::pair<AnfNodePtr, int> &user) {
return IsPrimitiveCNode(user.first, prim::kPrimTupleGetItem);


+ 47
- 4
mindspore/ccsrc/frontend/optimizer/irpass/updatestate_eliminate.cc View File

@@ -291,6 +291,50 @@ AnfNodePtr MakeTupleForSameNodes(const FuncGraphPtr &fg, const CNodePtr &old_upd
return make_tuple;
}

// Remove all nodes related to UpdateStates, if they're redundant.
void EliminateUselessNodesForUpdateStates(const std::vector<CNodePtr> &update_states) {
if (update_states.empty()) {
return;
}
auto mgr = GetManager(update_states[0]);

// 1. Remove the use of UpdateState nodes, except the last one.
for (auto i = update_states.size() - 1; i > 0; i--) {
auto &us = update_states[i];
mgr->Replace(us, us->input(kInputIndex));
}

// 2. Remove the Depend users of last UpdateState node.
auto &node_users = mgr->node_users();
auto iter = node_users.find(update_states[0]);
if (iter == node_users.end()) {
return;
}
auto &us_users = iter->second;
if (us_users.size() < 2) {
return;
}
std::vector<AnfNodePtr> depend_nodes;
for (auto &user : us_users) {
if (IsPrimitiveCNode(user.first, prim::kPrimDepend) && user.second == kAttachIndex) {
depend_nodes.emplace_back(user.first);
}
}
if (depend_nodes.empty()) {
return;
}
ssize_t end = 0;
// If all users are Depend CNode.
if (depend_nodes.size() == us_users.size()) {
end = 1;
}
for (ssize_t i = depend_nodes.size() - 1; i >= end; i--) {
const auto &depend_node = depend_nodes[i];
const auto &depend_cnode = depend_node->cast<CNodePtr>();
mgr->Replace(depend_cnode, depend_cnode->input(kInputIndex));
}
}

// Eliminate UpdateStates for consecutive Loads.
// Convert:
// x1 = Load(x1, u)
@@ -336,10 +380,9 @@ AnfNodePtr EliminateUpdateStateForLoads(const CNodePtr &old_update_state, const
mgr->SetEdge(load, kAttachIndex, input_monad);
}
}
for (auto i = update_states.size() - 1; i > 0; i--) {
auto &us = update_states[i];
mgr->Replace(us, us->input(kInputIndex));
}

EliminateUselessNodesForUpdateStates(update_states);

if (make_tuple_inputs.size() == 1) {
// This should not happen.
MS_LOG(WARNING) << "No loads for " << old_update_state->DebugString(2);


+ 5
- 6
mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc View File

@@ -52,25 +52,24 @@ static bool IsInWhiteList(const CNodePtr &cnode) {
return false;
}

static void SetGradTag(const AnfNodePtr &node, NodeUsersMap node_users_map) {
auto node_users = node_users_map[node];
static void SetGradTag(const AnfNodePtr &node, const FuncGraphManagerPtr &manager) {
const auto &node_users = manager->node_users()[node];
for (auto &user_pair : node_users) {
auto user_node = user_pair.first;
if (!user_node->grad()) {
user_node->set_grad(true);
SetGradTag(user_node, node_users_map);
SetGradTag(user_node, manager);
}
}
}

void PipelineTransformer::LabelRequiredGradCNode() {
auto parameters = root_->parameters();
auto node_users_map = manager_->node_users();
for (auto parameter : parameters) {
if (!ParameterRequireGrad(parameter)) {
continue;
}
SetGradTag(parameter, node_users_map);
SetGradTag(parameter, manager_);
}
}

@@ -243,7 +242,7 @@ void PipelineTransformer::DoBroadCast(const FuncGraphPtr &func) {
while (need_coloring) {
need_coloring = false;
auto all_nodes = func->nodes();
auto node_users = manager_->node_users();
auto &node_users = manager_->node_users();
for (auto &node : all_nodes) {
if (node->isa<CNode>() || node->stage() == -1) {
continue;


+ 1
- 1
mindspore/ccsrc/runtime/device/ascend/tasksink/task_generator.cc View File

@@ -58,7 +58,7 @@ void TaskGenerator::LaunchAddrCleanAkgKernel(const CNodePtr &anf_node_ptr, Addre
MS_EXCEPTION_IF_NULL(graph);
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
auto node_users = manager->node_users();
auto &node_users = manager->node_users();
if (node_users[anf_node_ptr].empty()) {
MS_LOG(EXCEPTION) << "Node users of " << anf_node_ptr->ToString() << " is empty.";
}


+ 1
- 1
tests/ut/cpp/ir/manager_test.cc View File

@@ -391,7 +391,7 @@ TEST_F(TestManager, test_nested_manual) {
ASSERT_EQ(2, f->nodes().size());
ASSERT_EQ(1, g->nodes().size());

auto users = mng->node_users();
auto &users = mng->node_users();
for (auto& iter : users) {
ASSERT_EQ(1, iter.second.size());
}


Loading…
Cancel
Save