Browse Source

!15180 enable not fully use opt shard

From: @gong_zi_yan
Reviewed-by: @stsuteng
Signed-off-by: @stsuteng
pull/15180/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
ee885b4e58
19 changed files with 305 additions and 87 deletions
  1. +10
    -0
      mindspore/ccsrc/frontend/parallel/context.cc
  2. +7
    -0
      mindspore/ccsrc/frontend/parallel/context.h
  3. +70
    -0
      mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc
  4. +1
    -0
      mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h
  5. +77
    -62
      mindspore/ccsrc/frontend/parallel/step_parallel.cc
  6. +2
    -1
      mindspore/ccsrc/frontend/parallel/step_parallel.h
  7. +6
    -6
      mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc
  8. +4
    -1
      mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h
  9. +5
    -0
      mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.cc
  10. +16
    -1
      mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.h
  11. +8
    -0
      mindspore/ccsrc/pipeline/jit/init.cc
  12. +2
    -0
      mindspore/ccsrc/utils/node_strategy.proto
  13. +12
    -6
      mindspore/core/utils/parallel_node_check.cc
  14. +1
    -0
      mindspore/core/utils/parallel_node_check.h
  15. +57
    -4
      mindspore/parallel/_auto_parallel_context.py
  16. +6
    -4
      mindspore/parallel/_utils.py
  17. +4
    -2
      mindspore/train/serialization.py
  18. +11
    -0
      tests/ut/python/parallel/test_parallel_optimizer.py
  19. +6
    -0
      tests/ut/python/parallel/test_set_auto_parallel_context.py

+ 10
- 0
mindspore/ccsrc/frontend/parallel/context.cc View File

@@ -69,6 +69,8 @@ void ParallelContext::Reset() {
pipeline_stage_split_num_ = 1; pipeline_stage_split_num_ = 1;
grad_accumulation_step_ = 1; grad_accumulation_step_ = 1;
communi_parallel_mode_ = ALL_GROUP_PARALLEL; communi_parallel_mode_ = ALL_GROUP_PARALLEL;
optimizer_weight_shard_size_ = -1;
optimizer_weight_shard_integrated_save_ = false;
} }


void ParallelContext::set_device_num(int64_t device_num) { void ParallelContext::set_device_num(int64_t device_num) {
@@ -132,6 +134,14 @@ void ParallelContext::set_group_ckpt_save_file(const std::string &group_ckpt_sav
group_ckpt_save_file_ = group_ckpt_save_file; group_ckpt_save_file_ = group_ckpt_save_file;
} }


void ParallelContext::set_optimizer_weight_shard_size(int64_t optimizer_weight_shard_size) {
optimizer_weight_shard_size_ = optimizer_weight_shard_size;
}

void ParallelContext::set_optimizer_weight_shard_integrated_save(bool optimizer_weight_shard_integrated_save) {
optimizer_weight_shard_integrated_save_ = optimizer_weight_shard_integrated_save;
}

void ParallelContext::SetAllReduceFusionSplitIndices(const std::vector<uint32_t> indices, const std::string &group) { void ParallelContext::SetAllReduceFusionSplitIndices(const std::vector<uint32_t> indices, const std::string &group) {
all_reduce_fusion_split_indices_[group] = indices; all_reduce_fusion_split_indices_[group] = indices;
} }


+ 7
- 0
mindspore/ccsrc/frontend/parallel/context.h View File

@@ -95,6 +95,11 @@ class ParallelContext {
bool global_rank_is_set() const { return global_rank_is_set_; } bool global_rank_is_set() const { return global_rank_is_set_; }
bool parameter_broadcast_is_set() const { return parameter_broadcast_is_set_; } bool parameter_broadcast_is_set() const { return parameter_broadcast_is_set_; }


void set_optimizer_weight_shard_size(int64_t optimizer_weight_shard_size);
int64_t optimizer_weight_shard_size() const { return optimizer_weight_shard_size_; }
void set_optimizer_weight_shard_integrated_save(bool optimizer_weight_shard_integrated_save);
bool optimizer_weight_shard_integrated_save() const { return optimizer_weight_shard_integrated_save_; }

void SetAllReduceFusionSplitIndices(const std::vector<uint32_t> indices, const std::string &group); void SetAllReduceFusionSplitIndices(const std::vector<uint32_t> indices, const std::string &group);
const std::vector<uint32_t> GetAllReduceFusionSplitIndices(const std::string &group) const; const std::vector<uint32_t> GetAllReduceFusionSplitIndices(const std::string &group) const;
void SetAllReduceFusionSplitSizes(const std::vector<uint32_t> sizes, const std::string &group); void SetAllReduceFusionSplitSizes(const std::vector<uint32_t> sizes, const std::string &group);
@@ -152,6 +157,8 @@ class ParallelContext {
bool enable_parallel_optimizer_; bool enable_parallel_optimizer_;
bool init_param_shape_; bool init_param_shape_;
std::string communi_parallel_mode_; std::string communi_parallel_mode_;
int64_t optimizer_weight_shard_size_;
bool optimizer_weight_shard_integrated_save_;
}; };


} // namespace parallel } // namespace parallel


+ 70
- 0
mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc View File

@@ -473,6 +473,76 @@ Status OperatorInfo::CreateGroupByTensorMap(const Shape &tensor_map, std::vector
return SUCCESS; return SUCCESS;
} }


