Browse Source

[MSLITE][Develop]Conv1x1 preTrasn neon code -> .S

tags/v1.0.0
ling 5 years ago
parent
commit
96d01f17ec
10 changed files with 477 additions and 189 deletions
  1. +130
    -0
      mindspore/lite/nnacl/assembly/arm32/PreSum4x16Int8Peroc.S
  2. +81
    -0
      mindspore/lite/nnacl/assembly/arm32/PreSum4x16Int8Pert.S
  3. +129
    -0
      mindspore/lite/nnacl/assembly/arm64/PreSum4x16Int8Peroc.S
  4. +70
    -0
      mindspore/lite/nnacl/assembly/arm64/PreSum4x16Int8Pert.S
  5. +8
    -0
      mindspore/lite/nnacl/int8/conv_int8.c
  6. +22
    -2
      mindspore/lite/nnacl/int8/matmul_int8.c
  7. +10
    -172
      mindspore/lite/nnacl/pack.c
  8. +7
    -0
      mindspore/lite/nnacl/pack.h
  9. +19
    -14
      mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.cc
  10. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.h

+ 130
- 0
mindspore/lite/nnacl/assembly/arm32/PreSum4x16Int8Peroc.S View File

@@ -0,0 +1,130 @@

.text
.align 5
.global PreSum4x16Int8Peroc
#ifndef __APPLE__
.type PreSum4x16Int8Peroc, %function
#endif


//void PreSum4x16Int8Peroc(const int8_t *src, int32_t *sum, int32_t *zp, size_t hw4, size_t ic16, int32_t oc_div2,
// size_t oc_res2, size_t stride);

// r0 src
// r1 sum
// r2 zp
// r3 hw4
// r4 ic16
// r5 oc_div2
// r6 oc_res2
// r7 stride

PreSum4x16Int8Peroc:
push {r4-r8, r10, r11, lr}
vpush {q4-q7}
add sp, sp, #96

