You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

imgproc.cpp 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471
  1. /**
  2. * \file src/opr/impl/imgproc.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "megbrain/opr/imgproc.h"
  13. #include "./internal/megdnn_opr_wrapper.inl"
  14. #include "megbrain/graph/grad_impl.h"
  15. #include "megbrain/opr/utility.h"
  16. using namespace mgb;
  17. using namespace opr;
  18. /* ======================= WarpPerspectiveForward ======================= */
  19. MGB_DYN_TYPE_OBJ_FINAL_IMPL(WarpPerspectiveForward);
  20. WarpPerspectiveForward::WarpPerspectiveForward(VarNode* src, VarNode* mat,
  21. VarNode* mat_idx,
  22. VarNode* out_shape,
  23. const Param& param,
  24. const OperatorNodeConfig& config)
  25. : Super(OperatorNodeBaseCtorParam{
  26. src->owner_graph(), config, "warp_perspective", {src, mat}}) {
  27. init_megdnn_opr(*this, param);
  28. if (mat_idx) {
  29. add_input({src, mat, mat_idx, out_shape});
  30. } else {
  31. add_input({src, mat, out_shape});
  32. }
  33. outshape_by_symvar_enable(input().size() - 1, input().size() - 1);
  34. }
  35. SymbolVar WarpPerspectiveForward::make(SymbolVar i0, SymbolVar i1, SymbolVar i2,
  36. SymbolVar i3, const Param& param,
  37. const OperatorNodeConfig& config) {
  38. return i0.insert_single_output_opr<WarpPerspectiveForward>(
  39. i0.node(), i1.node(), i2.node(), i3.node(), param, config);
  40. }
  41. void WarpPerspectiveForward::init_output_dtype() {
  42. output(0)->dtype(input(0)->dtype());
  43. }
  44. void WarpPerspectiveForward::add_input_layout_constraint() {
  45. mixin::megdnn_utils::add_input_layout_constraint_contig(*this);
  46. }
  47. void WarpPerspectiveForward::outshape_by_symvar_do_get_output_shape(
  48. TensorShape& dest, const ShapeInferInfo& shpinfo) {
  49. TensorShape oshp2d;
  50. cg::copy_tensor_value_to_shape(oshp2d, *shpinfo.shpval_inp_val.at(0));
  51. auto imgshp = shpinfo.shape_inp_shp.at(0),
  52. matshp = shpinfo.shape_inp_shp.at(1);
  53. mgb_assert((imgshp.ndim == 4 || imgshp.ndim == 5) && matshp.ndim == 3 &&
  54. oshp2d.ndim == 2 && matshp.shape[1] == 3 &&
  55. matshp.shape[2] == 3,
  56. "shape mismatch for WarpPerspectiveForward: img=%s mat=%s "
  57. "out2d=%s",
  58. imgshp.to_string().c_str(), matshp.to_string().c_str(),
  59. oshp2d.to_string().c_str());
  60. if (input().size() == 3) {
  61. mgb_assert(imgshp[0] == matshp[0],
  62. "batchsize mismatch: img=%zu mat=%zu", imgshp[0], matshp[0]);
  63. } else {
  64. mgb_assert(input().size() == 4);
  65. auto mat_idx_shp = shpinfo.shape_inp_shp.at(2);
  66. mgb_assert(mat_idx_shp[0] == matshp[0] && mat_idx_shp.ndim == 1,
  67. "invalid mat_idx shape: mat=%zu mat_idx=%s", matshp[0],
  68. mat_idx_shp.to_string().c_str());
  69. }
  70. //! The index of height, e.g.,[b, h, w, c], the height_idx = 1
  71. size_t height_idx = 0;
  72. if (param().format == Param::Format::NCHW ||
  73. param().format == Param::Format::NCHW4) {
  74. height_idx = 2;
  75. } else {
  76. height_idx = 1;
  77. }
  78. dest = imgshp;
  79. dest[0] = matshp[0];
  80. if (param().format == Param::Format::NHWCD4) {
  81. dest.shape[height_idx] = oshp2d.shape[0];
  82. dest.shape[height_idx + 2] = oshp2d.shape[1];
  83. } else {
  84. for (int i = 0; i < 2; ++i)
  85. dest.shape[height_idx + i] = oshp2d.shape[i];
  86. }
  87. }
  88. void WarpPerspectiveForward::init_output_static_infer_desc() {
  89. Super::init_output_static_infer_desc();
  90. init_output_static_infer_desc_workspace(false);
  91. }
  92. void WarpPerspectiveForward::scn_do_execute() {
  93. if (input().size() == 3) {
  94. intl::_MegDNNOprMethInvoker<2, 1>::exec(megdnn_opr(), this);
  95. } else {
  96. intl::_MegDNNOprMethInvoker<3, 1>::exec(megdnn_opr(), this);
  97. }
  98. }
  99. size_t WarpPerspectiveForward::get_workspace_size_bytes(
  100. const TensorShapeArray& input_shapes,
  101. const TensorShapeArray& output_shapes) const {
  102. if (input().size() == 3) {
  103. return intl::_MegDNNOprMethInvoker<2, 1>::get_workspace_in_bytes(
  104. megdnn_opr(), this, input_shapes, output_shapes);
  105. } else {
  106. return intl::_MegDNNOprMethInvoker<3, 1>::get_workspace_in_bytes(
  107. megdnn_opr(), this, input_shapes, output_shapes);
  108. }
  109. }
  110. void WarpPerspectiveForward::record_execute_deps(ExecDependencyArray& deps) {
  111. record_megdnn_opr(deps);
  112. }
  113. #if MGB_ENABLE_GRAD
  114. MGB_IMPL_OPR_GRAD(WarpPerspectiveForward) {
  115. if (opr.input().size() == 4) {
  116. if (wrt_idx == 0) {
  117. // wrt data
  118. SymbolVar grad = WarpPerspectiveBackwardData::make(
  119. opr.input(1), opr.input(2), out_grad[0], opr.input(0),
  120. opr.param());
  121. return grad.node();
  122. } else if (wrt_idx == 1) {
  123. // wrt mat
  124. SymbolVar grad = WarpPerspectiveBackwardMat::make(
  125. opr.input(0), opr.input(1), opr.input(2), out_grad[0],
  126. opr.param());
  127. return grad.node();
  128. } else {
  129. return InvalidGrad::make(opr, wrt_idx);
  130. }
  131. }
  132. mgb_assert(opr.input().size() == 3);
  133. if (wrt_idx == 0) {
  134. // wrt data
  135. SymbolVar grad = WarpPerspectiveBackwardData::make(
  136. opr.input(1), out_grad[0], opr.input(0), opr.param());
  137. return grad.node();
  138. } else if (wrt_idx == 1) {
  139. // wrt mat
  140. SymbolVar grad = WarpPerspectiveBackwardMat::make(
  141. opr.input(0), opr.input(1), out_grad[0], opr.param());
  142. return grad.node();
  143. } else
  144. return InvalidGrad::make(opr, wrt_idx);
  145. }
  146. #endif
  147. /* ====================== WarpPerspectiveBackwardData ====================== */
  148. MGB_DYN_TYPE_OBJ_FINAL_IMPL(WarpPerspectiveBackwardData);
  149. WarpPerspectiveBackwardData::WarpPerspectiveBackwardData(
  150. VarNode* mat, VarNode* out_diff, VarNode* in_for_shape,
  151. const Param& param, const OperatorNodeConfig& config)
  152. : Super(OperatorNodeBaseCtorParam{mat->owner_graph(),
  153. config,
  154. "warp_perspective_bwd_data",
  155. {mat}},
  156. 2, false) {
  157. init_megdnn_opr(*this, param);
  158. add_input({mat, out_diff, in_for_shape});
  159. intl::MegDNNOprInitPostCtor<WarpPerspectiveBackwardData>::apply(*this);
  160. }
  161. WarpPerspectiveBackwardData::WarpPerspectiveBackwardData(
  162. VarNode* mat, VarNode* mat_idx, VarNode* out_diff,
  163. VarNode* in_for_shape, const Param& param,
  164. const OperatorNodeConfig& config)
  165. : Super(OperatorNodeBaseCtorParam{mat->owner_graph(),
  166. config,
  167. "warp_perspective_bwd_data",
  168. {mat, mat_idx}},
  169. 3, false) {
  170. init_megdnn_opr(*this, param);
  171. add_input({mat, mat_idx, out_diff, in_for_shape});
  172. intl::MegDNNOprInitPostCtor<WarpPerspectiveBackwardData>::apply(*this);
  173. }
  174. SymbolVar WarpPerspectiveBackwardData::make(SymbolVar i0, SymbolVar i1,
  175. SymbolVar i2, const Param& param,
  176. const OperatorNodeConfig& config) {
  177. intl::MegDNNOprInitInputsModifier<WarpPerspectiveBackwardData>::apply(
  178. param, {&i0, &i1, &i2});
  179. return i0.insert_single_output_opr<WarpPerspectiveBackwardData>(
  180. i0.node(), i1.node(), i2.node(), param, config);
  181. }
  182. SymbolVar WarpPerspectiveBackwardData::make(SymbolVar i0, SymbolVar i1,
  183. SymbolVar i2, SymbolVar i3,
  184. const Param& param,
  185. const OperatorNodeConfig& config) {
  186. intl::MegDNNOprInitInputsModifier<WarpPerspectiveBackwardData>::apply(
  187. param, {&i0, &i1, &i2, &i3});
  188. return i0.insert_single_output_opr<WarpPerspectiveBackwardData>(
  189. i0.node(), i1.node(), i2.node(), i3.node(), param, config);
  190. }
  191. void WarpPerspectiveBackwardData::scn_do_execute() {
  192. if (input().size() == 3) {
  193. megdnn_opr()->exec(input(0)->dev_tensor().as_megdnn(),
  194. input(1)->dev_tensor().as_megdnn(),
  195. output(0)->dev_tensor().as_megdnn(),
  196. intl::get_megdnn_workspace_from_var(output(1)));
  197. } else {
  198. mgb_assert(input().size() == 4);
  199. megdnn_opr()->exec(input(0)->dev_tensor().as_megdnn(),
  200. input(1)->dev_tensor().as_megdnn(),
  201. input(2)->dev_tensor().as_megdnn(),
  202. output(0)->dev_tensor().as_megdnn(),
  203. intl::get_megdnn_workspace_from_var(output(1)));
  204. }
  205. }
  206. /* ====================== WarpPerspectiveBackwardMat ====================== */
  207. MGB_DYN_TYPE_OBJ_FINAL_IMPL(WarpPerspectiveBackwardMat);
  208. WarpPerspectiveBackwardMat::WarpPerspectiveBackwardMat(
  209. VarNode* src, VarNode* mat, VarNode* mat_idx, VarNode* out_diff,
  210. const Param& param, const OperatorNodeConfig& config)
  211. : Super(OperatorNodeBaseCtorParam{src->owner_graph(),
  212. config,
  213. "warp_perspective_bwd_mat",
  214. {src, mat, mat_idx}},
  215. 1, true) {
  216. init_megdnn_opr(*this, param);
  217. if (mat_idx) {
  218. add_input({src, mat, mat_idx, out_diff});
  219. } else {
  220. add_input({src, mat, out_diff});
  221. }
  222. intl::MegDNNOprInitPostCtor<WarpPerspectiveBackwardMat>::apply(*this);
  223. }
  224. void WarpPerspectiveBackwardMat::scn_do_execute() {
  225. if (input().size() == 3) {
  226. megdnn_opr()->exec(input(0)->dev_tensor().as_megdnn(),
  227. input(1)->dev_tensor().as_megdnn(),
  228. input(2)->dev_tensor().as_megdnn(),
  229. output(0)->dev_tensor().as_megdnn(),
  230. intl::get_megdnn_workspace_from_var(output(1)));
  231. } else {
  232. mgb_assert(input().size() == 4);
  233. megdnn_opr()->exec(input(0)->dev_tensor().as_megdnn(),
  234. input(1)->dev_tensor().as_megdnn(),
  235. input(2)->dev_tensor().as_megdnn(),
  236. input(3)->dev_tensor().as_megdnn(),
  237. output(0)->dev_tensor().as_megdnn(),
  238. intl::get_megdnn_workspace_from_var(output(1)));
  239. }
  240. }
  241. SymbolVar WarpPerspectiveBackwardMat::make(
  242. SymbolVar i0, SymbolVar i1, SymbolVar i2, SymbolVar i3,
  243. const Param& param, const OperatorNodeConfig& config) {
  244. intl::MegDNNOprInitInputsModifier<WarpPerspectiveBackwardMat>::apply(
  245. param, {&i0, &i1, &i2, &i3});
  246. return i0.insert_single_output_opr<WarpPerspectiveBackwardMat>(
  247. i0.node(), i1.node(), i2.node(), i3.node(), param, config);
  248. }
  249. /* ====================== Cv operator ====================== */
  250. MGB_DYN_TYPE_OBJ_FINAL_IMPL(RotateForward);
  251. MEGDNN_OPR_INIT1(RotateForward, "rotate")
  252. MGB_DYN_TYPE_OBJ_FINAL_IMPL(CvtColorForward);
  253. MEGDNN_OPR_INIT1(CvtColorForward, "cvt_color")
  254. MGB_DYN_TYPE_OBJ_FINAL_IMPL(GaussianBlurForward);
  255. MEGDNN_OPR_INIT1(GaussianBlurForward, "gaussion_blur")
  256. /* ======================= ResizeForward ======================= */
  257. MGB_DYN_TYPE_OBJ_FINAL_IMPL(ResizeForward);
  258. MEGDNN_OPR_INIT2(ResizeForward, "resize")
  259. void ResizeForward::init_output_dtype() {
  260. output(0)->dtype(input(0)->dtype());
  261. outshape_by_symvar_enable(1, 1);
  262. }
  263. void ResizeForward::add_input_layout_constraint() {
  264. if (param().format != Param::Format::NCHW) {
  265. input(0)->add_layout_constraint_contiguous();
  266. }
  267. input(1)->add_layout_constraint_contiguous();
  268. }
  269. void ResizeForward::outshape_by_symvar_do_get_output_shape(
  270. TensorShape& dest, const ShapeInferInfo& shpinfo) {
  271. TensorShape oshp2d;
  272. cg::copy_tensor_value_to_shape(oshp2d, *shpinfo.shpval_inp_val.at(0));
  273. auto imgshp = shpinfo.shape_inp_shp.at(0);
  274. mgb_assert((imgshp.ndim == 4 || imgshp.ndim == 5) && oshp2d.ndim == 2,
  275. "shape mismatch for ResizeForward: img=%s out2d=%s",
  276. imgshp.to_string().c_str(), oshp2d.to_string().c_str());
  277. //! The index of height, e.g.,[b, h, w, c], the height_idx = 1
  278. size_t height_idx = 0;
  279. if (param().format == Param::Format::NCHW ||
  280. param().format == Param::Format::NCHW4) {
  281. height_idx = 2;
  282. } else {
  283. height_idx = 1;
  284. }
  285. dest = imgshp;
  286. if (param().format == Param::Format::NHWCD4) {
  287. dest.shape[height_idx] = oshp2d.shape[0];
  288. dest.shape[height_idx + 2] = oshp2d.shape[1];
  289. } else {
  290. for (int i = 0; i < 2; ++i)
  291. dest.shape[height_idx + i] = oshp2d.shape[i];
  292. }
  293. }
  294. void ResizeForward::init_output_static_infer_desc() {
  295. Super::init_output_static_infer_desc();
  296. init_output_static_infer_desc_workspace(false);
  297. }
  298. void ResizeForward::scn_do_execute() {
  299. intl::MegDNNOprMethInvoker<megdnn::Resize>::exec(megdnn_opr(), this);
  300. }
  301. size_t ResizeForward::get_workspace_size_bytes(
  302. const TensorShapeArray& input_shapes,
  303. const TensorShapeArray& output_shapes) const {
  304. return intl::MegDNNOprMethInvoker<megdnn::Resize>::get_workspace_in_bytes(
  305. megdnn_opr(), this, input_shapes, output_shapes);
  306. }
  307. void ResizeForward::record_execute_deps(ExecDependencyArray& deps) {
  308. record_megdnn_opr(deps);
  309. }
  310. #if MGB_ENABLE_GRAD
  311. MGB_IMPL_OPR_GRAD(ResizeForward) {
  312. mgb_assert(opr.input().size() == 2);
  313. if (wrt_idx == 0) {
  314. SymbolVar grad =
  315. ResizeBackward::make(out_grad[0], opr.input(0), opr.param());
  316. return grad.node();
  317. } else
  318. return InvalidGrad::make(opr, wrt_idx);
  319. }
  320. #endif
  321. /* ====================== ResizeBackward ====================== */
  322. MGB_DYN_TYPE_OBJ_FINAL_IMPL(ResizeBackward);
  323. MEGDNN_OPR_INIT2(ResizeBackward, "resize_bwd", 1, false);
  324. /* ======================= WarpAffineForward ======================= */
  325. MGB_DYN_TYPE_OBJ_FINAL_IMPL(WarpAffineForward);
  326. MEGDNN_OPR_INIT3(WarpAffineForward, "warp_affine")
  327. void WarpAffineForward::init_output_dtype() {
  328. output(0)->dtype(input(0)->dtype());
  329. outshape_by_symvar_enable(2, 2);
  330. }
  331. void WarpAffineForward::add_input_layout_constraint() {
  332. mixin::megdnn_utils::add_input_layout_constraint_contig(*this);
  333. }
  334. void WarpAffineForward::outshape_by_symvar_do_get_output_shape(
  335. TensorShape& dest, const ShapeInferInfo& shpinfo) {
  336. TensorShape oshp2d;
  337. cg::copy_tensor_value_to_shape(oshp2d, *shpinfo.shpval_inp_val.at(0));
  338. auto imgshp = shpinfo.shape_inp_shp.at(0),
  339. matshp = shpinfo.shape_inp_shp.at(1);
  340. mgb_assert((imgshp.ndim == 4 || imgshp.ndim == 5) && matshp.ndim == 3 &&
  341. oshp2d.ndim == 2 && matshp.shape[0] == imgshp.shape[0] &&
  342. matshp.shape[1] == 2 && matshp.shape[2] == 3,
  343. "shape mismatch for WarpAffineForward: img=%s mat=%s out2d=%s",
  344. imgshp.to_string().c_str(), matshp.to_string().c_str(),
  345. oshp2d.to_string().c_str());
  346. size_t height_idx = 0;
  347. if (param().format == Param::Format::NCHW) {
  348. height_idx = 2;
  349. } else {
  350. height_idx = 1;
  351. }
  352. dest = imgshp;
  353. if (param().format == Param::Format::NHWCD4) {
  354. dest.shape[height_idx] = oshp2d.shape[0];
  355. dest.shape[height_idx + 2] = oshp2d.shape[1];
  356. } else {
  357. for (int i = 0; i < 2; ++i)
  358. dest.shape[height_idx + i] = oshp2d.shape[i];
  359. }
  360. }
  361. void WarpAffineForward::init_output_static_infer_desc() {
  362. Super::init_output_static_infer_desc();
  363. init_output_static_infer_desc_workspace(false);
  364. }
  365. void WarpAffineForward::scn_do_execute() {
  366. intl::MegDNNOprMethInvoker<megdnn::WarpAffine>::exec(megdnn_opr(), this);
  367. }
  368. size_t WarpAffineForward::get_workspace_size_bytes(
  369. const TensorShapeArray& input_shapes,
  370. const TensorShapeArray& output_shapes) const {
  371. return intl::MegDNNOprMethInvoker<
  372. megdnn::WarpAffine>::get_workspace_in_bytes(megdnn_opr(), this,
  373. input_shapes,
  374. output_shapes);
  375. }
  376. void WarpAffineForward::record_execute_deps(ExecDependencyArray& deps) {
  377. record_megdnn_opr(deps);
  378. }
  379. /* ======================= RemapForward ======================= */
  380. MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemapForward);
  381. MEGDNN_OPR_INIT2(RemapForward, "remap")
  382. void RemapForward::init_output_dtype() {
  383. output(0)->dtype(input(0)->dtype());
  384. }
  385. #if MGB_ENABLE_GRAD
  386. MGB_IMPL_OPR_GRAD(RemapForward) {
  387. mgb_assert(opr.input().size() == 2);
  388. if (wrt_idx == 0) {
  389. SymbolVar grad =
  390. RemapBackwardData::make(opr.input(1), out_grad[0],
  391. opr.input(0), opr.param());
  392. return grad.node();
  393. } else if (wrt_idx == 1) {
  394. SymbolVar grad =
  395. RemapBackwardMat::make(opr.input(0), opr.input(1),
  396. out_grad[0], opr.param());
  397. return grad.node();
  398. } else
  399. return InvalidGrad::make(opr, wrt_idx);
  400. }
  401. #endif
  402. /* ====================== RemapBackward ====================== */
  403. MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemapBackwardData);
  404. MEGDNN_OPR_INIT3(RemapBackwardData, "remap_bwd_data", 2, false);
  405. MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemapBackwardMat);
  406. MEGDNN_OPR_INIT3(RemapBackwardMat, "remap_bwd_mat", 1, true);
  407. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台