diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.cc b/mindspore/ccsrc/frontend/optimizer/irpass.cc index 8142b305df..4f01b8d415 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass.cc @@ -25,7 +25,7 @@ #include "frontend/optimizer/irpass/inline.h" #include "frontend/optimizer/irpass/incorporate_call.h" #include "frontend/optimizer/irpass/incorporate_getitem.h" -#include "frontend/optimizer/irpass/item_tuple_eliminate.h" +#include "frontend/optimizer/irpass/item_tuple_or_list_eliminate.h" #include "frontend/optimizer/irpass/mark_interface_fusion.h" #include "frontend/optimizer/irpass/merge_addn.h" #include "frontend/optimizer/irpass/accumulaten_eliminate.h" @@ -67,8 +67,9 @@ OptimizeIRPassLib::OptimizeIRPassLib() { MakeSubstitution(std::make_shared(), "adjust_all_reduce_mul_add", prim::kPrimAddN); // ops eliminate - item_tuple_eliminate_ = MakeSubstitution(std::make_shared(), "item_tuple_eliminate", - {prim::kPrimTupleGetItem, prim::kPrimTupleSetItem, prim::kPrimListGetItem}); + item_tuple_or_list_eliminate_ = MakeSubstitution( + std::make_shared(), "item_tuple_or_list_eliminate", + {prim::kPrimTupleGetItem, prim::kPrimTupleSetItem, prim::kPrimListGetItem, prim::kPrimListSetItem}); tile_eliminate_ = MakeSubstitution(std::make_shared(), "tile_eliminate", prim::kPrimTile); cast_eliminate_ = MakeSubstitution(std::make_shared(), "cast_eliminate", prim::kPrimCast); reshape_eliminate_ = MakeSubstitution(std::make_shared(), "reshape_eliminate", prim::kPrimReshape); diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.h b/mindspore/ccsrc/frontend/optimizer/irpass.h index e5b2371f92..7bbfbc63f0 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass.h @@ -39,7 +39,7 @@ class OptimizeIRPassLib { SubstitutionPtr adjust_all_reduce_mul_add_; // ops eliminate - SubstitutionPtr item_tuple_eliminate_; + SubstitutionPtr item_tuple_or_list_eliminate_; SubstitutionPtr tile_eliminate_; SubstitutionPtr cast_eliminate_; SubstitutionPtr reshape_eliminate_; diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/item_tuple_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/item_tuple_or_list_eliminate.h similarity index 89% rename from mindspore/ccsrc/frontend/optimizer/irpass/item_tuple_eliminate.h rename to mindspore/ccsrc/frontend/optimizer/irpass/item_tuple_or_list_eliminate.h index d943184a13..d09cf1cf00 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/item_tuple_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/item_tuple_or_list_eliminate.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_ -#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_ +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ITEM_TUPLE_OR_LIST_ELIMINATE_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ITEM_TUPLE_OR_LIST_ELIMINATE_H_ #include #include @@ -33,6 +33,7 @@ namespace irpass { // (a, b, c, ...)[0] => a // (a, b, c, ...)[1] => b // {prim::kPrimTupleGetItem, {prim::kPrimMakeTuple, Xs}, C} +// {prim::kPrimListGetItem, {prim::kPrimMakeList, Xs}, C} class GetitemEliminater : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { @@ -54,7 +55,7 @@ class GetitemEliminater : public AnfVisitor { void Visit(const ValueNodePtr &vnode) override { if (tuple_ != nullptr && IsValueNode(vnode)) { - int64_t idx = GetValue(vnode->value()); + auto idx = GetValue(vnode->value()); if (idx < 0) { idx = idx + tuple_->size() - 1; } @@ -80,6 +81,7 @@ class GetitemEliminater : public AnfVisitor { // (a, b, c, ...)[0] => a // (a, b, c, ...)[1] => b // {prim::kPrimTupleGetItem, C1, C} +// {prim::kPrimListGetItem, C1, C} class GetitemConstEliminater : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { @@ -124,11 +126,13 @@ class GetitemConstEliminater : public AnfVisitor { // setitem((a, b, c, ...), 0, z) => (z, b, c, ...) // setitem((a, b, c, ...), 1, z) => (a, z, c, ...) // {prim::kPrimTupleSetItem, {prim::kPrimMakeTuple, Xs}, C, Z} +// {prim::kPrimListSetItem, {prim::kPrimMakeList, Xs}, C, Z} class SetitemEliminater : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { Reset(); AnfVisitor::Match(prim::kPrimTupleSetItem, {IsCNode, IsVNode, IsNode})(node); + AnfVisitor::Match(prim::kPrimListSetItem, {IsCNode, IsVNode, IsNode})(node); auto fg = node->func_graph(); if (fg != nullptr && z_ != nullptr) { @@ -178,11 +182,13 @@ class SetitemEliminater : public AnfVisitor { }; // {prim::kPrimTupleGetItem, {prim::kPrimTupleSetItem, Y, C1, X}, C2} +// {prim::kPrimListGetItem, {prim::kPrimListSetItem, Y, C1, X}, C2} class GetSetitemEliminater : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { Reset(); AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsVNode})(node); + AnfVisitor::Match(prim::kPrimListGetItem, {IsCNode, IsVNode})(node); auto fg = node->func_graph(); if (fg != nullptr && key1_ >= 0 && key2_ >= 0) { @@ -195,7 +201,7 @@ class GetSetitemEliminater : public AnfVisitor { } void Visit(const CNodePtr &cnode) override { - if (IsPrimitiveCNode(cnode, prim::kPrimTupleSetItem)) { + if (IsPrimitiveCNode(cnode, prim::kPrimTupleSetItem) || IsPrimitiveCNode(cnode, prim::kPrimListSetItem)) { if (cnode->size() < 4) { return; } @@ -239,6 +245,8 @@ class GetSetitemEliminater : public AnfVisitor { // {prim::kPrimTupleGetItem, {prim::kPrimDepend, X, Y}, C} -> // {prim::kPrimDepend, {prim::kPrimTupleGetItem, X, C}, Y} +// {prim::kPrimListGetItem, {prim::kPrimDepend, X, Y}, C} -> +// {prim::kPrimDepend, {prim::kPrimListGetItem, X, C}, Y} class GetitemDependReorder : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { @@ -274,9 +282,9 @@ class GetitemDependReorder : public AnfVisitor { AnfNodePtr x_{nullptr}, y_{nullptr}, c_{nullptr}; }; -class ItemTupleEliminater : public OptimizerCaller { +class ItemTupleOrListEliminater : public OptimizerCaller { public: - ItemTupleEliminater() + ItemTupleOrListEliminater() : get_item_eliminater_(std::make_shared()), get_item_const_eliminater_(std::make_shared()), set_item_eliminater_(std::make_shared()), @@ -288,7 +296,7 @@ class ItemTupleEliminater : public OptimizerCaller { eliminaters_.emplace_back(get_set_item_eliminater_); eliminaters_.emplace_back(get_item_depend_reorder_); } - ~ItemTupleEliminater() = default; + ~ItemTupleOrListEliminater() = default; AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { AnfNodePtr new_node; @@ -309,4 +317,4 @@ class ItemTupleEliminater : public OptimizerCaller { } // namespace irpass } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_ +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ITEM_TUPLE_OR_LIST_ELIMINATE_H_ diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc index fe3d7844e4..817d363470 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -100,7 +100,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { irpass.specialize_transform_, // Miscellaneous - irpass.item_tuple_eliminate_, + irpass.item_tuple_or_list_eliminate_, irpass.env_get_item_eliminate_, irpass.cast_eliminate_, irpass.reshape_eliminate_, @@ -188,8 +188,9 @@ OptPassGroupMap GetOptPassesAfterCconv(const opt::irpass::OptimizeIRPassLib &irp } OptPassGroupMap GetOptPassesTransformGraph(const opt::irpass::OptimizeIRPassLib &irpass) { - opt::OptPassConfig d_1 = opt::OptPassConfig({// Safe inlining - irpass.call_graph_tuple_transform_, irpass.item_tuple_eliminate_}); + opt::OptPassConfig d_1 = + opt::OptPassConfig({// Safe inlining + irpass.call_graph_tuple_transform_, irpass.item_tuple_or_list_eliminate_}); OptPassGroupMap map_a({{"d_1", d_1}, {"renormalize", opt::OptPassConfig::Renormalize()}}); @@ -198,7 +199,7 @@ OptPassGroupMap GetOptPassesTransformGraph(const opt::irpass::OptimizeIRPassLib OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { opt::OptPassConfig b_1 = opt::OptPassConfig( - {irpass.zero_like_fill_zero_, irpass.item_tuple_eliminate_, irpass.float_tuple_getitem_switch_, + {irpass.zero_like_fill_zero_, irpass.item_tuple_or_list_eliminate_, irpass.float_tuple_getitem_switch_, irpass.reset_defer_inline_, irpass.inline_, irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_, irpass.incorporate_env_getitem_, irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_, irpass.incorporate_env_getitem_switch_layer_, irpass.value_based_eliminate_, irpass.receive_eliminate_}); diff --git a/mindspore/common/api.py b/mindspore/common/api.py index ba77828f5e..ea315d7658 100644 --- a/mindspore/common/api.py +++ b/mindspore/common/api.py @@ -232,7 +232,7 @@ def ms_function(fn=None, obj=None, input_signature=None): equal to the case when `fn` is not None. Examples: - >>> from mindspore.ops import functional as F + >>> from mindspore.ops import functional as F ... >>> def tensor_add(x, y): ... z = x + y diff --git a/tests/ut/cpp/optimizer/lib_test.cc b/tests/ut/cpp/optimizer/lib_test.cc index a363a056d4..3c794f97a8 100644 --- a/tests/ut/cpp/optimizer/lib_test.cc +++ b/tests/ut/cpp/optimizer/lib_test.cc @@ -360,7 +360,7 @@ TEST_F(TestOptLib, test_tuple_getitem) { FuncGraphPtr after_2 = std::make_shared(); after_2->set_output(value_node_2); - auto patterns = std::vector({irpass.item_tuple_eliminate_}); + auto patterns = std::vector({irpass.item_tuple_or_list_eliminate_}); ASSERT_TRUE(CheckOpt(make_get_0, after_0, patterns)); ASSERT_TRUE(CheckOpt(make_get_1, after_1, patterns)); ASSERT_TRUE(CheckOpt(make_get_const, after_2, patterns)); @@ -372,7 +372,7 @@ TEST_F(TestOptLib, test_tuple_setitem) { FuncGraphPtr after_0 = getPyFun.CallAndParseRet("test_tuple_setitem", "after_0"); FuncGraphPtr after_1 = getPyFun.CallAndParseRet("test_tuple_setitem", "after_1"); - auto patterns = std::vector({irpass.item_tuple_eliminate_}); + auto patterns = std::vector({irpass.item_tuple_or_list_eliminate_}); ASSERT_TRUE(CheckOpt(before_0, after_0, patterns)); ASSERT_TRUE(CheckOpt(before_1, after_1, patterns)); @@ -384,7 +384,7 @@ TEST_F(TestOptLib, test_tuple_get_set_item) { FuncGraphPtr before_1 = getPyFun.CallAndParseRet("test_tuple_get_set_item", "before_0"); FuncGraphPtr after_1 = getPyFun.CallAndParseRet("test_tuple_get_set_item", "after_0"); - auto patterns = std::vector({irpass.item_tuple_eliminate_}); + auto patterns = std::vector({irpass.item_tuple_or_list_eliminate_}); ASSERT_TRUE(CheckOpt(before_0, after_0, patterns)); ASSERT_TRUE(CheckOpt(before_1, after_1, patterns)); diff --git a/tests/ut/python/pipeline/parse/test_sequence_assign.py b/tests/ut/python/pipeline/parse/test_sequence_assign.py index 29b7bc11a5..255f40a4aa 100644 --- a/tests/ut/python/pipeline/parse/test_sequence_assign.py +++ b/tests/ut/python/pipeline/parse/test_sequence_assign.py @@ -13,9 +13,14 @@ # limitations under the License. # ============================================================================ """ test enumerate""" + +import numpy as np + import mindspore.nn as nn from mindspore import Tensor from mindspore import context +from mindspore.ops import operations as P +from mindspore.ops import composite as C context.set_context(mode=context.GRAPH_MODE) @@ -168,3 +173,60 @@ def test_list_index_3D_parameter(): net = Net() net(Tensor(0)) + + +def test_const_list_index_3D_bprop(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.value = [[1], [2, 2], [[3, 3], [3, 3]]] + self.relu = P.ReLU() + + def construct(self, input_x): + list_x = self.value + list_x[2][0][1] = input_x + return self.relu(list_x[2][0][1]) + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.grad_all_with_sens = C.GradOperation(get_all=True, sens_param=True) + + def construct(self, x, sens): + return self.grad_all_with_sens(self.net)(x, sens) + + net = Net() + grad_net = GradNet(net) + x = Tensor(np.arange(2 * 3).reshape(2, 3)) + sens = Tensor(np.arange(2 * 3).reshape(2, 3)) + grad_net(x, sens) + + +def test_parameter_list_index_3D_bprop(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.value = [[1], [2, 2], [[3, 3], [3, 3]]] + self.relu = P.ReLU() + + def construct(self, x, value): + list_value = [[x], [x, x], [[x, x], [x, x]]] + list_value[2][0][1] = value + return self.relu(list_value[2][0][1]) + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.grad_all_with_sens = C.GradOperation(get_all=True, sens_param=True) + + def construct(self, x, value, sens): + return self.grad_all_with_sens(self.net)(x, value, sens) + + net = Net() + grad_net = GradNet(net) + x = Tensor(np.arange(2 * 3).reshape(2, 3)) + value = Tensor(np.ones((2, 3), np.int64)) + sens = Tensor(np.arange(2 * 3).reshape(2, 3)) + grad_net(x, value, sens)