| @@ -80,6 +80,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||||
| {prim::kPrimReduceMean, prim::kPrimReduceAll, prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin}); | {prim::kPrimReduceMean, prim::kPrimReduceAll, prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin}); | ||||
| partial_eliminate_ = MakeSubstitution(std::make_shared<PartialEliminater>(), "partial_eliminate", IsCNodeDup); | partial_eliminate_ = MakeSubstitution(std::make_shared<PartialEliminater>(), "partial_eliminate", IsCNodeDup); | ||||
| same_eliminate_ = MakeSubstitution(std::make_shared<SameEliminater>(), "same_eliminate", prim::kPrimSameTypeShape); | same_eliminate_ = MakeSubstitution(std::make_shared<SameEliminater>(), "same_eliminate", prim::kPrimSameTypeShape); | ||||
| mirror_mini_step_elim_ = MakeSubstitution(std::make_shared<MirrorMiniStepEliminater>(), "mirror_mini_step_eliminate", | |||||
| prim::kPrimMirrorMiniStep); | |||||
| check_bprop_eliminate_ = | check_bprop_eliminate_ = | ||||
| MakeSubstitution(std::make_shared<CheckBpropEliminater>(), "check_bprop_eliminate", prim::kPrimCheckBprop); | MakeSubstitution(std::make_shared<CheckBpropEliminater>(), "check_bprop_eliminate", prim::kPrimCheckBprop); | ||||
| reset_defer_inline_ = | reset_defer_inline_ = | ||||
| @@ -51,6 +51,7 @@ class OptimizeIRPassLib { | |||||
| SubstitutionPtr reset_defer_inline_; | SubstitutionPtr reset_defer_inline_; | ||||
| SubstitutionPtr depend_value_elim_; | SubstitutionPtr depend_value_elim_; | ||||
| SubstitutionPtr all_reduce_const_elim_; | SubstitutionPtr all_reduce_const_elim_; | ||||
| SubstitutionPtr mirror_mini_step_elim_; | |||||
| // Env Item Eliminate | // Env Item Eliminate | ||||
| SubstitutionPtr env_get_item_eliminate_; | SubstitutionPtr env_get_item_eliminate_; | ||||
| @@ -155,6 +155,29 @@ class CheckBpropEliminater : public AnfVisitor { | |||||
| AnfNodePtr x_{nullptr}; | AnfNodePtr x_{nullptr}; | ||||
| }; | }; | ||||
| // {prim::kPrimMirrorMiniStep, X, Y, Z} -> X | |||||
| class MirrorMiniStepEliminater : public AnfVisitor { | |||||
| public: | |||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||||
| if (!IsPrimitiveCNode(node, prim::kPrimMirrorMiniStep) || node->func_graph() == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| if (cnode == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| auto inputs = cnode->inputs(); | |||||
| if (inputs.size() < 2) { | |||||
| return nullptr; | |||||
| } | |||||
| return inputs[1]; | |||||
| } | |||||
| void Visit(const AnfNodePtr &) override {} | |||||
| }; | |||||
| // Reset defer_inline flag | // Reset defer_inline flag | ||||
| class ResetDeferInline : public AnfVisitor { | class ResetDeferInline : public AnfVisitor { | ||||
| public: | public: | ||||
| @@ -64,6 +64,7 @@ void ParallelContext::Reset() { | |||||
| all_reduce_fusion_split_sizes_.clear(); | all_reduce_fusion_split_sizes_.clear(); | ||||
| strategy_search_mode_ = DYNAMIC_PROGRAMMING; | strategy_search_mode_ = DYNAMIC_PROGRAMMING; | ||||
| pipeline_stage_split_num_ = 1; | pipeline_stage_split_num_ = 1; | ||||
| grad_accumulation_step_ = 1; | |||||
| } | } | ||||
| void ParallelContext::set_device_num(int64_t device_num) { | void ParallelContext::set_device_num(int64_t device_num) { | ||||
| @@ -80,6 +81,10 @@ void ParallelContext::set_gradients_mean(bool gradients_mean) { gradients_mean_ | |||||
| void ParallelContext::set_full_batch(bool full_batch) { full_batch_ = full_batch; } | void ParallelContext::set_full_batch(bool full_batch) { full_batch_ = full_batch; } | ||||
| void ParallelContext::set_grad_accumulation_step(int64_t grad_accumulation_step) { | |||||
| grad_accumulation_step_ = grad_accumulation_step; | |||||
| } | |||||
| void ParallelContext::set_gradient_fp32_sync(bool gradient_fp32_sync) { gradient_fp32_sync_ = gradient_fp32_sync; } | void ParallelContext::set_gradient_fp32_sync(bool gradient_fp32_sync) { gradient_fp32_sync_ = gradient_fp32_sync; } | ||||
| void ParallelContext::set_loss_repeated_mean(bool loss_repeated_mean) { loss_repeated_mean_ = loss_repeated_mean; } | void ParallelContext::set_loss_repeated_mean(bool loss_repeated_mean) { loss_repeated_mean_ = loss_repeated_mean; } | ||||
| @@ -73,6 +73,9 @@ class ParallelContext { | |||||
| void set_global_rank(int64_t global_rank); | void set_global_rank(int64_t global_rank); | ||||
| int64_t global_rank() const { return global_rank_; } | int64_t global_rank() const { return global_rank_; } | ||||
| void set_grad_accumulation_step(int64_t grad_accumulation_step); | |||||
| int64_t grad_accumulation_step() const { return grad_accumulation_step_; } | |||||
| bool set_parallel_mode(const std::string ¶llel_mode); | bool set_parallel_mode(const std::string ¶llel_mode); | ||||
| std::string parallel_mode() const { return parallel_mode_; } | std::string parallel_mode() const { return parallel_mode_; } | ||||
| @@ -116,6 +119,7 @@ class ParallelContext { | |||||
| bool loss_repeated_mean_; | bool loss_repeated_mean_; | ||||
| int64_t device_num_; | int64_t device_num_; | ||||
| int64_t global_rank_; | int64_t global_rank_; | ||||
| int64_t grad_accumulation_step_; | |||||
| std::string parallel_mode_; | std::string parallel_mode_; | ||||
| std::string strategy_search_mode_; | std::string strategy_search_mode_; | ||||
| int64_t pipeline_stage_split_num_; | int64_t pipeline_stage_split_num_; | ||||
| @@ -285,8 +285,8 @@ OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num) { | |||||
| } | } | ||||
| OperatorVector op_for_weight; | OperatorVector op_for_weight; | ||||
| bool mean_flag = ParallelContext::GetInstance()->gradients_mean(); | bool mean_flag = ParallelContext::GetInstance()->gradients_mean(); | ||||
| int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step(); | |||||
| OperatorName operator_name = MIRROR_OPERATOR; | |||||
| ValuePtr attr0_value = MakeValue(group_name); | ValuePtr attr0_value = MakeValue(group_name); | ||||
| ValuePtr attr1_value = MakeValue(SizeToLong(dev_num)); | ValuePtr attr1_value = MakeValue(SizeToLong(dev_num)); | ||||
| ValuePtr attr2_value = MakeValue(mean_flag); | ValuePtr attr2_value = MakeValue(mean_flag); | ||||
| @@ -300,6 +300,17 @@ OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num) { | |||||
| operator_attrs.push_back(attr1); | operator_attrs.push_back(attr1); | ||||
| operator_attrs.push_back(attr2); | operator_attrs.push_back(attr2); | ||||
| OperatorName operator_name; | |||||
| if (grad_accumulation_step > 1) { | |||||
| operator_name = MIRROR_MINI_STEP_OPERATOR; | |||||
| ValuePtr attr3_value = MakeValue(grad_accumulation_step); | |||||
| Attr attr3 = std::make_pair(GRAD_ACCUMULATION_STEP, attr3_value); | |||||
| operator_attrs.push_back(attr3); | |||||
| MS_LOG(INFO) << "The grad accumulation step is " << grad_accumulation_step << ", use mini step mirror"; | |||||
| } else { | |||||
| operator_name = MIRROR_OPERATOR; | |||||
| } | |||||
| OperatorParams operator_param; | OperatorParams operator_param; | ||||
| OperatorArgs operator_args = std::make_pair(operator_attrs, operator_param); | OperatorArgs operator_args = std::make_pair(operator_attrs, operator_param); | ||||
| @@ -146,8 +146,10 @@ constexpr char IS_IN_FORWARD[] = "is_in_forward"; | |||||
| constexpr char DTYPE[] = "DType"; | constexpr char DTYPE[] = "DType"; | ||||
| constexpr char DEV_NUM[] = "dev_num"; | constexpr char DEV_NUM[] = "dev_num"; | ||||
| constexpr char MEAN_FLAG[] = "mean_flag"; | constexpr char MEAN_FLAG[] = "mean_flag"; | ||||
| constexpr char GRAD_ACCUMULATION_STEP[] = "grad_accumulation_step"; | |||||
| constexpr char TYPES[] = "types"; | constexpr char TYPES[] = "types"; | ||||
| constexpr char SHAPES[] = "shapes"; | constexpr char SHAPES[] = "shapes"; | ||||
| constexpr char ACCU_GRADS[] = "accu_grads"; | |||||
| constexpr char GETNEXT_NUM[] = "output_num"; | constexpr char GETNEXT_NUM[] = "output_num"; | ||||
| constexpr char SHARED_NAME[] = "shared_name"; | constexpr char SHARED_NAME[] = "shared_name"; | ||||
| constexpr char MIRROR_OP[] = "mirror_op"; | constexpr char MIRROR_OP[] = "mirror_op"; | ||||
| @@ -171,6 +173,8 @@ constexpr char CONCAT_BY_AXIS[] = "ConcatByAxis"; | |||||
| constexpr char SPLIT_BY_AXIS[] = "SplitByAxis"; | constexpr char SPLIT_BY_AXIS[] = "SplitByAxis"; | ||||
| constexpr char ALL_REDUCE[] = "AllReduce"; | constexpr char ALL_REDUCE[] = "AllReduce"; | ||||
| constexpr char MIRROR_OPERATOR[] = "_MirrorOperator"; | constexpr char MIRROR_OPERATOR[] = "_MirrorOperator"; | ||||
| constexpr char MIRROR_MINI_STEP_OPERATOR[] = "_MirrorMiniStepOperator"; | |||||
| constexpr char LOCAL_STEP[] = "local_step"; | |||||
| constexpr char STRIDED_SLICE[] = "StridedSlice"; | constexpr char STRIDED_SLICE[] = "StridedSlice"; | ||||
| constexpr char ALL_GATHER[] = "AllGather"; | constexpr char ALL_GATHER[] = "AllGather"; | ||||
| constexpr char REDUCE_SCATTER[] = "ReduceScatter"; | constexpr char REDUCE_SCATTER[] = "ReduceScatter"; | ||||
| @@ -128,6 +128,137 @@ void InsertNode(const Operator &op, const CNodePtr &node, size_t index, const An | |||||
| MS_LOG(INFO) << "Insert " << instance_name << " success"; | MS_LOG(INFO) << "Insert " << instance_name << " success"; | ||||
| } | } | ||||
| bool ParameterIsCloned(const AnfNodePtr ¶meter_node) { | |||||
| MS_EXCEPTION_IF_NULL(parameter_node); | |||||
| auto cloned_parameter = parameter_node->cast<ParameterPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cloned_parameter); | |||||
| // find the clone parameter | |||||
| if (!cloned_parameter->has_default()) { | |||||
| return false; | |||||
| } | |||||
| auto param_value = cloned_parameter->param_info(); | |||||
| if (param_value == nullptr) { | |||||
| return false; | |||||
| } | |||||
| bool cloned = param_value->cloned(); | |||||
| if (!cloned) { | |||||
| return false; | |||||
| } | |||||
| MS_LOG(INFO) << "The parameter: " << cloned_parameter->name() << " is cloned"; | |||||
| return true; | |||||
| } | |||||
| std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operator &op, const AnfNodePtr &node, | |||||
| const std::string &instance_name, const std::string &weight_name) { | |||||
| MS_EXCEPTION_IF_NULL(root); | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| MS_EXCEPTION_IF_NULL(root->manager()); | |||||
| AnfNodePtr local_step_param = nullptr; | |||||
| AnfNodePtr grad_accu = nullptr; | |||||
| std::string op_name = op.first; | |||||
| OperatorArgs arg_forward = op.second; | |||||
| int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step(); | |||||
| if (grad_accumulation_step > 1) { | |||||
| bool find_locat_step_node = false; | |||||
| auto parameters = root->parameters(); | |||||
| for (auto ¶m : parameters) { | |||||
| auto param_ptr = param->cast<ParameterPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(param_ptr); | |||||
| if (param_ptr->name() == LOCAL_STEP) { | |||||
| auto param_users = root->manager()->node_users()[param]; | |||||
| for (auto &user : param_users) { | |||||
| if (AnfNodeIsPrimitive(user.first, ASSIGN)) { | |||||
| find_locat_step_node = true; | |||||
| local_step_param = user.first; | |||||
| MS_LOG(INFO) << "Find the local step when create mirror, it may be in the mini step grad accumulation mode"; | |||||
| break; | |||||
| } | |||||
| } | |||||
| break; | |||||
| } | |||||
| } | |||||
| bool find_grad_accu_node = false; | |||||
| for (auto ¶m : parameters) { | |||||
| if (!ParameterIsCloned(param)) { | |||||
| continue; | |||||
| } | |||||
| auto param_ptr = param->cast<ParameterPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(param_ptr); | |||||
| if (param_ptr->name().find(weight_name) != std::string::npos && | |||||
| param_ptr->name().find(ACCU_GRADS) != std::string::npos) { | |||||
| find_grad_accu_node = true; | |||||
| grad_accu = param; | |||||
| MS_LOG(INFO) << "Find the accumulation grad node: " << param_ptr->name(); | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (op_name == MIRROR_MINI_STEP_OPERATOR) { | |||||
| if (!find_locat_step_node || !find_grad_accu_node) { | |||||
| op_name = MIRROR_OPERATOR; | |||||
| arg_forward.first.pop_back(); | |||||
| } | |||||
| } | |||||
| } | |||||
| ValuePtr pyop_instance = CreatOpInstance(arg_forward.first, op_name, instance_name); | |||||
| MS_EXCEPTION_IF_NULL(pyop_instance); | |||||
| OperatorParams params = arg_forward.second; | |||||
| std::vector<AnfNodePtr> new_node_input; | |||||
| if (op_name == MIRROR_MINI_STEP_OPERATOR) { | |||||
| new_node_input = {NewValueNode(pyop_instance), node, local_step_param, grad_accu}; | |||||
| MS_LOG(INFO) << "Insert the local step node and grad accumulation node as the mirror op's input"; | |||||
| } else { | |||||
| new_node_input = {NewValueNode(pyop_instance), node}; | |||||
| } | |||||
| if (!params.empty()) { | |||||
| for (auto ¶m : params) { | |||||
| AnfNodePtr val = NewValueNode(param.first.second); | |||||
| MS_EXCEPTION_IF_NULL(val); | |||||
| int64_t position = param.second; | |||||
| (void)new_node_input.insert(new_node_input.begin() + position, val); | |||||
| } | |||||
| } | |||||
| // if the op have 'group' attr, set the rank list name for the op | |||||
| SetCommunicationOpGroupLabel(new_node_input); | |||||
| return new_node_input; | |||||
| } | |||||
| void InsertMirrorNode(const FuncGraphPtr &root, const Operator &op, const CNodePtr &node, size_t index, | |||||
| const AnfNodePtr &pre_node, const FuncGraphPtr &func_graph, const std::string &instance_name, | |||||
| const std::string ¶m_name) { | |||||
| // insert new node before the node | |||||
| FuncGraphManagerPtr manager = func_graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| ScopePtr scope = node->scope(); | |||||
| MS_EXCEPTION_IF_NULL(scope); | |||||
| std::vector<AnfNodePtr> node_input = CreateMirrorInput(root, op, pre_node, instance_name, param_name); | |||||
| CNodePtr new_node = func_graph->NewCNode(node_input); | |||||
| MS_EXCEPTION_IF_NULL(new_node); | |||||
| if (instance_name.find(SPLIT_SENS) == std::string::npos) { | |||||
| new_node->set_in_forward_flag(true); // mark forward flag | |||||
| } | |||||
| auto new_node_value = node_input[0]->cast<ValueNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(new_node_value); | |||||
| PrimitivePtr new_node_prim = new_node_value->value()->cast<PrimitivePtr>(); | |||||
| new_node_prim->set_instance_name(instance_name); | |||||
| new_node_prim->set_attr("keep_value_node_input", MakeValue(true)); | |||||
| new_node->set_scope(scope); | |||||
| node_input[0]->set_scope(scope); | |||||
| manager->SetEdge(node, SizeToLong(index), new_node); | |||||
| MS_LOG(INFO) << "Insert " << instance_name << " success"; | |||||
| } | |||||
| // Replace pre_node with pre_node->op | // Replace pre_node with pre_node->op | ||||
| static CNodePtr ReplaceNode(const Operator &op, const AnfNodePtr &pre_node, const FuncGraphPtr &func_graph, | static CNodePtr ReplaceNode(const Operator &op, const AnfNodePtr &pre_node, const FuncGraphPtr &func_graph, | ||||
| const std::string &instance_name) { | const std::string &instance_name) { | ||||
| @@ -965,7 +1096,7 @@ static void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr &par | |||||
| MS_LOG(INFO) << "Set comm fusion:" << param->param_info()->name() << "'s fusion type is " << fusion_type; | MS_LOG(INFO) << "Set comm fusion:" << param->param_info()->name() << "'s fusion type is " << fusion_type; | ||||
| } | } | ||||
| void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) { | |||||
| void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, const CNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| size_t node_size = node->inputs().size(); | size_t node_size = node->inputs().size(); | ||||
| FuncGraphPtr func_graph = node->func_graph(); | FuncGraphPtr func_graph = node->func_graph(); | ||||
| @@ -997,6 +1128,13 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) { | |||||
| if (!param_node_pair.first) { | if (!param_node_pair.first) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| auto param_ptr = param_node_pair.first->cast<ParameterPtr>(); | |||||
| std::string param_name; | |||||
| if (param_ptr != nullptr) { | |||||
| param_name = param_ptr->name(); | |||||
| } | |||||
| // not a RefKey | // not a RefKey | ||||
| if (!param_node_pair.second) { | if (!param_node_pair.second) { | ||||
| auto next_cnode = FindCNode(param_node_pair.first, MIRROR_OPERATOR, func_graph); | auto next_cnode = FindCNode(param_node_pair.first, MIRROR_OPERATOR, func_graph); | ||||
| @@ -1028,7 +1166,7 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) { | |||||
| CNodePtr cnode = node->input(index)->cast<CNodePtr>(); | CNodePtr cnode = node->input(index)->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| AnfNodePtr pre_node = cnode->input(1); | AnfNodePtr pre_node = cnode->input(1); | ||||
| InsertNode(op, cnode, size_t(1), pre_node, func_graph, instance_name); | |||||
| InsertMirrorNode(root, op, cnode, size_t(1), pre_node, func_graph, instance_name, param_name); | |||||
| auto comm_op = cnode->input(size_t(1))->cast<CNodePtr>(); | auto comm_op = cnode->input(size_t(1))->cast<CNodePtr>(); | ||||
| // add fusion flag | // add fusion flag | ||||
| // pipeline mirror would not be set, which should be supported later | // pipeline mirror would not be set, which should be supported later | ||||
| @@ -1037,7 +1175,7 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) { | |||||
| } else { | } else { | ||||
| for (auto &op : backward_op) { | for (auto &op : backward_op) { | ||||
| AnfNodePtr pre_node = node->input(index); | AnfNodePtr pre_node = node->input(index); | ||||
| InsertNode(op, node, index, pre_node, func_graph, instance_name); | |||||
| InsertMirrorNode(root, op, node, index, pre_node, func_graph, instance_name, param_name); | |||||
| auto comm_op = node->input(index)->cast<CNodePtr>(); | auto comm_op = node->input(index)->cast<CNodePtr>(); | ||||
| // add fusion flag | // add fusion flag | ||||
| // pipeline mirror would not be set, which should be supported later | // pipeline mirror would not be set, which should be supported later | ||||
| @@ -1047,7 +1185,7 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) { | |||||
| } | } | ||||
| } | } | ||||
| void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node, | |||||
| void BackwardCommunication(const FuncGraphPtr &root, const OperatorInfoPtr &distribute_operator, const CNodePtr &node, | |||||
| const std::vector<std::pair<CNodePtr, LossNodeInfo>> &sens_loss_pairs) { | const std::vector<std::pair<CNodePtr, LossNodeInfo>> &sens_loss_pairs) { | ||||
| MS_EXCEPTION_IF_NULL(distribute_operator); | MS_EXCEPTION_IF_NULL(distribute_operator); | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| @@ -1061,7 +1199,7 @@ void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNo | |||||
| // insert mirror op | // insert mirror op | ||||
| if (!mirror_ops.empty()) { | if (!mirror_ops.empty()) { | ||||
| MS_LOG(INFO) << "insert mirror op for " << distribute_operator->name(); | MS_LOG(INFO) << "insert mirror op for " << distribute_operator->name(); | ||||
| InsertMirrorOps(mirror_ops, node); | |||||
| InsertMirrorOps(root, mirror_ops, node); | |||||
| } | } | ||||
| // insert virtual div op | // insert virtual div op | ||||
| if (!virtual_div_op.empty() && is_loss_cnode) { | if (!virtual_div_op.empty() && is_loss_cnode) { | ||||
| @@ -1519,28 +1657,6 @@ void CoverSliceShape(const FuncGraphPtr &root) { | |||||
| g_RefMap.clear(); | g_RefMap.clear(); | ||||
| } | } | ||||
| bool ParameterIsCloned(const AnfNodePtr ¶meter_node) { | |||||
| MS_EXCEPTION_IF_NULL(parameter_node); | |||||
| auto cloned_parameter = parameter_node->cast<ParameterPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cloned_parameter); | |||||
| // find the clone parameter | |||||
| if (!cloned_parameter->has_default()) { | |||||
| return false; | |||||
| } | |||||
| auto param_value = cloned_parameter->param_info(); | |||||
| if (param_value == nullptr) { | |||||
| return false; | |||||
| } | |||||
| bool cloned = param_value->cloned(); | |||||
| if (!cloned) { | |||||
| return false; | |||||
| } | |||||
| MS_LOG(INFO) << "The parameter: " << cloned_parameter->name() << " is cloned"; | |||||
| return true; | |||||
| } | |||||
| void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) { | void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) { | ||||
| MS_EXCEPTION_IF_NULL(root); | MS_EXCEPTION_IF_NULL(root); | ||||
| for (auto &cloned_parameter_node : root->parameters()) { | for (auto &cloned_parameter_node : root->parameters()) { | ||||
| @@ -2459,7 +2575,7 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt | |||||
| // insert backward ops | // insert backward ops | ||||
| if (has_backward && !IsSomePrimitive(cnode, RECEIVE)) { | if (has_backward && !IsSomePrimitive(cnode, RECEIVE)) { | ||||
| BackwardCommunication(distribute_operator, cnode, sens_loss_pairs); | |||||
| BackwardCommunication(root, distribute_operator, cnode, sens_loss_pairs); | |||||
| } | } | ||||
| HandleSpecialNode(distribute_operator, cnode); | HandleSpecialNode(distribute_operator, cnode); | ||||
| @@ -82,11 +82,6 @@ std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGrap | |||||
| std::pair<bool, CNodePtr> FindCNode(const AnfNodePtr &anode, const std::string &name, const FuncGraphPtr &func_graph); | std::pair<bool, CNodePtr> FindCNode(const AnfNodePtr &anode, const std::string &name, const FuncGraphPtr &func_graph); | ||||
| void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node); | |||||
| void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node, | |||||
| const std::vector<std::pair<CNodePtr, LossNodeInfo>> &sens_loss_pairs); | |||||
| // Generate and init parallel operator | // Generate and init parallel operator | ||||
| OperatorInfoPtr OperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs, | OperatorInfoPtr OperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs, | ||||
| const std::vector<Shapes> &shape_list); | const std::vector<Shapes> &shape_list); | ||||
| @@ -131,6 +131,8 @@ PYBIND11_MODULE(_c_expression, m) { | |||||
| .def("set_loss_repeated_mean", &ParallelContext::set_loss_repeated_mean, "Set loss repeated mean.") | .def("set_loss_repeated_mean", &ParallelContext::set_loss_repeated_mean, "Set loss repeated mean.") | ||||
| .def("get_parallel_mode", &ParallelContext::parallel_mode, "Get parallel mode.") | .def("get_parallel_mode", &ParallelContext::parallel_mode, "Get parallel mode.") | ||||
| .def("set_parallel_mode", &ParallelContext::set_parallel_mode, "Set parallel mode.") | .def("set_parallel_mode", &ParallelContext::set_parallel_mode, "Set parallel mode.") | ||||
| .def("get_grad_accumulation_step", &ParallelContext::grad_accumulation_step, "Get grad accumulation step.") | |||||
| .def("set_grad_accumulation_step", &ParallelContext::set_grad_accumulation_step, "Set grad accumulation step.") | |||||
| .def("get_strategy_search_mode", &ParallelContext::strategy_search_mode, "Get strategy search mode.") | .def("get_strategy_search_mode", &ParallelContext::strategy_search_mode, "Get strategy search mode.") | ||||
| .def("set_strategy_search_mode", &ParallelContext::set_strategy_search_mode, "Set strategy search mode.") | .def("set_strategy_search_mode", &ParallelContext::set_strategy_search_mode, "Set strategy search mode.") | ||||
| .def("set_all_reduce_fusion_split_indices", &ParallelContext::SetAllReduceFusionSplitIndices, | .def("set_all_reduce_fusion_split_indices", &ParallelContext::SetAllReduceFusionSplitIndices, | ||||
| @@ -143,6 +143,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { | |||||
| irpass.check_bprop_eliminate_, | irpass.check_bprop_eliminate_, | ||||
| irpass.switch_layer_defer_inline_, | irpass.switch_layer_defer_inline_, | ||||
| irpass.replace_applicator_, | irpass.replace_applicator_, | ||||
| irpass.mirror_mini_step_elim_, | |||||
| }); | }); | ||||
| opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_}); | opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_}); | ||||
| opt::OptPassConfig grad = opt::OptPassConfig({irpass.expand_jprim_}, true); | opt::OptPassConfig grad = opt::OptPassConfig({irpass.expand_jprim_}, true); | ||||
| @@ -206,6 +206,7 @@ inline const PrimitivePtr kPrimTensorMove = std::make_shared<Primitive>("TensorM | |||||
| // Comm ops | // Comm ops | ||||
| inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator"); | inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator"); | ||||
| inline const PrimitivePtr kPrimMirrorMiniStep = std::make_shared<Primitive>("_MirrorMiniStepOperator"); | |||||
| inline const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv"); | inline const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv"); | ||||
| inline const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset"); | inline const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset"); | ||||
| inline const PrimitivePtr kPrimSend = std::make_shared<Primitive>("Send"); | inline const PrimitivePtr kPrimSend = std::make_shared<Primitive>("Send"); | ||||
| @@ -21,7 +21,7 @@ from .. import operations as P | |||||
| from ...common.tensor import RowTensor | from ...common.tensor import RowTensor | ||||
| from ..composite.multitype_ops.zeros_like_impl import zeros_like | from ..composite.multitype_ops.zeros_like_impl import zeros_like | ||||
| from ..operations.comm_ops import (AllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast, | from ..operations.comm_ops import (AllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast, | ||||
| _GetTensorSlice, _MirrorOperator, ReduceOp, | |||||
| _GetTensorSlice, _MirrorOperator, _MirrorMiniStepOperator, ReduceOp, | |||||
| ReduceScatter, _HostReduceScatter, _VirtualDiv, AllSwap) | ReduceScatter, _HostReduceScatter, _VirtualDiv, AllSwap) | ||||
| from .grad_base import bprop_getters | from .grad_base import bprop_getters | ||||
| from ..operations._inner_ops import Send, Receive | from ..operations._inner_ops import Send, Receive | ||||
| @@ -282,6 +282,82 @@ def get_bprop_mirror_operator(self): | |||||
| return bprop | return bprop | ||||
| @bprop_getters.register(_MirrorMiniStepOperator) | |||||
| def get_bprop_mirror_mini_step_operator(self): | |||||
| """ | |||||
| Backpropagator for _MirrorMiniStepOperator, do allreduce or allgather for the devices in the group, | |||||
| allgather for sparse feature. | |||||
| """ | |||||
| group = self.group | |||||
| dev_num = self.dev_num | |||||
| mean_flag = self.mean_flag | |||||
| grad_accumulation_step = self.grad_accumulation_step | |||||
| all_reduce = AllReduce(group=group) | |||||
| all_gather = AllGather(group=group) | |||||
| mul = P.Mul() | |||||
| cast = P.Cast() | |||||
| equal = P.Equal() | |||||
| reshape = P.Reshape() | |||||
| fusion = 1 | |||||
| if hasattr(self, 'fusion'): | |||||
| fusion = self.fusion | |||||
| all_reduce.add_prim_attr("fusion", fusion) | |||||
| if hasattr(self, 'parameter'): | |||||
| parameter = self.parameter | |||||
| all_reduce.add_prim_attr("parameter", parameter) | |||||
| if self.instance_name: | |||||
| instance_name = "grad_mirror" + self.instance_name | |||||
| all_reduce.set_prim_instance_name(instance_name) | |||||
| def bprop(x, y, z, out, dout): | |||||
| do_mirror = equal(y, grad_accumulation_step) | |||||
| do_mirror = reshape(do_mirror, (())) | |||||
| if mean_flag: | |||||
| if F.issubclass_(F.typeof(dout), mstype.tensor): | |||||
| if do_mirror: | |||||
| tmp = z + dout | |||||
| real_grad = all_reduce(tmp) | |||||
| dx = real_grad - z | |||||
| else: | |||||
| dx = dout | |||||
| float_one = F.scalar_cast(1.0, F.dtype(dx)) | |||||
| num = F.scalar_cast(dev_num, F.dtype(dx)) | |||||
| dx = mul(dx, cast(F.scalar_to_array(float_one/num), F.dtype(dx))) | |||||
| else: | |||||
| if do_mirror: | |||||
| indices = all_gather(dout.indices) | |||||
| grad = all_gather(dout.values) | |||||
| else: | |||||
| indices = dout.indices | |||||
| grad = dout.values | |||||
| float_one = F.scalar_cast(1.0, F.dtype(grad)) | |||||
| num = F.scalar_cast(dev_num, F.dtype(grad)) | |||||
| grad = mul(grad, cast(F.scalar_to_array(float_one/num), F.dtype(grad))) | |||||
| dx = RowTensor(indices, grad, dout.dense_shape) | |||||
| else: | |||||
| if F.issubclass_(F.typeof(dout), mstype.tensor): | |||||
| if do_mirror: | |||||
| tmp = z + dout | |||||
| real_grad = all_reduce(tmp) | |||||
| dx = real_grad - z | |||||
| else: | |||||
| dx = dout | |||||
| else: | |||||
| if do_mirror: | |||||
| indices = all_gather(dout.indices) | |||||
| grad = all_gather(dout.values) | |||||
| else: | |||||
| indices = dout.indices | |||||
| grad = dout.values | |||||
| dx = RowTensor(indices, grad, dout.dense_shape) | |||||
| return (dx, zeros_like(y), zeros_like(z)) | |||||
| return bprop | |||||
| @bprop_getters.register(_VirtualDiv) | @bprop_getters.register(_VirtualDiv) | ||||
| def get_bprop_virtual_div_operator(self): | def get_bprop_virtual_div_operator(self): | ||||
| """Backpropagator for _VirtualDiv, do Div for the divisor.""" | """Backpropagator for _VirtualDiv, do Div for the divisor.""" | ||||
| @@ -35,7 +35,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, | |||||
| SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup, | SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup, | ||||
| Unique, GatherD, Identity) | Unique, GatherD, Identity) | ||||
| from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast, | from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast, | ||||
| _MirrorOperator, ReduceOp, _VirtualDataset, | |||||
| _MirrorOperator, _MirrorMiniStepOperator, ReduceOp, _VirtualDataset, | |||||
| _VirtualDiv, _GetTensorSlice, | _VirtualDiv, _GetTensorSlice, | ||||
| _HostAllGather, _HostReduceScatter) | _HostAllGather, _HostReduceScatter) | ||||
| from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary, | from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary, | ||||
| @@ -567,6 +567,35 @@ class _MirrorOperator(PrimitiveWithInfer): | |||||
| mirror = _MirrorOperator() | mirror = _MirrorOperator() | ||||
| class _MirrorMiniStepOperator(PrimitiveWithInfer): | |||||
| """ | |||||
| Auto parallel virtual operator. Do nothing in forward, do all reduce and mean in backward. It is only for | |||||
| internal use of parallel modules and cannot be called by users. | |||||
| Args: | |||||
| group (str): The communication group to work on. Default: None. | |||||
| dev_num (int): The device number of the group. Default: None. | |||||
| mean_flag (bool): Whether use mean in backward. Default: None. | |||||
| grad_accumulation_step (int): The grad accumulation step. Default: None. | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self, group=None, dev_num=None, mean_flag=None, grad_accumulation_step=None): | |||||
| self.group = group | |||||
| self.dev_num = dev_num | |||||
| self.mean_flag = mean_flag | |||||
| self.grad_accumulation_step = grad_accumulation_step | |||||
| def infer_shape(self, x_shape, y_shape, z_shape): | |||||
| return x_shape | |||||
| def infer_dtype(self, x_dtype, y_shape, z_shape): | |||||
| return x_dtype | |||||
| mirror_mini_step = _MirrorMiniStepOperator() | |||||
| class _VirtualDiv(PrimitiveWithInfer): | class _VirtualDiv(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| Auto parallel virtual operator. Do nothing in forward, do Div in backward. | Auto parallel virtual operator. Do nothing in forward, do Div in backward. | ||||
| @@ -249,6 +249,21 @@ class _AutoParallelContext: | |||||
| return False | return False | ||||
| return self._context_handle.get_full_batch() | return self._context_handle.get_full_batch() | ||||
| def set_grad_accumulation_step(self, grad_accumulation_step): | |||||
| """ | |||||
| Set grad accumulation step. | |||||
| Args: | |||||
| grad_accumulation_step (int): The grad accumulation step. | |||||
| """ | |||||
| self.check_context_handle() | |||||
| self._context_handle.set_grad_accumulation_step(grad_accumulation_step) | |||||
| def get_grad_accumulation_step(self): | |||||
| """Get grad accumulation step.""" | |||||
| self.check_context_handle() | |||||
| return self._context_handle.get_grad_accumulation_step() | |||||
| def set_strategy_ckpt_save_file(self, strategy_ckpt_save_file): | def set_strategy_ckpt_save_file(self, strategy_ckpt_save_file): | ||||
| """ | """ | ||||
| Set strategy checkpoint save path. | Set strategy checkpoint save path. | ||||
| @@ -492,6 +507,7 @@ _set_auto_parallel_context_func_map = { | |||||
| "strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file, | "strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file, | ||||
| "full_batch": auto_parallel_context().set_full_batch, | "full_batch": auto_parallel_context().set_full_batch, | ||||
| "enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer, | "enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer, | ||||
| "grad_accumulation_step": auto_parallel_context().set_grad_accumulation_step, | |||||
| "all_reduce_fusion_config": auto_parallel_context().set_all_reduce_fusion_split_indices} | "all_reduce_fusion_config": auto_parallel_context().set_all_reduce_fusion_split_indices} | ||||
| @@ -509,6 +525,7 @@ _get_auto_parallel_context_func_map = { | |||||
| "strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file, | "strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file, | ||||
| "full_batch": auto_parallel_context().get_full_batch, | "full_batch": auto_parallel_context().get_full_batch, | ||||
| "enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer, | "enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer, | ||||
| "grad_accumulation_step": auto_parallel_context().get_grad_accumulation_step, | |||||
| "all_reduce_fusion_config": auto_parallel_context().get_all_reduce_fusion_split_indices} | "all_reduce_fusion_config": auto_parallel_context().get_all_reduce_fusion_split_indices} | ||||
| @@ -516,7 +533,7 @@ _get_auto_parallel_context_func_map = { | |||||
| loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str, | loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str, | ||||
| parameter_broadcast=bool, strategy_ckpt_load_file=str, | parameter_broadcast=bool, strategy_ckpt_load_file=str, | ||||
| strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool, | strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool, | ||||
| all_reduce_fusion_config=list) | |||||
| grad_accumulation_step=int, all_reduce_fusion_config=list) | |||||
| def _set_auto_parallel_context(**kwargs): | def _set_auto_parallel_context(**kwargs): | ||||
| """ | """ | ||||
| @@ -0,0 +1,289 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import mindspore as ms | |||||
| import mindspore.common.dtype as mstype | |||||
| from mindspore import context, Tensor, Parameter | |||||
| from mindspore.nn import Cell, Momentum, Norm | |||||
| from mindspore.train import Model | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.ops import composite as C | |||||
| from mindspore.ops import functional as F | |||||
| from mindspore.common.initializer import initializer | |||||
| from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell | |||||
| from mindspore.context import ParallelMode | |||||
| from tests.dataset_mock import MindData | |||||
| class Dataset(MindData): | |||||
| def __init__(self, predict, label, length=3): | |||||
| super(Dataset, self).__init__(size=length) | |||||
| self.predict = predict | |||||
| self.label = label | |||||
| self.index = 0 | |||||
| self.length = length | |||||
| def __iter__(self): | |||||
| return self | |||||
| def __next__(self): | |||||
| if self.index >= self.length: | |||||
| raise StopIteration | |||||
| self.index += 1 | |||||
| return self.predict, self.label | |||||
| def reset(self): | |||||
| self.index = 0 | |||||
| get_square_sum = C.MultitypeFuncGraph("get_square_sum") | |||||
| @get_square_sum.register("Tensor") | |||||
| def _get_square_sum(grad): | |||||
| norm = P.ReduceSum(False)(F.square(grad), ()) | |||||
| norm = F.expand_dims(F.cast(norm, mstype.float32), 0) | |||||
| return norm | |||||
| apply_global_norm = C.MultitypeFuncGraph("apply_global_norm") | |||||
| @apply_global_norm.register("Tensor", "Tensor", "Tensor") | |||||
| def _apply_global_norm(clip_norm, global_norm, grad): | |||||
| grad = grad * clip_norm / global_norm | |||||
| return grad | |||||
| class GlobalNorm(Cell): | |||||
| """ | |||||
| Calculate the global norm value of given tensors | |||||
| """ | |||||
| def __init__(self): | |||||
| super(GlobalNorm, self).__init__() | |||||
| self.norm = Norm() | |||||
| self.hyper_map = C.HyperMap() | |||||
| def construct(self, grads): | |||||
| square_sum = self.hyper_map(get_square_sum, grads) | |||||
| global_norms = F.sqrt(F.addn(square_sum) / F.scalar_to_array(len(square_sum))) | |||||
| return global_norms | |||||
| class ClipByGlobalNorm(Cell): | |||||
| """ | |||||
| Clip grads by global norm | |||||
| """ | |||||
| def __init__(self, clip_norm=1.0): | |||||
| super(ClipByGlobalNorm, self).__init__() | |||||
| self.global_norm = GlobalNorm() | |||||
| self.clip_norm = Tensor([clip_norm], mstype.float32) | |||||
| self.hyper_map = C.HyperMap() | |||||
| def construct(self, grads): | |||||
| global_norm = self.global_norm(grads) | |||||
| cond = P.GreaterEqual()(global_norm, self.clip_norm) | |||||
| global_norm = F.select(cond, global_norm, self.clip_norm) | |||||
| grads = self.hyper_map(F.partial(apply_global_norm, self.clip_norm, global_norm), grads) | |||||
| return grads | |||||
| cast = P.Cast() | |||||
| update_accu_grads = C.MultitypeFuncGraph("update_accu_grads") | |||||
| @update_accu_grads.register("Tensor", "Tensor") | |||||
| def _update_accu_grads(accu_grad, grad): | |||||
| succ = True | |||||
| return F.depend(succ, F.assign_add(accu_grad, cast(grad, mstype.float32))) | |||||
| zeroslike = P.ZerosLike() | |||||
| reset_accu_grads = C.MultitypeFuncGraph("reset_accu_grads") | |||||
| @reset_accu_grads.register("Tensor") | |||||
| def _reset_accu_grads(accu_grad): | |||||
| succ = True | |||||
| return F.depend(succ, F.assign(accu_grad, zeroslike(accu_grad))) | |||||
| grad_scale = C.MultitypeFuncGraph("grad_scale") | |||||
| reciprocal = P.Reciprocal() | |||||
| @grad_scale.register("Tensor", "Tensor") | |||||
| def tensor_grad_scale(scale, grad): | |||||
| return grad * reciprocal(scale) | |||||
| class TrainAccumulateStepsWithLossScaleCell(Cell): | |||||
| """ | |||||
| Encapsulation class of bert network training. | |||||
| Append an optimizer to the training network after that the construct | |||||
| function can be called to create the backward graph. To mimic higher batch size, gradients are | |||||
| accumulated N times before weight update. | |||||
| Args: | |||||
| network (Cell): The training network. Note that loss function should have been added. | |||||
| optimizer (Optimizer): Optimizer for updating the weights. | |||||
| scale_update_cell (Cell): Cell to do the loss scale. Default: None. | |||||
| accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size = | |||||
| batch_size * accumulation_steps. Default: 1. | |||||
| """ | |||||
| def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=4): | |||||
| super(TrainAccumulateStepsWithLossScaleCell, self).__init__(auto_prefix=False) | |||||
| self.network = network | |||||
| self.network.set_grad() | |||||
| self.weights = optimizer.parameters | |||||
| self.optimizer = optimizer | |||||
| self.accumulation_steps = accumulation_steps | |||||
| self.one = Tensor(np.array([1]).astype(np.int32)) | |||||
| self.zero = Tensor(np.array([0]).astype(np.int32)) | |||||
| self.local_step = Parameter(initializer(0, [1], mstype.int32), name="local_step") | |||||
| self.accu_grads = self.weights.clone(prefix="accu_grads", init='zeros') | |||||
| self.accu_overflow = Parameter(initializer(0, [1], mstype.int32)) | |||||
| self.accu_loss = Parameter(initializer(0, [1], mstype.float32)) | |||||
| self.grad = C.GradOperation(get_by_list=True, sens_param=True) | |||||
| self.reducer_flag = False | |||||
| self.parallel_mode = context.get_auto_parallel_context("parallel_mode") | |||||
| if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: | |||||
| self.reducer_flag = True | |||||
| self.grad_reducer = F.identity | |||||
| self.degree = 1 | |||||
| if self.reducer_flag: | |||||
| self.degree = get_group_size() | |||||
| self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) | |||||
| self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) | |||||
| self.overflow_reducer = F.identity | |||||
| if self.is_distributed: | |||||
| self.overflow_reducer = P.AllReduce() | |||||
| self.cast = P.Cast() | |||||
| self.alloc_status = P.NPUAllocFloatStatus() | |||||
| self.get_status = P.NPUGetFloatStatus() | |||||
| self.clear_before_grad = P.NPUClearFloatStatus() | |||||
| self.reduce_sum = P.ReduceSum(keep_dims=False) | |||||
| self.base = Tensor(1, mstype.float32) | |||||
| self.less_equal = P.LessEqual() | |||||
| self.logical_or = P.LogicalOr() | |||||
| self.not_equal = P.NotEqual() | |||||
| self.select = P.Select() | |||||
| self.reshape = P.Reshape() | |||||
| self.hyper_map = C.HyperMap() | |||||
| self.loss_scale = None | |||||
| self.loss_scaling_manager = scale_update_cell | |||||
| if scale_update_cell: | |||||
| self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32)) | |||||
| @C.add_flags(has_effect=True) | |||||
| def construct(self, x, b, sens=None): | |||||
| """Defines the computation performed.""" | |||||
| weights = self.weights | |||||
| loss = self.network(x, b) | |||||
| if sens is None: | |||||
| scaling_sens = self.loss_scale | |||||
| else: | |||||
| scaling_sens = sens | |||||
| # update accumulation parameters | |||||
| is_accu_step = self.not_equal(self.local_step, self.accumulation_steps) | |||||
| self.local_step = self.select(is_accu_step, self.local_step + self.one, self.one) | |||||
| self.accu_loss = self.select(is_accu_step, self.accu_loss + loss, loss) | |||||
| mean_loss = self.accu_loss / self.local_step | |||||
| is_accu_step = self.not_equal(self.local_step, self.accumulation_steps) | |||||
| # alloc status and clear should be right before gradoperation | |||||
| init = self.alloc_status() | |||||
| self.clear_before_grad(init) | |||||
| grads = self.grad(self.network, weights)(x, b, self.cast(scaling_sens, mstype.float32)) | |||||
| accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, grads) | |||||
| mean_loss = F.depend(mean_loss, accu_succ) | |||||
| self.get_status(init) | |||||
| flag_sum = self.reduce_sum(init, (0,)) | |||||
| overflow = self.less_equal(self.base, flag_sum) | |||||
| overflow = self.logical_or(self.not_equal(self.accu_overflow, self.zero), overflow) | |||||
| accu_overflow = self.select(overflow, self.one, self.zero) | |||||
| self.accu_overflow = self.select(is_accu_step, accu_overflow, self.zero) | |||||
| is_accu_step = self.reshape(is_accu_step, (())) | |||||
| if is_accu_step: | |||||
| succ = False | |||||
| else: | |||||
| # apply grad reducer on grads | |||||
| grads = self.grad_reducer(self.accu_grads) | |||||
| scaling = scaling_sens * self.degree * self.accumulation_steps | |||||
| grads = self.hyper_map(F.partial(grad_scale, scaling), grads) | |||||
| grads = ClipByGlobalNorm()(grads) | |||||
| accu_overflow = self.overflow_reducer(accu_overflow) | |||||
| F.control_depend(grads, accu_overflow) | |||||
| overflow = self.less_equal(self.base, accu_overflow) | |||||
| accu_succ = self.hyper_map(reset_accu_grads, self.accu_grads) | |||||
| overflow = F.depend(overflow, accu_succ) | |||||
| overflow = self.reshape(overflow, (())) | |||||
| if sens is None: | |||||
| overflow = self.loss_scaling_manager(self.loss_scale, overflow) | |||||
| if overflow: | |||||
| succ = False | |||||
| else: | |||||
| succ = self.optimizer(grads) | |||||
| ret = (mean_loss, overflow, scaling_sens) | |||||
| return F.depend(ret, succ) | |||||
| class Net(Cell): | |||||
| def __init__(self, weight, strategy=None): | |||||
| super().__init__() | |||||
| self.mul = P.Mul().shard(strategy) | |||||
| self.weight = Parameter(weight, "w1") | |||||
| self.relu = P.ReLU() | |||||
| self.reduce_sum = P.ReduceSum(keep_dims=True) | |||||
| def construct(self, x, b): | |||||
| out = self.mul(x, self.weight) | |||||
| out = self.relu(out) | |||||
| out = self.reduce_sum(out) | |||||
| return out | |||||
| _x = Tensor(np.ones([2]), dtype=ms.float32) | |||||
| _b = Tensor(np.ones([16]), dtype=ms.float32) | |||||
| _w1 = Tensor(np.ones([16]), dtype=ms.float32) | |||||
| def compile_net(net, grad_accumulation_step): | |||||
| context.set_context(save_graphs=True) | |||||
| learning_rate = 0.1 | |||||
| momentum = 0.9 | |||||
| epoch_size = 2 | |||||
| dataset = Dataset(_x, _b) | |||||
| opt = Momentum(net.trainable_params(), learning_rate, momentum) | |||||
| update_cell = DynamicLossScaleUpdateCell(loss_scale_value=65536, scale_factor=2, scale_window=1000) | |||||
| net_wrap = TrainAccumulateStepsWithLossScaleCell(net, opt, scale_update_cell=update_cell, | |||||
| accumulation_steps=grad_accumulation_step) | |||||
| model = Model(net_wrap) | |||||
| model.train(epoch_size, dataset, dataset_sink_mode=False) | |||||
| context.reset_auto_parallel_context() | |||||
| def test_grad_accumulation(): | |||||
| grad_accumulation_step = 4 | |||||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0, | |||||
| grad_accumulation_step=grad_accumulation_step) | |||||
| strategy = ((2,), (2,)) | |||||
| net = Net(_w1, strategy) | |||||
| compile_net(net, grad_accumulation_step) | |||||