| @@ -241,6 +241,7 @@ const PrimitivePtr kPrimNotInDict = std::make_shared<Primitive>("not_in_dict"); | |||||
| const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator"); | const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator"); | ||||
| const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv"); | const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv"); | ||||
| const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset"); | const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset"); | ||||
| const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduce"); | |||||
| // Debug ops | // Debug ops | ||||
| const PrimitivePtr kPrimScalarSummary = std::make_shared<Primitive>("ScalarSummary"); | const PrimitivePtr kPrimScalarSummary = std::make_shared<Primitive>("ScalarSummary"); | ||||
| @@ -245,6 +245,7 @@ extern const PrimitivePtr kPrimInDict; | |||||
| extern const PrimitivePtr kPrimNotInDict; | extern const PrimitivePtr kPrimNotInDict; | ||||
| // Comm ops | // Comm ops | ||||
| extern const PrimitivePtr kPrimAllReduce; | |||||
| extern const PrimitivePtr kPrimMirror; | extern const PrimitivePtr kPrimMirror; | ||||
| extern const PrimitivePtr kPrimVirtualDiv; | extern const PrimitivePtr kPrimVirtualDiv; | ||||
| extern const PrimitivePtr kPrimVirtualDataset; | extern const PrimitivePtr kPrimVirtualDataset; | ||||
| @@ -53,6 +53,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||||
| {prim::kPrimInsertGradientOf, prim::kPrimPrintShapeType, | {prim::kPrimInsertGradientOf, prim::kPrimPrintShapeType, | ||||
| prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv}); | prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv}); | ||||
| zero_like_fill_zero_ = MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLikeTensor); | zero_like_fill_zero_ = MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLikeTensor); | ||||
| adjust_all_reduce_mul_add_ = MakeSubstitution(AdjustAllReduceMulAdd(), "adjust_all_reduce_mul_add", prim::kPrimAddN); | |||||
| // ops eliminate | // ops eliminate | ||||
| item_tuple_eliminate_ = | item_tuple_eliminate_ = | ||||
| @@ -35,6 +35,7 @@ class OptimizeIRPassLib { | |||||
| SubstitutionPtr arithmetic_simplify_; | SubstitutionPtr arithmetic_simplify_; | ||||
| SubstitutionPtr special_op_eliminate_; | SubstitutionPtr special_op_eliminate_; | ||||
| SubstitutionPtr zero_like_fill_zero_; | SubstitutionPtr zero_like_fill_zero_; | ||||
| SubstitutionPtr adjust_all_reduce_mul_add_; | |||||
| // ops eliminate | // ops eliminate | ||||
| SubstitutionPtr item_tuple_eliminate_; | SubstitutionPtr item_tuple_eliminate_; | ||||
| @@ -228,6 +228,116 @@ class ConstantDuplicateMul : public AnfVisitor { | |||||
| CNodePtr cnode_; | CNodePtr cnode_; | ||||
| }; | }; | ||||
| // grad = AllReduce(grad) / worker_number | |||||
| // grad = grad + weight * decy | |||||
| // -> | |||||
| // grad = grad + weight * decy | |||||
| // grad = AllReduce(grad) / worker_number | |||||
| // {prim::kPrimAddN, {prim::kPrimMakeTuple, {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}, Z}} -> | |||||
| // {prim::kPrimMul, {prim::kPrimAllReduce, {prim::kPrimAddN,{prim::kPrimMakeTuple, Z, X}}}, Y} | |||||
| class AdjustAllReduceMulAdd : public AnfVisitor { | |||||
| public: | |||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||||
| Reset(); | |||||
| // {prim::kPrimAddN, Zs} | |||||
| if (!IsPrimitiveCNode(node, prim::kPrimAddN)) { | |||||
| return nullptr; | |||||
| } | |||||
| auto addn = node->cast<CNodePtr>(); | |||||
| if (addn->size() != 2) { | |||||
| return nullptr; | |||||
| } | |||||
| AnfVisitor::Match(prim::kPrimMakeTuple, {IsNode, IsNode})(addn->input(1)); | |||||
| if (x_ == nullptr || y_ == nullptr || z_ == nullptr || all_reduce_fg_ == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| auto addn_maketuple = addn->input(1); | |||||
| auto fg = all_reduce_fg_; | |||||
| // addn inputs cross the graph, make the inputs same as allreduce node. | |||||
| if (z_->isa<CNode>() && fg != z_->func_graph()) { | |||||
| auto cnode_z = z_->cast<CNodePtr>(); | |||||
| z_ = NewCNode(cnode_z->inputs(), fg); | |||||
| } | |||||
| auto addn_op_node = addn->input(0); | |||||
| auto make_tuple_op_node = addn->input(1)->cast<CNodePtr>()->input(0); | |||||
| AnfNodePtr tuple = NewCNode({make_tuple_op_node, z_, x_}, fg); | |||||
| AnfNodePtr add = NewCNode({addn_op_node, tuple}, fg); | |||||
| AnfNodePtr all_reduce = NewCNode({all_reduce_, add}, fg); | |||||
| AnfNodePtr mul = NewCNode({mul_, all_reduce, y_}, fg); | |||||
| ProcessDependEdge(fg, addn_maketuple, all_reduce); | |||||
| return mul; | |||||
| } | |||||
| void ProcessDependEdge(const FuncGraphPtr &fg, const AnfNodePtr &addn_maketuple, const AnfNodePtr &new_node) { | |||||
| // If has dynamic loss scale. | |||||
| auto &users_map = fg->manager()->node_users(); | |||||
| auto it = users_map.find(mul_cnode_); | |||||
| if (it != users_map.end()) { | |||||
| auto users = it->second; | |||||
| for (auto &user_pair : users) { | |||||
| auto node = user_pair.first; | |||||
| if (node != addn_maketuple) { | |||||
| if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { | |||||
| fg->manager()->SetEdge(node, user_pair.second, new_node); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| void Visit(const AnfNodePtr &node) override { | |||||
| if (level_ == 0) { | |||||
| level_ = 1; | |||||
| is_reduce_match_ = false; | |||||
| // {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y} | |||||
| AnfVisitor::Match(prim::kPrimMul)(node); | |||||
| level_ = 0; | |||||
| if (is_reduce_match_) { | |||||
| mul_ = node->cast<CNodePtr>()->input(0); | |||||
| mul_cnode_ = node->cast<CNodePtr>(); | |||||
| y_ = tmp_; | |||||
| } else { | |||||
| z_ = node; | |||||
| } | |||||
| } | |||||
| if (level_ == 1) { | |||||
| // {prim::kPrimAllReduce, X} | |||||
| if (IsPrimitiveCNode(node, prim::kPrimAllReduce)) { | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| if (cnode->size() > 1) { | |||||
| all_reduce_ = cnode->input(0); | |||||
| x_ = cnode->input(1); | |||||
| is_reduce_match_ = true; | |||||
| all_reduce_fg_ = cnode->func_graph(); | |||||
| } | |||||
| } else { | |||||
| tmp_ = node; | |||||
| } | |||||
| } | |||||
| } | |||||
| void Reset() { | |||||
| level_ = 0; | |||||
| is_reduce_match_ = false; | |||||
| x_ = nullptr; | |||||
| y_ = nullptr; | |||||
| z_ = nullptr; | |||||
| tmp_ = nullptr; | |||||
| all_reduce_fg_ = nullptr; | |||||
| } | |||||
| private: | |||||
| int level_{0}; | |||||
| bool is_reduce_match_{false}; | |||||
| AnfNodePtr x_{nullptr}, y_{nullptr}, z_{nullptr}, tmp_{nullptr}; | |||||
| AnfNodePtr all_reduce_{nullptr}, mul_{nullptr}, mul_cnode_{nullptr}; | |||||
| FuncGraphPtr all_reduce_fg_{nullptr}; | |||||
| }; | |||||
| class ArithmeticSimplify { | class ArithmeticSimplify { | ||||
| public: | public: | ||||
| ArithmeticSimplify() | ArithmeticSimplify() | ||||
| @@ -28,6 +28,7 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include "pipeline/parse/parse_base.h" | #include "pipeline/parse/parse_base.h" | ||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| #include "utils/ordered_map.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parse { | namespace parse { | ||||
| @@ -99,7 +100,7 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> { | |||||
| std::unordered_map<ParameterPtr, AnfNodePtr> removable_phis_; | std::unordered_map<ParameterPtr, AnfNodePtr> removable_phis_; | ||||
| // set state nodes need to insert before function return nodes. | // set state nodes need to insert before function return nodes. | ||||
| std::unordered_map<AnfNodePtr, std::string> state_assign_; | |||||
| OrderedMap<AnfNodePtr, std::string> state_assign_; | |||||
| // hold declared global variables in function | // hold declared global variables in function | ||||
| std::set<std::string> global_vars_; | std::set<std::string> global_vars_; | ||||
| @@ -82,6 +82,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { | |||||
| // Arithmetic simplifications | // Arithmetic simplifications | ||||
| irpass.arithmetic_simplify_, | irpass.arithmetic_simplify_, | ||||
| irpass.addn_zero_filter_, | irpass.addn_zero_filter_, | ||||
| irpass.adjust_all_reduce_mul_add_, | |||||
| // Miscellaneous | // Miscellaneous | ||||
| irpass.item_tuple_eliminate_, | irpass.item_tuple_eliminate_, | ||||
| @@ -1213,7 +1213,7 @@ class UnsortedSegmentSum(PrimitiveWithInfer): | |||||
| Tensor, the shape is :math:`(z, x_{N+1}, ..., x_R)`. | Tensor, the shape is :math:`(z, x_{N+1}, ..., x_R)`. | ||||
| Examples: | Examples: | ||||
| >>> input_x = Tensor([1, 2, 3, 4], mindspore.float) | |||||
| >>> input_x = Tensor([1, 2, 3, 4], mindspore.float32) | |||||
| >>> segment_ids = Tensor([0, 0, 1, 2], mindspore.int32) | >>> segment_ids = Tensor([0, 0, 1, 2], mindspore.int32) | ||||
| >>> num_segments = 4 | >>> num_segments = 4 | ||||
| >>> P.UnsortedSegmentSum()(input_x, segment_ids, num_segments) | >>> P.UnsortedSegmentSum()(input_x, segment_ids, num_segments) | ||||
| @@ -1765,7 +1765,7 @@ class LayerNorm(Primitive): | |||||
| `Layer Normalization <https://arxiv.org/abs/1607.06450>`_. | `Layer Normalization <https://arxiv.org/abs/1607.06450>`_. | ||||
| .. math:: | .. math:: | ||||
| y = \frac{x - mean]}{\sqrt{variance + \epsilon}} * \gamma + \beta | |||||
| y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta | |||||
| where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon. | where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon. | ||||
| @@ -284,7 +284,8 @@ def prim_attr_register(fn): | |||||
| def constexpr(fn=None, get_instance=True, name=None): | def constexpr(fn=None, get_instance=True, name=None): | ||||
| """ | """ | ||||
| Makes a PrimitiveWithInfer operator, which infer the value while compiling. | |||||
| Makes a PrimitiveWithInfer operator, which infer the value while compiling. We can define a function | |||||
| to compute between constant variable and used in constructß. | |||||
| Args: | Args: | ||||
| fn (function): A `fn` use as the infer_value of the output operator. | fn (function): A `fn` use as the infer_value of the output operator. | ||||
| @@ -556,5 +556,24 @@ TEST_F(TestOptLib, test_constant_duplicate_mul) { | |||||
| ASSERT_TRUE(CheckOpt(beforerl, after, patterns)); | ASSERT_TRUE(CheckOpt(beforerl, after, patterns)); | ||||
| ASSERT_TRUE(CheckOpt(beforerr, after, patterns)); | ASSERT_TRUE(CheckOpt(beforerr, after, patterns)); | ||||
| } | } | ||||
| TEST_F(TestOptLib, test_adjust_allreduce_mul_add) { | |||||
| FuncGraphPtr beforell = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforell"); | |||||
| FuncGraphPtr beforelr = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforelr"); | |||||
| FuncGraphPtr beforerl = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforerl"); | |||||
| FuncGraphPtr beforerr = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforerr"); | |||||
| FuncGraphPtr after1 = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "after1"); | |||||
| FuncGraphPtr before2r = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "before2r"); | |||||
| FuncGraphPtr before2l = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "before2l"); | |||||
| FuncGraphPtr after2 = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "after2"); | |||||
| auto patterns = std::vector<SubstitutionPtr>({irpass.adjust_all_reduce_mul_add_}); | |||||
| ASSERT_TRUE(CheckOpt(beforell, after1, patterns)); | |||||
| ASSERT_TRUE(CheckOpt(beforelr, after1, patterns)); | |||||
| ASSERT_TRUE(CheckOpt(beforerl, after1, patterns)); | |||||
| ASSERT_TRUE(CheckOpt(beforerr, after1, patterns)); | |||||
| ASSERT_TRUE(CheckOpt(before2l, after2, patterns)); | |||||
| ASSERT_TRUE(CheckOpt(before2r, after2, patterns)); | |||||
| } | |||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -1045,8 +1045,8 @@ def test_print_tuple_wrapper(tag): | |||||
| def test_constant_duplicate_mul(tag): | def test_constant_duplicate_mul(tag): | ||||
| fns = FnDict() | fns = FnDict() | ||||
| Mul = Primitive('Mul'); | |||||
| Sqrt = Primitive('Sqrt'); | |||||
| Mul = Primitive('Mul') | |||||
| Sqrt = Primitive('Sqrt') | |||||
| x = Tensor(np.array([[2, 2], [2, 3]]).astype('float32')) | x = Tensor(np.array([[2, 2], [2, 3]]).astype('float32')) | ||||
| tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')) | tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')) | ||||
| @@ -1073,3 +1073,44 @@ def test_constant_duplicate_mul(tag): | |||||
| return Mul(Sqrt(x), Mul(tensor1, tensor2)) | return Mul(Sqrt(x), Mul(tensor1, tensor2)) | ||||
| return fns[tag] | return fns[tag] | ||||
| def test_adjust_allreduce_mul_add(tag): | |||||
| fns = FnDict() | |||||
| Mul = Primitive('Mul') | |||||
| AddN = Primitive('AddN') | |||||
| AllReduce = Primitive('AllReduce') | |||||
| @fns | |||||
| def beforell(x, y, z): | |||||
| return AddN((z, Mul(y, AllReduce(x)))) | |||||
| @fns | |||||
| def beforelr(x, y, z): | |||||
| return AddN((z, Mul(AllReduce(x), y))) | |||||
| @fns | |||||
| def beforerl(x, y, z): | |||||
| return AddN((Mul(y, AllReduce(x)), z)) | |||||
| @fns | |||||
| def beforerr(x, y, z): | |||||
| return AddN((Mul(AllReduce(x), y), z)) | |||||
| @fns | |||||
| def after1(x, y, z): | |||||
| return Mul(AllReduce(AddN((z, x))), y) | |||||
| @fns | |||||
| def before2r(x, y, z): | |||||
| return AddN((Mul(AllReduce(x), y), Mul(z, z))) | |||||
| @fns | |||||
| def before2l(x, y, z): | |||||
| return AddN((Mul(z, z), Mul(AllReduce(x), y))) | |||||
| @fns | |||||
| def after2(x, y, z): | |||||
| return Mul(AllReduce(AddN((Mul(z, z), x))), y) | |||||
| return fns[tag] | |||||
| @@ -20,9 +20,14 @@ import mindspore.context as context | |||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| from mindspore import amp | from mindspore import amp | ||||
| from mindspore import nn | from mindspore import nn | ||||
| from mindspore.train import Model | |||||
| from mindspore.train import Model, ParallelMode | |||||
| from mindspore import Tensor | |||||
| from mindspore.common import dtype as mstype | |||||
| import mindspore.context as context | |||||
| from mindspore.model_zoo.resnet import resnet50 | |||||
| from ....dataset_mock import MindData | from ....dataset_mock import MindData | ||||
| from mindspore.parallel._auto_parallel_context import auto_parallel_context | |||||
| from mindspore.communication.management import init | |||||
| def setup_module(module): | def setup_module(module): | ||||
| context.set_context(mode=context.GRAPH_MODE) | context.set_context(mode=context.GRAPH_MODE) | ||||
| @@ -138,3 +143,22 @@ def test_compile_model_train_O2(): | |||||
| with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
| # not actual run, the metrics step will fail, check if compile ok. | # not actual run, the metrics step will fail, check if compile ok. | ||||
| model.eval(dataset) | model.eval(dataset) | ||||
| def test_compile_model_train_O2_parallel(): | |||||
| dataset_types = (np.float32, np.float32) | |||||
| dataset_shapes = ((16, 16), (16, 16)) | |||||
| dataset = MindDataSet(dataset_types, dataset_shapes) | |||||
| net = NetNoLoss(16, 16) | |||||
| loss = nn.MSELoss() | |||||
| optimizer = nn.Momentum(net.trainable_params(), 0.1, 0.9, 0.00004, 1024.0) | |||||
| context.set_auto_parallel_context( | |||||
| global_rank=0, device_num=8, | |||||
| mirror_mean=True, parameter_broadcast=True, | |||||
| parallel_mode=ParallelMode.DATA_PARALLEL) | |||||
| init() | |||||
| model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={"acc"}, amp_level="O2") | |||||
| model.train(2, dataset, dataset_sink_mode=False) | |||||