| @@ -69,6 +69,8 @@ void ParallelContext::Reset() { | |||
| pipeline_stage_split_num_ = 1; | |||
| grad_accumulation_step_ = 1; | |||
| communi_parallel_mode_ = ALL_GROUP_PARALLEL; | |||
| optimizer_weight_shard_size_ = -1; | |||
| optimizer_weight_shard_integrated_save_ = false; | |||
| } | |||
| void ParallelContext::set_device_num(int64_t device_num) { | |||
| @@ -132,6 +134,14 @@ void ParallelContext::set_group_ckpt_save_file(const std::string &group_ckpt_sav | |||
| group_ckpt_save_file_ = group_ckpt_save_file; | |||
| } | |||
| void ParallelContext::set_optimizer_weight_shard_size(int64_t optimizer_weight_shard_size) { | |||
| optimizer_weight_shard_size_ = optimizer_weight_shard_size; | |||
| } | |||
| void ParallelContext::set_optimizer_weight_shard_integrated_save(bool optimizer_weight_shard_integrated_save) { | |||
| optimizer_weight_shard_integrated_save_ = optimizer_weight_shard_integrated_save; | |||
| } | |||
| void ParallelContext::SetAllReduceFusionSplitIndices(const std::vector<uint32_t> indices, const std::string &group) { | |||
| all_reduce_fusion_split_indices_[group] = indices; | |||
| } | |||
| @@ -95,6 +95,11 @@ class ParallelContext { | |||
| bool global_rank_is_set() const { return global_rank_is_set_; } | |||
| bool parameter_broadcast_is_set() const { return parameter_broadcast_is_set_; } | |||
| void set_optimizer_weight_shard_size(int64_t optimizer_weight_shard_size); | |||
| int64_t optimizer_weight_shard_size() const { return optimizer_weight_shard_size_; } | |||
| void set_optimizer_weight_shard_integrated_save(bool optimizer_weight_shard_integrated_save); | |||
| bool optimizer_weight_shard_integrated_save() const { return optimizer_weight_shard_integrated_save_; } | |||
| void SetAllReduceFusionSplitIndices(const std::vector<uint32_t> indices, const std::string &group); | |||
| const std::vector<uint32_t> GetAllReduceFusionSplitIndices(const std::string &group) const; | |||
| void SetAllReduceFusionSplitSizes(const std::vector<uint32_t> sizes, const std::string &group); | |||
| @@ -152,6 +157,8 @@ class ParallelContext { | |||
| bool enable_parallel_optimizer_; | |||
| bool init_param_shape_; | |||
| std::string communi_parallel_mode_; | |||
| int64_t optimizer_weight_shard_size_; | |||
| bool optimizer_weight_shard_integrated_save_; | |||
| }; | |||
| } // namespace parallel | |||
| @@ -473,6 +473,76 @@ Status OperatorInfo::CreateGroupByTensorMap(const Shape &tensor_map, std::vector | |||
| return SUCCESS; | |||
| } | |||
| Status OperatorInfo::CreateGroupForOptShard(TensorLayout *const tensor_layout, std::vector<Group> *groups) { | |||
| if (groups == nullptr) { | |||
| MS_LOG(ERROR) << "The group is null. Operator is " << name_; | |||
| return FAILED; | |||
| } | |||
| CheckGlobalDeviceManager(); | |||
| int64_t rank = g_device_manager->global_rank(); | |||
| DeviceMatrix dev_matrix(rank, stage_device_list_, dev_matrix_shape_); | |||
| RankList group_devices; | |||
| Shape tensor_map = tensor_layout->origin_tensor_map().array(); | |||
| if (dev_matrix.GetDevicesByTensorMap(tensor_map, &group_devices) != SUCCESS) { | |||
| return FAILED; | |||
| } | |||
| if (group_devices.size() == 1) { | |||
| MS_LOG(INFO) << "The dev size is 1, no need to create group."; | |||
| return SUCCESS; | |||
| } | |||
| int64_t optimizer_weight_shard_size = ParallelContext::GetInstance()->optimizer_weight_shard_size(); | |||
| if (optimizer_weight_shard_size != -1) { | |||
| // not fully use opt shard | |||
| int64_t index = std::find(group_devices.begin(), group_devices.end(), rank) - group_devices.begin(); | |||
| int64_t repeated_size = group_devices.size(); | |||
| if (repeated_size % optimizer_weight_shard_size != 0) { | |||
| MS_LOG(WARNING) << "Parallel optimizer: optimizer_weight_shard_size " << optimizer_weight_shard_size | |||
| << " can not be applied. The repeated size of Operator " << name_ << " is " << repeated_size; | |||
| return FAILED; | |||
| } | |||
| repeated_size = repeated_size / optimizer_weight_shard_size; | |||
| // create allgather group | |||
| // eg: optimizer_weight_shard_size = 2, [0, 8, 16, 24] -> [0, 8], [16, 24] | |||
| RankList new_group_devices( | |||
| group_devices.begin() + index / optimizer_weight_shard_size * optimizer_weight_shard_size, | |||
| group_devices.begin() + (index / optimizer_weight_shard_size + 1) * optimizer_weight_shard_size); | |||
| Group allgather_group = g_device_manager->CreateGroup(new_group_devices); | |||
| groups->push_back(allgather_group); | |||
| tensor_layout->set_opt_shard_group(allgather_group.name()); | |||
| MS_LOG(INFO) << "Parallel optimizer: create allgather group " << allgather_group.name(); | |||
| // create mirror group | |||
| // eg: optimizer_weight_shard_size = 2, [0, 8, 16, 24] -> [0, 16], [8, 24] | |||
| int64_t device_num = g_device_manager->stage_device_num(); | |||
| Shape dev_mat = {repeated_size, device_num / repeated_size}; | |||
| DeviceMatrix temp_dev_matrix(rank, stage_device_list_, dev_mat); | |||
| RankList mirror_group_devices; | |||
| if (temp_dev_matrix.GetDevicesAlongDim(0, &mirror_group_devices) != SUCCESS) { | |||
| return FAILED; | |||
| } | |||
| Group mirror_group = g_device_manager->CreateGroup(mirror_group_devices); | |||
| groups->push_back(mirror_group); | |||
| tensor_layout->set_opt_shard_mirror_group(mirror_group.name()); | |||
| MS_LOG(INFO) << "Parallel optimizer: create mirror group " << mirror_group.name(); | |||
| } else { | |||
| // fully use opt shard | |||
| // create allgather group | |||
| Group allgather_group = g_device_manager->CreateGroup(group_devices); | |||
| groups->push_back(allgather_group); | |||
| tensor_layout->set_opt_shard_group(allgather_group.name()); | |||
| MS_LOG(INFO) << "Parallel optimizer: create allgather group " << allgather_group.name(); | |||
| } | |||
| // save in tensor_layout for strategy ckpt | |||
| auto integrated_save = ParallelContext::GetInstance()->optimizer_weight_shard_integrated_save(); | |||
| if (!integrated_save) { | |||
| tensor_layout->set_opt_weight_shard_size(optimizer_weight_shard_size); | |||
| int32_t opt_weight_shard_step = (group_devices.back() - group_devices.front()) / (group_devices.size() - 1); | |||
| tensor_layout->set_opt_weight_shard_step(opt_weight_shard_step); | |||
| MS_LOG(INFO) << "Parallel optimizer: save opt_weight_shard_step " << opt_weight_shard_step << " in strategy ckpt"; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status OperatorInfo::CreateGroupByDim(size_t axis, std::vector<Group> *group) { | |||
| if (group == nullptr) { | |||
| MS_LOG(ERROR) << "The group is null."; | |||
| @@ -177,6 +177,7 @@ class OperatorInfo { | |||
| void set_stage_id(int32_t stage_id) { stage_id_ = stage_id; } | |||
| int32_t stage_id() const { return stage_id_; } | |||
| Status CreateGroupByTensorMap(const Shape &tensor_map, std::vector<Group> *group); | |||
| Status CreateGroupForOptShard(TensorLayout *const tensor_layout, std::vector<Group> *group); | |||
| // Key for user data. | |||
| constexpr static char key[] = "OpInfo"; | |||
| @@ -39,7 +39,6 @@ | |||
| #include "frontend/parallel/graph_util/node_info.h" | |||
| #include "frontend/parallel/node_check.h" | |||
| #include "frontend/parallel/ops_info/matmul_info.h" | |||
| #include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h" | |||
| #include "ir/param_info.h" | |||
| #include "ir/tensor.h" | |||
| #include "utils/comm_manager.h" | |||
| @@ -1069,7 +1068,7 @@ static std::pair<AnfNodePtr, bool> FindParameterByValueNode(const AnfNodePtr &no | |||
| << param_v.size(); | |||
| } | |||
| auto param_ptr = param_v[0]->user_data<parallel::TensorLayout>(); | |||
| if (param_ptr != nullptr && !param_ptr->opt_shard_group().empty()) { | |||
| if (param_ptr && !param_ptr->opt_shard_group().empty() && param_ptr->opt_shard_mirror_group().empty()) { | |||
| return std::make_pair(nullptr, true); | |||
| } | |||
| return std::make_pair(node, true); | |||
| @@ -1077,6 +1076,14 @@ static std::pair<AnfNodePtr, bool> FindParameterByValueNode(const AnfNodePtr &no | |||
| return std::make_pair(nullptr, false); | |||
| } | |||
| static std::pair<AnfNodePtr, bool> FindParameterByParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { | |||
| auto param_ptr = node->user_data<parallel::TensorLayout>(); | |||
| if (param_ptr && !param_ptr->opt_shard_group().empty() && param_ptr->opt_shard_mirror_group().empty()) { | |||
| return std::make_pair(nullptr, false); | |||
| } | |||
| return std::make_pair(node, false); | |||
| } | |||
| // Only used for InsertMirrorOps | |||
| std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { | |||
| if (!node->isa<Parameter>() && !node->isa<CNode>() && !node->isa<ValueNode>()) { | |||
| @@ -1084,11 +1091,7 @@ std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGrap | |||
| } | |||
| if (node->isa<Parameter>()) { | |||
| auto param_ptr = node->user_data<parallel::TensorLayout>(); | |||
| if (param_ptr != nullptr && !param_ptr->opt_shard_group().empty()) { | |||
| return std::make_pair(nullptr, false); | |||
| } | |||
| return std::make_pair(node, false); | |||
| return FindParameterByParameter(node, func_graph); | |||
| } | |||
| if (node->isa<ValueNode>()) { | |||
| @@ -1109,8 +1112,9 @@ std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGrap | |||
| if (IsSomePrimitive(cnode, RECEIVE) && !cnode->has_user_data<OperatorInfo>()) { | |||
| return std::make_pair(node, false); | |||
| } | |||
| if (IsParallelCareNode(cnode)) { | |||
| // When not fully use opt shard, allgather and mirror would be both inserted. | |||
| // Skip allgather here and find parameter recursively. | |||
| if (IsParallelCareNode(cnode) && !IsInAllGatherNodeList(cnode)) { | |||
| return std::make_pair(nullptr, false); | |||
| } | |||
| @@ -1238,10 +1242,17 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons | |||
| auto param_ptr = param_node_pair.first->cast<ParameterPtr>(); | |||
| std::string param_name; | |||
| if (param_ptr != nullptr) { | |||
| if (param_ptr) { | |||
| param_name = param_ptr->name(); | |||
| std::string opt_shard_mirror_group; | |||
| if (param_ptr->user_data<TensorLayout>()) { | |||
| opt_shard_mirror_group = param_ptr->user_data<TensorLayout>()->opt_shard_mirror_group(); | |||
| } | |||
| if (!opt_shard_mirror_group.empty()) { | |||
| // mirror ops is covered in not fully use opt shard case | |||
| backward_op = CreateMirrorOps(opt_shard_mirror_group, static_cast<size_t>(opt_shard_mirror_group[0])); | |||
| } | |||
| } | |||
| // not a RefKey | |||
| if (!param_node_pair.second) { | |||
| int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step(); | |||
| @@ -1275,26 +1286,23 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons | |||
| } | |||
| std::string instance_name = MIRROR_OP; | |||
| CNodePtr cnode = node->input(index)->cast<CNodePtr>(); | |||
| auto op = backward_op[0]; | |||
| if (IsCastBeforMirror(node, index) || (cnode != nullptr && IsSomePrimitive(cnode, LOAD))) { | |||
| for (auto &op : backward_op) { | |||
| // insert new node before the node | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| AnfNodePtr pre_node = cnode->input(1); | |||
| InsertMirrorNode(root, op, cnode, size_t(1), pre_node, func_graph, instance_name, param_name); | |||
| auto comm_op = cnode->input(size_t(1))->cast<CNodePtr>(); | |||
| // add fusion flag | |||
| AddCommOpFusionType(comm_op, param_node_pair.first); | |||
| } | |||
| continue; | |||
| } | |||
| for (auto &op : backward_op) { | |||
| AnfNodePtr pre_node = node->input(index); | |||
| InsertMirrorNode(root, op, node, index, pre_node, func_graph, instance_name, param_name); | |||
| auto comm_op = node->input(index)->cast<CNodePtr>(); | |||
| // insert new node before the node | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| AnfNodePtr pre_node = cnode->input(1); | |||
| InsertMirrorNode(root, op, cnode, size_t(1), pre_node, func_graph, instance_name, param_name); | |||
| auto comm_op = cnode->input(size_t(1))->cast<CNodePtr>(); | |||
| // add fusion flag | |||
| // pipeline mirror would not be set, which should be supported later | |||
| AddCommOpFusionType(comm_op, param_node_pair.first); | |||
| continue; | |||
| } | |||
| AnfNodePtr pre_node = node->input(index); | |||
| InsertMirrorNode(root, op, node, index, pre_node, func_graph, instance_name, param_name); | |||
| auto comm_op = node->input(index)->cast<CNodePtr>(); | |||
| // add fusion flag | |||
| // pipeline mirror would not be set, which should be supported later | |||
| AddCommOpFusionType(comm_op, param_node_pair.first); | |||
| } | |||
| } | |||
| @@ -1695,7 +1703,11 @@ static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr & | |||
| manager->SetEdge(cnode, SizeToLong(param_pair.second), next_cnode.second); | |||
| MS_LOG(INFO) << "Parallel optimizer is applied between " << parameter->ToString() << " and " | |||
| << GetPrimName(cnode); | |||
| continue; | |||
| } else { | |||
| // insert allgather operator between shard parameter and cnode | |||
| InsertAllGatherOp(root, opt_shard_group, param_pair, parameter, op_name); | |||
| MS_LOG(INFO) << "Parallel optimizer is applied between " << parameter->ToString() << " and " | |||
| << GetPrimName(cnode); | |||
| } | |||
| } else { | |||
| // insert allgather operator between shard parameter and cnode | |||
| @@ -1708,6 +1720,35 @@ static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr & | |||
| } | |||
| } | |||
| static std::string GetOptShardGroup(const AnfNodePtr ¶meter, TensorLayout *const tensor_layout, | |||
| const OperatorInfoPtr &distribute_operator) { | |||
| std::string opt_shard_group; | |||
| if (!ParameterRequireGrad(parameter)) { | |||
| // only trainable parameters need parallel optimizer | |||
| MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << " is not trainable parameter."; | |||
| } else if (parameter->cast<ParameterPtr>()->param_info() && | |||
| !parameter->cast<ParameterPtr>()->param_info()->parallel_optimizer()) { | |||
| MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << " does not need weight shard."; | |||
| } else if (tensor_layout->GenerateOptShardSliceShape() == Status::SUCCESS) { | |||
| // get the shard tensor slice shape if the weight is repeated on devices | |||
| // and the shape of the first dimension could be divided | |||
| // apply parallel optimizer on parameters | |||
| // create communication group for allgather operator | |||
| std::vector<Group> dev_group; | |||
| if (distribute_operator->CreateGroupForOptShard(tensor_layout, &dev_group) == Status::SUCCESS && | |||
| !dev_group.empty()) { | |||
| opt_shard_group = dev_group[0].name(); | |||
| MS_LOG(INFO) << "Parallel optimizer: create group for " << parameter->ToString() << " success."; | |||
| } else { | |||
| MS_LOG(ERROR) << "Parallel optimizer: create group for " << parameter->ToString() << " failed."; | |||
| } | |||
| } else { | |||
| MS_LOG(WARNING) << "Parallel optimizer: " << parameter->ToString() << "'s distributed shape " | |||
| << tensor_layout->slice_shape().ToString() << " does not satisfy the conditions."; | |||
| } | |||
| return opt_shard_group; | |||
| } | |||
| // When this function returns non-empty string, that means parallel optimizer is applied on this parameter. | |||
| std::string SetParallelShape(const AnfNodePtr ¶meter, const std::pair<AnfNodePtr, int64_t> &res) { | |||
| MS_EXCEPTION_IF_NULL(parameter); | |||
| @@ -1731,33 +1772,10 @@ std::string SetParallelShape(const AnfNodePtr ¶meter, const std::pair<AnfNod | |||
| MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); | |||
| bool enable_parallel_optimizer = ParallelContext::GetInstance()->enable_parallel_optimizer(); | |||
| if (enable_parallel_optimizer) { | |||
| if (!ParameterRequireGrad(parameter)) { | |||
| // only trainable parameters need parallel optimizer | |||
| MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << " is not trainable parameter."; | |||
| } else if (parameter->cast<ParameterPtr>()->param_info() && | |||
| !parameter->cast<ParameterPtr>()->param_info()->parallel_optimizer()) { | |||
| MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << " does not need weight shard."; | |||
| } else if (tensor_layout.GenerateOptShardSliceShape() == Status::SUCCESS) { | |||
| // get a totally shard tensor slice shape if the weight is repeated on devices | |||
| // and the shape of the first dimension could be divided | |||
| // apply parallel optimizer on parameters | |||
| // create communication group for allgather operator | |||
| slice_shape = tensor_layout.opt_shard_slice_shape(); | |||
| std::vector<Group> dev_group; | |||
| if (distribute_operator->CreateGroupByTensorMap(tensor_layout.origin_tensor_map().array(), &dev_group) == | |||
| Status::SUCCESS && | |||
| !dev_group.empty()) { | |||
| opt_shard_group = dev_group[0].name(); | |||
| // set communication group in tensor layout for checkpoint saving | |||
| tensor_layout.set_opt_shard_group(opt_shard_group); | |||
| MS_LOG(INFO) << "Parallel optimizer: create group " << opt_shard_group << " for " << parameter->ToString() | |||
| << " success."; | |||
| } else { | |||
| MS_LOG(WARNING) << "Parallel optimizer: create group for " << parameter->ToString() << " failed."; | |||
| } | |||
| } else { | |||
| MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << "'s shape does not satisfy the conditions."; | |||
| } | |||
| opt_shard_group = GetOptShardGroup(parameter, &tensor_layout, distribute_operator); | |||
| } | |||
| if (!opt_shard_group.empty()) { | |||
| slice_shape = tensor_layout.opt_shard_slice_shape(); | |||
| } | |||
| MS_LOG(INFO) << "SetParallelShape slice_shape " << parameter->ToString() << " shape " | |||
| << MakeValue(slice_shape)->ToString() << ", op name is " << distribute_operator->name(); | |||
| @@ -2812,21 +2830,21 @@ bool IsCohesiveNode(const CNodePtr &cnode) { | |||
| IsPrimitiveCNode(cnode, prim::kPrimAllGather) || IsPrimitiveCNode(cnode, prim::kPrimMiniStepAllGather); | |||
| } | |||
| std::vector<std::pair<std::string, int64_t>> NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_depth) { | |||
| ParameterMap NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_depth) { | |||
| if (curr_depth > MAX_RECURSIVE_DEPTH) { | |||
| MS_LOG(WARNING) << "When finding the parameters' name of a operator, exceeded the maximum depth: " | |||
| << MAX_RECURSIVE_DEPTH; | |||
| return {}; | |||
| } | |||
| std::vector<AnfNodePtr> node_inputs{node->inputs()}; | |||
| std::vector<std::pair<std::string, int64_t>> param_names; | |||
| ParameterMap param_names; | |||
| for (int64_t i = 0; i < UlongToLong(node_inputs.size()); ++i) { | |||
| int64_t idx = index > i ? index : i; | |||
| auto input = node_inputs[i]; | |||
| if (input->isa<Parameter>()) { | |||
| auto input_parameter = input->cast<ParameterPtr>(); | |||
| if (input_parameter->has_default() && ParameterRequireGrad(input_parameter)) { | |||
| param_names.push_back({input_parameter->name(), idx}); | |||
| param_names.push_back({input_parameter->name(), input_parameter}); | |||
| } | |||
| } else if (input->isa<CNode>()) { | |||
| CNodePtr cnode = input->cast<CNodePtr>(); | |||
| @@ -2878,10 +2896,7 @@ void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes) { | |||
| std::string stratey_key_name = prim->name() + "_" + param_name; | |||
| stra_map[stratey_key_name] = operator_info->strategy(); | |||
| for (auto param_name_pair : param_names) { | |||
| if (param_name_pair.second - 1 >= UlongToLong(input_tensor_info.size())) { | |||
| continue; | |||
| } | |||
| tensor_info_map[param_name_pair.first] = input_tensor_info[param_name_pair.second - 1]; | |||
| tensor_info_map[param_name_pair.first] = param_name_pair.second->user_data<TensorLayout>(); | |||
| } | |||
| if (IsGatherPInfo(operator_info->name())) { | |||
| auto gatherv2_info = std::dynamic_pointer_cast<GatherPInfo>(operator_info); | |||
| @@ -32,6 +32,7 @@ | |||
| #include "pipeline/jit/pipeline.h" | |||
| #include "frontend/parallel/ops_info/ops_utils.h" | |||
| #include "frontend/parallel/auto_parallel/operator_costmodel.h" | |||
| #include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h" | |||
| using OperatorInfoPtr = std::shared_ptr<mindspore::parallel::OperatorInfo>; | |||
| @@ -139,7 +140,7 @@ bool IsLastStage(); | |||
| void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes, | |||
| const FuncGraphManagerPtr &manager); | |||
| std::vector<std::pair<std::string, int64_t>> NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_depth); | |||
| ParameterMap NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_depth); | |||
| void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes); | |||
| @@ -17,7 +17,6 @@ | |||
| #include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h" | |||
| #include <fstream> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "utils/ms_utils.h" | |||
| @@ -141,20 +140,19 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInf | |||
| } | |||
| } | |||
| for (auto &node_tensor_info : tensor_info_map) { | |||
| TensorInfo tensor_info = node_tensor_info.second; | |||
| TensorLayout tensor_layout = tensor_info.tensor_layout(); | |||
| TensorLayoutPtr tensor_layout = node_tensor_info.second; | |||
| straspb::ParallelLayoutItem *parallel_layout_item = parallel_strategy_map.add_parallel_layout_item(); | |||
| MS_EXCEPTION_IF_NULL(parallel_layout_item); | |||
| parallel_layout_item->set_param_name(node_tensor_info.first); | |||
| straspb::ParallelLayouts *parallel_layouts = parallel_layout_item->mutable_parallel_layouts(); | |||
| straspb::DevMatrix *dev_matrix = parallel_layouts->add_dev_matrix(); | |||
| MS_EXCEPTION_IF_NULL(dev_matrix); | |||
| for (auto dim : tensor_layout.device_arrangement().array()) { | |||
| for (auto dim : tensor_layout->device_arrangement().array()) { | |||
| dev_matrix->add_dim(LongToUlong(dim)); | |||
| } | |||
| straspb::TensorMap *tensor_map = parallel_layouts->add_tensor_map(); | |||
| MS_EXCEPTION_IF_NULL(tensor_map); | |||
| for (auto dim : tensor_layout.tensor_map().array()) { | |||
| for (auto dim : tensor_layout->tensor_map().array()) { | |||
| tensor_map->add_dim(dim); | |||
| } | |||
| straspb::ParamSplitShape *param_split_shape = parallel_layouts->add_param_split_shape(); | |||
| @@ -165,7 +163,9 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInf | |||
| param_split_shape->add_dim(dim_pair.first); | |||
| indices_offset->add_dim(dim_pair.second); | |||
| } | |||
| parallel_layouts->set_field(tensor_layout.get_field_size()); | |||
| parallel_layouts->set_field(tensor_layout->get_field_size()); | |||
| parallel_layouts->set_opt_weight_shard_step(tensor_layout->opt_weight_shard_step()); | |||
| parallel_layouts->set_opt_weight_shard_size(tensor_layout->opt_weight_shard_size()); | |||
| } | |||
| std::fstream output(save_file_, std::ios::out | std::ios::trunc | std::ios::binary); | |||
| @@ -21,6 +21,7 @@ | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "frontend/parallel/ops_info/ops_utils.h" | |||
| #include "frontend/parallel/strategy.h" | |||
| #include "frontend/parallel/context.h" | |||
| @@ -30,7 +31,9 @@ | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| using StrategyMap = std::unordered_map<std::string, StrategyPtr>; | |||
| using TensorInfoMap = std::unordered_map<std::string, TensorInfo>; | |||
| using TensorLayoutPtr = std::shared_ptr<TensorLayout>; | |||
| using TensorInfoMap = std::unordered_map<std::string, TensorLayoutPtr>; | |||
| using ParameterMap = std::vector<std::pair<std::string, ParameterPtr>>; | |||
| using ManualShapeMap = std::unordered_map<std::string, std::vector<std::pair<int64_t, int64_t>>>; | |||
| using GroupInfoMap = std::vector<std::pair<std::string, std::vector<uint32_t>>>; | |||
| class StrategyCheckpoint { | |||
| @@ -21,6 +21,7 @@ | |||
| #include "ir/value.h" | |||
| #include "frontend/parallel/device_matrix.h" | |||
| #include "frontend/parallel/status.h" | |||
| #include "frontend/parallel/context.h" | |||
| #include "frontend/parallel/tensor_layout/shape_util.h" | |||
| #include "utils/log_adapter.h" | |||
| @@ -431,6 +432,10 @@ Status TensorLayout::GenerateOptShardSliceShape() { | |||
| int64_t repeated_num = | |||
| std::accumulate(repeated_dev.begin(), repeated_dev.end(), static_cast<int64_t>(1), std::multiplies<int64_t>()); | |||
| int64_t split_num; | |||
| int64_t optimizer_weight_shard_size = ParallelContext::GetInstance()->optimizer_weight_shard_size(); | |||
| if (optimizer_weight_shard_size != -1) { | |||
| repeated_num = optimizer_weight_shard_size; | |||
| } | |||
| if (tensor_map[0] == MAP_NONE) { | |||
| split_num = repeated_num; | |||
| } else { | |||
| @@ -104,6 +104,18 @@ class TensorLayout { | |||
| std::string opt_shard_group() { return opt_shard_group_; } | |||
| void set_opt_shard_mirror_group(std::string name) { opt_shard_mirror_group_ = std::move(name); } | |||
| std::string opt_shard_mirror_group() { return opt_shard_mirror_group_; } | |||
| void set_opt_weight_shard_step(int32_t step) { opt_weight_shard_step_ = step; } | |||
| int32_t opt_weight_shard_step() { return opt_weight_shard_step_; } | |||
| void set_opt_weight_shard_size(int32_t size) { opt_weight_shard_size_ = size; } | |||
| int32_t opt_weight_shard_size() { return opt_weight_shard_size_; } | |||
| // Key for user data. | |||
| constexpr static char key[] = "TLayout"; | |||
| @@ -129,7 +141,10 @@ class TensorLayout { | |||
| bool layout_transfer_ = false; | |||
| int32_t field_size_ = 0; | |||
| Shape opt_shard_slice_shape_; | |||
| std::string opt_shard_group_ = ""; | |||
| std::string opt_shard_group_ = ""; // for allgather | |||
| std::string opt_shard_mirror_group_ = ""; // for mirror ops | |||
| int32_t opt_weight_shard_step_ = 0; | |||
| int32_t opt_weight_shard_size_ = 0; | |||
| }; | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -173,6 +173,14 @@ PYBIND11_MODULE(_c_expression, m) { | |||
| "Get enable/disable parallel optimizer.") | |||
| .def("set_communi_parallel_mode", &ParallelContext::set_communi_parallel_mode, "Set communication parallel mode.") | |||
| .def("get_communi_parallel_mode", &ParallelContext::communi_parallel_mode, "Get communication parallel mode.") | |||
| .def("set_optimizer_weight_shard_size", &ParallelContext::set_optimizer_weight_shard_size, | |||
| "Set opt shard group size when not fully use parallel optimizer.") | |||
| .def("get_optimizer_weight_shard_size", &ParallelContext::optimizer_weight_shard_size, | |||
| "Get opt shard group size when not fully use parallel optimizer.") | |||
| .def("set_optimizer_weight_shard_integrated_save", &ParallelContext::set_optimizer_weight_shard_integrated_save, | |||
| "Set whether to integrated save weight shard when enable parallel optimizer.") | |||
| .def("get_optimizer_weight_shard_integrated_save", &ParallelContext::optimizer_weight_shard_integrated_save, | |||
| "Get whether to integrated save weight shard when enable parallel optimizer.") | |||
| .def("reset", &ParallelContext::Reset, "Reset auto parallel context."); | |||
| (void)py::class_<CostModelContext, std::shared_ptr<CostModelContext>>(m, "CostModelContext") | |||
| @@ -54,6 +54,8 @@ message ParallelLayouts { | |||
| repeated ParamSplitShape param_split_shape = 3; | |||
| repeated IndicesOffset indices_offset = 4; | |||
| required int32 field = 5; | |||
| required int32 opt_weight_shard_step = 6; | |||
| required int32 opt_weight_shard_size = 7; | |||
| } | |||
| message ParallelLayoutItem { | |||
| @@ -14,11 +14,10 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "utils/parallel_node_check.h" | |||
| #include <set> | |||
| #include <string> | |||
| #include "utils/parallel_node_check.h" | |||
| #include "base/core_ops.h" | |||
| namespace mindspore { | |||
| @@ -32,6 +31,7 @@ static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem, | |||
| "ImageSummary", "TensorSummary", "Debug", "HistogramSummary", "col2im_v1", "resolve", "BroadcastGradientArgs", | |||
| "InvertPermutation", "DropoutGenMask", "embed", "create_instance", "RefToEmbed", | |||
| "stop_gradient", "Send", "UpdateState", "Load"}; | |||
| static const std::set<PrimitivePtr> ALLGATHER_NODE_LIST_ = {prim::kPrimAllGather, prim::kPrimMiniStepAllGather}; | |||
| // clang-format on | |||
| bool IsInParallelBlackList(const PrimitivePtr &prim) { | |||
| @@ -39,6 +39,15 @@ bool IsInParallelBlackList(const PrimitivePtr &prim) { | |||
| return (PARALLEL_BLACK_LIST_.find(prim->name()) != PARALLEL_BLACK_LIST_.end()); | |||
| } | |||
| bool IsInAllGatherNodeList(const CNodePtr &cnode) { | |||
| for (auto &value : ALLGATHER_NODE_LIST_) { | |||
| if (IsPrimitiveCNode(cnode, value)) { | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| bool IsParallelConsiderCNode(const CNodePtr &cnode) { | |||
| if (cnode == nullptr || cnode->size() == 0) { | |||
| return false; | |||
| @@ -51,9 +60,6 @@ bool IsParallelConsiderCNode(const CNodePtr &cnode) { | |||
| if (prim == nullptr) { | |||
| return false; | |||
| } | |||
| if (IsInParallelBlackList(prim)) { | |||
| return false; | |||
| } | |||
| return true; | |||
| return !IsInParallelBlackList(prim); | |||
| } | |||
| } // namespace mindspore | |||
| @@ -21,6 +21,7 @@ | |||
| namespace mindspore { | |||
| bool IsInParallelBlackList(const PrimitivePtr &); | |||
| bool IsInAllGatherNodeList(const CNodePtr &); | |||
| bool IsParallelConsiderCNode(const CNodePtr &); | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_UTILS_PARALLEL_NODE_CHECK_H_ | |||
| @@ -15,6 +15,7 @@ | |||
| """Context of auto parallel""" | |||
| import threading | |||
| import mindspore.context as context | |||
| import mindspore.log as logger | |||
| from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, _set_fusion_strategy_by_size | |||
| from mindspore.parallel._ps_context import _is_role_pserver | |||
| from mindspore._c_expression import AutoParallelContext | |||
| @@ -501,6 +502,48 @@ class _AutoParallelContext: | |||
| self.check_context_handle() | |||
| return self._context_handle.get_communi_parallel_mode() | |||
| def set_optimizer_weight_shard_size(self, optimizer_weight_shard_size): | |||
| """ | |||
| Set optimizer_weight_shard_size. | |||
| Args: | |||
| optimizer_weight_shard_size (int): Opt shard group size when not globally use parallel | |||
| optimizer across devices. | |||
| """ | |||
| self.check_context_handle() | |||
| if not isinstance(optimizer_weight_shard_size, int): | |||
| raise TypeError('optimizer_weight_shard_size is invalid type') | |||
| if optimizer_weight_shard_size <= 1: | |||
| logger.warning("The setting 'optimizer_weight_shard_size' is invalid. " | |||
| "Please use the integer larger than 1.") | |||
| return | |||
| self._context_handle.set_optimizer_weight_shard_size(optimizer_weight_shard_size) | |||
| def get_optimizer_weight_shard_size(self): | |||
| """Get optimizer_weight_shard_size.""" | |||
| self.check_context_handle() | |||
| return self._context_handle.get_optimizer_weight_shard_size() | |||
| def set_optimizer_weight_shard_integrated_save(self, optimizer_weight_shard_integrated_save): | |||
| """ | |||
| Set optimizer_weight_shard_integrated_save. | |||
| Args: | |||
| optimizer_weight_shard_integrated_save (bool): Whether to integrated save weight shard when | |||
| enable parallel optimizer. | |||
| """ | |||
| self.check_context_handle() | |||
| if not isinstance(optimizer_weight_shard_integrated_save, bool): | |||
| raise TypeError('optimizer_weight_shard_integrated_save is invalid type') | |||
| self._context_handle.set_optimizer_weight_shard_integrated_save(optimizer_weight_shard_integrated_save) | |||
| def get_optimizer_weight_shard_integrated_save(self): | |||
| """Get optimizer_weight_shard_size.""" | |||
| self.check_context_handle() | |||
| return self._context_handle.get_optimizer_weight_shard_integrated_save() | |||
| def reset(self): | |||
| """Reset all settings.""" | |||
| self.check_context_handle() | |||
| @@ -540,7 +583,9 @@ _set_auto_parallel_context_func_map = { | |||
| "enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer, | |||
| "grad_accumulation_step": auto_parallel_context().set_grad_accumulation_step, | |||
| "all_reduce_fusion_config": auto_parallel_context().set_all_reduce_fusion_split_indices, | |||
| "communi_parallel_mode": auto_parallel_context().set_communi_parallel_mode} | |||
| "communi_parallel_mode": auto_parallel_context().set_communi_parallel_mode, | |||
| "optimizer_weight_shard_size": auto_parallel_context().set_optimizer_weight_shard_size, | |||
| "optimizer_weight_shard_integrated_save": auto_parallel_context().set_optimizer_weight_shard_integrated_save} | |||
| _get_auto_parallel_context_func_map = { | |||
| @@ -559,7 +604,9 @@ _get_auto_parallel_context_func_map = { | |||
| "enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer, | |||
| "grad_accumulation_step": auto_parallel_context().get_grad_accumulation_step, | |||
| "all_reduce_fusion_config": auto_parallel_context().get_all_reduce_fusion_split_indices, | |||
| "communi_parallel_mode": auto_parallel_context().get_communi_parallel_mode} | |||
| "communi_parallel_mode": auto_parallel_context().get_communi_parallel_mode, | |||
| "optimizer_weight_shard_size": auto_parallel_context().get_optimizer_weight_shard_size, | |||
| "optimizer_weight_shard_integrated_save": auto_parallel_context().get_optimizer_weight_shard_integrated_save} | |||
| @args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool, | |||
| @@ -567,7 +614,8 @@ _get_auto_parallel_context_func_map = { | |||
| parameter_broadcast=bool, strategy_ckpt_load_file=str, | |||
| strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool, | |||
| grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str, | |||
| communi_parallel_mode=str) | |||
| communi_parallel_mode=str, optimizer_weight_shard_size=int, | |||
| optimizer_weight_shard_integrated_save=bool) | |||
| def _set_auto_parallel_context(**kwargs): | |||
| """ | |||
| @@ -615,7 +663,7 @@ def _set_auto_parallel_context(**kwargs): | |||
| pipeline_stages (int): Set the stage information for pipeline parallel. This indicates how | |||
| the devices are distributed alone the pipeline. The total devices will be divided into | |||
| 'pipeline_stags' stages. This currently could only be used when | |||
| parall mode semi_auto_parallel is enabled. Default: 0 | |||
| parallel mode semi_auto_parallel is enabled. Default: 0 | |||
| communi_parallel_mode (str): There are tree kinds of communication parallel modes, "all_group_parallel", | |||
| "same_server_group_parallel" and "no_group_parallel". Default: "all_group_parallel". | |||
| @@ -624,6 +672,11 @@ def _set_auto_parallel_context(**kwargs): | |||
| - same_server_group_parallel: Only the communication groups within the same server are parallel. | |||
| - no_group_parallel: All communication groups are not parallel. | |||
| optimizer_weight_shard_size (int): Set optimizer shard group size when not fully use parallel optimizer. | |||
| It should be larger than one and less than or equal with the data parallel size. | |||
| Default: -1, which means fully use parallel optimizer in data parallel dimension. | |||
| optimizer_weight_shard_integrated_save (bool): Whether to integrated save weight shard when enable parallel | |||
| optimizer. Default: False. | |||
| Raises: | |||
| ValueError: If input key is not attribute in auto parallel context. | |||
| @@ -248,6 +248,10 @@ def _remove_repeated_slices(tensor_layout): | |||
| def _infer_rank_list(train_map, predict_map=None): | |||
| """infer checkpoint slices to be loaded""" | |||
| ret = {} | |||
| if _get_pipeline_stages() > 1: | |||
| local_rank = int(_get_global_rank() % (_get_device_num() / _get_pipeline_stages())) | |||
| else: | |||
| local_rank = _get_global_rank() | |||
| for param_name in train_map: | |||
| train_layout = train_map[param_name] | |||
| train_dev_mat = train_layout[0] | |||
| @@ -271,15 +275,13 @@ def _infer_rank_list(train_map, predict_map=None): | |||
| dev_num = np.array(predict_layout[0]).prod() | |||
| # optimization pass | |||
| if _check_same_layout(train_layout, predict_layout): | |||
| dev_rank = _get_global_rank() | |||
| ret[param_name] = ([dev_rank], True) | |||
| ret[param_name] = ([local_rank], True) | |||
| continue | |||
| if _check_similar_layout(train_layout, predict_layout): | |||
| if len(rank_list) == 1: | |||
| ret[param_name] = (rank_list, True) | |||
| elif len(rank_list) == dev_num: | |||
| dev_rank = _get_global_rank() | |||
| ret[param_name] = ([rank_list[dev_rank]], True) | |||
| ret[param_name] = ([rank_list[local_rank]], True) | |||
| else: | |||
| ret[param_name] = (rank_list, False) | |||
| else: | |||
| @@ -597,7 +597,7 @@ def _get_merged_param_data(net, param_name, param_data, integrated_save): | |||
| allgather_net = get_allgather_cell(opt_shard_group, False) | |||
| net.parallel_parameter_merge_net_dict[param_name] = allgather_net | |||
| param_data = allgather_net(param_data) | |||
| elif opt_shard_group: | |||
| elif opt_shard_group and context.get_auto_parallel_context("optimizer_weight_shard_integrated_save"): | |||
| if allgather_net is None: | |||
| allgather_net = get_allgather_cell(opt_shard_group, False) | |||
| net.parallel_parameter_merge_net_dict[param_name] = allgather_net | |||
| @@ -1247,7 +1247,9 @@ def _convert_to_list(strategy): | |||
| tensor_map = list(layout.tensor_map[0].dim) | |||
| param_split_shape = list(layout.param_split_shape[0].dim) | |||
| field_size = int(layout.field) | |||
| train_map[param_name] = [dev_mat, tensor_map, param_split_shape, field_size] | |||
| shard_stride = int(layout.opt_weight_shard_step) | |||
| shard_size = int(layout.opt_weight_shard_size) | |||
| train_map[param_name] = [dev_mat, tensor_map, param_split_shape, field_size, shard_stride, shard_size] | |||
| except BaseException as e: | |||
| raise ValueError(f"{e.__str__()}. Please make sure that strategy matches the node_strategy.proto.") | |||
| return train_map | |||
| @@ -131,6 +131,17 @@ def test_auto_parallel_momentum_5(): | |||
| assert not param_dict["weight2"][5] | |||
| def test_auto_parallel_momentum_6(): | |||
| # test not fully use parallel optimizer with optimizer_weight_shard_size | |||
| # weight1 could not be shard and weight2 is repeated | |||
| context.set_auto_parallel_context(optimizer_weight_shard_size=2) | |||
| train_network = auto_parallel_compile_net("semi_auto_parallel", 32, Net2, ((4, 8), (8, 1)), ((4, 4), (4, 2))) | |||
| param_dict = train_network.parameter_layout_dict | |||
| # validate opt_shard_group | |||
| assert param_dict["weight1"][5].startswith("2") | |||
| assert param_dict["weight2"][5].startswith("2") | |||
| def test_AdamWeightDecay(): | |||
| """ test_AdamWeightDecay """ | |||
| context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True) | |||
| @@ -59,6 +59,10 @@ def test_set_auto_parallel_context(): | |||
| parameter_broadcast_is_set = auto_parallel_context().get_parameter_broadcast_is_set() | |||
| assert parameter_broadcast_is_set | |||
| auto_parallel_context().set_optimizer_weight_shard_integrated_save(True) | |||
| integrated_save = auto_parallel_context().get_optimizer_weight_shard_integrated_save() | |||
| assert integrated_save | |||
| with pytest.raises(ValueError): | |||
| context.set_auto_parallel_context(device_num=0) | |||
| @@ -105,6 +109,7 @@ def test_reset_auto_parallel_context(): | |||
| parameter_broadcast_is_set = auto_parallel_context().get_parameter_broadcast_is_set() | |||
| stage = auto_parallel_context().get_pipeline_stages() | |||
| communi_parallel_mode = context.get_auto_parallel_context("communi_parallel_mode") | |||
| integrated_save = auto_parallel_context().get_optimizer_weight_shard_integrated_save() | |||
| assert device_num == 1 | |||
| assert global_rank == 0 | |||
| @@ -116,3 +121,4 @@ def test_reset_auto_parallel_context(): | |||
| assert not parameter_broadcast_is_set | |||
| assert stage == 1 | |||
| assert communi_parallel_mode == "all_group_parallel" | |||
| assert not integrated_save | |||