|
|
|
@@ -21,24 +21,27 @@ MatmulInt8Neon32: |
|
|
|
add sp, sp, #116 |
|
|
|
|
|
|
|
ldr r4, [sp] // col |
|
|
|
mov r7, #2 |
|
|
|
ldr r8, [sp, #4] // deep16 |
|
|
|
mul r9, r7, r8 // the sride of b |
|
|
|
ldr r7, [sp, #40] // output stride |
|
|
|
|
|
|
|
mov r8, #0 // output channels offset |
|
|
|
ldr r10, [sp, #44] |
|
|
|
cmp r10, #0 |
|
|
|
beq L1 |
|
|
|
ldr r6, [sp, #8] // load intpu_sums ptr if per_channel |
|
|
|
L1: |
|
|
|
cmp r4, #0 // if at the end of col |
|
|
|
ble End1 |
|
|
|
|
|
|
|
ldr r0, [sp, #-52] // reload a ptr |
|
|
|
ldr r3, [sp, #-40] // reset row counter |
|
|
|
ldr r6, [sp, #8] // reload intpu_sums ptr |
|
|
|
ldr r10, [sp, #44] |
|
|
|
cmp r10, #0 |
|
|
|
bne L2 |
|
|
|
ldr r6, [sp, #8] // reload intpu_sums ptr if per_tensor |
|
|
|
L2: |
|
|
|
cmp r3, #0 // if at the end of row |
|
|
|
ble End2 |
|
|
|
|
|
|
|
ldr r1, [sp, #-48] // reload b ptr |
|
|
|
ldr r8, [sp, #12] // reload weight_bias ptr |
|
|
|
ldr r5, [sp, #4] // reset deep16 |
|
|
|
vmov.i32 q6, #0 |
|
|
|
vmov.i32 q7, #0 |
|
|
|
@@ -101,7 +104,9 @@ End3: |
|
|
|
vpadd.i32 d31, d6, d7 |
|
|
|
|
|
|
|
// Add weight_bias |
|
|
|
vld1.32 {d26}, [r8]! |
|
|
|
ldr r9, [sp, #12] // reload weight_bias ptr |
|
|
|
add r9, r9, r8 |
|
|
|
vld1.32 {d26}, [r9]! |
|
|
|
vadd.i32 d28, d28, d26 |
|
|
|
vadd.i32 d29, d29, d26 |
|
|
|
vadd.i32 d30, d30, d26 |
|
|
|
@@ -111,6 +116,7 @@ End3: |
|
|
|
cmp r10, #0 |
|
|
|
bgt PerChannel |
|
|
|
|
|
|
|
PerTensor: |
|
|
|
// Substract input_sums |
|
|
|
vld1.32 {d24, d25}, [r6]! |
|
|
|
vdup.32 d20, d24[0] |
|
|
|
@@ -124,7 +130,7 @@ End3: |
|
|
|
|
|
|
|
// Apply left shift |
|
|
|
ldr r10, [sp, #32] |
|
|
|
ldr r11, [r10] |
|
|
|
ldr r11, [r10]! |
|
|
|
vdup.32 q9, r11 |
|
|
|
vshl.s32 q14, q14, q9 |
|
|
|
vshl.s32 q15, q15, q9 |
|
|
|
@@ -151,7 +157,51 @@ End3: |
|
|
|
b AddDstZP |
|
|
|
|
|
|
|
PerChannel: |
|
|
|
// Substract input_sums |
|
|
|
vld1.32 {d24, d25, d26, d27}, [r6]! |
|
|
|
vsub.s32 d28, d28, d24 |
|
|
|
vsub.s32 d29, d29, d25 |
|
|
|
vsub.s32 d30, d30, d26 |
|
|
|
vsub.s32 d31, d31, d27 |
|
|
|
|
|
|
|
// Apply left shift |
|
|
|
ldr r10, [sp, #32] |
|
|
|
add r10, r10, r8 |
|
|
|
vld1.32 {d23}, [r10] |
|
|
|
vshl.s32 d28, d28, d23 |
|
|
|
vshl.s32 d29, d29, d23 |
|
|
|
vshl.s32 d30, d30, d23 |
|
|
|
vshl.s32 d31, d31, d23 |
|
|
|
|
|
|
|
// Apply the fixed-point part of the multiplier |
|
|
|
ldr r10, [sp, #28] |
|
|
|
add r10, r10, r8 |
|
|
|
vld1.32 {d22}, [r10] |
|
|
|
vqrdmulh.s32 d28, d28, d22 |
|
|
|
vqrdmulh.s32 d29, d29, d22 |
|
|
|
vqrdmulh.s32 d30, d30, d22 |
|
|
|
vqrdmulh.s32 d31, d31, d22 |
|
|
|
|
|
|
|
// Apply right shift |
|
|
|
ldr r10, [sp, #36] |
|
|
|
add r10, r10, r8 |
|
|
|
vld1.32 {d21}, [r10] |
|
|
|
vand d20, d21, d28 |
|
|
|
vshr.s32 d20, d20, #31 |
|
|
|
vqadd.s32 d28, d28, d20 |
|
|
|
vrshl.s32 d28, d28, d21 |
|
|
|
vand d19, d21, d29 |
|
|
|
vshr.s32 d19, d19, #31 |
|
|
|
vqadd.s32 d29, d29, d19 |
|
|
|
vrshl.s32 d29, d29, d21 |
|
|
|
vand d18, d21, d30 |
|
|
|
vshr.s32 d18, d18, #31 |
|
|
|
vqadd.s32 d30, d30, d18 |
|
|
|
vrshl.s32 d30, d30, d21 |
|
|
|
vand d17, d21, d31 |
|
|
|
vshr.s32 d17, d17, #31 |
|
|
|
vqadd.s32 d31, d31, d17 |
|
|
|
vrshl.s32 d31, d31, d21 |
|
|
|
|
|
|
|
AddDstZP: |
|
|
|
// Add the destination zero point |
|
|
|
@@ -218,15 +268,16 @@ EndWrite: |
|
|
|
|
|
|
|
End2: |
|
|
|
sub r4, r4, #2 // b col counter -= 2 |
|
|
|
ldr r1, [sp, #-48] // b ptr + stride |
|
|
|
add r1, r1, r9 |
|
|
|
str r1, [sp, #-48] |
|
|
|
ldr r8, [sp, #12] // weight_bias + stride |
|
|
|
add r8, r8, #8 |
|
|
|
str r8, [sp, #12] |
|
|
|
ldr r2, [sp, #-44] // dst ptr + offset |
|
|
|
add r2, r2, #2 |
|
|
|
ldr r1, [sp, #-48] // load b ptr |
|
|
|
ldr r9, [sp, #4] |
|
|
|
mov r10, #2 |
|
|
|
mul r9, r9, r10 // the stride of b |
|
|
|
add r1, r1, r9 // b ptr + stride |
|
|
|
str r1, [sp, #-48] |
|
|
|
ldr r2, [sp, #-44] // load dst ptr |
|
|
|
add r2, r2, #2 // dst ptr + offset |
|
|
|
str r2, [sp, #-44] |
|
|
|
add r8, r8, #8 // output channels offset + 2*sizeof(int) |
|
|
|
b L1 |
|
|
|
|
|
|
|
End1: |
|
|
|
|