You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

convolution.cpp 56 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286
  1. #include "megdnn/oprs/nn.h"
  2. #include "src/common/utils.h"
  3. using namespace megdnn;
  4. namespace {
  5. template <typename Param>
  6. std::string get_errmsg(
  7. const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst,
  8. const Param& param) {
  9. MEGDNN_MARK_USED_VAR(src);
  10. MEGDNN_MARK_USED_VAR(filter);
  11. MEGDNN_MARK_USED_VAR(dst);
  12. return megdnn_layout_msg(src) + ", " + megdnn_layout_msg(filter) + ", " +
  13. megdnn_layout_msg(dst) + ", " + "is_nchw=" +
  14. std::to_string(param.format == param::Convolution::Format::NCHW) + ", " +
  15. "is_xcorr=" +
  16. std::to_string((param.mode == Convolution::Mode::CROSS_CORRELATION)) + ", " +
  17. "pad_h=" + std::to_string(param.pad_h) + ", " +
  18. "pad_w=" + std::to_string(param.pad_w) + ", " +
  19. "stride_h=" + std::to_string(param.stride_h) + ", " +
  20. "stride_w=" + std::to_string(param.stride_w) + ", " +
  21. "dilate_h=" + std::to_string(param.dilate_h) + ", " +
  22. "dilate_w=" + std::to_string(param.dilate_w);
  23. }
  24. template <typename Param, typename Param::Format>
  25. uint32_t spatial_getter(uint32_t filter, const Param&) {
  26. return filter;
  27. }
  28. template <typename Parameter, typename Param>
  29. void make_canonized_filter_meta_nchw_nhwc(
  30. size_t src_ndim, const TensorLayout& filter, const Param& param,
  31. typename ConvolutionBase<Parameter>::CanonizedFilterMeta& ret) {
  32. megdnn_assert(
  33. param.format == Param::Format::NCHW || param.format == Param::Format::NHWC);
  34. auto img_ndim = src_ndim - 2;
  35. size_t flt_start, flt_spatial_start, ocpg_pos, icpg_pos;
  36. if (param.sparse == Param::Sparse::DENSE) {
  37. megdnn_assert(
  38. filter.ndim == img_ndim + 2 || filter.ndim == img_ndim + 4,
  39. "bad filter ndim for dense convolution: "
  40. "spatial_ndim=%zu filter_ndim=%zu",
  41. img_ndim, filter.ndim);
  42. ret.group = 1;
  43. flt_start = 0;
  44. } else {
  45. megdnn_assert(
  46. param.sparse == Param::Sparse::GROUP,
  47. "invalid convolution sparse type");
  48. megdnn_assert(
  49. filter.ndim == img_ndim + 3 || filter.ndim == img_ndim + 5,
  50. "bad filter ndim for group convolution: "
  51. "spatial_ndim=%zu filter_ndim=%zu",
  52. img_ndim, filter.ndim);
  53. // grp, oc, ic, dims[]
  54. ret.group = filter[0];
  55. flt_start = 1;
  56. }
  57. uint32_t ic_block_size = 1, oc_block_size = 1;
  58. if (param.format == Param::Format::NCHW) {
  59. // filter should be (oc, ic, fh, fw)
  60. flt_spatial_start = 2;
  61. ocpg_pos = 0;
  62. icpg_pos = 1;
  63. } else {
  64. megdnn_assert(
  65. param.format == Param::Format::NHWC, "invalid conv tensor format");
  66. // filter should be (oc, fh, fw, ic)
  67. flt_spatial_start = 1;
  68. ocpg_pos = 0;
  69. icpg_pos = 3;
  70. }
  71. ret.spatial_ndim = src_ndim - 2;
  72. megdnn_assert(
  73. ret.spatial_ndim == 2,
  74. "only 2D convolution is supported, and input should be 4-dim; "
  75. "got input dim = %zu",
  76. src_ndim);
  77. ret.ocpg = filter[flt_start + ocpg_pos] * oc_block_size;
  78. ret.icpg = filter[flt_start + icpg_pos] * ic_block_size;
  79. auto dilation = ret.dilation;
  80. for (size_t i = 0; i < ret.spatial_ndim; ++i) {
  81. megdnn_assert(
  82. dilation[i] > 0, "invalid dilation on spatial dim %zu: %u", i,
  83. dilation[i]);
  84. ret.spatial[i] = spatial_getter<Param, Param::Format::NCHW>(
  85. filter[i + flt_start + flt_spatial_start], param);
  86. ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1;
  87. }
  88. }
  89. template <typename Parameter, typename Param>
  90. void make_canonized_filter_meta_nhwcd4(
  91. size_t src_ndim, const TensorLayout& filter, const Param& param,
  92. typename ConvolutionBase<Parameter>::CanonizedFilterMeta& ret) {
  93. /**
  94. * input: N H IC/4 W 4
  95. * Filter:
  96. * OC/4, FH, FW, IC, 4 [dense]
  97. * GROUP, OC/4, FH, FW, IC, 4 [group]
  98. * GROUP/4, 1, FH, FW, 4 [chanwise]
  99. */
  100. megdnn_assert(param.format == Param::Format::NHWCD4);
  101. auto img_ndim = src_ndim - 3;
  102. size_t flt_start = 0, flt_spatial_start = 1;
  103. bool is_chanwise = false;
  104. if (param.sparse == Param::Sparse::DENSE) {
  105. megdnn_assert(
  106. filter.ndim == img_ndim + 3,
  107. "bad filter ndim for dense convolution: "
  108. "spatial_ndim=%zu filter_ndim=%zu",
  109. img_ndim, filter.ndim);
  110. // oc, ic, dims[]
  111. ret.group = 1;
  112. flt_start = 0;
  113. } else {
  114. megdnn_assert(
  115. param.sparse == Param::Sparse::GROUP,
  116. "invalid convolution sparse type");
  117. megdnn_assert(
  118. filter.ndim == img_ndim + 3 || filter.ndim == img_ndim + 4,
  119. "bad filter ndim for group convolution: "
  120. "spatial_ndim=%zu filter_ndim=%zu",
  121. img_ndim, filter.ndim);
  122. if (filter.ndim == img_ndim + 3 && filter[1] == 1) {
  123. is_chanwise = true;
  124. ret.group = filter[0] * 4;
  125. } else {
  126. ret.group = filter[0];
  127. }
  128. flt_start = 1;
  129. }
  130. ret.spatial_ndim = src_ndim - 3;
  131. megdnn_assert(
  132. ret.spatial_ndim == 2,
  133. "only 2D convolution is supported, and input should be 4-dim; "
  134. "got input dim = %zu",
  135. src_ndim);
  136. if (is_chanwise) {
  137. ret.ocpg = 1;
  138. ret.icpg = 1;
  139. } else {
  140. ret.ocpg = filter[flt_start] * 4;
  141. ret.icpg = filter[flt_start + 3];
  142. }
  143. auto dilation = ret.dilation;
  144. for (size_t i = 0; i < ret.spatial_ndim; ++i) {
  145. megdnn_assert(
  146. dilation[i] > 0, "invalid dilation on spatial dim %zu: %u", i,
  147. dilation[i]);
  148. ret.spatial[i] = filter[i + flt_start + flt_spatial_start];
  149. ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1;
  150. }
  151. }
  152. template <typename Parameter, typename Param>
  153. void make_canonized_filter_meta_nhwcd4_dot(
  154. size_t src_ndim, const TensorLayout& filter, const Param& param,
  155. typename ConvolutionBase<Parameter>::CanonizedFilterMeta& ret) {
  156. /**
  157. * input: N H IC/4 W 4
  158. * Filter:
  159. * GROUP/4, 1, FH, FW, 4 [chanwise]
  160. * OC/4, FH, FW, IC/4, 4, 4 [dense]
  161. * GROUP, OC/4, FH, FW, IC/4, 4, 4 [group]
  162. */
  163. megdnn_assert(param.format == Param::Format::NHWCD4);
  164. auto img_ndim = src_ndim - 3;
  165. size_t flt_start = 0, flt_spatial_start = 1;
  166. bool is_chanwise = false;
  167. if (param.sparse == Param::Sparse::DENSE) {
  168. megdnn_assert(
  169. filter.ndim == img_ndim + 4,
  170. "bad filter ndim for dense convolution: "
  171. "spatial_ndim=%zu filter_ndim=%zu",
  172. img_ndim, filter.ndim);
  173. // oc, ic, dims[]
  174. ret.group = 1;
  175. flt_start = 0;
  176. } else {
  177. megdnn_assert(
  178. param.sparse == Param::Sparse::GROUP,
  179. "invalid convolution sparse type");
  180. megdnn_assert(
  181. filter.ndim == img_ndim + 3 || filter.ndim == img_ndim + 5,
  182. "bad filter ndim for group convolution: "
  183. "spatial_ndim=%zu filter_ndim=%zu",
  184. img_ndim, filter.ndim);
  185. if (filter.ndim == img_ndim + 3) {
  186. megdnn_assert(filter[1] == 1);
  187. is_chanwise = true;
  188. ret.group = filter[0] * 4;
  189. } else {
  190. ret.group = filter[0];
  191. }
  192. flt_start = 1;
  193. }
  194. ret.spatial_ndim = src_ndim - 3;
  195. megdnn_assert(
  196. ret.spatial_ndim == 2,
  197. "only 2D convolution is supported, and input should be 4-dim; "
  198. "got input dim = %zu",
  199. src_ndim);
  200. if (is_chanwise) {
  201. ret.ocpg = 1;
  202. ret.icpg = 1;
  203. } else {
  204. ret.ocpg = filter[flt_start] * 4;
  205. ret.icpg = filter[flt_start + 3] * 4;
  206. }
  207. auto dilation = ret.dilation;
  208. for (size_t i = 0; i < ret.spatial_ndim; ++i) {
  209. megdnn_assert(
  210. dilation[i] > 0, "invalid dilation on spatial dim %zu: %u", i,
  211. dilation[i]);
  212. ret.spatial[i] = filter[i + flt_start + flt_spatial_start];
  213. ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1;
  214. }
  215. }
  216. template <size_t pack_size, typename Parameter, typename Param>
  217. void make_canonized_filter_meta_nchwxx(
  218. size_t src_ndim, const TensorLayout& filter, const Param& param,
  219. typename ConvolutionBase<Parameter>::CanonizedFilterMeta& ret) {
  220. /**
  221. * input: N IC/pack_size, H, W, pack_size
  222. *
  223. ** NCHW44-DOT mode
  224. * filter:
  225. * {OC/pack_size, IC/pack_size, FH, FW, pack_size(OC), pack_size(IC)}
  226. * [dense]
  227. * {GROUP, OC_PER_GROUP/pack_size, IC_PER_GROUP/pack_size, \
  228. * FH, FW, pack_size(OC), pack_size(IC)} [group]
  229. *
  230. * NCHW88 and NCHW44 mode
  231. * filter:
  232. * {OC/pack_size, IC/pack_size, FH, FW, pack_size(IC), pack_size(OC)}
  233. * [dense]
  234. * {GROUP, OC_PER_GROUP/pack_size, IC_PER_GROUP/pack_size, \
  235. * FH, FW, pack_size(IC), pack_size(OC)} [group]
  236. * {GROUP/pack_size, 1, 1, FH, FW, pack_size} [chan]
  237. *
  238. *
  239. */
  240. megdnn_assert(
  241. param.format == Param::Format::NCHW88 ||
  242. param.format == Param::Format::NCHW44 ||
  243. param.format == Param::Format::NCHW44_DOT);
  244. size_t img_ndim = 2;
  245. size_t flt_start = 0;
  246. size_t flt_spatial_start = 2;
  247. size_t pack_c_size = 0;
  248. if (param.sparse == Param::Sparse::DENSE) {
  249. if (filter.ndim == img_ndim + 4) {
  250. // oihw8i8o case
  251. megdnn_assert(
  252. (filter[filter.ndim - 2] == pack_size &&
  253. filter[filter.ndim - 1] == pack_size) ||
  254. (filter[filter.ndim - 2] == 2 * pack_size &&
  255. filter[filter.ndim - 1] == 2 * pack_size),
  256. "last 2 dim of filter must be %zu, but got %zu, %zu", pack_size,
  257. filter[filter.ndim - 2], filter[filter.ndim - 1]);
  258. ret.group = 1;
  259. flt_start = 0;
  260. if (filter[filter.ndim - 2] == 2 * pack_size &&
  261. filter[filter.ndim - 1] == 2 * pack_size) {
  262. pack_c_size = 2 * pack_size;
  263. } else {
  264. pack_c_size = pack_size;
  265. }
  266. ret.ocpg = filter[flt_start] * pack_c_size;
  267. ret.icpg = filter[flt_start + 1] * pack_c_size;
  268. } else if (filter.ndim == img_ndim + 3) {
  269. // ohwi8o
  270. flt_start = 0;
  271. flt_spatial_start = 1;
  272. ret.group = 1;
  273. ret.ocpg = filter[flt_start] * pack_size;
  274. ret.icpg = filter[flt_start + 3];
  275. } else {
  276. megdnn_assert(0, "not support nchwxx filter dim = %zu", filter.ndim);
  277. }
  278. } else {
  279. megdnn_assert(
  280. param.sparse == Param::Sparse::GROUP,
  281. "invalid convolution sparse type");
  282. flt_start = 1;
  283. auto filter_oc = filter[flt_start];
  284. auto filter_ic = filter[flt_start + 1];
  285. if (filter_oc == 1 && filter_ic == 1 && filter.ndim == (img_ndim + 4)) {
  286. // Depthwise case goihw8g
  287. megdnn_assert(
  288. filter.ndim == img_ndim + 4,
  289. "bad filter ndim for group convolution: "
  290. "spatial_ndim=%zu filter_ndim=%zu",
  291. img_ndim, filter.ndim);
  292. megdnn_assert(
  293. filter[filter.ndim - 1] == pack_size,
  294. "last dim of filter must be %zu, but %zu", pack_size,
  295. filter[filter.ndim - 1]);
  296. ret.group = filter[0] * pack_size;
  297. ret.ocpg = filter_oc;
  298. ret.icpg = filter_ic;
  299. } else {
  300. // norm group case goihw8i8o
  301. megdnn_assert(
  302. filter.ndim == img_ndim + 5,
  303. "bad filter ndim for group convolution: "
  304. "spatial_ndim=%zu filter_ndim=%zu",
  305. img_ndim, filter.ndim);
  306. megdnn_assert(
  307. (filter[filter.ndim - 1] == pack_size &&
  308. filter[filter.ndim - 2] == pack_size) ||
  309. (filter[filter.ndim - 1] == 2 * pack_size &&
  310. filter[filter.ndim - 2] == 2 * pack_size),
  311. "last 2 dim of filter must be %zu, but got %zu, %zu", pack_size,
  312. filter[filter.ndim - 2], filter[filter.ndim - 1]);
  313. ret.group = filter[0];
  314. if (filter[filter.ndim - 2] == 2 * pack_size &&
  315. filter[filter.ndim - 1] == 2 * pack_size) {
  316. ret.ocpg = filter_oc * 2 * pack_size;
  317. ret.icpg = filter_ic * 2 * pack_size;
  318. } else {
  319. ret.ocpg = filter_oc * pack_size;
  320. ret.icpg = filter_ic * pack_size;
  321. }
  322. }
  323. }
  324. ret.spatial_ndim = 2;
  325. megdnn_assert(
  326. ret.spatial_ndim == 2,
  327. "only 2D convolution is supported, and input should be 5-dim "
  328. "for nchwxx; "
  329. "got input dim = %zu",
  330. src_ndim);
  331. auto dilation = ret.dilation;
  332. for (size_t i = 0; i < ret.spatial_ndim; ++i) {
  333. megdnn_assert(
  334. dilation[i] == 1,
  335. "NCHWXX has invalid dilation on spatial dim %zu: %u, "
  336. "require to be 1",
  337. i, dilation[i]);
  338. ret.spatial[i] = filter[i + flt_start + flt_spatial_start];
  339. ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1;
  340. }
  341. }
  342. template <size_t pack_size, typename Parameter, typename Param>
  343. void make_canonized_filter_meta_nchwx(
  344. size_t src_ndim, const TensorLayout& filter, const Param& param,
  345. typename ConvolutionBase<Parameter>::CanonizedFilterMeta& ret) {
  346. /**
  347. * input: N IC/pack_size, H, W, pack_size
  348. * filter:
  349. * OC, IC/pack_size, FH, FW, pack_size [dense]
  350. * GROUP, OC, IC/pack_size, FH, FW, pack_size [group]
  351. */
  352. megdnn_assert(
  353. param.format == Param::Format::NCHW4 ||
  354. param.format == Param::Format::NCHW8 ||
  355. param.format == Param::Format::NCHW32 ||
  356. param.format == Param::Format::NCHW4_NCHW ||
  357. param.format == Param::Format::NCHW4_NHWC ||
  358. param.format == Param::Format::NCHW4_NCHW32 ||
  359. param.format == Param::Format::NCHW32_NCHW4 ||
  360. param.format == Param::Format::NCHW64);
  361. auto img_ndim = src_ndim - 3;
  362. size_t flt_start = 0, flt_spatial_start = 2;
  363. if (param.sparse == Param::Sparse::DENSE) {
  364. megdnn_assert(
  365. filter.ndim == img_ndim + 3,
  366. "bad filter ndim for dense convolution: "
  367. "spatial_ndim=%zu filter_ndim=%zu",
  368. img_ndim, filter.ndim);
  369. // oc, ic, dims[]
  370. ret.group = 1;
  371. flt_start = 0;
  372. } else {
  373. megdnn_assert(
  374. param.sparse == Param::Sparse::GROUP,
  375. "invalid convolution sparse type");
  376. megdnn_assert(
  377. filter.ndim == img_ndim + 4,
  378. "bad filter ndim for group convolution: "
  379. "spatial_ndim=%zu filter_ndim=%zu",
  380. img_ndim, filter.ndim);
  381. ret.group = filter[0];
  382. flt_start = 1;
  383. }
  384. ret.spatial_ndim = src_ndim - 3;
  385. megdnn_assert(
  386. ret.spatial_ndim == 2,
  387. "only 2D convolution is supported, and input should be 5-dim "
  388. "for nchw4; "
  389. "got input dim = %zu",
  390. src_ndim);
  391. ret.ocpg = filter[flt_start];
  392. ret.icpg = filter[flt_start + 1] * pack_size;
  393. auto dilation = ret.dilation;
  394. for (size_t i = 0; i < ret.spatial_ndim; ++i) {
  395. megdnn_assert(
  396. dilation[i] == 1,
  397. "NCHW4 has invalid dilation on spatial dim %zu: %u, "
  398. "require to be 1",
  399. i, dilation[i]);
  400. ret.spatial[i] = filter[i + flt_start + flt_spatial_start];
  401. ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1;
  402. }
  403. }
  404. template <size_t pack_size, typename Parameter, typename Param>
  405. void make_canonized_filter_meta_chwnx(
  406. size_t src_ndim, const TensorLayout& filter, const Param& param,
  407. typename ConvolutionBase<Parameter>::CanonizedFilterMeta& ret) {
  408. /**
  409. * input: IC / pack_size, H, W, N, pack_size
  410. * Filter:
  411. * IC / pack_size, FH, FW, OC, pack_size [dense]
  412. * GROUP, icpg / pack_size, FH, FW, ocpg, pack_size [group]
  413. * not implemented [chanwise]
  414. */
  415. megdnn_assert(param.format == Param::Format::CHWN4);
  416. auto img_ndim = src_ndim - 3;
  417. size_t flt_start = 0, flt_spatial_start = 1;
  418. if (param.sparse == Param::Sparse::DENSE) {
  419. megdnn_assert(
  420. filter.ndim == img_ndim + 3,
  421. "bad filter ndim for dense convolution: "
  422. "spatial_ndim=%zu filter_ndim=%zu",
  423. img_ndim, filter.ndim);
  424. // oc, ic, dims[]
  425. ret.group = 1;
  426. flt_start = 0;
  427. } else {
  428. megdnn_assert(
  429. param.sparse == Param::Sparse::GROUP,
  430. "invalid convolution sparse type");
  431. megdnn_assert(
  432. filter.ndim == img_ndim + 4,
  433. "bad filter ndim for group convolution: "
  434. "spatial_ndim=%zu filter_ndim=%zu",
  435. img_ndim, filter.ndim);
  436. ret.group = filter[0];
  437. flt_start = 1;
  438. }
  439. ret.spatial_ndim = src_ndim - 3;
  440. megdnn_assert(
  441. ret.spatial_ndim == 2,
  442. "only 2D convolution is supported, and input should be 4-dim; "
  443. "got input dim = %zu",
  444. src_ndim);
  445. ret.icpg = filter[flt_start] * pack_size;
  446. ret.ocpg = filter[flt_start + 3];
  447. auto dilation = ret.dilation;
  448. for (size_t i = 0; i < ret.spatial_ndim; ++i) {
  449. megdnn_assert(
  450. dilation[i] == 1,
  451. "CHWNx has invalid dilation on spatial dim %zu: %u, "
  452. "require to be 1",
  453. i, dilation[i]);
  454. ret.spatial[i] = filter[i + flt_start + flt_spatial_start];
  455. ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1;
  456. }
  457. }
  458. } // namespace
  459. namespace megdnn {
  460. template <typename Parameter>
  461. typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Parameter>::
  462. make_canonized_filter_meta(size_t src_ndim, const TensorLayout& filter) const {
  463. megdnn_assert_contiguous(filter);
  464. CanonizedFilterMeta ret;
  465. ret.dtype = filter.dtype;
  466. ret.format = param().format;
  467. if (param().mode == Mode::CONVOLUTION) {
  468. ret.should_flip = true;
  469. } else {
  470. megdnn_assert(param().mode == Mode::CROSS_CORRELATION, "invalid conv mode");
  471. ret.should_flip = false;
  472. }
  473. ret.stride[0] = param().stride_h;
  474. ret.stride[1] = param().stride_w;
  475. ret.padding[0] = param().pad_h;
  476. ret.padding[1] = param().pad_w;
  477. ret.dilation[0] = param().dilate_h;
  478. ret.dilation[1] = param().dilate_w;
  479. if (param().format == Param::Format::NHWCD4) {
  480. if (filter.dtype.enumv() == DTypeEnum::QuantizedS8 ||
  481. filter.dtype.enumv() == DTypeEnum::Quantized8Asymm) {
  482. make_canonized_filter_meta_nhwcd4_dot<Parameter>(
  483. src_ndim, filter, param(), ret);
  484. } else {
  485. make_canonized_filter_meta_nhwcd4<Parameter>(
  486. src_ndim, filter, param(), ret);
  487. }
  488. } else if (
  489. param().format == Param::Format::NCHW4 ||
  490. param().format == Param::Format::NCHW4_NCHW ||
  491. param().format == Param::Format::NCHW4_NHWC ||
  492. param().format == Param::Format::NCHW4_NCHW32) {
  493. make_canonized_filter_meta_nchwx<4, Parameter>(src_ndim, filter, param(), ret);
  494. } else if (param().format == Param::Format::NCHW8) {
  495. make_canonized_filter_meta_nchwx<8, Parameter>(src_ndim, filter, param(), ret);
  496. } else if (param().format == Param::Format::NCHW88) {
  497. make_canonized_filter_meta_nchwxx<8, Parameter>(src_ndim, filter, param(), ret);
  498. } else if (
  499. param().format == Param::Format::NCHW44 ||
  500. param().format == Param::Format::NCHW44_DOT) {
  501. make_canonized_filter_meta_nchwxx<4, Parameter>(src_ndim, filter, param(), ret);
  502. } else if (
  503. param().format == Param::Format::NCHW32 ||
  504. param().format == Param::Format::NCHW32_NCHW4) {
  505. make_canonized_filter_meta_nchwx<32, Parameter>(src_ndim, filter, param(), ret);
  506. } else if (param().format == Param::Format::CHWN4) {
  507. make_canonized_filter_meta_chwnx<4, Parameter>(src_ndim, filter, param(), ret);
  508. } else if (param().format == Param::Format::NCHW64) {
  509. make_canonized_filter_meta_nchwx<64, Parameter>(src_ndim, filter, param(), ret);
  510. } else {
  511. megdnn_assert(
  512. param().format == Param::Format::NHWC ||
  513. param().format == Param::Format::NCHW);
  514. make_canonized_filter_meta_nchw_nhwc<Parameter>(src_ndim, filter, param(), ret);
  515. }
  516. return ret;
  517. }
  518. template <typename Parameter>
  519. void ConvolutionBase<Parameter>::check_or_deduce_dtype_fwd(
  520. DType src, DType filter, DType& dst) const {
  521. // The first one will be the default choice.
  522. SmallVector<DType> supported_dst_dtype;
  523. // We rely on megdnn_assert(src.enumv() == filter.enumv()) here.
  524. if (src.category() == DTypeCategory::FLOAT) {
  525. supported_dst_dtype.push_back(src);
  526. } else if (src.enumv() == DTypeEnum::Int8) {
  527. supported_dst_dtype = {dtype::Int32(), dtype::Int16()};
  528. } else if (
  529. src.enumv() == DTypeEnum::QuantizedS8 ||
  530. src.enumv() == DTypeEnum::Quantized8Asymm ||
  531. src.enumv() == DTypeEnum::QuantizedS4 ||
  532. src.enumv() == DTypeEnum::Quantized4Asymm ||
  533. src.enumv() == DTypeEnum::QuantizedS1) {
  534. supported_dst_dtype.push_back(dtype::QuantizedS32(mul_scale(src, filter)));
  535. bool cond_dst = dst.valid() && (dst.enumv() == src.enumv() ||
  536. ((dst.enumv() == DTypeEnum::QuantizedS4 ||
  537. dst.enumv() == DTypeEnum::Quantized4Asymm) &&
  538. src.enumv() == DTypeEnum::QuantizedS8) ||
  539. ((src.enumv() == DTypeEnum::QuantizedS4 ||
  540. src.enumv() == DTypeEnum::Quantized4Asymm) &&
  541. dst.enumv() == DTypeEnum::QuantizedS8));
  542. if (cond_dst) {
  543. supported_dst_dtype.push_back(dst);
  544. }
  545. if (src.enumv() == DTypeEnum::QuantizedS8) {
  546. supported_dst_dtype.push_back(dtype::Float32());
  547. }
  548. } else if (src.enumv() == DTypeEnum::QuantizedS32) {
  549. //! ConvolutionBackwardData: s8(filter) + s8(dst) -> s32(src)
  550. megdnn_assert(filter.enumv() == DTypeEnum::QuantizedS8);
  551. supported_dst_dtype.push_back(dtype::QuantizedS8(
  552. src.param<dtype::QuantizedS32>().scale /
  553. filter.param<dtype::QuantizedS8>().scale));
  554. } else {
  555. megdnn_throw(ssprintf(
  556. "runtime does not support input / filter DType: %s x %s"
  557. "now support case list: FLOAT x FLOAT\n"
  558. " Int8 x Int8\n"
  559. " QuantizedS8 x QuantizedS8\n"
  560. " Quantized8Asymm x Quantized8Asymm\n"
  561. " QuantizedS4 x QuantizedS4\n"
  562. " Quantized4Asymm x Quantized4Asymm\n"
  563. " QuantizedS1 x QuantizedS1\n",
  564. src.name(), filter.name()));
  565. }
  566. if (!dst.valid()) {
  567. dst = supported_dst_dtype.at(0);
  568. } else {
  569. bool dst_supported = false;
  570. for (auto&& dt : supported_dst_dtype) {
  571. if (dtype_almost_equal(dt, dst)) {
  572. dst_supported = true;
  573. break;
  574. }
  575. }
  576. MEGDNN_MARK_USED_VAR(dst_supported);
  577. megdnn_assert(
  578. dst_supported,
  579. "runtime does not support Conv(%s, %s) -> %s"
  580. "now support case list: Conv(FLOAT x FLOAT) -> FLOAT\n"
  581. " Conv(Int8 x Int8) -> Int32\n"
  582. " Conv(QuantizedS8 x QuantizedS8) -> "
  583. "QuantizedS32\n"
  584. " Conv(Quantized8Asymm x Quantized8Asymm) -> "
  585. "Quantized32Asymm\n"
  586. " Conv(QuantizedS4 x QuantizedS4) -> "
  587. "QuantizedS32\n"
  588. " Conv(Quantized4Asymm x Quantized4Asymm) -> "
  589. "Quantized32Asymm\n"
  590. " Conv(QuantizedS1 x QuantizedS1) -> "
  591. "QuantizedS32\n",
  592. src.name(), filter.name(), dst.name());
  593. }
  594. megdnn_assert(
  595. (param().compute_mode == Param::ComputeMode::FLOAT32 ||
  596. param().compute_mode == Param::ComputeMode::DEFAULT)
  597. #if !MEGDNN_DISABLE_FLOAT16
  598. || src.enumv() == DTypeEnum::Float16 ||
  599. src.enumv() == DTypeEnum::BFloat16
  600. #endif
  601. ,
  602. "ComputeMode::FLOAT32 is only available for Float16/BFloat16 "
  603. "input / output.");
  604. }
  605. template <typename Parameter>
  606. typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Parameter>::
  607. deduce_layout_fwd(
  608. const TensorLayout& src, const TensorLayout& filter,
  609. TensorLayout& dst) const {
  610. auto errmsg = [&]() { return get_errmsg(src, filter, dst, param()); };
  611. MEGDNN_MARK_USED_VAR(errmsg);
  612. megdnn_assert(src.ndim >= 3_z, "%s", errmsg().c_str());
  613. megdnn_assert(
  614. ((src.dtype.enumv() == filter.dtype.enumv()) ||
  615. (src.dtype.enumv() == DTypeEnum::Quantized4Asymm &&
  616. filter.dtype.enumv() == DTypeEnum::QuantizedS4)),
  617. "%s", errmsg().c_str());
  618. check_or_deduce_dtype_fwd(src.dtype, filter.dtype, dst.dtype);
  619. size_t img_dim;
  620. if (param().format == Param::Format::NCHW ||
  621. param().format == Param::Format::NHWC) {
  622. img_dim = src.ndim - 2;
  623. megdnn_assert(
  624. filter.ndim >= img_dim + 2 && filter.ndim <= img_dim + 6, "%s",
  625. errmsg().c_str());
  626. } else {
  627. megdnn_assert(
  628. param().format == Param::Format::NHWCD4 ||
  629. param().format == Param::Format::NCHW4 ||
  630. param().format == Param::Format::NCHW4_NCHW ||
  631. param().format == Param::Format::NCHW4_NHWC ||
  632. param().format == Param::Format::NCHW4_NCHW32 ||
  633. param().format == Param::Format::NCHW44 ||
  634. param().format == Param::Format::NCHW44_DOT ||
  635. param().format == Param::Format::NCHW8 ||
  636. param().format == Param::Format::NCHW32 ||
  637. param().format == Param::Format::NCHW32_NCHW4 ||
  638. param().format == Param::Format::NCHW88 ||
  639. param().format == Param::Format::CHWN4 ||
  640. param().format == Param::Format::NCHW64);
  641. img_dim = src.ndim - 3;
  642. if ((param().format == Param::Format::NCHW88 ||
  643. param().format == Param::Format::NCHW44_DOT ||
  644. param().format == Param::Format::NCHW44) &&
  645. filter.ndim == 5) {
  646. img_dim = src.ndim - 2;
  647. }
  648. megdnn_assert(
  649. filter.ndim == img_dim + 3 ||
  650. (filter.ndim == img_dim + 2 &&
  651. (param().format == Param::Format::NCHW88 ||
  652. param().format == Param::Format::NCHW44_DOT ||
  653. param().format == Param::Format::NCHW44)) ||
  654. filter.ndim == img_dim + 4 || filter.ndim == img_dim + 5,
  655. "%s", errmsg().c_str());
  656. if (param().format == Param::Format::NCHW4 ||
  657. param().format == Param::Format::NCHW4_NCHW ||
  658. param().format == Param::Format::NCHW4_NCHW32) {
  659. megdnn_assert(
  660. src.ndim == 5 &&
  661. (filter.ndim == 5 || filter.ndim == 6 ||
  662. filter.ndim == 7) &&
  663. src[src.ndim - 1] == 4 && filter[filter.ndim - 1] == 4,
  664. "NCHW4/NCHW4_NCHW/NCHW4_NCHW32 require src and "
  665. "filter's ndim is "
  666. "5 or 6, and "
  667. "last shape "
  668. "is 4 "
  669. "but got src %s, filter %s",
  670. src.to_string().c_str(), filter.to_string().c_str());
  671. }
  672. if (param().format == Param::Format::NCHW8) {
  673. megdnn_assert(
  674. src.ndim == 5 && (filter.ndim == 5 || filter.ndim == 6) &&
  675. src[src.ndim - 1] == 8 && filter[filter.ndim - 1] == 8,
  676. "NCHW8 require src and filter's ndim is 5 or 6, and last "
  677. "shape is 8 "
  678. "but got src %s, filter %s",
  679. src.to_string().c_str(), filter.to_string().c_str());
  680. }
  681. if (param().format == Param::Format::NCHW32 ||
  682. param().format == Param::Format::NCHW32_NCHW4) {
  683. megdnn_assert(
  684. src.ndim == 5 && (filter.ndim == 5 || filter.ndim == 6) &&
  685. src[src.ndim - 1] == 32 && filter[filter.ndim - 1] == 32,
  686. "NCHW32/NCHW32_NCHW4 require src and filter's ndim "
  687. "is 5 or 6, and last "
  688. "shape is 32 "
  689. "but got src %s, filter %s",
  690. src.to_string().c_str(), filter.to_string().c_str());
  691. }
  692. if (param().format == Param::Format::NCHW88) {
  693. megdnn_assert(
  694. (src.ndim == 4 && filter.ndim == 5 &&
  695. filter[filter.ndim - 1] == 8) ||
  696. (src.ndim == 5 &&
  697. ((filter.ndim == 6 && filter[filter.ndim - 1] == 8) ||
  698. (filter.ndim == 7 && filter[filter.ndim - 1] == 8 &&
  699. filter[filter.ndim - 2] == 8)) &&
  700. src[src.ndim - 1] == 8),
  701. "NCHW88 require src ndim is 5 and filter's ndim is 6 "
  702. ", and last shape two is 8 but got src %s, filter %s",
  703. src.to_string().c_str(), filter.to_string().c_str());
  704. }
  705. if (param().format == Param::Format::NCHW44 ||
  706. param().format == Param::Format::NCHW44_DOT) {
  707. //! support nchw44 filter change to 88 for int8 winogradf23_88 using
  708. //! MK8 mamtul
  709. megdnn_assert(
  710. (src.ndim == 4 && filter.ndim == 5 &&
  711. filter[filter.ndim - 1] == 4) ||
  712. (src.ndim == 5 &&
  713. ((filter.ndim == 6 && (filter[filter.ndim - 1] == 4 ||
  714. filter[filter.ndim - 1] == 8)) ||
  715. (filter.ndim == 7 &&
  716. (filter[filter.ndim - 1] == 4 ||
  717. filter[filter.ndim - 1] == 8) &&
  718. (filter[filter.ndim - 2] == 4 ||
  719. filter[filter.ndim - 2] == 8))) &&
  720. src[src.ndim - 1] == 4),
  721. "NCHW44 require src ndim is 5 and filter's ndim is 6 "
  722. ", and last shape two is 4 but got src %s, filter %s",
  723. src.to_string().c_str(), filter.to_string().c_str());
  724. }
  725. if (param().format == Param::Format::CHWN4) {
  726. megdnn_assert(
  727. src.ndim == 5 && (filter.ndim == 5 || filter.ndim == 6) &&
  728. src[src.ndim - 1] == 4 && filter[filter.ndim - 1] == 4,
  729. "CHWN4 require src and filter's ndim is 5 or 6, and last "
  730. "shape is 4 "
  731. "but got src %s, filter %s",
  732. src.to_string().c_str(), filter.to_string().c_str());
  733. }
  734. if (param().format == Param::Format::NCHW64) {
  735. megdnn_assert(
  736. src.ndim == 5 && (filter.ndim == 5 || filter.ndim == 6) &&
  737. src[src.ndim - 1] == 64 && filter[filter.ndim - 1] == 64,
  738. "NCHW64 require src and filter's ndim is 5 or 6, and "
  739. "last shape is 64 but got src %s, filter %s",
  740. src.to_string().c_str(), filter.to_string().c_str());
  741. }
  742. }
  743. megdnn_assert(img_dim == 2, "currently only convolution on 2D image is supported");
  744. auto cflt = make_canonized_filter_meta(src.ndim, filter);
  745. if (param().format == Param::Format::NCHW ||
  746. param().format == Param::Format::NHWC) {
  747. size_t src_or_dst_c_pos = 0;
  748. size_t src_or_dst_spatial_start = 0;
  749. if (param().format == Param::Format::NCHW) {
  750. src_or_dst_c_pos = 1;
  751. src_or_dst_spatial_start = 2;
  752. } else {
  753. megdnn_assert(param().format == Param::Format::NHWC, "invalid conv format");
  754. src_or_dst_c_pos = 3;
  755. src_or_dst_spatial_start = 1;
  756. }
  757. megdnn_assert(
  758. cflt.icpg * cflt.group == src[src_or_dst_c_pos],
  759. "group conv channel mismatch : input channel got %zu, and "
  760. "filter channel got %u. More details for src, filter and dst : \n%s",
  761. src[src_or_dst_c_pos], cflt.icpg * cflt.group, errmsg().c_str());
  762. dst.ndim = src.ndim;
  763. dst[0] = src[0];
  764. dst[src_or_dst_c_pos] = cflt.ocpg * cflt.group;
  765. for (size_t i = 0; i < cflt.spatial_ndim; ++i) {
  766. dst[i + src_or_dst_spatial_start] = infer_conv_shape(
  767. src[i + src_or_dst_spatial_start], cflt.dilated_spatial[i],
  768. cflt.stride[i], cflt.padding[i]);
  769. }
  770. } else if (param().format == Param::Format::NCHW4) {
  771. megdnn_assert(
  772. src.ndim == 5, "invalid src ndim for NCHW4, expected=5, got=%zu",
  773. src.ndim);
  774. megdnn_assert(
  775. cflt.icpg * cflt.group == src[1] * 4,
  776. "group conv channel mismatch : input channel got %zu, and "
  777. "filter channel got %u. More details for src, filter and dst : \n%s",
  778. src[1] * 4, cflt.icpg * cflt.group, errmsg().c_str());
  779. dst.ndim = src.ndim;
  780. dst[0] = src[0];
  781. auto oc = cflt.ocpg * cflt.group;
  782. megdnn_assert(oc % 4 == 0);
  783. dst[1] = oc / 4;
  784. dst[2] = infer_conv_shape(
  785. src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  786. dst[3] = infer_conv_shape(
  787. src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  788. dst[4] = 4;
  789. } else if (param().format == Param::Format::NCHW8) {
  790. megdnn_assert(
  791. src.ndim == 5, "invalid src ndim for NCHW8, expected=5, got=%zu",
  792. src.ndim);
  793. megdnn_assert(
  794. cflt.icpg * cflt.group == src[1] * 8,
  795. "group conv channel mismatch : input channel got %zu, and "
  796. "filter channel got %u. More details for src, filter and dst : \n%s",
  797. src[1] * 8, cflt.icpg * cflt.group, errmsg().c_str());
  798. dst.ndim = src.ndim;
  799. dst[0] = src[0];
  800. auto oc = cflt.ocpg * cflt.group;
  801. megdnn_assert(oc % 8 == 0);
  802. dst[1] = oc / 8;
  803. dst[2] = infer_conv_shape(
  804. src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  805. dst[3] = infer_conv_shape(
  806. src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  807. dst[4] = 8;
  808. } else if (param().format == Param::Format::NCHW32) {
  809. megdnn_assert(
  810. src.ndim == 5, "invalid src ndim for NCHW32, expected=5, got=%zu",
  811. src.ndim);
  812. megdnn_assert(
  813. cflt.icpg * cflt.group == src[1] * 32,
  814. "group conv channel mismatch : input channel got %zu, and "
  815. "filter channel got %u. More details for src, filter and dst : \n%s",
  816. src[1] * 32, cflt.icpg * cflt.group, errmsg().c_str());
  817. dst.ndim = src.ndim;
  818. dst[0] = src[0];
  819. auto oc = cflt.ocpg * cflt.group;
  820. megdnn_assert(oc % 32 == 0);
  821. dst[1] = oc / 32;
  822. dst[2] = infer_conv_shape(
  823. src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  824. dst[3] = infer_conv_shape(
  825. src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  826. dst[4] = 32;
  827. } else if (param().format == Param::Format::NCHW88) {
  828. megdnn_assert(
  829. src.ndim == 5 || (src.ndim == 4 && src[1] <= 8),
  830. "invalid src ndim for NCHW88, expected=5 or 4, got=%zu", src.ndim);
  831. dst.ndim = 5;
  832. dst[0] = src[0];
  833. auto oc = cflt.ocpg * cflt.group;
  834. megdnn_assert(oc % 8 == 0);
  835. dst[1] = oc / 8;
  836. dst[2] = infer_conv_shape(
  837. src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  838. dst[3] = infer_conv_shape(
  839. src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  840. dst[4] = 8;
  841. if (cflt.group == 1) {
  842. megdnn_assert(
  843. cflt.icpg * cflt.group == src[1] * 8 ||
  844. (cflt.icpg * cflt.group == src[1]),
  845. "group conv channel mismatch : input channel got %zu, and "
  846. "filter channel got %u. More details about src, filter and dst : "
  847. "\n%s",
  848. src.ndim == 5 ? src[1] * 8 : src[1], cflt.icpg * cflt.group,
  849. errmsg().c_str());
  850. }
  851. } else if (
  852. param().format == Param::Format::NCHW44 ||
  853. param().format == Param::Format::NCHW44_DOT) {
  854. megdnn_assert(
  855. src.ndim == 5 || (src.ndim == 4 && src[1] <= 4),
  856. "invalid src ndim for NCHW44, expected=5 or 4, got=%zu", src.ndim);
  857. dst.ndim = 5;
  858. dst[0] = src[0];
  859. auto oc = cflt.ocpg * cflt.group;
  860. megdnn_assert(oc % 4 == 0);
  861. dst[1] = oc / 4;
  862. dst[2] = infer_conv_shape(
  863. src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  864. dst[3] = infer_conv_shape(
  865. src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  866. dst[4] = 4;
  867. if (cflt.group == 1) {
  868. megdnn_assert(
  869. cflt.icpg * cflt.group == src[1] * 4 ||
  870. (cflt.icpg * cflt.group == src[1]),
  871. "group conv channel mismatch : input channel got %zu, and "
  872. "filter channel got %u. More details about src, filter and dst : "
  873. "\n%s",
  874. src.ndim == 5 ? src[1] * 4 : src[1], cflt.icpg * cflt.group,
  875. errmsg().c_str());
  876. }
  877. } else if (param().format == Param::Format::CHWN4) {
  878. megdnn_assert(
  879. src.ndim == 5, "invalid src ndim for CHWN4, expected=5, got=%zu",
  880. src.ndim);
  881. megdnn_assert(
  882. cflt.icpg * cflt.group == src[0] * 4,
  883. "group conv channel mismatch : input channel got %zu, and "
  884. "filter channel got %u. More details for src, filter and dst : \n%s",
  885. src[0] * 4, cflt.icpg * cflt.group, errmsg().c_str());
  886. dst.ndim = src.ndim;
  887. dst[3] = src[3];
  888. auto oc = cflt.ocpg * cflt.group;
  889. megdnn_assert(oc % 4 == 0);
  890. dst[0] = oc / 4;
  891. dst[1] = infer_conv_shape(
  892. src[1], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  893. dst[2] = infer_conv_shape(
  894. src[2], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  895. dst[4] = 4;
  896. } else if (param().format == Param::Format::NCHW4_NCHW) {
  897. megdnn_assert(
  898. src.ndim == 5, "invalid src ndim for NCHW4_NCHW, expected=5, got=%zu",
  899. src.ndim);
  900. megdnn_assert(
  901. cflt.icpg * cflt.group == src[1] * 4,
  902. "group conv channel mismatch : input channel got %zu, and "
  903. "filter channel got %u. More details for src, filter and dst : \n%s",
  904. src[1] * 4, cflt.icpg * cflt.group, errmsg().c_str());
  905. dst.ndim = 4;
  906. dst[0] = src[0];
  907. auto oc = cflt.ocpg * cflt.group;
  908. dst[1] = oc;
  909. dst[2] = infer_conv_shape(
  910. src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  911. dst[3] = infer_conv_shape(
  912. src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  913. } else if (param().format == Param::Format::NCHW4_NHWC) {
  914. megdnn_assert(
  915. src.ndim == 5, "invalid src ndim for NCHW4_NHWC, expected=5, got=%zu",
  916. src.ndim);
  917. megdnn_assert(
  918. cflt.icpg * cflt.group == src[1] * 4,
  919. "group conv channel mismatch : input channel got %zu, and "
  920. "filter channel got %u. More details for src, filter and dst : \n%s",
  921. src[1] * 4, cflt.icpg * cflt.group, errmsg().c_str());
  922. dst.ndim = 4;
  923. dst[0] = src[0];
  924. dst[1] = infer_conv_shape(
  925. src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  926. dst[2] = infer_conv_shape(
  927. src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  928. auto oc = cflt.ocpg * cflt.group;
  929. dst[3] = oc;
  930. } else if (param().format == Param::Format::NCHW4_NCHW32) {
  931. megdnn_assert(
  932. src.ndim == 5, "invalid src ndim for NCHW4_NCHW32, expected=5, got=%zu",
  933. src.ndim);
  934. megdnn_assert(
  935. cflt.icpg * cflt.group == src[1] * 4,
  936. "group conv channel mismatch : input channel got %zu, and "
  937. "filter channel got %u. More details for src, filter and dst : \n%s",
  938. src[1] * 4, cflt.icpg * cflt.group, errmsg().c_str());
  939. dst.ndim = src.ndim;
  940. dst[0] = src[0];
  941. auto oc = cflt.ocpg * cflt.group;
  942. megdnn_assert(oc % 32 == 0);
  943. dst[1] = oc / 32;
  944. dst[2] = infer_conv_shape(
  945. src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  946. dst[3] = infer_conv_shape(
  947. src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  948. dst[4] = 32;
  949. } else if (param().format == Param::Format::NCHW32_NCHW4) {
  950. megdnn_assert(
  951. src.ndim == 5, "invalid src ndim for NCHW32_NCHW4, expected=5, got=%zu",
  952. src.ndim);
  953. megdnn_assert(
  954. cflt.icpg * cflt.group == src[1] * 32,
  955. "group conv channel mismatch : input channel got %zu, and "
  956. "filter channel got %u. More details for src, filter and dst : \n%s",
  957. src[1] * 32, cflt.icpg * cflt.group, errmsg().c_str());
  958. dst.ndim = src.ndim;
  959. dst[0] = src[0];
  960. auto oc = cflt.ocpg * cflt.group;
  961. megdnn_assert(oc % 4 == 0);
  962. dst[1] = oc / 4;
  963. dst[2] = infer_conv_shape(
  964. src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  965. dst[3] = infer_conv_shape(
  966. src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  967. dst[4] = 4;
  968. } else if (param().format == Param::Format::NCHW64) {
  969. megdnn_assert(
  970. src.ndim == 5, "invalid src ndim for NCHW64, expected=5, got=%zu",
  971. src.ndim);
  972. megdnn_assert(
  973. cflt.icpg * cflt.group == src[1] * 64,
  974. "group conv channel mismatch : input channel got %zu, and "
  975. "filter channel got %u. More details for src, filter and dst : \n%s",
  976. src[1] * 64, cflt.icpg * cflt.group, errmsg().c_str());
  977. dst.ndim = src.ndim;
  978. dst[0] = src[0];
  979. auto oc = cflt.ocpg * cflt.group;
  980. megdnn_assert(oc % 64 == 0);
  981. dst[1] = oc / 64;
  982. dst[2] = infer_conv_shape(
  983. src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  984. dst[3] = infer_conv_shape(
  985. src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  986. dst[4] = 64;
  987. } else {
  988. megdnn_assert(param().format == Param::Format::NHWCD4);
  989. megdnn_assert(
  990. src.ndim == 5, "invalid src ndim for NHWCD4, expected=5, got=%zu",
  991. src.ndim);
  992. megdnn_assert(
  993. cflt.icpg * cflt.group == src[2] * 4,
  994. "group conv channel mismatch : input channel got %zu, and "
  995. "filter channel got %u. More details for src, filter and dst : \n%s",
  996. src[2] * 4, cflt.icpg * cflt.group, errmsg().c_str());
  997. dst.ndim = src.ndim;
  998. dst[0] = src[0];
  999. auto oc = cflt.ocpg * cflt.group;
  1000. megdnn_assert(oc % 4 == 0);
  1001. dst[2] = oc / 4;
  1002. dst[1] = infer_conv_shape(
  1003. src[1], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  1004. dst[3] = infer_conv_shape(
  1005. src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  1006. megdnn_assert(src[4] == 4);
  1007. dst[4] = 4;
  1008. }
  1009. if (!src.format.is_default() && !src.format.is_lowbit_aligned()) { // propagate
  1010. dst.format = src.format;
  1011. } else { // determined by dtype
  1012. dst.format = TensorFormat(dst.dtype);
  1013. }
  1014. dst.init_contiguous_stride();
  1015. return cflt;
  1016. }
  1017. /**
  1018. * \warning: An explicit specialization shall be declared in a namespace
  1019. * enclosing the specialized template. An explicit specialization whose
  1020. * declarator-id is not qualified shall be declared in the nearest enclosing
  1021. * namespace of the template, or, if the namespace is inline (7.3.1), any
  1022. * namespace from its enclosing namespace set.
  1023. * refer to:
  1024. * https://stackoverflow.com/questions/25594644/warning-specialization-of-template-in-different-namespace
  1025. */
  1026. template <>
  1027. ConvolutionBase<param::Convolution>::CanonizedFilterMeta ConvolutionBase<
  1028. param::Convolution>::
  1029. check_layout_fwd(
  1030. const TensorLayout& src, const TensorLayout& filter,
  1031. const TensorLayout& dst) const {
  1032. megdnn_assert_contiguous(src);
  1033. megdnn_assert_contiguous(filter);
  1034. TensorLayout dst_expected;
  1035. dst_expected.dtype = dst.dtype;
  1036. auto ret = deduce_layout_fwd(src, filter, dst_expected);
  1037. megdnn_assert_eq_layout(dst_expected, dst);
  1038. return ret;
  1039. }
  1040. template <>
  1041. ConvolutionBase<param::ConvBias>::CanonizedFilterMeta ConvolutionBase<param::ConvBias>::
  1042. check_layout_fwd(
  1043. const TensorLayout& src, const TensorLayout& filter,
  1044. const TensorLayout& dst) const {
  1045. megdnn_assert_contiguous(src);
  1046. megdnn_assert_contiguous(filter);
  1047. TensorLayout dst_expected;
  1048. dst_expected.dtype = dst.dtype;
  1049. auto ret = deduce_layout_fwd(src, filter, dst_expected);
  1050. megdnn_assert_eq_layout(dst_expected, dst);
  1051. return ret;
  1052. }
  1053. template <>
  1054. ConvolutionBase<param::BatchConvBias>::CanonizedFilterMeta ConvolutionBase<
  1055. param::BatchConvBias>::
  1056. check_layout_fwd(
  1057. const TensorLayout& src, const TensorLayout& filter,
  1058. const TensorLayout& dst) const {
  1059. megdnn_assert_contiguous(src);
  1060. megdnn_assert_contiguous(filter);
  1061. TensorLayout dst_expected;
  1062. dst_expected.dtype = dst.dtype;
  1063. auto ret = deduce_layout_fwd(src, filter, dst_expected);
  1064. megdnn_assert_eq_layout(dst_expected, dst);
  1065. return ret;
  1066. }
  1067. void ConvolutionForward::deduce_dtype(DType src, DType filter, DType& dst) {
  1068. check_or_deduce_dtype_fwd(src, filter, dst);
  1069. }
  1070. void ConvolutionForward::deduce_layout(
  1071. const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst) {
  1072. deduce_layout_fwd(src, filter, dst);
  1073. }
  1074. ConvolutionForward::CanonizedFilterMeta ConvolutionForward::check_exec(
  1075. const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst,
  1076. size_t workspace_in_bytes, const PreprocessedFilter* preprocessed_filter) {
  1077. auto ret = check_layout_fwd(src, filter, dst);
  1078. auto required_workspace_in_bytes =
  1079. get_workspace_in_bytes(src, filter, dst, preprocessed_filter);
  1080. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  1081. return ret;
  1082. }
  1083. ConvolutionBackwardData::CanonizedFilterMeta ConvolutionBackwardData::check_exec(
  1084. const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& grad,
  1085. size_t workspace_in_bytes) {
  1086. auto grad_fwd = grad;
  1087. auto filter_fwd = filter;
  1088. auto diff_fwd = diff;
  1089. std::swap(grad_fwd.dtype, diff_fwd.dtype);
  1090. grad_fwd.init_contiguous_stride();
  1091. diff_fwd.init_contiguous_stride();
  1092. auto ret = check_layout_fwd(grad_fwd, filter_fwd, diff_fwd);
  1093. auto required_workspace_in_bytes = get_workspace_in_bytes(filter, diff, grad);
  1094. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  1095. return ret;
  1096. }
  1097. void ConvolutionBackwardData::deduce_dtype(DType filter, DType diff, DType& grad) {
  1098. SmallVector<DType> supported_dst_dtype;
  1099. if (filter.category() == diff.category() &&
  1100. filter.category() == DTypeCategory::FLOAT) {
  1101. supported_dst_dtype.push_back(filter);
  1102. } else if (filter.enumv() == DTypeEnum::Int8 && diff == filter) {
  1103. supported_dst_dtype.push_back(dtype::Int32());
  1104. } else if (
  1105. (filter.enumv() == DTypeEnum::QuantizedS8 &&
  1106. diff.enumv() == DTypeEnum::QuantizedS8) ||
  1107. (filter.enumv() == DTypeEnum::Quantized8Asymm &&
  1108. diff.enumv() == DTypeEnum::Quantized8Asymm)) {
  1109. supported_dst_dtype.push_back(dtype::QuantizedS32(mul_scale(filter, diff)));
  1110. if (grad.valid() && grad.enumv() == diff.enumv()) {
  1111. supported_dst_dtype.push_back(grad);
  1112. }
  1113. } else {
  1114. megdnn_throw(ssprintf(
  1115. "runtime does not support input / diff DType: %s x %s"
  1116. "now support case list: FLOAT x FLOAT\n"
  1117. " Int8 x Int8\n"
  1118. " QuantizedS8 x QuantizedS8\n"
  1119. " Quantized8Asymm x Quantized8Asymm\n",
  1120. filter.name(), diff.name()));
  1121. }
  1122. if (!grad.valid()) {
  1123. grad = supported_dst_dtype.at(0);
  1124. } else {
  1125. megdnn_assert(
  1126. vec_contains(supported_dst_dtype, grad),
  1127. "runtime does not support ConvBwd(%s, %s) -> %s"
  1128. "now support case list: ConvBwd(FLOAT x FLOAT) -> FLOAT\n"
  1129. " ConvBwd(Int8 x Int8) -> Int32\n"
  1130. " ConvBwd(QuantizedS8 x QuantizedS8) -> "
  1131. "QuantizedS32\n"
  1132. " ConvBwd(Quantized8Asymm x Quantized8Asymm) -> "
  1133. "Quantized32Asymm\n",
  1134. filter.name(), diff.name(), grad.name());
  1135. }
  1136. megdnn_assert(
  1137. param().compute_mode != Param::ComputeMode::FLOAT32
  1138. #if !MEGDNN_DISABLE_FLOAT16
  1139. || filter.enumv() == DTypeEnum::Float16 ||
  1140. filter.enumv() == DTypeEnum::BFloat16
  1141. #endif
  1142. ,
  1143. "ComputeMode::FLOAT32 is only available for Float16/BFloat16 "
  1144. "input / output.");
  1145. }
  1146. void ConvolutionBackwardData::deduce_layout(
  1147. const TensorLayout& filter, const TensorLayout& diff, TensorLayout& grad) {
  1148. auto errmsg = [&]() { return get_errmsg(filter, diff, grad, param()); };
  1149. MEGDNN_MARK_USED_VAR(errmsg);
  1150. megdnn_assert_contiguous(filter);
  1151. megdnn_assert_contiguous(diff);
  1152. megdnn_assert(filter.ndim == 4_z || filter.ndim == 5_z, "%s", errmsg().c_str());
  1153. megdnn_assert(diff.ndim == 4_z || diff.ndim == 5_z, "%s", errmsg().c_str());
  1154. deduce_dtype(filter.dtype, diff.dtype, grad.dtype);
  1155. auto cflt = make_canonized_filter_meta(diff.ndim, filter);
  1156. auto deduce = [&errmsg](size_t out, size_t filter, size_t stride, size_t pad) {
  1157. MEGDNN_MARK_USED_VAR(errmsg);
  1158. auto i = (out - 1) * stride + filter;
  1159. megdnn_assert(i > pad * 2, "%s", errmsg().c_str());
  1160. return i - pad * 2;
  1161. };
  1162. if (param().format == Param::Format::NCHW ||
  1163. param().format == Param::Format::NHWC) {
  1164. size_t src_or_dst_c_pos = 0;
  1165. size_t src_or_dst_spatial_start = 0;
  1166. if (param().format == Param::Format::NCHW) {
  1167. src_or_dst_c_pos = 1;
  1168. src_or_dst_spatial_start = 2;
  1169. } else {
  1170. megdnn_assert(param().format == Param::Format::NHWC, "invalid conv format");
  1171. src_or_dst_c_pos = 3;
  1172. src_or_dst_spatial_start = 1;
  1173. }
  1174. megdnn_assert(
  1175. cflt.ocpg * cflt.group == diff[src_or_dst_c_pos], "%s",
  1176. errmsg().c_str());
  1177. grad.ndim = diff.ndim;
  1178. grad[0] = diff[0];
  1179. grad[src_or_dst_c_pos] = cflt.icpg * cflt.group;
  1180. for (size_t i = 0; i < cflt.spatial_ndim; ++i) {
  1181. grad[i + src_or_dst_spatial_start] =
  1182. deduce(diff[i + src_or_dst_spatial_start], cflt.dilated_spatial[i],
  1183. cflt.stride[i], cflt.padding[i]);
  1184. }
  1185. } else if (param().format == Param::Format::NCHW4) {
  1186. megdnn_assert(
  1187. diff.ndim == 5, "valid diff ndim for NCHW4, expected=5, got=%zu",
  1188. diff.ndim);
  1189. megdnn_assert(cflt.group == 1, "%s", errmsg().c_str());
  1190. megdnn_assert(cflt.ocpg * cflt.group == diff[1] * 4, "%s", errmsg().c_str());
  1191. grad.ndim = diff.ndim;
  1192. grad[0] = diff[0];
  1193. auto ic = cflt.icpg * cflt.group;
  1194. megdnn_assert(ic % 4 == 0);
  1195. grad[1] = ic / 4;
  1196. grad[2] = deduce(
  1197. diff[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  1198. grad[3] = deduce(
  1199. diff[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  1200. megdnn_assert(diff[4] == 4);
  1201. grad[4] = 4;
  1202. } else {
  1203. megdnn_assert(param().format == Param::Format::NHWCD4);
  1204. megdnn_assert(
  1205. diff.ndim == 5, "valid diff ndim for NHWCD4, expected=5, got=%zu",
  1206. diff.ndim);
  1207. megdnn_assert(cflt.ocpg * cflt.group == diff[2] * 4, "%s", errmsg().c_str());
  1208. grad.ndim = diff.ndim;
  1209. grad[0] = diff[0];
  1210. auto ic = cflt.icpg * cflt.group;
  1211. megdnn_assert(ic % 4 == 0);
  1212. grad[2] = ic / 4;
  1213. grad[1] = deduce(
  1214. diff[1], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  1215. grad[3] = deduce(
  1216. diff[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  1217. megdnn_assert(diff[4] == 4);
  1218. grad[4] = 4;
  1219. }
  1220. grad.format = diff.format;
  1221. grad.init_contiguous_stride();
  1222. }
  1223. ConvolutionBackwardFilter::CanonizedFilterMeta ConvolutionBackwardFilter::check_exec(
  1224. const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad,
  1225. size_t workspace_in_bytes) {
  1226. megdnn_assert(
  1227. src.dtype.category() == DTypeCategory::FLOAT &&
  1228. diff.dtype.category() == DTypeCategory::FLOAT &&
  1229. grad.dtype.category() == DTypeCategory::FLOAT,
  1230. "only float type is supported for conv backward filter");
  1231. auto src_fwd = src;
  1232. auto diff_fwd = diff;
  1233. src_fwd.init_contiguous_stride();
  1234. diff_fwd.init_contiguous_stride();
  1235. auto ret = check_layout_fwd(src_fwd, grad, diff_fwd);
  1236. auto required_workspace_in_bytes = get_workspace_in_bytes(src, diff, grad);
  1237. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  1238. return ret;
  1239. }
  1240. } // namespace megdnn
  1241. // vim: syntax=cpp.doxygen