|
- /**
- * Copyright 2020-2021 Huawei Technologies Co., Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
- #include "ops/conv2d.h"
- #include <string>
- #include <algorithm>
- #include <memory>
- #include <set>
- #include <vector>
- #include "ir/dtype/tensor_type.h"
- #include "utils/check_convert_utils.h"
- #include "abstract/primitive_infer_map.h"
-
- namespace mindspore {
- namespace ops {
- namespace {
- std::vector<int64_t> SetPadList(const PrimitivePtr &primitive, const std::vector<int64_t> &w_shape,
- const std::vector<int64_t> &x_shape, const int64_t &out_channel) {
- auto kernel_size_h = w_shape[2];
- auto kernel_size_w = w_shape[3];
- auto stride = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStride));
- auto dilation = GetValue<std::vector<int64_t>>(primitive->GetAttr(kDilation));
- auto stride_h = stride[2];
- auto stride_w = stride[3];
- auto dilation_h = dilation[2];
- auto dilation_w = dilation[3];
- int64_t h_out = -1;
- int64_t w_out = -1;
- std::vector<int64_t> pad_list(4, 0);
- auto pad_mode = PadMode(GetValue<int64_t>(primitive->GetAttr(kPadMode)));
- if (pad_mode == VALID) {
- h_out = ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h);
- w_out = ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w);
- } else if (pad_mode == SAME) {
- h_out = ceil(x_shape[2] / stride_h);
- w_out = ceil(x_shape[3] / stride_w);
-
- auto pad_needed_h =
- std::max(static_cast<int64_t>(0), (h_out - 1) * stride_h + dilation_h * (kernel_size_h - 1) + 1 - x_shape[2]);
- pad_list[0] = floor(pad_needed_h / 2);
- pad_list[1] = pad_needed_h / 2;
- auto pad_needed_w =
- std::max(static_cast<int64_t>(0), (w_out - 1) * stride_w + dilation_w * (kernel_size_w - 1) + 1 - x_shape[3]);
- auto pad_left = floor(pad_needed_w / 2);
- pad_list[2] = pad_left;
- pad_list[3] = pad_needed_h - pad_left;
- } else if (pad_mode == PAD) {
- auto pad = GetValue<std::vector<int64_t>>(primitive->GetAttr(kPad));
- (void)std::copy(pad.begin(), pad.end(), std::back_inserter(pad_list));
- auto pad_top = pad[0];
- auto pad_bottom = pad[1];
- auto pad_right = pad[2];
- auto pad_left = pad[3];
-
- h_out = 1 + (x_shape[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) / stride_h;
- w_out = 1 + (x_shape[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) / stride_w;
- h_out = floor(h_out);
- w_out = floor(w_out);
- }
- (void)CheckAndConvertUtils::CheckInteger("pad_size", pad_list.size(), kEqual, 4, primitive->name());
- (void)primitive->AddAttr(kPadList,
- MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad_list, primitive->name())));
- std::vector<int64_t> out_shape = {x_shape[0], out_channel, h_out, w_out};
- return out_shape;
- }
- abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
- MS_EXCEPTION_IF_NULL(primitive);
- auto prim_name = primitive->name();
- (void)CheckAndConvertUtils::CheckInRange<size_t>("conv2d_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name);
- auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
- auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
- auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
- if (format == NHWC) {
- x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]};
- w_shape = {w_shape[0], w_shape[3], w_shape[1], w_shape[2]};
- }
- (void)CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 4, prim_name);
- (void)CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name);
- (void)CheckAndConvertUtils::Check("x_shape[1] / group", x_shape[1] / GetValue<int64_t>(primitive->GetAttr(kGroup)),
- kEqual, "w_shape[1]", w_shape[1], prim_name);
- auto out_channel = GetValue<int64_t>(primitive->GetAttr(kOutChannel));
- (void)CheckAndConvertUtils::Check("out_channel", out_channel, kEqual, "w_shape[0]", w_shape[0], prim_name);
- std::vector<int64_t> temp_w;
- (void)std::copy(w_shape.begin() + 2, w_shape.end(), std::back_inserter(temp_w));
- (void)CheckAndConvertUtils::Check("kernel_size", GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize)),
- kEqual, "w_shape[2:4]", temp_w, prim_name);
- auto out_shape = SetPadList(primitive, w_shape, x_shape, out_channel);
- if (format == NHWC) {
- out_shape = {out_shape[0], out_shape[3], out_shape[1], out_shape[2]};
- }
- return std::make_shared<abstract::Shape>(out_shape);
- }
-
- TypePtr Conv2dInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
- (void)CheckAndConvertUtils::CheckInRange<size_t>("", input_args.size(), kIncludeBoth, {2, 3}, prim->name());
- for (const auto &item : input_args) {
- MS_EXCEPTION_IF_NULL(item);
- }
- const std::set<TypePtr> valid_types = {kInt8, kInt32, kInt64, kFloat16, kFloat32};
- std::map<std::string, TypePtr> types;
- (void)types.emplace("x", input_args[0]->BuildType());
- (void)types.emplace("w", input_args[1]->BuildType());
- return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
- }
- } // namespace
- void Conv2D::Init(int64_t out_channel, const std::vector<int64_t> &kernel_size, int64_t mode, const PadMode &pad_mode,
- const std::vector<int64_t> &pad, const std::vector<int64_t> &stride,
- const std::vector<int64_t> &dilation, int64_t group, const Format &format) {
- set_kernel_size(kernel_size);
- set_stride(stride);
- set_dilation(dilation);
- set_pad(pad);
- set_pad_mode(pad_mode);
- set_mode(mode);
- set_out_channel(out_channel);
- set_group(group);
- set_format(format);
- }
-
- void Conv2D::set_out_channel(int64_t out_channel) {
- AddAttr(kOutChannel,
- MakeValue(CheckAndConvertUtils::CheckInteger(kOutChannel, out_channel, kGreaterThan, 0, name())));
- }
-
- void Conv2D::set_kernel_size(const std::vector<int64_t> &kernel_size) {
- AddAttr(kKernelSize, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, name())));
- }
-
- void Conv2D::set_stride(const std::vector<int64_t> &stride) {
- AddAttr(kStride, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStride, stride, name())));
- }
-
- void Conv2D::set_dilation(const std::vector<int64_t> &dilation) {
- AddAttr(kDilation, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, name())));
- }
-
- void Conv2D::set_pad_mode(const PadMode &pad_mode) {
- std::vector<int64_t> pad = get_pad();
- if (pad_mode == PAD) {
- for (auto item : pad) {
- CheckAndConvertUtils::Check(kPadItem, item, kGreaterEqual, "zeros_list", 0, name());
- }
- } else {
- CheckAndConvertUtils::Check(kPad, pad, kEqual, "zeros_list", {0, 0, 0, 0}, name());
- }
- int64_t swi = pad_mode;
- AddAttr(kPadMode, MakeValue(swi));
- }
-
- void Conv2D::set_pad(const std::vector<int64_t> &pad) {
- (void)CheckAndConvertUtils::CheckInteger("pad_size", pad.size(), kEqual, 4, name());
- AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name())));
- }
-
- void Conv2D::set_mode(int64_t mode) {
- AddAttr(kMode, MakeValue(CheckAndConvertUtils::CheckInteger(kMode, mode, kEqual, 1, name())));
- }
-
- void Conv2D::set_group(int64_t group) {
- AddAttr(kGroup, MakeValue(CheckAndConvertUtils::CheckInteger(kGroup, group, kGreaterThan, 0, name())));
- }
-
- void Conv2D::set_format(const Format &format) {
- int64_t f = format;
- AddAttr(kFormat, MakeValue(f));
- }
-
- int64_t Conv2D::get_out_channel() const {
- auto value_ptr = GetAttr(kOutChannel);
- return GetValue<int64_t>(value_ptr);
- }
-
- std::vector<int64_t> Conv2D::get_kernel_size() const {
- auto value_ptr = GetAttr(kKernelSize);
- return GetValue<std::vector<int64_t>>(value_ptr);
- }
-
- std::vector<int64_t> Conv2D::get_stride() const {
- auto value_ptr = GetAttr(kStride);
- return GetValue<std::vector<int64_t>>(value_ptr);
- }
-
- std::vector<int64_t> Conv2D::get_dilation() const {
- auto value_ptr = GetAttr(kDilation);
- return GetValue<std::vector<int64_t>>(value_ptr);
- }
-
- PadMode Conv2D::get_pad_mode() const {
- auto value_ptr = GetAttr(kPadMode);
- return PadMode(GetValue<int64_t>(value_ptr));
- }
-
- std::vector<int64_t> Conv2D::get_pad() const {
- auto value_ptr = GetAttr(kPad);
- return GetValue<std::vector<int64_t>>(value_ptr);
- }
-
- int64_t Conv2D::get_mode() const {
- auto value_ptr = GetAttr(kMode);
- return GetValue<int64_t>(value_ptr);
- }
-
- int64_t Conv2D::get_group() const {
- auto value_ptr = GetAttr(kGroup);
- return GetValue<int64_t>(value_ptr);
- }
-
- Format Conv2D::get_format() const {
- auto value_ptr = GetAttr(kFormat);
- return Format(GetValue<int64_t>(value_ptr));
- }
-
- AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
- const std::vector<AbstractBasePtr> &input_args) {
- return std::make_shared<abstract::AbstractTensor>(Conv2dInferType(primitive, input_args),
- Conv2dInferShape(primitive, input_args)->shape());
- }
- REGISTER_PRIMITIVE_C(kNameConv2D, Conv2D);
- } // namespace ops
- } // namespace mindspore
|