GitOrigin-RevId: f36495a46a
tags/v0.4.0
| @@ -14,43 +14,7 @@ | |||
| #include "src/cuda/utils.h" | |||
| #include "src/cuda/handle.h" | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| namespace local { | |||
| void check_input(size_t N, | |||
| size_t IC, size_t IH, size_t IW, | |||
| size_t OC, size_t OH, size_t OW, | |||
| size_t FH, size_t FW, | |||
| size_t INs, size_t ONs, | |||
| size_t PH, size_t PW, | |||
| size_t SH, size_t SW, | |||
| bool is_xcorr) | |||
| { | |||
| megdnn_ignore(N); | |||
| megdnn_ignore(IC); | |||
| megdnn_ignore(IH); | |||
| megdnn_ignore(IW); | |||
| megdnn_ignore(OC); | |||
| megdnn_ignore(OH); | |||
| megdnn_ignore(OW); | |||
| megdnn_ignore(FH); | |||
| megdnn_ignore(FW); | |||
| megdnn_ignore(INs); | |||
| megdnn_ignore(ONs); | |||
| megdnn_ignore(PH); | |||
| megdnn_ignore(PW); | |||
| megdnn_ignore(SH); | |||
| megdnn_ignore(SW); | |||
| megdnn_ignore(is_xcorr); | |||
| // shared memory constraint | |||
| megdnn_assert(IH*IW <= 768, "spatial size should not be larger than 768."); | |||
| // megdnn_assert(4 * 4 * 4 * IH * IW <= 49152); | |||
| } | |||
| } // namespace local | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| #include "src/common/utils.cuh" | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| @@ -94,13 +58,9 @@ void LocalForwardImpl::exec(_megdnn_tensor_in src, | |||
| param().stride_h, param().stride_w, | |||
| cublas, stream, | |||
| one, zero); | |||
| } else { | |||
| local::check_input(N, IC, IH, IW, OC, OH, OW, FH, FW, | |||
| IC*IH*IW, OC*OH*OW, | |||
| param().pad_h, param().pad_w, | |||
| param().stride_h, param().stride_w, | |||
| is_xcorr); | |||
| local::forward_proxy_weiming(src.ptr<dt_float32>(), | |||
| } else if (local::forward_proxy_default_share_mem_in_bytes(IH, IW) <= | |||
| handle->device_prop().sharedMemPerBlock) { | |||
| local::forward_proxy_default(src.ptr<dt_float32>(), | |||
| filter.ptr<dt_float32>(), | |||
| dst.ptr<dt_float32>(), | |||
| N, | |||
| @@ -112,6 +72,11 @@ void LocalForwardImpl::exec(_megdnn_tensor_in src, | |||
| param().stride_h, param().stride_w, | |||
| is_xcorr, | |||
| stream); | |||
| } else { | |||
| megdnn_throw(ssprintf( | |||
| "No usable kernel for local conv, src: %s filter: %s \n", | |||
| src.layout.to_string().c_str(), | |||
| filter.layout.to_string().c_str())); | |||
| } | |||
| } | |||
| @@ -18,6 +18,12 @@ namespace megdnn { | |||
| namespace cuda { | |||
| namespace local { | |||
| constexpr size_t Ns = 4, ICs = 4; | |||
| size_t forward_proxy_default_share_mem_in_bytes(size_t IH, size_t IW) { | |||
| return Ns * ICs * sizeof(float) * IH * IW; | |||
| } | |||
| // blockIdx.y is OC*OH*OW/1024 | |||
| // blockIdx.x is N/4 | |||
| // threadIdx.x is [0, 1024) | |||
| @@ -96,7 +102,7 @@ __global__ void forward_kernel(const float * __restrict__ src, | |||
| } | |||
| } | |||
| void forward_proxy_weiming(const float *src, const float *filter, float *dst, | |||
| void forward_proxy_default(const float *src, const float *filter, float *dst, | |||
| size_t N, | |||
| size_t IC, size_t IH, size_t IW, | |||
| size_t OC, size_t OH, size_t OW, | |||
| @@ -108,7 +114,6 @@ void forward_proxy_weiming(const float *src, const float *filter, float *dst, | |||
| cudaStream_t stream) | |||
| { | |||
| size_t threads = 256; | |||
| const size_t Ns = 4, ICs = 4; | |||
| dim3 blocks = dim3(DIVUP(N, Ns), DIVUP(OC*OH*OW, threads)); | |||
| if (is_xcorr) { | |||
| forward_kernel<Ns, ICs, true><<<blocks, threads, | |||
| @@ -17,17 +17,10 @@ namespace megdnn { | |||
| namespace cuda { | |||
| namespace local { | |||
| void check_input(size_t N, | |||
| size_t IC, size_t IH, size_t IW, | |||
| size_t OC, size_t OH, size_t OW, | |||
| size_t FH, size_t FW, | |||
| size_t INs, size_t ONs, | |||
| size_t PH, size_t PW, | |||
| size_t SH, size_t SW, | |||
| bool is_xcorr); | |||
| size_t forward_proxy_default_share_mem_in_bytes(size_t IH, size_t IW); | |||
| void forward_proxy_weiming(const float *src, const float *filter, float *dst, | |||
| size_t N, | |||
| void forward_proxy_default(const float *src, const float *filter, float *dst, | |||
| size_t N, | |||
| size_t IC, size_t IH, size_t IW, | |||
| size_t OC, size_t OH, size_t OW, | |||
| size_t FH, size_t FW, | |||
| @@ -39,7 +32,7 @@ void forward_proxy_weiming(const float *src, const float *filter, float *dst, | |||
| /// forward | |||
| bool can_forward_proxy_convnet(size_t N, | |||
| bool can_forward_proxy_convnet(size_t N, | |||
| size_t IC, size_t IH, size_t IW, | |||
| size_t OC, size_t OH, size_t OW, | |||
| size_t FH, size_t FW, | |||
| @@ -70,7 +63,7 @@ size_t get_workspace_in_floats_forward_proxy_convnet(size_t N, | |||
| /// bwd data | |||
| bool can_backward_data_proxy_convnet(size_t N, | |||
| bool can_backward_data_proxy_convnet(size_t N, | |||
| size_t IC, size_t IH, size_t IW, | |||
| size_t OC, size_t OH, size_t OW, | |||
| size_t FH, size_t FW, | |||
| @@ -78,7 +71,7 @@ bool can_backward_data_proxy_convnet(size_t N, | |||
| size_t PH, size_t PW, | |||
| size_t SH, size_t SW); | |||
| void backward_data_proxy_convnet(const float *filter, | |||
| void backward_data_proxy_convnet(const float *filter, | |||
| const float *diff, | |||
| float *grad, | |||
| float *workspace, | |||
| @@ -103,7 +96,7 @@ size_t get_workspace_in_floats_backward_data_proxy_convnet(size_t N, | |||
| /// bwd filter | |||
| bool can_backward_filter_proxy_convnet(size_t N, | |||
| bool can_backward_filter_proxy_convnet(size_t N, | |||
| size_t IC, size_t IH, size_t IW, | |||
| size_t OC, size_t OH, size_t OW, | |||
| size_t FH, size_t FW, | |||