You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

warp_perspective.cpp 16 kB


  1. /**
  2. * \file dnn/src/common/warp_perspective.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "megdnn/oprs.h"
  13. #include "src/common/utils.h"
  14. namespace megdnn {
  15. void WarpPerspectiveBase::check_layout_fwd(
  16. const TensorLayout& src, const TensorLayout& mat, const TensorLayout& mat_idx,
  17. const TensorLayout& dst) {
  18. megdnn_assert_contiguous(mat);
  19. megdnn_assert_contiguous(src);
  20. megdnn_assert_contiguous(dst);
  21. auto errmsg = [&]() {
  22. return megdnn_layout_msg(src) + ", " + megdnn_layout_msg(mat) + ", " +
  23. megdnn_layout_msg(mat_idx) + ", " + megdnn_layout_msg(dst) + ", " +
  24. param_msg();
  25. };
  26. MEGDNN_MARK_USED_VAR(errmsg);
  27. if (param().format == param::WarpPerspective::Format::NHWCD4 ||
  28. param().format == param::WarpPerspective::Format::NCHW4 ||
  29. param().format == param::WarpPerspective::Format::NCHW64) {
  30. megdnn_assert(src.ndim == 5_z, "%s", errmsg().c_str());
  31. megdnn_assert(dst.ndim == 5_z, "%s", errmsg().c_str());
  32. } else if (
  33. param().format == param::WarpPerspective::Format::NHWC_NCHW4_IC_SMALL ||
  34. param().format == param::WarpPerspective::Format::NCHW_NCHW4_IC_SMALL) {
  35. megdnn_assert(src.ndim == 4_z, "%s", errmsg().c_str());
  36. megdnn_assert(dst.ndim == 5_z, "%s", errmsg().c_str());
  37. } else {
  38. megdnn_assert(
  39. param().format == param::WarpPerspective::Format::NHWC ||
  40. param().format == param::WarpPerspective::Format::NCHW ||
  41. param().format == param::WarpPerspective::Format::NHWC_NCHW);
  42. megdnn_assert(src.ndim == 4_z, "%s", errmsg().c_str());
  43. megdnn_assert(dst.ndim == 4_z, "%s", errmsg().c_str());
  44. }
  45. megdnn_assert(mat.ndim == 3_z, "%s", errmsg().c_str());
  46. megdnn_assert(dst.shape[0] == mat.shape[0], "%s", errmsg().c_str());
  47. if (mat_idx.ndim) {
  48. megdnn_assert(
  49. mat_idx.dtype == dtype::Int32() && mat_idx.ndim == 1, "%s",
  50. errmsg().c_str());
  51. megdnn_assert(mat.shape[0] == mat_idx.shape[0], "%s", errmsg().c_str());
  52. megdnn_assert_contiguous(mat_idx);
  53. } else {
  54. megdnn_assert(src.shape[0] == dst.shape[0], "%s", errmsg().c_str());
  55. }
  56. megdnn_assert(mat.shape[1] == 3_z, "%s", errmsg().c_str());
  57. megdnn_assert(mat.shape[2] == 3_z, "%s", errmsg().c_str());
  58. if (src.format == dst.format && dst.dtype == src.dtype) {
  59. if (param().format == param::WarpPerspective::Format::NCHW) {
  60. megdnn_assert(
  61. src.dtype.enumv() == DTypeEnum::Float32 ||
  62. DNN_FLOAT16_SELECT(
  63. (src.dtype.enumv() == DTypeEnum::Float16 ||
  64. src.dtype.enumv() == DTypeEnum::BFloat16),
  65. false) ||
  66. src.dtype.enumv() == DTypeEnum::Int8 ||
  67. src.dtype.enumv() == DTypeEnum::Uint8 ||
  68. (src.dtype.enumv() == DTypeEnum::QuantizedS8 ||
  69. src.dtype.enumv() == DTypeEnum::Quantized8Asymm) ||
  70. src.dtype.enumv() == DTypeEnum::QuantizedS4 ||
  71. src.dtype.enumv() == DTypeEnum::Quantized4Asymm,
  72. "WarpPerspective NCHW input dtype should be "
  73. "Float32/Int8/Uint8/QInt8/QUint8/QInt4/QUInt4" DNN_FLOAT16_SELECT(
  74. "/Float16/BFloat16", "") ".");
  75. megdnn_assert(
  76. (src.dtype.category() == DTypeCategory::FLOAT &&
  77. (src.dtype == mat.dtype ||
  78. mat.dtype.enumv() == DTypeEnum::Float32)) ||
  79. ((src.dtype.category() == DTypeCategory::INT ||
  80. src.dtype.category() == DTypeCategory::QUANTIZED) &&
  81. mat.dtype.enumv() == DTypeEnum::Float32),
  82. "The input to WarpPerspective is in NCHW format, in this "
  83. "case, if the input dtype is floating point, the "
  84. "transformation matrix should have same dtype as the "
  85. "input, otherwise, it should be in Float32, %s given.",
  86. mat.dtype.name());
  87. megdnn_assert(src.shape[1] == dst.shape[1], "%s", errmsg().c_str());
  88. megdnn_assert(
  89. param().imode == param::WarpPerspective::InterpolationMode::LINEAR);
  90. megdnn_assert(
  91. param().bmode != param::WarpPerspective::BorderMode::TRANSPARENT);
  92. megdnn_assert(
  93. param().bmode != param::WarpPerspective::BorderMode::ISOLATED);
  94. } else if (param().format == param::WarpPerspective::Format::NHWC) {
  95. megdnn_assert(src.shape[3] == dst.shape[3], "%s", errmsg().c_str());
  96. } else if (param().format == param::WarpPerspective::Format::NCHW4) {
  97. megdnn_assert(
  98. src.dtype.enumv() == DTypeEnum::QuantizedS8,
  99. "src expected QuantizedS8, but got %s", src.dtype.name());
  100. megdnn_assert(
  101. mat.dtype == dtype::Float32(),
  102. "matrix dtype expected float, got %s", mat.dtype.name());
  103. megdnn_assert(src.shape[4] == 4 && dst.shape[4] == 4);
  104. megdnn_assert(src.shape[1] == dst.shape[1], "%s", errmsg().c_str());
  105. megdnn_assert(
  106. param().imode == param::WarpPerspective::InterpolationMode::LINEAR);
  107. megdnn_assert(
  108. param().bmode != param::WarpPerspective::BorderMode::TRANSPARENT);
  109. megdnn_assert(
  110. param().bmode != param::WarpPerspective::BorderMode::ISOLATED);
  111. } else if (param().format == param::WarpPerspective::Format::NCHW64) {
  112. megdnn_assert(
  113. (src.dtype.enumv() == DTypeEnum::QuantizedS4 ||
  114. src.dtype.enumv() == DTypeEnum::Quantized4Asymm),
  115. "src expected QuantizedS4/Quantized4Asymm, but got %s",
  116. src.dtype.name());
  117. megdnn_assert(
  118. mat.dtype == dtype::Float32(),
  119. "matrix dtype expected float, got %s", mat.dtype.name());
  120. megdnn_assert(src.shape[4] == 64 && dst.shape[4] == 64);
  121. megdnn_assert(src.shape[1] == dst.shape[1], "%s", errmsg().c_str());
  122. megdnn_assert(
  123. param().imode == param::WarpPerspective::InterpolationMode::LINEAR);
  124. megdnn_assert(
  125. param().bmode != param::WarpPerspective::BorderMode::TRANSPARENT);
  126. megdnn_assert(
  127. param().bmode != param::WarpPerspective::BorderMode::ISOLATED);
  128. } else {
  129. megdnn_assert(param().format == param::WarpPerspective::Format::NHWCD4);
  130. megdnn_assert(
  131. src.dtype == dtype::Float32() ||
  132. DNN_FLOAT16_SELECT(
  133. (src.dtype == dtype::Float16() ||
  134. src.dtype == dtype::BFloat16()),
  135. false) ||
  136. src.dtype.enumv() == DTypeEnum::QuantizedS8 ||
  137. src.dtype.enumv() == DTypeEnum::Quantized8Asymm,
  138. "WarpPerspective NHWCD4 input dtype should be "
  139. "Float32" DNN_FLOAT16_SELECT(
  140. "/Float16/BFloat16", "") ",QunatizedS8, Quantized8Asymm.");
  141. megdnn_assert(
  142. (src.dtype == mat.dtype || mat.dtype == dtype::Float32()),
  143. "The input to WarpPerspective is in NHWCD4 format, in this "
  144. "case, if the input dtype is floating point, the "
  145. "transformation matrix should have same dtype as the "
  146. "input, %s given.",
  147. mat.dtype.name());
  148. //! number of channels is same
  149. megdnn_assert(src.shape[2] == dst.shape[2], "%s", errmsg().c_str());
  150. megdnn_assert(
  151. param().imode == param::WarpPerspective::InterpolationMode::LINEAR);
  152. megdnn_assert(
  153. param().bmode != param::WarpPerspective::BorderMode::TRANSPARENT);
  154. megdnn_assert(
  155. param().bmode != param::WarpPerspective::BorderMode::ISOLATED);
  156. }
  157. } else if (
  158. param().format == param::WarpPerspective::Format::NHWC_NCHW4_IC_SMALL ||
  159. param().format == param::WarpPerspective::Format::NCHW_NCHW4_IC_SMALL) {
  160. megdnn_assert(
  161. (src.dtype.enumv() == DTypeEnum::Quantized8Asymm ||
  162. src.dtype.enumv() == DTypeEnum::Uint8),
  163. "src expected Quantized8Asymm or Uint8, but got %s", src.dtype.name());
  164. megdnn_assert(
  165. mat.dtype == dtype::Float32(), "matrix dtype expected float, got %s",
  166. mat.dtype.name());
  167. megdnn_assert(dst.shape[4] == 4);
  168. megdnn_assert(
  169. param().imode == param::WarpPerspective::InterpolationMode::LINEAR);
  170. megdnn_assert(param().bmode != param::WarpPerspective::BorderMode::TRANSPARENT);
  171. megdnn_assert(param().bmode != param::WarpPerspective::BorderMode::ISOLATED);
  172. } else if (param().format == param::WarpPerspective::Format::NHWC_NCHW) {
  173. megdnn_assert(
  174. (src.dtype.enumv() == DTypeEnum::Quantized8Asymm ||
  175. src.dtype.enumv() == DTypeEnum::Uint8),
  176. "src expected Quantized8Asymm or Uint8, but got %s", src.dtype.name());
  177. megdnn_assert(
  178. mat.dtype == dtype::Float32(), "matrix dtype expected float, got %s",
  179. mat.dtype.name());
  180. megdnn_assert(src.shape[3] == dst.shape[1], "%s", errmsg().c_str());
  181. megdnn_assert(
  182. param().imode == param::WarpPerspective::InterpolationMode::LINEAR);
  183. megdnn_assert(param().bmode != param::WarpPerspective::BorderMode::TRANSPARENT);
  184. megdnn_assert(param().bmode != param::WarpPerspective::BorderMode::ISOLATED);
  185. } else {
  186. megdnn_assert(param().format == param::WarpPerspective::Format::NCHW);
  187. megdnn_assert(
  188. (src.dtype.enumv() == DTypeEnum::Quantized8Asymm ||
  189. src.dtype.enumv() == DTypeEnum::Uint8) &&
  190. dst.dtype.enumv() == DTypeEnum::Float32);
  191. }
  192. }
  193. std::string WarpPerspectiveBase::param_msg() const {
  194. std::string res;
  195. res.append("imode=");
  196. switch (param().imode) {
  197. case InterpolationMode::NEAREST:
  198. res.append("NEAREST");
  199. break;
  200. case InterpolationMode::LINEAR:
  201. res.append("LINEAR");
  202. break;
  203. case InterpolationMode::AREA:
  204. res.append("AREA");
  205. break;
  206. case InterpolationMode::CUBIC:
  207. res.append("CUBIC");
  208. break;
  209. case InterpolationMode::LANCZOS4:
  210. res.append("LANCZOS4");
  211. break;
  212. }
  213. res.append(", bmode=");
  214. switch (param().bmode) {
  215. case BorderMode::WRAP:
  216. res.append("WRAP");
  217. break;
  218. case BorderMode::CONSTANT:
  219. res.append("CONSTANT");
  220. break;
  221. case BorderMode::REFLECT:
  222. res.append("REFLECT");
  223. break;
  224. case BorderMode::REFLECT_101:
  225. res.append("REFLECT_101");
  226. break;
  227. case BorderMode::REPLICATE:
  228. res.append("REPLICATE");
  229. break;
  230. case BorderMode::TRANSPARENT:
  231. res.append("TRANSPARENT");
  232. break;
  233. case BorderMode::ISOLATED:
  234. res.append("ISOLATED");
  235. break;
  236. }
  237. if (param().bmode == BorderMode::CONSTANT) {
  238. res.append(", " + std::to_string(param().border_val));
  239. }
  240. return res;
  241. }
  242. int WarpPerspectiveBase::get_real_coord(int p, int len) {
  243. auto bmode = param().bmode;
  244. if ((unsigned)p < (unsigned)len)
  245. ;
  246. else if (bmode == BorderMode::REPLICATE)
  247. p = p < 0 ? 0 : len - 1;
  248. else if (bmode == BorderMode::REFLECT || bmode == BorderMode::REFLECT_101) {
  249. int delta = (bmode == BorderMode::REFLECT_101);
  250. if (len == 1)
  251. return 0;
  252. do {
  253. if (p < 0)
  254. p = -p - 1 + delta;
  255. else
  256. p = len - 1 - (p - len) - delta;
  257. } while ((unsigned)p >= (unsigned)len);
  258. } else if (bmode == BorderMode::WRAP) {
  259. if (p < 0)
  260. p -= ((p - len + 1) / len) * len;
  261. /*
  262. if( p >= len )
  263. p %= len;
  264. */
  265. while (p >= len) {
  266. p -= len;
  267. }
  268. } else if (bmode == BorderMode::CONSTANT)
  269. p = -1;
  270. return p;
  271. }
  272. void WarpPerspectiveForward::check_exec(
  273. const TensorLayout& src, const TensorLayout& mat, const TensorLayout& mat_idx,
  274. const TensorLayout& dst, size_t workspace_in_bytes) {
  275. check_exec_allow_nhwc_mat_idx(src, mat, mat_idx, dst, workspace_in_bytes);
  276. }
  277. void WarpPerspectiveForward::check_exec_allow_nhwc_mat_idx(
  278. const TensorLayout& src, const TensorLayout& mat, const TensorLayout& mat_idx,
  279. const TensorLayout& dst, size_t workspace_in_bytes) {
  280. check_layout_fwd(src, mat, mat_idx, dst);
  281. auto required_workspace_in_bytes = get_workspace_in_bytes(src, mat, mat_idx, dst);
  282. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  283. if (param().format != Param::Format::NHWC &&
  284. param().format != Param::Format::NCHW &&
  285. param().format != Param::Format::NCHW4 &&
  286. param().format != Param::Format::NHWC_NCHW &&
  287. param().format != Param::Format::NHWC_NCHW4_IC_SMALL &&
  288. param().format != Param::Format::NCHW_NCHW4_IC_SMALL &&
  289. param().format != Param::Format::NCHW64) {
  290. megdnn_assert(!mat_idx.ndim, "mat_idx not supported for current format");
  291. }
  292. }
  293. void WarpPerspectiveBackwardData::check_exec(
  294. const TensorLayout& mat, const TensorLayout& mat_idx, const TensorLayout& diff,
  295. const TensorLayout& grad, size_t workspace_in_bytes) {
  296. check_layout_fwd(grad, mat, mat_idx, diff);
  297. megdnn_assert(
  298. grad.dtype == dtype::Float32()
  299. DNN_INC_FLOAT16(|| grad.dtype == dtype::BFloat16()),
  300. "Backward WarpPerspective only supports Float32/BFloat16.");
  301. auto required_workspace_in_bytes = get_workspace_in_bytes(mat, mat_idx, diff, grad);
  302. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  303. }
  304. void WarpPerspectiveBackwardMat::check_exec(
  305. const TensorLayout& src, const TensorLayout& mat, const TensorLayout& mat_idx,
  306. const TensorLayout& diff, const TensorLayout& grad, size_t workspace_in_bytes) {
  307. check_layout_fwd(src, mat, mat_idx, diff);
  308. megdnn_assert_eq_layout(mat, grad);
  309. megdnn_assert(
  310. grad.dtype == dtype::Float32()
  311. DNN_INC_FLOAT16(|| grad.dtype == dtype::BFloat16()),
  312. "Backward WarpPerspective only supports Float32/BFloat16.");
  313. auto required_workspace_in_bytes =
  314. get_workspace_in_bytes(src, mat, mat_idx, diff, grad);
  315. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  316. }
  317. } // namespace megdnn
  318. // vim: syntax=cpp.doxygen