ldr r4, [sp]
ldr r5, [sp, #4]
ldr r6, [sp, #8]
ldr r7, [sp, #12]

mov r8, #0
mov r10, #8

RowLoop:
cmp r8, r3
beq End
add r8, r8, #4
vmov.s32 q13, #0
mov r9, #0
mov r11, r2

Sum:
cmp r9, r4
beq Mul
add r9, r9, #16

vld1.8 {q0, q1}, [r0]!
vld1.8 {q2, q3}, [r0]!

vpaddl.s8 q4, q0
vpaddl.s8 q5, q1
vpaddl.s8 q6, q2
vpaddl.s8 q7, q3

vpaddl.s16 q0, q4
vpaddl.s16 q1, q5
vpaddl.s16 q2, q6
vpaddl.s16 q3, q7

vpaddl.s32 q4, q0
vpaddl.s32 q5, q1
vpaddl.s32 q6, q2
vpaddl.s32 q7, q3

vqmovn.s64 d0, q4
vqmovn.s64 d1, q5
vqmovn.s64 d2, q6
vqmovn.s64 d3, q7

vpaddl.s32 q4, q0
vpaddl.s32 q5, q1

vqmovn.s64 d0, q4
vqmovn.s64 d1, q5

vadd.i32 q13, q13, q0
b Sum

Mul:
mov r12, r1
add r1, r1, #32
mov r9, #0

vdup.32 d1, d26[0]
vdup.32 d2, d26[1]
vdup.32 d3, d27[0]
vdup.32 d4, d27[1]

Write:

cmp r9, r5
beq OcRes
add r9, r9, #2
vld1.32 {d9}, [r11]!

vmul.i32 d5, d1, d9
vmul.i32 d6, d2, d9
vmul.i32 d7, d3, d9
vmul.i32 d8, d4, d9

vst1.32 d5, [r12], r10
vst1.32 d6, [r12], r10
vst1.32 d7, [r12], r10
vst1.32 d8, [r12], r10
add r12, r12, r7
b Write

OcRes:
cmp r6, #0
beq RowLoop

vmov.s32 d9, #0
vld1.8 {d9[0]}, [r11]

vmul.i32 d5, d1, d9
vmul.i32 d6, d2, d9
vmul.i32 d7, d3, d9
vmul.i32 d8, d4, d9

vst1.32 d5, [r12], r10
vst1.32 d6, [r12], r10
vst1.32 d7, [r12], r10
vst1.32 d8, [r12], r10
b RowLoop

End:
sub sp, sp, #96
vpop {q4-q7}
pop {r4-r8, r10, r11, pc}

+ 81
- 0
mindspore/lite/nnacl/assembly/arm32/PreSum4x16Int8Pert.S View File

@@ -0,0 +1,81 @@

.text
.align 5
.global PreSum4x16Int8Pert
#ifndef __APPLE__
.type PreSum4x16Int8Pert, %function
#endif


// void PreSum4x16Int8Pert(const int8_t *src, int32_t *sum, size_t row4, size_t col16, int32_t filter_zp);

// r0 src
// r1 sum
// r2 row4
// r3 co16
// r4 filter_zp

PreSum4x16Int8Pert:
push {r4-r8, r10, r11, lr}
vpush {q4-q7}
add sp, sp, #96

ldr r4, [sp]

vdup.32 q10, r4
mov r5, #0
mov r7, #16

RowLoop:
cmp r5, r2
beq End
add r5, r5, #4
vmov.s32 q13, #0
mov r6, #0

CalLoop:
cmp r6, r3
beq Write
add r6, r6, #16

vld1.8 {q0, q1}, [r0]!
vld1.8 {q2, q3}, [r0]!

vpaddl.s8 q4, q0
vpaddl.s8 q5, q1
vpaddl.s8 q6, q2
vpaddl.s8 q7, q3

vpaddl.s16 q0, q4
vpaddl.s16 q1, q5
vpaddl.s16 q2, q6
vpaddl.s16 q3, q7

vpaddl.s32 q4, q0
vpaddl.s32 q5, q1
vpaddl.s32 q6, q2
vpaddl.s32 q7, q3

vqmovn.s64 d0, q4
vqmovn.s64 d1, q5
vqmovn.s64 d2, q6
vqmovn.s64 d3, q7

vpaddl.s32 q4, q0
vpaddl.s32 q5, q1

vqmovn.s64 d0, q4
vqmovn.s64 d1, q5

vadd.i32 q13, q13, q0
b CalLoop

Write:
vmul.i32 q13, q13, q10
vst1.32 q13, [r1], r7
beq RowLoop

End:
sub sp, sp, #96
vpop {q4-q7}
pop {r4-r8, r10, r11, pc}

+ 129
- 0
mindspore/lite/nnacl/assembly/arm64/PreSum4x16Int8Peroc.S View File

@@ -0,0 +1,129 @@

#ifdef __aarch64__
.text
.align 5
//.p2align 5,,15
.global PreSum4x16Int8Peroc
#ifndef __APPLE__
.type PreSum4x16Int8Peroc, %function
#endif

//void PreSum4x16Int8Peroc(const int8_t *src, int32_t *sum, int32_t *zp, size_t hw4, size_t ic16, int32_t oc_div4,
// size_t oc_res4, size_t stride);

// x0 src
// x1 sum
// x2 zp
// w3 hw4
// w4 ic16
// w5 oc_div4
// w6 oc_res4
// w7 stride

PreSum4x16Int8Peroc:
mov w8, #0

RowLoop:
cmp w8, w3
beq End
add w8, w8, #4
dup v16.4s, wzr
mov w9, #0
mov x16, x2

Sum:
cmp w9, w4
beq Mul
add w9, w9, #16

ld1 {v0.16b}, [x0], #16
ld1 {v1.16b}, [x0], #16
ld1 {v2.16b}, [x0], #16
ld1 {v3.16b}, [x0], #16

saddlp v4.8h, v0.16b
saddlp v5.8h, v1.16b
saddlp v6.8h, v2.16b
saddlp v7.8h, v3.16b
saddlp v0.4S, v4.8h
saddlp v1.4S, v5.8h
saddlp v2.4S, v6.8h
saddlp v3.4S, v7.8h
addv s4, v0.4S
addv s5, v1.4S
addv s6, v2.4S
addv s7, v3.4S
mov v0.s[0], v4.s[0]
mov v0.s[1], v5.s[0]
mov v0.s[2], v6.s[0]
mov v0.s[3], v7.s[0]

add v16.4s, v16.4s, v0.4s
b Sum

Mul:
mov x12, x1
add x1, x1, #64
mov w9, #0

dup v1.4s, v16.s[0]
dup v2.4s, v16.s[1]
dup v3.4s, v16.s[2]
dup v4.4s, v16.s[3]

WriteOc4:
cmp w9, w5
beq OcRes4
add w9, w9, #4
ld1 {v5.4s}, [x16], #16

mul v16.4s, v5.4s, v1.4s
mul v17.4s, v5.4s, v2.4s
mul v18.4s, v5.4s, v3.4s
mul v19.4s, v5.4s, v4.4s
st1 {v16.4s}, [x12], #16
st1 {v17.4s}, [x12], #16
st1 {v18.4s}, [x12], #16
st1 {v19.4s}, [x12], #16
add x12, x12, x7
b WriteOc4

OcRes4:
cmp w6, #0
beq RowLoop
dup v15.4s, wzr
cmp w6, #1
beq OcRes4_1
cmp w6, #2
beq OcRes4_2
cmp w6, #3
beq OcRes4_3

OcRes4_1:
ld1 {v15.s}[0], [x16]
b OcRes4End

OcRes4_2:
ld1 {v15.h}[0], [x16]
b OcRes4End

OcRes4_3:
ld1 {v15.h}[0], [x16]
add x16, x16, #8
ld1 {v15.s}[2], [x16]
b OcRes4End

OcRes4End:
mul v16.4s, v15.4s, v1.4s
mul v17.4s, v15.4s, v2.4s
mul v18.4s, v15.4s, v3.4s
mul v19.4s, v15.4s, v4.4s
st1 {v16.4s}, [x12], #16
st1 {v17.4s}, [x12], #16
st1 {v18.4s}, [x12], #16
st1 {v19.4s}, [x12], #16
b RowLoop

End:
ret
#endif

+ 70
- 0
mindspore/lite/nnacl/assembly/arm64/PreSum4x16Int8Pert.S View File

@@ -0,0 +1,70 @@

#ifdef __aarch64__
.text
.align 5
//.p2align 5,,15
.global PreSum4x16Int8Pert
#ifndef __APPLE__
.type PreSum4x16Int8Pert, %function
#endif

// void PreSum4x16Int8Pert(const int8_t *src, int32_t *dst, size_t row4, size_t col16, int32_t filter_zp);

// x0 src
// x1 dst
// w2 row4
// w3 co16
// w4 filter_zp

PreSum4x16Int8Pert:
dup v17.4s, w4
mov w5, #0

RowLoop:
cmp w5, w2
beq End
add w5, w5, #4
dup v16.4s, wzr
mov w6, #0

CalLoop:
cmp w6, w3
beq Write
add w6, w6, #16

ld1 {v0.16b}, [x0], #16
ld1 {v1.16b}, [x0], #16
ld1 {v2.16b}, [x0], #16
ld1 {v3.16b}, [x0], #16

saddlp v4.8h, v0.16b
saddlp v5.8h, v1.16b
saddlp v6.8h, v2.16b
saddlp v7.8h, v3.16b

saddlp v0.4S, v4.8h
saddlp v1.4S, v5.8h
saddlp v2.4S, v6.8h
saddlp v3.4S, v7.8h

addv s4, v0.4S
addv s5, v1.4S
addv s6, v2.4S
addv s7, v3.4S

mov v0.s[0], v4.s[0]
mov v0.s[1], v5.s[0]
mov v0.s[2], v6.s[0]
mov v0.s[3], v7.s[0]

add v16.4s, v16.4s, v0.4s
b CalLoop

Write:
mul v16.4s, v16.4s, v17.4s
st1 {v16.4s}, [x1], #16
beq RowLoop

End:
ret
#endif

+ 8
- 0
mindspore/lite/nnacl/int8/conv_int8.c View File

@@ -1029,6 +1029,14 @@ void Conv1x1Int8Arm32(const int8_t *packed_input, const int8_t *packed_weight, i
const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift,
int32_t *multiplier, ConvParameter *conv_param) {
int is_per_channel = conv_param->conv_quant_arg_.filter_arg_num_ != 1 ? true : false;

if (is_per_channel == 1) {
return MatMulInt8_4x2_r(
packed_input, packed_weight, dst, row, col, deep16, conv_param->output_channel_, input_sum, bias, left_shift,
right_shift, multiplier, conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], true);
}

#ifdef ENABLE_ARM32
MatmulInt8Neon32(packed_input, packed_weight, dst, row, col, deep16, input_sum, bias,
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0],


+ 22
- 2
mindspore/lite/nnacl/int8/matmul_int8.c View File

@@ -117,10 +117,10 @@ void RowMajor2Row16x4MajorInt8(void *src_ptr, void *dst_ptr, int row, int col) {

for (int ri = 0; ri < row_4div; ri += C4NUM) {
for (int ci = 0; ci < col_16div; ci += C16NUM) {
#ifdef ENABLE_ARM64
size_t col_offset = col;
int8_t *src_c = src_r + ci;
int8_t *dst_c = dst_r + ci * C4NUM;
#ifdef ENABLE_ARM64
asm volatile(
"mov x10, %[src_c] \n"
"mov x11, %[dst_c] \n"
@@ -138,8 +138,28 @@ void RowMajor2Row16x4MajorInt8(void *src_ptr, void *dst_ptr, int row, int col) {
:
: [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ col_offset ] "r"(col_offset)
: "x10", "x11", "v0", "v1", "v2", "v3");
#elif ENABLE_ARM32
asm volatile(
"mov r0, %[src_c] \n"
"mov r1, %[dst_c] \n"
"mov r2, %[col_offset] \n"
"mov r3, #16 \n"

"vld1.8 {q0}, [r0], r2 \n"
"vld1.8 {q1}, [r0], r2 \n"
"vld1.8 {q2}, [r0], r2 \n"
"vld1.8 {q3}, [r0], r2 \n"

"vst1.32 q0, [r1], r3 \n"
"vst1.32 q1, [r1], r3 \n"
"vst1.32 q2, [r1], r3 \n"
"vst1.32 q3, [r1], r3 \n"

:
: [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ col_offset ] "r"(col_offset)
: "r0", "r1", "r2", "r3", "q0", "q1", "q2", "q3");
#else
MatrixPack4x16UnitInt8(src_r + ci, dst_r + ci * C4NUM, C4NUM, C16NUM, col);
MatrixPack4x16UnitInt8(src_c, dst_c, C4NUM, C16NUM, col_offset);
#endif
}



+ 10
- 172
mindspore/lite/nnacl/pack.c View File

@@ -189,63 +189,8 @@ void Pack1x1WeightFp32(const float *weight_data, float *packed_weight, ConvParam

void PackInputSum16x4PerLayer(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16) {
/* normal matmul : 4x16 * 16x4 -> 4x4 */
#ifdef ENABLE_ARM64
asm volatile(
"mov x10, %[src] \n"
"mov x11, %[dst] \n"
"dup v15.4s, %w[filter_zp] \n"

"mov x0, #0 \n"
"1: \n"
"cmp x0, %[row4] \n"
"beq 4f \n"
"add x0, x0, #4\n"
"dup v10.4s, wzr \n"
"mov x2, #0 \n"

"2: \n"
"cmp x2, %[col16] \n"
"beq 3f \n"
"add x2, x2, #16\n"

"ld1 {v0.16b}, [x10], #16\n"
"ld1 {v1.16b}, [x10], #16\n"
"ld1 {v2.16b}, [x10], #16\n"
"ld1 {v3.16b}, [x10], #16\n"

"saddlp v4.8h, v0.16b \n"
"saddlp v5.8h, v1.16b \n"
"saddlp v6.8h, v2.16b \n"
"saddlp v7.8h, v3.16b \n"

"saddlp v0.4S, v4.8h \n"
"saddlp v1.4S, v5.8h \n"
"saddlp v2.4S, v6.8h \n"
"saddlp v3.4S, v7.8h \n"

"addv s4, v0.4S \n"
"addv s5, v1.4S \n"
"addv s6, v2.4S \n"
"addv s7, v3.4S \n"

"mov v0.s[0], v4.s[0] \n"
"mov v0.s[1], v5.s[0] \n"
"mov v0.s[2], v6.s[0] \n"
"mov v0.s[3], v7.s[0] \n"

"add v10.4s, v10.4s, v0.4s \n"
"b 2b\n"

"3: \n"
"mul v10.4s, v10.4s, v15.4s \n"
"st1 {v10.4s}, [x11], #16 \n"
"beq 1b \n"

"4: \n"

:
: [ dst ] "r"(dst), [ src ] "r"(src), [ row4 ] "r"(row4), [ col16 ] "r"(col16), [ filter_zp ] "r"(filter_zp)
: "x0", "x1", "x2", "x3", "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v10", "v15");
#ifdef ENABLE_ARM
PreSum4x16Int8Pert(src, dst, row4, col16, filter_zp);
#else
for (int r = 0; r < row4; r++) {
int32_t tmp_value = 0;
@@ -268,121 +213,7 @@ void PackInputSum16x4PerChannel(const int8_t *input_value, int32_t *input_sum, i
size_t oc_div4 = output_channel / C4NUM * C4NUM;
size_t oc_res4 = output_channel - oc_div4;
size_t inputsun_stride = hw4 * C4NUM * 4 - C4NUM * C4NUM * 4;
asm volatile(
"mov x10, %[input_value] \n"
"mov x11, %[input_sum] \n"
"mov x15, %[filter_zp_ptr] \n"

"mov x0, #0 \n"
"1: \n"
"cmp x0, %[hw4] \n"
"beq 11f \n"
"add x0, x0, #4\n"
"dup v10.4s, wzr \n"
"mov x2, #0 \n"
"mov x16, x15 \n"

"2: \n"
"cmp x2, %[ic16] \n"
"beq 3f \n"
"add x2, x2, #16 \n"

"ld1 {v0.16b}, [x10], #16\n"
"ld1 {v1.16b}, [x10], #16\n"
"ld1 {v2.16b}, [x10], #16\n"
"ld1 {v3.16b}, [x10], #16\n"

"saddlp v4.8h, v0.16b \n"
"saddlp v5.8h, v1.16b \n"
"saddlp v6.8h, v2.16b \n"
"saddlp v7.8h, v3.16b \n"
"saddlp v0.4S, v4.8h \n"
"saddlp v1.4S, v5.8h \n"
"saddlp v2.4S, v6.8h \n"
"saddlp v3.4S, v7.8h \n"
"addv s4, v0.4S \n"
"addv s5, v1.4S \n"
"addv s6, v2.4S \n"
"addv s7, v3.4S \n"
"mov v0.s[0], v4.s[0] \n"
"mov v0.s[1], v5.s[0] \n"
"mov v0.s[2], v6.s[0] \n"
"mov v0.s[3], v7.s[0] \n"

"add v10.4s, v10.4s, v0.4s \n"
"b 2b \n"

"3: \n"
"mov x12, x11 \n"
"add x11, x11, #64 \n"
"mov x4, #0 \n"

"dup v1.4s, v10.s[0] \n"
"dup v2.4s, v10.s[1] \n"
"dup v3.4s, v10.s[2] \n"
"dup v4.4s, v10.s[3] \n"

"4: \n"
"cmp x4, %[oc_div4] \n"
"beq 6f \n"
"add x4, x4, #4\n"
"ld1 {v15.4s}, [x16], #16\n"

"mul v16.4s, v15.4s, v1.4s \n"
"mul v17.4s, v15.4s, v2.4s \n"
"mul v18.4s, v15.4s, v3.4s \n"
"mul v19.4s, v15.4s, v4.4s \n"
"st1 {v16.4s}, [x12], #16 \n"
"st1 {v17.4s}, [x12], #16 \n"
"st1 {v18.4s}, [x12], #16 \n"
"st1 {v19.4s}, [x12], #16 \n"
"add x12, x12, %[inputsun_stride] \n"
"b 4b \n"

"6: \n"
"cmp %[oc_res4], #0\n"
"beq 1b \n"
"dup v15.4s, wzr \n"
"cmp %[oc_res4], #1\n"
"beq 7f \n"
"cmp %[oc_res4], #2\n"
"beq 8f \n"
"cmp %[oc_res4], #3\n"
"beq 9f \n"

"7: \n"
"ld1 {v15.s}[0], [x16] \n"
"b 10f \n"

"8: \n"
"ld1 {v15.h}[0], [x16] \n"
"b 10f \n"

"9: \n"
"ld1 {v15.h}[0], [x16] \n"
"add x16, x16, #8 \n"
"ld1 {v15.s}[2], [x16] \n"
"b 10f \n"

"10: \n"
"mul v16.4s, v15.4s, v1.4s \n"
"mul v17.4s, v15.4s, v2.4s \n"
"mul v18.4s, v15.4s, v3.4s \n"
"mul v19.4s, v15.4s, v4.4s \n"
"st1 {v16.4s}, [x12], #16 \n"
"st1 {v17.4s}, [x12], #16 \n"
"st1 {v18.4s}, [x12], #16 \n"
"st1 {v19.4s}, [x12], #16 \n"
"b 1b \n"

"11: \n"

:
: [ input_value ] "r"(input_value), [ input_sum ] "r"(input_sum), [ filter_zp_ptr ] "r"(filter_zp_ptr),
[ hw4 ] "r"(hw4), [ ic16 ] "r"(ic16), [ oc_div4 ] "r"(oc_div4), [ oc_res4 ] "r"(oc_res4),
[ inputsun_stride ] "r"(inputsun_stride)
: "x0", "x2", "x4", "x10", "x11", "x12", "x15", "x16", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v10", "v15",
"v16", "v17", "v18", "v19");
PreSum4x16Int8Peroc(input_value, input_sum, filter_zp_ptr, hw4, ic16, oc_div4, oc_res4, inputsun_stride);
#else

for (int ri = 0; ri < plane_size; ri++) {
@@ -409,6 +240,12 @@ void PackInputSum16x4PerChannelArm32(const int8_t *input_value, int32_t *input_s
size_t hw4 = UP_ROUND(plane_size, C4NUM);
size_t ic16 = UP_ROUND(input_channel, C16NUM);

#ifdef ENABLE_ARM32
size_t oc_div2 = output_channel / C2NUM * C2NUM;
size_t oc_res2 = output_channel - oc_div2;
size_t inputsun_stride = hw4 * C2NUM * 4 - C4NUM * C2NUM * 4;
PreSum4x16Int8Peroc(input_value, input_sum, filter_zp_ptr, hw4, ic16, oc_div2, oc_res2, inputsun_stride);
#else
for (int ri = 0; ri < plane_size; ri++) {
int ri4div = ri / C4NUM, ri4mod = ri % C4NUM;
for (int ci = 0; ci < output_channel; ci++) {
@@ -424,6 +261,7 @@ void PackInputSum16x4PerChannelArm32(const int8_t *input_value, int32_t *input_s
input_sum[dst_index] = tmp_sum_value * filter_zp;
}
}
#endif
return;
}



+ 7
- 0
mindspore/lite/nnacl/pack.h View File

@@ -121,6 +121,13 @@ void PackDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight

void PackDeconvDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel,
ConvQuantArg *quant_qrg);

#ifdef ENABLE_ARM
void PreSum4x16Int8Pert(const int8_t *src, int32_t *sum, size_t row4, size_t col16, int32_t filter_zp);
void PreSum4x16Int8Peroc(const int8_t *src, int32_t *sum, int32_t *zp, size_t hw4, size_t ic16, int32_t oc_div,
size_t oc_res, size_t stride);
#endif

#ifdef __cplusplus
}
#endif


+ 19
- 14
mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.cc View File

@@ -71,7 +71,7 @@ void Convolution1x1Int8CPUKernel::FreeResizeBuf() {
}

void Convolution1x1Int8CPUKernel::CheckSupportOptimize() {
support_optimize_ = true;
support_optimize_ = false;
matmul_func_ = MatMulInt8_8x8_r;
#ifdef ENABLE_ARM64
void *optimize_op_handler = OptimizeModule::GetInstance()->optimized_op_handler_;
@@ -94,7 +94,7 @@ void Convolution1x1Int8CPUKernel::CheckSupportOptimize() {
return;
}

int Convolution1x1Int8CPUKernel::InitBiasByzp(void *src_weight, int input_channel, int output_channel) {
int Convolution1x1Int8CPUKernel::InitBiasByzp(void *src_weight, int input_channel, int output_channel, int round_oc) {
/* bias = bias - v2 x zp1 + zp1 x zp2 */
int32_t *bias_data = reinterpret_cast<int32_t *>(bias_data_);
int8_t *weight = reinterpret_cast<int8_t *>(src_weight);
@@ -118,24 +118,23 @@ int Convolution1x1Int8CPUKernel::InitBiasByzp(void *src_weight, int input_channe
filter_zp_ptr_[fi] = conv_param_->conv_quant_arg_.filter_quant_args_[fi].zp_;
}

int up_round_oc_size = support_optimize_ ? UP_ROUND(output_channel, C8NUM) : UP_ROUND(output_channel, C4NUM);
left_shift_ = reinterpret_cast<int32_t *>(malloc(up_round_oc_size * sizeof(int32_t)));
left_shift_ = reinterpret_cast<int32_t *>(malloc(round_oc * sizeof(int32_t)));
if (left_shift_ == nullptr) {
return RET_ERROR;
}
memset(left_shift_, 0, up_round_oc_size * sizeof(int32_t));
memset(left_shift_, 0, round_oc * sizeof(int32_t));
memcpy(left_shift_, conv_param_->conv_quant_arg_.left_shift_, output_channel * sizeof(int32_t));
right_shift_ = reinterpret_cast<int32_t *>(malloc(up_round_oc_size * sizeof(int32_t)));
right_shift_ = reinterpret_cast<int32_t *>(malloc(round_oc * sizeof(int32_t)));
if (right_shift_ == nullptr) {
return RET_ERROR;
}
memset(right_shift_, 0, up_round_oc_size * sizeof(int32_t));
memset(right_shift_, 0, round_oc * sizeof(int32_t));
memcpy(right_shift_, conv_param_->conv_quant_arg_.right_shift_, output_channel * sizeof(int32_t));
multiplier_ = reinterpret_cast<int32_t *>(malloc(up_round_oc_size * sizeof(int32_t)));
multiplier_ = reinterpret_cast<int32_t *>(malloc(round_oc * sizeof(int32_t)));
if (multiplier_ == nullptr) {
return RET_ERROR;
}
memset(multiplier_, 0, up_round_oc_size * sizeof(int32_t));
memset(multiplier_, 0, round_oc * sizeof(int32_t));
memcpy(multiplier_, conv_param_->conv_quant_arg_.quant_multiplier_, output_channel * sizeof(int32_t));
}
return RET_OK;
@@ -165,18 +164,18 @@ int Convolution1x1Int8CPUKernel::InitWeightBias() {

int col4 = UP_ROUND(output_channel, C4NUM);
int col8 = UP_ROUND(output_channel, C8NUM);
size = support_optimize_ ? col8 * sizeof(int32_t) : col4 * sizeof(int32_t);
bias_data_ = malloc(size);
size = support_optimize_ ? col8 : col4;
bias_data_ = malloc(size * sizeof(int32_t));
if (bias_data_ == nullptr) {
MS_LOG(ERROR) << "Conv1x1 int8 Malloc bias_ptr_ error!";
return RET_ERROR;
}
memset(bias_data_, 0, size);
memset(bias_data_, 0, size * sizeof(int32_t));
if (in_tensors_.size() == 3) {
memcpy(bias_data_, in_tensors_[kBiasIndex]->MutableData(), output_channel * sizeof(int32_t));
}

InitBiasByzp(filter_tensor->MutableData(), input_channel, output_channel);
InitBiasByzp(filter_tensor->MutableData(), input_channel, output_channel, size);
return RET_OK;
}

@@ -208,7 +207,7 @@ int Convolution1x1Int8CPUKernel::InitWeightBiasArm32() {
memcpy(bias_data_, in_tensors_[kBiasIndex]->MutableData(), output_channel * sizeof(int32_t));
}

InitBiasByzp(filter_tensor->MutableData(), input_channel, output_channel);
InitBiasByzp(filter_tensor->MutableData(), input_channel, output_channel, UP_ROUND(output_channel, C2NUM));
return RET_OK;
}

@@ -342,6 +341,12 @@ int Convolution1x1Int8CPUKernel::RunImpl(int task_id) {
if (cur_oc <= 0) {
return RET_OK;
}
if (filter_peroc_) {
cur_input_sum = input_sum_ + task_id * matmul_param_->row_4_ * thread_stride_ * C2NUM;
cur_left_shift = left_shift_ + task_id * thread_stride_ * C2NUM;
cur_right_shift = right_shift_ + task_id * thread_stride_ * C2NUM;
cur_multiplier = multiplier_ + task_id * thread_stride_ * C2NUM;
}
Conv1x1Int8Arm32(packed_input_, packed_weight_ + task_id * thread_stride_ * C2NUM * matmul_param_->deep_16_,
output_ptr_ + task_id * thread_stride_ * C2NUM, cur_input_sum,
reinterpret_cast<int32_t *>(bias_data_) + task_id * thread_stride_ * C2NUM, matmul_param_->row_,


+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.h View File

@@ -55,7 +55,7 @@ class Convolution1x1Int8CPUKernel : public ConvolutionBaseCPUKernel {
int InitWeightBiasArm32();
void Pre1x1Trans(int8_t *src_input, int8_t *src_output);
void CheckSupportOptimize();
int InitBiasByzp(void *src_weight, int input_channel, int output_channel);
int InitBiasByzp(void *src_weight, int input_channel, int output_channel, int round_oc);

private:
int32_t *input_sum_ = nullptr; /* per-oc: oc4 format */


Loading…
Cancel
Save