|
|
|
@@ -30,11 +30,70 @@ |
|
|
|
namespace mindspore { |
|
|
|
namespace opt { |
|
|
|
namespace irpass { |
|
|
|
// (a, b, c, ...)[-1] => (a, b, c, ...)[length-1] |
|
|
|
// [a, b, c, ...][-1] => [a, b, c, ...][length-1] |
|
|
|
// {prim::kPrimTupleGetItem, T, N} |
|
|
|
// {prim::kPrimListGetItem, L, N} |
|
|
|
// setitem((a, b, c, ...), -1, z) => setitem((a, b, c, ...), length - 1, z) |
|
|
|
// setitem([a, b, c, ...], -1, z) => setitem([a, b, c, ...], length - 1, z) |
|
|
|
// {prim::kPrimTupleSetItem, T, N, Z} |
|
|
|
// {prim::kPrimListSetItem, L, N, Z} |
|
|
|
class ConvertItemIndexToPositive : public AnfVisitor { |
|
|
|
public: |
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { |
|
|
|
Reset(); |
|
|
|
AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsVNode})(node); |
|
|
|
AnfVisitor::Match(prim::kPrimListGetItem, {IsCNode, IsVNode})(node); |
|
|
|
AnfVisitor::Match(prim::kPrimTupleSetItem, {IsCNode, IsVNode, IsNode})(node); |
|
|
|
AnfVisitor::Match(prim::kPrimListSetItem, {IsCNode, IsVNode, IsNode})(node); |
|
|
|
|
|
|
|
if (is_match_) { |
|
|
|
node->cast<CNodePtr>()->set_input(2, NewValueNode(id_)); |
|
|
|
} |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
void Visit(const AnfNodePtr &node) override { |
|
|
|
if (is_match_) { |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
AnfVisitor::Visit(node); |
|
|
|
} |
|
|
|
|
|
|
|
void Visit(const CNodePtr &cnode) override { sequeue_ = cnode; } |
|
|
|
|
|
|
|
void Visit(const ValueNodePtr &vnode) override { |
|
|
|
if (sequeue_ != nullptr && IsValueNode<Int64Imm>(vnode)) { |
|
|
|
auto idx = GetValue<int64_t>(vnode->value()); |
|
|
|
if (idx < 0) { |
|
|
|
auto sequeue_abstract = sequeue_->abstract()->cast<abstract::AbstractSequeuePtr>(); |
|
|
|
if (sequeue_abstract == nullptr) { |
|
|
|
return; |
|
|
|
} |
|
|
|
id_ = idx + sequeue_abstract->size(); |
|
|
|
is_match_ = true; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void Reset() { |
|
|
|
id_ = 0; |
|
|
|
sequeue_ = nullptr; |
|
|
|
is_match_ = false; |
|
|
|
} |
|
|
|
|
|
|
|
private: |
|
|
|
bool is_match_{false}; |
|
|
|
int64_t id_{0}; |
|
|
|
CNodePtr sequeue_{nullptr}; |
|
|
|
}; |
|
|
|
|
|
|
|
// (a, b, c, ...)[0] => a |
|
|
|
// (a, b, c, ...)[1] => b |
|
|
|
// {prim::kPrimTupleGetItem, {prim::kPrimMakeTuple, Xs}, C} |
|
|
|
// {prim::kPrimListGetItem, {prim::kPrimMakeList, Xs}, C} |
|
|
|
class GetitemEliminater : public AnfVisitor { |
|
|
|
class GetitemEliminator : public AnfVisitor { |
|
|
|
public: |
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { |
|
|
|
Reset(); |
|
|
|
@@ -82,7 +141,7 @@ class GetitemEliminater : public AnfVisitor { |
|
|
|
// (a, b, c, ...)[1] => b |
|
|
|
// {prim::kPrimTupleGetItem, C1, C} |
|
|
|
// {prim::kPrimListGetItem, C1, C} |
|
|
|
class GetitemConstEliminater : public AnfVisitor { |
|
|
|
class GetitemConstEliminator : public AnfVisitor { |
|
|
|
public: |
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { |
|
|
|
Reset(); |
|
|
|
@@ -103,8 +162,12 @@ class GetitemConstEliminater : public AnfVisitor { |
|
|
|
has_new_value_ = vnode->has_new_value(); |
|
|
|
} |
|
|
|
if (tuple_ != nullptr && IsValueNode<Int64Imm>(vnode)) { |
|
|
|
id_ = LongToSize(GetValue<int64_t>(vnode->value())); |
|
|
|
if (tuple_->size() > id_) { |
|
|
|
auto idx = GetValue<int64_t>(vnode->value()); |
|
|
|
if (idx < 0) { |
|
|
|
idx = idx + tuple_->size(); |
|
|
|
} |
|
|
|
id_ = LongToSize(idx); |
|
|
|
if (id_ < tuple_->size()) { |
|
|
|
is_match_ = true; |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -127,7 +190,7 @@ class GetitemConstEliminater : public AnfVisitor { |
|
|
|
// setitem((a, b, c, ...), 1, z) => (a, z, c, ...) |
|
|
|
// {prim::kPrimTupleSetItem, {prim::kPrimMakeTuple, Xs}, C, Z} |
|
|
|
// {prim::kPrimListSetItem, {prim::kPrimMakeList, Xs}, C, Z} |
|
|
|
class SetitemEliminater : public AnfVisitor { |
|
|
|
class SetitemEliminator : public AnfVisitor { |
|
|
|
public: |
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { |
|
|
|
Reset(); |
|
|
|
@@ -159,8 +222,12 @@ class SetitemEliminater : public AnfVisitor { |
|
|
|
} |
|
|
|
|
|
|
|
void Visit(const ValueNodePtr &vnode) override { |
|
|
|
if (args_.size() > 0 && IsValueNode<Int64Imm>(vnode)) { |
|
|
|
id_ = LongToSize(GetValue<int64_t>(vnode->value()) + 1); |
|
|
|
if (!args_.empty() && IsValueNode<Int64Imm>(vnode)) { |
|
|
|
auto idx = GetValue<int64_t>(vnode->value()); |
|
|
|
if (idx < 0) { |
|
|
|
idx = idx + args_.size() - 1; |
|
|
|
} |
|
|
|
id_ = LongToSize(idx + 1); |
|
|
|
if (id_ < args_.size()) { |
|
|
|
is_match_ = true; |
|
|
|
} |
|
|
|
@@ -183,7 +250,7 @@ class SetitemEliminater : public AnfVisitor { |
|
|
|
|
|
|
|
// {prim::kPrimTupleGetItem, {prim::kPrimTupleSetItem, Y, C1, X}, C2} |
|
|
|
// {prim::kPrimListGetItem, {prim::kPrimListSetItem, Y, C1, X}, C2} |
|
|
|
class GetSetitemEliminater : public AnfVisitor { |
|
|
|
class GetSetitemEliminator : public AnfVisitor { |
|
|
|
public: |
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { |
|
|
|
Reset(); |
|
|
|
@@ -217,8 +284,15 @@ class GetSetitemEliminater : public AnfVisitor { |
|
|
|
} |
|
|
|
|
|
|
|
void Visit(const ValueNodePtr &vnode) override { |
|
|
|
if (IsValueNode<Int64Imm>(vnode)) { |
|
|
|
if (tuple_ != nullptr && IsValueNode<Int64Imm>(vnode)) { |
|
|
|
auto key = GetValue<int64_t>(vnode->value()); |
|
|
|
if (key < 0) { |
|
|
|
auto sequeue_abstract = tuple_->abstract()->cast<abstract::AbstractSequeuePtr>(); |
|
|
|
if (sequeue_abstract == nullptr) { |
|
|
|
return; |
|
|
|
} |
|
|
|
key = key + sequeue_abstract->size(); |
|
|
|
} |
|
|
|
if (is_in_set_) { |
|
|
|
key1_ = key; |
|
|
|
} else { |
|
|
|
@@ -282,26 +356,28 @@ class GetitemDependReorder : public AnfVisitor { |
|
|
|
AnfNodePtr x_{nullptr}, y_{nullptr}, c_{nullptr}; |
|
|
|
}; |
|
|
|
|
|
|
|
class ItemTupleOrListEliminater : public OptimizerCaller { |
|
|
|
class ItemTupleOrListEliminator : public OptimizerCaller { |
|
|
|
public: |
|
|
|
ItemTupleOrListEliminater() |
|
|
|
: get_item_eliminater_(std::make_shared<GetitemEliminater>()), |
|
|
|
get_item_const_eliminater_(std::make_shared<GetitemConstEliminater>()), |
|
|
|
set_item_eliminater_(std::make_shared<SetitemEliminater>()), |
|
|
|
get_set_item_eliminater_(std::make_shared<GetSetitemEliminater>()), |
|
|
|
get_item_depend_reorder_(std::make_shared<GetitemDependReorder>()) { |
|
|
|
eliminaters_.emplace_back(get_item_eliminater_); |
|
|
|
eliminaters_.emplace_back(get_item_const_eliminater_); |
|
|
|
eliminaters_.emplace_back(set_item_eliminater_); |
|
|
|
eliminaters_.emplace_back(get_set_item_eliminater_); |
|
|
|
eliminaters_.emplace_back(get_item_depend_reorder_); |
|
|
|
ItemTupleOrListEliminator() |
|
|
|
: get_item_eliminator_(std::make_shared<GetitemEliminator>()), |
|
|
|
get_item_const_eliminator_(std::make_shared<GetitemConstEliminator>()), |
|
|
|
set_item_eliminator_(std::make_shared<SetitemEliminator>()), |
|
|
|
get_set_item_eliminator_(std::make_shared<GetSetitemEliminator>()), |
|
|
|
get_item_depend_reorder_(std::make_shared<GetitemDependReorder>()), |
|
|
|
convert_item_index_to_positive_(std::make_shared<ConvertItemIndexToPositive>()) { |
|
|
|
eliminators_.emplace_back(get_item_eliminator_); |
|
|
|
eliminators_.emplace_back(get_item_const_eliminator_); |
|
|
|
eliminators_.emplace_back(set_item_eliminator_); |
|
|
|
eliminators_.emplace_back(get_set_item_eliminator_); |
|
|
|
eliminators_.emplace_back(get_item_depend_reorder_); |
|
|
|
eliminators_.emplace_back(convert_item_index_to_positive_); |
|
|
|
} |
|
|
|
~ItemTupleOrListEliminater() = default; |
|
|
|
~ItemTupleOrListEliminator() = default; |
|
|
|
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { |
|
|
|
AnfNodePtr new_node; |
|
|
|
for (auto &eliminater : eliminaters_) { |
|
|
|
new_node = (*eliminater)(optimizer, node); |
|
|
|
for (auto &eliminator : eliminators_) { |
|
|
|
new_node = (*eliminator)(optimizer, node); |
|
|
|
if (new_node != nullptr) { |
|
|
|
return new_node; |
|
|
|
} |
|
|
|
@@ -310,10 +386,11 @@ class ItemTupleOrListEliminater : public OptimizerCaller { |
|
|
|
} |
|
|
|
|
|
|
|
private: |
|
|
|
OptimizerCallerPtr get_item_eliminater_, get_item_const_eliminater_, set_item_eliminater_, get_set_item_eliminater_, |
|
|
|
get_item_depend_reorder_; |
|
|
|
std::vector<OptimizerCallerPtr> eliminaters_{}; |
|
|
|
OptimizerCallerPtr get_item_eliminator_, get_item_const_eliminator_, set_item_eliminator_, get_set_item_eliminator_, |
|
|
|
get_item_depend_reorder_, convert_item_index_to_positive_; |
|
|
|
std::vector<OptimizerCallerPtr> eliminators_{}; |
|
|
|
}; |
|
|
|
|
|
|
|
} // namespace irpass |
|
|
|
} // namespace opt |
|
|
|
} // namespace mindspore |
|
|
|
|