|
|
|
@@ -124,7 +124,29 @@ Status Conv2DInfo::GetAttrsBase() { |
|
|
|
|
|
|
|
Status Conv2DInfo::GetAttrs() { return GetAttrsBase(); } |
|
|
|
|
|
|
|
Status Conv2DInfo::CheckHWStrategyBase(int64_t h_strategy, int64_t w_strategy) { |
|
|
|
if (outputs_shape_[0][2] % h_strategy != 0) { |
|
|
|
MS_LOG(ERROR) << name_ |
|
|
|
<< ": Do not support to split h dimension when out_shape of h dimension is not divisible by strategy " |
|
|
|
"of h dimension"; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
if (outputs_shape_[0][3] % w_strategy != 0) { |
|
|
|
MS_LOG(ERROR) << name_ |
|
|
|
<< ": Do not support to split w dimension when out_shape of w dimension is not divisible by strategy " |
|
|
|
"of w dimension"; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status Conv2DInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) { |
|
|
|
if (CheckHWStrategyBase(h_strategy, w_strategy) != SUCCESS) { |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
if (pad_mode_ == 0) { // 'pad' mode |
|
|
|
MS_LOG(ERROR) << name_ << ": The 'pad' mode do not support to split H or W"; |
|
|
|
return FAILED; |
|
|
|
@@ -642,6 +664,10 @@ Status Conv2DBackpropInputInfo::CheckStrategy(const StrategyPtr &strategy) { |
|
|
|
} |
|
|
|
|
|
|
|
Status Conv2DBackpropInputInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) { |
|
|
|
if (CheckHWStrategyBase(h_strategy, w_strategy) != SUCCESS) { |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
if (pad_mode_ != 1) { // only support same mode |
|
|
|
MS_LOG(ERROR) << name_ << ": Do not support the pad mode " << pad_mode_ << " when split H or W dimension"; |
|
|
|
return FAILED; |
|
|
|
@@ -649,18 +675,18 @@ Status Conv2DBackpropInputInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_st |
|
|
|
|
|
|
|
if (h_strategy > 1) { |
|
|
|
if (inputs_shape_[0][2] * stride_[2] != outputs_shape_[0][2]) { |
|
|
|
MS_LOG(ERROR) << name_ << ": Do not support split h dimension when in_shape * stride != out_shape"; |
|
|
|
MS_LOG(ERROR) << name_ << ": Do not support to split h dimension when in_shape * stride != out_shape"; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
if (kernel_size_[0] > stride_[2]) { |
|
|
|
MS_LOG(ERROR) << name_ << ": Do not support split h dimension when kernel size larger than stride"; |
|
|
|
MS_LOG(ERROR) << name_ << ": Do not support to split h dimension when kernel size larger than stride"; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if (w_strategy > 1 && inputs_shape_[0][3] * stride_[3] != outputs_shape_[0][3]) { |
|
|
|
MS_LOG(ERROR) << name_ << ": Do not support split w dimension when in_shape * stride != out_shape"; |
|
|
|
MS_LOG(ERROR) << name_ << ": Do not support to split w dimension when in_shape * stride != out_shape"; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
|