Browse Source

!32768 restore conv2d bug of infershape

Merge pull request !32768 from wangyanling/r1.7
r1.7
i-robot Gitee 4 years ago
parent
commit
cd344ae665
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 32 additions and 4 deletions
  1. +32
    -4
      mindspore/core/ops/conv2d.cc

+ 32
- 4
mindspore/core/ops/conv2d.cc View File

@@ -30,6 +30,11 @@ using mindspore::abstract::Shape;
namespace mindspore {
namespace ops {
namespace {
constexpr size_t top_padding = 0;
constexpr size_t bottom_padding = 1;
constexpr size_t left_padding = 2;
constexpr size_t right_padding = 3;

// check functions
void CheckShapeAnyAndPositive(const std::string &op, const ShapeVector &shape) {
for (size_t i = 0; i < shape.size(); ++i) {
@@ -148,12 +153,30 @@ void Conv2DPadFunction(std::vector<int64_t> *output_hw, std::vector<int64_t> *pa
}
}

bool CheckConv2dShape(const std::string &prim_name, const std::vector<AbstractBasePtr> &input_args,
const std::vector<int64_t> &x_shape, const std::vector<int64_t> &w_shape,
const std::vector<int64_t> &padding, int64_t pad_mode, uint64_t w_axis, uint64_t h_axis) {
auto x_shape_ptr = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, 0);
auto w_shape_ptr = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, 1);
if (x_shape_ptr->IsDynamic() || w_shape_ptr->IsDynamic()) {
return true;
}
if (w_shape[w_axis] != Shape::SHP_ANY && pad_mode != PadMode::SAME) {
int64_t input_height = x_shape[h_axis];
int64_t input_width = x_shape[w_axis];
if (pad_mode == PadMode::PAD) {
input_height += padding[left_padding] + padding[right_padding];
input_width += padding[top_padding] + padding[bottom_padding];
}
if (input_height < w_shape[h_axis] || input_width < w_shape[w_axis]) {
return false;
}
}
return true;
}

abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto x_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
auto w_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape());
auto x_shape = x_shape_map[kShape];
@@ -216,6 +239,10 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve
std::vector<int64_t> padding = CheckAttrIntOrTuple(primitive->GetAttr("pad"), 0, padding_num);
int64_t pad_mode;
CheckAndConvertUtils::GetPadModEnumValue(primitive->GetAttr("pad_mode"), &pad_mode);
if (!CheckConv2dShape(prim_name, input_args, x_shape, w_shape, padding, pad_mode, w_axis, h_axis)) {
MS_LOG(EXCEPTION)
<< "Shape error for Conv2d, input shape's h and w after padding is less than kernel_size's h and w dims.";
}
std::vector<int64_t> output_hw;
std::vector<int64_t> pad_list;
std::vector<int64_t> output_hw_min;
@@ -371,6 +398,7 @@ Format Conv2D::get_format() const {

AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
for (auto item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}


Loading…
Cancel
Save