You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

inplace_matmul_impl.cpp.hip 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377
  1. /**
  2. * \file src/rocm/convolution/forward/inplace_matmul_impl.cpp.hip
  3. *
  4. * This file is part of MegDNN, a deep neural network run-time library
  5. * developed by Megvii.
  6. *
  7. * \copyright Copyright (c) 2014-2019 Megvii Inc. All rights reserved.
  8. */
  9. #include "./inplace_matmul_impl.h.hip"
  10. #include "src/rocm/utils.h.hip"
  11. using namespace megdnn;
  12. using namespace rocm;
  13. namespace {
  14. struct BufferFetcherTexture {
  15. hipTextureObject_t tex;
  16. __device__ __forceinline__ float get(uint32_t offset) {
  17. return tex1Dfetch<float>(tex, offset);
  18. }
  19. };
  20. struct BufferFetcherRaw {
  21. const float* ptr;
  22. __device__ __forceinline__ float get(uint32_t offset) {
  23. return ptr[offset];
  24. }
  25. };
  26. struct BufferFetcherTextureHost {
  27. bool init_succ;
  28. BufferFetcherTexture val;
  29. BufferFetcherTextureHost(float* p, const size_t n);
  30. ~BufferFetcherTextureHost() { reset(); }
  31. void reset() {
  32. if (init_succ) {
  33. hip_check(hipDestroyTextureObject(val.tex));
  34. init_succ = false;
  35. }
  36. }
  37. };
  38. BufferFetcherTextureHost::BufferFetcherTextureHost(float* p, const size_t n) {
  39. init_succ = false;
  40. hipTextureObject_t tex_obj;
  41. hipResourceDesc res_desc;
  42. memset(&res_desc, 0, sizeof(hipResourceDesc));
  43. res_desc.resType = hipResourceTypeLinear;
  44. res_desc.res.linear.devPtr = static_cast<void*>(p);
  45. res_desc.res.linear.sizeInBytes = n * sizeof(float);
  46. res_desc.res.linear.desc =
  47. hipCreateChannelDesc(32, 0, 0, 0, hipChannelFormatKindFloat);
  48. hipTextureDesc tex_desc;
  49. memset(&tex_desc, 0, sizeof(hipTextureDesc));
  50. if (hipCreateTextureObject(&tex_obj, &res_desc, &tex_desc, NULL) ==
  51. hipSuccess) {
  52. val.tex = tex_obj;
  53. init_succ = true;
  54. } else {
  55. hipGetLastError(); // reset error
  56. }
  57. }
  58. template <class BufferFetcher>
  59. struct KernelPtr {
  60. typedef void (*type)(BufferFetcher, BufferFetcher, float*, uint32_t,
  61. uint32_t, uint32_t, uint32_t, uint32_t, uint32_t,
  62. uint32_t, uint32_t, uint32_t, uint32_t, uint32_t,
  63. uint32_t, uint32_t, uint32_t);
  64. };
  65. //! 1 -> 0xffffffff, 0 -> 0x00000000
  66. __device__ __forceinline__ uint32_t bool_as_mask(uint32_t cond) {
  67. return (!cond) - 1u;
  68. }
  69. union FloatAndU32 {
  70. float f;
  71. uint32_t u;
  72. };
  73. //! \p mask must be either all 1 or 0 bits
  74. template <class BufferFetcher>
  75. __device__ __forceinline__ float visit_with_mask(BufferFetcher buf,
  76. uint32_t offset,
  77. uint32_t mask) {
  78. FloatAndU32 f;
  79. f.f = buf.get(offset & mask);
  80. f.u &= mask;
  81. return f.f;
  82. }
  83. template <uint32_t BY, uint32_t BX, bool is_xcorr, class BufferFetcher>
  84. __global__ void conv_kernel(BufferFetcher src, BufferFetcher filter, float* dst,
  85. const uint32_t INP_BS, const uint32_t OUT_BS,
  86. const uint32_t IC, const uint32_t IH,
  87. const uint32_t IW, const uint32_t OC,
  88. const uint32_t OH, const uint32_t OW,
  89. const uint32_t FH, const uint32_t FW,
  90. const uint32_t SH, const uint32_t SW,
  91. const uint32_t PH, const uint32_t PW) {
  92. const uint32_t BM = BY < BX ? BY : BX;
  93. const uint32_t n = blockIdx.z;
  94. const uint32_t tidx = threadIdx.x;
  95. const uint32_t tidy = threadIdx.y;
  96. const uint32_t posx = blockIdx.x * blockDim.x + threadIdx.x;
  97. const uint32_t posy = blockIdx.y * blockDim.y + threadIdx.y;
  98. const uint32_t posx2 = posx << 2;
  99. const uint32_t posy2 = posy << 2;
  100. const uint32_t heightA = OC;
  101. const uint32_t widthA = IC * FH * FW;
  102. const uint32_t heightB = widthA;
  103. const uint32_t widthB = OH * OW;
  104. const uint32_t oh0 = (posx2 + 0) / OW * SH;
  105. const uint32_t ow0 = (posx2 + 0) % OW * SW;
  106. const uint32_t op0 = oh0 * IW + ow0;
  107. const uint32_t oh1 = (posx2 + 1) / OW * SH;
  108. const uint32_t ow1 = (posx2 + 1) % OW * SW;
  109. const uint32_t op1 = oh1 * IW + ow1;
  110. const uint32_t oh2 = (posx2 + 2) / OW * SH;
  111. const uint32_t ow2 = (posx2 + 2) % OW * SW;
  112. const uint32_t op2 = oh2 * IW + ow2;
  113. const uint32_t oh3 = (posx2 + 3) / OW * SH;
  114. const uint32_t ow3 = (posx2 + 3) % OW * SW;
  115. const uint32_t op3 = oh3 * IW + ow3;
  116. const uint32_t FP = FH * FW;
  117. __shared__ float4 localA[BY][BM];
  118. __shared__ float4 localB[BM][BX];
  119. uint32_t i = 0u;
  120. uint32_t offsetA = posy2 * widthA + tidx;
  121. uint32_t offsetB = n * INP_BS - PH * IW - PW;
  122. float4 sum0 = {0.0f, 0.0f, 0.0f, 0.0f}, sum1 = {0.0f, 0.0f, 0.0f, 0.0f},
  123. sum2 = {0.0f, 0.0f, 0.0f, 0.0f}, sum3 = {0.0f, 0.0f, 0.0f, 0.0f};
  124. uint32_t fh = tidy / FW % FH;
  125. uint32_t fw = tidy % FW;
  126. uint32_t ic = tidy / (FH * FW);
  127. uint32_t icm = tidy % (FH * FW);
  128. const uint32_t fhs = BM / FW % FH;
  129. const uint32_t fws = BM % FW;
  130. const uint32_t ics = BM / (FH * FW);
  131. const uint32_t icms = BM % (FH * FW);
  132. for (; i < widthA; i += BM, offsetA += BM) {
  133. // load localA
  134. if (tidx < BM) {
  135. localA[tidy][tidx].x = filter.get(offsetA + 0 * widthA);
  136. localA[tidy][tidx].y = filter.get(offsetA + 1 * widthA);
  137. localA[tidy][tidx].z = filter.get(offsetA + 2 * widthA);
  138. localA[tidy][tidx].w = filter.get(offsetA + 3 * widthA);
  139. }
  140. // load localB
  141. uint32_t fh2, fw2;
  142. if (is_xcorr) {
  143. fh2 = fh;
  144. fw2 = fw;
  145. } else {
  146. fh2 = FH - fh - 1;
  147. fw2 = FW - fw - 1;
  148. }
  149. if (tidy < BM) {
  150. uint32_t tmp = offsetB + (ic * IH + (fh2)) * IW + (fw2),
  151. ok = bool_as_mask(tidy + i < heightB),
  152. p0 = bool_as_mask(fh2 + oh0 >= PH && fh2 + oh0 < IH + PH &&
  153. fw2 + ow0 >= PW && fw2 + ow0 < IW + PW),
  154. p1 = bool_as_mask(fh2 + oh1 >= PH && fh2 + oh1 < IH + PH &&
  155. fw2 + ow1 >= PW && fw2 + ow1 < IW + PW),
  156. p2 = bool_as_mask(fh2 + oh2 >= PH && fh2 + oh2 < IH + PH &&
  157. fw2 + ow2 >= PW && fw2 + ow2 < IW + PW),
  158. p3 = bool_as_mask(fh2 + oh3 >= PH && fh2 + oh3 < IH + PH &&
  159. fw2 + ow3 >= PW && fw2 + ow3 < IW + PW);
  160. localB[tidy][tidx].x = visit_with_mask(src, tmp + op0, ok & p0);
  161. localB[tidy][tidx].y = visit_with_mask(src, tmp + op1, ok & p1);
  162. localB[tidy][tidx].z = visit_with_mask(src, tmp + op2, ok & p2);
  163. localB[tidy][tidx].w = visit_with_mask(src, tmp + op3, ok & p3);
  164. }
  165. __syncthreads();
  166. for (uint32_t j = 0u; j < BM; ++j) {
  167. float4 tmpA = localA[tidy][j];
  168. float4 tmpB = localB[j][tidx];
  169. sum0.x += tmpA.x * tmpB.x;
  170. sum0.y += tmpA.x * tmpB.y;
  171. sum0.z += tmpA.x * tmpB.z;
  172. sum0.w += tmpA.x * tmpB.w;
  173. sum1.x += tmpA.y * tmpB.x;
  174. sum1.y += tmpA.y * tmpB.y;
  175. sum1.z += tmpA.y * tmpB.z;
  176. sum1.w += tmpA.y * tmpB.w;
  177. sum2.x += tmpA.z * tmpB.x;
  178. sum2.y += tmpA.z * tmpB.y;
  179. sum2.z += tmpA.z * tmpB.z;
  180. sum2.w += tmpA.z * tmpB.w;
  181. sum3.x += tmpA.w * tmpB.x;
  182. sum3.y += tmpA.w * tmpB.y;
  183. sum3.z += tmpA.w * tmpB.z;
  184. sum3.w += tmpA.w * tmpB.w;
  185. }
  186. fw += fws;
  187. fh += fhs;
  188. fh += (fw >= FW);
  189. fh -= (fh >= FH) * FH;
  190. fw -= (fw >= FW) * FW;
  191. ic += ics;
  192. icm += icms;
  193. ic += (icm >= FP);
  194. icm -= (icm >= FP) * FP;
  195. __syncthreads();
  196. }
  197. const uint32_t dst_idx = n * OUT_BS + posy2 * widthB + posx2;
  198. bool y0 = (posy2 + 0 < heightA);
  199. bool y1 = (posy2 + 1 < heightA);
  200. bool y2 = (posy2 + 2 < heightA);
  201. bool y3 = (posy2 + 3 < heightA);
  202. bool x0 = (posx2 + 0 < widthB);
  203. bool x1 = (posx2 + 1 < widthB);
  204. bool x2 = (posx2 + 2 < widthB);
  205. bool x3 = (posx2 + 3 < widthB);
  206. if (y0) {
  207. if (x0)
  208. dst[dst_idx + 0 * widthB + 0] = sum0.x;
  209. if (x1)
  210. dst[dst_idx + 0 * widthB + 1] = sum0.y;
  211. if (x2)
  212. dst[dst_idx + 0 * widthB + 2] = sum0.z;
  213. if (x3)
  214. dst[dst_idx + 0 * widthB + 3] = sum0.w;
  215. }
  216. if (y1) {
  217. if (x0)
  218. dst[dst_idx + 1 * widthB + 0] = sum1.x;
  219. if (x1)
  220. dst[dst_idx + 1 * widthB + 1] = sum1.y;
  221. if (x2)
  222. dst[dst_idx + 1 * widthB + 2] = sum1.z;
  223. if (x3)
  224. dst[dst_idx + 1 * widthB + 3] = sum1.w;
  225. }
  226. if (y2) {
  227. if (x0)
  228. dst[dst_idx + 2 * widthB + 0] = sum2.x;
  229. if (x1)
  230. dst[dst_idx + 2 * widthB + 1] = sum2.y;
  231. if (x2)
  232. dst[dst_idx + 2 * widthB + 2] = sum2.z;
  233. if (x3)
  234. dst[dst_idx + 2 * widthB + 3] = sum2.w;
  235. }
  236. if (y3) {
  237. if (x0)
  238. dst[dst_idx + 3 * widthB + 0] = sum3.x;
  239. if (x1)
  240. dst[dst_idx + 3 * widthB + 1] = sum3.y;
  241. if (x2)
  242. dst[dst_idx + 3 * widthB + 2] = sum3.z;
  243. if (x3)
  244. dst[dst_idx + 3 * widthB + 3] = sum3.w;
  245. }
  246. }
  247. } // anonymous namespace
  248. void convolution::exec_inplace_matmul_fwd(
  249. const float* src, const float* filter, float* dst, size_t N,
  250. size_t INP_BS, size_t OUT_BS, size_t IC, size_t IH, size_t IW,
  251. size_t OC, size_t OH, size_t OW, size_t FH, size_t FW, size_t PH,
  252. size_t PW, size_t SH, size_t SW, bool is_xcorr, hipStream_t stream) {
  253. BufferFetcherTextureHost src_tex(const_cast<float*>(src), N * INP_BS),
  254. filter_tex(const_cast<float*>(filter), OC * IC * FH * FW);
  255. BufferFetcherRaw src_buf, filter_buf;
  256. src_buf.ptr = src;
  257. filter_buf.ptr = filter;
  258. if (!src_tex.init_succ || !filter_tex.init_succ) {
  259. src_tex.reset();
  260. filter_tex.reset();
  261. }
  262. int m = OC;
  263. int n = OH * OW;
  264. int BY = 1;
  265. int BX = 1;
  266. if (m <= 64) {
  267. while (BY < 16 && (BY << 2) < m)
  268. BY <<= 1;
  269. BX = 256 / BY;
  270. } else if (n <= 64) {
  271. while (BX < 16 && (BX << 2) < n)
  272. BX <<= 1;
  273. BY = 256 / BX;
  274. } else {
  275. BX = BY = 16;
  276. }
  277. dim3 blocks((OH * OW + BX * 4 - 1) / (BX * 4), (OC + BY * 4 - 1) / (BY * 4),
  278. N);
  279. dim3 threads(BX, BY);
  280. #define DISPATCH_BX_BY(BX, BY) \
  281. do { \
  282. if (src_tex.init_succ) { \
  283. KernelPtr<BufferFetcherTexture>::type kptr; \
  284. if (is_xcorr) { \
  285. kptr = conv_kernel<BY, BX, true, BufferFetcherTexture>; \
  286. } else { \
  287. kptr = conv_kernel<BY, BX, false, BufferFetcherTexture>; \
  288. } \
  289. kptr<<<blocks, threads, 0, stream>>>( \
  290. src_tex.val, filter_tex.val, dst, INP_BS, OUT_BS, IC, IH, \
  291. IW, OC, OH, OW, FH, FW, SH, SW, PH, PW); \
  292. } else { \
  293. KernelPtr<BufferFetcherRaw>::type kptr; \
  294. if (is_xcorr) { \
  295. kptr = conv_kernel<BY, BX, true, BufferFetcherRaw>; \
  296. } else { \
  297. kptr = conv_kernel<BY, BX, false, BufferFetcherRaw>; \
  298. } \
  299. kptr<<<blocks, threads, 0, stream>>>( \
  300. src_buf, filter_buf, dst, INP_BS, OUT_BS, IC, IH, IW, OC, \
  301. OH, OW, FH, FW, SH, SW, PH, PW); \
  302. } \
  303. } while (0)
  304. #define DISPATCH_BX(BX) \
  305. do { \
  306. DISPATCH_BX_BY(BX, 256 / BX); \
  307. } while (0)
  308. #define DISPATCH() \
  309. do { \
  310. switch (BX) { \
  311. case 1: \
  312. DISPATCH_BX(1); \
  313. break; \
  314. case 2: \
  315. DISPATCH_BX(2); \
  316. break; \
  317. case 4: \
  318. DISPATCH_BX(4); \
  319. break; \
  320. case 8: \
  321. DISPATCH_BX(8); \
  322. break; \
  323. case 16: \
  324. DISPATCH_BX(16); \
  325. break; \
  326. case 32: \
  327. DISPATCH_BX(32); \
  328. break; \
  329. case 64: \
  330. DISPATCH_BX(64); \
  331. break; \
  332. case 128: \
  333. DISPATCH_BX(128); \
  334. break; \
  335. case 256: \
  336. DISPATCH_BX(256); \
  337. break; \
  338. default: \
  339. report_error("no usable kernel"); \
  340. } \
  341. } while (0)
  342. DISPATCH();
  343. #undef DISPATCH
  344. #undef DISPATCH_BX
  345. #undef DISPATCH_BX_BY
  346. after_kernel_launch();
  347. }
  348. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台