Browse Source

add opt for list

tags/v0.6.0-beta
Wei Luning 5 years ago
parent
commit
0d2495c5ce
2 changed files with 11 additions and 4 deletions
  1. +1
    -1
      mindspore/ccsrc/frontend/optimizer/irpass.cc
  2. +10
    -3
      mindspore/ccsrc/frontend/optimizer/irpass/item_tuple_eliminate.h

+ 1
- 1
mindspore/ccsrc/frontend/optimizer/irpass.cc View File

@@ -64,7 +64,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() {

// ops eliminate
item_tuple_eliminate_ = MakeSubstitution(std::make_shared<ItemTupleEliminater>(), "item_tuple_eliminate",
{prim::kPrimTupleGetItem, prim::kPrimTupleSetItem});
{prim::kPrimTupleGetItem, prim::kPrimTupleSetItem, prim::kPrimListGetItem});
tile_eliminate_ = MakeSubstitution(std::make_shared<TileMultiplyByOne>(), "tile_eliminate", prim::kPrimTile);
cast_eliminate_ = MakeSubstitution(std::make_shared<CastEliminater>(), "cast_eliminate", prim::kPrimCast);
reshape_eliminate_ = MakeSubstitution(std::make_shared<ReshapeEliminater>(), "reshape_eliminate", prim::kPrimReshape);


+ 10
- 3
mindspore/ccsrc/frontend/optimizer/irpass/item_tuple_eliminate.h View File

@@ -38,6 +38,7 @@ class GetitemEliminater : public AnfVisitor {
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsVNode})(node);
AnfVisitor::Match(prim::kPrimListGetItem, {IsCNode, IsVNode})(node);

if (is_match_) {
return tuple_->input(id_);
@@ -46,14 +47,18 @@ class GetitemEliminater : public AnfVisitor {
}

void Visit(const CNodePtr &cnode) override {
if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) {
if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple) || IsPrimitiveCNode(cnode, prim::kPrimMakeList)) {
tuple_ = cnode;
}
}

void Visit(const ValueNodePtr &vnode) override {
if (tuple_ != nullptr && IsValueNode<Int32Imm>(vnode)) {
id_ = IntToSize(GetValue<int>(vnode->value()) + 1);
int idx = GetValue<int>(vnode->value());
if (idx < 0) {
idx = idx + tuple_->size() - 1;
}
id_ = IntToSize(idx + 1);
if (tuple_->size() > id_) {
is_match_ = true;
}
@@ -80,6 +85,7 @@ class GetitemConstEliminater : public AnfVisitor {
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
AnfVisitor::Match(prim::kPrimTupleGetItem, {IsVNode, IsVNode})(node);
AnfVisitor::Match(prim::kPrimListGetItem, {IsVNode, IsVNode})(node);

if (is_match_) {
return NewValueNode((*tuple_)[id_]);
@@ -138,7 +144,7 @@ class SetitemEliminater : public AnfVisitor {
}

void Visit(const CNodePtr &cnode) override {
if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) {
if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple) || IsPrimitiveCNode(cnode, prim::kPrimMakeList)) {
auto &inputs = cnode->inputs();
(void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(args_));
}
@@ -234,6 +240,7 @@ class GetitemDependReorder : public AnfVisitor {
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsValueNode<Int32Imm>})(node);
AnfVisitor::Match(prim::kPrimListGetItem, {IsCNode, IsValueNode<Int32Imm>})(node);
if (x_ == nullptr) {
return nullptr;
}


Loading…
Cancel
Save