Browse Source

optimize list getitem in bprop

tags/v1.1.0
buxue 5 years ago
parent
commit
2530943a7f
7 changed files with 92 additions and 20 deletions
  1. +4
    -3
      mindspore/ccsrc/frontend/optimizer/irpass.cc
  2. +1
    -1
      mindspore/ccsrc/frontend/optimizer/irpass.h
  3. +16
    -8
      mindspore/ccsrc/frontend/optimizer/irpass/item_tuple_or_list_eliminate.h
  4. +5
    -4
      mindspore/ccsrc/pipeline/jit/pass.cc
  5. +1
    -1
      mindspore/common/api.py
  6. +3
    -3
      tests/ut/cpp/optimizer/lib_test.cc
  7. +62
    -0
      tests/ut/python/pipeline/parse/test_sequence_assign.py

+ 4
- 3
mindspore/ccsrc/frontend/optimizer/irpass.cc View File

@@ -25,7 +25,7 @@
#include "frontend/optimizer/irpass/inline.h" #include "frontend/optimizer/irpass/inline.h"
#include "frontend/optimizer/irpass/incorporate_call.h" #include "frontend/optimizer/irpass/incorporate_call.h"
#include "frontend/optimizer/irpass/incorporate_getitem.h" #include "frontend/optimizer/irpass/incorporate_getitem.h"
#include "frontend/optimizer/irpass/item_tuple_eliminate.h"
#include "frontend/optimizer/irpass/item_tuple_or_list_eliminate.h"
#include "frontend/optimizer/irpass/mark_interface_fusion.h" #include "frontend/optimizer/irpass/mark_interface_fusion.h"
#include "frontend/optimizer/irpass/merge_addn.h" #include "frontend/optimizer/irpass/merge_addn.h"
#include "frontend/optimizer/irpass/accumulaten_eliminate.h" #include "frontend/optimizer/irpass/accumulaten_eliminate.h"
@@ -67,8 +67,9 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
MakeSubstitution(std::make_shared<AdjustAllReduceMulAdd>(), "adjust_all_reduce_mul_add", prim::kPrimAddN); MakeSubstitution(std::make_shared<AdjustAllReduceMulAdd>(), "adjust_all_reduce_mul_add", prim::kPrimAddN);


// ops eliminate // ops eliminate
item_tuple_eliminate_ = MakeSubstitution(std::make_shared<ItemTupleEliminater>(), "item_tuple_eliminate",
{prim::kPrimTupleGetItem, prim::kPrimTupleSetItem, prim::kPrimListGetItem});
item_tuple_or_list_eliminate_ = MakeSubstitution(
std::make_shared<ItemTupleOrListEliminater>(), "item_tuple_or_list_eliminate",
{prim::kPrimTupleGetItem, prim::kPrimTupleSetItem, prim::kPrimListGetItem, prim::kPrimListSetItem});
tile_eliminate_ = MakeSubstitution(std::make_shared<TileEliminater>(), "tile_eliminate", prim::kPrimTile); tile_eliminate_ = MakeSubstitution(std::make_shared<TileEliminater>(), "tile_eliminate", prim::kPrimTile);
cast_eliminate_ = MakeSubstitution(std::make_shared<CastEliminater>(), "cast_eliminate", prim::kPrimCast); cast_eliminate_ = MakeSubstitution(std::make_shared<CastEliminater>(), "cast_eliminate", prim::kPrimCast);
reshape_eliminate_ = MakeSubstitution(std::make_shared<ReshapeEliminater>(), "reshape_eliminate", prim::kPrimReshape); reshape_eliminate_ = MakeSubstitution(std::make_shared<ReshapeEliminater>(), "reshape_eliminate", prim::kPrimReshape);


+ 1
- 1
mindspore/ccsrc/frontend/optimizer/irpass.h View File

