| @@ -25,7 +25,7 @@ | |||||
| #include "frontend/optimizer/irpass/inline.h" | #include "frontend/optimizer/irpass/inline.h" | ||||
| #include "frontend/optimizer/irpass/incorporate_call.h" | #include "frontend/optimizer/irpass/incorporate_call.h" | ||||
| #include "frontend/optimizer/irpass/incorporate_getitem.h" | #include "frontend/optimizer/irpass/incorporate_getitem.h" | ||||
| #include "frontend/optimizer/irpass/item_tuple_eliminate.h" | |||||
| #include "frontend/optimizer/irpass/item_tuple_or_list_eliminate.h" | |||||
| #include "frontend/optimizer/irpass/mark_interface_fusion.h" | #include "frontend/optimizer/irpass/mark_interface_fusion.h" | ||||
| #include "frontend/optimizer/irpass/merge_addn.h" | #include "frontend/optimizer/irpass/merge_addn.h" | ||||
| #include "frontend/optimizer/irpass/accumulaten_eliminate.h" | #include "frontend/optimizer/irpass/accumulaten_eliminate.h" | ||||
| @@ -67,8 +67,9 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||||
| MakeSubstitution(std::make_shared<AdjustAllReduceMulAdd>(), "adjust_all_reduce_mul_add", prim::kPrimAddN); | MakeSubstitution(std::make_shared<AdjustAllReduceMulAdd>(), "adjust_all_reduce_mul_add", prim::kPrimAddN); | ||||
| // ops eliminate | // ops eliminate | ||||
| item_tuple_eliminate_ = MakeSubstitution(std::make_shared<ItemTupleEliminater>(), "item_tuple_eliminate", | |||||
| {prim::kPrimTupleGetItem, prim::kPrimTupleSetItem, prim::kPrimListGetItem}); | |||||
| item_tuple_or_list_eliminate_ = MakeSubstitution( | |||||
| std::make_shared<ItemTupleOrListEliminater>(), "item_tuple_or_list_eliminate", | |||||
| {prim::kPrimTupleGetItem, prim::kPrimTupleSetItem, prim::kPrimListGetItem, prim::kPrimListSetItem}); | |||||
| tile_eliminate_ = MakeSubstitution(std::make_shared<TileEliminater>(), "tile_eliminate", prim::kPrimTile); | tile_eliminate_ = MakeSubstitution(std::make_shared<TileEliminater>(), "tile_eliminate", prim::kPrimTile); | ||||
| cast_eliminate_ = MakeSubstitution(std::make_shared<CastEliminater>(), "cast_eliminate", prim::kPrimCast); | cast_eliminate_ = MakeSubstitution(std::make_shared<CastEliminater>(), "cast_eliminate", prim::kPrimCast); | ||||
| reshape_eliminate_ = MakeSubstitution(std::make_shared<ReshapeEliminater>(), "reshape_eliminate", prim::kPrimReshape); | reshape_eliminate_ = MakeSubstitution(std::make_shared<ReshapeEliminater>(), "reshape_eliminate", prim::kPrimReshape); | ||||
| @@ -39,7 +39,7 @@ class OptimizeIRPassLib { | |||||
| SubstitutionPtr adjust_all_reduce_mul_add_; | SubstitutionPtr adjust_all_reduce_mul_add_; | ||||
| // ops eliminate | // ops eliminate | ||||
| SubstitutionPtr item_tuple_eliminate_; | |||||
| SubstitutionPtr item_tuple_or_list_eliminate_; | |||||
| SubstitutionPtr tile_eliminate_; | SubstitutionPtr tile_eliminate_; | ||||
| SubstitutionPtr cast_eliminate_; | SubstitutionPtr cast_eliminate_; | ||||
| SubstitutionPtr reshape_eliminate_; | SubstitutionPtr reshape_eliminate_; | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_ | |||||
| #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_ | |||||
| #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ITEM_TUPLE_OR_LIST_ELIMINATE_H_ | |||||
| #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ITEM_TUPLE_OR_LIST_ELIMINATE_H_ | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <memory> | #include <memory> | ||||
| @@ -33,6 +33,7 @@ namespace irpass { | |||||
| // (a, b, c, ...)[0] => a | // (a, b, c, ...)[0] => a | ||||
| // (a, b, c, ...)[1] => b | // (a, b, c, ...)[1] => b | ||||
| // {prim::kPrimTupleGetItem, {prim::kPrimMakeTuple, Xs}, C} | // {prim::kPrimTupleGetItem, {prim::kPrimMakeTuple, Xs}, C} | ||||
| // {prim::kPrimListGetItem, {prim::kPrimMakeList, Xs}, C} | |||||
| class GetitemEliminater : public AnfVisitor { | class GetitemEliminater : public AnfVisitor { | ||||
| public: | public: | ||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | ||||
| @@ -54,7 +55,7 @@ class GetitemEliminater : public AnfVisitor { | |||||
| void Visit(const ValueNodePtr &vnode) override { | void Visit(const ValueNodePtr &vnode) override { | ||||
| if (tuple_ != nullptr && IsValueNode<Int64Imm>(vnode)) { | if (tuple_ != nullptr && IsValueNode<Int64Imm>(vnode)) { | ||||
| int64_t idx = GetValue<int64_t>(vnode->value()); | |||||
| auto idx = GetValue<int64_t>(vnode->value()); | |||||
| if (idx < 0) { | if (idx < 0) { | ||||
| idx = idx + tuple_->size() - 1; | idx = idx + tuple_->size() - 1; | ||||
| } | } | ||||
| @@ -80,6 +81,7 @@ class GetitemEliminater : public AnfVisitor { | |||||
| // (a, b, c, ...)[0] => a | // (a, b, c, ...)[0] => a | ||||
| // (a, b, c, ...)[1] => b | // (a, b, c, ...)[1] => b | ||||
| // {prim::kPrimTupleGetItem, C1, C} | // {prim::kPrimTupleGetItem, C1, C} | ||||
| // {prim::kPrimListGetItem, C1, C} | |||||
| class GetitemConstEliminater : public AnfVisitor { | class GetitemConstEliminater : public AnfVisitor { | ||||
| public: | public: | ||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | ||||
| @@ -124,11 +126,13 @@ class GetitemConstEliminater : public AnfVisitor { | |||||
| // setitem((a, b, c, ...), 0, z) => (z, b, c, ...) | // setitem((a, b, c, ...), 0, z) => (z, b, c, ...) | ||||
| // setitem((a, b, c, ...), 1, z) => (a, z, c, ...) | // setitem((a, b, c, ...), 1, z) => (a, z, c, ...) | ||||
| // {prim::kPrimTupleSetItem, {prim::kPrimMakeTuple, Xs}, C, Z} | // {prim::kPrimTupleSetItem, {prim::kPrimMakeTuple, Xs}, C, Z} | ||||
| // {prim::kPrimListSetItem, {prim::kPrimMakeList, Xs}, C, Z} | |||||
| class SetitemEliminater : public AnfVisitor { | class SetitemEliminater : public AnfVisitor { | ||||
| public: | public: | ||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | ||||
| Reset(); | Reset(); | ||||
| AnfVisitor::Match(prim::kPrimTupleSetItem, {IsCNode, IsVNode, IsNode})(node); | AnfVisitor::Match(prim::kPrimTupleSetItem, {IsCNode, IsVNode, IsNode})(node); | ||||
| AnfVisitor::Match(prim::kPrimListSetItem, {IsCNode, IsVNode, IsNode})(node); | |||||
| auto fg = node->func_graph(); | auto fg = node->func_graph(); | ||||
| if (fg != nullptr && z_ != nullptr) { | if (fg != nullptr && z_ != nullptr) { | ||||
| @@ -178,11 +182,13 @@ class SetitemEliminater : public AnfVisitor { | |||||
| }; | }; | ||||
| // {prim::kPrimTupleGetItem, {prim::kPrimTupleSetItem, Y, C1, X}, C2} | // {prim::kPrimTupleGetItem, {prim::kPrimTupleSetItem, Y, C1, X}, C2} | ||||
| // {prim::kPrimListGetItem, {prim::kPrimListSetItem, Y, C1, X}, C2} | |||||
| class GetSetitemEliminater : public AnfVisitor { | class GetSetitemEliminater : public AnfVisitor { | ||||
| public: | public: | ||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | ||||
| Reset(); | Reset(); | ||||
| AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsVNode})(node); | AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsVNode})(node); | ||||
| AnfVisitor::Match(prim::kPrimListGetItem, {IsCNode, IsVNode})(node); | |||||
| auto fg = node->func_graph(); | auto fg = node->func_graph(); | ||||
| if (fg != nullptr && key1_ >= 0 && key2_ >= 0) { | if (fg != nullptr && key1_ >= 0 && key2_ >= 0) { | ||||
| @@ -195,7 +201,7 @@ class GetSetitemEliminater : public AnfVisitor { | |||||
| } | } | ||||
| void Visit(const CNodePtr &cnode) override { | void Visit(const CNodePtr &cnode) override { | ||||
| if (IsPrimitiveCNode(cnode, prim::kPrimTupleSetItem)) { | |||||
| if (IsPrimitiveCNode(cnode, prim::kPrimTupleSetItem) || IsPrimitiveCNode(cnode, prim::kPrimListSetItem)) { | |||||
| if (cnode->size() < 4) { | if (cnode->size() < 4) { | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -239,6 +245,8 @@ class GetSetitemEliminater : public AnfVisitor { | |||||
| // {prim::kPrimTupleGetItem, {prim::kPrimDepend, X, Y}, C} -> | // {prim::kPrimTupleGetItem, {prim::kPrimDepend, X, Y}, C} -> | ||||
| // {prim::kPrimDepend, {prim::kPrimTupleGetItem, X, C}, Y} | // {prim::kPrimDepend, {prim::kPrimTupleGetItem, X, C}, Y} | ||||
| // {prim::kPrimListGetItem, {prim::kPrimDepend, X, Y}, C} -> | |||||
| // {prim::kPrimDepend, {prim::kPrimListGetItem, X, C}, Y} | |||||
| class GetitemDependReorder : public AnfVisitor { | class GetitemDependReorder : public AnfVisitor { | ||||
| public: | public: | ||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | ||||
| @@ -274,9 +282,9 @@ class GetitemDependReorder : public AnfVisitor { | |||||
| AnfNodePtr x_{nullptr}, y_{nullptr}, c_{nullptr}; | AnfNodePtr x_{nullptr}, y_{nullptr}, c_{nullptr}; | ||||
| }; | }; | ||||
| class ItemTupleEliminater : public OptimizerCaller { | |||||
| class ItemTupleOrListEliminater : public OptimizerCaller { | |||||
| public: | public: | ||||
| ItemTupleEliminater() | |||||
| ItemTupleOrListEliminater() | |||||
| : get_item_eliminater_(std::make_shared<GetitemEliminater>()), | : get_item_eliminater_(std::make_shared<GetitemEliminater>()), | ||||
| get_item_const_eliminater_(std::make_shared<GetitemConstEliminater>()), | get_item_const_eliminater_(std::make_shared<GetitemConstEliminater>()), | ||||
| set_item_eliminater_(std::make_shared<SetitemEliminater>()), | set_item_eliminater_(std::make_shared<SetitemEliminater>()), | ||||
| @@ -288,7 +296,7 @@ class ItemTupleEliminater : public OptimizerCaller { | |||||
| eliminaters_.emplace_back(get_set_item_eliminater_); | eliminaters_.emplace_back(get_set_item_eliminater_); | ||||
| eliminaters_.emplace_back(get_item_depend_reorder_); | eliminaters_.emplace_back(get_item_depend_reorder_); | ||||
| } | } | ||||
| ~ItemTupleEliminater() = default; | |||||
| ~ItemTupleOrListEliminater() = default; | |||||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { | AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { | ||||
| AnfNodePtr new_node; | AnfNodePtr new_node; | ||||
| @@ -309,4 +317,4 @@ class ItemTupleEliminater : public OptimizerCaller { | |||||
| } // namespace irpass | } // namespace irpass | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_ | |||||
| #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ITEM_TUPLE_OR_LIST_ELIMINATE_H_ | |||||
| @@ -100,7 +100,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { | |||||
| irpass.specialize_transform_, | irpass.specialize_transform_, | ||||
| // Miscellaneous | // Miscellaneous | ||||
| irpass.item_tuple_eliminate_, | |||||
| irpass.item_tuple_or_list_eliminate_, | |||||
| irpass.env_get_item_eliminate_, | irpass.env_get_item_eliminate_, | ||||
| irpass.cast_eliminate_, | irpass.cast_eliminate_, | ||||
| irpass.reshape_eliminate_, | irpass.reshape_eliminate_, | ||||
| @@ -188,8 +188,9 @@ OptPassGroupMap GetOptPassesAfterCconv(const opt::irpass::OptimizeIRPassLib &irp | |||||
| } | } | ||||
| OptPassGroupMap GetOptPassesTransformGraph(const opt::irpass::OptimizeIRPassLib &irpass) { | OptPassGroupMap GetOptPassesTransformGraph(const opt::irpass::OptimizeIRPassLib &irpass) { | ||||
| opt::OptPassConfig d_1 = opt::OptPassConfig({// Safe inlining | |||||
| irpass.call_graph_tuple_transform_, irpass.item_tuple_eliminate_}); | |||||
| opt::OptPassConfig d_1 = | |||||
| opt::OptPassConfig({// Safe inlining | |||||
| irpass.call_graph_tuple_transform_, irpass.item_tuple_or_list_eliminate_}); | |||||
| OptPassGroupMap map_a({{"d_1", d_1}, {"renormalize", opt::OptPassConfig::Renormalize()}}); | OptPassGroupMap map_a({{"d_1", d_1}, {"renormalize", opt::OptPassConfig::Renormalize()}}); | ||||
| @@ -198,7 +199,7 @@ OptPassGroupMap GetOptPassesTransformGraph(const opt::irpass::OptimizeIRPassLib | |||||
| OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { | OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { | ||||
| opt::OptPassConfig b_1 = opt::OptPassConfig( | opt::OptPassConfig b_1 = opt::OptPassConfig( | ||||
| {irpass.zero_like_fill_zero_, irpass.item_tuple_eliminate_, irpass.float_tuple_getitem_switch_, | |||||
| {irpass.zero_like_fill_zero_, irpass.item_tuple_or_list_eliminate_, irpass.float_tuple_getitem_switch_, | |||||
| irpass.reset_defer_inline_, irpass.inline_, irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_, | irpass.reset_defer_inline_, irpass.inline_, irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_, | ||||
| irpass.incorporate_env_getitem_, irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_, | irpass.incorporate_env_getitem_, irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_, | ||||
| irpass.incorporate_env_getitem_switch_layer_, irpass.value_based_eliminate_, irpass.receive_eliminate_}); | irpass.incorporate_env_getitem_switch_layer_, irpass.value_based_eliminate_, irpass.receive_eliminate_}); | ||||
| @@ -232,7 +232,7 @@ def ms_function(fn=None, obj=None, input_signature=None): | |||||
| equal to the case when `fn` is not None. | equal to the case when `fn` is not None. | ||||
| Examples: | Examples: | ||||
| >>> from mindspore.ops import functional as F | |||||
| >>> from mindspore.ops import functional as F | |||||
| ... | ... | ||||
| >>> def tensor_add(x, y): | >>> def tensor_add(x, y): | ||||
| ... z = x + y | ... z = x + y | ||||
| @@ -360,7 +360,7 @@ TEST_F(TestOptLib, test_tuple_getitem) { | |||||
| FuncGraphPtr after_2 = std::make_shared<FuncGraph>(); | FuncGraphPtr after_2 = std::make_shared<FuncGraph>(); | ||||
| after_2->set_output(value_node_2); | after_2->set_output(value_node_2); | ||||
| auto patterns = std::vector<SubstitutionPtr>({irpass.item_tuple_eliminate_}); | |||||
| auto patterns = std::vector<SubstitutionPtr>({irpass.item_tuple_or_list_eliminate_}); | |||||
| ASSERT_TRUE(CheckOpt(make_get_0, after_0, patterns)); | ASSERT_TRUE(CheckOpt(make_get_0, after_0, patterns)); | ||||
| ASSERT_TRUE(CheckOpt(make_get_1, after_1, patterns)); | ASSERT_TRUE(CheckOpt(make_get_1, after_1, patterns)); | ||||
| ASSERT_TRUE(CheckOpt(make_get_const, after_2, patterns)); | ASSERT_TRUE(CheckOpt(make_get_const, after_2, patterns)); | ||||
| @@ -372,7 +372,7 @@ TEST_F(TestOptLib, test_tuple_setitem) { | |||||
| FuncGraphPtr after_0 = getPyFun.CallAndParseRet("test_tuple_setitem", "after_0"); | FuncGraphPtr after_0 = getPyFun.CallAndParseRet("test_tuple_setitem", "after_0"); | ||||
| FuncGraphPtr after_1 = getPyFun.CallAndParseRet("test_tuple_setitem", "after_1"); | FuncGraphPtr after_1 = getPyFun.CallAndParseRet("test_tuple_setitem", "after_1"); | ||||
| auto patterns = std::vector<SubstitutionPtr>({irpass.item_tuple_eliminate_}); | |||||
| auto patterns = std::vector<SubstitutionPtr>({irpass.item_tuple_or_list_eliminate_}); | |||||
| ASSERT_TRUE(CheckOpt(before_0, after_0, patterns)); | ASSERT_TRUE(CheckOpt(before_0, after_0, patterns)); | ||||
| ASSERT_TRUE(CheckOpt(before_1, after_1, patterns)); | ASSERT_TRUE(CheckOpt(before_1, after_1, patterns)); | ||||
| @@ -384,7 +384,7 @@ TEST_F(TestOptLib, test_tuple_get_set_item) { | |||||
| FuncGraphPtr before_1 = getPyFun.CallAndParseRet("test_tuple_get_set_item", "before_0"); | FuncGraphPtr before_1 = getPyFun.CallAndParseRet("test_tuple_get_set_item", "before_0"); | ||||
| FuncGraphPtr after_1 = getPyFun.CallAndParseRet("test_tuple_get_set_item", "after_0"); | FuncGraphPtr after_1 = getPyFun.CallAndParseRet("test_tuple_get_set_item", "after_0"); | ||||
| auto patterns = std::vector<SubstitutionPtr>({irpass.item_tuple_eliminate_}); | |||||
| auto patterns = std::vector<SubstitutionPtr>({irpass.item_tuple_or_list_eliminate_}); | |||||
| ASSERT_TRUE(CheckOpt(before_0, after_0, patterns)); | ASSERT_TRUE(CheckOpt(before_0, after_0, patterns)); | ||||
| ASSERT_TRUE(CheckOpt(before_1, after_1, patterns)); | ASSERT_TRUE(CheckOpt(before_1, after_1, patterns)); | ||||
| @@ -13,9 +13,14 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """ test enumerate""" | """ test enumerate""" | ||||
| import numpy as np | |||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| from mindspore import context | from mindspore import context | ||||
| from mindspore.ops import operations as P | |||||
| from mindspore.ops import composite as C | |||||
| context.set_context(mode=context.GRAPH_MODE) | context.set_context(mode=context.GRAPH_MODE) | ||||
| @@ -168,3 +173,60 @@ def test_list_index_3D_parameter(): | |||||
| net = Net() | net = Net() | ||||
| net(Tensor(0)) | net(Tensor(0)) | ||||
| def test_const_list_index_3D_bprop(): | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| self.value = [[1], [2, 2], [[3, 3], [3, 3]]] | |||||
| self.relu = P.ReLU() | |||||
| def construct(self, input_x): | |||||
| list_x = self.value | |||||
| list_x[2][0][1] = input_x | |||||
| return self.relu(list_x[2][0][1]) | |||||
| class GradNet(nn.Cell): | |||||
| def __init__(self, net): | |||||
| super(GradNet, self).__init__() | |||||
| self.net = net | |||||
| self.grad_all_with_sens = C.GradOperation(get_all=True, sens_param=True) | |||||
| def construct(self, x, sens): | |||||
| return self.grad_all_with_sens(self.net)(x, sens) | |||||
| net = Net() | |||||
| grad_net = GradNet(net) | |||||
| x = Tensor(np.arange(2 * 3).reshape(2, 3)) | |||||
| sens = Tensor(np.arange(2 * 3).reshape(2, 3)) | |||||
| grad_net(x, sens) | |||||
| def test_parameter_list_index_3D_bprop(): | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| self.value = [[1], [2, 2], [[3, 3], [3, 3]]] | |||||
| self.relu = P.ReLU() | |||||
| def construct(self, x, value): | |||||
| list_value = [[x], [x, x], [[x, x], [x, x]]] | |||||
| list_value[2][0][1] = value | |||||
| return self.relu(list_value[2][0][1]) | |||||
| class GradNet(nn.Cell): | |||||
| def __init__(self, net): | |||||
| super(GradNet, self).__init__() | |||||
| self.net = net | |||||
| self.grad_all_with_sens = C.GradOperation(get_all=True, sens_param=True) | |||||
| def construct(self, x, value, sens): | |||||
| return self.grad_all_with_sens(self.net)(x, value, sens) | |||||
| net = Net() | |||||
| grad_net = GradNet(net) | |||||
| x = Tensor(np.arange(2 * 3).reshape(2, 3)) | |||||
| value = Tensor(np.ones((2, 3), np.int64)) | |||||
| sens = Tensor(np.arange(2 * 3).reshape(2, 3)) | |||||
| grad_net(x, value, sens) | |||||