You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv2d.cc 9.4 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "ops/conv2d.h"
  17. #include <string>
  18. #include <algorithm>
  19. #include <memory>
  20. #include <set>
  21. #include <vector>
  22. #include "ir/dtype/tensor_type.h"
  23. #include "utils/check_convert_utils.h"
  24. #include "abstract/primitive_infer_map.h"
  25. #include "ops/control_depend.h"
  26. namespace mindspore {
  27. namespace ops {
  28. namespace {
  29. abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
  30. MS_EXCEPTION_IF_NULL(primitive);
  31. auto conv_prim = primitive->cast<PrimConv2dPtr>();
  32. MS_EXCEPTION_IF_NULL(conv_prim);
  33. auto prim_name = conv_prim->name();
  34. CheckAndConvertUtils::CheckInRange<size_t>("conv2d_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name);
  35. auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
  36. auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->BuildShape(), prim_name);
  37. if (conv_prim->get_format() == NHWC) {
  38. x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]};
  39. w_shape = {w_shape[0], w_shape[3], w_shape[1], w_shape[2]};
  40. }
  41. CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 4, prim_name);
  42. CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name);
  43. CheckAndConvertUtils::Check("x_shape[1] / group", x_shape[1] / conv_prim->get_group(), kEqual, "w_shape[1]",
  44. w_shape[1], conv_prim->name());
  45. auto out_channel = conv_prim->get_out_channel();
  46. CheckAndConvertUtils::Check("out_channel", out_channel, kEqual, "w_shape[0]", w_shape[0], conv_prim->name());
  47. std::vector<int64_t> temp_w;
  48. std::copy(w_shape.begin() + 2, w_shape.end(), std::back_inserter(temp_w));
  49. CheckAndConvertUtils::Check("kernel_size", conv_prim->get_kernel_size(), kEqual, "w_shape[2:4]", temp_w,
  50. conv_prim->name());
  51. auto kernel_size_h = w_shape[2];
  52. auto kernel_size_w = w_shape[3];
  53. auto stride = conv_prim->get_stride();
  54. auto dilation = conv_prim->get_dilation();
  55. auto stride_h = stride[2];
  56. auto stride_w = stride[3];
  57. auto dilation_h = dilation[2];
  58. auto dilation_w = dilation[3];
  59. int64_t h_out = -1;
  60. int64_t w_out = -1;
  61. std::vector<int64_t> pad_list(4, 0);
  62. auto pad_mode = conv_prim->get_pad_mode();
  63. if (pad_mode == VALID) {
  64. h_out = ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h);
  65. w_out = ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w);
  66. } else if (pad_mode == SAME) {
  67. h_out = ceil(x_shape[2] / stride_h);
  68. w_out = ceil(x_shape[3] / stride_w);
  69. auto pad_needed_h =
  70. std::max(static_cast<int64_t>(0), (h_out - 1) * stride_h + dilation_h * (kernel_size_h - 1) + 1 - x_shape[2]);
  71. pad_list.emplace_back(floor(pad_needed_h / 2));
  72. pad_list.emplace_back(pad_needed_h / 2);
  73. auto pad_needed_w =
  74. std::max(static_cast<int64_t>(0), (w_out - 1) * stride_w + dilation_w * (kernel_size_w - 1) + 1 - x_shape[3]);
  75. auto pad_left = floor(pad_needed_w / 2);
  76. pad_list.emplace_back(pad_left);
  77. pad_list.emplace_back(pad_needed_h - pad_left);
  78. } else if (pad_mode == PAD) {
  79. std::copy(conv_prim->get_pad().begin(), conv_prim->get_pad().end(), std::back_inserter(pad_list));
  80. auto pad_top = conv_prim->get_pad()[0];
  81. auto pad_bottom = conv_prim->get_pad()[1];
  82. auto pad_right = conv_prim->get_pad()[2];
  83. auto pad_left = conv_prim->get_pad()[3];
  84. h_out = 1 + (x_shape[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) / stride_h;
  85. w_out = 1 + (x_shape[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) / stride_w;
  86. h_out = floor(h_out);
  87. w_out = floor(w_out);
  88. }
  89. conv_prim->set_pad(pad_list);
  90. std::vector<int64_t> out_shape = {x_shape[0], out_channel, h_out, w_out};
  91. if (conv_prim->get_format() == NHWC) {
  92. out_shape = {x_shape[0], h_out, w_out, out_channel};
  93. }
  94. return std::make_shared<abstract::Shape>(out_shape);
  95. }
  96. TypePtr Conv2dInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
  97. CheckAndConvertUtils::CheckInRange<size_t>("", input_args.size(), kIncludeBoth, {2, 3}, prim->name());
  98. for (const auto &item : input_args) {
  99. MS_EXCEPTION_IF_NULL(item);
  100. }
  101. const std::set<TypeId> valid_types = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeInt64, kNumberTypeFloat16,
  102. kNumberTypeFloat32};
  103. std::map<std::string, TypePtr> types;
  104. types.emplace("x", input_args[0]->BuildType());
  105. types.emplace("w", input_args[1]->BuildType());
  106. auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
  107. if (infer_type == kNumberTypeInt8) {
  108. return TypeIdToType(kNumberTypeInt32);
  109. }
  110. return TypeIdToType(infer_type);
  111. }
  112. } // namespace
  113. void Conv2D::Init(int64_t out_channel, const std::vector<int64_t> &kernel_size, int64_t mode, const PadMode &pad_mode,
  114. const std::vector<int64_t> &pad, const std::vector<int64_t> &stride,
  115. const std::vector<int64_t> &dilation, int64_t group, const Format &format) {
  116. set_kernel_size(kernel_size);
  117. set_stride(stride);
  118. set_dilation(dilation);
  119. set_pad(pad);
  120. set_pad_mode(pad_mode);
  121. set_mode(mode);
  122. set_out_channel(out_channel);
  123. set_group(group);
  124. set_format(format);
  125. }
  126. void Conv2D::set_out_channel(int64_t out_channel) {
  127. AddAttr(kOutChannel,
  128. MakeValue(CheckAndConvertUtils::CheckInteger(kOutChannel, out_channel, kGreaterThan, 0, name())));
  129. }
  130. void Conv2D::set_kernel_size(const std::vector<int64_t> &kernel_size) {
  131. AddAttr(kKernelSize, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, name())));
  132. }
  133. void Conv2D::set_stride(const std::vector<int64_t> &stride) {
  134. AddAttr(kStride, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStride, stride, name(), true, true)));
  135. }
  136. void Conv2D::set_dilation(const std::vector<int64_t> &dilation) {
  137. AddAttr(kDilation, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, name(), true, true)));
  138. }
  139. void Conv2D::set_pad_mode(const PadMode &pad_mode) {
  140. std::vector<int64_t> pad = get_pad();
  141. if (pad_mode == PAD) {
  142. for (auto item : pad) {
  143. CheckAndConvertUtils::Check(kPadItem, item, kGreaterEqual, "zeros_list", 0, name());
  144. }
  145. } else {
  146. CheckAndConvertUtils::Check(kPad, pad, kEqual, "zeros_list", {0, 0, 0, 0}, name());
  147. }
  148. int64_t swi = pad_mode;
  149. AddAttr(kPadMode, MakeValue(swi));
  150. }
  151. void Conv2D::set_pad(const std::vector<int64_t> &pad) {
  152. CheckAndConvertUtils::CheckInteger("pad_size", pad.size(), kEqual, 4, name());
  153. AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name(), true, true)));
  154. }
  155. void Conv2D::set_mode(int64_t mode) {
  156. AddAttr(kMode, MakeValue(CheckAndConvertUtils::CheckInteger(kMode, mode, kEqual, 1, name())));
  157. }
  158. void Conv2D::set_group(int64_t group) {
  159. AddAttr(kGroup, MakeValue(CheckAndConvertUtils::CheckInteger(kGroup, group, kGreaterThan, 0, name())));
  160. }
  161. void Conv2D::set_format(const Format &format) {
  162. int64_t f = format;
  163. AddAttr(kFormat, MakeValue(f));
  164. }
  165. int64_t Conv2D::get_out_channel() const {
  166. auto value_ptr = GetAttr(kOutChannel);
  167. return GetValue<int64_t>(value_ptr);
  168. }
  169. std::vector<int64_t> Conv2D::get_kernel_size() const {
  170. auto value_ptr = GetAttr(kKernelSize);
  171. return GetValue<std::vector<int64_t>>(value_ptr);
  172. }
  173. std::vector<int64_t> Conv2D::get_stride() const {
  174. auto value_ptr = GetAttr(kStride);
  175. return GetValue<std::vector<int64_t>>(value_ptr);
  176. }
  177. std::vector<int64_t> Conv2D::get_dilation() const {
  178. auto value_ptr = GetAttr(kDilation);
  179. return GetValue<std::vector<int64_t>>(value_ptr);
  180. }
  181. PadMode Conv2D::get_pad_mode() const {
  182. auto value_ptr = GetAttr(kPadMode);
  183. return PadMode(GetValue<int64_t>(value_ptr));
  184. }
  185. std::vector<int64_t> Conv2D::get_pad() const {
  186. auto value_ptr = GetAttr(kPad);
  187. return GetValue<std::vector<int64_t>>(value_ptr);
  188. }
  189. int64_t Conv2D::get_mode() const {
  190. auto value_ptr = GetAttr(kMode);
  191. return GetValue<int64_t>(value_ptr);
  192. }
  193. int64_t Conv2D::get_group() const {
  194. auto value_ptr = GetAttr(kGroup);
  195. return GetValue<int64_t>(value_ptr);
  196. }
  197. Format Conv2D::get_format() const {
  198. auto value_ptr = GetAttr(kFormat);
  199. return Format(GetValue<int64_t>(value_ptr));
  200. }
  201. AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
  202. const std::vector<AbstractBasePtr> &input_args) {
  203. return std::make_shared<abstract::AbstractTensor>(Conv2dInferType(primitive, input_args),
  204. Conv2dInferShape(primitive, input_args)->shape());
  205. }
  206. REGISTER_PRIMITIVE_EVAL_IMPL(Conv2D, prim::kPrimConv2D, Conv2dInfer);
  207. REGISTER_PRIMITIVE_C(kNameConv2D, Conv2D);
  208. } // namespace ops
  209. } // namespace mindspore