From a70d61684136b256bb9e00b231f98ce00ecaa258 Mon Sep 17 00:00:00 2001 From: yangzhenzhang Date: Mon, 28 Dec 2020 13:58:34 +0800 Subject: [PATCH] mini step grad accumulation --- mindspore/ccsrc/frontend/optimizer/irpass.cc | 2 + mindspore/ccsrc/frontend/optimizer/irpass.h | 1 + .../optimizer/irpass/special_op_eliminate.h | 173 ++++++++++++------ mindspore/ccsrc/frontend/parallel/context.cc | 28 +-- mindspore/ccsrc/frontend/parallel/context.h | 13 +- .../parallel/ops_info/operator_info.cc | 63 ++++++- .../parallel/ops_info/operator_info.h | 8 +- .../frontend/parallel/ops_info/ops_utils.h | 2 + .../ccsrc/frontend/parallel/step_parallel.cc | 152 ++++++++------- .../ccsrc/frontend/parallel/step_parallel.h | 2 + mindspore/ccsrc/pipeline/jit/action.cc | 10 +- mindspore/ccsrc/pipeline/jit/pass.cc | 1 + mindspore/ccsrc/pybind_api/ir/tensor_py.cc | 2 +- mindspore/common/parameter.py | 40 ++-- mindspore/core/base/core_ops.h | 1 + mindspore/ops/_grad/grad_comm_ops.py | 53 ++++-- mindspore/ops/operations/__init__.py | 2 +- mindspore/ops/operations/comm_ops.py | 36 +++- .../python/parallel/test_grad_accumulation.py | 82 +++++---- 19 files changed, 447 insertions(+), 224 deletions(-) diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.cc b/mindspore/ccsrc/frontend/optimizer/irpass.cc index 077dfe38f8..af1f6a5ca4 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass.cc @@ -86,6 +86,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() { same_eliminate_ = MakeSubstitution(std::make_shared(), "same_eliminate", prim::kPrimSameTypeShape); mirror_mini_step_elim_ = MakeSubstitution(std::make_shared(), "mirror_mini_step_eliminate", prim::kPrimMirrorMiniStep); + mini_step_allgather_replace_ = MakeSubstitution(std::make_shared(), + "mini_step_allgather_replace", prim::kPrimMiniStepAllGather); check_bprop_eliminate_ = MakeSubstitution(std::make_shared(), "check_bprop_eliminate", prim::kPrimCheckBprop); reset_defer_inline_ = diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.h b/mindspore/ccsrc/frontend/optimizer/irpass.h index 26139fa938..1da28f1681 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass.h @@ -52,6 +52,7 @@ class OptimizeIRPassLib { SubstitutionPtr depend_value_elim_; SubstitutionPtr all_reduce_const_elim_; SubstitutionPtr mirror_mini_step_elim_; + SubstitutionPtr mini_step_allgather_replace_; // Env Item Eliminate SubstitutionPtr env_get_item_eliminate_; diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h index eed6b04565..41fb26c840 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h @@ -33,6 +33,7 @@ #include "utils/comm_manager.h" #include "frontend/parallel/context.h" #include "pipeline/jit/parse/resolve.h" +#include "frontend/parallel/step_parallel.h" namespace mindspore { namespace opt { @@ -155,7 +156,7 @@ class CheckBpropEliminater : public AnfVisitor { AnfNodePtr x_{nullptr}; }; -// {prim::kPrimMirrorMiniStep, X, Y, Z} -> X +// {prim::kPrimMirrorMiniStep, X, Z} -> X class MirrorMiniStepEliminater : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { @@ -163,11 +164,7 @@ class MirrorMiniStepEliminater : public AnfVisitor { return nullptr; } - auto cnode = node->cast(); - if (cnode == nullptr) { - return nullptr; - } - auto inputs = cnode->inputs(); + auto &inputs = node->cast()->inputs(); if (inputs.size() < 2) { return nullptr; } @@ -178,6 +175,32 @@ class MirrorMiniStepEliminater : public AnfVisitor { void Visit(const AnfNodePtr &) override {} }; +// {prim::kPrimMiniStepAllGather, X, Z} -> {prim::kPrimAllGather, X} +class MiniStepAllGatherPass : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + if (!IsPrimitiveCNode(node, prim::kPrimMiniStepAllGather) || node->func_graph() == nullptr) { + return nullptr; + } + + auto &inputs = node->cast()->inputs(); + if (inputs.size() < 2) { + return nullptr; + } + auto prim = GetValueNode(node->cast()->input(0)); + MS_EXCEPTION_IF_NULL(prim); + auto attrs = prim->attrs(); + std::string group = attrs[parallel::GROUP]->ToString(); + parallel::Operator op = parallel::CreateAllGatherOp(group); + std::vector node_input = parallel::CreateInput(op, inputs[1], parallel::PARALLEL_OPTIMIZER_ALLGATHER); + auto func_graph = inputs[1]->func_graph(); + CNodePtr new_node = func_graph->NewCNode(node_input); + return new_node; + } + + void Visit(const AnfNodePtr &) override {} +}; + // Reset defer_inline flag class ResetDeferInline : public AnfVisitor { public: @@ -328,6 +351,80 @@ class PynativeEliminater : public OptimizerCaller { return out; } + private: + AnfNodePtr OperatorHandle1(const PatternNode &arg, const AnfNodePtr &node) { + auto rep = (arg).GetNode(node); + if (rep != nullptr) { + if (rep->isa()) { + auto value_node = rep->cast(); + auto new_value_node = NewValueNode(FillZero(value_node->value())); + new_value_node->set_has_new_value(value_node->has_new_value()); + MS_LOG(DEBUG) << "Zeros_like replace ok " << rep->DebugString(4); + return new_value_node; + } + } + return nullptr; + } + + AnfNodePtr OperatorHandle2(const PatternNode &arg, const AnfNodePtr &node) { + auto rep = (arg).GetNode(node); + if (rep != nullptr) { + if (rep->isa() && !HasAbstractMonad(rep)) { + auto value_node = rep->cast(); + auto new_value_node = NewValueNode(FillZero(value_node->value())); + new_value_node->set_has_new_value(value_node->has_new_value()); + MS_LOG(DEBUG) << "Zeros_like replace ok 2 " << rep->DebugString(4); + return new_value_node; + } + } + return nullptr; + } + + void OperatorHandle3(const std::vector> &args, const AnfNodePtr &node) { + for (size_t i = 0; i < 2; i++) { + auto rep = (args[i]).GetNode(node); + if (rep != nullptr && rep->isa()) { + auto value_node = rep->cast(); + MS_EXCEPTION_IF_NULL(value_node); + auto &value = value_node->value(); + MS_EXCEPTION_IF_NULL(value); + // when the use count of value node equals to one, it only used in binop_grad_common function + if (value->isa() && value_node->used_graph_count() == 1) { + auto tensor = value->cast(); + MS_EXCEPTION_IF_NULL(tensor); + auto new_tensor = std::make_shared(tensor->Dtype()->type_id(), tensor->shape()); + value_node->set_value(new_tensor); + } + } + } + } + + AnfNodePtr OperatorHandle4(const PatternNode &arg, const PatternNode &arg1, + const AnfNodePtr &node) { + auto rep = (arg).GetNode(node); + if (rep != nullptr) { + if (rep->isa()) { + MS_LOG(DEBUG) << "Rep is " << rep->DebugString(4); + ValueNodePtr new_node; + auto value_node = rep->cast(); + auto rep1 = (arg1).GetNode(node); + if (rep1 != nullptr) { + if (rep1->isa()) { + auto idx = rep1->cast(); + if (!value_node->value()->isa()) { + return nullptr; + } + new_node = NewValueNode(FillGetItem(value_node->value(), idx->value())); + new_node->set_has_new_value(value_node->has_new_value()); + } + } + MS_LOG(DEBUG) << "Fill getitem replace ok " << new_node->DebugString(4); + return new_node; + } + } + return nullptr; + } + public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { MS_LOG(DEBUG) << "Start replace node " << node->DebugString(4); @@ -342,15 +439,9 @@ class PynativeEliminater : public OptimizerCaller { if ((pattern).TryCapture(node) && (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "SymbolStr") && CheckSymbolVNode(c_vnode.GetNode(node), "C") && CheckStrVNode(zeros_like_vnode.GetNode(node), "zeros_like"))) { - auto rep = (arg).GetNode(node); - if (rep != nullptr) { - if (rep->isa()) { - auto value_node = rep->cast(); - auto new_value_node = NewValueNode(FillZero(value_node->value())); - new_value_node->set_has_new_value(value_node->has_new_value()); - MS_LOG(DEBUG) << "Zeros_like replace ok " << rep->DebugString(4); - return new_value_node; - } + auto new_value_node = OperatorHandle1(arg, node); + if (new_value_node != nullptr) { + return new_value_node; } } MS_LOG(DEBUG) << "End replace 1 " << node->DebugString(4); @@ -360,15 +451,9 @@ class PynativeEliminater : public OptimizerCaller { if ((pattern1).TryCapture(node) && (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "SymbolStr") && CheckSymbolVNode(zeros_like_vnode.GetNode(node), "zeros_like"))) { - auto rep = (arg).GetNode(node); - if (rep != nullptr) { - if (rep->isa() && !HasAbstractMonad(rep)) { - auto value_node = rep->cast(); - auto new_value_node = NewValueNode(FillZero(value_node->value())); - new_value_node->set_has_new_value(value_node->has_new_value()); - MS_LOG(DEBUG) << "Zeros_like replace ok 2 " << rep->DebugString(4); - return new_value_node; - } + auto new_value_node = OperatorHandle2(arg, node); + if (new_value_node != nullptr) { + return new_value_node; } } // {prim:getattr, {prim::resolve, SymbolStr, binop_grad_common}, x, y, out, dout} -> {shape(x), shape(y), out, dout} @@ -379,22 +464,7 @@ class PynativeEliminater : public OptimizerCaller { auto pattern_binop = PCNode(resolve_binop, args[0], args[1], args[2], args[3]); if ((pattern_binop).TryCapture(node) && (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "SymbolStr") && CheckSymbolVNode(binop_grad_common.GetNode(node), "binop_grad_common"))) { - for (size_t i = 0; i < 2; i++) { - auto rep = (args[i]).GetNode(node); - if (rep != nullptr && rep->isa()) { - auto value_node = rep->cast(); - MS_EXCEPTION_IF_NULL(value_node); - auto &value = value_node->value(); - MS_EXCEPTION_IF_NULL(value); - // when the use count of value node equals to one, it only used in binop_grad_common function - if (value->isa() && value_node->used_graph_count() == 1) { - auto tensor = value->cast(); - MS_EXCEPTION_IF_NULL(tensor); - auto new_tensor = std::make_shared(tensor->Dtype()->type_id(), tensor->shape()); - value_node->set_value(new_tensor); - } - } - } + OperatorHandle3(args, node); return nullptr; } // resolve(CommonOPS, getitem)((tensors), 3) @@ -403,26 +473,9 @@ class PynativeEliminater : public OptimizerCaller { auto pattern2 = PCNode(resolve2, arg, arg1); if ((pattern2).TryCapture(node) && (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "CommonOPS") && CheckSymbolVNode(getitem_vnode.GetNode(node), "getitem"))) { - auto rep = (arg).GetNode(node); - if (rep != nullptr) { - if (rep->isa()) { - MS_LOG(DEBUG) << "Rep is " << rep->DebugString(4); - ValueNodePtr new_node; - auto value_node = rep->cast(); - auto rep1 = (arg1).GetNode(node); - if (rep1 != nullptr) { - if (rep1->isa()) { - auto idx = rep1->cast(); - if (!value_node->value()->isa()) { - return nullptr; - } - new_node = NewValueNode(FillGetItem(value_node->value(), idx->value())); - new_node->set_has_new_value(value_node->has_new_value()); - } - } - MS_LOG(DEBUG) << "Fill getitem replace ok " << new_node->DebugString(4); - return new_node; - } + auto new_value_node = OperatorHandle4(arg, arg1, node); + if (new_value_node != nullptr) { + return new_value_node; } } diff --git a/mindspore/ccsrc/frontend/parallel/context.cc b/mindspore/ccsrc/frontend/parallel/context.cc index c03c64ce30..9e832b09be 100644 --- a/mindspore/ccsrc/frontend/parallel/context.cc +++ b/mindspore/ccsrc/frontend/parallel/context.cc @@ -153,25 +153,27 @@ const std::vector ParallelContext::GetAllReduceFusionSplitSizes(const } // Clear param_shapes before training in auto-parallel or semi-auto-parallel mode -void ParallelParameterContextInit(const FuncGraphPtr &func_graph) { +void ParallelContext::ParallelParameterContextInitShape(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); - if (!func_graph->has_flag(AUTO_PARALLEL) || !func_graph->has_flag(TRAINING)) { - return; + if (func_graph->has_flag(AUTO_PARALLEL) && + (!func_graph->has_flag(TRAINING) || + (ParallelContext::GetInstance()->grad_accumulation_step() > 1 && !func_graph->has_flag(ACCUMULATION)))) { + init_param_shape_ = false; + } else { + param_shapes.clear(); + init_param_shape_ = true; } - param_shapes.clear(); } // Restore the parameters' shape for evaluation/prediction in auto-parallel or semi-auto-parallel mode -void ParallelParameterContextRestoreInNoTraining(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, - AbstractBasePtr ptr) { +void ParallelContext::ParallelParameterContextRestoreShape(const FuncGraphPtr &func_graph, + const ParameterPtr ¶m_node, AbstractBasePtr ptr) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(param_node); MS_EXCEPTION_IF_NULL(ptr); - if (!func_graph->has_flag(AUTO_PARALLEL) || (func_graph->attrs().count(TRAINING) == 0) || - func_graph->has_flag(TRAINING)) { + if (init_param_shape_) { return; } - auto iter = param_shapes.find(param_node->name()); if (iter == param_shapes.end()) { MS_LOG(WARNING) << "Can not found the shape for parameter " << param_node->name(); @@ -183,16 +185,16 @@ void ParallelParameterContextRestoreInNoTraining(const FuncGraphPtr &func_graph, MS_LOG(DEBUG) << "The parameter name is " << param_node->name() << ", the shape is " << shape; } +// Clear param_shapes before training in auto-parallel or semi-auto-parallel mode // Checkpoint the parameters' shape for training in auto-parallel or semi-auto-parallel mode -void ParallelParameterContextCkptInTraining(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, - const AbstractBasePtr &ptr) { +void ParallelContext::ParallelParameterContextCkptShape(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, + const AbstractBasePtr &ptr) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(param_node); MS_EXCEPTION_IF_NULL(ptr); - if (!func_graph->has_flag(AUTO_PARALLEL) || !func_graph->has_flag(TRAINING)) { + if (!init_param_shape_) { return; } - std::vector shape = dyn_cast(ptr->GetShapeTrack())->shape(); auto ret = param_shapes.try_emplace(param_node->name(), shape); if (!ret.second) { diff --git a/mindspore/ccsrc/frontend/parallel/context.h b/mindspore/ccsrc/frontend/parallel/context.h index d4212dde42..910c371266 100644 --- a/mindspore/ccsrc/frontend/parallel/context.h +++ b/mindspore/ccsrc/frontend/parallel/context.h @@ -30,6 +30,7 @@ #include "ir/func_graph.h" #include "utils/convert_utils.h" #include "utils/info.h" +#include "pipeline/jit/pipeline.h" namespace mindspore { namespace parallel { @@ -43,6 +44,7 @@ constexpr char DYNAMIC_PROGRAMMING[] = "dynamic_programming"; constexpr char RECURSIVE_PROGRAMMING[] = "recursive_programming"; constexpr char TRAINING[] = "training"; +constexpr char ACCUMULATION[] = "accumulation"; class ParallelContext { public: @@ -111,6 +113,11 @@ class ParallelContext { bool enable_parallel_optimizer() const { return enable_parallel_optimizer_; } void Reset(); + void ParallelParameterContextInitShape(const FuncGraphPtr &func_graph); + void ParallelParameterContextRestoreShape(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, + AbstractBasePtr ptr); + void ParallelParameterContextCkptShape(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, + const AbstractBasePtr &ptr); private: ParallelContext(); @@ -136,13 +143,9 @@ class ParallelContext { std::string strategy_ckpt_save_file_; std::string group_ckpt_save_file_; bool enable_parallel_optimizer_; + bool init_param_shape_; }; -void ParallelParameterContextInit(const FuncGraphPtr &func_graph); -void ParallelParameterContextRestoreInNoTraining(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, - AbstractBasePtr ptr); -void ParallelParameterContextCkptInTraining(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, - const AbstractBasePtr &ptr); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc index d7161020ab..d4cce6f6be 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc @@ -284,6 +284,39 @@ Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string & return op; } +void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr ¶m_node) { + MS_EXCEPTION_IF_NULL(comm_node); + MS_EXCEPTION_IF_NULL(param_node); + if (IsPrimitiveCNode(param_node, prim::kPrimReceive)) { + MS_LOG(WARNING) << "The mirror of Receive does not support fusion type now."; + return; + } + auto param = param_node->cast(); + MS_EXCEPTION_IF_NULL(param); + auto prim = GetValueNode(comm_node->input(0)); + MS_EXCEPTION_IF_NULL(prim); + auto attrs = prim->attrs(); + auto param_info = param->param_info(); + if (!param_info) { + MS_LOG(WARNING) << param->ToString() << "does not have parameter info."; + return; + } + int32_t fusion_type = param_info->comm_fusion(); + attrs[FUSION] = MakeValue(fusion_type); + prim->SetAttrs(attrs); + MS_LOG(INFO) << "Set comm fusion:" << param->param_info()->name() << "'s fusion type is " << fusion_type; +} + +void AddCommOpMeanFlag(const CNodePtr &comm_node) { + MS_EXCEPTION_IF_NULL(comm_node); + auto prim = GetValueNode(comm_node->input(0)); + auto attrs = prim->attrs(); + MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); + bool mean_flag = ParallelContext::GetInstance()->gradients_mean(); + attrs[MEAN_FLAG] = MakeValue(mean_flag); + prim->SetAttrs(attrs); +} + Operator CreateAllGatherOp(const std::string &group) { OperatorName operator_name = ALL_GATHER; ValuePtr attr0_value = MakeValue(group); // group @@ -299,6 +332,30 @@ Operator CreateAllGatherOp(const std::string &group) { return op; } +Operator CreateMiniStepAllGatherOp(const std::string &group) { + int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step(); + bool mean_flag = ParallelContext::GetInstance()->gradients_mean(); + + OperatorName operator_name = MINI_STEP_ALL_GATHER; + ValuePtr attr0_value = MakeValue(group); // group + Attr attr0 = std::make_pair(GROUP, attr0_value); + ValuePtr attr1_value = MakeValue(grad_accumulation_step); // grad_accumulation_step + Attr attr1 = std::make_pair(GRAD_ACCUMULATION_STEP, attr1_value); + ValuePtr attr2_value = MakeValue(mean_flag); // mean_flag + Attr attr2 = std::make_pair(MEAN_FLAG, attr2_value); + OperatorAttrs operator_attrs; + operator_attrs.push_back(attr0); + operator_attrs.push_back(attr1); + operator_attrs.push_back(attr2); + + OperatorParams operator_param; + OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param); + + Operator op = std::make_pair(operator_name, operator_arg); + MS_LOG(INFO) << "Create MINI_STEP_ALL_GATHER success, the group is " << group; + return op; +} + // use for get tensor slice Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout) { Shape tensor_map = tensor_layout.tensor_map().array(); @@ -771,7 +828,7 @@ void OperatorInfo::ComputeBatchSplitFlagList() { ReComputeBatchSplitFlagList(); } -// This is a common method for checking whether the generated stragegy has the correct number of devuces. +// This is a common method for checking whether the generated strategy has the correct number of devuces. Status PrepareStrategyBase(int64_t stage_id, size_t dev_num, const Shapes &inputs_partitions, StrategyPtr *const sp) { if (sp == nullptr) { MS_LOG(ERROR) << "The strategy is null."; @@ -886,7 +943,7 @@ Status GenerateStrategiesForBroadcastLeft(int64_t stage_id, const Shapes &inputs (void)input0_strategy.erase(input0_strategy.begin(), input0_strategy.begin() + static_cast(size_diff)); - // handel the case likes ([1, c, d], [a, b, c, d]) + // handle the case likes ([1, c, d], [a, b, c, d]) for (size_t i = 0; i < inputs_shape[0].size(); ++i) { if (inputs_shape[0][i] == 1) { input0_strategy[i] = 1; @@ -937,7 +994,7 @@ Status GenerateStrategiesForBroadcastRight(int64_t stage_id, const Shapes &input (void)input1_strategy.erase(input1_strategy.begin(), input1_strategy.begin() + static_cast(size_diff)); - // handel the case likes ([a, b, c, d], [1, c, d]) + // handle the case likes ([a, b, c, d], [1, c, d]) for (size_t i = 0; i < inputs_shape[1].size(); ++i) { if (inputs_shape[1][i] == 1) { input1_strategy[i] = 1; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h index b49ad00d0f..f5866f6b6b 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h @@ -36,6 +36,7 @@ #include "frontend/parallel/strategy.h" #include "frontend/parallel/tensor_layout/tensor_info.h" #include "utils/log_adapter.h" +#include "base/core_ops.h" namespace mindspore { namespace parallel { @@ -160,7 +161,7 @@ class OperatorInfo { void set_refkey_parameter_name(std::string p_name) { refkey_parameter_name_ = std::move(p_name); } const std::string &refkey_parameter_name() const { return refkey_parameter_name_; } // When the output of a Parameter (require_grad) being used by multiple operators, the Parameter's cost is calculated - // multiple times. This method is to correct this, and makes the cost is calulated only once. + // multiple times. This method is to correct this, and makes the cost is calculated only once. Status CorrectMemoryCost(size_t input_index); int64_t is_output_parameter_involve() const { return is_output_parameter_involve_; } int64_t is_output_critical() const { return is_output_critical_; } @@ -242,7 +243,7 @@ class OperatorInfo { bool is_auto_parallel_ = false; // false: semi_auto_parallel; true: auto_parallel // 'corrected_input_indices_' used to store the indices of input that have ALREADY been corrected. std::vector corrected_input_indices_; - // Given a parallization strategy, there is a cost. + // Given a parallelization strategy, there is a cost. std::vector> strategy_cost_; // For each input in 'inputs_', there is a bool variable indicating whether that the corresponding input is parameter std::vector is_parameter_; @@ -288,6 +289,9 @@ Operator CreateVirtualDivOp(int64_t div_num); Operator CreateAllReduceOp(const std::string &reduce_op, const std::string &group); Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string &group); Operator CreateAllGatherOp(const std::string &group); +Operator CreateMiniStepAllGatherOp(const std::string &group); +void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr ¶m_node); +void AddCommOpMeanFlag(const CNodePtr &comm_node); Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout); OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num); int64_t ComputeRepeatDeviceNumByTensorMap(const Shape &dev_matrix_shape, const Shape &tensor_map); diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h index 8f6b4e80e0..9a97e611db 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h @@ -109,6 +109,7 @@ constexpr char END[] = "end"; constexpr char STRIDES[] = "strides"; constexpr char GROUP[] = "group"; constexpr char FUSION[] = "fusion"; +constexpr char DO_MIRROR[] = "do_mirror"; constexpr char NUM_SAMPLED[] = "num_sampled"; constexpr char NUM_TRUE[] = "num_true"; constexpr char SEED[] = "seed"; @@ -180,6 +181,7 @@ constexpr char MIRROR_MINI_STEP_OPERATOR[] = "_MirrorMiniStepOperator"; constexpr char LOCAL_STEP[] = "local_step"; constexpr char STRIDED_SLICE[] = "StridedSlice"; constexpr char ALL_GATHER[] = "AllGather"; +constexpr char MINI_STEP_ALL_GATHER[] = "_MiniStepAllGather"; constexpr char REDUCE_SCATTER[] = "ReduceScatter"; constexpr char HOST_REDUCE_SCATTER[] = "_HostReduceScatter"; constexpr char EMBEDDING_LOOKUP[] = "EmbeddingLookup"; diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index 3395115c6b..62246f3cb3 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -65,8 +65,8 @@ void SetCommunicationOpGroupLabel(std::vector new_node_input) { return; } - ValueNodePtr prim_anf_node = new_node_input[0]->cast(); - PrimitivePtr prim = GetValueNode(prim_anf_node); + auto prim_anf_node = new_node_input[0]->cast(); + auto prim = GetValueNode(prim_anf_node); MS_EXCEPTION_IF_NULL(prim); auto attrs = prim->attrs(); @@ -83,6 +83,19 @@ void SetCommunicationOpGroupLabel(std::vector new_node_input) { } } +void SetMiniStepOpDoMirrorLabel(std::vector new_node_input, bool accu_flag) { + if (new_node_input.empty()) { + return; + } + auto prim_anf_node = new_node_input[0]->cast(); + auto prim = GetValueNode(prim_anf_node); + MS_EXCEPTION_IF_NULL(prim); + + auto attrs = prim->attrs(); + attrs[DO_MIRROR] = MakeValue(!accu_flag); + prim->SetAttrs(attrs); +} + std::vector CreateInput(const Operator &op, const AnfNodePtr &node, const std::string &instance_name) { MS_EXCEPTION_IF_NULL(node); OperatorArgs arg_forward = op.second; @@ -157,7 +170,6 @@ std::vector CreateMirrorInput(const FuncGraphPtr &root, const Operat MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(root->manager()); - AnfNodePtr local_step_param = nullptr; AnfNodePtr grad_accu = nullptr; std::string op_name = op.first; OperatorArgs arg_forward = op.second; @@ -165,25 +177,7 @@ std::vector CreateMirrorInput(const FuncGraphPtr &root, const Operat int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step(); if (grad_accumulation_step > 1) { - bool find_locat_step_node = false; auto parameters = root->parameters(); - for (auto ¶m : parameters) { - auto param_ptr = param->cast(); - MS_EXCEPTION_IF_NULL(param_ptr); - if (param_ptr->name() == LOCAL_STEP) { - auto param_users = root->manager()->node_users()[param]; - for (auto &user : param_users) { - if (AnfNodeIsPrimitive(user.first, ASSIGN)) { - find_locat_step_node = true; - local_step_param = user.first; - MS_LOG(INFO) << "Find the local step when create mirror, it may be in the mini step grad accumulation mode"; - break; - } - } - break; - } - } - bool find_grad_accu_node = false; for (auto ¶m : parameters) { if (!ParameterIsCloned(param)) { @@ -201,10 +195,12 @@ std::vector CreateMirrorInput(const FuncGraphPtr &root, const Operat } } - if (op_name == MIRROR_MINI_STEP_OPERATOR) { - if (!find_locat_step_node || !find_grad_accu_node) { + if (!find_grad_accu_node) { + if (op_name == MIRROR_MINI_STEP_OPERATOR) { op_name = MIRROR_OPERATOR; arg_forward.first.pop_back(); + } else if (op_name == MINI_STEP_ALL_GATHER) { + MS_LOG(EXCEPTION) << "You should define `accu_grads` when enable gradient accumulation."; } } } @@ -214,9 +210,9 @@ std::vector CreateMirrorInput(const FuncGraphPtr &root, const Operat OperatorParams params = arg_forward.second; std::vector new_node_input; - if (op_name == MIRROR_MINI_STEP_OPERATOR) { - new_node_input = {NewValueNode(pyop_instance), node, local_step_param, grad_accu}; - MS_LOG(INFO) << "Insert the local step node and grad accumulation node as the mirror op's input"; + if (op_name == MIRROR_MINI_STEP_OPERATOR || op_name == MINI_STEP_ALL_GATHER) { + new_node_input = {NewValueNode(pyop_instance), node, grad_accu}; + MS_LOG(INFO) << "Insert the grad accumulation node as the mirror op's input"; } else { new_node_input = {NewValueNode(pyop_instance), node}; } @@ -232,6 +228,10 @@ std::vector CreateMirrorInput(const FuncGraphPtr &root, const Operat // if the op have 'group' attr, set the rank list name for the op SetCommunicationOpGroupLabel(new_node_input); + // gradient accumulation + if (grad_accumulation_step > 1) { + SetMiniStepOpDoMirrorLabel(new_node_input, root->has_flag(ACCUMULATION)); + } return new_node_input; } @@ -284,6 +284,31 @@ static CNodePtr ReplaceNode(const Operator &op, const AnfNodePtr &pre_node, cons return new_node; } +// Replace pre_node with pre_node->op +static CNodePtr ReplaceMirrorNode(const FuncGraphPtr &root, const Operator &op, const AnfNodePtr &pre_node, + const FuncGraphPtr &func_graph, const std::string &instance_name, + const std::string ¶m_name) { + // insert new node before the node + FuncGraphManagerPtr manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + ScopePtr scope = pre_node->scope(); + MS_EXCEPTION_IF_NULL(scope); + std::vector node_input = CreateMirrorInput(root, op, pre_node, instance_name, param_name); + CNodePtr new_node = func_graph->NewCNode(node_input); + MS_EXCEPTION_IF_NULL(new_node); + if (instance_name.find(SPLIT_SENS) == std::string::npos) { + new_node->set_in_forward_flag(true); // mark forward flag + } + auto new_node_prim = GetValueNode(node_input[0]); + new_node_prim->set_instance_name(instance_name); + new_node_prim->set_attr("keep_value_node_input", MakeValue(true)); + new_node->set_scope(scope); + node_input[0]->set_scope(scope); + manager->Replace(pre_node, new_node); + MS_LOG(INFO) << "Insert " << instance_name << " success"; + return new_node; +} + std::string CreateInstanceName(const CNodePtr &node, size_t index) { MS_EXCEPTION_IF_NULL(node); if (!IsValueNode(node->input(0))) { @@ -1085,29 +1110,6 @@ bool IsCastBeforMirror(const CNodePtr &node, size_t index) { return (type_id != kNumberTypeFloat32); } -static void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr ¶m_node) { - MS_EXCEPTION_IF_NULL(comm_node); - MS_EXCEPTION_IF_NULL(param_node); - if (IsPrimitiveCNode(param_node, prim::kPrimReceive)) { - MS_LOG(WARNING) << "The mirror of Receive does not support fusion type now."; - return; - } - auto param = param_node->cast(); - MS_EXCEPTION_IF_NULL(param); - auto prim = GetValueNode(comm_node->input(0)); - MS_EXCEPTION_IF_NULL(prim); - auto attrs = prim->attrs(); - auto param_info = param->param_info(); - if (!param_info) { - MS_LOG(WARNING) << param->ToString() << "does not have parameter info."; - return; - } - int32_t fusion_type = param_info->comm_fusion(); - attrs[FUSION] = MakeValue(fusion_type); - prim->SetAttrs(attrs); - MS_LOG(INFO) << "Set comm fusion:" << param->param_info()->name() << "'s fusion type is " << fusion_type; -} - static bool CheckInsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node, size_t node_size) { if ((node->inputs().size() == 2) && (IsValueNode(node->input(1)))) { MS_LOG(INFO) << "Input is ValueList, skip it."; @@ -1194,7 +1196,6 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons InsertMirrorNode(root, op, cnode, size_t(1), pre_node, func_graph, instance_name, param_name); auto comm_op = cnode->input(size_t(1))->cast(); // add fusion flag - // pipeline mirror would not be set, which should be supported later AddCommOpFusionType(comm_op, param_node_pair.first); } continue; @@ -1539,33 +1540,40 @@ std::pair FindSubGraph(const FuncGraphPtr &graph, const Anf return std::make_pair(nullptr, 0); } -static void InsertAllGatherOp(const std::string &group, const std::pair &res, - const AnfNodePtr ¶meter) { - Operator op = CreateAllGatherOp(group); +static void InsertAllGatherOp(const FuncGraphPtr &root, const std::string &group, const std::pair &res, + const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(res.first); - MS_EXCEPTION_IF_NULL(parameter); + MS_EXCEPTION_IF_NULL(node); auto cnode = res.first->cast(); auto graph = cnode->func_graph(); MS_EXCEPTION_IF_NULL(graph); auto cnode_prim = GetValueNode(cnode->input(0)); MS_EXCEPTION_IF_NULL(cnode_prim); + int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step(); + Operator op; CNodePtr allgather; - if (cnode_prim->name() == CAST) { - allgather = ReplaceNode(op, cnode, graph, PARALLEL_OPTIMIZER_ALLGATHER); + if (grad_accumulation_step > 1) { + op = CreateMiniStepAllGatherOp(group); + auto param_name = node->cast()->name(); + if (cnode_prim->name() == CAST) { + allgather = ReplaceMirrorNode(root, op, cnode, graph, PARALLEL_OPTIMIZER_ALLGATHER, param_name); + } else { + InsertMirrorNode(root, op, cnode, res.second, node, graph, PARALLEL_OPTIMIZER_ALLGATHER, param_name); + allgather = cnode->input(res.second)->cast(); + } } else { - InsertNode(op, cnode, res.second, parameter, graph, PARALLEL_OPTIMIZER_ALLGATHER); - allgather = cnode->input(res.second)->cast(); + op = CreateAllGatherOp(group); + if (cnode_prim->name() == CAST) { + allgather = ReplaceNode(op, cnode, graph, PARALLEL_OPTIMIZER_ALLGATHER); + } else { + InsertNode(op, cnode, res.second, node, graph, PARALLEL_OPTIMIZER_ALLGATHER); + allgather = cnode->input(res.second)->cast(); + } } - MS_EXCEPTION_IF_NULL(allgather); // add fusion flag - AddCommOpFusionType(allgather, parameter); + AddCommOpFusionType(allgather, node); // add gradients mean - auto prim = GetValueNode(allgather->input(0)); - auto attrs = prim->attrs(); - MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); - bool mean_flag = ParallelContext::GetInstance()->gradients_mean(); - attrs["mean_flag"] = MakeValue(mean_flag); - prim->SetAttrs(attrs); + AddCommOpMeanFlag(allgather); } static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr ¶meter, @@ -1588,7 +1596,7 @@ static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr & << distribute_operator->inputs_tensor_info().size(); } // insert allgather operator between shard parameter and cnode - InsertAllGatherOp(opt_shard_group, param_pair, parameter); + InsertAllGatherOp(root, opt_shard_group, param_pair, parameter); MS_LOG(INFO) << "Parallel optimizer is applied between " << parameter->ToString() << " and " << cnode->ToString(); } } @@ -1733,12 +1741,20 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) { if (found_be_cloned_parameter) { // set the shape and tensor layout for cloned parameter + std::string param_name = cloned_parameter_node->cast()->name(); cloned_parameter->set_user_data(cloned_from_parameter->user_data()); MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract()); MS_EXCEPTION_IF_NULL(cloned_from_node->abstract()); auto cloned_abstract = cloned_parameter_node->abstract()->Clone(); MS_EXCEPTION_IF_NULL(cloned_abstract); - cloned_abstract->set_shape(cloned_from_node->abstract()->GetShapeTrack()); + if (param_name.find(ACCU_GRADS) != std::string::npos) { + auto slice_shape = cloned_from_parameter->user_data()->slice_shape().array(); + std::shared_ptr parallel_shape = std::make_shared(slice_shape); + MS_EXCEPTION_IF_NULL(parallel_shape); + cloned_abstract->set_shape(parallel_shape); + } else { + cloned_abstract->set_shape(cloned_from_node->abstract()->GetShapeTrack()); + } cloned_parameter_node->set_abstract(cloned_abstract); MS_LOG(INFO) << "The parameter: " << cloned_parameter->name() << " is cloned, the be cloned parameter is: " << cloned_from_parameter->name() diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.h b/mindspore/ccsrc/frontend/parallel/step_parallel.h index 2926bba2dd..0d80760d0b 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.h +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.h @@ -30,6 +30,8 @@ #include "frontend/parallel/strategy.h" #include "frontend/parallel/tensor_layout/tensor_redistribution.h" #include "pipeline/jit/pipeline.h" +#include "frontend/parallel/ops_info/ops_utils.h" +#include "frontend/parallel/auto_parallel/operator_costmodel.h" using OperatorInfoPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/pipeline/jit/action.cc b/mindspore/ccsrc/pipeline/jit/action.cc index ce8e692d6b..3bc6297e4a 100644 --- a/mindspore/ccsrc/pipeline/jit/action.cc +++ b/mindspore/ccsrc/pipeline/jit/action.cc @@ -258,9 +258,9 @@ bool AbstractSpecializeAction(const ResourcePtr &res) { FuncGraphPtr func_graph = res->func_graph(); abstract::AbstractBasePtrList args_spec = res->args_spec(); - - parallel::ParallelParameterContextInit(func_graph); - + auto context = parallel::ParallelContext::GetInstance(); + MS_EXCEPTION_IF_NULL(parallel::ParallelContext::GetInstance()); + context->ParallelParameterContextInitShape(func_graph); // suppose that there is not KeywordArgument for the top graph // get the hyper parameter for (const auto ¶m : func_graph->parameters()) { @@ -271,9 +271,9 @@ bool AbstractSpecializeAction(const ResourcePtr &res) { auto ref_key = std::make_shared(param_node->name()); auto abs_ref_key = ref_key->ToAbstract(); auto abs_ref = std::make_shared(abs_ref_key, abs_value); - parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, abs_ref); + context->ParallelParameterContextRestoreShape(func_graph, param_node, abs_ref); args_spec.push_back(abs_ref); - parallel::ParallelParameterContextCkptInTraining(func_graph, param_node, abs_ref); + context->ParallelParameterContextCkptShape(func_graph, param_node, abs_ref); } } // Analyze diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc index 69361d6671..7e37806c22 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -159,6 +159,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { irpass.replace_applicator_, irpass.mirror_mini_step_elim_, irpass.row_tensor_add_zeros_like_, + irpass.mini_step_allgather_replace_, }); opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_}); opt::OptPassConfig grad = opt::OptPassConfig({irpass.expand_jprim_}, true); diff --git a/mindspore/ccsrc/pybind_api/ir/tensor_py.cc b/mindspore/ccsrc/pybind_api/ir/tensor_py.cc index 204db5479f..81c592926a 100644 --- a/mindspore/ccsrc/pybind_api/ir/tensor_py.cc +++ b/mindspore/ccsrc/pybind_api/ir/tensor_py.cc @@ -374,7 +374,7 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { .def(py::init(), py::arg("dtype"), py::arg("shape")) .def_property_readonly("dtype", &MetaTensor::Dtype, "Get the MetaTensor's dtype.") .def_property_readonly("shape", &MetaTensor::shape, "Get the MetaTensor's shape.") - .def_property("_param_info", &MetaTensor::param_info, &MetaTensor::set_param_info) + .def_property("param_info", &MetaTensor::param_info, &MetaTensor::set_param_info) .def(py::pickle( [](const MetaTensor &t) { // __getstate__ /* Return a tuple that fully encodes the state of the object */ diff --git a/mindspore/common/parameter.py b/mindspore/common/parameter.py index 1cd31597da..57fb005aa8 100644 --- a/mindspore/common/parameter.py +++ b/mindspore/common/parameter.py @@ -134,7 +134,7 @@ class Parameter(Tensor_): Parameter, (data, self.name, self.requires_grad, self.layerwise_parallel)) def __init__(self, default_input, name=None, requires_grad=True, layerwise_parallel=False, parallel_optimizer=True): - self._param_info = ParamInfo() + self.param_info = ParamInfo() self.init_in_server = False self.cache_enable = False self.name = name @@ -230,7 +230,7 @@ class Parameter(Tensor_): "sparse operator support initialization in server.".format(self.name)) self.is_param_ps = True self.init_in_server = init_in_server - self._param_info.init_in_server = init_in_server + self.param_info.init_in_server = init_in_server @property def inited_param(self): @@ -245,7 +245,7 @@ class Parameter(Tensor_): @property def name(self): """Get the name of the parameter.""" - return self._param_info.name + return self.param_info.name @name.setter def name(self, name_): @@ -272,9 +272,9 @@ class Parameter(Tensor_): if len(self.shape) != 2: raise RuntimeError("The dims of parameter '{}' must be 2, but got {}." .format(self.name, len(self.shape))) - _reinsert_hash_table_size(name_, self._param_info.name, self.shape[0], self.shape[1]) + _reinsert_hash_table_size(name_, self.param_info.name, self.shape[0], self.shape[1]) - self._param_info.name = name_ + self.param_info.name = name_ @property def sliced(self): @@ -288,12 +288,12 @@ class Parameter(Tensor_): @property def comm_fusion(self): """Get the fusion type for communication operators corresponding to this parameter.""" - return self._param_info.comm_fusion + return self.param_info.comm_fusion @comm_fusion.setter def comm_fusion(self, comm_fusion_): """Set the fusion type for communication operators corresponding to this parameter.""" - self._param_info.comm_fusion = comm_fusion_ + self.param_info.comm_fusion = comm_fusion_ @property def unique(self): @@ -339,7 +339,7 @@ class Parameter(Tensor_): """ x = copy(self) # pylint: disable=protected-access - x._param_info = self._param_info.clone() + x.param_info = self.param_info.clone() x.is_init = False x.init = self.init x.is_param_ps = self.is_param_ps @@ -355,57 +355,57 @@ class Parameter(Tensor_): @property def layerwise_parallel(self): - return self._param_info.layerwise_parallel + return self.param_info.layerwise_parallel @layerwise_parallel.setter def layerwise_parallel(self, value=True): if not isinstance(value, bool): raise TypeError("`layerwise_parallel` parameter must be bool type") - self._param_info.layerwise_parallel = value + self.param_info.layerwise_parallel = value @property def parallel_optimizer(self): """Return whether the parameter requires weight shard for parallel optimizer.""" - return self._param_info.parallel_optimizer + return self.param_info.parallel_optimizer @parallel_optimizer.setter def parallel_optimizer(self, value=True): if not isinstance(value, bool): raise TypeError("`parallel_optimizer` parameter must be bool type") - self._param_info.parallel_optimizer = value + self.param_info.parallel_optimizer = value @property def cache_enable(self): """Return whether the parameter is cache enable.""" - return self._param_info.cache_enable + return self.param_info.cache_enable @cache_enable.setter def cache_enable(self, value=True): if not isinstance(value, bool): raise TypeError("`cache_enable` parameter must be bool type") - self._param_info.cache_enable = value + self.param_info.cache_enable = value @property def cache_shape(self): """Return the cache shape corresponding to the parameter if use cache.""" - return self._param_info.cache_shape + return self.param_info.cache_shape @cache_shape.setter def cache_shape(self, value): if not isinstance(value, (tuple, list)): raise TypeError("`cache_shape` parameter must be tuple or list type") - self._param_info.cache_shape = value + self.param_info.cache_shape = value @property def requires_grad(self): """Return whether the parameter requires gradient.""" - return self._param_info.requires_grad + return self.param_info.requires_grad @requires_grad.setter def requires_grad(self, value=True): if not isinstance(value, bool): raise TypeError("`requires_grad` parameter must be bool type") - self._param_info.requires_grad = value + self.param_info.requires_grad = value @property def data(self): @@ -419,7 +419,9 @@ class Parameter(Tensor_): self.init = None return self.assign_value(data) # create a new tensor - return Parameter(data, self.name, self.requires_grad) + new_param = Parameter(data, self.name, self.requires_grad) + new_param.param_info = self.param_info + return new_param def set_data(self, data, slice_shape=False): """ diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 39c1bf92f4..37c4a4f5ad 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -304,6 +304,7 @@ inline const PrimitivePtr kPrimCustomExtractFeatures = std::make_shared("_MirrorOperator"); inline const PrimitivePtr kPrimMirrorMiniStep = std::make_shared("_MirrorMiniStepOperator"); +inline const PrimitivePtr kPrimMiniStepAllGather = std::make_shared("_MiniStepAllGather"); inline const PrimitivePtr kPrimVirtualDiv = std::make_shared("_VirtualDiv"); inline const PrimitivePtr kPrimVirtualDataset = std::make_shared("_VirtualDataset"); inline const PrimitivePtr kPrimSend = std::make_shared("Send"); diff --git a/mindspore/ops/_grad/grad_comm_ops.py b/mindspore/ops/_grad/grad_comm_ops.py index 3c12a1c294..425764f8d6 100644 --- a/mindspore/ops/_grad/grad_comm_ops.py +++ b/mindspore/ops/_grad/grad_comm_ops.py @@ -20,7 +20,7 @@ from mindspore.communication import get_rank, get_group_size from .. import operations as P from ...common.tensor import RowTensor from ..composite.multitype_ops.zeros_like_impl import zeros_like -from ..operations.comm_ops import (AllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast, +from ..operations.comm_ops import (AllGather, _MiniStepAllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast, _GetTensorSlice, _MirrorOperator, _MirrorMiniStepOperator, ReduceOp, ReduceScatter, _HostReduceScatter, _VirtualDiv, AllSwap) from .grad_base import bprop_getters @@ -150,6 +150,39 @@ def get_bprop_all_gather(self): return bprop +@bprop_getters.register(_MiniStepAllGather) +def get_bprop_mini_step_all_gather(self): + """Generate bprop for _MiniStepAllGather""" + fusion = self.get_attr_dict()["fusion"] + mean_flag = self.get_attr_dict()["mean_flag"] + do_mirror = self.get_attr_dict()["do_mirror"] + scale = 1 / self.rank_size + all_reduce = AllReduce(ReduceOp.SUM, self.group).add_prim_attr("fusion", fusion) + if self.instance_name: + instance_name = "grad_" + self.instance_name + all_reduce.set_prim_instance_name(instance_name) + rank = get_rank(self.group) + dev_num = get_group_size(self.group) + split = P.Split(output_num=dev_num) + + def bprop(x, z, out, dout): + if do_mirror: + if mean_flag: + tmp = z + dout + grad = all_reduce(tmp) + dx = split(grad)[rank] + dx = F.tensor_mul(dx, scale) + else: + tmp = z + dout + grad = all_reduce(tmp) + dx = split(grad)[rank] + else: + dx = dout + return (dx, zeros_like(z)) + + return bprop + + @bprop_getters.register(_HostAllGather) def get_bprop_host_all_gather(self): """Generate bprop for _HostAllGather""" @@ -291,18 +324,13 @@ def get_bprop_mirror_mini_step_operator(self): group = self.group dev_num = self.dev_num mean_flag = self.mean_flag - grad_accumulation_step = self.grad_accumulation_step all_reduce = AllReduce(group=group) all_gather = AllGather(group=group) mul = P.Mul() cast = P.Cast() - equal = P.Equal() - reshape = P.Reshape() - fusion = 1 - if hasattr(self, 'fusion'): - fusion = self.fusion + fusion = self.get_attr_dict()["fusion"] all_reduce.add_prim_attr("fusion", fusion) if hasattr(self, 'parameter'): parameter = self.parameter @@ -311,16 +339,15 @@ def get_bprop_mirror_mini_step_operator(self): if self.instance_name: instance_name = "grad_mirror" + self.instance_name all_reduce.set_prim_instance_name(instance_name) + do_mirror = self.get_attr_dict()["do_mirror"] - def bprop(x, y, z, out, dout): - do_mirror = equal(y, grad_accumulation_step) - do_mirror = reshape(do_mirror, (())) + def bprop(x, z, out, dout): if mean_flag: if F.issubclass_(F.typeof(dout), mstype.tensor): if do_mirror: tmp = z + dout real_grad = all_reduce(tmp) - dx = real_grad - z + dx = real_grad else: dx = dout float_one = F.scalar_cast(1.0, F.dtype(dx)) @@ -342,7 +369,7 @@ def get_bprop_mirror_mini_step_operator(self): if do_mirror: tmp = z + dout real_grad = all_reduce(tmp) - dx = real_grad - z + dx = real_grad else: dx = dout else: @@ -354,7 +381,7 @@ def get_bprop_mirror_mini_step_operator(self): grad = dout.values dx = RowTensor(indices, grad, dout.dense_shape) - return (dx, zeros_like(y), zeros_like(z)) + return (dx, zeros_like(z)) return bprop diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index ae3134358f..5622650043 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -35,7 +35,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unsta SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup, Unique, GatherD, Identity, Range) from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast, - _MirrorOperator, _MirrorMiniStepOperator, ReduceOp, _VirtualDataset, + _MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset, _VirtualDiv, _GetTensorSlice, _HostAllGather, _HostReduceScatter) from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary, diff --git a/mindspore/ops/operations/comm_ops.py b/mindspore/ops/operations/comm_ops.py index 25d6772510..13ee156730 100644 --- a/mindspore/ops/operations/comm_ops.py +++ b/mindspore/ops/operations/comm_ops.py @@ -200,6 +200,38 @@ class AllGather(PrimitiveWithInfer): raise NotImplementedError +class _MiniStepAllGather(PrimitiveWithInfer): + """ + Auto parallel virtual operator. Do nothing in forward, do reducescatter in backward in mini-step. It is only for + internal use of parallel modules and cannot be called by users. + + Args: + group (str): The communication group to work on. Default: None. + grad_accumulation_step (int): The grad accumulation step. Default: None. + """ + @prim_attr_register + def __init__(self, group=GlobalComm.WORLD_COMM_GROUP, grad_accumulation_step=None, mean_flag=None): + validator.check_value_type('group', _get_group(group), (str,), self.name) + self.rank = get_rank(_get_group(group)) + self.rank_size = get_group_size(_get_group(group)) + validator.check('rank', self.rank, 'rank_size', self.rank_size, Rel.LT, self.name) + self.add_prim_attr('rank_size', self.rank_size) + self.add_prim_attr('group', _get_group(group)) + self.add_prim_attr('fusion', 1) + self.grad_accumulation_step = grad_accumulation_step + self.mean_flag = mean_flag + + def infer_shape(self, x_shape, z_shape): + validator.check_positive_int(len(x_shape), "x shape", self.name) + if x_shape[0] > 0: + x_shape[0] = x_shape[0] * self.rank_size + return x_shape + + def infer_dtype(self, x_dtype, z_shape): + validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name) + return x_dtype + + class _HostAllGather(PrimitiveWithInfer): """ Gathers tensors from the specified communication group on host. @@ -590,10 +622,10 @@ class _MirrorMiniStepOperator(PrimitiveWithInfer): self.mean_flag = mean_flag self.grad_accumulation_step = grad_accumulation_step - def infer_shape(self, x_shape, y_shape, z_shape): + def infer_shape(self, x_shape, z_shape): return x_shape - def infer_dtype(self, x_dtype, y_shape, z_shape): + def infer_dtype(self, x_dtype, z_shape): return x_dtype diff --git a/tests/ut/python/parallel/test_grad_accumulation.py b/tests/ut/python/parallel/test_grad_accumulation.py index 7bee3b8fe9..3b1bb48b73 100644 --- a/tests/ut/python/parallel/test_grad_accumulation.py +++ b/tests/ut/python/parallel/test_grad_accumulation.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,15 +17,14 @@ import numpy as np import mindspore as ms import mindspore.common.dtype as mstype from mindspore import context, Tensor, Parameter -from mindspore.nn import Cell, Momentum, Norm from mindspore.train import Model from mindspore.ops import operations as P from mindspore.ops import composite as C from mindspore.ops import functional as F from mindspore.common.initializer import initializer -from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell from mindspore.context import ParallelMode - +from mindspore.nn import DistributedGradReducer, DynamicLossScaleUpdateCell, Cell, Momentum, Norm +from mindspore.parallel._utils import _get_device_num from tests.dataset_mock import MindData @@ -142,29 +141,29 @@ class TrainAccumulateStepsWithLossScaleCell(Cell): accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size = batch_size * accumulation_steps. Default: 1. """ - def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=4): + def __init__(self, network, optimizer, scale_update_cell=None): super(TrainAccumulateStepsWithLossScaleCell, self).__init__(auto_prefix=False) + self.accu = False + self.is_accu_step = Tensor(np.array([self.accu])) self.network = network self.network.set_grad() self.weights = optimizer.parameters self.optimizer = optimizer - self.accumulation_steps = accumulation_steps + self.accumulation_steps = context.get_auto_parallel_context("grad_accumulation_step") self.one = Tensor(np.array([1]).astype(np.int32)) self.zero = Tensor(np.array([0]).astype(np.int32)) - self.local_step = Parameter(initializer(0, [1], mstype.int32), name="local_step") self.accu_grads = self.weights.clone(prefix="accu_grads", init='zeros') self.accu_overflow = Parameter(initializer(0, [1], mstype.int32)) self.accu_loss = Parameter(initializer(0, [1], mstype.float32)) - - self.grad = C.GradOperation(get_by_list=True, sens_param=True) self.reducer_flag = False + self.grad = C.GradOperation(get_by_list=True, sens_param=True) self.parallel_mode = context.get_auto_parallel_context("parallel_mode") if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: self.reducer_flag = True - self.grad_reducer = F.identity self.degree = 1 + self.grad_reducer = F.identity if self.reducer_flag: - self.degree = get_group_size() + self.degree = _get_device_num() self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) self.overflow_reducer = F.identity @@ -197,34 +196,27 @@ class TrainAccumulateStepsWithLossScaleCell(Cell): else: scaling_sens = sens - # update accumulation parameters - is_accu_step = self.not_equal(self.local_step, self.accumulation_steps) - self.local_step = self.select(is_accu_step, self.local_step + self.one, self.one) - self.accu_loss = self.select(is_accu_step, self.accu_loss + loss, loss) - mean_loss = self.accu_loss / self.local_step - is_accu_step = self.not_equal(self.local_step, self.accumulation_steps) - # alloc status and clear should be right before gradoperation init = self.alloc_status() self.clear_before_grad(init) grads = self.grad(self.network, weights)(x, b, self.cast(scaling_sens, mstype.float32)) - accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, grads) - mean_loss = F.depend(mean_loss, accu_succ) + if self.is_accu_step and self.accumulation_steps > 1: + accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, grads) + loss = F.depend(loss, accu_succ) self.get_status(init) flag_sum = self.reduce_sum(init, (0,)) overflow = self.less_equal(self.base, flag_sum) overflow = self.logical_or(self.not_equal(self.accu_overflow, self.zero), overflow) accu_overflow = self.select(overflow, self.one, self.zero) - self.accu_overflow = self.select(is_accu_step, accu_overflow, self.zero) - is_accu_step = self.reshape(is_accu_step, (())) + self.accu_overflow = self.select(self.is_accu_step, accu_overflow, self.zero) - if is_accu_step: + if self.is_accu_step: succ = False else: # apply grad reducer on grads - grads = self.grad_reducer(self.accu_grads) + grads = self.grad_reducer(grads) scaling = scaling_sens * self.degree * self.accumulation_steps grads = self.hyper_map(F.partial(grad_scale, scaling), grads) grads = ClipByGlobalNorm()(grads) @@ -241,7 +233,7 @@ class TrainAccumulateStepsWithLossScaleCell(Cell): else: succ = self.optimizer(grads) - ret = (mean_loss, overflow, scaling_sens) + ret = (loss, overflow, scaling_sens) return F.depend(ret, succ) @@ -265,25 +257,51 @@ _b = Tensor(np.ones([16]), dtype=ms.float32) _w1 = Tensor(np.ones([16]), dtype=ms.float32) -def compile_net(net, grad_accumulation_step): - context.set_context(save_graphs=True) +def compile_net(net): + context.set_context(enable_sparse=False) learning_rate = 0.1 momentum = 0.9 epoch_size = 2 dataset = Dataset(_x, _b) opt = Momentum(net.trainable_params(), learning_rate, momentum) update_cell = DynamicLossScaleUpdateCell(loss_scale_value=65536, scale_factor=2, scale_window=1000) - net_wrap = TrainAccumulateStepsWithLossScaleCell(net, opt, scale_update_cell=update_cell, - accumulation_steps=grad_accumulation_step) + net_wrap = TrainAccumulateStepsWithLossScaleCell(net, opt, scale_update_cell=update_cell) model = Model(net_wrap) model.train(epoch_size, dataset, dataset_sink_mode=False) context.reset_auto_parallel_context() -def test_grad_accumulation(): +def test_grad_accumulation_accu(): grad_accumulation_step = 4 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0, grad_accumulation_step=grad_accumulation_step) strategy = ((2,), (2,)) - net = Net(_w1, strategy) - compile_net(net, grad_accumulation_step) + net = Net(_w1, strategy).add_flags_recursive(accu=True) + compile_net(net) + + +def test_grad_accu_and_opt_shard_accu(): + grad_accumulation_step = 4 + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0, + grad_accumulation_step=grad_accumulation_step, enable_parallel_optimizer=True) + strategy = ((2,), (2,)) + net = Net(_w1, strategy).add_flags_recursive(accu=True) + compile_net(net) + + +def test_grad_accumulation_not_accu(): + grad_accumulation_step = 4 + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0, + grad_accumulation_step=grad_accumulation_step) + strategy = ((2,), (2,)) + net = Net(_w1, strategy).add_flags_recursive(accu=False) + compile_net(net) + + +def test_grad_accu_and_opt_shard_not_accu(): + grad_accumulation_step = 4 + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0, + grad_accumulation_step=grad_accumulation_step, enable_parallel_optimizer=True) + strategy = ((2,), (2,)) + net = Net(_w1, strategy).add_flags_recursive(accu=False) + compile_net(net)