| @@ -6,7 +6,8 @@ | |||||
| * | * | ||||
| * Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | */ | ||||
| #pragma once | #pragma once | ||||
| @@ -156,7 +157,6 @@ private: | |||||
| uint32_t m_tile_size; | uint32_t m_tile_size; | ||||
| }; | }; | ||||
| class ConvBiasImpl::AlgoF32Direct final : public AlgoBase { | class ConvBiasImpl::AlgoF32Direct final : public AlgoBase { | ||||
| SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | ||||
| bool m_large_group; | bool m_large_group; | ||||
| @@ -217,6 +217,24 @@ public: | |||||
| fallback::ConvBiasImpl* opr, | fallback::ConvBiasImpl* opr, | ||||
| const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
| }; | }; | ||||
| class ConvBiasImpl::AlgoF32DirectStride2NCHWNCHW44 final : public AlgoBase { | |||||
| SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | |||||
| public: | |||||
| AlgoF32DirectStride2NCHWNCHW44() {} | |||||
| bool is_reproducible() const override { return true; } | |||||
| const char* name() const override { return "F32_CONV_NCHW_NCHW44"; } | |||||
| bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, | |||||
| AlgoSelectionStrategy algo_selection_strategy) const override; | |||||
| size_t get_workspace(fallback::ConvBiasImpl* opr, | |||||
| const NCBKernSizeParam& param) const override; | |||||
| virtual SmallVector<NCBKern> dispatch_kerns( | |||||
| fallback::ConvBiasImpl* opr, | |||||
| const NCBKernSizeParam& param) const override; | |||||
| }; | |||||
| } // namespace arm_common | } // namespace arm_common | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| @@ -0,0 +1,317 @@ | |||||
| /** | |||||
| * \file | |||||
| dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_algo.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express | |||||
| * or implied. | |||||
| */ | |||||
| #include "megdnn/oprs.h" | |||||
| #include "src/arm_common/conv_bias/fp32/algos.h" | |||||
| #include "src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.h" | |||||
| #include "src/arm_common/conv_bias/fp32/strategy.h" | |||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/common/opr_delegate.h" | |||||
| #include "midout.h" | |||||
| using namespace megdnn; | |||||
| using namespace arm_common; | |||||
| using conv_fun = std::function<void( | |||||
| WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param, | |||||
| const ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
| const CpuNDRange& workspace_ids, const CpuNDRange& ncb_range)>; | |||||
| MIDOUT_DECL(megdnn_arm_common_conv_bias_fp32_nchw_nchw44_stride2) | |||||
| namespace { | |||||
| static inline int block_helper(const int nthread, const int amount, | |||||
| const int per_unit_bytes) { | |||||
| MEGDNN_MARK_USED_VAR(per_unit_bytes); | |||||
| const int block_per_thread = div_ceil(amount, nthread); | |||||
| const int best_block = 16; | |||||
| const int max_block_num = div_ceil(block_per_thread, best_block); | |||||
| const int min_block_num = std::max(max_block_num - 1, 1); | |||||
| const int max_block = div_ceil(block_per_thread, max_block_num); | |||||
| const int min_block = div_ceil(block_per_thread, min_block_num); | |||||
| const int max_loss = std::abs(max_block_num * max_block - block_per_thread); | |||||
| const int min_loss = std::abs(min_block_num * min_block - block_per_thread); | |||||
| int block = max_loss > min_loss ? min_block : max_block; | |||||
| return block; | |||||
| } | |||||
| static inline size_t get_perthread_cache_bytes(const int ic, const int ih2, | |||||
| const int iw2) { | |||||
| // border_size is used to avoid read illegal memory | |||||
| int border_size = 64 * 2; | |||||
| return ic * ih2 * iw2 * sizeof(float) + border_size; | |||||
| } | |||||
| static void get_rectified_size( | |||||
| const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, int& ih2, | |||||
| int& iw2, int& oh2, int& ow2) { | |||||
| int iw = param.isz[1]; | |||||
| int oh = param.osz[0]; | |||||
| int ow = param.osz[1]; | |||||
| oh2 = oh; | |||||
| ow2 = ow; | |||||
| constexpr int cacheline = 64 / sizeof(float); | |||||
| int block_oh = block_helper(param.nr_threads, oh, 0); | |||||
| auto&& fm = param.filter_meta; | |||||
| const int stride_h = static_cast<int>(fm.stride[0]); | |||||
| const int filter_h = static_cast<int>(fm.spatial[0]); | |||||
| ih2 = block_oh * stride_h + filter_h - stride_h; | |||||
| iw2 = round_up(iw + 2 * static_cast<int>(fm.padding[1]), cacheline); | |||||
| } | |||||
| static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { | |||||
| auto&& fm = param.filter_meta; | |||||
| int group = fm.group; | |||||
| int ic = fm.icpg; | |||||
| int oc = fm.ocpg; | |||||
| int fh = fm.spatial[0]; | |||||
| int fw = fm.spatial[1]; | |||||
| int ih2, iw2, oh2, ow2; | |||||
| get_rectified_size(param, ih2, iw2, oh2, ow2); | |||||
| int oh_block = block_helper(param.nr_threads, oh2, 0); | |||||
| megdnn_assert(oh_block != 0, "oh_block!=0"); | |||||
| size_t src_size = get_perthread_cache_bytes(ic, ih2, iw2); | |||||
| size_t weight_size = group * oc * ic * fh * fw * sizeof(float); | |||||
| return {nullptr, {src_size * param.nr_threads, weight_size}}; | |||||
| }; | |||||
| static inline void copy_pad_src(float* sptr_base, const float* sptr_origin, | |||||
| int ph, int pw, int pad_right, int ih, int iw, | |||||
| int iw2, int pad_top, int pad_bottom, int ic, | |||||
| int ic_stride) { | |||||
| MEGDNN_MARK_USED_VAR(ph); | |||||
| rep(ic_idx, ic) { | |||||
| const float* sptr = sptr_origin + ic_idx * ic_stride; | |||||
| memset(sptr_base, 0, sizeof(float) * iw2 * pad_top); | |||||
| sptr_base += iw2 * pad_top; | |||||
| rep(ih_idx, ih) { | |||||
| memset(sptr_base, 0, sizeof(float) * pw); | |||||
| sptr_base += pw; | |||||
| memcpy(sptr_base, sptr, sizeof(float) * iw); | |||||
| sptr_base += iw; | |||||
| sptr += iw; | |||||
| memset(sptr_base, 0, sizeof(float) * pad_right); | |||||
| sptr_base += pad_right; | |||||
| } | |||||
| memset(sptr_base, 0, sizeof(float) * iw2 * pad_bottom); | |||||
| sptr_base += iw2 * pad_bottom; | |||||
| } | |||||
| } | |||||
| static void pack_weight(WorkspaceBundle bundle, | |||||
| const ConvBiasImpl::NCBKernParam& kern_param, | |||||
| const ConvBiasImpl::NCBKernIndex& ncb_index) { | |||||
| bundle.set(kern_param.workspace_ptr); | |||||
| const int group_id = ncb_index.ndrange_id[0]; | |||||
| int fh = kern_param.filter_meta.spatial[0]; | |||||
| int fw = kern_param.filter_meta.spatial[1]; | |||||
| int oc = kern_param.filter_meta.ocpg; | |||||
| int ic = kern_param.filter_meta.icpg; | |||||
| int oc_block = oc; | |||||
| int oc_idx = 0; | |||||
| const float* fptr = | |||||
| kern_param.filter<dt_float32>(group_id) + oc_idx * fh * fw * ic; | |||||
| auto packed_weight = reinterpret_cast<float*>(bundle.get(1)) + | |||||
| group_id * oc * ic * fh * fw + oc_idx * ic * fh * fw; | |||||
| conv_bias::pack_weight_fp32_nchw_nchw44(fptr, packed_weight, oc_block, fh, | |||||
| fw, ic); | |||||
| } | |||||
| template <size_t filter, BiasMode bias_mode, typename Op> | |||||
| static void do_conv_kern(WorkspaceBundle bundle, | |||||
| const ConvBiasImpl::NCBKernParam& kern_param, | |||||
| const ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
| const CpuNDRange&, const CpuNDRange&) { | |||||
| const int oh = kern_param.osz[0]; | |||||
| const int ow = kern_param.osz[1]; | |||||
| const int fh = kern_param.filter_meta.spatial[0]; | |||||
| const int fw = kern_param.filter_meta.spatial[1]; | |||||
| const int ic = kern_param.filter_meta.icpg; | |||||
| const int oc = kern_param.filter_meta.ocpg; | |||||
| const int ih = kern_param.isz[0]; | |||||
| const int iw = kern_param.isz[1]; | |||||
| const int stride_h = kern_param.filter_meta.stride[0]; | |||||
| const int ph = kern_param.filter_meta.padding[0]; | |||||
| const int pw = kern_param.filter_meta.padding[1]; | |||||
| int ih2 = 0; | |||||
| int iw2 = 0; | |||||
| int oh2 = 0; | |||||
| int ow2 = 0; | |||||
| get_rectified_size(kern_param, ih2, iw2, oh2, ow2); | |||||
| bundle.set(kern_param.workspace_ptr); | |||||
| constexpr int pack_c = 4; | |||||
| const int batch_id = ncb_index.ndrange_id[0]; | |||||
| const int group_id = ncb_index.ndrange_id[1]; | |||||
| int oc_idx = 0; | |||||
| int oc_block = oc; | |||||
| int oh_block = block_helper(kern_param.nr_threads, oh2, 0); | |||||
| const int oh_idx = ncb_index.ndrange_id[2]; | |||||
| const int oh_block_real = std::min(oh - oh_idx * oh_block, oh_block); | |||||
| const int ih_real = oh_block_real * stride_h + fh - stride_h; | |||||
| const int src_top_pad = std::max(ph - oh_idx * oh_block * stride_h, 0); | |||||
| const int src_bottom_pad = std::max( | |||||
| (oh_idx * oh_block + oh_block_real - 1) * stride_h + fh - ih - ph, | |||||
| 0); | |||||
| const int remain_right_pad = std::max(iw2 - iw - pw, 0); | |||||
| const int src_offset = std::max(oh_idx * oh_block * stride_h - ph, 0) * iw; | |||||
| const float* origin_sptr = static_cast<const float*>(kern_param.src<float>( | |||||
| batch_id, group_id, 0, 1, 1)) + | |||||
| src_offset; | |||||
| const size_t src_size = get_perthread_cache_bytes(ic, ih2, iw2); | |||||
| float* sptr = reinterpret_cast<float*>((int8_t*)bundle.get(0) + | |||||
| ncb_index.thread_id * src_size); | |||||
| copy_pad_src(sptr, origin_sptr, ph, pw, remain_right_pad, | |||||
| ih_real - src_top_pad - src_bottom_pad, iw, iw2, src_top_pad, | |||||
| src_bottom_pad, ic, ih * iw); | |||||
| // pack weight | |||||
| auto packed_weight = reinterpret_cast<float*>(bundle.get(1)) + | |||||
| group_id * oc * ic * fh * fw + oc_idx * ic * fh * fw; | |||||
| // get param | |||||
| float_t* dst = kern_param.dst<float_t>(batch_id, group_id) + | |||||
| oh_idx * oh_block * ow * pack_c; | |||||
| const float* bptr = | |||||
| kern_param.bias<dt_float32>(batch_id, group_id) + oc_idx; | |||||
| Op op; | |||||
| #define KERN1_NCHW44_CONV(filter) \ | |||||
| conv_bias::conv_direct_stride2_##filter##x##filter##_fp32_nchw_nchw44< \ | |||||
| \ | |||||
| bias_mode, Op>(sptr, packed_weight, bptr, nullptr, dst, oc_block, \ | |||||
| ic, ih_real, iw2, oh, oh_block_real, ow, op, ph, \ | |||||
| pw) | |||||
| DISPATCH_FILTER(filter, KERN1_NCHW44_CONV); | |||||
| #undef KERN1_NCHW44_CONV | |||||
| } | |||||
| } // namespace | |||||
| /* ===================== stride2 algo ===================== */ | |||||
| bool ConvBiasImpl::AlgoF32DirectStride2NCHWNCHW44::usable( | |||||
| fallback::ConvBiasImpl*, const NCBKernSizeParam& param, | |||||
| AlgoSelectionStrategy) const { | |||||
| auto&& fm = param.filter_meta; | |||||
| auto fh = fm.spatial[0]; | |||||
| int oc = fm.ocpg; | |||||
| bool ok_type = ((param.src_type.enumv() == DTypeEnum::Float32 && | |||||
| param.filter_type.enumv() == DTypeEnum::Float32 && | |||||
| (param.dst_type.enumv() == DTypeEnum::Float32))) && | |||||
| (fm.format == param::Convolution::Format::NCHW44); | |||||
| bool ok_src_dst = fm.icpg < 4 && (oc % 4 == 0 && oc >= 4) && fm.group == 1; | |||||
| bool ok_filter = fm.spatial_ndim == 2 && fh == fm.spatial[1] && | |||||
| (fh == 3 || fh == 5 || fh == 7); | |||||
| bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||||
| fm.stride[0] == 2 && fm.stride[1] == 2; | |||||
| bool ok_conv = !fm.should_flip && param.bias_mode != BiasMode::BIAS; | |||||
| bool avaible = ok_type && ok_src_dst && ok_filter && ok_slide && ok_conv; | |||||
| return avaible; | |||||
| } | |||||
| size_t ConvBiasImpl::AlgoF32DirectStride2NCHWNCHW44::get_workspace( | |||||
| fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||||
| return get_bundle(param).total_size_in_bytes(); | |||||
| } | |||||
| SmallVector<ConvBiasImpl::NCBKern> | |||||
| ConvBiasImpl::AlgoF32DirectStride2NCHWNCHW44::dispatch_kerns( | |||||
| fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||||
| auto fm = param.filter_meta; | |||||
| const int batch = param.n; | |||||
| const int group = fm.group; | |||||
| WorkspaceBundle wbundle = get_bundle(param); | |||||
| conv_fun do_conv_fun = nullptr; | |||||
| // NOTE: remain_w is not used to gen hash of midout for compatible with | |||||
| // shape runtime | |||||
| #define DO_CONV_KERN_FUN(filter, bias_mode, op) \ | |||||
| MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp32_nchw_nchw44_stride2, \ | |||||
| midout_iv(#filter #bias_mode #op##_hash)) { \ | |||||
| do_conv_fun = do_conv_kern<filter, bias_mode, op>; \ | |||||
| } \ | |||||
| MIDOUT_END(); | |||||
| #define GET_OP_PARAM(filter, bias_mode) \ | |||||
| switch (param.nonlineMode) { \ | |||||
| case param::ConvBias::NonlineMode::IDENTITY: \ | |||||
| DO_CONV_KERN_FUN(filter, bias_mode, NoneOp<dt_float32>) \ | |||||
| break; \ | |||||
| case param::ConvBias::NonlineMode::RELU: \ | |||||
| DO_CONV_KERN_FUN(filter, bias_mode, ReluOp<dt_float32>) \ | |||||
| break; \ | |||||
| case param::ConvBias::NonlineMode::H_SWISH: \ | |||||
| DO_CONV_KERN_FUN(filter, bias_mode, HSwishOp<dt_float32>) \ | |||||
| break; \ | |||||
| default: \ | |||||
| megdnn_assert(0); \ | |||||
| break; \ | |||||
| } | |||||
| #define GET_BIAS_MODE_PARAM(filter) \ | |||||
| switch (param.bias_mode) { \ | |||||
| case BiasMode::NO_BIAS: \ | |||||
| GET_OP_PARAM(filter, BiasMode::NO_BIAS) \ | |||||
| break; \ | |||||
| case BiasMode::BROADCAST_CHANNEL_BIAS: \ | |||||
| GET_OP_PARAM(filter, BiasMode::BROADCAST_CHANNEL_BIAS) \ | |||||
| break; \ | |||||
| default: \ | |||||
| megdnn_assert(0); \ | |||||
| break; \ | |||||
| } | |||||
| #define DISPATCH_CONV_KERN() \ | |||||
| switch (param.filter_meta.spatial[0]) { \ | |||||
| case 3: \ | |||||
| GET_BIAS_MODE_PARAM(3) \ | |||||
| break; \ | |||||
| case 5: \ | |||||
| GET_BIAS_MODE_PARAM(5) \ | |||||
| break; \ | |||||
| case 7: \ | |||||
| GET_BIAS_MODE_PARAM(7) \ | |||||
| break; \ | |||||
| default: \ | |||||
| megdnn_assert(0); \ | |||||
| break; \ | |||||
| } | |||||
| DISPATCH_CONV_KERN(); | |||||
| #undef DO_CONV_KERN_FUN | |||||
| #undef GET_REMAIN_W_PARAM | |||||
| #undef GET_OP_PARAM | |||||
| #undef GET_BIAS_MODE_PARAM | |||||
| #undef DISPATCH_CONV_KERN | |||||
| megdnn_assert(do_conv_fun); | |||||
| SmallVector<ConvBiasImpl::NCBKern> ret_kerns; | |||||
| WorkspaceBundle bundle = wbundle; | |||||
| int oh = param.osz[0]; | |||||
| int oh_block = block_helper(param.nr_threads, oh, 0); | |||||
| auto do_pack_weight = [bundle](const NCBKernParam& kern_param, | |||||
| const NCBKernIndex& ncb_index) { | |||||
| pack_weight(bundle, kern_param, ncb_index); | |||||
| }; | |||||
| ret_kerns.push_back({do_pack_weight, {static_cast<size_t>(group)}}); | |||||
| CpuNDRange ncb_range = {static_cast<size_t>(batch), | |||||
| static_cast<size_t>(group), | |||||
| static_cast<size_t>(div_ceil(oh, oh_block))}; | |||||
| auto do_conv = [bundle, do_conv_fun, ncb_range]( | |||||
| const NCBKernParam& kern_param, | |||||
| const NCBKernIndex& ncb_index) { | |||||
| do_conv_fun(bundle, kern_param, ncb_index, ncb_index.ndrange_id, | |||||
| ncb_range); | |||||
| }; | |||||
| ret_kerns.push_back({do_conv, ncb_range}); | |||||
| return ret_kerns; | |||||
| } | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -0,0 +1,430 @@ | |||||
| /** | |||||
| * \file | |||||
| * dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.h" | |||||
| #include "src/arm_common/conv_bias/intrinsic_helper.h" | |||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||||
| #include "src/common/unroll_macro.h" | |||||
| #include "src/common/utils.h" | |||||
| #include "src/fallback/conv_bias/common.h" | |||||
| using namespace megdnn; | |||||
| using namespace arm_common; | |||||
| namespace { | |||||
| template <int src_idx, int weight_idx, int c_dim, typename Func, typename T, | |||||
| typename T2, typename T3, typename T4> | |||||
| struct ShiftCalHelper { | |||||
| static void impl(T& c, T2& src, T3& weight); | |||||
| }; | |||||
| template <int src_idx, int weight_idx, typename Func, typename T, typename T2, | |||||
| typename T3, typename T4> | |||||
| struct ShiftCalHelper<src_idx, weight_idx, 2, Func, T, T2, T3, T4> { | |||||
| static void impl(T& c, T2& src, T3& weight) { | |||||
| constexpr int stride = 2; | |||||
| #define cb(step) \ | |||||
| c[0][step] = Func::template impl<(step * stride + src_idx) % 4>( \ | |||||
| c[0][step], weight[0][weight_idx], \ | |||||
| src[(step * stride + src_idx) / 4]); \ | |||||
| c[1][step] = Func::template impl<(step * stride + src_idx) % 4>( \ | |||||
| c[1][step], weight[1][weight_idx], \ | |||||
| src[(step * stride + src_idx) / 4]); | |||||
| UNROLL_CALL_RAW(8, cb); | |||||
| #undef cb | |||||
| } | |||||
| }; | |||||
| template <int src_idx, int weight_idx, typename Func, typename T, typename T2, | |||||
| typename T3, typename T4> | |||||
| struct ShiftCalHelper<src_idx, weight_idx, 1, Func, T, T2, T3, T4> { | |||||
| static void impl(T& c, T2& src, T3& weight) { | |||||
| constexpr int stride = 2; | |||||
| #define cb(step) \ | |||||
| c[0][step] = Func::template impl<(step * stride + src_idx) % 4>( \ | |||||
| c[0][step], weight[0][weight_idx], \ | |||||
| src[(step * stride + src_idx) / 4]); | |||||
| UNROLL_CALL_RAW(8, cb); | |||||
| #undef cb | |||||
| } | |||||
| }; | |||||
| template <int src_idx, int weight_idx, int c_dim, typename FUNC, typename T, | |||||
| typename T2, typename T3> | |||||
| inline void cal_helper(T& c, T2& src, T3& weight) { | |||||
| ShiftCalHelper<src_idx, weight_idx, c_dim, FUNC, T, T2, T3, int>::impl( | |||||
| c, src, weight); | |||||
| }; | |||||
| template <int oc> | |||||
| struct OCHelper { | |||||
| public: | |||||
| static const int val = -1; | |||||
| }; | |||||
| template <> | |||||
| struct OCHelper<4> { | |||||
| public: | |||||
| static const int val = 1; | |||||
| }; | |||||
| template <> | |||||
| struct OCHelper<8> { | |||||
| public: | |||||
| static const int val = 2; | |||||
| }; | |||||
| /** | |||||
| * oc8_ow8(m = 8, n = 8) and oc4_ow8(m = 4, n = 8) gemm like kernel | |||||
| * */ | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int filter_size, | |||||
| int oc_block> | |||||
| struct KerNeonXXs2NchwNchw44FP32 { | |||||
| static void impl(const float32_t* src_ptr, const float32_t* weight_ptr, | |||||
| const float32_t* bias_ptr, float32_t* dst_ptr, int ic, | |||||
| int ih, int iw, int ld_dst_oc, const Op& op); | |||||
| }; | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block> | |||||
| struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 7, oc_block> { | |||||
| static void impl(const float32_t* src_ptr, const float32_t* weight_ptr, | |||||
| const float32_t* bias_ptr, float32_t* dst_ptr, int ic, | |||||
| int ih, int iw, int ld_dst_oc, const Op& op) { | |||||
| constexpr int loop_ic_step = 1; | |||||
| constexpr int filter_size = 7; | |||||
| constexpr int oc_step = 4; | |||||
| constexpr int simd_len = 4; | |||||
| constexpr int src_reg_size = 6; | |||||
| constexpr int ld_weight_fw = oc_step * filter_size; | |||||
| const int ld_weight_oc = oc_step * filter_size * filter_size * ic; | |||||
| const int ld_weight_ic = oc_step * filter_size * filter_size; | |||||
| const int ld_src_ic = ih * iw; | |||||
| constexpr int c_dim = OCHelper<oc_block>::val; | |||||
| float32x4_t c[c_dim][8]; | |||||
| init_ocx_ow8<c_dim, bias_mode>(c, bias_ptr, oc_step); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
| float32x4_t src[src_reg_size]; | |||||
| float32x4_t weight[c_dim][filter_size]; | |||||
| #define KERNEL_CB(step) \ | |||||
| load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>( \ | |||||
| src, src_ptr + step * iw, 0); \ | |||||
| load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( \ | |||||
| weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \ | |||||
| cal_helper<0, 0, c_dim, Vfmaq_laneq_f32>(c, src, weight); \ | |||||
| cal_helper<1, 1, c_dim, Vfmaq_laneq_f32>(c, src, weight); \ | |||||
| cal_helper<2, 2, c_dim, Vfmaq_laneq_f32>(c, src, weight); \ | |||||
| cal_helper<3, 3, c_dim, Vfmaq_laneq_f32>(c, src, weight); \ | |||||
| cal_helper<4, 4, c_dim, Vfmaq_laneq_f32>(c, src, weight); \ | |||||
| cal_helper<5, 5, c_dim, Vfmaq_laneq_f32>(c, src, weight); \ | |||||
| cal_helper<6, 6, c_dim, Vfmaq_laneq_f32>(c, src, weight); | |||||
| UNROLL_CALL_RAW(7, KERNEL_CB) | |||||
| #undef KERNEL_CB | |||||
| src_ptr += ld_src_ic; | |||||
| weight_ptr += ld_weight_ic; | |||||
| } | |||||
| store_ocx_ow8_remain_static<c_dim, remain_w, Op>(c, op, dst_ptr, | |||||
| ld_dst_oc); | |||||
| } | |||||
| }; | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block> | |||||
| struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 5, oc_block> { | |||||
| static void impl(const float32_t* src_ptr, const float32_t* weight_ptr, | |||||
| const float32_t* bias_ptr, float32_t* dst_ptr, int ic, | |||||
| int ih, int iw, int ld_dst_oc, const Op& op) { | |||||
| constexpr int loop_ic_step = 1; | |||||
| constexpr int filter_size = 5; | |||||
| constexpr int oc_step = 4; | |||||
| constexpr int simd_len = 4; | |||||
| constexpr int src_reg_size = 5; | |||||
| constexpr int ld_weight_fw = oc_step * filter_size; | |||||
| const int ld_weight_oc = oc_step * filter_size * filter_size * ic; | |||||
| const int ld_weight_ic = oc_step * filter_size * filter_size; | |||||
| const int ld_src_ic = ih * iw; | |||||
| constexpr int c_dim = OCHelper<oc_block>::val; | |||||
| float32x4_t c[c_dim][8]; | |||||
| init_ocx_ow8<c_dim, bias_mode>(c, bias_ptr, oc_step); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
| float32x4_t src[src_reg_size]; | |||||
| float32x4_t weight[c_dim][filter_size]; | |||||
| #define KERNEL_CB(step) \ | |||||
| load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>( \ | |||||
| src, src_ptr + step * iw, 0); \ | |||||
| load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( \ | |||||
| weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \ | |||||
| cal_helper<0, 0, c_dim, Vfmaq_laneq_f32>(c, src, weight); \ | |||||
| cal_helper<1, 1, c_dim, Vfmaq_laneq_f32>(c, src, weight); \ | |||||
| cal_helper<2, 2, c_dim, Vfmaq_laneq_f32>(c, src, weight); \ | |||||
| cal_helper<3, 3, c_dim, Vfmaq_laneq_f32>(c, src, weight); \ | |||||
| cal_helper<4, 4, c_dim, Vfmaq_laneq_f32>(c, src, weight); | |||||
| UNROLL_CALL_RAW(5, KERNEL_CB) | |||||
| #undef KERNEL_CB | |||||
| src_ptr += ld_src_ic; | |||||
| weight_ptr += ld_weight_ic; | |||||
| } | |||||
| store_ocx_ow8_remain_static<c_dim, remain_w, Op>(c, op, dst_ptr, | |||||
| ld_dst_oc); | |||||
| } | |||||
| }; | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block> | |||||
| struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 3, oc_block> { | |||||
| static void impl(const float32_t* src_ptr, const float32_t* weight_ptr, | |||||
| const float32_t* bias_ptr, float32_t* dst_ptr, int ic, | |||||
| int ih, int iw, int ld_dst_oc, const Op& op) { | |||||
| constexpr int loop_ic_step = 1; | |||||
| constexpr int filter_size = 3; | |||||
| constexpr int oc_step = 4; | |||||
| constexpr int simd_len = 4; | |||||
| constexpr int src_reg_size = 5; | |||||
| constexpr int ld_weight_fw = oc_step * filter_size; | |||||
| const int ld_weight_oc = oc_step * filter_size * filter_size * ic; | |||||
| const int ld_weight_ic = oc_step * filter_size * filter_size; | |||||
| const int ld_src_ic = ih * iw; | |||||
| constexpr int c_dim = OCHelper<oc_block>::val; | |||||
| float32x4_t c[c_dim][8]; | |||||
| init_ocx_ow8<c_dim, bias_mode>(c, bias_ptr, oc_step); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
| float32x4_t src[src_reg_size]; | |||||
| float32x4_t weight[c_dim][filter_size]; | |||||
| // row 0 | |||||
| load_helper<5, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); | |||||
| load_helper<3, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, | |||||
| ld_weight_oc); | |||||
| cal_helper<0, 0, c_dim, Vfmaq_laneq_f32>(c, src, weight); | |||||
| cal_helper<1, 1, c_dim, Vfmaq_laneq_f32>(c, src, weight); | |||||
| cal_helper<2, 2, c_dim, Vfmaq_laneq_f32>(c, src, weight); | |||||
| // row 1 | |||||
| load_helper<5, 0, simd_len, 0, Vld1q_f32>(src, src_ptr + iw, 0); | |||||
| load_helper<3, 0, oc_step, c_dim, Vld1q_f32>( | |||||
| weight, weight_ptr + 1 * ld_weight_fw, ld_weight_oc); | |||||
| cal_helper<0, 0, c_dim, Vfmaq_laneq_f32>(c, src, weight); | |||||
| cal_helper<1, 1, c_dim, Vfmaq_laneq_f32>(c, src, weight); | |||||
| cal_helper<2, 2, c_dim, Vfmaq_laneq_f32>(c, src, weight); | |||||
| // row 2 | |||||
| load_helper<5, 0, simd_len, 0, Vld1q_f32>(src, src_ptr + 2 * iw, 0); | |||||
| load_helper<3, 0, oc_step, c_dim, Vld1q_f32>( | |||||
| weight, weight_ptr + 2 * ld_weight_fw, ld_weight_oc); | |||||
| cal_helper<0, 0, c_dim, Vfmaq_laneq_f32>(c, src, weight); | |||||
| cal_helper<1, 1, c_dim, Vfmaq_laneq_f32>(c, src, weight); | |||||
| cal_helper<2, 2, c_dim, Vfmaq_laneq_f32>(c, src, weight); | |||||
| src_ptr += ld_src_ic; | |||||
| weight_ptr += ld_weight_ic; | |||||
| } | |||||
| store_ocx_ow8_remain_static<c_dim, remain_w, Op>(c, op, dst_ptr, | |||||
| ld_dst_oc); | |||||
| } | |||||
| }; | |||||
| } // namespace | |||||
| void conv_bias::pack_weight_fp32_nchw_nchw44(const float32_t* in_ptr, | |||||
| float32_t* dst_ptr, const int oc, | |||||
| const int kh, const int kw, | |||||
| const int ic) { | |||||
| constexpr int oc_step = 4; | |||||
| const int filter_oc_stride = kh * kw * ic; | |||||
| const int filter_ic_stride = kh * kw * oc_step; | |||||
| for (int oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { | |||||
| const float32_t* in_ptr_oc = in_ptr + oc_idx * filter_oc_stride; | |||||
| float32_t* dst_ptr_oc = dst_ptr + oc_idx * filter_oc_stride; | |||||
| for (int kh_idx = 0; kh_idx < kh; ++kh_idx) { | |||||
| for (int kw_idx = 0; kw_idx < kw; ++kw_idx) { | |||||
| for (int ic_idx = 0; ic_idx < ic; ++ic_idx) { | |||||
| float32x4_t vsrc = vld1q_f32(in_ptr_oc); | |||||
| vst1q_f32(dst_ptr_oc + ic_idx * filter_ic_stride, vsrc); | |||||
| in_ptr_oc += oc_step; | |||||
| } | |||||
| dst_ptr_oc += oc_step; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| template <BiasMode bias_mode, typename Op, int filter_size> | |||||
| static void conv_direct_stride2_fp32_nchw_nchw44( | |||||
| const float32_t* src, const float32_t* filter, const float32_t* bias, | |||||
| float32_t*, float32_t* dst, const int oc, const int ic, const int ih, | |||||
| const int iw, const int oh, const int oh_block, const int ow, | |||||
| const Op& op, const int, const int) { | |||||
| constexpr int fh = filter_size; | |||||
| constexpr int fw = filter_size; | |||||
| constexpr int ic_step = 1; | |||||
| constexpr int big_oc_step = 8; | |||||
| constexpr int oc_step = 4; | |||||
| constexpr int ih_step = 1; | |||||
| constexpr int oh_step = 1; | |||||
| constexpr int ow_step = 8; | |||||
| constexpr int stride_h = 2; | |||||
| constexpr int stride_w = 2; | |||||
| constexpr int pack_iw_len = 1; | |||||
| const int img_stride = oh * ow; | |||||
| const int ow_end = ow / ow_step * ow_step; | |||||
| const int ow_remain = ow - ow_end; | |||||
| const int oc_end = oc / big_oc_step * big_oc_step; | |||||
| const int oc_remain = oc - oc_end; | |||||
| const int ld_dst_oc = oc_step * img_stride; | |||||
| using remain_fun = std::function<void( | |||||
| const float32_t* src_ptr, const float32_t* weight_ptr, | |||||
| const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, | |||||
| int iw, int ld_dst_oc, const Op& op)>; | |||||
| remain_fun kern_big_oc_remain = nullptr; | |||||
| remain_fun kern_small_oc_remain = nullptr; | |||||
| switch (ow_remain) { | |||||
| #define cb(step) \ | |||||
| case step: \ | |||||
| kern_big_oc_remain = \ | |||||
| KerNeonXXs2NchwNchw44FP32<bias_mode, Op, step, filter_size, \ | |||||
| big_oc_step>::impl; \ | |||||
| kern_small_oc_remain = \ | |||||
| KerNeonXXs2NchwNchw44FP32<bias_mode, Op, step, filter_size, \ | |||||
| oc_step>::impl; \ | |||||
| break; | |||||
| UNROLL_CALL_RAW(8, cb); | |||||
| default: | |||||
| megdnn_assert(0, "no remain %d for kern", ow_remain); | |||||
| } | |||||
| for (int oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { | |||||
| const int weight_offset = oc_idx * ic * fh * fw; | |||||
| for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) { | |||||
| for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||||
| const int src_offset = | |||||
| (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * | |||||
| ic_step * pack_iw_len; | |||||
| const int dst_offset = | |||||
| oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||||
| KerNeonXXs2NchwNchw44FP32< | |||||
| bias_mode, Op, 0, filter_size, | |||||
| big_oc_step>::impl(src + src_offset, | |||||
| filter + weight_offset, | |||||
| bias + oc_idx, dst + dst_offset, ic, | |||||
| ih, iw, ld_dst_oc, op); | |||||
| } | |||||
| if (ow_remain > 0) { | |||||
| const int src_offset = | |||||
| (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * | |||||
| ic_step * pack_iw_len; | |||||
| const int dst_offset = | |||||
| oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||||
| kern_big_oc_remain(src + src_offset, filter + weight_offset, | |||||
| bias + oc_idx, dst + dst_offset, ic, ih, iw, | |||||
| ld_dst_oc, op); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (oc_remain > 0) { | |||||
| int oc_idx = oc_end; | |||||
| const int weight_offset = oc_idx * ic * fh * fw; | |||||
| for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) { | |||||
| for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||||
| const int src_offset = | |||||
| (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * | |||||
| ic_step * pack_iw_len; | |||||
| const int dst_offset = | |||||
| oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||||
| KerNeonXXs2NchwNchw44FP32<bias_mode, Op, 0, filter_size, | |||||
| oc_step>::impl(src + src_offset, | |||||
| filter + weight_offset, | |||||
| bias + oc_idx, | |||||
| dst + dst_offset, ic, | |||||
| ih, iw, ld_dst_oc, op); | |||||
| } | |||||
| if (ow_remain > 0) { | |||||
| const int src_offset = | |||||
| (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * | |||||
| ic_step * pack_iw_len; | |||||
| const int dst_offset = | |||||
| oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||||
| kern_small_oc_remain(src + src_offset, filter + weight_offset, | |||||
| bias + oc_idx, dst + dst_offset, ic, ih, | |||||
| iw, ld_dst_oc, op); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| #define CONSTRUCT_FUNC(filter_size) \ | |||||
| template <BiasMode bias_mode, typename Op> \ | |||||
| void conv_bias:: \ | |||||
| conv_direct_stride2_##filter_size##x##filter_size##_fp32_nchw_nchw44( \ | |||||
| const float32_t* src, const float32_t* filter, \ | |||||
| const float32_t* bias, float32_t* temp, float32_t* dst, \ | |||||
| const int oc, const int ic, const int ih, const int iw, \ | |||||
| const int oh, const int oh_block, const int ow, \ | |||||
| const Op& op, const int ph, const int pw) { \ | |||||
| conv_direct_stride2_fp32_nchw_nchw44<bias_mode, Op, filter_size>( \ | |||||
| src, filter, bias, temp, dst, oc, ic, ih, iw, oh, oh_block, \ | |||||
| ow, op, ph, pw); \ | |||||
| } | |||||
| CONSTRUCT_FUNC(3); | |||||
| CONSTRUCT_FUNC(5); | |||||
| CONSTRUCT_FUNC(7); | |||||
| #undef CONSTRUCT_FUNC | |||||
| template <BiasMode bias_mode, typename Op> | |||||
| void conv_bias::conv_direct_stride2_2x2_fp32_nchw_nchw44( | |||||
| const float32_t*, const float32_t*, const float32_t*, float32_t*, | |||||
| float32_t*, const int, const int, const int, const int, const int, | |||||
| const int, const int, const Op&, const int, const int) { | |||||
| megdnn_assert(0, "not imple nchw_nchw44 2x2s2 conv"); | |||||
| } | |||||
| #define INSTANTIATION(stride, i, bias, Op) \ | |||||
| template void conv_bias:: \ | |||||
| conv_direct_##stride##_##i##x##i##_fp32_nchw_nchw44<bias, Op>( \ | |||||
| const float32_t*, const float32_t*, const float32_t*, \ | |||||
| float32_t*, float32_t*, const int, const int, const int, \ | |||||
| const int, const int, const int, const int, const Op&, \ | |||||
| const int, const int); | |||||
| #define FOR_OP(stride, i, bias) \ | |||||
| INSTANTIATION(stride, i, bias, NoneOp<dt_float32>) \ | |||||
| INSTANTIATION(stride, i, bias, ReluOp<dt_float32>) \ | |||||
| INSTANTIATION(stride, i, bias, HSwishOp<dt_float32>) | |||||
| #define FOR_BIAS(stride, i) \ | |||||
| FOR_OP(stride, i, BiasMode::NO_BIAS) \ | |||||
| FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS) | |||||
| #define FOR_FILTER(stride) \ | |||||
| FOR_BIAS(stride, 2) \ | |||||
| FOR_BIAS(stride, 3) \ | |||||
| FOR_BIAS(stride, 5) \ | |||||
| FOR_BIAS(stride, 7) | |||||
| FOR_FILTER(stride2) | |||||
| #undef FOR_STRIDE | |||||
| #undef FOR_FILTER | |||||
| #undef FOR_IC | |||||
| #undef FOR_BIAS | |||||
| #undef FOR_NONLINEAR | |||||
| #undef FOR_REMAIN | |||||
| #undef INSTANTIATION | |||||
| @@ -0,0 +1,38 @@ | |||||
| /** | |||||
| * \file dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "src/arm_common/conv_bias/opr_impl.h" | |||||
| #include "src/fallback/conv_bias/common.h" | |||||
| namespace megdnn { | |||||
| namespace arm_common { | |||||
| namespace conv_bias { | |||||
| #define KERN(stride, i, layout) \ | |||||
| template <BiasMode bias_mode, typename Op> \ | |||||
| void conv_direct_##stride##_##i##x##i##_fp32_nchw_##layout( \ | |||||
| const float* src, const float* filter, const float* bias, \ | |||||
| float* temp, float* dst, const int oc, const int ic, const int ih, \ | |||||
| const int iw, const int oh, const int oh_block, const int ow, \ | |||||
| const Op& op, const int ph, const int pw); | |||||
| KERN(stride2, 2, nchw44) | |||||
| KERN(stride2, 3, nchw44) | |||||
| KERN(stride2, 5, nchw44) | |||||
| KERN(stride2, 7, nchw44) | |||||
| #undef KERN | |||||
| void pack_weight_fp32_nchw_nchw44(const float_t* in_ptr, float_t* dst_ptr, | |||||
| const int oc, const int kh, const int kw, | |||||
| const int ic); | |||||
| } // namespace conv_bias | |||||
| } // namespace arm_common | |||||
| } // namespace megdnn | |||||
| @@ -174,7 +174,167 @@ inline void store_ocx_ow4_remain_static(T& c, const Op& op, int8_t* dst_ptr, | |||||
| int ld_dst_oc) { | int ld_dst_oc) { | ||||
| StoreOcxOw4Remain<c_dim, ow_remain, Op, T>::impl(c, op, dst_ptr, ld_dst_oc); | StoreOcxOw4Remain<c_dim, ow_remain, Op, T>::impl(c, op, dst_ptr, ld_dst_oc); | ||||
| } | } | ||||
| ////////////////////Store_OCX_OW8_Remain///////////////////////// | |||||
| template <int c_dim, int ow_remain, typename Op, typename T> | |||||
| struct StoreOcxOw8Remain { | |||||
| static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc); | |||||
| }; | |||||
| template <typename Op, typename T> | |||||
| struct StoreOcxOw8Remain<2, 0, Op, T> { | |||||
| static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) { | |||||
| op({{c[0][0], c[0][1]}}, dst_ptr); | |||||
| op({{c[0][2], c[0][3]}}, dst_ptr + 8); | |||||
| op({{c[0][4], c[0][5]}}, dst_ptr + 16); | |||||
| op({{c[0][6], c[0][7]}}, dst_ptr + 24); | |||||
| op({{c[1][0], c[1][1]}}, dst_ptr + ld_dst_oc); | |||||
| op({{c[1][2], c[1][3]}}, dst_ptr + ld_dst_oc + 8); | |||||
| op({{c[1][4], c[1][5]}}, dst_ptr + ld_dst_oc + 16); | |||||
| op({{c[1][6], c[1][7]}}, dst_ptr + ld_dst_oc + 24); | |||||
| } | |||||
| }; | |||||
| template <typename Op, typename T> | |||||
| struct StoreOcxOw8Remain<2, 7, Op, T> { | |||||
| static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) { | |||||
| op({{c[0][0], c[0][1]}}, dst_ptr); | |||||
| op({{c[0][2], c[0][3]}}, dst_ptr + 8); | |||||
| op({{c[0][4], c[0][5]}}, dst_ptr + 16); | |||||
| op(c[0][6], dst_ptr + 24); | |||||
| op({{c[1][0], c[1][1]}}, dst_ptr + ld_dst_oc); | |||||
| op({{c[1][2], c[1][3]}}, dst_ptr + ld_dst_oc + 8); | |||||
| op({{c[1][4], c[1][5]}}, dst_ptr + ld_dst_oc + 16); | |||||
| op(c[1][6], dst_ptr + ld_dst_oc + 24); | |||||
| } | |||||
| }; | |||||
| template <typename Op, typename T> | |||||
| struct StoreOcxOw8Remain<2, 6, Op, T> { | |||||
| static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) { | |||||
| op({{c[0][0], c[0][1]}}, dst_ptr); | |||||
| op({{c[0][2], c[0][3]}}, dst_ptr + 8); | |||||
| op({{c[0][4], c[0][5]}}, dst_ptr + 16); | |||||
| op({{c[1][0], c[1][1]}}, dst_ptr + ld_dst_oc); | |||||
| op({{c[1][2], c[1][3]}}, dst_ptr + ld_dst_oc + 8); | |||||
| op({{c[1][4], c[1][5]}}, dst_ptr + ld_dst_oc + 16); | |||||
| } | |||||
| }; | |||||
| template <typename Op, typename T> | |||||
| struct StoreOcxOw8Remain<2, 5, Op, T> { | |||||
| static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) { | |||||
| op({{c[0][0], c[0][1]}}, dst_ptr); | |||||
| op({{c[0][2], c[0][3]}}, dst_ptr + 8); | |||||
| op(c[0][4], dst_ptr + 16); | |||||
| op({{c[1][0], c[1][1]}}, dst_ptr + ld_dst_oc); | |||||
| op({{c[1][2], c[1][3]}}, dst_ptr + ld_dst_oc + 8); | |||||
| op(c[1][4], dst_ptr + ld_dst_oc + 16); | |||||
| } | |||||
| }; | |||||
| template <typename Op, typename T> | |||||
| struct StoreOcxOw8Remain<2, 4, Op, T> { | |||||
| static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) { | |||||
| op({{c[0][0], c[0][1]}}, dst_ptr); | |||||
| op({{c[0][2], c[0][3]}}, dst_ptr + 8); | |||||
| op({{c[1][0], c[1][1]}}, dst_ptr + ld_dst_oc); | |||||
| op({{c[1][2], c[1][3]}}, dst_ptr + ld_dst_oc + 8); | |||||
| } | |||||
| }; | |||||
| template <typename Op, typename T> | |||||
| struct StoreOcxOw8Remain<2, 3, Op, T> { | |||||
| static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) { | |||||
| op({{c[0][0], c[0][1]}}, dst_ptr); | |||||
| op(c[0][2], dst_ptr + 8); | |||||
| op({{c[1][0], c[1][1]}}, dst_ptr + ld_dst_oc); | |||||
| op(c[1][2], dst_ptr + ld_dst_oc + 8); | |||||
| } | |||||
| }; | |||||
| template <typename Op, typename T> | |||||
| struct StoreOcxOw8Remain<2, 2, Op, T> { | |||||
| static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) { | |||||
| op({{c[0][0], c[0][1]}}, dst_ptr); | |||||
| op({{c[1][0], c[1][1]}}, dst_ptr + ld_dst_oc); | |||||
| } | |||||
| }; | |||||
| template <typename Op, typename T> | |||||
| struct StoreOcxOw8Remain<2, 1, Op, T> { | |||||
| static void impl(T& c, const Op& op, float32_t* dst_ptr, int ld_dst_oc) { | |||||
| op(c[0][0], dst_ptr); | |||||
| op(c[1][0], dst_ptr + ld_dst_oc); | |||||
| } | |||||
| }; | |||||
| template <typename Op, typename T> | |||||
| struct StoreOcxOw8Remain<1, 0, Op, T> { | |||||
| static void impl(T& c, const Op& op, float32_t* dst_ptr, int) { | |||||
| op({{c[0][0], c[0][1]}}, dst_ptr); | |||||
| op({{c[0][2], c[0][3]}}, dst_ptr + 8); | |||||
| op({{c[0][4], c[0][5]}}, dst_ptr + 16); | |||||
| op({{c[0][6], c[0][7]}}, dst_ptr + 24); | |||||
| } | |||||
| }; | |||||
| template <typename Op, typename T> | |||||
| struct StoreOcxOw8Remain<1, 7, Op, T> { | |||||
| static void impl(T& c, const Op& op, float32_t* dst_ptr, int) { | |||||
| op({{c[0][0], c[0][1]}}, dst_ptr); | |||||
| op({{c[0][2], c[0][3]}}, dst_ptr + 8); | |||||
| op({{c[0][4], c[0][5]}}, dst_ptr + 16); | |||||
| op(c[0][6], dst_ptr + 24); | |||||
| } | |||||
| }; | |||||
| template <typename Op, typename T> | |||||
| struct StoreOcxOw8Remain<1, 6, Op, T> { | |||||
| static void impl(T& c, const Op& op, float32_t* dst_ptr, int) { | |||||
| op({{c[0][0], c[0][1]}}, dst_ptr); | |||||
| op({{c[0][2], c[0][3]}}, dst_ptr + 8); | |||||
| op({{c[0][4], c[0][5]}}, dst_ptr + 16); | |||||
| } | |||||
| }; | |||||
| template <typename Op, typename T> | |||||
| struct StoreOcxOw8Remain<1, 5, Op, T> { | |||||
| static void impl(T& c, const Op& op, float32_t* dst_ptr, int) { | |||||
| op({{c[0][0], c[0][1]}}, dst_ptr); | |||||
| op({{c[0][2], c[0][3]}}, dst_ptr + 8); | |||||
| op(c[0][4], dst_ptr + 16); | |||||
| } | |||||
| }; | |||||
| template <typename Op, typename T> | |||||
| struct StoreOcxOw8Remain<1, 4, Op, T> { | |||||
| static void impl(T& c, const Op& op, float32_t* dst_ptr, int) { | |||||
| op({{c[0][0], c[0][1]}}, dst_ptr); | |||||
| op({{c[0][2], c[0][3]}}, dst_ptr + 8); | |||||
| } | |||||
| }; | |||||
| template <typename Op, typename T> | |||||
| struct StoreOcxOw8Remain<1, 3, Op, T> { | |||||
| static void impl(T& c, const Op& op, float32_t* dst_ptr, int) { | |||||
| op({{c[0][0], c[0][1]}}, dst_ptr); | |||||
| op(c[0][2], dst_ptr + 8); | |||||
| } | |||||
| }; | |||||
| template <typename Op, typename T> | |||||
| struct StoreOcxOw8Remain<1, 2, Op, T> { | |||||
| static void impl(T& c, const Op& op, float32_t* dst_ptr, int) { | |||||
| op({{c[0][0], c[0][1]}}, dst_ptr); | |||||
| } | |||||
| }; | |||||
| template <typename Op, typename T> | |||||
| struct StoreOcxOw8Remain<1, 1, Op, T> { | |||||
| static void impl(T& c, const Op& op, float32_t* dst_ptr, int) { | |||||
| op(c[0][0], dst_ptr); | |||||
| } | |||||
| }; | |||||
| template <int c_dim, int ow_remain, typename Op, typename T> | |||||
| inline void store_ocx_ow8_remain_static(T& c, const Op& op, float32_t* dst_ptr, | |||||
| int ld_dst_oc) { | |||||
| StoreOcxOw8Remain<c_dim, ow_remain, Op, T>::impl(c, op, dst_ptr, ld_dst_oc); | |||||
| } | |||||
| ////////////////////Store_OC8_OW8_Remain///////////////////////// | ////////////////////Store_OC8_OW8_Remain///////////////////////// | ||||
| template <int ow_remain, typename Op> | template <int ow_remain, typename Op> | ||||
| @@ -299,14 +459,15 @@ struct Store_OC8_OW8_Remain<1, Op> { | |||||
| } | } | ||||
| }; | }; | ||||
| template <int ow_remain, typename Op> | |||||
| inline void store_oc8_ow8_remain_static(int32x4_t c[2][8], const Op& op, | |||||
| int8_t* dst_ptr, int ld_dst_oc) { | |||||
| /////////// | |||||
| template <int ow_remain, typename Op, typename T, typename T2> | |||||
| inline void store_oc8_ow8_remain_static(T& c, const Op& op, T2 dst_ptr, | |||||
| int ld_dst_oc) { | |||||
| Store_OC8_OW8_Remain<ow_remain, Op>::impl(c, op, dst_ptr, ld_dst_oc); | Store_OC8_OW8_Remain<ow_remain, Op>::impl(c, op, dst_ptr, ld_dst_oc); | ||||
| } | } | ||||
| /////////////////////////////////////////////////////// | |||||
| ////////////////////////////////////// | |||||
| template <BiasMode bias_mode> | template <BiasMode bias_mode> | ||||
| inline void init_oc4_ow8(int32x4_t c[8], const int32_t* bias_ptr) { | inline void init_oc4_ow8(int32x4_t c[8], const int32_t* bias_ptr) { | ||||
| if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { | if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { | ||||
| @@ -337,6 +498,49 @@ inline void init_oc8_ow8(int32x4_t c[2][8], const int32_t* bias_ptr, | |||||
| #undef BAIS_INIT | #undef BAIS_INIT | ||||
| } | } | ||||
| } | } | ||||
| /////////////////////////init_ocx_ow8//////////////////// | |||||
| template <int c_dim, BiasMode bias_mode, typename T, typename T2> | |||||
| struct InitOcxOw8 { | |||||
| static void impl(T& c, T2 bias_ptr, int oc_step); | |||||
| }; | |||||
| template <BiasMode bias_mode, typename T, typename T2> | |||||
| struct InitOcxOw8<2, bias_mode, T, T2> { | |||||
| static void impl(T& c, const float32_t* bias_ptr, int oc_step) { | |||||
| if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
| #define BAIS_INIT(step) \ | |||||
| c[0][step] = vld1q_f32(bias_ptr); \ | |||||
| c[1][step] = vld1q_f32(bias_ptr + oc_step); | |||||
| UNROLL_CALL_RAW(8, BAIS_INIT); | |||||
| #undef BAIS_INIT | |||||
| } else { | |||||
| #define BAIS_INIT(step) \ | |||||
| c[0][step] = vdupq_n_f32(0); \ | |||||
| c[1][step] = vdupq_n_f32(0); | |||||
| UNROLL_CALL_RAW(8, BAIS_INIT); | |||||
| #undef BAIS_INIT | |||||
| } | |||||
| } | |||||
| }; | |||||
| template <BiasMode bias_mode, typename T, typename T2> | |||||
| struct InitOcxOw8<1, bias_mode, T, T2> { | |||||
| static void impl(T& c, const float32_t* bias_ptr, int) { | |||||
| if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
| #define BAIS_INIT(step) c[0][step] = vld1q_f32(bias_ptr); | |||||
| UNROLL_CALL_RAW(8, BAIS_INIT); | |||||
| #undef BAIS_INIT | |||||
| } else { | |||||
| #define BAIS_INIT(step) c[0][step] = vdupq_n_f32(0); | |||||
| UNROLL_CALL_RAW(8, BAIS_INIT); | |||||
| #undef BAIS_INIT | |||||
| } | |||||
| } | |||||
| }; | |||||
| template <int c_dim, BiasMode bias_mode, typename T, typename T2> | |||||
| inline void init_ocx_ow8(T& c, T2 bias_ptr, int oc_step) { | |||||
| InitOcxOw8<c_dim, bias_mode, T, T2>::impl(c, bias_ptr, oc_step); | |||||
| } | |||||
| /////////////////////init_ocx_ow4///////////////////// | |||||
| template <int c_dim, BiasMode bias_mode, typename T> | template <int c_dim, BiasMode bias_mode, typename T> | ||||
| struct InitOcxOw4 { | struct InitOcxOw4 { | ||||
| static void impl(T& c, const int32_t* bias_ptr, int oc_step); | static void impl(T& c, const int32_t* bias_ptr, int oc_step); | ||||
| @@ -383,57 +587,54 @@ inline void init_ocx_ow4(T& c, const int32_t* bias_ptr, int oc_step) { | |||||
| } | } | ||||
| /////////////////////////////////////// | /////////////////////////////////////// | ||||
| template <int weight_number, int base_offset, int ptr_step, int oc_block, | template <int weight_number, int base_offset, int ptr_step, int oc_block, | ||||
| typename Func, typename T, typename... XT> | |||||
| typename Func, typename T, typename T2, typename... XT> | |||||
| struct LoadHelper { | struct LoadHelper { | ||||
| static void impl(T& weight, const int8_t* ptr, int oc_offset, XT... args); | |||||
| static void impl(T& weight, T2 ptr, int oc_offset, XT... args); | |||||
| }; | }; | ||||
| #define WEIGHT_CB(step) \ | #define WEIGHT_CB(step) \ | ||||
| src[step] = Func::impl(ptr + base_offset + step * ptr_step, args...); | src[step] = Func::impl(ptr + base_offset + step * ptr_step, args...); | ||||
| template <int base_offset, int ptr_step, typename Func, typename T, | |||||
| template <int base_offset, int ptr_step, typename Func, typename T, typename T2, | |||||
| typename... XT> | typename... XT> | ||||
| struct LoadHelper<1, base_offset, ptr_step, 0, Func, T, XT...> { | |||||
| static void impl(T& src, const int8_t* ptr, int oc_offset, XT... args) { | |||||
| struct LoadHelper<1, base_offset, ptr_step, 0, Func, T, T2, XT...> { | |||||
| static void impl(T& src, T2 ptr, int, XT... args) { | |||||
| UNROLL_CALL_RAW(1, WEIGHT_CB); | UNROLL_CALL_RAW(1, WEIGHT_CB); | ||||
| } | } | ||||
| }; | }; | ||||
| template <int base_offset, int ptr_step, typename Func, typename T, | |||||
| template <int base_offset, int ptr_step, typename Func, typename T, typename T2, | |||||
| typename... XT> | typename... XT> | ||||
| struct LoadHelper<2, base_offset, ptr_step, 0, Func, T, XT...> { | |||||
| static void impl(T& src, const int8_t* ptr, int oc_offset, XT... args) { | |||||
| struct LoadHelper<2, base_offset, ptr_step, 0, Func, T, T2, XT...> { | |||||
| static void impl(T& src, T2 ptr, int, XT... args) { | |||||
| UNROLL_CALL_RAW(2, WEIGHT_CB); | UNROLL_CALL_RAW(2, WEIGHT_CB); | ||||
| } | } | ||||
| }; | }; | ||||
| template <int base_offset, int ptr_step, typename Func, typename T, | |||||
| template <int base_offset, int ptr_step, typename Func, typename T, typename T2, | |||||
| typename... XT> | typename... XT> | ||||
| struct LoadHelper<3, base_offset, ptr_step, 0, Func, T, XT...> { | |||||
| static void impl(T& src, const int8_t* ptr, int oc_offset, XT... args) { | |||||
| struct LoadHelper<3, base_offset, ptr_step, 0, Func, T, T2, XT...> { | |||||
| static void impl(T& src, T2 ptr, int, XT... args) { | |||||
| UNROLL_CALL_RAW(3, WEIGHT_CB); | UNROLL_CALL_RAW(3, WEIGHT_CB); | ||||
| } | } | ||||
| }; | }; | ||||
| template <int base_offset, int ptr_step, typename Func, typename T, | |||||
| template <int base_offset, int ptr_step, typename Func, typename T, typename T2, | |||||
| typename... XT> | typename... XT> | ||||
| struct LoadHelper<4, base_offset, ptr_step, 0, Func, T, XT...> { | |||||
| static void impl(T& src, const int8_t* ptr, int oc_offset, XT... args) { | |||||
| MEGDNN_MARK_USED_VAR(oc_offset); | |||||
| struct LoadHelper<4, base_offset, ptr_step, 0, Func, T, T2, XT...> { | |||||
| static void impl(T& src, T2 ptr, int, XT... args) { | |||||
| UNROLL_CALL_RAW(4, WEIGHT_CB); | UNROLL_CALL_RAW(4, WEIGHT_CB); | ||||
| } | } | ||||
| }; | }; | ||||
| template <int base_offset, int ptr_step, typename Func, typename T, | |||||
| template <int base_offset, int ptr_step, typename Func, typename T, typename T2, | |||||
| typename... XT> | typename... XT> | ||||
| struct LoadHelper<5, base_offset, ptr_step, 0, Func, T, XT...> { | |||||
| static void impl(T& src, const int8_t* ptr, int oc_offset, XT... args) { | |||||
| MEGDNN_MARK_USED_VAR(oc_offset); | |||||
| struct LoadHelper<5, base_offset, ptr_step, 0, Func, T, T2, XT...> { | |||||
| static void impl(T& src, T2 ptr, int, XT... args) { | |||||
| UNROLL_CALL_RAW(5, WEIGHT_CB); | UNROLL_CALL_RAW(5, WEIGHT_CB); | ||||
| } | } | ||||
| }; | }; | ||||
| template <int base_offset, int ptr_step, typename Func, typename T, | |||||
| template <int base_offset, int ptr_step, typename Func, typename T, typename T2, | |||||
| typename... XT> | typename... XT> | ||||
| struct LoadHelper<6, base_offset, ptr_step, 0, Func, T, XT...> { | |||||
| static void impl(T& src, const int8_t* ptr, int oc_offset, XT... args) { | |||||
| MEGDNN_MARK_USED_VAR(oc_offset); | |||||
| struct LoadHelper<6, base_offset, ptr_step, 0, Func, T, T2, XT...> { | |||||
| static void impl(T& src, T2 ptr, int, XT... args) { | |||||
| UNROLL_CALL_RAW(6, WEIGHT_CB); | UNROLL_CALL_RAW(6, WEIGHT_CB); | ||||
| } | } | ||||
| }; | }; | ||||
| @@ -441,27 +642,36 @@ struct LoadHelper<6, base_offset, ptr_step, 0, Func, T, XT...> { | |||||
| #define WEIGHT_CB(step) \ | #define WEIGHT_CB(step) \ | ||||
| src[0][step] = Func::impl(ptr + base_offset + step * ptr_step); | src[0][step] = Func::impl(ptr + base_offset + step * ptr_step); | ||||
| template <int base_offset, int ptr_step, typename Func, typename T> | |||||
| struct LoadHelper<1, base_offset, ptr_step, 1, Func, T> { | |||||
| static void impl(T& src, const int8_t* ptr, int oc_offset) { | |||||
| MEGDNN_MARK_USED_VAR(oc_offset); | |||||
| UNROLL_CALL_RAW(1, WEIGHT_CB); | |||||
| } | |||||
| template <int base_offset, int ptr_step, typename Func, typename T, typename T2> | |||||
| struct LoadHelper<1, base_offset, ptr_step, 1, Func, T, T2> { | |||||
| static void impl(T& src, T2 ptr, int) { UNROLL_CALL_RAW(1, WEIGHT_CB); } | |||||
| }; | }; | ||||
| template <int base_offset, int ptr_step, typename Func, typename T> | |||||
| struct LoadHelper<2, base_offset, ptr_step, 1, Func, T> { | |||||
| static void impl(T& src, const int8_t* ptr, int oc_offset) { | |||||
| MEGDNN_MARK_USED_VAR(oc_offset); | |||||
| UNROLL_CALL_RAW(2, WEIGHT_CB); | |||||
| } | |||||
| template <int base_offset, int ptr_step, typename Func, typename T, typename T2> | |||||
| struct LoadHelper<2, base_offset, ptr_step, 1, Func, T, T2> { | |||||
| static void impl(T& src, T2 ptr, int) { UNROLL_CALL_RAW(2, WEIGHT_CB); } | |||||
| }; | }; | ||||
| template <int base_offset, int ptr_step, typename Func, typename T> | |||||
| struct LoadHelper<3, base_offset, ptr_step, 1, Func, T> { | |||||
| static void impl(T& src, const int8_t* ptr, int oc_offset) { | |||||
| MEGDNN_MARK_USED_VAR(oc_offset); | |||||
| UNROLL_CALL_RAW(3, WEIGHT_CB); | |||||
| } | |||||
| template <int base_offset, int ptr_step, typename Func, typename T, typename T2> | |||||
| struct LoadHelper<3, base_offset, ptr_step, 1, Func, T, T2> { | |||||
| static void impl(T& src, T2 ptr, int) { UNROLL_CALL_RAW(3, WEIGHT_CB); } | |||||
| }; | |||||
| template <int base_offset, int ptr_step, typename Func, typename T, typename T2> | |||||
| struct LoadHelper<4, base_offset, ptr_step, 1, Func, T, T2> { | |||||
| static void impl(T& src, T2 ptr, int) { UNROLL_CALL_RAW(4, WEIGHT_CB); } | |||||
| }; | |||||
| template <int base_offset, int ptr_step, typename Func, typename T, typename T2> | |||||
| struct LoadHelper<5, base_offset, ptr_step, 1, Func, T, T2> { | |||||
| static void impl(T& src, T2 ptr, int) { UNROLL_CALL_RAW(5, WEIGHT_CB); } | |||||
| }; | |||||
| template <int base_offset, int ptr_step, typename Func, typename T, typename T2> | |||||
| struct LoadHelper<6, base_offset, ptr_step, 1, Func, T, T2> { | |||||
| static void impl(T& src, T2 ptr, int) { UNROLL_CALL_RAW(6, WEIGHT_CB); } | |||||
| }; | |||||
| template <int base_offset, int ptr_step, typename Func, typename T, typename T2> | |||||
| struct LoadHelper<7, base_offset, ptr_step, 1, Func, T, T2> { | |||||
| static void impl(T& src, T2 ptr, int) { UNROLL_CALL_RAW(7, WEIGHT_CB); } | |||||
| }; | }; | ||||
| #undef WEIGHT_CB | #undef WEIGHT_CB | ||||
| @@ -470,40 +680,63 @@ struct LoadHelper<3, base_offset, ptr_step, 1, Func, T> { | |||||
| src[0][step] = Func::impl(ptr + base_offset + step * ptr_step); \ | src[0][step] = Func::impl(ptr + base_offset + step * ptr_step); \ | ||||
| src[1][step] = Func::impl(ptr + base_offset + step * ptr_step + oc_offset); | src[1][step] = Func::impl(ptr + base_offset + step * ptr_step + oc_offset); | ||||
| template <int base_offset, int ptr_step, typename Func, typename T> | |||||
| struct LoadHelper<1, base_offset, ptr_step, 2, Func, T> { | |||||
| static void impl(T& src, const int8_t* ptr, int oc_offset) { | |||||
| template <int base_offset, int ptr_step, typename Func, typename T, typename T2> | |||||
| struct LoadHelper<1, base_offset, ptr_step, 2, Func, T, T2> { | |||||
| static void impl(T& src, T2 ptr, int oc_offset) { | |||||
| UNROLL_CALL_RAW(1, WEIGHT_CB); | UNROLL_CALL_RAW(1, WEIGHT_CB); | ||||
| } | } | ||||
| }; | }; | ||||
| template <int base_offset, int ptr_step, typename Func, typename T> | |||||
| struct LoadHelper<2, base_offset, ptr_step, 2, Func, T> { | |||||
| static void impl(T& src, const int8_t* ptr, int oc_offset) { | |||||
| template <int base_offset, int ptr_step, typename Func, typename T, typename T2> | |||||
| struct LoadHelper<2, base_offset, ptr_step, 2, Func, T, T2> { | |||||
| static void impl(T& src, T2 ptr, int oc_offset) { | |||||
| UNROLL_CALL_RAW(2, WEIGHT_CB); | UNROLL_CALL_RAW(2, WEIGHT_CB); | ||||
| } | } | ||||
| }; | }; | ||||
| template <int base_offset, int ptr_step, typename Func, typename T> | |||||
| struct LoadHelper<3, base_offset, ptr_step, 2, Func, T> { | |||||
| static void impl(T& src, const int8_t* ptr, int oc_offset) { | |||||
| template <int base_offset, int ptr_step, typename Func, typename T, typename T2> | |||||
| struct LoadHelper<3, base_offset, ptr_step, 2, Func, T, T2> { | |||||
| static void impl(T& src, T2 ptr, int oc_offset) { | |||||
| UNROLL_CALL_RAW(3, WEIGHT_CB); | UNROLL_CALL_RAW(3, WEIGHT_CB); | ||||
| } | } | ||||
| }; | }; | ||||
| template <int base_offset, int ptr_step, typename Func, typename T, typename T2> | |||||
| struct LoadHelper<4, base_offset, ptr_step, 2, Func, T, T2> { | |||||
| static void impl(T& src, T2 ptr, int oc_offset) { | |||||
| UNROLL_CALL_RAW(4, WEIGHT_CB); | |||||
| } | |||||
| }; | |||||
| template <int base_offset, int ptr_step, typename Func, typename T, typename T2> | |||||
| struct LoadHelper<5, base_offset, ptr_step, 2, Func, T, T2> { | |||||
| static void impl(T& src, T2 ptr, int oc_offset) { | |||||
| UNROLL_CALL_RAW(5, WEIGHT_CB); | |||||
| } | |||||
| }; | |||||
| template <int base_offset, int ptr_step, typename Func, typename T, typename T2> | |||||
| struct LoadHelper<6, base_offset, ptr_step, 2, Func, T, T2> { | |||||
| static void impl(T& src, T2 ptr, int oc_offset) { | |||||
| UNROLL_CALL_RAW(6, WEIGHT_CB); | |||||
| } | |||||
| }; | |||||
| template <int base_offset, int ptr_step, typename Func, typename T, typename T2> | |||||
| struct LoadHelper<7, base_offset, ptr_step, 2, Func, T, T2> { | |||||
| static void impl(T& src, T2 ptr, int oc_offset) { | |||||
| UNROLL_CALL_RAW(7, WEIGHT_CB); | |||||
| } | |||||
| }; | |||||
| #undef WEIGHT_CB | #undef WEIGHT_CB | ||||
| template <int weight_number, int base_offset, int ptr_step, int c_dim, | template <int weight_number, int base_offset, int ptr_step, int c_dim, | ||||
| typename Func, typename T> | |||||
| inline void load_helper(T& weight, const int8_t* ptr, int oc_offset) { | |||||
| LoadHelper<weight_number, base_offset, ptr_step, c_dim, Func, T>::impl( | |||||
| typename Func, typename T, typename T2> | |||||
| inline void load_helper(T& weight, T2 ptr, int oc_offset) { | |||||
| LoadHelper<weight_number, base_offset, ptr_step, c_dim, Func, T, T2>::impl( | |||||
| weight, ptr, oc_offset); | weight, ptr, oc_offset); | ||||
| } | } | ||||
| template <int weight_number, int base_offset, int ptr_step, int c_dim, | template <int weight_number, int base_offset, int ptr_step, int c_dim, | ||||
| typename Func, typename T, typename... XT> | |||||
| inline void load_helper_x(T& weight, const int8_t* ptr, int oc_offset, | |||||
| XT... args) { | |||||
| LoadHelper<weight_number, base_offset, ptr_step, c_dim, Func, T, | |||||
| typename Func, typename T, typename T2, typename... XT> | |||||
| inline void load_helper_x(T& weight, T2 ptr, int oc_offset, XT... args) { | |||||
| LoadHelper<weight_number, base_offset, ptr_step, c_dim, Func, T, T2, | |||||
| XT...>::impl(weight, ptr, oc_offset, args...); | XT...>::impl(weight, ptr, oc_offset, args...); | ||||
| } | } | ||||
| @@ -34,6 +34,9 @@ struct Vmlal_s16 { | |||||
| struct Vld1q_s8 { | struct Vld1q_s8 { | ||||
| static int8x16_t impl(const int8_t* ptr) { return vld1q_s8(ptr); } | static int8x16_t impl(const int8_t* ptr) { return vld1q_s8(ptr); } | ||||
| }; | }; | ||||
| struct Vld1q_f32 { | |||||
| static float32x4_t impl(const float32_t* ptr) { return vld1q_f32(ptr); } | |||||
| }; | |||||
| struct Vld1_s8 { | struct Vld1_s8 { | ||||
| static int8x8_t impl(const int8_t* ptr) { return vld1_s8(ptr); } | static int8x8_t impl(const int8_t* ptr) { return vld1_s8(ptr); } | ||||
| }; | }; | ||||
| @@ -50,5 +53,13 @@ struct Vldq_tbl_low_s8 { | |||||
| struct Vld1_dup_s8_s16 { | struct Vld1_dup_s8_s16 { | ||||
| static int16x8_t impl(const int8_t* ptr) { return vld1_dup_s8_s16(ptr); } | static int16x8_t impl(const int8_t* ptr) { return vld1_dup_s8_s16(ptr); } | ||||
| }; | }; | ||||
| struct Vfmaq_laneq_f32 { | |||||
| template <const int lane> | |||||
| static float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) { | |||||
| return vfmaq_laneq_f32(a, b, v, lane); | |||||
| } | |||||
| }; | |||||
| } // namespace | } // namespace | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| @@ -71,6 +71,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||||
| AlgoF32DirectStride2 f32_direct_stride2_small_group{false}; | AlgoF32DirectStride2 f32_direct_stride2_small_group{false}; | ||||
| AlgoF32DirectStride1 f32_direct_stride1_large_group{true}; | AlgoF32DirectStride1 f32_direct_stride1_large_group{true}; | ||||
| AlgoF32DirectStride1 f32_direct_stride1_small_group{false}; | AlgoF32DirectStride1 f32_direct_stride1_small_group{false}; | ||||
| AlgoF32DirectStride2NCHWNCHW44 f32_direct_stride2_nchw_nchw44; | |||||
| AlgoI8x8x16Direct i8x8x16_direct_large_group{true}; | AlgoI8x8x16Direct i8x8x16_direct_large_group{true}; | ||||
| AlgoI8x8x16Direct i8x8x16_direct_small_group{false}; | AlgoI8x8x16Direct i8x8x16_direct_small_group{false}; | ||||
| AlgoI8x8x16Stride2 i8x8x16_stride2_large_group{true}; | AlgoI8x8x16Stride2 i8x8x16_stride2_large_group{true}; | ||||
| @@ -123,6 +124,7 @@ public: | |||||
| direct_algos.emplace_back(&i8x8x16_stride2_filter2); | direct_algos.emplace_back(&i8x8x16_stride2_filter2); | ||||
| direct_algos.emplace_back(&i8x8x16_stride2_large_group); | direct_algos.emplace_back(&i8x8x16_stride2_large_group); | ||||
| direct_algos.emplace_back(&i8x8x16_stride2_small_group); | direct_algos.emplace_back(&i8x8x16_stride2_small_group); | ||||
| direct_algos.emplace_back(&f32_direct_stride2_nchw_nchw44); | |||||
| direct_algos.emplace_back(&f32_direct_stride1_large_group); | direct_algos.emplace_back(&f32_direct_stride1_large_group); | ||||
| direct_algos.emplace_back(&f32_direct_stride1_small_group); | direct_algos.emplace_back(&f32_direct_stride1_small_group); | ||||
| direct_algos.emplace_back(&f32_direct_stride2_large_group); | direct_algos.emplace_back(&f32_direct_stride2_large_group); | ||||
| @@ -67,6 +67,7 @@ private: | |||||
| class AlgoF32Direct; | class AlgoF32Direct; | ||||
| class AlgoF32DirectStride1; | class AlgoF32DirectStride1; | ||||
| class AlgoF32DirectStride2; | class AlgoF32DirectStride2; | ||||
| class AlgoF32DirectStride2NCHWNCHW44; | |||||
| class AlgoI8x8x16Direct; | class AlgoI8x8x16Direct; | ||||
| class AlgoI8x8x16Stride2; | class AlgoI8x8x16Stride2; | ||||
| class AlgoI8x8x16Stride2Filter2; | class AlgoI8x8x16Stride2Filter2; | ||||
| @@ -45,13 +45,17 @@ struct HSwishOp; | |||||
| vst1q_##_func_suffix(dst, vitem.val[0]); \ | vst1q_##_func_suffix(dst, vitem.val[0]); \ | ||||
| vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ | vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ | ||||
| } \ | } \ | ||||
| void operator()(const _neon_type& src, _ctype* dst) const { \ | |||||
| auto vitem = operator()(src); \ | |||||
| vst1q_##_func_suffix(dst, vitem); \ | |||||
| } \ | |||||
| _neon_type2 operator()(const _neon_type2& src) const { \ | _neon_type2 operator()(const _neon_type2& src) const { \ | ||||
| auto val1 = src.val[0]; \ | auto val1 = src.val[0]; \ | ||||
| auto val2 = src.val[1]; \ | auto val2 = src.val[1]; \ | ||||
| H_SWISH_KERN(_func_suffix, val1, val2); \ | H_SWISH_KERN(_func_suffix, val1, val2); \ | ||||
| return {{val1, val2}}; \ | return {{val1, val2}}; \ | ||||
| } \ | } \ | ||||
| _neon_type operator()(const _neon_type& src) { \ | |||||
| _neon_type operator()(const _neon_type& src) const { \ | |||||
| auto val_zero = vdupq_n_##_func_suffix(0.f); \ | auto val_zero = vdupq_n_##_func_suffix(0.f); \ | ||||
| auto val_six = vdupq_n_##_func_suffix(6.f); \ | auto val_six = vdupq_n_##_func_suffix(6.f); \ | ||||
| auto val_three = vdupq_n_##_func_suffix(3.f); \ | auto val_three = vdupq_n_##_func_suffix(3.f); \ | ||||
| @@ -64,6 +68,7 @@ struct HSwishOp; | |||||
| val_rec_six); \ | val_rec_six); \ | ||||
| } \ | } \ | ||||
| }; | }; | ||||
| OP(dt_float32, float32x4_t, float32x4x2_t, f32, 4) | OP(dt_float32, float32x4_t, float32x4x2_t, f32, 4) | ||||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
| OP(__fp16, float16x8_t, float16x8x2_t, f16, 8) | OP(__fp16, float16x8_t, float16x8x2_t, f16, 8) | ||||
| @@ -6,7 +6,8 @@ | |||||
| * | * | ||||
| * Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | */ | ||||
| #pragma once | #pragma once | ||||
| @@ -30,6 +31,13 @@ struct NoneOp; | |||||
| using NoneOpBase::operator(); \ | using NoneOpBase::operator(); \ | ||||
| constexpr static size_t SIMD_WIDTH = _simd_width; \ | constexpr static size_t SIMD_WIDTH = _simd_width; \ | ||||
| _neon_type2 operator()(const _neon_type2& src) const { return src; } \ | _neon_type2 operator()(const _neon_type2& src) const { return src; } \ | ||||
| void operator()(const _neon_type2& src, _ctype* dst) const { \ | |||||
| vst1q_##_func_suffix(dst, src.val[0]); \ | |||||
| vst1q_##_func_suffix(dst + SIMD_WIDTH, src.val[1]); \ | |||||
| } \ | |||||
| void operator()(const _neon_type& src, _ctype* dst) const { \ | |||||
| vst1q_##_func_suffix(dst, src); \ | |||||
| } \ | |||||
| _neon_type operator()(const _neon_type& src) const { return src; } \ | _neon_type operator()(const _neon_type& src) const { return src; } \ | ||||
| }; | }; | ||||
| @@ -47,11 +47,16 @@ struct ReluOp; | |||||
| auto vitem1 = vmaxq_##_func_suffix(src.val[1], vzero); \ | auto vitem1 = vmaxq_##_func_suffix(src.val[1], vzero); \ | ||||
| return {{vitem0, vitem1}}; \ | return {{vitem0, vitem1}}; \ | ||||
| } \ | } \ | ||||
| void operator()(const _neon_type& src, _ctype* dst) const { \ | |||||
| auto vitem = operator()(src); \ | |||||
| vst1q_##_func_suffix(dst, vitem); \ | |||||
| } \ | |||||
| _neon_type operator()(const _neon_type& src) const { \ | _neon_type operator()(const _neon_type& src) const { \ | ||||
| auto vzero = vdupq_n_##_func_suffix(0); \ | auto vzero = vdupq_n_##_func_suffix(0); \ | ||||
| return vmaxq_##_func_suffix(src, vzero); \ | return vmaxq_##_func_suffix(src, vzero); \ | ||||
| } \ | } \ | ||||
| }; | }; | ||||
| OP(dt_float32, float32x4_t, float32x4x2_t, f32, 4) | OP(dt_float32, float32x4_t, float32x4x2_t, f32, 4) | ||||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
| OP(__fp16, float16x8_t, float16x8x2_t, f16, 8) | OP(__fp16, float16x8_t, float16x8x2_t, f16, 8) | ||||
| @@ -479,6 +479,39 @@ UNROLL_CALL_RAW(4, cb); | |||||
| #undef cb | #undef cb | ||||
| } // namespace | } // namespace | ||||
| #define vdup_laneq_s16(vec, lane) Vdup_laneq_s16_armv7<lane>::impl(vec) | #define vdup_laneq_s16(vec, lane) Vdup_laneq_s16_armv7<lane>::impl(vec) | ||||
| namespace { | |||||
| template <int lane> | |||||
| struct Vfmap_laneq_f32_armv7 { | |||||
| static float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v); | |||||
| }; | |||||
| template <> | |||||
| struct Vfmap_laneq_f32_armv7<0> { | |||||
| static float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) { | |||||
| return vmlaq_lane_f32(a, b, vget_low_f32(v), 0); | |||||
| } | |||||
| }; | |||||
| template <> | |||||
| struct Vfmap_laneq_f32_armv7<1> { | |||||
| static float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) { | |||||
| return vmlaq_lane_f32(a, b, vget_low_f32(v), 1); | |||||
| } | |||||
| }; | |||||
| template <> | |||||
| struct Vfmap_laneq_f32_armv7<2> { | |||||
| static float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) { | |||||
| return vmlaq_lane_f32(a, b, vget_high_f32(v), 0); | |||||
| } | |||||
| }; | |||||
| template <> | |||||
| struct Vfmap_laneq_f32_armv7<3> { | |||||
| static float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) { | |||||
| return vmlaq_lane_f32(a, b, vget_high_f32(v), 1); | |||||
| } | |||||
| }; | |||||
| } // namespace | |||||
| #define vfmaq_laneq_f32(a, b, v, lane) \ | |||||
| Vfmap_laneq_f32_armv7<lane>::impl(a, b, v) | |||||
| #endif | #endif | ||||
| @@ -85,7 +85,7 @@ TEST_F(ARM_COMMON, CONV_BIAS_MATMUL_QU8) { | |||||
| #if MEGDNN_WITH_BENCHMARK | #if MEGDNN_WITH_BENCHMARK | ||||
| static void benchmark_convbias(Handle* handle) { | |||||
| static void benchmark_convbias(Handle* handle, bool is_fp32 = false) { | |||||
| constexpr size_t RUNS = 30; | constexpr size_t RUNS = 30; | ||||
| Benchmarker<ConvBias> benchmarker_int(handle); | Benchmarker<ConvBias> benchmarker_int(handle); | ||||
| @@ -102,15 +102,25 @@ static void benchmark_convbias(Handle* handle) { | |||||
| Benchmarker<ConvBias> benchmarker_float(handle); | Benchmarker<ConvBias> benchmarker_float(handle); | ||||
| benchmarker_float.set_display(false).set_times(RUNS); | benchmarker_float.set_display(false).set_times(RUNS); | ||||
| benchmarker_float.set_before_exec_callback( | benchmarker_float.set_before_exec_callback( | ||||
| conv_bias::ConvBiasAlgoChecker<ConvBias>(".+")); | |||||
| conv_bias::ConvBiasAlgoChecker<ConvBias>( | |||||
| "IM2COLMATMUL:AARCH64_F32K8X12X1:192")); | |||||
| Benchmarker<ConvBias> benchmarker_int_nchw44(handle); | Benchmarker<ConvBias> benchmarker_int_nchw44(handle); | ||||
| benchmarker_int_nchw44.set_times(RUNS) | |||||
| .set_dtype(0, dtype::QuantizedS8(2.5)) | |||||
| .set_dtype(1, dtype::QuantizedS8(2.5)) | |||||
| .set_dtype(2, dtype::QuantizedS32(6.25)) | |||||
| .set_dtype(4, dtype::QuantizedS8(60.25)) | |||||
| .set_display(false); | |||||
| if (is_fp32) { | |||||
| benchmarker_int_nchw44.set_times(RUNS) | |||||
| .set_dtype(0, dtype::Float32()) | |||||
| .set_dtype(1, dtype::Float32()) | |||||
| .set_dtype(2, dtype::Float32()) | |||||
| .set_dtype(4, dtype::Float32()) | |||||
| .set_display(false); | |||||
| } else { | |||||
| benchmarker_int_nchw44.set_times(RUNS) | |||||
| .set_dtype(0, dtype::QuantizedS8(2.5)) | |||||
| .set_dtype(1, dtype::QuantizedS8(2.5)) | |||||
| .set_dtype(2, dtype::QuantizedS32(6.25)) | |||||
| .set_dtype(4, dtype::QuantizedS8(60.25)) | |||||
| .set_display(false); | |||||
| } | |||||
| benchmarker_int_nchw44.set_before_exec_callback( | benchmarker_int_nchw44.set_before_exec_callback( | ||||
| conv_bias::ConvBiasAlgoChecker<ConvBias>(".+")); | conv_bias::ConvBiasAlgoChecker<ConvBias>(".+")); | ||||
| @@ -151,7 +161,6 @@ static void benchmark_convbias(Handle* handle) { | |||||
| auto int_nchw44_used = benchmarker_int_nchw44.set_param(param).exec( | auto int_nchw44_used = benchmarker_int_nchw44.set_param(param).exec( | ||||
| {src, filter, bias, {}, dst}) / | {src, filter, bias, {}, dst}) / | ||||
| RUNS; | RUNS; | ||||
| float computations = IC * (FS * FS) * dst.total_nr_elems() * 2 * 1e-6; | float computations = IC * (FS * FS) * dst.total_nr_elems() * 2 * 1e-6; | ||||
| printf("run: %s %s %s->%s \n", src.to_string().c_str(), | printf("run: %s %s %s->%s \n", src.to_string().c_str(), | ||||
| filter.to_string().c_str(), bias.to_string().c_str(), | filter.to_string().c_str(), bias.to_string().c_str(), | ||||
| @@ -160,32 +169,42 @@ static void benchmark_convbias(Handle* handle) { | |||||
| computations / float_used); | computations / float_used); | ||||
| printf("int_nchw: %f ms %f Gflops, ", int_used, | printf("int_nchw: %f ms %f Gflops, ", int_used, | ||||
| computations / int_used); | computations / int_used); | ||||
| printf("int_nchw44: %f ms %f Gflops %f speedup, ", int_nchw44_used, | |||||
| computations / int_nchw44_used, int_used / int_nchw44_used); | |||||
| auto speed_up = int_used / int_nchw44_used; | |||||
| if (is_fp32) { | |||||
| speed_up = float_used / int_nchw44_used; | |||||
| printf("fp32_nchw44: %f ms %f Gflops %f speedup, ", int_nchw44_used, | |||||
| computations / int_nchw44_used, speed_up); | |||||
| } else { | |||||
| printf("int_nchw44: %f ms %f Gflops %f speedup, ", int_nchw44_used, | |||||
| computations / int_nchw44_used, speed_up); | |||||
| } | |||||
| printf("\n"); | printf("\n"); | ||||
| }; | }; | ||||
| run(1, 3, 32, 224, 224, 3, 2, true); | |||||
| run(1, 3, 64, 224, 224, 5, 2, true); | |||||
| run(1, 3, 64, 224, 224, 7, 2, true); | |||||
| run(1, 3, 32, 224, 224, 7, 2, true); | |||||
| for (size_t stride : {1, 2}) { | |||||
| printf("stride %zu\n", stride); | |||||
| for (size_t filter_size : {2, 3, 5, 7}) { | |||||
| for (size_t img_size : {32}) { | |||||
| for (size_t channel : {8, 16, 32, 64, 128, 256}) { | |||||
| run(1, channel, channel, img_size, img_size, filter_size, | |||||
| stride, false); | |||||
| if (is_fp32) { | |||||
| run(1, 3, 32, 224, 224, 3, 2, true); | |||||
| run(1, 3, 64, 224, 224, 7, 2, true); | |||||
| } else { | |||||
| for (size_t stride : {1, 2}) { | |||||
| printf("stride %zu\n", stride); | |||||
| for (size_t filter_size : {2, 3, 5, 7}) { | |||||
| for (size_t img_size : {32}) { | |||||
| for (size_t channel : {8, 16, 32, 64, 128, 256}) { | |||||
| run(1, channel, channel, img_size, img_size, | |||||
| filter_size, stride, false); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_NCHW44) { | TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_NCHW44) { | ||||
| benchmark_convbias(handle()); | |||||
| benchmark_convbias(handle(), true); | |||||
| } | } | ||||
| TEST_F(ARM_COMMON_MULTI_THREADS, BENCHMARK_CONVBIAS_NCHW44) { | TEST_F(ARM_COMMON_MULTI_THREADS, BENCHMARK_CONVBIAS_NCHW44) { | ||||
| benchmark_convbias(handle()); | |||||
| benchmark_convbias(handle(), true); | |||||
| } | } | ||||
| #endif | #endif | ||||
| TEST_F(ARM_COMMON, CONV_BIAS_MATMUL_QS8) { | TEST_F(ARM_COMMON, CONV_BIAS_MATMUL_QS8) { | ||||
| using namespace conv_bias; | using namespace conv_bias; | ||||
| @@ -1464,7 +1483,8 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_QUINT8_STRIDE2_WITHDOTPROD) { | |||||
| #if MEGDNN_WITH_BENCHMARK | #if MEGDNN_WITH_BENCHMARK | ||||
| namespace { | namespace { | ||||
| std::vector<conv_bias::TestArg> get_conv_bias_1x1_benchmark_args(size_t pack_size = 1) { | |||||
| std::vector<conv_bias::TestArg> get_conv_bias_1x1_benchmark_args( | |||||
| size_t pack_size = 1) { | |||||
| using namespace conv_bias; | using namespace conv_bias; | ||||
| std::vector<TestArg> args; | std::vector<TestArg> args; | ||||
| param::ConvBias param; | param::ConvBias param; | ||||
| @@ -1474,15 +1494,17 @@ std::vector<conv_bias::TestArg> get_conv_bias_1x1_benchmark_args(size_t pack_siz | |||||
| param.pad_w = 0; | param.pad_w = 0; | ||||
| param.nonlineMode = param::ConvBias::NonlineMode::IDENTITY; | param.nonlineMode = param::ConvBias::NonlineMode::IDENTITY; | ||||
| auto bench_case = [&](size_t OC, size_t IC, size_t H, size_t W) { | auto bench_case = [&](size_t OC, size_t IC, size_t H, size_t W) { | ||||
| if(pack_size == 1) | |||||
| if (pack_size == 1) | |||||
| args.emplace_back(param, TensorShape{1, IC, H, W}, | args.emplace_back(param, TensorShape{1, IC, H, W}, | ||||
| TensorShape{OC, IC, 1, 1}, TensorShape{}); | |||||
| TensorShape{OC, IC, 1, 1}, TensorShape{}); | |||||
| else { | else { | ||||
| if(pack_size == 4) | |||||
| if (pack_size == 4) | |||||
| param.format = param::ConvBias::Format::NCHW44; | param.format = param::ConvBias::Format::NCHW44; | ||||
| args.emplace_back(param, TensorShape{1, IC / pack_size, H, W, pack_size}, | |||||
| TensorShape{OC / pack_size, IC / pack_size, 1, 1, pack_size, pack_size}, | |||||
| TensorShape{}); | |||||
| args.emplace_back(param, | |||||
| TensorShape{1, IC / pack_size, H, W, pack_size}, | |||||
| TensorShape{OC / pack_size, IC / pack_size, 1, 1, | |||||
| pack_size, pack_size}, | |||||
| TensorShape{}); | |||||
| } | } | ||||
| }; | }; | ||||
| @@ -78,9 +78,10 @@ std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args( | |||||
| std::vector<TestArg> args; | std::vector<TestArg> args; | ||||
| auto pack = [&](size_t n, size_t oc, size_t ic, size_t h, size_t w, | auto pack = [&](size_t n, size_t oc, size_t ic, size_t h, size_t w, | ||||
| size_t kernel, size_t stride, size_t group, NLMode nlmode) { | |||||
| size_t kernel, size_t stride, size_t group, NLMode nlmode, | |||||
| int any_pad = -1) { | |||||
| constexpr int pack_c = 4; | constexpr int pack_c = 4; | ||||
| const size_t pad = no_pad ? 0 : kernel / 2; | |||||
| const size_t pad = any_pad >= 0 ? any_pad : kernel / 2; | |||||
| auto bias_mode = no_bias ? megdnn::BiasMode::NO_BIAS | auto bias_mode = no_bias ? megdnn::BiasMode::NO_BIAS | ||||
| : megdnn::BiasMode::BROADCAST_CHANNEL_BIAS; | : megdnn::BiasMode::BROADCAST_CHANNEL_BIAS; | ||||
| auto oc_per_group = oc / group; | auto oc_per_group = oc / group; | ||||
| @@ -90,7 +91,8 @@ std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args( | |||||
| ic_per_group > 0; | ic_per_group > 0; | ||||
| bool nchw_disable = group > 1 || ic_per_group >= 4; | bool nchw_disable = group > 1 || ic_per_group >= 4; | ||||
| bool nchw44_disable = ic_per_group % pack_c != 0; | bool nchw44_disable = ic_per_group % pack_c != 0; | ||||
| if (!(ok_group)) { | |||||
| bool invalid_pad = (w + 2 * pad < kernel) || (h + 2 * pad < kernel); | |||||
| if (!(ok_group) || invalid_pad) { | |||||
| return; | return; | ||||
| } | } | ||||
| if ((is_input_nchw && nchw_disable) || | if ((is_input_nchw && nchw_disable) || | ||||
| @@ -107,6 +109,7 @@ std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args( | |||||
| param.pad_h = pad; | param.pad_h = pad; | ||||
| param.pad_w = pad; | param.pad_w = pad; | ||||
| param.nonlineMode = nlmode; | param.nonlineMode = nlmode; | ||||
| auto src_tensor_shape = TensorShape{n, ic / pack_c, h, w, pack_c}; | auto src_tensor_shape = TensorShape{n, ic / pack_c, h, w, pack_c}; | ||||
| auto weight_tensor_shape = TensorShape{ | auto weight_tensor_shape = TensorShape{ | ||||
| oc / pack_c, ic / pack_c, kernel_h, kernel_w, pack_c, pack_c}; | oc / pack_c, ic / pack_c, kernel_h, kernel_w, pack_c, pack_c}; | ||||
| @@ -338,6 +341,11 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2_SMALL_GROUP) { | |||||
| check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false), | check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false), | ||||
| handle(), "F32STRD2_SMALL_GROUP"); | handle(), "F32STRD2_SMALL_GROUP"); | ||||
| } | } | ||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_NCHW_NCHW44_F32) { | |||||
| check_conv_bias( | |||||
| get_nchw44_conv_bias_args({3, 5, 7}, 2, false, false, false, true), | |||||
| handle(), "F32_CONV_NCHW_NCHW44"); | |||||
| } | |||||
| /**********************************F16 direct************************/ | /**********************************F16 direct************************/ | ||||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_LARGE_GROUP) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_LARGE_GROUP) { | ||||