diff --git a/mindspore/lite/nnacl/assembly/arm32/IndirectGemmInt8_2x4.S b/mindspore/lite/nnacl/assembly/arm32/IndirectGemmInt8_2x4.S index 7dc621c7d5..c3cf470ab0 100644 --- a/mindspore/lite/nnacl/assembly/arm32/IndirectGemmInt8_2x4.S +++ b/mindspore/lite/nnacl/assembly/arm32/IndirectGemmInt8_2x4.S @@ -9,8 +9,8 @@ #endif // void IndirectGemmInt8_2x4(int8_t *output, int8_t *input, int8_t *weight, int32_t *bias, size_t ksize, size_t ic4, -// size_t oc, size_t offset, int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, size_t out_multiplier, -// size_t shift_before, size_t shift_after); +// size_t oc, size_t offset, int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, int32_t *out_multiplier, +// int32_t *shift_before, int32_t *shift_after, size_t asymmetric, size_t per_channel, size_t per_channel_offset); // r0: output, r1: input, r2: weight, r3: bias, r4: kSize, r5: ic4, r6: oc, r7: offset // r8: input_sum, r10: act_min, r11: act_max, r10: out_zp, r11: out_multiplier, r10: shift_before, r11: shift_after IndirectGemmInt8_2x4: @@ -24,7 +24,7 @@ IndirectGemmInt8_2x4: veor q15, q15, q15 .endm - // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" + // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" // according to https://stackoverflow.com/questions/53625807 // even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway // clang's rule seems more simple, though there are no subroutine calls here @@ -127,10 +127,6 @@ IndirectGemmInt8_2x4: vpadal.s16 q14, q6 vpadal.s16 q15, q7 - // load sum - ldr r10, [sp, #16] - vld1.32 q0[], [r10]! - vld1.32 q1[], [r10]! // pairwise add vpadd.i32 d16, d16, d17 vpadd.i32 d18, d18, d19 @@ -145,8 +141,27 @@ IndirectGemmInt8_2x4: vpadd.i32 d17, d20, d22 vpadd.i32 d24, d24, d26 vpadd.i32 d25, d28, d30 + + // load sum + ldr lr, [sp, #44] + cmp lr, #0 + beq NoSum + ldr r10, [sp, #16] + ldr lr, [sp, #48] + cmp lr, #0 + beq SymSum + ldr lr, [sp, #52] + vld1.32 q0, [r10] + add r10, r10, lr + vld1.32 q1, [r10] + b AddSum + SymSum: + vld1.32 q0[], [r10]! + vld1.32 q1[], [r10]! + AddSum: vsub.i32 q8, q8, q0 vsub.i32 q12, q12, q1 + NoSum: cmp r3, #0 beq NoBias vld1.32 q2, [r3] @@ -154,18 +169,30 @@ IndirectGemmInt8_2x4: vadd.i32 q12, q12, q2 NoBias: - ldr r10, [sp, #36] - vdup.32 q3, r10 + ldr lr, [sp, #48] + cmp lr, #0 + bne PerChannel + ldr lr, [sp, #36] + vld1.32 q3[], [lr] + ldr lr, [sp, #32] + vld1.32 q4[], [lr] + ldr lr, [sp, #40] + vld1.32 q5[], [lr] + b QuantizeStart + PerChannel: + ldr lr, [sp, #36] + vld1.32 q3, [lr] + ldr lr, [sp, #32] + vld1.32 q4, [lr] + ldr lr, [sp, #40] + vld1.32 q5, [lr] + QuantizeStart: vshl.s32 q8, q8, q3 vshl.s32 q12, q12, q3 - ldr r10, [sp, #32] - vdup.32 q4, r10 vqrdmulh.s32 q8, q8, q4 vqrdmulh.s32 q12, q12, q4 - ldr r10, [sp, #40] - vdup.32 q5, r10 vand q3, q5, q8 vshr.s32 q3, q3, #31 vqadd.s32 q8, q8, q3 @@ -192,7 +219,7 @@ IndirectGemmInt8_2x4: vqmovn.s32 d30, q8 vqmovn.s32 d31, q12 - vqmovn.s16 d0, q14 + vqmovn.s16 d0, q15 // prefetching is not prefered while writing results in spite of cache missings // you could try prfm pstl2strm @@ -234,6 +261,26 @@ IndirectGemmInt8_2x4: cmp r6, #4 ble LoopOcEnd + ldr lr, [sp, #48] + cmp lr, #0 + beq NoChannelForward + ldr lr, [sp, #44] + cmp lr, #0 + beq NoSumForward + ldr lr, [sp, #16] + add lr, lr, #16 + str lr, [sp, #16] + NoSumForward: + ldr lr, [sp, #36] + add lr, lr, #16 + str lr, [sp, #36] + ldr lr, [sp, #32] + add lr, lr, #16 + str lr, [sp, #32] + ldr lr, [sp, #40] + add lr, lr, #16 + str lr, [sp, #40] + NoChannelForward: sub r6, r6, #4 cmp r3, #0 beq NoStepFowrard diff --git a/mindspore/lite/nnacl/assembly/arm64/IndirectGemmInt8_4x4.S b/mindspore/lite/nnacl/assembly/arm64/IndirectGemmInt8_4x4.S index 37711436d9..bdbfa738b2 100644 --- a/mindspore/lite/nnacl/assembly/arm64/IndirectGemmInt8_4x4.S +++ b/mindspore/lite/nnacl/assembly/arm64/IndirectGemmInt8_4x4.S @@ -8,8 +8,8 @@ #endif // void IndirectGemmInt8_4x4(int8_t *output, int8_t *input, int8_t *weight, int32_t *bias, size_t ksize, size_t ic4, -// size_t oc, size_t offset, int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, -// int32_t *out_multiplier, int32_t *shift_before, int32_t *shift_after, size_t asymmetric, size_t per_channel); +// size_t oc, size_t offset, int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, int32_t *out_multiplier, +// int32_t *shift_before, int32_t *shift_after, size_t asymmetric, size_t per_channel, size_t per_channel_offset); // x0: output, x1: input, x2: weight, x3: bias, x4: kSize, x5: ic4, x6: oc, x7: offset IndirectGemmInt8_4x4: @@ -52,10 +52,7 @@ IndirectGemmInt8_4x4: ldr x19, [sp, #48] ldr x20, [sp, #56] ldr x21, [sp, #64] - - add x24, x6, #3 - mov x23, #4 - sdiv x23, x24, x23 + ldr x23, [sp, #72] mul x5, x4, x5 mov x4, #1 @@ -218,10 +215,10 @@ IndirectGemmInt8_4x4: // load sum mov x22, x15 cbz x21, SymSum - ld1r {v8.4s}, [x22], x23 - ld1r {v9.4s}, [x22], x23 - ld1r {v10.4s}, [x22], x23 - ld1r {v11.4s}, [x22] + ld1 {v8.4s}, [x22], x23 + ld1 {v9.4s}, [x22], x23 + ld1 {v10.4s}, [x22], x23 + ld1 {v11.4s}, [x22] b AddSum SymSum: ld1r {v8.4s}, [x22], #4 diff --git a/mindspore/lite/nnacl/assembly/opt/IndirectGemmInt8_24x4_dp.S b/mindspore/lite/nnacl/assembly/opt/IndirectGemmInt8_24x4_dp.S index be79622f64..2c43efb982 100644 --- a/mindspore/lite/nnacl/assembly/opt/IndirectGemmInt8_24x4_dp.S +++ b/mindspore/lite/nnacl/assembly/opt/IndirectGemmInt8_24x4_dp.S @@ -9,7 +9,7 @@ // void IndirectGemmInt8_24x4_dp(int8_t *output, int8_t *input, int8_t *weight, int32_t *bias, size_t ksize, size_t ic4, // size_t oc, size_t offset, int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, int32_t *out_multiplier, -// int32_t *shift_before, int32_t *shift_after); +// int32_t *shift_before, int32_t *shift_after, size_t asymmetric, size_t per_channel, size_t per_channel_offset); // x0: output, x1: input, x2: weight, x3: bias, x4: kSize, x5: ic4, x6: oc, x7: offset // we use sdot intrinsic on cores that supports dotprod(Armv8.2-A w/dp or later) // mrs intrinsic could read system register ID_AA64ISAR0_EL1(or s3_0_c0_c6_0 on Armv8.2-A) @@ -148,10 +148,7 @@ IndirectGemmInt8_24x4_dp: ldr x19, [sp, #48] ldr x20, [sp, #56] ldr x21, [sp, #64] - - add x24, x6, #3 - mov x23, #4 - sdiv x23, x24, x23 + ldr x23, [sp, #72] mul x5, x4, x5 mov x4, #1 diff --git a/mindspore/lite/nnacl/int8/common_func.h b/mindspore/lite/nnacl/int8/common_func.h index 56f2bb9e42..3d35e6fc37 100644 --- a/mindspore/lite/nnacl/int8/common_func.h +++ b/mindspore/lite/nnacl/int8/common_func.h @@ -37,18 +37,25 @@ void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t * int output_channel, int input_step, int8_t input_zp); void ConvDwInt8PostAlign4(int8_t *dst, int32_t *buffer, int num_pixels, int32_t output_zp, int32_t out_multiplier, int32_t left_shift, int32_t right_shift, int32_t acc_min, int32_t acc_max); +void IndirectGemmInt16to32_8x4(int32_t *dst, const int16_t *src, const int16_t *weight, size_t ksize, size_t ic8, + size_t oc4, size_t offset); +#endif + +#ifdef ENABLE_ARM32 +void IndirectGemmInt8_2x4(int8_t *output, const int8_t *input, const int8_t *weight, const int32_t *bias, size_t ksize, + size_t ic4, size_t oc, size_t offset, const int32_t *input_sum, size_t act_min, + size_t act_max, size_t out_zp, int32_t *out_multiplier, int32_t *shift_before, + int32_t *shift_after, size_t asymmetric, size_t per_channel, size_t per_channel_offset); #endif #ifdef ENABLE_ARM64 void PostFuncInt8C4Neon64(const int32_t *in, const int32_t *bias, int8_t *out, size_t oc4div, size_t oc4res, size_t plane, size_t stride, int32_t multiplier, int32_t left_shift, int32_t right_shift, int32_t zp, int32_t mini, int32_t maxi); -void IndirectGemmInt16to32_8x4(int32_t *dst, const int16_t *src, const int16_t *weight, size_t ksize, size_t ic8, - size_t oc4, size_t offset); void IndirectGemmInt8_4x4(int8_t *output, const int8_t *input, const int8_t *weight, const int32_t *bias, size_t ksize, size_t ic4, size_t oc, size_t offset, const int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, int32_t *out_multiplier, int32_t *shift_before, - int32_t *shift_after, size_t asymmetric, size_t per_channel); + int32_t *shift_after, size_t asymmetric, size_t per_channel, size_t per_channel_offset); void DeconvDwInt8Center(int32_t *dst, const int16_t *src, const int16_t *weight, size_t height, size_t width, size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step); diff --git a/mindspore/lite/nnacl/int8/conv_int8.c b/mindspore/lite/nnacl/int8/conv_int8.c index f1e87c8682..6f8a2a5c60 100644 --- a/mindspore/lite/nnacl/int8/conv_int8.c +++ b/mindspore/lite/nnacl/int8/conv_int8.c @@ -28,15 +28,21 @@ void IndirectGemmInt8(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const in int32_t out_zp = conv_param->conv_quant_arg_.output_quant_args_[0].zp_; int32_t act_min = conv_param->conv_quant_arg_.out_act_min_[0]; int32_t act_max = conv_param->conv_quant_arg_.out_act_max_[0]; + int oc4 = UP_DIV(output_channel, C4NUM); #ifdef ENABLE_ARM64 size_t asymmetric = conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC; size_t per_channel = conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL; IndirectGemmInt8_4x4(dst, src, weight, bias, UP_DIV(kernel_plane, C4NUM), ic4, output_channel, output_channel * sizeof(int8_t), input_sum, act_min, act_max, out_zp, out_multiplier, - shift_before, shift_after, asymmetric, per_channel); + shift_before, shift_after, asymmetric, per_channel, oc4 * C4NUM * sizeof(int32_t)); +#elif ENABLE_ARM32 + size_t asymmetric = conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC; + size_t per_channel = conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL; + IndirectGemmInt8_2x4(dst, src, weight, bias, UP_DIV(kernel_plane, C4NUM), ic4, output_channel, + output_channel * sizeof(int8_t), input_sum, act_min, act_max, out_zp, out_multiplier, + shift_before, shift_after, asymmetric, per_channel, oc4 * C4NUM * sizeof(int32_t)); #else - int oc4 = UP_DIV(output_channel, C4NUM); int tile_num = conv_param->tile_num_; int plane_c4 = UP_DIV(kernel_plane, C4NUM); for (int oc = 0; oc < output_channel; oc++) { @@ -201,7 +207,7 @@ void IndirectGemmInt8Opt(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const void Conv3x3Int8Gemm(int32_t *dst, const int16_t *src, const int16_t *weight, int oc, int ic8, size_t real_cal_num) { int oc4 = UP_DIV(oc, C4NUM); -#ifdef ENABLE_ARM64 +#ifdef ENABLE_ARM IndirectGemmInt16to32_8x4(dst, src, weight, 16, ic8, oc4, oc4 * 4 * 16 * sizeof(int32_t)); #else const int input_unit_square = 16; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.cc index 3b4bf822e3..e7cf131d95 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.cc @@ -33,10 +33,10 @@ using mindspore::schema::PrimitiveType_Conv2D; namespace mindspore::kernel { void ConvolutionInt8CPUKernel::CheckSupportOptimize() { tile_num_ = 24; - // #ifdef ENABLE_ARM32 - // tile_num_ = 2; - // support_optimize_ = false; - // #endif +#ifdef ENABLE_ARM32 + tile_num_ = 2; + support_optimize_ = false; +#endif #ifdef ENABLE_ARM64 void *optimize_op_handler = OptimizeModule::GetInstance()->optimized_op_handler_; @@ -380,7 +380,11 @@ kernel::LiteKernel *CpuConvInt8KernelCreator(const std::vector & int dilation_w = conv_param->dilation_w_; kernel::LiteKernel *kernel; if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) { +#ifdef ENABLE_ARM32 + kernel = new (std::nothrow) kernel::Convolution3x3Int8CPUKernel(opParameter, inputs, outputs, ctx, primitive); +#else kernel = new (std::nothrow) kernel::ConvolutionInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); +#endif } else if (kernel_h == 1 && kernel_w == 1) { kernel = new (std::nothrow) kernel::Convolution1x1Int8CPUKernel(opParameter, inputs, outputs, ctx, primitive); } else {