GitOrigin-RevId: 5a3bfedd8a
tags/v0.5.0
| @@ -201,24 +201,27 @@ bool ConvBiasImpl::AlgoConv1x1::usable(ConvBiasImpl* opr, | |||
| if (FH != 1 || FW != 1 || PH || PW || SH != 1 || SW != 1) | |||
| return false; | |||
| if (param.src_type.enumv() != param.filter_type.enumv() && | |||
| param.src_type.enumv() != DTypeEnum::Int8 && | |||
| param.src_type.enumv() != DTypeEnum::QuantizedS8 && | |||
| param.src_type.enumv() != DTypeEnum::Quantized8Asymm && | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| param.src_type.enumv() != DTypeEnum::Float16 && | |||
| #endif | |||
| param.src_type.enumv() != DTypeEnum::Float32) { | |||
| return false; | |||
| } | |||
| //! make sure 8x8x16 and 8x8x32 biasmode is nobias and nonlineMode | |||
| //! is identity otherwise return false mean that 8x8x32 and 8x8x16 | |||
| //! not support PostProcess | |||
| if (param.src_type.enumv() == param.filter_type.enumv() && | |||
| (param.src_type.enumv() == DTypeEnum::Int8 && | |||
| (param.dst_type.enumv() == DTypeEnum::Int16 || | |||
| param.dst_type.enumv() == DTypeEnum::Int32)) && | |||
| param.bias_mode != megdnn::BiasMode::NO_BIAS && | |||
| param.nonlineMode != megdnn::NonlineMode::IDENTITY) | |||
| return false; | |||
| if (param.src_type.enumv() == param.filter_type.enumv() && | |||
| ((param.src_type.enumv() == DTypeEnum::QuantizedS8 || | |||
| param.src_type.enumv() == DTypeEnum::Quantized8Asymm) && | |||
| param.dst_type.enumv() == DTypeEnum::QuantizedS32) && | |||
| param.bias_mode != megdnn::BiasMode::NO_BIAS && | |||
| param.nonlineMode != megdnn::NonlineMode::IDENTITY) | |||
| return false; | |||
| if (param.dst_type.enumv() == DTypeEnum::Int16 || | |||
| param.dst_type.enumv() == DTypeEnum::Int32 || | |||
| param.dst_type.enumv() == DTypeEnum::QuantizedS32) { | |||
| if (param.bias_mode != megdnn::BiasMode::NO_BIAS || | |||
| param.nonlineMode != megdnn::NonlineMode::IDENTITY) { | |||
| return false; | |||
| } | |||
| } | |||
| if (opr->param().format == param::ConvBias::Format::NCHW44 || | |||
| opr->param().format == param::ConvBias::Format::NCHW44_DOT) { | |||
| @@ -647,19 +647,26 @@ bool ConvBiasImpl::AlgoIm2col::usable( | |||
| return false; | |||
| } | |||
| if (param.src_type.enumv() != param.filter_type.enumv() && | |||
| param.src_type.enumv() != DTypeEnum::Int8 && | |||
| param.src_type.enumv() != DTypeEnum::QuantizedS8 && | |||
| param.src_type.enumv() != DTypeEnum::Quantized8Asymm && | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| param.src_type.enumv() != DTypeEnum::Float16 && | |||
| #endif | |||
| param.src_type.enumv() != DTypeEnum::Float32) { | |||
| return false; | |||
| } | |||
| //! make sure 8x8x16 and 8x8x32 biasmode is nobias and nonlineMode is | |||
| //! identity otherwise return false mean that 8x8x32 and 8x8x16 not | |||
| //! support PostProcess | |||
| if (param.src_type.enumv() == param.filter_type.enumv() && | |||
| ((param.src_type.enumv() == DTypeEnum::Int8 && | |||
| (param.dst_type.enumv() == DTypeEnum::Int16 || | |||
| param.dst_type.enumv() == DTypeEnum::Int32)) || | |||
| ((param.src_type.enumv() == DTypeEnum::QuantizedS8 || | |||
| param.src_type.enumv() == DTypeEnum::Quantized8Asymm) && | |||
| param.dst_type.enumv() == DTypeEnum::QuantizedS32)) && | |||
| param.bias_mode != megdnn::BiasMode::NO_BIAS && | |||
| param.nonlineMode != megdnn::NonlineMode::IDENTITY) { | |||
| return false; | |||
| if (param.dst_type.enumv() == DTypeEnum::Int16 || | |||
| param.dst_type.enumv() == DTypeEnum::Int32 || | |||
| param.dst_type.enumv() == DTypeEnum::QuantizedS32) { | |||
| if (param.bias_mode != megdnn::BiasMode::NO_BIAS || | |||
| param.nonlineMode != megdnn::NonlineMode::IDENTITY) { | |||
| return false; | |||
| } | |||
| } | |||
| if (opr->param().format == param::ConvBias::Format::NCHW44 || | |||
| opr->param().format == param::ConvBias::Format::NCHW44_DOT) { | |||
| @@ -188,6 +188,24 @@ void checker_conv_bias(std::vector<conv_bias::TestArg> args, Handle* handle, | |||
| } | |||
| } | |||
| TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_FORWARD_IM2COL_8X8X16) { | |||
| using namespace conv_bias; | |||
| param::ConvBias cur_param; | |||
| using NLMode = param::ConvBias::NonlineMode; | |||
| std::vector<conv_bias::TestArg> args = get_conv_bias_args( | |||
| {1, 3}, {0}, {NLMode::IDENTITY, NLMode::RELU}, {1}, false, true); | |||
| NormalRNG default_rng; | |||
| Checker<ConvBias> checker(handle()); | |||
| checker.set_dtype(0, dtype::Int8{}); | |||
| checker.set_dtype(1, dtype::Int8{}); | |||
| checker.set_dtype(2, dtype::Int16{}); | |||
| checker.set_dtype(4, dtype::Int16{}); | |||
| for (auto&& arg : args) { | |||
| checker.set_param(arg.param).execs( | |||
| {arg.src, arg.filter, arg.bias, {}, {}}); | |||
| } | |||
| } | |||
| TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_FORWARD) { | |||
| using namespace conv_bias; | |||
| param::ConvBias cur_param; | |||
| @@ -1671,7 +1671,9 @@ void FuseConvBiasNonlinPass::apply(OptState& state) const { | |||
| rewriter.get_var(typecvt->input(0))->owner_opr()); | |||
| if (!conv_bias || m_deps.count(typecvt->input(0)) != 1 || | |||
| typecvt->output(0)->dtype().enumv() != | |||
| DTypeTrait<dtype::QuantizedS8>::enumv) | |||
| DTypeTrait<dtype::QuantizedS8>::enumv || | |||
| typecvt->input(0)->dtype().enumv() != | |||
| DTypeTrait<dtype::QuantizedS32>::enumv) | |||
| return nullptr; | |||
| auto config = conv_bias->config(); | |||