GitOrigin-RevId: 55fb2a9b25
tags/v1.2.0
| @@ -87,6 +87,23 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in ssrc, _megdnn_tensor_out sdst, | |||
| return pooling2d::do_pooling2d_int8_ncdiv4hw4( | |||
| src.compatible_ptr<int8_t>(), dst.compatible_ptr<int8_t>(), | |||
| kern_param, stream, static_cast<uint32_t>(param().mode)); | |||
| } else if (param().format == Format::NCHW32) { | |||
| pooling2d::Param kern_param; | |||
| size_t n = src.layout[0], hi = src.layout[2], wi = src.layout[3], | |||
| c = src.layout[1], ho = dst.layout[2], wo = dst.layout[3]; | |||
| c = c * 32; | |||
| size_t ph = param().pad_h, pw = param().pad_w; | |||
| size_t window_h = param().window_h, window_w = param().window_w; | |||
| size_t sh = param().stride_h, sw = param().stride_w; | |||
| kern_param.n = n, kern_param.c = c, kern_param.hi = hi, | |||
| kern_param.wi = wi, kern_param.ho = ho, kern_param.wo = wo, | |||
| kern_param.ph = ph, kern_param.pw = pw, | |||
| kern_param.window_h = window_h, kern_param.window_w = window_w, | |||
| kern_param.sh = sh, kern_param.sw = sw; | |||
| auto&& stream = cuda_stream(handle()); | |||
| return pooling2d::do_pooling2d_int8_ncdiv32hw32( | |||
| src.compatible_ptr<int8_t>(), dst.compatible_ptr<int8_t>(), | |||
| kern_param, stream, static_cast<uint32_t>(param().mode)); | |||
| } | |||
| auto handle = cudnn_handle(this->handle()); | |||
| setup_descs(src.layout, dst.layout); | |||
| @@ -413,6 +413,62 @@ __global__ void pooling2d_device_template_int8_ncdiv4hw4( | |||
| *(reinterpret_cast<ldg_type*>(g_dst_ptr)) = res; | |||
| } | |||
| template <typename Pooler> | |||
| __global__ void pooling2d_device_template_int8_ncdiv32hw32( | |||
| const int8_t* __restrict__ src, int8_t* __restrict__ dst, Param param) { | |||
| const int tid = blockIdx.x * blockDim.x + threadIdx.x; | |||
| using ldg_type = typename Pooler::feed_type; | |||
| static int constexpr pack_size = 32; | |||
| static int constexpr ldg_width = sizeof(ldg_type) / sizeof(int32_t); | |||
| static int constexpr ldg_width_bytes = sizeof(ldg_type); | |||
| static int constexpr section = pack_size / sizeof(ldg_type); | |||
| MEGDNN_STATIC_ASSERT( | |||
| ldg_width == 4, | |||
| "pooling2d (NCHW32) kernel must use 128bit width ldg instruction"); | |||
| const int c_packed = param.c / pack_size; | |||
| const int batch = tid / (param.ho * param.wo * c_packed * section); | |||
| const int batch_residual = | |||
| tid - batch * param.ho * param.wo * c_packed * section; | |||
| const int oc = batch_residual / (param.ho * param.wo * section); | |||
| const int oc_residual = batch_residual - oc * param.ho * param.wo * section; | |||
| const int oh = oc_residual / (param.wo * section); | |||
| const int oh_residual = (oc_residual - oh * param.wo * section); | |||
| const int ow = oh_residual / section; | |||
| const int sec = oh_residual - ow * section; | |||
| if (batch >= param.n || oc >= c_packed || oh >= param.ho || ow >= param.wo) | |||
| return; | |||
| const int in_batch_stride = param.hi * param.wi * param.c; | |||
| const int out_batch_stride = param.ho * param.wo * param.c; | |||
| const int in_channel_stride = param.hi * param.wi * pack_size; | |||
| const int out_channel_stride = param.ho * param.wo * pack_size; | |||
| const int8_t* __restrict__ g_src_ptr = src + batch * in_batch_stride + | |||
| oc * in_channel_stride + | |||
| sec * ldg_width_bytes; | |||
| int8_t* __restrict__ g_dst_ptr = | |||
| dst + batch * out_batch_stride + oc * out_channel_stride + | |||
| (oh * param.wo + ow) * pack_size + sec * ldg_width_bytes; | |||
| Pooler pooler(param.window_h * param.window_w); | |||
| pooler.init(); | |||
| for (int fh = 0; fh < param.window_h; fh++) { | |||
| uint32_t ih = oh * param.sh + fh - param.ph; | |||
| for (int fw = 0; fw < param.window_w; fw++) { | |||
| uint32_t iw = ow * param.sw + fw - param.pw; | |||
| if (ih < param.hi && iw < param.wi) { | |||
| const int8_t* __restrict__ cur_src_ptr = | |||
| g_src_ptr + (ih * param.wi + iw) * pack_size; | |||
| ldg_type sval = | |||
| __ldg(reinterpret_cast<const ldg_type*>(cur_src_ptr)); | |||
| pooler.feed(sval); | |||
| } | |||
| } | |||
| } | |||
| ldg_type res = pooler.get_ans(); | |||
| *(reinterpret_cast<ldg_type*>(g_dst_ptr)) = res; | |||
| } | |||
| }; // namespace | |||
| void megdnn::cuda::pooling2d::do_pooling2d_int8_cdiv4hwn4(const int8_t* d_src, | |||
| @@ -494,4 +550,43 @@ void megdnn::cuda::pooling2d::do_pooling2d_int8_ncdiv4hw4(const int8_t* d_src, | |||
| kern<<<nr_blocks, nr_threads, 0, stream>>>(d_src, d_dst, param); | |||
| after_kernel_launch(); | |||
| } | |||
| void megdnn::cuda::pooling2d::do_pooling2d_int8_ncdiv32hw32(const int8_t* d_src, | |||
| int8_t* d_dst, | |||
| const Param& param, | |||
| cudaStream_t stream, | |||
| uint32_t mode) { | |||
| using Mode = megdnn::param_enumv::Pooling::Mode; | |||
| void (*kern)(const int8_t* __restrict__, int8_t* __restrict__, Param param); | |||
| uint32_t vthreads = param.n * param.c * param.ho * param.wo / 16; | |||
| switch (mode) { | |||
| case Mode::MAX: | |||
| kern = pooling2d_device_template_int8_ncdiv32hw32< | |||
| MaxPooler<int8_t, int4>>; | |||
| break; | |||
| case Mode::AVERAGE: | |||
| kern = pooling2d_device_template_int8_ncdiv32hw32< | |||
| MeanIncludeRoundedPooler<int8_t, int4, int32_t>>; | |||
| break; | |||
| case Mode::AVERAGE_COUNT_EXCLUDE_PADDING: | |||
| kern = pooling2d_device_template_int8_ncdiv32hw32< | |||
| MeanExcludeRoundedPooler<int8_t, int4, int32_t>>; | |||
| break; | |||
| default: | |||
| megdnn_assert(false, "invalid pooling mode"); | |||
| } | |||
| uint32_t nr_threads = query_blocksize_for_kernel(kern); | |||
| nr_threads = std::min(nr_threads, vthreads); | |||
| uint32_t nr_blocks = DIVUP(vthreads, nr_threads); | |||
| kern<<<nr_blocks, nr_threads, 0, stream>>>(d_src, d_dst, param); | |||
| after_kernel_launch(); | |||
| } | |||
| #undef FEED1 | |||
| #undef FEED2 | |||
| #undef FEED3 | |||
| #undef ANS1 | |||
| #undef ANS2 | |||
| #undef ANS4 | |||
| // vim: syntax=cuda.doxygen | |||
| @@ -29,6 +29,9 @@ void do_pooling2d_int8_ncdiv4hw4(const int8_t* d_src, int8_t* d_dst, | |||
| const Param& param, cudaStream_t stream, | |||
| uint32_t mode); | |||
| void do_pooling2d_int8_ncdiv32hw32(const int8_t* d_src, int8_t* d_dst, | |||
| const Param& param, cudaStream_t stream, | |||
| uint32_t mode); | |||
| } // namespace pooling2d | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| @@ -310,6 +310,26 @@ TEST_F(CUDA, POOLING_FORWARD_INT8_NCHW4) { | |||
| } | |||
| } | |||
| TEST_F(CUDA, POOLING_FORWARD_INT8_NCHW32) { | |||
| require_compute_capability(6, 1); | |||
| using Param = param::Pooling; | |||
| Checker<Pooling> checker(handle_cuda()); | |||
| Param param; | |||
| auto i8_min = std::numeric_limits<int8_t>().min(); | |||
| auto i8_max = std::numeric_limits<int8_t>().max(); | |||
| UniformIntRNG int_rng{i8_min, i8_max}; | |||
| checker.set_dtype(0, dtype::QuantizedS8(0.1f)); | |||
| param.format = Param::Format::NCHW32; | |||
| for (auto mode : {Param::Mode::MAX, Param::Mode::AVERAGE, | |||
| Param::Mode::AVERAGE_COUNT_EXCLUDE_PADDING}) { | |||
| param.mode = mode; | |||
| checker.set_epsilon(1e-3).set_rng(0, &int_rng); | |||
| checker.set_param(param).exec({{64, 8, 28, 28, 32}, {}}); | |||
| checker.set_param(param).exec({{15, 8, 28, 28, 32}, {}}); | |||
| checker.set_param(param).exec({{30, 8, 28, 28, 32}, {}}); | |||
| } | |||
| } | |||
| #if MEGDNN_WITH_BENCHMARK | |||
| TEST_F(CUDA, BENCHMARK_POOLING_CHWN4) { | |||
| CUBenchmarker<Pooling> bencher(handle_cuda()); | |||
| @@ -331,13 +351,17 @@ TEST_F(CUDA, BENCHMARK_POOLING_CHWN4) { | |||
| param.format = Param::Format::CHWN4; | |||
| bencher.set_param(param); | |||
| auto time_chwn4 = bencher.execs({{C / 4, H, W, N, 4}, {}}) / nr_times; | |||
| auto time_nchw32 = | |||
| bencher.execs({{N, C / 32, H, W, 32}, {}}) / nr_times; | |||
| size_t oh = infer_conv_shape(H, window, stride, padding), | |||
| ow = infer_conv_shape(W, window, stride, padding); | |||
| float io = (N * C * H * W + N * C * oh * ow) * sizeof(int8_t); | |||
| printf("time(cudnn)=%.2f ms, time(chwn4)=%.2f ms, " | |||
| "bandwidth(cudnn)=%.2f Gb/s, bandwidth(chwn4)=%.2f Gb/s\n", | |||
| time_cudnn, time_chwn4, io / (1e6 * time_cudnn), | |||
| io / (1e6 * time_chwn4)); | |||
| printf("time(cudnn)=%.2f ms, time(chwn4)=%.2f ms, time(nchw32)=%.2f " | |||
| "ms, " | |||
| "bandwidth(cudnn)=%.2f Gb/s, bandwidth(chwn4)=%.2f Gb/s, " | |||
| "bandwidth(nchw32)=%.2f Gb/s\n", | |||
| time_cudnn, time_chwn4, time_nchw32, io / (1e6 * time_cudnn), | |||
| io / (1e6 * time_chwn4), io / (1e6 * time_nchw32)); | |||
| }; | |||
| run_bench(64, 64, 112, 112, 2, 1, 2); | |||
| run_bench(256, 64, 112, 112, 2, 1, 2); | |||
| @@ -1090,14 +1090,24 @@ EnableTensorCorePass::make_tensorcore_converter() { | |||
| size_t nr_inps = opr->input().size(); | |||
| MGB_MARK_USED_VAR(nr_inps); | |||
| mgb_assert(nr_inps == 1); | |||
| if (!opr->input(0)->shape().eq_shape(new_inp[0]->shape())) { | |||
| mgb_assert(opr->input(0)->shape().ndim == 5 && | |||
| opr->input(0)->shape()[4] == 4); | |||
| mgb_assert(new_inp[0]->shape().ndim == 5 && | |||
| new_inp[0]->shape()[4] == 32); | |||
| size_t nr_channels = opr->input(0)->shape()[1] * 4; | |||
| if (nr_channels % 32 == 0) { // use nchw32 format | |||
| VarNode* new_inp_var = new_inp[0]; | |||
| if (opr->input(0)->shape().eq_shape(new_inp[0]->shape())) { | |||
| new_inp_var = | |||
| RelayoutPlaceholder::make( | |||
| new_inp[0], RelayoutPlaceholder::LayoutType:: | |||
| NCHW4_TO_NCHW32) | |||
| .node(); | |||
| } else { | |||
| mgb_assert(opr->input(0)->shape().ndim == 5 && | |||
| opr->input(0)->shape()[4] == 4); | |||
| mgb_assert(new_inp[0]->shape().ndim == 5 && | |||
| new_inp[0]->shape()[4] == 32); | |||
| } | |||
| auto new_param = pooling.param(); | |||
| new_param.format = Format::NCHW32; | |||
| auto new_pooling = opr::PoolingForward::make(new_inp[0], new_param, | |||
| auto new_pooling = opr::PoolingForward::make(new_inp_var, new_param, | |||
| opr->config()); | |||
| return new_pooling.node()->owner_opr(); | |||
| } | |||
| @@ -1989,6 +1989,74 @@ TEST(TestEnableTensorCore, ConvBiasWithZ) { | |||
| MGB_ASSERT_TENSOR_EQ(host_y, host_y_opt); | |||
| } | |||
| TEST(TestEnableTensorCore, Pooling) { | |||
| REQUIRE_GPU(1); | |||
| auto cn = CompNode::load("gpu0"); | |||
| cn.activate(); | |||
| auto&& prop = CompNodeEnv::from_comp_node(cn).cuda_env().device_prop; | |||
| auto sm_ver = prop.major * 10 + prop.minor; | |||
| if (sm_ver < 75) { | |||
| printf("This testcast ignored due to insufficient cuda cap(got: %d, " | |||
| "expected: %d)\n", | |||
| sm_ver, 75); | |||
| return; | |||
| } | |||
| HostTensorGenerator<dtype::Int8> gen; | |||
| auto graph = ComputingGraph::make(); | |||
| graph->options().graph_opt_level = 0; | |||
| auto mkvar = [&](const char* name, const TensorShape& shp, | |||
| const DType& dtype) { | |||
| return opr::TypeCvt::make( | |||
| opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name), | |||
| dtype); | |||
| }; | |||
| auto mkcvar = [&](const char* name, const TensorShape& shp, | |||
| const DType& dtype) { | |||
| return opr::TypeCvt::make( | |||
| opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) | |||
| .rename(name), | |||
| dtype); | |||
| }; | |||
| auto x = mkvar("x", {32, 16, 16, 16, 4}, dtype::QuantizedS8(2.5f)), | |||
| w = mkcvar("w1", {64, 16, 3, 3, 4}, dtype::QuantizedS8(2.5f)), | |||
| b = mkcvar("b", {1, 16, 1, 1, 4}, dtype::QuantizedS32(6.25f)), | |||
| z = mkvar("b1", {32, 16, 16, 16, 4}, dtype::QuantizedS8(2.5f)); | |||
| opr::ConvBias::Param param; | |||
| param.format = opr::ConvBias::Param::Format::NCHW4; | |||
| param.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU; | |||
| param.stride_h = param.stride_w = 1; | |||
| param.pad_h = param.pad_w = 1; | |||
| auto y = opr::ConvBias::make(x, w, b, z, param, {}, | |||
| OperatorNodeConfig{dtype::QuantizedS8(2.5f)}); | |||
| opr::Pooling::Param pool_param; | |||
| pool_param.format = opr::Pooling::Param::Format::NCHW4; | |||
| y = opr::Pooling::make(y, pool_param); | |||
| y = opr::TypeCvt::make(y, dtype::Float32()); | |||
| SymbolVar y_opt; | |||
| SymbolVar y_no_tc; | |||
| { | |||
| auto options = gopt::OptimizeForInferenceOptions{}; | |||
| options.enable_fuse_conv_bias_nonlinearity().enable_nchw32(); | |||
| unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); | |||
| } | |||
| ASSERT_EQ(opr::Pooling::Param::Format::NCHW32, | |||
| find_opr<opr::Pooling>(y_opt).param().format); | |||
| { | |||
| auto options = gopt::OptimizeForInferenceOptions{}; | |||
| options.enable_fuse_conv_bias_nonlinearity(); | |||
| unpack_vector(gopt::optimize_for_inference({y}, options), y_no_tc); | |||
| } | |||
| HostTensorND host_y, host_y_opt; | |||
| auto func = graph->compile({make_callback_copy(y_no_tc, host_y), | |||
| make_callback_copy(y_opt, host_y_opt)}); | |||
| func->execute(); | |||
| MGB_ASSERT_TENSOR_EQ(host_y, host_y_opt); | |||
| } | |||
| TEST(TestGoptInference, EnableTensorCore) { | |||
| REQUIRE_GPU(1); | |||
| auto cn = CompNode::load("gpu0"); | |||