diff --git a/mindspore/lite/nnacl/assembly/arm32/MatmulInt8.S b/mindspore/lite/nnacl/assembly/arm32/MatmulInt8.S index d5c66121f3..5756ff5bf5 100644 --- a/mindspore/lite/nnacl/assembly/arm32/MatmulInt8.S +++ b/mindspore/lite/nnacl/assembly/arm32/MatmulInt8.S @@ -21,24 +21,27 @@ MatmulInt8Neon32: add sp, sp, #116 ldr r4, [sp] // col - mov r7, #2 - ldr r8, [sp, #4] // deep16 - mul r9, r7, r8 // the sride of b ldr r7, [sp, #40] // output stride - + mov r8, #0 // output channels offset + ldr r10, [sp, #44] + cmp r10, #0 + beq L1 + ldr r6, [sp, #8] // load intpu_sums ptr if per_channel L1: cmp r4, #0 // if at the end of col ble End1 ldr r0, [sp, #-52] // reload a ptr ldr r3, [sp, #-40] // reset row counter - ldr r6, [sp, #8] // reload intpu_sums ptr + ldr r10, [sp, #44] + cmp r10, #0 + bne L2 + ldr r6, [sp, #8] // reload intpu_sums ptr if per_tensor L2: cmp r3, #0 // if at the end of row ble End2 ldr r1, [sp, #-48] // reload b ptr - ldr r8, [sp, #12] // reload weight_bias ptr ldr r5, [sp, #4] // reset deep16 vmov.i32 q6, #0 vmov.i32 q7, #0 @@ -101,7 +104,9 @@ End3: vpadd.i32 d31, d6, d7 // Add weight_bias - vld1.32 {d26}, [r8]! + ldr r9, [sp, #12] // reload weight_bias ptr + add r9, r9, r8 + vld1.32 {d26}, [r9]! vadd.i32 d28, d28, d26 vadd.i32 d29, d29, d26 vadd.i32 d30, d30, d26 @@ -111,6 +116,7 @@ End3: cmp r10, #0 bgt PerChannel +PerTensor: // Substract input_sums vld1.32 {d24, d25}, [r6]! vdup.32 d20, d24[0] @@ -124,7 +130,7 @@ End3: // Apply left shift ldr r10, [sp, #32] - ldr r11, [r10] + ldr r11, [r10]! vdup.32 q9, r11 vshl.s32 q14, q14, q9 vshl.s32 q15, q15, q9 @@ -151,7 +157,51 @@ End3: b AddDstZP PerChannel: + // Substract input_sums + vld1.32 {d24, d25, d26, d27}, [r6]! + vsub.s32 d28, d28, d24 + vsub.s32 d29, d29, d25 + vsub.s32 d30, d30, d26 + vsub.s32 d31, d31, d27 + // Apply left shift + ldr r10, [sp, #32] + add r10, r10, r8 + vld1.32 {d23}, [r10] + vshl.s32 d28, d28, d23 + vshl.s32 d29, d29, d23 + vshl.s32 d30, d30, d23 + vshl.s32 d31, d31, d23 + + // Apply the fixed-point part of the multiplier + ldr r10, [sp, #28] + add r10, r10, r8 + vld1.32 {d22}, [r10] + vqrdmulh.s32 d28, d28, d22 + vqrdmulh.s32 d29, d29, d22 + vqrdmulh.s32 d30, d30, d22 + vqrdmulh.s32 d31, d31, d22 + + // Apply right shift + ldr r10, [sp, #36] + add r10, r10, r8 + vld1.32 {d21}, [r10] + vand d20, d21, d28 + vshr.s32 d20, d20, #31 + vqadd.s32 d28, d28, d20 + vrshl.s32 d28, d28, d21 + vand d19, d21, d29 + vshr.s32 d19, d19, #31 + vqadd.s32 d29, d29, d19 + vrshl.s32 d29, d29, d21 + vand d18, d21, d30 + vshr.s32 d18, d18, #31 + vqadd.s32 d30, d30, d18 + vrshl.s32 d30, d30, d21 + vand d17, d21, d31 + vshr.s32 d17, d17, #31 + vqadd.s32 d31, d31, d17 + vrshl.s32 d31, d31, d21 AddDstZP: // Add the destination zero point @@ -218,15 +268,16 @@ EndWrite: End2: sub r4, r4, #2 // b col counter -= 2 - ldr r1, [sp, #-48] // b ptr + stride - add r1, r1, r9 - str r1, [sp, #-48] - ldr r8, [sp, #12] // weight_bias + stride - add r8, r8, #8 - str r8, [sp, #12] - ldr r2, [sp, #-44] // dst ptr + offset - add r2, r2, #2 + ldr r1, [sp, #-48] // load b ptr + ldr r9, [sp, #4] + mov r10, #2 + mul r9, r9, r10 // the stride of b + add r1, r1, r9 // b ptr + stride + str r1, [sp, #-48] + ldr r2, [sp, #-44] // load dst ptr + add r2, r2, #2 // dst ptr + offset str r2, [sp, #-44] + add r8, r8, #8 // output channels offset + 2*sizeof(int) b L1 End1: diff --git a/mindspore/lite/nnacl/int8/conv_int8.c b/mindspore/lite/nnacl/int8/conv_int8.c index e660d5d4f8..f1e87c8682 100644 --- a/mindspore/lite/nnacl/int8/conv_int8.c +++ b/mindspore/lite/nnacl/int8/conv_int8.c @@ -1029,14 +1029,6 @@ void Conv1x1Int8Arm32(const int8_t *packed_input, const int8_t *packed_weight, i const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift, int32_t *multiplier, ConvParameter *conv_param) { int is_per_channel = conv_param->conv_quant_arg_.filter_arg_num_ != 1 ? true : false; - - if (is_per_channel == 1) { - return MatMulInt8_4x2_r( - packed_input, packed_weight, dst, row, col, deep16, conv_param->output_channel_, input_sum, bias, left_shift, - right_shift, multiplier, conv_param->conv_quant_arg_.output_quant_args_[0].zp_, - conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], true); - } - #ifdef ENABLE_ARM32 MatmulInt8Neon32(packed_input, packed_weight, dst, row, col, deep16, input_sum, bias, conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0],