| @@ -24,6 +24,18 @@ | |||||
| #include "frontend/operator/ops.h" | #include "frontend/operator/ops.h" | ||||
| namespace mindspore::opt::irpass { | namespace mindspore::opt::irpass { | ||||
| // Covert: | |||||
| // load1 = load(para1, u1) | |||||
| // u2 = UpdateState(u1, load1) | |||||
| // ... | |||||
| // load2 = load(load1, u3) | |||||
| // u4 = UpdateState(u3, load2) | |||||
| // To: | |||||
| // load1 = load(para1, u1) | |||||
| // u2 = UpdateState(u1, load1) | |||||
| // ... | |||||
| // load2 = load(para1, u3) # load1 replaced by para1 | |||||
| // u4 = UpdateState(u3, load2) | |||||
| AnfNodePtr LoadEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) { | AnfNodePtr LoadEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) { | ||||
| auto load_node = dyn_cast<CNode>(node); | auto load_node = dyn_cast<CNode>(node); | ||||
| if (load_node == nullptr || load_node->inputs().empty()) { | if (load_node == nullptr || load_node->inputs().empty()) { | ||||
| @@ -32,8 +44,20 @@ AnfNodePtr LoadEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &no | |||||
| } | } | ||||
| auto load_cnode = load_node->cast<CNodePtr>(); | auto load_cnode = load_node->cast<CNodePtr>(); | ||||
| constexpr size_t kFirstInputIndex = 1; | constexpr size_t kFirstInputIndex = 1; | ||||
| if (IsPrimitiveCNode(load_cnode->input(kFirstInputIndex), prim::kPrimLoad)) { | |||||
| return load_cnode->input(kFirstInputIndex); | |||||
| constexpr size_t kSecondInputIndex = 2; | |||||
| auto &input_load = load_cnode->input(kFirstInputIndex); | |||||
| if (IsPrimitiveCNode(input_load, prim::kPrimLoad)) { | |||||
| auto load_prim = NewValueNode(prim::kPrimLoad); | |||||
| auto input_load_cnode = input_load->cast<CNodePtr>(); | |||||
| auto replace_input = input_load_cnode->input(kFirstInputIndex); | |||||
| auto monad = load_cnode->input(kSecondInputIndex); | |||||
| std::vector<AnfNodePtr> new_load_inputs = {load_prim, replace_input, monad}; | |||||
| auto fg = load_cnode->func_graph(); | |||||
| MS_EXCEPTION_IF_NULL(fg); | |||||
| auto new_load = fg->NewCNode(new_load_inputs); | |||||
| new_load->set_abstract(load_cnode->abstract()); | |||||
| new_load->set_scope(load_cnode->scope()); | |||||
| return new_load; | |||||
| } | } | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||