You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

convolution.cpp 52 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183
  1. /**
  2. * \file dnn/src/common/convolution.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "megdnn/oprs/nn.h"
  13. #include "src/common/utils.h"
  14. using namespace megdnn;
  15. namespace {
  16. template <typename Param>
  17. std::string get_errmsg(const TensorLayout& src, const TensorLayout& filter,
  18. const TensorLayout& dst, const Param& param) {
  19. MEGDNN_MARK_USED_VAR(src);
  20. MEGDNN_MARK_USED_VAR(filter);
  21. MEGDNN_MARK_USED_VAR(dst);
  22. return megdnn_layout_msg(src) + ", " + megdnn_layout_msg(filter) + ", " +
  23. megdnn_layout_msg(dst) + ", " + megdnn_mangle("is_nchw=") +
  24. std::to_string(param.format == param::Convolution::Format::NCHW) +
  25. ", " + +megdnn_mangle("is_xcorr=") +
  26. std::to_string(
  27. (param.mode == Convolution::Mode::CROSS_CORRELATION)) +
  28. ", " + megdnn_mangle("pad_h=") + std::to_string(param.pad_h) + ", " +
  29. megdnn_mangle("pad_w=") + std::to_string(param.pad_w) + ", " +
  30. megdnn_mangle("stride_h=") + std::to_string(param.stride_h) + ", " +
  31. megdnn_mangle("stride_w=") + std::to_string(param.stride_w) + ", " +
  32. megdnn_mangle("dilate_h=") + std::to_string(param.dilate_h) + ", " +
  33. megdnn_mangle("dilate_w=") + std::to_string(param.dilate_w);
  34. }
  35. template <typename Param, typename Param::Format>
  36. uint32_t spatial_getter(uint32_t filter, const Param&) {
  37. return filter;
  38. }
  39. template <>
  40. uint32_t
  41. spatial_getter<param::ConvBias, param::ConvBias::Format::NCHW_WINOGRAD>(
  42. uint32_t filter, const param::ConvBias& param) {
  43. //! f = m + r - 1 -> r = f + 1 - m
  44. return filter - param.output_block_size + 1;
  45. }
  46. template <>
  47. uint32_t
  48. spatial_getter<param::ConvBias, param::ConvBias::Format::NCHW88_WINOGRAD>(
  49. uint32_t filter, const param::ConvBias& param) {
  50. //! f = m + r - 1 -> r = f + 1 - m
  51. return filter - param.output_block_size + 1;
  52. }
  53. template <>
  54. uint32_t
  55. spatial_getter<param::ConvBias, param::ConvBias::Format::NCHW44_WINOGRAD>(
  56. uint32_t filter, const param::ConvBias& param) {
  57. //! f = m + r - 1 -> r = f + 1 - m
  58. return filter - param.output_block_size + 1;
  59. }
  60. template <typename Parameter, typename Param>
  61. void make_canonized_filter_meta_nchw_nhwc(
  62. size_t src_ndim, const TensorLayout& filter, const Param& param,
  63. typename ConvolutionBase<Parameter>::CanonizedFilterMeta& ret) {
  64. megdnn_assert(param.format == Param::Format::NCHW ||
  65. param.format == Param::Format::NHWC ||
  66. param.format == Param::Format::NCHW_WINOGRAD);
  67. auto img_ndim = src_ndim - 2;
  68. size_t flt_start, flt_spatial_start, ocpg_pos, icpg_pos;
  69. if (param.sparse == Param::Sparse::DENSE) {
  70. megdnn_assert(
  71. filter.ndim == img_ndim + 2 || filter.ndim == img_ndim + 4,
  72. "bad filter ndim for dense convolution: "
  73. "spatial_ndim=%zu filter_ndim=%zu",
  74. img_ndim, filter.ndim);
  75. ret.group = 1;
  76. flt_start = 0;
  77. } else {
  78. megdnn_assert(param.sparse == Param::Sparse::GROUP,
  79. "invalid convolution sparse type");
  80. megdnn_assert(
  81. filter.ndim == img_ndim + 3 || filter.ndim == img_ndim + 5,
  82. "bad filter ndim for group convolution: "
  83. "spatial_ndim=%zu filter_ndim=%zu",
  84. img_ndim, filter.ndim);
  85. // grp, oc, ic, dims[]
  86. ret.group = filter[0];
  87. flt_start = 1;
  88. }
  89. uint32_t ic_block_size = 1, oc_block_size = 1;
  90. if (param.format == Param::Format::NCHW) {
  91. // filter should be (oc, ic, fh, fw)
  92. flt_spatial_start = 2;
  93. ocpg_pos = 0;
  94. icpg_pos = 1;
  95. } else if (param.format == Param::Format::NCHW_WINOGRAD) {
  96. // filter should be (alphah, alphaw, ic, oc) or (alphah, alphaw, ocb,
  97. // icb, ic_block_size, oc_block_size)
  98. flt_spatial_start = 0;
  99. if (filter.ndim == flt_start + 4) {
  100. ocpg_pos = 3;
  101. icpg_pos = 2;
  102. } else {
  103. megdnn_assert(filter.ndim == flt_start + 6);
  104. ic_block_size = filter[flt_start + 4];
  105. oc_block_size = filter[flt_start + 5];
  106. ocpg_pos = 2;
  107. icpg_pos = 3;
  108. }
  109. } else {
  110. megdnn_assert(param.format == Param::Format::NHWC,
  111. "invalid conv tensor format");
  112. // filter should be (oc, fh, fw, ic)
  113. flt_spatial_start = 1;
  114. ocpg_pos = 0;
  115. icpg_pos = 3;
  116. }
  117. ret.spatial_ndim = src_ndim - 2;
  118. megdnn_assert(
  119. ret.spatial_ndim == 2,
  120. "only 2D convolution is supported, and input should be 4-dim; "
  121. "got input dim = %zu",
  122. src_ndim);
  123. ret.ocpg = filter[flt_start + ocpg_pos] * oc_block_size;
  124. ret.icpg = filter[flt_start + icpg_pos] * ic_block_size;
  125. auto dilation = ret.dilation;
  126. for (size_t i = 0; i < ret.spatial_ndim; ++i) {
  127. megdnn_assert(dilation[i] > 0,
  128. "invalid dilation on spatial dim %zu: %u", i,
  129. dilation[i]);
  130. if (param.format == Param::Format::NCHW_WINOGRAD) {
  131. ret.spatial[i] =
  132. spatial_getter<Param, Param::Format::NCHW_WINOGRAD>(
  133. filter[i + flt_start + flt_spatial_start], param);
  134. } else {
  135. ret.spatial[i] = spatial_getter<Param, Param::Format::NCHW>(
  136. filter[i + flt_start + flt_spatial_start], param);
  137. }
  138. ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1;
  139. }
  140. }
  141. template <typename Parameter, typename Param>
  142. void make_canonized_filter_meta_nhwcd4(
  143. size_t src_ndim, const TensorLayout& filter, const Param& param,
  144. typename ConvolutionBase<Parameter>::CanonizedFilterMeta& ret) {
  145. /**
  146. * input: N H IC/4 W 4
  147. * Filter:
  148. * OC/4, FH, FW, IC, 4 [dense]
  149. * GROUP, OC/4, FH, FW, IC, 4 [group]
  150. * GROUP/4, 1, FH, FW, 4 [chanwise]
  151. */
  152. megdnn_assert(param.format == Param::Format::NHWCD4);
  153. auto img_ndim = src_ndim - 3;
  154. size_t flt_start = 0, flt_spatial_start = 1;
  155. bool is_chanwise = false;
  156. if (param.sparse == Param::Sparse::DENSE) {
  157. megdnn_assert(filter.ndim == img_ndim + 3,
  158. "bad filter ndim for dense convolution: "
  159. "spatial_ndim=%zu filter_ndim=%zu",
  160. img_ndim, filter.ndim);
  161. // oc, ic, dims[]
  162. ret.group = 1;
  163. flt_start = 0;
  164. } else {
  165. megdnn_assert(param.sparse == Param::Sparse::GROUP,
  166. "invalid convolution sparse type");
  167. megdnn_assert(
  168. filter.ndim == img_ndim + 3 || filter.ndim == img_ndim + 4,
  169. "bad filter ndim for group convolution: "
  170. "spatial_ndim=%zu filter_ndim=%zu",
  171. img_ndim, filter.ndim);
  172. if (filter.ndim == img_ndim + 3 && filter[1] == 1) {
  173. is_chanwise = true;
  174. ret.group = filter[0] * 4;
  175. } else {
  176. ret.group = filter[0];
  177. }
  178. flt_start = 1;
  179. }
  180. ret.spatial_ndim = src_ndim - 3;
  181. megdnn_assert(
  182. ret.spatial_ndim == 2,
  183. "only 2D convolution is supported, and input should be 4-dim; "
  184. "got input dim = %zu",
  185. src_ndim);
  186. if (is_chanwise) {
  187. ret.ocpg = 1;
  188. ret.icpg = 1;
  189. } else {
  190. ret.ocpg = filter[flt_start] * 4;
  191. ret.icpg = filter[flt_start + 3];
  192. }
  193. auto dilation = ret.dilation;
  194. for (size_t i = 0; i < ret.spatial_ndim; ++i) {
  195. megdnn_assert(dilation[i] > 0,
  196. "invalid dilation on spatial dim %zu: %u", i,
  197. dilation[i]);
  198. ret.spatial[i] = filter[i + flt_start + flt_spatial_start];
  199. ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1;
  200. }
  201. }
  202. template <typename Parameter, typename Param>
  203. void make_canonized_filter_meta_nhwcd4_dot(
  204. size_t src_ndim, const TensorLayout& filter, const Param& param,
  205. typename ConvolutionBase<Parameter>::CanonizedFilterMeta& ret) {
  206. /**
  207. * input: N H IC/4 W 4
  208. * Filter:
  209. * GROUP/4, 1, FH, FW, 4 [chanwise]
  210. * OC/4, FH, FW, IC/4, 4, 4 [dense]
  211. * GROUP, OC/4, FH, FW, IC/4, 4, 4 [group]
  212. */
  213. megdnn_assert(param.format == Param::Format::NHWCD4);
  214. auto img_ndim = src_ndim - 3;
  215. size_t flt_start = 0, flt_spatial_start = 1;
  216. bool is_chanwise = false;
  217. if (param.sparse == Param::Sparse::DENSE) {
  218. megdnn_assert(filter.ndim == img_ndim + 4,
  219. "bad filter ndim for dense convolution: "
  220. "spatial_ndim=%zu filter_ndim=%zu",
  221. img_ndim, filter.ndim);
  222. // oc, ic, dims[]
  223. ret.group = 1;
  224. flt_start = 0;
  225. } else {
  226. megdnn_assert(param.sparse == Param::Sparse::GROUP,
  227. "invalid convolution sparse type");
  228. megdnn_assert(
  229. filter.ndim == img_ndim + 3 || filter.ndim == img_ndim + 5,
  230. "bad filter ndim for group convolution: "
  231. "spatial_ndim=%zu filter_ndim=%zu",
  232. img_ndim, filter.ndim);
  233. if (filter.ndim == img_ndim + 3) {
  234. megdnn_assert(filter[1] == 1);
  235. is_chanwise = true;
  236. ret.group = filter[0] * 4;
  237. } else {
  238. ret.group = filter[0];
  239. }
  240. flt_start = 1;
  241. }
  242. ret.spatial_ndim = src_ndim - 3;
  243. megdnn_assert(
  244. ret.spatial_ndim == 2,
  245. "only 2D convolution is supported, and input should be 4-dim; "
  246. "got input dim = %zu",
  247. src_ndim);
  248. if (is_chanwise) {
  249. ret.ocpg = 1;
  250. ret.icpg = 1;
  251. } else {
  252. ret.ocpg = filter[flt_start] * 4;
  253. ret.icpg = filter[flt_start + 3] * 4;
  254. }
  255. auto dilation = ret.dilation;
  256. for (size_t i = 0; i < ret.spatial_ndim; ++i) {
  257. megdnn_assert(dilation[i] > 0,
  258. "invalid dilation on spatial dim %zu: %u", i,
  259. dilation[i]);
  260. ret.spatial[i] = filter[i + flt_start + flt_spatial_start];
  261. ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1;
  262. }
  263. }
  264. template <size_t pack_size, typename Parameter, typename Param>
  265. void make_canonized_filter_meta_nchwxx(
  266. size_t src_ndim, const TensorLayout& filter, const Param& param,
  267. typename ConvolutionBase<Parameter>::CanonizedFilterMeta& ret) {
  268. /**
  269. * input: N IC/pack_size, H, W, pack_size
  270. *
  271. ** NCHW44-DOT mode
  272. * filter:
  273. * {OC/pack_size, IC/pack_size, FH, FW, pack_size(OC), pack_size(IC)}
  274. * [dense]
  275. * {GROUP, OC_PER_GROUP/pack_size, IC_PER_GROUP/pack_size, \
  276. * FH, FW, pack_size(OC), pack_size(IC)} [group]
  277. *
  278. * NCHW88 and NCHW44 mode
  279. * filter:
  280. * {OC/pack_size, IC/pack_size, FH, FW, pack_size(IC), pack_size(OC)}
  281. * [dense]
  282. * {GROUP, OC_PER_GROUP/pack_size, IC_PER_GROUP/pack_size, \
  283. * FH, FW, pack_size(IC), pack_size(OC)} [group]
  284. * {GROUP/pack_size, 1, 1, FH, FW, pack_size} [chan]
  285. *
  286. ** NCHW88_WINOGRAD and NCHW44_WINOGRAD mode
  287. * filter:
  288. * {alpha, alpha, OC/pack_size, IC/pack_size, pack_size(IC),
  289. *pack_size(OC)} [dense]
  290. * {GROUP, alpha, alpha, OC_PER_GROUP/pack_size,
  291. * IC_PER_GROUP/pack_size, pack_size(IC), pack_size(OC)} [group]
  292. *
  293. */
  294. megdnn_assert(param.format == Param::Format::NCHW88 ||
  295. param.format == Param::Format::NCHW44 ||
  296. param.format == Param::Format::NCHW44_WINOGRAD ||
  297. param.format == Param::Format::NCHW44_DOT ||
  298. param.format == Param::Format::NCHW88_WINOGRAD);
  299. size_t img_ndim = 2;
  300. size_t flt_start = 0;
  301. size_t flt_spatial_start = 2;
  302. size_t pack_c_size = 0;
  303. if (param.sparse == Param::Sparse::DENSE) {
  304. if (filter.ndim == img_ndim + 4) {
  305. // oihw8i8o case
  306. megdnn_assert((filter[filter.ndim - 2] == pack_size &&
  307. filter[filter.ndim - 1] == pack_size) ||
  308. (filter[filter.ndim - 2] == 2 * pack_size &&
  309. filter[filter.ndim - 1] == 2 * pack_size),
  310. "last 2 dim of filter must be %zu, but got %zu, %zu",
  311. pack_size, filter[filter.ndim - 2],
  312. filter[filter.ndim - 1]);
  313. ret.group = 1;
  314. flt_start = 0;
  315. if (param.format == Param::Format::NCHW88_WINOGRAD ||
  316. param.format == Param::Format::NCHW44_WINOGRAD) {
  317. flt_start = 2;
  318. }
  319. if (filter[filter.ndim - 2] == 2 * pack_size &&
  320. filter[filter.ndim - 1] == 2 * pack_size) {
  321. pack_c_size = 2 * pack_size;
  322. } else {
  323. pack_c_size = pack_size;
  324. }
  325. ret.ocpg = filter[flt_start] * pack_c_size;
  326. ret.icpg = filter[flt_start + 1] * pack_c_size;
  327. } else if (filter.ndim == img_ndim + 3) {
  328. // ohwi8o
  329. megdnn_assert(param.format != Param::Format::NCHW88_WINOGRAD,
  330. "Hybrid nchw88 mode in not support winograd");
  331. megdnn_assert(param.format != Param::Format::NCHW44_WINOGRAD,
  332. "Hybrid nchw44 mode in not support winograd");
  333. flt_start = 0;
  334. flt_spatial_start = 1;
  335. ret.group = 1;
  336. ret.ocpg = filter[flt_start] * pack_size;
  337. ret.icpg = filter[flt_start + 3];
  338. } else {
  339. megdnn_assert(0, "not support nchwxx filter dim = %zu",
  340. filter.ndim);
  341. }
  342. } else {
  343. megdnn_assert(param.sparse == Param::Sparse::GROUP,
  344. "invalid convolution sparse type");
  345. flt_start = 1;
  346. if (param.format == Param::Format::NCHW88_WINOGRAD ||
  347. param.format == Param::Format::NCHW44_WINOGRAD) {
  348. flt_start = 3;
  349. }
  350. auto filter_oc = filter[flt_start];
  351. auto filter_ic = filter[flt_start + 1];
  352. if (filter_oc == 1 && filter_ic == 1 && filter.ndim == (img_ndim + 4) &&
  353. param.format != Param::Format::NCHW88_WINOGRAD &&
  354. param.format != Param::Format::NCHW44_WINOGRAD) {
  355. // Depthwise case goihw8g
  356. megdnn_assert(filter.ndim == img_ndim + 4,
  357. "bad filter ndim for group convolution: "
  358. "spatial_ndim=%zu filter_ndim=%zu",
  359. img_ndim, filter.ndim);
  360. megdnn_assert(filter[filter.ndim - 1] == pack_size,
  361. "last dim of filter must be %zu, but %zu", pack_size,
  362. filter[filter.ndim - 1]);
  363. ret.group = filter[0] * pack_size;
  364. ret.ocpg = filter_oc;
  365. ret.icpg = filter_ic;
  366. } else {
  367. // norm group case goihw8i8o
  368. megdnn_assert(filter.ndim == img_ndim + 5,
  369. "bad filter ndim for group convolution: "
  370. "spatial_ndim=%zu filter_ndim=%zu",
  371. img_ndim, filter.ndim);
  372. megdnn_assert((filter[filter.ndim - 1] == pack_size &&
  373. filter[filter.ndim - 2] == pack_size) ||
  374. (filter[filter.ndim - 1] == 2 * pack_size &&
  375. filter[filter.ndim - 2] == 2 * pack_size),
  376. "last 2 dim of filter must be %zu, but got %zu, %zu",
  377. pack_size, filter[filter.ndim - 2],
  378. filter[filter.ndim - 1]);
  379. ret.group = filter[0];
  380. if (filter[filter.ndim - 2] == 2 * pack_size &&
  381. filter[filter.ndim - 1] == 2 * pack_size) {
  382. ret.ocpg = filter_oc * 2 * pack_size;
  383. ret.icpg = filter_ic * 2 * pack_size;
  384. } else {
  385. ret.ocpg = filter_oc * pack_size;
  386. ret.icpg = filter_ic * pack_size;
  387. }
  388. }
  389. }
  390. ret.spatial_ndim = 2;
  391. megdnn_assert(ret.spatial_ndim == 2,
  392. "only 2D convolution is supported, and input should be 5-dim "
  393. "for nchwxx; "
  394. "got input dim = %zu",
  395. src_ndim);
  396. auto dilation = ret.dilation;
  397. for (size_t i = 0; i < ret.spatial_ndim; ++i) {
  398. megdnn_assert(dilation[i] == 1,
  399. "NCHWXX has invalid dilation on spatial dim %zu: %u, "
  400. "require to be 1",
  401. i, dilation[i]);
  402. if (param.format == Param::Format::NCHW88_WINOGRAD) {
  403. ret.spatial[i] =
  404. spatial_getter<Param, Param::Format::NCHW88_WINOGRAD>(
  405. filter[i + flt_start - 2], param);
  406. } else if (param.format == Param::Format::NCHW44_WINOGRAD) {
  407. ret.spatial[i] =
  408. spatial_getter<Param, Param::Format::NCHW44_WINOGRAD>(
  409. filter[i + flt_start - 2], param);
  410. } else {
  411. ret.spatial[i] = filter[i + flt_start + flt_spatial_start];
  412. }
  413. ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1;
  414. }
  415. }
  416. template <size_t pack_size, typename Parameter, typename Param>
  417. void make_canonized_filter_meta_nchwx(
  418. size_t src_ndim, const TensorLayout& filter, const Param& param,
  419. typename ConvolutionBase<Parameter>::CanonizedFilterMeta& ret) {
  420. /**
  421. * input: N IC/pack_size, H, W, pack_size
  422. * filter:
  423. * OC, IC/pack_size, FH, FW, pack_size [dense]
  424. * GROUP, OC, IC/pack_size, FH, FW, pack_size [group]
  425. */
  426. megdnn_assert(param.format == Param::Format::NCHW4 ||
  427. param.format == Param::Format::NCHW8 ||
  428. param.format == Param::Format::NCHW32);
  429. auto img_ndim = src_ndim - 3;
  430. size_t flt_start = 0, flt_spatial_start = 2;
  431. if (param.sparse == Param::Sparse::DENSE) {
  432. megdnn_assert(filter.ndim == img_ndim + 3,
  433. "bad filter ndim for dense convolution: "
  434. "spatial_ndim=%zu filter_ndim=%zu",
  435. img_ndim, filter.ndim);
  436. // oc, ic, dims[]
  437. ret.group = 1;
  438. flt_start = 0;
  439. } else {
  440. megdnn_assert(param.sparse == Param::Sparse::GROUP,
  441. "invalid convolution sparse type");
  442. megdnn_assert(filter.ndim == img_ndim + 4,
  443. "bad filter ndim for group convolution: "
  444. "spatial_ndim=%zu filter_ndim=%zu",
  445. img_ndim, filter.ndim);
  446. ret.group = filter[0];
  447. flt_start = 1;
  448. }
  449. ret.spatial_ndim = src_ndim - 3;
  450. megdnn_assert(ret.spatial_ndim == 2,
  451. "only 2D convolution is supported, and input should be 5-dim "
  452. "for nchw4; "
  453. "got input dim = %zu",
  454. src_ndim);
  455. ret.ocpg = filter[flt_start];
  456. ret.icpg = filter[flt_start + 1] * pack_size;
  457. auto dilation = ret.dilation;
  458. for (size_t i = 0; i < ret.spatial_ndim; ++i) {
  459. megdnn_assert(dilation[i] == 1,
  460. "NCHW4 has invalid dilation on spatial dim %zu: %u, "
  461. "require to be 1",
  462. i, dilation[i]);
  463. ret.spatial[i] = filter[i + flt_start + flt_spatial_start];
  464. ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1;
  465. }
  466. }
  467. template <size_t pack_size, typename Parameter, typename Param>
  468. void make_canonized_filter_meta_chwnx(
  469. size_t src_ndim, const TensorLayout& filter, const Param& param,
  470. typename ConvolutionBase<Parameter>::CanonizedFilterMeta& ret) {
  471. /**
  472. * input: IC / pack_size, H, W, N, pack_size
  473. * Filter:
  474. * IC / pack_size, FH, FW, OC, pack_size [dense]
  475. * GROUP, icpg / pack_size, FH, FW, ocpg, pack_size [group]
  476. * not implemented [chanwise]
  477. */
  478. megdnn_assert(param.format == Param::Format::CHWN4);
  479. auto img_ndim = src_ndim - 3;
  480. size_t flt_start = 0, flt_spatial_start = 1;
  481. if (param.sparse == Param::Sparse::DENSE) {
  482. megdnn_assert(filter.ndim == img_ndim + 3,
  483. "bad filter ndim for dense convolution: "
  484. "spatial_ndim=%zu filter_ndim=%zu",
  485. img_ndim, filter.ndim);
  486. // oc, ic, dims[]
  487. ret.group = 1;
  488. flt_start = 0;
  489. } else {
  490. megdnn_assert(param.sparse == Param::Sparse::GROUP,
  491. "invalid convolution sparse type");
  492. megdnn_assert(filter.ndim == img_ndim + 4,
  493. "bad filter ndim for group convolution: "
  494. "spatial_ndim=%zu filter_ndim=%zu",
  495. img_ndim, filter.ndim);
  496. ret.group = filter[0];
  497. flt_start = 1;
  498. }
  499. ret.spatial_ndim = src_ndim - 3;
  500. megdnn_assert(
  501. ret.spatial_ndim == 2,
  502. "only 2D convolution is supported, and input should be 4-dim; "
  503. "got input dim = %zu",
  504. src_ndim);
  505. ret.icpg = filter[flt_start] * pack_size;
  506. ret.ocpg = filter[flt_start + 3];
  507. auto dilation = ret.dilation;
  508. for (size_t i = 0; i < ret.spatial_ndim; ++i) {
  509. megdnn_assert(dilation[i] == 1,
  510. "CHWNx has invalid dilation on spatial dim %zu: %u, "
  511. "require to be 1",
  512. i, dilation[i]);
  513. ret.spatial[i] = filter[i + flt_start + flt_spatial_start];
  514. ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1;
  515. }
  516. }
  517. } // namespace
  518. namespace megdnn {
  519. template <typename Parameter>
  520. typename ConvolutionBase<Parameter>::CanonizedFilterMeta
  521. ConvolutionBase<Parameter>::make_canonized_filter_meta(
  522. size_t src_ndim, const TensorLayout& filter) const {
  523. megdnn_assert_contiguous(filter);
  524. CanonizedFilterMeta ret;
  525. ret.dtype = filter.dtype;
  526. ret.format = param().format;
  527. if (param().mode == Mode::CONVOLUTION) {
  528. ret.should_flip = true;
  529. } else {
  530. megdnn_assert(param().mode == Mode::CROSS_CORRELATION,
  531. "invalid conv mode");
  532. ret.should_flip = false;
  533. }
  534. ret.stride[0] = param().stride_h;
  535. ret.stride[1] = param().stride_w;
  536. ret.padding[0] = param().pad_h;
  537. ret.padding[1] = param().pad_w;
  538. ret.dilation[0] = param().dilate_h;
  539. ret.dilation[1] = param().dilate_w;
  540. if (param().format == Param::Format::NHWCD4) {
  541. if (filter.dtype.enumv() == DTypeEnum::QuantizedS8 ||
  542. filter.dtype.enumv() == DTypeEnum::Quantized8Asymm) {
  543. make_canonized_filter_meta_nhwcd4_dot<Parameter>(src_ndim, filter,
  544. param(), ret);
  545. } else {
  546. make_canonized_filter_meta_nhwcd4<Parameter>(src_ndim, filter,
  547. param(), ret);
  548. }
  549. } else if (param().format == Param::Format::NCHW4) {
  550. make_canonized_filter_meta_nchwx<4, Parameter>(src_ndim, filter,
  551. param(), ret);
  552. } else if (param().format == Param::Format::NCHW8) {
  553. make_canonized_filter_meta_nchwx<8, Parameter>(src_ndim, filter,
  554. param(), ret);
  555. } else if (param().format == Param::Format::NCHW88 ||
  556. param().format == Param::Format::NCHW88_WINOGRAD) {
  557. make_canonized_filter_meta_nchwxx<8, Parameter>(src_ndim, filter,
  558. param(), ret);
  559. } else if (param().format == Param::Format::NCHW44 ||
  560. param().format == Param::Format::NCHW44_DOT ||
  561. param().format == Param::Format::NCHW44_WINOGRAD) {
  562. make_canonized_filter_meta_nchwxx<4, Parameter>(src_ndim, filter,
  563. param(), ret);
  564. } else if (param().format == Param::Format::NCHW32) {
  565. make_canonized_filter_meta_nchwx<32, Parameter>(src_ndim, filter,
  566. param(), ret);
  567. } else if (param().format == Param::Format::CHWN4) {
  568. make_canonized_filter_meta_chwnx<4, Parameter>(src_ndim, filter,
  569. param(), ret);
  570. } else {
  571. megdnn_assert(param().format == Param::Format::NHWC ||
  572. param().format == Param::Format::NCHW ||
  573. param().format == Param::Format::NCHW_WINOGRAD);
  574. make_canonized_filter_meta_nchw_nhwc<Parameter>(src_ndim, filter,
  575. param(), ret);
  576. }
  577. return ret;
  578. }
  579. template <typename Parameter>
  580. void ConvolutionBase<Parameter>::check_or_deduce_dtype_fwd(DType src,
  581. DType filter,
  582. DType& dst) const {
  583. // The first one will be the default choice.
  584. SmallVector<DType> supported_dst_dtype;
  585. // We rely on megdnn_assert(src.enumv() == filter.enumv()) here.
  586. if (src.category() == DTypeCategory::FLOAT) {
  587. supported_dst_dtype.push_back(src);
  588. } else if (src.enumv() == DTypeEnum::Int8) {
  589. supported_dst_dtype = {dtype::Int32(), dtype::Int16()};
  590. } else if (src.enumv() == DTypeEnum::QuantizedS8 ||
  591. src.enumv() == DTypeEnum::Quantized8Asymm ||
  592. src.enumv() == DTypeEnum::Quantized4Asymm) {
  593. //! Qint8 winograd compute with float, in order to bringing the filter
  594. //! scale, here just use QuantizedS32 as filter type.
  595. if (src.enumv() == DTypeEnum::QuantizedS8 &&
  596. filter.enumv() == DTypeEnum::QuantizedS32) {
  597. supported_dst_dtype.push_back(dtype::QuantizedS32(
  598. src.param<dtype::QuantizedS8>().scale *
  599. filter.param<dtype::QuantizedS32>().scale));
  600. } else {
  601. supported_dst_dtype.push_back(
  602. dtype::QuantizedS32(mul_scale(src, filter)));
  603. }
  604. if (dst.valid() && dst.enumv() == src.enumv()) {
  605. supported_dst_dtype.push_back(dst);
  606. }
  607. } else if (src.enumv() == DTypeEnum::QuantizedS32) {
  608. //! ConvolutionBackwardData: s8(filter) + s8(dst) -> s32(src)
  609. megdnn_assert(filter.enumv() == DTypeEnum::QuantizedS8);
  610. supported_dst_dtype.push_back(
  611. dtype::QuantizedS8(src.param<dtype::QuantizedS32>().scale /
  612. filter.param<dtype::QuantizedS8>().scale));
  613. } else {
  614. megdnn_throw(ssprintf("unsupported input / filter DType: %s x %s",
  615. src.name(), filter.name()));
  616. }
  617. if (!dst.valid()) {
  618. dst = supported_dst_dtype.at(0);
  619. } else {
  620. bool dst_supported = false;
  621. for (auto&& dt : supported_dst_dtype) {
  622. if (dtype_almost_equal(dt, dst)) {
  623. dst_supported = true;
  624. break;
  625. }
  626. }
  627. MEGDNN_MARK_USED_VAR(dst_supported);
  628. megdnn_assert(dst_supported, "unsupported Conv(%s, %s) -> %s",
  629. src.name(), filter.name(), dst.name());
  630. }
  631. megdnn_assert((param().compute_mode == Param::ComputeMode::FLOAT32 ||
  632. param().compute_mode == Param::ComputeMode::DEFAULT)
  633. #if !MEGDNN_DISABLE_FLOAT16
  634. || src.enumv() == DTypeEnum::Float16 ||
  635. src.enumv() == DTypeEnum::BFloat16
  636. #endif
  637. ,
  638. "ComputeMode::FLOAT32 is only available for Float16/BFloat16 "
  639. "input / output.");
  640. }
  641. template <typename Parameter>
  642. typename ConvolutionBase<Parameter>::CanonizedFilterMeta
  643. ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src,
  644. const TensorLayout& filter,
  645. TensorLayout& dst) const {
  646. auto errmsg = [&]() { return get_errmsg(src, filter, dst, param()); };
  647. MEGDNN_MARK_USED_VAR(errmsg);
  648. megdnn_assert_contiguous(src);
  649. megdnn_assert_contiguous(filter);
  650. megdnn_assert(src.ndim >= 3_z, "%s", errmsg().c_str());
  651. if ((param().format == Param::Format::NCHW_WINOGRAD ||
  652. param().format == Param::Format::NCHW44_WINOGRAD) &&
  653. src.dtype.category() == DTypeCategory::QUANTIZED) {
  654. megdnn_assert((filter.dtype.enumv() == DTypeEnum::QuantizedS16 ||
  655. filter.dtype.enumv() == DTypeEnum::QuantizedS32),
  656. "%s", errmsg().c_str());
  657. megdnn_assert(src.dtype.enumv() == DTypeEnum::QuantizedS8 ||
  658. src.dtype.enumv() == DTypeEnum::Quantized8Asymm,
  659. "%s", errmsg().c_str());
  660. } else {
  661. megdnn_assert(src.dtype.enumv() == filter.dtype.enumv(), "%s",
  662. errmsg().c_str());
  663. }
  664. check_or_deduce_dtype_fwd(src.dtype, filter.dtype, dst.dtype);
  665. size_t img_dim;
  666. if (param().format == Param::Format::NCHW ||
  667. param().format == Param::Format::NHWC ||
  668. param().format == Param::Format::NCHW_WINOGRAD) {
  669. img_dim = src.ndim - 2;
  670. megdnn_assert(filter.ndim >= img_dim + 2 && filter.ndim <= img_dim + 6,
  671. "%s", errmsg().c_str());
  672. } else {
  673. megdnn_assert(param().format == Param::Format::NHWCD4 ||
  674. param().format == Param::Format::NCHW4 ||
  675. param().format == Param::Format::NCHW44 ||
  676. param().format == Param::Format::NCHW44_DOT ||
  677. param().format == Param::Format::NCHW8 ||
  678. param().format == Param::Format::NCHW32 ||
  679. param().format == Param::Format::NCHW88 ||
  680. param().format == Param::Format::NCHW88_WINOGRAD ||
  681. param().format == Param::Format::NCHW44_WINOGRAD ||
  682. param().format == Param::Format::CHWN4);
  683. img_dim = src.ndim - 3;
  684. if ((param().format == Param::Format::NCHW88 ||
  685. param().format == Param::Format::NCHW44_DOT ||
  686. param().format == Param::Format::NCHW44) &&
  687. filter.ndim == 5) {
  688. img_dim = src.ndim - 2;
  689. }
  690. megdnn_assert(filter.ndim == img_dim + 3 ||
  691. (filter.ndim == img_dim + 2 &&
  692. (param().format == Param::Format::NCHW88 ||
  693. param().format == Param::Format::NCHW44_DOT ||
  694. param().format == Param::Format::NCHW44)) ||
  695. filter.ndim == img_dim + 4 ||
  696. filter.ndim == img_dim + 5,
  697. "%s", errmsg().c_str());
  698. if (param().format == Param::Format::NCHW4) {
  699. megdnn_assert(src.ndim == 5 &&
  700. (filter.ndim == 5 || filter.ndim == 6 ||
  701. filter.ndim == 7) &&
  702. src[src.ndim - 1] == 4 &&
  703. filter[filter.ndim - 1] == 4,
  704. "NCHW4 require src and filter's ndim is 5 or 6, and "
  705. "last shape "
  706. "is 4 "
  707. "but got src %s, filter %s",
  708. src.to_string().c_str(), filter.to_string().c_str());
  709. }
  710. if (param().format == Param::Format::NCHW8) {
  711. megdnn_assert(
  712. src.ndim == 5 && (filter.ndim == 5 || filter.ndim == 6) &&
  713. src[src.ndim - 1] == 8 &&
  714. filter[filter.ndim - 1] == 8,
  715. "NCHW8 require src and filter's ndim is 5 or 6, and last "
  716. "shape is 8 "
  717. "but got src %s, filter %s",
  718. src.to_string().c_str(), filter.to_string().c_str());
  719. }
  720. if (param().format == Param::Format::NCHW32) {
  721. megdnn_assert(
  722. src.ndim == 5 && (filter.ndim == 5 || filter.ndim == 6) &&
  723. src[src.ndim - 1] == 32 &&
  724. filter[filter.ndim - 1] == 32,
  725. "NCHW32 require src and filter's ndim is 5 or 6, and last "
  726. "shape is 32 "
  727. "but got src %s, filter %s",
  728. src.to_string().c_str(), filter.to_string().c_str());
  729. }
  730. if (param().format == Param::Format::NCHW88 ||
  731. param().format == Param::Format::NCHW88_WINOGRAD) {
  732. megdnn_assert((src.ndim == 4 && filter.ndim == 5 &&
  733. filter[filter.ndim - 1] == 8) ||
  734. (src.ndim == 5 &&
  735. ((filter.ndim == 6 &&
  736. filter[filter.ndim - 1] == 8) ||
  737. (filter.ndim == 7 &&
  738. filter[filter.ndim - 1] == 8 &&
  739. filter[filter.ndim - 2] == 8)) &&
  740. src[src.ndim - 1] == 8),
  741. "NCHW88 require src ndim is 5 and filter's ndim is 6 "
  742. ", and last shape two is 8 but got src %s, filter %s",
  743. src.to_string().c_str(), filter.to_string().c_str());
  744. }
  745. if (param().format == Param::Format::NCHW44 ||
  746. param().format == Param::Format::NCHW44_DOT ||
  747. param().format == Param::Format::NCHW44_WINOGRAD) {
  748. //!support nchw44 filter change to 88 for int8 winogradf23_88 using MK8 mamtul
  749. megdnn_assert((src.ndim == 4 && filter.ndim == 5 &&
  750. filter[filter.ndim - 1] == 4) ||
  751. (src.ndim == 5 &&
  752. ((filter.ndim == 6 &&
  753. (filter[filter.ndim - 1] == 4 ||
  754. filter[filter.ndim - 1] == 8)) ||
  755. (filter.ndim == 7 &&
  756. (filter[filter.ndim - 1] == 4 ||
  757. filter[filter.ndim - 1] == 8) &&
  758. (filter[filter.ndim - 2] == 4 ||
  759. filter[filter.ndim - 2] == 8))) &&
  760. src[src.ndim - 1] == 4),
  761. "NCHW44 require src ndim is 5 and filter's ndim is 6 "
  762. ", and last shape two is 4 but got src %s, filter %s",
  763. src.to_string().c_str(), filter.to_string().c_str());
  764. }
  765. if (param().format == Param::Format::CHWN4) {
  766. megdnn_assert(
  767. src.ndim == 5 && (filter.ndim == 5 || filter.ndim == 6) &&
  768. src[src.ndim - 1] == 4 &&
  769. filter[filter.ndim - 1] == 4,
  770. "CHWN4 require src and filter's ndim is 5 or 6, and last "
  771. "shape is 4 "
  772. "but got src %s, filter %s",
  773. src.to_string().c_str(), filter.to_string().c_str());
  774. }
  775. }
  776. megdnn_assert(img_dim == 2,
  777. "currently only convolution on 2D image is supported");
  778. auto cflt = make_canonized_filter_meta(src.ndim, filter);
  779. if (param().format == Param::Format::NCHW ||
  780. param().format == Param::Format::NHWC ||
  781. param().format == Param::Format::NCHW_WINOGRAD) {
  782. size_t src_or_dst_c_pos = 0;
  783. size_t src_or_dst_spatial_start = 0;
  784. if (param().format == Param::Format::NCHW ||
  785. param().format == Param::Format::NCHW_WINOGRAD) {
  786. src_or_dst_c_pos = 1;
  787. src_or_dst_spatial_start = 2;
  788. } else {
  789. megdnn_assert(param().format == Param::Format::NHWC,
  790. "invalid conv format");
  791. src_or_dst_c_pos = 3;
  792. src_or_dst_spatial_start = 1;
  793. }
  794. megdnn_assert(cflt.icpg * cflt.group == src[src_or_dst_c_pos], "%s",
  795. errmsg().c_str());
  796. if (param().format == Param::Format::NCHW_WINOGRAD) {
  797. megdnn_assert(cflt.spatial[0] == cflt.spatial[1],
  798. "NCHW_WINOGRAD only support conv with fh == fw");
  799. }
  800. dst.ndim = src.ndim;
  801. dst[0] = src[0];
  802. dst[src_or_dst_c_pos] = cflt.ocpg * cflt.group;
  803. for (size_t i = 0; i < cflt.spatial_ndim; ++i) {
  804. dst[i + src_or_dst_spatial_start] = infer_conv_shape(
  805. src[i + src_or_dst_spatial_start], cflt.dilated_spatial[i],
  806. cflt.stride[i], cflt.padding[i]);
  807. }
  808. dst.init_contiguous_stride();
  809. } else if (param().format == Param::Format::NCHW4) {
  810. megdnn_assert(src.ndim == 5,
  811. "invalid src ndim for NCHW4, expected=5, got=%zu",
  812. src.ndim);
  813. megdnn_assert(cflt.icpg * cflt.group == src[1] * 4,
  814. "%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg,
  815. cflt.group);
  816. dst.ndim = src.ndim;
  817. dst[0] = src[0];
  818. auto oc = cflt.ocpg * cflt.group;
  819. megdnn_assert(oc % 4 == 0);
  820. dst[1] = oc / 4;
  821. dst[2] = infer_conv_shape(src[2], cflt.dilated_spatial[0],
  822. cflt.stride[0], cflt.padding[0]);
  823. dst[3] = infer_conv_shape(src[3], cflt.dilated_spatial[1],
  824. cflt.stride[1], cflt.padding[1]);
  825. dst[4] = 4;
  826. } else if (param().format == Param::Format::NCHW8) {
  827. megdnn_assert(src.ndim == 5,
  828. "invalid src ndim for NCHW8, expected=5, got=%zu",
  829. src.ndim);
  830. megdnn_assert(cflt.icpg * cflt.group == src[1] * 8,
  831. "%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg,
  832. cflt.group);
  833. dst.ndim = src.ndim;
  834. dst[0] = src[0];
  835. auto oc = cflt.ocpg * cflt.group;
  836. megdnn_assert(oc % 8 == 0);
  837. dst[1] = oc / 8;
  838. dst[2] = infer_conv_shape(src[2], cflt.dilated_spatial[0],
  839. cflt.stride[0], cflt.padding[0]);
  840. dst[3] = infer_conv_shape(src[3], cflt.dilated_spatial[1],
  841. cflt.stride[1], cflt.padding[1]);
  842. dst[4] = 8;
  843. } else if (param().format == Param::Format::NCHW32) {
  844. megdnn_assert(src.ndim == 5,
  845. "invalid src ndim for NCHW32, expected=5, got=%zu",
  846. src.ndim);
  847. megdnn_assert(cflt.icpg * cflt.group == src[1] * 32,
  848. "%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg,
  849. cflt.group);
  850. dst.ndim = src.ndim;
  851. dst[0] = src[0];
  852. auto oc = cflt.ocpg * cflt.group;
  853. megdnn_assert(oc % 32 == 0);
  854. dst[1] = oc / 32;
  855. dst[2] = infer_conv_shape(src[2], cflt.dilated_spatial[0],
  856. cflt.stride[0], cflt.padding[0]);
  857. dst[3] = infer_conv_shape(src[3], cflt.dilated_spatial[1],
  858. cflt.stride[1], cflt.padding[1]);
  859. dst[4] = 32;
  860. } else if (param().format == Param::Format::NCHW88 ||
  861. param().format == Param::Format::NCHW88_WINOGRAD) {
  862. megdnn_assert(src.ndim == 5 || (src.ndim == 4 && src[1] <= 8),
  863. "invalid src ndim for NCHW88, expected=5 or 4, got=%zu",
  864. src.ndim);
  865. dst.ndim = 5;
  866. dst[0] = src[0];
  867. auto oc = cflt.ocpg * cflt.group;
  868. megdnn_assert(oc % 8 == 0);
  869. dst[1] = oc / 8;
  870. dst[2] = infer_conv_shape(src[2], cflt.dilated_spatial[0],
  871. cflt.stride[0], cflt.padding[0]);
  872. dst[3] = infer_conv_shape(src[3], cflt.dilated_spatial[1],
  873. cflt.stride[1], cflt.padding[1]);
  874. dst[4] = 8;
  875. if (cflt.group == 1) {
  876. megdnn_assert(cflt.icpg * cflt.group == src[1] * 8 ||
  877. (cflt.icpg * cflt.group == src[1]),
  878. "%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg,
  879. cflt.group);
  880. }
  881. } else if (param().format == Param::Format::NCHW44 ||
  882. param().format == Param::Format::NCHW44_DOT ||
  883. param().format == Param::Format::NCHW44_WINOGRAD) {
  884. megdnn_assert(src.ndim == 5 || (src.ndim == 4 && src[1] <= 4),
  885. "invalid src ndim for NCHW44, expected=5 or 4, got=%zu",
  886. src.ndim);
  887. dst.ndim = 5;
  888. dst[0] = src[0];
  889. auto oc = cflt.ocpg * cflt.group;
  890. megdnn_assert(oc % 4 == 0);
  891. dst[1] = oc / 4;
  892. dst[2] = infer_conv_shape(src[2], cflt.dilated_spatial[0],
  893. cflt.stride[0], cflt.padding[0]);
  894. dst[3] = infer_conv_shape(src[3], cflt.dilated_spatial[1],
  895. cflt.stride[1], cflt.padding[1]);
  896. dst[4] = 4;
  897. if (cflt.group == 1) {
  898. megdnn_assert(cflt.icpg * cflt.group == src[1] * 4 ||
  899. (cflt.icpg * cflt.group == src[1]),
  900. "%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg,
  901. cflt.group);
  902. }
  903. } else if (param().format == Param::Format::CHWN4) {
  904. megdnn_assert(src.ndim == 5,
  905. "invalid src ndim for CHWN4, expected=5, got=%zu",
  906. src.ndim);
  907. megdnn_assert(cflt.icpg * cflt.group == src[0] * 4,
  908. "%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg,
  909. cflt.group);
  910. dst.ndim = src.ndim;
  911. dst[3] = src[3];
  912. auto oc = cflt.ocpg * cflt.group;
  913. megdnn_assert(oc % 4 == 0);
  914. dst[0] = oc / 4;
  915. dst[1] = infer_conv_shape(src[1], cflt.dilated_spatial[0],
  916. cflt.stride[0], cflt.padding[0]);
  917. dst[2] = infer_conv_shape(src[2], cflt.dilated_spatial[1],
  918. cflt.stride[1], cflt.padding[1]);
  919. dst[4] = 4;
  920. } else {
  921. megdnn_assert(param().format == Param::Format::NHWCD4);
  922. megdnn_assert(src.ndim == 5,
  923. "invalid src ndim for NHWCD4, expected=5, got=%zu",
  924. src.ndim);
  925. megdnn_assert(cflt.icpg * cflt.group == src[2] * 4,
  926. "%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg,
  927. cflt.group);
  928. dst.ndim = src.ndim;
  929. dst[0] = src[0];
  930. auto oc = cflt.ocpg * cflt.group;
  931. megdnn_assert(oc % 4 == 0);
  932. dst[2] = oc / 4;
  933. dst[1] = infer_conv_shape(src[1], cflt.dilated_spatial[0],
  934. cflt.stride[0], cflt.padding[0]);
  935. dst[3] = infer_conv_shape(src[3], cflt.dilated_spatial[1],
  936. cflt.stride[1], cflt.padding[1]);
  937. megdnn_assert(src[4] == 4);
  938. dst[4] = 4;
  939. }
  940. dst.format = src.format;
  941. dst.init_contiguous_stride();
  942. return cflt;
  943. }
  944. /**
  945. * \warning: An explicit specialization shall be declared in a namespace
  946. * enclosing the specialized template. An explicit specialization whose
  947. * declarator-id is not qualified shall be declared in the nearest enclosing
  948. * namespace of the template, or, if the namespace is inline (7.3.1), any
  949. * namespace from its enclosing namespace set.
  950. * refer to:
  951. * https://stackoverflow.com/questions/25594644/warning-specialization-of-template-in-different-namespace
  952. */
  953. template <>
  954. ConvolutionBase<param::Convolution>::CanonizedFilterMeta
  955. ConvolutionBase<param::Convolution>::check_layout_fwd(
  956. const TensorLayout& src, const TensorLayout& filter,
  957. const TensorLayout& dst) const {
  958. TensorLayout dst_expected;
  959. dst_expected.dtype = dst.dtype;
  960. auto ret = deduce_layout_fwd(src, filter, dst_expected);
  961. megdnn_assert_eq_layout(dst_expected, dst);
  962. return ret;
  963. }
  964. template <>
  965. ConvolutionBase<param::ConvBias>::CanonizedFilterMeta
  966. ConvolutionBase<param::ConvBias>::check_layout_fwd(
  967. const TensorLayout& src, const TensorLayout& filter,
  968. const TensorLayout& dst) const {
  969. TensorLayout dst_expected;
  970. dst_expected.dtype = dst.dtype;
  971. auto ret = deduce_layout_fwd(src, filter, dst_expected);
  972. megdnn_assert_eq_layout(dst_expected, dst);
  973. return ret;
  974. }
  975. template <>
  976. ConvolutionBase<param::BatchConvBias>::CanonizedFilterMeta
  977. ConvolutionBase<param::BatchConvBias>::check_layout_fwd(
  978. const TensorLayout& src, const TensorLayout& filter,
  979. const TensorLayout& dst) const {
  980. TensorLayout dst_expected;
  981. dst_expected.dtype = dst.dtype;
  982. auto ret = deduce_layout_fwd(src, filter, dst_expected);
  983. megdnn_assert_eq_layout(dst_expected, dst);
  984. return ret;
  985. }
  986. void ConvolutionForward::deduce_dtype(DType src, DType filter, DType& dst) {
  987. check_or_deduce_dtype_fwd(src, filter, dst);
  988. }
  989. void ConvolutionForward::deduce_layout(const TensorLayout& src,
  990. const TensorLayout& filter,
  991. TensorLayout& dst) {
  992. deduce_layout_fwd(src, filter, dst);
  993. }
  994. ConvolutionForward::CanonizedFilterMeta ConvolutionForward::check_exec(
  995. const TensorLayout& src, const TensorLayout& filter,
  996. const TensorLayout& dst, size_t workspace_in_bytes) {
  997. auto ret = check_layout_fwd(src, filter, dst);
  998. auto required_workspace_in_bytes =
  999. get_workspace_in_bytes(src, filter, dst, nullptr);
  1000. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  1001. return ret;
  1002. }
  1003. ConvolutionBackwardData::CanonizedFilterMeta
  1004. ConvolutionBackwardData::check_exec(const TensorLayout& filter,
  1005. const TensorLayout& diff,
  1006. const TensorLayout& grad,
  1007. size_t workspace_in_bytes) {
  1008. auto grad_fwd = grad;
  1009. auto filter_fwd = filter;
  1010. auto diff_fwd = diff;
  1011. std::swap(grad_fwd.dtype, diff_fwd.dtype);
  1012. grad_fwd.init_contiguous_stride();
  1013. diff_fwd.init_contiguous_stride();
  1014. auto ret = check_layout_fwd(grad_fwd, filter_fwd, diff_fwd);
  1015. auto required_workspace_in_bytes =
  1016. get_workspace_in_bytes(filter, diff, grad);
  1017. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  1018. return ret;
  1019. }
  1020. void ConvolutionBackwardData::deduce_dtype(DType filter, DType diff,
  1021. DType& grad) {
  1022. SmallVector<DType> supported_dst_dtype;
  1023. if (filter.category() == diff.category() &&
  1024. filter.category() == DTypeCategory::FLOAT) {
  1025. supported_dst_dtype.push_back(filter);
  1026. } else if (filter.enumv() == DTypeEnum::Int8 && diff == filter) {
  1027. supported_dst_dtype.push_back(dtype::Int32());
  1028. } else if ((filter.enumv() == DTypeEnum::QuantizedS8 &&
  1029. diff.enumv() == DTypeEnum::QuantizedS8) ||
  1030. (filter.enumv() == DTypeEnum::Quantized8Asymm &&
  1031. diff.enumv() == DTypeEnum::Quantized8Asymm)) {
  1032. supported_dst_dtype.push_back(
  1033. dtype::QuantizedS32(mul_scale(filter, diff)));
  1034. if (grad.valid() && grad.enumv() == diff.enumv()) {
  1035. supported_dst_dtype.push_back(grad);
  1036. }
  1037. } else {
  1038. megdnn_throw(ssprintf("unsupported input / diff DType: %s x %s",
  1039. filter.name(), diff.name()));
  1040. }
  1041. if (!grad.valid()) {
  1042. grad = supported_dst_dtype.at(0);
  1043. } else {
  1044. megdnn_assert(vec_contains(supported_dst_dtype, grad),
  1045. "unsupported ConvBwd(%s, %s) -> %s", filter.name(),
  1046. diff.name(), grad.name());
  1047. }
  1048. megdnn_assert(param().compute_mode != Param::ComputeMode::FLOAT32
  1049. #if !MEGDNN_DISABLE_FLOAT16
  1050. || filter.enumv() == DTypeEnum::Float16
  1051. || filter.enumv() == DTypeEnum::BFloat16
  1052. #endif
  1053. ,
  1054. "ComputeMode::FLOAT32 is only available for Float16/BFloat16 "
  1055. "input / output.");
  1056. }
  1057. void ConvolutionBackwardData::deduce_layout(const TensorLayout& filter,
  1058. const TensorLayout& diff,
  1059. TensorLayout& grad) {
  1060. auto errmsg = [&]() { return get_errmsg(filter, diff, grad, param()); };
  1061. MEGDNN_MARK_USED_VAR(errmsg);
  1062. megdnn_assert_contiguous(filter);
  1063. megdnn_assert_contiguous(diff);
  1064. megdnn_assert(filter.ndim == 4_z || filter.ndim == 5_z, "%s",
  1065. errmsg().c_str());
  1066. megdnn_assert(diff.ndim == 4_z || diff.ndim == 5_z, "%s", errmsg().c_str());
  1067. deduce_dtype(filter.dtype, diff.dtype, grad.dtype);
  1068. auto cflt = make_canonized_filter_meta(diff.ndim, filter);
  1069. auto deduce = [&errmsg](size_t out, size_t filter, size_t stride,
  1070. size_t pad) {
  1071. MEGDNN_MARK_USED_VAR(errmsg);
  1072. auto i = (out - 1) * stride + filter;
  1073. megdnn_assert(i > pad * 2, "%s", errmsg().c_str());
  1074. return i - pad * 2;
  1075. };
  1076. if (param().format == Param::Format::NCHW ||
  1077. param().format == Param::Format::NHWC) {
  1078. size_t src_or_dst_c_pos = 0;
  1079. size_t src_or_dst_spatial_start = 0;
  1080. if (param().format == Param::Format::NCHW) {
  1081. src_or_dst_c_pos = 1;
  1082. src_or_dst_spatial_start = 2;
  1083. } else {
  1084. megdnn_assert(param().format == Param::Format::NHWC,
  1085. "invalid conv format");
  1086. src_or_dst_c_pos = 3;
  1087. src_or_dst_spatial_start = 1;
  1088. }
  1089. megdnn_assert(cflt.ocpg * cflt.group == diff[src_or_dst_c_pos], "%s",
  1090. errmsg().c_str());
  1091. grad.ndim = diff.ndim;
  1092. grad[0] = diff[0];
  1093. grad[src_or_dst_c_pos] = cflt.icpg * cflt.group;
  1094. for (size_t i = 0; i < cflt.spatial_ndim; ++i) {
  1095. grad[i + src_or_dst_spatial_start] = deduce(
  1096. diff[i + src_or_dst_spatial_start], cflt.dilated_spatial[i],
  1097. cflt.stride[i], cflt.padding[i]);
  1098. }
  1099. } else {
  1100. megdnn_assert(param().format == Param::Format::NHWCD4);
  1101. megdnn_assert(diff.ndim == 5,
  1102. "valid diff ndim for NHWCD4, expected=5, got=%zu",
  1103. diff.ndim);
  1104. megdnn_assert(cflt.ocpg * cflt.group == diff[2] * 4, "%s",
  1105. errmsg().c_str());
  1106. grad.ndim = diff.ndim;
  1107. grad[0] = diff[0];
  1108. auto ic = cflt.icpg * cflt.group;
  1109. megdnn_assert(ic % 4 == 0);
  1110. grad[2] = ic / 4;
  1111. grad[1] = deduce(diff[1], cflt.dilated_spatial[0], cflt.stride[0],
  1112. cflt.padding[0]);
  1113. grad[3] = deduce(diff[3], cflt.dilated_spatial[1], cflt.stride[1],
  1114. cflt.padding[1]);
  1115. megdnn_assert(diff[4] == 4);
  1116. grad[4] = 4;
  1117. }
  1118. grad.format = diff.format;
  1119. grad.init_contiguous_stride();
  1120. }
  1121. ConvolutionBackwardFilter::CanonizedFilterMeta
  1122. ConvolutionBackwardFilter::check_exec(const TensorLayout& src,
  1123. const TensorLayout& diff,
  1124. const TensorLayout& grad,
  1125. size_t workspace_in_bytes) {
  1126. megdnn_assert(src.dtype.category() == DTypeCategory::FLOAT &&
  1127. diff.dtype.category() == DTypeCategory::FLOAT &&
  1128. grad.dtype.category() == DTypeCategory::FLOAT,
  1129. "only float type is supported for conv backward filter");
  1130. auto ret = check_layout_fwd(src, grad, diff);
  1131. auto required_workspace_in_bytes = get_workspace_in_bytes(src, diff, grad);
  1132. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  1133. return ret;
  1134. }
  1135. } // namespace megdnn
  1136. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台