GitOrigin-RevId: b06fba83eb
tags/v1.0.0-rc1
| @@ -0,0 +1,173 @@ | |||||
| /** | |||||
| * \file | |||||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h" | |||||
| #include "src/arm_common/conv_bias/opr_impl.h" | |||||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||||
| #include "src/fallback/conv_bias/common.h" | |||||
| namespace megdnn { | |||||
| namespace arm_common { | |||||
| namespace conv_bias { | |||||
| template <> | |||||
| void pack_src_fp32_nchw44<1>(float* sptr_base, const float* sptr_origin, | |||||
| const int, const int pw, const int pad_right, | |||||
| const int ih, const int iw, const int iw2, | |||||
| const int pad_top, const int pad_bottom, | |||||
| const int ic, const int ic_stride) { | |||||
| constexpr int ic_step = 4; | |||||
| rep_step(ic_idx, ic, ic_step) { | |||||
| const float* sptr = sptr_origin + ic_idx * ic_stride; | |||||
| memset(sptr_base, 0, sizeof(float) * iw2 * pad_top * ic_step); | |||||
| sptr_base += iw2 * pad_top * ic_step; | |||||
| rep(ih_idx, ih) { | |||||
| memset(sptr_base, 0, sizeof(float) * pw * ic_step); | |||||
| sptr_base += pw * ic_step; | |||||
| memcpy(sptr_base, sptr, sizeof(float) * iw * ic_step); | |||||
| sptr_base += iw * ic_step; | |||||
| sptr += iw * ic_step; | |||||
| memset(sptr_base, 0, sizeof(float) * pad_right * ic_step); | |||||
| sptr_base += pad_right * ic_step; | |||||
| } | |||||
| memset(sptr_base, 0, sizeof(float) * iw2 * pad_bottom * ic_step); | |||||
| sptr_base += iw2 * pad_bottom * ic_step; | |||||
| } | |||||
| } | |||||
| namespace { | |||||
| static inline void odd_even_split_iw8_even(float* sptr_base, const float* sptr, | |||||
| const int odd_start, | |||||
| const int src_idx, | |||||
| const int iw_idx) { | |||||
| constexpr int ic_step = 4; | |||||
| const int src_offset = src_idx * ic_step; | |||||
| const int even_offset = iw_idx / 2 * ic_step; | |||||
| const int odd_offset = (odd_start + iw_idx / 2) * ic_step; | |||||
| float32x4_t temp[8]; | |||||
| temp[0] = vld1q_f32(sptr + src_offset + 0 * ic_step); | |||||
| temp[1] = vld1q_f32(sptr + src_offset + 1 * ic_step); | |||||
| temp[2] = vld1q_f32(sptr + src_offset + 2 * ic_step); | |||||
| temp[3] = vld1q_f32(sptr + src_offset + 3 * ic_step); | |||||
| temp[4] = vld1q_f32(sptr + src_offset + 4 * ic_step); | |||||
| temp[5] = vld1q_f32(sptr + src_offset + 5 * ic_step); | |||||
| temp[6] = vld1q_f32(sptr + src_offset + 6 * ic_step); | |||||
| temp[7] = vld1q_f32(sptr + src_offset + 7 * ic_step); | |||||
| vst1q_f32(sptr_base + even_offset + 0 * ic_step, temp[0]); | |||||
| vst1q_f32(sptr_base + even_offset + 1 * ic_step, temp[2]); | |||||
| vst1q_f32(sptr_base + even_offset + 2 * ic_step, temp[4]); | |||||
| vst1q_f32(sptr_base + even_offset + 3 * ic_step, temp[6]); | |||||
| vst1q_f32(sptr_base + odd_offset + 0 * ic_step, temp[1]); | |||||
| vst1q_f32(sptr_base + odd_offset + 1 * ic_step, temp[3]); | |||||
| vst1q_f32(sptr_base + odd_offset + 2 * ic_step, temp[5]); | |||||
| vst1q_f32(sptr_base + odd_offset + 3 * ic_step, temp[7]); | |||||
| } | |||||
| static inline void odd_even_split_iw8_odd(float* sptr_base, const float* sptr, | |||||
| const int odd_start, | |||||
| const int src_idx, const int iw_idx) { | |||||
| constexpr int ic_step = 4; | |||||
| const int src_offset = src_idx * ic_step; | |||||
| const int even_offset = (iw_idx + 1) / 2 * ic_step; | |||||
| const int odd_offset = (odd_start + iw_idx / 2) * ic_step; | |||||
| float32x4_t temp[8]; | |||||
| temp[0] = vld1q_f32(sptr + src_offset + 0 * ic_step); | |||||
| temp[1] = vld1q_f32(sptr + src_offset + 1 * ic_step); | |||||
| temp[2] = vld1q_f32(sptr + src_offset + 2 * ic_step); | |||||
| temp[3] = vld1q_f32(sptr + src_offset + 3 * ic_step); | |||||
| temp[4] = vld1q_f32(sptr + src_offset + 4 * ic_step); | |||||
| temp[5] = vld1q_f32(sptr + src_offset + 5 * ic_step); | |||||
| temp[6] = vld1q_f32(sptr + src_offset + 6 * ic_step); | |||||
| temp[7] = vld1q_f32(sptr + src_offset + 7 * ic_step); | |||||
| vst1q_f32(sptr_base + odd_offset + 0 * ic_step, temp[0]); | |||||
| vst1q_f32(sptr_base + odd_offset + 1 * ic_step, temp[2]); | |||||
| vst1q_f32(sptr_base + odd_offset + 2 * ic_step, temp[4]); | |||||
| vst1q_f32(sptr_base + odd_offset + 3 * ic_step, temp[6]); | |||||
| vst1q_f32(sptr_base + even_offset + 0 * ic_step, temp[1]); | |||||
| vst1q_f32(sptr_base + even_offset + 1 * ic_step, temp[3]); | |||||
| vst1q_f32(sptr_base + even_offset + 2 * ic_step, temp[5]); | |||||
| vst1q_f32(sptr_base + even_offset + 3 * ic_step, temp[7]); | |||||
| } | |||||
| } // namespace | |||||
| template <> | |||||
| void pack_src_fp32_nchw44<2>(float* sptr_base, const float* sptr_origin, | |||||
| const int ph, const int pw, const int pad_right, | |||||
| const int ih, const int iw, const int iw2, | |||||
| const int pad_top, const int pad_bottom, | |||||
| const int ic, const int ic_stride) { | |||||
| constexpr int ic_step = 4; | |||||
| int odd_start = megdnn::div_ceil(iw2, 2); | |||||
| float32x4_t zero_v = vdupq_n_f32(0.f); | |||||
| MEGDNN_MARK_USED_VAR(ph); | |||||
| bool even_start = pw % 2 == 0; | |||||
| rep_step(ic_idx, ic, ic_step) { | |||||
| const float* sptr = sptr_origin + ic_idx * ic_stride; | |||||
| memset(sptr_base, 0, sizeof(float) * iw2 * pad_top * ic_step); | |||||
| sptr_base += iw2 * pad_top * ic_step; | |||||
| rep(ih_idx, ih) { | |||||
| int iw_idx = 0; | |||||
| rep(idx, pw) { | |||||
| if (iw_idx % 2 == 0) { | |||||
| vst1q_f32(sptr_base + iw_idx / 2 * ic_step, zero_v); | |||||
| } else { | |||||
| vst1q_f32(sptr_base + (odd_start + iw_idx / 2) * ic_step, | |||||
| zero_v); | |||||
| } | |||||
| ++iw_idx; | |||||
| } | |||||
| int src_idx = 0; | |||||
| if (even_start) { | |||||
| for (; src_idx + 7 < iw; src_idx += 8) { | |||||
| odd_even_split_iw8_even(sptr_base, sptr, odd_start, src_idx, | |||||
| iw_idx); | |||||
| iw_idx += 8; | |||||
| } | |||||
| } else { | |||||
| for (; src_idx + 7 < iw; src_idx += 8) { | |||||
| odd_even_split_iw8_odd(sptr_base, sptr, odd_start, src_idx, | |||||
| iw_idx); | |||||
| iw_idx += 8; | |||||
| } | |||||
| } | |||||
| for (; src_idx < iw; ++src_idx) { | |||||
| if (iw_idx % 2 == 0) { | |||||
| vst1q_f32(sptr_base + iw_idx / 2 * ic_step, | |||||
| vld1q_f32(sptr + src_idx * ic_step)); | |||||
| } else { | |||||
| vst1q_f32(sptr_base + (odd_start + iw_idx / 2) * ic_step, | |||||
| vld1q_f32(sptr + src_idx * ic_step)); | |||||
| } | |||||
| ++iw_idx; | |||||
| } | |||||
| rep(idx, pad_right) { | |||||
| if (iw_idx % 2 == 0) { | |||||
| vst1q_f32(sptr_base + iw_idx / 2 * ic_step, zero_v); | |||||
| } else { | |||||
| vst1q_f32(sptr_base + (odd_start + iw_idx / 2) * ic_step, | |||||
| zero_v); | |||||
| } | |||||
| ++iw_idx; | |||||
| } | |||||
| sptr_base += iw2 * ic_step; | |||||
| sptr += iw * ic_step; | |||||
| } | |||||
| memset(sptr_base, 0, sizeof(float) * iw2 * pad_bottom * ic_step); | |||||
| sptr_base += iw2 * pad_bottom * ic_step; | |||||
| } | |||||
| } | |||||
| } // namespace conv_bias | |||||
| } // namespace arm_common | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -0,0 +1,14 @@ | |||||
| /** | |||||
| * \file | |||||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s1.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||||
| INSTANTIATION_CONV_S1(2); | |||||
| @@ -0,0 +1,14 @@ | |||||
| /** | |||||
| * \file | |||||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_2x2s2.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||||
| INSTANTIATION_CONV_S2(2); | |||||
| @@ -0,0 +1,14 @@ | |||||
| /** | |||||
| * \file | |||||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s1.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||||
| INSTANTIATION_CONV_S1(3); | |||||
| @@ -0,0 +1,14 @@ | |||||
| /** | |||||
| * \file | |||||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_3x3s2.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||||
| INSTANTIATION_CONV_S2(3); | |||||
| @@ -0,0 +1,14 @@ | |||||
| /** | |||||
| * \file | |||||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s1.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||||
| INSTANTIATION_CONV_S1(5); | |||||
| @@ -0,0 +1,14 @@ | |||||
| /** | |||||
| * \file | |||||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_5x5s2.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||||
| INSTANTIATION_CONV_S2(5); | |||||
| @@ -0,0 +1,14 @@ | |||||
| /** | |||||
| * \file | |||||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s1.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h" | |||||
| INSTANTIATION_CONV_S1(7); | |||||
| @@ -0,0 +1,14 @@ | |||||
| /** | |||||
| * \file | |||||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_7x7s2.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h" | |||||
| INSTANTIATION_CONV_S2(7); | |||||
| @@ -1,6 +1,6 @@ | |||||
| /** | /** | ||||
| * \file | * \file | ||||
| * dnn/src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.cpp | |||||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
| * | * | ||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | ||||
| @@ -12,7 +12,7 @@ | |||||
| */ | */ | ||||
| #include "megdnn/arch.h" | #include "megdnn/arch.h" | ||||
| #include "src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.h" | |||||
| #include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h" | |||||
| #include "src/arm_common/conv_bias/intrinsic_helper.h" | #include "src/arm_common/conv_bias/intrinsic_helper.h" | ||||
| #include "src/arm_common/elemwise_op.h" | #include "src/arm_common/elemwise_op.h" | ||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| @@ -24,21 +24,21 @@ using namespace megdnn; | |||||
| using namespace arm_common; | using namespace arm_common; | ||||
| namespace { | namespace { | ||||
| template <int src_idx, int weight_idx, int c_dim, typename Func, int ow_block, | |||||
| typename T, typename T2, typename T3, typename T4> | |||||
| template <int src_idx, int weight_idx, int c_dim, int ow_block, typename T, | |||||
| typename T2, typename T3, typename T4> | |||||
| struct ShiftCalHelper { | struct ShiftCalHelper { | ||||
| static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight); | static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight); | ||||
| }; | }; | ||||
| template <int src_idx, int weight_idx, typename Func, typename T, typename T2, | |||||
| typename T3, typename T4> | |||||
| struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 8, T, T2, T3, T4> { | |||||
| template <int src_idx, int weight_idx, typename T, typename T2, typename T3, | |||||
| typename T4> | |||||
| struct ShiftCalHelper<src_idx, weight_idx, 2, 8, T, T2, T3, T4> { | |||||
| static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { | static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { | ||||
| #define cb(step, lane) \ | |||||
| c[0][step] = Func::template impl<lane>(c[0][step], weight[0][lane], \ | |||||
| src[(step + src_idx) % 8]); \ | |||||
| c[1][step] = Func::template impl<lane>(c[1][step], weight[1][lane], \ | |||||
| src[(step + src_idx) % 8]); | |||||
| #define cb(step, lane) \ | |||||
| c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \ | |||||
| src[(step + src_idx) % 8], lane); \ | |||||
| c[1][step] = vfmaq_laneq_f32(c[1][step], weight[1][lane], \ | |||||
| src[(step + src_idx) % 8], lane); | |||||
| UNROLL_CALL_RAW(8, cb, 0); | UNROLL_CALL_RAW(8, cb, 0); | ||||
| UNROLL_CALL_RAW(8, cb, 1); | UNROLL_CALL_RAW(8, cb, 1); | ||||
| @@ -47,15 +47,15 @@ struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 8, T, T2, T3, T4> { | |||||
| #undef cb | #undef cb | ||||
| } | } | ||||
| }; | }; | ||||
| template <int src_idx, int weight_idx, typename Func, typename T, typename T2, | |||||
| typename T3, typename T4> | |||||
| struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 4, T, T2, T3, T4> { | |||||
| template <int src_idx, int weight_idx, typename T, typename T2, typename T3, | |||||
| typename T4> | |||||
| struct ShiftCalHelper<src_idx, weight_idx, 2, 4, T, T2, T3, T4> { | |||||
| static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { | static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { | ||||
| #define cb(step, lane) \ | |||||
| c[0][step] = Func::template impl<lane>(c[0][step], weight[0][lane], \ | |||||
| src[(step + src_idx) % 4]); \ | |||||
| c[1][step] = Func::template impl<lane>(c[1][step], weight[1][lane], \ | |||||
| src[(step + src_idx) % 4]); | |||||
| #define cb(step, lane) \ | |||||
| c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \ | |||||
| src[(step + src_idx) % 4], lane); \ | |||||
| c[1][step] = vfmaq_laneq_f32(c[1][step], weight[1][lane], \ | |||||
| src[(step + src_idx) % 4], lane); | |||||
| UNROLL_CALL_RAW(4, cb, 0); | UNROLL_CALL_RAW(4, cb, 0); | ||||
| UNROLL_CALL_RAW(4, cb, 1); | UNROLL_CALL_RAW(4, cb, 1); | ||||
| @@ -64,13 +64,13 @@ struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 4, T, T2, T3, T4> { | |||||
| #undef cb | #undef cb | ||||
| } | } | ||||
| }; | }; | ||||
| template <int src_idx, int weight_idx, typename Func, typename T, typename T2, | |||||
| typename T3, typename T4> | |||||
| struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, T, T2, T3, T4> { | |||||
| template <int src_idx, int weight_idx, typename T, typename T2, typename T3, | |||||
| typename T4> | |||||
| struct ShiftCalHelper<src_idx, weight_idx, 1, 8, T, T2, T3, T4> { | |||||
| static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { | static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { | ||||
| #define cb(step, lane) \ | |||||
| c[0][step] = Func::template impl<lane>(c[0][step], weight[0][lane], \ | |||||
| src[(step + src_idx) % 8]); | |||||
| #define cb(step, lane) \ | |||||
| c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \ | |||||
| src[(step + src_idx) % 8], lane); | |||||
| UNROLL_CALL_RAW(8, cb, 0); | UNROLL_CALL_RAW(8, cb, 0); | ||||
| UNROLL_CALL_RAW(8, cb, 1); | UNROLL_CALL_RAW(8, cb, 1); | ||||
| @@ -79,13 +79,13 @@ struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, T, T2, T3, T4> { | |||||
| #undef cb | #undef cb | ||||
| } | } | ||||
| }; | }; | ||||
| template <int src_idx, int weight_idx, typename Func, typename T, typename T2, | |||||
| typename T3, typename T4> | |||||
| struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 4, T, T2, T3, T4> { | |||||
| template <int src_idx, int weight_idx, typename T, typename T2, typename T3, | |||||
| typename T4> | |||||
| struct ShiftCalHelper<src_idx, weight_idx, 1, 4, T, T2, T3, T4> { | |||||
| static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { | static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { | ||||
| #define cb(step, lane) \ | |||||
| c[0][step] = Func::template impl<lane>(c[0][step], weight[0][lane], \ | |||||
| src[(step + src_idx) % 4]); | |||||
| #define cb(step, lane) \ | |||||
| c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \ | |||||
| src[(step + src_idx) % 4], lane); | |||||
| UNROLL_CALL_RAW(4, cb, 0); | UNROLL_CALL_RAW(4, cb, 0); | ||||
| UNROLL_CALL_RAW(4, cb, 1); | UNROLL_CALL_RAW(4, cb, 1); | ||||
| @@ -95,11 +95,11 @@ struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 4, T, T2, T3, T4> { | |||||
| } | } | ||||
| }; | }; | ||||
| template <int src_idx, int weight_idx, int c_dim, typename FUNC, int ow_block, | |||||
| typename T, typename T2, typename T3> | |||||
| template <int src_idx, int weight_idx, int c_dim, int ow_block, typename T, | |||||
| typename T2, typename T3> | |||||
| MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) { | MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) { | ||||
| ShiftCalHelper<src_idx, weight_idx, c_dim, FUNC, ow_block, T, T2, T3, | |||||
| int>::impl(c, src, weight); | |||||
| ShiftCalHelper<src_idx, weight_idx, c_dim, ow_block, T, T2, T3, int>::impl( | |||||
| c, src, weight); | |||||
| }; | }; | ||||
| template <int oc> | template <int oc> | ||||
| struct OCHelper { | struct OCHelper { | ||||
| @@ -162,13 +162,11 @@ struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 2, oc_block, ow_block> { | |||||
| 0); | 0); | ||||
| load_helper<ic_step, 0, oc_step, c_dim, Vld1q_f32>( | load_helper<ic_step, 0, oc_step, c_dim, Vld1q_f32>( | ||||
| weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
| cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, | |||||
| weight); | |||||
| cal_helper<0, 0, c_dim, ow_block>(c, src, weight); | |||||
| src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step); | src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step); | ||||
| load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
| weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
| cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, | |||||
| weight); | |||||
| cal_helper<1, 0, c_dim, ow_block>(c, src, weight); | |||||
| src_ptr += ld_src_iw; | src_ptr += ld_src_iw; | ||||
| weight_ptr += ld_weight_fh; | weight_ptr += ld_weight_fh; | ||||
| } | } | ||||
| @@ -209,18 +207,15 @@ struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 3, oc_block, ow_block> { | |||||
| 0); | 0); | ||||
| load_helper<ic_step, 0, oc_step, c_dim, Vld1q_f32>( | load_helper<ic_step, 0, oc_step, c_dim, Vld1q_f32>( | ||||
| weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
| cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, | |||||
| weight); | |||||
| cal_helper<0, 0, c_dim, ow_block>(c, src, weight); | |||||
| src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step); | src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step); | ||||
| load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
| weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
| cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, | |||||
| weight); | |||||
| cal_helper<1, 0, c_dim, ow_block>(c, src, weight); | |||||
| src[1] = vld1q_f32(src_ptr + (ow_block + 1) * ic_step); | src[1] = vld1q_f32(src_ptr + (ow_block + 1) * ic_step); | ||||
| load_helper<ic_step, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<ic_step, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
| weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
| cal_helper<2, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, | |||||
| weight); | |||||
| cal_helper<2, 0, c_dim, ow_block>(c, src, weight); | |||||
| src_ptr += ld_src_iw; | src_ptr += ld_src_iw; | ||||
| weight_ptr += ld_weight_fh; | weight_ptr += ld_weight_fh; | ||||
| } | } | ||||
| @@ -260,32 +255,27 @@ struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 5, oc_block, ow_block> { | |||||
| 0); | 0); | ||||
| load_helper<ic_step, 0, oc_step, c_dim, Vld1q_f32>( | load_helper<ic_step, 0, oc_step, c_dim, Vld1q_f32>( | ||||
| weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
| cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, | |||||
| weight); | |||||
| cal_helper<0, 0, c_dim, ow_block>(c, src, weight); | |||||
| src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step); | src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step); | ||||
| load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
| weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
| cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, | |||||
| weight); | |||||
| cal_helper<1, 0, c_dim, ow_block>(c, src, weight); | |||||
| src[1] = vld1q_f32(src_ptr + (ow_block + 1) * ic_step); | src[1] = vld1q_f32(src_ptr + (ow_block + 1) * ic_step); | ||||
| load_helper<ic_step, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<ic_step, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
| weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
| cal_helper<2, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, | |||||
| weight); | |||||
| cal_helper<2, 0, c_dim, ow_block>(c, src, weight); | |||||
| src[2] = vld1q_f32(src_ptr + (ow_block + 2) * ic_step); | src[2] = vld1q_f32(src_ptr + (ow_block + 2) * ic_step); | ||||
| load_helper<ic_step, 3 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<ic_step, 3 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
| weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
| cal_helper<3, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, | |||||
| weight); | |||||
| cal_helper<3, 0, c_dim, ow_block>(c, src, weight); | |||||
| src[3] = vld1q_f32(src_ptr + (ow_block + 3) * ic_step); | src[3] = vld1q_f32(src_ptr + (ow_block + 3) * ic_step); | ||||
| load_helper<ic_step, 4 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<ic_step, 4 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
| weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
| cal_helper<4, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, | |||||
| weight); | |||||
| cal_helper<4, 0, c_dim, ow_block>(c, src, weight); | |||||
| src_ptr += ld_src_iw; | src_ptr += ld_src_iw; | ||||
| weight_ptr += ld_weight_fh; | weight_ptr += ld_weight_fh; | ||||
| } | } | ||||
| @@ -326,44 +316,37 @@ struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 7, oc_block, ow_block> { | |||||
| 0); | 0); | ||||
| load_helper<ic_step, 0, oc_step, c_dim, Vld1q_f32>( | load_helper<ic_step, 0, oc_step, c_dim, Vld1q_f32>( | ||||
| weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
| cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, | |||||
| weight); | |||||
| cal_helper<0, 0, c_dim, ow_block>(c, src, weight); | |||||
| src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step); | src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step); | ||||
| load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<ic_step, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
| weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
| cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, | |||||
| weight); | |||||
| cal_helper<1, 0, c_dim, ow_block>(c, src, weight); | |||||
| src[1] = vld1q_f32(src_ptr + (ow_block + 1) * ic_step); | src[1] = vld1q_f32(src_ptr + (ow_block + 1) * ic_step); | ||||
| load_helper<ic_step, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<ic_step, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
| weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
| cal_helper<2, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, | |||||
| weight); | |||||
| cal_helper<2, 0, c_dim, ow_block>(c, src, weight); | |||||
| src[2] = vld1q_f32(src_ptr + (ow_block + 2) * ic_step); | src[2] = vld1q_f32(src_ptr + (ow_block + 2) * ic_step); | ||||
| load_helper<ic_step, 3 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<ic_step, 3 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
| weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
| cal_helper<3, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, | |||||
| weight); | |||||
| cal_helper<3, 0, c_dim, ow_block>(c, src, weight); | |||||
| src[3] = vld1q_f32(src_ptr + (ow_block + 3) * ic_step); | src[3] = vld1q_f32(src_ptr + (ow_block + 3) * ic_step); | ||||
| load_helper<ic_step, 4 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<ic_step, 4 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
| weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
| cal_helper<4, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, | |||||
| weight); | |||||
| cal_helper<4, 0, c_dim, ow_block>(c, src, weight); | |||||
| src[4] = vld1q_f32(src_ptr + (ow_block + 4) * ic_step); | src[4] = vld1q_f32(src_ptr + (ow_block + 4) * ic_step); | ||||
| load_helper<ic_step, 5 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<ic_step, 5 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
| weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
| cal_helper<5, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, | |||||
| weight); | |||||
| cal_helper<5, 0, c_dim, ow_block>(c, src, weight); | |||||
| src[5] = vld1q_f32(src_ptr + (ow_block + 5) * ic_step); | src[5] = vld1q_f32(src_ptr + (ow_block + 5) * ic_step); | ||||
| load_helper<ic_step, 6 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<ic_step, 6 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
| weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
| cal_helper<6, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, | |||||
| weight); | |||||
| cal_helper<6, 0, c_dim, ow_block>(c, src, weight); | |||||
| src_ptr += ld_src_iw; | src_ptr += ld_src_iw; | ||||
| weight_ptr += ld_weight_fh; | weight_ptr += ld_weight_fh; | ||||
| } | } | ||||
| @@ -375,36 +358,14 @@ struct KerNeonXXs1Nchw44FP32<bias_mode, Op, remain_w, 7, oc_block, ow_block> { | |||||
| } // namespace | } // namespace | ||||
| void conv_bias::pack_src_fp32_nchw44_stride1( | |||||
| float* sptr_base, const float* sptr_origin, const int, const int pw, | |||||
| const int pad_right, const int ih, const int iw, const int iw2, | |||||
| const int pad_top, const int pad_bottom, const int ic, | |||||
| const int ic_stride) { | |||||
| constexpr int ic_step = 4; | |||||
| rep_step(ic_idx, ic, ic_step) { | |||||
| const float* sptr = sptr_origin + ic_idx * ic_stride; | |||||
| memset(sptr_base, 0, sizeof(float) * iw2 * pad_top * ic_step); | |||||
| sptr_base += iw2 * pad_top * ic_step; | |||||
| rep(ih_idx, ih) { | |||||
| memset(sptr_base, 0, sizeof(float) * pw * ic_step); | |||||
| sptr_base += pw * ic_step; | |||||
| memcpy(sptr_base, sptr, sizeof(float) * iw * ic_step); | |||||
| sptr_base += iw * ic_step; | |||||
| sptr += iw * ic_step; | |||||
| memset(sptr_base, 0, sizeof(float) * pad_right * ic_step); | |||||
| sptr_base += pad_right * ic_step; | |||||
| } | |||||
| memset(sptr_base, 0, sizeof(float) * iw2 * pad_bottom * ic_step); | |||||
| sptr_base += iw2 * pad_bottom * ic_step; | |||||
| } | |||||
| } | |||||
| template <BiasMode bias_mode, typename Op, int filter_size> | |||||
| static void conv_direct_stride1_fp32_nchw44( | |||||
| const float32_t* src, const float32_t* filter, const float32_t* bias, | |||||
| float32_t*, float32_t* dst, const int oc, const int ic, const int ih, | |||||
| const int iw, const int oh, const int oh_block, const int ow, | |||||
| const Op& op, const int, const int) { | |||||
| template <BiasMode bias_mode, typename Op, int filter_size, int stride> | |||||
| void conv_bias::conv_direct_fp32_nchw44(const float* src, const float* filter, | |||||
| const float* bias, float*, float* dst, | |||||
| const int oc, const int ic, | |||||
| const int ih, const int iw, | |||||
| const int oh, const int oh_block, | |||||
| const int ow, const Op& op, const int, | |||||
| const int) { | |||||
| constexpr int fh = filter_size; | constexpr int fh = filter_size; | ||||
| constexpr int fw = filter_size; | constexpr int fw = filter_size; | ||||
| constexpr int ic_step = 4; | constexpr int ic_step = 4; | ||||
| @@ -518,55 +479,23 @@ static void conv_direct_stride1_fp32_nchw44( | |||||
| } | } | ||||
| } | } | ||||
| #define CONSTRUCT_FUNC(filter_size) \ | |||||
| template <BiasMode bias_mode, typename Op> \ | |||||
| void conv_bias:: \ | |||||
| conv_direct_stride1_##filter_size##x##filter_size##_fp32_nchw44( \ | |||||
| const float32_t* src, const float32_t* filter, \ | |||||
| const float32_t* bias, float32_t* temp, float32_t* dst, \ | |||||
| const int oc, const int ic, const int ih, const int iw, \ | |||||
| const int oh, const int oh_block, const int ow, \ | |||||
| const Op& op, const int ph, const int pw) { \ | |||||
| conv_direct_stride1_fp32_nchw44<bias_mode, Op, filter_size>( \ | |||||
| src, filter, bias, temp, dst, oc, ic, ih, iw, oh, oh_block, \ | |||||
| ow, op, ph, pw); \ | |||||
| } | |||||
| CONSTRUCT_FUNC(2); | |||||
| CONSTRUCT_FUNC(3); | |||||
| CONSTRUCT_FUNC(5); | |||||
| CONSTRUCT_FUNC(7); | |||||
| #undef CONSTRUCT_FUNC | |||||
| #define INSTANTIATION(stride, i, bias, Op) \ | |||||
| template void conv_bias::conv_direct_##stride##_##i##x##i##_fp32_nchw44< \ | |||||
| bias, Op>(const float32_t*, const float32_t*, const float32_t*, \ | |||||
| float32_t*, float32_t*, const int, const int, const int, \ | |||||
| const int, const int, const int, const int, const Op&, \ | |||||
| const int, const int); | |||||
| #define FOR_OP(stride, i, bias) \ | |||||
| INSTANTIATION(stride, i, bias, NoneOp<dt_float32>) \ | |||||
| INSTANTIATION(stride, i, bias, ReluOp<dt_float32>) \ | |||||
| INSTANTIATION(stride, i, bias, HSwishOp<dt_float32>) \ | |||||
| INSTANTIATION(stride, i, bias, SigmoidOp<dt_float32>) | |||||
| #define FOR_BIAS(stride, i) \ | |||||
| FOR_OP(stride, i, BiasMode::NO_BIAS) \ | |||||
| FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS) \ | |||||
| FOR_OP(stride, i, BiasMode::BIAS) | |||||
| #define FOR_FILTER(stride) \ | |||||
| FOR_BIAS(stride, 2) \ | |||||
| FOR_BIAS(stride, 3) \ | |||||
| FOR_BIAS(stride, 5) \ | |||||
| FOR_BIAS(stride, 7) | |||||
| FOR_FILTER(stride1) | |||||
| #undef FOR_STRIDE | |||||
| #undef FOR_FILTER | |||||
| #undef FOR_IC | |||||
| #undef FOR_BIAS | |||||
| #undef FOR_NONLINEAR | |||||
| #undef FOR_REMAIN | |||||
| #undef INSTANTIATION | |||||
| #define INSTANTIATION(filter_size, bias_mode, Op) \ | |||||
| template void \ | |||||
| conv_bias::conv_direct_fp32_nchw44<bias_mode, Op, filter_size, 1>( \ | |||||
| const float* src, const float* filter, const float* bias, float*, \ | |||||
| float* dst, const int oc, const int ic, const int ih, \ | |||||
| const int iw, const int oh, const int oh_block, const int ow, \ | |||||
| const Op& op, const int, const int); | |||||
| #define FOR_OP(filter_size, bias) \ | |||||
| INSTANTIATION(filter_size, bias, NoneOp<dt_float32>) \ | |||||
| INSTANTIATION(filter_size, bias, ReluOp<dt_float32>) \ | |||||
| INSTANTIATION(filter_size, bias, HSwishOp<dt_float32>) \ | |||||
| INSTANTIATION(filter_size, bias, SigmoidOp<dt_float32>) | |||||
| #define INSTANTIATION_CONV_S1(filter_size) \ | |||||
| FOR_OP(filter_size, BiasMode::NO_BIAS) \ | |||||
| FOR_OP(filter_size, BiasMode::BROADCAST_CHANNEL_BIAS) \ | |||||
| FOR_OP(filter_size, BiasMode::BIAS) | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -1,6 +1,6 @@ | |||||
| /** | /** | ||||
| * \file | * \file | ||||
| * dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.cpp | |||||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
| * | * | ||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | ||||
| @@ -12,7 +12,7 @@ | |||||
| */ | */ | ||||
| #include "megdnn/arch.h" | #include "megdnn/arch.h" | ||||
| #include "src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.h" | |||||
| #include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h" | |||||
| #include "src/arm_common/conv_bias/intrinsic_helper.h" | #include "src/arm_common/conv_bias/intrinsic_helper.h" | ||||
| #include "src/arm_common/elemwise_op.h" | #include "src/arm_common/elemwise_op.h" | ||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| @@ -24,21 +24,21 @@ using namespace megdnn; | |||||
| using namespace arm_common; | using namespace arm_common; | ||||
| namespace { | namespace { | ||||
| template <int src_idx, int weight_idx, int c_dim, typename Func, int ow_block, | |||||
| typename T, typename T2, typename T3, typename T4> | |||||
| template <int src_idx, int weight_idx, int c_dim, int ow_block, typename T, | |||||
| typename T2, typename T3, typename T4> | |||||
| struct ShiftCalHelper { | struct ShiftCalHelper { | ||||
| static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight); | static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight); | ||||
| }; | }; | ||||
| template <int src_idx, int weight_idx, typename Func, typename T, typename T2, | |||||
| typename T3, typename T4> | |||||
| struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 8, T, T2, T3, T4> { | |||||
| template <int src_idx, int weight_idx, typename T, typename T2, typename T3, | |||||
| typename T4> | |||||
| struct ShiftCalHelper<src_idx, weight_idx, 2, 8, T, T2, T3, T4> { | |||||
| static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { | static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { | ||||
| #define cb(step, lane) \ | |||||
| c[0][step] = Func::template impl<lane>(c[0][step], weight[0][lane], \ | |||||
| src[(step + src_idx) % 8]); \ | |||||
| c[1][step] = Func::template impl<lane>(c[1][step], weight[1][lane], \ | |||||
| src[(step + src_idx) % 8]); | |||||
| #define cb(step, lane) \ | |||||
| c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \ | |||||
| src[(step + src_idx) % 8], lane); \ | |||||
| c[1][step] = vfmaq_laneq_f32(c[1][step], weight[1][lane], \ | |||||
| src[(step + src_idx) % 8], lane); | |||||
| UNROLL_CALL_RAW(8, cb, 0); | UNROLL_CALL_RAW(8, cb, 0); | ||||
| UNROLL_CALL_RAW(8, cb, 1); | UNROLL_CALL_RAW(8, cb, 1); | ||||
| @@ -47,15 +47,15 @@ struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 8, T, T2, T3, T4> { | |||||
| #undef cb | #undef cb | ||||
| } | } | ||||
| }; | }; | ||||
| template <int src_idx, int weight_idx, typename Func, typename T, typename T2, | |||||
| typename T3, typename T4> | |||||
| struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 4, T, T2, T3, T4> { | |||||
| template <int src_idx, int weight_idx, typename T, typename T2, typename T3, | |||||
| typename T4> | |||||
| struct ShiftCalHelper<src_idx, weight_idx, 2, 4, T, T2, T3, T4> { | |||||
| static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { | static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { | ||||
| #define cb(step, lane) \ | |||||
| c[0][step] = Func::template impl<lane>(c[0][step], weight[0][lane], \ | |||||
| src[(step + src_idx) % 4]); \ | |||||
| c[1][step] = Func::template impl<lane>(c[1][step], weight[1][lane], \ | |||||
| src[(step + src_idx) % 4]); | |||||
| #define cb(step, lane) \ | |||||
| c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \ | |||||
| src[(step + src_idx) % 4], lane); \ | |||||
| c[1][step] = vfmaq_laneq_f32(c[1][step], weight[1][lane], \ | |||||
| src[(step + src_idx) % 4], lane); | |||||
| UNROLL_CALL_RAW(4, cb, 0); | UNROLL_CALL_RAW(4, cb, 0); | ||||
| UNROLL_CALL_RAW(4, cb, 1); | UNROLL_CALL_RAW(4, cb, 1); | ||||
| @@ -64,13 +64,13 @@ struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 4, T, T2, T3, T4> { | |||||
| #undef cb | #undef cb | ||||
| } | } | ||||
| }; | }; | ||||
| template <int src_idx, int weight_idx, typename Func, typename T, typename T2, | |||||
| typename T3, typename T4> | |||||
| struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, T, T2, T3, T4> { | |||||
| template <int src_idx, int weight_idx, typename T, typename T2, typename T3, | |||||
| typename T4> | |||||
| struct ShiftCalHelper<src_idx, weight_idx, 1, 8, T, T2, T3, T4> { | |||||
| static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { | static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { | ||||
| #define cb(step, lane) \ | |||||
| c[0][step] = Func::template impl<lane>(c[0][step], weight[0][lane], \ | |||||
| src[(step + src_idx) % 8]); | |||||
| #define cb(step, lane) \ | |||||
| c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \ | |||||
| src[(step + src_idx) % 8], lane); | |||||
| UNROLL_CALL_RAW(8, cb, 0); | UNROLL_CALL_RAW(8, cb, 0); | ||||
| UNROLL_CALL_RAW(8, cb, 1); | UNROLL_CALL_RAW(8, cb, 1); | ||||
| @@ -79,13 +79,13 @@ struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, T, T2, T3, T4> { | |||||
| #undef cb | #undef cb | ||||
| } | } | ||||
| }; | }; | ||||
| template <int src_idx, int weight_idx, typename Func, typename T, typename T2, | |||||
| typename T3, typename T4> | |||||
| struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 4, T, T2, T3, T4> { | |||||
| template <int src_idx, int weight_idx, typename T, typename T2, typename T3, | |||||
| typename T4> | |||||
| struct ShiftCalHelper<src_idx, weight_idx, 1, 4, T, T2, T3, T4> { | |||||
| static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { | static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { | ||||
| #define cb(step, lane) \ | |||||
| c[0][step] = Func::template impl<lane>(c[0][step], weight[0][lane], \ | |||||
| src[(step + src_idx) % 4]); | |||||
| #define cb(step, lane) \ | |||||
| c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \ | |||||
| src[(step + src_idx) % 4], lane); | |||||
| UNROLL_CALL_RAW(4, cb, 0); | UNROLL_CALL_RAW(4, cb, 0); | ||||
| UNROLL_CALL_RAW(4, cb, 1); | UNROLL_CALL_RAW(4, cb, 1); | ||||
| @@ -95,11 +95,11 @@ struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 4, T, T2, T3, T4> { | |||||
| } | } | ||||
| }; | }; | ||||
| template <int src_idx, int weight_idx, int c_dim, typename FUNC, int ow_block, | |||||
| typename T, typename T2, typename T3> | |||||
| template <int src_idx, int weight_idx, int c_dim, int ow_block, typename T, | |||||
| typename T2, typename T3> | |||||
| MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) { | MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) { | ||||
| ShiftCalHelper<src_idx, weight_idx, c_dim, FUNC, ow_block, T, T2, T3, | |||||
| int>::impl(c, src, weight); | |||||
| ShiftCalHelper<src_idx, weight_idx, c_dim, ow_block, T, T2, T3, int>::impl( | |||||
| c, src, weight); | |||||
| }; | }; | ||||
| template <int oc> | template <int oc> | ||||
| struct OCHelper { | struct OCHelper { | ||||
| @@ -163,13 +163,13 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 2, oc_block, ow_block> { | |||||
| load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); | load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); | ||||
| load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, | load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, | ||||
| ld_weight_oc); | ld_weight_oc); | ||||
| cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight); | |||||
| cal_helper<0, 0, c_dim, ow_block>(c, src, weight); | |||||
| load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd, | load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd, | ||||
| 0); | 0); | ||||
| load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
| weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
| cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight); | |||||
| cal_helper<0, 0, c_dim, ow_block>(c, src, weight); | |||||
| src_ptr += ld_src_iw; | src_ptr += ld_src_iw; | ||||
| src_ptr_odd += ld_src_iw; | src_ptr_odd += ld_src_iw; | ||||
| weight_ptr += ld_weight_fh; | weight_ptr += ld_weight_fh; | ||||
| @@ -177,13 +177,13 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 2, oc_block, ow_block> { | |||||
| load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); | load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); | ||||
| load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, | load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, | ||||
| ld_weight_oc); | ld_weight_oc); | ||||
| cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight); | |||||
| cal_helper<0, 0, c_dim, ow_block>(c, src, weight); | |||||
| load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd, | load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd, | ||||
| 0); | 0); | ||||
| load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
| weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
| cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight); | |||||
| cal_helper<0, 0, c_dim, ow_block>(c, src, weight); | |||||
| src_ptr += ld_src_iw; | src_ptr += ld_src_iw; | ||||
| src_ptr_odd += ld_src_iw; | src_ptr_odd += ld_src_iw; | ||||
| weight_ptr += ld_weight_fh; | weight_ptr += ld_weight_fh; | ||||
| @@ -224,18 +224,18 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 3, oc_block, ow_block> { | |||||
| load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); | load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); | ||||
| load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, | load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, | ||||
| ld_weight_oc); | ld_weight_oc); | ||||
| cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight); | |||||
| cal_helper<0, 0, c_dim, ow_block>(c, src, weight); | |||||
| src[0] = vld1q_f32(src_ptr + ow_block * simd_len); | src[0] = vld1q_f32(src_ptr + ow_block * simd_len); | ||||
| load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
| weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
| cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight); | |||||
| cal_helper<1, 0, c_dim, ow_block>(c, src, weight); | |||||
| load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd, | load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd, | ||||
| 0); | 0); | ||||
| load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
| weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
| cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight); | |||||
| cal_helper<0, 0, c_dim, ow_block>(c, src, weight); | |||||
| src_ptr += ld_src_iw; | src_ptr += ld_src_iw; | ||||
| src_ptr_odd += ld_src_iw; | src_ptr_odd += ld_src_iw; | ||||
| weight_ptr += ld_weight_fh; | weight_ptr += ld_weight_fh; | ||||
| @@ -243,17 +243,17 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 3, oc_block, ow_block> { | |||||
| load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); | load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); | ||||
| load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, | load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, | ||||
| ld_weight_oc); | ld_weight_oc); | ||||
| cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight); | |||||
| cal_helper<0, 0, c_dim, ow_block>(c, src, weight); | |||||
| src[0] = vld1q_f32(src_ptr + ow_block * simd_len); | src[0] = vld1q_f32(src_ptr + ow_block * simd_len); | ||||
| load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
| weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
| cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight); | |||||
| cal_helper<1, 0, c_dim, ow_block>(c, src, weight); | |||||
| load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd, | load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd, | ||||
| 0); | 0); | ||||
| load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
| weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
| cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight); | |||||
| cal_helper<0, 0, c_dim, ow_block>(c, src, weight); | |||||
| src_ptr += ld_src_iw; | src_ptr += ld_src_iw; | ||||
| src_ptr_odd += ld_src_iw; | src_ptr_odd += ld_src_iw; | ||||
| weight_ptr += ld_weight_fh; | weight_ptr += ld_weight_fh; | ||||
| @@ -261,18 +261,18 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 3, oc_block, ow_block> { | |||||
| load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); | load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); | ||||
| load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, | load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, | ||||
| ld_weight_oc); | ld_weight_oc); | ||||
| cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight); | |||||
| cal_helper<0, 0, c_dim, ow_block>(c, src, weight); | |||||
| src[0] = vld1q_f32(src_ptr + ow_block * simd_len); | src[0] = vld1q_f32(src_ptr + ow_block * simd_len); | ||||
| load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
| weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
| cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight); | |||||
| cal_helper<1, 0, c_dim, ow_block>(c, src, weight); | |||||
| load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd, | load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>(src, src_ptr_odd, | ||||
| 0); | 0); | ||||
| load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
| weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
| cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, weight); | |||||
| cal_helper<0, 0, c_dim, ow_block>(c, src, weight); | |||||
| src_ptr += ld_src_iw; | src_ptr += ld_src_iw; | ||||
| src_ptr_odd += ld_src_iw; | src_ptr_odd += ld_src_iw; | ||||
| weight_ptr += ld_weight_fh; | weight_ptr += ld_weight_fh; | ||||
| @@ -316,30 +316,25 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 5, oc_block, ow_block> { | |||||
| 0); | 0); | ||||
| load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, | load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, | ||||
| ld_weight_oc); | ld_weight_oc); | ||||
| cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, | |||||
| weight); | |||||
| cal_helper<0, 0, c_dim, ow_block>(c, src, weight); | |||||
| src[0] = vld1q_f32(src_ptr + ow_block * simd_len); | src[0] = vld1q_f32(src_ptr + ow_block * simd_len); | ||||
| load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
| weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
| cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, | |||||
| weight); | |||||
| cal_helper<1, 0, c_dim, ow_block>(c, src, weight); | |||||
| src[1] = vld1q_f32(src_ptr + (ow_block + 1) * simd_len); | src[1] = vld1q_f32(src_ptr + (ow_block + 1) * simd_len); | ||||
| load_helper<4, 4 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<4, 4 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
| weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
| cal_helper<2, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, | |||||
| weight); | |||||
| cal_helper<2, 0, c_dim, ow_block>(c, src, weight); | |||||
| // odd element | // odd element | ||||
| load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>( | load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>( | ||||
| src, src_ptr_odd, 0); | src, src_ptr_odd, 0); | ||||
| load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
| weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
| cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, | |||||
| weight); | |||||
| cal_helper<0, 0, c_dim, ow_block>(c, src, weight); | |||||
| src[0] = vld1q_f32(src_ptr_odd + ow_block * simd_len); | src[0] = vld1q_f32(src_ptr_odd + ow_block * simd_len); | ||||
| load_helper<4, 3 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<4, 3 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
| weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
| cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, | |||||
| weight); | |||||
| cal_helper<1, 0, c_dim, ow_block>(c, src, weight); | |||||
| src_ptr += ld_src_iw; | src_ptr += ld_src_iw; | ||||
| src_ptr_odd += ld_src_iw; | src_ptr_odd += ld_src_iw; | ||||
| @@ -390,40 +385,33 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 7, oc_block, ow_block> { | |||||
| 0); | 0); | ||||
| load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, | load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, | ||||
| ld_weight_oc); | ld_weight_oc); | ||||
| cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, | |||||
| weight); | |||||
| cal_helper<0, 0, c_dim, ow_block>(c, src, weight); | |||||
| src[0] = vld1q_f32(src_ptr + ow_block * simd_len); | src[0] = vld1q_f32(src_ptr + ow_block * simd_len); | ||||
| load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
| weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
| cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, | |||||
| weight); | |||||
| cal_helper<1, 0, c_dim, ow_block>(c, src, weight); | |||||
| src[1] = vld1q_f32(src_ptr + (ow_block + 1) * simd_len); | src[1] = vld1q_f32(src_ptr + (ow_block + 1) * simd_len); | ||||
| load_helper<4, 4 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<4, 4 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
| weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
| cal_helper<2, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, | |||||
| weight); | |||||
| cal_helper<2, 0, c_dim, ow_block>(c, src, weight); | |||||
| src[2] = vld1q_f32(src_ptr + (ow_block + 2) * simd_len); | src[2] = vld1q_f32(src_ptr + (ow_block + 2) * simd_len); | ||||
| load_helper<4, 6 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<4, 6 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
| weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
| cal_helper<3, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, | |||||
| weight); | |||||
| cal_helper<3, 0, c_dim, ow_block>(c, src, weight); | |||||
| // odd element | // odd element | ||||
| load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>( | load_helper<ow_block, 0, simd_len, 0, Vld1q_f32>( | ||||
| src, src_ptr_odd, 0); | src, src_ptr_odd, 0); | ||||
| load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
| weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
| cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, | |||||
| weight); | |||||
| cal_helper<0, 0, c_dim, ow_block>(c, src, weight); | |||||
| src[0] = vld1q_f32(src_ptr_odd + ow_block * simd_len); | src[0] = vld1q_f32(src_ptr_odd + ow_block * simd_len); | ||||
| load_helper<4, 3 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<4, 3 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
| weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
| cal_helper<1, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, | |||||
| weight); | |||||
| cal_helper<1, 0, c_dim, ow_block>(c, src, weight); | |||||
| src[1] = vld1q_f32(src_ptr_odd + (ow_block + 1) * simd_len); | src[1] = vld1q_f32(src_ptr_odd + (ow_block + 1) * simd_len); | ||||
| load_helper<4, 5 * ld_weight, oc_step, c_dim, Vld1q_f32>( | load_helper<4, 5 * ld_weight, oc_step, c_dim, Vld1q_f32>( | ||||
| weight, weight_ptr, ld_weight_oc); | weight, weight_ptr, ld_weight_oc); | ||||
| cal_helper<2, 0, c_dim, Vfmaq_laneq_f32, ow_block>(c, src, | |||||
| weight); | |||||
| cal_helper<2, 0, c_dim, ow_block>(c, src, weight); | |||||
| src_ptr += ld_src_iw; | src_ptr += ld_src_iw; | ||||
| src_ptr_odd += ld_src_iw; | src_ptr_odd += ld_src_iw; | ||||
| @@ -436,133 +424,15 @@ struct KerNeonXXs2Nchw44FP32<bias_mode, Op, remain_w, 7, oc_block, ow_block> { | |||||
| }; | }; | ||||
| } // namespace | } // namespace | ||||
| namespace { | |||||
| inline void odd_even_split_iw8_even(float* sptr_base, const float* sptr, | |||||
| const int odd_start, const int src_idx, | |||||
| const int iw_idx) { | |||||
| constexpr int ic_step = 4; | |||||
| const int src_offset = src_idx * ic_step; | |||||
| const int even_offset = iw_idx / 2 * ic_step; | |||||
| const int odd_offset = (odd_start + iw_idx / 2) * ic_step; | |||||
| float32x4_t temp[8]; | |||||
| temp[0] = vld1q_f32(sptr + src_offset + 0 * ic_step); | |||||
| temp[1] = vld1q_f32(sptr + src_offset + 1 * ic_step); | |||||
| temp[2] = vld1q_f32(sptr + src_offset + 2 * ic_step); | |||||
| temp[3] = vld1q_f32(sptr + src_offset + 3 * ic_step); | |||||
| temp[4] = vld1q_f32(sptr + src_offset + 4 * ic_step); | |||||
| temp[5] = vld1q_f32(sptr + src_offset + 5 * ic_step); | |||||
| temp[6] = vld1q_f32(sptr + src_offset + 6 * ic_step); | |||||
| temp[7] = vld1q_f32(sptr + src_offset + 7 * ic_step); | |||||
| vst1q_f32(sptr_base + even_offset + 0 * ic_step, temp[0]); | |||||
| vst1q_f32(sptr_base + even_offset + 1 * ic_step, temp[2]); | |||||
| vst1q_f32(sptr_base + even_offset + 2 * ic_step, temp[4]); | |||||
| vst1q_f32(sptr_base + even_offset + 3 * ic_step, temp[6]); | |||||
| vst1q_f32(sptr_base + odd_offset + 0 * ic_step, temp[1]); | |||||
| vst1q_f32(sptr_base + odd_offset + 1 * ic_step, temp[3]); | |||||
| vst1q_f32(sptr_base + odd_offset + 2 * ic_step, temp[5]); | |||||
| vst1q_f32(sptr_base + odd_offset + 3 * ic_step, temp[7]); | |||||
| } | |||||
| inline void odd_even_split_iw8_odd(float* sptr_base, const float* sptr, | |||||
| const int odd_start, const int src_idx, | |||||
| const int iw_idx) { | |||||
| constexpr int ic_step = 4; | |||||
| const int src_offset = src_idx * ic_step; | |||||
| const int even_offset = (iw_idx + 1) / 2 * ic_step; | |||||
| const int odd_offset = (odd_start + iw_idx / 2) * ic_step; | |||||
| float32x4_t temp[8]; | |||||
| temp[0] = vld1q_f32(sptr + src_offset + 0 * ic_step); | |||||
| temp[1] = vld1q_f32(sptr + src_offset + 1 * ic_step); | |||||
| temp[2] = vld1q_f32(sptr + src_offset + 2 * ic_step); | |||||
| temp[3] = vld1q_f32(sptr + src_offset + 3 * ic_step); | |||||
| temp[4] = vld1q_f32(sptr + src_offset + 4 * ic_step); | |||||
| temp[5] = vld1q_f32(sptr + src_offset + 5 * ic_step); | |||||
| temp[6] = vld1q_f32(sptr + src_offset + 6 * ic_step); | |||||
| temp[7] = vld1q_f32(sptr + src_offset + 7 * ic_step); | |||||
| vst1q_f32(sptr_base + odd_offset + 0 * ic_step, temp[0]); | |||||
| vst1q_f32(sptr_base + odd_offset + 1 * ic_step, temp[2]); | |||||
| vst1q_f32(sptr_base + odd_offset + 2 * ic_step, temp[4]); | |||||
| vst1q_f32(sptr_base + odd_offset + 3 * ic_step, temp[6]); | |||||
| vst1q_f32(sptr_base + even_offset + 0 * ic_step, temp[1]); | |||||
| vst1q_f32(sptr_base + even_offset + 1 * ic_step, temp[3]); | |||||
| vst1q_f32(sptr_base + even_offset + 2 * ic_step, temp[5]); | |||||
| vst1q_f32(sptr_base + even_offset + 3 * ic_step, temp[7]); | |||||
| } | |||||
| } // namespace | |||||
| void conv_bias::pack_src_fp32_nchw44_stride2( | |||||
| float* sptr_base, const float* sptr_origin, const int ph, const int pw, | |||||
| const int pad_right, const int ih, const int iw, const int iw2, | |||||
| const int pad_top, const int pad_bottom, const int ic, | |||||
| const int ic_stride) { | |||||
| constexpr int ic_step = 4; | |||||
| int odd_start = megdnn::div_ceil(iw2, 2); | |||||
| float32x4_t zero_v = vdupq_n_f32(0.f); | |||||
| MEGDNN_MARK_USED_VAR(ph); | |||||
| bool even_start = pw % 2 == 0; | |||||
| rep_step(ic_idx, ic, ic_step) { | |||||
| const float* sptr = sptr_origin + ic_idx * ic_stride; | |||||
| memset(sptr_base, 0, sizeof(float) * iw2 * pad_top * ic_step); | |||||
| sptr_base += iw2 * pad_top * ic_step; | |||||
| rep(ih_idx, ih) { | |||||
| int iw_idx = 0; | |||||
| rep(idx, pw) { | |||||
| if (iw_idx % 2 == 0) { | |||||
| vst1q_f32(sptr_base + iw_idx / 2 * ic_step, zero_v); | |||||
| } else { | |||||
| vst1q_f32(sptr_base + (odd_start + iw_idx / 2) * ic_step, | |||||
| zero_v); | |||||
| } | |||||
| ++iw_idx; | |||||
| } | |||||
| int src_idx = 0; | |||||
| if (even_start) { | |||||
| for (; src_idx + 7 < iw; src_idx += 8) { | |||||
| odd_even_split_iw8_even(sptr_base, sptr, odd_start, src_idx, | |||||
| iw_idx); | |||||
| iw_idx += 8; | |||||
| } | |||||
| } else { | |||||
| for (; src_idx + 7 < iw; src_idx += 8) { | |||||
| odd_even_split_iw8_odd(sptr_base, sptr, odd_start, src_idx, | |||||
| iw_idx); | |||||
| iw_idx += 8; | |||||
| } | |||||
| } | |||||
| for (; src_idx < iw; ++src_idx) { | |||||
| if (iw_idx % 2 == 0) { | |||||
| vst1q_f32(sptr_base + iw_idx / 2 * ic_step, | |||||
| vld1q_f32(sptr + src_idx * ic_step)); | |||||
| } else { | |||||
| vst1q_f32(sptr_base + (odd_start + iw_idx / 2) * ic_step, | |||||
| vld1q_f32(sptr + src_idx * ic_step)); | |||||
| } | |||||
| ++iw_idx; | |||||
| } | |||||
| rep(idx, pad_right) { | |||||
| if (iw_idx % 2 == 0) { | |||||
| vst1q_f32(sptr_base + iw_idx / 2 * ic_step, zero_v); | |||||
| } else { | |||||
| vst1q_f32(sptr_base + (odd_start + iw_idx / 2) * ic_step, | |||||
| zero_v); | |||||
| } | |||||
| ++iw_idx; | |||||
| } | |||||
| sptr_base += iw2 * ic_step; | |||||
| sptr += iw * ic_step; | |||||
| } | |||||
| memset(sptr_base, 0, sizeof(float) * iw2 * pad_bottom * ic_step); | |||||
| sptr_base += iw2 * pad_bottom * ic_step; | |||||
| } | |||||
| } | |||||
| template <BiasMode bias_mode, typename Op, int filter_size> | |||||
| static void conv_direct_stride2_fp32_nchw44( | |||||
| const float32_t* src, const float32_t* filter, const float32_t* bias, | |||||
| float32_t*, float32_t* dst, const int oc, const int ic, const int ih, | |||||
| const int iw, const int oh, const int oh_block, const int ow, | |||||
| const Op& op, const int, const int) { | |||||
| template <BiasMode bias_mode, typename Op, int filter_size, int stride> | |||||
| void conv_bias::conv_direct_fp32_nchw44(const float* src, const float* filter, | |||||
| const float* bias, float*, float* dst, | |||||
| const int oc, const int ic, | |||||
| const int ih, const int iw, | |||||
| const int oh, const int oh_block, | |||||
| const int ow, const Op& op, const int, | |||||
| const int) { | |||||
| constexpr int fh = filter_size; | constexpr int fh = filter_size; | ||||
| constexpr int fw = filter_size; | constexpr int fw = filter_size; | ||||
| constexpr int ic_step = 4; | constexpr int ic_step = 4; | ||||
| @@ -697,55 +567,23 @@ static void conv_direct_stride2_fp32_nchw44( | |||||
| } | } | ||||
| } | } | ||||
| #define CONSTRUCT_FUNC(filter_size) \ | |||||
| template <BiasMode bias_mode, typename Op> \ | |||||
| void conv_bias:: \ | |||||
| conv_direct_stride2_##filter_size##x##filter_size##_fp32_nchw44( \ | |||||
| const float32_t* src, const float32_t* filter, \ | |||||
| const float32_t* bias, float32_t* temp, float32_t* dst, \ | |||||
| const int oc, const int ic, const int ih, const int iw, \ | |||||
| const int oh, const int oh_block, const int ow, \ | |||||
| const Op& op, const int ph, const int pw) { \ | |||||
| conv_direct_stride2_fp32_nchw44<bias_mode, Op, filter_size>( \ | |||||
| src, filter, bias, temp, dst, oc, ic, ih, iw, oh, oh_block, \ | |||||
| ow, op, ph, pw); \ | |||||
| } | |||||
| CONSTRUCT_FUNC(2); | |||||
| CONSTRUCT_FUNC(3); | |||||
| CONSTRUCT_FUNC(5); | |||||
| CONSTRUCT_FUNC(7); | |||||
| #undef CONSTRUCT_FUNC | |||||
| #define INSTANTIATION(stride, i, bias, Op) \ | |||||
| template void conv_bias::conv_direct_##stride##_##i##x##i##_fp32_nchw44< \ | |||||
| bias, Op>(const float32_t*, const float32_t*, const float32_t*, \ | |||||
| float32_t*, float32_t*, const int, const int, const int, \ | |||||
| const int, const int, const int, const int, const Op&, \ | |||||
| const int, const int); | |||||
| #define FOR_OP(stride, i, bias) \ | |||||
| INSTANTIATION(stride, i, bias, NoneOp<dt_float32>) \ | |||||
| INSTANTIATION(stride, i, bias, ReluOp<dt_float32>) \ | |||||
| INSTANTIATION(stride, i, bias, HSwishOp<dt_float32>) \ | |||||
| INSTANTIATION(stride, i, bias, SigmoidOp<dt_float32>) | |||||
| #define FOR_BIAS(stride, i) \ | |||||
| FOR_OP(stride, i, BiasMode::NO_BIAS) \ | |||||
| FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS) \ | |||||
| FOR_OP(stride, i, BiasMode::BIAS) | |||||
| #define FOR_FILTER(stride) \ | |||||
| FOR_BIAS(stride, 2) \ | |||||
| FOR_BIAS(stride, 3) \ | |||||
| FOR_BIAS(stride, 5) \ | |||||
| FOR_BIAS(stride, 7) | |||||
| FOR_FILTER(stride2) | |||||
| #undef FOR_STRIDE | |||||
| #undef FOR_FILTER | |||||
| #undef FOR_IC | |||||
| #undef FOR_BIAS | |||||
| #undef FOR_NONLINEAR | |||||
| #undef FOR_REMAIN | |||||
| #undef INSTANTIATION | |||||
| #define INSTANTIATION(filter_size, bias_mode, Op) \ | |||||
| template void \ | |||||
| conv_bias::conv_direct_fp32_nchw44<bias_mode, Op, filter_size, 2>( \ | |||||
| const float* src, const float* filter, const float* bias, float*, \ | |||||
| float* dst, const int oc, const int ic, const int ih, \ | |||||
| const int iw, const int oh, const int oh_block, const int ow, \ | |||||
| const Op& op, const int, const int); | |||||
| #define FOR_OP(filter_size, bias) \ | |||||
| INSTANTIATION(filter_size, bias, NoneOp<dt_float32>) \ | |||||
| INSTANTIATION(filter_size, bias, ReluOp<dt_float32>) \ | |||||
| INSTANTIATION(filter_size, bias, HSwishOp<dt_float32>) \ | |||||
| INSTANTIATION(filter_size, bias, SigmoidOp<dt_float32>) | |||||
| #define INSTANTIATION_CONV_S2(filter_size) \ | |||||
| FOR_OP(filter_size, BiasMode::NO_BIAS) \ | |||||
| FOR_OP(filter_size, BiasMode::BROADCAST_CHANNEL_BIAS) \ | |||||
| FOR_OP(filter_size, BiasMode::BIAS) | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -0,0 +1,14 @@ | |||||
| /** | |||||
| * \file | |||||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s1.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
| INSTANCE_CONV(2, 1); | |||||
| @@ -0,0 +1,14 @@ | |||||
| /** | |||||
| * \file | |||||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_2x2s2.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
| INSTANCE_CONV(2, 2); | |||||
| @@ -0,0 +1,14 @@ | |||||
| /** | |||||
| * \file | |||||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s1.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
| INSTANCE_CONV(3, 1); | |||||
| @@ -0,0 +1,14 @@ | |||||
| /** | |||||
| * \file | |||||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_3x3s2.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
| INSTANCE_CONV(3, 2); | |||||
| @@ -0,0 +1,14 @@ | |||||
| /** | |||||
| * \file | |||||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s1.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
| INSTANCE_CONV(5, 1); | |||||
| @@ -0,0 +1,14 @@ | |||||
| /** | |||||
| * \file | |||||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_5x5s2.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
| INSTANCE_CONV(5, 2); | |||||
| @@ -0,0 +1,14 @@ | |||||
| /** | |||||
| * \file | |||||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s1.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
| INSTANCE_CONV(7, 1); | |||||
| @@ -0,0 +1,14 @@ | |||||
| /** | |||||
| * \file | |||||
| * dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_7x7s2.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h" | |||||
| INSTANCE_CONV(7, 2); | |||||
| @@ -0,0 +1,443 @@ | |||||
| /** | |||||
| * \file dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include "megdnn/arch.h" | |||||
| #include "src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h" | |||||
| #include "src/arm_common/conv_bias/intrinsic_helper.h" | |||||
| #include "src/arm_common/conv_bias/opr_impl.h" | |||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||||
| #include "src/common/unroll_macro.h" | |||||
| #include "src/common/utils.h" | |||||
| #include "src/fallback/conv_bias/common.h" | |||||
| using namespace megdnn; | |||||
| using namespace arm_common; | |||||
| namespace { | |||||
| /** | |||||
| *\brief ShiftCalHelper is core calculate code | |||||
| *\tparam src_idx is offset for src regs | |||||
| *\tparam weight_idx is offset for weight regs | |||||
| *\tparam T is type of output regs | |||||
| *\tparam T2 is type of src regs | |||||
| *\tparam T3 is type of weight regs | |||||
| */ | |||||
| template <int src_idx, int weight_idx, int c_dim, int stride, typename T, | |||||
| typename T2, typename T3> | |||||
| struct ShiftCalHelper { | |||||
| static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight); | |||||
| }; | |||||
| template <int src_idx, int weight_idx, int stride, typename T, typename T2, | |||||
| typename T3> | |||||
| struct ShiftCalHelper<src_idx, weight_idx, 2, stride, T, T2, T3> { | |||||
| static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { | |||||
| #define cb(step) \ | |||||
| c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][weight_idx], \ | |||||
| src[(step * stride + src_idx) / 4], \ | |||||
| (step * stride + src_idx) % 4); \ | |||||
| c[1][step] = vfmaq_laneq_f32(c[1][step], weight[1][weight_idx], \ | |||||
| src[(step * stride + src_idx) / 4], \ | |||||
| (step * stride + src_idx) % 4); | |||||
| UNROLL_CALL_RAW(8, cb); | |||||
| #undef cb | |||||
| } | |||||
| }; | |||||
| template <int src_idx, int weight_idx, int stride, typename T, typename T2, | |||||
| typename T3> | |||||
| struct ShiftCalHelper<src_idx, weight_idx, 1, stride, T, T2, T3> { | |||||
| static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { | |||||
| #define cb(step) \ | |||||
| c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][weight_idx], \ | |||||
| src[(step * stride + src_idx) / 4], \ | |||||
| (step * stride + src_idx) % 4); | |||||
| UNROLL_CALL_RAW(8, cb); | |||||
| #undef cb | |||||
| } | |||||
| }; | |||||
| template <int src_idx, int weight_idx, int c_dim, int stride, typename T, | |||||
| typename T2, typename T3> | |||||
| MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) { | |||||
| ShiftCalHelper<src_idx, weight_idx, c_dim, stride, T, T2, T3>::impl(c, src, | |||||
| weight); | |||||
| }; | |||||
| template <int oc> | |||||
| struct OCHelper { | |||||
| public: | |||||
| static const int val = -1; | |||||
| }; | |||||
| template <> | |||||
| struct OCHelper<4> { | |||||
| public: | |||||
| static const int val = 1; | |||||
| }; | |||||
| template <> | |||||
| struct OCHelper<8> { | |||||
| public: | |||||
| static const int val = 2; | |||||
| }; | |||||
| /** | |||||
| * oc8_ow8(m = 8, n = 8) and oc4_ow8(m = 4, n = 8) gemm like kernel | |||||
| **/ | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int filter_size, | |||||
| int oc_block, int stride, int ow_block> | |||||
| struct KerNeonXXs2NchwNchw44FP32 { | |||||
| static void impl(const float32_t* src_ptr, const float32_t* weight_ptr, | |||||
| const float32_t* bias_ptr, float32_t* dst_ptr, int ic, | |||||
| int ih, int iw, int ld_dst_oc, const Op& op); | |||||
| }; | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | |||||
| int stride, int ow_block> | |||||
| struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 7, oc_block, stride, | |||||
| ow_block> { | |||||
| static void impl(const float32_t* src_ptr, const float32_t* weight_ptr, | |||||
| const float32_t* bias_ptr, float32_t* dst_ptr, int ic, | |||||
| int ih, int iw, int ld_dst_oc, const Op& op) { | |||||
| constexpr int loop_ic_step = 1; | |||||
| constexpr int filter_size = 7; | |||||
| constexpr int oc_step = 4; | |||||
| constexpr int simd_len = 4; | |||||
| constexpr int src_reg_size = | |||||
| (ow_block * stride + filter_size - stride + simd_len - 1) / | |||||
| simd_len; | |||||
| constexpr int ld_weight_fw = oc_step * filter_size; | |||||
| const int ld_weight_oc = oc_step * filter_size * filter_size * ic; | |||||
| const int ld_weight_ic = oc_step * filter_size * filter_size; | |||||
| const int ld_src_ic = ih * iw; | |||||
| constexpr int c_dim = OCHelper<oc_block>::val; | |||||
| float32x4_t c[c_dim][8]; | |||||
| init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
| float32x4_t src[src_reg_size]; | |||||
| float32x4_t weight[c_dim][filter_size]; | |||||
| #define KERNEL_CB(step) \ | |||||
| load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>( \ | |||||
| src, src_ptr + step * iw, 0); \ | |||||
| load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( \ | |||||
| weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \ | |||||
| cal_helper<0, 0, c_dim, stride>(c, src, weight); \ | |||||
| cal_helper<1, 1, c_dim, stride>(c, src, weight); \ | |||||
| cal_helper<2, 2, c_dim, stride>(c, src, weight); \ | |||||
| cal_helper<3, 3, c_dim, stride>(c, src, weight); \ | |||||
| cal_helper<4, 4, c_dim, stride>(c, src, weight); \ | |||||
| cal_helper<5, 5, c_dim, stride>(c, src, weight); \ | |||||
| cal_helper<6, 6, c_dim, stride>(c, src, weight); | |||||
| UNROLL_CALL_RAW(7, KERNEL_CB) | |||||
| #undef KERNEL_CB | |||||
| src_ptr += ld_src_ic; | |||||
| weight_ptr += ld_weight_ic; | |||||
| } | |||||
| store_ocx_ow8_remain_static<c_dim, remain_w, Op>(c, op, dst_ptr, | |||||
| ld_dst_oc); | |||||
| } | |||||
| }; | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | |||||
| int stride, int ow_block> | |||||
| struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 5, oc_block, stride, | |||||
| ow_block> { | |||||
| static void impl(const float32_t* src_ptr, const float32_t* weight_ptr, | |||||
| const float32_t* bias_ptr, float32_t* dst_ptr, int ic, | |||||
| int ih, int iw, int ld_dst_oc, const Op& op) { | |||||
| constexpr int loop_ic_step = 1; | |||||
| constexpr int filter_size = 5; | |||||
| constexpr int oc_step = 4; | |||||
| constexpr int simd_len = 4; | |||||
| constexpr int src_reg_size = | |||||
| (ow_block * stride + filter_size - stride + simd_len - 1) / | |||||
| simd_len; | |||||
| constexpr int ld_weight_fw = oc_step * filter_size; | |||||
| const int ld_weight_oc = oc_step * filter_size * filter_size * ic; | |||||
| const int ld_weight_ic = oc_step * filter_size * filter_size; | |||||
| const int ld_src_ic = ih * iw; | |||||
| constexpr int c_dim = OCHelper<oc_block>::val; | |||||
| float32x4_t c[c_dim][8]; | |||||
| init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
| float32x4_t src[src_reg_size]; | |||||
| float32x4_t weight[c_dim][filter_size]; | |||||
| #define KERNEL_CB(step) \ | |||||
| load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>( \ | |||||
| src, src_ptr + step * iw, 0); \ | |||||
| load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( \ | |||||
| weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \ | |||||
| cal_helper<0, 0, c_dim, stride>(c, src, weight); \ | |||||
| cal_helper<1, 1, c_dim, stride>(c, src, weight); \ | |||||
| cal_helper<2, 2, c_dim, stride>(c, src, weight); \ | |||||
| cal_helper<3, 3, c_dim, stride>(c, src, weight); \ | |||||
| cal_helper<4, 4, c_dim, stride>(c, src, weight); | |||||
| UNROLL_CALL_RAW(5, KERNEL_CB) | |||||
| #undef KERNEL_CB | |||||
| src_ptr += ld_src_ic; | |||||
| weight_ptr += ld_weight_ic; | |||||
| } | |||||
| store_ocx_ow8_remain_static<c_dim, remain_w, Op>(c, op, dst_ptr, | |||||
| ld_dst_oc); | |||||
| } | |||||
| }; | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | |||||
| int stride, int ow_block> | |||||
| struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 3, oc_block, stride, | |||||
| ow_block> { | |||||
| static void impl(const float32_t* src_ptr, const float32_t* weight_ptr, | |||||
| const float32_t* bias_ptr, float32_t* dst_ptr, int ic, | |||||
| int ih, int iw, int ld_dst_oc, const Op& op) { | |||||
| constexpr int loop_ic_step = 1; | |||||
| constexpr int filter_size = 3; | |||||
| constexpr int oc_step = 4; | |||||
| constexpr int simd_len = 4; | |||||
| constexpr int src_reg_size = | |||||
| (ow_block * stride + filter_size - stride + simd_len - 1) / | |||||
| simd_len; | |||||
| constexpr int ld_weight_fw = oc_step * filter_size; | |||||
| const int ld_weight_oc = oc_step * filter_size * filter_size * ic; | |||||
| const int ld_weight_ic = oc_step * filter_size * filter_size; | |||||
| const int ld_src_ic = ih * iw; | |||||
| constexpr int c_dim = OCHelper<oc_block>::val; | |||||
| float32x4_t c[c_dim][8]; | |||||
| init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
| float32x4_t src[src_reg_size]; | |||||
| float32x4_t weight[c_dim][filter_size]; | |||||
| // row 0 | |||||
| load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, | |||||
| 0); | |||||
| load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( | |||||
| weight, weight_ptr, ld_weight_oc); | |||||
| cal_helper<0, 0, c_dim, stride>(c, src, weight); | |||||
| cal_helper<1, 1, c_dim, stride>(c, src, weight); | |||||
| cal_helper<2, 2, c_dim, stride>(c, src, weight); | |||||
| // row 1 | |||||
| load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>( | |||||
| src, src_ptr + iw, 0); | |||||
| load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( | |||||
| weight, weight_ptr + 1 * ld_weight_fw, ld_weight_oc); | |||||
| cal_helper<0, 0, c_dim, stride>(c, src, weight); | |||||
| cal_helper<1, 1, c_dim, stride>(c, src, weight); | |||||
| cal_helper<2, 2, c_dim, stride>(c, src, weight); | |||||
| // row 2 | |||||
| load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>( | |||||
| src, src_ptr + 2 * iw, 0); | |||||
| load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( | |||||
| weight, weight_ptr + 2 * ld_weight_fw, ld_weight_oc); | |||||
| cal_helper<0, 0, c_dim, stride>(c, src, weight); | |||||
| cal_helper<1, 1, c_dim, stride>(c, src, weight); | |||||
| cal_helper<2, 2, c_dim, stride>(c, src, weight); | |||||
| src_ptr += ld_src_ic; | |||||
| weight_ptr += ld_weight_ic; | |||||
| } | |||||
| store_ocx_ow8_remain_static<c_dim, remain_w, Op>(c, op, dst_ptr, | |||||
| ld_dst_oc); | |||||
| } | |||||
| }; | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | |||||
| int stride, int ow_block> | |||||
| struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 2, oc_block, stride, | |||||
| ow_block> { | |||||
| static void impl(const float32_t* src_ptr, const float32_t* weight_ptr, | |||||
| const float32_t* bias_ptr, float32_t* dst_ptr, int ic, | |||||
| int ih, int iw, int ld_dst_oc, const Op& op) { | |||||
| constexpr int loop_ic_step = 1; | |||||
| constexpr int filter_size = 2; | |||||
| constexpr int oc_step = 4; | |||||
| constexpr int simd_len = 4; | |||||
| constexpr int src_reg_size = | |||||
| (ow_block * stride + filter_size - stride + simd_len - 1) / | |||||
| simd_len; | |||||
| constexpr int ld_weight_fw = oc_step * filter_size; | |||||
| const int ld_weight_oc = oc_step * filter_size * filter_size * ic; | |||||
| const int ld_weight_ic = oc_step * filter_size * filter_size; | |||||
| const int ld_src_ic = ih * iw; | |||||
| constexpr int c_dim = OCHelper<oc_block>::val; | |||||
| float32x4_t c[c_dim][8]; | |||||
| init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
| float32x4_t src[src_reg_size]; | |||||
| float32x4_t weight[c_dim][filter_size]; | |||||
| // row 0 | |||||
| load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, | |||||
| 0); | |||||
| load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( | |||||
| weight, weight_ptr, ld_weight_oc); | |||||
| cal_helper<0, 0, c_dim, stride>(c, src, weight); | |||||
| cal_helper<1, 1, c_dim, stride>(c, src, weight); | |||||
| // row 1 | |||||
| load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>( | |||||
| src, src_ptr + iw, 0); | |||||
| load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( | |||||
| weight, weight_ptr + 1 * ld_weight_fw, ld_weight_oc); | |||||
| cal_helper<0, 0, c_dim, stride>(c, src, weight); | |||||
| cal_helper<1, 1, c_dim, stride>(c, src, weight); | |||||
| src_ptr += ld_src_ic; | |||||
| weight_ptr += ld_weight_ic; | |||||
| } | |||||
| store_ocx_ow8_remain_static<c_dim, remain_w, Op>(c, op, dst_ptr, | |||||
| ld_dst_oc); | |||||
| } | |||||
| }; | |||||
| } // namespace | |||||
| template <BiasMode bias_mode, typename Op, int filter_size, int stride> | |||||
| void fp32_direct_nchw_nchw44::conv_direct_fp32_nchw_nchw44( | |||||
| const float32_t* src, const float32_t* filter, const float32_t* bias, | |||||
| float32_t*, float32_t* dst, const int oc, const int ic, const int ih, | |||||
| const int iw, const int oh, const int oh_block, const int ow, | |||||
| const Op& op, const int, const int) { | |||||
| constexpr int fh = filter_size; | |||||
| constexpr int fw = filter_size; | |||||
| constexpr int ic_step = 1; | |||||
| constexpr int big_oc_step = 8; | |||||
| constexpr int oc_step = 4; | |||||
| constexpr int ih_step = 1; | |||||
| constexpr int oh_step = 1; | |||||
| constexpr int ow_step = 8; | |||||
| constexpr int stride_h = stride; | |||||
| constexpr int stride_w = stride; | |||||
| constexpr int pack_iw_len = 1; | |||||
| const int img_stride = oh * ow; | |||||
| const int ow_end = ow / ow_step * ow_step; | |||||
| const int ow_remain = ow - ow_end; | |||||
| const int oc_end = oc / big_oc_step * big_oc_step; | |||||
| const int oc_remain = oc - oc_end; | |||||
| const int ld_dst_oc = oc_step * img_stride; | |||||
| using remain_fun = std::function<void( | |||||
| const float32_t* src_ptr, const float32_t* weight_ptr, | |||||
| const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, | |||||
| int iw, int ld_dst_oc, const Op& op)>; | |||||
| remain_fun kern_big_oc_remain = nullptr; | |||||
| remain_fun kern_small_oc_remain = nullptr; | |||||
| switch (ow_remain) { | |||||
| #define cb(step) \ | |||||
| case step: \ | |||||
| kern_big_oc_remain = \ | |||||
| KerNeonXXs2NchwNchw44FP32<bias_mode, Op, step, filter_size, \ | |||||
| big_oc_step, stride, ow_step>::impl; \ | |||||
| kern_small_oc_remain = \ | |||||
| KerNeonXXs2NchwNchw44FP32<bias_mode, Op, step, filter_size, \ | |||||
| oc_step, stride, ow_step>::impl; \ | |||||
| break; | |||||
| UNROLL_CALL_RAW(8, cb); | |||||
| default: | |||||
| megdnn_assert(0, "no remain %d for kern", ow_remain); | |||||
| } | |||||
| for (int oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { | |||||
| const int weight_offset = oc_idx * ic * fh * fw; | |||||
| for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) { | |||||
| for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||||
| const int src_offset = | |||||
| (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * | |||||
| ic_step * pack_iw_len; | |||||
| const int dst_offset = | |||||
| oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||||
| KerNeonXXs2NchwNchw44FP32<bias_mode, Op, ow_step, filter_size, | |||||
| big_oc_step, stride, | |||||
| ow_step>::impl(src + src_offset, | |||||
| filter + weight_offset, | |||||
| bias + oc_idx, | |||||
| dst + dst_offset, ic, | |||||
| ih, iw, ld_dst_oc, op); | |||||
| } | |||||
| if (ow_remain > 0) { | |||||
| const int src_offset = | |||||
| (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * | |||||
| ic_step * pack_iw_len; | |||||
| const int dst_offset = | |||||
| oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||||
| kern_big_oc_remain(src + src_offset, filter + weight_offset, | |||||
| bias + oc_idx, dst + dst_offset, ic, ih, iw, | |||||
| ld_dst_oc, op); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (oc_remain > 0) { | |||||
| int oc_idx = oc_end; | |||||
| const int weight_offset = oc_idx * ic * fh * fw; | |||||
| for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) { | |||||
| for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||||
| const int src_offset = | |||||
| (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * | |||||
| ic_step * pack_iw_len; | |||||
| const int dst_offset = | |||||
| oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||||
| KerNeonXXs2NchwNchw44FP32<bias_mode, Op, ow_step, filter_size, | |||||
| oc_step, stride, | |||||
| ow_step>::impl(src + src_offset, | |||||
| filter + weight_offset, | |||||
| bias + oc_idx, | |||||
| dst + dst_offset, ic, | |||||
| ih, iw, ld_dst_oc, op); | |||||
| } | |||||
| if (ow_remain > 0) { | |||||
| const int src_offset = | |||||
| (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * | |||||
| ic_step * pack_iw_len; | |||||
| const int dst_offset = | |||||
| oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||||
| kern_small_oc_remain(src + src_offset, filter + weight_offset, | |||||
| bias + oc_idx, dst + dst_offset, ic, ih, | |||||
| iw, ld_dst_oc, op); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| #define INSTANTIATION(stride, filter_size, bias_mode, Op) \ | |||||
| template void fp32_direct_nchw_nchw44::conv_direct_fp32_nchw_nchw44< \ | |||||
| bias_mode, Op, filter_size, stride>( \ | |||||
| const float32_t* src, const float32_t* filter, \ | |||||
| const float32_t* bias, float32_t*, float32_t* dst, const int oc, \ | |||||
| const int ic, const int ih, const int iw, const int oh, \ | |||||
| const int oh_block, const int ow, const Op& op, const int, \ | |||||
| const int); | |||||
| #define FOR_OP(stride, filter, bias) \ | |||||
| INSTANTIATION(stride, filter, bias, NoneOp<dt_float32>) \ | |||||
| INSTANTIATION(stride, filter, bias, ReluOp<dt_float32>) \ | |||||
| INSTANTIATION(stride, filter, bias, HSwishOp<dt_float32>) | |||||
| #define INSTANCE_CONV(filter, stride) \ | |||||
| FOR_OP(stride, filter, BiasMode::NO_BIAS) \ | |||||
| FOR_OP(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) \ | |||||
| FOR_OP(stride, filter, BiasMode::BIAS) | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -13,8 +13,8 @@ | |||||
| #include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
| #include "src/arm_common/conv_bias/block_helper.h" | #include "src/arm_common/conv_bias/block_helper.h" | ||||
| #include "src/arm_common/conv_bias/fp32/algos.h" | #include "src/arm_common/conv_bias/fp32/algos.h" | ||||
| #include "src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.h" | |||||
| #include "src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.h" | |||||
| #include "src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h" | |||||
| #include "src/arm_common/elemwise_op.h" | #include "src/arm_common/elemwise_op.h" | ||||
| #include "midout.h" | #include "midout.h" | ||||
| @@ -112,17 +112,11 @@ static void do_conv_kern(const WorkspaceBundle& bundle, | |||||
| const size_t src_size = get_perthread_cache_bytes(ic, ih2, iw2); | const size_t src_size = get_perthread_cache_bytes(ic, ih2, iw2); | ||||
| float* sptr = reinterpret_cast<float*>((int8_t*)bundle.get(0) + | float* sptr = reinterpret_cast<float*>((int8_t*)bundle.get(0) + | ||||
| ncb_index.thread_id * src_size); | ncb_index.thread_id * src_size); | ||||
| if (stride == 1) { | |||||
| conv_bias::pack_src_fp32_nchw44_stride1( | |||||
| sptr, origin_sptr, ph, pw, remain_right_pad, | |||||
| ih_real - src_top_pad - src_bottom_pad, iw, iw2, src_top_pad, | |||||
| src_bottom_pad, ic, ih * iw); | |||||
| } else { | |||||
| conv_bias::pack_src_fp32_nchw44_stride2( | |||||
| sptr, origin_sptr, ph, pw, remain_right_pad, | |||||
| ih_real - src_top_pad - src_bottom_pad, iw, iw2, src_top_pad, | |||||
| src_bottom_pad, ic, ih * iw); | |||||
| } | |||||
| conv_bias::pack_src_fp32_nchw44<stride>( | |||||
| sptr, origin_sptr, ph, pw, remain_right_pad, | |||||
| ih_real - src_top_pad - src_bottom_pad, iw, iw2, src_top_pad, | |||||
| src_bottom_pad, ic, ih * iw); | |||||
| const float* fptr = | const float* fptr = | ||||
| kern_param.filter<dt_float32>(group_id) + oc_idx * fh * fw * ic; | kern_param.filter<dt_float32>(group_id) + oc_idx * fh * fw * ic; | ||||
| @@ -135,25 +129,9 @@ static void do_conv_kern(const WorkspaceBundle& bundle, | |||||
| kern_param.bias<dt_float32>(batch_id, group_id) + bias_offset; | kern_param.bias<dt_float32>(batch_id, group_id) + bias_offset; | ||||
| Op op; | Op op; | ||||
| if (stride == 1) { | |||||
| #define KERN1_NCHW44_CONV(filter) \ | |||||
| conv_bias::conv_direct_stride1_##filter##x##filter##_fp32_nchw44< \ | |||||
| \ | |||||
| bias_mode, Op>(sptr, fptr, bptr, nullptr, dst, oc_block, ic, \ | |||||
| ih_real, iw2, oh, oh_block_real, ow, op, ph, pw) | |||||
| DISPATCH_FILTER(filter, KERN1_NCHW44_CONV); | |||||
| #undef KERN1_NCHW44_CONV | |||||
| } else { | |||||
| #define KERN1_NCHW44_CONV(filter) \ | |||||
| conv_bias::conv_direct_stride2_##filter##x##filter##_fp32_nchw44< \ | |||||
| \ | |||||
| bias_mode, Op>(sptr, fptr, bptr, nullptr, dst, oc_block, ic, \ | |||||
| ih_real, iw2, oh, oh_block_real, ow, op, ph, pw) | |||||
| DISPATCH_FILTER(filter, KERN1_NCHW44_CONV); | |||||
| #undef KERN1_NCHW44_CONV | |||||
| } | |||||
| conv_bias::conv_direct_fp32_nchw44<bias_mode, Op, filter, stride>( | |||||
| sptr, fptr, bptr, nullptr, dst, oc_block, ic, ih_real, iw2, oh, | |||||
| oh_block_real, ow, op, ph, pw); | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| @@ -0,0 +1,34 @@ | |||||
| /** | |||||
| * \file dnn/src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "src/arm_common/conv_bias/opr_impl.h" | |||||
| #include "src/fallback/conv_bias/common.h" | |||||
| namespace megdnn { | |||||
| namespace arm_common { | |||||
| namespace conv_bias { | |||||
| template <BiasMode bias_mode, typename Op, int filter_size, int stride> | |||||
| void conv_direct_fp32_nchw44(const float* src, const float* filter, | |||||
| const float* bias, float*, float* dst, | |||||
| const int oc, const int ic, const int ih, | |||||
| const int iw, const int oh, const int oh_block, | |||||
| const int ow, const Op& op, const int, const int); | |||||
| template <int stride> | |||||
| void pack_src_fp32_nchw44(float* sptr_base, const float* sptr_origin, const int, | |||||
| const int pw, const int pad_right, const int ih, | |||||
| const int iw, const int iw2, const int pad_top, | |||||
| const int pad_bottom, const int ic, | |||||
| const int ic_stride); | |||||
| } // namespace conv_bias | |||||
| } // namespace arm_common | |||||
| } // namespace megdnn | |||||
| @@ -120,7 +120,8 @@ static void pack_weight(const WorkspaceBundle& bundle, | |||||
| kern_param.filter<dt_float32>(group_id) + oc_idx * fh * fw * ic; | kern_param.filter<dt_float32>(group_id) + oc_idx * fh * fw * ic; | ||||
| auto packed_weight = reinterpret_cast<float*>(bundle.get(1)) + | auto packed_weight = reinterpret_cast<float*>(bundle.get(1)) + | ||||
| group_id * oc * ic * fh * fw + oc_idx * ic * fh * fw; | group_id * oc * ic * fh * fw + oc_idx * ic * fh * fw; | ||||
| pack_weight_fp32_nchw_nchw44(fptr, packed_weight, oc_block, fh, fw, ic); | |||||
| fp32_direct_nchw_nchw44::pack_weight_fp32_nchw_nchw44(fptr, packed_weight, | |||||
| oc_block, fh, fw, ic); | |||||
| } | } | ||||
| template <size_t filter_size, BiasMode bias_mode, typename Op, size_t stride> | template <size_t filter_size, BiasMode bias_mode, typename Op, size_t stride> | ||||
| @@ -180,7 +181,8 @@ static void do_conv_kern(const WorkspaceBundle& bundle, | |||||
| kern_param.bias<dt_float32>(batch_id, group_id) + oc_idx; | kern_param.bias<dt_float32>(batch_id, group_id) + oc_idx; | ||||
| Op op; | Op op; | ||||
| conv_direct_fp32_nchw_nchw44<bias_mode, Op, filter_size, stride>( | |||||
| fp32_direct_nchw_nchw44::conv_direct_fp32_nchw_nchw44<bias_mode, Op, | |||||
| filter_size, stride>( | |||||
| sptr, packed_weight, bptr, nullptr, dst, oc_block, ic, ih_real, iw2, | sptr, packed_weight, bptr, nullptr, dst, oc_block, ic, ih_real, iw2, | ||||
| oh, oh_block_real, ow, op, ph, pw); | oh, oh_block_real, ow, op, ph, pw); | ||||
| } | } | ||||
| @@ -20,295 +20,12 @@ | |||||
| #include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace arm_common { | namespace arm_common { | ||||
| namespace { | |||||
| /** | |||||
| *\brief ShiftCalHelper is core calculate code | |||||
| *\tparam src_idx is offset for src regs | |||||
| *\tparam weight_idx is offset for weight regs | |||||
| *\tparam T is type of output regs | |||||
| *\tparam T2 is type of src regs | |||||
| *\tparam T3 is type of weight regs | |||||
| */ | |||||
| template <int src_idx, int weight_idx, int c_dim, typename Func, int stride, | |||||
| typename T, typename T2, typename T3> | |||||
| struct ShiftCalHelper { | |||||
| static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight); | |||||
| }; | |||||
| template <int src_idx, int weight_idx, typename Func, int stride, typename T, | |||||
| typename T2, typename T3> | |||||
| struct ShiftCalHelper<src_idx, weight_idx, 2, Func, stride, T, T2, T3> { | |||||
| static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { | |||||
| #define cb(step) \ | |||||
| c[0][step] = Func::template impl<(step * stride + src_idx) % 4>( \ | |||||
| c[0][step], weight[0][weight_idx], \ | |||||
| src[(step * stride + src_idx) / 4]); \ | |||||
| c[1][step] = Func::template impl<(step * stride + src_idx) % 4>( \ | |||||
| c[1][step], weight[1][weight_idx], \ | |||||
| src[(step * stride + src_idx) / 4]); | |||||
| UNROLL_CALL_RAW(8, cb); | |||||
| #undef cb | |||||
| } | |||||
| }; | |||||
| template <int src_idx, int weight_idx, typename Func, int stride, typename T, | |||||
| typename T2, typename T3> | |||||
| struct ShiftCalHelper<src_idx, weight_idx, 1, Func, stride, T, T2, T3> { | |||||
| static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { | |||||
| #define cb(step) \ | |||||
| c[0][step] = Func::template impl<(step * stride + src_idx) % 4>( \ | |||||
| c[0][step], weight[0][weight_idx], \ | |||||
| src[(step * stride + src_idx) / 4]); | |||||
| UNROLL_CALL_RAW(8, cb); | |||||
| #undef cb | |||||
| } | |||||
| }; | |||||
| template <int src_idx, int weight_idx, int c_dim, typename FUNC, int stride, | |||||
| typename T, typename T2, typename T3> | |||||
| MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) { | |||||
| ShiftCalHelper<src_idx, weight_idx, c_dim, FUNC, stride, T, T2, T3>::impl( | |||||
| c, src, weight); | |||||
| }; | |||||
| template <int oc> | |||||
| struct OCHelper { | |||||
| public: | |||||
| static const int val = -1; | |||||
| }; | |||||
| template <> | |||||
| struct OCHelper<4> { | |||||
| public: | |||||
| static const int val = 1; | |||||
| }; | |||||
| template <> | |||||
| struct OCHelper<8> { | |||||
| public: | |||||
| static const int val = 2; | |||||
| }; | |||||
| /** | |||||
| * oc8_ow8(m = 8, n = 8) and oc4_ow8(m = 4, n = 8) gemm like kernel | |||||
| **/ | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int filter_size, | |||||
| int oc_block, int stride, int ow_block> | |||||
| struct KerNeonXXs2NchwNchw44FP32 { | |||||
| static void impl(const float32_t* src_ptr, const float32_t* weight_ptr, | |||||
| const float32_t* bias_ptr, float32_t* dst_ptr, int ic, | |||||
| int ih, int iw, int ld_dst_oc, const Op& op); | |||||
| }; | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | |||||
| int stride, int ow_block> | |||||
| struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 7, oc_block, stride, | |||||
| ow_block> { | |||||
| static void impl(const float32_t* src_ptr, const float32_t* weight_ptr, | |||||
| const float32_t* bias_ptr, float32_t* dst_ptr, int ic, | |||||
| int ih, int iw, int ld_dst_oc, const Op& op) { | |||||
| constexpr int loop_ic_step = 1; | |||||
| constexpr int filter_size = 7; | |||||
| constexpr int oc_step = 4; | |||||
| constexpr int simd_len = 4; | |||||
| constexpr int src_reg_size = | |||||
| (ow_block * stride + filter_size - stride + simd_len - 1) / | |||||
| simd_len; | |||||
| constexpr int ld_weight_fw = oc_step * filter_size; | |||||
| const int ld_weight_oc = oc_step * filter_size * filter_size * ic; | |||||
| const int ld_weight_ic = oc_step * filter_size * filter_size; | |||||
| const int ld_src_ic = ih * iw; | |||||
| constexpr int c_dim = OCHelper<oc_block>::val; | |||||
| float32x4_t c[c_dim][8]; | |||||
| init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
| float32x4_t src[src_reg_size]; | |||||
| float32x4_t weight[c_dim][filter_size]; | |||||
| #define KERNEL_CB(step) \ | |||||
| load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>( \ | |||||
| src, src_ptr + step * iw, 0); \ | |||||
| load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( \ | |||||
| weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \ | |||||
| cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \ | |||||
| cal_helper<1, 1, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \ | |||||
| cal_helper<2, 2, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \ | |||||
| cal_helper<3, 3, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \ | |||||
| cal_helper<4, 4, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \ | |||||
| cal_helper<5, 5, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \ | |||||
| cal_helper<6, 6, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); | |||||
| UNROLL_CALL_RAW(7, KERNEL_CB) | |||||
| #undef KERNEL_CB | |||||
| src_ptr += ld_src_ic; | |||||
| weight_ptr += ld_weight_ic; | |||||
| } | |||||
| store_ocx_ow8_remain_static<c_dim, remain_w, Op>(c, op, dst_ptr, | |||||
| ld_dst_oc); | |||||
| } | |||||
| }; | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | |||||
| int stride, int ow_block> | |||||
| struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 5, oc_block, stride, | |||||
| ow_block> { | |||||
| static void impl(const float32_t* src_ptr, const float32_t* weight_ptr, | |||||
| const float32_t* bias_ptr, float32_t* dst_ptr, int ic, | |||||
| int ih, int iw, int ld_dst_oc, const Op& op) { | |||||
| constexpr int loop_ic_step = 1; | |||||
| constexpr int filter_size = 5; | |||||
| constexpr int oc_step = 4; | |||||
| constexpr int simd_len = 4; | |||||
| constexpr int src_reg_size = | |||||
| (ow_block * stride + filter_size - stride + simd_len - 1) / | |||||
| simd_len; | |||||
| constexpr int ld_weight_fw = oc_step * filter_size; | |||||
| const int ld_weight_oc = oc_step * filter_size * filter_size * ic; | |||||
| const int ld_weight_ic = oc_step * filter_size * filter_size; | |||||
| const int ld_src_ic = ih * iw; | |||||
| constexpr int c_dim = OCHelper<oc_block>::val; | |||||
| float32x4_t c[c_dim][8]; | |||||
| init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
| float32x4_t src[src_reg_size]; | |||||
| float32x4_t weight[c_dim][filter_size]; | |||||
| #define KERNEL_CB(step) \ | |||||
| load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>( \ | |||||
| src, src_ptr + step * iw, 0); \ | |||||
| load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( \ | |||||
| weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \ | |||||
| cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \ | |||||
| cal_helper<1, 1, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \ | |||||
| cal_helper<2, 2, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \ | |||||
| cal_helper<3, 3, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); \ | |||||
| cal_helper<4, 4, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); | |||||
| UNROLL_CALL_RAW(5, KERNEL_CB) | |||||
| #undef KERNEL_CB | |||||
| src_ptr += ld_src_ic; | |||||
| weight_ptr += ld_weight_ic; | |||||
| } | |||||
| store_ocx_ow8_remain_static<c_dim, remain_w, Op>(c, op, dst_ptr, | |||||
| ld_dst_oc); | |||||
| } | |||||
| }; | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | |||||
| int stride, int ow_block> | |||||
| struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 3, oc_block, stride, | |||||
| ow_block> { | |||||
| static void impl(const float32_t* src_ptr, const float32_t* weight_ptr, | |||||
| const float32_t* bias_ptr, float32_t* dst_ptr, int ic, | |||||
| int ih, int iw, int ld_dst_oc, const Op& op) { | |||||
| constexpr int loop_ic_step = 1; | |||||
| constexpr int filter_size = 3; | |||||
| constexpr int oc_step = 4; | |||||
| constexpr int simd_len = 4; | |||||
| constexpr int src_reg_size = | |||||
| (ow_block * stride + filter_size - stride + simd_len - 1) / | |||||
| simd_len; | |||||
| constexpr int ld_weight_fw = oc_step * filter_size; | |||||
| const int ld_weight_oc = oc_step * filter_size * filter_size * ic; | |||||
| const int ld_weight_ic = oc_step * filter_size * filter_size; | |||||
| const int ld_src_ic = ih * iw; | |||||
| constexpr int c_dim = OCHelper<oc_block>::val; | |||||
| float32x4_t c[c_dim][8]; | |||||
| init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
| float32x4_t src[src_reg_size]; | |||||
| float32x4_t weight[c_dim][filter_size]; | |||||
| // row 0 | |||||
| load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, | |||||
| 0); | |||||
| load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( | |||||
| weight, weight_ptr, ld_weight_oc); | |||||
| cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); | |||||
| cal_helper<1, 1, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); | |||||
| cal_helper<2, 2, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); | |||||
| namespace fp32_direct_nchw_nchw44 { | |||||
| // row 1 | |||||
| load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>( | |||||
| src, src_ptr + iw, 0); | |||||
| load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( | |||||
| weight, weight_ptr + 1 * ld_weight_fw, ld_weight_oc); | |||||
| cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); | |||||
| cal_helper<1, 1, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); | |||||
| cal_helper<2, 2, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); | |||||
| // row 2 | |||||
| load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>( | |||||
| src, src_ptr + 2 * iw, 0); | |||||
| load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( | |||||
| weight, weight_ptr + 2 * ld_weight_fw, ld_weight_oc); | |||||
| cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); | |||||
| cal_helper<1, 1, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); | |||||
| cal_helper<2, 2, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); | |||||
| src_ptr += ld_src_ic; | |||||
| weight_ptr += ld_weight_ic; | |||||
| } | |||||
| store_ocx_ow8_remain_static<c_dim, remain_w, Op>(c, op, dst_ptr, | |||||
| ld_dst_oc); | |||||
| } | |||||
| }; | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | |||||
| int stride, int ow_block> | |||||
| struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 2, oc_block, stride, | |||||
| ow_block> { | |||||
| static void impl(const float32_t* src_ptr, const float32_t* weight_ptr, | |||||
| const float32_t* bias_ptr, float32_t* dst_ptr, int ic, | |||||
| int ih, int iw, int ld_dst_oc, const Op& op) { | |||||
| constexpr int loop_ic_step = 1; | |||||
| constexpr int filter_size = 2; | |||||
| constexpr int oc_step = 4; | |||||
| constexpr int simd_len = 4; | |||||
| constexpr int src_reg_size = | |||||
| (ow_block * stride + filter_size - stride + simd_len - 1) / | |||||
| simd_len; | |||||
| constexpr int ld_weight_fw = oc_step * filter_size; | |||||
| const int ld_weight_oc = oc_step * filter_size * filter_size * ic; | |||||
| const int ld_weight_ic = oc_step * filter_size * filter_size; | |||||
| const int ld_src_ic = ih * iw; | |||||
| constexpr int c_dim = OCHelper<oc_block>::val; | |||||
| float32x4_t c[c_dim][8]; | |||||
| init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
| float32x4_t src[src_reg_size]; | |||||
| float32x4_t weight[c_dim][filter_size]; | |||||
| // row 0 | |||||
| load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, | |||||
| 0); | |||||
| load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( | |||||
| weight, weight_ptr, ld_weight_oc); | |||||
| cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); | |||||
| cal_helper<1, 1, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); | |||||
| // row 1 | |||||
| load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>( | |||||
| src, src_ptr + iw, 0); | |||||
| load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>( | |||||
| weight, weight_ptr + 1 * ld_weight_fw, ld_weight_oc); | |||||
| cal_helper<0, 0, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); | |||||
| cal_helper<1, 1, c_dim, Vfmaq_laneq_f32, stride>(c, src, weight); | |||||
| src_ptr += ld_src_ic; | |||||
| weight_ptr += ld_weight_ic; | |||||
| } | |||||
| store_ocx_ow8_remain_static<c_dim, remain_w, Op>(c, op, dst_ptr, | |||||
| ld_dst_oc); | |||||
| } | |||||
| }; | |||||
| void pack_weight_fp32_nchw_nchw44(const float32_t* in_ptr, float32_t* dst_ptr, | |||||
| const int oc, const int kh, const int kw, | |||||
| const int ic) { | |||||
| static inline void pack_weight_fp32_nchw_nchw44(const float32_t* in_ptr, | |||||
| float32_t* dst_ptr, | |||||
| const int oc, const int kh, | |||||
| const int kw, const int ic) { | |||||
| constexpr int oc_step = 4; | constexpr int oc_step = 4; | ||||
| const int filter_oc_stride = kh * kw * ic; | const int filter_oc_stride = kh * kw * ic; | ||||
| const int filter_ic_stride = kh * kw * oc_step; | const int filter_ic_stride = kh * kw * oc_step; | ||||
| @@ -327,115 +44,15 @@ void pack_weight_fp32_nchw_nchw44(const float32_t* in_ptr, float32_t* dst_ptr, | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| template <BiasMode bias_mode, typename Op, int filter_size, int stride> | template <BiasMode bias_mode, typename Op, int filter_size, int stride> | ||||
| static void conv_direct_fp32_nchw_nchw44( | |||||
| const float32_t* src, const float32_t* filter, const float32_t* bias, | |||||
| float32_t*, float32_t* dst, const int oc, const int ic, const int ih, | |||||
| const int iw, const int oh, const int oh_block, const int ow, | |||||
| const Op& op, const int, const int) { | |||||
| constexpr int fh = filter_size; | |||||
| constexpr int fw = filter_size; | |||||
| constexpr int ic_step = 1; | |||||
| constexpr int big_oc_step = 8; | |||||
| constexpr int oc_step = 4; | |||||
| constexpr int ih_step = 1; | |||||
| constexpr int oh_step = 1; | |||||
| constexpr int ow_step = 8; | |||||
| constexpr int stride_h = stride; | |||||
| constexpr int stride_w = stride; | |||||
| constexpr int pack_iw_len = 1; | |||||
| void conv_direct_fp32_nchw_nchw44(const float32_t* src, const float32_t* filter, | |||||
| const float32_t* bias, float32_t*, | |||||
| float32_t* dst, const int oc, const int ic, | |||||
| const int ih, const int iw, const int oh, | |||||
| const int oh_block, const int ow, | |||||
| const Op& op, const int, const int); | |||||
| } // namespace fp32_direct_nchw_nchw44 | |||||
| const int img_stride = oh * ow; | |||||
| const int ow_end = ow / ow_step * ow_step; | |||||
| const int ow_remain = ow - ow_end; | |||||
| const int oc_end = oc / big_oc_step * big_oc_step; | |||||
| const int oc_remain = oc - oc_end; | |||||
| const int ld_dst_oc = oc_step * img_stride; | |||||
| using remain_fun = std::function<void( | |||||
| const float32_t* src_ptr, const float32_t* weight_ptr, | |||||
| const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, | |||||
| int iw, int ld_dst_oc, const Op& op)>; | |||||
| remain_fun kern_big_oc_remain = nullptr; | |||||
| remain_fun kern_small_oc_remain = nullptr; | |||||
| switch (ow_remain) { | |||||
| #define cb(step) \ | |||||
| case step: \ | |||||
| kern_big_oc_remain = \ | |||||
| KerNeonXXs2NchwNchw44FP32<bias_mode, Op, step, filter_size, \ | |||||
| big_oc_step, stride, ow_step>::impl; \ | |||||
| kern_small_oc_remain = \ | |||||
| KerNeonXXs2NchwNchw44FP32<bias_mode, Op, step, filter_size, \ | |||||
| oc_step, stride, ow_step>::impl; \ | |||||
| break; | |||||
| UNROLL_CALL_RAW(8, cb); | |||||
| default: | |||||
| megdnn_assert(0, "no remain %d for kern", ow_remain); | |||||
| } | |||||
| for (int oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { | |||||
| const int weight_offset = oc_idx * ic * fh * fw; | |||||
| for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) { | |||||
| for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||||
| const int src_offset = | |||||
| (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * | |||||
| ic_step * pack_iw_len; | |||||
| const int dst_offset = | |||||
| oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||||
| KerNeonXXs2NchwNchw44FP32<bias_mode, Op, ow_step, filter_size, | |||||
| big_oc_step, stride, | |||||
| ow_step>::impl(src + src_offset, | |||||
| filter + weight_offset, | |||||
| bias + oc_idx, | |||||
| dst + dst_offset, ic, | |||||
| ih, iw, ld_dst_oc, op); | |||||
| } | |||||
| if (ow_remain > 0) { | |||||
| const int src_offset = | |||||
| (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * | |||||
| ic_step * pack_iw_len; | |||||
| const int dst_offset = | |||||
| oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||||
| kern_big_oc_remain(src + src_offset, filter + weight_offset, | |||||
| bias + oc_idx, dst + dst_offset, ic, ih, iw, | |||||
| ld_dst_oc, op); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (oc_remain > 0) { | |||||
| int oc_idx = oc_end; | |||||
| const int weight_offset = oc_idx * ic * fh * fw; | |||||
| for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) { | |||||
| for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||||
| const int src_offset = | |||||
| (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * | |||||
| ic_step * pack_iw_len; | |||||
| const int dst_offset = | |||||
| oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||||
| KerNeonXXs2NchwNchw44FP32<bias_mode, Op, ow_step, filter_size, | |||||
| oc_step, stride, | |||||
| ow_step>::impl(src + src_offset, | |||||
| filter + weight_offset, | |||||
| bias + oc_idx, | |||||
| dst + dst_offset, ic, | |||||
| ih, iw, ld_dst_oc, op); | |||||
| } | |||||
| if (ow_remain > 0) { | |||||
| const int src_offset = | |||||
| (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * | |||||
| ic_step * pack_iw_len; | |||||
| const int dst_offset = | |||||
| oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||||
| kern_small_oc_remain(src + src_offset, filter + weight_offset, | |||||
| bias + oc_idx, dst + dst_offset, ic, ih, | |||||
| iw, ld_dst_oc, op); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } // namespace | |||||
| } // namespace arm_common | } // namespace arm_common | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -1,40 +0,0 @@ | |||||
| /** | |||||
| * \file dnn/src/arm_common/conv_bias/fp32/f32_direct_stride1_nchw44_kern.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "src/arm_common/conv_bias/opr_impl.h" | |||||
| #include "src/fallback/conv_bias/common.h" | |||||
| namespace megdnn { | |||||
| namespace arm_common { | |||||
| namespace conv_bias { | |||||
| #define KERN(stride, i, layout) \ | |||||
| template <BiasMode bias_mode, typename Op> \ | |||||
| void conv_direct_##stride##_##i##x##i##_fp32_##layout( \ | |||||
| const float* src, const float* filter, const float* bias, \ | |||||
| float* temp, float* dst, const int oc, const int ic, const int ih, \ | |||||
| const int iw, const int oh, const int oh_block, const int ow, \ | |||||
| const Op& op, const int ph, const int pw); | |||||
| KERN(stride1, 2, nchw44) | |||||
| KERN(stride1, 3, nchw44) | |||||
| KERN(stride1, 5, nchw44) | |||||
| KERN(stride1, 7, nchw44) | |||||
| #undef KERN | |||||
| void pack_src_fp32_nchw44_stride1(float* sptr_base, const float* sptr_origin, | |||||
| const int ph, const int pw, | |||||
| const int pad_right, const int ih, | |||||
| const int iw, const int iw2, | |||||
| const int pad_top, const int pad_bottom, | |||||
| const int ic, const int ic_stride); | |||||
| } // namespace conv_bias | |||||
| } // namespace arm_common | |||||
| } // namespace megdnn | |||||
| @@ -1,40 +0,0 @@ | |||||
| /** | |||||
| * \file dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw44_kern.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "src/arm_common/conv_bias/opr_impl.h" | |||||
| #include "src/fallback/conv_bias/common.h" | |||||
| namespace megdnn { | |||||
| namespace arm_common { | |||||
| namespace conv_bias { | |||||
| #define KERN(stride, i, layout) \ | |||||
| template <BiasMode bias_mode, typename Op> \ | |||||
| void conv_direct_##stride##_##i##x##i##_fp32_##layout( \ | |||||
| const float* src, const float* filter, const float* bias, \ | |||||
| float* temp, float* dst, const int oc, const int ic, const int ih, \ | |||||
| const int iw, const int oh, const int oh_block, const int ow, \ | |||||
| const Op& op, const int ph, const int pw); | |||||
| KERN(stride2, 2, nchw44) | |||||
| KERN(stride2, 3, nchw44) | |||||
| KERN(stride2, 5, nchw44) | |||||
| KERN(stride2, 7, nchw44) | |||||
| #undef KERN | |||||
| void pack_src_fp32_nchw44_stride2(float* sptr_base, const float* sptr_origin, | |||||
| const int ph, const int pw, | |||||
| const int pad_right, const int ih, | |||||
| const int iw, const int iw2, | |||||
| const int pad_top, const int pad_bottom, | |||||
| const int ic, const int ic_stride); | |||||
| } // namespace conv_bias | |||||
| } // namespace arm_common | |||||
| } // namespace megdnn | |||||
| @@ -6,7 +6,8 @@ | |||||
| * | * | ||||
| * Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | */ | ||||
| #ifdef __ARM_FEATURE_DOTPROD | #ifdef __ARM_FEATURE_DOTPROD | ||||
| @@ -17,7 +18,7 @@ | |||||
| #include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
| #include "src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h" | #include "src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h" | ||||
| #include "src/arm_common/conv_bias/int8/direct_dotprod_nchw44_kern.h" | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace arm_common { | namespace arm_common { | ||||
| namespace direct_dotprod_nchw44 { | namespace direct_dotprod_nchw44 { | ||||
| @@ -139,234 +140,9 @@ void copy_packed_src_int8_nchw44<2>(int8_t* dst, const int dst_step, | |||||
| } | } | ||||
| } | } | ||||
| template <typename dst_type, int stride, BiasMode bias_mode, typename Op, | |||||
| int filter_size> | |||||
| void conv_direct_sdot_int8_nchw44(dst_type* dst, const int oh, const int ow, | |||||
| const int8_t* src, const int ih, const int iw, | |||||
| const int8_t* filter, const int32_t* bias, | |||||
| const int oh_size, const int oc, const int ic, | |||||
| const Op& op) { | |||||
| constexpr int FH = filter_size; | |||||
| constexpr int FW = filter_size; | |||||
| constexpr int IC_PACK_SIZE = 4; | |||||
| constexpr int OC_PACK_SIZE = 4; | |||||
| #if MEGDNN_AARCH64 | |||||
| constexpr int OC_BIG_INTERVAL = 12; | |||||
| constexpr int OC_MID_INTERVAL = 8; | |||||
| constexpr int OC_SMA_INTERVAL = 4; | |||||
| #else | |||||
| constexpr int OC_BIG_INTERVAL = 4; | |||||
| constexpr int OC_MID_INTERVAL = 4; | |||||
| constexpr int OC_SMA_INTERVAL = 4; | |||||
| #endif | |||||
| constexpr int OW_INTERVAL = 8; | |||||
| constexpr int SH = stride; | |||||
| const int dst_numbers_per_channel = oh * ow; | |||||
| const int ow_remain = ow % OW_INTERVAL; | |||||
| const int ow_end_idx = ow - ow_remain; | |||||
| const int oc_remain = | |||||
| oc % OC_BIG_INTERVAL; //! NCHW44 means oc_remain = 4 or 8 | |||||
| const int oc_end_idx = oc - oc_remain; | |||||
| const int dst_numbers_4channel_packed = | |||||
| dst_numbers_per_channel * OC_PACK_SIZE; | |||||
| using remain_fun = std::function<void( | |||||
| dst_type * dst, const int dst_step, const int8_t* src, const int ih, | |||||
| const int iw, const int8_t* filter, const int32_t* bias, | |||||
| const int ic, const Op& op)>; | |||||
| remain_fun kern_big_oc_remain = nullptr; | |||||
| remain_fun kern_mid_oc_remain = nullptr; | |||||
| remain_fun kern_sma_oc_remain = nullptr; | |||||
| switch (ow_remain) { | |||||
| #define cb(step) \ | |||||
| case step: \ | |||||
| kern_big_oc_remain = \ | |||||
| KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, step, \ | |||||
| filter_size, OC_BIG_INTERVAL, \ | |||||
| OW_INTERVAL>::impl; \ | |||||
| kern_mid_oc_remain = \ | |||||
| KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, step, \ | |||||
| filter_size, OC_MID_INTERVAL, \ | |||||
| OW_INTERVAL>::impl; \ | |||||
| kern_sma_oc_remain = \ | |||||
| KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, step, \ | |||||
| filter_size, OC_SMA_INTERVAL, \ | |||||
| OW_INTERVAL>::impl; \ | |||||
| break; | |||||
| UNROLL_CALL_RAW(8, cb); | |||||
| #undef cb | |||||
| default: | |||||
| megdnn_assert(0, "no remain %d for kern", ow_remain); | |||||
| } | |||||
| //! filter layout is [OC/4, IC/4, FH, FW, 4OC, 4IC] | |||||
| //! cut [oc, oh, ow] into [oc/OC_INTERVAL, 1, ow/OW_INTERVAL, OW_INTERVAL, | |||||
| //! oh, OC_INTERVAL] to calculate KernNeonSdotNCHW44 calculates | |||||
| //! [OW_INTERVAL, 1, OC_INTERVAL] each time | |||||
| for (int oc_idx = 0; oc_idx < oc_end_idx; oc_idx += OC_BIG_INTERVAL) { | |||||
| const int filter_offset_in_element = oc_idx * ic * FH * FW; | |||||
| for (int oh_idx = 0; oh_idx < oh_size; ++oh_idx) { | |||||
| for (int ow_idx = 0; ow_idx < ow_end_idx; ow_idx += OW_INTERVAL) { | |||||
| const int src_offset_in_element = | |||||
| (oh_idx * SH * iw + ow_idx) * IC_PACK_SIZE; | |||||
| const int dst_offset_in_element = | |||||
| oc_idx * dst_numbers_per_channel + | |||||
| (oh_idx * ow + ow_idx) * OC_PACK_SIZE; | |||||
| const int bias_offset_in_element = oc_idx; | |||||
| KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, OW_INTERVAL, | |||||
| filter_size, OC_BIG_INTERVAL, OW_INTERVAL>:: | |||||
| impl(dst + dst_offset_in_element, | |||||
| dst_numbers_4channel_packed, | |||||
| src + src_offset_in_element, ih, iw, | |||||
| filter + filter_offset_in_element, | |||||
| bias + bias_offset_in_element, ic, op); | |||||
| } | |||||
| if (ow_remain) { | |||||
| const int src_offset_in_element = | |||||
| (oh_idx * SH * iw + ow_end_idx) * IC_PACK_SIZE; | |||||
| const int dst_offset_in_element = | |||||
| oc_idx * dst_numbers_per_channel + | |||||
| (oh_idx * ow + ow_end_idx) * OC_PACK_SIZE; | |||||
| const int bias_offset_in_element = oc_idx; | |||||
| kern_big_oc_remain(dst + dst_offset_in_element, | |||||
| dst_numbers_4channel_packed, | |||||
| src + src_offset_in_element, ih, iw, | |||||
| filter + filter_offset_in_element, | |||||
| bias + bias_offset_in_element, ic, op); | |||||
| } | |||||
| } | |||||
| } | |||||
| #ifdef MEGDNN_AARCH64 | |||||
| //! oc_remain must be 4 or 8 on aarch64 and must be 0 on aarch32 | |||||
| if (oc_remain) { | |||||
| int oc_idx = oc_end_idx; | |||||
| const int filter_offset_in_element = oc_idx * ic * FH * FW; | |||||
| for (int oh_idx = 0; oh_idx < oh_size; ++oh_idx) { | |||||
| for (int ow_idx = 0; ow_idx < ow_end_idx; ow_idx += OW_INTERVAL) { | |||||
| const int src_offset_in_element = | |||||
| (oh_idx * SH * iw + ow_idx) * IC_PACK_SIZE; | |||||
| const int dst_offset_in_element = | |||||
| oc_idx * dst_numbers_per_channel + | |||||
| (oh_idx * ow + ow_idx) * OC_PACK_SIZE; | |||||
| const int bias_offset_in_element = oc_idx; | |||||
| if (oc_remain == 8) { | |||||
| KernNeonSdotNCHW44< | |||||
| dst_type, stride, bias_mode, Op, OW_INTERVAL, | |||||
| filter_size, OC_MID_INTERVAL, | |||||
| OW_INTERVAL>::impl(dst + dst_offset_in_element, | |||||
| dst_numbers_4channel_packed, | |||||
| src + src_offset_in_element, ih, | |||||
| iw, | |||||
| filter + | |||||
| filter_offset_in_element, | |||||
| bias + bias_offset_in_element, | |||||
| ic, op); | |||||
| } else { | |||||
| KernNeonSdotNCHW44< | |||||
| dst_type, stride, bias_mode, Op, OW_INTERVAL, | |||||
| filter_size, OC_SMA_INTERVAL, | |||||
| OW_INTERVAL>::impl(dst + dst_offset_in_element, | |||||
| dst_numbers_4channel_packed, | |||||
| src + src_offset_in_element, ih, | |||||
| iw, | |||||
| filter + | |||||
| filter_offset_in_element, | |||||
| bias + bias_offset_in_element, | |||||
| ic, op); | |||||
| } | |||||
| } | |||||
| if (ow_remain) { | |||||
| const int src_offset_in_element = | |||||
| (oh_idx * SH * iw + ow_end_idx) * IC_PACK_SIZE; | |||||
| const int dst_offset_in_element = | |||||
| oc_idx * dst_numbers_per_channel + | |||||
| (oh_idx * ow + ow_end_idx) * OC_PACK_SIZE; | |||||
| const int bias_offset_in_element = oc_idx; | |||||
| if (oc_remain == 8) { | |||||
| kern_mid_oc_remain(dst + dst_offset_in_element, | |||||
| dst_numbers_4channel_packed, | |||||
| src + src_offset_in_element, ih, iw, | |||||
| filter + filter_offset_in_element, | |||||
| bias + bias_offset_in_element, ic, op); | |||||
| } else { | |||||
| kern_sma_oc_remain(dst + dst_offset_in_element, | |||||
| dst_numbers_4channel_packed, | |||||
| src + src_offset_in_element, ih, iw, | |||||
| filter + filter_offset_in_element, | |||||
| bias + bias_offset_in_element, ic, op); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| #endif | |||||
| } | |||||
| #define CONSTRUCT_FUNC(filter_size) \ | |||||
| template <typename dst_type, BiasMode bias_mode, typename Op, int stride> \ | |||||
| void conv_direct_##filter_size##x##filter_size##_int8_nchw44( \ | |||||
| dst_type* dst, const int oh, const int ow, const int8_t* src, \ | |||||
| const int ih, const int iw, const int8_t* weight, \ | |||||
| const int32_t* bias, const int oh_size, const int oc, \ | |||||
| const int ic, const Op& op) { \ | |||||
| conv_direct_sdot_int8_nchw44<dst_type, stride, bias_mode, Op, \ | |||||
| filter_size>( \ | |||||
| dst, oh, ow, src, ih, iw, weight, bias, oh_size, oc, ic, op); \ | |||||
| } | |||||
| CONSTRUCT_FUNC(2); | |||||
| CONSTRUCT_FUNC(3); | |||||
| CONSTRUCT_FUNC(5); | |||||
| CONSTRUCT_FUNC(7); | |||||
| #undef CONSTRUCT_FUNC | |||||
| #define INSTANTIATION(dst_type, stride, i, bias_mode, Op) \ | |||||
| template void conv_direct_##i##x##i##_int8_nchw44<dst_type, bias_mode, Op, \ | |||||
| stride>( \ | |||||
| dst_type * dst, const int oh, const int ow, const int8_t* src, \ | |||||
| const int ih, const int iw, const int8_t* weight, \ | |||||
| const int32_t* bias, const int oh_size, const int oc, \ | |||||
| const int ic, const Op& op); | |||||
| #define FOR_OP(stride, i, bias_mode) \ | |||||
| INSTANTIATION(dt_int8, stride, i, bias_mode, \ | |||||
| TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
| INSTANTIATION(dt_int32, stride, i, bias_mode, \ | |||||
| NoneOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
| INSTANTIATION(dt_int8, stride, i, bias_mode, \ | |||||
| ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
| INSTANTIATION(dt_int8, stride, i, bias_mode, \ | |||||
| HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) | |||||
| #define FOR_BIAS(stride, i) \ | |||||
| FOR_OP(stride, i, BiasMode::NO_BIAS) \ | |||||
| FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS) | |||||
| #define FOR_FILTER(stride) \ | |||||
| FOR_BIAS(stride, 2) \ | |||||
| FOR_BIAS(stride, 3) \ | |||||
| FOR_BIAS(stride, 5) \ | |||||
| FOR_BIAS(stride, 7) | |||||
| FOR_FILTER(1) | |||||
| FOR_FILTER(2) | |||||
| #undef FOR_STRIDE | |||||
| #undef FOR_FILTER | |||||
| #undef FOR_IC | |||||
| #undef FOR_BIAS | |||||
| #undef FOR_NONLINEAR | |||||
| #undef FOR_REMAIN | |||||
| #undef INSTANTIATION | |||||
| } // namespace direct_dotprod_nchw44 | } // namespace direct_dotprod_nchw44 | ||||
| } // namespace arm_common | } // namespace arm_common | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| #endif | #endif | ||||
| //vim: syntax=cpp.doxygen | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -6,7 +6,8 @@ | |||||
| * | * | ||||
| * Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | */ | ||||
| #if __ARM_FEATURE_DOTPROD | #if __ARM_FEATURE_DOTPROD | ||||
| @@ -42,20 +43,13 @@ using BiasMode = ConvBiasForward::BiasMode; | |||||
| * @return none | * @return none | ||||
| */ | */ | ||||
| #define KERN(filter_size) \ | |||||
| template <typename dst_type, BiasMode bias_mode, typename Op, int stride> \ | |||||
| void conv_direct_##filter_size##x##filter_size##_int8_nchw44( \ | |||||
| dst_type* dst, const int oh, const int ow, const int8_t* src, \ | |||||
| const int ih, const int iw, const int8_t* weight, \ | |||||
| const int32_t* bias, const int oh_size, const int oc, \ | |||||
| const int ic, const Op& op) | |||||
| KERN(2); | |||||
| KERN(3); | |||||
| KERN(5); | |||||
| KERN(7); | |||||
| #undef KERN | |||||
| template <typename dst_type, int stride, BiasMode bias_mode, typename Op, | |||||
| int filter_size> | |||||
| void conv_direct_sdot_int8_nchw44(dst_type* dst, const int oh, const int ow, | |||||
| const int8_t* src, const int ih, const int iw, | |||||
| const int8_t* filter, const int32_t* bias, | |||||
| const int oh_size, const int oc, const int ic, | |||||
| const Op& op); | |||||
| /** | /** | ||||
| * @brief : copy data from src to dst for direct conv with no side effect | * @brief : copy data from src to dst for direct conv with no side effect | ||||
| * @param : [output ptr] dst | * @param : [output ptr] dst | ||||
| @@ -84,4 +78,4 @@ void copy_packed_src_int8_nchw44(int8_t* dst, const int dst_step, | |||||
| #endif | #endif | ||||
| //vim: syntax=cpp.doxygen | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -148,14 +148,10 @@ static void conv_kern(const WorkspaceBundle& bundle, | |||||
| float scale_dst = ncb_param.dst_type.param<dtype::QuantizedS8>().scale; | float scale_dst = ncb_param.dst_type.param<dtype::QuantizedS8>().scale; | ||||
| op = Op(scale_bias, scale_dst); | op = Op(scale_bias, scale_dst); | ||||
| } | } | ||||
| #define KERN1_NCHW44_CONV(filter) \ | |||||
| direct_dotprod_nchw44::conv_direct_##filter##x##filter##_int8_nchw44< \ | |||||
| dst_type, bias_mode, Op, stride>(dst, OH, OW, copy_dst, \ | |||||
| ih_real_size, iw2, weights, bias, \ | |||||
| oh_real_size, OC, IC, op); | |||||
| DISPATCH_FILTER(filter_size, KERN1_NCHW44_CONV); | |||||
| #undef KERN1_NCHW44_CONV | |||||
| direct_dotprod_nchw44::conv_direct_sdot_int8_nchw44< | |||||
| dst_type, stride, bias_mode, Op, filter_size>( | |||||
| dst, OH, OW, copy_dst, ih_real_size, iw2, weights, bias, | |||||
| oh_real_size, OC, IC, op); | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| @@ -342,4 +338,4 @@ ConvBiasImpl::AlgoDotS8Direct_NCHW44::dispatch_kerns( | |||||
| #endif | #endif | ||||
| //vim: syntax=cpp.doxygen | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -1,435 +0,0 @@ | |||||
| /** | |||||
| * \file dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_kern.h | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #pragma once | |||||
| #ifdef __ARM_FEATURE_DOTPROD | |||||
| #include "megdnn/arch.h" | |||||
| #include "src/arm_common/conv_bias/intrinsic_helper.h" | |||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/intrinsic_helper.h" | |||||
| #include "src/arm_common/neon_struct.h" | |||||
| #include "src/common/unroll_macro.h" | |||||
| namespace megdnn { | |||||
| namespace arm_common { | |||||
| namespace direct_dotprod_nchw44 { | |||||
| constexpr int SIMD_LEN = 16; | |||||
| constexpr int IC_PACK_SIZE = 4; | |||||
| constexpr int OC_PACK_SIZE = 4; | |||||
| constexpr int filter_next_col = | |||||
| IC_PACK_SIZE * OC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC] | |||||
| template <int row, BiasMode bias_mode> | |||||
| MEGDNN_ALWAYS_INLINE void init_ocx_ow8(int32x4_t c[][8], | |||||
| const int32_t* bias_ptr, int oc_step) { | |||||
| static_assert(row == 1 || row == 2 || row == 3, "Invalid OC number."); | |||||
| if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
| #define BIAS_INIT(step, i) c[i][step] = vld1q_s32(bias_ptr + i * oc_step); | |||||
| switch (row) { | |||||
| case 3: | |||||
| UNROLL_CALL_RAW(8, BIAS_INIT, 2); | |||||
| case 2: | |||||
| UNROLL_CALL_RAW(8, BIAS_INIT, 1); | |||||
| default: | |||||
| UNROLL_CALL_RAW(8, BIAS_INIT, 0); | |||||
| } | |||||
| #undef BIAS_INIT | |||||
| } else { | |||||
| #define BIAS_INIT(step, i) c[i][step] = vdupq_n_s32(0); | |||||
| switch (row) { | |||||
| case 3: | |||||
| UNROLL_CALL_RAW(8, BIAS_INIT, 2); | |||||
| case 2: | |||||
| UNROLL_CALL_RAW(8, BIAS_INIT, 1); | |||||
| default: | |||||
| UNROLL_CALL_RAW(8, BIAS_INIT, 0); | |||||
| } | |||||
| #undef BIAS_INIT | |||||
| } | |||||
| } | |||||
| #define cb11(col) \ | |||||
| op(res[0][col], reinterpret_cast<dt_qint8*>(dst_ptr + col / 2 * 8)); | |||||
| #define cb21(col) \ | |||||
| op(res[0][col], reinterpret_cast<dt_qint8*>(dst_ptr + col / 2 * 8)); \ | |||||
| op(res[1][col], \ | |||||
| reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc + col / 2 * 8)); | |||||
| #define cb31(col) \ | |||||
| op(res[0][col], reinterpret_cast<dt_qint8*>(dst_ptr + col / 2 * 8)); \ | |||||
| op(res[1][col], \ | |||||
| reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc + col / 2 * 8)); \ | |||||
| op(res[2][col], reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc + \ | |||||
| ld_dst_oc + col / 2 * 8)); | |||||
| #define cb12(step) \ | |||||
| op({{res[0][2 * step], res[0][2 * step + 1]}}, \ | |||||
| reinterpret_cast<dt_qint8*>(dst_ptr + step * 8)); | |||||
| #define cb22(step) \ | |||||
| op({{res[0][2 * step], res[0][2 * step + 1]}}, \ | |||||
| reinterpret_cast<dt_qint8*>(dst_ptr + step * 8)); \ | |||||
| op({{res[1][2 * step], res[1][2 * step + 1]}}, \ | |||||
| reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc + step * 8)); | |||||
| #define cb32(step) \ | |||||
| op({{res[0][2 * step], res[0][2 * step + 1]}}, \ | |||||
| reinterpret_cast<dt_qint8*>(dst_ptr + step * 8)); \ | |||||
| op({{res[1][2 * step], res[1][2 * step + 1]}}, \ | |||||
| reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc + step * 8)); \ | |||||
| op({{res[2][2 * step], res[2][2 * step + 1]}}, \ | |||||
| reinterpret_cast<dt_qint8*>(dst_ptr + 2 * ld_dst_oc + step * 8)); | |||||
| template <int row, int ow_remain, typename Op, typename T> | |||||
| struct StoreOCxOWx { | |||||
| static MEGDNN_ALWAYS_INLINE void impl(int32x4_t res[][8], const Op& op, | |||||
| T* dst_ptr, const int ld_dst_oc); | |||||
| }; | |||||
| template <int ow_remain, typename Op, typename T> | |||||
| struct StoreOCxOWx<1, ow_remain, Op, T> { | |||||
| static void impl(int32x4_t res[][8], const Op& op, T* dst_ptr, | |||||
| const int ld_dst_oc) { | |||||
| MEGDNN_MARK_USED_VAR(ld_dst_oc); | |||||
| switch (ow_remain) { | |||||
| case 8: | |||||
| UNROLL_CALL_RAW(4, cb12); | |||||
| break; | |||||
| case 7: | |||||
| cb11(6); | |||||
| case 6: | |||||
| UNROLL_CALL_RAW(3, cb12); | |||||
| break; | |||||
| case 5: | |||||
| cb11(4); | |||||
| case 4: | |||||
| UNROLL_CALL_RAW(2, cb12); | |||||
| break; | |||||
| case 3: | |||||
| cb11(2); | |||||
| case 2: | |||||
| UNROLL_CALL_RAW(1, cb12); | |||||
| break; | |||||
| case 1: | |||||
| cb11(0); | |||||
| default: | |||||
| break; | |||||
| } | |||||
| } | |||||
| }; | |||||
| template <int ow_remain, typename Op, typename T> | |||||
| struct StoreOCxOWx<2, ow_remain, Op, T> { | |||||
| static MEGDNN_ALWAYS_INLINE void impl(int32x4_t res[][8], const Op& op, | |||||
| T* dst_ptr, const int ld_dst_oc) { | |||||
| switch (ow_remain) { | |||||
| case 8: | |||||
| UNROLL_CALL_RAW(4, cb22); | |||||
| break; | |||||
| case 7: | |||||
| cb21(6); | |||||
| case 6: | |||||
| UNROLL_CALL_RAW(3, cb22); | |||||
| break; | |||||
| case 5: | |||||
| cb21(4); | |||||
| case 4: | |||||
| UNROLL_CALL_RAW(2, cb22); | |||||
| break; | |||||
| case 3: | |||||
| cb21(2); | |||||
| case 2: | |||||
| UNROLL_CALL_RAW(1, cb22); | |||||
| break; | |||||
| case 1: | |||||
| cb21(0); | |||||
| default: | |||||
| break; | |||||
| } | |||||
| } | |||||
| }; | |||||
| template <int ow_remain, typename Op, typename T> | |||||
| struct StoreOCxOWx<3, ow_remain, Op, T> { | |||||
| static MEGDNN_ALWAYS_INLINE void impl(int32x4_t res[][8], const Op& op, | |||||
| T* dst_ptr, const int ld_dst_oc) { | |||||
| switch (ow_remain) { | |||||
| case 8: | |||||
| UNROLL_CALL_RAW(4, cb32); | |||||
| break; | |||||
| case 7: | |||||
| cb31(6); | |||||
| case 6: | |||||
| UNROLL_CALL_RAW(3, cb32); | |||||
| break; | |||||
| case 5: | |||||
| cb31(4); | |||||
| case 4: | |||||
| UNROLL_CALL_RAW(2, cb32); | |||||
| break; | |||||
| case 3: | |||||
| cb31(2); | |||||
| case 2: | |||||
| UNROLL_CALL_RAW(1, cb32); | |||||
| break; | |||||
| case 1: | |||||
| cb31(0); | |||||
| default: | |||||
| break; | |||||
| } | |||||
| } | |||||
| }; | |||||
| #undef cb11 | |||||
| #undef cb21 | |||||
| #undef cb31 | |||||
| #undef cb12 | |||||
| #undef cb22 | |||||
| #undef cb32 | |||||
| template <int row, int ow_remain, typename Op, typename T> | |||||
| MEGDNN_ALWAYS_INLINE void store_ocx_owx_remain_static(int32x4_t res[][8], | |||||
| const Op& op, T* dst_ptr, | |||||
| const int ld_dst_oc) { | |||||
| StoreOCxOWx<row, ow_remain, Op, T>::impl(res, op, dst_ptr, ld_dst_oc); | |||||
| } | |||||
| template <int res_row, int src_row, int src_start_idx, int weight_idx, | |||||
| typename FUNC, typename T, typename T2, typename T3> | |||||
| struct ShiftCalHelper { | |||||
| static MEGDNN_ALWAYS_INLINE void impl(T& res, T2& src, T3& weight) { | |||||
| #define cb(step) \ | |||||
| res[res_row][step] = FUNC::template impl<((src_start_idx + step) % 4)>( \ | |||||
| res[res_row][step], weight[weight_idx], \ | |||||
| src[src_row][(src_start_idx + step) / 4]); | |||||
| UNROLL_CALL_RAW(8, cb); | |||||
| #undef cb | |||||
| } | |||||
| }; | |||||
| template <int res_row, int src_row, int src_start_idx, int weight_idx, | |||||
| typename FUNC, typename T, typename T2, typename T3> | |||||
| MEGDNN_ALWAYS_INLINE void cal_helper(T& res, T2& src, T3& weight) { | |||||
| ShiftCalHelper<res_row, src_row, src_start_idx, weight_idx, FUNC, T, T2, | |||||
| T3>::impl(res, src, weight); | |||||
| }; | |||||
| /** | |||||
| * oc12_owx(m = 12, n = x) and oc8_owx(m = 8, n = x) and oc4_owx(m = 4, n = x) | |||||
| * gemm like kernel | |||||
| * */ | |||||
| template <typename dst_type, int stride, BiasMode bias_mode, typename Op, | |||||
| int ow_remain, int filter_size, int oc_interval, int ow_interval> | |||||
| struct KernNeonSdotNCHW44 { | |||||
| static void impl(dst_type* dst, const int dst_step, const int8_t* src, | |||||
| const int ih, const int iw, const int8_t* filter, | |||||
| const int32_t* bias, const int ic, const Op& op); | |||||
| }; | |||||
| template <typename dst_type, BiasMode bias_mode, typename Op, int ow_remain, | |||||
| int filter_size, int oc_interval, int ow_interval> | |||||
| struct KernNeonSdotNCHW44<dst_type, 1, bias_mode, Op, ow_remain, filter_size, | |||||
| oc_interval, ow_interval> { | |||||
| static void impl(dst_type* dst, const int dst_step, const int8_t* src, | |||||
| const int ih, const int iw, const int8_t* filter, | |||||
| const int32_t* bias, const int ic, const Op& op) { | |||||
| constexpr int FH = filter_size; | |||||
| constexpr int FW = filter_size; | |||||
| constexpr int filter_next_row = | |||||
| FW * OC_PACK_SIZE * | |||||
| IC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC] | |||||
| const int filter_next_4oc = | |||||
| FH * FW * ic * OC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC] | |||||
| const int src_next_ic = ih * iw; | |||||
| const int src_next_row = iw * IC_PACK_SIZE; | |||||
| constexpr int NSRC = (ow_interval + filter_size - 1) / 4 + 1; | |||||
| constexpr int LOOP = oc_interval / 4; | |||||
| int32x4_t res[3][ow_interval]; | |||||
| init_ocx_ow8<LOOP, bias_mode>(res, bias, OC_PACK_SIZE); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += IC_PACK_SIZE) { | |||||
| const int8_t* i_src = src + ic_idx * src_next_ic; | |||||
| const int8_t* i_filter = filter + ic_idx * FH * FW * OC_PACK_SIZE; | |||||
| for (int fh_idx = 0; fh_idx < FH; ++fh_idx) { | |||||
| int8x16_t src[1][4]; | |||||
| int8x16_t weight[3]; | |||||
| load_helper<NSRC, 0, SIMD_LEN, 1, Vld1q_s8>(src, i_src, 0); | |||||
| //! do not use switch order 3,2,1 because it will slow the speed. | |||||
| #define CALC_PART(step) \ | |||||
| switch (LOOP) { \ | |||||
| case 1: \ | |||||
| weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \ | |||||
| filter_next_col * step); \ | |||||
| cal_helper<0, 0, step, 0, Vdotq_laneq_s32>(res, src, weight); \ | |||||
| break; \ | |||||
| case 2: \ | |||||
| weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \ | |||||
| filter_next_col * step); \ | |||||
| cal_helper<0, 0, step, 0, Vdotq_laneq_s32>(res, src, weight); \ | |||||
| weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \ | |||||
| filter_next_col * step); \ | |||||
| cal_helper<1, 0, step, 1, Vdotq_laneq_s32>(res, src, weight); \ | |||||
| break; \ | |||||
| case 3: \ | |||||
| weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \ | |||||
| filter_next_col * step); \ | |||||
| cal_helper<0, 0, step, 0, Vdotq_laneq_s32>(res, src, weight); \ | |||||
| weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \ | |||||
| filter_next_col * step); \ | |||||
| cal_helper<1, 0, step, 1, Vdotq_laneq_s32>(res, src, weight); \ | |||||
| weight[2] = vld1q_s8(i_filter + filter_next_4oc * 2 + \ | |||||
| filter_next_col * step); \ | |||||
| cal_helper<2, 0, step, 2, Vdotq_laneq_s32>(res, src, weight); \ | |||||
| break; \ | |||||
| default: \ | |||||
| break; \ | |||||
| } | |||||
| switch (filter_size) { | |||||
| case 2: | |||||
| UNROLL_CALL_RAW(2, CALC_PART); | |||||
| break; | |||||
| case 3: | |||||
| UNROLL_CALL_RAW(3, CALC_PART); | |||||
| break; | |||||
| case 5: | |||||
| UNROLL_CALL_RAW(5, CALC_PART); | |||||
| break; | |||||
| case 7: | |||||
| UNROLL_CALL_RAW(7, CALC_PART); | |||||
| break; | |||||
| default: | |||||
| break; | |||||
| } | |||||
| #undef CALC_PART | |||||
| i_filter += filter_next_row; | |||||
| i_src += src_next_row; | |||||
| } | |||||
| } | |||||
| store_ocx_owx_remain_static<LOOP, ow_remain, Op>(res, op, dst, | |||||
| dst_step); | |||||
| } | |||||
| }; | |||||
| template <typename dst_type, BiasMode bias_mode, typename Op, int ow_remain, | |||||
| int filter_size, int oc_interval, int ow_interval> | |||||
| struct KernNeonSdotNCHW44<dst_type, 2, bias_mode, Op, ow_remain, filter_size, | |||||
| oc_interval, ow_interval> { | |||||
| static void impl(dst_type* dst, const int dst_step, const int8_t* src, | |||||
| const int ih, const int iw, const int8_t* filter, | |||||
| const int32_t* bias, const int ic, const Op& op) { | |||||
| constexpr int FH = filter_size; | |||||
| constexpr int FW = filter_size; | |||||
| constexpr int filter_next_row = | |||||
| FW * OC_PACK_SIZE * | |||||
| IC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC] | |||||
| const int filter_next_4oc = | |||||
| FH * FW * ic * OC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC] | |||||
| const int src_next_ic = ih * iw; | |||||
| const int src_next_row = iw * IC_PACK_SIZE; | |||||
| constexpr int NSRC = (ow_interval * 2 + filter_size - 3) / 8 + 1; | |||||
| constexpr int LOOP = oc_interval / 4; | |||||
| int32x4_t res[3][ow_interval]; | |||||
| init_ocx_ow8<LOOP, bias_mode>(res, bias, OC_PACK_SIZE); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += IC_PACK_SIZE) { | |||||
| const int8_t* i_src = src + ic_idx * src_next_ic; | |||||
| const int8_t* i_filter = filter + ic_idx * FH * FW * OC_PACK_SIZE; | |||||
| for (int fh_idx = 0; fh_idx < FH; ++fh_idx) { | |||||
| int8x16_t src[2][3]; | |||||
| int8x16_t weight[3]; | |||||
| const int offset = megdnn::div_ceil(iw, 2) * IC_PACK_SIZE; | |||||
| load_helper<NSRC, 0, SIMD_LEN, 2, Vld1q_s8>(src, i_src, offset); | |||||
| //! do not use switch order 3,2,1 because it will slow the speed. | |||||
| #define CALC_PART(step) \ | |||||
| switch (LOOP) { \ | |||||
| case 1: \ | |||||
| weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \ | |||||
| filter_next_col * step); \ | |||||
| cal_helper<0, step % 2, step / 2, 0, Vdotq_laneq_s32>(res, src, \ | |||||
| weight); \ | |||||
| break; \ | |||||
| case 2: \ | |||||
| weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \ | |||||
| filter_next_col * step); \ | |||||
| cal_helper<0, step % 2, step / 2, 0, Vdotq_laneq_s32>(res, src, \ | |||||
| weight); \ | |||||
| weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \ | |||||
| filter_next_col * step); \ | |||||
| cal_helper<1, step % 2, step / 2, 1, Vdotq_laneq_s32>(res, src, \ | |||||
| weight); \ | |||||
| break; \ | |||||
| case 3: \ | |||||
| weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \ | |||||
| filter_next_col * step); \ | |||||
| cal_helper<0, step % 2, step / 2, 0, Vdotq_laneq_s32>(res, src, \ | |||||
| weight); \ | |||||
| weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \ | |||||
| filter_next_col * step); \ | |||||
| cal_helper<1, step % 2, step / 2, 1, Vdotq_laneq_s32>(res, src, \ | |||||
| weight); \ | |||||
| weight[2] = vld1q_s8(i_filter + filter_next_4oc * 2 + \ | |||||
| filter_next_col * step); \ | |||||
| cal_helper<2, step % 2, step / 2, 2, Vdotq_laneq_s32>(res, src, \ | |||||
| weight); \ | |||||
| break; \ | |||||
| default: \ | |||||
| break; \ | |||||
| } | |||||
| switch (filter_size) { | |||||
| case 2: | |||||
| UNROLL_CALL_RAW(2, CALC_PART); | |||||
| break; | |||||
| case 3: | |||||
| UNROLL_CALL_RAW(3, CALC_PART); | |||||
| break; | |||||
| case 5: | |||||
| UNROLL_CALL_RAW(5, CALC_PART); | |||||
| break; | |||||
| case 7: | |||||
| UNROLL_CALL_RAW(7, CALC_PART); | |||||
| break; | |||||
| default: | |||||
| break; | |||||
| } | |||||
| #undef CALC_PART | |||||
| i_filter += filter_next_row; | |||||
| i_src += src_next_row; | |||||
| } | |||||
| } | |||||
| store_ocx_owx_remain_static<LOOP, ow_remain, Op>(res, op, dst, | |||||
| dst_step); | |||||
| } | |||||
| }; | |||||
| } // namespace direct_dotprod_nchw44 | |||||
| } // namespace arm_common | |||||
| } // namespace megdnn | |||||
| #endif | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -0,0 +1,245 @@ | |||||
| /** | |||||
| * \file | |||||
| * dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #pragma once | |||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #include "megdnn/arch.h" | |||||
| #include "src/arm_common/conv_bias/intrinsic_helper.h" | |||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/intrinsic_helper.h" | |||||
| #include "src/arm_common/neon_struct.h" | |||||
| #include "src/common/unroll_macro.h" | |||||
| namespace megdnn { | |||||
| namespace arm_common { | |||||
| namespace direct_dotprod_nchw44 { | |||||
| constexpr int SIMD_LEN = 16; | |||||
| constexpr int IC_PACK_SIZE = 4; | |||||
| constexpr int OC_PACK_SIZE = 4; | |||||
| constexpr int filter_next_col = | |||||
| IC_PACK_SIZE * OC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC] | |||||
| template <int row, BiasMode bias_mode> | |||||
| MEGDNN_ALWAYS_INLINE void init_ocx_ow8(int32x4_t c[][8], | |||||
| const int32_t* bias_ptr, int oc_step) { | |||||
| static_assert(row == 1 || row == 2 || row == 3, "Invalid OC number."); | |||||
| if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
| #define BIAS_INIT(step, i) c[i][step] = vld1q_s32(bias_ptr + i * oc_step); | |||||
| switch (row) { | |||||
| case 3: | |||||
| UNROLL_CALL_RAW(8, BIAS_INIT, 2); | |||||
| case 2: | |||||
| UNROLL_CALL_RAW(8, BIAS_INIT, 1); | |||||
| default: | |||||
| UNROLL_CALL_RAW(8, BIAS_INIT, 0); | |||||
| } | |||||
| #undef BIAS_INIT | |||||
| } else { | |||||
| #define BIAS_INIT(step, i) c[i][step] = vdupq_n_s32(0); | |||||
| switch (row) { | |||||
| case 3: | |||||
| UNROLL_CALL_RAW(8, BIAS_INIT, 2); | |||||
| case 2: | |||||
| UNROLL_CALL_RAW(8, BIAS_INIT, 1); | |||||
| default: | |||||
| UNROLL_CALL_RAW(8, BIAS_INIT, 0); | |||||
| } | |||||
| #undef BIAS_INIT | |||||
| } | |||||
| } | |||||
| #define cb11(col) \ | |||||
| op(res[0][col], reinterpret_cast<dt_qint8*>(dst_ptr + col / 2 * 8)); | |||||
| #define cb21(col) \ | |||||
| op(res[0][col], reinterpret_cast<dt_qint8*>(dst_ptr + col / 2 * 8)); \ | |||||
| op(res[1][col], \ | |||||
| reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc + col / 2 * 8)); | |||||
| #define cb31(col) \ | |||||
| op(res[0][col], reinterpret_cast<dt_qint8*>(dst_ptr + col / 2 * 8)); \ | |||||
| op(res[1][col], \ | |||||
| reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc + col / 2 * 8)); \ | |||||
| op(res[2][col], reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc + \ | |||||
| ld_dst_oc + col / 2 * 8)); | |||||
| #define cb12(step) \ | |||||
| op({{res[0][2 * step], res[0][2 * step + 1]}}, \ | |||||
| reinterpret_cast<dt_qint8*>(dst_ptr + step * 8)); | |||||
| #define cb22(step) \ | |||||
| op({{res[0][2 * step], res[0][2 * step + 1]}}, \ | |||||
| reinterpret_cast<dt_qint8*>(dst_ptr + step * 8)); \ | |||||
| op({{res[1][2 * step], res[1][2 * step + 1]}}, \ | |||||
| reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc + step * 8)); | |||||
| #define cb32(step) \ | |||||
| op({{res[0][2 * step], res[0][2 * step + 1]}}, \ | |||||
| reinterpret_cast<dt_qint8*>(dst_ptr + step * 8)); \ | |||||
| op({{res[1][2 * step], res[1][2 * step + 1]}}, \ | |||||
| reinterpret_cast<dt_qint8*>(dst_ptr + ld_dst_oc + step * 8)); \ | |||||
| op({{res[2][2 * step], res[2][2 * step + 1]}}, \ | |||||
| reinterpret_cast<dt_qint8*>(dst_ptr + 2 * ld_dst_oc + step * 8)); | |||||
| template <int row, int ow_remain, typename Op, typename T> | |||||
| struct StoreOCxOWx { | |||||
| static MEGDNN_ALWAYS_INLINE void impl(int32x4_t res[][8], const Op& op, | |||||
| T* dst_ptr, const int ld_dst_oc); | |||||
| }; | |||||
| template <int ow_remain, typename Op, typename T> | |||||
| struct StoreOCxOWx<1, ow_remain, Op, T> { | |||||
| static void impl(int32x4_t res[][8], const Op& op, T* dst_ptr, | |||||
| const int ld_dst_oc) { | |||||
| MEGDNN_MARK_USED_VAR(ld_dst_oc); | |||||
| switch (ow_remain) { | |||||
| case 8: | |||||
| UNROLL_CALL_RAW(4, cb12); | |||||
| break; | |||||
| case 7: | |||||
| cb11(6); | |||||
| case 6: | |||||
| UNROLL_CALL_RAW(3, cb12); | |||||
| break; | |||||
| case 5: | |||||
| cb11(4); | |||||
| case 4: | |||||
| UNROLL_CALL_RAW(2, cb12); | |||||
| break; | |||||
| case 3: | |||||
| cb11(2); | |||||
| case 2: | |||||
| UNROLL_CALL_RAW(1, cb12); | |||||
| break; | |||||
| case 1: | |||||
| cb11(0); | |||||
| default: | |||||
| break; | |||||
| } | |||||
| } | |||||
| }; | |||||
| template <int ow_remain, typename Op, typename T> | |||||
| struct StoreOCxOWx<2, ow_remain, Op, T> { | |||||
| static MEGDNN_ALWAYS_INLINE void impl(int32x4_t res[][8], const Op& op, | |||||
| T* dst_ptr, const int ld_dst_oc) { | |||||
| switch (ow_remain) { | |||||
| case 8: | |||||
| UNROLL_CALL_RAW(4, cb22); | |||||
| break; | |||||
| case 7: | |||||
| cb21(6); | |||||
| case 6: | |||||
| UNROLL_CALL_RAW(3, cb22); | |||||
| break; | |||||
| case 5: | |||||
| cb21(4); | |||||
| case 4: | |||||
| UNROLL_CALL_RAW(2, cb22); | |||||
| break; | |||||
| case 3: | |||||
| cb21(2); | |||||
| case 2: | |||||
| UNROLL_CALL_RAW(1, cb22); | |||||
| break; | |||||
| case 1: | |||||
| cb21(0); | |||||
| default: | |||||
| break; | |||||
| } | |||||
| } | |||||
| }; | |||||
| template <int ow_remain, typename Op, typename T> | |||||
| struct StoreOCxOWx<3, ow_remain, Op, T> { | |||||
| static MEGDNN_ALWAYS_INLINE void impl(int32x4_t res[][8], const Op& op, | |||||
| T* dst_ptr, const int ld_dst_oc) { | |||||
| switch (ow_remain) { | |||||
| case 8: | |||||
| UNROLL_CALL_RAW(4, cb32); | |||||
| break; | |||||
| case 7: | |||||
| cb31(6); | |||||
| case 6: | |||||
| UNROLL_CALL_RAW(3, cb32); | |||||
| break; | |||||
| case 5: | |||||
| cb31(4); | |||||
| case 4: | |||||
| UNROLL_CALL_RAW(2, cb32); | |||||
| break; | |||||
| case 3: | |||||
| cb31(2); | |||||
| case 2: | |||||
| UNROLL_CALL_RAW(1, cb32); | |||||
| break; | |||||
| case 1: | |||||
| cb31(0); | |||||
| default: | |||||
| break; | |||||
| } | |||||
| } | |||||
| }; | |||||
| #undef cb11 | |||||
| #undef cb21 | |||||
| #undef cb31 | |||||
| #undef cb12 | |||||
| #undef cb22 | |||||
| #undef cb32 | |||||
| template <int row, int ow_remain, typename Op, typename T> | |||||
| MEGDNN_ALWAYS_INLINE void store_ocx_owx_remain_static(int32x4_t res[][8], | |||||
| const Op& op, T* dst_ptr, | |||||
| const int ld_dst_oc) { | |||||
| StoreOCxOWx<row, ow_remain, Op, T>::impl(res, op, dst_ptr, ld_dst_oc); | |||||
| } | |||||
| template <int res_row, int src_row, int src_start_idx, int weight_idx, | |||||
| typename T, typename T2, typename T3> | |||||
| struct ShiftCalHelper { | |||||
| static MEGDNN_ALWAYS_INLINE void impl(T& res, T2& src, T3& weight) { | |||||
| #define cb(step) \ | |||||
| res[res_row][step] = \ | |||||
| vdotq_laneq_s32(res[res_row][step], weight[weight_idx], \ | |||||
| src[src_row][(src_start_idx + step) / 4], \ | |||||
| (src_start_idx + step) % 4); | |||||
| UNROLL_CALL_RAW(8, cb); | |||||
| #undef cb | |||||
| } | |||||
| }; | |||||
| template <int res_row, int src_row, int src_start_idx, int weight_idx, | |||||
| typename T, typename T2, typename T3> | |||||
| MEGDNN_ALWAYS_INLINE void cal_helper(T& res, T2& src, T3& weight) { | |||||
| ShiftCalHelper<res_row, src_row, src_start_idx, weight_idx, T, T2, | |||||
| T3>::impl(res, src, weight); | |||||
| }; | |||||
| /** | |||||
| * oc12_owx(m = 12, n = x) and oc8_owx(m = 8, n = x) and oc4_owx(m = 4, n = x) | |||||
| * gemm like kernel | |||||
| * */ | |||||
| template <typename dst_type, int stride, BiasMode bias_mode, typename Op, | |||||
| int ow_remain, int filter_size, int oc_interval, int ow_interval> | |||||
| struct KernNeonSdotNCHW44 { | |||||
| static void impl(dst_type* dst, const int dst_step, const int8_t* src, | |||||
| const int ih, const int iw, const int8_t* filter, | |||||
| const int32_t* bias, const int ic, const Op& op); | |||||
| }; | |||||
| } // namespace direct_dotprod_nchw44 | |||||
| } // namespace arm_common | |||||
| } // namespace megdnn | |||||
| #endif | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -0,0 +1,320 @@ | |||||
| /** | |||||
| * \file | |||||
| * dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h" | |||||
| namespace megdnn { | |||||
| namespace arm_common { | |||||
| namespace direct_dotprod_nchw44 { | |||||
| template <typename dst_type, BiasMode bias_mode, typename Op, int ow_remain, | |||||
| int filter_size, int oc_interval, int ow_interval> | |||||
| struct KernNeonSdotNCHW44<dst_type, 1, bias_mode, Op, ow_remain, filter_size, | |||||
| oc_interval, ow_interval> { | |||||
| static void impl(dst_type* dst, const int dst_step, const int8_t* src, | |||||
| const int ih, const int iw, const int8_t* filter, | |||||
| const int32_t* bias, const int ic, const Op& op) { | |||||
| constexpr int FH = filter_size; | |||||
| constexpr int FW = filter_size; | |||||
| constexpr int filter_next_row = | |||||
| FW * OC_PACK_SIZE * | |||||
| IC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC] | |||||
| const int filter_next_4oc = | |||||
| FH * FW * ic * OC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC] | |||||
| const int src_next_ic = ih * iw; | |||||
| const int src_next_row = iw * IC_PACK_SIZE; | |||||
| constexpr int NSRC = (ow_interval + filter_size - 1) / 4 + 1; | |||||
| constexpr int LOOP = oc_interval / 4; | |||||
| int32x4_t res[3][ow_interval]; | |||||
| init_ocx_ow8<LOOP, bias_mode>(res, bias, OC_PACK_SIZE); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += IC_PACK_SIZE) { | |||||
| const int8_t* i_src = src + ic_idx * src_next_ic; | |||||
| const int8_t* i_filter = filter + ic_idx * FH * FW * OC_PACK_SIZE; | |||||
| for (int fh_idx = 0; fh_idx < FH; ++fh_idx) { | |||||
| int8x16_t src[1][4]; | |||||
| int8x16_t weight[3]; | |||||
| load_helper<NSRC, 0, SIMD_LEN, 1, Vld1q_s8>(src, i_src, 0); | |||||
| //! do not use switch order 3,2,1 because it will slow the speed. | |||||
| #define CALC_PART(step) \ | |||||
| switch (LOOP) { \ | |||||
| case 1: \ | |||||
| weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \ | |||||
| filter_next_col * step); \ | |||||
| cal_helper<0, 0, step, 0>(res, src, weight); \ | |||||
| break; \ | |||||
| case 2: \ | |||||
| weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \ | |||||
| filter_next_col * step); \ | |||||
| cal_helper<0, 0, step, 0>(res, src, weight); \ | |||||
| weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \ | |||||
| filter_next_col * step); \ | |||||
| cal_helper<1, 0, step, 1>(res, src, weight); \ | |||||
| break; \ | |||||
| case 3: \ | |||||
| weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \ | |||||
| filter_next_col * step); \ | |||||
| cal_helper<0, 0, step, 0>(res, src, weight); \ | |||||
| weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \ | |||||
| filter_next_col * step); \ | |||||
| cal_helper<1, 0, step, 1>(res, src, weight); \ | |||||
| weight[2] = vld1q_s8(i_filter + filter_next_4oc * 2 + \ | |||||
| filter_next_col * step); \ | |||||
| cal_helper<2, 0, step, 2>(res, src, weight); \ | |||||
| break; \ | |||||
| default: \ | |||||
| break; \ | |||||
| } | |||||
| switch (filter_size) { | |||||
| case 2: | |||||
| UNROLL_CALL_RAW(2, CALC_PART); | |||||
| break; | |||||
| case 3: | |||||
| UNROLL_CALL_RAW(3, CALC_PART); | |||||
| break; | |||||
| case 5: | |||||
| UNROLL_CALL_RAW(5, CALC_PART); | |||||
| break; | |||||
| case 7: | |||||
| UNROLL_CALL_RAW(7, CALC_PART); | |||||
| break; | |||||
| default: | |||||
| break; | |||||
| } | |||||
| #undef CALC_PART | |||||
| i_filter += filter_next_row; | |||||
| i_src += src_next_row; | |||||
| } | |||||
| } | |||||
| store_ocx_owx_remain_static<LOOP, ow_remain, Op>(res, op, dst, | |||||
| dst_step); | |||||
| } | |||||
| }; | |||||
| template <typename dst_type, int stride, BiasMode bias_mode, typename Op, | |||||
| int filter_size> | |||||
| void conv_direct_sdot_int8_nchw44(dst_type* dst, const int oh, const int ow, | |||||
| const int8_t* src, const int ih, const int iw, | |||||
| const int8_t* filter, const int32_t* bias, | |||||
| const int oh_size, const int oc, const int ic, | |||||
| const Op& op) { | |||||
| constexpr int FH = filter_size; | |||||
| constexpr int FW = filter_size; | |||||
| constexpr int IC_PACK_SIZE = 4; | |||||
| constexpr int OC_PACK_SIZE = 4; | |||||
| #if MEGDNN_AARCH64 | |||||
| constexpr int OC_BIG_INTERVAL = 12; | |||||
| constexpr int OC_MID_INTERVAL = 8; | |||||
| constexpr int OC_SMA_INTERVAL = 4; | |||||
| #else | |||||
| constexpr int OC_BIG_INTERVAL = 4; | |||||
| constexpr int OC_MID_INTERVAL = 4; | |||||
| constexpr int OC_SMA_INTERVAL = 4; | |||||
| #endif | |||||
| constexpr int OW_INTERVAL = 8; | |||||
| constexpr int SH = stride; | |||||
| const int dst_numbers_per_channel = oh * ow; | |||||
| const int ow_remain = ow % OW_INTERVAL; | |||||
| const int ow_end_idx = ow - ow_remain; | |||||
| const int oc_remain = | |||||
| oc % OC_BIG_INTERVAL; //! NCHW44 means oc_remain = 4 or 8 | |||||
| const int oc_end_idx = oc - oc_remain; | |||||
| const int dst_numbers_4channel_packed = | |||||
| dst_numbers_per_channel * OC_PACK_SIZE; | |||||
| using remain_fun = std::function<void( | |||||
| dst_type * dst, const int dst_step, const int8_t* src, const int ih, | |||||
| const int iw, const int8_t* filter, const int32_t* bias, | |||||
| const int ic, const Op& op)>; | |||||
| remain_fun kern_big_oc_remain = nullptr; | |||||
| remain_fun kern_mid_oc_remain = nullptr; | |||||
| remain_fun kern_sma_oc_remain = nullptr; | |||||
| switch (ow_remain) { | |||||
| #define cb(step) \ | |||||
| case step: \ | |||||
| kern_big_oc_remain = \ | |||||
| KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, step, \ | |||||
| filter_size, OC_BIG_INTERVAL, \ | |||||
| OW_INTERVAL>::impl; \ | |||||
| kern_mid_oc_remain = \ | |||||
| KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, step, \ | |||||
| filter_size, OC_MID_INTERVAL, \ | |||||
| OW_INTERVAL>::impl; \ | |||||
| kern_sma_oc_remain = \ | |||||
| KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, step, \ | |||||
| filter_size, OC_SMA_INTERVAL, \ | |||||
| OW_INTERVAL>::impl; \ | |||||
| break; | |||||
| UNROLL_CALL_RAW(8, cb); | |||||
| #undef cb | |||||
| default: | |||||
| megdnn_assert(0, "no remain %d for kern", ow_remain); | |||||
| } | |||||
| //! filter layout is [OC/4, IC/4, FH, FW, 4OC, 4IC] | |||||
| //! cut [oc, oh, ow] into [oc/OC_INTERVAL, 1, ow/OW_INTERVAL, OW_INTERVAL, | |||||
| //! oh, OC_INTERVAL] to calculate KernNeonSdotNCHW44 calculates | |||||
| //! [OW_INTERVAL, 1, OC_INTERVAL] each time | |||||
| for (int oc_idx = 0; oc_idx < oc_end_idx; oc_idx += OC_BIG_INTERVAL) { | |||||
| const int filter_offset_in_element = oc_idx * ic * FH * FW; | |||||
| for (int oh_idx = 0; oh_idx < oh_size; ++oh_idx) { | |||||
| for (int ow_idx = 0; ow_idx < ow_end_idx; ow_idx += OW_INTERVAL) { | |||||
| const int src_offset_in_element = | |||||
| (oh_idx * SH * iw + ow_idx) * IC_PACK_SIZE; | |||||
| const int dst_offset_in_element = | |||||
| oc_idx * dst_numbers_per_channel + | |||||
| (oh_idx * ow + ow_idx) * OC_PACK_SIZE; | |||||
| const int bias_offset_in_element = oc_idx; | |||||
| KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, OW_INTERVAL, | |||||
| filter_size, OC_BIG_INTERVAL, OW_INTERVAL>:: | |||||
| impl(dst + dst_offset_in_element, | |||||
| dst_numbers_4channel_packed, | |||||
| src + src_offset_in_element, ih, iw, | |||||
| filter + filter_offset_in_element, | |||||
| bias + bias_offset_in_element, ic, op); | |||||
| } | |||||
| if (ow_remain) { | |||||
| const int src_offset_in_element = | |||||
| (oh_idx * SH * iw + ow_end_idx) * IC_PACK_SIZE; | |||||
| const int dst_offset_in_element = | |||||
| oc_idx * dst_numbers_per_channel + | |||||
| (oh_idx * ow + ow_end_idx) * OC_PACK_SIZE; | |||||
| const int bias_offset_in_element = oc_idx; | |||||
| kern_big_oc_remain(dst + dst_offset_in_element, | |||||
| dst_numbers_4channel_packed, | |||||
| src + src_offset_in_element, ih, iw, | |||||
| filter + filter_offset_in_element, | |||||
| bias + bias_offset_in_element, ic, op); | |||||
| } | |||||
| } | |||||
| } | |||||
| #ifdef MEGDNN_AARCH64 | |||||
| //! oc_remain must be 4 or 8 on aarch64 and must be 0 on aarch32 | |||||
| if (oc_remain) { | |||||
| int oc_idx = oc_end_idx; | |||||
| const int filter_offset_in_element = oc_idx * ic * FH * FW; | |||||
| for (int oh_idx = 0; oh_idx < oh_size; ++oh_idx) { | |||||
| for (int ow_idx = 0; ow_idx < ow_end_idx; ow_idx += OW_INTERVAL) { | |||||
| const int src_offset_in_element = | |||||
| (oh_idx * SH * iw + ow_idx) * IC_PACK_SIZE; | |||||
| const int dst_offset_in_element = | |||||
| oc_idx * dst_numbers_per_channel + | |||||
| (oh_idx * ow + ow_idx) * OC_PACK_SIZE; | |||||
| const int bias_offset_in_element = oc_idx; | |||||
| if (oc_remain == 8) { | |||||
| KernNeonSdotNCHW44< | |||||
| dst_type, stride, bias_mode, Op, OW_INTERVAL, | |||||
| filter_size, OC_MID_INTERVAL, | |||||
| OW_INTERVAL>::impl(dst + dst_offset_in_element, | |||||
| dst_numbers_4channel_packed, | |||||
| src + src_offset_in_element, ih, | |||||
| iw, | |||||
| filter + | |||||
| filter_offset_in_element, | |||||
| bias + bias_offset_in_element, | |||||
| ic, op); | |||||
| } else { | |||||
| KernNeonSdotNCHW44< | |||||
| dst_type, stride, bias_mode, Op, OW_INTERVAL, | |||||
| filter_size, OC_SMA_INTERVAL, | |||||
| OW_INTERVAL>::impl(dst + dst_offset_in_element, | |||||
| dst_numbers_4channel_packed, | |||||
| src + src_offset_in_element, ih, | |||||
| iw, | |||||
| filter + | |||||
| filter_offset_in_element, | |||||
| bias + bias_offset_in_element, | |||||
| ic, op); | |||||
| } | |||||
| } | |||||
| if (ow_remain) { | |||||
| const int src_offset_in_element = | |||||
| (oh_idx * SH * iw + ow_end_idx) * IC_PACK_SIZE; | |||||
| const int dst_offset_in_element = | |||||
| oc_idx * dst_numbers_per_channel + | |||||
| (oh_idx * ow + ow_end_idx) * OC_PACK_SIZE; | |||||
| const int bias_offset_in_element = oc_idx; | |||||
| if (oc_remain == 8) { | |||||
| kern_mid_oc_remain(dst + dst_offset_in_element, | |||||
| dst_numbers_4channel_packed, | |||||
| src + src_offset_in_element, ih, iw, | |||||
| filter + filter_offset_in_element, | |||||
| bias + bias_offset_in_element, ic, op); | |||||
| } else { | |||||
| kern_sma_oc_remain(dst + dst_offset_in_element, | |||||
| dst_numbers_4channel_packed, | |||||
| src + src_offset_in_element, ih, iw, | |||||
| filter + filter_offset_in_element, | |||||
| bias + bias_offset_in_element, ic, op); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| #endif | |||||
| } | |||||
| #define INSTANTIATION(dst_type, stride, filter_size, bias_mode, Op) \ | |||||
| template void conv_direct_sdot_int8_nchw44<dst_type, stride, bias_mode, \ | |||||
| Op, filter_size>( \ | |||||
| dst_type * dst, const int oh, const int ow, const int8_t* src, \ | |||||
| const int ih, const int iw, const int8_t* weight, \ | |||||
| const int32_t* bias, const int oh_size, const int oc, \ | |||||
| const int ic, const Op& op); | |||||
| #define FOR_OP(stride, i, bias_mode) \ | |||||
| INSTANTIATION(dt_int8, stride, i, bias_mode, \ | |||||
| TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
| INSTANTIATION(dt_int32, stride, i, bias_mode, \ | |||||
| NoneOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
| INSTANTIATION(dt_int8, stride, i, bias_mode, \ | |||||
| ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
| INSTANTIATION(dt_int8, stride, i, bias_mode, \ | |||||
| HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) | |||||
| #define FOR_BIAS(stride, i) \ | |||||
| FOR_OP(stride, i, BiasMode::NO_BIAS) \ | |||||
| FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS) | |||||
| #define FOR_FILTER(stride) \ | |||||
| FOR_BIAS(stride, 2) \ | |||||
| FOR_BIAS(stride, 3) \ | |||||
| FOR_BIAS(stride, 5) \ | |||||
| FOR_BIAS(stride, 7) | |||||
| FOR_FILTER(1) | |||||
| #undef FOR_STRIDE | |||||
| #undef FOR_FILTER | |||||
| #undef FOR_IC | |||||
| #undef FOR_BIAS | |||||
| #undef FOR_NONLINEAR | |||||
| #undef FOR_REMAIN | |||||
| #undef INSTANTIATION | |||||
| } // namespace direct_dotprod_nchw44 | |||||
| } // namespace arm_common | |||||
| } // namespace megdnn | |||||
| #endif | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -0,0 +1,322 @@ | |||||
| /** | |||||
| * \file | |||||
| * dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h" | |||||
| namespace megdnn { | |||||
| namespace arm_common { | |||||
| namespace direct_dotprod_nchw44 { | |||||
| template <typename dst_type, BiasMode bias_mode, typename Op, int ow_remain, | |||||
| int filter_size, int oc_interval, int ow_interval> | |||||
| struct KernNeonSdotNCHW44<dst_type, 2, bias_mode, Op, ow_remain, filter_size, | |||||
| oc_interval, ow_interval> { | |||||
| static void impl(dst_type* dst, const int dst_step, const int8_t* src, | |||||
| const int ih, const int iw, const int8_t* filter, | |||||
| const int32_t* bias, const int ic, const Op& op) { | |||||
| constexpr int FH = filter_size; | |||||
| constexpr int FW = filter_size; | |||||
| constexpr int filter_next_row = | |||||
| FW * OC_PACK_SIZE * | |||||
| IC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC] | |||||
| const int filter_next_4oc = | |||||
| FH * FW * ic * OC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC] | |||||
| const int src_next_ic = ih * iw; | |||||
| const int src_next_row = iw * IC_PACK_SIZE; | |||||
| constexpr int NSRC = (ow_interval * 2 + filter_size - 3) / 8 + 1; | |||||
| constexpr int LOOP = oc_interval / 4; | |||||
| int32x4_t res[3][ow_interval]; | |||||
| init_ocx_ow8<LOOP, bias_mode>(res, bias, OC_PACK_SIZE); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += IC_PACK_SIZE) { | |||||
| const int8_t* i_src = src + ic_idx * src_next_ic; | |||||
| const int8_t* i_filter = filter + ic_idx * FH * FW * OC_PACK_SIZE; | |||||
| for (int fh_idx = 0; fh_idx < FH; ++fh_idx) { | |||||
| int8x16_t src[2][3]; | |||||
| int8x16_t weight[3]; | |||||
| const int offset = megdnn::div_ceil(iw, 2) * IC_PACK_SIZE; | |||||
| load_helper<NSRC, 0, SIMD_LEN, 2, Vld1q_s8>(src, i_src, offset); | |||||
| //! do not use switch order 3,2,1 because it will slow the speed. | |||||
| #define CALC_PART(step) \ | |||||
| switch (LOOP) { \ | |||||
| case 1: \ | |||||
| weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \ | |||||
| filter_next_col * step); \ | |||||
| cal_helper<0, step % 2, step / 2, 0>(res, src, weight); \ | |||||
| break; \ | |||||
| case 2: \ | |||||
| weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \ | |||||
| filter_next_col * step); \ | |||||
| cal_helper<0, step % 2, step / 2, 0>(res, src, weight); \ | |||||
| weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \ | |||||
| filter_next_col * step); \ | |||||
| cal_helper<1, step % 2, step / 2, 1>(res, src, weight); \ | |||||
| break; \ | |||||
| case 3: \ | |||||
| weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \ | |||||
| filter_next_col * step); \ | |||||
| cal_helper<0, step % 2, step / 2, 0>(res, src, weight); \ | |||||
| weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \ | |||||
| filter_next_col * step); \ | |||||
| cal_helper<1, step % 2, step / 2, 1>(res, src, weight); \ | |||||
| weight[2] = vld1q_s8(i_filter + filter_next_4oc * 2 + \ | |||||
| filter_next_col * step); \ | |||||
| cal_helper<2, step % 2, step / 2, 2>(res, src, weight); \ | |||||
| break; \ | |||||
| default: \ | |||||
| break; \ | |||||
| } | |||||
| switch (filter_size) { | |||||
| case 2: | |||||
| UNROLL_CALL_RAW(2, CALC_PART); | |||||
| break; | |||||
| case 3: | |||||
| UNROLL_CALL_RAW(3, CALC_PART); | |||||
| break; | |||||
| case 5: | |||||
| UNROLL_CALL_RAW(5, CALC_PART); | |||||
| break; | |||||
| case 7: | |||||
| UNROLL_CALL_RAW(7, CALC_PART); | |||||
| break; | |||||
| default: | |||||
| break; | |||||
| } | |||||
| #undef CALC_PART | |||||
| i_filter += filter_next_row; | |||||
| i_src += src_next_row; | |||||
| } | |||||
| } | |||||
| store_ocx_owx_remain_static<LOOP, ow_remain, Op>(res, op, dst, | |||||
| dst_step); | |||||
| } | |||||
| }; | |||||
| template <typename dst_type, int stride, BiasMode bias_mode, typename Op, | |||||
| int filter_size> | |||||
| void conv_direct_sdot_int8_nchw44(dst_type* dst, const int oh, const int ow, | |||||
| const int8_t* src, const int ih, const int iw, | |||||
| const int8_t* filter, const int32_t* bias, | |||||
| const int oh_size, const int oc, const int ic, | |||||
| const Op& op) { | |||||
| constexpr int FH = filter_size; | |||||
| constexpr int FW = filter_size; | |||||
| constexpr int IC_PACK_SIZE = 4; | |||||
| constexpr int OC_PACK_SIZE = 4; | |||||
| #if MEGDNN_AARCH64 | |||||
| constexpr int OC_BIG_INTERVAL = 12; | |||||
| constexpr int OC_MID_INTERVAL = 8; | |||||
| constexpr int OC_SMA_INTERVAL = 4; | |||||
| #else | |||||
| constexpr int OC_BIG_INTERVAL = 4; | |||||
| constexpr int OC_MID_INTERVAL = 4; | |||||
| constexpr int OC_SMA_INTERVAL = 4; | |||||
| #endif | |||||
| constexpr int OW_INTERVAL = 8; | |||||
| constexpr int SH = stride; | |||||
| const int dst_numbers_per_channel = oh * ow; | |||||
| const int ow_remain = ow % OW_INTERVAL; | |||||
| const int ow_end_idx = ow - ow_remain; | |||||
| const int oc_remain = | |||||
| oc % OC_BIG_INTERVAL; //! NCHW44 means oc_remain = 4 or 8 | |||||
| const int oc_end_idx = oc - oc_remain; | |||||
| const int dst_numbers_4channel_packed = | |||||
| dst_numbers_per_channel * OC_PACK_SIZE; | |||||
| using remain_fun = std::function<void( | |||||
| dst_type * dst, const int dst_step, const int8_t* src, const int ih, | |||||
| const int iw, const int8_t* filter, const int32_t* bias, | |||||
| const int ic, const Op& op)>; | |||||
| remain_fun kern_big_oc_remain = nullptr; | |||||
| remain_fun kern_mid_oc_remain = nullptr; | |||||
| remain_fun kern_sma_oc_remain = nullptr; | |||||
| switch (ow_remain) { | |||||
| #define cb(step) \ | |||||
| case step: \ | |||||
| kern_big_oc_remain = \ | |||||
| KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, step, \ | |||||
| filter_size, OC_BIG_INTERVAL, \ | |||||
| OW_INTERVAL>::impl; \ | |||||
| kern_mid_oc_remain = \ | |||||
| KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, step, \ | |||||
| filter_size, OC_MID_INTERVAL, \ | |||||
| OW_INTERVAL>::impl; \ | |||||
| kern_sma_oc_remain = \ | |||||
| KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, step, \ | |||||
| filter_size, OC_SMA_INTERVAL, \ | |||||
| OW_INTERVAL>::impl; \ | |||||
| break; | |||||
| UNROLL_CALL_RAW(8, cb); | |||||
| #undef cb | |||||
| default: | |||||
| megdnn_assert(0, "no remain %d for kern", ow_remain); | |||||
| } | |||||
| //! filter layout is [OC/4, IC/4, FH, FW, 4OC, 4IC] | |||||
| //! cut [oc, oh, ow] into [oc/OC_INTERVAL, 1, ow/OW_INTERVAL, OW_INTERVAL, | |||||
| //! oh, OC_INTERVAL] to calculate KernNeonSdotNCHW44 calculates | |||||
| //! [OW_INTERVAL, 1, OC_INTERVAL] each time | |||||
| for (int oc_idx = 0; oc_idx < oc_end_idx; oc_idx += OC_BIG_INTERVAL) { | |||||
| const int filter_offset_in_element = oc_idx * ic * FH * FW; | |||||
| for (int oh_idx = 0; oh_idx < oh_size; ++oh_idx) { | |||||
| for (int ow_idx = 0; ow_idx < ow_end_idx; ow_idx += OW_INTERVAL) { | |||||
| const int src_offset_in_element = | |||||
| (oh_idx * SH * iw + ow_idx) * IC_PACK_SIZE; | |||||
| const int dst_offset_in_element = | |||||
| oc_idx * dst_numbers_per_channel + | |||||
| (oh_idx * ow + ow_idx) * OC_PACK_SIZE; | |||||
| const int bias_offset_in_element = oc_idx; | |||||
| KernNeonSdotNCHW44<dst_type, stride, bias_mode, Op, OW_INTERVAL, | |||||
| filter_size, OC_BIG_INTERVAL, OW_INTERVAL>:: | |||||
| impl(dst + dst_offset_in_element, | |||||
| dst_numbers_4channel_packed, | |||||
| src + src_offset_in_element, ih, iw, | |||||
| filter + filter_offset_in_element, | |||||
| bias + bias_offset_in_element, ic, op); | |||||
| } | |||||
| if (ow_remain) { | |||||
| const int src_offset_in_element = | |||||
| (oh_idx * SH * iw + ow_end_idx) * IC_PACK_SIZE; | |||||
| const int dst_offset_in_element = | |||||
| oc_idx * dst_numbers_per_channel + | |||||
| (oh_idx * ow + ow_end_idx) * OC_PACK_SIZE; | |||||
| const int bias_offset_in_element = oc_idx; | |||||
| kern_big_oc_remain(dst + dst_offset_in_element, | |||||
| dst_numbers_4channel_packed, | |||||
| src + src_offset_in_element, ih, iw, | |||||
| filter + filter_offset_in_element, | |||||
| bias + bias_offset_in_element, ic, op); | |||||
| } | |||||
| } | |||||
| } | |||||
| #ifdef MEGDNN_AARCH64 | |||||
| //! oc_remain must be 4 or 8 on aarch64 and must be 0 on aarch32 | |||||
| if (oc_remain) { | |||||
| int oc_idx = oc_end_idx; | |||||
| const int filter_offset_in_element = oc_idx * ic * FH * FW; | |||||
| for (int oh_idx = 0; oh_idx < oh_size; ++oh_idx) { | |||||
| for (int ow_idx = 0; ow_idx < ow_end_idx; ow_idx += OW_INTERVAL) { | |||||
| const int src_offset_in_element = | |||||
| (oh_idx * SH * iw + ow_idx) * IC_PACK_SIZE; | |||||
| const int dst_offset_in_element = | |||||
| oc_idx * dst_numbers_per_channel + | |||||
| (oh_idx * ow + ow_idx) * OC_PACK_SIZE; | |||||
| const int bias_offset_in_element = oc_idx; | |||||
| if (oc_remain == 8) { | |||||
| KernNeonSdotNCHW44< | |||||
| dst_type, stride, bias_mode, Op, OW_INTERVAL, | |||||
| filter_size, OC_MID_INTERVAL, | |||||
| OW_INTERVAL>::impl(dst + dst_offset_in_element, | |||||
| dst_numbers_4channel_packed, | |||||
| src + src_offset_in_element, ih, | |||||
| iw, | |||||
| filter + | |||||
| filter_offset_in_element, | |||||
| bias + bias_offset_in_element, | |||||
| ic, op); | |||||
| } else { | |||||
| KernNeonSdotNCHW44< | |||||
| dst_type, stride, bias_mode, Op, OW_INTERVAL, | |||||
| filter_size, OC_SMA_INTERVAL, | |||||
| OW_INTERVAL>::impl(dst + dst_offset_in_element, | |||||
| dst_numbers_4channel_packed, | |||||
| src + src_offset_in_element, ih, | |||||
| iw, | |||||
| filter + | |||||
| filter_offset_in_element, | |||||
| bias + bias_offset_in_element, | |||||
| ic, op); | |||||
| } | |||||
| } | |||||
| if (ow_remain) { | |||||
| const int src_offset_in_element = | |||||
| (oh_idx * SH * iw + ow_end_idx) * IC_PACK_SIZE; | |||||
| const int dst_offset_in_element = | |||||
| oc_idx * dst_numbers_per_channel + | |||||
| (oh_idx * ow + ow_end_idx) * OC_PACK_SIZE; | |||||
| const int bias_offset_in_element = oc_idx; | |||||
| if (oc_remain == 8) { | |||||
| kern_mid_oc_remain(dst + dst_offset_in_element, | |||||
| dst_numbers_4channel_packed, | |||||
| src + src_offset_in_element, ih, iw, | |||||
| filter + filter_offset_in_element, | |||||
| bias + bias_offset_in_element, ic, op); | |||||
| } else { | |||||
| kern_sma_oc_remain(dst + dst_offset_in_element, | |||||
| dst_numbers_4channel_packed, | |||||
| src + src_offset_in_element, ih, iw, | |||||
| filter + filter_offset_in_element, | |||||
| bias + bias_offset_in_element, ic, op); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| #endif | |||||
| } | |||||
| #define INSTANTIATION(dst_type, stride, filter_size, bias_mode, Op) \ | |||||
| template void conv_direct_sdot_int8_nchw44<dst_type, stride, bias_mode, \ | |||||
| Op, filter_size>( \ | |||||
| dst_type * dst, const int oh, const int ow, const int8_t* src, \ | |||||
| const int ih, const int iw, const int8_t* weight, \ | |||||
| const int32_t* bias, const int oh_size, const int oc, \ | |||||
| const int ic, const Op& op); | |||||
| #define FOR_OP(stride, i, bias_mode) \ | |||||
| INSTANTIATION(dt_int8, stride, i, bias_mode, \ | |||||
| TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
| INSTANTIATION(dt_int32, stride, i, bias_mode, \ | |||||
| NoneOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
| INSTANTIATION(dt_int8, stride, i, bias_mode, \ | |||||
| ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
| INSTANTIATION(dt_int8, stride, i, bias_mode, \ | |||||
| HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) | |||||
| #define FOR_BIAS(stride, i) \ | |||||
| FOR_OP(stride, i, BiasMode::NO_BIAS) \ | |||||
| FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS) | |||||
| #define FOR_FILTER(stride) \ | |||||
| FOR_BIAS(stride, 2) \ | |||||
| FOR_BIAS(stride, 3) \ | |||||
| FOR_BIAS(stride, 5) \ | |||||
| FOR_BIAS(stride, 7) | |||||
| FOR_FILTER(2) | |||||
| #undef FOR_STRIDE | |||||
| #undef FOR_FILTER | |||||
| #undef FOR_IC | |||||
| #undef FOR_BIAS | |||||
| #undef FOR_NONLINEAR | |||||
| #undef FOR_REMAIN | |||||
| #undef INSTANTIATION | |||||
| } // namespace direct_dotprod_nchw44 | |||||
| } // namespace arm_common | |||||
| } // namespace megdnn | |||||
| #endif | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -0,0 +1,448 @@ | |||||
| /** | |||||
| * \file | |||||
| * dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s1.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #include "src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h" | |||||
| namespace megdnn { | |||||
| namespace arm_common { | |||||
| namespace dot_direct_nchw_nchw44 { | |||||
| template <int src_idx, int weight_idx, typename Func, typename T, typename T2, | |||||
| typename T3, typename T4> | |||||
| struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 8, 1, T, T2, T3, T4> { | |||||
| static void impl(T& c, T2& src, T3& weight) { | |||||
| #define cb(step) \ | |||||
| c[0][step] = Func::template impl<(src_idx + step) % 4>( \ | |||||
| c[0][step], weight[0][weight_idx], src[(src_idx + step) / 4]); \ | |||||
| c[1][step] = Func::template impl<(src_idx + step) % 4>( \ | |||||
| c[1][step], weight[1][weight_idx], src[(src_idx + step) / 4]); | |||||
| UNROLL_CALL_RAW(8, cb); | |||||
| #undef cb | |||||
| } | |||||
| }; | |||||
| template <int src_idx, int weight_idx, typename Func, typename T, typename T2, | |||||
| typename T3, typename T4> | |||||
| struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, 1, T, T2, T3, T4> { | |||||
| static void impl(T& c, T2& src, T3& weight) { | |||||
| #define cb(step) \ | |||||
| c[0][step] = Func::template impl<(src_idx + step) % 4>( \ | |||||
| c[0][step], weight[0][weight_idx], src[(src_idx + step) / 4]); | |||||
| UNROLL_CALL_RAW(8, cb); | |||||
| #undef cb | |||||
| } | |||||
| }; | |||||
| ////////////////////stride 1/////////////////// | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | |||||
| int ow_block> | |||||
| struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 2, oc_block, ow_block, | |||||
| 1> { | |||||
| static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | |||||
| int iw, int ld_dst_oc, const Op& op) { | |||||
| constexpr int stride = 1; | |||||
| constexpr int filter_hight = 2; | |||||
| constexpr int filter_width = 4; | |||||
| constexpr int weight_reg = 2; | |||||
| constexpr int src_reg = 2; | |||||
| constexpr int oc_step = 4; | |||||
| constexpr int ic_step = 1; | |||||
| constexpr int pack_iw_len = 4; | |||||
| constexpr int simd_len = 16; | |||||
| const int ld_bias = oc_step; | |||||
| const int ic_stride = ih * iw * pack_iw_len; | |||||
| const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; | |||||
| constexpr int c_dim = OCHelper<oc_block>::val; | |||||
| int32x4_t c[c_dim][8]; | |||||
| init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | |||||
| int8x16_t src[src_reg]; | |||||
| int8x16_t weight[c_dim][weight_reg]; | |||||
| // row 0 | |||||
| load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( | |||||
| src, src_ptr + 0 * iw * pack_iw_len, 0); | |||||
| load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( | |||||
| weight, weight_ptr, ld_weight_oc); | |||||
| cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, | |||||
| weight); | |||||
| // row 1 | |||||
| load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( | |||||
| src, src_ptr + 1 * iw * pack_iw_len, 0); | |||||
| cal_helper<0, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, | |||||
| weight); | |||||
| src_ptr += ic_stride; | |||||
| weight_ptr += filter_hight * filter_width * oc_step; | |||||
| } | |||||
| store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>( | |||||
| c, op, dst_ptr, ld_dst_oc); | |||||
| } | |||||
| }; | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | |||||
| int ow_block> | |||||
| struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 3, oc_block, ow_block, | |||||
| 1> { | |||||
| static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | |||||
| int iw, int ld_dst_oc, const Op& op) { | |||||
| constexpr int stride = 1; | |||||
| constexpr int filter_hight = 3; | |||||
| constexpr int filter_width = 4; | |||||
| constexpr int weight_reg = 3; | |||||
| constexpr int src_reg = 2; | |||||
| constexpr int oc_step = 4; | |||||
| constexpr int ic_step = 1; | |||||
| constexpr int pack_iw_len = 4; | |||||
| constexpr int simd_len = 16; | |||||
| const int ld_bias = oc_step; | |||||
| const int ic_stride = ih * iw * pack_iw_len; | |||||
| const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; | |||||
| constexpr int c_dim = OCHelper<oc_block>::val; | |||||
| int32x4_t c[c_dim][8]; | |||||
| init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | |||||
| int8x16_t src[src_reg]; | |||||
| int8x16_t weight[c_dim][weight_reg]; | |||||
| // row 0 | |||||
| load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( | |||||
| src, src_ptr + 0 * iw * pack_iw_len, 0); | |||||
| load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( | |||||
| weight, weight_ptr, ld_weight_oc); | |||||
| cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, | |||||
| weight); | |||||
| // row 1 | |||||
| load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( | |||||
| src, src_ptr + 1 * iw * pack_iw_len, 0); | |||||
| cal_helper<0, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, | |||||
| weight); | |||||
| // row 2 | |||||
| load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( | |||||
| src, src_ptr + 2 * iw * pack_iw_len, 0); | |||||
| cal_helper<0, 2, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, | |||||
| weight); | |||||
| src_ptr += ic_stride; | |||||
| weight_ptr += filter_hight * filter_width * oc_step; | |||||
| } | |||||
| store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>( | |||||
| c, op, dst_ptr, ld_dst_oc); | |||||
| } | |||||
| }; | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | |||||
| int ow_block> | |||||
| struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 5, oc_block, ow_block, | |||||
| 1> { | |||||
| static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | |||||
| int iw, int ld_dst_oc, const Op& op) { | |||||
| constexpr int stride = 1; | |||||
| constexpr int filter_hight = 5; | |||||
| constexpr int filter_width = 8; | |||||
| constexpr int src_reg = 3; | |||||
| constexpr int weight_reg = 2; | |||||
| constexpr int oc_step = 4; | |||||
| constexpr int ic_step = 1; | |||||
| constexpr int pack_iw_len = 4; | |||||
| constexpr int simd_len = 16; | |||||
| const int ld_bias = oc_step; | |||||
| const int ic_stride = ih * iw * pack_iw_len; | |||||
| const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; | |||||
| constexpr int c_dim = OCHelper<oc_block>::val; | |||||
| int32x4_t c[c_dim][8]; | |||||
| init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | |||||
| int8x16_t src[src_reg]; | |||||
| int8x16_t weight[c_dim][weight_reg]; | |||||
| #define cb(step) \ | |||||
| load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( \ | |||||
| src, src_ptr + step * iw * pack_iw_len, 0); \ | |||||
| load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \ | |||||
| weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \ | |||||
| cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \ | |||||
| weight); \ | |||||
| cal_helper<4, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight); | |||||
| UNROLL_CALL_RAW(5, cb); | |||||
| #undef cb | |||||
| src_ptr += ic_stride; | |||||
| weight_ptr += filter_hight * filter_width * oc_step; | |||||
| } | |||||
| store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>( | |||||
| c, op, dst_ptr, ld_dst_oc); | |||||
| } | |||||
| }; | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | |||||
| int ow_block> | |||||
| struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 7, oc_block, ow_block, | |||||
| 1> { | |||||
| static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | |||||
| int iw, int ld_dst_oc, const Op& op) { | |||||
| constexpr int stride = 1; | |||||
| constexpr int filter_hight = 7; | |||||
| constexpr int filter_width = 8; | |||||
| constexpr int src_reg = 3; | |||||
| constexpr int weight_reg = 2; | |||||
| constexpr int oc_step = 4; | |||||
| constexpr int ic_step = 1; | |||||
| constexpr int pack_iw_len = 4; | |||||
| constexpr int simd_len = 16; | |||||
| const int ld_bias = oc_step; | |||||
| const int ic_stride = ih * iw * pack_iw_len; | |||||
| const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; | |||||
| constexpr int c_dim = OCHelper<oc_block>::val; | |||||
| int32x4_t c[c_dim][8]; | |||||
| init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | |||||
| int8x16_t src[src_reg]; | |||||
| int8x16_t weight[c_dim][weight_reg]; | |||||
| #define cb(step) \ | |||||
| load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( \ | |||||
| src, src_ptr + step * iw * pack_iw_len, 0); \ | |||||
| load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \ | |||||
| weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \ | |||||
| cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \ | |||||
| weight); \ | |||||
| cal_helper<4, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight); | |||||
| UNROLL_CALL_RAW(7, cb); | |||||
| #undef cb | |||||
| src_ptr += ic_stride; | |||||
| weight_ptr += filter_hight * filter_width * oc_step; | |||||
| } | |||||
| store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>( | |||||
| c, op, dst_ptr, ld_dst_oc); | |||||
| } | |||||
| }; | |||||
| template <> | |||||
| void pack_src_int8_nchw_nchw44_dot<1>(int8_t* sptr_base, | |||||
| const int8_t* sptr_origin, const int, | |||||
| const int pw, const int, const int ih, | |||||
| const int iw, const int iw2, | |||||
| const int pad_top, const int pad_bottom, | |||||
| const int ic, const int ic_stride, | |||||
| int8_t* temp_ptr) { | |||||
| static uint8_t reorder_idx[16] = {0, 1, 2, 3, 1, 2, 3, 4, | |||||
| 2, 3, 4, 5, 3, 4, 5, 6}; | |||||
| uint8x16_t tbl_idx = vld1q_u8(&reorder_idx[0]); | |||||
| constexpr int iw_step = 16; | |||||
| constexpr int pack_iw_len = 4; | |||||
| const int iw_with_pad = iw + 2 * pw; | |||||
| const int iw_with_pad_end = iw_with_pad / iw_step * iw_step; | |||||
| rep(ic_idx, ic) { | |||||
| const int8_t* sptr = sptr_origin + ic_idx * ic_stride; | |||||
| memset(sptr_base, 0, | |||||
| sizeof(int8_t) * iw2 * (ih + pad_top + pad_bottom) * | |||||
| pack_iw_len); | |||||
| sptr_base += iw2 * pad_top * pack_iw_len; | |||||
| rep(ih_idx, ih) { | |||||
| memset(temp_ptr, 0, iw_with_pad * sizeof(int8_t)); | |||||
| memcpy(temp_ptr + pw, sptr, sizeof(int8_t) * iw); | |||||
| for (int iw_idx = 0; iw_idx < iw_with_pad_end; iw_idx += iw_step) { | |||||
| int8x16_t src[4]; | |||||
| int8x16_t dst[4]; | |||||
| src[0] = vld1q_s8(temp_ptr + iw_idx); | |||||
| src[1] = vld1q_s8(temp_ptr + iw_idx + 4); | |||||
| src[2] = vld1q_s8(temp_ptr + iw_idx + 8); | |||||
| src[3] = vld1q_s8(temp_ptr + iw_idx + 12); | |||||
| dst[0] = vqtbl1q_s8(src[0], tbl_idx); | |||||
| dst[1] = vqtbl1q_s8(src[1], tbl_idx); | |||||
| dst[2] = vqtbl1q_s8(src[2], tbl_idx); | |||||
| dst[3] = vqtbl1q_s8(src[3], tbl_idx); | |||||
| vst1q_s8(sptr_base + iw_idx * pack_iw_len + 0, dst[0]); | |||||
| vst1q_s8(sptr_base + iw_idx * pack_iw_len + 16, dst[1]); | |||||
| vst1q_s8(sptr_base + iw_idx * pack_iw_len + 32, dst[2]); | |||||
| vst1q_s8(sptr_base + iw_idx * pack_iw_len + 48, dst[3]); | |||||
| } | |||||
| for (int iw_idx = iw_with_pad_end; iw_idx < iw_with_pad; ++iw_idx) { | |||||
| *(sptr_base + iw_idx * pack_iw_len + 0) = | |||||
| *(temp_ptr + iw_idx + 0); | |||||
| *(sptr_base + iw_idx * pack_iw_len + 1) = | |||||
| *(temp_ptr + iw_idx + 1); | |||||
| *(sptr_base + iw_idx * pack_iw_len + 2) = | |||||
| *(temp_ptr + iw_idx + 2); | |||||
| *(sptr_base + iw_idx * pack_iw_len + 3) = | |||||
| *(temp_ptr + iw_idx + 3); | |||||
| } | |||||
| sptr_base += iw2 * pack_iw_len; | |||||
| sptr += iw; | |||||
| } | |||||
| sptr_base += iw2 * pad_bottom * pack_iw_len; | |||||
| } | |||||
| } | |||||
| template <BiasMode bias_mode, typename Op, int filter_size, int stride> | |||||
| void conv_direct_int8_nchw_nchw44_dot(const int8_t* src, const int8_t* filter, | |||||
| const int32_t* bias, int32_t* temp, | |||||
| int8_t* dst, const int oc, const int ic, | |||||
| const int ih, const int iw, const int oh, | |||||
| const int oh_block, const int ow, | |||||
| const Op& op) { | |||||
| MEGDNN_MARK_USED_VAR(temp); | |||||
| constexpr int fh = filter_size; | |||||
| constexpr int fw = (filter_size + 3) / 4 * 4; | |||||
| #if MEGDNN_AARCH64 | |||||
| constexpr int big_oc_step = 8; | |||||
| #else | |||||
| constexpr int big_oc_step = 4; | |||||
| #endif | |||||
| constexpr int oc_step = 4; | |||||
| constexpr int ih_step = 1; | |||||
| constexpr int oh_step = 1; | |||||
| constexpr int ow_step = 8; | |||||
| constexpr int stride_h = stride; | |||||
| constexpr int stride_w = stride; | |||||
| constexpr int pack_iw_len = stride == 2 ? 1 : 4; | |||||
| const int img_stride = oh * ow; | |||||
| const int ow_end = ow / ow_step * ow_step; | |||||
| const int ow_remain = ow - ow_end; | |||||
| const int oc_end = oc / big_oc_step * big_oc_step; | |||||
| const int oc_remain = oc - oc_end; | |||||
| const int ld_dst_oc = oc_step * img_stride; | |||||
| using remain_fun = | |||||
| std::function<void(const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, int8_t* dst_ptr, int ic, | |||||
| int ih, int iw, int ld_dst_oc, const Op& op)>; | |||||
| remain_fun kern_big_oc_remain = nullptr; | |||||
| remain_fun kern_small_oc_remain = nullptr; | |||||
| switch (ow_remain) { | |||||
| #define cb(step) \ | |||||
| case step: \ | |||||
| kern_big_oc_remain = \ | |||||
| KerNeonDotXXs2Nchw44Int8<bias_mode, Op, step, filter_size, \ | |||||
| big_oc_step, ow_step, stride>::impl; \ | |||||
| kern_small_oc_remain = \ | |||||
| KerNeonDotXXs2Nchw44Int8<bias_mode, Op, step, filter_size, \ | |||||
| oc_step, ow_step, stride>::impl; \ | |||||
| break; | |||||
| UNROLL_CALL_RAW(8, cb); | |||||
| default: | |||||
| megdnn_assert(0, "no remain %d for kern", ow_remain); | |||||
| } | |||||
| for (int oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { | |||||
| const int weight_offset = oc_idx * ic * fh * fw; | |||||
| for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) { | |||||
| for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||||
| const int src_offset = | |||||
| (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * | |||||
| pack_iw_len; | |||||
| const int dst_offset = | |||||
| oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||||
| KerNeonDotXXs2Nchw44Int8<bias_mode, Op, ow_step, filter_size, | |||||
| big_oc_step, ow_step, | |||||
| stride>::impl(src + src_offset, | |||||
| filter + weight_offset, | |||||
| bias + oc_idx, | |||||
| dst + dst_offset, ic, ih, | |||||
| iw, ld_dst_oc, op); | |||||
| } | |||||
| if (ow_remain > 0) { | |||||
| const int src_offset = | |||||
| (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * | |||||
| pack_iw_len; | |||||
| const int dst_offset = | |||||
| oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||||
| kern_big_oc_remain(src + src_offset, filter + weight_offset, | |||||
| bias + oc_idx, dst + dst_offset, ic, ih, iw, | |||||
| ld_dst_oc, op); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (oc_remain > 0) { | |||||
| int oc_idx = oc_end; | |||||
| const int weight_offset = oc_idx * ic * fh * fw; | |||||
| for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) { | |||||
| for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||||
| const int src_offset = | |||||
| (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * | |||||
| pack_iw_len; | |||||
| const int dst_offset = | |||||
| oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||||
| KerNeonDotXXs2Nchw44Int8<bias_mode, Op, ow_step, filter_size, | |||||
| oc_step, ow_step, | |||||
| stride>::impl(src + src_offset, | |||||
| filter + weight_offset, | |||||
| bias + oc_idx, | |||||
| dst + dst_offset, ic, ih, | |||||
| iw, ld_dst_oc, op); | |||||
| } | |||||
| if (ow_remain > 0) { | |||||
| const int src_offset = | |||||
| (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * | |||||
| pack_iw_len; | |||||
| const int dst_offset = | |||||
| oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||||
| kern_small_oc_remain(src + src_offset, filter + weight_offset, | |||||
| bias + oc_idx, dst + dst_offset, ic, ih, | |||||
| iw, ld_dst_oc, op); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| #define DO_CONV_KERN_FUN(stride, filter_size, bias_mode, Op) \ | |||||
| template void \ | |||||
| conv_direct_int8_nchw_nchw44_dot<bias_mode, Op, filter_size, stride>( \ | |||||
| const int8_t* src, const int8_t* filter, const int32_t* bias, \ | |||||
| int32_t* temp, int8_t* dst, const int oc, const int ic, \ | |||||
| const int ih, const int iw, const int oh, const int oh_block, \ | |||||
| const int ow, const Op& op); | |||||
| #define GET_OP_PARAM(stride, filter, bias_mode) \ | |||||
| DO_CONV_KERN_FUN(stride, filter, bias_mode, \ | |||||
| TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
| DO_CONV_KERN_FUN(stride, filter, bias_mode, \ | |||||
| ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
| DO_CONV_KERN_FUN(stride, filter, bias_mode, \ | |||||
| HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) | |||||
| #define GET_BIAS_MODE_PARAM(stride, filter) \ | |||||
| GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \ | |||||
| GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) | |||||
| #define DISPATCH_CONV_KERN(stride) \ | |||||
| GET_BIAS_MODE_PARAM(stride, 2) \ | |||||
| GET_BIAS_MODE_PARAM(stride, 3) \ | |||||
| GET_BIAS_MODE_PARAM(stride, 5) \ | |||||
| GET_BIAS_MODE_PARAM(stride, 7) | |||||
| DISPATCH_CONV_KERN(1); | |||||
| } // namespace dot_direct_nchw_nchw44 | |||||
| } // namespace arm_common | |||||
| } // namespace megdnn | |||||
| #endif | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -0,0 +1,437 @@ | |||||
| /** | |||||
| * \file | |||||
| * dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s2.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #if __ARM_FEATURE_DOTPROD | |||||
| #include "src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h" | |||||
| namespace megdnn { | |||||
| namespace arm_common { | |||||
| namespace dot_direct_nchw_nchw44 { | |||||
| template <int src_idx, int weight_idx, typename Func, typename T, typename T2, | |||||
| typename T3, typename T4> | |||||
| struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 8, 2, T, T2, T3, T4> { | |||||
| static void impl(T& c, T2& src, T3& weight) { | |||||
| #define cb(step) \ | |||||
| c[0][step * 2] = Func::template impl<(src_idx + step) % 4>( \ | |||||
| c[0][step * 2], weight[0][weight_idx], \ | |||||
| src[0][(src_idx + step) / 4]); \ | |||||
| c[1][step * 2] = Func::template impl<(src_idx + step) % 4>( \ | |||||
| c[1][step * 2], weight[1][weight_idx], \ | |||||
| src[0][(src_idx + step) / 4]); \ | |||||
| c[0][step * 2 + 1] = Func::template impl<(src_idx + step) % 4>( \ | |||||
| c[0][step * 2 + 1], weight[0][weight_idx], \ | |||||
| src[1][(src_idx + step) / 4]); \ | |||||
| c[1][step * 2 + 1] = Func::template impl<(src_idx + step) % 4>( \ | |||||
| c[1][step * 2 + 1], weight[1][weight_idx], \ | |||||
| src[1][(src_idx + step) / 4]); | |||||
| UNROLL_CALL_RAW(4, cb); | |||||
| #undef cb | |||||
| } | |||||
| }; | |||||
| template <int src_idx, int weight_idx, typename Func, typename T, typename T2, | |||||
| typename T3, typename T4> | |||||
| struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, 2, T, T2, T3, T4> { | |||||
| static void impl(T& c, T2& src, T3& weight) { | |||||
| #define cb(step) \ | |||||
| c[0][step * 2] = Func::template impl<(src_idx + step) % 4>( \ | |||||
| c[0][step * 2], weight[0][weight_idx], \ | |||||
| src[0][(src_idx + step) / 4]); \ | |||||
| c[0][step * 2 + 1] = Func::template impl<(src_idx + step) % 4>( \ | |||||
| c[0][step * 2 + 1], weight[0][weight_idx], \ | |||||
| src[1][(src_idx + step) / 4]); | |||||
| UNROLL_CALL_RAW(4, cb); | |||||
| #undef cb | |||||
| } | |||||
| }; | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | |||||
| int ow_block> | |||||
| struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 2, oc_block, ow_block, | |||||
| 2> { | |||||
| static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | |||||
| int iw, int ld_dst_oc, const Op& op) { | |||||
| constexpr int stride = 2; | |||||
| constexpr int filter_hight = 2; | |||||
| constexpr int filter_width = 4; | |||||
| constexpr int weight_reg = 1; | |||||
| constexpr int src_reg = 1; | |||||
| constexpr int oc_step = 4; | |||||
| constexpr int ic_step = 1; | |||||
| constexpr int pack_iw_len = 1; | |||||
| constexpr int simd_len = 16; | |||||
| const int ld_bias = oc_step; | |||||
| const int ic_stride = ih * iw * pack_iw_len; | |||||
| const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; | |||||
| constexpr int c_dim = OCHelper<oc_block>::val; | |||||
| int32x4_t c[c_dim][8]; | |||||
| init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | |||||
| int8x16_t src[2][src_reg]; | |||||
| int8x16_t weight[c_dim][weight_reg]; | |||||
| // row 0 | |||||
| load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>( | |||||
| src, src_ptr + 0 * iw, stride); | |||||
| load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( | |||||
| weight, weight_ptr, ld_weight_oc); | |||||
| cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, | |||||
| weight); | |||||
| // row 1 | |||||
| load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>( | |||||
| src, src_ptr + 1 * iw, stride); | |||||
| load_helper<weight_reg, 1 * simd_len, simd_len, c_dim, Vld1q_s8>( | |||||
| weight, weight_ptr, ld_weight_oc); | |||||
| cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, | |||||
| weight); | |||||
| src_ptr += ic_stride; | |||||
| weight_ptr += filter_hight * filter_width * oc_step; | |||||
| } | |||||
| store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>( | |||||
| c, op, dst_ptr, ld_dst_oc); | |||||
| } | |||||
| }; | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | |||||
| int ow_block> | |||||
| struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 3, oc_block, ow_block, | |||||
| 2> { | |||||
| static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | |||||
| int iw, int ld_dst_oc, const Op& op) { | |||||
| constexpr int stride = 2; | |||||
| constexpr int filter_hight = 3; | |||||
| constexpr int filter_width = 4; | |||||
| constexpr int weight_reg = 1; | |||||
| constexpr int src_reg = 1; | |||||
| constexpr int oc_step = 4; | |||||
| constexpr int ic_step = 1; | |||||
| constexpr int pack_iw_len = 1; | |||||
| constexpr int simd_len = 16; | |||||
| const int ld_bias = oc_step; | |||||
| const int ic_stride = ih * iw * pack_iw_len; | |||||
| const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; | |||||
| constexpr int c_dim = OCHelper<oc_block>::val; | |||||
| int32x4_t c[c_dim][8]; | |||||
| init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | |||||
| int8x16_t src[2][src_reg]; | |||||
| int8x16_t weight[c_dim][weight_reg]; | |||||
| // row 0 | |||||
| load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>( | |||||
| src, src_ptr + 0 * iw, stride); | |||||
| load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( | |||||
| weight, weight_ptr, ld_weight_oc); | |||||
| cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, | |||||
| weight); | |||||
| // row 1 | |||||
| load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>( | |||||
| src, src_ptr + 1 * iw, stride); | |||||
| load_helper<weight_reg, 1 * simd_len, simd_len, c_dim, Vld1q_s8>( | |||||
| weight, weight_ptr, ld_weight_oc); | |||||
| cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, | |||||
| weight); | |||||
| // row 2 | |||||
| load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>( | |||||
| src, src_ptr + 2 * iw, stride); | |||||
| load_helper<weight_reg, 2 * simd_len, simd_len, c_dim, Vld1q_s8>( | |||||
| weight, weight_ptr, ld_weight_oc); | |||||
| cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, | |||||
| weight); | |||||
| src_ptr += ic_stride; | |||||
| weight_ptr += filter_hight * filter_width * oc_step; | |||||
| } | |||||
| store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>( | |||||
| c, op, dst_ptr, ld_dst_oc); | |||||
| } | |||||
| }; | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | |||||
| int ow_block> | |||||
| struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 5, oc_block, ow_block, | |||||
| 2> { | |||||
| static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | |||||
| int iw, int ld_dst_oc, const Op& op) { | |||||
| constexpr int stride = 2; | |||||
| constexpr int filter_hight = 5; | |||||
| constexpr int filter_width = 8; | |||||
| constexpr int src_reg = 2; | |||||
| constexpr int weight_reg = 2; | |||||
| constexpr int oc_step = 4; | |||||
| constexpr int ic_step = 1; | |||||
| constexpr int pack_iw_len = 1; | |||||
| constexpr int simd_len = 16; | |||||
| const int ld_bias = oc_step; | |||||
| const int ic_stride = ih * iw * pack_iw_len; | |||||
| const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; | |||||
| constexpr int c_dim = OCHelper<oc_block>::val; | |||||
| int32x4_t c[c_dim][8]; | |||||
| init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | |||||
| int8x16_t src[2][src_reg]; | |||||
| int8x16_t weight[c_dim][weight_reg]; | |||||
| #define cb(step) \ | |||||
| load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>(src, src_ptr + step * iw, \ | |||||
| stride); \ | |||||
| load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \ | |||||
| weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \ | |||||
| cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \ | |||||
| weight); \ | |||||
| cal_helper<1, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight); | |||||
| UNROLL_CALL_RAW(5, cb); | |||||
| #undef cb | |||||
| src_ptr += ic_stride; | |||||
| weight_ptr += 5 * 32; | |||||
| } | |||||
| store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>( | |||||
| c, op, dst_ptr, ld_dst_oc); | |||||
| } | |||||
| }; | |||||
| /** | |||||
| * oc = 8, ow = 8 | |||||
| * dot 4 element, pad last filter and do twice dot every row filter, filter like | |||||
| * below | |||||
| * -------------------------- | |||||
| * |x, x, x, x,| x, x, x, 0 | | |||||
| * -------------------------- | |||||
| **/ | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | |||||
| int ow_block> | |||||
| struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 7, oc_block, ow_block, | |||||
| 2> { | |||||
| static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | |||||
| int iw, int ld_dst_oc, const Op& op) { | |||||
| constexpr int stride = 2; | |||||
| constexpr int filter_hight = 7; | |||||
| constexpr int filter_width = 8; | |||||
| constexpr int src_reg = 2; | |||||
| constexpr int weight_reg = 2; | |||||
| constexpr int oc_step = 4; | |||||
| constexpr int ic_step = 1; | |||||
| constexpr int pack_iw_len = 1; | |||||
| constexpr int simd_len = 16; | |||||
| const int ld_bias = oc_step; | |||||
| const int ic_stride = ih * iw * pack_iw_len; | |||||
| const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; | |||||
| constexpr int c_dim = OCHelper<oc_block>::val; | |||||
| int32x4_t c[c_dim][8]; | |||||
| init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | |||||
| int8x16_t src[2][src_reg]; | |||||
| int8x16_t weight[c_dim][weight_reg]; | |||||
| #define cb(step) \ | |||||
| load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>(src, src_ptr + step * iw, \ | |||||
| stride); \ | |||||
| load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \ | |||||
| weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \ | |||||
| cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \ | |||||
| weight); \ | |||||
| cal_helper<1, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight); | |||||
| UNROLL_CALL_RAW(7, cb); | |||||
| #undef cb | |||||
| src_ptr += ic_stride; | |||||
| weight_ptr += 7 * 32; | |||||
| } | |||||
| store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>( | |||||
| c, op, dst_ptr, ld_dst_oc); | |||||
| } | |||||
| }; | |||||
| template <> | |||||
| void pack_src_int8_nchw_nchw44_dot<2>( | |||||
| int8_t* sptr_base, const int8_t* sptr_origin, const int, const int pw, | |||||
| const int, const int ih, const int iw, const int iw2, const int pad_top, | |||||
| const int pad_bottom, const int ic, const int ic_stride, int8_t*) { | |||||
| constexpr int ic_step = 1; | |||||
| rep_step(ic_idx, ic, ic_step) { | |||||
| const int8_t* sptr = sptr_origin + ic_idx * ic_stride; | |||||
| memset(sptr_base, 0, | |||||
| sizeof(int8_t) * ic_step * iw2 * (ih + pad_top + pad_bottom)); | |||||
| sptr_base += iw2 * pad_top * ic_step; | |||||
| rep(ih_idx, ih) { | |||||
| memcpy(sptr_base + pw * ic_step, sptr, | |||||
| sizeof(int8_t) * iw * ic_step); | |||||
| sptr_base += iw2 * ic_step; | |||||
| sptr += iw * ic_step; | |||||
| } | |||||
| sptr_base += iw2 * pad_bottom * ic_step; | |||||
| } | |||||
| } | |||||
| template <BiasMode bias_mode, typename Op, int filter_size, int stride> | |||||
| void conv_direct_int8_nchw_nchw44_dot(const int8_t* src, const int8_t* filter, | |||||
| const int32_t* bias, int32_t* temp, | |||||
| int8_t* dst, const int oc, const int ic, | |||||
| const int ih, const int iw, const int oh, | |||||
| const int oh_block, const int ow, | |||||
| const Op& op) { | |||||
| MEGDNN_MARK_USED_VAR(temp); | |||||
| constexpr int fh = filter_size; | |||||
| constexpr int fw = (filter_size + 3) / 4 * 4; | |||||
| #if MEGDNN_AARCH64 | |||||
| constexpr int big_oc_step = 8; | |||||
| #else | |||||
| constexpr int big_oc_step = 4; | |||||
| #endif | |||||
| constexpr int oc_step = 4; | |||||
| constexpr int ih_step = 1; | |||||
| constexpr int oh_step = 1; | |||||
| constexpr int ow_step = 8; | |||||
| constexpr int stride_h = stride; | |||||
| constexpr int stride_w = stride; | |||||
| constexpr int pack_iw_len = stride == 2 ? 1 : 4; | |||||
| const int img_stride = oh * ow; | |||||
| const int ow_end = ow / ow_step * ow_step; | |||||
| const int ow_remain = ow - ow_end; | |||||
| const int oc_end = oc / big_oc_step * big_oc_step; | |||||
| const int oc_remain = oc - oc_end; | |||||
| const int ld_dst_oc = oc_step * img_stride; | |||||
| using remain_fun = | |||||
| std::function<void(const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, int8_t* dst_ptr, int ic, | |||||
| int ih, int iw, int ld_dst_oc, const Op& op)>; | |||||
| remain_fun kern_big_oc_remain = nullptr; | |||||
| remain_fun kern_small_oc_remain = nullptr; | |||||
| switch (ow_remain) { | |||||
| #define cb(step) \ | |||||
| case step: \ | |||||
| kern_big_oc_remain = \ | |||||
| KerNeonDotXXs2Nchw44Int8<bias_mode, Op, step, filter_size, \ | |||||
| big_oc_step, ow_step, stride>::impl; \ | |||||
| kern_small_oc_remain = \ | |||||
| KerNeonDotXXs2Nchw44Int8<bias_mode, Op, step, filter_size, \ | |||||
| oc_step, ow_step, stride>::impl; \ | |||||
| break; | |||||
| UNROLL_CALL_RAW(8, cb); | |||||
| default: | |||||
| megdnn_assert(0, "no remain %d for kern", ow_remain); | |||||
| } | |||||
| for (int oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { | |||||
| const int weight_offset = oc_idx * ic * fh * fw; | |||||
| for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) { | |||||
| for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||||
| const int src_offset = | |||||
| (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * | |||||
| pack_iw_len; | |||||
| const int dst_offset = | |||||
| oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||||
| KerNeonDotXXs2Nchw44Int8<bias_mode, Op, ow_step, filter_size, | |||||
| big_oc_step, ow_step, | |||||
| stride>::impl(src + src_offset, | |||||
| filter + weight_offset, | |||||
| bias + oc_idx, | |||||
| dst + dst_offset, ic, ih, | |||||
| iw, ld_dst_oc, op); | |||||
| } | |||||
| if (ow_remain > 0) { | |||||
| const int src_offset = | |||||
| (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * | |||||
| pack_iw_len; | |||||
| const int dst_offset = | |||||
| oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||||
| kern_big_oc_remain(src + src_offset, filter + weight_offset, | |||||
| bias + oc_idx, dst + dst_offset, ic, ih, iw, | |||||
| ld_dst_oc, op); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (oc_remain > 0) { | |||||
| int oc_idx = oc_end; | |||||
| const int weight_offset = oc_idx * ic * fh * fw; | |||||
| for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) { | |||||
| for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||||
| const int src_offset = | |||||
| (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * | |||||
| pack_iw_len; | |||||
| const int dst_offset = | |||||
| oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||||
| KerNeonDotXXs2Nchw44Int8<bias_mode, Op, ow_step, filter_size, | |||||
| oc_step, ow_step, | |||||
| stride>::impl(src + src_offset, | |||||
| filter + weight_offset, | |||||
| bias + oc_idx, | |||||
| dst + dst_offset, ic, ih, | |||||
| iw, ld_dst_oc, op); | |||||
| } | |||||
| if (ow_remain > 0) { | |||||
| const int src_offset = | |||||
| (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * | |||||
| pack_iw_len; | |||||
| const int dst_offset = | |||||
| oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||||
| kern_small_oc_remain(src + src_offset, filter + weight_offset, | |||||
| bias + oc_idx, dst + dst_offset, ic, ih, | |||||
| iw, ld_dst_oc, op); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| #define DO_CONV_KERN_FUN(stride, filter_size, bias_mode, Op) \ | |||||
| template void \ | |||||
| conv_direct_int8_nchw_nchw44_dot<bias_mode, Op, filter_size, stride>( \ | |||||
| const int8_t* src, const int8_t* filter, const int32_t* bias, \ | |||||
| int32_t* temp, int8_t* dst, const int oc, const int ic, \ | |||||
| const int ih, const int iw, const int oh, const int oh_block, \ | |||||
| const int ow, const Op& op); | |||||
| #define GET_OP_PARAM(stride, filter, bias_mode) \ | |||||
| DO_CONV_KERN_FUN(stride, filter, bias_mode, \ | |||||
| TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
| DO_CONV_KERN_FUN(stride, filter, bias_mode, \ | |||||
| ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
| DO_CONV_KERN_FUN(stride, filter, bias_mode, \ | |||||
| HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) | |||||
| #define GET_BIAS_MODE_PARAM(stride, filter) \ | |||||
| GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \ | |||||
| GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) | |||||
| #define DISPATCH_CONV_KERN(stride) \ | |||||
| GET_BIAS_MODE_PARAM(stride, 2) \ | |||||
| GET_BIAS_MODE_PARAM(stride, 3) \ | |||||
| GET_BIAS_MODE_PARAM(stride, 5) \ | |||||
| GET_BIAS_MODE_PARAM(stride, 7) | |||||
| DISPATCH_CONV_KERN(2); | |||||
| } // namespace dot_direct_nchw_nchw44 | |||||
| } // namespace arm_common | |||||
| } // namespace megdnn | |||||
| #endif | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -0,0 +1,743 @@ | |||||
| /** | |||||
| * \file | |||||
| * dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s1.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "src/arm_common/conv_bias/int8/direct.h" | |||||
| #include "src/arm_common/conv_bias/int8/direct_nchw44_kern.h" | |||||
| #include "src/arm_common/conv_bias/intrinsic_helper.h" | |||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||||
| #include "src/common/utils.h" | |||||
| #include "src/fallback/conv_bias/common.h" | |||||
| namespace megdnn { | |||||
| namespace arm_common { | |||||
| namespace { | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int filter_size, | |||||
| int c_dim, typename DstType> | |||||
| static void ker_neon_dirctconv_2x2s1_oc8_ow8(const int8_t* src_ptr, | |||||
| const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, | |||||
| DstType* dst_ptr, int ic, int ih, | |||||
| int iw, int ld_dst_oc, | |||||
| const Op& op) { | |||||
| constexpr int fh = filter_size; | |||||
| constexpr int fw = filter_size; | |||||
| constexpr int ic_step = 4; | |||||
| constexpr int oc_step = 4; | |||||
| constexpr int loop_ic_step = 4; | |||||
| constexpr int ld_weight_ic4 = 16; | |||||
| constexpr int pack_iw_len = 4; | |||||
| const int ic_stride = ih * iw * pack_iw_len; | |||||
| const int ld_weight_oc4 = oc_step * fh * fw * ic; | |||||
| int32x4_t c[2][8]; | |||||
| int8x16_t weight[2][2]; | |||||
| int8x16_t src[8 + 1]; | |||||
| int16x8_t temp_c[4]; | |||||
| init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
| for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||||
| const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + | |||||
| fh_idx * iw * ic_step * pack_iw_len; | |||||
| src[0] = vld1q_s8(src_ic_0_3); | |||||
| src[1] = vld1q_s8((src_ic_0_3 + 16)); | |||||
| src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); | |||||
| src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); | |||||
| src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); | |||||
| src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); | |||||
| src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); | |||||
| src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); | |||||
| src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); | |||||
| // oc == 0 | |||||
| const int8_t* read_weight_ptr = | |||||
| weight_ptr + fh_idx * fw * ld_weight_ic4; | |||||
| weight[0][0] = vld1q_s8(read_weight_ptr); | |||||
| weight[0][1] = vld1q_s8(read_weight_ptr + 16); | |||||
| weight[1][0] = vld1q_s8(read_weight_ptr + ld_weight_oc4); | |||||
| weight[1][1] = vld1q_s8(read_weight_ptr + ld_weight_oc4 + 16); | |||||
| c[0][0] = vdotq_s32_h(weight[0][0], src[0], c[0][0], temp_c[0]); | |||||
| c[1][0] = vdotq_s32_h(weight[1][0], src[0], c[1][0], temp_c[1]); | |||||
| c[0][1] = vdotq_s32_h(weight[0][0], src[1], c[0][1], temp_c[2]); | |||||
| c[1][1] = vdotq_s32_h(weight[1][0], src[1], c[1][1], temp_c[3]); | |||||
| c[0][0] = vdotq_s32_h(weight[0][1], src[1], c[0][0], temp_c[0]); | |||||
| c[1][0] = vdotq_s32_h(weight[1][1], src[1], c[1][0], temp_c[1]); | |||||
| c[0][1] = vdotq_s32_h(weight[0][1], src[2], c[0][1], temp_c[2]); | |||||
| c[1][1] = vdotq_s32_h(weight[1][1], src[2], c[1][1], temp_c[3]); | |||||
| c[0][2] = vdotq_s32_h(weight[0][0], src[2], c[0][2], temp_c[0]); | |||||
| c[1][2] = vdotq_s32_h(weight[1][0], src[2], c[1][2], temp_c[1]); | |||||
| c[0][3] = vdotq_s32_h(weight[0][0], src[3], c[0][3], temp_c[2]); | |||||
| c[1][3] = vdotq_s32_h(weight[1][0], src[3], c[1][3], temp_c[3]); | |||||
| c[0][2] = vdotq_s32_h(weight[0][1], src[3], c[0][2], temp_c[0]); | |||||
| c[1][2] = vdotq_s32_h(weight[1][1], src[3], c[1][2], temp_c[1]); | |||||
| c[0][3] = vdotq_s32_h(weight[0][1], src[4], c[0][3], temp_c[2]); | |||||
| c[1][3] = vdotq_s32_h(weight[1][1], src[4], c[1][3], temp_c[3]); | |||||
| c[0][4] = vdotq_s32_h(weight[0][0], src[4], c[0][4], temp_c[0]); | |||||
| c[1][4] = vdotq_s32_h(weight[1][0], src[4], c[1][4], temp_c[1]); | |||||
| c[0][5] = vdotq_s32_h(weight[0][0], src[5], c[0][5], temp_c[2]); | |||||
| c[1][5] = vdotq_s32_h(weight[1][0], src[5], c[1][5], temp_c[3]); | |||||
| c[0][4] = vdotq_s32_h(weight[0][1], src[5], c[0][4], temp_c[0]); | |||||
| c[1][4] = vdotq_s32_h(weight[1][1], src[5], c[1][4], temp_c[1]); | |||||
| c[0][5] = vdotq_s32_h(weight[0][1], src[6], c[0][5], temp_c[2]); | |||||
| c[1][5] = vdotq_s32_h(weight[1][1], src[6], c[1][5], temp_c[3]); | |||||
| c[0][6] = vdotq_s32_h(weight[0][0], src[6], c[0][6], temp_c[0]); | |||||
| c[1][6] = vdotq_s32_h(weight[1][0], src[6], c[1][6], temp_c[1]); | |||||
| c[0][7] = vdotq_s32_h(weight[0][0], src[7], c[0][7], temp_c[2]); | |||||
| c[1][7] = vdotq_s32_h(weight[1][0], src[7], c[1][7], temp_c[3]); | |||||
| c[0][6] = vdotq_s32_h(weight[0][1], src[7], c[0][6], temp_c[0]); | |||||
| c[1][6] = vdotq_s32_h(weight[1][1], src[7], c[1][6], temp_c[1]); | |||||
| c[0][7] = vdotq_s32_h(weight[0][1], src[8], c[0][7], temp_c[2]); | |||||
| c[1][7] = vdotq_s32_h(weight[1][1], src[8], c[1][7], temp_c[3]); | |||||
| } | |||||
| weight_ptr += fh * fw * ld_weight_ic4; | |||||
| } | |||||
| store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, DstType*>( | |||||
| c, op, dst_ptr, ld_dst_oc); | |||||
| } | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int filter_size, | |||||
| int c_dim, typename DstType> | |||||
| static void ker_neon_dirctconv_2x2s1_oc4_ow8(const int8_t* src_ptr, | |||||
| const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, | |||||
| DstType* dst_ptr, int ic, int ih, | |||||
| int iw, int ld_dst_oc, | |||||
| const Op& op) { | |||||
| constexpr int fh = filter_size; | |||||
| constexpr int fw = filter_size; | |||||
| constexpr int oc_step = 4; | |||||
| constexpr int ic_step = 4; | |||||
| constexpr int loop_ic_step = 4; | |||||
| constexpr int ld_weight_ic4 = 16; | |||||
| constexpr int pack_iw_len = 4; | |||||
| const int ic_stride = ih * iw * pack_iw_len; | |||||
| int32x4_t c[1][8]; | |||||
| int8x16_t weight[1][2]; | |||||
| int8x16_t src[8 + 1]; | |||||
| int16x8_t temp_c[2]; | |||||
| init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
| for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||||
| const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + | |||||
| fh_idx * iw * ic_step * pack_iw_len; | |||||
| src[0] = vld1q_s8(src_ic_0_3); | |||||
| src[1] = vld1q_s8((src_ic_0_3 + 16)); | |||||
| src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); | |||||
| src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); | |||||
| src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); | |||||
| src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); | |||||
| src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); | |||||
| src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); | |||||
| src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); | |||||
| // oc == 0 | |||||
| const int8_t* read_weight_ptr = | |||||
| weight_ptr + fh_idx * fw * ld_weight_ic4; | |||||
| weight[0][0] = vld1q_s8(read_weight_ptr); | |||||
| weight[0][1] = vld1q_s8(read_weight_ptr + 16); | |||||
| c[0][0] = vdotq_s32_h(weight[0][0], src[0], c[0][0], temp_c[0]); | |||||
| c[0][1] = vdotq_s32_h(weight[0][0], src[1], c[0][1], temp_c[1]); | |||||
| c[0][0] = vdotq_s32_h(weight[0][1], src[1], c[0][0], temp_c[0]); | |||||
| c[0][1] = vdotq_s32_h(weight[0][1], src[2], c[0][1], temp_c[1]); | |||||
| c[0][2] = vdotq_s32_h(weight[0][0], src[2], c[0][2], temp_c[0]); | |||||
| c[0][3] = vdotq_s32_h(weight[0][0], src[3], c[0][3], temp_c[1]); | |||||
| c[0][2] = vdotq_s32_h(weight[0][1], src[3], c[0][2], temp_c[0]); | |||||
| c[0][3] = vdotq_s32_h(weight[0][1], src[4], c[0][3], temp_c[1]); | |||||
| c[0][4] = vdotq_s32_h(weight[0][0], src[4], c[0][4], temp_c[0]); | |||||
| c[0][5] = vdotq_s32_h(weight[0][0], src[5], c[0][5], temp_c[1]); | |||||
| c[0][4] = vdotq_s32_h(weight[0][1], src[5], c[0][4], temp_c[0]); | |||||
| c[0][5] = vdotq_s32_h(weight[0][1], src[6], c[0][5], temp_c[1]); | |||||
| c[0][6] = vdotq_s32_h(weight[0][0], src[6], c[0][6], temp_c[0]); | |||||
| c[0][7] = vdotq_s32_h(weight[0][0], src[7], c[0][7], temp_c[1]); | |||||
| c[0][6] = vdotq_s32_h(weight[0][1], src[7], c[0][6], temp_c[0]); | |||||
| c[0][7] = vdotq_s32_h(weight[0][1], src[8], c[0][7], temp_c[1]); | |||||
| } | |||||
| weight_ptr += fh * fw * ld_weight_ic4; | |||||
| } | |||||
| store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, DstType*>( | |||||
| c, op, dst_ptr, ld_dst_oc); | |||||
| } | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int filter_size, | |||||
| int c_dim, typename DstType> | |||||
| struct KerNeonDirectStride1Int8 { | |||||
| static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, | |||||
| int iw, const Op& op, int ld_dst_oc); | |||||
| }; | |||||
| /** | |||||
| dot like impl. dot 4 ic to 1 oc, accumale to c <ow, oc> | |||||
| example: (format like weight<oc, ic>) | |||||
| packed weight | |||||
| low 64 bit <0, 0> <0, 1> <1, 2> <1, 3> | <2, 0> <2, 1> <3, 2> <3, 3> | |||||
| --------------------------------------------------------------------- | |||||
| high 64 bit <0, 3> <0, 2> <1, 1> <1, 0> | <2, 3> <2, 2> <3, 1> <3, 0> | |||||
| dot: (<0, 0> + <0, 3>) + (<0, 1> + <0, 2>) -> <0> | |||||
| **/ | |||||
| //! TODO: can try oh = 2 impl, oc = 8 impl | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int c_dim, | |||||
| typename DstType> | |||||
| struct KerNeonDirectStride1Int8<bias_mode, Op, remain_w, 3, c_dim, DstType> { | |||||
| static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, | |||||
| int iw, const Op& op, int ld_dst_oc) { | |||||
| constexpr int filter_size = 3; | |||||
| constexpr int fh = filter_size; | |||||
| constexpr int fw = filter_size; | |||||
| constexpr int oc_step = 4; | |||||
| constexpr int ic_step = 4; | |||||
| constexpr int loop_ic_step = 4; | |||||
| constexpr int ld_weight_ic4 = 16; | |||||
| constexpr int pack_iw_len = 4; | |||||
| const int ic_stride = ih * iw * pack_iw_len; | |||||
| int32x4_t c[c_dim][8]; | |||||
| int8x16_t weight[3]; | |||||
| int8x16_t src[8 + 2]; | |||||
| int16x8_t temp_c[2]; | |||||
| init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
| for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||||
| const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + | |||||
| fh_idx * iw * ic_step * pack_iw_len; | |||||
| src[0] = vld1q_s8(src_ic_0_3); | |||||
| src[1] = vld1q_s8((src_ic_0_3 + 16)); | |||||
| src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); | |||||
| src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); | |||||
| src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); | |||||
| src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); | |||||
| src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); | |||||
| src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); | |||||
| src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); | |||||
| src[9] = vld1q_s8((src_ic_0_3 + 9 * 16)); | |||||
| // oc == 0 | |||||
| const int8_t* read_weight_ptr = | |||||
| weight_ptr + fh_idx * fw * ld_weight_ic4; | |||||
| weight[0] = vld1q_s8(read_weight_ptr); | |||||
| weight[1] = vld1q_s8(read_weight_ptr + 16); | |||||
| weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); | |||||
| c[0][0] = vdotq_s32_h(weight[0], src[0], c[0][0], temp_c[0]); | |||||
| c[0][1] = vdotq_s32_h(weight[0], src[1], c[0][1], temp_c[1]); | |||||
| c[0][0] = vdotq_s32_h(weight[1], src[1], c[0][0], temp_c[0]); | |||||
| c[0][1] = vdotq_s32_h(weight[1], src[2], c[0][1], temp_c[1]); | |||||
| c[0][0] = vdotq_s32_h(weight[2], src[2], c[0][0], temp_c[0]); | |||||
| c[0][1] = vdotq_s32_h(weight[2], src[3], c[0][1], temp_c[1]); | |||||
| c[0][2] = vdotq_s32_h(weight[0], src[2], c[0][2], temp_c[0]); | |||||
| c[0][3] = vdotq_s32_h(weight[0], src[3], c[0][3], temp_c[1]); | |||||
| c[0][2] = vdotq_s32_h(weight[1], src[3], c[0][2], temp_c[0]); | |||||
| c[0][3] = vdotq_s32_h(weight[1], src[4], c[0][3], temp_c[1]); | |||||
| c[0][2] = vdotq_s32_h(weight[2], src[4], c[0][2], temp_c[0]); | |||||
| c[0][3] = vdotq_s32_h(weight[2], src[5], c[0][3], temp_c[1]); | |||||
| c[0][4] = vdotq_s32_h(weight[0], src[4], c[0][4], temp_c[0]); | |||||
| c[0][5] = vdotq_s32_h(weight[0], src[5], c[0][5], temp_c[1]); | |||||
| c[0][4] = vdotq_s32_h(weight[1], src[5], c[0][4], temp_c[0]); | |||||
| c[0][5] = vdotq_s32_h(weight[1], src[6], c[0][5], temp_c[1]); | |||||
| c[0][4] = vdotq_s32_h(weight[2], src[6], c[0][4], temp_c[0]); | |||||
| c[0][5] = vdotq_s32_h(weight[2], src[7], c[0][5], temp_c[1]); | |||||
| c[0][6] = vdotq_s32_h(weight[0], src[6], c[0][6], temp_c[0]); | |||||
| c[0][7] = vdotq_s32_h(weight[0], src[7], c[0][7], temp_c[1]); | |||||
| c[0][6] = vdotq_s32_h(weight[1], src[7], c[0][6], temp_c[0]); | |||||
| c[0][7] = vdotq_s32_h(weight[1], src[8], c[0][7], temp_c[1]); | |||||
| c[0][6] = vdotq_s32_h(weight[2], src[8], c[0][6], temp_c[0]); | |||||
| c[0][7] = vdotq_s32_h(weight[2], src[9], c[0][7], temp_c[1]); | |||||
| } | |||||
| weight_ptr += fh * fw * ld_weight_ic4; | |||||
| } | |||||
| store_ocx_ow8_remain_static_dt<1, remain_w, Op, DstType*>( | |||||
| c, op, dst_ptr, ld_dst_oc); | |||||
| } | |||||
| }; | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int c_dim, | |||||
| typename DstType> | |||||
| struct KerNeonDirectStride1Int8<bias_mode, Op, remain_w, 5, c_dim, DstType> { | |||||
| static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, | |||||
| int iw, const Op& op, int ld_dst_oc) { | |||||
| constexpr int filter_size = 5; | |||||
| constexpr int fh = filter_size; | |||||
| constexpr int fw = filter_size; | |||||
| constexpr int oc_step = 4; | |||||
| constexpr int ic_step = 4; | |||||
| constexpr int loop_ic_step = 4; | |||||
| constexpr int ld_weight_ic4 = 16; | |||||
| constexpr int pack_iw_len = 4; | |||||
| const int ic_stride = ih * iw * pack_iw_len; | |||||
| int32x4_t c[c_dim][8]; | |||||
| int8x16_t weight[5]; | |||||
| int8x16_t src[8 + 2]; | |||||
| int16x8_t temp_c[2]; | |||||
| init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
| for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||||
| const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + | |||||
| fh_idx * iw * ic_step * pack_iw_len; | |||||
| src[0] = vld1q_s8(src_ic_0_3); | |||||
| src[1] = vld1q_s8((src_ic_0_3 + 16)); | |||||
| src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); | |||||
| src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); | |||||
| src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); | |||||
| src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); | |||||
| src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); | |||||
| src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); | |||||
| src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); | |||||
| src[9] = vld1q_s8((src_ic_0_3 + 9 * 16)); | |||||
| // oc == 0 | |||||
| const int8_t* read_weight_ptr = | |||||
| weight_ptr + fh_idx * fw * ld_weight_ic4; | |||||
| weight[0] = vld1q_s8(read_weight_ptr); | |||||
| weight[1] = vld1q_s8(read_weight_ptr + 16); | |||||
| weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); | |||||
| weight[3] = vld1q_s8(read_weight_ptr + 3 * 16); | |||||
| weight[4] = vld1q_s8(read_weight_ptr + 4 * 16); | |||||
| c[0][0] = vdotq_s32_h(weight[0], src[0], c[0][0], temp_c[0]); | |||||
| c[0][1] = vdotq_s32_h(weight[0], src[1], c[0][1], temp_c[1]); | |||||
| c[0][0] = vdotq_s32_h(weight[1], src[1], c[0][0], temp_c[0]); | |||||
| c[0][1] = vdotq_s32_h(weight[1], src[2], c[0][1], temp_c[1]); | |||||
| c[0][0] = vdotq_s32_h(weight[2], src[2], c[0][0], temp_c[0]); | |||||
| c[0][1] = vdotq_s32_h(weight[2], src[3], c[0][1], temp_c[1]); | |||||
| c[0][0] = vdotq_s32_h(weight[3], src[3], c[0][0], temp_c[0]); | |||||
| c[0][1] = vdotq_s32_h(weight[3], src[4], c[0][1], temp_c[1]); | |||||
| c[0][0] = vdotq_s32_h(weight[4], src[4], c[0][0], temp_c[0]); | |||||
| c[0][1] = vdotq_s32_h(weight[4], src[5], c[0][1], temp_c[1]); | |||||
| c[0][2] = vdotq_s32_h(weight[0], src[2], c[0][2], temp_c[0]); | |||||
| c[0][3] = vdotq_s32_h(weight[0], src[3], c[0][3], temp_c[1]); | |||||
| c[0][2] = vdotq_s32_h(weight[1], src[3], c[0][2], temp_c[0]); | |||||
| c[0][3] = vdotq_s32_h(weight[1], src[4], c[0][3], temp_c[1]); | |||||
| c[0][2] = vdotq_s32_h(weight[2], src[4], c[0][2], temp_c[0]); | |||||
| c[0][3] = vdotq_s32_h(weight[2], src[5], c[0][3], temp_c[1]); | |||||
| c[0][2] = vdotq_s32_h(weight[3], src[5], c[0][2], temp_c[0]); | |||||
| c[0][3] = vdotq_s32_h(weight[3], src[6], c[0][3], temp_c[1]); | |||||
| c[0][2] = vdotq_s32_h(weight[4], src[6], c[0][2], temp_c[0]); | |||||
| c[0][3] = vdotq_s32_h(weight[4], src[7], c[0][3], temp_c[1]); | |||||
| c[0][4] = vdotq_s32_h(weight[0], src[4], c[0][4], temp_c[0]); | |||||
| c[0][5] = vdotq_s32_h(weight[0], src[5], c[0][5], temp_c[1]); | |||||
| c[0][4] = vdotq_s32_h(weight[1], src[5], c[0][4], temp_c[0]); | |||||
| c[0][5] = vdotq_s32_h(weight[1], src[6], c[0][5], temp_c[1]); | |||||
| c[0][4] = vdotq_s32_h(weight[2], src[6], c[0][4], temp_c[0]); | |||||
| c[0][5] = vdotq_s32_h(weight[2], src[7], c[0][5], temp_c[1]); | |||||
| c[0][4] = vdotq_s32_h(weight[3], src[7], c[0][4], temp_c[0]); | |||||
| c[0][5] = vdotq_s32_h(weight[3], src[8], c[0][5], temp_c[1]); | |||||
| c[0][4] = vdotq_s32_h(weight[4], src[8], c[0][4], temp_c[0]); | |||||
| c[0][5] = vdotq_s32_h(weight[4], src[9], c[0][5], temp_c[1]); | |||||
| src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); | |||||
| src[1] = vld1q_s8((src_ic_0_3 + 11 * 16)); | |||||
| c[0][6] = vdotq_s32_h(weight[0], src[6], c[0][6], temp_c[0]); | |||||
| c[0][7] = vdotq_s32_h(weight[0], src[7], c[0][7], temp_c[1]); | |||||
| c[0][6] = vdotq_s32_h(weight[1], src[7], c[0][6], temp_c[0]); | |||||
| c[0][7] = vdotq_s32_h(weight[1], src[8], c[0][7], temp_c[1]); | |||||
| c[0][6] = vdotq_s32_h(weight[2], src[8], c[0][6], temp_c[0]); | |||||
| c[0][7] = vdotq_s32_h(weight[2], src[9], c[0][7], temp_c[1]); | |||||
| c[0][6] = vdotq_s32_h(weight[3], src[9], c[0][6], temp_c[0]); | |||||
| c[0][7] = vdotq_s32_h(weight[3], src[0], c[0][7], temp_c[1]); | |||||
| c[0][6] = vdotq_s32_h(weight[4], src[0], c[0][6], temp_c[0]); | |||||
| c[0][7] = vdotq_s32_h(weight[4], src[1], c[0][7], temp_c[1]); | |||||
| } | |||||
| weight_ptr += fh * fw * ld_weight_ic4; | |||||
| } | |||||
| store_ocx_ow8_remain_static_dt<1, remain_w, Op, DstType*>( | |||||
| c, op, dst_ptr, ld_dst_oc); | |||||
| } | |||||
| }; | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int c_dim, | |||||
| typename DstType> | |||||
| struct KerNeonDirectStride1Int8<bias_mode, Op, remain_w, 7, c_dim, DstType> { | |||||
| static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, | |||||
| int iw, const Op& op, int ld_dst_oc) { | |||||
| constexpr int filter_size = 7; | |||||
| constexpr int fh = filter_size; | |||||
| constexpr int fw = filter_size; | |||||
| constexpr int oc_step = 4; | |||||
| constexpr int ic_step = 4; | |||||
| constexpr int loop_ic_step = 4; | |||||
| constexpr int ld_weight_ic4 = 16; | |||||
| constexpr int pack_iw_len = 4; | |||||
| const int ic_stride = ih * iw * pack_iw_len; | |||||
| int32x4_t c[c_dim][8]; | |||||
| int8x16_t weight[7]; | |||||
| int8x16_t src[8 + 2]; | |||||
| int16x8_t temp_c[2]; | |||||
| init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
| for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||||
| const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + | |||||
| fh_idx * iw * ic_step * pack_iw_len; | |||||
| src[0] = vld1q_s8(src_ic_0_3); | |||||
| src[1] = vld1q_s8((src_ic_0_3 + 16)); | |||||
| src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); | |||||
| src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); | |||||
| src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); | |||||
| src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); | |||||
| src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); | |||||
| src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); | |||||
| src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); | |||||
| src[9] = vld1q_s8((src_ic_0_3 + 9 * 16)); | |||||
| // oc == 0 | |||||
| const int8_t* read_weight_ptr = | |||||
| weight_ptr + fh_idx * fw * ld_weight_ic4; | |||||
| weight[0] = vld1q_s8(read_weight_ptr); | |||||
| weight[1] = vld1q_s8(read_weight_ptr + 16); | |||||
| weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); | |||||
| weight[3] = vld1q_s8(read_weight_ptr + 3 * 16); | |||||
| weight[4] = vld1q_s8(read_weight_ptr + 4 * 16); | |||||
| weight[5] = vld1q_s8(read_weight_ptr + 5 * 16); | |||||
| weight[6] = vld1q_s8(read_weight_ptr + 6 * 16); | |||||
| c[0][0] = vdotq_s32_h(weight[0], src[0], c[0][0], temp_c[0]); | |||||
| c[0][1] = vdotq_s32_h(weight[0], src[1], c[0][1], temp_c[1]); | |||||
| c[0][0] = vdotq_s32_h(weight[1], src[1], c[0][0], temp_c[0]); | |||||
| c[0][1] = vdotq_s32_h(weight[1], src[2], c[0][1], temp_c[1]); | |||||
| c[0][0] = vdotq_s32_h(weight[2], src[2], c[0][0], temp_c[0]); | |||||
| c[0][1] = vdotq_s32_h(weight[2], src[3], c[0][1], temp_c[1]); | |||||
| c[0][0] = vdotq_s32_h(weight[3], src[3], c[0][0], temp_c[0]); | |||||
| c[0][1] = vdotq_s32_h(weight[3], src[4], c[0][1], temp_c[1]); | |||||
| c[0][0] = vdotq_s32_h(weight[4], src[4], c[0][0], temp_c[0]); | |||||
| c[0][1] = vdotq_s32_h(weight[4], src[5], c[0][1], temp_c[1]); | |||||
| c[0][0] = vdotq_s32_h(weight[5], src[5], c[0][0], temp_c[0]); | |||||
| c[0][1] = vdotq_s32_h(weight[5], src[6], c[0][1], temp_c[1]); | |||||
| c[0][0] = vdotq_s32_h(weight[6], src[6], c[0][0], temp_c[0]); | |||||
| c[0][1] = vdotq_s32_h(weight[6], src[7], c[0][1], temp_c[1]); | |||||
| c[0][2] = vdotq_s32_h(weight[0], src[2], c[0][2], temp_c[0]); | |||||
| c[0][3] = vdotq_s32_h(weight[0], src[3], c[0][3], temp_c[1]); | |||||
| c[0][2] = vdotq_s32_h(weight[1], src[3], c[0][2], temp_c[0]); | |||||
| c[0][3] = vdotq_s32_h(weight[1], src[4], c[0][3], temp_c[1]); | |||||
| c[0][2] = vdotq_s32_h(weight[2], src[4], c[0][2], temp_c[0]); | |||||
| c[0][3] = vdotq_s32_h(weight[2], src[5], c[0][3], temp_c[1]); | |||||
| c[0][2] = vdotq_s32_h(weight[3], src[5], c[0][2], temp_c[0]); | |||||
| c[0][3] = vdotq_s32_h(weight[3], src[6], c[0][3], temp_c[1]); | |||||
| c[0][2] = vdotq_s32_h(weight[4], src[6], c[0][2], temp_c[0]); | |||||
| c[0][3] = vdotq_s32_h(weight[4], src[7], c[0][3], temp_c[1]); | |||||
| c[0][2] = vdotq_s32_h(weight[5], src[7], c[0][2], temp_c[0]); | |||||
| c[0][3] = vdotq_s32_h(weight[5], src[8], c[0][3], temp_c[1]); | |||||
| c[0][2] = vdotq_s32_h(weight[6], src[8], c[0][2], temp_c[0]); | |||||
| c[0][3] = vdotq_s32_h(weight[6], src[9], c[0][3], temp_c[1]); | |||||
| src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); | |||||
| src[1] = vld1q_s8((src_ic_0_3 + 11 * 16)); | |||||
| c[0][4] = vdotq_s32_h(weight[0], src[4], c[0][4], temp_c[0]); | |||||
| c[0][5] = vdotq_s32_h(weight[0], src[5], c[0][5], temp_c[1]); | |||||
| c[0][4] = vdotq_s32_h(weight[1], src[5], c[0][4], temp_c[0]); | |||||
| c[0][5] = vdotq_s32_h(weight[1], src[6], c[0][5], temp_c[1]); | |||||
| c[0][4] = vdotq_s32_h(weight[2], src[6], c[0][4], temp_c[0]); | |||||
| c[0][5] = vdotq_s32_h(weight[2], src[7], c[0][5], temp_c[1]); | |||||
| c[0][4] = vdotq_s32_h(weight[3], src[7], c[0][4], temp_c[0]); | |||||
| c[0][5] = vdotq_s32_h(weight[3], src[8], c[0][5], temp_c[1]); | |||||
| c[0][4] = vdotq_s32_h(weight[4], src[8], c[0][4], temp_c[0]); | |||||
| c[0][5] = vdotq_s32_h(weight[4], src[9], c[0][5], temp_c[1]); | |||||
| c[0][4] = vdotq_s32_h(weight[5], src[9], c[0][4], temp_c[0]); | |||||
| c[0][5] = vdotq_s32_h(weight[5], src[0], c[0][5], temp_c[1]); | |||||
| c[0][4] = vdotq_s32_h(weight[6], src[0], c[0][4], temp_c[0]); | |||||
| c[0][5] = vdotq_s32_h(weight[6], src[1], c[0][5], temp_c[1]); | |||||
| src[2] = vld1q_s8(src_ic_0_3 + 12 * 16); | |||||
| src[3] = vld1q_s8((src_ic_0_3 + 13 * 16)); | |||||
| c[0][6] = vdotq_s32_h(weight[0], src[6], c[0][6], temp_c[0]); | |||||
| c[0][7] = vdotq_s32_h(weight[0], src[7], c[0][7], temp_c[1]); | |||||
| c[0][6] = vdotq_s32_h(weight[1], src[7], c[0][6], temp_c[0]); | |||||
| c[0][7] = vdotq_s32_h(weight[1], src[8], c[0][7], temp_c[1]); | |||||
| c[0][6] = vdotq_s32_h(weight[2], src[8], c[0][6], temp_c[0]); | |||||
| c[0][7] = vdotq_s32_h(weight[2], src[9], c[0][7], temp_c[1]); | |||||
| c[0][6] = vdotq_s32_h(weight[3], src[9], c[0][6], temp_c[0]); | |||||
| c[0][7] = vdotq_s32_h(weight[3], src[0], c[0][7], temp_c[1]); | |||||
| c[0][6] = vdotq_s32_h(weight[4], src[0], c[0][6], temp_c[0]); | |||||
| c[0][7] = vdotq_s32_h(weight[4], src[1], c[0][7], temp_c[1]); | |||||
| c[0][6] = vdotq_s32_h(weight[5], src[1], c[0][6], temp_c[0]); | |||||
| c[0][7] = vdotq_s32_h(weight[5], src[2], c[0][7], temp_c[1]); | |||||
| c[0][6] = vdotq_s32_h(weight[6], src[2], c[0][6], temp_c[0]); | |||||
| c[0][7] = vdotq_s32_h(weight[6], src[3], c[0][7], temp_c[1]); | |||||
| } | |||||
| weight_ptr += fh * fw * ld_weight_ic4; | |||||
| } | |||||
| store_ocx_ow8_remain_static_dt<1, remain_w, Op, DstType*>( | |||||
| c, op, dst_ptr, ld_dst_oc); | |||||
| } | |||||
| }; | |||||
| template <BiasMode bias_mode, typename Op, typename DstType> | |||||
| void conv_direct_stride1_2x2_int8_nchw44(const int8_t* src, | |||||
| const int8_t* filter, | |||||
| const int32_t* bias, int32_t* temp, | |||||
| DstType* dst, const size_t oc, | |||||
| const size_t ic, const size_t ih, | |||||
| const size_t iw, const size_t oh, | |||||
| const size_t ow, const Op& op) { | |||||
| MEGDNN_MARK_USED_VAR(temp); | |||||
| constexpr size_t filter_size = 2; | |||||
| constexpr size_t fh = filter_size; | |||||
| constexpr size_t fw = filter_size; | |||||
| constexpr size_t ic_step = 4; | |||||
| constexpr size_t oc_step = 4; | |||||
| constexpr size_t big_oc_step = 8; | |||||
| constexpr size_t oh_step = 1; | |||||
| constexpr size_t ow_step = 8; | |||||
| constexpr int pack_iw_len = 4; | |||||
| const size_t img_stride = oh * ow; | |||||
| const size_t ow_end = ow / ow_step * ow_step; | |||||
| const size_t ow_remain = ow - ow_end; | |||||
| const size_t oc_end = oc / big_oc_step * big_oc_step; | |||||
| const size_t oc_remain = oc - oc_end; | |||||
| const int ld_oc = oh * ow * oc_step; | |||||
| using remain_fun = std::function<void( | |||||
| const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, int iw, | |||||
| int ld_dst_oc, const Op& op)>; | |||||
| remain_fun kern_big_oc_remain = nullptr; | |||||
| remain_fun kern_small_oc_remain = nullptr; | |||||
| switch (ow_remain) { | |||||
| #define cb(step) \ | |||||
| case step: \ | |||||
| kern_big_oc_remain = \ | |||||
| ker_neon_dirctconv_2x2s1_oc8_ow8<bias_mode, Op, step, \ | |||||
| filter_size, 2, DstType>; \ | |||||
| kern_small_oc_remain = \ | |||||
| ker_neon_dirctconv_2x2s1_oc4_ow8<bias_mode, Op, step, \ | |||||
| filter_size, 1, DstType>; \ | |||||
| break; | |||||
| UNROLL_CALL_RAW(8, cb); | |||||
| default: | |||||
| megdnn_assert(0, "no remain %zu for kern", ow_remain); | |||||
| } | |||||
| #undef cb | |||||
| for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { | |||||
| const size_t weight_offset = oc_idx * ic * fh * fw; | |||||
| for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||||
| for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||||
| const size_t src_offset = | |||||
| (oh_idx * iw + ow_idx) * ic_step * pack_iw_len; | |||||
| const size_t dst_offset = | |||||
| oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||||
| ker_neon_dirctconv_2x2s1_oc8_ow8<bias_mode, Op, 0, filter_size, | |||||
| 2, DstType>( | |||||
| src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
| dst + dst_offset, ic, ih, iw, ld_oc, op); | |||||
| } | |||||
| if (ow_remain > 0) { | |||||
| const size_t src_offset = | |||||
| (oh_idx * iw + ow_end) * ic_step * pack_iw_len; | |||||
| const size_t dst_offset = | |||||
| oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||||
| kern_big_oc_remain(src + src_offset, filter + weight_offset, | |||||
| bias + oc_idx, dst + dst_offset, ic, ih, iw, | |||||
| ld_oc, op); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (oc_remain > 0) { | |||||
| const size_t oc_idx = oc_end; | |||||
| const size_t weight_offset = oc_idx * ic * fh * fw; | |||||
| for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||||
| for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||||
| const size_t src_offset = | |||||
| (oh_idx * iw + ow_idx) * ic_step * pack_iw_len; | |||||
| const size_t dst_offset = | |||||
| oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||||
| ker_neon_dirctconv_2x2s1_oc4_ow8<bias_mode, Op, 0, filter_size, | |||||
| 1, DstType>( | |||||
| src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
| dst + dst_offset, ic, ih, iw, ld_oc, op); | |||||
| } | |||||
| if (ow_remain > 0) { | |||||
| const size_t src_offset = | |||||
| (oh_idx * iw + ow_end) * ic_step * pack_iw_len; | |||||
| const size_t dst_offset = | |||||
| oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||||
| kern_small_oc_remain(src + src_offset, filter + weight_offset, | |||||
| bias + oc_idx, dst + dst_offset, ic, ih, | |||||
| iw, ld_oc, op); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| template <BiasMode bias_mode, typename Op, int filter_size, typename DstType> | |||||
| void conv_direct_stride1_int8_nchw44_kern(const int8_t* src, | |||||
| const int8_t* filter, | |||||
| const int32_t* bias, int32_t* temp, | |||||
| DstType* dst, const size_t oc, | |||||
| const size_t ic, const size_t ih, | |||||
| const size_t iw, const size_t oh, | |||||
| const size_t ow, const Op& op) { | |||||
| MEGDNN_MARK_USED_VAR(temp); | |||||
| constexpr size_t fh = filter_size; | |||||
| constexpr size_t fw = filter_size; | |||||
| constexpr size_t ic_step = 4; | |||||
| constexpr size_t oc_step = 4; | |||||
| constexpr size_t oh_step = 1; | |||||
| constexpr size_t ow_step = 8; | |||||
| constexpr int pack_iw_len = 4; | |||||
| const size_t img_stride = oh * ow; | |||||
| const int ld_dst_oc = oh * ow * oc_step; | |||||
| const size_t ow_end = ow / ow_step * ow_step; | |||||
| const size_t ow_remain = ow - ow_end; | |||||
| using remain_fun = std::function<void( | |||||
| const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, int iw, | |||||
| const Op& op, int ld_dst_oc)>; | |||||
| remain_fun kern_small_oc_remain = nullptr; | |||||
| switch (ow_remain) { | |||||
| #define cb(step) \ | |||||
| case step: \ | |||||
| kern_small_oc_remain = \ | |||||
| KerNeonDirectStride1Int8<bias_mode, Op, step, filter_size, 1, \ | |||||
| DstType>::impl; \ | |||||
| break; | |||||
| UNROLL_CALL_RAW(8, cb); | |||||
| default: | |||||
| megdnn_assert(0, "no remain %zu for kern", ow_remain); | |||||
| } | |||||
| #undef cb | |||||
| for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { | |||||
| const size_t weight_offset = oc_idx * ic * fh * fw; | |||||
| for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||||
| for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||||
| const size_t src_offset = | |||||
| (oh_idx * iw + ow_idx) * ic_step * pack_iw_len; | |||||
| const size_t dst_offset = | |||||
| oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||||
| KerNeonDirectStride1Int8<bias_mode, Op, ow_step, filter_size, 1, | |||||
| DstType>::impl(src + src_offset, | |||||
| filter + weight_offset, | |||||
| bias + oc_idx, | |||||
| dst + dst_offset, ic, | |||||
| ih, iw, op, ld_dst_oc); | |||||
| } | |||||
| if (ow_remain > 0) { | |||||
| const size_t src_offset = | |||||
| (oh_idx * iw + ow_end) * ic_step * pack_iw_len; | |||||
| const size_t dst_offset = | |||||
| oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||||
| kern_small_oc_remain(src + src_offset, filter + weight_offset, | |||||
| bias + oc_idx, dst + dst_offset, ic, ih, | |||||
| iw, op, ld_dst_oc); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } // namespace | |||||
| namespace int8_direct_nchw44 { | |||||
| template <BiasMode bias_mode, typename Op, int filter_size, typename DstType> | |||||
| struct ConvDirectInt8Nchw44Choose<bias_mode, Op, filter_size, DstType, 1> { | |||||
| static void impl(const int8_t* src, const int8_t* filter, | |||||
| const int32_t* bias, int32_t* temp, DstType* dst, | |||||
| const size_t oc, const size_t ic, const size_t ih, | |||||
| const size_t iw, const size_t oh, const size_t ow, | |||||
| const Op& op) { | |||||
| conv_direct_stride1_int8_nchw44_kern<bias_mode, Op, filter_size, | |||||
| DstType>( | |||||
| src, filter, bias, temp, dst, oc, ic, ih, iw, oh, ow, op); | |||||
| } | |||||
| }; | |||||
| template <BiasMode bias_mode, typename Op, typename DstType> | |||||
| struct ConvDirectInt8Nchw44Choose<bias_mode, Op, 2, DstType, 1> { | |||||
| static void impl(const int8_t* src, const int8_t* filter, | |||||
| const int32_t* bias, int32_t* temp, DstType* dst, | |||||
| const size_t oc, const size_t ic, const size_t ih, | |||||
| const size_t iw, const size_t oh, const size_t ow, | |||||
| const Op& op) { | |||||
| conv_direct_stride1_2x2_int8_nchw44<bias_mode, Op, DstType>( | |||||
| src, filter, bias, temp, dst, oc, ic, ih, iw, oh, ow, op); | |||||
| } | |||||
| }; | |||||
| #define DO_CONV_KERN_FUN(stride, DstType, filter_size, bias_mode, Op) \ | |||||
| template struct ConvDirectInt8Nchw44Choose<bias_mode, Op, filter_size, \ | |||||
| DstType, stride>; | |||||
| #define GET_OP_PARAM(stride, filter, bias_mode) \ | |||||
| DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \ | |||||
| \ | |||||
| TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
| DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \ | |||||
| \ | |||||
| ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
| DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \ | |||||
| \ | |||||
| HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
| DO_CONV_KERN_FUN(stride, dt_int32, filter, bias_mode, NoneOp<dt_int32>) | |||||
| #define GET_BIAS_MODE_PARAM(stride, filter) \ | |||||
| GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \ | |||||
| GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) | |||||
| #define DISPATCH_CONV_KERN(stride) \ | |||||
| GET_BIAS_MODE_PARAM(stride, 2) \ | |||||
| GET_BIAS_MODE_PARAM(stride, 3) \ | |||||
| GET_BIAS_MODE_PARAM(stride, 5) \ | |||||
| GET_BIAS_MODE_PARAM(stride, 7) | |||||
| DISPATCH_CONV_KERN(1); | |||||
| } // namespace int8_direct_nchw44 | |||||
| } // namespace arm_common | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -0,0 +1,778 @@ | |||||
| /** | |||||
| * \file | |||||
| * dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s2.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "src/arm_common/conv_bias/int8/direct.h" | |||||
| #include "src/arm_common/conv_bias/int8/direct_nchw44_kern.h" | |||||
| #include "src/arm_common/conv_bias/intrinsic_helper.h" | |||||
| #include "src/arm_common/elemwise_op.h" | |||||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||||
| #include "src/common/utils.h" | |||||
| #include "src/fallback/conv_bias/common.h" | |||||
| namespace megdnn { | |||||
| namespace arm_common { | |||||
| namespace { | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int filter_size, | |||||
| int c_dim, typename DstType> | |||||
| struct KerNeonDirectStride2Int8 { | |||||
| static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, | |||||
| int iw, const Op& op, int ld_dst_oc); | |||||
| }; | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int filter_size, | |||||
| int c_dim, typename DstType> | |||||
| static void ker_neon_dirctconv_2x2s2_oc8_ow8(const int8_t* src_ptr, | |||||
| const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, | |||||
| DstType* dst_ptr, int ic, int ih, | |||||
| int iw, int ld_dst_oc, | |||||
| const Op& op) { | |||||
| constexpr int fh = filter_size; | |||||
| constexpr int fw = filter_size; | |||||
| constexpr int ic_step = 4; | |||||
| constexpr int oc_step = 4; | |||||
| constexpr int loop_ic_step = 4; | |||||
| constexpr int ld_weight_ic4 = 16; | |||||
| constexpr int pack_iw_len = 4; | |||||
| const int ic_stride = ih * iw * pack_iw_len; | |||||
| const int ld_weight_oc4 = oc_step * fh * fw * ic; | |||||
| int32x4_t c[2][8]; | |||||
| int8x16_t weight[2][2]; | |||||
| int8x16_t src[8 + 1]; | |||||
| int16x8_t temp_c[4]; | |||||
| init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
| for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||||
| const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + | |||||
| fh_idx * iw * ic_step * pack_iw_len; | |||||
| src[0] = vld1q_s8(src_ic_0_3); | |||||
| src[1] = vld1q_s8(src_ic_0_3 + 16); | |||||
| src[2] = vld1q_s8(src_ic_0_3 + 2 * 16); | |||||
| src[3] = vld1q_s8(src_ic_0_3 + 3 * 16); | |||||
| src[4] = vld1q_s8(src_ic_0_3 + 4 * 16); | |||||
| src[5] = vld1q_s8(src_ic_0_3 + 5 * 16); | |||||
| src[6] = vld1q_s8(src_ic_0_3 + 6 * 16); | |||||
| src[7] = vld1q_s8(src_ic_0_3 + 7 * 16); | |||||
| src[8] = vld1q_s8(src_ic_0_3 + 8 * 16); | |||||
| // oc == 0 | |||||
| const int8_t* read_weight_ptr = | |||||
| weight_ptr + fh_idx * fw * ld_weight_ic4; | |||||
| weight[0][0] = vld1q_s8(read_weight_ptr); | |||||
| weight[0][1] = vld1q_s8(read_weight_ptr + 16); | |||||
| weight[1][0] = vld1q_s8(read_weight_ptr + ld_weight_oc4); | |||||
| weight[1][1] = vld1q_s8(read_weight_ptr + ld_weight_oc4 + 16); | |||||
| c[0][0] = vdotq_s32_h(weight[0][0], src[0], c[0][0], temp_c[0]); | |||||
| c[1][0] = vdotq_s32_h(weight[1][0], src[0], c[1][0], temp_c[1]); | |||||
| c[0][1] = vdotq_s32_h(weight[0][0], src[2], c[0][1], temp_c[2]); | |||||
| c[1][1] = vdotq_s32_h(weight[1][0], src[2], c[1][1], temp_c[3]); | |||||
| c[0][0] = vdotq_s32_h(weight[0][1], src[1], c[0][0], temp_c[0]); | |||||
| c[1][0] = vdotq_s32_h(weight[1][1], src[1], c[1][0], temp_c[1]); | |||||
| c[0][1] = vdotq_s32_h(weight[0][1], src[3], c[0][1], temp_c[2]); | |||||
| c[1][1] = vdotq_s32_h(weight[1][1], src[3], c[1][1], temp_c[3]); | |||||
| c[0][2] = vdotq_s32_h(weight[0][0], src[4], c[0][2], temp_c[0]); | |||||
| c[1][2] = vdotq_s32_h(weight[1][0], src[4], c[1][2], temp_c[1]); | |||||
| c[0][3] = vdotq_s32_h(weight[0][0], src[6], c[0][3], temp_c[2]); | |||||
| c[1][3] = vdotq_s32_h(weight[1][0], src[6], c[1][3], temp_c[3]); | |||||
| c[0][2] = vdotq_s32_h(weight[0][1], src[5], c[0][2], temp_c[0]); | |||||
| c[1][2] = vdotq_s32_h(weight[1][1], src[5], c[1][2], temp_c[1]); | |||||
| c[0][3] = vdotq_s32_h(weight[0][1], src[7], c[0][3], temp_c[2]); | |||||
| c[1][3] = vdotq_s32_h(weight[1][1], src[7], c[1][3], temp_c[3]); | |||||
| src[0] = vld1q_s8(src_ic_0_3 + 9 * 16); | |||||
| src[1] = vld1q_s8(src_ic_0_3 + 10 * 16); | |||||
| src[2] = vld1q_s8(src_ic_0_3 + 11 * 16); | |||||
| c[0][4] = vdotq_s32_h(weight[0][0], src[8], c[0][4], temp_c[0]); | |||||
| c[1][4] = vdotq_s32_h(weight[1][0], src[8], c[1][4], temp_c[1]); | |||||
| c[0][5] = vdotq_s32_h(weight[0][0], src[1], c[0][5], temp_c[2]); | |||||
| c[1][5] = vdotq_s32_h(weight[1][0], src[1], c[1][5], temp_c[3]); | |||||
| c[0][4] = vdotq_s32_h(weight[0][1], src[0], c[0][4], temp_c[0]); | |||||
| c[1][4] = vdotq_s32_h(weight[1][1], src[0], c[1][4], temp_c[1]); | |||||
| c[0][5] = vdotq_s32_h(weight[0][1], src[2], c[0][5], temp_c[2]); | |||||
| c[1][5] = vdotq_s32_h(weight[1][1], src[2], c[1][5], temp_c[3]); | |||||
| src[3] = vld1q_s8(src_ic_0_3 + 12 * 16); | |||||
| src[4] = vld1q_s8(src_ic_0_3 + 13 * 16); | |||||
| src[5] = vld1q_s8(src_ic_0_3 + 14 * 16); | |||||
| src[6] = vld1q_s8(src_ic_0_3 + 15 * 16); | |||||
| c[0][6] = vdotq_s32_h(weight[0][0], src[3], c[0][6], temp_c[0]); | |||||
| c[1][6] = vdotq_s32_h(weight[1][0], src[3], c[1][6], temp_c[1]); | |||||
| c[0][7] = vdotq_s32_h(weight[0][0], src[5], c[0][7], temp_c[2]); | |||||
| c[1][7] = vdotq_s32_h(weight[1][0], src[5], c[1][7], temp_c[3]); | |||||
| c[0][6] = vdotq_s32_h(weight[0][1], src[4], c[0][6], temp_c[0]); | |||||
| c[1][6] = vdotq_s32_h(weight[1][1], src[4], c[1][6], temp_c[1]); | |||||
| c[0][7] = vdotq_s32_h(weight[0][1], src[6], c[0][7], temp_c[2]); | |||||
| c[1][7] = vdotq_s32_h(weight[1][1], src[6], c[1][7], temp_c[3]); | |||||
| } | |||||
| weight_ptr += fh * fw * ld_weight_ic4; | |||||
| } | |||||
| store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, DstType*>( | |||||
| c, op, dst_ptr, ld_dst_oc); | |||||
| } | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int filter_size, | |||||
| int c_dim, typename DstType> | |||||
| static void ker_neon_dirctconv_2x2s2_oc4_ow8(const int8_t* src_ptr, | |||||
| const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, | |||||
| DstType* dst_ptr, int ic, int ih, | |||||
| int iw, int ld_dst_oc, | |||||
| const Op& op) { | |||||
| constexpr int fh = filter_size; | |||||
| constexpr int fw = filter_size; | |||||
| constexpr int oc_step = 4; | |||||
| constexpr int ic_step = 4; | |||||
| constexpr int loop_ic_step = 4; | |||||
| constexpr int ld_weight_ic4 = 16; | |||||
| constexpr int pack_iw_len = 4; | |||||
| const int ic_stride = ih * iw * pack_iw_len; | |||||
| int32x4_t c[c_dim][8]; | |||||
| int8x16_t weight[2]; | |||||
| int8x16_t src[8 + 1]; | |||||
| int16x8_t temp_c[2]; | |||||
| init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
| for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||||
| const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + | |||||
| fh_idx * iw * ic_step * pack_iw_len; | |||||
| src[0] = vld1q_s8(src_ic_0_3); | |||||
| src[1] = vld1q_s8((src_ic_0_3 + 16)); | |||||
| src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); | |||||
| src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); | |||||
| src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); | |||||
| src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); | |||||
| src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); | |||||
| src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); | |||||
| src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); | |||||
| // oc == 0 | |||||
| const int8_t* read_weight_ptr = | |||||
| weight_ptr + fh_idx * fw * ld_weight_ic4; | |||||
| weight[0] = vld1q_s8(read_weight_ptr); | |||||
| weight[1] = vld1q_s8(read_weight_ptr + 16); | |||||
| c[0][0] = vdotq_s32_h(weight[0], src[0], c[0][0], temp_c[0]); | |||||
| c[0][1] = vdotq_s32_h(weight[0], src[2], c[0][1], temp_c[1]); | |||||
| c[0][0] = vdotq_s32_h(weight[1], src[1], c[0][0], temp_c[0]); | |||||
| c[0][1] = vdotq_s32_h(weight[1], src[3], c[0][1], temp_c[1]); | |||||
| c[0][2] = vdotq_s32_h(weight[0], src[4], c[0][2], temp_c[0]); | |||||
| c[0][3] = vdotq_s32_h(weight[0], src[6], c[0][3], temp_c[1]); | |||||
| c[0][2] = vdotq_s32_h(weight[1], src[5], c[0][2], temp_c[0]); | |||||
| c[0][3] = vdotq_s32_h(weight[1], src[7], c[0][3], temp_c[1]); | |||||
| src[0] = vld1q_s8(src_ic_0_3 + 9 * 16); | |||||
| src[1] = vld1q_s8(src_ic_0_3 + 10 * 16); | |||||
| src[2] = vld1q_s8(src_ic_0_3 + 11 * 16); | |||||
| c[0][4] = vdotq_s32_h(weight[0], src[8], c[0][4], temp_c[0]); | |||||
| c[0][5] = vdotq_s32_h(weight[0], src[1], c[0][5], temp_c[1]); | |||||
| c[0][4] = vdotq_s32_h(weight[1], src[0], c[0][4], temp_c[0]); | |||||
| c[0][5] = vdotq_s32_h(weight[1], src[2], c[0][5], temp_c[1]); | |||||
| src[3] = vld1q_s8(src_ic_0_3 + 12 * 16); | |||||
| src[4] = vld1q_s8(src_ic_0_3 + 13 * 16); | |||||
| src[5] = vld1q_s8(src_ic_0_3 + 14 * 16); | |||||
| src[6] = vld1q_s8(src_ic_0_3 + 15 * 16); | |||||
| c[0][6] = vdotq_s32_h(weight[0], src[3], c[0][6], temp_c[0]); | |||||
| c[0][7] = vdotq_s32_h(weight[0], src[5], c[0][7], temp_c[1]); | |||||
| c[0][6] = vdotq_s32_h(weight[1], src[4], c[0][6], temp_c[0]); | |||||
| c[0][7] = vdotq_s32_h(weight[1], src[6], c[0][7], temp_c[1]); | |||||
| } | |||||
| weight_ptr += fh * fw * ld_weight_ic4; | |||||
| } | |||||
| store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, DstType*>( | |||||
| c, op, dst_ptr, ld_dst_oc); | |||||
| } | |||||
| /** | |||||
| dot like impl. dot 4 ic to 1 oc, accumale to c <ow, oc> | |||||
| example: (format like weight<oc, ic>) | |||||
| packed weight | |||||
| low 64 bit <0, 0> <0, 1> <1, 2> <1, 3> | <2, 0> <2, 1> <3, 2> <3, 3> | |||||
| --------------------------------------------------------------------- | |||||
| high 64 bit <0, 3> <0, 2> <1, 1> <1, 0> | <2, 3> <2, 2> <3, 1> <3, 0> | |||||
| dot: (<0, 0> + <0, 3>) + (<0, 1> + <0, 2>) -> <0> | |||||
| **/ | |||||
| // TODO: can try oh = 2 impl, oc = 8 impl | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int c_dim, | |||||
| typename DstType> | |||||
| struct KerNeonDirectStride2Int8<bias_mode, Op, remain_w, 3, c_dim, DstType> { | |||||
| static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, | |||||
| int iw, const Op& op, int ld_dst_oc) { | |||||
| constexpr int filter_size = 3; | |||||
| constexpr int fh = filter_size; | |||||
| constexpr int fw = filter_size; | |||||
| constexpr int oc_step = 4; | |||||
| constexpr int ic_step = 4; | |||||
| constexpr int loop_ic_step = 4; | |||||
| constexpr int ld_weight_ic4 = 16; | |||||
| constexpr int pack_iw_len = 4; | |||||
| const int ic_stride = ih * iw * pack_iw_len; | |||||
| int32x4_t c[c_dim][8]; | |||||
| int8x16_t weight[3]; | |||||
| int8x16_t src[8 + 2]; | |||||
| int16x8_t temp_c[4]; | |||||
| init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
| for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||||
| const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + | |||||
| fh_idx * iw * ic_step * pack_iw_len; | |||||
| src[0] = vld1q_s8(src_ic_0_3); | |||||
| src[1] = vld1q_s8((src_ic_0_3 + 16)); | |||||
| src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); | |||||
| src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); | |||||
| src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); | |||||
| src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); | |||||
| src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); | |||||
| src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); | |||||
| src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); | |||||
| src[9] = vld1q_s8((src_ic_0_3 + 9 * 16)); | |||||
| // oc == 0 | |||||
| const int8_t* read_weight_ptr = | |||||
| weight_ptr + fh_idx * fw * ld_weight_ic4; | |||||
| weight[0] = vld1q_s8(read_weight_ptr); | |||||
| weight[1] = vld1q_s8(read_weight_ptr + 16); | |||||
| weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); | |||||
| c[0][0] = vdotq_s32_h(weight[0], src[0], c[0][0], temp_c[0]); | |||||
| c[0][1] = vdotq_s32_h(weight[0], src[2], c[0][1], temp_c[1]); | |||||
| c[0][0] = vdotq_s32_h(weight[1], src[1], c[0][0], temp_c[2]); | |||||
| c[0][1] = vdotq_s32_h(weight[1], src[3], c[0][1], temp_c[3]); | |||||
| c[0][0] = vdotq_s32_h(weight[2], src[2], c[0][0], temp_c[0]); | |||||
| c[0][1] = vdotq_s32_h(weight[2], src[4], c[0][1], temp_c[1]); | |||||
| c[0][2] = vdotq_s32_h(weight[0], src[4], c[0][2], temp_c[2]); | |||||
| c[0][3] = vdotq_s32_h(weight[0], src[6], c[0][3], temp_c[3]); | |||||
| c[0][2] = vdotq_s32_h(weight[1], src[5], c[0][2], temp_c[0]); | |||||
| c[0][3] = vdotq_s32_h(weight[1], src[7], c[0][3], temp_c[1]); | |||||
| c[0][2] = vdotq_s32_h(weight[2], src[6], c[0][2], temp_c[2]); | |||||
| c[0][3] = vdotq_s32_h(weight[2], src[8], c[0][3], temp_c[3]); | |||||
| src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); | |||||
| src[1] = vld1q_s8((src_ic_0_3 + 11 * 16)); | |||||
| src[2] = vld1q_s8((src_ic_0_3 + 12 * 16)); | |||||
| c[0][4] = vdotq_s32_h(weight[0], src[8], c[0][4], temp_c[0]); | |||||
| c[0][5] = vdotq_s32_h(weight[0], src[0], c[0][5], temp_c[1]); | |||||
| c[0][4] = vdotq_s32_h(weight[1], src[9], c[0][4], temp_c[2]); | |||||
| c[0][5] = vdotq_s32_h(weight[1], src[1], c[0][5], temp_c[3]); | |||||
| c[0][4] = vdotq_s32_h(weight[2], src[0], c[0][4], temp_c[0]); | |||||
| c[0][5] = vdotq_s32_h(weight[2], src[2], c[0][5], temp_c[1]); | |||||
| src[3] = vld1q_s8((src_ic_0_3 + 13 * 16)); | |||||
| src[4] = vld1q_s8((src_ic_0_3 + 14 * 16)); | |||||
| src[5] = vld1q_s8((src_ic_0_3 + 15 * 16)); | |||||
| src[6] = vld1q_s8((src_ic_0_3 + 16 * 16)); | |||||
| c[0][6] = vdotq_s32_h(weight[0], src[2], c[0][6], temp_c[2]); | |||||
| c[0][7] = vdotq_s32_h(weight[0], src[4], c[0][7], temp_c[3]); | |||||
| c[0][6] = vdotq_s32_h(weight[1], src[3], c[0][6], temp_c[0]); | |||||
| c[0][7] = vdotq_s32_h(weight[1], src[5], c[0][7], temp_c[1]); | |||||
| c[0][6] = vdotq_s32_h(weight[2], src[4], c[0][6], temp_c[2]); | |||||
| c[0][7] = vdotq_s32_h(weight[2], src[6], c[0][7], temp_c[3]); | |||||
| } | |||||
| weight_ptr += fh * fw * ld_weight_ic4; | |||||
| } | |||||
| store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, DstType*>( | |||||
| c, op, dst_ptr, ld_dst_oc); | |||||
| } | |||||
| }; | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int c_dim, | |||||
| typename DstType> | |||||
| struct KerNeonDirectStride2Int8<bias_mode, Op, remain_w, 5, c_dim, DstType> { | |||||
| static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, | |||||
| int iw, const Op& op, int ld_dst_oc) { | |||||
| constexpr int filter_size = 5; | |||||
| constexpr int fh = filter_size; | |||||
| constexpr int fw = filter_size; | |||||
| constexpr int oc_step = 4; | |||||
| constexpr int ic_step = 4; | |||||
| constexpr int loop_ic_step = 4; | |||||
| constexpr int ld_weight_ic4 = 16; | |||||
| constexpr int pack_iw_len = 4; | |||||
| const int ic_stride = ih * iw * pack_iw_len; | |||||
| int32x4_t c[c_dim][8]; | |||||
| int8x16_t weight[5]; | |||||
| int8x16_t src[8 + 2]; | |||||
| int16x8_t temp_c[4]; | |||||
| init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
| for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||||
| const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + | |||||
| fh_idx * iw * ic_step * pack_iw_len; | |||||
| src[0] = vld1q_s8(src_ic_0_3); | |||||
| src[1] = vld1q_s8((src_ic_0_3 + 16)); | |||||
| src[2] = vld1q_s8((src_ic_0_3 + 2 * 16)); | |||||
| src[3] = vld1q_s8((src_ic_0_3 + 3 * 16)); | |||||
| src[4] = vld1q_s8((src_ic_0_3 + 4 * 16)); | |||||
| src[5] = vld1q_s8((src_ic_0_3 + 5 * 16)); | |||||
| src[6] = vld1q_s8((src_ic_0_3 + 6 * 16)); | |||||
| src[7] = vld1q_s8((src_ic_0_3 + 7 * 16)); | |||||
| src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); | |||||
| src[9] = vld1q_s8((src_ic_0_3 + 9 * 16)); | |||||
| // oc == 0 | |||||
| const int8_t* read_weight_ptr = | |||||
| weight_ptr + fh_idx * fw * ld_weight_ic4; | |||||
| weight[0] = vld1q_s8(read_weight_ptr); | |||||
| weight[1] = vld1q_s8(read_weight_ptr + 16); | |||||
| weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); | |||||
| weight[3] = vld1q_s8(read_weight_ptr + 3 * 16); | |||||
| weight[4] = vld1q_s8(read_weight_ptr + 4 * 16); | |||||
| c[0][0] = vdotq_s32_h(weight[0], src[0], c[0][0], temp_c[0]); | |||||
| c[0][1] = vdotq_s32_h(weight[0], src[2], c[0][1], temp_c[1]); | |||||
| c[0][0] = vdotq_s32_h(weight[1], src[1], c[0][0], temp_c[2]); | |||||
| c[0][1] = vdotq_s32_h(weight[1], src[3], c[0][1], temp_c[3]); | |||||
| c[0][0] = vdotq_s32_h(weight[2], src[2], c[0][0], temp_c[0]); | |||||
| c[0][1] = vdotq_s32_h(weight[2], src[4], c[0][1], temp_c[1]); | |||||
| c[0][0] = vdotq_s32_h(weight[3], src[3], c[0][0], temp_c[2]); | |||||
| c[0][1] = vdotq_s32_h(weight[3], src[5], c[0][1], temp_c[3]); | |||||
| c[0][0] = vdotq_s32_h(weight[4], src[4], c[0][0], temp_c[0]); | |||||
| c[0][1] = vdotq_s32_h(weight[4], src[6], c[0][1], temp_c[1]); | |||||
| src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); | |||||
| c[0][2] = vdotq_s32_h(weight[0], src[4], c[0][2], temp_c[2]); | |||||
| c[0][3] = vdotq_s32_h(weight[0], src[6], c[0][3], temp_c[3]); | |||||
| c[0][2] = vdotq_s32_h(weight[1], src[5], c[0][2], temp_c[0]); | |||||
| c[0][3] = vdotq_s32_h(weight[1], src[7], c[0][3], temp_c[1]); | |||||
| c[0][2] = vdotq_s32_h(weight[2], src[6], c[0][2], temp_c[2]); | |||||
| c[0][3] = vdotq_s32_h(weight[2], src[8], c[0][3], temp_c[3]); | |||||
| c[0][2] = vdotq_s32_h(weight[3], src[7], c[0][2], temp_c[0]); | |||||
| c[0][3] = vdotq_s32_h(weight[3], src[9], c[0][3], temp_c[1]); | |||||
| c[0][2] = vdotq_s32_h(weight[4], src[8], c[0][2], temp_c[2]); | |||||
| c[0][3] = vdotq_s32_h(weight[4], src[0], c[0][3], temp_c[3]); | |||||
| src[1] = vld1q_s8((src_ic_0_3 + 11 * 16)); | |||||
| src[2] = vld1q_s8((src_ic_0_3 + 12 * 16)); | |||||
| src[3] = vld1q_s8((src_ic_0_3 + 13 * 16)); | |||||
| src[4] = vld1q_s8((src_ic_0_3 + 14 * 16)); | |||||
| c[0][4] = vdotq_s32_h(weight[0], src[8], c[0][4], temp_c[0]); | |||||
| c[0][5] = vdotq_s32_h(weight[0], src[0], c[0][5], temp_c[1]); | |||||
| c[0][4] = vdotq_s32_h(weight[1], src[9], c[0][4], temp_c[2]); | |||||
| c[0][5] = vdotq_s32_h(weight[1], src[1], c[0][5], temp_c[3]); | |||||
| c[0][4] = vdotq_s32_h(weight[2], src[0], c[0][4], temp_c[0]); | |||||
| c[0][5] = vdotq_s32_h(weight[2], src[2], c[0][5], temp_c[1]); | |||||
| c[0][4] = vdotq_s32_h(weight[3], src[1], c[0][4], temp_c[2]); | |||||
| c[0][5] = vdotq_s32_h(weight[3], src[3], c[0][5], temp_c[3]); | |||||
| c[0][4] = vdotq_s32_h(weight[4], src[2], c[0][4], temp_c[0]); | |||||
| c[0][5] = vdotq_s32_h(weight[4], src[4], c[0][5], temp_c[1]); | |||||
| src[5] = vld1q_s8((src_ic_0_3 + 15 * 16)); | |||||
| src[6] = vld1q_s8((src_ic_0_3 + 16 * 16)); | |||||
| src[7] = vld1q_s8((src_ic_0_3 + 17 * 16)); | |||||
| src[8] = vld1q_s8((src_ic_0_3 + 18 * 16)); | |||||
| c[0][6] = vdotq_s32_h(weight[0], src[2], c[0][6], temp_c[2]); | |||||
| c[0][7] = vdotq_s32_h(weight[0], src[4], c[0][7], temp_c[3]); | |||||
| c[0][6] = vdotq_s32_h(weight[1], src[3], c[0][6], temp_c[0]); | |||||
| c[0][7] = vdotq_s32_h(weight[1], src[5], c[0][7], temp_c[1]); | |||||
| c[0][6] = vdotq_s32_h(weight[2], src[4], c[0][6], temp_c[2]); | |||||
| c[0][7] = vdotq_s32_h(weight[2], src[6], c[0][7], temp_c[3]); | |||||
| c[0][6] = vdotq_s32_h(weight[3], src[5], c[0][6], temp_c[0]); | |||||
| c[0][7] = vdotq_s32_h(weight[3], src[7], c[0][7], temp_c[1]); | |||||
| c[0][6] = vdotq_s32_h(weight[4], src[6], c[0][6], temp_c[2]); | |||||
| c[0][7] = vdotq_s32_h(weight[4], src[8], c[0][7], temp_c[3]); | |||||
| } | |||||
| weight_ptr += fh * fw * ld_weight_ic4; | |||||
| } | |||||
| store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, DstType*>( | |||||
| c, op, dst_ptr, ld_dst_oc); | |||||
| } | |||||
| }; | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int c_dim, | |||||
| typename DstType> | |||||
| struct KerNeonDirectStride2Int8<bias_mode, Op, remain_w, 7, c_dim, DstType> { | |||||
| static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, | |||||
| int iw, const Op& op, int ld_dst_oc) { | |||||
| constexpr int filter_size = 7; | |||||
| constexpr int fh = filter_size; | |||||
| constexpr int fw = filter_size; | |||||
| constexpr int oc_step = 4; | |||||
| constexpr int ic_step = 4; | |||||
| constexpr int loop_ic_step = 4; | |||||
| constexpr int ld_weight_ic4 = 16; | |||||
| constexpr int pack_iw_len = 4; | |||||
| const int ic_stride = ih * iw * pack_iw_len; | |||||
| int32x4_t c[c_dim][8]; | |||||
| int8x16_t weight[7]; | |||||
| int8x16_t src[8 + 2]; | |||||
| int16x8_t temp_c[4]; | |||||
| init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
| for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { | |||||
| const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + | |||||
| fh_idx * iw * ic_step * pack_iw_len; | |||||
| src[0] = vld1q_s8(src_ic_0_3); | |||||
| src[1] = vld1q_s8(src_ic_0_3 + 1 * 16); | |||||
| src[2] = vld1q_s8(src_ic_0_3 + 2 * 16); | |||||
| src[3] = vld1q_s8(src_ic_0_3 + 3 * 16); | |||||
| src[4] = vld1q_s8(src_ic_0_3 + 4 * 16); | |||||
| src[5] = vld1q_s8(src_ic_0_3 + 5 * 16); | |||||
| src[6] = vld1q_s8(src_ic_0_3 + 6 * 16); | |||||
| src[7] = vld1q_s8(src_ic_0_3 + 7 * 16); | |||||
| src[8] = vld1q_s8(src_ic_0_3 + 8 * 16); | |||||
| src[9] = vld1q_s8(src_ic_0_3 + 9 * 16); | |||||
| // oc == 0 | |||||
| const int8_t* read_weight_ptr = | |||||
| weight_ptr + fh_idx * fw * ld_weight_ic4; | |||||
| weight[0] = vld1q_s8(read_weight_ptr); | |||||
| weight[1] = vld1q_s8(read_weight_ptr + 16); | |||||
| weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); | |||||
| weight[3] = vld1q_s8(read_weight_ptr + 3 * 16); | |||||
| weight[4] = vld1q_s8(read_weight_ptr + 4 * 16); | |||||
| weight[5] = vld1q_s8(read_weight_ptr + 5 * 16); | |||||
| weight[6] = vld1q_s8(read_weight_ptr + 6 * 16); | |||||
| c[0][0] = vdotq_s32_h(weight[0], src[0], c[0][0], temp_c[0]); | |||||
| c[0][1] = vdotq_s32_h(weight[0], src[2], c[0][1], temp_c[1]); | |||||
| c[0][0] = vdotq_s32_h(weight[1], src[1], c[0][0], temp_c[2]); | |||||
| c[0][1] = vdotq_s32_h(weight[1], src[3], c[0][1], temp_c[3]); | |||||
| c[0][0] = vdotq_s32_h(weight[2], src[2], c[0][0], temp_c[0]); | |||||
| c[0][1] = vdotq_s32_h(weight[2], src[4], c[0][1], temp_c[1]); | |||||
| c[0][0] = vdotq_s32_h(weight[3], src[3], c[0][0], temp_c[2]); | |||||
| c[0][1] = vdotq_s32_h(weight[3], src[5], c[0][1], temp_c[3]); | |||||
| c[0][0] = vdotq_s32_h(weight[4], src[4], c[0][0], temp_c[0]); | |||||
| c[0][1] = vdotq_s32_h(weight[4], src[6], c[0][1], temp_c[1]); | |||||
| c[0][0] = vdotq_s32_h(weight[5], src[5], c[0][0], temp_c[2]); | |||||
| c[0][1] = vdotq_s32_h(weight[5], src[7], c[0][1], temp_c[3]); | |||||
| c[0][0] = vdotq_s32_h(weight[6], src[6], c[0][0], temp_c[0]); | |||||
| c[0][1] = vdotq_s32_h(weight[6], src[8], c[0][1], temp_c[1]); | |||||
| src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); | |||||
| src[1] = vld1q_s8(src_ic_0_3 + 11 * 16); | |||||
| src[2] = vld1q_s8(src_ic_0_3 + 12 * 16); | |||||
| c[0][2] = vdotq_s32_h(weight[0], src[4], c[0][2], temp_c[2]); | |||||
| c[0][3] = vdotq_s32_h(weight[0], src[6], c[0][3], temp_c[3]); | |||||
| c[0][2] = vdotq_s32_h(weight[1], src[5], c[0][2], temp_c[0]); | |||||
| c[0][3] = vdotq_s32_h(weight[1], src[7], c[0][3], temp_c[1]); | |||||
| c[0][2] = vdotq_s32_h(weight[2], src[6], c[0][2], temp_c[2]); | |||||
| c[0][3] = vdotq_s32_h(weight[2], src[8], c[0][3], temp_c[3]); | |||||
| c[0][2] = vdotq_s32_h(weight[3], src[7], c[0][2], temp_c[0]); | |||||
| c[0][3] = vdotq_s32_h(weight[3], src[9], c[0][3], temp_c[1]); | |||||
| c[0][2] = vdotq_s32_h(weight[4], src[8], c[0][2], temp_c[2]); | |||||
| c[0][3] = vdotq_s32_h(weight[4], src[0], c[0][3], temp_c[3]); | |||||
| c[0][2] = vdotq_s32_h(weight[5], src[9], c[0][2], temp_c[0]); | |||||
| c[0][3] = vdotq_s32_h(weight[5], src[1], c[0][3], temp_c[1]); | |||||
| c[0][2] = vdotq_s32_h(weight[6], src[0], c[0][2], temp_c[2]); | |||||
| c[0][3] = vdotq_s32_h(weight[6], src[2], c[0][3], temp_c[3]); | |||||
| src[3] = vld1q_s8(src_ic_0_3 + 13 * 16); | |||||
| src[4] = vld1q_s8(src_ic_0_3 + 14 * 16); | |||||
| src[5] = vld1q_s8(src_ic_0_3 + 15 * 16); | |||||
| src[6] = vld1q_s8(src_ic_0_3 + 16 * 16); | |||||
| c[0][4] = vdotq_s32_h(weight[0], src[8], c[0][4], temp_c[0]); | |||||
| c[0][5] = vdotq_s32_h(weight[0], src[0], c[0][5], temp_c[1]); | |||||
| c[0][4] = vdotq_s32_h(weight[1], src[9], c[0][4], temp_c[2]); | |||||
| c[0][5] = vdotq_s32_h(weight[1], src[1], c[0][5], temp_c[3]); | |||||
| c[0][4] = vdotq_s32_h(weight[2], src[0], c[0][4], temp_c[0]); | |||||
| c[0][5] = vdotq_s32_h(weight[2], src[2], c[0][5], temp_c[1]); | |||||
| c[0][4] = vdotq_s32_h(weight[3], src[1], c[0][4], temp_c[2]); | |||||
| c[0][5] = vdotq_s32_h(weight[3], src[3], c[0][5], temp_c[3]); | |||||
| c[0][4] = vdotq_s32_h(weight[4], src[2], c[0][4], temp_c[0]); | |||||
| c[0][5] = vdotq_s32_h(weight[4], src[4], c[0][5], temp_c[1]); | |||||
| c[0][4] = vdotq_s32_h(weight[5], src[3], c[0][4], temp_c[2]); | |||||
| c[0][5] = vdotq_s32_h(weight[5], src[5], c[0][5], temp_c[3]); | |||||
| c[0][4] = vdotq_s32_h(weight[6], src[4], c[0][4], temp_c[0]); | |||||
| c[0][5] = vdotq_s32_h(weight[6], src[6], c[0][5], temp_c[1]); | |||||
| src[7] = vld1q_s8(src_ic_0_3 + 17 * 16); | |||||
| src[8] = vld1q_s8(src_ic_0_3 + 18 * 16); | |||||
| src[9] = vld1q_s8(src_ic_0_3 + 19 * 16); | |||||
| src[0] = vld1q_s8(src_ic_0_3 + 20 * 16); | |||||
| c[0][6] = vdotq_s32_h(weight[0], src[2], c[0][6], temp_c[2]); | |||||
| c[0][7] = vdotq_s32_h(weight[0], src[4], c[0][7], temp_c[3]); | |||||
| c[0][6] = vdotq_s32_h(weight[1], src[3], c[0][6], temp_c[0]); | |||||
| c[0][7] = vdotq_s32_h(weight[1], src[5], c[0][7], temp_c[1]); | |||||
| c[0][6] = vdotq_s32_h(weight[2], src[4], c[0][6], temp_c[2]); | |||||
| c[0][7] = vdotq_s32_h(weight[2], src[6], c[0][7], temp_c[3]); | |||||
| c[0][6] = vdotq_s32_h(weight[3], src[5], c[0][6], temp_c[0]); | |||||
| c[0][7] = vdotq_s32_h(weight[3], src[7], c[0][7], temp_c[1]); | |||||
| c[0][6] = vdotq_s32_h(weight[4], src[6], c[0][6], temp_c[2]); | |||||
| c[0][7] = vdotq_s32_h(weight[4], src[8], c[0][7], temp_c[3]); | |||||
| c[0][6] = vdotq_s32_h(weight[5], src[7], c[0][6], temp_c[0]); | |||||
| c[0][7] = vdotq_s32_h(weight[5], src[9], c[0][7], temp_c[1]); | |||||
| c[0][6] = vdotq_s32_h(weight[6], src[8], c[0][6], temp_c[2]); | |||||
| c[0][7] = vdotq_s32_h(weight[6], src[0], c[0][7], temp_c[3]); | |||||
| } | |||||
| weight_ptr += fh * fw * ld_weight_ic4; | |||||
| } | |||||
| store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, DstType*>( | |||||
| c, op, dst_ptr, ld_dst_oc); | |||||
| } | |||||
| }; | |||||
| template <BiasMode bias_mode, typename Op, typename DstType> | |||||
| void conv_direct_stride2_2x2_int8_nchw44( | |||||
| const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t*, | |||||
| DstType* dst, const size_t oc, const size_t ic, const size_t ih, | |||||
| const size_t iw, const size_t oh, const size_t ow, const Op& op) { | |||||
| constexpr size_t filter_size = 2; | |||||
| constexpr size_t fh = filter_size; | |||||
| constexpr size_t fw = filter_size; | |||||
| constexpr size_t ic_step = 4; | |||||
| constexpr size_t oc_step = 4; | |||||
| constexpr size_t big_oc_step = 8; | |||||
| constexpr size_t oh_step = 1; | |||||
| constexpr size_t ow_step = 8; | |||||
| constexpr size_t stride_h = 2; | |||||
| constexpr size_t stride_w = 2; | |||||
| constexpr int pack_iw_len = 4; | |||||
| const size_t out_img_stride = oh * ow; | |||||
| const size_t ow_end = ow / ow_step * ow_step; | |||||
| const size_t ow_remain = ow - ow_end; | |||||
| const size_t oc_end = oc / big_oc_step * big_oc_step; | |||||
| const size_t oc_remain = oc - oc_end; | |||||
| const int ld_dst_oc = oh * ow * oc_step; | |||||
| using remain_fun = std::function<void( | |||||
| const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, int iw, | |||||
| int ld_dst_oc, const Op& op)>; | |||||
| remain_fun kern_big_oc_remain = nullptr; | |||||
| remain_fun kern_small_oc_remain = nullptr; | |||||
| switch (ow_remain) { | |||||
| #define cb(step) \ | |||||
| case step: \ | |||||
| kern_big_oc_remain = \ | |||||
| ker_neon_dirctconv_2x2s2_oc8_ow8<bias_mode, Op, step, \ | |||||
| filter_size, 2, DstType>; \ | |||||
| kern_small_oc_remain = \ | |||||
| ker_neon_dirctconv_2x2s2_oc4_ow8<bias_mode, Op, step, \ | |||||
| filter_size, 1, DstType>; \ | |||||
| break; | |||||
| UNROLL_CALL_RAW(8, cb); | |||||
| default: | |||||
| megdnn_assert(0, "no remain %zu for kern", ow_remain); | |||||
| } | |||||
| #undef cb | |||||
| for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { | |||||
| const size_t weight_offset = oc_idx * ic * fh * fw; | |||||
| for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||||
| for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||||
| const size_t src_offset = | |||||
| (oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step * | |||||
| pack_iw_len; | |||||
| const size_t dst_offset = oc_idx * out_img_stride + | |||||
| (oh_idx * ow + ow_idx) * oc_step; | |||||
| ker_neon_dirctconv_2x2s2_oc8_ow8<bias_mode, Op, ow_step, | |||||
| filter_size, 2, DstType>( | |||||
| src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
| dst + dst_offset, ic, ih, iw, ld_dst_oc, op); | |||||
| } | |||||
| if (ow_remain > 0) { | |||||
| const size_t src_offset = | |||||
| (oh_idx * stride_h * iw + ow_end * stride_w) * ic_step * | |||||
| pack_iw_len; | |||||
| const size_t dst_offset = oc_idx * out_img_stride + | |||||
| (oh_idx * ow + ow_end) * oc_step; | |||||
| kern_big_oc_remain(src + src_offset, filter + weight_offset, | |||||
| bias + oc_idx, dst + dst_offset, ic, ih, iw, | |||||
| ld_dst_oc, op); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (oc_remain > 0) { | |||||
| const size_t oc_idx = oc_end; | |||||
| const size_t weight_offset = oc_idx * ic * fh * fw; | |||||
| for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||||
| for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||||
| const size_t src_offset = | |||||
| (oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step * | |||||
| pack_iw_len; | |||||
| const size_t dst_offset = oc_idx * out_img_stride + | |||||
| (oh_idx * ow + ow_idx) * oc_step; | |||||
| ker_neon_dirctconv_2x2s2_oc4_ow8<bias_mode, Op, ow_step, | |||||
| filter_size, 1, DstType>( | |||||
| src + src_offset, filter + weight_offset, bias + oc_idx, | |||||
| dst + dst_offset, ic, ih, iw, ld_dst_oc, op); | |||||
| } | |||||
| if (ow_remain > 0) { | |||||
| const size_t src_offset = | |||||
| (oh_idx * stride_h * iw + ow_end * stride_w) * ic_step * | |||||
| pack_iw_len; | |||||
| const size_t dst_offset = oc_idx * out_img_stride + | |||||
| (oh_idx * ow + ow_end) * oc_step; | |||||
| kern_small_oc_remain(src + src_offset, filter + weight_offset, | |||||
| bias + oc_idx, dst + dst_offset, ic, ih, | |||||
| iw, ld_dst_oc, op); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| template <BiasMode bias_mode, typename Op, int filter_size, typename DstType> | |||||
| void conv_direct_stride2_int8_nchw44_kern( | |||||
| const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t*, | |||||
| DstType* dst, const size_t oc, const size_t ic, const size_t ih, | |||||
| const size_t iw, const size_t oh, const size_t ow, const Op& op) { | |||||
| constexpr size_t fh = filter_size; | |||||
| constexpr size_t fw = filter_size; | |||||
| constexpr size_t ic_step = 4; | |||||
| constexpr size_t oc_step = 4; | |||||
| constexpr size_t oh_step = 1; | |||||
| constexpr size_t ow_step = 8; | |||||
| constexpr size_t stride_h = 2; | |||||
| constexpr size_t stride_w = 2; | |||||
| constexpr int pack_iw_len = 4; | |||||
| const size_t img_stride = oh * ow; | |||||
| const size_t ow_end = ow / ow_step * ow_step; | |||||
| const size_t ow_remain = ow - ow_end; | |||||
| const int ld_dst_oc = oh * ow * oc_step; | |||||
| using remain_fun = std::function<void( | |||||
| const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, int iw, | |||||
| const Op& op, int ld_dst_oc)>; | |||||
| remain_fun kern_small_oc_remain = nullptr; | |||||
| switch (ow_remain) { | |||||
| #define cb(step) \ | |||||
| case step: \ | |||||
| kern_small_oc_remain = \ | |||||
| KerNeonDirectStride2Int8<bias_mode, Op, step, filter_size, 1, \ | |||||
| DstType>::impl; \ | |||||
| break; | |||||
| UNROLL_CALL_RAW(8, cb); | |||||
| default: | |||||
| megdnn_assert(0, "no remain %zu for kern", ow_remain); | |||||
| } | |||||
| #undef cb | |||||
| for (size_t oc_idx = 0; oc_idx < oc; oc_idx += oc_step) { | |||||
| const size_t weight_offset = oc_idx * ic * fh * fw; | |||||
| for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||||
| for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||||
| const size_t src_offset = | |||||
| (oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step * | |||||
| pack_iw_len; | |||||
| const size_t dst_offset = | |||||
| oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||||
| KerNeonDirectStride2Int8<bias_mode, Op, ow_step, filter_size, 1, | |||||
| DstType>::impl(src + src_offset, | |||||
| filter + weight_offset, | |||||
| bias + oc_idx, | |||||
| dst + dst_offset, ic, | |||||
| ih, iw, op, ld_dst_oc); | |||||
| } | |||||
| if (ow_remain > 0) { | |||||
| const size_t src_offset = | |||||
| (oh_idx * stride_h * iw + ow_end * stride_w) * ic_step * | |||||
| pack_iw_len; | |||||
| const size_t dst_offset = | |||||
| oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||||
| kern_small_oc_remain(src + src_offset, filter + weight_offset, | |||||
| bias + oc_idx, dst + dst_offset, ic, ih, | |||||
| iw, op, ld_dst_oc); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } // namespace | |||||
| namespace int8_direct_nchw44 { | |||||
| template <BiasMode bias_mode, typename Op, int filter_size, typename DstType> | |||||
| struct ConvDirectInt8Nchw44Choose<bias_mode, Op, filter_size, DstType, 2> { | |||||
| static void impl(const int8_t* src, const int8_t* filter, | |||||
| const int32_t* bias, int32_t* temp, DstType* dst, | |||||
| const size_t oc, const size_t ic, const size_t ih, | |||||
| const size_t iw, const size_t oh, const size_t ow, | |||||
| const Op& op) { | |||||
| conv_direct_stride2_int8_nchw44_kern<bias_mode, Op, filter_size, | |||||
| DstType>( | |||||
| src, filter, bias, temp, dst, oc, ic, ih, iw, oh, ow, op); | |||||
| } | |||||
| }; | |||||
| template <BiasMode bias_mode, typename Op, typename DstType> | |||||
| struct ConvDirectInt8Nchw44Choose<bias_mode, Op, 2, DstType, 2> { | |||||
| static void impl(const int8_t* src, const int8_t* filter, | |||||
| const int32_t* bias, int32_t* temp, DstType* dst, | |||||
| const size_t oc, const size_t ic, const size_t ih, | |||||
| const size_t iw, const size_t oh, const size_t ow, | |||||
| const Op& op) { | |||||
| conv_direct_stride2_2x2_int8_nchw44<bias_mode, Op, DstType>( | |||||
| src, filter, bias, temp, dst, oc, ic, ih, iw, oh, ow, op); | |||||
| } | |||||
| }; | |||||
| #define DO_CONV_KERN_FUN(stride, DstType, filter_size, bias_mode, Op) \ | |||||
| template struct ConvDirectInt8Nchw44Choose<bias_mode, Op, filter_size, \ | |||||
| DstType, stride>; | |||||
| #define GET_OP_PARAM(stride, filter, bias_mode) \ | |||||
| DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \ | |||||
| \ | |||||
| TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
| DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \ | |||||
| \ | |||||
| ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
| DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \ | |||||
| \ | |||||
| HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
| DO_CONV_KERN_FUN(stride, dt_int32, filter, bias_mode, NoneOp<dt_int32>) | |||||
| #define GET_BIAS_MODE_PARAM(stride, filter) \ | |||||
| GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \ | |||||
| GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) | |||||
| #define DISPATCH_CONV_KERN(stride) \ | |||||
| GET_BIAS_MODE_PARAM(stride, 2) \ | |||||
| GET_BIAS_MODE_PARAM(stride, 3) \ | |||||
| GET_BIAS_MODE_PARAM(stride, 5) \ | |||||
| GET_BIAS_MODE_PARAM(stride, 7) | |||||
| DISPATCH_CONV_KERN(2); | |||||
| } // namespace int8_direct_nchw44 | |||||
| } // namespace arm_common | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -0,0 +1,47 @@ | |||||
| /** | |||||
| * \file | |||||
| * dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h" | |||||
| namespace megdnn { | |||||
| namespace arm_common { | |||||
| namespace { | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int filter_size, | |||||
| int oc_block, int stride> | |||||
| struct KerNeonXXs2NchwNchw44 { | |||||
| static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | |||||
| int iw, int ld_dst_oc, const Op& op); | |||||
| }; | |||||
| template <int oc> | |||||
| struct OCHelper { | |||||
| public: | |||||
| static const int val = 0; | |||||
| }; | |||||
| template <> | |||||
| struct OCHelper<4> { | |||||
| public: | |||||
| static const int val = 1; | |||||
| }; | |||||
| template <> | |||||
| struct OCHelper<8> { | |||||
| public: | |||||
| static const int val = 2; | |||||
| }; | |||||
| } // namespace | |||||
| } // namespace arm_common | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -0,0 +1,561 @@ | |||||
| /** | |||||
| * \file | |||||
| * dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_common.h" | |||||
| #include "src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h" | |||||
| namespace megdnn { | |||||
| namespace arm_common { | |||||
| namespace { | |||||
| /** | |||||
| * @brief core code for calculation patten | |||||
| * | |||||
| * @tparam src_idx is offset of src reg | |||||
| * @tparam weight_idx is offset of weight reg | |||||
| * @tparam c_dim is output channel | |||||
| * @tparam Func mla operation funcion | |||||
| * @tparam stride | |||||
| * @tparam T outpur regs type | |||||
| * @tparam T2 src regs type | |||||
| * @tparam T3 weight regs type | |||||
| * @tparam T4 temp regs type | |||||
| */ | |||||
| template <int src_idx, int weight_idx, int c_dim, int stride, typename T, | |||||
| typename T2, typename T3, typename T4> | |||||
| struct ShiftCalHelper { | |||||
| static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight, T4& temp); | |||||
| static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight); | |||||
| }; | |||||
| template <int src_idx, int weight_idx, int c_dim, int stride, typename T, | |||||
| typename T2, typename T3, typename T4> | |||||
| MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight, T4& temp) { | |||||
| ShiftCalHelper<src_idx, weight_idx, c_dim, stride, T, T2, T3, T4>::impl( | |||||
| c, src, weight, temp); | |||||
| } | |||||
| template <int src_idx, int weight_idx, int c_dim, int stride, typename T, | |||||
| typename T2, typename T3> | |||||
| MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) { | |||||
| ShiftCalHelper<src_idx, weight_idx, c_dim, stride, T, T2, T3, int>::impl( | |||||
| c, src, weight); | |||||
| }; | |||||
| template <int src_idx, int weight_idx, typename T, typename T2, typename T3, | |||||
| typename T4> | |||||
| struct ShiftCalHelper<src_idx, weight_idx, 2, 1, T, T2, T3, T4> { | |||||
| static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight, T4& temp) { | |||||
| c[0][0] = vdotq_s32_h(src[(0 + src_idx) % 8], weight[0][weight_idx], | |||||
| c[0][0], temp[0]); | |||||
| c[1][0] = vdotq_s32_h(src[(0 + src_idx) % 8], weight[1][weight_idx], | |||||
| c[1][0], temp[1]); | |||||
| c[0][1] = vdotq_s32_h(src[(1 + src_idx) % 8], weight[0][weight_idx], | |||||
| c[0][1], temp[2]); | |||||
| c[1][1] = vdotq_s32_h(src[(1 + src_idx) % 8], weight[1][weight_idx], | |||||
| c[1][1], temp[3]); | |||||
| c[0][2] = vdotq_s32_h(src[(2 + src_idx) % 8], weight[0][weight_idx], | |||||
| c[0][2], temp[0]); | |||||
| c[1][2] = vdotq_s32_h(src[(2 + src_idx) % 8], weight[1][weight_idx], | |||||
| c[1][2], temp[1]); | |||||
| c[0][3] = vdotq_s32_h(src[(3 + src_idx) % 8], weight[0][weight_idx], | |||||
| c[0][3], temp[2]); | |||||
| c[1][3] = vdotq_s32_h(src[(3 + src_idx) % 8], weight[1][weight_idx], | |||||
| c[1][3], temp[3]); | |||||
| c[0][4] = vdotq_s32_h(src[(4 + src_idx) % 8], weight[0][weight_idx], | |||||
| c[0][4], temp[0]); | |||||
| c[1][4] = vdotq_s32_h(src[(4 + src_idx) % 8], weight[1][weight_idx], | |||||
| c[1][4], temp[1]); | |||||
| c[0][5] = vdotq_s32_h(src[(5 + src_idx) % 8], weight[0][weight_idx], | |||||
| c[0][5], temp[2]); | |||||
| c[1][5] = vdotq_s32_h(src[(5 + src_idx) % 8], weight[1][weight_idx], | |||||
| c[1][5], temp[3]); | |||||
| c[0][6] = vdotq_s32_h(src[(6 + src_idx) % 8], weight[0][weight_idx], | |||||
| c[0][6], temp[0]); | |||||
| c[1][6] = vdotq_s32_h(src[(6 + src_idx) % 8], weight[1][weight_idx], | |||||
| c[1][6], temp[1]); | |||||
| c[0][7] = vdotq_s32_h(src[(7 + src_idx) % 8], weight[0][weight_idx], | |||||
| c[0][7], temp[2]); | |||||
| c[1][7] = vdotq_s32_h(src[(7 + src_idx) % 8], weight[1][weight_idx], | |||||
| c[1][7], temp[3]); | |||||
| } | |||||
| static MEGDNN_ALWAYS_INLINE void impl(T&, T2&, T3&); | |||||
| }; | |||||
| template <int src_idx, int weight_idx, typename T, typename T2, typename T3, | |||||
| typename T4> | |||||
| struct ShiftCalHelper<src_idx, weight_idx, 1, 1, T, T2, T3, T4> { | |||||
| static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight, T4& temp) { | |||||
| c[0][0] = vdotq_s32_h(src[(0 + src_idx) % 8], weight[0][weight_idx], | |||||
| c[0][0], temp[0]); | |||||
| c[0][1] = vdotq_s32_h(src[(1 + src_idx) % 8], weight[0][weight_idx], | |||||
| c[0][1], temp[1]); | |||||
| c[0][2] = vdotq_s32_h(src[(2 + src_idx) % 8], weight[0][weight_idx], | |||||
| c[0][2], temp[2]); | |||||
| c[0][3] = vdotq_s32_h(src[(3 + src_idx) % 8], weight[0][weight_idx], | |||||
| c[0][3], temp[3]); | |||||
| c[0][4] = vdotq_s32_h(src[(4 + src_idx) % 8], weight[0][weight_idx], | |||||
| c[0][4], temp[0]); | |||||
| c[0][5] = vdotq_s32_h(src[(5 + src_idx) % 8], weight[0][weight_idx], | |||||
| c[0][5], temp[1]); | |||||
| c[0][6] = vdotq_s32_h(src[(6 + src_idx) % 8], weight[0][weight_idx], | |||||
| c[0][6], temp[2]); | |||||
| c[0][7] = vdotq_s32_h(src[(7 + src_idx) % 8], weight[0][weight_idx], | |||||
| c[0][7], temp[3]); | |||||
| } | |||||
| static MEGDNN_ALWAYS_INLINE void impl(T&, T2&, T3&); | |||||
| }; | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block> | |||||
| struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 2, oc_block, 1> { | |||||
| static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | |||||
| int iw, int ld_dst_oc, const Op& op) { | |||||
| constexpr int stride = 1; | |||||
| constexpr int filter_height = 2; | |||||
| constexpr int filter_width = 4; | |||||
| constexpr int oc_step = 4; | |||||
| constexpr int loop_ic_step = 1; | |||||
| constexpr int simd_len = 16; | |||||
| constexpr int pack_iw_len = 16; | |||||
| constexpr int src_reg = 8; | |||||
| constexpr int weight_reg = 1; | |||||
| const int ic_stride = ih * iw * pack_iw_len; | |||||
| const int ld_weight_oc = oc_step * filter_height * filter_width * ic; | |||||
| constexpr int c_dim = OCHelper<oc_block>::val; | |||||
| int32x4_t c[c_dim][8]; | |||||
| init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
| const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; | |||||
| int8x16_t src[src_reg]; | |||||
| int8x16_t dot4_weight[c_dim][weight_reg]; | |||||
| int16x8_t temp_c[4]; | |||||
| load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( | |||||
| dot4_weight, weight_ptr, ld_weight_oc); | |||||
| load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( | |||||
| src, nchw_src_ptr + 0 * iw * pack_iw_len, 0); | |||||
| cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); | |||||
| load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( | |||||
| dot4_weight, weight_ptr + 1 * filter_width * oc_step, | |||||
| ld_weight_oc); | |||||
| load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( | |||||
| src, nchw_src_ptr + 1 * iw * pack_iw_len, 0); | |||||
| cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); | |||||
| weight_ptr += oc_step * filter_height * filter_width; | |||||
| } | |||||
| store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>( | |||||
| c, op, dst_ptr, ld_dst_oc); | |||||
| } | |||||
| }; | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block> | |||||
| struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 3, oc_block, 1> { | |||||
| static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | |||||
| int iw, int ld_dst_oc, const Op& op) { | |||||
| constexpr int stride = 1; | |||||
| constexpr int filter_height = 3; | |||||
| constexpr int filter_width = 4; | |||||
| constexpr int oc_step = 4; | |||||
| constexpr int loop_ic_step = 1; | |||||
| constexpr int simd_len = 16; | |||||
| constexpr int pack_iw_len = 16; | |||||
| constexpr int src_reg = 8; | |||||
| constexpr int weight_reg = 1; | |||||
| const int ic_stride = ih * iw * pack_iw_len; | |||||
| const int ld_weight_oc = oc_step * filter_height * filter_width * ic; | |||||
| constexpr int c_dim = OCHelper<oc_block>::val; | |||||
| int32x4_t c[c_dim][8]; | |||||
| init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
| const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; | |||||
| int8x16_t src[src_reg]; | |||||
| int8x16_t dot4_weight[c_dim][weight_reg]; | |||||
| int16x8_t temp_c[4]; | |||||
| load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( | |||||
| dot4_weight, weight_ptr, ld_weight_oc); | |||||
| load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( | |||||
| src, nchw_src_ptr + 0 * iw * pack_iw_len, 0); | |||||
| cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); | |||||
| load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( | |||||
| dot4_weight, weight_ptr + 1 * filter_width * oc_step, | |||||
| ld_weight_oc); | |||||
| load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( | |||||
| src, nchw_src_ptr + 1 * iw * pack_iw_len, 0); | |||||
| cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); | |||||
| load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( | |||||
| dot4_weight, weight_ptr + 2 * filter_width * oc_step, | |||||
| ld_weight_oc); | |||||
| load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( | |||||
| src, nchw_src_ptr + 2 * iw * pack_iw_len, 0); | |||||
| cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); | |||||
| weight_ptr += oc_step * filter_height * filter_width; | |||||
| } | |||||
| store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>( | |||||
| c, op, dst_ptr, ld_dst_oc); | |||||
| } | |||||
| }; | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block> | |||||
| struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 5, oc_block, 1> { | |||||
| static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | |||||
| int iw, int ld_dst_oc, const Op& op) { | |||||
| constexpr int stride = 1; | |||||
| constexpr int filter_height = 5; | |||||
| constexpr int filter_width = 8; | |||||
| constexpr int oc_step = 4; | |||||
| constexpr int loop_ic_step = 1; | |||||
| constexpr int simd_len = 16; | |||||
| constexpr int pack_iw_len = 16; | |||||
| constexpr int src_reg = 8; | |||||
| constexpr int weight_reg = 2; | |||||
| const int ic_stride = ih * iw * pack_iw_len; | |||||
| const int ld_weight_oc = oc_step * filter_height * filter_width * ic; | |||||
| constexpr int c_dim = OCHelper<oc_block>::val; | |||||
| int32x4_t c[c_dim][8]; | |||||
| init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
| const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; | |||||
| int8x16_t src[src_reg]; | |||||
| int8x16_t dot4_weight[c_dim][weight_reg]; | |||||
| int16x8_t temp_c[4]; | |||||
| #define cb(step) \ | |||||
| load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \ | |||||
| dot4_weight, weight_ptr + step * filter_width * oc_step, \ | |||||
| ld_weight_oc); \ | |||||
| load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( \ | |||||
| src, nchw_src_ptr + step * iw * pack_iw_len, 0); \ | |||||
| cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); \ | |||||
| load_helper<4, 0, simd_len, 0, Vld1q_s8>( \ | |||||
| src, \ | |||||
| nchw_src_ptr + step * iw * pack_iw_len + src_reg * pack_iw_len, \ | |||||
| 0); \ | |||||
| cal_helper<4, 1, c_dim, stride>(c, src, dot4_weight, temp_c); | |||||
| UNROLL_CALL_RAW(5, cb); | |||||
| #undef cb | |||||
| weight_ptr += oc_step * filter_height * filter_width; | |||||
| } | |||||
| store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>( | |||||
| c, op, dst_ptr, ld_dst_oc); | |||||
| } | |||||
| }; | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block> | |||||
| struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 7, oc_block, 1> { | |||||
| static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | |||||
| int iw, int ld_dst_oc, const Op& op) { | |||||
| constexpr int stride = 1; | |||||
| constexpr int filter_height = 7; | |||||
| constexpr int filter_width = 8; | |||||
| constexpr int oc_step = 4; | |||||
| constexpr int loop_ic_step = 1; | |||||
| constexpr int simd_len = 16; | |||||
| constexpr int pack_iw_len = 16; | |||||
| constexpr int src_reg = 8; | |||||
| constexpr int weight_reg = 2; | |||||
| const int ic_stride = ih * iw * pack_iw_len; | |||||
| const int ld_weight_oc = oc_step * filter_height * filter_width * ic; | |||||
| constexpr int c_dim = OCHelper<oc_block>::val; | |||||
| int32x4_t c[c_dim][8]; | |||||
| init_ocx_ow8<c_dim, bias_mode, 8>(c, bias_ptr, oc_step); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
| const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; | |||||
| int8x16_t src[src_reg]; | |||||
| int8x16_t dot4_weight[c_dim][weight_reg]; | |||||
| int16x8_t temp_c[4]; | |||||
| #define cb(step) \ | |||||
| load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \ | |||||
| dot4_weight, weight_ptr + step * filter_width * oc_step, \ | |||||
| ld_weight_oc); \ | |||||
| load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( \ | |||||
| src, nchw_src_ptr + step * iw * pack_iw_len, 0); \ | |||||
| cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); \ | |||||
| load_helper<4, 0, simd_len, 0, Vld1q_s8>( \ | |||||
| src, \ | |||||
| nchw_src_ptr + step * iw * pack_iw_len + src_reg * pack_iw_len, \ | |||||
| 0); \ | |||||
| cal_helper<4, 1, c_dim, stride>(c, src, dot4_weight, temp_c); | |||||
| UNROLL_CALL_RAW(7, cb); | |||||
| #undef cb | |||||
| weight_ptr += oc_step * filter_height * filter_width; | |||||
| } | |||||
| store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>( | |||||
| c, op, dst_ptr, ld_dst_oc); | |||||
| } | |||||
| }; | |||||
| } // namespace | |||||
| namespace int8_direct_nchw_nchw44 { | |||||
| /** | |||||
| * pack {oc / 4, fh, fw, ic, 4(oc)} to {oc / 4, ic, fh ,fw/4, 4(oc)*4(fw)} | |||||
| * pack interleave two adjacent row in filter to one row | |||||
| * */ | |||||
| template <> | |||||
| void pack_nchw44_weight_for_nchw_conv<1>(const int8_t* src_ptr, int8_t* dst_ptr, | |||||
| const int ic, const int fh, | |||||
| const int fw, const int oc) { | |||||
| constexpr int oc_step = 4; | |||||
| const int fw2 = round_up(fw, 4); | |||||
| const int fw_remain = fw2 - fw; | |||||
| const int dst_ic_stride = fh * fw2; | |||||
| const int oc_step_stride = fh * fw2 * ic * oc_step; | |||||
| static const uint8_t transpose_4x4_idx[16] = {0, 4, 1, 5, 2, 6, 3, 7, | |||||
| 8, 12, 9, 13, 10, 14, 11, 15}; | |||||
| uint8x16_t tbl_transpose_4x4 = vld1q_u8(&transpose_4x4_idx[0]); | |||||
| rep_step(oc_idx, oc, oc_step) { | |||||
| int32_t* dst_temp_ptr = | |||||
| reinterpret_cast<int32_t*>(dst_ptr + oc_idx * ic * fh * fw2); | |||||
| const int32_t* src_temp_ptr = reinterpret_cast<const int32_t*>( | |||||
| src_ptr + oc_idx * ic * fh * fw); | |||||
| // transpose ic and pad | |||||
| rep(fh_idx, fh) { | |||||
| rep(fw_idx, fw) { | |||||
| rep(ic_idx, ic) { | |||||
| *(dst_temp_ptr + ic_idx * dst_ic_stride) = *src_temp_ptr; | |||||
| src_temp_ptr++; | |||||
| } | |||||
| dst_temp_ptr++; | |||||
| } | |||||
| rep(ic_idx, ic) { | |||||
| memset(dst_temp_ptr + ic_idx * dst_ic_stride, 0, | |||||
| sizeof(int8_t) * oc_step * fw_remain); | |||||
| } | |||||
| dst_temp_ptr += fw_remain; | |||||
| } | |||||
| // transpose fw oc | |||||
| int8_t* trans_dst_temp_ptr = | |||||
| reinterpret_cast<int8_t*>(dst_ptr + oc_idx * ic * fh * fw2); | |||||
| rep_step(idx, oc_step_stride, 16) { | |||||
| int8x16_t temp = vld1q_s8(trans_dst_temp_ptr + idx); | |||||
| vst1q_s8(trans_dst_temp_ptr + idx, | |||||
| vqtbl1q_s8(temp, tbl_transpose_4x4)); | |||||
| } | |||||
| } | |||||
| }; | |||||
| /** | |||||
| * pack (ic, h, w) to (ic, h, w * 16) | |||||
| * pack interleave two adjacent row in src and repeat 4 times, store to one row | |||||
| * */ | |||||
| template <> | |||||
| void pack_nchw_src_for_nchw44_conv<1>(const int8_t* sptr_origin, | |||||
| int8_t* sptr_base, const int ic, | |||||
| const int pad_top, const int pad_bottom, | |||||
| const int, const int, const int ih, | |||||
| const int iw, const int iw2, const int pw, | |||||
| int8_t* temp_ptr) { | |||||
| static uint8_t reorder_idx[16] = {0, 1, 0, 1, 0, 1, 0, 1, | |||||
| 2, 3, 2, 3, 2, 3, 2, 3}; | |||||
| uint8x16_t tbl_idx = vld1q_u8(&reorder_idx[0]); | |||||
| constexpr int iw_step = 4; | |||||
| constexpr int pack_iw_len = 16; | |||||
| const int ic_stride = ih * iw; | |||||
| const int iw_with_pad = iw + 2 * pw; | |||||
| const int iw_with_pad_end = iw_with_pad / iw_step * iw_step; | |||||
| rep(ic_idx, ic) { | |||||
| const int8_t* sptr = sptr_origin + ic_idx * ic_stride; | |||||
| memset(sptr_base, 0, | |||||
| sizeof(int8_t) * iw2 * (ih + pad_top + pad_bottom) * | |||||
| pack_iw_len); | |||||
| sptr_base += iw2 * pad_top * pack_iw_len; | |||||
| rep(ih_idx, ih) { | |||||
| memset(temp_ptr, 0, iw_with_pad * sizeof(int8_t)); | |||||
| memcpy(temp_ptr + pw, sptr, sizeof(int8_t) * iw); | |||||
| for (int iw_idx = 0; iw_idx < iw_with_pad_end; iw_idx += iw_step) { | |||||
| int8x16_t src[4]; | |||||
| int8x16_t dst[4]; | |||||
| src[0] = vld1q_s8(temp_ptr + iw_idx); | |||||
| src[1] = vld1q_s8(temp_ptr + iw_idx + 1); | |||||
| src[2] = vld1q_s8(temp_ptr + iw_idx + 2); | |||||
| src[3] = vld1q_s8(temp_ptr + iw_idx + 3); | |||||
| dst[0] = vqtbl1q_s8(src[0], tbl_idx); | |||||
| dst[1] = vqtbl1q_s8(src[1], tbl_idx); | |||||
| dst[2] = vqtbl1q_s8(src[2], tbl_idx); | |||||
| dst[3] = vqtbl1q_s8(src[3], tbl_idx); | |||||
| vst1q_s8(sptr_base + iw_idx * pack_iw_len + 0, dst[0]); | |||||
| vst1q_s8(sptr_base + iw_idx * pack_iw_len + 16, dst[1]); | |||||
| vst1q_s8(sptr_base + iw_idx * pack_iw_len + 32, dst[2]); | |||||
| vst1q_s8(sptr_base + iw_idx * pack_iw_len + 48, dst[3]); | |||||
| } | |||||
| for (int iw_idx = iw_with_pad_end; iw_idx < iw_with_pad; ++iw_idx) { | |||||
| int8x16_t src = vld1q_s8(temp_ptr + iw_idx); | |||||
| int8x16_t dst = vqtbl1q_s8(src, tbl_idx); | |||||
| vst1q_s8(sptr_base + iw_idx * pack_iw_len, dst); | |||||
| } | |||||
| sptr_base += iw2 * pack_iw_len; | |||||
| sptr += iw; | |||||
| } | |||||
| sptr_base += iw2 * pad_bottom * pack_iw_len; | |||||
| } | |||||
| } | |||||
| template <BiasMode bias_mode, typename Op, size_t filter_size> | |||||
| struct ConvDiectStrideInt8NchwNchw44<bias_mode, Op, filter_size, 1> { | |||||
| static void impl(const int8_t* src, const int8_t* filter, | |||||
| const int32_t* bias, int32_t* temp, int8_t* dst, | |||||
| const size_t oc, const size_t ic, const size_t ih, | |||||
| const size_t iw, const size_t oh, const size_t ow, | |||||
| const Op& op) { | |||||
| MEGDNN_MARK_USED_VAR(temp); | |||||
| constexpr int stride = 1; | |||||
| constexpr size_t fh = filter_size; | |||||
| constexpr size_t fw = (filter_size + 3) / 4 * 4; | |||||
| constexpr size_t ic_step = 1; | |||||
| constexpr size_t big_oc_step = 8; | |||||
| constexpr size_t oc_step = 4; | |||||
| constexpr size_t ih_step = 1; | |||||
| constexpr size_t oh_step = 1; | |||||
| constexpr size_t ow_step = 8; | |||||
| constexpr size_t stride_h = stride; | |||||
| constexpr size_t stride_w = stride; | |||||
| constexpr int pack_iw_len = 16; | |||||
| const size_t img_stride = oh * ow; | |||||
| const size_t ow_end = ow / ow_step * ow_step; | |||||
| const size_t ow_remain = ow - ow_end; | |||||
| const size_t oc_end = oc / big_oc_step * big_oc_step; | |||||
| const size_t oc_remain = oc - oc_end; | |||||
| const int ld_dst_oc = oc_step * img_stride; | |||||
| using remain_fun = std::function<void( | |||||
| const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | |||||
| int iw, int ld_dst_oc, const Op& op)>; | |||||
| remain_fun kern_big_oc_remain = nullptr; | |||||
| remain_fun kern_small_oc_remain = nullptr; | |||||
| switch (ow_remain) { | |||||
| #define cb(step) \ | |||||
| case step: \ | |||||
| kern_big_oc_remain = \ | |||||
| KerNeonXXs2NchwNchw44<bias_mode, Op, step, filter_size, \ | |||||
| big_oc_step, stride>::impl; \ | |||||
| kern_small_oc_remain = \ | |||||
| KerNeonXXs2NchwNchw44<bias_mode, Op, step, filter_size, \ | |||||
| oc_step, stride>::impl; \ | |||||
| break; | |||||
| UNROLL_CALL_RAW(8, cb); | |||||
| default: | |||||
| megdnn_assert(0, "no remain %zu for kern", ow_remain); | |||||
| } | |||||
| for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { | |||||
| const size_t weight_offset = oc_idx * ic * fh * fw; | |||||
| for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||||
| for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||||
| const size_t src_offset = (oh_idx * stride_h * iw + | |||||
| ow_idx * stride_w * ih_step) * | |||||
| ic_step * pack_iw_len; | |||||
| const size_t dst_offset = oc_idx * img_stride + | |||||
| (oh_idx * ow + ow_idx) * oc_step; | |||||
| KerNeonXXs2NchwNchw44<bias_mode, Op, ow_step, filter_size, | |||||
| big_oc_step, | |||||
| stride>::impl(src + src_offset, | |||||
| filter + weight_offset, | |||||
| bias + oc_idx, | |||||
| dst + dst_offset, ic, | |||||
| ih, iw, ld_dst_oc, op); | |||||
| } | |||||
| if (ow_remain > 0) { | |||||
| const size_t src_offset = (oh_idx * stride_h * iw + | |||||
| ow_end * stride_w * ih_step) * | |||||
| ic_step * pack_iw_len; | |||||
| const size_t dst_offset = oc_idx * img_stride + | |||||
| (oh_idx * ow + ow_end) * oc_step; | |||||
| kern_big_oc_remain(src + src_offset, filter + weight_offset, | |||||
| bias + oc_idx, dst + dst_offset, ic, ih, | |||||
| iw, ld_dst_oc, op); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (oc_remain > 0) { | |||||
| size_t oc_idx = oc_end; | |||||
| const size_t weight_offset = oc_idx * ic * fh * fw; | |||||
| for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { | |||||
| for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||||
| const size_t src_offset = (oh_idx * stride_h * iw + | |||||
| ow_idx * stride_w * ih_step) * | |||||
| ic_step * pack_iw_len; | |||||
| const size_t dst_offset = oc_idx * img_stride + | |||||
| (oh_idx * ow + ow_idx) * oc_step; | |||||
| KerNeonXXs2NchwNchw44<bias_mode, Op, ow_step, filter_size, | |||||
| oc_step, | |||||
| stride>::impl(src + src_offset, | |||||
| filter + weight_offset, | |||||
| bias + oc_idx, | |||||
| dst + dst_offset, ic, | |||||
| ih, iw, ld_dst_oc, op); | |||||
| } | |||||
| if (ow_remain > 0) { | |||||
| const size_t src_offset = (oh_idx * stride_h * iw + | |||||
| ow_end * stride_w * ih_step) * | |||||
| ic_step * pack_iw_len; | |||||
| const size_t dst_offset = oc_idx * img_stride + | |||||
| (oh_idx * ow + ow_end) * oc_step; | |||||
| kern_small_oc_remain(src + src_offset, | |||||
| filter + weight_offset, bias + oc_idx, | |||||
| dst + dst_offset, ic, ih, iw, | |||||
| ld_dst_oc, op); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| }; | |||||
| #define INSTANCE_CONV_KERN_FUN(stride, filter_size, bias_mode, Op) \ | |||||
| template struct ConvDiectStrideInt8NchwNchw44<bias_mode, Op, filter_size, \ | |||||
| stride>; | |||||
| #define INSTANCE_OP_PARAM(stride, filter, bias_mode) \ | |||||
| INSTANCE_CONV_KERN_FUN(stride, filter, bias_mode, \ | |||||
| TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
| INSTANCE_CONV_KERN_FUN(stride, filter, bias_mode, \ | |||||
| ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | |||||
| INSTANCE_CONV_KERN_FUN(stride, filter, bias_mode, \ | |||||
| HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) | |||||
| #define INSTANCE_BIAS_MODE_PARAM(stride, filter) \ | |||||
| INSTANCE_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \ | |||||
| INSTANCE_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) | |||||
| #define INSTANCE_CONV_KERN(stride) \ | |||||
| INSTANCE_BIAS_MODE_PARAM(stride, 2) \ | |||||
| INSTANCE_BIAS_MODE_PARAM(stride, 3) \ | |||||
| INSTANCE_BIAS_MODE_PARAM(stride, 5) \ | |||||
| INSTANCE_BIAS_MODE_PARAM(stride, 7) | |||||
| INSTANCE_CONV_KERN(1); | |||||
| } // namespace int8_direct_nchw_nchw44 | |||||
| } // namespace arm_common | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -114,7 +114,7 @@ static void copy_padding_kern(const WorkspaceBundle& bundle, | |||||
| rep(ih_idx, IH) { | rep(ih_idx, IH) { | ||||
| std::memset(sptr_base, 0, nr_pad_w * sizeof(int8_t)); | std::memset(sptr_base, 0, nr_pad_w * sizeof(int8_t)); | ||||
| sptr_base += nr_pad_w; | sptr_base += nr_pad_w; | ||||
| nchw44_pack_src(sptr, sptr_base, IW); | |||||
| int8_direct_nchw44::nchw44_pack_src(sptr, sptr_base, IW); | |||||
| sptr_base += IW * pack_ic * expend_element; | sptr_base += IW * pack_ic * expend_element; | ||||
| sptr += IW * pack_ic; | sptr += IW * pack_ic; | ||||
| std::memset(sptr_base, 0, row_last_pad * sizeof(int8_t)); | std::memset(sptr_base, 0, row_last_pad * sizeof(int8_t)); | ||||
| @@ -125,8 +125,8 @@ static void copy_padding_kern(const WorkspaceBundle& bundle, | |||||
| } | } | ||||
| } | } | ||||
| template <size_t filter, BiasMode bias_mode, typename Op, int ow_remain, | |||||
| typename DstType, int stride> | |||||
| template <size_t filter, BiasMode bias_mode, typename Op, typename DstType, | |||||
| int stride> | |||||
| static void do_conv_kern(const WorkspaceBundle& bundle, | static void do_conv_kern(const WorkspaceBundle& bundle, | ||||
| const ConvBiasImpl::NCBKernParam& kern_param, | const ConvBiasImpl::NCBKernParam& kern_param, | ||||
| const ConvBiasImpl::NCBKernIndex& ncb_index, | const ConvBiasImpl::NCBKernIndex& ncb_index, | ||||
| @@ -182,8 +182,10 @@ static void do_conv_kern(const WorkspaceBundle& bundle, | |||||
| kern_param.bias<dt_int32>(batch_id, group_id) + oc_idx; | kern_param.bias<dt_int32>(batch_id, group_id) + oc_idx; | ||||
| auto packed_weight = reinterpret_cast<int8_t*>(bundle.get(1)) + | auto packed_weight = reinterpret_cast<int8_t*>(bundle.get(1)) + | ||||
| group_id * OC * IC * FH * FW + oc_idx * IC * FH * FW; | group_id * OC * IC * FH * FW + oc_idx * IC * FH * FW; | ||||
| nchw44_pack_filter(fptr, packed_weight, oc_block / 4 * IC / 4 * FH * FW); | |||||
| conv_direct_int8_nchw44<bias_mode, Op, ow_remain, filter, DstType, stride>( | |||||
| int8_direct_nchw44::nchw44_pack_filter(fptr, packed_weight, | |||||
| oc_block / 4 * IC / 4 * FH * FW); | |||||
| int8_direct_nchw44::conv_direct_int8_nchw44<bias_mode, Op, filter, DstType, | |||||
| stride>( | |||||
| sptr, packed_weight, bptr, nullptr, static_cast<DstType*>(dst), | sptr, packed_weight, bptr, nullptr, static_cast<DstType*>(dst), | ||||
| oc_block, IC, IH2, IW2, OH, OW, op); | oc_block, IC, IH2, IW2, OH, OW, op); | ||||
| } | } | ||||
| @@ -233,40 +235,38 @@ ConvBiasImpl::AlgoS8DirectNCHW44::dispatch_kerns( | |||||
| size_t N = param.n; | size_t N = param.n; | ||||
| size_t IC = fm.icpg; | size_t IC = fm.icpg; | ||||
| size_t OC = fm.ocpg; | size_t OC = fm.ocpg; | ||||
| size_t OW = param.osz[1]; | |||||
| size_t group = fm.group; | size_t group = fm.group; | ||||
| size_t fh = fm.spatial[0]; | size_t fh = fm.spatial[0]; | ||||
| size_t fw = fm.spatial[1]; | size_t fw = fm.spatial[1]; | ||||
| WorkspaceBundle wbundle = get_bundle(param); | WorkspaceBundle wbundle = get_bundle(param); | ||||
| conv_fun do_conv_fun = nullptr; | conv_fun do_conv_fun = nullptr; | ||||
| int ow_remain = OW % 8; | |||||
| bool need_post_process = param.dst_type.enumv() == DTypeEnum::QuantizedS8; | bool need_post_process = param.dst_type.enumv() == DTypeEnum::QuantizedS8; | ||||
| // NOTE: remain_w is not used to gen hash of midout for compatible with changing | // NOTE: remain_w is not used to gen hash of midout for compatible with changing | ||||
| // shape runtime | // shape runtime | ||||
| #define DO_CONV_KERN_FUN(stride, dst_type, filter, bias_mode, remain_w, op) \ | |||||
| #define DO_CONV_KERN_FUN(stride, dst_type, filter, bias_mode, op) \ | |||||
| MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw44, \ | MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw44, \ | ||||
| midout_iv(#stride #dst_type #filter #bias_mode #op##_hash)) { \ | midout_iv(#stride #dst_type #filter #bias_mode #op##_hash)) { \ | ||||
| do_conv_fun = do_conv_kern<filter, bias_mode, op, remain_w, dst_type, \ | |||||
| stride>; \ | |||||
| do_conv_fun = do_conv_kern<filter, bias_mode, op, dst_type, stride>; \ | |||||
| } \ | } \ | ||||
| MIDOUT_END(); | MIDOUT_END(); | ||||
| #define GET_OP_PARAM(stride, filter, bias_mode, remain_w) \ | |||||
| #define GET_OP_PARAM(stride, filter, bias_mode) \ | |||||
| if (need_post_process) { \ | if (need_post_process) { \ | ||||
| switch (param.nonlineMode) { \ | switch (param.nonlineMode) { \ | ||||
| case param::ConvBias::NonlineMode::IDENTITY: \ | case param::ConvBias::NonlineMode::IDENTITY: \ | ||||
| DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \ | DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \ | ||||
| remain_w, \ | |||||
| \ | |||||
| TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | ||||
| break; \ | break; \ | ||||
| case param::ConvBias::NonlineMode::RELU: \ | case param::ConvBias::NonlineMode::RELU: \ | ||||
| DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \ | DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \ | ||||
| remain_w, \ | |||||
| \ | |||||
| ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | ||||
| break; \ | break; \ | ||||
| case param::ConvBias::NonlineMode::H_SWISH: \ | case param::ConvBias::NonlineMode::H_SWISH: \ | ||||
| DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \ | DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \ | ||||
| remain_w, \ | |||||
| \ | |||||
| HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \ | ||||
| break; \ | break; \ | ||||
| default: \ | default: \ | ||||
| @@ -277,7 +277,7 @@ ConvBiasImpl::AlgoS8DirectNCHW44::dispatch_kerns( | |||||
| switch (param.nonlineMode) { \ | switch (param.nonlineMode) { \ | ||||
| case param::ConvBias::NonlineMode::IDENTITY: \ | case param::ConvBias::NonlineMode::IDENTITY: \ | ||||
| DO_CONV_KERN_FUN(stride, dt_int32, filter, bias_mode, \ | DO_CONV_KERN_FUN(stride, dt_int32, filter, bias_mode, \ | ||||
| remain_w, NoneOp<dt_int32>) \ | |||||
| NoneOp<dt_int32>) \ | |||||
| break; \ | break; \ | ||||
| default: \ | default: \ | ||||
| megdnn_assert( \ | megdnn_assert( \ | ||||
| @@ -287,48 +287,17 @@ ConvBiasImpl::AlgoS8DirectNCHW44::dispatch_kerns( | |||||
| } \ | } \ | ||||
| } | } | ||||
| #define GET_REMAIN_W_PARAM(stride, filter, bias_mode) \ | |||||
| switch (ow_remain) { \ | |||||
| case 0: \ | |||||
| GET_OP_PARAM(stride, filter, bias_mode, 0); \ | |||||
| break; \ | |||||
| case 1: \ | |||||
| GET_OP_PARAM(stride, filter, bias_mode, 1); \ | |||||
| break; \ | |||||
| case 2: \ | |||||
| GET_OP_PARAM(stride, filter, bias_mode, 2); \ | |||||
| break; \ | |||||
| case 3: \ | |||||
| GET_OP_PARAM(stride, filter, bias_mode, 3); \ | |||||
| break; \ | |||||
| case 4: \ | |||||
| GET_OP_PARAM(stride, filter, bias_mode, 4); \ | |||||
| break; \ | |||||
| case 5: \ | |||||
| GET_OP_PARAM(stride, filter, bias_mode, 5); \ | |||||
| break; \ | |||||
| case 6: \ | |||||
| GET_OP_PARAM(stride, filter, bias_mode, 6); \ | |||||
| break; \ | |||||
| case 7: \ | |||||
| GET_OP_PARAM(stride, filter, bias_mode, 7); \ | |||||
| break; \ | |||||
| default: \ | |||||
| megdnn_assert(0); \ | |||||
| } | |||||
| #define GET_BIAS_MODE_PARAM(stride, filter) \ | |||||
| switch (param.bias_mode) { \ | |||||
| case BiasMode::NO_BIAS: \ | |||||
| GET_REMAIN_W_PARAM(stride, filter, BiasMode::NO_BIAS) \ | |||||
| break; \ | |||||
| case BiasMode::BROADCAST_CHANNEL_BIAS: \ | |||||
| GET_REMAIN_W_PARAM(stride, filter, \ | |||||
| BiasMode::BROADCAST_CHANNEL_BIAS) \ | |||||
| break; \ | |||||
| default: \ | |||||
| megdnn_assert(0); \ | |||||
| break; \ | |||||
| #define GET_BIAS_MODE_PARAM(stride, filter) \ | |||||
| switch (param.bias_mode) { \ | |||||
| case BiasMode::NO_BIAS: \ | |||||
| GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \ | |||||
| break; \ | |||||
| case BiasMode::BROADCAST_CHANNEL_BIAS: \ | |||||
| GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) \ | |||||
| break; \ | |||||
| default: \ | |||||
| megdnn_assert(0); \ | |||||
| break; \ | |||||
| } | } | ||||
| #define DISPATCH_CONV_KERN(stride) \ | #define DISPATCH_CONV_KERN(stride) \ | ||||
| @@ -117,11 +117,11 @@ static void copy_padding_kern(const WorkspaceBundle& bundle, | |||||
| const size_t tmp_size = get_temp_bytes(iw, pw); | const size_t tmp_size = get_temp_bytes(iw, pw); | ||||
| int8_t* tmp_ptr = reinterpret_cast<int8_t*>(bundle.get(2)) + | int8_t* tmp_ptr = reinterpret_cast<int8_t*>(bundle.get(2)) + | ||||
| ncb_index.thread_id * tmp_size; | ncb_index.thread_id * tmp_size; | ||||
| pack_nchw_src_for_nchw44_conv<1>(sptr, sptr_base, 1, ph, ph, pw, pw, ih, | |||||
| iw, iw2, pw, tmp_ptr); | |||||
| int8_direct_nchw_nchw44::pack_nchw_src_for_nchw44_conv<1>( | |||||
| sptr, sptr_base, 1, ph, ph, pw, pw, ih, iw, iw2, pw, tmp_ptr); | |||||
| } else { | } else { | ||||
| pack_nchw_src_for_nchw44_conv<2>(sptr, sptr_base, 1, ph, ph, pw, pw, ih, | |||||
| iw, iw2, pw, nullptr); | |||||
| int8_direct_nchw_nchw44::pack_nchw_src_for_nchw44_conv<2>( | |||||
| sptr, sptr_base, 1, ph, ph, pw, pw, ih, iw, iw2, pw, nullptr); | |||||
| } | } | ||||
| } | } | ||||
| static void pack_weight(const WorkspaceBundle& bundle, | static void pack_weight(const WorkspaceBundle& bundle, | ||||
| @@ -142,11 +142,11 @@ static void pack_weight(const WorkspaceBundle& bundle, | |||||
| group_id * oc * ic * fh * fw2 + oc_idx * ic * fh * fw2; | group_id * oc * ic * fh * fw2 + oc_idx * ic * fh * fw2; | ||||
| if (stride_h == 1) { | if (stride_h == 1) { | ||||
| pack_nchw44_weight_for_nchw_conv<1>(fptr, packed_weight, ic, fh, fw, | |||||
| oc_block); | |||||
| int8_direct_nchw_nchw44::pack_nchw44_weight_for_nchw_conv<1>( | |||||
| fptr, packed_weight, ic, fh, fw, oc_block); | |||||
| } else { | } else { | ||||
| pack_nchw44_weight_for_nchw_conv<2>(fptr, packed_weight, ic, fh, fw, | |||||
| oc_block); | |||||
| int8_direct_nchw_nchw44::pack_nchw44_weight_for_nchw_conv<2>( | |||||
| fptr, packed_weight, ic, fh, fw, oc_block); | |||||
| } | } | ||||
| } | } | ||||
| template <size_t filter, BiasMode bias_mode, typename Op, int stride> | template <size_t filter, BiasMode bias_mode, typename Op, int stride> | ||||
| @@ -208,7 +208,8 @@ static void do_conv_kern(const WorkspaceBundle& bundle, | |||||
| int8_t* packed_weight = reinterpret_cast<int8_t*>(bundle.get(1)) + | int8_t* packed_weight = reinterpret_cast<int8_t*>(bundle.get(1)) + | ||||
| group_id * oc * ic * fh * fw2 + | group_id * oc * ic * fh * fw2 + | ||||
| oc_idx * ic * fh * fw2; | oc_idx * ic * fh * fw2; | ||||
| conv_direct_int8_nchw_nchw44<bias_mode, Op, filter, stride>( | |||||
| int8_direct_nchw_nchw44::conv_direct_int8_nchw_nchw44<bias_mode, Op, filter, | |||||
| stride>( | |||||
| sptr, packed_weight, bptr, nullptr, dst, oc_block, ic, ih2, iw2, oh, | sptr, packed_weight, bptr, nullptr, dst, oc_block, ic, ih2, iw2, oh, | ||||
| ow, op); | ow, op); | ||||
| } | } | ||||
| @@ -93,8 +93,8 @@ void do_weight_trans(const WorkspaceBundle& bundle, | |||||
| const int fw2 = round_up(fw, 4); | const int fw2 = round_up(fw, 4); | ||||
| auto packed_weight = reinterpret_cast<int8_t*>(bundle.get(1)); | auto packed_weight = reinterpret_cast<int8_t*>(bundle.get(1)); | ||||
| auto origin_weight = kern_param.filter<dt_int8>(); | auto origin_weight = kern_param.filter<dt_int8>(); | ||||
| pack_weight_int8_nchw_nchw44_dot(packed_weight, origin_weight, oc, ic, fh, | |||||
| fw, fw2); | |||||
| dot_direct_nchw_nchw44::pack_weight_int8_nchw_nchw44_dot( | |||||
| packed_weight, origin_weight, oc, ic, fh, fw, fw2); | |||||
| } | } | ||||
| template <size_t filter, BiasMode bias_mode, typename Op, int stride> | template <size_t filter, BiasMode bias_mode, typename Op, int stride> | ||||
| @@ -147,7 +147,7 @@ static void do_conv_kern(const WorkspaceBundle& bundle, | |||||
| tmp_ptr = reinterpret_cast<int8_t*>(bundle.get(2)) + | tmp_ptr = reinterpret_cast<int8_t*>(bundle.get(2)) + | ||||
| ncb_index.thread_id * tmp_size; | ncb_index.thread_id * tmp_size; | ||||
| } | } | ||||
| pack_src_int8_nchw_nchw44_dot<stride>( | |||||
| dot_direct_nchw_nchw44::pack_src_int8_nchw_nchw44_dot<stride>( | |||||
| sptr, origin_sptr, ph, pw, remain_right_pad, | sptr, origin_sptr, ph, pw, remain_right_pad, | ||||
| ih_real - src_top_pad - src_bottom_pad, iw, iw2, src_top_pad, | ih_real - src_top_pad - src_bottom_pad, iw, iw2, src_top_pad, | ||||
| src_bottom_pad, ic, ih * iw, tmp_ptr); | src_bottom_pad, ic, ih * iw, tmp_ptr); | ||||
| @@ -164,7 +164,8 @@ static void do_conv_kern(const WorkspaceBundle& bundle, | |||||
| float scale_bias = kern_param.bias_type.param<dtype::QuantizedS32>().scale; | float scale_bias = kern_param.bias_type.param<dtype::QuantizedS32>().scale; | ||||
| float scale_dst = kern_param.dst_type.param<dtype::QuantizedS8>().scale; | float scale_dst = kern_param.dst_type.param<dtype::QuantizedS8>().scale; | ||||
| Op op(scale_bias, scale_dst); | Op op(scale_bias, scale_dst); | ||||
| conv_direct_int8_nchw_nchw44_dot<bias_mode, Op, filter, stride>( | |||||
| dot_direct_nchw_nchw44::conv_direct_int8_nchw_nchw44_dot<bias_mode, Op, | |||||
| filter, stride>( | |||||
| sptr, fptr, bptr, nullptr, dst, oc_block, ic, ih_real, iw2, oh, | sptr, fptr, bptr, nullptr, dst, oc_block, ic, ih_real, iw2, oh, | ||||
| oh_block_real, ow, op); | oh_block_real, ow, op); | ||||
| } | } | ||||
| @@ -20,83 +20,15 @@ | |||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
| using namespace megdnn; | |||||
| using namespace arm_common; | |||||
| namespace { | |||||
| namespace megdnn { | |||||
| namespace arm_common { | |||||
| namespace dot_direct_nchw_nchw44 { | |||||
| template <int src_idx, int weight_idx, int c_dim, typename Func, int ow_block, | template <int src_idx, int weight_idx, int c_dim, typename Func, int ow_block, | ||||
| int stride, typename T, typename T2, typename T3, typename T4> | int stride, typename T, typename T2, typename T3, typename T4> | ||||
| struct ShiftCalHelper { | struct ShiftCalHelper { | ||||
| static void impl(T& c, T2& src, T3& weight); | static void impl(T& c, T2& src, T3& weight); | ||||
| }; | }; | ||||
| template <int src_idx, int weight_idx, typename Func, int stride, typename T, | |||||
| typename T2, typename T3, typename T4> | |||||
| struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 8, stride, T, T2, T3, T4> { | |||||
| static void impl(T& c, T2& src, T3& weight) { | |||||
| #define cb(step) \ | |||||
| c[0][step * 2] = Func::template impl<(src_idx + step) % 4>( \ | |||||
| c[0][step * 2], weight[0][weight_idx], \ | |||||
| src[0][(src_idx + step) / 4]); \ | |||||
| c[1][step * 2] = Func::template impl<(src_idx + step) % 4>( \ | |||||
| c[1][step * 2], weight[1][weight_idx], \ | |||||
| src[0][(src_idx + step) / 4]); \ | |||||
| c[0][step * 2 + 1] = Func::template impl<(src_idx + step) % 4>( \ | |||||
| c[0][step * 2 + 1], weight[0][weight_idx], \ | |||||
| src[1][(src_idx + step) / 4]); \ | |||||
| c[1][step * 2 + 1] = Func::template impl<(src_idx + step) % 4>( \ | |||||
| c[1][step * 2 + 1], weight[1][weight_idx], \ | |||||
| src[1][(src_idx + step) / 4]); | |||||
| UNROLL_CALL_RAW(4, cb); | |||||
| #undef cb | |||||
| } | |||||
| }; | |||||
| template <int src_idx, int weight_idx, typename Func, int stride, typename T, | |||||
| typename T2, typename T3, typename T4> | |||||
| struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, stride, T, T2, T3, T4> { | |||||
| static void impl(T& c, T2& src, T3& weight) { | |||||
| #define cb(step) \ | |||||
| c[0][step * 2] = Func::template impl<(src_idx + step) % 4>( \ | |||||
| c[0][step * 2], weight[0][weight_idx], \ | |||||
| src[0][(src_idx + step) / 4]); \ | |||||
| c[0][step * 2 + 1] = Func::template impl<(src_idx + step) % 4>( \ | |||||
| c[0][step * 2 + 1], weight[0][weight_idx], \ | |||||
| src[1][(src_idx + step) / 4]); | |||||
| UNROLL_CALL_RAW(4, cb); | |||||
| #undef cb | |||||
| } | |||||
| }; | |||||
| template <int src_idx, int weight_idx, typename Func, typename T, typename T2, | |||||
| typename T3, typename T4> | |||||
| struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 8, 1, T, T2, T3, T4> { | |||||
| static void impl(T& c, T2& src, T3& weight) { | |||||
| #define cb(step) \ | |||||
| c[0][step] = Func::template impl<(src_idx + step) % 4>( \ | |||||
| c[0][step], weight[0][weight_idx], src[(src_idx + step) / 4]); \ | |||||
| c[1][step] = Func::template impl<(src_idx + step) % 4>( \ | |||||
| c[1][step], weight[1][weight_idx], src[(src_idx + step) / 4]); | |||||
| UNROLL_CALL_RAW(8, cb); | |||||
| #undef cb | |||||
| } | |||||
| }; | |||||
| template <int src_idx, int weight_idx, typename Func, typename T, typename T2, | |||||
| typename T3, typename T4> | |||||
| struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, 1, T, T2, T3, T4> { | |||||
| static void impl(T& c, T2& src, T3& weight) { | |||||
| #define cb(step) \ | |||||
| c[0][step] = Func::template impl<(src_idx + step) % 4>( \ | |||||
| c[0][step], weight[0][weight_idx], src[(src_idx + step) / 4]); | |||||
| UNROLL_CALL_RAW(8, cb); | |||||
| #undef cb | |||||
| } | |||||
| }; | |||||
| template <int src_idx, int weight_idx, int c_dim, typename FUNC, int ow_block, | template <int src_idx, int weight_idx, int c_dim, typename FUNC, int ow_block, | ||||
| int stride, typename T, typename T2, typename T3> | int stride, typename T, typename T2, typename T3> | ||||
| inline void cal_helper(T& c, T2& src, T3& weight) { | inline void cal_helper(T& c, T2& src, T3& weight) { | ||||
| @@ -133,490 +65,12 @@ struct KerNeonDotXXs2Nchw44Int8 { | |||||
| int iw, int ld_dst_oc, const Op& op); | int iw, int ld_dst_oc, const Op& op); | ||||
| }; | }; | ||||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | |||||
| int ow_block, int stride> | |||||
| struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 2, oc_block, ow_block, | |||||
| stride> { | |||||
| static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | |||||
| int iw, int ld_dst_oc, const Op& op) { | |||||
| constexpr int filter_hight = 2; | |||||
| constexpr int filter_width = 4; | |||||
| constexpr int weight_reg = 1; | |||||
| constexpr int src_reg = 1; | |||||
| constexpr int oc_step = 4; | |||||
| constexpr int ic_step = 1; | |||||
| constexpr int pack_iw_len = 1; | |||||
| constexpr int simd_len = 16; | |||||
| const int ld_bias = oc_step; | |||||
| const int ic_stride = ih * iw * pack_iw_len; | |||||
| const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; | |||||
| constexpr int c_dim = OCHelper<oc_block>::val; | |||||
| int32x4_t c[c_dim][8]; | |||||
| init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | |||||
| int8x16_t src[2][src_reg]; | |||||
| int8x16_t weight[c_dim][weight_reg]; | |||||
| // row 0 | |||||
| load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>( | |||||
| src, src_ptr + 0 * iw, stride); | |||||
| load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( | |||||
| weight, weight_ptr, ld_weight_oc); | |||||
| cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, | |||||
| weight); | |||||
| // row 1 | |||||
| load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>( | |||||
| src, src_ptr + 1 * iw, stride); | |||||
| load_helper<weight_reg, 1 * simd_len, simd_len, c_dim, Vld1q_s8>( | |||||
| weight, weight_ptr, ld_weight_oc); | |||||
| cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, | |||||
| weight); | |||||
| src_ptr += ic_stride; | |||||
| weight_ptr += filter_hight * filter_width * oc_step; | |||||
| } | |||||
| store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>( | |||||
| c, op, dst_ptr, ld_dst_oc); | |||||
| } | |||||
| }; | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | |||||
| int ow_block, int stride> | |||||
| struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 3, oc_block, ow_block, | |||||
| stride> { | |||||
| static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | |||||
| int iw, int ld_dst_oc, const Op& op) { | |||||
| constexpr int filter_hight = 3; | |||||
| constexpr int filter_width = 4; | |||||
| constexpr int weight_reg = 1; | |||||
| constexpr int src_reg = 1; | |||||
| constexpr int oc_step = 4; | |||||
| constexpr int ic_step = 1; | |||||
| constexpr int pack_iw_len = 1; | |||||
| constexpr int simd_len = 16; | |||||
| const int ld_bias = oc_step; | |||||
| const int ic_stride = ih * iw * pack_iw_len; | |||||
| const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; | |||||
| constexpr int c_dim = OCHelper<oc_block>::val; | |||||
| int32x4_t c[c_dim][8]; | |||||
| init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | |||||
| int8x16_t src[2][src_reg]; | |||||
| int8x16_t weight[c_dim][weight_reg]; | |||||
| // row 0 | |||||
| load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>( | |||||
| src, src_ptr + 0 * iw, stride); | |||||
| load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( | |||||
| weight, weight_ptr, ld_weight_oc); | |||||
| cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, | |||||
| weight); | |||||
| // row 1 | |||||
| load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>( | |||||
| src, src_ptr + 1 * iw, stride); | |||||
| load_helper<weight_reg, 1 * simd_len, simd_len, c_dim, Vld1q_s8>( | |||||
| weight, weight_ptr, ld_weight_oc); | |||||
| cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, | |||||
| weight); | |||||
| // row 2 | |||||
| load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>( | |||||
| src, src_ptr + 2 * iw, stride); | |||||
| load_helper<weight_reg, 2 * simd_len, simd_len, c_dim, Vld1q_s8>( | |||||
| weight, weight_ptr, ld_weight_oc); | |||||
| cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, | |||||
| weight); | |||||
| src_ptr += ic_stride; | |||||
| weight_ptr += filter_hight * filter_width * oc_step; | |||||
| } | |||||
| store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>( | |||||
| c, op, dst_ptr, ld_dst_oc); | |||||
| } | |||||
| }; | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | |||||
| int ow_block, int stride> | |||||
| struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 5, oc_block, ow_block, | |||||
| stride> { | |||||
| static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | |||||
| int iw, int ld_dst_oc, const Op& op) { | |||||
| constexpr int filter_hight = 5; | |||||
| constexpr int filter_width = 8; | |||||
| constexpr int src_reg = 2; | |||||
| constexpr int weight_reg = 2; | |||||
| constexpr int oc_step = 4; | |||||
| constexpr int ic_step = 1; | |||||
| constexpr int pack_iw_len = 1; | |||||
| constexpr int simd_len = 16; | |||||
| const int ld_bias = oc_step; | |||||
| const int ic_stride = ih * iw * pack_iw_len; | |||||
| const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; | |||||
| constexpr int c_dim = OCHelper<oc_block>::val; | |||||
| int32x4_t c[c_dim][8]; | |||||
| init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | |||||
| int8x16_t src[2][src_reg]; | |||||
| int8x16_t weight[c_dim][weight_reg]; | |||||
| #define cb(step) \ | |||||
| load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>(src, src_ptr + step * iw, \ | |||||
| stride); \ | |||||
| load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \ | |||||
| weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \ | |||||
| cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \ | |||||
| weight); \ | |||||
| cal_helper<1, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight); | |||||
| UNROLL_CALL_RAW(5, cb); | |||||
| #undef cb | |||||
| src_ptr += ic_stride; | |||||
| weight_ptr += 5 * 32; | |||||
| } | |||||
| store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>( | |||||
| c, op, dst_ptr, ld_dst_oc); | |||||
| } | |||||
| }; | |||||
| /** | |||||
| * oc = 8, ow = 8 | |||||
| * dot 4 element, pad last filter and do twice dot every row filter, filter like | |||||
| * below | |||||
| * -------------------------- | |||||
| * |x, x, x, x,| x, x, x, 0 | | |||||
| * -------------------------- | |||||
| **/ | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | |||||
| int ow_block, int stride> | |||||
| struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 7, oc_block, ow_block, | |||||
| stride> { | |||||
| static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | |||||
| int iw, int ld_dst_oc, const Op& op) { | |||||
| constexpr int filter_hight = 7; | |||||
| constexpr int filter_width = 8; | |||||
| constexpr int src_reg = 2; | |||||
| constexpr int weight_reg = 2; | |||||
| constexpr int oc_step = 4; | |||||
| constexpr int ic_step = 1; | |||||
| constexpr int pack_iw_len = 1; | |||||
| constexpr int simd_len = 16; | |||||
| const int ld_bias = oc_step; | |||||
| const int ic_stride = ih * iw * pack_iw_len; | |||||
| const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; | |||||
| constexpr int c_dim = OCHelper<oc_block>::val; | |||||
| int32x4_t c[c_dim][8]; | |||||
| init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | |||||
| int8x16_t src[2][src_reg]; | |||||
| int8x16_t weight[c_dim][weight_reg]; | |||||
| #define cb(step) \ | |||||
| load_helper<src_reg, 0, simd_len, 2, Vld1q_s8>(src, src_ptr + step * iw, \ | |||||
| stride); \ | |||||
| load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \ | |||||
| weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \ | |||||
| cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \ | |||||
| weight); \ | |||||
| cal_helper<1, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight); | |||||
| UNROLL_CALL_RAW(7, cb); | |||||
| #undef cb | |||||
| src_ptr += ic_stride; | |||||
| weight_ptr += 7 * 32; | |||||
| } | |||||
| store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>( | |||||
| c, op, dst_ptr, ld_dst_oc); | |||||
| } | |||||
| }; | |||||
| ////////////////////stride 1/////////////////// | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | |||||
| int ow_block> | |||||
| struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 2, oc_block, ow_block, | |||||
| 1> { | |||||
| static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | |||||
| int iw, int ld_dst_oc, const Op& op) { | |||||
| constexpr int stride = 1; | |||||
| constexpr int filter_hight = 2; | |||||
| constexpr int filter_width = 4; | |||||
| constexpr int weight_reg = 2; | |||||
| constexpr int src_reg = 2; | |||||
| constexpr int oc_step = 4; | |||||
| constexpr int ic_step = 1; | |||||
| constexpr int pack_iw_len = 4; | |||||
| constexpr int simd_len = 16; | |||||
| const int ld_bias = oc_step; | |||||
| const int ic_stride = ih * iw * pack_iw_len; | |||||
| const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; | |||||
| constexpr int c_dim = OCHelper<oc_block>::val; | |||||
| int32x4_t c[c_dim][8]; | |||||
| init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | |||||
| int8x16_t src[src_reg]; | |||||
| int8x16_t weight[c_dim][weight_reg]; | |||||
| // row 0 | |||||
| load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( | |||||
| src, src_ptr + 0 * iw * pack_iw_len, 0); | |||||
| load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( | |||||
| weight, weight_ptr, ld_weight_oc); | |||||
| cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, | |||||
| weight); | |||||
| // row 1 | |||||
| load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( | |||||
| src, src_ptr + 1 * iw * pack_iw_len, 0); | |||||
| cal_helper<0, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, | |||||
| weight); | |||||
| src_ptr += ic_stride; | |||||
| weight_ptr += filter_hight * filter_width * oc_step; | |||||
| } | |||||
| store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>( | |||||
| c, op, dst_ptr, ld_dst_oc); | |||||
| } | |||||
| }; | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | |||||
| int ow_block> | |||||
| struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 3, oc_block, ow_block, | |||||
| 1> { | |||||
| static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | |||||
| int iw, int ld_dst_oc, const Op& op) { | |||||
| constexpr int stride = 1; | |||||
| constexpr int filter_hight = 3; | |||||
| constexpr int filter_width = 4; | |||||
| constexpr int weight_reg = 3; | |||||
| constexpr int src_reg = 2; | |||||
| constexpr int oc_step = 4; | |||||
| constexpr int ic_step = 1; | |||||
| constexpr int pack_iw_len = 4; | |||||
| constexpr int simd_len = 16; | |||||
| const int ld_bias = oc_step; | |||||
| const int ic_stride = ih * iw * pack_iw_len; | |||||
| const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; | |||||
| constexpr int c_dim = OCHelper<oc_block>::val; | |||||
| int32x4_t c[c_dim][8]; | |||||
| init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | |||||
| int8x16_t src[src_reg]; | |||||
| int8x16_t weight[c_dim][weight_reg]; | |||||
| // row 0 | |||||
| load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( | |||||
| src, src_ptr + 0 * iw * pack_iw_len, 0); | |||||
| load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( | |||||
| weight, weight_ptr, ld_weight_oc); | |||||
| cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, | |||||
| weight); | |||||
| // row 1 | |||||
| load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( | |||||
| src, src_ptr + 1 * iw * pack_iw_len, 0); | |||||
| cal_helper<0, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, | |||||
| weight); | |||||
| // row 2 | |||||
| load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( | |||||
| src, src_ptr + 2 * iw * pack_iw_len, 0); | |||||
| cal_helper<0, 2, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, | |||||
| weight); | |||||
| src_ptr += ic_stride; | |||||
| weight_ptr += filter_hight * filter_width * oc_step; | |||||
| } | |||||
| store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>( | |||||
| c, op, dst_ptr, ld_dst_oc); | |||||
| } | |||||
| }; | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | |||||
| int ow_block> | |||||
| struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 5, oc_block, ow_block, | |||||
| 1> { | |||||
| static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | |||||
| int iw, int ld_dst_oc, const Op& op) { | |||||
| constexpr int stride = 1; | |||||
| constexpr int filter_hight = 5; | |||||
| constexpr int filter_width = 8; | |||||
| constexpr int src_reg = 3; | |||||
| constexpr int weight_reg = 2; | |||||
| constexpr int oc_step = 4; | |||||
| constexpr int ic_step = 1; | |||||
| constexpr int pack_iw_len = 4; | |||||
| constexpr int simd_len = 16; | |||||
| const int ld_bias = oc_step; | |||||
| const int ic_stride = ih * iw * pack_iw_len; | |||||
| const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; | |||||
| constexpr int c_dim = OCHelper<oc_block>::val; | |||||
| int32x4_t c[c_dim][8]; | |||||
| init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | |||||
| int8x16_t src[src_reg]; | |||||
| int8x16_t weight[c_dim][weight_reg]; | |||||
| #define cb(step) \ | |||||
| load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( \ | |||||
| src, src_ptr + step * iw * pack_iw_len, 0); \ | |||||
| load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \ | |||||
| weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \ | |||||
| cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \ | |||||
| weight); \ | |||||
| cal_helper<4, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight); | |||||
| UNROLL_CALL_RAW(5, cb); | |||||
| #undef cb | |||||
| src_ptr += ic_stride; | |||||
| weight_ptr += filter_hight * filter_width * oc_step; | |||||
| } | |||||
| store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>( | |||||
| c, op, dst_ptr, ld_dst_oc); | |||||
| } | |||||
| }; | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | |||||
| int ow_block> | |||||
| struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 7, oc_block, ow_block, | |||||
| 1> { | |||||
| static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | |||||
| int iw, int ld_dst_oc, const Op& op) { | |||||
| constexpr int stride = 1; | |||||
| constexpr int filter_hight = 7; | |||||
| constexpr int filter_width = 8; | |||||
| constexpr int src_reg = 3; | |||||
| constexpr int weight_reg = 2; | |||||
| constexpr int oc_step = 4; | |||||
| constexpr int ic_step = 1; | |||||
| constexpr int pack_iw_len = 4; | |||||
| constexpr int simd_len = 16; | |||||
| const int ld_bias = oc_step; | |||||
| const int ic_stride = ih * iw * pack_iw_len; | |||||
| const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; | |||||
| constexpr int c_dim = OCHelper<oc_block>::val; | |||||
| int32x4_t c[c_dim][8]; | |||||
| init_ocx_ow8<c_dim, bias_mode, ow_block>(c, bias_ptr, ld_bias); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | |||||
| int8x16_t src[src_reg]; | |||||
| int8x16_t weight[c_dim][weight_reg]; | |||||
| #define cb(step) \ | |||||
| load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( \ | |||||
| src, src_ptr + step * iw * pack_iw_len, 0); \ | |||||
| load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( \ | |||||
| weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \ | |||||
| cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \ | |||||
| weight); \ | |||||
| cal_helper<4, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight); | |||||
| UNROLL_CALL_RAW(7, cb); | |||||
| #undef cb | |||||
| src_ptr += ic_stride; | |||||
| weight_ptr += filter_hight * filter_width * oc_step; | |||||
| } | |||||
| store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>( | |||||
| c, op, dst_ptr, ld_dst_oc); | |||||
| } | |||||
| }; | |||||
| template <int stride> | template <int stride> | ||||
| void pack_src_int8_nchw_nchw44_dot(int8_t* sptr_base, const int8_t* sptr_origin, | void pack_src_int8_nchw_nchw44_dot(int8_t* sptr_base, const int8_t* sptr_origin, | ||||
| const int, const int pw, const int, | const int, const int pw, const int, | ||||
| const int ih, const int iw, const int iw2, | const int ih, const int iw, const int iw2, | ||||
| const int pad_top, const int pad_bottom, | const int pad_top, const int pad_bottom, | ||||
| const int ic, const int ic_stride, int8_t*) { | |||||
| constexpr int ic_step = 1; | |||||
| rep_step(ic_idx, ic, ic_step) { | |||||
| const int8_t* sptr = sptr_origin + ic_idx * ic_stride; | |||||
| memset(sptr_base, 0, | |||||
| sizeof(int8_t) * ic_step * iw2 * (ih + pad_top + pad_bottom)); | |||||
| sptr_base += iw2 * pad_top * ic_step; | |||||
| rep(ih_idx, ih) { | |||||
| memcpy(sptr_base + pw * ic_step, sptr, | |||||
| sizeof(int8_t) * iw * ic_step); | |||||
| sptr_base += iw2 * ic_step; | |||||
| sptr += iw * ic_step; | |||||
| } | |||||
| sptr_base += iw2 * pad_bottom * ic_step; | |||||
| } | |||||
| } | |||||
| template <> | |||||
| void pack_src_int8_nchw_nchw44_dot<1>(int8_t* sptr_base, | |||||
| const int8_t* sptr_origin, const int, | |||||
| const int pw, const int, const int ih, | |||||
| const int iw, const int iw2, | |||||
| const int pad_top, const int pad_bottom, | |||||
| const int ic, const int ic_stride, | |||||
| int8_t* temp_ptr) { | |||||
| static uint8_t reorder_idx[16] = {0, 1, 2, 3, 1, 2, 3, 4, | |||||
| 2, 3, 4, 5, 3, 4, 5, 6}; | |||||
| uint8x16_t tbl_idx = vld1q_u8(&reorder_idx[0]); | |||||
| constexpr int iw_step = 16; | |||||
| constexpr int pack_iw_len = 4; | |||||
| const int iw_with_pad = iw + 2 * pw; | |||||
| const int iw_with_pad_end = iw_with_pad / iw_step * iw_step; | |||||
| rep(ic_idx, ic) { | |||||
| const int8_t* sptr = sptr_origin + ic_idx * ic_stride; | |||||
| memset(sptr_base, 0, | |||||
| sizeof(int8_t) * iw2 * (ih + pad_top + pad_bottom) * | |||||
| pack_iw_len); | |||||
| sptr_base += iw2 * pad_top * pack_iw_len; | |||||
| rep(ih_idx, ih) { | |||||
| memset(temp_ptr, 0, iw_with_pad * sizeof(int8_t)); | |||||
| memcpy(temp_ptr + pw, sptr, sizeof(int8_t) * iw); | |||||
| for (int iw_idx = 0; iw_idx < iw_with_pad_end; iw_idx += iw_step) { | |||||
| int8x16_t src[4]; | |||||
| int8x16_t dst[4]; | |||||
| src[0] = vld1q_s8(temp_ptr + iw_idx); | |||||
| src[1] = vld1q_s8(temp_ptr + iw_idx + 4); | |||||
| src[2] = vld1q_s8(temp_ptr + iw_idx + 8); | |||||
| src[3] = vld1q_s8(temp_ptr + iw_idx + 12); | |||||
| dst[0] = vqtbl1q_s8(src[0], tbl_idx); | |||||
| dst[1] = vqtbl1q_s8(src[1], tbl_idx); | |||||
| dst[2] = vqtbl1q_s8(src[2], tbl_idx); | |||||
| dst[3] = vqtbl1q_s8(src[3], tbl_idx); | |||||
| vst1q_s8(sptr_base + iw_idx * pack_iw_len + 0, dst[0]); | |||||
| vst1q_s8(sptr_base + iw_idx * pack_iw_len + 16, dst[1]); | |||||
| vst1q_s8(sptr_base + iw_idx * pack_iw_len + 32, dst[2]); | |||||
| vst1q_s8(sptr_base + iw_idx * pack_iw_len + 48, dst[3]); | |||||
| } | |||||
| for (int iw_idx = iw_with_pad_end; iw_idx < iw_with_pad; ++iw_idx) { | |||||
| *(sptr_base + iw_idx * pack_iw_len + 0) = | |||||
| *(temp_ptr + iw_idx + 0); | |||||
| *(sptr_base + iw_idx * pack_iw_len + 1) = | |||||
| *(temp_ptr + iw_idx + 1); | |||||
| *(sptr_base + iw_idx * pack_iw_len + 2) = | |||||
| *(temp_ptr + iw_idx + 2); | |||||
| *(sptr_base + iw_idx * pack_iw_len + 3) = | |||||
| *(temp_ptr + iw_idx + 3); | |||||
| } | |||||
| sptr_base += iw2 * pack_iw_len; | |||||
| sptr += iw; | |||||
| } | |||||
| sptr_base += iw2 * pad_bottom * pack_iw_len; | |||||
| } | |||||
| } | |||||
| const int ic, const int ic_stride, int8_t*); | |||||
| static inline void pack_weight_int8_nchw_nchw44_dot(int8_t* dst_ptr, | static inline void pack_weight_int8_nchw_nchw44_dot(int8_t* dst_ptr, | ||||
| const int8_t* src_ptr, | const int8_t* src_ptr, | ||||
| @@ -663,117 +117,15 @@ static inline void pack_weight_int8_nchw_nchw44_dot(int8_t* dst_ptr, | |||||
| } | } | ||||
| template <BiasMode bias_mode, typename Op, int filter_size, int stride> | template <BiasMode bias_mode, typename Op, int filter_size, int stride> | ||||
| static void conv_direct_int8_nchw_nchw44_dot( | |||||
| const int8_t* src, const int8_t* filter, const int32_t* bias, | |||||
| int32_t* temp, int8_t* dst, const int oc, const int ic, const int ih, | |||||
| const int iw, const int oh, const int oh_block, const int ow, | |||||
| const Op& op) { | |||||
| MEGDNN_MARK_USED_VAR(temp); | |||||
| constexpr int fh = filter_size; | |||||
| constexpr int fw = (filter_size + 3) / 4 * 4; | |||||
| #if MEGDNN_AARCH64 | |||||
| constexpr int big_oc_step = 8; | |||||
| #else | |||||
| constexpr int big_oc_step = 4; | |||||
| #endif | |||||
| constexpr int oc_step = 4; | |||||
| constexpr int ih_step = 1; | |||||
| constexpr int oh_step = 1; | |||||
| constexpr int ow_step = 8; | |||||
| constexpr int stride_h = stride; | |||||
| constexpr int stride_w = stride; | |||||
| constexpr int pack_iw_len = stride == 2 ? 1 : 4; | |||||
| const int img_stride = oh * ow; | |||||
| const int ow_end = ow / ow_step * ow_step; | |||||
| const int ow_remain = ow - ow_end; | |||||
| const int oc_end = oc / big_oc_step * big_oc_step; | |||||
| const int oc_remain = oc - oc_end; | |||||
| const int ld_dst_oc = oc_step * img_stride; | |||||
| using remain_fun = | |||||
| std::function<void(const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, int8_t* dst_ptr, int ic, | |||||
| int ih, int iw, int ld_dst_oc, const Op& op)>; | |||||
| remain_fun kern_big_oc_remain = nullptr; | |||||
| remain_fun kern_small_oc_remain = nullptr; | |||||
| switch (ow_remain) { | |||||
| #define cb(step) \ | |||||
| case step: \ | |||||
| kern_big_oc_remain = \ | |||||
| KerNeonDotXXs2Nchw44Int8<bias_mode, Op, step, filter_size, \ | |||||
| big_oc_step, ow_step, stride>::impl; \ | |||||
| kern_small_oc_remain = \ | |||||
| KerNeonDotXXs2Nchw44Int8<bias_mode, Op, step, filter_size, \ | |||||
| oc_step, ow_step, stride>::impl; \ | |||||
| break; | |||||
| UNROLL_CALL_RAW(8, cb); | |||||
| default: | |||||
| megdnn_assert(0, "no remain %d for kern", ow_remain); | |||||
| } | |||||
| for (int oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { | |||||
| const int weight_offset = oc_idx * ic * fh * fw; | |||||
| for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) { | |||||
| for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||||
| const int src_offset = | |||||
| (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * | |||||
| pack_iw_len; | |||||
| const int dst_offset = | |||||
| oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||||
| KerNeonDotXXs2Nchw44Int8<bias_mode, Op, ow_step, filter_size, | |||||
| big_oc_step, ow_step, | |||||
| stride>::impl(src + src_offset, | |||||
| filter + weight_offset, | |||||
| bias + oc_idx, | |||||
| dst + dst_offset, ic, ih, | |||||
| iw, ld_dst_oc, op); | |||||
| } | |||||
| if (ow_remain > 0) { | |||||
| const int src_offset = | |||||
| (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * | |||||
| pack_iw_len; | |||||
| const int dst_offset = | |||||
| oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||||
| kern_big_oc_remain(src + src_offset, filter + weight_offset, | |||||
| bias + oc_idx, dst + dst_offset, ic, ih, iw, | |||||
| ld_dst_oc, op); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (oc_remain > 0) { | |||||
| int oc_idx = oc_end; | |||||
| const int weight_offset = oc_idx * ic * fh * fw; | |||||
| for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) { | |||||
| for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { | |||||
| const int src_offset = | |||||
| (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * | |||||
| pack_iw_len; | |||||
| const int dst_offset = | |||||
| oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; | |||||
| KerNeonDotXXs2Nchw44Int8<bias_mode, Op, ow_step, filter_size, | |||||
| oc_step, ow_step, | |||||
| stride>::impl(src + src_offset, | |||||
| filter + weight_offset, | |||||
| bias + oc_idx, | |||||
| dst + dst_offset, ic, ih, | |||||
| iw, ld_dst_oc, op); | |||||
| } | |||||
| if (ow_remain > 0) { | |||||
| const int src_offset = | |||||
| (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * | |||||
| pack_iw_len; | |||||
| const int dst_offset = | |||||
| oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; | |||||
| kern_small_oc_remain(src + src_offset, filter + weight_offset, | |||||
| bias + oc_idx, dst + dst_offset, ic, ih, | |||||
| iw, ld_dst_oc, op); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } // namespace | |||||
| void conv_direct_int8_nchw_nchw44_dot(const int8_t* src, const int8_t* filter, | |||||
| const int32_t* bias, int32_t* temp, | |||||
| int8_t* dst, const int oc, const int ic, | |||||
| const int ih, const int iw, const int oh, | |||||
| const int oh_block, const int ow, | |||||
| const Op& op); | |||||
| } // namespace dot_direct_nchw_nchw44 | |||||
| } // namespace arm_common | |||||
| } // namespace megdnn | |||||
| #endif | #endif | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -2344,7 +2344,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F32) { | |||||
| #endif | #endif | ||||
| std::vector<conv_bias::TestArg> gemv_args; | std::vector<conv_bias::TestArg> gemv_args; | ||||
| for (auto&& arg : args) | for (auto&& arg : args) | ||||
| if(arg.src.shape[2] == 1 && arg.src.shape[3] == 1) { | |||||
| if (arg.src.shape[2] == 1 && arg.src.shape[3] == 1) { | |||||
| gemv_args.emplace_back(arg); | gemv_args.emplace_back(arg); | ||||
| } | } | ||||
| check_conv_bias(gemv_args, handle(), "CONV1x1_GEMV"); | check_conv_bias(gemv_args, handle(), "CONV1x1_GEMV"); | ||||
| @@ -2361,7 +2361,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_MK4_PACK_F32) { | |||||
| #endif | #endif | ||||
| std::vector<conv_bias::TestArg> gemv_args; | std::vector<conv_bias::TestArg> gemv_args; | ||||
| for (auto&& arg : args) | for (auto&& arg : args) | ||||
| if(arg.src.shape[2] == 1 && arg.src.shape[3] == 1) { | |||||
| if (arg.src.shape[2] == 1 && arg.src.shape[3] == 1) { | |||||
| gemv_args.emplace_back(arg); | gemv_args.emplace_back(arg); | ||||
| } | } | ||||
| check_conv_bias(gemv_args, handle(), "CONV1x1_GEMV"); | check_conv_bias(gemv_args, handle(), "CONV1x1_GEMV"); | ||||
| @@ -6,7 +6,8 @@ | |||||
| * | * | ||||
| * Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | */ | ||||
| #include "test/arm_common/fixture.h" | #include "test/arm_common/fixture.h" | ||||
| @@ -30,8 +31,7 @@ TEST_F(ARM_COMMON, MATRIX_MUL_INT8x8x16) { | |||||
| TEST_F(ARM_COMMON, MATRIX_MUL_QUINT8) { | TEST_F(ARM_COMMON, MATRIX_MUL_QUINT8) { | ||||
| matrix_mul::check_matrix_mul(dtype::Quantized8Asymm(1.2f, (uint8_t)127), | matrix_mul::check_matrix_mul(dtype::Quantized8Asymm(1.2f, (uint8_t)127), | ||||
| dtype::Quantized8Asymm(1.3f, (uint8_t)129), | |||||
| {}, | |||||
| dtype::Quantized8Asymm(1.3f, (uint8_t)129), {}, | |||||
| handle()); | handle()); | ||||
| } | } | ||||
| @@ -232,8 +232,7 @@ TEST_F(ARM_COMMON, QINT8x8x32_GEVM) { | |||||
| Checker<MatrixMul> checker(handle()); | Checker<MatrixMul> checker(handle()); | ||||
| using Param = MatrixMul::Param; | using Param = MatrixMul::Param; | ||||
| checker.set_before_exec_callback( | |||||
| AlgoChecker<MatrixMul>("ARM_COMMON_GEVM")); | |||||
| checker.set_before_exec_callback(AlgoChecker<MatrixMul>("ARM_COMMON_GEVM")); | |||||
| std::unique_ptr<RNG> rng = std::make_unique<UniformIntRNG>(-127, 127); | std::unique_ptr<RNG> rng = std::make_unique<UniformIntRNG>(-127, 127); | ||||
| checker.set_rng(0, rng.get()).set_rng(1, rng.get()); | checker.set_rng(0, rng.get()).set_rng(1, rng.get()); | ||||
| @@ -251,7 +250,7 @@ TEST_F(ARM_COMMON, QINT8x8x32_GEVM) { | |||||
| .set_dtype(2, dtype::QuantizedS32(6.25f)) | .set_dtype(2, dtype::QuantizedS32(6.25f)) | ||||
| .execs({A, B, {}}); | .execs({A, B, {}}); | ||||
| }; | }; | ||||
| // M = 1 | // M = 1 | ||||
| for (size_t N : {1, 10, 16, 33, 64}) | for (size_t N : {1, 10, 16, 33, 64}) | ||||
| for (size_t K : {7, 512, 1024}) | for (size_t K : {7, 512, 1024}) | ||||
| @@ -263,8 +262,7 @@ TEST_F(ARM_COMMON, FP32_GEVM) { | |||||
| Checker<MatrixMul> checker(handle()); | Checker<MatrixMul> checker(handle()); | ||||
| using Param = MatrixMul::Param; | using Param = MatrixMul::Param; | ||||
| checker.set_before_exec_callback( | |||||
| AlgoChecker<MatrixMul>("ARM_COMMON_GEVM")); | |||||
| checker.set_before_exec_callback(AlgoChecker<MatrixMul>("ARM_COMMON_GEVM")); | |||||
| checker.set_epsilon(1e-2); | checker.set_epsilon(1e-2); | ||||
| auto run = [&](size_t M, size_t K, size_t N) { | auto run = [&](size_t M, size_t K, size_t N) { | ||||
| @@ -276,7 +274,7 @@ TEST_F(ARM_COMMON, FP32_GEVM) { | |||||
| B = TensorShape{N, K}; | B = TensorShape{N, K}; | ||||
| checker.set_param(param).execs({A, B, {}}); | checker.set_param(param).execs({A, B, {}}); | ||||
| }; | }; | ||||
| // M = 1 | // M = 1 | ||||
| for (size_t M : {1}) | for (size_t M : {1}) | ||||
| for (size_t K : {1000, 4096, 25088}) | for (size_t K : {1000, 4096, 25088}) | ||||
| @@ -298,15 +296,15 @@ TEST_F(ARM_COMMON, FP32_GEMV_MK4) { | |||||
| param.transposeA = false; | param.transposeA = false; | ||||
| param.transposeB = false; | param.transposeB = false; | ||||
| TensorShape A, B; | TensorShape A, B; | ||||
| A = TensorShape{M/4, K/4, 4, 4}; | |||||
| B = TensorShape{K/4, 1, 4}; | |||||
| A = TensorShape{M / 4, K / 4, 4, 4}; | |||||
| B = TensorShape{K / 4, 1, 4}; | |||||
| checker.set_param(param).execs({A, B, {}}); | checker.set_param(param).execs({A, B, {}}); | ||||
| }; | }; | ||||
| // N = 1 | // N = 1 | ||||
| for (size_t M : {4, 16, 128, 1024}) | for (size_t M : {4, 16, 128, 1024}) | ||||
| for (size_t K : {4, 8, 12, 128, 256, 4096}) | for (size_t K : {4, 8, 12, 128, 256, 4096}) | ||||
| run(M, K); | |||||
| run(M, K); | |||||
| } | } | ||||
| #if MEGDNN_WITH_BENCHMARK | #if MEGDNN_WITH_BENCHMARK | ||||
| @@ -343,7 +341,7 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMV) { | |||||
| for (size_t M : {4, 64, 1024, 4096}) | for (size_t M : {4, 64, 1024, 4096}) | ||||
| for (size_t K : {128, 256, 1024, 4096}) | for (size_t K : {128, 256, 1024, 4096}) | ||||
| run(M, K, 1); | |||||
| run(M, K, 1); | |||||
| } | } | ||||
| TEST_F(ARM_COMMON, BENCHMARK_SGEMV_FP32) { | TEST_F(ARM_COMMON, BENCHMARK_SGEMV_FP32) { | ||||
| @@ -372,7 +370,7 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMV_FP32) { | |||||
| .exec({{2, 1024}, {1024, 512}, {}}); | .exec({{2, 1024}, {1024, 512}, {}}); | ||||
| benchmarker.set_display(true); | benchmarker.set_display(true); | ||||
| } | } | ||||
| // run gemv | // run gemv | ||||
| run(12, 48, 1); | run(12, 48, 1); | ||||
| run(48, 12, 1); | run(48, 12, 1); | ||||
| @@ -396,14 +394,14 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMV_MK4) { | |||||
| Benchmarker<MatrixMul> benchmarker(handle()); | Benchmarker<MatrixMul> benchmarker(handle()); | ||||
| benchmarker.set_times(exec_times); | benchmarker.set_times(exec_times); | ||||
| benchmarker.set_dtype(0, dtype::Float32()) | benchmarker.set_dtype(0, dtype::Float32()) | ||||
| .set_dtype(1, dtype::Float32()) | |||||
| .set_param(param); | |||||
| .set_dtype(1, dtype::Float32()) | |||||
| .set_param(param); | |||||
| auto run = [&](size_t M, size_t K) { | auto run = [&](size_t M, size_t K) { | ||||
| printf("SGEMV_MK4: (%zu, %zu, %zu)\n", M, K, N); | |||||
| printf("SGEMV_MK4: (%zu, %zu)\n", M, K); | |||||
| TensorShape A, B; | TensorShape A, B; | ||||
| A = TensorShape{M/4, K/4, 4, 4}; | |||||
| B = TensorShape{K/4, 1, 4}; | |||||
| A = TensorShape{M / 4, K / 4, 4, 4}; | |||||
| B = TensorShape{K / 4, 1, 4}; | |||||
| auto time = benchmarker.exec({A, B, {}}) / exec_times; | auto time = benchmarker.exec({A, B, {}}) / exec_times; | ||||
| auto computations = 2.f * M * K * 1e-6; | auto computations = 2.f * M * K * 1e-6; | ||||
| auto perf = computations / time; | auto perf = computations / time; | ||||
| @@ -422,7 +420,7 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMV_MK4) { | |||||
| // run gemv mk4 | // run gemv mk4 | ||||
| for (size_t M : {4, 64, 1024, 4096}) | for (size_t M : {4, 64, 1024, 4096}) | ||||
| for (size_t K : {128, 1024, 4096}) | for (size_t K : {128, 1024, 4096}) | ||||
| run(M, K); | |||||
| run(M, K); | |||||
| } | } | ||||
| TEST_F(ARM_COMMON, BENCHMARK_SGEMV_FP16) { | TEST_F(ARM_COMMON, BENCHMARK_SGEMV_FP16) { | ||||
| @@ -490,7 +488,7 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMM) { | |||||
| //////////////////////// gemv ////////////////////////// | //////////////////////// gemv ////////////////////////// | ||||
| for (size_t M : {8, 64, 112, 256}) { | for (size_t M : {8, 64, 112, 256}) { | ||||
| for (size_t K : {8, 64, 112, 256}) { | for (size_t K : {8, 64, 112, 256}) { | ||||
| run (M, 1, K); | |||||
| run(M, 1, K); | |||||
| } | } | ||||
| } | } | ||||
| @@ -502,10 +500,8 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMM) { | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| TEST_F(ARM_COMMON, BENCHMARK_MATRIX_MUL_INT8x8x32) { | TEST_F(ARM_COMMON, BENCHMARK_MATRIX_MUL_INT8x8x32) { | ||||
| constexpr size_t RUNS = 50; | constexpr size_t RUNS = 50; | ||||
| param::MatrixMul param; | param::MatrixMul param; | ||||
| @@ -514,7 +510,8 @@ TEST_F(ARM_COMMON, BENCHMARK_MATRIX_MUL_INT8x8x32) { | |||||
| .set_dtype(0, dtype::Int8{}) | .set_dtype(0, dtype::Int8{}) | ||||
| .set_dtype(1, dtype::Int8{}) | .set_dtype(1, dtype::Int8{}) | ||||
| .set_dtype(2, dtype::Int32{}) | .set_dtype(2, dtype::Int32{}) | ||||
| .set_param(param).set_display(false); | |||||
| .set_param(param) | |||||
| .set_display(false); | |||||
| Benchmarker<MatrixMul> benchmarker_float(handle()); | Benchmarker<MatrixMul> benchmarker_float(handle()); | ||||
| benchmarker_float.set_display(false).set_times(RUNS); | benchmarker_float.set_display(false).set_times(RUNS); | ||||
| @@ -533,7 +530,7 @@ TEST_F(ARM_COMMON, BENCHMARK_MATRIX_MUL_INT8x8x32) { | |||||
| //////////////////////// gemv ////////////////////////// | //////////////////////// gemv ////////////////////////// | ||||
| for (size_t M : {8, 64, 112, 256}) { | for (size_t M : {8, 64, 112, 256}) { | ||||
| for (size_t K : {8, 64, 112, 256}) { | for (size_t K : {8, 64, 112, 256}) { | ||||
| run (M, 1, K); | |||||
| run(M, 1, K); | |||||
| } | } | ||||
| } | } | ||||
| @@ -618,5 +615,4 @@ TEST_F(ARM_COMMON, BENCHMARK_TRANSPOSED_MATRIX_MUL_QUINT8) { | |||||
| #endif | #endif | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||