| @@ -24,27 +24,16 @@ | |||||
| #include "frontend/operator/ops.h" | #include "frontend/operator/ops.h" | ||||
| namespace mindspore::opt::irpass { | namespace mindspore::opt::irpass { | ||||
| namespace { | |||||
| // Return true if the node has Ref abstract. | |||||
| bool HasAbstractRef(const AnfNodePtr &node) { | |||||
| if (node == nullptr) { | |||||
| return false; | |||||
| } | |||||
| auto &abs = node->abstract(); | |||||
| return (abs != nullptr) && abs->isa<abstract::AbstractRef>(); | |||||
| } | |||||
| } // namespace | |||||
| AnfNodePtr LoadEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) { | AnfNodePtr LoadEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) { | ||||
| auto load_node = dyn_cast<CNode>(node); | auto load_node = dyn_cast<CNode>(node); | ||||
| if (load_node == nullptr || load_node->inputs().empty()) { | if (load_node == nullptr || load_node->inputs().empty()) { | ||||
| MS_LOG(WARNING) << "LoadEliminater encounter invalid node: " << node->DebugString(); | MS_LOG(WARNING) << "LoadEliminater encounter invalid node: " << node->DebugString(); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto load_cnode = load_node->cast<CNodePtr>(); | |||||
| constexpr size_t kFirstInputIndex = 1; | constexpr size_t kFirstInputIndex = 1; | ||||
| auto ¶m = load_node->inputs().at(kFirstInputIndex); | |||||
| if (!HasAbstractRef(param)) { | |||||
| return param; | |||||
| if (IsPrimitiveCNode(load_cnode->input(kFirstInputIndex), prim::kPrimLoad)) { | |||||
| return load_cnode->input(kFirstInputIndex); | |||||
| } | } | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||