GitOrigin-RevId: 85592bca6b
tags/v1.8.2
| @@ -19,10 +19,12 @@ using namespace cuda; | |||||
| ConvolutionBackwardDataImpl::AlgoPack::AlgoPack() { | ConvolutionBackwardDataImpl::AlgoPack::AlgoPack() { | ||||
| non_cudnn_algos.push_back(&chanwise); | non_cudnn_algos.push_back(&chanwise); | ||||
| non_cudnn_algos.push_back(&chanwise_small); | non_cudnn_algos.push_back(&chanwise_small); | ||||
| non_cudnn_algos.push_back(&depthwise_large_filter); | |||||
| non_cudnn_algos.push_back(&matmul); | non_cudnn_algos.push_back(&matmul); | ||||
| all_algos.push_back(&chanwise); // prefer chanwise | all_algos.push_back(&chanwise); // prefer chanwise | ||||
| all_algos.push_back(&chanwise_small); // prefer small chanwise | all_algos.push_back(&chanwise_small); // prefer small chanwise | ||||
| all_algos.push_back(&depthwise_large_filter); | |||||
| fill_cudnn_algos(); | fill_cudnn_algos(); | ||||
| for (auto&& i : cudnn) { | for (auto&& i : cudnn) { | ||||
| @@ -37,6 +37,7 @@ public: | |||||
| CUDA_MATMUL, | CUDA_MATMUL, | ||||
| CUDA_CHANWISE, | CUDA_CHANWISE, | ||||
| CUDA_CHANWISE_SMALL, | CUDA_CHANWISE_SMALL, | ||||
| CUDA_DEPTHWISE_LARGE_FILTER, | |||||
| CUDA_BFLOAT16, | CUDA_BFLOAT16, | ||||
| CUDA_GROUP_CONV_GENERAL, | CUDA_GROUP_CONV_GENERAL, | ||||
| CUDA_IMPLICIT_GEMM_NCHW4_DOTPROD_INT8, | CUDA_IMPLICIT_GEMM_NCHW4_DOTPROD_INT8, | ||||
| @@ -192,6 +193,20 @@ public: | |||||
| } | } | ||||
| }; | }; | ||||
| class ConvolutionBackwardDataImpl::AlgoDepthwiseLargeFilter final : public AlgoBase { | |||||
| public: | |||||
| bool is_available(const SizeArgs& args) const override; | |||||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||||
| void exec(const ExecArgs& args) const override; | |||||
| const char* name() const override { return "DEPTHWISE_LARGE_FILTER"; } | |||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_DEPTHWISE_LARGE_FILTER) | |||||
| AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||||
| private: | |||||
| mutable std::string m_name; | |||||
| }; | |||||
| class ConvolutionBackwardDataImpl::AlgoBFloat16 final : public AlgoBase { | class ConvolutionBackwardDataImpl::AlgoBFloat16 final : public AlgoBase { | ||||
| public: | public: | ||||
| bool is_available(const SizeArgs& args) const override; | bool is_available(const SizeArgs& args) const override; | ||||
| @@ -411,6 +426,7 @@ public: | |||||
| AlgoMatmul matmul; | AlgoMatmul matmul; | ||||
| AlgoChanwise chanwise; | AlgoChanwise chanwise; | ||||
| AlgoChanwiseSmall chanwise_small; | AlgoChanwiseSmall chanwise_small; | ||||
| AlgoDepthwiseLargeFilter depthwise_large_filter; | |||||
| AlgoBFloat16 bfloat16; | AlgoBFloat16 bfloat16; | ||||
| AlgoGroupConvGeneral group; | AlgoGroupConvGeneral group; | ||||
| std::vector<AlgoInt8NCHW4DotProdImplicitGemm> int8_nchw4_dotprod; | std::vector<AlgoInt8NCHW4DotProdImplicitGemm> int8_nchw4_dotprod; | ||||
| @@ -0,0 +1,86 @@ | |||||
| /** | |||||
| * \file dnn/src/cuda/convolution/backward_data/depthwise_large_filter.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #include "src/cuda/convolution/backward_data/algo.h" | |||||
| #include "src/cuda/convolution/chanwise/kern.cuh" | |||||
| #include "src/cuda/utils.h" | |||||
| using namespace megdnn; | |||||
| using namespace cuda; | |||||
| using namespace convolution; | |||||
| namespace { | |||||
| inline bool is_available_depthwise_large_filter(const chanwise::Param& param) { | |||||
| auto&& device_prop = cuda::current_device_prop(); | |||||
| int flt_smem_w = (param.flt_w + 3) / 4 * 4; | |||||
| int flt_smem_h = 3; | |||||
| int flt_reg_per_thread = | |||||
| flt_smem_w > 32 ? (flt_smem_w + 31) / 32 : 1 + flt_smem_w / 4; | |||||
| int ow = param.out_w > 64 ? 64 : param.out_w; | |||||
| int src_smem_w = ow + flt_smem_w - 1; | |||||
| int src_smem_h = flt_smem_h + param.flt_h - 1; | |||||
| int src_reg_per_thread = src_smem_w > 128 ? (flt_smem_w + 127) / 128 | |||||
| : 1 + (ow + 3) / 4 + flt_smem_w / 4 - 1; | |||||
| int out_reg_per_thread = (ow + 3) / 4 * 4; | |||||
| if (device_prop.regsPerBlock < 4 * 32 * | |||||
| (flt_reg_per_thread + src_reg_per_thread + | |||||
| out_reg_per_thread) || | |||||
| device_prop.sharedMemPerBlock < | |||||
| static_cast<size_t>( | |||||
| flt_smem_w * flt_smem_h + src_smem_w * src_smem_h)) { | |||||
| return false; | |||||
| } | |||||
| return param.stride_h == 1 && param.stride_w == 1 && param.src_h == param.out_h && | |||||
| param.src_w == param.out_w; | |||||
| } | |||||
| } // anonymous namespace | |||||
| bool ConvolutionBackwardDataImpl::AlgoDepthwiseLargeFilter::is_available( | |||||
| const SizeArgs& args) const { | |||||
| if (!args.grad_layout->is_contiguous() || !args.diff_layout->is_contiguous()) { | |||||
| return false; | |||||
| } | |||||
| if (args.diff_layout->dtype != args.filter_layout->dtype && | |||||
| args.diff_layout->dtype != dtype::Float32()) { | |||||
| return false; | |||||
| } | |||||
| auto param = chanwise::Param::from_fwd_args(args.as_fwd_args()); | |||||
| auto&& fm = args.filter_meta; | |||||
| return fm.group > 1 && args.filter_meta.format == Param::Format::NCHW && | |||||
| args.diff_layout->dtype.category() == DTypeCategory::FLOAT && | |||||
| args.opr->param().compute_mode == Param::ComputeMode::DEFAULT && | |||||
| fm.spatial_ndim == 2 && fm.icpg == 1 && fm.dilation[0] == 1 && | |||||
| fm.dilation[1] == 1 && !fm.should_flip && | |||||
| is_available_depthwise_large_filter(param); | |||||
| } | |||||
| size_t ConvolutionBackwardDataImpl::AlgoDepthwiseLargeFilter::get_workspace_in_bytes( | |||||
| const SizeArgs& args) const { | |||||
| return 0; | |||||
| } | |||||
| void ConvolutionBackwardDataImpl::AlgoDepthwiseLargeFilter::exec( | |||||
| const ExecArgs& args) const { | |||||
| auto kparam = chanwise::Param::from_fwd_args(args.as_fwd_args()); | |||||
| auto stream = cuda_stream(args.handle); | |||||
| switch (args.diff_layout->dtype.enumv()) { | |||||
| case DTypeEnum::Float32: | |||||
| chanwise::run_bwd_depthwise_large_filter( | |||||
| args.grad_tensor->ptr<float>(), args.diff_tensor->ptr<float>(), | |||||
| args.filter_tensor->ptr<float>(), kparam, stream); | |||||
| break; | |||||
| default: | |||||
| megdnn_assert_internal(0); | |||||
| } | |||||
| } | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -0,0 +1,45 @@ | |||||
| /** | |||||
| * \file dnn/src/cuda/conv_bias/chanwise/fwd_depthwise_large_filter.cu | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #include "./kern.cuh" | |||||
| #include "./kern_helper.cuh" | |||||
| #include "cuda.h" | |||||
| #include "cuda_fp16.h" | |||||
| #include "src/cuda/convolution/chanwise/launch_config.cuh" | |||||
| #include "src/cuda/fp16_help.cuh" | |||||
| using namespace megdnn; | |||||
| using namespace cuda; | |||||
| using namespace convolution; | |||||
| using namespace chanwise; | |||||
| #include "src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.inl" | |||||
| namespace megdnn { | |||||
| namespace cuda { | |||||
| namespace convolution { | |||||
| namespace chanwise { | |||||
| // =====================================fwd===================================== | |||||
| template <> | |||||
| void run_bwd_depthwise_large_filter( | |||||
| float* dst, const float* src, const float* flt, const Param& param, | |||||
| cudaStream_t stream) { | |||||
| INSTANCE(DepthwiseConv2dDirection::DIRECTION_BACKWARD) | |||||
| } | |||||
| } // namespace chanwise | |||||
| } // namespace convolution | |||||
| } // namespace cuda | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cuda.doxygen | |||||
| @@ -63,6 +63,10 @@ void run_bwd_data( | |||||
| T* src_grad, const T* dst_grad, const T* flt, const Param& param, | T* src_grad, const T* dst_grad, const T* flt, const Param& param, | ||||
| cudaStream_t stream); | cudaStream_t stream); | ||||
| template <typename T> | |||||
| void run_bwd_depthwise_large_filter( | |||||
| T* dst, const T* src, const T* flt, const Param& param, cudaStream_t stream); | |||||
| template <typename T> | template <typename T> | ||||
| void run_bwd_filter( | void run_bwd_filter( | ||||
| T* filter_grad, const T* src, const T* dst_grad, const Param& param, | T* filter_grad, const T* src, const T* dst_grad, const Param& param, | ||||
| @@ -97,6 +97,7 @@ public: | |||||
| class AlgoMatmul; | class AlgoMatmul; | ||||
| class AlgoChanwise; | class AlgoChanwise; | ||||
| class AlgoChanwiseSmall; | class AlgoChanwiseSmall; | ||||
| class AlgoDepthwiseLargeFilter; | |||||
| class AlgoGroupConvGeneral; | class AlgoGroupConvGeneral; | ||||
| class AlgoBFloat16; | class AlgoBFloat16; | ||||
| class AlgoInt8NCHW4DotProdImplicitGemm; | class AlgoInt8NCHW4DotProdImplicitGemm; | ||||
| @@ -724,6 +724,55 @@ TEST_F(CUDA, CONVOLUTION_BACKWARD_DATA_1) { | |||||
| TensorLayoutArray{filter, dst, src}); | TensorLayoutArray{filter, dst, src}); | ||||
| } | } | ||||
| TEST_F(CUDA, CONVOLUTION_BACKWARD_DEPTHWISE_LARGE_FILTER) { | |||||
| Checker<ConvolutionBackwardData> checker(handle_cuda()); | |||||
| checker.set_before_exec_callback( | |||||
| AlgoChecker<ConvolutionBackwardData>("DEPTHWISE_LARGE_FILTER")); | |||||
| for (auto dtype : std::vector<DType>{dtype::Float32()}) { | |||||
| auto run = [&checker, &dtype](size_t n, size_t g, size_t h, size_t fh) { | |||||
| param::Convolution param; | |||||
| param.stride_h = param.stride_w = 1; | |||||
| param.pad_h = param.pad_w = fh / 2; | |||||
| param.mode = Convolution::Mode::CROSS_CORRELATION; | |||||
| param.sparse = param::Convolution::Sparse::GROUP; | |||||
| checker.set_dtype(0, dtype).set_dtype(1, dtype).set_dtype(2, dtype); | |||||
| checker.set_param(param).execs( | |||||
| {{g, 1, 1, fh, fh}, {n, g, h, h}, {n, g, h, h}}); | |||||
| }; | |||||
| run(4, 8, 32, 5); | |||||
| run(4, 8, 32, 7); | |||||
| run(4, 8, 32, 9); | |||||
| run(4, 8, 32, 11); | |||||
| run(4, 8, 32, 13); | |||||
| run(4, 8, 32, 15); | |||||
| run(4, 8, 32, 17); | |||||
| run(4, 8, 32, 19); | |||||
| run(4, 8, 32, 21); | |||||
| run(4, 8, 32, 23); | |||||
| run(4, 8, 32, 25); | |||||
| run(4, 8, 32, 27); | |||||
| run(4, 8, 32, 29); | |||||
| run(4, 8, 32, 31); | |||||
| run(4, 8, 64, 7); | |||||
| run(4, 8, 64, 5); | |||||
| run(4, 8, 64, 9); | |||||
| run(4, 8, 64, 11); | |||||
| run(4, 8, 64, 13); | |||||
| run(4, 8, 64, 15); | |||||
| run(4, 8, 64, 17); | |||||
| run(4, 8, 64, 19); | |||||
| run(4, 8, 64, 21); | |||||
| run(4, 8, 64, 23); | |||||
| run(4, 8, 64, 25); | |||||
| run(4, 8, 64, 27); | |||||
| run(4, 8, 64, 29); | |||||
| run(4, 8, 64, 31); | |||||
| run(1, 2, 128, 31); | |||||
| run(1, 2, 256, 31); | |||||
| } | |||||
| } | |||||
| #if MEGDNN_WITH_BENCHMARK | #if MEGDNN_WITH_BENCHMARK | ||||
| TEST_F(CUDA, CONV_FWD_BENCHMARK) { | TEST_F(CUDA, CONV_FWD_BENCHMARK) { | ||||
| auto run = [&](size_t N, size_t OC, size_t IC, size_t IH, size_t IW, size_t SH = 1, | auto run = [&](size_t N, size_t OC, size_t IC, size_t IH, size_t IW, size_t SH = 1, | ||||
| @@ -901,24 +950,23 @@ TEST_F(CUDA, CONVOLUTION_BWD_DATA_BENCHMARK) { | |||||
| run(32, 64, 64, 56, 56, 1, 1, 0); | run(32, 64, 64, 56, 56, 1, 1, 0); | ||||
| } | } | ||||
| TEST_F(CUDA, BENCHMARK_CONVOLUTION_BWD_DATA_CHANWISE_SMALL_FEAT_LARGE_FILTER) { | |||||
| CUBenchmarker<ConvolutionBackwardData> bench{handle_cuda()}; | |||||
| std::unique_ptr<OprProxy<ConvolutionBackwardData>> proxy{ | |||||
| new OprProxy<ConvolutionBackwardData>{true}}; | |||||
| size_t RUNS = 10; | |||||
| bench.set_proxy(proxy).set_times(RUNS); | |||||
| TEST_F(CUDA, BENCHMARK_CONVOLUTION_BWD_DATA_DEPTHWISE_LARGE_FILTER) { | |||||
| CUBenchmarker<ConvolutionBackwardData> bencher{handle_cuda()}; | |||||
| bencher.set_display(false); | |||||
| bencher.set_before_exec_callback( | |||||
| AlgoChecker<ConvolutionBackwardData>("DEPTHWISE_LARGE_FILTER")); | |||||
| auto run = [&](size_t N, size_t OC, size_t g, size_t IH, size_t IW, size_t FH, | auto run = [&](size_t N, size_t OC, size_t g, size_t IH, size_t IW, size_t FH, | ||||
| size_t SH, size_t PH) { | |||||
| bench.set_dtype(0, dtype::Float32()) | |||||
| size_t SH, size_t nr_times) { | |||||
| bencher.set_dtype(0, dtype::Float32()) | |||||
| .set_dtype(1, dtype::Float32()) | .set_dtype(1, dtype::Float32()) | ||||
| .set_dtype(2, dtype::Float32()); | .set_dtype(2, dtype::Float32()); | ||||
| param::Convolution param; | param::Convolution param; | ||||
| param.stride_h = param.stride_w = SH; | param.stride_h = param.stride_w = SH; | ||||
| param.pad_h = param.pad_w = FH / 2; | param.pad_h = param.pad_w = FH / 2; | ||||
| param.sparse = param::Convolution::Sparse::GROUP; | param.sparse = param::Convolution::Sparse::GROUP; | ||||
| bench.set_param(param); | |||||
| bench.proxy()->target_execution_policy.algo.reset(); | |||||
| bencher.set_param(param); | |||||
| bencher.set_times(nr_times); | |||||
| TensorLayout src{{N, g, IH, IW}, dtype::Float32()}, | TensorLayout src{{N, g, IH, IW}, dtype::Float32()}, | ||||
| filter{{g, 1, 1, FH, FH}, dtype::Float32()}; | filter{{g, 1, 1, FH, FH}, dtype::Float32()}; | ||||
| TensorLayout dst; | TensorLayout dst; | ||||
| @@ -927,15 +975,28 @@ TEST_F(CUDA, BENCHMARK_CONVOLUTION_BWD_DATA_CHANWISE_SMALL_FEAT_LARGE_FILTER) { | |||||
| opr->param() = param; | opr->param() = param; | ||||
| opr->deduce_layout(src, filter, dst); | opr->deduce_layout(src, filter, dst); | ||||
| } | } | ||||
| auto time_ms_fp32 = bench.execl({filter, dst, src}) / RUNS; | |||||
| auto time_ms_fp32 = bencher.execl({filter, dst, src}) / nr_times; | |||||
| float flo = 2.0 * N * g * dst[2] * dst[3] * FH * FH; | float flo = 2.0 * N * g * dst[2] * dst[3] * FH * FH; | ||||
| printf("inp=%s, kern=%s, dst=%s ", src.to_string().c_str(), | printf("inp=%s, kern=%s, dst=%s ", src.to_string().c_str(), | ||||
| filter.to_string().c_str(), dst.to_string().c_str()); | filter.to_string().c_str(), dst.to_string().c_str()); | ||||
| printf("time_fp32=%.2fms, flops=%.3fTFLOPS\n", time_ms_fp32, | printf("time_fp32=%.2fms, flops=%.3fTFLOPS\n", time_ms_fp32, | ||||
| (flo / (time_ms_fp32 * 1e9))); | (flo / (time_ms_fp32 * 1e9))); | ||||
| }; | }; | ||||
| run(64, 384, 384, 32, 32, 31, 1, 15); | |||||
| run(64, 384, 384, 32, 32, 3, 1, 10); | |||||
| run(64, 384, 384, 32, 32, 5, 1, 10); | |||||
| run(64, 384, 384, 32, 32, 7, 1, 10); | |||||
| run(64, 384, 384, 32, 32, 9, 1, 10); | |||||
| run(64, 384, 384, 32, 32, 11, 1, 10); | |||||
| run(64, 384, 384, 32, 32, 13, 1, 10); | |||||
| run(64, 384, 384, 32, 32, 15, 1, 10); | |||||
| run(64, 384, 384, 32, 32, 17, 1, 10); | |||||
| run(64, 384, 384, 32, 32, 19, 1, 10); | |||||
| run(64, 384, 384, 32, 32, 21, 1, 10); | |||||
| run(64, 384, 384, 32, 32, 23, 1, 10); | |||||
| run(64, 384, 384, 32, 32, 25, 1, 10); | |||||
| run(64, 384, 384, 32, 32, 27, 1, 10); | |||||
| run(64, 384, 384, 32, 32, 29, 1, 10); | |||||
| run(64, 384, 384, 32, 32, 31, 1, 10); | |||||
| } | } | ||||
| TEST_F(CUDA, BENCHMARK_CONVOLUTION_BWD_DATA_BF16) { | TEST_F(CUDA, BENCHMARK_CONVOLUTION_BWD_DATA_BF16) { | ||||
| @@ -1103,7 +1164,7 @@ TEST_F(CUDA, CONVOLUTION_BWD_FILTER_BENCHMARK) { | |||||
| run(32, 64, 64, 56, 56, 1, 1, 0); | run(32, 64, 64, 56, 56, 1, 1, 0); | ||||
| } | } | ||||
| TEST_F(CUDA, BENCHMARK_CONVOLUTION_BWD_FILTER_CHANWISE_SMALL_FEAT_LARGE_FILTER) { | |||||
| TEST_F(CUDA, BENCHMARK_CONVOLUTION_BWD_FILTER_DEPTHWISE_LARGE_FILTER) { | |||||
| CUBenchmarker<ConvolutionBackwardFilter> bench{handle_cuda()}; | CUBenchmarker<ConvolutionBackwardFilter> bench{handle_cuda()}; | ||||
| std::unique_ptr<OprProxy<ConvolutionBackwardFilter>> proxy{ | std::unique_ptr<OprProxy<ConvolutionBackwardFilter>> proxy{ | ||||
| new OprProxy<ConvolutionBackwardFilter>{true}}; | new OprProxy<ConvolutionBackwardFilter>{true}}; | ||||
| @@ -57,6 +57,7 @@ | |||||
| #cmakedefine01 MEGDNN_64_BIT | #cmakedefine01 MEGDNN_64_BIT | ||||
| #cmakedefine01 MEGDNN_THREADS_512 | #cmakedefine01 MEGDNN_THREADS_512 | ||||
| #cmakedefine01 MEGDNN_ENABLE_MULTI_THREADS | #cmakedefine01 MEGDNN_ENABLE_MULTI_THREADS | ||||
| #cmakedefine01 MEGDNN_WITH_BENCHMARK | |||||
| // whether atlas is available | // whether atlas is available | ||||
| #ifndef MGB_ATLAS | #ifndef MGB_ATLAS | ||||