Browse Source

check data format for conv2d

tags/v1.5.0-rc1
yangzhenzhang 4 years ago
parent
commit
caa429d67c
2 changed files with 13 additions and 7 deletions
  1. +7
    -7
      mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.cc
  2. +6
    -0
      mindspore/ccsrc/frontend/parallel/ops_info/virtual_output_info.cc

+ 7
- 7
mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.cc View File

@@ -32,6 +32,13 @@
namespace mindspore {
namespace parallel {
Status Conv2DInfo::GetAttrsBase() {
// format
format_ = GetStringAttr(FORMAT);
if (format_ != NCHW) {
MS_LOG(ERROR) << name_ << ": The format must be 'NCHW', but got " << format_;
return FAILED;
}

// out_channel
out_channel_ = GetIntAttr(OUT_CHANNEL);
if (out_channel_ <= 0) {
@@ -105,13 +112,6 @@ Status Conv2DInfo::GetAttrsBase() {
// group
group_ = GetIntAttr(GROUP);

// format
format_ = GetStringAttr(FORMAT);
if (format_ != NCHW) {
MS_LOG(ERROR) << name_ << ": The format must be 'NCHW', but got " << format_;
return FAILED;
}

MS_LOG(INFO) << name_ << ": The out channel is " << out_channel_ << ", kernel size is " << kernel_size_
<< ", mode is " << mode_ << ", pad mode is " << pad_mode_ << ", pad list is " << pad_list_
<< ", stride is " << stride_ << ", dilation is " << dilation_ << ", group is " << group_


+ 6
- 0
mindspore/ccsrc/frontend/parallel/ops_info/virtual_output_info.cc View File

@@ -62,6 +62,12 @@ Status VirtualOutputInfo::GenerateStrategies(int64_t stage_id) {
} else {
total_dev_num = LongToSize(stage_device_size_);
}

if (total_dev_num == 0) {
MS_LOG(ERROR) << name_ << ": The total devices num is 0";
return FAILED;
}

for (auto &shape : inputs_shape_) {
Shape temp;
if (!shape.empty()) {


Loading…
Cancel
Save