This reverts commit ea6958c50a.
tags/v0.3.0-alpha
| @@ -230,7 +230,6 @@ const PrimitivePtr kPrimNotInDict = std::make_shared<Primitive>("not_in_dict"); | |||||
| const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator"); | const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator"); | ||||
| const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv"); | const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv"); | ||||
| const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset"); | const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset"); | ||||
| const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduce"); | |||||
| // Debug ops | // Debug ops | ||||
| const PrimitivePtr kPrimScalarSummary = std::make_shared<Primitive>("ScalarSummary"); | const PrimitivePtr kPrimScalarSummary = std::make_shared<Primitive>("ScalarSummary"); | ||||
| @@ -234,7 +234,6 @@ extern const PrimitivePtr kPrimInDict; | |||||
| extern const PrimitivePtr kPrimNotInDict; | extern const PrimitivePtr kPrimNotInDict; | ||||
| // Comm ops | // Comm ops | ||||
| extern const PrimitivePtr kPrimAllReduce; | |||||
| extern const PrimitivePtr kPrimMirror; | extern const PrimitivePtr kPrimMirror; | ||||
| extern const PrimitivePtr kPrimVirtualDiv; | extern const PrimitivePtr kPrimVirtualDiv; | ||||
| extern const PrimitivePtr kPrimVirtualDataset; | extern const PrimitivePtr kPrimVirtualDataset; | ||||
| @@ -48,7 +48,7 @@ namespace irpass { | |||||
| OptimizeIRPassLib::OptimizeIRPassLib() { | OptimizeIRPassLib::OptimizeIRPassLib() { | ||||
| arithmetic_simplify_ = MakeSubstitution(ArithmeticSimplify(), "arithmetic_simplify", | arithmetic_simplify_ = MakeSubstitution(ArithmeticSimplify(), "arithmetic_simplify", | ||||
| {prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd, | {prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd, | ||||
| prim::kPrimAddN, prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul}); | |||||
| prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul}); | |||||
| special_op_eliminate_ = MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate", | special_op_eliminate_ = MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate", | ||||
| {prim::kPrimInsertGradientOf, prim::kPrimPrintShapeType, | {prim::kPrimInsertGradientOf, prim::kPrimPrintShapeType, | ||||
| prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv}); | prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv}); | ||||
| @@ -228,82 +228,6 @@ class ConstantDuplicateMul : public AnfVisitor { | |||||
| CNodePtr cnode_; | CNodePtr cnode_; | ||||
| }; | }; | ||||
| // grad = AllReduce(grad) / worker_number | |||||
| // grad = grad + weight * decy | |||||
| // -> | |||||
| // grad = grad + weight * decy | |||||
| // grad = AllReduce(grad) / worker_number | |||||
| // {prim::kPrimAddN, {prim::kPrimMakeTuple, {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}, Z}} -> | |||||
| // {prim::kPrimMul, {prim::kPrimAllReduce, {prim::kPrimAddN,{prim::kPrimMakeTuple, Z, X}}}, Y} | |||||
| class AdjustAllReduceMulAdd : public AnfVisitor { | |||||
| public: | |||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||||
| Reset(); | |||||
| // {prim::kPrimAddN, Zs} | |||||
| if (!IsPrimitiveCNode(node, prim::kPrimAddN)) { | |||||
| return nullptr; | |||||
| } | |||||
| auto addn = node->cast<CNodePtr>(); | |||||
| if (addn->size() != 2) { | |||||
| return nullptr; | |||||
| } | |||||
| AnfVisitor::Match(prim::kPrimMakeTuple, {IsNode, IsNode})(addn->input(1)); | |||||
| if (x_ == nullptr || y_ == nullptr || z_ == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| auto fg = node->func_graph(); | |||||
| AnfNodePtr tuple = NewCNode({NewValueNode(prim::kPrimMakeTuple), z_, x_}, fg); | |||||
| AnfNodePtr add = NewCNode({NewValueNode(prim::kPrimAddN), tuple}, fg); | |||||
| AnfNodePtr all_reduce = NewCNode({NewValueNode(prim::kPrimAllReduce), add}, fg); | |||||
| return NewCNode({NewValueNode(prim::kPrimMul), all_reduce, y_}, fg); | |||||
| } | |||||
| void Visit(const AnfNodePtr &node) override { | |||||
| if (level_ == 0) { | |||||
| level_ = 1; | |||||
| is_reduce_match_ = false; | |||||
| // {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y} | |||||
| AnfVisitor::Match(prim::kPrimMul)(node); | |||||
| level_ = 0; | |||||
| if (is_reduce_match_) { | |||||
| y_ = tmp_; | |||||
| } else { | |||||
| z_ = node; | |||||
| } | |||||
| } | |||||
| if (level_ == 1) { | |||||
| // {prim::kPrimAllReduce, X} | |||||
| if (IsPrimitiveCNode(node, prim::kPrimAllReduce)) { | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| if (cnode->size() > 1) { | |||||
| x_ = cnode->input(1); | |||||
| is_reduce_match_ = true; | |||||
| } | |||||
| } else { | |||||
| tmp_ = node; | |||||
| } | |||||
| } | |||||
| } | |||||
| void Reset() { | |||||
| level_ = 0; | |||||
| is_reduce_match_ = false; | |||||
| x_ = nullptr; | |||||
| y_ = nullptr; | |||||
| z_ = nullptr; | |||||
| tmp_ = nullptr; | |||||
| } | |||||
| private: | |||||
| int level_{0}; | |||||
| bool is_reduce_match_{false}; | |||||
| AnfNodePtr x_{nullptr}, y_{nullptr}, z_{nullptr}, tmp_{nullptr}; | |||||
| }; | |||||
| class ArithmeticSimplify { | class ArithmeticSimplify { | ||||
| public: | public: | ||||
| ArithmeticSimplify() | ArithmeticSimplify() | ||||
| @@ -319,7 +243,6 @@ class ArithmeticSimplify { | |||||
| eliminaters_.emplace_back(identity_); | eliminaters_.emplace_back(identity_); | ||||
| eliminaters_.emplace_back(opt_update_zero_tensor_); | eliminaters_.emplace_back(opt_update_zero_tensor_); | ||||
| eliminaters_.emplace_back(constant_duplicate_mul_); | eliminaters_.emplace_back(constant_duplicate_mul_); | ||||
| eliminaters_.emplace_back(adjust_allreduce_mul_add_); | |||||
| } | } | ||||
| ~ArithmeticSimplify() = default; | ~ArithmeticSimplify() = default; | ||||
| @@ -341,7 +264,6 @@ class ArithmeticSimplify { | |||||
| PrimEliminater identity_; | PrimEliminater identity_; | ||||
| OptUpdateZeroTensor opt_update_zero_tensor_; | OptUpdateZeroTensor opt_update_zero_tensor_; | ||||
| ConstantDuplicateMul constant_duplicate_mul_; | ConstantDuplicateMul constant_duplicate_mul_; | ||||
| AdjustAllReduceMulAdd adjust_allreduce_mul_add_; | |||||
| std::vector<TransformFuncType> eliminaters_{}; | std::vector<TransformFuncType> eliminaters_{}; | ||||
| }; | }; | ||||
| } // namespace irpass | } // namespace irpass | ||||
| @@ -1229,7 +1229,7 @@ class UnsortedSegmentSum(PrimitiveWithInfer): | |||||
| Tensor, the shape is :math:`(z, x_{N+1}, ..., x_R)`. | Tensor, the shape is :math:`(z, x_{N+1}, ..., x_R)`. | ||||
| Examples: | Examples: | ||||
| >>> input_x = Tensor([1, 2, 3, 4], mindspore.float32) | |||||
| >>> input_x = Tensor([1, 2, 3, 4], mindspore.float) | |||||
| >>> segment_ids = Tensor([0, 0, 1, 2], mindspore.int32) | >>> segment_ids = Tensor([0, 0, 1, 2], mindspore.int32) | ||||
| >>> num_segments = 4 | >>> num_segments = 4 | ||||
| >>> P.UnsortedSegmentSum()(input_x, segment_ids, num_segments) | >>> P.UnsortedSegmentSum()(input_x, segment_ids, num_segments) | ||||
| @@ -1630,7 +1630,7 @@ class LayerNorm(Primitive): | |||||
| `Layer Normalization <https://arxiv.org/abs/1607.06450>`_. | `Layer Normalization <https://arxiv.org/abs/1607.06450>`_. | ||||
| .. math:: | .. math:: | ||||
| y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta | |||||
| y = \frac{x - mean]}{\sqrt{variance + \epsilon}} * \gamma + \beta | |||||
| where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon. | where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon. | ||||
| @@ -556,24 +556,5 @@ TEST_F(TestOptLib, test_constant_duplicate_mul) { | |||||
| ASSERT_TRUE(CheckOpt(beforerl, after, patterns)); | ASSERT_TRUE(CheckOpt(beforerl, after, patterns)); | ||||
| ASSERT_TRUE(CheckOpt(beforerr, after, patterns)); | ASSERT_TRUE(CheckOpt(beforerr, after, patterns)); | ||||
| } | } | ||||
| TEST_F(TestOptLib, test_adjust_allreduce_mul_add) { | |||||
| FuncGraphPtr beforell = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforell"); | |||||
| FuncGraphPtr beforelr = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforelr"); | |||||
| FuncGraphPtr beforerl = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforerl"); | |||||
| FuncGraphPtr beforerr = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforerr"); | |||||
| FuncGraphPtr after1 = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "after1"); | |||||
| FuncGraphPtr before2r = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "before2r"); | |||||
| FuncGraphPtr before2l = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "before2l"); | |||||
| FuncGraphPtr after2 = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "after2"); | |||||
| auto patterns = std::vector<SubstitutionPtr>({irpass.arithmetic_simplify_}); | |||||
| ASSERT_TRUE(CheckOpt(beforell, after1, patterns)); | |||||
| ASSERT_TRUE(CheckOpt(beforelr, after1, patterns)); | |||||
| ASSERT_TRUE(CheckOpt(beforerl, after1, patterns)); | |||||
| ASSERT_TRUE(CheckOpt(beforerr, after1, patterns)); | |||||
| ASSERT_TRUE(CheckOpt(before2l, after2, patterns)); | |||||
| ASSERT_TRUE(CheckOpt(before2r, after2, patterns)); | |||||
| } | |||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -908,8 +908,8 @@ def test_print_tuple_wrapper(tag): | |||||
| def test_constant_duplicate_mul(tag): | def test_constant_duplicate_mul(tag): | ||||
| fns = FnDict() | fns = FnDict() | ||||
| Mul = Primitive('Mul') | |||||
| Sqrt = Primitive('Sqrt') | |||||
| Mul = Primitive('Mul'); | |||||
| Sqrt = Primitive('Sqrt'); | |||||
| x = Tensor(np.array([[2, 2], [2, 3]]).astype('float32')) | x = Tensor(np.array([[2, 2], [2, 3]]).astype('float32')) | ||||
| tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')) | tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')) | ||||
| @@ -936,44 +936,3 @@ def test_constant_duplicate_mul(tag): | |||||
| return Mul(Sqrt(x), Mul(tensor1, tensor2)) | return Mul(Sqrt(x), Mul(tensor1, tensor2)) | ||||
| return fns[tag] | return fns[tag] | ||||
| def test_adjust_allreduce_mul_add(tag): | |||||
| fns = FnDict() | |||||
| Mul = Primitive('Mul') | |||||
| AddN = Primitive('AddN') | |||||
| AllReduce = Primitive('AllReduce') | |||||
| @fns | |||||
| def beforell(x, y, z): | |||||
| return AddN((z, Mul(y, AllReduce(x)))) | |||||
| @fns | |||||
| def beforelr(x, y, z): | |||||
| return AddN((z, Mul(AllReduce(x), y))) | |||||
| @fns | |||||
| def beforerl(x, y, z): | |||||
| return AddN((Mul(y, AllReduce(x)), z)) | |||||
| @fns | |||||
| def beforerr(x, y, z): | |||||
| return AddN((Mul(AllReduce(x), y), z)) | |||||
| @fns | |||||
| def after1(x, y, z): | |||||
| return Mul(AllReduce(AddN((z, x))), y) | |||||
| @fns | |||||
| def before2r(x, y, z): | |||||
| return AddN((Mul(AllReduce(x), y), Mul(z, z))) | |||||
| @fns | |||||
| def before2l(x, y, z): | |||||
| return AddN((Mul(z, z), Mul(AllReduce(x), y))) | |||||
| @fns | |||||
| def after2(x, y, z): | |||||
| return Mul(AllReduce(AddN((Mul(z, z), x))), y) | |||||
| return fns[tag] | |||||