From 00083d13b6b01a3e72d2b688c9608d47e7255335 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 20 May 2021 14:41:06 +0800 Subject: [PATCH] fix(dnn/cuda): fix recursive algo search for fallback_nchw_qs8 GitOrigin-RevId: 6be2991224bced3a38a17b6b888fd4f324d03f9f --- dnn/src/cuda/conv_bias/algo.h | 5 ++++- dnn/src/cuda/conv_bias/conv_nchwqs8.cpp | 30 +++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/dnn/src/cuda/conv_bias/algo.h b/dnn/src/cuda/conv_bias/algo.h index e4cd1a2a..869777b6 100644 --- a/dnn/src/cuda/conv_bias/algo.h +++ b/dnn/src/cuda/conv_bias/algo.h @@ -575,7 +575,10 @@ public: return AlgoAttribute::REPRODUCIBLE; } MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_NCHW4_DOTPROD_INT8) - + std::vector get_subopr_list( + const TensorLayoutArray& layouts, + const OperatorBase* opr) const override; + private: void make_inner_layout(const SizeArgs& args, TensorLayout& inner_src_layout, TensorLayout& inner_weight_layout, diff --git a/dnn/src/cuda/conv_bias/conv_nchwqs8.cpp b/dnn/src/cuda/conv_bias/conv_nchwqs8.cpp index 04a6697a..c8b395e6 100644 --- a/dnn/src/cuda/conv_bias/conv_nchwqs8.cpp +++ b/dnn/src/cuda/conv_bias/conv_nchwqs8.cpp @@ -69,6 +69,32 @@ void ConvBiasForwardImpl::AlgoFallbackNCHWQS8::make_inner_layout( } }; +std::vector +ConvBiasForwardImpl::AlgoFallbackNCHWQS8::get_subopr_list( + const TensorLayoutArray& layouts, const OperatorBase* opr) const { + const ConvBiasForwardImpl* o = static_cast(opr); + SizeArgs args(const_cast(o), layouts[0], layouts[1], + layouts[2], layouts[3], layouts[4], nullptr); + TensorLayout inner_src_layout; + TensorLayout inner_weight_layout; + TensorLayout inner_dst_layout; + TensorLayout inner_bias_layout; + TensorLayout inner_z_layout; + make_inner_layout(args, inner_src_layout, inner_weight_layout, + inner_dst_layout, inner_bias_layout, inner_z_layout); + + Param inner_conv_param = o->param(); + inner_conv_param.format = Param::Format::NCHW4; + + std::string param_str; + Algorithm::serialize_write_pod(inner_conv_param, param_str); + + return {{Algorithm::OprType::CONVBIAS_FORWARD, + param_str, + {inner_src_layout, inner_weight_layout, inner_bias_layout, + inner_z_layout, inner_dst_layout}}}; +} + bool ConvBiasForwardImpl::AlgoFallbackNCHWQS8::is_available( const SizeArgs& args) const { if (!args.src_layout->is_contiguous() || @@ -109,6 +135,8 @@ WorkspaceBundle ConvBiasForwardImpl::AlgoFallbackNCHWQS8::get_workspace_bundle( } auto opr = args.handle->create_operator(); opr->param() = inner_conv_param; + set_execution_policy(args.opr, + opr.get()); return WorkspaceBundle( ptr, {inner_src_layout.span().dist_byte(), @@ -164,6 +192,8 @@ void ConvBiasForwardImpl::AlgoFallbackNCHWQS8::exec( inner_conv_param.format = dst_float ? Param::Format::NCHW4_NCHW : Param::Format::NCHW4; auto inner_opr = args.handle->create_operator(); + set_execution_policy(args.opr, + inner_opr.get()); inner_opr->param() = inner_conv_param; relayout_nchw_nchw4->exec(*args.src_tensor, inner_src, {});