| @@ -106,7 +106,9 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||
| prim::kPrimMirrorMiniStep); | |||
| mini_step_allgather_replace_ = MakeSubstitution(std::make_shared<MiniStepAllGatherPass>(), | |||
| "mini_step_allgather_replace", prim::kPrimMiniStepAllGather); | |||
| virtual_add_elim_ = MakeSubstitution(std::make_shared<VirtualAddEliminater>(), "virtual_add", prim::kPrimVirtualAdd); | |||
| micro_step_allgather_replace_ = MakeSubstitution(std::make_shared<MicroStepAllGatherPass>(), | |||
| "micro_step_allgather_replace", prim::kPrimMicroStepAllGather); | |||
| virtual_add_elim_ = MakeSubstitution(std::make_shared<VirtualAddEliminater>(), "virtual add", prim::kPrimVirtualAdd); | |||
| check_bprop_eliminate_ = | |||
| MakeSubstitution(std::make_shared<CheckBpropEliminater>(), "check_bprop_eliminate", prim::kPrimCheckBprop); | |||
| reset_defer_inline_ = | |||
| @@ -60,6 +60,7 @@ class OptimizeIRPassLib { | |||
| SubstitutionPtr mirror_mini_step_elim_; | |||
| SubstitutionPtr virtual_add_elim_; | |||
| SubstitutionPtr mini_step_allgather_replace_; | |||
| SubstitutionPtr micro_step_allgather_replace_; | |||
| // Env Item Eliminate | |||
| SubstitutionPtr env_get_item_eliminate_; | |||
| @@ -300,6 +300,39 @@ class MiniStepAllGatherPass : public AnfVisitor { | |||
| void Visit(const AnfNodePtr &) override {} | |||
| }; | |||
| // {prim::kPrimMicroStepAllGather, X, Z} -> {prim::kPrimAllGather, X} | |||
| class MicroStepAllGatherPass : public AnfVisitor { | |||
| public: | |||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||
| if (!IsPrimitiveCNode(node, prim::kPrimMicroStepAllGather) || node->func_graph() == nullptr) { | |||
| return nullptr; | |||
| } | |||
| auto &inputs = node->cast<CNodePtr>()->inputs(); | |||
| if (inputs.size() < 2) { | |||
| return nullptr; | |||
| } | |||
| auto prim = GetValueNode<PrimitivePtr>(node->cast<CNodePtr>()->input(0)); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| auto attrs = prim->attrs(); | |||
| std::string group = attrs[parallel::GROUP]->ToString(); | |||
| auto fusion = attrs[parallel::FUSION]; | |||
| parallel::Operator op = parallel::CreateAllGatherOp(group); | |||
| std::vector<AnfNodePtr> node_input = parallel::CreateInput(op, inputs[1], parallel::PARALLEL_OPTIMIZER_ALLGATHER); | |||
| auto prim_anf_node = node_input[0]->cast<ValueNodePtr>(); | |||
| prim = GetValueNode<PrimitivePtr>(prim_anf_node); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| attrs = prim->attrs(); | |||
| attrs[parallel::FUSION] = fusion; | |||
| prim->SetAttrs(attrs); | |||
| auto func_graph = inputs[1]->func_graph(); | |||
| CNodePtr new_node = func_graph->NewCNode(node_input); | |||
| return new_node; | |||
| } | |||
| void Visit(const AnfNodePtr &) override {} | |||
| }; | |||
| // Reset defer_inline flag | |||
| class ResetDeferInline : public AnfVisitor { | |||
| public: | |||
| @@ -36,7 +36,6 @@ constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALLREDUCE_INHERENT_TIME = 0 | |||
| constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALLREDUCE_BANDWIDTH = 0.1; | |||
| constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_COMPUTATION_TIME_PARAMETER = 0.1; | |||
| constexpr char PARAMETER[] = "parameter"; | |||
| const uint64_t MAX_RECURSIVE_CALL_TIMES = 100; | |||
| class AllreduceFusion { | |||
| public: | |||
| @@ -87,14 +87,12 @@ void SetStridedSliceStrategy(const AnfNodePtr &node) { | |||
| void InsertVirtualAssignAdd(const std::pair<AnfNodePtr, int> &node_user, const FuncGraphManagerPtr &manager, | |||
| const AnfNodePtr &accu_parameter) { | |||
| auto cnode = node_user.first->cast<CNodePtr>(); | |||
| if (IsPrimitiveCNode(cnode, prim::kPrimReceive) || !cnode->in_forward_flag() || | |||
| ((IsPrimitiveCNode(node_user.first, prim::kPrimSend) || IsPrimitiveCNode(node_user.first, prim::kPrimDepend)) && | |||
| ParallelContext::GetInstance()->enable_parallel_optimizer())) { | |||
| if (IsPrimitiveCNode(cnode, prim::kPrimReceive) || !cnode->in_forward_flag()) { | |||
| return; | |||
| } | |||
| auto prim = GetCNodePrimitive(cnode); | |||
| if (prim == nullptr) { | |||
| MS_LOG(WARNING) << cnode->DebugString() << " can not insert _VirtualAssignAd"; | |||
| MS_LOG(WARNING) << cnode->DebugString() << " can not insert _VirtualAssignAdd."; | |||
| return; | |||
| } | |||
| OperatorAttrs attrs; | |||
| @@ -154,10 +152,12 @@ void HandleReceiveParam(const FuncGraphPtr &root, const std::vector<AnfNodePtr> | |||
| auto node_users = node_users_map[node]; | |||
| for (auto &temp_user : node_users) { | |||
| auto temp_node = temp_user.first; | |||
| // Micro virtual operator might be inserted after cast | |||
| if (IsPrimitiveCNode(temp_node, prim::kPrimCast)) { | |||
| temp_node = node_users_map[temp_node].begin()->first; | |||
| } | |||
| if (IsPrimitiveCNode(temp_node, prim::kPrimMirrorMicroStep)) { | |||
| if (IsPrimitiveCNode(temp_node, prim::kPrimMirrorMicroStep) || | |||
| IsPrimitiveCNode(temp_node, prim::kPrimMicroStepAllGather)) { | |||
| auto node_set = node_users_map[temp_node]; | |||
| for (auto &node_user : node_set) { | |||
| InsertVirtualAssignAdd(node_user, root->manager(), accu_parameter); | |||
| @@ -182,10 +182,12 @@ void AddVirtualAssignAdd(const FuncGraphPtr &root) { | |||
| auto node_users = node_users_map[parameter]; | |||
| for (auto &temp_user : node_users) { | |||
| auto temp_node = temp_user.first; | |||
| // Micro virtual operator might be inserted after cast | |||
| if (IsPrimitiveCNode(temp_node, prim::kPrimCast)) { | |||
| temp_node = node_users_map[temp_node].begin()->first; | |||
| } | |||
| if (IsPrimitiveCNode(temp_node, prim::kPrimMirrorMicroStep)) { | |||
| if (IsPrimitiveCNode(temp_node, prim::kPrimMirrorMicroStep) || | |||
| IsPrimitiveCNode(temp_node, prim::kPrimMicroStepAllGather)) { | |||
| auto node_set = node_users_map[temp_node]; | |||
| for (auto &node_user : node_set) { | |||
| InsertVirtualAssignAdd(node_user, root->manager(), accu_parameter); | |||
| @@ -428,6 +428,26 @@ Operator CreateMiniStepAllGatherOp(const std::string &group) { | |||
| return op; | |||
| } | |||
| Operator CreateMicroStepAllGatherOp(const std::string &group) { | |||
| bool mean_flag = ParallelContext::GetInstance()->gradients_mean(); | |||
| OperatorName operator_name = MICRO_STEP_ALL_GATHER; | |||
| ValuePtr attr0_value = MakeValue(group); // group | |||
| Attr attr0 = std::make_pair(GROUP, attr0_value); | |||
| ValuePtr attr1_value = MakeValue(mean_flag); // mean_flag | |||
| Attr attr1 = std::make_pair(MEAN_FLAG, attr1_value); | |||
| OperatorAttrs operator_attrs; | |||
| operator_attrs.push_back(attr0); | |||
| operator_attrs.push_back(attr1); | |||
| OperatorParams operator_param; | |||
| OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param); | |||
| Operator op = std::make_pair(operator_name, operator_arg); | |||
| MS_LOG(INFO) << "Create MICRO_STEP_ALL_GATHER success, the group is " << group; | |||
| return op; | |||
| } | |||
| // use for get tensor slice | |||
| Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout) { | |||
| Shape tensor_map = tensor_layout.tensor_map().array(); | |||
| @@ -299,6 +299,7 @@ Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string & | |||
| Operator CreateAllGatherOp(const std::string &group); | |||
| Operator CreateMiniStepAllGatherOp(const std::string &group); | |||
| void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr ¶m_node); | |||
| Operator CreateMicroStepAllGatherOp(const std::string &group); | |||
| void AddCommOpMeanFlag(const CNodePtr &comm_node); | |||
| void AddCommOpParamFlag(const CNodePtr &comm_node); | |||
| Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout); | |||
| @@ -217,6 +217,7 @@ constexpr char LOCAL_STEP[] = "local_step"; | |||
| constexpr char STRIDED_SLICE[] = "StridedSlice"; | |||
| constexpr char ALL_GATHER[] = "AllGather"; | |||
| constexpr char MINI_STEP_ALL_GATHER[] = "_MiniStepAllGather"; | |||
| constexpr char MICRO_STEP_ALL_GATHER[] = "_MicroStepAllGather"; | |||
| constexpr char REDUCE_SCATTER[] = "ReduceScatter"; | |||
| constexpr char HOST_REDUCE_SCATTER[] = "_HostReduceScatter"; | |||
| constexpr char EMBEDDING_LOOKUP[] = "EmbeddingLookup"; | |||
| @@ -383,6 +384,7 @@ constexpr char VIRTUAL_ACCU_GRAD[] = "_VirtualAccuGrad"; | |||
| constexpr char ACCU_GRAD[] = "accu_grad"; | |||
| constexpr char PARAMETER_START[] = "parameter_start"; | |||
| constexpr char PARAM_INDEX[] = "param_index"; | |||
| constexpr char PARAMETER[] = "parameter"; | |||
| // Parallel don't care | |||
| constexpr char STRING_EQUAL[] = "string_equal"; | |||
| @@ -199,7 +199,8 @@ std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operat | |||
| if (op_name == MIRROR_MINI_STEP_OPERATOR) { | |||
| op_name = MIRROR_OPERATOR; | |||
| arg_forward.first.pop_back(); | |||
| } else if (op_name == MINI_STEP_ALL_GATHER || op_name == MIRROR_MICRO_STEP_OPERATOR) { | |||
| } else if (op_name == MINI_STEP_ALL_GATHER || op_name == MIRROR_MICRO_STEP_OPERATOR || | |||
| op_name == MICRO_STEP_ALL_GATHER) { | |||
| MS_LOG(EXCEPTION) << "You should define `accu_grads` when use " << op_name << " parameter:" << weight_name; | |||
| } | |||
| } | |||
| @@ -211,7 +212,7 @@ std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operat | |||
| std::vector<AnfNodePtr> new_node_input; | |||
| if (op_name == MIRROR_MINI_STEP_OPERATOR || op_name == MINI_STEP_ALL_GATHER || | |||
| op_name == MIRROR_MICRO_STEP_OPERATOR) { | |||
| op_name == MIRROR_MICRO_STEP_OPERATOR || op_name == MICRO_STEP_ALL_GATHER) { | |||
| new_node_input = {NewValueNode(pyop_instance), node, grad_accu}; | |||
| MS_LOG(INFO) << "Insert the grad accumulation node as the mirror op's input"; | |||
| } else { | |||
| @@ -1117,6 +1118,15 @@ std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGrap | |||
| return std::make_pair(nullptr, false); | |||
| } | |||
| // only used for FindCNode | |||
| CNodePtr SkipTrivialNodesMoveDown(FuncGraphManagerPtr manager, CNodePtr node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| while (IsInTrivialNodeList(node) || IsSomePrimitive(node, LOAD)) { | |||
| node = manager->node_users()[node].begin()->first->cast<CNodePtr>(); | |||
| } | |||
| return node; | |||
| } | |||
| std::pair<bool, CNodePtr> FindCNode(const AnfNodePtr &anode, const std::string &name, const FuncGraphPtr &func_graph) { | |||
| MS_EXCEPTION_IF_NULL(anode); | |||
| MS_EXCEPTION_IF_NULL(anode->func_graph()); | |||
| @@ -1130,6 +1140,9 @@ std::pair<bool, CNodePtr> FindCNode(const AnfNodePtr &anode, const std::string & | |||
| if (use_apply == nullptr || !IsValueNode<Primitive>(use_apply->input(0))) { | |||
| continue; | |||
| } | |||
| if (ParallelContext::GetInstance()->enable_parallel_optimizer()) { | |||
| use_apply = SkipTrivialNodesMoveDown(manager, use_apply); | |||
| } | |||
| ValueNodePtr prim_anf_node = use_apply->input(0)->cast<ValueNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(prim_anf_node); | |||
| PrimitivePtr node_prim = prim_anf_node->value()->cast<PrimitivePtr>(); | |||
| @@ -1202,7 +1215,7 @@ static bool CheckInsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &no | |||
| } | |||
| // only used for InsertMirrorOps | |||
| CNodePtr SkipTrivialNodes(CNodePtr node) { | |||
| CNodePtr SkipTrivialNodesMoveUp(CNodePtr node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| while (!IsSomePrimitive(node, LOAD)) { | |||
| if (IsInTrivialNodeList(node) || IsInAllGatherNodeList(node)) { | |||
| @@ -1287,7 +1300,7 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons | |||
| // assume Load is inserted next to parameter | |||
| // skip Load moving up and insert mirror next to the parameter | |||
| if (pre_node->cast<CNodePtr>()) { | |||
| CNodePtr load_node = SkipTrivialNodes(node->input(index)->cast<CNodePtr>()); | |||
| CNodePtr load_node = SkipTrivialNodesMoveUp(node->input(index)->cast<CNodePtr>()); | |||
| manager->SetEdge(load_node, 1, next_cnode.second); | |||
| } else { | |||
| manager->SetEdge(node, static_cast<int>(index), next_cnode.second); | |||
| @@ -1306,7 +1319,7 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons | |||
| if (pre_node->cast<CNodePtr>() && (InsertMirrorBeforeCast(node, index) || is_shared_param)) { | |||
| // assume Load is inserted next to parameter | |||
| // skip Load moving up and insert mirror next to the parameter | |||
| CNodePtr load_node = SkipTrivialNodes(pre_node->cast<CNodePtr>()); | |||
| CNodePtr load_node = SkipTrivialNodesMoveUp(pre_node->cast<CNodePtr>()); | |||
| InsertNode(op, load_node, 1, load_node->input(1), func_graph, mirror_op_name, param_name, root); | |||
| auto comm_op = load_node->input(1)->cast<CNodePtr>(); | |||
| // add fusion flag | |||
| @@ -1706,6 +1719,8 @@ static void InsertAllGatherOp(const FuncGraphPtr &root, const std::string &group | |||
| auto param_name = node->cast<ParameterPtr>()->name(); | |||
| if (op_name == MINI_STEP_ALL_GATHER) { | |||
| op = CreateMiniStepAllGatherOp(group); | |||
| } else if (op_name == MICRO_STEP_ALL_GATHER) { | |||
| op = CreateMicroStepAllGatherOp(group); | |||
| } else { | |||
| op = CreateAllGatherOp(group); | |||
| } | |||
| @@ -1733,9 +1748,12 @@ static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr & | |||
| MS_EXCEPTION_IF_NULL(parameter); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step(); | |||
| int32_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num(); | |||
| std::string op_name; | |||
| if (grad_accumulation_step > 1) { | |||
| op_name = MINI_STEP_ALL_GATHER; | |||
| } else if (split_stage_num > 1) { | |||
| op_name = MICRO_STEP_ALL_GATHER; | |||
| } else { | |||
| op_name = ALL_GATHER; | |||
| } | |||
| @@ -1744,7 +1762,7 @@ static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr & | |||
| for (auto ¶m_pair : param_sub_set) { | |||
| auto cnode = param_pair.first->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (cnode->in_forward_flag()) { | |||
| if (cnode->in_forward_flag() && !IsPrimitiveCNode(cnode, prim::kPrimReceive)) { | |||
| OperatorInfoPtr distribute_operator = cnode->user_data<OperatorInfo>(); | |||
| if (distribute_operator == nullptr) { | |||
| MS_LOG(DEBUG) << "Parallel optimizer: " << GetPrimName(cnode) << " 's OperatorInfoPtr is nullptr"; | |||
| @@ -1759,6 +1777,8 @@ static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr & | |||
| manager->SetEdge(cnode, SizeToLong(param_pair.second), next_cnode.second); | |||
| MS_LOG(INFO) << "Parallel optimizer is shared between " << parameter->ToString() << " and " | |||
| << GetPrimName(cnode); | |||
| } else { | |||
| MS_LOG(ERROR) << "Can not find the shared AllGather with multiple node users."; | |||
| } | |||
| } else { | |||
| // insert allgather operator between shard parameter and cnode | |||
| @@ -2852,12 +2872,14 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt | |||
| OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode); | |||
| MS_EXCEPTION_IF_NULL(distribute_operator); | |||
| // insert forward ops | |||
| InsertForwardOps(distribute_operator, cnode); | |||
| // insert redistribution ops | |||
| StepRedistribution(cnode, distribute_operator, cnode, tensor_redistribution, cnode); | |||
| // skip Send Receive | |||
| if (!cnode->HasPrimalAttr(PIPELINE_PARAM)) { | |||
| // insert forward ops | |||
| InsertForwardOps(distribute_operator, cnode); | |||
| // insert redistribution ops | |||
| StepRedistribution(cnode, distribute_operator, cnode, tensor_redistribution, cnode); | |||
| } | |||
| // insert backward ops | |||
| if (has_backward) { | |||
| BackwardCommunication(root, distribute_operator, cnode, sens_loss_pairs); | |||
| @@ -2873,7 +2895,8 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (node->isa<CNode>()) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>() || IsSomePrimitive(cnode, RECEIVE)) { | |||
| if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>() || IsSomePrimitive(cnode, RECEIVE) || | |||
| IsSomePrimitive(cnode, SEND)) { | |||
| continue; | |||
| } | |||
| @@ -2922,7 +2945,8 @@ void HandleSymbolicKeyInstance(const FuncGraphPtr &root, const std::vector<AnfNo | |||
| bool IsCohesiveNode(const CNodePtr &cnode) { | |||
| return IsPrimitiveCNode(cnode, prim::kPrimCast) || IsPrimitiveCNode(cnode, prim::kPrimLoad) || | |||
| IsPrimitiveCNode(cnode, prim::kPrimAllGather) || IsPrimitiveCNode(cnode, prim::kPrimMiniStepAllGather); | |||
| IsPrimitiveCNode(cnode, prim::kPrimAllGather) || IsPrimitiveCNode(cnode, prim::kPrimMiniStepAllGather) || | |||
| IsPrimitiveCNode(cnode, prim::kPrimMicroStepAllGather); | |||
| } | |||
| ParameterMap NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_depth) { | |||
| @@ -309,6 +309,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| irpass.virtual_add_elim_, | |||
| irpass.row_tensor_add_zeros_like_, | |||
| irpass.mini_step_allgather_replace_, | |||
| irpass.micro_step_allgather_replace_, | |||
| }, | |||
| false, true); | |||
| opt::OptPassConfig accelerated_algorithm = opt::OptPassConfig({irpass.less_batch_normalization_}); | |||
| @@ -362,6 +362,7 @@ inline const PrimitivePtr kFusedMulAdd = std::make_shared<Primitive>("FusedMulAd | |||
| inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator"); | |||
| inline const PrimitivePtr kPrimMirrorMiniStep = std::make_shared<Primitive>("_MirrorMiniStepOperator"); | |||
| inline const PrimitivePtr kPrimMiniStepAllGather = std::make_shared<Primitive>("_MiniStepAllGather"); | |||
| inline const PrimitivePtr kPrimMicroStepAllGather = std::make_shared<Primitive>("_MicroStepAllGather"); | |||
| inline const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv"); | |||
| inline const PrimitivePtr kPrimVirtualAdd = std::make_shared<Primitive>("_VirtualAdd"); | |||
| inline const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset"); | |||
| @@ -31,7 +31,8 @@ static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem, | |||
| "ImageSummary", "TensorSummary", "Debug", "HistogramSummary", "col2im_v1", "resolve", "BroadcastGradientArgs", | |||
| "InvertPermutation", "DropoutGenMask", "embed", "create_instance", "RefToEmbed", | |||
| "stop_gradient", "UpdateState", "Load"}; | |||
| static const std::set<PrimitivePtr> ALLGATHER_NODE_LIST_ = {prim::kPrimAllGather, prim::kPrimMiniStepAllGather}; | |||
| static const std::set<PrimitivePtr> ALLGATHER_NODE_LIST_ = {prim::kPrimAllGather, prim::kPrimMiniStepAllGather, | |||
| prim::kPrimMicroStepAllGather}; | |||
| static const std::set<PrimitivePtr> TRIVIAL_NODE_LIST_ = {prim::kPrimCast, prim::kPrimDepend}; | |||
| // clang-format on | |||
| @@ -16,7 +16,7 @@ | |||
| from types import FunctionType, MethodType | |||
| from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean, | |||
| _get_parallel_mode) | |||
| _get_parallel_mode, _get_enable_parallel_optimizer) | |||
| from mindspore.context import ParallelMode, get_auto_parallel_context | |||
| from mindspore._checkparam import Validator as validator | |||
| from mindspore import ops, nn | |||
| @@ -538,6 +538,7 @@ class _TrainPipelineAccuStepCell(TrainOneStepCell): | |||
| super(_TrainPipelineAccuStepCell, self).__init__(network, optimizer, sens) | |||
| self.accu_grads = self.weights.clone(prefix="accu_grads", init="zeros") | |||
| self.hyper_map = ops.HyperMap() | |||
| self.opt_shard = _get_enable_parallel_optimizer() | |||
| def construct(self, *inputs): | |||
| weights = self.weights | |||
| @@ -545,7 +546,10 @@ class _TrainPipelineAccuStepCell(TrainOneStepCell): | |||
| sens = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), self.sens) | |||
| grads = self.grad(self.network, weights)(*inputs, sens) | |||
| accu_grads = ops.depend(self.accu_grads, grads) | |||
| succ = self.optimizer(accu_grads) | |||
| if self.opt_shard: | |||
| succ = self.optimizer(grads) | |||
| else: | |||
| succ = self.optimizer(accu_grads) | |||
| clear = self.hyper_map(_pipeline_clear_grad, accu_grads, grads) | |||
| loss = ops.depend(loss, succ, clear) | |||
| return loss | |||
| @@ -18,13 +18,14 @@ from mindspore import Tensor | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.ops import functional as F | |||
| from mindspore.communication import get_rank, get_group_size | |||
| from mindspore.parallel._utils import _get_enable_parallel_optimizer | |||
| from .. import operations as P | |||
| from ...common.tensor import RowTensor | |||
| from ..composite.multitype_ops.zeros_like_impl import zeros_like | |||
| from ..operations.comm_ops import (AllGather, _MiniStepAllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast, | |||
| _GetTensorSlice, _MirrorOperator, _MirrorMiniStepOperator, ReduceOp, | |||
| ReduceScatter, _HostReduceScatter, _VirtualDiv, _VirtualAdd, AllSwap, | |||
| _VirtualAssignAdd, _VirtualAccuGrad, _MirrorMicroStepOperator) | |||
| _VirtualAssignAdd, _VirtualAccuGrad, _MirrorMicroStepOperator, _MicroStepAllGather) | |||
| from .grad_base import bprop_getters | |||
| from ..operations._inner_ops import Send, Receive | |||
| @@ -102,10 +103,14 @@ def get_bprop_receive(self): | |||
| depend = P.Depend() | |||
| cast = P.Cast() | |||
| out_tensor = Tensor(0.0, mstype.float16) | |||
| is_opt_shard = _get_enable_parallel_optimizer() | |||
| def bprop(x, out, dout): | |||
| send_out = receive_grad(dout) | |||
| dx = depend(cast(out_tensor, F.dtype(x)), send_out) | |||
| if is_opt_shard: | |||
| dx = depend(F.zeros_like(x), send_out) | |||
| else: | |||
| dx = depend(cast(out_tensor, F.dtype(x)), send_out) | |||
| return (dx,) | |||
| return bprop | |||
| @@ -174,6 +179,7 @@ def get_bprop_mirror_micro_step_operator(self): | |||
| if "parameter_micro" in self.get_attr_dict(): | |||
| assign.add_prim_attr("parameter_micro", 0) | |||
| out_tensor = Tensor(1.0, mstype.float16) | |||
| opt_shard = _get_enable_parallel_optimizer() | |||
| def bprop(x, z, out, dout): | |||
| real_grad = z | |||
| @@ -188,6 +194,8 @@ def get_bprop_mirror_micro_step_operator(self): | |||
| z = F.depend(z, dout) | |||
| real_grad = all_reduce(z) | |||
| assign(z, real_grad) | |||
| if opt_shard: | |||
| return (real_grad, cast(out_tensor, dtype(z))) | |||
| return (cast(out_tensor, dtype(x)), cast(out_tensor, dtype(z))) | |||
| return bprop | |||
| @@ -205,30 +213,17 @@ def get_bprop_broad_cast(self): | |||
| def get_bprop_all_gather(self): | |||
| """Generate bprop for AllGather""" | |||
| fusion = self.get_attr_dict()["fusion"] | |||
| if fusion == 0: | |||
| reduce_scatter = ReduceScatter(ReduceOp.SUM, self.group) | |||
| if self.instance_name: | |||
| instance_name = "grad_" + self.instance_name | |||
| reduce_scatter.set_prim_instance_name(instance_name) | |||
| else: | |||
| all_reduce = AllReduce(ReduceOp.SUM, self.group).add_prim_attr("fusion", fusion) | |||
| if self.instance_name: | |||
| instance_name = "grad_" + self.instance_name | |||
| all_reduce.set_prim_instance_name(instance_name) | |||
| rank = get_rank(self.group) | |||
| dev_num = get_group_size(self.group) | |||
| split = P.Split(output_num=dev_num) | |||
| mean_flag = self.get_attr_dict()["mean_flag"] | |||
| scale = 1/self.rank_size | |||
| reduce_scatter = ReduceScatter(ReduceOp.SUM, self.group).add_prim_attr("fusion", fusion) | |||
| if self.instance_name: | |||
| instance_name = "grad_" + self.instance_name | |||
| reduce_scatter.set_prim_instance_name(instance_name) | |||
| mean_flag = self.get_attr_dict()["mean_flag"] | |||
| scale = 1 / self.rank_size | |||
| def bprop(x, out, dout): | |||
| if fusion == 0: | |||
| dx = reduce_scatter(dout) | |||
| else: | |||
| grad = all_reduce(dout) | |||
| dx = split(grad)[rank] | |||
| if mean_flag: | |||
| dx = F.tensor_mul(dx, scale) | |||
| dx = reduce_scatter(dout) | |||
| if mean_flag: | |||
| dx = F.tensor_mul(dx, scale) | |||
| return (dx,) | |||
| return bprop | |||
| @@ -267,6 +262,35 @@ def get_bprop_mini_step_all_gather(self): | |||
| return bprop | |||
| @bprop_getters.register(_MicroStepAllGather) | |||
| def get_bprop_micro_step_all_gather(self): | |||
| """Generate bprop for _MicroStepAllGather""" | |||
| fusion = self.get_attr_dict()["fusion"] | |||
| mean_flag = self.get_attr_dict()["mean_flag"] | |||
| scale = 1 / self.rank_size | |||
| all_reduce = AllReduce(ReduceOp.SUM, self.group).add_prim_attr("fusion", fusion) | |||
| rank = get_rank(self.group) | |||
| dev_num = get_group_size(self.group) | |||
| split = P.Split(output_num=dev_num) | |||
| if self.instance_name: | |||
| instance_name = "grad_" + self.instance_name | |||
| all_reduce.set_prim_instance_name(instance_name) | |||
| cast = P.Cast() | |||
| dtype = P.DType() | |||
| out_tensor = Tensor(1.0, mstype.float16) | |||
| # z: accu_grad | |||
| def bprop(x, z, out, dout): | |||
| z = F.depend(z, dout) | |||
| real_grad = all_reduce(z) | |||
| real_grad = split(real_grad)[rank] | |||
| if mean_flag: | |||
| real_grad = F.tensor_mul(real_grad, scale) | |||
| return (real_grad, cast(out_tensor, dtype(z))) | |||
| return bprop | |||
| @bprop_getters.register(_HostAllGather) | |||
| def get_bprop_host_all_gather(self): | |||
| """Generate bprop for _HostAllGather""" | |||
| @@ -6,4 +6,4 @@ | |||
| bprop.10:x* | |||
| bprop.10:out* | |||
| bprop.10:dout2 | |||
| bprop.10:[CNode]12:2:€14cac93a068aa39edcd5220275a7f3df23c79f939b5f52bbe3321d22bc4706d92366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22248b4695c64d61a01e33ef3ba7144b288e54122debb351a5ac8f55d8914329584c332efad4a51b4773cb78093dd53a4ca850b2dc6cdd5f2ae47106b3fda77bb3522819d4919298eadafe049d3d0f3f1998cec40b35bed9c51c9d28b44ea7726065c0e00bc893ef15ec6199798d6c8c46997153587d375b3240c1195ff2c7278c7e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c0606bdbf14ec1b2b2d86ab82b5eb2ac71f1d3d0ba743f7cee45a1d9a0a2d82ac414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| bprop.10:[CNode]12:2:€027af68f320ba40d9fbd0893da424c07f9c3a4ec82e98f9543bff9b5a15547a22366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22248b4695c64d61a01e33ef3ba7144b288e54122debb351a5ac8f55d8914329584c332efad4a51b4773cb78093dd53a4ca850b2dc6cdd5f2ae47106b3fda77bb3522819d4919298eadafe049d3d0f3f1998cec40b35bed9c51c9d28b44ea7726065c0e00bc893ef15ec6199798d6c8c46997153587d375b3240c1195ff2c7278c7e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c0606bdbf14ec1b2b2d86ab82b5eb2ac71f1d3d0ba743f7cee45a1d9a0a2d82ac414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| @@ -8,4 +8,4 @@ | |||
| bprop.2:x* | |||
| bprop.2:out* | |||
| bprop.2:dout2 | |||
| bprop.2:[CNode]4:3:€14cac93a068aa39edcd5220275a7f3df23c79f939b5f52bbe3321d22bc4706d92366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22248b4695c64d61a01e33ef3ba7144b288e54122debb351a5ac8f55d8914329584c332efad4a51b4773cb78093dd53a4ca850b2dc6cdd5f2ae47106b3fda77bb3522819d4919298eadafe049d3d0f3f1998cec40b35bed9c51c9d28b44ea7726065c0e00bc893ef15ec6199798d6c8c46997153587d375b3240c1195ff2c7278c7e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c0606bdbf14ec1b2b2d86ab82b5eb2ac71f1d3d0ba743f7cee45a1d9a0a2d82ac414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| bprop.2:[CNode]4:3:€027af68f320ba40d9fbd0893da424c07f9c3a4ec82e98f9543bff9b5a15547a22366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22248b4695c64d61a01e33ef3ba7144b288e54122debb351a5ac8f55d8914329584c332efad4a51b4773cb78093dd53a4ca850b2dc6cdd5f2ae47106b3fda77bb3522819d4919298eadafe049d3d0f3f1998cec40b35bed9c51c9d28b44ea7726065c0e00bc893ef15ec6199798d6c8c46997153587d375b3240c1195ff2c7278c7e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c0606bdbf14ec1b2b2d86ab82b5eb2ac71f1d3d0ba743f7cee45a1d9a0a2d82ac414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| @@ -37,7 +37,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unsta | |||
| from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast, | |||
| _MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset, | |||
| _VirtualOutput, _VirtualDiv, _GetTensorSlice, _VirtualAdd, _VirtualAssignAdd, _VirtualAccuGrad, | |||
| _HostAllGather, _HostReduceScatter, _MirrorMicroStepOperator) | |||
| _HostAllGather, _HostReduceScatter, _MirrorMicroStepOperator, _MicroStepAllGather) | |||
| from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary, | |||
| TensorSummary, HistogramSummary, Print, Assert) | |||
| from .control_ops import GeSwitch, Merge | |||
| @@ -191,6 +191,7 @@ class AllGather(PrimitiveWithInfer): | |||
| self.add_prim_attr('rank_size', self.rank_size) | |||
| self.add_prim_attr('group', _get_group(group)) | |||
| self.add_prim_attr('fusion', 0) | |||
| self.add_prim_attr('mean_flag', False) | |||
| def infer_shape(self, x_shape): | |||
| validator.check_positive_int(len(x_shape), "x shape", self.name) | |||
| @@ -239,6 +240,36 @@ class _MiniStepAllGather(PrimitiveWithInfer): | |||
| return x_dtype | |||
| class _MicroStepAllGather(PrimitiveWithInfer): | |||
| """ | |||
| Auto parallel virtual operator. Do nothing in forward, do reducescatter in backward in mini-step. It is only for | |||
| internal use of parallel modules and cannot be called by users. | |||
| Args: | |||
| group (str): The communication group to work on. Default: None. | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, group=GlobalComm.WORLD_COMM_GROUP, mean_flag=None): | |||
| validator.check_value_type('group', _get_group(group), (str,), self.name) | |||
| self.rank = get_rank(_get_group(group)) | |||
| self.rank_size = get_group_size(_get_group(group)) | |||
| validator.check('rank', self.rank, 'rank_size', self.rank_size, Rel.LT, self.name) | |||
| self.add_prim_attr('rank_size', self.rank_size) | |||
| self.add_prim_attr('group', _get_group(group)) | |||
| self.add_prim_attr('fusion', 1) | |||
| self.mean_flag = mean_flag | |||
| def infer_shape(self, x_shape, z_shape): | |||
| validator.check_positive_int(len(x_shape), "x shape", self.name) | |||
| if x_shape[0] > 0: | |||
| x_shape[0] = x_shape[0] * self.rank_size | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype, z_dtype): | |||
| validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name) | |||
| return x_dtype | |||
| class _HostAllGather(PrimitiveWithInfer): | |||
| """ | |||
| Gathers tensors from the specified communication group on host. | |||
| @@ -160,6 +160,11 @@ def _get_parameter_broadcast(): | |||
| return parameter_broadcast | |||
| def _get_enable_parallel_optimizer(): | |||
| """Get if using parallel optimizer.""" | |||
| return auto_parallel_context().get_enable_parallel_optimizer() | |||
| def _device_number_check(parallel_mode, device_number): | |||
| """ | |||
| Check device num. | |||
| @@ -173,3 +173,67 @@ def test_pipeline_split_shared_parameter_stage1(): | |||
| optimizer = nn.Lamb(params, learning_rate=0.01) | |||
| model = Model(net, optimizer=optimizer) | |||
| model.train(2, dataset, dataset_sink_mode=False) | |||
| def test_pipeline_split_stage0_opt_shard(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, pipeline_stages=2, enable_parallel_optimizer=True) | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | |||
| data = Tensor(np.ones([32, 64]), dtype=ms.float32) | |||
| label = Tensor(np.ones([64, 64]), dtype=ms.float32) | |||
| strategy1 = ((4, 1), (1, 1)) | |||
| strategy2 = ((2, 1), (1, 1)) | |||
| net = PipelineCell(PipelineSplit(strategy1, strategy2), 4) | |||
| params = net.network.cell.block[0].trainable_params() | |||
| dataset = DatasetLenet(data, label, 3) | |||
| optimizer = nn.Lamb(params, learning_rate=0.01) | |||
| model = Model(net, optimizer=optimizer) | |||
| model.train(2, dataset, dataset_sink_mode=False) | |||
| for _, param in model._train_network.parameters_and_names(): | |||
| assert param.name != "cell.block.1.param" | |||
| assert param.name != "cell.block.1.param1" | |||
| def test_pipeline_split_stage1_opt_shard(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=4, pipeline_stages=2, enable_parallel_optimizer=True) | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | |||
| data = Tensor(np.ones([32, 64]), dtype=ms.float32) | |||
| label = Tensor(np.ones([64, 64]), dtype=ms.float32) | |||
| strategy1 = ((4, 1), (1, 1)) | |||
| strategy2 = ((2, 1), (1, 1)) | |||
| net = PipelineCell(PipelineSplit(strategy1, strategy2), 4) | |||
| params = net.network.cell.block[1].trainable_params() | |||
| dataset = DatasetLenet(data, label, 3) | |||
| optimizer = nn.Lamb(params, learning_rate=0.01) | |||
| model = Model(net, optimizer=optimizer) | |||
| model.train(2, dataset, dataset_sink_mode=False) | |||
| for _, param in model._train_network.parameters_and_names(): | |||
| assert param.name != "cell.block.0.param" | |||
| assert param.name != "cell.block.0.param1" | |||
| def test_pipeline_split_shared_parameter_stage0_opt_shard(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, pipeline_stages=2, enable_parallel_optimizer=True) | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | |||
| data = Tensor(np.ones([32, 64]), dtype=ms.float32) | |||
| label = Tensor(np.ones([64, 64]), dtype=ms.float32) | |||
| strategy1 = ((4, 1), (1, 1)) | |||
| strategy2 = ((2, 1), (1, 1)) | |||
| net = PipelineCell(PipelineSplit2(strategy1, strategy2), 4) | |||
| params = net.network.cell.block[0].trainable_params() | |||
| dataset = DatasetLenet(data, label, 3) | |||
| optimizer = nn.Lamb(params, learning_rate=0.01) | |||
| model = Model(net, optimizer=optimizer) | |||
| model.train(2, dataset, dataset_sink_mode=False) | |||
| def test_pipeline_split_shared_parameter_stage1_opt_shard(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=4, pipeline_stages=2, enable_parallel_optimizer=True) | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | |||
| data = Tensor(np.ones([32, 64]), dtype=ms.float32) | |||
| label = Tensor(np.ones([64, 64]), dtype=ms.float32) | |||
| strategy1 = ((4, 1), (1, 1)) | |||
| strategy2 = ((2, 1), (1, 1)) | |||
| net = PipelineCell(PipelineSplit2(strategy1, strategy2), 4) | |||
| params = net.network.cell.block[1].trainable_params() | |||
| dataset = DatasetLenet(data, label, 3) | |||
| optimizer = nn.Lamb(params, learning_rate=0.01) | |||
| model = Model(net, optimizer=optimizer) | |||
| model.train(2, dataset, dataset_sink_mode=False) | |||