| @@ -25,7 +25,7 @@ | |||
| #include "frontend/optimizer/irpass/inline.h" | |||
| #include "frontend/optimizer/irpass/incorporate_call.h" | |||
| #include "frontend/optimizer/irpass/incorporate_getitem.h" | |||
| #include "frontend/optimizer/irpass/item_tuple_eliminate.h" | |||
| #include "frontend/optimizer/irpass/item_tuple_or_list_eliminate.h" | |||
| #include "frontend/optimizer/irpass/mark_interface_fusion.h" | |||
| #include "frontend/optimizer/irpass/merge_addn.h" | |||
| #include "frontend/optimizer/irpass/accumulaten_eliminate.h" | |||
| @@ -67,8 +67,9 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||
| MakeSubstitution(std::make_shared<AdjustAllReduceMulAdd>(), "adjust_all_reduce_mul_add", prim::kPrimAddN); | |||
| // ops eliminate | |||
| item_tuple_eliminate_ = MakeSubstitution(std::make_shared<ItemTupleEliminater>(), "item_tuple_eliminate", | |||
| {prim::kPrimTupleGetItem, prim::kPrimTupleSetItem, prim::kPrimListGetItem}); | |||
| item_tuple_or_list_eliminate_ = MakeSubstitution( | |||
| std::make_shared<ItemTupleOrListEliminater>(), "item_tuple_or_list_eliminate", | |||
| {prim::kPrimTupleGetItem, prim::kPrimTupleSetItem, prim::kPrimListGetItem, prim::kPrimListSetItem}); | |||
| tile_eliminate_ = MakeSubstitution(std::make_shared<TileEliminater>(), "tile_eliminate", prim::kPrimTile); | |||
| cast_eliminate_ = MakeSubstitution(std::make_shared<CastEliminater>(), "cast_eliminate", prim::kPrimCast); | |||
| reshape_eliminate_ = MakeSubstitution(std::make_shared<ReshapeEliminater>(), "reshape_eliminate", prim::kPrimReshape); | |||
| @@ -39,7 +39,7 @@ class OptimizeIRPassLib { | |||
| SubstitutionPtr adjust_all_reduce_mul_add_; | |||
| // ops eliminate | |||
| SubstitutionPtr item_tuple_eliminate_; | |||
| SubstitutionPtr item_tuple_or_list_eliminate_; | |||
| SubstitutionPtr tile_eliminate_; | |||
| SubstitutionPtr cast_eliminate_; | |||
| SubstitutionPtr reshape_eliminate_; | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_ | |||
| #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_ | |||
| #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ITEM_TUPLE_OR_LIST_ELIMINATE_H_ | |||
| #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ITEM_TUPLE_OR_LIST_ELIMINATE_H_ | |||
| #include <algorithm> | |||
| #include <memory> | |||
| @@ -33,6 +33,7 @@ namespace irpass { | |||
| // (a, b, c, ...)[0] => a | |||
| // (a, b, c, ...)[1] => b | |||
| // {prim::kPrimTupleGetItem, {prim::kPrimMakeTuple, Xs}, C} | |||
| // {prim::kPrimListGetItem, {prim::kPrimMakeList, Xs}, C} | |||
| class GetitemEliminater : public AnfVisitor { | |||
| public: | |||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||
| @@ -54,7 +55,7 @@ class GetitemEliminater : public AnfVisitor { | |||
| void Visit(const ValueNodePtr &vnode) override { | |||
| if (tuple_ != nullptr && IsValueNode<Int64Imm>(vnode)) { | |||
| int64_t idx = GetValue<int64_t>(vnode->value()); | |||
| auto idx = GetValue<int64_t>(vnode->value()); | |||
| if (idx < 0) { | |||
| idx = idx + tuple_->size() - 1; | |||
| } | |||
| @@ -80,6 +81,7 @@ class GetitemEliminater : public AnfVisitor { | |||
| // (a, b, c, ...)[0] => a | |||
| // (a, b, c, ...)[1] => b | |||
| // {prim::kPrimTupleGetItem, C1, C} | |||
| // {prim::kPrimListGetItem, C1, C} | |||
| class GetitemConstEliminater : public AnfVisitor { | |||
| public: | |||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||
| @@ -124,11 +126,13 @@ class GetitemConstEliminater : public AnfVisitor { | |||
| // setitem((a, b, c, ...), 0, z) => (z, b, c, ...) | |||
| // setitem((a, b, c, ...), 1, z) => (a, z, c, ...) | |||
| // {prim::kPrimTupleSetItem, {prim::kPrimMakeTuple, Xs}, C, Z} | |||
| // {prim::kPrimListSetItem, {prim::kPrimMakeList, Xs}, C, Z} | |||
| class SetitemEliminater : public AnfVisitor { | |||
| public: | |||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||
| Reset(); | |||
| AnfVisitor::Match(prim::kPrimTupleSetItem, {IsCNode, IsVNode, IsNode})(node); | |||
| AnfVisitor::Match(prim::kPrimListSetItem, {IsCNode, IsVNode, IsNode})(node); | |||
| auto fg = node->func_graph(); | |||
| if (fg != nullptr && z_ != nullptr) { | |||
| @@ -178,11 +182,13 @@ class SetitemEliminater : public AnfVisitor { | |||
| }; | |||
| // {prim::kPrimTupleGetItem, {prim::kPrimTupleSetItem, Y, C1, X}, C2} | |||
| // {prim::kPrimListGetItem, {prim::kPrimListSetItem, Y, C1, X}, C2} | |||
| class GetSetitemEliminater : public AnfVisitor { | |||
| public: | |||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||
| Reset(); | |||
| AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsVNode})(node); | |||
| AnfVisitor::Match(prim::kPrimListGetItem, {IsCNode, IsVNode})(node); | |||
| auto fg = node->func_graph(); | |||
| if (fg != nullptr && key1_ >= 0 && key2_ >= 0) { | |||
| @@ -195,7 +201,7 @@ class GetSetitemEliminater : public AnfVisitor { | |||
| } | |||
| void Visit(const CNodePtr &cnode) override { | |||
| if (IsPrimitiveCNode(cnode, prim::kPrimTupleSetItem)) { | |||
| if (IsPrimitiveCNode(cnode, prim::kPrimTupleSetItem) || IsPrimitiveCNode(cnode, prim::kPrimListSetItem)) { | |||
| if (cnode->size() < 4) { | |||
| return; | |||
| } | |||
| @@ -239,6 +245,8 @@ class GetSetitemEliminater : public AnfVisitor { | |||
| // {prim::kPrimTupleGetItem, {prim::kPrimDepend, X, Y}, C} -> | |||
| // {prim::kPrimDepend, {prim::kPrimTupleGetItem, X, C}, Y} | |||
| // {prim::kPrimListGetItem, {prim::kPrimDepend, X, Y}, C} -> | |||
| // {prim::kPrimDepend, {prim::kPrimListGetItem, X, C}, Y} | |||
| class GetitemDependReorder : public AnfVisitor { | |||
| public: | |||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||
| @@ -274,9 +282,9 @@ class GetitemDependReorder : public AnfVisitor { | |||
| AnfNodePtr x_{nullptr}, y_{nullptr}, c_{nullptr}; | |||
| }; | |||
| class ItemTupleEliminater : public OptimizerCaller { | |||
| class ItemTupleOrListEliminater : public OptimizerCaller { | |||
| public: | |||
| ItemTupleEliminater() | |||
| ItemTupleOrListEliminater() | |||
| : get_item_eliminater_(std::make_shared<GetitemEliminater>()), | |||
| get_item_const_eliminater_(std::make_shared<GetitemConstEliminater>()), | |||
| set_item_eliminater_(std::make_shared<SetitemEliminater>()), | |||
| @@ -288,7 +296,7 @@ class ItemTupleEliminater : public OptimizerCaller { | |||
| eliminaters_.emplace_back(get_set_item_eliminater_); | |||
| eliminaters_.emplace_back(get_item_depend_reorder_); | |||
| } | |||
| ~ItemTupleEliminater() = default; | |||
| ~ItemTupleOrListEliminater() = default; | |||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { | |||
| AnfNodePtr new_node; | |||
| @@ -309,4 +317,4 @@ class ItemTupleEliminater : public OptimizerCaller { | |||
| } // namespace irpass | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_ | |||
| #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ITEM_TUPLE_OR_LIST_ELIMINATE_H_ | |||
| @@ -100,7 +100,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| irpass.specialize_transform_, | |||
| // Miscellaneous | |||
| irpass.item_tuple_eliminate_, | |||
| irpass.item_tuple_or_list_eliminate_, | |||
| irpass.env_get_item_eliminate_, | |||
| irpass.cast_eliminate_, | |||
| irpass.reshape_eliminate_, | |||
| @@ -188,8 +188,9 @@ OptPassGroupMap GetOptPassesAfterCconv(const opt::irpass::OptimizeIRPassLib &irp | |||
| } | |||
| OptPassGroupMap GetOptPassesTransformGraph(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| opt::OptPassConfig d_1 = opt::OptPassConfig({// Safe inlining | |||
| irpass.call_graph_tuple_transform_, irpass.item_tuple_eliminate_}); | |||
| opt::OptPassConfig d_1 = | |||
| opt::OptPassConfig({// Safe inlining | |||
| irpass.call_graph_tuple_transform_, irpass.item_tuple_or_list_eliminate_}); | |||
| OptPassGroupMap map_a({{"d_1", d_1}, {"renormalize", opt::OptPassConfig::Renormalize()}}); | |||
| @@ -198,7 +199,7 @@ OptPassGroupMap GetOptPassesTransformGraph(const opt::irpass::OptimizeIRPassLib | |||
| OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| opt::OptPassConfig b_1 = opt::OptPassConfig( | |||
| {irpass.zero_like_fill_zero_, irpass.item_tuple_eliminate_, irpass.float_tuple_getitem_switch_, | |||
| {irpass.zero_like_fill_zero_, irpass.item_tuple_or_list_eliminate_, irpass.float_tuple_getitem_switch_, | |||
| irpass.reset_defer_inline_, irpass.inline_, irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_, | |||
| irpass.incorporate_env_getitem_, irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_, | |||
| irpass.incorporate_env_getitem_switch_layer_, irpass.value_based_eliminate_, irpass.receive_eliminate_}); | |||
| @@ -232,7 +232,7 @@ def ms_function(fn=None, obj=None, input_signature=None): | |||
| equal to the case when `fn` is not None. | |||
| Examples: | |||
| >>> from mindspore.ops import functional as F | |||
| >>> from mindspore.ops import functional as F | |||
| ... | |||
| >>> def tensor_add(x, y): | |||
| ... z = x + y | |||
| @@ -360,7 +360,7 @@ TEST_F(TestOptLib, test_tuple_getitem) { | |||
| FuncGraphPtr after_2 = std::make_shared<FuncGraph>(); | |||
| after_2->set_output(value_node_2); | |||
| auto patterns = std::vector<SubstitutionPtr>({irpass.item_tuple_eliminate_}); | |||
| auto patterns = std::vector<SubstitutionPtr>({irpass.item_tuple_or_list_eliminate_}); | |||
| ASSERT_TRUE(CheckOpt(make_get_0, after_0, patterns)); | |||
| ASSERT_TRUE(CheckOpt(make_get_1, after_1, patterns)); | |||
| ASSERT_TRUE(CheckOpt(make_get_const, after_2, patterns)); | |||
| @@ -372,7 +372,7 @@ TEST_F(TestOptLib, test_tuple_setitem) { | |||
| FuncGraphPtr after_0 = getPyFun.CallAndParseRet("test_tuple_setitem", "after_0"); | |||
| FuncGraphPtr after_1 = getPyFun.CallAndParseRet("test_tuple_setitem", "after_1"); | |||
| auto patterns = std::vector<SubstitutionPtr>({irpass.item_tuple_eliminate_}); | |||
| auto patterns = std::vector<SubstitutionPtr>({irpass.item_tuple_or_list_eliminate_}); | |||
| ASSERT_TRUE(CheckOpt(before_0, after_0, patterns)); | |||
| ASSERT_TRUE(CheckOpt(before_1, after_1, patterns)); | |||
| @@ -384,7 +384,7 @@ TEST_F(TestOptLib, test_tuple_get_set_item) { | |||
| FuncGraphPtr before_1 = getPyFun.CallAndParseRet("test_tuple_get_set_item", "before_0"); | |||
| FuncGraphPtr after_1 = getPyFun.CallAndParseRet("test_tuple_get_set_item", "after_0"); | |||
| auto patterns = std::vector<SubstitutionPtr>({irpass.item_tuple_eliminate_}); | |||
| auto patterns = std::vector<SubstitutionPtr>({irpass.item_tuple_or_list_eliminate_}); | |||
| ASSERT_TRUE(CheckOpt(before_0, after_0, patterns)); | |||
| ASSERT_TRUE(CheckOpt(before_1, after_1, patterns)); | |||
| @@ -13,9 +13,14 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ test enumerate""" | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import composite as C | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| @@ -168,3 +173,60 @@ def test_list_index_3D_parameter(): | |||
| net = Net() | |||
| net(Tensor(0)) | |||
| def test_const_list_index_3D_bprop(): | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.value = [[1], [2, 2], [[3, 3], [3, 3]]] | |||
| self.relu = P.ReLU() | |||
| def construct(self, input_x): | |||
| list_x = self.value | |||
| list_x[2][0][1] = input_x | |||
| return self.relu(list_x[2][0][1]) | |||
| class GradNet(nn.Cell): | |||
| def __init__(self, net): | |||
| super(GradNet, self).__init__() | |||
| self.net = net | |||
| self.grad_all_with_sens = C.GradOperation(get_all=True, sens_param=True) | |||
| def construct(self, x, sens): | |||
| return self.grad_all_with_sens(self.net)(x, sens) | |||
| net = Net() | |||
| grad_net = GradNet(net) | |||
| x = Tensor(np.arange(2 * 3).reshape(2, 3)) | |||
| sens = Tensor(np.arange(2 * 3).reshape(2, 3)) | |||
| grad_net(x, sens) | |||
| def test_parameter_list_index_3D_bprop(): | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.value = [[1], [2, 2], [[3, 3], [3, 3]]] | |||
| self.relu = P.ReLU() | |||
| def construct(self, x, value): | |||
| list_value = [[x], [x, x], [[x, x], [x, x]]] | |||
| list_value[2][0][1] = value | |||
| return self.relu(list_value[2][0][1]) | |||
| class GradNet(nn.Cell): | |||
| def __init__(self, net): | |||
| super(GradNet, self).__init__() | |||
| self.net = net | |||
| self.grad_all_with_sens = C.GradOperation(get_all=True, sens_param=True) | |||
| def construct(self, x, value, sens): | |||
| return self.grad_all_with_sens(self.net)(x, value, sens) | |||
| net = Net() | |||
| grad_net = GradNet(net) | |||
| x = Tensor(np.arange(2 * 3).reshape(2, 3)) | |||
| value = Tensor(np.ones((2, 3), np.int64)) | |||
| sens = Tensor(np.arange(2 * 3).reshape(2, 3)) | |||
| grad_net(x, value, sens) | |||