GitOrigin-RevId: c0530a949e
tags/v1.3.0
| @@ -36,6 +36,9 @@ ConvolutionBackwardDataImpl::AlgoPack::AlgoPack() { | |||
| int8_algos.push_back(&algo); | |||
| } | |||
| int8_algos.push_back(&int8_nchw_dotprod); | |||
| all_algos.push_back(&int8_nchw_dotprod); | |||
| all_algos.reserve(all_algos.size() * 2); | |||
| // add gconv algos by AlgoGroupConvGeneral | |||
| @@ -39,7 +39,8 @@ public: | |||
| CUDA_CHANWISE_SMALL, | |||
| CUDA_BFLOAT16, | |||
| CUDA_GROUP_CONV_GENERAL, | |||
| CUDA_IMPLICIT_GEMM_NCHW4_DOTPROD_INT8 | |||
| CUDA_IMPLICIT_GEMM_NCHW4_DOTPROD_INT8, | |||
| CUDA_IMPLICIT_GEMM_NCHW_DOTPROD_INT8 | |||
| }; | |||
| using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||
| @@ -254,12 +255,6 @@ public: | |||
| int warp_k; | |||
| int stage; | |||
| std::string to_string() { | |||
| /// default algorithm | |||
| if (threadblock_m == 128 && threadblock_n == 128 && | |||
| threadblock_k == 32 && warp_m == 32 && warp_n == 64 && | |||
| warp_k == 32 && stage == 2) { | |||
| return ""; | |||
| } | |||
| return ssprintf("_%dX%dX%d_%dX%dX%d_%dstage", threadblock_m, | |||
| threadblock_n, threadblock_k, warp_m, warp_n, | |||
| warp_k, stage); | |||
| @@ -284,6 +279,24 @@ private: | |||
| std::string m_name; | |||
| }; | |||
| class ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm final | |||
| : public AlgoBase { | |||
| public: | |||
| bool is_available(const SizeArgs& args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
| void exec(const ExecArgs& args) const override; | |||
| const char* name() const override { | |||
| return "INT8_NCHW_DOTPROD_IMPLICIT_GEMM"; | |||
| } | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE; | |||
| } | |||
| MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_NCHW_DOTPROD_INT8); | |||
| private: | |||
| WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, | |||
| const SizeArgs& args) const; | |||
| }; | |||
| class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj { | |||
| // defined in cudnn.cpp | |||
| void fill_cudnn_algos(); | |||
| @@ -303,6 +316,7 @@ public: | |||
| std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv; | |||
| AlgoBFloat16 bfloat16; | |||
| std::vector<AlgoInt8NCHW4DotProdImplicitGemm> int8_nchw4_dotprod; | |||
| AlgoInt8NCHWDotProdImplicitGemm int8_nchw_dotprod; | |||
| std::vector<AlgoBase*> | |||
| //! all algorithms | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * \file dnn/src/cuda/conv_bias/implicit_gemm_int8_nchw4_dp4a.cpp | |||
| * \file dnn/src/cuda/convolution/backward_data/implicit_gemm_int8_nchw4_dp4a.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| @@ -0,0 +1,154 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/cuda/convolution/backward_data/implicit_gemm_int8_nchw4_dp4a.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "./algo.h" | |||
| #include "src/cuda/utils.h" | |||
| #include "src/cuda/convolution_helper/parameter.cuh" | |||
| #include "src/cuda/convolution/backward_data/cutlass_deconvolution_wrapper.cuh" | |||
| using namespace megdnn; | |||
| using namespace cuda; | |||
| bool ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm:: | |||
| is_available(const SizeArgs& args) const { | |||
| auto&& fm = args.filter_meta; | |||
| if (fm.format != Param::Format::NCHW) | |||
| return false; | |||
| bool available = true; | |||
| auto src_dtype = args.diff_layout->dtype, | |||
| filter_dtype = args.filter_layout->dtype, | |||
| dst_dtype = args.grad_layout->dtype; | |||
| available &= (src_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||
| filter_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||
| dst_dtype.enumv() == DTypeEnum::QuantizedS8); | |||
| // TODO support group deconv int8 | |||
| available &= (fm.group == 1); | |||
| // ic and oc must be multiples of 4 | |||
| available &= ((fm.group * fm.icpg) % 4 == 0 && (fm.group * fm.ocpg) % 4 == 0); | |||
| // mode must be cross correlation | |||
| available &= !fm.should_flip; | |||
| // mode must be 2D | |||
| available &= fm.spatial_ndim == 2; | |||
| // TODO: support dialtion | |||
| available &= (fm.dilation[0] == 1 && fm.dilation[1] == 1); | |||
| // FIXME: too large filter size is not supported now | |||
| available &= fm.spatial[0] * fm.spatial[1] <= 64; | |||
| // only support sm_61 or later, platform should have fast native int8 | |||
| // support | |||
| available &= is_compute_capability_required(6, 1); | |||
| return available; | |||
| } | |||
| WorkspaceBundle ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm:: | |||
| get_workspace_bundle(dt_byte* raw_ptr, const SizeArgs& args) const { | |||
| size_t ws_filter = args.filter_layout->span().dist_byte(); | |||
| size_t ws_diff = args.diff_layout->span().dist_byte(); | |||
| size_t ws_grad = args.grad_layout->span().dist_byte(); | |||
| return WorkspaceBundle{raw_ptr, {ws_filter, ws_diff, ws_grad}}; | |||
| } | |||
| size_t ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm:: | |||
| get_workspace_in_bytes(const SizeArgs& args) const { | |||
| return get_workspace_bundle(nullptr, args).total_size_in_bytes(); | |||
| } | |||
| void ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm::exec( | |||
| const ExecArgs& args) const { | |||
| auto&& fm = args.filter_meta; | |||
| size_t n = args.diff_layout->operator[](0), | |||
| co = args.diff_layout->operator[](1), | |||
| ho = args.diff_layout->operator[](2), | |||
| wo = args.diff_layout->operator[](3); | |||
| size_t ci = args.grad_layout->operator[](1), | |||
| hi = args.grad_layout->operator[](2), | |||
| wi = args.grad_layout->operator[](3); | |||
| size_t fh = fm.spatial[0], fw = fm.spatial[1]; | |||
| size_t sh = fm.stride[0], sw = fm.stride[1]; | |||
| size_t ph = fm.padding[0], pw = fm.padding[1]; | |||
| auto&& stream = cuda_stream(args.opr->handle()); | |||
| auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); | |||
| int8_t* inner_filter_ptr = nullptr; | |||
| int8_t* inner_diff_ptr = nullptr; | |||
| // TODO: weight preprocess | |||
| { | |||
| inner_filter_ptr = reinterpret_cast<int8_t*>(bundle.get(0)); | |||
| // reformat filter from nchw to n4hwc4 | |||
| TensorLayout exec_src{{co / 4, 4, ci, fh, fw}, dtype::Int8()}; | |||
| TensorLayout exec_dst{{co / 4, fh, fw, ci, 4}, dtype::Int8()}; | |||
| exec_src = exec_src.dimshuffle({0, 3, 4, 2, 1}); | |||
| auto&& relayout = | |||
| args.opr->handle()->create_operator<RelayoutForward>(); | |||
| relayout->exec({args.filter_tensor->raw_ptr, exec_src}, | |||
| {inner_filter_ptr, exec_dst}); | |||
| } | |||
| { | |||
| inner_diff_ptr = reinterpret_cast<int8_t*>(bundle.get(1)); | |||
| // reformat diff from nchw to nchw4 | |||
| TensorLayout exec_src{{n, co / 4, 4, ho, wo}, dtype::Int8()}; | |||
| TensorLayout exec_dst{{n, co / 4, ho, wo, 4}, dtype::Int8()}; | |||
| exec_src = exec_src.dimshuffle({0, 1, 3, 4, 2}); | |||
| auto&& relayout = | |||
| args.opr->handle()->create_operator<RelayoutForward>(); | |||
| relayout->exec({args.diff_tensor->raw_ptr, exec_src}, | |||
| {inner_diff_ptr, exec_dst}); | |||
| } | |||
| int8_t* inner_grad_ptr = reinterpret_cast<int8_t*>(bundle.get(2)); | |||
| convolution::ConvParam kern_param; | |||
| kern_param.n = n, kern_param.co = co, kern_param.ci = ci, | |||
| kern_param.hi = hi, kern_param.wi = wi, kern_param.ho = ho, | |||
| kern_param.wo = wo, kern_param.ph = ph, kern_param.pw = pw, | |||
| kern_param.sh = sh, kern_param.sw = sw, kern_param.fh = fh, | |||
| kern_param.fw = fw; | |||
| float diff_scale = | |||
| args.diff_layout->dtype.param<dtype::QuantizedS8>().scale, | |||
| filter_scale = | |||
| args.filter_layout->dtype.param<dtype::QuantizedS8>().scale, | |||
| grad_scale = | |||
| args.grad_layout->dtype.param<dtype::QuantizedS8>().scale; | |||
| float alpha = diff_scale * filter_scale / grad_scale; | |||
| // only use 16x64x8_16x64x8_2stages impl | |||
| cutlass_wrapper::do_deconv_int8_implicit_gemm_dp4a_ncdiv4hw4( | |||
| inner_diff_ptr, inner_filter_ptr, inner_grad_ptr, nullptr, | |||
| kern_param, alpha, cutlass_wrapper::GemmCoord{16, 64, 8}, | |||
| cutlass_wrapper::GemmCoord{16, 64, 8}, 2, stream); | |||
| after_kernel_launch(); | |||
| { | |||
| // reformat grad from nchw4 to nchw | |||
| TensorLayout exec_src{{n, ci / 4, hi, wi, 4}, dtype::Int8()}; | |||
| TensorLayout exec_dst{{n, ci / 4, 4, hi, wi}, dtype::Int8()}; | |||
| exec_src = exec_src.dimshuffle({0, 1, 4, 2, 3}); | |||
| auto&& relayout = | |||
| args.opr->handle()->create_operator<RelayoutForward>(); | |||
| relayout->exec({inner_grad_ptr, exec_src}, | |||
| {args.grad_tensor->raw_ptr, exec_dst}); | |||
| } | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -106,6 +106,7 @@ public: | |||
| class AlgoGroupConvGeneral; | |||
| class AlgoBFloat16; | |||
| class AlgoInt8NCHW4DotProdImplicitGemm; | |||
| class AlgoInt8NCHWDotProdImplicitGemm; | |||
| class AlgoPack; | |||
| @@ -434,7 +434,7 @@ std::vector<TestArg> convolution::get_args_int8_nchw4_conv_bwd_data() { | |||
| param::Convolution cur_param; | |||
| // clang-format off | |||
| for (auto mode : {param::ConvBias::Mode::CROSS_CORRELATION}) { | |||
| for (auto mode : {param::Convolution::Mode::CROSS_CORRELATION}) { | |||
| for (size_t b : {64, 16}) { | |||
| for (size_t ic : {16, 32}) { | |||
| for (size_t oc : {16, 32}) { | |||
| @@ -449,8 +449,8 @@ std::vector<TestArg> convolution::get_args_int8_nchw4_conv_bwd_data() { | |||
| size_t f = kernel_size; | |||
| cur_param.mode = mode; | |||
| cur_param.format = param::ConvBias::Format::NCHW4; | |||
| cur_param.sparse = param::ConvBias::Sparse::DENSE; | |||
| cur_param.format = param::Convolution::Format::NCHW4; | |||
| cur_param.sparse = param::Convolution::Sparse::DENSE; | |||
| cur_param.pad_h = cur_param.pad_w = p; | |||
| cur_param.stride_h = cur_param.stride_w = s; | |||
| @@ -460,6 +460,54 @@ std::vector<TestArg> convolution::get_args_int8_nchw4_conv_bwd_data() { | |||
| } } } } } } } } } | |||
| // clang-format on | |||
| cur_param.pad_h = cur_param.pad_w = 1; | |||
| cur_param.stride_h = cur_param.stride_w = 1; | |||
| args.emplace_back(cur_param, TensorShape{16, 4, 8, 11, 4}, | |||
| TensorShape{16, 4, 3, 3, 4}); | |||
| return args; | |||
| } | |||
| std::vector<TestArg> convolution::get_args_int8_nchw_conv_bwd_data() { | |||
| std::vector<TestArg> args; | |||
| param::Convolution cur_param; | |||
| // clang-format off | |||
| for (auto mode : {param::Convolution::Mode::CROSS_CORRELATION}) { | |||
| for (size_t b : {64, 16}) { | |||
| for (size_t ic : {16, 32}) { | |||
| for (size_t oc : {16, 32}) { | |||
| for (size_t h : {8}) { | |||
| for (size_t w : {8, 11}) { | |||
| for (size_t kernel_size : {3, 4, 5, 7}) { | |||
| for (int p : {0, static_cast<int>(kernel_size / 2)}) { | |||
| for (size_t s : {2}) { | |||
| if (kernel_size >= 7) { | |||
| b = std::min(b, 32_z); | |||
| } | |||
| size_t f = kernel_size; | |||
| cur_param.mode = mode; | |||
| cur_param.format = param::Convolution::Format::NCHW; | |||
| cur_param.sparse = param::Convolution::Sparse::DENSE; | |||
| cur_param.pad_h = cur_param.pad_w = p; | |||
| cur_param.stride_h = cur_param.stride_w = s; | |||
| //! bias channel | |||
| args.emplace_back(cur_param, TensorShape{b, ic, h, w}, | |||
| TensorShape{oc, ic, f, f}); | |||
| } } } } } } } } } | |||
| // clang-format on | |||
| // test stride = 1 | |||
| cur_param.pad_h = cur_param.pad_w = 1; | |||
| cur_param.stride_h = cur_param.stride_w = 1; | |||
| args.emplace_back(cur_param, TensorShape{16, 16, 8, 11}, | |||
| TensorShape{16, 16, 3, 3}); | |||
| return args; | |||
| } | |||
| @@ -49,6 +49,7 @@ std::vector<TestArg> get_1x1_args(); | |||
| std::vector<TestArg> get_dilated_args(); | |||
| std::vector<TestArg> get_chanwise_args(); | |||
| std::vector<TestArg> get_args_int8_nchw4_conv_bwd_data(); | |||
| std::vector<TestArg> get_args_int8_nchw_conv_bwd_data(); | |||
| //! \param stage 0 for fwd, 1 for bwd data, 2 for bwd filter | |||
| using ConvEPSGetter = | |||
| @@ -266,19 +266,78 @@ TEST_F(CUDA, CONVOLUTION_BACKWARD_DATA_MATMUL) { | |||
| } | |||
| } | |||
| TEST_F(CUDA, CONVOLUTION_BACKWARD_DATA_INT8_DP4A) { | |||
| TEST_F(CUDA, CONVOLUTION_BACKWARD_DATA_INT8_NCHW4_DP4A) { | |||
| if (!cuda::is_compute_capability_required(6, 1)) { | |||
| printf("Skip CUDA.CONVOLUTION_BACKWARD_DATA_INT8_DP4A test as current " | |||
| "device doesn't support\n"); | |||
| printf("Skip CUDA.CONVOLUTION_BACKWARD_DATA_INT8_NCHW4_DP4A test as " | |||
| "current device doesn't support\n"); | |||
| return; | |||
| } | |||
| using namespace convolution; | |||
| std::vector<TestArg> args = get_args_int8_nchw4_conv_bwd_data(); | |||
| struct AlgoParam { | |||
| int threadblock_m; | |||
| int threadblock_n; | |||
| int threadblock_k; | |||
| int warp_m; | |||
| int warp_n; | |||
| int warp_k; | |||
| int stage; | |||
| std::string to_string() { | |||
| return ssprintf("_%dX%dX%d_%dX%dX%d_%dstage", threadblock_m, | |||
| threadblock_n, threadblock_k, warp_m, warp_n, | |||
| warp_k, stage); | |||
| } | |||
| }; | |||
| std::vector<AlgoParam> all_params; | |||
| all_params.emplace_back(AlgoParam{16, 64, 8, 16, 64, 8, 2}); | |||
| all_params.emplace_back(AlgoParam{16, 128, 16, 16, 64, 16, 2}); | |||
| all_params.emplace_back(AlgoParam{16, 128, 16, 16, 128, 16, 1}); | |||
| all_params.emplace_back(AlgoParam{32, 128, 32, 32, 64, 32, 2}); | |||
| all_params.emplace_back(AlgoParam{64, 128, 32, 64, 32, 32, 2}); | |||
| for (auto algo_param : all_params) { | |||
| Checker<ConvolutionBackwardData> checker(handle_cuda()); | |||
| std::string algo_name(ssprintf("INT8_NCHW4_DOTPROD_IMPLICIT_GEMM%s", | |||
| algo_param.to_string().c_str())); | |||
| checker.set_before_exec_callback( | |||
| AlgoChecker<ConvolutionBackwardData>(algo_name.c_str())); | |||
| checker.set_epsilon(1 + 1e-3).set_max_avg_error(1e-1); | |||
| for (auto&& arg : args) { | |||
| UniformIntRNG rng(-3, 3); | |||
| auto src = TensorLayout(arg.src, dtype::QuantizedS8{1.2f}); | |||
| auto filter = TensorLayout(arg.filter, dtype::QuantizedS8{1.3f}); | |||
| TensorLayout dst; | |||
| dst.dtype = dtype::QuantizedS8{1.2f}; | |||
| { | |||
| auto opr = handle_cuda()->create_operator<Convolution>(); | |||
| opr->param() = arg.param; | |||
| opr->deduce_layout(src, filter, dst); | |||
| } | |||
| checker.set_rng(0, &rng).set_rng(1, &rng).set_param(arg.param).exec( | |||
| TensorLayoutArray{filter, dst, src}); | |||
| } | |||
| } | |||
| } | |||
| TEST_F(CUDA, CONVOLUTION_BACKWARD_DATA_INT8_NCHW_DP4A) { | |||
| if (!cuda::is_compute_capability_required(6, 1)) { | |||
| printf("Skip CUDA.CONVOLUTION_BACKWARD_DATA_INT8_NCHW_DP4A test as " | |||
| "current device doesn't support\n"); | |||
| return; | |||
| } | |||
| using namespace convolution; | |||
| std::vector<TestArg> args = get_args_int8_nchw_conv_bwd_data(); | |||
| Checker<ConvolutionBackwardData> checker(handle_cuda()); | |||
| checker.set_before_exec_callback(AlgoChecker<ConvolutionBackwardData>( | |||
| "INT8_NCHW4_DOTPROD_IMPLICIT_GEMM")); | |||
| "INT8_NCHW_DOTPROD_IMPLICIT_GEMM")); | |||
| checker.set_epsilon(1 + 1e-3).set_max_avg_error(1e-1); | |||
| @@ -459,7 +459,8 @@ void FuseWarpPerspectiveDimshufflePass::apply(OptState& opt) const { | |||
| opr::WarpPerspective::Param new_param, | |||
| megdnn::DType dst_dtype, | |||
| SymbolVar& new_warp) { | |||
| OperatorNodeConfig new_config(dst_dtype); | |||
| OperatorNodeConfig new_config = warp->config(); | |||
| new_config.output_dtype(dst_dtype); | |||
| if (warp->input().size() == 3) { | |||
| auto src = rewriter.get_var(warp->input(0)), | |||
| mat = rewriter.get_var(warp->input(1)), | |||
| @@ -1514,6 +1514,46 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() { | |||
| return new_opr; | |||
| }; | |||
| auto replace_deconv_opr = [trans_nchw4, conv_format]( | |||
| OperatorNodeBase* opr, | |||
| const VarNodeArray& new_inp) { | |||
| if (new_inp[1]->dtype().enumv() == DTypeEnum::Float32) { | |||
| return serialization::copy_opr_shallow(*opr, new_inp, | |||
| opr->config()); | |||
| } | |||
| mgb_assert(opr->input().size() == new_inp.size()); | |||
| auto& deconv_opr = opr->cast_final_safe<opr::ConvolutionBackwardData>(); | |||
| if ((deconv_opr.param().format != | |||
| megdnn::param::Convolution::Format::NCHW) || | |||
| (deconv_opr.param().sparse != | |||
| megdnn::param::Convolution::Sparse::DENSE)) { | |||
| return serialization::copy_opr_shallow(*opr, new_inp, | |||
| opr->config()); | |||
| } | |||
| VarNode *deconv_src = new_inp[1], *deconv_filter = new_inp[0]; | |||
| auto deconv_mode = trans_nchw4(deconv_opr.param().sparse, deconv_filter); | |||
| // src: NCHW --> NCWH4 | |||
| if (deconv_src->shape().ndim != 5) { | |||
| mgb_assert(deconv_src->shape().ndim == 4); | |||
| auto new_src = | |||
| RelayoutPlaceholder::make(deconv_src, deconv_mode.src); | |||
| deconv_src = new_src.node(); | |||
| } | |||
| // weight: NCHW --> NCHW4 | |||
| auto new_filter = | |||
| RelayoutPlaceholder::make(deconv_filter, deconv_mode.weight); | |||
| deconv_filter = new_filter.node(); | |||
| // format: NCHW --> NCHW4 | |||
| auto new_param = deconv_opr.param(); | |||
| new_param.format = conv_format; | |||
| // dst | |||
| auto new_deconv_opr = opr::ConvolutionBackwardData::make_deconv( | |||
| deconv_src, deconv_filter, new_param, | |||
| deconv_opr.execution_policy(), deconv_opr.config()); | |||
| OperatorNodeBase* new_opr = new_deconv_opr.node()->owner_opr(); | |||
| return new_opr; | |||
| }; | |||
| auto replace_batch_conv_bias_opr = [batch_conv_bias_format, | |||
| src_to_nchw4_mode]( | |||
| OperatorNodeBase* opr, | |||
| @@ -1806,6 +1846,8 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() { | |||
| auto&& replace_func = ret->m_opr_replace_func; | |||
| //! supportted nchw4 | |||
| replace_func[opr::Convolution::typeinfo()] = replace_conv_opr; | |||
| replace_func[opr::ConvolutionBackwardData::typeinfo()] = | |||
| replace_deconv_opr; | |||
| replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr; | |||
| replace_func[opr::BatchConvBias::typeinfo()] = replace_batch_conv_bias_opr; | |||
| replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr; | |||
| @@ -1818,8 +1860,6 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() { | |||
| replace_func[opr::PowC::typeinfo()] = replace_elemwise_opr; | |||
| //! not supported nchw4 | |||
| replace_func[opr::Concat::typeinfo()] = relayout_inp_to_nchw; | |||
| replace_func[opr::ConvolutionBackwardData::typeinfo()] = | |||
| relayout_inp_to_nchw; | |||
| replace_func[opr::Subtensor::typeinfo()] = relayout_inp_to_nchw; | |||
| replace_func[opr::GetVarShape::typeinfo()] = relayout_inp_to_nchw; | |||
| replace_func[opr::Dimshuffle::typeinfo()] = relayout_inp_to_nchw; | |||
| @@ -2923,6 +2923,7 @@ TEST(TestGoptInference, ConvertFormatNCHW4GPU) { | |||
| auto conv1 = opr::ConvBiasForward::make( | |||
| x, w1, b1, param_conv_bias, {}, | |||
| OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); | |||
| // group | |||
| // icpg != 1 && ocpg != 1 | |||
| param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP; | |||
| @@ -2932,8 +2933,19 @@ TEST(TestGoptInference, ConvertFormatNCHW4GPU) { | |||
| conv1, w2, b2, param_conv_bias, {}, | |||
| OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); | |||
| auto conv2_fp32 = opr::TypeCvt::make(conv2, dtype::Float32()); | |||
| auto y = conv2_fp32 + opr::TypeCvt::make(b2, dtype::Float32()); | |||
| opr::Convolution::Param param_deconv; | |||
| param_deconv.format = opr::Convolution::Param::Format::NCHW; | |||
| param_deconv.stride_h = param_deconv.stride_w = 2; | |||
| param_deconv.pad_h = param_deconv.pad_w = 2; | |||
| // dense | |||
| param_deconv.sparse = opr::Convolution::Param::Sparse::DENSE; | |||
| auto w3 = mkcvar("w3", {8, 8, 4, 4}, dtype::QuantizedS8(2.5f)); | |||
| auto deconv1 = opr::ConvolutionBackwardData::make_deconv( | |||
| conv2, w3, param_deconv, {}, | |||
| OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); | |||
| auto deconv1_fp32 = opr::TypeCvt::make(deconv1, dtype::Float32()); | |||
| auto y = deconv1_fp32 + opr::TypeCvt::make(b2, dtype::Float32()); | |||
| SymbolVar y_opt; | |||
| { | |||
| @@ -2944,6 +2956,8 @@ TEST(TestGoptInference, ConvertFormatNCHW4GPU) { | |||
| ASSERT_EQ(opr::ConvBias::Param::Format::NCHW4, | |||
| find_opr<opr::ConvBias>(y_opt).param().format); | |||
| ASSERT_EQ(opr::ConvolutionBackwardData::Param::Format::NCHW4, | |||
| find_opr<opr::ConvolutionBackwardData>(y_opt).param().format); | |||
| auto nr_reshape = find_opr_num<mgb::opr::Reshape>(y_opt); | |||
| ASSERT_EQ(2u, nr_reshape); | |||
| @@ -51,7 +51,7 @@ decl_opr('ConvolutionBackwardData', | |||
| ], | |||
| desc='batched deconvolution on channeled 2D images; the underlying ' | |||
| 'computation is in fact gradient of convolution w.r.t. data', | |||
| version=2) | |||
| version=2, has_out_dtype=True) | |||
| decl_opr('MaskConvolution', | |||
| inputs=[Doc('src', | |||
| @@ -609,10 +609,11 @@ TEST(TestOprDNN, DeconvolutionExePolicy_QuantizedS8) { | |||
| using S = Policy::Strategy; | |||
| #if MGB_ENABLE_FASTRUN | |||
| for (auto strategy : {S::PROFILE, S::HEURISTIC, S::PROFILE_REPRODUCIBLE, | |||
| S::PROFILE_HEURISTIC}) { | |||
| for (auto strategy : | |||
| {S::PROFILE, S::HEURISTIC, S(S::PROFILE | S::REPRODUCIBLE), | |||
| S(S::PROFILE | S::HEURISTIC)}) { | |||
| #else | |||
| for (auto strategy : {S : HEURISTIC, S::PROFILE_HEURISTIC}) { | |||
| for (auto strategy: {S:HEURISTIC, S(S::PROFILE | S::HEURISTIC)}) { | |||
| #endif | |||
| auto graph = ComputingGraph::make(); | |||
| HostTensorGenerator<> gen; | |||