|
|
|
@@ -302,12 +302,12 @@ Status Conv2DInfo::CheckStrategy(const StrategyPtr &strategy) { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// if the h/w dimension is split, need to exchange overlap |
|
|
|
if (input_strategy[2] > 1) { |
|
|
|
// if the h/w dimension is split, and the pad mode is not "valid", need to exchange overlap |
|
|
|
if (input_strategy[2] > 1 && pad_mode_ != 2) { |
|
|
|
h_dim_need_exchange_overlap_ = true; |
|
|
|
} |
|
|
|
|
|
|
|
if (input_strategy[3] > 1) { |
|
|
|
if (input_strategy[3] > 1 && pad_mode_ != 2) { |
|
|
|
w_dim_need_exchange_overlap_ = true; |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
@@ -556,8 +556,12 @@ void Conv2DInfo::InferOverlapSizeForWDim() { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void Conv2DInfo::CheckOverlapSizeNonNegative() { |
|
|
|
// check h dimension |
|
|
|
void Conv2DInfo::CheckHDimensionOverlapSizeNonNegative() { |
|
|
|
if (h_dimension_shard_num_ == 1) { |
|
|
|
MS_LOG(INFO) << name_ << ": The h dimension is not shard"; |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
int64_t h_first_rank_bottom_size = ComputeOverlapBottomSizeByRankBias(0); |
|
|
|
if (h_first_rank_bottom_size < 0) { |
|
|
|
MS_LOG(EXCEPTION) << name_ << ": The bottom overlap size of h dimension rank bias 0 must be positive, but it is " |
|
|
|
@@ -579,8 +583,13 @@ void Conv2DInfo::CheckOverlapSizeNonNegative() { |
|
|
|
MS_LOG(EXCEPTION) << name_ << ": The top overlap size of h dimension last rank bias must be positive, but it is " |
|
|
|
<< h_last_rank_top_size; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// check w dimension |
|
|
|
void Conv2DInfo::CheckWDimensionOverlapSizeNonNegative() { |
|
|
|
if (w_dimension_shard_num_ == 1) { |
|
|
|
MS_LOG(INFO) << name_ << ": The w dimension is not shard"; |
|
|
|
return; |
|
|
|
} |
|
|
|
int64_t w_first_rank_right_size = ComputeOverlapRightSizeByRankBias(0); |
|
|
|
if (w_first_rank_right_size < 0) { |
|
|
|
MS_LOG(EXCEPTION) << name_ << ": The right overlap size of w dimension rank bias 0 must be positive, but it is " |
|
|
|
@@ -604,6 +613,11 @@ void Conv2DInfo::CheckOverlapSizeNonNegative() { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void Conv2DInfo::CheckOverlapSizeNonNegative() { |
|
|
|
CheckHDimensionOverlapSizeNonNegative(); |
|
|
|
CheckWDimensionOverlapSizeNonNegative(); |
|
|
|
} |
|
|
|
|
|
|
|
void Conv2DInfo::InferOverlapSize() { |
|
|
|
InferOverlapSizeForHDim(); |
|
|
|
InferOverlapSizeForWDim(); |
|
|
|
|