@@ -39,7 +39,7 @@ class OptimizeIRPassLib {
SubstitutionPtr adjust_all_reduce_mul_add_; SubstitutionPtr adjust_all_reduce_mul_add_;


// ops eliminate // ops eliminate
SubstitutionPtr item_tuple_eliminate_;
SubstitutionPtr item_tuple_or_list_eliminate_;
SubstitutionPtr tile_eliminate_; SubstitutionPtr tile_eliminate_;
SubstitutionPtr cast_eliminate_; SubstitutionPtr cast_eliminate_;
SubstitutionPtr reshape_eliminate_; SubstitutionPtr reshape_eliminate_;


mindspore/ccsrc/frontend/optimizer/irpass/item_tuple_eliminate.h → mindspore/ccsrc/frontend/optimizer/irpass/item_tuple_or_list_eliminate.h View File

@@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */


#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ITEM_TUPLE_OR_LIST_ELIMINATE_H_
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ITEM_TUPLE_OR_LIST_ELIMINATE_H_


#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
@@ -33,6 +33,7 @@ namespace irpass {
// (a, b, c, ...)[0] => a // (a, b, c, ...)[0] => a
// (a, b, c, ...)[1] => b // (a, b, c, ...)[1] => b
// {prim::kPrimTupleGetItem, {prim::kPrimMakeTuple, Xs}, C} // {prim::kPrimTupleGetItem, {prim::kPrimMakeTuple, Xs}, C}
// {prim::kPrimListGetItem, {prim::kPrimMakeList, Xs}, C}
class GetitemEliminater : public AnfVisitor { class GetitemEliminater : public AnfVisitor {
public: public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
@@ -54,7 +55,7 @@ class GetitemEliminater : public AnfVisitor {


void Visit(const ValueNodePtr &vnode) override { void Visit(const ValueNodePtr &vnode) override {
if (tuple_ != nullptr && IsValueNode<Int64Imm>(vnode)) { if (tuple_ != nullptr && IsValueNode<Int64Imm>(vnode)) {
int64_t idx = GetValue<int64_t>(vnode->value());
auto idx = GetValue<int64_t>(vnode->value());
if (idx < 0) { if (idx < 0) {
idx = idx + tuple_->size() - 1; idx = idx + tuple_->size() - 1;
} }
@@ -80,6 +81,7 @@ class GetitemEliminater : public AnfVisitor {
// (a, b, c, ...)[0] => a // (a, b, c, ...)[0] => a
// (a, b, c, ...)[1] => b // (a, b, c, ...)[1] => b
// {prim::kPrimTupleGetItem, C1, C} // {prim::kPrimTupleGetItem, C1, C}
// {prim::kPrimListGetItem, C1, C}
class GetitemConstEliminater : public AnfVisitor { class GetitemConstEliminater : public AnfVisitor {
public: public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
@@ -124,11 +126,13 @@ class GetitemConstEliminater : public AnfVisitor {
// setitem((a, b, c, ...), 0, z) => (z, b, c, ...) // setitem((a, b, c, ...), 0, z) => (z, b, c, ...)
// setitem((a, b, c, ...), 1, z) => (a, z, c, ...) // setitem((a, b, c, ...), 1, z) => (a, z, c, ...)
// {prim::kPrimTupleSetItem, {prim::kPrimMakeTuple, Xs}, C, Z} // {prim::kPrimTupleSetItem, {prim::kPrimMakeTuple, Xs}, C, Z}
// {prim::kPrimListSetItem, {prim::kPrimMakeList, Xs}, C, Z}
class SetitemEliminater : public AnfVisitor { class SetitemEliminater : public AnfVisitor {
public: public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset(); Reset();
AnfVisitor::Match(prim::kPrimTupleSetItem, {IsCNode, IsVNode, IsNode})(node); AnfVisitor::Match(prim::kPrimTupleSetItem, {IsCNode, IsVNode, IsNode})(node);
AnfVisitor::Match(prim::kPrimListSetItem, {IsCNode, IsVNode, IsNode})(node);


auto fg = node->func_graph(); auto fg = node->func_graph();
if (fg != nullptr && z_ != nullptr) { if (fg != nullptr && z_ != nullptr) {
@@ -178,11 +182,13 @@ class SetitemEliminater : public AnfVisitor {
}; };


// {prim::kPrimTupleGetItem, {prim::kPrimTupleSetItem, Y, C1, X}, C2} // {prim::kPrimTupleGetItem, {prim::kPrimTupleSetItem, Y, C1, X}, C2}
// {prim::kPrimListGetItem, {prim::kPrimListSetItem, Y, C1, X}, C2}
class GetSetitemEliminater : public AnfVisitor { class GetSetitemEliminater : public AnfVisitor {
public: public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset(); Reset();
AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsVNode})(node); AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsVNode})(node);
AnfVisitor::Match(prim::kPrimListGetItem, {IsCNode, IsVNode})(node);


auto fg = node->func_graph(); auto fg = node->func_graph();
if (fg != nullptr && key1_ >= 0 && key2_ >= 0) { if (fg != nullptr && key1_ >= 0 && key2_ >= 0) {
@@ -195,7 +201,7 @@ class GetSetitemEliminater : public AnfVisitor {
} }


void Visit(const CNodePtr &cnode) override { void Visit(const CNodePtr &cnode) override {
if (IsPrimitiveCNode(cnode, prim::kPrimTupleSetItem)) {
if (IsPrimitiveCNode(cnode, prim::kPrimTupleSetItem) || IsPrimitiveCNode(cnode, prim::kPrimListSetItem)) {
if (cnode->size() < 4) { if (cnode->size() < 4) {
return; return;
} }
@@ -239,6 +245,8 @@ class GetSetitemEliminater : public AnfVisitor {


// {prim::kPrimTupleGetItem, {prim::kPrimDepend, X, Y}, C} -> // {prim::kPrimTupleGetItem, {prim::kPrimDepend, X, Y}, C} ->
// {prim::kPrimDepend, {prim::kPrimTupleGetItem, X, C}, Y} // {prim::kPrimDepend, {prim::kPrimTupleGetItem, X, C}, Y}
// {prim::kPrimListGetItem, {prim::kPrimDepend, X, Y}, C} ->
// {prim::kPrimDepend, {prim::kPrimListGetItem, X, C}, Y}
class GetitemDependReorder : public AnfVisitor { class GetitemDependReorder : public AnfVisitor {
public: public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
@@ -274,9 +282,9 @@ class GetitemDependReorder : public AnfVisitor {
AnfNodePtr x_{nullptr}, y_{nullptr}, c_{nullptr}; AnfNodePtr x_{nullptr}, y_{nullptr}, c_{nullptr};
}; };


class ItemTupleEliminater : public OptimizerCaller {
class ItemTupleOrListEliminater : public OptimizerCaller {
public: public:
ItemTupleEliminater()
ItemTupleOrListEliminater()
: get_item_eliminater_(std::make_shared<GetitemEliminater>()), : get_item_eliminater_(std::make_shared<GetitemEliminater>()),
get_item_const_eliminater_(std::make_shared<GetitemConstEliminater>()), get_item_const_eliminater_(std::make_shared<GetitemConstEliminater>()),
set_item_eliminater_(std::make_shared<SetitemEliminater>()), set_item_eliminater_(std::make_shared<SetitemEliminater>()),
@@ -288,7 +296,7 @@ class ItemTupleEliminater : public OptimizerCaller {
eliminaters_.emplace_back(get_set_item_eliminater_); eliminaters_.emplace_back(get_set_item_eliminater_);
eliminaters_.emplace_back(get_item_depend_reorder_); eliminaters_.emplace_back(get_item_depend_reorder_);
} }
~ItemTupleEliminater() = default;
~ItemTupleOrListEliminater() = default;


AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
AnfNodePtr new_node; AnfNodePtr new_node;
@@ -309,4 +317,4 @@ class ItemTupleEliminater : public OptimizerCaller {
} // namespace irpass } // namespace irpass
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ITEM_TUPLE_OR_LIST_ELIMINATE_H_

+ 5
- 4
mindspore/ccsrc/pipeline/jit/pass.cc View File

@@ -100,7 +100,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.specialize_transform_, irpass.specialize_transform_,


// Miscellaneous // Miscellaneous
irpass.item_tuple_eliminate_,
irpass.item_tuple_or_list_eliminate_,
irpass.env_get_item_eliminate_, irpass.env_get_item_eliminate_,
irpass.cast_eliminate_, irpass.cast_eliminate_,
irpass.reshape_eliminate_, irpass.reshape_eliminate_,
@@ -188,8 +188,9 @@ OptPassGroupMap GetOptPassesAfterCconv(const opt::irpass::OptimizeIRPassLib &irp
} }


OptPassGroupMap GetOptPassesTransformGraph(const opt::irpass::OptimizeIRPassLib &irpass) { OptPassGroupMap GetOptPassesTransformGraph(const opt::irpass::OptimizeIRPassLib &irpass) {
opt::OptPassConfig d_1 = opt::OptPassConfig({// Safe inlining
irpass.call_graph_tuple_transform_, irpass.item_tuple_eliminate_});
opt::OptPassConfig d_1 =
opt::OptPassConfig({// Safe inlining
irpass.call_graph_tuple_transform_, irpass.item_tuple_or_list_eliminate_});


OptPassGroupMap map_a({{"d_1", d_1}, {"renormalize", opt::OptPassConfig::Renormalize()}}); OptPassGroupMap map_a({{"d_1", d_1}, {"renormalize", opt::OptPassConfig::Renormalize()}});


@@ -198,7 +199,7 @@ OptPassGroupMap GetOptPassesTransformGraph(const opt::irpass::OptimizeIRPassLib


OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
opt::OptPassConfig b_1 = opt::OptPassConfig( opt::OptPassConfig b_1 = opt::OptPassConfig(
{irpass.zero_like_fill_zero_, irpass.item_tuple_eliminate_, irpass.float_tuple_getitem_switch_,
{irpass.zero_like_fill_zero_, irpass.item_tuple_or_list_eliminate_, irpass.float_tuple_getitem_switch_,
irpass.reset_defer_inline_, irpass.inline_, irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_, irpass.reset_defer_inline_, irpass.inline_, irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_,
irpass.incorporate_env_getitem_, irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_, irpass.incorporate_env_getitem_, irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_,
irpass.incorporate_env_getitem_switch_layer_, irpass.value_based_eliminate_, irpass.receive_eliminate_}); irpass.incorporate_env_getitem_switch_layer_, irpass.value_based_eliminate_, irpass.receive_eliminate_});


+ 1
- 1
mindspore/common/api.py View File

@@ -232,7 +232,7 @@ def ms_function(fn=None, obj=None, input_signature=None):
equal to the case when `fn` is not None. equal to the case when `fn` is not None.


Examples: Examples:
>>> from mindspore.ops import functional as F
>>> from mindspore.ops import functional as F
... ...
>>> def tensor_add(x, y): >>> def tensor_add(x, y):
... z = x + y ... z = x + y


+ 3
- 3
tests/ut/cpp/optimizer/lib_test.cc View File

@@ -360,7 +360,7 @@ TEST_F(TestOptLib, test_tuple_getitem) {
FuncGraphPtr after_2 = std::make_shared<FuncGraph>(); FuncGraphPtr after_2 = std::make_shared<FuncGraph>();
after_2->set_output(value_node_2); after_2->set_output(value_node_2);


auto patterns = std::vector<SubstitutionPtr>({irpass.item_tuple_eliminate_});
auto patterns = std::vector<SubstitutionPtr>({irpass.item_tuple_or_list_eliminate_});
ASSERT_TRUE(CheckOpt(make_get_0, after_0, patterns)); ASSERT_TRUE(CheckOpt(make_get_0, after_0, patterns));
ASSERT_TRUE(CheckOpt(make_get_1, after_1, patterns)); ASSERT_TRUE(CheckOpt(make_get_1, after_1, patterns));
ASSERT_TRUE(CheckOpt(make_get_const, after_2, patterns)); ASSERT_TRUE(CheckOpt(make_get_const, after_2, patterns));
@@ -372,7 +372,7 @@ TEST_F(TestOptLib, test_tuple_setitem) {
FuncGraphPtr after_0 = getPyFun.CallAndParseRet("test_tuple_setitem", "after_0"); FuncGraphPtr after_0 = getPyFun.CallAndParseRet("test_tuple_setitem", "after_0");
FuncGraphPtr after_1 = getPyFun.CallAndParseRet("test_tuple_setitem", "after_1"); FuncGraphPtr after_1 = getPyFun.CallAndParseRet("test_tuple_setitem", "after_1");


auto patterns = std::vector<SubstitutionPtr>({irpass.item_tuple_eliminate_});
auto patterns = std::vector<SubstitutionPtr>({irpass.item_tuple_or_list_eliminate_});


ASSERT_TRUE(CheckOpt(before_0, after_0, patterns)); ASSERT_TRUE(CheckOpt(before_0, after_0, patterns));
ASSERT_TRUE(CheckOpt(before_1, after_1, patterns)); ASSERT_TRUE(CheckOpt(before_1, after_1, patterns));
@@ -384,7 +384,7 @@ TEST_F(TestOptLib, test_tuple_get_set_item) {
FuncGraphPtr before_1 = getPyFun.CallAndParseRet("test_tuple_get_set_item", "before_0"); FuncGraphPtr before_1 = getPyFun.CallAndParseRet("test_tuple_get_set_item", "before_0");
FuncGraphPtr after_1 = getPyFun.CallAndParseRet("test_tuple_get_set_item", "after_0"); FuncGraphPtr after_1 = getPyFun.CallAndParseRet("test_tuple_get_set_item", "after_0");


auto patterns = std::vector<SubstitutionPtr>({irpass.item_tuple_eliminate_});
auto patterns = std::vector<SubstitutionPtr>({irpass.item_tuple_or_list_eliminate_});


ASSERT_TRUE(CheckOpt(before_0, after_0, patterns)); ASSERT_TRUE(CheckOpt(before_0, after_0, patterns));
ASSERT_TRUE(CheckOpt(before_1, after_1, patterns)); ASSERT_TRUE(CheckOpt(before_1, after_1, patterns));


+ 62
- 0
tests/ut/python/pipeline/parse/test_sequence_assign.py View File

@@ -13,9 +13,14 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
""" test enumerate""" """ test enumerate"""

import numpy as np

import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
from mindspore import context from mindspore import context
from mindspore.ops import operations as P
from mindspore.ops import composite as C


context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)


@@ -168,3 +173,60 @@ def test_list_index_3D_parameter():


net = Net() net = Net()
net(Tensor(0)) net(Tensor(0))


def test_const_list_index_3D_bprop():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = [[1], [2, 2], [[3, 3], [3, 3]]]
self.relu = P.ReLU()

def construct(self, input_x):
list_x = self.value
list_x[2][0][1] = input_x
return self.relu(list_x[2][0][1])

class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.grad_all_with_sens = C.GradOperation(get_all=True, sens_param=True)

def construct(self, x, sens):
return self.grad_all_with_sens(self.net)(x, sens)

net = Net()
grad_net = GradNet(net)
x = Tensor(np.arange(2 * 3).reshape(2, 3))
sens = Tensor(np.arange(2 * 3).reshape(2, 3))
grad_net(x, sens)


def test_parameter_list_index_3D_bprop():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = [[1], [2, 2], [[3, 3], [3, 3]]]
self.relu = P.ReLU()

def construct(self, x, value):
list_value = [[x], [x, x], [[x, x], [x, x]]]
list_value[2][0][1] = value
return self.relu(list_value[2][0][1])

class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.grad_all_with_sens = C.GradOperation(get_all=True, sens_param=True)

def construct(self, x, value, sens):
return self.grad_all_with_sens(self.net)(x, value, sens)

net = Net()
grad_net = GradNet(net)
x = Tensor(np.arange(2 * 3).reshape(2, 3))
value = Tensor(np.ones((2, 3), np.int64))
sens = Tensor(np.arange(2 * 3).reshape(2, 3))
grad_net(x, value, sens)

Loading…
Cancel
Save