Browse Source

fix load input is depend node when ReplaceUpdateStateForLoad

tags/v1.2.0-rc1
Margaret_wangrui 5 years ago
parent
commit
81259d7339
1 changed files with 11 additions and 3 deletions
  1. +11
    -3
      mindspore/ccsrc/frontend/optimizer/cse.cc

+ 11
- 3
mindspore/ccsrc/frontend/optimizer/cse.cc View File

@@ -133,7 +133,8 @@ std::vector<std::vector<size_t>> GenerateLoadGroups(const FuncGraphPtr &fg, cons
}
if (!IsPrimitiveCNode(cnode, prim::kPrimLoad)) {
for (const auto &input : cnode->inputs()) {
if (input->isa<Parameter>()) {
if (input->isa<Parameter>() ||
(IsPrimitiveCNode(input, prim::kPrimDepend) && input->cast<CNodePtr>()->input(1)->isa<Parameter>())) {
unload_users_record.insert(input);
}
}
@@ -317,7 +318,10 @@ AnfNodePtr GetFirstMonad(const FuncGraphPtr &fg) {
// To:
// u1 = UpdateState(u, c)
// p1 = Load(para1, u') // u' is first monad in graph or new monad
void ReplaceUpdateStateForLoad(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &need_replace_loads) {
bool ReplaceUpdateStateForLoad(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &need_replace_loads) {
if (need_replace_loads.size() == 0) {
return false;
}
constexpr size_t second_input_index = 2;
auto monad = GetFirstMonad(fg);
for (const auto &load_node : need_replace_loads) {
@@ -331,6 +335,7 @@ void ReplaceUpdateStateForLoad(const FuncGraphPtr &fg, const std::vector<AnfNode
auto mgr = fg->manager();
mgr->SetEdge(load_node, second_input_index, monad);
}
return true;
}

// Node1{primLoad,X,Y1},...,Node{Node's input != X},...,Node2{primLoad,X,Y2},... =>
@@ -341,7 +346,10 @@ bool CSE::ReplaceAutoMonadNode(const FuncGraphManagerPtr &manager) const {
std::vector<AnfNodePtr> toposet = TopoSort(fg->get_return());
std::vector<AnfNodePtr> need_replace_loads;
std::vector<std::vector<size_t>> load_groups = GenerateLoadGroups(fg, toposet, &need_replace_loads);
ReplaceUpdateStateForLoad(fg, need_replace_loads);
const bool update_state_replaced = ReplaceUpdateStateForLoad(fg, need_replace_loads);
if (update_state_replaced) {
changed = true;
}
// split group if there is no-load node between two load nodes.
std::vector<std::vector<size_t>> need_merge_loads;
for (auto &group : load_groups) {


Loading…
Cancel
Save