| @@ -241,6 +241,7 @@ const PrimitivePtr kPrimNotInDict = std::make_shared<Primitive>("not_in_dict"); | |||
| const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator"); | |||
| const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv"); | |||
| const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset"); | |||
| const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduce"); | |||
| // Debug ops | |||
| const PrimitivePtr kPrimScalarSummary = std::make_shared<Primitive>("ScalarSummary"); | |||
| @@ -245,6 +245,7 @@ extern const PrimitivePtr kPrimInDict; | |||
| extern const PrimitivePtr kPrimNotInDict; | |||
| // Comm ops | |||
| extern const PrimitivePtr kPrimAllReduce; | |||
| extern const PrimitivePtr kPrimMirror; | |||
| extern const PrimitivePtr kPrimVirtualDiv; | |||
| extern const PrimitivePtr kPrimVirtualDataset; | |||
| @@ -53,6 +53,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||
| {prim::kPrimInsertGradientOf, prim::kPrimPrintShapeType, | |||
| prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv}); | |||
| zero_like_fill_zero_ = MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLikeTensor); | |||
| adjust_all_reduce_mul_add_ = MakeSubstitution(AdjustAllReduceMulAdd(), "adjust_all_reduce_mul_add", prim::kPrimAddN); | |||
| // ops eliminate | |||
| item_tuple_eliminate_ = | |||
| @@ -35,6 +35,7 @@ class OptimizeIRPassLib { | |||
| SubstitutionPtr arithmetic_simplify_; | |||
| SubstitutionPtr special_op_eliminate_; | |||
| SubstitutionPtr zero_like_fill_zero_; | |||
| SubstitutionPtr adjust_all_reduce_mul_add_; | |||
| // ops eliminate | |||
| SubstitutionPtr item_tuple_eliminate_; | |||
| @@ -228,6 +228,116 @@ class ConstantDuplicateMul : public AnfVisitor { | |||
| CNodePtr cnode_; | |||
| }; | |||
| // grad = AllReduce(grad) / worker_number | |||
| // grad = grad + weight * decy | |||
| // -> | |||
| // grad = grad + weight * decy | |||
| // grad = AllReduce(grad) / worker_number | |||
| // {prim::kPrimAddN, {prim::kPrimMakeTuple, {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}, Z}} -> | |||
| // {prim::kPrimMul, {prim::kPrimAllReduce, {prim::kPrimAddN,{prim::kPrimMakeTuple, Z, X}}}, Y} | |||
| class AdjustAllReduceMulAdd : public AnfVisitor { | |||
| public: | |||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||
| Reset(); | |||
| // {prim::kPrimAddN, Zs} | |||
| if (!IsPrimitiveCNode(node, prim::kPrimAddN)) { | |||
| return nullptr; | |||
| } | |||
| auto addn = node->cast<CNodePtr>(); | |||
| if (addn->size() != 2) { | |||
| return nullptr; | |||
| } | |||
| AnfVisitor::Match(prim::kPrimMakeTuple, {IsNode, IsNode})(addn->input(1)); | |||
| if (x_ == nullptr || y_ == nullptr || z_ == nullptr || all_reduce_fg_ == nullptr) { | |||
| return nullptr; | |||
| } | |||
| auto addn_maketuple = addn->input(1); | |||
| auto fg = all_reduce_fg_; | |||
| // addn inputs cross the graph, make the inputs same as allreduce node. | |||
| if (z_->isa<CNode>() && fg != z_->func_graph()) { | |||
| auto cnode_z = z_->cast<CNodePtr>(); | |||
| z_ = NewCNode(cnode_z->inputs(), fg); | |||
| } | |||
| auto addn_op_node = addn->input(0); | |||
| auto make_tuple_op_node = addn->input(1)->cast<CNodePtr>()->input(0); | |||
| AnfNodePtr tuple = NewCNode({make_tuple_op_node, z_, x_}, fg); | |||
| AnfNodePtr add = NewCNode({addn_op_node, tuple}, fg); | |||
| AnfNodePtr all_reduce = NewCNode({all_reduce_, add}, fg); | |||
| AnfNodePtr mul = NewCNode({mul_, all_reduce, y_}, fg); | |||
| ProcessDependEdge(fg, addn_maketuple, all_reduce); | |||
| return mul; | |||
| } | |||
| void ProcessDependEdge(const FuncGraphPtr &fg, const AnfNodePtr &addn_maketuple, const AnfNodePtr &new_node) { | |||
| // If has dynamic loss scale. | |||
| auto &users_map = fg->manager()->node_users(); | |||
| auto it = users_map.find(mul_cnode_); | |||
| if (it != users_map.end()) { | |||
| auto users = it->second; | |||
| for (auto &user_pair : users) { | |||
| auto node = user_pair.first; | |||
| if (node != addn_maketuple) { | |||
| if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { | |||
| fg->manager()->SetEdge(node, user_pair.second, new_node); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void Visit(const AnfNodePtr &node) override { | |||
| if (level_ == 0) { | |||
| level_ = 1; | |||
| is_reduce_match_ = false; | |||
| // {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y} | |||
| AnfVisitor::Match(prim::kPrimMul)(node); | |||
| level_ = 0; | |||
| if (is_reduce_match_) { | |||
| mul_ = node->cast<CNodePtr>()->input(0); | |||
| mul_cnode_ = node->cast<CNodePtr>(); | |||
| y_ = tmp_; | |||
| } else { | |||
| z_ = node; | |||
| } | |||
| } | |||
| if (level_ == 1) { | |||
| // {prim::kPrimAllReduce, X} | |||
| if (IsPrimitiveCNode(node, prim::kPrimAllReduce)) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (cnode->size() > 1) { | |||
| all_reduce_ = cnode->input(0); | |||
| x_ = cnode->input(1); | |||
| is_reduce_match_ = true; | |||
| all_reduce_fg_ = cnode->func_graph(); | |||
| } | |||
| } else { | |||
| tmp_ = node; | |||
| } | |||
| } | |||
| } | |||
| void Reset() { | |||
| level_ = 0; | |||
| is_reduce_match_ = false; | |||
| x_ = nullptr; | |||
| y_ = nullptr; | |||
| z_ = nullptr; | |||
| tmp_ = nullptr; | |||
| all_reduce_fg_ = nullptr; | |||
| } | |||
| private: | |||
| int level_{0}; | |||
| bool is_reduce_match_{false}; | |||
| AnfNodePtr x_{nullptr}, y_{nullptr}, z_{nullptr}, tmp_{nullptr}; | |||
| AnfNodePtr all_reduce_{nullptr}, mul_{nullptr}, mul_cnode_{nullptr}; | |||
| FuncGraphPtr all_reduce_fg_{nullptr}; | |||
| }; | |||
| class ArithmeticSimplify { | |||
| public: | |||
| ArithmeticSimplify() | |||
| @@ -28,6 +28,7 @@ | |||
| #include <utility> | |||
| #include "pipeline/parse/parse_base.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "utils/ordered_map.h" | |||
| namespace mindspore { | |||
| namespace parse { | |||
| @@ -99,7 +100,7 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> { | |||
| std::unordered_map<ParameterPtr, AnfNodePtr> removable_phis_; | |||
| // set state nodes need to insert before function return nodes. | |||
| std::unordered_map<AnfNodePtr, std::string> state_assign_; | |||
| OrderedMap<AnfNodePtr, std::string> state_assign_; | |||
| // hold declared global variables in function | |||
| std::set<std::string> global_vars_; | |||
| @@ -82,6 +82,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| // Arithmetic simplifications | |||
| irpass.arithmetic_simplify_, | |||
| irpass.addn_zero_filter_, | |||
| irpass.adjust_all_reduce_mul_add_, | |||
| // Miscellaneous | |||
| irpass.item_tuple_eliminate_, | |||
| @@ -1213,7 +1213,7 @@ class UnsortedSegmentSum(PrimitiveWithInfer): | |||
| Tensor, the shape is :math:`(z, x_{N+1}, ..., x_R)`. | |||
| Examples: | |||
| >>> input_x = Tensor([1, 2, 3, 4], mindspore.float) | |||
| >>> input_x = Tensor([1, 2, 3, 4], mindspore.float32) | |||
| >>> segment_ids = Tensor([0, 0, 1, 2], mindspore.int32) | |||
| >>> num_segments = 4 | |||
| >>> P.UnsortedSegmentSum()(input_x, segment_ids, num_segments) | |||
| @@ -1765,7 +1765,7 @@ class LayerNorm(Primitive): | |||
| `Layer Normalization <https://arxiv.org/abs/1607.06450>`_. | |||
| .. math:: | |||
| y = \frac{x - mean]}{\sqrt{variance + \epsilon}} * \gamma + \beta | |||
| y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta | |||
| where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon. | |||
| @@ -284,7 +284,8 @@ def prim_attr_register(fn): | |||
| def constexpr(fn=None, get_instance=True, name=None): | |||
| """ | |||
| Makes a PrimitiveWithInfer operator, which infer the value while compiling. | |||
| Makes a PrimitiveWithInfer operator, which infer the value while compiling. We can define a function | |||
| to compute between constant variable and used in constructß. | |||
| Args: | |||
| fn (function): A `fn` use as the infer_value of the output operator. | |||
| @@ -556,5 +556,24 @@ TEST_F(TestOptLib, test_constant_duplicate_mul) { | |||
| ASSERT_TRUE(CheckOpt(beforerl, after, patterns)); | |||
| ASSERT_TRUE(CheckOpt(beforerr, after, patterns)); | |||
| } | |||
| TEST_F(TestOptLib, test_adjust_allreduce_mul_add) { | |||
| FuncGraphPtr beforell = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforell"); | |||
| FuncGraphPtr beforelr = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforelr"); | |||
| FuncGraphPtr beforerl = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforerl"); | |||
| FuncGraphPtr beforerr = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforerr"); | |||
| FuncGraphPtr after1 = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "after1"); | |||
| FuncGraphPtr before2r = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "before2r"); | |||
| FuncGraphPtr before2l = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "before2l"); | |||
| FuncGraphPtr after2 = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "after2"); | |||
| auto patterns = std::vector<SubstitutionPtr>({irpass.adjust_all_reduce_mul_add_}); | |||
| ASSERT_TRUE(CheckOpt(beforell, after1, patterns)); | |||
| ASSERT_TRUE(CheckOpt(beforelr, after1, patterns)); | |||
| ASSERT_TRUE(CheckOpt(beforerl, after1, patterns)); | |||
| ASSERT_TRUE(CheckOpt(beforerr, after1, patterns)); | |||
| ASSERT_TRUE(CheckOpt(before2l, after2, patterns)); | |||
| ASSERT_TRUE(CheckOpt(before2r, after2, patterns)); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -1045,8 +1045,8 @@ def test_print_tuple_wrapper(tag): | |||
| def test_constant_duplicate_mul(tag): | |||
| fns = FnDict() | |||
| Mul = Primitive('Mul'); | |||
| Sqrt = Primitive('Sqrt'); | |||
| Mul = Primitive('Mul') | |||
| Sqrt = Primitive('Sqrt') | |||
| x = Tensor(np.array([[2, 2], [2, 3]]).astype('float32')) | |||
| tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')) | |||
| @@ -1073,3 +1073,44 @@ def test_constant_duplicate_mul(tag): | |||
| return Mul(Sqrt(x), Mul(tensor1, tensor2)) | |||
| return fns[tag] | |||
| def test_adjust_allreduce_mul_add(tag): | |||
| fns = FnDict() | |||
| Mul = Primitive('Mul') | |||
| AddN = Primitive('AddN') | |||
| AllReduce = Primitive('AllReduce') | |||
| @fns | |||
| def beforell(x, y, z): | |||
| return AddN((z, Mul(y, AllReduce(x)))) | |||
| @fns | |||
| def beforelr(x, y, z): | |||
| return AddN((z, Mul(AllReduce(x), y))) | |||
| @fns | |||
| def beforerl(x, y, z): | |||
| return AddN((Mul(y, AllReduce(x)), z)) | |||
| @fns | |||
| def beforerr(x, y, z): | |||
| return AddN((Mul(AllReduce(x), y), z)) | |||
| @fns | |||
| def after1(x, y, z): | |||
| return Mul(AllReduce(AddN((z, x))), y) | |||
| @fns | |||
| def before2r(x, y, z): | |||
| return AddN((Mul(AllReduce(x), y), Mul(z, z))) | |||
| @fns | |||
| def before2l(x, y, z): | |||
| return AddN((Mul(z, z), Mul(AllReduce(x), y))) | |||
| @fns | |||
| def after2(x, y, z): | |||
| return Mul(AllReduce(AddN((Mul(z, z), x))), y) | |||
| return fns[tag] | |||
| @@ -20,9 +20,14 @@ import mindspore.context as context | |||
| from mindspore import Tensor | |||
| from mindspore import amp | |||
| from mindspore import nn | |||
| from mindspore.train import Model | |||
| from mindspore.train import Model, ParallelMode | |||
| from mindspore import Tensor | |||
| from mindspore.common import dtype as mstype | |||
| import mindspore.context as context | |||
| from mindspore.model_zoo.resnet import resnet50 | |||
| from ....dataset_mock import MindData | |||
| from mindspore.parallel._auto_parallel_context import auto_parallel_context | |||
| from mindspore.communication.management import init | |||
| def setup_module(module): | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| @@ -138,3 +143,22 @@ def test_compile_model_train_O2(): | |||
| with pytest.raises(ValueError): | |||
| # not actual run, the metrics step will fail, check if compile ok. | |||
| model.eval(dataset) | |||
| def test_compile_model_train_O2_parallel(): | |||
| dataset_types = (np.float32, np.float32) | |||
| dataset_shapes = ((16, 16), (16, 16)) | |||
| dataset = MindDataSet(dataset_types, dataset_shapes) | |||
| net = NetNoLoss(16, 16) | |||
| loss = nn.MSELoss() | |||
| optimizer = nn.Momentum(net.trainable_params(), 0.1, 0.9, 0.00004, 1024.0) | |||
| context.set_auto_parallel_context( | |||
| global_rank=0, device_num=8, | |||
| mirror_mean=True, parameter_broadcast=True, | |||
| parallel_mode=ParallelMode.DATA_PARALLEL) | |||
| init() | |||
| model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={"acc"}, amp_level="O2") | |||
| model.train(2, dataset, dataset_sink_mode=False) | |||