| @@ -86,6 +86,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||||
| same_eliminate_ = MakeSubstitution(std::make_shared<SameEliminater>(), "same_eliminate", prim::kPrimSameTypeShape); | same_eliminate_ = MakeSubstitution(std::make_shared<SameEliminater>(), "same_eliminate", prim::kPrimSameTypeShape); | ||||
| mirror_mini_step_elim_ = MakeSubstitution(std::make_shared<MirrorMiniStepEliminater>(), "mirror_mini_step_eliminate", | mirror_mini_step_elim_ = MakeSubstitution(std::make_shared<MirrorMiniStepEliminater>(), "mirror_mini_step_eliminate", | ||||
| prim::kPrimMirrorMiniStep); | prim::kPrimMirrorMiniStep); | ||||
| mini_step_allgather_replace_ = MakeSubstitution(std::make_shared<MiniStepAllGatherPass>(), | |||||
| "mini_step_allgather_replace", prim::kPrimMiniStepAllGather); | |||||
| check_bprop_eliminate_ = | check_bprop_eliminate_ = | ||||
| MakeSubstitution(std::make_shared<CheckBpropEliminater>(), "check_bprop_eliminate", prim::kPrimCheckBprop); | MakeSubstitution(std::make_shared<CheckBpropEliminater>(), "check_bprop_eliminate", prim::kPrimCheckBprop); | ||||
| reset_defer_inline_ = | reset_defer_inline_ = | ||||
| @@ -52,6 +52,7 @@ class OptimizeIRPassLib { | |||||
| SubstitutionPtr depend_value_elim_; | SubstitutionPtr depend_value_elim_; | ||||
| SubstitutionPtr all_reduce_const_elim_; | SubstitutionPtr all_reduce_const_elim_; | ||||
| SubstitutionPtr mirror_mini_step_elim_; | SubstitutionPtr mirror_mini_step_elim_; | ||||
| SubstitutionPtr mini_step_allgather_replace_; | |||||
| // Env Item Eliminate | // Env Item Eliminate | ||||
| SubstitutionPtr env_get_item_eliminate_; | SubstitutionPtr env_get_item_eliminate_; | ||||
| @@ -33,6 +33,7 @@ | |||||
| #include "utils/comm_manager.h" | #include "utils/comm_manager.h" | ||||
| #include "frontend/parallel/context.h" | #include "frontend/parallel/context.h" | ||||
| #include "pipeline/jit/parse/resolve.h" | #include "pipeline/jit/parse/resolve.h" | ||||
| #include "frontend/parallel/step_parallel.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| @@ -155,7 +156,7 @@ class CheckBpropEliminater : public AnfVisitor { | |||||
| AnfNodePtr x_{nullptr}; | AnfNodePtr x_{nullptr}; | ||||
| }; | }; | ||||
| // {prim::kPrimMirrorMiniStep, X, Y, Z} -> X | |||||
| // {prim::kPrimMirrorMiniStep, X, Z} -> X | |||||
| class MirrorMiniStepEliminater : public AnfVisitor { | class MirrorMiniStepEliminater : public AnfVisitor { | ||||
| public: | public: | ||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | ||||
| @@ -163,11 +164,7 @@ class MirrorMiniStepEliminater : public AnfVisitor { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| if (cnode == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| auto inputs = cnode->inputs(); | |||||
| auto &inputs = node->cast<CNodePtr>()->inputs(); | |||||
| if (inputs.size() < 2) { | if (inputs.size() < 2) { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -178,6 +175,32 @@ class MirrorMiniStepEliminater : public AnfVisitor { | |||||
| void Visit(const AnfNodePtr &) override {} | void Visit(const AnfNodePtr &) override {} | ||||
| }; | }; | ||||
| // {prim::kPrimMiniStepAllGather, X, Z} -> {prim::kPrimAllGather, X} | |||||
| class MiniStepAllGatherPass : public AnfVisitor { | |||||
| public: | |||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||||
| if (!IsPrimitiveCNode(node, prim::kPrimMiniStepAllGather) || node->func_graph() == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| auto &inputs = node->cast<CNodePtr>()->inputs(); | |||||
| if (inputs.size() < 2) { | |||||
| return nullptr; | |||||
| } | |||||
| auto prim = GetValueNode<PrimitivePtr>(node->cast<CNodePtr>()->input(0)); | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| auto attrs = prim->attrs(); | |||||
| std::string group = attrs[parallel::GROUP]->ToString(); | |||||
| parallel::Operator op = parallel::CreateAllGatherOp(group); | |||||
| std::vector<AnfNodePtr> node_input = parallel::CreateInput(op, inputs[1], parallel::PARALLEL_OPTIMIZER_ALLGATHER); | |||||
| auto func_graph = inputs[1]->func_graph(); | |||||
| CNodePtr new_node = func_graph->NewCNode(node_input); | |||||
| return new_node; | |||||
| } | |||||
| void Visit(const AnfNodePtr &) override {} | |||||
| }; | |||||
| // Reset defer_inline flag | // Reset defer_inline flag | ||||
| class ResetDeferInline : public AnfVisitor { | class ResetDeferInline : public AnfVisitor { | ||||
| public: | public: | ||||
| @@ -328,6 +351,80 @@ class PynativeEliminater : public OptimizerCaller { | |||||
| return out; | return out; | ||||
| } | } | ||||
| private: | |||||
| AnfNodePtr OperatorHandle1(const PatternNode<AnfNodePtr> &arg, const AnfNodePtr &node) { | |||||
| auto rep = (arg).GetNode(node); | |||||
| if (rep != nullptr) { | |||||
| if (rep->isa<ValueNode>()) { | |||||
| auto value_node = rep->cast<ValueNodePtr>(); | |||||
| auto new_value_node = NewValueNode(FillZero(value_node->value())); | |||||
| new_value_node->set_has_new_value(value_node->has_new_value()); | |||||
| MS_LOG(DEBUG) << "Zeros_like replace ok " << rep->DebugString(4); | |||||
| return new_value_node; | |||||
| } | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| AnfNodePtr OperatorHandle2(const PatternNode<AnfNodePtr> &arg, const AnfNodePtr &node) { | |||||
| auto rep = (arg).GetNode(node); | |||||
| if (rep != nullptr) { | |||||
| if (rep->isa<ValueNode>() && !HasAbstractMonad(rep)) { | |||||
| auto value_node = rep->cast<ValueNodePtr>(); | |||||
| auto new_value_node = NewValueNode(FillZero(value_node->value())); | |||||
| new_value_node->set_has_new_value(value_node->has_new_value()); | |||||
| MS_LOG(DEBUG) << "Zeros_like replace ok 2 " << rep->DebugString(4); | |||||
| return new_value_node; | |||||
| } | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| void OperatorHandle3(const std::vector<PatternNode<AnfNodePtr>> &args, const AnfNodePtr &node) { | |||||
| for (size_t i = 0; i < 2; i++) { | |||||
| auto rep = (args[i]).GetNode(node); | |||||
| if (rep != nullptr && rep->isa<ValueNode>()) { | |||||
| auto value_node = rep->cast<ValueNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(value_node); | |||||
| auto &value = value_node->value(); | |||||
| MS_EXCEPTION_IF_NULL(value); | |||||
| // when the use count of value node equals to one, it only used in binop_grad_common function | |||||
| if (value->isa<tensor::Tensor>() && value_node->used_graph_count() == 1) { | |||||
| auto tensor = value->cast<tensor::TensorPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(tensor); | |||||
| auto new_tensor = std::make_shared<tensor::Tensor>(tensor->Dtype()->type_id(), tensor->shape()); | |||||
| value_node->set_value(new_tensor); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| AnfNodePtr OperatorHandle4(const PatternNode<AnfNodePtr> &arg, const PatternNode<AnfNodePtr> &arg1, | |||||
| const AnfNodePtr &node) { | |||||
| auto rep = (arg).GetNode(node); | |||||
| if (rep != nullptr) { | |||||
| if (rep->isa<ValueNode>()) { | |||||
| MS_LOG(DEBUG) << "Rep is " << rep->DebugString(4); | |||||
| ValueNodePtr new_node; | |||||
| auto value_node = rep->cast<ValueNodePtr>(); | |||||
| auto rep1 = (arg1).GetNode(node); | |||||
| if (rep1 != nullptr) { | |||||
| if (rep1->isa<ValueNode>()) { | |||||
| auto idx = rep1->cast<ValueNodePtr>(); | |||||
| if (!value_node->value()->isa<ValueTuple>()) { | |||||
| return nullptr; | |||||
| } | |||||
| new_node = NewValueNode(FillGetItem(value_node->value(), idx->value())); | |||||
| new_node->set_has_new_value(value_node->has_new_value()); | |||||
| } | |||||
| } | |||||
| MS_LOG(DEBUG) << "Fill getitem replace ok " << new_node->DebugString(4); | |||||
| return new_node; | |||||
| } | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| public: | public: | ||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | ||||
| MS_LOG(DEBUG) << "Start replace node " << node->DebugString(4); | MS_LOG(DEBUG) << "Start replace node " << node->DebugString(4); | ||||
| @@ -342,15 +439,9 @@ class PynativeEliminater : public OptimizerCaller { | |||||
| if ((pattern).TryCapture(node) && | if ((pattern).TryCapture(node) && | ||||
| (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "SymbolStr") && | (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "SymbolStr") && | ||||
| CheckSymbolVNode(c_vnode.GetNode(node), "C") && CheckStrVNode(zeros_like_vnode.GetNode(node), "zeros_like"))) { | CheckSymbolVNode(c_vnode.GetNode(node), "C") && CheckStrVNode(zeros_like_vnode.GetNode(node), "zeros_like"))) { | ||||
| auto rep = (arg).GetNode(node); | |||||
| if (rep != nullptr) { | |||||
| if (rep->isa<ValueNode>()) { | |||||
| auto value_node = rep->cast<ValueNodePtr>(); | |||||
| auto new_value_node = NewValueNode(FillZero(value_node->value())); | |||||
| new_value_node->set_has_new_value(value_node->has_new_value()); | |||||
| MS_LOG(DEBUG) << "Zeros_like replace ok " << rep->DebugString(4); | |||||
| return new_value_node; | |||||
| } | |||||
| auto new_value_node = OperatorHandle1(arg, node); | |||||
| if (new_value_node != nullptr) { | |||||
| return new_value_node; | |||||
| } | } | ||||
| } | } | ||||
| MS_LOG(DEBUG) << "End replace 1 " << node->DebugString(4); | MS_LOG(DEBUG) << "End replace 1 " << node->DebugString(4); | ||||
| @@ -360,15 +451,9 @@ class PynativeEliminater : public OptimizerCaller { | |||||
| if ((pattern1).TryCapture(node) && (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "SymbolStr") && | if ((pattern1).TryCapture(node) && (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "SymbolStr") && | ||||
| CheckSymbolVNode(zeros_like_vnode.GetNode(node), "zeros_like"))) { | CheckSymbolVNode(zeros_like_vnode.GetNode(node), "zeros_like"))) { | ||||
| auto rep = (arg).GetNode(node); | |||||
| if (rep != nullptr) { | |||||
| if (rep->isa<ValueNode>() && !HasAbstractMonad(rep)) { | |||||
| auto value_node = rep->cast<ValueNodePtr>(); | |||||
| auto new_value_node = NewValueNode(FillZero(value_node->value())); | |||||
| new_value_node->set_has_new_value(value_node->has_new_value()); | |||||
| MS_LOG(DEBUG) << "Zeros_like replace ok 2 " << rep->DebugString(4); | |||||
| return new_value_node; | |||||
| } | |||||
| auto new_value_node = OperatorHandle2(arg, node); | |||||
| if (new_value_node != nullptr) { | |||||
| return new_value_node; | |||||
| } | } | ||||
| } | } | ||||
| // {prim:getattr, {prim::resolve, SymbolStr, binop_grad_common}, x, y, out, dout} -> {shape(x), shape(y), out, dout} | // {prim:getattr, {prim::resolve, SymbolStr, binop_grad_common}, x, y, out, dout} -> {shape(x), shape(y), out, dout} | ||||
| @@ -379,22 +464,7 @@ class PynativeEliminater : public OptimizerCaller { | |||||
| auto pattern_binop = PCNode(resolve_binop, args[0], args[1], args[2], args[3]); | auto pattern_binop = PCNode(resolve_binop, args[0], args[1], args[2], args[3]); | ||||
| if ((pattern_binop).TryCapture(node) && (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "SymbolStr") && | if ((pattern_binop).TryCapture(node) && (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "SymbolStr") && | ||||
| CheckSymbolVNode(binop_grad_common.GetNode(node), "binop_grad_common"))) { | CheckSymbolVNode(binop_grad_common.GetNode(node), "binop_grad_common"))) { | ||||
| for (size_t i = 0; i < 2; i++) { | |||||
| auto rep = (args[i]).GetNode(node); | |||||
| if (rep != nullptr && rep->isa<ValueNode>()) { | |||||
| auto value_node = rep->cast<ValueNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(value_node); | |||||
| auto &value = value_node->value(); | |||||
| MS_EXCEPTION_IF_NULL(value); | |||||
| // when the use count of value node equals to one, it only used in binop_grad_common function | |||||
| if (value->isa<tensor::Tensor>() && value_node->used_graph_count() == 1) { | |||||
| auto tensor = value->cast<tensor::TensorPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(tensor); | |||||
| auto new_tensor = std::make_shared<tensor::Tensor>(tensor->Dtype()->type_id(), tensor->shape()); | |||||
| value_node->set_value(new_tensor); | |||||
| } | |||||
| } | |||||
| } | |||||
| OperatorHandle3(args, node); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| // resolve(CommonOPS, getitem)((tensors), 3) | // resolve(CommonOPS, getitem)((tensors), 3) | ||||
| @@ -403,26 +473,9 @@ class PynativeEliminater : public OptimizerCaller { | |||||
| auto pattern2 = PCNode(resolve2, arg, arg1); | auto pattern2 = PCNode(resolve2, arg, arg1); | ||||
| if ((pattern2).TryCapture(node) && (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "CommonOPS") && | if ((pattern2).TryCapture(node) && (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "CommonOPS") && | ||||
| CheckSymbolVNode(getitem_vnode.GetNode(node), "getitem"))) { | CheckSymbolVNode(getitem_vnode.GetNode(node), "getitem"))) { | ||||
| auto rep = (arg).GetNode(node); | |||||
| if (rep != nullptr) { | |||||
| if (rep->isa<ValueNode>()) { | |||||
| MS_LOG(DEBUG) << "Rep is " << rep->DebugString(4); | |||||
| ValueNodePtr new_node; | |||||
| auto value_node = rep->cast<ValueNodePtr>(); | |||||
| auto rep1 = (arg1).GetNode(node); | |||||
| if (rep1 != nullptr) { | |||||
| if (rep1->isa<ValueNode>()) { | |||||
| auto idx = rep1->cast<ValueNodePtr>(); | |||||
| if (!value_node->value()->isa<ValueTuple>()) { | |||||
| return nullptr; | |||||
| } | |||||
| new_node = NewValueNode(FillGetItem(value_node->value(), idx->value())); | |||||
| new_node->set_has_new_value(value_node->has_new_value()); | |||||
| } | |||||
| } | |||||
| MS_LOG(DEBUG) << "Fill getitem replace ok " << new_node->DebugString(4); | |||||
| return new_node; | |||||
| } | |||||
| auto new_value_node = OperatorHandle4(arg, arg1, node); | |||||
| if (new_value_node != nullptr) { | |||||
| return new_value_node; | |||||
| } | } | ||||
| } | } | ||||
| @@ -153,25 +153,27 @@ const std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitSizes(const | |||||
| } | } | ||||
| // Clear param_shapes before training in auto-parallel or semi-auto-parallel mode | // Clear param_shapes before training in auto-parallel or semi-auto-parallel mode | ||||
| void ParallelParameterContextInit(const FuncGraphPtr &func_graph) { | |||||
| void ParallelContext::ParallelParameterContextInitShape(const FuncGraphPtr &func_graph) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| if (!func_graph->has_flag(AUTO_PARALLEL) || !func_graph->has_flag(TRAINING)) { | |||||
| return; | |||||
| if (func_graph->has_flag(AUTO_PARALLEL) && | |||||
| (!func_graph->has_flag(TRAINING) || | |||||
| (ParallelContext::GetInstance()->grad_accumulation_step() > 1 && !func_graph->has_flag(ACCUMULATION)))) { | |||||
| init_param_shape_ = false; | |||||
| } else { | |||||
| param_shapes.clear(); | |||||
| init_param_shape_ = true; | |||||
| } | } | ||||
| param_shapes.clear(); | |||||
| } | } | ||||
| // Restore the parameters' shape for evaluation/prediction in auto-parallel or semi-auto-parallel mode | // Restore the parameters' shape for evaluation/prediction in auto-parallel or semi-auto-parallel mode | ||||
| void ParallelParameterContextRestoreInNoTraining(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, | |||||
| AbstractBasePtr ptr) { | |||||
| void ParallelContext::ParallelParameterContextRestoreShape(const FuncGraphPtr &func_graph, | |||||
| const ParameterPtr ¶m_node, AbstractBasePtr ptr) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| MS_EXCEPTION_IF_NULL(param_node); | MS_EXCEPTION_IF_NULL(param_node); | ||||
| MS_EXCEPTION_IF_NULL(ptr); | MS_EXCEPTION_IF_NULL(ptr); | ||||
| if (!func_graph->has_flag(AUTO_PARALLEL) || (func_graph->attrs().count(TRAINING) == 0) || | |||||
| func_graph->has_flag(TRAINING)) { | |||||
| if (init_param_shape_) { | |||||
| return; | return; | ||||
| } | } | ||||
| auto iter = param_shapes.find(param_node->name()); | auto iter = param_shapes.find(param_node->name()); | ||||
| if (iter == param_shapes.end()) { | if (iter == param_shapes.end()) { | ||||
| MS_LOG(WARNING) << "Can not found the shape for parameter " << param_node->name(); | MS_LOG(WARNING) << "Can not found the shape for parameter " << param_node->name(); | ||||
| @@ -183,16 +185,16 @@ void ParallelParameterContextRestoreInNoTraining(const FuncGraphPtr &func_graph, | |||||
| MS_LOG(DEBUG) << "The parameter name is " << param_node->name() << ", the shape is " << shape; | MS_LOG(DEBUG) << "The parameter name is " << param_node->name() << ", the shape is " << shape; | ||||
| } | } | ||||
| // Clear param_shapes before training in auto-parallel or semi-auto-parallel mode | |||||
| // Checkpoint the parameters' shape for training in auto-parallel or semi-auto-parallel mode | // Checkpoint the parameters' shape for training in auto-parallel or semi-auto-parallel mode | ||||
| void ParallelParameterContextCkptInTraining(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, | |||||
| const AbstractBasePtr &ptr) { | |||||
| void ParallelContext::ParallelParameterContextCkptShape(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, | |||||
| const AbstractBasePtr &ptr) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| MS_EXCEPTION_IF_NULL(param_node); | MS_EXCEPTION_IF_NULL(param_node); | ||||
| MS_EXCEPTION_IF_NULL(ptr); | MS_EXCEPTION_IF_NULL(ptr); | ||||
| if (!func_graph->has_flag(AUTO_PARALLEL) || !func_graph->has_flag(TRAINING)) { | |||||
| if (!init_param_shape_) { | |||||
| return; | return; | ||||
| } | } | ||||
| std::vector<int64_t> shape = dyn_cast<abstract::Shape>(ptr->GetShapeTrack())->shape(); | std::vector<int64_t> shape = dyn_cast<abstract::Shape>(ptr->GetShapeTrack())->shape(); | ||||
| auto ret = param_shapes.try_emplace(param_node->name(), shape); | auto ret = param_shapes.try_emplace(param_node->name(), shape); | ||||
| if (!ret.second) { | if (!ret.second) { | ||||
| @@ -30,6 +30,7 @@ | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "utils/convert_utils.h" | #include "utils/convert_utils.h" | ||||
| #include "utils/info.h" | #include "utils/info.h" | ||||
| #include "pipeline/jit/pipeline.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parallel { | namespace parallel { | ||||
| @@ -43,6 +44,7 @@ constexpr char DYNAMIC_PROGRAMMING[] = "dynamic_programming"; | |||||
| constexpr char RECURSIVE_PROGRAMMING[] = "recursive_programming"; | constexpr char RECURSIVE_PROGRAMMING[] = "recursive_programming"; | ||||
| constexpr char TRAINING[] = "training"; | constexpr char TRAINING[] = "training"; | ||||
| constexpr char ACCUMULATION[] = "accumulation"; | |||||
| class ParallelContext { | class ParallelContext { | ||||
| public: | public: | ||||
| @@ -111,6 +113,11 @@ class ParallelContext { | |||||
| bool enable_parallel_optimizer() const { return enable_parallel_optimizer_; } | bool enable_parallel_optimizer() const { return enable_parallel_optimizer_; } | ||||
| void Reset(); | void Reset(); | ||||
| void ParallelParameterContextInitShape(const FuncGraphPtr &func_graph); | |||||
| void ParallelParameterContextRestoreShape(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, | |||||
| AbstractBasePtr ptr); | |||||
| void ParallelParameterContextCkptShape(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, | |||||
| const AbstractBasePtr &ptr); | |||||
| private: | private: | ||||
| ParallelContext(); | ParallelContext(); | ||||
| @@ -136,13 +143,9 @@ class ParallelContext { | |||||
| std::string strategy_ckpt_save_file_; | std::string strategy_ckpt_save_file_; | ||||
| std::string group_ckpt_save_file_; | std::string group_ckpt_save_file_; | ||||
| bool enable_parallel_optimizer_; | bool enable_parallel_optimizer_; | ||||
| bool init_param_shape_; | |||||
| }; | }; | ||||
| void ParallelParameterContextInit(const FuncGraphPtr &func_graph); | |||||
| void ParallelParameterContextRestoreInNoTraining(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, | |||||
| AbstractBasePtr ptr); | |||||
| void ParallelParameterContextCkptInTraining(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, | |||||
| const AbstractBasePtr &ptr); | |||||
| } // namespace parallel | } // namespace parallel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -284,6 +284,39 @@ Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string & | |||||
| return op; | return op; | ||||
| } | } | ||||
| void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr ¶m_node) { | |||||
| MS_EXCEPTION_IF_NULL(comm_node); | |||||
| MS_EXCEPTION_IF_NULL(param_node); | |||||
| if (IsPrimitiveCNode(param_node, prim::kPrimReceive)) { | |||||
| MS_LOG(WARNING) << "The mirror of Receive does not support fusion type now."; | |||||
| return; | |||||
| } | |||||
| auto param = param_node->cast<ParameterPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(param); | |||||
| auto prim = GetValueNode<PrimitivePtr>(comm_node->input(0)); | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| auto attrs = prim->attrs(); | |||||
| auto param_info = param->param_info(); | |||||
| if (!param_info) { | |||||
| MS_LOG(WARNING) << param->ToString() << "does not have parameter info."; | |||||
| return; | |||||
| } | |||||
| int32_t fusion_type = param_info->comm_fusion(); | |||||
| attrs[FUSION] = MakeValue<int64_t>(fusion_type); | |||||
| prim->SetAttrs(attrs); | |||||
| MS_LOG(INFO) << "Set comm fusion:" << param->param_info()->name() << "'s fusion type is " << fusion_type; | |||||
| } | |||||
| void AddCommOpMeanFlag(const CNodePtr &comm_node) { | |||||
| MS_EXCEPTION_IF_NULL(comm_node); | |||||
| auto prim = GetValueNode<PrimitivePtr>(comm_node->input(0)); | |||||
| auto attrs = prim->attrs(); | |||||
| MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); | |||||
| bool mean_flag = ParallelContext::GetInstance()->gradients_mean(); | |||||
| attrs[MEAN_FLAG] = MakeValue<bool>(mean_flag); | |||||
| prim->SetAttrs(attrs); | |||||
| } | |||||
| Operator CreateAllGatherOp(const std::string &group) { | Operator CreateAllGatherOp(const std::string &group) { | ||||
| OperatorName operator_name = ALL_GATHER; | OperatorName operator_name = ALL_GATHER; | ||||
| ValuePtr attr0_value = MakeValue(group); // group | ValuePtr attr0_value = MakeValue(group); // group | ||||
| @@ -299,6 +332,30 @@ Operator CreateAllGatherOp(const std::string &group) { | |||||
| return op; | return op; | ||||
| } | } | ||||
| Operator CreateMiniStepAllGatherOp(const std::string &group) { | |||||
| int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step(); | |||||
| bool mean_flag = ParallelContext::GetInstance()->gradients_mean(); | |||||
| OperatorName operator_name = MINI_STEP_ALL_GATHER; | |||||
| ValuePtr attr0_value = MakeValue(group); // group | |||||
| Attr attr0 = std::make_pair(GROUP, attr0_value); | |||||
| ValuePtr attr1_value = MakeValue(grad_accumulation_step); // grad_accumulation_step | |||||
| Attr attr1 = std::make_pair(GRAD_ACCUMULATION_STEP, attr1_value); | |||||
| ValuePtr attr2_value = MakeValue(mean_flag); // mean_flag | |||||
| Attr attr2 = std::make_pair(MEAN_FLAG, attr2_value); | |||||
| OperatorAttrs operator_attrs; | |||||
| operator_attrs.push_back(attr0); | |||||
| operator_attrs.push_back(attr1); | |||||
| operator_attrs.push_back(attr2); | |||||
| OperatorParams operator_param; | |||||
| OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param); | |||||
| Operator op = std::make_pair(operator_name, operator_arg); | |||||
| MS_LOG(INFO) << "Create MINI_STEP_ALL_GATHER success, the group is " << group; | |||||
| return op; | |||||
| } | |||||
| // use for get tensor slice | // use for get tensor slice | ||||
| Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout) { | Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout) { | ||||
| Shape tensor_map = tensor_layout.tensor_map().array(); | Shape tensor_map = tensor_layout.tensor_map().array(); | ||||
| @@ -771,7 +828,7 @@ void OperatorInfo::ComputeBatchSplitFlagList() { | |||||
| ReComputeBatchSplitFlagList(); | ReComputeBatchSplitFlagList(); | ||||
| } | } | ||||
| // This is a common method for checking whether the generated stragegy has the correct number of devuces. | |||||
| // This is a common method for checking whether the generated strategy has the correct number of devuces. | |||||
| Status PrepareStrategyBase(int64_t stage_id, size_t dev_num, const Shapes &inputs_partitions, StrategyPtr *const sp) { | Status PrepareStrategyBase(int64_t stage_id, size_t dev_num, const Shapes &inputs_partitions, StrategyPtr *const sp) { | ||||
| if (sp == nullptr) { | if (sp == nullptr) { | ||||
| MS_LOG(ERROR) << "The strategy is null."; | MS_LOG(ERROR) << "The strategy is null."; | ||||
| @@ -886,7 +943,7 @@ Status GenerateStrategiesForBroadcastLeft(int64_t stage_id, const Shapes &inputs | |||||
| (void)input0_strategy.erase(input0_strategy.begin(), | (void)input0_strategy.erase(input0_strategy.begin(), | ||||
| input0_strategy.begin() + static_cast<different_type>(size_diff)); | input0_strategy.begin() + static_cast<different_type>(size_diff)); | ||||
| // handel the case likes ([1, c, d], [a, b, c, d]) | |||||
| // handle the case likes ([1, c, d], [a, b, c, d]) | |||||
| for (size_t i = 0; i < inputs_shape[0].size(); ++i) { | for (size_t i = 0; i < inputs_shape[0].size(); ++i) { | ||||
| if (inputs_shape[0][i] == 1) { | if (inputs_shape[0][i] == 1) { | ||||
| input0_strategy[i] = 1; | input0_strategy[i] = 1; | ||||
| @@ -937,7 +994,7 @@ Status GenerateStrategiesForBroadcastRight(int64_t stage_id, const Shapes &input | |||||
| (void)input1_strategy.erase(input1_strategy.begin(), | (void)input1_strategy.erase(input1_strategy.begin(), | ||||
| input1_strategy.begin() + static_cast<different_type>(size_diff)); | input1_strategy.begin() + static_cast<different_type>(size_diff)); | ||||
| // handel the case likes ([a, b, c, d], [1, c, d]) | |||||
| // handle the case likes ([a, b, c, d], [1, c, d]) | |||||
| for (size_t i = 0; i < inputs_shape[1].size(); ++i) { | for (size_t i = 0; i < inputs_shape[1].size(); ++i) { | ||||
| if (inputs_shape[1][i] == 1) { | if (inputs_shape[1][i] == 1) { | ||||
| input1_strategy[i] = 1; | input1_strategy[i] = 1; | ||||
| @@ -36,6 +36,7 @@ | |||||
| #include "frontend/parallel/strategy.h" | #include "frontend/parallel/strategy.h" | ||||
| #include "frontend/parallel/tensor_layout/tensor_info.h" | #include "frontend/parallel/tensor_layout/tensor_info.h" | ||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| #include "base/core_ops.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parallel { | namespace parallel { | ||||
| @@ -160,7 +161,7 @@ class OperatorInfo { | |||||
| void set_refkey_parameter_name(std::string p_name) { refkey_parameter_name_ = std::move(p_name); } | void set_refkey_parameter_name(std::string p_name) { refkey_parameter_name_ = std::move(p_name); } | ||||
| const std::string &refkey_parameter_name() const { return refkey_parameter_name_; } | const std::string &refkey_parameter_name() const { return refkey_parameter_name_; } | ||||
| // When the output of a Parameter (require_grad) being used by multiple operators, the Parameter's cost is calculated | // When the output of a Parameter (require_grad) being used by multiple operators, the Parameter's cost is calculated | ||||
| // multiple times. This method is to correct this, and makes the cost is calulated only once. | |||||
| // multiple times. This method is to correct this, and makes the cost is calculated only once. | |||||
| Status CorrectMemoryCost(size_t input_index); | Status CorrectMemoryCost(size_t input_index); | ||||
| int64_t is_output_parameter_involve() const { return is_output_parameter_involve_; } | int64_t is_output_parameter_involve() const { return is_output_parameter_involve_; } | ||||
| int64_t is_output_critical() const { return is_output_critical_; } | int64_t is_output_critical() const { return is_output_critical_; } | ||||
| @@ -242,7 +243,7 @@ class OperatorInfo { | |||||
| bool is_auto_parallel_ = false; // false: semi_auto_parallel; true: auto_parallel | bool is_auto_parallel_ = false; // false: semi_auto_parallel; true: auto_parallel | ||||
| // 'corrected_input_indices_' used to store the indices of input that have ALREADY been corrected. | // 'corrected_input_indices_' used to store the indices of input that have ALREADY been corrected. | ||||
| std::vector<size_t> corrected_input_indices_; | std::vector<size_t> corrected_input_indices_; | ||||
| // Given a parallization strategy, there is a cost. | |||||
| // Given a parallelization strategy, there is a cost. | |||||
| std::vector<std::shared_ptr<StrategyWithCost>> strategy_cost_; | std::vector<std::shared_ptr<StrategyWithCost>> strategy_cost_; | ||||
| // For each input in 'inputs_', there is a bool variable indicating whether that the corresponding input is parameter | // For each input in 'inputs_', there is a bool variable indicating whether that the corresponding input is parameter | ||||
| std::vector<bool> is_parameter_; | std::vector<bool> is_parameter_; | ||||
| @@ -288,6 +289,9 @@ Operator CreateVirtualDivOp(int64_t div_num); | |||||
| Operator CreateAllReduceOp(const std::string &reduce_op, const std::string &group); | Operator CreateAllReduceOp(const std::string &reduce_op, const std::string &group); | ||||
| Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string &group); | Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string &group); | ||||
| Operator CreateAllGatherOp(const std::string &group); | Operator CreateAllGatherOp(const std::string &group); | ||||
| Operator CreateMiniStepAllGatherOp(const std::string &group); | |||||
| void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr ¶m_node); | |||||
| void AddCommOpMeanFlag(const CNodePtr &comm_node); | |||||
| Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout); | Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout); | ||||
| OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num); | OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num); | ||||
| int64_t ComputeRepeatDeviceNumByTensorMap(const Shape &dev_matrix_shape, const Shape &tensor_map); | int64_t ComputeRepeatDeviceNumByTensorMap(const Shape &dev_matrix_shape, const Shape &tensor_map); | ||||
| @@ -109,6 +109,7 @@ constexpr char END[] = "end"; | |||||
| constexpr char STRIDES[] = "strides"; | constexpr char STRIDES[] = "strides"; | ||||
| constexpr char GROUP[] = "group"; | constexpr char GROUP[] = "group"; | ||||
| constexpr char FUSION[] = "fusion"; | constexpr char FUSION[] = "fusion"; | ||||
| constexpr char DO_MIRROR[] = "do_mirror"; | |||||
| constexpr char NUM_SAMPLED[] = "num_sampled"; | constexpr char NUM_SAMPLED[] = "num_sampled"; | ||||
| constexpr char NUM_TRUE[] = "num_true"; | constexpr char NUM_TRUE[] = "num_true"; | ||||
| constexpr char SEED[] = "seed"; | constexpr char SEED[] = "seed"; | ||||
| @@ -180,6 +181,7 @@ constexpr char MIRROR_MINI_STEP_OPERATOR[] = "_MirrorMiniStepOperator"; | |||||
| constexpr char LOCAL_STEP[] = "local_step"; | constexpr char LOCAL_STEP[] = "local_step"; | ||||
| constexpr char STRIDED_SLICE[] = "StridedSlice"; | constexpr char STRIDED_SLICE[] = "StridedSlice"; | ||||
| constexpr char ALL_GATHER[] = "AllGather"; | constexpr char ALL_GATHER[] = "AllGather"; | ||||
| constexpr char MINI_STEP_ALL_GATHER[] = "_MiniStepAllGather"; | |||||
| constexpr char REDUCE_SCATTER[] = "ReduceScatter"; | constexpr char REDUCE_SCATTER[] = "ReduceScatter"; | ||||
| constexpr char HOST_REDUCE_SCATTER[] = "_HostReduceScatter"; | constexpr char HOST_REDUCE_SCATTER[] = "_HostReduceScatter"; | ||||
| constexpr char EMBEDDING_LOOKUP[] = "EmbeddingLookup"; | constexpr char EMBEDDING_LOOKUP[] = "EmbeddingLookup"; | ||||
| @@ -65,8 +65,8 @@ void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input) { | |||||
| return; | return; | ||||
| } | } | ||||
| ValueNodePtr prim_anf_node = new_node_input[0]->cast<ValueNodePtr>(); | |||||
| PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); | |||||
| auto prim_anf_node = new_node_input[0]->cast<ValueNodePtr>(); | |||||
| auto prim = GetValueNode<PrimitivePtr>(prim_anf_node); | |||||
| MS_EXCEPTION_IF_NULL(prim); | MS_EXCEPTION_IF_NULL(prim); | ||||
| auto attrs = prim->attrs(); | auto attrs = prim->attrs(); | ||||
| @@ -83,6 +83,19 @@ void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input) { | |||||
| } | } | ||||
| } | } | ||||
| void SetMiniStepOpDoMirrorLabel(std::vector<AnfNodePtr> new_node_input, bool accu_flag) { | |||||
| if (new_node_input.empty()) { | |||||
| return; | |||||
| } | |||||
| auto prim_anf_node = new_node_input[0]->cast<ValueNodePtr>(); | |||||
| auto prim = GetValueNode<PrimitivePtr>(prim_anf_node); | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| auto attrs = prim->attrs(); | |||||
| attrs[DO_MIRROR] = MakeValue<bool>(!accu_flag); | |||||
| prim->SetAttrs(attrs); | |||||
| } | |||||
| std::vector<AnfNodePtr> CreateInput(const Operator &op, const AnfNodePtr &node, const std::string &instance_name) { | std::vector<AnfNodePtr> CreateInput(const Operator &op, const AnfNodePtr &node, const std::string &instance_name) { | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| OperatorArgs arg_forward = op.second; | OperatorArgs arg_forward = op.second; | ||||
| @@ -157,7 +170,6 @@ std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operat | |||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| MS_EXCEPTION_IF_NULL(root->manager()); | MS_EXCEPTION_IF_NULL(root->manager()); | ||||
| AnfNodePtr local_step_param = nullptr; | |||||
| AnfNodePtr grad_accu = nullptr; | AnfNodePtr grad_accu = nullptr; | ||||
| std::string op_name = op.first; | std::string op_name = op.first; | ||||
| OperatorArgs arg_forward = op.second; | OperatorArgs arg_forward = op.second; | ||||
| @@ -165,25 +177,7 @@ std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operat | |||||
| int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step(); | int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step(); | ||||
| if (grad_accumulation_step > 1) { | if (grad_accumulation_step > 1) { | ||||
| bool find_locat_step_node = false; | |||||
| auto parameters = root->parameters(); | auto parameters = root->parameters(); | ||||
| for (auto ¶m : parameters) { | |||||
| auto param_ptr = param->cast<ParameterPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(param_ptr); | |||||
| if (param_ptr->name() == LOCAL_STEP) { | |||||
| auto param_users = root->manager()->node_users()[param]; | |||||
| for (auto &user : param_users) { | |||||
| if (AnfNodeIsPrimitive(user.first, ASSIGN)) { | |||||
| find_locat_step_node = true; | |||||
| local_step_param = user.first; | |||||
| MS_LOG(INFO) << "Find the local step when create mirror, it may be in the mini step grad accumulation mode"; | |||||
| break; | |||||
| } | |||||
| } | |||||
| break; | |||||
| } | |||||
| } | |||||
| bool find_grad_accu_node = false; | bool find_grad_accu_node = false; | ||||
| for (auto ¶m : parameters) { | for (auto ¶m : parameters) { | ||||
| if (!ParameterIsCloned(param)) { | if (!ParameterIsCloned(param)) { | ||||
| @@ -201,10 +195,12 @@ std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operat | |||||
| } | } | ||||
| } | } | ||||
| if (op_name == MIRROR_MINI_STEP_OPERATOR) { | |||||
| if (!find_locat_step_node || !find_grad_accu_node) { | |||||
| if (!find_grad_accu_node) { | |||||
| if (op_name == MIRROR_MINI_STEP_OPERATOR) { | |||||
| op_name = MIRROR_OPERATOR; | op_name = MIRROR_OPERATOR; | ||||
| arg_forward.first.pop_back(); | arg_forward.first.pop_back(); | ||||
| } else if (op_name == MINI_STEP_ALL_GATHER) { | |||||
| MS_LOG(EXCEPTION) << "You should define `accu_grads` when enable gradient accumulation."; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -214,9 +210,9 @@ std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operat | |||||
| OperatorParams params = arg_forward.second; | OperatorParams params = arg_forward.second; | ||||
| std::vector<AnfNodePtr> new_node_input; | std::vector<AnfNodePtr> new_node_input; | ||||
| if (op_name == MIRROR_MINI_STEP_OPERATOR) { | |||||
| new_node_input = {NewValueNode(pyop_instance), node, local_step_param, grad_accu}; | |||||
| MS_LOG(INFO) << "Insert the local step node and grad accumulation node as the mirror op's input"; | |||||
| if (op_name == MIRROR_MINI_STEP_OPERATOR || op_name == MINI_STEP_ALL_GATHER) { | |||||
| new_node_input = {NewValueNode(pyop_instance), node, grad_accu}; | |||||
| MS_LOG(INFO) << "Insert the grad accumulation node as the mirror op's input"; | |||||
| } else { | } else { | ||||
| new_node_input = {NewValueNode(pyop_instance), node}; | new_node_input = {NewValueNode(pyop_instance), node}; | ||||
| } | } | ||||
| @@ -232,6 +228,10 @@ std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operat | |||||
| // if the op have 'group' attr, set the rank list name for the op | // if the op have 'group' attr, set the rank list name for the op | ||||
| SetCommunicationOpGroupLabel(new_node_input); | SetCommunicationOpGroupLabel(new_node_input); | ||||
| // gradient accumulation | |||||
| if (grad_accumulation_step > 1) { | |||||
| SetMiniStepOpDoMirrorLabel(new_node_input, root->has_flag(ACCUMULATION)); | |||||
| } | |||||
| return new_node_input; | return new_node_input; | ||||
| } | } | ||||
| @@ -284,6 +284,31 @@ static CNodePtr ReplaceNode(const Operator &op, const AnfNodePtr &pre_node, cons | |||||
| return new_node; | return new_node; | ||||
| } | } | ||||
| // Replace pre_node with pre_node->op | |||||
| static CNodePtr ReplaceMirrorNode(const FuncGraphPtr &root, const Operator &op, const AnfNodePtr &pre_node, | |||||
| const FuncGraphPtr &func_graph, const std::string &instance_name, | |||||
| const std::string ¶m_name) { | |||||
| // insert new node before the node | |||||
| FuncGraphManagerPtr manager = func_graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| ScopePtr scope = pre_node->scope(); | |||||
| MS_EXCEPTION_IF_NULL(scope); | |||||
| std::vector<AnfNodePtr> node_input = CreateMirrorInput(root, op, pre_node, instance_name, param_name); | |||||
| CNodePtr new_node = func_graph->NewCNode(node_input); | |||||
| MS_EXCEPTION_IF_NULL(new_node); | |||||
| if (instance_name.find(SPLIT_SENS) == std::string::npos) { | |||||
| new_node->set_in_forward_flag(true); // mark forward flag | |||||
| } | |||||
| auto new_node_prim = GetValueNode<PrimitivePtr>(node_input[0]); | |||||
| new_node_prim->set_instance_name(instance_name); | |||||
| new_node_prim->set_attr("keep_value_node_input", MakeValue(true)); | |||||
| new_node->set_scope(scope); | |||||
| node_input[0]->set_scope(scope); | |||||
| manager->Replace(pre_node, new_node); | |||||
| MS_LOG(INFO) << "Insert " << instance_name << " success"; | |||||
| return new_node; | |||||
| } | |||||
| std::string CreateInstanceName(const CNodePtr &node, size_t index) { | std::string CreateInstanceName(const CNodePtr &node, size_t index) { | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| if (!IsValueNode<Primitive>(node->input(0))) { | if (!IsValueNode<Primitive>(node->input(0))) { | ||||
| @@ -1085,29 +1110,6 @@ bool IsCastBeforMirror(const CNodePtr &node, size_t index) { | |||||
| return (type_id != kNumberTypeFloat32); | return (type_id != kNumberTypeFloat32); | ||||
| } | } | ||||
| static void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr ¶m_node) { | |||||
| MS_EXCEPTION_IF_NULL(comm_node); | |||||
| MS_EXCEPTION_IF_NULL(param_node); | |||||
| if (IsPrimitiveCNode(param_node, prim::kPrimReceive)) { | |||||
| MS_LOG(WARNING) << "The mirror of Receive does not support fusion type now."; | |||||
| return; | |||||
| } | |||||
| auto param = param_node->cast<ParameterPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(param); | |||||
| auto prim = GetValueNode<PrimitivePtr>(comm_node->input(0)); | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| auto attrs = prim->attrs(); | |||||
| auto param_info = param->param_info(); | |||||
| if (!param_info) { | |||||
| MS_LOG(WARNING) << param->ToString() << "does not have parameter info."; | |||||
| return; | |||||
| } | |||||
| int32_t fusion_type = param_info->comm_fusion(); | |||||
| attrs[FUSION] = MakeValue<int64_t>(fusion_type); | |||||
| prim->SetAttrs(attrs); | |||||
| MS_LOG(INFO) << "Set comm fusion:" << param->param_info()->name() << "'s fusion type is " << fusion_type; | |||||
| } | |||||
| static bool CheckInsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node, size_t node_size) { | static bool CheckInsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node, size_t node_size) { | ||||
| if ((node->inputs().size() == 2) && (IsValueNode<ValueSequeue>(node->input(1)))) { | if ((node->inputs().size() == 2) && (IsValueNode<ValueSequeue>(node->input(1)))) { | ||||
| MS_LOG(INFO) << "Input is ValueList, skip it."; | MS_LOG(INFO) << "Input is ValueList, skip it."; | ||||
| @@ -1194,7 +1196,6 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons | |||||
| InsertMirrorNode(root, op, cnode, size_t(1), pre_node, func_graph, instance_name, param_name); | InsertMirrorNode(root, op, cnode, size_t(1), pre_node, func_graph, instance_name, param_name); | ||||
| auto comm_op = cnode->input(size_t(1))->cast<CNodePtr>(); | auto comm_op = cnode->input(size_t(1))->cast<CNodePtr>(); | ||||
| // add fusion flag | // add fusion flag | ||||
| // pipeline mirror would not be set, which should be supported later | |||||
| AddCommOpFusionType(comm_op, param_node_pair.first); | AddCommOpFusionType(comm_op, param_node_pair.first); | ||||
| } | } | ||||
| continue; | continue; | ||||
| @@ -1539,33 +1540,40 @@ std::pair<AnfNodePtr, int64_t> FindSubGraph(const FuncGraphPtr &graph, const Anf | |||||
| return std::make_pair(nullptr, 0); | return std::make_pair(nullptr, 0); | ||||
| } | } | ||||
| static void InsertAllGatherOp(const std::string &group, const std::pair<AnfNodePtr, int> &res, | |||||
| const AnfNodePtr ¶meter) { | |||||
| Operator op = CreateAllGatherOp(group); | |||||
| static void InsertAllGatherOp(const FuncGraphPtr &root, const std::string &group, const std::pair<AnfNodePtr, int> &res, | |||||
| const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(res.first); | MS_EXCEPTION_IF_NULL(res.first); | ||||
| MS_EXCEPTION_IF_NULL(parameter); | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| auto cnode = res.first->cast<CNodePtr>(); | auto cnode = res.first->cast<CNodePtr>(); | ||||
| auto graph = cnode->func_graph(); | auto graph = cnode->func_graph(); | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| auto cnode_prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | auto cnode_prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | ||||
| MS_EXCEPTION_IF_NULL(cnode_prim); | MS_EXCEPTION_IF_NULL(cnode_prim); | ||||
| int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step(); | |||||
| Operator op; | |||||
| CNodePtr allgather; | CNodePtr allgather; | ||||
| if (cnode_prim->name() == CAST) { | |||||
| allgather = ReplaceNode(op, cnode, graph, PARALLEL_OPTIMIZER_ALLGATHER); | |||||
| if (grad_accumulation_step > 1) { | |||||
| op = CreateMiniStepAllGatherOp(group); | |||||
| auto param_name = node->cast<ParameterPtr>()->name(); | |||||
| if (cnode_prim->name() == CAST) { | |||||
| allgather = ReplaceMirrorNode(root, op, cnode, graph, PARALLEL_OPTIMIZER_ALLGATHER, param_name); | |||||
| } else { | |||||
| InsertMirrorNode(root, op, cnode, res.second, node, graph, PARALLEL_OPTIMIZER_ALLGATHER, param_name); | |||||
| allgather = cnode->input(res.second)->cast<CNodePtr>(); | |||||
| } | |||||
| } else { | } else { | ||||
| InsertNode(op, cnode, res.second, parameter, graph, PARALLEL_OPTIMIZER_ALLGATHER); | |||||
| allgather = cnode->input(res.second)->cast<CNodePtr>(); | |||||
| op = CreateAllGatherOp(group); | |||||
| if (cnode_prim->name() == CAST) { | |||||
| allgather = ReplaceNode(op, cnode, graph, PARALLEL_OPTIMIZER_ALLGATHER); | |||||
| } else { | |||||
| InsertNode(op, cnode, res.second, node, graph, PARALLEL_OPTIMIZER_ALLGATHER); | |||||
| allgather = cnode->input(res.second)->cast<CNodePtr>(); | |||||
| } | |||||
| } | } | ||||
| MS_EXCEPTION_IF_NULL(allgather); | |||||
| // add fusion flag | // add fusion flag | ||||
| AddCommOpFusionType(allgather, parameter); | |||||
| AddCommOpFusionType(allgather, node); | |||||
| // add gradients mean | // add gradients mean | ||||
| auto prim = GetValueNode<PrimitivePtr>(allgather->input(0)); | |||||
| auto attrs = prim->attrs(); | |||||
| MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); | |||||
| bool mean_flag = ParallelContext::GetInstance()->gradients_mean(); | |||||
| attrs["mean_flag"] = MakeValue<bool>(mean_flag); | |||||
| prim->SetAttrs(attrs); | |||||
| AddCommOpMeanFlag(allgather); | |||||
| } | } | ||||
| static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr ¶meter, | static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr ¶meter, | ||||
| @@ -1588,7 +1596,7 @@ static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr & | |||||
| << distribute_operator->inputs_tensor_info().size(); | << distribute_operator->inputs_tensor_info().size(); | ||||
| } | } | ||||
| // insert allgather operator between shard parameter and cnode | // insert allgather operator between shard parameter and cnode | ||||
| InsertAllGatherOp(opt_shard_group, param_pair, parameter); | |||||
| InsertAllGatherOp(root, opt_shard_group, param_pair, parameter); | |||||
| MS_LOG(INFO) << "Parallel optimizer is applied between " << parameter->ToString() << " and " << cnode->ToString(); | MS_LOG(INFO) << "Parallel optimizer is applied between " << parameter->ToString() << " and " << cnode->ToString(); | ||||
| } | } | ||||
| } | } | ||||
| @@ -1733,12 +1741,20 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) { | |||||
| if (found_be_cloned_parameter) { | if (found_be_cloned_parameter) { | ||||
| // set the shape and tensor layout for cloned parameter | // set the shape and tensor layout for cloned parameter | ||||
| std::string param_name = cloned_parameter_node->cast<ParameterPtr>()->name(); | |||||
| cloned_parameter->set_user_data<TensorLayout>(cloned_from_parameter->user_data<TensorLayout>()); | cloned_parameter->set_user_data<TensorLayout>(cloned_from_parameter->user_data<TensorLayout>()); | ||||
| MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract()); | MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract()); | ||||
| MS_EXCEPTION_IF_NULL(cloned_from_node->abstract()); | MS_EXCEPTION_IF_NULL(cloned_from_node->abstract()); | ||||
| auto cloned_abstract = cloned_parameter_node->abstract()->Clone(); | auto cloned_abstract = cloned_parameter_node->abstract()->Clone(); | ||||
| MS_EXCEPTION_IF_NULL(cloned_abstract); | MS_EXCEPTION_IF_NULL(cloned_abstract); | ||||
| cloned_abstract->set_shape(cloned_from_node->abstract()->GetShapeTrack()); | |||||
| if (param_name.find(ACCU_GRADS) != std::string::npos) { | |||||
| auto slice_shape = cloned_from_parameter->user_data<TensorLayout>()->slice_shape().array(); | |||||
| std::shared_ptr<abstract::BaseShape> parallel_shape = std::make_shared<abstract::Shape>(slice_shape); | |||||
| MS_EXCEPTION_IF_NULL(parallel_shape); | |||||
| cloned_abstract->set_shape(parallel_shape); | |||||
| } else { | |||||
| cloned_abstract->set_shape(cloned_from_node->abstract()->GetShapeTrack()); | |||||
| } | |||||
| cloned_parameter_node->set_abstract(cloned_abstract); | cloned_parameter_node->set_abstract(cloned_abstract); | ||||
| MS_LOG(INFO) << "The parameter: " << cloned_parameter->name() | MS_LOG(INFO) << "The parameter: " << cloned_parameter->name() | ||||
| << " is cloned, the be cloned parameter is: " << cloned_from_parameter->name() | << " is cloned, the be cloned parameter is: " << cloned_from_parameter->name() | ||||
| @@ -30,6 +30,8 @@ | |||||
| #include "frontend/parallel/strategy.h" | #include "frontend/parallel/strategy.h" | ||||
| #include "frontend/parallel/tensor_layout/tensor_redistribution.h" | #include "frontend/parallel/tensor_layout/tensor_redistribution.h" | ||||
| #include "pipeline/jit/pipeline.h" | #include "pipeline/jit/pipeline.h" | ||||
| #include "frontend/parallel/ops_info/ops_utils.h" | |||||
| #include "frontend/parallel/auto_parallel/operator_costmodel.h" | |||||
| using OperatorInfoPtr = std::shared_ptr<mindspore::parallel::OperatorInfo>; | using OperatorInfoPtr = std::shared_ptr<mindspore::parallel::OperatorInfo>; | ||||
| @@ -258,9 +258,9 @@ bool AbstractSpecializeAction(const ResourcePtr &res) { | |||||
| FuncGraphPtr func_graph = res->func_graph(); | FuncGraphPtr func_graph = res->func_graph(); | ||||
| abstract::AbstractBasePtrList args_spec = res->args_spec(); | abstract::AbstractBasePtrList args_spec = res->args_spec(); | ||||
| parallel::ParallelParameterContextInit(func_graph); | |||||
| auto context = parallel::ParallelContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(parallel::ParallelContext::GetInstance()); | |||||
| context->ParallelParameterContextInitShape(func_graph); | |||||
| // suppose that there is not KeywordArgument for the top graph | // suppose that there is not KeywordArgument for the top graph | ||||
| // get the hyper parameter | // get the hyper parameter | ||||
| for (const auto ¶m : func_graph->parameters()) { | for (const auto ¶m : func_graph->parameters()) { | ||||
| @@ -271,9 +271,9 @@ bool AbstractSpecializeAction(const ResourcePtr &res) { | |||||
| auto ref_key = std::make_shared<RefKey>(param_node->name()); | auto ref_key = std::make_shared<RefKey>(param_node->name()); | ||||
| auto abs_ref_key = ref_key->ToAbstract(); | auto abs_ref_key = ref_key->ToAbstract(); | ||||
| auto abs_ref = std::make_shared<abstract::AbstractRef>(abs_ref_key, abs_value); | auto abs_ref = std::make_shared<abstract::AbstractRef>(abs_ref_key, abs_value); | ||||
| parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, abs_ref); | |||||
| context->ParallelParameterContextRestoreShape(func_graph, param_node, abs_ref); | |||||
| args_spec.push_back(abs_ref); | args_spec.push_back(abs_ref); | ||||
| parallel::ParallelParameterContextCkptInTraining(func_graph, param_node, abs_ref); | |||||
| context->ParallelParameterContextCkptShape(func_graph, param_node, abs_ref); | |||||
| } | } | ||||
| } | } | ||||
| // Analyze | // Analyze | ||||
| @@ -159,6 +159,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { | |||||
| irpass.replace_applicator_, | irpass.replace_applicator_, | ||||
| irpass.mirror_mini_step_elim_, | irpass.mirror_mini_step_elim_, | ||||
| irpass.row_tensor_add_zeros_like_, | irpass.row_tensor_add_zeros_like_, | ||||
| irpass.mini_step_allgather_replace_, | |||||
| }); | }); | ||||
| opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_}); | opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_}); | ||||
| opt::OptPassConfig grad = opt::OptPassConfig({irpass.expand_jprim_}, true); | opt::OptPassConfig grad = opt::OptPassConfig({irpass.expand_jprim_}, true); | ||||
| @@ -374,7 +374,7 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { | |||||
| .def(py::init<TypePtr, const ShapeVector>(), py::arg("dtype"), py::arg("shape")) | .def(py::init<TypePtr, const ShapeVector>(), py::arg("dtype"), py::arg("shape")) | ||||
| .def_property_readonly("dtype", &MetaTensor::Dtype, "Get the MetaTensor's dtype.") | .def_property_readonly("dtype", &MetaTensor::Dtype, "Get the MetaTensor's dtype.") | ||||
| .def_property_readonly("shape", &MetaTensor::shape, "Get the MetaTensor's shape.") | .def_property_readonly("shape", &MetaTensor::shape, "Get the MetaTensor's shape.") | ||||
| .def_property("_param_info", &MetaTensor::param_info, &MetaTensor::set_param_info) | |||||
| .def_property("param_info", &MetaTensor::param_info, &MetaTensor::set_param_info) | |||||
| .def(py::pickle( | .def(py::pickle( | ||||
| [](const MetaTensor &t) { // __getstate__ | [](const MetaTensor &t) { // __getstate__ | ||||
| /* Return a tuple that fully encodes the state of the object */ | /* Return a tuple that fully encodes the state of the object */ | ||||
| @@ -134,7 +134,7 @@ class Parameter(Tensor_): | |||||
| Parameter, (data, self.name, self.requires_grad, self.layerwise_parallel)) | Parameter, (data, self.name, self.requires_grad, self.layerwise_parallel)) | ||||
| def __init__(self, default_input, name=None, requires_grad=True, layerwise_parallel=False, parallel_optimizer=True): | def __init__(self, default_input, name=None, requires_grad=True, layerwise_parallel=False, parallel_optimizer=True): | ||||
| self._param_info = ParamInfo() | |||||
| self.param_info = ParamInfo() | |||||
| self.init_in_server = False | self.init_in_server = False | ||||
| self.cache_enable = False | self.cache_enable = False | ||||
| self.name = name | self.name = name | ||||
| @@ -230,7 +230,7 @@ class Parameter(Tensor_): | |||||
| "sparse operator support initialization in server.".format(self.name)) | "sparse operator support initialization in server.".format(self.name)) | ||||
| self.is_param_ps = True | self.is_param_ps = True | ||||
| self.init_in_server = init_in_server | self.init_in_server = init_in_server | ||||
| self._param_info.init_in_server = init_in_server | |||||
| self.param_info.init_in_server = init_in_server | |||||
| @property | @property | ||||
| def inited_param(self): | def inited_param(self): | ||||
| @@ -245,7 +245,7 @@ class Parameter(Tensor_): | |||||
| @property | @property | ||||
| def name(self): | def name(self): | ||||
| """Get the name of the parameter.""" | """Get the name of the parameter.""" | ||||
| return self._param_info.name | |||||
| return self.param_info.name | |||||
| @name.setter | @name.setter | ||||
| def name(self, name_): | def name(self, name_): | ||||
| @@ -272,9 +272,9 @@ class Parameter(Tensor_): | |||||
| if len(self.shape) != 2: | if len(self.shape) != 2: | ||||
| raise RuntimeError("The dims of parameter '{}' must be 2, but got {}." | raise RuntimeError("The dims of parameter '{}' must be 2, but got {}." | ||||
| .format(self.name, len(self.shape))) | .format(self.name, len(self.shape))) | ||||
| _reinsert_hash_table_size(name_, self._param_info.name, self.shape[0], self.shape[1]) | |||||
| _reinsert_hash_table_size(name_, self.param_info.name, self.shape[0], self.shape[1]) | |||||
| self._param_info.name = name_ | |||||
| self.param_info.name = name_ | |||||
| @property | @property | ||||
| def sliced(self): | def sliced(self): | ||||
| @@ -288,12 +288,12 @@ class Parameter(Tensor_): | |||||
| @property | @property | ||||
| def comm_fusion(self): | def comm_fusion(self): | ||||
| """Get the fusion type for communication operators corresponding to this parameter.""" | """Get the fusion type for communication operators corresponding to this parameter.""" | ||||
| return self._param_info.comm_fusion | |||||
| return self.param_info.comm_fusion | |||||
| @comm_fusion.setter | @comm_fusion.setter | ||||
| def comm_fusion(self, comm_fusion_): | def comm_fusion(self, comm_fusion_): | ||||
| """Set the fusion type for communication operators corresponding to this parameter.""" | """Set the fusion type for communication operators corresponding to this parameter.""" | ||||
| self._param_info.comm_fusion = comm_fusion_ | |||||
| self.param_info.comm_fusion = comm_fusion_ | |||||
| @property | @property | ||||
| def unique(self): | def unique(self): | ||||
| @@ -339,7 +339,7 @@ class Parameter(Tensor_): | |||||
| """ | """ | ||||
| x = copy(self) | x = copy(self) | ||||
| # pylint: disable=protected-access | # pylint: disable=protected-access | ||||
| x._param_info = self._param_info.clone() | |||||
| x.param_info = self.param_info.clone() | |||||
| x.is_init = False | x.is_init = False | ||||
| x.init = self.init | x.init = self.init | ||||
| x.is_param_ps = self.is_param_ps | x.is_param_ps = self.is_param_ps | ||||
| @@ -355,57 +355,57 @@ class Parameter(Tensor_): | |||||
| @property | @property | ||||
| def layerwise_parallel(self): | def layerwise_parallel(self): | ||||
| return self._param_info.layerwise_parallel | |||||
| return self.param_info.layerwise_parallel | |||||
| @layerwise_parallel.setter | @layerwise_parallel.setter | ||||
| def layerwise_parallel(self, value=True): | def layerwise_parallel(self, value=True): | ||||
| if not isinstance(value, bool): | if not isinstance(value, bool): | ||||
| raise TypeError("`layerwise_parallel` parameter must be bool type") | raise TypeError("`layerwise_parallel` parameter must be bool type") | ||||
| self._param_info.layerwise_parallel = value | |||||
| self.param_info.layerwise_parallel = value | |||||
| @property | @property | ||||
| def parallel_optimizer(self): | def parallel_optimizer(self): | ||||
| """Return whether the parameter requires weight shard for parallel optimizer.""" | """Return whether the parameter requires weight shard for parallel optimizer.""" | ||||
| return self._param_info.parallel_optimizer | |||||
| return self.param_info.parallel_optimizer | |||||
| @parallel_optimizer.setter | @parallel_optimizer.setter | ||||
| def parallel_optimizer(self, value=True): | def parallel_optimizer(self, value=True): | ||||
| if not isinstance(value, bool): | if not isinstance(value, bool): | ||||
| raise TypeError("`parallel_optimizer` parameter must be bool type") | raise TypeError("`parallel_optimizer` parameter must be bool type") | ||||
| self._param_info.parallel_optimizer = value | |||||
| self.param_info.parallel_optimizer = value | |||||
| @property | @property | ||||
| def cache_enable(self): | def cache_enable(self): | ||||
| """Return whether the parameter is cache enable.""" | """Return whether the parameter is cache enable.""" | ||||
| return self._param_info.cache_enable | |||||
| return self.param_info.cache_enable | |||||
| @cache_enable.setter | @cache_enable.setter | ||||
| def cache_enable(self, value=True): | def cache_enable(self, value=True): | ||||
| if not isinstance(value, bool): | if not isinstance(value, bool): | ||||
| raise TypeError("`cache_enable` parameter must be bool type") | raise TypeError("`cache_enable` parameter must be bool type") | ||||
| self._param_info.cache_enable = value | |||||
| self.param_info.cache_enable = value | |||||
| @property | @property | ||||
| def cache_shape(self): | def cache_shape(self): | ||||
| """Return the cache shape corresponding to the parameter if use cache.""" | """Return the cache shape corresponding to the parameter if use cache.""" | ||||
| return self._param_info.cache_shape | |||||
| return self.param_info.cache_shape | |||||
| @cache_shape.setter | @cache_shape.setter | ||||
| def cache_shape(self, value): | def cache_shape(self, value): | ||||
| if not isinstance(value, (tuple, list)): | if not isinstance(value, (tuple, list)): | ||||
| raise TypeError("`cache_shape` parameter must be tuple or list type") | raise TypeError("`cache_shape` parameter must be tuple or list type") | ||||
| self._param_info.cache_shape = value | |||||
| self.param_info.cache_shape = value | |||||
| @property | @property | ||||
| def requires_grad(self): | def requires_grad(self): | ||||
| """Return whether the parameter requires gradient.""" | """Return whether the parameter requires gradient.""" | ||||
| return self._param_info.requires_grad | |||||
| return self.param_info.requires_grad | |||||
| @requires_grad.setter | @requires_grad.setter | ||||
| def requires_grad(self, value=True): | def requires_grad(self, value=True): | ||||
| if not isinstance(value, bool): | if not isinstance(value, bool): | ||||
| raise TypeError("`requires_grad` parameter must be bool type") | raise TypeError("`requires_grad` parameter must be bool type") | ||||
| self._param_info.requires_grad = value | |||||
| self.param_info.requires_grad = value | |||||
| @property | @property | ||||
| def data(self): | def data(self): | ||||
| @@ -419,7 +419,9 @@ class Parameter(Tensor_): | |||||
| self.init = None | self.init = None | ||||
| return self.assign_value(data) | return self.assign_value(data) | ||||
| # create a new tensor | # create a new tensor | ||||
| return Parameter(data, self.name, self.requires_grad) | |||||
| new_param = Parameter(data, self.name, self.requires_grad) | |||||
| new_param.param_info = self.param_info | |||||
| return new_param | |||||
| def set_data(self, data, slice_shape=False): | def set_data(self, data, slice_shape=False): | ||||
| """ | """ | ||||
| @@ -304,6 +304,7 @@ inline const PrimitivePtr kPrimCustomExtractFeatures = std::make_shared<Primitiv | |||||
| // Comm ops | // Comm ops | ||||
| inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator"); | inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator"); | ||||
| inline const PrimitivePtr kPrimMirrorMiniStep = std::make_shared<Primitive>("_MirrorMiniStepOperator"); | inline const PrimitivePtr kPrimMirrorMiniStep = std::make_shared<Primitive>("_MirrorMiniStepOperator"); | ||||
| inline const PrimitivePtr kPrimMiniStepAllGather = std::make_shared<Primitive>("_MiniStepAllGather"); | |||||
| inline const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv"); | inline const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv"); | ||||
| inline const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset"); | inline const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset"); | ||||
| inline const PrimitivePtr kPrimSend = std::make_shared<Primitive>("Send"); | inline const PrimitivePtr kPrimSend = std::make_shared<Primitive>("Send"); | ||||
| @@ -20,7 +20,7 @@ from mindspore.communication import get_rank, get_group_size | |||||
| from .. import operations as P | from .. import operations as P | ||||
| from ...common.tensor import RowTensor | from ...common.tensor import RowTensor | ||||
| from ..composite.multitype_ops.zeros_like_impl import zeros_like | from ..composite.multitype_ops.zeros_like_impl import zeros_like | ||||
| from ..operations.comm_ops import (AllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast, | |||||
| from ..operations.comm_ops import (AllGather, _MiniStepAllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast, | |||||
| _GetTensorSlice, _MirrorOperator, _MirrorMiniStepOperator, ReduceOp, | _GetTensorSlice, _MirrorOperator, _MirrorMiniStepOperator, ReduceOp, | ||||
| ReduceScatter, _HostReduceScatter, _VirtualDiv, AllSwap) | ReduceScatter, _HostReduceScatter, _VirtualDiv, AllSwap) | ||||
| from .grad_base import bprop_getters | from .grad_base import bprop_getters | ||||
| @@ -150,6 +150,39 @@ def get_bprop_all_gather(self): | |||||
| return bprop | return bprop | ||||
| @bprop_getters.register(_MiniStepAllGather) | |||||
| def get_bprop_mini_step_all_gather(self): | |||||
| """Generate bprop for _MiniStepAllGather""" | |||||
| fusion = self.get_attr_dict()["fusion"] | |||||
| mean_flag = self.get_attr_dict()["mean_flag"] | |||||
| do_mirror = self.get_attr_dict()["do_mirror"] | |||||
| scale = 1 / self.rank_size | |||||
| all_reduce = AllReduce(ReduceOp.SUM, self.group).add_prim_attr("fusion", fusion) | |||||
| if self.instance_name: | |||||
| instance_name = "grad_" + self.instance_name | |||||
| all_reduce.set_prim_instance_name(instance_name) | |||||
| rank = get_rank(self.group) | |||||
| dev_num = get_group_size(self.group) | |||||
| split = P.Split(output_num=dev_num) | |||||
| def bprop(x, z, out, dout): | |||||
| if do_mirror: | |||||
| if mean_flag: | |||||
| tmp = z + dout | |||||
| grad = all_reduce(tmp) | |||||
| dx = split(grad)[rank] | |||||
| dx = F.tensor_mul(dx, scale) | |||||
| else: | |||||
| tmp = z + dout | |||||
| grad = all_reduce(tmp) | |||||
| dx = split(grad)[rank] | |||||
| else: | |||||
| dx = dout | |||||
| return (dx, zeros_like(z)) | |||||
| return bprop | |||||
| @bprop_getters.register(_HostAllGather) | @bprop_getters.register(_HostAllGather) | ||||
| def get_bprop_host_all_gather(self): | def get_bprop_host_all_gather(self): | ||||
| """Generate bprop for _HostAllGather""" | """Generate bprop for _HostAllGather""" | ||||
| @@ -291,18 +324,13 @@ def get_bprop_mirror_mini_step_operator(self): | |||||
| group = self.group | group = self.group | ||||
| dev_num = self.dev_num | dev_num = self.dev_num | ||||
| mean_flag = self.mean_flag | mean_flag = self.mean_flag | ||||
| grad_accumulation_step = self.grad_accumulation_step | |||||
| all_reduce = AllReduce(group=group) | all_reduce = AllReduce(group=group) | ||||
| all_gather = AllGather(group=group) | all_gather = AllGather(group=group) | ||||
| mul = P.Mul() | mul = P.Mul() | ||||
| cast = P.Cast() | cast = P.Cast() | ||||
| equal = P.Equal() | |||||
| reshape = P.Reshape() | |||||
| fusion = 1 | |||||
| if hasattr(self, 'fusion'): | |||||
| fusion = self.fusion | |||||
| fusion = self.get_attr_dict()["fusion"] | |||||
| all_reduce.add_prim_attr("fusion", fusion) | all_reduce.add_prim_attr("fusion", fusion) | ||||
| if hasattr(self, 'parameter'): | if hasattr(self, 'parameter'): | ||||
| parameter = self.parameter | parameter = self.parameter | ||||
| @@ -311,16 +339,15 @@ def get_bprop_mirror_mini_step_operator(self): | |||||
| if self.instance_name: | if self.instance_name: | ||||
| instance_name = "grad_mirror" + self.instance_name | instance_name = "grad_mirror" + self.instance_name | ||||
| all_reduce.set_prim_instance_name(instance_name) | all_reduce.set_prim_instance_name(instance_name) | ||||
| do_mirror = self.get_attr_dict()["do_mirror"] | |||||
| def bprop(x, y, z, out, dout): | |||||
| do_mirror = equal(y, grad_accumulation_step) | |||||
| do_mirror = reshape(do_mirror, (())) | |||||
| def bprop(x, z, out, dout): | |||||
| if mean_flag: | if mean_flag: | ||||
| if F.issubclass_(F.typeof(dout), mstype.tensor): | if F.issubclass_(F.typeof(dout), mstype.tensor): | ||||
| if do_mirror: | if do_mirror: | ||||
| tmp = z + dout | tmp = z + dout | ||||
| real_grad = all_reduce(tmp) | real_grad = all_reduce(tmp) | ||||
| dx = real_grad - z | |||||
| dx = real_grad | |||||
| else: | else: | ||||
| dx = dout | dx = dout | ||||
| float_one = F.scalar_cast(1.0, F.dtype(dx)) | float_one = F.scalar_cast(1.0, F.dtype(dx)) | ||||
| @@ -342,7 +369,7 @@ def get_bprop_mirror_mini_step_operator(self): | |||||
| if do_mirror: | if do_mirror: | ||||
| tmp = z + dout | tmp = z + dout | ||||
| real_grad = all_reduce(tmp) | real_grad = all_reduce(tmp) | ||||
| dx = real_grad - z | |||||
| dx = real_grad | |||||
| else: | else: | ||||
| dx = dout | dx = dout | ||||
| else: | else: | ||||
| @@ -354,7 +381,7 @@ def get_bprop_mirror_mini_step_operator(self): | |||||
| grad = dout.values | grad = dout.values | ||||
| dx = RowTensor(indices, grad, dout.dense_shape) | dx = RowTensor(indices, grad, dout.dense_shape) | ||||
| return (dx, zeros_like(y), zeros_like(z)) | |||||
| return (dx, zeros_like(z)) | |||||
| return bprop | return bprop | ||||
| @@ -35,7 +35,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unsta | |||||
| SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup, | SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup, | ||||
| Unique, GatherD, Identity, Range) | Unique, GatherD, Identity, Range) | ||||
| from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast, | from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast, | ||||
| _MirrorOperator, _MirrorMiniStepOperator, ReduceOp, _VirtualDataset, | |||||
| _MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset, | |||||
| _VirtualDiv, _GetTensorSlice, | _VirtualDiv, _GetTensorSlice, | ||||
| _HostAllGather, _HostReduceScatter) | _HostAllGather, _HostReduceScatter) | ||||
| from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary, | from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary, | ||||
| @@ -200,6 +200,38 @@ class AllGather(PrimitiveWithInfer): | |||||
| raise NotImplementedError | raise NotImplementedError | ||||
| class _MiniStepAllGather(PrimitiveWithInfer): | |||||
| """ | |||||
| Auto parallel virtual operator. Do nothing in forward, do reducescatter in backward in mini-step. It is only for | |||||
| internal use of parallel modules and cannot be called by users. | |||||
| Args: | |||||
| group (str): The communication group to work on. Default: None. | |||||
| grad_accumulation_step (int): The grad accumulation step. Default: None. | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self, group=GlobalComm.WORLD_COMM_GROUP, grad_accumulation_step=None, mean_flag=None): | |||||
| validator.check_value_type('group', _get_group(group), (str,), self.name) | |||||
| self.rank = get_rank(_get_group(group)) | |||||
| self.rank_size = get_group_size(_get_group(group)) | |||||
| validator.check('rank', self.rank, 'rank_size', self.rank_size, Rel.LT, self.name) | |||||
| self.add_prim_attr('rank_size', self.rank_size) | |||||
| self.add_prim_attr('group', _get_group(group)) | |||||
| self.add_prim_attr('fusion', 1) | |||||
| self.grad_accumulation_step = grad_accumulation_step | |||||
| self.mean_flag = mean_flag | |||||
| def infer_shape(self, x_shape, z_shape): | |||||
| validator.check_positive_int(len(x_shape), "x shape", self.name) | |||||
| if x_shape[0] > 0: | |||||
| x_shape[0] = x_shape[0] * self.rank_size | |||||
| return x_shape | |||||
| def infer_dtype(self, x_dtype, z_shape): | |||||
| validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name) | |||||
| return x_dtype | |||||
| class _HostAllGather(PrimitiveWithInfer): | class _HostAllGather(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| Gathers tensors from the specified communication group on host. | Gathers tensors from the specified communication group on host. | ||||
| @@ -590,10 +622,10 @@ class _MirrorMiniStepOperator(PrimitiveWithInfer): | |||||
| self.mean_flag = mean_flag | self.mean_flag = mean_flag | ||||
| self.grad_accumulation_step = grad_accumulation_step | self.grad_accumulation_step = grad_accumulation_step | ||||
| def infer_shape(self, x_shape, y_shape, z_shape): | |||||
| def infer_shape(self, x_shape, z_shape): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_dtype, y_shape, z_shape): | |||||
| def infer_dtype(self, x_dtype, z_shape): | |||||
| return x_dtype | return x_dtype | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -17,15 +17,14 @@ import numpy as np | |||||
| import mindspore as ms | import mindspore as ms | ||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| from mindspore import context, Tensor, Parameter | from mindspore import context, Tensor, Parameter | ||||
| from mindspore.nn import Cell, Momentum, Norm | |||||
| from mindspore.train import Model | from mindspore.train import Model | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.ops import composite as C | from mindspore.ops import composite as C | ||||
| from mindspore.ops import functional as F | from mindspore.ops import functional as F | ||||
| from mindspore.common.initializer import initializer | from mindspore.common.initializer import initializer | ||||
| from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell | |||||
| from mindspore.context import ParallelMode | from mindspore.context import ParallelMode | ||||
| from mindspore.nn import DistributedGradReducer, DynamicLossScaleUpdateCell, Cell, Momentum, Norm | |||||
| from mindspore.parallel._utils import _get_device_num | |||||
| from tests.dataset_mock import MindData | from tests.dataset_mock import MindData | ||||
| @@ -142,29 +141,29 @@ class TrainAccumulateStepsWithLossScaleCell(Cell): | |||||
| accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size = | accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size = | ||||
| batch_size * accumulation_steps. Default: 1. | batch_size * accumulation_steps. Default: 1. | ||||
| """ | """ | ||||
| def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=4): | |||||
| def __init__(self, network, optimizer, scale_update_cell=None): | |||||
| super(TrainAccumulateStepsWithLossScaleCell, self).__init__(auto_prefix=False) | super(TrainAccumulateStepsWithLossScaleCell, self).__init__(auto_prefix=False) | ||||
| self.accu = False | |||||
| self.is_accu_step = Tensor(np.array([self.accu])) | |||||
| self.network = network | self.network = network | ||||
| self.network.set_grad() | self.network.set_grad() | ||||
| self.weights = optimizer.parameters | self.weights = optimizer.parameters | ||||
| self.optimizer = optimizer | self.optimizer = optimizer | ||||
| self.accumulation_steps = accumulation_steps | |||||
| self.accumulation_steps = context.get_auto_parallel_context("grad_accumulation_step") | |||||
| self.one = Tensor(np.array([1]).astype(np.int32)) | self.one = Tensor(np.array([1]).astype(np.int32)) | ||||
| self.zero = Tensor(np.array([0]).astype(np.int32)) | self.zero = Tensor(np.array([0]).astype(np.int32)) | ||||
| self.local_step = Parameter(initializer(0, [1], mstype.int32), name="local_step") | |||||
| self.accu_grads = self.weights.clone(prefix="accu_grads", init='zeros') | self.accu_grads = self.weights.clone(prefix="accu_grads", init='zeros') | ||||
| self.accu_overflow = Parameter(initializer(0, [1], mstype.int32)) | self.accu_overflow = Parameter(initializer(0, [1], mstype.int32)) | ||||
| self.accu_loss = Parameter(initializer(0, [1], mstype.float32)) | self.accu_loss = Parameter(initializer(0, [1], mstype.float32)) | ||||
| self.grad = C.GradOperation(get_by_list=True, sens_param=True) | |||||
| self.reducer_flag = False | self.reducer_flag = False | ||||
| self.grad = C.GradOperation(get_by_list=True, sens_param=True) | |||||
| self.parallel_mode = context.get_auto_parallel_context("parallel_mode") | self.parallel_mode = context.get_auto_parallel_context("parallel_mode") | ||||
| if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: | if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: | ||||
| self.reducer_flag = True | self.reducer_flag = True | ||||
| self.grad_reducer = F.identity | |||||
| self.degree = 1 | self.degree = 1 | ||||
| self.grad_reducer = F.identity | |||||
| if self.reducer_flag: | if self.reducer_flag: | ||||
| self.degree = get_group_size() | |||||
| self.degree = _get_device_num() | |||||
| self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) | self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) | ||||
| self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) | self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) | ||||
| self.overflow_reducer = F.identity | self.overflow_reducer = F.identity | ||||
| @@ -197,34 +196,27 @@ class TrainAccumulateStepsWithLossScaleCell(Cell): | |||||
| else: | else: | ||||
| scaling_sens = sens | scaling_sens = sens | ||||
| # update accumulation parameters | |||||
| is_accu_step = self.not_equal(self.local_step, self.accumulation_steps) | |||||
| self.local_step = self.select(is_accu_step, self.local_step + self.one, self.one) | |||||
| self.accu_loss = self.select(is_accu_step, self.accu_loss + loss, loss) | |||||
| mean_loss = self.accu_loss / self.local_step | |||||
| is_accu_step = self.not_equal(self.local_step, self.accumulation_steps) | |||||
| # alloc status and clear should be right before gradoperation | # alloc status and clear should be right before gradoperation | ||||
| init = self.alloc_status() | init = self.alloc_status() | ||||
| self.clear_before_grad(init) | self.clear_before_grad(init) | ||||
| grads = self.grad(self.network, weights)(x, b, self.cast(scaling_sens, mstype.float32)) | grads = self.grad(self.network, weights)(x, b, self.cast(scaling_sens, mstype.float32)) | ||||
| accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, grads) | |||||
| mean_loss = F.depend(mean_loss, accu_succ) | |||||
| if self.is_accu_step and self.accumulation_steps > 1: | |||||
| accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, grads) | |||||
| loss = F.depend(loss, accu_succ) | |||||
| self.get_status(init) | self.get_status(init) | ||||
| flag_sum = self.reduce_sum(init, (0,)) | flag_sum = self.reduce_sum(init, (0,)) | ||||
| overflow = self.less_equal(self.base, flag_sum) | overflow = self.less_equal(self.base, flag_sum) | ||||
| overflow = self.logical_or(self.not_equal(self.accu_overflow, self.zero), overflow) | overflow = self.logical_or(self.not_equal(self.accu_overflow, self.zero), overflow) | ||||
| accu_overflow = self.select(overflow, self.one, self.zero) | accu_overflow = self.select(overflow, self.one, self.zero) | ||||
| self.accu_overflow = self.select(is_accu_step, accu_overflow, self.zero) | |||||
| is_accu_step = self.reshape(is_accu_step, (())) | |||||
| self.accu_overflow = self.select(self.is_accu_step, accu_overflow, self.zero) | |||||
| if is_accu_step: | |||||
| if self.is_accu_step: | |||||
| succ = False | succ = False | ||||
| else: | else: | ||||
| # apply grad reducer on grads | # apply grad reducer on grads | ||||
| grads = self.grad_reducer(self.accu_grads) | |||||
| grads = self.grad_reducer(grads) | |||||
| scaling = scaling_sens * self.degree * self.accumulation_steps | scaling = scaling_sens * self.degree * self.accumulation_steps | ||||
| grads = self.hyper_map(F.partial(grad_scale, scaling), grads) | grads = self.hyper_map(F.partial(grad_scale, scaling), grads) | ||||
| grads = ClipByGlobalNorm()(grads) | grads = ClipByGlobalNorm()(grads) | ||||
| @@ -241,7 +233,7 @@ class TrainAccumulateStepsWithLossScaleCell(Cell): | |||||
| else: | else: | ||||
| succ = self.optimizer(grads) | succ = self.optimizer(grads) | ||||
| ret = (mean_loss, overflow, scaling_sens) | |||||
| ret = (loss, overflow, scaling_sens) | |||||
| return F.depend(ret, succ) | return F.depend(ret, succ) | ||||
| @@ -265,25 +257,51 @@ _b = Tensor(np.ones([16]), dtype=ms.float32) | |||||
| _w1 = Tensor(np.ones([16]), dtype=ms.float32) | _w1 = Tensor(np.ones([16]), dtype=ms.float32) | ||||
| def compile_net(net, grad_accumulation_step): | |||||
| context.set_context(save_graphs=True) | |||||
| def compile_net(net): | |||||
| context.set_context(enable_sparse=False) | |||||
| learning_rate = 0.1 | learning_rate = 0.1 | ||||
| momentum = 0.9 | momentum = 0.9 | ||||
| epoch_size = 2 | epoch_size = 2 | ||||
| dataset = Dataset(_x, _b) | dataset = Dataset(_x, _b) | ||||
| opt = Momentum(net.trainable_params(), learning_rate, momentum) | opt = Momentum(net.trainable_params(), learning_rate, momentum) | ||||
| update_cell = DynamicLossScaleUpdateCell(loss_scale_value=65536, scale_factor=2, scale_window=1000) | update_cell = DynamicLossScaleUpdateCell(loss_scale_value=65536, scale_factor=2, scale_window=1000) | ||||
| net_wrap = TrainAccumulateStepsWithLossScaleCell(net, opt, scale_update_cell=update_cell, | |||||
| accumulation_steps=grad_accumulation_step) | |||||
| net_wrap = TrainAccumulateStepsWithLossScaleCell(net, opt, scale_update_cell=update_cell) | |||||
| model = Model(net_wrap) | model = Model(net_wrap) | ||||
| model.train(epoch_size, dataset, dataset_sink_mode=False) | model.train(epoch_size, dataset, dataset_sink_mode=False) | ||||
| context.reset_auto_parallel_context() | context.reset_auto_parallel_context() | ||||
| def test_grad_accumulation(): | |||||
| def test_grad_accumulation_accu(): | |||||
| grad_accumulation_step = 4 | grad_accumulation_step = 4 | ||||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0, | context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0, | ||||
| grad_accumulation_step=grad_accumulation_step) | grad_accumulation_step=grad_accumulation_step) | ||||
| strategy = ((2,), (2,)) | strategy = ((2,), (2,)) | ||||
| net = Net(_w1, strategy) | |||||
| compile_net(net, grad_accumulation_step) | |||||
| net = Net(_w1, strategy).add_flags_recursive(accu=True) | |||||
| compile_net(net) | |||||
| def test_grad_accu_and_opt_shard_accu(): | |||||
| grad_accumulation_step = 4 | |||||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0, | |||||
| grad_accumulation_step=grad_accumulation_step, enable_parallel_optimizer=True) | |||||
| strategy = ((2,), (2,)) | |||||
| net = Net(_w1, strategy).add_flags_recursive(accu=True) | |||||
| compile_net(net) | |||||
| def test_grad_accumulation_not_accu(): | |||||
| grad_accumulation_step = 4 | |||||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0, | |||||
| grad_accumulation_step=grad_accumulation_step) | |||||
| strategy = ((2,), (2,)) | |||||
| net = Net(_w1, strategy).add_flags_recursive(accu=False) | |||||
| compile_net(net) | |||||
| def test_grad_accu_and_opt_shard_not_accu(): | |||||
| grad_accumulation_step = 4 | |||||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0, | |||||
| grad_accumulation_step=grad_accumulation_step, enable_parallel_optimizer=True) | |||||
| strategy = ((2,), (2,)) | |||||
| net = Net(_w1, strategy).add_flags_recursive(accu=False) | |||||
| compile_net(net) | |||||