| @@ -3,3 +3,4 @@ | |||
| dnn/src/cuda/conv_bias/int8/kimpl/* binary | |||
| dnn/src/cuda/conv_bias/int8_imma/kimpl/* binary | |||
| dnn/src/cuda/batch_conv_bias/int8/kimpl/* binary | |||
| dnn/src/cuda/sass/prebuilt/map_defs.cpp binary | |||
| @@ -236,6 +236,7 @@ void ConvBiasForwardImpl::AlgoPack::fill_imma_algos() { | |||
| } | |||
| #endif | |||
| ConvBiasForwardImpl::AlgoBase* | |||
| ConvBiasForwardImpl::AlgoPack::cudnn_conv_from_enum( | |||
| cudnnConvolutionFwdAlgo_t algo) { | |||
| @@ -14,11 +14,11 @@ | |||
| #include "megdnn/oprs.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/cuda/conv_bias/conv_bias_int8.cuh" | |||
| #include "src/cuda/conv_bias/helper.h" | |||
| #include "src/cuda/conv_bias/opr_impl.h" | |||
| #include "src/cuda/handle.h" | |||
| #include "src/cuda/conv_bias/conv_bias_int8.cuh" | |||
| #include "src/cuda/convolution_helper/parameter.cuh" | |||
| #include "src/cuda/handle.h" | |||
| #include <cuda.h> | |||
| #include <memory> | |||
| @@ -521,6 +521,7 @@ private: | |||
| std::string m_name; | |||
| }; | |||
| class ConvBiasForwardImpl::AlgoPack { | |||
| AlgoPack(const AlgoPack&) = delete; | |||
| AlgoPack& operator=(const AlgoPack&) = delete; | |||
| @@ -6,7 +6,8 @@ | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "../elemwise/opr_impl.h" | |||
| @@ -94,5 +95,5 @@ private: | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -10,7 +10,7 @@ | |||
| */ | |||
| #include "src/cuda/pooling/opr_impl.h" | |||
| #include "./pooling2d_int8_cdiv4hwn4.cuh" | |||
| #include "./pooling2d_int8.cuh" | |||
| #include "src/cuda/utils.h" | |||
| namespace megdnn { | |||
| @@ -67,7 +67,24 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in ssrc, _megdnn_tensor_out sdst, | |||
| kern_param.window_h = window_h, kern_param.window_w = window_w, | |||
| kern_param.sh = sh, kern_param.sw = sw; | |||
| auto&& stream = cuda_stream(handle()); | |||
| return pooling2d::_do_pooling2d_int8_cdiv4hwn4( | |||
| return pooling2d::do_pooling2d_int8_cdiv4hwn4( | |||
| src.compatible_ptr<int8_t>(), dst.compatible_ptr<int8_t>(), | |||
| kern_param, stream, static_cast<uint32_t>(param().mode)); | |||
| } else if (param().format == Format::NCHW4) { | |||
| pooling2d::Param kern_param; | |||
| size_t n = src.layout[0], hi = src.layout[2], wi = src.layout[3], | |||
| c = src.layout[1], ho = dst.layout[2], wo = dst.layout[3]; | |||
| c = c * 4; | |||
| size_t ph = param().pad_h, pw = param().pad_w; | |||
| size_t window_h = param().window_h, window_w = param().window_w; | |||
| size_t sh = param().stride_h, sw = param().stride_w; | |||
| kern_param.n = n, kern_param.c = c, kern_param.hi = hi, | |||
| kern_param.wi = wi, kern_param.ho = ho, kern_param.wo = wo, | |||
| kern_param.ph = ph, kern_param.pw = pw, | |||
| kern_param.window_h = window_h, kern_param.window_w = window_w, | |||
| kern_param.sh = sh, kern_param.sw = sw; | |||
| auto&& stream = cuda_stream(handle()); | |||
| return pooling2d::do_pooling2d_int8_ncdiv4hw4( | |||
| src.compatible_ptr<int8_t>(), dst.compatible_ptr<int8_t>(), | |||
| kern_param, stream, static_cast<uint32_t>(param().mode)); | |||
| } | |||
| @@ -8,8 +8,9 @@ | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "./pooling2d_int8_cdiv4hwn4.cuh" | |||
| #include "./pooling2d_int8.cuh" | |||
| #include "src/common/opr_param_defs_enumv.cuh" | |||
| #include "src/cuda/query_blocksize.cuh" | |||
| using namespace megdnn; | |||
| using namespace cuda; | |||
| @@ -360,11 +361,65 @@ __global__ void pooling2d_device_template_int8_cdiv4hwn4( | |||
| ldg_type res = pooler.get_ans(); | |||
| *(reinterpret_cast<ldg_type*>(g_dst_ptr)) = res; | |||
| } | |||
| template <typename Pooler> | |||
| __global__ void pooling2d_device_template_int8_ncdiv4hw4( | |||
| const int8_t* __restrict__ src, int8_t* __restrict__ dst, Param param) { | |||
| const int tid = blockIdx.x * blockDim.x + threadIdx.x; | |||
| using ldg_type = typename Pooler::feed_type; | |||
| static int constexpr pack_size = 4; | |||
| static int constexpr ldg_width = sizeof(ldg_type) / sizeof(int32_t); | |||
| MEGDNN_STATIC_ASSERT( | |||
| ldg_width == 1, | |||
| "pooling2d (NCHW4) kernel must use 32bit width ldg instruction"); | |||
| const int wo_ldg = param.wo / ldg_width; | |||
| const int c_packed = param.c / pack_size; | |||
| const int batch = tid / (param.ho * wo_ldg * c_packed); | |||
| const int chw = tid - batch * param.ho * wo_ldg * c_packed; | |||
| const int oc_packed = chw / (param.ho * wo_ldg); | |||
| const int hw = chw - oc_packed * param.ho * wo_ldg; | |||
| const int oh = hw / wo_ldg; | |||
| const int ow = (hw - wo_ldg * oh) * ldg_width; | |||
| if (batch >= param.n || oc_packed >= c_packed || oh >= param.ho || | |||
| ow >= param.wo) | |||
| return; | |||
| const int in_batch_stride = param.hi * param.wi * param.c; | |||
| const int out_batch_stride = param.ho * param.wo * param.c; | |||
| const int in_channel_stride = param.hi * param.wi * pack_size; | |||
| const int out_channel_stride = param.ho * param.wo * pack_size; | |||
| const int8_t* __restrict__ g_src_ptr = | |||
| src + batch * in_batch_stride + oc_packed * in_channel_stride; | |||
| int8_t* __restrict__ g_dst_ptr = dst + batch * out_batch_stride + | |||
| oc_packed * out_channel_stride + | |||
| (oh * param.wo + ow) * pack_size; | |||
| Pooler pooler(param.window_h * param.window_w); | |||
| pooler.init(); | |||
| for (int fh = 0; fh < param.window_h; fh++) { | |||
| uint32_t ih = oh * param.sh + fh - param.ph; | |||
| for (int fw = 0; fw < param.window_w; fw++) { | |||
| uint32_t iw = ow * param.sw + fw - param.pw; | |||
| if (ih < param.hi && iw < param.wi) { | |||
| const int8_t* __restrict__ cur_src_ptr = | |||
| g_src_ptr + (ih * param.wi + iw) * pack_size; | |||
| ldg_type sval = __ldg(reinterpret_cast<const ldg_type*>(cur_src_ptr)); | |||
| pooler.feed(sval); | |||
| } | |||
| } | |||
| } | |||
| ldg_type res = pooler.get_ans(); | |||
| *(reinterpret_cast<ldg_type*>(g_dst_ptr)) = res; | |||
| } | |||
| }; // namespace | |||
| void megdnn::cuda::pooling2d::_do_pooling2d_int8_cdiv4hwn4( | |||
| const int8_t* d_src, int8_t* d_dst, const Param& param, | |||
| cudaStream_t stream, uint32_t mode) { | |||
| void megdnn::cuda::pooling2d::do_pooling2d_int8_cdiv4hwn4(const int8_t* d_src, | |||
| int8_t* d_dst, | |||
| const Param& param, | |||
| cudaStream_t stream, | |||
| uint32_t mode) { | |||
| using Mode = megdnn::param_enumv::Pooling::Mode; | |||
| void (*kern)(const int8_t* __restrict__, int8_t* __restrict__, Param param); | |||
| uint32_t vthreads_x = 0, vthreads_y = param.c / 4; | |||
| @@ -397,8 +452,7 @@ void megdnn::cuda::pooling2d::_do_pooling2d_int8_cdiv4hwn4( | |||
| } | |||
| #undef dispatch_pooling_mode | |||
| constexpr uint32_t threads_x = 16; | |||
| uint32_t nr_threads = | |||
| _get_kern_block_size(reinterpret_cast<const void*>(kern)); | |||
| uint32_t nr_threads = query_blocksize_for_kernel(kern); | |||
| uint32_t nr_threads_x = std::min(threads_x, vthreads_x), | |||
| nr_threads_y = std::min(nr_threads / nr_threads_x, vthreads_y); | |||
| uint32_t nr_blocks_x = param.ho * param.wo, | |||
| @@ -410,4 +464,34 @@ void megdnn::cuda::pooling2d::_do_pooling2d_int8_cdiv4hwn4( | |||
| after_kernel_launch(); | |||
| } | |||
| void megdnn::cuda::pooling2d::do_pooling2d_int8_ncdiv4hw4(const int8_t* d_src, | |||
| int8_t* d_dst, | |||
| const Param& param, | |||
| cudaStream_t stream, | |||
| uint32_t mode) { | |||
| using Mode = megdnn::param_enumv::Pooling::Mode; | |||
| void (*kern)(const int8_t* __restrict__, int8_t* __restrict__, Param param); | |||
| uint32_t vthreads = param.n * param.c * param.ho * param.wo / 4; | |||
| switch (mode) { | |||
| case Mode::MAX: | |||
| kern = pooling2d_device_template_int8_ncdiv4hw4< | |||
| MaxPooler<int8_t, int32_t>>; | |||
| break; | |||
| case Mode::AVERAGE: | |||
| kern = pooling2d_device_template_int8_ncdiv4hw4< | |||
| MeanIncludeRoundedPooler<int8_t, int32_t, int32_t>>; | |||
| break; | |||
| case Mode::AVERAGE_COUNT_EXCLUDE_PADDING: | |||
| kern = pooling2d_device_template_int8_ncdiv4hw4< | |||
| MeanExcludeRoundedPooler<int8_t, int32_t, int32_t>>; | |||
| break; | |||
| default: | |||
| megdnn_assert(false, "invalid pooling mode"); | |||
| } | |||
| uint32_t nr_threads = query_blocksize_for_kernel(kern); | |||
| nr_threads = std::min(nr_threads, vthreads); | |||
| uint32_t nr_blocks = DIVUP(vthreads, nr_threads); | |||
| kern<<<nr_blocks, nr_threads, 0, stream>>>(d_src, d_dst, param); | |||
| after_kernel_launch(); | |||
| } | |||
| // vim: syntax=cuda.doxygen | |||
| @@ -1,12 +1,13 @@ | |||
| /** | |||
| * \file dnn/src/cuda/pooling/pooling2d_int8_cdiv4hwn4.cuh | |||
| * \file dnn/src/cuda/pooling/pooling2d_int8.cuh | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| @@ -20,15 +21,16 @@ struct Param { | |||
| int n, c, hi, wi, ho, wo, ph, pw, window_h, window_w, sh, sw; | |||
| }; | |||
| uint32_t _get_kern_block_size(const void* kern); | |||
| void do_pooling2d_int8_cdiv4hwn4(const int8_t* d_src, int8_t* d_dst, | |||
| const Param& param, cudaStream_t stream, | |||
| uint32_t mode); | |||
| void _do_pooling2d_int8_cdiv4hwn4(const int8_t* d_src, int8_t* d_dst, | |||
| const Param& param, cudaStream_t stream, | |||
| uint32_t mode); | |||
| void do_pooling2d_int8_ncdiv4hw4(const int8_t* d_src, int8_t* d_dst, | |||
| const Param& param, cudaStream_t stream, | |||
| uint32_t mode); | |||
| } // namespace pooling2d | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| // vim: syntax=cuda.doxygen | |||
| @@ -1,27 +0,0 @@ | |||
| /** | |||
| * \file dnn/src/cuda/pooling/pooling2d_int8_cdiv4hwn4.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "./pooling2d_int8_cdiv4hwn4.cuh" | |||
| #include "src/cuda/query_blocksize.cuh" | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| namespace pooling2d { | |||
| uint32_t _get_kern_block_size(const void* kern) { | |||
| uint32_t ret = query_blocksize_for_kernel(kern); | |||
| return ret; | |||
| } | |||
| } // namespace pooling2d | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -82,6 +82,11 @@ void cuda::__throw_cusolver_error__(cusolverStatus_t err, const char* msg) { | |||
| megdnn_throw(s.c_str()); | |||
| } | |||
| void cuda::__throw_cuda_driver_error__(CUresult err, const char* msg) { | |||
| auto s = ssprintf("cuda driver error %d occurred; expr: %s", int(err), msg); | |||
| megdnn_throw(s.c_str()); | |||
| } | |||
| void cuda::report_error(const char *msg) { | |||
| megdnn_throw(msg); | |||
| MEGDNN_MARK_USED_VAR(msg); | |||
| @@ -118,9 +123,31 @@ bool cuda::is_compute_capability_required(int major, int minor) { | |||
| (device_prop.major == major && device_prop.minor >= minor); | |||
| } | |||
| bool cuda::is_compute_capability_equalto(int major, int minor) { | |||
| auto&& device_prop = cuda::current_device_prop(); | |||
| return device_prop.major == major && device_prop.minor == minor; | |||
| } | |||
| size_t cuda::max_batch_x_channel_size() { | |||
| return current_device_prop().maxGridSize[2]; | |||
| } | |||
| const char* cuda::current_device_arch_name() { | |||
| auto&& device_prop = current_device_prop(); | |||
| int cap = 10 * device_prop.major + device_prop.minor; | |||
| if (cap >= 50 && cap < 60) | |||
| return "maxwell"; | |||
| else if (cap >= 60 && cap < 70) | |||
| return "pascal"; | |||
| else if (cap >= 70 && cap < 75) | |||
| return "volta"; | |||
| else if (cap >= 75 && cap < 80) | |||
| return "turing"; | |||
| else if (cap >= 80) | |||
| return "ampere"; | |||
| megdnn_throw( | |||
| ssprintf("unsupported cuda compute capability %d", cap).c_str()); | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -53,6 +53,14 @@ | |||
| } \ | |||
| } while (0) | |||
| #define cucheck(_x) \ | |||
| do { \ | |||
| CUresult _err = (_x); \ | |||
| if (_err != CUDA_SUCCESS) { \ | |||
| ::megdnn::cuda::__throw_cuda_driver_error__(_err, #_x); \ | |||
| } \ | |||
| } while (0) | |||
| #define after_kernel_launch() \ | |||
| do { \ | |||
| cuda_check(cudaGetLastError()); \ | |||
| @@ -84,6 +92,7 @@ MEGDNN_NORETURN void __throw_cublas_error__(cublasStatus_t err, | |||
| const char* msg); | |||
| MEGDNN_NORETURN void __throw_cusolver_error__(cusolverStatus_t err, | |||
| const char* msg); | |||
| MEGDNN_NORETURN void __throw_cuda_driver_error__(CUresult err, const char* msg); | |||
| MEGDNN_NORETURN void report_error(const char* msg); | |||
| template <typename T, size_t N> | |||
| @@ -57,10 +57,15 @@ cudaDeviceProp current_device_prop(); | |||
| //! check compute capability satisfied with given sm version | |||
| bool is_compute_capability_required(int major, int minor); | |||
| //! check compute capability equal to the given sm version | |||
| bool is_compute_capability_equalto(int major, int minor); | |||
| //! get the CUDNN_MAX_BATCH_X_CHANNEL_SIZE, it's just return the max size of the | |||
| //! third demension | |||
| size_t max_batch_x_channel_size(); | |||
| const char* current_device_arch_name(); | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| @@ -493,6 +493,7 @@ std::vector<TestArg> get_int8_nchw44_args(size_t kernel_size, size_t pack_size, | |||
| return args; | |||
| } | |||
| std::vector<TestArg> get_int8_nchw4_args_check_bounds(size_t kernel_size) { | |||
| std::vector<TestArg> args; | |||
| param::ConvBias cur_param; | |||
| @@ -528,6 +529,7 @@ std::vector<TestArg> get_int8_nchw4_args_check_bounds(size_t kernel_size) { | |||
| return args; | |||
| } | |||
| std::vector<TestArg> get_int8_nchw4_args_small_batch(size_t kernel_size) { | |||
| std::vector<TestArg> args; | |||
| param::ConvBias cur_param; | |||
| @@ -728,7 +730,7 @@ std::vector<TestArg> get_int8_chwn4_tensorcore_args(size_t kernel_size) { | |||
| void check_conv_bias(DType src_dtype, DType filter_dtype, DType bias_dtype, | |||
| DType dst_dtype, Handle* handle, const char* algo, | |||
| param::ConvBias::Format format, | |||
| const std::vector<TestArg>& args) { | |||
| const std::vector<TestArg>& args, bool fuse_z) { | |||
| megdnn_assert(src_dtype.enumv() == filter_dtype.enumv()); | |||
| Checker<ConvBiasForward> checker(handle); | |||
| if (algo) { | |||
| @@ -758,36 +760,72 @@ void check_conv_bias(DType src_dtype, DType filter_dtype, DType bias_dtype, | |||
| bias_rng = std::make_unique<NormalRNG>(2.f); | |||
| } | |||
| using Param = param::ConvBias; | |||
| using Format = Param::Format; | |||
| auto get_z_shape = [&fuse_z, &format](TestArg arg) -> TensorShape { | |||
| TensorShape z{}; | |||
| if (fuse_z) { | |||
| size_t hi, wi, sh, sw, ph, pw, fh, fw; | |||
| z = arg.src; | |||
| size_t spatial_idx = 2; | |||
| if (format == Format::NCHW4) { | |||
| hi = arg.src[2]; | |||
| wi = arg.src[3]; | |||
| fh = arg.filter[2]; | |||
| fw = arg.filter[3]; | |||
| z[1] = arg.filter[0] / 4; | |||
| } else { | |||
| megdnn_assert(format == Format::CHWN4); | |||
| hi = arg.src[1]; | |||
| wi = arg.src[2]; | |||
| fh = arg.filter[1]; | |||
| fw = arg.filter[2]; | |||
| z[0] = arg.filter[3] / 4; | |||
| spatial_idx = 1; | |||
| } | |||
| sh = arg.param.stride_h; | |||
| sw = arg.param.stride_w; | |||
| ph = arg.param.pad_h; | |||
| pw = arg.param.pad_w; | |||
| size_t ho = infer_conv_shape(hi, fh, sh, ph); | |||
| size_t wo = infer_conv_shape(wi, fw, sw, pw); | |||
| z[spatial_idx] = ho; | |||
| z[spatial_idx + 1] = wo; | |||
| } | |||
| return z; | |||
| }; | |||
| megdnn_assert(rng != nullptr && bias_rng != nullptr); | |||
| checker.set_rng(0, rng.get()) | |||
| checker.set_rng(0, rng.get()) | |||
| .set_rng(1, rng.get()) | |||
| .set_rng(2, rng.get()) | |||
| .set_rng(3, rng.get()); | |||
| if (args.empty()) { | |||
| std::vector<TestArg> default_args; | |||
| using Param = param::ConvBias; | |||
| using Format = Param::Format; | |||
| if (format == Format::NCHW4) { | |||
| default_args = get_int8_nchw4_args(3); | |||
| } else if (format == Format::CHWN4) { | |||
| default_args = get_int8_chwn4_args(3); | |||
| } | |||
| for (auto&& arg : default_args) { | |||
| auto z = get_z_shape(arg); | |||
| checker.set_dtype(0, src_dtype) | |||
| .set_dtype(1, filter_dtype) | |||
| .set_dtype(2, bias_dtype) | |||
| .set_dtype(3, dst_dtype) | |||
| .set_dtype(4, dst_dtype) | |||
| .set_param(arg.param) | |||
| .execs({arg.src, arg.filter, arg.bias, {}, {}}); | |||
| .execs({arg.src, arg.filter, arg.bias, z, {}}); | |||
| } | |||
| } else { | |||
| for (auto&& arg : args) { | |||
| auto z = get_z_shape(arg); | |||
| checker.set_dtype(0, src_dtype) | |||
| .set_dtype(1, filter_dtype) | |||
| .set_dtype(2, bias_dtype) | |||
| .set_dtype(3, dst_dtype) | |||
| .set_dtype(4, dst_dtype) | |||
| .set_param(arg.param) | |||
| .execs({arg.src, arg.filter, arg.bias, {}, {}}); | |||
| .execs({arg.src, arg.filter, arg.bias, z, {}}); | |||
| } | |||
| } | |||
| } | |||
| @@ -66,7 +66,7 @@ void check_conv_bias( | |||
| DType src_dtype, DType filter_dtype, DType bias_dtype, DType dst_dtype, | |||
| Handle* handle, const char* algo = nullptr, | |||
| param::ConvBias::Format format = param::ConvBias::Format::NCHW4, | |||
| const std::vector<TestArg>& args = {}); | |||
| const std::vector<TestArg>& args = {}, bool fuse_z = false); | |||
| #if MEGDNN_WITH_BENCHMARK | |||
| std::vector<conv_bias::TestArg> get_winograd_benchmark_args( | |||
| @@ -18,10 +18,14 @@ | |||
| #include "test/cuda/fixture.h" | |||
| #include "test/cuda/utils.h" | |||
| #define V1(x) #x | |||
| #define V(x) V1(x) | |||
| namespace megdnn { | |||
| namespace test { | |||
| #if MEGDNN_WITH_BENCHMARK | |||
| namespace { | |||
| #if MEGDNN_WITH_BENCHMARK | |||
| struct BenchArgs { | |||
| size_t n, ci, hi, wi, co, f, s; | |||
| }; | |||
| @@ -29,9 +33,16 @@ struct BenchArgs { | |||
| std::vector<BenchArgs> get_resnet50_bench_args(size_t batch = 64) { | |||
| std::vector<BenchArgs> args; | |||
| args.emplace_back(BenchArgs{batch, 64, 56, 56, 256, 1, 1}); | |||
| args.emplace_back(BenchArgs{batch, 256, 56, 56, 32, 3, 1}); | |||
| args.emplace_back(BenchArgs{batch, 256, 56, 56, 32, 3, 2}); | |||
| args.emplace_back(BenchArgs{batch, 4, 256, 256, 32, 7, 2}); | |||
| args.emplace_back(BenchArgs{batch, 256, 56, 56, 64, 1, 1}); | |||
| args.emplace_back(BenchArgs{batch, 64, 56, 56, 64, 1, 1}); | |||
| args.emplace_back(BenchArgs{batch, 64, 56, 56, 64, 3, 1}); | |||
| args.emplace_back(BenchArgs{batch, 64, 56, 56, 64, 3, 2}); | |||
| args.emplace_back(BenchArgs{batch, 256, 56, 56, 64, 3, 2}); | |||
| args.emplace_back(BenchArgs{batch, 64, 56, 56, 256, 1, 1}); | |||
| args.emplace_back(BenchArgs{batch, 256, 56, 56, 512, 1, 2}); | |||
| @@ -101,13 +112,12 @@ void benchmark_target_algo( | |||
| conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(algo)); | |||
| } | |||
| #define V1(x) #x | |||
| #define V(x) V1(x) | |||
| #define CUDNN_VERSION_STRING \ | |||
| "v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) "." V(CUDNN_PATCHLEVEL) | |||
| benchmarker_cudnn.set_before_exec_callback( | |||
| conv_bias::ConvBiasAlgoChecker<ConvBiasForward>( | |||
| "CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_" | |||
| "DEFAULT:CUDNN:ConvBiasActivation:CUDNN_CONVOLUTION_FWD_" | |||
| "ALGO_IMPLICIT_PRECOMP_" | |||
| "GEMM" CUDNN_VERSION_STRING)); | |||
| benchmarker.set_dtype(0, src_dtype) | |||
| @@ -141,6 +151,7 @@ void benchmark_target_algo( | |||
| {}, | |||
| {}}) / | |||
| RUNS; | |||
| param.nonlineMode = Param::NonlineMode::IDENTITY; | |||
| benchmarker_cudnn.set_param(param); | |||
| auto time_in_ms_cudnn = | |||
| benchmarker_cudnn.execs( | |||
| @@ -162,6 +173,47 @@ void benchmark_target_algo( | |||
| (flo / (time_in_ms_cudnn * 1e-3)), algo, | |||
| time_in_ms_cudnn / time_in_ms); | |||
| } | |||
| printf("bench with z tensor\n"); | |||
| for (auto&& arg : args) { | |||
| Param param; | |||
| param.pad_h = param.pad_w = arg.f / 2; | |||
| param.stride_h = param.stride_w = arg.s; | |||
| param.format = Format::NCHW4; | |||
| size_t ho = infer_conv_shape(arg.hi, arg.f, arg.s, arg.f / 2); | |||
| size_t wo = infer_conv_shape(arg.wi, arg.f, arg.s, arg.f / 2); | |||
| benchmarker.set_param(param); | |||
| auto time_in_ms = | |||
| benchmarker.execs({{arg.n, arg.ci / 4, arg.hi, arg.wi, 4}, | |||
| {arg.co, arg.ci / 4, arg.f, arg.f, 4}, | |||
| {1, arg.co / 4, 1, 1, 4}, | |||
| {arg.n, arg.co / 4, ho, wo, 4}, | |||
| {}}) / | |||
| RUNS; | |||
| param.format = Format::NCHW4; | |||
| param.nonlineMode = Param::NonlineMode::IDENTITY; | |||
| benchmarker_cudnn.set_param(param); | |||
| auto time_in_ms_cudnn = | |||
| benchmarker_cudnn.execs( | |||
| {{arg.n, arg.ci / 4, arg.hi, arg.wi, 4}, | |||
| {arg.co, arg.ci / 4, arg.f, arg.f, 4}, | |||
| {1, arg.co / 4, 1, 1, 4}, | |||
| {arg.n, arg.co / 4, ho, wo, 4}, | |||
| {}}) / | |||
| RUNS; | |||
| float flo = 2.0 * arg.n * arg.co * ho * wo * arg.ci * arg.f * | |||
| arg.f / (1e12); | |||
| TensorShape src{arg.n, arg.ci, arg.hi, arg.wi}, | |||
| filter{arg.co, arg.ci, arg.f, arg.f}; | |||
| printf("src=%s, filter=%s, time(algo=%s)=%.2f %.2fTops, " | |||
| "time(cudnn)=%.2f %.2fTops, " | |||
| "perf(algo=%s)/perf(cudnn)=%.2f\n", | |||
| src.to_string().c_str(), filter.to_string().c_str(), algo, | |||
| time_in_ms, (flo / (time_in_ms * 1e-3)), time_in_ms_cudnn, | |||
| (flo / (time_in_ms_cudnn * 1e-3)), algo, | |||
| time_in_ms_cudnn / time_in_ms); | |||
| } | |||
| } else if (format == Format::CHWN4) { | |||
| for (auto&& arg : args) { | |||
| Param param; | |||
| @@ -222,6 +274,7 @@ void benchmark_target_algo( | |||
| RUNS; | |||
| param.format = Format::NCHW4; | |||
| benchmarker_cudnn.set_param(param); | |||
| param.nonlineMode = Param::NonlineMode::IDENTITY; | |||
| auto time_in_ms_cudnn = | |||
| benchmarker_cudnn.execs( | |||
| {{arg.n, arg.ci / 4, arg.hi, arg.wi, 4}, | |||
| @@ -242,7 +295,6 @@ void benchmark_target_algo( | |||
| (flo / (time_in_ms_cudnn * 1e-3)), algo, | |||
| time_in_ms_cudnn / time_in_ms); | |||
| } | |||
| } | |||
| } | |||
| @@ -265,15 +317,14 @@ void benchmark_target_algo_with_cudnn_tsc( | |||
| benchmarker.set_before_exec_callback( | |||
| conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(algo)); | |||
| } else { | |||
| benchmarker.set_proxy(proxy); | |||
| benchmarker.set_proxy(proxy); | |||
| } | |||
| benchmarker_cudnn.set_before_exec_callback( | |||
| conv_bias::ConvBiasAlgoChecker<ConvBiasForward>( | |||
| "CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_" | |||
| "DEFAULT:CUDNN:ConvBiasActivation:CUDNN_CONVOLUTION_FWD_" | |||
| "ALGO_IMPLICIT_PRECOMP_" | |||
| "GEMM" CUDNN_VERSION_STRING)); | |||
| #undef V1 | |||
| #undef V | |||
| #undef CUDNN_VERSION_STRING | |||
| benchmarker.set_dtype(0, src_dtype) | |||
| @@ -446,12 +497,10 @@ void benchmark_target_algo_with_cudnn_tsc( | |||
| (flo / (time_in_ms_cudnn * 1e-3)), algo, | |||
| time_in_ms_cudnn / time_in_ms); | |||
| } | |||
| } | |||
| } | |||
| } // namespace | |||
| #endif | |||
| } // namespace | |||
| TEST_F(CUDA, CONV_BIAS_INT8_NCHW4_1x1) { | |||
| require_compute_capability(6, 1); | |||
| @@ -1116,6 +1165,7 @@ TEST_F(CUDA, CONV_BIAS_INT8_CHWN4_UNROLL_WIDTH_TENSORCORE_1x1_ALGO_2) { | |||
| conv_bias::get_int8_chwn4_args_small_batch(1)); | |||
| } | |||
| #if MEGDNN_WITH_BENCHMARK | |||
| TEST_F(CUDA, BENCHMARK_CONV_BIAS_INT8_CHWN4) { | |||
| require_compute_capability(6, 1); | |||
| @@ -1182,9 +1232,13 @@ TEST_F(CUDA, BENCHMARK_CONV_BIAS_INT8_CHWN4_SMALL_CHANNEL) { | |||
| dtype::QuantizedS8{1.0f}, "INT8_CHWN4_DOTPROD_IMPLICIT_GEMM", | |||
| param::ConvBias::Format::CHWN4); | |||
| } | |||
| #endif | |||
| } // namespace test | |||
| } // namespace megdnn | |||
| #undef V1 | |||
| #undef V | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -290,6 +290,26 @@ TEST_F(CUDA, POOLING_FORWARD_CHWN4) { | |||
| } | |||
| } | |||
| TEST_F(CUDA, POOLING_FORWARD_INT8_NCHW4) { | |||
| require_compute_capability(6, 1); | |||
| using Param = param::Pooling; | |||
| Checker<Pooling> checker(handle_cuda()); | |||
| Param param; | |||
| auto i8_min = std::numeric_limits<int8_t>().min(); | |||
| auto i8_max = std::numeric_limits<int8_t>().max(); | |||
| UniformIntRNG int_rng{i8_min, i8_max}; | |||
| checker.set_dtype(0, dtype::QuantizedS8(0.1f)); | |||
| param.format = Param::Format::NCHW4; | |||
| for (auto mode : {Param::Mode::MAX, Param::Mode::AVERAGE, | |||
| Param::Mode::AVERAGE_COUNT_EXCLUDE_PADDING}) { | |||
| param.mode = mode; | |||
| checker.set_epsilon(1e-3).set_rng(0, &int_rng); | |||
| checker.set_param(param).exec({{64, 8, 28, 28, 4}, {}}); | |||
| checker.set_param(param).exec({{15, 8, 28, 28, 4}, {}}); | |||
| checker.set_param(param).exec({{30, 8, 28, 28, 4}, {}}); | |||
| } | |||
| } | |||
| #if MEGDNN_WITH_BENCHMARK | |||
| TEST_F(CUDA, BENCHMARK_POOLING_CHWN4) { | |||
| CUBenchmarker<Pooling> bencher(handle_cuda()); | |||
| @@ -20,6 +20,14 @@ bool check_compute_capability(int major, int minor) { | |||
| cuda_check(cudaGetDeviceProperties(&prop, dev)); | |||
| return prop.major > major || (prop.major == major && prop.minor >= minor); | |||
| } | |||
| bool check_compute_capability_eq(int major, int minor) { | |||
| int dev; | |||
| cuda_check(cudaGetDevice(&dev)); | |||
| cudaDeviceProp prop; | |||
| cuda_check(cudaGetDeviceProperties(&prop, dev)); | |||
| return (prop.major == major && prop.minor == minor); | |||
| } | |||
| } // namespace test | |||
| } // namespace megdnn | |||
| @@ -26,13 +26,28 @@ | |||
| namespace megdnn { | |||
| namespace test { | |||
| bool check_compute_capability(int major, int minor); | |||
| bool check_compute_capability_eq(int major, int minor); | |||
| } // namespace test | |||
| } // namespace megdnn | |||
| #define require_compute_capability(x, y) \ | |||
| do { \ | |||
| if (!megdnn::test::check_compute_capability((x), (y))) \ | |||
| return; \ | |||
| #define require_compute_capability(x, y) \ | |||
| do { \ | |||
| if (!megdnn::test::check_compute_capability((x), (y))) { \ | |||
| printf("skip testcase due to cuda compute capability not " \ | |||
| "require.(expected:%d.%d)", \ | |||
| (x), (y)); \ | |||
| return; \ | |||
| } \ | |||
| } while (0) | |||
| #define require_compute_capability_eq(x, y) \ | |||
| do { \ | |||
| if (!megdnn::test::check_compute_capability_eq((x), (y))) { \ | |||
| printf("skip testcase due to cuda compute capability not " \ | |||
| "equal to %d.%d", \ | |||
| (x), (y)); \ | |||
| return; \ | |||
| } \ | |||
| } while (0) | |||
| // vim: syntax=cpp.doxygen | |||