GitOrigin-RevId: 3d97fedc8f
tags/v1.0.0-rc1
| @@ -100,7 +100,6 @@ namespace { | |||||
| MIDOUT_END(); \ | MIDOUT_END(); \ | ||||
| break; \ | break; \ | ||||
| default: \ | default: \ | ||||
| megdnn_throw("no quantized unsupported biasmode"); \ | |||||
| break; \ | break; \ | ||||
| } | } | ||||
| @@ -258,6 +257,66 @@ struct PostProcess<opctype, opdtype, megdnn::PostprocessMode::QUANTIZED> { | |||||
| #undef FOR_NONLINEAR_NOBIAS | #undef FOR_NONLINEAR_NOBIAS | ||||
| #undef FOR_NONLINEAR | #undef FOR_NONLINEAR | ||||
| #undef FOR_BIAS | #undef FOR_BIAS | ||||
| #define FOR_BINARY_BROADCAST(_op) \ | |||||
| megdnn::arm_common:: \ | |||||
| OpCallerBinary<_op<ctype>, megdnn::arm_common::VEC_BCAST101>::run( \ | |||||
| static_cast<ctype*>(conv_dst_ptr), \ | |||||
| reinterpret_cast<const ctype*>(bias_ptr), \ | |||||
| reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \ | |||||
| dst_type, N, OC, OH* OW); | |||||
| #define FOR_BINARY_BROADCAST_NCHW44(_op) \ | |||||
| megdnn::arm_common::OpCallerBinary<_op<ctype>, \ | |||||
| megdnn::arm_common::VEC_BCAST101x4>:: \ | |||||
| run(static_cast<ctype*>(conv_dst_ptr), \ | |||||
| reinterpret_cast<const ctype*>(bias_ptr), \ | |||||
| reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \ | |||||
| dst_type, N, OC, OH* OW, pack_oc_size); | |||||
| #define FOR_BINARY(_op) \ | |||||
| megdnn::arm_common:: \ | |||||
| OpCallerBinary<_op<ctype>, megdnn::arm_common::VEC_VEC>::run( \ | |||||
| static_cast<ctype*>(conv_dst_ptr), \ | |||||
| reinterpret_cast<const ctype*>(bias_ptr), \ | |||||
| reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \ | |||||
| dst_type, N* OC* OH* OW* pack_oc_size); | |||||
| #define FOR_BIAS(_bias_mode, OH, OW) \ | |||||
| switch (_bias_mode) { \ | |||||
| case megdnn::BiasMode::NO_BIAS: \ | |||||
| break; \ | |||||
| case megdnn::BiasMode::BROADCAST_CHANNEL_BIAS: \ | |||||
| if (pack_oc_size == 1) { \ | |||||
| FOR_BINARY_BROADCAST(CONCAT_OP(AddOp)); \ | |||||
| } else { \ | |||||
| megdnn_assert(pack_oc_size == 4, \ | |||||
| "Only support nchw44 in ARM"); \ | |||||
| FOR_BINARY_BROADCAST_NCHW44(CONCAT_OP(AddOp)); \ | |||||
| } \ | |||||
| break; \ | |||||
| case megdnn::BiasMode::BIAS: \ | |||||
| FOR_BINARY(CONCAT_OP(AddOp)); \ | |||||
| break; \ | |||||
| default: \ | |||||
| break; \ | |||||
| } | |||||
| template <typename ctype, typename dtype> | |||||
| struct PostProcess<ctype, dtype, megdnn::PostprocessMode::ADD_BIAS> { | |||||
| static void run(void* conv_dst_ptr, void* bias_ptr, void* dst_ptr, | |||||
| megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode, | |||||
| megdnn::DType bias_type, megdnn::DType dst_type, size_t N, | |||||
| size_t OC, size_t OH, size_t OW, size_t pack_oc_size = 1) { | |||||
| megdnn_assert(nonlineMode == megdnn::NonlineMode::IDENTITY); | |||||
| FOR_BIAS(bias_mode, OH, OW); | |||||
| } | |||||
| }; | |||||
| #undef FOR_BINARY_BROADCAST | |||||
| #undef FOR_BINARY_BROADCAST_NCHW44 | |||||
| #undef FOR_BINARY | |||||
| #undef FOR_BIAS | |||||
| #undef CB | #undef CB | ||||
| #undef CONCAT_OP | #undef CONCAT_OP | ||||
| #undef CONCAT_NL | #undef CONCAT_NL | ||||
| @@ -158,9 +158,11 @@ private: \ | |||||
| uint32_t m_tile_size; | uint32_t m_tile_size; | ||||
| enum class PostprocessMode : uint8_t { | enum class PostprocessMode : uint8_t { | ||||
| FLOAT = 0, ///< support all biasmode and no_nonlinemode | |||||
| NO_PROCESS, ///<support non bias and identity | |||||
| QUANTIZED,///<support NOBIAS ,BROADCAST_CHANNEL_BIAS and relu hswish identify nonline mode | |||||
| FLOAT = 0, ///< support all biasmode and no_nonlinemode | |||||
| NO_PROCESS, ///< support non bias and identity | |||||
| QUANTIZED, ///< support NOBIAS ,BROADCAST_CHANNEL_BIAS and relu hswish | |||||
| ///< identify nonline mode | |||||
| ADD_BIAS, ///< only add bias | |||||
| }; | }; | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| @@ -227,8 +227,7 @@ bool ConvBiasImpl::AlgoConv1x1::usable(const NCBKernSizeParam& param, | |||||
| param.dst_type.enumv() == DTypeEnum::QuantizedS16 || | param.dst_type.enumv() == DTypeEnum::QuantizedS16 || | ||||
| param.dst_type.enumv() == DTypeEnum::Int32 || | param.dst_type.enumv() == DTypeEnum::Int32 || | ||||
| param.dst_type.enumv() == DTypeEnum::QuantizedS32) { | param.dst_type.enumv() == DTypeEnum::QuantizedS32) { | ||||
| if (param.bias_mode != megdnn::BiasMode::NO_BIAS || | |||||
| param.nonlineMode != megdnn::NonlineMode::IDENTITY) { | |||||
| if (param.nonlineMode != megdnn::NonlineMode::IDENTITY) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| @@ -310,6 +310,19 @@ ConvBiasImpl::AlgoConv1x1Gemv::dispatch_kerns( | |||||
| } \ | } \ | ||||
| } \ | } \ | ||||
| MIDOUT_END() | MIDOUT_END() | ||||
| #define cb3(_format, _i_src_type, _i_bias_type, _i_dst_type, _src_ctype, \ | |||||
| _bias_ctype, _dst_ctype, _postprocess_mode, _midout_tag) \ | |||||
| MIDOUT_BEGIN(megdnn_fallback_conv1x1_gemv, midout_iv(_midout_tag)) { \ | |||||
| if (param.filter_type.enumv() == param.src_type.enumv() && \ | |||||
| param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \ | |||||
| param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \ | |||||
| conv1x1_gemv_worker = \ | |||||
| Conv1x1GemvWorker<_src_ctype, _bias_ctype, _dst_ctype, \ | |||||
| _bias_ctype, _dst_ctype, \ | |||||
| _postprocess_mode, _format>::exec; \ | |||||
| } \ | |||||
| } \ | |||||
| MIDOUT_END() | |||||
| switch (param.filter_meta.format) { | switch (param.filter_meta.format) { | ||||
| case param::ConvBias::Format::NCHW: | case param::ConvBias::Format::NCHW: | ||||
| @@ -324,23 +337,23 @@ ConvBiasImpl::AlgoConv1x1Gemv::dispatch_kerns( | |||||
| PostprocessMode::NO_PROCESS, "NCHW::GEMV::FLOAT16_FLOAT16"_hash); | PostprocessMode::NO_PROCESS, "NCHW::GEMV::FLOAT16_FLOAT16"_hash); | ||||
| #endif | #endif | ||||
| #endif | #endif | ||||
| cb2(param::ConvBias::Format::NCHW, dt_int8, dt_int32, dt_int32, | |||||
| dt_int8, dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||||
| cb3(param::ConvBias::Format::NCHW, dt_int8, dt_int32, dt_int32, | |||||
| dt_int8, dt_int32, dt_int32, PostprocessMode::ADD_BIAS, | |||||
| "NCHW::GEMV::INT8x8x32_INT32"_hash); | "NCHW::GEMV::INT8x8x32_INT32"_hash); | ||||
| cb2(param::ConvBias::Format::NCHW, dt_int8, dt_int16, dt_int16, | |||||
| dt_int8, dt_int16, dt_int16, PostprocessMode::NO_PROCESS, | |||||
| cb3(param::ConvBias::Format::NCHW, dt_int8, dt_int16, dt_int16, | |||||
| dt_int8, dt_int16, dt_int16, PostprocessMode::ADD_BIAS, | |||||
| "NCHW::GEMV::INT8x8x16_INT16"_hash); | "NCHW::GEMV::INT8x8x16_INT16"_hash); | ||||
| cb2(param::ConvBias::Format::NCHW, dtype::QuantizedS8, | |||||
| cb3(param::ConvBias::Format::NCHW, dtype::QuantizedS8, | |||||
| dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, dt_int32, | dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, dt_int32, | ||||
| dt_int32, PostprocessMode::NO_PROCESS, | |||||
| dt_int32, PostprocessMode::ADD_BIAS, | |||||
| "NCHW::GEMV::QINT8x8x32_QINT32"_hash); | "NCHW::GEMV::QINT8x8x32_QINT32"_hash); | ||||
| cb2(param::ConvBias::Format::NCHW, dtype::QuantizedS8, | cb2(param::ConvBias::Format::NCHW, dtype::QuantizedS8, | ||||
| dtype::QuantizedS32, dtype::QuantizedS8, dt_int8, dt_int32, | dtype::QuantizedS32, dtype::QuantizedS8, dt_int8, dt_int32, | ||||
| dt_int8, PostprocessMode::QUANTIZED, | dt_int8, PostprocessMode::QUANTIZED, | ||||
| "NCHW::GEMV::QINT8x8x32_QINT8"_hash); | "NCHW::GEMV::QINT8x8x32_QINT8"_hash); | ||||
| cb2(param::ConvBias::Format::NCHW, dtype::Quantized8Asymm, | |||||
| cb3(param::ConvBias::Format::NCHW, dtype::Quantized8Asymm, | |||||
| dtype::QuantizedS32, dtype::QuantizedS32, dt_uint8, dt_int32, | dtype::QuantizedS32, dtype::QuantizedS32, dt_uint8, dt_int32, | ||||
| dt_int32, PostprocessMode::NO_PROCESS, | |||||
| dt_int32, PostprocessMode::ADD_BIAS, | |||||
| "NCHW::GEMV::QUINT8x8x32_QINT32"_hash); | "NCHW::GEMV::QUINT8x8x32_QINT32"_hash); | ||||
| cb2(param::ConvBias::Format::NCHW, dtype::Quantized8Asymm, | cb2(param::ConvBias::Format::NCHW, dtype::Quantized8Asymm, | ||||
| dtype::QuantizedS32, dtype::Quantized8Asymm, dt_uint8, dt_int32, | dtype::QuantizedS32, dtype::Quantized8Asymm, dt_uint8, dt_int32, | ||||
| @@ -365,13 +378,13 @@ ConvBiasImpl::AlgoConv1x1Gemv::dispatch_kerns( | |||||
| break; | break; | ||||
| case param::ConvBias::Format::NCHW44_DOT: | case param::ConvBias::Format::NCHW44_DOT: | ||||
| cb2(param::ConvBias::Format::NCHW44_DOT, dt_int8, dt_int32, | |||||
| cb3(param::ConvBias::Format::NCHW44_DOT, dt_int8, dt_int32, | |||||
| dt_int32, dt_int8, dt_int32, dt_int32, | dt_int32, dt_int8, dt_int32, dt_int32, | ||||
| PostprocessMode::NO_PROCESS, | |||||
| PostprocessMode::ADD_BIAS, | |||||
| "NCHW44_DOT::GEMV::INT8x8x32_INT32"_hash); | "NCHW44_DOT::GEMV::INT8x8x32_INT32"_hash); | ||||
| cb2(param::ConvBias::Format::NCHW44_DOT, dtype::QuantizedS8, | |||||
| cb3(param::ConvBias::Format::NCHW44_DOT, dtype::QuantizedS8, | |||||
| dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, dt_int32, | dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, dt_int32, | ||||
| dt_int32, PostprocessMode::NO_PROCESS, | |||||
| dt_int32, PostprocessMode::ADD_BIAS, | |||||
| "NCHW44_DOT::GEMV::QINT8x8x32_QINT32"_hash); | "NCHW44_DOT::GEMV::QINT8x8x32_QINT32"_hash); | ||||
| cb2(param::ConvBias::Format::NCHW44_DOT, dtype::QuantizedS8, | cb2(param::ConvBias::Format::NCHW44_DOT, dtype::QuantizedS8, | ||||
| dtype::QuantizedS32, dtype::QuantizedS8, dt_int8, dt_int32, | dtype::QuantizedS32, dtype::QuantizedS8, dt_int8, dt_int32, | ||||
| @@ -385,6 +398,7 @@ ConvBiasImpl::AlgoConv1x1Gemv::dispatch_kerns( | |||||
| } | } | ||||
| #undef cb1 | #undef cb1 | ||||
| #undef cb2 | #undef cb2 | ||||
| #undef cb3 | |||||
| megdnn_assert(conv1x1_gemv_worker, "No suitable gemv worker"); | megdnn_assert(conv1x1_gemv_worker, "No suitable gemv worker"); | ||||
| @@ -448,8 +462,7 @@ bool ConvBiasImpl::AlgoConv1x1Gemv::usable(const NCBKernSizeParam& param, | |||||
| if (param.dst_type.enumv() == DTypeEnum::Int16 || | if (param.dst_type.enumv() == DTypeEnum::Int16 || | ||||
| param.dst_type.enumv() == DTypeEnum::Int32 || | param.dst_type.enumv() == DTypeEnum::Int32 || | ||||
| param.dst_type.enumv() == DTypeEnum::QuantizedS32) { | param.dst_type.enumv() == DTypeEnum::QuantizedS32) { | ||||
| if (param.bias_mode != megdnn::BiasMode::NO_BIAS || | |||||
| param.nonlineMode != megdnn::NonlineMode::IDENTITY) { | |||||
| if (param.nonlineMode != megdnn::NonlineMode::IDENTITY) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| @@ -56,6 +56,19 @@ std::unique_ptr<Conv1x1StrategyBase> create_conv1x1_strategy( | |||||
| } \ | } \ | ||||
| } \ | } \ | ||||
| MIDOUT_END() | MIDOUT_END() | ||||
| #define cb3(_packmode, _i_src_type, _i_bias_type, _i_dst_type, _src_ctype, \ | |||||
| _bias_ctype, _dst_ctype, _postprocess_mode, _midout_tag) \ | |||||
| MIDOUT_BEGIN(megdnn_fallback_conv1x1_factory_strategy, \ | |||||
| midout_iv(_midout_tag)) { \ | |||||
| if (param.filter_type.enumv() == param.src_type.enumv() && \ | |||||
| param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \ | |||||
| param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \ | |||||
| return std::make_unique<Conv1x1Strategy< \ | |||||
| _src_ctype, _bias_ctype, _dst_ctype, _bias_ctype, \ | |||||
| _dst_ctype, _postprocess_mode, _packmode>>(pack_c_size); \ | |||||
| } \ | |||||
| } \ | |||||
| MIDOUT_END() | |||||
| switch (pack_mode) { | switch (pack_mode) { | ||||
| case MatrixMulImpl::AlgoBase::PackMode::DEFAULT: | case MatrixMulImpl::AlgoBase::PackMode::DEFAULT: | ||||
| @@ -71,26 +84,26 @@ std::unique_ptr<Conv1x1StrategyBase> create_conv1x1_strategy( | |||||
| "Default::FLOAT16_FLOAT16"_hash); | "Default::FLOAT16_FLOAT16"_hash); | ||||
| #endif | #endif | ||||
| #endif | #endif | ||||
| cb2(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, dt_int8, dt_int32, | |||||
| cb3(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, dt_int8, dt_int32, | |||||
| dt_int32, dt_int8, dt_int32, dt_int32, | dt_int32, dt_int8, dt_int32, dt_int32, | ||||
| PostprocessMode::NO_PROCESS, "Default::INT8x8x32_INT32"_hash); | |||||
| cb2(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, dt_int8, dt_int16, | |||||
| PostprocessMode::ADD_BIAS, "Default::INT8x8x32_INT32"_hash); | |||||
| cb3(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, dt_int8, dt_int16, | |||||
| dt_int16, dt_int8, dt_int16, dt_int16, | dt_int16, dt_int8, dt_int16, dt_int16, | ||||
| PostprocessMode::NO_PROCESS, "Default::INT8x8x16_INT16"_hash); | |||||
| PostprocessMode::ADD_BIAS, "Default::INT8x8x16_INT16"_hash); | |||||
| #if MEGDNN_AARCH64 || MEGDNN_ARMV7 | #if MEGDNN_AARCH64 || MEGDNN_ARMV7 | ||||
| cb2(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, | |||||
| cb3(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, | |||||
| dtype::Quantized8Asymm, dtype::QuantizedS32, | dtype::Quantized8Asymm, dtype::QuantizedS32, | ||||
| dtype::QuantizedS32, dt_uint8, dt_int32, dt_int32, | dtype::QuantizedS32, dt_uint8, dt_int32, dt_int32, | ||||
| PostprocessMode::NO_PROCESS, | |||||
| PostprocessMode::ADD_BIAS, | |||||
| "Default::QUINT8x8x32_QINT32"_hash); | "Default::QUINT8x8x32_QINT32"_hash); | ||||
| cb2(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, | cb2(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, | ||||
| dtype::Quantized8Asymm, dtype::QuantizedS32, | dtype::Quantized8Asymm, dtype::QuantizedS32, | ||||
| dtype::Quantized8Asymm, dt_uint8, dt_int32, dt_uint8, | dtype::Quantized8Asymm, dt_uint8, dt_int32, dt_uint8, | ||||
| PostprocessMode::QUANTIZED, "Default::QUINT8x8x32_QUINT8"_hash); | PostprocessMode::QUANTIZED, "Default::QUINT8x8x32_QUINT8"_hash); | ||||
| #endif | #endif | ||||
| cb2(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, dtype::QuantizedS8, | |||||
| cb3(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, dtype::QuantizedS8, | |||||
| dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, dt_int32, | dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, dt_int32, | ||||
| dt_int32, PostprocessMode::NO_PROCESS, | |||||
| dt_int32, PostprocessMode::ADD_BIAS, | |||||
| "Default::QINT8x8x32_QINT32"_hash); | "Default::QINT8x8x32_QINT32"_hash); | ||||
| cb2(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, dtype::QuantizedS8, | cb2(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, dtype::QuantizedS8, | ||||
| dtype::QuantizedS32, dtype::QuantizedS8, dt_int8, dt_int32, | dtype::QuantizedS32, dtype::QuantizedS8, dt_int8, dt_int32, | ||||
| @@ -107,17 +120,17 @@ std::unique_ptr<Conv1x1StrategyBase> create_conv1x1_strategy( | |||||
| cb1(MatrixMulImpl::AlgoBase::PackMode::NO_PACK, dt_float32, | cb1(MatrixMulImpl::AlgoBase::PackMode::NO_PACK, dt_float32, | ||||
| dt_float32, PostprocessMode::FLOAT, "NoPack::FLOAT"_hash); | dt_float32, PostprocessMode::FLOAT, "NoPack::FLOAT"_hash); | ||||
| cb2(MatrixMulImpl::AlgoBase::PackMode::NO_PACK, dt_int8, dt_int16, | |||||
| cb3(MatrixMulImpl::AlgoBase::PackMode::NO_PACK, dt_int8, dt_int16, | |||||
| dt_int16, dt_int8, dt_int16, dt_int16, | dt_int16, dt_int8, dt_int16, dt_int16, | ||||
| PostprocessMode::NO_PROCESS, "NoPack::INT8x8x16_INT16"_hash); | |||||
| PostprocessMode::ADD_BIAS, "NoPack::INT8x8x16_INT16"_hash); | |||||
| cb2(MatrixMulImpl::AlgoBase::PackMode::NO_PACK, dt_int8, dt_int32, | |||||
| cb3(MatrixMulImpl::AlgoBase::PackMode::NO_PACK, dt_int8, dt_int32, | |||||
| dt_int32, dt_int8, dt_int32, dt_int32, | dt_int32, dt_int8, dt_int32, dt_int32, | ||||
| PostprocessMode::NO_PROCESS, "NoPack::INT8x8x32_INT32"_hash); | |||||
| PostprocessMode::ADD_BIAS, "NoPack::INT8x8x32_INT32"_hash); | |||||
| cb2(MatrixMulImpl::AlgoBase::PackMode::NO_PACK, dtype::QuantizedS8, | |||||
| cb3(MatrixMulImpl::AlgoBase::PackMode::NO_PACK, dtype::QuantizedS8, | |||||
| dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, dt_int32, | dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, dt_int32, | ||||
| dt_int32, PostprocessMode::NO_PROCESS, | |||||
| dt_int32, PostprocessMode::ADD_BIAS, | |||||
| "NoPack::QINT8x8x32_QINT32"_hash); | "NoPack::QINT8x8x32_QINT32"_hash); | ||||
| break; | break; | ||||
| @@ -127,6 +140,7 @@ std::unique_ptr<Conv1x1StrategyBase> create_conv1x1_strategy( | |||||
| } | } | ||||
| #undef cb1 | #undef cb1 | ||||
| #undef cb2 | #undef cb2 | ||||
| #undef cb3 | |||||
| megdnn_throw("Invalid Data Type"); | megdnn_throw("Invalid Data Type"); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -207,4 +221,4 @@ bool Conv1x1Factory::can_make_conv1x1_strategy( | |||||
| } // namespace fallback | } // namespace fallback | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| // vim: syntax=cpp.doxygen | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -746,8 +746,7 @@ bool ConvBiasImpl::AlgoIm2col::usable( | |||||
| if (param.dst_type.enumv() == DTypeEnum::Int16 || | if (param.dst_type.enumv() == DTypeEnum::Int16 || | ||||
| param.dst_type.enumv() == DTypeEnum::Int32 || | param.dst_type.enumv() == DTypeEnum::Int32 || | ||||
| param.dst_type.enumv() == DTypeEnum::QuantizedS32) { | param.dst_type.enumv() == DTypeEnum::QuantizedS32) { | ||||
| if (param.bias_mode != megdnn::BiasMode::NO_BIAS || | |||||
| param.nonlineMode != megdnn::NonlineMode::IDENTITY) { | |||||
| if (param.nonlineMode != megdnn::NonlineMode::IDENTITY) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| @@ -213,6 +213,22 @@ public: | |||||
| } \ | } \ | ||||
| MIDOUT_END(); \ | MIDOUT_END(); \ | ||||
| return {}; | return {}; | ||||
| #define cb3(_format, _packmode, _i_src_type, _i_bias_type, _i_dst_type, \ | |||||
| _src_ctype, _bias_ctype, _dst_ctype, _postprocess_mode, \ | |||||
| _midout_tag) \ | |||||
| MIDOUT_BEGIN(megdnn_fallback_im2col_factory_make_strategy, \ | |||||
| midout_iv(_midout_tag)) { \ | |||||
| if (param.filter_type.enumv() == param.src_type.enumv() && \ | |||||
| param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \ | |||||
| param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \ | |||||
| return std::make_unique< \ | |||||
| Strategy<_src_ctype, _bias_ctype, _dst_ctype, _bias_ctype, \ | |||||
| _dst_ctype, _postprocess_mode, \ | |||||
| PackMode::_packmode, FormatMode::_format>>(); \ | |||||
| } \ | |||||
| } \ | |||||
| MIDOUT_END(); \ | |||||
| return {}; | |||||
| static std::unique_ptr<StrategyBase> make_default_strategy( | static std::unique_ptr<StrategyBase> make_default_strategy( | ||||
| fallback::MatrixMulImpl::AlgoBase* matmul_algo, | fallback::MatrixMulImpl::AlgoBase* matmul_algo, | ||||
| @@ -279,13 +295,13 @@ public: | |||||
| #endif | #endif | ||||
| case StrategyType::INT8x8x32: | case StrategyType::INT8x8x32: | ||||
| if (format == param::ConvBias::Format::NCHW) { | if (format == param::ConvBias::Format::NCHW) { | ||||
| cb2(NCHW, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8, | |||||
| dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||||
| cb3(NCHW, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8, | |||||
| dt_int32, dt_int32, PostprocessMode::ADD_BIAS, | |||||
| "DefaultStrategyType::INT8x8x32"_hash); | "DefaultStrategyType::INT8x8x32"_hash); | ||||
| } else if (format == param::ConvBias::Format::NCHW44 || | } else if (format == param::ConvBias::Format::NCHW44 || | ||||
| format == param::ConvBias::Format::NCHW44_DOT) { | format == param::ConvBias::Format::NCHW44_DOT) { | ||||
| cb2(NCHW44, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8, | |||||
| dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||||
| cb3(NCHW44, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8, | |||||
| dt_int32, dt_int32, PostprocessMode::ADD_BIAS, | |||||
| "DefaultStrategyType::INT8x8x32"_hash); | "DefaultStrategyType::INT8x8x32"_hash); | ||||
| } else { | } else { | ||||
| megdnn_throw( | megdnn_throw( | ||||
| @@ -299,12 +315,12 @@ public: | |||||
| case StrategyType::INT8x8x16: | case StrategyType::INT8x8x16: | ||||
| if (format == param::ConvBias::Format::NCHW) { | if (format == param::ConvBias::Format::NCHW) { | ||||
| cb2(NCHW, DEFAULT, dt_int8, dt_int16, dt_int16, dt_int8, | |||||
| dt_int16, dt_int16, PostprocessMode::NO_PROCESS, | |||||
| cb3(NCHW, DEFAULT, dt_int8, dt_int16, dt_int16, dt_int8, | |||||
| dt_int16, dt_int16, PostprocessMode::ADD_BIAS, | |||||
| "DefaultStrategyType::INT8x8x16"_hash); | "DefaultStrategyType::INT8x8x16"_hash); | ||||
| } else if (format == param::ConvBias::Format::NCHW44) { | } else if (format == param::ConvBias::Format::NCHW44) { | ||||
| cb2(NCHW44, DEFAULT, dt_int8, dt_int16, dt_int16, dt_int8, | |||||
| dt_int16, dt_int16, PostprocessMode::NO_PROCESS, | |||||
| cb3(NCHW44, DEFAULT, dt_int8, dt_int16, dt_int16, dt_int8, | |||||
| dt_int16, dt_int16, PostprocessMode::ADD_BIAS, | |||||
| "DefaultStrategyType::INT8x8x16"_hash); | "DefaultStrategyType::INT8x8x16"_hash); | ||||
| } else { | } else { | ||||
| megdnn_throw( | megdnn_throw( | ||||
| @@ -316,9 +332,9 @@ public: | |||||
| break; | break; | ||||
| #if MEGDNN_AARCH64 || MEGDNN_ARMV7 | #if MEGDNN_AARCH64 || MEGDNN_ARMV7 | ||||
| case StrategyType::QUINT8x8x32: | case StrategyType::QUINT8x8x32: | ||||
| cb2(NCHW, DEFAULT, dtype::Quantized8Asymm, dtype::QuantizedS32, | |||||
| cb3(NCHW, DEFAULT, dtype::Quantized8Asymm, dtype::QuantizedS32, | |||||
| dtype::QuantizedS32, dt_uint8, dt_int32, dt_int32, | dtype::QuantizedS32, dt_uint8, dt_int32, dt_int32, | ||||
| PostprocessMode::NO_PROCESS, | |||||
| PostprocessMode::ADD_BIAS, | |||||
| "DefaultStrategyType::QUINT8x8x32"_hash); | "DefaultStrategyType::QUINT8x8x32"_hash); | ||||
| break; | break; | ||||
| @@ -331,15 +347,15 @@ public: | |||||
| #endif | #endif | ||||
| case StrategyType::QINT8x8x32: | case StrategyType::QINT8x8x32: | ||||
| if (format == param::ConvBias::Format::NCHW) { | if (format == param::ConvBias::Format::NCHW) { | ||||
| cb2(NCHW, DEFAULT, dtype::QuantizedS8, dtype::QuantizedS32, | |||||
| cb3(NCHW, DEFAULT, dtype::QuantizedS8, dtype::QuantizedS32, | |||||
| dtype::QuantizedS32, dt_int8, dt_int32, dt_int32, | dtype::QuantizedS32, dt_int8, dt_int32, dt_int32, | ||||
| PostprocessMode::NO_PROCESS, | |||||
| PostprocessMode::ADD_BIAS, | |||||
| "DefaultStrategyTypeNCHW::QINT8x8x32"_hash); | "DefaultStrategyTypeNCHW::QINT8x8x32"_hash); | ||||
| } else if (format == param::ConvBias::Format::NCHW44 || | } else if (format == param::ConvBias::Format::NCHW44 || | ||||
| format == param::ConvBias::Format::NCHW44_DOT) { | format == param::ConvBias::Format::NCHW44_DOT) { | ||||
| cb2(NCHW44, DEFAULT, dtype::QuantizedS8, | |||||
| cb3(NCHW44, DEFAULT, dtype::QuantizedS8, | |||||
| dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, | dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, | ||||
| dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||||
| dt_int32, dt_int32, PostprocessMode::ADD_BIAS, | |||||
| "DefaultStrategyTypeHCHW44::QINT8x8x32"_hash); | "DefaultStrategyTypeHCHW44::QINT8x8x32"_hash); | ||||
| } else { | } else { | ||||
| megdnn_throw( | megdnn_throw( | ||||
| @@ -467,13 +483,13 @@ public: | |||||
| #endif | #endif | ||||
| #endif | #endif | ||||
| case StrategyType::INT8x8x16: | case StrategyType::INT8x8x16: | ||||
| cb2(NCHW, NO_PACK, dt_int8, dt_int16, dt_int16, dt_int8, | |||||
| dt_int16, dt_int16, PostprocessMode::NO_PROCESS, | |||||
| cb3(NCHW, NO_PACK, dt_int8, dt_int16, dt_int16, dt_int8, | |||||
| dt_int16, dt_int16, PostprocessMode::ADD_BIAS, | |||||
| "NoPackStrategyType::INT8x8x16"_hash); | "NoPackStrategyType::INT8x8x16"_hash); | ||||
| break; | break; | ||||
| case StrategyType::INT8x8x32: | case StrategyType::INT8x8x32: | ||||
| cb2(NCHW, NO_PACK, dt_int8, dt_int32, dt_int32, dt_int8, | |||||
| dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||||
| cb3(NCHW, NO_PACK, dt_int8, dt_int32, dt_int32, dt_int8, | |||||
| dt_int32, dt_int32, PostprocessMode::ADD_BIAS, | |||||
| "NoPackStrategyType::INT8x8x32"_hash); | "NoPackStrategyType::INT8x8x32"_hash); | ||||
| break; | break; | ||||
| default: | default: | ||||
| @@ -509,6 +525,7 @@ public: | |||||
| #undef cb1 | #undef cb1 | ||||
| #undef cb2 | #undef cb2 | ||||
| #undef cb3 | |||||
| static std::unique_ptr<StrategyBase> make_strategy( | static std::unique_ptr<StrategyBase> make_strategy( | ||||
| fallback::MatrixMulImpl::AlgoBase* matmul_algo, | fallback::MatrixMulImpl::AlgoBase* matmul_algo, | ||||
| @@ -203,18 +203,16 @@ INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16, | |||||
| //! x86 do not have uint8 matmul so only armv7 armv8 support uint8 | //! x86 do not have uint8 matmul so only armv7 armv8 support uint8 | ||||
| INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_uint8, dt_qint32, dt_quint8, | INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_uint8, dt_qint32, dt_quint8, | ||||
| megdnn::PostprocessMode::QUANTIZED) | megdnn::PostprocessMode::QUANTIZED) | ||||
| INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_qint32, dt_qint32, | |||||
| megdnn::PostprocessMode::NO_PROCESS) | |||||
| INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_int32, dt_int32, | |||||
| megdnn::PostprocessMode::ADD_BIAS) | |||||
| #endif | #endif | ||||
| INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8, | INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8, | ||||
| megdnn::PostprocessMode::QUANTIZED) | megdnn::PostprocessMode::QUANTIZED) | ||||
| INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_int32, dt_int32, | INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_int32, dt_int32, | ||||
| megdnn::PostprocessMode::NO_PROCESS) | |||||
| megdnn::PostprocessMode::ADD_BIAS) | |||||
| INSTANTIAL_CLASS(dt_int8, dt_int16, dt_int16, dt_int16, dt_int16, | INSTANTIAL_CLASS(dt_int8, dt_int16, dt_int16, dt_int16, dt_int16, | ||||
| megdnn::PostprocessMode::NO_PROCESS) | |||||
| INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_qint32, dt_qint32, | |||||
| megdnn::PostprocessMode::NO_PROCESS) | |||||
| megdnn::PostprocessMode::ADD_BIAS) | |||||
| #undef INSTANTIAL_CLASS | #undef INSTANTIAL_CLASS | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| @@ -119,19 +119,16 @@ INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16, | |||||
| //! x86 do not have uint8 matmul so only armv7 armv8 support uint8 | //! x86 do not have uint8 matmul so only armv7 armv8 support uint8 | ||||
| INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_uint8, dt_qint32, dt_quint8, | INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_uint8, dt_qint32, dt_quint8, | ||||
| megdnn::PostprocessMode::QUANTIZED) | megdnn::PostprocessMode::QUANTIZED) | ||||
| INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_qint32, dt_qint32, | |||||
| megdnn::PostprocessMode::NO_PROCESS) | |||||
| INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_int32, dt_int32, | |||||
| megdnn::PostprocessMode::ADD_BIAS) | |||||
| #endif | #endif | ||||
| INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8, | INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8, | ||||
| megdnn::PostprocessMode::QUANTIZED) | megdnn::PostprocessMode::QUANTIZED) | ||||
| INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_int32, dt_int32, | INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_int32, dt_int32, | ||||
| megdnn::PostprocessMode::NO_PROCESS) | |||||
| megdnn::PostprocessMode::ADD_BIAS) | |||||
| INSTANTIAL_CLASS(dt_int8, dt_int16, dt_int16, dt_int16, dt_int16, | INSTANTIAL_CLASS(dt_int8, dt_int16, dt_int16, dt_int16, dt_int16, | ||||
| megdnn::PostprocessMode::NO_PROCESS) | |||||
| INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_qint32, dt_qint32, | |||||
| megdnn::PostprocessMode::NO_PROCESS) | |||||
| megdnn::PostprocessMode::ADD_BIAS) | |||||
| #undef INSTANTIAL_CLASS | #undef INSTANTIAL_CLASS | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| @@ -162,9 +162,9 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
| INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32, | INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32, | ||||
| megdnn::PostprocessMode::FLOAT) | megdnn::PostprocessMode::FLOAT) | ||||
| INSTANTIAL_CLASS(dt_int8, dt_int16, dt_int16, dt_int16, dt_int16, | INSTANTIAL_CLASS(dt_int8, dt_int16, dt_int16, dt_int16, dt_int16, | ||||
| megdnn::PostprocessMode::NO_PROCESS) | |||||
| megdnn::PostprocessMode::ADD_BIAS) | |||||
| INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_int32, dt_int32, | INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_int32, dt_int32, | ||||
| megdnn::PostprocessMode::NO_PROCESS) | |||||
| megdnn::PostprocessMode::ADD_BIAS) | |||||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
| #else | #else | ||||
| #if !MEGDNN_DISABLE_FLOAT16 | #if !MEGDNN_DISABLE_FLOAT16 | ||||
| @@ -294,6 +294,73 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::QUANTIZED> { | |||||
| #undef FOR_BIAS | #undef FOR_BIAS | ||||
| } | } | ||||
| }; | }; | ||||
| #undef CALL_BINARY | |||||
| #undef CALL_BINARY_BROADCAST | |||||
| #define CALL_BINARY(_op, _simd_type) \ | |||||
| thin_function<void(const ctype*, const ctype*, dtype*, DType, DType, \ | |||||
| DType, size_t)> \ | |||||
| run = OpCallerBinary<_op<_simd_type, ctype, dtype>, _simd_type, \ | |||||
| megdnn::x86::BcastType::VEC_VEC>::run; \ | |||||
| run(static_cast<ctype*>(conv_dst_ptr), static_cast<ctype*>(bias_ptr), \ | |||||
| reinterpret_cast<dtype*>(dst_ptr), bias_type, bias_type, dst_type, \ | |||||
| N* OC* OH* OW); | |||||
| #define CALL_BINARY_BROADCAST(_op, _simd_type) \ | |||||
| thin_function<void(const ctype*, const ctype*, dtype*, DType, DType, \ | |||||
| DType, size_t, size_t, size_t)> \ | |||||
| run = OpCallerBinary<_op<_simd_type, ctype, dtype>, _simd_type, \ | |||||
| megdnn::x86::BcastType::VEC_BCAST101>::run; \ | |||||
| run(static_cast<ctype*>(conv_dst_ptr), static_cast<ctype*>(bias_ptr), \ | |||||
| reinterpret_cast<dtype*>(dst_ptr), bias_type, bias_type, dst_type, N, \ | |||||
| OC, OH* OW); | |||||
| #define FOR_SIMD(CALLER) \ | |||||
| if (is_supported(SIMDType::AVX2)) { \ | |||||
| CALLER(AddOp, SIMDType::AVX2) \ | |||||
| } else if (is_supported(SIMDType::SSE4_2)) { \ | |||||
| CALLER(AddOp, SIMDType::SSE4_2) \ | |||||
| } else { \ | |||||
| CALLER(AddOp, SIMDType::NONE) \ | |||||
| } | |||||
| #define FOR_BIAS(bias_mode) \ | |||||
| switch (bias_mode) { \ | |||||
| case BiasMode::BIAS: \ | |||||
| FOR_SIMD(CALL_BINARY); \ | |||||
| break; \ | |||||
| case BiasMode::BROADCAST_CHANNEL_BIAS: \ | |||||
| FOR_SIMD(CALL_BINARY_BROADCAST); \ | |||||
| break; \ | |||||
| default: \ | |||||
| break; \ | |||||
| } | |||||
| template <typename ctype, typename dtype> | |||||
| struct PostProcess<ctype, dtype, megdnn::PostprocessMode::ADD_BIAS> { | |||||
| static void run(void* conv_dst_ptr, void* bias_ptr, void* dst_ptr, | |||||
| megdnn::ConvBiasForward::BiasMode bias_mode, | |||||
| megdnn::param::ConvBiasV0::NonlineMode nonlineMode, | |||||
| DType bias_type, DType dst_type, size_t N, size_t OC, | |||||
| size_t OH, size_t OW, size_t pack_oc_size = 1) { | |||||
| MEGDNN_MARK_USED_VAR(pack_oc_size); | |||||
| megdnn_assert(pack_oc_size == 1, | |||||
| "PostProcess only support nchw in x86"); | |||||
| megdnn_assert( | |||||
| nonlineMode == megdnn::param::ConvBiasV0::NonlineMode::IDENTITY, | |||||
| "Add bias PostProcess only support IDENTITY"); | |||||
| if (bias_mode == megdnn::ConvBiasForward::BiasMode::NO_BIAS) { | |||||
| return; | |||||
| } | |||||
| FOR_BIAS(bias_mode); | |||||
| #undef CALL_BINARY | |||||
| #undef CALL_BINARY_BROADCAST | |||||
| #undef FOR_SIMD | |||||
| #undef FOR_BIAS | |||||
| } | |||||
| }; | |||||
| #undef cb_unary | #undef cb_unary | ||||
| #undef cb_binary | #undef cb_binary | ||||
| #undef BIAS_CASE | #undef BIAS_CASE | ||||
| @@ -92,6 +92,8 @@ OP(dt_int8, SIMDType::AVX2, "avx2", __m256i, __m256ix2, __m256i, mm256, epi8, | |||||
| using AddOpBase::operator(); \ | using AddOpBase::operator(); \ | ||||
| }; | }; | ||||
| OP(dt_int32, SIMDType::NONE); | |||||
| OP(dt_int16, SIMDType::NONE); | |||||
| OP(dt_float32, SIMDType::NONE); | OP(dt_float32, SIMDType::NONE); | ||||
| #undef OP | #undef OP | ||||
| } // namespace x86 | } // namespace x86 | ||||
| @@ -1992,13 +1992,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_S8x8x32_MK4_DOT) { | |||||
| #define cb(name) \ | #define cb(name) \ | ||||
| checker_conv_bias( \ | checker_conv_bias( \ | ||||
| get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \ | |||||
| true, false, true, false, false, true), \ | |||||
| get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, \ | |||||
| true, false, true, true, false, false), \ | |||||
| handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ | handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ | ||||
| dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), {}, name); \ | dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), {}, name); \ | ||||
| checker_conv_bias( \ | checker_conv_bias( \ | ||||
| get_nchw44_conv_bias_args({1}, 2, false, true, true, false, true, \ | get_nchw44_conv_bias_args({1}, 2, false, true, true, false, true, \ | ||||
| false, false, true), \ | |||||
| true, false, false), \ | |||||
| handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ | handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ | ||||
| dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), {}, name); | dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), {}, name); | ||||
| @@ -2041,13 +2041,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32_MK4_DOT) { | |||||
| #define cb(name) \ | #define cb(name) \ | ||||
| checker_conv_bias( \ | checker_conv_bias( \ | ||||
| get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \ | |||||
| true, false, true, false, false, true), \ | |||||
| get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, \ | |||||
| true, false, true, true, false, false), \ | |||||
| handle(), &rng, epsilon, dtype::Int8(), dtype::Int8(), \ | handle(), &rng, epsilon, dtype::Int8(), dtype::Int8(), \ | ||||
| dtype::Int32(), {}, name); \ | dtype::Int32(), {}, name); \ | ||||
| checker_conv_bias( \ | checker_conv_bias( \ | ||||
| get_nchw44_conv_bias_args({1}, 2, false, true, true, false, true, \ | get_nchw44_conv_bias_args({1}, 2, false, true, true, false, true, \ | ||||
| false, false, true), \ | |||||
| true, false, false), \ | |||||
| handle(), &rng, epsilon, dtype::Int8(), dtype::Int8(), \ | handle(), &rng, epsilon, dtype::Int8(), dtype::Int8(), \ | ||||
| dtype::Int32(), {}, name); | dtype::Int32(), {}, name); | ||||
| @@ -2118,7 +2118,6 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CONV1x1_QUANTIZEDSYM_MK4_DOT) { | |||||
| #if MEGDNN_AARCH64 || MEGDNN_ARMV7 | #if MEGDNN_AARCH64 || MEGDNN_ARMV7 | ||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDASYM) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDASYM) { | ||||
| NormalRNG rng(128.f); | NormalRNG rng(128.f); | ||||
| #define cb(name) \ | #define cb(name) \ | ||||
| checker_conv_bias(get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \ | checker_conv_bias(get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \ | ||||
| false, true, true), \ | false, true, true), \ | ||||
| @@ -2188,18 +2187,19 @@ TEST_F(ARM_COMMON_MULTI_THREADS, | |||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUINT8x8x32) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUINT8x8x32) { | ||||
| UniformIntRNG rng{-50, 50}; | UniformIntRNG rng{-50, 50}; | ||||
| float epsilon = 0.001; | float epsilon = 0.001; | ||||
| #define cb(name) \ | |||||
| checker_conv_bias( \ | |||||
| get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true), \ | |||||
| handle(), &rng, epsilon, \ | |||||
| dtype::Quantized8Asymm(1.2f, (uint8_t)125), \ | |||||
| dtype::Quantized8Asymm(1.3f, (uint8_t)129), \ | |||||
| dtype::QuantizedS32(1.2 * 1.3), {}, name); \ | |||||
| checker_conv_bias(get_conv_bias_args({1}, 2, false, true, true), handle(), \ | |||||
| &rng, epsilon, \ | |||||
| dtype::Quantized8Asymm(1.2f, (uint8_t)125), \ | |||||
| dtype::Quantized8Asymm(1.3f, (uint8_t)129), \ | |||||
| dtype::QuantizedS32(1.2 * 1.3), {}, name); | |||||
| #define cb(name) \ | |||||
| checker_conv_bias(get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \ | |||||
| true, true, false), \ | |||||
| handle(), &rng, epsilon, \ | |||||
| dtype::Quantized8Asymm(1.2f, (uint8_t)125), \ | |||||
| dtype::Quantized8Asymm(1.3f, (uint8_t)129), \ | |||||
| dtype::QuantizedS32(1.2 * 1.3), {}, name); \ | |||||
| checker_conv_bias( \ | |||||
| get_conv_bias_args({1}, 2, false, false, true, true, false), \ | |||||
| handle(), &rng, epsilon, \ | |||||
| dtype::Quantized8Asymm(1.2f, (uint8_t)125), \ | |||||
| dtype::Quantized8Asymm(1.3f, (uint8_t)129), \ | |||||
| dtype::QuantizedS32(1.2 * 1.3), {}, name); | |||||
| #if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
| #if __ARM_FEATURE_DOTPROD | #if __ARM_FEATURE_DOTPROD | ||||
| @@ -2252,18 +2252,18 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16) { | |||||
| UniformIntRNG rng{-50, 50}; | UniformIntRNG rng{-50, 50}; | ||||
| float epsilon = 0.001; | float epsilon = 0.001; | ||||
| std::vector<conv_bias::TestArg> args_nchw44 = | std::vector<conv_bias::TestArg> args_nchw44 = | ||||
| get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, true, true, true, | |||||
| false, false, false, false, true); | |||||
| get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, true, false, true, | |||||
| false, false, true, false, false); | |||||
| std::vector<conv_bias::TestArg> args_nchw44_1x1s2 = | std::vector<conv_bias::TestArg> args_nchw44_1x1s2 = | ||||
| get_nchw44_conv_bias_args({1}, 2, true, true, true, false, false, | |||||
| false, false, true); | |||||
| #define cb(name) \ | |||||
| checker_conv_bias( \ | |||||
| get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true), \ | |||||
| handle(), &rng, epsilon, dtype::Int8{}, dtype::Int8{}, \ | |||||
| dtype::Int16{}, dtype::Int16{}, name); \ | |||||
| checker_conv_bias(get_conv_bias_args({1}, 2, false, true, true), handle(), \ | |||||
| &rng, epsilon, dtype::Int8{}, dtype::Int8{}, \ | |||||
| get_nchw44_conv_bias_args({1}, 2, true, false, true, false, false, | |||||
| true, false, false); | |||||
| #define cb(name) \ | |||||
| checker_conv_bias( \ | |||||
| get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, true), \ | |||||
| handle(), &rng, epsilon, dtype::Int8{}, dtype::Int8{}, \ | |||||
| dtype::Int16{}, dtype::Int16{}, name); \ | |||||
| checker_conv_bias(get_conv_bias_args({1}, 2, false, false, true), \ | |||||
| handle(), &rng, epsilon, dtype::Int8{}, dtype::Int8{}, \ | |||||
| dtype::Int16{}, dtype::Int16{}, name); | dtype::Int16{}, dtype::Int16{}, name); | ||||
| #define cb_nchw44(name) \ | #define cb_nchw44(name) \ | ||||
| @@ -2314,14 +2314,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16_FILTERPREPROCES | |||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16_NOPACK_FILTERPREPROCESS) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16_NOPACK_FILTERPREPROCESS) { | ||||
| UniformIntRNG rng{-50, 50}; | UniformIntRNG rng{-50, 50}; | ||||
| float epsilon = 0.001; | float epsilon = 0.001; | ||||
| #define cb(name) \ | |||||
| check_conv_bias_preprocess( \ | |||||
| get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true), \ | |||||
| handle(), &rng, epsilon, dtype::Int8{}, dtype::Int8{}, \ | |||||
| dtype::Int16{}, dtype::Int16{}, name); \ | |||||
| check_conv_bias_preprocess(get_conv_bias_args({1}, 2, false, true, true), \ | |||||
| handle(), &rng, epsilon, dtype::Int8{}, \ | |||||
| dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, \ | |||||
| #define cb(name) \ | |||||
| check_conv_bias_preprocess( \ | |||||
| get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, true), \ | |||||
| handle(), &rng, epsilon, dtype::Int8{}, dtype::Int8{}, \ | |||||
| dtype::Int16{}, dtype::Int16{}, name); \ | |||||
| check_conv_bias_preprocess(get_conv_bias_args({1}, 2, false, false, true), \ | |||||
| handle(), &rng, epsilon, dtype::Int8{}, \ | |||||
| dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, \ | |||||
| name); | name); | ||||
| #if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
| @@ -2406,7 +2406,7 @@ void checker_conv_bias_mul_int8x8x32(std::vector<conv_bias::TestArg> args, | |||||
| checker.set_dtype(0, dtype::QuantizedS8(2.5f)) | checker.set_dtype(0, dtype::QuantizedS8(2.5f)) | ||||
| .set_dtype(1, dtype::QuantizedS8(2.5f)) | .set_dtype(1, dtype::QuantizedS8(2.5f)) | ||||
| .set_dtype(2, dtype::QuantizedS32(6.25f)) | .set_dtype(2, dtype::QuantizedS32(6.25f)) | ||||
| .set_dtype(4, {}) | |||||
| .set_dtype(4, dtype::QuantizedS32(6.25f)) | |||||
| .set_rng(0, &rng) | .set_rng(0, &rng) | ||||
| .set_rng(1, &rng) | .set_rng(1, &rng) | ||||
| .set_rng(2, &rng) | .set_rng(2, &rng) | ||||
| @@ -2436,7 +2436,7 @@ void checker_conv_bias_int8x8x32_preprocess(std::vector<conv_bias::TestArg> args | |||||
| checker.set_dtype(0, dtype::QuantizedS8(2.5f)) | checker.set_dtype(0, dtype::QuantizedS8(2.5f)) | ||||
| .set_dtype(1, dtype::QuantizedS8(2.5f)) | .set_dtype(1, dtype::QuantizedS8(2.5f)) | ||||
| .set_dtype(2, dtype::QuantizedS32(6.25f)) | .set_dtype(2, dtype::QuantizedS32(6.25f)) | ||||
| .set_dtype(4, {}) | |||||
| .set_dtype(4, dtype::QuantizedS32(6.25f)) | |||||
| .set_rng(0, &rng) | .set_rng(0, &rng) | ||||
| .set_rng(1, &rng) | .set_rng(1, &rng) | ||||
| .set_rng(2, &rng) | .set_rng(2, &rng) | ||||
| @@ -2450,7 +2450,7 @@ void checker_conv_bias_int8x8x32_preprocess(std::vector<conv_bias::TestArg> args | |||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2) { | ||||
| using namespace conv_bias; | using namespace conv_bias; | ||||
| std::vector<conv_bias::TestArg> args = | std::vector<conv_bias::TestArg> args = | ||||
| get_nchw44_conv_bias_args({2, 5, 7}, 2, false, true, true); | |||||
| get_nchw44_conv_bias_args({2, 5, 7}, 2, false, false, true); | |||||
| #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name); | #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name); | ||||
| #if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
| @@ -2464,7 +2464,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2) { | |||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2_PREPROCESS) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2_PREPROCESS) { | ||||
| using namespace conv_bias; | using namespace conv_bias; | ||||
| std::vector<conv_bias::TestArg> args = | std::vector<conv_bias::TestArg> args = | ||||
| get_nchw44_conv_bias_args({2, 5, 7}, 2, false, true, true); | |||||
| get_nchw44_conv_bias_args({2, 5, 7}, 2, false, false, true); | |||||
| #define cb(name) checker_conv_bias_int8x8x32_preprocess(args, handle(), name); | #define cb(name) checker_conv_bias_int8x8x32_preprocess(args, handle(), name); | ||||
| #if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
| @@ -2478,7 +2478,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2_PREPR | |||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S1) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S1) { | ||||
| using namespace conv_bias; | using namespace conv_bias; | ||||
| std::vector<conv_bias::TestArg> args = | std::vector<conv_bias::TestArg> args = | ||||
| get_nchw44_conv_bias_args({3, 4, 6}, 1, false, true, true); | |||||
| get_nchw44_conv_bias_args({3, 4, 6}, 1, false, false, true); | |||||
| #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name); | #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name); | ||||
| #if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
| @@ -3080,9 +3080,10 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUINT8x8x32_PREPROCESS) { | |||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_1X1_S1_INT8x8x16) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_1X1_S1_INT8x8x16) { | ||||
| UniformIntRNG rng{-50, 50}; | UniformIntRNG rng{-50, 50}; | ||||
| float epsilon = 0.001; | float epsilon = 0.001; | ||||
| std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(true, true); | |||||
| std::vector<conv_bias::TestArg> args = | |||||
| get_conv_bias_1x1_args(false, true, false, false); | |||||
| std::vector<conv_bias::TestArg> args_nchw44 = get_nchw44_conv_bias_args( | std::vector<conv_bias::TestArg> args_nchw44 = get_nchw44_conv_bias_args( | ||||
| {1}, 1, true, true, true, false, false, false, false, true); | |||||
| {1}, 1, true, true, true, false, false, true, false, false); | |||||
| #define cb(name) \ | #define cb(name) \ | ||||
| checker_conv_bias(args, handle(), &rng, epsilon, dtype::Int8{}, \ | checker_conv_bias(args, handle(), &rng, epsilon, dtype::Int8{}, \ | ||||
| dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, name); | dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, name); | ||||
| @@ -3140,7 +3141,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_1X1_S1_INT8x8x16_PREPROCESS) { | |||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32) { | ||||
| using namespace conv_bias; | using namespace conv_bias; | ||||
| std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(true, true); | |||||
| std::vector<conv_bias::TestArg> args = | |||||
| get_conv_bias_1x1_args(false, true, false, false); | |||||
| #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name); | #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name); | ||||
| @@ -834,6 +834,13 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8X8X32) { | |||||
| //! no bias | //! no bias | ||||
| args.emplace_back(param, TensorShape{1, ic, h, w}, | args.emplace_back(param, TensorShape{1, ic, h, w}, | ||||
| TensorShape{oc, ic, kernel, kernel}, TensorShape{}); | TensorShape{oc, ic, kernel, kernel}, TensorShape{}); | ||||
| args.emplace_back(param, TensorShape{1, ic, h, w}, | |||||
| TensorShape{oc, ic, kernel, kernel}, | |||||
| TensorShape{1, oc, 1, 1}); | |||||
| args.emplace_back(param, TensorShape{1, ic, h, w}, | |||||
| TensorShape{oc, ic, kernel, kernel}, | |||||
| TensorShape{1, oc, (h + 2 * p - kernel) + 1, | |||||
| (h + 2 * p - kernel) + 1}); | |||||
| }; | }; | ||||
| for (size_t kernel : {2, 3, 4, 5, 6, 7}) | for (size_t kernel : {2, 3, 4, 5, 6, 7}) | ||||
| @@ -1384,7 +1391,7 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_INT8X8X32) { | |||||
| using namespace conv_bias; | using namespace conv_bias; | ||||
| UniformIntRNG rng{-50, 50}; | UniformIntRNG rng{-50, 50}; | ||||
| float epsilon = 0.001; | float epsilon = 0.001; | ||||
| std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(true, true); | |||||
| std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(false, true); | |||||
| #if MEGDNN_X86_WITH_MKL_DNN | #if MEGDNN_X86_WITH_MKL_DNN | ||||
| if (x86::is_supported(x86::SIMDType::VNNI)) { | if (x86::is_supported(x86::SIMDType::VNNI)) { | ||||
| checker_conv_bias(args, handle(), &rng, epsilon, dtype::Int8{}, | checker_conv_bias(args, handle(), &rng, epsilon, dtype::Int8{}, | ||||
| @@ -1422,7 +1429,7 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_INT8X8X32_PREPROCESS) { | |||||
| using namespace conv_bias; | using namespace conv_bias; | ||||
| UniformIntRNG rng{-50, 50}; | UniformIntRNG rng{-50, 50}; | ||||
| float epsilon = 0.001; | float epsilon = 0.001; | ||||
| std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(true, true); | |||||
| std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(false, true); | |||||
| #if MEGDNN_X86_WITH_VNNI | #if MEGDNN_X86_WITH_VNNI | ||||
| if (x86::is_supported(x86::SIMDType::VNNI)) { | if (x86::is_supported(x86::SIMDType::VNNI)) { | ||||
| checker_conv_bias_preprocess(args, handle(), &rng, epsilon, dtype::Int8{}, | checker_conv_bias_preprocess(args, handle(), &rng, epsilon, dtype::Int8{}, | ||||