You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

bwd_data.cpp.hip 6.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. /**
  2. * \file src/rocm/convolution/chanwise/bwd_data.cpp.hip
  3. *
  4. * This file is part of MegDNN, a deep neural network run-time library
  5. * developed by Megvii.
  6. *
  7. * \copyright Copyright (c) 2014-2019 Megvii Inc. All rights reserved.
  8. */
  9. #include "hip_header.h"
  10. #include "./kern.h.hip"
  11. #include "./kern_helper.h.hip"
  12. using namespace megdnn;
  13. using namespace rocm;
  14. using namespace convolution;
  15. using namespace chanwise;
  16. namespace {
  17. // grid idx is (inp_chl, worker_index)
  18. // each y-slice of a block works on an (N, IH, IW) spatial image at given
  19. // inp_chl
  20. template <typename T, int CHL_MUL_SET, int FH_SET, int FW_SET, int SH_SET,
  21. int SW_SET>
  22. __global__ void kern_bwd_data(T* src_grad, const T* dst_grad, const T* flt_tot,
  23. Param param) {
  24. extern __shared__ uint8_t flt_storage[];
  25. T* const flt = reinterpret_cast<T*>(flt_storage);
  26. const uint32_t N = param.batch, IC = param.src_chl, ic = blockIdx.x,
  27. IH = param.src_h, IW = param.src_w,
  28. CHL_MUL = CHL_MUL_SET ? CHL_MUL_SET : param.chl_mul,
  29. FH = FH_SET ? FH_SET : param.flt_h,
  30. FW = FW_SET ? FW_SET : param.flt_w, FSIZE = FH * FW,
  31. PH = param.pad_h, PW = param.pad_w,
  32. SH = SH_SET ? SH_SET : param.stride_h,
  33. SW = SW_SET ? SW_SET : param.stride_w, OH = param.out_h,
  34. OW = param.out_w, TOT_OUT = N * IH * IW;
  35. block_memcpy(flt, flt_tot + ic * FSIZE * CHL_MUL, FSIZE * CHL_MUL);
  36. dst_grad += ic * CHL_MUL * OH * OW;
  37. src_grad += ic * IH * IW;
  38. uint32_t out_idx_ = blockIdx.y * blockDim.x + threadIdx.x,
  39. nr_out_per_launch = blockDim.x * gridDim.y;
  40. for (; out_idx_ < TOT_OUT; out_idx_ += nr_out_per_launch) {
  41. uint32_t out_idx = out_idx_, n, ih, iw;
  42. out_idx = div_mod(out_idx, IW, iw);
  43. out_idx = div_mod(out_idx, IH, ih);
  44. n = out_idx;
  45. const T* dst_grad_base = dst_grad + n * (IC * CHL_MUL * OH * OW);
  46. T sum(0);
  47. // o >= max(0, floor_div((i+P-F+1), S))
  48. uint32_t ohmin = max(int32_t(ih + PH - FH + SH), 0) / SH,
  49. owmin = max(int32_t(iw + PW - FW + SW), 0) / SW,
  50. ohmax = min((ih + PH) / SH, OH - 1),
  51. owmax = min((iw + PW) / SW, OW - 1);
  52. if (SH_SET == 1 && SW_SET == 1 && FH_SET && FW_SET) {
  53. #pragma unroll
  54. for (uint32_t doh = 0; doh < FH; ++doh) {
  55. uint32_t oh = ohmin + doh;
  56. if (oh <= ohmax) {
  57. uint32_t fh = ih - oh * SH + PH;
  58. #pragma unroll
  59. for (uint32_t dow = 0; dow < FW; ++dow) {
  60. uint32_t ow = owmin + dow;
  61. if (ow <= owmax) {
  62. uint32_t fw = iw - ow * SW + PW;
  63. const T* pd = dst_grad_base + oh * OW + ow;
  64. const T* pf = flt + fh * FW + fw;
  65. #pragma unroll
  66. for (uint32_t chl_mul = 0; chl_mul < CHL_MUL;
  67. ++chl_mul) {
  68. sum += *pd * *pf;
  69. pd += OH * OW;
  70. pf += FSIZE;
  71. }
  72. }
  73. }
  74. }
  75. }
  76. } else {
  77. for (uint32_t oh = ohmin; oh <= ohmax; ++oh) {
  78. uint32_t fh = ih - oh * SH + PH;
  79. for (uint32_t ow = owmin; ow <= owmax; ++ow) {
  80. uint32_t fw = iw - ow * SW + PW;
  81. const T* pd = dst_grad_base + oh * OW + ow;
  82. const T* pf = flt + fh * FW + fw;
  83. #pragma unroll
  84. for (uint32_t chl_mul = 0; chl_mul < CHL_MUL; ++chl_mul) {
  85. sum += *pd * *pf;
  86. pd += OH * OW;
  87. pf += FSIZE;
  88. }
  89. }
  90. }
  91. }
  92. src_grad[(n * (IC * IH) + ih) * IW + iw] = sum;
  93. }
  94. }
  95. template <typename T>
  96. class KernDispatch {
  97. public:
  98. typedef void (*kern_ptr_t)(T*, const T*, const T*, Param);
  99. static kern_ptr_t dispatch(int chl_mul, int fh, int fw, int sh, int sw) {
  100. if (chl_mul == 1) {
  101. if (fh == 3 && fw == 3)
  102. return d1<1, 3, 3>(sh, sw);
  103. if (fh == 4 && fw == 4)
  104. return d1<1, 4, 4>(sh, sw);
  105. }
  106. return d1<0, 0, 0>(sh, sw);
  107. }
  108. private:
  109. template <int chl_mul, int fh, int fw>
  110. static kern_ptr_t d1(int sh, int sw) {
  111. if (sh == 1 && sw == 1)
  112. return kern_bwd_data<T, chl_mul, fh, fw, 1, 1>;
  113. if (sh == 1 && sw == 2)
  114. return kern_bwd_data<T, chl_mul, fh, fw, 1, 2>;
  115. if (sh == 2 && sw == 1)
  116. return kern_bwd_data<T, chl_mul, fh, fw, 2, 1>;
  117. if (sh == 2 && sw == 2)
  118. return kern_bwd_data<T, chl_mul, fh, fw, 2, 2>;
  119. return kern_bwd_data<T, chl_mul, fh, fw, 0, 0>;
  120. }
  121. };
  122. } // anonymous namespace
  123. template <typename T>
  124. void chanwise::run_bwd_data(T* src_grad, const T* dst_grad, const T* flt,
  125. const Param& param, hipStream_t stream) {
  126. typename KernDispatch<T>::kern_ptr_t kern =
  127. KernDispatch<T>::dispatch(param.chl_mul, param.flt_h, param.flt_w,
  128. param.stride_h, param.stride_w);
  129. int nr_thread = 256, nr_out_dimx = param.src_h * param.src_w * param.batch;
  130. dim3 nr_block(param.src_chl,
  131. std::min(512, max(nr_out_dimx / (nr_thread * 4), 1)));
  132. uint32_t shared = param.chl_mul * param.flt_h * param.flt_w * sizeof(T);
  133. kern<<<nr_block, nr_thread, shared, stream>>>(src_grad, dst_grad, flt,
  134. param);
  135. after_kernel_launch();
  136. }
  137. namespace megdnn {
  138. namespace rocm {
  139. namespace convolution {
  140. namespace chanwise {
  141. #define INST(_dt) \
  142. template void run_bwd_data( \
  143. DTypeTrait<_dt>::ctype*, const DTypeTrait<_dt>::ctype*, \
  144. const DTypeTrait<_dt>::ctype*, const Param&, hipStream_t);
  145. MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(INST)
  146. #undef INST
  147. #undef DO_INST
  148. } // namespace chanwise
  149. } // namespace convolution
  150. } // namespace rocm
  151. } // namespace megdnn
  152. // vim: syntax=cuda.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台