| @@ -47,6 +47,52 @@ struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, 1, T, T2, T3, T4> { | |||||
| } | } | ||||
| }; | }; | ||||
| ////////////////////stride 1/////////////////// | ////////////////////stride 1/////////////////// | ||||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | |||||
| int ow_block> | |||||
| struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 1, oc_block, ow_block, | |||||
| 1> { | |||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | |||||
| int iw, int ld_dst_oc, const Op& op) { | |||||
| constexpr int stride = 1; | |||||
| constexpr int filter_hight = 1; | |||||
| constexpr int filter_width = 4; | |||||
| constexpr int weight_reg = 2; | |||||
| constexpr int src_reg = 2; | |||||
| constexpr int oc_step = 4; | |||||
| constexpr int ic_step = 1; | |||||
| constexpr int pack_iw_len = 4; | |||||
| constexpr int simd_len = 16; | |||||
| const int ld_bias = oc_step; | |||||
| const int ic_stride = ih * iw * pack_iw_len; | |||||
| const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; | |||||
| constexpr int c_dim = OCHelper<oc_block>::val; | |||||
| int32x4_t c[c_dim][8]; | |||||
| init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | |||||
| int8x16_t src[src_reg]; | |||||
| int8x16_t weight[c_dim][weight_reg]; | |||||
| // row 0 | |||||
| load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( | |||||
| src, src_ptr + 0 * iw * pack_iw_len, 0); | |||||
| load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( | |||||
| weight, weight_ptr, ld_weight_oc); | |||||
| cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, | |||||
| weight); | |||||
| src_ptr += ic_stride; | |||||
| weight_ptr += filter_hight * filter_width * oc_step; | |||||
| } | |||||
| store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>( | |||||
| c, op, dst_ptr, ld_dst_oc); | |||||
| } | |||||
| }; | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | ||||
| int ow_block> | int ow_block> | ||||
| struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 2, oc_block, ow_block, | struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 2, oc_block, ow_block, | ||||
| @@ -441,6 +487,7 @@ void conv_direct_int8_nchw_nchw44_dot(const int8_t* src, const int8_t* filter, | |||||
| GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) | GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) | ||||
| #define DISPATCH_CONV_KERN(stride) \ | #define DISPATCH_CONV_KERN(stride) \ | ||||
| GET_BIAS_MODE_PARAM(stride, 1) \ | |||||
| GET_BIAS_MODE_PARAM(stride, 2) \ | GET_BIAS_MODE_PARAM(stride, 2) \ | ||||
| GET_BIAS_MODE_PARAM(stride, 3) \ | GET_BIAS_MODE_PARAM(stride, 3) \ | ||||
| GET_BIAS_MODE_PARAM(stride, 5) \ | GET_BIAS_MODE_PARAM(stride, 5) \ | ||||
| @@ -58,6 +58,17 @@ struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, 2, T, T2, T3, T4> { | |||||
| } | } | ||||
| }; | }; | ||||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | |||||
| int ow_block> | |||||
| struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 1, oc_block, ow_block, | |||||
| 2> { | |||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
| static void impl(const int8_t*, const int8_t*, const int32_t*, int8_t*, int, | |||||
| int, int, int, const Op&) { | |||||
| megdnn_assert(0, "not impl"); | |||||
| } | |||||
| }; | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | ||||
| int ow_block> | int ow_block> | ||||
| struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 2, oc_block, ow_block, | struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 2, oc_block, ow_block, | ||||
| @@ -429,6 +440,7 @@ void conv_direct_int8_nchw_nchw44_dot(const int8_t* src, const int8_t* filter, | |||||
| GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) | GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) | ||||
| #define DISPATCH_CONV_KERN(stride) \ | #define DISPATCH_CONV_KERN(stride) \ | ||||
| GET_BIAS_MODE_PARAM(stride, 1) \ | |||||
| GET_BIAS_MODE_PARAM(stride, 2) \ | GET_BIAS_MODE_PARAM(stride, 2) \ | ||||
| GET_BIAS_MODE_PARAM(stride, 3) \ | GET_BIAS_MODE_PARAM(stride, 3) \ | ||||
| GET_BIAS_MODE_PARAM(stride, 5) \ | GET_BIAS_MODE_PARAM(stride, 5) \ | ||||
| @@ -112,6 +112,47 @@ struct ShiftCalHelper<src_idx, weight_idx, 1, 1, T, T2, T3, T4> { | |||||
| static MEGDNN_ALWAYS_INLINE void impl(T&, T2&, T3&); | static MEGDNN_ALWAYS_INLINE void impl(T&, T2&, T3&); | ||||
| }; | }; | ||||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block> | |||||
| struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 1, oc_block, 1> { | |||||
| static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
| const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | |||||
| int iw, int ld_dst_oc, const Op& op) { | |||||
| constexpr int stride = 1; | |||||
| constexpr int filter_height = 1; | |||||
| constexpr int filter_width = 4; | |||||
| constexpr int oc_step = 4; | |||||
| constexpr int loop_ic_step = 1; | |||||
| constexpr int simd_len = 16; | |||||
| constexpr int pack_iw_len = 16; | |||||
| constexpr int src_reg = 8; | |||||
| constexpr int weight_reg = 1; | |||||
| const int ic_stride = ih * iw * pack_iw_len; | |||||
| const int ld_weight_oc = oc_step * filter_height * filter_width * ic; | |||||
| constexpr int c_dim = OCHelper<oc_block>::val; | |||||
| int32x4_t c[c_dim][8]; | |||||
| init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step); | |||||
| for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
| const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; | |||||
| int8x16_t src[src_reg]; | |||||
| int8x16_t dot4_weight[c_dim][weight_reg]; | |||||
| int16x8_t temp_c[4]; | |||||
| load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( | |||||
| dot4_weight, weight_ptr, ld_weight_oc); | |||||
| load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( | |||||
| src, nchw_src_ptr + 0 * iw * pack_iw_len, 0); | |||||
| cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); | |||||
| weight_ptr += oc_step * filter_height * filter_width; | |||||
| } | |||||
| store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>( | |||||
| c, op, dst_ptr, ld_dst_oc); | |||||
| } | |||||
| }; | |||||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block> | template <BiasMode bias_mode, typename Op, int remain_w, int oc_block> | ||||
| struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 2, oc_block, 1> { | struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 2, oc_block, 1> { | ||||
| static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | ||||
| @@ -547,6 +588,7 @@ struct ConvDiectStrideInt8NchwNchw44<bias_mode, Op, filter_size, 1> { | |||||
| INSTANCE_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) | INSTANCE_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) | ||||
| #define INSTANCE_CONV_KERN(stride) \ | #define INSTANCE_CONV_KERN(stride) \ | ||||
| INSTANCE_BIAS_MODE_PARAM(stride, 1) \ | |||||
| INSTANCE_BIAS_MODE_PARAM(stride, 2) \ | INSTANCE_BIAS_MODE_PARAM(stride, 2) \ | ||||
| INSTANCE_BIAS_MODE_PARAM(stride, 3) \ | INSTANCE_BIAS_MODE_PARAM(stride, 3) \ | ||||
| INSTANCE_BIAS_MODE_PARAM(stride, 5) \ | INSTANCE_BIAS_MODE_PARAM(stride, 5) \ | ||||
| @@ -1033,6 +1033,15 @@ struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 2, oc_block, stride> { | |||||
| } | } | ||||
| }; | }; | ||||
| template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | |||||
| int stride> | |||||
| struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 1, oc_block, stride> { | |||||
| static void impl(const int8_t*, const int8_t*, const int32_t*, int8_t*, int, | |||||
| int, int, int, const Op&) { | |||||
| megdnn_assert(0, "not impl nchw_nchw44 1x1 s2"); | |||||
| } | |||||
| }; | |||||
| enum PACK_MODE { NO_PAD = 0, FIRST_PAD = 1, LAST_PAD = 2 }; | enum PACK_MODE { NO_PAD = 0, FIRST_PAD = 1, LAST_PAD = 2 }; | ||||
| template <PACK_MODE mode> | template <PACK_MODE mode> | ||||
| MEGDNN_ALWAYS_INLINE void pack_src_one_line(const int8_t* inptr, int8_t* outptr, | MEGDNN_ALWAYS_INLINE void pack_src_one_line(const int8_t* inptr, int8_t* outptr, | ||||
| @@ -1398,6 +1407,7 @@ struct ConvDiectStrideInt8NchwNchw44<bias_mode, Op, filter_size, 2> { | |||||
| INSTANCE_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) | INSTANCE_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) | ||||
| #define INSTANCE_CONV_KERN(stride) \ | #define INSTANCE_CONV_KERN(stride) \ | ||||
| INSTANCE_BIAS_MODE_PARAM(stride, 1) \ | |||||
| INSTANCE_BIAS_MODE_PARAM(stride, 2) \ | INSTANCE_BIAS_MODE_PARAM(stride, 2) \ | ||||
| INSTANCE_BIAS_MODE_PARAM(stride, 3) \ | INSTANCE_BIAS_MODE_PARAM(stride, 3) \ | ||||
| INSTANCE_BIAS_MODE_PARAM(stride, 5) \ | INSTANCE_BIAS_MODE_PARAM(stride, 5) \ | ||||
| @@ -291,6 +291,9 @@ ConvBiasImpl::AlgoS8DirectNCHWNCHW44::dispatch_kerns( | |||||
| #define DISPATCH_CONV_KERN(stride) \ | #define DISPATCH_CONV_KERN(stride) \ | ||||
| switch (param.filter_meta.spatial[0]) { \ | switch (param.filter_meta.spatial[0]) { \ | ||||
| case 1: \ | |||||
| GET_BIAS_MODE_PARAM(stride, 1) \ | |||||
| break; \ | |||||
| case 2: \ | case 2: \ | ||||
| GET_BIAS_MODE_PARAM(stride, 2) \ | GET_BIAS_MODE_PARAM(stride, 2) \ | ||||
| break; \ | break; \ | ||||
| @@ -245,6 +245,9 @@ ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::dispatch_kerns( | |||||
| #define DISPATCH_CONV_KERN(stride) \ | #define DISPATCH_CONV_KERN(stride) \ | ||||
| switch (param.filter_meta.spatial[0]) { \ | switch (param.filter_meta.spatial[0]) { \ | ||||
| case 1: \ | |||||
| GET_BIAS_MODE_PARAM(stride, 1) \ | |||||
| break; \ | |||||
| case 2: \ | case 2: \ | ||||
| GET_BIAS_MODE_PARAM(stride, 2) \ | GET_BIAS_MODE_PARAM(stride, 2) \ | ||||
| break; \ | break; \ | ||||
| @@ -74,9 +74,11 @@ inline bool nchw_nchwxx_valid<NCHW44_INT8>( | |||||
| nonline_mode == param::ConvBias::NonlineMode::H_SWISH; | nonline_mode == param::ConvBias::NonlineMode::H_SWISH; | ||||
| bool ok_src_dst = | bool ok_src_dst = | ||||
| fm.icpg < 4 && (fm.ocpg % 4 == 0 && fm.ocpg >= 4) && fm.group == 1; | fm.icpg < 4 && (fm.ocpg % 4 == 0 && fm.ocpg >= 4) && fm.group == 1; | ||||
| bool ok_filter = fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] && | |||||
| (fm.spatial[0] == 2 || fm.spatial[0] == 3 || | |||||
| fm.spatial[0] == 5 || fm.spatial[0] == 7); | |||||
| bool ok_filter = | |||||
| fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] && | |||||
| (fm.spatial[0] == 2 || fm.spatial[0] == 3 || fm.spatial[0] == 5 || | |||||
| fm.spatial[0] == 7 || | |||||
| (fm.spatial[0] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1)); | |||||
| bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 && | bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 && | ||||
| fm.stride[0] == fm.stride[1] && | fm.stride[0] == fm.stride[1] && | ||||
| (fm.stride[0] == 1 || fm.stride[1] == 2); | (fm.stride[0] == 1 || fm.stride[1] == 2); | ||||
| @@ -126,9 +128,11 @@ inline bool nchw_nchwxx_valid<NCHW44_INT8_DOT>( | |||||
| nonline_mode == param::ConvBias::NonlineMode::H_SWISH; | nonline_mode == param::ConvBias::NonlineMode::H_SWISH; | ||||
| bool ok_src_dst = | bool ok_src_dst = | ||||
| fm.icpg < 4 && (fm.ocpg % 4 == 0 && fm.ocpg >= 4) && fm.group == 1; | fm.icpg < 4 && (fm.ocpg % 4 == 0 && fm.ocpg >= 4) && fm.group == 1; | ||||
| bool ok_filter = fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] && | |||||
| (fm.spatial[0] == 2 || fm.spatial[0] == 3 || | |||||
| fm.spatial[0] == 5 || fm.spatial[0] == 7); | |||||
| bool ok_filter = | |||||
| fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] && | |||||
| (fm.spatial[0] == 2 || fm.spatial[0] == 3 || fm.spatial[0] == 5 || | |||||
| fm.spatial[0] == 7 || | |||||
| (fm.spatial[0] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1)); | |||||
| bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 && | bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 && | ||||
| fm.stride[0] == fm.stride[1] && | fm.stride[0] == fm.stride[1] && | ||||
| (fm.stride[0] == 1 || fm.stride[1] == 2); | (fm.stride[0] == 1 || fm.stride[1] == 2); | ||||
| @@ -487,6 +487,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44_S2) { | |||||
| handle(), "S8_CONV_NCHW_NCHW44"); | handle(), "S8_CONV_NCHW_NCHW44"); | ||||
| } | } | ||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44_S1_F1) { | |||||
| checker_conv_bias_qint8x8x8( | |||||
| get_nchw44_conv_bias_args({1}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1, | |||||
| false, true), | |||||
| handle(), "S8_CONV_NCHW_NCHW44"); | |||||
| } | |||||
| /*****************************quint8 direct****************************/ | /*****************************quint8 direct****************************/ | ||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1) { | ||||
| checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args( | checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args( | ||||
| @@ -517,6 +524,15 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44) { | |||||
| checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44"); | checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44"); | ||||
| } | } | ||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44_S1_F1) { | |||||
| auto args = get_nchw44_conv_bias_args({1}, QUAN_NLMODE, BR_AND_NO_BIASMODE, | |||||
| 1, false, true); | |||||
| for (auto&& arg : args) { | |||||
| arg.param.format = param::ConvBias::Format::NCHW44_DOT; | |||||
| } | |||||
| checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44"); | |||||
| } | |||||
| TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_WITHDOTPROD) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_WITHDOTPROD) { | ||||
| checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args( | checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args( | ||||
| {2, 3, 5, 7}, 1, false, false, false), | {2, 3, 5, 7}, 1, false, false, false), | ||||
| @@ -635,6 +635,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_INT8_NCHW44) { | |||||
| benchmark_impl(param, shape_arg, ".+", RUNS, {4, {4, 5, 6, 7}}, | benchmark_impl(param, shape_arg, ".+", RUNS, {4, {4, 5, 6, 7}}, | ||||
| {1, {7}}, data_type); | {1, {7}}, data_type); | ||||
| }; | }; | ||||
| bench_case(1, 2, 64, 160, 160, 1, 1, 0, 1, true); | |||||
| bench_case(1, 3, 64, 224, 224, 7, 1, 3, 2, true); | bench_case(1, 3, 64, 224, 224, 7, 1, 3, 2, true); | ||||
| bench_case(1, 64, 64, 56, 56, 3, 1, 1, 1); | bench_case(1, 64, 64, 56, 56, 3, 1, 1, 1); | ||||
| bench_case(1, 128, 128, 28, 28, 3, 1, 1, 1); | bench_case(1, 128, 128, 28, 28, 3, 1, 1, 1); | ||||
| @@ -131,7 +131,8 @@ uint64_t eval_conv_computation(const TensorShape& src_shape, | |||||
| if (param.format == Param::Format::NCHW44 || | if (param.format == Param::Format::NCHW44 || | ||||
| param.format == Param::Format::NCHW44_DOT) { | param.format == Param::Format::NCHW44_DOT) { | ||||
| //! if channel wise weight layout is {group/4, FH, FW, 1, 1, 4} | //! if channel wise weight layout is {group/4, FH, FW, 1, 1, 4} | ||||
| if (filter_shape[1] == 1 && filter_shape[2] == 1) { | |||||
| if (filter_shape[1] == 1 && filter_shape[2] == 1 && | |||||
| filter_shape.ndim == 6) { | |||||
| group *= 4; | group *= 4; | ||||
| } | } | ||||
| size_t computation = dst_shape.total_nr_elems() * fh * fw * | size_t computation = dst_shape.total_nr_elems() * fh * fw * | ||||