| @@ -153,6 +153,27 @@ public: | |||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_CHWNWISE_NCHW88_F16) | MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_CHWNWISE_NCHW88_F16) | ||||
| }; | }; | ||||
| class ConvBiasImpl::AlgoF16DirectNCHW88 final : public AlgoBase { | |||||
| SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | |||||
| public: | |||||
| AlgoF16DirectNCHW88() {} | |||||
| AlgoAttribute attribute() const override { | |||||
| return AlgoAttribute::REPRODUCIBLE; | |||||
| } | |||||
| const char* name() const override { return "F16_CONV_NCHW88_DIRECT"; } | |||||
| bool usable(const NCBKernSizeParam& param, | |||||
| AlgoSelectionStrategy algo_selection_strategy) const override; | |||||
| size_t get_workspace(const NCBKernSizeParam& param) const override; | |||||
| virtual SmallVector<NCBKern> dispatch_kerns( | |||||
| const NCBKernSizeParam& param) const override; | |||||
| ConvAlgoTypePack get_algo_type() const override { | |||||
| return {AlgoDataType::FLOAT16, AlgoCategory::DIRECT}; | |||||
| } | |||||
| MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW88_FP16) | |||||
| }; | |||||
| } // namespace arm_common | } // namespace arm_common | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| #endif | #endif | ||||
| @@ -0,0 +1,296 @@ | |||||
| /** | |||||
| * \file dnn/src/arm_common/conv_bias/f16/direct_nchw88_algo.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express | |||||
| * or implied. | |||||
| */ | |||||
| #include "megdnn/oprs.h" | |||||
| #include "src/arm_common/conv_bias/block_helper.h" | |||||
| #include "src/arm_common/conv_bias/f16/algos.h" | |||||
| #include "src/arm_common/conv_bias/f16/direct_nchw88_kern.h" | |||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "midout.h" | |||||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
| using namespace megdnn; | |||||
| using namespace arm_common; | |||||
| using conv_fun = | |||||
| std::function<void(const WorkspaceBundle& bundle, | |||||
| const ConvBiasImpl::NCBKernParam& kern_param, | |||||
| const ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
| const CpuNDRange& workspace_ids)>; | |||||
| MIDOUT_DECL(megdnn_arm_common_conv_bias_fp16_nchw88) | |||||
| namespace { | |||||
| static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { | |||||
| auto&& fm = param.filter_meta; | |||||
| size_t nr_threads = param.nr_threads; | |||||
| size_t IC = fm.icpg / 8; | |||||
| size_t PH = fm.padding[0]; | |||||
| size_t PW = fm.padding[1]; | |||||
| size_t IH2 = param.isz[0] + 2 * PH; | |||||
| size_t IW2 = param.isz[1] + 2 * PW; | |||||
| if (PH == 0 && PW == 0) { | |||||
| return {nullptr, {}}; | |||||
| } | |||||
| size_t s = (nr_threads * IC * IH2 * IW2 * 8) * sizeof(dt_float16); | |||||
| return {nullptr, {s}}; | |||||
| } | |||||
| void copy_padding_kern(const WorkspaceBundle& bundle, | |||||
| const ConvBiasImpl::NCBKernParam& kern_param, | |||||
| const ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
| const CpuNDRange& workspace_ids) { | |||||
| auto fm = kern_param.filter_meta; | |||||
| size_t group = fm.group; | |||||
| size_t IH = kern_param.isz[0]; | |||||
| size_t IW = kern_param.isz[1]; | |||||
| size_t IC = fm.icpg / 8; | |||||
| size_t PH = fm.padding[0]; | |||||
| size_t PW = fm.padding[1]; | |||||
| size_t IH2 = IH + 2 * PH; | |||||
| size_t IW2 = IW + 2 * PW; | |||||
| if (PH == 0 && PW == 0) { | |||||
| return; | |||||
| } | |||||
| //! Used for get the workspace offset | |||||
| size_t workspace_group_id = workspace_ids[0]; | |||||
| size_t workspace_batch_id = workspace_ids[1]; | |||||
| size_t channel_id = workspace_ids[2]; | |||||
| size_t group_id = ncb_index.ndrange_id[0]; | |||||
| size_t batch_id = ncb_index.ndrange_id[1]; | |||||
| const dt_float16* sptr = | |||||
| kern_param.src<dt_float16>(batch_id, group_id, channel_id, 1, 8); | |||||
| //! copy to sptr_base to eliminate padding effect | |||||
| dt_float16* sptr_base = static_cast<dt_float16*>(bundle.get(0)) + | |||||
| workspace_batch_id * group * IC * IH2 * IW2 * 8 + | |||||
| workspace_group_id * IC * IH2 * IW2 * 8 + | |||||
| channel_id * IH2 * IW2 * 8; | |||||
| std::memset(sptr_base, 0, IH2 * IW2 * 8 * sizeof(dt_float16)); | |||||
| rep(ih, IH) { | |||||
| std::memcpy(sptr_base + (ih + PH) * IW2 * 8 + PW * 8, | |||||
| sptr + ih * IW * 8, IW * 8 * sizeof(dt_float16)); | |||||
| } | |||||
| }; | |||||
| template <size_t FH, size_t SH, BiasMode bias_mode, typename Op> | |||||
| static void do_conv_kern(const WorkspaceBundle& bundle, | |||||
| const ConvBiasImpl::NCBKernParam& kern_param, | |||||
| const ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
| const CpuNDRange& workspace_ids) { | |||||
| auto fm = kern_param.filter_meta; | |||||
| size_t group = fm.group; | |||||
| size_t OH = kern_param.osz[0]; | |||||
| size_t OW = kern_param.osz[1]; | |||||
| size_t FW = FH; | |||||
| size_t IC = fm.icpg / 8; | |||||
| size_t PH = fm.padding[0]; | |||||
| size_t PW = fm.padding[1]; | |||||
| size_t IH2 = kern_param.isz[0] + 2 * PH; | |||||
| size_t IW2 = kern_param.isz[1] + 2 * PW; | |||||
| size_t group_id = ncb_index.ndrange_id[0]; | |||||
| size_t batch_id = ncb_index.ndrange_id[1]; | |||||
| size_t channel_id = workspace_ids[2]; | |||||
| //! Used for get the workspace offset | |||||
| size_t workspace_batch_id = workspace_ids[1]; | |||||
| size_t workspace_group_id = workspace_ids[0]; | |||||
| const __fp16* sptr = nullptr; | |||||
| if (PH == 0 && PW == 0) { | |||||
| sptr = reinterpret_cast<const __fp16*>( | |||||
| kern_param.src<dt_float16>(batch_id, group_id)); | |||||
| } else { | |||||
| sptr = reinterpret_cast<const __fp16*>( | |||||
| static_cast<const dt_float16*>(bundle.get(0))) + | |||||
| workspace_batch_id * group * IC * IH2 * IW2 * 8 + | |||||
| workspace_group_id * IC * IH2 * IW2 * 8; | |||||
| } | |||||
| const __fp16* filter = reinterpret_cast<const __fp16*>( | |||||
| kern_param.filter<dt_float16>(group_id, 1)) + | |||||
| channel_id * IC * FH * FW * 8 * 8; | |||||
| const __fp16* bias_ptr = reinterpret_cast<const __fp16*>( | |||||
| kern_param.bias<dt_float16>(batch_id, group_id, channel_id, 1, 8)); | |||||
| __fp16* dptr = reinterpret_cast<__fp16*>( | |||||
| kern_param.dst<dt_float16>(batch_id, group_id, channel_id, 1, 8)); | |||||
| conv_bias::conv_direct_fp16_nchw88<FH, SH, bias_mode, Op>( | |||||
| sptr, filter, bias_ptr, dptr, IC, IH2, IW2, OH, OW); | |||||
| } | |||||
| } // namespace | |||||
| /* ===================== stride1 algo ===================== */ | |||||
| bool ConvBiasImpl::AlgoF16DirectNCHW88::usable(const NCBKernSizeParam& param, | |||||
| AlgoSelectionStrategy) const { | |||||
| auto&& fm = param.filter_meta; | |||||
| auto fh = fm.spatial[0]; | |||||
| int oc = fm.ocpg; | |||||
| int ic = fm.icpg; | |||||
| bool ok_type = ((param.src_type.enumv() == DTypeEnum::Float16 && | |||||
| param.filter_type.enumv() == DTypeEnum::Float16 && | |||||
| (param.dst_type.enumv() == DTypeEnum::Float16))) && | |||||
| (fm.format == param::Convolution::Format::NCHW88); | |||||
| bool ok_src_dst = (oc % 8 == 0 && oc >= 8 && ic % 8 == 0 && ic >= 8); | |||||
| bool ok_filter = fm.spatial_ndim == 2 && fh == fm.spatial[1] && | |||||
| (fh == 1 || fh == 2 || fh == 3 || fh == 5 || fh == 7); | |||||
| bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||||
| ((fm.stride[0] == 1 && fm.stride[1] == 1) || | |||||
| (fm.stride[0] == 2 && fm.stride[1] == 2)); | |||||
| bool ok_conv = !fm.should_flip; | |||||
| bool ok_comp = param.compute_mode == Param::ComputeMode::DEFAULT; | |||||
| return ok_type && ok_src_dst && ok_filter && ok_slide && ok_conv && ok_comp; | |||||
| } | |||||
| size_t ConvBiasImpl::AlgoF16DirectNCHW88::get_workspace( | |||||
| const NCBKernSizeParam& param) const { | |||||
| MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_nchw88_stride1, | |||||
| midout_iv("AlgoF16DirectNCHW88::get_workspace"_hash)) { | |||||
| return get_bundle(param).total_size_in_bytes(); | |||||
| } | |||||
| MIDOUT_END(); | |||||
| return 0; | |||||
| } | |||||
| SmallVector<ConvBiasImpl::NCBKern> | |||||
| ConvBiasImpl::AlgoF16DirectNCHW88::dispatch_kerns( | |||||
| const NCBKernSizeParam& param) const { | |||||
| auto fm = param.filter_meta; | |||||
| size_t batch = param.n; | |||||
| size_t group = fm.group; | |||||
| WorkspaceBundle wbundle = get_bundle(param); | |||||
| conv_fun do_conv_fun = nullptr; | |||||
| // NOTE: remain_w is not used to gen hash of midout for compatible with | |||||
| // shape runtime | |||||
| #define DO_CONV_KERN_FUN(filter, bias_mode, op, stride) \ | |||||
| MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_nchw88, \ | |||||
| midout_iv(#filter #bias_mode #stride #op##_hash)) { \ | |||||
| do_conv_fun = do_conv_kern<filter, stride, bias_mode, op>; \ | |||||
| } \ | |||||
| MIDOUT_END(); | |||||
| #define GET_STRIDE_PARAM(filter, bias_mode, op) \ | |||||
| switch (fm.stride[0]) { \ | |||||
| case 1: \ | |||||
| DO_CONV_KERN_FUN(filter, bias_mode, op, 1); \ | |||||
| break; \ | |||||
| case 2: \ | |||||
| DO_CONV_KERN_FUN(filter, bias_mode, op, 2); \ | |||||
| break; \ | |||||
| \ | |||||
| default: \ | |||||
| megdnn_assert(0, "stride not supported"); \ | |||||
| } | |||||
| #define GET_OP_PARAM(filter, bias_mode) \ | |||||
| switch (param.nonlineMode) { \ | |||||
| case param::ConvBias::NonlineMode::IDENTITY: \ | |||||
| GET_STRIDE_PARAM(filter, bias_mode, NoneOp<__fp16>) \ | |||||
| break; \ | |||||
| case param::ConvBias::NonlineMode::RELU: \ | |||||
| GET_STRIDE_PARAM(filter, bias_mode, ReluOp<__fp16>) \ | |||||
| break; \ | |||||
| case param::ConvBias::NonlineMode::H_SWISH: \ | |||||
| GET_STRIDE_PARAM(filter, bias_mode, HSwishOp<__fp16>) \ | |||||
| break; \ | |||||
| case param::ConvBias::NonlineMode::SIGMOID: \ | |||||
| GET_STRIDE_PARAM(filter, bias_mode, SigmoidOp<__fp16>) \ | |||||
| break; \ | |||||
| default: \ | |||||
| megdnn_assert(0, "nonline not supported"); \ | |||||
| break; \ | |||||
| } | |||||
| #define GET_BIAS_MODE_PARAM(filter) \ | |||||
| switch (param.bias_mode) { \ | |||||
| case BiasMode::NO_BIAS: \ | |||||
| GET_OP_PARAM(filter, BiasMode::NO_BIAS) \ | |||||
| break; \ | |||||
| case BiasMode::BROADCAST_CHANNEL_BIAS: \ | |||||
| GET_OP_PARAM(filter, BiasMode::BROADCAST_CHANNEL_BIAS) \ | |||||
| break; \ | |||||
| case BiasMode::BIAS: \ | |||||
| GET_OP_PARAM(filter, BiasMode::BIAS) \ | |||||
| break; \ | |||||
| default: \ | |||||
| megdnn_assert(0, "bias_mode not supported"); \ | |||||
| break; \ | |||||
| } | |||||
| #define DISPATCH_CONV_KERN() \ | |||||
| switch (param.filter_meta.spatial[0]) { \ | |||||
| case 1: \ | |||||
| GET_BIAS_MODE_PARAM(1) \ | |||||
| break; \ | |||||
| case 2: \ | |||||
| GET_BIAS_MODE_PARAM(2) \ | |||||
| break; \ | |||||
| case 3: \ | |||||
| GET_BIAS_MODE_PARAM(3) \ | |||||
| break; \ | |||||
| case 5: \ | |||||
| GET_BIAS_MODE_PARAM(5) \ | |||||
| break; \ | |||||
| case 7: \ | |||||
| GET_BIAS_MODE_PARAM(7) \ | |||||
| break; \ | |||||
| default: \ | |||||
| megdnn_assert(0, "filter not supported"); \ | |||||
| break; \ | |||||
| } | |||||
| DISPATCH_CONV_KERN(); | |||||
| #undef DO_CONV_KERN_FUN | |||||
| #undef GET_REMAIN_W_PARAM | |||||
| #undef GET_OP_PARAM | |||||
| #undef GET_BIAS_MODE_PARAM | |||||
| #undef DISPATCH_CONV_KERN | |||||
| megdnn_assert(do_conv_fun); | |||||
| WorkspaceBundle bundle = get_bundle(param); | |||||
| SmallVector<ConvBiasImpl::NCBKern> ret_kerns; | |||||
| auto exec_one_group = [bundle, do_conv_fun]( | |||||
| const NCBKernParam& kern_param, | |||||
| const NCBKernIndex& ncb_index) mutable { | |||||
| auto fm = kern_param.filter_meta; | |||||
| size_t IC = fm.icpg / 8; | |||||
| size_t OC = fm.ocpg / 8; | |||||
| bundle.set(kern_param.workspace_ptr); | |||||
| for (size_t ic = 0; ic < IC; ic++) { | |||||
| copy_padding_kern(bundle, kern_param, ncb_index, | |||||
| {ncb_index.thread_id, 0, ic}); | |||||
| } | |||||
| for (size_t oc = 0; oc < OC; oc++) { | |||||
| do_conv_fun(bundle, kern_param, ncb_index, | |||||
| {ncb_index.thread_id, 0, oc}); | |||||
| } | |||||
| }; | |||||
| // TODO: large group only, further multithread optimization required | |||||
| ret_kerns.push_back({exec_one_group, {group, batch, 1_z}}); | |||||
| return ret_kerns; | |||||
| } | |||||
| #endif | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -0,0 +1,307 @@ | |||||
| /** | |||||
| * \file | |||||
| * dnn/src/arm_common/conv_bias/f16/direct_nchw88_kern.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "src/arm_common/conv_bias/f16/direct_nchw88_kern.h" | |||||
| #include "src/arm_common/conv_bias/opr_impl.h" | |||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||||
| #include "src/fallback/conv_bias/common.h" | |||||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
| using namespace megdnn; | |||||
| using namespace arm_common; | |||||
| template <int PC, int BW, int pc, int bw> | |||||
| struct compute_fma { | |||||
| static inline void call(const float16x8_t* ri, const float16x8_t* rf, | |||||
| float16x8_t* rdst) { | |||||
| #if defined(__aarch64__) | |||||
| rdst[bw] = vfmaq_laneq_f16(rdst[bw], rf[pc], ri[bw], pc); | |||||
| #else | |||||
| rdst[bw] = vfmaq_f16(rdst[bw], rf[pc], | |||||
| vdupq_n_f16(vgetq_lane_f16(ri[bw], pc))); | |||||
| #endif | |||||
| compute_fma<PC, BW, pc, bw + 1>::call(ri, rf, rdst); | |||||
| } | |||||
| }; | |||||
| template <int PC, int BW, int pc> | |||||
| struct compute_fma<PC, BW, pc, BW> { | |||||
| static inline void call(const float16x8_t* ri, const float16x8_t* rf, | |||||
| float16x8_t* rdst) { | |||||
| compute_fma<PC, BW, pc + 1, 0>::call(ri, rf, rdst); | |||||
| } | |||||
| }; | |||||
| template <int PC, int BW> | |||||
| struct compute_fma<PC, BW, PC, 0> { | |||||
| static inline void call(const float16x8_t* ri, const float16x8_t* rf, | |||||
| float16x8_t* rdst) {} | |||||
| }; | |||||
| template <int PC, int BW, int bw> | |||||
| struct load_dst { | |||||
| static inline void call(float16x8_t* rdst, const float16_t* dst_ptr) { | |||||
| rdst[bw] = vld1q_f16(dst_ptr + bw * PC); | |||||
| load_dst<PC, BW, bw + 1>::call(rdst, dst_ptr); | |||||
| } | |||||
| }; | |||||
| template <int PC, int BW> | |||||
| struct load_dst<PC, BW, BW> { | |||||
| static inline void call(float16x8_t* rdst, const float16_t* dst_ptr) {} | |||||
| }; | |||||
| template <int PC, int SW, int BW, int bw> | |||||
| struct load_src { | |||||
| static inline void call(float16x8_t* ri, const float16_t* src_ptr) { | |||||
| ri[bw] = vld1q_f16(src_ptr + bw * SW * PC); | |||||
| load_src<PC, SW, BW, bw + 1>::call(ri, src_ptr); | |||||
| } | |||||
| }; | |||||
| template <int PC, int SW, int BW> | |||||
| struct load_src<PC, SW, BW, BW> { | |||||
| static inline void call(float16x8_t* ri, const float16_t* src_ptr) {} | |||||
| }; | |||||
| template <int PC, int pc> | |||||
| struct load_filter { | |||||
| static inline void call(float16x8_t* rf, const float16_t* filter_ptr) { | |||||
| rf[pc] = vld1q_f16(filter_ptr + pc * PC); | |||||
| load_filter<PC, pc + 1>::call(rf, filter_ptr); | |||||
| } | |||||
| }; | |||||
| template <int PC> | |||||
| struct load_filter<PC, PC> { | |||||
| static inline void call(float16x8_t* rf, const float16_t* filter_ptr) {} | |||||
| }; | |||||
| template <int PC, int BW, int bw> | |||||
| struct store_dst { | |||||
| static inline void call(const float16x8_t* rdst, float16_t* dst_ptr) { | |||||
| vst1q_f16(dst_ptr + bw * PC, rdst[bw]); | |||||
| store_dst<PC, BW, bw + 1>::call(rdst, dst_ptr); | |||||
| } | |||||
| }; | |||||
| template <int PC, int BW> | |||||
| struct store_dst<PC, BW, BW> { | |||||
| static inline void call(const float16x8_t* rdst, float16_t* dst_ptr) {} | |||||
| }; | |||||
| template <int FH, int SH, int BW> | |||||
| static inline void do_conv_kern_1xBW(const float16_t*& src, float16_t*& dst, | |||||
| const float16_t* filter, int IW, int OW, | |||||
| int& ow) { | |||||
| constexpr int PC = 8; | |||||
| constexpr int FW = FH; | |||||
| constexpr int SW = SH; | |||||
| float16x8_t rf[PC]; | |||||
| if (FH == 1 && FW == 1) { | |||||
| load_filter<PC, 0>::call(rf, filter); | |||||
| } | |||||
| for (; ow + BW - 1 < OW; ow += BW) { | |||||
| float16x8_t rdst[BW]; | |||||
| load_dst<PC, BW, 0>::call(rdst, dst); | |||||
| for (int fh = 0; fh < FH; ++fh) { | |||||
| for (int fw = 0; fw < FW; ++fw) { | |||||
| float16x8_t ri[BW]; | |||||
| load_src<PC, SW, BW, 0>::call(ri, src + (fh * IW + fw) * PC); | |||||
| if (FH > 1 || FW > 1) { | |||||
| load_filter<PC, 0>::call(rf, | |||||
| filter + (fh * FW + fw) * PC * PC); | |||||
| } | |||||
| compute_fma<PC, BW, 0, 0>::call(ri, rf, rdst); | |||||
| } | |||||
| } | |||||
| store_dst<PC, BW, 0>::call(rdst, dst); | |||||
| src += SW * BW * PC; | |||||
| dst += BW * PC; | |||||
| } | |||||
| } | |||||
| template <BiasMode bias_mode> | |||||
| static void do_load_bias_kern(float16_t* dst, const float16_t* bias, int OH, | |||||
| int OW) { | |||||
| constexpr int PC = 8; | |||||
| if (bias_mode == BiasMode::NO_BIAS) { | |||||
| memset(dst, 0, OH * OW * PC * sizeof(float16_t)); | |||||
| } else if (bias_mode == BiasMode::BIAS) { | |||||
| memcpy(dst, bias, OH * OW * PC * sizeof(float16_t)); | |||||
| } else if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
| float16x8_t bias_v = vld1q_f16(bias); | |||||
| int i = 0; | |||||
| for (; i + 3 < OH * OW; i += 4) { | |||||
| vst1q_f16(dst + PC * 0, bias_v); | |||||
| vst1q_f16(dst + PC * 1, bias_v); | |||||
| vst1q_f16(dst + PC * 2, bias_v); | |||||
| vst1q_f16(dst + PC * 3, bias_v); | |||||
| dst += PC * 4; | |||||
| } | |||||
| for (; i < OH * OW; i += 1) { | |||||
| vst1q_f16(dst, bias_v); | |||||
| dst += PC; | |||||
| } | |||||
| } | |||||
| } | |||||
| template <typename Op> | |||||
| static void do_op_kern(float16_t* dst, int OH, int OW) { | |||||
| constexpr int PC = 8; | |||||
| Op op; | |||||
| int i = 0; | |||||
| for (; i + 3 < OH * OW; i += 4) { | |||||
| float16x8_t dst0 = vld1q_f16(dst + PC * 0); | |||||
| float16x8_t dst1 = vld1q_f16(dst + PC * 1); | |||||
| float16x8_t dst2 = vld1q_f16(dst + PC * 2); | |||||
| float16x8_t dst3 = vld1q_f16(dst + PC * 3); | |||||
| dst0 = op(dst0); | |||||
| dst1 = op(dst1); | |||||
| dst2 = op(dst2); | |||||
| dst3 = op(dst3); | |||||
| vst1q_f16(dst + PC * 0, dst0); | |||||
| vst1q_f16(dst + PC * 1, dst1); | |||||
| vst1q_f16(dst + PC * 2, dst2); | |||||
| vst1q_f16(dst + PC * 3, dst3); | |||||
| dst += PC * 4; | |||||
| } | |||||
| for (; i < OH * OW; i += 1) { | |||||
| vst1q_f16(dst, op(vld1q_f16(dst))); | |||||
| dst += PC; | |||||
| } | |||||
| } | |||||
| template <int FH, int SH> | |||||
| static void do_conv_kern(const float16_t* src, float16_t* dst, | |||||
| const float16_t* filter, int IC, int IH, int IW, | |||||
| int OH, int OW) { | |||||
| constexpr int PC = 8; | |||||
| constexpr int FW = FH; | |||||
| for (int ic = 0; ic < IC; ic += 1) { | |||||
| const float16_t* src_ptr_h = src; | |||||
| float16_t* dst_ptr_h = dst; | |||||
| for (int oh = 0; oh < OH; oh += 1) { | |||||
| const float16_t* src_ptr_w = src_ptr_h; | |||||
| float16_t* dst_ptr_w = dst_ptr_h; | |||||
| int ow = 0; | |||||
| do_conv_kern_1xBW<FH, SH, 4>(src_ptr_w, dst_ptr_w, filter, IW, OW, | |||||
| ow); | |||||
| if (OW & 3) { | |||||
| do_conv_kern_1xBW<FH, SH, 2>(src_ptr_w, dst_ptr_w, filter, IW, | |||||
| OW, ow); | |||||
| do_conv_kern_1xBW<FH, SH, 1>(src_ptr_w, dst_ptr_w, filter, IW, | |||||
| OW, ow); | |||||
| } | |||||
| src_ptr_h += SH * IW * PC; | |||||
| dst_ptr_h += OW * PC; | |||||
| } | |||||
| src += IH * IW * PC; | |||||
| filter += FH * FW * PC * PC; | |||||
| } | |||||
| } | |||||
| static void do_conv_kern_1x1(const float16_t* src, float16_t* dst, | |||||
| const float16_t* filter, int IC, int OH, int OW) { | |||||
| constexpr int PC = 8; | |||||
| const int IH = OH; | |||||
| const int IW = OW; | |||||
| const int IHW = IH * IW; | |||||
| const int OHW = OH * OW; | |||||
| for (int ic = 0; ic < IC; ic += 1) { | |||||
| const float16_t* src_ptr_hw = src; | |||||
| float16_t* dst_ptr_hw = dst; | |||||
| int ohw = 0; | |||||
| do_conv_kern_1xBW<1, 1, 8>(src_ptr_hw, dst_ptr_hw, filter, IHW, OHW, | |||||
| ohw); | |||||
| do_conv_kern_1xBW<1, 1, 4>(src_ptr_hw, dst_ptr_hw, filter, IHW, OHW, | |||||
| ohw); | |||||
| do_conv_kern_1xBW<1, 1, 1>(src_ptr_hw, dst_ptr_hw, filter, IHW, OHW, | |||||
| ohw); | |||||
| src += IHW * PC; | |||||
| filter += PC * PC; | |||||
| } | |||||
| } | |||||
| template <size_t FH, size_t SH, BiasMode bias_mode, typename Op> | |||||
| void conv_bias::conv_direct_fp16_nchw88(const __fp16* src, const __fp16* filter, | |||||
| const __fp16* bias, __fp16* dst, int IC, | |||||
| int IH, int IW, int OH, int OW) { | |||||
| do_load_bias_kern<bias_mode>(dst, bias, OH, OW); | |||||
| if (FH == 1 && SH == 1 && IH == OH && IW == OW) { | |||||
| do_conv_kern_1x1(src, dst, filter, IC, OH, OW); | |||||
| } else { | |||||
| do_conv_kern<FH, SH>(src, dst, filter, IC, IH, IW, OH, OW); | |||||
| } | |||||
| do_op_kern<Op>(dst, OH, OW); | |||||
| } | |||||
| #define INSTANTIATION(stride, filter, bias, Op) \ | |||||
| template void \ | |||||
| conv_bias::conv_direct_fp16_nchw88<filter, stride, bias, Op>( \ | |||||
| const __fp16*, const __fp16*, const __fp16*, __fp16*, int, int, \ | |||||
| int, int, int); | |||||
| #define FOR_OP(stride, filter, bias) \ | |||||
| INSTANTIATION(stride, filter, bias, SigmoidOp<__fp16>) \ | |||||
| INSTANTIATION(stride, filter, bias, ReluOp<__fp16>) \ | |||||
| INSTANTIATION(stride, filter, bias, HSwishOp<__fp16>) \ | |||||
| INSTANTIATION(stride, filter, bias, NoneOp<__fp16>) | |||||
| #define FOR_BIAS(stride, filter) \ | |||||
| FOR_OP(stride, filter, BiasMode::NO_BIAS) \ | |||||
| FOR_OP(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) \ | |||||
| FOR_OP(stride, filter, BiasMode::BIAS) | |||||
| #define FOR_FILTER(stride) \ | |||||
| FOR_BIAS(stride, 1) \ | |||||
| FOR_BIAS(stride, 2) \ | |||||
| FOR_BIAS(stride, 3) \ | |||||
| FOR_BIAS(stride, 5) \ | |||||
| FOR_BIAS(stride, 7) | |||||
| #define FOR_STRIDE \ | |||||
| FOR_FILTER(1) \ | |||||
| FOR_FILTER(2) | |||||
| FOR_STRIDE | |||||
| #undef FOR_STRIDE | |||||
| #undef FOR_FILTER | |||||
| #undef FOR_BIAS | |||||
| #undef FOR_OP | |||||
| #undef INSTANTIATION | |||||
| #endif | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -0,0 +1,32 @@ | |||||
| /** | |||||
| * \file dnn/src/arm_common/conv_bias/f16/direct_nchw88_kern.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include "src/arm_common/conv_bias/opr_impl.h" | |||||
| #include "src/fallback/conv_bias/common.h" | |||||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
| namespace megdnn { | |||||
| namespace arm_common { | |||||
| namespace conv_bias { | |||||
| template <size_t FH, size_t SH, BiasMode bias_mode, typename Op> | |||||
| void conv_direct_fp16_nchw88(const __fp16* src, const __fp16* filter, | |||||
| const __fp16* bias, __fp16* dst, int IC, int IH, | |||||
| int IW, int OH, int OW); | |||||
| } // namespace conv_bias | |||||
| } // namespace arm_common | |||||
| } // namespace megdnn | |||||
| #endif | |||||
| @@ -86,6 +86,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||||
| AlgoF16Direct f16_direct; | AlgoF16Direct f16_direct; | ||||
| AlgoF16DirectStride1 f16_direct_stride1; | AlgoF16DirectStride1 f16_direct_stride1; | ||||
| AlgoF16ChannelWiseNCHW88 f16_channel_wise_nchw88; | AlgoF16ChannelWiseNCHW88 f16_channel_wise_nchw88; | ||||
| AlgoF16DirectNCHW88 f16_direct_nchw88; | |||||
| #endif | #endif | ||||
| SmallVector<std::unique_ptr<AlgoBase>> refhold; | SmallVector<std::unique_ptr<AlgoBase>> refhold; | ||||
| @@ -121,6 +122,7 @@ public: | |||||
| m_direct_algos.emplace_back(&f16_direct_stride1); | m_direct_algos.emplace_back(&f16_direct_stride1); | ||||
| m_direct_algos.emplace_back(&f16_direct); | m_direct_algos.emplace_back(&f16_direct); | ||||
| m_direct_algos.emplace_back(&f16_channel_wise_nchw88); | m_direct_algos.emplace_back(&f16_channel_wise_nchw88); | ||||
| m_direct_algos.emplace_back(&f16_direct_nchw88); | |||||
| #endif | #endif | ||||
| m_direct_algos.emplace_back(&i8x8x16_direct); | m_direct_algos.emplace_back(&i8x8x16_direct); | ||||
| m_direct_algos.emplace_back(&i8x8x16_stride2_filter2); | m_direct_algos.emplace_back(&i8x8x16_stride2_filter2); | ||||
| @@ -252,7 +254,6 @@ public: | |||||
| } | } | ||||
| } | } | ||||
| for (auto&& algo : m_direct_algos) { | for (auto&& algo : m_direct_algos) { | ||||
| m_all_algos_map.emplace(algo->info().desc, algo); | m_all_algos_map.emplace(algo->info().desc, algo); | ||||
| } | } | ||||
| @@ -261,8 +262,7 @@ public: | |||||
| } | } | ||||
| } | } | ||||
| const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& direct_algos() | |||||
| const { | |||||
| const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& direct_algos() const { | |||||
| return m_direct_algos; | return m_direct_algos; | ||||
| } | } | ||||
| const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& winograd_algos() | const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& winograd_algos() | ||||
| @@ -10,9 +10,9 @@ | |||||
| * implied. | * implied. | ||||
| */ | */ | ||||
| #pragma once | #pragma once | ||||
| #include "src/common/algo_base.h" | |||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/fallback/conv_bias/opr_impl.h" | #include "src/fallback/conv_bias/opr_impl.h" | ||||
| #include "src/common/algo_base.h" | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace arm_common { | namespace arm_common { | ||||
| @@ -28,7 +28,8 @@ public: | |||||
| } | } | ||||
| }; | }; | ||||
| SmallVector<fallback::ConvBiasImpl::AlgoBase*> get_all_packed_algo() override; | |||||
| SmallVector<fallback::ConvBiasImpl::AlgoBase*> get_all_packed_algo() | |||||
| override; | |||||
| bool is_matmul_quantized_prefer( | bool is_matmul_quantized_prefer( | ||||
| const fallback::ConvBiasImpl::NCBKernSizeParam& ncb_param) | const fallback::ConvBiasImpl::NCBKernSizeParam& ncb_param) | ||||
| @@ -97,6 +98,7 @@ private: | |||||
| class AlgoF16Direct; | class AlgoF16Direct; | ||||
| class AlgoF16DirectStride1; | class AlgoF16DirectStride1; | ||||
| class AlgoF16ChannelWiseNCHW88; | class AlgoF16ChannelWiseNCHW88; | ||||
| class AlgoF16DirectNCHW88; | |||||
| #endif | #endif | ||||
| class AlgoPack; | class AlgoPack; | ||||
| @@ -56,8 +56,7 @@ public: | |||||
| bool is_thread_safe() const override { return true; } | bool is_thread_safe() const override { return true; } | ||||
| void exec_preprocess(const TensorLayout& src_layout, | void exec_preprocess(const TensorLayout& src_layout, | ||||
| _megdnn_tensor_in filter, | |||||
| _megdnn_tensor_in bias, | |||||
| _megdnn_tensor_in filter, _megdnn_tensor_in bias, | |||||
| const TensorLayout& z_layout, | const TensorLayout& z_layout, | ||||
| const TensorLayout& dst_layout, | const TensorLayout& dst_layout, | ||||
| PreprocessedFilter* preprocessed_filter, | PreprocessedFilter* preprocessed_filter, | ||||
| @@ -243,6 +242,7 @@ public: | |||||
| ARM_COMMON_DIRECT_FP16, | ARM_COMMON_DIRECT_FP16, | ||||
| ARM_COMMON_DIRECT_STRD1_FP16, | ARM_COMMON_DIRECT_STRD1_FP16, | ||||
| ARM_COMMON_CHWNWISE_NCHW88_F16, | ARM_COMMON_CHWNWISE_NCHW88_F16, | ||||
| ARM_COMMON_DIRECT_NCHW88_FP16, | |||||
| ARM_COMMON_WINOGRAD_F23_4X4_FP32, | ARM_COMMON_WINOGRAD_F23_4X4_FP32, | ||||
| ARM_COMMON_WINOGRAD_F63_FP32, | ARM_COMMON_WINOGRAD_F63_FP32, | ||||
| ARM_COMMON_WINOGRAD_F63_4X4_FP32, | ARM_COMMON_WINOGRAD_F63_4X4_FP32, | ||||
| @@ -288,7 +288,7 @@ public: | |||||
| #else | #else | ||||
| ARMV7_MATMUL_S8, | ARMV7_MATMUL_S8, | ||||
| ARMV7_MATMUL_QU8, | ARMV7_MATMUL_QU8, | ||||
| #endif // MEGDNN_AARCH64 | |||||
| #endif // MEGDNN_AARCH64 | |||||
| #endif | #endif | ||||
| }; | }; | ||||
| @@ -124,8 +124,8 @@ std::vector<conv_bias::TestArg> get_nchw44_channel_wise_args( | |||||
| for (size_t n : {1, 2}) { | for (size_t n : {1, 2}) { | ||||
| for (auto nlmode : nonlinemode) { | for (auto nlmode : nonlinemode) { | ||||
| for (bool pad : {true}) { | for (bool pad : {true}) { | ||||
| for (size_t group : {1, 2, 4, 7, 128}) { | |||||
| for (size_t size : {4, 6, 7, 9, 15, 40}) { | |||||
| for (size_t group : {1, 2, 4, 7, 16}) { | |||||
| for (size_t size : {4, 6, 7, 9, 20}) { | |||||
| for (size_t kern : kernel) { | for (size_t kern : kernel) { | ||||
| pack(n, group, size, size, kern, stride, nlmode, | pack(n, group, size, size, kern, stride, nlmode, | ||||
| pad); | pad); | ||||
| @@ -134,8 +134,8 @@ std::vector<conv_bias::TestArg> get_nchw44_channel_wise_args( | |||||
| } | } | ||||
| } | } | ||||
| for (bool pad : {false}) { | for (bool pad : {false}) { | ||||
| for (size_t group : {1, 2, 7, 128}) { | |||||
| for (size_t size : {7, 9, 15, 40}) { | |||||
| for (size_t group : {1, 2, 7, 16}) { | |||||
| for (size_t size : {7, 9, 20}) { | |||||
| for (size_t kern : kernel) { | for (size_t kern : kernel) { | ||||
| pack(n, group, size, size, kern, stride, nlmode, | pack(n, group, size, size, kern, stride, nlmode, | ||||
| pad); | pad); | ||||
| @@ -199,8 +199,8 @@ std::vector<conv_bias::TestArg> get_nchw88_channel_wise_args( | |||||
| for (size_t n : {1, 2}) { | for (size_t n : {1, 2}) { | ||||
| for (auto nlmode : nonlinemode) { | for (auto nlmode : nonlinemode) { | ||||
| for (bool pad : {true}) { | for (bool pad : {true}) { | ||||
| for (size_t group : {1, 2, 4, 7, 8, 128}) { | |||||
| for (size_t size : {4, 6, 7, 9, 15, 40}) { | |||||
| for (size_t group : {1, 2, 4, 7, 8, 16}) { | |||||
| for (size_t size : {4, 6, 7, 9, 20}) { | |||||
| for (size_t kern : kernel) { | for (size_t kern : kernel) { | ||||
| pack(n, group, size, size, kern, stride, nlmode, | pack(n, group, size, size, kern, stride, nlmode, | ||||
| pad); | pad); | ||||
| @@ -209,8 +209,8 @@ std::vector<conv_bias::TestArg> get_nchw88_channel_wise_args( | |||||
| } | } | ||||
| } | } | ||||
| for (bool pad : {false}) { | for (bool pad : {false}) { | ||||
| for (size_t group : {1, 2, 7, 128}) { | |||||
| for (size_t size : {7, 9, 15, 40}) { | |||||
| for (size_t group : {1, 2, 7, 16}) { | |||||
| for (size_t size : {7, 9, 20}) { | |||||
| for (size_t kern : kernel) { | for (size_t kern : kernel) { | ||||
| pack(n, group, size, size, kern, stride, nlmode, | pack(n, group, size, size, kern, stride, nlmode, | ||||
| pad); | pad); | ||||
| @@ -412,6 +412,23 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE2_FP16_NCHW88) { | |||||
| get_nchw88_channel_wise_args({2, 3, 5}, 2, false, false, false), | get_nchw88_channel_wise_args({2, 3, 5}, 2, false, false, false), | ||||
| handle(), rng, "F16_CHANNEL_WISE_NCHW88", 0.03); | handle(), rng, "F16_CHANNEL_WISE_NCHW88", 0.03); | ||||
| } | } | ||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_NCHW88_S1) { | |||||
| NormalRNG rng(1); | |||||
| checker_conv_bias_f16( | |||||
| get_nchw88_conv_bias_args({1, 2, 3, 5, 7}, FULL_NLMODE, | |||||
| ALL_BIASMODE, 1), | |||||
| handle(), rng, "F16_CONV_NCHW88_DIRECT", 0.03); | |||||
| } | |||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_NCHW88_S2) { | |||||
| NormalRNG rng(1); | |||||
| checker_conv_bias_f16( | |||||
| get_nchw88_conv_bias_args({1, 2, 3, 5, 7}, FULL_NLMODE, | |||||
| ALL_BIASMODE, 2), | |||||
| handle(), rng, "F16_CONV_NCHW88_DIRECT", 0.03); | |||||
| } | |||||
| #endif | #endif | ||||
| /**********************************algo 8816 direct************************/ | /**********************************algo 8816 direct************************/ | ||||
| @@ -794,8 +811,6 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63) { | |||||
| check_winograd("1:6:32", checker, args); | check_winograd("1:6:32", checker, args); | ||||
| } | } | ||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4) { | ||||
| using namespace conv_bias; | using namespace conv_bias; | ||||
| std::vector<TestArg> args = get_winograd_mk_packed_args(); | std::vector<TestArg> args = get_winograd_mk_packed_args(); | ||||
| @@ -804,19 +819,15 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4) { | |||||
| check_winograd("4:6:16", checker, args, param::MatrixMul::Format::MK4); | check_winograd("4:6:16", checker, args, param::MatrixMul::Format::MK4); | ||||
| } | } | ||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4_NCHW44) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4_NCHW44) { | ||||
| using namespace conv_bias; | using namespace conv_bias; | ||||
| std::vector<TestArg> args = | std::vector<TestArg> args = | ||||
| get_nchw44_conv_bias_args({3},QUAN_NLMODE,BR_AND_NO_BIASMODE,1); | |||||
| get_nchw44_conv_bias_args({3}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1); | |||||
| Checker<ConvBiasForward> checker(handle()); | Checker<ConvBiasForward> checker(handle()); | ||||
| check_winograd("4:6:16", checker, args, param::MatrixMul::Format::MK4, | check_winograd("4:6:16", checker, args, param::MatrixMul::Format::MK4, | ||||
| param::ConvBias::Format::NCHW44); | param::ConvBias::Format::NCHW44); | ||||
| } | } | ||||
| //! uncomment it when low precision mode is ok | //! uncomment it when low precision mode is ok | ||||
| #if 0 | #if 0 | ||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F73_4_NCHW44) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F73_4_NCHW44) { | ||||
| @@ -847,8 +858,6 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F54) { | |||||
| check_winograd("1:5:32", checker, args); | check_winograd("1:5:32", checker, args); | ||||
| } | } | ||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F45) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F45) { | ||||
| using namespace conv_bias; | using namespace conv_bias; | ||||
| std::vector<TestArg> args = get_winograd_args(5); | std::vector<TestArg> args = get_winograd_args(5); | ||||
| @@ -971,18 +980,17 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8) { | |||||
| using namespace conv_bias; | using namespace conv_bias; | ||||
| Checker<ConvBiasForward> checker(handle()); | Checker<ConvBiasForward> checker(handle()); | ||||
| auto run = [&checker](const std::vector<TestArg>& args, | |||||
| DType A_dtype, | |||||
| auto run = [&checker](const std::vector<TestArg>& args, DType A_dtype, | |||||
| DType B_dtype, DType C_dtype, DType D_dtype, | DType B_dtype, DType C_dtype, DType D_dtype, | ||||
| float eps) { | float eps) { | ||||
| for (auto&& arg : args) { | for (auto&& arg : args) { | ||||
| checker.set_dtype(0, A_dtype) | |||||
| .set_dtype(1, B_dtype) | |||||
| .set_dtype(2, C_dtype) | |||||
| .set_dtype(4, D_dtype) | |||||
| .set_epsilon(eps) | |||||
| .set_param(arg.param) | |||||
| .execs({arg.src, arg.filter, arg.bias, {}, {}}); | |||||
| checker.set_dtype(0, A_dtype) | |||||
| .set_dtype(1, B_dtype) | |||||
| .set_dtype(2, C_dtype) | |||||
| .set_dtype(4, D_dtype) | |||||
| .set_epsilon(eps) | |||||
| .set_param(arg.param) | |||||
| .execs({arg.src, arg.filter, arg.bias, {}, {}}); | |||||
| } | } | ||||
| }; | }; | ||||
| @@ -997,9 +1005,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8) { | |||||
| std::vector<TestArg> quantized_args = get_int8_nchw44_args(3, 4); | std::vector<TestArg> quantized_args = get_int8_nchw44_args(3, 4); | ||||
| UniformIntRNG int_rng{-50, 50}; | UniformIntRNG int_rng{-50, 50}; | ||||
| checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng); | checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng); | ||||
| run(quantized_args, dtype::QuantizedS8(2.5f), | |||||
| dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), | |||||
| dtype::QuantizedS8(60.25f),1e-3); | |||||
| run(quantized_args, dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f), | |||||
| dtype::QuantizedS32(6.25f), dtype::QuantizedS8(60.25f), 1e-3); | |||||
| } | } | ||||
| TEST_F(ARM_COMMON_MULTI_THREADS, | TEST_F(ARM_COMMON_MULTI_THREADS, | ||||
| @@ -400,7 +400,8 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF16_STR1) { | |||||
| benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, | benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, | ||||
| {1, {4}}, data_type); | {1, {4}}, data_type); | ||||
| } | } | ||||
| TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_F16_NCHW88) { | |||||
| TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CHANNEL_WISE_FP16_NCHW88) { | |||||
| constexpr size_t RUNS = 50; | constexpr size_t RUNS = 50; | ||||
| std::string algo_name = "F16_CHANNEL_WISE_NCHW88"; | std::string algo_name = "F16_CHANNEL_WISE_NCHW88"; | ||||
| @@ -462,6 +463,64 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_F16_NCHW88) { | |||||
| bench_case(1, 64, 28, 28, 2, 0, 2); | bench_case(1, 64, 28, 28, 2, 0, 2); | ||||
| } | } | ||||
| TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_FP16_NCHW88) { | |||||
| constexpr size_t RUNS = 40; | |||||
| std::vector<DType> data_type = {dtype::Float16(), dtype::Float16(), | |||||
| dtype::Float16(), dtype::Float16()}; | |||||
| auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, | |||||
| size_t FS, size_t group, size_t P, size_t S) { | |||||
| param::ConvBias param; | |||||
| param.nonlineMode = param::ConvBias::NonlineMode::RELU; | |||||
| param.pad_h = P; | |||||
| param.pad_w = P; | |||||
| param.stride_h = S; | |||||
| param.stride_w = S; | |||||
| param.sparse = param::ConvBias::Sparse::DENSE; | |||||
| param.format = param::ConvBias::Format::NCHW88; | |||||
| auto OH = (H + 2 * P - FS) / static_cast<size_t>(S) + 1; | |||||
| auto OW = (W + 2 * P - FS) / static_cast<size_t>(S) + 1; | |||||
| TensorShape src = {N, IC / 8, H, W, 8}; | |||||
| TensorShape filter = {OC / 8, IC / 8, FS, FS, 8, 8}; | |||||
| if (group > 1) { | |||||
| filter = {group, OC / group / 8, IC / group / 8, FS, FS, 8, 8}; | |||||
| param.sparse = param::ConvBias::Sparse::GROUP; | |||||
| } | |||||
| TensorShape bias = {1, OC / 8, 1, 1, 8}; | |||||
| TensorShape dst = {N, OC / 8, OH, OW, 8}; | |||||
| SmallVector<TensorShape> shapes{src, filter, bias, {}, dst}; | |||||
| float computations = | |||||
| (((IC / group) * FS * FS + 1) * dst.total_nr_elems() * 2 + | |||||
| dst.total_nr_elems()) * | |||||
| 1e-6; | |||||
| std::vector<std::pair<SmallVector<TensorShape>, float>> shape_arg = { | |||||
| std::make_pair(shapes, computations)}; | |||||
| benchmark_impl(param, shape_arg, ".+", RUNS, {4, {4, 5, 6, 7}}, | |||||
| {1, {7}}, data_type); | |||||
| }; | |||||
| bench_case(1, 64, 64, 28, 28, 3, 1, 1, 1); | |||||
| bench_case(1, 64, 64, 28, 28, 5, 1, 2, 1); | |||||
| bench_case(1, 64, 64, 28, 28, 7, 1, 3, 1); | |||||
| bench_case(1, 64, 64, 28, 28, 3, 1, 1, 2); | |||||
| bench_case(1, 64, 64, 28, 28, 5, 1, 2, 2); | |||||
| bench_case(1, 64, 64, 28, 28, 7, 1, 3, 2); | |||||
| bench_case(1, 64, 64, 28, 28, 3, 2, 1, 1); | |||||
| bench_case(1, 64, 64, 28, 28, 3, 4, 1, 1); | |||||
| bench_case(1, 64, 64, 28, 28, 3, 8, 1, 1); | |||||
| bench_case(1, 16, 16, 28, 28, 3, 1, 1, 1); | |||||
| bench_case(1, 32, 32, 28, 28, 3, 1, 1, 1); | |||||
| bench_case(1, 128, 128, 28, 28, 3, 1, 1, 1); | |||||
| bench_case(1, 256, 256, 28, 28, 3, 1, 1, 1); | |||||
| bench_case(1, 64, 64, 7, 7, 3, 1, 1, 1); | |||||
| bench_case(1, 64, 64, 14, 14, 3, 1, 1, 1); | |||||
| bench_case(1, 64, 64, 56, 56, 3, 1, 1, 1); | |||||
| bench_case(1, 64, 64, 112, 112, 3, 1, 1, 1); | |||||
| } | |||||
| #endif | #endif | ||||
| TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | ||||
| BENCHMARK_CONVBIAS_DIRECT_INT8x8x16) { | BENCHMARK_CONVBIAS_DIRECT_INT8x8x16) { | ||||
| @@ -769,10 +828,10 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_INT8_NCHW44_DOT) { | |||||
| bench_case(1, 128, 128, 28, 28, 3, 4, 1, 1); | bench_case(1, 128, 128, 28, 28, 3, 4, 1, 1); | ||||
| bench_case(1, 256, 256, 14, 14, 3, 4, 1, 1); | bench_case(1, 256, 256, 14, 14, 3, 4, 1, 1); | ||||
| bench_case(1, 512, 512, 7, 7, 3, 4, 1, 1); | bench_case(1, 512, 512, 7, 7, 3, 4, 1, 1); | ||||
| } | } | ||||
| TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_INT8_NCHW44_DOT_S2) { | |||||
| TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | |||||
| BENCHMARK_CONVBIAS_INT8_NCHW44_DOT_S2) { | |||||
| constexpr size_t RUNS = 40; | constexpr size_t RUNS = 40; | ||||
| std::vector<DType> data_type = { | std::vector<DType> data_type = { | ||||
| dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f), | dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f), | ||||
| @@ -825,16 +884,13 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_INT8_NCHW44_DOT_S2 | |||||
| bench_case(1, 128, 128, 28, 28, 3, 4, 1, 2); | bench_case(1, 128, 128, 28, 28, 3, 4, 1, 2); | ||||
| bench_case(1, 256, 256, 14, 14, 3, 4, 1, 2); | bench_case(1, 256, 256, 14, 14, 3, 4, 1, 2); | ||||
| bench_case(1, 512, 512, 7, 7, 3, 4, 1, 2); | bench_case(1, 512, 512, 7, 7, 3, 4, 1, 2); | ||||
| } | } | ||||
| #endif | #endif | ||||
| TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_FLOAT_NCHW44) { | TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_FLOAT_NCHW44) { | ||||
| constexpr size_t RUNS = 40; | constexpr size_t RUNS = 40; | ||||
| std::vector<DType> data_type = { | |||||
| dtype::Float32(), dtype::Float32(), | |||||
| dtype::Float32(), dtype::Float32()}; | |||||
| std::vector<DType> data_type = {dtype::Float32(), dtype::Float32(), | |||||
| dtype::Float32(), dtype::Float32()}; | |||||
| auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, | auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, | ||||
| size_t FS, size_t group, size_t P, size_t S, | size_t FS, size_t group, size_t P, size_t S, | ||||
| bool is_nchw = false) { | bool is_nchw = false) { | ||||
| @@ -880,15 +936,12 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_FLOAT_NCHW44) { | |||||
| bench_case(1, 128, 128, 28, 28, 3, 4, 1, 2); | bench_case(1, 128, 128, 28, 28, 3, 4, 1, 2); | ||||
| bench_case(1, 256, 256, 14, 14, 3, 4, 1, 2); | bench_case(1, 256, 256, 14, 14, 3, 4, 1, 2); | ||||
| bench_case(1, 512, 512, 7, 7, 3, 4, 1, 2); | bench_case(1, 512, 512, 7, 7, 3, 4, 1, 2); | ||||
| bench_case(1, 64, 64, 56*2, 56*2, 3, 4, 1, 2); | |||||
| bench_case(1, 128, 128, 28*2, 28*2, 3, 4, 1, 2); | |||||
| bench_case(1, 256, 256, 14*2, 14*2, 3, 4, 1, 2); | |||||
| bench_case(1, 512, 512, 7*2, 7*2, 3, 4, 1, 2); | |||||
| } | |||||
| bench_case(1, 64, 64, 56 * 2, 56 * 2, 3, 4, 1, 2); | |||||
| bench_case(1, 128, 128, 28 * 2, 28 * 2, 3, 4, 1, 2); | |||||
| bench_case(1, 256, 256, 14 * 2, 14 * 2, 3, 4, 1, 2); | |||||
| bench_case(1, 512, 512, 7 * 2, 7 * 2, 3, 4, 1, 2); | |||||
| } | |||||
| TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | ||||
| BENCHMARK_CONVBIAS_INT8_INT8_INT8_STRIDE2) { | BENCHMARK_CONVBIAS_INT8_INT8_INT8_STRIDE2) { | ||||
| @@ -1473,9 +1526,9 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_WINOGRAD_INT8) { | |||||
| algo_name = "WINOGRAD:ARMV7_INT16X16X32_MK8_4X8:8:2:32"; | algo_name = "WINOGRAD:ARMV7_INT16X16X32_MK8_4X8:8:2:32"; | ||||
| #endif | #endif | ||||
| std::vector<DType> data_type = {dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f), | |||||
| dtype::QuantizedS32(6.25f) ,dtype::QuantizedS8(60.25f) }; | |||||
| std::vector<DType> data_type = { | |||||
| dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f), | |||||
| dtype::QuantizedS32(6.25f), dtype::QuantizedS8(60.25f)}; | |||||
| printf("Benchmark WINOGRAD_IN8_MK8 algo\n"); | printf("Benchmark WINOGRAD_IN8_MK8 algo\n"); | ||||
| benchmark_impl(param, shapes_and_computation, algo_name, RUNS, | benchmark_impl(param, shapes_and_computation, algo_name, RUNS, | ||||
| {4, {4, 5, 6, 7}}, {1, {4}}, data_type); | {4, {4, 5, 6, 7}}, {1, {4}}, data_type); | ||||
| @@ -1839,7 +1892,6 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | |||||
| {1, {4}}, data_type); | {1, {4}}, data_type); | ||||
| } | } | ||||
| TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | ||||
| BENCHMARK_IM2COL_NCHW44_INT8x8x32_STRIDE1) { | BENCHMARK_IM2COL_NCHW44_INT8x8x32_STRIDE1) { | ||||
| constexpr size_t RUNS = 50; | constexpr size_t RUNS = 50; | ||||
| @@ -1852,18 +1904,17 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | |||||
| param.stride_w = 1; | param.stride_w = 1; | ||||
| param.sparse = param::ConvBias::Sparse::DENSE; | param.sparse = param::ConvBias::Sparse::DENSE; | ||||
| param.format = param::ConvBias::Format::NCHW44; | param.format = param::ConvBias::Format::NCHW44; | ||||
| std::vector<std::pair<SmallVector<TensorShape>, float>> | std::vector<std::pair<SmallVector<TensorShape>, float>> | ||||
| shapes_and_computation; | shapes_and_computation; | ||||
| auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, | auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, | ||||
| size_t FS, size_t group=1) { | |||||
| SmallVector<TensorShape> shapes{{N, IC, H, W,4}, | |||||
| {OC, IC / group, FS, FS,4,4}, | |||||
| size_t FS, size_t group = 1) { | |||||
| SmallVector<TensorShape> shapes{{N, IC, H, W, 4}, | |||||
| {OC, IC / group, FS, FS, 4, 4}, | |||||
| {/*1, OC, 1, 1*/}, | {/*1, OC, 1, 1*/}, | ||||
| {}, | {}, | ||||
| {N, OC, H, W,4}}; | |||||
| TensorShape dst{N, OC, H, W,4}; | |||||
| {N, OC, H, W, 4}}; | |||||
| TensorShape dst{N, OC, H, W, 4}; | |||||
| float computations = | float computations = | ||||
| ((4 * IC / group) * FS * FS * dst.total_nr_elems() * 2 + | ((4 * IC / group) * FS * FS * dst.total_nr_elems() * 2 + | ||||
| dst.total_nr_elems()) * | dst.total_nr_elems()) * | ||||
| @@ -1907,9 +1958,10 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | |||||
| #endif | #endif | ||||
| std::string algo_name = "IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:96"; | std::string algo_name = "IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:96"; | ||||
| printf("Benchmarker IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:96 algo\n"); | printf("Benchmarker IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:96 algo\n"); | ||||
| std::vector<DType> data_type = { | |||||
| dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f), | |||||
| dtype::QuantizedS32(6.25f), {}}; | |||||
| std::vector<DType> data_type = {dtype::QuantizedS8(2.5f), | |||||
| dtype::QuantizedS8(2.5f), | |||||
| dtype::QuantizedS32(6.25f), | |||||
| {}}; | |||||
| benchmark_impl(param, shapes_and_computation, algo_name, RUNS, | benchmark_impl(param, shapes_and_computation, algo_name, RUNS, | ||||
| {4, {4, 5, 6, 7}}, {1, {4}}, data_type); | {4, {4, 5, 6, 7}}, {1, {4}}, data_type); | ||||
| benchmark_impl(param, shapes_and_computation, algo_name, RUNS, | benchmark_impl(param, shapes_and_computation, algo_name, RUNS, | ||||
| @@ -1917,10 +1969,9 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | |||||
| benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, | benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, | ||||
| {1, {4}}, data_type); | {1, {4}}, data_type); | ||||
| algo_name = "IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:192"; | algo_name = "IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:192"; | ||||
| printf("Benchmarker IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:192 algo\n"); | |||||
| printf("Benchmarker IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:192 " | |||||
| "algo\n"); | |||||
| benchmark_impl(param, shapes_and_computation, algo_name, RUNS, | benchmark_impl(param, shapes_and_computation, algo_name, RUNS, | ||||
| {4, {4, 5, 6, 7}}, {1, {4}}, data_type); | {4, {4, 5, 6, 7}}, {1, {4}}, data_type); | ||||
| benchmark_impl(param, shapes_and_computation, algo_name, RUNS, | benchmark_impl(param, shapes_and_computation, algo_name, RUNS, | ||||
| @@ -1929,14 +1980,14 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | |||||
| {1, {4}}, data_type); | {1, {4}}, data_type); | ||||
| algo_name = "IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:384"; | algo_name = "IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:384"; | ||||
| printf("Benchmarker IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:384 algo\n"); | |||||
| printf("Benchmarker IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:384 " | |||||
| "algo\n"); | |||||
| benchmark_impl(param, shapes_and_computation, algo_name, RUNS, | benchmark_impl(param, shapes_and_computation, algo_name, RUNS, | ||||
| {4, {4, 5, 6, 7}}, {1, {4}}, data_type); | {4, {4, 5, 6, 7}}, {1, {4}}, data_type); | ||||
| benchmark_impl(param, shapes_and_computation, algo_name, RUNS, | benchmark_impl(param, shapes_and_computation, algo_name, RUNS, | ||||
| {4, {4, 5, 6, 7}}, {1, {7}}, data_type); | {4, {4, 5, 6, 7}}, {1, {7}}, data_type); | ||||
| benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, | benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, | ||||
| {1, {4}}, data_type); | {1, {4}}, data_type); | ||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -1185,9 +1185,10 @@ void check_conv_bias_preprocess(std::vector<conv_bias::TestArg> args, | |||||
| } | } | ||||
| } | } | ||||
| void checker_conv_bias_common(std::vector<conv_bias::TestArg> args, Handle* handle, | |||||
| RNG* rng, float epsilon, DType type0, DType type1, | |||||
| DType type2, DType type3, const char* algo_name) { | |||||
| void checker_conv_bias_common(std::vector<conv_bias::TestArg> args, | |||||
| Handle* handle, RNG* rng, float epsilon, | |||||
| DType type0, DType type1, DType type2, | |||||
| DType type3, const char* algo_name) { | |||||
| using namespace conv_bias; | using namespace conv_bias; | ||||
| Checker<ConvBias> checker(handle); | Checker<ConvBias> checker(handle); | ||||
| @@ -1377,6 +1378,88 @@ std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args( | |||||
| } | } | ||||
| return args; | return args; | ||||
| } | } | ||||
| std::vector<conv_bias::TestArg> get_nchw88_conv_bias_args( | |||||
| std::vector<size_t> kernel_vec, | |||||
| std::vector<param::ConvBias::NonlineMode> nlmode_vec, | |||||
| std::vector<megdnn::BiasMode> biasmode_vec, size_t stride) { | |||||
| using namespace conv_bias; | |||||
| using NLMode = param::ConvBias::NonlineMode; | |||||
| std::vector<TestArg> args; | |||||
| auto pack = [&](size_t n, size_t oc, size_t ic, size_t h, size_t w, | |||||
| size_t kernel, size_t stride, size_t group, NLMode nlmode, | |||||
| megdnn::BiasMode bias_mode) { | |||||
| constexpr int pack_c = 8; | |||||
| const size_t pad = kernel / 2; | |||||
| auto oc_per_group = oc / group; | |||||
| auto ic_per_group = ic / group; | |||||
| megdnn_assert(oc_per_group % pack_c == 0 && ic_per_group % pack_c == 0, | |||||
| "ocpg/icpg not divided by 8"); | |||||
| size_t kernel_h = kernel; | |||||
| size_t kernel_w = kernel; | |||||
| param::ConvBias param; | |||||
| param.format = param::ConvBias::Format::NCHW88; | |||||
| param.stride_h = stride; | |||||
| param.stride_w = stride; | |||||
| param.pad_h = pad; | |||||
| param.pad_w = pad; | |||||
| param.nonlineMode = nlmode; | |||||
| auto src_tensor_shape = TensorShape{n, ic / pack_c, h, w, pack_c}; | |||||
| auto weight_tensor_shape = TensorShape{ | |||||
| oc / pack_c, ic / pack_c, kernel_h, kernel_w, pack_c, pack_c}; | |||||
| auto bias_tensor_shape = TensorShape{}; | |||||
| if (bias_mode == megdnn::BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
| bias_tensor_shape = {1, oc / pack_c, 1, 1, pack_c}; | |||||
| } else if (bias_mode == megdnn::BiasMode::BIAS) { | |||||
| bias_tensor_shape = {n, oc / pack_c, | |||||
| (h + 2 * pad - kernel) / stride + 1, | |||||
| (w + 2 * pad - kernel) / stride + 1, pack_c}; | |||||
| } | |||||
| if (group == 1) { | |||||
| param.sparse = param::ConvBias::Sparse::DENSE; | |||||
| } else { | |||||
| param.sparse = param::ConvBias::Sparse::GROUP; | |||||
| weight_tensor_shape = TensorShape{group, | |||||
| oc_per_group / pack_c, | |||||
| ic_per_group / pack_c, | |||||
| kernel_h, | |||||
| kernel_w, | |||||
| pack_c, | |||||
| pack_c}; | |||||
| } | |||||
| args.emplace_back(param, src_tensor_shape, weight_tensor_shape, | |||||
| bias_tensor_shape); | |||||
| }; | |||||
| for (auto bias : biasmode_vec) | |||||
| for (auto nlmode : nlmode_vec) | |||||
| for (size_t n : {1, 2}) | |||||
| for (size_t kernel : kernel_vec) | |||||
| for (size_t oc : {8, 16}) | |||||
| for (size_t ic : {8, 16, 24}) | |||||
| for (size_t h : {1, 3, 12}) | |||||
| for (size_t w : {1, 8, 13}) { | |||||
| for (size_t group = 1; group < oc / 8; | |||||
| ++group) { | |||||
| if (ic % (group * 8) || | |||||
| oc % (group * 8)) { | |||||
| continue; | |||||
| } | |||||
| if (kernel < h || kernel < w) { | |||||
| continue; | |||||
| } | |||||
| pack(n, oc, ic, h, w, kernel, stride, | |||||
| group, nlmode, bias); | |||||
| } | |||||
| } | |||||
| return args; | |||||
| } | |||||
| } // namespace conv_bias | } // namespace conv_bias | ||||
| } // namespace test | } // namespace test | ||||
| } // namespace megdnn | } // namespace megdnn | ||||