| @@ -107,6 +107,8 @@ AnfNodePtr AdjustAllReduceMulAdd::operator()(const OptimizerPtr &, const AnfNode | |||||
| auto adjust_lambda = [&node, &x, &y, &z, &addn_pat, &all_reduce_pat, &admktup_pat, &mul_pat, this]() -> AnfNodePtr { | auto adjust_lambda = [&node, &x, &y, &z, &addn_pat, &all_reduce_pat, &admktup_pat, &mul_pat, this]() -> AnfNodePtr { | ||||
| auto fg = all_reduce_pat.GetFuncGraph(); | auto fg = all_reduce_pat.GetFuncGraph(); | ||||
| auto z_ = z.GetNode(node); | auto z_ = z.GetNode(node); | ||||
| auto x_ = x.GetNode(node); | |||||
| // If addn inputs cross the graph, make the inputs same as allreduce node. | // If addn inputs cross the graph, make the inputs same as allreduce node. | ||||
| if (z_->isa<CNode>() && fg != z_->func_graph()) { | if (z_->isa<CNode>() && fg != z_->func_graph()) { | ||||
| auto cnode_z = z_->cast<CNodePtr>(); | auto cnode_z = z_->cast<CNodePtr>(); | ||||
| @@ -121,7 +123,43 @@ AnfNodePtr AdjustAllReduceMulAdd::operator()(const OptimizerPtr &, const AnfNode | |||||
| auto mul_prim = mul_cnode_->cast<CNodePtr>()->input(0); | auto mul_prim = mul_cnode_->cast<CNodePtr>()->input(0); | ||||
| auto addn_maketuple = admktup_pat.GetOriginalNode(); | auto addn_maketuple = admktup_pat.GetOriginalNode(); | ||||
| AnfNodePtr tuple = NewCNode({make_tuple_op_node, z_, x.GetNode(node)}, fg); | |||||
| ShapeVector x_shape, z_shape; | |||||
| if (!x_->isa<ValueNode>()) { | |||||
| if ((x_->abstract() == nullptr) || !x_->abstract()->isa<abstract::AbstractTensor>()) { | |||||
| return nullptr; | |||||
| } | |||||
| auto x_abstract = x_->abstract()->cast<abstract::AbstractTensorPtr>(); | |||||
| x_shape = x_abstract->shape()->shape(); | |||||
| } else { | |||||
| ValuePtr x_value = x_->cast<ValueNodePtr>()->value(); | |||||
| if (!x_value->isa<tensor::Tensor>()) { | |||||
| return nullptr; | |||||
| } | |||||
| auto x_tensor = GetValueNode<tensor::TensorPtr>(x_->cast<ValueNodePtr>()); | |||||
| x_shape = x_tensor->shape(); | |||||
| } | |||||
| if (!z_->isa<ValueNode>()) { | |||||
| if ((z_->abstract() == nullptr) || !z_->abstract()->isa<abstract::AbstractTensor>()) { | |||||
| return nullptr; | |||||
| } | |||||
| auto z_abstract = z_->abstract()->cast<abstract::AbstractTensorPtr>(); | |||||
| z_shape = z_abstract->shape()->shape(); | |||||
| } else { | |||||
| ValuePtr z_value = z_->cast<ValueNodePtr>()->value(); | |||||
| if (!z_value->isa<tensor::Tensor>()) { | |||||
| return nullptr; | |||||
| } | |||||
| auto z_tensor = GetValueNode<tensor::TensorPtr>(z_->cast<ValueNodePtr>()); | |||||
| z_shape = z_tensor->shape(); | |||||
| } | |||||
| if (x_shape != z_shape) { | |||||
| // AddN requires x_ and z_ have the same shape. | |||||
| // If broadcasting TensorAdd is supported then can use this | |||||
| // AnfNodePtr add = NewCNode({NewValueNode(prim::kPrimTensorAdd), z_, x_}, fg); | |||||
| return nullptr; | |||||
| } | |||||
| AnfNodePtr tuple = NewCNode({make_tuple_op_node, z_, x_}, fg); | |||||
| AnfNodePtr add = NewCNode({addn_op_node, tuple}, fg); | AnfNodePtr add = NewCNode({addn_op_node, tuple}, fg); | ||||
| AnfNodePtr all_reduce = NewCNode({all_reduce_prim, add}, fg); | AnfNodePtr all_reduce = NewCNode({all_reduce_prim, add}, fg); | ||||
| AnfNodePtr mul = NewCNode({mul_prim, all_reduce, y.GetNode(node)}, fg); | AnfNodePtr mul = NewCNode({mul_prim, all_reduce, y.GetNode(node)}, fg); | ||||
| @@ -353,11 +353,7 @@ TEST_F(TestOptLib, test_tuple_getitem) { | |||||
| auto value_node_2 = NewValueNode(2); | auto value_node_2 = NewValueNode(2); | ||||
| std::vector<int> vec{1, 2}; | std::vector<int> vec{1, 2}; | ||||
| auto value_node_tuple = NewValueNode(MakeValue(vec)); | auto value_node_tuple = NewValueNode(MakeValue(vec)); | ||||
| std::vector<AnfNodePtr> node_list{ | |||||
| NewValueNode(prim::kPrimTupleGetItem), | |||||
| value_node_tuple, | |||||
| value_node_1 | |||||
| }; | |||||
| std::vector<AnfNodePtr> node_list{NewValueNode(prim::kPrimTupleGetItem), value_node_tuple, value_node_1}; | |||||
| auto get_item = make_get_const->NewCNode(node_list); | auto get_item = make_get_const->NewCNode(node_list); | ||||
| make_get_const->set_output(get_item); | make_get_const->set_output(get_item); | ||||
| @@ -598,12 +594,10 @@ TEST_F(TestOptLib, test_adjust_allreduce_mul_add) { | |||||
| FuncGraphPtr before2l = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "before2l"); | FuncGraphPtr before2l = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "before2l"); | ||||
| FuncGraphPtr after2 = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "after2"); | FuncGraphPtr after2 = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "after2"); | ||||
| auto patterns = std::vector<SubstitutionPtr>({irpass.adjust_all_reduce_mul_add_}); | auto patterns = std::vector<SubstitutionPtr>({irpass.adjust_all_reduce_mul_add_}); | ||||
| ASSERT_TRUE(CheckOpt(beforell, after1, patterns)); | |||||
| ASSERT_TRUE(CheckOpt(beforell, after1, patterns, true)); | |||||
| ASSERT_TRUE(CheckOpt(beforelr, after1, patterns)); | ASSERT_TRUE(CheckOpt(beforelr, after1, patterns)); | ||||
| ASSERT_TRUE(CheckOpt(beforerl, after1, patterns)); | ASSERT_TRUE(CheckOpt(beforerl, after1, patterns)); | ||||
| ASSERT_TRUE(CheckOpt(beforerr, after1, patterns)); | ASSERT_TRUE(CheckOpt(beforerr, after1, patterns)); | ||||
| ASSERT_TRUE(CheckOpt(before2l, after2, patterns)); | |||||
| ASSERT_TRUE(CheckOpt(before2r, after2, patterns)); | |||||
| } | } | ||||
| TEST_F(TestOptLib, test_row_tensor) { | TEST_F(TestOptLib, test_row_tensor) { | ||||
| @@ -1095,36 +1095,40 @@ def test_adjust_allreduce_mul_add(tag): | |||||
| AddN = Primitive('AddN') | AddN = Primitive('AddN') | ||||
| AllReduce = Primitive('AllReduce') | AllReduce = Primitive('AllReduce') | ||||
| x = Tensor(np.ones(shape=(64, 32)).astype(np.float32)) | |||||
| y = Tensor(np.ones(shape=(64, 32)).astype(np.float32)) | |||||
| z = Tensor(np.ones(shape=(64, 32)).astype(np.float32)) | |||||
| @fns | @fns | ||||
| def beforell(x, y, z): | |||||
| def beforell(): | |||||
| return AddN((z, Mul(y, AllReduce(x)))) | return AddN((z, Mul(y, AllReduce(x)))) | ||||
| @fns | @fns | ||||
| def beforelr(x, y, z): | |||||
| def beforelr(): | |||||
| return AddN((z, Mul(AllReduce(x), y))) | return AddN((z, Mul(AllReduce(x), y))) | ||||
| @fns | @fns | ||||
| def beforerl(x, y, z): | |||||
| def beforerl(): | |||||
| return AddN((Mul(y, AllReduce(x)), z)) | return AddN((Mul(y, AllReduce(x)), z)) | ||||
| @fns | @fns | ||||
| def beforerr(x, y, z): | |||||
| def beforerr(): | |||||
| return AddN((Mul(AllReduce(x), y), z)) | return AddN((Mul(AllReduce(x), y), z)) | ||||
| @fns | @fns | ||||
| def after1(x, y, z): | |||||
| def after1(): | |||||
| return Mul(AllReduce(AddN((z, x))), y) | return Mul(AllReduce(AddN((z, x))), y) | ||||
| @fns | @fns | ||||
| def before2r(x, y, z): | |||||
| def before2r(): | |||||
| return AddN((Mul(AllReduce(x), y), Mul(z, z))) | return AddN((Mul(AllReduce(x), y), Mul(z, z))) | ||||
| @fns | @fns | ||||
| def before2l(x, y, z): | |||||
| def before2l(): | |||||
| return AddN((Mul(z, z), Mul(AllReduce(x), y))) | return AddN((Mul(z, z), Mul(AllReduce(x), y))) | ||||
| @fns | @fns | ||||
| def after2(x, y, z): | |||||
| def after2(): | |||||
| return Mul(AllReduce(AddN((Mul(z, z), x))), y) | return Mul(AllReduce(AddN((Mul(z, z), x))), y) | ||||
| return fns[tag] | return fns[tag] | ||||