| @@ -3,3 +3,4 @@ | |||||
| dnn/src/cuda/conv_bias/int8/kimpl/* binary | dnn/src/cuda/conv_bias/int8/kimpl/* binary | ||||
| dnn/src/cuda/conv_bias/int8_imma/kimpl/* binary | dnn/src/cuda/conv_bias/int8_imma/kimpl/* binary | ||||
| dnn/src/cuda/batch_conv_bias/int8/kimpl/* binary | dnn/src/cuda/batch_conv_bias/int8/kimpl/* binary | ||||
| dnn/src/cuda/sass/prebuilt/map_defs.cpp binary | |||||
| @@ -236,6 +236,7 @@ void ConvBiasForwardImpl::AlgoPack::fill_imma_algos() { | |||||
| } | } | ||||
| #endif | #endif | ||||
| ConvBiasForwardImpl::AlgoBase* | ConvBiasForwardImpl::AlgoBase* | ||||
| ConvBiasForwardImpl::AlgoPack::cudnn_conv_from_enum( | ConvBiasForwardImpl::AlgoPack::cudnn_conv_from_enum( | ||||
| cudnnConvolutionFwdAlgo_t algo) { | cudnnConvolutionFwdAlgo_t algo) { | ||||
| @@ -14,11 +14,11 @@ | |||||
| #include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/cuda/conv_bias/conv_bias_int8.cuh" | |||||
| #include "src/cuda/conv_bias/helper.h" | #include "src/cuda/conv_bias/helper.h" | ||||
| #include "src/cuda/conv_bias/opr_impl.h" | #include "src/cuda/conv_bias/opr_impl.h" | ||||
| #include "src/cuda/handle.h" | |||||
| #include "src/cuda/conv_bias/conv_bias_int8.cuh" | |||||
| #include "src/cuda/convolution_helper/parameter.cuh" | #include "src/cuda/convolution_helper/parameter.cuh" | ||||
| #include "src/cuda/handle.h" | |||||
| #include <cuda.h> | #include <cuda.h> | ||||
| #include <memory> | #include <memory> | ||||
| @@ -521,6 +521,7 @@ private: | |||||
| std::string m_name; | std::string m_name; | ||||
| }; | }; | ||||
| class ConvBiasForwardImpl::AlgoPack { | class ConvBiasForwardImpl::AlgoPack { | ||||
| AlgoPack(const AlgoPack&) = delete; | AlgoPack(const AlgoPack&) = delete; | ||||
| AlgoPack& operator=(const AlgoPack&) = delete; | AlgoPack& operator=(const AlgoPack&) = delete; | ||||
| @@ -6,7 +6,8 @@ | |||||
| * | * | ||||
| * Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | */ | ||||
| #pragma once | #pragma once | ||||
| #include "../elemwise/opr_impl.h" | #include "../elemwise/opr_impl.h" | ||||
| @@ -94,5 +95,5 @@ private: | |||||
| } // namespace cuda | } // namespace cuda | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -10,7 +10,7 @@ | |||||
| */ | */ | ||||
| #include "src/cuda/pooling/opr_impl.h" | #include "src/cuda/pooling/opr_impl.h" | ||||
| #include "./pooling2d_int8_cdiv4hwn4.cuh" | |||||
| #include "./pooling2d_int8.cuh" | |||||
| #include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
| namespace megdnn { | namespace megdnn { | ||||
| @@ -67,7 +67,24 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in ssrc, _megdnn_tensor_out sdst, | |||||
| kern_param.window_h = window_h, kern_param.window_w = window_w, | kern_param.window_h = window_h, kern_param.window_w = window_w, | ||||
| kern_param.sh = sh, kern_param.sw = sw; | kern_param.sh = sh, kern_param.sw = sw; | ||||
| auto&& stream = cuda_stream(handle()); | auto&& stream = cuda_stream(handle()); | ||||
| return pooling2d::_do_pooling2d_int8_cdiv4hwn4( | |||||
| return pooling2d::do_pooling2d_int8_cdiv4hwn4( | |||||
| src.compatible_ptr<int8_t>(), dst.compatible_ptr<int8_t>(), | |||||
| kern_param, stream, static_cast<uint32_t>(param().mode)); | |||||
| } else if (param().format == Format::NCHW4) { | |||||
| pooling2d::Param kern_param; | |||||
| size_t n = src.layout[0], hi = src.layout[2], wi = src.layout[3], | |||||
| c = src.layout[1], ho = dst.layout[2], wo = dst.layout[3]; | |||||
| c = c * 4; | |||||
| size_t ph = param().pad_h, pw = param().pad_w; | |||||
| size_t window_h = param().window_h, window_w = param().window_w; | |||||
| size_t sh = param().stride_h, sw = param().stride_w; | |||||
| kern_param.n = n, kern_param.c = c, kern_param.hi = hi, | |||||
| kern_param.wi = wi, kern_param.ho = ho, kern_param.wo = wo, | |||||
| kern_param.ph = ph, kern_param.pw = pw, | |||||
| kern_param.window_h = window_h, kern_param.window_w = window_w, | |||||
| kern_param.sh = sh, kern_param.sw = sw; | |||||
| auto&& stream = cuda_stream(handle()); | |||||
| return pooling2d::do_pooling2d_int8_ncdiv4hw4( | |||||
| src.compatible_ptr<int8_t>(), dst.compatible_ptr<int8_t>(), | src.compatible_ptr<int8_t>(), dst.compatible_ptr<int8_t>(), | ||||
| kern_param, stream, static_cast<uint32_t>(param().mode)); | kern_param, stream, static_cast<uint32_t>(param().mode)); | ||||
| } | } | ||||
| @@ -8,8 +8,9 @@ | |||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| */ | */ | ||||
| #include "./pooling2d_int8_cdiv4hwn4.cuh" | |||||
| #include "./pooling2d_int8.cuh" | |||||
| #include "src/common/opr_param_defs_enumv.cuh" | #include "src/common/opr_param_defs_enumv.cuh" | ||||
| #include "src/cuda/query_blocksize.cuh" | |||||
| using namespace megdnn; | using namespace megdnn; | ||||
| using namespace cuda; | using namespace cuda; | ||||
| @@ -360,11 +361,65 @@ __global__ void pooling2d_device_template_int8_cdiv4hwn4( | |||||
| ldg_type res = pooler.get_ans(); | ldg_type res = pooler.get_ans(); | ||||
| *(reinterpret_cast<ldg_type*>(g_dst_ptr)) = res; | *(reinterpret_cast<ldg_type*>(g_dst_ptr)) = res; | ||||
| } | } | ||||
| template <typename Pooler> | |||||
| __global__ void pooling2d_device_template_int8_ncdiv4hw4( | |||||
| const int8_t* __restrict__ src, int8_t* __restrict__ dst, Param param) { | |||||
| const int tid = blockIdx.x * blockDim.x + threadIdx.x; | |||||
| using ldg_type = typename Pooler::feed_type; | |||||
| static int constexpr pack_size = 4; | |||||
| static int constexpr ldg_width = sizeof(ldg_type) / sizeof(int32_t); | |||||
| MEGDNN_STATIC_ASSERT( | |||||
| ldg_width == 1, | |||||
| "pooling2d (NCHW4) kernel must use 32bit width ldg instruction"); | |||||
| const int wo_ldg = param.wo / ldg_width; | |||||
| const int c_packed = param.c / pack_size; | |||||
| const int batch = tid / (param.ho * wo_ldg * c_packed); | |||||
| const int chw = tid - batch * param.ho * wo_ldg * c_packed; | |||||
| const int oc_packed = chw / (param.ho * wo_ldg); | |||||
| const int hw = chw - oc_packed * param.ho * wo_ldg; | |||||
| const int oh = hw / wo_ldg; | |||||
| const int ow = (hw - wo_ldg * oh) * ldg_width; | |||||
| if (batch >= param.n || oc_packed >= c_packed || oh >= param.ho || | |||||
| ow >= param.wo) | |||||
| return; | |||||
| const int in_batch_stride = param.hi * param.wi * param.c; | |||||
| const int out_batch_stride = param.ho * param.wo * param.c; | |||||
| const int in_channel_stride = param.hi * param.wi * pack_size; | |||||
| const int out_channel_stride = param.ho * param.wo * pack_size; | |||||
| const int8_t* __restrict__ g_src_ptr = | |||||
| src + batch * in_batch_stride + oc_packed * in_channel_stride; | |||||
| int8_t* __restrict__ g_dst_ptr = dst + batch * out_batch_stride + | |||||
| oc_packed * out_channel_stride + | |||||
| (oh * param.wo + ow) * pack_size; | |||||
| Pooler pooler(param.window_h * param.window_w); | |||||
| pooler.init(); | |||||
| for (int fh = 0; fh < param.window_h; fh++) { | |||||
| uint32_t ih = oh * param.sh + fh - param.ph; | |||||
| for (int fw = 0; fw < param.window_w; fw++) { | |||||
| uint32_t iw = ow * param.sw + fw - param.pw; | |||||
| if (ih < param.hi && iw < param.wi) { | |||||
| const int8_t* __restrict__ cur_src_ptr = | |||||
| g_src_ptr + (ih * param.wi + iw) * pack_size; | |||||
| ldg_type sval = __ldg(reinterpret_cast<const ldg_type*>(cur_src_ptr)); | |||||
| pooler.feed(sval); | |||||
| } | |||||
| } | |||||
| } | |||||
| ldg_type res = pooler.get_ans(); | |||||
| *(reinterpret_cast<ldg_type*>(g_dst_ptr)) = res; | |||||
| } | |||||
| }; // namespace | }; // namespace | ||||
| void megdnn::cuda::pooling2d::_do_pooling2d_int8_cdiv4hwn4( | |||||
| const int8_t* d_src, int8_t* d_dst, const Param& param, | |||||
| cudaStream_t stream, uint32_t mode) { | |||||
| void megdnn::cuda::pooling2d::do_pooling2d_int8_cdiv4hwn4(const int8_t* d_src, | |||||
| int8_t* d_dst, | |||||
| const Param& param, | |||||
| cudaStream_t stream, | |||||
| uint32_t mode) { | |||||
| using Mode = megdnn::param_enumv::Pooling::Mode; | using Mode = megdnn::param_enumv::Pooling::Mode; | ||||
| void (*kern)(const int8_t* __restrict__, int8_t* __restrict__, Param param); | void (*kern)(const int8_t* __restrict__, int8_t* __restrict__, Param param); | ||||
| uint32_t vthreads_x = 0, vthreads_y = param.c / 4; | uint32_t vthreads_x = 0, vthreads_y = param.c / 4; | ||||
| @@ -397,8 +452,7 @@ void megdnn::cuda::pooling2d::_do_pooling2d_int8_cdiv4hwn4( | |||||
| } | } | ||||
| #undef dispatch_pooling_mode | #undef dispatch_pooling_mode | ||||
| constexpr uint32_t threads_x = 16; | constexpr uint32_t threads_x = 16; | ||||
| uint32_t nr_threads = | |||||
| _get_kern_block_size(reinterpret_cast<const void*>(kern)); | |||||
| uint32_t nr_threads = query_blocksize_for_kernel(kern); | |||||
| uint32_t nr_threads_x = std::min(threads_x, vthreads_x), | uint32_t nr_threads_x = std::min(threads_x, vthreads_x), | ||||
| nr_threads_y = std::min(nr_threads / nr_threads_x, vthreads_y); | nr_threads_y = std::min(nr_threads / nr_threads_x, vthreads_y); | ||||
| uint32_t nr_blocks_x = param.ho * param.wo, | uint32_t nr_blocks_x = param.ho * param.wo, | ||||
| @@ -410,4 +464,34 @@ void megdnn::cuda::pooling2d::_do_pooling2d_int8_cdiv4hwn4( | |||||
| after_kernel_launch(); | after_kernel_launch(); | ||||
| } | } | ||||
| void megdnn::cuda::pooling2d::do_pooling2d_int8_ncdiv4hw4(const int8_t* d_src, | |||||
| int8_t* d_dst, | |||||
| const Param& param, | |||||
| cudaStream_t stream, | |||||
| uint32_t mode) { | |||||
| using Mode = megdnn::param_enumv::Pooling::Mode; | |||||
| void (*kern)(const int8_t* __restrict__, int8_t* __restrict__, Param param); | |||||
| uint32_t vthreads = param.n * param.c * param.ho * param.wo / 4; | |||||
| switch (mode) { | |||||
| case Mode::MAX: | |||||
| kern = pooling2d_device_template_int8_ncdiv4hw4< | |||||
| MaxPooler<int8_t, int32_t>>; | |||||
| break; | |||||
| case Mode::AVERAGE: | |||||
| kern = pooling2d_device_template_int8_ncdiv4hw4< | |||||
| MeanIncludeRoundedPooler<int8_t, int32_t, int32_t>>; | |||||
| break; | |||||
| case Mode::AVERAGE_COUNT_EXCLUDE_PADDING: | |||||
| kern = pooling2d_device_template_int8_ncdiv4hw4< | |||||
| MeanExcludeRoundedPooler<int8_t, int32_t, int32_t>>; | |||||
| break; | |||||
| default: | |||||
| megdnn_assert(false, "invalid pooling mode"); | |||||
| } | |||||
| uint32_t nr_threads = query_blocksize_for_kernel(kern); | |||||
| nr_threads = std::min(nr_threads, vthreads); | |||||
| uint32_t nr_blocks = DIVUP(vthreads, nr_threads); | |||||
| kern<<<nr_blocks, nr_threads, 0, stream>>>(d_src, d_dst, param); | |||||
| after_kernel_launch(); | |||||
| } | |||||
| // vim: syntax=cuda.doxygen | // vim: syntax=cuda.doxygen | ||||
| @@ -1,12 +1,13 @@ | |||||
| /** | /** | ||||
| * \file dnn/src/cuda/pooling/pooling2d_int8_cdiv4hwn4.cuh | |||||
| * \file dnn/src/cuda/pooling/pooling2d_int8.cuh | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
| * | * | ||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | ||||
| * | * | ||||
| * Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | */ | ||||
| #pragma once | #pragma once | ||||
| @@ -20,15 +21,16 @@ struct Param { | |||||
| int n, c, hi, wi, ho, wo, ph, pw, window_h, window_w, sh, sw; | int n, c, hi, wi, ho, wo, ph, pw, window_h, window_w, sh, sw; | ||||
| }; | }; | ||||
| uint32_t _get_kern_block_size(const void* kern); | |||||
| void do_pooling2d_int8_cdiv4hwn4(const int8_t* d_src, int8_t* d_dst, | |||||
| const Param& param, cudaStream_t stream, | |||||
| uint32_t mode); | |||||
| void _do_pooling2d_int8_cdiv4hwn4(const int8_t* d_src, int8_t* d_dst, | |||||
| const Param& param, cudaStream_t stream, | |||||
| uint32_t mode); | |||||
| void do_pooling2d_int8_ncdiv4hw4(const int8_t* d_src, int8_t* d_dst, | |||||
| const Param& param, cudaStream_t stream, | |||||
| uint32_t mode); | |||||
| } // namespace pooling2d | } // namespace pooling2d | ||||
| } // namespace cuda | } // namespace cuda | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| // vim: syntax=cuda.doxygen | // vim: syntax=cuda.doxygen | ||||
| @@ -1,27 +0,0 @@ | |||||
| /** | |||||
| * \file dnn/src/cuda/pooling/pooling2d_int8_cdiv4hwn4.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #include "./pooling2d_int8_cdiv4hwn4.cuh" | |||||
| #include "src/cuda/query_blocksize.cuh" | |||||
| namespace megdnn { | |||||
| namespace cuda { | |||||
| namespace pooling2d { | |||||
| uint32_t _get_kern_block_size(const void* kern) { | |||||
| uint32_t ret = query_blocksize_for_kernel(kern); | |||||
| return ret; | |||||
| } | |||||
| } // namespace pooling2d | |||||
| } // namespace cuda | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -82,6 +82,11 @@ void cuda::__throw_cusolver_error__(cusolverStatus_t err, const char* msg) { | |||||
| megdnn_throw(s.c_str()); | megdnn_throw(s.c_str()); | ||||
| } | } | ||||
| void cuda::__throw_cuda_driver_error__(CUresult err, const char* msg) { | |||||
| auto s = ssprintf("cuda driver error %d occurred; expr: %s", int(err), msg); | |||||
| megdnn_throw(s.c_str()); | |||||
| } | |||||
| void cuda::report_error(const char *msg) { | void cuda::report_error(const char *msg) { | ||||
| megdnn_throw(msg); | megdnn_throw(msg); | ||||
| MEGDNN_MARK_USED_VAR(msg); | MEGDNN_MARK_USED_VAR(msg); | ||||
| @@ -118,9 +123,31 @@ bool cuda::is_compute_capability_required(int major, int minor) { | |||||
| (device_prop.major == major && device_prop.minor >= minor); | (device_prop.major == major && device_prop.minor >= minor); | ||||
| } | } | ||||
| bool cuda::is_compute_capability_equalto(int major, int minor) { | |||||
| auto&& device_prop = cuda::current_device_prop(); | |||||
| return device_prop.major == major && device_prop.minor == minor; | |||||
| } | |||||
| size_t cuda::max_batch_x_channel_size() { | size_t cuda::max_batch_x_channel_size() { | ||||
| return current_device_prop().maxGridSize[2]; | return current_device_prop().maxGridSize[2]; | ||||
| } | } | ||||
| const char* cuda::current_device_arch_name() { | |||||
| auto&& device_prop = current_device_prop(); | |||||
| int cap = 10 * device_prop.major + device_prop.minor; | |||||
| if (cap >= 50 && cap < 60) | |||||
| return "maxwell"; | |||||
| else if (cap >= 60 && cap < 70) | |||||
| return "pascal"; | |||||
| else if (cap >= 70 && cap < 75) | |||||
| return "volta"; | |||||
| else if (cap >= 75 && cap < 80) | |||||
| return "turing"; | |||||
| else if (cap >= 80) | |||||
| return "ampere"; | |||||
| megdnn_throw( | |||||
| ssprintf("unsupported cuda compute capability %d", cap).c_str()); | |||||
| } | |||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -53,6 +53,14 @@ | |||||
| } \ | } \ | ||||
| } while (0) | } while (0) | ||||
| #define cucheck(_x) \ | |||||
| do { \ | |||||
| CUresult _err = (_x); \ | |||||
| if (_err != CUDA_SUCCESS) { \ | |||||
| ::megdnn::cuda::__throw_cuda_driver_error__(_err, #_x); \ | |||||
| } \ | |||||
| } while (0) | |||||
| #define after_kernel_launch() \ | #define after_kernel_launch() \ | ||||
| do { \ | do { \ | ||||
| cuda_check(cudaGetLastError()); \ | cuda_check(cudaGetLastError()); \ | ||||
| @@ -84,6 +92,7 @@ MEGDNN_NORETURN void __throw_cublas_error__(cublasStatus_t err, | |||||
| const char* msg); | const char* msg); | ||||
| MEGDNN_NORETURN void __throw_cusolver_error__(cusolverStatus_t err, | MEGDNN_NORETURN void __throw_cusolver_error__(cusolverStatus_t err, | ||||
| const char* msg); | const char* msg); | ||||
| MEGDNN_NORETURN void __throw_cuda_driver_error__(CUresult err, const char* msg); | |||||
| MEGDNN_NORETURN void report_error(const char* msg); | MEGDNN_NORETURN void report_error(const char* msg); | ||||
| template <typename T, size_t N> | template <typename T, size_t N> | ||||
| @@ -57,10 +57,15 @@ cudaDeviceProp current_device_prop(); | |||||
| //! check compute capability satisfied with given sm version | //! check compute capability satisfied with given sm version | ||||
| bool is_compute_capability_required(int major, int minor); | bool is_compute_capability_required(int major, int minor); | ||||
| //! check compute capability equal to the given sm version | |||||
| bool is_compute_capability_equalto(int major, int minor); | |||||
| //! get the CUDNN_MAX_BATCH_X_CHANNEL_SIZE, it's just return the max size of the | //! get the CUDNN_MAX_BATCH_X_CHANNEL_SIZE, it's just return the max size of the | ||||
| //! third demension | //! third demension | ||||
| size_t max_batch_x_channel_size(); | size_t max_batch_x_channel_size(); | ||||
| const char* current_device_arch_name(); | |||||
| } // namespace cuda | } // namespace cuda | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| @@ -493,6 +493,7 @@ std::vector<TestArg> get_int8_nchw44_args(size_t kernel_size, size_t pack_size, | |||||
| return args; | return args; | ||||
| } | } | ||||
| std::vector<TestArg> get_int8_nchw4_args_check_bounds(size_t kernel_size) { | std::vector<TestArg> get_int8_nchw4_args_check_bounds(size_t kernel_size) { | ||||
| std::vector<TestArg> args; | std::vector<TestArg> args; | ||||
| param::ConvBias cur_param; | param::ConvBias cur_param; | ||||
| @@ -528,6 +529,7 @@ std::vector<TestArg> get_int8_nchw4_args_check_bounds(size_t kernel_size) { | |||||
| return args; | return args; | ||||
| } | } | ||||
| std::vector<TestArg> get_int8_nchw4_args_small_batch(size_t kernel_size) { | std::vector<TestArg> get_int8_nchw4_args_small_batch(size_t kernel_size) { | ||||
| std::vector<TestArg> args; | std::vector<TestArg> args; | ||||
| param::ConvBias cur_param; | param::ConvBias cur_param; | ||||
| @@ -728,7 +730,7 @@ std::vector<TestArg> get_int8_chwn4_tensorcore_args(size_t kernel_size) { | |||||
| void check_conv_bias(DType src_dtype, DType filter_dtype, DType bias_dtype, | void check_conv_bias(DType src_dtype, DType filter_dtype, DType bias_dtype, | ||||
| DType dst_dtype, Handle* handle, const char* algo, | DType dst_dtype, Handle* handle, const char* algo, | ||||
| param::ConvBias::Format format, | param::ConvBias::Format format, | ||||
| const std::vector<TestArg>& args) { | |||||
| const std::vector<TestArg>& args, bool fuse_z) { | |||||
| megdnn_assert(src_dtype.enumv() == filter_dtype.enumv()); | megdnn_assert(src_dtype.enumv() == filter_dtype.enumv()); | ||||
| Checker<ConvBiasForward> checker(handle); | Checker<ConvBiasForward> checker(handle); | ||||
| if (algo) { | if (algo) { | ||||
| @@ -758,36 +760,72 @@ void check_conv_bias(DType src_dtype, DType filter_dtype, DType bias_dtype, | |||||
| bias_rng = std::make_unique<NormalRNG>(2.f); | bias_rng = std::make_unique<NormalRNG>(2.f); | ||||
| } | } | ||||
| using Param = param::ConvBias; | |||||
| using Format = Param::Format; | |||||
| auto get_z_shape = [&fuse_z, &format](TestArg arg) -> TensorShape { | |||||
| TensorShape z{}; | |||||
| if (fuse_z) { | |||||
| size_t hi, wi, sh, sw, ph, pw, fh, fw; | |||||
| z = arg.src; | |||||
| size_t spatial_idx = 2; | |||||
| if (format == Format::NCHW4) { | |||||
| hi = arg.src[2]; | |||||
| wi = arg.src[3]; | |||||
| fh = arg.filter[2]; | |||||
| fw = arg.filter[3]; | |||||
| z[1] = arg.filter[0] / 4; | |||||
| } else { | |||||
| megdnn_assert(format == Format::CHWN4); | |||||
| hi = arg.src[1]; | |||||
| wi = arg.src[2]; | |||||
| fh = arg.filter[1]; | |||||
| fw = arg.filter[2]; | |||||
| z[0] = arg.filter[3] / 4; | |||||
| spatial_idx = 1; | |||||
| } | |||||
| sh = arg.param.stride_h; | |||||
| sw = arg.param.stride_w; | |||||
| ph = arg.param.pad_h; | |||||
| pw = arg.param.pad_w; | |||||
| size_t ho = infer_conv_shape(hi, fh, sh, ph); | |||||
| size_t wo = infer_conv_shape(wi, fw, sw, pw); | |||||
| z[spatial_idx] = ho; | |||||
| z[spatial_idx + 1] = wo; | |||||
| } | |||||
| return z; | |||||
| }; | |||||
| megdnn_assert(rng != nullptr && bias_rng != nullptr); | megdnn_assert(rng != nullptr && bias_rng != nullptr); | ||||
| checker.set_rng(0, rng.get()) | |||||
| checker.set_rng(0, rng.get()) | |||||
| .set_rng(1, rng.get()) | .set_rng(1, rng.get()) | ||||
| .set_rng(2, rng.get()) | .set_rng(2, rng.get()) | ||||
| .set_rng(3, rng.get()); | .set_rng(3, rng.get()); | ||||
| if (args.empty()) { | if (args.empty()) { | ||||
| std::vector<TestArg> default_args; | std::vector<TestArg> default_args; | ||||
| using Param = param::ConvBias; | |||||
| using Format = Param::Format; | |||||
| if (format == Format::NCHW4) { | if (format == Format::NCHW4) { | ||||
| default_args = get_int8_nchw4_args(3); | default_args = get_int8_nchw4_args(3); | ||||
| } else if (format == Format::CHWN4) { | } else if (format == Format::CHWN4) { | ||||
| default_args = get_int8_chwn4_args(3); | default_args = get_int8_chwn4_args(3); | ||||
| } | } | ||||
| for (auto&& arg : default_args) { | for (auto&& arg : default_args) { | ||||
| auto z = get_z_shape(arg); | |||||
| checker.set_dtype(0, src_dtype) | checker.set_dtype(0, src_dtype) | ||||
| .set_dtype(1, filter_dtype) | .set_dtype(1, filter_dtype) | ||||
| .set_dtype(2, bias_dtype) | .set_dtype(2, bias_dtype) | ||||
| .set_dtype(3, dst_dtype) | |||||
| .set_dtype(4, dst_dtype) | .set_dtype(4, dst_dtype) | ||||
| .set_param(arg.param) | .set_param(arg.param) | ||||
| .execs({arg.src, arg.filter, arg.bias, {}, {}}); | |||||
| .execs({arg.src, arg.filter, arg.bias, z, {}}); | |||||
| } | } | ||||
| } else { | } else { | ||||
| for (auto&& arg : args) { | for (auto&& arg : args) { | ||||
| auto z = get_z_shape(arg); | |||||
| checker.set_dtype(0, src_dtype) | checker.set_dtype(0, src_dtype) | ||||
| .set_dtype(1, filter_dtype) | .set_dtype(1, filter_dtype) | ||||
| .set_dtype(2, bias_dtype) | .set_dtype(2, bias_dtype) | ||||
| .set_dtype(3, dst_dtype) | |||||
| .set_dtype(4, dst_dtype) | .set_dtype(4, dst_dtype) | ||||
| .set_param(arg.param) | .set_param(arg.param) | ||||
| .execs({arg.src, arg.filter, arg.bias, {}, {}}); | |||||
| .execs({arg.src, arg.filter, arg.bias, z, {}}); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -66,7 +66,7 @@ void check_conv_bias( | |||||
| DType src_dtype, DType filter_dtype, DType bias_dtype, DType dst_dtype, | DType src_dtype, DType filter_dtype, DType bias_dtype, DType dst_dtype, | ||||
| Handle* handle, const char* algo = nullptr, | Handle* handle, const char* algo = nullptr, | ||||
| param::ConvBias::Format format = param::ConvBias::Format::NCHW4, | param::ConvBias::Format format = param::ConvBias::Format::NCHW4, | ||||
| const std::vector<TestArg>& args = {}); | |||||
| const std::vector<TestArg>& args = {}, bool fuse_z = false); | |||||
| #if MEGDNN_WITH_BENCHMARK | #if MEGDNN_WITH_BENCHMARK | ||||
| std::vector<conv_bias::TestArg> get_winograd_benchmark_args( | std::vector<conv_bias::TestArg> get_winograd_benchmark_args( | ||||
| @@ -18,10 +18,14 @@ | |||||
| #include "test/cuda/fixture.h" | #include "test/cuda/fixture.h" | ||||
| #include "test/cuda/utils.h" | #include "test/cuda/utils.h" | ||||
| #define V1(x) #x | |||||
| #define V(x) V1(x) | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace test { | namespace test { | ||||
| #if MEGDNN_WITH_BENCHMARK | |||||
| namespace { | namespace { | ||||
| #if MEGDNN_WITH_BENCHMARK | |||||
| struct BenchArgs { | struct BenchArgs { | ||||
| size_t n, ci, hi, wi, co, f, s; | size_t n, ci, hi, wi, co, f, s; | ||||
| }; | }; | ||||
| @@ -29,9 +33,16 @@ struct BenchArgs { | |||||
| std::vector<BenchArgs> get_resnet50_bench_args(size_t batch = 64) { | std::vector<BenchArgs> get_resnet50_bench_args(size_t batch = 64) { | ||||
| std::vector<BenchArgs> args; | std::vector<BenchArgs> args; | ||||
| args.emplace_back(BenchArgs{batch, 64, 56, 56, 256, 1, 1}); | args.emplace_back(BenchArgs{batch, 64, 56, 56, 256, 1, 1}); | ||||
| args.emplace_back(BenchArgs{batch, 256, 56, 56, 32, 3, 1}); | |||||
| args.emplace_back(BenchArgs{batch, 256, 56, 56, 32, 3, 2}); | |||||
| args.emplace_back(BenchArgs{batch, 4, 256, 256, 32, 7, 2}); | |||||
| args.emplace_back(BenchArgs{batch, 256, 56, 56, 64, 1, 1}); | args.emplace_back(BenchArgs{batch, 256, 56, 56, 64, 1, 1}); | ||||
| args.emplace_back(BenchArgs{batch, 64, 56, 56, 64, 1, 1}); | args.emplace_back(BenchArgs{batch, 64, 56, 56, 64, 1, 1}); | ||||
| args.emplace_back(BenchArgs{batch, 64, 56, 56, 64, 3, 1}); | args.emplace_back(BenchArgs{batch, 64, 56, 56, 64, 3, 1}); | ||||
| args.emplace_back(BenchArgs{batch, 64, 56, 56, 64, 3, 2}); | |||||
| args.emplace_back(BenchArgs{batch, 256, 56, 56, 64, 3, 2}); | |||||
| args.emplace_back(BenchArgs{batch, 64, 56, 56, 256, 1, 1}); | args.emplace_back(BenchArgs{batch, 64, 56, 56, 256, 1, 1}); | ||||
| args.emplace_back(BenchArgs{batch, 256, 56, 56, 512, 1, 2}); | args.emplace_back(BenchArgs{batch, 256, 56, 56, 512, 1, 2}); | ||||
| @@ -101,13 +112,12 @@ void benchmark_target_algo( | |||||
| conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(algo)); | conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(algo)); | ||||
| } | } | ||||
| #define V1(x) #x | |||||
| #define V(x) V1(x) | |||||
| #define CUDNN_VERSION_STRING \ | #define CUDNN_VERSION_STRING \ | ||||
| "v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) "." V(CUDNN_PATCHLEVEL) | "v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) "." V(CUDNN_PATCHLEVEL) | ||||
| benchmarker_cudnn.set_before_exec_callback( | benchmarker_cudnn.set_before_exec_callback( | ||||
| conv_bias::ConvBiasAlgoChecker<ConvBiasForward>( | conv_bias::ConvBiasAlgoChecker<ConvBiasForward>( | ||||
| "CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_" | |||||
| "DEFAULT:CUDNN:ConvBiasActivation:CUDNN_CONVOLUTION_FWD_" | |||||
| "ALGO_IMPLICIT_PRECOMP_" | |||||
| "GEMM" CUDNN_VERSION_STRING)); | "GEMM" CUDNN_VERSION_STRING)); | ||||
| benchmarker.set_dtype(0, src_dtype) | benchmarker.set_dtype(0, src_dtype) | ||||
| @@ -141,6 +151,7 @@ void benchmark_target_algo( | |||||
| {}, | {}, | ||||
| {}}) / | {}}) / | ||||
| RUNS; | RUNS; | ||||
| param.nonlineMode = Param::NonlineMode::IDENTITY; | |||||
| benchmarker_cudnn.set_param(param); | benchmarker_cudnn.set_param(param); | ||||
| auto time_in_ms_cudnn = | auto time_in_ms_cudnn = | ||||
| benchmarker_cudnn.execs( | benchmarker_cudnn.execs( | ||||
| @@ -162,6 +173,47 @@ void benchmark_target_algo( | |||||
| (flo / (time_in_ms_cudnn * 1e-3)), algo, | (flo / (time_in_ms_cudnn * 1e-3)), algo, | ||||
| time_in_ms_cudnn / time_in_ms); | time_in_ms_cudnn / time_in_ms); | ||||
| } | } | ||||
| printf("bench with z tensor\n"); | |||||
| for (auto&& arg : args) { | |||||
| Param param; | |||||
| param.pad_h = param.pad_w = arg.f / 2; | |||||
| param.stride_h = param.stride_w = arg.s; | |||||
| param.format = Format::NCHW4; | |||||
| size_t ho = infer_conv_shape(arg.hi, arg.f, arg.s, arg.f / 2); | |||||
| size_t wo = infer_conv_shape(arg.wi, arg.f, arg.s, arg.f / 2); | |||||
| benchmarker.set_param(param); | |||||
| auto time_in_ms = | |||||
| benchmarker.execs({{arg.n, arg.ci / 4, arg.hi, arg.wi, 4}, | |||||
| {arg.co, arg.ci / 4, arg.f, arg.f, 4}, | |||||
| {1, arg.co / 4, 1, 1, 4}, | |||||
| {arg.n, arg.co / 4, ho, wo, 4}, | |||||
| {}}) / | |||||
| RUNS; | |||||
| param.format = Format::NCHW4; | |||||
| param.nonlineMode = Param::NonlineMode::IDENTITY; | |||||
| benchmarker_cudnn.set_param(param); | |||||
| auto time_in_ms_cudnn = | |||||
| benchmarker_cudnn.execs( | |||||
| {{arg.n, arg.ci / 4, arg.hi, arg.wi, 4}, | |||||
| {arg.co, arg.ci / 4, arg.f, arg.f, 4}, | |||||
| {1, arg.co / 4, 1, 1, 4}, | |||||
| {arg.n, arg.co / 4, ho, wo, 4}, | |||||
| {}}) / | |||||
| RUNS; | |||||
| float flo = 2.0 * arg.n * arg.co * ho * wo * arg.ci * arg.f * | |||||
| arg.f / (1e12); | |||||
| TensorShape src{arg.n, arg.ci, arg.hi, arg.wi}, | |||||
| filter{arg.co, arg.ci, arg.f, arg.f}; | |||||
| printf("src=%s, filter=%s, time(algo=%s)=%.2f %.2fTops, " | |||||
| "time(cudnn)=%.2f %.2fTops, " | |||||
| "perf(algo=%s)/perf(cudnn)=%.2f\n", | |||||
| src.to_string().c_str(), filter.to_string().c_str(), algo, | |||||
| time_in_ms, (flo / (time_in_ms * 1e-3)), time_in_ms_cudnn, | |||||
| (flo / (time_in_ms_cudnn * 1e-3)), algo, | |||||
| time_in_ms_cudnn / time_in_ms); | |||||
| } | |||||
| } else if (format == Format::CHWN4) { | } else if (format == Format::CHWN4) { | ||||
| for (auto&& arg : args) { | for (auto&& arg : args) { | ||||
| Param param; | Param param; | ||||
| @@ -222,6 +274,7 @@ void benchmark_target_algo( | |||||
| RUNS; | RUNS; | ||||
| param.format = Format::NCHW4; | param.format = Format::NCHW4; | ||||
| benchmarker_cudnn.set_param(param); | benchmarker_cudnn.set_param(param); | ||||
| param.nonlineMode = Param::NonlineMode::IDENTITY; | |||||
| auto time_in_ms_cudnn = | auto time_in_ms_cudnn = | ||||
| benchmarker_cudnn.execs( | benchmarker_cudnn.execs( | ||||
| {{arg.n, arg.ci / 4, arg.hi, arg.wi, 4}, | {{arg.n, arg.ci / 4, arg.hi, arg.wi, 4}, | ||||
| @@ -242,7 +295,6 @@ void benchmark_target_algo( | |||||
| (flo / (time_in_ms_cudnn * 1e-3)), algo, | (flo / (time_in_ms_cudnn * 1e-3)), algo, | ||||
| time_in_ms_cudnn / time_in_ms); | time_in_ms_cudnn / time_in_ms); | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -265,15 +317,14 @@ void benchmark_target_algo_with_cudnn_tsc( | |||||
| benchmarker.set_before_exec_callback( | benchmarker.set_before_exec_callback( | ||||
| conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(algo)); | conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(algo)); | ||||
| } else { | } else { | ||||
| benchmarker.set_proxy(proxy); | |||||
| benchmarker.set_proxy(proxy); | |||||
| } | } | ||||
| benchmarker_cudnn.set_before_exec_callback( | benchmarker_cudnn.set_before_exec_callback( | ||||
| conv_bias::ConvBiasAlgoChecker<ConvBiasForward>( | conv_bias::ConvBiasAlgoChecker<ConvBiasForward>( | ||||
| "CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_" | |||||
| "DEFAULT:CUDNN:ConvBiasActivation:CUDNN_CONVOLUTION_FWD_" | |||||
| "ALGO_IMPLICIT_PRECOMP_" | |||||
| "GEMM" CUDNN_VERSION_STRING)); | "GEMM" CUDNN_VERSION_STRING)); | ||||
| #undef V1 | |||||
| #undef V | |||||
| #undef CUDNN_VERSION_STRING | #undef CUDNN_VERSION_STRING | ||||
| benchmarker.set_dtype(0, src_dtype) | benchmarker.set_dtype(0, src_dtype) | ||||
| @@ -446,12 +497,10 @@ void benchmark_target_algo_with_cudnn_tsc( | |||||
| (flo / (time_in_ms_cudnn * 1e-3)), algo, | (flo / (time_in_ms_cudnn * 1e-3)), algo, | ||||
| time_in_ms_cudnn / time_in_ms); | time_in_ms_cudnn / time_in_ms); | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| } // namespace | |||||
| #endif | #endif | ||||
| } // namespace | |||||
| TEST_F(CUDA, CONV_BIAS_INT8_NCHW4_1x1) { | TEST_F(CUDA, CONV_BIAS_INT8_NCHW4_1x1) { | ||||
| require_compute_capability(6, 1); | require_compute_capability(6, 1); | ||||
| @@ -1116,6 +1165,7 @@ TEST_F(CUDA, CONV_BIAS_INT8_CHWN4_UNROLL_WIDTH_TENSORCORE_1x1_ALGO_2) { | |||||
| conv_bias::get_int8_chwn4_args_small_batch(1)); | conv_bias::get_int8_chwn4_args_small_batch(1)); | ||||
| } | } | ||||
| #if MEGDNN_WITH_BENCHMARK | #if MEGDNN_WITH_BENCHMARK | ||||
| TEST_F(CUDA, BENCHMARK_CONV_BIAS_INT8_CHWN4) { | TEST_F(CUDA, BENCHMARK_CONV_BIAS_INT8_CHWN4) { | ||||
| require_compute_capability(6, 1); | require_compute_capability(6, 1); | ||||
| @@ -1182,9 +1232,13 @@ TEST_F(CUDA, BENCHMARK_CONV_BIAS_INT8_CHWN4_SMALL_CHANNEL) { | |||||
| dtype::QuantizedS8{1.0f}, "INT8_CHWN4_DOTPROD_IMPLICIT_GEMM", | dtype::QuantizedS8{1.0f}, "INT8_CHWN4_DOTPROD_IMPLICIT_GEMM", | ||||
| param::ConvBias::Format::CHWN4); | param::ConvBias::Format::CHWN4); | ||||
| } | } | ||||
| #endif | #endif | ||||
| } // namespace test | } // namespace test | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| #undef V1 | |||||
| #undef V | |||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -290,6 +290,26 @@ TEST_F(CUDA, POOLING_FORWARD_CHWN4) { | |||||
| } | } | ||||
| } | } | ||||
| TEST_F(CUDA, POOLING_FORWARD_INT8_NCHW4) { | |||||
| require_compute_capability(6, 1); | |||||
| using Param = param::Pooling; | |||||
| Checker<Pooling> checker(handle_cuda()); | |||||
| Param param; | |||||
| auto i8_min = std::numeric_limits<int8_t>().min(); | |||||
| auto i8_max = std::numeric_limits<int8_t>().max(); | |||||
| UniformIntRNG int_rng{i8_min, i8_max}; | |||||
| checker.set_dtype(0, dtype::QuantizedS8(0.1f)); | |||||
| param.format = Param::Format::NCHW4; | |||||
| for (auto mode : {Param::Mode::MAX, Param::Mode::AVERAGE, | |||||
| Param::Mode::AVERAGE_COUNT_EXCLUDE_PADDING}) { | |||||
| param.mode = mode; | |||||
| checker.set_epsilon(1e-3).set_rng(0, &int_rng); | |||||
| checker.set_param(param).exec({{64, 8, 28, 28, 4}, {}}); | |||||
| checker.set_param(param).exec({{15, 8, 28, 28, 4}, {}}); | |||||
| checker.set_param(param).exec({{30, 8, 28, 28, 4}, {}}); | |||||
| } | |||||
| } | |||||
| #if MEGDNN_WITH_BENCHMARK | #if MEGDNN_WITH_BENCHMARK | ||||
| TEST_F(CUDA, BENCHMARK_POOLING_CHWN4) { | TEST_F(CUDA, BENCHMARK_POOLING_CHWN4) { | ||||
| CUBenchmarker<Pooling> bencher(handle_cuda()); | CUBenchmarker<Pooling> bencher(handle_cuda()); | ||||
| @@ -20,6 +20,14 @@ bool check_compute_capability(int major, int minor) { | |||||
| cuda_check(cudaGetDeviceProperties(&prop, dev)); | cuda_check(cudaGetDeviceProperties(&prop, dev)); | ||||
| return prop.major > major || (prop.major == major && prop.minor >= minor); | return prop.major > major || (prop.major == major && prop.minor >= minor); | ||||
| } | } | ||||
| bool check_compute_capability_eq(int major, int minor) { | |||||
| int dev; | |||||
| cuda_check(cudaGetDevice(&dev)); | |||||
| cudaDeviceProp prop; | |||||
| cuda_check(cudaGetDeviceProperties(&prop, dev)); | |||||
| return (prop.major == major && prop.minor == minor); | |||||
| } | |||||
| } // namespace test | } // namespace test | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| @@ -26,13 +26,28 @@ | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace test { | namespace test { | ||||
| bool check_compute_capability(int major, int minor); | bool check_compute_capability(int major, int minor); | ||||
| bool check_compute_capability_eq(int major, int minor); | |||||
| } // namespace test | } // namespace test | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| #define require_compute_capability(x, y) \ | |||||
| do { \ | |||||
| if (!megdnn::test::check_compute_capability((x), (y))) \ | |||||
| return; \ | |||||
| #define require_compute_capability(x, y) \ | |||||
| do { \ | |||||
| if (!megdnn::test::check_compute_capability((x), (y))) { \ | |||||
| printf("skip testcase due to cuda compute capability not " \ | |||||
| "require.(expected:%d.%d)", \ | |||||
| (x), (y)); \ | |||||
| return; \ | |||||
| } \ | |||||
| } while (0) | |||||
| #define require_compute_capability_eq(x, y) \ | |||||
| do { \ | |||||
| if (!megdnn::test::check_compute_capability_eq((x), (y))) { \ | |||||
| printf("skip testcase due to cuda compute capability not " \ | |||||
| "equal to %d.%d", \ | |||||
| (x), (y)); \ | |||||
| return; \ | |||||
| } \ | |||||
| } while (0) | } while (0) | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||