You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

warp_perspective.cpp 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. /**
  2. * \file dnn/src/common/warp_perspective.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megdnn/oprs.h"
  12. #include "src/common/utils.h"
  13. namespace megdnn {
  14. void WarpPerspectiveBase::check_layout_fwd(const TensorLayout &src,
  15. const TensorLayout &mat,
  16. const TensorLayout &mat_idx,
  17. const TensorLayout &dst)
  18. {
  19. megdnn_assert_contiguous(mat);
  20. megdnn_assert_contiguous(src);
  21. megdnn_assert_contiguous(dst);
  22. auto errmsg = [&]() {
  23. return megdnn_layout_msg(src) + ", " +
  24. megdnn_layout_msg(mat) + ", " +
  25. megdnn_layout_msg(mat_idx) + ", " +
  26. megdnn_layout_msg(dst) + ", " +
  27. param_msg();
  28. };
  29. MEGDNN_MARK_USED_VAR(errmsg);
  30. if (param().format == param::WarpPerspective::Format::NHWCD4 ||
  31. param().format == param::WarpPerspective::Format::NCHW4) {
  32. megdnn_assert(src.ndim == 5_z, "%s", errmsg().c_str());
  33. megdnn_assert(dst.ndim == 5_z, "%s", errmsg().c_str());
  34. } else {
  35. megdnn_assert(param().format == param::WarpPerspective::Format::NHWC ||
  36. param().format == param::WarpPerspective::Format::NCHW);
  37. megdnn_assert(src.ndim == 4_z, "%s", errmsg().c_str());
  38. megdnn_assert(dst.ndim == 4_z, "%s", errmsg().c_str());
  39. }
  40. megdnn_assert(mat.ndim == 3_z, "%s", errmsg().c_str());
  41. megdnn_assert(dst.shape[0] == mat.shape[0], "%s", errmsg().c_str());
  42. if (mat_idx.ndim) {
  43. megdnn_assert(mat_idx.dtype == dtype::Int32() && mat_idx.ndim == 1,
  44. "%s", errmsg().c_str());
  45. megdnn_assert(mat.shape[0] == mat_idx.shape[0], "%s", errmsg().c_str());
  46. megdnn_assert_contiguous(mat_idx);
  47. } else {
  48. megdnn_assert(src.shape[0] == dst.shape[0], "%s", errmsg().c_str());
  49. }
  50. megdnn_assert(mat.shape[1] == 3_z, "%s", errmsg().c_str());
  51. megdnn_assert(mat.shape[2] == 3_z, "%s", errmsg().c_str());
  52. if (param().format == param::WarpPerspective::Format::NCHW) {
  53. megdnn_assert(
  54. src.dtype.enumv() == DTypeEnum::Float32 ||
  55. MEGDNN_FLOAT16_SELECT(
  56. (src.dtype.enumv() == DTypeEnum::Float16 ||
  57. src.dtype.enumv() == DTypeEnum::BFloat16),
  58. false) ||
  59. src.dtype.enumv() == DTypeEnum::Int8 ||
  60. src.dtype.enumv() == DTypeEnum::Uint8 ||
  61. (src.dtype.enumv() == DTypeEnum::QuantizedS8 ||
  62. src.dtype.enumv() == DTypeEnum::Quantized8Asymm),
  63. "WarpPerspective NCHW input dtype should be "
  64. "Float32/Int8/Uint8/QInt8/QUint8" MEGDNN_FLOAT16_SELECT(
  65. "/Float16/BFloat16", "") ".");
  66. megdnn_assert(
  67. (src.dtype.category() == DTypeCategory::FLOAT &&
  68. (src.dtype == mat.dtype ||
  69. mat.dtype.enumv() == DTypeEnum::Float32)) ||
  70. ((src.dtype.category() == DTypeCategory::INT ||
  71. src.dtype.category() == DTypeCategory::QUANTIZED) &&
  72. mat.dtype.enumv() == DTypeEnum::Float32),
  73. "The input to WarpPerspective is in NCHW format, in this "
  74. "case, if the input dtype is floating point, the "
  75. "transformation matrix should have same dtype as the "
  76. "input, otherwise, it should be in Float32, %s given.",
  77. mat.dtype.name());
  78. megdnn_assert(dst.dtype == src.dtype);
  79. megdnn_assert(src.shape[1] == dst.shape[1], "%s", errmsg().c_str());
  80. megdnn_assert(param().imode ==
  81. param::WarpPerspective::InterpolationMode::LINEAR);
  82. megdnn_assert(param().bmode !=
  83. param::WarpPerspective::BorderMode::TRANSPARENT);
  84. megdnn_assert(param().bmode !=
  85. param::WarpPerspective::BorderMode::ISOLATED);
  86. } else if (param().format == param::WarpPerspective::Format::NHWC) {
  87. megdnn_assert(src.shape[3] == dst.shape[3], "%s", errmsg().c_str());
  88. } else if (param().format == param::WarpPerspective::Format::NCHW4) {
  89. megdnn_assert(dst.dtype == src.dtype);
  90. megdnn_assert(src.dtype.enumv() == DTypeEnum::QuantizedS8,
  91. "src expected QuantizedS8, but got %s", src.dtype.name());
  92. megdnn_assert(mat.dtype == dtype::Float32(),
  93. "matrix dtype expected float, got %s", mat.dtype.name());
  94. megdnn_assert(src.shape[4] == 4 && dst.shape[4] == 4);
  95. megdnn_assert(src.shape[1] == dst.shape[1], "%s", errmsg().c_str());
  96. megdnn_assert(param().imode ==
  97. param::WarpPerspective::InterpolationMode::LINEAR);
  98. megdnn_assert(param().bmode !=
  99. param::WarpPerspective::BorderMode::TRANSPARENT);
  100. megdnn_assert(param().bmode !=
  101. param::WarpPerspective::BorderMode::ISOLATED);
  102. } else {
  103. megdnn_assert(param().format == param::WarpPerspective::Format::NHWCD4);
  104. megdnn_assert(
  105. src.dtype == dtype::Float32() ||
  106. MEGDNN_FLOAT16_SELECT((src.dtype == dtype::Float16() ||
  107. src.dtype == dtype::BFloat16()),
  108. false) ||
  109. src.dtype.enumv() == DTypeEnum::QuantizedS8 ||
  110. src.dtype.enumv() == DTypeEnum::Quantized8Asymm,
  111. "WarpPerspective NHWCD4 input dtype should be "
  112. "Float32" MEGDNN_FLOAT16_SELECT(
  113. "/Float16/BFloat16",
  114. "") ",QunatizedS8, Quantized8Asymm.");
  115. megdnn_assert(
  116. (src.dtype == mat.dtype || mat.dtype == dtype::Float32()),
  117. "The input to WarpPerspective is in NHWCD4 format, in this "
  118. "case, if the input dtype is floating point, the "
  119. "transformation matrix should have same dtype as the "
  120. "input, %s given.",
  121. mat.dtype.name());
  122. megdnn_assert(dst.dtype == src.dtype);
  123. //! number of channels is same
  124. megdnn_assert(src.shape[2] == dst.shape[2], "%s", errmsg().c_str());
  125. megdnn_assert(param().imode ==
  126. param::WarpPerspective::InterpolationMode::LINEAR);
  127. megdnn_assert(param().bmode !=
  128. param::WarpPerspective::BorderMode::TRANSPARENT);
  129. megdnn_assert(param().bmode !=
  130. param::WarpPerspective::BorderMode::ISOLATED);
  131. }
  132. megdnn_assert(src.format == dst.format);
  133. }
  134. std::string WarpPerspectiveBase::param_msg() const
  135. {
  136. std::string res;
  137. res.append(megdnn_mangle("imode="));
  138. switch (param().imode) {
  139. case InterpolationMode::NEAREST:
  140. res.append(megdnn_mangle("NEAREST"));
  141. break;
  142. case InterpolationMode::LINEAR:
  143. res.append(megdnn_mangle("LINEAR"));
  144. break;
  145. case InterpolationMode::AREA:
  146. res.append(megdnn_mangle("AREA"));
  147. break;
  148. case InterpolationMode::CUBIC:
  149. res.append(megdnn_mangle("CUBIC"));
  150. break;
  151. case InterpolationMode::LANCZOS4:
  152. res.append(megdnn_mangle("LANCZOS4"));
  153. break;
  154. }
  155. res.append(megdnn_mangle("bmode="));
  156. switch (param().bmode) {
  157. case BorderMode::WRAP:
  158. res.append(megdnn_mangle("WRAP"));
  159. break;
  160. case BorderMode::CONSTANT:
  161. res.append(megdnn_mangle("CONSTANT"));
  162. break;
  163. case BorderMode::REFLECT:
  164. res.append(megdnn_mangle("REFLECT"));
  165. break;
  166. case BorderMode::REFLECT_101:
  167. res.append(megdnn_mangle("REFLECT_101"));
  168. break;
  169. case BorderMode::REPLICATE:
  170. res.append(megdnn_mangle("REPLICATE"));
  171. break;
  172. case BorderMode::TRANSPARENT:
  173. res.append(megdnn_mangle("TRANSPARENT"));
  174. break;
  175. case BorderMode::ISOLATED:
  176. res.append(megdnn_mangle("ISOLATED"));
  177. break;
  178. }
  179. if (param().bmode == BorderMode::CONSTANT) {
  180. res.append(", " + std::to_string(param().border_val));
  181. }
  182. return res;
  183. }
  184. int WarpPerspectiveBase::get_real_coord(int p, int len)
  185. {
  186. auto bmode = param().bmode;
  187. if( (unsigned)p < (unsigned)len )
  188. ;
  189. else if( bmode == BorderMode::REPLICATE )
  190. p = p < 0 ? 0 : len - 1;
  191. else if( bmode == BorderMode::REFLECT || bmode == BorderMode::REFLECT_101 )
  192. {
  193. int delta = (bmode == BorderMode::REFLECT_101);
  194. if( len == 1 )
  195. return 0;
  196. do
  197. {
  198. if( p < 0 )
  199. p = -p - 1 + delta;
  200. else
  201. p = len - 1 - (p - len) - delta;
  202. }
  203. while( (unsigned)p >= (unsigned)len );
  204. }
  205. else if( bmode == BorderMode::WRAP )
  206. {
  207. if( p < 0 )
  208. p -= ((p-len+1)/len)*len;
  209. /*
  210. if( p >= len )
  211. p %= len;
  212. */
  213. while (p >= len) {
  214. p -= len;
  215. }
  216. }
  217. else if( bmode == BorderMode::CONSTANT )
  218. p = -1;
  219. return p;
  220. }
  221. void WarpPerspectiveForward::check_exec(const TensorLayout &src,
  222. const TensorLayout &mat,
  223. const TensorLayout &mat_idx,
  224. const TensorLayout &dst,
  225. size_t workspace_in_bytes)
  226. {
  227. check_exec_allow_nhwc_mat_idx(src, mat, mat_idx, dst, workspace_in_bytes);
  228. if (param().format == Param::Format::NHWC) {
  229. megdnn_assert(!mat_idx.ndim,
  230. "mat_idx not supported for current format");
  231. }
  232. }
  233. void WarpPerspectiveForward::check_exec_allow_nhwc_mat_idx(
  234. const TensorLayout& src, const TensorLayout& mat,
  235. const TensorLayout& mat_idx, const TensorLayout& dst,
  236. size_t workspace_in_bytes) {
  237. check_layout_fwd(src, mat, mat_idx, dst);
  238. auto required_workspace_in_bytes =
  239. get_workspace_in_bytes(src, mat, mat_idx, dst);
  240. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  241. if (param().format != Param::Format::NHWC &&
  242. param().format != Param::Format::NCHW &&
  243. param().format != Param::Format::NCHW4) {
  244. megdnn_assert(!mat_idx.ndim,
  245. "mat_idx not supported for current format");
  246. }
  247. }
  248. void WarpPerspectiveBackwardData::check_exec(const TensorLayout& mat,
  249. const TensorLayout& diff,
  250. const TensorLayout& grad,
  251. size_t workspace_in_bytes) {
  252. check_layout_fwd(grad, mat, diff);
  253. megdnn_assert(grad.dtype == dtype::Float32() MEGDNN_INC_FLOAT16(
  254. || grad.dtype == dtype::BFloat16()),
  255. "Backward WarpPerspective only supports Float32/BFloat16.");
  256. auto required_workspace_in_bytes = get_workspace_in_bytes(mat, diff, grad);
  257. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  258. }
  259. void WarpPerspectiveBackwardMat::check_exec(const TensorLayout& src,
  260. const TensorLayout& mat,
  261. const TensorLayout& diff,
  262. const TensorLayout& grad,
  263. size_t workspace_in_bytes) {
  264. check_layout_fwd(src, mat, diff);
  265. megdnn_assert_eq_layout(mat, grad);
  266. megdnn_assert(grad.dtype == dtype::Float32() MEGDNN_INC_FLOAT16(
  267. || grad.dtype == dtype::BFloat16()),
  268. "Backward WarpPerspective only supports Float32/BFloat16.");
  269. auto required_workspace_in_bytes =
  270. get_workspace_in_bytes(src, mat, diff, grad);
  271. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  272. }
  273. } // namespace megdnn
  274. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台