GitOrigin-RevId: 3d97fedc8f
tags/v1.0.0-rc1
| @@ -100,7 +100,6 @@ namespace { | |||
| MIDOUT_END(); \ | |||
| break; \ | |||
| default: \ | |||
| megdnn_throw("no quantized unsupported biasmode"); \ | |||
| break; \ | |||
| } | |||
| @@ -258,6 +257,66 @@ struct PostProcess<opctype, opdtype, megdnn::PostprocessMode::QUANTIZED> { | |||
| #undef FOR_NONLINEAR_NOBIAS | |||
| #undef FOR_NONLINEAR | |||
| #undef FOR_BIAS | |||
| #define FOR_BINARY_BROADCAST(_op) \ | |||
| megdnn::arm_common:: \ | |||
| OpCallerBinary<_op<ctype>, megdnn::arm_common::VEC_BCAST101>::run( \ | |||
| static_cast<ctype*>(conv_dst_ptr), \ | |||
| reinterpret_cast<const ctype*>(bias_ptr), \ | |||
| reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \ | |||
| dst_type, N, OC, OH* OW); | |||
| #define FOR_BINARY_BROADCAST_NCHW44(_op) \ | |||
| megdnn::arm_common::OpCallerBinary<_op<ctype>, \ | |||
| megdnn::arm_common::VEC_BCAST101x4>:: \ | |||
| run(static_cast<ctype*>(conv_dst_ptr), \ | |||
| reinterpret_cast<const ctype*>(bias_ptr), \ | |||
| reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \ | |||
| dst_type, N, OC, OH* OW, pack_oc_size); | |||
| #define FOR_BINARY(_op) \ | |||
| megdnn::arm_common:: \ | |||
| OpCallerBinary<_op<ctype>, megdnn::arm_common::VEC_VEC>::run( \ | |||
| static_cast<ctype*>(conv_dst_ptr), \ | |||
| reinterpret_cast<const ctype*>(bias_ptr), \ | |||
| reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \ | |||
| dst_type, N* OC* OH* OW* pack_oc_size); | |||
| #define FOR_BIAS(_bias_mode, OH, OW) \ | |||
| switch (_bias_mode) { \ | |||
| case megdnn::BiasMode::NO_BIAS: \ | |||
| break; \ | |||
| case megdnn::BiasMode::BROADCAST_CHANNEL_BIAS: \ | |||
| if (pack_oc_size == 1) { \ | |||
| FOR_BINARY_BROADCAST(CONCAT_OP(AddOp)); \ | |||
| } else { \ | |||
| megdnn_assert(pack_oc_size == 4, \ | |||
| "Only support nchw44 in ARM"); \ | |||
| FOR_BINARY_BROADCAST_NCHW44(CONCAT_OP(AddOp)); \ | |||
| } \ | |||
| break; \ | |||
| case megdnn::BiasMode::BIAS: \ | |||
| FOR_BINARY(CONCAT_OP(AddOp)); \ | |||
| break; \ | |||
| default: \ | |||
| break; \ | |||
| } | |||
| template <typename ctype, typename dtype> | |||
| struct PostProcess<ctype, dtype, megdnn::PostprocessMode::ADD_BIAS> { | |||
| static void run(void* conv_dst_ptr, void* bias_ptr, void* dst_ptr, | |||
| megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode, | |||
| megdnn::DType bias_type, megdnn::DType dst_type, size_t N, | |||
| size_t OC, size_t OH, size_t OW, size_t pack_oc_size = 1) { | |||
| megdnn_assert(nonlineMode == megdnn::NonlineMode::IDENTITY); | |||
| FOR_BIAS(bias_mode, OH, OW); | |||
| } | |||
| }; | |||
| #undef FOR_BINARY_BROADCAST | |||
| #undef FOR_BINARY_BROADCAST_NCHW44 | |||
| #undef FOR_BINARY | |||
| #undef FOR_BIAS | |||
| #undef CB | |||
| #undef CONCAT_OP | |||
| #undef CONCAT_NL | |||
| @@ -158,9 +158,11 @@ private: \ | |||
| uint32_t m_tile_size; | |||
| enum class PostprocessMode : uint8_t { | |||
| FLOAT = 0, ///< support all biasmode and no_nonlinemode | |||
| NO_PROCESS, ///<support non bias and identity | |||
| QUANTIZED,///<support NOBIAS ,BROADCAST_CHANNEL_BIAS and relu hswish identify nonline mode | |||
| FLOAT = 0, ///< support all biasmode and no_nonlinemode | |||
| NO_PROCESS, ///< support non bias and identity | |||
| QUANTIZED, ///< support NOBIAS ,BROADCAST_CHANNEL_BIAS and relu hswish | |||
| ///< identify nonline mode | |||
| ADD_BIAS, ///< only add bias | |||
| }; | |||
| } // namespace megdnn | |||
| @@ -227,8 +227,7 @@ bool ConvBiasImpl::AlgoConv1x1::usable(const NCBKernSizeParam& param, | |||
| param.dst_type.enumv() == DTypeEnum::QuantizedS16 || | |||
| param.dst_type.enumv() == DTypeEnum::Int32 || | |||
| param.dst_type.enumv() == DTypeEnum::QuantizedS32) { | |||
| if (param.bias_mode != megdnn::BiasMode::NO_BIAS || | |||
| param.nonlineMode != megdnn::NonlineMode::IDENTITY) { | |||
| if (param.nonlineMode != megdnn::NonlineMode::IDENTITY) { | |||
| return false; | |||
| } | |||
| } | |||
| @@ -310,6 +310,19 @@ ConvBiasImpl::AlgoConv1x1Gemv::dispatch_kerns( | |||
| } \ | |||
| } \ | |||
| MIDOUT_END() | |||
| #define cb3(_format, _i_src_type, _i_bias_type, _i_dst_type, _src_ctype, \ | |||
| _bias_ctype, _dst_ctype, _postprocess_mode, _midout_tag) \ | |||
| MIDOUT_BEGIN(megdnn_fallback_conv1x1_gemv, midout_iv(_midout_tag)) { \ | |||
| if (param.filter_type.enumv() == param.src_type.enumv() && \ | |||
| param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \ | |||
| param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \ | |||
| conv1x1_gemv_worker = \ | |||
| Conv1x1GemvWorker<_src_ctype, _bias_ctype, _dst_ctype, \ | |||
| _bias_ctype, _dst_ctype, \ | |||
| _postprocess_mode, _format>::exec; \ | |||
| } \ | |||
| } \ | |||
| MIDOUT_END() | |||
| switch (param.filter_meta.format) { | |||
| case param::ConvBias::Format::NCHW: | |||
| @@ -324,23 +337,23 @@ ConvBiasImpl::AlgoConv1x1Gemv::dispatch_kerns( | |||
| PostprocessMode::NO_PROCESS, "NCHW::GEMV::FLOAT16_FLOAT16"_hash); | |||
| #endif | |||
| #endif | |||
| cb2(param::ConvBias::Format::NCHW, dt_int8, dt_int32, dt_int32, | |||
| dt_int8, dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||
| cb3(param::ConvBias::Format::NCHW, dt_int8, dt_int32, dt_int32, | |||
| dt_int8, dt_int32, dt_int32, PostprocessMode::ADD_BIAS, | |||
| "NCHW::GEMV::INT8x8x32_INT32"_hash); | |||
| cb2(param::ConvBias::Format::NCHW, dt_int8, dt_int16, dt_int16, | |||
| dt_int8, dt_int16, dt_int16, PostprocessMode::NO_PROCESS, | |||
| cb3(param::ConvBias::Format::NCHW, dt_int8, dt_int16, dt_int16, | |||
| dt_int8, dt_int16, dt_int16, PostprocessMode::ADD_BIAS, | |||
| "NCHW::GEMV::INT8x8x16_INT16"_hash); | |||
| cb2(param::ConvBias::Format::NCHW, dtype::QuantizedS8, | |||
| cb3(param::ConvBias::Format::NCHW, dtype::QuantizedS8, | |||
| dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, dt_int32, | |||
| dt_int32, PostprocessMode::NO_PROCESS, | |||
| dt_int32, PostprocessMode::ADD_BIAS, | |||
| "NCHW::GEMV::QINT8x8x32_QINT32"_hash); | |||
| cb2(param::ConvBias::Format::NCHW, dtype::QuantizedS8, | |||
| dtype::QuantizedS32, dtype::QuantizedS8, dt_int8, dt_int32, | |||
| dt_int8, PostprocessMode::QUANTIZED, | |||
| "NCHW::GEMV::QINT8x8x32_QINT8"_hash); | |||
| cb2(param::ConvBias::Format::NCHW, dtype::Quantized8Asymm, | |||
| cb3(param::ConvBias::Format::NCHW, dtype::Quantized8Asymm, | |||
| dtype::QuantizedS32, dtype::QuantizedS32, dt_uint8, dt_int32, | |||
| dt_int32, PostprocessMode::NO_PROCESS, | |||
| dt_int32, PostprocessMode::ADD_BIAS, | |||
| "NCHW::GEMV::QUINT8x8x32_QINT32"_hash); | |||
| cb2(param::ConvBias::Format::NCHW, dtype::Quantized8Asymm, | |||
| dtype::QuantizedS32, dtype::Quantized8Asymm, dt_uint8, dt_int32, | |||
| @@ -365,13 +378,13 @@ ConvBiasImpl::AlgoConv1x1Gemv::dispatch_kerns( | |||
| break; | |||
| case param::ConvBias::Format::NCHW44_DOT: | |||
| cb2(param::ConvBias::Format::NCHW44_DOT, dt_int8, dt_int32, | |||
| cb3(param::ConvBias::Format::NCHW44_DOT, dt_int8, dt_int32, | |||
| dt_int32, dt_int8, dt_int32, dt_int32, | |||
| PostprocessMode::NO_PROCESS, | |||
| PostprocessMode::ADD_BIAS, | |||
| "NCHW44_DOT::GEMV::INT8x8x32_INT32"_hash); | |||
| cb2(param::ConvBias::Format::NCHW44_DOT, dtype::QuantizedS8, | |||
| cb3(param::ConvBias::Format::NCHW44_DOT, dtype::QuantizedS8, | |||
| dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, dt_int32, | |||
| dt_int32, PostprocessMode::NO_PROCESS, | |||
| dt_int32, PostprocessMode::ADD_BIAS, | |||
| "NCHW44_DOT::GEMV::QINT8x8x32_QINT32"_hash); | |||
| cb2(param::ConvBias::Format::NCHW44_DOT, dtype::QuantizedS8, | |||
| dtype::QuantizedS32, dtype::QuantizedS8, dt_int8, dt_int32, | |||
| @@ -385,6 +398,7 @@ ConvBiasImpl::AlgoConv1x1Gemv::dispatch_kerns( | |||
| } | |||
| #undef cb1 | |||
| #undef cb2 | |||
| #undef cb3 | |||
| megdnn_assert(conv1x1_gemv_worker, "No suitable gemv worker"); | |||
| @@ -448,8 +462,7 @@ bool ConvBiasImpl::AlgoConv1x1Gemv::usable(const NCBKernSizeParam& param, | |||
| if (param.dst_type.enumv() == DTypeEnum::Int16 || | |||
| param.dst_type.enumv() == DTypeEnum::Int32 || | |||
| param.dst_type.enumv() == DTypeEnum::QuantizedS32) { | |||
| if (param.bias_mode != megdnn::BiasMode::NO_BIAS || | |||
| param.nonlineMode != megdnn::NonlineMode::IDENTITY) { | |||
| if (param.nonlineMode != megdnn::NonlineMode::IDENTITY) { | |||
| return false; | |||
| } | |||
| } | |||
| @@ -56,6 +56,19 @@ std::unique_ptr<Conv1x1StrategyBase> create_conv1x1_strategy( | |||
| } \ | |||
| } \ | |||
| MIDOUT_END() | |||
| #define cb3(_packmode, _i_src_type, _i_bias_type, _i_dst_type, _src_ctype, \ | |||
| _bias_ctype, _dst_ctype, _postprocess_mode, _midout_tag) \ | |||
| MIDOUT_BEGIN(megdnn_fallback_conv1x1_factory_strategy, \ | |||
| midout_iv(_midout_tag)) { \ | |||
| if (param.filter_type.enumv() == param.src_type.enumv() && \ | |||
| param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \ | |||
| param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \ | |||
| return std::make_unique<Conv1x1Strategy< \ | |||
| _src_ctype, _bias_ctype, _dst_ctype, _bias_ctype, \ | |||
| _dst_ctype, _postprocess_mode, _packmode>>(pack_c_size); \ | |||
| } \ | |||
| } \ | |||
| MIDOUT_END() | |||
| switch (pack_mode) { | |||
| case MatrixMulImpl::AlgoBase::PackMode::DEFAULT: | |||
| @@ -71,26 +84,26 @@ std::unique_ptr<Conv1x1StrategyBase> create_conv1x1_strategy( | |||
| "Default::FLOAT16_FLOAT16"_hash); | |||
| #endif | |||
| #endif | |||
| cb2(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, dt_int8, dt_int32, | |||
| cb3(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, dt_int8, dt_int32, | |||
| dt_int32, dt_int8, dt_int32, dt_int32, | |||
| PostprocessMode::NO_PROCESS, "Default::INT8x8x32_INT32"_hash); | |||
| cb2(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, dt_int8, dt_int16, | |||
| PostprocessMode::ADD_BIAS, "Default::INT8x8x32_INT32"_hash); | |||
| cb3(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, dt_int8, dt_int16, | |||
| dt_int16, dt_int8, dt_int16, dt_int16, | |||
| PostprocessMode::NO_PROCESS, "Default::INT8x8x16_INT16"_hash); | |||
| PostprocessMode::ADD_BIAS, "Default::INT8x8x16_INT16"_hash); | |||
| #if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
| cb2(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, | |||
| cb3(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, | |||
| dtype::Quantized8Asymm, dtype::QuantizedS32, | |||
| dtype::QuantizedS32, dt_uint8, dt_int32, dt_int32, | |||
| PostprocessMode::NO_PROCESS, | |||
| PostprocessMode::ADD_BIAS, | |||
| "Default::QUINT8x8x32_QINT32"_hash); | |||
| cb2(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, | |||
| dtype::Quantized8Asymm, dtype::QuantizedS32, | |||
| dtype::Quantized8Asymm, dt_uint8, dt_int32, dt_uint8, | |||
| PostprocessMode::QUANTIZED, "Default::QUINT8x8x32_QUINT8"_hash); | |||
| #endif | |||
| cb2(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, dtype::QuantizedS8, | |||
| cb3(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, dtype::QuantizedS8, | |||
| dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, dt_int32, | |||
| dt_int32, PostprocessMode::NO_PROCESS, | |||
| dt_int32, PostprocessMode::ADD_BIAS, | |||
| "Default::QINT8x8x32_QINT32"_hash); | |||
| cb2(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, dtype::QuantizedS8, | |||
| dtype::QuantizedS32, dtype::QuantizedS8, dt_int8, dt_int32, | |||
| @@ -107,17 +120,17 @@ std::unique_ptr<Conv1x1StrategyBase> create_conv1x1_strategy( | |||
| cb1(MatrixMulImpl::AlgoBase::PackMode::NO_PACK, dt_float32, | |||
| dt_float32, PostprocessMode::FLOAT, "NoPack::FLOAT"_hash); | |||
| cb2(MatrixMulImpl::AlgoBase::PackMode::NO_PACK, dt_int8, dt_int16, | |||
| cb3(MatrixMulImpl::AlgoBase::PackMode::NO_PACK, dt_int8, dt_int16, | |||
| dt_int16, dt_int8, dt_int16, dt_int16, | |||
| PostprocessMode::NO_PROCESS, "NoPack::INT8x8x16_INT16"_hash); | |||
| PostprocessMode::ADD_BIAS, "NoPack::INT8x8x16_INT16"_hash); | |||
| cb2(MatrixMulImpl::AlgoBase::PackMode::NO_PACK, dt_int8, dt_int32, | |||
| cb3(MatrixMulImpl::AlgoBase::PackMode::NO_PACK, dt_int8, dt_int32, | |||
| dt_int32, dt_int8, dt_int32, dt_int32, | |||
| PostprocessMode::NO_PROCESS, "NoPack::INT8x8x32_INT32"_hash); | |||
| PostprocessMode::ADD_BIAS, "NoPack::INT8x8x32_INT32"_hash); | |||
| cb2(MatrixMulImpl::AlgoBase::PackMode::NO_PACK, dtype::QuantizedS8, | |||
| cb3(MatrixMulImpl::AlgoBase::PackMode::NO_PACK, dtype::QuantizedS8, | |||
| dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, dt_int32, | |||
| dt_int32, PostprocessMode::NO_PROCESS, | |||
| dt_int32, PostprocessMode::ADD_BIAS, | |||
| "NoPack::QINT8x8x32_QINT32"_hash); | |||
| break; | |||
| @@ -127,6 +140,7 @@ std::unique_ptr<Conv1x1StrategyBase> create_conv1x1_strategy( | |||
| } | |||
| #undef cb1 | |||
| #undef cb2 | |||
| #undef cb3 | |||
| megdnn_throw("Invalid Data Type"); | |||
| return nullptr; | |||
| } | |||
| @@ -207,4 +221,4 @@ bool Conv1x1Factory::can_make_conv1x1_strategy( | |||
| } // namespace fallback | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -746,8 +746,7 @@ bool ConvBiasImpl::AlgoIm2col::usable( | |||
| if (param.dst_type.enumv() == DTypeEnum::Int16 || | |||
| param.dst_type.enumv() == DTypeEnum::Int32 || | |||
| param.dst_type.enumv() == DTypeEnum::QuantizedS32) { | |||
| if (param.bias_mode != megdnn::BiasMode::NO_BIAS || | |||
| param.nonlineMode != megdnn::NonlineMode::IDENTITY) { | |||
| if (param.nonlineMode != megdnn::NonlineMode::IDENTITY) { | |||
| return false; | |||
| } | |||
| } | |||
| @@ -213,6 +213,22 @@ public: | |||
| } \ | |||
| MIDOUT_END(); \ | |||
| return {}; | |||
| #define cb3(_format, _packmode, _i_src_type, _i_bias_type, _i_dst_type, \ | |||
| _src_ctype, _bias_ctype, _dst_ctype, _postprocess_mode, \ | |||
| _midout_tag) \ | |||
| MIDOUT_BEGIN(megdnn_fallback_im2col_factory_make_strategy, \ | |||
| midout_iv(_midout_tag)) { \ | |||
| if (param.filter_type.enumv() == param.src_type.enumv() && \ | |||
| param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \ | |||
| param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \ | |||
| return std::make_unique< \ | |||
| Strategy<_src_ctype, _bias_ctype, _dst_ctype, _bias_ctype, \ | |||
| _dst_ctype, _postprocess_mode, \ | |||
| PackMode::_packmode, FormatMode::_format>>(); \ | |||
| } \ | |||
| } \ | |||
| MIDOUT_END(); \ | |||
| return {}; | |||
| static std::unique_ptr<StrategyBase> make_default_strategy( | |||
| fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
| @@ -279,13 +295,13 @@ public: | |||
| #endif | |||
| case StrategyType::INT8x8x32: | |||
| if (format == param::ConvBias::Format::NCHW) { | |||
| cb2(NCHW, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8, | |||
| dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||
| cb3(NCHW, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8, | |||
| dt_int32, dt_int32, PostprocessMode::ADD_BIAS, | |||
| "DefaultStrategyType::INT8x8x32"_hash); | |||
| } else if (format == param::ConvBias::Format::NCHW44 || | |||
| format == param::ConvBias::Format::NCHW44_DOT) { | |||
| cb2(NCHW44, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8, | |||
| dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||
| cb3(NCHW44, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8, | |||
| dt_int32, dt_int32, PostprocessMode::ADD_BIAS, | |||
| "DefaultStrategyType::INT8x8x32"_hash); | |||
| } else { | |||
| megdnn_throw( | |||
| @@ -299,12 +315,12 @@ public: | |||
| case StrategyType::INT8x8x16: | |||
| if (format == param::ConvBias::Format::NCHW) { | |||
| cb2(NCHW, DEFAULT, dt_int8, dt_int16, dt_int16, dt_int8, | |||
| dt_int16, dt_int16, PostprocessMode::NO_PROCESS, | |||
| cb3(NCHW, DEFAULT, dt_int8, dt_int16, dt_int16, dt_int8, | |||
| dt_int16, dt_int16, PostprocessMode::ADD_BIAS, | |||
| "DefaultStrategyType::INT8x8x16"_hash); | |||
| } else if (format == param::ConvBias::Format::NCHW44) { | |||
| cb2(NCHW44, DEFAULT, dt_int8, dt_int16, dt_int16, dt_int8, | |||
| dt_int16, dt_int16, PostprocessMode::NO_PROCESS, | |||
| cb3(NCHW44, DEFAULT, dt_int8, dt_int16, dt_int16, dt_int8, | |||
| dt_int16, dt_int16, PostprocessMode::ADD_BIAS, | |||
| "DefaultStrategyType::INT8x8x16"_hash); | |||
| } else { | |||
| megdnn_throw( | |||
| @@ -316,9 +332,9 @@ public: | |||
| break; | |||
| #if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
| case StrategyType::QUINT8x8x32: | |||
| cb2(NCHW, DEFAULT, dtype::Quantized8Asymm, dtype::QuantizedS32, | |||
| cb3(NCHW, DEFAULT, dtype::Quantized8Asymm, dtype::QuantizedS32, | |||
| dtype::QuantizedS32, dt_uint8, dt_int32, dt_int32, | |||
| PostprocessMode::NO_PROCESS, | |||
| PostprocessMode::ADD_BIAS, | |||
| "DefaultStrategyType::QUINT8x8x32"_hash); | |||
| break; | |||
| @@ -331,15 +347,15 @@ public: | |||
| #endif | |||
| case StrategyType::QINT8x8x32: | |||
| if (format == param::ConvBias::Format::NCHW) { | |||
| cb2(NCHW, DEFAULT, dtype::QuantizedS8, dtype::QuantizedS32, | |||
| cb3(NCHW, DEFAULT, dtype::QuantizedS8, dtype::QuantizedS32, | |||
| dtype::QuantizedS32, dt_int8, dt_int32, dt_int32, | |||
| PostprocessMode::NO_PROCESS, | |||
| PostprocessMode::ADD_BIAS, | |||
| "DefaultStrategyTypeNCHW::QINT8x8x32"_hash); | |||
| } else if (format == param::ConvBias::Format::NCHW44 || | |||
| format == param::ConvBias::Format::NCHW44_DOT) { | |||
| cb2(NCHW44, DEFAULT, dtype::QuantizedS8, | |||
| cb3(NCHW44, DEFAULT, dtype::QuantizedS8, | |||
| dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, | |||
| dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||
| dt_int32, dt_int32, PostprocessMode::ADD_BIAS, | |||
| "DefaultStrategyTypeHCHW44::QINT8x8x32"_hash); | |||
| } else { | |||
| megdnn_throw( | |||
| @@ -467,13 +483,13 @@ public: | |||
| #endif | |||
| #endif | |||
| case StrategyType::INT8x8x16: | |||
| cb2(NCHW, NO_PACK, dt_int8, dt_int16, dt_int16, dt_int8, | |||
| dt_int16, dt_int16, PostprocessMode::NO_PROCESS, | |||
| cb3(NCHW, NO_PACK, dt_int8, dt_int16, dt_int16, dt_int8, | |||
| dt_int16, dt_int16, PostprocessMode::ADD_BIAS, | |||
| "NoPackStrategyType::INT8x8x16"_hash); | |||
| break; | |||
| case StrategyType::INT8x8x32: | |||
| cb2(NCHW, NO_PACK, dt_int8, dt_int32, dt_int32, dt_int8, | |||
| dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||
| cb3(NCHW, NO_PACK, dt_int8, dt_int32, dt_int32, dt_int8, | |||
| dt_int32, dt_int32, PostprocessMode::ADD_BIAS, | |||
| "NoPackStrategyType::INT8x8x32"_hash); | |||
| break; | |||
| default: | |||
| @@ -509,6 +525,7 @@ public: | |||
| #undef cb1 | |||
| #undef cb2 | |||
| #undef cb3 | |||
| static std::unique_ptr<StrategyBase> make_strategy( | |||
| fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
| @@ -203,18 +203,16 @@ INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16, | |||
| //! x86 do not have uint8 matmul so only armv7 armv8 support uint8 | |||
| INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_uint8, dt_qint32, dt_quint8, | |||
| megdnn::PostprocessMode::QUANTIZED) | |||
| INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_qint32, dt_qint32, | |||
| megdnn::PostprocessMode::NO_PROCESS) | |||
| INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_int32, dt_int32, | |||
| megdnn::PostprocessMode::ADD_BIAS) | |||
| #endif | |||
| INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8, | |||
| megdnn::PostprocessMode::QUANTIZED) | |||
| INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_int32, dt_int32, | |||
| megdnn::PostprocessMode::NO_PROCESS) | |||
| megdnn::PostprocessMode::ADD_BIAS) | |||
| INSTANTIAL_CLASS(dt_int8, dt_int16, dt_int16, dt_int16, dt_int16, | |||
| megdnn::PostprocessMode::NO_PROCESS) | |||
| INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_qint32, dt_qint32, | |||
| megdnn::PostprocessMode::NO_PROCESS) | |||
| megdnn::PostprocessMode::ADD_BIAS) | |||
| #undef INSTANTIAL_CLASS | |||
| } // namespace megdnn | |||
| @@ -119,19 +119,16 @@ INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16, | |||
| //! x86 do not have uint8 matmul so only armv7 armv8 support uint8 | |||
| INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_uint8, dt_qint32, dt_quint8, | |||
| megdnn::PostprocessMode::QUANTIZED) | |||
| INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_qint32, dt_qint32, | |||
| megdnn::PostprocessMode::NO_PROCESS) | |||
| INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_int32, dt_int32, | |||
| megdnn::PostprocessMode::ADD_BIAS) | |||
| #endif | |||
| INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8, | |||
| megdnn::PostprocessMode::QUANTIZED) | |||
| INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_int32, dt_int32, | |||
| megdnn::PostprocessMode::NO_PROCESS) | |||
| megdnn::PostprocessMode::ADD_BIAS) | |||
| INSTANTIAL_CLASS(dt_int8, dt_int16, dt_int16, dt_int16, dt_int16, | |||
| megdnn::PostprocessMode::NO_PROCESS) | |||
| INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_qint32, dt_qint32, | |||
| megdnn::PostprocessMode::NO_PROCESS) | |||
| megdnn::PostprocessMode::ADD_BIAS) | |||
| #undef INSTANTIAL_CLASS | |||
| } // namespace megdnn | |||
| @@ -162,9 +162,9 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
| INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32, | |||
| megdnn::PostprocessMode::FLOAT) | |||
| INSTANTIAL_CLASS(dt_int8, dt_int16, dt_int16, dt_int16, dt_int16, | |||
| megdnn::PostprocessMode::NO_PROCESS) | |||
| megdnn::PostprocessMode::ADD_BIAS) | |||
| INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_int32, dt_int32, | |||
| megdnn::PostprocessMode::NO_PROCESS) | |||
| megdnn::PostprocessMode::ADD_BIAS) | |||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
| #else | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| @@ -294,6 +294,73 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::QUANTIZED> { | |||
| #undef FOR_BIAS | |||
| } | |||
| }; | |||
| #undef CALL_BINARY | |||
| #undef CALL_BINARY_BROADCAST | |||
| #define CALL_BINARY(_op, _simd_type) \ | |||
| thin_function<void(const ctype*, const ctype*, dtype*, DType, DType, \ | |||
| DType, size_t)> \ | |||
| run = OpCallerBinary<_op<_simd_type, ctype, dtype>, _simd_type, \ | |||
| megdnn::x86::BcastType::VEC_VEC>::run; \ | |||
| run(static_cast<ctype*>(conv_dst_ptr), static_cast<ctype*>(bias_ptr), \ | |||
| reinterpret_cast<dtype*>(dst_ptr), bias_type, bias_type, dst_type, \ | |||
| N* OC* OH* OW); | |||
| #define CALL_BINARY_BROADCAST(_op, _simd_type) \ | |||
| thin_function<void(const ctype*, const ctype*, dtype*, DType, DType, \ | |||
| DType, size_t, size_t, size_t)> \ | |||
| run = OpCallerBinary<_op<_simd_type, ctype, dtype>, _simd_type, \ | |||
| megdnn::x86::BcastType::VEC_BCAST101>::run; \ | |||
| run(static_cast<ctype*>(conv_dst_ptr), static_cast<ctype*>(bias_ptr), \ | |||
| reinterpret_cast<dtype*>(dst_ptr), bias_type, bias_type, dst_type, N, \ | |||
| OC, OH* OW); | |||
| #define FOR_SIMD(CALLER) \ | |||
| if (is_supported(SIMDType::AVX2)) { \ | |||
| CALLER(AddOp, SIMDType::AVX2) \ | |||
| } else if (is_supported(SIMDType::SSE4_2)) { \ | |||
| CALLER(AddOp, SIMDType::SSE4_2) \ | |||
| } else { \ | |||
| CALLER(AddOp, SIMDType::NONE) \ | |||
| } | |||
| #define FOR_BIAS(bias_mode) \ | |||
| switch (bias_mode) { \ | |||
| case BiasMode::BIAS: \ | |||
| FOR_SIMD(CALL_BINARY); \ | |||
| break; \ | |||
| case BiasMode::BROADCAST_CHANNEL_BIAS: \ | |||
| FOR_SIMD(CALL_BINARY_BROADCAST); \ | |||
| break; \ | |||
| default: \ | |||
| break; \ | |||
| } | |||
| template <typename ctype, typename dtype> | |||
| struct PostProcess<ctype, dtype, megdnn::PostprocessMode::ADD_BIAS> { | |||
| static void run(void* conv_dst_ptr, void* bias_ptr, void* dst_ptr, | |||
| megdnn::ConvBiasForward::BiasMode bias_mode, | |||
| megdnn::param::ConvBiasV0::NonlineMode nonlineMode, | |||
| DType bias_type, DType dst_type, size_t N, size_t OC, | |||
| size_t OH, size_t OW, size_t pack_oc_size = 1) { | |||
| MEGDNN_MARK_USED_VAR(pack_oc_size); | |||
| megdnn_assert(pack_oc_size == 1, | |||
| "PostProcess only support nchw in x86"); | |||
| megdnn_assert( | |||
| nonlineMode == megdnn::param::ConvBiasV0::NonlineMode::IDENTITY, | |||
| "Add bias PostProcess only support IDENTITY"); | |||
| if (bias_mode == megdnn::ConvBiasForward::BiasMode::NO_BIAS) { | |||
| return; | |||
| } | |||
| FOR_BIAS(bias_mode); | |||
| #undef CALL_BINARY | |||
| #undef CALL_BINARY_BROADCAST | |||
| #undef FOR_SIMD | |||
| #undef FOR_BIAS | |||
| } | |||
| }; | |||
| #undef cb_unary | |||
| #undef cb_binary | |||
| #undef BIAS_CASE | |||
| @@ -92,6 +92,8 @@ OP(dt_int8, SIMDType::AVX2, "avx2", __m256i, __m256ix2, __m256i, mm256, epi8, | |||
| using AddOpBase::operator(); \ | |||
| }; | |||
| OP(dt_int32, SIMDType::NONE); | |||
| OP(dt_int16, SIMDType::NONE); | |||
| OP(dt_float32, SIMDType::NONE); | |||
| #undef OP | |||
| } // namespace x86 | |||
| @@ -1992,13 +1992,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_S8x8x32_MK4_DOT) { | |||
| #define cb(name) \ | |||
| checker_conv_bias( \ | |||
| get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \ | |||
| true, false, true, false, false, true), \ | |||
| get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, \ | |||
| true, false, true, true, false, false), \ | |||
| handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ | |||
| dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), {}, name); \ | |||
| checker_conv_bias( \ | |||
| get_nchw44_conv_bias_args({1}, 2, false, true, true, false, true, \ | |||
| false, false, true), \ | |||
| true, false, false), \ | |||
| handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ | |||
| dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), {}, name); | |||
| @@ -2041,13 +2041,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32_MK4_DOT) { | |||
| #define cb(name) \ | |||
| checker_conv_bias( \ | |||
| get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \ | |||
| true, false, true, false, false, true), \ | |||
| get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, \ | |||
| true, false, true, true, false, false), \ | |||
| handle(), &rng, epsilon, dtype::Int8(), dtype::Int8(), \ | |||
| dtype::Int32(), {}, name); \ | |||
| checker_conv_bias( \ | |||
| get_nchw44_conv_bias_args({1}, 2, false, true, true, false, true, \ | |||
| false, false, true), \ | |||
| true, false, false), \ | |||
| handle(), &rng, epsilon, dtype::Int8(), dtype::Int8(), \ | |||
| dtype::Int32(), {}, name); | |||
| @@ -2118,7 +2118,6 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CONV1x1_QUANTIZEDSYM_MK4_DOT) { | |||
| #if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDASYM) { | |||
| NormalRNG rng(128.f); | |||
| #define cb(name) \ | |||
| checker_conv_bias(get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \ | |||
| false, true, true), \ | |||
| @@ -2188,18 +2187,19 @@ TEST_F(ARM_COMMON_MULTI_THREADS, | |||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUINT8x8x32) { | |||
| UniformIntRNG rng{-50, 50}; | |||
| float epsilon = 0.001; | |||
| #define cb(name) \ | |||
| checker_conv_bias( \ | |||
| get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true), \ | |||
| handle(), &rng, epsilon, \ | |||
| dtype::Quantized8Asymm(1.2f, (uint8_t)125), \ | |||
| dtype::Quantized8Asymm(1.3f, (uint8_t)129), \ | |||
| dtype::QuantizedS32(1.2 * 1.3), {}, name); \ | |||
| checker_conv_bias(get_conv_bias_args({1}, 2, false, true, true), handle(), \ | |||
| &rng, epsilon, \ | |||
| dtype::Quantized8Asymm(1.2f, (uint8_t)125), \ | |||
| dtype::Quantized8Asymm(1.3f, (uint8_t)129), \ | |||
| dtype::QuantizedS32(1.2 * 1.3), {}, name); | |||
| #define cb(name) \ | |||
| checker_conv_bias(get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \ | |||
| true, true, false), \ | |||
| handle(), &rng, epsilon, \ | |||
| dtype::Quantized8Asymm(1.2f, (uint8_t)125), \ | |||
| dtype::Quantized8Asymm(1.3f, (uint8_t)129), \ | |||
| dtype::QuantizedS32(1.2 * 1.3), {}, name); \ | |||
| checker_conv_bias( \ | |||
| get_conv_bias_args({1}, 2, false, false, true, true, false), \ | |||
| handle(), &rng, epsilon, \ | |||
| dtype::Quantized8Asymm(1.2f, (uint8_t)125), \ | |||
| dtype::Quantized8Asymm(1.3f, (uint8_t)129), \ | |||
| dtype::QuantizedS32(1.2 * 1.3), {}, name); | |||
| #if MEGDNN_AARCH64 | |||
| #if __ARM_FEATURE_DOTPROD | |||
| @@ -2252,18 +2252,18 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16) { | |||
| UniformIntRNG rng{-50, 50}; | |||
| float epsilon = 0.001; | |||
| std::vector<conv_bias::TestArg> args_nchw44 = | |||
| get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, true, true, true, | |||
| false, false, false, false, true); | |||
| get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, true, false, true, | |||
| false, false, true, false, false); | |||
| std::vector<conv_bias::TestArg> args_nchw44_1x1s2 = | |||
| get_nchw44_conv_bias_args({1}, 2, true, true, true, false, false, | |||
| false, false, true); | |||
| #define cb(name) \ | |||
| checker_conv_bias( \ | |||
| get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true), \ | |||
| handle(), &rng, epsilon, dtype::Int8{}, dtype::Int8{}, \ | |||
| dtype::Int16{}, dtype::Int16{}, name); \ | |||
| checker_conv_bias(get_conv_bias_args({1}, 2, false, true, true), handle(), \ | |||
| &rng, epsilon, dtype::Int8{}, dtype::Int8{}, \ | |||
| get_nchw44_conv_bias_args({1}, 2, true, false, true, false, false, | |||
| true, false, false); | |||
| #define cb(name) \ | |||
| checker_conv_bias( \ | |||
| get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, true), \ | |||
| handle(), &rng, epsilon, dtype::Int8{}, dtype::Int8{}, \ | |||
| dtype::Int16{}, dtype::Int16{}, name); \ | |||
| checker_conv_bias(get_conv_bias_args({1}, 2, false, false, true), \ | |||
| handle(), &rng, epsilon, dtype::Int8{}, dtype::Int8{}, \ | |||
| dtype::Int16{}, dtype::Int16{}, name); | |||
| #define cb_nchw44(name) \ | |||
| @@ -2314,14 +2314,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16_FILTERPREPROCES | |||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16_NOPACK_FILTERPREPROCESS) { | |||
| UniformIntRNG rng{-50, 50}; | |||
| float epsilon = 0.001; | |||
| #define cb(name) \ | |||
| check_conv_bias_preprocess( \ | |||
| get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true), \ | |||
| handle(), &rng, epsilon, dtype::Int8{}, dtype::Int8{}, \ | |||
| dtype::Int16{}, dtype::Int16{}, name); \ | |||
| check_conv_bias_preprocess(get_conv_bias_args({1}, 2, false, true, true), \ | |||
| handle(), &rng, epsilon, dtype::Int8{}, \ | |||
| dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, \ | |||
| #define cb(name) \ | |||
| check_conv_bias_preprocess( \ | |||
| get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, true), \ | |||
| handle(), &rng, epsilon, dtype::Int8{}, dtype::Int8{}, \ | |||
| dtype::Int16{}, dtype::Int16{}, name); \ | |||
| check_conv_bias_preprocess(get_conv_bias_args({1}, 2, false, false, true), \ | |||
| handle(), &rng, epsilon, dtype::Int8{}, \ | |||
| dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, \ | |||
| name); | |||
| #if MEGDNN_AARCH64 | |||
| @@ -2406,7 +2406,7 @@ void checker_conv_bias_mul_int8x8x32(std::vector<conv_bias::TestArg> args, | |||
| checker.set_dtype(0, dtype::QuantizedS8(2.5f)) | |||
| .set_dtype(1, dtype::QuantizedS8(2.5f)) | |||
| .set_dtype(2, dtype::QuantizedS32(6.25f)) | |||
| .set_dtype(4, {}) | |||
| .set_dtype(4, dtype::QuantizedS32(6.25f)) | |||
| .set_rng(0, &rng) | |||
| .set_rng(1, &rng) | |||
| .set_rng(2, &rng) | |||
| @@ -2436,7 +2436,7 @@ void checker_conv_bias_int8x8x32_preprocess(std::vector<conv_bias::TestArg> args | |||
| checker.set_dtype(0, dtype::QuantizedS8(2.5f)) | |||
| .set_dtype(1, dtype::QuantizedS8(2.5f)) | |||
| .set_dtype(2, dtype::QuantizedS32(6.25f)) | |||
| .set_dtype(4, {}) | |||
| .set_dtype(4, dtype::QuantizedS32(6.25f)) | |||
| .set_rng(0, &rng) | |||
| .set_rng(1, &rng) | |||
| .set_rng(2, &rng) | |||
| @@ -2450,7 +2450,7 @@ void checker_conv_bias_int8x8x32_preprocess(std::vector<conv_bias::TestArg> args | |||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2) { | |||
| using namespace conv_bias; | |||
| std::vector<conv_bias::TestArg> args = | |||
| get_nchw44_conv_bias_args({2, 5, 7}, 2, false, true, true); | |||
| get_nchw44_conv_bias_args({2, 5, 7}, 2, false, false, true); | |||
| #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name); | |||
| #if MEGDNN_AARCH64 | |||
| @@ -2464,7 +2464,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2) { | |||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2_PREPROCESS) { | |||
| using namespace conv_bias; | |||
| std::vector<conv_bias::TestArg> args = | |||
| get_nchw44_conv_bias_args({2, 5, 7}, 2, false, true, true); | |||
| get_nchw44_conv_bias_args({2, 5, 7}, 2, false, false, true); | |||
| #define cb(name) checker_conv_bias_int8x8x32_preprocess(args, handle(), name); | |||
| #if MEGDNN_AARCH64 | |||
| @@ -2478,7 +2478,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2_PREPR | |||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S1) { | |||
| using namespace conv_bias; | |||
| std::vector<conv_bias::TestArg> args = | |||
| get_nchw44_conv_bias_args({3, 4, 6}, 1, false, true, true); | |||
| get_nchw44_conv_bias_args({3, 4, 6}, 1, false, false, true); | |||
| #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name); | |||
| #if MEGDNN_AARCH64 | |||
| @@ -3080,9 +3080,10 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUINT8x8x32_PREPROCESS) { | |||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_1X1_S1_INT8x8x16) { | |||
| UniformIntRNG rng{-50, 50}; | |||
| float epsilon = 0.001; | |||
| std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(true, true); | |||
| std::vector<conv_bias::TestArg> args = | |||
| get_conv_bias_1x1_args(false, true, false, false); | |||
| std::vector<conv_bias::TestArg> args_nchw44 = get_nchw44_conv_bias_args( | |||
| {1}, 1, true, true, true, false, false, false, false, true); | |||
| {1}, 1, true, true, true, false, false, true, false, false); | |||
| #define cb(name) \ | |||
| checker_conv_bias(args, handle(), &rng, epsilon, dtype::Int8{}, \ | |||
| dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, name); | |||
| @@ -3140,7 +3141,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_1X1_S1_INT8x8x16_PREPROCESS) { | |||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32) { | |||
| using namespace conv_bias; | |||
| std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(true, true); | |||
| std::vector<conv_bias::TestArg> args = | |||
| get_conv_bias_1x1_args(false, true, false, false); | |||
| #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name); | |||
| @@ -834,6 +834,13 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8X8X32) { | |||
| //! no bias | |||
| args.emplace_back(param, TensorShape{1, ic, h, w}, | |||
| TensorShape{oc, ic, kernel, kernel}, TensorShape{}); | |||
| args.emplace_back(param, TensorShape{1, ic, h, w}, | |||
| TensorShape{oc, ic, kernel, kernel}, | |||
| TensorShape{1, oc, 1, 1}); | |||
| args.emplace_back(param, TensorShape{1, ic, h, w}, | |||
| TensorShape{oc, ic, kernel, kernel}, | |||
| TensorShape{1, oc, (h + 2 * p - kernel) + 1, | |||
| (h + 2 * p - kernel) + 1}); | |||
| }; | |||
| for (size_t kernel : {2, 3, 4, 5, 6, 7}) | |||
| @@ -1384,7 +1391,7 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_INT8X8X32) { | |||
| using namespace conv_bias; | |||
| UniformIntRNG rng{-50, 50}; | |||
| float epsilon = 0.001; | |||
| std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(true, true); | |||
| std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(false, true); | |||
| #if MEGDNN_X86_WITH_MKL_DNN | |||
| if (x86::is_supported(x86::SIMDType::VNNI)) { | |||
| checker_conv_bias(args, handle(), &rng, epsilon, dtype::Int8{}, | |||
| @@ -1422,7 +1429,7 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_INT8X8X32_PREPROCESS) { | |||
| using namespace conv_bias; | |||
| UniformIntRNG rng{-50, 50}; | |||
| float epsilon = 0.001; | |||
| std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(true, true); | |||
| std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(false, true); | |||
| #if MEGDNN_X86_WITH_VNNI | |||
| if (x86::is_supported(x86::SIMDType::VNNI)) { | |||
| checker_conv_bias_preprocess(args, handle(), &rng, epsilon, dtype::Int8{}, | |||