GitOrigin-RevId: 0dba39f56b
tags/v0.4.0
| @@ -1,145 +0,0 @@ | |||||
| /** | |||||
| * \file dnn/src/cuda/group_local/cuda_interface.cu | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #include "./cuda_interface.h" | |||||
| #include "src/cuda/utils.cuh" | |||||
| namespace megdnn { | |||||
| namespace cuda { | |||||
| // src layout is (N, G, IC, IH, IW) | |||||
| // filter layout is (G, OH, OW, IC, FH, FW, OC) | |||||
| // dst layout is (N, G, OC, OH, OW) | |||||
| // NR_THREADS is 256 | |||||
| // gridDim.z is G | |||||
| // gridDim.y is OC*OH*OW/NR_THREADS | |||||
| // gridDim.x is N/NB | |||||
| // blockDim.x is NR_THREADS | |||||
| // INs and ONs are the stride on the src/dst batch size dim | |||||
| // IC and OC are nr. channels per group | |||||
| // Each thread tackles with NB (actually NB_cur if non-multiple-of-NB N is considered). | |||||
| // Let oid = blockIdx.y*NR_THREADS + threadIdx.x (global thread ID along block | |||||
| // axis y), and we flatten (OC, OH, OW) into one dimension, then each thread | |||||
| // calculates the answer at dst position (n, blockIdx.z, oid), where n ranges | |||||
| // from blockDim.x*NB + 0 to blockDim.x*NB + (NB-1). | |||||
| // IC is processed at stride of ICB. On entrance of each iteration of the loop, | |||||
| // NB * ICB spatial src planes are loaded into shared memory (presumably src | |||||
| // spatial size is small). | |||||
| template <uint32_t NB, uint32_t ICB, bool is_xcorr> | |||||
| __global__ void forward_kernel(const float * __restrict__ src, | |||||
| const float * __restrict__ filter, | |||||
| float * __restrict__ dst, | |||||
| uint32_t N, | |||||
| uint32_t IC, uint32_t IH, uint32_t IW, | |||||
| uint32_t OC, uint32_t OH, uint32_t OW, | |||||
| uint32_t FH, uint32_t FW, | |||||
| uint32_t G, | |||||
| uint32_t INs, uint32_t ONs, | |||||
| uint32_t PH, uint32_t PW, | |||||
| uint32_t SH, uint32_t SW) | |||||
| { | |||||
| // NB * ICB * sizeof(float) * IH * IW | |||||
| extern __shared__ float shared_mem[]; | |||||
| float *src_cache = shared_mem; | |||||
| uint32_t tid = threadIdx.x; | |||||
| uint32_t tstride = blockDim.x; | |||||
| uint32_t oid = tid + blockIdx.y * tstride; | |||||
| src += blockIdx.x*NB * INs + blockIdx.z*IC*IH*IW; | |||||
| dst += blockIdx.x*NB * ONs + blockIdx.z*OC*OH*OW; | |||||
| filter += blockIdx.z*OH*OW*IC*FH*FW*OC; | |||||
| uint32_t op = oid / OC; | |||||
| uint32_t oc = oid % OC; | |||||
| uint32_t oh = op / OW; | |||||
| uint32_t ow = op % OW; | |||||
| float dst_reg[NB]; | |||||
| for (uint32_t nb = 0; nb < NB; ++nb) dst_reg[nb] = 0.0f; | |||||
| uint32_t NB_cur = min(N-blockIdx.x*NB, NB); | |||||
| for (uint32_t ic = 0; ic < IC; ic += ICB) { | |||||
| // read ICB-channel src | |||||
| // (NB, ICB, IHs, IWs) | |||||
| uint32_t ICB_cur = min(ICB, IC-ic); | |||||
| for (uint32_t i = tid; i < NB_cur*ICB*IH*IW; i += tstride) { | |||||
| uint32_t ip = i % (IH*IW); | |||||
| uint32_t icb = i / (IH*IW) % ICB; | |||||
| uint32_t nb = i / (IH*IW) / ICB; | |||||
| src_cache[i] = | |||||
| (icb < ICB_cur) * src[nb*INs + min(IC-1, (ic+icb))*IH*IW + ip]; | |||||
| } | |||||
| __syncthreads(); | |||||
| if (oid < OC*OH*OW) | |||||
| for (uint32_t fh = 0; fh < FH; ++fh) | |||||
| { | |||||
| uint32_t ih; | |||||
| if (is_xcorr) ih = oh*SH + fh - PH; else ih = oh*SH + (FH-fh-1) - PH; | |||||
| if (ih < IH) | |||||
| for (uint32_t fw = 0; fw < FW; ++fw) | |||||
| { | |||||
| uint32_t iw; | |||||
| if (is_xcorr) iw = ow*SW + fw - PW; else iw = ow*SW + (FW-fw-1) - PW; | |||||
| if (iw < IW) | |||||
| for (uint32_t icb = 0; icb < ICB_cur; ++icb) { | |||||
| uint32_t fid = op*IC*FH*FW*OC + (ic+icb)*FH*FW*OC + | |||||
| fh*FW*OC + fw*OC + oc; | |||||
| float fval = filter[fid]; | |||||
| float src_reg[NB]; | |||||
| #pragma unroll | |||||
| for (uint32_t nb = 0; nb < NB; ++nb) { | |||||
| src_reg[nb] = src_cache[nb*ICB*IH*IW + icb*IH*IW + ih*IW + iw]; | |||||
| } | |||||
| #pragma unroll | |||||
| for (uint32_t nb = 0; nb < NB; ++nb) { | |||||
| dst_reg[nb] += src_reg[nb]*fval; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| __syncthreads(); | |||||
| } | |||||
| if (oid < OC*OH*OW) { | |||||
| for (uint32_t nb = 0; nb < NB_cur; ++nb) { | |||||
| dst[nb*ONs + oc*OH*OW + op] = dst_reg[nb]; | |||||
| } | |||||
| } | |||||
| } | |||||
| void run_inference_kernel(const float *src, const float *filter, float *dst, | |||||
| float *wptr, | |||||
| uint32_t N, uint32_t IC, uint32_t IH, uint32_t IW, | |||||
| uint32_t OC, uint32_t OH, uint32_t OW, | |||||
| uint32_t FH, uint32_t FW, | |||||
| uint32_t G, | |||||
| uint32_t PH, uint32_t PW, | |||||
| uint32_t SH, uint32_t SW, | |||||
| cudaStream_t stream) | |||||
| { | |||||
| MEGDNN_MARK_USED_VAR(wptr); | |||||
| size_t threads = 256; | |||||
| const size_t NB = 4, ICB = 4; | |||||
| dim3 blocks = dim3(DIVUP(N, NB), DIVUP(OC*OH*OW, threads), G); | |||||
| uint32_t INs = G*IC*IH*IW, ONs = G*OC*OH*OW; | |||||
| forward_kernel<NB, ICB, true><<<blocks, threads, | |||||
| NB*ICB*sizeof(float)*IH*IW, stream>>>(src, filter, dst, | |||||
| N, | |||||
| IC, IH, IW, | |||||
| OC, OH, OW, | |||||
| FH, FW, | |||||
| G, | |||||
| INs, ONs, | |||||
| PH, PW, | |||||
| SH, SW); | |||||
| after_kernel_launch(); | |||||
| } | |||||
| } // namespace cuda | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -0,0 +1,146 @@ | |||||
| /** | |||||
| * \file dnn/src/cuda/group_local/forward/kern.cu | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "src/cuda/group_local/forward/kern.cuh" | |||||
| #include "src/cuda/utils.cuh" | |||||
| using namespace megdnn; | |||||
| using namespace cuda; | |||||
| namespace { | |||||
| constexpr size_t NB = 4, ICB = 4; | |||||
| // src layout is (N, G, IC, IH, IW) | |||||
| // filter layout is (G, OH, OW, IC, FH, FW, OC) | |||||
| // dst layout is (N, G, OC, OH, OW) | |||||
| // NR_THREADS is 256 | |||||
| // gridDim.z is G | |||||
| // gridDim.y is OC*OH*OW/NR_THREADS | |||||
| // gridDim.x is N/NB | |||||
| // blockDim.x is NR_THREADS | |||||
| // INs and ONs are the stride on the src/dst batch size dim | |||||
| // IC and OC are nr. channels per group | |||||
| // Each thread tackles with NB (actually NB_cur if non-multiple-of-NB N is | |||||
| // considered). Let oid = blockIdx.y*NR_THREADS + threadIdx.x (global thread ID | |||||
| // along block axis y), and we flatten (OC, OH, OW) into one dimension, then | |||||
| // each thread calculates the answer at dst position (n, blockIdx.z, oid), where | |||||
| // n ranges from blockDim.x*NB + 0 to blockDim.x*NB + (NB-1). IC is processed at | |||||
| // stride of ICB. On entrance of each iteration of the loop, NB * ICB spatial | |||||
| // src planes are loaded into shared memory (presumably src spatial size is | |||||
| // small). | |||||
| template <uint32_t NB, uint32_t ICB, bool is_xcorr> | |||||
| __global__ void forward_kernel(const float* __restrict__ src, | |||||
| const float* __restrict__ filter, | |||||
| float* __restrict__ dst, uint32_t N, uint32_t IC, | |||||
| uint32_t IH, uint32_t IW, uint32_t OC, | |||||
| uint32_t OH, uint32_t OW, uint32_t FH, | |||||
| uint32_t FW, uint32_t INs, uint32_t ONs, | |||||
| uint32_t PH, uint32_t PW, uint32_t SH, | |||||
| uint32_t SW) { | |||||
| // NB * ICB * sizeof(float) * IH * IW | |||||
| extern __shared__ float shared_mem[]; | |||||
| float* src_cache = shared_mem; | |||||
| uint32_t tid = threadIdx.x; | |||||
| uint32_t tstride = blockDim.x; | |||||
| uint32_t oid = tid + blockIdx.y * tstride; | |||||
| src += blockIdx.x * NB * INs + blockIdx.z * IC * IH * IW; | |||||
| dst += blockIdx.x * NB * ONs + blockIdx.z * OC * OH * OW; | |||||
| filter += blockIdx.z * OH * OW * IC * FH * FW * OC; | |||||
| uint32_t op = oid / OC; | |||||
| uint32_t oc = oid % OC; | |||||
| uint32_t oh = op / OW; | |||||
| uint32_t ow = op % OW; | |||||
| float dst_reg[NB]; | |||||
| for (uint32_t nb = 0; nb < NB; ++nb) | |||||
| dst_reg[nb] = 0.0f; | |||||
| uint32_t NB_cur = min(N - blockIdx.x * NB, NB); | |||||
| for (uint32_t ic = 0; ic < IC; ic += ICB) { | |||||
| // read ICB-channel src | |||||
| // (NB, ICB, IHs, IWs) | |||||
| uint32_t ICB_cur = min(ICB, IC - ic); | |||||
| for (uint32_t i = tid; i < NB_cur * ICB * IH * IW; i += tstride) { | |||||
| uint32_t ip = i % (IH * IW); | |||||
| uint32_t icb = i / (IH * IW) % ICB; | |||||
| uint32_t nb = i / (IH * IW) / ICB; | |||||
| src_cache[i] = | |||||
| (icb < ICB_cur) * | |||||
| src[nb * INs + min(IC - 1, (ic + icb)) * IH * IW + ip]; | |||||
| } | |||||
| __syncthreads(); | |||||
| if (oid < OC * OH * OW) | |||||
| for (uint32_t fh = 0; fh < FH; ++fh) { | |||||
| uint32_t ih; | |||||
| if (is_xcorr) | |||||
| ih = oh * SH + fh - PH; | |||||
| else | |||||
| ih = oh * SH + (FH - fh - 1) - PH; | |||||
| if (ih < IH) | |||||
| for (uint32_t fw = 0; fw < FW; ++fw) { | |||||
| uint32_t iw; | |||||
| if (is_xcorr) | |||||
| iw = ow * SW + fw - PW; | |||||
| else | |||||
| iw = ow * SW + (FW - fw - 1) - PW; | |||||
| if (iw < IW) | |||||
| for (uint32_t icb = 0; icb < ICB_cur; ++icb) { | |||||
| uint32_t fid = op * IC * FH * FW * OC + | |||||
| (ic + icb) * FH * FW * OC + | |||||
| fh * FW * OC + fw * OC + oc; | |||||
| float fval = filter[fid]; | |||||
| float src_reg[NB]; | |||||
| #pragma unroll | |||||
| for (uint32_t nb = 0; nb < NB; ++nb) { | |||||
| src_reg[nb] = src_cache[nb * ICB * IH * IW + | |||||
| icb * IH * IW + | |||||
| ih * IW + iw]; | |||||
| } | |||||
| #pragma unroll | |||||
| for (uint32_t nb = 0; nb < NB; ++nb) { | |||||
| dst_reg[nb] += src_reg[nb] * fval; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| __syncthreads(); | |||||
| } | |||||
| if (oid < OC * OH * OW) { | |||||
| for (uint32_t nb = 0; nb < NB_cur; ++nb) { | |||||
| dst[nb * ONs + oc * OH * OW + op] = dst_reg[nb]; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| void group_local::exec(const float* src, const float* filter, float* dst, | |||||
| float* wptr, uint32_t N, uint32_t IC, uint32_t IH, | |||||
| uint32_t IW, uint32_t OC, uint32_t OH, uint32_t OW, | |||||
| uint32_t FH, uint32_t FW, uint32_t G, uint32_t PH, | |||||
| uint32_t PW, uint32_t SH, uint32_t SW, | |||||
| cudaStream_t stream) { | |||||
| MEGDNN_MARK_USED_VAR(wptr); | |||||
| size_t threads = 256; | |||||
| dim3 blocks = dim3(DIVUP(N, NB), DIVUP(OC * OH * OW, threads), G); | |||||
| uint32_t INs = G * IC * IH * IW, ONs = G * OC * OH * OW; | |||||
| forward_kernel<NB, ICB, true> | |||||
| <<<blocks, threads, NB * ICB * sizeof(float) * IH * IW, stream>>>( | |||||
| src, filter, dst, N, IC, IH, IW, OC, OH, OW, FH, FW, INs, | |||||
| ONs, PH, PW, SH, SW); | |||||
| after_kernel_launch(); | |||||
| } | |||||
| size_t group_local::get_share_mem_in_bytes(uint32_t IH, uint32_t IW) { | |||||
| return NB * ICB * sizeof(float) * IH * IW; | |||||
| } | |||||
| @@ -14,8 +14,9 @@ | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace cuda { | namespace cuda { | ||||
| namespace group_local { | |||||
| void run_inference_kernel(const float *src, const float *filter, float *dst, | |||||
| void exec(const float *src, const float *filter, float *dst, | |||||
| float *wptr, | float *wptr, | ||||
| uint32_t N, uint32_t IC, uint32_t IH, uint32_t IW, | uint32_t N, uint32_t IC, uint32_t IH, uint32_t IW, | ||||
| uint32_t OC, uint32_t OH, uint32_t OW, | uint32_t OC, uint32_t OH, uint32_t OW, | ||||
| @@ -25,6 +26,10 @@ void run_inference_kernel(const float *src, const float *filter, float *dst, | |||||
| uint32_t SH, uint32_t SW, | uint32_t SH, uint32_t SW, | ||||
| cudaStream_t stream); | cudaStream_t stream); | ||||
| size_t get_share_mem_in_bytes(uint32_t IH, uint32_t IW); | |||||
| } // namespace group_local | |||||
| } // namespace cuda | } // namespace cuda | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * \file dnn/src/cuda/group_local/fwd.cpp | |||||
| * \file dnn/src/cuda/group_local/forward/opr_impl.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
| * | * | ||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | ||||
| @@ -14,7 +14,7 @@ | |||||
| #include "src/cuda/local/local.cuh" | #include "src/cuda/local/local.cuh" | ||||
| #include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
| #include "./cuda_interface.h" | |||||
| #include "src/cuda/group_local/forward/kern.cuh" | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace cuda { | namespace cuda { | ||||
| @@ -46,7 +46,7 @@ void GroupLocalForwardImpl::exec(_megdnn_tensor_in src, | |||||
| auto one = handle->one_device(); | auto one = handle->one_device(); | ||||
| auto zero = handle->zero_device(); | auto zero = handle->zero_device(); | ||||
| if (prefer_inference_kernel(src.layout, filter.layout, dst.layout)) { | if (prefer_inference_kernel(src.layout, filter.layout, dst.layout)) { | ||||
| run_inference_kernel(sptr, fptr, dptr, wptr, | |||||
| group_local::exec(sptr, fptr, dptr, wptr, | |||||
| N, IC, IH, IW, | N, IC, IH, IW, | ||||
| OC, OH, OW, | OC, OH, OW, | ||||
| FH, FW, | FH, FW, | ||||
| @@ -141,11 +141,14 @@ bool GroupLocalForwardImpl::prefer_inference_kernel(const TensorLayout &src, | |||||
| const TensorLayout &filter, | const TensorLayout &filter, | ||||
| const TensorLayout &dst) | const TensorLayout &dst) | ||||
| { | { | ||||
| megdnn_ignore(filter); | |||||
| megdnn_ignore(dst); | |||||
| return src.shape[0] <= 8; | |||||
| MEGDNN_MARK_USED_VAR(filter); | |||||
| MEGDNN_MARK_USED_VAR(dst); | |||||
| auto handle = concrete_handle(this->handle()); | |||||
| size_t N = src.shape[0], IH = src.shape[2], IW = src.shape[3]; | |||||
| return N <= 8 && handle->device_prop().sharedMemPerBlock >= | |||||
| group_local::get_share_mem_in_bytes(IH, IW); | |||||
| } | } | ||||
| } // namespace cuda | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | |||||
| } // namespace cuda | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | |||||