|
|
|
@@ -30,6 +30,11 @@ using mindspore::abstract::Shape; |
|
|
|
namespace mindspore { |
|
|
|
namespace ops { |
|
|
|
namespace { |
|
|
|
constexpr size_t top_padding = 0; |
|
|
|
constexpr size_t bottom_padding = 1; |
|
|
|
constexpr size_t left_padding = 2; |
|
|
|
constexpr size_t right_padding = 3; |
|
|
|
|
|
|
|
// check functions |
|
|
|
void CheckShapeAnyAndPositive(const std::string &op, const ShapeVector &shape) { |
|
|
|
for (size_t i = 0; i < shape.size(); ++i) { |
|
|
|
@@ -148,12 +153,30 @@ void Conv2DPadFunction(std::vector<int64_t> *output_hw, std::vector<int64_t> *pa |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
bool CheckConv2dShape(const std::string &prim_name, const std::vector<AbstractBasePtr> &input_args, |
|
|
|
const std::vector<int64_t> &x_shape, const std::vector<int64_t> &w_shape, |
|
|
|
const std::vector<int64_t> &padding, int64_t pad_mode, uint64_t w_axis, uint64_t h_axis) { |
|
|
|
auto x_shape_ptr = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, 0); |
|
|
|
auto w_shape_ptr = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, 1); |
|
|
|
if (x_shape_ptr->IsDynamic() || w_shape_ptr->IsDynamic()) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
if (w_shape[w_axis] != Shape::SHP_ANY && pad_mode != PadMode::SAME) { |
|
|
|
int64_t input_height = x_shape[h_axis]; |
|
|
|
int64_t input_width = x_shape[w_axis]; |
|
|
|
if (pad_mode == PadMode::PAD) { |
|
|
|
input_height += padding[left_padding] + padding[right_padding]; |
|
|
|
input_width += padding[top_padding] + padding[bottom_padding]; |
|
|
|
} |
|
|
|
if (input_height < w_shape[h_axis] || input_width < w_shape[w_axis]) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { |
|
|
|
MS_EXCEPTION_IF_NULL(primitive); |
|
|
|
auto prim_name = primitive->name(); |
|
|
|
for (const auto &item : input_args) { |
|
|
|
MS_EXCEPTION_IF_NULL(item); |
|
|
|
} |
|
|
|
auto x_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape()); |
|
|
|
auto w_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape()); |
|
|
|
auto x_shape = x_shape_map[kShape]; |
|
|
|
@@ -216,6 +239,10 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve |
|
|
|
std::vector<int64_t> padding = CheckAttrIntOrTuple(primitive->GetAttr("pad"), 0, padding_num); |
|
|
|
int64_t pad_mode; |
|
|
|
CheckAndConvertUtils::GetPadModEnumValue(primitive->GetAttr("pad_mode"), &pad_mode); |
|
|
|
if (!CheckConv2dShape(prim_name, input_args, x_shape, w_shape, padding, pad_mode, w_axis, h_axis)) { |
|
|
|
MS_LOG(EXCEPTION) |
|
|
|
<< "Shape error for Conv2d, input shape's h and w after padding is less than kernel_size's h and w dims."; |
|
|
|
} |
|
|
|
std::vector<int64_t> output_hw; |
|
|
|
std::vector<int64_t> pad_list; |
|
|
|
std::vector<int64_t> output_hw_min; |
|
|
|
@@ -371,6 +398,7 @@ Format Conv2D::get_format() const { |
|
|
|
|
|
|
|
AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, |
|
|
|
const std::vector<AbstractBasePtr> &input_args) { |
|
|
|
MS_EXCEPTION_IF_NULL(primitive); |
|
|
|
for (auto item : input_args) { |
|
|
|
MS_EXCEPTION_IF_NULL(item); |
|
|
|
} |
|
|
|
|