You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

convolution.cpp 52 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235
  1. /**
  2. * \file dnn/src/common/convolution.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "megdnn/oprs/nn.h"
  13. #include "src/common/utils.h"
  14. using namespace megdnn;
  15. namespace {
  16. template <typename Param>
  17. std::string get_errmsg(
  18. const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst,
  19. const Param& param) {
  20. MEGDNN_MARK_USED_VAR(src);
  21. MEGDNN_MARK_USED_VAR(filter);
  22. MEGDNN_MARK_USED_VAR(dst);
  23. return megdnn_layout_msg(src) + ", " + megdnn_layout_msg(filter) + ", " +
  24. megdnn_layout_msg(dst) + ", " + "is_nchw=" +
  25. std::to_string(param.format == param::Convolution::Format::NCHW) + ", " +
  26. "is_xcorr=" +
  27. std::to_string((param.mode == Convolution::Mode::CROSS_CORRELATION)) + ", " +
  28. "pad_h=" + std::to_string(param.pad_h) + ", " +
  29. "pad_w=" + std::to_string(param.pad_w) + ", " +
  30. "stride_h=" + std::to_string(param.stride_h) + ", " +
  31. "stride_w=" + std::to_string(param.stride_w) + ", " +
  32. "dilate_h=" + std::to_string(param.dilate_h) + ", " +
  33. "dilate_w=" + std::to_string(param.dilate_w);
  34. }
  35. template <typename Param, typename Param::Format>
  36. uint32_t spatial_getter(uint32_t filter, const Param&) {
  37. return filter;
  38. }
  39. template <typename Parameter, typename Param>
  40. void make_canonized_filter_meta_nchw_nhwc(
  41. size_t src_ndim, const TensorLayout& filter, const Param& param,
  42. typename ConvolutionBase<Parameter>::CanonizedFilterMeta& ret) {
  43. megdnn_assert(
  44. param.format == Param::Format::NCHW || param.format == Param::Format::NHWC);
  45. auto img_ndim = src_ndim - 2;
  46. size_t flt_start, flt_spatial_start, ocpg_pos, icpg_pos;
  47. if (param.sparse == Param::Sparse::DENSE) {
  48. megdnn_assert(
  49. filter.ndim == img_ndim + 2 || filter.ndim == img_ndim + 4,
  50. "bad filter ndim for dense convolution: "
  51. "spatial_ndim=%zu filter_ndim=%zu",
  52. img_ndim, filter.ndim);
  53. ret.group = 1;
  54. flt_start = 0;
  55. } else {
  56. megdnn_assert(
  57. param.sparse == Param::Sparse::GROUP,
  58. "invalid convolution sparse type");
  59. megdnn_assert(
  60. filter.ndim == img_ndim + 3 || filter.ndim == img_ndim + 5,
  61. "bad filter ndim for group convolution: "
  62. "spatial_ndim=%zu filter_ndim=%zu",
  63. img_ndim, filter.ndim);
  64. // grp, oc, ic, dims[]
  65. ret.group = filter[0];
  66. flt_start = 1;
  67. }
  68. uint32_t ic_block_size = 1, oc_block_size = 1;
  69. if (param.format == Param::Format::NCHW) {
  70. // filter should be (oc, ic, fh, fw)
  71. flt_spatial_start = 2;
  72. ocpg_pos = 0;
  73. icpg_pos = 1;
  74. } else {
  75. megdnn_assert(
  76. param.format == Param::Format::NHWC, "invalid conv tensor format");
  77. // filter should be (oc, fh, fw, ic)
  78. flt_spatial_start = 1;
  79. ocpg_pos = 0;
  80. icpg_pos = 3;
  81. }
  82. ret.spatial_ndim = src_ndim - 2;
  83. megdnn_assert(
  84. ret.spatial_ndim == 2,
  85. "only 2D convolution is supported, and input should be 4-dim; "
  86. "got input dim = %zu",
  87. src_ndim);
  88. ret.ocpg = filter[flt_start + ocpg_pos] * oc_block_size;
  89. ret.icpg = filter[flt_start + icpg_pos] * ic_block_size;
  90. auto dilation = ret.dilation;
  91. for (size_t i = 0; i < ret.spatial_ndim; ++i) {
  92. megdnn_assert(
  93. dilation[i] > 0, "invalid dilation on spatial dim %zu: %u", i,
  94. dilation[i]);
  95. ret.spatial[i] = spatial_getter<Param, Param::Format::NCHW>(
  96. filter[i + flt_start + flt_spatial_start], param);
  97. ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1;
  98. }
  99. }
  100. template <typename Parameter, typename Param>
  101. void make_canonized_filter_meta_nhwcd4(
  102. size_t src_ndim, const TensorLayout& filter, const Param& param,
  103. typename ConvolutionBase<Parameter>::CanonizedFilterMeta& ret) {
  104. /**
  105. * input: N H IC/4 W 4
  106. * Filter:
  107. * OC/4, FH, FW, IC, 4 [dense]
  108. * GROUP, OC/4, FH, FW, IC, 4 [group]
  109. * GROUP/4, 1, FH, FW, 4 [chanwise]
  110. */
  111. megdnn_assert(param.format == Param::Format::NHWCD4);
  112. auto img_ndim = src_ndim - 3;
  113. size_t flt_start = 0, flt_spatial_start = 1;
  114. bool is_chanwise = false;
  115. if (param.sparse == Param::Sparse::DENSE) {
  116. megdnn_assert(
  117. filter.ndim == img_ndim + 3,
  118. "bad filter ndim for dense convolution: "
  119. "spatial_ndim=%zu filter_ndim=%zu",
  120. img_ndim, filter.ndim);
  121. // oc, ic, dims[]
  122. ret.group = 1;
  123. flt_start = 0;
  124. } else {
  125. megdnn_assert(
  126. param.sparse == Param::Sparse::GROUP,
  127. "invalid convolution sparse type");
  128. megdnn_assert(
  129. filter.ndim == img_ndim + 3 || filter.ndim == img_ndim + 4,
  130. "bad filter ndim for group convolution: "
  131. "spatial_ndim=%zu filter_ndim=%zu",
  132. img_ndim, filter.ndim);
  133. if (filter.ndim == img_ndim + 3 && filter[1] == 1) {
  134. is_chanwise = true;
  135. ret.group = filter[0] * 4;
  136. } else {
  137. ret.group = filter[0];
  138. }
  139. flt_start = 1;
  140. }
  141. ret.spatial_ndim = src_ndim - 3;
  142. megdnn_assert(
  143. ret.spatial_ndim == 2,
  144. "only 2D convolution is supported, and input should be 4-dim; "
  145. "got input dim = %zu",
  146. src_ndim);
  147. if (is_chanwise) {
  148. ret.ocpg = 1;
  149. ret.icpg = 1;
  150. } else {
  151. ret.ocpg = filter[flt_start] * 4;
  152. ret.icpg = filter[flt_start + 3];
  153. }
  154. auto dilation = ret.dilation;
  155. for (size_t i = 0; i < ret.spatial_ndim; ++i) {
  156. megdnn_assert(
  157. dilation[i] > 0, "invalid dilation on spatial dim %zu: %u", i,
  158. dilation[i]);
  159. ret.spatial[i] = filter[i + flt_start + flt_spatial_start];
  160. ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1;
  161. }
  162. }
  163. template <typename Parameter, typename Param>
  164. void make_canonized_filter_meta_nhwcd4_dot(
  165. size_t src_ndim, const TensorLayout& filter, const Param& param,
  166. typename ConvolutionBase<Parameter>::CanonizedFilterMeta& ret) {
  167. /**
  168. * input: N H IC/4 W 4
  169. * Filter:
  170. * GROUP/4, 1, FH, FW, 4 [chanwise]
  171. * OC/4, FH, FW, IC/4, 4, 4 [dense]
  172. * GROUP, OC/4, FH, FW, IC/4, 4, 4 [group]
  173. */
  174. megdnn_assert(param.format == Param::Format::NHWCD4);
  175. auto img_ndim = src_ndim - 3;
  176. size_t flt_start = 0, flt_spatial_start = 1;
  177. bool is_chanwise = false;
  178. if (param.sparse == Param::Sparse::DENSE) {
  179. megdnn_assert(
  180. filter.ndim == img_ndim + 4,
  181. "bad filter ndim for dense convolution: "
  182. "spatial_ndim=%zu filter_ndim=%zu",
  183. img_ndim, filter.ndim);
  184. // oc, ic, dims[]
  185. ret.group = 1;
  186. flt_start = 0;
  187. } else {
  188. megdnn_assert(
  189. param.sparse == Param::Sparse::GROUP,
  190. "invalid convolution sparse type");
  191. megdnn_assert(
  192. filter.ndim == img_ndim + 3 || filter.ndim == img_ndim + 5,
  193. "bad filter ndim for group convolution: "
  194. "spatial_ndim=%zu filter_ndim=%zu",
  195. img_ndim, filter.ndim);
  196. if (filter.ndim == img_ndim + 3) {
  197. megdnn_assert(filter[1] == 1);
  198. is_chanwise = true;
  199. ret.group = filter[0] * 4;
  200. } else {
  201. ret.group = filter[0];
  202. }
  203. flt_start = 1;
  204. }
  205. ret.spatial_ndim = src_ndim - 3;
  206. megdnn_assert(
  207. ret.spatial_ndim == 2,
  208. "only 2D convolution is supported, and input should be 4-dim; "
  209. "got input dim = %zu",
  210. src_ndim);
  211. if (is_chanwise) {
  212. ret.ocpg = 1;
  213. ret.icpg = 1;
  214. } else {
  215. ret.ocpg = filter[flt_start] * 4;
  216. ret.icpg = filter[flt_start + 3] * 4;
  217. }
  218. auto dilation = ret.dilation;
  219. for (size_t i = 0; i < ret.spatial_ndim; ++i) {
  220. megdnn_assert(
  221. dilation[i] > 0, "invalid dilation on spatial dim %zu: %u", i,
  222. dilation[i]);
  223. ret.spatial[i] = filter[i + flt_start + flt_spatial_start];
  224. ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1;
  225. }
  226. }
  227. template <size_t pack_size, typename Parameter, typename Param>
  228. void make_canonized_filter_meta_nchwxx(
  229. size_t src_ndim, const TensorLayout& filter, const Param& param,
  230. typename ConvolutionBase<Parameter>::CanonizedFilterMeta& ret) {
  231. /**
  232. * input: N IC/pack_size, H, W, pack_size
  233. *
  234. ** NCHW44-DOT mode
  235. * filter:
  236. * {OC/pack_size, IC/pack_size, FH, FW, pack_size(OC), pack_size(IC)}
  237. * [dense]
  238. * {GROUP, OC_PER_GROUP/pack_size, IC_PER_GROUP/pack_size, \
  239. * FH, FW, pack_size(OC), pack_size(IC)} [group]
  240. *
  241. * NCHW88 and NCHW44 mode
  242. * filter:
  243. * {OC/pack_size, IC/pack_size, FH, FW, pack_size(IC), pack_size(OC)}
  244. * [dense]
  245. * {GROUP, OC_PER_GROUP/pack_size, IC_PER_GROUP/pack_size, \
  246. * FH, FW, pack_size(IC), pack_size(OC)} [group]
  247. * {GROUP/pack_size, 1, 1, FH, FW, pack_size} [chan]
  248. *
  249. *
  250. */
  251. megdnn_assert(
  252. param.format == Param::Format::NCHW88 ||
  253. param.format == Param::Format::NCHW44 ||
  254. param.format == Param::Format::NCHW44_DOT);
  255. size_t img_ndim = 2;
  256. size_t flt_start = 0;
  257. size_t flt_spatial_start = 2;
  258. size_t pack_c_size = 0;
  259. if (param.sparse == Param::Sparse::DENSE) {
  260. if (filter.ndim == img_ndim + 4) {
  261. // oihw8i8o case
  262. megdnn_assert(
  263. (filter[filter.ndim - 2] == pack_size &&
  264. filter[filter.ndim - 1] == pack_size) ||
  265. (filter[filter.ndim - 2] == 2 * pack_size &&
  266. filter[filter.ndim - 1] == 2 * pack_size),
  267. "last 2 dim of filter must be %zu, but got %zu, %zu", pack_size,
  268. filter[filter.ndim - 2], filter[filter.ndim - 1]);
  269. ret.group = 1;
  270. flt_start = 0;
  271. if (filter[filter.ndim - 2] == 2 * pack_size &&
  272. filter[filter.ndim - 1] == 2 * pack_size) {
  273. pack_c_size = 2 * pack_size;
  274. } else {
  275. pack_c_size = pack_size;
  276. }
  277. ret.ocpg = filter[flt_start] * pack_c_size;
  278. ret.icpg = filter[flt_start + 1] * pack_c_size;
  279. } else if (filter.ndim == img_ndim + 3) {
  280. // ohwi8o
  281. flt_start = 0;
  282. flt_spatial_start = 1;
  283. ret.group = 1;
  284. ret.ocpg = filter[flt_start] * pack_size;
  285. ret.icpg = filter[flt_start + 3];
  286. } else {
  287. megdnn_assert(0, "not support nchwxx filter dim = %zu", filter.ndim);
  288. }
  289. } else {
  290. megdnn_assert(
  291. param.sparse == Param::Sparse::GROUP,
  292. "invalid convolution sparse type");
  293. flt_start = 1;
  294. auto filter_oc = filter[flt_start];
  295. auto filter_ic = filter[flt_start + 1];
  296. if (filter_oc == 1 && filter_ic == 1 && filter.ndim == (img_ndim + 4)) {
  297. // Depthwise case goihw8g
  298. megdnn_assert(
  299. filter.ndim == img_ndim + 4,
  300. "bad filter ndim for group convolution: "
  301. "spatial_ndim=%zu filter_ndim=%zu",
  302. img_ndim, filter.ndim);
  303. megdnn_assert(
  304. filter[filter.ndim - 1] == pack_size,
  305. "last dim of filter must be %zu, but %zu", pack_size,
  306. filter[filter.ndim - 1]);
  307. ret.group = filter[0] * pack_size;
  308. ret.ocpg = filter_oc;
  309. ret.icpg = filter_ic;
  310. } else {
  311. // norm group case goihw8i8o
  312. megdnn_assert(
  313. filter.ndim == img_ndim + 5,
  314. "bad filter ndim for group convolution: "
  315. "spatial_ndim=%zu filter_ndim=%zu",
  316. img_ndim, filter.ndim);
  317. megdnn_assert(
  318. (filter[filter.ndim - 1] == pack_size &&
  319. filter[filter.ndim - 2] == pack_size) ||
  320. (filter[filter.ndim - 1] == 2 * pack_size &&
  321. filter[filter.ndim - 2] == 2 * pack_size),
  322. "last 2 dim of filter must be %zu, but got %zu, %zu", pack_size,
  323. filter[filter.ndim - 2], filter[filter.ndim - 1]);
  324. ret.group = filter[0];
  325. if (filter[filter.ndim - 2] == 2 * pack_size &&
  326. filter[filter.ndim - 1] == 2 * pack_size) {
  327. ret.ocpg = filter_oc * 2 * pack_size;
  328. ret.icpg = filter_ic * 2 * pack_size;
  329. } else {
  330. ret.ocpg = filter_oc * pack_size;
  331. ret.icpg = filter_ic * pack_size;
  332. }
  333. }
  334. }
  335. ret.spatial_ndim = 2;
  336. megdnn_assert(
  337. ret.spatial_ndim == 2,
  338. "only 2D convolution is supported, and input should be 5-dim "
  339. "for nchwxx; "
  340. "got input dim = %zu",
  341. src_ndim);
  342. auto dilation = ret.dilation;
  343. for (size_t i = 0; i < ret.spatial_ndim; ++i) {
  344. megdnn_assert(
  345. dilation[i] == 1,
  346. "NCHWXX has invalid dilation on spatial dim %zu: %u, "
  347. "require to be 1",
  348. i, dilation[i]);
  349. ret.spatial[i] = filter[i + flt_start + flt_spatial_start];
  350. ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1;
  351. }
  352. }
  353. template <size_t pack_size, typename Parameter, typename Param>
  354. void make_canonized_filter_meta_nchwx(
  355. size_t src_ndim, const TensorLayout& filter, const Param& param,
  356. typename ConvolutionBase<Parameter>::CanonizedFilterMeta& ret) {
  357. /**
  358. * input: N IC/pack_size, H, W, pack_size
  359. * filter:
  360. * OC, IC/pack_size, FH, FW, pack_size [dense]
  361. * GROUP, OC, IC/pack_size, FH, FW, pack_size [group]
  362. */
  363. megdnn_assert(
  364. param.format == Param::Format::NCHW4 ||
  365. param.format == Param::Format::NCHW8 ||
  366. param.format == Param::Format::NCHW32 ||
  367. param.format == Param::Format::NCHW4_NCHW ||
  368. param.format == Param::Format::NCHW4_NHWC ||
  369. param.format == Param::Format::NCHW4_NCHW32 ||
  370. param.format == Param::Format::NCHW32_NCHW4 ||
  371. param.format == Param::Format::NCHW64);
  372. auto img_ndim = src_ndim - 3;
  373. size_t flt_start = 0, flt_spatial_start = 2;
  374. if (param.sparse == Param::Sparse::DENSE) {
  375. megdnn_assert(
  376. filter.ndim == img_ndim + 3,
  377. "bad filter ndim for dense convolution: "
  378. "spatial_ndim=%zu filter_ndim=%zu",
  379. img_ndim, filter.ndim);
  380. // oc, ic, dims[]
  381. ret.group = 1;
  382. flt_start = 0;
  383. } else {
  384. megdnn_assert(
  385. param.sparse == Param::Sparse::GROUP,
  386. "invalid convolution sparse type");
  387. megdnn_assert(
  388. filter.ndim == img_ndim + 4,
  389. "bad filter ndim for group convolution: "
  390. "spatial_ndim=%zu filter_ndim=%zu",
  391. img_ndim, filter.ndim);
  392. ret.group = filter[0];
  393. flt_start = 1;
  394. }
  395. ret.spatial_ndim = src_ndim - 3;
  396. megdnn_assert(
  397. ret.spatial_ndim == 2,
  398. "only 2D convolution is supported, and input should be 5-dim "
  399. "for nchw4; "
  400. "got input dim = %zu",
  401. src_ndim);
  402. ret.ocpg = filter[flt_start];
  403. ret.icpg = filter[flt_start + 1] * pack_size;
  404. auto dilation = ret.dilation;
  405. for (size_t i = 0; i < ret.spatial_ndim; ++i) {
  406. megdnn_assert(
  407. dilation[i] == 1,
  408. "NCHW4 has invalid dilation on spatial dim %zu: %u, "
  409. "require to be 1",
  410. i, dilation[i]);
  411. ret.spatial[i] = filter[i + flt_start + flt_spatial_start];
  412. ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1;
  413. }
  414. }
  415. template <size_t pack_size, typename Parameter, typename Param>
  416. void make_canonized_filter_meta_chwnx(
  417. size_t src_ndim, const TensorLayout& filter, const Param& param,
  418. typename ConvolutionBase<Parameter>::CanonizedFilterMeta& ret) {
  419. /**
  420. * input: IC / pack_size, H, W, N, pack_size
  421. * Filter:
  422. * IC / pack_size, FH, FW, OC, pack_size [dense]
  423. * GROUP, icpg / pack_size, FH, FW, ocpg, pack_size [group]
  424. * not implemented [chanwise]
  425. */
  426. megdnn_assert(param.format == Param::Format::CHWN4);
  427. auto img_ndim = src_ndim - 3;
  428. size_t flt_start = 0, flt_spatial_start = 1;
  429. if (param.sparse == Param::Sparse::DENSE) {
  430. megdnn_assert(
  431. filter.ndim == img_ndim + 3,
  432. "bad filter ndim for dense convolution: "
  433. "spatial_ndim=%zu filter_ndim=%zu",
  434. img_ndim, filter.ndim);
  435. // oc, ic, dims[]
  436. ret.group = 1;
  437. flt_start = 0;
  438. } else {
  439. megdnn_assert(
  440. param.sparse == Param::Sparse::GROUP,
  441. "invalid convolution sparse type");
  442. megdnn_assert(
  443. filter.ndim == img_ndim + 4,
  444. "bad filter ndim for group convolution: "
  445. "spatial_ndim=%zu filter_ndim=%zu",
  446. img_ndim, filter.ndim);
  447. ret.group = filter[0];
  448. flt_start = 1;
  449. }
  450. ret.spatial_ndim = src_ndim - 3;
  451. megdnn_assert(
  452. ret.spatial_ndim == 2,
  453. "only 2D convolution is supported, and input should be 4-dim; "
  454. "got input dim = %zu",
  455. src_ndim);
  456. ret.icpg = filter[flt_start] * pack_size;
  457. ret.ocpg = filter[flt_start + 3];
  458. auto dilation = ret.dilation;
  459. for (size_t i = 0; i < ret.spatial_ndim; ++i) {
  460. megdnn_assert(
  461. dilation[i] == 1,
  462. "CHWNx has invalid dilation on spatial dim %zu: %u, "
  463. "require to be 1",
  464. i, dilation[i]);
  465. ret.spatial[i] = filter[i + flt_start + flt_spatial_start];
  466. ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1;
  467. }
  468. }
  469. } // namespace
  470. namespace megdnn {
  471. template <typename Parameter>
  472. typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Parameter>::
  473. make_canonized_filter_meta(size_t src_ndim, const TensorLayout& filter) const {
  474. megdnn_assert_contiguous(filter);
  475. CanonizedFilterMeta ret;
  476. ret.dtype = filter.dtype;
  477. ret.format = param().format;
  478. if (param().mode == Mode::CONVOLUTION) {
  479. ret.should_flip = true;
  480. } else {
  481. megdnn_assert(param().mode == Mode::CROSS_CORRELATION, "invalid conv mode");
  482. ret.should_flip = false;
  483. }
  484. ret.stride[0] = param().stride_h;
  485. ret.stride[1] = param().stride_w;
  486. ret.padding[0] = param().pad_h;
  487. ret.padding[1] = param().pad_w;
  488. ret.dilation[0] = param().dilate_h;
  489. ret.dilation[1] = param().dilate_w;
  490. if (param().format == Param::Format::NHWCD4) {
  491. if (filter.dtype.enumv() == DTypeEnum::QuantizedS8 ||
  492. filter.dtype.enumv() == DTypeEnum::Quantized8Asymm) {
  493. make_canonized_filter_meta_nhwcd4_dot<Parameter>(
  494. src_ndim, filter, param(), ret);
  495. } else {
  496. make_canonized_filter_meta_nhwcd4<Parameter>(
  497. src_ndim, filter, param(), ret);
  498. }
  499. } else if (
  500. param().format == Param::Format::NCHW4 ||
  501. param().format == Param::Format::NCHW4_NCHW ||
  502. param().format == Param::Format::NCHW4_NHWC ||
  503. param().format == Param::Format::NCHW4_NCHW32) {
  504. make_canonized_filter_meta_nchwx<4, Parameter>(src_ndim, filter, param(), ret);
  505. } else if (param().format == Param::Format::NCHW8) {
  506. make_canonized_filter_meta_nchwx<8, Parameter>(src_ndim, filter, param(), ret);
  507. } else if (param().format == Param::Format::NCHW88) {
  508. make_canonized_filter_meta_nchwxx<8, Parameter>(src_ndim, filter, param(), ret);
  509. } else if (
  510. param().format == Param::Format::NCHW44 ||
  511. param().format == Param::Format::NCHW44_DOT) {
  512. make_canonized_filter_meta_nchwxx<4, Parameter>(src_ndim, filter, param(), ret);
  513. } else if (
  514. param().format == Param::Format::NCHW32 ||
  515. param().format == Param::Format::NCHW32_NCHW4) {
  516. make_canonized_filter_meta_nchwx<32, Parameter>(src_ndim, filter, param(), ret);
  517. } else if (param().format == Param::Format::CHWN4) {
  518. make_canonized_filter_meta_chwnx<4, Parameter>(src_ndim, filter, param(), ret);
  519. } else if (param().format == Param::Format::NCHW64) {
  520. make_canonized_filter_meta_nchwx<64, Parameter>(src_ndim, filter, param(), ret);
  521. } else {
  522. megdnn_assert(
  523. param().format == Param::Format::NHWC ||
  524. param().format == Param::Format::NCHW);
  525. make_canonized_filter_meta_nchw_nhwc<Parameter>(src_ndim, filter, param(), ret);
  526. }
  527. return ret;
  528. }
  529. template <typename Parameter>
  530. void ConvolutionBase<Parameter>::check_or_deduce_dtype_fwd(
  531. DType src, DType filter, DType& dst) const {
  532. // The first one will be the default choice.
  533. SmallVector<DType> supported_dst_dtype;
  534. // We rely on megdnn_assert(src.enumv() == filter.enumv()) here.
  535. if (src.category() == DTypeCategory::FLOAT) {
  536. supported_dst_dtype.push_back(src);
  537. } else if (src.enumv() == DTypeEnum::Int8) {
  538. supported_dst_dtype = {dtype::Int32(), dtype::Int16()};
  539. } else if (
  540. src.enumv() == DTypeEnum::QuantizedS8 ||
  541. src.enumv() == DTypeEnum::Quantized8Asymm ||
  542. src.enumv() == DTypeEnum::QuantizedS4 ||
  543. src.enumv() == DTypeEnum::Quantized4Asymm) {
  544. supported_dst_dtype.push_back(dtype::QuantizedS32(mul_scale(src, filter)));
  545. bool cond_dst = dst.valid() && (dst.enumv() == src.enumv() ||
  546. ((dst.enumv() == DTypeEnum::QuantizedS4 ||
  547. dst.enumv() == DTypeEnum::Quantized4Asymm) &&
  548. src.enumv() == DTypeEnum::QuantizedS8) ||
  549. ((src.enumv() == DTypeEnum::QuantizedS4 ||
  550. src.enumv() == DTypeEnum::Quantized4Asymm) &&
  551. dst.enumv() == DTypeEnum::QuantizedS8));
  552. if (cond_dst) {
  553. supported_dst_dtype.push_back(dst);
  554. }
  555. if (src.enumv() == DTypeEnum::QuantizedS8) {
  556. supported_dst_dtype.push_back(dtype::Float32());
  557. }
  558. } else if (src.enumv() == DTypeEnum::QuantizedS32) {
  559. //! ConvolutionBackwardData: s8(filter) + s8(dst) -> s32(src)
  560. megdnn_assert(filter.enumv() == DTypeEnum::QuantizedS8);
  561. supported_dst_dtype.push_back(dtype::QuantizedS8(
  562. src.param<dtype::QuantizedS32>().scale /
  563. filter.param<dtype::QuantizedS8>().scale));
  564. } else {
  565. megdnn_throw(ssprintf(
  566. "unsupported input / filter DType: %s x %s", src.name(),
  567. filter.name()));
  568. }
  569. if (!dst.valid()) {
  570. dst = supported_dst_dtype.at(0);
  571. } else {
  572. bool dst_supported = false;
  573. for (auto&& dt : supported_dst_dtype) {
  574. if (dtype_almost_equal(dt, dst)) {
  575. dst_supported = true;
  576. break;
  577. }
  578. }
  579. MEGDNN_MARK_USED_VAR(dst_supported);
  580. megdnn_assert(
  581. dst_supported, "unsupported Conv(%s, %s) -> %s", src.name(),
  582. filter.name(), dst.name());
  583. }
  584. megdnn_assert(
  585. (param().compute_mode == Param::ComputeMode::FLOAT32 ||
  586. param().compute_mode == Param::ComputeMode::DEFAULT)
  587. #if !MEGDNN_DISABLE_FLOAT16
  588. || src.enumv() == DTypeEnum::Float16 ||
  589. src.enumv() == DTypeEnum::BFloat16
  590. #endif
  591. ,
  592. "ComputeMode::FLOAT32 is only available for Float16/BFloat16 "
  593. "input / output.");
  594. }
  595. template <typename Parameter>
  596. typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Parameter>::
  597. deduce_layout_fwd(
  598. const TensorLayout& src, const TensorLayout& filter,
  599. TensorLayout& dst) const {
  600. auto errmsg = [&]() { return get_errmsg(src, filter, dst, param()); };
  601. MEGDNN_MARK_USED_VAR(errmsg);
  602. megdnn_assert(src.ndim >= 3_z, "%s", errmsg().c_str());
  603. megdnn_assert(
  604. ((src.dtype.enumv() == filter.dtype.enumv()) ||
  605. (src.dtype.enumv() == DTypeEnum::Quantized4Asymm &&
  606. filter.dtype.enumv() == DTypeEnum::QuantizedS4)),
  607. "%s", errmsg().c_str());
  608. check_or_deduce_dtype_fwd(src.dtype, filter.dtype, dst.dtype);
  609. size_t img_dim;
  610. if (param().format == Param::Format::NCHW ||
  611. param().format == Param::Format::NHWC) {
  612. img_dim = src.ndim - 2;
  613. megdnn_assert(
  614. filter.ndim >= img_dim + 2 && filter.ndim <= img_dim + 6, "%s",
  615. errmsg().c_str());
  616. } else {
  617. megdnn_assert(
  618. param().format == Param::Format::NHWCD4 ||
  619. param().format == Param::Format::NCHW4 ||
  620. param().format == Param::Format::NCHW4_NCHW ||
  621. param().format == Param::Format::NCHW4_NHWC ||
  622. param().format == Param::Format::NCHW4_NCHW32 ||
  623. param().format == Param::Format::NCHW44 ||
  624. param().format == Param::Format::NCHW44_DOT ||
  625. param().format == Param::Format::NCHW8 ||
  626. param().format == Param::Format::NCHW32 ||
  627. param().format == Param::Format::NCHW32_NCHW4 ||
  628. param().format == Param::Format::NCHW88 ||
  629. param().format == Param::Format::CHWN4 ||
  630. param().format == Param::Format::NCHW64);
  631. img_dim = src.ndim - 3;
  632. if ((param().format == Param::Format::NCHW88 ||
  633. param().format == Param::Format::NCHW44_DOT ||
  634. param().format == Param::Format::NCHW44) &&
  635. filter.ndim == 5) {
  636. img_dim = src.ndim - 2;
  637. }
  638. megdnn_assert(
  639. filter.ndim == img_dim + 3 ||
  640. (filter.ndim == img_dim + 2 &&
  641. (param().format == Param::Format::NCHW88 ||
  642. param().format == Param::Format::NCHW44_DOT ||
  643. param().format == Param::Format::NCHW44)) ||
  644. filter.ndim == img_dim + 4 || filter.ndim == img_dim + 5,
  645. "%s", errmsg().c_str());
  646. if (param().format == Param::Format::NCHW4 ||
  647. param().format == Param::Format::NCHW4_NCHW ||
  648. param().format == Param::Format::NCHW4_NCHW32) {
  649. megdnn_assert(
  650. src.ndim == 5 &&
  651. (filter.ndim == 5 || filter.ndim == 6 ||
  652. filter.ndim == 7) &&
  653. src[src.ndim - 1] == 4 && filter[filter.ndim - 1] == 4,
  654. "NCHW4/NCHW4_NCHW/NCHW4_NCHW32 require src and "
  655. "filter's ndim is "
  656. "5 or 6, and "
  657. "last shape "
  658. "is 4 "
  659. "but got src %s, filter %s",
  660. src.to_string().c_str(), filter.to_string().c_str());
  661. }
  662. if (param().format == Param::Format::NCHW8) {
  663. megdnn_assert(
  664. src.ndim == 5 && (filter.ndim == 5 || filter.ndim == 6) &&
  665. src[src.ndim - 1] == 8 && filter[filter.ndim - 1] == 8,
  666. "NCHW8 require src and filter's ndim is 5 or 6, and last "
  667. "shape is 8 "
  668. "but got src %s, filter %s",
  669. src.to_string().c_str(), filter.to_string().c_str());
  670. }
  671. if (param().format == Param::Format::NCHW32 ||
  672. param().format == Param::Format::NCHW32_NCHW4) {
  673. megdnn_assert(
  674. src.ndim == 5 && (filter.ndim == 5 || filter.ndim == 6) &&
  675. src[src.ndim - 1] == 32 && filter[filter.ndim - 1] == 32,
  676. "NCHW32/NCHW32_NCHW4 require src and filter's ndim "
  677. "is 5 or 6, and last "
  678. "shape is 32 "
  679. "but got src %s, filter %s",
  680. src.to_string().c_str(), filter.to_string().c_str());
  681. }
  682. if (param().format == Param::Format::NCHW88) {
  683. megdnn_assert(
  684. (src.ndim == 4 && filter.ndim == 5 &&
  685. filter[filter.ndim - 1] == 8) ||
  686. (src.ndim == 5 &&
  687. ((filter.ndim == 6 && filter[filter.ndim - 1] == 8) ||
  688. (filter.ndim == 7 && filter[filter.ndim - 1] == 8 &&
  689. filter[filter.ndim - 2] == 8)) &&
  690. src[src.ndim - 1] == 8),
  691. "NCHW88 require src ndim is 5 and filter's ndim is 6 "
  692. ", and last shape two is 8 but got src %s, filter %s",
  693. src.to_string().c_str(), filter.to_string().c_str());
  694. }
  695. if (param().format == Param::Format::NCHW44 ||
  696. param().format == Param::Format::NCHW44_DOT) {
  697. //! support nchw44 filter change to 88 for int8 winogradf23_88 using
  698. //! MK8 mamtul
  699. megdnn_assert(
  700. (src.ndim == 4 && filter.ndim == 5 &&
  701. filter[filter.ndim - 1] == 4) ||
  702. (src.ndim == 5 &&
  703. ((filter.ndim == 6 && (filter[filter.ndim - 1] == 4 ||
  704. filter[filter.ndim - 1] == 8)) ||
  705. (filter.ndim == 7 &&
  706. (filter[filter.ndim - 1] == 4 ||
  707. filter[filter.ndim - 1] == 8) &&
  708. (filter[filter.ndim - 2] == 4 ||
  709. filter[filter.ndim - 2] == 8))) &&
  710. src[src.ndim - 1] == 4),
  711. "NCHW44 require src ndim is 5 and filter's ndim is 6 "
  712. ", and last shape two is 4 but got src %s, filter %s",
  713. src.to_string().c_str(), filter.to_string().c_str());
  714. }
  715. if (param().format == Param::Format::CHWN4) {
  716. megdnn_assert(
  717. src.ndim == 5 && (filter.ndim == 5 || filter.ndim == 6) &&
  718. src[src.ndim - 1] == 4 && filter[filter.ndim - 1] == 4,
  719. "CHWN4 require src and filter's ndim is 5 or 6, and last "
  720. "shape is 4 "
  721. "but got src %s, filter %s",
  722. src.to_string().c_str(), filter.to_string().c_str());
  723. }
  724. if (param().format == Param::Format::NCHW64) {
  725. megdnn_assert(
  726. src.ndim == 5 && (filter.ndim == 5 || filter.ndim == 6) &&
  727. src[src.ndim - 1] == 64 && filter[filter.ndim - 1] == 64,
  728. "NCHW64 require src and filter's ndim is 5 or 6, and "
  729. "last shape is 64 but got src %s, filter %s",
  730. src.to_string().c_str(), filter.to_string().c_str());
  731. }
  732. }
  733. megdnn_assert(img_dim == 2, "currently only convolution on 2D image is supported");
  734. auto cflt = make_canonized_filter_meta(src.ndim, filter);
  735. if (param().format == Param::Format::NCHW ||
  736. param().format == Param::Format::NHWC) {
  737. size_t src_or_dst_c_pos = 0;
  738. size_t src_or_dst_spatial_start = 0;
  739. if (param().format == Param::Format::NCHW) {
  740. src_or_dst_c_pos = 1;
  741. src_or_dst_spatial_start = 2;
  742. } else {
  743. megdnn_assert(param().format == Param::Format::NHWC, "invalid conv format");
  744. src_or_dst_c_pos = 3;
  745. src_or_dst_spatial_start = 1;
  746. }
  747. megdnn_assert(
  748. cflt.icpg * cflt.group == src[src_or_dst_c_pos], "%s",
  749. errmsg().c_str());
  750. dst.ndim = src.ndim;
  751. dst[0] = src[0];
  752. dst[src_or_dst_c_pos] = cflt.ocpg * cflt.group;
  753. for (size_t i = 0; i < cflt.spatial_ndim; ++i) {
  754. dst[i + src_or_dst_spatial_start] = infer_conv_shape(
  755. src[i + src_or_dst_spatial_start], cflt.dilated_spatial[i],
  756. cflt.stride[i], cflt.padding[i]);
  757. }
  758. } else if (param().format == Param::Format::NCHW4) {
  759. megdnn_assert(
  760. src.ndim == 5, "invalid src ndim for NCHW4, expected=5, got=%zu",
  761. src.ndim);
  762. megdnn_assert(
  763. cflt.icpg * cflt.group == src[1] * 4, "%s icpg=%u group=%u",
  764. errmsg().c_str(), cflt.icpg, cflt.group);
  765. dst.ndim = src.ndim;
  766. dst[0] = src[0];
  767. auto oc = cflt.ocpg * cflt.group;
  768. megdnn_assert(oc % 4 == 0);
  769. dst[1] = oc / 4;
  770. dst[2] = infer_conv_shape(
  771. src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  772. dst[3] = infer_conv_shape(
  773. src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  774. dst[4] = 4;
  775. } else if (param().format == Param::Format::NCHW8) {
  776. megdnn_assert(
  777. src.ndim == 5, "invalid src ndim for NCHW8, expected=5, got=%zu",
  778. src.ndim);
  779. megdnn_assert(
  780. cflt.icpg * cflt.group == src[1] * 8, "%s icpg=%u group=%u",
  781. errmsg().c_str(), cflt.icpg, cflt.group);
  782. dst.ndim = src.ndim;
  783. dst[0] = src[0];
  784. auto oc = cflt.ocpg * cflt.group;
  785. megdnn_assert(oc % 8 == 0);
  786. dst[1] = oc / 8;
  787. dst[2] = infer_conv_shape(
  788. src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  789. dst[3] = infer_conv_shape(
  790. src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  791. dst[4] = 8;
  792. } else if (param().format == Param::Format::NCHW32) {
  793. megdnn_assert(
  794. src.ndim == 5, "invalid src ndim for NCHW32, expected=5, got=%zu",
  795. src.ndim);
  796. megdnn_assert(
  797. cflt.icpg * cflt.group == src[1] * 32, "%s icpg=%u group=%u",
  798. errmsg().c_str(), cflt.icpg, cflt.group);
  799. dst.ndim = src.ndim;
  800. dst[0] = src[0];
  801. auto oc = cflt.ocpg * cflt.group;
  802. megdnn_assert(oc % 32 == 0);
  803. dst[1] = oc / 32;
  804. dst[2] = infer_conv_shape(
  805. src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  806. dst[3] = infer_conv_shape(
  807. src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  808. dst[4] = 32;
  809. } else if (param().format == Param::Format::NCHW88) {
  810. megdnn_assert(src.ndim == 5 || src.ndim == 4,
  811. "invalid src ndim for NCHW88, expected=5 or 4, got=%zu",
  812. src.ndim);
  813. dst.ndim = 5;
  814. dst[0] = src[0];
  815. auto oc = cflt.ocpg * cflt.group;
  816. megdnn_assert(oc % 8 == 0);
  817. dst[1] = oc / 8;
  818. dst[2] = infer_conv_shape(
  819. src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  820. dst[3] = infer_conv_shape(
  821. src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  822. dst[4] = 8;
  823. if (cflt.group == 1) {
  824. megdnn_assert(
  825. cflt.icpg * cflt.group == src[1] * 8 ||
  826. (cflt.icpg * cflt.group == src[1]),
  827. "%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg, cflt.group);
  828. }
  829. } else if (param().format == Param::Format::NCHW44 ||
  830. param().format == Param::Format::NCHW44_DOT) {
  831. megdnn_assert(src.ndim == 5 || src.ndim == 4,
  832. "invalid src ndim for NCHW44, expected=5 or 4, got=%zu",
  833. src.ndim);
  834. dst.ndim = 5;
  835. dst[0] = src[0];
  836. auto oc = cflt.ocpg * cflt.group;
  837. megdnn_assert(oc % 4 == 0);
  838. dst[1] = oc / 4;
  839. dst[2] = infer_conv_shape(
  840. src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  841. dst[3] = infer_conv_shape(
  842. src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  843. dst[4] = 4;
  844. if (cflt.group == 1) {
  845. megdnn_assert(
  846. cflt.icpg * cflt.group == src[1] * 4 ||
  847. (cflt.icpg * cflt.group == src[1]),
  848. "%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg, cflt.group);
  849. }
  850. } else if (param().format == Param::Format::CHWN4) {
  851. megdnn_assert(
  852. src.ndim == 5, "invalid src ndim for CHWN4, expected=5, got=%zu",
  853. src.ndim);
  854. megdnn_assert(
  855. cflt.icpg * cflt.group == src[0] * 4, "%s icpg=%u group=%u",
  856. errmsg().c_str(), cflt.icpg, cflt.group);
  857. dst.ndim = src.ndim;
  858. dst[3] = src[3];
  859. auto oc = cflt.ocpg * cflt.group;
  860. megdnn_assert(oc % 4 == 0);
  861. dst[0] = oc / 4;
  862. dst[1] = infer_conv_shape(
  863. src[1], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  864. dst[2] = infer_conv_shape(
  865. src[2], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  866. dst[4] = 4;
  867. } else if (param().format == Param::Format::NCHW4_NCHW) {
  868. megdnn_assert(
  869. src.ndim == 5, "invalid src ndim for NCHW4_NCHW, expected=5, got=%zu",
  870. src.ndim);
  871. megdnn_assert(
  872. cflt.icpg * cflt.group == src[1] * 4, "%s icpg=%u group=%u",
  873. errmsg().c_str(), cflt.icpg, cflt.group);
  874. dst.ndim = 4;
  875. dst[0] = src[0];
  876. auto oc = cflt.ocpg * cflt.group;
  877. dst[1] = oc;
  878. dst[2] = infer_conv_shape(
  879. src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  880. dst[3] = infer_conv_shape(
  881. src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  882. } else if (param().format == Param::Format::NCHW4_NHWC) {
  883. megdnn_assert(
  884. src.ndim == 5, "invalid src ndim for NCHW4_NHWC, expected=5, got=%zu",
  885. src.ndim);
  886. megdnn_assert(
  887. cflt.icpg * cflt.group == src[1] * 4, "%s icpg=%u group=%u",
  888. errmsg().c_str(), cflt.icpg, cflt.group);
  889. dst.ndim = 4;
  890. dst[0] = src[0];
  891. dst[1] = infer_conv_shape(
  892. src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  893. dst[2] = infer_conv_shape(
  894. src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  895. auto oc = cflt.ocpg * cflt.group;
  896. dst[3] = oc;
  897. } else if (param().format == Param::Format::NCHW4_NCHW32) {
  898. megdnn_assert(
  899. src.ndim == 5, "invalid src ndim for NCHW4_NCHW32, expected=5, got=%zu",
  900. src.ndim);
  901. megdnn_assert(
  902. cflt.icpg * cflt.group == src[1] * 4, "%s icpg=%u group=%u",
  903. errmsg().c_str(), cflt.icpg, cflt.group);
  904. dst.ndim = src.ndim;
  905. dst[0] = src[0];
  906. auto oc = cflt.ocpg * cflt.group;
  907. megdnn_assert(oc % 32 == 0);
  908. dst[1] = oc / 32;
  909. dst[2] = infer_conv_shape(
  910. src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  911. dst[3] = infer_conv_shape(
  912. src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  913. dst[4] = 32;
  914. } else if (param().format == Param::Format::NCHW32_NCHW4) {
  915. megdnn_assert(
  916. src.ndim == 5, "invalid src ndim for NCHW32_NCHW4, expected=5, got=%zu",
  917. src.ndim);
  918. megdnn_assert(
  919. cflt.icpg * cflt.group == src[1] * 32, "%s icpg=%u group=%u",
  920. errmsg().c_str(), cflt.icpg, cflt.group);
  921. dst.ndim = src.ndim;
  922. dst[0] = src[0];
  923. auto oc = cflt.ocpg * cflt.group;
  924. megdnn_assert(oc % 4 == 0);
  925. dst[1] = oc / 4;
  926. dst[2] = infer_conv_shape(
  927. src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  928. dst[3] = infer_conv_shape(
  929. src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  930. dst[4] = 4;
  931. } else if (param().format == Param::Format::NCHW64) {
  932. megdnn_assert(
  933. src.ndim == 5, "invalid src ndim for NCHW64, expected=5, got=%zu",
  934. src.ndim);
  935. megdnn_assert(
  936. cflt.icpg * cflt.group == src[1] * 64, "%s icpg=%u group=%u",
  937. errmsg().c_str(), cflt.icpg, cflt.group);
  938. dst.ndim = src.ndim;
  939. dst[0] = src[0];
  940. auto oc = cflt.ocpg * cflt.group;
  941. megdnn_assert(oc % 64 == 0);
  942. dst[1] = oc / 64;
  943. dst[2] = infer_conv_shape(
  944. src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  945. dst[3] = infer_conv_shape(
  946. src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  947. dst[4] = 64;
  948. } else {
  949. megdnn_assert(param().format == Param::Format::NHWCD4);
  950. megdnn_assert(
  951. src.ndim == 5, "invalid src ndim for NHWCD4, expected=5, got=%zu",
  952. src.ndim);
  953. megdnn_assert(
  954. cflt.icpg * cflt.group == src[2] * 4, "%s icpg=%u group=%u",
  955. errmsg().c_str(), cflt.icpg, cflt.group);
  956. dst.ndim = src.ndim;
  957. dst[0] = src[0];
  958. auto oc = cflt.ocpg * cflt.group;
  959. megdnn_assert(oc % 4 == 0);
  960. dst[2] = oc / 4;
  961. dst[1] = infer_conv_shape(
  962. src[1], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  963. dst[3] = infer_conv_shape(
  964. src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  965. megdnn_assert(src[4] == 4);
  966. dst[4] = 4;
  967. }
  968. if (!src.format.is_default() && !src.format.is_lowbit_aligned()) { // propagate
  969. dst.format = src.format;
  970. } else { // determined by dtype
  971. dst.format = TensorFormat(dst.dtype);
  972. }
  973. dst.init_contiguous_stride();
  974. return cflt;
  975. }
  976. /**
  977. * \warning: An explicit specialization shall be declared in a namespace
  978. * enclosing the specialized template. An explicit specialization whose
  979. * declarator-id is not qualified shall be declared in the nearest enclosing
  980. * namespace of the template, or, if the namespace is inline (7.3.1), any
  981. * namespace from its enclosing namespace set.
  982. * refer to:
  983. * https://stackoverflow.com/questions/25594644/warning-specialization-of-template-in-different-namespace
  984. */
  985. template <>
  986. ConvolutionBase<param::Convolution>::CanonizedFilterMeta ConvolutionBase<
  987. param::Convolution>::
  988. check_layout_fwd(
  989. const TensorLayout& src, const TensorLayout& filter,
  990. const TensorLayout& dst) const {
  991. megdnn_assert_contiguous(src);
  992. megdnn_assert_contiguous(filter);
  993. TensorLayout dst_expected;
  994. dst_expected.dtype = dst.dtype;
  995. auto ret = deduce_layout_fwd(src, filter, dst_expected);
  996. megdnn_assert_eq_layout(dst_expected, dst);
  997. return ret;
  998. }
  999. template <>
  1000. ConvolutionBase<param::ConvBias>::CanonizedFilterMeta ConvolutionBase<param::ConvBias>::
  1001. check_layout_fwd(
  1002. const TensorLayout& src, const TensorLayout& filter,
  1003. const TensorLayout& dst) const {
  1004. megdnn_assert_contiguous(src);
  1005. megdnn_assert_contiguous(filter);
  1006. TensorLayout dst_expected;
  1007. dst_expected.dtype = dst.dtype;
  1008. auto ret = deduce_layout_fwd(src, filter, dst_expected);
  1009. megdnn_assert_eq_layout(dst_expected, dst);
  1010. return ret;
  1011. }
  1012. template <>
  1013. ConvolutionBase<param::BatchConvBias>::CanonizedFilterMeta ConvolutionBase<
  1014. param::BatchConvBias>::
  1015. check_layout_fwd(
  1016. const TensorLayout& src, const TensorLayout& filter,
  1017. const TensorLayout& dst) const {
  1018. megdnn_assert_contiguous(src);
  1019. megdnn_assert_contiguous(filter);
  1020. TensorLayout dst_expected;
  1021. dst_expected.dtype = dst.dtype;
  1022. auto ret = deduce_layout_fwd(src, filter, dst_expected);
  1023. megdnn_assert_eq_layout(dst_expected, dst);
  1024. return ret;
  1025. }
  1026. void ConvolutionForward::deduce_dtype(DType src, DType filter, DType& dst) {
  1027. check_or_deduce_dtype_fwd(src, filter, dst);
  1028. }
  1029. void ConvolutionForward::deduce_layout(
  1030. const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst) {
  1031. deduce_layout_fwd(src, filter, dst);
  1032. }
  1033. ConvolutionForward::CanonizedFilterMeta ConvolutionForward::check_exec(
  1034. const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst,
  1035. size_t workspace_in_bytes, const PreprocessedFilter* preprocessed_filter) {
  1036. auto ret = check_layout_fwd(src, filter, dst);
  1037. auto required_workspace_in_bytes =
  1038. get_workspace_in_bytes(src, filter, dst, preprocessed_filter);
  1039. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  1040. return ret;
  1041. }
  1042. ConvolutionBackwardData::CanonizedFilterMeta ConvolutionBackwardData::check_exec(
  1043. const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& grad,
  1044. size_t workspace_in_bytes) {
  1045. auto grad_fwd = grad;
  1046. auto filter_fwd = filter;
  1047. auto diff_fwd = diff;
  1048. std::swap(grad_fwd.dtype, diff_fwd.dtype);
  1049. grad_fwd.init_contiguous_stride();
  1050. diff_fwd.init_contiguous_stride();
  1051. auto ret = check_layout_fwd(grad_fwd, filter_fwd, diff_fwd);
  1052. auto required_workspace_in_bytes = get_workspace_in_bytes(filter, diff, grad);
  1053. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  1054. return ret;
  1055. }
  1056. void ConvolutionBackwardData::deduce_dtype(DType filter, DType diff, DType& grad) {
  1057. SmallVector<DType> supported_dst_dtype;
  1058. if (filter.category() == diff.category() &&
  1059. filter.category() == DTypeCategory::FLOAT) {
  1060. supported_dst_dtype.push_back(filter);
  1061. } else if (filter.enumv() == DTypeEnum::Int8 && diff == filter) {
  1062. supported_dst_dtype.push_back(dtype::Int32());
  1063. } else if (
  1064. (filter.enumv() == DTypeEnum::QuantizedS8 &&
  1065. diff.enumv() == DTypeEnum::QuantizedS8) ||
  1066. (filter.enumv() == DTypeEnum::Quantized8Asymm &&
  1067. diff.enumv() == DTypeEnum::Quantized8Asymm)) {
  1068. supported_dst_dtype.push_back(dtype::QuantizedS32(mul_scale(filter, diff)));
  1069. if (grad.valid() && grad.enumv() == diff.enumv()) {
  1070. supported_dst_dtype.push_back(grad);
  1071. }
  1072. } else {
  1073. megdnn_throw(ssprintf(
  1074. "unsupported input / diff DType: %s x %s", filter.name(), diff.name()));
  1075. }
  1076. if (!grad.valid()) {
  1077. grad = supported_dst_dtype.at(0);
  1078. } else {
  1079. megdnn_assert(
  1080. vec_contains(supported_dst_dtype, grad),
  1081. "unsupported ConvBwd(%s, %s) -> %s", filter.name(), diff.name(),
  1082. grad.name());
  1083. }
  1084. megdnn_assert(
  1085. param().compute_mode != Param::ComputeMode::FLOAT32
  1086. #if !MEGDNN_DISABLE_FLOAT16
  1087. || filter.enumv() == DTypeEnum::Float16 ||
  1088. filter.enumv() == DTypeEnum::BFloat16
  1089. #endif
  1090. ,
  1091. "ComputeMode::FLOAT32 is only available for Float16/BFloat16 "
  1092. "input / output.");
  1093. }
  1094. void ConvolutionBackwardData::deduce_layout(
  1095. const TensorLayout& filter, const TensorLayout& diff, TensorLayout& grad) {
  1096. auto errmsg = [&]() { return get_errmsg(filter, diff, grad, param()); };
  1097. MEGDNN_MARK_USED_VAR(errmsg);
  1098. megdnn_assert_contiguous(filter);
  1099. megdnn_assert_contiguous(diff);
  1100. megdnn_assert(filter.ndim == 4_z || filter.ndim == 5_z, "%s", errmsg().c_str());
  1101. megdnn_assert(diff.ndim == 4_z || diff.ndim == 5_z, "%s", errmsg().c_str());
  1102. deduce_dtype(filter.dtype, diff.dtype, grad.dtype);
  1103. auto cflt = make_canonized_filter_meta(diff.ndim, filter);
  1104. auto deduce = [&errmsg](size_t out, size_t filter, size_t stride, size_t pad) {
  1105. MEGDNN_MARK_USED_VAR(errmsg);
  1106. auto i = (out - 1) * stride + filter;
  1107. megdnn_assert(i > pad * 2, "%s", errmsg().c_str());
  1108. return i - pad * 2;
  1109. };
  1110. if (param().format == Param::Format::NCHW ||
  1111. param().format == Param::Format::NHWC) {
  1112. size_t src_or_dst_c_pos = 0;
  1113. size_t src_or_dst_spatial_start = 0;
  1114. if (param().format == Param::Format::NCHW) {
  1115. src_or_dst_c_pos = 1;
  1116. src_or_dst_spatial_start = 2;
  1117. } else {
  1118. megdnn_assert(param().format == Param::Format::NHWC, "invalid conv format");
  1119. src_or_dst_c_pos = 3;
  1120. src_or_dst_spatial_start = 1;
  1121. }
  1122. megdnn_assert(
  1123. cflt.ocpg * cflt.group == diff[src_or_dst_c_pos], "%s",
  1124. errmsg().c_str());
  1125. grad.ndim = diff.ndim;
  1126. grad[0] = diff[0];
  1127. grad[src_or_dst_c_pos] = cflt.icpg * cflt.group;
  1128. for (size_t i = 0; i < cflt.spatial_ndim; ++i) {
  1129. grad[i + src_or_dst_spatial_start] =
  1130. deduce(diff[i + src_or_dst_spatial_start], cflt.dilated_spatial[i],
  1131. cflt.stride[i], cflt.padding[i]);
  1132. }
  1133. } else if (param().format == Param::Format::NCHW4) {
  1134. megdnn_assert(
  1135. diff.ndim == 5, "valid diff ndim for NCHW4, expected=5, got=%zu",
  1136. diff.ndim);
  1137. megdnn_assert(cflt.group == 1, "%s", errmsg().c_str());
  1138. megdnn_assert(cflt.ocpg * cflt.group == diff[1] * 4, "%s", errmsg().c_str());
  1139. grad.ndim = diff.ndim;
  1140. grad[0] = diff[0];
  1141. auto ic = cflt.icpg * cflt.group;
  1142. megdnn_assert(ic % 4 == 0);
  1143. grad[1] = ic / 4;
  1144. grad[2] = deduce(
  1145. diff[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  1146. grad[3] = deduce(
  1147. diff[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  1148. megdnn_assert(diff[4] == 4);
  1149. grad[4] = 4;
  1150. } else {
  1151. megdnn_assert(param().format == Param::Format::NHWCD4);
  1152. megdnn_assert(
  1153. diff.ndim == 5, "valid diff ndim for NHWCD4, expected=5, got=%zu",
  1154. diff.ndim);
  1155. megdnn_assert(cflt.ocpg * cflt.group == diff[2] * 4, "%s", errmsg().c_str());
  1156. grad.ndim = diff.ndim;
  1157. grad[0] = diff[0];
  1158. auto ic = cflt.icpg * cflt.group;
  1159. megdnn_assert(ic % 4 == 0);
  1160. grad[2] = ic / 4;
  1161. grad[1] = deduce(
  1162. diff[1], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]);
  1163. grad[3] = deduce(
  1164. diff[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
  1165. megdnn_assert(diff[4] == 4);
  1166. grad[4] = 4;
  1167. }
  1168. grad.format = diff.format;
  1169. grad.init_contiguous_stride();
  1170. }
  1171. ConvolutionBackwardFilter::CanonizedFilterMeta ConvolutionBackwardFilter::check_exec(
  1172. const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad,
  1173. size_t workspace_in_bytes) {
  1174. megdnn_assert(
  1175. src.dtype.category() == DTypeCategory::FLOAT &&
  1176. diff.dtype.category() == DTypeCategory::FLOAT &&
  1177. grad.dtype.category() == DTypeCategory::FLOAT,
  1178. "only float type is supported for conv backward filter");
  1179. auto src_fwd = src;
  1180. auto diff_fwd = diff;
  1181. src_fwd.init_contiguous_stride();
  1182. diff_fwd.init_contiguous_stride();
  1183. auto ret = check_layout_fwd(src_fwd, grad, diff_fwd);
  1184. auto required_workspace_in_bytes = get_workspace_in_bytes(src, diff, grad);
  1185. megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
  1186. return ret;
  1187. }
  1188. } // namespace megdnn
  1189. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台