GitOrigin-RevId: a28a97fcb5
tags/v1.10.0
| @@ -557,7 +557,10 @@ pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0) | |||||
| 'layout is (K/8, M/8, 8(k), 8(m)) x (K/8, N, 8(k))'), | 'layout is (K/8, M/8, 8(k), 8(m)) x (K/8, N, 8(k))'), | ||||
| Doc('MK4_DOT = 3', 'Split 4 from M and K, better for neon dotprod:' | Doc('MK4_DOT = 3', 'Split 4 from M and K, better for neon dotprod:' | ||||
| 'M/4, K/4, 4(m), 4(k)) x (K/4, N, 4(k)). if transposeA the ' | 'M/4, K/4, 4(m), 4(k)) x (K/4, N, 4(k)). if transposeA the ' | ||||
| 'layout is (K/4, M/4, 4(m), 4(k)) x (K/4, N, 4(k))')) | |||||
| 'layout is (K/4, M/4, 4(m), 4(k)) x (K/4, N, 4(k))'), | |||||
| Doc('N32K4_DOT = 4', 'Split 32 from N and 4 from K, better for neon gevm dotprod:' | |||||
| 'N/32, K/4, 32(n), 4(k)') | |||||
| ) | |||||
| ) | ) | ||||
| (pdef('SVD'). | (pdef('SVD'). | ||||
| @@ -127,6 +127,37 @@ public: | |||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW_NCHW44_DOT_S8) | MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW_NCHW44_DOT_S8) | ||||
| }; | }; | ||||
| class ConvBiasImpl::AlgoDotS8DirectChanWiseLarge final : public AlgoBase { | |||||
| public: | |||||
| AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||||
| const char* name() const override { return "ARMDOTS8_DIRECT_CHANWISE_LARGE"; } | |||||
| bool usable(const NCBKernSizeParam&, AlgoSelectionStrategy algo_selection_strategy) | |||||
| const override; | |||||
| size_t get_workspace(const NCBKernSizeParam&) const override; | |||||
| virtual SmallVector<NCBKern> dispatch_kerns( | |||||
| const NCBKernSizeParam& param) const override; | |||||
| ConvAlgoTypePack get_algo_type() const override { | |||||
| return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||||
| } | |||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DOT_DIRECT_CHANWISE_LARGE_S8) | |||||
| }; | |||||
| class ConvBiasImpl::AlgoDotS8Im2colChanWiseLarge final : public AlgoBase { | |||||
| public: | |||||
| AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||||
| const char* name() const override { return "ARMDOTS8_IM2COL_CHANWISE_LARGE"; } | |||||
| bool usable(const NCBKernSizeParam&, AlgoSelectionStrategy algo_selection_strategy) | |||||
| const override; | |||||
| size_t get_workspace(const NCBKernSizeParam&) const override; | |||||
| virtual SmallVector<NCBKern> dispatch_kerns( | |||||
| const NCBKernSizeParam& param) const override; | |||||
| ConvAlgoTypePack get_algo_type() const override { | |||||
| return {AlgoDataType::QINT8X8X32, AlgoCategory::IM2COL}; | |||||
| } | |||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DOT_IM2COL_CHANWISE_LARGE_S8) | |||||
| }; | |||||
| class ConvBiasImpl::AlgoDotS8DirectStride1 final : public AlgoBase { | class ConvBiasImpl::AlgoDotS8DirectStride1 final : public AlgoBase { | ||||
| public: | public: | ||||
| AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | ||||
| @@ -0,0 +1,270 @@ | |||||
| /** | |||||
| * \file dnn/src/arm_common/conv_bias/int8/chanwise_direct_dot.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include <arm_neon.h> | |||||
| #include "src/arm_common/conv_bias/int8/algos.h" | |||||
| #include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large.h" | |||||
| #include "src/common/unroll_macro.h" | |||||
| #if MGB_ENABLE_DOT | |||||
| #include "midout.h" | |||||
| MIDOUT_DECL(megdnn_arm_common_conv_bias_int8_direct_dot_large_kernel) | |||||
| using namespace megdnn; | |||||
| using namespace arm_common; | |||||
| using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; | |||||
| using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; | |||||
| using NCBKernIndex = fallback::ConvBiasImpl::NCBKernIndex; | |||||
| namespace { | |||||
| class DirectConvRunner { | |||||
| public: | |||||
| DirectConvRunner(size_t flt_size, size_t stride) { | |||||
| if (flt_size == 9 && stride == 1) { | |||||
| m_func = megdnn_dot_nchw_large_chanwise_direct_conv_9x9s1_oh4_ow16; | |||||
| } else { | |||||
| megdnn_assert(flt_size == 9 && stride == 2); | |||||
| m_func = megdnn_dot_nchw_large_chanwise_direct_conv_9x9s2_oh4_ow16; | |||||
| } | |||||
| } | |||||
| size_t get_round_fw(const ConvBiasImpl::NCBKernSizeParam& param) const { | |||||
| auto&& fm = param.filter_meta; | |||||
| auto FW = fm.spatial[1]; | |||||
| return round_up((size_t)FW, m_block_k); | |||||
| } | |||||
| size_t get_round_iw(const ConvBiasImpl::NCBKernSizeParam& param) const { | |||||
| auto&& fm = param.filter_meta; | |||||
| size_t SW = fm.stride[1]; | |||||
| size_t OW = param.osz[1]; | |||||
| size_t round_ow = round_up(OW, m_block_ow); | |||||
| size_t round_fw = get_round_fw(param); | |||||
| size_t pad_iw = round_ow * SW - SW + round_fw; | |||||
| return round_up(pad_iw, m_align_iw); | |||||
| } | |||||
| size_t get_round_ih(const ConvBiasImpl::NCBKernSizeParam& param) const { | |||||
| auto&& fm = param.filter_meta; | |||||
| size_t SH = fm.stride[0]; | |||||
| size_t OH = param.osz[0]; | |||||
| auto FH = fm.spatial[0]; | |||||
| size_t round_oh = round_up(OH, m_block_oh); | |||||
| return round_oh * SH - SH + FH; | |||||
| } | |||||
| WorkspaceBundle get_sub_bundle(const ConvBiasImpl::NCBKernSizeParam& param) const { | |||||
| auto&& fm = param.filter_meta; | |||||
| auto FH = fm.spatial[0]; | |||||
| size_t round_filter = get_round_fw(param) * FH; | |||||
| size_t round_ih = get_round_ih(param); | |||||
| size_t round_iw = get_round_iw(param); | |||||
| size_t pad_src = round_iw * round_ih; | |||||
| return {nullptr, {pad_src, round_filter}}; | |||||
| } | |||||
| WorkspaceBundle get_total_bundle( | |||||
| const ConvBiasImpl::NCBKernSizeParam& param) const { | |||||
| auto sub_bundle = get_sub_bundle(param); | |||||
| auto sub_bundle_size = sub_bundle.total_size_in_bytes(); | |||||
| size_t nr_threads = param.nr_threads; | |||||
| SmallVector<size_t> sizes_in_bytes; | |||||
| for (size_t i = 0; i < nr_threads; ++i) { | |||||
| sizes_in_bytes.push_back(sub_bundle_size); | |||||
| } | |||||
| WorkspaceBundle total_bundle(nullptr, sizes_in_bytes); | |||||
| return total_bundle; | |||||
| } | |||||
| void run( | |||||
| const int8_t* pad_src_ptr, const int8_t* round_filter_ptr, int32_t bias, | |||||
| int8_t* dst_ptr, size_t OH, size_t OW, size_t pad_iw, float scale, | |||||
| int8_t relu_val) const { | |||||
| const size_t ow_end = OW / m_block_ow * m_block_ow; | |||||
| const size_t ow_remain = OW - ow_end; | |||||
| const size_t oh_end = OH / m_block_oh * m_block_oh; | |||||
| const size_t oh_remain = OH - oh_end; | |||||
| int8_t cache[4 * 16]; | |||||
| for (size_t oh = 0; oh < oh_end; oh += m_block_oh) { | |||||
| for (size_t ow = 0; ow < ow_end; ow += m_block_ow) { | |||||
| m_func(pad_src_ptr, round_filter_ptr, bias, dst_ptr, oh, ow, OH, OW, | |||||
| pad_iw, scale, relu_val); | |||||
| } | |||||
| if (ow_remain > 0) { | |||||
| m_func(pad_src_ptr, round_filter_ptr, bias, | |||||
| &cache[0] - (oh * m_block_ow + ow_end), oh, ow_end, OH, | |||||
| m_block_ow, pad_iw, scale, relu_val); | |||||
| for (size_t i = 0; i < m_block_oh; ++i) { | |||||
| for (size_t j = 0; j < ow_remain; ++j) { | |||||
| dst_ptr[(i + oh) * OW + (j + ow_end)] = cache[i * 16 + j]; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| if (oh_remain > 0) { | |||||
| for (size_t ow = 0; ow < ow_end; ow += m_block_ow) { | |||||
| m_func(pad_src_ptr, round_filter_ptr, bias, | |||||
| &cache[0] - (oh_end * m_block_ow + ow), oh_end, ow, OH, | |||||
| m_block_ow, pad_iw, scale, relu_val); | |||||
| for (size_t i = 0; i < oh_remain; ++i) { | |||||
| for (size_t j = 0; j < m_block_ow; ++j) { | |||||
| dst_ptr[(i + oh_end) * OW + (j + ow)] = cache[i * 16 + j]; | |||||
| } | |||||
| } | |||||
| } | |||||
| if (ow_remain > 0) { | |||||
| m_func(pad_src_ptr, round_filter_ptr, bias, | |||||
| &cache[0] - (oh_end * m_block_ow + ow_end), oh_end, ow_end, OH, | |||||
| m_block_ow, pad_iw, scale, relu_val); | |||||
| for (size_t i = 0; i < oh_remain; ++i) { | |||||
| for (size_t j = 0; j < ow_remain; ++j) { | |||||
| dst_ptr[(i + oh_end) * OW + (j + ow_end)] = cache[i * 16 + j]; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| private: | |||||
| std::function<void( | |||||
| const int8_t* src, const int8_t* weight, int32_t bias, int8_t* dst, | |||||
| size_t oh, size_t ow, size_t OH, size_t OW, size_t pad_iw, | |||||
| const float scale, int8_t relu_val)> | |||||
| m_func; | |||||
| size_t m_block_oh{4}; | |||||
| size_t m_block_ow{16}; | |||||
| size_t m_block_k{4}; | |||||
| size_t m_align_iw{16}; | |||||
| }; | |||||
| void do_conv( | |||||
| const WorkspaceBundle& bundle, const NCBKernParam& kern_param, | |||||
| const NCBKernIndex& ncb_index, const DirectConvRunner& runner) { | |||||
| auto&& fm = kern_param.filter_meta; | |||||
| size_t PH = kern_param.filter_meta.padding[0]; | |||||
| size_t PW = kern_param.filter_meta.padding[1]; | |||||
| size_t OH = kern_param.osz[0]; | |||||
| size_t OW = kern_param.osz[1]; | |||||
| size_t IH = kern_param.isz[0]; | |||||
| size_t IW = kern_param.isz[1]; | |||||
| size_t FH = fm.spatial[0]; | |||||
| size_t FW = fm.spatial[1]; | |||||
| float scale_bias = kern_param.bias_type.param<dtype::QuantizedS32>().scale; | |||||
| float scale_dst = kern_param.dst_type.param<dtype::QuantizedS8>().scale; | |||||
| float scale_dst_div = 1.f / scale_dst; | |||||
| size_t batch_id = ncb_index.ndrange_id[0]; | |||||
| size_t group_id = ncb_index.ndrange_id[1]; | |||||
| int8_t* pad_src_ptr = static_cast<int8_t*>(bundle.get(0)); | |||||
| int8_t* round_filter_ptr = static_cast<int8_t*>(bundle.get(1)); | |||||
| const int8_t* sptr = kern_param.src<dt_int8>(batch_id, group_id); | |||||
| const int32_t* bptr = kern_param.bias<dt_int32>(batch_id, group_id); | |||||
| const int8_t* fptr = kern_param.filter<dt_int8>(group_id); | |||||
| void* dst = kern_param.dst<void>(batch_id, group_id); | |||||
| size_t pad_iw = runner.get_round_iw(kern_param); | |||||
| memset(pad_src_ptr, 0, bundle.get_size(0)); | |||||
| rep(ih, IH) { | |||||
| std::memcpy( | |||||
| pad_src_ptr + (ih + PH) * pad_iw + PW, sptr + ih * IW, | |||||
| sizeof(int8_t) * IW); | |||||
| } | |||||
| memset(round_filter_ptr, 0, bundle.get_size(1)); | |||||
| size_t round_fw = runner.get_round_fw(kern_param); | |||||
| for (size_t fh = 0; fh < FH; ++fh) { | |||||
| std::memcpy(round_filter_ptr + fh * round_fw, fptr + fh * FW, FW); | |||||
| } | |||||
| int8_t relu_val = kern_param.nonlineMode == NonlineMode::RELU ? 0 : -128; | |||||
| int32_t bias_val = kern_param.bias_mode == BiasMode::NO_BIAS ? 0 : *bptr; | |||||
| int8_t* dst_ptr = (int8_t*)dst; | |||||
| runner.run( | |||||
| pad_src_ptr, round_filter_ptr, bias_val, dst_ptr, OH, OW, pad_iw, | |||||
| scale_bias * scale_dst_div, relu_val); | |||||
| } | |||||
| } // namespace | |||||
| bool ConvBiasImpl::AlgoDotS8DirectChanWiseLarge::usable( | |||||
| const NCBKernSizeParam& param, AlgoSelectionStrategy) const { | |||||
| if (!cpuinfo_has_arm_neon_dot()) { | |||||
| return false; | |||||
| } | |||||
| auto&& fm = param.filter_meta; | |||||
| auto FH = fm.spatial[0]; | |||||
| auto FW = fm.spatial[1]; | |||||
| auto SH = fm.stride[0]; | |||||
| auto SW = fm.stride[1]; | |||||
| auto noline = param.nonlineMode; | |||||
| auto bias_mode = param.bias_mode; | |||||
| bool avaible = | |||||
| //! src and filter are qint8, dst is qint8 or qint32 | |||||
| (param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||||
| param.filter_type.enumv() == DTypeEnum::QuantizedS8 && | |||||
| (param.dst_type.enumv() == DTypeEnum::QuantizedS8)) && | |||||
| fm.format == param::Convolution::Format::NCHW && !fm.should_flip && | |||||
| (noline == NonlineMode::IDENTITY || noline == NonlineMode::RELU) && | |||||
| (bias_mode == BiasMode::NO_BIAS || | |||||
| bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) && | |||||
| fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||||
| SH == SW && (SH == 1 || SH == 2) && FH == FW && (FH == 9) && fm.icpg == 1 && | |||||
| fm.ocpg == 1; | |||||
| return avaible; | |||||
| } | |||||
| size_t ConvBiasImpl::AlgoDotS8DirectChanWiseLarge::get_workspace( | |||||
| const NCBKernSizeParam& param) const { | |||||
| MIDOUT_BEGIN( | |||||
| megdnn_arm_common_conv_bias_int8_direct_dot_large_kernel, | |||||
| midout_iv("AlgoDotS8DirectChanWiseLarge::get_workspace"_hash)) { | |||||
| auto&& fm = param.filter_meta; | |||||
| DirectConvRunner runner(fm.spatial[0], fm.stride[0]); | |||||
| auto total_bundle = runner.get_total_bundle(param); | |||||
| return total_bundle.total_size_in_bytes(); | |||||
| } | |||||
| MIDOUT_END(); | |||||
| return 0; | |||||
| } | |||||
| SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoDotS8DirectChanWiseLarge:: | |||||
| dispatch_kerns(const NCBKernSizeParam& param) const { | |||||
| MIDOUT_BEGIN( | |||||
| megdnn_arm_common_conv_bias_int8_direct_dot_large_kernel, | |||||
| midout_iv("AlgoDotS8DirectChanWiseLarge::dispatch_kerns"_hash)) { | |||||
| SmallVector<ConvBiasImpl::NCBKern> ret_kerns; | |||||
| auto&& fm = param.filter_meta; | |||||
| DirectConvRunner runner(fm.spatial[0], fm.stride[0]); | |||||
| WorkspaceBundle wbundle = runner.get_sub_bundle(param); | |||||
| WorkspaceBundle total_bundle = runner.get_total_bundle(param); | |||||
| auto exec_one_group = [wbundle, total_bundle, runner]( | |||||
| const NCBKernParam& kern_param, | |||||
| const NCBKernIndex& ncb_index) mutable { | |||||
| WorkspaceBundle temp_total_bundle = total_bundle; | |||||
| temp_total_bundle.set(kern_param.workspace_ptr); | |||||
| WorkspaceBundle temp_bundle = wbundle; | |||||
| temp_bundle.set(temp_total_bundle.get(ncb_index.thread_id)); | |||||
| do_conv(temp_bundle, kern_param, ncb_index, runner); | |||||
| }; | |||||
| size_t N = param.n; | |||||
| size_t group = fm.group; | |||||
| ret_kerns.push_back({exec_one_group, {N, group}}); | |||||
| return ret_kerns; | |||||
| } | |||||
| MIDOUT_END(); | |||||
| return {}; | |||||
| } | |||||
| #endif | |||||
| @@ -0,0 +1,425 @@ | |||||
| /** | |||||
| * \file dnn/src/arm_common/conv_bias/int8/chanwise_im2col_dot.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "megdnn/arch.h" | |||||
| #if MGB_ENABLE_DOT | |||||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||||
| #include "src/arm_common/conv_bias/int8/algos.h" | |||||
| #include "src/arm_common/matrix_mul/int8/gemv.h" | |||||
| #include "midout.h" | |||||
| MIDOUT_DECL(megdnn_arm_common_conv_bias_int8_im2col_dot_large_kernel) | |||||
| using namespace megdnn; | |||||
| using namespace arm_common; | |||||
| using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; | |||||
| using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; | |||||
| using NCBKernIndex = fallback::ConvBiasImpl::NCBKernIndex; | |||||
| namespace { | |||||
| constexpr size_t block_n = 32; | |||||
| constexpr size_t block_k = 4; | |||||
| WorkspaceBundle get_sub_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { | |||||
| auto&& fm = param.filter_meta; | |||||
| auto OH = param.osz[0]; | |||||
| auto OW = param.osz[1]; | |||||
| size_t IH = param.isz[0]; | |||||
| size_t IW = param.isz[1]; | |||||
| auto FH = fm.spatial[0]; | |||||
| auto FW = fm.spatial[1]; | |||||
| size_t PH = param.filter_meta.padding[0]; | |||||
| size_t PW = param.filter_meta.padding[1]; | |||||
| size_t round_ohw = round_up((size_t)OH * OW, block_n); | |||||
| size_t round_filter = round_up((size_t)FW, block_k) * FH; | |||||
| size_t pad_src = (IW + PW * 2) * (IH + PH * 2); | |||||
| return {nullptr, | |||||
| {pad_src, round_filter, round_ohw * round_filter, | |||||
| round_ohw * sizeof(int32_t)}}; | |||||
| } | |||||
| WorkspaceBundle get_total_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { | |||||
| auto sub_bundle = get_sub_bundle(param); | |||||
| auto sub_bundle_size = sub_bundle.total_size_in_bytes(); | |||||
| size_t nr_threads = param.nr_threads; | |||||
| SmallVector<size_t> sizes_in_bytes; | |||||
| for (size_t i = 0; i < nr_threads; ++i) { | |||||
| sizes_in_bytes.push_back(sub_bundle_size); | |||||
| } | |||||
| WorkspaceBundle total_bundle(nullptr, sizes_in_bytes); | |||||
| return total_bundle; | |||||
| } | |||||
| template <size_t flt_size, size_t stride> | |||||
| void im2col( | |||||
| const int8_t* src, int8_t* dst, size_t OH, size_t OW, size_t pad_iw, | |||||
| size_t round_filter) { | |||||
| constexpr size_t FH = flt_size; | |||||
| constexpr size_t FW = flt_size; | |||||
| constexpr size_t SH = stride; | |||||
| constexpr size_t SW = stride; | |||||
| constexpr size_t FW_ROUND = (FW + 3) / 4 * 4; | |||||
| int bn = 0; | |||||
| int ni = 0; | |||||
| for (size_t oh = 0; oh < OH; ++oh) | |||||
| for (size_t ow = 0; ow < OW; ++ow) { | |||||
| const int8_t* src_n = src + oh * SH * pad_iw + ow * SW; | |||||
| int bk = 0; | |||||
| int ki = 0; | |||||
| for (size_t fh = 0; fh < FH; ++fh) | |||||
| for (size_t fw = 0; fw < FW_ROUND; ++fw) { | |||||
| dst[bn * block_n * round_filter + bk * block_n * block_k + | |||||
| ni * block_k + ki] = src_n[fh * pad_iw + fw]; | |||||
| ++ki; | |||||
| if (ki == block_k) { | |||||
| ki = 0; | |||||
| ++bk; | |||||
| } | |||||
| } | |||||
| ++ni; | |||||
| if (ni == block_n) { | |||||
| ni = 0; | |||||
| ++bn; | |||||
| } | |||||
| } | |||||
| } | |||||
| template <> | |||||
| void im2col<9, 1>( | |||||
| const int8_t* src, int8_t* dst, size_t OH, size_t OW, size_t pad_iw, | |||||
| size_t round_filter) { | |||||
| constexpr size_t FH = 9; | |||||
| constexpr size_t SH = 1; | |||||
| constexpr size_t SW = 1; | |||||
| constexpr size_t k_block_stride = block_k * block_n; | |||||
| constexpr size_t ow_block = 16; | |||||
| static const uint8_t tbl_array_0[16] = {0, 1, 2, 3, 1, 2, 3, 4, | |||||
| 2, 3, 4, 5, 3, 4, 5, 6}; | |||||
| static const uint8_t tbl_array_1[16] = {4, 5, 6, 7, 5, 6, 7, 8, | |||||
| 6, 7, 8, 9, 7, 8, 9, 10}; | |||||
| static const uint8_t tbl_array_2[16] = {8, 9, 10, 11, 9, 10, 11, 12, | |||||
| 10, 11, 12, 13, 11, 12, 13, 14}; | |||||
| uint8x16_t tbl_reg_0 = vld1q_u8(&tbl_array_0[0]); | |||||
| uint8x16_t tbl_reg_1 = vld1q_u8(&tbl_array_1[0]); | |||||
| uint8x16_t tbl_reg_2 = vld1q_u8(&tbl_array_2[0]); | |||||
| int bn = 0; | |||||
| int ni = 0; | |||||
| for (size_t oh = 0; oh < OH; ++oh) | |||||
| for (size_t ow = 0; ow < OW;) { | |||||
| if (ow + ow_block <= OW && ni + ow_block <= block_n) { | |||||
| const int8_t* src_n = src + oh * SH * pad_iw + ow * SW; | |||||
| int8_t* dst_n = dst + bn * block_n * round_filter + ni * block_k; | |||||
| for (size_t fh = 0; fh < FH; ++fh) { | |||||
| int8x16_t read_w[2]; | |||||
| read_w[0] = vld1q_s8(src_n); | |||||
| read_w[1] = vld1q_s8(src_n + 16); | |||||
| int8x16_t n0123_0 = vqtbl1q_s8(read_w[0], tbl_reg_0); | |||||
| int8x16_t n4567_0 = vqtbl1q_s8(read_w[0], tbl_reg_1); | |||||
| int8x16_t n89ab_0 = vqtbl1q_s8(read_w[0], tbl_reg_2); | |||||
| int8x16_t ncdef_0 = | |||||
| vqtbl1q_s8(vextq_s8(read_w[0], read_w[1], 12), tbl_reg_0); | |||||
| int8x16_t n0123_1 = n4567_0; | |||||
| int8x16_t n4567_1 = n89ab_0; | |||||
| int8x16_t n89ab_1 = ncdef_0; | |||||
| int8x16_t ncdef_1 = vqtbl1q_s8(read_w[1], tbl_reg_0); | |||||
| int8x16_t n0123_2 = n89ab_0; | |||||
| int8x16_t n4567_2 = ncdef_0; | |||||
| int8x16_t n89ab_2 = ncdef_1; | |||||
| int8x16_t ncdef_2 = vqtbl1q_s8(read_w[1], tbl_reg_1); | |||||
| vst1q_s8(dst_n + 0 * 16, n0123_0); | |||||
| vst1q_s8(dst_n + 1 * 16, n4567_0); | |||||
| vst1q_s8(dst_n + 2 * 16, n89ab_0); | |||||
| vst1q_s8(dst_n + 3 * 16, ncdef_0); | |||||
| vst1q_s8(dst_n + 1 * k_block_stride + 0 * 16, n0123_1); | |||||
| vst1q_s8(dst_n + 1 * k_block_stride + 1 * 16, n4567_1); | |||||
| vst1q_s8(dst_n + 1 * k_block_stride + 2 * 16, n89ab_1); | |||||
| vst1q_s8(dst_n + 1 * k_block_stride + 3 * 16, ncdef_1); | |||||
| vst1q_s8(dst_n + 2 * k_block_stride + 0 * 16, n0123_2); | |||||
| vst1q_s8(dst_n + 2 * k_block_stride + 1 * 16, n4567_2); | |||||
| vst1q_s8(dst_n + 2 * k_block_stride + 2 * 16, n89ab_2); | |||||
| vst1q_s8(dst_n + 2 * k_block_stride + 3 * 16, ncdef_2); | |||||
| dst_n += 3 * k_block_stride; | |||||
| src_n += pad_iw; | |||||
| } | |||||
| ni += ow_block; | |||||
| ow += ow_block; | |||||
| if (ni == block_n) { | |||||
| ni = 0; | |||||
| ++bn; | |||||
| } | |||||
| } else { | |||||
| const int8_t* src_n = src + oh * SH * pad_iw + ow * SW; | |||||
| int8_t* dst_n = dst + bn * block_n * round_filter + ni * block_k; | |||||
| for (size_t fh = 0; fh < FH; ++fh) { | |||||
| int8x16_t read_w[0]; | |||||
| read_w[0] = vld1q_s8(src_n); | |||||
| vst1q_lane_s32(dst_n, read_w[0], 0); | |||||
| vst1q_lane_s32(dst_n + 1 * k_block_stride, read_w[0], 1); | |||||
| vst1q_lane_s32(dst_n + 2 * k_block_stride, read_w[0], 2); | |||||
| dst_n += 3 * k_block_stride; | |||||
| src_n += pad_iw; | |||||
| } | |||||
| ++ni; | |||||
| ++ow; | |||||
| if (ni == block_n) { | |||||
| ni = 0; | |||||
| ++bn; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| template <> | |||||
| void im2col<9, 2>( | |||||
| const int8_t* src, int8_t* dst, size_t OH, size_t OW, size_t pad_iw, | |||||
| size_t round_filter) { | |||||
| constexpr size_t FH = 9; | |||||
| constexpr size_t SH = 2; | |||||
| constexpr size_t SW = 2; | |||||
| constexpr size_t k_block_stride = block_k * block_n; | |||||
| constexpr size_t ow_block = 16; | |||||
| static const uint8_t tbl_array_0[16] = {0, 1, 2, 3, 2, 3, 4, 5, | |||||
| 4, 5, 6, 7, 6, 7, 8, 9}; | |||||
| static const uint8_t tbl_array_1[16] = {4, 5, 6, 7, 6, 7, 8, 9, | |||||
| 8, 9, 10, 11, 10, 11, 12, 13}; | |||||
| uint8x16_t tbl_reg_0 = vld1q_u8(&tbl_array_0[0]); | |||||
| uint8x16_t tbl_reg_1 = vld1q_u8(&tbl_array_1[0]); | |||||
| int bn = 0; | |||||
| int ni = 0; | |||||
| for (size_t oh = 0; oh < OH; ++oh) | |||||
| for (size_t ow = 0; ow < OW;) { | |||||
| if (ow + ow_block <= OW && ni + ow_block <= block_n) { | |||||
| const int8_t* src_n = src + oh * SH * pad_iw + ow * SW; | |||||
| int8_t* dst_n = dst + bn * block_n * round_filter + ni * block_k; | |||||
| for (size_t fh = 0; fh < FH; ++fh) { | |||||
| int8x16_t read_w[3]; | |||||
| read_w[0] = vld1q_s8(src_n); | |||||
| read_w[1] = vld1q_s8(src_n + 16); | |||||
| read_w[2] = vld1q_s8(src_n + 32); | |||||
| int8x16_t ext_8 = vextq_s8(read_w[0], read_w[1], 8); | |||||
| int8x16_t ext_24 = vextq_s8(read_w[1], read_w[2], 8); | |||||
| int8x16_t n0123_0 = vqtbl1q_s8(read_w[0], tbl_reg_0); | |||||
| int8x16_t n4567_0 = vqtbl1q_s8(ext_8, tbl_reg_0); | |||||
| int8x16_t n89ab_0 = vqtbl1q_s8(read_w[1], tbl_reg_0); | |||||
| int8x16_t ncdef_0 = vqtbl1q_s8(ext_24, tbl_reg_0); | |||||
| int8x16_t n0123_1 = vqtbl1q_s8(read_w[0], tbl_reg_1); | |||||
| int8x16_t n4567_1 = vqtbl1q_s8(ext_8, tbl_reg_1); | |||||
| int8x16_t n89ab_1 = vqtbl1q_s8(read_w[1], tbl_reg_1); | |||||
| int8x16_t ncdef_1 = vqtbl1q_s8(ext_24, tbl_reg_1); | |||||
| int8x16_t n0123_2 = n4567_0; | |||||
| int8x16_t n4567_2 = n89ab_0; | |||||
| int8x16_t n89ab_2 = ncdef_0; | |||||
| int8x16_t ncdef_2 = vqtbl1q_s8(read_w[2], tbl_reg_0); | |||||
| vst1q_s8(dst_n + 0 * 16, n0123_0); | |||||
| vst1q_s8(dst_n + 1 * 16, n4567_0); | |||||
| vst1q_s8(dst_n + 2 * 16, n89ab_0); | |||||
| vst1q_s8(dst_n + 3 * 16, ncdef_0); | |||||
| vst1q_s8(dst_n + 1 * k_block_stride + 0 * 16, n0123_1); | |||||
| vst1q_s8(dst_n + 1 * k_block_stride + 1 * 16, n4567_1); | |||||
| vst1q_s8(dst_n + 1 * k_block_stride + 2 * 16, n89ab_1); | |||||
| vst1q_s8(dst_n + 1 * k_block_stride + 3 * 16, ncdef_1); | |||||
| vst1q_s8(dst_n + 2 * k_block_stride + 0 * 16, n0123_2); | |||||
| vst1q_s8(dst_n + 2 * k_block_stride + 1 * 16, n4567_2); | |||||
| vst1q_s8(dst_n + 2 * k_block_stride + 2 * 16, n89ab_2); | |||||
| vst1q_s8(dst_n + 2 * k_block_stride + 3 * 16, ncdef_2); | |||||
| dst_n += 3 * k_block_stride; | |||||
| src_n += pad_iw; | |||||
| } | |||||
| ni += ow_block; | |||||
| ow += ow_block; | |||||
| if (ni == block_n) { | |||||
| ni = 0; | |||||
| ++bn; | |||||
| } | |||||
| } else { | |||||
| const int8_t* src_n = src + oh * SH * pad_iw + ow * SW; | |||||
| int8_t* dst_n = dst + bn * block_n * round_filter + ni * block_k; | |||||
| for (size_t fh = 0; fh < FH; ++fh) { | |||||
| int8x16_t read_w[0]; | |||||
| read_w[0] = vld1q_s8(src_n); | |||||
| vst1q_lane_s32(dst_n, read_w[0], 0); | |||||
| vst1q_lane_s32(dst_n + 1 * k_block_stride, read_w[0], 1); | |||||
| vst1q_lane_s32(dst_n + 2 * k_block_stride, read_w[0], 2); | |||||
| dst_n += 3 * k_block_stride; | |||||
| src_n += pad_iw; | |||||
| } | |||||
| ++ni; | |||||
| ++ow; | |||||
| if (ni == block_n) { | |||||
| ni = 0; | |||||
| ++bn; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| void do_conv( | |||||
| const WorkspaceBundle& bundle, const NCBKernParam& kern_param, | |||||
| const NCBKernIndex& ncb_index) { | |||||
| auto&& fm = kern_param.filter_meta; | |||||
| size_t PH = kern_param.filter_meta.padding[0]; | |||||
| size_t PW = kern_param.filter_meta.padding[1]; | |||||
| size_t OH = kern_param.osz[0]; | |||||
| size_t OW = kern_param.osz[1]; | |||||
| size_t IH = kern_param.isz[0]; | |||||
| size_t IW = kern_param.isz[1]; | |||||
| size_t FH = fm.spatial[0]; | |||||
| size_t FW = fm.spatial[1]; | |||||
| size_t SH = fm.stride[0]; | |||||
| float scale_bias = kern_param.bias_type.param<dtype::QuantizedS32>().scale; | |||||
| float scale_dst = kern_param.dst_type.param<dtype::QuantizedS8>().scale; | |||||
| float scale_dst_div = 1.f / scale_dst; | |||||
| size_t batch_id = ncb_index.ndrange_id[0]; | |||||
| size_t group_id = ncb_index.ndrange_id[1]; | |||||
| int8_t* pad_src_ptr = static_cast<int8_t*>(bundle.get(0)); | |||||
| int8_t* round_filter_ptr = static_cast<int8_t*>(bundle.get(1)); | |||||
| int8_t* im2col_ptr = static_cast<int8_t*>(bundle.get(2)); | |||||
| int32_t* i32_ptr = static_cast<int32_t*>(bundle.get(3)); | |||||
| const int8_t* sptr = kern_param.src<dt_int8>(batch_id, group_id); | |||||
| const int32_t* bptr = kern_param.bias<dt_int32>(batch_id, group_id); | |||||
| const int8_t* fptr = kern_param.filter<dt_int8>(group_id); | |||||
| void* dst = kern_param.dst<void>(batch_id, group_id); | |||||
| size_t round_filter = round_up(FW, block_k) * FH; | |||||
| size_t pad_iw = IW + 2 * PW; | |||||
| memset(pad_src_ptr, 0, bundle.get_size(0)); | |||||
| rep(ih, IH) { | |||||
| std::memcpy( | |||||
| pad_src_ptr + (ih + PH) * pad_iw + PW, sptr + ih * IW, | |||||
| sizeof(int8_t) * IW); | |||||
| } | |||||
| memset(round_filter_ptr, 0, bundle.get_size(1)); | |||||
| size_t fh_stride = round_up(FW, block_k); | |||||
| for (size_t fh = 0; fh < FH; ++fh) { | |||||
| std::memcpy(round_filter_ptr + fh * fh_stride, fptr + fh * FW, FW); | |||||
| } | |||||
| memset(im2col_ptr, 0, bundle.get_size(2)); | |||||
| if (SH == 1) { | |||||
| im2col<9, 1>(pad_src_ptr, im2col_ptr, OH, OW, pad_iw, round_filter); | |||||
| } else { | |||||
| im2col<9, 2>(pad_src_ptr, im2col_ptr, OH, OW, pad_iw, round_filter); | |||||
| } | |||||
| gevm_naive_n32k4_dot( | |||||
| round_filter_ptr, im2col_ptr, i32_ptr, 1, OH * OW, round_filter, 0, 0, 0); | |||||
| int32_t bias_val = kern_param.bias_mode == BiasMode::NO_BIAS ? 0 : *bptr; | |||||
| int8_t relu_val = kern_param.nonlineMode == NonlineMode::RELU ? 0 : -128; | |||||
| int8_t* dst_ptr = (int8_t*)dst; | |||||
| for (size_t i = 0; i < OH * OW; ++i) { | |||||
| //! optimize by tbl | |||||
| int val = roundf(scale_bias * scale_dst_div * (i32_ptr[i] + bias_val)); | |||||
| val = val < -128 ? -128 : val; | |||||
| val = val > 127 ? 127 : val; | |||||
| val = val > relu_val ? val : relu_val; | |||||
| dst_ptr[i] = val; | |||||
| } | |||||
| } | |||||
| } // namespace | |||||
| bool ConvBiasImpl::AlgoDotS8Im2colChanWiseLarge::usable( | |||||
| const NCBKernSizeParam& param, AlgoSelectionStrategy) const { | |||||
| if (!cpuinfo_has_arm_neon_dot()) { | |||||
| return false; | |||||
| } | |||||
| auto&& fm = param.filter_meta; | |||||
| auto FH = fm.spatial[0]; | |||||
| auto FW = fm.spatial[1]; | |||||
| auto SH = fm.stride[0]; | |||||
| auto SW = fm.stride[1]; | |||||
| auto noline = param.nonlineMode; | |||||
| auto bias_mode = param.bias_mode; | |||||
| bool avaible = | |||||
| //! src and filter are qint8, dst is qint8 or qint32 | |||||
| (param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||||
| param.filter_type.enumv() == DTypeEnum::QuantizedS8 && | |||||
| (param.dst_type.enumv() == DTypeEnum::QuantizedS8)) && | |||||
| fm.format == param::Convolution::Format::NCHW && !fm.should_flip && | |||||
| (noline == NonlineMode::IDENTITY || noline == NonlineMode::RELU) && | |||||
| (bias_mode == BiasMode::NO_BIAS || | |||||
| bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) && | |||||
| fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||||
| SH == SW && (SH == 1 || SH == 2) && FH == FW && (FH == 9) && fm.icpg == 1 && | |||||
| fm.ocpg == 1; | |||||
| return avaible; | |||||
| } | |||||
| size_t ConvBiasImpl::AlgoDotS8Im2colChanWiseLarge::get_workspace( | |||||
| const NCBKernSizeParam& param) const { | |||||
| MIDOUT_BEGIN( | |||||
| megdnn_arm_common_conv_bias_int8_im2col_dot_large_kernel, | |||||
| midout_iv("AlgoDotS8Im2colChanWiseLarge::get_workspace"_hash)) { | |||||
| auto bundle = get_total_bundle(param); | |||||
| return bundle.total_size_in_bytes(); | |||||
| } | |||||
| MIDOUT_END(); | |||||
| return 0; | |||||
| } | |||||
| SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoDotS8Im2colChanWiseLarge:: | |||||
| dispatch_kerns(const NCBKernSizeParam& param) const { | |||||
| MIDOUT_BEGIN( | |||||
| megdnn_arm_common_conv_bias_int8_im2col_dot_large_kernel, | |||||
| midout_iv("AlgoDotS8Im2colChanWiseLarge::dispatch_kerns"_hash)) { | |||||
| SmallVector<ConvBiasImpl::NCBKern> ret_kerns; | |||||
| auto fm = param.filter_meta; | |||||
| size_t N = param.n; | |||||
| size_t group = fm.group; | |||||
| WorkspaceBundle wbundle = get_sub_bundle(param); | |||||
| WorkspaceBundle total_bundle = get_total_bundle(param); | |||||
| auto exec_one_group = [wbundle, total_bundle]( | |||||
| const NCBKernParam& kern_param, | |||||
| const NCBKernIndex& ncb_index) mutable { | |||||
| WorkspaceBundle temp_total_bundle = total_bundle; | |||||
| temp_total_bundle.set(kern_param.workspace_ptr); | |||||
| WorkspaceBundle temp_bundle = wbundle; | |||||
| temp_bundle.set(temp_total_bundle.get(ncb_index.thread_id)); | |||||
| do_conv(temp_bundle, kern_param, ncb_index); | |||||
| }; | |||||
| ret_kerns.push_back({exec_one_group, {N, group}}); | |||||
| return ret_kerns; | |||||
| } | |||||
| MIDOUT_END(); | |||||
| return {}; | |||||
| } | |||||
| #endif | |||||
| @@ -0,0 +1,27 @@ | |||||
| /** | |||||
| * \file | |||||
| * dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include "megdnn/arch.h" | |||||
| #if MGB_ENABLE_DOT | |||||
| void megdnn_dot_nchw_large_chanwise_direct_conv_9x9s1_oh4_ow16( | |||||
| const int8_t* src, const int8_t* weight, int32_t bias, int8_t* dst, size_t oh, | |||||
| size_t ow, size_t OH, size_t OW, size_t pad_iw, const float scale, | |||||
| int8_t relu_val); | |||||
| void megdnn_dot_nchw_large_chanwise_direct_conv_9x9s2_oh4_ow16( | |||||
| const int8_t* src, const int8_t* weight, int32_t bias, int8_t* dst, size_t oh, | |||||
| size_t ow, size_t OH, size_t OW, size_t pad_iw, const float scale, | |||||
| int8_t relu_val); | |||||
| #endif | |||||
| @@ -0,0 +1,40 @@ | |||||
| /** | |||||
| * \file | |||||
| * dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include "megdnn/arch.h" | |||||
| #if MGB_ENABLE_DOT | |||||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||||
| static inline void quant_store_s8( | |||||
| float32x4_t v0, float32x4_t v1, float32x4_t v2, float32x4_t v3, int8_t* ptr, | |||||
| int8x16_t relu_reg) { | |||||
| int32x4_t i0 = vcvtaq_s32_f32(v0); | |||||
| int32x4_t i1 = vcvtaq_s32_f32(v1); | |||||
| int32x4_t i2 = vcvtaq_s32_f32(v2); | |||||
| int32x4_t i3 = vcvtaq_s32_f32(v3); | |||||
| int16x4_t i16_0 = vqmovn_s32(i0); | |||||
| int16x4_t i16_1 = vqmovn_s32(i1); | |||||
| int16x4_t i16_2 = vqmovn_s32(i2); | |||||
| int16x4_t i16_3 = vqmovn_s32(i3); | |||||
| int8x8_t i8_0 = vqmovn_s16(vcombine_s16(i16_0, i16_1)); | |||||
| int8x8_t i8_1 = vqmovn_s16(vcombine_s16(i16_2, i16_3)); | |||||
| int8x16_t rst = vcombine_s8(i8_0, i8_1); | |||||
| rst = vmaxq_s8(rst, relu_reg); | |||||
| vst1q_s8(ptr, rst); | |||||
| } | |||||
| #endif | |||||
| @@ -0,0 +1,221 @@ | |||||
| /** | |||||
| * \file | |||||
| * dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_larget_9x9s1.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "megdnn/arch.h" | |||||
| #if MGB_ENABLE_DOT | |||||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||||
| #include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large.h" | |||||
| #include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large_common.h" | |||||
| #include "src/common/unroll_macro.h" | |||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| void megdnn_dot_nchw_large_chanwise_direct_conv_9x9s1_oh4_ow16( | |||||
| const int8_t* src, const int8_t* weight, int32_t bias, int8_t* dst, size_t oh, | |||||
| size_t ow, size_t OH, size_t OW, size_t pad_iw, const float scale, | |||||
| int8_t relu_val) { | |||||
| //! 4x16 | |||||
| const size_t SH = 1; | |||||
| const size_t SW = 1; | |||||
| static const uint8_t tbl_array_0[16] = {0, 1, 2, 3, 1, 2, 3, 4, | |||||
| 2, 3, 4, 5, 3, 4, 5, 6}; | |||||
| static const uint8_t tbl_array_1[16] = {4, 5, 6, 7, 5, 6, 7, 8, | |||||
| 6, 7, 8, 9, 7, 8, 9, 10}; | |||||
| static const uint8_t tbl_array_2[16] = {8, 9, 10, 11, 9, 10, 11, 12, | |||||
| 10, 11, 12, 13, 11, 12, 13, 14}; | |||||
| uint8x16_t tbl_reg_0 = vld1q_u8(&tbl_array_0[0]); | |||||
| uint8x16_t tbl_reg_1 = vld1q_u8(&tbl_array_1[0]); | |||||
| uint8x16_t tbl_reg_2 = vld1q_u8(&tbl_array_2[0]); | |||||
| const int8_t* src_n = src + oh * SH * pad_iw + ow * SW; | |||||
| //! init | |||||
| int32x4_t c[4][4]; | |||||
| #define cb(step) \ | |||||
| c[step][0] = vdupq_n_s32(bias); \ | |||||
| c[step][1] = vdupq_n_s32(bias); \ | |||||
| c[step][2] = vdupq_n_s32(bias); \ | |||||
| c[step][3] = vdupq_n_s32(bias); | |||||
| UNROLL_CALL_RAW(4, cb); | |||||
| #undef cb | |||||
| int8x16_t flt[4]; | |||||
| flt[0] = vld1q_s8(weight + 0 * 16); | |||||
| flt[1] = vld1q_s8(weight + 1 * 16); | |||||
| flt[2] = vld1q_s8(weight + 2 * 16); | |||||
| flt[3] = vld1q_s8(weight + 3 * 16); | |||||
| //! row 0 | |||||
| int8x16_t read_w[2]; | |||||
| read_w[0] = vld1q_s8(src_n + 0 * pad_iw); | |||||
| read_w[1] = vld1q_s8(src_n + 0 * pad_iw + 16); | |||||
| int8x16_t n0123_0 = vqtbl1q_s8(read_w[0], tbl_reg_0); | |||||
| int8x16_t n4567_0 = vqtbl1q_s8(read_w[0], tbl_reg_1); | |||||
| int8x16_t n89ab_0 = vqtbl1q_s8(read_w[0], tbl_reg_2); | |||||
| int8x16_t ncdef_0 = vqtbl1q_s8(vextq_s8(read_w[0], read_w[1], 12), tbl_reg_0); | |||||
| int8x16_t n0123_1 = n4567_0; | |||||
| int8x16_t n4567_1 = n89ab_0; | |||||
| int8x16_t n89ab_1 = ncdef_0; | |||||
| int8x16_t ncdef_1 = vqtbl1q_s8(read_w[1], tbl_reg_0); | |||||
| int8x16_t n0123_2 = n89ab_0; | |||||
| int8x16_t n4567_2 = ncdef_0; | |||||
| int8x16_t n89ab_2 = ncdef_1; | |||||
| int8x16_t ncdef_2 = vqtbl1q_s8(read_w[1], tbl_reg_1); | |||||
| #define CAL_C(oh, flt_start) \ | |||||
| c[oh][0] = vdotq_laneq_s32( \ | |||||
| c[oh][0], n0123_0, flt[(flt_start + 0) / 4 % 4], (flt_start + 0) % 4); \ | |||||
| c[oh][1] = vdotq_laneq_s32( \ | |||||
| c[oh][1], n4567_0, flt[(flt_start + 0) / 4 % 4], (flt_start + 0) % 4); \ | |||||
| c[oh][2] = vdotq_laneq_s32( \ | |||||
| c[oh][2], n89ab_0, flt[(flt_start + 0) / 4 % 4], (flt_start + 0) % 4); \ | |||||
| c[oh][3] = vdotq_laneq_s32( \ | |||||
| c[oh][3], ncdef_0, flt[(flt_start + 0) / 4 % 4], (flt_start + 0) % 4); \ | |||||
| c[oh][0] = vdotq_laneq_s32( \ | |||||
| c[oh][0], n0123_1, flt[(flt_start + 1) / 4 % 4], (flt_start + 1) % 4); \ | |||||
| c[oh][1] = vdotq_laneq_s32( \ | |||||
| c[oh][1], n4567_1, flt[(flt_start + 1) / 4 % 4], (flt_start + 1) % 4); \ | |||||
| c[oh][2] = vdotq_laneq_s32( \ | |||||
| c[oh][2], n89ab_1, flt[(flt_start + 1) / 4 % 4], (flt_start + 1) % 4); \ | |||||
| c[oh][3] = vdotq_laneq_s32( \ | |||||
| c[oh][3], ncdef_1, flt[(flt_start + 1) / 4 % 4], (flt_start + 1) % 4); \ | |||||
| c[oh][0] = vdotq_laneq_s32( \ | |||||
| c[oh][0], n0123_2, flt[(flt_start + 2) / 4 % 4], (flt_start + 2) % 4); \ | |||||
| c[oh][1] = vdotq_laneq_s32( \ | |||||
| c[oh][1], n4567_2, flt[(flt_start + 2) / 4 % 4], (flt_start + 2) % 4); \ | |||||
| c[oh][2] = vdotq_laneq_s32( \ | |||||
| c[oh][2], n89ab_2, flt[(flt_start + 2) / 4 % 4], (flt_start + 2) % 4); \ | |||||
| c[oh][3] = vdotq_laneq_s32( \ | |||||
| c[oh][3], ncdef_2, flt[(flt_start + 2) / 4 % 4], (flt_start + 2) % 4); | |||||
| CAL_C(0, 0); | |||||
| //! row 1 | |||||
| #define LOAD_SRC(row_id) \ | |||||
| read_w[0] = vld1q_s8(src_n + row_id * pad_iw); \ | |||||
| read_w[1] = vld1q_s8(src_n + row_id * pad_iw + 16); \ | |||||
| n0123_0 = vqtbl1q_s8(read_w[0], tbl_reg_0); \ | |||||
| n4567_0 = vqtbl1q_s8(read_w[0], tbl_reg_1); \ | |||||
| n89ab_0 = vqtbl1q_s8(read_w[0], tbl_reg_2); \ | |||||
| ncdef_0 = vqtbl1q_s8(vextq_s8(read_w[0], read_w[1], 12), tbl_reg_0); \ | |||||
| n0123_1 = n4567_0; \ | |||||
| n4567_1 = n89ab_0; \ | |||||
| n89ab_1 = ncdef_0; \ | |||||
| ncdef_1 = vqtbl1q_s8(read_w[1], tbl_reg_0); \ | |||||
| n0123_2 = n89ab_0; \ | |||||
| n4567_2 = ncdef_0; \ | |||||
| n89ab_2 = ncdef_1; \ | |||||
| ncdef_2 = vqtbl1q_s8(read_w[1], tbl_reg_1); | |||||
| LOAD_SRC(1); | |||||
| CAL_C(0, 3); | |||||
| CAL_C(1, 0); | |||||
| //! row 2 | |||||
| LOAD_SRC(2); | |||||
| CAL_C(0, 3 * 2); | |||||
| CAL_C(1, 3 * 1); | |||||
| CAL_C(2, 3 * 0); | |||||
| //! row 3 | |||||
| LOAD_SRC(3); | |||||
| CAL_C(0, 3 * 3); | |||||
| CAL_C(1, 3 * 2); | |||||
| CAL_C(2, 3 * 1); | |||||
| CAL_C(3, 3 * 0); | |||||
| //! row 4 | |||||
| LOAD_SRC(4); | |||||
| CAL_C(0, 3 * 4); | |||||
| CAL_C(1, 3 * 3); | |||||
| CAL_C(2, 3 * 2); | |||||
| CAL_C(3, 3 * 1); | |||||
| //! update flt 4 -> 0 | |||||
| flt[0] = vld1q_s8(weight + 4 * 16); | |||||
| //! row 5 | |||||
| LOAD_SRC(5); | |||||
| CAL_C(0, 3 * 5); | |||||
| CAL_C(1, 3 * 4); | |||||
| CAL_C(2, 3 * 3); | |||||
| CAL_C(3, 3 * 2); | |||||
| //! update flt 5 -> 1 | |||||
| flt[1] = vld1q_s8(weight + 5 * 16); | |||||
| //! row 6 | |||||
| LOAD_SRC(6); | |||||
| CAL_C(0, 3 * 6); | |||||
| CAL_C(1, 3 * 5); | |||||
| CAL_C(2, 3 * 4); | |||||
| CAL_C(3, 3 * 3); | |||||
| //! update flt 6 -> 2 | |||||
| flt[2] = vld1q_s8(weight + 6 * 16); | |||||
| //! row 7 | |||||
| LOAD_SRC(7); | |||||
| CAL_C(0, 3 * 7); | |||||
| CAL_C(1, 3 * 6); | |||||
| CAL_C(2, 3 * 5); | |||||
| CAL_C(3, 3 * 4); | |||||
| //! row 8 | |||||
| LOAD_SRC(8); | |||||
| CAL_C(0, 3 * 8); | |||||
| CAL_C(1, 3 * 7); | |||||
| CAL_C(2, 3 * 6); | |||||
| CAL_C(3, 3 * 5); | |||||
| //! row 9 | |||||
| LOAD_SRC(9); | |||||
| CAL_C(1, 3 * 8); | |||||
| CAL_C(2, 3 * 7); | |||||
| CAL_C(3, 3 * 6); | |||||
| //! row 10 | |||||
| LOAD_SRC(10); | |||||
| CAL_C(2, 3 * 8); | |||||
| CAL_C(3, 3 * 7); | |||||
| //! row 11 | |||||
| LOAD_SRC(11); | |||||
| CAL_C(3, 3 * 8); | |||||
| float32x4_t dst_reg[4][4]; | |||||
| #define cb(step) \ | |||||
| dst_reg[step][0] = vcvtq_f32_s32(c[step][0]); \ | |||||
| dst_reg[step][1] = vcvtq_f32_s32(c[step][1]); \ | |||||
| dst_reg[step][2] = vcvtq_f32_s32(c[step][2]); \ | |||||
| dst_reg[step][3] = vcvtq_f32_s32(c[step][3]); | |||||
| UNROLL_CALL_RAW(4, cb); | |||||
| #undef cb | |||||
| #define cb(step) \ | |||||
| dst_reg[step][0] = vmulq_n_f32(dst_reg[step][0], scale); \ | |||||
| dst_reg[step][1] = vmulq_n_f32(dst_reg[step][1], scale); \ | |||||
| dst_reg[step][2] = vmulq_n_f32(dst_reg[step][2], scale); \ | |||||
| dst_reg[step][3] = vmulq_n_f32(dst_reg[step][3], scale); | |||||
| UNROLL_CALL_RAW(4, cb); | |||||
| #undef cb | |||||
| int8_t* dst_store = dst + oh * OW + ow; | |||||
| int8x16_t relu_reg = vdupq_n_s8(relu_val); | |||||
| #define cb(step) \ | |||||
| quant_store_s8( \ | |||||
| dst_reg[step][0], dst_reg[step][1], dst_reg[step][2], dst_reg[step][3], \ | |||||
| dst_store + step * OW, relu_reg); | |||||
| UNROLL_CALL_RAW(4, cb); | |||||
| #undef cb | |||||
| } | |||||
| #endif | |||||
| @@ -0,0 +1,250 @@ | |||||
| /** | |||||
| * \file | |||||
| * dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_larget_9x9s2.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "megdnn/arch.h" | |||||
| #if MGB_ENABLE_DOT | |||||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||||
| #include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large.h" | |||||
| #include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large_common.h" | |||||
| #include "src/common/unroll_macro.h" | |||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| void megdnn_dot_nchw_large_chanwise_direct_conv_9x9s2_oh4_ow16( | |||||
| const int8_t* src, const int8_t* weight, int32_t bias, int8_t* dst, size_t oh, | |||||
| size_t ow, size_t OH, size_t OW, size_t pad_iw, const float scale, | |||||
| int8_t relu_val) { | |||||
| //! 4x16 | |||||
| const size_t SH = 2; | |||||
| const size_t SW = 2; | |||||
| static const uint8_t tbl_array_0[16] = {0, 1, 2, 3, 2, 3, 4, 5, | |||||
| 4, 5, 6, 7, 6, 7, 8, 9}; | |||||
| static const uint8_t tbl_array_1[16] = {4, 5, 6, 7, 6, 7, 8, 9, | |||||
| 8, 9, 10, 11, 10, 11, 12, 13}; | |||||
| uint8x16_t tbl_reg_0 = vld1q_u8(&tbl_array_0[0]); | |||||
| uint8x16_t tbl_reg_1 = vld1q_u8(&tbl_array_1[0]); | |||||
| const int8_t* src_n = src + oh * SH * pad_iw + ow * SW; | |||||
| //! init | |||||
| int32x4_t c[4][4]; | |||||
| #define cb(step) \ | |||||
| c[step][0] = vdupq_n_s32(bias); \ | |||||
| c[step][1] = vdupq_n_s32(bias); \ | |||||
| c[step][2] = vdupq_n_s32(bias); \ | |||||
| c[step][3] = vdupq_n_s32(bias); | |||||
| UNROLL_CALL_RAW(4, cb); | |||||
| #undef cb | |||||
| constexpr int flt_reg = 7; | |||||
| constexpr int flt_per_reg = 4; | |||||
| int8x16_t flt[7]; | |||||
| flt[0] = vld1q_s8(weight + 0 * 16); | |||||
| flt[1] = vld1q_s8(weight + 1 * 16); | |||||
| flt[2] = vld1q_s8(weight + 2 * 16); | |||||
| flt[3] = vld1q_s8(weight + 3 * 16); | |||||
| flt[4] = vld1q_s8(weight + 4 * 16); | |||||
| flt[5] = vld1q_s8(weight + 5 * 16); | |||||
| flt[6] = vld1q_s8(weight + 6 * 16); | |||||
| #define CAL_C(oh, flt_start) \ | |||||
| c[oh][0] = vdotq_laneq_s32( \ | |||||
| c[oh][0], n0123_0, flt[(flt_start + 0) / flt_per_reg % flt_reg], \ | |||||
| (flt_start + 0) % flt_per_reg); \ | |||||
| c[oh][1] = vdotq_laneq_s32( \ | |||||
| c[oh][1], n4567_0, flt[(flt_start + 0) / flt_per_reg % flt_reg], \ | |||||
| (flt_start + 0) % flt_per_reg); \ | |||||
| c[oh][2] = vdotq_laneq_s32( \ | |||||
| c[oh][2], n89ab_0, flt[(flt_start + 0) / flt_per_reg % flt_reg], \ | |||||
| (flt_start + 0) % flt_per_reg); \ | |||||
| c[oh][3] = vdotq_laneq_s32( \ | |||||
| c[oh][3], ncdef_0, flt[(flt_start + 0) / flt_per_reg % flt_reg], \ | |||||
| (flt_start + 0) % flt_per_reg); \ | |||||
| c[oh][0] = vdotq_laneq_s32( \ | |||||
| c[oh][0], n0123_1, flt[(flt_start + 1) / flt_per_reg % flt_reg], \ | |||||
| (flt_start + 1) % flt_per_reg); \ | |||||
| c[oh][1] = vdotq_laneq_s32( \ | |||||
| c[oh][1], n4567_1, flt[(flt_start + 1) / flt_per_reg % flt_reg], \ | |||||
| (flt_start + 1) % flt_per_reg); \ | |||||
| c[oh][2] = vdotq_laneq_s32( \ | |||||
| c[oh][2], n89ab_1, flt[(flt_start + 1) / flt_per_reg % flt_reg], \ | |||||
| (flt_start + 1) % flt_per_reg); \ | |||||
| c[oh][3] = vdotq_laneq_s32( \ | |||||
| c[oh][3], ncdef_1, flt[(flt_start + 1) / flt_per_reg % flt_reg], \ | |||||
| (flt_start + 1) % flt_per_reg); \ | |||||
| c[oh][0] = vdotq_laneq_s32( \ | |||||
| c[oh][0], n0123_2, flt[(flt_start + 2) / flt_per_reg % flt_reg], \ | |||||
| (flt_start + 2) % flt_per_reg); \ | |||||
| c[oh][1] = vdotq_laneq_s32( \ | |||||
| c[oh][1], n4567_2, flt[(flt_start + 2) / flt_per_reg % flt_reg], \ | |||||
| (flt_start + 2) % flt_per_reg); \ | |||||
| c[oh][2] = vdotq_laneq_s32( \ | |||||
| c[oh][2], n89ab_2, flt[(flt_start + 2) / flt_per_reg % flt_reg], \ | |||||
| (flt_start + 2) % flt_per_reg); \ | |||||
| c[oh][3] = vdotq_laneq_s32( \ | |||||
| c[oh][3], ncdef_2, flt[(flt_start + 2) / flt_per_reg % flt_reg], \ | |||||
| (flt_start + 2) % flt_per_reg); | |||||
| #define LOAD_SRC(row_id) \ | |||||
| read_w[0] = vld1q_s8(src_n + row_id * pad_iw); \ | |||||
| read_w[1] = vld1q_s8(src_n + row_id * pad_iw + 16); \ | |||||
| read_w[2] = vld1q_s8(src_n + row_id * pad_iw + 32); \ | |||||
| ext_8 = vextq_s8(read_w[0], read_w[1], 8); \ | |||||
| ext_24 = vextq_s8(read_w[1], read_w[2], 8); \ | |||||
| n0123_0 = vqtbl1q_s8(read_w[0], tbl_reg_0); \ | |||||
| n4567_0 = vqtbl1q_s8(ext_8, tbl_reg_0); \ | |||||
| n89ab_0 = vqtbl1q_s8(read_w[1], tbl_reg_0); \ | |||||
| ncdef_0 = vqtbl1q_s8(ext_24, tbl_reg_0); \ | |||||
| n0123_1 = vqtbl1q_s8(read_w[0], tbl_reg_1); \ | |||||
| n4567_1 = vqtbl1q_s8(ext_8, tbl_reg_1); \ | |||||
| n89ab_1 = vqtbl1q_s8(read_w[1], tbl_reg_1); \ | |||||
| ncdef_1 = vqtbl1q_s8(ext_24, tbl_reg_1); \ | |||||
| n0123_2 = n4567_0; \ | |||||
| n4567_2 = n89ab_0; \ | |||||
| n89ab_2 = ncdef_0; \ | |||||
| ncdef_2 = vqtbl1q_s8(read_w[2], tbl_reg_0); | |||||
| //! row 0 | |||||
| int8x16_t read_w[3]; | |||||
| read_w[0] = vld1q_s8(src_n); | |||||
| read_w[1] = vld1q_s8(src_n + 16); | |||||
| read_w[2] = vld1q_s8(src_n + 32); | |||||
| int8x16_t ext_8 = vextq_s8(read_w[0], read_w[1], 8); | |||||
| int8x16_t ext_24 = vextq_s8(read_w[1], read_w[2], 8); | |||||
| int8x16_t n0123_0 = vqtbl1q_s8(read_w[0], tbl_reg_0); | |||||
| int8x16_t n4567_0 = vqtbl1q_s8(ext_8, tbl_reg_0); | |||||
| int8x16_t n89ab_0 = vqtbl1q_s8(read_w[1], tbl_reg_0); | |||||
| int8x16_t ncdef_0 = vqtbl1q_s8(ext_24, tbl_reg_0); | |||||
| int8x16_t n0123_1 = vqtbl1q_s8(read_w[0], tbl_reg_1); | |||||
| int8x16_t n4567_1 = vqtbl1q_s8(ext_8, tbl_reg_1); | |||||
| int8x16_t n89ab_1 = vqtbl1q_s8(read_w[1], tbl_reg_1); | |||||
| int8x16_t ncdef_1 = vqtbl1q_s8(ext_24, tbl_reg_1); | |||||
| int8x16_t n0123_2 = n4567_0; | |||||
| int8x16_t n4567_2 = n89ab_0; | |||||
| int8x16_t n89ab_2 = ncdef_0; | |||||
| int8x16_t ncdef_2 = vqtbl1q_s8(read_w[2], tbl_reg_0); | |||||
| CAL_C(0, 0); | |||||
| //! row 1 | |||||
| LOAD_SRC(1); | |||||
| CAL_C(0, 3 * 1); | |||||
| //! row 2 | |||||
| LOAD_SRC(2); | |||||
| CAL_C(0, 3 * 2); | |||||
| CAL_C(1, 3 * 0); | |||||
| //! row 3 | |||||
| LOAD_SRC(3); | |||||
| CAL_C(0, 3 * 3); | |||||
| CAL_C(1, 3 * 1); | |||||
| //! row 4 | |||||
| LOAD_SRC(4); | |||||
| CAL_C(0, 3 * 4); | |||||
| CAL_C(1, 3 * 2); | |||||
| CAL_C(2, 3 * 0); | |||||
| //! row 5 | |||||
| LOAD_SRC(5); | |||||
| CAL_C(0, 3 * 5); | |||||
| CAL_C(1, 3 * 3); | |||||
| CAL_C(2, 3 * 1); | |||||
| //! row 6 | |||||
| LOAD_SRC(6); | |||||
| CAL_C(0, 3 * 6); | |||||
| CAL_C(1, 3 * 4); | |||||
| CAL_C(2, 3 * 2); | |||||
| CAL_C(3, 3 * 0); | |||||
| //! row 7 | |||||
| LOAD_SRC(7); | |||||
| CAL_C(0, 3 * 7); | |||||
| CAL_C(1, 3 * 5); | |||||
| CAL_C(2, 3 * 3); | |||||
| CAL_C(3, 3 * 1); | |||||
| //! row 8 | |||||
| LOAD_SRC(8); | |||||
| CAL_C(0, 3 * 8); | |||||
| CAL_C(1, 3 * 6); | |||||
| CAL_C(2, 3 * 4); | |||||
| CAL_C(3, 3 * 2); | |||||
| //! row 9 | |||||
| LOAD_SRC(9); | |||||
| CAL_C(1, 3 * 7); | |||||
| CAL_C(2, 3 * 5); | |||||
| CAL_C(3, 3 * 3); | |||||
| //! row 10 | |||||
| LOAD_SRC(10); | |||||
| CAL_C(1, 3 * 8); | |||||
| CAL_C(2, 3 * 6); | |||||
| CAL_C(3, 3 * 4); | |||||
| //! row 11 | |||||
| LOAD_SRC(11); | |||||
| CAL_C(2, 3 * 7); | |||||
| CAL_C(3, 3 * 5); | |||||
| //! row 12 | |||||
| LOAD_SRC(12); | |||||
| CAL_C(2, 3 * 8); | |||||
| CAL_C(3, 3 * 6); | |||||
| //! row 13 | |||||
| LOAD_SRC(13); | |||||
| CAL_C(3, 3 * 7); | |||||
| //! row 14 | |||||
| LOAD_SRC(14); | |||||
| CAL_C(3, 3 * 8); | |||||
| float32x4_t dst_reg[4][4]; | |||||
| #define cb(step) \ | |||||
| dst_reg[step][0] = vcvtq_f32_s32(c[step][0]); \ | |||||
| dst_reg[step][1] = vcvtq_f32_s32(c[step][1]); \ | |||||
| dst_reg[step][2] = vcvtq_f32_s32(c[step][2]); \ | |||||
| dst_reg[step][3] = vcvtq_f32_s32(c[step][3]); | |||||
| UNROLL_CALL_RAW(4, cb); | |||||
| #undef cb | |||||
| #define cb(step) \ | |||||
| dst_reg[step][0] = vmulq_n_f32(dst_reg[step][0], scale); \ | |||||
| dst_reg[step][1] = vmulq_n_f32(dst_reg[step][1], scale); \ | |||||
| dst_reg[step][2] = vmulq_n_f32(dst_reg[step][2], scale); \ | |||||
| dst_reg[step][3] = vmulq_n_f32(dst_reg[step][3], scale); | |||||
| UNROLL_CALL_RAW(4, cb); | |||||
| #undef cb | |||||
| int8_t* dst_store = dst + oh * OW + ow; | |||||
| int8x16_t relu_reg = vdupq_n_s8(relu_val); | |||||
| #define cb(step) \ | |||||
| quant_store_s8( \ | |||||
| dst_reg[step][0], dst_reg[step][1], dst_reg[step][2], dst_reg[step][3], \ | |||||
| dst_store + step * OW, relu_reg); | |||||
| UNROLL_CALL_RAW(4, cb); | |||||
| #undef cb | |||||
| } | |||||
| #endif | |||||
| @@ -54,6 +54,8 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||||
| AlgoDotS8Direct_NCHW44 ds8_direct_nchw44; | AlgoDotS8Direct_NCHW44 ds8_direct_nchw44; | ||||
| AlgoDotS8DirectNCHWNCHW44 ds8_direct_nchw_nchw44; | AlgoDotS8DirectNCHWNCHW44 ds8_direct_nchw_nchw44; | ||||
| AlgoDotS8Im2colChanWiseLarge ds8_im2col_large_chanwise; | |||||
| AlgoDotS8DirectChanWiseLarge ds8_direct_large_chanwise; | |||||
| #endif | #endif | ||||
| AlgoI8x8x16Direct i8x8x16_direct; | AlgoI8x8x16Direct i8x8x16_direct; | ||||
| @@ -75,6 +77,8 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||||
| public: | public: | ||||
| AlgoPack() { | AlgoPack() { | ||||
| #if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
| m_direct_algos.emplace_back(&ds8_direct_large_chanwise); | |||||
| m_direct_algos.emplace_back(&ds8_im2col_large_chanwise); | |||||
| m_direct_algos.emplace_back(&ds8_direct_stride1); | m_direct_algos.emplace_back(&ds8_direct_stride1); | ||||
| m_direct_algos.emplace_back(&ds8_direct_stride2); | m_direct_algos.emplace_back(&ds8_direct_stride2); | ||||
| m_direct_algos.emplace_back(&du8_direct_stride1); | m_direct_algos.emplace_back(&du8_direct_stride1); | ||||
| @@ -51,6 +51,8 @@ private: | |||||
| #endif | #endif | ||||
| #if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
| class AlgoDotS8DirectNCHWNCHW44; | class AlgoDotS8DirectNCHWNCHW44; | ||||
| class AlgoDotS8DirectChanWiseLarge; | |||||
| class AlgoDotS8Im2colChanWiseLarge; | |||||
| class AlgoDotS8DirectStride1; | class AlgoDotS8DirectStride1; | ||||
| class AlgoDotS8DirectStride2; | class AlgoDotS8DirectStride2; | ||||
| class AlgoDotU8DirectStride1; | class AlgoDotU8DirectStride1; | ||||
| @@ -143,8 +143,89 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32GemvMK4::get_kern( | |||||
| const KernSizeParam&) const { | const KernSizeParam&) const { | ||||
| return int8x8x32_gemv_mk4_kern; | return int8x8x32_gemv_mk4_kern; | ||||
| } | } | ||||
| #if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
| namespace { | |||||
| void int8x8x32_gevm_dot_kern(const MatrixMulImpl::KernParam& kern_param) { | |||||
| MIDOUT_BEGIN(megdnn_arm_exec_int8832, midout_iv("int8x8x32_gevm_dot_kern"_hash)) { | |||||
| auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||||
| auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||||
| const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>(); | |||||
| auto Cptr = kern_param.C<dt_int32>(); | |||||
| gevm_naive_dot(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); | |||||
| } | |||||
| MIDOUT_END(); | |||||
| } | |||||
| } // anonymous namespace | |||||
| bool MatrixMulImpl::AlgoInt8x8x32GevmDot::usable( | |||||
| const KernSizeParam& kern_size_param) const { | |||||
| if (!cpuinfo_has_arm_neon_dot()) { | |||||
| return false; | |||||
| } | |||||
| auto M = kern_size_param.M; | |||||
| bool is_dtype_ok = | |||||
| kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() && | |||||
| (kern_size_param.A_type.enumv() == DTypeEnum::Int8 || | |||||
| kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) && | |||||
| (kern_size_param.C_type.enumv() == DTypeEnum::Int32 || | |||||
| kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32); | |||||
| return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && | |||||
| kern_size_param.format == param::MatrixMul::Format::DEFAULT && is_dtype_ok && | |||||
| !kern_size_param.trA && !kern_size_param.trB && M == 1; | |||||
| } | |||||
| bool MatrixMulImpl::AlgoInt8x8x32GevmDot::preferred( | |||||
| const KernSizeParam& kern_size_param) const { | |||||
| return true; | |||||
| } | |||||
| MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32GevmDot::get_kern( | |||||
| const KernSizeParam&) const { | |||||
| return int8x8x32_gevm_dot_kern; | |||||
| } | |||||
| namespace { | |||||
| void int8x8x32_gevm_n32k4_dot_kern(const MatrixMulImpl::KernParam& kern_param) { | |||||
| MIDOUT_BEGIN(megdnn_arm_exec_int8832, midout_iv("int8x8x32_gevm_dot_kern"_hash)) { | |||||
| auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||||
| auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||||
| const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>(); | |||||
| auto Cptr = kern_param.C<dt_int32>(); | |||||
| gevm_naive_n32k4_dot(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); | |||||
| } | |||||
| MIDOUT_END(); | |||||
| } | |||||
| } // anonymous namespace | |||||
| bool MatrixMulImpl::AlgoInt8x8x32GevmN32K4Dot::usable( | |||||
| const KernSizeParam& kern_size_param) const { | |||||
| if (!cpuinfo_has_arm_neon_dot()) { | |||||
| return false; | |||||
| } | |||||
| auto M = kern_size_param.M; | |||||
| bool is_dtype_ok = | |||||
| kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() && | |||||
| (kern_size_param.A_type.enumv() == DTypeEnum::Int8 || | |||||
| kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) && | |||||
| (kern_size_param.C_type.enumv() == DTypeEnum::Int32 || | |||||
| kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32); | |||||
| return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && | |||||
| kern_size_param.format == param::MatrixMul::Format::N32K4_DOT && | |||||
| is_dtype_ok && !kern_size_param.trA && !kern_size_param.trB && M == 1; | |||||
| } | |||||
| bool MatrixMulImpl::AlgoInt8x8x32GevmN32K4Dot::preferred( | |||||
| const KernSizeParam& kern_size_param) const { | |||||
| return true; | |||||
| } | |||||
| MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32GevmN32K4Dot::get_kern( | |||||
| const KernSizeParam&) const { | |||||
| return int8x8x32_gevm_n32k4_dot_kern; | |||||
| } | |||||
| /* =================== Int8x8x32 Gemv MK4_DOT algo ==================== */ | /* =================== Int8x8x32 Gemv MK4_DOT algo ==================== */ | ||||
| namespace { | namespace { | ||||
| void int8x8x32_gemv_mk4_dot_kern(const MatrixMulImpl::KernParam& kern_param) { | void int8x8x32_gemv_mk4_dot_kern(const MatrixMulImpl::KernParam& kern_param) { | ||||
| @@ -49,8 +49,69 @@ public: | |||||
| MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, MK4) | MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, MK4) | ||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_INT8X8X32_GEMV_MK4) | MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_INT8X8X32_GEMV_MK4) | ||||
| }; | }; | ||||
| #if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
| class MatrixMulImpl::AlgoInt8x8x32GevmDot : public AlgoBase { | |||||
| public: | |||||
| AlgoAttribute attribute() const override { | |||||
| return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
| } | |||||
| const char* name() const override { return "ARM_COMMON_INT8X8X32_GEVM_DOT"; } | |||||
| bool usable(const KernSizeParam&) const override; | |||||
| bool preferred(const KernSizeParam&) const override; | |||||
| size_t get_workspace(const KernSizeParam&) const override { return 0; } | |||||
| kern_t get_kern(const KernSizeParam&) const override; | |||||
| AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEVM; } | |||||
| PackMode packmode() const override { return PackMode::NO_PACK; } | |||||
| MEGDNN_OVERRIDE_MATMUL_DESC(1, 32, 4, 2, AlgoDataType::QINT8X8X32, DEFAULT) | |||||
| WorkspaceBundle get_bundle(const KernSizeParam&) const override { | |||||
| return WorkspaceBundle{nullptr, {}}; | |||||
| } | |||||
| kern_naked_t get_kern_naked(const KernSizeParam&) const override { | |||||
| megdnn_assert(0, "naked kern no impl"); | |||||
| } | |||||
| void pack_A(const KernParam& kern_param, void* out, size_t index, size_t stride) | |||||
| const override { | |||||
| megdnn_assert(0, "pack_A no impl"); | |||||
| } | |||||
| void pack_B(const KernParam& kern_param, void* out, size_t x0, size_t xmax) | |||||
| const override { | |||||
| megdnn_assert(0, "pack_B no impl"); | |||||
| } | |||||
| InnerBlockSize get_inner_block_size() const override { return {1, 32, 4}; }; | |||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_INT8X8X32_GEVM_DOT) | |||||
| }; | |||||
| class MatrixMulImpl::AlgoInt8x8x32GevmN32K4Dot : public AlgoBase { | |||||
| public: | |||||
| AlgoAttribute attribute() const override { | |||||
| return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
| } | |||||
| const char* name() const override { return "ARM_COMMON_INT8X8X32_GEVM_N32K4_DOT"; } | |||||
| bool usable(const KernSizeParam&) const override; | |||||
| bool preferred(const KernSizeParam&) const override; | |||||
| size_t get_workspace(const KernSizeParam&) const override { return 0; } | |||||
| kern_t get_kern(const KernSizeParam&) const override; | |||||
| AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEVM; } | |||||
| PackMode packmode() const override { return PackMode::NO_PACK; } | |||||
| MEGDNN_OVERRIDE_MATMUL_DESC(1, 32, 4, 2, AlgoDataType::QINT8X8X32, N32K4_DOT) | |||||
| WorkspaceBundle get_bundle(const KernSizeParam&) const override { | |||||
| return WorkspaceBundle{nullptr, {}}; | |||||
| } | |||||
| kern_naked_t get_kern_naked(const KernSizeParam&) const override { | |||||
| megdnn_assert(0, "naked kern no impl"); | |||||
| } | |||||
| void pack_A(const KernParam& kern_param, void* out, size_t index, size_t stride) | |||||
| const override { | |||||
| megdnn_assert(0, "pack_A no impl"); | |||||
| } | |||||
| void pack_B(const KernParam& kern_param, void* out, size_t x0, size_t xmax) | |||||
| const override { | |||||
| megdnn_assert(0, "pack_B no impl"); | |||||
| } | |||||
| InnerBlockSize get_inner_block_size() const override { return {1, 32, 4}; }; | |||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_INT8X8X32_GEVM_N32K4_DOT) | |||||
| }; | |||||
| class MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot : public AlgoBase { | ||||
| public: | public: | ||||
| AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
| @@ -2,6 +2,7 @@ | |||||
| #include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
| #include "src/arm_common/matrix_mul/int8/gemv.h" | #include "src/arm_common/matrix_mul/int8/gemv.h" | ||||
| #include "src/common/unroll_macro.h" | |||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "midout.h" | #include "midout.h" | ||||
| @@ -430,5 +431,398 @@ void arm_common::gemv_like_mk4_dot( | |||||
| MIDOUT_END(); | MIDOUT_END(); | ||||
| } | } | ||||
| #endif | #endif | ||||
| #if MGB_ENABLE_DOT | |||||
| namespace { | |||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| void gevm_naive_dot_impl( | |||||
| const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, | |||||
| size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride, | |||||
| bool load_c) { | |||||
| constexpr size_t n_block = 32; | |||||
| const size_t n_end = N / n_block * n_block; | |||||
| const size_t n_remain = N - n_end; | |||||
| constexpr size_t k_block = 4; | |||||
| constexpr size_t k_block_x2 = k_block * 2; | |||||
| const size_t k_end = (K / k_block_x2) * k_block_x2; | |||||
| const size_t k_remain = K - k_end; | |||||
| for (size_t n = 0; n < n_end; n += n_block) { | |||||
| if (K < k_block_x2) { | |||||
| if (!load_c) { | |||||
| for (size_t i = 0; i < n_block; ++i) { | |||||
| C[n + i] = 0; | |||||
| } | |||||
| } | |||||
| for (size_t k = 0; k < K; ++k) { | |||||
| for (size_t i = 0; i < n_block; ++i) { | |||||
| C[n + i] += A[k] * B[k * Bstride + n + i]; | |||||
| } | |||||
| } | |||||
| continue; | |||||
| } | |||||
| int32x4_t c[8]; | |||||
| if (load_c) { | |||||
| #define cb(step) c[step] = vld1q_s32(C + n + step * 4); | |||||
| UNROLL_CALL_RAW(8, cb); | |||||
| #undef cb | |||||
| } else { | |||||
| #define cb(step) c[step] = vdupq_n_s32(0); | |||||
| UNROLL_CALL_RAW(8, cb); | |||||
| #undef cb | |||||
| } | |||||
| int8x16_t a[2]; | |||||
| a[0] = vld1q_dup_s32(A); | |||||
| int8x16_t b[2][8]; | |||||
| #define cb(step) \ | |||||
| b[0][step * 2 + 0] = vld1q_s8(B + (0 + step) * Bstride + n); \ | |||||
| b[0][step * 2 + 1] = vld1q_s8(B + (0 + step) * Bstride + n + 16); | |||||
| UNROLL_CALL_RAW(4, cb); | |||||
| #undef cb | |||||
| size_t k_buffer_end = k_end - k_block_x2; | |||||
| for (size_t k = 0; k < k_buffer_end; k += k_block_x2) { | |||||
| //! double buffer main | |||||
| #define cb(step) \ | |||||
| b[1][step * 2 + 0] = vld1q_s8(B + (k + step + k_block) * Bstride + n); \ | |||||
| b[1][step * 2 + 1] = vld1q_s8(B + (k + step + k_block) * Bstride + n + 16); | |||||
| UNROLL_CALL_RAW(4, cb); | |||||
| #undef cb | |||||
| a[1] = vld1q_dup_s32(A + k + k_block); | |||||
| int8x16x2_t ab0 = vzipq_s8(b[0][0], b[0][2]); | |||||
| int8x16x2_t cd0 = vzipq_s8(b[0][4], b[0][6]); | |||||
| int8x16x2_t ab1 = vzipq_s8(b[0][1], b[0][3]); | |||||
| int8x16x2_t cd1 = vzipq_s8(b[0][5], b[0][7]); | |||||
| int16x8x2_t abcd0 = vzipq_s16(ab0.val[0], cd0.val[0]); | |||||
| int16x8x2_t abcd1 = vzipq_s16(ab0.val[1], cd0.val[1]); | |||||
| int16x8x2_t abcd2 = vzipq_s16(ab1.val[0], cd1.val[0]); | |||||
| int16x8x2_t abcd3 = vzipq_s16(ab1.val[1], cd1.val[1]); | |||||
| c[0] = vdotq_s32(c[0], abcd0.val[0], a[0]); | |||||
| c[1] = vdotq_s32(c[1], abcd0.val[1], a[0]); | |||||
| c[2] = vdotq_s32(c[2], abcd1.val[0], a[0]); | |||||
| c[3] = vdotq_s32(c[3], abcd1.val[1], a[0]); | |||||
| c[4] = vdotq_s32(c[4], abcd2.val[0], a[0]); | |||||
| c[5] = vdotq_s32(c[5], abcd2.val[1], a[0]); | |||||
| c[6] = vdotq_s32(c[6], abcd3.val[0], a[0]); | |||||
| c[7] = vdotq_s32(c[7], abcd3.val[1], a[0]); | |||||
| #define cb(step) \ | |||||
| b[0][step * 2 + 0] = vld1q_s8(B + (k + step + k_block_x2) * Bstride + n); \ | |||||
| b[0][step * 2 + 1] = vld1q_s8(B + (k + step + k_block_x2) * Bstride + n + 16); | |||||
| UNROLL_CALL_RAW(4, cb); | |||||
| #undef cb | |||||
| a[0] = vld1q_dup_s32(A + k + k_block_x2); | |||||
| ab0 = vzipq_s8(b[1][0], b[1][2]); | |||||
| cd0 = vzipq_s8(b[1][4], b[1][6]); | |||||
| ab1 = vzipq_s8(b[1][1], b[1][3]); | |||||
| cd1 = vzipq_s8(b[1][5], b[1][7]); | |||||
| abcd0 = vzipq_s16(ab0.val[0], cd0.val[0]); | |||||
| abcd1 = vzipq_s16(ab0.val[1], cd0.val[1]); | |||||
| abcd2 = vzipq_s16(ab1.val[0], cd1.val[0]); | |||||
| abcd3 = vzipq_s16(ab1.val[1], cd1.val[1]); | |||||
| c[0] = vdotq_s32(c[0], abcd0.val[0], a[1]); | |||||
| c[1] = vdotq_s32(c[1], abcd0.val[1], a[1]); | |||||
| c[2] = vdotq_s32(c[2], abcd1.val[0], a[1]); | |||||
| c[3] = vdotq_s32(c[3], abcd1.val[1], a[1]); | |||||
| c[4] = vdotq_s32(c[4], abcd2.val[0], a[1]); | |||||
| c[5] = vdotq_s32(c[5], abcd2.val[1], a[1]); | |||||
| c[6] = vdotq_s32(c[6], abcd3.val[0], a[1]); | |||||
| c[7] = vdotq_s32(c[7], abcd3.val[1], a[1]); | |||||
| } | |||||
| //! double buffer remain | |||||
| #define cb(step) \ | |||||
| b[1][step * 2 + 0] = vld1q_s8(B + (k_buffer_end + step + k_block) * Bstride + n); \ | |||||
| b[1][step * 2 + 1] = \ | |||||
| vld1q_s8(B + (k_buffer_end + step + k_block) * Bstride + n + 16); | |||||
| UNROLL_CALL_RAW(4, cb); | |||||
| #undef cb | |||||
| a[1] = vld1q_dup_s32(A + k_buffer_end + k_block); | |||||
| int8x16x2_t ab0 = vzipq_s8(b[0][0], b[0][2]); | |||||
| int8x16x2_t cd0 = vzipq_s8(b[0][4], b[0][6]); | |||||
| int8x16x2_t ab1 = vzipq_s8(b[0][1], b[0][3]); | |||||
| int8x16x2_t cd1 = vzipq_s8(b[0][5], b[0][7]); | |||||
| int16x8x2_t abcd0 = vzipq_s16(ab0.val[0], cd0.val[0]); | |||||
| int16x8x2_t abcd1 = vzipq_s16(ab0.val[1], cd0.val[1]); | |||||
| int16x8x2_t abcd2 = vzipq_s16(ab1.val[0], cd1.val[0]); | |||||
| int16x8x2_t abcd3 = vzipq_s16(ab1.val[1], cd1.val[1]); | |||||
| c[0] = vdotq_s32(c[0], abcd0.val[0], a[0]); | |||||
| c[1] = vdotq_s32(c[1], abcd0.val[1], a[0]); | |||||
| c[2] = vdotq_s32(c[2], abcd1.val[0], a[0]); | |||||
| c[3] = vdotq_s32(c[3], abcd1.val[1], a[0]); | |||||
| c[4] = vdotq_s32(c[4], abcd2.val[0], a[0]); | |||||
| c[5] = vdotq_s32(c[5], abcd2.val[1], a[0]); | |||||
| c[6] = vdotq_s32(c[6], abcd3.val[0], a[0]); | |||||
| c[7] = vdotq_s32(c[7], abcd3.val[1], a[0]); | |||||
| ab0 = vzipq_s8(b[1][0], b[1][2]); | |||||
| cd0 = vzipq_s8(b[1][4], b[1][6]); | |||||
| ab1 = vzipq_s8(b[1][1], b[1][3]); | |||||
| cd1 = vzipq_s8(b[1][5], b[1][7]); | |||||
| abcd0 = vzipq_s16(ab0.val[0], cd0.val[0]); | |||||
| abcd1 = vzipq_s16(ab0.val[1], cd0.val[1]); | |||||
| abcd2 = vzipq_s16(ab1.val[0], cd1.val[0]); | |||||
| abcd3 = vzipq_s16(ab1.val[1], cd1.val[1]); | |||||
| c[0] = vdotq_s32(c[0], abcd0.val[0], a[1]); | |||||
| c[1] = vdotq_s32(c[1], abcd0.val[1], a[1]); | |||||
| c[2] = vdotq_s32(c[2], abcd1.val[0], a[1]); | |||||
| c[3] = vdotq_s32(c[3], abcd1.val[1], a[1]); | |||||
| c[4] = vdotq_s32(c[4], abcd2.val[0], a[1]); | |||||
| c[5] = vdotq_s32(c[5], abcd2.val[1], a[1]); | |||||
| c[6] = vdotq_s32(c[6], abcd3.val[0], a[1]); | |||||
| c[7] = vdotq_s32(c[7], abcd3.val[1], a[1]); | |||||
| vst1q_s32(C + n + 0 * 4, c[0]); | |||||
| vst1q_s32(C + n + 1 * 4, c[1]); | |||||
| vst1q_s32(C + n + 2 * 4, c[2]); | |||||
| vst1q_s32(C + n + 3 * 4, c[3]); | |||||
| vst1q_s32(C + n + 4 * 4, c[4]); | |||||
| vst1q_s32(C + n + 5 * 4, c[5]); | |||||
| vst1q_s32(C + n + 6 * 4, c[6]); | |||||
| vst1q_s32(C + n + 7 * 4, c[7]); | |||||
| if (k_remain > 0) { | |||||
| for (size_t k = k_end; k < K; ++k) { | |||||
| for (size_t i = 0; i < n_block; ++i) { | |||||
| C[n + i] += A[k] * B[k * Bstride + n + i]; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| if (n_remain > 0) { | |||||
| for (size_t n = n_end; n < N; ++n) { | |||||
| if (!load_c) { | |||||
| C[n] = 0; | |||||
| } | |||||
| for (size_t k = 0; k < K; ++k) { | |||||
| C[n] += A[k] * B[k * Bstride + n]; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| #if MEGDNN_ARMV7 | |||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| void gevm_naive_dot_n32k4_impl( | |||||
| const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, | |||||
| size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride, | |||||
| bool load_c) { | |||||
| //! input must be N/32, k/4, 32, 4 | |||||
| //! TODO: add prefetch | |||||
| //! TODO: add double buffer | |||||
| constexpr size_t n_block = 32; | |||||
| constexpr size_t k_block = 4; | |||||
| for (size_t n = 0; n < N; n += n_block) { | |||||
| int32x4_t c[n_block / 4]; | |||||
| #define cb(step) c[step] = vdupq_n_s32(0); | |||||
| UNROLL_CALL_RAW(8, cb); | |||||
| #undef cb | |||||
| const int8_t* b_base = B + n * K; | |||||
| for (size_t k = 0; k < K; k += k_block) { | |||||
| int8x16_t a[1]; | |||||
| int8x16_t b[1][8]; | |||||
| #define cb(step) b[0][step] = vld1q_s8(b_base + k * 32 + 16 * step); | |||||
| UNROLL_CALL_RAW(8, cb); | |||||
| #undef cb | |||||
| a[0] = vld1q_dup_s32(A + k); | |||||
| c[0] = vdotq_s32(c[0], b[0][0], a[0]); | |||||
| c[1] = vdotq_s32(c[1], b[0][1], a[0]); | |||||
| c[2] = vdotq_s32(c[2], b[0][2], a[0]); | |||||
| c[3] = vdotq_s32(c[3], b[0][3], a[0]); | |||||
| c[4] = vdotq_s32(c[4], b[0][4], a[0]); | |||||
| c[5] = vdotq_s32(c[5], b[0][5], a[0]); | |||||
| c[6] = vdotq_s32(c[6], b[0][6], a[0]); | |||||
| c[7] = vdotq_s32(c[7], b[0][7], a[0]); | |||||
| } | |||||
| vst1q_s32(C + n + 0 * 4, c[0]); | |||||
| vst1q_s32(C + n + 1 * 4, c[1]); | |||||
| vst1q_s32(C + n + 2 * 4, c[2]); | |||||
| vst1q_s32(C + n + 3 * 4, c[3]); | |||||
| vst1q_s32(C + n + 4 * 4, c[4]); | |||||
| vst1q_s32(C + n + 5 * 4, c[5]); | |||||
| vst1q_s32(C + n + 6 * 4, c[6]); | |||||
| vst1q_s32(C + n + 7 * 4, c[7]); | |||||
| } | |||||
| } | |||||
| #else | |||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| inline void n32k4_dot( | |||||
| const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, | |||||
| size_t K) { | |||||
| int oddk = (K & 1); | |||||
| K = ((K + 1) / 2) - 1; | |||||
| //! C q0-q7 | |||||
| //! A q8-q9 | |||||
| //! B q10-q25 | |||||
| asm volatile( | |||||
| // load accumulator C | |||||
| "1:\n" | |||||
| "eor v0.16b, v0.16b, v0.16b\n" | |||||
| "eor v1.16b, v1.16b, v1.16b\n" | |||||
| "eor v2.16b, v2.16b, v2.16b\n" | |||||
| "eor v3.16b, v3.16b, v3.16b\n" | |||||
| "eor v4.16b, v4.16b, v4.16b\n" | |||||
| "eor v5.16b, v5.16b, v5.16b\n" | |||||
| "eor v6.16b, v6.16b, v6.16b\n" | |||||
| "eor v7.16b, v7.16b, v7.16b\n" | |||||
| "ld1r {v8.4s}, [%[a_ptr]]\n" | |||||
| "ld1 {v10.4s, v11.4s, v12.4s, v13.4s}, [%[b_ptr]], 64\n" | |||||
| "ld1 {v14.4s, v15.4s, v16.4s, v17.4s}, [%[b_ptr]], 64\n" | |||||
| "add %[a_ptr], %[a_ptr], #4\n" | |||||
| "cmp %w[k], #0\n" | |||||
| "beq 4f\n" | |||||
| "2: \n" | |||||
| // Loop proper | |||||
| "3:\n" | |||||
| "ld1r {v9.4s}, [%[a_ptr]]\n" | |||||
| "sdot v0.4s, v10.16b, v8.16b\n" | |||||
| "ldr q18, [%[b_ptr], #0]\n" | |||||
| "sdot v1.4s, v11.16b, v8.16b\n" | |||||
| "ldr q19, [%[b_ptr], #16]\n" | |||||
| "sdot v2.4s, v12.16b, v8.16b\n" | |||||
| "ldr q20, [%[b_ptr], #32]\n" | |||||
| "add %[a_ptr], %[a_ptr], #4\n" | |||||
| "sdot v3.4s, v13.16b, v8.16b\n" | |||||
| "ldr q21, [%[b_ptr], #48]\n" | |||||
| "sdot v4.4s, v14.16b, v8.16b\n" | |||||
| "ldr q22, [%[b_ptr], #64]\n" | |||||
| "sdot v5.4s, v15.16b, v8.16b\n" | |||||
| "ldr q23, [%[b_ptr], #80]\n" | |||||
| "sdot v6.4s, v16.16b, v8.16b\n" | |||||
| "ldr q24, [%[b_ptr], #96]\n" | |||||
| "sdot v7.4s, v17.16b, v8.16b\n" | |||||
| "ldr q25, [%[b_ptr], #112]\n" | |||||
| "ld1r {v8.4s}, [%[a_ptr]]\n" | |||||
| "sdot v0.4s, v18.16b, v9.16b\n" | |||||
| "ldr q10, [%[b_ptr], #128]\n" | |||||
| "sdot v1.4s, v19.16b, v9.16b\n" | |||||
| "ldr q11, [%[b_ptr], #144]\n" | |||||
| "sdot v2.4s, v20.16b, v9.16b\n" | |||||
| "ldr q12, [%[b_ptr], #160]\n" | |||||
| "sdot v3.4s, v21.16b, v9.16b\n" | |||||
| "ldr q13, [%[b_ptr], #176]\n" | |||||
| "sdot v4.4s, v22.16b, v9.16b\n" | |||||
| "ldr q14, [%[b_ptr], #192]\n" | |||||
| "sdot v5.4s, v23.16b, v9.16b\n" | |||||
| "ldr q15, [%[b_ptr], #208]\n" | |||||
| "sdot v6.4s, v24.16b, v9.16b\n" | |||||
| "ldr q16, [%[b_ptr], #224]\n" | |||||
| "sdot v7.4s, v25.16b, v9.16b\n" | |||||
| "ldr q17, [%[b_ptr], #240]\n" | |||||
| "add %[a_ptr], %[a_ptr], #4\n" | |||||
| "add %[b_ptr], %[b_ptr], #256\n" | |||||
| "subs %w[k], %w[k], #1\n" | |||||
| "bne 3b\n" | |||||
| "4:\n" | |||||
| "cmp %w[oddk], #1\n" | |||||
| "beq 5f\n" | |||||
| // Even tail | |||||
| "ld1r {v9.4s}, [%[a_ptr]]\n" | |||||
| "sdot v0.4s, v10.16b, v8.16b\n" | |||||
| "ldr q18, [%[b_ptr], #0]\n" | |||||
| "sdot v1.4s, v11.16b, v8.16b\n" | |||||
| "ldr q19, [%[b_ptr], #16]\n" | |||||
| "sdot v2.4s, v12.16b, v8.16b\n" | |||||
| "ldr q20, [%[b_ptr], #32]\n" | |||||
| "sdot v3.4s, v13.16b, v8.16b\n" | |||||
| "ldr q21, [%[b_ptr], #48]\n" | |||||
| "sdot v4.4s, v14.16b, v8.16b\n" | |||||
| "ldr q22, [%[b_ptr], #64]\n" | |||||
| "sdot v5.4s, v15.16b, v8.16b\n" | |||||
| "ldr q23, [%[b_ptr], #80]\n" | |||||
| "sdot v6.4s, v16.16b, v8.16b\n" | |||||
| "ldr q24, [%[b_ptr], #96]\n" | |||||
| "sdot v7.4s, v17.16b, v8.16b\n" | |||||
| "ldr q25, [%[b_ptr], #112]\n" | |||||
| "sdot v0.4s, v18.16b, v9.16b\n" | |||||
| "sdot v1.4s, v19.16b, v9.16b\n" | |||||
| "sdot v2.4s, v20.16b, v9.16b\n" | |||||
| "sdot v3.4s, v21.16b, v9.16b\n" | |||||
| "sdot v4.4s, v22.16b, v9.16b\n" | |||||
| "sdot v5.4s, v23.16b, v9.16b\n" | |||||
| "sdot v6.4s, v24.16b, v9.16b\n" | |||||
| "sdot v7.4s, v25.16b, v9.16b\n" | |||||
| "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[c_ptr]], 64\n" | |||||
| "st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[c_ptr]], 64\n" | |||||
| "b 6f\n" | |||||
| "5:\n" | |||||
| // Odd tail | |||||
| "sdot v0.4s, v10.16b, v8.16b\n" | |||||
| "sdot v1.4s, v11.16b, v8.16b\n" | |||||
| "sdot v2.4s, v12.16b, v8.16b\n" | |||||
| "sdot v3.4s, v13.16b, v8.16b\n" | |||||
| "sdot v4.4s, v14.16b, v8.16b\n" | |||||
| "sdot v5.4s, v15.16b, v8.16b\n" | |||||
| "sdot v6.4s, v16.16b, v8.16b\n" | |||||
| "sdot v7.4s, v17.16b, v8.16b\n" | |||||
| "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[c_ptr]], 64\n" | |||||
| "st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[c_ptr]], 64\n" | |||||
| "6:\n" | |||||
| : [a_ptr] "+r"(A), [b_ptr] "+r"(B), [k] "+r"(K), [c_ptr] "+r"(C), | |||||
| [oddk] "+r"(oddk) | |||||
| : | |||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||||
| "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", | |||||
| "v22", "v23", "v24", "v25", "cc", "memory"); | |||||
| } | |||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| void gevm_naive_dot_n32k4_impl( | |||||
| const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, | |||||
| size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride, | |||||
| bool load_c) { | |||||
| //! input must be N/32, k/4, 32, 4 | |||||
| //! TODO: add prefetch | |||||
| //! TODO: add double buffer | |||||
| constexpr size_t n_block = 32; | |||||
| for (size_t n = 0; n < N; n += n_block) { | |||||
| n32k4_dot(A, B + n * K, C + n, K / 4); | |||||
| } | |||||
| } | |||||
| #endif | |||||
| } // namespace | |||||
| void arm_common::gevm_naive_dot( | |||||
| const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, | |||||
| size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride) { | |||||
| megdnn_assert(M == 1); | |||||
| MIDOUT_BEGIN(megdnn_arm_common_int8_gemv, midout_iv("INT8_gevm_dot"_hash)) { | |||||
| size_t cache_size = 256 * 1024; | |||||
| size_t k_group = N * K / cache_size; | |||||
| constexpr size_t k_align = 8; | |||||
| if (k_group >= 2) { | |||||
| size_t k_per_group = ((K / k_group) + k_align - 1) / k_align * k_align; | |||||
| for (size_t k = 0; k < K; k += k_per_group) { | |||||
| size_t real_k = std::min(K - k, k_per_group); | |||||
| gevm_naive_dot_impl( | |||||
| A + k, B + k * Bstride, C, M, N, real_k, Astride, Bstride, | |||||
| Cstride, k != 0); | |||||
| } | |||||
| } else { | |||||
| gevm_naive_dot_impl(A, B, C, M, N, K, Astride, Bstride, Cstride, false); | |||||
| } | |||||
| } | |||||
| MIDOUT_END(); | |||||
| } | |||||
| void arm_common::gevm_naive_n32k4_dot( | |||||
| const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, | |||||
| size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride) { | |||||
| megdnn_assert(M == 1); | |||||
| MIDOUT_BEGIN(megdnn_arm_common_int8_gemv, midout_iv("INT8_gevm_dot_nk4"_hash)) { | |||||
| gevm_naive_dot_n32k4_impl(A, B, C, M, N, K, Astride, Bstride, Cstride, false); | |||||
| } | |||||
| MIDOUT_END(); | |||||
| } | |||||
| #endif | |||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -22,6 +22,14 @@ void gemv_like_mk4( | |||||
| void gemv_like_mk4_dot( | void gemv_like_mk4_dot( | ||||
| const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, | const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, | ||||
| size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride); | size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride); | ||||
| void gevm_naive_dot( | |||||
| const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, | |||||
| size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride); | |||||
| void gevm_naive_n32k4_dot( | |||||
| const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, | |||||
| size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride); | |||||
| #endif | #endif | ||||
| } // namespace arm_common | } // namespace arm_common | ||||
| @@ -14,6 +14,8 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||||
| AlgoInt8x8x32GemvMK4 int8x8x32_gemv_mk4; | AlgoInt8x8x32GemvMK4 int8x8x32_gemv_mk4; | ||||
| #if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
| AlgoInt8x8x32GemvMK4Dot int8x8x32_gemv_mk4_dot; | AlgoInt8x8x32GemvMK4Dot int8x8x32_gemv_mk4_dot; | ||||
| AlgoInt8x8x32GevmDot int8x8x32_gevm_dot; | |||||
| AlgoInt8x8x32GevmN32K4Dot int8x8x32_gevm_n32k4_dot; | |||||
| #endif | #endif | ||||
| AlgoGevm gevm; | AlgoGevm gevm; | ||||
| @@ -28,6 +30,8 @@ public: | |||||
| #endif | #endif | ||||
| #if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
| m_all_algos.emplace_back(&int8x8x32_gemv_mk4_dot); | m_all_algos.emplace_back(&int8x8x32_gemv_mk4_dot); | ||||
| m_all_algos.emplace_back(&int8x8x32_gevm_dot); | |||||
| m_all_algos.emplace_back(&int8x8x32_gevm_n32k4_dot); | |||||
| #endif | #endif | ||||
| m_all_algos.emplace_back(&int8x8x32_gemv); | m_all_algos.emplace_back(&int8x8x32_gemv); | ||||
| m_all_algos.emplace_back(&int8x8x32_gemv_mk4); | m_all_algos.emplace_back(&int8x8x32_gemv_mk4); | ||||
| @@ -31,7 +31,9 @@ protected: | |||||
| class AlgoF16Gemv; | class AlgoF16Gemv; | ||||
| #endif | #endif | ||||
| #if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
| class AlgoInt8x8x32GemvMK4Dot; // Arm_common Int8x8x32 Gemv NCHW44_DOT | |||||
| class AlgoInt8x8x32GemvMK4Dot; // Arm_common Int8x8x32 Gemv NCHW44_DOT | |||||
| class AlgoInt8x8x32GevmDot; // Arm_common Int8x8x32 Gevm NCHW DOT | |||||
| class AlgoInt8x8x32GevmN32K4Dot; // Arm_common Int8x8x32 Gevm NK4 | |||||
| #endif | #endif | ||||
| class AlgoInt8x8x16; // Arm_common Int 8x8x16 | class AlgoInt8x8x16; // Arm_common Int 8x8x16 | ||||
| class AlgoPack; | class AlgoPack; | ||||
| @@ -469,7 +469,7 @@ __ai float64x2_t vbitq_f64(float64x2_t dst, float64x2_t v1, uint64x2_t mask) { | |||||
| #endif | #endif | ||||
| #if MEGDNN_ARMV7 | #if MEGDNN_ARMV7 | ||||
| __ai int8x16_t vqtbl1q_s8(int8x16_t& a, uint8x16_t& idx) { | |||||
| __ai int8x16_t vqtbl1q_s8(int8x16_t a, uint8x16_t idx) { | |||||
| int8x8_t src_low = vget_low_s8(a); | int8x8_t src_low = vget_low_s8(a); | ||||
| int8x8_t src_high = vget_high_s8(a); | int8x8_t src_high = vget_high_s8(a); | ||||
| return vcombine_s8( | return vcombine_s8( | ||||
| @@ -726,6 +726,13 @@ __ai float32x4_t Vfmsq_f32(float32x4_t& a, float32x4_t& b, float32x4_t& v) { | |||||
| asm volatile("fmls %0.4s, %1.4s, %2.4s\n" : "+w"(a) : "w"(b), "w"(v) :); | asm volatile("fmls %0.4s, %1.4s, %2.4s\n" : "+w"(a) : "w"(b), "w"(v) :); | ||||
| return a; | return a; | ||||
| } | } | ||||
| #if __ARM_ARCH < 8 | |||||
| __ai int32x4_t vcvtaq_s32_f32(float32x4_t val) { | |||||
| float32x4_t vinc0 = vbslq_f32( | |||||
| vcgeq_f32(val, vdupq_n_f32(0.f)), vdupq_n_f32(0.5f), vdupq_n_f32(-0.5f)); | |||||
| return vcvtq_s32_f32(vaddq_f32(val, vinc0)); | |||||
| } | |||||
| #endif | |||||
| #if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
| #undef __ARM_FEATURE_DOTPROD | #undef __ARM_FEATURE_DOTPROD | ||||
| #endif | #endif | ||||
| @@ -61,6 +61,15 @@ void MatrixMulForward::deduce_layout( | |||||
| "(transposed) B is (%zu,%zu)", | "(transposed) B is (%zu,%zu)", | ||||
| A0, A1, B0, B1); | A0, A1, B0, B1); | ||||
| C = TensorLayout(TensorShape({A0, B1}), C.dtype); | C = TensorLayout(TensorShape({A0, B1}), C.dtype); | ||||
| } else if (param().format == param::MatrixMul::Format::N32K4_DOT) { | |||||
| A0 = A.shape[0]; | |||||
| A1 = A.shape[1]; | |||||
| B0 = B.shape[0]; | |||||
| B1 = B.shape[1]; | |||||
| megdnn_assert(!m_param.transposeA && !m_param.transposeB); | |||||
| megdnn_assert(A0 == 1 && A1 % 4 == 0); | |||||
| megdnn_assert(B.ndim == 4); | |||||
| C = TensorLayout(TensorShape({A0, B0 * 32}), C.dtype); | |||||
| } else { | } else { | ||||
| auto do_deduce = [&](size_t pack_size) { | auto do_deduce = [&](size_t pack_size) { | ||||
| megdnn_assert( | megdnn_assert( | ||||
| @@ -132,6 +141,18 @@ void MatrixMulForward::check_exec( | |||||
| megdnn_assert(A0 == C0, "%s", errmsg().c_str()); | megdnn_assert(A0 == C0, "%s", errmsg().c_str()); | ||||
| megdnn_assert(B1 == C1, "%s", errmsg().c_str()); | megdnn_assert(B1 == C1, "%s", errmsg().c_str()); | ||||
| megdnn_assert(A1 == B0, "%s", errmsg().c_str()); | megdnn_assert(A1 == B0, "%s", errmsg().c_str()); | ||||
| } else if (param().format == param::MatrixMul::Format::N32K4_DOT) { | |||||
| size_t A0 = A.shape[0]; | |||||
| size_t A1 = A.shape[1]; | |||||
| size_t B2 = B.shape[2]; | |||||
| size_t B3 = B.shape[3]; | |||||
| megdnn_assert(!m_param.transposeA && !m_param.transposeB); | |||||
| megdnn_assert(A0 == 1 && A1 % 4 == 0); | |||||
| megdnn_assert(B.ndim == 4); | |||||
| megdnn_assert(B2 == 32 && B3 == 4); | |||||
| megdnn_assert_contiguous(A); | |||||
| megdnn_assert_contiguous(B); | |||||
| megdnn_assert_contiguous(C); | |||||
| } else { | } else { | ||||
| megdnn_assert_eq_size_t(A.ndim, 4_z); | megdnn_assert_eq_size_t(A.ndim, 4_z); | ||||
| megdnn_assert_eq_size_t(B.ndim, 3_z); | megdnn_assert_eq_size_t(B.ndim, 3_z); | ||||
| @@ -442,9 +442,11 @@ ConvBiasImpl::NCBKernSizeParam ConvBiasImpl::make_ncb_kern_size_param( | |||||
| megdnn_assert(0, "invalid conv format %d", static_cast<int>(param().format)); | megdnn_assert(0, "invalid conv format %d", static_cast<int>(param().format)); | ||||
| } | } | ||||
| BiasMode bias_mode; | BiasMode bias_mode; | ||||
| //! dst only channel BIAS is viewed as BROADCAST_CHANNEL_BIAS | |||||
| bool dst_only_c = dst[0] == 1 && dst[spatial_pos] == 1 && dst[spatial_pos + 1] == 1; | |||||
| if (bias.ndim == 0) { | if (bias.ndim == 0) { | ||||
| bias_mode = BiasMode::NO_BIAS; | bias_mode = BiasMode::NO_BIAS; | ||||
| } else if (bias.eq_shape(dst)) { | |||||
| } else if (bias.eq_shape(dst) && !dst_only_c) { | |||||
| bias_mode = BiasMode::BIAS; | bias_mode = BiasMode::BIAS; | ||||
| } else { | } else { | ||||
| //! just check the ndim, the detail shape check is in check_exec | //! just check the ndim, the detail shape check is in check_exec | ||||
| @@ -258,6 +258,9 @@ public: | |||||
| ARM_COMMON_CHANWISE_STRD1_NCHW44_S8, | ARM_COMMON_CHANWISE_STRD1_NCHW44_S8, | ||||
| ARM_COMMON_CHANWISE_STRD2_NCHW44_S8, | ARM_COMMON_CHANWISE_STRD2_NCHW44_S8, | ||||
| ARM_COMMON_DIRECT_NCHW_NCHW44_DOT_S8, | ARM_COMMON_DIRECT_NCHW_NCHW44_DOT_S8, | ||||
| //! LARGE for large filter | |||||
| ARM_COMMON_DOT_IM2COL_CHANWISE_LARGE_S8, | |||||
| ARM_COMMON_DOT_DIRECT_CHANWISE_LARGE_S8, | |||||
| ARM_COMMON_DIRECT_STRD1_DOT_S8, | ARM_COMMON_DIRECT_STRD1_DOT_S8, | ||||
| ARM_COMMON_DIRECT_STRD2_DOT_S8, | ARM_COMMON_DIRECT_STRD2_DOT_S8, | ||||
| ARM_COMMON_DIRECT_NCHW44_DOT_S8, | ARM_COMMON_DIRECT_NCHW44_DOT_S8, | ||||
| @@ -195,11 +195,11 @@ MatrixMulImpl::KernSizeParam MatrixMulImpl::make_kern_size_param( | |||||
| kern_size_param.trB = param().transposeB; | kern_size_param.trB = param().transposeB; | ||||
| kern_size_param.compute_mode = param().compute_mode; | kern_size_param.compute_mode = param().compute_mode; | ||||
| kern_size_param.format = param().format; | kern_size_param.format = param().format; | ||||
| size_t pack_size = MatrixMulForward::pack_size(param().format); | |||||
| kern_size_param.K *= pack_size; | |||||
| kern_size_param.M *= pack_size; | |||||
| if (param().format != Param::Format::N32K4_DOT) { | |||||
| size_t pack_size = MatrixMulForward::pack_size(param().format); | |||||
| kern_size_param.K *= pack_size; | |||||
| kern_size_param.M *= pack_size; | |||||
| } | |||||
| return kern_size_param; | return kern_size_param; | ||||
| } | } | ||||
| @@ -122,6 +122,8 @@ public: | |||||
| ARM_COMMON_INT8X8X32_GEMV, | ARM_COMMON_INT8X8X32_GEMV, | ||||
| ARM_COMMON_INT8X8X32_GEMV_MK4, | ARM_COMMON_INT8X8X32_GEMV_MK4, | ||||
| ARM_COMMON_INT8X8X32_GEMV_MK4_DOT, | ARM_COMMON_INT8X8X32_GEMV_MK4_DOT, | ||||
| ARM_COMMON_INT8X8X32_GEVM_DOT, | |||||
| ARM_COMMON_INT8X8X32_GEVM_N32K4_DOT, | |||||
| ARM_COMMON_F16_GEMV, | ARM_COMMON_F16_GEMV, | ||||
| ARM_COMMON_GEVM, | ARM_COMMON_GEVM, | ||||
| #if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
| @@ -175,6 +177,7 @@ public: | |||||
| enum class AlgoSet : uint32_t { | enum class AlgoSet : uint32_t { | ||||
| ALGO_TYPE_GEMM = 0, | ALGO_TYPE_GEMM = 0, | ||||
| ALGO_TYPE_GEMV = 1, | ALGO_TYPE_GEMV = 1, | ||||
| ALGO_TYPE_GEVM = 2, | |||||
| }; | }; | ||||
| enum class PackMode : uint32_t { | enum class PackMode : uint32_t { | ||||
| @@ -102,6 +102,34 @@ void run_matrix_mul_mk4_dot_tpl( | |||||
| } | } | ||||
| } | } | ||||
| template < | |||||
| typename itype, typename otype, bool transA, bool transB, | |||||
| typename comp_type = otype> | |||||
| void run_matrix_mul_n32k4_dot_tpl( | |||||
| const itype* A, const itype* B, otype* C, size_t M, size_t N, size_t K, | |||||
| size_t LDA, size_t, size_t, const DType& A_type, const DType& B_type) { | |||||
| Getter<itype, comp_type> getterA(A_type), getterB(B_type); | |||||
| megdnn_assert(!transA && !transB); | |||||
| for (size_t m = 0; m < M; ++m) { | |||||
| for (size_t n = 0; n < N; n += 32) { | |||||
| comp_type res[32] = {comp_type(0)}; | |||||
| for (size_t k = 0; k < K; k += 4) { | |||||
| for (size_t i = 0; i < 32; i++) { | |||||
| comp_type av, bv; | |||||
| for (size_t j = 0; j < 4; j++) { | |||||
| av = getterA(A[m * LDA + k + j]); | |||||
| bv = getterA(B[n * K + k * 32 + i * 4 + j]); | |||||
| res[i] += av * bv; | |||||
| } | |||||
| } | |||||
| } | |||||
| for (size_t i = 0; i < 32; i++) { | |||||
| C[n + i] = res[i]; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| template < | template < | ||||
| typename itype, typename otype, bool transA, bool transB, | typename itype, typename otype, bool transA, bool transB, | ||||
| typename comp_type = otype> | typename comp_type = otype> | ||||
| @@ -251,6 +279,10 @@ void dispatch_ta_tb( | |||||
| return run_matrix_mul_mk4_dot_tpl<_itype, _otype, TA, TB, _comp_type>( \ | return run_matrix_mul_mk4_dot_tpl<_itype, _otype, TA, TB, _comp_type>( \ | ||||
| static_cast<const _itype*>(A), static_cast<const _itype*>(B), \ | static_cast<const _itype*>(A), static_cast<const _itype*>(B), \ | ||||
| static_cast<_otype*>(C), M, N, K, LDA, LDB, LDC, A_type, B_type); \ | static_cast<_otype*>(C), M, N, K, LDA, LDB, LDC, A_type, B_type); \ | ||||
| } else if (format == param::MatrixMul::Format::N32K4_DOT) { \ | |||||
| return run_matrix_mul_n32k4_dot_tpl<_itype, _otype, TA, TB, _comp_type>( \ | |||||
| static_cast<const _itype*>(A), static_cast<const _itype*>(B), \ | |||||
| static_cast<_otype*>(C), M, N, K, LDA, LDB, LDC, A_type, B_type); \ | |||||
| } else if (format == param::MatrixMul::Format::MK8) { \ | } else if (format == param::MatrixMul::Format::MK8) { \ | ||||
| return run_matrix_mul_mk8_tpl<_itype, _otype, TA, TB, _comp_type>( \ | return run_matrix_mul_mk8_tpl<_itype, _otype, TA, TB, _comp_type>( \ | ||||
| static_cast<const _itype*>(A), static_cast<const _itype*>(B), \ | static_cast<const _itype*>(A), static_cast<const _itype*>(B), \ | ||||
| @@ -160,7 +160,7 @@ static void benchmark_convbias( | |||||
| .set_display(false); | .set_display(false); | ||||
| } | } | ||||
| auto nchw44_algo_regx = ".*(DIRECT|NCHW_NCHW44).*"; | auto nchw44_algo_regx = ".*(DIRECT|NCHW_NCHW44).*"; | ||||
| #if MGB_ENBALE_DOT | |||||
| #if MGB_ENABLE_DOT | |||||
| if (!is_fp32) { | if (!is_fp32) { | ||||
| nchw44_algo_regx = ".*DOT.*"; | nchw44_algo_regx = ".*DOT.*"; | ||||
| } | } | ||||
| @@ -1626,7 +1626,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_QINT8_STRIDE1_NCHW44) { | |||||
| #endif | #endif | ||||
| #if MGB_ENBALE_DOT | |||||
| #if MGB_ENABLE_DOT | |||||
| #if MEGDNN_WITH_BENCHMARK | #if MEGDNN_WITH_BENCHMARK | ||||
| TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_STRIDE1_WITHDOTPROD) { | TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_STRIDE1_WITHDOTPROD) { | ||||
| // have to remove preferred restrict in usable func before run the benchmark | // have to remove preferred restrict in usable func before run the benchmark | ||||
| @@ -2011,6 +2011,80 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_NCHW44_DOT) { | |||||
| } | } | ||||
| } | } | ||||
| TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_LARGE_KERN_NCHW_DOT) { | |||||
| using namespace conv_bias; | |||||
| std::vector<TestArg> args; | |||||
| auto run = [&](size_t group, size_t w, size_t h, size_t kernel, size_t stride, | |||||
| NonlineMode nonline_mode) { | |||||
| size_t p = kernel / 2; | |||||
| if (w + 2 * p < kernel || h + 2 * p < kernel) | |||||
| return; | |||||
| param::ConvBias param; | |||||
| param.stride_h = stride; | |||||
| param.stride_w = stride; | |||||
| param.pad_h = p; | |||||
| param.pad_w = p; | |||||
| param.nonlineMode = nonline_mode; | |||||
| param.format = param::ConvBias::Format::NCHW; | |||||
| param.sparse = ConvBiasForward::Param::Sparse::GROUP; | |||||
| //! channel bias | |||||
| args.emplace_back( | |||||
| param, TensorShape{1, group, h, w}, | |||||
| TensorShape{group, 1, 1, kernel, kernel}, TensorShape{1, group, 1, 1}); | |||||
| }; | |||||
| run(64, 64, 64, 9, 1, NonlineMode::RELU); | |||||
| run(64, 40, 40, 9, 2, NonlineMode::RELU); | |||||
| run(64, 20, 20, 9, 1, NonlineMode::RELU); | |||||
| constexpr size_t RUN = 120; | |||||
| Benchmarker<ConvBias> benchmark0(handle()); | |||||
| benchmark0.set_dtype(0, dtype::QuantizedS8(2.5f)) | |||||
| .set_dtype(1, dtype::QuantizedS8(2.5f)) | |||||
| .set_dtype(2, dtype::QuantizedS32(6.25f)) | |||||
| .set_dtype(4, dtype::QuantizedS8(60.25f)); | |||||
| benchmark0.set_display(false); | |||||
| benchmark0.set_times(RUN); | |||||
| benchmark0.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBiasForward>( | |||||
| "ARMDOTS8_DIRECT_CHANWISE_LARGE")); | |||||
| Benchmarker<ConvBias> benchmark1(handle()); | |||||
| benchmark1.set_dtype(0, dtype::QuantizedS8(2.5f)) | |||||
| .set_dtype(1, dtype::QuantizedS8(2.5f)) | |||||
| .set_dtype(2, dtype::QuantizedS32(6.25f)) | |||||
| .set_dtype(4, dtype::QuantizedS8(60.25f)); | |||||
| benchmark1.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBiasForward>( | |||||
| "ARMDOTS8_IM2COL_CHANWISE_LARGE")); | |||||
| benchmark1.set_display(false); | |||||
| benchmark1.set_times(RUN); | |||||
| for (auto&& arg : args) { | |||||
| TensorLayout dst_layout; | |||||
| auto opr = handle()->create_operator<ConvBias>(); | |||||
| opr->param() = arg.param; | |||||
| opr->deduce_layout( | |||||
| {arg.src, dtype::Int8()}, {arg.filter, dtype::Int8()}, | |||||
| {arg.bias, dtype::Int32()}, {}, dst_layout); | |||||
| //! dst.nr_elems * FH * FW * 2 | |||||
| float computations = | |||||
| dst_layout.total_nr_elems() * arg.filter[3] * arg.filter[4] * 2.0 / 1e6; | |||||
| auto used0 = benchmark0.set_param(arg.param).exec( | |||||
| {arg.src, arg.filter, arg.bias, {}, {}}) / | |||||
| RUN; | |||||
| auto used1 = benchmark1.set_param(arg.param).exec( | |||||
| {arg.src, arg.filter, arg.bias, {}, {}}) / | |||||
| RUN; | |||||
| printf("%s %s: Direct use: %f ms %f Gflops im2col: %f ms %f GFlops " | |||||
| "speedup: %f\n", | |||||
| arg.src.to_string().c_str(), arg.filter.to_string().c_str(), used0, | |||||
| computations / used0, used1, computations / used1, used1 / used0); | |||||
| } | |||||
| } | |||||
| #endif | #endif | ||||
| #endif | #endif | ||||
| @@ -2194,7 +2268,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_CONV1X1_S1_QUANTIZEDSYM) { | |||||
| dtype::QuantizedS8 stype(2.5f); | dtype::QuantizedS8 stype(2.5f); | ||||
| dtype::QuantizedS32 dtype(6.25f); | dtype::QuantizedS32 dtype(6.25f); | ||||
| #if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
| #if MGB_ENBALE_DOT | |||||
| #if MGB_ENABLE_DOT | |||||
| benchmark_conv1x1( | benchmark_conv1x1( | ||||
| "AARCH64_INT8X8X32_K8X12X4_DOTPROD", handle(), stype, dtype, dtype, dtype); | "AARCH64_INT8X8X32_K8X12X4_DOTPROD", handle(), stype, dtype, dtype, dtype); | ||||
| #else | #else | ||||
| @@ -2212,7 +2286,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_CONV1X1_S1_QUANTIZEDASYM) { | |||||
| dtype::QuantizedS32 dtype(1.2 * 1.2); | dtype::QuantizedS32 dtype(1.2 * 1.2); | ||||
| #if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
| #if MGB_ENBALE_DOT | |||||
| #if MGB_ENABLE_DOT | |||||
| benchmark_conv1x1( | benchmark_conv1x1( | ||||
| "AARCH64_QUINT8_K8X8X4_DOTPROD", handle(), stype, dtype, dtype, dtype); | "AARCH64_QUINT8_K8X8X4_DOTPROD", handle(), stype, dtype, dtype, dtype); | ||||
| #else | #else | ||||
| @@ -136,6 +136,84 @@ std::vector<conv_bias::TestArg> get_nchw44_channel_wise_args( | |||||
| return args; | return args; | ||||
| } | } | ||||
| std::vector<conv_bias::TestArg> get_channel_wise_args( | |||||
| std::vector<size_t> kernel, size_t stride, bool no_bias, bool no_nonlinemode, | |||||
| bool no_full_bias, bool support_relu) { | |||||
| using namespace conv_bias; | |||||
| using Param = param::ConvBias; | |||||
| using NLMode = param::ConvBias::NonlineMode; | |||||
| std::vector<TestArg> args; | |||||
| auto pack = [&](size_t n, size_t group, size_t w, size_t h, size_t kernel, | |||||
| size_t stride, NLMode nlmode, bool pad) { | |||||
| Param param; | |||||
| param.stride_h = stride; | |||||
| param.stride_w = stride; | |||||
| if (pad) { | |||||
| param.pad_h = kernel / 2; | |||||
| param.pad_w = kernel / 2; | |||||
| } else { | |||||
| param.pad_h = 0; | |||||
| param.pad_w = 0; | |||||
| } | |||||
| param.nonlineMode = nlmode; | |||||
| param.format = param::ConvBias::Format::NCHW; | |||||
| param.sparse = param::ConvBias::Sparse::GROUP; | |||||
| args.emplace_back( | |||||
| param, TensorShape{n, group, h, w}, | |||||
| TensorShape{group, 1, 1, kernel, kernel}, TensorShape{}); | |||||
| if (!no_bias) { | |||||
| args.emplace_back( | |||||
| param, TensorShape{n, group, h, w}, | |||||
| TensorShape{group, 1, 1, kernel, kernel}, | |||||
| TensorShape{1, group, 1, 1}); | |||||
| } | |||||
| if (!no_full_bias) { | |||||
| args.emplace_back( | |||||
| param, TensorShape{n, group, h, w}, | |||||
| TensorShape{group, 1, 1, kernel, kernel}, | |||||
| TensorShape{ | |||||
| n, group, (h + 2 * param.pad_w - kernel) / stride + 1, | |||||
| (w + 2 * param.pad_w - kernel) / stride + 1}); | |||||
| } | |||||
| }; | |||||
| std::vector<NLMode> nonlinemode = {NLMode::IDENTITY}; | |||||
| if (!no_nonlinemode) { | |||||
| nonlinemode.emplace_back(NLMode::RELU); | |||||
| nonlinemode.emplace_back(NLMode::H_SWISH); | |||||
| } else if (support_relu) { | |||||
| nonlinemode.emplace_back(NLMode::RELU); | |||||
| } | |||||
| for (size_t n : {1, 2}) { | |||||
| for (auto nlmode : nonlinemode) { | |||||
| for (bool pad : {true}) { | |||||
| for (size_t group : {1, 3, 7}) { | |||||
| for (size_t size : {4, 6, 7, 9, 16, 20, 32, 55}) { | |||||
| for (size_t kern : kernel) { | |||||
| pack(n, group, size, size, kern, stride, nlmode, pad); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| for (bool pad : {false}) { | |||||
| for (size_t group : {7}) { | |||||
| for (size_t size : {37}) { | |||||
| for (size_t kern : kernel) { | |||||
| if (size < kern) | |||||
| continue; | |||||
| pack(n, group, size, size, kern, stride, nlmode, pad); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| return args; | |||||
| } | |||||
| std::vector<conv_bias::TestArg> get_nchw88_channel_wise_args( | std::vector<conv_bias::TestArg> get_nchw88_channel_wise_args( | ||||
| std::vector<size_t> kernel, size_t stride, bool no_bias, bool no_nonlinemode, | std::vector<size_t> kernel, size_t stride, bool no_bias, bool no_nonlinemode, | ||||
| bool no_full_bias) { | bool no_full_bias) { | ||||
| @@ -226,7 +304,7 @@ void checker_conv_bias_qint8x8x8( | |||||
| .set_rng(1, &rng) | .set_rng(1, &rng) | ||||
| .set_rng(2, &rng); | .set_rng(2, &rng); | ||||
| for (auto&& arg : args) { | for (auto&& arg : args) { | ||||
| checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}}); | |||||
| checker.set_param(arg.param).execs({arg.src, arg.filter, arg.bias, {}, {}}); | |||||
| } | } | ||||
| } | } | ||||
| void checker_conv_bias_qint8x8x32( | void checker_conv_bias_qint8x8x32( | ||||
| @@ -532,6 +610,30 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2) { | |||||
| /****************************dot qint8 direct*************************/ | /****************************dot qint8 direct*************************/ | ||||
| #if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_DIRECT_LARGE_S1) { | |||||
| checker_conv_bias_qint8x8x8( | |||||
| get_channel_wise_args({9}, 1, false, true, true, true), handle(), | |||||
| "ARMDOTS8_DIRECT_CHANWISE_LARGE"); | |||||
| } | |||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_DIRECT_LARGE_S2) { | |||||
| checker_conv_bias_qint8x8x8( | |||||
| get_channel_wise_args({9}, 2, false, true, true, true), handle(), | |||||
| "ARMDOTS8_DIRECT_CHANWISE_LARGE"); | |||||
| } | |||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_IM2COL_LARGE_S1) { | |||||
| checker_conv_bias_qint8x8x8( | |||||
| get_channel_wise_args({9}, 1, false, true, true, true), handle(), | |||||
| "ARMDOTS8_IM2COL_CHANWISE_LARGE"); | |||||
| } | |||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_IM2COL_LARGE_S2) { | |||||
| checker_conv_bias_qint8x8x8( | |||||
| get_channel_wise_args({9}, 2, false, true, true, true), handle(), | |||||
| "ARMDOTS8_IM2COL_CHANWISE_LARGE"); | |||||
| } | |||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44) { | ||||
| auto args = get_nchw44_conv_bias_args( | auto args = get_nchw44_conv_bias_args( | ||||
| {2, 3, 5, 7}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 2, false, true); | {2, 3, 5, 7}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 2, false, true); | ||||
| @@ -219,6 +219,113 @@ TEST_F(ARM_COMMON, QINT8x8x32_GEMV_MK4_DOT) { | |||||
| for (size_t K : {4, 8, 12, 16, 20, 24, 256, 1024}) | for (size_t K : {4, 8, 12, 16, 20, 24, 256, 1024}) | ||||
| run(M, K, 1); | run(M, K, 1); | ||||
| } | } | ||||
| TEST_F(ARM_COMMON, QINT8x8x32_GEVM_DOT) { | |||||
| Checker<MatrixMul> checker(handle()); | |||||
| using Param = MatrixMul::Param; | |||||
| auto algo_ck = AlgoChecker<MatrixMul>("ARM_COMMON_INT8X8X32_GEVM_DOT"); | |||||
| checker.set_before_exec_callback(algo_ck); | |||||
| std::unique_ptr<RNG> rng = std::make_unique<UniformIntRNG>(-30, 30); | |||||
| checker.set_rng(0, rng.get()).set_rng(1, rng.get()); | |||||
| Param param; | |||||
| param.format = Param::Format::DEFAULT; | |||||
| param.transposeA = false; | |||||
| param.transposeB = false; | |||||
| auto run = [&](size_t M, size_t N, size_t K) { | |||||
| TensorShape A, B; | |||||
| A = TensorShape{M, K}; | |||||
| B = TensorShape{K, N}; | |||||
| checker.set_param(param) | |||||
| .set_dtype(0, dtype::Int8()) | |||||
| .set_dtype(1, dtype::Int8()) | |||||
| .set_dtype(2, dtype::Int32()) | |||||
| .execs({A, B, {}}); | |||||
| }; | |||||
| run(1, 32, 4); | |||||
| for (int n = 7; n < 43; n += 3) { | |||||
| for (int k = 1; k < 33; k += 3) { | |||||
| run(1, n, k); | |||||
| } | |||||
| } | |||||
| } | |||||
| TEST_F(ARM_COMMON, QINT8x8x32_GEVM_N32K4_DOT) { | |||||
| Checker<MatrixMul> checker(handle()); | |||||
| using Param = MatrixMul::Param; | |||||
| auto algo_ck = AlgoChecker<MatrixMul>("ARM_COMMON_INT8X8X32_GEVM_N32K4_DOT"); | |||||
| checker.set_before_exec_callback(algo_ck); | |||||
| std::unique_ptr<RNG> rng = std::make_unique<UniformIntRNG>(-30, 30); | |||||
| checker.set_rng(0, rng.get()).set_rng(1, rng.get()); | |||||
| Param param; | |||||
| param.format = Param::Format::N32K4_DOT; | |||||
| param.transposeA = false; | |||||
| param.transposeB = false; | |||||
| auto run = [&](size_t M, size_t N, size_t K) { | |||||
| TensorShape A, B; | |||||
| A = TensorShape{M, K}; | |||||
| B = TensorShape{N / 32, K / 4, 32, 4}; | |||||
| checker.set_param(param) | |||||
| .set_dtype(0, dtype::Int8()) | |||||
| .set_dtype(1, dtype::Int8()) | |||||
| .set_dtype(2, dtype::Int32()) | |||||
| .execs({A, B, {}}); | |||||
| }; | |||||
| run(1, 32, 4); | |||||
| for (int n = 32; n < 65; n += 32) { | |||||
| for (int k = 4; k < 39; k += 4) { | |||||
| run(1, n, k); | |||||
| } | |||||
| } | |||||
| } | |||||
| #if MEGDNN_WITH_BENCHMARK | |||||
| TEST_F(ARM_COMMON, BENCHMARK_QINT8x8x32_GEVM_N32K4_DOT) { | |||||
| using Param = MatrixMul::Param; | |||||
| auto algo_ck = AlgoChecker<MatrixMul>("ARM_COMMON_INT8X8X32_GEVM_N32K4_DOT"); | |||||
| std::unique_ptr<RNG> rng = std::make_unique<UniformIntRNG>(-30, 30); | |||||
| Param param; | |||||
| param.format = Param::Format::N32K4_DOT; | |||||
| param.transposeA = false; | |||||
| param.transposeB = false; | |||||
| constexpr size_t RUNS = 2000; | |||||
| Benchmarker<MatrixMul> benchmarker_int(handle()); | |||||
| benchmarker_int.set_times(RUNS) | |||||
| .set_dtype(0, dtype::Int8{}) | |||||
| .set_dtype(1, dtype::Int8{}) | |||||
| .set_dtype(2, dtype::Int32{}) | |||||
| .set_param(param) | |||||
| .set_before_exec_callback(algo_ck) | |||||
| .set_display(false); | |||||
| Benchmarker<MatrixMul> benchmarker_float(handle()); | |||||
| benchmarker_float.set_display(false).set_times(RUNS); | |||||
| auto bench = [&](size_t M, size_t N, size_t K) { | |||||
| auto int_used = | |||||
| benchmarker_int.exec({{M, K}, {N / 32, K / 4, 32, 4}, {}}) / RUNS; | |||||
| auto float_used = benchmarker_float.exec({{M, K}, {K, N}, {}}) / RUNS; | |||||
| float computations = 2.f * M * K * N * 1e-6; | |||||
| float through_put = (M * K + N * K + M * N) * 1e-6; | |||||
| printf("run: {%zu{M} %zu{K} %zu{N}} float: %f ms %f Gflops int: %f ms " | |||||
| "%f Gflops speedup: %f, through put %f G\n", | |||||
| M, K, N, float_used, computations / float_used, int_used, | |||||
| computations / int_used, float_used / int_used, through_put / int_used); | |||||
| }; | |||||
| bench(1, 256, 512); | |||||
| bench(1, 256, 1024); | |||||
| bench(1, 512, 512); | |||||
| bench(1, 512, 1024); | |||||
| } | |||||
| #endif | |||||
| #endif | #endif | ||||
| TEST_F(ARM_COMMON, QINT8x8x32_GEVM) { | TEST_F(ARM_COMMON, QINT8x8x32_GEVM) { | ||||