GitOrigin-RevId: 814b8a83f8
tags/v1.11.1
| @@ -38,7 +38,7 @@ void RegionRestrictedConvolutionForward::deduce_dtype( | |||
| "only float type is supported for region_restricted_conv forward"); | |||
| megdnn_assert( | |||
| rin == rout && (rin == dtype::Int32() || rin == dtype::Uint8()), | |||
| "the dtype of rin/rout should be Int32, got %s.", rin.name()); | |||
| "the dtype of rin/rout should be Int32 or Uint8, got %s.", rin.name()); | |||
| } | |||
| void RegionRestrictedConvolutionForward::deduce_layout( | |||
| @@ -91,12 +91,12 @@ RegionRestrictedConvolutionBackwardData::check_exec( | |||
| auto ret = check_layout_fwd(grad_fwd, filter_fwd, diff_fwd); | |||
| #define err_msg(lhs, rhs) \ | |||
| megdnn_assert(lhs == rhs, "shape mismatch, #lhs:%zu, #rhs:%zu", lhs, rhs); | |||
| err_msg(rin.shape[0], grad_fwd.shape[0]); | |||
| err_msg(rin.shape[1], grad_fwd.shape[2]); | |||
| err_msg(rin.shape[2], grad_fwd.shape[3]); | |||
| err_msg(rout.shape[0], diff_fwd.shape[0]); | |||
| err_msg(rout.shape[1], diff_fwd.shape[2]); | |||
| err_msg(rout.shape[2], diff_fwd.shape[3]); | |||
| err_msg(rin.shape[0], grad_fwd.shape[0]); // batch | |||
| err_msg(rin.shape[1], grad_fwd.shape[2]); // ih | |||
| err_msg(rin.shape[2], grad_fwd.shape[3]); // iw | |||
| err_msg(rout.shape[0], diff_fwd.shape[0]); // batch | |||
| err_msg(rout.shape[1], diff_fwd.shape[2]); // oh | |||
| err_msg(rout.shape[2], diff_fwd.shape[3]); // ow | |||
| #undef err_msg | |||
| auto required_workspace_in_bytes = | |||
| get_workspace_in_bytes(filter, diff, rin, rout, grad); | |||
| @@ -106,45 +106,22 @@ RegionRestrictedConvolutionBackwardData::check_exec( | |||
| void RegionRestrictedConvolutionBackwardData::deduce_dtype( | |||
| DType filter, DType diff, DType rin, DType rout, DType& grad) { | |||
| SmallVector<DType> supported_dst_dtype; | |||
| if (filter.category() == diff.category() && | |||
| filter.category() == DTypeCategory::FLOAT) { | |||
| supported_dst_dtype.push_back(filter); | |||
| } else if (filter.enumv() == DTypeEnum::Int8 && diff == filter) { | |||
| supported_dst_dtype.push_back(dtype::Int32()); | |||
| } else if ( | |||
| (filter.enumv() == DTypeEnum::QuantizedS8 && | |||
| diff.enumv() == DTypeEnum::QuantizedS8) || | |||
| (filter.enumv() == DTypeEnum::Quantized8Asymm && | |||
| diff.enumv() == DTypeEnum::Quantized8Asymm)) { | |||
| supported_dst_dtype.push_back(dtype::QuantizedS32(mul_scale(filter, diff))); | |||
| if (grad.valid() && grad.enumv() == diff.enumv()) { | |||
| supported_dst_dtype.push_back(grad); | |||
| } | |||
| } else { | |||
| megdnn_throw(ssprintf( | |||
| "unsupported input / diff DType: %s x %s", filter.name(), diff.name())); | |||
| } | |||
| if (!grad.valid()) { | |||
| grad = supported_dst_dtype.at(0); | |||
| } else { | |||
| megdnn_assert( | |||
| vec_contains(supported_dst_dtype, grad), | |||
| "unsupported ConvBwd(%s, %s) -> %s", filter.name(), diff.name(), | |||
| grad.name()); | |||
| } | |||
| megdnn_assert( | |||
| param().compute_mode != Param::ComputeMode::FLOAT32 | |||
| // FIXME: infering dtype of grad via naive impl only support fp32 | |||
| // (lack of quantized dtype infering or others) may not suitable in the furture | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| || filter.enumv() == DTypeEnum::Float16 || | |||
| filter.enumv() == DTypeEnum::BFloat16 | |||
| if (diff.enumv() == DTypeEnum::Float32 || diff.enumv() == DTypeEnum::Float16) { | |||
| grad = diff; | |||
| } | |||
| #endif | |||
| , | |||
| "ComputeMode::FLOAT32 is only available for Float16/BFloat16 " | |||
| "input / output."); | |||
| megdnn_assert(grad.valid(), "dtype of grad requires deducing of assigned"); | |||
| megdnn_assert( | |||
| rin == rout && rin == dtype::Int32(), | |||
| "the dtype of rin/rout should be Int32, got %s.", rin.name()); | |||
| diff.category() == DTypeCategory::FLOAT && | |||
| filter.category() == DTypeCategory::FLOAT && | |||
| grad.category() == DTypeCategory::FLOAT, | |||
| "only float type is supported for region_restricted_conv backward data"); | |||
| megdnn_assert( | |||
| rin == rout && (rin == dtype::Int32() || rin == dtype::Uint8()), | |||
| "the dtype of rin/rout should be Int32 or Uint8, got %s.", rin.name()); | |||
| } | |||
| void RegionRestrictedConvolutionBackwardData::deduce_layout( | |||
| @@ -1,7 +1,7 @@ | |||
| #include "./kern.cuh" | |||
| #include "cuda.h" | |||
| #include "cuda_fp16.h" | |||
| #include "src/cuda/fp16_help.cuh" | |||
| #include "src/cuda/region_restricted_convolution/chanwise/kern.cuh" | |||
| using namespace megdnn; | |||
| using namespace cuda; | |||
| @@ -15,7 +15,7 @@ namespace cuda { | |||
| namespace region_restricted_convolution { | |||
| namespace chanwise { | |||
| // =====================================fwd===================================== | |||
| // =====================================bwd===================================== | |||
| template <> | |||
| void run_bwd_depthwise_large_filter( | |||
| @@ -498,16 +498,8 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( | |||
| SrcGlobal2ShareVisitor gl2sh_src = { | |||
| smem_src, | |||
| static_cast<int>(param.src_w), | |||
| static_cast<int>( | |||
| is_fwd ? src_start_h | |||
| : src_start_h - | |||
| (param.out_h / 2 + param.flt_h / 2 - param.pad_h - | |||
| param.src_h * param.stride_h / 2)), | |||
| static_cast<int>( | |||
| is_fwd ? src_start_w | |||
| : src_start_w - | |||
| (param.out_w / 2 + param.flt_w / 2 - param.pad_w - | |||
| param.src_w * param.stride_w / 2)), | |||
| static_cast<int>(src_start_h), | |||
| static_cast<int>(src_start_w), | |||
| static_cast<int>(is_fwd ? param.src_h : param.src_h * param.stride_h), | |||
| static_cast<int>(is_fwd ? param.src_w : param.src_w * param.stride_w), | |||
| is_fwd ? 1 : static_cast<int>(param.stride_h), | |||
| @@ -516,16 +508,8 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( | |||
| RinGlobal2ShareVisitor gl2sh_rin = { | |||
| smem_rin, | |||
| static_cast<int>(param.src_w), | |||
| static_cast<int>( | |||
| is_fwd ? src_start_h | |||
| : src_start_h - | |||
| (param.out_h / 2 + param.flt_h / 2 - param.pad_h - | |||
| param.src_h * param.stride_h / 2)), | |||
| static_cast<int>( | |||
| is_fwd ? src_start_w | |||
| : src_start_w - | |||
| (param.out_w / 2 + param.flt_w / 2 - param.pad_w - | |||
| param.src_w * param.stride_w / 2)), | |||
| static_cast<int>(src_start_h), | |||
| static_cast<int>(src_start_w), | |||
| static_cast<int>(is_fwd ? param.src_h : param.src_h * param.stride_h), | |||
| static_cast<int>(is_fwd ? param.src_w : param.src_w * param.stride_w), | |||
| is_fwd ? 1 : static_cast<int>(param.stride_h), | |||
| @@ -790,14 +774,15 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( | |||
| out_base_h_idx = out_start_h + off_oh * OutTileConfig::unroll_h; | |||
| T* smem_src_ptr = smem_src + off_ow * FilterTileConfig::unroll_w; | |||
| static_assert((FilterTileConfig::unroll_w & 3) == 0); | |||
| static_assert( | |||
| (FilterTileConfig::unroll_w & 3) == 0, "filter tile unroll_w & 3 != 0"); | |||
| int* smem_rin_ptr = smem_rin + (off_ow * FilterTileConfig::unroll_w >> 2); | |||
| T* smem_flt_ptr = smem_flt + off_ow * FilterTileConfig::unroll_w; | |||
| T* out_base_ptr = output + off_ochannel * param.out_h * param.out_w; | |||
| const uint8_t* rout_base_ptr = rout + batch * param.out_h * param.out_w; | |||
| static_assert((OutTileConfig::unroll_w & 3) == 0); | |||
| static_assert((OutTileConfig::block_w & 3) == 0); | |||
| static_assert((OutTileConfig::unroll_w & 3) == 0, "output tile unroll_w & 3 != 0"); | |||
| static_assert((OutTileConfig::block_w & 3) == 0, "output block_w & 3 != 0"); | |||
| int reg_rout[OutTileConfig::unroll_size] = {0}; | |||
| #pragma unroll | |||
| for (int i = 0; i < OutTileConfig::unroll_h; ++i) { | |||
| @@ -821,16 +806,8 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( | |||
| SrcGlobal2ShareVisitor gl2sh_src = { | |||
| smem_src, | |||
| static_cast<int>(param.src_w), | |||
| static_cast<int>( | |||
| is_fwd ? src_start_h | |||
| : src_start_h - | |||
| (param.out_h / 2 + param.flt_h / 2 - param.pad_h - | |||
| param.src_h * param.stride_h / 2)), | |||
| static_cast<int>( | |||
| is_fwd ? src_start_w | |||
| : src_start_w - | |||
| (param.out_w / 2 + param.flt_w / 2 - param.pad_w - | |||
| param.src_w * param.stride_w / 2)), | |||
| static_cast<int>(src_start_h), | |||
| static_cast<int>(src_start_w), | |||
| static_cast<int>(is_fwd ? param.src_h : param.src_h * param.stride_h), | |||
| static_cast<int>(is_fwd ? param.src_w : param.src_w * param.stride_w), | |||
| is_fwd ? 1 : static_cast<int>(param.stride_h), | |||
| @@ -839,16 +816,8 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( | |||
| RinGlobal2ShareVisitor gl2sh_rin = { | |||
| smem_rin, | |||
| static_cast<int>(param.src_w), | |||
| static_cast<int>( | |||
| is_fwd ? src_start_h | |||
| : src_start_h - | |||
| (param.out_h / 2 + param.flt_h / 2 - param.pad_h - | |||
| param.src_h * param.stride_h / 2)), | |||
| static_cast<int>( | |||
| is_fwd ? src_start_w | |||
| : src_start_w - | |||
| (param.out_w / 2 + param.flt_w / 2 - param.pad_w - | |||
| param.src_w * param.stride_w / 2)), | |||
| static_cast<int>(src_start_h), | |||
| static_cast<int>(src_start_w), | |||
| static_cast<int>(is_fwd ? param.src_h : param.src_h * param.stride_h), | |||
| static_cast<int>(is_fwd ? param.src_w : param.src_w * param.stride_w), | |||
| is_fwd ? 1 : static_cast<int>(param.stride_h), | |||
| @@ -1134,14 +1103,20 @@ void LaunchDepthwiseConv2dGPU( | |||
| RinTileCount::smem_size * sizeof(int); | |||
| void (*kernel)(const Param, const T*, const T*, const RT*, const RT*, T*); | |||
| const bool is_fwd = (kDirection == DIRECTION_FORWARD); | |||
| if (param.is_compute_deafult) { | |||
| kernel = DepthwiseConv2dGPUKernelNCHW<IConvTrait, kDirection, stride>; | |||
| } else { | |||
| megdnn_assert_internal(0); | |||
| } | |||
| kernel<<<grid, block, shared_storage, stream>>>( | |||
| param, input, filter, rin, rout, output); | |||
| if (is_fwd) { | |||
| kernel<<<grid, block, shared_storage, stream>>>( | |||
| param, input, filter, rin, rout, output); | |||
| } else { | |||
| kernel<<<grid, block, shared_storage, stream>>>( | |||
| param, input, filter, rout, rin, output); | |||
| } | |||
| after_kernel_launch(); | |||
| } | |||
| @@ -55,25 +55,65 @@ size_t RegionRestrictedConvolutionBackwardDataImpl::get_workspace_in_bytes( | |||
| void RegionRestrictedConvolutionBackwardDataImpl::exec( | |||
| _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_in rin, | |||
| _megdnn_tensor_in rout, _megdnn_tensor_out grad, _megdnn_workspace workspace) { | |||
| megdnn_throw(ssprintf( | |||
| "unsupported RegionRestrictedConvolutionBackwardData(%s, %s, %s, %s) -> %s", | |||
| filter.layout.dtype.name(), diff.layout.dtype.name(), | |||
| rin.layout.dtype.name(), rout.layout.dtype.name(), | |||
| grad.layout.dtype.name())); | |||
| auto fm = check_exec( | |||
| filter.layout, diff.layout, rin.layout, rout.layout, grad.layout, | |||
| workspace.size); | |||
| // XXX: a naive impl to set deconv padding to param, needs optimization in future. | |||
| [&]() -> void { | |||
| size_t stride = fm.stride[0]; | |||
| size_t src_size = grad.layout.shape[2]; | |||
| size_t fwd_pad = fm.padding[0]; | |||
| size_t filter_size = fm.spatial[0]; | |||
| size_t deconv_pad = (stride * src_size - stride + stride * filter_size - | |||
| src_size - 2 * fwd_pad + filter_size - 1) / | |||
| (2 * stride); | |||
| fm.padding[0] = fm.padding[1] = deconv_pad; | |||
| return; | |||
| }(); | |||
| auto kparam = chanwise::Param::load( | |||
| diff.layout, grad.layout, fm, | |||
| param().compute_mode == Param::ComputeMode::DEFAULT); | |||
| megdnn_assert( | |||
| fm.group > 1 && diff.layout.dtype.category() == DTypeCategory::FLOAT && | |||
| param().compute_mode == Param::ComputeMode::DEFAULT && | |||
| fm.spatial_ndim == 2 && fm.icpg == 1 && fm.ocpg == 1 && | |||
| fm.dilation[0] == 1 && fm.dilation[1] == 1 && !fm.should_flip && | |||
| param().stride_h == 1 && param().stride_w == 1); | |||
| // NOTE: uint8 dtype region mask requires the spatial size of src&dst is 4*N | |||
| if (rin.layout.dtype == dtype::Uint8()) { | |||
| megdnn_assert( | |||
| (grad.layout.shape[3] & 3) == 0 && (diff.layout.shape[3] & 3) == 0); | |||
| } | |||
| auto stream = cuda_stream(handle()); | |||
| if (filter.layout.dtype == dtype::Float32() && rin.layout.dtype == dtype::Int32() && | |||
| rout.layout.dtype == dtype::Int32()) { | |||
| chanwise::run_bwd_depthwise_large_filter( | |||
| grad.ptr<dt_float32>(), diff.ptr<dt_float32>(), | |||
| filter.ptr<dt_float32>(), rin.ptr<dt_int32>(), rout.ptr<dt_int32>(), | |||
| kparam, stream); | |||
| } else if ( | |||
| filter.layout.dtype == dtype::Float32() && | |||
| rin.layout.dtype == dtype::Uint8() && rout.layout.dtype == dtype::Uint8()) { | |||
| chanwise::run_bwd_depthwise_large_filter( | |||
| grad.ptr<dt_float32>(), diff.ptr<dt_float32>(), | |||
| filter.ptr<dt_float32>(), rin.ptr<dt_uint8>(), rout.ptr<dt_uint8>(), | |||
| kparam, stream); | |||
| } else { | |||
| megdnn_throw("undefined or unimplemented region restricted conv mode"); | |||
| } | |||
| } | |||
| size_t RegionRestrictedConvolutionBackwardFilterImpl::get_workspace_in_bytes( | |||
| const TensorLayout& src, const TensorLayout& diff, const TensorLayout&, | |||
| const TensorLayout&, const TensorLayout& grad) { | |||
| size_t workspace_size = 0; | |||
| return workspace_size; | |||
| return 0; | |||
| } | |||
| /* ============== RegionRestrictedConvolutionBackwardFilterImpl ============== */ | |||
| void RegionRestrictedConvolutionBackwardFilterImpl::exec( | |||
| _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_in rin, | |||
| _megdnn_tensor_in rout, _megdnn_tensor_out grad, _megdnn_workspace workspace) { | |||
| megdnn_assert_internal(0); | |||
| megdnn_throw("Region Restricted Conv BackwardFilter unimplemented"); | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -117,7 +117,7 @@ TEST_F(CUDA, BENCHMARK_REGION_RESTRICTED_CONV_FORWARD_LARGE_FILTER_FP32) { | |||
| .set_dtype(1, dtype::Float32()) | |||
| .set_dtype(2, dtype::Int32()) | |||
| .set_dtype(3, dtype::Int32()); | |||
| rr_bencher.set_rng(2, &r_rng).set_rng(3, &r_rng).set_rng(0, &r_rng); | |||
| rr_bencher.set_rng(2, &r_rng).set_rng(3, &r_rng); | |||
| rr_bencher.set_times(nr_times); | |||
| size_t ho = infer_conv_shape(hi, fh, sh, param.pad_h); | |||
| @@ -169,6 +169,202 @@ TEST_F(CUDA, BENCHMARK_REGION_RESTRICTED_CONV_FORWARD_LARGE_FILTER_FP32) { | |||
| run_bench(64, 384, 32, 32, 31, 31, 1, 1, 10); | |||
| } | |||
| TEST_F(CUDA, BENCHMARK_REGION_RESTRICTED_CONV_BACKWARD_LARGE_FILTER_FP32) { | |||
| require_compute_capability(7, 5); | |||
| Benchmarker<ConvolutionBackwardData> bencher(handle_cuda()); | |||
| bencher.set_display(false); | |||
| bencher.set_before_exec_callback( | |||
| AlgoChecker<ConvolutionBackwardData>("DEPTHWISE_LARGE_FILTER")); | |||
| Benchmarker<RegionRestrictedConvolutionBackwardData> rr_bencher(handle_cuda()); | |||
| rr_bencher.set_display(false); | |||
| ConvolutionBackwardData::Param param; | |||
| param.format = ConvolutionBackwardData::Param::Format::NCHW; | |||
| param.sparse = ConvolutionBackwardData::Param::Sparse::GROUP; | |||
| RegionRestrictedConvolutionBackwardData::Param rr_param; | |||
| rr_param.format = RegionRestrictedConvolutionBackwardData::Param::Format::NCHW; | |||
| rr_param.sparse = RegionRestrictedConvolutionBackwardData::Param::Sparse::GROUP; | |||
| UniformIntRNG r_rng{1, 3}; | |||
| auto run_bench = [&](size_t batch, size_t g, size_t hi, size_t wi, size_t fh, | |||
| size_t fw, size_t sh, size_t sw, size_t nr_times) { | |||
| param.pad_h = fh / 2; | |||
| param.pad_w = fw / 2; | |||
| param.stride_h = sh; | |||
| param.stride_w = sw; | |||
| rr_param.pad_h = fh / 2; | |||
| rr_param.pad_w = fw / 2; | |||
| rr_param.stride_h = sh; | |||
| rr_param.stride_w = sw; | |||
| bencher.set_param(param) | |||
| .set_dtype(0, dtype::Float32()) | |||
| .set_dtype(1, dtype::Float32()) | |||
| .set_dtype(2, dtype::Float32()) | |||
| .set_dtype(4, dtype::Float32()); | |||
| bencher.set_times(nr_times); | |||
| rr_bencher.set_param(rr_param) | |||
| .set_dtype(0, dtype::Float32()) | |||
| .set_dtype(1, dtype::Float32()) | |||
| .set_dtype(2, dtype::Int32()) | |||
| .set_dtype(3, dtype::Int32()); | |||
| rr_bencher.set_rng(2, &r_rng).set_rng(3, &r_rng); | |||
| rr_bencher.set_times(nr_times); | |||
| size_t ho = infer_conv_shape(hi, fh, sh, param.pad_h); | |||
| size_t wo = infer_conv_shape(wi, fw, sw, param.pad_w); | |||
| TensorShape inp{batch, g, hi, wi} /*src*/, kern{g, 1, 1, fh, fw} /*filter*/, | |||
| rin{batch, hi, wi}, rout{batch, ho, wo}, | |||
| out{batch, g, ho, wo} /*output*/; | |||
| float bandwith = static_cast<float>( | |||
| inp.total_nr_elems() + kern.total_nr_elems() + | |||
| out.total_nr_elems()) / | |||
| (1024 * 1024 * 1024) * 1e3; | |||
| float rr_bandwith = static_cast<float>( | |||
| inp.total_nr_elems() + kern.total_nr_elems() + | |||
| rin.total_nr_elems() + rout.total_nr_elems() + | |||
| out.total_nr_elems()) / | |||
| (1024 * 1024 * 1024) * 1e3; | |||
| auto time_in_ms = bencher.execs({kern, out, inp}) / nr_times; | |||
| auto ops = 2.0 * batch * g * ho * wo * fh * fw / (time_in_ms * 1e-3) * 1e-12; | |||
| auto rr_time_in_ms = rr_bencher.execs({kern, out, rin, rout, inp}) / nr_times; | |||
| auto rr_ops = | |||
| 2.0 * batch * g * ho * wo * fh * fw / (rr_time_in_ms * 1e-3) * 1e-12; | |||
| printf("[DGRAD]RegionRestrictedDepthwiseLargeFilter vs DepthwiseLargeFilter: " | |||
| "grad=%s, " | |||
| "kern=%s, diff=%s\n" | |||
| "time: %.2f ms, time(rr): %.2f ms, perf: %.2fTops, perf(rr): %.2f Tops\n" | |||
| "bandwidth: %.2fGB/s, bandwidth(rr): %.2fGB/s, speedup: %.2f.\n", | |||
| inp.to_string().c_str(), kern.to_string().c_str(), | |||
| out.to_string().c_str(), time_in_ms, rr_time_in_ms, ops, rr_ops, | |||
| bandwith * 4 / time_in_ms, rr_bandwith * 4 / rr_time_in_ms, | |||
| time_in_ms / rr_time_in_ms); | |||
| }; | |||
| run_bench(64, 384, 32, 32, 3, 3, 1, 1, 10); | |||
| run_bench(64, 384, 32, 32, 5, 5, 1, 1, 10); | |||
| run_bench(64, 384, 32, 32, 7, 7, 1, 1, 10); | |||
| run_bench(64, 384, 32, 32, 9, 9, 1, 1, 10); | |||
| run_bench(64, 384, 32, 32, 11, 11, 1, 1, 10); | |||
| run_bench(64, 384, 32, 32, 13, 13, 1, 1, 10); | |||
| run_bench(64, 384, 32, 32, 15, 15, 1, 1, 10); | |||
| run_bench(64, 384, 32, 32, 17, 17, 1, 1, 10); | |||
| run_bench(64, 384, 32, 32, 19, 19, 1, 1, 10); | |||
| run_bench(64, 384, 32, 32, 21, 21, 1, 1, 10); | |||
| run_bench(64, 384, 32, 32, 23, 23, 1, 1, 10); | |||
| run_bench(64, 384, 32, 32, 25, 25, 1, 1, 10); | |||
| run_bench(64, 384, 32, 32, 27, 27, 1, 1, 10); | |||
| run_bench(64, 384, 32, 32, 29, 29, 1, 1, 10); | |||
| run_bench(64, 384, 32, 32, 31, 31, 1, 1, 10); | |||
| } | |||
| TEST_F(CUDA, BENCHMARK_REGION_RESTRICTED_CONV_BACKWARD_LARGE_FILTER_FP32_UINT8) { | |||
| require_compute_capability(7, 5); | |||
| Benchmarker<ConvolutionBackwardData> bencher(handle_cuda()); | |||
| bencher.set_display(false); | |||
| bencher.set_before_exec_callback( | |||
| AlgoChecker<ConvolutionBackwardData>("DEPTHWISE_LARGE_FILTER")); | |||
| Benchmarker<RegionRestrictedConvolutionBackwardData> rr_bencher(handle_cuda()); | |||
| rr_bencher.set_display(false); | |||
| ConvolutionBackwardData::Param param; | |||
| param.format = ConvolutionBackwardData::Param::Format::NCHW; | |||
| param.sparse = ConvolutionBackwardData::Param::Sparse::GROUP; | |||
| RegionRestrictedConvolutionBackwardData::Param rr_param; | |||
| rr_param.format = RegionRestrictedConvolutionBackwardData::Param::Format::NCHW; | |||
| rr_param.sparse = RegionRestrictedConvolutionBackwardData::Param::Sparse::GROUP; | |||
| UniformIntRNG r_rng{1, 3}; | |||
| auto run_bench = [&](size_t batch, size_t g, size_t hi, size_t wi, size_t fh, | |||
| size_t fw, size_t sh, size_t sw, size_t nr_times) { | |||
| param.pad_h = fh / 2; | |||
| param.pad_w = fw / 2; | |||
| param.stride_h = sh; | |||
| param.stride_w = sw; | |||
| rr_param.pad_h = fh / 2; | |||
| rr_param.pad_w = fw / 2; | |||
| rr_param.stride_h = sh; | |||
| rr_param.stride_w = sw; | |||
| bencher.set_param(param) | |||
| .set_dtype(0, dtype::Float32()) | |||
| .set_dtype(1, dtype::Float32()) | |||
| .set_dtype(2, dtype::Float32()) | |||
| .set_dtype(4, dtype::Float32()); | |||
| bencher.set_times(nr_times); | |||
| rr_bencher.set_param(rr_param) | |||
| .set_dtype(0, dtype::Float32()) | |||
| .set_dtype(1, dtype::Float32()) | |||
| .set_dtype(2, dtype::Uint8()) | |||
| .set_dtype(3, dtype::Uint8()); | |||
| rr_bencher.set_rng(2, &r_rng).set_rng(3, &r_rng); | |||
| rr_bencher.set_times(nr_times); | |||
| size_t ho = infer_conv_shape(hi, fh, sh, param.pad_h); | |||
| size_t wo = infer_conv_shape(wi, fw, sw, param.pad_w); | |||
| TensorShape inp{batch, g, hi, wi} /*src*/, kern{g, 1, 1, fh, fw} /*filter*/, | |||
| rin{batch, hi, wi}, rout{batch, ho, wo}, | |||
| out{batch, g, ho, wo} /*output*/; | |||
| float bandwith = static_cast<float>( | |||
| inp.total_nr_elems() + kern.total_nr_elems() + | |||
| out.total_nr_elems()) / | |||
| (1024 * 1024 * 1024) * 1e3; | |||
| float rr_bandwith = static_cast<float>( | |||
| inp.total_nr_elems() + kern.total_nr_elems() + | |||
| rin.total_nr_elems() + rout.total_nr_elems() + | |||
| out.total_nr_elems()) / | |||
| (1024 * 1024 * 1024) * 1e3; | |||
| auto time_in_ms = bencher.execs({kern, out, inp}) / nr_times; | |||
| auto ops = 2.0 * batch * g * ho * wo * fh * fw / (time_in_ms * 1e-3) * 1e-12; | |||
| auto rr_time_in_ms = rr_bencher.execs({kern, out, rin, rout, inp}) / nr_times; | |||
| auto rr_ops = | |||
| 2.0 * batch * g * ho * wo * fh * fw / (rr_time_in_ms * 1e-3) * 1e-12; | |||
| printf("[DGRAD]RegionRestrictedDepthwiseLargeFilter vs DepthwiseLargeFilter: " | |||
| "grad=%s, " | |||
| "kern=%s, diff=%s\n" | |||
| "time: %.2f ms, time(rr): %.2f ms, perf: %.2fTops, perf(rr): %.2f Tops\n" | |||
| "bandwidth: %.2fGB/s, bandwidth(rr): %.2fGB/s, speedup: %.2f.\n", | |||
| inp.to_string().c_str(), kern.to_string().c_str(), | |||
| out.to_string().c_str(), time_in_ms, rr_time_in_ms, ops, rr_ops, | |||
| bandwith * 4 / time_in_ms, rr_bandwith * 4 / rr_time_in_ms, | |||
| time_in_ms / rr_time_in_ms); | |||
| }; | |||
| run_bench(64, 384, 32, 32, 3, 3, 1, 1, 10); | |||
| run_bench(64, 384, 32, 32, 5, 5, 1, 1, 10); | |||
| run_bench(64, 384, 32, 32, 7, 7, 1, 1, 10); | |||
| run_bench(64, 384, 32, 32, 9, 9, 1, 1, 10); | |||
| run_bench(64, 384, 32, 32, 11, 11, 1, 1, 10); | |||
| run_bench(64, 384, 32, 32, 13, 13, 1, 1, 10); | |||
| run_bench(64, 384, 32, 32, 15, 15, 1, 1, 10); | |||
| run_bench(64, 384, 32, 32, 17, 17, 1, 1, 10); | |||
| run_bench(64, 384, 32, 32, 19, 19, 1, 1, 10); | |||
| run_bench(64, 384, 32, 32, 21, 21, 1, 1, 10); | |||
| run_bench(64, 384, 32, 32, 23, 23, 1, 1, 10); | |||
| run_bench(64, 384, 32, 32, 25, 25, 1, 1, 10); | |||
| run_bench(64, 384, 32, 32, 27, 27, 1, 1, 10); | |||
| run_bench(64, 384, 32, 32, 29, 29, 1, 1, 10); | |||
| run_bench(64, 384, 32, 32, 31, 31, 1, 1, 10); | |||
| } | |||
| TEST_F(CUDA, BENCHMARK_REGION_RESTRICTED_CONV_FORWARD_LARGE_FILTER_UINT8) { | |||
| require_compute_capability(7, 5); | |||
| Benchmarker<ConvBiasForward> bencher(handle_cuda()); | |||
| @@ -271,6 +467,124 @@ TEST_F(CUDA, BENCHMARK_REGION_RESTRICTED_CONV_FORWARD_LARGE_FILTER_UINT8) { | |||
| #endif | |||
| TEST_F(CUDA, REGION_RESTRICTED_CONV_BWD_DATA_FP32) { | |||
| Checker<RegionRestrictedConvolutionBackwardData> checker(handle_cuda()); | |||
| for (auto dt : std::vector<DType>{dtype::Int32(), dtype::Uint8()}) { | |||
| auto run = [&checker, &dt]( | |||
| size_t n, size_t g, size_t ih, size_t fh, size_t padding, | |||
| size_t stride) { | |||
| RegionRestrictedConvolutionBackwardData::Param cur_param; | |||
| cur_param.mode = RegionRestrictedConvolutionBackwardData::Param::Mode:: | |||
| CROSS_CORRELATION; | |||
| cur_param.compute_mode = RegionRestrictedConvolutionBackwardData::Param:: | |||
| ComputeMode::DEFAULT; | |||
| cur_param.sparse = | |||
| RegionRestrictedConvolutionBackwardData::Param::Sparse::GROUP; | |||
| checker.set_dtype(0, dtype::Float32()) | |||
| .set_dtype(1, dtype::Float32()) | |||
| .set_dtype(2, dt) | |||
| .set_dtype(3, dt); | |||
| float scale = 64.f / sqrt(fh * fh); | |||
| UniformFloatRNG rng(scale, 2 * scale); | |||
| UniformIntRNG r_rng{1, 2}; | |||
| checker.set_rng(0, &rng).set_rng(1, &rng).set_rng(2, &r_rng).set_rng( | |||
| 3, &r_rng); | |||
| cur_param.pad_h = cur_param.pad_w = padding; | |||
| cur_param.stride_h = cur_param.stride_w = stride; | |||
| size_t oh = (ih + 2 * padding - fh + 1) / stride; | |||
| checker.set_param(cur_param).execs({ | |||
| {g, 1, 1, fh, fh}, // filter | |||
| {n, g * 1, oh, oh}, // diff | |||
| {n, ih, ih}, // rin | |||
| {n, oh, oh}, // rout | |||
| {n, g * 1, ih, ih} // grad | |||
| }); | |||
| }; | |||
| if (dt == dtype::Int32()) { | |||
| run(4, 8, 32, 5, 5 / 2, 1); | |||
| run(1, 2, 2, 2, 0, 1); | |||
| run(1, 2, 3, 3, 0, 1); | |||
| run(1, 2, 4, 4, 0, 1); | |||
| run(1, 2, 5, 5, 0, 1); | |||
| run(1, 2, 6, 6, 0, 1); | |||
| run(1, 2, 7, 7, 0, 1); | |||
| } | |||
| run(4, 8, 32, 7, 7 / 2, 1); | |||
| run(4, 8, 32, 9, 9 / 2, 1); | |||
| run(4, 8, 32, 11, 11 / 2, 1); | |||
| run(4, 8, 32, 13, 13 / 2, 1); | |||
| run(4, 8, 32, 15, 15 / 2, 1); | |||
| run(4, 8, 32, 17, 17 / 2, 1); | |||
| run(4, 8, 32, 19, 19 / 2, 1); | |||
| run(4, 8, 32, 21, 21 / 2, 1); | |||
| run(4, 8, 32, 23, 23 / 2, 1); | |||
| run(4, 8, 32, 25, 25 / 2, 1); | |||
| run(4, 8, 32, 27, 27 / 2, 1); | |||
| run(4, 8, 32, 29, 29 / 2, 1); | |||
| run(4, 8, 32, 31, 31 / 2, 1); | |||
| } | |||
| } | |||
| TEST_F(CUDA, REGION_RESTRICTED_CONV_BWD_DATA_FP32_RIN_EQ_ROUT) { | |||
| Checker<RegionRestrictedConvolutionBackwardData> checker(handle_cuda()); | |||
| for (auto dt : std::vector<DType>{dtype::Int32()}) { | |||
| auto run = [&checker, &dt]( | |||
| size_t n, size_t g, size_t ih, size_t fh, size_t padding, | |||
| size_t stride) { | |||
| RegionRestrictedConvolutionBackwardData::Param cur_param; | |||
| cur_param.mode = RegionRestrictedConvolutionBackwardData::Param::Mode:: | |||
| CROSS_CORRELATION; | |||
| cur_param.compute_mode = RegionRestrictedConvolutionBackwardData::Param:: | |||
| ComputeMode::DEFAULT; | |||
| cur_param.sparse = | |||
| RegionRestrictedConvolutionBackwardData::Param::Sparse::GROUP; | |||
| checker.set_dtype(2, dt).set_dtype(3, dt); | |||
| float scale = 64.f / sqrt(fh * fh); | |||
| UniformFloatRNG rng(scale, 2 * scale); | |||
| // value 0 mask may cause unexpected behaviour. | |||
| UniformIntRNG r_rng{1, 1}; | |||
| checker.set_rng(0, &rng).set_rng(1, &rng).set_rng(2, &r_rng).set_rng( | |||
| 3, &r_rng); | |||
| cur_param.pad_h = cur_param.pad_w = padding; | |||
| cur_param.stride_h = cur_param.stride_w = stride; | |||
| size_t oh = (ih + 2 * padding - fh + 1) / stride; | |||
| checker.set_param(cur_param).execs( | |||
| {/*filter*/ {g, 1, 1, fh, fh}, | |||
| /*diff*/ {n, g * 1, oh, oh}, | |||
| /*rin*/ {n, ih, ih}, | |||
| /*rout*/ {n, oh, oh}, | |||
| /*grad*/ {n, g * 1, ih, ih}}); | |||
| }; | |||
| if (dt == dtype::Int32()) { | |||
| // NOTE: UINT8 assert the spatial size of src&dst is 4*N | |||
| run(4, 8, 32, 5, 5 / 2, 1); | |||
| run(1, 2, 2, 2, 0, 1); | |||
| run(1, 2, 3, 3, 0, 1); | |||
| run(1, 2, 4, 4, 0, 1); | |||
| run(1, 2, 5, 5, 0, 1); | |||
| run(1, 2, 6, 6, 0, 1); | |||
| run(1, 2, 7, 7, 0, 1); | |||
| } | |||
| run(4, 8, 32, 7, 7 / 2, 1); | |||
| run(4, 8, 32, 9, 9 / 2, 1); | |||
| run(4, 8, 32, 11, 11 / 2, 1); | |||
| run(4, 8, 32, 13, 13 / 2, 1); | |||
| run(4, 8, 32, 15, 15 / 2, 1); | |||
| run(4, 8, 32, 17, 17 / 2, 1); | |||
| run(4, 8, 32, 19, 19 / 2, 1); | |||
| run(4, 8, 32, 21, 21 / 2, 1); | |||
| run(4, 8, 32, 23, 23 / 2, 1); | |||
| run(4, 8, 32, 25, 25 / 2, 1); | |||
| run(4, 8, 32, 27, 27 / 2, 1); | |||
| run(4, 8, 32, 29, 29 / 2, 1); | |||
| run(4, 8, 32, 31, 31 / 2, 1); | |||
| } | |||
| } | |||
| } // namespace test | |||
| } // namespace megdnn | |||
| @@ -131,4 +131,110 @@ TEST_F(NAIVE, REGIONRESTRICTEDCONVOLUTION_FORWARD) { | |||
| {}}); | |||
| } | |||
| TEST_F(NAIVE, REGIONRESTRICTEDCONVOLUTION_FORWARD_DENSE_BRUTE) { | |||
| Checker<RegionRestrictedConvolutionForward> checker(handle()); | |||
| RegionRestrictedConvolutionForward::Param param; | |||
| checker.set_param(param).exect( | |||
| Testcase{ | |||
| TensorValue( // src | |||
| {1, 1, 4, 4}, dtype::Float32(), | |||
| {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}), | |||
| TensorValue( // filter | |||
| {1, 1, 2, 2}, dtype::Float32(), {1, 1, 1, 1}), | |||
| TensorValue( // rin | |||
| {1, 4, 4}, dtype::Int32(), | |||
| {1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1}), | |||
| TensorValue( // rout | |||
| {1, 3, 3}, dtype::Int32(), {0, 1, 1, 1, 0, 0, 1, 0, 1}), | |||
| {}, // output | |||
| }, | |||
| Testcase{ | |||
| {}, | |||
| {}, | |||
| {}, | |||
| {}, | |||
| TensorValue( | |||
| {1, 1, 3, 3}, dtype::Float32(), | |||
| {4, 14, 18, 5, 9, 0, 13, 9, 50})}); | |||
| } | |||
| TEST_F(NAIVE, REGIONRESTRICTEDCONVOLUTION_BWD_DATA_DENSE_BRUTE) { | |||
| Checker<RegionRestrictedConvolutionBackwardData> checker(handle()); | |||
| RegionRestrictedConvolutionBackwardData::Param param; | |||
| checker.set_param(param).exect( | |||
| Testcase{ | |||
| // filter | |||
| TensorValue( | |||
| {1, 1, 2, 2}, // shape | |||
| dtype::Float32(), // dtype | |||
| {1.f, 1.f, 1.f, 1.f}), | |||
| // diff | |||
| TensorValue( | |||
| {1, 1, 3, 3}, dtype::Float32(), | |||
| {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f}), | |||
| // rin | |||
| TensorValue( | |||
| {1, 4, 4}, dtype::Int32(), | |||
| {1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1}), | |||
| // rout | |||
| TensorValue({1, 3, 3}, dtype::Int32(), {0, 1, 1, 1, 0, 0, 1, 0, 1}), | |||
| // grad | |||
| {}}, | |||
| Testcase{// filter | |||
| {}, | |||
| // diff | |||
| {}, | |||
| // rin | |||
| {}, | |||
| // rout | |||
| {}, | |||
| // grad | |||
| TensorValue( | |||
| {1, 1, 4, 4}, dtype::Float32(), | |||
| {0., 2., 5., 3., 1., 6., 5., 3., 0., 13., 9., 9., 0., 7., | |||
| 9., 9.})}); | |||
| } | |||
| TEST_F(NAIVE, REGIONRESTRICTEDCONVOLUTION_BWD_DATA_GROUP_BRUTE) { | |||
| Checker<RegionRestrictedConvolutionBackwardData> checker(handle()); | |||
| // params | |||
| RegionRestrictedConvolutionBackwardData::Param param; | |||
| param.sparse = RegionRestrictedConvolutionBackwardData::Param::Sparse::GROUP; | |||
| param.mode = RegionRestrictedConvolutionBackwardData::Mode::CROSS_CORRELATION; | |||
| param.compute_mode = | |||
| RegionRestrictedConvolutionBackwardData::Param::ComputeMode::DEFAULT; | |||
| param.pad_h = param.pad_w = | |||
| 0; // forward param, naive backward data doesn't matter with deconv padding | |||
| param.stride_h = param.stride_w = 1; | |||
| // checker setting | |||
| checker.set_param(param).exect( | |||
| Testcase{// filter | |||
| TensorValue( | |||
| {2, 1, 1, 2, 2}, // shape | |||
| dtype::Float32(), // dtype | |||
| {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}), | |||
| // diff | |||
| TensorValue({1, 2, 1, 1}, dtype::Float32(), {1, 2}), | |||
| // rin | |||
| TensorValue({1, 2, 2}, dtype::Int32(), {1, 1, 1, 1}), | |||
| // rout | |||
| TensorValue({1, 1, 1}, dtype::Int32(), {1}), | |||
| // grad | |||
| {}}, | |||
| Testcase{// filter | |||
| {}, | |||
| // diff | |||
| {}, | |||
| // rin | |||
| {}, | |||
| // rout | |||
| {}, | |||
| // grad | |||
| TensorValue( | |||
| {1, 2, 2, 2}, dtype::Float32(), | |||
| {1, 2, 3, 4, 10, 12, 14, 16})}); | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||