|
|
|
@@ -69,6 +69,32 @@ void ConvBiasForwardImpl::AlgoFallbackNCHWQS8::make_inner_layout( |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
std::vector<Algorithm::SearchItem> |
|
|
|
ConvBiasForwardImpl::AlgoFallbackNCHWQS8::get_subopr_list( |
|
|
|
const TensorLayoutArray& layouts, const OperatorBase* opr) const { |
|
|
|
const ConvBiasForwardImpl* o = static_cast<const ConvBiasForwardImpl*>(opr); |
|
|
|
SizeArgs args(const_cast<ConvBiasForwardImpl*>(o), layouts[0], layouts[1], |
|
|
|
layouts[2], layouts[3], layouts[4], nullptr); |
|
|
|
TensorLayout inner_src_layout; |
|
|
|
TensorLayout inner_weight_layout; |
|
|
|
TensorLayout inner_dst_layout; |
|
|
|
TensorLayout inner_bias_layout; |
|
|
|
TensorLayout inner_z_layout; |
|
|
|
make_inner_layout(args, inner_src_layout, inner_weight_layout, |
|
|
|
inner_dst_layout, inner_bias_layout, inner_z_layout); |
|
|
|
|
|
|
|
Param inner_conv_param = o->param(); |
|
|
|
inner_conv_param.format = Param::Format::NCHW4; |
|
|
|
|
|
|
|
std::string param_str; |
|
|
|
Algorithm::serialize_write_pod(inner_conv_param, param_str); |
|
|
|
|
|
|
|
return {{Algorithm::OprType::CONVBIAS_FORWARD, |
|
|
|
param_str, |
|
|
|
{inner_src_layout, inner_weight_layout, inner_bias_layout, |
|
|
|
inner_z_layout, inner_dst_layout}}}; |
|
|
|
} |
|
|
|
|
|
|
|
bool ConvBiasForwardImpl::AlgoFallbackNCHWQS8::is_available( |
|
|
|
const SizeArgs& args) const { |
|
|
|
if (!args.src_layout->is_contiguous() || |
|
|
|
@@ -109,6 +135,8 @@ WorkspaceBundle ConvBiasForwardImpl::AlgoFallbackNCHWQS8::get_workspace_bundle( |
|
|
|
} |
|
|
|
auto opr = args.handle->create_operator<ConvBiasForward>(); |
|
|
|
opr->param() = inner_conv_param; |
|
|
|
set_execution_policy<ConvBiasForward, ConvBiasForward*>(args.opr, |
|
|
|
opr.get()); |
|
|
|
return WorkspaceBundle( |
|
|
|
ptr, |
|
|
|
{inner_src_layout.span().dist_byte(), |
|
|
|
@@ -164,6 +192,8 @@ void ConvBiasForwardImpl::AlgoFallbackNCHWQS8::exec( |
|
|
|
inner_conv_param.format = |
|
|
|
dst_float ? Param::Format::NCHW4_NCHW : Param::Format::NCHW4; |
|
|
|
auto inner_opr = args.handle->create_operator<ConvBiasForward>(); |
|
|
|
set_execution_policy<ConvBiasForward, ConvBiasForward*>(args.opr, |
|
|
|
inner_opr.get()); |
|
|
|
inner_opr->param() = inner_conv_param; |
|
|
|
|
|
|
|
relayout_nchw_nchw4->exec(*args.src_tensor, inner_src, {}); |
|
|
|
|