| @@ -80,6 +80,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||
| {prim::kPrimReduceMean, prim::kPrimReduceAll, prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin}); | |||
| partial_eliminate_ = MakeSubstitution(std::make_shared<PartialEliminater>(), "partial_eliminate", IsCNodeDup); | |||
| same_eliminate_ = MakeSubstitution(std::make_shared<SameEliminater>(), "same_eliminate", prim::kPrimSameTypeShape); | |||
| mirror_mini_step_elim_ = MakeSubstitution(std::make_shared<MirrorMiniStepEliminater>(), "mirror_mini_step_eliminate", | |||
| prim::kPrimMirrorMiniStep); | |||
| check_bprop_eliminate_ = | |||
| MakeSubstitution(std::make_shared<CheckBpropEliminater>(), "check_bprop_eliminate", prim::kPrimCheckBprop); | |||
| reset_defer_inline_ = | |||
| @@ -51,6 +51,7 @@ class OptimizeIRPassLib { | |||
| SubstitutionPtr reset_defer_inline_; | |||
| SubstitutionPtr depend_value_elim_; | |||
| SubstitutionPtr all_reduce_const_elim_; | |||
| SubstitutionPtr mirror_mini_step_elim_; | |||
| // Env Item Eliminate | |||
| SubstitutionPtr env_get_item_eliminate_; | |||
| @@ -155,6 +155,29 @@ class CheckBpropEliminater : public AnfVisitor { | |||
| AnfNodePtr x_{nullptr}; | |||
| }; | |||
| // {prim::kPrimMirrorMiniStep, X, Y, Z} -> X | |||
| class MirrorMiniStepEliminater : public AnfVisitor { | |||
| public: | |||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||
| if (!IsPrimitiveCNode(node, prim::kPrimMirrorMiniStep) || node->func_graph() == nullptr) { | |||
| return nullptr; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (cnode == nullptr) { | |||
| return nullptr; | |||
| } | |||
| auto inputs = cnode->inputs(); | |||
| if (inputs.size() < 2) { | |||
| return nullptr; | |||
| } | |||
| return inputs[1]; | |||
| } | |||
| void Visit(const AnfNodePtr &) override {} | |||
| }; | |||
| // Reset defer_inline flag | |||
| class ResetDeferInline : public AnfVisitor { | |||
| public: | |||
| @@ -64,6 +64,7 @@ void ParallelContext::Reset() { | |||
| all_reduce_fusion_split_sizes_.clear(); | |||
| strategy_search_mode_ = DYNAMIC_PROGRAMMING; | |||
| pipeline_stage_split_num_ = 1; | |||
| grad_accumulation_step_ = 1; | |||
| } | |||
| void ParallelContext::set_device_num(int64_t device_num) { | |||
| @@ -80,6 +81,10 @@ void ParallelContext::set_gradients_mean(bool gradients_mean) { gradients_mean_ | |||
| void ParallelContext::set_full_batch(bool full_batch) { full_batch_ = full_batch; } | |||
| void ParallelContext::set_grad_accumulation_step(int64_t grad_accumulation_step) { | |||
| grad_accumulation_step_ = grad_accumulation_step; | |||
| } | |||
| void ParallelContext::set_gradient_fp32_sync(bool gradient_fp32_sync) { gradient_fp32_sync_ = gradient_fp32_sync; } | |||
| void ParallelContext::set_loss_repeated_mean(bool loss_repeated_mean) { loss_repeated_mean_ = loss_repeated_mean; } | |||
| @@ -73,6 +73,9 @@ class ParallelContext { | |||
| void set_global_rank(int64_t global_rank); | |||
| int64_t global_rank() const { return global_rank_; } | |||
| void set_grad_accumulation_step(int64_t grad_accumulation_step); | |||
| int64_t grad_accumulation_step() const { return grad_accumulation_step_; } | |||
| bool set_parallel_mode(const std::string ¶llel_mode); | |||
| std::string parallel_mode() const { return parallel_mode_; } | |||
| @@ -116,6 +119,7 @@ class ParallelContext { | |||
| bool loss_repeated_mean_; | |||
| int64_t device_num_; | |||
| int64_t global_rank_; | |||
| int64_t grad_accumulation_step_; | |||
| std::string parallel_mode_; | |||
| std::string strategy_search_mode_; | |||
| int64_t pipeline_stage_split_num_; | |||
| @@ -285,8 +285,8 @@ OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num) { | |||
| } | |||
| OperatorVector op_for_weight; | |||
| bool mean_flag = ParallelContext::GetInstance()->gradients_mean(); | |||
| int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step(); | |||
| OperatorName operator_name = MIRROR_OPERATOR; | |||
| ValuePtr attr0_value = MakeValue(group_name); | |||
| ValuePtr attr1_value = MakeValue(SizeToLong(dev_num)); | |||
| ValuePtr attr2_value = MakeValue(mean_flag); | |||
| @@ -300,6 +300,17 @@ OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num) { | |||
| operator_attrs.push_back(attr1); | |||
| operator_attrs.push_back(attr2); | |||
| OperatorName operator_name; | |||
| if (grad_accumulation_step > 1) { | |||
| operator_name = MIRROR_MINI_STEP_OPERATOR; | |||
| ValuePtr attr3_value = MakeValue(grad_accumulation_step); | |||
| Attr attr3 = std::make_pair(GRAD_ACCUMULATION_STEP, attr3_value); | |||
| operator_attrs.push_back(attr3); | |||
| MS_LOG(INFO) << "The grad accumulation step is " << grad_accumulation_step << ", use mini step mirror"; | |||
| } else { | |||
| operator_name = MIRROR_OPERATOR; | |||
| } | |||
| OperatorParams operator_param; | |||
| OperatorArgs operator_args = std::make_pair(operator_attrs, operator_param); | |||
| @@ -146,8 +146,10 @@ constexpr char IS_IN_FORWARD[] = "is_in_forward"; | |||
| constexpr char DTYPE[] = "DType"; | |||
| constexpr char DEV_NUM[] = "dev_num"; | |||
| constexpr char MEAN_FLAG[] = "mean_flag"; | |||
| constexpr char GRAD_ACCUMULATION_STEP[] = "grad_accumulation_step"; | |||
| constexpr char TYPES[] = "types"; | |||
| constexpr char SHAPES[] = "shapes"; | |||
| constexpr char ACCU_GRADS[] = "accu_grads"; | |||
| constexpr char GETNEXT_NUM[] = "output_num"; | |||
| constexpr char SHARED_NAME[] = "shared_name"; | |||
| constexpr char MIRROR_OP[] = "mirror_op"; | |||
| @@ -171,6 +173,8 @@ constexpr char CONCAT_BY_AXIS[] = "ConcatByAxis"; | |||
| constexpr char SPLIT_BY_AXIS[] = "SplitByAxis"; | |||
| constexpr char ALL_REDUCE[] = "AllReduce"; | |||
| constexpr char MIRROR_OPERATOR[] = "_MirrorOperator"; | |||
| constexpr char MIRROR_MINI_STEP_OPERATOR[] = "_MirrorMiniStepOperator"; | |||
| constexpr char LOCAL_STEP[] = "local_step"; | |||
| constexpr char STRIDED_SLICE[] = "StridedSlice"; | |||
| constexpr char ALL_GATHER[] = "AllGather"; | |||
| constexpr char REDUCE_SCATTER[] = "ReduceScatter"; | |||
| @@ -128,6 +128,137 @@ void InsertNode(const Operator &op, const CNodePtr &node, size_t index, const An | |||
| MS_LOG(INFO) << "Insert " << instance_name << " success"; | |||
| } | |||
| bool ParameterIsCloned(const AnfNodePtr ¶meter_node) { | |||
| MS_EXCEPTION_IF_NULL(parameter_node); | |||
| auto cloned_parameter = parameter_node->cast<ParameterPtr>(); | |||
| MS_EXCEPTION_IF_NULL(cloned_parameter); | |||
| // find the clone parameter | |||
| if (!cloned_parameter->has_default()) { | |||
| return false; | |||
| } | |||
| auto param_value = cloned_parameter->param_info(); | |||
| if (param_value == nullptr) { | |||
| return false; | |||
| } | |||
| bool cloned = param_value->cloned(); | |||
| if (!cloned) { | |||
| return false; | |||
| } | |||
| MS_LOG(INFO) << "The parameter: " << cloned_parameter->name() << " is cloned"; | |||
| return true; | |||
| } | |||
| std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operator &op, const AnfNodePtr &node, | |||
| const std::string &instance_name, const std::string &weight_name) { | |||
| MS_EXCEPTION_IF_NULL(root); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_EXCEPTION_IF_NULL(root->manager()); | |||
| AnfNodePtr local_step_param = nullptr; | |||
| AnfNodePtr grad_accu = nullptr; | |||
| std::string op_name = op.first; | |||
| OperatorArgs arg_forward = op.second; | |||
| int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step(); | |||
| if (grad_accumulation_step > 1) { | |||
| bool find_locat_step_node = false; | |||
| auto parameters = root->parameters(); | |||
| for (auto ¶m : parameters) { | |||
| auto param_ptr = param->cast<ParameterPtr>(); | |||
| MS_EXCEPTION_IF_NULL(param_ptr); | |||
| if (param_ptr->name() == LOCAL_STEP) { | |||
| auto param_users = root->manager()->node_users()[param]; | |||
| for (auto &user : param_users) { | |||
| if (AnfNodeIsPrimitive(user.first, ASSIGN)) { | |||
| find_locat_step_node = true; | |||
| local_step_param = user.first; | |||
| MS_LOG(INFO) << "Find the local step when create mirror, it may be in the mini step grad accumulation mode"; | |||
| break; | |||
| } | |||
| } | |||
| break; | |||
| } | |||
| } | |||
| bool find_grad_accu_node = false; | |||
| for (auto ¶m : parameters) { | |||
| if (!ParameterIsCloned(param)) { | |||
| continue; | |||
| } | |||
| auto param_ptr = param->cast<ParameterPtr>(); | |||
| MS_EXCEPTION_IF_NULL(param_ptr); | |||
| if (param_ptr->name().find(weight_name) != std::string::npos && | |||
| param_ptr->name().find(ACCU_GRADS) != std::string::npos) { | |||
| find_grad_accu_node = true; | |||
| grad_accu = param; | |||
| MS_LOG(INFO) << "Find the accumulation grad node: " << param_ptr->name(); | |||
| break; | |||
| } | |||
| } | |||
| if (op_name == MIRROR_MINI_STEP_OPERATOR) { | |||
| if (!find_locat_step_node || !find_grad_accu_node) { | |||
| op_name = MIRROR_OPERATOR; | |||
| arg_forward.first.pop_back(); | |||
| } | |||
| } | |||
| } | |||
| ValuePtr pyop_instance = CreatOpInstance(arg_forward.first, op_name, instance_name); | |||
| MS_EXCEPTION_IF_NULL(pyop_instance); | |||
| OperatorParams params = arg_forward.second; | |||
| std::vector<AnfNodePtr> new_node_input; | |||
| if (op_name == MIRROR_MINI_STEP_OPERATOR) { | |||
| new_node_input = {NewValueNode(pyop_instance), node, local_step_param, grad_accu}; | |||
| MS_LOG(INFO) << "Insert the local step node and grad accumulation node as the mirror op's input"; | |||
| } else { | |||
| new_node_input = {NewValueNode(pyop_instance), node}; | |||
| } | |||
| if (!params.empty()) { | |||
| for (auto ¶m : params) { | |||
| AnfNodePtr val = NewValueNode(param.first.second); | |||
| MS_EXCEPTION_IF_NULL(val); | |||
| int64_t position = param.second; | |||
| (void)new_node_input.insert(new_node_input.begin() + position, val); | |||
| } | |||
| } | |||
| // if the op have 'group' attr, set the rank list name for the op | |||
| SetCommunicationOpGroupLabel(new_node_input); | |||
| return new_node_input; | |||
| } | |||
| void InsertMirrorNode(const FuncGraphPtr &root, const Operator &op, const CNodePtr &node, size_t index, | |||
| const AnfNodePtr &pre_node, const FuncGraphPtr &func_graph, const std::string &instance_name, | |||
| const std::string ¶m_name) { | |||
| // insert new node before the node | |||
| FuncGraphManagerPtr manager = func_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| ScopePtr scope = node->scope(); | |||
| MS_EXCEPTION_IF_NULL(scope); | |||
| std::vector<AnfNodePtr> node_input = CreateMirrorInput(root, op, pre_node, instance_name, param_name); | |||
| CNodePtr new_node = func_graph->NewCNode(node_input); | |||
| MS_EXCEPTION_IF_NULL(new_node); | |||
| if (instance_name.find(SPLIT_SENS) == std::string::npos) { | |||
| new_node->set_in_forward_flag(true); // mark forward flag | |||
| } | |||
| auto new_node_value = node_input[0]->cast<ValueNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(new_node_value); | |||
| PrimitivePtr new_node_prim = new_node_value->value()->cast<PrimitivePtr>(); | |||
| new_node_prim->set_instance_name(instance_name); | |||
| new_node_prim->set_attr("keep_value_node_input", MakeValue(true)); | |||
| new_node->set_scope(scope); | |||
| node_input[0]->set_scope(scope); | |||
| manager->SetEdge(node, SizeToLong(index), new_node); | |||
| MS_LOG(INFO) << "Insert " << instance_name << " success"; | |||
| } | |||
| // Replace pre_node with pre_node->op | |||
| static CNodePtr ReplaceNode(const Operator &op, const AnfNodePtr &pre_node, const FuncGraphPtr &func_graph, | |||
| const std::string &instance_name) { | |||
| @@ -965,7 +1096,7 @@ static void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr &par | |||
| MS_LOG(INFO) << "Set comm fusion:" << param->param_info()->name() << "'s fusion type is " << fusion_type; | |||
| } | |||
| void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) { | |||
| void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, const CNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| size_t node_size = node->inputs().size(); | |||
| FuncGraphPtr func_graph = node->func_graph(); | |||
| @@ -997,6 +1128,13 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) { | |||
| if (!param_node_pair.first) { | |||
| continue; | |||
| } | |||
| auto param_ptr = param_node_pair.first->cast<ParameterPtr>(); | |||
| std::string param_name; | |||
| if (param_ptr != nullptr) { | |||
| param_name = param_ptr->name(); | |||
| } | |||
| // not a RefKey | |||
| if (!param_node_pair.second) { | |||
| auto next_cnode = FindCNode(param_node_pair.first, MIRROR_OPERATOR, func_graph); | |||
| @@ -1028,7 +1166,7 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) { | |||
| CNodePtr cnode = node->input(index)->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| AnfNodePtr pre_node = cnode->input(1); | |||
| InsertNode(op, cnode, size_t(1), pre_node, func_graph, instance_name); | |||
| InsertMirrorNode(root, op, cnode, size_t(1), pre_node, func_graph, instance_name, param_name); | |||
| auto comm_op = cnode->input(size_t(1))->cast<CNodePtr>(); | |||
| // add fusion flag | |||
| // pipeline mirror would not be set, which should be supported later | |||
| @@ -1037,7 +1175,7 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) { | |||
| } else { | |||
| for (auto &op : backward_op) { | |||
| AnfNodePtr pre_node = node->input(index); | |||
| InsertNode(op, node, index, pre_node, func_graph, instance_name); | |||
| InsertMirrorNode(root, op, node, index, pre_node, func_graph, instance_name, param_name); | |||
| auto comm_op = node->input(index)->cast<CNodePtr>(); | |||
| // add fusion flag | |||
| // pipeline mirror would not be set, which should be supported later | |||
| @@ -1047,7 +1185,7 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) { | |||
| } | |||
| } | |||
| void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node, | |||
| void BackwardCommunication(const FuncGraphPtr &root, const OperatorInfoPtr &distribute_operator, const CNodePtr &node, | |||
| const std::vector<std::pair<CNodePtr, LossNodeInfo>> &sens_loss_pairs) { | |||
| MS_EXCEPTION_IF_NULL(distribute_operator); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| @@ -1061,7 +1199,7 @@ void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNo | |||
| // insert mirror op | |||
| if (!mirror_ops.empty()) { | |||
| MS_LOG(INFO) << "insert mirror op for " << distribute_operator->name(); | |||
| InsertMirrorOps(mirror_ops, node); | |||
| InsertMirrorOps(root, mirror_ops, node); | |||
| } | |||
| // insert virtual div op | |||
| if (!virtual_div_op.empty() && is_loss_cnode) { | |||
| @@ -1519,28 +1657,6 @@ void CoverSliceShape(const FuncGraphPtr &root) { | |||
| g_RefMap.clear(); | |||
| } | |||
| bool ParameterIsCloned(const AnfNodePtr ¶meter_node) { | |||
| MS_EXCEPTION_IF_NULL(parameter_node); | |||
| auto cloned_parameter = parameter_node->cast<ParameterPtr>(); | |||
| MS_EXCEPTION_IF_NULL(cloned_parameter); | |||
| // find the clone parameter | |||
| if (!cloned_parameter->has_default()) { | |||
| return false; | |||
| } | |||
| auto param_value = cloned_parameter->param_info(); | |||
| if (param_value == nullptr) { | |||
| return false; | |||
| } | |||
| bool cloned = param_value->cloned(); | |||
| if (!cloned) { | |||
| return false; | |||
| } | |||
| MS_LOG(INFO) << "The parameter: " << cloned_parameter->name() << " is cloned"; | |||
| return true; | |||
| } | |||
| void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) { | |||
| MS_EXCEPTION_IF_NULL(root); | |||
| for (auto &cloned_parameter_node : root->parameters()) { | |||
| @@ -2459,7 +2575,7 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt | |||
| // insert backward ops | |||
| if (has_backward && !IsSomePrimitive(cnode, RECEIVE)) { | |||
| BackwardCommunication(distribute_operator, cnode, sens_loss_pairs); | |||
| BackwardCommunication(root, distribute_operator, cnode, sens_loss_pairs); | |||
| } | |||
| HandleSpecialNode(distribute_operator, cnode); | |||
| @@ -82,11 +82,6 @@ std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGrap | |||
| std::pair<bool, CNodePtr> FindCNode(const AnfNodePtr &anode, const std::string &name, const FuncGraphPtr &func_graph); | |||
| void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node); | |||
| void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node, | |||
| const std::vector<std::pair<CNodePtr, LossNodeInfo>> &sens_loss_pairs); | |||
| // Generate and init parallel operator | |||
| OperatorInfoPtr OperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs, | |||
| const std::vector<Shapes> &shape_list); | |||
| @@ -131,6 +131,8 @@ PYBIND11_MODULE(_c_expression, m) { | |||
| .def("set_loss_repeated_mean", &ParallelContext::set_loss_repeated_mean, "Set loss repeated mean.") | |||
| .def("get_parallel_mode", &ParallelContext::parallel_mode, "Get parallel mode.") | |||
| .def("set_parallel_mode", &ParallelContext::set_parallel_mode, "Set parallel mode.") | |||
| .def("get_grad_accumulation_step", &ParallelContext::grad_accumulation_step, "Get grad accumulation step.") | |||
| .def("set_grad_accumulation_step", &ParallelContext::set_grad_accumulation_step, "Set grad accumulation step.") | |||
| .def("get_strategy_search_mode", &ParallelContext::strategy_search_mode, "Get strategy search mode.") | |||
| .def("set_strategy_search_mode", &ParallelContext::set_strategy_search_mode, "Set strategy search mode.") | |||
| .def("set_all_reduce_fusion_split_indices", &ParallelContext::SetAllReduceFusionSplitIndices, | |||
| @@ -143,6 +143,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| irpass.check_bprop_eliminate_, | |||
| irpass.switch_layer_defer_inline_, | |||
| irpass.replace_applicator_, | |||
| irpass.mirror_mini_step_elim_, | |||
| }); | |||
| opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_}); | |||
| opt::OptPassConfig grad = opt::OptPassConfig({irpass.expand_jprim_}, true); | |||
| @@ -206,6 +206,7 @@ inline const PrimitivePtr kPrimTensorMove = std::make_shared<Primitive>("TensorM | |||
| // Comm ops | |||
| inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator"); | |||
| inline const PrimitivePtr kPrimMirrorMiniStep = std::make_shared<Primitive>("_MirrorMiniStepOperator"); | |||
| inline const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv"); | |||
| inline const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset"); | |||
| inline const PrimitivePtr kPrimSend = std::make_shared<Primitive>("Send"); | |||
| @@ -21,7 +21,7 @@ from .. import operations as P | |||
| from ...common.tensor import RowTensor | |||
| from ..composite.multitype_ops.zeros_like_impl import zeros_like | |||
| from ..operations.comm_ops import (AllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast, | |||
| _GetTensorSlice, _MirrorOperator, ReduceOp, | |||
| _GetTensorSlice, _MirrorOperator, _MirrorMiniStepOperator, ReduceOp, | |||
| ReduceScatter, _HostReduceScatter, _VirtualDiv, AllSwap) | |||
| from .grad_base import bprop_getters | |||
| from ..operations._inner_ops import Send, Receive | |||
| @@ -282,6 +282,82 @@ def get_bprop_mirror_operator(self): | |||
| return bprop | |||
| @bprop_getters.register(_MirrorMiniStepOperator) | |||
| def get_bprop_mirror_mini_step_operator(self): | |||
| """ | |||
| Backpropagator for _MirrorMiniStepOperator, do allreduce or allgather for the devices in the group, | |||
| allgather for sparse feature. | |||
| """ | |||
| group = self.group | |||
| dev_num = self.dev_num | |||
| mean_flag = self.mean_flag | |||
| grad_accumulation_step = self.grad_accumulation_step | |||
| all_reduce = AllReduce(group=group) | |||
| all_gather = AllGather(group=group) | |||
| mul = P.Mul() | |||
| cast = P.Cast() | |||
| equal = P.Equal() | |||
| reshape = P.Reshape() | |||
| fusion = 1 | |||
| if hasattr(self, 'fusion'): | |||
| fusion = self.fusion | |||
| all_reduce.add_prim_attr("fusion", fusion) | |||
| if hasattr(self, 'parameter'): | |||
| parameter = self.parameter | |||
| all_reduce.add_prim_attr("parameter", parameter) | |||
| if self.instance_name: | |||
| instance_name = "grad_mirror" + self.instance_name | |||
| all_reduce.set_prim_instance_name(instance_name) | |||
| def bprop(x, y, z, out, dout): | |||
| do_mirror = equal(y, grad_accumulation_step) | |||
| do_mirror = reshape(do_mirror, (())) | |||
| if mean_flag: | |||
| if F.issubclass_(F.typeof(dout), mstype.tensor): | |||
| if do_mirror: | |||
| tmp = z + dout | |||
| real_grad = all_reduce(tmp) | |||
| dx = real_grad - z | |||
| else: | |||
| dx = dout | |||
| float_one = F.scalar_cast(1.0, F.dtype(dx)) | |||
| num = F.scalar_cast(dev_num, F.dtype(dx)) | |||
| dx = mul(dx, cast(F.scalar_to_array(float_one/num), F.dtype(dx))) | |||
| else: | |||
| if do_mirror: | |||
| indices = all_gather(dout.indices) | |||
| grad = all_gather(dout.values) | |||
| else: | |||
| indices = dout.indices | |||
| grad = dout.values | |||
| float_one = F.scalar_cast(1.0, F.dtype(grad)) | |||
| num = F.scalar_cast(dev_num, F.dtype(grad)) | |||
| grad = mul(grad, cast(F.scalar_to_array(float_one/num), F.dtype(grad))) | |||
| dx = RowTensor(indices, grad, dout.dense_shape) | |||
| else: | |||
| if F.issubclass_(F.typeof(dout), mstype.tensor): | |||
| if do_mirror: | |||
| tmp = z + dout | |||
| real_grad = all_reduce(tmp) | |||
| dx = real_grad - z | |||
| else: | |||
| dx = dout | |||
| else: | |||
| if do_mirror: | |||
| indices = all_gather(dout.indices) | |||
| grad = all_gather(dout.values) | |||
| else: | |||
| indices = dout.indices | |||
| grad = dout.values | |||
| dx = RowTensor(indices, grad, dout.dense_shape) | |||
| return (dx, zeros_like(y), zeros_like(z)) | |||
| return bprop | |||
| @bprop_getters.register(_VirtualDiv) | |||
| def get_bprop_virtual_div_operator(self): | |||
| """Backpropagator for _VirtualDiv, do Div for the divisor.""" | |||
| @@ -35,7 +35,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, | |||
| SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup, | |||
| Unique, GatherD, Identity) | |||
| from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast, | |||
| _MirrorOperator, ReduceOp, _VirtualDataset, | |||
| _MirrorOperator, _MirrorMiniStepOperator, ReduceOp, _VirtualDataset, | |||
| _VirtualDiv, _GetTensorSlice, | |||
| _HostAllGather, _HostReduceScatter) | |||
| from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary, | |||
| @@ -567,6 +567,35 @@ class _MirrorOperator(PrimitiveWithInfer): | |||
| mirror = _MirrorOperator() | |||
| class _MirrorMiniStepOperator(PrimitiveWithInfer): | |||
| """ | |||
| Auto parallel virtual operator. Do nothing in forward, do all reduce and mean in backward. It is only for | |||
| internal use of parallel modules and cannot be called by users. | |||
| Args: | |||
| group (str): The communication group to work on. Default: None. | |||
| dev_num (int): The device number of the group. Default: None. | |||
| mean_flag (bool): Whether use mean in backward. Default: None. | |||
| grad_accumulation_step (int): The grad accumulation step. Default: None. | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, group=None, dev_num=None, mean_flag=None, grad_accumulation_step=None): | |||
| self.group = group | |||
| self.dev_num = dev_num | |||
| self.mean_flag = mean_flag | |||
| self.grad_accumulation_step = grad_accumulation_step | |||
| def infer_shape(self, x_shape, y_shape, z_shape): | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype, y_shape, z_shape): | |||
| return x_dtype | |||
| mirror_mini_step = _MirrorMiniStepOperator() | |||
| class _VirtualDiv(PrimitiveWithInfer): | |||
| """ | |||
| Auto parallel virtual operator. Do nothing in forward, do Div in backward. | |||
| @@ -249,6 +249,21 @@ class _AutoParallelContext: | |||
| return False | |||
| return self._context_handle.get_full_batch() | |||
| def set_grad_accumulation_step(self, grad_accumulation_step): | |||
| """ | |||
| Set grad accumulation step. | |||
| Args: | |||
| grad_accumulation_step (int): The grad accumulation step. | |||
| """ | |||
| self.check_context_handle() | |||
| self._context_handle.set_grad_accumulation_step(grad_accumulation_step) | |||
| def get_grad_accumulation_step(self): | |||
| """Get grad accumulation step.""" | |||
| self.check_context_handle() | |||
| return self._context_handle.get_grad_accumulation_step() | |||
| def set_strategy_ckpt_save_file(self, strategy_ckpt_save_file): | |||
| """ | |||
| Set strategy checkpoint save path. | |||
| @@ -492,6 +507,7 @@ _set_auto_parallel_context_func_map = { | |||
| "strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file, | |||
| "full_batch": auto_parallel_context().set_full_batch, | |||
| "enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer, | |||
| "grad_accumulation_step": auto_parallel_context().set_grad_accumulation_step, | |||
| "all_reduce_fusion_config": auto_parallel_context().set_all_reduce_fusion_split_indices} | |||
| @@ -509,6 +525,7 @@ _get_auto_parallel_context_func_map = { | |||
| "strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file, | |||
| "full_batch": auto_parallel_context().get_full_batch, | |||
| "enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer, | |||
| "grad_accumulation_step": auto_parallel_context().get_grad_accumulation_step, | |||
| "all_reduce_fusion_config": auto_parallel_context().get_all_reduce_fusion_split_indices} | |||
| @@ -516,7 +533,7 @@ _get_auto_parallel_context_func_map = { | |||
| loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str, | |||
| parameter_broadcast=bool, strategy_ckpt_load_file=str, | |||
| strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool, | |||
| all_reduce_fusion_config=list) | |||
| grad_accumulation_step=int, all_reduce_fusion_config=list) | |||
| def _set_auto_parallel_context(**kwargs): | |||
| """ | |||
| @@ -0,0 +1,289 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import mindspore as ms | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore import context, Tensor, Parameter | |||
| from mindspore.nn import Cell, Momentum, Norm | |||
| from mindspore.train import Model | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops import functional as F | |||
| from mindspore.common.initializer import initializer | |||
| from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell | |||
| from mindspore.context import ParallelMode | |||
| from tests.dataset_mock import MindData | |||
| class Dataset(MindData): | |||
| def __init__(self, predict, label, length=3): | |||
| super(Dataset, self).__init__(size=length) | |||
| self.predict = predict | |||
| self.label = label | |||
| self.index = 0 | |||
| self.length = length | |||
| def __iter__(self): | |||
| return self | |||
| def __next__(self): | |||
| if self.index >= self.length: | |||
| raise StopIteration | |||
| self.index += 1 | |||
| return self.predict, self.label | |||
| def reset(self): | |||
| self.index = 0 | |||
| get_square_sum = C.MultitypeFuncGraph("get_square_sum") | |||
| @get_square_sum.register("Tensor") | |||
| def _get_square_sum(grad): | |||
| norm = P.ReduceSum(False)(F.square(grad), ()) | |||
| norm = F.expand_dims(F.cast(norm, mstype.float32), 0) | |||
| return norm | |||
| apply_global_norm = C.MultitypeFuncGraph("apply_global_norm") | |||
| @apply_global_norm.register("Tensor", "Tensor", "Tensor") | |||
| def _apply_global_norm(clip_norm, global_norm, grad): | |||
| grad = grad * clip_norm / global_norm | |||
| return grad | |||
| class GlobalNorm(Cell): | |||
| """ | |||
| Calculate the global norm value of given tensors | |||
| """ | |||
| def __init__(self): | |||
| super(GlobalNorm, self).__init__() | |||
| self.norm = Norm() | |||
| self.hyper_map = C.HyperMap() | |||
| def construct(self, grads): | |||
| square_sum = self.hyper_map(get_square_sum, grads) | |||
| global_norms = F.sqrt(F.addn(square_sum) / F.scalar_to_array(len(square_sum))) | |||
| return global_norms | |||
| class ClipByGlobalNorm(Cell): | |||
| """ | |||
| Clip grads by global norm | |||
| """ | |||
| def __init__(self, clip_norm=1.0): | |||
| super(ClipByGlobalNorm, self).__init__() | |||
| self.global_norm = GlobalNorm() | |||
| self.clip_norm = Tensor([clip_norm], mstype.float32) | |||
| self.hyper_map = C.HyperMap() | |||
| def construct(self, grads): | |||
| global_norm = self.global_norm(grads) | |||
| cond = P.GreaterEqual()(global_norm, self.clip_norm) | |||
| global_norm = F.select(cond, global_norm, self.clip_norm) | |||
| grads = self.hyper_map(F.partial(apply_global_norm, self.clip_norm, global_norm), grads) | |||
| return grads | |||
| cast = P.Cast() | |||
| update_accu_grads = C.MultitypeFuncGraph("update_accu_grads") | |||
| @update_accu_grads.register("Tensor", "Tensor") | |||
| def _update_accu_grads(accu_grad, grad): | |||
| succ = True | |||
| return F.depend(succ, F.assign_add(accu_grad, cast(grad, mstype.float32))) | |||
| zeroslike = P.ZerosLike() | |||
| reset_accu_grads = C.MultitypeFuncGraph("reset_accu_grads") | |||
| @reset_accu_grads.register("Tensor") | |||
| def _reset_accu_grads(accu_grad): | |||
| succ = True | |||
| return F.depend(succ, F.assign(accu_grad, zeroslike(accu_grad))) | |||
| grad_scale = C.MultitypeFuncGraph("grad_scale") | |||
| reciprocal = P.Reciprocal() | |||
| @grad_scale.register("Tensor", "Tensor") | |||
| def tensor_grad_scale(scale, grad): | |||
| return grad * reciprocal(scale) | |||
| class TrainAccumulateStepsWithLossScaleCell(Cell): | |||
| """ | |||
| Encapsulation class of bert network training. | |||
| Append an optimizer to the training network after that the construct | |||
| function can be called to create the backward graph. To mimic higher batch size, gradients are | |||
| accumulated N times before weight update. | |||
| Args: | |||
| network (Cell): The training network. Note that loss function should have been added. | |||
| optimizer (Optimizer): Optimizer for updating the weights. | |||
| scale_update_cell (Cell): Cell to do the loss scale. Default: None. | |||
| accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size = | |||
| batch_size * accumulation_steps. Default: 1. | |||
| """ | |||
| def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=4): | |||
| super(TrainAccumulateStepsWithLossScaleCell, self).__init__(auto_prefix=False) | |||
| self.network = network | |||
| self.network.set_grad() | |||
| self.weights = optimizer.parameters | |||
| self.optimizer = optimizer | |||
| self.accumulation_steps = accumulation_steps | |||
| self.one = Tensor(np.array([1]).astype(np.int32)) | |||
| self.zero = Tensor(np.array([0]).astype(np.int32)) | |||
| self.local_step = Parameter(initializer(0, [1], mstype.int32), name="local_step") | |||
| self.accu_grads = self.weights.clone(prefix="accu_grads", init='zeros') | |||
| self.accu_overflow = Parameter(initializer(0, [1], mstype.int32)) | |||
| self.accu_loss = Parameter(initializer(0, [1], mstype.float32)) | |||
| self.grad = C.GradOperation(get_by_list=True, sens_param=True) | |||
| self.reducer_flag = False | |||
| self.parallel_mode = context.get_auto_parallel_context("parallel_mode") | |||
| if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: | |||
| self.reducer_flag = True | |||
| self.grad_reducer = F.identity | |||
| self.degree = 1 | |||
| if self.reducer_flag: | |||
| self.degree = get_group_size() | |||
| self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) | |||
| self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) | |||
| self.overflow_reducer = F.identity | |||
| if self.is_distributed: | |||
| self.overflow_reducer = P.AllReduce() | |||
| self.cast = P.Cast() | |||
| self.alloc_status = P.NPUAllocFloatStatus() | |||
| self.get_status = P.NPUGetFloatStatus() | |||
| self.clear_before_grad = P.NPUClearFloatStatus() | |||
| self.reduce_sum = P.ReduceSum(keep_dims=False) | |||
| self.base = Tensor(1, mstype.float32) | |||
| self.less_equal = P.LessEqual() | |||
| self.logical_or = P.LogicalOr() | |||
| self.not_equal = P.NotEqual() | |||
| self.select = P.Select() | |||
| self.reshape = P.Reshape() | |||
| self.hyper_map = C.HyperMap() | |||
| self.loss_scale = None | |||
| self.loss_scaling_manager = scale_update_cell | |||
| if scale_update_cell: | |||
| self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32)) | |||
| @C.add_flags(has_effect=True) | |||
| def construct(self, x, b, sens=None): | |||
| """Defines the computation performed.""" | |||
| weights = self.weights | |||
| loss = self.network(x, b) | |||
| if sens is None: | |||
| scaling_sens = self.loss_scale | |||
| else: | |||
| scaling_sens = sens | |||
| # update accumulation parameters | |||
| is_accu_step = self.not_equal(self.local_step, self.accumulation_steps) | |||
| self.local_step = self.select(is_accu_step, self.local_step + self.one, self.one) | |||
| self.accu_loss = self.select(is_accu_step, self.accu_loss + loss, loss) | |||
| mean_loss = self.accu_loss / self.local_step | |||
| is_accu_step = self.not_equal(self.local_step, self.accumulation_steps) | |||
| # alloc status and clear should be right before gradoperation | |||
| init = self.alloc_status() | |||
| self.clear_before_grad(init) | |||
| grads = self.grad(self.network, weights)(x, b, self.cast(scaling_sens, mstype.float32)) | |||
| accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, grads) | |||
| mean_loss = F.depend(mean_loss, accu_succ) | |||
| self.get_status(init) | |||
| flag_sum = self.reduce_sum(init, (0,)) | |||
| overflow = self.less_equal(self.base, flag_sum) | |||
| overflow = self.logical_or(self.not_equal(self.accu_overflow, self.zero), overflow) | |||
| accu_overflow = self.select(overflow, self.one, self.zero) | |||
| self.accu_overflow = self.select(is_accu_step, accu_overflow, self.zero) | |||
| is_accu_step = self.reshape(is_accu_step, (())) | |||
| if is_accu_step: | |||
| succ = False | |||
| else: | |||
| # apply grad reducer on grads | |||
| grads = self.grad_reducer(self.accu_grads) | |||
| scaling = scaling_sens * self.degree * self.accumulation_steps | |||
| grads = self.hyper_map(F.partial(grad_scale, scaling), grads) | |||
| grads = ClipByGlobalNorm()(grads) | |||
| accu_overflow = self.overflow_reducer(accu_overflow) | |||
| F.control_depend(grads, accu_overflow) | |||
| overflow = self.less_equal(self.base, accu_overflow) | |||
| accu_succ = self.hyper_map(reset_accu_grads, self.accu_grads) | |||
| overflow = F.depend(overflow, accu_succ) | |||
| overflow = self.reshape(overflow, (())) | |||
| if sens is None: | |||
| overflow = self.loss_scaling_manager(self.loss_scale, overflow) | |||
| if overflow: | |||
| succ = False | |||
| else: | |||
| succ = self.optimizer(grads) | |||
| ret = (mean_loss, overflow, scaling_sens) | |||
| return F.depend(ret, succ) | |||
| class Net(Cell): | |||
| def __init__(self, weight, strategy=None): | |||
| super().__init__() | |||
| self.mul = P.Mul().shard(strategy) | |||
| self.weight = Parameter(weight, "w1") | |||
| self.relu = P.ReLU() | |||
| self.reduce_sum = P.ReduceSum(keep_dims=True) | |||
| def construct(self, x, b): | |||
| out = self.mul(x, self.weight) | |||
| out = self.relu(out) | |||
| out = self.reduce_sum(out) | |||
| return out | |||
| _x = Tensor(np.ones([2]), dtype=ms.float32) | |||
| _b = Tensor(np.ones([16]), dtype=ms.float32) | |||
| _w1 = Tensor(np.ones([16]), dtype=ms.float32) | |||
| def compile_net(net, grad_accumulation_step): | |||
| context.set_context(save_graphs=True) | |||
| learning_rate = 0.1 | |||
| momentum = 0.9 | |||
| epoch_size = 2 | |||
| dataset = Dataset(_x, _b) | |||
| opt = Momentum(net.trainable_params(), learning_rate, momentum) | |||
| update_cell = DynamicLossScaleUpdateCell(loss_scale_value=65536, scale_factor=2, scale_window=1000) | |||
| net_wrap = TrainAccumulateStepsWithLossScaleCell(net, opt, scale_update_cell=update_cell, | |||
| accumulation_steps=grad_accumulation_step) | |||
| model = Model(net_wrap) | |||
| model.train(epoch_size, dataset, dataset_sink_mode=False) | |||
| context.reset_auto_parallel_context() | |||
| def test_grad_accumulation(): | |||
| grad_accumulation_step = 4 | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0, | |||
| grad_accumulation_step=grad_accumulation_step) | |||
| strategy = ((2,), (2,)) | |||
| net = Net(_w1, strategy) | |||
| compile_net(net, grad_accumulation_step) | |||