Browse Source

fix(dnn/cuda): fix recursive algo search for fallback_nchw_qs8

GitOrigin-RevId: 6be2991224
tags/v1.5.0
Megvii Engine Team huangxinda 4 years ago
parent
commit
00083d13b6
2 changed files with 34 additions and 1 deletions
  1. +4
    -1
      dnn/src/cuda/conv_bias/algo.h
  2. +30
    -0
      dnn/src/cuda/conv_bias/conv_nchwqs8.cpp

+ 4
- 1
dnn/src/cuda/conv_bias/algo.h View File

@@ -575,7 +575,10 @@ public:
return AlgoAttribute::REPRODUCIBLE;
}
MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_NCHW4_DOTPROD_INT8)

std::vector<SearchItem> get_subopr_list(
const TensorLayoutArray& layouts,
const OperatorBase* opr) const override;
private:
void make_inner_layout(const SizeArgs& args, TensorLayout& inner_src_layout,
TensorLayout& inner_weight_layout,


+ 30
- 0
dnn/src/cuda/conv_bias/conv_nchwqs8.cpp View File

@@ -69,6 +69,32 @@ void ConvBiasForwardImpl::AlgoFallbackNCHWQS8::make_inner_layout(
}
};

std::vector<Algorithm::SearchItem>
ConvBiasForwardImpl::AlgoFallbackNCHWQS8::get_subopr_list(
const TensorLayoutArray& layouts, const OperatorBase* opr) const {
const ConvBiasForwardImpl* o = static_cast<const ConvBiasForwardImpl*>(opr);
SizeArgs args(const_cast<ConvBiasForwardImpl*>(o), layouts[0], layouts[1],
layouts[2], layouts[3], layouts[4], nullptr);
TensorLayout inner_src_layout;
TensorLayout inner_weight_layout;
TensorLayout inner_dst_layout;
TensorLayout inner_bias_layout;
TensorLayout inner_z_layout;
make_inner_layout(args, inner_src_layout, inner_weight_layout,
inner_dst_layout, inner_bias_layout, inner_z_layout);

Param inner_conv_param = o->param();
inner_conv_param.format = Param::Format::NCHW4;
std::string param_str;
Algorithm::serialize_write_pod(inner_conv_param, param_str);

return {{Algorithm::OprType::CONVBIAS_FORWARD,
param_str,
{inner_src_layout, inner_weight_layout, inner_bias_layout,
inner_z_layout, inner_dst_layout}}};
}

bool ConvBiasForwardImpl::AlgoFallbackNCHWQS8::is_available(
const SizeArgs& args) const {
if (!args.src_layout->is_contiguous() ||
@@ -109,6 +135,8 @@ WorkspaceBundle ConvBiasForwardImpl::AlgoFallbackNCHWQS8::get_workspace_bundle(
}
auto opr = args.handle->create_operator<ConvBiasForward>();
opr->param() = inner_conv_param;
set_execution_policy<ConvBiasForward, ConvBiasForward*>(args.opr,
opr.get());
return WorkspaceBundle(
ptr,
{inner_src_layout.span().dist_byte(),
@@ -164,6 +192,8 @@ void ConvBiasForwardImpl::AlgoFallbackNCHWQS8::exec(
inner_conv_param.format =
dst_float ? Param::Format::NCHW4_NCHW : Param::Format::NCHW4;
auto inner_opr = args.handle->create_operator<ConvBiasForward>();
set_execution_policy<ConvBiasForward, ConvBiasForward*>(args.opr,
inner_opr.get());
inner_opr->param() = inner_conv_param;

relayout_nchw_nchw4->exec(*args.src_tensor, inner_src, {});


Loading…
Cancel
Save