| @@ -8,10 +8,10 @@ | |||
| .type MatmulInt8Neon32Opt, %function | |||
| #endif | |||
| //void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, | |||
| // const int *input_sums, const int *weight_bias, int act_min, int act_max, int out_zp, | |||
| // int *multiplier, int *left_shift, int *right_shift, int stride, int per_channel, | |||
| // int *filter_zp); | |||
| //void MatmulInt8Neon32Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, | |||
| // const int *input_sums, const int *weight_bias, int act_min, int act_max, int out_zp, | |||
| // int *multiplier, int *left_shift, int *right_shift, int stride, int per_channel, | |||
| // int *filter_zp); | |||
| // #-52: a, #-48: b, #-44: dst, #-40: row | |||
| // #0: col, #4: deep16, #8: input_sums, #12: weight_bias, #16: act_min, #20: act_max, #24: out_zp | |||
| // #28: multiplier, #32: left_shift, #36: right_shift, #40: stride, #44: per_channel, #48: filter_zp | |||
| @@ -21,6 +21,12 @@ MatmulInt8Neon32Opt: | |||
| vpush {q4-q7} | |||
| add sp, sp, #116 | |||
| ldr r0, [sp, #-52] // load a ptr | |||
| vld1.8 {d0, d1, d2, d3}, [r0]! | |||
| ldr r1, [sp, #-48] // load b ptr | |||
| vld1.8 {d8, d9, d10, d11}, [r1]! | |||
| ldr r4, [sp] // col | |||
| ldr r7, [sp, #40] // output stride | |||
| mov r8, #0 // output channels offset | |||
| @@ -32,15 +38,14 @@ L1: | |||
| cmp r4, #0 // if at the end of col | |||
| ble End1 | |||
| ldr r0, [sp, #-52] // reload a ptr | |||
| ldr r3, [sp, #-40] // reset row counter | |||
| ldr r6, [sp, #8] // reload intpu_sums ptr if per_tensor | |||
| L2: | |||
| cmp r3, #0 // if at the end of row | |||
| ble End2 | |||
| ldr r1, [sp, #-48] // reload b ptr | |||
| ldr r5, [sp, #4] // reset deep16 | |||
| sub r5, r5, #16 | |||
| vmov.i32 q6, #0 | |||
| vmov.i32 q7, #0 | |||
| vmov.i32 q8, #0 | |||
| @@ -49,12 +54,44 @@ L2: | |||
| vmov.i32 q11, #0 | |||
| vmov.i32 q12, #0 | |||
| vmov.i32 q13, #0 | |||
| L3: | |||
| cmp r5, #0 | |||
| beq End3 | |||
| beq L3Tail | |||
| L3: | |||
| vmull.s8 q14, d0, d8 | |||
| vmull.s8 q2, d0, d10 | |||
| vmull.s8 q15, d2, d8 | |||
| vmull.s8 q3, d2, d10 | |||
| vmlal.s8 q14, d1, d9 | |||
| vmlal.s8 q2, d1, d11 | |||
| vmlal.s8 q15, d3, d9 | |||
| vmlal.s8 q3, d3, d11 | |||
| vld1.8 {d0, d1, d2, d3}, [r0]! | |||
| vpadal.s16 q6, q14 | |||
| vpadal.s16 q7, q2 | |||
| vpadal.s16 q8, q15 | |||
| vpadal.s16 q9, q3 | |||
| vmull.s8 q14, d0, d8 | |||
| vmull.s8 q2, d0, d10 | |||
| vmull.s8 q15, d2, d8 | |||
| vmull.s8 q3, d2, d10 | |||
| vmlal.s8 q14, d1, d9 | |||
| vmlal.s8 q2, d1, d11 | |||
| vmlal.s8 q15, d3, d9 | |||
| vmlal.s8 q3, d3, d11 | |||
| vld1.8 {d0, d1, d2, d3}, [r0]! | |||
| vpadal.s16 q10, q14 | |||
| vld1.8 {d8, d9, d10, d11}, [r1]! | |||
| vpadal.s16 q11, q2 | |||
| vpadal.s16 q12, q15 | |||
| vpadal.s16 q13, q3 | |||
| sub r5, r5, #16 // deep16 -= 16 | |||
| cmp r5, #0 | |||
| bgt L3 | |||
| L3Tail: | |||
| vmull.s8 q14, d0, d8 | |||
| vmull.s8 q2, d0, d10 | |||
| vmull.s8 q15, d2, d8 | |||
| @@ -63,13 +100,13 @@ L3: | |||
| vmlal.s8 q2, d1, d11 | |||
| vmlal.s8 q15, d3, d9 | |||
| vmlal.s8 q3, d3, d11 | |||
| vld1.8 {d0, d1, d2, d3}, [r0]! | |||
| vpadal.s16 q6, q14 | |||
| vpadal.s16 q7, q2 | |||
| vpadal.s16 q8, q15 | |||
| vpadal.s16 q9, q3 | |||
| vld1.8 {d0, d1, d2, d3}, [r0]! | |||
| vmull.s8 q14, d0, d8 | |||
| vmull.s8 q2, d0, d10 | |||
| vmull.s8 q15, d2, d8 | |||
| @@ -83,10 +120,7 @@ L3: | |||
| vpadal.s16 q11, q2 | |||
| vpadal.s16 q12, q15 | |||
| vpadal.s16 q13, q3 | |||
| sub r5, r5, #16 // deep16 -= 16 | |||
| b L3 | |||
| End3: | |||
| vpadd.i32 d0, d12, d13 | |||
| vpadd.i32 d1, d14, d15 | |||
| vpadd.i32 d2, d16, d17 | |||
| @@ -101,7 +135,26 @@ End3: | |||
| vpadd.i32 d30, d4, d5 | |||
| vpadd.i32 d31, d6, d7 | |||
| // Add weight_bias | |||
| cmp r3, #4 | |||
| ble LAST_ROW | |||
| vld1.8 {d0, d1, d2, d3}, [r0]! | |||
| ldr r1, [sp, #-48] // reload b ptr | |||
| vld1.8 {d8, d9, d10, d11}, [r1]! | |||
| b AddWeightBias | |||
| LAST_ROW: | |||
| ldr r0, [sp, #-52] // reload a ptr | |||
| vld1.8 {d0, d1, d2, d3}, [r0]! | |||
| ldr r1, [sp, #-48] // reload b ptr | |||
| ldr r9, [sp, #4] | |||
| mov r10, #2 | |||
| mul r9, r9, r10 // the stride of b | |||
| add r1, r1, r9 // b ptr + stride | |||
| str r1, [sp, #-48] | |||
| vld1.8 {d8, d9, d10, d11}, [r1]! | |||
| AddWeightBias: | |||
| ldr r9, [sp, #12] // reload weight_bias ptr | |||
| add r9, r9, r8 | |||
| vld1.32 {d26}, [r9]! | |||
| @@ -148,9 +201,9 @@ PerTensor: | |||
| vshr.s32 q6, q6, #31 | |||
| vqadd.s32 q14, q14, q6 | |||
| vrshl.s32 q14, q14, q7 | |||
| vand q5, q7, q15 | |||
| vshr.s32 q5, q5, #31 | |||
| vqadd.s32 q15, q15, q5 | |||
| vand q3, q7, q15 | |||
| vshr.s32 q3, q3, #31 | |||
| vqadd.s32 q15, q15, q3 | |||
| vrshl.s32 q15, q15, q7 | |||
| b AddDstZP | |||
| @@ -214,9 +267,9 @@ PerChannel: | |||
| AddDstZP: | |||
| // Add the destination zero point | |||
| ldr r10, [sp, #24] | |||
| vdup.32 q4, r10 | |||
| vadd.i32 q14, q14, q4 | |||
| vadd.i32 q15, q15, q4 | |||
| vdup.32 q2, r10 | |||
| vadd.i32 q14, q14, q2 | |||
| vadd.i32 q15, q15, q2 | |||
| // Apply the act_min bound | |||
| ldr r10, [sp, #16] | |||
| @@ -276,12 +329,6 @@ EndWrite: | |||
| End2: | |||
| sub r4, r4, #2 // b col counter -= 2 | |||
| ldr r1, [sp, #-48] // load b ptr | |||
| ldr r9, [sp, #4] | |||
| mov r10, #2 | |||
| mul r9, r9, r10 // the stride of b | |||
| add r1, r1, r9 // b ptr + stride | |||
| str r1, [sp, #-48] | |||
| ldr r2, [sp, #-44] // load dst ptr | |||
| add r2, r2, #2 // dst ptr + offset | |||
| str r2, [sp, #-44] | |||