This reverts commit ea6958c50a.
tags/v0.3.0-alpha
| @@ -230,7 +230,6 @@ const PrimitivePtr kPrimNotInDict = std::make_shared<Primitive>("not_in_dict"); | |||
| const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator"); | |||
| const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv"); | |||
| const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset"); | |||
| const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduce"); | |||
| // Debug ops | |||
| const PrimitivePtr kPrimScalarSummary = std::make_shared<Primitive>("ScalarSummary"); | |||
| @@ -234,7 +234,6 @@ extern const PrimitivePtr kPrimInDict; | |||
| extern const PrimitivePtr kPrimNotInDict; | |||
| // Comm ops | |||
| extern const PrimitivePtr kPrimAllReduce; | |||
| extern const PrimitivePtr kPrimMirror; | |||
| extern const PrimitivePtr kPrimVirtualDiv; | |||
| extern const PrimitivePtr kPrimVirtualDataset; | |||
| @@ -48,7 +48,7 @@ namespace irpass { | |||
| OptimizeIRPassLib::OptimizeIRPassLib() { | |||
| arithmetic_simplify_ = MakeSubstitution(ArithmeticSimplify(), "arithmetic_simplify", | |||
| {prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd, | |||
| prim::kPrimAddN, prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul}); | |||
| prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul}); | |||
| special_op_eliminate_ = MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate", | |||
| {prim::kPrimInsertGradientOf, prim::kPrimPrintShapeType, | |||
| prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv}); | |||
| @@ -228,82 +228,6 @@ class ConstantDuplicateMul : public AnfVisitor { | |||
| CNodePtr cnode_; | |||
| }; | |||
| // grad = AllReduce(grad) / worker_number | |||
| // grad = grad + weight * decy | |||
| // -> | |||
| // grad = grad + weight * decy | |||
| // grad = AllReduce(grad) / worker_number | |||
| // {prim::kPrimAddN, {prim::kPrimMakeTuple, {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}, Z}} -> | |||
| // {prim::kPrimMul, {prim::kPrimAllReduce, {prim::kPrimAddN,{prim::kPrimMakeTuple, Z, X}}}, Y} | |||
| class AdjustAllReduceMulAdd : public AnfVisitor { | |||
| public: | |||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||
| Reset(); | |||
| // {prim::kPrimAddN, Zs} | |||
| if (!IsPrimitiveCNode(node, prim::kPrimAddN)) { | |||
| return nullptr; | |||
| } | |||
| auto addn = node->cast<CNodePtr>(); | |||
| if (addn->size() != 2) { | |||
| return nullptr; | |||
| } | |||
| AnfVisitor::Match(prim::kPrimMakeTuple, {IsNode, IsNode})(addn->input(1)); | |||
| if (x_ == nullptr || y_ == nullptr || z_ == nullptr) { | |||
| return nullptr; | |||
| } | |||
| auto fg = node->func_graph(); | |||
| AnfNodePtr tuple = NewCNode({NewValueNode(prim::kPrimMakeTuple), z_, x_}, fg); | |||
| AnfNodePtr add = NewCNode({NewValueNode(prim::kPrimAddN), tuple}, fg); | |||
| AnfNodePtr all_reduce = NewCNode({NewValueNode(prim::kPrimAllReduce), add}, fg); | |||
| return NewCNode({NewValueNode(prim::kPrimMul), all_reduce, y_}, fg); | |||
| } | |||
| void Visit(const AnfNodePtr &node) override { | |||
| if (level_ == 0) { | |||
| level_ = 1; | |||
| is_reduce_match_ = false; | |||
| // {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y} | |||
| AnfVisitor::Match(prim::kPrimMul)(node); | |||
| level_ = 0; | |||
| if (is_reduce_match_) { | |||
| y_ = tmp_; | |||
| } else { | |||
| z_ = node; | |||
| } | |||
| } | |||
| if (level_ == 1) { | |||
| // {prim::kPrimAllReduce, X} | |||
| if (IsPrimitiveCNode(node, prim::kPrimAllReduce)) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (cnode->size() > 1) { | |||
| x_ = cnode->input(1); | |||
| is_reduce_match_ = true; | |||
| } | |||
| } else { | |||
| tmp_ = node; | |||
| } | |||
| } | |||
| } | |||
| void Reset() { | |||
| level_ = 0; | |||
| is_reduce_match_ = false; | |||
| x_ = nullptr; | |||
| y_ = nullptr; | |||
| z_ = nullptr; | |||
| tmp_ = nullptr; | |||
| } | |||
| private: | |||
| int level_{0}; | |||
| bool is_reduce_match_{false}; | |||
| AnfNodePtr x_{nullptr}, y_{nullptr}, z_{nullptr}, tmp_{nullptr}; | |||
| }; | |||
| class ArithmeticSimplify { | |||
| public: | |||
| ArithmeticSimplify() | |||
| @@ -319,7 +243,6 @@ class ArithmeticSimplify { | |||
| eliminaters_.emplace_back(identity_); | |||
| eliminaters_.emplace_back(opt_update_zero_tensor_); | |||
| eliminaters_.emplace_back(constant_duplicate_mul_); | |||
| eliminaters_.emplace_back(adjust_allreduce_mul_add_); | |||
| } | |||
| ~ArithmeticSimplify() = default; | |||
| @@ -341,7 +264,6 @@ class ArithmeticSimplify { | |||
| PrimEliminater identity_; | |||
| OptUpdateZeroTensor opt_update_zero_tensor_; | |||
| ConstantDuplicateMul constant_duplicate_mul_; | |||
| AdjustAllReduceMulAdd adjust_allreduce_mul_add_; | |||
| std::vector<TransformFuncType> eliminaters_{}; | |||
| }; | |||
| } // namespace irpass | |||
| @@ -1229,7 +1229,7 @@ class UnsortedSegmentSum(PrimitiveWithInfer): | |||
| Tensor, the shape is :math:`(z, x_{N+1}, ..., x_R)`. | |||
| Examples: | |||
| >>> input_x = Tensor([1, 2, 3, 4], mindspore.float32) | |||
| >>> input_x = Tensor([1, 2, 3, 4], mindspore.float) | |||
| >>> segment_ids = Tensor([0, 0, 1, 2], mindspore.int32) | |||
| >>> num_segments = 4 | |||
| >>> P.UnsortedSegmentSum()(input_x, segment_ids, num_segments) | |||
| @@ -1630,7 +1630,7 @@ class LayerNorm(Primitive): | |||
| `Layer Normalization <https://arxiv.org/abs/1607.06450>`_. | |||
| .. math:: | |||
| y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta | |||
| y = \frac{x - mean]}{\sqrt{variance + \epsilon}} * \gamma + \beta | |||
| where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon. | |||
| @@ -556,24 +556,5 @@ TEST_F(TestOptLib, test_constant_duplicate_mul) { | |||
| ASSERT_TRUE(CheckOpt(beforerl, after, patterns)); | |||
| ASSERT_TRUE(CheckOpt(beforerr, after, patterns)); | |||
| } | |||
| TEST_F(TestOptLib, test_adjust_allreduce_mul_add) { | |||
| FuncGraphPtr beforell = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforell"); | |||
| FuncGraphPtr beforelr = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforelr"); | |||
| FuncGraphPtr beforerl = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforerl"); | |||
| FuncGraphPtr beforerr = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforerr"); | |||
| FuncGraphPtr after1 = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "after1"); | |||
| FuncGraphPtr before2r = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "before2r"); | |||
| FuncGraphPtr before2l = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "before2l"); | |||
| FuncGraphPtr after2 = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "after2"); | |||
| auto patterns = std::vector<SubstitutionPtr>({irpass.arithmetic_simplify_}); | |||
| ASSERT_TRUE(CheckOpt(beforell, after1, patterns)); | |||
| ASSERT_TRUE(CheckOpt(beforelr, after1, patterns)); | |||
| ASSERT_TRUE(CheckOpt(beforerl, after1, patterns)); | |||
| ASSERT_TRUE(CheckOpt(beforerr, after1, patterns)); | |||
| ASSERT_TRUE(CheckOpt(before2l, after2, patterns)); | |||
| ASSERT_TRUE(CheckOpt(before2r, after2, patterns)); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -908,8 +908,8 @@ def test_print_tuple_wrapper(tag): | |||
| def test_constant_duplicate_mul(tag): | |||
| fns = FnDict() | |||
| Mul = Primitive('Mul') | |||
| Sqrt = Primitive('Sqrt') | |||
| Mul = Primitive('Mul'); | |||
| Sqrt = Primitive('Sqrt'); | |||
| x = Tensor(np.array([[2, 2], [2, 3]]).astype('float32')) | |||
| tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')) | |||
| @@ -936,44 +936,3 @@ def test_constant_duplicate_mul(tag): | |||
| return Mul(Sqrt(x), Mul(tensor1, tensor2)) | |||
| return fns[tag] | |||
| def test_adjust_allreduce_mul_add(tag): | |||
| fns = FnDict() | |||
| Mul = Primitive('Mul') | |||
| AddN = Primitive('AddN') | |||
| AllReduce = Primitive('AllReduce') | |||
| @fns | |||
| def beforell(x, y, z): | |||
| return AddN((z, Mul(y, AllReduce(x)))) | |||
| @fns | |||
| def beforelr(x, y, z): | |||
| return AddN((z, Mul(AllReduce(x), y))) | |||
| @fns | |||
| def beforerl(x, y, z): | |||
| return AddN((Mul(y, AllReduce(x)), z)) | |||
| @fns | |||
| def beforerr(x, y, z): | |||
| return AddN((Mul(AllReduce(x), y), z)) | |||
| @fns | |||
| def after1(x, y, z): | |||
| return Mul(AllReduce(AddN((z, x))), y) | |||
| @fns | |||
| def before2r(x, y, z): | |||
| return AddN((Mul(AllReduce(x), y), Mul(z, z))) | |||
| @fns | |||
| def before2l(x, y, z): | |||
| return AddN((Mul(z, z), Mul(AllReduce(x), y))) | |||
| @fns | |||
| def after2(x, y, z): | |||
| return Mul(AllReduce(AddN((Mul(z, z), x))), y) | |||
| return fns[tag] | |||