|
|
|
@@ -255,6 +255,38 @@ Status GatherPInfo::CheckSplitAxisStrategy(const StrategyPtr &strategy) { |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
// return true: axis is 0, and split the first dimension of parameter and the first dimension of indices |
|
|
|
// otherwise return false |
|
|
|
bool GatherPInfo::ShardBatchAndAxis(const Strategys &strategy) { |
|
|
|
if (axis_ != 0) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
if (strategy.size() != 2) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
Dimensions param_strategy = strategy[0]; |
|
|
|
Dimensions indices_strategy = strategy[1]; |
|
|
|
if ((param_strategy.size() != 2) || (indices_strategy.size() != 2)) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
if ((param_strategy[1] != 1) || (indices_strategy[1] != 1)) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
if (param_strategy[0] * indices_strategy[0] != stage_device_size_) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
if ((param_strategy[0] == stage_device_size_) || (indices_strategy[0] == stage_device_size_)) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
Status GatherPInfo::CheckStrategy(const StrategyPtr &strategy) { |
|
|
|
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { |
|
|
|
return FAILED; |
|
|
|
@@ -286,6 +318,9 @@ Status GatherPInfo::CheckStrategy(const StrategyPtr &strategy) { |
|
|
|
if ((axis_ == 0) && (index_shape.at(0) % param_strategy.at(0) != 0) && !dynamic_shape_indices_) { |
|
|
|
MS_LOG(INFO) << name_ << ": index_shape(0) can't be divided by param_strategy(0), use allreduce in forward"; |
|
|
|
axis_split_forward_allreduce_ = true; |
|
|
|
} else if (is_auto_parallel_) { |
|
|
|
// in auto parallel mode, this function will be called many times, so need to reset the flags |
|
|
|
axis_split_forward_allreduce_ = false; |
|
|
|
} |
|
|
|
|
|
|
|
if (manual_split_) { |
|
|
|
@@ -296,6 +331,17 @@ Status GatherPInfo::CheckStrategy(const StrategyPtr &strategy) { |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
if (ShardBatchAndAxis(strategy->GetInputDim())) { |
|
|
|
shard_batch_and_axis_ = true; |
|
|
|
axis_split_forward_allreduce_ = true; |
|
|
|
MS_LOG(INFO) << name_ << ": Sharding batch and axis, and the forward use allreduce"; |
|
|
|
return SUCCESS; |
|
|
|
} else if (is_auto_parallel_) { |
|
|
|
// in auto parallel mode, this function will be called many times, so need to reset the flags |
|
|
|
shard_batch_and_axis_ = false; |
|
|
|
axis_split_forward_allreduce_ = false; |
|
|
|
} |
|
|
|
|
|
|
|
// axis != 0, param_shape(0)%(param_strategy(0)*param_strategy(axis)) must be 0 |
|
|
|
if (axis_ != 0 && param_shape.at(0) % (param_strategy.at(0) * param_strategy.at(LongToSize(axis_))) != 0) { |
|
|
|
MS_LOG(DEBUG) << name_ << ": param_shape(0) can't be divided by (param_strategy(0)*param_strategy(axis))."; |
|
|
|
@@ -357,6 +403,15 @@ Status GatherPInfo::InferDevMatrixShape() { |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
if (shard_batch_and_axis_) { |
|
|
|
dev_matrix_shape_ = {index_strategy[0], param_strategy[0]}; |
|
|
|
// if forward use reducescatter, the dev matrix is {index_strategy[0] * param_strategy[0]} |
|
|
|
out_dev_matrix_shape_ = dev_matrix_shape_; |
|
|
|
MS_LOG(INFO) << name_ << ": Sharding batch and axis, the dev matrix is " << dev_matrix_shape_ |
|
|
|
<< ", out dev matrix is " << out_dev_matrix_shape_; |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
dev_matrix_shape_ = param_strategy; |
|
|
|
|
|
|
|
// param_strategy(axis)==1, |
|
|
|
@@ -473,6 +528,13 @@ Status GatherPInfo::InferTensorMap() { |
|
|
|
outputs_tensor_map_.push_back({-1, 1, 0}); |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
if (shard_batch_and_axis_) { |
|
|
|
inputs_tensor_map_.push_back({0, -1}); // param |
|
|
|
inputs_tensor_map_.push_back({1, -1}); // indices |
|
|
|
outputs_tensor_map_.push_back({1, -1, -1}); // output, if forward use reducescatter, tensormap is {0, -1, -1} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
InferInputsTensorMap(); |
|
|
|
InferOutputsTensorMap(); |
|
|
|
return SUCCESS; |
|
|
|
@@ -516,6 +578,15 @@ Status GatherPInfo::InferBias() { |
|
|
|
int64_t rank = g_device_manager->rank_index_in_stage(); |
|
|
|
auto input_shape = inputs_shape_.at(0); |
|
|
|
auto params_strategy = strategy_->GetInputDim().at(0); |
|
|
|
|
|
|
|
if (shard_batch_and_axis_) { |
|
|
|
slice_size_ = input_shape[0] / params_strategy[0]; |
|
|
|
bias_ = rank % params_strategy[0] * slice_size_; |
|
|
|
MS_LOG(INFO) << name_ << ": Sharding batch and axis, the rank is " << rank << ", slice size is " << slice_size_ |
|
|
|
<< ", bias is " << bias_; |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
// axis don't split |
|
|
|
if (params_strategy.at(axis_) == 1) { |
|
|
|
bias_ = 0; |
|
|
|
@@ -598,6 +669,11 @@ Status GatherPInfo::InferGroup() { |
|
|
|
dim = dim + 1; |
|
|
|
} |
|
|
|
|
|
|
|
if (shard_batch_and_axis_) { |
|
|
|
dim = 1; |
|
|
|
MS_LOG(INFO) << name_ << ": Sharding batch and axis, the group dim is " << dim; |
|
|
|
} |
|
|
|
|
|
|
|
if (dev_matrix.GetDevicesAlongDim(SizeToUlong(dim), &group_devices) != SUCCESS) { |
|
|
|
MS_LOG(ERROR) << name_ << ": Create group failed."; |
|
|
|
return FAILED; |
|
|
|
|