| @@ -38,23 +38,6 @@ public: | |||||
| const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
| }; | }; | ||||
| class ConvBiasImpl::AlgoS8DirectStride1NCHW44 final : public AlgoBase { | |||||
| public: | |||||
| AlgoS8DirectStride1NCHW44() {} | |||||
| bool is_reproducible() const override { return true; } | |||||
| const char* name() const override { return "S8_NCHW44_DIRECT_STRD1"; } | |||||
| bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, | |||||
| AlgoSelectionStrategy algo_selection_strategy) const override; | |||||
| size_t get_workspace(fallback::ConvBiasImpl*, | |||||
| const NCBKernSizeParam& param) const override; | |||||
| virtual SmallVector<NCBKern> dispatch_kerns( | |||||
| fallback::ConvBiasImpl* opr, | |||||
| const NCBKernSizeParam& param) const override; | |||||
| bool is_preferred(megdnn::fallback::ConvBiasImpl*, | |||||
| const NCBKernSizeParam& param) const override; | |||||
| }; | |||||
| class ConvBiasImpl::AlgoS8DirectStride2 final : public AlgoBase { | class ConvBiasImpl::AlgoS8DirectStride2 final : public AlgoBase { | ||||
| bool m_large_group; | bool m_large_group; | ||||
| @@ -74,11 +57,11 @@ public: | |||||
| const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
| }; | }; | ||||
| class ConvBiasImpl::AlgoS8DirectStride2NCHW44 final : public AlgoBase { | |||||
| class ConvBiasImpl::AlgoS8DirectNCHW44 final : public AlgoBase { | |||||
| public: | public: | ||||
| AlgoS8DirectStride2NCHW44() {} | |||||
| AlgoS8DirectNCHW44() {} | |||||
| bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
| const char* name() const override { return "S8_NCHW44_DIRECT_STRD2"; } | |||||
| const char* name() const override { return "S8_NCHW44_DIRECT"; } | |||||
| bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, | bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, | ||||
| AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
| size_t get_workspace(fallback::ConvBiasImpl*, | size_t get_workspace(fallback::ConvBiasImpl*, | ||||
| @@ -245,8 +228,8 @@ private: | |||||
| //=======================input int8 compute fp32 output int8============ | //=======================input int8 compute fp32 output int8============ | ||||
| class ConvBiasImpl::AlgoS8CF32WinogradF23_4x4_NCHW44 final : public AlgoBase { | class ConvBiasImpl::AlgoS8CF32WinogradF23_4x4_NCHW44 final : public AlgoBase { | ||||
| public: | public: | ||||
| AlgoS8CF32WinogradF23_4x4_NCHW44(fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
| uint32_t tile_size) | |||||
| AlgoS8CF32WinogradF23_4x4_NCHW44( | |||||
| fallback::MatrixMulImpl::AlgoBase* matmul_algo, uint32_t tile_size) | |||||
| : m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} | : m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} | ||||
| bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
| const char* name() const override { | const char* name() const override { | ||||
| @@ -277,7 +260,7 @@ private: | |||||
| class ConvBiasImpl::AlgoS8WinogradF23_8x8_NCHW44 final : public AlgoBase { | class ConvBiasImpl::AlgoS8WinogradF23_8x8_NCHW44 final : public AlgoBase { | ||||
| public: | public: | ||||
| AlgoS8WinogradF23_8x8_NCHW44(fallback::MatrixMulImpl::AlgoBase* matmul_algo, | AlgoS8WinogradF23_8x8_NCHW44(fallback::MatrixMulImpl::AlgoBase* matmul_algo, | ||||
| uint32_t tile_size) | |||||
| uint32_t tile_size) | |||||
| : m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} | : m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} | ||||
| bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
| const char* name() const override { | const char* name() const override { | ||||
| @@ -36,26 +36,6 @@ KERN(stride2, 7, nchw) | |||||
| #undef KERN | #undef KERN | ||||
| #define KERN(stride, i, layout) \ | |||||
| template <BiasMode bias_mode, typename Op, int remain_w> \ | |||||
| void conv_direct_##stride##_##i##x##i##_int8_##layout( \ | |||||
| const int8_t* src, const int8_t* filter, const int32_t* bias, \ | |||||
| int32_t* temp, int8_t* dst, const size_t OC, const size_t IC, \ | |||||
| const size_t IH, const size_t IW, const size_t OH, \ | |||||
| const size_t OW, const Op& op); | |||||
| KERN(stride1, 2, nchw44) | |||||
| KERN(stride1, 3, nchw44) | |||||
| KERN(stride1, 5, nchw44) | |||||
| KERN(stride1, 7, nchw44) | |||||
| KERN(stride2, 2, nchw44) | |||||
| KERN(stride2, 3, nchw44) | |||||
| KERN(stride2, 5, nchw44) | |||||
| KERN(stride2, 7, nchw44) | |||||
| #undef KERN | |||||
| void nchw44_pack_filter(const int8_t* src, int8_t* dst, int filter); | |||||
| void nchw44_pack_src(const int8_t* src, int8_t* dst, int length); | |||||
| } // namespace conv_bias | } // namespace conv_bias | ||||
| } // namespace arm_common | } // namespace arm_common | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * \file dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw44_algo.cpp | |||||
| * \file dnn/src/arm_common/conv_bias/int8/direct_nchw44_algo.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
| * | * | ||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | ||||
| @@ -13,6 +13,7 @@ | |||||
| #include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
| #include "src/arm_common/conv_bias/int8/algos.h" | #include "src/arm_common/conv_bias/int8/algos.h" | ||||
| #include "src/arm_common/conv_bias/int8/direct.h" | #include "src/arm_common/conv_bias/int8/direct.h" | ||||
| #include "src/arm_common/conv_bias/int8/direct_nchw44_kern.h" | |||||
| #include "src/arm_common/conv_bias/int8/strategy.h" | #include "src/arm_common/conv_bias/int8/strategy.h" | ||||
| #include "src/arm_common/elemwise_op.h" | #include "src/arm_common/elemwise_op.h" | ||||
| #include "src/common/opr_delegate.h" | #include "src/common/opr_delegate.h" | ||||
| @@ -25,28 +26,19 @@ using conv_fun = std::function<void( | |||||
| WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param, | WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param, | ||||
| const ConvBiasImpl::NCBKernIndex& ncb_index, | const ConvBiasImpl::NCBKernIndex& ncb_index, | ||||
| const CpuNDRange& workspace_ids, const CpuNDRange& ncb_range)>; | const CpuNDRange& workspace_ids, const CpuNDRange& ncb_range)>; | ||||
| MIDOUT_DECL(megdnn_arm_common_conv_bias_int8_nchw44_stride2) | |||||
| MIDOUT_DECL(megdnn_arm_common_conv_bias_int8_nchw44) | |||||
| static void get_rectified_size( | static void get_rectified_size( | ||||
| const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, | |||||
| size_t& IH2, size_t& IW2, size_t& OH2, size_t& OW2) { | |||||
| const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, int& ih2, | |||||
| int& iw2) { | |||||
| auto&& fm = param.filter_meta; | auto&& fm = param.filter_meta; | ||||
| size_t SW = fm.stride[1]; | |||||
| size_t IH = param.isz[0]; | |||||
| size_t IW = param.isz[1]; | |||||
| size_t OH = param.osz[0]; | |||||
| size_t OW = param.osz[1]; | |||||
| size_t FH = fm.spatial[0]; | |||||
| size_t FW = fm.spatial[1]; | |||||
| int ih = param.isz[0]; | |||||
| int iw = param.isz[1]; | |||||
| int ph = fm.padding[0]; | |||||
| int pw = fm.padding[1]; | |||||
| OH2 = OH; | |||||
| OW2 = (OW + 7) & ~7; | |||||
| IH2 = SW * OH + FH - SW; | |||||
| IW2 = SW * OW2 + FW - SW; | |||||
| // Because stride is 2, sometimes IW == IW2+1. Do a max update to | |||||
| // handle this case. | |||||
| IH2 = std::max(IH2, IH); | |||||
| IW2 = std::max(IW2, IW); | |||||
| ih2 = ih + ph * 2; | |||||
| iw2 = iw + pw * 2; | |||||
| } | } | ||||
| static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { | static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { | ||||
| constexpr size_t src_expand = 4; | constexpr size_t src_expand = 4; | ||||
| @@ -57,8 +49,8 @@ static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { | |||||
| size_t OC = fm.ocpg; | size_t OC = fm.ocpg; | ||||
| size_t FH = fm.spatial[0]; | size_t FH = fm.spatial[0]; | ||||
| size_t FW = fm.spatial[1]; | size_t FW = fm.spatial[1]; | ||||
| size_t IH2, IW2, OH2, OW2; | |||||
| get_rectified_size(param, IH2, IW2, OH2, OW2); | |||||
| int IH2, IW2; | |||||
| get_rectified_size(param, IH2, IW2); | |||||
| if (group == 1) { | if (group == 1) { | ||||
| size_t src_size = | size_t src_size = | ||||
| batch * group * IC * IH2 * IW2 * sizeof(int8_t) * src_expand; | batch * group * IC * IH2 * IW2 * sizeof(int8_t) * src_expand; | ||||
| @@ -76,16 +68,16 @@ static void copy_padding_kern(WorkspaceBundle bundle, | |||||
| const ConvBiasImpl::NCBKernParam& kern_param, | const ConvBiasImpl::NCBKernParam& kern_param, | ||||
| const ConvBiasImpl::NCBKernIndex& ncb_index, | const ConvBiasImpl::NCBKernIndex& ncb_index, | ||||
| const CpuNDRange& workspace_ids) { | const CpuNDRange& workspace_ids) { | ||||
| size_t IH = kern_param.isz[0]; | |||||
| size_t IW = kern_param.isz[1]; | |||||
| size_t IC = kern_param.filter_meta.icpg; | |||||
| size_t PH = kern_param.filter_meta.padding[0]; | |||||
| size_t PW = kern_param.filter_meta.padding[1]; | |||||
| size_t GROUP = kern_param.filter_meta.group; | |||||
| int IH = kern_param.isz[0]; | |||||
| int IW = kern_param.isz[1]; | |||||
| int IC = kern_param.filter_meta.icpg; | |||||
| int PH = kern_param.filter_meta.padding[0]; | |||||
| int PW = kern_param.filter_meta.padding[1]; | |||||
| int GROUP = kern_param.filter_meta.group; | |||||
| size_t IH2, IW2, OH2, OW2; | |||||
| get_rectified_size(kern_param, IH2, IW2, OH2, OW2); | |||||
| size_t padding_group_size = IH2 * IW2 * IC; | |||||
| int IH2, IW2; | |||||
| get_rectified_size(kern_param, IH2, IW2); | |||||
| int padding_group_size = IH2 * IW2 * IC; | |||||
| bundle.set(kern_param.workspace_ptr); | bundle.set(kern_param.workspace_ptr); | ||||
| //! Used for get the workspace offset | //! Used for get the workspace offset | ||||
| constexpr int pack_ic = 4; | constexpr int pack_ic = 4; | ||||
| @@ -100,16 +92,10 @@ static void copy_padding_kern(WorkspaceBundle bundle, | |||||
| size_t group_id = ncb_index.ndrange_id[1]; | size_t group_id = ncb_index.ndrange_id[1]; | ||||
| size_t group_pack_size = 1; | size_t group_pack_size = 1; | ||||
| int nr_pad_h = PH * IW2 * pack_ic * expend_element; | |||||
| int nr_pad_w = PW * pack_ic * expend_element; | int nr_pad_w = PW * pack_ic * expend_element; | ||||
| int over_pad = std::max(0_z, IW2 - IW - 2 * PW) * pack_ic * expend_element; | |||||
| int row_last_pad = ((int)IW2 - (int)IW - 2 * (int)PW) >= 0 | |||||
| ? nr_pad_w + over_pad | |||||
| : (IW2 - IW - PW) * pack_ic * expend_element; | |||||
| int col_last_pad = | |||||
| ((int)IH2 - (int)IH - 2 * (int)PH) >= 0 | |||||
| ? nr_pad_h | |||||
| : (IH2 - IH - PH) * IW2 * pack_ic * expend_element; | |||||
| int nr_pad_h = PH * IW2 * pack_ic * expend_element; | |||||
| int row_last_pad = (IW2 - IW - PW) * pack_ic * expend_element; | |||||
| int col_last_pad = (IH2 - IH - PH) * IW2 * pack_ic * expend_element; | |||||
| const int8_t* sptr = static_cast<const int8_t*>(kern_param.src<int8_t>( | const int8_t* sptr = static_cast<const int8_t*>(kern_param.src<int8_t>( | ||||
| batch_id, group_id, workspace_ic_id, group_pack_size, pack_ic)); | batch_id, group_id, workspace_ic_id, group_pack_size, pack_ic)); | ||||
| @@ -129,7 +115,7 @@ static void copy_padding_kern(WorkspaceBundle bundle, | |||||
| rep(ih_idx, IH) { | rep(ih_idx, IH) { | ||||
| std::memset(sptr_base, 0, nr_pad_w * sizeof(int8_t)); | std::memset(sptr_base, 0, nr_pad_w * sizeof(int8_t)); | ||||
| sptr_base += nr_pad_w; | sptr_base += nr_pad_w; | ||||
| conv_bias::nchw44_pack_src(sptr, sptr_base, IW); | |||||
| nchw44_pack_src(sptr, sptr_base, IW); | |||||
| sptr_base += IW * pack_ic * expend_element; | sptr_base += IW * pack_ic * expend_element; | ||||
| sptr += IW * pack_ic; | sptr += IW * pack_ic; | ||||
| std::memset(sptr_base, 0, row_last_pad * sizeof(int8_t)); | std::memset(sptr_base, 0, row_last_pad * sizeof(int8_t)); | ||||
| @@ -140,7 +126,8 @@ static void copy_padding_kern(WorkspaceBundle bundle, | |||||
| } | } | ||||
| } | } | ||||
| template <size_t filter, BiasMode bias_mode, typename Op, int ow_remain> | |||||
| template <size_t filter, BiasMode bias_mode, typename Op, int ow_remain, | |||||
| typename DstType, int stride> | |||||
| static void do_conv_kern(WorkspaceBundle bundle, | static void do_conv_kern(WorkspaceBundle bundle, | ||||
| const ConvBiasImpl::NCBKernParam& kern_param, | const ConvBiasImpl::NCBKernParam& kern_param, | ||||
| const ConvBiasImpl::NCBKernIndex& ncb_index, | const ConvBiasImpl::NCBKernIndex& ncb_index, | ||||
| @@ -153,12 +140,12 @@ static void do_conv_kern(WorkspaceBundle bundle, | |||||
| size_t IC = kern_param.filter_meta.icpg; | size_t IC = kern_param.filter_meta.icpg; | ||||
| size_t OC = kern_param.filter_meta.ocpg; | size_t OC = kern_param.filter_meta.ocpg; | ||||
| size_t GROUP = kern_param.filter_meta.group; | size_t GROUP = kern_param.filter_meta.group; | ||||
| size_t IH2, IW2, OH2, OW2; | |||||
| get_rectified_size(kern_param, IH2, IW2, OH2, OW2); | |||||
| int IH2, IW2; | |||||
| get_rectified_size(kern_param, IH2, IW2); | |||||
| bool need_post_process = | bool need_post_process = | ||||
| kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8; | kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8; | ||||
| //! if dst_type is qint32, the op is not used, just fill with (1.0f,4.0f) | //! if dst_type is qint32, the op is not used, just fill with (1.0f,4.0f) | ||||
| Op op = Op(1.0f, 4.0f); | |||||
| Op op(1.f, 4.f); | |||||
| if (need_post_process) { | if (need_post_process) { | ||||
| float scale_bias = | float scale_bias = | ||||
| kern_param.bias_type.param<dtype::QuantizedS32>().scale; | kern_param.bias_type.param<dtype::QuantizedS32>().scale; | ||||
| @@ -191,49 +178,43 @@ static void do_conv_kern(WorkspaceBundle bundle, | |||||
| const int8_t* fptr = | const int8_t* fptr = | ||||
| kern_param.filter<dt_int8>(group_id) + oc_idx * FH * FW * IC; | kern_param.filter<dt_int8>(group_id) + oc_idx * FH * FW * IC; | ||||
| void* dst = reinterpret_cast<void*>( | |||||
| reinterpret_cast<ptrdiff_t>( | |||||
| kern_param.dst<void>(batch_id, group_id)) + | |||||
| oc_idx * OH * OW); | |||||
| DstType* dst = reinterpret_cast<DstType*>( | |||||
| kern_param.dst<void>(batch_id, group_id, oc_idx)); | |||||
| const int32_t* bptr = | const int32_t* bptr = | ||||
| kern_param.bias<dt_int32>(batch_id, group_id) + oc_idx; | kern_param.bias<dt_int32>(batch_id, group_id) + oc_idx; | ||||
| auto packed_weight = reinterpret_cast<int8_t*>(bundle.get(1)) + | auto packed_weight = reinterpret_cast<int8_t*>(bundle.get(1)) + | ||||
| group_id * OC * IC * FH * FW + oc_idx * IC * FH * FW; | group_id * OC * IC * FH * FW + oc_idx * IC * FH * FW; | ||||
| conv_bias::nchw44_pack_filter(fptr, packed_weight, | |||||
| oc_block / 4 * IC / 4 * FH * FW); | |||||
| #define KERN1_NCHW44_CONV(filter) \ | |||||
| conv_bias::conv_direct_stride2_##filter##x##filter##_int8_nchw44< \ | |||||
| bias_mode, Op, ow_remain>(sptr, packed_weight, bptr, nullptr, \ | |||||
| static_cast<int8_t*>(dst), oc_block, IC, \ | |||||
| IH2, IW2, OH, OW, op) | |||||
| DISPATCH_FILTER(filter, KERN1_NCHW44_CONV) | |||||
| #undef KERN1_NCHW44_CONV | |||||
| nchw44_pack_filter(fptr, packed_weight, oc_block / 4 * IC / 4 * FH * FW); | |||||
| conv_direct_int8_nchw44<bias_mode, Op, ow_remain, filter, DstType, stride>( | |||||
| sptr, packed_weight, bptr, nullptr, static_cast<DstType*>(dst), | |||||
| oc_block, IC, IH2, IW2, OH, OW, op); | |||||
| } | } | ||||
| /* ===================== stride2 algo ===================== */ | |||||
| bool ConvBiasImpl::AlgoS8DirectStride2NCHW44::usable( | |||||
| bool ConvBiasImpl::AlgoS8DirectNCHW44::usable( | |||||
| fallback::ConvBiasImpl*, const NCBKernSizeParam& param, | fallback::ConvBiasImpl*, const NCBKernSizeParam& param, | ||||
| AlgoSelectionStrategy algo_selection_strategy) const { | AlgoSelectionStrategy algo_selection_strategy) const { | ||||
| MEGDNN_MARK_USED_VAR(algo_selection_strategy); | MEGDNN_MARK_USED_VAR(algo_selection_strategy); | ||||
| auto&& fm = param.filter_meta; | auto&& fm = param.filter_meta; | ||||
| auto FH = fm.spatial[0]; | |||||
| auto OC = fm.ocpg; | |||||
| auto IC = fm.icpg; | |||||
| bool avaible = //! src and filter are qint8, dst is qint8 or qint32 | |||||
| const int fh = fm.spatial[0]; | |||||
| const int fw = fm.spatial[1]; | |||||
| const int oc = fm.ocpg; | |||||
| const int ic = fm.icpg; | |||||
| const bool avaible = //! src and filter are qint8, dst is qint8 or qint32 | |||||
| ((param.src_type.enumv() == DTypeEnum::QuantizedS8 && | ((param.src_type.enumv() == DTypeEnum::QuantizedS8 && | ||||
| param.filter_type.enumv() == DTypeEnum::QuantizedS8 && | param.filter_type.enumv() == DTypeEnum::QuantizedS8 && | ||||
| (param.dst_type.enumv() == DTypeEnum::QuantizedS8 || | (param.dst_type.enumv() == DTypeEnum::QuantizedS8 || | ||||
| param.dst_type.enumv() == DTypeEnum::QuantizedS32))) && | param.dst_type.enumv() == DTypeEnum::QuantizedS32))) && | ||||
| (fm.format == param::Convolution::Format::NCHW44) && | (fm.format == param::Convolution::Format::NCHW44) && | ||||
| (OC % 4 == 0 && IC % 4 == 0 && OC >= 4) && !fm.should_flip && | |||||
| (oc % 4 == 0 && ic % 4 == 0 && oc >= 4) && !fm.should_flip && | |||||
| fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | ||||
| fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 && | |||||
| FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5 || FH == 7) && | |||||
| fm.dilation[1] == 1 && fm.stride[0] == fm.stride[1] && | |||||
| (fm.stride[0] == 2 || fm.stride[0] == 1) && fh == fw && | |||||
| (fh == 2 || fh == 3 || fh == 5 || fh == 7) && | |||||
| param.bias_mode != BiasMode::BIAS; | param.bias_mode != BiasMode::BIAS; | ||||
| return avaible; | return avaible; | ||||
| } | } | ||||
| bool ConvBiasImpl::AlgoS8DirectStride2NCHW44::is_preferred( | |||||
| bool ConvBiasImpl::AlgoS8DirectNCHW44::is_preferred( | |||||
| megdnn::fallback::ConvBiasImpl* conv_bias_impl_ptr, | megdnn::fallback::ConvBiasImpl* conv_bias_impl_ptr, | ||||
| const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
| // TODO: benchmark and fix | // TODO: benchmark and fix | ||||
| @@ -242,13 +223,13 @@ bool ConvBiasImpl::AlgoS8DirectStride2NCHW44::is_preferred( | |||||
| return false; | return false; | ||||
| } | } | ||||
| size_t ConvBiasImpl::AlgoS8DirectStride2NCHW44::get_workspace( | |||||
| size_t ConvBiasImpl::AlgoS8DirectNCHW44::get_workspace( | |||||
| fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | ||||
| return get_bundle(param).total_size_in_bytes(); | return get_bundle(param).total_size_in_bytes(); | ||||
| } | } | ||||
| SmallVector<ConvBiasImpl::NCBKern> | SmallVector<ConvBiasImpl::NCBKern> | ||||
| ConvBiasImpl::AlgoS8DirectStride2NCHW44::dispatch_kerns( | |||||
| ConvBiasImpl::AlgoS8DirectNCHW44::dispatch_kerns( | |||||
| fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | ||||
| auto fm = param.filter_meta; | auto fm = param.filter_meta; | ||||
| size_t N = param.n; | size_t N = param.n; | ||||
| @@ -261,97 +242,129 @@ ConvBiasImpl::AlgoS8DirectStride2NCHW44::dispatch_kerns( | |||||
| WorkspaceBundle wbundle = get_bundle(param); | WorkspaceBundle wbundle = get_bundle(param); | ||||
| conv_fun do_conv_fun = nullptr; | conv_fun do_conv_fun = nullptr; | ||||
| int ow_remain = OW % 8; | int ow_remain = OW % 8; | ||||
| bool need_post_process = param.dst_type.enumv() == DTypeEnum::QuantizedS8; | |||||
| // NOTE: remain_w is not used to gen hash of midout for compatible with changing | // NOTE: remain_w is not used to gen hash of midout for compatible with changing | ||||
| // shape runtime | // shape runtime | ||||
| #define DO_CONV_KERN_FUN(filter, bias_mode, remain_w, op) \ | |||||
| MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw44_stride2, \ | |||||
| midout_iv(#filter #bias_mode #op##_hash)) { \ | |||||
| do_conv_fun = do_conv_kern<filter, bias_mode, op, remain_w>; \ | |||||
| } \ | |||||
| #define DO_CONV_KERN_FUN(stride, dst_type, filter, bias_mode, remain_w, op) \ | |||||
| MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw44, \ | |||||
| midout_iv(#stride #dst_type #filter #bias_mode #op##_hash)) { \ | |||||
| do_conv_fun = do_conv_kern<filter, bias_mode, op, remain_w, dst_type, \ | |||||
| stride>; \ | |||||
| } \ | |||||
| MIDOUT_END(); | MIDOUT_END(); | ||||
| #define GET_OP_PARAM(filter, bias_mode, remain_w) \ | |||||
| switch (param.nonlineMode) { \ | |||||
| case param::ConvBias::NonlineMode::IDENTITY: \ | |||||
| DO_CONV_KERN_FUN(filter, bias_mode, remain_w, \ | |||||
| TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
| break; \ | |||||
| case param::ConvBias::NonlineMode::RELU: \ | |||||
| DO_CONV_KERN_FUN(filter, bias_mode, remain_w, \ | |||||
| ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
| break; \ | |||||
| case param::ConvBias::NonlineMode::H_SWISH: \ | |||||
| DO_CONV_KERN_FUN(filter, bias_mode, remain_w, \ | |||||
| HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
| break; \ | |||||
| default: \ | |||||
| megdnn_assert(0); \ | |||||
| break; \ | |||||
| #define GET_OP_PARAM(stride, filter, bias_mode, remain_w) \ | |||||
| if (need_post_process) { \ | |||||
| switch (param.nonlineMode) { \ | |||||
| case param::ConvBias::NonlineMode::IDENTITY: \ | |||||
| DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \ | |||||
| remain_w, \ | |||||
| TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
| break; \ | |||||
| case param::ConvBias::NonlineMode::RELU: \ | |||||
| DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \ | |||||
| remain_w, \ | |||||
| ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
| break; \ | |||||
| case param::ConvBias::NonlineMode::H_SWISH: \ | |||||
| DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \ | |||||
| remain_w, \ | |||||
| HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
| break; \ | |||||
| default: \ | |||||
| megdnn_assert(0, "no supported noline mode"); \ | |||||
| break; \ | |||||
| } \ | |||||
| } else { \ | |||||
| switch (param.nonlineMode) { \ | |||||
| case param::ConvBias::NonlineMode::IDENTITY: \ | |||||
| DO_CONV_KERN_FUN(stride, dt_int32, filter, bias_mode, \ | |||||
| remain_w, NoneOp<dt_int32>) \ | |||||
| break; \ | |||||
| default: \ | |||||
| megdnn_assert( \ | |||||
| 0, \ | |||||
| "only support IDENTITY mode when dst is not qint8"); \ | |||||
| break; \ | |||||
| } \ | |||||
| } | } | ||||
| #define GET_REMAIN_W_PARAM(filter, bias_mode) \ | |||||
| switch (ow_remain) { \ | |||||
| case 0: \ | |||||
| GET_OP_PARAM(filter, bias_mode, 0); \ | |||||
| break; \ | |||||
| case 1: \ | |||||
| GET_OP_PARAM(filter, bias_mode, 1); \ | |||||
| break; \ | |||||
| case 2: \ | |||||
| GET_OP_PARAM(filter, bias_mode, 2); \ | |||||
| break; \ | |||||
| case 3: \ | |||||
| GET_OP_PARAM(filter, bias_mode, 3); \ | |||||
| break; \ | |||||
| case 4: \ | |||||
| GET_OP_PARAM(filter, bias_mode, 4); \ | |||||
| break; \ | |||||
| case 5: \ | |||||
| GET_OP_PARAM(filter, bias_mode, 5); \ | |||||
| break; \ | |||||
| case 6: \ | |||||
| GET_OP_PARAM(filter, bias_mode, 6); \ | |||||
| break; \ | |||||
| case 7: \ | |||||
| GET_OP_PARAM(filter, bias_mode, 7); \ | |||||
| break; \ | |||||
| default: \ | |||||
| megdnn_assert(0); \ | |||||
| #define GET_REMAIN_W_PARAM(stride, filter, bias_mode) \ | |||||
| switch (ow_remain) { \ | |||||
| case 0: \ | |||||
| GET_OP_PARAM(stride, filter, bias_mode, 0); \ | |||||
| break; \ | |||||
| case 1: \ | |||||
| GET_OP_PARAM(stride, filter, bias_mode, 1); \ | |||||
| break; \ | |||||
| case 2: \ | |||||
| GET_OP_PARAM(stride, filter, bias_mode, 2); \ | |||||
| break; \ | |||||
| case 3: \ | |||||
| GET_OP_PARAM(stride, filter, bias_mode, 3); \ | |||||
| break; \ | |||||
| case 4: \ | |||||
| GET_OP_PARAM(stride, filter, bias_mode, 4); \ | |||||
| break; \ | |||||
| case 5: \ | |||||
| GET_OP_PARAM(stride, filter, bias_mode, 5); \ | |||||
| break; \ | |||||
| case 6: \ | |||||
| GET_OP_PARAM(stride, filter, bias_mode, 6); \ | |||||
| break; \ | |||||
| case 7: \ | |||||
| GET_OP_PARAM(stride, filter, bias_mode, 7); \ | |||||
| break; \ | |||||
| default: \ | |||||
| megdnn_assert(0); \ | |||||
| } | } | ||||
| #define GET_BIAS_MODE_PARAM(filter) \ | |||||
| switch (param.bias_mode) { \ | |||||
| case BiasMode::NO_BIAS: \ | |||||
| GET_REMAIN_W_PARAM(filter, BiasMode::NO_BIAS) \ | |||||
| break; \ | |||||
| case BiasMode::BROADCAST_CHANNEL_BIAS: \ | |||||
| GET_REMAIN_W_PARAM(filter, BiasMode::BROADCAST_CHANNEL_BIAS) \ | |||||
| break; \ | |||||
| default: \ | |||||
| megdnn_assert(0); \ | |||||
| break; \ | |||||
| #define GET_BIAS_MODE_PARAM(stride, filter) \ | |||||
| switch (param.bias_mode) { \ | |||||
| case BiasMode::NO_BIAS: \ | |||||
| GET_REMAIN_W_PARAM(stride, filter, BiasMode::NO_BIAS) \ | |||||
| break; \ | |||||
| case BiasMode::BROADCAST_CHANNEL_BIAS: \ | |||||
| GET_REMAIN_W_PARAM(stride, filter, \ | |||||
| BiasMode::BROADCAST_CHANNEL_BIAS) \ | |||||
| break; \ | |||||
| default: \ | |||||
| megdnn_assert(0); \ | |||||
| break; \ | |||||
| } | } | ||||
| #define DISPATCH_CONV_KERN() \ | |||||
| #define DISPATCH_CONV_KERN(stride) \ | |||||
| switch (param.filter_meta.spatial[0]) { \ | switch (param.filter_meta.spatial[0]) { \ | ||||
| case 2: \ | case 2: \ | ||||
| GET_BIAS_MODE_PARAM(2) \ | |||||
| GET_BIAS_MODE_PARAM(stride, 2) \ | |||||
| break; \ | break; \ | ||||
| case 3: \ | case 3: \ | ||||
| GET_BIAS_MODE_PARAM(3) \ | |||||
| GET_BIAS_MODE_PARAM(stride, 3) \ | |||||
| break; \ | break; \ | ||||
| case 5: \ | case 5: \ | ||||
| GET_BIAS_MODE_PARAM(5) \ | |||||
| GET_BIAS_MODE_PARAM(stride, 5) \ | |||||
| break; \ | break; \ | ||||
| case 7: \ | case 7: \ | ||||
| GET_BIAS_MODE_PARAM(7) \ | |||||
| GET_BIAS_MODE_PARAM(stride, 7) \ | |||||
| break; \ | break; \ | ||||
| default: \ | default: \ | ||||
| megdnn_assert(0); \ | megdnn_assert(0); \ | ||||
| break; \ | break; \ | ||||
| } | } | ||||
| DISPATCH_CONV_KERN(); | |||||
| switch (param.filter_meta.stride[0]) { | |||||
| case 1: | |||||
| DISPATCH_CONV_KERN(1); | |||||
| break; | |||||
| case 2: | |||||
| DISPATCH_CONV_KERN(2); | |||||
| break; | |||||
| default: | |||||
| megdnn_throw(ssprintf("Unsupport stride size %u for the first conv", | |||||
| param.filter_meta.stride[0]) | |||||
| .c_str()); | |||||
| break; | |||||
| } | |||||
| #undef DO_CONV_KERN_FUN | #undef DO_CONV_KERN_FUN | ||||
| #undef GET_REMAIN_W_PARAM | #undef GET_REMAIN_W_PARAM | ||||
| @@ -1,393 +0,0 @@ | |||||
| /** | |||||
| * \file dnn/src/arm_common/conv_bias/int8/direct_stride1_nchw44_algo.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "megdnn/oprs.h" | |||||
| #include "src/arm_common/conv_bias/int8/algos.h" | |||||
| #include "src/arm_common/conv_bias/int8/direct.h" | |||||
| #include "src/arm_common/conv_bias/int8/strategy.h" | |||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/common/opr_delegate.h" | |||||
| #include "midout.h" | |||||
| using namespace megdnn; | |||||
| using namespace arm_common; | |||||
| using conv_fun = std::function<void( | |||||
| WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param, | |||||
| const ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
| const CpuNDRange& workspace_ids, const CpuNDRange& ncb_range)>; | |||||
| MIDOUT_DECL(megdnn_arm_common_conv_bias_int8_nchw44_stride1) | |||||
| static void get_rectified_size( | |||||
| const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, | |||||
| size_t& IH2, size_t& IW2, size_t& OH2, size_t& OW2) { | |||||
| auto&& fm = param.filter_meta; | |||||
| auto SW = fm.stride[1]; | |||||
| auto OH = param.osz[0]; | |||||
| auto OW = param.osz[1]; | |||||
| auto FH = fm.spatial[0]; | |||||
| auto FW = fm.spatial[1]; | |||||
| OH2 = OH; | |||||
| OW2 = (OW + 7) & ~7; | |||||
| IH2 = SW * OH + FH - SW; | |||||
| IW2 = SW * OW2 + FW - SW; | |||||
| } | |||||
| static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { | |||||
| constexpr size_t src_expand = 4; | |||||
| auto&& fm = param.filter_meta; | |||||
| size_t group = fm.group; | |||||
| size_t batch = param.n; | |||||
| size_t IC = fm.icpg; | |||||
| size_t OC = fm.ocpg; | |||||
| size_t FH = fm.spatial[0]; | |||||
| size_t FW = fm.spatial[1]; | |||||
| size_t IH2, IW2, OH2, OW2; | |||||
| get_rectified_size(param, IH2, IW2, OH2, OW2); | |||||
| if (group == 1) { | |||||
| size_t src_size = | |||||
| batch * group * IC * IH2 * IW2 * sizeof(int8_t) * src_expand; | |||||
| size_t weight_size = group * OC * IC * FH * FW * sizeof(int8_t); | |||||
| return {nullptr, {src_size, weight_size}}; | |||||
| } else { | |||||
| size_t src_size = | |||||
| param.nr_threads * IC * IH2 * IW2 * sizeof(int8_t) * src_expand; | |||||
| size_t weight_size = group * OC * IC * FH * FW * sizeof(int8_t); | |||||
| return {nullptr, {src_size, weight_size}}; | |||||
| } | |||||
| }; | |||||
| static void copy_padding_kern(WorkspaceBundle bundle, | |||||
| const ConvBiasImpl::NCBKernParam& kern_param, | |||||
| const ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
| const CpuNDRange& workspace_ids) { | |||||
| size_t IH = kern_param.isz[0]; | |||||
| size_t IW = kern_param.isz[1]; | |||||
| size_t IC = kern_param.filter_meta.icpg; | |||||
| size_t PH = kern_param.filter_meta.padding[0]; | |||||
| size_t PW = kern_param.filter_meta.padding[1]; | |||||
| size_t GROUP = kern_param.filter_meta.group; | |||||
| size_t IH2, IW2, OH2, OW2; | |||||
| get_rectified_size(kern_param, IH2, IW2, OH2, OW2); | |||||
| size_t padding_group_size = IH2 * IW2 * IC; | |||||
| bundle.set(kern_param.workspace_ptr); | |||||
| //! Used for get the workspace offset | |||||
| constexpr int pack_ic = 4; | |||||
| constexpr int expend_element = 4; | |||||
| // TODO: block dim is better to get from arg | |||||
| size_t workspace_ic_block = 4; | |||||
| size_t workspace_batch_id = workspace_ids[0]; | |||||
| size_t workspace_group_id = workspace_ids[1]; | |||||
| size_t workspace_ic_id = workspace_ids[2]; | |||||
| size_t workspace_ic = workspace_ic_id * workspace_ic_block; | |||||
| size_t batch_id = ncb_index.ndrange_id[0]; | |||||
| size_t group_id = ncb_index.ndrange_id[1]; | |||||
| size_t group_pack_size = 1; | |||||
| int nr_pad_h = PH * IW2 * pack_ic * expend_element; | |||||
| int nr_pad_w = PW * pack_ic * expend_element; | |||||
| int over_pad = std::max(0_z, IW2 - IW - 2 * PW) * pack_ic * expend_element; | |||||
| //! copy to sptr_base to eliminate padding effect | |||||
| const int8_t* sptr = static_cast<const int8_t*>(kern_param.src<int8_t>( | |||||
| batch_id, group_id, workspace_ic_id, group_pack_size, pack_ic)); | |||||
| int8_t* sptr_base = static_cast<int8_t*>(bundle.get(0)) + | |||||
| (workspace_batch_id * GROUP * padding_group_size + | |||||
| workspace_group_id * padding_group_size + | |||||
| workspace_ic * IH2 * IW2) * | |||||
| expend_element; | |||||
| size_t nr_ic = workspace_ic_block; | |||||
| if (GROUP > 1) { | |||||
| nr_ic = IC; | |||||
| } | |||||
| rep_step(ic_idx, nr_ic, pack_ic) { | |||||
| std::memset(sptr_base, 0, nr_pad_h * sizeof(int8_t)); | |||||
| sptr_base += nr_pad_h; | |||||
| rep(ih_idx, IH) { | |||||
| std::memset(sptr_base, 0, nr_pad_w * sizeof(int8_t)); | |||||
| sptr_base += nr_pad_w; | |||||
| conv_bias::nchw44_pack_src(sptr, sptr_base, IW); | |||||
| sptr_base += IW * pack_ic * expend_element; | |||||
| sptr += IW * pack_ic; | |||||
| std::memset(sptr_base, 0, (nr_pad_w + over_pad) * sizeof(int8_t)); | |||||
| sptr_base += nr_pad_w + over_pad; | |||||
| } | |||||
| std::memset(sptr_base, 0, nr_pad_h * sizeof(int8_t)); | |||||
| sptr_base += nr_pad_h; | |||||
| } | |||||
| } | |||||
| template <size_t filter, BiasMode bias_mode, typename Op, int ow_remain> | |||||
| static void do_conv_kern(WorkspaceBundle bundle, | |||||
| const ConvBiasImpl::NCBKernParam& kern_param, | |||||
| const ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
| const CpuNDRange& workspace_ids, | |||||
| const CpuNDRange& ncb_range) { | |||||
| size_t OH = kern_param.osz[0]; | |||||
| size_t OW = kern_param.osz[1]; | |||||
| size_t FH = kern_param.filter_meta.spatial[0]; | |||||
| size_t FW = kern_param.filter_meta.spatial[1]; | |||||
| size_t IC = kern_param.filter_meta.icpg; | |||||
| size_t OC = kern_param.filter_meta.ocpg; | |||||
| size_t GROUP = kern_param.filter_meta.group; | |||||
| size_t IH2, IW2, OH2, OW2; | |||||
| get_rectified_size(kern_param, IH2, IW2, OH2, OW2); | |||||
| bool need_post_process = | |||||
| kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8; | |||||
| //! if dst_type is qint32, the op is not used, just fill with (1.0f,4.0f) | |||||
| Op op = Op(1.0f, 4.0f); | |||||
| if (need_post_process) { | |||||
| float scale_bias = | |||||
| kern_param.bias_type.param<dtype::QuantizedS32>().scale; | |||||
| float scale_dst = kern_param.dst_type.param<dtype::QuantizedS8>().scale; | |||||
| op = Op(scale_bias, scale_dst); | |||||
| } | |||||
| size_t padding_group_size = IH2 * IW2 * IC; | |||||
| bundle.set(kern_param.workspace_ptr); | |||||
| constexpr size_t pack_c = 4; | |||||
| constexpr size_t src_expand_size = 4; | |||||
| const size_t workspace_batch_id = workspace_ids[0]; | |||||
| const size_t workspace_group_id = workspace_ids[1]; | |||||
| const size_t batch_id = ncb_index.ndrange_id[0]; | |||||
| const size_t group_id = ncb_index.ndrange_id[1]; | |||||
| const size_t oc_id = ncb_index.ndrange_id[2]; | |||||
| const size_t oc_block_num = ncb_range[2]; | |||||
| size_t nr_pack_per_step = div_ceil(div_ceil(OC, pack_c), oc_block_num); | |||||
| size_t oc_block = nr_pack_per_step * pack_c; | |||||
| const size_t oc_idx = oc_id * oc_block; | |||||
| if (oc_id == (oc_block_num - 1)) { | |||||
| oc_block = OC - oc_id * nr_pack_per_step * pack_c; | |||||
| } | |||||
| megdnn_assert(oc_block % pack_c == 0, | |||||
| "oc must be devisible by 4, but oc = %zu", oc_block); | |||||
| const int8_t* sptr = | |||||
| static_cast<int8_t*>(bundle.get(0)) + | |||||
| workspace_batch_id * GROUP * padding_group_size * src_expand_size + | |||||
| workspace_group_id * padding_group_size * src_expand_size; | |||||
| const int8_t* fptr = | |||||
| kern_param.filter<dt_int8>(group_id) + oc_idx * FH * FW * IC; | |||||
| void* dst = reinterpret_cast<void*>( | |||||
| reinterpret_cast<ptrdiff_t>( | |||||
| kern_param.dst<void>(batch_id, group_id)) + | |||||
| oc_idx * OH * OW); | |||||
| const int32_t* bptr = | |||||
| kern_param.bias<dt_int32>(batch_id, group_id) + oc_idx; | |||||
| auto packed_weight = reinterpret_cast<int8_t*>(bundle.get(1)) + | |||||
| group_id * OC * IC * FH * FW + oc_idx * IC * FH * FW; | |||||
| conv_bias::nchw44_pack_filter(fptr, packed_weight, | |||||
| oc_block / 4 * IC / 4 * FH * FW); | |||||
| #define KERN1_NCHW44_CONV(filter) \ | |||||
| conv_bias::conv_direct_stride1_##filter##x##filter##_int8_nchw44< \ | |||||
| bias_mode, Op, ow_remain>(sptr, packed_weight, bptr, nullptr, \ | |||||
| static_cast<int8_t*>(dst), oc_block, IC, \ | |||||
| IH2, IW2, OH, OW, op) | |||||
| DISPATCH_FILTER(filter, KERN1_NCHW44_CONV) | |||||
| #undef KERN1_NCHW44_CONV | |||||
| } | |||||
| /* ===================== stride1 algo ===================== */ | |||||
| bool ConvBiasImpl::AlgoS8DirectStride1NCHW44::usable( | |||||
| fallback::ConvBiasImpl*, const NCBKernSizeParam& param, | |||||
| AlgoSelectionStrategy algo_selection_strategy) const { | |||||
| MEGDNN_MARK_USED_VAR(algo_selection_strategy); | |||||
| auto&& fm = param.filter_meta; | |||||
| auto FH = fm.spatial[0]; | |||||
| auto OC = fm.ocpg; | |||||
| auto IC = fm.icpg; | |||||
| bool avaible = //! src and filter are qint8, dst is qint8 or qint32 | |||||
| ((param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||||
| param.filter_type.enumv() == DTypeEnum::QuantizedS8 && | |||||
| (param.dst_type.enumv() == DTypeEnum::QuantizedS8 || | |||||
| param.dst_type.enumv() == DTypeEnum::QuantizedS32))) && | |||||
| (fm.format == param::Convolution::Format::NCHW44) && | |||||
| (OC % 4 == 0 && IC % 4 == 0 && OC >= 4) && !fm.should_flip && | |||||
| fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | |||||
| fm.dilation[1] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1 && | |||||
| FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5 || FH == 7) && | |||||
| param.bias_mode != BiasMode::BIAS; | |||||
| return avaible; | |||||
| } | |||||
| bool ConvBiasImpl::AlgoS8DirectStride1NCHW44::is_preferred( | |||||
| megdnn::fallback::ConvBiasImpl* conv_bias_impl_ptr, | |||||
| const NCBKernSizeParam& param) const { | |||||
| // TODO: benchmark and fix | |||||
| MEGDNN_MARK_USED_VAR(conv_bias_impl_ptr); | |||||
| MEGDNN_MARK_USED_VAR(param); | |||||
| return false; | |||||
| } | |||||
| size_t ConvBiasImpl::AlgoS8DirectStride1NCHW44::get_workspace( | |||||
| fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||||
| return get_bundle(param).total_size_in_bytes(); | |||||
| } | |||||
| SmallVector<ConvBiasImpl::NCBKern> | |||||
| ConvBiasImpl::AlgoS8DirectStride1NCHW44::dispatch_kerns( | |||||
| fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||||
| auto fm = param.filter_meta; | |||||
| size_t N = param.n; | |||||
| size_t IC = fm.icpg; | |||||
| size_t OC = fm.ocpg; | |||||
| size_t OW = param.osz[1]; | |||||
| size_t group = fm.group; | |||||
| size_t fh = fm.spatial[0]; | |||||
| size_t fw = fm.spatial[1]; | |||||
| WorkspaceBundle wbundle = get_bundle(param); | |||||
| conv_fun do_conv_fun = nullptr; | |||||
| int ow_remain = OW % 8; | |||||
| // NOTE: remain_w is not used to gen hash of midout for compatible with changing | |||||
| // shape runtime | |||||
| #define DO_CONV_KERN_FUN(filter, bias_mode, remain_w, op) \ | |||||
| MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw44_stride1, \ | |||||
| midout_iv(#filter #bias_mode #op##_hash)) { \ | |||||
| do_conv_fun = do_conv_kern<filter, bias_mode, op, remain_w>; \ | |||||
| } \ | |||||
| MIDOUT_END(); | |||||
| #define GET_OP_PARAM(filter, bias_mode, remain_w) \ | |||||
| switch (param.nonlineMode) { \ | |||||
| case param::ConvBias::NonlineMode::IDENTITY: \ | |||||
| DO_CONV_KERN_FUN(filter, bias_mode, remain_w, \ | |||||
| TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
| break; \ | |||||
| case param::ConvBias::NonlineMode::RELU: \ | |||||
| DO_CONV_KERN_FUN(filter, bias_mode, remain_w, \ | |||||
| ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
| break; \ | |||||
| case param::ConvBias::NonlineMode::H_SWISH: \ | |||||
| DO_CONV_KERN_FUN(filter, bias_mode, remain_w, \ | |||||
| HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
| break; \ | |||||
| default: \ | |||||
| megdnn_assert(0); \ | |||||
| break; \ | |||||
| } | |||||
| #define GET_REMAIN_W_PARAM(filter, bias_mode) \ | |||||
| switch (ow_remain) { \ | |||||
| case 0: \ | |||||
| GET_OP_PARAM(filter, bias_mode, 0); \ | |||||
| break; \ | |||||
| case 1: \ | |||||
| GET_OP_PARAM(filter, bias_mode, 1); \ | |||||
| break; \ | |||||
| case 2: \ | |||||
| GET_OP_PARAM(filter, bias_mode, 2); \ | |||||
| break; \ | |||||
| case 3: \ | |||||
| GET_OP_PARAM(filter, bias_mode, 3); \ | |||||
| break; \ | |||||
| case 4: \ | |||||
| GET_OP_PARAM(filter, bias_mode, 4); \ | |||||
| break; \ | |||||
| case 5: \ | |||||
| GET_OP_PARAM(filter, bias_mode, 5); \ | |||||
| break; \ | |||||
| case 6: \ | |||||
| GET_OP_PARAM(filter, bias_mode, 6); \ | |||||
| break; \ | |||||
| case 7: \ | |||||
| GET_OP_PARAM(filter, bias_mode, 7); \ | |||||
| break; \ | |||||
| default: \ | |||||
| megdnn_assert(0); \ | |||||
| } | |||||
| #define GET_BIAS_MODE_PARAM(filter) \ | |||||
| switch (param.bias_mode) { \ | |||||
| case BiasMode::NO_BIAS: \ | |||||
| GET_REMAIN_W_PARAM(filter, BiasMode::NO_BIAS) \ | |||||
| break; \ | |||||
| case BiasMode::BROADCAST_CHANNEL_BIAS: \ | |||||
| GET_REMAIN_W_PARAM(filter, BiasMode::BROADCAST_CHANNEL_BIAS) \ | |||||
| break; \ | |||||
| default: \ | |||||
| megdnn_assert(0); \ | |||||
| break; \ | |||||
| } | |||||
| #define DISPATCH_CONV_KERN() \ | |||||
| switch (param.filter_meta.spatial[0]) { \ | |||||
| case 2: \ | |||||
| GET_BIAS_MODE_PARAM(2) \ | |||||
| break; \ | |||||
| case 3: \ | |||||
| GET_BIAS_MODE_PARAM(3) \ | |||||
| break; \ | |||||
| case 5: \ | |||||
| GET_BIAS_MODE_PARAM(5) \ | |||||
| break; \ | |||||
| case 7: \ | |||||
| GET_BIAS_MODE_PARAM(7) \ | |||||
| break; \ | |||||
| default: \ | |||||
| megdnn_assert(0); \ | |||||
| break; \ | |||||
| } | |||||
| DISPATCH_CONV_KERN(); | |||||
| #undef DO_CONV_KERN_FUN | |||||
| #undef GET_REMAIN_W_PARAM | |||||
| #undef GET_OP_PARAM | |||||
| #undef GET_BIAS_MODE_PARAM | |||||
| #undef DISPATCH_CONV_KERN | |||||
| megdnn_assert(do_conv_fun); | |||||
| SmallVector<ConvBiasImpl::NCBKern> ret_kerns; | |||||
| WorkspaceBundle bundle = wbundle; | |||||
| constexpr size_t pack_oc = 4; | |||||
| size_t oc_step = pack_oc; | |||||
| if (fh == 2 && fw == 2 && OC >= 8) { | |||||
| oc_step = 8; | |||||
| } | |||||
| if (group == 1) { | |||||
| CpuNDRange ncb_range = {N, group, div_ceil(OC, oc_step)}; | |||||
| auto copy_padding = [bundle](const NCBKernParam& kern_param, | |||||
| const NCBKernIndex& ncb_index) { | |||||
| copy_padding_kern(bundle, kern_param, ncb_index, | |||||
| ncb_index.ndrange_id); | |||||
| }; | |||||
| constexpr size_t pack_ic = 4; | |||||
| ret_kerns.push_back({copy_padding, {N, group, div_ceil(IC, pack_ic)}}); | |||||
| auto do_conv = [bundle, do_conv_fun, ncb_range]( | |||||
| const NCBKernParam& kern_param, | |||||
| const NCBKernIndex& ncb_index) { | |||||
| do_conv_fun(bundle, kern_param, ncb_index, ncb_index.ndrange_id, | |||||
| ncb_range); | |||||
| }; | |||||
| ret_kerns.push_back({do_conv, ncb_range}); | |||||
| } else { | |||||
| CpuNDRange ncb_range = {N, group, 1}; | |||||
| auto do_conv = [bundle, do_conv_fun, ncb_range]( | |||||
| const NCBKernParam& kern_param, | |||||
| const NCBKernIndex& ncb_index) { | |||||
| copy_padding_kern(bundle, kern_param, ncb_index, | |||||
| {0, ncb_index.thread_id, 0}); | |||||
| do_conv_fun(bundle, kern_param, ncb_index, | |||||
| {0, ncb_index.thread_id, 0}, ncb_range); | |||||
| }; | |||||
| ret_kerns.push_back({do_conv, ncb_range}); | |||||
| } | |||||
| return ret_kerns; | |||||
| } | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -1,791 +0,0 @@ | |||||
| /** | |||||
| * \file dnn/src/arm_common/conv_bias/int8/direct_stride1_nchw44_kern.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "src/arm_common/conv_bias/int8/direct.h" | |||||
| #include "src/arm_common/conv_bias/intrinsic_helper.h" | |||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||||
| #include "src/common/utils.h" | |||||
| #include "src/fallback/conv_bias/common.h" | |||||
| using namespace megdnn; | |||||
| using namespace arm_common; | |||||
| namespace { | |||||
| /** | |||||
| dot like impl. dot 4 ic to 1 oc, accumale to c <ow, oc> | |||||
| example: (format like weight<oc, ic>) | |||||
| packed weight | |||||
| low 64 bit <0, 0> <0, 1> <1, 2> <1, 3> | <2, 0> <2, 1> <3, 2> <3, 3> | |||||
| --------------------------------------------------------------------- | |||||
| high 64 bit <0, 3> <0, 2> <1, 1> <1, 0> | <2, 3> <2, 2> <3, 1> <3, 0> | |||||
| dot: (<0, 0> + <0, 3>) + (<0, 1> + <0, 2>) -> <0> | |||||
| **/ | |||||
| // TODO: can try oh = 2 impl, oc = 8 impl | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int filter_size> | |||||
| static void ker_neon_dirctconv_3x3s1_oc4_ow8(const int8_t* src_ptr, | |||||
| const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, | |||||
| int8_t* dst_ptr, int ic, int ih, | |||||
| int iw, const Op& op) { | |||||
| constexpr int fh = filter_size; | |||||
| constexpr int fw = filter_size; | |||||
| constexpr int ic_step = 4; | |||||
| constexpr int loop_ic_step = 4; | |||||
| constexpr int ld_weight_ic4 = 16; | |||||
| constexpr int pack_iw_len = 4; | |||||
| const int ic_stride = ih * iw * pack_iw_len; | |||||
| int32x4_t c[2 * 4]; | |||||
| int8x16_t weight[3]; | |||||
| int8x16_t src[8 + 2]; | |||||
| int16x8_t temp_c[2]; | |||||
| init_oc4_ow8<bias_mode>(c, bias_ptr); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
| for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||||
| const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + | |||||
| fh_idx * iw * ic_step * pack_iw_len; | |||||
| src[0] = vld1q_s8(src_ic_0_3); | |||||
| src[1] = vld1q_s8((src_ic_0_3 + 16)); | |||||
| src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); | |||||
| src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); | |||||
| src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); | |||||
| src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); | |||||
| src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); | |||||
| src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); | |||||
| src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); | |||||
| src[9] = vld1q_s8((src_ic_0_3 + 9 * 16)); | |||||
| // oc == 0 | |||||
| const int8_t* read_weight_ptr = | |||||
| weight_ptr + fh_idx * fw * ld_weight_ic4; | |||||
| weight[0] = vld1q_s8(read_weight_ptr); | |||||
| weight[1] = vld1q_s8(read_weight_ptr + 16); | |||||
| weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); | |||||
| c[0] = vdotq_s32_h(weight[0], src[0], c[0], temp_c[0]); | |||||
| c[1] = vdotq_s32_h(weight[0], src[1], c[1], temp_c[1]); | |||||
| c[0] = vdotq_s32_h(weight[1], src[1], c[0], temp_c[0]); | |||||
| c[1] = vdotq_s32_h(weight[1], src[2], c[1], temp_c[1]); | |||||
| c[0] = vdotq_s32_h(weight[2], src[2], c[0], temp_c[0]); | |||||
| c[1] = vdotq_s32_h(weight[2], src[3], c[1], temp_c[1]); | |||||
| c[2] = vdotq_s32_h(weight[0], src[2], c[2], temp_c[0]); | |||||
| c[3] = vdotq_s32_h(weight[0], src[3], c[3], temp_c[1]); | |||||
| c[2] = vdotq_s32_h(weight[1], src[3], c[2], temp_c[0]); | |||||
| c[3] = vdotq_s32_h(weight[1], src[4], c[3], temp_c[1]); | |||||
| c[2] = vdotq_s32_h(weight[2], src[4], c[2], temp_c[0]); | |||||
| c[3] = vdotq_s32_h(weight[2], src[5], c[3], temp_c[1]); | |||||
| c[4] = vdotq_s32_h(weight[0], src[4], c[4], temp_c[0]); | |||||
| c[5] = vdotq_s32_h(weight[0], src[5], c[5], temp_c[1]); | |||||
| c[4] = vdotq_s32_h(weight[1], src[5], c[4], temp_c[0]); | |||||
| c[5] = vdotq_s32_h(weight[1], src[6], c[5], temp_c[1]); | |||||
| c[4] = vdotq_s32_h(weight[2], src[6], c[4], temp_c[0]); | |||||
| c[5] = vdotq_s32_h(weight[2], src[7], c[5], temp_c[1]); | |||||
| c[6] = vdotq_s32_h(weight[0], src[6], c[6], temp_c[0]); | |||||
| c[7] = vdotq_s32_h(weight[0], src[7], c[7], temp_c[1]); | |||||
| c[6] = vdotq_s32_h(weight[1], src[7], c[6], temp_c[0]); | |||||
| c[7] = vdotq_s32_h(weight[1], src[8], c[7], temp_c[1]); | |||||
| c[6] = vdotq_s32_h(weight[2], src[8], c[6], temp_c[0]); | |||||
| c[7] = vdotq_s32_h(weight[2], src[9], c[7], temp_c[1]); | |||||
| } | |||||
| weight_ptr += fh * fw * ld_weight_ic4; | |||||
| } | |||||
| store_oc4_ow8_remain_static<remain_w, Op>(c, op, dst_ptr); | |||||
| } | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int filter_size> | |||||
| static void ker_neon_dirctconv_2x2s1_oc8_ow8(const int8_t* src_ptr, | |||||
| const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, | |||||
| int8_t* dst_ptr, int ic, int ih, | |||||
| int iw, int ld_dst_oc, | |||||
| const Op& op) { | |||||
| constexpr int fh = filter_size; | |||||
| constexpr int fw = filter_size; | |||||
| constexpr int ic_step = 4; | |||||
| constexpr int oc_step = 4; | |||||
| constexpr int loop_ic_step = 4; | |||||
| constexpr int ld_weight_ic4 = 16; | |||||
| constexpr int pack_iw_len = 4; | |||||
| const int ic_stride = ih * iw * pack_iw_len; | |||||
| const int ld_weight_oc4 = oc_step * fh * fw * ic; | |||||
| int32x4_t c[2][8]; | |||||
| int8x16_t weight[2][2]; | |||||
| int8x16_t src[8 + 1]; | |||||
| int16x8_t temp_c[4]; | |||||
| init_oc8_ow8<bias_mode>(c, bias_ptr, oc_step); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
| for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||||
| const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + | |||||
| fh_idx * iw * ic_step * pack_iw_len; | |||||
| src[0] = vld1q_s8(src_ic_0_3); | |||||
| src[1] = vld1q_s8((src_ic_0_3 + 16)); | |||||
| src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); | |||||
| src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); | |||||
| src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); | |||||
| src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); | |||||
| src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); | |||||
| src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); | |||||
| src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); | |||||
| // oc == 0 | |||||
| const int8_t* read_weight_ptr = | |||||
| weight_ptr + fh_idx * fw * ld_weight_ic4; | |||||
| weight[0][0] = vld1q_s8(read_weight_ptr); | |||||
| weight[0][1] = vld1q_s8(read_weight_ptr + 16); | |||||
| weight[1][0] = vld1q_s8(read_weight_ptr + ld_weight_oc4); | |||||
| weight[1][1] = vld1q_s8(read_weight_ptr + ld_weight_oc4 + 16); | |||||
| c[0][0] = vdotq_s32_h(weight[0][0], src[0], c[0][0], temp_c[0]); | |||||
| c[1][0] = vdotq_s32_h(weight[1][0], src[0], c[1][0], temp_c[1]); | |||||
| c[0][1] = vdotq_s32_h(weight[0][0], src[1], c[0][1], temp_c[2]); | |||||
| c[1][1] = vdotq_s32_h(weight[1][0], src[1], c[1][1], temp_c[3]); | |||||
| c[0][0] = vdotq_s32_h(weight[0][1], src[1], c[0][0], temp_c[0]); | |||||
| c[1][0] = vdotq_s32_h(weight[1][1], src[1], c[1][0], temp_c[1]); | |||||
| c[0][1] = vdotq_s32_h(weight[0][1], src[2], c[0][1], temp_c[2]); | |||||
| c[1][1] = vdotq_s32_h(weight[1][1], src[2], c[1][1], temp_c[3]); | |||||
| c[0][2] = vdotq_s32_h(weight[0][0], src[2], c[0][2], temp_c[0]); | |||||
| c[1][2] = vdotq_s32_h(weight[1][0], src[2], c[1][2], temp_c[1]); | |||||
| c[0][3] = vdotq_s32_h(weight[0][0], src[3], c[0][3], temp_c[2]); | |||||
| c[1][3] = vdotq_s32_h(weight[1][0], src[3], c[1][3], temp_c[3]); | |||||
| c[0][2] = vdotq_s32_h(weight[0][1], src[3], c[0][2], temp_c[0]); | |||||
| c[1][2] = vdotq_s32_h(weight[1][1], src[3], c[1][2], temp_c[1]); | |||||
| c[0][3] = vdotq_s32_h(weight[0][1], src[4], c[0][3], temp_c[2]); | |||||
| c[1][3] = vdotq_s32_h(weight[1][1], src[4], c[1][3], temp_c[3]); | |||||
| c[0][4] = vdotq_s32_h(weight[0][0], src[4], c[0][4], temp_c[0]); | |||||
| c[1][4] = vdotq_s32_h(weight[1][0], src[4], c[1][4], temp_c[1]); | |||||
| c[0][5] = vdotq_s32_h(weight[0][0], src[5], c[0][5], temp_c[2]); | |||||
| c[1][5] = vdotq_s32_h(weight[1][0], src[5], c[1][5], temp_c[3]); | |||||
| c[0][4] = vdotq_s32_h(weight[0][1], src[5], c[0][4], temp_c[0]); | |||||
| c[1][4] = vdotq_s32_h(weight[1][1], src[5], c[1][4], temp_c[1]); | |||||
| c[0][5] = vdotq_s32_h(weight[0][1], src[6], c[0][5], temp_c[2]); | |||||
| c[1][5] = vdotq_s32_h(weight[1][1], src[6], c[1][5], temp_c[3]); | |||||
| c[0][6] = vdotq_s32_h(weight[0][0], src[6], c[0][6], temp_c[0]); | |||||
| c[1][6] = vdotq_s32_h(weight[1][0], src[6], c[1][6], temp_c[1]); | |||||
| c[0][7] = vdotq_s32_h(weight[0][0], src[7], c[0][7], temp_c[2]); | |||||
| c[1][7] = vdotq_s32_h(weight[1][0], src[7], c[1][7], temp_c[3]); | |||||
| c[0][6] = vdotq_s32_h(weight[0][1], src[7], c[0][6], temp_c[0]); | |||||
| c[1][6] = vdotq_s32_h(weight[1][1], src[7], c[1][6], temp_c[1]); | |||||
| c[0][7] = vdotq_s32_h(weight[0][1], src[8], c[0][7], temp_c[2]); | |||||
| c[1][7] = vdotq_s32_h(weight[1][1], src[8], c[1][7], temp_c[3]); | |||||
| } | |||||
| weight_ptr += fh * fw * ld_weight_ic4; | |||||
| } | |||||
| store_oc8_ow8_remain_static<remain_w>(c, op, dst_ptr, ld_dst_oc); | |||||
| } | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int filter_size> | |||||
| static void ker_neon_dirctconv_2x2s1_oc4_ow8(const int8_t* src_ptr, | |||||
| const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, | |||||
| int8_t* dst_ptr, int ic, int ih, | |||||
| int iw, const Op& op) { | |||||
| constexpr int fh = filter_size; | |||||
| constexpr int fw = filter_size; | |||||
| constexpr int ic_step = 4; | |||||
| constexpr int loop_ic_step = 4; | |||||
| constexpr int ld_weight_ic4 = 16; | |||||
| constexpr int pack_iw_len = 4; | |||||
| const int ic_stride = ih * iw * pack_iw_len; | |||||
| int32x4_t c[2 * 4]; | |||||
| int8x16_t weight[2]; | |||||
| int8x16_t src[8 + 1]; | |||||
| int16x8_t temp_c[2]; | |||||
| init_oc4_ow8<bias_mode>(c, bias_ptr); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
| for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||||
| const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + | |||||
| fh_idx * iw * ic_step * pack_iw_len; | |||||
| src[0] = vld1q_s8(src_ic_0_3); | |||||
| src[1] = vld1q_s8((src_ic_0_3 + 16)); | |||||
| src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); | |||||
| src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); | |||||
| src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); | |||||
| src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); | |||||
| src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); | |||||
| src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); | |||||
| src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); | |||||
| // oc == 0 | |||||
| const int8_t* read_weight_ptr = | |||||
| weight_ptr + fh_idx * fw * ld_weight_ic4; | |||||
| weight[0] = vld1q_s8(read_weight_ptr); | |||||
| weight[1] = vld1q_s8(read_weight_ptr + 16); | |||||
| c[0] = vdotq_s32_h(weight[0], src[0], c[0], temp_c[0]); | |||||
| c[1] = vdotq_s32_h(weight[0], src[1], c[1], temp_c[1]); | |||||
| c[0] = vdotq_s32_h(weight[1], src[1], c[0], temp_c[0]); | |||||
| c[1] = vdotq_s32_h(weight[1], src[2], c[1], temp_c[1]); | |||||
| c[2] = vdotq_s32_h(weight[0], src[2], c[2], temp_c[0]); | |||||
| c[3] = vdotq_s32_h(weight[0], src[3], c[3], temp_c[1]); | |||||
| c[2] = vdotq_s32_h(weight[1], src[3], c[2], temp_c[0]); | |||||
| c[3] = vdotq_s32_h(weight[1], src[4], c[3], temp_c[1]); | |||||
| c[4] = vdotq_s32_h(weight[0], src[4], c[4], temp_c[0]); | |||||
| c[5] = vdotq_s32_h(weight[0], src[5], c[5], temp_c[1]); | |||||
| c[4] = vdotq_s32_h(weight[1], src[5], c[4], temp_c[0]); | |||||
| c[5] = vdotq_s32_h(weight[1], src[6], c[5], temp_c[1]); | |||||
| c[6] = vdotq_s32_h(weight[0], src[6], c[6], temp_c[0]); | |||||
| c[7] = vdotq_s32_h(weight[0], src[7], c[7], temp_c[1]); | |||||
| c[6] = vdotq_s32_h(weight[1], src[7], c[6], temp_c[0]); | |||||
| c[7] = vdotq_s32_h(weight[1], src[8], c[7], temp_c[1]); | |||||
| } | |||||
| weight_ptr += fh * fw * ld_weight_ic4; | |||||
| } | |||||
| store_oc4_ow8_remain_static<remain_w, Op>(c, op, dst_ptr); | |||||
| } | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int filter_size> | |||||
| static void ker_neon_dirctconv_5x5s1_oc4_ow8(const int8_t* src_ptr, | |||||
| const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, | |||||
| int8_t* dst_ptr, int ic, int ih, | |||||
| int iw, const Op& op) { | |||||
| constexpr int fh = filter_size; | |||||
| constexpr int fw = filter_size; | |||||
| constexpr int ic_step = 4; | |||||
| constexpr int loop_ic_step = 4; | |||||
| constexpr int ld_weight_ic4 = 16; | |||||
| constexpr int pack_iw_len = 4; | |||||
| const int ic_stride = ih * iw * pack_iw_len; | |||||
| int32x4_t c[2 * 4]; | |||||
| int8x16_t weight[5]; | |||||
| int8x16_t src[8 + 2]; | |||||
| int16x8_t temp_c[2]; | |||||
| init_oc4_ow8<bias_mode>(c, bias_ptr); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
| for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||||
| const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + | |||||
| fh_idx * iw * ic_step * pack_iw_len; | |||||
| src[0] = vld1q_s8(src_ic_0_3); | |||||
| src[1] = vld1q_s8((src_ic_0_3 + 16)); | |||||
| src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); | |||||
| src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); | |||||
| src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); | |||||
| src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); | |||||
| src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); | |||||
| src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); | |||||
| src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); | |||||
| src[9] = vld1q_s8((src_ic_0_3 + 9 * 16)); | |||||
| // oc == 0 | |||||
| const int8_t* read_weight_ptr = | |||||
| weight_ptr + fh_idx * fw * ld_weight_ic4; | |||||
| weight[0] = vld1q_s8(read_weight_ptr); | |||||
| weight[1] = vld1q_s8(read_weight_ptr + 16); | |||||
| weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); | |||||
| weight[3] = vld1q_s8(read_weight_ptr + 3 * 16); | |||||
| weight[4] = vld1q_s8(read_weight_ptr + 4 * 16); | |||||
| c[0] = vdotq_s32_h(weight[0], src[0], c[0], temp_c[0]); | |||||
| c[1] = vdotq_s32_h(weight[0], src[1], c[1], temp_c[1]); | |||||
| c[0] = vdotq_s32_h(weight[1], src[1], c[0], temp_c[0]); | |||||
| c[1] = vdotq_s32_h(weight[1], src[2], c[1], temp_c[1]); | |||||
| c[0] = vdotq_s32_h(weight[2], src[2], c[0], temp_c[0]); | |||||
| c[1] = vdotq_s32_h(weight[2], src[3], c[1], temp_c[1]); | |||||
| c[0] = vdotq_s32_h(weight[3], src[3], c[0], temp_c[0]); | |||||
| c[1] = vdotq_s32_h(weight[3], src[4], c[1], temp_c[1]); | |||||
| c[0] = vdotq_s32_h(weight[4], src[4], c[0], temp_c[0]); | |||||
| c[1] = vdotq_s32_h(weight[4], src[5], c[1], temp_c[1]); | |||||
| c[2] = vdotq_s32_h(weight[0], src[2], c[2], temp_c[0]); | |||||
| c[3] = vdotq_s32_h(weight[0], src[3], c[3], temp_c[1]); | |||||
| c[2] = vdotq_s32_h(weight[1], src[3], c[2], temp_c[0]); | |||||
| c[3] = vdotq_s32_h(weight[1], src[4], c[3], temp_c[1]); | |||||
| c[2] = vdotq_s32_h(weight[2], src[4], c[2], temp_c[0]); | |||||
| c[3] = vdotq_s32_h(weight[2], src[5], c[3], temp_c[1]); | |||||
| c[2] = vdotq_s32_h(weight[3], src[5], c[2], temp_c[0]); | |||||
| c[3] = vdotq_s32_h(weight[3], src[6], c[3], temp_c[1]); | |||||
| c[2] = vdotq_s32_h(weight[4], src[6], c[2], temp_c[0]); | |||||
| c[3] = vdotq_s32_h(weight[4], src[7], c[3], temp_c[1]); | |||||
| c[4] = vdotq_s32_h(weight[0], src[4], c[4], temp_c[0]); | |||||
| c[5] = vdotq_s32_h(weight[0], src[5], c[5], temp_c[1]); | |||||
| c[4] = vdotq_s32_h(weight[1], src[5], c[4], temp_c[0]); | |||||
| c[5] = vdotq_s32_h(weight[1], src[6], c[5], temp_c[1]); | |||||
| c[4] = vdotq_s32_h(weight[2], src[6], c[4], temp_c[0]); | |||||
| c[5] = vdotq_s32_h(weight[2], src[7], c[5], temp_c[1]); | |||||
| c[4] = vdotq_s32_h(weight[3], src[7], c[4], temp_c[0]); | |||||
| c[5] = vdotq_s32_h(weight[3], src[8], c[5], temp_c[1]); | |||||
| c[4] = vdotq_s32_h(weight[4], src[8], c[4], temp_c[0]); | |||||
| c[5] = vdotq_s32_h(weight[4], src[9], c[5], temp_c[1]); | |||||
| src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); | |||||
| src[1] = vld1q_s8((src_ic_0_3 + 11 * 16)); | |||||
| c[6] = vdotq_s32_h(weight[0], src[6], c[6], temp_c[0]); | |||||
| c[7] = vdotq_s32_h(weight[0], src[7], c[7], temp_c[1]); | |||||
| c[6] = vdotq_s32_h(weight[1], src[7], c[6], temp_c[0]); | |||||
| c[7] = vdotq_s32_h(weight[1], src[8], c[7], temp_c[1]); | |||||
| c[6] = vdotq_s32_h(weight[2], src[8], c[6], temp_c[0]); | |||||
| c[7] = vdotq_s32_h(weight[2], src[9], c[7], temp_c[1]); | |||||
| c[6] = vdotq_s32_h(weight[3], src[9], c[6], temp_c[0]); | |||||
| c[7] = vdotq_s32_h(weight[3], src[0], c[7], temp_c[1]); | |||||
| c[6] = vdotq_s32_h(weight[4], src[0], c[6], temp_c[0]); | |||||
| c[7] = vdotq_s32_h(weight[4], src[1], c[7], temp_c[1]); | |||||
| } | |||||
| weight_ptr += fh * fw * ld_weight_ic4; | |||||
| } | |||||
| store_oc4_ow8_remain_static<remain_w, Op>(c, op, dst_ptr); | |||||
| } | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int filter_size> | |||||
| static void ker_neon_dirctconv_7x7s1_oc4_ow8(const int8_t* src_ptr, | |||||
| const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, | |||||
| int8_t* dst_ptr, int ic, int ih, | |||||
| int iw, const Op& op) { | |||||
| constexpr int fh = filter_size; | |||||
| constexpr int fw = filter_size; | |||||
| constexpr int ic_step = 4; | |||||
| constexpr int loop_ic_step = 4; | |||||
| constexpr int ld_weight_ic4 = 16; | |||||
| constexpr int pack_iw_len = 4; | |||||
| const int ic_stride = ih * iw * pack_iw_len; | |||||
| int32x4_t c[2 * 4]; | |||||
| int8x16_t weight[7]; | |||||
| int8x16_t src[8 + 2]; | |||||
| int16x8_t temp_c[2]; | |||||
| init_oc4_ow8<bias_mode>(c, bias_ptr); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
| for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||||
| const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + | |||||
| fh_idx * iw * ic_step * pack_iw_len; | |||||
| src[0] = vld1q_s8(src_ic_0_3); | |||||
| src[1] = vld1q_s8((src_ic_0_3 + 16)); | |||||
| src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); | |||||
| src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); | |||||
| src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); | |||||
| src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); | |||||
| src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); | |||||
| src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); | |||||
| src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); | |||||
| src[9] = vld1q_s8((src_ic_0_3 + 9 * 16)); | |||||
| // oc == 0 | |||||
| const int8_t* read_weight_ptr = | |||||
| weight_ptr + fh_idx * fw * ld_weight_ic4; | |||||
| weight[0] = vld1q_s8(read_weight_ptr); | |||||
| weight[1] = vld1q_s8(read_weight_ptr + 16); | |||||
| weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); | |||||
| weight[3] = vld1q_s8(read_weight_ptr + 3 * 16); | |||||
| weight[4] = vld1q_s8(read_weight_ptr + 4 * 16); | |||||
| weight[5] = vld1q_s8(read_weight_ptr + 5 * 16); | |||||
| weight[6] = vld1q_s8(read_weight_ptr + 6 * 16); | |||||
| c[0] = vdotq_s32_h(weight[0], src[0], c[0], temp_c[0]); | |||||
| c[1] = vdotq_s32_h(weight[0], src[1], c[1], temp_c[1]); | |||||
| c[0] = vdotq_s32_h(weight[1], src[1], c[0], temp_c[0]); | |||||
| c[1] = vdotq_s32_h(weight[1], src[2], c[1], temp_c[1]); | |||||
| c[0] = vdotq_s32_h(weight[2], src[2], c[0], temp_c[0]); | |||||
| c[1] = vdotq_s32_h(weight[2], src[3], c[1], temp_c[1]); | |||||
| c[0] = vdotq_s32_h(weight[3], src[3], c[0], temp_c[0]); | |||||
| c[1] = vdotq_s32_h(weight[3], src[4], c[1], temp_c[1]); | |||||
| c[0] = vdotq_s32_h(weight[4], src[4], c[0], temp_c[0]); | |||||
| c[1] = vdotq_s32_h(weight[4], src[5], c[1], temp_c[1]); | |||||
| c[0] = vdotq_s32_h(weight[5], src[5], c[0], temp_c[0]); | |||||
| c[1] = vdotq_s32_h(weight[5], src[6], c[1], temp_c[1]); | |||||
| c[0] = vdotq_s32_h(weight[6], src[6], c[0], temp_c[0]); | |||||
| c[1] = vdotq_s32_h(weight[6], src[7], c[1], temp_c[1]); | |||||
| c[2] = vdotq_s32_h(weight[0], src[2], c[2], temp_c[0]); | |||||
| c[3] = vdotq_s32_h(weight[0], src[3], c[3], temp_c[1]); | |||||
| c[2] = vdotq_s32_h(weight[1], src[3], c[2], temp_c[0]); | |||||
| c[3] = vdotq_s32_h(weight[1], src[4], c[3], temp_c[1]); | |||||
| c[2] = vdotq_s32_h(weight[2], src[4], c[2], temp_c[0]); | |||||
| c[3] = vdotq_s32_h(weight[2], src[5], c[3], temp_c[1]); | |||||
| c[2] = vdotq_s32_h(weight[3], src[5], c[2], temp_c[0]); | |||||
| c[3] = vdotq_s32_h(weight[3], src[6], c[3], temp_c[1]); | |||||
| c[2] = vdotq_s32_h(weight[4], src[6], c[2], temp_c[0]); | |||||
| c[3] = vdotq_s32_h(weight[4], src[7], c[3], temp_c[1]); | |||||
| c[2] = vdotq_s32_h(weight[5], src[7], c[2], temp_c[0]); | |||||
| c[3] = vdotq_s32_h(weight[5], src[8], c[3], temp_c[1]); | |||||
| c[2] = vdotq_s32_h(weight[6], src[8], c[2], temp_c[0]); | |||||
| c[3] = vdotq_s32_h(weight[6], src[9], c[3], temp_c[1]); | |||||
| src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); | |||||
| src[1] = vld1q_s8((src_ic_0_3 + 11 * 16)); | |||||
| c[4] = vdotq_s32_h(weight[0], src[4], c[4], temp_c[0]); | |||||
| c[5] = vdotq_s32_h(weight[0], src[5], c[5], temp_c[1]); | |||||
| c[4] = vdotq_s32_h(weight[1], src[5], c[4], temp_c[0]); | |||||
| c[5] = vdotq_s32_h(weight[1], src[6], c[5], temp_c[1]); | |||||
| c[4] = vdotq_s32_h(weight[2], src[6], c[4], temp_c[0]); | |||||
| c[5] = vdotq_s32_h(weight[2], src[7], c[5], temp_c[1]); | |||||
| c[4] = vdotq_s32_h(weight[3], src[7], c[4], temp_c[0]); | |||||
| c[5] = vdotq_s32_h(weight[3], src[8], c[5], temp_c[1]); | |||||
| c[4] = vdotq_s32_h(weight[4], src[8], c[4], temp_c[0]); | |||||
| c[5] = vdotq_s32_h(weight[4], src[9], c[5], temp_c[1]); | |||||
| c[4] = vdotq_s32_h(weight[5], src[9], c[4], temp_c[0]); | |||||
| c[5] = vdotq_s32_h(weight[5], src[0], c[5], temp_c[1]); | |||||
| c[4] = vdotq_s32_h(weight[6], src[0], c[4], temp_c[0]); | |||||
| c[5] = vdotq_s32_h(weight[6], src[1], c[5], temp_c[1]); | |||||
| src[2] = vld1q_s8(src_ic_0_3 + 12 * 16); | |||||
| src[3] = vld1q_s8((src_ic_0_3 + 13 * 16)); | |||||
| c[6] = vdotq_s32_h(weight[0], src[6], c[6], temp_c[0]); | |||||
| c[7] = vdotq_s32_h(weight[0], src[7], c[7], temp_c[1]); | |||||
| c[6] = vdotq_s32_h(weight[1], src[7], c[6], temp_c[0]); | |||||
| c[7] = vdotq_s32_h(weight[1], src[8], c[7], temp_c[1]); | |||||
| c[6] = vdotq_s32_h(weight[2], src[8], c[6], temp_c[0]); | |||||
| c[7] = vdotq_s32_h(weight[2], src[9], c[7], temp_c[1]); | |||||
| c[6] = vdotq_s32_h(weight[3], src[9], c[6], temp_c[0]); | |||||
| c[7] = vdotq_s32_h(weight[3], src[0], c[7], temp_c[1]); | |||||
| c[6] = vdotq_s32_h(weight[4], src[0], c[6], temp_c[0]); | |||||
| c[7] = vdotq_s32_h(weight[4], src[1], c[7], temp_c[1]); | |||||
| c[6] = vdotq_s32_h(weight[5], src[1], c[6], temp_c[0]); | |||||
| c[7] = vdotq_s32_h(weight[5], src[2], c[7], temp_c[1]); | |||||
| c[6] = vdotq_s32_h(weight[6], src[2], c[6], temp_c[0]); | |||||
| c[7] = vdotq_s32_h(weight[6], src[3], c[7], temp_c[1]); | |||||
| } | |||||
| weight_ptr += fh * fw * ld_weight_ic4; | |||||
| } | |||||
| store_oc4_ow8_remain_static<remain_w, Op>(c, op, dst_ptr); | |||||
| } | |||||
| } // namespace | |||||
| /** | |||||
| origin weight shape <oc/4, ic/4, fh, fw, 4, 4> | |||||
| packed weight shape <oc/4, ic/4, fh, fw, 16> | |||||
| example: (format like weight<oc, ic>) | |||||
| origin | |||||
| <0, 0> <1, 0> <2, 0> <3, 0> | |||||
| <0, 1> <1, 1> <2, 1> <3, 1> | |||||
| <0, 2> <1, 2> <2, 2> <3, 2> | |||||
| <0, 3> <1, 3> <2, 3> <3, 3> | |||||
| packed | |||||
| low 64 bit <0, 0> <0, 1> <1, 2> <1, 3> | <2, 0> <2, 1> <3, 2> <3, 3> | |||||
| --------------------------------------------------------------------- | |||||
| high 64 bit <0, 3> <0, 2> <1, 1> <1, 0> | <2, 3> <2, 2> <3, 1> <3, 0> | |||||
| **/ | |||||
| void conv_bias::nchw44_pack_filter(const int8_t* src, int8_t* dst, int length) { | |||||
| static const uint8_t weight_idx_buffer[16] = {0, 4, 9, 13, 2, 6, 11, 15, | |||||
| 12, 8, 5, 1, 14, 10, 7, 3}; | |||||
| constexpr int simd_len = 16; | |||||
| uint8x16_t weight_idx = vld1q_u8(weight_idx_buffer); | |||||
| for (int i = 0; i < length; i++) { | |||||
| int8x16_t result = vldq_tbl_s8(src + i * simd_len, weight_idx); | |||||
| vst1q_s8(dst + i * simd_len, result); | |||||
| } | |||||
| } | |||||
| /** | |||||
| origin src shape <n, ic/4, h, w, 4> | |||||
| packed src shape <n, ic/4, h, w, 16> | |||||
| example: (format like <ic>) | |||||
| origin | |||||
| <0> <0> <0> <0> | |||||
| packed | |||||
| low 64 bit <0> <1> <2> <3> | <0> <1> <2> <3> | |||||
| --------------------------------------------------------------------- | |||||
| high 64 bit <3> <2> <1> <0> | <3> <2> <1> <0> | |||||
| **/ | |||||
| void conv_bias::nchw44_pack_src(const int8_t* src, int8_t* dst, int length) { | |||||
| static const uint8_t src_idx_buffer[16] = {0, 1, 2, 3, 0, 1, 2, 3, | |||||
| 3, 2, 1, 0, 3, 2, 1, 0}; | |||||
| constexpr int pack_ic = 4; | |||||
| constexpr int simd_len = 16; | |||||
| uint8x16_t src_idx = vld1q_u8(src_idx_buffer); | |||||
| for (int i = 0; i < length; i++) { | |||||
| int8x16_t result = vld_dup_tbl_s32(src + i * pack_ic, src_idx); | |||||
| vst1q_s8(dst + i * simd_len, result); | |||||
| } | |||||
| } | |||||
| template <BiasMode bias_mode, typename Op, int remain_w> | |||||
| void conv_bias::conv_direct_stride1_2x2_int8_nchw44( | |||||
| const int8_t* src, const int8_t* filter, const int32_t* bias, | |||||
| int32_t* temp, int8_t* dst, const size_t oc, const size_t ic, | |||||
| const size_t ih, const size_t iw, const size_t oh, const size_t ow, | |||||
| const Op& op) { | |||||
| MEGDNN_MARK_USED_VAR(temp); | |||||
| constexpr size_t filter_size = 2; | |||||
| constexpr size_t fh = filter_size; | |||||
| constexpr size_t fw = filter_size; | |||||
| constexpr size_t ic_step = 4; | |||||
| constexpr size_t oc_step = 4; | |||||
| constexpr size_t big_oc_step = 8; | |||||
| constexpr size_t oh_step = 1; | |||||
| constexpr size_t ow_step = 8; | |||||
| constexpr int pack_iw_len = 4; | |||||
| const size_t img_stride = oh * ow; | |||||
| const size_t ow_end = ow / ow_step * ow_step; | |||||
| const size_t ow_remain = ow - ow_end; | |||||
| const size_t oc_end = oc / big_oc_step * big_oc_step; | |||||
| const size_t oc_remain = oc - oc_end; | |||||
| const int ld_oc = oh * ow * ic_step; | |||||
| for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { | |||||
| const size_t weight_offset = oc_idx * ic * fh * fw; | |||||
| for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||||
| for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||||
| const size_t src_offset = | |||||
| (oh_idx * iw + ow_idx) * ic_step * pack_iw_len; | |||||
| const size_t dst_offset = | |||||
| oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||||
| ker_neon_dirctconv_2x2s1_oc8_ow8<bias_mode, Op, 0, filter_size>( | |||||
| src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
| dst + dst_offset, ic, ih, iw, ld_oc, op); | |||||
| } | |||||
| if (ow_remain > 0) { | |||||
| const size_t src_offset = | |||||
| (oh_idx * iw + ow_end) * ic_step * pack_iw_len; | |||||
| const size_t dst_offset = | |||||
| oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||||
| ker_neon_dirctconv_2x2s1_oc8_ow8<bias_mode, Op, remain_w, | |||||
| filter_size>( | |||||
| src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
| dst + dst_offset, ic, ih, iw, ld_oc, op); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (oc_remain > 0) { | |||||
| const size_t oc_idx = oc_end; | |||||
| const size_t weight_offset = oc_idx * ic * fh * fw; | |||||
| for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||||
| for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||||
| const size_t src_offset = | |||||
| (oh_idx * iw + ow_idx) * ic_step * pack_iw_len; | |||||
| const size_t dst_offset = | |||||
| oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||||
| ker_neon_dirctconv_2x2s1_oc4_ow8<bias_mode, Op, 0, filter_size>( | |||||
| src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
| dst + dst_offset, ic, ih, iw, op); | |||||
| } | |||||
| if (ow_remain > 0) { | |||||
| const size_t src_offset = | |||||
| (oh_idx * iw + ow_end) * ic_step * pack_iw_len; | |||||
| const size_t dst_offset = | |||||
| oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||||
| ker_neon_dirctconv_2x2s1_oc4_ow8<bias_mode, Op, remain_w, | |||||
| filter_size>( | |||||
| src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
| dst + dst_offset, ic, ih, iw, op); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| template <BiasMode bias_mode, typename Op, int remain_w> | |||||
| void conv_bias::conv_direct_stride1_3x3_int8_nchw44( | |||||
| const int8_t* src, const int8_t* filter, const int32_t* bias, | |||||
| int32_t* temp, int8_t* dst, const size_t oc, const size_t ic, | |||||
| const size_t ih, const size_t iw, const size_t oh, const size_t ow, | |||||
| const Op& op) { | |||||
| MEGDNN_MARK_USED_VAR(temp); | |||||
| constexpr size_t filter_size = 3; | |||||
| constexpr size_t fh = filter_size; | |||||
| constexpr size_t fw = filter_size; | |||||
| constexpr size_t ic_step = 4; | |||||
| constexpr size_t oc_step = 4; | |||||
| constexpr size_t oh_step = 1; | |||||
| constexpr size_t ow_step = 8; | |||||
| constexpr int pack_iw_len = 4; | |||||
| const size_t img_stride = oh * ow; | |||||
| const size_t ow_end = ow / ow_step * ow_step; | |||||
| const size_t ow_remain = ow - ow_end; | |||||
| for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { | |||||
| const size_t weight_offset = oc_idx * ic * fh * fw; | |||||
| for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||||
| for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||||
| const size_t src_offset = | |||||
| (oh_idx * iw + ow_idx) * ic_step * pack_iw_len; | |||||
| const size_t dst_offset = | |||||
| oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||||
| ker_neon_dirctconv_3x3s1_oc4_ow8<bias_mode, Op, 0, filter_size>( | |||||
| src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
| dst + dst_offset, ic, ih, iw, op); | |||||
| } | |||||
| if (ow_remain > 0) { | |||||
| const size_t src_offset = | |||||
| (oh_idx * iw + ow_end) * ic_step * pack_iw_len; | |||||
| const size_t dst_offset = | |||||
| oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||||
| ker_neon_dirctconv_3x3s1_oc4_ow8<bias_mode, Op, remain_w, | |||||
| filter_size>( | |||||
| src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
| dst + dst_offset, ic, ih, iw, op); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| template <BiasMode bias_mode, typename Op, int remain_w> | |||||
| void conv_bias::conv_direct_stride1_5x5_int8_nchw44( | |||||
| const int8_t* src, const int8_t* filter, const int32_t* bias, | |||||
| int32_t* temp, int8_t* dst, const size_t oc, const size_t ic, | |||||
| const size_t ih, const size_t iw, const size_t oh, const size_t ow, | |||||
| const Op& op) { | |||||
| MEGDNN_MARK_USED_VAR(temp); | |||||
| constexpr size_t filter_size = 5; | |||||
| constexpr size_t fh = filter_size; | |||||
| constexpr size_t fw = filter_size; | |||||
| constexpr size_t ic_step = 4; | |||||
| constexpr size_t oc_step = 4; | |||||
| constexpr size_t oh_step = 1; | |||||
| constexpr size_t ow_step = 8; | |||||
| constexpr int pack_iw_len = 4; | |||||
| const size_t img_stride = oh * ow; | |||||
| const size_t ow_end = ow / ow_step * ow_step; | |||||
| const size_t ow_remain = ow - ow_end; | |||||
| for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { | |||||
| const size_t weight_offset = oc_idx * ic * fh * fw; | |||||
| for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||||
| for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||||
| const size_t src_offset = | |||||
| (oh_idx * iw + ow_idx) * ic_step * pack_iw_len; | |||||
| const size_t dst_offset = | |||||
| oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||||
| ker_neon_dirctconv_5x5s1_oc4_ow8<bias_mode, Op, 0, filter_size>( | |||||
| src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
| dst + dst_offset, ic, ih, iw, op); | |||||
| } | |||||
| if (ow_remain > 0) { | |||||
| const size_t src_offset = | |||||
| (oh_idx * iw + ow_end) * ic_step * pack_iw_len; | |||||
| const size_t dst_offset = | |||||
| oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||||
| ker_neon_dirctconv_5x5s1_oc4_ow8<bias_mode, Op, remain_w, | |||||
| filter_size>( | |||||
| src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
| dst + dst_offset, ic, ih, iw, op); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| template <BiasMode bias_mode, typename Op, int remain_w> | |||||
| void conv_bias::conv_direct_stride1_7x7_int8_nchw44( | |||||
| const int8_t* src, const int8_t* filter, const int32_t* bias, | |||||
| int32_t* temp, int8_t* dst, const size_t oc, const size_t ic, | |||||
| const size_t ih, const size_t iw, const size_t oh, const size_t ow, | |||||
| const Op& op) { | |||||
| MEGDNN_MARK_USED_VAR(temp); | |||||
| constexpr size_t filter_size = 7; | |||||
| constexpr size_t fh = filter_size; | |||||
| constexpr size_t fw = filter_size; | |||||
| constexpr size_t ic_step = 4; | |||||
| constexpr size_t oc_step = 4; | |||||
| constexpr size_t oh_step = 1; | |||||
| constexpr size_t ow_step = 8; | |||||
| constexpr int pack_iw_len = 4; | |||||
| const size_t img_stride = oh * ow; | |||||
| const size_t ow_end = ow / ow_step * ow_step; | |||||
| const size_t ow_remain = ow - ow_end; | |||||
| for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { | |||||
| const size_t weight_offset = oc_idx * ic * fh * fw; | |||||
| for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||||
| for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||||
| const size_t src_offset = | |||||
| (oh_idx * iw + ow_idx) * ic_step * pack_iw_len; | |||||
| const size_t dst_offset = | |||||
| oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||||
| ker_neon_dirctconv_7x7s1_oc4_ow8<bias_mode, Op, 0, filter_size>( | |||||
| src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
| dst + dst_offset, ic, ih, iw, op); | |||||
| } | |||||
| if (ow_remain > 0) { | |||||
| const size_t src_offset = | |||||
| (oh_idx * iw + ow_end) * ic_step * pack_iw_len; | |||||
| const size_t dst_offset = | |||||
| oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||||
| ker_neon_dirctconv_7x7s1_oc4_ow8<bias_mode, Op, remain_w, | |||||
| filter_size>( | |||||
| src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
| dst + dst_offset, ic, ih, iw, op); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| #define INSTANTIATION(stride, i, bias, remain_w, Op) \ | |||||
| template void conv_bias::conv_direct_##stride##_##i##x##i##_int8_nchw44< \ | |||||
| bias, Op, remain_w>(const int8_t*, const int8_t*, const int32_t*, \ | |||||
| int32_t*, int8_t*, const size_t, const size_t, \ | |||||
| const size_t, const size_t, const size_t, \ | |||||
| const size_t, const Op&); | |||||
| #define FOR_OP(stride, i, bias, remain_w) \ | |||||
| INSTANTIATION(stride, i, bias, remain_w, \ | |||||
| TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
| INSTANTIATION(stride, i, bias, remain_w, \ | |||||
| ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
| INSTANTIATION(stride, i, bias, remain_w, \ | |||||
| HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) | |||||
| #define FOR_REMAIN(stride, i, bias) \ | |||||
| FOR_OP(stride, i, bias, 0) \ | |||||
| FOR_OP(stride, i, bias, 1) \ | |||||
| FOR_OP(stride, i, bias, 2) \ | |||||
| FOR_OP(stride, i, bias, 3) \ | |||||
| FOR_OP(stride, i, bias, 4) \ | |||||
| FOR_OP(stride, i, bias, 5) \ | |||||
| FOR_OP(stride, i, bias, 6) \ | |||||
| FOR_OP(stride, i, bias, 7) | |||||
| #define FOR_BIAS(stride, i) \ | |||||
| FOR_REMAIN(stride, i, BiasMode::NO_BIAS) \ | |||||
| FOR_REMAIN(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS) | |||||
| #define FOR_FILTER(stride) \ | |||||
| FOR_BIAS(stride, 2) \ | |||||
| FOR_BIAS(stride, 3) \ | |||||
| FOR_BIAS(stride, 5) \ | |||||
| FOR_BIAS(stride, 7) | |||||
| FOR_FILTER(stride1) | |||||
| #undef FOR_STRIDE | |||||
| #undef FOR_FILTER | |||||
| #undef FOR_IC | |||||
| #undef FOR_BIAS | |||||
| #undef FOR_NONLINEAR | |||||
| #undef FOR_REMAIN | |||||
| #undef INSTANTIATION | |||||
| @@ -1,793 +0,0 @@ | |||||
| /** | |||||
| * \file dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw44_kern.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "src/arm_common/conv_bias/int8/direct.h" | |||||
| #include "src/arm_common/conv_bias/intrinsic_helper.h" | |||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||||
| #include "src/common/utils.h" | |||||
| #include "src/fallback/conv_bias/common.h" | |||||
| using namespace megdnn; | |||||
| using namespace arm_common; | |||||
| namespace { | |||||
| /** | |||||
| dot like impl. dot 4 ic to 1 oc, accumale to c <ow, oc> | |||||
| example: (format like weight<oc, ic>) | |||||
| packed weight | |||||
| low 64 bit <0, 0> <0, 1> <1, 2> <1, 3> | <2, 0> <2, 1> <3, 2> <3, 3> | |||||
| --------------------------------------------------------------------- | |||||
| high 64 bit <0, 3> <0, 2> <1, 1> <1, 0> | <2, 3> <2, 2> <3, 1> <3, 0> | |||||
| dot: (<0, 0> + <0, 3>) + (<0, 1> + <0, 2>) -> <0> | |||||
| **/ | |||||
| // TODO: can try oh = 2 impl, oc = 8 impl | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int filter_size> | |||||
| static void ker_neon_dirctconv_3x3s2_oc4_ow8(const int8_t* src_ptr, | |||||
| const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, | |||||
| int8_t* dst_ptr, int ic, int ih, | |||||
| int iw, const Op& op) { | |||||
| constexpr int fh = filter_size; | |||||
| constexpr int fw = filter_size; | |||||
| constexpr int ic_step = 4; | |||||
| constexpr int loop_ic_step = 4; | |||||
| constexpr int ld_weight_ic4 = 16; | |||||
| constexpr int pack_iw_len = 4; | |||||
| const int ic_stride = ih * iw * pack_iw_len; | |||||
| int32x4_t c[2 * 4]; | |||||
| int8x16_t weight[3]; | |||||
| int8x16_t src[8 + 2]; | |||||
| int16x8_t temp_c[2]; | |||||
| init_oc4_ow8<bias_mode>(c, bias_ptr); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
| for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||||
| const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + | |||||
| fh_idx * iw * ic_step * pack_iw_len; | |||||
| src[0] = vld1q_s8(src_ic_0_3); | |||||
| src[1] = vld1q_s8((src_ic_0_3 + 16)); | |||||
| src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); | |||||
| src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); | |||||
| src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); | |||||
| src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); | |||||
| src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); | |||||
| src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); | |||||
| src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); | |||||
| src[9] = vld1q_s8((src_ic_0_3 + 9 * 16)); | |||||
| // oc == 0 | |||||
| const int8_t* read_weight_ptr = | |||||
| weight_ptr + fh_idx * fw * ld_weight_ic4; | |||||
| weight[0] = vld1q_s8(read_weight_ptr); | |||||
| weight[1] = vld1q_s8(read_weight_ptr + 16); | |||||
| weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); | |||||
| c[0] = vdotq_s32_h(weight[0], src[0], c[0], temp_c[0]); | |||||
| c[1] = vdotq_s32_h(weight[0], src[2], c[1], temp_c[1]); | |||||
| c[0] = vdotq_s32_h(weight[1], src[1], c[0], temp_c[0]); | |||||
| c[1] = vdotq_s32_h(weight[1], src[3], c[1], temp_c[1]); | |||||
| c[0] = vdotq_s32_h(weight[2], src[2], c[0], temp_c[0]); | |||||
| c[1] = vdotq_s32_h(weight[2], src[4], c[1], temp_c[1]); | |||||
| c[2] = vdotq_s32_h(weight[0], src[4], c[2], temp_c[0]); | |||||
| c[3] = vdotq_s32_h(weight[0], src[6], c[3], temp_c[1]); | |||||
| c[2] = vdotq_s32_h(weight[1], src[5], c[2], temp_c[0]); | |||||
| c[3] = vdotq_s32_h(weight[1], src[7], c[3], temp_c[1]); | |||||
| c[2] = vdotq_s32_h(weight[2], src[6], c[2], temp_c[0]); | |||||
| c[3] = vdotq_s32_h(weight[2], src[8], c[3], temp_c[1]); | |||||
| src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); | |||||
| src[1] = vld1q_s8((src_ic_0_3 + 11 * 16)); | |||||
| src[2] = vld1q_s8((src_ic_0_3 + 12 * 16)); | |||||
| c[4] = vdotq_s32_h(weight[0], src[8], c[4], temp_c[0]); | |||||
| c[5] = vdotq_s32_h(weight[0], src[0], c[5], temp_c[1]); | |||||
| c[4] = vdotq_s32_h(weight[1], src[9], c[4], temp_c[0]); | |||||
| c[5] = vdotq_s32_h(weight[1], src[1], c[5], temp_c[1]); | |||||
| c[4] = vdotq_s32_h(weight[2], src[0], c[4], temp_c[0]); | |||||
| c[5] = vdotq_s32_h(weight[2], src[2], c[5], temp_c[1]); | |||||
| src[3] = vld1q_s8((src_ic_0_3 + 13 * 16)); | |||||
| src[4] = vld1q_s8((src_ic_0_3 + 14 * 16)); | |||||
| src[5] = vld1q_s8((src_ic_0_3 + 15 * 16)); | |||||
| src[6] = vld1q_s8((src_ic_0_3 + 16 * 16)); | |||||
| c[6] = vdotq_s32_h(weight[0], src[2], c[6], temp_c[0]); | |||||
| c[7] = vdotq_s32_h(weight[0], src[4], c[7], temp_c[1]); | |||||
| c[6] = vdotq_s32_h(weight[1], src[3], c[6], temp_c[0]); | |||||
| c[7] = vdotq_s32_h(weight[1], src[5], c[7], temp_c[1]); | |||||
| c[6] = vdotq_s32_h(weight[2], src[4], c[6], temp_c[0]); | |||||
| c[7] = vdotq_s32_h(weight[2], src[6], c[7], temp_c[1]); | |||||
| } | |||||
| weight_ptr += fh * fw * ld_weight_ic4; | |||||
| } | |||||
| store_oc4_ow8_remain_static<remain_w, Op>(c, op, dst_ptr); | |||||
| } | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int filter_size> | |||||
| static void ker_neon_dirctconv_2x2s2_oc8_ow8(const int8_t* src_ptr, | |||||
| const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, | |||||
| int8_t* dst_ptr, int ic, int ih, | |||||
| int iw, int ld_dst_oc, | |||||
| const Op& op) { | |||||
| constexpr int fh = filter_size; | |||||
| constexpr int fw = filter_size; | |||||
| constexpr int ic_step = 4; | |||||
| constexpr int oc_step = 4; | |||||
| constexpr int loop_ic_step = 4; | |||||
| constexpr int ld_weight_ic4 = 16; | |||||
| constexpr int pack_iw_len = 4; | |||||
| const int ic_stride = ih * iw * pack_iw_len; | |||||
| const int ld_weight_oc4 = oc_step * fh * fw * ic; | |||||
| int32x4_t c[2][8]; | |||||
| int8x16_t weight[2][2]; | |||||
| int8x16_t src[8 + 1]; | |||||
| int16x8_t temp_c[4]; | |||||
| init_oc8_ow8<bias_mode>(c, bias_ptr, oc_step); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
| for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||||
| const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + | |||||
| fh_idx * iw * ic_step * pack_iw_len; | |||||
| src[0] = vld1q_s8(src_ic_0_3); | |||||
| src[1] = vld1q_s8(src_ic_0_3 + 16); | |||||
| src[2] = vld1q_s8(src_ic_0_3 + 2 * 16); | |||||
| src[3] = vld1q_s8(src_ic_0_3 + 3 * 16); | |||||
| src[4] = vld1q_s8(src_ic_0_3 + 4 * 16); | |||||
| src[5] = vld1q_s8(src_ic_0_3 + 5 * 16); | |||||
| src[6] = vld1q_s8(src_ic_0_3 + 6 * 16); | |||||
| src[7] = vld1q_s8(src_ic_0_3 + 7 * 16); | |||||
| src[8] = vld1q_s8(src_ic_0_3 + 8 * 16); | |||||
| // oc == 0 | |||||
| const int8_t* read_weight_ptr = | |||||
| weight_ptr + fh_idx * fw * ld_weight_ic4; | |||||
| weight[0][0] = vld1q_s8(read_weight_ptr); | |||||
| weight[0][1] = vld1q_s8(read_weight_ptr + 16); | |||||
| weight[1][0] = vld1q_s8(read_weight_ptr + ld_weight_oc4); | |||||
| weight[1][1] = vld1q_s8(read_weight_ptr + ld_weight_oc4 + 16); | |||||
| c[0][0] = vdotq_s32_h(weight[0][0], src[0], c[0][0], temp_c[0]); | |||||
| c[1][0] = vdotq_s32_h(weight[1][0], src[0], c[1][0], temp_c[1]); | |||||
| c[0][1] = vdotq_s32_h(weight[0][0], src[2], c[0][1], temp_c[2]); | |||||
| c[1][1] = vdotq_s32_h(weight[1][0], src[2], c[1][1], temp_c[3]); | |||||
| c[0][0] = vdotq_s32_h(weight[0][1], src[1], c[0][0], temp_c[0]); | |||||
| c[1][0] = vdotq_s32_h(weight[1][1], src[1], c[1][0], temp_c[1]); | |||||
| c[0][1] = vdotq_s32_h(weight[0][1], src[3], c[0][1], temp_c[2]); | |||||
| c[1][1] = vdotq_s32_h(weight[1][1], src[3], c[1][1], temp_c[3]); | |||||
| c[0][2] = vdotq_s32_h(weight[0][0], src[4], c[0][2], temp_c[0]); | |||||
| c[1][2] = vdotq_s32_h(weight[1][0], src[4], c[1][2], temp_c[1]); | |||||
| c[0][3] = vdotq_s32_h(weight[0][0], src[6], c[0][3], temp_c[2]); | |||||
| c[1][3] = vdotq_s32_h(weight[1][0], src[6], c[1][3], temp_c[3]); | |||||
| c[0][2] = vdotq_s32_h(weight[0][1], src[5], c[0][2], temp_c[0]); | |||||
| c[1][2] = vdotq_s32_h(weight[1][1], src[5], c[1][2], temp_c[1]); | |||||
| c[0][3] = vdotq_s32_h(weight[0][1], src[7], c[0][3], temp_c[2]); | |||||
| c[1][3] = vdotq_s32_h(weight[1][1], src[7], c[1][3], temp_c[3]); | |||||
| src[0] = vld1q_s8(src_ic_0_3 + 9 * 16); | |||||
| src[1] = vld1q_s8(src_ic_0_3 + 10 * 16); | |||||
| src[2] = vld1q_s8(src_ic_0_3 + 11 * 16); | |||||
| c[0][4] = vdotq_s32_h(weight[0][0], src[8], c[0][4], temp_c[0]); | |||||
| c[1][4] = vdotq_s32_h(weight[1][0], src[8], c[1][4], temp_c[1]); | |||||
| c[0][5] = vdotq_s32_h(weight[0][0], src[1], c[0][5], temp_c[2]); | |||||
| c[1][5] = vdotq_s32_h(weight[1][0], src[1], c[1][5], temp_c[3]); | |||||
| c[0][4] = vdotq_s32_h(weight[0][1], src[0], c[0][4], temp_c[0]); | |||||
| c[1][4] = vdotq_s32_h(weight[1][1], src[0], c[1][4], temp_c[1]); | |||||
| c[0][5] = vdotq_s32_h(weight[0][1], src[2], c[0][5], temp_c[2]); | |||||
| c[1][5] = vdotq_s32_h(weight[1][1], src[2], c[1][5], temp_c[3]); | |||||
| src[3] = vld1q_s8(src_ic_0_3 + 12 * 16); | |||||
| src[4] = vld1q_s8(src_ic_0_3 + 13 * 16); | |||||
| src[5] = vld1q_s8(src_ic_0_3 + 14 * 16); | |||||
| src[6] = vld1q_s8(src_ic_0_3 + 15 * 16); | |||||
| c[0][6] = vdotq_s32_h(weight[0][0], src[3], c[0][6], temp_c[0]); | |||||
| c[1][6] = vdotq_s32_h(weight[1][0], src[3], c[1][6], temp_c[1]); | |||||
| c[0][7] = vdotq_s32_h(weight[0][0], src[5], c[0][7], temp_c[2]); | |||||
| c[1][7] = vdotq_s32_h(weight[1][0], src[5], c[1][7], temp_c[3]); | |||||
| c[0][6] = vdotq_s32_h(weight[0][1], src[4], c[0][6], temp_c[0]); | |||||
| c[1][6] = vdotq_s32_h(weight[1][1], src[4], c[1][6], temp_c[1]); | |||||
| c[0][7] = vdotq_s32_h(weight[0][1], src[6], c[0][7], temp_c[2]); | |||||
| c[1][7] = vdotq_s32_h(weight[1][1], src[6], c[1][7], temp_c[3]); | |||||
| } | |||||
| weight_ptr += fh * fw * ld_weight_ic4; | |||||
| } | |||||
| store_oc8_ow8_remain_static<remain_w>(c, op, dst_ptr, ld_dst_oc); | |||||
| } | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int filter_size> | |||||
| static void ker_neon_dirctconv_2x2s2_oc4_ow8(const int8_t* src_ptr, | |||||
| const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, | |||||
| int8_t* dst_ptr, int ic, int ih, | |||||
| int iw, const Op& op) { | |||||
| constexpr int fh = filter_size; | |||||
| constexpr int fw = filter_size; | |||||
| constexpr int ic_step = 4; | |||||
| constexpr int loop_ic_step = 4; | |||||
| constexpr int ld_weight_ic4 = 16; | |||||
| constexpr int pack_iw_len = 4; | |||||
| const int ic_stride = ih * iw * pack_iw_len; | |||||
| int32x4_t c[2 * 4]; | |||||
| int8x16_t weight[2]; | |||||
| int8x16_t src[8 + 1]; | |||||
| int16x8_t temp_c[2]; | |||||
| init_oc4_ow8<bias_mode>(c, bias_ptr); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
| for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||||
| const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + | |||||
| fh_idx * iw * ic_step * pack_iw_len; | |||||
| src[0] = vld1q_s8(src_ic_0_3); | |||||
| src[1] = vld1q_s8((src_ic_0_3 + 16)); | |||||
| src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); | |||||
| src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); | |||||
| src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); | |||||
| src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); | |||||
| src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); | |||||
| src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); | |||||
| src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); | |||||
| // oc == 0 | |||||
| const int8_t* read_weight_ptr = | |||||
| weight_ptr + fh_idx * fw * ld_weight_ic4; | |||||
| weight[0] = vld1q_s8(read_weight_ptr); | |||||
| weight[1] = vld1q_s8(read_weight_ptr + 16); | |||||
| c[0] = vdotq_s32_h(weight[0], src[0], c[0], temp_c[0]); | |||||
| c[1] = vdotq_s32_h(weight[0], src[2], c[1], temp_c[1]); | |||||
| c[0] = vdotq_s32_h(weight[1], src[1], c[0], temp_c[0]); | |||||
| c[1] = vdotq_s32_h(weight[1], src[3], c[1], temp_c[1]); | |||||
| c[2] = vdotq_s32_h(weight[0], src[4], c[2], temp_c[0]); | |||||
| c[3] = vdotq_s32_h(weight[0], src[6], c[3], temp_c[1]); | |||||
| c[2] = vdotq_s32_h(weight[1], src[5], c[2], temp_c[0]); | |||||
| c[3] = vdotq_s32_h(weight[1], src[7], c[3], temp_c[1]); | |||||
| src[0] = vld1q_s8(src_ic_0_3 + 9 * 16); | |||||
| src[1] = vld1q_s8(src_ic_0_3 + 10 * 16); | |||||
| src[2] = vld1q_s8(src_ic_0_3 + 11 * 16); | |||||
| c[4] = vdotq_s32_h(weight[0], src[8], c[4], temp_c[0]); | |||||
| c[5] = vdotq_s32_h(weight[0], src[1], c[5], temp_c[1]); | |||||
| c[4] = vdotq_s32_h(weight[1], src[0], c[4], temp_c[0]); | |||||
| c[5] = vdotq_s32_h(weight[1], src[2], c[5], temp_c[1]); | |||||
| src[3] = vld1q_s8(src_ic_0_3 + 12 * 16); | |||||
| src[4] = vld1q_s8(src_ic_0_3 + 13 * 16); | |||||
| src[5] = vld1q_s8(src_ic_0_3 + 14 * 16); | |||||
| src[6] = vld1q_s8(src_ic_0_3 + 15 * 16); | |||||
| c[6] = vdotq_s32_h(weight[0], src[3], c[6], temp_c[0]); | |||||
| c[7] = vdotq_s32_h(weight[0], src[5], c[7], temp_c[1]); | |||||
| c[6] = vdotq_s32_h(weight[1], src[4], c[6], temp_c[0]); | |||||
| c[7] = vdotq_s32_h(weight[1], src[6], c[7], temp_c[1]); | |||||
| } | |||||
| weight_ptr += fh * fw * ld_weight_ic4; | |||||
| } | |||||
| store_oc4_ow8_remain_static<remain_w, Op>(c, op, dst_ptr); | |||||
| } | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int filter_size> | |||||
| static void ker_neon_dirctconv_5x5s2_oc4_ow8(const int8_t* src_ptr, | |||||
| const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, | |||||
| int8_t* dst_ptr, int ic, int ih, | |||||
| int iw, const Op& op) { | |||||
| constexpr int fh = filter_size; | |||||
| constexpr int fw = filter_size; | |||||
| constexpr int ic_step = 4; | |||||
| constexpr int loop_ic_step = 4; | |||||
| constexpr int ld_weight_ic4 = 16; | |||||
| constexpr int pack_iw_len = 4; | |||||
| const int ic_stride = ih * iw * pack_iw_len; | |||||
| int32x4_t c[2 * 4]; | |||||
| int8x16_t weight[5]; | |||||
| int8x16_t src[8 + 2]; | |||||
| int16x8_t temp_c[2]; | |||||
| init_oc4_ow8<bias_mode>(c, bias_ptr); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
| for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||||
| const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + | |||||
| fh_idx * iw * ic_step * pack_iw_len; | |||||
| src[0] = vld1q_s8(src_ic_0_3); | |||||
| src[1] = vld1q_s8((src_ic_0_3 + 16)); | |||||
| src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); | |||||
| src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); | |||||
| src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); | |||||
| src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); | |||||
| src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); | |||||
| src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); | |||||
| src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); | |||||
| src[9] = vld1q_s8((src_ic_0_3 + 9 * 16)); | |||||
| // oc == 0 | |||||
| const int8_t* read_weight_ptr = | |||||
| weight_ptr + fh_idx * fw * ld_weight_ic4; | |||||
| weight[0] = vld1q_s8(read_weight_ptr); | |||||
| weight[1] = vld1q_s8(read_weight_ptr + 16); | |||||
| weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); | |||||
| weight[3] = vld1q_s8(read_weight_ptr + 3 * 16); | |||||
| weight[4] = vld1q_s8(read_weight_ptr + 4 * 16); | |||||
| c[0] = vdotq_s32_h(weight[0], src[0], c[0], temp_c[0]); | |||||
| c[1] = vdotq_s32_h(weight[0], src[2], c[1], temp_c[1]); | |||||
| c[0] = vdotq_s32_h(weight[1], src[1], c[0], temp_c[0]); | |||||
| c[1] = vdotq_s32_h(weight[1], src[3], c[1], temp_c[1]); | |||||
| c[0] = vdotq_s32_h(weight[2], src[2], c[0], temp_c[0]); | |||||
| c[1] = vdotq_s32_h(weight[2], src[4], c[1], temp_c[1]); | |||||
| c[0] = vdotq_s32_h(weight[3], src[3], c[0], temp_c[0]); | |||||
| c[1] = vdotq_s32_h(weight[3], src[5], c[1], temp_c[1]); | |||||
| c[0] = vdotq_s32_h(weight[4], src[4], c[0], temp_c[0]); | |||||
| c[1] = vdotq_s32_h(weight[4], src[6], c[1], temp_c[1]); | |||||
| src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); | |||||
| c[2] = vdotq_s32_h(weight[0], src[4], c[2], temp_c[0]); | |||||
| c[3] = vdotq_s32_h(weight[0], src[6], c[3], temp_c[1]); | |||||
| c[2] = vdotq_s32_h(weight[1], src[5], c[2], temp_c[0]); | |||||
| c[3] = vdotq_s32_h(weight[1], src[7], c[3], temp_c[1]); | |||||
| c[2] = vdotq_s32_h(weight[2], src[6], c[2], temp_c[0]); | |||||
| c[3] = vdotq_s32_h(weight[2], src[8], c[3], temp_c[1]); | |||||
| c[2] = vdotq_s32_h(weight[3], src[7], c[2], temp_c[0]); | |||||
| c[3] = vdotq_s32_h(weight[3], src[9], c[3], temp_c[1]); | |||||
| c[2] = vdotq_s32_h(weight[4], src[8], c[2], temp_c[0]); | |||||
| c[3] = vdotq_s32_h(weight[4], src[0], c[3], temp_c[1]); | |||||
| src[1] = vld1q_s8((src_ic_0_3 + 11 * 16)); | |||||
| src[2] = vld1q_s8((src_ic_0_3 + 12 * 16)); | |||||
| src[3] = vld1q_s8((src_ic_0_3 + 13 * 16)); | |||||
| src[4] = vld1q_s8((src_ic_0_3 + 14 * 16)); | |||||
| c[4] = vdotq_s32_h(weight[0], src[8], c[4], temp_c[0]); | |||||
| c[5] = vdotq_s32_h(weight[0], src[0], c[5], temp_c[1]); | |||||
| c[4] = vdotq_s32_h(weight[1], src[9], c[4], temp_c[0]); | |||||
| c[5] = vdotq_s32_h(weight[1], src[1], c[5], temp_c[1]); | |||||
| c[4] = vdotq_s32_h(weight[2], src[0], c[4], temp_c[0]); | |||||
| c[5] = vdotq_s32_h(weight[2], src[2], c[5], temp_c[1]); | |||||
| c[4] = vdotq_s32_h(weight[3], src[1], c[4], temp_c[0]); | |||||
| c[5] = vdotq_s32_h(weight[3], src[3], c[5], temp_c[1]); | |||||
| c[4] = vdotq_s32_h(weight[4], src[2], c[4], temp_c[0]); | |||||
| c[5] = vdotq_s32_h(weight[4], src[4], c[5], temp_c[1]); | |||||
| src[5] = vld1q_s8((src_ic_0_3 + 15 * 16)); | |||||
| src[6] = vld1q_s8((src_ic_0_3 + 16 * 16)); | |||||
| src[7] = vld1q_s8((src_ic_0_3 + 17 * 16)); | |||||
| src[8] = vld1q_s8((src_ic_0_3 + 18 * 16)); | |||||
| c[6] = vdotq_s32_h(weight[0], src[2], c[6], temp_c[0]); | |||||
| c[7] = vdotq_s32_h(weight[0], src[4], c[7], temp_c[1]); | |||||
| c[6] = vdotq_s32_h(weight[1], src[3], c[6], temp_c[0]); | |||||
| c[7] = vdotq_s32_h(weight[1], src[5], c[7], temp_c[1]); | |||||
| c[6] = vdotq_s32_h(weight[2], src[4], c[6], temp_c[0]); | |||||
| c[7] = vdotq_s32_h(weight[2], src[6], c[7], temp_c[1]); | |||||
| c[6] = vdotq_s32_h(weight[3], src[5], c[6], temp_c[0]); | |||||
| c[7] = vdotq_s32_h(weight[3], src[7], c[7], temp_c[1]); | |||||
| c[6] = vdotq_s32_h(weight[4], src[6], c[6], temp_c[0]); | |||||
| c[7] = vdotq_s32_h(weight[4], src[8], c[7], temp_c[1]); | |||||
| } | |||||
| weight_ptr += fh * fw * ld_weight_ic4; | |||||
| } | |||||
| store_oc4_ow8_remain_static<remain_w, Op>(c, op, dst_ptr); | |||||
| } | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int filter_size> | |||||
| static void ker_neon_dirctconv_7x7s2_oc4_ow8(const int8_t* src_ptr, | |||||
| const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, | |||||
| int8_t* dst_ptr, int ic, int ih, | |||||
| int iw, const Op& op) { | |||||
| constexpr int fh = filter_size; | |||||
| constexpr int fw = filter_size; | |||||
| constexpr int ic_step = 4; | |||||
| constexpr int loop_ic_step = 4; | |||||
| constexpr int ld_weight_ic4 = 16; | |||||
| constexpr int pack_iw_len = 4; | |||||
| const int ic_stride = ih * iw * pack_iw_len; | |||||
| int32x4_t c[2 * 4]; | |||||
| int8x16_t weight[7]; | |||||
| int8x16_t src[8 + 2]; | |||||
| int16x8_t temp_c[2]; | |||||
| init_oc4_ow8<bias_mode>(c, bias_ptr); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
| for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||||
| const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + | |||||
| fh_idx * iw * ic_step * pack_iw_len; | |||||
| src[0] = vld1q_s8(src_ic_0_3); | |||||
| src[1] = vld1q_s8(src_ic_0_3 + 1 * 16); | |||||
| src[2] = vld1q_s8(src_ic_0_3 + 2 * 16); | |||||
| src[3] = vld1q_s8(src_ic_0_3 + 3 * 16); | |||||
| src[4] = vld1q_s8(src_ic_0_3 + 4 * 16); | |||||
| src[5] = vld1q_s8(src_ic_0_3 + 5 * 16); | |||||
| src[6] = vld1q_s8(src_ic_0_3 + 6 * 16); | |||||
| src[7] = vld1q_s8(src_ic_0_3 + 7 * 16); | |||||
| src[8] = vld1q_s8(src_ic_0_3 + 8 * 16); | |||||
| src[9] = vld1q_s8(src_ic_0_3 + 9 * 16); | |||||
| // oc == 0 | |||||
| const int8_t* read_weight_ptr = | |||||
| weight_ptr + fh_idx * fw * ld_weight_ic4; | |||||
| weight[0] = vld1q_s8(read_weight_ptr); | |||||
| weight[1] = vld1q_s8(read_weight_ptr + 16); | |||||
| weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); | |||||
| weight[3] = vld1q_s8(read_weight_ptr + 3 * 16); | |||||
| weight[4] = vld1q_s8(read_weight_ptr + 4 * 16); | |||||
| weight[5] = vld1q_s8(read_weight_ptr + 5 * 16); | |||||
| weight[6] = vld1q_s8(read_weight_ptr + 6 * 16); | |||||
| c[0] = vdotq_s32_h(weight[0], src[0], c[0], temp_c[0]); | |||||
| c[1] = vdotq_s32_h(weight[0], src[2], c[1], temp_c[1]); | |||||
| c[0] = vdotq_s32_h(weight[1], src[1], c[0], temp_c[0]); | |||||
| c[1] = vdotq_s32_h(weight[1], src[3], c[1], temp_c[1]); | |||||
| c[0] = vdotq_s32_h(weight[2], src[2], c[0], temp_c[0]); | |||||
| c[1] = vdotq_s32_h(weight[2], src[4], c[1], temp_c[1]); | |||||
| c[0] = vdotq_s32_h(weight[3], src[3], c[0], temp_c[0]); | |||||
| c[1] = vdotq_s32_h(weight[3], src[5], c[1], temp_c[1]); | |||||
| c[0] = vdotq_s32_h(weight[4], src[4], c[0], temp_c[0]); | |||||
| c[1] = vdotq_s32_h(weight[4], src[6], c[1], temp_c[1]); | |||||
| c[0] = vdotq_s32_h(weight[5], src[5], c[0], temp_c[0]); | |||||
| c[1] = vdotq_s32_h(weight[5], src[7], c[1], temp_c[1]); | |||||
| c[0] = vdotq_s32_h(weight[6], src[6], c[0], temp_c[0]); | |||||
| c[1] = vdotq_s32_h(weight[6], src[8], c[1], temp_c[1]); | |||||
| src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); | |||||
| src[1] = vld1q_s8(src_ic_0_3 + 11 * 16); | |||||
| src[2] = vld1q_s8(src_ic_0_3 + 12 * 16); | |||||
| c[2] = vdotq_s32_h(weight[0], src[4], c[2], temp_c[0]); | |||||
| c[3] = vdotq_s32_h(weight[0], src[6], c[3], temp_c[1]); | |||||
| c[2] = vdotq_s32_h(weight[1], src[5], c[2], temp_c[0]); | |||||
| c[3] = vdotq_s32_h(weight[1], src[7], c[3], temp_c[1]); | |||||
| c[2] = vdotq_s32_h(weight[2], src[6], c[2], temp_c[0]); | |||||
| c[3] = vdotq_s32_h(weight[2], src[8], c[3], temp_c[1]); | |||||
| c[2] = vdotq_s32_h(weight[3], src[7], c[2], temp_c[0]); | |||||
| c[3] = vdotq_s32_h(weight[3], src[9], c[3], temp_c[1]); | |||||
| c[2] = vdotq_s32_h(weight[4], src[8], c[2], temp_c[0]); | |||||
| c[3] = vdotq_s32_h(weight[4], src[0], c[3], temp_c[1]); | |||||
| c[2] = vdotq_s32_h(weight[5], src[9], c[2], temp_c[0]); | |||||
| c[3] = vdotq_s32_h(weight[5], src[1], c[3], temp_c[1]); | |||||
| c[2] = vdotq_s32_h(weight[6], src[0], c[2], temp_c[0]); | |||||
| c[3] = vdotq_s32_h(weight[6], src[2], c[3], temp_c[1]); | |||||
| src[3] = vld1q_s8(src_ic_0_3 + 13 * 16); | |||||
| src[4] = vld1q_s8(src_ic_0_3 + 14 * 16); | |||||
| src[5] = vld1q_s8(src_ic_0_3 + 15 * 16); | |||||
| src[6] = vld1q_s8(src_ic_0_3 + 16 * 16); | |||||
| c[4] = vdotq_s32_h(weight[0], src[8], c[4], temp_c[0]); | |||||
| c[5] = vdotq_s32_h(weight[0], src[0], c[5], temp_c[1]); | |||||
| c[4] = vdotq_s32_h(weight[1], src[9], c[4], temp_c[0]); | |||||
| c[5] = vdotq_s32_h(weight[1], src[1], c[5], temp_c[1]); | |||||
| c[4] = vdotq_s32_h(weight[2], src[0], c[4], temp_c[0]); | |||||
| c[5] = vdotq_s32_h(weight[2], src[2], c[5], temp_c[1]); | |||||
| c[4] = vdotq_s32_h(weight[3], src[1], c[4], temp_c[0]); | |||||
| c[5] = vdotq_s32_h(weight[3], src[3], c[5], temp_c[1]); | |||||
| c[4] = vdotq_s32_h(weight[4], src[2], c[4], temp_c[0]); | |||||
| c[5] = vdotq_s32_h(weight[4], src[4], c[5], temp_c[1]); | |||||
| c[4] = vdotq_s32_h(weight[5], src[3], c[4], temp_c[0]); | |||||
| c[5] = vdotq_s32_h(weight[5], src[5], c[5], temp_c[1]); | |||||
| c[4] = vdotq_s32_h(weight[6], src[4], c[4], temp_c[0]); | |||||
| c[5] = vdotq_s32_h(weight[6], src[6], c[5], temp_c[1]); | |||||
| src[7] = vld1q_s8(src_ic_0_3 + 17 * 16); | |||||
| src[8] = vld1q_s8(src_ic_0_3 + 18 * 16); | |||||
| src[9] = vld1q_s8(src_ic_0_3 + 19 * 16); | |||||
| src[0] = vld1q_s8(src_ic_0_3 + 20 * 16); | |||||
| c[6] = vdotq_s32_h(weight[0], src[2], c[6], temp_c[0]); | |||||
| c[7] = vdotq_s32_h(weight[0], src[4], c[7], temp_c[1]); | |||||
| c[6] = vdotq_s32_h(weight[1], src[3], c[6], temp_c[0]); | |||||
| c[7] = vdotq_s32_h(weight[1], src[5], c[7], temp_c[1]); | |||||
| c[6] = vdotq_s32_h(weight[2], src[4], c[6], temp_c[0]); | |||||
| c[7] = vdotq_s32_h(weight[2], src[6], c[7], temp_c[1]); | |||||
| c[6] = vdotq_s32_h(weight[3], src[5], c[6], temp_c[0]); | |||||
| c[7] = vdotq_s32_h(weight[3], src[7], c[7], temp_c[1]); | |||||
| c[6] = vdotq_s32_h(weight[4], src[6], c[6], temp_c[0]); | |||||
| c[7] = vdotq_s32_h(weight[4], src[8], c[7], temp_c[1]); | |||||
| c[6] = vdotq_s32_h(weight[5], src[7], c[6], temp_c[0]); | |||||
| c[7] = vdotq_s32_h(weight[5], src[9], c[7], temp_c[1]); | |||||
| c[6] = vdotq_s32_h(weight[6], src[8], c[6], temp_c[0]); | |||||
| c[7] = vdotq_s32_h(weight[6], src[0], c[7], temp_c[1]); | |||||
| } | |||||
| weight_ptr += fh * fw * ld_weight_ic4; | |||||
| } | |||||
| store_oc4_ow8_remain_static<remain_w, Op>(c, op, dst_ptr); | |||||
| } | |||||
| } // namespace | |||||
| template <BiasMode bias_mode, typename Op, int remain_w> | |||||
| void conv_bias::conv_direct_stride2_2x2_int8_nchw44( | |||||
| const int8_t* src, const int8_t* filter, const int32_t* bias, | |||||
| int32_t* temp, int8_t* dst, const size_t oc, const size_t ic, | |||||
| const size_t ih, const size_t iw, const size_t oh, const size_t ow, | |||||
| const Op& op) { | |||||
| MEGDNN_MARK_USED_VAR(temp); | |||||
| constexpr size_t filter_size = 2; | |||||
| constexpr size_t fh = filter_size; | |||||
| constexpr size_t fw = filter_size; | |||||
| constexpr size_t ic_step = 4; | |||||
| constexpr size_t oc_step = 4; | |||||
| constexpr size_t big_oc_step = 8; | |||||
| constexpr size_t oh_step = 1; | |||||
| constexpr size_t ow_step = 8; | |||||
| constexpr size_t stride_h = 2; | |||||
| constexpr size_t stride_w = 2; | |||||
| constexpr int pack_iw_len = 4; | |||||
| const size_t out_img_stride = oh * ow; | |||||
| const size_t ow_end = ow / ow_step * ow_step; | |||||
| const size_t ow_remain = ow - ow_end; | |||||
| const size_t oc_end = oc / big_oc_step * big_oc_step; | |||||
| const size_t oc_remain = oc - oc_end; | |||||
| const int ld_oc = oh * ow * ic_step; | |||||
| for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { | |||||
| const size_t weight_offset = oc_idx * ic * fh * fw; | |||||
| for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||||
| for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||||
| const size_t src_offset = | |||||
| (oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step * | |||||
| pack_iw_len; | |||||
| const size_t dst_offset = oc_idx * out_img_stride + | |||||
| (oh_idx * ow + ow_idx) * oc_step; | |||||
| ker_neon_dirctconv_2x2s2_oc8_ow8<bias_mode, Op, 0, filter_size>( | |||||
| src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
| dst + dst_offset, ic, ih, iw, ld_oc, op); | |||||
| } | |||||
| if (ow_remain > 0) { | |||||
| const size_t src_offset = | |||||
| (oh_idx * stride_h * iw + ow_end * stride_w) * ic_step * | |||||
| pack_iw_len; | |||||
| const size_t dst_offset = oc_idx * out_img_stride + | |||||
| (oh_idx * ow + ow_end) * oc_step; | |||||
| ker_neon_dirctconv_2x2s2_oc8_ow8<bias_mode, Op, remain_w, | |||||
| filter_size>( | |||||
| src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
| dst + dst_offset, ic, ih, iw, ld_oc, op); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (oc_remain > 0) { | |||||
| const size_t oc_idx = oc_end; | |||||
| const size_t weight_offset = oc_idx * ic * fh * fw; | |||||
| for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||||
| for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||||
| const size_t src_offset = | |||||
| (oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step * | |||||
| pack_iw_len; | |||||
| const size_t dst_offset = oc_idx * out_img_stride + | |||||
| (oh_idx * ow + ow_idx) * oc_step; | |||||
| ker_neon_dirctconv_2x2s2_oc4_ow8<bias_mode, Op, 0, filter_size>( | |||||
| src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
| dst + dst_offset, ic, ih, iw, op); | |||||
| } | |||||
| if (ow_remain > 0) { | |||||
| const size_t src_offset = | |||||
| (oh_idx * stride_h * iw + ow_end * stride_w) * ic_step * | |||||
| pack_iw_len; | |||||
| const size_t dst_offset = oc_idx * out_img_stride + | |||||
| (oh_idx * ow + ow_end) * oc_step; | |||||
| ker_neon_dirctconv_2x2s2_oc4_ow8<bias_mode, Op, remain_w, | |||||
| filter_size>( | |||||
| src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
| dst + dst_offset, ic, ih, iw, op); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| template <BiasMode bias_mode, typename Op, int remain_w> | |||||
| void conv_bias::conv_direct_stride2_3x3_int8_nchw44( | |||||
| const int8_t* src, const int8_t* filter, const int32_t* bias, | |||||
| int32_t* temp, int8_t* dst, const size_t oc, const size_t ic, | |||||
| const size_t ih, const size_t iw, const size_t oh, const size_t ow, | |||||
| const Op& op) { | |||||
| MEGDNN_MARK_USED_VAR(temp); | |||||
| constexpr size_t filter_size = 3; | |||||
| constexpr size_t fh = filter_size; | |||||
| constexpr size_t fw = filter_size; | |||||
| constexpr size_t ic_step = 4; | |||||
| constexpr size_t oc_step = 4; | |||||
| constexpr size_t oh_step = 1; | |||||
| constexpr size_t ow_step = 8; | |||||
| constexpr size_t stride_h = 2; | |||||
| constexpr size_t stride_w = 2; | |||||
| constexpr int pack_iw_len = 4; | |||||
| const size_t img_stride = oh * ow; | |||||
| const size_t ow_end = ow / ow_step * ow_step; | |||||
| const size_t ow_remain = ow - ow_end; | |||||
| for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { | |||||
| const size_t weight_offset = oc_idx * ic * fh * fw; | |||||
| for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||||
| for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||||
| const size_t src_offset = | |||||
| (oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step * | |||||
| pack_iw_len; | |||||
| const size_t dst_offset = | |||||
| oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||||
| ker_neon_dirctconv_3x3s2_oc4_ow8<bias_mode, Op, 0, filter_size>( | |||||
| src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
| dst + dst_offset, ic, ih, iw, op); | |||||
| } | |||||
| if (ow_remain > 0) { | |||||
| const size_t src_offset = | |||||
| (oh_idx * stride_h * iw + ow_end * stride_w) * ic_step * | |||||
| pack_iw_len; | |||||
| const size_t dst_offset = | |||||
| oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||||
| ker_neon_dirctconv_3x3s2_oc4_ow8<bias_mode, Op, remain_w, | |||||
| filter_size>( | |||||
| src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
| dst + dst_offset, ic, ih, iw, op); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| template <BiasMode bias_mode, typename Op, int remain_w> | |||||
| void conv_bias::conv_direct_stride2_5x5_int8_nchw44( | |||||
| const int8_t* src, const int8_t* filter, const int32_t* bias, | |||||
| int32_t* temp, int8_t* dst, const size_t oc, const size_t ic, | |||||
| const size_t ih, const size_t iw, const size_t oh, const size_t ow, | |||||
| const Op& op) { | |||||
| MEGDNN_MARK_USED_VAR(temp); | |||||
| constexpr size_t filter_size = 5; | |||||
| constexpr size_t fh = filter_size; | |||||
| constexpr size_t fw = filter_size; | |||||
| constexpr size_t ic_step = 4; | |||||
| constexpr size_t oc_step = 4; | |||||
| constexpr size_t oh_step = 1; | |||||
| constexpr size_t ow_step = 8; | |||||
| constexpr size_t stride_h = 2; | |||||
| constexpr size_t stride_w = 2; | |||||
| constexpr int pack_iw_len = 4; | |||||
| const size_t img_stride = oh * ow; | |||||
| const size_t ow_end = ow / ow_step * ow_step; | |||||
| const size_t ow_remain = ow - ow_end; | |||||
| for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { | |||||
| const size_t weight_offset = oc_idx * ic * fh * fw; | |||||
| for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||||
| for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||||
| const size_t src_offset = | |||||
| (oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step * | |||||
| pack_iw_len; | |||||
| const size_t dst_offset = | |||||
| oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||||
| ker_neon_dirctconv_5x5s2_oc4_ow8<bias_mode, Op, 0, filter_size>( | |||||
| src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
| dst + dst_offset, ic, ih, iw, op); | |||||
| } | |||||
| if (ow_remain > 0) { | |||||
| const size_t src_offset = | |||||
| (oh_idx * stride_h * iw + ow_end * stride_w) * ic_step * | |||||
| pack_iw_len; | |||||
| const size_t dst_offset = | |||||
| oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||||
| ker_neon_dirctconv_5x5s2_oc4_ow8<bias_mode, Op, remain_w, | |||||
| filter_size>( | |||||
| src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
| dst + dst_offset, ic, ih, iw, op); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| template <BiasMode bias_mode, typename Op, int remain_w> | |||||
| void conv_bias::conv_direct_stride2_7x7_int8_nchw44( | |||||
| const int8_t* src, const int8_t* filter, const int32_t* bias, | |||||
| int32_t* temp, int8_t* dst, const size_t oc, const size_t ic, | |||||
| const size_t ih, const size_t iw, const size_t oh, const size_t ow, | |||||
| const Op& op) { | |||||
| MEGDNN_MARK_USED_VAR(temp); | |||||
| constexpr size_t filter_size = 7; | |||||
| constexpr size_t fh = filter_size; | |||||
| constexpr size_t fw = filter_size; | |||||
| constexpr size_t ic_step = 4; | |||||
| constexpr size_t oc_step = 4; | |||||
| constexpr size_t oh_step = 1; | |||||
| constexpr size_t ow_step = 8; | |||||
| constexpr size_t stride_h = 2; | |||||
| constexpr size_t stride_w = 2; | |||||
| constexpr int pack_iw_len = 4; | |||||
| const size_t img_stride = oh * ow; | |||||
| const size_t ow_end = ow / ow_step * ow_step; | |||||
| const size_t ow_remain = ow - ow_end; | |||||
| for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { | |||||
| const size_t weight_offset = oc_idx * ic * fh * fw; | |||||
| for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||||
| for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||||
| const size_t src_offset = | |||||
| (oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step * | |||||
| pack_iw_len; | |||||
| const size_t dst_offset = | |||||
| oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||||
| ker_neon_dirctconv_7x7s2_oc4_ow8<bias_mode, Op, 0, filter_size>( | |||||
| src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
| dst + dst_offset, ic, ih, iw, op); | |||||
| } | |||||
| if (ow_remain > 0) { | |||||
| const size_t src_offset = | |||||
| (oh_idx * stride_h * iw + ow_end * stride_w) * ic_step * | |||||
| pack_iw_len; | |||||
| const size_t dst_offset = | |||||
| oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||||
| ker_neon_dirctconv_7x7s2_oc4_ow8<bias_mode, Op, remain_w, | |||||
| filter_size>( | |||||
| src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
| dst + dst_offset, ic, ih, iw, op); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| #define INSTANTIATION(stride, i, bias, remain_w, Op) \ | |||||
| template void conv_bias::conv_direct_##stride##_##i##x##i##_int8_nchw44< \ | |||||
| bias, Op, remain_w>(const int8_t*, const int8_t*, const int32_t*, \ | |||||
| int32_t*, int8_t*, const size_t, const size_t, \ | |||||
| const size_t, const size_t, const size_t, \ | |||||
| const size_t, const Op&); | |||||
| #define FOR_OP(stride, i, bias, remain_w) \ | |||||
| INSTANTIATION(stride, i, bias, remain_w, \ | |||||
| TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
| INSTANTIATION(stride, i, bias, remain_w, \ | |||||
| ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
| INSTANTIATION(stride, i, bias, remain_w, \ | |||||
| HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) | |||||
| #define FOR_REMAIN(stride, i, bias) \ | |||||
| FOR_OP(stride, i, bias, 0) \ | |||||
| FOR_OP(stride, i, bias, 1) \ | |||||
| FOR_OP(stride, i, bias, 2) \ | |||||
| FOR_OP(stride, i, bias, 3) \ | |||||
| FOR_OP(stride, i, bias, 4) \ | |||||
| FOR_OP(stride, i, bias, 5) \ | |||||
| FOR_OP(stride, i, bias, 6) \ | |||||
| FOR_OP(stride, i, bias, 7) | |||||
| #define FOR_BIAS(stride, i) \ | |||||
| FOR_REMAIN(stride, i, BiasMode::NO_BIAS) \ | |||||
| FOR_REMAIN(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS) | |||||
| #define FOR_FILTER(stride) \ | |||||
| FOR_BIAS(stride, 2) \ | |||||
| FOR_BIAS(stride, 3) \ | |||||
| FOR_BIAS(stride, 5) \ | |||||
| FOR_BIAS(stride, 7) | |||||
| FOR_FILTER(stride2) | |||||
| #undef FOR_STRIDE | |||||
| #undef FOR_FILTER | |||||
| #undef FOR_IC | |||||
| #undef FOR_BIAS | |||||
| #undef FOR_NONLINEAR | |||||
| #undef FOR_REMAIN | |||||
| #undef INSTANTIATION | |||||
| @@ -46,11 +46,10 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||||
| AlgoQU8DirectStride1 qu8_direct_stride1_small_group{false}; | AlgoQU8DirectStride1 qu8_direct_stride1_small_group{false}; | ||||
| AlgoS8DirectStride2 s8_direct_stride2_large_group{true}; | AlgoS8DirectStride2 s8_direct_stride2_large_group{true}; | ||||
| AlgoS8DirectStride2 s8_direct_stride2_small_group{false}; | AlgoS8DirectStride2 s8_direct_stride2_small_group{false}; | ||||
| AlgoS8DirectStride2NCHW44 s8_direct_stride2_nchw44; | |||||
| AlgoS8DirectNCHW44 s8_direct_nchw44; | |||||
| AlgoS8DirectNCHWNCHW44 s8_direct_nchw_nchw44; | AlgoS8DirectNCHWNCHW44 s8_direct_nchw_nchw44; | ||||
| AlgoS8DirectStride1 s8_direct_stride1_large_group{true}; | AlgoS8DirectStride1 s8_direct_stride1_large_group{true}; | ||||
| AlgoS8DirectStride1 s8_direct_stride1_small_group{false}; | AlgoS8DirectStride1 s8_direct_stride1_small_group{false}; | ||||
| AlgoS8DirectStride1NCHW44 s8_direct_stride1_nchw44; | |||||
| AlgoS8ChanWiseStride1NCHW44 s8_channel_wise_stride1_nchw44; | AlgoS8ChanWiseStride1NCHW44 s8_channel_wise_stride1_nchw44; | ||||
| AlgoS8ChanWiseStride2NCHW44 s8_channel_wise_stride2_nchw44; | AlgoS8ChanWiseStride2NCHW44 s8_channel_wise_stride2_nchw44; | ||||
| @@ -114,11 +113,10 @@ public: | |||||
| direct_algos.emplace_back(&qu8_direct_stride1_small_group); | direct_algos.emplace_back(&qu8_direct_stride1_small_group); | ||||
| direct_algos.emplace_back(&s8_direct_stride2_large_group); | direct_algos.emplace_back(&s8_direct_stride2_large_group); | ||||
| direct_algos.emplace_back(&s8_direct_stride2_small_group); | direct_algos.emplace_back(&s8_direct_stride2_small_group); | ||||
| direct_algos.emplace_back(&s8_direct_stride2_nchw44); | |||||
| direct_algos.emplace_back(&s8_direct_nchw44); | |||||
| direct_algos.emplace_back(&s8_direct_nchw_nchw44); | direct_algos.emplace_back(&s8_direct_nchw_nchw44); | ||||
| direct_algos.emplace_back(&s8_direct_stride1_large_group); | direct_algos.emplace_back(&s8_direct_stride1_large_group); | ||||
| direct_algos.emplace_back(&s8_direct_stride1_small_group); | direct_algos.emplace_back(&s8_direct_stride1_small_group); | ||||
| direct_algos.emplace_back(&s8_direct_stride1_nchw44); | |||||
| direct_algos.emplace_back(&s8_channel_wise_stride1_nchw44); | direct_algos.emplace_back(&s8_channel_wise_stride1_nchw44); | ||||
| direct_algos.emplace_back(&s8_channel_wise_stride2_nchw44); | direct_algos.emplace_back(&s8_channel_wise_stride2_nchw44); | ||||
| @@ -37,9 +37,8 @@ protected: | |||||
| private: | private: | ||||
| class AlgoS8DirectStride1; | class AlgoS8DirectStride1; | ||||
| class AlgoS8DirectStride1NCHW44; | |||||
| class AlgoS8DirectStride2; | class AlgoS8DirectStride2; | ||||
| class AlgoS8DirectStride2NCHW44; | |||||
| class AlgoS8DirectNCHW44; | |||||
| class AlgoS8DirectNCHWNCHW44; | class AlgoS8DirectNCHWNCHW44; | ||||
| class AlgoQU8DirectStride1; | class AlgoQU8DirectStride1; | ||||
| class AlgoQU8DirectStride2; | class AlgoQU8DirectStride2; | ||||
| @@ -27,6 +27,8 @@ struct NoneOp; | |||||
| #define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \ | #define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \ | ||||
| template <> \ | template <> \ | ||||
| struct NoneOp<_ctype> : NoneOpBase<_ctype> { \ | struct NoneOp<_ctype> : NoneOpBase<_ctype> { \ | ||||
| NoneOp(){}; \ | |||||
| NoneOp(float, float){}; \ | |||||
| using NoneOpBase::NoneOpBase; \ | using NoneOpBase::NoneOpBase; \ | ||||
| using NoneOpBase::operator(); \ | using NoneOpBase::operator(); \ | ||||
| constexpr static size_t SIMD_WIDTH = _simd_width; \ | constexpr static size_t SIMD_WIDTH = _simd_width; \ | ||||
| @@ -226,7 +226,15 @@ static void benchmark_convbias(Handle* handle, std::string int_name, | |||||
| run(1, 3, 32, 224, 224, 5, 1, true); | run(1, 3, 32, 224, 224, 5, 1, true); | ||||
| run(1, 3, 64, 224, 224, 7, 1, true); | run(1, 3, 64, 224, 224, 7, 1, true); | ||||
| for (size_t stride : {1, 2}) { | |||||
| run(1, 64, 128, 56, 56, 3, 2, false); | |||||
| run(1, 128, 256, 28, 28, 3, 2, false); | |||||
| run(1, 256, 512, 14, 14, 3, 2, false); | |||||
| run(1, 128, 128, 28, 28, 3, 1, false); | |||||
| run(1, 256, 256, 14, 14, 3, 1, false); | |||||
| run(1, 512, 512, 7, 7, 3, 1, false); | |||||
| for (size_t stride : {1}) { | |||||
| printf("stride %zu\n", stride); | printf("stride %zu\n", stride); | ||||
| for (size_t filter_size : {2, 3, 5, 7}) { | for (size_t filter_size : {2, 3, 5, 7}) { | ||||
| for (size_t img_size : {32}) { | for (size_t img_size : {32}) { | ||||
| @@ -527,12 +527,22 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_SMALL_GROUP) { | |||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44) { | ||||
| checker_conv_bias_qint8x8x8( | checker_conv_bias_qint8x8x8( | ||||
| get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, false), | get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, false), | ||||
| handle(), "S8_NCHW44_DIRECT_STRD1"); | |||||
| handle(), "S8_NCHW44_DIRECT"); | |||||
| } | |||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44_8832) { | |||||
| checker_conv_bias_qint8x8x32( | |||||
| get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, true), | |||||
| handle(), "S8_NCHW44_DIRECT"); | |||||
| } | |||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44_8832) { | |||||
| checker_conv_bias_qint8x8x32( | |||||
| get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, true), | |||||
| handle(), "S8_NCHW44_DIRECT"); | |||||
| } | } | ||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44) { | ||||
| checker_conv_bias_qint8x8x8( | checker_conv_bias_qint8x8x8( | ||||
| get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, false), | get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, false), | ||||
| handle(), "S8_NCHW44_DIRECT_STRD2"); | |||||
| handle(), "S8_NCHW44_DIRECT"); | |||||
| } | } | ||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QS8_CHANNEL_WISE_DIRECT1_NCHW44) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QS8_CHANNEL_WISE_DIRECT1_NCHW44) { | ||||
| checker_conv_bias_qint8x8x8( | checker_conv_bias_qint8x8x8( | ||||
| @@ -1085,7 +1095,6 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_INT8) { | |||||
| dtype::QuantizedS8(60.25f), param::MatrixMul::Format::MK8, 1e-3); | dtype::QuantizedS8(60.25f), param::MatrixMul::Format::MK8, 1e-3); | ||||
| } | } | ||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8) { | ||||
| using namespace conv_bias; | using namespace conv_bias; | ||||
| @@ -1096,17 +1105,17 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8) { | |||||
| param::MatrixMul::Format format, float eps) { | param::MatrixMul::Format format, float eps) { | ||||
| for (auto&& arg : args) { | for (auto&& arg : args) { | ||||
| for (uint32_t m : out_size) { | for (uint32_t m : out_size) { | ||||
| checker.set_extra_opr_impl(std::bind( | |||||
| winograd_algo_extra_impl, std::placeholders::_1, m, | |||||
| arg.param, handle, format)); | |||||
| checker.set_dtype(0, A_dtype) | |||||
| .set_dtype(1, B_dtype) | |||||
| .set_dtype(2, C_dtype) | |||||
| .set_dtype(4, D_dtype) | |||||
| .set_epsilon(eps) | |||||
| .set_param(arg.param) | |||||
| .execs({arg.src, arg.filter, arg.bias, {}, {}}); | |||||
| } | |||||
| checker.set_extra_opr_impl(std::bind( | |||||
| winograd_algo_extra_impl, std::placeholders::_1, m, | |||||
| arg.param, handle, format)); | |||||
| checker.set_dtype(0, A_dtype) | |||||
| .set_dtype(1, B_dtype) | |||||
| .set_dtype(2, C_dtype) | |||||
| .set_dtype(4, D_dtype) | |||||
| .set_epsilon(eps) | |||||
| .set_param(arg.param) | |||||
| .execs({arg.src, arg.filter, arg.bias, {}, {}}); | |||||
| } | |||||
| } | } | ||||
| }; | }; | ||||
| @@ -1118,7 +1127,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8) { | |||||
| checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>( | checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>( | ||||
| ssprintf("WINOGRAD_NCHW44:%s:8:2:32", matmul_name).c_str())); | ssprintf("WINOGRAD_NCHW44:%s:8:2:32", matmul_name).c_str())); | ||||
| std::vector<TestArg> quantized_args = get_int8_nchw44_args (3,4); | |||||
| std::vector<TestArg> quantized_args = get_int8_nchw44_args(3, 4); | |||||
| UniformIntRNG int_rng{-50, 50}; | UniformIntRNG int_rng{-50, 50}; | ||||
| checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng); | checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng); | ||||
| run(handle(), quantized_args, {2}, dtype::QuantizedS8(2.5f), | run(handle(), quantized_args, {2}, dtype::QuantizedS8(2.5f), | ||||
| @@ -1126,8 +1135,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8) { | |||||
| dtype::QuantizedS8(60.25f), param::MatrixMul::Format::MK8, 1e-3); | dtype::QuantizedS8(60.25f), param::MatrixMul::Format::MK8, 1e-3); | ||||
| } | } | ||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_GROUPMODE) { | |||||
| TEST_F(ARM_COMMON_MULTI_THREADS, | |||||
| CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_GROUPMODE) { | |||||
| using namespace conv_bias; | using namespace conv_bias; | ||||
| Checker<ConvBiasForward> checker(handle()); | Checker<ConvBiasForward> checker(handle()); | ||||
| @@ -1137,17 +1146,17 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_GROUPM | |||||
| param::MatrixMul::Format format, float eps) { | param::MatrixMul::Format format, float eps) { | ||||
| for (auto&& arg : args) { | for (auto&& arg : args) { | ||||
| for (uint32_t m : out_size) { | for (uint32_t m : out_size) { | ||||
| checker.set_extra_opr_impl(std::bind( | |||||
| winograd_algo_extra_impl, std::placeholders::_1, m, | |||||
| arg.param, handle, format)); | |||||
| checker.set_dtype(0, A_dtype) | |||||
| .set_dtype(1, B_dtype) | |||||
| .set_dtype(2, C_dtype) | |||||
| .set_dtype(4, D_dtype) | |||||
| .set_epsilon(eps) | |||||
| .set_param(arg.param) | |||||
| .execs({arg.src, arg.filter, arg.bias, {}, {}}); | |||||
| } | |||||
| checker.set_extra_opr_impl(std::bind( | |||||
| winograd_algo_extra_impl, std::placeholders::_1, m, | |||||
| arg.param, handle, format)); | |||||
| checker.set_dtype(0, A_dtype) | |||||
| .set_dtype(1, B_dtype) | |||||
| .set_dtype(2, C_dtype) | |||||
| .set_dtype(4, D_dtype) | |||||
| .set_epsilon(eps) | |||||
| .set_param(arg.param) | |||||
| .execs({arg.src, arg.filter, arg.bias, {}, {}}); | |||||
| } | |||||
| } | } | ||||
| }; | }; | ||||
| @@ -1168,7 +1177,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_GROUPM | |||||
| dtype::QuantizedS8(60.25f), param::MatrixMul::Format::MK8, 1e-3); | dtype::QuantizedS8(60.25f), param::MatrixMul::Format::MK8, 1e-3); | ||||
| } | } | ||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F32) { | |||||
| TEST_F(ARM_COMMON_MULTI_THREADS, | |||||
| CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F32) { | |||||
| using namespace conv_bias; | using namespace conv_bias; | ||||
| Checker<ConvBiasForward> checker(handle()); | Checker<ConvBiasForward> checker(handle()); | ||||
| @@ -1196,21 +1206,22 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F | |||||
| #if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
| const char* matmul_name = "AARCH64_F32_MK4_4x16"; | const char* matmul_name = "AARCH64_F32_MK4_4x16"; | ||||
| #else | #else | ||||
| const char* matmul_name = "ARMV7_F32_MK4_4x8"; | |||||
| const char* matmul_name = "ARMV7_F32_MK4_4x8"; | |||||
| #endif | #endif | ||||
| checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>( | checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>( | ||||
| ssprintf("WINOGRAD_NCHW44:%s:4:2:32", matmul_name).c_str())); | ssprintf("WINOGRAD_NCHW44:%s:4:2:32", matmul_name).c_str())); | ||||
| std::vector<TestArg> quantized_args = | |||||
| get_int8_nchw44_args(3, 4, true); | |||||
| std::vector<TestArg> quantized_args = get_int8_nchw44_args(3, 4, true); | |||||
| UniformIntRNG int_rng{-50, 50}; | UniformIntRNG int_rng{-50, 50}; | ||||
| checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng); | checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng); | ||||
| run(handle(), quantized_args, {2}, dtype::QuantizedS8(0.41113496f), | run(handle(), quantized_args, {2}, dtype::QuantizedS8(0.41113496f), | ||||
| dtype::QuantizedS8(0.01887994f), | dtype::QuantizedS8(0.01887994f), | ||||
| dtype::QuantizedS32(0.41113496f * 0.01887994f), | dtype::QuantizedS32(0.41113496f * 0.01887994f), | ||||
| dtype::QuantizedS8(0.49550694f), param::MatrixMul::Format::MK4, epsilon); | |||||
| dtype::QuantizedS8(0.49550694f), param::MatrixMul::Format::MK4, | |||||
| epsilon); | |||||
| } | } | ||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F32_GROUPMODE) { | |||||
| TEST_F(ARM_COMMON_MULTI_THREADS, | |||||
| CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F32_GROUPMODE) { | |||||
| using namespace conv_bias; | using namespace conv_bias; | ||||
| Checker<ConvBiasForward> checker(handle()); | Checker<ConvBiasForward> checker(handle()); | ||||
| @@ -1238,7 +1249,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F | |||||
| #if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
| const char* matmul_name = "AARCH64_F32_MK4_4x16"; | const char* matmul_name = "AARCH64_F32_MK4_4x16"; | ||||
| #else | #else | ||||
| const char* matmul_name = "ARMV7_F32_MK4_4x8"; | |||||
| const char* matmul_name = "ARMV7_F32_MK4_4x8"; | |||||
| #endif | #endif | ||||
| checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>( | checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>( | ||||
| ssprintf("WINOGRAD_NCHW44:%s:4:2:32", matmul_name).c_str())); | ssprintf("WINOGRAD_NCHW44:%s:4:2:32", matmul_name).c_str())); | ||||
| @@ -1249,10 +1260,10 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F | |||||
| run(handle(), quantized_args, {2}, dtype::QuantizedS8(0.41113496f), | run(handle(), quantized_args, {2}, dtype::QuantizedS8(0.41113496f), | ||||
| dtype::QuantizedS8(0.01887994f), | dtype::QuantizedS8(0.01887994f), | ||||
| dtype::QuantizedS32(0.41113496f * 0.01887994f), | dtype::QuantizedS32(0.41113496f * 0.01887994f), | ||||
| dtype::QuantizedS8(0.49550694f), param::MatrixMul::Format::MK4, epsilon); | |||||
| dtype::QuantizedS8(0.49550694f), param::MatrixMul::Format::MK4, | |||||
| epsilon); | |||||
| } | } | ||||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F23) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F23) { | ||||
| using namespace conv_bias; | using namespace conv_bias; | ||||