You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

convolution.cpp 52 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225
  1. #include "megdnn/oprs/nn.h"
  2. #include "src/common/utils.h"
  3. using namespace megdnn;
  4. namespace {
  5. template <typename Param>
  6. std::string get_errmsg(
  7. const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst,
  8. const Param& param) {
  9. MEGDNN_MARK_USED_VAR(src);
  10. MEGDNN_MARK_USED_VAR(filter);
  11. MEGDNN_MARK_USED_VAR(dst);
  12. return megdnn_layout_msg(src) + ", " + megdnn_layout_msg(filter) + ", " +
  13. megdnn_layout_msg(dst) + ", " + "is_nchw=" +
  14. std::to_string(param.format == param::Convolution::Format::NCHW) + ", " +
  15. "is_xcorr=" +
  16. std::to_string((param.mode == Convolution::Mode::CROSS_CORRELATION)) + ", " +
  17. "pad_h=" + std::to_string(param.pad_h) + ", " +
  18. "pad_w=" + std::to_string(param.pad_w) + ", " +
  19. "stride_h=" + std::to_string(param.stride_h) + ", " +
  20. "stride_w=" + std::to_string(param.stride_w) + ", " +
  21. "dilate_h=" + std::to_string(param.dilate_h) + ", " +
  22. "dilate_w=" + std::to_string(param.dilate_w);
  23. }
  24. template <typename Param, typename Param::Format>
  25. uint32_t spatial_getter(uint32_t filter, const Param&) {
  26. return filter;
  27. }
  28. template <typename Parameter, typename Param>
  29. void make_canonized_filter_meta_nchw_nhwc(
  30. size_t src_ndim, const TensorLayout& filter, const Param& param,
  31. typename ConvolutionBase<Parameter>::CanonizedFilterMeta& ret) {
  32. megdnn_assert(
  33. param.format == Param::Format::NCHW || param.format == Param::Format::NHWC);
  34. auto img_ndim = src_ndim - 2;
  35. size_t flt_start, flt_spatial_start, ocpg_pos, icpg_pos;
  36. if (param.sparse == Param::Sparse::DENSE) {
  37. megdnn_assert(
  38. filter.ndim == img_ndim + 2 || filter.ndim == img_ndim + 4,
  39. "bad filter ndim for dense convolution: "
  40. "spatial_ndim=%zu filter_ndim=%zu",
  41. img_ndim, filter.ndim);
  42. ret.group = 1;
  43. flt_start = 0;
  44. } else {
  45. megdnn_assert(
  46. param.sparse == Param::Sparse::GROUP,
  47. "invalid convolution sparse type");
  48. megdnn_assert(
  49. filter.ndim == img_ndim + 3 || filter.ndim == img_ndim + 5,
  50. "bad filter ndim for group convolution: "
  51. "spatial_ndim=%zu filter_ndim=%zu",
  52. img_ndim, filter.ndim);
  53. // grp, oc, ic, dims[]
  54. ret.group = filter[0];
  55. flt_start = 1;
  56. }
  57. uint32_t ic_block_size = 1, oc_block_size = 1;
  58. if (param.format == Param::Format::NCHW) {
  59. // filter should be (oc, ic, fh, fw)
  60. flt_spatial_start = 2;
  61. ocpg_pos = 0;
  62. icpg_pos = 1;
  63. } else {
  64. megdnn_assert(
  65. param.format == Param::Format::NHWC, "invalid conv tensor format");
  66. // filter should be (oc, fh, fw, ic)
  67. flt_spatial_start = 1;
  68. ocpg_pos = 0;
  69. icpg_pos = 3;
  70. }
  71. ret.spatial_ndim = src_ndim - 2;
  72. megdnn_assert(
  73. ret.spatial_ndim == 2,
  74. "only 2D convolution is supported, and input should be 4-dim; "
  75. "got input dim = %zu",
  76. src_ndim);
  77. ret.ocpg = filter[flt_start + ocpg_pos] * oc_block_size;
  78. ret.icpg = filter[flt_start + icpg_pos] * ic_block_size;
  79. auto dilation = ret.dilation;
  80. for (size_t i = 0; i < ret.spatial_ndim; ++i) {
  81. megdnn_assert(
  82. dilation[i] > 0, "invalid dilation on spatial dim %zu: %u", i,
  83. dilation[i]);
  84. ret.spatial[i] = spatial_getter<Param, Param::Format::NCHW>(
  85. filter[i + flt_start + flt_spatial_start], param);
  86. ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1;
  87. }
  88. }
  89. template <typename Parameter, typename Param>
  90. void make_canonized_filter_meta_nhwcd4(
  91. size_t src_ndim, const TensorLayout& filter, const Param& param,
  92. typename ConvolutionBase<Parameter>::CanonizedFilterMeta& ret) {
  93. /**
  94. * input: N H IC/4 W 4
  95. * Filter:
  96. * OC/4, FH, FW, IC, 4 [dense]
  97. * GROUP, OC/4, FH, FW, IC, 4 [group]
  98. * GROUP/4, 1, FH, FW, 4 [chanwise]
  99. */
  100. megdnn_assert(param.format == Param::Format::NHWCD4);
  101. auto img_ndim = src_ndim - 3;
  102. size_t flt_start = 0, flt_spatial_start = 1;
  103. bool is_chanwise = false;
  104. if (param.sparse == Param::Sparse::DENSE) {
  105. megdnn_assert(
  106. filter.ndim == img_ndim + 3,
  107. "bad filter ndim for dense convolution: "
  108. "spatial_ndim=%zu filter_ndim=%zu",
  109. img_ndim, filter.ndim);
  110. // oc, ic, dims[]
  111. ret.group = 1;
  112. flt_start = 0;
  113. } else {
  114. megdnn_assert(
  115. param.sparse == Param::Sparse::GROUP,
  116. "invalid convolution sparse type");
  117. megdnn_assert(
  118. filter.ndim == img_ndim + 3 || filter.ndim == img_ndim + 4,
  119. "bad filter ndim for group convolution: "
  120. "spatial_ndim=%zu filter_ndim=%zu",
  121. img_ndim, filter.ndim);
  122. if (filter.ndim == img_ndim + 3 && filter[1] == 1) {
  123. is_chanwise = true;
  124. ret.group = filter[0] * 4;
  125. } else {
  126. ret.group = filter[0];
  127. }
  128. flt_start = 1;
  129. }
  130. ret.spatial_ndim = src_ndim - 3;
  131. megdnn_assert(
  132. ret.spatial_ndim == 2,
  133. "only 2D convolution is supported, and input should be 4-dim; "
  134. "got input dim = %zu",
  135. src_ndim);
  136. if (is_chanwise) {
  137. ret.ocpg = 1;
  138. ret.icpg = 1;
  139. } else {
  140. ret.ocpg = filter[flt_start] * 4;
  141. ret.icpg = filter[flt_start + 3];
  142. }
  143. auto dilation = ret.dilation;
  144. for (size_t i = 0; i < ret.spatial_ndim; ++i) {
  145. megdnn_assert(
  146. dilation[i] > 0, "invalid dilation on spatial dim %zu: %u", i,
  147. dilation[i]);
  148. ret.spatial[i] = filter[i + flt_start + flt_spatial_start];
  149. ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1;
  150. }
  151. }
  152. template <typename Parameter, typename Param>
  153. void make_canonized_filter_meta_nhwcd4_dot(
  154. size_t src_ndim, const TensorLayout& filter, const Param& param,
  155. typename ConvolutionBase<Parameter>::CanonizedFilterMeta& ret) {
  156. /**
  157. * input: N H IC/4 W 4
  158. * Filter:
  159. * GROUP/4, 1, FH, FW, 4 [chanwise]
  160. * OC/4, FH, FW, IC/4, 4, 4 [dense]
  161. * GROUP, OC/4, FH, FW, IC/4, 4, 4 [group]
  162. */
  163. megdnn_assert(param.format == Param::Format::NHWCD4);
  164. auto img_ndim = src_ndim - 3;
  165. size_t flt_start = 0, flt_spatial_start = 1;
  166. bool is_chanwise = false;
  167. if (param.sparse == Param::Sparse::DENSE) {
  168. megdnn_assert(
  169. filter.ndim == img_ndim + 4,
  170. "bad filter ndim for dense convolution: "
  171. "spatial_ndim=%zu filter_ndim=%zu",
  172. img_ndim, filter.ndim);
  173. // oc, ic, dims[]
  174. ret.group = 1;
  175. flt_start = 0;
  176. } else {
  177. megdnn_assert(
  178. param.sparse == Param::Sparse::GROUP,
  179. "invalid convolution sparse type");
  180. megdnn_assert(
  181. filter.ndim == img_ndim + 3 || filter.ndim == img_ndim + 5,
  182. "bad filter ndim for group convolution: "
  183. "spatial_ndim=%zu filter_ndim=%zu",
  184. img_ndim, filter.ndim);
  185. if (filter.ndim == img_ndim + 3) {
  186. megdnn_assert(filter[1] == 1);
  187. is_chanwise = true;
  188. ret.group = filter[0] * 4;
  189. } else {
  190. ret.group = filter[0];
  191. }
  192. flt_start = 1;
  193. }
  194. ret.spatial_ndim = src_ndim - 3;
  195. megdnn_assert(
  196. ret.spatial_ndim == 2,
  197. "only 2D convolution is supported, and input should be 4-dim; "
  198. "got input dim = %zu",
  199. src_ndim);
  200. if (is_chanwise) {
  201. ret.ocpg = 1;
  202. ret.icpg = 1;
  203. } else {
  204. ret.ocpg = filter[flt_start] * 4;
  205. ret.icpg = filter[flt_start + 3] * 4;
  206. }
  207. auto dilation = ret.dilation;
  208. for (size_t i = 0; i < ret.spatial_ndim; ++i) {
  209. megdnn_assert(
  210. dilation[i] > 0, "invalid dilation on spatial dim %zu: %u", i,
  211. dilation[i]);
  212. ret.spatial[i] = filter[i + flt_start + flt_spatial_start];
  213. ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1;
  214. }
  215. }
  216. template <size_t pack_size, typename Parameter, typename Param>
  217. void make_canonized_filter_meta_nchwxx(
  218. size_t src_ndim, const TensorLayout& filter, const Param& param,
  219. typename ConvolutionBase<Parameter>::CanonizedFilterMeta& ret) {
  220. /**
  221. * input: N IC/pack_size, H, W, pack_size
  222. *
  223. ** NCHW44-DOT mode
  224. * filter:
  225. * {OC/pack_size, IC/pack_size, FH, FW, pack_size(OC), pack_size(IC)}
  226. * [dense]
  227. * {GROUP, OC_PER_GROUP/pack_size, IC_PER_GROUP/pack_size, \
  228. * FH, FW, pack_size(OC), pack_size(IC)} [group]
  229. *
  230. * NCHW88 and NCHW44 mode
  231. * filter:
  232. * {OC/pack_size, IC/pack_size, FH, FW, pack_size(IC), pack_size(OC)}
  233. * [dense]
  234. * {GROUP, OC_PER_GROUP/pack_size, IC_PER_GROUP/pack_size, \
  235. * FH, FW, pack_size(IC), pack_size(OC)} [group]
  236. * {GROUP/pack_size, 1, 1, FH, FW, pack_size} [chan]
  237. *
  238. *
  239. */
  240. megdnn_assert(
  241. param.format == Param::Format::NCHW88 ||
  242. param.format == Param::Format::NCHW44 ||
  243. param.format == Param::Format::NCHW44_DOT);
  244. size_t img_ndim = 2;
  245. size_t flt_start = 0;
  246. size_t flt_spatial_start = 2;
  247. size_t pack_c_size = 0;
  248. if (param.sparse == Param::Sparse::DENSE) {
  249. if (filter.ndim == img_ndim + 4) {
  250. // oihw8i8o case
  251. megdnn_assert(
  252. (filter[filter.ndim - 2] == pack_size &&
  253. filter[filter.ndim - 1] == pack_size) ||
  254. (filter[filter.ndim - 2] == 2 * pack_size &&
  255. filter[filter.ndim - 1] == 2 * pack_size),
  256. "last 2 dim of filter must be %zu, but got %zu, %zu", pack_size,
  257. filter[filter.ndim - 2], filter[filter.ndim - 1]);
  258. ret.group = 1;
  259. flt_start = 0;
  260. if (filter[filter.ndim - 2] == 2 * pack_size &&
  261. filter[filter.ndim - 1] == 2 * pack_size) {
  262. pack_c_size = 2 * pack_size;
  263. } else {
  264. pack_c_size = pack_size;
  265. }
  266. ret.ocpg = filter[flt_start] * pack_c_size;
  267. ret.icpg = filter[flt_start + 1] * pack_c_size;
  268. } else if (filter.ndim == img_ndim + 3) {
  269. // ohwi8o
  270. flt_start = 0;
  271. flt_spatial_start = 1;
  272. ret.group = 1;
  273. ret.ocpg = filter[flt_start] * pack_size;
  274. ret.icpg = filter[flt_start + 3];
  275. } else {
  276. megdnn_assert(0, "not support nchwxx filter dim = %zu", filter.ndim);
  277. }
  278. } else {
  279. megdnn_assert(
  280. param.sparse == Param::Sparse::GROUP,
  281. "invalid convolution sparse type");
  282. flt_start = 1;
  283. auto filter_oc = filter[flt_start];
  284. auto filter_ic = filter[flt_start + 1];
  285. if (filter_oc == 1 && filter_ic == 1 && filter.ndim == (img_ndim + 4)) {
  286. // Depthwise case goihw8g
  287. megdnn_assert(
  288. filter.ndim == img_ndim + 4,
  289. "bad filter ndim for group convolution: "
  290. "spatial_ndim=%zu filter_ndim=%zu",
  291. img_ndim, filter.ndim);
  292. megdnn_assert(
  293. filter[filter.ndim - 1] == pack_size,
  294. "last dim of filter must be %zu, but %zu", pack_size,
  295. filter[filter.ndim - 1]);
  296. ret.group = filter[0] * pack_size;
  297. ret.ocpg = filter_oc;
  298. ret.icpg = filter_ic;
  299. } else {
  300. // norm group case goihw8i8o
  301. megdnn_assert(
  302. filter.ndim == img_ndim + 5,
  303. "bad filter ndim for group convolution: "
  304. "spatial_ndim=%zu filter_ndim=%zu",
  305. img_ndim, filter.ndim);
  306. megdnn_assert(
  307. (filter[filter.ndim - 1] == pack_size &&
  308. filter[filter.ndim - 2] == pack_size) ||
  309. (filter[filter.ndim - 1] == 2 * pack_size &&
  310. filter[filter.ndim - 2] == 2 * pack_size),
  311. "last 2 dim of filter must be %zu, but got %zu, %zu", pack_size,
  312. filter[filter.ndim - 2], filter[filter.ndim - 1]);
  313. ret.group = filter[0];
  314. if (filter[filter.ndim - 2] == 2 * pack_size &&
  315. filter[filter.ndim - 1] == 2 * pack_size) {
  316. ret.ocpg = filter_oc * 2 * pack_size;
  317. ret.icpg = filter_ic * 2 * pack_size;
  318. } else {
  319. ret.ocpg = filter_oc * pack_size;
  320. ret.icpg = filter_ic * pack_size;
  321. }
  322. }
  323. }
  324. ret.spatial_ndim = 2;
  325. megdnn_assert(
  326. ret.spatial_ndim == 2,
  327. "only 2D convolution is supported, and input should be 5-dim "
  328. "for nchwxx; "
  329. "got input dim = %zu",
  330. src_ndim);
  331. auto dilation = ret.dilation;
  332. for (size_t i = 0; i < ret.spatial_ndim; ++i) {
  333. megdnn_assert(
  334. dilation[i] == 1,
  335. "NCHWXX has invalid dilation on spatial dim %zu: %u, "
  336. "require to be 1",
  337. i, dilation[i]);
  338. ret.spatial[i] = filter[i + flt_start + flt_spatial_start];
  339. ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1;
  340. }
  341. }
  342. template <size_t pack_size, typename Parameter, typename Param>
  343. void make_canonized_filter_meta_nchwx(
  344. size_t src_ndim, const TensorLayout& filter, const Param& param,
  345. typename ConvolutionBase<Parameter>::CanonizedFilterMeta& ret) {
  346. /**
  347. * input: N IC/pack_size, H, W, pack_size
  348. * filter:
  349. * OC, IC/pack_size, FH, FW, pack_size [dense]
  350. * GROUP, OC, IC/pack_size, FH, FW, pack_size [group]
  351. */
  352. megdnn_assert(
  353. param.format == Param::Format::NCHW4 ||
  354. param.format == Param::Format::NCHW8 ||
  355. param.format == Param::Format::NCHW32 ||
  356. param.format == Param::Format::NCHW4_NCHW ||
  357. param.format == Param::Format::NCHW4_NHWC ||
  358. param.format == Param::Format::NCHW4_NCHW32 ||
  359. param.format == Param::Format::NCHW32_NCHW4 ||
  360. param.format == Param::Format::NCHW64);
  361. auto img_ndim = src_ndim - 3;
  362. size_t flt_start = 0, flt_spatial_start = 2;
  363. if (param.sparse == Param::Sparse::DENSE) {
  364. megdnn_assert(
  365. filter.ndim == img_ndim + 3,
  366. "bad filter ndim for dense convolution: "
  367. "spatial_ndim=%zu filter_ndim=%zu",
  368. img_ndim, filter.ndim);
  369. // oc, ic, dims[]
  370. ret.group = 1;
  371. flt_start = 0;
  372. } else {
  373. megdnn_assert(
  374. param.sparse == Param::Sparse::GROUP,
  375. "invalid convolution sparse type");
  376. megdnn_assert(
  377. filter.ndim == img_ndim + 4,
  378. "bad filter ndim for group convolution: "
  379. "spatial_ndim=%zu filter_ndim=%zu",
  380. img_ndim, filter.ndim);
  381. ret.group = filter[0];
  382. flt_start = 1;
  383. }
  384. ret.spatial_ndim = src_ndim - 3;
  385. megdnn_assert(
  386. ret.spatial_ndim == 2,
  387. "only 2D convolution is supported, and input should be 5-dim "
  388. "for nchw4; "
  389. "got input dim = %zu",
  390. src_ndim);
  391. ret.ocpg = filter[flt_start];
  392. ret.icpg = filter[flt_start + 1] * pack_size;
  393. auto dilation = ret.dilation;
  394. for (size_t i = 0; i < ret.spatial_ndim; ++i) {
  395. megdnn_assert(
  396. dilation[i] == 1,
  397. "NCHW4 has invalid dilation on spatial dim %zu: %u, "
  398. "require to be 1",
  399. i, dilation[i]);
  400. ret.spatial[i] = filter[i + flt_start + flt_spatial_start];
  401. ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1;
  402. }
  403. }
  404. template <size_t pack_size, typename Parameter, typename Param>
  405. void make_canonized_filter_meta_chwnx(
  406. size_t src_ndim, const TensorLayout& filter, const Param& param,
  407. typename ConvolutionBase<Parameter>::CanonizedFilterMeta& ret) {
  408. /**
  409. * input: IC / pack_size, H, W, N, pack_size
  410. * Filter:
  411. * IC / pack_size, FH, FW, OC, pack_size [dense]
  412. * GROUP, icpg / pack_size, FH, FW, ocpg, pack_size [group]
  413. * not implemented [chanwise]
  414. */
  415. megdnn_assert(param.format == Param::Format::CHWN4);
  416. auto img_ndim = src_ndim - 3;
  417. size_t flt_start = 0, flt_spatial_start = 1;
  418. if (param.sparse == Param::Sparse::DENSE) {
  419. megdnn_assert(
  420. filter.ndim == img_ndim + 3,
  421. "bad filter ndim for dense convolution: "
  422. "spatial_ndim=%zu filter_ndim=%zu",
  423. img_ndim, filter.ndim);
  424. // oc, ic, dims[]
  425. ret.group = 1;
  426. flt_start = 0;
  427. } else {
  428. megdnn_assert(
  429. param.sparse == Param::Sparse::GROUP,
  430. "invalid convolution sparse type");
  431. megdnn_assert(
  432. filter.ndim == img_ndim + 4,
  433. "bad filter ndim for group convolution: "
  434. "spatial_ndim=%zu filter_ndim=%zu",
  435. img_ndim, filter.ndim);
  436. ret.group = filter[0];
  437. flt_start = 1;
  438. }
  439. ret.spatial_ndim = src_ndim - 3;
  440. megdnn_assert(
  441. ret.spatial_ndim == 2,
  442. "only 2D convolution is supported, and input should be 4-dim; "
  443. "got input dim = %zu",
  444. src_ndim);
  445. ret.icpg = filter[flt_start] * pack_size;
  446. ret.ocpg = filter[flt_start + 3];
  447. auto dilation = ret.dilation;
  448. for (size_t i = 0; i < ret.spatial_ndim; ++i) {
  449. megdnn_assert(
  450. dilation[i] == 1,
  451. "CHWNx has invalid dilation on spatial dim %zu: %u, "
  452. "require to be 1",
  453. i, dilation[i]);
  454. ret.spatial[i] = filter[i + flt_start + flt_spatial_start];
  455. ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1;
  456. }
  457. }
  458. } // namespace
  459. namespace megdnn {
  460. template <typename Parameter>
  461. typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Parameter>::
  462. make_canonized_filter_meta(size_t src_ndim, const TensorLayout& filter) const {
  463. megdnn_assert_contiguous(filter);
  464. CanonizedFilterMeta ret;
  465. ret.dtype = filter.dtype;
  466. ret.format = param().format;
  467. if (param().mode == Mode::CONVOLUTION) {
  468. ret.should_flip = true;
  469. } else {
  470. megdnn_assert(param().mode == Mode::CROSS_CORRELATION, "invalid conv mode");
  471. ret.should_flip = false;
  472. }
  473. ret.stride[0] = param().stride_h;
  474. ret.stride[1] = param().stride_w;
  475. ret.padding[0] = param().pad_h;
  476. ret.padding[1] = param().pad_w;
  477. ret.dilation[0] = param().dilate_h;
  478. ret.dilation[1] = param().dilate_w;
  479. if (param().format == Param::Format::NHWCD4) {
  480. if (filter.dtype.enumv() == DTypeEnum::QuantizedS8 ||
  481. filter.dtype.enumv() == DTypeEnum::Quantized8Asymm) {
  482. make_canonized_filter_meta_nhwcd4_dot<Parameter>(
  483. src_ndim, filter, param(), ret);
  484. } else {
  485. make_canonized_filter_meta_nhwcd4<Parameter>(
  486. src_ndim, filter, param(), ret);
  487. }
  488. } else if (
  489. param().format == Param::Format::NCHW4 ||
  490. param().format == Param::Format::NCHW4_NCHW ||
  491. param().format == Param::Format::NCHW4_NHWC ||
  492. param().format == Param::Format::NCHW4_NCHW32) {
  493. make_canonized_filter_meta_nchwx<4, Parameter>(src_ndim, filter, param(), ret);
  494. } else if (param().format == Param::Format::NCHW8) {
  495. make_canonized_filter_meta_nchwx<8, Parameter>(src_ndim, filter, param(), ret);
  496. } else if (param().format == Param::Format::NCHW88) {
  497. make_canonized_filter_meta_nchwxx<8, Parameter>(src_ndim, filter, param(), ret);
  498. } else if (
  499. param().format == Param::Format::NCHW44 ||
  500. param().format == Param::Format::NCHW44_DOT) {
  501. make_canonized_filter_meta_nchwxx<4, Parameter>(src_ndim, filter, param(), ret);
  502. } else if (
  503. param().format == Param::Format::NCHW32 ||
  504. param().format == Param::Format::NCHW32_NCHW4) {
  505. make_canonized_filter_meta_nchwx<32, Parameter>(src_ndim, filter, param(), ret);
  506. } else if (param().format == Param::Format::CHWN4) {
  507. make_canonized_filter_meta_chwnx<4, Parameter>(src_ndim, filter, param(), ret);
  508. } else if (param().format == Param::Format::NCHW64) {
  509. make_canonized_filter_meta_nchwx<64, Parameter>(src_ndim, filter, param(), ret);
  510. } else {
  511. megdnn_assert(
  512. param().format == Param::Format::NHWC ||
  513. param().format == Param::Format::NCHW);
  514. make_canonized_filter_meta_nchw_nhwc<Parameter>(src_ndim, filter, param(), ret);
  515. }
  516. return ret;
  517. }
  518. template <typename Parameter>
  519. void ConvolutionBase<Parameter>::check_or_deduce_dtype_fwd(
  520. DType src, DType filter, DType& dst) const {
  521. // The first one will be the default choice.
  522. SmallVector<DType> supported_dst_dtype;
  523. // We rely on megdnn_assert(src.enumv() == filter.enumv()) here.
  524. if (src.category() == DTypeCategory::FLOAT) {
  525. supported_dst_dtype.push_back(src);
  526. } else if (src.enumv() == DTypeEnum::Int8) {
  527. supported_dst_dtype = {dtype::Int32(), dtype::Int16()};
  528. } else if (
  529. src.enumv() == DTypeEnum::QuantizedS8 ||
  530. src.enumv() == DTypeEnum::Quantized8Asymm ||
  531. src.enumv() == DTypeEnum::QuantizedS4 ||
  532. src.enumv() == DTypeEnum::Quantized4Asymm ||
  533. src.enumv() == DTypeEnum::QuantizedS1) {
  534. supported_dst_dtype.push_back(dtype::QuantizedS32(mul_scale(src, filter)));
  535. bool cond_dst = dst.valid() && (dst.enumv() == src.enumv() ||
  536. ((dst.enumv() == DTypeEnum::QuantizedS4 ||
  537. dst.enumv() == DTypeEnum::Quantized4Asymm) &&
  538. src.enumv() == DTypeEnum::QuantizedS8) ||
  539. ((src.enumv() == DTypeEnum::QuantizedS4 ||
  540. src.enumv() == DTypeEnum::Quantized4Asymm) &&
  541. dst.enumv() == DTypeEnum::QuantizedS8));
  542. if (cond_dst) {
  543. supported_dst_dtype.push_back(dst);
  544. }
  545. if (src.enumv() == DTypeEnum::QuantizedS8) {
  546. supported_dst_dtype.push_back(dtype::Float32());
  547. }
  548. } else if (src.enumv() == DTypeEnum::QuantizedS32) {
  549. //! ConvolutionBackwardData: s8(filter) + s8(dst) -> s32(src)
  550. megdnn_assert(filter.enumv() == DTypeEnum::QuantizedS8);
  551. supported_dst_dtype.push_back(dtype::QuantizedS8(
  552. src.param<dtype::QuantizedS32>().scale /
  553. filter.param<dtype::QuantizedS8>().scale));
  554. } else {
  555. megdnn_throw(ssprintf(
  556. "unsupported input / filter DType: %s x %s", src.name(),
  557. filter.name()));
  558. }
  559. if (!dst.valid()) {
  560. dst = supported_dst_dtype.at(0);
  561. } else {
  562. bool dst_supported = false;
  563. for (auto&& dt : supported_dst_dtype) {
  564. if (dtype_almost_equal(dt, dst)) {
  565. dst_supported = true;
  566. break;
  567. }
  568. }
  569. MEGDNN_MARK_USED_VAR(dst_supported);
  570. megdnn_assert(
  571. dst_supported, "unsupported Conv(%s, %s) -> %s", src.name(),
  572. filter.name(), dst.name());
  573. }
  574. megdnn_assert(
  575. (param().compute_mode == Param::ComputeMode::FLOAT32 ||
  576. param().compute_mode == Param::ComputeMode::DEFAULT)
  577. #if !MEGDNN_DISABLE_FLOAT16
  578. || src.enumv() == DTypeEnum::Float16 ||
  579. src.enumv() == DTypeEnum::BFloat16
  580. #endif
  581. ,
  582. "ComputeMode::FLOAT32 is only available for Float16/BFloat16 "
  583. "input / output.");
  584. }
  585. template <typename Parameter>
  586. typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Parameter>::
  587. deduce_layout_fwd(
  588. const TensorLayout& src, const TensorLayout& filter,
  589. TensorLayout& dst) const {
  590. auto errmsg = [&]() { return get_errmsg(src, filter, dst, param()); };
  591. MEGDNN_MARK_USED_VAR(errmsg);
  592. megdnn_assert(src.ndim >= 3_z, "%s", errmsg().c_str());
  593. megdnn_assert(
  594. ((src.dtype.enumv() == filter.dtype.enumv()) ||
  595. (src.dtype.enumv() == DTypeEnum::Quantized4Asymm &&
  596. filter.dtype.enumv() == DTypeEnum::QuantizedS4)),
  597. "%s", errmsg().c_str());
  598. check_or_deduce_dtype_fwd(src.dtype, filter.dtype, dst.dtype);
  599. size_t img_dim;
  600. if (param().format == Param::Format::NCHW ||
  601. param().format == Param::Format::NHWC) {
  602. img_dim = src.ndim - 2;
  603. megdnn_assert(
  604. filter.ndim >= img_dim + 2 && filter.ndim <= img_dim + 6, "%s",
  605. errmsg().c_str());
  606. } else {
  607. megdnn_assert(
  608. param().format == Param::Format::NHWCD4 ||
  609. param().format == Param::Format::NCHW4 ||
  610. param().format == Param::Format::NCHW4_NCHW ||
  611. param().format == Param::Format::NCHW4_NHWC ||
  612. param().format == Param::Format::NCHW4_NCHW32 ||
  613. param().format == Param::Format::NCHW44 ||
  614. param().format == Param::Format::NCHW44_DOT ||
  615. param().format == Param::Format::NCHW8 ||
  616. param().format == Param::Format::NCHW32 ||
  617. param().format == Param::Format::NCHW32_NCHW4 ||
  618. param().format == Param::Format::NCHW88 ||
  619. param().format == Param::Format::CHWN4 ||
  620. param().format == Param::Format::NCHW64);
  621. img_dim = src.ndim - 3;
  622. if ((param().format == Param::Format::NCHW88 ||
  623. param().format == Param::Format::NCHW44_DOT ||
  624. param().format == Param::Format::NCHW44) &&
  625. filter.ndim == 5) {
  626. img_dim = src.ndim - 2;
  627. }
  628. megdnn_assert(
  629. filter.ndim == img_dim + 3 ||
  630. (filter.ndim == img_dim + 2 &&
  631. (param().format == Param::Format::NCHW88 ||
  632. param().format == Param::Format::NCHW44_DOT ||
  633. param().format == Param::Format::NCHW44)) ||
  634. filter.ndim == img_dim + 4 || filter.ndim == img_dim + 5,
  635. "%s", errmsg().c_str());
  636. if (param().format == Param::Format::NCHW4 ||
  637. param().format == Param::Format::NCHW4_NCHW ||
  638. param().format == Param::Format::NCHW4_NCHW32) {
  639. megdnn_assert(
  640. src.ndim == 5 &&
  641. (filter.ndim == 5 || filter.ndim == 6 ||
  642. filter.ndim == 7) &&
  643. src[src.ndim - 1] == 4 && filter[filter.ndim - 1] == 4,
  644. "NCHW4/NCHW4_NCHW/NCHW4_NCHW32 require src and "
  645. "filter's ndim is "
  646. "5 or 6, and "
  647. "last shape "
  648. "is 4 "
  649. "but got src %s, filter %s",
  650. src.to_string().c_str(), filter.to_string().c_str());
  651. }
  652. if (param().format == Param::Format::NCHW8) {
  653. megdnn_assert(
  654. src.ndim == 5 && (filter.ndim == 5 || filter.ndim == 6) &&
  655. src[src.ndim - 1] == 8 && filter[filter.ndim - 1] == 8,
  656. "NCHW8 require src and filter's ndim is 5 or 6, and last "
  657. "shape is 8 "
  658. "but got src %s, filter %s",
  659. src.to_string().c_str(), filter.to_string().c_str());
  660. }
  661. if (param().format == Param::Format::NCHW32 ||
  662. param().format == Param::Format::NCHW32_NCHW4) {
  663. megdnn_assert(
  664. src.ndim == 5 && (filter.ndim == 5 || filter.ndim == 6) &&
  665. src[src.ndim - 1] == 32 && filter[filter.ndim - 1] == 32,
  666. "NCHW32/NCHW32_NCHW4 require src and filter's ndim "
  667. "is 5 or 6, and last "
  668. "shape is 32 "
  669. "but got src %s, filter %s",
  670. src.to_string().c_str(), filter.to_string().c_str());
  671. }
  672. if (param().format == Param::Format::NCHW88) {
  673. megdnn_assert(
  674. (src.ndim == 4 && filter.ndim == 5 &&
  675. filter[filter.ndim - 1] == 8) ||
  676. (src.ndim == 5 &&
  677. ((filter.ndim == 6 && filter[filter.ndim - 1] == 8) ||
  678. (filter.ndim == 7 && filter[filter.ndim - 1] == 8 &&
  679. filter[filter.ndim - 2] == 8)) &&
  680. src[src.ndim - 1] == 8),
  681. "NCHW88 require src ndim is 5 and filter's ndim is 6 "
  682. ", and last shape two is 8 but got src %s, filter %s",
  683. src.to_string().c_str(), filter.to_string().c_str());
  684. }
  685. if (param().format == Param::Format::NCHW44 ||
  686. param().format == Param::Format::NCHW44_DOT) {
  687. //! support nchw44 filter change to 88 for int8 winogradf23_88 using
  688. //! MK8 mamtul
  689. megdnn_assert(
  690. (src.ndim == 4 && filter.ndim == 5 &&
  691. filter[filter.ndim - 1] == 4) ||
  692. (src.ndim == 5 &&
  693. ((filter.ndim == 6 && (filter[filter.ndim - 1] == 4 ||
  694. filter[filter.ndim - 1] == 8)) ||
  695. (filter.ndim == 7 &&
  696. (filter[filter.ndim - 1] == 4 ||
  697. filter[filter.ndim - 1] == 8) &&
  698. (filter[filter.ndim - 2] == 4 ||
  699. filter[filter.ndim - 2] == 8))) &&
  700. src[src.ndim - 1] == 4),
  701. "NCHW44 require src ndim is 5 and filter's ndim is 6 "
  702. ", and last shape two is 4 but got src %s, filter %s",
  703. src.to_string().c_str(), filter.to_string().c_str());
  704. }
  705. if (param().format == Param::Format::CHWN4) {
  706. megdnn_assert(
  707. src.ndim == 5 && (filter.ndim == 5 || filter.ndim == 6) &&
  708. src[src.ndim - 1] == 4 && filter[filter.ndim - 1] == 4,
  709. "CHWN4 require src and filter's ndim is 5 or 6, and last "
  710. "shape is 4 "
  711. "but got src %s, filter %s",
  712. src.to_string().c_str(), filter.to_string().c_str());
  713. }
  714. if (param().format == Param::Format::NCHW64) {
  715. megdnn_assert(
  716. src.ndim == 5 && (filter.ndim == 5 || filter.ndim == 6) &&
  717. src[src.ndim - 1] == 64 && filter[filter.ndim - 1] == 64,
  718. "NCHW64 require src and filter's ndim is 5 or 6, and "
  719. "last shape is 64 but got src %s, filter %s",
  720. src.to_string().c_str(), filter.to_string().c_str());
  721. }
  722. }
  723. megdnn_assert(img_dim == 2, "currently only convolution on 2D image is supported");
  724. auto cflt = make_canonized_filter_meta(src.ndim, filter);
  725. if (param().format == Param::Format::NCHW ||
  726. param().format == Param::Format::NHWC) {
  727. size_t src_or_dst_c_pos = 0;
  728. size_t src_or_dst_spatial_start = 0;
  729. if (param().format == Param::Format::NCHW) {
  730. src_or_dst_c_pos = 1;
  731. src_or_dst_spatial_start = 2;
  732. } else {
  733. megdnn_assert(param().format == Param::Format::NHWC, "invalid conv format");
  734. src_or_dst_c_pos = 3;
  735. src_or_dst_spatial_start = 1;
  736. }
  737. megdnn_assert(
  738. cflt.icpg * cflt.group == src[src_or_dst_c_pos], "%s",
  739. errmsg().c_str());
  740. dst.ndim = src.ndim;
  741. dst[0] = src[0];
  742. dst[src_or_dst_c_pos] = cflt.ocpg * cflt.group;
  743. for (size_t i = 0; i < cflt.spatial_ndim; ++i) {
  744. dst[i + src_or_dst_spatial_start] = infer_conv_shape(
  745. src[i + src_or_dst_spatial_start], cflt.dilated_spatial[i],
  746. cflt.stride[i], cflt.padding[i]);
  747. }
  748. } else if (param().format == Param::Format::NCHW4) {
  749. megdnn_assert(
  750. src.ndim == 5, "invalid src ndim for NCHW4, expected=5, got=%zu",
  751. src.ndim);
  752. megdnn_assert(
  753. cflt.icpg * cflt.group == src[1] * 4, "%s icpg=%u group=%u",
  754. errmsg().c_str(), cflt.icpg, cflt.group);
  755. dst.ndim = src.ndim;
  756. dst[0] = src[0];
  757. auto oc = cflt.ocpg * cflt.group;
  758. megdnn_assert(oc % 4 == 0);
  759. dst[1] = oc / 4;
  760. dst[2] = infer_conv_shape(
  761. src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  762. dst[3] = infer_conv_shape(
  763. src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  764. dst[4] = 4;
  765. } else if (param().format == Param::Format::NCHW8) {
  766. megdnn_assert(
  767. src.ndim == 5, "invalid src ndim for NCHW8, expected=5, got=%zu",
  768. src.ndim);
  769. megdnn_assert(
  770. cflt.icpg * cflt.group == src[1] * 8, "%s icpg=%u group=%u",
  771. errmsg().c_str(), cflt.icpg, cflt.group);
  772. dst.ndim = src.ndim;
  773. dst[0] = src[0];
  774. auto oc = cflt.ocpg * cflt.group;
  775. megdnn_assert(oc % 8 == 0);
  776. dst[1] = oc / 8;
  777. dst[2] = infer_conv_shape(
  778. src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  779. dst[3] = infer_conv_shape(
  780. src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  781. dst[4] = 8;
  782. } else if (param().format == Param::Format::NCHW32) {
  783. megdnn_assert(
  784. src.ndim == 5, "invalid src ndim for NCHW32, expected=5, got=%zu",
  785. src.ndim);
  786. megdnn_assert(
  787. cflt.icpg * cflt.group == src[1] * 32, "%s icpg=%u group=%u",
  788. errmsg().c_str(), cflt.icpg, cflt.group);
  789. dst.ndim = src.ndim;
  790. dst[0] = src[0];
  791. auto oc = cflt.ocpg * cflt.group;
  792. megdnn_assert(oc % 32 == 0);
  793. dst[1] = oc / 32;
  794. dst[2] = infer_conv_shape(
  795. src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  796. dst[3] = infer_conv_shape(
  797. src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  798. dst[4] = 32;
  799. } else if (param().format == Param::Format::NCHW88) {
  800. megdnn_assert(
  801. src.ndim == 5 || (src.ndim == 4 && src[1] <= 8),
  802. "invalid src ndim for NCHW88, expected=5 or 4, got=%zu", src.ndim);
  803. dst.ndim = 5;
  804. dst[0] = src[0];
  805. auto oc = cflt.ocpg * cflt.group;
  806. megdnn_assert(oc % 8 == 0);
  807. dst[1] = oc / 8;
  808. dst[2] = infer_conv_shape(
  809. src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  810. dst[3] = infer_conv_shape(
  811. src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  812. dst[4] = 8;
  813. if (cflt.group == 1) {
  814. megdnn_assert(
  815. cflt.icpg * cflt.group == src[1] * 8 ||
  816. (cflt.icpg * cflt.group == src[1]),
  817. "%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg, cflt.group);
  818. }
  819. } else if (
  820. param().format == Param::Format::NCHW44 ||
  821. param().format == Param::Format::NCHW44_DOT) {
  822. megdnn_assert(
  823. src.ndim == 5 || (src.ndim == 4 && src[1] <= 4),
  824. "invalid src ndim for NCHW44, expected=5 or 4, got=%zu", src.ndim);
  825. dst.ndim = 5;
  826. dst[0] = src[0];
  827. auto oc = cflt.ocpg * cflt.group;
  828. megdnn_assert(oc % 4 == 0);
  829. dst[1] = oc / 4;
  830. dst[2] = infer_conv_shape(
  831. src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  832. dst[3] = infer_conv_shape(
  833. src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  834. dst[4] = 4;
  835. if (cflt.group == 1) {
  836. megdnn_assert(
  837. cflt.icpg * cflt.group == src[1] * 4 ||
  838. (cflt.icpg * cflt.group == src[1]),
  839. "%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg, cflt.group);
  840. }
  841. } else if (param().format == Param::Format::CHWN4) {
  842. megdnn_assert(
  843. src.ndim == 5, "invalid src ndim for CHWN4, expected=5, got=%zu",
  844. src.ndim);
  845. megdnn_assert(
  846. cflt.icpg * cflt.group == src[0] * 4, "%s icpg=%u group=%u",
  847. errmsg().c_str(), cflt.icpg, cflt.group);
  848. dst.ndim = src.ndim;
  849. dst[3] = src[3];
  850. auto oc = cflt.ocpg * cflt.group;
  851. megdnn_assert(oc % 4 == 0);
  852. dst[0] = oc / 4;
  853. dst[1] = infer_conv_shape(
  854. src[1], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  855. dst[2] = infer_conv_shape(
  856. src[2], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  857. dst[4] = 4;
  858. } else if (param().format == Param::Format::NCHW4_NCHW) {
  859. megdnn_assert(
  860. src.ndim == 5, "invalid src ndim for NCHW4_NCHW, expected=5, got=%zu",
  861. src.ndim);
  862. megdnn_assert(
  863. cflt.icpg * cflt.group == src[1] * 4, "%s icpg=%u group=%u",
  864. errmsg().c_str(), cflt.icpg, cflt.group);
  865. dst.ndim = 4;
  866. dst[0] = src[0];
  867. auto oc = cflt.ocpg * cflt.group;
  868. dst[1] = oc;
  869. dst[2] = infer_conv_shape(
  870. src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  871. dst[3] = infer_conv_shape(
  872. src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  873. } else if (param().format == Param::Format::NCHW4_NHWC) {
  874. megdnn_assert(
  875. src.ndim == 5, "invalid src ndim for NCHW4_NHWC, expected=5, got=%zu",
  876. src.ndim);
  877. megdnn_assert(
  878. cflt.icpg * cflt.group == src[1] * 4, "%s icpg=%u group=%u",
  879. errmsg().c_str(), cflt.icpg, cflt.group);
  880. dst.ndim = 4;
  881. dst[0] = src[0];
  882. dst[1] = infer_conv_shape(
  883. src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  884. dst[2] = infer_conv_shape(
  885. src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  886. auto oc = cflt.ocpg * cflt.group;
  887. dst[3] = oc;
  888. } else if (param().format == Param::Format::NCHW4_NCHW32) {
  889. megdnn_assert(
  890. src.ndim == 5, "invalid src ndim for NCHW4_NCHW32, expected=5, got=%zu",
  891. src.ndim);
  892. megdnn_assert(
  893. cflt.icpg * cflt.group == src[1] * 4, "%s icpg=%u group=%u",
  894. errmsg().c_str(), cflt.icpg, cflt.group);
  895. dst.ndim = src.ndim;
  896. dst[0] = src[0];
  897. auto oc = cflt.ocpg * cflt.group;
  898. megdnn_assert(oc % 32 == 0);
  899. dst[1] = oc / 32;
  900. dst[2] = infer_conv_shape(
  901. src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  902. dst[3] = infer_conv_shape(
  903. src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  904. dst[4] = 32;
  905. } else if (param().format == Param::Format::NCHW32_NCHW4) {
  906. megdnn_assert(
  907. src.ndim == 5, "invalid src ndim for NCHW32_NCHW4, expected=5, got=%zu",
  908. src.ndim);
  909. megdnn_assert(
  910. cflt.icpg * cflt.group == src[1] * 32, "%s icpg=%u group=%u",
  911. errmsg().c_str(), cflt.icpg, cflt.group);
  912. dst.ndim = src.ndim;
  913. dst[0] = src[0];
  914. auto oc = cflt.ocpg * cflt.group;
  915. megdnn_assert(oc % 4 == 0);
  916. dst[1] = oc / 4;
  917. dst[2] = infer_conv_shape(
  918. src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  919. dst[3] = infer_conv_shape(
  920. src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  921. dst[4] = 4;
  922. } else if (param().format == Param::Format::NCHW64) {
  923. megdnn_assert(
  924. src.ndim == 5, "invalid src ndim for NCHW64, expected=5, got=%zu",
  925. src.ndim);
  926. megdnn_assert(
  927. cflt.icpg * cflt.group == src[1] * 64, "%s icpg=%u group=%u",
  928. errmsg().c_str(), cflt.icpg, cflt.group);
  929. dst.ndim = src.ndim;
  930. dst[0] = src[0];
  931. auto oc = cflt.ocpg * cflt.group;
  932. megdnn_assert(oc % 64 == 0);
  933. dst[1] = oc / 64;
  934. dst[2] = infer_conv_shape(
  935. src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  936. dst[3] = infer_conv_shape(
  937. src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  938. dst[4] = 64;
  939. } else {
  940. megdnn_assert(param().format == Param::Format::NHWCD4);
  941. megdnn_assert(
  942. src.ndim == 5, "invalid src ndim for NHWCD4, expected=5, got=%zu",
  943. src.ndim);
  944. megdnn_assert(
  945. cflt.icpg * cflt.group == src[2] * 4, "%s icpg=%u group=%u",
  946. errmsg().c_str(), cflt.icpg, cflt.group);
  947. dst.ndim = src.ndim;
  948. dst[0] = src[0];
  949. auto oc = cflt.ocpg * cflt.group;
  950. megdnn_assert(oc % 4 == 0);
  951. dst[2] = oc / 4;
  952. dst[1] = infer_conv_shape(
  953. src[1], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  954. dst[3] = infer_conv_shape(
  955. src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  956. megdnn_assert(src[4] == 4);
  957. dst[4] = 4;
  958. }
  959. if (!src.format.is_default() && !src.format.is_lowbit_aligned()) { // propagate
  960. dst.format = src.format;
  961. } else { // determined by dtype
  962. dst.format = TensorFormat(dst.dtype);
  963. }
  964. dst.init_contiguous_stride();
  965. return cflt;
  966. }
  967. /**
  968. * \warning: An explicit specialization shall be declared in a namespace
  969. * enclosing the specialized template. An explicit specialization whose
  970. * declarator-id is not qualified shall be declared in the nearest enclosing
  971. * namespace of the template, or, if the namespace is inline (7.3.1), any
  972. * namespace from its enclosing namespace set.
  973. * refer to:
  974. * https://stackoverflow.com/questions/25594644/warning-specialization-of-template-in-different-namespace
  975. */
  976. template <>
  977. ConvolutionBase<param::Convolution>::CanonizedFilterMeta ConvolutionBase<
  978. param::Convolution>::
  979. check_layout_fwd(
  980. const TensorLayout& src, const TensorLayout& filter,
  981. const TensorLayout& dst) const {
  982. megdnn_assert_contiguous(src);
  983. megdnn_assert_contiguous(filter);
  984. TensorLayout dst_expected;
  985. dst_expected.dtype = dst.dtype;
  986. auto ret = deduce_layout_fwd(src, filter, dst_expected);
  987. megdnn_assert_eq_layout(dst_expected, dst);
  988. return ret;
  989. }
  990. template <>
  991. ConvolutionBase<param::ConvBias>::CanonizedFilterMeta ConvolutionBase<param::ConvBias>::
  992. check_layout_fwd(
  993. const TensorLayout& src, const TensorLayout& filter,
  994. const TensorLayout& dst) const {
  995. megdnn_assert_contiguous(src);
  996. megdnn_assert_contiguous(filter);
  997. TensorLayout dst_expected;
  998. dst_expected.dtype = dst.dtype;
  999. auto ret = deduce_layout_fwd(src, filter, dst_expected);
  1000. megdnn_assert_eq_layout(dst_expected, dst);
  1001. return ret;
  1002. }
  1003. template <>
  1004. ConvolutionBase<param::BatchConvBias>::CanonizedFilterMeta ConvolutionBase<
  1005. param::BatchConvBias>::
  1006. check_layout_fwd(
  1007. const TensorLayout& src, const TensorLayout& filter,
  1008. const TensorLayout& dst) const {
  1009. megdnn_assert_contiguous(src);
  1010. megdnn_assert_contiguous(filter);
  1011. TensorLayout dst_expected;
  1012. dst_expected.dtype = dst.dtype;
  1013. auto ret = deduce_layout_fwd(src, filter, dst_expected);
  1014. megdnn_assert_eq_layout(dst_expected, dst);
  1015. return ret;
  1016. }
  1017. void ConvolutionForward::deduce_dtype(DType src, DType filter, DType& dst) {
  1018. check_or_deduce_dtype_fwd(src, filter, dst);
  1019. }
  1020. void ConvolutionForward::deduce_layout(
  1021. const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst) {
  1022. deduce_layout_fwd(src, filter, dst);
  1023. }
  1024. ConvolutionForward::CanonizedFilterMeta ConvolutionForward::check_exec(
  1025. const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst,
  1026. size_t workspace_in_bytes, const PreprocessedFilter* preprocessed_filter) {
  1027. auto ret = check_layout_fwd(src, filter, dst);
  1028. auto required_workspace_in_bytes =
  1029. get_workspace_in_bytes(src, filter, dst, preprocessed_filter);
  1030. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  1031. return ret;
  1032. }
  1033. ConvolutionBackwardData::CanonizedFilterMeta ConvolutionBackwardData::check_exec(
  1034. const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& grad,
  1035. size_t workspace_in_bytes) {
  1036. auto grad_fwd = grad;
  1037. auto filter_fwd = filter;
  1038. auto diff_fwd = diff;
  1039. std::swap(grad_fwd.dtype, diff_fwd.dtype);
  1040. grad_fwd.init_contiguous_stride();
  1041. diff_fwd.init_contiguous_stride();
  1042. auto ret = check_layout_fwd(grad_fwd, filter_fwd, diff_fwd);
  1043. auto required_workspace_in_bytes = get_workspace_in_bytes(filter, diff, grad);
  1044. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  1045. return ret;
  1046. }
  1047. void ConvolutionBackwardData::deduce_dtype(DType filter, DType diff, DType& grad) {
  1048. SmallVector<DType> supported_dst_dtype;
  1049. if (filter.category() == diff.category() &&
  1050. filter.category() == DTypeCategory::FLOAT) {
  1051. supported_dst_dtype.push_back(filter);
  1052. } else if (filter.enumv() == DTypeEnum::Int8 && diff == filter) {
  1053. supported_dst_dtype.push_back(dtype::Int32());
  1054. } else if (
  1055. (filter.enumv() == DTypeEnum::QuantizedS8 &&
  1056. diff.enumv() == DTypeEnum::QuantizedS8) ||
  1057. (filter.enumv() == DTypeEnum::Quantized8Asymm &&
  1058. diff.enumv() == DTypeEnum::Quantized8Asymm)) {
  1059. supported_dst_dtype.push_back(dtype::QuantizedS32(mul_scale(filter, diff)));
  1060. if (grad.valid() && grad.enumv() == diff.enumv()) {
  1061. supported_dst_dtype.push_back(grad);
  1062. }
  1063. } else {
  1064. megdnn_throw(ssprintf(
  1065. "unsupported input / diff DType: %s x %s", filter.name(), diff.name()));
  1066. }
  1067. if (!grad.valid()) {
  1068. grad = supported_dst_dtype.at(0);
  1069. } else {
  1070. megdnn_assert(
  1071. vec_contains(supported_dst_dtype, grad),
  1072. "unsupported ConvBwd(%s, %s) -> %s", filter.name(), diff.name(),
  1073. grad.name());
  1074. }
  1075. megdnn_assert(
  1076. param().compute_mode != Param::ComputeMode::FLOAT32
  1077. #if !MEGDNN_DISABLE_FLOAT16
  1078. || filter.enumv() == DTypeEnum::Float16 ||
  1079. filter.enumv() == DTypeEnum::BFloat16
  1080. #endif
  1081. ,
  1082. "ComputeMode::FLOAT32 is only available for Float16/BFloat16 "
  1083. "input / output.");
  1084. }
  1085. void ConvolutionBackwardData::deduce_layout(
  1086. const TensorLayout& filter, const TensorLayout& diff, TensorLayout& grad) {
  1087. auto errmsg = [&]() { return get_errmsg(filter, diff, grad, param()); };
  1088. MEGDNN_MARK_USED_VAR(errmsg);
  1089. megdnn_assert_contiguous(filter);
  1090. megdnn_assert_contiguous(diff);
  1091. megdnn_assert(filter.ndim == 4_z || filter.ndim == 5_z, "%s", errmsg().c_str());
  1092. megdnn_assert(diff.ndim == 4_z || diff.ndim == 5_z, "%s", errmsg().c_str());
  1093. deduce_dtype(filter.dtype, diff.dtype, grad.dtype);
  1094. auto cflt = make_canonized_filter_meta(diff.ndim, filter);
  1095. auto deduce = [&errmsg](size_t out, size_t filter, size_t stride, size_t pad) {
  1096. MEGDNN_MARK_USED_VAR(errmsg);
  1097. auto i = (out - 1) * stride + filter;
  1098. megdnn_assert(i > pad * 2, "%s", errmsg().c_str());
  1099. return i - pad * 2;
  1100. };
  1101. if (param().format == Param::Format::NCHW ||
  1102. param().format == Param::Format::NHWC) {
  1103. size_t src_or_dst_c_pos = 0;
  1104. size_t src_or_dst_spatial_start = 0;
  1105. if (param().format == Param::Format::NCHW) {
  1106. src_or_dst_c_pos = 1;
  1107. src_or_dst_spatial_start = 2;
  1108. } else {
  1109. megdnn_assert(param().format == Param::Format::NHWC, "invalid conv format");
  1110. src_or_dst_c_pos = 3;
  1111. src_or_dst_spatial_start = 1;
  1112. }
  1113. megdnn_assert(
  1114. cflt.ocpg * cflt.group == diff[src_or_dst_c_pos], "%s",
  1115. errmsg().c_str());
  1116. grad.ndim = diff.ndim;
  1117. grad[0] = diff[0];
  1118. grad[src_or_dst_c_pos] = cflt.icpg * cflt.group;
  1119. for (size_t i = 0; i < cflt.spatial_ndim; ++i) {
  1120. grad[i + src_or_dst_spatial_start] =
  1121. deduce(diff[i + src_or_dst_spatial_start], cflt.dilated_spatial[i],
  1122. cflt.stride[i], cflt.padding[i]);
  1123. }
  1124. } else if (param().format == Param::Format::NCHW4) {
  1125. megdnn_assert(
  1126. diff.ndim == 5, "valid diff ndim for NCHW4, expected=5, got=%zu",
  1127. diff.ndim);
  1128. megdnn_assert(cflt.group == 1, "%s", errmsg().c_str());
  1129. megdnn_assert(cflt.ocpg * cflt.group == diff[1] * 4, "%s", errmsg().c_str());
  1130. grad.ndim = diff.ndim;
  1131. grad[0] = diff[0];
  1132. auto ic = cflt.icpg * cflt.group;
  1133. megdnn_assert(ic % 4 == 0);
  1134. grad[1] = ic / 4;
  1135. grad[2] = deduce(
  1136. diff[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  1137. grad[3] = deduce(
  1138. diff[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  1139. megdnn_assert(diff[4] == 4);
  1140. grad[4] = 4;
  1141. } else {
  1142. megdnn_assert(param().format == Param::Format::NHWCD4);
  1143. megdnn_assert(
  1144. diff.ndim == 5, "valid diff ndim for NHWCD4, expected=5, got=%zu",
  1145. diff.ndim);
  1146. megdnn_assert(cflt.ocpg * cflt.group == diff[2] * 4, "%s", errmsg().c_str());
  1147. grad.ndim = diff.ndim;
  1148. grad[0] = diff[0];
  1149. auto ic = cflt.icpg * cflt.group;
  1150. megdnn_assert(ic % 4 == 0);
  1151. grad[2] = ic / 4;
  1152. grad[1] = deduce(
  1153. diff[1], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  1154. grad[3] = deduce(
  1155. diff[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  1156. megdnn_assert(diff[4] == 4);
  1157. grad[4] = 4;
  1158. }
  1159. grad.format = diff.format;
  1160. grad.init_contiguous_stride();
  1161. }
  1162. ConvolutionBackwardFilter::CanonizedFilterMeta ConvolutionBackwardFilter::check_exec(
  1163. const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad,
  1164. size_t workspace_in_bytes) {
  1165. megdnn_assert(
  1166. src.dtype.category() == DTypeCategory::FLOAT &&
  1167. diff.dtype.category() == DTypeCategory::FLOAT &&
  1168. grad.dtype.category() == DTypeCategory::FLOAT,
  1169. "only float type is supported for conv backward filter");
  1170. auto src_fwd = src;
  1171. auto diff_fwd = diff;
  1172. src_fwd.init_contiguous_stride();
  1173. diff_fwd.init_contiguous_stride();
  1174. auto ret = check_layout_fwd(src_fwd, grad, diff_fwd);
  1175. auto required_workspace_in_bytes = get_workspace_in_bytes(src, diff, grad);
  1176. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  1177. return ret;
  1178. }
  1179. } // namespace megdnn
  1180. // vim: syntax=cpp.doxygen