From 96d01f17ec4bd9162c029913e4a520b5a79cbea8 Mon Sep 17 00:00:00 2001 From: ling Date: Mon, 14 Sep 2020 15:09:38 +0800 Subject: [PATCH] [MSLITE][Develop]Conv1x1 preTrasn neon code -> .S --- .../assembly/arm32/PreSum4x16Int8Peroc.S | 130 +++++++++++++ .../nnacl/assembly/arm32/PreSum4x16Int8Pert.S | 81 ++++++++ .../assembly/arm64/PreSum4x16Int8Peroc.S | 129 +++++++++++++ .../nnacl/assembly/arm64/PreSum4x16Int8Pert.S | 70 +++++++ mindspore/lite/nnacl/int8/conv_int8.c | 8 + mindspore/lite/nnacl/int8/matmul_int8.c | 24 ++- mindspore/lite/nnacl/pack.c | 182 +----------------- mindspore/lite/nnacl/pack.h | 7 + .../kernel/arm/int8/convolution_1x1_int8.cc | 33 ++-- .../kernel/arm/int8/convolution_1x1_int8.h | 2 +- 10 files changed, 477 insertions(+), 189 deletions(-) create mode 100644 mindspore/lite/nnacl/assembly/arm32/PreSum4x16Int8Peroc.S create mode 100644 mindspore/lite/nnacl/assembly/arm32/PreSum4x16Int8Pert.S create mode 100644 mindspore/lite/nnacl/assembly/arm64/PreSum4x16Int8Peroc.S create mode 100644 mindspore/lite/nnacl/assembly/arm64/PreSum4x16Int8Pert.S diff --git a/mindspore/lite/nnacl/assembly/arm32/PreSum4x16Int8Peroc.S b/mindspore/lite/nnacl/assembly/arm32/PreSum4x16Int8Peroc.S new file mode 100644 index 0000000000..569cca56d5 --- /dev/null +++ b/mindspore/lite/nnacl/assembly/arm32/PreSum4x16Int8Peroc.S @@ -0,0 +1,130 @@ + +.text +.align 5 +.global PreSum4x16Int8Peroc +#ifndef __APPLE__ +.type PreSum4x16Int8Peroc, %function +#endif + + +//void PreSum4x16Int8Peroc(const int8_t *src, int32_t *sum, int32_t *zp, size_t hw4, size_t ic16, int32_t oc_div2, +// size_t oc_res2, size_t stride); + +// r0 src +// r1 sum +// r2 zp +// r3 hw4 +// r4 ic16 +// r5 oc_div2 +// r6 oc_res2 +// r7 stride + +PreSum4x16Int8Peroc: + push {r4-r8, r10, r11, lr} + vpush {q4-q7} + add sp, sp, #96 + + ldr r4, [sp] + ldr r5, [sp, #4] + ldr r6, [sp, #8] + ldr r7, [sp, #12] + + mov r8, #0 + mov r10, #8 + +RowLoop: + cmp r8, r3 + beq End + add r8, r8, #4 + vmov.s32 q13, #0 + mov r9, #0 + mov r11, r2 + +Sum: + cmp r9, r4 + beq Mul + add r9, r9, #16 + + vld1.8 {q0, q1}, [r0]! + vld1.8 {q2, q3}, [r0]! + + vpaddl.s8 q4, q0 + vpaddl.s8 q5, q1 + vpaddl.s8 q6, q2 + vpaddl.s8 q7, q3 + + vpaddl.s16 q0, q4 + vpaddl.s16 q1, q5 + vpaddl.s16 q2, q6 + vpaddl.s16 q3, q7 + + vpaddl.s32 q4, q0 + vpaddl.s32 q5, q1 + vpaddl.s32 q6, q2 + vpaddl.s32 q7, q3 + + vqmovn.s64 d0, q4 + vqmovn.s64 d1, q5 + vqmovn.s64 d2, q6 + vqmovn.s64 d3, q7 + + vpaddl.s32 q4, q0 + vpaddl.s32 q5, q1 + + vqmovn.s64 d0, q4 + vqmovn.s64 d1, q5 + + vadd.i32 q13, q13, q0 + b Sum + +Mul: + mov r12, r1 + add r1, r1, #32 + mov r9, #0 + + vdup.32 d1, d26[0] + vdup.32 d2, d26[1] + vdup.32 d3, d27[0] + vdup.32 d4, d27[1] + +Write: + + cmp r9, r5 + beq OcRes + add r9, r9, #2 + vld1.32 {d9}, [r11]! + + vmul.i32 d5, d1, d9 + vmul.i32 d6, d2, d9 + vmul.i32 d7, d3, d9 + vmul.i32 d8, d4, d9 + + vst1.32 d5, [r12], r10 + vst1.32 d6, [r12], r10 + vst1.32 d7, [r12], r10 + vst1.32 d8, [r12], r10 + add r12, r12, r7 + b Write + +OcRes: + cmp r6, #0 + beq RowLoop + + vmov.s32 d9, #0 + vld1.8 {d9[0]}, [r11] + + vmul.i32 d5, d1, d9 + vmul.i32 d6, d2, d9 + vmul.i32 d7, d3, d9 + vmul.i32 d8, d4, d9 + + vst1.32 d5, [r12], r10 + vst1.32 d6, [r12], r10 + vst1.32 d7, [r12], r10 + vst1.32 d8, [r12], r10 + b RowLoop + +End: + sub sp, sp, #96 + vpop {q4-q7} + pop {r4-r8, r10, r11, pc} diff --git a/mindspore/lite/nnacl/assembly/arm32/PreSum4x16Int8Pert.S b/mindspore/lite/nnacl/assembly/arm32/PreSum4x16Int8Pert.S new file mode 100644 index 0000000000..052931fa2f --- /dev/null +++ b/mindspore/lite/nnacl/assembly/arm32/PreSum4x16Int8Pert.S @@ -0,0 +1,81 @@ + +.text +.align 5 +.global PreSum4x16Int8Pert +#ifndef __APPLE__ +.type PreSum4x16Int8Pert, %function +#endif + + +// void PreSum4x16Int8Pert(const int8_t *src, int32_t *sum, size_t row4, size_t col16, int32_t filter_zp); + +// r0 src +// r1 sum +// r2 row4 +// r3 co16 +// r4 filter_zp + +PreSum4x16Int8Pert: + push {r4-r8, r10, r11, lr} + vpush {q4-q7} + add sp, sp, #96 + + ldr r4, [sp] + + vdup.32 q10, r4 + mov r5, #0 + mov r7, #16 + +RowLoop: + cmp r5, r2 + beq End + add r5, r5, #4 + vmov.s32 q13, #0 + mov r6, #0 + +CalLoop: + cmp r6, r3 + beq Write + add r6, r6, #16 + + vld1.8 {q0, q1}, [r0]! + vld1.8 {q2, q3}, [r0]! + + vpaddl.s8 q4, q0 + vpaddl.s8 q5, q1 + vpaddl.s8 q6, q2 + vpaddl.s8 q7, q3 + + vpaddl.s16 q0, q4 + vpaddl.s16 q1, q5 + vpaddl.s16 q2, q6 + vpaddl.s16 q3, q7 + + vpaddl.s32 q4, q0 + vpaddl.s32 q5, q1 + vpaddl.s32 q6, q2 + vpaddl.s32 q7, q3 + + vqmovn.s64 d0, q4 + vqmovn.s64 d1, q5 + vqmovn.s64 d2, q6 + vqmovn.s64 d3, q7 + + vpaddl.s32 q4, q0 + vpaddl.s32 q5, q1 + + vqmovn.s64 d0, q4 + vqmovn.s64 d1, q5 + + vadd.i32 q13, q13, q0 + b CalLoop + +Write: + vmul.i32 q13, q13, q10 + vst1.32 q13, [r1], r7 + beq RowLoop + +End: + sub sp, sp, #96 + vpop {q4-q7} + pop {r4-r8, r10, r11, pc} diff --git a/mindspore/lite/nnacl/assembly/arm64/PreSum4x16Int8Peroc.S b/mindspore/lite/nnacl/assembly/arm64/PreSum4x16Int8Peroc.S new file mode 100644 index 0000000000..a48e1d823b --- /dev/null +++ b/mindspore/lite/nnacl/assembly/arm64/PreSum4x16Int8Peroc.S @@ -0,0 +1,129 @@ + +#ifdef __aarch64__ + .text + .align 5 + //.p2align 5,,15 + .global PreSum4x16Int8Peroc +#ifndef __APPLE__ + .type PreSum4x16Int8Peroc, %function +#endif + +//void PreSum4x16Int8Peroc(const int8_t *src, int32_t *sum, int32_t *zp, size_t hw4, size_t ic16, int32_t oc_div4, +// size_t oc_res4, size_t stride); + +// x0 src +// x1 sum +// x2 zp +// w3 hw4 +// w4 ic16 +// w5 oc_div4 +// w6 oc_res4 +// w7 stride + +PreSum4x16Int8Peroc: + mov w8, #0 + +RowLoop: + cmp w8, w3 + beq End + add w8, w8, #4 + dup v16.4s, wzr + mov w9, #0 + mov x16, x2 + +Sum: + cmp w9, w4 + beq Mul + add w9, w9, #16 + + ld1 {v0.16b}, [x0], #16 + ld1 {v1.16b}, [x0], #16 + ld1 {v2.16b}, [x0], #16 + ld1 {v3.16b}, [x0], #16 + + saddlp v4.8h, v0.16b + saddlp v5.8h, v1.16b + saddlp v6.8h, v2.16b + saddlp v7.8h, v3.16b + saddlp v0.4S, v4.8h + saddlp v1.4S, v5.8h + saddlp v2.4S, v6.8h + saddlp v3.4S, v7.8h + addv s4, v0.4S + addv s5, v1.4S + addv s6, v2.4S + addv s7, v3.4S + mov v0.s[0], v4.s[0] + mov v0.s[1], v5.s[0] + mov v0.s[2], v6.s[0] + mov v0.s[3], v7.s[0] + + add v16.4s, v16.4s, v0.4s + b Sum + +Mul: + mov x12, x1 + add x1, x1, #64 + mov w9, #0 + + dup v1.4s, v16.s[0] + dup v2.4s, v16.s[1] + dup v3.4s, v16.s[2] + dup v4.4s, v16.s[3] + +WriteOc4: + cmp w9, w5 + beq OcRes4 + add w9, w9, #4 + ld1 {v5.4s}, [x16], #16 + + mul v16.4s, v5.4s, v1.4s + mul v17.4s, v5.4s, v2.4s + mul v18.4s, v5.4s, v3.4s + mul v19.4s, v5.4s, v4.4s + st1 {v16.4s}, [x12], #16 + st1 {v17.4s}, [x12], #16 + st1 {v18.4s}, [x12], #16 + st1 {v19.4s}, [x12], #16 + add x12, x12, x7 + b WriteOc4 + +OcRes4: + cmp w6, #0 + beq RowLoop + dup v15.4s, wzr + cmp w6, #1 + beq OcRes4_1 + cmp w6, #2 + beq OcRes4_2 + cmp w6, #3 + beq OcRes4_3 + +OcRes4_1: + ld1 {v15.s}[0], [x16] + b OcRes4End + +OcRes4_2: + ld1 {v15.h}[0], [x16] + b OcRes4End + +OcRes4_3: + ld1 {v15.h}[0], [x16] + add x16, x16, #8 + ld1 {v15.s}[2], [x16] + b OcRes4End + +OcRes4End: + mul v16.4s, v15.4s, v1.4s + mul v17.4s, v15.4s, v2.4s + mul v18.4s, v15.4s, v3.4s + mul v19.4s, v15.4s, v4.4s + st1 {v16.4s}, [x12], #16 + st1 {v17.4s}, [x12], #16 + st1 {v18.4s}, [x12], #16 + st1 {v19.4s}, [x12], #16 + b RowLoop + +End: + ret +#endif \ No newline at end of file diff --git a/mindspore/lite/nnacl/assembly/arm64/PreSum4x16Int8Pert.S b/mindspore/lite/nnacl/assembly/arm64/PreSum4x16Int8Pert.S new file mode 100644 index 0000000000..d4c61a2242 --- /dev/null +++ b/mindspore/lite/nnacl/assembly/arm64/PreSum4x16Int8Pert.S @@ -0,0 +1,70 @@ + +#ifdef __aarch64__ + .text + .align 5 + //.p2align 5,,15 + .global PreSum4x16Int8Pert +#ifndef __APPLE__ + .type PreSum4x16Int8Pert, %function +#endif + +// void PreSum4x16Int8Pert(const int8_t *src, int32_t *dst, size_t row4, size_t col16, int32_t filter_zp); + +// x0 src +// x1 dst +// w2 row4 +// w3 co16 +// w4 filter_zp + +PreSum4x16Int8Pert: + dup v17.4s, w4 + mov w5, #0 + +RowLoop: + cmp w5, w2 + beq End + add w5, w5, #4 + dup v16.4s, wzr + mov w6, #0 + +CalLoop: + cmp w6, w3 + beq Write + add w6, w6, #16 + + ld1 {v0.16b}, [x0], #16 + ld1 {v1.16b}, [x0], #16 + ld1 {v2.16b}, [x0], #16 + ld1 {v3.16b}, [x0], #16 + + saddlp v4.8h, v0.16b + saddlp v5.8h, v1.16b + saddlp v6.8h, v2.16b + saddlp v7.8h, v3.16b + + saddlp v0.4S, v4.8h + saddlp v1.4S, v5.8h + saddlp v2.4S, v6.8h + saddlp v3.4S, v7.8h + + addv s4, v0.4S + addv s5, v1.4S + addv s6, v2.4S + addv s7, v3.4S + + mov v0.s[0], v4.s[0] + mov v0.s[1], v5.s[0] + mov v0.s[2], v6.s[0] + mov v0.s[3], v7.s[0] + + add v16.4s, v16.4s, v0.4s + b CalLoop + +Write: + mul v16.4s, v16.4s, v17.4s + st1 {v16.4s}, [x1], #16 + beq RowLoop + +End: + ret +#endif \ No newline at end of file diff --git a/mindspore/lite/nnacl/int8/conv_int8.c b/mindspore/lite/nnacl/int8/conv_int8.c index f1e87c8682..e660d5d4f8 100644 --- a/mindspore/lite/nnacl/int8/conv_int8.c +++ b/mindspore/lite/nnacl/int8/conv_int8.c @@ -1029,6 +1029,14 @@ void Conv1x1Int8Arm32(const int8_t *packed_input, const int8_t *packed_weight, i const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift, int32_t *multiplier, ConvParameter *conv_param) { int is_per_channel = conv_param->conv_quant_arg_.filter_arg_num_ != 1 ? true : false; + + if (is_per_channel == 1) { + return MatMulInt8_4x2_r( + packed_input, packed_weight, dst, row, col, deep16, conv_param->output_channel_, input_sum, bias, left_shift, + right_shift, multiplier, conv_param->conv_quant_arg_.output_quant_args_[0].zp_, + conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], true); + } + #ifdef ENABLE_ARM32 MatmulInt8Neon32(packed_input, packed_weight, dst, row, col, deep16, input_sum, bias, conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], diff --git a/mindspore/lite/nnacl/int8/matmul_int8.c b/mindspore/lite/nnacl/int8/matmul_int8.c index 06545b77b5..b212511d9e 100644 --- a/mindspore/lite/nnacl/int8/matmul_int8.c +++ b/mindspore/lite/nnacl/int8/matmul_int8.c @@ -117,10 +117,10 @@ void RowMajor2Row16x4MajorInt8(void *src_ptr, void *dst_ptr, int row, int col) { for (int ri = 0; ri < row_4div; ri += C4NUM) { for (int ci = 0; ci < col_16div; ci += C16NUM) { -#ifdef ENABLE_ARM64 size_t col_offset = col; int8_t *src_c = src_r + ci; int8_t *dst_c = dst_r + ci * C4NUM; +#ifdef ENABLE_ARM64 asm volatile( "mov x10, %[src_c] \n" "mov x11, %[dst_c] \n" @@ -138,8 +138,28 @@ void RowMajor2Row16x4MajorInt8(void *src_ptr, void *dst_ptr, int row, int col) { : : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ col_offset ] "r"(col_offset) : "x10", "x11", "v0", "v1", "v2", "v3"); +#elif ENABLE_ARM32 + asm volatile( + "mov r0, %[src_c] \n" + "mov r1, %[dst_c] \n" + "mov r2, %[col_offset] \n" + "mov r3, #16 \n" + + "vld1.8 {q0}, [r0], r2 \n" + "vld1.8 {q1}, [r0], r2 \n" + "vld1.8 {q2}, [r0], r2 \n" + "vld1.8 {q3}, [r0], r2 \n" + + "vst1.32 q0, [r1], r3 \n" + "vst1.32 q1, [r1], r3 \n" + "vst1.32 q2, [r1], r3 \n" + "vst1.32 q3, [r1], r3 \n" + + : + : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ col_offset ] "r"(col_offset) + : "r0", "r1", "r2", "r3", "q0", "q1", "q2", "q3"); #else - MatrixPack4x16UnitInt8(src_r + ci, dst_r + ci * C4NUM, C4NUM, C16NUM, col); + MatrixPack4x16UnitInt8(src_c, dst_c, C4NUM, C16NUM, col_offset); #endif } diff --git a/mindspore/lite/nnacl/pack.c b/mindspore/lite/nnacl/pack.c index e142a176f5..d4571439b0 100644 --- a/mindspore/lite/nnacl/pack.c +++ b/mindspore/lite/nnacl/pack.c @@ -189,63 +189,8 @@ void Pack1x1WeightFp32(const float *weight_data, float *packed_weight, ConvParam void PackInputSum16x4PerLayer(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16) { /* normal matmul : 4x16 * 16x4 -> 4x4 */ -#ifdef ENABLE_ARM64 - asm volatile( - "mov x10, %[src] \n" - "mov x11, %[dst] \n" - "dup v15.4s, %w[filter_zp] \n" - - "mov x0, #0 \n" - "1: \n" - "cmp x0, %[row4] \n" - "beq 4f \n" - "add x0, x0, #4\n" - "dup v10.4s, wzr \n" - "mov x2, #0 \n" - - "2: \n" - "cmp x2, %[col16] \n" - "beq 3f \n" - "add x2, x2, #16\n" - - "ld1 {v0.16b}, [x10], #16\n" - "ld1 {v1.16b}, [x10], #16\n" - "ld1 {v2.16b}, [x10], #16\n" - "ld1 {v3.16b}, [x10], #16\n" - - "saddlp v4.8h, v0.16b \n" - "saddlp v5.8h, v1.16b \n" - "saddlp v6.8h, v2.16b \n" - "saddlp v7.8h, v3.16b \n" - - "saddlp v0.4S, v4.8h \n" - "saddlp v1.4S, v5.8h \n" - "saddlp v2.4S, v6.8h \n" - "saddlp v3.4S, v7.8h \n" - - "addv s4, v0.4S \n" - "addv s5, v1.4S \n" - "addv s6, v2.4S \n" - "addv s7, v3.4S \n" - - "mov v0.s[0], v4.s[0] \n" - "mov v0.s[1], v5.s[0] \n" - "mov v0.s[2], v6.s[0] \n" - "mov v0.s[3], v7.s[0] \n" - - "add v10.4s, v10.4s, v0.4s \n" - "b 2b\n" - - "3: \n" - "mul v10.4s, v10.4s, v15.4s \n" - "st1 {v10.4s}, [x11], #16 \n" - "beq 1b \n" - - "4: \n" - - : - : [ dst ] "r"(dst), [ src ] "r"(src), [ row4 ] "r"(row4), [ col16 ] "r"(col16), [ filter_zp ] "r"(filter_zp) - : "x0", "x1", "x2", "x3", "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v10", "v15"); +#ifdef ENABLE_ARM + PreSum4x16Int8Pert(src, dst, row4, col16, filter_zp); #else for (int r = 0; r < row4; r++) { int32_t tmp_value = 0; @@ -268,121 +213,7 @@ void PackInputSum16x4PerChannel(const int8_t *input_value, int32_t *input_sum, i size_t oc_div4 = output_channel / C4NUM * C4NUM; size_t oc_res4 = output_channel - oc_div4; size_t inputsun_stride = hw4 * C4NUM * 4 - C4NUM * C4NUM * 4; - asm volatile( - "mov x10, %[input_value] \n" - "mov x11, %[input_sum] \n" - "mov x15, %[filter_zp_ptr] \n" - - "mov x0, #0 \n" - "1: \n" - "cmp x0, %[hw4] \n" - "beq 11f \n" - "add x0, x0, #4\n" - "dup v10.4s, wzr \n" - "mov x2, #0 \n" - "mov x16, x15 \n" - - "2: \n" - "cmp x2, %[ic16] \n" - "beq 3f \n" - "add x2, x2, #16 \n" - - "ld1 {v0.16b}, [x10], #16\n" - "ld1 {v1.16b}, [x10], #16\n" - "ld1 {v2.16b}, [x10], #16\n" - "ld1 {v3.16b}, [x10], #16\n" - - "saddlp v4.8h, v0.16b \n" - "saddlp v5.8h, v1.16b \n" - "saddlp v6.8h, v2.16b \n" - "saddlp v7.8h, v3.16b \n" - "saddlp v0.4S, v4.8h \n" - "saddlp v1.4S, v5.8h \n" - "saddlp v2.4S, v6.8h \n" - "saddlp v3.4S, v7.8h \n" - "addv s4, v0.4S \n" - "addv s5, v1.4S \n" - "addv s6, v2.4S \n" - "addv s7, v3.4S \n" - "mov v0.s[0], v4.s[0] \n" - "mov v0.s[1], v5.s[0] \n" - "mov v0.s[2], v6.s[0] \n" - "mov v0.s[3], v7.s[0] \n" - - "add v10.4s, v10.4s, v0.4s \n" - "b 2b \n" - - "3: \n" - "mov x12, x11 \n" - "add x11, x11, #64 \n" - "mov x4, #0 \n" - - "dup v1.4s, v10.s[0] \n" - "dup v2.4s, v10.s[1] \n" - "dup v3.4s, v10.s[2] \n" - "dup v4.4s, v10.s[3] \n" - - "4: \n" - "cmp x4, %[oc_div4] \n" - "beq 6f \n" - "add x4, x4, #4\n" - "ld1 {v15.4s}, [x16], #16\n" - - "mul v16.4s, v15.4s, v1.4s \n" - "mul v17.4s, v15.4s, v2.4s \n" - "mul v18.4s, v15.4s, v3.4s \n" - "mul v19.4s, v15.4s, v4.4s \n" - "st1 {v16.4s}, [x12], #16 \n" - "st1 {v17.4s}, [x12], #16 \n" - "st1 {v18.4s}, [x12], #16 \n" - "st1 {v19.4s}, [x12], #16 \n" - "add x12, x12, %[inputsun_stride] \n" - "b 4b \n" - - "6: \n" - "cmp %[oc_res4], #0\n" - "beq 1b \n" - "dup v15.4s, wzr \n" - "cmp %[oc_res4], #1\n" - "beq 7f \n" - "cmp %[oc_res4], #2\n" - "beq 8f \n" - "cmp %[oc_res4], #3\n" - "beq 9f \n" - - "7: \n" - "ld1 {v15.s}[0], [x16] \n" - "b 10f \n" - - "8: \n" - "ld1 {v15.h}[0], [x16] \n" - "b 10f \n" - - "9: \n" - "ld1 {v15.h}[0], [x16] \n" - "add x16, x16, #8 \n" - "ld1 {v15.s}[2], [x16] \n" - "b 10f \n" - - "10: \n" - "mul v16.4s, v15.4s, v1.4s \n" - "mul v17.4s, v15.4s, v2.4s \n" - "mul v18.4s, v15.4s, v3.4s \n" - "mul v19.4s, v15.4s, v4.4s \n" - "st1 {v16.4s}, [x12], #16 \n" - "st1 {v17.4s}, [x12], #16 \n" - "st1 {v18.4s}, [x12], #16 \n" - "st1 {v19.4s}, [x12], #16 \n" - "b 1b \n" - - "11: \n" - - : - : [ input_value ] "r"(input_value), [ input_sum ] "r"(input_sum), [ filter_zp_ptr ] "r"(filter_zp_ptr), - [ hw4 ] "r"(hw4), [ ic16 ] "r"(ic16), [ oc_div4 ] "r"(oc_div4), [ oc_res4 ] "r"(oc_res4), - [ inputsun_stride ] "r"(inputsun_stride) - : "x0", "x2", "x4", "x10", "x11", "x12", "x15", "x16", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v10", "v15", - "v16", "v17", "v18", "v19"); + PreSum4x16Int8Peroc(input_value, input_sum, filter_zp_ptr, hw4, ic16, oc_div4, oc_res4, inputsun_stride); #else for (int ri = 0; ri < plane_size; ri++) { @@ -409,6 +240,12 @@ void PackInputSum16x4PerChannelArm32(const int8_t *input_value, int32_t *input_s size_t hw4 = UP_ROUND(plane_size, C4NUM); size_t ic16 = UP_ROUND(input_channel, C16NUM); +#ifdef ENABLE_ARM32 + size_t oc_div2 = output_channel / C2NUM * C2NUM; + size_t oc_res2 = output_channel - oc_div2; + size_t inputsun_stride = hw4 * C2NUM * 4 - C4NUM * C2NUM * 4; + PreSum4x16Int8Peroc(input_value, input_sum, filter_zp_ptr, hw4, ic16, oc_div2, oc_res2, inputsun_stride); +#else for (int ri = 0; ri < plane_size; ri++) { int ri4div = ri / C4NUM, ri4mod = ri % C4NUM; for (int ci = 0; ci < output_channel; ci++) { @@ -424,6 +261,7 @@ void PackInputSum16x4PerChannelArm32(const int8_t *input_value, int32_t *input_s input_sum[dst_index] = tmp_sum_value * filter_zp; } } +#endif return; } diff --git a/mindspore/lite/nnacl/pack.h b/mindspore/lite/nnacl/pack.h index 75384b5247..9d30e426f2 100644 --- a/mindspore/lite/nnacl/pack.h +++ b/mindspore/lite/nnacl/pack.h @@ -121,6 +121,13 @@ void PackDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight void PackDeconvDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel, ConvQuantArg *quant_qrg); + +#ifdef ENABLE_ARM +void PreSum4x16Int8Pert(const int8_t *src, int32_t *sum, size_t row4, size_t col16, int32_t filter_zp); +void PreSum4x16Int8Peroc(const int8_t *src, int32_t *sum, int32_t *zp, size_t hw4, size_t ic16, int32_t oc_div, + size_t oc_res, size_t stride); +#endif + #ifdef __cplusplus } #endif diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.cc index 2781e2e6f4..dfab880c38 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.cc @@ -71,7 +71,7 @@ void Convolution1x1Int8CPUKernel::FreeResizeBuf() { } void Convolution1x1Int8CPUKernel::CheckSupportOptimize() { - support_optimize_ = true; + support_optimize_ = false; matmul_func_ = MatMulInt8_8x8_r; #ifdef ENABLE_ARM64 void *optimize_op_handler = OptimizeModule::GetInstance()->optimized_op_handler_; @@ -94,7 +94,7 @@ void Convolution1x1Int8CPUKernel::CheckSupportOptimize() { return; } -int Convolution1x1Int8CPUKernel::InitBiasByzp(void *src_weight, int input_channel, int output_channel) { +int Convolution1x1Int8CPUKernel::InitBiasByzp(void *src_weight, int input_channel, int output_channel, int round_oc) { /* bias = bias - v2 x zp1 + zp1 x zp2 */ int32_t *bias_data = reinterpret_cast(bias_data_); int8_t *weight = reinterpret_cast(src_weight); @@ -118,24 +118,23 @@ int Convolution1x1Int8CPUKernel::InitBiasByzp(void *src_weight, int input_channe filter_zp_ptr_[fi] = conv_param_->conv_quant_arg_.filter_quant_args_[fi].zp_; } - int up_round_oc_size = support_optimize_ ? UP_ROUND(output_channel, C8NUM) : UP_ROUND(output_channel, C4NUM); - left_shift_ = reinterpret_cast(malloc(up_round_oc_size * sizeof(int32_t))); + left_shift_ = reinterpret_cast(malloc(round_oc * sizeof(int32_t))); if (left_shift_ == nullptr) { return RET_ERROR; } - memset(left_shift_, 0, up_round_oc_size * sizeof(int32_t)); + memset(left_shift_, 0, round_oc * sizeof(int32_t)); memcpy(left_shift_, conv_param_->conv_quant_arg_.left_shift_, output_channel * sizeof(int32_t)); - right_shift_ = reinterpret_cast(malloc(up_round_oc_size * sizeof(int32_t))); + right_shift_ = reinterpret_cast(malloc(round_oc * sizeof(int32_t))); if (right_shift_ == nullptr) { return RET_ERROR; } - memset(right_shift_, 0, up_round_oc_size * sizeof(int32_t)); + memset(right_shift_, 0, round_oc * sizeof(int32_t)); memcpy(right_shift_, conv_param_->conv_quant_arg_.right_shift_, output_channel * sizeof(int32_t)); - multiplier_ = reinterpret_cast(malloc(up_round_oc_size * sizeof(int32_t))); + multiplier_ = reinterpret_cast(malloc(round_oc * sizeof(int32_t))); if (multiplier_ == nullptr) { return RET_ERROR; } - memset(multiplier_, 0, up_round_oc_size * sizeof(int32_t)); + memset(multiplier_, 0, round_oc * sizeof(int32_t)); memcpy(multiplier_, conv_param_->conv_quant_arg_.quant_multiplier_, output_channel * sizeof(int32_t)); } return RET_OK; @@ -165,18 +164,18 @@ int Convolution1x1Int8CPUKernel::InitWeightBias() { int col4 = UP_ROUND(output_channel, C4NUM); int col8 = UP_ROUND(output_channel, C8NUM); - size = support_optimize_ ? col8 * sizeof(int32_t) : col4 * sizeof(int32_t); - bias_data_ = malloc(size); + size = support_optimize_ ? col8 : col4; + bias_data_ = malloc(size * sizeof(int32_t)); if (bias_data_ == nullptr) { MS_LOG(ERROR) << "Conv1x1 int8 Malloc bias_ptr_ error!"; return RET_ERROR; } - memset(bias_data_, 0, size); + memset(bias_data_, 0, size * sizeof(int32_t)); if (in_tensors_.size() == 3) { memcpy(bias_data_, in_tensors_[kBiasIndex]->MutableData(), output_channel * sizeof(int32_t)); } - InitBiasByzp(filter_tensor->MutableData(), input_channel, output_channel); + InitBiasByzp(filter_tensor->MutableData(), input_channel, output_channel, size); return RET_OK; } @@ -208,7 +207,7 @@ int Convolution1x1Int8CPUKernel::InitWeightBiasArm32() { memcpy(bias_data_, in_tensors_[kBiasIndex]->MutableData(), output_channel * sizeof(int32_t)); } - InitBiasByzp(filter_tensor->MutableData(), input_channel, output_channel); + InitBiasByzp(filter_tensor->MutableData(), input_channel, output_channel, UP_ROUND(output_channel, C2NUM)); return RET_OK; } @@ -342,6 +341,12 @@ int Convolution1x1Int8CPUKernel::RunImpl(int task_id) { if (cur_oc <= 0) { return RET_OK; } + if (filter_peroc_) { + cur_input_sum = input_sum_ + task_id * matmul_param_->row_4_ * thread_stride_ * C2NUM; + cur_left_shift = left_shift_ + task_id * thread_stride_ * C2NUM; + cur_right_shift = right_shift_ + task_id * thread_stride_ * C2NUM; + cur_multiplier = multiplier_ + task_id * thread_stride_ * C2NUM; + } Conv1x1Int8Arm32(packed_input_, packed_weight_ + task_id * thread_stride_ * C2NUM * matmul_param_->deep_16_, output_ptr_ + task_id * thread_stride_ * C2NUM, cur_input_sum, reinterpret_cast(bias_data_) + task_id * thread_stride_ * C2NUM, matmul_param_->row_, diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.h index 96f6a11f09..2144344f78 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.h @@ -55,7 +55,7 @@ class Convolution1x1Int8CPUKernel : public ConvolutionBaseCPUKernel { int InitWeightBiasArm32(); void Pre1x1Trans(int8_t *src_input, int8_t *src_output); void CheckSupportOptimize(); - int InitBiasByzp(void *src_weight, int input_channel, int output_channel); + int InitBiasByzp(void *src_weight, int input_channel, int output_channel, int round_oc); private: int32_t *input_sum_ = nullptr; /* per-oc: oc4 format */