im2co and conv1x1 mk4_dot support
GitOrigin-RevId: 096b16a3ab
tags/v0.5.0
| @@ -913,10 +913,10 @@ static void gemm_mk4_s8_8x12_pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||||
| *outptr++ = *inptr++; | *outptr++ = *inptr++; | ||||
| } | } | ||||
| for (; i < 4; i++) { | for (; i < 4; i++) { | ||||
| *outptr++ = *inptr++; | |||||
| *outptr++ = *inptr++; | |||||
| *outptr++ = *inptr++; | |||||
| *outptr++ = *inptr++; | |||||
| *outptr++ = 0; | |||||
| *outptr++ = 0; | |||||
| *outptr++ = 0; | |||||
| *outptr++ = 0; | |||||
| } | } | ||||
| } | } | ||||
| @@ -39,7 +39,7 @@ namespace { | |||||
| megdnn::arm_common::OpCallerUnary<_op<ctype>, megdnn::arm_common::VEC>:: \ | megdnn::arm_common::OpCallerUnary<_op<ctype>, megdnn::arm_common::VEC>:: \ | ||||
| run(static_cast<ctype*>(conv_dst_ptr), \ | run(static_cast<ctype*>(conv_dst_ptr), \ | ||||
| reinterpret_cast<ctype*>(dst_ptr), bias_type, dst_type, \ | reinterpret_cast<ctype*>(dst_ptr), bias_type, dst_type, \ | ||||
| N* OC* OH* OW); | |||||
| N* OC* OH* OW* pack_oc_size); | |||||
| #define FOR_NONLINEAR_BINARY_BROADCAST(_op) \ | #define FOR_NONLINEAR_BINARY_BROADCAST(_op) \ | ||||
| megdnn::arm_common:: \ | megdnn::arm_common:: \ | ||||
| @@ -63,7 +63,7 @@ namespace { | |||||
| static_cast<ctype*>(conv_dst_ptr), \ | static_cast<ctype*>(conv_dst_ptr), \ | ||||
| reinterpret_cast<const ctype*>(bias_ptr), \ | reinterpret_cast<const ctype*>(bias_ptr), \ | ||||
| reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \ | reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \ | ||||
| dst_type, N* OC* OH* OW); | |||||
| dst_type, N* OC* OH* OW* pack_oc_size); | |||||
| #define FOR_BIAS(_mode) \ | #define FOR_BIAS(_mode) \ | ||||
| switch (_mode) { \ | switch (_mode) { \ | ||||
| @@ -113,7 +113,6 @@ struct PostProcess { | |||||
| megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode, | megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode, | ||||
| megdnn::DType bias_type, megdnn::DType dst_type, size_t N, | megdnn::DType bias_type, megdnn::DType dst_type, size_t N, | ||||
| size_t OC, size_t OH, size_t OW, size_t pack_oc_size = 1) { | size_t OC, size_t OH, size_t OW, size_t pack_oc_size = 1) { | ||||
| MEGDNN_MARK_USED_VAR(pack_oc_size); | |||||
| FOR_BIAS(bias_mode) | FOR_BIAS(bias_mode) | ||||
| } | } | ||||
| }; | }; | ||||
| @@ -155,7 +154,8 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> { | |||||
| _op<opctype, opdtype>, \ | _op<opctype, opdtype>, \ | ||||
| megdnn::arm_common::VEC>::run(static_cast<opctype*>(conv_dst_ptr), \ | megdnn::arm_common::VEC>::run(static_cast<opctype*>(conv_dst_ptr), \ | ||||
| reinterpret_cast<opdtype*>(dst_ptr), \ | reinterpret_cast<opdtype*>(dst_ptr), \ | ||||
| bias_type, dst_type, N* OC* OH* OW); | |||||
| bias_type, dst_type, \ | |||||
| N* OC* OH* OW* pack_oc_size); | |||||
| #define FOR_NONLINEAR_BINARY_BROADCAST(_op) \ | #define FOR_NONLINEAR_BINARY_BROADCAST(_op) \ | ||||
| megdnn::arm_common::OpCallerBinary<_op<opctype, opdtype>, \ | megdnn::arm_common::OpCallerBinary<_op<opctype, opdtype>, \ | ||||
| @@ -173,8 +173,8 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> { | |||||
| reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, \ | reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, \ | ||||
| dst_type, N, OC, OH* OW, pack_oc_size); | dst_type, N, OC, OH* OW, pack_oc_size); | ||||
| #define HANDLE_IDENTITY(_caller, _op) \ | |||||
| case megdnn::NonlineMode::IDENTITY: \ | |||||
| #define HANDLE_IDENTITY(_caller, _op) \ | |||||
| case megdnn::NonlineMode::IDENTITY: \ | |||||
| _caller(_op) break; | _caller(_op) break; | ||||
| #define FOR_NONLINEAR(_caller) \ | #define FOR_NONLINEAR(_caller) \ | ||||
| @@ -729,10 +729,10 @@ static void gemm_dots8_8x6_pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||||
| *outptr++ = *inptr++; | *outptr++ = *inptr++; | ||||
| } | } | ||||
| for (; i < 4; i++) { | for (; i < 4; i++) { | ||||
| *outptr++ = *inptr++; | |||||
| *outptr++ = *inptr++; | |||||
| *outptr++ = *inptr++; | |||||
| *outptr++ = *inptr++; | |||||
| *outptr++ = 0; | |||||
| *outptr++ = 0; | |||||
| *outptr++ = 0; | |||||
| *outptr++ = 0; | |||||
| } | } | ||||
| } | } | ||||
| outptr_base += 24; | outptr_base += 24; | ||||
| @@ -187,7 +187,8 @@ bool ConvBiasImpl::AlgoConv1x1::usable(ConvBiasImpl* opr, | |||||
| AlgoSelectionStrategy) const { | AlgoSelectionStrategy) const { | ||||
| MIDOUT_BEGIN(megdnn_fallback_conv1x1, 0, 2) { | MIDOUT_BEGIN(megdnn_fallback_conv1x1, 0, 2) { | ||||
| if (opr->param().format != param::ConvBias::Format::NCHW && | if (opr->param().format != param::ConvBias::Format::NCHW && | ||||
| opr->param().format != param::ConvBias::Format::NCHW44) | |||||
| opr->param().format != param::ConvBias::Format::NCHW44 && | |||||
| opr->param().format != param::ConvBias::Format::NCHW44_DOT) | |||||
| return false; | return false; | ||||
| size_t FH = param.filter_meta.spatial[0], | size_t FH = param.filter_meta.spatial[0], | ||||
| @@ -219,8 +220,8 @@ bool ConvBiasImpl::AlgoConv1x1::usable(ConvBiasImpl* opr, | |||||
| param.nonlineMode != megdnn::NonlineMode::IDENTITY) | param.nonlineMode != megdnn::NonlineMode::IDENTITY) | ||||
| return false; | return false; | ||||
| if (opr->param().format == param::ConvBias::Format::NCHW44) { | |||||
| //! nchw44 hybird mode and channel wise is not support | |||||
| if (opr->param().format == param::ConvBias::Format::NCHW44 || | |||||
| opr->param().format == param::ConvBias::Format::NCHW44_DOT) { | |||||
| if (param.filter_meta.icpg < 4_z || param.filter_meta.icpg == 1 || | if (param.filter_meta.icpg < 4_z || param.filter_meta.icpg == 1 || | ||||
| param.filter_meta.ocpg == 1) { | param.filter_meta.ocpg == 1) { | ||||
| return false; | return false; | ||||
| @@ -73,32 +73,34 @@ std::unique_ptr<Conv1x1StrategyBase> create_conv1x1_strategy( | |||||
| const ConvBiasImpl::NCBKernSizeParam& param, | const ConvBiasImpl::NCBKernSizeParam& param, | ||||
| MatrixMulImpl::AlgoBase::PackMode pack_mode, | MatrixMulImpl::AlgoBase::PackMode pack_mode, | ||||
| param::ConvBias::Format format) { | param::ConvBias::Format format) { | ||||
| size_t pack_size = get_format_pack_size(format); | |||||
| #define cb1(_packmode, _dt, _post_ctype, _postprocess_mode, _midout_tag) \ | |||||
| MIDOUT_BEGIN(megdnn_fallback_conv1x1_factory_strategy, \ | |||||
| midout_iv(_midout_tag)) { \ | |||||
| if (param.filter_type.enumv() == DTypeTrait<_dt>::enumv) { \ | |||||
| return std::make_unique< \ | |||||
| Conv1x1Strategy<_dt, _dt, _dt, _post_ctype, _post_ctype, \ | |||||
| _postprocess_mode, _packmode>>(pack_size); \ | |||||
| } \ | |||||
| } \ | |||||
| size_t pack_c_size = pack_size(format); | |||||
| #define cb1(_packmode, _dt, _post_ctype, _postprocess_mode, _midout_tag) \ | |||||
| MIDOUT_BEGIN(megdnn_fallback_conv1x1_factory_strategy, \ | |||||
| midout_iv(_midout_tag)) { \ | |||||
| if (param.filter_type.enumv() == DTypeTrait<_dt>::enumv) { \ | |||||
| return std::make_unique< \ | |||||
| Conv1x1Strategy<_dt, _dt, _dt, _post_ctype, _post_ctype, \ | |||||
| _postprocess_mode, _packmode>>( \ | |||||
| pack_c_size); \ | |||||
| } \ | |||||
| } \ | |||||
| MIDOUT_END() | MIDOUT_END() | ||||
| #define cb2(_packmode, _i_src_type, _i_bias_type, _i_dst_type, _src_ctype, \ | |||||
| _bias_ctype, _dst_ctype, _postprocess_mode, _midout_tag) \ | |||||
| MIDOUT_BEGIN(megdnn_fallback_conv1x1_factory_strategy, \ | |||||
| midout_iv(_midout_tag)) { \ | |||||
| if (param.filter_type.enumv() == param.src_type.enumv() && \ | |||||
| param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \ | |||||
| param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \ | |||||
| return std::make_unique< \ | |||||
| Conv1x1Strategy<_src_ctype, _bias_ctype, _dst_ctype, \ | |||||
| DTypeTrait<_i_bias_type>::ctype, \ | |||||
| DTypeTrait<_i_dst_type>::ctype, \ | |||||
| _postprocess_mode, _packmode>>(pack_size); \ | |||||
| } \ | |||||
| } \ | |||||
| #define cb2(_packmode, _i_src_type, _i_bias_type, _i_dst_type, _src_ctype, \ | |||||
| _bias_ctype, _dst_ctype, _postprocess_mode, _midout_tag) \ | |||||
| MIDOUT_BEGIN(megdnn_fallback_conv1x1_factory_strategy, \ | |||||
| midout_iv(_midout_tag)) { \ | |||||
| if (param.filter_type.enumv() == param.src_type.enumv() && \ | |||||
| param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \ | |||||
| param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \ | |||||
| return std::make_unique< \ | |||||
| Conv1x1Strategy<_src_ctype, _bias_ctype, _dst_ctype, \ | |||||
| DTypeTrait<_i_bias_type>::ctype, \ | |||||
| DTypeTrait<_i_dst_type>::ctype, \ | |||||
| _postprocess_mode, _packmode>>( \ | |||||
| pack_c_size); \ | |||||
| } \ | |||||
| } \ | |||||
| MIDOUT_END() | MIDOUT_END() | ||||
| switch (pack_mode) { | switch (pack_mode) { | ||||
| @@ -12,7 +12,6 @@ | |||||
| #pragma once | #pragma once | ||||
| #include "megdnn/opr_param_defs.h" | |||||
| #include "src/fallback/conv_bias/opr_impl.h" | #include "src/fallback/conv_bias/opr_impl.h" | ||||
| #if MEGDNN_X86 | #if MEGDNN_X86 | ||||
| #include "src/x86/conv_bias/postprocess_helper.h" | #include "src/x86/conv_bias/postprocess_helper.h" | ||||
| @@ -41,12 +40,15 @@ MatrixMulImpl::KernSizeParam get_matmul_kern_param( | |||||
| param.dst_type.enumv() == DTypeEnum::QuantizedS8) || | param.dst_type.enumv() == DTypeEnum::QuantizedS8) || | ||||
| (param.src_type.enumv() == DTypeEnum::Quantized8Asymm && | (param.src_type.enumv() == DTypeEnum::Quantized8Asymm && | ||||
| param.dst_type.enumv() == DTypeEnum::Quantized8Asymm); | param.dst_type.enumv() == DTypeEnum::Quantized8Asymm); | ||||
| size_t pack_c_size = 1_z; | |||||
| size_t pack_c_size = pack_size(param.filter_meta.format); | |||||
| auto format = param::MatrixMul::Format::DEFAULT; | auto format = param::MatrixMul::Format::DEFAULT; | ||||
| if(param.filter_meta.format == param::ConvBias::Format::NCHW44){ | |||||
| pack_c_size = 4_z; | |||||
| if (param.filter_meta.format == param::ConvBias::Format::NCHW44) { | |||||
| format = param::MatrixMul::Format::MK4; | format = param::MatrixMul::Format::MK4; | ||||
| } else if (param.filter_meta.format == | |||||
| param::ConvBias::Format::NCHW44_DOT) { | |||||
| format = param::MatrixMul::Format::MK4_DOT; | |||||
| } | } | ||||
| return {param.filter_type, | return {param.filter_type, | ||||
| param.src_type, | param.src_type, | ||||
| is_dst_8bit ? param.bias_type : param.dst_type, | is_dst_8bit ? param.bias_type : param.dst_type, | ||||
| @@ -15,7 +15,6 @@ | |||||
| #include "src/common/opr_delegate.h" | #include "src/common/opr_delegate.h" | ||||
| #include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
| #include "src/fallback/conv_bias/opr_impl.h" | #include "src/fallback/conv_bias/opr_impl.h" | ||||
| #include "src/fallback/conv_bias/winograd/strategy.h" | |||||
| #include "src/naive/convolution/helper.h" | #include "src/naive/convolution/helper.h" | ||||
| #include "midout.h" | #include "midout.h" | ||||
| @@ -125,7 +124,7 @@ public: | |||||
| size_t oc_tile_size) { | size_t oc_tile_size) { | ||||
| size_t IC = param.filter_meta.icpg, FH = param.filter_meta.spatial[0], | size_t IC = param.filter_meta.icpg, FH = param.filter_meta.spatial[0], | ||||
| FW = param.filter_meta.spatial[1]; | FW = param.filter_meta.spatial[1]; | ||||
| size_t pack_oc_size = get_format_pack_size(param.filter_meta.format); | |||||
| size_t pack_oc_size = pack_size(param.filter_meta.format); | |||||
| size_t im2col = 0, packb = 0, bias_temp = 0; | size_t im2col = 0, packb = 0, bias_temp = 0; | ||||
| bool default_pack = matmul_algo->packmode() == Pack_Mode::DEFAULT; | bool default_pack = matmul_algo->packmode() == Pack_Mode::DEFAULT; | ||||
| megdnn_assert(default_pack, "only support default packa"); | megdnn_assert(default_pack, "only support default packa"); | ||||
| @@ -319,9 +318,11 @@ ConvBiasImpl::AlgoIm2col ::get_matmul_kern_param(const NCBKernSizeParam& param, | |||||
| size_t ohw_tile_size, | size_t ohw_tile_size, | ||||
| size_t oc_tile_size) const { | size_t oc_tile_size) const { | ||||
| auto format = param::MatrixMul::Format::DEFAULT; | auto format = param::MatrixMul::Format::DEFAULT; | ||||
| size_t pack_oc_size = get_format_pack_size(param.filter_meta.format); | |||||
| size_t pack_oc_size = pack_size(param.filter_meta.format); | |||||
| if (param.filter_meta.format == param::ConvBias::Format::NCHW44) { | if (param.filter_meta.format == param::ConvBias::Format::NCHW44) { | ||||
| format = param::MatrixMul::Format::MK4; | format = param::MatrixMul::Format::MK4; | ||||
| } else if(param.filter_meta.format == param::ConvBias::Format::NCHW44_DOT){ | |||||
| format = param::MatrixMul::Format::MK4_DOT; | |||||
| } | } | ||||
| size_t M = oc_tile_size; | size_t M = oc_tile_size; | ||||
| size_t N = ohw_tile_size; | size_t N = ohw_tile_size; | ||||
| @@ -351,11 +352,10 @@ ConvBiasImpl::AlgoIm2col ::get_matmul_kern_param(const NCBKernSizeParam& param, | |||||
| void ConvBiasImpl::AlgoIm2col::choice_ohw_oc_block( | void ConvBiasImpl::AlgoIm2col::choice_ohw_oc_block( | ||||
| const NCBKernSizeParam& param, size_t& oc_tile_size, | const NCBKernSizeParam& param, size_t& oc_tile_size, | ||||
| size_t& ohw_tile_size, size_t block_m, size_t block_n, | size_t& ohw_tile_size, size_t block_m, size_t block_n, | ||||
| bool need_pack) const { | |||||
| fallback::MatrixMulImpl::AlgoBase::PackMode pack_mode) const { | |||||
| size_t nr_threads = param.nr_threads; | size_t nr_threads = param.nr_threads; | ||||
| size_t OC = param.filter_meta.ocpg; | size_t OC = param.filter_meta.ocpg; | ||||
| size_t ohw = param.osz[0] * param.osz[1]; | size_t ohw = param.osz[0] * param.osz[1]; | ||||
| oc_tile_size = DEFAULT_OC_TILE_SIZE; | oc_tile_size = DEFAULT_OC_TILE_SIZE; | ||||
| ohw_tile_size = m_ohw_tile_size; | ohw_tile_size = m_ohw_tile_size; | ||||
| @@ -376,7 +376,8 @@ void ConvBiasImpl::AlgoIm2col::choice_ohw_oc_block( | |||||
| } | } | ||||
| } | } | ||||
| } else { | } else { | ||||
| if (!need_pack) { //! no pack ,usually in x86 save memroy | |||||
| //! in no_pack mode don't do block operation when using single thread | |||||
| if (pack_mode == fallback::MatrixMulImpl::AlgoBase::PackMode::NO_PACK) { | |||||
| ohw_tile_size = ohw; | ohw_tile_size = ohw; | ||||
| oc_tile_size = OC; | oc_tile_size = OC; | ||||
| } | } | ||||
| @@ -406,7 +407,7 @@ WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_bundle( | |||||
| if (need_pack || only_packA) { | if (need_pack || only_packA) { | ||||
| auto inner_block = m_matmul_algo->get_inner_block_size(); | auto inner_block = m_matmul_algo->get_inner_block_size(); | ||||
| choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, inner_block.m, | choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, inner_block.m, | ||||
| inner_block.n, need_pack); | |||||
| inner_block.n, m_matmul_algo->packmode()); | |||||
| auto im2col_kern_param = get_matmul_kern_param( | auto im2col_kern_param = get_matmul_kern_param( | ||||
| param, ohw_tile_size, only_packA ? oc_tile_size : OC); | param, ohw_tile_size, only_packA ? oc_tile_size : OC); | ||||
| size_t oc_parallel_times = div_ceil<size_t>(OC, oc_tile_size); | size_t oc_parallel_times = div_ceil<size_t>(OC, oc_tile_size); | ||||
| @@ -418,7 +419,7 @@ WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_bundle( | |||||
| size_t nopack_default_blockn = 16; | size_t nopack_default_blockn = 16; | ||||
| choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, | choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, | ||||
| nopack_default_blockm, nopack_default_blockn, | nopack_default_blockm, nopack_default_blockn, | ||||
| need_pack); | |||||
| m_matmul_algo->packmode()); | |||||
| packa_group_size = 0; | packa_group_size = 0; | ||||
| } | } | ||||
| @@ -488,19 +489,20 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns( | |||||
| if (default_pack || only_packA) { | if (default_pack || only_packA) { | ||||
| auto inner_block = m_matmul_algo->get_inner_block_size(); | auto inner_block = m_matmul_algo->get_inner_block_size(); | ||||
| choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, | choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, | ||||
| inner_block.m, inner_block.n, default_pack); | |||||
| } else { //! not support pack,not need pack | |||||
| inner_block.m, inner_block.n, | |||||
| m_matmul_algo->packmode()); | |||||
| } else { //! nopack_mode | |||||
| size_t nopack_default_blockm = 8; | size_t nopack_default_blockm = 8; | ||||
| size_t nopack_default_blockn = 16; | size_t nopack_default_blockn = 16; | ||||
| choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, | choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, | ||||
| nopack_default_blockm, nopack_default_blockn, | nopack_default_blockm, nopack_default_blockn, | ||||
| no_pack); | |||||
| m_matmul_algo->packmode()); | |||||
| } | } | ||||
| size_t ohw_parallel_times = div_ceil(ohw, ohw_tile_size); | size_t ohw_parallel_times = div_ceil(ohw, ohw_tile_size); | ||||
| size_t oc_parallel_times = div_ceil<size_t>(OC, oc_tile_size); | size_t oc_parallel_times = div_ceil<size_t>(OC, oc_tile_size); | ||||
| size_t packa_parallel_times = 0; | size_t packa_parallel_times = 0; | ||||
| size_t pack_oc_size = get_format_pack_size(param.filter_meta.format); | |||||
| size_t pack_oc_size = pack_size(param.filter_meta.format); | |||||
| if (only_packA) { | if (only_packA) { | ||||
| packa_parallel_times = div_ceil<size_t>(OC, oc_tile_size); | packa_parallel_times = div_ceil<size_t>(OC, oc_tile_size); | ||||
| @@ -639,9 +641,15 @@ bool ConvBiasImpl::AlgoIm2col::usable( | |||||
| ConvBiasImpl* opr, const NCBKernSizeParam& param, | ConvBiasImpl* opr, const NCBKernSizeParam& param, | ||||
| AlgoSelectionStrategy /*algo_selection_strategy*/) const { | AlgoSelectionStrategy /*algo_selection_strategy*/) const { | ||||
| MIDOUT_BEGIN(megdnn_fallback_im2col, 0, 2) { | MIDOUT_BEGIN(megdnn_fallback_im2col, 0, 2) { | ||||
| if (opr->param().format != param::ConvBias::Format::NCHW && | |||||
| opr->param().format != param::ConvBias::Format::NCHW44_DOT && | |||||
| opr->param().format != param::ConvBias::Format::NCHW44) { | |||||
| return false; | |||||
| } | |||||
| //! make sure 8x8x16 and 8x8x32 biasmode is nobias and nonlineMode is | //! make sure 8x8x16 and 8x8x32 biasmode is nobias and nonlineMode is | ||||
| //! identity otherwise return false mean that 8x8x32 and 8x8x16 not support | |||||
| //! PostProcess | |||||
| //! identity otherwise return false mean that 8x8x32 and 8x8x16 not | |||||
| //! support PostProcess | |||||
| if (param.src_type.enumv() == param.filter_type.enumv() && | if (param.src_type.enumv() == param.filter_type.enumv() && | ||||
| ((param.src_type.enumv() == DTypeEnum::Int8 && | ((param.src_type.enumv() == DTypeEnum::Int8 && | ||||
| (param.dst_type.enumv() == DTypeEnum::Int16 || | (param.dst_type.enumv() == DTypeEnum::Int16 || | ||||
| @@ -653,9 +661,10 @@ bool ConvBiasImpl::AlgoIm2col::usable( | |||||
| param.nonlineMode != megdnn::NonlineMode::IDENTITY) { | param.nonlineMode != megdnn::NonlineMode::IDENTITY) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| if (opr->param().format == param::ConvBias::Format::NCHW44) { | |||||
| if (opr->param().format == param::ConvBias::Format::NCHW44 || | |||||
| opr->param().format == param::ConvBias::Format::NCHW44_DOT) { | |||||
| //! current NCHW44 im2col only support DEFAULT mode matmul | //! current NCHW44 im2col only support DEFAULT mode matmul | ||||
| if(m_matmul_algo->packmode() != Pack_Mode::DEFAULT) { | |||||
| if (m_matmul_algo->packmode() != Pack_Mode::DEFAULT) { | |||||
| return false; | return false; | ||||
| //! nchw44 hybird mode and channel wise is not support | //! nchw44 hybird mode and channel wise is not support | ||||
| } else if (param.filter_meta.icpg < 4_z || | } else if (param.filter_meta.icpg < 4_z || | ||||
| @@ -668,29 +677,27 @@ bool ConvBiasImpl::AlgoIm2col::usable( | |||||
| size_t oc_tile_size = 0, ohw_tile_size = 0; | size_t oc_tile_size = 0, ohw_tile_size = 0; | ||||
| Pack_Mode packmode = m_matmul_algo->packmode(); | Pack_Mode packmode = m_matmul_algo->packmode(); | ||||
| bool default_pack = packmode == Pack_Mode::DEFAULT; | bool default_pack = packmode == Pack_Mode::DEFAULT; | ||||
| bool no_pack = packmode == Pack_Mode::NO_PACK; | |||||
| bool only_packA = packmode == Pack_Mode::ONLY_PACKA; | bool only_packA = packmode == Pack_Mode::ONLY_PACKA; | ||||
| if (default_pack || only_packA) { | if (default_pack || only_packA) { | ||||
| auto inner_block = m_matmul_algo->get_inner_block_size(); | auto inner_block = m_matmul_algo->get_inner_block_size(); | ||||
| choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, | choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, | ||||
| inner_block.m, inner_block.n, default_pack); | |||||
| inner_block.m, inner_block.n, | |||||
| m_matmul_algo->packmode()); | |||||
| } else { //! not support pack,not need pack | } else { //! not support pack,not need pack | ||||
| size_t nopack_default_blockm = 8; | size_t nopack_default_blockm = 8; | ||||
| size_t nopack_default_blockn = 16; | size_t nopack_default_blockn = 16; | ||||
| choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, | choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, | ||||
| nopack_default_blockm, nopack_default_blockn, | nopack_default_blockm, nopack_default_blockn, | ||||
| no_pack); | |||||
| m_matmul_algo->packmode()); | |||||
| } | } | ||||
| fallback::MatrixMulImpl::KernSizeParam matmul_param = | fallback::MatrixMulImpl::KernSizeParam matmul_param = | ||||
| get_matmul_kern_param(param, ohw_tile_size, oc_tile_size); | get_matmul_kern_param(param, ohw_tile_size, oc_tile_size); | ||||
| bool matmulusable = m_matmul_algo->usable(matmul_param); | bool matmulusable = m_matmul_algo->usable(matmul_param); | ||||
| return matmulusable && | return matmulusable && | ||||
| (opr->param().format == param::ConvBias::Format::NCHW || | |||||
| opr->param().format == param::ConvBias::Format::NCHW44) && | |||||
| (!(param.filter_meta.spatial[0] == | (!(param.filter_meta.spatial[0] == | ||||
| param.filter_meta.spatial[1] && | param.filter_meta.spatial[1] && | ||||
| (param.filter_meta.spatial[0] == 1) && | |||||
| param.filter_meta.spatial[0] == 1 && | |||||
| param.filter_meta.stride[0] == param.filter_meta.stride[1] && | param.filter_meta.stride[0] == param.filter_meta.stride[1] && | ||||
| param.filter_meta.stride[0] == 1)) && | param.filter_meta.stride[0] == 1)) && | ||||
| (param.filter_meta.dilation[0] == | (param.filter_meta.dilation[0] == | ||||
| @@ -36,10 +36,10 @@ class ConvBiasImpl::AlgoIm2col final : public AlgoBase { | |||||
| const NCBKernSizeParam& param, size_t ohw_tile_size, | const NCBKernSizeParam& param, size_t ohw_tile_size, | ||||
| size_t oc_tile_size) const; | size_t oc_tile_size) const; | ||||
| WorkspaceBundle get_bundle(const NCBKernSizeParam& param) const; | WorkspaceBundle get_bundle(const NCBKernSizeParam& param) const; | ||||
| void choice_ohw_oc_block(const NCBKernSizeParam& param, | |||||
| size_t& oc_tile_size, size_t& ohw_tile_size, | |||||
| size_t block_m, size_t block_n, | |||||
| bool pack_default) const; | |||||
| void choice_ohw_oc_block( | |||||
| const NCBKernSizeParam& param, size_t& oc_tile_size, | |||||
| size_t& ohw_tile_size, size_t block_m, size_t block_n, | |||||
| fallback::MatrixMulImpl::AlgoBase::PackMode pack_mode) const; | |||||
| public: | public: | ||||
| AlgoIm2col(MatrixMulImpl::AlgoBase* matmul_algo, size_t ohw_tile_size) | AlgoIm2col(MatrixMulImpl::AlgoBase* matmul_algo, size_t ohw_tile_size) | ||||
| @@ -230,7 +230,11 @@ public: | |||||
| PostprocessMode::FLOAT, | PostprocessMode::FLOAT, | ||||
| "DefaultStrategyTypeNCHW44::FLOAT"_hash); | "DefaultStrategyTypeNCHW44::FLOAT"_hash); | ||||
| } else { | } else { | ||||
| megdnn_throw("not support format except nchw44 and nchw\n"); | |||||
| megdnn_throw( | |||||
| ssprintf("Current only support layout " | |||||
| "NCHW44/NCHW for im2col " | |||||
| "algo, but got %d\n", | |||||
| uint32_t(format))); | |||||
| } | } | ||||
| break; | break; | ||||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
| @@ -252,12 +256,17 @@ public: | |||||
| cb2(NCHW, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8, | cb2(NCHW, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8, | ||||
| dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | ||||
| "DefaultStrategyType::INT8x8x32"_hash); | "DefaultStrategyType::INT8x8x32"_hash); | ||||
| } else if (format == param::ConvBias::Format::NCHW44) { | |||||
| } else if (format == param::ConvBias::Format::NCHW44 || | |||||
| format == param::ConvBias::Format::NCHW44_DOT) { | |||||
| cb2(NCHW44, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8, | cb2(NCHW44, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8, | ||||
| dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | ||||
| "DefaultStrategyType::INT8x8x32"_hash); | "DefaultStrategyType::INT8x8x32"_hash); | ||||
| } else { | } else { | ||||
| megdnn_throw("not support format except nchw44 and nchw\n"); | |||||
| megdnn_throw( | |||||
| ssprintf("Current only support layout " | |||||
| "NCHW44/NCHW/NCHW_DOT for im2col " | |||||
| "algo, but got %d\n", | |||||
| uint32_t(format))); | |||||
| } | } | ||||
| break; | break; | ||||
| @@ -288,13 +297,18 @@ public: | |||||
| dtype::QuantizedS32, dt_int8, dt_int32, dt_int32, | dtype::QuantizedS32, dt_int8, dt_int32, dt_int32, | ||||
| PostprocessMode::NO_PROCESS, | PostprocessMode::NO_PROCESS, | ||||
| "DefaultStrategyTypeNCHW::QINT8x8x32"_hash); | "DefaultStrategyTypeNCHW::QINT8x8x32"_hash); | ||||
| } else if (format == param::ConvBias::Format::NCHW44) { | |||||
| } else if (format == param::ConvBias::Format::NCHW44 || | |||||
| format == param::ConvBias::Format::NCHW44_DOT) { | |||||
| cb2(NCHW44, DEFAULT, dtype::QuantizedS8, | cb2(NCHW44, DEFAULT, dtype::QuantizedS8, | ||||
| dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, | dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, | ||||
| dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | ||||
| "DefaultStrategyTypeHCHW44::QINT8x8x32"_hash); | "DefaultStrategyTypeHCHW44::QINT8x8x32"_hash); | ||||
| } else { | } else { | ||||
| megdnn_throw("not support format except nchw44 and nchw\n"); | |||||
| megdnn_throw( | |||||
| ssprintf("Current only support layout " | |||||
| "NCHW44/NCHW/NCHW_DOT for im2col " | |||||
| "algo, but got %d\n", | |||||
| uint32_t(format))); | |||||
| } | } | ||||
| break; | break; | ||||
| @@ -304,17 +318,22 @@ public: | |||||
| dtype::QuantizedS8, dt_int8, dt_int32, dt_int8, | dtype::QuantizedS8, dt_int8, dt_int32, dt_int8, | ||||
| PostprocessMode::QUANTIZED, | PostprocessMode::QUANTIZED, | ||||
| "DefaultStrategyType::QINT8x8x32x8"_hash); | "DefaultStrategyType::QINT8x8x32x8"_hash); | ||||
| } else if (format == param::ConvBias::Format::NCHW44) { | |||||
| } else if (format == param::ConvBias::Format::NCHW44 || | |||||
| format == param::ConvBias::Format::NCHW44_DOT) { | |||||
| cb2(NCHW44, DEFAULT, dtype::QuantizedS8, | cb2(NCHW44, DEFAULT, dtype::QuantizedS8, | ||||
| dtype::QuantizedS32, dtype::QuantizedS8, dt_int8, | dtype::QuantizedS32, dtype::QuantizedS8, dt_int8, | ||||
| dt_int32, dt_int8, PostprocessMode::QUANTIZED, | dt_int32, dt_int8, PostprocessMode::QUANTIZED, | ||||
| "DefaultStrategyTypeNCHW44::QINT8x8x32x8"_hash); | "DefaultStrategyTypeNCHW44::QINT8x8x32x8"_hash); | ||||
| } else { | } else { | ||||
| megdnn_throw("not support format except nchw44 and nchw\n"); | |||||
| megdnn_throw(ssprintf("Current only support layout " | |||||
| "NCHW44/NCHW/NCHW_DOT for im2col " | |||||
| "algo, but got %d\n", | |||||
| uint32_t(format))); | |||||
| } | } | ||||
| break; | break; | ||||
| } | } | ||||
| megdnn_throw("error not support strategy type "); | |||||
| megdnn_throw(ssprintf("Unsupported strategy type %u in default mode", | |||||
| uint32_t(strategytype))); | |||||
| } | } | ||||
| static std::unique_ptr<StrategyBase> make_nopack_strategy( | static std::unique_ptr<StrategyBase> make_nopack_strategy( | ||||
| @@ -328,10 +347,6 @@ public: | |||||
| PostprocessMode::FLOAT, "NoPackStrategyType::FLOAT"_hash); | PostprocessMode::FLOAT, "NoPackStrategyType::FLOAT"_hash); | ||||
| break; | break; | ||||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
| case StrategyType::FLOAT_FP16: | |||||
| cb1(NCHW, NO_PACK, dt_float16, __fp16, PostprocessMode::FLOAT, | |||||
| "NoPackStrategyType::FLOAT_FP16"_hash); | |||||
| break; | |||||
| #else | #else | ||||
| #if !MEGDNN_DISABLE_FLOAT16 | #if !MEGDNN_DISABLE_FLOAT16 | ||||
| case StrategyType::FLOAT16_FLOAT16: | case StrategyType::FLOAT16_FLOAT16: | ||||
| @@ -341,48 +356,24 @@ public: | |||||
| break; | break; | ||||
| #endif | #endif | ||||
| #endif | #endif | ||||
| case StrategyType::INT8x8x32: | |||||
| cb2(NCHW, NO_PACK, dt_int8, dt_int32, dt_int32, dt_int8, | |||||
| dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||||
| "NoPackStrategyType::INT8x8x32"_hash); | |||||
| break; | |||||
| case StrategyType::INT8x8x16: | case StrategyType::INT8x8x16: | ||||
| cb2(NCHW, NO_PACK, dt_int8, dt_int16, dt_int16, dt_int8, | cb2(NCHW, NO_PACK, dt_int8, dt_int16, dt_int16, dt_int8, | ||||
| dt_int16, dt_int16, PostprocessMode::NO_PROCESS, | dt_int16, dt_int16, PostprocessMode::NO_PROCESS, | ||||
| "NoPackStrategyType::INT8x8x16"_hash); | "NoPackStrategyType::INT8x8x16"_hash); | ||||
| break; | break; | ||||
| #if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||||
| case StrategyType::QUINT8x8x32: | |||||
| cb2(NCHW, NO_PACK, dtype::Quantized8Asymm, dtype::QuantizedS32, | |||||
| dtype::QuantizedS32, dt_uint8, dt_int32, dt_int32, | |||||
| PostprocessMode::NO_PROCESS, | |||||
| "NoPackStrategyType::QUINT8x8x32"_hash); | |||||
| break; | |||||
| case StrategyType::QUINT8x8x32x8: | |||||
| cb2(NCHW, NO_PACK, dtype::Quantized8Asymm, dtype::QuantizedS32, | |||||
| dtype::Quantized8Asymm, dt_uint8, dt_int32, dt_uint8, | |||||
| PostprocessMode::QUANTIZED, | |||||
| "NoPackStrategyType::QUINT8x8x32x8"_hash); | |||||
| break; | |||||
| #endif | |||||
| case StrategyType::QINT8x8x32: | |||||
| cb2(NCHW, NO_PACK, dtype::QuantizedS8, dtype::QuantizedS32, | |||||
| dtype::QuantizedS32, dt_int8, dt_int32, dt_int32, | |||||
| PostprocessMode::NO_PROCESS, | |||||
| "NoPackStrategyType::QINT8x8x32"_hash); | |||||
| case StrategyType::INT8x8x32: | |||||
| cb2(NCHW, NO_PACK, dt_int8, dt_int32, dt_int32, dt_int8, | |||||
| dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||||
| "NoPackStrategyType::INT8x8x32"_hash); | |||||
| break; | break; | ||||
| case StrategyType::QINT8x8x32x8: | |||||
| cb2(NCHW, NO_PACK, dtype::QuantizedS8, dtype::QuantizedS32, | |||||
| dtype::QuantizedS8, dt_int8, dt_int32, dt_int8, | |||||
| PostprocessMode::QUANTIZED, | |||||
| "NoPackStrategyType::QINT8x8x32x8"_hash); | |||||
| default: | |||||
| megdnn_throw( | |||||
| ssprintf("Unsupported strategy type %u in no_pack mode", | |||||
| uint32_t(strategytype))); | |||||
| break; | break; | ||||
| } | } | ||||
| megdnn_throw("error not support strategy type "); | |||||
| megdnn_throw(ssprintf("Unsupported strategy type %u in no_pack mode", | |||||
| uint32_t(strategytype))); | |||||
| } | } | ||||
| static std::unique_ptr<StrategyBase> make_onlypacka_strategy( | static std::unique_ptr<StrategyBase> make_onlypacka_strategy( | ||||
| @@ -396,63 +387,14 @@ public: | |||||
| PostprocessMode::FLOAT, | PostprocessMode::FLOAT, | ||||
| "OnlyPackaStrategyType::FLOAT"_hash); | "OnlyPackaStrategyType::FLOAT"_hash); | ||||
| break; | break; | ||||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
| case StrategyType::FLOAT_FP16: | |||||
| cb1(NCHW, ONLY_PACKA, dt_float16, __fp16, | |||||
| PostprocessMode::FLOAT, | |||||
| "OnlyPackaStrategyType::FLOAT_FP16"_hash); | |||||
| break; | |||||
| #else | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| case StrategyType::FLOAT16_FLOAT16: | |||||
| cb1(NCHW, ONLY_PACKA, dt_float16, dt_float16, | |||||
| PostprocessMode::NO_PROCESS, | |||||
| "OnlyPackaStrategyType::FLOAT16_FLOAT16"_hash); | |||||
| break; | |||||
| #endif | |||||
| #endif | |||||
| case StrategyType::INT8x8x32: | |||||
| cb2(NCHW, ONLY_PACKA, dt_int8, dt_int32, dt_int32, dt_int8, | |||||
| dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||||
| "OnlyPackaStrategyType::INT8x8x32"_hash); | |||||
| break; | |||||
| case StrategyType::INT8x8x16: | |||||
| cb2(NCHW, ONLY_PACKA, dt_int8, dt_int16, dt_int16, dt_int8, | |||||
| dt_int16, dt_int16, PostprocessMode::NO_PROCESS, | |||||
| "OnlyPackaStrategyType::INT8x8x16"_hash); | |||||
| break; | |||||
| #if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||||
| case StrategyType::QUINT8x8x32: | |||||
| cb2(NCHW, ONLY_PACKA, dtype::Quantized8Asymm, | |||||
| dtype::QuantizedS32, dtype::QuantizedS32, dt_uint8, | |||||
| dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||||
| "OnlyPackaStrategyType::QUINT8x8x32"_hash); | |||||
| break; | |||||
| case StrategyType::QUINT8x8x32x8: | |||||
| cb2(NCHW, ONLY_PACKA, dtype::Quantized8Asymm, | |||||
| dtype::QuantizedS32, dtype::Quantized8Asymm, dt_uint8, | |||||
| dt_int32, dt_uint8, PostprocessMode::QUANTIZED, | |||||
| "OnlyPackaStrategyType::QUINT8x8x32x8"_hash); | |||||
| break; | |||||
| #endif | |||||
| case StrategyType::QINT8x8x32: | |||||
| cb2(NCHW, ONLY_PACKA, dtype::QuantizedS8, dtype::QuantizedS32, | |||||
| dtype::QuantizedS32, dt_int8, dt_int32, dt_int32, | |||||
| PostprocessMode::NO_PROCESS, | |||||
| "OnlyPackaStrategyType::QINT8x8x32"_hash); | |||||
| break; | |||||
| case StrategyType::QINT8x8x32x8: | |||||
| cb2(NCHW, ONLY_PACKA, dtype::QuantizedS8, dtype::QuantizedS32, | |||||
| dtype::QuantizedS8, dt_int8, dt_int32, dt_int8, | |||||
| PostprocessMode::QUANTIZED, | |||||
| "OnlyPackaStrategyType::QINT8x8x32x8"_hash); | |||||
| default: | |||||
| megdnn_throw(ssprintf( | |||||
| "Unsupported strategy type %u in onlypacka mode", | |||||
| uint32_t(strategytype))); | |||||
| break; | break; | ||||
| } | } | ||||
| megdnn_throw("error not support strategy type "); | |||||
| megdnn_throw(ssprintf("Unsupported strategy type %u in onlypacka mode", | |||||
| uint32_t(strategytype))); | |||||
| } | } | ||||
| #undef cb1 | #undef cb1 | ||||
| @@ -11,6 +11,16 @@ | |||||
| #pragma once | #pragma once | ||||
| #include "src/fallback/conv_bias/opr_impl.h" | #include "src/fallback/conv_bias/opr_impl.h" | ||||
| #if MEGDNN_X86 | |||||
| #include "src/x86/conv_bias/postprocess_helper.h" | |||||
| #elif (MEGDNN_ARMV7 || MEGDNN_AARCH64) | |||||
| #include "src/arm_common/conv_bias/postprocess_helper.h" | |||||
| #endif | |||||
| using namespace megdnn; | |||||
| #if MEGDNN_X86 | |||||
| using namespace x86; | |||||
| #endif | |||||
| namespace megdnn { | namespace megdnn { | ||||
| using PackMode = fallback::MatrixMulImpl::AlgoBase::PackMode; | using PackMode = fallback::MatrixMulImpl::AlgoBase::PackMode; | ||||
| @@ -72,6 +82,185 @@ public: | |||||
| const StrategyParam& sparam, WorkspaceBundle bundle_thread) = 0; | const StrategyParam& sparam, WorkspaceBundle bundle_thread) = 0; | ||||
| }; | }; | ||||
| template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
| typename op_ctype, typename op_dtype, | |||||
| megdnn::PostprocessMode postprocess_mode, PackMode packmode, | |||||
| FormatMode format> | |||||
| //! this class is a new base class for StrategyDefault StrategyNoPack and so on, | |||||
| //! in order to handle copy pad use the same code | |||||
| class StrategyBridge : public StrategyBase { | |||||
| public: | |||||
| constexpr static size_t BUNDLE_PADDING_INDEX = 0; | |||||
| StrategyBridge() = default; | |||||
| virtual void copy_padding_kern( | |||||
| WorkspaceBundle bundle, | |||||
| const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
| const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
| size_t pack_oc_size) override { | |||||
| UNPACK_CONV_F32_NCB_KERN_SIZES(param); | |||||
| MEGDNN_MARK_USED_VAR(N); | |||||
| MEGDNN_MARK_USED_VAR(OC); | |||||
| MEGDNN_MARK_USED_VAR(OH); | |||||
| MEGDNN_MARK_USED_VAR(OW); | |||||
| MEGDNN_MARK_USED_VAR(FH); | |||||
| MEGDNN_MARK_USED_VAR(FW); | |||||
| MEGDNN_MARK_USED_VAR(SH); | |||||
| MEGDNN_MARK_USED_VAR(SW); | |||||
| size_t IW2 = IW + 2 * PW; | |||||
| size_t IH2 = IH + 2 * PH; | |||||
| size_t batch_id = ncb_index.ndrange_id[0]; | |||||
| size_t group_id = ncb_index.ndrange_id[1]; | |||||
| size_t channel_id = ncb_index.ndrange_id[2]; | |||||
| size_t PH_SIZE = PH * IW2 * pack_oc_size; | |||||
| PW = PW * pack_oc_size; | |||||
| IW = IW * pack_oc_size; | |||||
| size_t padding_group_size = IH2 * IW2 * IC; | |||||
| size_t workspace_channel_offset = pack_oc_size * IH2 * IW2 * channel_id; | |||||
| size_t workspace_group_offset = group_id * padding_group_size; | |||||
| size_t workspace_batch_offset = | |||||
| param.filter_meta.group * batch_id * padding_group_size; | |||||
| bundle.set(param.workspace_ptr); | |||||
| src_ctype src_zp = static_cast<src_ctype>(0); | |||||
| if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) { | |||||
| src_zp = param.src_type.param<dtype::Quantized8Asymm>().zero_point; | |||||
| } | |||||
| src_ctype* src = const_cast<src_ctype*>(param.src<src_ctype>( | |||||
| batch_id, group_id, channel_id, 1, pack_oc_size)); | |||||
| src_ctype* src2; | |||||
| src2 = static_cast<src_ctype*>(bundle.get(BUNDLE_PADDING_INDEX)) + | |||||
| workspace_group_offset + workspace_batch_offset + | |||||
| workspace_channel_offset; | |||||
| src_ctype* src2_ptr = src2; | |||||
| const src_ctype* src_ptr = src; | |||||
| if (PH != 0) { | |||||
| std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH_SIZE); | |||||
| src2_ptr += PH_SIZE; | |||||
| } | |||||
| rep(ih, IH) { | |||||
| if (PW != 0) | |||||
| rep(pw, PW) * (src2_ptr++) = src_zp; | |||||
| std::memcpy(src2_ptr, src_ptr, sizeof(src_ctype) * IW); | |||||
| src2_ptr += IW; | |||||
| src_ptr += IW; | |||||
| if (PW != 0) | |||||
| rep(pw, PW) * (src2_ptr++) = src_zp; | |||||
| } | |||||
| if (PH != 0) { | |||||
| std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH_SIZE); | |||||
| src2_ptr += PH_SIZE; | |||||
| } | |||||
| } | |||||
| }; | |||||
| namespace{ | |||||
| template <typename bias_ctype> | |||||
| inline void* get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
| const WorkspaceBundle& bundle_thread, | |||||
| const StrategyParam& sparam, | |||||
| size_t matmul_bundle_index) { | |||||
| if (sparam.is_dst_8bit || !sparam.is_ohw_size_bigger) { | |||||
| return static_cast<void*>(bundle_thread.get(matmul_bundle_index)); | |||||
| } else { | |||||
| bias_ctype* dst = | |||||
| param.dst<bias_ctype>(sparam.batch_id, sparam.group_id) + | |||||
| sparam.oc_cur_index * sparam.ohw; | |||||
| return static_cast<void*>(dst); | |||||
| } | |||||
| } | |||||
| template <typename bias_ctype> | |||||
| inline void* get_bias_temp_ptr( | |||||
| const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
| const WorkspaceBundle& bundle_thread, size_t bias_bundle_index) { | |||||
| bias_ctype* bias_tmp_ptr = | |||||
| param.bias_mode == megdnn::BiasMode::BIAS | |||||
| ? static_cast<bias_ctype*>( | |||||
| bundle_thread.get(bias_bundle_index)) | |||||
| : nullptr; | |||||
| return bias_tmp_ptr; | |||||
| } | |||||
| template <typename dst_ctype> | |||||
| void copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
| const void* matmul_dst, const StrategyParam& sparam) { | |||||
| if (!sparam.skip_copy_dst) { | |||||
| size_t pack_oc_size = sparam.pack_oc_size; | |||||
| dst_ctype* dst_tmp_ptr = | |||||
| reinterpret_cast<dst_ctype*>(const_cast<void*>(matmul_dst)); | |||||
| dst_ctype* dst = | |||||
| param.dst<dst_ctype>(sparam.batch_id, sparam.group_id) + | |||||
| sparam.oc_cur_index * sparam.ohw + | |||||
| sparam.ohw_cur_index * pack_oc_size; | |||||
| size_t oc_loop = sparam.output_block_oc_size / pack_oc_size; | |||||
| for (size_t oc = 0; oc < oc_loop; oc++) { | |||||
| std::memcpy(dst, dst_tmp_ptr, | |||||
| sizeof(dst_ctype) * sparam.output_block_size * | |||||
| pack_oc_size); | |||||
| dst_tmp_ptr += sparam.output_block_size * pack_oc_size; | |||||
| dst += sparam.ohw * pack_oc_size; | |||||
| } | |||||
| } | |||||
| } | |||||
| template <typename bias_ctype> | |||||
| void copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
| WorkspaceBundle bundle_thread, const StrategyParam& sparam, | |||||
| size_t bias_index) { | |||||
| const bias_ctype* bias_ptr = static_cast<const bias_ctype*>( | |||||
| param.bias<bias_ctype>(sparam.batch_id, sparam.group_id)); | |||||
| bias_ctype* bias_temp_ptr = static_cast<bias_ctype*>( | |||||
| get_bias_temp_ptr<bias_ctype>(param, bundle_thread, bias_index)); | |||||
| if (param.bias_mode == megdnn::BiasMode::BIAS) { | |||||
| bias_ctype* copy_dst = bias_temp_ptr; | |||||
| size_t pack_oc_size = sparam.pack_oc_size; | |||||
| const bias_ctype* copy_src = bias_ptr + | |||||
| sparam.oc_cur_index * sparam.ohw + | |||||
| sparam.ohw_cur_index * pack_oc_size; | |||||
| for (size_t oc = sparam.oc_cur_index / pack_oc_size; | |||||
| oc < sparam.oc_end_index / pack_oc_size; oc++) { | |||||
| std::memcpy(copy_dst, copy_src, | |||||
| sizeof(bias_ctype) * sparam.output_block_size * | |||||
| pack_oc_size); | |||||
| copy_dst += sparam.output_block_size * pack_oc_size; | |||||
| copy_src += sparam.ohw * pack_oc_size; | |||||
| } | |||||
| } | |||||
| } | |||||
| template <typename bias_ctype, typename dst_ctype, | |||||
| typename op_ctype, typename op_dtype, | |||||
| megdnn::PostprocessMode postprocess_mode> | |||||
| void do_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
| const StrategyParam& sparam, WorkspaceBundle bundle_thread, | |||||
| size_t matmul_bundle_index, size_t bias_bundle_index) { | |||||
| copy_bias<bias_ctype>(param, bundle_thread, sparam, bias_bundle_index); | |||||
| void* matmul_dst = get_matmul_dst_ptr<bias_ctype>( | |||||
| param, bundle_thread, sparam, matmul_bundle_index); | |||||
| const bias_ctype* bias_ptr = static_cast<const bias_ctype*>( | |||||
| param.bias<bias_ctype>(sparam.batch_id, sparam.group_id)); | |||||
| void* bias_temp_ptr = get_bias_temp_ptr<bias_ctype>(param, bundle_thread, | |||||
| bias_bundle_index); | |||||
| void* bias_preprocess_ptr = const_cast<void*>( | |||||
| param.bias_mode == megdnn::BiasMode::BIAS | |||||
| ? bias_temp_ptr | |||||
| : static_cast<void*>(const_cast<bias_ctype*>( | |||||
| bias_ptr + sparam.oc_cur_index))); | |||||
| size_t pack_oc_size = sparam.pack_oc_size; | |||||
| PostProcess<op_ctype, op_dtype, postprocess_mode>::run( | |||||
| matmul_dst, bias_preprocess_ptr, matmul_dst, param.bias_mode, | |||||
| param.nonlineMode, param.bias_type, param.dst_type, 1_z, | |||||
| sparam.output_block_oc_size / pack_oc_size, 1_z, | |||||
| sparam.output_block_size, pack_oc_size); | |||||
| copy_dst<dst_ctype>(param, matmul_dst, sparam); | |||||
| } | |||||
| } | |||||
| template <typename src_ctype, typename bias_ctype, typename dst_ctype, | template <typename src_ctype, typename bias_ctype, typename dst_ctype, | ||||
| typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
| megdnn::PostprocessMode postprocess_mode, PackMode packmode, | megdnn::PostprocessMode postprocess_mode, PackMode packmode, | ||||
| @@ -82,7 +271,10 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
| typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
| megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
| class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
| postprocess_mode, PackMode::DEFAULT> : public StrategyBase { | |||||
| postprocess_mode, PackMode::DEFAULT> | |||||
| : public StrategyBridge<src_ctype, bias_ctype, dst_ctype, op_ctype, | |||||
| op_dtype, postprocess_mode, PackMode::DEFAULT, | |||||
| FormatMode::NCHW> { | |||||
| public: | public: | ||||
| constexpr static size_t BUNDLE_PADDING_INDEX = 0; | constexpr static size_t BUNDLE_PADDING_INDEX = 0; | ||||
| constexpr static size_t BUNDLE_PACKA_INDEX = 1; | constexpr static size_t BUNDLE_PACKA_INDEX = 1; | ||||
| @@ -92,13 +284,7 @@ public: | |||||
| Strategy() = default; | Strategy() = default; | ||||
| void copy_padding_kern( | |||||
| WorkspaceBundle bundle, | |||||
| const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
| const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
| size_t pack_size) override; | |||||
| void packA_kern(WorkspaceBundle bundle, | |||||
| virtual void packA_kern(WorkspaceBundle bundle, | |||||
| const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
| fallback::MatrixMulImpl::KernSizeParam matmulparam, | fallback::MatrixMulImpl::KernSizeParam matmulparam, | ||||
| fallback::MatrixMulImpl::AlgoBase* matmul_algo, | fallback::MatrixMulImpl::AlgoBase* matmul_algo, | ||||
| @@ -120,16 +306,13 @@ public: | |||||
| const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) override; | const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) override; | ||||
| void exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, | void exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
| const StrategyParam& sparam, | const StrategyParam& sparam, | ||||
| WorkspaceBundle bundle_thread) override; | |||||
| void copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
| const void* matmul_dst, const StrategyParam& sparam); | |||||
| void copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
| WorkspaceBundle bundle_thread, const StrategyParam& sparam); | |||||
| WorkspaceBundle bundle_thread) override { | |||||
| do_postprocess<bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
| postprocess_mode>(param, sparam, bundle_thread, | |||||
| THREAD_BUNDLE_IM2COL_INDEX, | |||||
| THREAD_BUNDLE_BIAS_INDEX); | |||||
| } | |||||
| void* get_bias_temp_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
| const WorkspaceBundle& bundle_thread); | |||||
| void* get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | void* get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
| const WorkspaceBundle& bundle_thread, | const WorkspaceBundle& bundle_thread, | ||||
| const StrategyParam& sparam); | const StrategyParam& sparam); | ||||
| @@ -162,7 +345,10 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
| typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
| megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
| class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
| postprocess_mode, PackMode::NO_PACK> : public StrategyBase { | |||||
| postprocess_mode, PackMode::NO_PACK> | |||||
| : public StrategyBridge<src_ctype, bias_ctype, dst_ctype, op_ctype, | |||||
| op_dtype, postprocess_mode, PackMode::NO_PACK, | |||||
| FormatMode::NCHW> { | |||||
| public: | public: | ||||
| constexpr static size_t BUNDLE_PADDING_INDEX = 0; | constexpr static size_t BUNDLE_PADDING_INDEX = 0; | ||||
| constexpr static size_t BUNDLE_PACKA_INDEX = 1; | constexpr static size_t BUNDLE_PACKA_INDEX = 1; | ||||
| @@ -173,12 +359,6 @@ public: | |||||
| Strategy() = default; | Strategy() = default; | ||||
| void copy_padding_kern( | |||||
| WorkspaceBundle bundle, | |||||
| const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
| const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
| size_t pack_size) override; | |||||
| void packA_kern(WorkspaceBundle bundle, | void packA_kern(WorkspaceBundle bundle, | ||||
| const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
| fallback::MatrixMulImpl::KernSizeParam matmulparam, | fallback::MatrixMulImpl::KernSizeParam matmulparam, | ||||
| @@ -198,17 +378,6 @@ public: | |||||
| const WorkspaceBundle& bundle_thread, | const WorkspaceBundle& bundle_thread, | ||||
| const StrategyParam& sparam); | const StrategyParam& sparam); | ||||
| inline void* get_bias_temp_ptr( | |||||
| const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
| const WorkspaceBundle& bundle_thread) { | |||||
| bias_ctype* bias_tmp_ptr = | |||||
| param.bias_mode == megdnn::BiasMode::BIAS | |||||
| ? static_cast<bias_ctype*>( | |||||
| bundle_thread.get(THREAD_BUNDLE_BIAS_INDEX)) | |||||
| : nullptr; | |||||
| return bias_tmp_ptr; | |||||
| } | |||||
| void exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | void exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | ||||
| const StrategyParam& sparam, | const StrategyParam& sparam, | ||||
| const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
| @@ -216,19 +385,22 @@ public: | |||||
| fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | ||||
| void exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, | void exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
| const StrategyParam& sparam, | const StrategyParam& sparam, | ||||
| WorkspaceBundle bundle_thread) override; | |||||
| void copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
| const void* matmul_dst, const StrategyParam& sparam); | |||||
| void copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
| WorkspaceBundle bundle_thread, const StrategyParam& sparam); | |||||
| WorkspaceBundle bundle_thread) override { | |||||
| do_postprocess<bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
| postprocess_mode>(param, sparam, bundle_thread, | |||||
| THREAD_BUNDLE_MATMULDST_INDEX, | |||||
| THREAD_BUNDLE_BIAS_INDEX); | |||||
| } | |||||
| }; | }; | ||||
| template <typename src_ctype, typename bias_ctype, typename dst_ctype, | template <typename src_ctype, typename bias_ctype, typename dst_ctype, | ||||
| typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
| megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
| class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
| postprocess_mode, PackMode::ONLY_PACKA> : public StrategyBase { | |||||
| postprocess_mode, PackMode::ONLY_PACKA> | |||||
| : public StrategyBridge<src_ctype, bias_ctype, dst_ctype, op_ctype, | |||||
| op_dtype, postprocess_mode, | |||||
| PackMode::ONLY_PACKA,FormatMode::NCHW> { | |||||
| public: | public: | ||||
| constexpr static size_t BUNDLE_PADDING_INDEX = 0; | constexpr static size_t BUNDLE_PADDING_INDEX = 0; | ||||
| constexpr static size_t BUNDLE_PACKA_INDEX = 1; | constexpr static size_t BUNDLE_PACKA_INDEX = 1; | ||||
| @@ -239,12 +411,6 @@ public: | |||||
| Strategy() = default; | Strategy() = default; | ||||
| void copy_padding_kern( | |||||
| WorkspaceBundle bundle, | |||||
| const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
| const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
| size_t pack_size) override; | |||||
| void packA_kern(WorkspaceBundle bundle, | void packA_kern(WorkspaceBundle bundle, | ||||
| const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
| fallback::MatrixMulImpl::KernSizeParam matmulparam, | fallback::MatrixMulImpl::KernSizeParam matmulparam, | ||||
| @@ -269,24 +435,15 @@ public: | |||||
| void* get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | void* get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
| const WorkspaceBundle& bundle_thread, | const WorkspaceBundle& bundle_thread, | ||||
| const StrategyParam& sparam); | const StrategyParam& sparam); | ||||
| inline void* get_bias_temp_ptr( | |||||
| const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
| const WorkspaceBundle& bundle_thread) { | |||||
| bias_ctype* bias_tmp_ptr = | |||||
| param.bias_mode == megdnn::BiasMode::BIAS | |||||
| ? static_cast<bias_ctype*>( | |||||
| bundle_thread.get(THREAD_BUNDLE_BIAS_INDEX)) | |||||
| : nullptr; | |||||
| return bias_tmp_ptr; | |||||
| } | |||||
| void exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, | void exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
| const StrategyParam& sparam, | const StrategyParam& sparam, | ||||
| WorkspaceBundle bundle_thread) override; | |||||
| void copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
| const void* matmul_dst, const StrategyParam& sparam); | |||||
| void copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
| WorkspaceBundle bundle_thread, const StrategyParam& sparam); | |||||
| WorkspaceBundle bundle_thread) override { | |||||
| do_postprocess<bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
| postprocess_mode>(param, sparam, bundle_thread, | |||||
| THREAD_BUNDLE_MATMULDST_INDEX, | |||||
| THREAD_BUNDLE_BIAS_INDEX); | |||||
| } | |||||
| }; | }; | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| @@ -10,85 +10,9 @@ | |||||
| */ | */ | ||||
| #include "src/fallback/conv_bias/im2col/strategy_base.h" | #include "src/fallback/conv_bias/im2col/strategy_base.h" | ||||
| #include "src/fallback/convolution/img2col_helper.h" | #include "src/fallback/convolution/img2col_helper.h" | ||||
| #if MEGDNN_X86 | |||||
| #include "src/x86/conv_bias/postprocess_helper.h" | |||||
| #elif (MEGDNN_ARMV7 || MEGDNN_AARCH64) | |||||
| #include "src/arm_common/conv_bias/postprocess_helper.h" | |||||
| #endif | |||||
| using namespace megdnn; | |||||
| #if MEGDNN_X86 | |||||
| using namespace x86; | |||||
| #endif | |||||
| namespace megdnn { | namespace megdnn { | ||||
| template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
| typename op_ctype, typename op_dtype, | |||||
| megdnn::PostprocessMode postprocess_mode> | |||||
| void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
| postprocess_mode, PackMode::DEFAULT>:: | |||||
| copy_padding_kern(WorkspaceBundle bundle, | |||||
| const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
| const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
| size_t pack_oc_size) { | |||||
| UNPACK_CONV_F32_NCB_KERN_SIZES(param); | |||||
| MEGDNN_MARK_USED_VAR(N); | |||||
| MEGDNN_MARK_USED_VAR(OC); | |||||
| MEGDNN_MARK_USED_VAR(OH); | |||||
| MEGDNN_MARK_USED_VAR(OW); | |||||
| MEGDNN_MARK_USED_VAR(FH); | |||||
| MEGDNN_MARK_USED_VAR(FW); | |||||
| MEGDNN_MARK_USED_VAR(SH); | |||||
| MEGDNN_MARK_USED_VAR(SW); | |||||
| size_t IW2 = IW + 2 * PW; | |||||
| size_t IH2 = IH + 2 * PH; | |||||
| size_t batch_id = ncb_index.ndrange_id[0]; | |||||
| size_t group_id = ncb_index.ndrange_id[1]; | |||||
| size_t channel_id = ncb_index.ndrange_id[2]; | |||||
| size_t PH_SIZE = PH * IW2 * pack_oc_size; | |||||
| PW = PW * pack_oc_size; | |||||
| IW = IW * pack_oc_size; | |||||
| size_t padding_group_size = IH2 * IW2 * IC; | |||||
| size_t workspace_channel_offset = pack_oc_size * IH2 * IW2 * channel_id; | |||||
| size_t workspace_group_offset = group_id * padding_group_size; | |||||
| size_t workspace_batch_offset = | |||||
| param.filter_meta.group * batch_id * padding_group_size; | |||||
| bundle.set(param.workspace_ptr); | |||||
| src_ctype src_zp = static_cast<src_ctype>(0); | |||||
| if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) { | |||||
| src_zp = param.src_type.param<dtype::Quantized8Asymm>().zero_point; | |||||
| } | |||||
| src_ctype* src = const_cast<src_ctype*>(param.src<src_ctype>( | |||||
| batch_id, group_id, channel_id, 1, pack_oc_size)); | |||||
| src_ctype* src2; | |||||
| src2 = static_cast<src_ctype*>(bundle.get(BUNDLE_PADDING_INDEX)) + | |||||
| workspace_group_offset + workspace_batch_offset + | |||||
| workspace_channel_offset; | |||||
| src_ctype* src2_ptr = src2; | |||||
| const src_ctype* src_ptr = src; | |||||
| if (PH != 0) { | |||||
| std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH_SIZE); | |||||
| src2_ptr += PH_SIZE; | |||||
| } | |||||
| rep(ih, IH) { | |||||
| if (PW != 0) | |||||
| rep(pw, PW) * (src2_ptr++) = src_zp; | |||||
| std::memcpy(src2_ptr, src_ptr, sizeof(src_ctype) * IW); | |||||
| src2_ptr += IW; | |||||
| src_ptr += IW; | |||||
| if (PW != 0) | |||||
| rep(pw, PW) * (src2_ptr++) = src_zp; | |||||
| } | |||||
| if (PH != 0) { | |||||
| std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH_SIZE); | |||||
| src2_ptr += PH_SIZE; | |||||
| } | |||||
| } | |||||
| template <typename src_ctype, typename bias_ctype, typename dst_ctype, | template <typename src_ctype, typename bias_ctype, typename dst_ctype, | ||||
| typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
| megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
| @@ -244,100 +168,6 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
| matmul_kern_naked(matmul_param, a_panel, b_panel); | matmul_kern_naked(matmul_param, a_panel, b_panel); | ||||
| } | } | ||||
| template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
| typename op_ctype, typename op_dtype, | |||||
| megdnn::PostprocessMode postprocess_mode> | |||||
| void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
| postprocess_mode, PackMode::DEFAULT>:: | |||||
| exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
| const StrategyParam& sparam, | |||||
| WorkspaceBundle bundle_thread) { | |||||
| copy_bias(param, bundle_thread, sparam); | |||||
| void* matmul_dst = get_matmul_dst_ptr(param, bundle_thread, sparam); | |||||
| const bias_ctype* bias_ptr = static_cast<const bias_ctype*>( | |||||
| param.bias<bias_ctype>(sparam.batch_id, sparam.group_id)); | |||||
| void* bias_temp_ptr = get_bias_temp_ptr(param, bundle_thread); | |||||
| void* bias_preprocess_ptr = const_cast<void*>( | |||||
| param.bias_mode == megdnn::BiasMode::BIAS | |||||
| ? bias_temp_ptr | |||||
| : static_cast<void*>(const_cast<bias_ctype*>( | |||||
| bias_ptr + sparam.oc_cur_index))); | |||||
| size_t pack_oc_size = sparam.pack_oc_size; | |||||
| PostProcess<op_ctype, op_dtype, postprocess_mode>::run( | |||||
| matmul_dst, bias_preprocess_ptr, matmul_dst, param.bias_mode, | |||||
| param.nonlineMode, param.bias_type, param.dst_type, 1_z, | |||||
| sparam.output_block_oc_size / pack_oc_size, 1_z, | |||||
| sparam.output_block_size, pack_oc_size); | |||||
| copy_dst(param, matmul_dst, sparam); | |||||
| } | |||||
| template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
| typename op_ctype, typename op_dtype, | |||||
| megdnn::PostprocessMode postprocess_mode> | |||||
| void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
| postprocess_mode, PackMode::DEFAULT>:: | |||||
| copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
| const void* matmul_dst, const StrategyParam& sparam) { | |||||
| if (!sparam.skip_copy_dst) { | |||||
| size_t pack_oc_size = sparam.pack_oc_size; | |||||
| dst_ctype* dst_tmp_ptr = | |||||
| reinterpret_cast<dst_ctype*>(const_cast<void*>(matmul_dst)); | |||||
| dst_ctype* dst = | |||||
| param.dst<dst_ctype>(sparam.batch_id, sparam.group_id) + | |||||
| sparam.oc_cur_index * sparam.ohw + | |||||
| sparam.ohw_cur_index * pack_oc_size; | |||||
| size_t oc_loop = sparam.output_block_oc_size / pack_oc_size; | |||||
| for (size_t oc = 0; oc < oc_loop; oc++) { | |||||
| std::memcpy(dst, dst_tmp_ptr, | |||||
| sizeof(dst_ctype) * sparam.output_block_size * | |||||
| pack_oc_size); | |||||
| dst_tmp_ptr += sparam.output_block_size * pack_oc_size; | |||||
| dst += sparam.ohw * pack_oc_size; | |||||
| } | |||||
| } | |||||
| } | |||||
| template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
| typename op_ctype, typename op_dtype, | |||||
| megdnn::PostprocessMode postprocess_mode> | |||||
| void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
| postprocess_mode, PackMode::DEFAULT>:: | |||||
| get_bias_temp_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
| const WorkspaceBundle& bundle_thread) { | |||||
| bias_ctype* bias_tmp_ptr = | |||||
| param.bias_mode == megdnn::BiasMode::BIAS | |||||
| ? static_cast<bias_ctype*>( | |||||
| bundle_thread.get(THREAD_BUNDLE_BIAS_INDEX)) | |||||
| : nullptr; | |||||
| return bias_tmp_ptr; | |||||
| } | |||||
| template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
| typename op_ctype, typename op_dtype, | |||||
| megdnn::PostprocessMode postprocess_mode> | |||||
| void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
| postprocess_mode, PackMode::DEFAULT>:: | |||||
| copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
| WorkspaceBundle bundle_thread, const StrategyParam& sparam) { | |||||
| const bias_ctype* bias_ptr = static_cast<const bias_ctype*>( | |||||
| param.bias<bias_ctype>(sparam.batch_id, sparam.group_id)); | |||||
| bias_ctype* bias_temp_ptr = | |||||
| static_cast<bias_ctype*>(get_bias_temp_ptr(param, bundle_thread)); | |||||
| if (param.bias_mode == megdnn::BiasMode::BIAS) { | |||||
| bias_ctype* copy_dst = bias_temp_ptr; | |||||
| const bias_ctype* copy_src = bias_ptr + | |||||
| sparam.oc_cur_index * sparam.ohw + | |||||
| sparam.ohw_cur_index; | |||||
| for (size_t oc = sparam.oc_cur_index; oc < sparam.oc_end_index; oc++) { | |||||
| std::memcpy(copy_dst, copy_src, | |||||
| sizeof(bias_ctype) * sparam.output_block_size); | |||||
| copy_dst += sparam.output_block_size; | |||||
| copy_src += sparam.ohw; | |||||
| } | |||||
| } | |||||
| } | |||||
| #define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | #define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | ||||
| _op_dtype, _postprocess_mode) \ | _op_dtype, _postprocess_mode) \ | ||||
| template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | ||||
| @@ -11,81 +11,9 @@ | |||||
| #include "src/fallback/conv_bias/im2col/strategy_base.h" | #include "src/fallback/conv_bias/im2col/strategy_base.h" | ||||
| #include "src/fallback/convolution/img2col_helper.h" | #include "src/fallback/convolution/img2col_helper.h" | ||||
| #if MEGDNN_X86 | |||||
| #include "src/x86/conv_bias/postprocess_helper.h" | |||||
| #elif (MEGDNN_ARMV7 || MEGDNN_AARCH64) | |||||
| #include "src/arm_common/conv_bias/postprocess_helper.h" | |||||
| #endif | |||||
| using namespace megdnn; | |||||
| #if MEGDNN_X86 | |||||
| using namespace x86; | |||||
| #endif | |||||
| namespace megdnn { | namespace megdnn { | ||||
| template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
| typename op_ctype, typename op_dtype, | |||||
| megdnn::PostprocessMode postprocess_mode> | |||||
| void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
| postprocess_mode, PackMode::NO_PACK>:: | |||||
| copy_padding_kern(WorkspaceBundle bundle, | |||||
| const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
| const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
| size_t) { | |||||
| UNPACK_CONV_F32_NCB_KERN_SIZES(param); | |||||
| MEGDNN_MARK_USED_VAR(N); | |||||
| MEGDNN_MARK_USED_VAR(OC); | |||||
| MEGDNN_MARK_USED_VAR(OH); | |||||
| MEGDNN_MARK_USED_VAR(OW); | |||||
| MEGDNN_MARK_USED_VAR(FH); | |||||
| MEGDNN_MARK_USED_VAR(FW); | |||||
| MEGDNN_MARK_USED_VAR(SH); | |||||
| MEGDNN_MARK_USED_VAR(SW); | |||||
| size_t IW2 = IW + 2 * PW; | |||||
| size_t IH2 = IH + 2 * PH; | |||||
| size_t batch_id = ncb_index.ndrange_id[0]; | |||||
| size_t group_id = ncb_index.ndrange_id[1]; | |||||
| size_t channel_id = ncb_index.ndrange_id[2]; | |||||
| size_t padding_group_size = IH2 * IW2 * IC; | |||||
| size_t workspace_channel_offset = IH2 * IW2 * channel_id; | |||||
| size_t workspace_group_offset = group_id * padding_group_size; | |||||
| size_t workspace_batch_offset = | |||||
| param.filter_meta.group * batch_id * padding_group_size; | |||||
| bundle.set(param.workspace_ptr); | |||||
| src_ctype src_zp = static_cast<src_ctype>(0); | |||||
| if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) { | |||||
| src_zp = param.src_type.param<dtype::Quantized8Asymm>().zero_point; | |||||
| } | |||||
| src_ctype* src = const_cast<src_ctype*>( | |||||
| param.src<src_ctype>(batch_id, group_id, channel_id)); | |||||
| src_ctype* src2; | |||||
| src2 = static_cast<src_ctype*>(bundle.get(BUNDLE_PADDING_INDEX)) + | |||||
| workspace_group_offset + workspace_batch_offset + | |||||
| workspace_channel_offset; | |||||
| src_ctype* src2_ptr = src2; | |||||
| const src_ctype* src_ptr = src; | |||||
| if (PH != 0) { | |||||
| std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH * IW2); | |||||
| src2_ptr += PH * IW2; | |||||
| } | |||||
| rep(ih, IH) { | |||||
| if (PW != 0) | |||||
| rep(pw, PW) * (src2_ptr++) = src_zp; | |||||
| std::memcpy(src2_ptr, src_ptr, sizeof(src_ctype) * IW); | |||||
| src2_ptr += IW; | |||||
| src_ptr += IW; | |||||
| if (PW != 0) | |||||
| rep(pw, PW) * (src2_ptr++) = src_zp; | |||||
| } | |||||
| if (PH != 0) { | |||||
| std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH * IW2); | |||||
| src2_ptr += PH * IW2; | |||||
| } | |||||
| } | |||||
| template <typename src_ctype, typename bias_ctype, typename dst_ctype, | template <typename src_ctype, typename bias_ctype, typename dst_ctype, | ||||
| typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
| megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
| @@ -220,81 +148,6 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
| } | } | ||||
| } | } | ||||
| template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
| typename op_ctype, typename op_dtype, | |||||
| megdnn::PostprocessMode postprocess_mode> | |||||
| void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
| postprocess_mode, PackMode::NO_PACK>:: | |||||
| exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
| const StrategyParam& sparam, | |||||
| WorkspaceBundle bundle_thread) { | |||||
| copy_bias(param, bundle_thread, sparam); | |||||
| void* matmul_dst = get_matmul_dst_ptr(param, bundle_thread, sparam); | |||||
| const bias_ctype* bias_ptr = static_cast<const bias_ctype*>( | |||||
| param.bias<bias_ctype>(sparam.batch_id, sparam.group_id)); | |||||
| bias_ctype* bias_temp_ptr = | |||||
| static_cast<bias_ctype*>(get_bias_temp_ptr(param, bundle_thread)); | |||||
| PostProcess<op_ctype, op_dtype, postprocess_mode>::run( | |||||
| matmul_dst, | |||||
| const_cast<void*>( | |||||
| param.bias_mode == megdnn::BiasMode::BIAS | |||||
| ? bias_temp_ptr | |||||
| : static_cast<void*>(const_cast<bias_ctype*>( | |||||
| bias_ptr + sparam.oc_cur_index))), | |||||
| matmul_dst, param.bias_mode, param.nonlineMode, param.bias_type, | |||||
| param.dst_type, 1_z, sparam.output_block_oc_size, 1_z, | |||||
| sparam.output_block_size); | |||||
| copy_dst(param, matmul_dst, sparam); | |||||
| } | |||||
| template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
| typename op_ctype, typename op_dtype, | |||||
| megdnn::PostprocessMode postprocess_mode> | |||||
| void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
| postprocess_mode, PackMode::NO_PACK>:: | |||||
| copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
| const void* matmul_dst, const StrategyParam& sparam) { | |||||
| if (!sparam.skip_copy_dst) { | |||||
| dst_ctype* dst_tmp_ptr = | |||||
| reinterpret_cast<dst_ctype*>(const_cast<void*>(matmul_dst)); | |||||
| dst_ctype* dst = | |||||
| param.dst<dst_ctype>(sparam.batch_id, sparam.group_id) + | |||||
| sparam.oc_cur_index * sparam.ohw + sparam.ohw_cur_index; | |||||
| for (size_t oc = 0; oc < sparam.output_block_oc_size; oc++) { | |||||
| std::memcpy(dst, dst_tmp_ptr, | |||||
| sizeof(dst_ctype) * sparam.output_block_size); | |||||
| dst_tmp_ptr += sparam.output_block_size; | |||||
| dst += sparam.ohw; | |||||
| } | |||||
| } | |||||
| } | |||||
| template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
| typename op_ctype, typename op_dtype, | |||||
| megdnn::PostprocessMode postprocess_mode> | |||||
| void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
| postprocess_mode, PackMode::NO_PACK>:: | |||||
| copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
| WorkspaceBundle bundle_thread, const StrategyParam& sparam) { | |||||
| const bias_ctype* bias_ptr = static_cast<const bias_ctype*>( | |||||
| param.bias<bias_ctype>(sparam.batch_id, sparam.group_id)); | |||||
| bias_ctype* bias_temp_ptr = | |||||
| static_cast<bias_ctype*>(get_bias_temp_ptr(param, bundle_thread)); | |||||
| if (param.bias_mode == megdnn::BiasMode::BIAS) { | |||||
| bias_ctype* copy_dst = bias_temp_ptr; | |||||
| const bias_ctype* copy_src = bias_ptr + | |||||
| sparam.oc_cur_index * sparam.ohw + | |||||
| sparam.ohw_cur_index; | |||||
| for (size_t oc = sparam.oc_cur_index; oc < sparam.oc_end_index; oc++) { | |||||
| std::memcpy(copy_dst, copy_src, | |||||
| sizeof(bias_ctype) * sparam.output_block_size); | |||||
| copy_dst += sparam.output_block_size; | |||||
| copy_src += sparam.ohw; | |||||
| } | |||||
| } | |||||
| } | |||||
| #define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | #define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | ||||
| _op_dtype, _postprocess_mode) \ | _op_dtype, _postprocess_mode) \ | ||||
| template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | ||||
| @@ -302,34 +155,18 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
| INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32, | INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32, | ||||
| megdnn::PostprocessMode::FLOAT) | megdnn::PostprocessMode::FLOAT) | ||||
| INSTANTIAL_CLASS(dt_int8, dt_int16, dt_int16, dt_int16, dt_int16, | |||||
| megdnn::PostprocessMode::NO_PROCESS) | |||||
| INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_int32, dt_int32, | |||||
| megdnn::PostprocessMode::NO_PROCESS) | |||||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
| INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, __fp16, __fp16, | |||||
| megdnn::PostprocessMode::FLOAT) | |||||
| #else | #else | ||||
| #if !MEGDNN_DISABLE_FLOAT16 | #if !MEGDNN_DISABLE_FLOAT16 | ||||
| INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16, | INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16, | ||||
| megdnn::PostprocessMode::NO_PROCESS) | megdnn::PostprocessMode::NO_PROCESS) | ||||
| #endif | #endif | ||||
| #endif | #endif | ||||
| #if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||||
| //! x86 do not have uint8 matmul so only armv7 armv8 support uint8 | |||||
| INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_uint8, dt_qint32, dt_quint8, | |||||
| megdnn::PostprocessMode::QUANTIZED) | |||||
| INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_qint32, dt_qint32, | |||||
| megdnn::PostprocessMode::NO_PROCESS) | |||||
| #endif | |||||
| INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8, | |||||
| megdnn::PostprocessMode::QUANTIZED) | |||||
| INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_int32, dt_int32, | |||||
| megdnn::PostprocessMode::NO_PROCESS) | |||||
| INSTANTIAL_CLASS(dt_int8, dt_int16, dt_int16, dt_int16, dt_int16, | |||||
| megdnn::PostprocessMode::NO_PROCESS) | |||||
| INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_qint32, dt_qint32, | |||||
| megdnn::PostprocessMode::NO_PROCESS) | |||||
| #undef INSTANTIAL_CLASS | |||||
| } // namespace megdnn | } // namespace megdnn | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -11,81 +11,9 @@ | |||||
| #include "src/fallback/conv_bias/im2col/strategy_base.h" | #include "src/fallback/conv_bias/im2col/strategy_base.h" | ||||
| #include "src/fallback/convolution/img2col_helper.h" | #include "src/fallback/convolution/img2col_helper.h" | ||||
| #if MEGDNN_X86 | |||||
| #include "src/x86/conv_bias/postprocess_helper.h" | |||||
| #elif (MEGDNN_ARMV7 || MEGDNN_AARCH64) | |||||
| #include "src/arm_common/conv_bias/postprocess_helper.h" | |||||
| #endif | |||||
| using namespace megdnn; | |||||
| #if MEGDNN_X86 | |||||
| using namespace x86; | |||||
| #endif | |||||
| namespace megdnn { | namespace megdnn { | ||||
| template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
| typename op_ctype, typename op_dtype, | |||||
| megdnn::PostprocessMode postprocess_mode> | |||||
| void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
| postprocess_mode, PackMode::ONLY_PACKA>:: | |||||
| copy_padding_kern(WorkspaceBundle bundle, | |||||
| const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
| const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
| size_t) { | |||||
| UNPACK_CONV_F32_NCB_KERN_SIZES(param); | |||||
| MEGDNN_MARK_USED_VAR(N); | |||||
| MEGDNN_MARK_USED_VAR(OC); | |||||
| MEGDNN_MARK_USED_VAR(OH); | |||||
| MEGDNN_MARK_USED_VAR(OW); | |||||
| MEGDNN_MARK_USED_VAR(FH); | |||||
| MEGDNN_MARK_USED_VAR(FW); | |||||
| MEGDNN_MARK_USED_VAR(SH); | |||||
| MEGDNN_MARK_USED_VAR(SW); | |||||
| size_t IW2 = IW + 2 * PW; | |||||
| size_t IH2 = IH + 2 * PH; | |||||
| size_t batch_id = ncb_index.ndrange_id[0]; | |||||
| size_t group_id = ncb_index.ndrange_id[1]; | |||||
| size_t channel_id = ncb_index.ndrange_id[2]; | |||||
| size_t padding_group_size = IH2 * IW2 * IC; | |||||
| size_t workspace_channel_offset = IH2 * IW2 * channel_id; | |||||
| size_t workspace_group_offset = group_id * padding_group_size; | |||||
| size_t workspace_batch_offset = | |||||
| param.filter_meta.group * batch_id * padding_group_size; | |||||
| bundle.set(param.workspace_ptr); | |||||
| src_ctype src_zp = static_cast<src_ctype>(0); | |||||
| if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) { | |||||
| src_zp = param.src_type.param<dtype::Quantized8Asymm>().zero_point; | |||||
| } | |||||
| src_ctype* src = const_cast<src_ctype*>( | |||||
| param.src<src_ctype>(batch_id, group_id, channel_id)); | |||||
| src_ctype* src2; | |||||
| src2 = static_cast<src_ctype*>(bundle.get(BUNDLE_PADDING_INDEX)) + | |||||
| workspace_group_offset + workspace_batch_offset + | |||||
| workspace_channel_offset; | |||||
| src_ctype* src2_ptr = src2; | |||||
| const src_ctype* src_ptr = src; | |||||
| if (PH != 0) { | |||||
| std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH * IW2); | |||||
| src2_ptr += PH * IW2; | |||||
| } | |||||
| rep(ih, IH) { | |||||
| if (PW != 0) | |||||
| rep(pw, PW) * (src2_ptr++) = src_zp; | |||||
| std::memcpy(src2_ptr, src_ptr, sizeof(src_ctype) * IW); | |||||
| src2_ptr += IW; | |||||
| src_ptr += IW; | |||||
| if (PW != 0) | |||||
| rep(pw, PW) * (src2_ptr++) = src_zp; | |||||
| } | |||||
| if (PH != 0) { | |||||
| std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH * IW2); | |||||
| src2_ptr += PH * IW2; | |||||
| } | |||||
| } | |||||
| template <typename src_ctype, typename bias_ctype, typename dst_ctype, | template <typename src_ctype, typename bias_ctype, typename dst_ctype, | ||||
| typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
| megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
| @@ -120,25 +48,6 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
| matmul_algo->pack_A(matmul_param, a_panel, 0_z, 0_z); | matmul_algo->pack_A(matmul_param, a_panel, 0_z, 0_z); | ||||
| } | } | ||||
| template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
| typename op_ctype, typename op_dtype, | |||||
| megdnn::PostprocessMode postprocess_mode> | |||||
| void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
| postprocess_mode, PackMode::ONLY_PACKA>:: | |||||
| get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
| const WorkspaceBundle& bundle_thread, | |||||
| const StrategyParam& sparam) { | |||||
| if (sparam.is_dst_8bit || !sparam.is_ohw_size_bigger) { | |||||
| return static_cast<void*>( | |||||
| bundle_thread.get(THREAD_BUNDLE_MATMULDST_INDEX)); | |||||
| } else { | |||||
| bias_ctype* dst = | |||||
| param.dst<bias_ctype>(sparam.batch_id, sparam.group_id) + | |||||
| sparam.oc_cur_index * sparam.ohw; | |||||
| return static_cast<void*>(dst); | |||||
| } | |||||
| } | |||||
| template <typename src_ctype, typename bias_ctype, typename dst_ctype, | template <typename src_ctype, typename bias_ctype, typename dst_ctype, | ||||
| typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
| megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
| @@ -241,63 +150,19 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
| template <typename src_ctype, typename bias_ctype, typename dst_ctype, | template <typename src_ctype, typename bias_ctype, typename dst_ctype, | ||||
| typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
| megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
| void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
| postprocess_mode, PackMode::ONLY_PACKA>:: | |||||
| exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
| const StrategyParam& sparam, | |||||
| WorkspaceBundle bundle_thread) { | |||||
| void* matmul_dst = get_matmul_dst_ptr(param, bundle_thread, sparam); | |||||
| const bias_ctype* bias_ptr = static_cast<const bias_ctype*>( | |||||
| param.bias<bias_ctype>(sparam.batch_id, sparam.group_id)); | |||||
| bias_ctype* bias_temp_ptr = | |||||
| static_cast<bias_ctype*>(get_bias_temp_ptr(param, bundle_thread)); | |||||
| if (param.bias_mode == megdnn::BiasMode::BIAS) { | |||||
| bias_ctype* copy_dst = bias_temp_ptr; | |||||
| const bias_ctype* copy_src = bias_ptr + | |||||
| sparam.oc_cur_index * sparam.ohw + | |||||
| sparam.ohw_cur_index; | |||||
| for (size_t oc = sparam.oc_cur_index; oc < sparam.oc_end_index; oc++) { | |||||
| std::memcpy(copy_dst, copy_src, | |||||
| sizeof(bias_ctype) * sparam.output_block_size); | |||||
| copy_dst += sparam.output_block_size; | |||||
| copy_src += sparam.ohw; | |||||
| } | |||||
| } | |||||
| PostProcess<op_ctype, op_dtype, postprocess_mode>::run( | |||||
| matmul_dst, | |||||
| const_cast<void*>( | |||||
| param.bias_mode == megdnn::BiasMode::BIAS | |||||
| ? bias_temp_ptr | |||||
| : static_cast<void*>(const_cast<bias_ctype*>( | |||||
| bias_ptr + sparam.oc_cur_index))), | |||||
| matmul_dst, param.bias_mode, param.nonlineMode, param.bias_type, | |||||
| param.dst_type, 1_z, sparam.output_block_oc_size, 1_z, | |||||
| sparam.output_block_size); | |||||
| copy_dst(param, matmul_dst, sparam); | |||||
| } | |||||
| template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
| typename op_ctype, typename op_dtype, | |||||
| megdnn::PostprocessMode postprocess_mode> | |||||
| void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
| postprocess_mode, PackMode::ONLY_PACKA>:: | |||||
| copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
| const void* matmul_dst, const StrategyParam& sparam) { | |||||
| if (!sparam.skip_copy_dst) { | |||||
| dst_ctype* dst_tmp_ptr = | |||||
| reinterpret_cast<dst_ctype*>(const_cast<void*>(matmul_dst)); | |||||
| dst_ctype* dst = | |||||
| param.dst<dst_ctype>(sparam.batch_id, sparam.group_id) + | |||||
| sparam.oc_cur_index * sparam.ohw + sparam.ohw_cur_index; | |||||
| for (size_t oc = 0; oc < sparam.output_block_oc_size; oc++) { | |||||
| std::memcpy(dst, dst_tmp_ptr, | |||||
| sizeof(dst_ctype) * sparam.output_block_size); | |||||
| dst_tmp_ptr += sparam.output_block_size; | |||||
| dst += sparam.ohw; | |||||
| } | |||||
| void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
| postprocess_mode, PackMode::ONLY_PACKA>:: | |||||
| get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
| const WorkspaceBundle& bundle_thread, | |||||
| const StrategyParam& sparam) { | |||||
| if (sparam.is_dst_8bit || !sparam.is_ohw_size_bigger) { | |||||
| return static_cast<bias_ctype*>( | |||||
| bundle_thread.get(THREAD_BUNDLE_MATMULDST_INDEX)); | |||||
| } else { | |||||
| bias_ctype* dst = | |||||
| param.dst<bias_ctype>(sparam.batch_id, sparam.group_id) + | |||||
| sparam.oc_cur_index * sparam.ohw; | |||||
| return static_cast<void*>(dst); | |||||
| } | } | ||||
| } | } | ||||
| @@ -310,33 +175,6 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
| INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32, | INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32, | ||||
| megdnn::PostprocessMode::FLOAT) | megdnn::PostprocessMode::FLOAT) | ||||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
| INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, __fp16, __fp16, | |||||
| megdnn::PostprocessMode::FLOAT) | |||||
| #else | |||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16, | |||||
| megdnn::PostprocessMode::NO_PROCESS) | |||||
| #endif | |||||
| #endif | |||||
| #if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||||
| //! x86 do not have uint8 matmul so only armv7 armv8 support uint8 | |||||
| INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_uint8, dt_qint32, dt_quint8, | |||||
| megdnn::PostprocessMode::QUANTIZED) | |||||
| INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_qint32, dt_qint32, | |||||
| megdnn::PostprocessMode::NO_PROCESS) | |||||
| #endif | |||||
| INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8, | |||||
| megdnn::PostprocessMode::QUANTIZED) | |||||
| INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_int32, dt_int32, | |||||
| megdnn::PostprocessMode::NO_PROCESS) | |||||
| INSTANTIAL_CLASS(dt_int8, dt_int16, dt_int16, dt_int16, dt_int16, | |||||
| megdnn::PostprocessMode::NO_PROCESS) | |||||
| INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_qint32, dt_qint32, | |||||
| megdnn::PostprocessMode::NO_PROCESS) | |||||
| #undef INSTANTIAL_CLASS | #undef INSTANTIAL_CLASS | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| @@ -26,7 +26,7 @@ | |||||
| using namespace megdnn; | using namespace megdnn; | ||||
| using namespace fallback; | using namespace fallback; | ||||
| size_t megdnn::fallback::get_format_pack_size(param::ConvBias::Format format) { | |||||
| size_t megdnn::fallback::pack_size(param::ConvBias::Format format) { | |||||
| switch (format) { | switch (format) { | ||||
| case param::ConvBias::Format::NCHW44: | case param::ConvBias::Format::NCHW44: | ||||
| case param::ConvBias::Format::NCHW44_DOT: | case param::ConvBias::Format::NCHW44_DOT: | ||||
| @@ -23,8 +23,10 @@ namespace fallback { | |||||
| /*! | /*! | ||||
| * \brief get the pack_size according to the format | * \brief get the pack_size according to the format | ||||
| * Note TODO: when remove format from param, | |||||
| * may using like this "opr::param::format specify" | |||||
| * */ | * */ | ||||
| size_t get_format_pack_size(param::ConvBias::Format format); | |||||
| size_t pack_size(param::ConvBias::Format format); | |||||
| /*! | /*! | ||||
| * \brief fallback conv bias forward impl | * \brief fallback conv bias forward impl | ||||
| @@ -52,9 +52,21 @@ class GemmInterleaved<Strategy, true> { | |||||
| } | } | ||||
| size_t get_b_workspace_size() const { | size_t get_b_workspace_size() const { | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| size_t new_blockn = m_strategy.block_n; | |||||
| if (m_strategy.KERNEL_W == 6 && m_strategy.UNROLL_K == 4 && | |||||
| m_strategy.KERNEL_H == 8) { | |||||
| new_blockn = round_up<size_t>((m_strategy.block_n-1) % 6, 4) + | |||||
| m_strategy.block_n / 6 * 6; | |||||
| } | |||||
| size_t N = round_up(new_blockn, m_strategy.KERNEL_W); | |||||
| size_t K = round_up(m_strategy.block_k, m_strategy.UNROLL_K); | |||||
| return round_up(sizeof(stype) * N * K, CACHELINE_SIZE) + m_align_size; | |||||
| #else | |||||
| size_t N = round_up(m_strategy.block_n, m_strategy.KERNEL_W); | size_t N = round_up(m_strategy.block_n, m_strategy.KERNEL_W); | ||||
| size_t K = round_up(m_strategy.block_k, m_strategy.UNROLL_K); | size_t K = round_up(m_strategy.block_k, m_strategy.UNROLL_K); | ||||
| return round_up(sizeof(stype) * N * K, CACHELINE_SIZE) + m_align_size; | return round_up(sizeof(stype) * N * K, CACHELINE_SIZE) + m_align_size; | ||||
| #endif | |||||
| } | } | ||||
| //! temporary storage for output, post process such as add bias or relu will | //! temporary storage for output, post process such as add bias or relu will | ||||
| @@ -268,7 +268,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, BENCHMARK_CONVBIAS_NCHW44) { | |||||
| benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384", | benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384", | ||||
| "IM2COLMATMUL:AARCH64_F32K8X12X1:192", false); | "IM2COLMATMUL:AARCH64_F32K8X12X1:192", false); | ||||
| #else | #else | ||||
| benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384", | |||||
| benchmark_convbias(handle(), "IM2COLMATMUL:ARMV7_INT8X8X32_K4X8X8:384", | |||||
| "IM2COLMATMUL:ARMV7_F32:192", true); | "IM2COLMATMUL:ARMV7_F32:192", true); | ||||
| benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384", | benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384", | ||||
| "IM2COLMATMUL:ARMV7_F32:192", false); | "IM2COLMATMUL:ARMV7_F32:192", false); | ||||
| @@ -72,10 +72,12 @@ std::vector<conv_bias::TestArg> get_int8_quint8_conv_bias_args( | |||||
| std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args( | std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args( | ||||
| std::vector<size_t> kernel_vec, size_t stride, bool no_pad = false, | std::vector<size_t> kernel_vec, size_t stride, bool no_pad = false, | ||||
| bool no_bias = false, bool no_nonlinemode = false, | bool no_bias = false, bool no_nonlinemode = false, | ||||
| bool is_input_nchw = false, bool support_full_bias = false, | |||||
| bool support_sigmoid = false) { | |||||
| bool is_input_nchw = false, bool is_nchw44_dot = false, | |||||
| bool support_full_bias = false, bool support_sigmoid = false, | |||||
| bool only_no_bias = false) { | |||||
| using namespace conv_bias; | using namespace conv_bias; | ||||
| using NLMode = param::ConvBias::NonlineMode; | using NLMode = param::ConvBias::NonlineMode; | ||||
| std::vector<TestArg> args; | std::vector<TestArg> args; | ||||
| auto pack = [&](size_t n, size_t oc, size_t ic, size_t h, size_t w, | auto pack = [&](size_t n, size_t oc, size_t ic, size_t h, size_t w, | ||||
| @@ -102,7 +104,11 @@ std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args( | |||||
| size_t kernel_h = kernel; | size_t kernel_h = kernel; | ||||
| size_t kernel_w = kernel; | size_t kernel_w = kernel; | ||||
| param::ConvBias param; | param::ConvBias param; | ||||
| param.format = param::ConvBias::Format::NCHW44; | |||||
| if (!is_nchw44_dot) { | |||||
| param.format = param::ConvBias::Format::NCHW44; | |||||
| } else { | |||||
| param.format = param::ConvBias::Format::NCHW44_DOT; | |||||
| } | |||||
| param.stride_h = stride; | param.stride_h = stride; | ||||
| param.stride_w = stride; | param.stride_w = stride; | ||||
| param.pad_h = pad; | param.pad_h = pad; | ||||
| @@ -155,18 +161,22 @@ std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args( | |||||
| if (support_sigmoid) { | if (support_sigmoid) { | ||||
| nonlinemode.emplace_back(NLMode::SIGMOID); | nonlinemode.emplace_back(NLMode::SIGMOID); | ||||
| } | } | ||||
| std::vector<megdnn::BiasMode> bias_mode = { | |||||
| megdnn::BiasMode::BROADCAST_CHANNEL_BIAS}; | |||||
| if (no_bias) { | |||||
| std::vector<megdnn::BiasMode> bias_mode; | |||||
| if (!only_no_bias) { | |||||
| bias_mode.emplace_back(megdnn::BiasMode::BROADCAST_CHANNEL_BIAS); | |||||
| if (no_bias) { | |||||
| bias_mode.emplace_back(megdnn::BiasMode::NO_BIAS); | |||||
| } | |||||
| } else { | |||||
| bias_mode.emplace_back(megdnn::BiasMode::NO_BIAS); | bias_mode.emplace_back(megdnn::BiasMode::NO_BIAS); | ||||
| } | } | ||||
| if (support_full_bias) { | if (support_full_bias) { | ||||
| bias_mode.emplace_back(megdnn::BiasMode::BIAS); | |||||
| bias_mode.emplace_back(megdnn::BiasMode::BIAS); | |||||
| } | } | ||||
| for (auto bias : bias_mode) | for (auto bias : bias_mode) | ||||
| for (auto nlmode : nonlinemode) | for (auto nlmode : nonlinemode) | ||||
| for (size_t n : {1, 2}) | |||||
| for (size_t n : {1,2}) | |||||
| for (size_t kernel : kernel_vec) | for (size_t kernel : kernel_vec) | ||||
| for (size_t oc : {4, 12}) | for (size_t oc : {4, 12}) | ||||
| for (size_t ic : {1, 3, 4, 12}) | for (size_t ic : {1, 3, 4, 12}) | ||||
| @@ -361,19 +371,19 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K7) { | |||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K2K3) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K2K3) { | ||||
| check_conv_bias(get_nchw44_conv_bias_args({2, 3}, 1, false, false, false, | check_conv_bias(get_nchw44_conv_bias_args({2, 3}, 1, false, false, false, | ||||
| false, true, true), | |||||
| false, false, true, true), | |||||
| handle(), "F32_CONV_NCHW44_DIRECT"); | handle(), "F32_CONV_NCHW44_DIRECT"); | ||||
| } | } | ||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K5) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K5) { | ||||
| check_conv_bias(get_nchw44_conv_bias_args({5}, 1, false, false, false, | check_conv_bias(get_nchw44_conv_bias_args({5}, 1, false, false, false, | ||||
| false, true, true), | |||||
| false, false, true, true), | |||||
| handle(), "F32_CONV_NCHW44_DIRECT"); | handle(), "F32_CONV_NCHW44_DIRECT"); | ||||
| } | } | ||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S2) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S2) { | ||||
| check_conv_bias(get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, | check_conv_bias(get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, | ||||
| false, false, true, true), | |||||
| false, false, false, true, true), | |||||
| handle(), "F32_CONV_NCHW44_DIRECT"); | handle(), "F32_CONV_NCHW44_DIRECT"); | ||||
| } | } | ||||
| @@ -1420,6 +1430,111 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM) { | |||||
| #endif | #endif | ||||
| #undef cb | #undef cb | ||||
| } | } | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_MK4_DOT) { | |||||
| UniformIntRNG rng{-50, 50}; | |||||
| #define cb(name) \ | |||||
| checker_conv_bias(get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, \ | |||||
| false, false, false, true), \ | |||||
| handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ | |||||
| dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \ | |||||
| dtype::QuantizedS8(60.25f), name); \ | |||||
| checker_conv_bias( \ | |||||
| get_nchw44_conv_bias_args({1}, 2, false, true, true, false, true), \ | |||||
| handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ | |||||
| dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \ | |||||
| dtype::QuantizedS8(60.25f), name); | |||||
| float epsilon = 0.001; | |||||
| #if MEGDNN_AARCH64 | |||||
| cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD:96"); | |||||
| #elif MEGDNN_ARMV7 | |||||
| cb("IM2COLMATMUL:AARCH32_INT8_MK4_8X6X4_DOTPROD:96"); | |||||
| #endif | |||||
| #undef cb | |||||
| } | |||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_S8x8x32_MK4_DOT) { | |||||
| UniformIntRNG rng{-50, 50}; | |||||
| #define cb(name) \ | |||||
| checker_conv_bias( \ | |||||
| get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \ | |||||
| true, false, true, false, false, true), \ | |||||
| handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ | |||||
| dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), {}, name); \ | |||||
| checker_conv_bias( \ | |||||
| get_nchw44_conv_bias_args({1}, 2, false, true, true, false, true, \ | |||||
| false, false, true), \ | |||||
| handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ | |||||
| dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), {}, name); | |||||
| float epsilon = 0.001; | |||||
| #if MEGDNN_AARCH64 | |||||
| cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD:96"); | |||||
| #elif MEGDNN_ARMV7 | |||||
| cb("IM2COLMATMUL:AARCH32_INT8_MK4_8X6X4_DOTPROD:96"); | |||||
| #endif | |||||
| #undef cb | |||||
| } | |||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32_MK4_DOT) { | |||||
| UniformIntRNG rng{-50, 50}; | |||||
| #define cb(name) \ | |||||
| checker_conv_bias( \ | |||||
| get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \ | |||||
| true, false, true, false, false, true), \ | |||||
| handle(), &rng, epsilon, dtype::Int8(), dtype::Int8(), \ | |||||
| dtype::Int32(), {}, name); \ | |||||
| checker_conv_bias( \ | |||||
| get_nchw44_conv_bias_args({1}, 2, false, true, true, false, true, \ | |||||
| false, false, true), \ | |||||
| handle(), &rng, epsilon, dtype::Int8(), dtype::Int8(), \ | |||||
| dtype::Int32(), {}, name); | |||||
| float epsilon = 0.001; | |||||
| #if MEGDNN_AARCH64 | |||||
| cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD:96"); | |||||
| #elif MEGDNN_ARMV7 | |||||
| cb("IM2COLMATMUL:AARCH32_INT8_MK4_8X6X4_DOTPROD:96"); | |||||
| #endif | |||||
| #undef cb | |||||
| } | |||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CONV1x1_QUANTIZEDSYM_MK4_DOT) { | |||||
| UniformIntRNG rng{-50, 50}; | |||||
| #define cb(name) \ | |||||
| checker_conv_bias( \ | |||||
| get_nchw44_conv_bias_args({1}, 1, true, true, false, false, true), \ | |||||
| handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ | |||||
| dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \ | |||||
| dtype::QuantizedS8(60.25f), name); \ | |||||
| checker_conv_bias( \ | |||||
| get_nchw44_conv_bias_args({1}, 1, true, true, true, false, true, \ | |||||
| false, false, true), \ | |||||
| handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ | |||||
| dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), {}, name); \ | |||||
| checker_conv_bias( \ | |||||
| get_nchw44_conv_bias_args({1}, 1, true, true, true, false, true, \ | |||||
| false, false, true), \ | |||||
| handle(), &rng, epsilon, dtype::Int8(), dtype::Int8(), \ | |||||
| dtype::Int32(), {}, name); | |||||
| float epsilon = 0.001; | |||||
| #if MEGDNN_AARCH64 | |||||
| cb("CONV1x1:AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD"); | |||||
| #elif MEGDNN_ARMV7 | |||||
| cb("CONV1x1:AARCH32_INT8_MK4_8X6X4_DOTPROD"); | |||||
| #endif | |||||
| #undef cb | |||||
| } | |||||
| #endif | |||||
| // clang-format on | // clang-format on | ||||
| #if MEGDNN_AARCH64 || MEGDNN_ARMV7 | #if MEGDNN_AARCH64 || MEGDNN_ARMV7 | ||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDASYM) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDASYM) { | ||||
| @@ -1685,8 +1800,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32) { | |||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S1_MK4_PACK_F32) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S1_MK4_PACK_F32) { | ||||
| using namespace conv_bias; | using namespace conv_bias; | ||||
| std::vector<conv_bias::TestArg> args = | |||||
| get_nchw44_conv_bias_args({2, 4, 7}, 1); | |||||
| std::vector<conv_bias::TestArg> args = get_nchw44_conv_bias_args( | |||||
| {2, 4, 7}, 1, false, false, false, false, false, true,true); | |||||
| #if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
| check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1"); | check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1"); | ||||
| #elif MEGDNN_ARMV7 | #elif MEGDNN_ARMV7 | ||||
| @@ -1696,8 +1811,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S1_MK4_PACK_F32) { | |||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S2_MK4_PACK_F32) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S2_MK4_PACK_F32) { | ||||
| using namespace conv_bias; | using namespace conv_bias; | ||||
| std::vector<conv_bias::TestArg> args = | |||||
| get_nchw44_conv_bias_args({3, 5, 6}, 2); | |||||
| std::vector<conv_bias::TestArg> args = get_nchw44_conv_bias_args( | |||||
| {3, 5, 6}, 2, false, false, false, false, false, true, true); | |||||
| #if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
| check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1"); | check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1"); | ||||
| #elif MEGDNN_ARMV7 | #elif MEGDNN_ARMV7 | ||||
| @@ -897,6 +897,62 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_FP32) { | |||||
| #undef cb | #undef cb | ||||
| } | } | ||||
| #if MEGDNN_X86_WITH_MKL || MEGDNN_X86_WITH_OPENBLAS | |||||
| TEST_F(X86, CONV_BIAS_IM2COLMATMUL_FP32) { | |||||
| using namespace conv_bias; | |||||
| std::vector<TestArg> args; | |||||
| auto run = [&](size_t oc, size_t ic, size_t w, size_t h, size_t kernel, | |||||
| size_t p, NonlineMode nonline_mode) { | |||||
| if (w + 2 * p < kernel || h + 2 * p < kernel) | |||||
| return; | |||||
| param::ConvBias param; | |||||
| param.stride_h = 1; | |||||
| param.stride_w = 1; | |||||
| param.pad_h = p; | |||||
| param.pad_w = p; | |||||
| param.nonlineMode = nonline_mode; | |||||
| //! no bias | |||||
| args.emplace_back(param, TensorShape{1, ic, h, w}, | |||||
| TensorShape{oc, ic, kernel, kernel}, TensorShape{}); | |||||
| args.emplace_back(param, TensorShape{1, ic, h, w}, | |||||
| TensorShape{oc, ic, kernel, kernel}, | |||||
| TensorShape{1, oc, 1, 1}); | |||||
| args.emplace_back( | |||||
| param, TensorShape{1, ic, h, w}, | |||||
| TensorShape{oc, ic, kernel, kernel}, | |||||
| TensorShape{1, oc, (h + 2 * p - kernel) / param.stride_h + 1, | |||||
| (w + 2 * p - kernel) / param.stride_w + 1}); | |||||
| }; | |||||
| for (size_t kernel : {2, 3, 4, 5, 6, 7}) | |||||
| for (size_t ic : {1, 4, 8, 16}) | |||||
| for (size_t oc : {1, 4, 8, 16, 300}) | |||||
| for (size_t p : {0, 2}) | |||||
| for (size_t size : {8, 24}) | |||||
| for (NonlineMode nonline_mode : | |||||
| {NonlineMode::IDENTITY, NonlineMode::RELU}) { | |||||
| run(oc, ic, size, size, kernel, p, nonline_mode); | |||||
| } | |||||
| run(2046, 8, 20, 20, 3, 1, NonlineMode::IDENTITY); | |||||
| Checker<ConvBias> checker(handle()); | |||||
| #define cb(algo_name) \ | |||||
| checker.set_before_exec_callback( \ | |||||
| conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name)); \ | |||||
| for (auto&& arg : args) { \ | |||||
| checker.set_param(arg.param).execs( \ | |||||
| {arg.src, arg.filter, arg.bias, {}, {}}); \ | |||||
| } | |||||
| cb("IM2COLMATMUL:X86_F32_BLAS"); | |||||
| #undef cb | |||||
| } | |||||
| #endif | |||||
| #if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM | #if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM | ||||
| TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_FP32_PACKA) { | TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_FP32_PACKA) { | ||||
| using namespace conv_bias; | using namespace conv_bias; | ||||