| @@ -189,6 +189,28 @@ public: | |||
| fallback::ConvBiasImpl* opr, | |||
| const NCBKernSizeParam& param) const override; | |||
| }; | |||
| class ConvBiasImpl::AlgoDotS8Direct_NCHW44 final : public AlgoBase { | |||
| public: | |||
| AlgoDotS8Direct_NCHW44() {} | |||
| bool is_reproducible() const override { return true; } | |||
| const char* name() const override { | |||
| return "ARMDOTS8DIRECT_NCHW44"; | |||
| } | |||
| bool usable(FallbackConvBiasImpl*, const NCBKernSizeParam&, | |||
| AlgoSelectionStrategy algo_selection_strategy) const override; | |||
| size_t get_workspace(FallbackConvBiasImpl*, | |||
| const NCBKernSizeParam&) const override; | |||
| SmallVector<NCBKern> dispatch_kerns( | |||
| fallback::ConvBiasImpl* opr, | |||
| const NCBKernSizeParam& param) const override; | |||
| bool is_preferred(megdnn::fallback::ConvBiasImpl*, | |||
| const NCBKernSizeParam& param) const override; | |||
| }; | |||
| #endif | |||
| class ConvBiasImpl::AlgoS8WinogradF23_8x8 final : public AlgoBase { | |||
| @@ -0,0 +1,370 @@ | |||
| /** | |||
| * \file dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #ifdef __ARM_FEATURE_DOTPROD | |||
| #include "src/arm_common/elemwise_helper/kimpl/typecvt.h" | |||
| #include "src/common/unroll_macro.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/fallback/conv_bias/common.h" | |||
| #include "src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h" | |||
| #include "src/arm_common/conv_bias/int8/direct_dotprod_nchw44_kern.h" | |||
| namespace megdnn { | |||
| namespace arm_common { | |||
| namespace direct_dotprod_nchw44 { | |||
| template <> | |||
| void copy_packed_src_int8_nchw44<1>(int8_t* dst, const int dst_step, | |||
| const int8_t* src, const int src_step, | |||
| const int ic, const int ic_step, | |||
| const int ih, const int pad_left, | |||
| const int pad_right, const int pad_top, | |||
| const int pad_bottom) { | |||
| constexpr int IC_PACK_SIZE = 4; | |||
| rep_step(ic_idx, ic, IC_PACK_SIZE) { | |||
| const int8_t* i_src = src + ic_idx * ic_step; | |||
| //! pad top | |||
| int bytes_pad_top = pad_top * dst_step * IC_PACK_SIZE * sizeof(int8_t); | |||
| memset(dst, 0, bytes_pad_top); | |||
| dst += bytes_pad_top / sizeof(int8_t); | |||
| rep(ih_idx, ih) { | |||
| int bytes_row_in_dst = dst_step * IC_PACK_SIZE * sizeof(int8_t); | |||
| memset(dst, 0, bytes_row_in_dst); | |||
| //! left elements | |||
| int pad_left_elements = pad_left * IC_PACK_SIZE; | |||
| //! copy row [ih_idx, x] | |||
| int bytes_copy = src_step * IC_PACK_SIZE * sizeof(int8_t); | |||
| memcpy(dst + pad_left_elements, i_src, bytes_copy); | |||
| //! dst move to next row | |||
| dst += bytes_row_in_dst / sizeof(int8_t); | |||
| //! src move to next row | |||
| i_src += bytes_copy / sizeof(int8_t); | |||
| } | |||
| //! pad bottom | |||
| int bytes_pad_bottom = | |||
| pad_bottom * dst_step * IC_PACK_SIZE * sizeof(int8_t); | |||
| memset(dst, 0, bytes_pad_bottom); | |||
| dst += bytes_pad_bottom / sizeof(int8_t); | |||
| } | |||
| } | |||
| template <> | |||
| void copy_packed_src_int8_nchw44<2>(int8_t* dst, const int dst_step, | |||
| const int8_t* src, const int src_step, | |||
| const int ic, const int ic_step, | |||
| const int ih, const int pad_left, | |||
| const int pad_right, const int pad_top, | |||
| const int pad_bottom) { | |||
| constexpr int IC_PACK_SIZE = 4; | |||
| int odd_start = megdnn::div_ceil(dst_step, 2); | |||
| bool nochange = pad_left % 2 == 0; | |||
| rep_step(ic_idx, ic, IC_PACK_SIZE) { | |||
| const int32_t* i_src = | |||
| reinterpret_cast<const int32_t*>(src + ic_idx * ic_step); | |||
| int bytes_pad_top = pad_top * dst_step * IC_PACK_SIZE * sizeof(int8_t); | |||
| memset(dst, 0, bytes_pad_top); | |||
| dst += bytes_pad_top / sizeof(int8_t); | |||
| rep(ih_idx, ih) { | |||
| int bytes_row_in_dst = dst_step * IC_PACK_SIZE * sizeof(int8_t); | |||
| memset(dst, 0, bytes_row_in_dst); | |||
| int32_t* dst_even = reinterpret_cast<int32_t*>(dst) + pad_left / 2 + | |||
| pad_left % 2; | |||
| int32_t* dst_odd = | |||
| reinterpret_cast<int32_t*>(dst) + odd_start + pad_left / 2; | |||
| int i_src_idx = 0; | |||
| if (nochange) { | |||
| for (; i_src_idx + 7 < src_step; i_src_idx += 8) { | |||
| int32x4x2_t tmp; | |||
| tmp = vld2q_s32(i_src + i_src_idx); | |||
| vst1q_s32(dst_even, tmp.val[0]); | |||
| vst1q_s32(dst_odd, tmp.val[1]); | |||
| dst_even += 4; | |||
| dst_odd += 4; | |||
| } | |||
| } else { | |||
| for (; i_src_idx + 7 < src_step; i_src_idx += 8) { | |||
| int32x4x2_t tmp; | |||
| tmp = vld2q_s32(i_src + i_src_idx); | |||
| vst1q_s32(dst_even, tmp.val[1]); | |||
| vst1q_s32(dst_odd, tmp.val[0]); | |||
| dst_even += 4; | |||
| dst_odd += 4; | |||
| } | |||
| } | |||
| for (; i_src_idx < src_step; ++i_src_idx) { | |||
| if (nochange) { | |||
| if (i_src_idx % 2 == 0) { | |||
| *dst_even = *(i_src + i_src_idx); | |||
| dst_even++; | |||
| } else { | |||
| *dst_odd = *(i_src + i_src_idx); | |||
| dst_odd++; | |||
| } | |||
| } else { | |||
| if (i_src_idx % 2 == 0) { | |||
| *dst_odd = *(i_src + i_src_idx); | |||
| dst_odd++; | |||
| } else { | |||
| *dst_even = *(i_src + i_src_idx); | |||
| dst_even++; | |||
| } | |||
| } | |||
| } | |||
| //! dst move to next row | |||
| dst += bytes_row_in_dst / sizeof(int8_t); | |||
| //! src move to next row | |||
| i_src += src_step; | |||
| } | |||
| //! pad bottom | |||
| int bytes_pad_bottom = | |||
| pad_bottom * dst_step * IC_PACK_SIZE * sizeof(int8_t); | |||
| memset(dst, 0, bytes_pad_bottom); | |||
| dst += bytes_pad_bottom / sizeof(int8_t); | |||
| } | |||
| } | |||
| template <typename dst_type, int stride, BiasMode bias_mode, typename Op, | |||
| int filter_size> | |||
| void conv_direct_sdot_int8_nchw44(dst_type* dst, const int oh, const int ow, | |||
| const int8_t* src, const int ih, const int iw, | |||
| const int8_t* filter, const int32_t* bias, | |||
| const int oh_size, const int oc, const int ic, | |||
| const Op& op) { | |||
| constexpr int FH = filter_size; | |||
| constexpr int FW = filter_size; | |||
| constexpr int IC_PACK_SIZE = 4; | |||
| constexpr int OC_PACK_SIZE = 4; | |||
| #if MEGDNN_AARCH64 | |||
| constexpr int OC_BIG_INTERVAL = 12; | |||
| constexpr int OC_MID_INTERVAL = 8; | |||
| constexpr int OC_SMA_INTERVAL = 4; | |||
| #else | |||
| constexpr int OC_BIG_INTERVAL = 4; | |||
| constexpr int OC_MID_INTERVAL = 4; | |||
| constexpr int OC_SMA_INTERVAL = 4; | |||
| #endif | |||
| constexpr int OW_INTERVAL = 8; | |||
| constexpr int SH = stride; | |||
| const int dst_numbers_per_channel = oh * ow; | |||
| const int ow_remain = ow % OW_INTERVAL; | |||
| const int ow_end_idx = ow - ow_remain; | |||
| const int oc_remain = | |||
| oc % OC_BIG_INTERVAL; //! NCHW44 means oc_remain = 4 or 8 | |||
| const int oc_end_idx = oc - oc_remain; | |||
| const int dst_numbers_4channel_packed = | |||
| dst_numbers_per_channel * OC_PACK_SIZE; | |||
| using remain_fun = std::function<void( | |||
| dst_type * dst, const int dst_step, const int8_t* src, const int ih, | |||
| const int iw, const int8_t* filter, const int32_t* bias, | |||
| const int ic, const Op& op)>; | |||
| remain_fun kern_big_oc_remain = nullptr; | |||
| remain_fun kern_mid_oc_remain = nullptr; | |||
| remain_fun kern_sma_oc_remain = nullptr; | |||
| switch (ow_remain) { | |||
| #define cb(step) \ | |||
| case step: \ | |||
| kern_big_oc_remain = \ | |||
| KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, step, \ | |||
| filter_size, OC_BIG_INTERVAL, \ | |||
| OW_INTERVAL>::impl; \ | |||
| kern_mid_oc_remain = \ | |||
| KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, step, \ | |||
| filter_size, OC_MID_INTERVAL, \ | |||
| OW_INTERVAL>::impl; \ | |||
| kern_sma_oc_remain = \ | |||
| KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, step, \ | |||
| filter_size, OC_SMA_INTERVAL, \ | |||
| OW_INTERVAL>::impl; \ | |||
| break; | |||
| UNROLL_CALL_RAW(8, cb); | |||
| #undef cb | |||
| default: | |||
| megdnn_assert(0, "no remain %d for kern", ow_remain); | |||
| } | |||
| //! filter layout is [OC/4, IC/4, FH, FW, 4OC, 4IC] | |||
| //! cut [oc, oh, ow] into [oc/OC_INTERVAL, 1, ow/OW_INTERVAL, OW_INTERVAL, | |||
| //! oh, OC_INTERVAL] to calculate KernNeonSdotNCHW44 calculates | |||
| //! [OW_INTERVAL, 1, OC_INTERVAL] each time | |||
| for (int oc_idx = 0; oc_idx < oc_end_idx; oc_idx += OC_BIG_INTERVAL) { | |||
| const int filter_offset_in_element = oc_idx * ic * FH * FW; | |||
| for (int oh_idx = 0; oh_idx < oh_size; ++oh_idx) { | |||
| for (int ow_idx = 0; ow_idx < ow_end_idx; ow_idx += OW_INTERVAL) { | |||
| const int src_offset_in_element = | |||
| (oh_idx * SH * iw + ow_idx) * IC_PACK_SIZE; | |||
| const int dst_offset_in_element = | |||
| oc_idx * dst_numbers_per_channel + | |||
| (oh_idx * ow + ow_idx) * OC_PACK_SIZE; | |||
| const int bias_offset_in_element = oc_idx; | |||
| KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, OW_INTERVAL, | |||
| filter_size, OC_BIG_INTERVAL, OW_INTERVAL>:: | |||
| impl(dst + dst_offset_in_element, | |||
| dst_numbers_4channel_packed, | |||
| src + src_offset_in_element, ih, iw, | |||
| filter + filter_offset_in_element, | |||
| bias + bias_offset_in_element, ic, op); | |||
| } | |||
| if (ow_remain) { | |||
| const int src_offset_in_element = | |||
| (oh_idx * SH * iw + ow_end_idx) * IC_PACK_SIZE; | |||
| const int dst_offset_in_element = | |||
| oc_idx * dst_numbers_per_channel + | |||
| (oh_idx * ow + ow_end_idx) * OC_PACK_SIZE; | |||
| const int bias_offset_in_element = oc_idx; | |||
| kern_big_oc_remain(dst + dst_offset_in_element, | |||
| dst_numbers_4channel_packed, | |||
| src + src_offset_in_element, ih, iw, | |||
| filter + filter_offset_in_element, | |||
| bias + bias_offset_in_element, ic, op); | |||
| } | |||
| } | |||
| } | |||
| #ifdef MEGDNN_AARCH64 | |||
| //! oc_remain must be 4 or 8 on aarch64 and must be 0 on aarch32 | |||
| if (oc_remain) { | |||
| int oc_idx = oc_end_idx; | |||
| const int filter_offset_in_element = oc_idx * ic * FH * FW; | |||
| for (int oh_idx = 0; oh_idx < oh_size; ++oh_idx) { | |||
| for (int ow_idx = 0; ow_idx < ow_end_idx; ow_idx += OW_INTERVAL) { | |||
| const int src_offset_in_element = | |||
| (oh_idx * SH * iw + ow_idx) * IC_PACK_SIZE; | |||
| const int dst_offset_in_element = | |||
| oc_idx * dst_numbers_per_channel + | |||
| (oh_idx * ow + ow_idx) * OC_PACK_SIZE; | |||
| const int bias_offset_in_element = oc_idx; | |||
| if (oc_remain == 8) { | |||
| KernNeonSdotNCHW44< | |||
| dst_type, stride, bias_mode, Op, OW_INTERVAL, | |||
| filter_size, OC_MID_INTERVAL, | |||
| OW_INTERVAL>::impl(dst + dst_offset_in_element, | |||
| dst_numbers_4channel_packed, | |||
| src + src_offset_in_element, ih, | |||
| iw, | |||
| filter + | |||
| filter_offset_in_element, | |||
| bias + bias_offset_in_element, | |||
| ic, op); | |||
| } else { | |||
| KernNeonSdotNCHW44< | |||
| dst_type, stride, bias_mode, Op, OW_INTERVAL, | |||
| filter_size, OC_SMA_INTERVAL, | |||
| OW_INTERVAL>::impl(dst + dst_offset_in_element, | |||
| dst_numbers_4channel_packed, | |||
| src + src_offset_in_element, ih, | |||
| iw, | |||
| filter + | |||
| filter_offset_in_element, | |||
| bias + bias_offset_in_element, | |||
| ic, op); | |||
| } | |||
| } | |||
| if (ow_remain) { | |||
| const int src_offset_in_element = | |||
| (oh_idx * SH * iw + ow_end_idx) * IC_PACK_SIZE; | |||
| const int dst_offset_in_element = | |||
| oc_idx * dst_numbers_per_channel + | |||
| (oh_idx * ow + ow_end_idx) * OC_PACK_SIZE; | |||
| const int bias_offset_in_element = oc_idx; | |||
| if (oc_remain == 8) { | |||
| kern_mid_oc_remain(dst + dst_offset_in_element, | |||
| dst_numbers_4channel_packed, | |||
| src + src_offset_in_element, ih, iw, | |||
| filter + filter_offset_in_element, | |||
| bias + bias_offset_in_element, ic, op); | |||
| } else { | |||
| kern_sma_oc_remain(dst + dst_offset_in_element, | |||
| dst_numbers_4channel_packed, | |||
| src + src_offset_in_element, ih, iw, | |||
| filter + filter_offset_in_element, | |||
| bias + bias_offset_in_element, ic, op); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| #define CONSTRUCT_FUNC(filter_size) \ | |||
| template <typename dst_type, BiasMode bias_mode, typename Op, int stride> \ | |||
| void conv_direct_##filter_size##x##filter_size##_int8_nchw44( \ | |||
| dst_type* dst, const int oh, const int ow, const int8_t* src, \ | |||
| const int ih, const int iw, const int8_t* weight, \ | |||
| const int32_t* bias, const int oh_size, const int oc, \ | |||
| const int ic, const Op& op) { \ | |||
| conv_direct_sdot_int8_nchw44<dst_type, stride, bias_mode, Op, \ | |||
| filter_size>( \ | |||
| dst, oh, ow, src, ih, iw, weight, bias, oh_size, oc, ic, op); \ | |||
| } | |||
| CONSTRUCT_FUNC(2); | |||
| CONSTRUCT_FUNC(3); | |||
| CONSTRUCT_FUNC(5); | |||
| CONSTRUCT_FUNC(7); | |||
| #undef CONSTRUCT_FUNC | |||
| #define INSTANTIATION(dst_type, stride, i, bias_mode, Op) \ | |||
| template void conv_direct_##i##x##i##_int8_nchw44<dst_type, bias_mode, Op, \ | |||
| stride>( \ | |||
| dst_type * dst, const int oh, const int ow, const int8_t* src, \ | |||
| const int ih, const int iw, const int8_t* weight, \ | |||
| const int32_t* bias, const int oh_size, const int oc, \ | |||
| const int ic, const Op& op); | |||
| #define FOR_OP(stride, i, bias_mode) \ | |||
| INSTANTIATION(dt_int8, stride, i, bias_mode, \ | |||
| TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||
| INSTANTIATION(dt_int32, stride, i, bias_mode, \ | |||
| NoneOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||
| INSTANTIATION(dt_int8, stride, i, bias_mode, \ | |||
| ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||
| INSTANTIATION(dt_int8, stride, i, bias_mode, \ | |||
| HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) | |||
| #define FOR_BIAS(stride, i) \ | |||
| FOR_OP(stride, i, BiasMode::NO_BIAS) \ | |||
| FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS) | |||
| #define FOR_FILTER(stride) \ | |||
| FOR_BIAS(stride, 2) \ | |||
| FOR_BIAS(stride, 3) \ | |||
| FOR_BIAS(stride, 5) \ | |||
| FOR_BIAS(stride, 7) | |||
| FOR_FILTER(1) | |||
| FOR_FILTER(2) | |||
| #undef FOR_STRIDE | |||
| #undef FOR_FILTER | |||
| #undef FOR_IC | |||
| #undef FOR_BIAS | |||
| #undef FOR_NONLINEAR | |||
| #undef FOR_REMAIN | |||
| #undef INSTANTIATION | |||
| } // namespace direct_dotprod_nchw44 | |||
| } // namespace arm_common | |||
| } // namespace megdnn | |||
| #endif | |||
| //vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,87 @@ | |||
| /** | |||
| * \file dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #if __ARM_FEATURE_DOTPROD | |||
| #pragma once | |||
| #include "src/arm_common/conv_bias/opr_impl.h" | |||
| namespace megdnn { | |||
| namespace arm_common { | |||
| namespace direct_dotprod_nchw44 { | |||
| using BiasMode = ConvBiasForward::BiasMode; | |||
| /** | |||
| * @brief : do direct conv with no side effect | |||
| * input buffer's size is [ih, iw] | |||
| * output buffer's size is [oh, ow] | |||
| * filter layout is [OC/4, IC/4, FH, FW, 4, 4] | |||
| * | |||
| * @param : [output ptr] dst | |||
| * [input] oh -> dst rows | |||
| * [input] ow -> dst cols | |||
| * [input ptr] src | |||
| * [input] ih -> rows of src used by this this kernel | |||
| * [input] iw -> src step in elements [iw2] | |||
| * [input ptr] filter | |||
| * [input ptr] bias | |||
| * [input] oh_size -> rows of result generated by this kernel | |||
| * [input] oc -> output channels | |||
| * [input] ic -> intput channels | |||
| * [input] op -> post process operator | |||
| * @return none | |||
| */ | |||
| #define KERN(filter_size) \ | |||
| template <typename dst_type, BiasMode bias_mode, typename Op, int stride> \ | |||
| void conv_direct_##filter_size##x##filter_size##_int8_nchw44( \ | |||
| dst_type* dst, const int oh, const int ow, const int8_t* src, \ | |||
| const int ih, const int iw, const int8_t* weight, \ | |||
| const int32_t* bias, const int oh_size, const int oc, \ | |||
| const int ic, const Op& op) | |||
| KERN(2); | |||
| KERN(3); | |||
| KERN(5); | |||
| KERN(7); | |||
| #undef KERN | |||
| /** | |||
| * @brief : copy data from src to dst for direct conv with no side effect | |||
| * @param : [output ptr] dst | |||
| * [input] dst_step -> step of dst in numbers of elements | |||
| * [input ptr] src | |||
| * [input] src_step -> step of src in numbers of elements | |||
| * [input] ic -> input channels | |||
| * [input] ic_step -> step of ic in numbers of elements | |||
| * [input] ih -> totle rows to copy | |||
| * [input] pad_left -> cols padding at left | |||
| * [input] pad_right -> cols padding at right | |||
| * [input] pad_top -> rows padding at top | |||
| * [input] pad_bottom -> rows padding at bottom | |||
| * @return none | |||
| */ | |||
| template <int stride> | |||
| void copy_packed_src_int8_nchw44(int8_t* dst, const int dst_step, | |||
| const int8_t* src, const int src_step, | |||
| const int ic, const int ic_step, const int ih, | |||
| const int pad_left, const int pad_right, | |||
| const int pad_top, const int pad_bottom); | |||
| } // namespace direct_dotprod_nchw44 | |||
| } // namespace arm_common | |||
| } // namespace megdnn | |||
| #endif | |||
| //vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,341 @@ | |||
| /** | |||
| * \file dnn/src/arm_common/conv_bias/int8/direct_dotpord_nchw44_algo.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #if __ARM_FEATURE_DOTPROD | |||
| #include "src/arm_common/conv_bias/block_helper.h" | |||
| #include "src/arm_common/conv_bias/int8/algos.h" | |||
| #include "src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "midout.h" | |||
| using namespace megdnn; | |||
| using namespace arm_common; | |||
| MIDOUT_DECL(megdnn_arm_common_conv_bias_int8) | |||
| using direct_fun = std::function<void( | |||
| WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& ncb_param, | |||
| const ConvBiasImpl::NCBKernIndex& ncb_index)>; | |||
| namespace { | |||
| static void get_rectified_size( | |||
| const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, int& ih, | |||
| int& iw, int& oh, int& ow) { | |||
| int IC = param.filter_meta.icpg; | |||
| int IW = param.isz[1]; | |||
| int OH = param.osz[0]; | |||
| int OW = param.osz[1]; | |||
| oh = OH; | |||
| ow = OW; | |||
| constexpr int cacheline = 64 / sizeof(int8_t); | |||
| int oh_tile_size = | |||
| l2_block_helper(param.nr_threads, OH, IC * IW * sizeof(int8_t) * 2); | |||
| auto&& fm = param.filter_meta; | |||
| const int SH = static_cast<int>(fm.stride[0]); | |||
| const int FH = static_cast<int>(fm.spatial[0]); | |||
| const int PW = static_cast<int>(fm.padding[1]); | |||
| ih = oh_tile_size * SH + FH - SH; | |||
| iw = round_up(IW + 2 * PW, cacheline); | |||
| } | |||
| static inline int get_perthread_cache_bytes(const int ic, const int ih, | |||
| const int iw) { | |||
| // border_size is used to avoid read illegal memory | |||
| int border_size = 64 * 2; | |||
| return ic * ih * iw * sizeof(int8_t) + border_size; | |||
| } | |||
| static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { | |||
| auto&& fm = param.filter_meta; | |||
| int IC = fm.icpg; | |||
| int ih2, iw2, oh2, ow2; | |||
| get_rectified_size(param, ih2, iw2, oh2, ow2); | |||
| int bytes_of_copy_per_thread = get_perthread_cache_bytes(IC, ih2, iw2); | |||
| return {nullptr, {bytes_of_copy_per_thread * param.nr_threads}}; | |||
| } | |||
| template <typename dst_type, size_t filter_size, BiasMode bias_mode, | |||
| typename Op, int stride> | |||
| static void conv_kern(WorkspaceBundle bundle, | |||
| const ConvBiasImpl::NCBKernParam& ncb_param, | |||
| const ConvBiasImpl::NCBKernIndex& ncb_index) { | |||
| const int OH = ncb_param.osz[0]; | |||
| const int OW = ncb_param.osz[1]; | |||
| const int FH = ncb_param.filter_meta.spatial[0]; | |||
| const int IC = ncb_param.filter_meta.icpg; | |||
| const int OC = ncb_param.filter_meta.ocpg; | |||
| const int IH = ncb_param.isz[0]; | |||
| const int IW = ncb_param.isz[1]; | |||
| const int SH = ncb_param.filter_meta.stride[0]; | |||
| const int PH = ncb_param.filter_meta.padding[0]; | |||
| const int PW = ncb_param.filter_meta.padding[1]; | |||
| int ih2 = 0; | |||
| int iw2 = 0; | |||
| int oh2 = 0; | |||
| int ow2 = 0; | |||
| get_rectified_size(ncb_param, ih2, iw2, oh2, ow2); | |||
| constexpr int IC_PACK_SIZE = 4; | |||
| constexpr int OC_PACK_SIZE = 4; | |||
| bundle.set(ncb_param.workspace_ptr); | |||
| const int batch_id = ncb_index.ndrange_id[0]; | |||
| const int group_id = ncb_index.ndrange_id[1]; | |||
| const int oh_tile_id = ncb_index.ndrange_id[2]; | |||
| const int thread_id = ncb_index.thread_id; | |||
| const int oh_tile_size = l2_block_helper(ncb_param.nr_threads, OH, | |||
| IC * IW * sizeof(int8_t) * 2); | |||
| const int oh_start_row = oh_tile_id * oh_tile_size; | |||
| const int ih_start_row = std::max(oh_start_row * SH - PH, 0); | |||
| const int oh_real_size = std::min(OH - oh_start_row, oh_tile_size); | |||
| const int ih_real_size = oh_real_size * SH + FH - SH; | |||
| const int rows_padding_at_top = std::max(PH - oh_start_row * SH, 0); | |||
| const int rows_padding_at_bottom = | |||
| std::max((oh_start_row + oh_real_size - 1) * SH + FH - IH - PH, 0); | |||
| const int cols_padding_at_left = PW; | |||
| const int cols_padding_at_right = std::max(iw2 - IW - PW, 0); | |||
| //! src layout{IC/4, IH, IW, 4} | |||
| const int bytes_of_src_offset = | |||
| ih_start_row * IW * IC_PACK_SIZE * sizeof(int8_t); | |||
| const int8_t* copy_src = static_cast<const int8_t*>( | |||
| ncb_param.src<int8_t>(batch_id, group_id) + bytes_of_src_offset); | |||
| const int bytes_of_copy_per_thread = | |||
| get_perthread_cache_bytes(IC, ih2, iw2); | |||
| int8_t* copy_dst = reinterpret_cast<int8_t*>(bundle.get(0)) + | |||
| thread_id * bytes_of_copy_per_thread; | |||
| const int rows_copy_from_src = | |||
| ih_real_size - rows_padding_at_top - rows_padding_at_bottom; | |||
| direct_dotprod_nchw44::copy_packed_src_int8_nchw44<stride>( | |||
| copy_dst, iw2, copy_src, IW, IC, IH * IW, rows_copy_from_src, | |||
| cols_padding_at_left, cols_padding_at_right, rows_padding_at_top, | |||
| rows_padding_at_bottom); | |||
| const int8_t* weights = ncb_param.filter<int8_t>(group_id); | |||
| dst_type* dst = ncb_param.dst<dst_type>(batch_id, group_id) + | |||
| oh_start_row * OW * OC_PACK_SIZE; | |||
| //! only broadcast or no_bias | |||
| const int32_t* bias = ncb_param.bias<int32_t>(batch_id, group_id); | |||
| Op op = Op(1.0f, 4.0f); | |||
| if (ncb_param.dst_type.enumv() == DTypeEnum::QuantizedS8) { | |||
| float scale_bias = | |||
| ncb_param.bias_type.param<dtype::QuantizedS32>().scale; | |||
| float scale_dst = ncb_param.dst_type.param<dtype::QuantizedS8>().scale; | |||
| op = Op(scale_bias, scale_dst); | |||
| } | |||
| #define KERN1_NCHW44_CONV(filter) \ | |||
| direct_dotprod_nchw44::conv_direct_##filter##x##filter##_int8_nchw44< \ | |||
| dst_type, bias_mode, Op, stride>(dst, OH, OW, copy_dst, \ | |||
| ih_real_size, iw2, weights, bias, \ | |||
| oh_real_size, OC, IC, op); | |||
| DISPATCH_FILTER(filter_size, KERN1_NCHW44_CONV); | |||
| #undef KERN1_NCHW44_CONV | |||
| } | |||
| } // namespace | |||
| bool ConvBiasImpl::AlgoDotS8Direct_NCHW44::usable( | |||
| FallbackConvBiasImpl*, const NCBKernSizeParam& param, | |||
| AlgoSelectionStrategy algo_selection_strategy) const { | |||
| auto&& fm = param.filter_meta; | |||
| auto FH = fm.spatial[0]; | |||
| auto FW = fm.spatial[1]; | |||
| auto SH = fm.stride[0]; | |||
| auto SW = fm.stride[1]; | |||
| auto OC = fm.ocpg; | |||
| auto IC = fm.icpg; | |||
| //! src and filter are qint8, dst is qint8. | |||
| bool data_type_ok = param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||
| param.filter_type.enumv() == DTypeEnum::QuantizedS8 && | |||
| (param.dst_type.enumv() == DTypeEnum::QuantizedS8 || | |||
| param.dst_type.enumv() == DTypeEnum::QuantizedS32); | |||
| if (param.bias_type.valid()) { | |||
| data_type_ok &= param.bias_type.enumv() == DTypeEnum::QuantizedS32; | |||
| } | |||
| data_type_ok |= param.src_type.enumv() == DTypeEnum::Int8 && | |||
| param.filter_type.enumv() == DTypeEnum::Int8 && | |||
| param.dst_type.enumv() == DTypeEnum::Int32; | |||
| bool layout_ok = fm.format == param::Convolution::Format::NCHW44_DOT && | |||
| IC % 4 == 0 && OC % 4 == 0; | |||
| bool param_ok = !fm.should_flip && fm.spatial_ndim == 2 && | |||
| fm.dilation[0] == 1 && fm.dilation[1] == 1 && FH == FW && | |||
| (FH >= 2 && FH <= 7); | |||
| bool stride_ok = SH == SW && (SH == 1 || SH == 2); | |||
| return data_type_ok && layout_ok && param_ok && stride_ok; | |||
| } | |||
| bool ConvBiasImpl::AlgoDotS8Direct_NCHW44::is_preferred( | |||
| megdnn::fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||
| return true; | |||
| } | |||
| size_t ConvBiasImpl::AlgoDotS8Direct_NCHW44::get_workspace( | |||
| FallbackConvBiasImpl*, const NCBKernSizeParam& param) const { | |||
| return get_bundle(param).total_size_in_bytes(); | |||
| } | |||
| SmallVector<ConvBiasImpl::NCBKern> | |||
| ConvBiasImpl::AlgoDotS8Direct_NCHW44::dispatch_kerns( | |||
| FallbackConvBiasImpl*, const NCBKernSizeParam& param) const { | |||
| MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, | |||
| midout_iv("ALGODOTS8DIRECT_NCHW44"_hash)) { | |||
| auto fm = param.filter_meta; | |||
| size_t BATCH = param.n; | |||
| size_t GROUP = fm.group; | |||
| WorkspaceBundle wbundle = get_bundle(param); | |||
| direct_fun kernel = nullptr; | |||
| bool quantized = param.dst_type.enumv() == DTypeEnum::QuantizedS8; | |||
| #define DO_CONV_KERN_FUN(dst_type, filter, bias_mode, op, stride) \ | |||
| MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, \ | |||
| midout_iv(#dst_type #filter #bias_mode #op##_hash)) { \ | |||
| kernel = conv_kern<dst_type, filter, bias_mode, op, stride>; \ | |||
| } \ | |||
| MIDOUT_END(); | |||
| #define GET_OP_PARAM(i, bias_mode, stride) \ | |||
| switch (param.nonlineMode) { \ | |||
| case param::ConvBias::NonlineMode::IDENTITY: \ | |||
| if (quantized) { \ | |||
| DO_CONV_KERN_FUN(dt_int8, i, bias_mode, \ | |||
| TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>, \ | |||
| stride) \ | |||
| } else { \ | |||
| DO_CONV_KERN_FUN(dt_int32, i, bias_mode, \ | |||
| NoneOp<dt_qint32 MEGDNN_COMMA dt_qint8>, \ | |||
| stride) \ | |||
| } \ | |||
| break; \ | |||
| case param::ConvBias::NonlineMode::RELU: \ | |||
| if (quantized) { \ | |||
| DO_CONV_KERN_FUN(dt_int8, i, bias_mode, \ | |||
| ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>, \ | |||
| stride) \ | |||
| } else { \ | |||
| megdnn_assert("No support NoQuantized RELU"); \ | |||
| } \ | |||
| break; \ | |||
| case param::ConvBias::NonlineMode::H_SWISH: \ | |||
| if (quantized) { \ | |||
| DO_CONV_KERN_FUN(dt_int8, i, bias_mode, \ | |||
| HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>, \ | |||
| stride) \ | |||
| } else { \ | |||
| megdnn_assert("No support NoQuantized H_SWISH"); \ | |||
| } \ | |||
| break; \ | |||
| default: \ | |||
| megdnn_assert(0); \ | |||
| break; \ | |||
| } | |||
| #define GET_STRIDE_PARAM(filter, bias_mode) \ | |||
| switch (fm.stride[0]) { \ | |||
| case 1: \ | |||
| GET_OP_PARAM(filter, bias_mode, 1); \ | |||
| break; \ | |||
| case 2: \ | |||
| GET_OP_PARAM(filter, bias_mode, 2); \ | |||
| break; \ | |||
| default: \ | |||
| megdnn_assert(0); \ | |||
| } | |||
| #define GET_BIAS_MODE_PARAM(filter) \ | |||
| switch (param.bias_mode) { \ | |||
| case BiasMode::NO_BIAS: \ | |||
| GET_STRIDE_PARAM(filter, BiasMode::NO_BIAS) \ | |||
| break; \ | |||
| case BiasMode::BROADCAST_CHANNEL_BIAS: \ | |||
| GET_STRIDE_PARAM(filter, BiasMode::BROADCAST_CHANNEL_BIAS) \ | |||
| break; \ | |||
| default: \ | |||
| megdnn_assert(0); \ | |||
| break; \ | |||
| } | |||
| #define SELECT_CONV_KERN() \ | |||
| switch (param.filter_meta.spatial[0]) { \ | |||
| case 2: \ | |||
| GET_BIAS_MODE_PARAM(2) \ | |||
| break; \ | |||
| case 3: \ | |||
| GET_BIAS_MODE_PARAM(3) \ | |||
| break; \ | |||
| case 5: \ | |||
| GET_BIAS_MODE_PARAM(5) \ | |||
| break; \ | |||
| case 7: \ | |||
| GET_BIAS_MODE_PARAM(7) \ | |||
| break; \ | |||
| default: \ | |||
| megdnn_assert(0); \ | |||
| break; \ | |||
| } | |||
| SELECT_CONV_KERN() | |||
| #undef DO_CONV_KERN_FUN | |||
| #undef GET_OP_PARAM | |||
| #undef GET_STRIDE_PARAM | |||
| #undef GET_BIAS_MODE_PARAM | |||
| #undef SELECT_CONV_KERN | |||
| megdnn_assert(kernel); | |||
| SmallVector<ConvBiasImpl::NCBKern> ret_kerns; | |||
| int OH = param.osz[0]; | |||
| int IC = param.filter_meta.icpg; | |||
| int IW = param.isz[1]; | |||
| int oh_tile_size = l2_block_helper(param.nr_threads, OH, | |||
| IC * IW * sizeof(int8_t) * 2); | |||
| size_t oh_tiles = static_cast<size_t>(div_ceil(OH, oh_tile_size)); | |||
| auto do_conv = [wbundle, kernel](const NCBKernParam& ncb_param, | |||
| const NCBKernIndex& ncb_index) { | |||
| kernel(wbundle, ncb_param, std::move(ncb_index)); | |||
| }; | |||
| ret_kerns.push_back({do_conv, {BATCH, GROUP, oh_tiles}}); | |||
| return ret_kerns; | |||
| } | |||
| MIDOUT_END(); | |||
| return {}; | |||
| } | |||
| #endif | |||
| //vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,430 @@ | |||
| /** | |||
| * \file dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_kern.h | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #ifdef __ARM_FEATURE_DOTPROD | |||
| #include "src/arm_common/conv_bias/intrinsic_helper.h" | |||
| #include "src/arm_common/elemwise_op.h" | |||
| #include "src/arm_common/intrinsic_helper.h" | |||
| #include "src/arm_common/neon_struct.h" | |||
| #include "src/common/unroll_macro.h" | |||
| namespace megdnn { | |||
| namespace arm_common { | |||
| namespace direct_dotprod_nchw44 { | |||
| constexpr int SIMD_LEN = 16; | |||
| constexpr int IC_PACK_SIZE = 4; | |||
| constexpr int OC_PACK_SIZE = 4; | |||
| constexpr int filter_next_col = | |||
| IC_PACK_SIZE * OC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC] | |||
| template <int row, BiasMode bias_mode> | |||
| inline void init_ocx_ow8(int32x4_t c[][8], const int32_t* bias_ptr, | |||
| int oc_step) { | |||
| static_assert(row == 1 || row == 2 || row == 3, "Invalid OC number."); | |||
| if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||
| #define BIAS_INIT(step, i) c[i][step] = vld1q_s32(bias_ptr + i * oc_step); | |||
| switch (row) { | |||
| case 3: | |||
| UNROLL_CALL_RAW(8, BIAS_INIT, 2); | |||
| case 2: | |||
| UNROLL_CALL_RAW(8, BIAS_INIT, 1); | |||
| default: | |||
| UNROLL_CALL_RAW(8, BIAS_INIT, 0); | |||
| } | |||
| #undef BIAS_INIT | |||
| } else { | |||
| #define BIAS_INIT(step, i) c[i][step] = vdupq_n_s32(0); | |||
| switch (row) { | |||
| case 3: | |||
| UNROLL_CALL_RAW(8, BIAS_INIT, 2); | |||
| case 2: | |||
| UNROLL_CALL_RAW(8, BIAS_INIT, 1); | |||
| default: | |||
| UNROLL_CALL_RAW(8, BIAS_INIT, 0); | |||
| } | |||
| #undef BIAS_INIT | |||
| } | |||
| } | |||
| #define cb11(col) \ | |||
| op(res[0][col], reinterpret_cast<dt_qint8*>(dst_ptr + col / 2 * 8)); | |||
| #define cb21(col) \ | |||
| op(res[0][col], reinterpret_cast<dt_qint8*>(dst_ptr + col / 2 * 8)); \ | |||
| op(res[1][col], \ | |||
| reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc + col / 2 * 8)); | |||
| #define cb31(col) \ | |||
| op(res[0][col], reinterpret_cast<dt_qint8*>(dst_ptr + col / 2 * 8)); \ | |||
| op(res[1][col], \ | |||
| reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc + col / 2 * 8)); \ | |||
| op(res[2][col], reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc + \ | |||
| ld_dst_oc + col / 2 * 8)); | |||
| #define cb12(step) \ | |||
| op({{res[0][2 * step], res[0][2 * step + 1]}}, \ | |||
| reinterpret_cast<dt_qint8*>(dst_ptr + step * 8)); | |||
| #define cb22(step) \ | |||
| op({{res[0][2 * step], res[0][2 * step + 1]}}, \ | |||
| reinterpret_cast<dt_qint8*>(dst_ptr + step * 8)); \ | |||
| op({{res[1][2 * step], res[1][2 * step + 1]}}, \ | |||
| reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc + step * 8)); | |||
| #define cb32(step) \ | |||
| op({{res[0][2 * step], res[0][2 * step + 1]}}, \ | |||
| reinterpret_cast<dt_qint8*>(dst_ptr + step * 8)); \ | |||
| op({{res[1][2 * step], res[1][2 * step + 1]}}, \ | |||
| reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc + step * 8)); \ | |||
| op({{res[2][2 * step], res[2][2 * step + 1]}}, \ | |||
| reinterpret_cast<dt_qint8*>(dst_ptr + 2 * ld_dst_oc + step * 8)); | |||
| template <int row, int ow_remain, typename Op, typename T> | |||
| struct StoreOCxOWx { | |||
| static void impl(int32x4_t res[][8], const Op& op, T* dst_ptr, | |||
| const int ld_dst_oc); | |||
| }; | |||
| template <int ow_remain, typename Op, typename T> | |||
| struct StoreOCxOWx<1, ow_remain, Op, T> { | |||
| static void impl(int32x4_t res[][8], const Op& op, T* dst_ptr, | |||
| const int ld_dst_oc) { | |||
| switch (ow_remain) { | |||
| case 8: | |||
| UNROLL_CALL_RAW(4, cb12); | |||
| break; | |||
| case 7: | |||
| cb11(6); | |||
| case 6: | |||
| UNROLL_CALL_RAW(3, cb12); | |||
| break; | |||
| case 5: | |||
| cb11(4); | |||
| case 4: | |||
| UNROLL_CALL_RAW(2, cb12); | |||
| break; | |||
| case 3: | |||
| cb11(2); | |||
| case 2: | |||
| UNROLL_CALL_RAW(1, cb12); | |||
| break; | |||
| case 1: | |||
| cb11(0); | |||
| default: | |||
| break; | |||
| } | |||
| } | |||
| }; | |||
| template <int ow_remain, typename Op, typename T> | |||
| struct StoreOCxOWx<2, ow_remain, Op, T> { | |||
| static void impl(int32x4_t res[][8], const Op& op, T* dst_ptr, | |||
| const int ld_dst_oc) { | |||
| switch (ow_remain) { | |||
| case 8: | |||
| UNROLL_CALL_RAW(4, cb22); | |||
| break; | |||
| case 7: | |||
| cb21(6); | |||
| case 6: | |||
| UNROLL_CALL_RAW(3, cb22); | |||
| break; | |||
| case 5: | |||
| cb21(4); | |||
| case 4: | |||
| UNROLL_CALL_RAW(2, cb22); | |||
| break; | |||
| case 3: | |||
| cb21(2); | |||
| case 2: | |||
| UNROLL_CALL_RAW(1, cb22); | |||
| break; | |||
| case 1: | |||
| cb21(0); | |||
| default: | |||
| break; | |||
| } | |||
| } | |||
| }; | |||
| template <int ow_remain, typename Op, typename T> | |||
| struct StoreOCxOWx<3, ow_remain, Op, T> { | |||
| static void impl(int32x4_t res[][8], const Op& op, T* dst_ptr, | |||
| const int ld_dst_oc) { | |||
| switch (ow_remain) { | |||
| case 8: | |||
| UNROLL_CALL_RAW(4, cb32); | |||
| break; | |||
| case 7: | |||
| cb31(6); | |||
| case 6: | |||
| UNROLL_CALL_RAW(3, cb32); | |||
| break; | |||
| case 5: | |||
| cb31(4); | |||
| case 4: | |||
| UNROLL_CALL_RAW(2, cb32); | |||
| break; | |||
| case 3: | |||
| cb31(2); | |||
| case 2: | |||
| UNROLL_CALL_RAW(1, cb32); | |||
| break; | |||
| case 1: | |||
| cb31(0); | |||
| default: | |||
| break; | |||
| } | |||
| } | |||
| }; | |||
| #undef cb11 | |||
| #undef cb21 | |||
| #undef cb31 | |||
| #undef cb12 | |||
| #undef cb22 | |||
| #undef cb32 | |||
| template <int row, int ow_remain, typename Op, typename T> | |||
| inline void store_ocx_owx_remain_static(int32x4_t res[][8], const Op& op, | |||
| T* dst_ptr, const int ld_dst_oc) { | |||
| StoreOCxOWx<row, ow_remain, Op, T>::impl(res, op, dst_ptr, ld_dst_oc); | |||
| } | |||
| template <int res_row, int src_row, int src_start_idx, int weight_idx, | |||
| typename FUNC, typename T, typename T2, typename T3> | |||
| struct ShiftCalHelper { | |||
| static void impl(T& res, T2& src, T3& weight) { | |||
| #define cb(step) \ | |||
| res[res_row][step] = FUNC::template impl<((src_start_idx + step) % 4)>( \ | |||
| res[res_row][step], weight[weight_idx], \ | |||
| src[src_row][(src_start_idx + step) / 4]); | |||
| UNROLL_CALL_RAW(8, cb); | |||
| #undef cb | |||
| } | |||
| }; | |||
| template <int res_row, int src_row, int src_start_idx, int weight_idx, | |||
| typename FUNC, typename T, typename T2, typename T3> | |||
| inline void cal_helper(T& res, T2& src, T3& weight) { | |||
| ShiftCalHelper<res_row, src_row, src_start_idx, weight_idx, FUNC, T, T2, | |||
| T3>::impl(res, src, weight); | |||
| }; | |||
| /** | |||
| * oc12_owx(m = 12, n = x) and oc8_owx(m = 8, n = x) and oc4_owx(m = 4, n = x) | |||
| * gemm like kernel | |||
| * */ | |||
| template <typename dst_type, int stride, BiasMode bias_mode, typename Op, | |||
| int ow_remain, int filter_size, int oc_interval, int ow_interval> | |||
| struct KernNeonSdotNCHW44 { | |||
| static void impl(dst_type* dst, const int dst_step, const int8_t* src, | |||
| const int ih, const int iw, const int8_t* filter, | |||
| const int32_t* bias, const int ic, const Op& op); | |||
| }; | |||
| template <typename dst_type, BiasMode bias_mode, typename Op, int ow_remain, | |||
| int filter_size, int oc_interval, int ow_interval> | |||
| struct KernNeonSdotNCHW44<dst_type, 1, bias_mode, Op, ow_remain, filter_size, | |||
| oc_interval, ow_interval> { | |||
| static void impl(dst_type* dst, const int dst_step, const int8_t* src, | |||
| const int ih, const int iw, const int8_t* filter, | |||
| const int32_t* bias, const int ic, const Op& op) { | |||
| constexpr int FH = filter_size; | |||
| constexpr int FW = filter_size; | |||
| constexpr int filter_next_row = | |||
| FW * OC_PACK_SIZE * | |||
| IC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC] | |||
| const int filter_next_4oc = | |||
| FH * FW * ic * OC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC] | |||
| const int src_next_ic = ih * iw; | |||
| const int src_next_row = iw * IC_PACK_SIZE; | |||
| constexpr int NSRC = (ow_interval + filter_size - 1) / 4 + 1; | |||
| constexpr int LOOP = oc_interval / 4; | |||
| int32x4_t res[3][ow_interval]; | |||
| init_ocx_ow8<LOOP, bias_mode>(res, bias, OC_PACK_SIZE); | |||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += IC_PACK_SIZE) { | |||
| const int8_t* i_src = src + ic_idx * src_next_ic; | |||
| const int8_t* i_filter = filter + ic_idx * FH * FW * OC_PACK_SIZE; | |||
| for (int fh_idx = 0; fh_idx < FH; ++fh_idx) { | |||
| int8x16_t src[1][4]; | |||
| int8x16_t weight[3]; | |||
| load_helper<NSRC, 0, SIMD_LEN, 1, Vld1q_s8>(src, i_src, 0); | |||
| //! do not use switch order 3,2,1 because it will slow the speed. | |||
| #define CALC_PART(step) \ | |||
| switch (LOOP) { \ | |||
| case 1: \ | |||
| weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \ | |||
| filter_next_col * step); \ | |||
| cal_helper<0, 0, step, 0, Vdotq_laneq_s32>(res, src, weight); \ | |||
| break; \ | |||
| case 2: \ | |||
| weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \ | |||
| filter_next_col * step); \ | |||
| cal_helper<0, 0, step, 0, Vdotq_laneq_s32>(res, src, weight); \ | |||
| weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \ | |||
| filter_next_col * step); \ | |||
| cal_helper<1, 0, step, 1, Vdotq_laneq_s32>(res, src, weight); \ | |||
| break; \ | |||
| case 3: \ | |||
| weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \ | |||
| filter_next_col * step); \ | |||
| cal_helper<0, 0, step, 0, Vdotq_laneq_s32>(res, src, weight); \ | |||
| weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \ | |||
| filter_next_col * step); \ | |||
| cal_helper<1, 0, step, 1, Vdotq_laneq_s32>(res, src, weight); \ | |||
| weight[2] = vld1q_s8(i_filter + filter_next_4oc * 2 + \ | |||
| filter_next_col * step); \ | |||
| cal_helper<2, 0, step, 2, Vdotq_laneq_s32>(res, src, weight); \ | |||
| break; \ | |||
| default: \ | |||
| break; \ | |||
| } | |||
| switch (filter_size) { | |||
| case 2: | |||
| UNROLL_CALL_RAW(2, CALC_PART); | |||
| break; | |||
| case 3: | |||
| UNROLL_CALL_RAW(3, CALC_PART); | |||
| break; | |||
| case 5: | |||
| UNROLL_CALL_RAW(5, CALC_PART); | |||
| break; | |||
| case 7: | |||
| UNROLL_CALL_RAW(7, CALC_PART); | |||
| break; | |||
| default: | |||
| break; | |||
| } | |||
| #undef CALC_PART | |||
| i_filter += filter_next_row; | |||
| i_src += src_next_row; | |||
| } | |||
| } | |||
| store_ocx_owx_remain_static<LOOP, ow_remain, Op>(res, op, dst, | |||
| dst_step); | |||
| } | |||
| }; | |||
| template <typename dst_type, BiasMode bias_mode, typename Op, int ow_remain, | |||
| int filter_size, int oc_interval, int ow_interval> | |||
| struct KernNeonSdotNCHW44<dst_type, 2, bias_mode, Op, ow_remain, filter_size, | |||
| oc_interval, ow_interval> { | |||
| static void impl(dst_type* dst, const int dst_step, const int8_t* src, | |||
| const int ih, const int iw, const int8_t* filter, | |||
| const int32_t* bias, const int ic, const Op& op) { | |||
| constexpr int FH = filter_size; | |||
| constexpr int FW = filter_size; | |||
| constexpr int filter_next_row = | |||
| FW * OC_PACK_SIZE * | |||
| IC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC] | |||
| const int filter_next_4oc = | |||
| FH * FW * ic * OC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC] | |||
| const int src_next_ic = ih * iw; | |||
| const int src_next_row = iw * IC_PACK_SIZE; | |||
| constexpr int NSRC = (ow_interval * 2 + filter_size - 3) / 8 + 1; | |||
| constexpr int LOOP = oc_interval / 4; | |||
| int32x4_t res[3][ow_interval]; | |||
| init_ocx_ow8<LOOP, bias_mode>(res, bias, OC_PACK_SIZE); | |||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += IC_PACK_SIZE) { | |||
| const int8_t* i_src = src + ic_idx * src_next_ic; | |||
| const int8_t* i_filter = filter + ic_idx * FH * FW * OC_PACK_SIZE; | |||
| for (int fh_idx = 0; fh_idx < FH; ++fh_idx) { | |||
| int8x16_t src[2][3]; | |||
| int8x16_t weight[3]; | |||
| const int offset = megdnn::div_ceil(iw, 2) * IC_PACK_SIZE; | |||
| load_helper<NSRC, 0, SIMD_LEN, 2, Vld1q_s8>(src, i_src, offset); | |||
| //! do not use switch order 3,2,1 because it will slow the speed. | |||
| #define CALC_PART(step) \ | |||
| switch (LOOP) { \ | |||
| case 1: \ | |||
| weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \ | |||
| filter_next_col * step); \ | |||
| cal_helper<0, step % 2, step / 2, 0, Vdotq_laneq_s32>(res, src, \ | |||
| weight); \ | |||
| break; \ | |||
| case 2: \ | |||
| weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \ | |||
| filter_next_col * step); \ | |||
| cal_helper<0, step % 2, step / 2, 0, Vdotq_laneq_s32>(res, src, \ | |||
| weight); \ | |||
| weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \ | |||
| filter_next_col * step); \ | |||
| cal_helper<1, step % 2, step / 2, 1, Vdotq_laneq_s32>(res, src, \ | |||
| weight); \ | |||
| break; \ | |||
| case 3: \ | |||
| weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \ | |||
| filter_next_col * step); \ | |||
| cal_helper<0, step % 2, step / 2, 0, Vdotq_laneq_s32>(res, src, \ | |||
| weight); \ | |||
| weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \ | |||
| filter_next_col * step); \ | |||
| cal_helper<1, step % 2, step / 2, 1, Vdotq_laneq_s32>(res, src, \ | |||
| weight); \ | |||
| weight[2] = vld1q_s8(i_filter + filter_next_4oc * 2 + \ | |||
| filter_next_col * step); \ | |||
| cal_helper<2, step % 2, step / 2, 2, Vdotq_laneq_s32>(res, src, \ | |||
| weight); \ | |||
| break; \ | |||
| default: \ | |||
| break; \ | |||
| } | |||
| switch (filter_size) { | |||
| case 2: | |||
| UNROLL_CALL_RAW(2, CALC_PART); | |||
| break; | |||
| case 3: | |||
| UNROLL_CALL_RAW(3, CALC_PART); | |||
| break; | |||
| case 5: | |||
| UNROLL_CALL_RAW(5, CALC_PART); | |||
| break; | |||
| case 7: | |||
| UNROLL_CALL_RAW(7, CALC_PART); | |||
| break; | |||
| default: | |||
| break; | |||
| } | |||
| #undef CALC_PART | |||
| i_filter += filter_next_row; | |||
| i_src += src_next_row; | |||
| } | |||
| } | |||
| store_ocx_owx_remain_static<LOOP, ow_remain, Op>(res, op, dst, | |||
| dst_step); | |||
| } | |||
| }; | |||
| } // namespace direct_dotprod_nchw44 | |||
| } // namespace arm_common | |||
| } // namespace megdnn | |||
| #endif | |||
| //vim: syntax=cpp.doxygen | |||
| @@ -536,6 +536,7 @@ inline void init_oc8_ow8(int32x4_t c[2][8], const int32_t* bias_ptr, | |||
| #undef BAIS_INIT | |||
| } | |||
| } | |||
| /////////////////////////init_ocx_ow8//////////////////// | |||
| inline float32x4_t neon_vdupq_n(float val) { | |||
| @@ -64,6 +64,8 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||
| AlgoDotU8DirectStride1 du8_direct_stride1_small_group{false}; | |||
| AlgoDotU8DirectStride2 du8_direct_stride2_large_group{true}; | |||
| AlgoDotU8DirectStride2 du8_direct_stride2_small_group{false}; | |||
| AlgoDotS8Direct_NCHW44 ds8_direct_nchw44; | |||
| #endif | |||
| AlgoF32DirectNCHWNCHW44 f32_direct_stride2_nchw_nchw44; | |||
| @@ -103,6 +105,8 @@ public: | |||
| direct_algos.emplace_back(&du8_direct_stride1_small_group); | |||
| direct_algos.emplace_back(&du8_direct_stride2_large_group); | |||
| direct_algos.emplace_back(&du8_direct_stride2_small_group); | |||
| direct_algos.emplace_back(&ds8_direct_nchw44); | |||
| #endif | |||
| direct_algos.emplace_back(&qu8_direct_stride2_large_group); | |||
| direct_algos.emplace_back(&qu8_direct_stride2_small_group); | |||
| @@ -67,6 +67,8 @@ private: | |||
| class AlgoDotS8DirectStride2; | |||
| class AlgoDotU8DirectStride1; | |||
| class AlgoDotU8DirectStride2; | |||
| class AlgoDotS8Direct_NCHW44; | |||
| #endif | |||
| class AlgoF32Direct; | |||
| class AlgoF32DirectStride1; | |||
| @@ -1809,6 +1809,81 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_QUINT8_STRIDE2_WITHDOTPROD) { | |||
| used1 / used0); | |||
| } | |||
| } | |||
| TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_NCHW44_DOT) { | |||
| using namespace conv_bias; | |||
| std::vector<TestArg> args; | |||
| auto run = [&](size_t oc, size_t ic, size_t w, size_t h, size_t kernel, | |||
| size_t p, size_t stride, NonlineMode nonline_mode) { | |||
| if (w + 2 * p < kernel || h + 2 * p < kernel) | |||
| return; | |||
| param::ConvBias param; | |||
| param.stride_h = stride; | |||
| param.stride_w = stride; | |||
| param.pad_h = p; | |||
| param.pad_w = p; | |||
| param.nonlineMode = nonline_mode; | |||
| param.format = param::ConvBias::Format::NCHW44_DOT; | |||
| //! channel bias | |||
| args.emplace_back(param, TensorShape{1, ic/4, h, w, 4}, | |||
| TensorShape{oc/4, ic/4, kernel, kernel, 4, 4}, | |||
| TensorShape{1, oc/4, 1, 1, 4}); | |||
| }; | |||
| for (size_t stride : {1, 2}) | |||
| for (size_t kernel : {2, 3, 5, 7}) | |||
| for(size_t oc : {64}) | |||
| for (NonlineMode nonline_mode : {NonlineMode::IDENTITY}) { | |||
| run(oc, oc, 56, 56, kernel, kernel / 2, stride, nonline_mode); | |||
| } | |||
| constexpr size_t RUN = 50; | |||
| Benchmarker<ConvBias> benchmark0(handle()); | |||
| benchmark0.set_dtype(0, dtype::QuantizedS8(2.5f)) | |||
| .set_dtype(1, dtype::QuantizedS8(2.5f)) | |||
| .set_dtype(2, dtype::QuantizedS32(6.25f)) | |||
| .set_dtype(4, dtype::QuantizedS8(60.25f)); | |||
| benchmark0.set_display(false); | |||
| benchmark0.set_times(RUN); | |||
| benchmark0.set_before_exec_callback( | |||
| conv_bias::ConvBiasAlgoChecker<ConvBiasForward>("ARMDOTS8DIRECT_NCHW44")); | |||
| Benchmarker<ConvBias> benchmark1(handle()); | |||
| benchmark1.set_dtype(0, dtype::QuantizedS8(2.5f)) | |||
| .set_dtype(1, dtype::QuantizedS8(2.5f)) | |||
| .set_dtype(2, dtype::QuantizedS32(6.25f)) | |||
| .set_dtype(4, dtype::QuantizedS8(60.25f)); | |||
| benchmark1.set_display(false); | |||
| benchmark1.set_times(RUN); | |||
| for (auto&& arg : args) { | |||
| TensorLayout dst_layout; | |||
| auto opr = handle()->create_operator<ConvBias>(); | |||
| opr->param() = arg.param; | |||
| opr->deduce_layout({arg.src, dtype::Int8()}, | |||
| {arg.filter, dtype::Int8()}, | |||
| {arg.bias, dtype::Int32()}, {}, dst_layout); | |||
| //! dst.nr_elems * IC * FH * FW * 2 | |||
| float computations = dst_layout.total_nr_elems() * arg.filter[1] * | |||
| arg.filter[2] * arg.filter[3] * 8.0 / | |||
| (1024 * 1024 * 1024) * 1e3; | |||
| auto used0 = benchmark0.set_param(arg.param).exec( | |||
| {arg.src, arg.filter, arg.bias, {}, {}}) / | |||
| RUN; | |||
| auto used1 = benchmark1.set_param(arg.param).exec( | |||
| {arg.src, arg.filter, arg.bias, {}, {}}) / | |||
| RUN; | |||
| printf("%s %s: Direct use: %f ms %f Gflops normal: %f ms %f GFlops " | |||
| "speedup: %f\n", | |||
| arg.src.to_string().c_str(), arg.filter.to_string().c_str(), | |||
| used0, computations / used0, used1, computations / used1, | |||
| used1 / used0); | |||
| } | |||
| } | |||
| #endif | |||
| #endif | |||
| @@ -155,7 +155,7 @@ std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args( | |||
| if (support_sigmoid) { | |||
| nonlinemode.emplace_back(NLMode::SIGMOID); | |||
| } | |||
| std::vector<megdnn::BiasMode> bias_mode = { | |||
| megdnn::BiasMode::BROADCAST_CHANNEL_BIAS}; | |||
| if (no_bias) { | |||
| @@ -672,6 +672,63 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE2_SMALL_GROUP) { | |||
| get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(), | |||
| "ARMDOTU8STRD2_SMALL_GROUP"); | |||
| } | |||
| /******************************dot int8x8x8 nchw44 ***********************/ | |||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S1_Q8x8x8) { | |||
| using namespace conv_bias; | |||
| std::vector<TestArg> args = get_nchw44_conv_bias_args({2, 3, 5, 7}, 1); | |||
| for (auto&& arg : args) | |||
| arg.param.format = param::ConvBias::Format::NCHW44_DOT; | |||
| checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8DIRECT_NCHW44"); | |||
| } | |||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S1_Q8x8x32) { | |||
| using namespace conv_bias; | |||
| std::vector<TestArg> args = | |||
| get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, true, true); | |||
| for (auto&& arg : args) | |||
| arg.param.format = param::ConvBias::Format::NCHW44_DOT; | |||
| checker_conv_bias_qint8x8x32(args, handle(), "ARMDOTS8DIRECT_NCHW44"); | |||
| } | |||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S1_8x8x32) { | |||
| using namespace conv_bias; | |||
| std::vector<TestArg> args = | |||
| get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, true, true); | |||
| for (auto&& arg : args) | |||
| arg.param.format = param::ConvBias::Format::NCHW44_DOT; | |||
| checker_conv_bias_int8x8x32_multi(args, handle(), "ARMDOTS8DIRECT_NCHW44"); | |||
| } | |||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S2_Q8x8x8) { | |||
| using namespace conv_bias; | |||
| //! test qint8x8x8 | |||
| std::vector<TestArg> args = get_nchw44_conv_bias_args({2, 3, 5, 7}, 2); | |||
| for (auto&& arg : args) | |||
| arg.param.format = param::ConvBias::Format::NCHW44_DOT; | |||
| checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8DIRECT_NCHW44"); | |||
| } | |||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S2_Q8x8x32) { | |||
| using namespace conv_bias; | |||
| //! test qint8x8x8 | |||
| std::vector<TestArg> args = | |||
| get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, true, true); | |||
| for (auto&& arg : args) | |||
| arg.param.format = param::ConvBias::Format::NCHW44_DOT; | |||
| checker_conv_bias_qint8x8x32(args, handle(), "ARMDOTS8DIRECT_NCHW44"); | |||
| } | |||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S2_8x8x32) { | |||
| using namespace conv_bias; | |||
| //! test qint8x8x8 | |||
| std::vector<TestArg> args = | |||
| get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, true, true); | |||
| for (auto&& arg : args) | |||
| arg.param.format = param::ConvBias::Format::NCHW44_DOT; | |||
| checker_conv_bias_int8x8x32_multi(args, handle(), "ARMDOTS8DIRECT_NCHW44"); | |||
| } | |||
| #endif | |||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F23_4) { | |||