nchw44 float 3x3 s2 fuse packb speed up about 10%
GitOrigin-RevId: 3f864cef1d
tags/v0.6.0
| @@ -227,28 +227,28 @@ public: | |||||
| "DefaultStrategyType::FLOAT"_hash); | "DefaultStrategyType::FLOAT"_hash); | ||||
| } else if (format == param::ConvBias::Format::NCHW44) { | } else if (format == param::ConvBias::Format::NCHW44) { | ||||
| #if MEGDNN_AARCH64 | |||||
| #if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||||
| auto matmul_block = matmul_algo->get_inner_block_size(); | auto matmul_block = matmul_algo->get_inner_block_size(); | ||||
| //! Optimize NCHW44 3x3s2 8X12X1 im2col+pack fuse | |||||
| if (matmul_block.m == 8 && matmul_block.n == 12 && | |||||
| matmul_block.k == 1 && | |||||
| param.filter_meta.spatial[0] == 3 && | |||||
| param.filter_meta.spatial[1] == 3 && | |||||
| param.filter_meta.stride[0] == 2 && | |||||
| param.filter_meta.stride[1] == 2 && | |||||
| !param.filter_meta.should_flip) { | |||||
| MIDOUT_BEGIN( | |||||
| megdnn_fallback_im2col_factory_make_strategy, | |||||
| midout_iv( | |||||
| "DefaultStrategyType::8x12x1_fuse_packb_s2_nchw44"_hash)) { | |||||
| return std::make_unique< | |||||
| StrategyFuse8x12x1Nchw44K3x3S2< | |||||
| float, float, | |||||
| PostprocessMode::FLOAT>>(); | |||||
| } | |||||
| MIDOUT_END(); | |||||
| return {}; | |||||
| //! Optimize NCHW44 3x3s2 aarch64 8X12X1 and armv7 4x12x1 im2col+pack fuse | |||||
| if ((matmul_block.m == 8 || matmul_block.m == 4) && | |||||
| matmul_block.n == 12 && matmul_block.k == 1 && | |||||
| param.filter_meta.spatial[0] == 3 && | |||||
| param.filter_meta.spatial[1] == 3 && | |||||
| param.filter_meta.stride[0] == 2 && | |||||
| param.filter_meta.stride[1] == 2 && | |||||
| !param.filter_meta.should_flip) { | |||||
| MIDOUT_BEGIN( | |||||
| megdnn_fallback_im2col_factory_make_strategy, | |||||
| midout_iv( | |||||
| "DefaultStrategyType::8x12x1_fuse_packb_s2_nchw44"_hash)) { | |||||
| return std::make_unique< | |||||
| StrategyFuseXx12x1Nchw44K3x3S2< | |||||
| float, float, | |||||
| PostprocessMode::FLOAT>>(); | |||||
| } | } | ||||
| MIDOUT_END(); | |||||
| return {}; | |||||
| } | |||||
| #endif | #endif | ||||
| cb1(NCHW44, DEFAULT, dt_float32, dt_float32, | cb1(NCHW44, DEFAULT, dt_float32, dt_float32, | ||||
| @@ -345,10 +345,10 @@ public: | |||||
| "DefaultStrategyType::QINT8x8x32x8"_hash); | "DefaultStrategyType::QINT8x8x32x8"_hash); | ||||
| } else if (format == param::ConvBias::Format::NCHW44 || | } else if (format == param::ConvBias::Format::NCHW44 || | ||||
| format == param::ConvBias::Format::NCHW44_DOT) { | format == param::ConvBias::Format::NCHW44_DOT) { | ||||
| #if MEGDNN_AARCH64 | |||||
| auto matmul_block = matmul_algo->get_inner_block_size(); | |||||
| if (format == param::ConvBias::Format::NCHW44) { | if (format == param::ConvBias::Format::NCHW44) { | ||||
| //! Optimize NCHW44 3x3s1 4X4X16 im2col+pack fuse | //! Optimize NCHW44 3x3s1 4X4X16 im2col+pack fuse | ||||
| #if MEGDNN_AARCH64 | |||||
| auto matmul_block = matmul_algo->get_inner_block_size(); | |||||
| if (matmul_block.m == 4 && matmul_block.n == 4 && | if (matmul_block.m == 4 && matmul_block.n == 4 && | ||||
| matmul_block.k == 16 && | matmul_block.k == 16 && | ||||
| param.filter_meta.spatial[0] == 3 && | param.filter_meta.spatial[0] == 3 && | ||||
| @@ -368,7 +368,10 @@ public: | |||||
| MIDOUT_END(); | MIDOUT_END(); | ||||
| return {}; | return {}; | ||||
| } | } | ||||
| #endif | |||||
| } else { | } else { | ||||
| #if MEGDNN_AARCH64 | |||||
| auto matmul_block = matmul_algo->get_inner_block_size(); | |||||
| //! Optimize NCHW44_DOT 3x3s1 8X12X4 im2col+pack fuse | //! Optimize NCHW44_DOT 3x3s1 8X12X4 im2col+pack fuse | ||||
| if (matmul_block.m == 8 && matmul_block.n == 12 && | if (matmul_block.m == 8 && matmul_block.n == 12 && | ||||
| matmul_block.k == 4 && | matmul_block.k == 4 && | ||||
| @@ -389,8 +392,30 @@ public: | |||||
| MIDOUT_END(); | MIDOUT_END(); | ||||
| return {}; | return {}; | ||||
| } | } | ||||
| } | |||||
| #endif | #endif | ||||
| #if MEGDNN_ARMV7 | |||||
| auto matmul_block = matmul_algo->get_inner_block_size(); | |||||
| if (matmul_block.m == 8 && matmul_block.n == 4 && | |||||
| matmul_block.k == 4 && | |||||
| param.filter_meta.spatial[0] == 3 && | |||||
| param.filter_meta.spatial[1] == 3 && | |||||
| param.filter_meta.stride[0] == 2 && | |||||
| param.filter_meta.stride[1] == 2 && | |||||
| !param.filter_meta.should_flip) { | |||||
| MIDOUT_BEGIN( | |||||
| megdnn_fallback_im2col_factory_make_strategy, | |||||
| midout_iv( | |||||
| "DefaultStrategyType::INT8x8x32_8x4x4_s2"_hash)) { | |||||
| return std::make_unique< | |||||
| StrategyFuse8x4x4Nchw44DotK3x3S2< | |||||
| dt_qint32, dt_qint8, | |||||
| PostprocessMode::QUANTIZED>>(); | |||||
| } | |||||
| MIDOUT_END(); | |||||
| return {}; | |||||
| } | |||||
| #endif | |||||
| } | |||||
| cb2(NCHW44, DEFAULT, dtype::QuantizedS8, | cb2(NCHW44, DEFAULT, dtype::QuantizedS8, | ||||
| dtype::QuantizedS32, dtype::QuantizedS8, dt_int8, | dtype::QuantizedS32, dtype::QuantizedS8, dt_int8, | ||||
| dt_int32, dt_int8, PostprocessMode::QUANTIZED, | dt_int32, dt_int8, PostprocessMode::QUANTIZED, | ||||
| @@ -488,12 +488,12 @@ public: | |||||
| template <typename op_ctype, typename op_dtype, | template <typename op_ctype, typename op_dtype, | ||||
| megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
| class StrategyFuse8x12x1Nchw44K3x3S2 | |||||
| : public Strategy<float, float, float, op_ctype, op_dtype, | |||||
| class StrategyFuse8x12x4Nchw44Dot | |||||
| : public Strategy<dt_int8, dt_int32, dt_int8, op_ctype, op_dtype, | |||||
| postprocess_mode, PackMode::DEFAULT, | postprocess_mode, PackMode::DEFAULT, | ||||
| FormatMode::NCHW44> { | FormatMode::NCHW44> { | ||||
| public: | public: | ||||
| StrategyFuse8x12x1Nchw44K3x3S2() = default; | |||||
| StrategyFuse8x12x4Nchw44Dot() = default; | |||||
| constexpr static size_t BUNDLE_PADDING_INDEX = 0; | constexpr static size_t BUNDLE_PADDING_INDEX = 0; | ||||
| constexpr static size_t BUNDLE_PACKA_INDEX = 1; | constexpr static size_t BUNDLE_PACKA_INDEX = 1; | ||||
| @@ -508,16 +508,15 @@ public: | |||||
| fallback::MatrixMulImpl::KernParam matmul_param, | fallback::MatrixMulImpl::KernParam matmul_param, | ||||
| const fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | const fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | ||||
| }; | }; | ||||
| #else | |||||
| template <typename op_ctype, typename op_dtype, | template <typename op_ctype, typename op_dtype, | ||||
| megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
| class StrategyFuse8x12x4Nchw44Dot | |||||
| class StrategyFuse8x4x4Nchw44DotK3x3S2 | |||||
| : public Strategy<dt_int8, dt_int32, dt_int8, op_ctype, op_dtype, | : public Strategy<dt_int8, dt_int32, dt_int8, op_ctype, op_dtype, | ||||
| postprocess_mode, PackMode::DEFAULT, | postprocess_mode, PackMode::DEFAULT, | ||||
| FormatMode::NCHW44> { | FormatMode::NCHW44> { | ||||
| public: | public: | ||||
| StrategyFuse8x12x4Nchw44Dot() = default; | |||||
| StrategyFuse8x4x4Nchw44DotK3x3S2() = default; | |||||
| constexpr static size_t BUNDLE_PADDING_INDEX = 0; | constexpr static size_t BUNDLE_PADDING_INDEX = 0; | ||||
| constexpr static size_t BUNDLE_PACKA_INDEX = 1; | constexpr static size_t BUNDLE_PACKA_INDEX = 1; | ||||
| @@ -534,6 +533,30 @@ public: | |||||
| }; | }; | ||||
| #endif | #endif | ||||
| #if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||||
| template <typename op_ctype, typename op_dtype, | |||||
| megdnn::PostprocessMode postprocess_mode> | |||||
| class StrategyFuseXx12x1Nchw44K3x3S2 | |||||
| : public Strategy<float, float, float, op_ctype, op_dtype, | |||||
| postprocess_mode, PackMode::DEFAULT, | |||||
| FormatMode::NCHW44> { | |||||
| public: | |||||
| StrategyFuseXx12x1Nchw44K3x3S2() = default; | |||||
| constexpr static size_t BUNDLE_PADDING_INDEX = 0; | |||||
| constexpr static size_t BUNDLE_PACKA_INDEX = 1; | |||||
| constexpr static size_t THREAD_BUNDLE_PACKB_INDEX = 0; | |||||
| constexpr static size_t THREAD_BUNDLE_IM2COL_INDEX = 1; | |||||
| constexpr static size_t THREAD_BUNDLE_BIAS_INDEX = 2; | |||||
| void exec_im2col( | |||||
| const WorkspaceBundle& bundle, const WorkspaceBundle& bundle_thread, | |||||
| const StrategyParam& sparam, | |||||
| const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
| fallback::MatrixMulImpl::KernParam matmul_param, | |||||
| const fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | |||||
| }; | |||||
| #endif | |||||
| } // namespace megdnn | } // namespace megdnn | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -0,0 +1,208 @@ | |||||
| /** | |||||
| * \file dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44_dot_s2.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #include "src/fallback/conv_bias/im2col/strategy_base.h" | |||||
| #if MEGDNN_ARMV7 | |||||
| #include <arm_neon.h> | |||||
| using namespace megdnn; | |||||
| namespace { | |||||
| #define PACKB_ONELINE() \ | |||||
| int out_index = 0; \ | |||||
| outptr = output_base; \ | |||||
| for (; out_index + 3 < block_size; out_index += 4) { \ | |||||
| std::memcpy(outptr, tmp_output, 16); \ | |||||
| outptr += ksize4; \ | |||||
| tmp_output += 4; \ | |||||
| } \ | |||||
| \ | |||||
| if (out_index < block_size) { \ | |||||
| uint32_t zerobuffer[4] = {0}; \ | |||||
| size_t out_remain = std::min(block_size - out_index, 4); \ | |||||
| std::memcpy(outptr, tmp_output, out_remain * sizeof(uint32_t)); \ | |||||
| outptr += out_remain; \ | |||||
| std::memcpy(outptr, zerobuffer, (4 - out_remain) * sizeof(uint32_t)); \ | |||||
| } \ | |||||
| output_base += 4; | |||||
| #define STOR_IM2COL_DST() \ | |||||
| output0[count] = uint32_src[index]; \ | |||||
| output1[count] = uint32_src[index + 1]; \ | |||||
| output2[count] = uint32_src[index + 2]; \ | |||||
| count++; \ | |||||
| index += SW; | |||||
| #define LOAD_AND_STOR_IM2COL_DST() \ | |||||
| uint32x4x2_t val_01 = vld2q_u32(&uint32_src[index]); \ | |||||
| index += 8; \ | |||||
| uint32x4_t val_index8 = vdupq_n_u32(uint32_src[index]); \ | |||||
| uint32x4_t val_2 = vextq_u32(val_01.val[0], val_index8, 1); \ | |||||
| vst1q_u32(&output0[count], val_01.val[0]); \ | |||||
| vst1q_u32(&output1[count], val_01.val[1]); \ | |||||
| vst1q_u32(&output2[count], val_2); \ | |||||
| count += 4; | |||||
| void fuse_packb(const dt_int8* __restrict src, dt_int8* __restrict dst, | |||||
| dt_int8* __restrict b_panel, const int OW, const int IC, | |||||
| const int IH, const int IW, const int cur_index, | |||||
| const int block_size) { | |||||
| int start_h = cur_index / OW; | |||||
| int cur_remain_w = cur_index % OW; | |||||
| int end_h = (cur_index + block_size) / OW; | |||||
| int end_remain_w = (cur_index + block_size) % OW; | |||||
| bool same_line = start_h == end_h ? true : false; | |||||
| size_t newIC = IC / 4; | |||||
| const uint32_t* uint32_src = | |||||
| static_cast<const uint32_t*>(static_cast<const void*>(src)); | |||||
| uint32_t* output = static_cast<uint32_t*>(static_cast<void*>(dst)); | |||||
| uint32_t* b_output = static_cast<uint32_t*>(static_cast<void*>(b_panel)); | |||||
| const int packed_k = newIC * 3 * 3; | |||||
| const int ksize4 = packed_k * 4; | |||||
| uint32_t* outptr = b_output; | |||||
| uint32_t* output_base = b_output; | |||||
| constexpr int FH = 3; | |||||
| constexpr int SH = 2; | |||||
| constexpr int SW = 2; | |||||
| if (same_line) { | |||||
| rep(ic, newIC) { | |||||
| rep(fh, FH) { | |||||
| uint32_t* output02 = output; | |||||
| uint32_t* output1 = output + block_size + 1; | |||||
| size_t count = 0; | |||||
| size_t index = 0; | |||||
| int w = cur_remain_w; | |||||
| index = (ic * IH + (start_h * SH + fh)) * IW + w * SW; | |||||
| for (; w + 3 < end_remain_w; w += 4) { | |||||
| uint32x4x2_t val_01 = vld2q_u32(&uint32_src[index]); | |||||
| vst1q_u32(&output02[count], val_01.val[0]); | |||||
| vst1q_u32(&output1[count], val_01.val[1]); | |||||
| count += 4; | |||||
| index += 8; | |||||
| } | |||||
| for (; w < end_remain_w; w++) { | |||||
| output02[count] = uint32_src[index + 0]; | |||||
| output1[count] = uint32_src[index + 1]; | |||||
| count++; | |||||
| index += SW; | |||||
| } | |||||
| output02[count] = uint32_src[index]; | |||||
| const uint32_t* output_ptr[3]; | |||||
| output_ptr[0] = output02; | |||||
| output_ptr[1] = output1; | |||||
| output_ptr[2] = output02 + 1; | |||||
| for (int i = 0; i < 3; i++) { | |||||
| const uint32_t* tmp_output = output_ptr[i]; | |||||
| PACKB_ONELINE(); | |||||
| } | |||||
| } | |||||
| } | |||||
| } else { | |||||
| rep(ic, newIC) { | |||||
| rep(fh, FH) { | |||||
| size_t count = 0; | |||||
| size_t index = 0; | |||||
| uint32_t* output0 = output; | |||||
| uint32_t* output1 = output + block_size; | |||||
| uint32_t* output2 = output1 + block_size; | |||||
| int w = cur_remain_w; | |||||
| index = (ic * IH + (SH * start_h + fh)) * IW + SW * w; | |||||
| for (; w + 3 < OW; w += 4) { | |||||
| LOAD_AND_STOR_IM2COL_DST() | |||||
| } | |||||
| for (; w < OW; w++) { | |||||
| STOR_IM2COL_DST() | |||||
| } | |||||
| for (int h = start_h + 1; h < end_h; h++) { | |||||
| int ow = 0; | |||||
| index = (ic * IH + (SH * h + fh)) * IW; | |||||
| for (; ow + 3 < OW; ow += 4) { | |||||
| LOAD_AND_STOR_IM2COL_DST() | |||||
| } | |||||
| for (; ow < OW; ow++) { | |||||
| STOR_IM2COL_DST() | |||||
| } | |||||
| } | |||||
| index = (ic * IH + (SH * end_h + fh)) * IW; | |||||
| w = 0; | |||||
| for (; w + 3 < end_remain_w; w += 4) { | |||||
| LOAD_AND_STOR_IM2COL_DST() | |||||
| } | |||||
| for (; w < end_remain_w; w++) { | |||||
| STOR_IM2COL_DST() | |||||
| } | |||||
| for (int k = 0; k < 3; k++) { | |||||
| const uint32_t* tmp_output = output + k * block_size; | |||||
| PACKB_ONELINE(); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| #undef PACKB_ONELINE | |||||
| #undef STOR_IM2COL_DST | |||||
| #undef LOAD_AND_STOR_IM2COL_DST | |||||
| } // namespace | |||||
| template <typename op_ctype, typename op_dtype, | |||||
| megdnn::PostprocessMode postprocess_mode> | |||||
| void StrategyFuse8x4x4Nchw44DotK3x3S2<op_ctype, op_dtype, postprocess_mode>:: | |||||
| exec_im2col(const WorkspaceBundle& bundle, | |||||
| const WorkspaceBundle& bundle_thread, | |||||
| const StrategyParam& sparam, | |||||
| const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
| fallback::MatrixMulImpl::KernParam /*matmul_param*/, | |||||
| const fallback::MatrixMulImpl::AlgoBase* /*matmul_algo*/) { | |||||
| size_t ow = param.osz[1]; | |||||
| size_t ic = param.filter_meta.icpg; | |||||
| size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2; | |||||
| size_t iw = param.isz[1] + param.filter_meta.padding[1] * 2; | |||||
| size_t input_offset = | |||||
| ih * iw * ic * | |||||
| (sparam.group_id + param.filter_meta.group * sparam.batch_id) * | |||||
| sizeof(dt_int8); | |||||
| dt_int8* src2 = reinterpret_cast<dt_int8*>( | |||||
| reinterpret_cast<uintptr_t>(bundle.get(BUNDLE_PADDING_INDEX)) + | |||||
| input_offset); | |||||
| bool is_phpwzero = param.filter_meta.padding[0] == 0 && | |||||
| param.filter_meta.padding[1] == 0; | |||||
| if (is_phpwzero) { | |||||
| src2 = const_cast<dt_int8*>( | |||||
| param.src<dt_int8>(sparam.batch_id, sparam.group_id)); | |||||
| } | |||||
| dt_int8* b_panel = reinterpret_cast<dt_int8*>(reinterpret_cast<uintptr_t>( | |||||
| bundle_thread.get(THREAD_BUNDLE_PACKB_INDEX))); | |||||
| megdnn_assert(ic % 4 == 0, "nchw44dot_dot with ic is not of time 4"); | |||||
| int8_t* im2col_dst = | |||||
| static_cast<int8_t*>(bundle_thread.get(THREAD_BUNDLE_IM2COL_INDEX)); | |||||
| fuse_packb(src2, im2col_dst, b_panel, ow, ic, ih, iw, sparam.ohw_cur_index, | |||||
| sparam.output_block_size); | |||||
| } | |||||
| namespace megdnn { | |||||
| template class StrategyFuse8x4x4Nchw44DotK3x3S2<dt_qint32, dt_qint8, | |||||
| megdnn::PostprocessMode::QUANTIZED>; | |||||
| } // namespace megdnn | |||||
| #endif | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -10,9 +10,8 @@ | |||||
| */ | */ | ||||
| #include "src/fallback/conv_bias/im2col/strategy_base.h" | #include "src/fallback/conv_bias/im2col/strategy_base.h" | ||||
| #include "src/fallback/convolution/img2col_helper.h" | |||||
| #if MEGDNN_AARCH64 | |||||
| #if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||||
| #include <arm_neon.h> | #include <arm_neon.h> | ||||
| using namespace megdnn; | using namespace megdnn; | ||||
| @@ -163,7 +162,7 @@ void fuse_packb(const float* __restrict src, float* __restrict dst, | |||||
| template <typename op_ctype, typename op_dtype, | template <typename op_ctype, typename op_dtype, | ||||
| megdnn::PostprocessMode postprocess_mode> | megdnn::PostprocessMode postprocess_mode> | ||||
| void StrategyFuse8x12x1Nchw44K3x3S2<op_ctype, op_dtype, postprocess_mode>:: | |||||
| void StrategyFuseXx12x1Nchw44K3x3S2<op_ctype, op_dtype, postprocess_mode>:: | |||||
| exec_im2col(const WorkspaceBundle& bundle, | exec_im2col(const WorkspaceBundle& bundle, | ||||
| const WorkspaceBundle& bundle_thread, | const WorkspaceBundle& bundle_thread, | ||||
| const StrategyParam& sparam, | const StrategyParam& sparam, | ||||
| @@ -194,14 +193,13 @@ void StrategyFuse8x12x1Nchw44K3x3S2<op_ctype, op_dtype, postprocess_mode>:: | |||||
| float* im2col_dst = | float* im2col_dst = | ||||
| static_cast<float*>(bundle_thread.get(THREAD_BUNDLE_IM2COL_INDEX)); | static_cast<float*>(bundle_thread.get(THREAD_BUNDLE_IM2COL_INDEX)); | ||||
| fuse_packb(src2, im2col_dst, b_panel, ow, ic, ih, iw, sparam.ohw_cur_index, | fuse_packb(src2, im2col_dst, b_panel, ow, ic, ih, iw, sparam.ohw_cur_index, | ||||
| sparam.output_block_size); | sparam.output_block_size); | ||||
| } | } | ||||
| namespace megdnn { | namespace megdnn { | ||||
| template class StrategyFuse8x12x1Nchw44K3x3S2<float, float, | |||||
| template class StrategyFuseXx12x1Nchw44K3x3S2<float, float, | |||||
| megdnn::PostprocessMode::FLOAT>; | megdnn::PostprocessMode::FLOAT>; | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| @@ -1461,6 +1461,25 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_MK4_DOT) { | |||||
| #undef cb | #undef cb | ||||
| } | } | ||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_MK4_DOT_S2_FUSE) { | |||||
| UniformIntRNG rng{-50, 50}; | |||||
| #define cb(name) \ | |||||
| checker_conv_bias(get_nchw44_conv_bias_args({3}, 2, false, \ | |||||
| false, false, false, true), \ | |||||
| handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ | |||||
| dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \ | |||||
| dtype::QuantizedS8(60.25f), name); \ | |||||
| float epsilon = 0.001; | |||||
| #if MEGDNN_AARCH64 | |||||
| cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD:96"); | |||||
| #elif MEGDNN_ARMV7 | |||||
| cb("IM2COLMATMUL:AARCH32_INT8_MK4_8X4X4_DOTPROD:96"); | |||||
| #endif | |||||
| #undef cb | |||||
| } | |||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_S8x8x32_MK4_DOT) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_S8x8x32_MK4_DOT) { | ||||
| UniformIntRNG rng{-50, 50}; | UniformIntRNG rng{-50, 50}; | ||||
| @@ -655,6 +655,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_INT8_NCHW44) { | |||||
| bench_case(1, 512, 256, 28, 28, 3, 4, 1, 2); | bench_case(1, 512, 256, 28, 28, 3, 4, 1, 2); | ||||
| } | } | ||||
| #if __ARM_FEATURE_DOTPROD | |||||
| TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_INT8_NCHW44_DOT) { | TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_INT8_NCHW44_DOT) { | ||||
| constexpr size_t RUNS = 40; | constexpr size_t RUNS = 40; | ||||
| std::vector<DType> data_type = { | std::vector<DType> data_type = { | ||||
| @@ -708,6 +709,64 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_INT8_NCHW44_DOT) { | |||||
| } | } | ||||
| TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_INT8_NCHW44_DOT_S2) { | |||||
| constexpr size_t RUNS = 40; | |||||
| std::vector<DType> data_type = { | |||||
| dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f), | |||||
| dtype::QuantizedS32(6.25f), dtype::QuantizedS8(60.25f)}; | |||||
| auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, | |||||
| size_t FS, size_t group, size_t P, size_t S, | |||||
| bool is_nchw = false) { | |||||
| param::ConvBias param; | |||||
| param.nonlineMode = param::ConvBias::NonlineMode::RELU; | |||||
| param.pad_h = P; | |||||
| param.pad_w = P; | |||||
| param.stride_h = S; | |||||
| param.stride_w = S; | |||||
| param.sparse = param::ConvBias::Sparse::DENSE; | |||||
| param.format = param::ConvBias::Format::NCHW44_DOT; | |||||
| auto OH = (H + 2 * P - FS) / static_cast<size_t>(S) + 1; | |||||
| auto OW = (W + 2 * P - FS) / static_cast<size_t>(S) + 1; | |||||
| TensorShape src = {N, IC / 4, H, W, 4}; | |||||
| TensorShape filter = {OC / 4, IC / 4, FS, FS, 4, 4}; | |||||
| if (group > 1) { | |||||
| filter = {group, OC / group / 4, IC / group / 4, FS, FS, 4, 4}; | |||||
| param.sparse = param::ConvBias::Sparse::GROUP; | |||||
| } | |||||
| if (is_nchw) { | |||||
| src = {N, IC, H, W}; | |||||
| filter = {OC / 4, FS, FS, IC, 4}; | |||||
| } | |||||
| TensorShape bias = {1, OC / 4, 1, 1, 4}; | |||||
| TensorShape dst = {N, OC / 4, OH, OW, 4}; | |||||
| SmallVector<TensorShape> shapes{src, filter, bias, {}, dst}; | |||||
| float computations = | |||||
| (((IC / group) * FS * FS + 1) * dst.total_nr_elems() * 2 + | |||||
| dst.total_nr_elems()) * | |||||
| 1e-6; | |||||
| std::vector<std::pair<SmallVector<TensorShape>, float>> shape_arg = { | |||||
| std::make_pair(shapes, computations)}; | |||||
| benchmark_impl(param, shape_arg, ".+", RUNS, {4, {4, 5, 6, 7}}, | |||||
| {1, {7}}, data_type); | |||||
| }; | |||||
| bench_case(1, 64, 64, 56, 56, 3, 1, 1, 2); | |||||
| bench_case(1, 64, 64, 128, 128, 3, 1, 1, 2); | |||||
| bench_case(1, 64, 64, 256, 256, 3, 1, 1, 2); | |||||
| bench_case(1, 64, 64, 156, 156, 3, 1, 1, 2); | |||||
| bench_case(1, 128, 128, 28, 28, 3, 1, 1, 2); | |||||
| bench_case(1, 256, 256, 14, 14, 3, 1, 1, 2); | |||||
| bench_case(1, 512, 512, 7, 7, 3, 1, 1, 2); | |||||
| bench_case(1, 64, 64, 56, 56, 3, 4, 1, 2); | |||||
| bench_case(1, 128, 128, 28, 28, 3, 4, 1, 2); | |||||
| bench_case(1, 256, 256, 14, 14, 3, 4, 1, 2); | |||||
| bench_case(1, 512, 512, 7, 7, 3, 4, 1, 2); | |||||
| } | |||||
| #endif | |||||
| TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_FLOAT_NCHW44) { | TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_FLOAT_NCHW44) { | ||||
| constexpr size_t RUNS = 40; | constexpr size_t RUNS = 40; | ||||
| std::vector<DType> data_type = { | std::vector<DType> data_type = { | ||||