GitOrigin-RevId: 85592bca6b
tags/v1.8.2
| @@ -19,10 +19,12 @@ using namespace cuda; | |||
| ConvolutionBackwardDataImpl::AlgoPack::AlgoPack() { | |||
| non_cudnn_algos.push_back(&chanwise); | |||
| non_cudnn_algos.push_back(&chanwise_small); | |||
| non_cudnn_algos.push_back(&depthwise_large_filter); | |||
| non_cudnn_algos.push_back(&matmul); | |||
| all_algos.push_back(&chanwise); // prefer chanwise | |||
| all_algos.push_back(&chanwise_small); // prefer small chanwise | |||
| all_algos.push_back(&depthwise_large_filter); | |||
| fill_cudnn_algos(); | |||
| for (auto&& i : cudnn) { | |||
| @@ -37,6 +37,7 @@ public: | |||
| CUDA_MATMUL, | |||
| CUDA_CHANWISE, | |||
| CUDA_CHANWISE_SMALL, | |||
| CUDA_DEPTHWISE_LARGE_FILTER, | |||
| CUDA_BFLOAT16, | |||
| CUDA_GROUP_CONV_GENERAL, | |||
| CUDA_IMPLICIT_GEMM_NCHW4_DOTPROD_INT8, | |||
| @@ -192,6 +193,20 @@ public: | |||
| } | |||
| }; | |||
| class ConvolutionBackwardDataImpl::AlgoDepthwiseLargeFilter final : public AlgoBase { | |||
| public: | |||
| bool is_available(const SizeArgs& args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
| void exec(const ExecArgs& args) const override; | |||
| const char* name() const override { return "DEPTHWISE_LARGE_FILTER"; } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_DEPTHWISE_LARGE_FILTER) | |||
| AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
| private: | |||
| mutable std::string m_name; | |||
| }; | |||
| class ConvolutionBackwardDataImpl::AlgoBFloat16 final : public AlgoBase { | |||
| public: | |||
| bool is_available(const SizeArgs& args) const override; | |||
| @@ -411,6 +426,7 @@ public: | |||
| AlgoMatmul matmul; | |||
| AlgoChanwise chanwise; | |||
| AlgoChanwiseSmall chanwise_small; | |||
| AlgoDepthwiseLargeFilter depthwise_large_filter; | |||
| AlgoBFloat16 bfloat16; | |||
| AlgoGroupConvGeneral group; | |||
| std::vector<AlgoInt8NCHW4DotProdImplicitGemm> int8_nchw4_dotprod; | |||
| @@ -0,0 +1,86 @@ | |||
| /** | |||
| * \file dnn/src/cuda/convolution/backward_data/depthwise_large_filter.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "src/cuda/convolution/backward_data/algo.h" | |||
| #include "src/cuda/convolution/chanwise/kern.cuh" | |||
| #include "src/cuda/utils.h" | |||
| using namespace megdnn; | |||
| using namespace cuda; | |||
| using namespace convolution; | |||
| namespace { | |||
| inline bool is_available_depthwise_large_filter(const chanwise::Param& param) { | |||
| auto&& device_prop = cuda::current_device_prop(); | |||
| int flt_smem_w = (param.flt_w + 3) / 4 * 4; | |||
| int flt_smem_h = 3; | |||
| int flt_reg_per_thread = | |||
| flt_smem_w > 32 ? (flt_smem_w + 31) / 32 : 1 + flt_smem_w / 4; | |||
| int ow = param.out_w > 64 ? 64 : param.out_w; | |||
| int src_smem_w = ow + flt_smem_w - 1; | |||
| int src_smem_h = flt_smem_h + param.flt_h - 1; | |||
| int src_reg_per_thread = src_smem_w > 128 ? (flt_smem_w + 127) / 128 | |||
| : 1 + (ow + 3) / 4 + flt_smem_w / 4 - 1; | |||
| int out_reg_per_thread = (ow + 3) / 4 * 4; | |||
| if (device_prop.regsPerBlock < 4 * 32 * | |||
| (flt_reg_per_thread + src_reg_per_thread + | |||
| out_reg_per_thread) || | |||
| device_prop.sharedMemPerBlock < | |||
| static_cast<size_t>( | |||
| flt_smem_w * flt_smem_h + src_smem_w * src_smem_h)) { | |||
| return false; | |||
| } | |||
| return param.stride_h == 1 && param.stride_w == 1 && param.src_h == param.out_h && | |||
| param.src_w == param.out_w; | |||
| } | |||
| } // anonymous namespace | |||
| bool ConvolutionBackwardDataImpl::AlgoDepthwiseLargeFilter::is_available( | |||
| const SizeArgs& args) const { | |||
| if (!args.grad_layout->is_contiguous() || !args.diff_layout->is_contiguous()) { | |||
| return false; | |||
| } | |||
| if (args.diff_layout->dtype != args.filter_layout->dtype && | |||
| args.diff_layout->dtype != dtype::Float32()) { | |||
| return false; | |||
| } | |||
| auto param = chanwise::Param::from_fwd_args(args.as_fwd_args()); | |||
| auto&& fm = args.filter_meta; | |||
| return fm.group > 1 && args.filter_meta.format == Param::Format::NCHW && | |||
| args.diff_layout->dtype.category() == DTypeCategory::FLOAT && | |||
| args.opr->param().compute_mode == Param::ComputeMode::DEFAULT && | |||
| fm.spatial_ndim == 2 && fm.icpg == 1 && fm.dilation[0] == 1 && | |||
| fm.dilation[1] == 1 && !fm.should_flip && | |||
| is_available_depthwise_large_filter(param); | |||
| } | |||
| size_t ConvolutionBackwardDataImpl::AlgoDepthwiseLargeFilter::get_workspace_in_bytes( | |||
| const SizeArgs& args) const { | |||
| return 0; | |||
| } | |||
| void ConvolutionBackwardDataImpl::AlgoDepthwiseLargeFilter::exec( | |||
| const ExecArgs& args) const { | |||
| auto kparam = chanwise::Param::from_fwd_args(args.as_fwd_args()); | |||
| auto stream = cuda_stream(args.handle); | |||
| switch (args.diff_layout->dtype.enumv()) { | |||
| case DTypeEnum::Float32: | |||
| chanwise::run_bwd_depthwise_large_filter( | |||
| args.grad_tensor->ptr<float>(), args.diff_tensor->ptr<float>(), | |||
| args.filter_tensor->ptr<float>(), kparam, stream); | |||
| break; | |||
| default: | |||
| megdnn_assert_internal(0); | |||
| } | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,45 @@ | |||
| /** | |||
| * \file dnn/src/cuda/conv_bias/chanwise/fwd_depthwise_large_filter.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "./kern.cuh" | |||
| #include "./kern_helper.cuh" | |||
| #include "cuda.h" | |||
| #include "cuda_fp16.h" | |||
| #include "src/cuda/convolution/chanwise/launch_config.cuh" | |||
| #include "src/cuda/fp16_help.cuh" | |||
| using namespace megdnn; | |||
| using namespace cuda; | |||
| using namespace convolution; | |||
| using namespace chanwise; | |||
| #include "src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.inl" | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| namespace convolution { | |||
| namespace chanwise { | |||
| // =====================================fwd===================================== | |||
| template <> | |||
| void run_bwd_depthwise_large_filter( | |||
| float* dst, const float* src, const float* flt, const Param& param, | |||
| cudaStream_t stream) { | |||
| INSTANCE(DepthwiseConv2dDirection::DIRECTION_BACKWARD) | |||
| } | |||
| } // namespace chanwise | |||
| } // namespace convolution | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| // vim: syntax=cuda.doxygen | |||
| @@ -63,6 +63,10 @@ void run_bwd_data( | |||
| T* src_grad, const T* dst_grad, const T* flt, const Param& param, | |||
| cudaStream_t stream); | |||
| template <typename T> | |||
| void run_bwd_depthwise_large_filter( | |||
| T* dst, const T* src, const T* flt, const Param& param, cudaStream_t stream); | |||
| template <typename T> | |||
| void run_bwd_filter( | |||
| T* filter_grad, const T* src, const T* dst_grad, const Param& param, | |||
| @@ -97,6 +97,7 @@ public: | |||
| class AlgoMatmul; | |||
| class AlgoChanwise; | |||
| class AlgoChanwiseSmall; | |||
| class AlgoDepthwiseLargeFilter; | |||
| class AlgoGroupConvGeneral; | |||
| class AlgoBFloat16; | |||
| class AlgoInt8NCHW4DotProdImplicitGemm; | |||
| @@ -724,6 +724,55 @@ TEST_F(CUDA, CONVOLUTION_BACKWARD_DATA_1) { | |||
| TensorLayoutArray{filter, dst, src}); | |||
| } | |||
| TEST_F(CUDA, CONVOLUTION_BACKWARD_DEPTHWISE_LARGE_FILTER) { | |||
| Checker<ConvolutionBackwardData> checker(handle_cuda()); | |||
| checker.set_before_exec_callback( | |||
| AlgoChecker<ConvolutionBackwardData>("DEPTHWISE_LARGE_FILTER")); | |||
| for (auto dtype : std::vector<DType>{dtype::Float32()}) { | |||
| auto run = [&checker, &dtype](size_t n, size_t g, size_t h, size_t fh) { | |||
| param::Convolution param; | |||
| param.stride_h = param.stride_w = 1; | |||
| param.pad_h = param.pad_w = fh / 2; | |||
| param.mode = Convolution::Mode::CROSS_CORRELATION; | |||
| param.sparse = param::Convolution::Sparse::GROUP; | |||
| checker.set_dtype(0, dtype).set_dtype(1, dtype).set_dtype(2, dtype); | |||
| checker.set_param(param).execs( | |||
| {{g, 1, 1, fh, fh}, {n, g, h, h}, {n, g, h, h}}); | |||
| }; | |||
| run(4, 8, 32, 5); | |||
| run(4, 8, 32, 7); | |||
| run(4, 8, 32, 9); | |||
| run(4, 8, 32, 11); | |||
| run(4, 8, 32, 13); | |||
| run(4, 8, 32, 15); | |||
| run(4, 8, 32, 17); | |||
| run(4, 8, 32, 19); | |||
| run(4, 8, 32, 21); | |||
| run(4, 8, 32, 23); | |||
| run(4, 8, 32, 25); | |||
| run(4, 8, 32, 27); | |||
| run(4, 8, 32, 29); | |||
| run(4, 8, 32, 31); | |||
| run(4, 8, 64, 7); | |||
| run(4, 8, 64, 5); | |||
| run(4, 8, 64, 9); | |||
| run(4, 8, 64, 11); | |||
| run(4, 8, 64, 13); | |||
| run(4, 8, 64, 15); | |||
| run(4, 8, 64, 17); | |||
| run(4, 8, 64, 19); | |||
| run(4, 8, 64, 21); | |||
| run(4, 8, 64, 23); | |||
| run(4, 8, 64, 25); | |||
| run(4, 8, 64, 27); | |||
| run(4, 8, 64, 29); | |||
| run(4, 8, 64, 31); | |||
| run(1, 2, 128, 31); | |||
| run(1, 2, 256, 31); | |||
| } | |||
| } | |||
| #if MEGDNN_WITH_BENCHMARK | |||
| TEST_F(CUDA, CONV_FWD_BENCHMARK) { | |||
| auto run = [&](size_t N, size_t OC, size_t IC, size_t IH, size_t IW, size_t SH = 1, | |||
| @@ -901,24 +950,23 @@ TEST_F(CUDA, CONVOLUTION_BWD_DATA_BENCHMARK) { | |||
| run(32, 64, 64, 56, 56, 1, 1, 0); | |||
| } | |||
| TEST_F(CUDA, BENCHMARK_CONVOLUTION_BWD_DATA_CHANWISE_SMALL_FEAT_LARGE_FILTER) { | |||
| CUBenchmarker<ConvolutionBackwardData> bench{handle_cuda()}; | |||
| std::unique_ptr<OprProxy<ConvolutionBackwardData>> proxy{ | |||
| new OprProxy<ConvolutionBackwardData>{true}}; | |||
| size_t RUNS = 10; | |||
| bench.set_proxy(proxy).set_times(RUNS); | |||
| TEST_F(CUDA, BENCHMARK_CONVOLUTION_BWD_DATA_DEPTHWISE_LARGE_FILTER) { | |||
| CUBenchmarker<ConvolutionBackwardData> bencher{handle_cuda()}; | |||
| bencher.set_display(false); | |||
| bencher.set_before_exec_callback( | |||
| AlgoChecker<ConvolutionBackwardData>("DEPTHWISE_LARGE_FILTER")); | |||
| auto run = [&](size_t N, size_t OC, size_t g, size_t IH, size_t IW, size_t FH, | |||
| size_t SH, size_t PH) { | |||
| bench.set_dtype(0, dtype::Float32()) | |||
| size_t SH, size_t nr_times) { | |||
| bencher.set_dtype(0, dtype::Float32()) | |||
| .set_dtype(1, dtype::Float32()) | |||
| .set_dtype(2, dtype::Float32()); | |||
| param::Convolution param; | |||
| param.stride_h = param.stride_w = SH; | |||
| param.pad_h = param.pad_w = FH / 2; | |||
| param.sparse = param::Convolution::Sparse::GROUP; | |||
| bench.set_param(param); | |||
| bench.proxy()->target_execution_policy.algo.reset(); | |||
| bencher.set_param(param); | |||
| bencher.set_times(nr_times); | |||
| TensorLayout src{{N, g, IH, IW}, dtype::Float32()}, | |||
| filter{{g, 1, 1, FH, FH}, dtype::Float32()}; | |||
| TensorLayout dst; | |||
| @@ -927,15 +975,28 @@ TEST_F(CUDA, BENCHMARK_CONVOLUTION_BWD_DATA_CHANWISE_SMALL_FEAT_LARGE_FILTER) { | |||
| opr->param() = param; | |||
| opr->deduce_layout(src, filter, dst); | |||
| } | |||
| auto time_ms_fp32 = bench.execl({filter, dst, src}) / RUNS; | |||
| auto time_ms_fp32 = bencher.execl({filter, dst, src}) / nr_times; | |||
| float flo = 2.0 * N * g * dst[2] * dst[3] * FH * FH; | |||
| printf("inp=%s, kern=%s, dst=%s ", src.to_string().c_str(), | |||
| filter.to_string().c_str(), dst.to_string().c_str()); | |||
| printf("time_fp32=%.2fms, flops=%.3fTFLOPS\n", time_ms_fp32, | |||
| (flo / (time_ms_fp32 * 1e9))); | |||
| }; | |||
| run(64, 384, 384, 32, 32, 31, 1, 15); | |||
| run(64, 384, 384, 32, 32, 3, 1, 10); | |||
| run(64, 384, 384, 32, 32, 5, 1, 10); | |||
| run(64, 384, 384, 32, 32, 7, 1, 10); | |||
| run(64, 384, 384, 32, 32, 9, 1, 10); | |||
| run(64, 384, 384, 32, 32, 11, 1, 10); | |||
| run(64, 384, 384, 32, 32, 13, 1, 10); | |||
| run(64, 384, 384, 32, 32, 15, 1, 10); | |||
| run(64, 384, 384, 32, 32, 17, 1, 10); | |||
| run(64, 384, 384, 32, 32, 19, 1, 10); | |||
| run(64, 384, 384, 32, 32, 21, 1, 10); | |||
| run(64, 384, 384, 32, 32, 23, 1, 10); | |||
| run(64, 384, 384, 32, 32, 25, 1, 10); | |||
| run(64, 384, 384, 32, 32, 27, 1, 10); | |||
| run(64, 384, 384, 32, 32, 29, 1, 10); | |||
| run(64, 384, 384, 32, 32, 31, 1, 10); | |||
| } | |||
| TEST_F(CUDA, BENCHMARK_CONVOLUTION_BWD_DATA_BF16) { | |||
| @@ -1103,7 +1164,7 @@ TEST_F(CUDA, CONVOLUTION_BWD_FILTER_BENCHMARK) { | |||
| run(32, 64, 64, 56, 56, 1, 1, 0); | |||
| } | |||
| TEST_F(CUDA, BENCHMARK_CONVOLUTION_BWD_FILTER_CHANWISE_SMALL_FEAT_LARGE_FILTER) { | |||
| TEST_F(CUDA, BENCHMARK_CONVOLUTION_BWD_FILTER_DEPTHWISE_LARGE_FILTER) { | |||
| CUBenchmarker<ConvolutionBackwardFilter> bench{handle_cuda()}; | |||
| std::unique_ptr<OprProxy<ConvolutionBackwardFilter>> proxy{ | |||
| new OprProxy<ConvolutionBackwardFilter>{true}}; | |||
| @@ -57,6 +57,7 @@ | |||
| #cmakedefine01 MEGDNN_64_BIT | |||
| #cmakedefine01 MEGDNN_THREADS_512 | |||
| #cmakedefine01 MEGDNN_ENABLE_MULTI_THREADS | |||
| #cmakedefine01 MEGDNN_WITH_BENCHMARK | |||
| // whether atlas is available | |||
| #ifndef MGB_ATLAS | |||