Browse Source

fix bug of conv2d cpp infer

tags/v1.2.0-rc1
LianLiguang 4 years ago
parent
commit
a39b312191
1 changed files with 4 additions and 4 deletions
  1. +4
    -4
      mindspore/core/ops/conv2d.cc

+ 4
- 4
mindspore/core/ops/conv2d.cc View File

@@ -72,13 +72,13 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve


auto pad_needed_h = auto pad_needed_h =
std::max(static_cast<int64_t>(0), (h_out - 1) * stride_h + dilation_h * (kernel_size_h - 1) + 1 - x_shape[2]); std::max(static_cast<int64_t>(0), (h_out - 1) * stride_h + dilation_h * (kernel_size_h - 1) + 1 - x_shape[2]);
pad_list.emplace_back(floor(pad_needed_h / 2));
pad_list.emplace_back(pad_needed_h / 2);
pad_list[0] = floor(pad_needed_h / 2);
pad_list[1] = pad_needed_h / 2;
auto pad_needed_w = auto pad_needed_w =
std::max(static_cast<int64_t>(0), (w_out - 1) * stride_w + dilation_w * (kernel_size_w - 1) + 1 - x_shape[3]); std::max(static_cast<int64_t>(0), (w_out - 1) * stride_w + dilation_w * (kernel_size_w - 1) + 1 - x_shape[3]);
auto pad_left = floor(pad_needed_w / 2); auto pad_left = floor(pad_needed_w / 2);
pad_list.emplace_back(pad_left);
pad_list.emplace_back(pad_needed_h - pad_left);
pad_list[2] = pad_left;
pad_list[3] = pad_needed_h - pad_left;
} else if (pad_mode == PAD) { } else if (pad_mode == PAD) {
auto pad = GetValue<std::vector<int64_t>>(primitive->GetAttr(kPad)); auto pad = GetValue<std::vector<int64_t>>(primitive->GetAttr(kPad));
std::copy(pad.begin(), pad.end(), std::back_inserter(pad_list)); std::copy(pad.begin(), pad.end(), std::back_inserter(pad_list));


Loading…
Cancel
Save