|
|
|
@@ -140,52 +140,56 @@ Status Conv2DInfo::CheckHWStrategyBase(int64_t h_strategy, int64_t w_strategy) c |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status Conv2DInfo::CheckHWStrategySameMode(int64_t h_strategy, int64_t w_strategy) { |
|
|
|
int64_t h_slice_shape = inputs_shape_[0][2] / h_strategy; |
|
|
|
int64_t w_slice_shape = inputs_shape_[0][3] / w_strategy; |
|
|
|
|
|
|
|
// H dimension |
|
|
|
if (kernel_size_[0] > stride_[2] && h_strategy > 1) { |
|
|
|
MS_LOG(ERROR) << name_ << ": The 'same' mode do not support to split H when kernel_size > stride"; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
if (h_strategy > 1 && (kernel_size_[0] <= stride_[2] && h_slice_shape % stride_[2] != 0)) { |
|
|
|
MS_LOG(ERROR) << name_ |
|
|
|
<< ": The 'same' mode do not support to split H when kernel_size <= stride but slice shape " |
|
|
|
"is not divisible by stride "; |
|
|
|
return FAILED; |
|
|
|
Status Conv2DInfo::CheckHWStrategySameModeByDimension(int64_t strategy, const std::string &dimension) { |
|
|
|
int64_t h_or_w_input_shape = 0, h_or_w_slice_shape = 0, h_or_w_kernel_size = 0, h_or_w_stride = 0; |
|
|
|
if (dimension == H_DIMENSION) { |
|
|
|
h_or_w_input_shape = inputs_shape_[0][2]; |
|
|
|
h_or_w_slice_shape = h_or_w_input_shape / strategy; |
|
|
|
h_or_w_kernel_size = kernel_size_[0]; |
|
|
|
h_or_w_stride = stride_[2]; |
|
|
|
} else { |
|
|
|
h_or_w_input_shape = inputs_shape_[0][3]; |
|
|
|
h_or_w_slice_shape = h_or_w_input_shape / strategy; |
|
|
|
h_or_w_kernel_size = kernel_size_[1]; |
|
|
|
h_or_w_stride = stride_[3]; |
|
|
|
} |
|
|
|
|
|
|
|
// W dimension |
|
|
|
if (w_strategy > 1 && (kernel_size_[1] <= stride_[3] && w_slice_shape % stride_[3] != 0)) { |
|
|
|
MS_LOG(ERROR) << name_ |
|
|
|
<< ": The 'same' mode do not support to split W when kernel_size <= stride but slice shape " |
|
|
|
"is not divisible by stride "; |
|
|
|
if (strategy > 1 && (h_or_w_kernel_size <= h_or_w_stride && h_or_w_slice_shape % h_or_w_stride != 0)) { |
|
|
|
MS_LOG(ERROR) << name_ << ": The 'same' mode do not support to split " << dimension |
|
|
|
<< " when kernel_size <= stride but slice shape is not divisible by stride "; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
if (w_strategy > 1 && (kernel_size_[1] > stride_[3])) { |
|
|
|
if (inputs_shape_[0][3] % stride_[3] != 0) { |
|
|
|
MS_LOG(ERROR) << name_ |
|
|
|
<< ": The 'same' mode do not support to split W when kernel_size > stride but w shape is not " |
|
|
|
"divisible by stride"; |
|
|
|
if (strategy > 1 && (h_or_w_kernel_size > h_or_w_stride)) { |
|
|
|
if (h_or_w_input_shape % h_or_w_stride != 0) { |
|
|
|
MS_LOG(ERROR) << name_ << ": The 'same' mode do not support to split " << dimension |
|
|
|
<< " when kernel_size > stride but input shape is not divisible by stride"; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
if (w_slice_shape <= ((kernel_size_[1] - stride_[3] + 1) / 2)) { |
|
|
|
MS_LOG(ERROR) << name_ |
|
|
|
<< ": The 'same' mode do not support to split W when kernel_size > stride but w slice shape is " |
|
|
|
"smaller than or equal to (k - s + 1) / 2"; |
|
|
|
if (h_or_w_slice_shape <= ((h_or_w_kernel_size - h_or_w_stride + 1) / 2)) { |
|
|
|
MS_LOG(ERROR) << name_ << ": The 'same' mode do not support to split " << dimension |
|
|
|
<< " when kernel_size > stride but slice shape is smaller than or equal to (k - s + 1) / 2"; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
if (kernel_size_[1] - stride_[3] == 1) { |
|
|
|
MS_LOG(ERROR) << name_ << ": The 'same' mode do not support to split W when kernel_size > stride but k - s == 1"; |
|
|
|
if (h_or_w_kernel_size - h_or_w_stride == 1) { |
|
|
|
MS_LOG(ERROR) << name_ << ": The 'same' mode do not support to split " << dimension |
|
|
|
<< " when kernel_size > stride but k - s == 1"; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status Conv2DInfo::CheckHWStrategySameMode(int64_t h_strategy, int64_t w_strategy) { |
|
|
|
if (CheckHWStrategySameModeByDimension(h_strategy, H_DIMENSION) != SUCCESS) { |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
if (CheckHWStrategySameModeByDimension(w_strategy, W_DIMENSION) != SUCCESS) { |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -293,7 +297,8 @@ Status Conv2DInfo::CheckStrategyBase(const StrategyPtr &strategy) { |
|
|
|
} |
|
|
|
|
|
|
|
Status Conv2DInfo::CheckStrategy(const StrategyPtr &strategy) { |
|
|
|
need_exchange_overlap_ = false; |
|
|
|
h_dim_need_exchange_overlap_ = false; |
|
|
|
w_dim_need_exchange_overlap_ = false; |
|
|
|
if (CheckStrategyBase(strategy) != SUCCESS) { |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
@@ -313,11 +318,14 @@ Status Conv2DInfo::CheckStrategy(const StrategyPtr &strategy) { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// kernel size larger than stride and the w dimension is split, need to exchange overlap |
|
|
|
if ((kernel_size_[1] > stride_[3]) && (input_strategy[3] > 1)) { |
|
|
|
need_exchange_overlap_ = true; |
|
|
|
// kernel size larger than stride and the h/w dimension is split, need to exchange overlap |
|
|
|
if ((kernel_size_[0] > stride_[2]) && (input_strategy[2] > 1)) { |
|
|
|
h_dim_need_exchange_overlap_ = true; |
|
|
|
} |
|
|
|
|
|
|
|
if ((kernel_size_[1] > stride_[3]) && (input_strategy[3] > 1)) { |
|
|
|
w_dim_need_exchange_overlap_ = true; |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -333,85 +341,154 @@ Status Conv2DInfo::InferDevMatrixShape() { |
|
|
|
|
|
|
|
dev_matrix_shape_ = stra[0]; |
|
|
|
dev_matrix_shape_.push_back(stra[1][0]); |
|
|
|
h_dimension_shard_num_ = stra[0][2]; |
|
|
|
w_dimension_shard_num_ = stra[0][3]; |
|
|
|
input_slice_shape_ = GetSliceShape(inputs_shape_[0], stra[0]); |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status Conv2DInfo::InferRankBias() { |
|
|
|
// the Conv2D operator: |
|
|
|
// the origin dev_matrix is [n, i, h, w, o] |
|
|
|
// if repeated calculation and repeated num in the left of dev matrix, the dev_matrix is [repeated_num, n, i, h, w, o] |
|
|
|
// if repeated calculation and repeated num in the right of dev matrix, the dev_matrix is [n, i, h, w, o, |
|
|
|
// repeated_num] |
|
|
|
// |
|
|
|
// the Conv2DBackpropInput's origin dev_matrix is [n, o, h, w, i], w dimension's relative position is the same as |
|
|
|
// Conv2D, the rank_bias_ is the position of the current rank in the w dimension of the dev_matrix(have not split h |
|
|
|
// dimension) |
|
|
|
if (!need_exchange_overlap_) { |
|
|
|
MS_LOG(INFO) << name_ << ": No need to infer rank bias"; |
|
|
|
return SUCCESS; |
|
|
|
std::vector<int64_t> Conv2DInfo::GetAdjacentRankIdsAndBiases(int64_t rank_id, const std::string &dimension) { |
|
|
|
std::vector<int64_t> ret; |
|
|
|
if (rank_id < 0) { |
|
|
|
ret = {-1, -1, -1, -1, -1}; |
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
MS_LOG(INFO) << name_ << ": The rank id is " << rank_id << ", the dimension is " << dimension; |
|
|
|
|
|
|
|
uint64_t index_in_dev_matrix = -1; |
|
|
|
int64_t dimension_shard_num = 1; |
|
|
|
if (dimension == H_DIMENSION) { |
|
|
|
index_in_dev_matrix = 2; |
|
|
|
dimension_shard_num = h_dimension_shard_num_; |
|
|
|
} else { |
|
|
|
index_in_dev_matrix = 3; |
|
|
|
dimension_shard_num = w_dimension_shard_num_; |
|
|
|
} |
|
|
|
|
|
|
|
uint64_t w_index_in_dev_matrix = 3; |
|
|
|
if (repeated_calc_num_ > 1 && !repeated_num_in_dev_matrix_right_) { |
|
|
|
w_index_in_dev_matrix += 1; |
|
|
|
index_in_dev_matrix += 1; |
|
|
|
} |
|
|
|
|
|
|
|
CheckGlobalDeviceManager(); |
|
|
|
int64_t rank = g_device_manager->global_rank(); |
|
|
|
DeviceMatrix dev_matrix(rank, stage_device_list_, dev_matrix_shape_); |
|
|
|
DeviceMatrix dev_matrix(rank_id, stage_device_list_, dev_matrix_shape_); |
|
|
|
RankList group_devices; |
|
|
|
if (dev_matrix.GetDevicesAlongDim(w_index_in_dev_matrix, &group_devices) != SUCCESS) { |
|
|
|
return FAILED; |
|
|
|
if (dev_matrix.GetDevicesAlongDim(index_in_dev_matrix, &group_devices) != SUCCESS) { |
|
|
|
MS_LOG(EXCEPTION) << name_ << ": Get device along dim failed"; |
|
|
|
} |
|
|
|
|
|
|
|
if (group_devices.size() <= 1) { |
|
|
|
MS_LOG(INFO) << name_ << ": The devices' size of w dimension is " << group_devices.size() |
|
|
|
MS_LOG(INFO) << name_ << ": The devices' size of " << dimension << " is " << group_devices.size() |
|
|
|
<< ", no need to infer rank bias"; |
|
|
|
return SUCCESS; |
|
|
|
ret = {-1, -1, -1, -1, -1}; |
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
if (group_devices.size() != LongToSize(w_dimension_shard_num_)) { |
|
|
|
MS_LOG(ERROR) << name_ << ": The devices' size of w dimension is " << group_devices.size() |
|
|
|
<< ", but the shard num of w dimension is " << w_dimension_shard_num_; |
|
|
|
return FAILED; |
|
|
|
if (group_devices.size() != LongToSize(dimension_shard_num)) { |
|
|
|
MS_LOG(EXCEPTION) << name_ << ": The devices' size of " << dimension << " is " << group_devices.size() |
|
|
|
<< ", but the shard num of w dimension is " << dimension_shard_num; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<int64_t>::iterator it = std::find(group_devices.begin(), group_devices.end(), rank); |
|
|
|
std::vector<int64_t>::iterator it = std::find(group_devices.begin(), group_devices.end(), rank_id); |
|
|
|
if (it == group_devices.end()) { |
|
|
|
MS_LOG(ERROR) << name_ << ": Can not find the current rank in device list of w dimension, the current rank is " |
|
|
|
<< rank << ", the device list is " << group_devices; |
|
|
|
return FAILED; |
|
|
|
MS_LOG(EXCEPTION) << name_ << ": Can not find the current rank in device list of " << dimension |
|
|
|
<< ", the current rank is " << rank_id << ", the device list is " << group_devices; |
|
|
|
} |
|
|
|
|
|
|
|
rank_bias_ = std::distance(group_devices.begin(), it); |
|
|
|
int64_t left_or_top_rank_id = -1; |
|
|
|
int64_t right_or_bottom_rank_id = -1; |
|
|
|
int64_t left_or_top_rank_bias = -1; |
|
|
|
int64_t right_or_bottom_rank_bias = -1; |
|
|
|
int64_t current_rank_bias = -1; |
|
|
|
current_rank_bias = std::distance(group_devices.begin(), it); |
|
|
|
if (it == group_devices.begin()) { |
|
|
|
left_rank_bias_ = -1; |
|
|
|
right_rank_bias_ = rank_bias_ + 1; |
|
|
|
left_or_top_rank_bias = -1; |
|
|
|
right_or_bottom_rank_bias = current_rank_bias + 1; |
|
|
|
|
|
|
|
left_rank_id_ = -1; |
|
|
|
right_rank_id_ = *(it + 1); |
|
|
|
left_or_top_rank_id = -1; |
|
|
|
right_or_bottom_rank_id = *(it + 1); |
|
|
|
} else if (it == group_devices.end() - 1) { |
|
|
|
left_rank_bias_ = rank_bias_ - 1; |
|
|
|
right_rank_bias_ = -1; |
|
|
|
left_or_top_rank_bias = current_rank_bias - 1; |
|
|
|
right_or_bottom_rank_bias = -1; |
|
|
|
|
|
|
|
left_rank_id_ = *(it - 1); |
|
|
|
right_rank_id_ = -1; |
|
|
|
left_or_top_rank_id = *(it - 1); |
|
|
|
right_or_bottom_rank_id = -1; |
|
|
|
} else { |
|
|
|
left_rank_bias_ = rank_bias_ - 1; |
|
|
|
right_rank_bias_ = rank_bias_ + 1; |
|
|
|
left_or_top_rank_bias = current_rank_bias - 1; |
|
|
|
right_or_bottom_rank_bias = current_rank_bias + 1; |
|
|
|
|
|
|
|
left_rank_id_ = *(it - 1); |
|
|
|
right_rank_id_ = *(it + 1); |
|
|
|
left_or_top_rank_id = *(it - 1); |
|
|
|
right_or_bottom_rank_id = *(it + 1); |
|
|
|
} |
|
|
|
|
|
|
|
ret = {left_or_top_rank_id, right_or_bottom_rank_id, left_or_top_rank_bias, right_or_bottom_rank_bias, |
|
|
|
current_rank_bias}; |
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
void Conv2DInfo::InferAdjacentRankInfo() { |
|
|
|
// the Conv2D operator: |
|
|
|
// the origin dev_matrix is [n, i, h, w, o] |
|
|
|
// if repeated calculation and repeated num in the left of dev matrix, the dev_matrix is [repeated_num, n, i, h, w, o] |
|
|
|
// if repeated calculation and repeated num in the right of dev matrix, the dev_matrix is [n, i, h, w, o, |
|
|
|
// repeated_num] |
|
|
|
// |
|
|
|
// the Conv2DBackpropInput's origin dev_matrix is [n, o, h, w, i], w dimension's relative position is the same as |
|
|
|
// Conv2D, the w_rank_bias_ is the position of the current rank in the w dimension of the dev_matrix(have not split h |
|
|
|
// dimension) |
|
|
|
|
|
|
|
CheckGlobalDeviceManager(); |
|
|
|
int64_t rank = g_device_manager->global_rank(); |
|
|
|
std::vector<int64_t> h_dim_rank_info = GetAdjacentRankIdsAndBiases(rank, H_DIMENSION); |
|
|
|
top_rank_id_ = h_dim_rank_info[0]; |
|
|
|
bottom_rank_id_ = h_dim_rank_info[1]; |
|
|
|
top_rank_bias_ = h_dim_rank_info[2]; |
|
|
|
bottom_rank_bias_ = h_dim_rank_info[3]; |
|
|
|
h_rank_bias_ = h_dim_rank_info[4]; |
|
|
|
|
|
|
|
std::vector<int64_t> w_dim_rank_info = GetAdjacentRankIdsAndBiases(rank, W_DIMENSION); |
|
|
|
left_rank_id_ = w_dim_rank_info[0]; |
|
|
|
right_rank_id_ = w_dim_rank_info[1]; |
|
|
|
left_rank_bias_ = w_dim_rank_info[2]; |
|
|
|
right_rank_bias_ = w_dim_rank_info[3]; |
|
|
|
w_rank_bias_ = w_dim_rank_info[4]; |
|
|
|
|
|
|
|
std::vector<int64_t> top_w_dim_rank_info = GetAdjacentRankIdsAndBiases(top_rank_id_, W_DIMENSION); |
|
|
|
top_left_rank_id_ = top_w_dim_rank_info[0]; |
|
|
|
top_right_rank_id_ = top_w_dim_rank_info[1]; |
|
|
|
|
|
|
|
std::vector<int64_t> bottom_w_dim_rank_info = GetAdjacentRankIdsAndBiases(bottom_rank_id_, W_DIMENSION); |
|
|
|
bottom_left_rank_id_ = bottom_w_dim_rank_info[0]; |
|
|
|
bottom_right_rank_id_ = bottom_w_dim_rank_info[1]; |
|
|
|
|
|
|
|
all_to_all_group_ = g_device_manager->world_group(); // use world group temporarily |
|
|
|
MS_LOG(INFO) << name_ << ": The current rank is " << rank << ", the device list of w dimension is " << group_devices |
|
|
|
<< ", the rank bias is " << rank_bias_ << ", the left rank bias is " << left_rank_bias_ |
|
|
|
<< ", the right rank bias is " << right_rank_bias_ << ", the left rank id is " << left_rank_id_ |
|
|
|
<< ", the right rank id is " << right_rank_id_ << ", the all to all group is " << all_to_all_group_; |
|
|
|
return SUCCESS; |
|
|
|
MS_LOG(INFO) << name_ << ": The current rank is " << rank << ", the top rank id is " << top_rank_id_ |
|
|
|
<< ", the top right rank id is " << top_right_rank_id_ << ", the right rank id is " << right_rank_id_ |
|
|
|
<< ", the bottom right rank id is " << bottom_right_rank_id_ << ", the bottom rank id is " |
|
|
|
<< bottom_rank_id_ << ", the bottom left rank id is " << bottom_left_rank_id_ << ", the left rank id is " |
|
|
|
<< left_rank_id_ << ", the top left rank id is " << top_left_rank_id_ << ", the top rank bias is " |
|
|
|
<< top_rank_bias_ << ", the bottom rank bias is " << bottom_rank_bias_ << ", the left rank bias is " |
|
|
|
<< left_rank_bias_ << ", the right rank bias is " << right_rank_bias_ << ", the h dim rank bias is " |
|
|
|
<< h_rank_bias_ << ", the w dim rank bias is " << w_rank_bias_; |
|
|
|
} |
|
|
|
|
|
|
|
int64_t Conv2DInfo::ComputeOverlapTopSizeByRankBias(int64_t rank_bias) { |
|
|
|
int64_t top_pad = pad_list_[0]; |
|
|
|
int64_t h_dimension_input_shape = inputs_shape_[0][2]; |
|
|
|
int64_t h_dimension_output_shape = outputs_shape_[0][2]; |
|
|
|
int64_t h_stride = stride_[2]; |
|
|
|
|
|
|
|
return top_pad + (h_dimension_input_shape - h_dimension_output_shape * h_stride) * rank_bias / h_dimension_shard_num_; |
|
|
|
} |
|
|
|
|
|
|
|
int64_t Conv2DInfo::ComputeOverlapBottomSizeByRankBias(int64_t rank_bias) { |
|
|
|
int64_t top_pad = pad_list_[0]; |
|
|
|
int64_t h_dimension_input_shape = inputs_shape_[0][2]; |
|
|
|
int64_t h_dimension_output_shape = outputs_shape_[0][2]; |
|
|
|
int64_t h_kernel_size = kernel_size_[0]; |
|
|
|
int64_t h_stride = stride_[2]; |
|
|
|
|
|
|
|
return (rank_bias + 1) * (h_dimension_output_shape * h_stride - h_dimension_input_shape) / h_dimension_shard_num_ + |
|
|
|
h_kernel_size - h_stride - top_pad; |
|
|
|
} |
|
|
|
|
|
|
|
int64_t Conv2DInfo::ComputeOverlapLeftSizeByRankBias(int64_t rank_bias) { |
|
|
|
@@ -435,38 +512,78 @@ int64_t Conv2DInfo::ComputeOverlapRightSizeByRankBias(int64_t rank_bias) { |
|
|
|
w_kernel_size - w_stride - left_pad; |
|
|
|
} |
|
|
|
|
|
|
|
void Conv2DInfo::InferOverlapSize() { |
|
|
|
if (!need_exchange_overlap_) { |
|
|
|
MS_LOG(INFO) << name_ << ": No need to infer overlap size"; |
|
|
|
void Conv2DInfo::InferOverlapSizeForHDim() { |
|
|
|
if (!h_dim_need_exchange_overlap_) { |
|
|
|
overlap_top_size_ = 0; |
|
|
|
overlap_bottom_size_ = 0; |
|
|
|
bottom_rank_overlap_top_size_ = 0; |
|
|
|
top_rank_overlap_bottom_size_ = 0; |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
overlap_left_size_ = ComputeOverlapLeftSizeByRankBias(rank_bias_); |
|
|
|
overlap_right_size_ = ComputeOverlapRightSizeByRankBias(rank_bias_); |
|
|
|
if (h_rank_bias_ == 0) { |
|
|
|
// it has not top rank |
|
|
|
overlap_top_size_ = 0; |
|
|
|
overlap_bottom_size_ = ComputeOverlapBottomSizeByRankBias(h_rank_bias_); |
|
|
|
top_rank_overlap_bottom_size_ = 0; |
|
|
|
bottom_rank_overlap_top_size_ = ComputeOverlapTopSizeByRankBias(bottom_rank_bias_); |
|
|
|
} else if (h_rank_bias_ == h_dimension_shard_num_ - 1) { |
|
|
|
// it has not bottom rank |
|
|
|
overlap_top_size_ = ComputeOverlapTopSizeByRankBias(h_rank_bias_); |
|
|
|
overlap_bottom_size_ = 0; |
|
|
|
top_rank_overlap_bottom_size_ = ComputeOverlapBottomSizeByRankBias(top_rank_bias_); |
|
|
|
bottom_rank_overlap_top_size_ = 0; |
|
|
|
} else { |
|
|
|
// it has left rank and right rank |
|
|
|
overlap_top_size_ = ComputeOverlapTopSizeByRankBias(h_rank_bias_); |
|
|
|
overlap_bottom_size_ = ComputeOverlapBottomSizeByRankBias(h_rank_bias_); |
|
|
|
top_rank_overlap_bottom_size_ = ComputeOverlapBottomSizeByRankBias(top_rank_bias_); |
|
|
|
bottom_rank_overlap_top_size_ = ComputeOverlapTopSizeByRankBias(bottom_rank_bias_); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void Conv2DInfo::InferOverlapSizeForWDim() { |
|
|
|
if (!w_dim_need_exchange_overlap_) { |
|
|
|
overlap_left_size_ = 0; |
|
|
|
overlap_right_size_ = 0; |
|
|
|
left_rank_overlap_right_size_ = 0; |
|
|
|
right_rank_overlap_left_size_ = 0; |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
if (rank_bias_ == 0) { // it has not left rank |
|
|
|
left_rank_overlap_left_size_ = 0; |
|
|
|
if (w_rank_bias_ == 0) { |
|
|
|
// it has not left rank |
|
|
|
overlap_left_size_ = 0; |
|
|
|
overlap_right_size_ = ComputeOverlapRightSizeByRankBias(w_rank_bias_); |
|
|
|
left_rank_overlap_right_size_ = 0; |
|
|
|
right_rank_overlap_left_size_ = ComputeOverlapLeftSizeByRankBias(right_rank_bias_); |
|
|
|
right_rank_overlap_right_size_ = ComputeOverlapRightSizeByRankBias(right_rank_bias_); |
|
|
|
} else if (rank_bias_ == w_dimension_shard_num_ - 1) { // it has not right rank |
|
|
|
left_rank_overlap_left_size_ = ComputeOverlapLeftSizeByRankBias(left_rank_bias_); |
|
|
|
} else if (w_rank_bias_ == w_dimension_shard_num_ - 1) { |
|
|
|
// it has not right rank |
|
|
|
overlap_left_size_ = ComputeOverlapLeftSizeByRankBias(w_rank_bias_); |
|
|
|
overlap_right_size_ = 0; |
|
|
|
left_rank_overlap_right_size_ = ComputeOverlapRightSizeByRankBias(left_rank_bias_); |
|
|
|
right_rank_overlap_left_size_ = 0; |
|
|
|
right_rank_overlap_right_size_ = 0; |
|
|
|
} else { // it has left rank and right rank |
|
|
|
left_rank_overlap_left_size_ = ComputeOverlapLeftSizeByRankBias(left_rank_bias_); |
|
|
|
} else { |
|
|
|
// it has left rank and right rank |
|
|
|
overlap_left_size_ = ComputeOverlapLeftSizeByRankBias(w_rank_bias_); |
|
|
|
overlap_right_size_ = ComputeOverlapRightSizeByRankBias(w_rank_bias_); |
|
|
|
left_rank_overlap_right_size_ = ComputeOverlapRightSizeByRankBias(left_rank_bias_); |
|
|
|
right_rank_overlap_left_size_ = ComputeOverlapLeftSizeByRankBias(right_rank_bias_); |
|
|
|
right_rank_overlap_right_size_ = ComputeOverlapRightSizeByRankBias(right_rank_bias_); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void Conv2DInfo::InferOverlapSize() { |
|
|
|
InferOverlapSizeForHDim(); |
|
|
|
InferOverlapSizeForWDim(); |
|
|
|
|
|
|
|
MS_LOG(INFO) << name_ << ": the left overlap size of current rank is " << overlap_left_size_ |
|
|
|
<< ", the right overlap size of current rank is " << overlap_right_size_ |
|
|
|
<< ", the left overlap size of left rank is " << left_rank_overlap_left_size_ |
|
|
|
<< ", the right overlap size of left rank is " << left_rank_overlap_right_size_ |
|
|
|
<< ", the left overlap size of right rank is " << right_rank_overlap_left_size_ |
|
|
|
<< ", the right overlap size of right rank is " << right_rank_overlap_right_size_; |
|
|
|
<< ", the top overlap size of current rank is " << overlap_top_size_ |
|
|
|
<< ", the bottom overlap size of current rank is " << overlap_bottom_size_ |
|
|
|
<< ", the bottom overlap size of top rank is " << top_rank_overlap_bottom_size_ |
|
|
|
<< ", the top overlap size of bottom rank is " << bottom_rank_overlap_top_size_; |
|
|
|
} |
|
|
|
|
|
|
|
Status Conv2DInfo::InferTensorMap() { |
|
|
|
@@ -518,60 +635,96 @@ Status Conv2DInfo::InferForwardCommunication() { |
|
|
|
|
|
|
|
void Conv2DInfo::InferNewPadList() { |
|
|
|
new_pad_list_ = pad_list_; |
|
|
|
if (rank_bias_ == 0) { // the first rank |
|
|
|
new_pad_list_[3] = 0; // no need the right pad |
|
|
|
} else if (rank_bias_ == w_dimension_shard_num_ - 1) { // the last rank |
|
|
|
new_pad_list_[2] = 0; // no need the left pad |
|
|
|
} else { // the middle rank |
|
|
|
new_pad_list_[2] = 0; // no need the left pad |
|
|
|
new_pad_list_[3] = 0; // no need the right pad |
|
|
|
if (h_dim_need_exchange_overlap_) { |
|
|
|
if (h_rank_bias_ == 0) { // the first rank |
|
|
|
new_pad_list_[1] = 0; // no need the bottom pad |
|
|
|
} else if (h_rank_bias_ == h_dimension_shard_num_ - 1) { // the last rank |
|
|
|
new_pad_list_[0] = 0; // no need the top pad |
|
|
|
} else { // the middle rank |
|
|
|
new_pad_list_[0] = 0; // no need the top pad |
|
|
|
new_pad_list_[1] = 0; // no need the bottom pad |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if (w_dim_need_exchange_overlap_) { |
|
|
|
if (w_rank_bias_ == 0) { // the first rank |
|
|
|
new_pad_list_[3] = 0; // no need the right pad |
|
|
|
} else if (w_rank_bias_ == w_dimension_shard_num_ - 1) { // the last rank |
|
|
|
new_pad_list_[2] = 0; // no need the left pad |
|
|
|
} else { // the middle rank |
|
|
|
new_pad_list_[2] = 0; // no need the left pad |
|
|
|
new_pad_list_[3] = 0; // no need the right pad |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
MS_LOG(INFO) << name_ << ": the new pad list is " << new_pad_list_; |
|
|
|
} |
|
|
|
|
|
|
|
void Conv2DInfo::InferSendRankIds() { |
|
|
|
int64_t send_top_rank = top_rank_overlap_bottom_size_ > 0 ? top_rank_id_ : -1; |
|
|
|
int64_t send_bottom_rank = bottom_rank_overlap_top_size_ > 0 ? bottom_rank_id_ : -1; |
|
|
|
int64_t send_left_rank = left_rank_overlap_right_size_ > 0 ? left_rank_id_ : -1; |
|
|
|
int64_t send_right_rank = right_rank_overlap_left_size_ > 0 ? right_rank_id_ : -1; |
|
|
|
int64_t send_top_left_rank = (send_top_rank != -1 && send_left_rank != -1) ? top_left_rank_id_ : -1; |
|
|
|
int64_t send_top_right_rank = (send_top_rank != -1 && send_right_rank != -1) ? top_right_rank_id_ : -1; |
|
|
|
int64_t send_bottom_left_rank = (send_bottom_rank != -1 && send_left_rank != -1) ? bottom_left_rank_id_ : -1; |
|
|
|
int64_t send_bottom_right_rank = (send_bottom_rank != -1 && send_right_rank != -1) ? bottom_right_rank_id_ : -1; |
|
|
|
|
|
|
|
// the order of send or recv rank ids in the array is organized in the following format: |
|
|
|
// [top, top_right, right, bottom_right, bottom, bottom_left, left, top_left] |
|
|
|
send_rank_ids_ = {send_top_rank, send_top_right_rank, send_right_rank, send_bottom_right_rank, |
|
|
|
send_bottom_rank, send_bottom_left_rank, send_left_rank, send_top_left_rank}; |
|
|
|
} |
|
|
|
|
|
|
|
void Conv2DInfo::InferRecvRankIds() { |
|
|
|
int64_t recv_top_rank = overlap_top_size_ > 0 ? top_rank_id_ : -1; |
|
|
|
int64_t recv_bottom_rank = overlap_bottom_size_ > 0 ? bottom_rank_id_ : -1; |
|
|
|
int64_t recv_left_rank = overlap_left_size_ > 0 ? left_rank_id_ : -1; |
|
|
|
int64_t recv_right_rank = overlap_right_size_ > 0 ? right_rank_id_ : -1; |
|
|
|
int64_t recv_top_left_rank = (recv_top_rank != -1 && recv_left_rank != -1) ? top_left_rank_id_ : -1; |
|
|
|
int64_t recv_top_right_rank = (recv_top_rank != -1 && recv_right_rank != -1) ? top_right_rank_id_ : -1; |
|
|
|
int64_t recv_bottom_left_rank = (recv_bottom_rank != -1 && recv_left_rank != -1) ? bottom_left_rank_id_ : -1; |
|
|
|
int64_t recv_bottom_right_rank = (recv_bottom_rank != -1 && recv_right_rank != -1) ? bottom_right_rank_id_ : -1; |
|
|
|
|
|
|
|
// the order of send or recv rank ids in the array is organized in the following format: |
|
|
|
// [top, top_right, right, bottom_right, bottom, bottom_left, left, top_left] |
|
|
|
recv_rank_ids_ = {recv_top_rank, recv_top_right_rank, recv_right_rank, recv_bottom_right_rank, |
|
|
|
recv_bottom_rank, recv_bottom_left_rank, recv_left_rank, recv_top_left_rank}; |
|
|
|
} |
|
|
|
|
|
|
|
void Conv2DInfo::InferCommunicationAttrs() { |
|
|
|
// send rank ids: [-1, -1, send_right_rank, -1, -1, -1, send_left_rank, -1] |
|
|
|
// recv rank ids: [-1, -1, recv_right_rank, -1, -1, -1, recv_left_rank, -1] |
|
|
|
// send lens: [0, 0, send_left_len, send_right_len] |
|
|
|
// recv lens: [0, 0, recv_left_len, recv_right_len] |
|
|
|
int64_t send_right_rank = -1, send_left_rank = -1, recv_right_rank = -1, recv_left_rank = -1; |
|
|
|
int64_t send_left_len = 0, send_right_len = 0, recv_left_len = 0, recv_right_len = 0; |
|
|
|
|
|
|
|
if (rank_bias_ == 0) { |
|
|
|
// the first rank |
|
|
|
send_right_len = right_rank_overlap_left_size_; |
|
|
|
send_right_rank = send_right_len > 0 ? right_rank_id_ : -1; |
|
|
|
|
|
|
|
recv_right_len = overlap_right_size_; |
|
|
|
recv_right_rank = recv_right_len > 0 ? right_rank_id_ : -1; |
|
|
|
} else if (rank_bias_ == w_dimension_shard_num_ - 1) { |
|
|
|
// the last rank |
|
|
|
send_left_len = left_rank_overlap_right_size_; |
|
|
|
send_left_rank = send_left_len > 0 ? left_rank_id_ : -1; |
|
|
|
|
|
|
|
recv_left_len = overlap_left_size_; |
|
|
|
recv_left_rank = recv_left_len > 0 ? left_rank_id_ : -1; |
|
|
|
} else { |
|
|
|
// the middle rank |
|
|
|
send_right_len = right_rank_overlap_left_size_; |
|
|
|
send_right_rank = send_right_len > 0 ? right_rank_id_ : -1; |
|
|
|
// send ranks |
|
|
|
InferSendRankIds(); |
|
|
|
|
|
|
|
recv_right_len = overlap_right_size_; |
|
|
|
recv_right_rank = recv_right_len > 0 ? right_rank_id_ : -1; |
|
|
|
send_left_len = left_rank_overlap_right_size_; |
|
|
|
send_left_rank = send_left_len > 0 ? left_rank_id_ : -1; |
|
|
|
// recv ranks |
|
|
|
InferRecvRankIds(); |
|
|
|
|
|
|
|
recv_left_len = overlap_left_size_; |
|
|
|
recv_left_rank = recv_left_len > 0 ? left_rank_id_ : -1; |
|
|
|
} |
|
|
|
// send lens |
|
|
|
int64_t send_top_len = top_rank_overlap_bottom_size_; |
|
|
|
int64_t send_bottom_len = bottom_rank_overlap_top_size_; |
|
|
|
int64_t send_left_len = left_rank_overlap_right_size_; |
|
|
|
int64_t send_right_len = right_rank_overlap_left_size_; |
|
|
|
|
|
|
|
// recv lens |
|
|
|
int64_t recv_top_len = overlap_top_size_; |
|
|
|
int64_t recv_bottom_len = overlap_bottom_size_; |
|
|
|
int64_t recv_left_len = overlap_left_size_; |
|
|
|
int64_t recv_right_len = overlap_right_size_; |
|
|
|
|
|
|
|
// the order of send or recv lens in the array is organized in the following format: |
|
|
|
// [top, bottom, left, right] |
|
|
|
send_lens_ = {send_top_len, send_bottom_len, send_left_len, send_right_len}; |
|
|
|
recv_lens_ = {recv_top_len, recv_bottom_len, recv_left_len, recv_right_len}; |
|
|
|
|
|
|
|
send_rank_ids_ = {-1, -1, send_right_rank, -1, -1, -1, send_left_rank, -1}; |
|
|
|
recv_rank_ids_ = {-1, -1, recv_right_rank, -1, -1, -1, recv_left_rank, -1}; |
|
|
|
send_lens_ = {0, 0, send_left_len, send_right_len}; |
|
|
|
recv_lens_ = {0, 0, recv_left_len, recv_right_len}; |
|
|
|
MS_LOG(INFO) << name_ << ": The send rank ids is " << send_rank_ids_ << ", the send lens is " << send_lens_ |
|
|
|
<< ", the recv rank ids is " << recv_rank_ids_ << ", the recv lens is " << recv_lens_; |
|
|
|
|
|
|
|
int64_t h_slice_shape = input_slice_shape_[2]; |
|
|
|
if (send_top_len > h_slice_shape || send_bottom_len > h_slice_shape || recv_top_len > h_slice_shape || |
|
|
|
recv_bottom_len > h_slice_shape) { |
|
|
|
MS_LOG(EXCEPTION) << name_ << ": The send or recv len larger than slice shape of h dimension " << h_slice_shape; |
|
|
|
} |
|
|
|
|
|
|
|
int64_t w_slice_shape = input_slice_shape_[3]; |
|
|
|
if (send_left_len > w_slice_shape || send_right_len > w_slice_shape || recv_left_len > w_slice_shape || |
|
|
|
recv_right_len > w_slice_shape) { |
|
|
|
@@ -675,7 +828,7 @@ void Conv2DInfo::ComputeReplaceGraph(const CNodePtr &cnode) { |
|
|
|
} |
|
|
|
|
|
|
|
ReplaceGraphPtr Conv2DInfo::replace_graph(const CNodePtr &cnode) { |
|
|
|
if (!need_exchange_overlap_) { |
|
|
|
if (!w_dim_need_exchange_overlap_ && !h_dim_need_exchange_overlap_) { |
|
|
|
if (!out_channel_shard_) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
@@ -684,9 +837,7 @@ ReplaceGraphPtr Conv2DInfo::replace_graph(const CNodePtr &cnode) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
if (InferRankBias() != SUCCESS) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
InferAdjacentRankInfo(); |
|
|
|
|
|
|
|
InferOverlapSize(); |
|
|
|
|
|
|
|
@@ -757,7 +908,8 @@ Status Conv2DBackpropInputInfo::GetAttrs() { |
|
|
|
} |
|
|
|
|
|
|
|
Status Conv2DBackpropInputInfo::CheckStrategy(const StrategyPtr &strategy) { |
|
|
|
need_exchange_overlap_ = false; |
|
|
|
w_dim_need_exchange_overlap_ = false; |
|
|
|
h_dim_need_exchange_overlap_ = false; |
|
|
|
if (CheckStrategyBase(strategy) != SUCCESS) { |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
@@ -777,9 +929,13 @@ Status Conv2DBackpropInputInfo::CheckStrategy(const StrategyPtr &strategy) { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// kernel size larger than stride and the w dimension is split, need to exchange overlap |
|
|
|
// kernel size larger than stride and the h/w dimension is split, need to exchange overlap |
|
|
|
if ((kernel_size_[0] > stride_[2]) && (input_strategy[2] > 1)) { |
|
|
|
h_dim_need_exchange_overlap_ = true; |
|
|
|
} |
|
|
|
|
|
|
|
if ((kernel_size_[1] > stride_[3]) && (input_strategy[3] > 1)) { |
|
|
|
need_exchange_overlap_ = true; |
|
|
|
w_dim_need_exchange_overlap_ = true; |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
@@ -794,8 +950,8 @@ Status Conv2DBackpropInputInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_st |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
if (h_strategy > 1) { |
|
|
|
MS_LOG(ERROR) << name_ << ": Do not support to split h dimension"; |
|
|
|
if (h_strategy > 1 && inputs_shape_[0][2] * stride_[2] != outputs_shape_[0][2]) { |
|
|
|
MS_LOG(ERROR) << name_ << ": Do not support to split h dimension when in_shape * stride != out_shape"; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -839,6 +995,7 @@ Status Conv2DBackpropInputInfo::InferDevMatrixShape() { |
|
|
|
out_slice_shape_[i] = out_slice_shape_[i] / out_strategy[i]; |
|
|
|
} |
|
|
|
|
|
|
|
h_dimension_shard_num_ = stra[0][2]; |
|
|
|
w_dimension_shard_num_ = stra[0][3]; |
|
|
|
input_slice_shape_ = GetSliceShape(inputs_shape_[0], stra[0]); |
|
|
|
MS_LOG(INFO) << name_ << ": The output slice shape is " << out_slice_shape_; |
|
|
|
@@ -925,6 +1082,63 @@ void Conv2DBackpropInputInfo::UpdateOutShape() { |
|
|
|
MS_LOG(INFO) << name_ << ": Update the output shape " << out_slice_shape_; |
|
|
|
} |
|
|
|
|
|
|
|
int64_t Conv2DBackpropInputInfo::ComputeOverlapTopSizeByRankBias(int64_t rank_bias) { |
|
|
|
// 1. the first rank: 0 |
|
|
|
// 2. the last rank: |
|
|
|
// size of origin data required by current rank: a = ceil((o/n + k - o + h*s - s - x)/s) |
|
|
|
// data size of the current rank: b = h/n |
|
|
|
// return a - b = ceil((o/n + k - o + h*s - s - x)/s) - h/n |
|
|
|
// 3. the middle rank: (the x is top pad) |
|
|
|
// r*h/n - ceil((r*o/n - k + x + 1)/s) |
|
|
|
if (rank_bias == 0) { // the first rank |
|
|
|
return 0; |
|
|
|
} |
|
|
|
|
|
|
|
int64_t h_output_shape = outputs_shape_[0][2]; |
|
|
|
int64_t h_input_shape = inputs_shape_[0][2]; |
|
|
|
int64_t h_kernel_size = kernel_size_[0]; |
|
|
|
int64_t h_stride = stride_[2]; |
|
|
|
int64_t top_pad = pad_list_[0]; |
|
|
|
if (rank_bias == h_dimension_shard_num_ - 1) { // the last rank |
|
|
|
return DoubleToLong(std::ceil(LongToDouble(h_output_shape / h_dimension_shard_num_ + h_kernel_size - |
|
|
|
h_output_shape + h_input_shape * h_stride - h_stride - top_pad) / |
|
|
|
LongToDouble(h_stride))) - |
|
|
|
h_input_shape / h_dimension_shard_num_; |
|
|
|
} |
|
|
|
|
|
|
|
// the middle rank |
|
|
|
return rank_bias * h_input_shape / h_dimension_shard_num_ - |
|
|
|
DoubleToLong( |
|
|
|
std::ceil(LongToDouble(rank_bias * h_output_shape / h_dimension_shard_num_ - h_kernel_size + top_pad + 1) / |
|
|
|
LongToDouble(h_stride))); |
|
|
|
} |
|
|
|
|
|
|
|
int64_t Conv2DBackpropInputInfo::ComputeOverlapBottomSizeByRankBias(int64_t rank_bias) { |
|
|
|
// 1. the first rank: ceil((o/n + x)/s) - h/n |
|
|
|
// 2. the last rank: 0 |
|
|
|
// 3. the middle rank: ceil((r*o/n + o/n + x)/s) - r*h/n - h/n |
|
|
|
int64_t h_output_shape = outputs_shape_[0][2]; |
|
|
|
int64_t h_input_shape = inputs_shape_[0][2]; |
|
|
|
int64_t h_stride = stride_[2]; |
|
|
|
int64_t top_pad = pad_list_[0]; |
|
|
|
|
|
|
|
if (rank_bias == 0) { // the first rank |
|
|
|
return DoubleToLong( |
|
|
|
std::ceil(LongToDouble(h_output_shape / h_dimension_shard_num_ + top_pad) / LongToDouble(h_stride))) - |
|
|
|
h_input_shape / h_dimension_shard_num_; |
|
|
|
} |
|
|
|
|
|
|
|
if (rank_bias == h_dimension_shard_num_ - 1) { // the last rank |
|
|
|
return 0; |
|
|
|
} |
|
|
|
|
|
|
|
// the middle rank |
|
|
|
return DoubleToLong(std::ceil(LongToDouble(rank_bias * h_output_shape / h_dimension_shard_num_ + |
|
|
|
h_output_shape / h_dimension_shard_num_ + top_pad) / |
|
|
|
LongToDouble(h_stride))) - |
|
|
|
(rank_bias + 1) * h_input_shape / h_dimension_shard_num_; |
|
|
|
} |
|
|
|
|
|
|
|
int64_t Conv2DBackpropInputInfo::ComputeOverlapLeftSizeByRankBias(int64_t rank_bias) { |
|
|
|
// 1. the first rank: 0 |
|
|
|
// 2. the last rank: |
|
|
|
@@ -982,7 +1196,7 @@ int64_t Conv2DBackpropInputInfo::ComputeOverlapRightSizeByRankBias(int64_t rank_ |
|
|
|
(rank_bias + 1) * w_input_shape / w_dimension_shard_num_; |
|
|
|
} |
|
|
|
|
|
|
|
void Conv2DBackpropInputInfo::InferNewPadList() { |
|
|
|
void Conv2DBackpropInputInfo::InferNewPadListByDimension(const std::string &dimension) { |
|
|
|
// 1. compute the size of origin data required by current rank: |
|
|
|
// 1) the first rank: ceil((o/n + x) / s) |
|
|
|
// 2) the last rank: ceil((o/n + k - o + ws - s - x) / s) |
|
|
|
@@ -996,64 +1210,105 @@ void Conv2DBackpropInputInfo::InferNewPadList() { |
|
|
|
// 3) the middle rank: |
|
|
|
// if (r*on - k + x + 1) is divisible by s, real_left_pad = 0. |
|
|
|
// otherwise, real_left_pad = s - (r*on - k + x + 1) % s |
|
|
|
int64_t w_output_shape = outputs_shape_[0][3]; |
|
|
|
int64_t w_input_shape = inputs_shape_[0][3]; |
|
|
|
int64_t w_kernel_size = kernel_size_[1]; |
|
|
|
int64_t w_stride = stride_[3]; |
|
|
|
int64_t left_pad = pad_list_[2]; |
|
|
|
int64_t current_rank_required_size = 0; |
|
|
|
int64_t real_left_pad = 0; |
|
|
|
int64_t real_top_or_left_pad = 0; |
|
|
|
int64_t h_or_w_output_shape = -1; |
|
|
|
int64_t h_or_w_input_shape = -1; |
|
|
|
int64_t h_or_w_kernel_size = -1; |
|
|
|
int64_t h_or_w_stride = -1; |
|
|
|
int64_t top_or_left_pad = -1; |
|
|
|
int64_t h_or_w_rank_bias = -1; |
|
|
|
int64_t h_or_w_dim_shard_num = -1; |
|
|
|
|
|
|
|
if (dimension == H_DIMENSION) { |
|
|
|
h_or_w_output_shape = outputs_shape_[0][2]; |
|
|
|
h_or_w_input_shape = inputs_shape_[0][2]; |
|
|
|
h_or_w_kernel_size = kernel_size_[0]; |
|
|
|
h_or_w_stride = stride_[2]; |
|
|
|
top_or_left_pad = pad_list_[0]; |
|
|
|
h_or_w_rank_bias = h_rank_bias_; |
|
|
|
h_or_w_dim_shard_num = h_dimension_shard_num_; |
|
|
|
} else { |
|
|
|
h_or_w_output_shape = outputs_shape_[0][3]; |
|
|
|
h_or_w_input_shape = inputs_shape_[0][3]; |
|
|
|
h_or_w_kernel_size = kernel_size_[1]; |
|
|
|
h_or_w_stride = stride_[3]; |
|
|
|
top_or_left_pad = pad_list_[2]; |
|
|
|
h_or_w_rank_bias = w_rank_bias_; |
|
|
|
h_or_w_dim_shard_num = w_dimension_shard_num_; |
|
|
|
} |
|
|
|
|
|
|
|
if (rank_bias_ == 0) { // the first rank |
|
|
|
current_rank_required_size = DoubleToLong( |
|
|
|
std::ceil(LongToDouble(w_output_shape / w_dimension_shard_num_ + left_pad) / LongToDouble(w_stride))); |
|
|
|
if (h_or_w_rank_bias == 0) { // the first rank |
|
|
|
current_rank_required_size = DoubleToLong(std::ceil( |
|
|
|
LongToDouble(h_or_w_output_shape / h_or_w_dim_shard_num + top_or_left_pad) / LongToDouble(h_or_w_stride))); |
|
|
|
|
|
|
|
real_left_pad = w_kernel_size - left_pad - 1; |
|
|
|
} else if (rank_bias_ == w_dimension_shard_num_ - 1) { // the last rank |
|
|
|
current_rank_required_size = |
|
|
|
DoubleToLong(std::ceil(LongToDouble(w_output_shape / w_dimension_shard_num_ + w_kernel_size - w_output_shape + |
|
|
|
w_input_shape * w_stride - w_stride - left_pad) / |
|
|
|
LongToDouble(w_stride))); |
|
|
|
|
|
|
|
int64_t tmp = w_output_shape / w_dimension_shard_num_ + w_kernel_size - w_output_shape + w_input_shape * w_stride - |
|
|
|
w_stride - left_pad; |
|
|
|
if (tmp % w_stride == 0) { |
|
|
|
real_left_pad = w_stride - 1; |
|
|
|
real_top_or_left_pad = h_or_w_kernel_size - top_or_left_pad - 1; |
|
|
|
} else if (h_or_w_rank_bias == h_or_w_dim_shard_num - 1) { // the last rank |
|
|
|
current_rank_required_size = DoubleToLong( |
|
|
|
std::ceil(LongToDouble(h_or_w_output_shape / h_or_w_dim_shard_num + h_or_w_kernel_size - h_or_w_output_shape + |
|
|
|
h_or_w_input_shape * h_or_w_stride - h_or_w_stride - top_or_left_pad) / |
|
|
|
LongToDouble(h_or_w_stride))); |
|
|
|
|
|
|
|
int64_t tmp = h_or_w_output_shape / h_or_w_dim_shard_num + h_or_w_kernel_size - h_or_w_output_shape + |
|
|
|
h_or_w_input_shape * h_or_w_stride - h_or_w_stride - top_or_left_pad; |
|
|
|
if (tmp % h_or_w_stride == 0) { |
|
|
|
real_top_or_left_pad = h_or_w_stride - 1; |
|
|
|
} else { |
|
|
|
real_left_pad = tmp % w_stride - 1; |
|
|
|
real_top_or_left_pad = tmp % h_or_w_stride - 1; |
|
|
|
} |
|
|
|
} else { // the middle rank |
|
|
|
current_rank_required_size = |
|
|
|
DoubleToLong(std::ceil(LongToDouble(rank_bias_ * w_output_shape / w_dimension_shard_num_ + |
|
|
|
w_output_shape / w_dimension_shard_num_ + left_pad) / |
|
|
|
LongToDouble(w_stride))) - |
|
|
|
DoubleToLong( |
|
|
|
std::ceil(LongToDouble(rank_bias_ * w_output_shape / w_dimension_shard_num_ - w_kernel_size + left_pad + 1) / |
|
|
|
LongToDouble(w_stride))); |
|
|
|
|
|
|
|
int64_t tmp = rank_bias_ * w_output_shape / w_dimension_shard_num_ - w_kernel_size + left_pad + 1; |
|
|
|
if (tmp % w_stride == 0) { |
|
|
|
real_left_pad = 0; |
|
|
|
DoubleToLong(std::ceil(LongToDouble(h_or_w_rank_bias * h_or_w_output_shape / h_or_w_dim_shard_num + |
|
|
|
h_or_w_output_shape / h_or_w_dim_shard_num + top_or_left_pad) / |
|
|
|
LongToDouble(h_or_w_stride))) - |
|
|
|
DoubleToLong(std::ceil(LongToDouble(h_or_w_rank_bias * h_or_w_output_shape / h_or_w_dim_shard_num - |
|
|
|
h_or_w_kernel_size + top_or_left_pad + 1) / |
|
|
|
LongToDouble(h_or_w_stride))); |
|
|
|
|
|
|
|
int64_t tmp = |
|
|
|
h_or_w_rank_bias * h_or_w_output_shape / h_or_w_dim_shard_num - h_or_w_kernel_size + top_or_left_pad + 1; |
|
|
|
if (tmp % h_or_w_stride == 0) { |
|
|
|
real_top_or_left_pad = 0; |
|
|
|
} else { |
|
|
|
real_left_pad = w_stride - tmp % w_stride; |
|
|
|
real_top_or_left_pad = h_or_w_stride - tmp % h_or_w_stride; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// 3. compute the pad_add: (current_rank_required_size - 1) * s + k - o/n |
|
|
|
int64_t pad_all = |
|
|
|
(current_rank_required_size - 1) * w_stride + w_kernel_size - w_output_shape / w_dimension_shard_num_; |
|
|
|
(current_rank_required_size - 1) * h_or_w_stride + h_or_w_kernel_size - h_or_w_output_shape / h_or_w_dim_shard_num; |
|
|
|
|
|
|
|
// 4. compute new left pad: k - real_left_pad - 1 |
|
|
|
new_pad_list_ = pad_list_; |
|
|
|
new_pad_list_[2] = w_kernel_size - real_left_pad - 1; |
|
|
|
|
|
|
|
// 5. compute new right pad: pad_all - new_left_pad |
|
|
|
new_pad_list_[3] = pad_all - new_pad_list_[2]; |
|
|
|
if (dimension == H_DIMENSION) { |
|
|
|
new_pad_list_[0] = h_or_w_kernel_size - real_top_or_left_pad - 1; |
|
|
|
new_pad_list_[1] = pad_all - new_pad_list_[2]; |
|
|
|
} else { |
|
|
|
new_pad_list_[2] = h_or_w_kernel_size - real_top_or_left_pad - 1; |
|
|
|
new_pad_list_[3] = pad_all - new_pad_list_[2]; |
|
|
|
} |
|
|
|
|
|
|
|
MS_LOG(INFO) << name_ << ": the new pad list is " << new_pad_list_ << ", the required size of current rank is " |
|
|
|
MS_LOG(INFO) << name_ << ": The dimension is " << dimension << ", the required size of current rank is " |
|
|
|
<< current_rank_required_size << ", new pad all is " << pad_all; |
|
|
|
} |
|
|
|
|
|
|
|
void Conv2DBackpropInputInfo::InferNewPadList() { |
|
|
|
// init new pad list |
|
|
|
new_pad_list_ = pad_list_; |
|
|
|
|
|
|
|
// infer h dimension's new pad |
|
|
|
if (h_dim_need_exchange_overlap_) { |
|
|
|
InferNewPadListByDimension(H_DIMENSION); |
|
|
|
} |
|
|
|
|
|
|
|
// infer w dimension's new pad |
|
|
|
if (w_dim_need_exchange_overlap_) { |
|
|
|
InferNewPadListByDimension(W_DIMENSION); |
|
|
|
} |
|
|
|
|
|
|
|
MS_LOG(INFO) << name_ << ": The new pad list is " << new_pad_list_; |
|
|
|
} |
|
|
|
|
|
|
|
void Conv2DBackpropInputInfo::ReplaceNodeInputOrAttrs() { UpdateOutShape(); } |
|
|
|
} // namespace parallel |
|
|
|
} // namespace mindspore |