Browse Source

eliminate consecutive loads

tags/v1.2.0-rc1
Margaret_wangrui 4 years ago
parent
commit
aeb43e5167
1 changed files with 3 additions and 14 deletions
  1. +3
    -14
      mindspore/ccsrc/frontend/optimizer/irpass/load_eliminate.cc

+ 3
- 14
mindspore/ccsrc/frontend/optimizer/irpass/load_eliminate.cc View File

@@ -24,27 +24,16 @@
#include "frontend/operator/ops.h" #include "frontend/operator/ops.h"


namespace mindspore::opt::irpass { namespace mindspore::opt::irpass {
namespace {
// Return true if the node has Ref abstract.
bool HasAbstractRef(const AnfNodePtr &node) {
if (node == nullptr) {
return false;
}
auto &abs = node->abstract();
return (abs != nullptr) && abs->isa<abstract::AbstractRef>();
}
} // namespace

AnfNodePtr LoadEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) { AnfNodePtr LoadEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
auto load_node = dyn_cast<CNode>(node); auto load_node = dyn_cast<CNode>(node);
if (load_node == nullptr || load_node->inputs().empty()) { if (load_node == nullptr || load_node->inputs().empty()) {
MS_LOG(WARNING) << "LoadEliminater encounter invalid node: " << node->DebugString(); MS_LOG(WARNING) << "LoadEliminater encounter invalid node: " << node->DebugString();
return nullptr; return nullptr;
} }
auto load_cnode = load_node->cast<CNodePtr>();
constexpr size_t kFirstInputIndex = 1; constexpr size_t kFirstInputIndex = 1;
auto &param = load_node->inputs().at(kFirstInputIndex);
if (!HasAbstractRef(param)) {
return param;
if (IsPrimitiveCNode(load_cnode->input(kFirstInputIndex), prim::kPrimLoad)) {
return load_cnode->input(kFirstInputIndex);
} }
return nullptr; return nullptr;
} }


Loading…
Cancel
Save