You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

imgproc.cpp 25 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626
  1. /**
  2. * \file src/opr/impl/imgproc.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "./internal/megdnn_opr_wrapper.inl"
  13. #include "megbrain/graph/grad_impl.h"
  14. #include "megbrain/opr/imgproc.h"
  15. #include "megbrain/opr/io.h"
  16. #include "megbrain/opr/utility.h"
  17. using namespace mgb;
  18. using namespace opr;
  19. /* ======================= WarpPerspectiveForward ======================= */
  20. MGB_DYN_TYPE_OBJ_FINAL_IMPL(WarpPerspectiveForward);
  21. WarpPerspectiveForward::WarpPerspectiveForward(VarNode* src, VarNode* mat,
  22. VarNode* mat_idx,
  23. VarNode* out_shape,
  24. const Param& param,
  25. const OperatorNodeConfig& config)
  26. : Super(OperatorNodeBaseCtorParam{
  27. src->owner_graph(), config, "warp_perspective", {src, mat}}) {
  28. init_megdnn_opr(*this, param);
  29. if (mat_idx) {
  30. add_input({src, mat, mat_idx, out_shape});
  31. } else {
  32. add_input({src, mat, out_shape});
  33. }
  34. outshape_by_symvar_enable(input().size() - 1, input().size() - 1);
  35. }
  36. SymbolVar WarpPerspectiveForward::make(SymbolVar i0, SymbolVar i1, SymbolVar i2,
  37. SymbolVar i3, const Param& param,
  38. const OperatorNodeConfig& config) {
  39. return i0.insert_single_output_opr<WarpPerspectiveForward>(
  40. i0.node(), i1.node(), i2.node(), i3.node(), param, config);
  41. }
  42. void WarpPerspectiveForward::init_output_dtype() {
  43. if (config().output_dtype().valid()) {
  44. output(0)->dtype(config().output_dtype());
  45. } else {
  46. output(0)->dtype(input(0)->dtype());
  47. }
  48. }
  49. void WarpPerspectiveForward::add_input_layout_constraint() {
  50. mixin::megdnn_utils::add_input_layout_constraint_contig(*this);
  51. }
  52. void WarpPerspectiveForward::outshape_by_symvar_do_get_output_shape(
  53. TensorShape& dest, const ShapeInferInfo& shpinfo) {
  54. TensorShape oshp2d;
  55. cg::copy_tensor_value_to_shape(oshp2d, *shpinfo.shpval_inp_val.at(0));
  56. auto imgshp = shpinfo.shape_inp_shp.at(0),
  57. matshp = shpinfo.shape_inp_shp.at(1);
  58. mgb_assert((imgshp.ndim == 4 || imgshp.ndim == 5) && matshp.ndim == 3 &&
  59. oshp2d.ndim == 2 && matshp.shape[1] == 3 &&
  60. matshp.shape[2] == 3,
  61. "shape mismatch for WarpPerspectiveForward: img=%s mat=%s "
  62. "out2d=%s",
  63. imgshp.to_string().c_str(), matshp.to_string().c_str(),
  64. oshp2d.to_string().c_str());
  65. if (input().size() == 3) {
  66. mgb_assert(imgshp[0] == matshp[0],
  67. "batchsize mismatch: img=%zu mat=%zu", imgshp[0], matshp[0]);
  68. } else {
  69. mgb_assert(input().size() == 4);
  70. auto mat_idx_shp = shpinfo.shape_inp_shp.at(2);
  71. mgb_assert(mat_idx_shp[0] == matshp[0] && mat_idx_shp.ndim == 1,
  72. "invalid mat_idx shape: mat=%zu mat_idx=%s", matshp[0],
  73. mat_idx_shp.to_string().c_str());
  74. }
  75. switch (param().format) {
  76. case Param::Format::NCHW_NCHW4_IC_SMALL:
  77. case Param::Format::NHWC_NCHW4_IC_SMALL:
  78. dest.ndim = 5;
  79. dest[0] = matshp[0];
  80. dest.shape[1] = 1;
  81. dest.shape[2] = oshp2d.shape[0];
  82. dest.shape[3] = oshp2d.shape[1];
  83. dest.shape[4] = 4;
  84. break;
  85. case Param::Format::NHWC_NCHW:
  86. dest[0] = matshp[0];
  87. dest.shape[1] = imgshp.shape[3];
  88. dest.shape[2] = oshp2d.shape[0];
  89. dest.shape[3] = oshp2d.shape[1];
  90. break;
  91. default:
  92. size_t height_idx = 0;
  93. if (param().format == Param::Format::NCHW ||
  94. param().format == Param::Format::NCHW4) {
  95. height_idx = 2;
  96. } else {
  97. height_idx = 1;
  98. }
  99. dest = imgshp;
  100. dest[0] = matshp[0];
  101. if (param().format == Param::Format::NHWCD4) {
  102. dest.shape[height_idx] = oshp2d.shape[0];
  103. dest.shape[height_idx + 2] = oshp2d.shape[1];
  104. } else {
  105. for (int i = 0; i < 2; ++i)
  106. dest.shape[height_idx + i] = oshp2d.shape[i];
  107. }
  108. break;
  109. }
  110. }
  111. void WarpPerspectiveForward::init_output_static_infer_desc() {
  112. Super::init_output_static_infer_desc();
  113. init_output_static_infer_desc_workspace(false);
  114. }
  115. void WarpPerspectiveForward::scn_do_execute() {
  116. if (input().size() == 3) {
  117. intl::_MegDNNOprMethInvoker<2, 1>::exec(megdnn_opr(), this);
  118. } else {
  119. intl::_MegDNNOprMethInvoker<3, 1>::exec(megdnn_opr(), this);
  120. }
  121. }
  122. size_t WarpPerspectiveForward::get_workspace_size_bytes(
  123. const TensorShapeArray& input_shapes,
  124. const TensorShapeArray& output_shapes) const {
  125. if (input().size() == 3) {
  126. return intl::_MegDNNOprMethInvoker<2, 1>::get_workspace_in_bytes(
  127. megdnn_opr(), this, input_shapes, output_shapes);
  128. } else {
  129. return intl::_MegDNNOprMethInvoker<3, 1>::get_workspace_in_bytes(
  130. megdnn_opr(), this, input_shapes, output_shapes);
  131. }
  132. }
  133. void WarpPerspectiveForward::record_execute_deps(ExecDependencyArray& deps) {
  134. record_megdnn_opr(deps);
  135. }
  136. #if MGB_ENABLE_GRAD
  137. MGB_IMPL_OPR_GRAD(WarpPerspectiveForward) {
  138. if (opr.input().size() == 4) {
  139. if (wrt_idx == 0) {
  140. // wrt data
  141. SymbolVar grad = WarpPerspectiveBackwardData::make(
  142. opr.input(1), opr.input(2), out_grad[0], opr.input(0),
  143. opr.param());
  144. return grad.node();
  145. } else if (wrt_idx == 1) {
  146. // wrt mat
  147. SymbolVar grad = WarpPerspectiveBackwardMat::make(
  148. opr.input(0), opr.input(1), opr.input(2), out_grad[0],
  149. opr.param());
  150. return grad.node();
  151. } else {
  152. return InvalidGrad::make(opr, wrt_idx);
  153. }
  154. }
  155. mgb_assert(opr.input().size() == 3);
  156. if (wrt_idx == 0) {
  157. // wrt data
  158. SymbolVar grad = WarpPerspectiveBackwardData::make(
  159. opr.input(1), out_grad[0], opr.input(0), opr.param());
  160. return grad.node();
  161. } else if (wrt_idx == 1) {
  162. // wrt mat
  163. SymbolVar grad = WarpPerspectiveBackwardMat::make(
  164. opr.input(0), opr.input(1), out_grad[0], opr.param());
  165. return grad.node();
  166. } else
  167. return InvalidGrad::make(opr, wrt_idx);
  168. }
  169. #endif
  170. /* ====================== WarpPerspectiveBackwardData ====================== */
  171. MGB_DYN_TYPE_OBJ_FINAL_IMPL(WarpPerspectiveBackwardData);
  172. WarpPerspectiveBackwardData::WarpPerspectiveBackwardData(
  173. VarNode* mat, VarNode* out_diff, VarNode* in_for_shape,
  174. const Param& param, const OperatorNodeConfig& config)
  175. : Super(OperatorNodeBaseCtorParam{mat->owner_graph(),
  176. config,
  177. "warp_perspective_bwd_data",
  178. {mat}},
  179. 2, false) {
  180. init_megdnn_opr(*this, param);
  181. add_input({mat, out_diff, in_for_shape});
  182. intl::MegDNNOprInitPostCtor<WarpPerspectiveBackwardData>::apply(*this);
  183. }
  184. WarpPerspectiveBackwardData::WarpPerspectiveBackwardData(
  185. VarNode* mat, VarNode* mat_idx, VarNode* out_diff,
  186. VarNode* in_for_shape, const Param& param,
  187. const OperatorNodeConfig& config)
  188. : Super(OperatorNodeBaseCtorParam{mat->owner_graph(),
  189. config,
  190. "warp_perspective_bwd_data",
  191. {mat, mat_idx}},
  192. 3, false) {
  193. init_megdnn_opr(*this, param);
  194. add_input({mat, mat_idx, out_diff, in_for_shape});
  195. intl::MegDNNOprInitPostCtor<WarpPerspectiveBackwardData>::apply(*this);
  196. }
  197. SymbolVar WarpPerspectiveBackwardData::make(SymbolVar i0, SymbolVar i1,
  198. SymbolVar i2, const Param& param,
  199. const OperatorNodeConfig& config) {
  200. intl::MegDNNOprInitInputsModifier<WarpPerspectiveBackwardData>::apply(
  201. param, {&i0, &i1, &i2});
  202. return i0.insert_single_output_opr<WarpPerspectiveBackwardData>(
  203. i0.node(), i1.node(), i2.node(), param, config);
  204. }
  205. SymbolVar WarpPerspectiveBackwardData::make(SymbolVar i0, SymbolVar i1,
  206. SymbolVar i2, SymbolVar i3,
  207. const Param& param,
  208. const OperatorNodeConfig& config) {
  209. intl::MegDNNOprInitInputsModifier<WarpPerspectiveBackwardData>::apply(
  210. param, {&i0, &i1, &i2, &i3});
  211. return i0.insert_single_output_opr<WarpPerspectiveBackwardData>(
  212. i0.node(), i1.node(), i2.node(), i3.node(), param, config);
  213. }
  214. void WarpPerspectiveBackwardData::scn_do_execute() {
  215. if (input().size() == 3) {
  216. megdnn_opr()->exec(input(0)->dev_tensor().as_megdnn(),
  217. input(1)->dev_tensor().as_megdnn(),
  218. output(0)->dev_tensor().as_megdnn(),
  219. intl::get_megdnn_workspace_from_var(output(1)));
  220. } else {
  221. mgb_assert(input().size() == 4);
  222. megdnn_opr()->exec(input(0)->dev_tensor().as_megdnn(),
  223. input(1)->dev_tensor().as_megdnn(),
  224. input(2)->dev_tensor().as_megdnn(),
  225. output(0)->dev_tensor().as_megdnn(),
  226. intl::get_megdnn_workspace_from_var(output(1)));
  227. }
  228. }
  229. /* ====================== WarpPerspectiveBackwardMat ====================== */
  230. MGB_DYN_TYPE_OBJ_FINAL_IMPL(WarpPerspectiveBackwardMat);
  231. WarpPerspectiveBackwardMat::WarpPerspectiveBackwardMat(
  232. VarNode* src, VarNode* mat, VarNode* mat_idx, VarNode* out_diff,
  233. const Param& param, const OperatorNodeConfig& config)
  234. : Super(OperatorNodeBaseCtorParam{src->owner_graph(),
  235. config,
  236. "warp_perspective_bwd_mat",
  237. {src, mat, mat_idx}},
  238. 1, true) {
  239. init_megdnn_opr(*this, param);
  240. if (mat_idx) {
  241. add_input({src, mat, mat_idx, out_diff});
  242. } else {
  243. add_input({src, mat, out_diff});
  244. }
  245. intl::MegDNNOprInitPostCtor<WarpPerspectiveBackwardMat>::apply(*this);
  246. }
  247. void WarpPerspectiveBackwardMat::scn_do_execute() {
  248. if (input().size() == 3) {
  249. megdnn_opr()->exec(input(0)->dev_tensor().as_megdnn(),
  250. input(1)->dev_tensor().as_megdnn(),
  251. input(2)->dev_tensor().as_megdnn(),
  252. output(0)->dev_tensor().as_megdnn(),
  253. intl::get_megdnn_workspace_from_var(output(1)));
  254. } else {
  255. mgb_assert(input().size() == 4);
  256. megdnn_opr()->exec(input(0)->dev_tensor().as_megdnn(),
  257. input(1)->dev_tensor().as_megdnn(),
  258. input(2)->dev_tensor().as_megdnn(),
  259. input(3)->dev_tensor().as_megdnn(),
  260. output(0)->dev_tensor().as_megdnn(),
  261. intl::get_megdnn_workspace_from_var(output(1)));
  262. }
  263. }
  264. SymbolVar WarpPerspectiveBackwardMat::make(SymbolVar i0, SymbolVar i1,
  265. SymbolVar i2, SymbolVar i3,
  266. const Param& param,
  267. const OperatorNodeConfig& config) {
  268. intl::MegDNNOprInitInputsModifier<WarpPerspectiveBackwardMat>::apply(
  269. param, {&i0, &i1, &i2, &i3});
  270. return i0.insert_single_output_opr<WarpPerspectiveBackwardMat>(
  271. i0.node(), i1.node(), i2.node(), i3.node(), param, config);
  272. }
  273. /* ====================== Cv operator ====================== */
  274. MGB_DYN_TYPE_OBJ_FINAL_IMPL(RotateForward);
  275. MEGDNN_OPR_INIT1(RotateForward, "rotate")
  276. MGB_DYN_TYPE_OBJ_FINAL_IMPL(CvtColorForward);
  277. MEGDNN_OPR_INIT1(CvtColorForward, "cvt_color")
  278. MGB_DYN_TYPE_OBJ_FINAL_IMPL(GaussianBlurForward);
  279. MEGDNN_OPR_INIT1(GaussianBlurForward, "gaussion_blur")
  280. /* ======================= ResizeForward ======================= */
  281. MGB_DYN_TYPE_OBJ_FINAL_IMPL(ResizeForward);
  282. MEGDNN_OPR_INIT2(ResizeForward, "resize")
  283. void ResizeForward::init_output_dtype() {
  284. output(0)->dtype(input(0)->dtype());
  285. outshape_by_symvar_enable(1, 1);
  286. }
  287. void ResizeForward::add_input_layout_constraint() {
  288. if (param().format != Param::Format::NCHW) {
  289. input(0)->add_layout_constraint_contiguous();
  290. }
  291. input(1)->add_layout_constraint_contiguous();
  292. }
  293. void ResizeForward::outshape_by_symvar_do_get_output_shape(
  294. TensorShape& dest, const ShapeInferInfo& shpinfo) {
  295. TensorShape oshp2d;
  296. cg::copy_tensor_value_to_shape(oshp2d, *shpinfo.shpval_inp_val.at(0));
  297. auto imgshp = shpinfo.shape_inp_shp.at(0);
  298. mgb_assert((imgshp.ndim == 4 || imgshp.ndim == 5) && oshp2d.ndim == 2,
  299. "shape mismatch for ResizeForward: img=%s out2d=%s",
  300. imgshp.to_string().c_str(), oshp2d.to_string().c_str());
  301. //! The index of height, e.g.,[b, h, w, c], the height_idx = 1
  302. size_t height_idx = 0;
  303. if (param().format == Param::Format::NCHW ||
  304. param().format == Param::Format::NCHW4) {
  305. height_idx = 2;
  306. } else {
  307. height_idx = 1;
  308. }
  309. dest = imgshp;
  310. if (param().format == Param::Format::NHWCD4) {
  311. dest.shape[height_idx] = oshp2d.shape[0];
  312. dest.shape[height_idx + 2] = oshp2d.shape[1];
  313. } else {
  314. for (int i = 0; i < 2; ++i)
  315. dest.shape[height_idx + i] = oshp2d.shape[i];
  316. }
  317. }
  318. void ResizeForward::init_output_static_infer_desc() {
  319. Super::init_output_static_infer_desc();
  320. init_output_static_infer_desc_workspace(false);
  321. }
  322. void ResizeForward::scn_do_execute() {
  323. intl::MegDNNOprMethInvoker<megdnn::Resize>::exec(megdnn_opr(), this);
  324. }
  325. size_t ResizeForward::get_workspace_size_bytes(
  326. const TensorShapeArray& input_shapes,
  327. const TensorShapeArray& output_shapes) const {
  328. return intl::MegDNNOprMethInvoker<megdnn::Resize>::get_workspace_in_bytes(
  329. megdnn_opr(), this, input_shapes, output_shapes);
  330. }
  331. void ResizeForward::record_execute_deps(ExecDependencyArray& deps) {
  332. record_megdnn_opr(deps);
  333. }
  334. #if MGB_ENABLE_GRAD
  335. MGB_IMPL_OPR_GRAD(ResizeForward) {
  336. mgb_assert(opr.input().size() == 2);
  337. if (wrt_idx == 0) {
  338. SymbolVar grad =
  339. ResizeBackward::make(out_grad[0], opr.input(0), opr.param());
  340. return grad.node();
  341. } else
  342. return InvalidGrad::make(opr, wrt_idx);
  343. }
  344. #endif
  345. /* ====================== ResizeBackward ====================== */
  346. MGB_DYN_TYPE_OBJ_FINAL_IMPL(ResizeBackward);
  347. MEGDNN_OPR_INIT2(ResizeBackward, "resize_bwd", 1, false);
  348. /* ======================= WarpAffineForward ======================= */
  349. MGB_DYN_TYPE_OBJ_FINAL_IMPL(WarpAffineForward);
  350. MEGDNN_OPR_INIT3(WarpAffineForward, "warp_affine")
  351. void WarpAffineForward::init_output_dtype() {
  352. output(0)->dtype(input(0)->dtype());
  353. outshape_by_symvar_enable(2, 2);
  354. }
  355. void WarpAffineForward::add_input_layout_constraint() {
  356. mixin::megdnn_utils::add_input_layout_constraint_contig(*this);
  357. }
  358. void WarpAffineForward::outshape_by_symvar_do_get_output_shape(
  359. TensorShape& dest, const ShapeInferInfo& shpinfo) {
  360. TensorShape oshp2d;
  361. cg::copy_tensor_value_to_shape(oshp2d, *shpinfo.shpval_inp_val.at(0));
  362. auto imgshp = shpinfo.shape_inp_shp.at(0),
  363. matshp = shpinfo.shape_inp_shp.at(1);
  364. mgb_assert((imgshp.ndim == 4 || imgshp.ndim == 5) && matshp.ndim == 3 &&
  365. oshp2d.ndim == 2 && matshp.shape[0] == imgshp.shape[0] &&
  366. matshp.shape[1] == 2 && matshp.shape[2] == 3,
  367. "shape mismatch for WarpAffineForward: img=%s mat=%s out2d=%s",
  368. imgshp.to_string().c_str(), matshp.to_string().c_str(),
  369. oshp2d.to_string().c_str());
  370. size_t height_idx = 0;
  371. if (param().format == Param::Format::NCHW) {
  372. height_idx = 2;
  373. } else {
  374. height_idx = 1;
  375. }
  376. dest = imgshp;
  377. if (param().format == Param::Format::NHWCD4) {
  378. dest.shape[height_idx] = oshp2d.shape[0];
  379. dest.shape[height_idx + 2] = oshp2d.shape[1];
  380. } else {
  381. for (int i = 0; i < 2; ++i)
  382. dest.shape[height_idx + i] = oshp2d.shape[i];
  383. }
  384. }
  385. void WarpAffineForward::init_output_static_infer_desc() {
  386. Super::init_output_static_infer_desc();
  387. init_output_static_infer_desc_workspace(false);
  388. }
  389. void WarpAffineForward::scn_do_execute() {
  390. intl::MegDNNOprMethInvoker<megdnn::WarpAffine>::exec(megdnn_opr(), this);
  391. }
  392. size_t WarpAffineForward::get_workspace_size_bytes(
  393. const TensorShapeArray& input_shapes,
  394. const TensorShapeArray& output_shapes) const {
  395. return intl::MegDNNOprMethInvoker<
  396. megdnn::WarpAffine>::get_workspace_in_bytes(megdnn_opr(), this,
  397. input_shapes,
  398. output_shapes);
  399. }
  400. void WarpAffineForward::record_execute_deps(ExecDependencyArray& deps) {
  401. record_megdnn_opr(deps);
  402. }
  403. /* ======================= RemapForward ======================= */
  404. MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemapForward);
  405. MEGDNN_OPR_INIT2(RemapForward, "remap")
  406. void RemapForward::init_output_dtype() {
  407. output(0)->dtype(input(0)->dtype());
  408. }
  409. #if MGB_ENABLE_GRAD
  410. MGB_IMPL_OPR_GRAD(RemapForward) {
  411. mgb_assert(opr.input().size() == 2);
  412. if (wrt_idx == 0) {
  413. SymbolVar grad = RemapBackwardData::make(opr.input(1), out_grad[0],
  414. opr.input(0), opr.param());
  415. return grad.node();
  416. } else if (wrt_idx == 1) {
  417. SymbolVar grad = RemapBackwardMat::make(opr.input(0), opr.input(1),
  418. out_grad[0], opr.param());
  419. return grad.node();
  420. } else
  421. return InvalidGrad::make(opr, wrt_idx);
  422. }
  423. #endif
  424. /* ====================== RemapBackward ====================== */
  425. MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemapBackwardData);
  426. MEGDNN_OPR_INIT3(RemapBackwardData, "remap_bwd_data", 2, false);
  427. MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemapBackwardMat);
  428. MEGDNN_OPR_INIT3(RemapBackwardMat, "remap_bwd_mat", 1, true);
  429. /* ======================= DctChannelSelectForward ======================= */
  430. MGB_DYN_TYPE_OBJ_FINAL_IMPL(DctChannelSelectForward);
  431. namespace mgb {
  432. namespace opr {
  433. namespace intl {
  434. template <>
  435. struct MegDNNOprInitPostCtor<DctChannelSelectForward> {
  436. static void apply(cg::OperatorNodeBase& opr) {
  437. if (opr.config().output_dtype().valid()) {
  438. opr.output(0)->dtype(opr.config().output_dtype());
  439. } else {
  440. opr.output(0)->dtype(dtype::Float32());
  441. }
  442. }
  443. };
  444. } // namespace intl
  445. } // namespace opr
  446. } // namespace mgb
  447. void DctChannelSelectForward::get_output_var_shape(
  448. const TensorShapeArray& inp_shape, TensorShapeArray& out_shape) const {
  449. auto mo = megdnn_opr();
  450. TensorLayout dst;
  451. dst.dtype = output(0)->dtype();
  452. if (inp_shape.size() == 1) {
  453. mo->deduce_layout({inp_shape[0], input(0)->dtype(), input(0)->format()},
  454. {}, {}, dst);
  455. } else {
  456. mgb_assert(inp_shape.size() == 3, "no support input tensor num %zu",
  457. inp_shape.size());
  458. mo->deduce_layout({inp_shape[0], input(0)->dtype(), input(0)->format()},
  459. {inp_shape[1], input(1)->dtype(), input(1)->format()},
  460. {inp_shape[2], input(2)->dtype(), input(2)->format()},
  461. dst);
  462. }
  463. out_shape[0] = dst;
  464. }
  465. size_t DctChannelSelectForward::get_workspace_size_bytes(
  466. const TensorShapeArray& input_shapes,
  467. const TensorShapeArray& output_shapes) const {
  468. auto mo = megdnn_opr();
  469. return mo->get_workspace_in_bytes(
  470. {input_shapes[0], input(0)->dtype(), input(0)->format()}, {}, {},
  471. {output_shapes[0], output(0)->dtype(), output(0)->format()});
  472. }
  473. void DctChannelSelectForward::scn_do_execute() {
  474. auto&& inp = input();
  475. auto mo = megdnn_opr();
  476. if (inp.size() == 1) {
  477. mo->exec(inp[0]->dev_tensor().as_megdnn(), {}, {},
  478. output(0)->dev_tensor().as_megdnn(),
  479. intl::get_megdnn_workspace_from_var(output().back()));
  480. } else {
  481. mgb_assert(inp.size() == 3, "no support input tensor num %zu",
  482. inp.size());
  483. mo->exec(inp[0]->dev_tensor().as_megdnn(),
  484. inp[1]->dev_tensor().as_megdnn(),
  485. inp[2]->dev_tensor().as_megdnn(),
  486. output(0)->dev_tensor().as_megdnn(),
  487. intl::get_megdnn_workspace_from_var(output().back()));
  488. }
  489. }
  490. void DctChannelSelectForward::valid_mask(const int* mask_offset, int mask_len,
  491. const int* mask_val, int mask_val_len,
  492. const Param& param) {
  493. if (mask_len <= 0)
  494. return;
  495. mgb_assert(mask_offset[0] == 0,
  496. "The first element of mask_offset must be zero, but got %d. For "
  497. "example mask offset [0, 15, 20] indicate there are 2 ic, and "
  498. "ic_0 will have (15 - 0) oc, ic_1 have (20 - 15) oc",
  499. mask_offset[0]);
  500. for (int i = 1; i < mask_len; ++i) {
  501. if (param.format == Param::Format::NCHW4) {
  502. mgb_assert(mask_offset[i] % 4 == 0,
  503. "Invalid mask offset %d at %d, it should be times of "
  504. "4 when using nchw4 format",
  505. mask_offset[i], i);
  506. }
  507. mgb_assert(mask_offset[i] >= mask_offset[i - 1],
  508. "The offset of mask must be increasing, but %d(%d) is less "
  509. "than %d(%d)",
  510. mask_offset[i], i, mask_offset[i - 1], i - 1);
  511. }
  512. const int max_mask = param.dct_block_size * param.dct_block_size;
  513. for (int i = 0; i < mask_val_len; ++i) {
  514. mgb_assert(0 <= mask_val[i] && mask_val[i] < max_mask,
  515. "Invalid mask_val, assert 0 <= mask_val[%d] < %d, aka 0 <= "
  516. "%d < %d",
  517. i, max_mask, mask_val[i], max_mask);
  518. }
  519. }
  520. DctChannelSelectForward::DctChannelSelectForward(
  521. VarNode* src, VarNode* mask_offset, VarNode* mask_val,
  522. const Param& param, const OperatorNodeConfig& config)
  523. : Super(OperatorNodeBaseCtorParam{
  524. src->owner_graph(), config, "dct_channel_select", {src}}) {
  525. init_megdnn_opr(*this, param);
  526. add_input({src, mask_offset, mask_val});
  527. if (mask_offset != nullptr) {
  528. mgb_assert(mask_val,
  529. "mask_val should not be null when mask_offset is not null");
  530. auto host_offset = mask_offset->owner_opr()
  531. ->cast_final_safe<opr::ImmutableTensor>()
  532. .host_value();
  533. auto host_val = mask_val->owner_opr()
  534. ->cast_final_safe<opr::ImmutableTensor>()
  535. .host_value();
  536. valid_mask(host_offset.ptr<int>(),
  537. host_offset.layout().total_nr_elems(), host_val.ptr<int>(),
  538. host_val.layout().total_nr_elems(), param);
  539. }
  540. intl::MegDNNOprInitPostCtor<DctChannelSelectForward>::apply(*this);
  541. }
  542. SymbolVar DctChannelSelectForward::make(SymbolVar src, SymbolVar mask_offset,
  543. SymbolVar mask_val, const Param& param,
  544. const OperatorNodeConfig& config) {
  545. intl::MegDNNOprInitInputsModifier<DctChannelSelectForward>::apply(
  546. param, {&src, &mask_offset, &mask_val});
  547. return src.insert_single_output_opr<DctChannelSelectForward>(
  548. src.node(), mask_offset.node(), mask_val.node(), param, config);
  549. }
  550. MEGDNN_OPR_INIT1(DctChannelSelectForward, "dct_channel_select")
  551. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台