GitOrigin-RevId: c8d3d55b25
tags/v0.5.0
| @@ -0,0 +1,36 @@ | |||
| /** | |||
| * \file dnn/src/arm_common/conv_bias/block_helper.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/common/utils.h" | |||
| namespace megdnn { | |||
| namespace { | |||
| // block_helper is used to calculate oh block size | |||
| static inline int l2_block_helper(const int nthread, const int amount, | |||
| const int size_per_unit) { | |||
| constexpr int l2_cache_size = 256 * 1024; | |||
| const int block_per_thread = div_ceil(amount, nthread); | |||
| const int best_block = std::min( | |||
| amount, (l2_cache_size + size_per_unit / 2) / size_per_unit); | |||
| const int max_block_num = div_ceil(block_per_thread, best_block); | |||
| const int min_block_num = std::max(max_block_num - 1, 1); | |||
| const int max_block = div_ceil(block_per_thread, max_block_num); | |||
| const int min_block = div_ceil(block_per_thread, min_block_num); | |||
| const int max_loss = std::abs(max_block_num * max_block - block_per_thread); | |||
| const int min_loss = std::abs(min_block_num * min_block - block_per_thread); | |||
| int block = max_loss > min_loss ? min_block : max_block; | |||
| return block; | |||
| } | |||
| } // namespace | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -11,6 +11,7 @@ | |||
| */ | |||
| #include "megdnn/oprs.h" | |||
| #include "src/arm_common/conv_bias/block_helper.h" | |||
| #include "src/arm_common/conv_bias/fp32/algos.h" | |||
| #include "src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.h" | |||
| #include "src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.h" | |||
| @@ -26,22 +27,7 @@ using conv_fun = std::function<void( | |||
| const CpuNDRange& workspace_ids, const CpuNDRange& ncb_range)>; | |||
| MIDOUT_DECL(megdnn_arm_common_conv_bias_fp32_nchw44_stride1) | |||
| namespace { | |||
| // block_helper is used to calculate oh block size | |||
| static inline int block_helper(const int nthread, const int amount, | |||
| const int size_per_unit) { | |||
| constexpr int l2_cache_size = 256 * 1024; | |||
| const int block_per_thread = div_ceil(amount, nthread); | |||
| const int best_block = std::min( | |||
| amount, (l2_cache_size + size_per_unit / 2) / size_per_unit); | |||
| const int max_block_num = div_ceil(block_per_thread, best_block); | |||
| const int min_block_num = std::max(max_block_num - 1, 1); | |||
| const int max_block = div_ceil(block_per_thread, max_block_num); | |||
| const int min_block = div_ceil(block_per_thread, min_block_num); | |||
| const int max_loss = std::abs(max_block_num * max_block - block_per_thread); | |||
| const int min_loss = std::abs(min_block_num * min_block - block_per_thread); | |||
| int block = max_loss > min_loss ? min_block : max_block; | |||
| return block; | |||
| } | |||
| static inline size_t get_perthread_cache_bytes(const int ic, const int ih2, | |||
| const int iw2) { | |||
| // border_size is used to avoid read illegal memory | |||
| @@ -60,7 +46,7 @@ static void get_rectified_size( | |||
| ow2 = ow; | |||
| constexpr int cacheline = 64 / sizeof(float); | |||
| int block_oh = | |||
| block_helper(param.nr_threads, oh, ic * iw * sizeof(float) * 2); | |||
| l2_block_helper(param.nr_threads, oh, ic * iw * sizeof(float) * 2); | |||
| auto&& fm = param.filter_meta; | |||
| const int stride_h = static_cast<int>(fm.stride[0]); | |||
| const int filter_h = static_cast<int>(fm.spatial[0]); | |||
| @@ -106,8 +92,8 @@ static void do_conv_kern(WorkspaceBundle bundle, | |||
| const int group_id = ncb_index.ndrange_id[1]; | |||
| constexpr int oc_idx = 0; | |||
| int oc_block = oc; | |||
| int oh_block = block_helper(kern_param.nr_threads, oh2, | |||
| ic * iw * sizeof(float) * stride_h); | |||
| int oh_block = l2_block_helper(kern_param.nr_threads, oh2, | |||
| ic * iw * sizeof(float) * stride_h); | |||
| const int oh_idx = ncb_index.ndrange_id[2]; | |||
| const int oh_block_real = std::min(oh - oh_idx * oh_block, oh_block); | |||
| const int ih_real = oh_block_real * stride_h + fh - stride_h; | |||
| @@ -298,8 +284,8 @@ ConvBiasImpl::AlgoF32DirectNCHW44::dispatch_kerns( | |||
| int ic = param.filter_meta.icpg; | |||
| int iw = param.isz[1]; | |||
| int stride_h = param.filter_meta.stride[0]; | |||
| int oh_block = block_helper(param.nr_threads, oh, | |||
| ic * iw * sizeof(float) * stride_h); | |||
| int oh_block = l2_block_helper(param.nr_threads, oh, | |||
| ic * iw * sizeof(float) * stride_h); | |||
| CpuNDRange ncb_range = {static_cast<size_t>(batch), | |||
| static_cast<size_t>(group), | |||
| static_cast<size_t>(div_ceil(oh, oh_block))}; | |||
| @@ -133,6 +133,21 @@ public: | |||
| }; | |||
| #if __ARM_FEATURE_DOTPROD | |||
| class ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44 final : public AlgoBase { | |||
| public: | |||
| bool is_reproducible() const override { return true; } | |||
| const char* name() const override { return "ARMDOTS8_NCHW_NCHW44"; } | |||
| bool usable(FallbackConvBiasImpl*, const NCBKernSizeParam&, | |||
| AlgoSelectionStrategy algo_selection_strategy) const override; | |||
| size_t get_workspace(FallbackConvBiasImpl*, | |||
| const NCBKernSizeParam&) const override; | |||
| virtual SmallVector<NCBKern> dispatch_kerns( | |||
| fallback::ConvBiasImpl* opr, | |||
| const NCBKernSizeParam& param) const override; | |||
| }; | |||
| class ConvBiasImpl::AlgoDotS8DirectStride1 final : public AlgoBase { | |||
| bool m_large_group; | |||
| @@ -0,0 +1,321 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express | |||
| * or implied. | |||
| */ | |||
| #if __ARM_FEATURE_DOTPROD | |||
| #include "megdnn/oprs.h" | |||
| #include "src/arm_common/conv_bias/block_helper.h" | |||
| #include "src/arm_common/conv_bias/int8/algos.h" | |||
| #include "src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "midout.h" | |||
| using namespace megdnn; | |||
| using namespace arm_common; | |||
| using conv_fun = std::function<void( | |||
| WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param, | |||
| const ConvBiasImpl::NCBKernIndex& ncb_index, | |||
| const CpuNDRange& workspace_ids, const CpuNDRange& ncb_range)>; | |||
| MIDOUT_DECL(megdnn_arm_common_conv_bias_int8_nchw44_dot) | |||
| namespace { | |||
| static inline size_t get_perthread_cache_bytes(const int ic, const int ih2, | |||
| const int iw2, | |||
| const int stride) { | |||
| //! border_size is used to avoid read illegal memory | |||
| constexpr int cacheline_size = 64; | |||
| constexpr int border_size = 2 * cacheline_size; | |||
| const int pack_iw_len = stride == 1 ? 4 : 1; | |||
| return round_up( | |||
| ic * ih2 * iw2 * pack_iw_len * (int)sizeof(int8_t) + border_size, | |||
| cacheline_size); | |||
| } | |||
| static inline size_t get_temp_bytes(const int iw, const int pw) { | |||
| //! border_size is used to avoid read illegal memory | |||
| constexpr int cacheline_size = 64; | |||
| constexpr int border_size = 1 * cacheline_size; | |||
| return round_up(iw + pw * 2, cacheline_size) + border_size; | |||
| } | |||
| static void get_rectified_size( | |||
| const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, int& ih2, | |||
| int& iw2) { | |||
| auto&& fm = param.filter_meta; | |||
| const int stride_h = static_cast<int>(fm.stride[0]); | |||
| const int filter_h = static_cast<int>(fm.spatial[0]); | |||
| int ic = param.filter_meta.icpg; | |||
| int iw = param.isz[1]; | |||
| int oh = param.osz[0]; | |||
| int block_oh = l2_block_helper(param.nr_threads, oh, | |||
| ic * iw * sizeof(int8_t) * stride_h); | |||
| ih2 = block_oh * stride_h + filter_h - stride_h; | |||
| iw2 = iw + 2 * static_cast<int>(fm.padding[1]); | |||
| } | |||
| static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { | |||
| auto&& fm = param.filter_meta; | |||
| int ic = fm.icpg; | |||
| int fh = fm.spatial[0]; | |||
| int fw = fm.spatial[1]; | |||
| int iw = param.isz[1]; | |||
| int pw = param.filter_meta.padding[1]; | |||
| int stride_w = param.filter_meta.stride[1]; | |||
| int ih2, iw2; | |||
| get_rectified_size(param, ih2, iw2); | |||
| size_t src_size = get_perthread_cache_bytes(ic, ih2, iw2, stride_w); | |||
| size_t weight_size = fm.group * fm.icpg * fm.ocpg * fh * round_up(fw, 4); | |||
| size_t temp_size = 0; | |||
| if (fm.stride[0] == 1) { | |||
| temp_size = get_temp_bytes(iw, pw); | |||
| } | |||
| return {nullptr, | |||
| {src_size * param.nr_threads, weight_size, | |||
| temp_size * param.nr_threads}}; | |||
| }; | |||
| void do_weight_trans(WorkspaceBundle bundle, | |||
| const ConvBiasImpl::NCBKernParam& kern_param, | |||
| const ConvBiasImpl::NCBKernIndex&, const CpuNDRange&) { | |||
| const int ic = kern_param.filter_meta.icpg; | |||
| const int oc = kern_param.filter_meta.ocpg; | |||
| const int fh = kern_param.filter_meta.spatial[0]; | |||
| const int fw = kern_param.filter_meta.spatial[1]; | |||
| const int fw2 = round_up(fw, 4); | |||
| bundle.set(kern_param.workspace_ptr); | |||
| auto packed_weight = reinterpret_cast<int8_t*>(bundle.get(1)); | |||
| auto origin_weight = kern_param.filter<dt_int8>(); | |||
| pack_weight_int8_nchw_nchw44_dot(packed_weight, origin_weight, oc, ic, fh, | |||
| fw, fw2); | |||
| } | |||
| template <size_t filter, BiasMode bias_mode, typename Op, int stride> | |||
| static void do_conv_kern(WorkspaceBundle bundle, | |||
| const ConvBiasImpl::NCBKernParam& kern_param, | |||
| const ConvBiasImpl::NCBKernIndex& ncb_index, | |||
| const CpuNDRange&, const CpuNDRange&) { | |||
| const int oh = kern_param.osz[0]; | |||
| const int ow = kern_param.osz[1]; | |||
| const int fh = kern_param.filter_meta.spatial[0]; | |||
| const int fw = kern_param.filter_meta.spatial[1]; | |||
| const int ic = kern_param.filter_meta.icpg; | |||
| const int oc = kern_param.filter_meta.ocpg; | |||
| const int ih = kern_param.isz[0]; | |||
| const int iw = kern_param.isz[1]; | |||
| const int stride_h = kern_param.filter_meta.stride[0]; | |||
| const int stride_w = kern_param.filter_meta.stride[1]; | |||
| const int ph = kern_param.filter_meta.padding[0]; | |||
| const int pw = kern_param.filter_meta.padding[1]; | |||
| int ih2 = 0; | |||
| int iw2 = 0; | |||
| get_rectified_size(kern_param, ih2, iw2); | |||
| bundle.set(kern_param.workspace_ptr); | |||
| constexpr int pack_c = 4; | |||
| const int batch_id = ncb_index.ndrange_id[0]; | |||
| const int group_id = ncb_index.ndrange_id[1]; | |||
| constexpr int oc_idx = 0; | |||
| int oc_block = oc; | |||
| int oh_block = l2_block_helper(kern_param.nr_threads, oh, | |||
| ic * iw * sizeof(int8_t) * stride_h); | |||
| const int oh_idx = ncb_index.ndrange_id[2]; | |||
| const int oh_block_real = std::min(oh - oh_idx * oh_block, oh_block); | |||
| const int ih_real = oh_block_real * stride_h + fh - stride_h; | |||
| const int src_top_pad = std::max(ph - oh_idx * oh_block * stride_h, 0); | |||
| const int src_bottom_pad = std::max( | |||
| (oh_idx * oh_block + oh_block_real - 1) * stride_h + fh - ih - ph, | |||
| 0); | |||
| const int remain_right_pad = std::max(iw2 - iw - pw, 0); | |||
| const int src_offset = std::max(oh_idx * oh_block * stride_h - ph, 0) * iw; | |||
| const int8_t* origin_sptr = | |||
| static_cast<const int8_t*>( | |||
| kern_param.src<int8_t>(batch_id, group_id, 0, 1, 1)) + | |||
| src_offset; | |||
| const size_t src_size = get_perthread_cache_bytes(ic, ih2, iw2, stride_w); | |||
| int8_t* sptr = reinterpret_cast<int8_t*>(bundle.get(0)) + | |||
| ncb_index.thread_id * src_size; | |||
| int8_t* tmp_ptr = nullptr; | |||
| if (stride == 1) { | |||
| const size_t tmp_size = get_temp_bytes(iw, pw); | |||
| tmp_ptr = reinterpret_cast<int8_t*>(bundle.get(2)) + | |||
| ncb_index.thread_id * tmp_size; | |||
| } | |||
| pack_src_int8_nchw_nchw44_dot<stride>( | |||
| sptr, origin_sptr, ph, pw, remain_right_pad, | |||
| ih_real - src_top_pad - src_bottom_pad, iw, iw2, src_top_pad, | |||
| src_bottom_pad, ic, ih * iw, tmp_ptr); | |||
| const int8_t* fptr = | |||
| reinterpret_cast<int8_t*>(bundle.get(1)) + oc_idx * fh * fw * ic; | |||
| int8_t* dst = kern_param.dst<int8_t>(batch_id, group_id) + | |||
| oh_idx * oh_block * ow * pack_c; | |||
| const int bias_offset = oc_idx; | |||
| const int32_t* bptr = | |||
| kern_param.bias<dt_int32>(batch_id, group_id) + bias_offset; | |||
| float scale_bias = kern_param.bias_type.param<dtype::QuantizedS32>().scale; | |||
| float scale_dst = kern_param.dst_type.param<dtype::QuantizedS8>().scale; | |||
| Op op(scale_bias, scale_dst); | |||
| conv_direct_int8_nchw_nchw44_dot<bias_mode, Op, filter, stride>( | |||
| sptr, fptr, bptr, nullptr, dst, oc_block, ic, ih_real, iw2, oh, | |||
| oh_block_real, ow, op); | |||
| } | |||
| } // namespace | |||
| bool ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::usable( | |||
| fallback::ConvBiasImpl*, const NCBKernSizeParam& param, | |||
| AlgoSelectionStrategy) const { | |||
| auto&& fm = param.filter_meta; | |||
| auto fh = fm.spatial[0]; | |||
| int oc = fm.ocpg; | |||
| int ic = fm.icpg; | |||
| bool ok_type = ((param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||
| param.filter_type.enumv() == DTypeEnum::QuantizedS8 && | |||
| (param.dst_type.enumv() == DTypeEnum::QuantizedS8))) && | |||
| (fm.format == param::Convolution::Format::NCHW44); | |||
| bool ok_src_dst = (oc % 4 == 0 && oc >= 4 && ic < 4); | |||
| bool ok_filter = fm.spatial_ndim == 2 && fh == fm.spatial[1] && | |||
| (fh == 2 || fh == 3 || fh == 5 || fh == 7); | |||
| bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||
| fm.stride[0] == fm.stride[1] && | |||
| (fm.stride[0] == 1 || fm.stride[0] == 2); | |||
| bool ok_conv = !fm.should_flip && param.bias_mode != BiasMode::BIAS; | |||
| bool avaible = ok_type && ok_src_dst && ok_filter && ok_slide && ok_conv; | |||
| return avaible; | |||
| } | |||
| size_t ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::get_workspace( | |||
| fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||
| return get_bundle(param).total_size_in_bytes(); | |||
| } | |||
| SmallVector<ConvBiasImpl::NCBKern> | |||
| ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::dispatch_kerns( | |||
| fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||
| auto fm = param.filter_meta; | |||
| const int batch = param.n; | |||
| const int group = fm.group; | |||
| WorkspaceBundle wbundle = get_bundle(param); | |||
| conv_fun do_conv_fun = nullptr; | |||
| // NOTE: remain_w is not used to gen hash of midout for compatible with | |||
| // shape runtime | |||
| #define DO_CONV_KERN_FUN(stride, filter, bias_mode, op) \ | |||
| MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw44_dot, \ | |||
| midout_iv(#stride #filter #bias_mode #op##_hash)) { \ | |||
| do_conv_fun = do_conv_kern<filter, bias_mode, op, stride>; \ | |||
| } \ | |||
| MIDOUT_END(); | |||
| #define GET_OP_PARAM(stride, filter, bias_mode) \ | |||
| switch (param.nonlineMode) { \ | |||
| case param::ConvBias::NonlineMode::IDENTITY: \ | |||
| DO_CONV_KERN_FUN(stride, filter, bias_mode, \ | |||
| TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||
| break; \ | |||
| case param::ConvBias::NonlineMode::RELU: \ | |||
| DO_CONV_KERN_FUN(stride, filter, bias_mode, \ | |||
| ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||
| break; \ | |||
| case param::ConvBias::NonlineMode::H_SWISH: \ | |||
| DO_CONV_KERN_FUN(stride, filter, bias_mode, \ | |||
| HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||
| break; \ | |||
| default: \ | |||
| megdnn_assert(0); \ | |||
| break; \ | |||
| } | |||
| #define GET_BIAS_MODE_PARAM(stride, filter) \ | |||
| switch (param.bias_mode) { \ | |||
| case BiasMode::NO_BIAS: \ | |||
| GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \ | |||
| break; \ | |||
| case BiasMode::BROADCAST_CHANNEL_BIAS: \ | |||
| GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) \ | |||
| break; \ | |||
| default: \ | |||
| megdnn_assert(0); \ | |||
| break; \ | |||
| } | |||
| #define DISPATCH_CONV_KERN(stride) \ | |||
| switch (param.filter_meta.spatial[0]) { \ | |||
| case 2: \ | |||
| GET_BIAS_MODE_PARAM(stride, 2) \ | |||
| break; \ | |||
| case 3: \ | |||
| GET_BIAS_MODE_PARAM(stride, 3) \ | |||
| break; \ | |||
| case 5: \ | |||
| GET_BIAS_MODE_PARAM(stride, 5) \ | |||
| break; \ | |||
| case 7: \ | |||
| GET_BIAS_MODE_PARAM(stride, 7) \ | |||
| break; \ | |||
| default: \ | |||
| megdnn_assert(0); \ | |||
| break; \ | |||
| } | |||
| switch (param.filter_meta.stride[0]) { | |||
| case 1: | |||
| DISPATCH_CONV_KERN(1); | |||
| break; | |||
| case 2: | |||
| DISPATCH_CONV_KERN(2); | |||
| break; | |||
| default: | |||
| megdnn_assert(0); | |||
| break; | |||
| } | |||
| #undef DO_CONV_KERN_FUN | |||
| #undef GET_REMAIN_W_PARAM | |||
| #undef GET_OP_PARAM | |||
| #undef GET_BIAS_MODE_PARAM | |||
| #undef DISPATCH_CONV_KERN | |||
| megdnn_assert(do_conv_fun); | |||
| SmallVector<ConvBiasImpl::NCBKern> ret_kerns; | |||
| WorkspaceBundle bundle = wbundle; | |||
| int oh = param.osz[0]; | |||
| int ic = param.filter_meta.icpg; | |||
| int iw = param.isz[1]; | |||
| int stride_h = param.filter_meta.stride[0]; | |||
| int oh_block = l2_block_helper(param.nr_threads, oh, | |||
| ic * iw * sizeof(int8_t) * stride_h); | |||
| CpuNDRange ncb_range = {static_cast<size_t>(batch), | |||
| static_cast<size_t>(group), | |||
| static_cast<size_t>(div_ceil(oh, oh_block))}; | |||
| auto do_trans_weight = [bundle](const NCBKernParam& kern_param, | |||
| const NCBKernIndex& ncb_index) { | |||
| do_weight_trans(bundle, kern_param, ncb_index, ncb_index.ndrange_id); | |||
| }; | |||
| ret_kerns.push_back({do_trans_weight, {1}}); | |||
| auto do_conv = [bundle, do_conv_fun, ncb_range]( | |||
| const NCBKernParam& kern_param, | |||
| const NCBKernIndex& ncb_index) { | |||
| do_conv_fun(bundle, kern_param, ncb_index, ncb_index.ndrange_id, | |||
| ncb_range); | |||
| }; | |||
| ret_kerns.push_back({do_conv, ncb_range}); | |||
| return ret_kerns; | |||
| } | |||
| #endif | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,779 @@ | |||
| /** | |||
| * \file | |||
| * dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #if __ARM_FEATURE_DOTPROD | |||
| #include "src/arm_common/conv_bias/intrinsic_helper.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||
| #include "src/common/unroll_macro.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/fallback/conv_bias/common.h" | |||
| using namespace megdnn; | |||
| using namespace arm_common; | |||
| namespace { | |||
| template <int src_idx, int weight_idx, int c_dim, typename Func, int ow_block, | |||
| int stride, typename T, typename T2, typename T3, typename T4> | |||
| struct ShiftCalHelper { | |||
| static void impl(T& c, T2& src, T3& weight); | |||
| }; | |||
| template <int src_idx, int weight_idx, typename Func, int stride, typename T, | |||
| typename T2, typename T3, typename T4> | |||
| struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 8, stride, T, T2, T3, T4> { | |||
| static void impl(T& c, T2& src, T3& weight) { | |||
| #define cb(step) \ | |||
| c[0][step * 2] = Func::template impl<(src_idx + step) % 4>( \ | |||
| c[0][step * 2], weight[0][weight_idx], \ | |||
| src[0][(src_idx + step) / 4]); \ | |||
| c[1][step * 2] = Func::template impl<(src_idx + step) % 4>( \ | |||
| c[1][step * 2], weight[1][weight_idx], \ | |||
| src[0][(src_idx + step) / 4]); \ | |||
| c[0][step * 2 + 1] = Func::template impl<(src_idx + step) % 4>( \ | |||
| c[0][step * 2 + 1], weight[0][weight_idx], \ | |||
| src[1][(src_idx + step) / 4]); \ | |||
| c[1][step * 2 + 1] = Func::template impl<(src_idx + step) % 4>( \ | |||
| c[1][step * 2 + 1], weight[1][weight_idx], \ | |||
| src[1][(src_idx + step) / 4]); | |||
| UNROLL_CALL_RAW(4, cb); | |||
| #undef cb | |||
| } | |||
| }; | |||
| template <int src_idx, int weight_idx, typename Func, int stride, typename T, | |||
| typename T2, typename T3, typename T4> | |||
| struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, stride, T, T2, T3, T4> { | |||
| static void impl(T& c, T2& src, T3& weight) { | |||
| #define cb(step) \ | |||
| c[0][step * 2] = Func::template impl<(src_idx + step) % 4>( \ | |||
| c[0][step * 2], weight[0][weight_idx], \ | |||
| src[0][(src_idx + step) / 4]); \ | |||
| c[0][step * 2 + 1] = Func::template impl<(src_idx + step) % 4>( \ | |||
| c[0][step * 2 + 1], weight[0][weight_idx], \ | |||
| src[1][(src_idx + step) / 4]); | |||
| UNROLL_CALL_RAW(4, cb); | |||
| #undef cb | |||
| } | |||
| }; | |||
| template <int src_idx, int weight_idx, typename Func, typename T, typename T2, | |||
| typename T3, typename T4> | |||
| struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 8, 1, T, T2, T3, T4> { | |||
| static void impl(T& c, T2& src, T3& weight) { | |||
| #define cb(step) \ | |||
| c[0][step] = Func::template impl<(src_idx + step) % 4>( \ | |||
| c[0][step], weight[0][weight_idx], src[(src_idx + step) / 4]); \ | |||
| c[1][step] = Func::template impl<(src_idx + step) % 4>( \ | |||
| c[1][step], weight[1][weight_idx], src[(src_idx + step) / 4]); | |||
| UNROLL_CALL_RAW(8, cb); | |||
| #undef cb | |||
| } | |||
| }; | |||
| template <int src_idx, int weight_idx, typename Func, typename T, typename T2, | |||
| typename T3, typename T4> | |||
| struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, 1, T, T2, T3, T4> { | |||
| static void impl(T& c, T2& src, T3& weight) { | |||
| #define cb(step) \ | |||
| c[0][step] = Func::template impl<(src_idx + step) % 4>( \ | |||
| c[0][step], weight[0][weight_idx], src[(src_idx + step) / 4]); | |||
| UNROLL_CALL_RAW(8, cb); | |||
| #undef cb | |||
| } | |||
| }; | |||
| template <int src_idx, int weight_idx, int c_dim, typename FUNC, int ow_block, | |||
| int stride, typename T, typename T2, typename T3> | |||
| inline void cal_helper(T& c, T2& src, T3& weight) { | |||
| ShiftCalHelper<src_idx, weight_idx, c_dim, FUNC, ow_block, stride, T, T2, | |||
| T3, int>::impl(c, src, weight); | |||
| }; | |||
| //! OCHelper is used to trans oc_block to row number of result regs | |||
| template <int oc> | |||
| struct OCHelper { | |||
| public: | |||
| static const int val = -1; | |||
| }; | |||
| template <> | |||
| struct OCHelper<4> { | |||
| public: | |||
| static const int val = 1; | |||
| }; | |||
| #if MEGDNN_AARCH64 | |||
| template <> | |||
| struct OCHelper<8> { | |||
| public: | |||
| static const int val = 2; | |||
| }; | |||
| #endif | |||
| /** | |||
| * oc8_ow8(m = 8, n = 8) and oc4_ow8(m = 4, n = 8) gemm like kernel | |||
| * */ | |||
| template <BiasMode bias_mode, typename Op, int remain_w, int filter_size, | |||
| int oc_block, int ow_block, int stride> | |||
| struct KerNeonDotXXs2Nchw44Int8 { | |||
| static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||
| const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | |||
| int iw, int ld_dst_oc, const Op& op); | |||
| }; | |||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | |||
| int ow_block, int stride> | |||
| struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 2, oc_block, ow_block, | |||
| stride> { | |||
| static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||
| const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | |||
| int iw, int ld_dst_oc, const Op& op) { | |||
| constexpr int filter_hight = 2; | |||
| constexpr int filter_width = 4; | |||
| constexpr int weight_reg = 1; | |||
| constexpr int src_reg = 1; | |||
| constexpr int oc_step = 4; | |||
| constexpr int ic_step = 1; | |||
| constexpr int pack_iw_len = 1; | |||
| constexpr int simd_len = 16; | |||
| const int ld_bias = oc_step; | |||
| const int ic_stride = ih * iw * pack_iw_len; | |||
| const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; | |||
| constexpr int c_dim = OCHelper<oc_block>::val; | |||
| int32x4_t c[c_dim][8]; | |||
| init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias); | |||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | |||
| int8x16_t src[2][src_reg]; | |||
| int8x16_t weight[c_dim][weight_reg]; | |||
| // row 0 | |||
| load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>( | |||
| src, src_ptr + 0 * iw, stride); | |||
| load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( | |||
| weight, weight_ptr, ld_weight_oc); | |||
| cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, | |||
| weight); | |||
| // row 1 | |||
| load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>( | |||
| src, src_ptr + 1 * iw, stride); | |||
| load_helper<weight_reg, 1 * simd_len, simd_len, c_dim, Vld1q_s8>( | |||
| weight, weight_ptr, ld_weight_oc); | |||
| cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, | |||
| weight); | |||
| src_ptr += ic_stride; | |||
| weight_ptr += filter_hight * filter_width * oc_step; | |||
| } | |||
| store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>( | |||
| c, op, dst_ptr, ld_dst_oc); | |||
| } | |||
| }; | |||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | |||
| int ow_block, int stride> | |||
| struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 3, oc_block, ow_block, | |||
| stride> { | |||
| static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||
| const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | |||
| int iw, int ld_dst_oc, const Op& op) { | |||
| constexpr int filter_hight = 3; | |||
| constexpr int filter_width = 4; | |||
| constexpr int weight_reg = 1; | |||
| constexpr int src_reg = 1; | |||
| constexpr int oc_step = 4; | |||
| constexpr int ic_step = 1; | |||
| constexpr int pack_iw_len = 1; | |||
| constexpr int simd_len = 16; | |||
| const int ld_bias = oc_step; | |||
| const int ic_stride = ih * iw * pack_iw_len; | |||
| const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; | |||
| constexpr int c_dim = OCHelper<oc_block>::val; | |||
| int32x4_t c[c_dim][8]; | |||
| init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias); | |||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | |||
| int8x16_t src[2][src_reg]; | |||
| int8x16_t weight[c_dim][weight_reg]; | |||
| // row 0 | |||
| load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>( | |||
| src, src_ptr + 0 * iw, stride); | |||
| load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( | |||
| weight, weight_ptr, ld_weight_oc); | |||
| cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, | |||
| weight); | |||
| // row 1 | |||
| load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>( | |||
| src, src_ptr + 1 * iw, stride); | |||
| load_helper<weight_reg, 1 * simd_len, simd_len, c_dim, Vld1q_s8>( | |||
| weight, weight_ptr, ld_weight_oc); | |||
| cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, | |||
| weight); | |||
| // row 2 | |||
| load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>( | |||
| src, src_ptr + 2 * iw, stride); | |||
| load_helper<weight_reg, 2 * simd_len, simd_len, c_dim, Vld1q_s8>( | |||
| weight, weight_ptr, ld_weight_oc); | |||
| cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, | |||
| weight); | |||
| src_ptr += ic_stride; | |||
| weight_ptr += filter_hight * filter_width * oc_step; | |||
| } | |||
| store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>( | |||
| c, op, dst_ptr, ld_dst_oc); | |||
| } | |||
| }; | |||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | |||
| int ow_block, int stride> | |||
| struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 5, oc_block, ow_block, | |||
| stride> { | |||
| static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||
| const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | |||
| int iw, int ld_dst_oc, const Op& op) { | |||
| constexpr int filter_hight = 5; | |||
| constexpr int filter_width = 8; | |||
| constexpr int src_reg = 2; | |||
| constexpr int weight_reg = 2; | |||
| constexpr int oc_step = 4; | |||
| constexpr int ic_step = 1; | |||
| constexpr int pack_iw_len = 1; | |||
| constexpr int simd_len = 16; | |||
| const int ld_bias = oc_step; | |||
| const int ic_stride = ih * iw * pack_iw_len; | |||
| const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; | |||
| constexpr int c_dim = OCHelper<oc_block>::val; | |||
| int32x4_t c[c_dim][8]; | |||
| init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias); | |||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | |||
| int8x16_t src[2][src_reg]; | |||
| int8x16_t weight[c_dim][weight_reg]; | |||
| #define cb(step) \ | |||
| load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>(src, src_ptr + step * iw, \ | |||
| stride); \ | |||
| load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \ | |||
| weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \ | |||
| cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \ | |||
| weight); \ | |||
| cal_helper<1, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight); | |||
| UNROLL_CALL_RAW(5, cb); | |||
| #undef cb | |||
| src_ptr += ic_stride; | |||
| weight_ptr += 5 * 32; | |||
| } | |||
| store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>( | |||
| c, op, dst_ptr, ld_dst_oc); | |||
| } | |||
| }; | |||
| /** | |||
| * oc = 8, ow = 8 | |||
| * dot 4 element, pad last filter and do twice dot every row filter, filter like | |||
| * below | |||
| * -------------------------- | |||
| * |x, x, x, x,| x, x, x, 0 | | |||
| * -------------------------- | |||
| **/ | |||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | |||
| int ow_block, int stride> | |||
| struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 7, oc_block, ow_block, | |||
| stride> { | |||
| static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||
| const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | |||
| int iw, int ld_dst_oc, const Op& op) { | |||
| constexpr int filter_hight = 7; | |||
| constexpr int filter_width = 8; | |||
| constexpr int src_reg = 2; | |||
| constexpr int weight_reg = 2; | |||
| constexpr int oc_step = 4; | |||
| constexpr int ic_step = 1; | |||
| constexpr int pack_iw_len = 1; | |||
| constexpr int simd_len = 16; | |||
| const int ld_bias = oc_step; | |||
| const int ic_stride = ih * iw * pack_iw_len; | |||
| const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; | |||
| constexpr int c_dim = OCHelper<oc_block>::val; | |||
| int32x4_t c[c_dim][8]; | |||
| init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias); | |||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | |||
| int8x16_t src[2][src_reg]; | |||
| int8x16_t weight[c_dim][weight_reg]; | |||
| #define cb(step) \ | |||
| load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>(src, src_ptr + step * iw, \ | |||
| stride); \ | |||
| load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \ | |||
| weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \ | |||
| cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \ | |||
| weight); \ | |||
| cal_helper<1, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight); | |||
| UNROLL_CALL_RAW(7, cb); | |||
| #undef cb | |||
| src_ptr += ic_stride; | |||
| weight_ptr += 7 * 32; | |||
| } | |||
| store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>( | |||
| c, op, dst_ptr, ld_dst_oc); | |||
| } | |||
| }; | |||
| ////////////////////stride 1/////////////////// | |||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | |||
| int ow_block> | |||
| struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 2, oc_block, ow_block, | |||
| 1> { | |||
| static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||
| const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | |||
| int iw, int ld_dst_oc, const Op& op) { | |||
| constexpr int stride = 1; | |||
| constexpr int filter_hight = 2; | |||
| constexpr int filter_width = 4; | |||
| constexpr int weight_reg = 2; | |||
| constexpr int src_reg = 2; | |||
| constexpr int oc_step = 4; | |||
| constexpr int ic_step = 1; | |||
| constexpr int pack_iw_len = 4; | |||
| constexpr int simd_len = 16; | |||
| const int ld_bias = oc_step; | |||
| const int ic_stride = ih * iw * pack_iw_len; | |||
| const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; | |||
| constexpr int c_dim = OCHelper<oc_block>::val; | |||
| int32x4_t c[c_dim][8]; | |||
| init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias); | |||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | |||
| int8x16_t src[src_reg]; | |||
| int8x16_t weight[c_dim][weight_reg]; | |||
| // row 0 | |||
| load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( | |||
| src, src_ptr + 0 * iw * pack_iw_len, 0); | |||
| load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( | |||
| weight, weight_ptr, ld_weight_oc); | |||
| cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, | |||
| weight); | |||
| // row 1 | |||
| load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( | |||
| src, src_ptr + 1 * iw * pack_iw_len, 0); | |||
| cal_helper<0, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, | |||
| weight); | |||
| src_ptr += ic_stride; | |||
| weight_ptr += filter_hight * filter_width * oc_step; | |||
| } | |||
| store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>( | |||
| c, op, dst_ptr, ld_dst_oc); | |||
| } | |||
| }; | |||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | |||
| int ow_block> | |||
| struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 3, oc_block, ow_block, | |||
| 1> { | |||
| static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||
| const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | |||
| int iw, int ld_dst_oc, const Op& op) { | |||
| constexpr int stride = 1; | |||
| constexpr int filter_hight = 3; | |||
| constexpr int filter_width = 4; | |||
| constexpr int weight_reg = 3; | |||
| constexpr int src_reg = 2; | |||
| constexpr int oc_step = 4; | |||
| constexpr int ic_step = 1; | |||
| constexpr int pack_iw_len = 4; | |||
| constexpr int simd_len = 16; | |||
| const int ld_bias = oc_step; | |||
| const int ic_stride = ih * iw * pack_iw_len; | |||
| const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; | |||
| constexpr int c_dim = OCHelper<oc_block>::val; | |||
| int32x4_t c[c_dim][8]; | |||
| init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias); | |||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | |||
| int8x16_t src[src_reg]; | |||
| int8x16_t weight[c_dim][weight_reg]; | |||
| // row 0 | |||
| load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( | |||
| src, src_ptr + 0 * iw * pack_iw_len, 0); | |||
| load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( | |||
| weight, weight_ptr, ld_weight_oc); | |||
| cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, | |||
| weight); | |||
| // row 1 | |||
| load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( | |||
| src, src_ptr + 1 * iw * pack_iw_len, 0); | |||
| cal_helper<0, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, | |||
| weight); | |||
| // row 2 | |||
| load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( | |||
| src, src_ptr + 2 * iw * pack_iw_len, 0); | |||
| cal_helper<0, 2, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, | |||
| weight); | |||
| src_ptr += ic_stride; | |||
| weight_ptr += filter_hight * filter_width * oc_step; | |||
| } | |||
| store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>( | |||
| c, op, dst_ptr, ld_dst_oc); | |||
| } | |||
| }; | |||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | |||
| int ow_block> | |||
| struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 5, oc_block, ow_block, | |||
| 1> { | |||
| static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||
| const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | |||
| int iw, int ld_dst_oc, const Op& op) { | |||
| constexpr int stride = 1; | |||
| constexpr int filter_hight = 5; | |||
| constexpr int filter_width = 8; | |||
| constexpr int src_reg = 3; | |||
| constexpr int weight_reg = 2; | |||
| constexpr int oc_step = 4; | |||
| constexpr int ic_step = 1; | |||
| constexpr int pack_iw_len = 4; | |||
| constexpr int simd_len = 16; | |||
| const int ld_bias = oc_step; | |||
| const int ic_stride = ih * iw * pack_iw_len; | |||
| const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; | |||
| constexpr int c_dim = OCHelper<oc_block>::val; | |||
| int32x4_t c[c_dim][8]; | |||
| init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias); | |||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | |||
| int8x16_t src[src_reg]; | |||
| int8x16_t weight[c_dim][weight_reg]; | |||
| #define cb(step) \ | |||
| load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( \ | |||
| src, src_ptr + step * iw * pack_iw_len, 0); \ | |||
| load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \ | |||
| weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \ | |||
| cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \ | |||
| weight); \ | |||
| cal_helper<4, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight); | |||
| UNROLL_CALL_RAW(5, cb); | |||
| #undef cb | |||
| src_ptr += ic_stride; | |||
| weight_ptr += filter_hight * filter_width * oc_step; | |||
| } | |||
| store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>( | |||
| c, op, dst_ptr, ld_dst_oc); | |||
| } | |||
| }; | |||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | |||
| int ow_block> | |||
| struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 7, oc_block, ow_block, | |||
| 1> { | |||
| static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||
| const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | |||
| int iw, int ld_dst_oc, const Op& op) { | |||
| constexpr int stride = 1; | |||
| constexpr int filter_hight = 7; | |||
| constexpr int filter_width = 8; | |||
| constexpr int src_reg = 3; | |||
| constexpr int weight_reg = 2; | |||
| constexpr int oc_step = 4; | |||
| constexpr int ic_step = 1; | |||
| constexpr int pack_iw_len = 4; | |||
| constexpr int simd_len = 16; | |||
| const int ld_bias = oc_step; | |||
| const int ic_stride = ih * iw * pack_iw_len; | |||
| const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; | |||
| constexpr int c_dim = OCHelper<oc_block>::val; | |||
| int32x4_t c[c_dim][8]; | |||
| init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias); | |||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | |||
| int8x16_t src[src_reg]; | |||
| int8x16_t weight[c_dim][weight_reg]; | |||
| #define cb(step) \ | |||
| load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( \ | |||
| src, src_ptr + step * iw * pack_iw_len, 0); \ | |||
| load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \ | |||
| weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \ | |||
| cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \ | |||
| weight); \ | |||
| cal_helper<4, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight); | |||
| UNROLL_CALL_RAW(7, cb); | |||
| #undef cb | |||
| src_ptr += ic_stride; | |||
| weight_ptr += filter_hight * filter_width * oc_step; | |||
| } | |||
| store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>( | |||
| c, op, dst_ptr, ld_dst_oc); | |||
| } | |||
| }; | |||
| template <int stride> | |||
| void pack_src_int8_nchw_nchw44_dot(int8_t* sptr_base, const int8_t* sptr_origin, | |||
| const int, const int pw, const int, | |||
| const int ih, const int iw, const int iw2, | |||
| const int pad_top, const int pad_bottom, | |||
| const int ic, const int ic_stride, int8_t*) { | |||
| constexpr int ic_step = 1; | |||
| rep_step(ic_idx, ic, ic_step) { | |||
| const int8_t* sptr = sptr_origin + ic_idx * ic_stride; | |||
| memset(sptr_base, 0, | |||
| sizeof(int8_t) * ic_step * iw2 * (ih + pad_top + pad_bottom)); | |||
| sptr_base += iw2 * pad_top * ic_step; | |||
| rep(ih_idx, ih) { | |||
| memcpy(sptr_base + pw * ic_step, sptr, | |||
| sizeof(int8_t) * iw * ic_step); | |||
| sptr_base += iw2 * ic_step; | |||
| sptr += iw * ic_step; | |||
| } | |||
| sptr_base += iw2 * pad_bottom * ic_step; | |||
| } | |||
| } | |||
| template <> | |||
| void pack_src_int8_nchw_nchw44_dot<1>(int8_t* sptr_base, | |||
| const int8_t* sptr_origin, const int, | |||
| const int pw, const int, const int ih, | |||
| const int iw, const int iw2, | |||
| const int pad_top, const int pad_bottom, | |||
| const int ic, const int ic_stride, | |||
| int8_t* temp_ptr) { | |||
| static uint8_t reorder_idx[16] = {0, 1, 2, 3, 1, 2, 3, 4, | |||
| 2, 3, 4, 5, 3, 4, 5, 6}; | |||
| uint8x16_t tbl_idx = vld1q_u8(&reorder_idx[0]); | |||
| constexpr int iw_step = 16; | |||
| constexpr int pack_iw_len = 4; | |||
| const int iw_with_pad = iw + 2 * pw; | |||
| const int iw_with_pad_end = iw_with_pad / iw_step * iw_step; | |||
| rep(ic_idx, ic) { | |||
| const int8_t* sptr = sptr_origin + ic_idx * ic_stride; | |||
| memset(sptr_base, 0, | |||
| sizeof(int8_t) * iw2 * (ih + pad_top + pad_bottom) * | |||
| pack_iw_len); | |||
| sptr_base += iw2 * pad_top * pack_iw_len; | |||
| rep(ih_idx, ih) { | |||
| memset(temp_ptr, 0, iw_with_pad * sizeof(int8_t)); | |||
| memcpy(temp_ptr + pw, sptr, sizeof(int8_t) * iw); | |||
| for (int iw_idx = 0; iw_idx < iw_with_pad_end; iw_idx += iw_step) { | |||
| int8x16_t src[4]; | |||
| int8x16_t dst[4]; | |||
| src[0] = vld1q_s8(temp_ptr + iw_idx); | |||
| src[1] = vld1q_s8(temp_ptr + iw_idx + 4); | |||
| src[2] = vld1q_s8(temp_ptr + iw_idx + 8); | |||
| src[3] = vld1q_s8(temp_ptr + iw_idx + 12); | |||
| dst[0] = vqtbl1q_s8(src[0], tbl_idx); | |||
| dst[1] = vqtbl1q_s8(src[1], tbl_idx); | |||
| dst[2] = vqtbl1q_s8(src[2], tbl_idx); | |||
| dst[3] = vqtbl1q_s8(src[3], tbl_idx); | |||
| vst1q_s8(sptr_base + iw_idx * pack_iw_len + 0, dst[0]); | |||
| vst1q_s8(sptr_base + iw_idx * pack_iw_len + 16, dst[1]); | |||
| vst1q_s8(sptr_base + iw_idx * pack_iw_len + 32, dst[2]); | |||
| vst1q_s8(sptr_base + iw_idx * pack_iw_len + 48, dst[3]); | |||
| } | |||
| for (int iw_idx = iw_with_pad_end; iw_idx < iw_with_pad; ++iw_idx) { | |||
| *(sptr_base + iw_idx * pack_iw_len + 0) = | |||
| *(temp_ptr + iw_idx + 0); | |||
| *(sptr_base + iw_idx * pack_iw_len + 1) = | |||
| *(temp_ptr + iw_idx + 1); | |||
| *(sptr_base + iw_idx * pack_iw_len + 2) = | |||
| *(temp_ptr + iw_idx + 2); | |||
| *(sptr_base + iw_idx * pack_iw_len + 3) = | |||
| *(temp_ptr + iw_idx + 3); | |||
| } | |||
| sptr_base += iw2 * pack_iw_len; | |||
| sptr += iw; | |||
| } | |||
| sptr_base += iw2 * pad_bottom * pack_iw_len; | |||
| } | |||
| } | |||
| static inline void pack_weight_int8_nchw_nchw44_dot(int8_t* dst_ptr, | |||
| const int8_t* src_ptr, | |||
| const int oc, const int ic, | |||
| const int fh, const int fw, | |||
| const int fw2) { | |||
| constexpr int oc_step = 4; | |||
| const int fw_remain = fw2 - fw; | |||
| const int dst_ic_stride = fh * fw2; | |||
| const int oc_step_stride = fh * fw2 * ic * oc_step; | |||
| static const uint8_t transpose_4x4_idx[16] = {0, 4, 8, 12, 1, 5, 9, 13, | |||
| 2, 6, 10, 14, 3, 7, 11, 15}; | |||
| uint8x16_t tbl_transpose_4x4 = vld1q_u8(&transpose_4x4_idx[0]); | |||
| rep_step(oc_idx, oc, oc_step) { | |||
| int32_t* dst_temp_ptr = | |||
| reinterpret_cast<int32_t*>(dst_ptr + oc_idx * ic * fh * fw2); | |||
| const int32_t* src_temp_ptr = reinterpret_cast<const int32_t*>( | |||
| src_ptr + oc_idx * ic * fh * fw); | |||
| // transpose ic and pad | |||
| rep(fh_idx, fh) { | |||
| rep(fw_idx, fw) { | |||
| rep(ic_idx, ic) { | |||
| *(dst_temp_ptr + ic_idx * dst_ic_stride) = *src_temp_ptr; | |||
| src_temp_ptr++; | |||
| } | |||
| dst_temp_ptr++; | |||
| } | |||
| rep(ic_idx, ic) { | |||
| memset(dst_temp_ptr + ic_idx * dst_ic_stride, 0, | |||
| sizeof(int8_t) * oc_step * fw_remain); | |||
| } | |||
| dst_temp_ptr += fw_remain; | |||
| } | |||
| // transpose fw oc | |||
| int8_t* trans_dst_temp_ptr = | |||
| reinterpret_cast<int8_t*>(dst_ptr + oc_idx * ic * fh * fw2); | |||
| rep_step(idx, oc_step_stride, 16) { | |||
| int8x16_t temp = vld1q_s8(trans_dst_temp_ptr + idx); | |||
| vst1q_s8(trans_dst_temp_ptr + idx, | |||
| vqtbl1q_s8(temp, tbl_transpose_4x4)); | |||
| } | |||
| } | |||
| } | |||
| template <BiasMode bias_mode, typename Op, int filter_size, int stride> | |||
| static void conv_direct_int8_nchw_nchw44_dot( | |||
| const int8_t* src, const int8_t* filter, const int32_t* bias, | |||
| int32_t* temp, int8_t* dst, const int oc, const int ic, const int ih, | |||
| const int iw, const int oh, const int oh_block, const int ow, | |||
| const Op& op) { | |||
| MEGDNN_MARK_USED_VAR(temp); | |||
| constexpr int fh = filter_size; | |||
| constexpr int fw = (filter_size + 3) / 4 * 4; | |||
| #if MEGDNN_AARCH64 | |||
| constexpr int big_oc_step = 8; | |||
| #else | |||
| constexpr int big_oc_step = 4; | |||
| #endif | |||
| constexpr int oc_step = 4; | |||
| constexpr int ih_step = 1; | |||
| constexpr int oh_step = 1; | |||
| constexpr int ow_step = 8; | |||
| constexpr int stride_h = stride; | |||
| constexpr int stride_w = stride; | |||
| constexpr int pack_iw_len = stride == 2 ? 1 : 4; | |||
| const int img_stride = oh * ow; | |||
| const int ow_end = ow / ow_step * ow_step; | |||
| const int ow_remain = ow - ow_end; | |||
| const int oc_end = oc / big_oc_step * big_oc_step; | |||
| const int oc_remain = oc - oc_end; | |||
| const int ld_dst_oc = oc_step * img_stride; | |||
| using remain_fun = | |||
| std::function<void(const int8_t* src_ptr, const int8_t* weight_ptr, | |||
| const int32_t* bias_ptr, int8_t* dst_ptr, int ic, | |||
| int ih, int iw, int ld_dst_oc, const Op& op)>; | |||
| remain_fun kern_big_oc_remain = nullptr; | |||
| remain_fun kern_small_oc_remain = nullptr; | |||
| switch (ow_remain) { | |||
| #define cb(step) \ | |||
| case step: \ | |||
| kern_big_oc_remain = \ | |||
| KerNeonDotXXs2Nchw44Int8<bias_mode, Op, step, filter_size, \ | |||
| big_oc_step, ow_step, stride>::impl; \ | |||
| kern_small_oc_remain = \ | |||
| KerNeonDotXXs2Nchw44Int8<bias_mode, Op, step, filter_size, \ | |||
| oc_step, ow_step, stride>::impl; \ | |||
| break; | |||
| UNROLL_CALL_RAW(8, cb); | |||
| default: | |||
| megdnn_assert(0, "no remain %d for kern", ow_remain); | |||
| } | |||
| for (int oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { | |||
| const int weight_offset = oc_idx * ic * fh * fw; | |||
| for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) { | |||
| for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||
| const int src_offset = | |||
| (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * | |||
| pack_iw_len; | |||
| const int dst_offset = | |||
| oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||
| KerNeonDotXXs2Nchw44Int8<bias_mode, Op, ow_step, filter_size, | |||
| big_oc_step, ow_step, | |||
| stride>::impl(src + src_offset, | |||
| filter + weight_offset, | |||
| bias + oc_idx, | |||
| dst + dst_offset, ic, ih, | |||
| iw, ld_dst_oc, op); | |||
| } | |||
| if (ow_remain > 0) { | |||
| const int src_offset = | |||
| (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * | |||
| pack_iw_len; | |||
| const int dst_offset = | |||
| oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||
| kern_big_oc_remain(src + src_offset, filter + weight_offset, | |||
| bias + oc_idx, dst + dst_offset, ic, ih, iw, | |||
| ld_dst_oc, op); | |||
| } | |||
| } | |||
| } | |||
| if (oc_remain > 0) { | |||
| int oc_idx = oc_end; | |||
| const int weight_offset = oc_idx * ic * fh * fw; | |||
| for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) { | |||
| for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||
| const int src_offset = | |||
| (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * | |||
| pack_iw_len; | |||
| const int dst_offset = | |||
| oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||
| KerNeonDotXXs2Nchw44Int8<bias_mode, Op, ow_step, filter_size, | |||
| oc_step, ow_step, | |||
| stride>::impl(src + src_offset, | |||
| filter + weight_offset, | |||
| bias + oc_idx, | |||
| dst + dst_offset, ic, ih, | |||
| iw, ld_dst_oc, op); | |||
| } | |||
| if (ow_remain > 0) { | |||
| const int src_offset = | |||
| (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * | |||
| pack_iw_len; | |||
| const int dst_offset = | |||
| oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||
| kern_small_oc_remain(src + src_offset, filter + weight_offset, | |||
| bias + oc_idx, dst + dst_offset, ic, ih, | |||
| iw, ld_dst_oc, op); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } // namespace | |||
| #endif | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -176,187 +176,202 @@ inline void store_ocx_ow4_remain_static(T& c, const Op& op, int8_t* dst_ptr, | |||
| StoreOcxOw4Remain<c_dim, ow_remain, Op, T>::impl(c, op, dst_ptr, ld_dst_oc); | |||
| } | |||
| ////////////////////Store_OCX_OW8_Remain///////////////////////// | |||
| template <int c_dim, int ow_remain, typename Op, typename T> | |||
| template <int c_dim, int ow_remain, typename Op, typename T, typename T2, | |||
| typename T3> | |||
| struct StoreOcxOw8Remain { | |||
| static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc); | |||
| static void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc); | |||
| }; | |||
| template <typename Op, typename T> | |||
| struct StoreOcxOw8Remain<2, 0, Op, T> { | |||
| static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) { | |||
| op({{c[0][0], c[0][1]}}, dst_ptr); | |||
| op({{c[0][2], c[0][3]}}, dst_ptr + 8); | |||
| op({{c[0][4], c[0][5]}}, dst_ptr + 16); | |||
| op({{c[0][6], c[0][7]}}, dst_ptr + 24); | |||
| template <typename Op, typename T, typename T2, typename T3> | |||
| struct StoreOcxOw8Remain<2, 0, Op, T, T2, T3> { | |||
| static void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { | |||
| op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); | |||
| op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8)); | |||
| op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16)); | |||
| op({{c[0][6], c[0][7]}}, reinterpret_cast<T3>(dst_ptr + 24)); | |||
| op({{c[1][0], c[1][1]}}, dst_ptr + ld_dst_oc); | |||
| op({{c[1][2], c[1][3]}}, dst_ptr + ld_dst_oc + 8); | |||
| op({{c[1][4], c[1][5]}}, dst_ptr + ld_dst_oc + 16); | |||
| op({{c[1][6], c[1][7]}}, dst_ptr + ld_dst_oc + 24); | |||
| op({{c[1][0], c[1][1]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc)); | |||
| op({{c[1][2], c[1][3]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 8)); | |||
| op({{c[1][4], c[1][5]}}, | |||
| reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 16)); | |||
| op({{c[1][6], c[1][7]}}, | |||
| reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 24)); | |||
| } | |||
| }; | |||
| template <typename Op, typename T> | |||
| struct StoreOcxOw8Remain<2, 8, Op, T> { | |||
| static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) { | |||
| op({{c[0][0], c[0][1]}}, dst_ptr); | |||
| op({{c[0][2], c[0][3]}}, dst_ptr + 8); | |||
| op({{c[0][4], c[0][5]}}, dst_ptr + 16); | |||
| op({{c[0][6], c[0][7]}}, dst_ptr + 24); | |||
| template <typename Op, typename T, typename T2, typename T3> | |||
| struct StoreOcxOw8Remain<2, 8, Op, T, T2, T3> { | |||
| static void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { | |||
| op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); | |||
| op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8)); | |||
| op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16)); | |||
| op({{c[0][6], c[0][7]}}, reinterpret_cast<T3>(dst_ptr + 24)); | |||
| op({{c[1][0], c[1][1]}}, dst_ptr + ld_dst_oc); | |||
| op({{c[1][2], c[1][3]}}, dst_ptr + ld_dst_oc + 8); | |||
| op({{c[1][4], c[1][5]}}, dst_ptr + ld_dst_oc + 16); | |||
| op({{c[1][6], c[1][7]}}, dst_ptr + ld_dst_oc + 24); | |||
| op({{c[1][0], c[1][1]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc)); | |||
| op({{c[1][2], c[1][3]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 8)); | |||
| op({{c[1][4], c[1][5]}}, | |||
| reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 16)); | |||
| op({{c[1][6], c[1][7]}}, | |||
| reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 24)); | |||
| } | |||
| }; | |||
| template <typename Op, typename T> | |||
| struct StoreOcxOw8Remain<2, 7, Op, T> { | |||
| static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) { | |||
| op({{c[0][0], c[0][1]}}, dst_ptr); | |||
| op({{c[0][2], c[0][3]}}, dst_ptr + 8); | |||
| op({{c[0][4], c[0][5]}}, dst_ptr + 16); | |||
| op(c[0][6], dst_ptr + 24); | |||
| template <typename Op, typename T, typename T2, typename T3> | |||
| struct StoreOcxOw8Remain<2, 7, Op, T, T2, T3> { | |||
| static void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { | |||
| op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); | |||
| op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8)); | |||
| op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16)); | |||
| op(c[0][6], reinterpret_cast<T3>(dst_ptr + 24)); | |||
| op({{c[1][0], c[1][1]}}, dst_ptr + ld_dst_oc); | |||
| op({{c[1][2], c[1][3]}}, dst_ptr + ld_dst_oc + 8); | |||
| op({{c[1][4], c[1][5]}}, dst_ptr + ld_dst_oc + 16); | |||
| op(c[1][6], dst_ptr + ld_dst_oc + 24); | |||
| op({{c[1][0], c[1][1]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc)); | |||
| op({{c[1][2], c[1][3]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 8)); | |||
| op({{c[1][4], c[1][5]}}, | |||
| reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 16)); | |||
| op(c[1][6], reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 24)); | |||
| } | |||
| }; | |||
| template <typename Op, typename T> | |||
| struct StoreOcxOw8Remain<2, 6, Op, T> { | |||
| static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) { | |||
| op({{c[0][0], c[0][1]}}, dst_ptr); | |||
| op({{c[0][2], c[0][3]}}, dst_ptr + 8); | |||
| op({{c[0][4], c[0][5]}}, dst_ptr + 16); | |||
| template <typename Op, typename T, typename T2, typename T3> | |||
| struct StoreOcxOw8Remain<2, 6, Op, T, T2, T3> { | |||
| static void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { | |||
| op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); | |||
| op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8)); | |||
| op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16)); | |||
| op({{c[1][0], c[1][1]}}, dst_ptr + ld_dst_oc); | |||
| op({{c[1][2], c[1][3]}}, dst_ptr + ld_dst_oc + 8); | |||
| op({{c[1][4], c[1][5]}}, dst_ptr + ld_dst_oc + 16); | |||
| op({{c[1][0], c[1][1]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc)); | |||
| op({{c[1][2], c[1][3]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 8)); | |||
| op({{c[1][4], c[1][5]}}, | |||
| reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 16)); | |||
| } | |||
| }; | |||
| template <typename Op, typename T> | |||
| struct StoreOcxOw8Remain<2, 5, Op, T> { | |||
| static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) { | |||
| op({{c[0][0], c[0][1]}}, dst_ptr); | |||
| op({{c[0][2], c[0][3]}}, dst_ptr + 8); | |||
| op(c[0][4], dst_ptr + 16); | |||
| template <typename Op, typename T, typename T2, typename T3> | |||
| struct StoreOcxOw8Remain<2, 5, Op, T, T2, T3> { | |||
| static void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { | |||
| op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); | |||
| op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8)); | |||
| op(c[0][4], reinterpret_cast<T3>(dst_ptr + 16)); | |||
| op({{c[1][0], c[1][1]}}, dst_ptr + ld_dst_oc); | |||
| op({{c[1][2], c[1][3]}}, dst_ptr + ld_dst_oc + 8); | |||
| op(c[1][4], dst_ptr + ld_dst_oc + 16); | |||
| op({{c[1][0], c[1][1]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc)); | |||
| op({{c[1][2], c[1][3]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 8)); | |||
| op(c[1][4], reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 16)); | |||
| } | |||
| }; | |||
| template <typename Op, typename T> | |||
| struct StoreOcxOw8Remain<2, 4, Op, T> { | |||
| static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) { | |||
| op({{c[0][0], c[0][1]}}, dst_ptr); | |||
| op({{c[0][2], c[0][3]}}, dst_ptr + 8); | |||
| template <typename Op, typename T, typename T2, typename T3> | |||
| struct StoreOcxOw8Remain<2, 4, Op, T, T2, T3> { | |||
| static void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { | |||
| op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); | |||
| op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8)); | |||
| op({{c[1][0], c[1][1]}}, dst_ptr + ld_dst_oc); | |||
| op({{c[1][2], c[1][3]}}, dst_ptr + ld_dst_oc + 8); | |||
| op({{c[1][0], c[1][1]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc)); | |||
| op({{c[1][2], c[1][3]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 8)); | |||
| } | |||
| }; | |||
| template <typename Op, typename T> | |||
| struct StoreOcxOw8Remain<2, 3, Op, T> { | |||
| static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) { | |||
| op({{c[0][0], c[0][1]}}, dst_ptr); | |||
| op(c[0][2], dst_ptr + 8); | |||
| template <typename Op, typename T, typename T2, typename T3> | |||
| struct StoreOcxOw8Remain<2, 3, Op, T, T2, T3> { | |||
| static void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { | |||
| op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); | |||
| op(c[0][2], reinterpret_cast<T3>(dst_ptr + 8)); | |||
| op({{c[1][0], c[1][1]}}, dst_ptr + ld_dst_oc); | |||
| op(c[1][2], dst_ptr + ld_dst_oc + 8); | |||
| op({{c[1][0], c[1][1]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc)); | |||
| op(c[1][2], reinterpret_cast<T3>(dst_ptr + ld_dst_oc + 8)); | |||
| } | |||
| }; | |||
| template <typename Op, typename T> | |||
| struct StoreOcxOw8Remain<2, 2, Op, T> { | |||
| static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) { | |||
| op({{c[0][0], c[0][1]}}, dst_ptr); | |||
| op({{c[1][0], c[1][1]}}, dst_ptr + ld_dst_oc); | |||
| template <typename Op, typename T, typename T2, typename T3> | |||
| struct StoreOcxOw8Remain<2, 2, Op, T, T2, T3> { | |||
| static void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { | |||
| op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); | |||
| op({{c[1][0], c[1][1]}}, reinterpret_cast<T3>(dst_ptr + ld_dst_oc)); | |||
| } | |||
| }; | |||
| template <typename Op, typename T> | |||
| struct StoreOcxOw8Remain<2, 1, Op, T> { | |||
| static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) { | |||
| op(c[0][0], dst_ptr); | |||
| op(c[1][0], dst_ptr + ld_dst_oc); | |||
| template <typename Op, typename T, typename T2, typename T3> | |||
| struct StoreOcxOw8Remain<2, 1, Op, T, T2, T3> { | |||
| static void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { | |||
| op(c[0][0], reinterpret_cast<T3>(dst_ptr)); | |||
| op(c[1][0], reinterpret_cast<T3>(dst_ptr + ld_dst_oc)); | |||
| } | |||
| }; | |||
| template <typename Op, typename T> | |||
| struct StoreOcxOw8Remain<1, 0, Op, T> { | |||
| static void impl(T& c, const Op& op, float32_t* dst_ptr, int) { | |||
| op({{c[0][0], c[0][1]}}, dst_ptr); | |||
| op({{c[0][2], c[0][3]}}, dst_ptr + 8); | |||
| op({{c[0][4], c[0][5]}}, dst_ptr + 16); | |||
| op({{c[0][6], c[0][7]}}, dst_ptr + 24); | |||
| template <typename Op, typename T, typename T2, typename T3> | |||
| struct StoreOcxOw8Remain<1, 0, Op, T, T2, T3> { | |||
| static void impl(T& c, const Op& op, T2 dst_ptr, int) { | |||
| op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); | |||
| op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8)); | |||
| op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16)); | |||
| op({{c[0][6], c[0][7]}}, reinterpret_cast<T3>(dst_ptr + 24)); | |||
| } | |||
| }; | |||
| template <typename Op, typename T> | |||
| struct StoreOcxOw8Remain<1, 8, Op, T> { | |||
| static void impl(T& c, const Op& op, float32_t* dst_ptr, int) { | |||
| op({{c[0][0], c[0][1]}}, dst_ptr); | |||
| op({{c[0][2], c[0][3]}}, dst_ptr + 8); | |||
| op({{c[0][4], c[0][5]}}, dst_ptr + 16); | |||
| op({{c[0][6], c[0][7]}}, dst_ptr + 24); | |||
| template <typename Op, typename T, typename T2, typename T3> | |||
| struct StoreOcxOw8Remain<1, 8, Op, T, T2, T3> { | |||
| static void impl(T& c, const Op& op, T2 dst_ptr, int) { | |||
| op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); | |||
| op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8)); | |||
| op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16)); | |||
| op({{c[0][6], c[0][7]}}, reinterpret_cast<T3>(dst_ptr + 24)); | |||
| } | |||
| }; | |||
| template <typename Op, typename T> | |||
| struct StoreOcxOw8Remain<1, 7, Op, T> { | |||
| static void impl(T& c, const Op& op, float32_t* dst_ptr, int) { | |||
| op({{c[0][0], c[0][1]}}, dst_ptr); | |||
| op({{c[0][2], c[0][3]}}, dst_ptr + 8); | |||
| op({{c[0][4], c[0][5]}}, dst_ptr + 16); | |||
| op(c[0][6], dst_ptr + 24); | |||
| template <typename Op, typename T, typename T2, typename T3> | |||
| struct StoreOcxOw8Remain<1, 7, Op, T, T2, T3> { | |||
| static void impl(T& c, const Op& op, T2 dst_ptr, int) { | |||
| op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); | |||
| op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8)); | |||
| op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16)); | |||
| op(c[0][6], reinterpret_cast<T3>(dst_ptr + 24)); | |||
| } | |||
| }; | |||
| template <typename Op, typename T> | |||
| struct StoreOcxOw8Remain<1, 6, Op, T> { | |||
| static void impl(T& c, const Op& op, float32_t* dst_ptr, int) { | |||
| op({{c[0][0], c[0][1]}}, dst_ptr); | |||
| op({{c[0][2], c[0][3]}}, dst_ptr + 8); | |||
| op({{c[0][4], c[0][5]}}, dst_ptr + 16); | |||
| template <typename Op, typename T, typename T2, typename T3> | |||
| struct StoreOcxOw8Remain<1, 6, Op, T, T2, T3> { | |||
| static void impl(T& c, const Op& op, T2 dst_ptr, int) { | |||
| op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); | |||
| op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8)); | |||
| op({{c[0][4], c[0][5]}}, reinterpret_cast<T3>(dst_ptr + 16)); | |||
| } | |||
| }; | |||
| template <typename Op, typename T> | |||
| struct StoreOcxOw8Remain<1, 5, Op, T> { | |||
| static void impl(T& c, const Op& op, float32_t* dst_ptr, int) { | |||
| op({{c[0][0], c[0][1]}}, dst_ptr); | |||
| op({{c[0][2], c[0][3]}}, dst_ptr + 8); | |||
| op(c[0][4], dst_ptr + 16); | |||
| template <typename Op, typename T, typename T2, typename T3> | |||
| struct StoreOcxOw8Remain<1, 5, Op, T, T2, T3> { | |||
| static void impl(T& c, const Op& op, T2 dst_ptr, int) { | |||
| op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); | |||
| op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8)); | |||
| op(c[0][4], reinterpret_cast<T3>(dst_ptr + 16)); | |||
| } | |||
| }; | |||
| template <typename Op, typename T> | |||
| struct StoreOcxOw8Remain<1, 4, Op, T> { | |||
| static void impl(T& c, const Op& op, float32_t* dst_ptr, int) { | |||
| op({{c[0][0], c[0][1]}}, dst_ptr); | |||
| op({{c[0][2], c[0][3]}}, dst_ptr + 8); | |||
| template <typename Op, typename T, typename T2, typename T3> | |||
| struct StoreOcxOw8Remain<1, 4, Op, T, T2, T3> { | |||
| static void impl(T& c, const Op& op, T2 dst_ptr, int) { | |||
| op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); | |||
| op({{c[0][2], c[0][3]}}, reinterpret_cast<T3>(dst_ptr + 8)); | |||
| } | |||
| }; | |||
| template <typename Op, typename T> | |||
| struct StoreOcxOw8Remain<1, 3, Op, T> { | |||
| static void impl(T& c, const Op& op, float32_t* dst_ptr, int) { | |||
| op({{c[0][0], c[0][1]}}, dst_ptr); | |||
| op(c[0][2], dst_ptr + 8); | |||
| template <typename Op, typename T, typename T2, typename T3> | |||
| struct StoreOcxOw8Remain<1, 3, Op, T, T2, T3> { | |||
| static void impl(T& c, const Op& op, T2 dst_ptr, int) { | |||
| op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); | |||
| op(c[0][2], reinterpret_cast<T3>(dst_ptr + 8)); | |||
| } | |||
| }; | |||
| template <typename Op, typename T> | |||
| struct StoreOcxOw8Remain<1, 2, Op, T> { | |||
| static void impl(T& c, const Op& op, float32_t* dst_ptr, int) { | |||
| op({{c[0][0], c[0][1]}}, dst_ptr); | |||
| template <typename Op, typename T, typename T2, typename T3> | |||
| struct StoreOcxOw8Remain<1, 2, Op, T, T2, T3> { | |||
| static void impl(T& c, const Op& op, T2 dst_ptr, int) { | |||
| op({{c[0][0], c[0][1]}}, reinterpret_cast<T3>(dst_ptr)); | |||
| } | |||
| }; | |||
| template <typename Op, typename T> | |||
| struct StoreOcxOw8Remain<1, 1, Op, T> { | |||
| static void impl(T& c, const Op& op, float32_t* dst_ptr, int) { | |||
| op(c[0][0], dst_ptr); | |||
| template <typename Op, typename T, typename T2, typename T3> | |||
| struct StoreOcxOw8Remain<1, 1, Op, T, T2, T3> { | |||
| static void impl(T& c, const Op& op, T2 dst_ptr, int) { | |||
| op(c[0][0], reinterpret_cast<T3>(dst_ptr)); | |||
| } | |||
| }; | |||
| template <int c_dim, int ow_remain, typename Op, typename T> | |||
| inline void store_ocx_ow8_remain_static(T& c, const Op& op, float32_t* dst_ptr, | |||
| template <int c_dim, int ow_remain, typename Op, typename T, typename T2> | |||
| inline void store_ocx_ow8_remain_static(T& c, const Op& op, T2 dst_ptr, | |||
| int ld_dst_oc) { | |||
| StoreOcxOw8Remain<c_dim, ow_remain, Op, T>::impl(c, op, dst_ptr, ld_dst_oc); | |||
| StoreOcxOw8Remain<c_dim, ow_remain, Op, T, T2, T2>::impl(c, op, dst_ptr, | |||
| ld_dst_oc); | |||
| } | |||
| template <int c_dim, int ow_remain, typename Op, typename T3, typename T, | |||
| typename T2> | |||
| inline void store_ocx_ow8_remain_static_dt(T& c, const Op& op, T2 dst_ptr, | |||
| int ld_dst_oc) { | |||
| StoreOcxOw8Remain<c_dim, ow_remain, Op, T, T2, T3>::impl(c, op, dst_ptr, | |||
| ld_dst_oc); | |||
| } | |||
| ////////////////////Store_OC8_OW8_Remain///////////////////////// | |||
| @@ -522,68 +537,84 @@ inline void init_oc8_ow8(int32x4_t c[2][8], const int32_t* bias_ptr, | |||
| } | |||
| } | |||
| /////////////////////////init_ocx_ow8//////////////////// | |||
| inline float32x4_t neon_vdupq_n(float val) { | |||
| return vdupq_n_f32(val); | |||
| } | |||
| inline int32x4_t neon_vdupq_n(int val) { | |||
| return vdupq_n_s32(val); | |||
| } | |||
| inline float32x4_t neon_vld1q(const float* ptr) { | |||
| return vld1q_f32(ptr); | |||
| } | |||
| inline int32x4_t neon_vld1q(const int* ptr) { | |||
| return vld1q_s32(ptr); | |||
| } | |||
| template <int c_dim, BiasMode bias_mode, int ow_block, typename T, typename T2> | |||
| struct InitOcxOw8 { | |||
| static void impl(T& c, T2 bias_ptr, int oc_step); | |||
| static void impl(T& c, const T2* bias_ptr, int oc_step); | |||
| }; | |||
| template <typename T, typename T2> | |||
| struct InitOcxOw8<2, BiasMode::NO_BIAS, 8, T, T2> { | |||
| static void impl(T& c, const float32_t*, int) { | |||
| #define BAIS_INIT(step) \ | |||
| c[0][step] = vdupq_n_f32(0); \ | |||
| c[1][step] = vdupq_n_f32(0); | |||
| static void impl(T& c, const T2*, int) { | |||
| #define BAIS_INIT(step) \ | |||
| c[0][step] = neon_vdupq_n(static_cast<T2>(0)); \ | |||
| c[1][step] = neon_vdupq_n(static_cast<T2>(0)); | |||
| UNROLL_CALL_RAW(8, BAIS_INIT); | |||
| #undef BAIS_INIT | |||
| } | |||
| }; | |||
| template <typename T, typename T2> | |||
| struct InitOcxOw8<2, BiasMode::NO_BIAS, 4, T, T2> { | |||
| static void impl(T& c, const float32_t*, int) { | |||
| #define BAIS_INIT(step) \ | |||
| c[0][step] = vdupq_n_f32(0); \ | |||
| c[1][step] = vdupq_n_f32(0); | |||
| static void impl(T& c, const T2*, int) { | |||
| #define BAIS_INIT(step) \ | |||
| c[0][step] = neon_vdupq_n(static_cast<T2>(0)); \ | |||
| c[1][step] = neon_vdupq_n(static_cast<T2>(0)); | |||
| UNROLL_CALL_RAW(4, BAIS_INIT); | |||
| #undef BAIS_INIT | |||
| } | |||
| }; | |||
| template <typename T, typename T2> | |||
| struct InitOcxOw8<2, BiasMode::BROADCAST_CHANNEL_BIAS, 8, T, T2> { | |||
| static void impl(T& c, const float32_t* bias_ptr, int oc_step) { | |||
| #define BAIS_INIT(step) \ | |||
| c[0][step] = vld1q_f32(bias_ptr); \ | |||
| c[1][step] = vld1q_f32(bias_ptr + oc_step); | |||
| static void impl(T& c, const T2* bias_ptr, int oc_step) { | |||
| #define BAIS_INIT(step) \ | |||
| c[0][step] = neon_vld1q(bias_ptr); \ | |||
| c[1][step] = neon_vld1q(bias_ptr + oc_step); | |||
| UNROLL_CALL_RAW(8, BAIS_INIT); | |||
| #undef BAIS_INIT | |||
| } | |||
| }; | |||
| template <typename T, typename T2> | |||
| struct InitOcxOw8<2, BiasMode::BROADCAST_CHANNEL_BIAS, 4, T, T2> { | |||
| static void impl(T& c, const float32_t* bias_ptr, int oc_step) { | |||
| #define BAIS_INIT(step) \ | |||
| c[0][step] = vld1q_f32(bias_ptr); \ | |||
| c[1][step] = vld1q_f32(bias_ptr + oc_step); | |||
| static void impl(T& c, const T2* bias_ptr, int oc_step) { | |||
| #define BAIS_INIT(step) \ | |||
| c[0][step] = neon_vld1q(bias_ptr); \ | |||
| c[1][step] = neon_vld1q(bias_ptr + oc_step); | |||
| UNROLL_CALL_RAW(4, BAIS_INIT); | |||
| #undef BAIS_INIT | |||
| } | |||
| }; | |||
| template <typename T, typename T2> | |||
| struct InitOcxOw8<2, BiasMode::BIAS, 8, T, T2> { | |||
| static void impl(T& c, const float32_t* bias_ptr, int oc_step) { | |||
| static void impl(T& c, const T2* bias_ptr, int oc_step) { | |||
| constexpr int simd_len = 4; | |||
| #define BAIS_INIT(step) \ | |||
| c[0][step] = vld1q_f32(bias_ptr + step * simd_len); \ | |||
| c[1][step] = vld1q_f32(bias_ptr + oc_step + step * simd_len); | |||
| #define BAIS_INIT(step) \ | |||
| c[0][step] = neon_vld1q(bias_ptr + step * simd_len); \ | |||
| c[1][step] = neon_vld1q(bias_ptr + oc_step + step * simd_len); | |||
| UNROLL_CALL_RAW(8, BAIS_INIT); | |||
| #undef BAIS_INIT | |||
| } | |||
| }; | |||
| template <typename T, typename T2> | |||
| struct InitOcxOw8<2, BiasMode::BIAS, 4, T, T2> { | |||
| static void impl(T& c, const float32_t* bias_ptr, int oc_step) { | |||
| static void impl(T& c, const T2* bias_ptr, int oc_step) { | |||
| constexpr int simd_len = 4; | |||
| #define BAIS_INIT(step) \ | |||
| c[0][step] = vld1q_f32(bias_ptr + step * simd_len); \ | |||
| c[1][step] = vld1q_f32(bias_ptr + oc_step + step * simd_len); | |||
| #define BAIS_INIT(step) \ | |||
| c[0][step] = neon_vld1q(bias_ptr + step * simd_len); \ | |||
| c[1][step] = neon_vld1q(bias_ptr + oc_step + step * simd_len); | |||
| UNROLL_CALL_RAW(4, BAIS_INIT); | |||
| #undef BAIS_INIT | |||
| } | |||
| @@ -591,57 +622,57 @@ struct InitOcxOw8<2, BiasMode::BIAS, 4, T, T2> { | |||
| template <typename T, typename T2> | |||
| struct InitOcxOw8<1, BiasMode::NO_BIAS, 8, T, T2> { | |||
| static void impl(T& c, const float32_t*, int) { | |||
| #define BAIS_INIT(step) c[0][step] = vdupq_n_f32(0); | |||
| static void impl(T& c, const T2*, int) { | |||
| #define BAIS_INIT(step) c[0][step] = neon_vdupq_n(static_cast<T2>(0)); | |||
| UNROLL_CALL_RAW(8, BAIS_INIT); | |||
| #undef BAIS_INIT | |||
| } | |||
| }; | |||
| template <typename T, typename T2> | |||
| struct InitOcxOw8<1, BiasMode::NO_BIAS, 4, T, T2> { | |||
| static void impl(T& c, const float32_t*, int) { | |||
| #define BAIS_INIT(step) c[0][step] = vdupq_n_f32(0); | |||
| static void impl(T& c, const T2*, int) { | |||
| #define BAIS_INIT(step) c[0][step] = neon_vdupq_n(static_cast<T2>(0)); | |||
| UNROLL_CALL_RAW(4, BAIS_INIT); | |||
| #undef BAIS_INIT | |||
| } | |||
| }; | |||
| template <typename T, typename T2> | |||
| struct InitOcxOw8<1, BiasMode::BROADCAST_CHANNEL_BIAS, 8, T, T2> { | |||
| static void impl(T& c, const float32_t* bias_ptr, int) { | |||
| #define BAIS_INIT(step) c[0][step] = vld1q_f32(bias_ptr); | |||
| static void impl(T& c, const T2* bias_ptr, int) { | |||
| #define BAIS_INIT(step) c[0][step] = neon_vld1q(bias_ptr); | |||
| UNROLL_CALL_RAW(8, BAIS_INIT); | |||
| #undef BAIS_INIT | |||
| } | |||
| }; | |||
| template <typename T, typename T2> | |||
| struct InitOcxOw8<1, BiasMode::BROADCAST_CHANNEL_BIAS, 4, T, T2> { | |||
| static void impl(T& c, const float32_t* bias_ptr, int) { | |||
| #define BAIS_INIT(step) c[0][step] = vld1q_f32(bias_ptr); | |||
| static void impl(T& c, const T2* bias_ptr, int) { | |||
| #define BAIS_INIT(step) c[0][step] = neon_vld1q(bias_ptr); | |||
| UNROLL_CALL_RAW(4, BAIS_INIT); | |||
| #undef BAIS_INIT | |||
| } | |||
| }; | |||
| template <typename T, typename T2> | |||
| struct InitOcxOw8<1, BiasMode::BIAS, 8, T, T2> { | |||
| static void impl(T& c, const float32_t* bias_ptr, int) { | |||
| static void impl(T& c, const T2* bias_ptr, int) { | |||
| constexpr int simd_len = 4; | |||
| #define BAIS_INIT(step) c[0][step] = vld1q_f32(bias_ptr + step * simd_len); | |||
| #define BAIS_INIT(step) c[0][step] = neon_vld1q(bias_ptr + step * simd_len); | |||
| UNROLL_CALL_RAW(8, BAIS_INIT); | |||
| #undef BAIS_INIT | |||
| } | |||
| }; | |||
| template <typename T, typename T2> | |||
| struct InitOcxOw8<1, BiasMode::BIAS, 4, T, T2> { | |||
| static void impl(T& c, const float32_t* bias_ptr, int) { | |||
| static void impl(T& c, const T2* bias_ptr, int) { | |||
| constexpr int simd_len = 4; | |||
| #define BAIS_INIT(step) c[0][step] = vld1q_f32(bias_ptr + step * simd_len); | |||
| #define BAIS_INIT(step) c[0][step] = neon_vld1q(bias_ptr + step * simd_len); | |||
| UNROLL_CALL_RAW(4, BAIS_INIT); | |||
| #undef BAIS_INIT | |||
| } | |||
| }; | |||
| template <int c_dim, BiasMode bias_mode, int ow_block, typename T, typename T2> | |||
| inline void init_ocx_ow8(T& c, T2 bias_ptr, int oc_step) { | |||
| inline void init_ocx_ow8(T& c, const T2* bias_ptr, int oc_step) { | |||
| InitOcxOw8<c_dim, bias_mode, ow_block, T, T2>::impl(c, bias_ptr, oc_step); | |||
| } | |||
| /////////////////////init_ocx_ow4///////////////////// | |||
| @@ -55,6 +55,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||
| AlgoS8ChanWiseStride2NCHW44 s8_channel_wise_stride2_nchw44; | |||
| #if __ARM_FEATURE_DOTPROD | |||
| AlgoDotS8DirectNCHWNCHW44 ds8_direct_stride2_nchw_nchw44; | |||
| AlgoDotS8DirectStride1 ds8_direct_stride1_large_group{true}; | |||
| AlgoDotS8DirectStride1 ds8_direct_stride1_small_group{false}; | |||
| AlgoDotS8DirectStride2 ds8_direct_stride2_large_group{true}; | |||
| @@ -93,6 +94,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||
| public: | |||
| AlgoPack() { | |||
| #if __ARM_FEATURE_DOTPROD | |||
| direct_algos.emplace_back(&ds8_direct_stride2_nchw_nchw44); | |||
| direct_algos.emplace_back(&ds8_direct_stride1_large_group); | |||
| direct_algos.emplace_back(&ds8_direct_stride1_small_group); | |||
| direct_algos.emplace_back(&ds8_direct_stride2_large_group); | |||
| @@ -62,6 +62,7 @@ private: | |||
| class AlgoFP16WinogradF23_8x8; | |||
| #endif | |||
| #if __ARM_FEATURE_DOTPROD | |||
| class AlgoDotS8DirectNCHWNCHW44; | |||
| class AlgoDotS8DirectStride1; | |||
| class AlgoDotS8DirectStride2; | |||
| class AlgoDotU8DirectStride1; | |||
| @@ -60,6 +60,14 @@ struct Vfmaq_laneq_f32 { | |||
| return vfmaq_laneq_f32(a, b, v, lane); | |||
| } | |||
| }; | |||
| #if __ARM_FEATURE_DOTPROD | |||
| struct Vdotq_laneq_s32 { | |||
| template <const int lane> | |||
| static int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) { | |||
| return vdotq_laneq_s32(a, b, v, lane); | |||
| } | |||
| }; | |||
| #endif | |||
| } // namespace | |||
| } // namespace megdnn | |||
| @@ -481,37 +481,71 @@ UNROLL_CALL_RAW(4, cb); | |||
| #define vdup_laneq_s16(vec, lane) Vdup_laneq_s16_armv7<lane>::impl(vec) | |||
| namespace { | |||
| template <int lane> | |||
| struct Vfmap_laneq_f32_armv7 { | |||
| struct Vfmaq_laneq_f32_armv7 { | |||
| static float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v); | |||
| }; | |||
| template <> | |||
| struct Vfmap_laneq_f32_armv7<0> { | |||
| struct Vfmaq_laneq_f32_armv7<0> { | |||
| static float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) { | |||
| return vmlaq_lane_f32(a, b, vget_low_f32(v), 0); | |||
| } | |||
| }; | |||
| template <> | |||
| struct Vfmap_laneq_f32_armv7<1> { | |||
| struct Vfmaq_laneq_f32_armv7<1> { | |||
| static float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) { | |||
| return vmlaq_lane_f32(a, b, vget_low_f32(v), 1); | |||
| } | |||
| }; | |||
| template <> | |||
| struct Vfmap_laneq_f32_armv7<2> { | |||
| struct Vfmaq_laneq_f32_armv7<2> { | |||
| static float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) { | |||
| return vmlaq_lane_f32(a, b, vget_high_f32(v), 0); | |||
| } | |||
| }; | |||
| template <> | |||
| struct Vfmap_laneq_f32_armv7<3> { | |||
| struct Vfmaq_laneq_f32_armv7<3> { | |||
| static float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) { | |||
| return vmlaq_lane_f32(a, b, vget_high_f32(v), 1); | |||
| } | |||
| }; | |||
| } // namespace | |||
| #define vfmaq_laneq_f32(a, b, v, lane) \ | |||
| Vfmap_laneq_f32_armv7<lane>::impl(a, b, v) | |||
| Vfmaq_laneq_f32_armv7<lane>::impl(a, b, v) | |||
| #if __ARM_FEATURE_DOTPROD | |||
| template <int lane> | |||
| struct Vdotq_laneq_s32_armv7 { | |||
| static int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v); | |||
| }; | |||
| template <> | |||
| struct Vdotq_laneq_s32_armv7<0> { | |||
| static int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) { | |||
| return vdotq_lane_s32(a, b, vget_low_s32(v), 0); | |||
| } | |||
| }; | |||
| template <> | |||
| struct Vdotq_laneq_s32_armv7<1> { | |||
| static int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) { | |||
| return vdotq_lane_s32(a, b, vget_low_s32(v), 1); | |||
| } | |||
| }; | |||
| template <> | |||
| struct Vdotq_laneq_s32_armv7<2> { | |||
| static int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) { | |||
| return vdotq_lane_s32(a, b, vget_high_s32(v), 0); | |||
| } | |||
| }; | |||
| template <> | |||
| struct Vdotq_laneq_s32_armv7<3> { | |||
| static int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) { | |||
| return vdotq_lane_s32(a, b, vget_high_f32(v), 1); | |||
| } | |||
| }; | |||
| #define vdotq_laneq_s32(a, b, v, lane) \ | |||
| Vdotq_laneq_s32_armv7<lane>::impl(a, b, v) | |||
| #endif | |||
| #endif | |||
| @@ -109,14 +109,12 @@ static void benchmark_convbias(Handle* handle, bool is_fp32 = false) { | |||
| .set_dtype(4, dtype::QuantizedS8(60.25)) | |||
| .set_display(false); | |||
| benchmarker_int.set_before_exec_callback( | |||
| conv_bias::ConvBiasAlgoChecker<ConvBias>( | |||
| "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384")); | |||
| conv_bias::ConvBiasAlgoChecker<ConvBias>("IM2COLMATMUL:.+")); | |||
| Benchmarker<ConvBias> benchmarker_float(handle); | |||
| benchmarker_float.set_display(false).set_times(RUNS); | |||
| benchmarker_float.set_before_exec_callback( | |||
| conv_bias::ConvBiasAlgoChecker<ConvBias>( | |||
| "IM2COLMATMUL:AARCH64_F32K8X12X1:192")); | |||
| conv_bias::ConvBiasAlgoChecker<ConvBias>("IM2COLMATMUL:.+")); | |||
| Benchmarker<ConvBias> benchmarker_nchw44(handle); | |||
| if (is_fp32) { | |||
| @@ -213,6 +211,15 @@ static void benchmark_convbias(Handle* handle, bool is_fp32 = false) { | |||
| run(1, 256, 256, 14, 14, 3, 1, false); | |||
| run(1, 512, 512, 7, 7, 3, 1, false); | |||
| } else { | |||
| run(1, 1, 4, 112, 112, 2, 2, true); | |||
| run(1, 3, 32, 224, 224, 3, 2, true); | |||
| run(1, 3, 32, 224, 224, 5, 2, true); | |||
| run(1, 3, 64, 224, 224, 7, 2, true); | |||
| run(1, 1, 4, 112, 112, 2, 1, true); | |||
| run(1, 3, 32, 224, 224, 3, 1, true); | |||
| run(1, 3, 32, 224, 224, 5, 1, true); | |||
| run(1, 3, 64, 224, 224, 7, 1, true); | |||
| for (size_t stride : {1, 2}) { | |||
| printf("stride %zu\n", stride); | |||
| for (size_t filter_size : {2, 3, 5, 7}) { | |||
| @@ -228,9 +235,11 @@ static void benchmark_convbias(Handle* handle, bool is_fp32 = false) { | |||
| } | |||
| TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_NCHW44) { | |||
| benchmark_convbias(handle(), true); | |||
| benchmark_convbias(handle(), false); | |||
| } | |||
| TEST_F(ARM_COMMON_MULTI_THREADS, BENCHMARK_CONVBIAS_NCHW44) { | |||
| benchmark_convbias(handle(), true); | |||
| benchmark_convbias(handle(), false); | |||
| } | |||
| #endif | |||
| @@ -557,6 +557,16 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2_SMALL_GROUP) { | |||
| /****************************dot qint8 direct*************************/ | |||
| #if __ARM_FEATURE_DOTPROD | |||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44) { | |||
| checker_conv_bias_qint8x8x8( | |||
| get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, false, | |||
| true), | |||
| handle(), "ARMDOTS8_NCHW_NCHW44"); | |||
| checker_conv_bias_qint8x8x8( | |||
| get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, false, | |||
| true), | |||
| handle(), "ARMDOTS8_NCHW_NCHW44"); | |||
| } | |||
| TEST_F(ARM_COMMON_MULTI_THREADS, | |||
| CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_LARGE_GROUP) { | |||
| checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args( | |||