| @@ -86,6 +86,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||
| same_eliminate_ = MakeSubstitution(std::make_shared<SameEliminater>(), "same_eliminate", prim::kPrimSameTypeShape); | |||
| mirror_mini_step_elim_ = MakeSubstitution(std::make_shared<MirrorMiniStepEliminater>(), "mirror_mini_step_eliminate", | |||
| prim::kPrimMirrorMiniStep); | |||
| mini_step_allgather_replace_ = MakeSubstitution(std::make_shared<MiniStepAllGatherPass>(), | |||
| "mini_step_allgather_replace", prim::kPrimMiniStepAllGather); | |||
| check_bprop_eliminate_ = | |||
| MakeSubstitution(std::make_shared<CheckBpropEliminater>(), "check_bprop_eliminate", prim::kPrimCheckBprop); | |||
| reset_defer_inline_ = | |||
| @@ -52,6 +52,7 @@ class OptimizeIRPassLib { | |||
| SubstitutionPtr depend_value_elim_; | |||
| SubstitutionPtr all_reduce_const_elim_; | |||
| SubstitutionPtr mirror_mini_step_elim_; | |||
| SubstitutionPtr mini_step_allgather_replace_; | |||
| // Env Item Eliminate | |||
| SubstitutionPtr env_get_item_eliminate_; | |||
| @@ -33,6 +33,7 @@ | |||
| #include "utils/comm_manager.h" | |||
| #include "frontend/parallel/context.h" | |||
| #include "pipeline/jit/parse/resolve.h" | |||
| #include "frontend/parallel/step_parallel.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| @@ -155,7 +156,7 @@ class CheckBpropEliminater : public AnfVisitor { | |||
| AnfNodePtr x_{nullptr}; | |||
| }; | |||
| // {prim::kPrimMirrorMiniStep, X, Y, Z} -> X | |||
| // {prim::kPrimMirrorMiniStep, X, Z} -> X | |||
| class MirrorMiniStepEliminater : public AnfVisitor { | |||
| public: | |||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||
| @@ -163,11 +164,7 @@ class MirrorMiniStepEliminater : public AnfVisitor { | |||
| return nullptr; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (cnode == nullptr) { | |||
| return nullptr; | |||
| } | |||
| auto inputs = cnode->inputs(); | |||
| auto &inputs = node->cast<CNodePtr>()->inputs(); | |||
| if (inputs.size() < 2) { | |||
| return nullptr; | |||
| } | |||
| @@ -178,6 +175,32 @@ class MirrorMiniStepEliminater : public AnfVisitor { | |||
| void Visit(const AnfNodePtr &) override {} | |||
| }; | |||
| // {prim::kPrimMiniStepAllGather, X, Z} -> {prim::kPrimAllGather, X} | |||
| class MiniStepAllGatherPass : public AnfVisitor { | |||
| public: | |||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||
| if (!IsPrimitiveCNode(node, prim::kPrimMiniStepAllGather) || node->func_graph() == nullptr) { | |||
| return nullptr; | |||
| } | |||
| auto &inputs = node->cast<CNodePtr>()->inputs(); | |||
| if (inputs.size() < 2) { | |||
| return nullptr; | |||
| } | |||
| auto prim = GetValueNode<PrimitivePtr>(node->cast<CNodePtr>()->input(0)); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| auto attrs = prim->attrs(); | |||
| std::string group = attrs[parallel::GROUP]->ToString(); | |||
| parallel::Operator op = parallel::CreateAllGatherOp(group); | |||
| std::vector<AnfNodePtr> node_input = parallel::CreateInput(op, inputs[1], parallel::PARALLEL_OPTIMIZER_ALLGATHER); | |||
| auto func_graph = inputs[1]->func_graph(); | |||
| CNodePtr new_node = func_graph->NewCNode(node_input); | |||
| return new_node; | |||
| } | |||
| void Visit(const AnfNodePtr &) override {} | |||
| }; | |||
| // Reset defer_inline flag | |||
| class ResetDeferInline : public AnfVisitor { | |||
| public: | |||
| @@ -328,6 +351,80 @@ class PynativeEliminater : public OptimizerCaller { | |||
| return out; | |||
| } | |||
| private: | |||
| AnfNodePtr OperatorHandle1(const PatternNode<AnfNodePtr> &arg, const AnfNodePtr &node) { | |||
| auto rep = (arg).GetNode(node); | |||
| if (rep != nullptr) { | |||
| if (rep->isa<ValueNode>()) { | |||
| auto value_node = rep->cast<ValueNodePtr>(); | |||
| auto new_value_node = NewValueNode(FillZero(value_node->value())); | |||
| new_value_node->set_has_new_value(value_node->has_new_value()); | |||
| MS_LOG(DEBUG) << "Zeros_like replace ok " << rep->DebugString(4); | |||
| return new_value_node; | |||
| } | |||
| } | |||
| return nullptr; | |||
| } | |||
| AnfNodePtr OperatorHandle2(const PatternNode<AnfNodePtr> &arg, const AnfNodePtr &node) { | |||
| auto rep = (arg).GetNode(node); | |||
| if (rep != nullptr) { | |||
| if (rep->isa<ValueNode>() && !HasAbstractMonad(rep)) { | |||
| auto value_node = rep->cast<ValueNodePtr>(); | |||
| auto new_value_node = NewValueNode(FillZero(value_node->value())); | |||
| new_value_node->set_has_new_value(value_node->has_new_value()); | |||
| MS_LOG(DEBUG) << "Zeros_like replace ok 2 " << rep->DebugString(4); | |||
| return new_value_node; | |||
| } | |||
| } | |||
| return nullptr; | |||
| } | |||
| void OperatorHandle3(const std::vector<PatternNode<AnfNodePtr>> &args, const AnfNodePtr &node) { | |||
| for (size_t i = 0; i < 2; i++) { | |||
| auto rep = (args[i]).GetNode(node); | |||
| if (rep != nullptr && rep->isa<ValueNode>()) { | |||
| auto value_node = rep->cast<ValueNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(value_node); | |||
| auto &value = value_node->value(); | |||
| MS_EXCEPTION_IF_NULL(value); | |||
| // when the use count of value node equals to one, it only used in binop_grad_common function | |||
| if (value->isa<tensor::Tensor>() && value_node->used_graph_count() == 1) { | |||
| auto tensor = value->cast<tensor::TensorPtr>(); | |||
| MS_EXCEPTION_IF_NULL(tensor); | |||
| auto new_tensor = std::make_shared<tensor::Tensor>(tensor->Dtype()->type_id(), tensor->shape()); | |||
| value_node->set_value(new_tensor); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| AnfNodePtr OperatorHandle4(const PatternNode<AnfNodePtr> &arg, const PatternNode<AnfNodePtr> &arg1, | |||
| const AnfNodePtr &node) { | |||
| auto rep = (arg).GetNode(node); | |||
| if (rep != nullptr) { | |||
| if (rep->isa<ValueNode>()) { | |||
| MS_LOG(DEBUG) << "Rep is " << rep->DebugString(4); | |||
| ValueNodePtr new_node; | |||
| auto value_node = rep->cast<ValueNodePtr>(); | |||
| auto rep1 = (arg1).GetNode(node); | |||
| if (rep1 != nullptr) { | |||
| if (rep1->isa<ValueNode>()) { | |||
| auto idx = rep1->cast<ValueNodePtr>(); | |||
| if (!value_node->value()->isa<ValueTuple>()) { | |||
| return nullptr; | |||
| } | |||
| new_node = NewValueNode(FillGetItem(value_node->value(), idx->value())); | |||
| new_node->set_has_new_value(value_node->has_new_value()); | |||
| } | |||
| } | |||
| MS_LOG(DEBUG) << "Fill getitem replace ok " << new_node->DebugString(4); | |||
| return new_node; | |||
| } | |||
| } | |||
| return nullptr; | |||
| } | |||
| public: | |||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||
| MS_LOG(DEBUG) << "Start replace node " << node->DebugString(4); | |||
| @@ -342,15 +439,9 @@ class PynativeEliminater : public OptimizerCaller { | |||
| if ((pattern).TryCapture(node) && | |||
| (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "SymbolStr") && | |||
| CheckSymbolVNode(c_vnode.GetNode(node), "C") && CheckStrVNode(zeros_like_vnode.GetNode(node), "zeros_like"))) { | |||
| auto rep = (arg).GetNode(node); | |||
| if (rep != nullptr) { | |||
| if (rep->isa<ValueNode>()) { | |||
| auto value_node = rep->cast<ValueNodePtr>(); | |||
| auto new_value_node = NewValueNode(FillZero(value_node->value())); | |||
| new_value_node->set_has_new_value(value_node->has_new_value()); | |||
| MS_LOG(DEBUG) << "Zeros_like replace ok " << rep->DebugString(4); | |||
| return new_value_node; | |||
| } | |||
| auto new_value_node = OperatorHandle1(arg, node); | |||
| if (new_value_node != nullptr) { | |||
| return new_value_node; | |||
| } | |||
| } | |||
| MS_LOG(DEBUG) << "End replace 1 " << node->DebugString(4); | |||
| @@ -360,15 +451,9 @@ class PynativeEliminater : public OptimizerCaller { | |||
| if ((pattern1).TryCapture(node) && (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "SymbolStr") && | |||
| CheckSymbolVNode(zeros_like_vnode.GetNode(node), "zeros_like"))) { | |||
| auto rep = (arg).GetNode(node); | |||
| if (rep != nullptr) { | |||
| if (rep->isa<ValueNode>() && !HasAbstractMonad(rep)) { | |||
| auto value_node = rep->cast<ValueNodePtr>(); | |||
| auto new_value_node = NewValueNode(FillZero(value_node->value())); | |||
| new_value_node->set_has_new_value(value_node->has_new_value()); | |||
| MS_LOG(DEBUG) << "Zeros_like replace ok 2 " << rep->DebugString(4); | |||
| return new_value_node; | |||
| } | |||
| auto new_value_node = OperatorHandle2(arg, node); | |||
| if (new_value_node != nullptr) { | |||
| return new_value_node; | |||
| } | |||
| } | |||
| // {prim:getattr, {prim::resolve, SymbolStr, binop_grad_common}, x, y, out, dout} -> {shape(x), shape(y), out, dout} | |||
| @@ -379,22 +464,7 @@ class PynativeEliminater : public OptimizerCaller { | |||
| auto pattern_binop = PCNode(resolve_binop, args[0], args[1], args[2], args[3]); | |||
| if ((pattern_binop).TryCapture(node) && (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "SymbolStr") && | |||
| CheckSymbolVNode(binop_grad_common.GetNode(node), "binop_grad_common"))) { | |||
| for (size_t i = 0; i < 2; i++) { | |||
| auto rep = (args[i]).GetNode(node); | |||
| if (rep != nullptr && rep->isa<ValueNode>()) { | |||
| auto value_node = rep->cast<ValueNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(value_node); | |||
| auto &value = value_node->value(); | |||
| MS_EXCEPTION_IF_NULL(value); | |||
| // when the use count of value node equals to one, it only used in binop_grad_common function | |||
| if (value->isa<tensor::Tensor>() && value_node->used_graph_count() == 1) { | |||
| auto tensor = value->cast<tensor::TensorPtr>(); | |||
| MS_EXCEPTION_IF_NULL(tensor); | |||
| auto new_tensor = std::make_shared<tensor::Tensor>(tensor->Dtype()->type_id(), tensor->shape()); | |||
| value_node->set_value(new_tensor); | |||
| } | |||
| } | |||
| } | |||
| OperatorHandle3(args, node); | |||
| return nullptr; | |||
| } | |||
| // resolve(CommonOPS, getitem)((tensors), 3) | |||
| @@ -403,26 +473,9 @@ class PynativeEliminater : public OptimizerCaller { | |||
| auto pattern2 = PCNode(resolve2, arg, arg1); | |||
| if ((pattern2).TryCapture(node) && (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "CommonOPS") && | |||
| CheckSymbolVNode(getitem_vnode.GetNode(node), "getitem"))) { | |||
| auto rep = (arg).GetNode(node); | |||
| if (rep != nullptr) { | |||
| if (rep->isa<ValueNode>()) { | |||
| MS_LOG(DEBUG) << "Rep is " << rep->DebugString(4); | |||
| ValueNodePtr new_node; | |||
| auto value_node = rep->cast<ValueNodePtr>(); | |||
| auto rep1 = (arg1).GetNode(node); | |||
| if (rep1 != nullptr) { | |||
| if (rep1->isa<ValueNode>()) { | |||
| auto idx = rep1->cast<ValueNodePtr>(); | |||
| if (!value_node->value()->isa<ValueTuple>()) { | |||
| return nullptr; | |||
| } | |||
| new_node = NewValueNode(FillGetItem(value_node->value(), idx->value())); | |||
| new_node->set_has_new_value(value_node->has_new_value()); | |||
| } | |||
| } | |||
| MS_LOG(DEBUG) << "Fill getitem replace ok " << new_node->DebugString(4); | |||
| return new_node; | |||
| } | |||
| auto new_value_node = OperatorHandle4(arg, arg1, node); | |||
| if (new_value_node != nullptr) { | |||
| return new_value_node; | |||
| } | |||
| } | |||
| @@ -153,25 +153,27 @@ const std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitSizes(const | |||
| } | |||
| // Clear param_shapes before training in auto-parallel or semi-auto-parallel mode | |||
| void ParallelParameterContextInit(const FuncGraphPtr &func_graph) { | |||
| void ParallelContext::ParallelParameterContextInitShape(const FuncGraphPtr &func_graph) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| if (!func_graph->has_flag(AUTO_PARALLEL) || !func_graph->has_flag(TRAINING)) { | |||
| return; | |||
| if (func_graph->has_flag(AUTO_PARALLEL) && | |||
| (!func_graph->has_flag(TRAINING) || | |||
| (ParallelContext::GetInstance()->grad_accumulation_step() > 1 && !func_graph->has_flag(ACCUMULATION)))) { | |||
| init_param_shape_ = false; | |||
| } else { | |||
| param_shapes.clear(); | |||
| init_param_shape_ = true; | |||
| } | |||
| param_shapes.clear(); | |||
| } | |||
| // Restore the parameters' shape for evaluation/prediction in auto-parallel or semi-auto-parallel mode | |||
| void ParallelParameterContextRestoreInNoTraining(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, | |||
| AbstractBasePtr ptr) { | |||
| void ParallelContext::ParallelParameterContextRestoreShape(const FuncGraphPtr &func_graph, | |||
| const ParameterPtr ¶m_node, AbstractBasePtr ptr) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(param_node); | |||
| MS_EXCEPTION_IF_NULL(ptr); | |||
| if (!func_graph->has_flag(AUTO_PARALLEL) || (func_graph->attrs().count(TRAINING) == 0) || | |||
| func_graph->has_flag(TRAINING)) { | |||
| if (init_param_shape_) { | |||
| return; | |||
| } | |||
| auto iter = param_shapes.find(param_node->name()); | |||
| if (iter == param_shapes.end()) { | |||
| MS_LOG(WARNING) << "Can not found the shape for parameter " << param_node->name(); | |||
| @@ -183,16 +185,16 @@ void ParallelParameterContextRestoreInNoTraining(const FuncGraphPtr &func_graph, | |||
| MS_LOG(DEBUG) << "The parameter name is " << param_node->name() << ", the shape is " << shape; | |||
| } | |||
| // Clear param_shapes before training in auto-parallel or semi-auto-parallel mode | |||
| // Checkpoint the parameters' shape for training in auto-parallel or semi-auto-parallel mode | |||
| void ParallelParameterContextCkptInTraining(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, | |||
| const AbstractBasePtr &ptr) { | |||
| void ParallelContext::ParallelParameterContextCkptShape(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, | |||
| const AbstractBasePtr &ptr) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(param_node); | |||
| MS_EXCEPTION_IF_NULL(ptr); | |||
| if (!func_graph->has_flag(AUTO_PARALLEL) || !func_graph->has_flag(TRAINING)) { | |||
| if (!init_param_shape_) { | |||
| return; | |||
| } | |||
| std::vector<int64_t> shape = dyn_cast<abstract::Shape>(ptr->GetShapeTrack())->shape(); | |||
| auto ret = param_shapes.try_emplace(param_node->name(), shape); | |||
| if (!ret.second) { | |||
| @@ -30,6 +30,7 @@ | |||
| #include "ir/func_graph.h" | |||
| #include "utils/convert_utils.h" | |||
| #include "utils/info.h" | |||
| #include "pipeline/jit/pipeline.h" | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| @@ -43,6 +44,7 @@ constexpr char DYNAMIC_PROGRAMMING[] = "dynamic_programming"; | |||
| constexpr char RECURSIVE_PROGRAMMING[] = "recursive_programming"; | |||
| constexpr char TRAINING[] = "training"; | |||
| constexpr char ACCUMULATION[] = "accumulation"; | |||
| class ParallelContext { | |||
| public: | |||
| @@ -111,6 +113,11 @@ class ParallelContext { | |||
| bool enable_parallel_optimizer() const { return enable_parallel_optimizer_; } | |||
| void Reset(); | |||
| void ParallelParameterContextInitShape(const FuncGraphPtr &func_graph); | |||
| void ParallelParameterContextRestoreShape(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, | |||
| AbstractBasePtr ptr); | |||
| void ParallelParameterContextCkptShape(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, | |||
| const AbstractBasePtr &ptr); | |||
| private: | |||
| ParallelContext(); | |||
| @@ -136,13 +143,9 @@ class ParallelContext { | |||
| std::string strategy_ckpt_save_file_; | |||
| std::string group_ckpt_save_file_; | |||
| bool enable_parallel_optimizer_; | |||
| bool init_param_shape_; | |||
| }; | |||
| void ParallelParameterContextInit(const FuncGraphPtr &func_graph); | |||
| void ParallelParameterContextRestoreInNoTraining(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, | |||
| AbstractBasePtr ptr); | |||
| void ParallelParameterContextCkptInTraining(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, | |||
| const AbstractBasePtr &ptr); | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -284,6 +284,39 @@ Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string & | |||
| return op; | |||
| } | |||
| void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr ¶m_node) { | |||
| MS_EXCEPTION_IF_NULL(comm_node); | |||
| MS_EXCEPTION_IF_NULL(param_node); | |||
| if (IsPrimitiveCNode(param_node, prim::kPrimReceive)) { | |||
| MS_LOG(WARNING) << "The mirror of Receive does not support fusion type now."; | |||
| return; | |||
| } | |||
| auto param = param_node->cast<ParameterPtr>(); | |||
| MS_EXCEPTION_IF_NULL(param); | |||
| auto prim = GetValueNode<PrimitivePtr>(comm_node->input(0)); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| auto attrs = prim->attrs(); | |||
| auto param_info = param->param_info(); | |||
| if (!param_info) { | |||
| MS_LOG(WARNING) << param->ToString() << "does not have parameter info."; | |||
| return; | |||
| } | |||
| int32_t fusion_type = param_info->comm_fusion(); | |||
| attrs[FUSION] = MakeValue<int64_t>(fusion_type); | |||
| prim->SetAttrs(attrs); | |||
| MS_LOG(INFO) << "Set comm fusion:" << param->param_info()->name() << "'s fusion type is " << fusion_type; | |||
| } | |||
| void AddCommOpMeanFlag(const CNodePtr &comm_node) { | |||
| MS_EXCEPTION_IF_NULL(comm_node); | |||
| auto prim = GetValueNode<PrimitivePtr>(comm_node->input(0)); | |||
| auto attrs = prim->attrs(); | |||
| MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); | |||
| bool mean_flag = ParallelContext::GetInstance()->gradients_mean(); | |||
| attrs[MEAN_FLAG] = MakeValue<bool>(mean_flag); | |||
| prim->SetAttrs(attrs); | |||
| } | |||
| Operator CreateAllGatherOp(const std::string &group) { | |||
| OperatorName operator_name = ALL_GATHER; | |||
| ValuePtr attr0_value = MakeValue(group); // group | |||
| @@ -299,6 +332,30 @@ Operator CreateAllGatherOp(const std::string &group) { | |||
| return op; | |||
| } | |||
| Operator CreateMiniStepAllGatherOp(const std::string &group) { | |||
| int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step(); | |||
| bool mean_flag = ParallelContext::GetInstance()->gradients_mean(); | |||
| OperatorName operator_name = MINI_STEP_ALL_GATHER; | |||
| ValuePtr attr0_value = MakeValue(group); // group | |||
| Attr attr0 = std::make_pair(GROUP, attr0_value); | |||
| ValuePtr attr1_value = MakeValue(grad_accumulation_step); // grad_accumulation_step | |||
| Attr attr1 = std::make_pair(GRAD_ACCUMULATION_STEP, attr1_value); | |||
| ValuePtr attr2_value = MakeValue(mean_flag); // mean_flag | |||
| Attr attr2 = std::make_pair(MEAN_FLAG, attr2_value); | |||
| OperatorAttrs operator_attrs; | |||
| operator_attrs.push_back(attr0); | |||
| operator_attrs.push_back(attr1); | |||
| operator_attrs.push_back(attr2); | |||
| OperatorParams operator_param; | |||
| OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param); | |||
| Operator op = std::make_pair(operator_name, operator_arg); | |||
| MS_LOG(INFO) << "Create MINI_STEP_ALL_GATHER success, the group is " << group; | |||
| return op; | |||
| } | |||
| // use for get tensor slice | |||
| Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout) { | |||
| Shape tensor_map = tensor_layout.tensor_map().array(); | |||
| @@ -771,7 +828,7 @@ void OperatorInfo::ComputeBatchSplitFlagList() { | |||
| ReComputeBatchSplitFlagList(); | |||
| } | |||
| // This is a common method for checking whether the generated stragegy has the correct number of devuces. | |||
| // This is a common method for checking whether the generated strategy has the correct number of devuces. | |||
| Status PrepareStrategyBase(int64_t stage_id, size_t dev_num, const Shapes &inputs_partitions, StrategyPtr *const sp) { | |||
| if (sp == nullptr) { | |||
| MS_LOG(ERROR) << "The strategy is null."; | |||
| @@ -886,7 +943,7 @@ Status GenerateStrategiesForBroadcastLeft(int64_t stage_id, const Shapes &inputs | |||
| (void)input0_strategy.erase(input0_strategy.begin(), | |||
| input0_strategy.begin() + static_cast<different_type>(size_diff)); | |||
| // handel the case likes ([1, c, d], [a, b, c, d]) | |||
| // handle the case likes ([1, c, d], [a, b, c, d]) | |||
| for (size_t i = 0; i < inputs_shape[0].size(); ++i) { | |||
| if (inputs_shape[0][i] == 1) { | |||
| input0_strategy[i] = 1; | |||
| @@ -937,7 +994,7 @@ Status GenerateStrategiesForBroadcastRight(int64_t stage_id, const Shapes &input | |||
| (void)input1_strategy.erase(input1_strategy.begin(), | |||
| input1_strategy.begin() + static_cast<different_type>(size_diff)); | |||
| // handel the case likes ([a, b, c, d], [1, c, d]) | |||
| // handle the case likes ([a, b, c, d], [1, c, d]) | |||
| for (size_t i = 0; i < inputs_shape[1].size(); ++i) { | |||
| if (inputs_shape[1][i] == 1) { | |||
| input1_strategy[i] = 1; | |||
| @@ -36,6 +36,7 @@ | |||
| #include "frontend/parallel/strategy.h" | |||
| #include "frontend/parallel/tensor_layout/tensor_info.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "base/core_ops.h" | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| @@ -160,7 +161,7 @@ class OperatorInfo { | |||
| void set_refkey_parameter_name(std::string p_name) { refkey_parameter_name_ = std::move(p_name); } | |||
| const std::string &refkey_parameter_name() const { return refkey_parameter_name_; } | |||
| // When the output of a Parameter (require_grad) being used by multiple operators, the Parameter's cost is calculated | |||
| // multiple times. This method is to correct this, and makes the cost is calulated only once. | |||
| // multiple times. This method is to correct this, and makes the cost is calculated only once. | |||
| Status CorrectMemoryCost(size_t input_index); | |||
| int64_t is_output_parameter_involve() const { return is_output_parameter_involve_; } | |||
| int64_t is_output_critical() const { return is_output_critical_; } | |||
| @@ -242,7 +243,7 @@ class OperatorInfo { | |||
| bool is_auto_parallel_ = false; // false: semi_auto_parallel; true: auto_parallel | |||
| // 'corrected_input_indices_' used to store the indices of input that have ALREADY been corrected. | |||
| std::vector<size_t> corrected_input_indices_; | |||
| // Given a parallization strategy, there is a cost. | |||
| // Given a parallelization strategy, there is a cost. | |||
| std::vector<std::shared_ptr<StrategyWithCost>> strategy_cost_; | |||
| // For each input in 'inputs_', there is a bool variable indicating whether that the corresponding input is parameter | |||
| std::vector<bool> is_parameter_; | |||
| @@ -288,6 +289,9 @@ Operator CreateVirtualDivOp(int64_t div_num); | |||
| Operator CreateAllReduceOp(const std::string &reduce_op, const std::string &group); | |||
| Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string &group); | |||
| Operator CreateAllGatherOp(const std::string &group); | |||
| Operator CreateMiniStepAllGatherOp(const std::string &group); | |||
| void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr ¶m_node); | |||
| void AddCommOpMeanFlag(const CNodePtr &comm_node); | |||
| Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout); | |||
| OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num); | |||
| int64_t ComputeRepeatDeviceNumByTensorMap(const Shape &dev_matrix_shape, const Shape &tensor_map); | |||
| @@ -109,6 +109,7 @@ constexpr char END[] = "end"; | |||
| constexpr char STRIDES[] = "strides"; | |||
| constexpr char GROUP[] = "group"; | |||
| constexpr char FUSION[] = "fusion"; | |||
| constexpr char DO_MIRROR[] = "do_mirror"; | |||
| constexpr char NUM_SAMPLED[] = "num_sampled"; | |||
| constexpr char NUM_TRUE[] = "num_true"; | |||
| constexpr char SEED[] = "seed"; | |||
| @@ -180,6 +181,7 @@ constexpr char MIRROR_MINI_STEP_OPERATOR[] = "_MirrorMiniStepOperator"; | |||
| constexpr char LOCAL_STEP[] = "local_step"; | |||
| constexpr char STRIDED_SLICE[] = "StridedSlice"; | |||
| constexpr char ALL_GATHER[] = "AllGather"; | |||
| constexpr char MINI_STEP_ALL_GATHER[] = "_MiniStepAllGather"; | |||
| constexpr char REDUCE_SCATTER[] = "ReduceScatter"; | |||
| constexpr char HOST_REDUCE_SCATTER[] = "_HostReduceScatter"; | |||
| constexpr char EMBEDDING_LOOKUP[] = "EmbeddingLookup"; | |||
| @@ -65,8 +65,8 @@ void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input) { | |||
| return; | |||
| } | |||
| ValueNodePtr prim_anf_node = new_node_input[0]->cast<ValueNodePtr>(); | |||
| PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); | |||
| auto prim_anf_node = new_node_input[0]->cast<ValueNodePtr>(); | |||
| auto prim = GetValueNode<PrimitivePtr>(prim_anf_node); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| auto attrs = prim->attrs(); | |||
| @@ -83,6 +83,19 @@ void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input) { | |||
| } | |||
| } | |||
| void SetMiniStepOpDoMirrorLabel(std::vector<AnfNodePtr> new_node_input, bool accu_flag) { | |||
| if (new_node_input.empty()) { | |||
| return; | |||
| } | |||
| auto prim_anf_node = new_node_input[0]->cast<ValueNodePtr>(); | |||
| auto prim = GetValueNode<PrimitivePtr>(prim_anf_node); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| auto attrs = prim->attrs(); | |||
| attrs[DO_MIRROR] = MakeValue<bool>(!accu_flag); | |||
| prim->SetAttrs(attrs); | |||
| } | |||
| std::vector<AnfNodePtr> CreateInput(const Operator &op, const AnfNodePtr &node, const std::string &instance_name) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| OperatorArgs arg_forward = op.second; | |||
| @@ -157,7 +170,6 @@ std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operat | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_EXCEPTION_IF_NULL(root->manager()); | |||
| AnfNodePtr local_step_param = nullptr; | |||
| AnfNodePtr grad_accu = nullptr; | |||
| std::string op_name = op.first; | |||
| OperatorArgs arg_forward = op.second; | |||
| @@ -165,25 +177,7 @@ std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operat | |||
| int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step(); | |||
| if (grad_accumulation_step > 1) { | |||
| bool find_locat_step_node = false; | |||
| auto parameters = root->parameters(); | |||
| for (auto ¶m : parameters) { | |||
| auto param_ptr = param->cast<ParameterPtr>(); | |||
| MS_EXCEPTION_IF_NULL(param_ptr); | |||
| if (param_ptr->name() == LOCAL_STEP) { | |||
| auto param_users = root->manager()->node_users()[param]; | |||
| for (auto &user : param_users) { | |||
| if (AnfNodeIsPrimitive(user.first, ASSIGN)) { | |||
| find_locat_step_node = true; | |||
| local_step_param = user.first; | |||
| MS_LOG(INFO) << "Find the local step when create mirror, it may be in the mini step grad accumulation mode"; | |||
| break; | |||
| } | |||
| } | |||
| break; | |||
| } | |||
| } | |||
| bool find_grad_accu_node = false; | |||
| for (auto ¶m : parameters) { | |||
| if (!ParameterIsCloned(param)) { | |||
| @@ -201,10 +195,12 @@ std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operat | |||
| } | |||
| } | |||
| if (op_name == MIRROR_MINI_STEP_OPERATOR) { | |||
| if (!find_locat_step_node || !find_grad_accu_node) { | |||
| if (!find_grad_accu_node) { | |||
| if (op_name == MIRROR_MINI_STEP_OPERATOR) { | |||
| op_name = MIRROR_OPERATOR; | |||
| arg_forward.first.pop_back(); | |||
| } else if (op_name == MINI_STEP_ALL_GATHER) { | |||
| MS_LOG(EXCEPTION) << "You should define `accu_grads` when enable gradient accumulation."; | |||
| } | |||
| } | |||
| } | |||
| @@ -214,9 +210,9 @@ std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operat | |||
| OperatorParams params = arg_forward.second; | |||
| std::vector<AnfNodePtr> new_node_input; | |||
| if (op_name == MIRROR_MINI_STEP_OPERATOR) { | |||
| new_node_input = {NewValueNode(pyop_instance), node, local_step_param, grad_accu}; | |||
| MS_LOG(INFO) << "Insert the local step node and grad accumulation node as the mirror op's input"; | |||
| if (op_name == MIRROR_MINI_STEP_OPERATOR || op_name == MINI_STEP_ALL_GATHER) { | |||
| new_node_input = {NewValueNode(pyop_instance), node, grad_accu}; | |||
| MS_LOG(INFO) << "Insert the grad accumulation node as the mirror op's input"; | |||
| } else { | |||
| new_node_input = {NewValueNode(pyop_instance), node}; | |||
| } | |||
| @@ -232,6 +228,10 @@ std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operat | |||
| // if the op have 'group' attr, set the rank list name for the op | |||
| SetCommunicationOpGroupLabel(new_node_input); | |||
| // gradient accumulation | |||
| if (grad_accumulation_step > 1) { | |||
| SetMiniStepOpDoMirrorLabel(new_node_input, root->has_flag(ACCUMULATION)); | |||
| } | |||
| return new_node_input; | |||
| } | |||
| @@ -284,6 +284,31 @@ static CNodePtr ReplaceNode(const Operator &op, const AnfNodePtr &pre_node, cons | |||
| return new_node; | |||
| } | |||
| // Replace pre_node with pre_node->op | |||
| static CNodePtr ReplaceMirrorNode(const FuncGraphPtr &root, const Operator &op, const AnfNodePtr &pre_node, | |||
| const FuncGraphPtr &func_graph, const std::string &instance_name, | |||
| const std::string ¶m_name) { | |||
| // insert new node before the node | |||
| FuncGraphManagerPtr manager = func_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| ScopePtr scope = pre_node->scope(); | |||
| MS_EXCEPTION_IF_NULL(scope); | |||
| std::vector<AnfNodePtr> node_input = CreateMirrorInput(root, op, pre_node, instance_name, param_name); | |||
| CNodePtr new_node = func_graph->NewCNode(node_input); | |||
| MS_EXCEPTION_IF_NULL(new_node); | |||
| if (instance_name.find(SPLIT_SENS) == std::string::npos) { | |||
| new_node->set_in_forward_flag(true); // mark forward flag | |||
| } | |||
| auto new_node_prim = GetValueNode<PrimitivePtr>(node_input[0]); | |||
| new_node_prim->set_instance_name(instance_name); | |||
| new_node_prim->set_attr("keep_value_node_input", MakeValue(true)); | |||
| new_node->set_scope(scope); | |||
| node_input[0]->set_scope(scope); | |||
| manager->Replace(pre_node, new_node); | |||
| MS_LOG(INFO) << "Insert " << instance_name << " success"; | |||
| return new_node; | |||
| } | |||
| std::string CreateInstanceName(const CNodePtr &node, size_t index) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (!IsValueNode<Primitive>(node->input(0))) { | |||
| @@ -1085,29 +1110,6 @@ bool IsCastBeforMirror(const CNodePtr &node, size_t index) { | |||
| return (type_id != kNumberTypeFloat32); | |||
| } | |||
| static void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr ¶m_node) { | |||
| MS_EXCEPTION_IF_NULL(comm_node); | |||
| MS_EXCEPTION_IF_NULL(param_node); | |||
| if (IsPrimitiveCNode(param_node, prim::kPrimReceive)) { | |||
| MS_LOG(WARNING) << "The mirror of Receive does not support fusion type now."; | |||
| return; | |||
| } | |||
| auto param = param_node->cast<ParameterPtr>(); | |||
| MS_EXCEPTION_IF_NULL(param); | |||
| auto prim = GetValueNode<PrimitivePtr>(comm_node->input(0)); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| auto attrs = prim->attrs(); | |||
| auto param_info = param->param_info(); | |||
| if (!param_info) { | |||
| MS_LOG(WARNING) << param->ToString() << "does not have parameter info."; | |||
| return; | |||
| } | |||
| int32_t fusion_type = param_info->comm_fusion(); | |||
| attrs[FUSION] = MakeValue<int64_t>(fusion_type); | |||
| prim->SetAttrs(attrs); | |||
| MS_LOG(INFO) << "Set comm fusion:" << param->param_info()->name() << "'s fusion type is " << fusion_type; | |||
| } | |||
| static bool CheckInsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node, size_t node_size) { | |||
| if ((node->inputs().size() == 2) && (IsValueNode<ValueSequeue>(node->input(1)))) { | |||
| MS_LOG(INFO) << "Input is ValueList, skip it."; | |||
| @@ -1194,7 +1196,6 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons | |||
| InsertMirrorNode(root, op, cnode, size_t(1), pre_node, func_graph, instance_name, param_name); | |||
| auto comm_op = cnode->input(size_t(1))->cast<CNodePtr>(); | |||
| // add fusion flag | |||
| // pipeline mirror would not be set, which should be supported later | |||
| AddCommOpFusionType(comm_op, param_node_pair.first); | |||
| } | |||
| continue; | |||
| @@ -1539,33 +1540,40 @@ std::pair<AnfNodePtr, int64_t> FindSubGraph(const FuncGraphPtr &graph, const Anf | |||
| return std::make_pair(nullptr, 0); | |||
| } | |||
| static void InsertAllGatherOp(const std::string &group, const std::pair<AnfNodePtr, int> &res, | |||
| const AnfNodePtr ¶meter) { | |||
| Operator op = CreateAllGatherOp(group); | |||
| static void InsertAllGatherOp(const FuncGraphPtr &root, const std::string &group, const std::pair<AnfNodePtr, int> &res, | |||
| const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(res.first); | |||
| MS_EXCEPTION_IF_NULL(parameter); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto cnode = res.first->cast<CNodePtr>(); | |||
| auto graph = cnode->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| auto cnode_prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||
| MS_EXCEPTION_IF_NULL(cnode_prim); | |||
| int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step(); | |||
| Operator op; | |||
| CNodePtr allgather; | |||
| if (cnode_prim->name() == CAST) { | |||
| allgather = ReplaceNode(op, cnode, graph, PARALLEL_OPTIMIZER_ALLGATHER); | |||
| if (grad_accumulation_step > 1) { | |||
| op = CreateMiniStepAllGatherOp(group); | |||
| auto param_name = node->cast<ParameterPtr>()->name(); | |||
| if (cnode_prim->name() == CAST) { | |||
| allgather = ReplaceMirrorNode(root, op, cnode, graph, PARALLEL_OPTIMIZER_ALLGATHER, param_name); | |||
| } else { | |||
| InsertMirrorNode(root, op, cnode, res.second, node, graph, PARALLEL_OPTIMIZER_ALLGATHER, param_name); | |||
| allgather = cnode->input(res.second)->cast<CNodePtr>(); | |||
| } | |||
| } else { | |||
| InsertNode(op, cnode, res.second, parameter, graph, PARALLEL_OPTIMIZER_ALLGATHER); | |||
| allgather = cnode->input(res.second)->cast<CNodePtr>(); | |||
| op = CreateAllGatherOp(group); | |||
| if (cnode_prim->name() == CAST) { | |||
| allgather = ReplaceNode(op, cnode, graph, PARALLEL_OPTIMIZER_ALLGATHER); | |||
| } else { | |||
| InsertNode(op, cnode, res.second, node, graph, PARALLEL_OPTIMIZER_ALLGATHER); | |||
| allgather = cnode->input(res.second)->cast<CNodePtr>(); | |||
| } | |||
| } | |||
| MS_EXCEPTION_IF_NULL(allgather); | |||
| // add fusion flag | |||
| AddCommOpFusionType(allgather, parameter); | |||
| AddCommOpFusionType(allgather, node); | |||
| // add gradients mean | |||
| auto prim = GetValueNode<PrimitivePtr>(allgather->input(0)); | |||
| auto attrs = prim->attrs(); | |||
| MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); | |||
| bool mean_flag = ParallelContext::GetInstance()->gradients_mean(); | |||
| attrs["mean_flag"] = MakeValue<bool>(mean_flag); | |||
| prim->SetAttrs(attrs); | |||
| AddCommOpMeanFlag(allgather); | |||
| } | |||
| static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr ¶meter, | |||
| @@ -1588,7 +1596,7 @@ static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr & | |||
| << distribute_operator->inputs_tensor_info().size(); | |||
| } | |||
| // insert allgather operator between shard parameter and cnode | |||
| InsertAllGatherOp(opt_shard_group, param_pair, parameter); | |||
| InsertAllGatherOp(root, opt_shard_group, param_pair, parameter); | |||
| MS_LOG(INFO) << "Parallel optimizer is applied between " << parameter->ToString() << " and " << cnode->ToString(); | |||
| } | |||
| } | |||
| @@ -1733,12 +1741,20 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) { | |||
| if (found_be_cloned_parameter) { | |||
| // set the shape and tensor layout for cloned parameter | |||
| std::string param_name = cloned_parameter_node->cast<ParameterPtr>()->name(); | |||
| cloned_parameter->set_user_data<TensorLayout>(cloned_from_parameter->user_data<TensorLayout>()); | |||
| MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract()); | |||
| MS_EXCEPTION_IF_NULL(cloned_from_node->abstract()); | |||
| auto cloned_abstract = cloned_parameter_node->abstract()->Clone(); | |||
| MS_EXCEPTION_IF_NULL(cloned_abstract); | |||
| cloned_abstract->set_shape(cloned_from_node->abstract()->GetShapeTrack()); | |||
| if (param_name.find(ACCU_GRADS) != std::string::npos) { | |||
| auto slice_shape = cloned_from_parameter->user_data<TensorLayout>()->slice_shape().array(); | |||
| std::shared_ptr<abstract::BaseShape> parallel_shape = std::make_shared<abstract::Shape>(slice_shape); | |||
| MS_EXCEPTION_IF_NULL(parallel_shape); | |||
| cloned_abstract->set_shape(parallel_shape); | |||
| } else { | |||
| cloned_abstract->set_shape(cloned_from_node->abstract()->GetShapeTrack()); | |||
| } | |||
| cloned_parameter_node->set_abstract(cloned_abstract); | |||
| MS_LOG(INFO) << "The parameter: " << cloned_parameter->name() | |||
| << " is cloned, the be cloned parameter is: " << cloned_from_parameter->name() | |||
| @@ -30,6 +30,8 @@ | |||
| #include "frontend/parallel/strategy.h" | |||
| #include "frontend/parallel/tensor_layout/tensor_redistribution.h" | |||
| #include "pipeline/jit/pipeline.h" | |||
| #include "frontend/parallel/ops_info/ops_utils.h" | |||
| #include "frontend/parallel/auto_parallel/operator_costmodel.h" | |||
| using OperatorInfoPtr = std::shared_ptr<mindspore::parallel::OperatorInfo>; | |||
| @@ -258,9 +258,9 @@ bool AbstractSpecializeAction(const ResourcePtr &res) { | |||
| FuncGraphPtr func_graph = res->func_graph(); | |||
| abstract::AbstractBasePtrList args_spec = res->args_spec(); | |||
| parallel::ParallelParameterContextInit(func_graph); | |||
| auto context = parallel::ParallelContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(parallel::ParallelContext::GetInstance()); | |||
| context->ParallelParameterContextInitShape(func_graph); | |||
| // suppose that there is not KeywordArgument for the top graph | |||
| // get the hyper parameter | |||
| for (const auto ¶m : func_graph->parameters()) { | |||
| @@ -271,9 +271,9 @@ bool AbstractSpecializeAction(const ResourcePtr &res) { | |||
| auto ref_key = std::make_shared<RefKey>(param_node->name()); | |||
| auto abs_ref_key = ref_key->ToAbstract(); | |||
| auto abs_ref = std::make_shared<abstract::AbstractRef>(abs_ref_key, abs_value); | |||
| parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, abs_ref); | |||
| context->ParallelParameterContextRestoreShape(func_graph, param_node, abs_ref); | |||
| args_spec.push_back(abs_ref); | |||
| parallel::ParallelParameterContextCkptInTraining(func_graph, param_node, abs_ref); | |||
| context->ParallelParameterContextCkptShape(func_graph, param_node, abs_ref); | |||
| } | |||
| } | |||
| // Analyze | |||
| @@ -159,6 +159,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| irpass.replace_applicator_, | |||
| irpass.mirror_mini_step_elim_, | |||
| irpass.row_tensor_add_zeros_like_, | |||
| irpass.mini_step_allgather_replace_, | |||
| }); | |||
| opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_}); | |||
| opt::OptPassConfig grad = opt::OptPassConfig({irpass.expand_jprim_}, true); | |||
| @@ -374,7 +374,7 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { | |||
| .def(py::init<TypePtr, const ShapeVector>(), py::arg("dtype"), py::arg("shape")) | |||
| .def_property_readonly("dtype", &MetaTensor::Dtype, "Get the MetaTensor's dtype.") | |||
| .def_property_readonly("shape", &MetaTensor::shape, "Get the MetaTensor's shape.") | |||
| .def_property("_param_info", &MetaTensor::param_info, &MetaTensor::set_param_info) | |||
| .def_property("param_info", &MetaTensor::param_info, &MetaTensor::set_param_info) | |||
| .def(py::pickle( | |||
| [](const MetaTensor &t) { // __getstate__ | |||
| /* Return a tuple that fully encodes the state of the object */ | |||
| @@ -134,7 +134,7 @@ class Parameter(Tensor_): | |||
| Parameter, (data, self.name, self.requires_grad, self.layerwise_parallel)) | |||
| def __init__(self, default_input, name=None, requires_grad=True, layerwise_parallel=False, parallel_optimizer=True): | |||
| self._param_info = ParamInfo() | |||
| self.param_info = ParamInfo() | |||
| self.init_in_server = False | |||
| self.cache_enable = False | |||
| self.name = name | |||
| @@ -230,7 +230,7 @@ class Parameter(Tensor_): | |||
| "sparse operator support initialization in server.".format(self.name)) | |||
| self.is_param_ps = True | |||
| self.init_in_server = init_in_server | |||
| self._param_info.init_in_server = init_in_server | |||
| self.param_info.init_in_server = init_in_server | |||
| @property | |||
| def inited_param(self): | |||
| @@ -245,7 +245,7 @@ class Parameter(Tensor_): | |||
| @property | |||
| def name(self): | |||
| """Get the name of the parameter.""" | |||
| return self._param_info.name | |||
| return self.param_info.name | |||
| @name.setter | |||
| def name(self, name_): | |||
| @@ -272,9 +272,9 @@ class Parameter(Tensor_): | |||
| if len(self.shape) != 2: | |||
| raise RuntimeError("The dims of parameter '{}' must be 2, but got {}." | |||
| .format(self.name, len(self.shape))) | |||
| _reinsert_hash_table_size(name_, self._param_info.name, self.shape[0], self.shape[1]) | |||
| _reinsert_hash_table_size(name_, self.param_info.name, self.shape[0], self.shape[1]) | |||
| self._param_info.name = name_ | |||
| self.param_info.name = name_ | |||
| @property | |||
| def sliced(self): | |||
| @@ -288,12 +288,12 @@ class Parameter(Tensor_): | |||
| @property | |||
| def comm_fusion(self): | |||
| """Get the fusion type for communication operators corresponding to this parameter.""" | |||
| return self._param_info.comm_fusion | |||
| return self.param_info.comm_fusion | |||
| @comm_fusion.setter | |||
| def comm_fusion(self, comm_fusion_): | |||
| """Set the fusion type for communication operators corresponding to this parameter.""" | |||
| self._param_info.comm_fusion = comm_fusion_ | |||
| self.param_info.comm_fusion = comm_fusion_ | |||
| @property | |||
| def unique(self): | |||
| @@ -339,7 +339,7 @@ class Parameter(Tensor_): | |||
| """ | |||
| x = copy(self) | |||
| # pylint: disable=protected-access | |||
| x._param_info = self._param_info.clone() | |||
| x.param_info = self.param_info.clone() | |||
| x.is_init = False | |||
| x.init = self.init | |||
| x.is_param_ps = self.is_param_ps | |||
| @@ -355,57 +355,57 @@ class Parameter(Tensor_): | |||
| @property | |||
| def layerwise_parallel(self): | |||
| return self._param_info.layerwise_parallel | |||
| return self.param_info.layerwise_parallel | |||
| @layerwise_parallel.setter | |||
| def layerwise_parallel(self, value=True): | |||
| if not isinstance(value, bool): | |||
| raise TypeError("`layerwise_parallel` parameter must be bool type") | |||
| self._param_info.layerwise_parallel = value | |||
| self.param_info.layerwise_parallel = value | |||
| @property | |||
| def parallel_optimizer(self): | |||
| """Return whether the parameter requires weight shard for parallel optimizer.""" | |||
| return self._param_info.parallel_optimizer | |||
| return self.param_info.parallel_optimizer | |||
| @parallel_optimizer.setter | |||
| def parallel_optimizer(self, value=True): | |||
| if not isinstance(value, bool): | |||
| raise TypeError("`parallel_optimizer` parameter must be bool type") | |||
| self._param_info.parallel_optimizer = value | |||
| self.param_info.parallel_optimizer = value | |||
| @property | |||
| def cache_enable(self): | |||
| """Return whether the parameter is cache enable.""" | |||
| return self._param_info.cache_enable | |||
| return self.param_info.cache_enable | |||
| @cache_enable.setter | |||
| def cache_enable(self, value=True): | |||
| if not isinstance(value, bool): | |||
| raise TypeError("`cache_enable` parameter must be bool type") | |||
| self._param_info.cache_enable = value | |||
| self.param_info.cache_enable = value | |||
| @property | |||
| def cache_shape(self): | |||
| """Return the cache shape corresponding to the parameter if use cache.""" | |||
| return self._param_info.cache_shape | |||
| return self.param_info.cache_shape | |||
| @cache_shape.setter | |||
| def cache_shape(self, value): | |||
| if not isinstance(value, (tuple, list)): | |||
| raise TypeError("`cache_shape` parameter must be tuple or list type") | |||
| self._param_info.cache_shape = value | |||
| self.param_info.cache_shape = value | |||
| @property | |||
| def requires_grad(self): | |||
| """Return whether the parameter requires gradient.""" | |||
| return self._param_info.requires_grad | |||
| return self.param_info.requires_grad | |||
| @requires_grad.setter | |||
| def requires_grad(self, value=True): | |||
| if not isinstance(value, bool): | |||
| raise TypeError("`requires_grad` parameter must be bool type") | |||
| self._param_info.requires_grad = value | |||
| self.param_info.requires_grad = value | |||
| @property | |||
| def data(self): | |||
| @@ -419,7 +419,9 @@ class Parameter(Tensor_): | |||
| self.init = None | |||
| return self.assign_value(data) | |||
| # create a new tensor | |||
| return Parameter(data, self.name, self.requires_grad) | |||
| new_param = Parameter(data, self.name, self.requires_grad) | |||
| new_param.param_info = self.param_info | |||
| return new_param | |||
| def set_data(self, data, slice_shape=False): | |||
| """ | |||
| @@ -304,6 +304,7 @@ inline const PrimitivePtr kPrimCustomExtractFeatures = std::make_shared<Primitiv | |||
| // Comm ops | |||
| inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator"); | |||
| inline const PrimitivePtr kPrimMirrorMiniStep = std::make_shared<Primitive>("_MirrorMiniStepOperator"); | |||
| inline const PrimitivePtr kPrimMiniStepAllGather = std::make_shared<Primitive>("_MiniStepAllGather"); | |||
| inline const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv"); | |||
| inline const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset"); | |||
| inline const PrimitivePtr kPrimSend = std::make_shared<Primitive>("Send"); | |||
| @@ -20,7 +20,7 @@ from mindspore.communication import get_rank, get_group_size | |||
| from .. import operations as P | |||
| from ...common.tensor import RowTensor | |||
| from ..composite.multitype_ops.zeros_like_impl import zeros_like | |||
| from ..operations.comm_ops import (AllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast, | |||
| from ..operations.comm_ops import (AllGather, _MiniStepAllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast, | |||
| _GetTensorSlice, _MirrorOperator, _MirrorMiniStepOperator, ReduceOp, | |||
| ReduceScatter, _HostReduceScatter, _VirtualDiv, AllSwap) | |||
| from .grad_base import bprop_getters | |||
| @@ -150,6 +150,39 @@ def get_bprop_all_gather(self): | |||
| return bprop | |||
| @bprop_getters.register(_MiniStepAllGather) | |||
| def get_bprop_mini_step_all_gather(self): | |||
| """Generate bprop for _MiniStepAllGather""" | |||
| fusion = self.get_attr_dict()["fusion"] | |||
| mean_flag = self.get_attr_dict()["mean_flag"] | |||
| do_mirror = self.get_attr_dict()["do_mirror"] | |||
| scale = 1 / self.rank_size | |||
| all_reduce = AllReduce(ReduceOp.SUM, self.group).add_prim_attr("fusion", fusion) | |||
| if self.instance_name: | |||
| instance_name = "grad_" + self.instance_name | |||
| all_reduce.set_prim_instance_name(instance_name) | |||
| rank = get_rank(self.group) | |||
| dev_num = get_group_size(self.group) | |||
| split = P.Split(output_num=dev_num) | |||
| def bprop(x, z, out, dout): | |||
| if do_mirror: | |||
| if mean_flag: | |||
| tmp = z + dout | |||
| grad = all_reduce(tmp) | |||
| dx = split(grad)[rank] | |||
| dx = F.tensor_mul(dx, scale) | |||
| else: | |||
| tmp = z + dout | |||
| grad = all_reduce(tmp) | |||
| dx = split(grad)[rank] | |||
| else: | |||
| dx = dout | |||
| return (dx, zeros_like(z)) | |||
| return bprop | |||
| @bprop_getters.register(_HostAllGather) | |||
| def get_bprop_host_all_gather(self): | |||
| """Generate bprop for _HostAllGather""" | |||
| @@ -291,18 +324,13 @@ def get_bprop_mirror_mini_step_operator(self): | |||
| group = self.group | |||
| dev_num = self.dev_num | |||
| mean_flag = self.mean_flag | |||
| grad_accumulation_step = self.grad_accumulation_step | |||
| all_reduce = AllReduce(group=group) | |||
| all_gather = AllGather(group=group) | |||
| mul = P.Mul() | |||
| cast = P.Cast() | |||
| equal = P.Equal() | |||
| reshape = P.Reshape() | |||
| fusion = 1 | |||
| if hasattr(self, 'fusion'): | |||
| fusion = self.fusion | |||
| fusion = self.get_attr_dict()["fusion"] | |||
| all_reduce.add_prim_attr("fusion", fusion) | |||
| if hasattr(self, 'parameter'): | |||
| parameter = self.parameter | |||
| @@ -311,16 +339,15 @@ def get_bprop_mirror_mini_step_operator(self): | |||
| if self.instance_name: | |||
| instance_name = "grad_mirror" + self.instance_name | |||
| all_reduce.set_prim_instance_name(instance_name) | |||
| do_mirror = self.get_attr_dict()["do_mirror"] | |||
| def bprop(x, y, z, out, dout): | |||
| do_mirror = equal(y, grad_accumulation_step) | |||
| do_mirror = reshape(do_mirror, (())) | |||
| def bprop(x, z, out, dout): | |||
| if mean_flag: | |||
| if F.issubclass_(F.typeof(dout), mstype.tensor): | |||
| if do_mirror: | |||
| tmp = z + dout | |||
| real_grad = all_reduce(tmp) | |||
| dx = real_grad - z | |||
| dx = real_grad | |||
| else: | |||
| dx = dout | |||
| float_one = F.scalar_cast(1.0, F.dtype(dx)) | |||
| @@ -342,7 +369,7 @@ def get_bprop_mirror_mini_step_operator(self): | |||
| if do_mirror: | |||
| tmp = z + dout | |||
| real_grad = all_reduce(tmp) | |||
| dx = real_grad - z | |||
| dx = real_grad | |||
| else: | |||
| dx = dout | |||
| else: | |||
| @@ -354,7 +381,7 @@ def get_bprop_mirror_mini_step_operator(self): | |||
| grad = dout.values | |||
| dx = RowTensor(indices, grad, dout.dense_shape) | |||
| return (dx, zeros_like(y), zeros_like(z)) | |||
| return (dx, zeros_like(z)) | |||
| return bprop | |||
| @@ -35,7 +35,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unsta | |||
| SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup, | |||
| Unique, GatherD, Identity, Range) | |||
| from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast, | |||
| _MirrorOperator, _MirrorMiniStepOperator, ReduceOp, _VirtualDataset, | |||
| _MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset, | |||
| _VirtualDiv, _GetTensorSlice, | |||
| _HostAllGather, _HostReduceScatter) | |||
| from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary, | |||
| @@ -200,6 +200,38 @@ class AllGather(PrimitiveWithInfer): | |||
| raise NotImplementedError | |||
| class _MiniStepAllGather(PrimitiveWithInfer): | |||
| """ | |||
| Auto parallel virtual operator. Do nothing in forward, do reducescatter in backward in mini-step. It is only for | |||
| internal use of parallel modules and cannot be called by users. | |||
| Args: | |||
| group (str): The communication group to work on. Default: None. | |||
| grad_accumulation_step (int): The grad accumulation step. Default: None. | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, group=GlobalComm.WORLD_COMM_GROUP, grad_accumulation_step=None, mean_flag=None): | |||
| validator.check_value_type('group', _get_group(group), (str,), self.name) | |||
| self.rank = get_rank(_get_group(group)) | |||
| self.rank_size = get_group_size(_get_group(group)) | |||
| validator.check('rank', self.rank, 'rank_size', self.rank_size, Rel.LT, self.name) | |||
| self.add_prim_attr('rank_size', self.rank_size) | |||
| self.add_prim_attr('group', _get_group(group)) | |||
| self.add_prim_attr('fusion', 1) | |||
| self.grad_accumulation_step = grad_accumulation_step | |||
| self.mean_flag = mean_flag | |||
| def infer_shape(self, x_shape, z_shape): | |||
| validator.check_positive_int(len(x_shape), "x shape", self.name) | |||
| if x_shape[0] > 0: | |||
| x_shape[0] = x_shape[0] * self.rank_size | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype, z_shape): | |||
| validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name) | |||
| return x_dtype | |||
| class _HostAllGather(PrimitiveWithInfer): | |||
| """ | |||
| Gathers tensors from the specified communication group on host. | |||
| @@ -590,10 +622,10 @@ class _MirrorMiniStepOperator(PrimitiveWithInfer): | |||
| self.mean_flag = mean_flag | |||
| self.grad_accumulation_step = grad_accumulation_step | |||
| def infer_shape(self, x_shape, y_shape, z_shape): | |||
| def infer_shape(self, x_shape, z_shape): | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype, y_shape, z_shape): | |||
| def infer_dtype(self, x_dtype, z_shape): | |||
| return x_dtype | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -17,15 +17,14 @@ import numpy as np | |||
| import mindspore as ms | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore import context, Tensor, Parameter | |||
| from mindspore.nn import Cell, Momentum, Norm | |||
| from mindspore.train import Model | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops import functional as F | |||
| from mindspore.common.initializer import initializer | |||
| from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell | |||
| from mindspore.context import ParallelMode | |||
| from mindspore.nn import DistributedGradReducer, DynamicLossScaleUpdateCell, Cell, Momentum, Norm | |||
| from mindspore.parallel._utils import _get_device_num | |||
| from tests.dataset_mock import MindData | |||
| @@ -142,29 +141,29 @@ class TrainAccumulateStepsWithLossScaleCell(Cell): | |||
| accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size = | |||
| batch_size * accumulation_steps. Default: 1. | |||
| """ | |||
| def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=4): | |||
| def __init__(self, network, optimizer, scale_update_cell=None): | |||
| super(TrainAccumulateStepsWithLossScaleCell, self).__init__(auto_prefix=False) | |||
| self.accu = False | |||
| self.is_accu_step = Tensor(np.array([self.accu])) | |||
| self.network = network | |||
| self.network.set_grad() | |||
| self.weights = optimizer.parameters | |||
| self.optimizer = optimizer | |||
| self.accumulation_steps = accumulation_steps | |||
| self.accumulation_steps = context.get_auto_parallel_context("grad_accumulation_step") | |||
| self.one = Tensor(np.array([1]).astype(np.int32)) | |||
| self.zero = Tensor(np.array([0]).astype(np.int32)) | |||
| self.local_step = Parameter(initializer(0, [1], mstype.int32), name="local_step") | |||
| self.accu_grads = self.weights.clone(prefix="accu_grads", init='zeros') | |||
| self.accu_overflow = Parameter(initializer(0, [1], mstype.int32)) | |||
| self.accu_loss = Parameter(initializer(0, [1], mstype.float32)) | |||
| self.grad = C.GradOperation(get_by_list=True, sens_param=True) | |||
| self.reducer_flag = False | |||
| self.grad = C.GradOperation(get_by_list=True, sens_param=True) | |||
| self.parallel_mode = context.get_auto_parallel_context("parallel_mode") | |||
| if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: | |||
| self.reducer_flag = True | |||
| self.grad_reducer = F.identity | |||
| self.degree = 1 | |||
| self.grad_reducer = F.identity | |||
| if self.reducer_flag: | |||
| self.degree = get_group_size() | |||
| self.degree = _get_device_num() | |||
| self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) | |||
| self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) | |||
| self.overflow_reducer = F.identity | |||
| @@ -197,34 +196,27 @@ class TrainAccumulateStepsWithLossScaleCell(Cell): | |||
| else: | |||
| scaling_sens = sens | |||
| # update accumulation parameters | |||
| is_accu_step = self.not_equal(self.local_step, self.accumulation_steps) | |||
| self.local_step = self.select(is_accu_step, self.local_step + self.one, self.one) | |||
| self.accu_loss = self.select(is_accu_step, self.accu_loss + loss, loss) | |||
| mean_loss = self.accu_loss / self.local_step | |||
| is_accu_step = self.not_equal(self.local_step, self.accumulation_steps) | |||
| # alloc status and clear should be right before gradoperation | |||
| init = self.alloc_status() | |||
| self.clear_before_grad(init) | |||
| grads = self.grad(self.network, weights)(x, b, self.cast(scaling_sens, mstype.float32)) | |||
| accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, grads) | |||
| mean_loss = F.depend(mean_loss, accu_succ) | |||
| if self.is_accu_step and self.accumulation_steps > 1: | |||
| accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, grads) | |||
| loss = F.depend(loss, accu_succ) | |||
| self.get_status(init) | |||
| flag_sum = self.reduce_sum(init, (0,)) | |||
| overflow = self.less_equal(self.base, flag_sum) | |||
| overflow = self.logical_or(self.not_equal(self.accu_overflow, self.zero), overflow) | |||
| accu_overflow = self.select(overflow, self.one, self.zero) | |||
| self.accu_overflow = self.select(is_accu_step, accu_overflow, self.zero) | |||
| is_accu_step = self.reshape(is_accu_step, (())) | |||
| self.accu_overflow = self.select(self.is_accu_step, accu_overflow, self.zero) | |||
| if is_accu_step: | |||
| if self.is_accu_step: | |||
| succ = False | |||
| else: | |||
| # apply grad reducer on grads | |||
| grads = self.grad_reducer(self.accu_grads) | |||
| grads = self.grad_reducer(grads) | |||
| scaling = scaling_sens * self.degree * self.accumulation_steps | |||
| grads = self.hyper_map(F.partial(grad_scale, scaling), grads) | |||
| grads = ClipByGlobalNorm()(grads) | |||
| @@ -241,7 +233,7 @@ class TrainAccumulateStepsWithLossScaleCell(Cell): | |||
| else: | |||
| succ = self.optimizer(grads) | |||
| ret = (mean_loss, overflow, scaling_sens) | |||
| ret = (loss, overflow, scaling_sens) | |||
| return F.depend(ret, succ) | |||
| @@ -265,25 +257,51 @@ _b = Tensor(np.ones([16]), dtype=ms.float32) | |||
| _w1 = Tensor(np.ones([16]), dtype=ms.float32) | |||
| def compile_net(net, grad_accumulation_step): | |||
| context.set_context(save_graphs=True) | |||
| def compile_net(net): | |||
| context.set_context(enable_sparse=False) | |||
| learning_rate = 0.1 | |||
| momentum = 0.9 | |||
| epoch_size = 2 | |||
| dataset = Dataset(_x, _b) | |||
| opt = Momentum(net.trainable_params(), learning_rate, momentum) | |||
| update_cell = DynamicLossScaleUpdateCell(loss_scale_value=65536, scale_factor=2, scale_window=1000) | |||
| net_wrap = TrainAccumulateStepsWithLossScaleCell(net, opt, scale_update_cell=update_cell, | |||
| accumulation_steps=grad_accumulation_step) | |||
| net_wrap = TrainAccumulateStepsWithLossScaleCell(net, opt, scale_update_cell=update_cell) | |||
| model = Model(net_wrap) | |||
| model.train(epoch_size, dataset, dataset_sink_mode=False) | |||
| context.reset_auto_parallel_context() | |||
| def test_grad_accumulation(): | |||
| def test_grad_accumulation_accu(): | |||
| grad_accumulation_step = 4 | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0, | |||
| grad_accumulation_step=grad_accumulation_step) | |||
| strategy = ((2,), (2,)) | |||
| net = Net(_w1, strategy) | |||
| compile_net(net, grad_accumulation_step) | |||
| net = Net(_w1, strategy).add_flags_recursive(accu=True) | |||
| compile_net(net) | |||
| def test_grad_accu_and_opt_shard_accu(): | |||
| grad_accumulation_step = 4 | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0, | |||
| grad_accumulation_step=grad_accumulation_step, enable_parallel_optimizer=True) | |||
| strategy = ((2,), (2,)) | |||
| net = Net(_w1, strategy).add_flags_recursive(accu=True) | |||
| compile_net(net) | |||
| def test_grad_accumulation_not_accu(): | |||
| grad_accumulation_step = 4 | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0, | |||
| grad_accumulation_step=grad_accumulation_step) | |||
| strategy = ((2,), (2,)) | |||
| net = Net(_w1, strategy).add_flags_recursive(accu=False) | |||
| compile_net(net) | |||
| def test_grad_accu_and_opt_shard_not_accu(): | |||
| grad_accumulation_step = 4 | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0, | |||
| grad_accumulation_step=grad_accumulation_step, enable_parallel_optimizer=True) | |||
| strategy = ((2,), (2,)) | |||
| net = Net(_w1, strategy).add_flags_recursive(accu=False) | |||
| compile_net(net) | |||