GitOrigin-RevId: d326035202
tags/v0.4.0
| @@ -17,18 +17,14 @@ | |||
| #include "src/fallback/conv_bias/opr_impl.h" | |||
| #include "src/fallback/conv_bias/winograd/strategy.h" | |||
| #include "src/naive/convolution/helper.h" | |||
| #if MEGDNN_X86 | |||
| #include "src/x86/conv_bias/postprocess_helper.h" | |||
| #endif | |||
| #include "midout.h" | |||
| MIDOUT_DECL(megdnn_fallback_im2col) | |||
| using namespace megdnn; | |||
| using namespace fallback; | |||
| using namespace im2col; | |||
| #if MEGDNN_X86 | |||
| using namespace x86; | |||
| #endif | |||
| /*======================== AlgoIm2col=======================*/ | |||
| /*! | |||
| @@ -47,8 +43,8 @@ using Pack_Mode=fallback::MatrixMulImpl::AlgoBase::PackMode; | |||
| static void copy_padding_kern(WorkspaceBundle bundle, | |||
| const ConvBiasImpl::NCBKernParam& param, | |||
| const ConvBiasImpl::NCBKernIndex& ncb_index, | |||
| StrategyBase* im2colstrategy) { | |||
| im2colstrategy->copy_padding_kern(bundle, param, ncb_index); | |||
| StrategyBase* im2colstrategy, size_t pack_oc_size) { | |||
| im2colstrategy->copy_padding_kern(bundle, param, ncb_index, pack_oc_size); | |||
| } | |||
| //! packA_kern | |||
| @@ -57,9 +53,9 @@ static void packA_kern(WorkspaceBundle bundle, | |||
| fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||
| fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
| const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
| StrategyBase* im2colstrategy) { | |||
| StrategyBase* im2colstrategy, size_t pack_oc_size) { | |||
| im2colstrategy->packA_kern(bundle, param, matmulparam, matmul_algo, | |||
| ncb_index); | |||
| ncb_index, pack_oc_size); | |||
| } | |||
| /*! | |||
| @@ -129,14 +125,17 @@ public: | |||
| size_t oc_tile_size) { | |||
| size_t IC = param.filter_meta.icpg, FH = param.filter_meta.spatial[0], | |||
| FW = param.filter_meta.spatial[1]; | |||
| size_t pack_oc_size = 1; | |||
| size_t im2col = 0, packb = 0, bias_temp = 0; | |||
| bool default_pack = matmul_algo->packmode() == Pack_Mode::DEFAULT; | |||
| megdnn_assert(default_pack, "only support default packa"); | |||
| if (param.filter_meta.format == param::ConvBias::Format::NCHW44) { | |||
| pack_oc_size = 4; | |||
| } | |||
| size_t im2col_dst_size = | |||
| IC * FH * FW * ohw_tile_size * sizeof(param.src_type); | |||
| size_t matmul_dst_size = | |||
| oc_tile_size * ohw_tile_size * sizeof(param.bias_type); | |||
| size_t matmul_dst_size = pack_oc_size * oc_tile_size * ohw_tile_size * | |||
| sizeof(param.bias_type); | |||
| //! matmul_dst and im2col_dst use the same memory | |||
| WorkspaceBundle wb = matmul_algo->get_bundle(im2col_kern_param); | |||
| packb = wb.get_size(1); | |||
| @@ -318,17 +317,18 @@ public: | |||
| } | |||
| }; | |||
| #undef FILL_IM2COL_STRATEGY_PARAM | |||
| fallback::MatrixMulImpl::KernSizeParam | |||
| ConvBiasImpl::AlgoIm2col ::get_matmul_kern_param(const NCBKernSizeParam& param, | |||
| size_t ohw_tile_size, | |||
| size_t oc_tile_size) const { | |||
| bool is_nchw44 = | |||
| param.filter_meta.format == param::ConvBias::Format::NCHW44; | |||
| size_t M = oc_tile_size; | |||
| size_t N = ohw_tile_size; | |||
| size_t K = param.filter_meta.icpg * param.filter_meta.spatial[0] * | |||
| param.filter_meta.spatial[1]; | |||
| size_t LDA = K, LDB = N, LDC = N; | |||
| size_t pack_oc_size = is_nchw44 ? 4 : 1; | |||
| size_t LDA = pack_oc_size * K, LDB = pack_oc_size * N, LDC = N; | |||
| bool is_dst_8bit = (param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||
| param.dst_type.enumv() == DTypeEnum::QuantizedS8) || | |||
| (param.src_type.enumv() == DTypeEnum::Quantized8Asymm && | |||
| @@ -345,7 +345,8 @@ ConvBiasImpl::AlgoIm2col ::get_matmul_kern_param(const NCBKernSizeParam& param, | |||
| false, | |||
| false, | |||
| param::MatrixMul::ComputeMode::DEFAULT, | |||
| param::MatrixMul::Format::DEFAULT}; | |||
| is_nchw44 ? param::MatrixMul::Format::MK4 | |||
| : param::MatrixMul::Format::DEFAULT}; | |||
| } | |||
| void ConvBiasImpl::AlgoIm2col::choice_ohw_oc_block( | |||
| @@ -405,6 +406,7 @@ WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_bundle( | |||
| size_t GROUP = param.filter_meta.group; | |||
| bool need_pack = m_matmul_algo->packmode() == Pack_Mode::DEFAULT; | |||
| bool only_packA = m_matmul_algo->packmode() == Pack_Mode::ONLY_PACKA; | |||
| if (need_pack || only_packA) { | |||
| auto inner_block = m_matmul_algo->get_inner_block_size(); | |||
| choice_ohw_oc_block(param, inner_block.m, inner_block.n, need_pack); | |||
| @@ -421,16 +423,19 @@ WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_bundle( | |||
| need_pack); | |||
| packa_group_size = 0; | |||
| } | |||
| if (no_need_pading) { | |||
| padding = 0; //! not need padding | |||
| } else { | |||
| padding = (GROUP * N * IC * IH2 * IW2) * | |||
| sizeof(param.src_type); //! for padding | |||
| } | |||
| packa_size = GROUP * packa_group_size; //! for packA size = GROUP * a_size | |||
| WorkspaceBundle ws = {nullptr, {}}; | |||
| auto im2col_kern_param = | |||
| get_matmul_kern_param(param, m_ohw_tile_size, m_oc_tile_size); | |||
| if (m_matmul_algo->packmode() == Pack_Mode::DEFAULT) { | |||
| Im2colKerns<Pack_Mode::DEFAULT> defaultkern; | |||
| ws = defaultkern.get_thread_bundle(param, im2col_kern_param, | |||
| @@ -447,6 +452,7 @@ WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_bundle( | |||
| m_matmul_algo, m_ohw_tile_size, | |||
| m_oc_tile_size); | |||
| } | |||
| return {nullptr, | |||
| {padding, packa_size, ws.total_size_in_bytes() * nr_threads}}; | |||
| } | |||
| @@ -461,7 +467,7 @@ size_t ConvBiasImpl::AlgoIm2col::get_workspace( | |||
| } | |||
| SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns( | |||
| ConvBiasImpl* opr, const NCBKernSizeParam& param) const { | |||
| ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||
| MIDOUT_BEGIN(megdnn_fallback_im2col, 0, 1) { | |||
| UNPACK_CONV_F32_NCB_KERN_SIZES(param); | |||
| MEGDNN_MARK_USED_VAR(SH); | |||
| @@ -473,7 +479,6 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns( | |||
| size_t ohw = OH * OW; | |||
| size_t ohw_parallel_times = div_ceil(ohw, m_ohw_tile_size); | |||
| size_t GROUP = param.filter_meta.group; | |||
| WorkspaceBundle bundle = get_bundle(param); | |||
| WorkspaceBundle bundle_thread = {nullptr, {}}; | |||
| size_t oc_parallel_times = div_ceil<size_t>(OC, m_oc_tile_size); | |||
| @@ -483,11 +488,14 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns( | |||
| bool no_pack = packmode == Pack_Mode::NO_PACK; | |||
| bool only_packA = packmode == Pack_Mode::ONLY_PACKA; | |||
| size_t packa_parallel_times = 0; | |||
| size_t pack_oc_size = | |||
| (param.filter_meta.format == param::ConvBias::Format::NCHW ? 1 | |||
| : 4); | |||
| if (only_packA) { | |||
| packa_parallel_times = div_ceil<size_t>(OC, m_oc_tile_size); | |||
| } else if (default_pack) { | |||
| packa_parallel_times = div_ceil<size_t>( | |||
| OC, m_matmul_algo->get_inner_block_size().m); | |||
| OC, m_matmul_algo->get_inner_block_size().m * pack_oc_size); | |||
| } | |||
| auto matmul_param = get_matmul_kern_param( | |||
| @@ -520,25 +528,29 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns( | |||
| strategyparam.skip_copy_dst = | |||
| strategyparam.is_ohw_size_bigger && !strategyparam.is_dst_8bit; | |||
| strategyparam.oc_tile_size = m_oc_tile_size; | |||
| strategyparam.pack_oc_size = pack_oc_size; | |||
| SmallVector<ConvBiasImpl::NCBKern> ret_kern; | |||
| MIDOUT_BEGIN( | |||
| megdnn_fallback_im2col, | |||
| midout_iv("ConvBiasImpl::AlgoIm2col::dispatch_kerns"_hash)) { | |||
| StrategyBase* im2colstrategy = Factory::get_im2col_strategy( | |||
| param, m_matmul_algo, opr->param().format); | |||
| auto kern_padding = [bundle, im2colstrategy]( | |||
| StrategyBase* im2colstrategy = | |||
| Factory::get_im2col_strategy(param, m_matmul_algo); | |||
| auto kern_padding = [bundle, im2colstrategy, | |||
| pack_oc_size = pack_oc_size]( | |||
| const NCBKernParam& param, | |||
| const NCBKernIndex& ncb_index) { | |||
| copy_padding_kern(bundle, param, ncb_index, im2colstrategy); | |||
| copy_padding_kern(bundle, param, ncb_index, im2colstrategy, | |||
| pack_oc_size); | |||
| }; | |||
| auto kern_packA = [bundle, matmul_algo = m_matmul_algo, | |||
| matmul_param, | |||
| im2colstrategy](const NCBKernParam& param, | |||
| const NCBKernIndex& ncb_index) { | |||
| matmul_param, im2colstrategy, | |||
| pack_oc_size = pack_oc_size]( | |||
| const NCBKernParam& param, | |||
| const NCBKernIndex& ncb_index) { | |||
| packA_kern(bundle, param, matmul_param, matmul_algo, ncb_index, | |||
| im2colstrategy); | |||
| im2colstrategy, pack_oc_size); | |||
| }; | |||
| if (default_pack) { | |||
| auto kern_compute_default = | |||
| @@ -556,7 +568,8 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns( | |||
| ret_kern.push_back({kern_packA, {GROUP, packa_parallel_times}}); | |||
| if (need_padding) { | |||
| ret_kern.push_back({kern_padding, {param.n, GROUP, IC}}); | |||
| ret_kern.push_back({kern_padding, | |||
| {param.n, GROUP, IC / pack_oc_size}}); | |||
| } | |||
| ret_kern.push_back( | |||
| {kern_compute_default, | |||
| @@ -629,19 +642,25 @@ bool ConvBiasImpl::AlgoIm2col::usable( | |||
| param.nonlineMode != megdnn::NonlineMode::IDENTITY) { | |||
| return false; | |||
| } | |||
| //! current now im2col only support int8 quantized s8 nchw44 | |||
| if (opr->param().format == param::ConvBias::Format::NCHW44 && | |||
| (param.src_type.enumv() == param.filter_type.enumv() && | |||
| (param.src_type.enumv() != DTypeEnum::Int8) && | |||
| (param.src_type.enumv() != DTypeEnum::QuantizedS8))) { | |||
| return false; | |||
| } | |||
| fallback::MatrixMulImpl::KernSizeParam matmul_param = | |||
| get_matmul_kern_param(param, m_ohw_tile_size, m_oc_tile_size); | |||
| bool matmulusable = m_matmul_algo->usable(matmul_param); | |||
| return matmulusable && | |||
| (opr->param().format == param::ConvBias::Format::NCHW) && | |||
| ((param.filter_meta.spatial[0] == param.filter_meta.spatial[1] && | |||
| (param.filter_meta.spatial[0] <= 7) && | |||
| (param.filter_meta.spatial[0] >= 2)) || | |||
| (param.filter_meta.spatial[0] != param.filter_meta.spatial[1] && | |||
| (param.filter_meta.spatial[0] <= 7) && | |||
| (param.filter_meta.spatial[0] >= 1) && | |||
| (param.filter_meta.spatial[1] <= 7) && | |||
| (param.filter_meta.spatial[1] >= 1))) && | |||
| (opr->param().format == param::ConvBias::Format::NCHW || | |||
| opr->param().format == param::ConvBias::Format::NCHW44) && | |||
| (!(param.filter_meta.spatial[0] == | |||
| param.filter_meta.spatial[1] && | |||
| (param.filter_meta.spatial[0] == 1) && | |||
| param.filter_meta.stride[0] == param.filter_meta.stride[1] && | |||
| param.filter_meta.stride[0] == 1)) && | |||
| (param.filter_meta.dilation[0] == | |||
| param.filter_meta.dilation[1] && | |||
| param.filter_meta.dilation[0] == 1) && | |||
| @@ -36,7 +36,6 @@ class ConvBiasImpl::AlgoIm2col final : public AlgoBase { | |||
| const NCBKernSizeParam& param, size_t ohw_tile_size, | |||
| size_t oc_tile_size) const; | |||
| WorkspaceBundle get_bundle(const NCBKernSizeParam& param) const; | |||
| WorkspaceBundle get_thread_bundle(const NCBKernSizeParam& param) const; | |||
| void choice_ohw_oc_block(const NCBKernSizeParam& param, size_t block_m, | |||
| size_t block_n, bool pack_default) const; | |||
| @@ -23,19 +23,11 @@ namespace im2col { | |||
| enum class StrategyType : uint32_t { | |||
| FLOAT = 0, | |||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
| FLOAT_FP16 = 1, | |||
| #else | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| FLOAT16_FLOAT16 = 2, | |||
| #endif | |||
| #endif | |||
| INT8x8x32 = 3, | |||
| INT8x8x16 = 4, | |||
| #if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
| QUINT8x8x32 = 5, | |||
| QUINT8x8x32x8 = 6, | |||
| #endif | |||
| QINT8x8x32 = 7, | |||
| QINT8x8x32x8 = 8 | |||
| }; | |||
| @@ -107,8 +99,7 @@ public: | |||
| ~StrategyDelegationStorage() = default; | |||
| template <typename Strategy> | |||
| Strategy* get(param::ConvBias::Format format, | |||
| fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
| Strategy* get(fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
| const fallback::ConvBiasImpl::NCBKernSizeParam& param, | |||
| StrategyType stype); | |||
| }; | |||
| @@ -117,12 +108,10 @@ class Factory { | |||
| public: | |||
| static StrategyBase* get_im2col_strategy( | |||
| const fallback::ConvBiasImpl::NCBKernSizeParam& param, | |||
| fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
| param::ConvBias::Format format) { | |||
| fallback::MatrixMulImpl::AlgoBase* matmul_algo) { | |||
| static StrategyDelegationStorage storage; | |||
| StrategyType strategytype = get_strategy_type(param); | |||
| return storage.get<StrategyBase>(format, matmul_algo, param, | |||
| strategytype); | |||
| return storage.get<StrategyBase>(matmul_algo, param, strategytype); | |||
| } | |||
| static StrategyType get_strategy_type( | |||
| @@ -141,12 +130,8 @@ public: | |||
| } | |||
| cb1(dt_float32, dt_float32, StrategyType::FLOAT); | |||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
| cb1(dt_float16, __fp16, StrategyType::FLOAT_FP16); | |||
| #else | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| cb1(dt_float16, dt_float16, StrategyType::FLOAT16_FLOAT16); | |||
| #endif | |||
| #endif | |||
| cb2(dt_int8, dt_int32, dt_int32, dt_int8, dt_int32, dt_int32, | |||
| @@ -155,13 +140,6 @@ public: | |||
| cb2(dt_int8, dt_int16, dt_int16, dt_int8, dt_int16, dt_int16, | |||
| StrategyType::INT8x8x16); | |||
| #if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
| cb2(dtype::Quantized8Asymm, dtype::QuantizedS32, dtype::QuantizedS32, | |||
| dt_uint8, dt_int32, dt_int32, StrategyType::QUINT8x8x32); | |||
| cb2(dtype::Quantized8Asymm, dtype::QuantizedS32, dtype::Quantized8Asymm, | |||
| dt_uint8, dt_int32, dt_uint8, StrategyType::QUINT8x8x32x8); | |||
| #endif | |||
| cb2(dtype::QuantizedS8, dtype::QuantizedS32, dtype::QuantizedS32, | |||
| dt_int8, dt_int32, dt_int32, StrategyType::QINT8x8x32); | |||
| @@ -172,98 +150,106 @@ public: | |||
| megdnn_throw("not support datatype in im2col strategy\n"); | |||
| } | |||
| #define cb1(_packmode, _dt, _post_ctype, _postprocess_mode, _midout_tag) \ | |||
| MIDOUT_BEGIN(megdnn_fallback_im2col_factory_make_strategy, \ | |||
| midout_iv(_midout_tag)) { \ | |||
| if (param.filter_type.enumv() == DTypeTrait<_dt>::enumv) { \ | |||
| return std::make_unique< \ | |||
| Strategy<_dt, _dt, _dt, _post_ctype, _post_ctype, \ | |||
| _postprocess_mode, PackMode::_packmode>>(); \ | |||
| } \ | |||
| } \ | |||
| MIDOUT_END(); \ | |||
| #define cb1(_format, _packmode, _dt, _post_ctype, _postprocess_mode, \ | |||
| _midout_tag) \ | |||
| MIDOUT_BEGIN(megdnn_fallback_im2col_factory_make_strategy, \ | |||
| midout_iv(_midout_tag)) { \ | |||
| if (param.filter_type.enumv() == DTypeTrait<_dt>::enumv) { \ | |||
| return std::make_unique< \ | |||
| Strategy<_dt, _dt, _dt, _post_ctype, _post_ctype, \ | |||
| _postprocess_mode, PackMode::_packmode, \ | |||
| FormatMode::_format>>(); \ | |||
| } \ | |||
| } \ | |||
| MIDOUT_END(); \ | |||
| return {}; | |||
| #define cb2(_packmode, _i_src_type, _i_bias_type, _i_dst_type, _src_ctype, \ | |||
| _bias_ctype, _dst_ctype, _postprocess_mode, _midout_tag) \ | |||
| MIDOUT_BEGIN(megdnn_fallback_im2col_factory_make_strategy, \ | |||
| midout_iv(_midout_tag)) { \ | |||
| if (param.filter_type.enumv() == param.src_type.enumv() && \ | |||
| param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \ | |||
| param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \ | |||
| return std::make_unique< \ | |||
| Strategy<_src_ctype, _bias_ctype, _dst_ctype, \ | |||
| DTypeTrait<_i_bias_type>::ctype, \ | |||
| DTypeTrait<_i_dst_type>::ctype, \ | |||
| _postprocess_mode, PackMode::_packmode>>(); \ | |||
| } \ | |||
| } \ | |||
| MIDOUT_END(); \ | |||
| #define cb2(_format, _packmode, _i_src_type, _i_bias_type, _i_dst_type, \ | |||
| _src_ctype, _bias_ctype, _dst_ctype, _postprocess_mode, \ | |||
| _midout_tag) \ | |||
| MIDOUT_BEGIN(megdnn_fallback_im2col_factory_make_strategy, \ | |||
| midout_iv(_midout_tag)) { \ | |||
| if (param.filter_type.enumv() == param.src_type.enumv() && \ | |||
| param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \ | |||
| param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \ | |||
| return std::make_unique<Strategy< \ | |||
| _src_ctype, _bias_ctype, _dst_ctype, \ | |||
| DTypeTrait<_i_bias_type>::ctype, \ | |||
| DTypeTrait<_i_dst_type>::ctype, _postprocess_mode, \ | |||
| PackMode::_packmode, FormatMode::_format>>(); \ | |||
| } \ | |||
| } \ | |||
| MIDOUT_END(); \ | |||
| return {}; | |||
| static std::unique_ptr<StrategyBase> make_default_strategy( | |||
| fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
| const fallback::ConvBiasImpl::NCBKernSizeParam& param, | |||
| param::ConvBias::Format format, StrategyType strategytype) { | |||
| StrategyType strategytype) { | |||
| MEGDNN_MARK_USED_VAR(matmul_algo); | |||
| MEGDNN_MARK_USED_VAR(format); | |||
| param::ConvBias::Format format = param.filter_meta.format; | |||
| switch (strategytype) { | |||
| case StrategyType::FLOAT: | |||
| cb1(DEFAULT, dt_float32, dt_float32, PostprocessMode::FLOAT, | |||
| "DefaultStrategyType::FLOAT"_hash); | |||
| break; | |||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
| case StrategyType::FLOAT_FP16: | |||
| cb1(DEFAULT, dt_float16, __fp16, PostprocessMode::FLOAT, | |||
| "DefaultStrategyType::FLOAT_FP16"_hash); | |||
| cb1(NCHW, DEFAULT, dt_float32, dt_float32, | |||
| PostprocessMode::FLOAT, "DefaultStrategyType::FLOAT"_hash); | |||
| break; | |||
| #else | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| case StrategyType::FLOAT16_FLOAT16: | |||
| cb1(DEFAULT, dt_float16, dt_float16, | |||
| cb1(NCHW, DEFAULT, dt_float16, dt_float16, | |||
| PostprocessMode::NO_PROCESS, | |||
| "DefaultStrategyType::FLOAT16_FLOAT16"_hash); | |||
| break; | |||
| #endif | |||
| #endif | |||
| case StrategyType::INT8x8x32: | |||
| cb2(DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8, dt_int32, | |||
| dt_int32, PostprocessMode::NO_PROCESS, | |||
| "DefaultStrategyType::INT8x8x32"_hash); | |||
| if (format == param::ConvBias::Format::NCHW) { | |||
| cb2(NCHW, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8, | |||
| dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||
| "DefaultStrategyType::INT8x8x32"_hash); | |||
| } else if (format == param::ConvBias::Format::NCHW44) { | |||
| cb2(NCHW44, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8, | |||
| dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||
| "DefaultStrategyType::INT8x8x32"_hash); | |||
| } else { | |||
| megdnn_throw("not support format except nchw44 and nchw\n"); | |||
| } | |||
| break; | |||
| case StrategyType::INT8x8x16: | |||
| cb2(DEFAULT, dt_int8, dt_int16, dt_int16, dt_int8, dt_int16, | |||
| dt_int16, PostprocessMode::NO_PROCESS, | |||
| cb2(NCHW, DEFAULT, dt_int8, dt_int16, dt_int16, dt_int8, | |||
| dt_int16, dt_int16, PostprocessMode::NO_PROCESS, | |||
| "DefaultStrategyType::INT8x8x16"_hash); | |||
| break; | |||
| #if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
| case StrategyType::QUINT8x8x32: | |||
| cb2(DEFAULT, dtype::Quantized8Asymm, dtype::QuantizedS32, | |||
| dtype::QuantizedS32, dt_uint8, dt_int32, dt_int32, | |||
| PostprocessMode::NO_PROCESS, | |||
| "DefaultStrategyType::QUINT8x8x32"_hash); | |||
| break; | |||
| case StrategyType::QUINT8x8x32x8: | |||
| cb2(DEFAULT, dtype::Quantized8Asymm, dtype::QuantizedS32, | |||
| dtype::Quantized8Asymm, dt_uint8, dt_int32, dt_uint8, | |||
| PostprocessMode::QUANTIZED, | |||
| "DefaultStrategyType::QUINT8x8x32x8"_hash); | |||
| break; | |||
| #endif | |||
| case StrategyType::QINT8x8x32: | |||
| cb2(DEFAULT, dtype::QuantizedS8, dtype::QuantizedS32, | |||
| dtype::QuantizedS32, dt_int8, dt_int32, dt_int32, | |||
| PostprocessMode::NO_PROCESS, | |||
| "DefaultStrategyType::QINT8x8x32"_hash); | |||
| if (format == param::ConvBias::Format::NCHW) { | |||
| cb2(NCHW, DEFAULT, dtype::QuantizedS8, dtype::QuantizedS32, | |||
| dtype::QuantizedS32, dt_int8, dt_int32, dt_int32, | |||
| PostprocessMode::NO_PROCESS, | |||
| "DefaultStrategyTypeNCHW::QINT8x8x32"_hash); | |||
| } else if (format == param::ConvBias::Format::NCHW44) { | |||
| cb2(NCHW44, DEFAULT, dtype::QuantizedS8, | |||
| dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, | |||
| dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||
| "DefaultStrategyTypeHCHW44::QINT8x8x32"_hash); | |||
| } else { | |||
| megdnn_throw("not support format except nchw44 and nchw\n"); | |||
| } | |||
| break; | |||
| case StrategyType::QINT8x8x32x8: | |||
| cb2(DEFAULT, dtype::QuantizedS8, dtype::QuantizedS32, | |||
| dtype::QuantizedS8, dt_int8, dt_int32, dt_int8, | |||
| PostprocessMode::QUANTIZED, | |||
| "DefaultStrategyType::QINT8x8x32x8"_hash); | |||
| if (format == param::ConvBias::Format::NCHW) { | |||
| cb2(NCHW, DEFAULT, dtype::QuantizedS8, dtype::QuantizedS32, | |||
| dtype::QuantizedS8, dt_int8, dt_int32, dt_int8, | |||
| PostprocessMode::QUANTIZED, | |||
| "DefaultStrategyType::QINT8x8x32x8"_hash); | |||
| } else if (format == param::ConvBias::Format::NCHW44) { | |||
| cb2(NCHW44, DEFAULT, dtype::QuantizedS8, | |||
| dtype::QuantizedS32, dtype::QuantizedS8, dt_int8, | |||
| dt_int32, dt_int8, PostprocessMode::QUANTIZED, | |||
| "DefaultStrategyTypeNCHW44::QINT8x8x32x8"_hash); | |||
| } else { | |||
| megdnn_throw("not support format except nchw44 and nchw\n"); | |||
| } | |||
| break; | |||
| } | |||
| megdnn_throw("error not support strategy type "); | |||
| @@ -272,63 +258,41 @@ public: | |||
| static std::unique_ptr<StrategyBase> make_nopack_strategy( | |||
| fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
| const fallback::ConvBiasImpl::NCBKernSizeParam& param, | |||
| param::ConvBias::Format format, StrategyType strategytype) { | |||
| StrategyType strategytype) { | |||
| MEGDNN_MARK_USED_VAR(matmul_algo); | |||
| MEGDNN_MARK_USED_VAR(format); | |||
| switch (strategytype) { | |||
| case StrategyType::FLOAT: | |||
| cb1(NO_PACK, dt_float32, dt_float32, PostprocessMode::FLOAT, | |||
| "NoPackStrategyType::FLOAT"_hash); | |||
| break; | |||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
| case StrategyType::FLOAT_FP16: | |||
| cb1(NO_PACK, dt_float16, __fp16, PostprocessMode::FLOAT, | |||
| "NoPackStrategyType::FLOAT_FP16"_hash); | |||
| cb1(NCHW, NO_PACK, dt_float32, dt_float32, | |||
| PostprocessMode::FLOAT, "NoPackStrategyType::FLOAT"_hash); | |||
| break; | |||
| #else | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| case StrategyType::FLOAT16_FLOAT16: | |||
| cb1(NO_PACK, dt_float16, dt_float16, PostprocessMode::NO_PROCESS, | |||
| cb1(NCHW, NO_PACK, dt_float16, dt_float16, | |||
| PostprocessMode::NO_PROCESS, | |||
| "NoPackStrategyType::FLOAT16_FLOAT16"_hash); | |||
| break; | |||
| #endif | |||
| #endif | |||
| case StrategyType::INT8x8x32: | |||
| cb2(NO_PACK, dt_int8, dt_int32, dt_int32, dt_int8, dt_int32, | |||
| dt_int32, PostprocessMode::NO_PROCESS, | |||
| cb2(NCHW, NO_PACK, dt_int8, dt_int32, dt_int32, dt_int8, | |||
| dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||
| "NoPackStrategyType::INT8x8x32"_hash); | |||
| break; | |||
| case StrategyType::INT8x8x16: | |||
| cb2(NO_PACK, dt_int8, dt_int16, dt_int16, dt_int8, dt_int16, | |||
| dt_int16, PostprocessMode::NO_PROCESS, | |||
| cb2(NCHW, NO_PACK, dt_int8, dt_int16, dt_int16, dt_int8, | |||
| dt_int16, dt_int16, PostprocessMode::NO_PROCESS, | |||
| "NoPackStrategyType::INT8x8x16"_hash); | |||
| break; | |||
| #if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
| case StrategyType::QUINT8x8x32: | |||
| cb2(NO_PACK, dtype::Quantized8Asymm, dtype::QuantizedS32, | |||
| dtype::QuantizedS32, dt_uint8, dt_int32, dt_int32, | |||
| PostprocessMode::NO_PROCESS, | |||
| "NoPackStrategyType::QUINT8x8x32"_hash); | |||
| break; | |||
| case StrategyType::QUINT8x8x32x8: | |||
| cb2(NO_PACK, dtype::Quantized8Asymm, dtype::QuantizedS32, | |||
| dtype::Quantized8Asymm, dt_uint8, dt_int32, dt_uint8, | |||
| PostprocessMode::QUANTIZED, | |||
| "NoPackStrategyType::QUINT8x8x32x8"_hash); | |||
| break; | |||
| #endif | |||
| case StrategyType::QINT8x8x32: | |||
| cb2(NO_PACK, dtype::QuantizedS8, dtype::QuantizedS32, | |||
| cb2(NCHW, NO_PACK, dtype::QuantizedS8, dtype::QuantizedS32, | |||
| dtype::QuantizedS32, dt_int8, dt_int32, dt_int32, | |||
| PostprocessMode::NO_PROCESS, | |||
| "NoPackStrategyType::QINT8x8x32"_hash); | |||
| break; | |||
| case StrategyType::QINT8x8x32x8: | |||
| cb2(NO_PACK, dtype::QuantizedS8, dtype::QuantizedS32, | |||
| cb2(NCHW, NO_PACK, dtype::QuantizedS8, dtype::QuantizedS32, | |||
| dtype::QuantizedS8, dt_int8, dt_int32, dt_int8, | |||
| PostprocessMode::QUANTIZED, | |||
| "NoPackStrategyType::QINT8x8x32x8"_hash); | |||
| @@ -340,64 +304,42 @@ public: | |||
| static std::unique_ptr<StrategyBase> make_onlypacka_strategy( | |||
| fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
| const fallback::ConvBiasImpl::NCBKernSizeParam& param, | |||
| param::ConvBias::Format format, StrategyType strategytype) { | |||
| StrategyType strategytype) { | |||
| MEGDNN_MARK_USED_VAR(matmul_algo); | |||
| MEGDNN_MARK_USED_VAR(format); | |||
| switch (strategytype) { | |||
| case StrategyType::FLOAT: | |||
| cb1(ONLY_PACKA, dt_float32, dt_float32, PostprocessMode::FLOAT, | |||
| cb1(NCHW, ONLY_PACKA, dt_float32, dt_float32, | |||
| PostprocessMode::FLOAT, | |||
| "OnlyPackaStrategyType::FLOAT"_hash); | |||
| break; | |||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
| case StrategyType::FLOAT_FP16: | |||
| cb1(ONLY_PACKA, dt_float16, __fp16, PostprocessMode::FLOAT, | |||
| "OnlyPackaStrategyType::FLOAT_FP16"_hash); | |||
| break; | |||
| #else | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| case StrategyType::FLOAT16_FLOAT16: | |||
| cb1(ONLY_PACKA, dt_float16, dt_float16, | |||
| cb1(NCHW, ONLY_PACKA, dt_float16, dt_float16, | |||
| PostprocessMode::NO_PROCESS, | |||
| "OnlyPackaStrategyType::FLOAT16_FLOAT16"_hash); | |||
| break; | |||
| #endif | |||
| #endif | |||
| case StrategyType::INT8x8x32: | |||
| cb2(ONLY_PACKA, dt_int8, dt_int32, dt_int32, dt_int8, dt_int32, | |||
| dt_int32, PostprocessMode::NO_PROCESS, | |||
| cb2(NCHW, ONLY_PACKA, dt_int8, dt_int32, dt_int32, dt_int8, | |||
| dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||
| "OnlyPackaStrategyType::INT8x8x32"_hash); | |||
| break; | |||
| case StrategyType::INT8x8x16: | |||
| cb2(ONLY_PACKA, dt_int8, dt_int16, dt_int16, dt_int8, dt_int16, | |||
| dt_int16, PostprocessMode::NO_PROCESS, | |||
| cb2(NCHW, ONLY_PACKA, dt_int8, dt_int16, dt_int16, dt_int8, | |||
| dt_int16, dt_int16, PostprocessMode::NO_PROCESS, | |||
| "OnlyPackaStrategyType::INT8x8x16"_hash); | |||
| break; | |||
| #if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
| case StrategyType::QUINT8x8x32: | |||
| cb2(ONLY_PACKA, dtype::Quantized8Asymm, dtype::QuantizedS32, | |||
| dtype::QuantizedS32, dt_uint8, dt_int32, dt_int32, | |||
| PostprocessMode::NO_PROCESS, | |||
| "OnlyPackaStrategyType::QUINT8x8x32"_hash); | |||
| break; | |||
| case StrategyType::QUINT8x8x32x8: | |||
| cb2(ONLY_PACKA, dtype::Quantized8Asymm, dtype::QuantizedS32, | |||
| dtype::Quantized8Asymm, dt_uint8, dt_int32, dt_uint8, | |||
| PostprocessMode::QUANTIZED, | |||
| "OnlyPackaStrategyType::QUINT8x8x32x8"_hash); | |||
| break; | |||
| #endif | |||
| case StrategyType::QINT8x8x32: | |||
| cb2(ONLY_PACKA, dtype::QuantizedS8, dtype::QuantizedS32, | |||
| cb2(NCHW, ONLY_PACKA, dtype::QuantizedS8, dtype::QuantizedS32, | |||
| dtype::QuantizedS32, dt_int8, dt_int32, dt_int32, | |||
| PostprocessMode::NO_PROCESS, | |||
| "OnlyPackaStrategyType::QINT8x8x32"_hash); | |||
| break; | |||
| case StrategyType::QINT8x8x32x8: | |||
| cb2(ONLY_PACKA, dtype::QuantizedS8, dtype::QuantizedS32, | |||
| cb2(NCHW, ONLY_PACKA, dtype::QuantizedS8, dtype::QuantizedS32, | |||
| dtype::QuantizedS8, dt_int8, dt_int32, dt_int8, | |||
| PostprocessMode::QUANTIZED, | |||
| "OnlyPackaStrategyType::QINT8x8x32x8"_hash); | |||
| @@ -410,21 +352,19 @@ public: | |||
| #undef cb2 | |||
| static std::unique_ptr<StrategyBase> make_strategy( | |||
| param::ConvBias::Format format, | |||
| fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
| fallback::MatrixMulImpl::AlgoBase::PackMode packmode, | |||
| const fallback::ConvBiasImpl::NCBKernSizeParam& param, | |||
| StrategyType stype) { | |||
| switch (packmode) { | |||
| case MatrixMulImpl::AlgoBase::PackMode::DEFAULT: | |||
| return make_default_strategy(matmul_algo, param, format, stype); | |||
| return make_default_strategy(matmul_algo, param, stype); | |||
| break; | |||
| case MatrixMulImpl::AlgoBase::PackMode::ONLY_PACKA: | |||
| return make_onlypacka_strategy(matmul_algo, param, format, | |||
| stype); | |||
| return make_onlypacka_strategy(matmul_algo, param, stype); | |||
| break; | |||
| case MatrixMulImpl::AlgoBase::PackMode::NO_PACK: | |||
| return make_nopack_strategy(matmul_algo, param, format, stype); | |||
| return make_nopack_strategy(matmul_algo, param, stype); | |||
| break; | |||
| default: | |||
| megdnn_throw( | |||
| @@ -432,14 +372,12 @@ public: | |||
| "nopack"); | |||
| break; | |||
| } | |||
| megdnn_throw( | |||
| "factory make Strategy error please check your code"); | |||
| megdnn_throw("factory make Strategy error please check your code"); | |||
| } | |||
| }; | |||
| template <typename Strategy> | |||
| Strategy* StrategyDelegationStorage::get( | |||
| param::ConvBias::Format format, | |||
| fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
| const fallback::ConvBiasImpl::NCBKernSizeParam& param, | |||
| StrategyType stype) { | |||
| @@ -455,14 +393,14 @@ Strategy* StrategyDelegationStorage::get( | |||
| } | |||
| StrategyHashParam sparam; | |||
| sparam.param = param; | |||
| sparam.format = format; | |||
| sparam.format = param.filter_meta.format; | |||
| sparam.packmode = packmode; | |||
| sparam.block_m = block_m; | |||
| sparam.block_n = block_n; | |||
| sparam.block_k = block_k; | |||
| if (map_strategys.find(sparam) == map_strategys.end()) { | |||
| MEGDNN_LOCK_GUARD(m_mtx); | |||
| auto strategy = Factory::make_strategy(format, matmul_algo, packmode, | |||
| auto strategy = Factory::make_strategy(matmul_algo, packmode, | |||
| param, stype); | |||
| map_strategys[sparam] = std::move(strategy); | |||
| } | |||
| @@ -14,6 +14,7 @@ | |||
| namespace megdnn { | |||
| using PackMode = fallback::MatrixMulImpl::AlgoBase::PackMode; | |||
| using FormatMode = param::ConvBias::Format; | |||
| struct StrategyParam { | |||
| size_t batch_id; | |||
| @@ -28,6 +29,7 @@ struct StrategyParam { | |||
| size_t block_m; | |||
| size_t block_n; | |||
| size_t block_k; | |||
| size_t pack_oc_size; | |||
| bool skip_copy_dst; | |||
| bool is_dst_8bit; | |||
| bool is_ohw_size_bigger; | |||
| @@ -40,13 +42,15 @@ public: | |||
| virtual void copy_padding_kern( | |||
| WorkspaceBundle bundle, | |||
| const fallback::ConvBiasImpl::NCBKernParam& param, | |||
| const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) = 0; | |||
| const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
| size_t pack_size) = 0; | |||
| virtual void packA_kern( | |||
| WorkspaceBundle bundle, | |||
| const fallback::ConvBiasImpl::NCBKernParam& param, | |||
| fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||
| fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
| const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) = 0; | |||
| const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
| size_t pack_size) = 0; | |||
| virtual void exec_im2col( | |||
| WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | |||
| @@ -70,14 +74,16 @@ public: | |||
| template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
| typename op_ctype, typename op_dtype, | |||
| megdnn::PostprocessMode postprocess_mode, PackMode packmode> | |||
| megdnn::PostprocessMode postprocess_mode, PackMode packmode, | |||
| FormatMode format> | |||
| class Strategy; | |||
| template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
| typename op_ctype, typename op_dtype, | |||
| megdnn::PostprocessMode postprocess_mode> | |||
| class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
| postprocess_mode, PackMode::DEFAULT> : public StrategyBase { | |||
| postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW> | |||
| : public StrategyBase { | |||
| public: | |||
| constexpr static size_t BUNDLE_PADDING_INDEX = 0; | |||
| constexpr static size_t BUNDLE_PACKA_INDEX = 1; | |||
| @@ -85,24 +91,26 @@ public: | |||
| constexpr static size_t THREAD_BUNDLE_IM2COL_INDEX = 1; | |||
| constexpr static size_t THREAD_BUNDLE_BIAS_INDEX = 2; | |||
| Strategy(); | |||
| Strategy() = default; | |||
| void copy_padding_kern( | |||
| WorkspaceBundle bundle, | |||
| const fallback::ConvBiasImpl::NCBKernParam& param, | |||
| const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) override; | |||
| void packA_kern( | |||
| WorkspaceBundle bundle, | |||
| const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
| size_t pack_size) override; | |||
| void packA_kern(WorkspaceBundle bundle, | |||
| const fallback::ConvBiasImpl::NCBKernParam& param, | |||
| fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||
| fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
| const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
| size_t pack_size) override; | |||
| virtual void exec_im2col( | |||
| WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | |||
| const StrategyParam& sparam, | |||
| const fallback::ConvBiasImpl::NCBKernParam& param, | |||
| fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||
| fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
| const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) override; | |||
| void exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | |||
| const StrategyParam& sparam, | |||
| const fallback::ConvBiasImpl::NCBKernParam& param, | |||
| fallback::MatrixMulImpl::KernParam matmul_param, | |||
| fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | |||
| fallback::MatrixMulImpl::KernParam matmul_param, | |||
| fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | |||
| void exec_matmul( | |||
| const fallback::ConvBiasImpl::NCBKernParam& param, | |||
| @@ -132,7 +140,32 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
| typename op_ctype, typename op_dtype, | |||
| megdnn::PostprocessMode postprocess_mode> | |||
| class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
| postprocess_mode, PackMode::NO_PACK> : public StrategyBase { | |||
| postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW44> | |||
| : public Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
| postprocess_mode, PackMode::DEFAULT, | |||
| FormatMode::NCHW> { | |||
| public: | |||
| const size_t BUNDLE_PADDING_INDEX = 0; | |||
| const size_t BUNDLE_PACKA_INDEX = 1; | |||
| const size_t THREAD_BUNDLE_PACKB_INDEX = 0; | |||
| const size_t THREAD_BUNDLE_IM2COL_INDEX = 1; | |||
| const size_t THREAD_BUNDLE_BIAS_INDEX = 2; | |||
| Strategy() = default; | |||
| void exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | |||
| const StrategyParam& sparam, | |||
| const fallback::ConvBiasImpl::NCBKernParam& param, | |||
| fallback::MatrixMulImpl::KernParam matmul_param, | |||
| fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | |||
| }; | |||
| template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
| typename op_ctype, typename op_dtype, | |||
| megdnn::PostprocessMode postprocess_mode> | |||
| class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
| postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW> | |||
| : public StrategyBase { | |||
| public: | |||
| constexpr static size_t BUNDLE_PADDING_INDEX = 0; | |||
| constexpr static size_t BUNDLE_PACKA_INDEX = 1; | |||
| @@ -141,19 +174,20 @@ public: | |||
| constexpr static size_t THREAD_BUNDLE_BIAS_INDEX = 2; | |||
| constexpr static size_t THREAD_BUNDLE_MATCOMP_INDEX = 3; | |||
| Strategy(); | |||
| Strategy() = default; | |||
| void copy_padding_kern( | |||
| WorkspaceBundle bundle, | |||
| const fallback::ConvBiasImpl::NCBKernParam& param, | |||
| const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) override; | |||
| const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
| size_t pack_size) override; | |||
| void packA_kern( | |||
| WorkspaceBundle bundle, | |||
| const fallback::ConvBiasImpl::NCBKernParam& param, | |||
| fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||
| fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
| const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) override; | |||
| void packA_kern(WorkspaceBundle bundle, | |||
| const fallback::ConvBiasImpl::NCBKernParam& param, | |||
| fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||
| fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
| const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
| size_t pack_size) override; | |||
| void exec_matmul( | |||
| const fallback::ConvBiasImpl::NCBKernParam& param, | |||
| @@ -197,7 +231,8 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
| typename op_ctype, typename op_dtype, | |||
| megdnn::PostprocessMode postprocess_mode> | |||
| class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
| postprocess_mode, PackMode::ONLY_PACKA> : public StrategyBase { | |||
| postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW> | |||
| : public StrategyBase { | |||
| public: | |||
| constexpr static size_t BUNDLE_PADDING_INDEX = 0; | |||
| constexpr static size_t BUNDLE_PACKA_INDEX = 1; | |||
| @@ -206,19 +241,20 @@ public: | |||
| constexpr static size_t THREAD_BUNDLE_MATMULDST_INDEX = 2; | |||
| constexpr static size_t THREAD_BUNDLE_BIAS_INDEX = 3; | |||
| Strategy(); | |||
| Strategy() = default; | |||
| void copy_padding_kern( | |||
| WorkspaceBundle bundle, | |||
| const fallback::ConvBiasImpl::NCBKernParam& param, | |||
| const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) override; | |||
| void packA_kern( | |||
| WorkspaceBundle bundle, | |||
| const fallback::ConvBiasImpl::NCBKernParam& param, | |||
| fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||
| fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
| const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) override; | |||
| const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
| size_t pack_size) override; | |||
| void packA_kern(WorkspaceBundle bundle, | |||
| const fallback::ConvBiasImpl::NCBKernParam& param, | |||
| fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||
| fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
| const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
| size_t pack_size) override; | |||
| void exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | |||
| const StrategyParam& sparam, | |||
| @@ -8,8 +8,6 @@ | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "megdnn/opr_param_defs.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/fallback/conv_bias/im2col/strategy_base.h" | |||
| #include "src/fallback/convolution/img2col_helper.h" | |||
| #if MEGDNN_X86 | |||
| @@ -22,22 +20,15 @@ using namespace x86; | |||
| #endif | |||
| namespace megdnn { | |||
| template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
| typename op_ctype, typename op_dtype, | |||
| megdnn::PostprocessMode postprocess_mode> | |||
| Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
| postprocess_mode,PackMode::DEFAULT>::Strategy() | |||
| : StrategyBase() {} | |||
| template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
| typename op_ctype, typename op_dtype, | |||
| megdnn::PostprocessMode postprocess_mode> | |||
| void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
| postprocess_mode,PackMode::DEFAULT>:: | |||
| copy_padding_kern( | |||
| WorkspaceBundle bundle, | |||
| const fallback::ConvBiasImpl::NCBKernParam& param, | |||
| const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) { | |||
| postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>:: | |||
| copy_padding_kern(WorkspaceBundle bundle, | |||
| const fallback::ConvBiasImpl::NCBKernParam& param, | |||
| const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
| size_t pack_oc_size) { | |||
| UNPACK_CONV_F32_NCB_KERN_SIZES(param); | |||
| MEGDNN_MARK_USED_VAR(N); | |||
| MEGDNN_MARK_USED_VAR(OC); | |||
| @@ -53,9 +44,13 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
| size_t batch_id = ncb_index.ndrange_id[0]; | |||
| size_t group_id = ncb_index.ndrange_id[1]; | |||
| size_t channel_id = ncb_index.ndrange_id[2]; | |||
| size_t PH_SIZE = PH * IW2 * pack_oc_size; | |||
| PW = PW * pack_oc_size; | |||
| IW = IW * pack_oc_size; | |||
| size_t padding_group_size = IH2 * IW2 * IC; | |||
| size_t workspace_channel_offset = IH2 * IW2 * channel_id; | |||
| size_t workspace_channel_offset = pack_oc_size * IH2 * IW2 * channel_id; | |||
| size_t workspace_group_offset = group_id * padding_group_size; | |||
| size_t workspace_batch_offset = | |||
| param.filter_meta.group * batch_id * padding_group_size; | |||
| @@ -65,8 +60,8 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
| if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) { | |||
| src_zp = param.src_type.param<dtype::Quantized8Asymm>().zero_point; | |||
| } | |||
| src_ctype* src = const_cast<src_ctype*>( | |||
| param.src<src_ctype>(batch_id, group_id, channel_id)); | |||
| src_ctype* src = const_cast<src_ctype*>(param.src<src_ctype>( | |||
| batch_id, group_id, channel_id, 1, pack_oc_size)); | |||
| src_ctype* src2; | |||
| src2 = static_cast<src_ctype*>(bundle.get(BUNDLE_PADDING_INDEX)) + | |||
| workspace_group_offset + workspace_batch_offset + | |||
| @@ -74,8 +69,8 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
| src_ctype* src2_ptr = src2; | |||
| const src_ctype* src_ptr = src; | |||
| if (PH != 0) { | |||
| std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH * IW2); | |||
| src2_ptr += PH * IW2; | |||
| std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH_SIZE); | |||
| src2_ptr += PH_SIZE; | |||
| } | |||
| rep(ih, IH) { | |||
| if (PW != 0) | |||
| @@ -87,8 +82,8 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
| rep(pw, PW) * (src2_ptr++) = src_zp; | |||
| } | |||
| if (PH != 0) { | |||
| std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH * IW2); | |||
| src2_ptr += PH * IW2; | |||
| std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH_SIZE); | |||
| src2_ptr += PH_SIZE; | |||
| } | |||
| } | |||
| @@ -96,12 +91,13 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
| typename op_ctype, typename op_dtype, | |||
| megdnn::PostprocessMode postprocess_mode> | |||
| void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
| postprocess_mode,PackMode::DEFAULT>:: | |||
| postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>:: | |||
| packA_kern(WorkspaceBundle bundle, | |||
| const fallback::ConvBiasImpl::NCBKernParam& param, | |||
| fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||
| fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
| const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) { | |||
| const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
| size_t pack_oc_size) { | |||
| bundle.set(param.workspace_ptr); | |||
| fallback::MatrixMulImpl::KernParam matmul_param; | |||
| size_t group_id = ncb_index.ndrange_id[0]; | |||
| @@ -114,38 +110,38 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
| matmul_algo->get_packA_type_size(); | |||
| size_t a_panel_offset = ncb_index.ndrange_id[1] * packed_per_oc_block_size; | |||
| int8_t* a_panel = static_cast<int8_t*>(bundle.get(BUNDLE_PACKA_INDEX)) + | |||
| group_id * packA_group_size + a_panel_offset; | |||
| group_id * packA_group_size + | |||
| (pack_oc_size == 4 ? 0 : a_panel_offset); | |||
| matmul_param.A_ptr = | |||
| const_cast<src_ctype*>(param.filter<src_ctype>(group_id)); | |||
| matmul_algo->pack_A(matmul_param, a_panel, ncb_index.ndrange_id[1], | |||
| matmul_algo->get_inner_block_size().m); | |||
| matmul_algo->get_inner_block_size().m * pack_oc_size); | |||
| } | |||
| template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
| typename op_ctype, typename op_dtype, | |||
| megdnn::PostprocessMode postprocess_mode> | |||
| void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
| postprocess_mode,PackMode::DEFAULT>:: | |||
| postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>:: | |||
| exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | |||
| const StrategyParam& sparam, | |||
| const fallback::ConvBiasImpl::NCBKernParam& param, | |||
| fallback::MatrixMulImpl::KernParam matmul_param, | |||
| fallback::MatrixMulImpl::AlgoBase* matmul_algo | |||
| ) { | |||
| size_t m_sh = param.filter_meta.stride[0]; | |||
| size_t m_sw = param.filter_meta.stride[1]; | |||
| size_t m_oc = param.filter_meta.ocpg; | |||
| size_t m_oh = param.osz[0]; | |||
| size_t m_ow = param.osz[1]; | |||
| size_t m_ic = param.filter_meta.icpg; | |||
| size_t m_ih = param.isz[0] + param.filter_meta.padding[0] * 2; | |||
| size_t m_iw = param.isz[1] + param.filter_meta.padding[1] * 2; | |||
| size_t m_fh = param.filter_meta.spatial[0]; | |||
| size_t m_fw = param.filter_meta.spatial[1]; | |||
| size_t m_is_xcorr = !param.filter_meta.should_flip; | |||
| fallback::MatrixMulImpl::AlgoBase* matmul_algo) { | |||
| size_t sh = param.filter_meta.stride[0]; | |||
| size_t sw = param.filter_meta.stride[1]; | |||
| size_t oc = param.filter_meta.ocpg; | |||
| size_t oh = param.osz[0]; | |||
| size_t ow = param.osz[1]; | |||
| size_t ic = param.filter_meta.icpg; | |||
| size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2; | |||
| size_t iw = param.isz[1] + param.filter_meta.padding[1] * 2; | |||
| size_t fh = param.filter_meta.spatial[0]; | |||
| size_t fw = param.filter_meta.spatial[1]; | |||
| size_t is_xcorr = !param.filter_meta.should_flip; | |||
| size_t input_offset = | |||
| m_ih * m_iw * m_ic * | |||
| ih * iw * ic * | |||
| (sparam.group_id + param.filter_meta.group * sparam.batch_id) * | |||
| sizeof(src_ctype); | |||
| @@ -160,26 +156,22 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
| } | |||
| src_ctype* im2col_dst = static_cast<src_ctype*>( | |||
| bundle_thread.get(THREAD_BUNDLE_IM2COL_INDEX)); | |||
| if (m_sh == 1 && m_sw == 1) { | |||
| if (m_is_xcorr) { | |||
| img2col<true>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, m_ih, m_iw, | |||
| m_fh, m_fw, sparam.ohw_cur_index, | |||
| sparam.output_block_size); | |||
| if (sh == 1 && sw == 1) { | |||
| if (is_xcorr) { | |||
| img2col<true>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, fw, | |||
| sparam.ohw_cur_index, sparam.output_block_size); | |||
| } else { | |||
| img2col<false>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, m_ih, m_iw, | |||
| m_fh, m_fw, sparam.ohw_cur_index, | |||
| sparam.output_block_size); | |||
| img2col<false>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, fw, | |||
| sparam.ohw_cur_index, sparam.output_block_size); | |||
| } | |||
| } else { | |||
| if (m_is_xcorr) { | |||
| img2col_stride<true>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, m_ih, | |||
| m_iw, m_fh, m_fw, m_sh, m_sw, | |||
| sparam.ohw_cur_index, | |||
| if (is_xcorr) { | |||
| img2col_stride<true>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, | |||
| fw, sh, sw, sparam.ohw_cur_index, | |||
| sparam.output_block_size); | |||
| } else { | |||
| img2col_stride<false>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, | |||
| m_ih, m_iw, m_fh, m_fw, m_sh, m_sw, | |||
| sparam.ohw_cur_index, | |||
| img2col_stride<false>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, | |||
| fw, sh, sw, sparam.ohw_cur_index, | |||
| sparam.output_block_size); | |||
| } | |||
| } | |||
| @@ -199,7 +191,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
| typename op_ctype, typename op_dtype, | |||
| megdnn::PostprocessMode postprocess_mode> | |||
| void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
| postprocess_mode,PackMode::DEFAULT>:: | |||
| postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>:: | |||
| get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
| const WorkspaceBundle& bundle_thread, | |||
| const StrategyParam& sparam) { | |||
| @@ -218,7 +210,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
| typename op_ctype, typename op_dtype, | |||
| megdnn::PostprocessMode postprocess_mode> | |||
| void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
| postprocess_mode,PackMode::DEFAULT>:: | |||
| postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>:: | |||
| exec_matmul(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
| const StrategyParam& sparam, WorkspaceBundle bundle, | |||
| WorkspaceBundle bundle_thread, | |||
| @@ -240,11 +232,11 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
| src_ctype* b_panel = | |||
| reinterpret_cast<src_ctype*>(reinterpret_cast<uintptr_t>( | |||
| bundle_thread.get(THREAD_BUNDLE_PACKB_INDEX))); | |||
| size_t pack_oc_size = sparam.pack_oc_size; | |||
| matmul_param.M = sparam.output_block_oc_size; | |||
| matmul_param.N = sparam.output_block_size; | |||
| matmul_param.LDB = sparam.output_block_size; | |||
| matmul_param.LDC = sparam.output_block_size; | |||
| matmul_param.LDB = pack_oc_size * sparam.output_block_size; | |||
| matmul_param.LDC = pack_oc_size * sparam.output_block_size; | |||
| matmul_param.C_ptr = matmul_dst; | |||
| auto matmul_kern_naked = matmul_algo->get_kern_naked(matmul_param); | |||
| @@ -255,7 +247,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
| typename op_ctype, typename op_dtype, | |||
| megdnn::PostprocessMode postprocess_mode> | |||
| void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
| postprocess_mode,PackMode::DEFAULT>:: | |||
| postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>:: | |||
| exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
| const StrategyParam& sparam, | |||
| WorkspaceBundle bundle_thread) { | |||
| @@ -274,7 +266,8 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
| PostProcess<op_ctype, op_dtype, postprocess_mode>::run( | |||
| matmul_dst, bias_preprocess_ptr, matmul_dst, param.bias_mode, | |||
| param.nonlineMode, param.bias_type, param.dst_type, 1_z, | |||
| sparam.output_block_oc_size, 1_z, sparam.output_block_size); | |||
| sparam.output_block_oc_size, 1_z, sparam.output_block_size, | |||
| sparam.pack_oc_size); | |||
| copy_dst(param, matmul_dst, sparam); | |||
| } | |||
| @@ -282,20 +275,24 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
| typename op_ctype, typename op_dtype, | |||
| megdnn::PostprocessMode postprocess_mode> | |||
| void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
| postprocess_mode,PackMode::DEFAULT>:: | |||
| postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>:: | |||
| copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
| const void* matmul_dst, const StrategyParam& sparam) { | |||
| if (!sparam.skip_copy_dst) { | |||
| size_t pack_oc_size = sparam.pack_oc_size; | |||
| dst_ctype* dst_tmp_ptr = | |||
| reinterpret_cast<dst_ctype*>(const_cast<void*>(matmul_dst)); | |||
| dst_ctype* dst = | |||
| param.dst<dst_ctype>(sparam.batch_id, sparam.group_id) + | |||
| sparam.oc_cur_index * sparam.ohw + sparam.ohw_cur_index; | |||
| for (size_t oc = 0; oc < sparam.output_block_oc_size; oc++) { | |||
| sparam.oc_cur_index * sparam.ohw + | |||
| sparam.ohw_cur_index * pack_oc_size; | |||
| size_t oc_loop = sparam.output_block_oc_size / pack_oc_size; | |||
| for (size_t oc = 0; oc < oc_loop; oc++) { | |||
| std::memcpy(dst, dst_tmp_ptr, | |||
| sizeof(dst_ctype) * sparam.output_block_size); | |||
| dst_tmp_ptr += sparam.output_block_size; | |||
| dst += sparam.ohw; | |||
| sizeof(dst_ctype) * sparam.output_block_size * | |||
| pack_oc_size); | |||
| dst_tmp_ptr += sparam.output_block_size * pack_oc_size; | |||
| dst += sparam.ohw * pack_oc_size; | |||
| } | |||
| } | |||
| } | |||
| @@ -304,7 +301,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
| typename op_ctype, typename op_dtype, | |||
| megdnn::PostprocessMode postprocess_mode> | |||
| void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
| postprocess_mode,PackMode::DEFAULT>:: | |||
| postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>:: | |||
| get_bias_temp_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
| const WorkspaceBundle& bundle_thread) { | |||
| bias_ctype* bias_tmp_ptr = | |||
| @@ -319,7 +316,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
| typename op_ctype, typename op_dtype, | |||
| megdnn::PostprocessMode postprocess_mode> | |||
| void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
| postprocess_mode,PackMode::DEFAULT>:: | |||
| postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>:: | |||
| copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
| WorkspaceBundle bundle_thread, const StrategyParam& sparam) { | |||
| const bias_ctype* bias_ptr = static_cast<const bias_ctype*>( | |||
| @@ -340,31 +337,20 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
| } | |||
| } | |||
| #define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | |||
| _op_dtype, _postprocess_mode) \ | |||
| template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | |||
| _op_dtype, _postprocess_mode, PackMode::DEFAULT>; | |||
| #define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | |||
| _op_dtype, _postprocess_mode) \ | |||
| template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | |||
| _op_dtype, _postprocess_mode, PackMode::DEFAULT, \ | |||
| FormatMode::NCHW>; | |||
| INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32, | |||
| megdnn::PostprocessMode::FLOAT) | |||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
| INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, __fp16, __fp16, | |||
| megdnn::PostprocessMode::FLOAT) | |||
| #else | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16, | |||
| megdnn::PostprocessMode::NO_PROCESS) | |||
| #endif | |||
| #endif | |||
| #if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
| //! x86 do not have uint8 matmul so only armv7 armv8 support uint8 | |||
| INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_uint8, dt_qint32, dt_quint8, | |||
| megdnn::PostprocessMode::QUANTIZED) | |||
| INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_qint32, dt_qint32, | |||
| megdnn::PostprocessMode::NO_PROCESS) | |||
| #endif | |||
| INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8, | |||
| megdnn::PostprocessMode::QUANTIZED) | |||
| @@ -0,0 +1,118 @@ | |||
| /** | |||
| * \file dnn/src/fallback/conv_bias/im2col/strategy_default.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "src/fallback/conv_bias/im2col/strategy_base.h" | |||
| #include "src/fallback/convolution/img2col_helper.h" | |||
| #if MEGDNN_X86 | |||
| #include "src/x86/conv_bias/postprocess_helper.h" | |||
| #endif | |||
| using namespace megdnn; | |||
| #if MEGDNN_X86 | |||
| using namespace x86; | |||
| #endif | |||
| namespace megdnn { | |||
| template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
| typename op_ctype, typename op_dtype, | |||
| megdnn::PostprocessMode postprocess_mode> | |||
| void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
| postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW44>:: | |||
| exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | |||
| const StrategyParam& sparam, | |||
| const fallback::ConvBiasImpl::NCBKernParam& param, | |||
| fallback::MatrixMulImpl::KernParam matmul_param, | |||
| fallback::MatrixMulImpl::AlgoBase* matmul_algo) { | |||
| size_t sh = param.filter_meta.stride[0]; | |||
| size_t sw = param.filter_meta.stride[1]; | |||
| size_t oc = param.filter_meta.ocpg; | |||
| size_t oh = param.osz[0]; | |||
| size_t ow = param.osz[1]; | |||
| size_t ic = param.filter_meta.icpg; | |||
| size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2; | |||
| size_t iw = param.isz[1] + param.filter_meta.padding[1] * 2; | |||
| size_t fh = param.filter_meta.spatial[0]; | |||
| size_t fw = param.filter_meta.spatial[1]; | |||
| size_t is_xcorr = !param.filter_meta.should_flip; | |||
| constexpr static size_t pack_size = 4; | |||
| size_t input_offset = | |||
| ih * iw * ic * | |||
| (sparam.group_id + param.filter_meta.group * sparam.batch_id) * | |||
| sizeof(src_ctype); | |||
| src_ctype* src2 = reinterpret_cast<src_ctype*>( | |||
| reinterpret_cast<uintptr_t>(bundle.get(BUNDLE_PADDING_INDEX)) + | |||
| input_offset); | |||
| bool is_phpwzero = param.filter_meta.padding[0] == 0 && | |||
| param.filter_meta.padding[1] == 0; | |||
| if (is_phpwzero) { | |||
| src2 = const_cast<src_ctype*>( | |||
| param.src<src_ctype>(sparam.batch_id, sparam.group_id)); | |||
| } | |||
| src_ctype* im2col_dst = static_cast<src_ctype*>( | |||
| bundle_thread.get(THREAD_BUNDLE_IM2COL_INDEX)); | |||
| if (is_xcorr) { | |||
| if (sh == sw && sh == 1) { | |||
| img2col_nchw4<true>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, | |||
| fw, sh, sw, sparam.ohw_cur_index, | |||
| sparam.output_block_size); | |||
| } else { | |||
| img2col_stride_nchw4<true>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, | |||
| fh, fw, sh, sw, sparam.ohw_cur_index, | |||
| sparam.output_block_size); | |||
| } | |||
| } else { | |||
| if (sh == sw && sh == 1) { | |||
| img2col_nchw4<false>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, | |||
| fw, sh, sw, sparam.ohw_cur_index, | |||
| sparam.output_block_size); | |||
| } else { | |||
| img2col_stride_nchw4<false>( | |||
| src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, fw, sh, sw, | |||
| sparam.ohw_cur_index, sparam.output_block_size); | |||
| } | |||
| } | |||
| matmul_param.M = sparam.output_block_oc_size; | |||
| matmul_param.N = sparam.output_block_size; | |||
| matmul_param.LDB = pack_size * sparam.output_block_size; | |||
| matmul_param.LDC = pack_size * sparam.output_block_size; | |||
| matmul_param.B_ptr = im2col_dst; | |||
| src_ctype* b_panel = | |||
| reinterpret_cast<src_ctype*>(reinterpret_cast<uintptr_t>( | |||
| bundle_thread.get(THREAD_BUNDLE_PACKB_INDEX))); | |||
| matmul_algo->pack_B(matmul_param, b_panel, 0, matmul_param.N); | |||
| } | |||
| #define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | |||
| _op_dtype, _postprocess_mode) \ | |||
| template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | |||
| _op_dtype, _postprocess_mode, PackMode::DEFAULT, \ | |||
| FormatMode::NCHW44>; | |||
| INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32, | |||
| megdnn::PostprocessMode::FLOAT) | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16, | |||
| megdnn::PostprocessMode::NO_PROCESS) | |||
| #endif | |||
| INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8, | |||
| megdnn::PostprocessMode::QUANTIZED) | |||
| INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_int32, dt_int32, | |||
| megdnn::PostprocessMode::NO_PROCESS) | |||
| INSTANTIAL_CLASS(dt_int8, dt_int16, dt_int16, dt_int16, dt_int16, | |||
| megdnn::PostprocessMode::NO_PROCESS) | |||
| INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_qint32, dt_qint32, | |||
| megdnn::PostprocessMode::NO_PROCESS) | |||
| #undef INSTANTIAL_CLASS | |||
| } // namespace megdnn | |||
| @@ -9,8 +9,6 @@ | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "megdnn/opr_param_defs.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/fallback/conv_bias/im2col/strategy_base.h" | |||
| #include "src/fallback/convolution/img2col_helper.h" | |||
| #if MEGDNN_X86 | |||
| @@ -22,22 +20,16 @@ using namespace megdnn; | |||
| using namespace x86; | |||
| #endif | |||
| namespace megdnn { | |||
| template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
| typename op_ctype, typename op_dtype, | |||
| megdnn::PostprocessMode postprocess_mode> | |||
| Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
| postprocess_mode,PackMode::NO_PACK>::Strategy() | |||
| : StrategyBase() {} | |||
| template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
| typename op_ctype, typename op_dtype, | |||
| megdnn::PostprocessMode postprocess_mode> | |||
| void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
| postprocess_mode,PackMode::NO_PACK>:: | |||
| copy_padding_kern( | |||
| WorkspaceBundle bundle, | |||
| const fallback::ConvBiasImpl::NCBKernParam& param, | |||
| const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) { | |||
| postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>:: | |||
| copy_padding_kern(WorkspaceBundle bundle, | |||
| const fallback::ConvBiasImpl::NCBKernParam& param, | |||
| const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
| size_t) { | |||
| UNPACK_CONV_F32_NCB_KERN_SIZES(param); | |||
| MEGDNN_MARK_USED_VAR(N); | |||
| MEGDNN_MARK_USED_VAR(OC); | |||
| @@ -96,12 +88,13 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
| typename op_ctype, typename op_dtype, | |||
| megdnn::PostprocessMode postprocess_mode> | |||
| void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
| postprocess_mode,PackMode::NO_PACK>:: | |||
| postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>:: | |||
| packA_kern(WorkspaceBundle bundle, | |||
| const fallback::ConvBiasImpl::NCBKernParam& param, | |||
| fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||
| fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
| const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) { | |||
| const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
| size_t) { | |||
| MEGDNN_MARK_USED_VAR(bundle); | |||
| MEGDNN_MARK_USED_VAR(param); | |||
| MEGDNN_MARK_USED_VAR(matmulparam); | |||
| @@ -115,7 +108,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
| typename op_ctype, typename op_dtype, | |||
| megdnn::PostprocessMode postprocess_mode> | |||
| void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
| postprocess_mode,PackMode::NO_PACK>:: | |||
| postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>:: | |||
| get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
| const WorkspaceBundle& bundle_thread, | |||
| const StrategyParam& sparam) { | |||
| @@ -134,7 +127,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
| typename op_ctype, typename op_dtype, | |||
| megdnn::PostprocessMode postprocess_mode> | |||
| void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
| postprocess_mode,PackMode::NO_PACK>:: | |||
| postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>:: | |||
| exec_matmul(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
| const StrategyParam& sparam, WorkspaceBundle bundle, | |||
| WorkspaceBundle bundle_thread, | |||
| @@ -167,29 +160,28 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
| typename op_ctype, typename op_dtype, | |||
| megdnn::PostprocessMode postprocess_mode> | |||
| void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
| postprocess_mode,PackMode::NO_PACK>:: | |||
| postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>:: | |||
| exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | |||
| const StrategyParam& sparam, | |||
| const fallback::ConvBiasImpl::NCBKernParam& param, | |||
| fallback::MatrixMulImpl::KernParam matmul_param, | |||
| fallback::MatrixMulImpl::AlgoBase* matmul_algo | |||
| ) { | |||
| fallback::MatrixMulImpl::AlgoBase* matmul_algo) { | |||
| MEGDNN_MARK_USED_VAR(matmul_param); | |||
| MEGDNN_MARK_USED_VAR(matmul_algo); | |||
| size_t m_sh = param.filter_meta.stride[0]; | |||
| size_t m_sw = param.filter_meta.stride[1]; | |||
| size_t m_oc = param.filter_meta.ocpg; | |||
| size_t m_oh = param.osz[0]; | |||
| size_t m_ow = param.osz[1]; | |||
| size_t m_ic = param.filter_meta.icpg; | |||
| size_t m_ih = param.isz[0] + param.filter_meta.padding[0] * 2; | |||
| size_t m_iw = param.isz[1] + param.filter_meta.padding[1] * 2; | |||
| size_t m_fh = param.filter_meta.spatial[0]; | |||
| size_t m_fw = param.filter_meta.spatial[1]; | |||
| size_t m_is_xcorr = !param.filter_meta.should_flip; | |||
| size_t sh = param.filter_meta.stride[0]; | |||
| size_t sw = param.filter_meta.stride[1]; | |||
| size_t oc = param.filter_meta.ocpg; | |||
| size_t oh = param.osz[0]; | |||
| size_t ow = param.osz[1]; | |||
| size_t ic = param.filter_meta.icpg; | |||
| size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2; | |||
| size_t iw = param.isz[1] + param.filter_meta.padding[1] * 2; | |||
| size_t fh = param.filter_meta.spatial[0]; | |||
| size_t fw = param.filter_meta.spatial[1]; | |||
| size_t is_xcorr = !param.filter_meta.should_flip; | |||
| size_t input_offset = | |||
| m_ih * m_iw * m_ic * | |||
| ih * iw * ic * | |||
| (sparam.group_id + param.filter_meta.group * sparam.batch_id) * | |||
| sizeof(src_ctype); | |||
| @@ -205,26 +197,22 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
| } | |||
| src_ctype* im2col_dst = static_cast<src_ctype*>( | |||
| bundle_thread.get(THREAD_BUNDLE_IM2COL_INDEX)); | |||
| if (m_sh == 1 && m_sw == 1) { | |||
| if (m_is_xcorr) { | |||
| img2col<true>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, m_ih, m_iw, | |||
| m_fh, m_fw, sparam.ohw_cur_index, | |||
| sparam.output_block_size); | |||
| if (sh == 1 && sw == 1) { | |||
| if (is_xcorr) { | |||
| img2col<true>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, fw, | |||
| sparam.ohw_cur_index, sparam.output_block_size); | |||
| } else { | |||
| img2col<false>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, m_ih, m_iw, | |||
| m_fh, m_fw, sparam.ohw_cur_index, | |||
| sparam.output_block_size); | |||
| img2col<false>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, fw, | |||
| sparam.ohw_cur_index, sparam.output_block_size); | |||
| } | |||
| } else { | |||
| if (m_is_xcorr) { | |||
| img2col_stride<true>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, m_ih, | |||
| m_iw, m_fh, m_fw, m_sh, m_sw, | |||
| sparam.ohw_cur_index, | |||
| if (is_xcorr) { | |||
| img2col_stride<true>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, | |||
| fw, sh, sw, sparam.ohw_cur_index, | |||
| sparam.output_block_size); | |||
| } else { | |||
| img2col_stride<false>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, | |||
| m_ih, m_iw, m_fh, m_fw, m_sh, m_sw, | |||
| sparam.ohw_cur_index, | |||
| img2col_stride<false>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, | |||
| fw, sh, sw, sparam.ohw_cur_index, | |||
| sparam.output_block_size); | |||
| } | |||
| } | |||
| @@ -234,7 +222,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
| typename op_ctype, typename op_dtype, | |||
| megdnn::PostprocessMode postprocess_mode> | |||
| void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
| postprocess_mode,PackMode::NO_PACK>:: | |||
| postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>:: | |||
| exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
| const StrategyParam& sparam, | |||
| WorkspaceBundle bundle_thread) { | |||
| @@ -262,7 +250,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
| typename op_ctype, typename op_dtype, | |||
| megdnn::PostprocessMode postprocess_mode> | |||
| void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
| postprocess_mode,PackMode::NO_PACK>:: | |||
| postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>:: | |||
| copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
| const void* matmul_dst, const StrategyParam& sparam) { | |||
| if (!sparam.skip_copy_dst) { | |||
| @@ -284,7 +272,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
| typename op_ctype, typename op_dtype, | |||
| megdnn::PostprocessMode postprocess_mode> | |||
| void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
| postprocess_mode,PackMode::NO_PACK>:: | |||
| postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>:: | |||
| copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
| WorkspaceBundle bundle_thread, const StrategyParam& sparam) { | |||
| const bias_ctype* bias_ptr = static_cast<const bias_ctype*>( | |||
| @@ -305,31 +293,20 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
| } | |||
| } | |||
| #define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | |||
| _op_dtype, _postprocess_mode) \ | |||
| template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | |||
| _op_dtype, _postprocess_mode, PackMode::NO_PACK>; | |||
| #define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | |||
| _op_dtype, _postprocess_mode) \ | |||
| template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | |||
| _op_dtype, _postprocess_mode, PackMode::NO_PACK, \ | |||
| FormatMode::NCHW>; | |||
| INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32, | |||
| megdnn::PostprocessMode::FLOAT) | |||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
| INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, __fp16, __fp16, | |||
| megdnn::PostprocessMode::FLOAT) | |||
| #else | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16, | |||
| megdnn::PostprocessMode::NO_PROCESS) | |||
| #endif | |||
| #endif | |||
| #if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
| //! x86 do not have uint8 matmul so only armv7 armv8 support uint8 | |||
| INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_uint8, dt_qint32, dt_quint8, | |||
| megdnn::PostprocessMode::QUANTIZED) | |||
| INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_qint32, dt_qint32, | |||
| megdnn::PostprocessMode::NO_PROCESS) | |||
| #endif | |||
| INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8, | |||
| megdnn::PostprocessMode::QUANTIZED) | |||
| @@ -9,7 +9,6 @@ | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "megdnn/opr_param_defs.h" | |||
| #include "src/fallback/conv_bias/im2col/strategy_base.h" | |||
| #include "src/fallback/convolution/img2col_helper.h" | |||
| #if MEGDNN_X86 | |||
| @@ -21,22 +20,16 @@ using namespace megdnn; | |||
| using namespace x86; | |||
| #endif | |||
| namespace megdnn { | |||
| template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
| typename op_ctype, typename op_dtype, | |||
| megdnn::PostprocessMode postprocess_mode> | |||
| Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
| postprocess_mode,PackMode::ONLY_PACKA>::Strategy() | |||
| : StrategyBase() {} | |||
| template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
| typename op_ctype, typename op_dtype, | |||
| megdnn::PostprocessMode postprocess_mode> | |||
| void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
| postprocess_mode,PackMode::ONLY_PACKA>:: | |||
| copy_padding_kern( | |||
| WorkspaceBundle bundle, | |||
| const fallback::ConvBiasImpl::NCBKernParam& param, | |||
| const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) { | |||
| postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW>:: | |||
| copy_padding_kern(WorkspaceBundle bundle, | |||
| const fallback::ConvBiasImpl::NCBKernParam& param, | |||
| const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
| size_t) { | |||
| UNPACK_CONV_F32_NCB_KERN_SIZES(param); | |||
| MEGDNN_MARK_USED_VAR(N); | |||
| MEGDNN_MARK_USED_VAR(OC); | |||
| @@ -95,12 +88,13 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
| typename op_ctype, typename op_dtype, | |||
| megdnn::PostprocessMode postprocess_mode> | |||
| void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
| postprocess_mode,PackMode::ONLY_PACKA>:: | |||
| postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW>:: | |||
| packA_kern(WorkspaceBundle bundle, | |||
| const fallback::ConvBiasImpl::NCBKernParam& param, | |||
| fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||
| fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
| const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) { | |||
| const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
| size_t) { | |||
| bundle.set(param.workspace_ptr); | |||
| fallback::MatrixMulImpl::KernParam matmul_param; | |||
| static_cast<fallback::MatrixMulImpl::KernSizeParam&>(matmul_param) = | |||
| @@ -128,7 +122,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
| typename op_ctype, typename op_dtype, | |||
| megdnn::PostprocessMode postprocess_mode> | |||
| void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
| postprocess_mode,PackMode::ONLY_PACKA>:: | |||
| postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW>:: | |||
| get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
| const WorkspaceBundle& bundle_thread, | |||
| const StrategyParam& sparam) { | |||
| @@ -147,7 +141,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
| typename op_ctype, typename op_dtype, | |||
| megdnn::PostprocessMode postprocess_mode> | |||
| void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
| postprocess_mode,PackMode::ONLY_PACKA>:: | |||
| postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW>:: | |||
| exec_matmul(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
| const StrategyParam& sparam, WorkspaceBundle bundle, | |||
| WorkspaceBundle bundle_thread, | |||
| @@ -185,29 +179,28 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
| typename op_ctype, typename op_dtype, | |||
| megdnn::PostprocessMode postprocess_mode> | |||
| void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
| postprocess_mode,PackMode::ONLY_PACKA>:: | |||
| postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW>:: | |||
| exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | |||
| const StrategyParam& sparam, | |||
| const fallback::ConvBiasImpl::NCBKernParam& param, | |||
| fallback::MatrixMulImpl::KernParam matmul_param, | |||
| fallback::MatrixMulImpl::AlgoBase* matmul_algo | |||
| ) { | |||
| fallback::MatrixMulImpl::AlgoBase* matmul_algo) { | |||
| MEGDNN_MARK_USED_VAR(matmul_param); | |||
| MEGDNN_MARK_USED_VAR(matmul_algo); | |||
| size_t m_sh = param.filter_meta.stride[0]; | |||
| size_t m_sw = param.filter_meta.stride[1]; | |||
| size_t m_oc = param.filter_meta.ocpg; | |||
| size_t m_oh = param.osz[0]; | |||
| size_t m_ow = param.osz[1]; | |||
| size_t m_ic = param.filter_meta.icpg; | |||
| size_t m_ih = param.isz[0] + param.filter_meta.padding[0] * 2; | |||
| size_t m_iw = param.isz[1] + param.filter_meta.padding[1] * 2; | |||
| size_t m_fh = param.filter_meta.spatial[0]; | |||
| size_t m_fw = param.filter_meta.spatial[1]; | |||
| size_t m_is_xcorr = !param.filter_meta.should_flip; | |||
| size_t sh = param.filter_meta.stride[0]; | |||
| size_t sw = param.filter_meta.stride[1]; | |||
| size_t oc = param.filter_meta.ocpg; | |||
| size_t oh = param.osz[0]; | |||
| size_t ow = param.osz[1]; | |||
| size_t ic = param.filter_meta.icpg; | |||
| size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2; | |||
| size_t iw = param.isz[1] + param.filter_meta.padding[1] * 2; | |||
| size_t fh = param.filter_meta.spatial[0]; | |||
| size_t fw = param.filter_meta.spatial[1]; | |||
| size_t is_xcorr = !param.filter_meta.should_flip; | |||
| size_t input_offset = | |||
| m_ih * m_iw * m_ic * | |||
| ih * iw * ic * | |||
| (sparam.group_id + param.filter_meta.group * sparam.batch_id) * | |||
| sizeof(src_ctype); | |||
| @@ -222,26 +215,22 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
| } | |||
| src_ctype* im2col_dst = static_cast<src_ctype*>( | |||
| bundle_thread.get(THREAD_BUNDLE_IM2COL_INDEX)); | |||
| if (m_sh == 1 && m_sw == 1) { | |||
| if (m_is_xcorr) { | |||
| img2col<true>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, m_ih, m_iw, | |||
| m_fh, m_fw, sparam.ohw_cur_index, | |||
| sparam.output_block_size); | |||
| if (sh == 1 && sw == 1) { | |||
| if (is_xcorr) { | |||
| img2col<true>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, fw, | |||
| sparam.ohw_cur_index, sparam.output_block_size); | |||
| } else { | |||
| img2col<false>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, m_ih, m_iw, | |||
| m_fh, m_fw, sparam.ohw_cur_index, | |||
| sparam.output_block_size); | |||
| img2col<false>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, fw, | |||
| sparam.ohw_cur_index, sparam.output_block_size); | |||
| } | |||
| } else { | |||
| if (m_is_xcorr) { | |||
| img2col_stride<true>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, m_ih, | |||
| m_iw, m_fh, m_fw, m_sh, m_sw, | |||
| sparam.ohw_cur_index, | |||
| if (is_xcorr) { | |||
| img2col_stride<true>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, | |||
| fw, sh, sw, sparam.ohw_cur_index, | |||
| sparam.output_block_size); | |||
| } else { | |||
| img2col_stride<false>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, | |||
| m_ih, m_iw, m_fh, m_fw, m_sh, m_sw, | |||
| sparam.ohw_cur_index, | |||
| img2col_stride<false>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, | |||
| fw, sh, sw, sparam.ohw_cur_index, | |||
| sparam.output_block_size); | |||
| } | |||
| } | |||
| @@ -251,7 +240,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
| typename op_ctype, typename op_dtype, | |||
| megdnn::PostprocessMode postprocess_mode> | |||
| void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
| postprocess_mode,PackMode::ONLY_PACKA>:: | |||
| postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW>:: | |||
| exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
| const StrategyParam& sparam, | |||
| WorkspaceBundle bundle_thread) { | |||
| @@ -292,7 +281,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
| typename op_ctype, typename op_dtype, | |||
| megdnn::PostprocessMode postprocess_mode> | |||
| void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
| postprocess_mode,PackMode::ONLY_PACKA>:: | |||
| postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW>:: | |||
| copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
| const void* matmul_dst, const StrategyParam& sparam) { | |||
| if (!sparam.skip_copy_dst) { | |||
| @@ -310,31 +299,20 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
| } | |||
| } | |||
| #define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | |||
| _op_dtype, _postprocess_mode) \ | |||
| template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, \ | |||
| _op_ctype, _op_dtype, _postprocess_mode,PackMode::ONLY_PACKA>; | |||
| #define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | |||
| _op_dtype, _postprocess_mode) \ | |||
| template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ | |||
| _op_dtype, _postprocess_mode, \ | |||
| PackMode::ONLY_PACKA, FormatMode::NCHW>; | |||
| INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32, | |||
| megdnn::PostprocessMode::FLOAT) | |||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
| INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, __fp16, __fp16, | |||
| megdnn::PostprocessMode::FLOAT) | |||
| #else | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16, | |||
| megdnn::PostprocessMode::NO_PROCESS) | |||
| #endif | |||
| #endif | |||
| #if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
| //! x86 do not have uint8 matmul so only armv7 armv8 support uint8 | |||
| INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_uint8, dt_qint32, dt_quint8, | |||
| megdnn::PostprocessMode::QUANTIZED) | |||
| INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_qint32, dt_qint32, | |||
| megdnn::PostprocessMode::NO_PROCESS) | |||
| #endif | |||
| INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8, | |||
| megdnn::PostprocessMode::QUANTIZED) | |||
| @@ -9,7 +9,6 @@ | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "src/common/utils.h" | |||
| namespace { | |||
| template <bool is_xcorr, typename dtype> | |||
| @@ -41,7 +40,326 @@ void img2col_stride(const dtype* __restrict src, dtype* __restrict dst, | |||
| } | |||
| } | |||
| //!add for im2col matmul multithread | |||
| // | |||
| template <bool is_xcorr, typename dtype> | |||
| void img2col_stride_nchw4(const dtype* __restrict src, dtype* __restrict dst, | |||
| const int OC, const int OH, const int OW, const int IC, | |||
| const int IH, const int IW, const int FH, const int FW, | |||
| const int SH, const int SW, const int cur_index, | |||
| const int block_size) { | |||
| MEGDNN_MARK_USED_VAR(OC); | |||
| MEGDNN_MARK_USED_VAR(OH); | |||
| int start_h = cur_index / OW; | |||
| int cur_remain_w = cur_index % OW; | |||
| int end_h = (cur_index + block_size) / OW; | |||
| int end_remain_w = (cur_index + block_size) % OW; | |||
| bool same_line = false; | |||
| if (start_h == end_h) { | |||
| same_line = true; | |||
| } | |||
| size_t newIC = IC / 4; | |||
| size_t i = 0; | |||
| if (sizeof(dtype) != 1) { | |||
| if (same_line) { | |||
| rep(ic, newIC) { | |||
| rep(fh, FH) { | |||
| rep(fw, FW) { | |||
| int fh2, fw2; | |||
| if (is_xcorr) { | |||
| fh2 = fh; | |||
| fw2 = fw; | |||
| } else { | |||
| fh2 = FH - fh - 1; | |||
| fw2 = FW - fw - 1; | |||
| } | |||
| for (int w = cur_remain_w; w < end_remain_w; w++) { | |||
| size_t index = 4 * (ic * IH * IW + | |||
| (start_h * SH + fh2) * IW + | |||
| (w * SW + fw2)); | |||
| dst[i++] = src[index]; | |||
| dst[i++] = src[index + 1]; | |||
| dst[i++] = src[index + 2]; | |||
| dst[i++] = src[index + 3]; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } else { | |||
| rep(ic, newIC) { | |||
| rep(fh, FH) { | |||
| rep(fw, FW) { | |||
| int fh2, fw2; | |||
| if (is_xcorr) { | |||
| fh2 = fh; | |||
| fw2 = fw; | |||
| } else { | |||
| fh2 = FH - fh - 1; | |||
| fw2 = FW - fw - 1; | |||
| } | |||
| for (int w = cur_remain_w; w < OW; w++) { | |||
| size_t index =4 * (ic * IH * IW + | |||
| (start_h * SH + fh2) * IW + | |||
| (w * SW + fw2)); | |||
| dst[i++] = src[index + 0]; | |||
| dst[i++] = src[index + 1]; | |||
| dst[i++] = src[index + 2]; | |||
| dst[i++] = src[index + 3]; | |||
| } | |||
| for (int h = start_h + 1; h < end_h; h++) { | |||
| rep(ow, OW) { | |||
| size_t index = 4 * (ic * IH * IW + | |||
| (h * SH + fh2) * IW + | |||
| (ow * SW + fw2)); | |||
| dst[i++] = src[index + 0]; | |||
| dst[i++] = src[index + 1]; | |||
| dst[i++] = src[index + 2]; | |||
| dst[i++] = src[index + 3]; | |||
| } | |||
| } | |||
| for (int w = 0; w < end_remain_w; w++) { | |||
| size_t index = 4 * (ic * IH * IW + | |||
| (end_h * SH + fh2) * IW + | |||
| (w * SW + fw2)); | |||
| dst[i++] = src[index + 0]; | |||
| dst[i++] = src[index + 1]; | |||
| dst[i++] = src[index + 2]; | |||
| dst[i++] = src[index + 3]; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } else { | |||
| uint32_t* output = nullptr; | |||
| const uint32_t* uint32_src = | |||
| static_cast<const uint32_t*>(static_cast<const void*>(src)); | |||
| output = static_cast<uint32_t*>(static_cast<void*>(dst)); | |||
| if (same_line) { | |||
| rep(ic, newIC) { | |||
| rep(fh, FH) { | |||
| rep(fw, FW) { | |||
| int fh2, fw2; | |||
| if (is_xcorr) { | |||
| fh2 = fh; | |||
| fw2 = fw; | |||
| } else { | |||
| fh2 = FH - fh - 1; | |||
| fw2 = FW - fw - 1; | |||
| } | |||
| size_t index = | |||
| (ic * IH * IW + (start_h * SH + fh2) * IW + | |||
| (cur_remain_w * SW + fw2)); | |||
| for (int w = cur_remain_w; w < end_remain_w; w++) { | |||
| output[i++] = uint32_src[index]; | |||
| index += SW; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } else { | |||
| rep(ic, newIC) { | |||
| rep(fh, FH) { | |||
| rep(fw, FW) { | |||
| int fh2, fw2; | |||
| if (is_xcorr) { | |||
| fh2 = fh; | |||
| fw2 = fw; | |||
| } else { | |||
| fh2 = FH - fh - 1; | |||
| fw2 = FW - fw - 1; | |||
| } | |||
| size_t index = ic * IH * IW + | |||
| (start_h * SH + fh2) * IW + | |||
| cur_remain_w * SW + fw2; | |||
| for (int w = cur_remain_w; w < OW; w++) { | |||
| output[i++] = uint32_src[index]; | |||
| index += SW; | |||
| } | |||
| for (int h = start_h + 1; h < end_h; h++) { | |||
| index = ic * IH * IW + (h * SH + fh2) * IW + fw2; | |||
| rep(ow, OW) { | |||
| output[i++] = uint32_src[index]; | |||
| index += SW; | |||
| } | |||
| } | |||
| index = ic * IH * IW + (end_h * SH + fh2) * IW + fw2; | |||
| for (int w = 0; w < end_remain_w; w++) { | |||
| output[i++] = uint32_src[index]; | |||
| index += SW; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| template <bool is_xcorr, typename dtype> | |||
| void img2col_nchw4(const dtype* __restrict src, dtype* __restrict dst, | |||
| const int OC, const int OH, const int OW, const int IC, | |||
| const int IH, const int IW, const int FH, const int FW, | |||
| const int SH, const int SW, const int cur_index, | |||
| const int block_size) { | |||
| MEGDNN_MARK_USED_VAR(OC); | |||
| MEGDNN_MARK_USED_VAR(OH); | |||
| MEGDNN_MARK_USED_VAR(SH); | |||
| MEGDNN_MARK_USED_VAR(SW); | |||
| int start_h = cur_index / OW; | |||
| int cur_remain_w = cur_index % OW; | |||
| int end_h = (cur_index + block_size) / OW; | |||
| int end_remain_w = (cur_index + block_size) % OW; | |||
| bool same_line = false; | |||
| if (start_h == end_h) { | |||
| same_line = true; | |||
| } | |||
| size_t newIC = IC / 4; | |||
| size_t i = 0; | |||
| if (sizeof(dtype) != 1) { | |||
| if (same_line) { | |||
| rep(ic, newIC) { | |||
| rep(fh, FH) { | |||
| rep(fw, FW) { | |||
| int fh2, fw2; | |||
| if (is_xcorr) { | |||
| fh2 = fh; | |||
| fw2 = fw; | |||
| } else { | |||
| fh2 = FH - fh - 1; | |||
| fw2 = FW - fw - 1; | |||
| } | |||
| for (int w = cur_remain_w; w < end_remain_w; w++) { | |||
| size_t index = | |||
| 4 * (ic * IH * IW + (start_h + fh2) * IW + | |||
| (w + fw2)); | |||
| dst[i++] = src[index]; | |||
| dst[i++] = src[index + 1]; | |||
| dst[i++] = src[index + 2]; | |||
| dst[i++] = src[index + 3]; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } else { | |||
| rep(ic, newIC) { | |||
| rep(fh, FH) { | |||
| rep(fw, FW) { | |||
| int fh2, fw2; | |||
| if (is_xcorr) { | |||
| fh2 = fh; | |||
| fw2 = fw; | |||
| } else { | |||
| fh2 = FH - fh - 1; | |||
| fw2 = FW - fw - 1; | |||
| } | |||
| for (int w = cur_remain_w; w < OW; w++) { | |||
| size_t index = ic * IH * IW + (start_h + fh2) * IW + | |||
| (w + fw2); | |||
| dst[i++] = src[4 * index]; | |||
| dst[i++] = src[4 * index + 1]; | |||
| dst[i++] = src[4 * index + 2]; | |||
| dst[i++] = src[4 * index + 3]; | |||
| } | |||
| for (int h = start_h + 1; h < end_h; h++) { | |||
| rep(ow, OW) { | |||
| size_t index = | |||
| 4 * (ic * IH * IW + (h + fh2) * IW + | |||
| (ow + fw2)); | |||
| dst[i++] = src[index + 0]; | |||
| dst[i++] = src[index + 1]; | |||
| dst[i++] = src[index + 2]; | |||
| dst[i++] = src[index + 3]; | |||
| } | |||
| } | |||
| for (int w = 0; w < end_remain_w; w++) { | |||
| size_t index = 4 * (ic * IH * IW + | |||
| (end_h + fh2) * IW + (w + fw2)); | |||
| dst[i++] = src[index + 0]; | |||
| dst[i++] = src[index + 1]; | |||
| dst[i++] = src[index + 2]; | |||
| dst[i++] = src[index + 3]; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } else { | |||
| uint32_t* output = nullptr; | |||
| const uint32_t* uint32_src = | |||
| static_cast<const uint32_t*>(static_cast<const void*>(src)); | |||
| output = static_cast<uint32_t*>(static_cast<void*>(dst)); | |||
| if (same_line) { | |||
| rep(ic, newIC) { | |||
| rep(fh, FH) { | |||
| rep(fw, FW) { | |||
| int fh2, fw2; | |||
| if (is_xcorr) { | |||
| fh2 = fh; | |||
| fw2 = fw; | |||
| } else { | |||
| fh2 = FH - fh - 1; | |||
| fw2 = FW - fw - 1; | |||
| } | |||
| for (int w = cur_remain_w; w < end_remain_w; w++) { | |||
| size_t index = (ic * IH * IW + | |||
| (start_h + fh2) * IW + (w + fw2)); | |||
| output[i++] = uint32_src[index]; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } else { | |||
| rep(ic, newIC) { | |||
| rep(fh, FH) { | |||
| rep(fw, FW) { | |||
| int fh2, fw2; | |||
| if (is_xcorr) { | |||
| fh2 = fh; | |||
| fw2 = fw; | |||
| } else { | |||
| fh2 = FH - fh - 1; | |||
| fw2 = FW - fw - 1; | |||
| } | |||
| for (int w = cur_remain_w; w < OW; w++) { | |||
| size_t index = ic * IH * IW + (start_h + fh2) * IW + | |||
| (w + fw2); | |||
| output[i++] = uint32_src[index]; | |||
| } | |||
| for (int h = start_h + 1; h < end_h; h++) { | |||
| rep(ow, OW) { | |||
| size_t index = (ic * IH * IW + (h + fh2) * IW + | |||
| (ow + fw2)); | |||
| output[i++] = uint32_src[index]; | |||
| } | |||
| } | |||
| for (int w = 0; w < end_remain_w; w++) { | |||
| size_t index = (ic * IH * IW + (end_h + fh2) * IW + | |||
| (w + fw2)); | |||
| output[i++] = uint32_src[index]; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| template <bool is_xcorr, typename dtype> | |||
| void img2col_stride(const dtype* __restrict src, dtype* __restrict dst, | |||
| @@ -124,7 +124,8 @@ struct PostProcess { | |||
| megdnn::ConvBiasForward::BiasMode bias_mode, | |||
| megdnn::param::ConvBias::NonlineMode nonlineMode, | |||
| DType bias_type, DType dst_type, size_t N, size_t OC, | |||
| size_t OH, size_t OW) { | |||
| size_t OH, size_t OW, size_t pack_oc_size = 1) { | |||
| MEGDNN_MARK_USED_VAR(pack_oc_size); | |||
| megdnn::param::Elemwise::Mode elem_mode = | |||
| megdnn::param::Elemwise::Mode::ADD; | |||
| if (bias_mode != megdnn::ConvBiasForward::BiasMode::NO_BIAS) { | |||
| @@ -154,7 +155,8 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::FLOAT> { | |||
| megdnn::ConvBiasForward::BiasMode bias_mode, | |||
| megdnn::param::ConvBias::NonlineMode nonlineMode, | |||
| DType bias_type, DType dst_type, size_t N, size_t OC, | |||
| size_t OH, size_t OW) { | |||
| size_t OH, size_t OW, size_t pack_oc_size=1) { | |||
| MEGDNN_MARK_USED_VAR(pack_oc_size); | |||
| megdnn::param::Elemwise::Mode elem_mode = | |||
| megdnn::param::Elemwise::Mode::ADD; | |||
| if (bias_mode != megdnn::ConvBiasForward::BiasMode::NO_BIAS) { | |||
| @@ -185,7 +187,8 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> { | |||
| megdnn::ConvBiasForward::BiasMode bias_mode, | |||
| megdnn::param::ConvBias::NonlineMode nonlineMode, | |||
| DType bias_type, DType dst_type, size_t N, size_t OC, | |||
| size_t OH, size_t OW) { | |||
| size_t OH, size_t OW,size_t pack_oc_size = 1) { | |||
| MEGDNN_MARK_USED_VAR(pack_oc_size); | |||
| MEGDNN_MARK_USED_VAR(conv_dst_ptr); | |||
| MEGDNN_MARK_USED_VAR(bias_ptr); | |||
| MEGDNN_MARK_USED_VAR(dst_ptr); | |||
| @@ -292,7 +295,8 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::QUANTIZED> { | |||
| megdnn::ConvBiasForward::BiasMode bias_mode, | |||
| megdnn::param::ConvBiasV0::NonlineMode nonlineMode, | |||
| DType bias_type, DType dst_type, size_t N, size_t OC, | |||
| size_t OH, size_t OW) { | |||
| size_t OH, size_t OW, size_t pack_oc_size = 1) { | |||
| MEGDNN_MARK_USED_VAR(pack_oc_size); | |||
| megdnn::param::Elemwise::Mode elem_mode = | |||
| megdnn::param::Elemwise::Mode::ADD; | |||
| if (bias_mode != megdnn::ConvBiasForward::BiasMode::NO_BIAS) { | |||