|
|
|
@@ -24,6 +24,18 @@ |
|
|
|
#include "frontend/operator/ops.h" |
|
|
|
|
|
|
|
namespace mindspore::opt::irpass { |
|
|
|
// Covert: |
|
|
|
// load1 = load(para1, u1) |
|
|
|
// u2 = UpdateState(u1, load1) |
|
|
|
// ... |
|
|
|
// load2 = load(load1, u3) |
|
|
|
// u4 = UpdateState(u3, load2) |
|
|
|
// To: |
|
|
|
// load1 = load(para1, u1) |
|
|
|
// u2 = UpdateState(u1, load1) |
|
|
|
// ... |
|
|
|
// load2 = load(para1, u3) # load1 replaced by para1 |
|
|
|
// u4 = UpdateState(u3, load2) |
|
|
|
AnfNodePtr LoadEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) { |
|
|
|
auto load_node = dyn_cast<CNode>(node); |
|
|
|
if (load_node == nullptr || load_node->inputs().empty()) { |
|
|
|
@@ -32,8 +44,20 @@ AnfNodePtr LoadEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &no |
|
|
|
} |
|
|
|
auto load_cnode = load_node->cast<CNodePtr>(); |
|
|
|
constexpr size_t kFirstInputIndex = 1; |
|
|
|
if (IsPrimitiveCNode(load_cnode->input(kFirstInputIndex), prim::kPrimLoad)) { |
|
|
|
return load_cnode->input(kFirstInputIndex); |
|
|
|
constexpr size_t kSecondInputIndex = 2; |
|
|
|
auto &input_load = load_cnode->input(kFirstInputIndex); |
|
|
|
if (IsPrimitiveCNode(input_load, prim::kPrimLoad)) { |
|
|
|
auto load_prim = NewValueNode(prim::kPrimLoad); |
|
|
|
auto input_load_cnode = input_load->cast<CNodePtr>(); |
|
|
|
auto replace_input = input_load_cnode->input(kFirstInputIndex); |
|
|
|
auto monad = load_cnode->input(kSecondInputIndex); |
|
|
|
std::vector<AnfNodePtr> new_load_inputs = {load_prim, replace_input, monad}; |
|
|
|
auto fg = load_cnode->func_graph(); |
|
|
|
MS_EXCEPTION_IF_NULL(fg); |
|
|
|
auto new_load = fg->NewCNode(new_load_inputs); |
|
|
|
new_load->set_abstract(load_cnode->abstract()); |
|
|
|
new_load->set_scope(load_cnode->scope()); |
|
|
|
return new_load; |
|
|
|
} |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|