You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

run_conv.cpp 17 kB


  1. /**
  2. * \file dnn/src/fallback/convolution/run_conv.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "src/fallback/convolution/run_conv.h"
  12. #include "src/common/utils.h"
  13. #include "midout.h"
  14. MIDOUT_DECL(megdnn_fallback_conv)
  15. namespace {
  16. bool can_run_xcorr_single_channel_templated(
  17. size_t /* IH */, size_t /* IW */,
  18. size_t FH, size_t FW,
  19. size_t /* OH */, size_t /* OW */,
  20. size_t /* PH */, size_t /* PW */,
  21. size_t /* SH */, size_t /* SW */)
  22. {
  23. return FH == FW && FH >= 1 && FH <= 7;
  24. }
  25. template <int ker_size>
  26. void run_xcorr_single_channel_templated_impl(const float * __restrict src,
  27. const float * __restrict filter,
  28. float * __restrict dst,
  29. size_t IH, size_t IW,
  30. size_t OH, size_t OW, size_t PH, size_t PW, size_t SH, size_t SW,
  31. bool add_mode)
  32. {
  33. #define divup(x, y) (((x)+(y)-1)/(y))
  34. #define clear(oh, ow) if (!add_mode) { dst[(oh)*OW + (ow)] = 0; }
  35. #define update(oh, ow, fh, fw) \
  36. dst[(oh)*OW + (ow)] += filter[(fh)*ker_size + (fw)] * \
  37. src[((oh)*SH+(fh)-PH)*IW + ((ow)*SW+(fw)-PW)]
  38. // OH = (IH-ker_size)/stride+1
  39. // OW = (IW-ker_size)/stride+1
  40. // good region:
  41. // oh*stride-anchor >= 0
  42. // oh*stride-anchor+ker_size <= IH
  43. // oh >= anchor/stride
  44. // oh <= (IH+anchor-ker_size)/stride
  45. size_t oh_start = divup(PH, SH);
  46. size_t oh_end = IH+PH>=ker_size ? (IH+PH-ker_size)/SH+1 : 0;
  47. size_t ow_start = divup(PW, SW);
  48. size_t ow_end = IW+PW>=ker_size ? (IW+PW-ker_size)/SW+1 : 0;
  49. if (oh_start > oh_end) oh_start = oh_end = 0;
  50. if (ow_start > ow_end) ow_start = ow_end = 0;
  51. for (size_t oh = 0; oh < oh_start; ++oh)
  52. for (size_t ow = 0; ow < OW; ++ow) {
  53. clear(oh, ow);
  54. int ih = oh*SH - PH;
  55. int iw = ow*SW - PW;
  56. for (int fh = 0; fh < ker_size; ++fh) if (ih+fh >= 0 && ih+fh < (int)IH)
  57. for (int fw = 0; fw < ker_size; ++fw) if (iw+fw >= 0 && iw+fw < (int)IW)
  58. {
  59. update(oh, ow, fh, fw);
  60. }
  61. }
  62. for (size_t oh = oh_start; oh < oh_end; ++oh) {
  63. for (size_t ow = 0; ow < ow_start; ++ow) {
  64. clear(oh, ow);
  65. int iw = ow*SW - PW;
  66. for (int fh = 0; fh < ker_size; ++fh)
  67. for (int fw = 0; fw < ker_size; ++fw)
  68. {
  69. if (iw+fw >= 0 && iw+fw < (int)IW) update(oh, ow, fh, fw);
  70. }
  71. }
  72. for (size_t ow = ow_start; ow < ow_end; ++ow) {
  73. clear(oh, ow);
  74. for (int fh = 0; fh < ker_size; ++fh)
  75. for (int fw = 0; fw < ker_size; ++fw)
  76. {
  77. update(oh, ow, fh, fw);
  78. }
  79. }
  80. for (size_t ow = ow_end; ow < OW; ++ow) {
  81. clear(oh, ow);
  82. int iw = ow*SW - PW;
  83. for (int fh = 0; fh < ker_size; ++fh)
  84. for (int fw = 0; fw < ker_size; ++fw)
  85. {
  86. if (iw+fw >= 0 && iw+fw < (int)IW) update(oh, ow, fh, fw);
  87. }
  88. }
  89. }
  90. for (size_t oh = oh_end; oh < OH; ++oh) {
  91. for (size_t ow = 0; ow < OW; ++ow) {
  92. clear(oh, ow);
  93. int ih = oh*SH - PH;
  94. int iw = ow*SW - PW;
  95. for (int fh = 0; fh < ker_size; ++fh) if (ih+fh >= 0 && ih+fh < (int)IH)
  96. for (int fw = 0; fw < ker_size; ++fw) if (iw+fw >= 0 && iw+fw < (int)IW)
  97. {
  98. update(oh, ow, fh, fw);
  99. }
  100. }
  101. }
  102. #undef divup
  103. #undef clear
  104. #undef update
  105. }
  106. void run_xcorr_single_channel_templated(
  107. const float *src, const float *filter, float *dst,
  108. size_t IH, size_t IW, size_t FH, size_t FW,
  109. size_t OH, size_t OW, size_t PH, size_t PW, size_t SH, size_t SW,
  110. bool add_mode)
  111. {
  112. (void)FW;
  113. #define DISPATCH(ker_size) \
  114. if (FH == ker_size) { \
  115. MIDOUT_BEGIN(megdnn_fallback_conv, ker_size) { \
  116. run_xcorr_single_channel_templated_impl<ker_size>( \
  117. src, filter, dst, \
  118. IH, IW, OH, OW, PH, PW, SH, SW, add_mode); \
  119. } MIDOUT_END(); \
  120. return; \
  121. }
  122. DISPATCH(1)
  123. DISPATCH(2)
  124. DISPATCH(3)
  125. DISPATCH(4)
  126. DISPATCH(5)
  127. DISPATCH(6)
  128. DISPATCH(7)
  129. #undef DISPATCH
  130. megdnn_throw("internal error in conv template dispatching: impossible");
  131. }
  132. void run_xcorr_single_channel_nontemplated(
  133. const float *src, const float *filter, float *dst,
  134. size_t IH, size_t IW, size_t FH_, size_t FW_,
  135. size_t OH, size_t OW, size_t PH, size_t PW, size_t SH, size_t SW,
  136. bool add_mode)
  137. {
  138. #define divup(x, y) (((x)+(y)-1)/(y))
  139. #define clear(oh, ow) if (!add_mode) { dst[(oh)*OW + (ow)] = 0; }
  140. #define update(oh, ow, fh, fw) \
  141. dst[(oh)*OW + (ow)] += filter[(fh)*FW + (fw)] * \
  142. src[((oh)*SH+(fh)-PH)*IW + ((ow)*SW+(fw)-PW)]
  143. // OH = (IH-ker_size)/stride+1
  144. // OW = (IW-ker_size)/stride+1
  145. // good region:
  146. // oh*stride-anchor >= 0
  147. // oh*stride-anchor+ker_size <= IH
  148. // oh >= anchor/stride
  149. // oh <= (IH+anchor-ker_size)/stride
  150. int FH = FH_, FW = FW_;
  151. size_t oh_start = divup(PH, SH);
  152. size_t oh_end = IH+PH>=FH_ ? (IH+PH-FH)/SH+1 : 0;
  153. size_t ow_start = divup(PW, SW);
  154. size_t ow_end = IW+PW>=FW_ ? (IW+PW-FW)/SW+1 : 0;
  155. if (oh_start > oh_end) oh_start = oh_end = 0;
  156. if (ow_start > ow_end) ow_start = ow_end = 0;
  157. for (size_t oh = 0; oh < oh_start; ++oh)
  158. for (size_t ow = 0; ow < OW; ++ow) {
  159. clear(oh, ow);
  160. int ih = oh*SH - PH;
  161. int iw = ow*SW - PW;
  162. for (int fh = 0; fh < FH; ++fh) if (ih+fh >= 0 && ih+fh < (int)IH)
  163. for (int fw = 0; fw < FW; ++fw) if (iw+fw >= 0 && iw+fw < (int)IW)
  164. {
  165. update(oh, ow, fh, fw);
  166. }
  167. }
  168. for (size_t oh = oh_start; oh < oh_end; ++oh) {
  169. for (size_t ow = 0; ow < ow_start; ++ow) {
  170. clear(oh, ow);
  171. int iw = ow*SW - PW;
  172. for (int fh = 0; fh < FH; ++fh)
  173. for (int fw = 0; fw < FW; ++fw)
  174. {
  175. if (iw+fw >= 0 && iw+fw < (int)IW) update(oh, ow, fh, fw);
  176. }
  177. }
  178. for (size_t ow = ow_start; ow < ow_end; ++ow) {
  179. clear(oh, ow);
  180. for (int fh = 0; fh < FH; ++fh)
  181. for (int fw = 0; fw < FW; ++fw)
  182. {
  183. update(oh, ow, fh, fw);
  184. }
  185. }
  186. for (size_t ow = ow_end; ow < OW; ++ow) {
  187. clear(oh, ow);
  188. int iw = ow*SW - PW;
  189. for (int fh = 0; fh < FH; ++fh)
  190. for (int fw = 0; fw < FW; ++fw)
  191. {
  192. if (iw+fw >= 0 && iw+fw < (int)IW) update(oh, ow, fh, fw);
  193. }
  194. }
  195. }
  196. for (size_t oh = oh_end; oh < OH; ++oh) {
  197. for (size_t ow = 0; ow < OW; ++ow) {
  198. clear(oh, ow);
  199. int ih = oh*SH - PH;
  200. int iw = ow*SW - PW;
  201. for (int fh = 0; fh < FH; ++fh) if (ih+fh >= 0 && ih+fh < (int)IH)
  202. for (int fw = 0; fw < FW; ++fw) if (iw+fw >= 0 && iw+fw < (int)IW)
  203. {
  204. update(oh, ow, fh, fw);
  205. }
  206. }
  207. }
  208. #undef divup
  209. #undef clear
  210. #undef update
  211. }
  212. void run_xcorr_single_channel(const float *src, const float *filter, float *dst,
  213. size_t IH, size_t IW,
  214. size_t FH, size_t FW,
  215. size_t OH, size_t OW,
  216. size_t PH, size_t PW,
  217. size_t SH, size_t SW,
  218. bool add_mode)
  219. {
  220. if (can_run_xcorr_single_channel_templated(IH, IW, FH, FW, OH, OW,
  221. PH, PW, SH, SW)) {
  222. run_xcorr_single_channel_templated(src, filter, dst,
  223. IH, IW, FH, FW, OH, OW, PH, PW, SH, SW,
  224. add_mode);
  225. } else {
  226. MIDOUT_BEGIN(megdnn_fallback_conv, void) {
  227. run_xcorr_single_channel_nontemplated(src, filter, dst,
  228. IH, IW, FH, FW, OH, OW, PH, PW, SH, SW,
  229. add_mode);
  230. } MIDOUT_END();
  231. }
  232. }
  233. /*================ ConvolutionBackwardData =============*/
  234. template <int ker_size>
  235. void conv_backdata_single_channel_templated_impl(const float * __restrict diff,
  236. const float * __restrict filter,
  237. float * __restrict grad,
  238. size_t IH, size_t IW,
  239. size_t OH, size_t OW, size_t PH, size_t PW, size_t SH, size_t SW){
  240. #define divup(x, y) (((x) + (y)-1) / (y))
  241. #define update(oh, ow, fh, fw, val) \
  242. grad[(oh+fh)*OW + (ow+fw)] += filter[(fh)*ker_size + (fw)] * val
  243. size_t ih_start = divup(PH, SH);
  244. size_t ih_end = OH+PH>=ker_size ? (OH+PH-ker_size)/SH+1 : 0;
  245. size_t iw_start = divup(PW, SW);
  246. size_t iw_end = OW+PW>=ker_size ? (OW+PW-ker_size)/SW+1 : 0;
  247. if (ih_start > ih_end) ih_start = ih_end = 0;
  248. if (iw_start > iw_end) iw_start = iw_end = 0;
  249. for (size_t ih = 0; ih < ih_start; ++ih)
  250. for (size_t iw = 0; iw < IW; ++iw) {
  251. int oh = ih*SH - PH;
  252. int ow = iw*SW - PW;
  253. float val = diff[ih*IW + iw];
  254. for (int fh = 0; fh < ker_size; ++fh) if (oh+fh >= 0 && oh+fh < (int)OH)
  255. for (int fw = 0; fw < ker_size; ++fw) if (ow+fw >= 0 && ow+fw < (int)OW)
  256. {
  257. update(oh, ow, fh, fw, val);
  258. }
  259. }
  260. for (size_t ih = ih_start; ih < ih_end; ++ih) {
  261. int oh = ih*SH - PH;
  262. for (size_t iw = 0; iw < iw_start; ++iw) {
  263. int ow = iw*SW - PW;
  264. float val = diff[ih*IW + iw];
  265. for (int fh = 0; fh < ker_size; ++fh)
  266. for (int fw = 0; fw < ker_size; ++fw)
  267. {
  268. if (ow+fw >= 0 && ow+fw < (int)OW) update(oh, ow, fh, fw, val);
  269. }
  270. }
  271. for (size_t iw = iw_start; iw < iw_end; ++iw) {
  272. int ow = iw*SW - PW;
  273. float val = diff[ih*IW + iw];
  274. for (int fh = 0; fh < ker_size; ++fh)
  275. for (int fw = 0; fw < ker_size; ++fw)
  276. {
  277. update(oh, ow, fh, fw, val);
  278. }
  279. }
  280. for (size_t iw = iw_end; iw < IW; ++iw) {
  281. int ow = iw*SW - PW;
  282. float val = diff[ih*IW + iw];
  283. for (int fh = 0; fh < ker_size; ++fh)
  284. for (int fw = 0; fw < ker_size; ++fw)
  285. {
  286. if (ow+fw >= 0 && ow+fw < (int)OW) update(oh, ow, fh, fw, val);
  287. }
  288. }
  289. }
  290. for (size_t ih = ih_end; ih < IH; ++ih) {
  291. for (size_t iw = 0; iw < IW; ++iw) {
  292. int oh = ih*SH - PH;
  293. int ow = iw*SW - PW;
  294. float val = diff[ih*IW + iw];
  295. for (int fh = 0; fh < ker_size; ++fh) if (oh+fh >= 0 && oh+fh < (int)OH)
  296. for (int fw = 0; fw < ker_size; ++fw) if (ow+fw >= 0 && ow+fw < (int)OW)
  297. {
  298. update(oh, ow, fh, fw, val);
  299. }
  300. }
  301. }
  302. #undef divup
  303. #undef update
  304. }
  305. void conv_backdata_single_channel_templated(
  306. const float *src, const float *filter, float *dst,
  307. size_t IH, size_t IW, size_t FH, size_t FW,
  308. size_t OH, size_t OW, size_t PH, size_t PW, size_t SH, size_t SW)
  309. {
  310. megdnn_ignore(FW);
  311. #define DISPATCH(ker_size) \
  312. if (FH == ker_size) { \
  313. MIDOUT_BEGIN(megdnn_fallback_conv, ker_size) { \
  314. conv_backdata_single_channel_templated_impl<ker_size>( \
  315. src, filter, dst, \
  316. IH, IW, OH, OW, PH, PW, SH, SW); \
  317. } MIDOUT_END(); \
  318. return; \
  319. }
  320. DISPATCH(1)
  321. DISPATCH(2)
  322. DISPATCH(3)
  323. DISPATCH(4)
  324. DISPATCH(5)
  325. DISPATCH(6)
  326. DISPATCH(7)
  327. #undef DISPATCH
  328. megdnn_throw(
  329. "internal error in conv_backdata template dispatching: impossible");
  330. }
  331. void conv_backdata_single_channel_nontemplated(
  332. const float *diff, const float *filter, float *grad,
  333. size_t IH, size_t IW, size_t FH_, size_t FW_,
  334. size_t OH, size_t OW, size_t PH, size_t PW, size_t SH, size_t SW){
  335. #define divup(x, y) (((x) + (y)-1) / (y))
  336. #define update(oh, ow, fh, fw, val) \
  337. grad[(oh+fh)*OW + (ow+fw)] += filter[(fh)*FW + (fw)] * val
  338. int FH = FH_, FW = FW_;
  339. size_t ih_start = divup(PH, SH);
  340. size_t ih_end = OH+PH>=FH_ ? (OH+PH-FH)/SH+1 : 0;
  341. size_t iw_start = divup(PW, SW);
  342. size_t iw_end = OW+PW>=FW_ ? (OW+PW-FW)/SW+1 : 0;
  343. if (ih_start > ih_end) ih_start = ih_end = 0;
  344. if (iw_start > iw_end) iw_start = iw_end = 0;
  345. for (size_t ih = 0; ih < ih_start; ++ih)
  346. for (size_t iw = 0; iw < IW; ++iw) {
  347. int oh = ih*SH - PH;
  348. int ow = iw*SW - PW;
  349. float val = diff[ih*IW + iw];
  350. for (int fh = 0; fh < FH; ++fh) if (oh+fh >= 0 && oh+fh < (int)OH)
  351. for (int fw = 0; fw < FW; ++fw) if (ow+fw >= 0 && ow+fw < (int)OW)
  352. {
  353. update(oh, ow, fh, fw, val);
  354. }
  355. }
  356. for (size_t ih = ih_start; ih < ih_end; ++ih) {
  357. int oh = ih*SH - PH;
  358. for (size_t iw = 0; iw < iw_start; ++iw) {
  359. int ow = iw*SW - PW;
  360. float val = diff[ih*IW + iw];
  361. for (int fh = 0; fh < FH; ++fh)
  362. for (int fw = 0; fw < FW; ++fw)
  363. {
  364. if (ow+fw >= 0 && ow+fw < (int)OW) update(oh, ow, fh, fw, val);
  365. }
  366. }
  367. for (size_t iw = iw_start; iw < iw_end; ++iw) {
  368. int ow = iw*SW - PW;
  369. float val = diff[ih*IW + iw];
  370. for (int fh = 0; fh < FH; ++fh)
  371. for (int fw = 0; fw < FW; ++fw)
  372. {
  373. update(oh, ow, fh, fw, val);
  374. }
  375. }
  376. for (size_t iw = iw_end; iw < IW; ++iw) {
  377. int ow = iw*SW - PW;
  378. float val = diff[ih*IW + iw];
  379. for (int fh = 0; fh < FH; ++fh)
  380. for (int fw = 0; fw < FW; ++fw)
  381. {
  382. if (ow+fw >= 0 && ow+fw < (int)OW) update(oh, ow, fh, fw, val);
  383. }
  384. }
  385. }
  386. for (size_t ih = ih_end; ih < IH; ++ih) {
  387. for (size_t iw = 0; iw < IW; ++iw) {
  388. int oh = ih*SH - PH;
  389. int ow = iw*SW - PW;
  390. float val = diff[ih*IW + iw];
  391. for (int fh = 0; fh < FH; ++fh) if (oh+fh >= 0 && oh+fh < (int)OH)
  392. for (int fw = 0; fw < FW; ++fw) if (ow+fw >= 0 && ow+fw < (int)OW)
  393. {
  394. update(oh, ow, fh, fw, val);
  395. }
  396. }
  397. }
  398. #undef divup
  399. #undef update
  400. }
  401. void conv_backdata_single_channel(const float *diff, const float *filter, float *grad,
  402. size_t IH, size_t IW,
  403. size_t FH, size_t FW,
  404. size_t OH, size_t OW,
  405. size_t PH, size_t PW,
  406. size_t SH, size_t SW)
  407. {
  408. if (can_run_xcorr_single_channel_templated(IH, IW, FH, FW, OH, OW,
  409. PH, PW, SH, SW)) {
  410. conv_backdata_single_channel_templated(diff, filter, grad,
  411. IH, IW, FH, FW, OH, OW, PH, PW, SH, SW);
  412. } else {
  413. MIDOUT_BEGIN(megdnn_fallback_conv, void) {
  414. conv_backdata_single_channel_nontemplated(diff, filter, grad,
  415. IH, IW, FH, FW, OH, OW, PH, PW, SH, SW);
  416. } MIDOUT_END();
  417. }
  418. }
  419. } // anonymous namespace
  420. namespace megdnn {
  421. namespace fallback {
  422. namespace convolution {
  423. void run_conv(const float *src, const float *filter, float *dst, void *workspace,
  424. size_t IH, size_t IW, size_t IC,
  425. size_t FH, size_t FW,
  426. size_t OH, size_t OW, size_t OC,
  427. size_t PH, size_t PW,
  428. size_t SH, size_t SW,
  429. bool xcorr)
  430. {
  431. for (size_t oc = 0; oc < OC; ++oc)
  432. for (size_t ic = 0; ic < IC; ++ic)
  433. {
  434. // ut for untransposed
  435. const float *fut = filter + oc*IC*FH*FW + ic*FH*FW;
  436. const float *f;
  437. if (!xcorr) {
  438. // need transpose
  439. f = (float *)workspace;
  440. for (size_t fh = 0; fh < FH; ++fh)
  441. for (size_t fw = 0; fw < FW; ++fw)
  442. {
  443. ((float *)f)[fh*FW + fw] = fut[(FH-fh-1)*FW + (FW-fw-1)];
  444. }
  445. } else {
  446. // do not need transpose
  447. f = fut;
  448. }
  449. run_xcorr_single_channel(src + ic*IH*IW, f, dst + oc*OH*OW,
  450. IH, IW, FH, FW, OH, OW, PH, PW, SH, SW,
  451. ic > 0);
  452. }
  453. }
  454. void run_conv_backward_data(const float* diff, const float* filter, float* grad,
  455. void* workspace, size_t IH, size_t IW, size_t IC,
  456. size_t FH, size_t FW, size_t OH, size_t OW,
  457. size_t OC, size_t PH, size_t PW, size_t SH,
  458. size_t SW, bool xcorr) {
  459. std::memset(grad, 0, sizeof(float) * IC * OH * OW);
  460. for (size_t oc = 0; oc < OC; ++oc)
  461. for (size_t ic = 0; ic < IC; ++ic) {
  462. // ut for untransposed
  463. const float* fut = filter + oc * IC * FH * FW + ic * FH * FW;
  464. const float* f;
  465. if (!xcorr) {
  466. // need transpose
  467. f = (float*)workspace;
  468. for (size_t fh = 0; fh < FH; ++fh)
  469. for (size_t fw = 0; fw < FW; ++fw) {
  470. ((float*)f)[fh * FW + fw] =
  471. fut[(FH - fh - 1) * FW + (FW - fw - 1)];
  472. }
  473. } else {
  474. // do not need transpose
  475. f = fut;
  476. }
  477. conv_backdata_single_channel(diff + oc * IH * IW, f,
  478. grad + ic * OH * OW, IH, IW, FH, FW,
  479. OH, OW, PH, PW, SH, SW);
  480. }
  481. }
  482. } // namespace convolution
  483. } // namespace fallback
  484. } // namespace megdnn
  485. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台