Browse Source

compute top bottom overlap for conv2d

tags/v1.6.0
yangzhenzhang 4 years ago
parent
commit
e5df74e9e4
6 changed files with 645 additions and 208 deletions
  1. +445
    -190
      mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.cc
  2. +50
    -8
      mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.h
  3. +2
    -0
      mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h
  4. +1
    -3
      mindspore/ccsrc/frontend/parallel/ops_info/resizebilinear_info.cc
  5. +144
    -2
      tests/ut/python/parallel/test_conv2d.py
  6. +3
    -5
      tests/ut/python/parallel/test_conv2d_transpose.py

+ 445
- 190
mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.cc View File

@@ -140,52 +140,56 @@ Status Conv2DInfo::CheckHWStrategyBase(int64_t h_strategy, int64_t w_strategy) c
return SUCCESS;
}

Status Conv2DInfo::CheckHWStrategySameMode(int64_t h_strategy, int64_t w_strategy) {
int64_t h_slice_shape = inputs_shape_[0][2] / h_strategy;
int64_t w_slice_shape = inputs_shape_[0][3] / w_strategy;

// H dimension
if (kernel_size_[0] > stride_[2] && h_strategy > 1) {
MS_LOG(ERROR) << name_ << ": The 'same' mode do not support to split H when kernel_size > stride";
return FAILED;
}

if (h_strategy > 1 && (kernel_size_[0] <= stride_[2] && h_slice_shape % stride_[2] != 0)) {
MS_LOG(ERROR) << name_
<< ": The 'same' mode do not support to split H when kernel_size <= stride but slice shape "
"is not divisible by stride ";
return FAILED;
Status Conv2DInfo::CheckHWStrategySameModeByDimension(int64_t strategy, const std::string &dimension) {
int64_t h_or_w_input_shape = 0, h_or_w_slice_shape = 0, h_or_w_kernel_size = 0, h_or_w_stride = 0;
if (dimension == H_DIMENSION) {
h_or_w_input_shape = inputs_shape_[0][2];
h_or_w_slice_shape = h_or_w_input_shape / strategy;
h_or_w_kernel_size = kernel_size_[0];
h_or_w_stride = stride_[2];
} else {
h_or_w_input_shape = inputs_shape_[0][3];
h_or_w_slice_shape = h_or_w_input_shape / strategy;
h_or_w_kernel_size = kernel_size_[1];
h_or_w_stride = stride_[3];
}

// W dimension
if (w_strategy > 1 && (kernel_size_[1] <= stride_[3] && w_slice_shape % stride_[3] != 0)) {
MS_LOG(ERROR) << name_
<< ": The 'same' mode do not support to split W when kernel_size <= stride but slice shape "
"is not divisible by stride ";
if (strategy > 1 && (h_or_w_kernel_size <= h_or_w_stride && h_or_w_slice_shape % h_or_w_stride != 0)) {
MS_LOG(ERROR) << name_ << ": The 'same' mode do not support to split " << dimension
<< " when kernel_size <= stride but slice shape is not divisible by stride ";
return FAILED;
}

if (w_strategy > 1 && (kernel_size_[1] > stride_[3])) {
if (inputs_shape_[0][3] % stride_[3] != 0) {
MS_LOG(ERROR) << name_
<< ": The 'same' mode do not support to split W when kernel_size > stride but w shape is not "
"divisible by stride";
if (strategy > 1 && (h_or_w_kernel_size > h_or_w_stride)) {
if (h_or_w_input_shape % h_or_w_stride != 0) {
MS_LOG(ERROR) << name_ << ": The 'same' mode do not support to split " << dimension
<< " when kernel_size > stride but input shape is not divisible by stride";
return FAILED;
}

if (w_slice_shape <= ((kernel_size_[1] - stride_[3] + 1) / 2)) {
MS_LOG(ERROR) << name_
<< ": The 'same' mode do not support to split W when kernel_size > stride but w slice shape is "
"smaller than or equal to (k - s + 1) / 2";
if (h_or_w_slice_shape <= ((h_or_w_kernel_size - h_or_w_stride + 1) / 2)) {
MS_LOG(ERROR) << name_ << ": The 'same' mode do not support to split " << dimension
<< " when kernel_size > stride but slice shape is smaller than or equal to (k - s + 1) / 2";
return FAILED;
}

if (kernel_size_[1] - stride_[3] == 1) {
MS_LOG(ERROR) << name_ << ": The 'same' mode do not support to split W when kernel_size > stride but k - s == 1";
if (h_or_w_kernel_size - h_or_w_stride == 1) {
MS_LOG(ERROR) << name_ << ": The 'same' mode do not support to split " << dimension
<< " when kernel_size > stride but k - s == 1";
return FAILED;
}
}
return SUCCESS;
}

Status Conv2DInfo::CheckHWStrategySameMode(int64_t h_strategy, int64_t w_strategy) {
if (CheckHWStrategySameModeByDimension(h_strategy, H_DIMENSION) != SUCCESS) {
return FAILED;
}

if (CheckHWStrategySameModeByDimension(w_strategy, W_DIMENSION) != SUCCESS) {
return FAILED;
}
return SUCCESS;
}

@@ -293,7 +297,8 @@ Status Conv2DInfo::CheckStrategyBase(const StrategyPtr &strategy) {
}

Status Conv2DInfo::CheckStrategy(const StrategyPtr &strategy) {
need_exchange_overlap_ = false;
h_dim_need_exchange_overlap_ = false;
w_dim_need_exchange_overlap_ = false;
if (CheckStrategyBase(strategy) != SUCCESS) {
return FAILED;
}
@@ -313,11 +318,14 @@ Status Conv2DInfo::CheckStrategy(const StrategyPtr &strategy) {
}
}

// kernel size larger than stride and the w dimension is split, need to exchange overlap
if ((kernel_size_[1] > stride_[3]) && (input_strategy[3] > 1)) {
need_exchange_overlap_ = true;
// kernel size larger than stride and the h/w dimension is split, need to exchange overlap
if ((kernel_size_[0] > stride_[2]) && (input_strategy[2] > 1)) {
h_dim_need_exchange_overlap_ = true;
}

if ((kernel_size_[1] > stride_[3]) && (input_strategy[3] > 1)) {
w_dim_need_exchange_overlap_ = true;
}
return SUCCESS;
}

@@ -333,85 +341,154 @@ Status Conv2DInfo::InferDevMatrixShape() {

dev_matrix_shape_ = stra[0];
dev_matrix_shape_.push_back(stra[1][0]);
h_dimension_shard_num_ = stra[0][2];
w_dimension_shard_num_ = stra[0][3];
input_slice_shape_ = GetSliceShape(inputs_shape_[0], stra[0]);
return SUCCESS;
}

Status Conv2DInfo::InferRankBias() {
// the Conv2D operator:
// the origin dev_matrix is [n, i, h, w, o]
// if repeated calculation and repeated num in the left of dev matrix, the dev_matrix is [repeated_num, n, i, h, w, o]
// if repeated calculation and repeated num in the right of dev matrix, the dev_matrix is [n, i, h, w, o,
// repeated_num]
//
// the Conv2DBackpropInput's origin dev_matrix is [n, o, h, w, i], w dimension's relative position is the same as
// Conv2D, the rank_bias_ is the position of the current rank in the w dimension of the dev_matrix(have not split h
// dimension)
if (!need_exchange_overlap_) {
MS_LOG(INFO) << name_ << ": No need to infer rank bias";
return SUCCESS;
std::vector<int64_t> Conv2DInfo::GetAdjacentRankIdsAndBiases(int64_t rank_id, const std::string &dimension) {
std::vector<int64_t> ret;
if (rank_id < 0) {
ret = {-1, -1, -1, -1, -1};
return ret;
}

MS_LOG(INFO) << name_ << ": The rank id is " << rank_id << ", the dimension is " << dimension;

uint64_t index_in_dev_matrix = -1;
int64_t dimension_shard_num = 1;
if (dimension == H_DIMENSION) {
index_in_dev_matrix = 2;
dimension_shard_num = h_dimension_shard_num_;
} else {
index_in_dev_matrix = 3;
dimension_shard_num = w_dimension_shard_num_;
}

uint64_t w_index_in_dev_matrix = 3;
if (repeated_calc_num_ > 1 && !repeated_num_in_dev_matrix_right_) {
w_index_in_dev_matrix += 1;
index_in_dev_matrix += 1;
}

CheckGlobalDeviceManager();
int64_t rank = g_device_manager->global_rank();
DeviceMatrix dev_matrix(rank, stage_device_list_, dev_matrix_shape_);
DeviceMatrix dev_matrix(rank_id, stage_device_list_, dev_matrix_shape_);
RankList group_devices;
if (dev_matrix.GetDevicesAlongDim(w_index_in_dev_matrix, &group_devices) != SUCCESS) {
return FAILED;
if (dev_matrix.GetDevicesAlongDim(index_in_dev_matrix, &group_devices) != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << ": Get device along dim failed";
}

if (group_devices.size() <= 1) {
MS_LOG(INFO) << name_ << ": The devices' size of w dimension is " << group_devices.size()
MS_LOG(INFO) << name_ << ": The devices' size of " << dimension << " is " << group_devices.size()
<< ", no need to infer rank bias";
return SUCCESS;
ret = {-1, -1, -1, -1, -1};
return ret;
}

if (group_devices.size() != LongToSize(w_dimension_shard_num_)) {
MS_LOG(ERROR) << name_ << ": The devices' size of w dimension is " << group_devices.size()
<< ", but the shard num of w dimension is " << w_dimension_shard_num_;
return FAILED;
if (group_devices.size() != LongToSize(dimension_shard_num)) {
MS_LOG(EXCEPTION) << name_ << ": The devices' size of " << dimension << " is " << group_devices.size()
<< ", but the shard num of w dimension is " << dimension_shard_num;
}

std::vector<int64_t>::iterator it = std::find(group_devices.begin(), group_devices.end(), rank);
std::vector<int64_t>::iterator it = std::find(group_devices.begin(), group_devices.end(), rank_id);
if (it == group_devices.end()) {
MS_LOG(ERROR) << name_ << ": Can not find the current rank in device list of w dimension, the current rank is "
<< rank << ", the device list is " << group_devices;
return FAILED;
MS_LOG(EXCEPTION) << name_ << ": Can not find the current rank in device list of " << dimension
<< ", the current rank is " << rank_id << ", the device list is " << group_devices;
}

rank_bias_ = std::distance(group_devices.begin(), it);
int64_t left_or_top_rank_id = -1;
int64_t right_or_bottom_rank_id = -1;
int64_t left_or_top_rank_bias = -1;
int64_t right_or_bottom_rank_bias = -1;
int64_t current_rank_bias = -1;
current_rank_bias = std::distance(group_devices.begin(), it);
if (it == group_devices.begin()) {
left_rank_bias_ = -1;
right_rank_bias_ = rank_bias_ + 1;
left_or_top_rank_bias = -1;
right_or_bottom_rank_bias = current_rank_bias + 1;

left_rank_id_ = -1;
right_rank_id_ = *(it + 1);
left_or_top_rank_id = -1;
right_or_bottom_rank_id = *(it + 1);
} else if (it == group_devices.end() - 1) {
left_rank_bias_ = rank_bias_ - 1;
right_rank_bias_ = -1;
left_or_top_rank_bias = current_rank_bias - 1;
right_or_bottom_rank_bias = -1;

left_rank_id_ = *(it - 1);
right_rank_id_ = -1;
left_or_top_rank_id = *(it - 1);
right_or_bottom_rank_id = -1;
} else {
left_rank_bias_ = rank_bias_ - 1;
right_rank_bias_ = rank_bias_ + 1;
left_or_top_rank_bias = current_rank_bias - 1;
right_or_bottom_rank_bias = current_rank_bias + 1;

left_rank_id_ = *(it - 1);
right_rank_id_ = *(it + 1);
left_or_top_rank_id = *(it - 1);
right_or_bottom_rank_id = *(it + 1);
}

ret = {left_or_top_rank_id, right_or_bottom_rank_id, left_or_top_rank_bias, right_or_bottom_rank_bias,
current_rank_bias};
return ret;
}

void Conv2DInfo::InferAdjacentRankInfo() {
// the Conv2D operator:
// the origin dev_matrix is [n, i, h, w, o]
// if repeated calculation and repeated num in the left of dev matrix, the dev_matrix is [repeated_num, n, i, h, w, o]
// if repeated calculation and repeated num in the right of dev matrix, the dev_matrix is [n, i, h, w, o,
// repeated_num]
//
// the Conv2DBackpropInput's origin dev_matrix is [n, o, h, w, i], w dimension's relative position is the same as
// Conv2D, the w_rank_bias_ is the position of the current rank in the w dimension of the dev_matrix(have not split h
// dimension)

CheckGlobalDeviceManager();
int64_t rank = g_device_manager->global_rank();
std::vector<int64_t> h_dim_rank_info = GetAdjacentRankIdsAndBiases(rank, H_DIMENSION);
top_rank_id_ = h_dim_rank_info[0];
bottom_rank_id_ = h_dim_rank_info[1];
top_rank_bias_ = h_dim_rank_info[2];
bottom_rank_bias_ = h_dim_rank_info[3];
h_rank_bias_ = h_dim_rank_info[4];

std::vector<int64_t> w_dim_rank_info = GetAdjacentRankIdsAndBiases(rank, W_DIMENSION);
left_rank_id_ = w_dim_rank_info[0];
right_rank_id_ = w_dim_rank_info[1];
left_rank_bias_ = w_dim_rank_info[2];
right_rank_bias_ = w_dim_rank_info[3];
w_rank_bias_ = w_dim_rank_info[4];

std::vector<int64_t> top_w_dim_rank_info = GetAdjacentRankIdsAndBiases(top_rank_id_, W_DIMENSION);
top_left_rank_id_ = top_w_dim_rank_info[0];
top_right_rank_id_ = top_w_dim_rank_info[1];

std::vector<int64_t> bottom_w_dim_rank_info = GetAdjacentRankIdsAndBiases(bottom_rank_id_, W_DIMENSION);
bottom_left_rank_id_ = bottom_w_dim_rank_info[0];
bottom_right_rank_id_ = bottom_w_dim_rank_info[1];

all_to_all_group_ = g_device_manager->world_group(); // use world group temporarily
MS_LOG(INFO) << name_ << ": The current rank is " << rank << ", the device list of w dimension is " << group_devices
<< ", the rank bias is " << rank_bias_ << ", the left rank bias is " << left_rank_bias_
<< ", the right rank bias is " << right_rank_bias_ << ", the left rank id is " << left_rank_id_
<< ", the right rank id is " << right_rank_id_ << ", the all to all group is " << all_to_all_group_;
return SUCCESS;
MS_LOG(INFO) << name_ << ": The current rank is " << rank << ", the top rank id is " << top_rank_id_
<< ", the top right rank id is " << top_right_rank_id_ << ", the right rank id is " << right_rank_id_
<< ", the bottom right rank id is " << bottom_right_rank_id_ << ", the bottom rank id is "
<< bottom_rank_id_ << ", the bottom left rank id is " << bottom_left_rank_id_ << ", the left rank id is "
<< left_rank_id_ << ", the top left rank id is " << top_left_rank_id_ << ", the top rank bias is "
<< top_rank_bias_ << ", the bottom rank bias is " << bottom_rank_bias_ << ", the left rank bias is "
<< left_rank_bias_ << ", the right rank bias is " << right_rank_bias_ << ", the h dim rank bias is "
<< h_rank_bias_ << ", the w dim rank bias is " << w_rank_bias_;
}

int64_t Conv2DInfo::ComputeOverlapTopSizeByRankBias(int64_t rank_bias) {
int64_t top_pad = pad_list_[0];
int64_t h_dimension_input_shape = inputs_shape_[0][2];
int64_t h_dimension_output_shape = outputs_shape_[0][2];
int64_t h_stride = stride_[2];

return top_pad + (h_dimension_input_shape - h_dimension_output_shape * h_stride) * rank_bias / h_dimension_shard_num_;
}

int64_t Conv2DInfo::ComputeOverlapBottomSizeByRankBias(int64_t rank_bias) {
int64_t top_pad = pad_list_[0];
int64_t h_dimension_input_shape = inputs_shape_[0][2];
int64_t h_dimension_output_shape = outputs_shape_[0][2];
int64_t h_kernel_size = kernel_size_[0];
int64_t h_stride = stride_[2];

return (rank_bias + 1) * (h_dimension_output_shape * h_stride - h_dimension_input_shape) / h_dimension_shard_num_ +
h_kernel_size - h_stride - top_pad;
}

int64_t Conv2DInfo::ComputeOverlapLeftSizeByRankBias(int64_t rank_bias) {
@@ -435,38 +512,78 @@ int64_t Conv2DInfo::ComputeOverlapRightSizeByRankBias(int64_t rank_bias) {
w_kernel_size - w_stride - left_pad;
}

void Conv2DInfo::InferOverlapSize() {
if (!need_exchange_overlap_) {
MS_LOG(INFO) << name_ << ": No need to infer overlap size";
void Conv2DInfo::InferOverlapSizeForHDim() {
if (!h_dim_need_exchange_overlap_) {
overlap_top_size_ = 0;
overlap_bottom_size_ = 0;
bottom_rank_overlap_top_size_ = 0;
top_rank_overlap_bottom_size_ = 0;
return;
}

overlap_left_size_ = ComputeOverlapLeftSizeByRankBias(rank_bias_);
overlap_right_size_ = ComputeOverlapRightSizeByRankBias(rank_bias_);
if (h_rank_bias_ == 0) {
// it has not top rank
overlap_top_size_ = 0;
overlap_bottom_size_ = ComputeOverlapBottomSizeByRankBias(h_rank_bias_);
top_rank_overlap_bottom_size_ = 0;
bottom_rank_overlap_top_size_ = ComputeOverlapTopSizeByRankBias(bottom_rank_bias_);
} else if (h_rank_bias_ == h_dimension_shard_num_ - 1) {
// it has not bottom rank
overlap_top_size_ = ComputeOverlapTopSizeByRankBias(h_rank_bias_);
overlap_bottom_size_ = 0;
top_rank_overlap_bottom_size_ = ComputeOverlapBottomSizeByRankBias(top_rank_bias_);
bottom_rank_overlap_top_size_ = 0;
} else {
// it has left rank and right rank
overlap_top_size_ = ComputeOverlapTopSizeByRankBias(h_rank_bias_);
overlap_bottom_size_ = ComputeOverlapBottomSizeByRankBias(h_rank_bias_);
top_rank_overlap_bottom_size_ = ComputeOverlapBottomSizeByRankBias(top_rank_bias_);
bottom_rank_overlap_top_size_ = ComputeOverlapTopSizeByRankBias(bottom_rank_bias_);
}
}

void Conv2DInfo::InferOverlapSizeForWDim() {
if (!w_dim_need_exchange_overlap_) {
overlap_left_size_ = 0;
overlap_right_size_ = 0;
left_rank_overlap_right_size_ = 0;
right_rank_overlap_left_size_ = 0;
return;
}

if (rank_bias_ == 0) { // it has not left rank
left_rank_overlap_left_size_ = 0;
if (w_rank_bias_ == 0) {
// it has not left rank
overlap_left_size_ = 0;
overlap_right_size_ = ComputeOverlapRightSizeByRankBias(w_rank_bias_);
left_rank_overlap_right_size_ = 0;
right_rank_overlap_left_size_ = ComputeOverlapLeftSizeByRankBias(right_rank_bias_);
right_rank_overlap_right_size_ = ComputeOverlapRightSizeByRankBias(right_rank_bias_);
} else if (rank_bias_ == w_dimension_shard_num_ - 1) { // it has not right rank
left_rank_overlap_left_size_ = ComputeOverlapLeftSizeByRankBias(left_rank_bias_);
} else if (w_rank_bias_ == w_dimension_shard_num_ - 1) {
// it has not right rank
overlap_left_size_ = ComputeOverlapLeftSizeByRankBias(w_rank_bias_);
overlap_right_size_ = 0;
left_rank_overlap_right_size_ = ComputeOverlapRightSizeByRankBias(left_rank_bias_);
right_rank_overlap_left_size_ = 0;
right_rank_overlap_right_size_ = 0;
} else { // it has left rank and right rank
left_rank_overlap_left_size_ = ComputeOverlapLeftSizeByRankBias(left_rank_bias_);
} else {
// it has left rank and right rank
overlap_left_size_ = ComputeOverlapLeftSizeByRankBias(w_rank_bias_);
overlap_right_size_ = ComputeOverlapRightSizeByRankBias(w_rank_bias_);
left_rank_overlap_right_size_ = ComputeOverlapRightSizeByRankBias(left_rank_bias_);
right_rank_overlap_left_size_ = ComputeOverlapLeftSizeByRankBias(right_rank_bias_);
right_rank_overlap_right_size_ = ComputeOverlapRightSizeByRankBias(right_rank_bias_);
}
}

void Conv2DInfo::InferOverlapSize() {
InferOverlapSizeForHDim();
InferOverlapSizeForWDim();

MS_LOG(INFO) << name_ << ": the left overlap size of current rank is " << overlap_left_size_
<< ", the right overlap size of current rank is " << overlap_right_size_
<< ", the left overlap size of left rank is " << left_rank_overlap_left_size_
<< ", the right overlap size of left rank is " << left_rank_overlap_right_size_
<< ", the left overlap size of right rank is " << right_rank_overlap_left_size_
<< ", the right overlap size of right rank is " << right_rank_overlap_right_size_;
<< ", the top overlap size of current rank is " << overlap_top_size_
<< ", the bottom overlap size of current rank is " << overlap_bottom_size_
<< ", the bottom overlap size of top rank is " << top_rank_overlap_bottom_size_
<< ", the top overlap size of bottom rank is " << bottom_rank_overlap_top_size_;
}

Status Conv2DInfo::InferTensorMap() {
@@ -518,60 +635,96 @@ Status Conv2DInfo::InferForwardCommunication() {

void Conv2DInfo::InferNewPadList() {
new_pad_list_ = pad_list_;
if (rank_bias_ == 0) { // the first rank
new_pad_list_[3] = 0; // no need the right pad
} else if (rank_bias_ == w_dimension_shard_num_ - 1) { // the last rank
new_pad_list_[2] = 0; // no need the left pad
} else { // the middle rank
new_pad_list_[2] = 0; // no need the left pad
new_pad_list_[3] = 0; // no need the right pad
if (h_dim_need_exchange_overlap_) {
if (h_rank_bias_ == 0) { // the first rank
new_pad_list_[1] = 0; // no need the bottom pad
} else if (h_rank_bias_ == h_dimension_shard_num_ - 1) { // the last rank
new_pad_list_[0] = 0; // no need the top pad
} else { // the middle rank
new_pad_list_[0] = 0; // no need the top pad
new_pad_list_[1] = 0; // no need the bottom pad
}
}

if (w_dim_need_exchange_overlap_) {
if (w_rank_bias_ == 0) { // the first rank
new_pad_list_[3] = 0; // no need the right pad
} else if (w_rank_bias_ == w_dimension_shard_num_ - 1) { // the last rank
new_pad_list_[2] = 0; // no need the left pad
} else { // the middle rank
new_pad_list_[2] = 0; // no need the left pad
new_pad_list_[3] = 0; // no need the right pad
}
}

MS_LOG(INFO) << name_ << ": the new pad list is " << new_pad_list_;
}

void Conv2DInfo::InferSendRankIds() {
int64_t send_top_rank = top_rank_overlap_bottom_size_ > 0 ? top_rank_id_ : -1;
int64_t send_bottom_rank = bottom_rank_overlap_top_size_ > 0 ? bottom_rank_id_ : -1;
int64_t send_left_rank = left_rank_overlap_right_size_ > 0 ? left_rank_id_ : -1;
int64_t send_right_rank = right_rank_overlap_left_size_ > 0 ? right_rank_id_ : -1;
int64_t send_top_left_rank = (send_top_rank != -1 && send_left_rank != -1) ? top_left_rank_id_ : -1;
int64_t send_top_right_rank = (send_top_rank != -1 && send_right_rank != -1) ? top_right_rank_id_ : -1;
int64_t send_bottom_left_rank = (send_bottom_rank != -1 && send_left_rank != -1) ? bottom_left_rank_id_ : -1;
int64_t send_bottom_right_rank = (send_bottom_rank != -1 && send_right_rank != -1) ? bottom_right_rank_id_ : -1;

// the order of send or recv rank ids in the array is organized in the following format:
// [top, top_right, right, bottom_right, bottom, bottom_left, left, top_left]
send_rank_ids_ = {send_top_rank, send_top_right_rank, send_right_rank, send_bottom_right_rank,
send_bottom_rank, send_bottom_left_rank, send_left_rank, send_top_left_rank};
}

void Conv2DInfo::InferRecvRankIds() {
int64_t recv_top_rank = overlap_top_size_ > 0 ? top_rank_id_ : -1;
int64_t recv_bottom_rank = overlap_bottom_size_ > 0 ? bottom_rank_id_ : -1;
int64_t recv_left_rank = overlap_left_size_ > 0 ? left_rank_id_ : -1;
int64_t recv_right_rank = overlap_right_size_ > 0 ? right_rank_id_ : -1;
int64_t recv_top_left_rank = (recv_top_rank != -1 && recv_left_rank != -1) ? top_left_rank_id_ : -1;
int64_t recv_top_right_rank = (recv_top_rank != -1 && recv_right_rank != -1) ? top_right_rank_id_ : -1;
int64_t recv_bottom_left_rank = (recv_bottom_rank != -1 && recv_left_rank != -1) ? bottom_left_rank_id_ : -1;
int64_t recv_bottom_right_rank = (recv_bottom_rank != -1 && recv_right_rank != -1) ? bottom_right_rank_id_ : -1;

// the order of send or recv rank ids in the array is organized in the following format:
// [top, top_right, right, bottom_right, bottom, bottom_left, left, top_left]
recv_rank_ids_ = {recv_top_rank, recv_top_right_rank, recv_right_rank, recv_bottom_right_rank,
recv_bottom_rank, recv_bottom_left_rank, recv_left_rank, recv_top_left_rank};
}

void Conv2DInfo::InferCommunicationAttrs() {
// send rank ids: [-1, -1, send_right_rank, -1, -1, -1, send_left_rank, -1]
// recv rank ids: [-1, -1, recv_right_rank, -1, -1, -1, recv_left_rank, -1]
// send lens: [0, 0, send_left_len, send_right_len]
// recv lens: [0, 0, recv_left_len, recv_right_len]
int64_t send_right_rank = -1, send_left_rank = -1, recv_right_rank = -1, recv_left_rank = -1;
int64_t send_left_len = 0, send_right_len = 0, recv_left_len = 0, recv_right_len = 0;

if (rank_bias_ == 0) {
// the first rank
send_right_len = right_rank_overlap_left_size_;
send_right_rank = send_right_len > 0 ? right_rank_id_ : -1;

recv_right_len = overlap_right_size_;
recv_right_rank = recv_right_len > 0 ? right_rank_id_ : -1;
} else if (rank_bias_ == w_dimension_shard_num_ - 1) {
// the last rank
send_left_len = left_rank_overlap_right_size_;
send_left_rank = send_left_len > 0 ? left_rank_id_ : -1;

recv_left_len = overlap_left_size_;
recv_left_rank = recv_left_len > 0 ? left_rank_id_ : -1;
} else {
// the middle rank
send_right_len = right_rank_overlap_left_size_;
send_right_rank = send_right_len > 0 ? right_rank_id_ : -1;
// send ranks
InferSendRankIds();

recv_right_len = overlap_right_size_;
recv_right_rank = recv_right_len > 0 ? right_rank_id_ : -1;
send_left_len = left_rank_overlap_right_size_;
send_left_rank = send_left_len > 0 ? left_rank_id_ : -1;
// recv ranks
InferRecvRankIds();

recv_left_len = overlap_left_size_;
recv_left_rank = recv_left_len > 0 ? left_rank_id_ : -1;
}
// send lens
int64_t send_top_len = top_rank_overlap_bottom_size_;
int64_t send_bottom_len = bottom_rank_overlap_top_size_;
int64_t send_left_len = left_rank_overlap_right_size_;
int64_t send_right_len = right_rank_overlap_left_size_;

// recv lens
int64_t recv_top_len = overlap_top_size_;
int64_t recv_bottom_len = overlap_bottom_size_;
int64_t recv_left_len = overlap_left_size_;
int64_t recv_right_len = overlap_right_size_;

// the order of send or recv lens in the array is organized in the following format:
// [top, bottom, left, right]
send_lens_ = {send_top_len, send_bottom_len, send_left_len, send_right_len};
recv_lens_ = {recv_top_len, recv_bottom_len, recv_left_len, recv_right_len};

send_rank_ids_ = {-1, -1, send_right_rank, -1, -1, -1, send_left_rank, -1};
recv_rank_ids_ = {-1, -1, recv_right_rank, -1, -1, -1, recv_left_rank, -1};
send_lens_ = {0, 0, send_left_len, send_right_len};
recv_lens_ = {0, 0, recv_left_len, recv_right_len};
MS_LOG(INFO) << name_ << ": The send rank ids is " << send_rank_ids_ << ", the send lens is " << send_lens_
<< ", the recv rank ids is " << recv_rank_ids_ << ", the recv lens is " << recv_lens_;

int64_t h_slice_shape = input_slice_shape_[2];
if (send_top_len > h_slice_shape || send_bottom_len > h_slice_shape || recv_top_len > h_slice_shape ||
recv_bottom_len > h_slice_shape) {
MS_LOG(EXCEPTION) << name_ << ": The send or recv len larger than slice shape of h dimension " << h_slice_shape;
}

int64_t w_slice_shape = input_slice_shape_[3];
if (send_left_len > w_slice_shape || send_right_len > w_slice_shape || recv_left_len > w_slice_shape ||
recv_right_len > w_slice_shape) {
@@ -675,7 +828,7 @@ void Conv2DInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
}

ReplaceGraphPtr Conv2DInfo::replace_graph(const CNodePtr &cnode) {
if (!need_exchange_overlap_) {
if (!w_dim_need_exchange_overlap_ && !h_dim_need_exchange_overlap_) {
if (!out_channel_shard_) {
return nullptr;
}
@@ -684,9 +837,7 @@ ReplaceGraphPtr Conv2DInfo::replace_graph(const CNodePtr &cnode) {
return nullptr;
}

if (InferRankBias() != SUCCESS) {
return nullptr;
}
InferAdjacentRankInfo();

InferOverlapSize();

@@ -757,7 +908,8 @@ Status Conv2DBackpropInputInfo::GetAttrs() {
}

Status Conv2DBackpropInputInfo::CheckStrategy(const StrategyPtr &strategy) {
need_exchange_overlap_ = false;
w_dim_need_exchange_overlap_ = false;
h_dim_need_exchange_overlap_ = false;
if (CheckStrategyBase(strategy) != SUCCESS) {
return FAILED;
}
@@ -777,9 +929,13 @@ Status Conv2DBackpropInputInfo::CheckStrategy(const StrategyPtr &strategy) {
}
}

// kernel size larger than stride and the w dimension is split, need to exchange overlap
// kernel size larger than stride and the h/w dimension is split, need to exchange overlap
if ((kernel_size_[0] > stride_[2]) && (input_strategy[2] > 1)) {
h_dim_need_exchange_overlap_ = true;
}

if ((kernel_size_[1] > stride_[3]) && (input_strategy[3] > 1)) {
need_exchange_overlap_ = true;
w_dim_need_exchange_overlap_ = true;
}
return SUCCESS;
}
@@ -794,8 +950,8 @@ Status Conv2DBackpropInputInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_st
return FAILED;
}

if (h_strategy > 1) {
MS_LOG(ERROR) << name_ << ": Do not support to split h dimension";
if (h_strategy > 1 && inputs_shape_[0][2] * stride_[2] != outputs_shape_[0][2]) {
MS_LOG(ERROR) << name_ << ": Do not support to split h dimension when in_shape * stride != out_shape";
return FAILED;
}

@@ -839,6 +995,7 @@ Status Conv2DBackpropInputInfo::InferDevMatrixShape() {
out_slice_shape_[i] = out_slice_shape_[i] / out_strategy[i];
}

h_dimension_shard_num_ = stra[0][2];
w_dimension_shard_num_ = stra[0][3];
input_slice_shape_ = GetSliceShape(inputs_shape_[0], stra[0]);
MS_LOG(INFO) << name_ << ": The output slice shape is " << out_slice_shape_;
@@ -925,6 +1082,63 @@ void Conv2DBackpropInputInfo::UpdateOutShape() {
MS_LOG(INFO) << name_ << ": Update the output shape " << out_slice_shape_;
}

int64_t Conv2DBackpropInputInfo::ComputeOverlapTopSizeByRankBias(int64_t rank_bias) {
// 1. the first rank: 0
// 2. the last rank:
// size of origin data required by current rank: a = ceil((o/n + k - o + h*s - s - x)/s)
// data size of the current rank: b = h/n
// return a - b = ceil((o/n + k - o + h*s - s - x)/s) - h/n
// 3. the middle rank: (the x is top pad)
// r*h/n - ceil((r*o/n - k + x + 1)/s)
if (rank_bias == 0) { // the first rank
return 0;
}

int64_t h_output_shape = outputs_shape_[0][2];
int64_t h_input_shape = inputs_shape_[0][2];
int64_t h_kernel_size = kernel_size_[0];
int64_t h_stride = stride_[2];
int64_t top_pad = pad_list_[0];
if (rank_bias == h_dimension_shard_num_ - 1) { // the last rank
return DoubleToLong(std::ceil(LongToDouble(h_output_shape / h_dimension_shard_num_ + h_kernel_size -
h_output_shape + h_input_shape * h_stride - h_stride - top_pad) /
LongToDouble(h_stride))) -
h_input_shape / h_dimension_shard_num_;
}

// the middle rank
return rank_bias * h_input_shape / h_dimension_shard_num_ -
DoubleToLong(
std::ceil(LongToDouble(rank_bias * h_output_shape / h_dimension_shard_num_ - h_kernel_size + top_pad + 1) /
LongToDouble(h_stride)));
}

int64_t Conv2DBackpropInputInfo::ComputeOverlapBottomSizeByRankBias(int64_t rank_bias) {
// 1. the first rank: ceil((o/n + x)/s) - h/n
// 2. the last rank: 0
// 3. the middle rank: ceil((r*o/n + o/n + x)/s) - r*h/n - h/n
int64_t h_output_shape = outputs_shape_[0][2];
int64_t h_input_shape = inputs_shape_[0][2];
int64_t h_stride = stride_[2];
int64_t top_pad = pad_list_[0];

if (rank_bias == 0) { // the first rank
return DoubleToLong(
std::ceil(LongToDouble(h_output_shape / h_dimension_shard_num_ + top_pad) / LongToDouble(h_stride))) -
h_input_shape / h_dimension_shard_num_;
}

if (rank_bias == h_dimension_shard_num_ - 1) { // the last rank
return 0;
}

// the middle rank
return DoubleToLong(std::ceil(LongToDouble(rank_bias * h_output_shape / h_dimension_shard_num_ +
h_output_shape / h_dimension_shard_num_ + top_pad) /
LongToDouble(h_stride))) -
(rank_bias + 1) * h_input_shape / h_dimension_shard_num_;
}

int64_t Conv2DBackpropInputInfo::ComputeOverlapLeftSizeByRankBias(int64_t rank_bias) {
// 1. the first rank: 0
// 2. the last rank:
@@ -982,7 +1196,7 @@ int64_t Conv2DBackpropInputInfo::ComputeOverlapRightSizeByRankBias(int64_t rank_
(rank_bias + 1) * w_input_shape / w_dimension_shard_num_;
}

void Conv2DBackpropInputInfo::InferNewPadList() {
void Conv2DBackpropInputInfo::InferNewPadListByDimension(const std::string &dimension) {
// 1. compute the size of origin data required by current rank:
// 1) the first rank: ceil((o/n + x) / s)
// 2) the last rank: ceil((o/n + k - o + ws - s - x) / s)
@@ -996,64 +1210,105 @@ void Conv2DBackpropInputInfo::InferNewPadList() {
// 3) the middle rank:
// if (r*on - k + x + 1) is divisible by s, real_left_pad = 0.
// otherwise, real_left_pad = s - (r*on - k + x + 1) % s
int64_t w_output_shape = outputs_shape_[0][3];
int64_t w_input_shape = inputs_shape_[0][3];
int64_t w_kernel_size = kernel_size_[1];
int64_t w_stride = stride_[3];
int64_t left_pad = pad_list_[2];
int64_t current_rank_required_size = 0;
int64_t real_left_pad = 0;
int64_t real_top_or_left_pad = 0;
int64_t h_or_w_output_shape = -1;
int64_t h_or_w_input_shape = -1;
int64_t h_or_w_kernel_size = -1;
int64_t h_or_w_stride = -1;
int64_t top_or_left_pad = -1;
int64_t h_or_w_rank_bias = -1;
int64_t h_or_w_dim_shard_num = -1;

if (dimension == H_DIMENSION) {
h_or_w_output_shape = outputs_shape_[0][2];
h_or_w_input_shape = inputs_shape_[0][2];
h_or_w_kernel_size = kernel_size_[0];
h_or_w_stride = stride_[2];
top_or_left_pad = pad_list_[0];
h_or_w_rank_bias = h_rank_bias_;
h_or_w_dim_shard_num = h_dimension_shard_num_;
} else {
h_or_w_output_shape = outputs_shape_[0][3];
h_or_w_input_shape = inputs_shape_[0][3];
h_or_w_kernel_size = kernel_size_[1];
h_or_w_stride = stride_[3];
top_or_left_pad = pad_list_[2];
h_or_w_rank_bias = w_rank_bias_;
h_or_w_dim_shard_num = w_dimension_shard_num_;
}

if (rank_bias_ == 0) { // the first rank
current_rank_required_size = DoubleToLong(
std::ceil(LongToDouble(w_output_shape / w_dimension_shard_num_ + left_pad) / LongToDouble(w_stride)));
if (h_or_w_rank_bias == 0) { // the first rank
current_rank_required_size = DoubleToLong(std::ceil(
LongToDouble(h_or_w_output_shape / h_or_w_dim_shard_num + top_or_left_pad) / LongToDouble(h_or_w_stride)));

real_left_pad = w_kernel_size - left_pad - 1;
} else if (rank_bias_ == w_dimension_shard_num_ - 1) { // the last rank
current_rank_required_size =
DoubleToLong(std::ceil(LongToDouble(w_output_shape / w_dimension_shard_num_ + w_kernel_size - w_output_shape +
w_input_shape * w_stride - w_stride - left_pad) /
LongToDouble(w_stride)));

int64_t tmp = w_output_shape / w_dimension_shard_num_ + w_kernel_size - w_output_shape + w_input_shape * w_stride -
w_stride - left_pad;
if (tmp % w_stride == 0) {
real_left_pad = w_stride - 1;
real_top_or_left_pad = h_or_w_kernel_size - top_or_left_pad - 1;
} else if (h_or_w_rank_bias == h_or_w_dim_shard_num - 1) { // the last rank
current_rank_required_size = DoubleToLong(
std::ceil(LongToDouble(h_or_w_output_shape / h_or_w_dim_shard_num + h_or_w_kernel_size - h_or_w_output_shape +
h_or_w_input_shape * h_or_w_stride - h_or_w_stride - top_or_left_pad) /
LongToDouble(h_or_w_stride)));
int64_t tmp = h_or_w_output_shape / h_or_w_dim_shard_num + h_or_w_kernel_size - h_or_w_output_shape +
h_or_w_input_shape * h_or_w_stride - h_or_w_stride - top_or_left_pad;
if (tmp % h_or_w_stride == 0) {
real_top_or_left_pad = h_or_w_stride - 1;
} else {
real_left_pad = tmp % w_stride - 1;
real_top_or_left_pad = tmp % h_or_w_stride - 1;
}
} else { // the middle rank
current_rank_required_size =
DoubleToLong(std::ceil(LongToDouble(rank_bias_ * w_output_shape / w_dimension_shard_num_ +
w_output_shape / w_dimension_shard_num_ + left_pad) /
LongToDouble(w_stride))) -
DoubleToLong(
std::ceil(LongToDouble(rank_bias_ * w_output_shape / w_dimension_shard_num_ - w_kernel_size + left_pad + 1) /
LongToDouble(w_stride)));

int64_t tmp = rank_bias_ * w_output_shape / w_dimension_shard_num_ - w_kernel_size + left_pad + 1;
if (tmp % w_stride == 0) {
real_left_pad = 0;
DoubleToLong(std::ceil(LongToDouble(h_or_w_rank_bias * h_or_w_output_shape / h_or_w_dim_shard_num +
h_or_w_output_shape / h_or_w_dim_shard_num + top_or_left_pad) /
LongToDouble(h_or_w_stride))) -
DoubleToLong(std::ceil(LongToDouble(h_or_w_rank_bias * h_or_w_output_shape / h_or_w_dim_shard_num -
h_or_w_kernel_size + top_or_left_pad + 1) /
LongToDouble(h_or_w_stride)));

int64_t tmp =
h_or_w_rank_bias * h_or_w_output_shape / h_or_w_dim_shard_num - h_or_w_kernel_size + top_or_left_pad + 1;
if (tmp % h_or_w_stride == 0) {
real_top_or_left_pad = 0;
} else {
real_left_pad = w_stride - tmp % w_stride;
real_top_or_left_pad = h_or_w_stride - tmp % h_or_w_stride;
}
}

// 3. compute the pad_add: (current_rank_required_size - 1) * s + k - o/n
int64_t pad_all =
(current_rank_required_size - 1) * w_stride + w_kernel_size - w_output_shape / w_dimension_shard_num_;
(current_rank_required_size - 1) * h_or_w_stride + h_or_w_kernel_size - h_or_w_output_shape / h_or_w_dim_shard_num;

// 4. compute new left pad: k - real_left_pad - 1
new_pad_list_ = pad_list_;
new_pad_list_[2] = w_kernel_size - real_left_pad - 1;

// 5. compute new right pad: pad_all - new_left_pad
new_pad_list_[3] = pad_all - new_pad_list_[2];
if (dimension == H_DIMENSION) {
new_pad_list_[0] = h_or_w_kernel_size - real_top_or_left_pad - 1;
new_pad_list_[1] = pad_all - new_pad_list_[2];
} else {
new_pad_list_[2] = h_or_w_kernel_size - real_top_or_left_pad - 1;
new_pad_list_[3] = pad_all - new_pad_list_[2];
}

MS_LOG(INFO) << name_ << ": the new pad list is " << new_pad_list_ << ", the required size of current rank is "
MS_LOG(INFO) << name_ << ": The dimension is " << dimension << ", the required size of current rank is "
<< current_rank_required_size << ", new pad all is " << pad_all;
}

void Conv2DBackpropInputInfo::InferNewPadList() {
// init new pad list
new_pad_list_ = pad_list_;

// infer h dimension's new pad
if (h_dim_need_exchange_overlap_) {
InferNewPadListByDimension(H_DIMENSION);
}

// infer w dimension's new pad
if (w_dim_need_exchange_overlap_) {
InferNewPadListByDimension(W_DIMENSION);
}

MS_LOG(INFO) << name_ << ": The new pad list is " << new_pad_list_;
}

void Conv2DBackpropInputInfo::ReplaceNodeInputOrAttrs() { UpdateOutShape(); }
} // namespace parallel
} // namespace mindspore

+ 50
- 8
mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.h View File

@@ -50,9 +50,14 @@ class Conv2DInfo : public OperatorInfo {
Status InferForwardCommunication() override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
Status InferRankBias();
void InferAdjacentRankInfo();
std::vector<int64_t> GetAdjacentRankIdsAndBiases(int64_t rank_id, const std::string &dimension);
void InferOverlapSize();
void InferOverlapSizeForHDim();
void InferOverlapSizeForWDim();
void InferNewOperatorAttrs();
void InferSendRankIds();
void InferRecvRankIds();
void InferCommunicationAttrs();
std::string ReplaceNodeName() const;
AnfNodePtr GenerateConv2DNode(const AnfNodePtr &new_input, const CNodePtr &cnode);
@@ -74,23 +79,54 @@ class Conv2DInfo : public OperatorInfo {
int64_t new_out_channel_ = 1;
std::vector<int64_t> new_pad_list_;

bool need_exchange_overlap_ = false;
int64_t rank_bias_ = 0;
int64_t left_rank_bias_ = -1;
int64_t right_rank_bias_ = -1;
int64_t left_rank_id_ = -1;
bool w_dim_need_exchange_overlap_ = false;
bool h_dim_need_exchange_overlap_ = false;
int64_t h_rank_bias_ = 0; // the bias of current rank in h dimension of device matrix
int64_t w_rank_bias_ = 0; // the bias of current rank in w dimension of device matrix
int64_t top_rank_bias_ = -1; // the bias of top rank in h dimension of device matrix
int64_t bottom_rank_bias_ = -1; // the bias of bottom rank in h dimension of device matrix
int64_t left_rank_bias_ = -1; // the bias of left rank in w dimension of device matrix
int64_t right_rank_bias_ = -1; // the bias of right rank in w dimension of device matrix

// 8 adjacent ranks
int64_t top_rank_id_ = -1;
int64_t top_right_rank_id_ = -1;
int64_t right_rank_id_ = -1;
int64_t bottom_right_rank_id_ = -1;
int64_t bottom_rank_id_ = -1;
int64_t bottom_left_rank_id_ = -1;
int64_t left_rank_id_ = -1;
int64_t top_left_rank_id_ = -1;

// overlap sizes for h dimension
int64_t overlap_top_size_ = 0;
int64_t overlap_bottom_size_ = 0;
int64_t top_rank_overlap_bottom_size_ = 0;
int64_t bottom_rank_overlap_top_size_ = 0;

// overlap sizes for w dimension
int64_t overlap_left_size_ = 0;
int64_t overlap_right_size_ = 0;
int64_t left_rank_overlap_left_size_ = 0;
int64_t left_rank_overlap_right_size_ = 0;
int64_t right_rank_overlap_left_size_ = 0;
int64_t right_rank_overlap_right_size_ = 0;

int64_t h_dimension_shard_num_ = 1;
int64_t w_dimension_shard_num_ = 1;
Shape input_slice_shape_;

// the send_rank_ids_ or recv_rank_ids is an array with 8 rank ids, the order of index in the array is organized in
// the following format(the 'R' is current rank), the invalid rank fill -1
// +++++++++++++
// | 7 | 0 | 1 |
// +++++++++++++
// | 6 | R | 2 |
// +++++++++++++
// | 5 | 4 | 3 |
// +++++++++++++
std::vector<int64_t> send_rank_ids_;
std::vector<int64_t> recv_rank_ids_;

// the send_lens_ or recv_lens_ is an array with 4 lens, the order in the array represents top, bottom, left, right
std::vector<int64_t> send_lens_;
std::vector<int64_t> recv_lens_;
std::string all_to_all_group_;
@@ -99,10 +135,13 @@ class Conv2DInfo : public OperatorInfo {

virtual Status CheckHWStrategy(int64_t h_strategy, int64_t w_strategy);
virtual void InferNewPadList();
virtual int64_t ComputeOverlapTopSizeByRankBias(int64_t rank_bias);
virtual int64_t ComputeOverlapBottomSizeByRankBias(int64_t rank_bias);
virtual int64_t ComputeOverlapLeftSizeByRankBias(int64_t rank_bias);
virtual int64_t ComputeOverlapRightSizeByRankBias(int64_t rank_bias);

private:
Status CheckHWStrategySameModeByDimension(int64_t strategy, const std::string &dimension);
Status CheckHWStrategySameMode(int64_t h_strategy, int64_t w_strategy);
Status CheckHWStrategyValidMode(int64_t h_strategy, int64_t w_strategy);
};
@@ -126,6 +165,9 @@ class Conv2DBackpropInputInfo : public Conv2DInfo {

Status CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) override;
void InferNewPadList() override;
void InferNewPadListByDimension(const std::string &dimension);
int64_t ComputeOverlapTopSizeByRankBias(int64_t rank_bias) override;
int64_t ComputeOverlapBottomSizeByRankBias(int64_t rank_bias) override;
int64_t ComputeOverlapLeftSizeByRankBias(int64_t rank_bias) override;
int64_t ComputeOverlapRightSizeByRankBias(int64_t rank_bias) override;



+ 2
- 0
mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h View File

@@ -231,6 +231,8 @@ constexpr char STRIDE[] = "stride";
constexpr char DILATION[] = "dilation";
constexpr char FORMAT[] = "format";
constexpr char NCHW[] = "NCHW";
constexpr char H_DIMENSION[] = "h_dimension";
constexpr char W_DIMENSION[] = "w_dimension";
constexpr char IS_TRAINING[] = "is_training";
constexpr char EPSILON[] = "epsilon";
constexpr char MOMENTUM[] = "momentum";


+ 1
- 3
mindspore/ccsrc/frontend/parallel/ops_info/resizebilinear_info.cc View File

@@ -204,9 +204,7 @@ Status ResizeBilinearInfo::InferRankBias() {
right_rank_id_ = *(it + 1);
}

Group group = g_device_manager->CreateGroup(group_devices);
all_to_all_group_ = group.name();

all_to_all_group_ = g_device_manager->world_group(); // use world group temporarily
MS_LOG(INFO) << name_ << ": The current rank is " << rank << ", the device list of w dimension is " << group_devices
<< ", the rank bias is " << rank_bias_ << ", the left rank bias is " << left_rank_bias_
<< ", the right rank bias is " << right_rank_bias_ << ", the left rank id is " << left_rank_id_


+ 144
- 2
tests/ut/python/parallel/test_conv2d.py View File

@@ -57,6 +57,11 @@ def compile_net(net, input_x=_x):


def test_conv2d_data_parallel():
"""
Feature: test conv2d data parallel
Description: shard n dimension
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
strategy2 = ((8, 1, 1, 1),)
@@ -65,6 +70,11 @@ def test_conv2d_data_parallel():


def test_conv2d_data_parallel_invalid_stride():
"""
Feature: test conv2d invalid stride
Description: the first two elements of stride must be 1, but set 2
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
strategy2 = ((8, 1, 1, 1),)
@@ -75,6 +85,11 @@ def test_conv2d_data_parallel_invalid_stride():


def test_conv2d_data_parallel_dilation():
"""
Feature: test conv2d data parallel and dilation is not 1
Description: data parallel and dilation is not 1
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
strategy2 = ((8, 1, 1, 1),)
@@ -84,6 +99,11 @@ def test_conv2d_data_parallel_dilation():


def test_conv2d_data_parallel_group():
"""
Feature: test conv2d data parallel and group is not 1
Description: data parallel and group is not 1
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
strategy2 = ((8, 1, 1, 1),)
@@ -93,6 +113,11 @@ def test_conv2d_data_parallel_group():


def test_conv2d_model_parallel1():
"""
Feature: test conv2d model parallel
Description: split n/c-in/c-out
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
strategy2 = ((8, 1, 1, 1),)
@@ -101,6 +126,11 @@ def test_conv2d_model_parallel1():


def test_conv2d_model_parallel_dilation():
"""
Feature: test conv2d model parallel and dilation is not 1
Description: model parallel and dilation is not 1
Expectation: compile failed
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
strategy2 = ((8, 1, 1, 1),)
@@ -111,6 +141,11 @@ def test_conv2d_model_parallel_dilation():


def test_conv2d_model_parallel_group():
"""
Feature: test conv2d model parallel and group is not 1
Description: model parallel and group is not 1
Expectation: compile failed
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
strategy2 = ((8, 1, 1, 1),)
@@ -121,6 +156,11 @@ def test_conv2d_model_parallel_group():


def test_conv2d_model_parallel2():
"""
Feature: same mode, stride = kernel_size, no need exchange
Description: split n/c-in/c-out/h/w
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0)
strategy1 = ((2, 2, 2, 2), (2, 2, 1, 1))
strategy2 = ((32, 1, 1, 1),)
@@ -129,6 +169,11 @@ def test_conv2d_model_parallel2():


def test_conv2d_model_parallel3():
"""
Feature: same mode, stride < kernel_size, need exchange
Description: split n/w
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((2, 1, 1, 4), (1, 1, 1, 1))
strategy2 = ((2, 1, 1, 4),)
@@ -137,12 +182,22 @@ def test_conv2d_model_parallel3():


def test_conv2d_auto_parallel():
"""
Feature: same mode, auto parallel
Description: generate data parallel strategy
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1)
compile_net(net)


def test_conv2d_model_parallel4():
"""
Feature: same mode, stride < kernel_size, need exchange
Description: split n/c-in/c-out/w
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0)
strategy1 = ((2, 2, 1, 4), (2, 2, 1, 1))
strategy2 = ((2, 2, 1, 4),)
@@ -151,6 +206,11 @@ def test_conv2d_model_parallel4():


def test_conv2d_left_and_right_no_need_to_send():
"""
Feature: same mode, k - s = 1, left pad is 0
Description: do not support that the left no need to send
Expectation: compile failed
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((2, 1, 1, 4), (1, 1, 1, 1))
strategy2 = ((2, 1, 1, 4),)
@@ -160,15 +220,24 @@ def test_conv2d_left_and_right_no_need_to_send():


def test_conv2d_kernel_size_larger_than_stride_and_split_h():
"""
Feature: same mode, stride < kernel_size, need exchange
Description: split n/c-in/c-out/h
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0)
strategy1 = ((2, 2, 4, 1), (2, 2, 1, 1))
strategy2 = ((2, 2, 4, 1),)
net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
with pytest.raises(RuntimeError):
compile_net(net)
compile_net(net)


def test_conv2d_valid_mode_kernel_size_larger_than_stride():
"""
Feature: valid mode, stride < kernel_size, need exchange
Description: do not support to split w
Expectation: compile failed
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((2, 1, 1, 2), (1, 1, 1, 1))
strategy2 = ((2, 1, 1, 4),)
@@ -178,6 +247,11 @@ def test_conv2d_valid_mode_kernel_size_larger_than_stride():


def test_conv2d_output_can_not_divisible_by_strategy():
"""
Feature: same mode, stride = kernel_size, but output shape can not be divided by strategy
Description: split w dimension
Expectation: compile failed
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1))
strategy2 = ((1, 1, 1, 8),)
@@ -187,6 +261,11 @@ def test_conv2d_output_can_not_divisible_by_strategy():


def test_conv2d_output_can_not_divisible_by_strategy2():
"""
Feature: same mode, stride = kernel_size, but output shape can not be divided by strategy
Description: split h dimension
Expectation: compile failed
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 1, 8, 1), (1, 1, 1, 1))
strategy2 = ((1, 1, 1, 8),)
@@ -196,6 +275,11 @@ def test_conv2d_output_can_not_divisible_by_strategy2():


def test_split_kernel():
"""
Feature: split kernel size
Description: do not support to split kernel size
Expectation: compile failed
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 1, 1, 1), (1, 1, 2, 2))
strategy2 = ((1, 1, 1, 8),)
@@ -205,6 +289,11 @@ def test_split_kernel():


def test_kernel_size_smaller_than_stride_and_slice_can_not_divisible_by_stride_same_mode():
"""
Feature: same mode, slice shape can not be divided by stride
Description: split w
Expectation: compile failed
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 1, 1, 2), (1, 1, 1, 1))
strategy2 = ((1, 1, 1, 8),)
@@ -214,6 +303,11 @@ def test_kernel_size_smaller_than_stride_and_slice_can_not_divisible_by_stride_s


def test_kernel_size_smaller_than_stride_and_slice_can_not_divisible_by_stride_valid_mode():
"""
Feature: valid mode, slice shape can not be divided by stride
Description: split w
Expectation: compile failed
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 1, 1, 2), (1, 1, 1, 1))
strategy2 = ((1, 1, 1, 8),)
@@ -223,6 +317,11 @@ def test_kernel_size_smaller_than_stride_and_slice_can_not_divisible_by_stride_v


def test_h_dimension_kernel_size_smaller_than_stride_and_slice_is_not_divisible_by_stride_same_mode():
"""
Feature: same mode, slice shape can not be divided by stride
Description: split h
Expectation: compile failed
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 1, 2, 1), (1, 1, 1, 1))
strategy2 = ((1, 1, 1, 8),)
@@ -232,6 +331,11 @@ def test_h_dimension_kernel_size_smaller_than_stride_and_slice_is_not_divisible_


def test_h_dimension_kernel_size_smaller_than_stride_and_slice_can_not_divisible_by_stride_valid_mode():
"""
Feature: valid mode, slice shape can not be divided by stride
Description: split h
Expectation: compile failed
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 1, 2, 1), (1, 1, 1, 1))
strategy2 = ((1, 1, 1, 8),)
@@ -241,6 +345,11 @@ def test_h_dimension_kernel_size_smaller_than_stride_and_slice_can_not_divisible


def test_split_h_dimension_and_pad_mode_is_pad():
"""
Feature: pad mode
Description: split h
Expectation: compile failed
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 1, 2, 1), (1, 1, 1, 1))
strategy2 = ((1, 1, 1, 8),)
@@ -250,6 +359,11 @@ def test_split_h_dimension_and_pad_mode_is_pad():


def test_kernel_size_larger_than_stride_and_input_can_not_divisible_by_stride():
"""
Feature: same mode, input shape can not be divided by stride
Description: split w
Expectation: compile failed
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 1, 1, 2), (1, 1, 1, 1))
strategy2 = ((1, 1, 1, 8),)
@@ -259,6 +373,11 @@ def test_kernel_size_larger_than_stride_and_input_can_not_divisible_by_stride():


def test_kernel_size_larger_than_stride_and_slice_too_small():
"""
Feature: same mode, slice shape is small than overlap shape
Description: split w
Expectation: compile failed
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1))
strategy2 = ((1, 1, 1, 8),)
@@ -268,6 +387,11 @@ def test_kernel_size_larger_than_stride_and_slice_too_small():


def test_conv2d_same_mode_overlap_size_equal_to_slice_shape():
"""
Feature: same mode, slice shape is equal to overlap shape
Description: split w
Expectation: compile failed
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1))
strategy2 = ((2, 1, 1, 4),)
@@ -277,9 +401,27 @@ def test_conv2d_same_mode_overlap_size_equal_to_slice_shape():


def test_kernel_size_larger_than_stride_and_left_pad_is_0():
"""
Feature: same mode, kernel_size > stride and left pad is 0
Description: split w
Expectation: compile failed
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 1, 1, 4), (1, 1, 1, 1))
strategy2 = ((1, 1, 1, 8),)
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
with pytest.raises(RuntimeError):
compile_net(net)


def test_conv2d_kernel_size_larger_than_stride_and_split_nchw():
"""
Feature: same mode, stride < kernel_size, need exchange
Description: split n/c-in/c-out/h/w
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0)
strategy1 = ((2, 2, 2, 2), (2, 2, 1, 1))
strategy2 = ((2, 2, 2, 2),)
net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
compile_net(net)

+ 3
- 5
tests/ut/python/parallel/test_conv2d_transpose.py View File

@@ -156,15 +156,14 @@ def test_conv2d_transpose_split_h_in_same_mode():
"""
Feature: test split h dimension
Description: shard h dimension in same mode
Expectation: compile failed
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((2, 2, 4, 1), (2, 1, 1, 1))
strategy2 = ((2, 2, 1, 4),)
strategy2 = ((2, 2, 4, 1),)
net = Net2(_w1, out_channel=8, kernel_size=(2, 2), pad_mode="same", stride=2,
strategy1=strategy1, strategy2=strategy2)
with pytest.raises(RuntimeError):
compile_net(net)
compile_net(net)


def test_conv2d_transpose_overlap_size_too_large():
@@ -180,4 +179,3 @@ def test_conv2d_transpose_overlap_size_too_large():
strategy1=strategy1, strategy2=strategy2)
with pytest.raises(RuntimeError):
compile_net(net)

Loading…
Cancel
Save