You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

deformable_conv.cpp 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. /**
  2. * \file dnn/src/common/deformable_conv.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megdnn/oprs/nn.h"
  12. #include "src/common/utils.h"
  13. using namespace megdnn;
  14. using CanonizedFilterMeta = DeformableConvBase::CanonizedFilterMeta;
  15. namespace {
  16. template <typename Param>
  17. std::string get_errmsg(const TensorLayout& src, const TensorLayout& filter,
  18. const TensorLayout& offset, const TensorLayout& mask,
  19. const TensorLayout& dst, const Param& param) {
  20. MEGDNN_MARK_USED_VAR(src);
  21. MEGDNN_MARK_USED_VAR(filter);
  22. MEGDNN_MARK_USED_VAR(dst);
  23. return megdnn_layout_msg(src) + ", " + megdnn_layout_msg(filter) + ", " +
  24. megdnn_layout_msg(offset) + ", " + megdnn_layout_msg(mask) + ", " +
  25. megdnn_layout_msg(dst) + ", " + megdnn_mangle("only support nchw") +
  26. ", " + megdnn_mangle("group=") + std::to_string(param.group) + ", " +
  27. megdnn_mangle("deformable_group=") +
  28. std::to_string(param.deformable_group) + ", " +
  29. megdnn_mangle("pad_h=") + std::to_string(param.pad_h) + ", " +
  30. megdnn_mangle("pad_w=") + std::to_string(param.pad_w) + ", " +
  31. megdnn_mangle("stride_h=") + std::to_string(param.stride_h) + ", " +
  32. megdnn_mangle("stride_w=") + std::to_string(param.stride_w) + ", " +
  33. megdnn_mangle("dilate_h=") + std::to_string(param.dilate_h) + ", " +
  34. megdnn_mangle("dilate_w=") + std::to_string(param.dilate_w);
  35. }
  36. template <typename Param>
  37. void make_canonized_filter_meta_nchw(size_t src_ndim,
  38. const TensorLayout& filter,
  39. const Param& param,
  40. CanonizedFilterMeta& ret) {
  41. megdnn_assert(param.mode == Param::Mode::CROSS_CORRELATION,
  42. "only support CROSS_CORRELATION mode");
  43. megdnn_assert(param.format == Param::Format::NCHW,
  44. "only support nchw input layout");
  45. size_t flt_start, flt_spatial_start, ocpg_pos, icpg_pos;
  46. flt_start = 0, flt_spatial_start = 2;
  47. ocpg_pos = 0, icpg_pos = 1;
  48. if (param.sparse == Param::Sparse::GROUP)
  49. flt_start = 1;
  50. ret.spatial_ndim = src_ndim - 2;
  51. megdnn_assert(
  52. ret.spatial_ndim == 2,
  53. "only 2D convolution is supported, and imput should be 4-dim; "
  54. "got input dim = %zu",
  55. src_ndim);
  56. ret.ocpg = filter[flt_start + ocpg_pos];
  57. ret.icpg = filter[flt_start + icpg_pos];
  58. auto dilation = ret.dilation;
  59. for (size_t i = 0; i < ret.spatial_ndim; ++i) {
  60. megdnn_assert(dilation[i] > 0,
  61. "invalid dilation on spatial dim %zu, %u", i,
  62. dilation[i]);
  63. ret.spatial[i] = filter[i + flt_start + flt_spatial_start];
  64. ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1;
  65. }
  66. }
  67. } // namespace
  68. namespace megdnn {
  69. CanonizedFilterMeta DeformableConvBase::make_canonized_filter_meta(
  70. size_t src_ndim, const TensorLayout& filter,
  71. const TensorLayout& offset) const {
  72. megdnn_assert_contiguous(filter);
  73. CanonizedFilterMeta ret;
  74. ret.group = 1;
  75. ret.dtype = filter.dtype;
  76. ret.stride[0] = param().stride_h;
  77. ret.stride[1] = param().stride_w;
  78. ret.padding[0] = param().pad_h;
  79. ret.padding[1] = param().pad_w;
  80. ret.dilation[0] = param().dilate_h;
  81. ret.dilation[1] = param().dilate_w;
  82. if (param().sparse == Param::Sparse::GROUP) {
  83. megdnn_assert(filter.ndim == 5,
  84. "filter dim should be 5 for group conv");
  85. ret.group = filter[0];
  86. }
  87. make_canonized_filter_meta_nchw(src_ndim, filter, param(), ret);
  88. auto fh = ret.spatial[0];
  89. auto fw = ret.spatial[1];
  90. ret.deformable_group = offset[1] / (2 * fh * fw);
  91. return ret;
  92. }
  93. void DeformableConvBase::deduce_layout_fwd(const TensorLayout& im,
  94. const TensorLayout& filter,
  95. const TensorLayout& offset,
  96. const TensorLayout& mask,
  97. TensorLayout& dst) {
  98. // im shape: (n, IC, IH, IW)
  99. megdnn_assert(im.ndim == 4, "invalid src layout: %s",
  100. megdnn_layout_msg(im).c_str());
  101. // filter shape: (OC, IC, FH, FW) or (g, OC/g, IC/g, FH, FW)
  102. megdnn_assert(filter.ndim == 4 || filter.ndim == 5,
  103. "invalid filter layout: %s",
  104. megdnn_layout_msg(filter).c_str());
  105. // offset shape: (N, 2*dg*FH*FW, OH, OW)
  106. megdnn_assert(offset.ndim == 4, "invalid offset layout: %s",
  107. megdnn_layout_msg(offset).c_str());
  108. // mask shape: (N, dg*FH*FW, OH, OW)
  109. megdnn_assert(mask.ndim == 4, "invalid mask layout: %s",
  110. megdnn_layout_msg(mask).c_str());
  111. size_t n = im.shape[0], ic = im.shape[1];
  112. size_t ih = im.shape[2], iw = im.shape[3];
  113. size_t dh = param().dilate_h, dw = param().dilate_w;
  114. size_t ph = param().pad_h, pw = param().pad_w;
  115. size_t sh = param().stride_h, sw = param().stride_w;
  116. auto&& fm = make_canonized_filter_meta(im.ndim, filter, offset);
  117. size_t fh = fm.spatial[0], fw = fm.spatial[1];
  118. size_t kh = 1 + (fh - 1) * dh;
  119. size_t kw = 1 + (fw - 1) * dw;
  120. size_t group = fm.group;
  121. size_t deformable_group = fm.deformable_group;
  122. size_t icpg = fm.icpg, ocpg = fm.ocpg;
  123. size_t oc = group * ocpg;
  124. size_t oh = (ih + ph * 2 - kh) / sh + 1;
  125. size_t ow = (iw + pw * 2 - kw) / sw + 1;
  126. megdnn_assert(group > 0 && deformable_group > 0,
  127. "group and deformable group should > 0");
  128. megdnn_assert(ic == icpg * group, "im ic != group * icpg of filter");
  129. megdnn_assert(ic % deformable_group == 0, "ic %% deformable_group != 0");
  130. megdnn_assert(oc % deformable_group == 0, "oc %% deformable_group != 0");
  131. megdnn_assert(
  132. (offset[1] % (2 * fh * fw) == 0) && (mask[1] % (fh * fw) == 0),
  133. "invalid deformable group deduced from offset(%s) or mask(%s)",
  134. megdnn_layout_msg(offset).c_str(), megdnn_layout_msg(mask).c_str());
  135. megdnn_assert((offset[1] / (2 * fh * fw)) == (mask[1] / (fh * fw)),
  136. "offset(%s) and mask(%s) should have same deformable group",
  137. megdnn_layout_msg(offset).c_str(),
  138. megdnn_layout_msg(mask).c_str());
  139. megdnn_assert((offset[2] == mask[2]) && (offset[3] == mask[3]),
  140. "offset(%s) and mask(%s) should have same spatial dim",
  141. megdnn_layout_msg(offset).c_str(),
  142. megdnn_layout_msg(mask).c_str());
  143. megdnn_assert(oh == offset[2], "deduced oh(%zu) != offset oh(%zu)", oh,
  144. offset[2]);
  145. megdnn_assert(ow == offset[3], "deduced ow(%zu) != offset ow(%zu)", ow,
  146. offset[3]);
  147. dst.ndim = 4;
  148. dst = {{n, oc, oh, ow}, im.dtype};
  149. }
  150. void DeformableConvBase::check_layout_fwd(const TensorLayout& im,
  151. const TensorLayout& filter,
  152. const TensorLayout& offset,
  153. const TensorLayout& mask,
  154. const TensorLayout& dst) {
  155. auto& im_dtype = im.dtype;
  156. TensorLayout dst_expected;
  157. megdnn_assert(im_dtype.enumv() == DTypeEnum::Float32,
  158. "DeformableConv only support float32 input");
  159. megdnn_assert_eq_dtype(im, dst);
  160. megdnn_assert_eq_dtype(im, filter);
  161. megdnn_assert_eq_dtype(im, dst);
  162. megdnn_assert_eq_dtype(im, offset);
  163. megdnn_assert_eq_dtype(im, mask);
  164. deduce_layout_fwd(im, filter, offset, mask, dst_expected);
  165. megdnn_assert_eq_layout(dst_expected, dst);
  166. }
  167. void DeformableConvForward::deduce_layout(const TensorLayout& im,
  168. const TensorLayout& filter,
  169. const TensorLayout& offset,
  170. const TensorLayout& mask,
  171. TensorLayout& dst) {
  172. deduce_layout_fwd(im, filter, offset, mask, dst);
  173. return;
  174. }
  175. CanonizedFilterMeta DeformableConvForward::check_exec(
  176. const TensorLayout& im, const TensorLayout& filter,
  177. const TensorLayout& offset, const TensorLayout& mask,
  178. const TensorLayout& dst, size_t workspace_in_bytes) {
  179. auto ret = make_canonized_filter_meta(im.ndim, filter, offset);
  180. auto required_workspace_in_bytes =
  181. get_workspace_in_bytes(im, filter, offset, mask, dst);
  182. check_layout_fwd(im, filter, offset, mask, dst);
  183. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  184. return ret;
  185. }
  186. CanonizedFilterMeta DeformableConvBackwardFilter::check_exec(
  187. const TensorLayout& im, const TensorLayout& offset,
  188. const TensorLayout& mask, const TensorLayout& out_grad,
  189. const TensorLayout& filter_grad, size_t workspace_in_bytes) {
  190. check_layout_fwd(im, filter_grad, offset, mask, out_grad);
  191. // check dtype
  192. megdnn_assert_eq_dtype(im, filter_grad);
  193. auto ret = make_canonized_filter_meta(im.ndim, filter_grad, offset);
  194. auto required_workspace_in_bytes =
  195. get_workspace_in_bytes(im, offset, mask, out_grad, filter_grad);
  196. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  197. return ret;
  198. }
  199. CanonizedFilterMeta DeformableConvBackwardData::check_exec(
  200. const TensorLayout& im, const TensorLayout& filter,
  201. const TensorLayout& offset, const TensorLayout& mask,
  202. const TensorLayout& out_grad, const TensorLayout& im_grad,
  203. const TensorLayout& offset_grad, const TensorLayout& mask_grad,
  204. size_t workspace_in_bytes) {
  205. check_layout_fwd(im, filter, offset, mask, out_grad);
  206. // check dtype
  207. megdnn_assert_eq_dtype(im, im_grad);
  208. megdnn_assert_eq_dtype(im, offset_grad);
  209. megdnn_assert_eq_dtype(im, mask_grad);
  210. // check layout
  211. megdnn_assert(im.shape == im_grad.shape, "invalid im_grad shape: %s",
  212. megdnn_layout_msg(im_grad).c_str());
  213. megdnn_assert(offset.shape == offset_grad.shape,
  214. "invalid offset_grad shape: %s",
  215. megdnn_layout_msg(offset_grad).c_str());
  216. megdnn_assert(mask.shape == mask_grad.shape, "invalid mask_grad shape: %s",
  217. megdnn_layout_msg(mask_grad).c_str());
  218. auto ret = make_canonized_filter_meta(im.ndim, filter, offset);
  219. auto required_workspace_in_bytes =
  220. get_workspace_in_bytes(im, filter, offset, mask, out_grad, im_grad,
  221. offset_grad, mask_grad);
  222. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  223. return ret;
  224. }
  225. } // namespace megdnn
  226. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台