Status OperatorInfo::CreateGroupForOptShard(TensorLayout *const tensor_layout, std::vector<Group> *groups) {
if (groups == nullptr) {
MS_LOG(ERROR) << "The group is null. Operator is " << name_;
return FAILED;
}
CheckGlobalDeviceManager();
int64_t rank = g_device_manager->global_rank();
DeviceMatrix dev_matrix(rank, stage_device_list_, dev_matrix_shape_);
RankList group_devices;
Shape tensor_map = tensor_layout->origin_tensor_map().array();
if (dev_matrix.GetDevicesByTensorMap(tensor_map, &group_devices) != SUCCESS) {
return FAILED;
}

if (group_devices.size() == 1) {
MS_LOG(INFO) << "The dev size is 1, no need to create group.";
return SUCCESS;
}
int64_t optimizer_weight_shard_size = ParallelContext::GetInstance()->optimizer_weight_shard_size();
if (optimizer_weight_shard_size != -1) {
// not fully use opt shard
int64_t index = std::find(group_devices.begin(), group_devices.end(), rank) - group_devices.begin();
int64_t repeated_size = group_devices.size();
if (repeated_size % optimizer_weight_shard_size != 0) {
MS_LOG(WARNING) << "Parallel optimizer: optimizer_weight_shard_size " << optimizer_weight_shard_size
<< " can not be applied. The repeated size of Operator " << name_ << " is " << repeated_size;
return FAILED;
}
repeated_size = repeated_size / optimizer_weight_shard_size;
// create allgather group
// eg: optimizer_weight_shard_size = 2, [0, 8, 16, 24] -> [0, 8], [16, 24]
RankList new_group_devices(
group_devices.begin() + index / optimizer_weight_shard_size * optimizer_weight_shard_size,
group_devices.begin() + (index / optimizer_weight_shard_size + 1) * optimizer_weight_shard_size);
Group allgather_group = g_device_manager->CreateGroup(new_group_devices);
groups->push_back(allgather_group);
tensor_layout->set_opt_shard_group(allgather_group.name());
MS_LOG(INFO) << "Parallel optimizer: create allgather group " << allgather_group.name();
// create mirror group
// eg: optimizer_weight_shard_size = 2, [0, 8, 16, 24] -> [0, 16], [8, 24]
int64_t device_num = g_device_manager->stage_device_num();
Shape dev_mat = {repeated_size, device_num / repeated_size};
DeviceMatrix temp_dev_matrix(rank, stage_device_list_, dev_mat);
RankList mirror_group_devices;
if (temp_dev_matrix.GetDevicesAlongDim(0, &mirror_group_devices) != SUCCESS) {
return FAILED;
}
Group mirror_group = g_device_manager->CreateGroup(mirror_group_devices);
groups->push_back(mirror_group);
tensor_layout->set_opt_shard_mirror_group(mirror_group.name());
MS_LOG(INFO) << "Parallel optimizer: create mirror group " << mirror_group.name();
} else {
// fully use opt shard
// create allgather group
Group allgather_group = g_device_manager->CreateGroup(group_devices);
groups->push_back(allgather_group);
tensor_layout->set_opt_shard_group(allgather_group.name());
MS_LOG(INFO) << "Parallel optimizer: create allgather group " << allgather_group.name();
}
// save in tensor_layout for strategy ckpt
auto integrated_save = ParallelContext::GetInstance()->optimizer_weight_shard_integrated_save();
if (!integrated_save) {
tensor_layout->set_opt_weight_shard_size(optimizer_weight_shard_size);
int32_t opt_weight_shard_step = (group_devices.back() - group_devices.front()) / (group_devices.size() - 1);
tensor_layout->set_opt_weight_shard_step(opt_weight_shard_step);
MS_LOG(INFO) << "Parallel optimizer: save opt_weight_shard_step " << opt_weight_shard_step << " in strategy ckpt";
}
return SUCCESS;
}

Status OperatorInfo::CreateGroupByDim(size_t axis, std::vector<Group> *group) { Status OperatorInfo::CreateGroupByDim(size_t axis, std::vector<Group> *group) {
if (group == nullptr) { if (group == nullptr) {
MS_LOG(ERROR) << "The group is null."; MS_LOG(ERROR) << "The group is null.";


+ 1
- 0
mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h View File

@@ -177,6 +177,7 @@ class OperatorInfo {
void set_stage_id(int32_t stage_id) { stage_id_ = stage_id; } void set_stage_id(int32_t stage_id) { stage_id_ = stage_id; }
int32_t stage_id() const { return stage_id_; } int32_t stage_id() const { return stage_id_; }
Status CreateGroupByTensorMap(const Shape &tensor_map, std::vector<Group> *group); Status CreateGroupByTensorMap(const Shape &tensor_map, std::vector<Group> *group);
Status CreateGroupForOptShard(TensorLayout *const tensor_layout, std::vector<Group> *group);


// Key for user data. // Key for user data.
constexpr static char key[] = "OpInfo"; constexpr static char key[] = "OpInfo";


+ 77
- 62
mindspore/ccsrc/frontend/parallel/step_parallel.cc View File

@@ -39,7 +39,6 @@
#include "frontend/parallel/graph_util/node_info.h" #include "frontend/parallel/graph_util/node_info.h"
#include "frontend/parallel/node_check.h" #include "frontend/parallel/node_check.h"
#include "frontend/parallel/ops_info/matmul_info.h" #include "frontend/parallel/ops_info/matmul_info.h"
#include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h"
#include "ir/param_info.h" #include "ir/param_info.h"
#include "ir/tensor.h" #include "ir/tensor.h"
#include "utils/comm_manager.h" #include "utils/comm_manager.h"
@@ -1069,7 +1068,7 @@ static std::pair<AnfNodePtr, bool> FindParameterByValueNode(const AnfNodePtr &no
<< param_v.size(); << param_v.size();
} }
auto param_ptr = param_v[0]->user_data<parallel::TensorLayout>(); auto param_ptr = param_v[0]->user_data<parallel::TensorLayout>();
if (param_ptr != nullptr && !param_ptr->opt_shard_group().empty()) {
if (param_ptr && !param_ptr->opt_shard_group().empty() && param_ptr->opt_shard_mirror_group().empty()) {
return std::make_pair(nullptr, true); return std::make_pair(nullptr, true);
} }
return std::make_pair(node, true); return std::make_pair(node, true);
@@ -1077,6 +1076,14 @@ static std::pair<AnfNodePtr, bool> FindParameterByValueNode(const AnfNodePtr &no
return std::make_pair(nullptr, false); return std::make_pair(nullptr, false);
} }


static std::pair<AnfNodePtr, bool> FindParameterByParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
auto param_ptr = node->user_data<parallel::TensorLayout>();
if (param_ptr && !param_ptr->opt_shard_group().empty() && param_ptr->opt_shard_mirror_group().empty()) {
return std::make_pair(nullptr, false);
}
return std::make_pair(node, false);
}

// Only used for InsertMirrorOps // Only used for InsertMirrorOps
std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
if (!node->isa<Parameter>() && !node->isa<CNode>() && !node->isa<ValueNode>()) { if (!node->isa<Parameter>() && !node->isa<CNode>() && !node->isa<ValueNode>()) {
@@ -1084,11 +1091,7 @@ std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGrap
} }


if (node->isa<Parameter>()) { if (node->isa<Parameter>()) {
auto param_ptr = node->user_data<parallel::TensorLayout>();
if (param_ptr != nullptr && !param_ptr->opt_shard_group().empty()) {
return std::make_pair(nullptr, false);
}
return std::make_pair(node, false);
return FindParameterByParameter(node, func_graph);
} }


if (node->isa<ValueNode>()) { if (node->isa<ValueNode>()) {
@@ -1109,8 +1112,9 @@ std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGrap
if (IsSomePrimitive(cnode, RECEIVE) && !cnode->has_user_data<OperatorInfo>()) { if (IsSomePrimitive(cnode, RECEIVE) && !cnode->has_user_data<OperatorInfo>()) {
return std::make_pair(node, false); return std::make_pair(node, false);
} }

if (IsParallelCareNode(cnode)) {
// When not fully use opt shard, allgather and mirror would be both inserted.
// Skip allgather here and find parameter recursively.
if (IsParallelCareNode(cnode) && !IsInAllGatherNodeList(cnode)) {
return std::make_pair(nullptr, false); return std::make_pair(nullptr, false);
} }


@@ -1238,10 +1242,17 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons


auto param_ptr = param_node_pair.first->cast<ParameterPtr>(); auto param_ptr = param_node_pair.first->cast<ParameterPtr>();
std::string param_name; std::string param_name;
if (param_ptr != nullptr) {
if (param_ptr) {
param_name = param_ptr->name(); param_name = param_ptr->name();
std::string opt_shard_mirror_group;
if (param_ptr->user_data<TensorLayout>()) {
opt_shard_mirror_group = param_ptr->user_data<TensorLayout>()->opt_shard_mirror_group();
}
if (!opt_shard_mirror_group.empty()) {
// mirror ops is covered in not fully use opt shard case
backward_op = CreateMirrorOps(opt_shard_mirror_group, static_cast<size_t>(opt_shard_mirror_group[0]));
}
} }

// not a RefKey // not a RefKey
if (!param_node_pair.second) { if (!param_node_pair.second) {
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step(); int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
@@ -1275,26 +1286,23 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons
} }
std::string instance_name = MIRROR_OP; std::string instance_name = MIRROR_OP;
CNodePtr cnode = node->input(index)->cast<CNodePtr>(); CNodePtr cnode = node->input(index)->cast<CNodePtr>();
auto op = backward_op[0];
if (IsCastBeforMirror(node, index) || (cnode != nullptr && IsSomePrimitive(cnode, LOAD))) { if (IsCastBeforMirror(node, index) || (cnode != nullptr && IsSomePrimitive(cnode, LOAD))) {
for (auto &op : backward_op) {
// insert new node before the node
MS_EXCEPTION_IF_NULL(cnode);
AnfNodePtr pre_node = cnode->input(1);
InsertMirrorNode(root, op, cnode, size_t(1), pre_node, func_graph, instance_name, param_name);
auto comm_op = cnode->input(size_t(1))->cast<CNodePtr>();
// add fusion flag
AddCommOpFusionType(comm_op, param_node_pair.first);
}
continue;
}
for (auto &op : backward_op) {
AnfNodePtr pre_node = node->input(index);
InsertMirrorNode(root, op, node, index, pre_node, func_graph, instance_name, param_name);
auto comm_op = node->input(index)->cast<CNodePtr>();
// insert new node before the node
MS_EXCEPTION_IF_NULL(cnode);
AnfNodePtr pre_node = cnode->input(1);
InsertMirrorNode(root, op, cnode, size_t(1), pre_node, func_graph, instance_name, param_name);
auto comm_op = cnode->input(size_t(1))->cast<CNodePtr>();
// add fusion flag // add fusion flag
// pipeline mirror would not be set, which should be supported later
AddCommOpFusionType(comm_op, param_node_pair.first); AddCommOpFusionType(comm_op, param_node_pair.first);
continue;
} }
AnfNodePtr pre_node = node->input(index);
InsertMirrorNode(root, op, node, index, pre_node, func_graph, instance_name, param_name);
auto comm_op = node->input(index)->cast<CNodePtr>();
// add fusion flag
// pipeline mirror would not be set, which should be supported later
AddCommOpFusionType(comm_op, param_node_pair.first);
} }
} }


