|
|
|
@@ -32,6 +32,13 @@ |
|
|
|
namespace mindspore { |
|
|
|
namespace parallel { |
|
|
|
Status Conv2DInfo::GetAttrsBase() { |
|
|
|
// format |
|
|
|
format_ = GetStringAttr(FORMAT); |
|
|
|
if (format_ != NCHW) { |
|
|
|
MS_LOG(ERROR) << name_ << ": The format must be 'NCHW', but got " << format_; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
// out_channel |
|
|
|
out_channel_ = GetIntAttr(OUT_CHANNEL); |
|
|
|
if (out_channel_ <= 0) { |
|
|
|
@@ -105,13 +112,6 @@ Status Conv2DInfo::GetAttrsBase() { |
|
|
|
// group |
|
|
|
group_ = GetIntAttr(GROUP); |
|
|
|
|
|
|
|
// format |
|
|
|
format_ = GetStringAttr(FORMAT); |
|
|
|
if (format_ != NCHW) { |
|
|
|
MS_LOG(ERROR) << name_ << ": The format must be 'NCHW', but got " << format_; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
MS_LOG(INFO) << name_ << ": The out channel is " << out_channel_ << ", kernel size is " << kernel_size_ |
|
|
|
<< ", mode is " << mode_ << ", pad mode is " << pad_mode_ << ", pad list is " << pad_list_ |
|
|
|
<< ", stride is " << stride_ << ", dilation is " << dilation_ << ", group is " << group_ |
|
|
|
|