You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

relayout_format.cpp 21 kB


  1. /**
  2. * \file dnn/src/common/relayout_format.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "megdnn/oprs.h"
  13. #include "megdnn/tensor_format.h"
  14. #include "src/common/utils.h"
  15. using namespace megdnn;
  16. void RelayoutFormat::deduce_layout_fwd(const TensorLayout& src,
  17. TensorLayout& dst) {
  18. using Param = param::RelayoutFormat;
  19. switch (param().mode) {
  20. case Param::Mode::NCHW_NHWCD4:
  21. case Param::Mode::NCHW_NHWCD4I:
  22. dst.ndim = 5;
  23. dst[0] = src[0];
  24. dst[1] = src[2];
  25. dst[2] = (src[1] + 3) / 4;
  26. dst[3] = src[3];
  27. dst[4] = 4;
  28. break;
  29. case Param::Mode::NCHW_NCHW4_IC_SMALL:
  30. dst.ndim = 5;
  31. megdnn_assert(src[1] <= 4_z, "ic should be less equal 4");
  32. dst[0] = src[0];
  33. dst[1] = div_ceil(src[1], 4_z);
  34. dst[2] = src[2];
  35. dst[3] = src[3];
  36. dst[4] = 4;
  37. break;
  38. case Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT:
  39. megdnn_assert(src.ndim == 4, "src must be oihw, ndim == 4");
  40. megdnn_assert(src[1] <= 4_z, "ic should be less equal 4");
  41. dst.ndim = 5;
  42. dst[0] = src[0];
  43. dst[1] = div_ceil(src[1], 4_z);
  44. dst[2] = src[2];
  45. dst[3] = src[3];
  46. dst[4] = 4;
  47. break;
  48. case Param::Mode::NCHW_NCHW88:
  49. dst.ndim = 5;
  50. dst[0] = src[0];
  51. dst[1] = div_ceil(src[1], 8_z);
  52. dst[2] = src[2];
  53. dst[3] = src[3];
  54. dst[4] = 8;
  55. break;
  56. case Param::Mode::NCHW88_NCHW:
  57. dst.ndim = 4;
  58. dst[0] = src[0];
  59. dst[1] = src[1] * 8;
  60. dst[2] = src[2];
  61. dst[3] = src[3];
  62. break;
  63. case Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT:
  64. megdnn_assert(src.ndim == 4, "src must be oihw, ndim == 4");
  65. dst.ndim = 6;
  66. megdnn_assert(src[0] % 8 == 0,
  67. "NCHW_NCHW88_CONV_DENSE_WEIGHT out channel must "
  68. "align to 8");
  69. dst[0] = src[0] / 8;
  70. dst[1] = div_ceil(src[1], 8_z);
  71. dst[2] = src[2];
  72. dst[3] = src[3];
  73. dst[4] = 8;
  74. dst[5] = 8;
  75. break;
  76. case Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT:
  77. megdnn_assert(src.ndim == 5, "src must be goihw, ndim == 5");
  78. dst.ndim = 6;
  79. dst[0] = div_ceil(src[0], 8_z);
  80. dst[1] = src[1];
  81. dst[2] = src[2];
  82. dst[3] = src[3];
  83. dst[4] = src[4];
  84. dst[5] = 8;
  85. break;
  86. case Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT:
  87. megdnn_assert(src.ndim == 5, "src must be goihw, ndim == 5");
  88. dst.ndim = 7;
  89. dst[0] = src[0];
  90. megdnn_assert(src[1] % 8 == 0,
  91. "NCHW_NCHW88_CONV_GROUP_WEIGHT out channel must "
  92. "align to 8");
  93. dst[1] = src[1] / 8;
  94. dst[2] = div_ceil(src[2], 8_z);
  95. dst[3] = src[3];
  96. dst[4] = src[4];
  97. dst[5] = 8;
  98. dst[6] = 8;
  99. break;
  100. case Param::Mode::NHWC_NHWCD4:
  101. case Param::Mode::NHWC_NHWCD4I:
  102. megdnn_assert(src.ndim == 4);
  103. //! channel mod 4 should == 4
  104. megdnn_assert(src[3] % 4 == 0);
  105. dst.ndim = 5;
  106. dst[0] = src[0];
  107. dst[1] = src[1];
  108. dst[2] = src[3] / 4;
  109. dst[3] = src[2];
  110. dst[4] = 4;
  111. break;
  112. case Param::Mode::NHWCD4_NHWC:
  113. megdnn_assert(src.ndim == 5);
  114. dst.ndim = 4;
  115. dst[0] = src[0];
  116. dst[1] = src[1];
  117. dst[2] = src[3];
  118. dst[3] = src[2] * 4;
  119. break;
  120. case Param::Mode::NHWCD4_NCHW:
  121. case Param::Mode::NHWCD4I_NCHW:
  122. megdnn_assert(src.ndim == 5);
  123. dst.ndim = 4;
  124. dst[0] = src[0];
  125. dst[1] = src[2] * 4;
  126. dst[2] = src[1];
  127. dst[3] = src[3];
  128. break;
  129. case Param::Mode::INTER_WEIGHT_DENSE:
  130. case Param::Mode::INTER_WEIGHT_DENSEI:
  131. megdnn_assert(src.ndim == 4);
  132. megdnn_assert(src[0] % 4 == 0);
  133. dst.ndim = 5;
  134. dst[0] = src[0] / 4;
  135. dst[1] = src[2];
  136. dst[2] = src[3];
  137. dst[3] = round_up<size_t>(src[1], 4);
  138. dst[4] = 4;
  139. break;
  140. case Param::Mode::INTER_WEIGHT_GROUP:
  141. case Param::Mode::INTER_WEIGHT_GROUPI:
  142. // group conv filter
  143. megdnn_assert(src.ndim == 5);
  144. megdnn_assert(src[1] % 4 == 0 && src[2] % 4 == 0);
  145. dst.ndim = 6;
  146. dst[0] = src[0];
  147. dst[1] = src[1] / 4;
  148. dst[2] = src[3];
  149. dst[3] = src[4];
  150. dst[4] = src[2];
  151. dst[5] = 4;
  152. break;
  153. case Param::Mode::INTER_WEIGHT_CHAN:
  154. case Param::Mode::INTER_WEIGHT_CHANI:
  155. megdnn_assert(src.ndim == 5 && src[1] == 1 && src[2] == 1);
  156. // chanwise conv filter
  157. dst.ndim = 5;
  158. dst[0] = src[0] / 4;
  159. dst[1] = 1;
  160. dst[2] = src[3];
  161. dst[3] = src[4];
  162. dst[4] = 4;
  163. break;
  164. case Param::Mode::INTER_WEIGHT_DENSEI_DOT:
  165. megdnn_assert(src.ndim == 4);
  166. megdnn_assert(src[0] % 4 == 0);
  167. dst.ndim = 6;
  168. dst[0] = src[0] / 4;
  169. dst[1] = src[2];
  170. dst[2] = src[3];
  171. dst[3] = div_ceil<size_t>(src[1], 4);
  172. dst[4] = 4;
  173. dst[5] = 4;
  174. break;
  175. case Param::Mode::INTER_WEIGHT_GROUPI_DOT:
  176. megdnn_assert(src.ndim == 5);
  177. megdnn_assert(src[1] % 4 == 0 && src[2] % 4 == 0);
  178. dst.ndim = 7;
  179. dst[0] = src[0];
  180. dst[1] = src[1] / 4;
  181. dst[2] = src[3];
  182. dst[3] = src[4];
  183. dst[4] = src[2] / 4;
  184. dst[5] = 4;
  185. dst[6] = 4;
  186. break;
  187. case Param::Mode::NCHW4_CHWN4:
  188. megdnn_assert(src.ndim == 5);
  189. megdnn_assert(src[4] == 4);
  190. dst.ndim = 5;
  191. dst[0] = src[1];
  192. dst[1] = src[2];
  193. dst[2] = src[3];
  194. dst[3] = src[0];
  195. dst[4] = src[4];
  196. break;
  197. case Param::Mode::CHWN4_NCHW4:
  198. megdnn_assert(src.ndim == 5);
  199. megdnn_assert(src[4] == 4);
  200. dst.ndim = 5;
  201. dst[0] = src[3];
  202. dst[1] = src[0];
  203. dst[2] = src[1];
  204. dst[3] = src[2];
  205. dst[4] = src[4];
  206. break;
  207. case Param::Mode::NCHW_NCHW4:
  208. megdnn_assert(src.ndim == 4);
  209. dst.ndim = 5;
  210. dst[0] = src[0];
  211. dst[1] = div_ceil<size_t>(src[1], 4);
  212. dst[2] = src[2];
  213. dst[3] = src[3];
  214. dst[4] = 4;
  215. break;
  216. default:
  217. megdnn_assert(0, "Invalid RelayoutFormat Mode");
  218. break;
  219. }
  220. TensorFormat dst_fmt;
  221. deduce_format(src.format, dst_fmt);
  222. dst.format = dst_fmt;
  223. if (!dst.dtype.valid()) {
  224. dst.dtype = src.dtype;
  225. }
  226. dst.init_contiguous_stride();
  227. }
  228. void RelayoutFormat::deduce_layout(const TensorLayout& src, TensorLayout& dst) {
  229. deduce_layout_fwd(src, dst);
  230. }
  231. void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) {
  232. size_t align = handle()->image2d_pitch_alignment();
  233. auto vendor_type = handle()->vendor_type();
  234. using Param = param::RelayoutFormat;
  235. #define CHECK_SRC(_expect) \
  236. megdnn_assert(src == _expect, "invalid src format: expect=%s got=%s", \
  237. _expect.to_string().c_str(), src.to_string().c_str())
  238. switch (param().mode) {
  239. case Param::Mode::NHWC_NHWCD4:
  240. CHECK_SRC(DefaultTensorFormat::make());
  241. dst = src;
  242. break;
  243. case Param::Mode::NHWCD4_NHWC:
  244. CHECK_SRC(DefaultTensorFormat::make());
  245. dst = src;
  246. break;
  247. case Param::Mode::NHWC_NHWCD4I:
  248. CHECK_SRC(DefaultTensorFormat::make());
  249. dst = Image2DPack4TensorFormat::make_raw(2, align, vendor_type);
  250. break;
  251. case Param::Mode::NCHW_NHWCD4:
  252. CHECK_SRC(DefaultTensorFormat::make());
  253. dst = src;
  254. break;
  255. case Param::Mode::NCHW_NCHW4:
  256. CHECK_SRC(DefaultTensorFormat::make());
  257. dst = src;
  258. break;
  259. case Param::Mode::NCHW_NHWCD4I:
  260. CHECK_SRC(DefaultTensorFormat::make());
  261. dst = Image2DPack4TensorFormat::make_raw(2, align, vendor_type);
  262. break;
  263. case Param::Mode::NHWCD4I_NCHW:
  264. CHECK_SRC(Image2DPack4TensorFormat::make_raw(2, align, vendor_type));
  265. dst = DefaultTensorFormat::make();
  266. break;
  267. case Param::Mode::NHWCD4_NCHW:
  268. CHECK_SRC(DefaultTensorFormat::make());
  269. dst = src;
  270. break;
  271. case Param::Mode::INTER_WEIGHT_DENSE:
  272. CHECK_SRC(DefaultTensorFormat::make());
  273. dst = src;
  274. break;
  275. case Param::Mode::INTER_WEIGHT_DENSEI:
  276. case Param::Mode::INTER_WEIGHT_DENSEI_DOT:
  277. CHECK_SRC(DefaultTensorFormat::make());
  278. dst = Image2DPack4TensorFormat::make_raw(3, align, vendor_type);
  279. break;
  280. case Param::Mode::INTER_WEIGHT_GROUP:
  281. CHECK_SRC(DefaultTensorFormat::make());
  282. dst = src;
  283. break;
  284. case Param::Mode::INTER_WEIGHT_GROUPI:
  285. case Param::Mode::INTER_WEIGHT_GROUPI_DOT:
  286. CHECK_SRC(DefaultTensorFormat::make());
  287. dst = Image2DPack4TensorFormat::make_raw(4, align, vendor_type);
  288. break;
  289. case Param::Mode::INTER_WEIGHT_CHAN:
  290. CHECK_SRC(DefaultTensorFormat::make());
  291. dst = src;
  292. break;
  293. case Param::Mode::INTER_WEIGHT_CHANI:
  294. CHECK_SRC(DefaultTensorFormat::make());
  295. dst = Image2DPack4TensorFormat::make_raw(1, align, vendor_type);
  296. break;
  297. case Param::Mode::NCHW4_CHWN4:
  298. CHECK_SRC(DefaultTensorFormat::make());
  299. dst = src;
  300. break;
  301. case Param::Mode::CHWN4_NCHW4:
  302. CHECK_SRC(DefaultTensorFormat::make());
  303. dst = src;
  304. break;
  305. case Param::Mode::NCHW_NCHW88:
  306. case Param::Mode::NCHW88_NCHW:
  307. case Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT:
  308. case Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT:
  309. case Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT:
  310. case Param::Mode::NCHW_NCHW4_IC_SMALL:
  311. case Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT:
  312. CHECK_SRC(DefaultTensorFormat::make());
  313. dst = src;
  314. break;
  315. default:
  316. megdnn_throw("Invalid relayout format mode");
  317. break;
  318. }
  319. if (!dst.is_default() &&
  320. (
  321. handle()->type() != Handle::HandleType::NAIVE)) {
  322. megdnn_throw(
  323. "Only naive and opencl handle support "
  324. "Image2DPack4TensorFormat, try to export MGB_USE_MEGDNN_DBG=2 "
  325. "and also export CUDA_VISIBLE_DEVICES=\'\' at CUDA env"
  326. "to enable naive handle");
  327. }
  328. #undef CHECK_SRC
  329. }
  330. void RelayoutFormat::check_layout_fwd(const TensorLayout& src,
  331. const TensorLayout& dst) {
  332. TensorLayout dst_expected;
  333. dst_expected.dtype = dst.dtype;
  334. deduce_layout_fwd(src, dst_expected);
  335. megdnn_assert_eq_layout(dst_expected, dst);
  336. }
  337. void RelayoutFormat::check_exec(const TensorLayout& src,
  338. const TensorLayout& dst,
  339. size_t workspace_in_bytes) {
  340. check_layout_fwd(src, dst);
  341. auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst);
  342. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  343. }
  344. void RelayoutFormat::deduce_exec_layout(const TensorLayout& src,
  345. const TensorLayout& dst,
  346. TensorLayout& exec_src,
  347. TensorLayout& exec_dst) {
  348. check_layout_fwd(src, dst);
  349. using Param = param::RelayoutFormat;
  350. switch (param().mode) {
  351. case Param::Mode::NCHW_NCHW88:
  352. // nchw to nchw8c
  353. {
  354. TensorLayout work_space_layout(
  355. {src[0], round_up(src[1], 8_z), src[2], src[3]},
  356. src.dtype, src.format);
  357. exec_src = work_space_layout
  358. .reshape({src[0], div_ceil(src[1], 8_z), 8,
  359. src[2], src[3]})
  360. .dimshuffle({0, 1, 3, 4, 2});
  361. exec_dst = dst;
  362. }
  363. break;
  364. case Param::Mode::NCHW_NCHW4:
  365. // nchw to nchw4
  366. {
  367. TensorLayout work_space_layout(
  368. {src[0], round_up(src[1], 4_z), src[2], src[3]},
  369. src.dtype, src.format);
  370. exec_src = work_space_layout
  371. .reshape({src[0], div_ceil(src[1], 4_z), 4,
  372. src[2], src[3]})
  373. .dimshuffle({0, 1, 3, 4, 2});
  374. exec_dst = dst;
  375. }
  376. break;
  377. case Param::Mode::NCHW88_NCHW:
  378. // nchw8c to nchw
  379. exec_src = src;
  380. exec_dst = dst.reshape({dst[0], dst[1] / 8, 8, dst[2], dst[3]})
  381. .dimshuffle({0, 1, 3, 4, 2});
  382. break;
  383. case Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT:
  384. // oihw to oihw8i8o
  385. {
  386. megdnn_assert(src.ndim == 4);
  387. megdnn_assert(src[0] % 8 == 0);
  388. TensorLayout work_space_layout(
  389. {src[0], round_up(src[1], 8_z), src[2], src[3]},
  390. src.dtype, src.format);
  391. exec_src =
  392. work_space_layout
  393. .reshape({src[0] / 8, 8, div_ceil(src[1], 8_z),
  394. 8, src[2], src[3]})
  395. .dimshuffle({0, 2, 4, 5, 3, 1});
  396. exec_dst = dst;
  397. }
  398. break;
  399. case Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT:
  400. // goihw to goihw8g
  401. {
  402. megdnn_assert(src.ndim == 5);
  403. TensorLayout work_space_layout(
  404. {round_up(src[0], 8_z), src[1], src[2], src[3], src[4]},
  405. src.dtype, src.format);
  406. exec_src = work_space_layout
  407. .reshape({div_ceil(src[0], 8_z), 8, src[1],
  408. src[2], src[3], src[4]})
  409. .dimshuffle({0, 2, 3, 4, 5, 1});
  410. exec_dst = dst;
  411. }
  412. break;
  413. case Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT:
  414. // goihw to goihw8i8o
  415. {
  416. megdnn_assert(src.ndim == 5);
  417. megdnn_assert(src[1] % 8 == 0);
  418. TensorLayout work_space_layout(
  419. {src[0], src[1], round_up(src[2], 8_z), src[3], src[4]},
  420. src.dtype, src.format);
  421. exec_src = work_space_layout
  422. .reshape({src[0], src[1] / 8, 8,
  423. div_ceil(src[2], 8_z), 8, src[3],
  424. src[4]})
  425. .dimshuffle({0, 1, 3, 5, 6, 4, 2});
  426. exec_dst = dst;
  427. }
  428. break;
  429. case Param::Mode::NCHW_NCHW4_IC_SMALL:
  430. case Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT:
  431. // nchw to nchw4c or oihw to oihw4i
  432. {
  433. TensorLayout work_space_layout(
  434. {src[0], round_up(src[1], 4_z), src[2], src[3]},
  435. src.dtype, src.format);
  436. exec_src = work_space_layout
  437. .reshape({src[0], div_ceil(src[1], 4_z), 4,
  438. src[2], src[3]})
  439. .dimshuffle({0, 1, 3, 4, 2});
  440. exec_dst = dst;
  441. }
  442. break;
  443. case Param::Mode::NCHW_NHWCD4:
  444. case Param::Mode::NCHW_NHWCD4I:
  445. // src is {N, C, H, W}
  446. // dst is {N, H, CB, W, 4}
  447. exec_src = src;
  448. exec_src[1] = (exec_src[1] + 3) / 4 * 4;
  449. exec_src.stride[0] = exec_src[1] * exec_src.stride[1];
  450. exec_src = exec_src.dimshuffle({0, 2, 3, 1});
  451. exec_src = exec_src.reshape({exec_src[0], exec_src[1], exec_src[2],
  452. exec_src[3] / 4, 4})
  453. .dimshuffle({0, 1, 3, 2, 4});
  454. exec_dst = dst;
  455. break;
  456. case Param::Mode::NHWC_NHWCD4:
  457. case Param::Mode::NHWC_NHWCD4I:
  458. // src is {N, H, W, C},
  459. // dst is {N, H, CB, W, 4}
  460. exec_src = src.reshape({src[0], src[1], src[2], src[3] / 4, 4})
  461. .dimshuffle({0, 1, 3, 2, 4});
  462. exec_dst = dst;
  463. break;
  464. case Param::Mode::NHWCD4_NHWC:
  465. // src is {N, H, CB, W, 4}
  466. // dst is {N, H, W, C},
  467. exec_src = src;
  468. exec_dst = dst.reshape({dst[0], dst[1], dst[2], dst[3] / 4, 4})
  469. .dimshuffle({0, 1, 3, 2, 4});
  470. break;
  471. case Param::Mode::NHWCD4_NCHW:
  472. case Param::Mode::NHWCD4I_NCHW:
  473. exec_src = src;
  474. exec_dst = dst.reshape({dst[0], dst[1] / 4, 4, dst[2], dst[3]})
  475. .dimshuffle({0, 3, 1, 4, 2});
  476. break;
  477. case Param::Mode::INTER_WEIGHT_DENSE:
  478. case Param::Mode::INTER_WEIGHT_DENSEI:
  479. // src is {OC, IC, FH, FW}
  480. // dst is {OCB, FH, FW, IC, 4}
  481. exec_src = src.reshape({src[0] / 4, 4, src[1], src[2], src[3]})
  482. .dimshuffle({0, 3, 4, 2, 1});
  483. exec_dst = dst;
  484. // dst[3] may be round_uped, set to the real ic
  485. exec_dst.shape[3] = src[1];
  486. break;
  487. case Param::Mode::INTER_WEIGHT_GROUP:
  488. case Param::Mode::INTER_WEIGHT_GROUPI:
  489. // group conv filter
  490. // src is {G, ocpg, icpg, fh, fw}
  491. // dst is {G, ocpgb, fh, fw, icpg, 4}
  492. exec_src =
  493. src.reshape({src[0], src[1] / 4, 4, src[2], src[3], src[4]})
  494. .dimshuffle({0, 1, 4, 5, 3, 2});
  495. exec_dst = dst;
  496. break;
  497. case Param::Mode::INTER_WEIGHT_CHAN:
  498. case Param::Mode::INTER_WEIGHT_CHANI:
  499. megdnn_assert(src.ndim == 5);
  500. megdnn_assert(src[1] == 1 && src[2] == 1);
  501. // chanwise conv filter
  502. megdnn_assert(src[0] % 4 == 0);
  503. exec_src = src.reshape({src[0] / 4, 4, 1, src[3], src[4]})
  504. .dimshuffle({0, 2, 3, 4, 1});
  505. exec_dst = dst;
  506. break;
  507. case Param::Mode::INTER_WEIGHT_DENSEI_DOT:
  508. // src is {oc, ic, fh , fw}
  509. // dst is {oc/4, fh, fw, ic/4, 4, 4}
  510. exec_src = src;
  511. exec_src[1] = round_up<size_t>(src[1], 4);
  512. exec_src.stride[0] = exec_src.stride[1] * exec_src[1];
  513. exec_src = exec_src.reshape({exec_src[0] / 4, 4, exec_src[1] / 4, 4,
  514. exec_src[2], exec_src[3]})
  515. .dimshuffle({0, 4, 5, 2, 1, 3});
  516. exec_dst = dst;
  517. break;
  518. case Param::Mode::INTER_WEIGHT_GROUPI_DOT:
  519. // src is {G, ocpg, icpg, fh, fw}
  520. // dst is {G, ocpg/4, fh, fw, icpg/4, 4, 4}
  521. exec_src = src.reshape({src[0], src[1] / 4, 4, src[2] / 4, 4,
  522. src[3], src[4]})
  523. .dimshuffle({0, 1, 5, 6, 3, 2, 4});
  524. exec_dst = dst;
  525. break;
  526. case Param::Mode::NCHW4_CHWN4:
  527. // src is {N, C/4, H, W, 4}
  528. // dst is {C/4, H, W, N, 4}
  529. exec_src = src.dimshuffle({1, 2, 3, 0, 4});
  530. exec_dst = dst;
  531. break;
  532. case Param::Mode::CHWN4_NCHW4:
  533. // src is {C/4, H, W, N, 4}
  534. // dst is {N, C/4, H, W, 4}
  535. exec_src = src.dimshuffle({3, 0, 1, 2, 4});
  536. exec_dst = dst;
  537. break;
  538. default:
  539. megdnn_assert(0, "Invalid RelayoutFormat Mode");
  540. }
  541. }
  542. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台