@@ -1695,7 +1703,11 @@ static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr &
manager->SetEdge(cnode, SizeToLong(param_pair.second), next_cnode.second); manager->SetEdge(cnode, SizeToLong(param_pair.second), next_cnode.second);
MS_LOG(INFO) << "Parallel optimizer is applied between " << parameter->ToString() << " and " MS_LOG(INFO) << "Parallel optimizer is applied between " << parameter->ToString() << " and "
<< GetPrimName(cnode); << GetPrimName(cnode);
continue;
} else {
// insert allgather operator between shard parameter and cnode
InsertAllGatherOp(root, opt_shard_group, param_pair, parameter, op_name);
MS_LOG(INFO) << "Parallel optimizer is applied between " << parameter->ToString() << " and "
<< GetPrimName(cnode);
} }
} else { } else {
// insert allgather operator between shard parameter and cnode // insert allgather operator between shard parameter and cnode
@@ -1708,6 +1720,35 @@ static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr &
} }
} }


static std::string GetOptShardGroup(const AnfNodePtr &parameter, TensorLayout *const tensor_layout,
const OperatorInfoPtr &distribute_operator) {
std::string opt_shard_group;
if (!ParameterRequireGrad(parameter)) {
// only trainable parameters need parallel optimizer
MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << " is not trainable parameter.";
} else if (parameter->cast<ParameterPtr>()->param_info() &&
!parameter->cast<ParameterPtr>()->param_info()->parallel_optimizer()) {
MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << " does not need weight shard.";
} else if (tensor_layout->GenerateOptShardSliceShape() == Status::SUCCESS) {
// get the shard tensor slice shape if the weight is repeated on devices
// and the shape of the first dimension could be divided
// apply parallel optimizer on parameters
// create communication group for allgather operator
std::vector<Group> dev_group;
if (distribute_operator->CreateGroupForOptShard(tensor_layout, &dev_group) == Status::SUCCESS &&
!dev_group.empty()) {
opt_shard_group = dev_group[0].name();
MS_LOG(INFO) << "Parallel optimizer: create group for " << parameter->ToString() << " success.";
} else {
MS_LOG(ERROR) << "Parallel optimizer: create group for " << parameter->ToString() << " failed.";
}
} else {
MS_LOG(WARNING) << "Parallel optimizer: " << parameter->ToString() << "'s distributed shape "
<< tensor_layout->slice_shape().ToString() << " does not satisfy the conditions.";
}
return opt_shard_group;
}

