From 9da3f9bec9374595ff60fdced7ddf6e2d7a29aa7 Mon Sep 17 00:00:00 2001 From: yangzhenzhang Date: Mon, 28 Dec 2020 13:58:34 +0800 Subject: [PATCH] mini step grad accumulation --- mindspore/ccsrc/frontend/optimizer/irpass.cc | 2 + mindspore/ccsrc/frontend/optimizer/irpass.h | 1 + .../optimizer/irpass/special_op_eliminate.h | 23 ++ mindspore/ccsrc/frontend/parallel/context.cc | 5 + mindspore/ccsrc/frontend/parallel/context.h | 4 + .../parallel/ops_info/operator_info.cc | 13 +- .../frontend/parallel/ops_info/ops_utils.h | 4 + .../ccsrc/frontend/parallel/step_parallel.cc | 172 +++++++++-- .../ccsrc/frontend/parallel/step_parallel.h | 5 - mindspore/ccsrc/pipeline/jit/init.cc | 2 + mindspore/ccsrc/pipeline/jit/pass.cc | 1 + mindspore/core/base/core_ops.h | 1 + mindspore/ops/_grad/grad_comm_ops.py | 78 ++++- mindspore/ops/operations/__init__.py | 2 +- mindspore/ops/operations/comm_ops.py | 29 ++ mindspore/parallel/_auto_parallel_context.py | 19 +- .../python/parallel/test_grad_accumulation.py | 289 ++++++++++++++++++ 17 files changed, 613 insertions(+), 37 deletions(-) create mode 100644 tests/ut/python/parallel/test_grad_accumulation.py diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.cc b/mindspore/ccsrc/frontend/optimizer/irpass.cc index 4f01b8d415..a9fccf618c 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass.cc @@ -80,6 +80,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() { {prim::kPrimReduceMean, prim::kPrimReduceAll, prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin}); partial_eliminate_ = MakeSubstitution(std::make_shared(), "partial_eliminate", IsCNodeDup); same_eliminate_ = MakeSubstitution(std::make_shared(), "same_eliminate", prim::kPrimSameTypeShape); + mirror_mini_step_elim_ = MakeSubstitution(std::make_shared(), "mirror_mini_step_eliminate", + prim::kPrimMirrorMiniStep); check_bprop_eliminate_ = MakeSubstitution(std::make_shared(), "check_bprop_eliminate", prim::kPrimCheckBprop); reset_defer_inline_ = diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.h b/mindspore/ccsrc/frontend/optimizer/irpass.h index 7bbfbc63f0..752f844057 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass.h @@ -51,6 +51,7 @@ class OptimizeIRPassLib { SubstitutionPtr reset_defer_inline_; SubstitutionPtr depend_value_elim_; SubstitutionPtr all_reduce_const_elim_; + SubstitutionPtr mirror_mini_step_elim_; // Env Item Eliminate SubstitutionPtr env_get_item_eliminate_; diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h index 9ce5fa335a..a4e84e9cbf 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h @@ -155,6 +155,29 @@ class CheckBpropEliminater : public AnfVisitor { AnfNodePtr x_{nullptr}; }; +// {prim::kPrimMirrorMiniStep, X, Y, Z} -> X +class MirrorMiniStepEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + if (!IsPrimitiveCNode(node, prim::kPrimMirrorMiniStep) || node->func_graph() == nullptr) { + return nullptr; + } + + auto cnode = node->cast(); + if (cnode == nullptr) { + return nullptr; + } + auto inputs = cnode->inputs(); + if (inputs.size() < 2) { + return nullptr; + } + + return inputs[1]; + } + + void Visit(const AnfNodePtr &) override {} +}; + // Reset defer_inline flag class ResetDeferInline : public AnfVisitor { public: diff --git a/mindspore/ccsrc/frontend/parallel/context.cc b/mindspore/ccsrc/frontend/parallel/context.cc index 7d82feb6ce..d24e10b0e9 100644 --- a/mindspore/ccsrc/frontend/parallel/context.cc +++ b/mindspore/ccsrc/frontend/parallel/context.cc @@ -64,6 +64,7 @@ void ParallelContext::Reset() { all_reduce_fusion_split_sizes_.clear(); strategy_search_mode_ = DYNAMIC_PROGRAMMING; pipeline_stage_split_num_ = 1; + grad_accumulation_step_ = 1; } void ParallelContext::set_device_num(int64_t device_num) { @@ -80,6 +81,10 @@ void ParallelContext::set_gradients_mean(bool gradients_mean) { gradients_mean_ void ParallelContext::set_full_batch(bool full_batch) { full_batch_ = full_batch; } +void ParallelContext::set_grad_accumulation_step(int64_t grad_accumulation_step) { + grad_accumulation_step_ = grad_accumulation_step; +} + void ParallelContext::set_gradient_fp32_sync(bool gradient_fp32_sync) { gradient_fp32_sync_ = gradient_fp32_sync; } void ParallelContext::set_loss_repeated_mean(bool loss_repeated_mean) { loss_repeated_mean_ = loss_repeated_mean; } diff --git a/mindspore/ccsrc/frontend/parallel/context.h b/mindspore/ccsrc/frontend/parallel/context.h index b58b9b2371..4f964bc479 100644 --- a/mindspore/ccsrc/frontend/parallel/context.h +++ b/mindspore/ccsrc/frontend/parallel/context.h @@ -73,6 +73,9 @@ class ParallelContext { void set_global_rank(int64_t global_rank); int64_t global_rank() const { return global_rank_; } + void set_grad_accumulation_step(int64_t grad_accumulation_step); + int64_t grad_accumulation_step() const { return grad_accumulation_step_; } + bool set_parallel_mode(const std::string ¶llel_mode); std::string parallel_mode() const { return parallel_mode_; } @@ -116,6 +119,7 @@ class ParallelContext { bool loss_repeated_mean_; int64_t device_num_; int64_t global_rank_; + int64_t grad_accumulation_step_; std::string parallel_mode_; std::string strategy_search_mode_; int64_t pipeline_stage_split_num_; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc index f910be859d..6ad3b32f82 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc @@ -285,8 +285,8 @@ OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num) { } OperatorVector op_for_weight; bool mean_flag = ParallelContext::GetInstance()->gradients_mean(); + int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step(); - OperatorName operator_name = MIRROR_OPERATOR; ValuePtr attr0_value = MakeValue(group_name); ValuePtr attr1_value = MakeValue(SizeToLong(dev_num)); ValuePtr attr2_value = MakeValue(mean_flag); @@ -300,6 +300,17 @@ OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num) { operator_attrs.push_back(attr1); operator_attrs.push_back(attr2); + OperatorName operator_name; + if (grad_accumulation_step > 1) { + operator_name = MIRROR_MINI_STEP_OPERATOR; + ValuePtr attr3_value = MakeValue(grad_accumulation_step); + Attr attr3 = std::make_pair(GRAD_ACCUMULATION_STEP, attr3_value); + operator_attrs.push_back(attr3); + MS_LOG(INFO) << "The grad accumulation step is " << grad_accumulation_step << ", use mini step mirror"; + } else { + operator_name = MIRROR_OPERATOR; + } + OperatorParams operator_param; OperatorArgs operator_args = std::make_pair(operator_attrs, operator_param); diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h index 201d5d6d38..0a1b2563e8 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h @@ -146,8 +146,10 @@ constexpr char IS_IN_FORWARD[] = "is_in_forward"; constexpr char DTYPE[] = "DType"; constexpr char DEV_NUM[] = "dev_num"; constexpr char MEAN_FLAG[] = "mean_flag"; +constexpr char GRAD_ACCUMULATION_STEP[] = "grad_accumulation_step"; constexpr char TYPES[] = "types"; constexpr char SHAPES[] = "shapes"; +constexpr char ACCU_GRADS[] = "accu_grads"; constexpr char GETNEXT_NUM[] = "output_num"; constexpr char SHARED_NAME[] = "shared_name"; constexpr char MIRROR_OP[] = "mirror_op"; @@ -171,6 +173,8 @@ constexpr char CONCAT_BY_AXIS[] = "ConcatByAxis"; constexpr char SPLIT_BY_AXIS[] = "SplitByAxis"; constexpr char ALL_REDUCE[] = "AllReduce"; constexpr char MIRROR_OPERATOR[] = "_MirrorOperator"; +constexpr char MIRROR_MINI_STEP_OPERATOR[] = "_MirrorMiniStepOperator"; +constexpr char LOCAL_STEP[] = "local_step"; constexpr char STRIDED_SLICE[] = "StridedSlice"; constexpr char ALL_GATHER[] = "AllGather"; constexpr char REDUCE_SCATTER[] = "ReduceScatter"; diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index ce97cd3763..910e9f9725 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -128,6 +128,137 @@ void InsertNode(const Operator &op, const CNodePtr &node, size_t index, const An MS_LOG(INFO) << "Insert " << instance_name << " success"; } +bool ParameterIsCloned(const AnfNodePtr ¶meter_node) { + MS_EXCEPTION_IF_NULL(parameter_node); + auto cloned_parameter = parameter_node->cast(); + MS_EXCEPTION_IF_NULL(cloned_parameter); + + // find the clone parameter + if (!cloned_parameter->has_default()) { + return false; + } + auto param_value = cloned_parameter->param_info(); + if (param_value == nullptr) { + return false; + } + bool cloned = param_value->cloned(); + if (!cloned) { + return false; + } + + MS_LOG(INFO) << "The parameter: " << cloned_parameter->name() << " is cloned"; + return true; +} + +std::vector CreateMirrorInput(const FuncGraphPtr &root, const Operator &op, const AnfNodePtr &node, + const std::string &instance_name, const std::string &weight_name) { + MS_EXCEPTION_IF_NULL(root); + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(root->manager()); + + AnfNodePtr local_step_param = nullptr; + AnfNodePtr grad_accu = nullptr; + std::string op_name = op.first; + OperatorArgs arg_forward = op.second; + + int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step(); + + if (grad_accumulation_step > 1) { + bool find_locat_step_node = false; + auto parameters = root->parameters(); + for (auto ¶m : parameters) { + auto param_ptr = param->cast(); + MS_EXCEPTION_IF_NULL(param_ptr); + if (param_ptr->name() == LOCAL_STEP) { + auto param_users = root->manager()->node_users()[param]; + for (auto &user : param_users) { + if (AnfNodeIsPrimitive(user.first, ASSIGN)) { + find_locat_step_node = true; + local_step_param = user.first; + MS_LOG(INFO) << "Find the local step when create mirror, it may be in the mini step grad accumulation mode"; + break; + } + } + break; + } + } + + bool find_grad_accu_node = false; + for (auto ¶m : parameters) { + if (!ParameterIsCloned(param)) { + continue; + } + + auto param_ptr = param->cast(); + MS_EXCEPTION_IF_NULL(param_ptr); + if (param_ptr->name().find(weight_name) != std::string::npos && + param_ptr->name().find(ACCU_GRADS) != std::string::npos) { + find_grad_accu_node = true; + grad_accu = param; + MS_LOG(INFO) << "Find the accumulation grad node: " << param_ptr->name(); + break; + } + } + + if (op_name == MIRROR_MINI_STEP_OPERATOR) { + if (!find_locat_step_node || !find_grad_accu_node) { + op_name = MIRROR_OPERATOR; + arg_forward.first.pop_back(); + } + } + } + + ValuePtr pyop_instance = CreatOpInstance(arg_forward.first, op_name, instance_name); + MS_EXCEPTION_IF_NULL(pyop_instance); + OperatorParams params = arg_forward.second; + + std::vector new_node_input; + if (op_name == MIRROR_MINI_STEP_OPERATOR) { + new_node_input = {NewValueNode(pyop_instance), node, local_step_param, grad_accu}; + MS_LOG(INFO) << "Insert the local step node and grad accumulation node as the mirror op's input"; + } else { + new_node_input = {NewValueNode(pyop_instance), node}; + } + + if (!params.empty()) { + for (auto ¶m : params) { + AnfNodePtr val = NewValueNode(param.first.second); + MS_EXCEPTION_IF_NULL(val); + int64_t position = param.second; + (void)new_node_input.insert(new_node_input.begin() + position, val); + } + } + + // if the op have 'group' attr, set the rank list name for the op + SetCommunicationOpGroupLabel(new_node_input); + return new_node_input; +} + +void InsertMirrorNode(const FuncGraphPtr &root, const Operator &op, const CNodePtr &node, size_t index, + const AnfNodePtr &pre_node, const FuncGraphPtr &func_graph, const std::string &instance_name, + const std::string ¶m_name) { + // insert new node before the node + FuncGraphManagerPtr manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + ScopePtr scope = node->scope(); + MS_EXCEPTION_IF_NULL(scope); + std::vector node_input = CreateMirrorInput(root, op, pre_node, instance_name, param_name); + CNodePtr new_node = func_graph->NewCNode(node_input); + MS_EXCEPTION_IF_NULL(new_node); + if (instance_name.find(SPLIT_SENS) == std::string::npos) { + new_node->set_in_forward_flag(true); // mark forward flag + } + auto new_node_value = node_input[0]->cast(); + MS_EXCEPTION_IF_NULL(new_node_value); + PrimitivePtr new_node_prim = new_node_value->value()->cast(); + new_node_prim->set_instance_name(instance_name); + new_node_prim->set_attr("keep_value_node_input", MakeValue(true)); + new_node->set_scope(scope); + node_input[0]->set_scope(scope); + manager->SetEdge(node, SizeToLong(index), new_node); + MS_LOG(INFO) << "Insert " << instance_name << " success"; +} + // Replace pre_node with pre_node->op static CNodePtr ReplaceNode(const Operator &op, const AnfNodePtr &pre_node, const FuncGraphPtr &func_graph, const std::string &instance_name) { @@ -965,7 +1096,7 @@ static void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr &par MS_LOG(INFO) << "Set comm fusion:" << param->param_info()->name() << "'s fusion type is " << fusion_type; } -void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) { +void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); size_t node_size = node->inputs().size(); FuncGraphPtr func_graph = node->func_graph(); @@ -997,6 +1128,13 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) { if (!param_node_pair.first) { continue; } + + auto param_ptr = param_node_pair.first->cast(); + std::string param_name; + if (param_ptr != nullptr) { + param_name = param_ptr->name(); + } + // not a RefKey if (!param_node_pair.second) { auto next_cnode = FindCNode(param_node_pair.first, MIRROR_OPERATOR, func_graph); @@ -1028,7 +1166,7 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) { CNodePtr cnode = node->input(index)->cast(); MS_EXCEPTION_IF_NULL(cnode); AnfNodePtr pre_node = cnode->input(1); - InsertNode(op, cnode, size_t(1), pre_node, func_graph, instance_name); + InsertMirrorNode(root, op, cnode, size_t(1), pre_node, func_graph, instance_name, param_name); auto comm_op = cnode->input(size_t(1))->cast(); // add fusion flag // pipeline mirror would not be set, which should be supported later @@ -1037,7 +1175,7 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) { } else { for (auto &op : backward_op) { AnfNodePtr pre_node = node->input(index); - InsertNode(op, node, index, pre_node, func_graph, instance_name); + InsertMirrorNode(root, op, node, index, pre_node, func_graph, instance_name, param_name); auto comm_op = node->input(index)->cast(); // add fusion flag // pipeline mirror would not be set, which should be supported later @@ -1047,7 +1185,7 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) { } } -void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node, +void BackwardCommunication(const FuncGraphPtr &root, const OperatorInfoPtr &distribute_operator, const CNodePtr &node, const std::vector> &sens_loss_pairs) { MS_EXCEPTION_IF_NULL(distribute_operator); MS_EXCEPTION_IF_NULL(node); @@ -1061,7 +1199,7 @@ void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNo // insert mirror op if (!mirror_ops.empty()) { MS_LOG(INFO) << "insert mirror op for " << distribute_operator->name(); - InsertMirrorOps(mirror_ops, node); + InsertMirrorOps(root, mirror_ops, node); } // insert virtual div op if (!virtual_div_op.empty() && is_loss_cnode) { @@ -1519,28 +1657,6 @@ void CoverSliceShape(const FuncGraphPtr &root) { g_RefMap.clear(); } -bool ParameterIsCloned(const AnfNodePtr ¶meter_node) { - MS_EXCEPTION_IF_NULL(parameter_node); - auto cloned_parameter = parameter_node->cast(); - MS_EXCEPTION_IF_NULL(cloned_parameter); - - // find the clone parameter - if (!cloned_parameter->has_default()) { - return false; - } - auto param_value = cloned_parameter->param_info(); - if (param_value == nullptr) { - return false; - } - bool cloned = param_value->cloned(); - if (!cloned) { - return false; - } - - MS_LOG(INFO) << "The parameter: " << cloned_parameter->name() << " is cloned"; - return true; -} - void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) { MS_EXCEPTION_IF_NULL(root); for (auto &cloned_parameter_node : root->parameters()) { @@ -2459,7 +2575,7 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector FindParameter(const AnfNodePtr &node, const FuncGrap std::pair FindCNode(const AnfNodePtr &anode, const std::string &name, const FuncGraphPtr &func_graph); -void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node); - -void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node, - const std::vector> &sens_loss_pairs); - // Generate and init parallel operator OperatorInfoPtr OperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs, const std::vector &shape_list); diff --git a/mindspore/ccsrc/pipeline/jit/init.cc b/mindspore/ccsrc/pipeline/jit/init.cc index 77a54bd954..da92fc3258 100644 --- a/mindspore/ccsrc/pipeline/jit/init.cc +++ b/mindspore/ccsrc/pipeline/jit/init.cc @@ -131,6 +131,8 @@ PYBIND11_MODULE(_c_expression, m) { .def("set_loss_repeated_mean", &ParallelContext::set_loss_repeated_mean, "Set loss repeated mean.") .def("get_parallel_mode", &ParallelContext::parallel_mode, "Get parallel mode.") .def("set_parallel_mode", &ParallelContext::set_parallel_mode, "Set parallel mode.") + .def("get_grad_accumulation_step", &ParallelContext::grad_accumulation_step, "Get grad accumulation step.") + .def("set_grad_accumulation_step", &ParallelContext::set_grad_accumulation_step, "Set grad accumulation step.") .def("get_strategy_search_mode", &ParallelContext::strategy_search_mode, "Get strategy search mode.") .def("set_strategy_search_mode", &ParallelContext::set_strategy_search_mode, "Set strategy search mode.") .def("set_all_reduce_fusion_split_indices", &ParallelContext::SetAllReduceFusionSplitIndices, diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc index 31321491c2..4a9cd00d5d 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -143,6 +143,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { irpass.check_bprop_eliminate_, irpass.switch_layer_defer_inline_, irpass.replace_applicator_, + irpass.mirror_mini_step_elim_, }); opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_}); opt::OptPassConfig grad = opt::OptPassConfig({irpass.expand_jprim_}, true); diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 8ab89abffc..b9b2c632c6 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -206,6 +206,7 @@ inline const PrimitivePtr kPrimTensorMove = std::make_shared("TensorM // Comm ops inline const PrimitivePtr kPrimMirror = std::make_shared("_MirrorOperator"); +inline const PrimitivePtr kPrimMirrorMiniStep = std::make_shared("_MirrorMiniStepOperator"); inline const PrimitivePtr kPrimVirtualDiv = std::make_shared("_VirtualDiv"); inline const PrimitivePtr kPrimVirtualDataset = std::make_shared("_VirtualDataset"); inline const PrimitivePtr kPrimSend = std::make_shared("Send"); diff --git a/mindspore/ops/_grad/grad_comm_ops.py b/mindspore/ops/_grad/grad_comm_ops.py index 17655bf321..3c12a1c294 100644 --- a/mindspore/ops/_grad/grad_comm_ops.py +++ b/mindspore/ops/_grad/grad_comm_ops.py @@ -21,7 +21,7 @@ from .. import operations as P from ...common.tensor import RowTensor from ..composite.multitype_ops.zeros_like_impl import zeros_like from ..operations.comm_ops import (AllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast, - _GetTensorSlice, _MirrorOperator, ReduceOp, + _GetTensorSlice, _MirrorOperator, _MirrorMiniStepOperator, ReduceOp, ReduceScatter, _HostReduceScatter, _VirtualDiv, AllSwap) from .grad_base import bprop_getters from ..operations._inner_ops import Send, Receive @@ -282,6 +282,82 @@ def get_bprop_mirror_operator(self): return bprop +@bprop_getters.register(_MirrorMiniStepOperator) +def get_bprop_mirror_mini_step_operator(self): + """ + Backpropagator for _MirrorMiniStepOperator, do allreduce or allgather for the devices in the group, + allgather for sparse feature. + """ + group = self.group + dev_num = self.dev_num + mean_flag = self.mean_flag + grad_accumulation_step = self.grad_accumulation_step + + all_reduce = AllReduce(group=group) + all_gather = AllGather(group=group) + mul = P.Mul() + cast = P.Cast() + equal = P.Equal() + reshape = P.Reshape() + + fusion = 1 + if hasattr(self, 'fusion'): + fusion = self.fusion + all_reduce.add_prim_attr("fusion", fusion) + if hasattr(self, 'parameter'): + parameter = self.parameter + all_reduce.add_prim_attr("parameter", parameter) + + if self.instance_name: + instance_name = "grad_mirror" + self.instance_name + all_reduce.set_prim_instance_name(instance_name) + + def bprop(x, y, z, out, dout): + do_mirror = equal(y, grad_accumulation_step) + do_mirror = reshape(do_mirror, (())) + if mean_flag: + if F.issubclass_(F.typeof(dout), mstype.tensor): + if do_mirror: + tmp = z + dout + real_grad = all_reduce(tmp) + dx = real_grad - z + else: + dx = dout + float_one = F.scalar_cast(1.0, F.dtype(dx)) + num = F.scalar_cast(dev_num, F.dtype(dx)) + dx = mul(dx, cast(F.scalar_to_array(float_one/num), F.dtype(dx))) + else: + if do_mirror: + indices = all_gather(dout.indices) + grad = all_gather(dout.values) + else: + indices = dout.indices + grad = dout.values + float_one = F.scalar_cast(1.0, F.dtype(grad)) + num = F.scalar_cast(dev_num, F.dtype(grad)) + grad = mul(grad, cast(F.scalar_to_array(float_one/num), F.dtype(grad))) + dx = RowTensor(indices, grad, dout.dense_shape) + else: + if F.issubclass_(F.typeof(dout), mstype.tensor): + if do_mirror: + tmp = z + dout + real_grad = all_reduce(tmp) + dx = real_grad - z + else: + dx = dout + else: + if do_mirror: + indices = all_gather(dout.indices) + grad = all_gather(dout.values) + else: + indices = dout.indices + grad = dout.values + dx = RowTensor(indices, grad, dout.dense_shape) + + return (dx, zeros_like(y), zeros_like(z)) + return bprop + + @bprop_getters.register(_VirtualDiv) def get_bprop_virtual_div_operator(self): """Backpropagator for _VirtualDiv, do Div for the divisor.""" diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 00d37aefe5..747192be07 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -35,7 +35,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup, Unique, GatherD, Identity) from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast, - _MirrorOperator, ReduceOp, _VirtualDataset, + _MirrorOperator, _MirrorMiniStepOperator, ReduceOp, _VirtualDataset, _VirtualDiv, _GetTensorSlice, _HostAllGather, _HostReduceScatter) from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary, diff --git a/mindspore/ops/operations/comm_ops.py b/mindspore/ops/operations/comm_ops.py index 54319c99aa..906e02d93d 100644 --- a/mindspore/ops/operations/comm_ops.py +++ b/mindspore/ops/operations/comm_ops.py @@ -567,6 +567,35 @@ class _MirrorOperator(PrimitiveWithInfer): mirror = _MirrorOperator() +class _MirrorMiniStepOperator(PrimitiveWithInfer): + """ + Auto parallel virtual operator. Do nothing in forward, do all reduce and mean in backward. It is only for + internal use of parallel modules and cannot be called by users. + + Args: + group (str): The communication group to work on. Default: None. + dev_num (int): The device number of the group. Default: None. + mean_flag (bool): Whether use mean in backward. Default: None. + grad_accumulation_step (int): The grad accumulation step. Default: None. + """ + + @prim_attr_register + def __init__(self, group=None, dev_num=None, mean_flag=None, grad_accumulation_step=None): + self.group = group + self.dev_num = dev_num + self.mean_flag = mean_flag + self.grad_accumulation_step = grad_accumulation_step + + def infer_shape(self, x_shape, y_shape, z_shape): + return x_shape + + def infer_dtype(self, x_dtype, y_shape, z_shape): + return x_dtype + + +mirror_mini_step = _MirrorMiniStepOperator() + + class _VirtualDiv(PrimitiveWithInfer): """ Auto parallel virtual operator. Do nothing in forward, do Div in backward. diff --git a/mindspore/parallel/_auto_parallel_context.py b/mindspore/parallel/_auto_parallel_context.py index ed68fdf76b..d97daca266 100644 --- a/mindspore/parallel/_auto_parallel_context.py +++ b/mindspore/parallel/_auto_parallel_context.py @@ -249,6 +249,21 @@ class _AutoParallelContext: return False return self._context_handle.get_full_batch() + def set_grad_accumulation_step(self, grad_accumulation_step): + """ + Set grad accumulation step. + + Args: + grad_accumulation_step (int): The grad accumulation step. + """ + self.check_context_handle() + self._context_handle.set_grad_accumulation_step(grad_accumulation_step) + + def get_grad_accumulation_step(self): + """Get grad accumulation step.""" + self.check_context_handle() + return self._context_handle.get_grad_accumulation_step() + def set_strategy_ckpt_save_file(self, strategy_ckpt_save_file): """ Set strategy checkpoint save path. @@ -492,6 +507,7 @@ _set_auto_parallel_context_func_map = { "strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file, "full_batch": auto_parallel_context().set_full_batch, "enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer, + "grad_accumulation_step": auto_parallel_context().set_grad_accumulation_step, "all_reduce_fusion_config": auto_parallel_context().set_all_reduce_fusion_split_indices} @@ -509,6 +525,7 @@ _get_auto_parallel_context_func_map = { "strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file, "full_batch": auto_parallel_context().get_full_batch, "enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer, + "grad_accumulation_step": auto_parallel_context().get_grad_accumulation_step, "all_reduce_fusion_config": auto_parallel_context().get_all_reduce_fusion_split_indices} @@ -516,7 +533,7 @@ _get_auto_parallel_context_func_map = { loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str, strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool, - all_reduce_fusion_config=list) + grad_accumulation_step=int, all_reduce_fusion_config=list) def _set_auto_parallel_context(**kwargs): """ diff --git a/tests/ut/python/parallel/test_grad_accumulation.py b/tests/ut/python/parallel/test_grad_accumulation.py new file mode 100644 index 0000000000..7bee3b8fe9 --- /dev/null +++ b/tests/ut/python/parallel/test_grad_accumulation.py @@ -0,0 +1,289 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np + +import mindspore as ms +import mindspore.common.dtype as mstype +from mindspore import context, Tensor, Parameter +from mindspore.nn import Cell, Momentum, Norm +from mindspore.train import Model +from mindspore.ops import operations as P +from mindspore.ops import composite as C +from mindspore.ops import functional as F +from mindspore.common.initializer import initializer +from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell +from mindspore.context import ParallelMode + +from tests.dataset_mock import MindData + + +class Dataset(MindData): + def __init__(self, predict, label, length=3): + super(Dataset, self).__init__(size=length) + self.predict = predict + self.label = label + self.index = 0 + self.length = length + + def __iter__(self): + return self + + def __next__(self): + if self.index >= self.length: + raise StopIteration + self.index += 1 + return self.predict, self.label + + def reset(self): + self.index = 0 + + +get_square_sum = C.MultitypeFuncGraph("get_square_sum") +@get_square_sum.register("Tensor") +def _get_square_sum(grad): + norm = P.ReduceSum(False)(F.square(grad), ()) + norm = F.expand_dims(F.cast(norm, mstype.float32), 0) + return norm + + +apply_global_norm = C.MultitypeFuncGraph("apply_global_norm") +@apply_global_norm.register("Tensor", "Tensor", "Tensor") +def _apply_global_norm(clip_norm, global_norm, grad): + grad = grad * clip_norm / global_norm + return grad + + +class GlobalNorm(Cell): + """ + Calculate the global norm value of given tensors + """ + def __init__(self): + super(GlobalNorm, self).__init__() + self.norm = Norm() + self.hyper_map = C.HyperMap() + + def construct(self, grads): + square_sum = self.hyper_map(get_square_sum, grads) + global_norms = F.sqrt(F.addn(square_sum) / F.scalar_to_array(len(square_sum))) + return global_norms + + +class ClipByGlobalNorm(Cell): + """ + Clip grads by global norm + """ + def __init__(self, clip_norm=1.0): + super(ClipByGlobalNorm, self).__init__() + self.global_norm = GlobalNorm() + self.clip_norm = Tensor([clip_norm], mstype.float32) + self.hyper_map = C.HyperMap() + + def construct(self, grads): + global_norm = self.global_norm(grads) + cond = P.GreaterEqual()(global_norm, self.clip_norm) + global_norm = F.select(cond, global_norm, self.clip_norm) + grads = self.hyper_map(F.partial(apply_global_norm, self.clip_norm, global_norm), grads) + return grads + + +cast = P.Cast() +update_accu_grads = C.MultitypeFuncGraph("update_accu_grads") + + +@update_accu_grads.register("Tensor", "Tensor") +def _update_accu_grads(accu_grad, grad): + succ = True + return F.depend(succ, F.assign_add(accu_grad, cast(grad, mstype.float32))) + + +zeroslike = P.ZerosLike() +reset_accu_grads = C.MultitypeFuncGraph("reset_accu_grads") + + +@reset_accu_grads.register("Tensor") +def _reset_accu_grads(accu_grad): + succ = True + return F.depend(succ, F.assign(accu_grad, zeroslike(accu_grad))) + + +grad_scale = C.MultitypeFuncGraph("grad_scale") +reciprocal = P.Reciprocal() + + +@grad_scale.register("Tensor", "Tensor") +def tensor_grad_scale(scale, grad): + return grad * reciprocal(scale) + + +class TrainAccumulateStepsWithLossScaleCell(Cell): + """ + Encapsulation class of bert network training. + + Append an optimizer to the training network after that the construct + function can be called to create the backward graph. To mimic higher batch size, gradients are + accumulated N times before weight update. + + Args: + network (Cell): The training network. Note that loss function should have been added. + optimizer (Optimizer): Optimizer for updating the weights. + scale_update_cell (Cell): Cell to do the loss scale. Default: None. + accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size = + batch_size * accumulation_steps. Default: 1. + """ + def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=4): + super(TrainAccumulateStepsWithLossScaleCell, self).__init__(auto_prefix=False) + self.network = network + self.network.set_grad() + self.weights = optimizer.parameters + self.optimizer = optimizer + self.accumulation_steps = accumulation_steps + self.one = Tensor(np.array([1]).astype(np.int32)) + self.zero = Tensor(np.array([0]).astype(np.int32)) + self.local_step = Parameter(initializer(0, [1], mstype.int32), name="local_step") + self.accu_grads = self.weights.clone(prefix="accu_grads", init='zeros') + self.accu_overflow = Parameter(initializer(0, [1], mstype.int32)) + self.accu_loss = Parameter(initializer(0, [1], mstype.float32)) + + self.grad = C.GradOperation(get_by_list=True, sens_param=True) + self.reducer_flag = False + self.parallel_mode = context.get_auto_parallel_context("parallel_mode") + if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: + self.reducer_flag = True + self.grad_reducer = F.identity + self.degree = 1 + if self.reducer_flag: + self.degree = get_group_size() + self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) + self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) + self.overflow_reducer = F.identity + if self.is_distributed: + self.overflow_reducer = P.AllReduce() + self.cast = P.Cast() + self.alloc_status = P.NPUAllocFloatStatus() + self.get_status = P.NPUGetFloatStatus() + self.clear_before_grad = P.NPUClearFloatStatus() + self.reduce_sum = P.ReduceSum(keep_dims=False) + self.base = Tensor(1, mstype.float32) + self.less_equal = P.LessEqual() + self.logical_or = P.LogicalOr() + self.not_equal = P.NotEqual() + self.select = P.Select() + self.reshape = P.Reshape() + self.hyper_map = C.HyperMap() + self.loss_scale = None + self.loss_scaling_manager = scale_update_cell + if scale_update_cell: + self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32)) + + @C.add_flags(has_effect=True) + def construct(self, x, b, sens=None): + """Defines the computation performed.""" + weights = self.weights + loss = self.network(x, b) + if sens is None: + scaling_sens = self.loss_scale + else: + scaling_sens = sens + + # update accumulation parameters + is_accu_step = self.not_equal(self.local_step, self.accumulation_steps) + self.local_step = self.select(is_accu_step, self.local_step + self.one, self.one) + self.accu_loss = self.select(is_accu_step, self.accu_loss + loss, loss) + mean_loss = self.accu_loss / self.local_step + is_accu_step = self.not_equal(self.local_step, self.accumulation_steps) + + # alloc status and clear should be right before gradoperation + init = self.alloc_status() + self.clear_before_grad(init) + grads = self.grad(self.network, weights)(x, b, self.cast(scaling_sens, mstype.float32)) + + accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, grads) + mean_loss = F.depend(mean_loss, accu_succ) + + self.get_status(init) + flag_sum = self.reduce_sum(init, (0,)) + overflow = self.less_equal(self.base, flag_sum) + overflow = self.logical_or(self.not_equal(self.accu_overflow, self.zero), overflow) + accu_overflow = self.select(overflow, self.one, self.zero) + self.accu_overflow = self.select(is_accu_step, accu_overflow, self.zero) + is_accu_step = self.reshape(is_accu_step, (())) + + if is_accu_step: + succ = False + else: + # apply grad reducer on grads + grads = self.grad_reducer(self.accu_grads) + scaling = scaling_sens * self.degree * self.accumulation_steps + grads = self.hyper_map(F.partial(grad_scale, scaling), grads) + grads = ClipByGlobalNorm()(grads) + accu_overflow = self.overflow_reducer(accu_overflow) + F.control_depend(grads, accu_overflow) + overflow = self.less_equal(self.base, accu_overflow) + accu_succ = self.hyper_map(reset_accu_grads, self.accu_grads) + overflow = F.depend(overflow, accu_succ) + overflow = self.reshape(overflow, (())) + if sens is None: + overflow = self.loss_scaling_manager(self.loss_scale, overflow) + if overflow: + succ = False + else: + succ = self.optimizer(grads) + + ret = (mean_loss, overflow, scaling_sens) + return F.depend(ret, succ) + + +class Net(Cell): + def __init__(self, weight, strategy=None): + super().__init__() + self.mul = P.Mul().shard(strategy) + self.weight = Parameter(weight, "w1") + self.relu = P.ReLU() + self.reduce_sum = P.ReduceSum(keep_dims=True) + + def construct(self, x, b): + out = self.mul(x, self.weight) + out = self.relu(out) + out = self.reduce_sum(out) + return out + + +_x = Tensor(np.ones([2]), dtype=ms.float32) +_b = Tensor(np.ones([16]), dtype=ms.float32) +_w1 = Tensor(np.ones([16]), dtype=ms.float32) + + +def compile_net(net, grad_accumulation_step): + context.set_context(save_graphs=True) + learning_rate = 0.1 + momentum = 0.9 + epoch_size = 2 + dataset = Dataset(_x, _b) + opt = Momentum(net.trainable_params(), learning_rate, momentum) + update_cell = DynamicLossScaleUpdateCell(loss_scale_value=65536, scale_factor=2, scale_window=1000) + net_wrap = TrainAccumulateStepsWithLossScaleCell(net, opt, scale_update_cell=update_cell, + accumulation_steps=grad_accumulation_step) + model = Model(net_wrap) + model.train(epoch_size, dataset, dataset_sink_mode=False) + context.reset_auto_parallel_context() + + +def test_grad_accumulation(): + grad_accumulation_step = 4 + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0, + grad_accumulation_step=grad_accumulation_step) + strategy = ((2,), (2,)) + net = Net(_w1, strategy) + compile_net(net, grad_accumulation_step)