diff --git a/dnn/src/fallback/convolution/opr_impl.cpp b/dnn/src/fallback/convolution/opr_impl.cpp index 1e1622db..0624f02e 100644 --- a/dnn/src/fallback/convolution/opr_impl.cpp +++ b/dnn/src/fallback/convolution/opr_impl.cpp @@ -431,7 +431,9 @@ ConvolutionImpl::AlgoDataType ConvolutionImpl::NCBKernSizeParam::deduce_algo_dat } } else if (src_type.enumv() == DTypeEnum::Quantized8Asymm) { return ConvolutionImpl::AlgoDataType::QUINT8X8X32; - } else if (src_type.enumv() == DTypeEnum::QuantizedS4) { + } else if ( + src_type.enumv() == DTypeEnum::QuantizedS4 || + src_type.enumv() == DTypeEnum::Quantized4Asymm) { return ConvolutionImpl::AlgoDataType::QINT4x4x32; } else { megdnn_throw(ssprintf( @@ -477,7 +479,8 @@ void ConvolutionBackwardDataImpl::exec( _megdnn_workspace workspace) { if (param().format == param::Convolution::Format::NHWCD4 || param().format == param::Convolution::Format::NCHW4 || - (param().format == param::Convolution::Format::NCHW && + ((param().format == param::Convolution::Format::NCHW || + param().format == param::Convolution::Format::NHWC) && grad.layout.dtype.enumv() == DTypeEnum::QuantizedS8)) { return naive::ConvolutionBackwardDataImpl::exec(filter, diff, grad, workspace); } @@ -499,7 +502,8 @@ size_t ConvolutionBackwardDataImpl::get_workspace_in_bytes( if (param().format == param::Convolution::Format::NHWCD4 || param().format == param::Convolution::Format::NCHW4 || - (param().format == param::Convolution::Format::NCHW && + ((param().format == param::Convolution::Format::NCHW || + param().format == param::Convolution::Format::NHWC) && grad.dtype.enumv() == DTypeEnum::QuantizedS8)) { return naive::ConvolutionBackwardDataImpl::get_workspace_in_bytes( filter, diff, grad); @@ -514,7 +518,8 @@ std::vector ConvolutionBackwardDataImpl const TensorLayout& grad) { if (param().format == param::Convolution::Format::NHWCD4 || param().format == param::Convolution::Format::NCHW4 || - (param().format == param::Convolution::Format::NCHW && + ((param().format == param::Convolution::Format::NCHW || + param().format == param::Convolution::Format::NHWC) && grad.dtype.enumv() == DTypeEnum::QuantizedS8)) { return naive::ConvolutionBackwardDataImpl::get_all_algorithms( filter, diff, grad); @@ -541,7 +546,8 @@ ConvolutionBackwardDataImpl::Algorithm* ConvolutionBackwardDataImpl:: const AlgoAttribute& negative_attr) { if (param().format == param::Convolution::Format::NHWCD4 || param().format == param::Convolution::Format::NCHW4 || - (param().format == param::Convolution::Format::NCHW && + ((param().format == param::Convolution::Format::NCHW || + param().format == param::Convolution::Format::NHWC) && grad.dtype.enumv() == DTypeEnum::QuantizedS8)) { return naive::ConvolutionBackwardDataImpl::get_algorithm_heuristic( filter, diff, grad, workspace_limit_in_bytes, positive_attr,