// When this function returns non-empty string, that means parallel optimizer is applied on this parameter. // When this function returns non-empty string, that means parallel optimizer is applied on this parameter.
std::string SetParallelShape(const AnfNodePtr &parameter, const std::pair<AnfNodePtr, int64_t> &res) { std::string SetParallelShape(const AnfNodePtr &parameter, const std::pair<AnfNodePtr, int64_t> &res) {
MS_EXCEPTION_IF_NULL(parameter); MS_EXCEPTION_IF_NULL(parameter);
@@ -1731,33 +1772,10 @@ std::string SetParallelShape(const AnfNodePtr &parameter, const std::pair<AnfNod
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
bool enable_parallel_optimizer = ParallelContext::GetInstance()->enable_parallel_optimizer(); bool enable_parallel_optimizer = ParallelContext::GetInstance()->enable_parallel_optimizer();
if (enable_parallel_optimizer) { if (enable_parallel_optimizer) {
if (!ParameterRequireGrad(parameter)) {
// only trainable parameters need parallel optimizer
MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << " is not trainable parameter.";
} else if (parameter->cast<ParameterPtr>()->param_info() &&
!parameter->cast<ParameterPtr>()->param_info()->parallel_optimizer()) {
MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << " does not need weight shard.";
} else if (tensor_layout.GenerateOptShardSliceShape() == Status::SUCCESS) {
// get a totally shard tensor slice shape if the weight is repeated on devices
// and the shape of the first dimension could be divided
// apply parallel optimizer on parameters
// create communication group for allgather operator
slice_shape = tensor_layout.opt_shard_slice_shape();
std::vector<Group> dev_group;
if (distribute_operator->CreateGroupByTensorMap(tensor_layout.origin_tensor_map().array(), &dev_group) ==
Status::SUCCESS &&
!dev_group.empty()) {
opt_shard_group = dev_group[0].name();
// set communication group in tensor layout for checkpoint saving
tensor_layout.set_opt_shard_group(opt_shard_group);
MS_LOG(INFO) << "Parallel optimizer: create group " << opt_shard_group << " for " << parameter->ToString()
<< " success.";
} else {
MS_LOG(WARNING) << "Parallel optimizer: create group for " << parameter->ToString() << " failed.";
}
} else {
MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << "'s shape does not satisfy the conditions.";
}
opt_shard_group = GetOptShardGroup(parameter, &tensor_layout, distribute_operator);
}
if (!opt_shard_group.empty()) {
slice_shape = tensor_layout.opt_shard_slice_shape();
} }
MS_LOG(INFO) << "SetParallelShape slice_shape " << parameter->ToString() << " shape " MS_LOG(INFO) << "SetParallelShape slice_shape " << parameter->ToString() << " shape "
<< MakeValue(slice_shape)->ToString() << ", op name is " << distribute_operator->name(); << MakeValue(slice_shape)->ToString() << ", op name is " << distribute_operator->name();
@@ -2812,21 +2830,21 @@ bool IsCohesiveNode(const CNodePtr &cnode) {
IsPrimitiveCNode(cnode, prim::kPrimAllGather) || IsPrimitiveCNode(cnode, prim::kPrimMiniStepAllGather); IsPrimitiveCNode(cnode, prim::kPrimAllGather) || IsPrimitiveCNode(cnode, prim::kPrimMiniStepAllGather);
} }


std::vector<std::pair<std::string, int64_t>> NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_depth) {
ParameterMap NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_depth) {
if (curr_depth > MAX_RECURSIVE_DEPTH) { if (curr_depth > MAX_RECURSIVE_DEPTH) {
MS_LOG(WARNING) << "When finding the parameters' name of a operator, exceeded the maximum depth: " MS_LOG(WARNING) << "When finding the parameters' name of a operator, exceeded the maximum depth: "
<< MAX_RECURSIVE_DEPTH; << MAX_RECURSIVE_DEPTH;
return {}; return {};
} }
std::vector<AnfNodePtr> node_inputs{node->inputs()}; std::vector<AnfNodePtr> node_inputs{node->inputs()};
std::vector<std::pair<std::string, int64_t>> param_names;
ParameterMap param_names;
for (int64_t i = 0; i < UlongToLong(node_inputs.size()); ++i) { for (int64_t i = 0; i < UlongToLong(node_inputs.size()); ++i) {
int64_t idx = index > i ? index : i; int64_t idx = index > i ? index : i;
auto input = node_inputs[i]; auto input = node_inputs[i];
if (input->isa<Parameter>()) { if (input->isa<Parameter>()) {
auto input_parameter = input->cast<ParameterPtr>(); auto input_parameter = input->cast<ParameterPtr>();
if (input_parameter->has_default() && ParameterRequireGrad(input_parameter)) { if (input_parameter->has_default() && ParameterRequireGrad(input_parameter)) {
param_names.push_back({input_parameter->name(), idx});
param_names.push_back({input_parameter->name(), input_parameter});
} }
} else if (input->isa<CNode>()) { } else if (input->isa<CNode>()) {
CNodePtr cnode = input->cast<CNodePtr>(); CNodePtr cnode = input->cast<CNodePtr>();
@@ -2878,10 +2896,7 @@ void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes) {
std::string stratey_key_name = prim->name() + "_" + param_name; std::string stratey_key_name = prim->name() + "_" + param_name;
stra_map[stratey_key_name] = operator_info->strategy(); stra_map[stratey_key_name] = operator_info->strategy();
for (auto param_name_pair : param_names) { for (auto param_name_pair : param_names) {
if (param_name_pair.second - 1 >= UlongToLong(input_tensor_info.size())) {
continue;
}
tensor_info_map[param_name_pair.first] = input_tensor_info[param_name_pair.second - 1];
tensor_info_map[param_name_pair.first] = param_name_pair.second->user_data<TensorLayout>();
} }
if (IsGatherPInfo(operator_info->name())) { if (IsGatherPInfo(operator_info->name())) {
auto gatherv2_info = std::dynamic_pointer_cast<GatherPInfo>(operator_info); auto gatherv2_info = std::dynamic_pointer_cast<GatherPInfo>(operator_info);


+ 2
- 1
mindspore/ccsrc/frontend/parallel/step_parallel.h View File

@@ -32,6 +32,7 @@
#include "pipeline/jit/pipeline.h" #include "pipeline/jit/pipeline.h"
#include "frontend/parallel/ops_info/ops_utils.h" #include "frontend/parallel/ops_info/ops_utils.h"
#include "frontend/parallel/auto_parallel/operator_costmodel.h" #include "frontend/parallel/auto_parallel/operator_costmodel.h"
#include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h"


using OperatorInfoPtr = std::shared_ptr<mindspore::parallel::OperatorInfo>; using OperatorInfoPtr = std::shared_ptr<mindspore::parallel::OperatorInfo>;


@@ -139,7 +140,7 @@ bool IsLastStage();
void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes, void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes,
const FuncGraphManagerPtr &manager); const FuncGraphManagerPtr &manager);


std::vector<std::pair<std::string, int64_t>> NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_depth);
ParameterMap NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_depth);


void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes); void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes);




+ 6
- 6
mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc View File

@@ -17,7 +17,6 @@
#include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h" #include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h"


#include <fstream> #include <fstream>
#include <memory>
#include <vector> #include <vector>


#include "utils/ms_utils.h" #include "utils/ms_utils.h"
@@ -141,20 +140,19 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInf
} }
} }
for (auto &node_tensor_info : tensor_info_map) { for (auto &node_tensor_info : tensor_info_map) {
TensorInfo tensor_info = node_tensor_info.second;
TensorLayout tensor_layout = tensor_info.tensor_layout();
TensorLayoutPtr tensor_layout = node_tensor_info.second;
straspb::ParallelLayoutItem *parallel_layout_item = parallel_strategy_map.add_parallel_layout_item(); straspb::ParallelLayoutItem *parallel_layout_item = parallel_strategy_map.add_parallel_layout_item();
MS_EXCEPTION_IF_NULL(parallel_layout_item); MS_EXCEPTION_IF_NULL(parallel_layout_item);
parallel_layout_item->set_param_name(node_tensor_info.first); parallel_layout_item->set_param_name(node_tensor_info.first);
straspb::ParallelLayouts *parallel_layouts = parallel_layout_item->mutable_parallel_layouts(); straspb::ParallelLayouts *parallel_layouts = parallel_layout_item->mutable_parallel_layouts();
straspb::DevMatrix *dev_matrix = parallel_layouts->add_dev_matrix(); straspb::DevMatrix *dev_matrix = parallel_layouts->add_dev_matrix();
MS_EXCEPTION_IF_NULL(dev_matrix); MS_EXCEPTION_IF_NULL(dev_matrix);
for (auto dim : tensor_layout.device_arrangement().array()) {
for (auto dim : tensor_layout->device_arrangement().array()) {
dev_matrix->add_dim(LongToUlong(dim)); dev_matrix->add_dim(LongToUlong(dim));
} }
straspb::TensorMap *tensor_map = parallel_layouts->add_tensor_map(); straspb::TensorMap *tensor_map = parallel_layouts->add_tensor_map();
MS_EXCEPTION_IF_NULL(tensor_map); MS_EXCEPTION_IF_NULL(tensor_map);
for (auto dim : tensor_layout.tensor_map().array()) {
for (auto dim : tensor_layout->tensor_map().array()) {
tensor_map->add_dim(dim); tensor_map->add_dim(dim);
} }
straspb::ParamSplitShape *param_split_shape = parallel_layouts->add_param_split_shape(); straspb::ParamSplitShape *param_split_shape = parallel_layouts->add_param_split_shape();
@@ -165,7 +163,9 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInf
param_split_shape->add_dim(dim_pair.first); param_split_shape->add_dim(dim_pair.first);
indices_offset->add_dim(dim_pair.second); indices_offset->add_dim(dim_pair.second);
} }
parallel_layouts->set_field(tensor_layout.get_field_size());
parallel_layouts->set_field(tensor_layout->get_field_size());
parallel_layouts->set_opt_weight_shard_step(tensor_layout->opt_weight_shard_step());
parallel_layouts->set_opt_weight_shard_size(tensor_layout->opt_weight_shard_size());
} }


