|
|
|
@@ -133,7 +133,8 @@ std::vector<std::vector<size_t>> GenerateLoadGroups(const FuncGraphPtr &fg, cons |
|
|
|
} |
|
|
|
if (!IsPrimitiveCNode(cnode, prim::kPrimLoad)) { |
|
|
|
for (const auto &input : cnode->inputs()) { |
|
|
|
if (input->isa<Parameter>()) { |
|
|
|
if (input->isa<Parameter>() || |
|
|
|
(IsPrimitiveCNode(input, prim::kPrimDepend) && input->cast<CNodePtr>()->input(1)->isa<Parameter>())) { |
|
|
|
unload_users_record.insert(input); |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -317,7 +318,10 @@ AnfNodePtr GetFirstMonad(const FuncGraphPtr &fg) { |
|
|
|
// To: |
|
|
|
// u1 = UpdateState(u, c) |
|
|
|
// p1 = Load(para1, u') // u' is first monad in graph or new monad |
|
|
|
void ReplaceUpdateStateForLoad(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &need_replace_loads) { |
|
|
|
bool ReplaceUpdateStateForLoad(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &need_replace_loads) { |
|
|
|
if (need_replace_loads.size() == 0) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
constexpr size_t second_input_index = 2; |
|
|
|
auto monad = GetFirstMonad(fg); |
|
|
|
for (const auto &load_node : need_replace_loads) { |
|
|
|
@@ -331,6 +335,7 @@ void ReplaceUpdateStateForLoad(const FuncGraphPtr &fg, const std::vector<AnfNode |
|
|
|
auto mgr = fg->manager(); |
|
|
|
mgr->SetEdge(load_node, second_input_index, monad); |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
// Node1{primLoad,X,Y1},...,Node{Node's input != X},...,Node2{primLoad,X,Y2},... => |
|
|
|
@@ -341,7 +346,10 @@ bool CSE::ReplaceAutoMonadNode(const FuncGraphManagerPtr &manager) const { |
|
|
|
std::vector<AnfNodePtr> toposet = TopoSort(fg->get_return()); |
|
|
|
std::vector<AnfNodePtr> need_replace_loads; |
|
|
|
std::vector<std::vector<size_t>> load_groups = GenerateLoadGroups(fg, toposet, &need_replace_loads); |
|
|
|
ReplaceUpdateStateForLoad(fg, need_replace_loads); |
|
|
|
const bool update_state_replaced = ReplaceUpdateStateForLoad(fg, need_replace_loads); |
|
|
|
if (update_state_replaced) { |
|
|
|
changed = true; |
|
|
|
} |
|
|
|
// split group if there is no-load node between two load nodes. |
|
|
|
std::vector<std::vector<size_t>> need_merge_loads; |
|
|
|
for (auto &group : load_groups) { |
|
|
|
|