Browse Source

Fix AdjustAllReduceMulAdd pass

Revive tests in lib_test

Code cleaning
tags/v1.0.0
Hoai Linh Tran 5 years ago
parent
commit
46f07efc31
3 changed files with 53 additions and 17 deletions
  1. +39
    -1
      mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc
  2. +2
    -8
      tests/ut/cpp/optimizer/lib_test.cc
  3. +12
    -8
      tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py

+ 39
- 1
mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc View File

@@ -107,6 +107,8 @@ AnfNodePtr AdjustAllReduceMulAdd::operator()(const OptimizerPtr &, const AnfNode
auto adjust_lambda = [&node, &x, &y, &z, &addn_pat, &all_reduce_pat, &admktup_pat, &mul_pat, this]() -> AnfNodePtr { auto adjust_lambda = [&node, &x, &y, &z, &addn_pat, &all_reduce_pat, &admktup_pat, &mul_pat, this]() -> AnfNodePtr {
auto fg = all_reduce_pat.GetFuncGraph(); auto fg = all_reduce_pat.GetFuncGraph();
auto z_ = z.GetNode(node); auto z_ = z.GetNode(node);
auto x_ = x.GetNode(node);

// If addn inputs cross the graph, make the inputs same as allreduce node. // If addn inputs cross the graph, make the inputs same as allreduce node.
if (z_->isa<CNode>() && fg != z_->func_graph()) { if (z_->isa<CNode>() && fg != z_->func_graph()) {
auto cnode_z = z_->cast<CNodePtr>(); auto cnode_z = z_->cast<CNodePtr>();
@@ -121,7 +123,43 @@ AnfNodePtr AdjustAllReduceMulAdd::operator()(const OptimizerPtr &, const AnfNode
auto mul_prim = mul_cnode_->cast<CNodePtr>()->input(0); auto mul_prim = mul_cnode_->cast<CNodePtr>()->input(0);
auto addn_maketuple = admktup_pat.GetOriginalNode(); auto addn_maketuple = admktup_pat.GetOriginalNode();


AnfNodePtr tuple = NewCNode({make_tuple_op_node, z_, x.GetNode(node)}, fg);
ShapeVector x_shape, z_shape;
if (!x_->isa<ValueNode>()) {
if ((x_->abstract() == nullptr) || !x_->abstract()->isa<abstract::AbstractTensor>()) {
return nullptr;
}
auto x_abstract = x_->abstract()->cast<abstract::AbstractTensorPtr>();
x_shape = x_abstract->shape()->shape();
} else {
ValuePtr x_value = x_->cast<ValueNodePtr>()->value();
if (!x_value->isa<tensor::Tensor>()) {
return nullptr;
}
auto x_tensor = GetValueNode<tensor::TensorPtr>(x_->cast<ValueNodePtr>());
x_shape = x_tensor->shape();
}
if (!z_->isa<ValueNode>()) {
if ((z_->abstract() == nullptr) || !z_->abstract()->isa<abstract::AbstractTensor>()) {
return nullptr;
}
auto z_abstract = z_->abstract()->cast<abstract::AbstractTensorPtr>();
z_shape = z_abstract->shape()->shape();
} else {
ValuePtr z_value = z_->cast<ValueNodePtr>()->value();
if (!z_value->isa<tensor::Tensor>()) {
return nullptr;
}
auto z_tensor = GetValueNode<tensor::TensorPtr>(z_->cast<ValueNodePtr>());
z_shape = z_tensor->shape();
}

if (x_shape != z_shape) {
// AddN requires x_ and z_ have the same shape.
// If broadcasting TensorAdd is supported then can use this
// AnfNodePtr add = NewCNode({NewValueNode(prim::kPrimTensorAdd), z_, x_}, fg);
return nullptr;
}
AnfNodePtr tuple = NewCNode({make_tuple_op_node, z_, x_}, fg);
AnfNodePtr add = NewCNode({addn_op_node, tuple}, fg); AnfNodePtr add = NewCNode({addn_op_node, tuple}, fg);
AnfNodePtr all_reduce = NewCNode({all_reduce_prim, add}, fg); AnfNodePtr all_reduce = NewCNode({all_reduce_prim, add}, fg);
AnfNodePtr mul = NewCNode({mul_prim, all_reduce, y.GetNode(node)}, fg); AnfNodePtr mul = NewCNode({mul_prim, all_reduce, y.GetNode(node)}, fg);


+ 2
- 8
tests/ut/cpp/optimizer/lib_test.cc View File

@@ -353,11 +353,7 @@ TEST_F(TestOptLib, test_tuple_getitem) {
auto value_node_2 = NewValueNode(2); auto value_node_2 = NewValueNode(2);
std::vector<int> vec{1, 2}; std::vector<int> vec{1, 2};
auto value_node_tuple = NewValueNode(MakeValue(vec)); auto value_node_tuple = NewValueNode(MakeValue(vec));
std::vector<AnfNodePtr> node_list{
NewValueNode(prim::kPrimTupleGetItem),
value_node_tuple,
value_node_1
};
std::vector<AnfNodePtr> node_list{NewValueNode(prim::kPrimTupleGetItem), value_node_tuple, value_node_1};
auto get_item = make_get_const->NewCNode(node_list); auto get_item = make_get_const->NewCNode(node_list);
make_get_const->set_output(get_item); make_get_const->set_output(get_item);


@@ -598,12 +594,10 @@ TEST_F(TestOptLib, test_adjust_allreduce_mul_add) {
FuncGraphPtr before2l = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "before2l"); FuncGraphPtr before2l = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "before2l");
FuncGraphPtr after2 = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "after2"); FuncGraphPtr after2 = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "after2");
auto patterns = std::vector<SubstitutionPtr>({irpass.adjust_all_reduce_mul_add_}); auto patterns = std::vector<SubstitutionPtr>({irpass.adjust_all_reduce_mul_add_});
ASSERT_TRUE(CheckOpt(beforell, after1, patterns));
ASSERT_TRUE(CheckOpt(beforell, after1, patterns, true));
ASSERT_TRUE(CheckOpt(beforelr, after1, patterns)); ASSERT_TRUE(CheckOpt(beforelr, after1, patterns));
ASSERT_TRUE(CheckOpt(beforerl, after1, patterns)); ASSERT_TRUE(CheckOpt(beforerl, after1, patterns));
ASSERT_TRUE(CheckOpt(beforerr, after1, patterns)); ASSERT_TRUE(CheckOpt(beforerr, after1, patterns));
ASSERT_TRUE(CheckOpt(before2l, after2, patterns));
ASSERT_TRUE(CheckOpt(before2r, after2, patterns));
} }


TEST_F(TestOptLib, test_row_tensor) { TEST_F(TestOptLib, test_row_tensor) {


+ 12
- 8
tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py View File

@@ -1095,36 +1095,40 @@ def test_adjust_allreduce_mul_add(tag):
AddN = Primitive('AddN') AddN = Primitive('AddN')
AllReduce = Primitive('AllReduce') AllReduce = Primitive('AllReduce')


x = Tensor(np.ones(shape=(64, 32)).astype(np.float32))
y = Tensor(np.ones(shape=(64, 32)).astype(np.float32))
z = Tensor(np.ones(shape=(64, 32)).astype(np.float32))

@fns @fns
def beforell(x, y, z):
def beforell():
return AddN((z, Mul(y, AllReduce(x)))) return AddN((z, Mul(y, AllReduce(x))))


@fns @fns
def beforelr(x, y, z):
def beforelr():
return AddN((z, Mul(AllReduce(x), y))) return AddN((z, Mul(AllReduce(x), y)))


@fns @fns
def beforerl(x, y, z):
def beforerl():
return AddN((Mul(y, AllReduce(x)), z)) return AddN((Mul(y, AllReduce(x)), z))


@fns @fns
def beforerr(x, y, z):
def beforerr():
return AddN((Mul(AllReduce(x), y), z)) return AddN((Mul(AllReduce(x), y), z))


@fns @fns
def after1(x, y, z):
def after1():
return Mul(AllReduce(AddN((z, x))), y) return Mul(AllReduce(AddN((z, x))), y)


@fns @fns
def before2r(x, y, z):
def before2r():
return AddN((Mul(AllReduce(x), y), Mul(z, z))) return AddN((Mul(AllReduce(x), y), Mul(z, z)))


@fns @fns
def before2l(x, y, z):
def before2l():
return AddN((Mul(z, z), Mul(AllReduce(x), y))) return AddN((Mul(z, z), Mul(AllReduce(x), y)))


@fns @fns
def after2(x, y, z):
def after2():
return Mul(AllReduce(AddN((Mul(z, z), x))), y) return Mul(AllReduce(AddN((Mul(z, z), x))), y)


return fns[tag] return fns[tag]


Loading…
Cancel
Save