std::fstream output(save_file_, std::ios::out | std::ios::trunc | std::ios::binary); std::fstream output(save_file_, std::ios::out | std::ios::trunc | std::ios::binary);


+ 4
- 1
mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h View File

@@ -21,6 +21,7 @@
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <memory>
#include "frontend/parallel/ops_info/ops_utils.h" #include "frontend/parallel/ops_info/ops_utils.h"
#include "frontend/parallel/strategy.h" #include "frontend/parallel/strategy.h"
#include "frontend/parallel/context.h" #include "frontend/parallel/context.h"
@@ -30,7 +31,9 @@
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
using StrategyMap = std::unordered_map<std::string, StrategyPtr>; using StrategyMap = std::unordered_map<std::string, StrategyPtr>;
using TensorInfoMap = std::unordered_map<std::string, TensorInfo>;
using TensorLayoutPtr = std::shared_ptr<TensorLayout>;
using TensorInfoMap = std::unordered_map<std::string, TensorLayoutPtr>;
using ParameterMap = std::vector<std::pair<std::string, ParameterPtr>>;
using ManualShapeMap = std::unordered_map<std::string, std::vector<std::pair<int64_t, int64_t>>>; using ManualShapeMap = std::unordered_map<std::string, std::vector<std::pair<int64_t, int64_t>>>;
using GroupInfoMap = std::vector<std::pair<std::string, std::vector<uint32_t>>>; using GroupInfoMap = std::vector<std::pair<std::string, std::vector<uint32_t>>>;
class StrategyCheckpoint { class StrategyCheckpoint {


+ 5
- 0
mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.cc View File

@@ -21,6 +21,7 @@
#include "ir/value.h" #include "ir/value.h"
#include "frontend/parallel/device_matrix.h" #include "frontend/parallel/device_matrix.h"
#include "frontend/parallel/status.h" #include "frontend/parallel/status.h"
#include "frontend/parallel/context.h"
#include "frontend/parallel/tensor_layout/shape_util.h" #include "frontend/parallel/tensor_layout/shape_util.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"


@@ -431,6 +432,10 @@ Status TensorLayout::GenerateOptShardSliceShape() {
int64_t repeated_num = int64_t repeated_num =
std::accumulate(repeated_dev.begin(), repeated_dev.end(), static_cast<int64_t>(1), std::multiplies<int64_t>()); std::accumulate(repeated_dev.begin(), repeated_dev.end(), static_cast<int64_t>(1), std::multiplies<int64_t>());
int64_t split_num; int64_t split_num;
int64_t optimizer_weight_shard_size = ParallelContext::GetInstance()->optimizer_weight_shard_size();
if (optimizer_weight_shard_size != -1) {
repeated_num = optimizer_weight_shard_size;
}
if (tensor_map[0] == MAP_NONE) { if (tensor_map[0] == MAP_NONE) {
split_num = repeated_num; split_num = repeated_num;
} else { } else {


+ 16
- 1
mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.h View File

@@ -104,6 +104,18 @@ class TensorLayout {


std::string opt_shard_group() { return opt_shard_group_; } std::string opt_shard_group() { return opt_shard_group_; }


void set_opt_shard_mirror_group(std::string name) { opt_shard_mirror_group_ = std::move(name); }

std::string opt_shard_mirror_group() { return opt_shard_mirror_group_; }

void set_opt_weight_shard_step(int32_t step) { opt_weight_shard_step_ = step; }

int32_t opt_weight_shard_step() { return opt_weight_shard_step_; }

void set_opt_weight_shard_size(int32_t size) { opt_weight_shard_size_ = size; }

int32_t opt_weight_shard_size() { return opt_weight_shard_size_; }

// Key for user data. // Key for user data.
constexpr static char key[] = "TLayout"; constexpr static char key[] = "TLayout";


@@ -129,7 +141,10 @@ class TensorLayout {
bool layout_transfer_ = false; bool layout_transfer_ = false;
int32_t field_size_ = 0; int32_t field_size_ = 0;
Shape opt_shard_slice_shape_; Shape opt_shard_slice_shape_;
std::string opt_shard_group_ = "";
std::string opt_shard_group_ = ""; // for allgather
std::string opt_shard_mirror_group_ = ""; // for mirror ops
int32_t opt_weight_shard_step_ = 0;
int32_t opt_weight_shard_size_ = 0;
}; };
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore


+ 8
- 0
mindspore/ccsrc/pipeline/jit/init.cc View File

@@ -173,6 +173,14 @@ PYBIND11_MODULE(_c_expression, m) {
"Get enable/disable parallel optimizer.") "Get enable/disable parallel optimizer.")
.def("set_communi_parallel_mode", &ParallelContext::set_communi_parallel_mode, "Set communication parallel mode.") .def("set_communi_parallel_mode", &ParallelContext::set_communi_parallel_mode, "Set communication parallel mode.")
.def("get_communi_parallel_mode", &ParallelContext::communi_parallel_mode, "Get communication parallel mode.") .def("get_communi_parallel_mode", &ParallelContext::communi_parallel_mode, "Get communication parallel mode.")
.def("set_optimizer_weight_shard_size", &ParallelContext::set_optimizer_weight_shard_size,
"Set opt shard group size when not fully use parallel optimizer.")
.def("get_optimizer_weight_shard_size", &ParallelContext::optimizer_weight_shard_size,
"Get opt shard group size when not fully use parallel optimizer.")
.def("set_optimizer_weight_shard_integrated_save", &ParallelContext::set_optimizer_weight_shard_integrated_save,
"Set whether to integrated save weight shard when enable parallel optimizer.")
.def("get_optimizer_weight_shard_integrated_save", &ParallelContext::optimizer_weight_shard_integrated_save,
"Get whether to integrated save weight shard when enable parallel optimizer.")
.def("reset", &ParallelContext::Reset, "Reset auto parallel context."); .def("reset", &ParallelContext::Reset, "Reset auto parallel context.");


(void)py::class_<CostModelContext, std::shared_ptr<CostModelContext>>(m, "CostModelContext") (void)py::class_<CostModelContext, std::shared_ptr<CostModelContext>>(m, "CostModelContext")


+ 2
- 0
mindspore/ccsrc/utils/node_strategy.proto View File

@@ -54,6 +54,8 @@ message ParallelLayouts {
repeated ParamSplitShape param_split_shape = 3; repeated ParamSplitShape param_split_shape = 3;
repeated IndicesOffset indices_offset = 4; repeated IndicesOffset indices_offset = 4;
required int32 field = 5; required int32 field = 5;
required int32 opt_weight_shard_step = 6;
required int32 opt_weight_shard_size = 7;
} }


message ParallelLayoutItem { message ParallelLayoutItem {


+ 12
- 6
mindspore/core/utils/parallel_node_check.cc View File

@@ -14,11 +14,10 @@
* limitations under the License. * limitations under the License.
*/ */


#include "utils/parallel_node_check.h"

#include <set> #include <set>
#include <string> #include <string>


#include "utils/parallel_node_check.h"
#include "base/core_ops.h" #include "base/core_ops.h"


namespace mindspore { namespace mindspore {
@@ -32,6 +31,7 @@ static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem,
"ImageSummary", "TensorSummary", "Debug", "HistogramSummary", "col2im_v1", "resolve", "BroadcastGradientArgs", "ImageSummary", "TensorSummary", "Debug", "HistogramSummary", "col2im_v1", "resolve", "BroadcastGradientArgs",
"InvertPermutation", "DropoutGenMask", "embed", "create_instance", "RefToEmbed", "InvertPermutation", "DropoutGenMask", "embed", "create_instance", "RefToEmbed",
"stop_gradient", "Send", "UpdateState", "Load"}; "stop_gradient", "Send", "UpdateState", "Load"};
static const std::set<PrimitivePtr> ALLGATHER_NODE_LIST_ = {prim::kPrimAllGather, prim::kPrimMiniStepAllGather};
// clang-format on // clang-format on


bool IsInParallelBlackList(const PrimitivePtr &prim) { bool IsInParallelBlackList(const PrimitivePtr &prim) {
@@ -39,6 +39,15 @@ bool IsInParallelBlackList(const PrimitivePtr &prim) {
return (PARALLEL_BLACK_LIST_.find(prim->name()) != PARALLEL_BLACK_LIST_.end()); return (PARALLEL_BLACK_LIST_.find(prim->name()) != PARALLEL_BLACK_LIST_.end());
} }


bool IsInAllGatherNodeList(const CNodePtr &cnode) {
for (auto &value : ALLGATHER_NODE_LIST_) {
if (IsPrimitiveCNode(cnode, value)) {
return true;
}
}
return false;
}

bool IsParallelConsiderCNode(const CNodePtr &cnode) { bool IsParallelConsiderCNode(const CNodePtr &cnode) {
if (cnode == nullptr || cnode->size() == 0) { if (cnode == nullptr || cnode->size() == 0) {
return false; return false;
@@ -51,9 +60,6 @@ bool IsParallelConsiderCNode(const CNodePtr &cnode) {
if (prim == nullptr) { if (prim == nullptr) {
return false; return false;
} }
if (IsInParallelBlackList(prim)) {
return false;
}
return true;
return !IsInParallelBlackList(prim);
} }
} // namespace mindspore } // namespace mindspore

+ 1
- 0
mindspore/core/utils/parallel_node_check.h View File

@@ -21,6 +21,7 @@


namespace mindspore { namespace mindspore {
bool IsInParallelBlackList(const PrimitivePtr &); bool IsInParallelBlackList(const PrimitivePtr &);
bool IsInAllGatherNodeList(const CNodePtr &);
bool IsParallelConsiderCNode(const CNodePtr &); bool IsParallelConsiderCNode(const CNodePtr &);
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CORE_UTILS_PARALLEL_NODE_CHECK_H_ #endif // MINDSPORE_CORE_UTILS_PARALLEL_NODE_CHECK_H_

+ 57
- 4
mindspore/parallel/_auto_parallel_context.py View File

@@ -15,6 +15,7 @@
"""Context of auto parallel""" """Context of auto parallel"""
import threading import threading
import mindspore.context as context import mindspore.context as context
import mindspore.log as logger
from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, _set_fusion_strategy_by_size from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, _set_fusion_strategy_by_size
from mindspore.parallel._ps_context import _is_role_pserver from mindspore.parallel._ps_context import _is_role_pserver
from mindspore._c_expression import AutoParallelContext from mindspore._c_expression import AutoParallelContext
@@ -501,6 +502,48 @@ class _AutoParallelContext:
self.check_context_handle() self.check_context_handle()
return self._context_handle.get_communi_parallel_mode() return self._context_handle.get_communi_parallel_mode()


def set_optimizer_weight_shard_size(self, optimizer_weight_shard_size):
"""
Set optimizer_weight_shard_size.

Args:
optimizer_weight_shard_size (int): Opt shard group size when not globally use parallel
optimizer across devices.
"""
self.check_context_handle()
if not isinstance(optimizer_weight_shard_size, int):
raise TypeError('optimizer_weight_shard_size is invalid type')
if optimizer_weight_shard_size <= 1:
logger.warning("The setting 'optimizer_weight_shard_size' is invalid. "
"Please use the integer larger than 1.")
return
self._context_handle.set_optimizer_weight_shard_size(optimizer_weight_shard_size)

def get_optimizer_weight_shard_size(self):
"""Get optimizer_weight_shard_size."""
self.check_context_handle()
return self._context_handle.get_optimizer_weight_shard_size()

def set_optimizer_weight_shard_integrated_save(self, optimizer_weight_shard_integrated_save):
"""
Set optimizer_weight_shard_integrated_save.

Args:
optimizer_weight_shard_integrated_save (bool): Whether to integrated save weight shard when
enable parallel optimizer.
"""
self.check_context_handle()
if not isinstance(optimizer_weight_shard_integrated_save, bool):
raise TypeError('optimizer_weight_shard_integrated_save is invalid type')
self._context_handle.set_optimizer_weight_shard_integrated_save(optimizer_weight_shard_integrated_save)


def get_optimizer_weight_shard_integrated_save(self):
"""Get optimizer_weight_shard_size."""
self.check_context_handle()
return self._context_handle.get_optimizer_weight_shard_integrated_save()


def reset(self): def reset(self):
"""Reset all settings.""" """Reset all settings."""
self.check_context_handle() self.check_context_handle()
@@ -540,7 +583,9 @@ _set_auto_parallel_context_func_map = {
"enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer, "enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer,
"grad_accumulation_step": auto_parallel_context().set_grad_accumulation_step, "grad_accumulation_step": auto_parallel_context().set_grad_accumulation_step,
"all_reduce_fusion_config": auto_parallel_context().set_all_reduce_fusion_split_indices, "all_reduce_fusion_config": auto_parallel_context().set_all_reduce_fusion_split_indices,
"communi_parallel_mode": auto_parallel_context().set_communi_parallel_mode}
"communi_parallel_mode": auto_parallel_context().set_communi_parallel_mode,
"optimizer_weight_shard_size": auto_parallel_context().set_optimizer_weight_shard_size,
"optimizer_weight_shard_integrated_save": auto_parallel_context().set_optimizer_weight_shard_integrated_save}




_get_auto_parallel_context_func_map = { _get_auto_parallel_context_func_map = {
@@ -559,7 +604,9 @@ _get_auto_parallel_context_func_map = {
"enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer, "enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer,
"grad_accumulation_step": auto_parallel_context().get_grad_accumulation_step, "grad_accumulation_step": auto_parallel_context().get_grad_accumulation_step,
"all_reduce_fusion_config": auto_parallel_context().get_all_reduce_fusion_split_indices, "all_reduce_fusion_config": auto_parallel_context().get_all_reduce_fusion_split_indices,
"communi_parallel_mode": auto_parallel_context().get_communi_parallel_mode}
"communi_parallel_mode": auto_parallel_context().get_communi_parallel_mode,
"optimizer_weight_shard_size": auto_parallel_context().get_optimizer_weight_shard_size,
"optimizer_weight_shard_integrated_save": auto_parallel_context().get_optimizer_weight_shard_integrated_save}




@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool, @args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool,
@@ -567,7 +614,8 @@ _get_auto_parallel_context_func_map = {
parameter_broadcast=bool, strategy_ckpt_load_file=str, parameter_broadcast=bool, strategy_ckpt_load_file=str,
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool, strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str, grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str,
communi_parallel_mode=str)
communi_parallel_mode=str, optimizer_weight_shard_size=int,
optimizer_weight_shard_integrated_save=bool)


def _set_auto_parallel_context(**kwargs): def _set_auto_parallel_context(**kwargs):
""" """
@@ -615,7 +663,7 @@ def _set_auto_parallel_context(**kwargs):
pipeline_stages (int): Set the stage information for pipeline parallel. This indicates how pipeline_stages (int): Set the stage information for pipeline parallel. This indicates how
the devices are distributed alone the pipeline. The total devices will be divided into the devices are distributed alone the pipeline. The total devices will be divided into
'pipeline_stags' stages. This currently could only be used when 'pipeline_stags' stages. This currently could only be used when
parall mode semi_auto_parallel is enabled. Default: 0
parallel mode semi_auto_parallel is enabled. Default: 0
communi_parallel_mode (str): There are tree kinds of communication parallel modes, "all_group_parallel", communi_parallel_mode (str): There are tree kinds of communication parallel modes, "all_group_parallel",
"same_server_group_parallel" and "no_group_parallel". Default: "all_group_parallel". "same_server_group_parallel" and "no_group_parallel". Default: "all_group_parallel".


@@ -624,6 +672,11 @@ def _set_auto_parallel_context(**kwargs):
- same_server_group_parallel: Only the communication groups within the same server are parallel. - same_server_group_parallel: Only the communication groups within the same server are parallel.


- no_group_parallel: All communication groups are not parallel. - no_group_parallel: All communication groups are not parallel.
optimizer_weight_shard_size (int): Set optimizer shard group size when not fully use parallel optimizer.
It should be larger than one and less than or equal with the data parallel size.
Default: -1, which means fully use parallel optimizer in data parallel dimension.
optimizer_weight_shard_integrated_save (bool): Whether to integrated save weight shard when enable parallel
optimizer. Default: False.


Raises: Raises:
ValueError: If input key is not attribute in auto parallel context. ValueError: If input key is not attribute in auto parallel context.


+ 6
- 4
mindspore/parallel/_utils.py View File

@@ -248,6 +248,10 @@ def _remove_repeated_slices(tensor_layout):
def _infer_rank_list(train_map, predict_map=None): def _infer_rank_list(train_map, predict_map=None):
"""infer checkpoint slices to be loaded""" """infer checkpoint slices to be loaded"""
ret = {} ret = {}
if _get_pipeline_stages() > 1:
local_rank = int(_get_global_rank() % (_get_device_num() / _get_pipeline_stages()))
else:
local_rank = _get_global_rank()
for param_name in train_map: for param_name in train_map:
train_layout = train_map[param_name] train_layout = train_map[param_name]
train_dev_mat = train_layout[0] train_dev_mat = train_layout[0]
@@ -271,15 +275,13 @@ def _infer_rank_list(train_map, predict_map=None):
dev_num = np.array(predict_layout[0]).prod() dev_num = np.array(predict_layout[0]).prod()
# optimization pass # optimization pass
if _check_same_layout(train_layout, predict_layout): if _check_same_layout(train_layout, predict_layout):
dev_rank = _get_global_rank()
ret[param_name] = ([dev_rank], True)
ret[param_name] = ([local_rank], True)
continue continue
if _check_similar_layout(train_layout, predict_layout): if _check_similar_layout(train_layout, predict_layout):
if len(rank_list) == 1: if len(rank_list) == 1:
ret[param_name] = (rank_list, True) ret[param_name] = (rank_list, True)
elif len(rank_list) == dev_num: elif len(rank_list) == dev_num:
dev_rank = _get_global_rank()
ret[param_name] = ([rank_list[dev_rank]], True)
ret[param_name] = ([rank_list[local_rank]], True)
else: else:
ret[param_name] = (rank_list, False) ret[param_name] = (rank_list, False)
else: else:


+ 4
- 2
mindspore/train/serialization.py View File

@@ -597,7 +597,7 @@ def _get_merged_param_data(net, param_name, param_data, integrated_save):
allgather_net = get_allgather_cell(opt_shard_group, False) allgather_net = get_allgather_cell(opt_shard_group, False)
net.parallel_parameter_merge_net_dict[param_name] = allgather_net net.parallel_parameter_merge_net_dict[param_name] = allgather_net
param_data = allgather_net(param_data) param_data = allgather_net(param_data)
elif opt_shard_group:
elif opt_shard_group and context.get_auto_parallel_context("optimizer_weight_shard_integrated_save"):
if allgather_net is None: if allgather_net is None:
allgather_net = get_allgather_cell(opt_shard_group, False) allgather_net = get_allgather_cell(opt_shard_group, False)
net.parallel_parameter_merge_net_dict[param_name] = allgather_net net.parallel_parameter_merge_net_dict[param_name] = allgather_net
@@ -1247,7 +1247,9 @@ def _convert_to_list(strategy):
tensor_map = list(layout.tensor_map[0].dim) tensor_map = list(layout.tensor_map[0].dim)
param_split_shape = list(layout.param_split_shape[0].dim) param_split_shape = list(layout.param_split_shape[0].dim)
field_size = int(layout.field) field_size = int(layout.field)
train_map[param_name] = [dev_mat, tensor_map, param_split_shape, field_size]
shard_stride = int(layout.opt_weight_shard_step)
shard_size = int(layout.opt_weight_shard_size)
train_map[param_name] = [dev_mat, tensor_map, param_split_shape, field_size, shard_stride, shard_size]
except BaseException as e: except BaseException as e:
raise ValueError(f"{e.__str__()}. Please make sure that strategy matches the node_strategy.proto.") raise ValueError(f"{e.__str__()}. Please make sure that strategy matches the node_strategy.proto.")
return train_map return train_map


+ 11
- 0
tests/ut/python/parallel/test_parallel_optimizer.py View File

@@ -131,6 +131,17 @@ def test_auto_parallel_momentum_5():
assert not param_dict["weight2"][5] assert not param_dict["weight2"][5]




def test_auto_parallel_momentum_6():
# test not fully use parallel optimizer with optimizer_weight_shard_size
# weight1 could not be shard and weight2 is repeated
context.set_auto_parallel_context(optimizer_weight_shard_size=2)
train_network = auto_parallel_compile_net("semi_auto_parallel", 32, Net2, ((4, 8), (8, 1)), ((4, 4), (4, 2)))
param_dict = train_network.parameter_layout_dict
# validate opt_shard_group
assert param_dict["weight1"][5].startswith("2")
assert param_dict["weight2"][5].startswith("2")


def test_AdamWeightDecay(): def test_AdamWeightDecay():
""" test_AdamWeightDecay """ """ test_AdamWeightDecay """
context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True) context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True)


+ 6
- 0
tests/ut/python/parallel/test_set_auto_parallel_context.py View File

@@ -59,6 +59,10 @@ def test_set_auto_parallel_context():
parameter_broadcast_is_set = auto_parallel_context().get_parameter_broadcast_is_set() parameter_broadcast_is_set = auto_parallel_context().get_parameter_broadcast_is_set()
assert parameter_broadcast_is_set assert parameter_broadcast_is_set


auto_parallel_context().set_optimizer_weight_shard_integrated_save(True)
integrated_save = auto_parallel_context().get_optimizer_weight_shard_integrated_save()
assert integrated_save

with pytest.raises(ValueError): with pytest.raises(ValueError):
context.set_auto_parallel_context(device_num=0) context.set_auto_parallel_context(device_num=0)


@@ -105,6 +109,7 @@ def test_reset_auto_parallel_context():
parameter_broadcast_is_set = auto_parallel_context().get_parameter_broadcast_is_set() parameter_broadcast_is_set = auto_parallel_context().get_parameter_broadcast_is_set()
stage = auto_parallel_context().get_pipeline_stages() stage = auto_parallel_context().get_pipeline_stages()
communi_parallel_mode = context.get_auto_parallel_context("communi_parallel_mode") communi_parallel_mode = context.get_auto_parallel_context("communi_parallel_mode")
integrated_save = auto_parallel_context().get_optimizer_weight_shard_integrated_save()


assert device_num == 1 assert device_num == 1
assert global_rank == 0 assert global_rank == 0
@@ -116,3 +121,4 @@ def test_reset_auto_parallel_context():
assert not parameter_broadcast_is_set assert not parameter_broadcast_is_set
assert stage == 1 assert stage == 1
assert communi_parallel_mode == "all_group_parallel" assert communi_parallel_mode == "all_group_parallel"
assert not integrated_save

Loading…
Cancel
Save