/** * \file src/rocm/convolution/forward/inplace_matmul_impl.cpp.hip * * This file is part of MegDNN, a deep neural network run-time library * developed by Megvii. * * \copyright Copyright (c) 2014-2019 Megvii Inc. All rights reserved. */ #include "./inplace_matmul_impl.h.hip" #include "src/rocm/utils.h.hip" using namespace megdnn; using namespace rocm; namespace { struct BufferFetcherTexture { hipTextureObject_t tex; __device__ __forceinline__ float get(uint32_t offset) { return tex1Dfetch(tex, offset); } }; struct BufferFetcherRaw { const float* ptr; __device__ __forceinline__ float get(uint32_t offset) { return ptr[offset]; } }; struct BufferFetcherTextureHost { bool init_succ; BufferFetcherTexture val; BufferFetcherTextureHost(float* p, const size_t n); ~BufferFetcherTextureHost() { reset(); } void reset() { if (init_succ) { hip_check(hipDestroyTextureObject(val.tex)); init_succ = false; } } }; BufferFetcherTextureHost::BufferFetcherTextureHost(float* p, const size_t n) { init_succ = false; hipTextureObject_t tex_obj; hipResourceDesc res_desc; memset(&res_desc, 0, sizeof(hipResourceDesc)); res_desc.resType = hipResourceTypeLinear; res_desc.res.linear.devPtr = static_cast(p); res_desc.res.linear.sizeInBytes = n * sizeof(float); res_desc.res.linear.desc = hipCreateChannelDesc(32, 0, 0, 0, hipChannelFormatKindFloat); hipTextureDesc tex_desc; memset(&tex_desc, 0, sizeof(hipTextureDesc)); if (hipCreateTextureObject(&tex_obj, &res_desc, &tex_desc, NULL) == hipSuccess) { val.tex = tex_obj; init_succ = true; } else { hipGetLastError(); // reset error } } template struct KernelPtr { typedef void (*type)(BufferFetcher, BufferFetcher, float*, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t); }; //! 1 -> 0xffffffff, 0 -> 0x00000000 __device__ __forceinline__ uint32_t bool_as_mask(uint32_t cond) { return (!cond) - 1u; } union FloatAndU32 { float f; uint32_t u; }; //! \p mask must be either all 1 or 0 bits template __device__ __forceinline__ float visit_with_mask(BufferFetcher buf, uint32_t offset, uint32_t mask) { FloatAndU32 f; f.f = buf.get(offset & mask); f.u &= mask; return f.f; } template __global__ void conv_kernel(BufferFetcher src, BufferFetcher filter, float* dst, const uint32_t INP_BS, const uint32_t OUT_BS, const uint32_t IC, const uint32_t IH, const uint32_t IW, const uint32_t OC, const uint32_t OH, const uint32_t OW, const uint32_t FH, const uint32_t FW, const uint32_t SH, const uint32_t SW, const uint32_t PH, const uint32_t PW) { const uint32_t BM = BY < BX ? BY : BX; const uint32_t n = blockIdx.z; const uint32_t tidx = threadIdx.x; const uint32_t tidy = threadIdx.y; const uint32_t posx = blockIdx.x * blockDim.x + threadIdx.x; const uint32_t posy = blockIdx.y * blockDim.y + threadIdx.y; const uint32_t posx2 = posx << 2; const uint32_t posy2 = posy << 2; const uint32_t heightA = OC; const uint32_t widthA = IC * FH * FW; const uint32_t heightB = widthA; const uint32_t widthB = OH * OW; const uint32_t oh0 = (posx2 + 0) / OW * SH; const uint32_t ow0 = (posx2 + 0) % OW * SW; const uint32_t op0 = oh0 * IW + ow0; const uint32_t oh1 = (posx2 + 1) / OW * SH; const uint32_t ow1 = (posx2 + 1) % OW * SW; const uint32_t op1 = oh1 * IW + ow1; const uint32_t oh2 = (posx2 + 2) / OW * SH; const uint32_t ow2 = (posx2 + 2) % OW * SW; const uint32_t op2 = oh2 * IW + ow2; const uint32_t oh3 = (posx2 + 3) / OW * SH; const uint32_t ow3 = (posx2 + 3) % OW * SW; const uint32_t op3 = oh3 * IW + ow3; const uint32_t FP = FH * FW; __shared__ float4 localA[BY][BM]; __shared__ float4 localB[BM][BX]; uint32_t i = 0u; uint32_t offsetA = posy2 * widthA + tidx; uint32_t offsetB = n * INP_BS - PH * IW - PW; float4 sum0 = {0.0f, 0.0f, 0.0f, 0.0f}, sum1 = {0.0f, 0.0f, 0.0f, 0.0f}, sum2 = {0.0f, 0.0f, 0.0f, 0.0f}, sum3 = {0.0f, 0.0f, 0.0f, 0.0f}; uint32_t fh = tidy / FW % FH; uint32_t fw = tidy % FW; uint32_t ic = tidy / (FH * FW); uint32_t icm = tidy % (FH * FW); const uint32_t fhs = BM / FW % FH; const uint32_t fws = BM % FW; const uint32_t ics = BM / (FH * FW); const uint32_t icms = BM % (FH * FW); for (; i < widthA; i += BM, offsetA += BM) { // load localA if (tidx < BM) { localA[tidy][tidx].x = filter.get(offsetA + 0 * widthA); localA[tidy][tidx].y = filter.get(offsetA + 1 * widthA); localA[tidy][tidx].z = filter.get(offsetA + 2 * widthA); localA[tidy][tidx].w = filter.get(offsetA + 3 * widthA); } // load localB uint32_t fh2, fw2; if (is_xcorr) { fh2 = fh; fw2 = fw; } else { fh2 = FH - fh - 1; fw2 = FW - fw - 1; } if (tidy < BM) { uint32_t tmp = offsetB + (ic * IH + (fh2)) * IW + (fw2), ok = bool_as_mask(tidy + i < heightB), p0 = bool_as_mask(fh2 + oh0 >= PH && fh2 + oh0 < IH + PH && fw2 + ow0 >= PW && fw2 + ow0 < IW + PW), p1 = bool_as_mask(fh2 + oh1 >= PH && fh2 + oh1 < IH + PH && fw2 + ow1 >= PW && fw2 + ow1 < IW + PW), p2 = bool_as_mask(fh2 + oh2 >= PH && fh2 + oh2 < IH + PH && fw2 + ow2 >= PW && fw2 + ow2 < IW + PW), p3 = bool_as_mask(fh2 + oh3 >= PH && fh2 + oh3 < IH + PH && fw2 + ow3 >= PW && fw2 + ow3 < IW + PW); localB[tidy][tidx].x = visit_with_mask(src, tmp + op0, ok & p0); localB[tidy][tidx].y = visit_with_mask(src, tmp + op1, ok & p1); localB[tidy][tidx].z = visit_with_mask(src, tmp + op2, ok & p2); localB[tidy][tidx].w = visit_with_mask(src, tmp + op3, ok & p3); } __syncthreads(); for (uint32_t j = 0u; j < BM; ++j) { float4 tmpA = localA[tidy][j]; float4 tmpB = localB[j][tidx]; sum0.x += tmpA.x * tmpB.x; sum0.y += tmpA.x * tmpB.y; sum0.z += tmpA.x * tmpB.z; sum0.w += tmpA.x * tmpB.w; sum1.x += tmpA.y * tmpB.x; sum1.y += tmpA.y * tmpB.y; sum1.z += tmpA.y * tmpB.z; sum1.w += tmpA.y * tmpB.w; sum2.x += tmpA.z * tmpB.x; sum2.y += tmpA.z * tmpB.y; sum2.z += tmpA.z * tmpB.z; sum2.w += tmpA.z * tmpB.w; sum3.x += tmpA.w * tmpB.x; sum3.y += tmpA.w * tmpB.y; sum3.z += tmpA.w * tmpB.z; sum3.w += tmpA.w * tmpB.w; } fw += fws; fh += fhs; fh += (fw >= FW); fh -= (fh >= FH) * FH; fw -= (fw >= FW) * FW; ic += ics; icm += icms; ic += (icm >= FP); icm -= (icm >= FP) * FP; __syncthreads(); } const uint32_t dst_idx = n * OUT_BS + posy2 * widthB + posx2; bool y0 = (posy2 + 0 < heightA); bool y1 = (posy2 + 1 < heightA); bool y2 = (posy2 + 2 < heightA); bool y3 = (posy2 + 3 < heightA); bool x0 = (posx2 + 0 < widthB); bool x1 = (posx2 + 1 < widthB); bool x2 = (posx2 + 2 < widthB); bool x3 = (posx2 + 3 < widthB); if (y0) { if (x0) dst[dst_idx + 0 * widthB + 0] = sum0.x; if (x1) dst[dst_idx + 0 * widthB + 1] = sum0.y; if (x2) dst[dst_idx + 0 * widthB + 2] = sum0.z; if (x3) dst[dst_idx + 0 * widthB + 3] = sum0.w; } if (y1) { if (x0) dst[dst_idx + 1 * widthB + 0] = sum1.x; if (x1) dst[dst_idx + 1 * widthB + 1] = sum1.y; if (x2) dst[dst_idx + 1 * widthB + 2] = sum1.z; if (x3) dst[dst_idx + 1 * widthB + 3] = sum1.w; } if (y2) { if (x0) dst[dst_idx + 2 * widthB + 0] = sum2.x; if (x1) dst[dst_idx + 2 * widthB + 1] = sum2.y; if (x2) dst[dst_idx + 2 * widthB + 2] = sum2.z; if (x3) dst[dst_idx + 2 * widthB + 3] = sum2.w; } if (y3) { if (x0) dst[dst_idx + 3 * widthB + 0] = sum3.x; if (x1) dst[dst_idx + 3 * widthB + 1] = sum3.y; if (x2) dst[dst_idx + 3 * widthB + 2] = sum3.z; if (x3) dst[dst_idx + 3 * widthB + 3] = sum3.w; } } } // anonymous namespace void convolution::exec_inplace_matmul_fwd( const float* src, const float* filter, float* dst, size_t N, size_t INP_BS, size_t OUT_BS, size_t IC, size_t IH, size_t IW, size_t OC, size_t OH, size_t OW, size_t FH, size_t FW, size_t PH, size_t PW, size_t SH, size_t SW, bool is_xcorr, hipStream_t stream) { BufferFetcherTextureHost src_tex(const_cast(src), N * INP_BS), filter_tex(const_cast(filter), OC * IC * FH * FW); BufferFetcherRaw src_buf, filter_buf; src_buf.ptr = src; filter_buf.ptr = filter; if (!src_tex.init_succ || !filter_tex.init_succ) { src_tex.reset(); filter_tex.reset(); } int m = OC; int n = OH * OW; int BY = 1; int BX = 1; if (m <= 64) { while (BY < 16 && (BY << 2) < m) BY <<= 1; BX = 256 / BY; } else if (n <= 64) { while (BX < 16 && (BX << 2) < n) BX <<= 1; BY = 256 / BX; } else { BX = BY = 16; } dim3 blocks((OH * OW + BX * 4 - 1) / (BX * 4), (OC + BY * 4 - 1) / (BY * 4), N); dim3 threads(BX, BY); #define DISPATCH_BX_BY(BX, BY) \ do { \ if (src_tex.init_succ) { \ KernelPtr::type kptr; \ if (is_xcorr) { \ kptr = conv_kernel; \ } else { \ kptr = conv_kernel; \ } \ kptr<<>>( \ src_tex.val, filter_tex.val, dst, INP_BS, OUT_BS, IC, IH, \ IW, OC, OH, OW, FH, FW, SH, SW, PH, PW); \ } else { \ KernelPtr::type kptr; \ if (is_xcorr) { \ kptr = conv_kernel; \ } else { \ kptr = conv_kernel; \ } \ kptr<<>>( \ src_buf, filter_buf, dst, INP_BS, OUT_BS, IC, IH, IW, OC, \ OH, OW, FH, FW, SH, SW, PH, PW); \ } \ } while (0) #define DISPATCH_BX(BX) \ do { \ DISPATCH_BX_BY(BX, 256 / BX); \ } while (0) #define DISPATCH() \ do { \ switch (BX) { \ case 1: \ DISPATCH_BX(1); \ break; \ case 2: \ DISPATCH_BX(2); \ break; \ case 4: \ DISPATCH_BX(4); \ break; \ case 8: \ DISPATCH_BX(8); \ break; \ case 16: \ DISPATCH_BX(16); \ break; \ case 32: \ DISPATCH_BX(32); \ break; \ case 64: \ DISPATCH_BX(64); \ break; \ case 128: \ DISPATCH_BX(128); \ break; \ case 256: \ DISPATCH_BX(256); \ break; \ default: \ report_error("no usable kernel"); \ } \ } while (0) DISPATCH(); #undef DISPATCH #undef DISPATCH_BX #undef DISPATCH_BX_BY after_kernel_launch(); } // vim: syntax=cpp.doxygen