Browse Source

Support per-channel quantization of int8 matmul for arm32 platform

tags/v1.0.0
zhanyuan 5 years ago
parent
commit
cae7a22b51
2 changed files with 67 additions and 24 deletions
  1. +67
    -16
      mindspore/lite/nnacl/assembly/arm32/MatmulInt8.S
  2. +0
    -8
      mindspore/lite/nnacl/int8/conv_int8.c

+ 67
- 16
mindspore/lite/nnacl/assembly/arm32/MatmulInt8.S View File

@@ -21,24 +21,27 @@ MatmulInt8Neon32:
add sp, sp, #116 add sp, sp, #116


ldr r4, [sp] // col ldr r4, [sp] // col
mov r7, #2
ldr r8, [sp, #4] // deep16
mul r9, r7, r8 // the sride of b
ldr r7, [sp, #40] // output stride ldr r7, [sp, #40] // output stride

mov r8, #0 // output channels offset
ldr r10, [sp, #44]
cmp r10, #0
beq L1
ldr r6, [sp, #8] // load intpu_sums ptr if per_channel
L1: L1:
cmp r4, #0 // if at the end of col cmp r4, #0 // if at the end of col
ble End1 ble End1


ldr r0, [sp, #-52] // reload a ptr ldr r0, [sp, #-52] // reload a ptr
ldr r3, [sp, #-40] // reset row counter ldr r3, [sp, #-40] // reset row counter
ldr r6, [sp, #8] // reload intpu_sums ptr
ldr r10, [sp, #44]
cmp r10, #0
bne L2
ldr r6, [sp, #8] // reload intpu_sums ptr if per_tensor
L2: L2:
cmp r3, #0 // if at the end of row cmp r3, #0 // if at the end of row
ble End2 ble End2


ldr r1, [sp, #-48] // reload b ptr ldr r1, [sp, #-48] // reload b ptr
ldr r8, [sp, #12] // reload weight_bias ptr
ldr r5, [sp, #4] // reset deep16 ldr r5, [sp, #4] // reset deep16
vmov.i32 q6, #0 vmov.i32 q6, #0
vmov.i32 q7, #0 vmov.i32 q7, #0
@@ -101,7 +104,9 @@ End3:
vpadd.i32 d31, d6, d7 vpadd.i32 d31, d6, d7


// Add weight_bias // Add weight_bias
vld1.32 {d26}, [r8]!
ldr r9, [sp, #12] // reload weight_bias ptr
add r9, r9, r8
vld1.32 {d26}, [r9]!
vadd.i32 d28, d28, d26 vadd.i32 d28, d28, d26
vadd.i32 d29, d29, d26 vadd.i32 d29, d29, d26
vadd.i32 d30, d30, d26 vadd.i32 d30, d30, d26
@@ -111,6 +116,7 @@ End3:
cmp r10, #0 cmp r10, #0
bgt PerChannel bgt PerChannel


PerTensor:
// Substract input_sums // Substract input_sums
vld1.32 {d24, d25}, [r6]! vld1.32 {d24, d25}, [r6]!
vdup.32 d20, d24[0] vdup.32 d20, d24[0]
@@ -124,7 +130,7 @@ End3:


// Apply left shift // Apply left shift
ldr r10, [sp, #32] ldr r10, [sp, #32]
ldr r11, [r10]
ldr r11, [r10]!
vdup.32 q9, r11 vdup.32 q9, r11
vshl.s32 q14, q14, q9 vshl.s32 q14, q14, q9
vshl.s32 q15, q15, q9 vshl.s32 q15, q15, q9
@@ -151,7 +157,51 @@ End3:
b AddDstZP b AddDstZP


PerChannel: PerChannel:
// Substract input_sums
vld1.32 {d24, d25, d26, d27}, [r6]!
vsub.s32 d28, d28, d24
vsub.s32 d29, d29, d25
vsub.s32 d30, d30, d26
vsub.s32 d31, d31, d27


// Apply left shift
ldr r10, [sp, #32]
add r10, r10, r8
vld1.32 {d23}, [r10]
vshl.s32 d28, d28, d23
vshl.s32 d29, d29, d23
vshl.s32 d30, d30, d23
vshl.s32 d31, d31, d23

// Apply the fixed-point part of the multiplier
ldr r10, [sp, #28]
add r10, r10, r8
vld1.32 {d22}, [r10]
vqrdmulh.s32 d28, d28, d22
vqrdmulh.s32 d29, d29, d22
vqrdmulh.s32 d30, d30, d22
vqrdmulh.s32 d31, d31, d22

// Apply right shift
ldr r10, [sp, #36]
add r10, r10, r8
vld1.32 {d21}, [r10]
vand d20, d21, d28
vshr.s32 d20, d20, #31
vqadd.s32 d28, d28, d20
vrshl.s32 d28, d28, d21
vand d19, d21, d29
vshr.s32 d19, d19, #31
vqadd.s32 d29, d29, d19
vrshl.s32 d29, d29, d21
vand d18, d21, d30
vshr.s32 d18, d18, #31
vqadd.s32 d30, d30, d18
vrshl.s32 d30, d30, d21
vand d17, d21, d31
vshr.s32 d17, d17, #31
vqadd.s32 d31, d31, d17
vrshl.s32 d31, d31, d21


AddDstZP: AddDstZP:
// Add the destination zero point // Add the destination zero point
@@ -218,15 +268,16 @@ EndWrite:


End2: End2:
sub r4, r4, #2 // b col counter -= 2 sub r4, r4, #2 // b col counter -= 2
ldr r1, [sp, #-48] // b ptr + stride
add r1, r1, r9
str r1, [sp, #-48]
ldr r8, [sp, #12] // weight_bias + stride
add r8, r8, #8
str r8, [sp, #12]
ldr r2, [sp, #-44] // dst ptr + offset
add r2, r2, #2
ldr r1, [sp, #-48] // load b ptr
ldr r9, [sp, #4]
mov r10, #2
mul r9, r9, r10 // the stride of b
add r1, r1, r9 // b ptr + stride
str r1, [sp, #-48]
ldr r2, [sp, #-44] // load dst ptr
add r2, r2, #2 // dst ptr + offset
str r2, [sp, #-44] str r2, [sp, #-44]
add r8, r8, #8 // output channels offset + 2*sizeof(int)
b L1 b L1


End1: End1:


+ 0
- 8
mindspore/lite/nnacl/int8/conv_int8.c View File

@@ -1029,14 +1029,6 @@ void Conv1x1Int8Arm32(const int8_t *packed_input, const int8_t *packed_weight, i
const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift, const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift,
int32_t *multiplier, ConvParameter *conv_param) { int32_t *multiplier, ConvParameter *conv_param) {
int is_per_channel = conv_param->conv_quant_arg_.filter_arg_num_ != 1 ? true : false; int is_per_channel = conv_param->conv_quant_arg_.filter_arg_num_ != 1 ? true : false;

if (is_per_channel == 1) {
return MatMulInt8_4x2_r(
packed_input, packed_weight, dst, row, col, deep16, conv_param->output_channel_, input_sum, bias, left_shift,
right_shift, multiplier, conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], true);
}

#ifdef ENABLE_ARM32 #ifdef ENABLE_ARM32
MatmulInt8Neon32(packed_input, packed_weight, dst, row, col, deep16, input_sum, bias, MatmulInt8Neon32(packed_input, packed_weight, dst, row, col, deep16, input_sum, bias,
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0],


Loading…
Cancel
Save