|
|
|
@@ -24,27 +24,16 @@ |
|
|
|
#include "frontend/operator/ops.h" |
|
|
|
|
|
|
|
namespace mindspore::opt::irpass { |
|
|
|
namespace { |
|
|
|
// Return true if the node has Ref abstract. |
|
|
|
bool HasAbstractRef(const AnfNodePtr &node) { |
|
|
|
if (node == nullptr) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
auto &abs = node->abstract(); |
|
|
|
return (abs != nullptr) && abs->isa<abstract::AbstractRef>(); |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
AnfNodePtr LoadEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) { |
|
|
|
auto load_node = dyn_cast<CNode>(node); |
|
|
|
if (load_node == nullptr || load_node->inputs().empty()) { |
|
|
|
MS_LOG(WARNING) << "LoadEliminater encounter invalid node: " << node->DebugString(); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
auto load_cnode = load_node->cast<CNodePtr>(); |
|
|
|
constexpr size_t kFirstInputIndex = 1; |
|
|
|
auto ¶m = load_node->inputs().at(kFirstInputIndex); |
|
|
|
if (!HasAbstractRef(param)) { |
|
|
|
return param; |
|
|
|
if (IsPrimitiveCNode(load_cnode->input(kFirstInputIndex), prim::kPrimLoad)) { |
|
|
|
return load_cnode->input(kFirstInputIndex); |
|
|
|
} |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|