Browse Source

!7640 [MSLITE] Optimize depthwise convolution 3x3 for arm64 platform

Merge pull request !7640 from zhanyuan/tmp
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
f42353d0d3
3 changed files with 294 additions and 0 deletions
  1. +281
    -0
      mindspore/lite/nnacl/assembly/arm64/ConvDwInt8_3x3.S
  2. +5
    -0
      mindspore/lite/nnacl/int8/conv_depthwise_int8.c
  3. +8
    -0
      mindspore/lite/nnacl/int8/conv_depthwise_int8.h

+ 281
- 0
mindspore/lite/nnacl/assembly/arm64/ConvDwInt8_3x3.S View File

@@ -0,0 +1,281 @@
#ifdef __aarch64__

.text
.align 5
.global ConvDw3x3Int8Neon64
#ifndef __APPLE__
.type ConvDw3x3Int8Neon64, %function
#endif


// void ConvDw3x3Int8Neon64(int8_t *output, const int8_t *input, const int16_t *weight, const int32_t *bias, int input_col_size,
// int input_row_size, int channel, int output_h, int output_w, int8_t in_zp, int32_t out_zp,
// int out_multiplier, int left_shift, int right_shift, int32_t acc_min, int32_t acc_max)
//
// x0: output
// x1: input
// x2: weight
// x3: bias
// w4: col_size
// w5: row_size
// w6: channel
// w7: output_h
// w8: output_w
// w9: in_zp
// w10: out_zp
// w11: out_multiplier
// w12: left_shift
// w13: right_shift
// w14: acc_min
// w15: acc_max

ConvDw3x3Int8Neon64:
sub sp, sp, #160
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
stp x19, x20, [sp], #16
stp x21, x22, [sp], #16

ldr w8, [sp]
ldr w9, [sp, #8]
ldr w10, [sp, #16]
ldr w11, [sp, #24]
ldr w12, [sp, #32]
ldr w13, [sp, #40]
ldr w14, [sp, #48]
ldr w15, [sp, #56]

add x19, x3, #16
add w20, w6, w6 // channel * 2
add w21, w4, w4 // col_size * 2
dup v25.8b, w9
dup v26.4s, w12
dup v27.4s, w11
dup v28.4s, w13
dup v29.4s, w10
dup v30.4s, w14
dup v31.4s, w15

// Load weights
ld1 {v0.8h}, [x2], x20
ld1 {v1.8h}, [x2], x20
ld1 {v2.8h}, [x2], x20
ld1 {v3.8h}, [x2], x20
ld1 {v4.8h}, [x2], x20
ld1 {v5.8h}, [x2], x20
ld1 {v6.8h}, [x2], x20
ld1 {v7.8h}, [x2], x20
ld1 {v8.8h}, [x2], x20

Loop:
mov x16, x1
add x17, x16, x5
add x18, x17, x5
ld1 {v9.8b}, [x16], x4
ld1 {v10.8b}, [x16], x4
ld1 {v11.8b}, [x16], x4
ld1 {v13.8b}, [x17], x4
ld1 {v14.8b}, [x17], x4
ld1 {v15.8b}, [x17], x4
ld1 {v17.8b}, [x18], x4
ld1 {v18.8b}, [x18], x4
ld1 {v19.8b}, [x18], x4

ld1 {v21.4s}, [x3]
ld1 {v22.4s}, [x19]

// subtract input zp
ssubl v9.8h, v9.8b, v25.8b
ssubl v10.8h, v10.8b, v25.8b
ssubl v11.8h, v11.8b, v25.8b
ssubl v13.8h, v13.8b, v25.8b
ssubl v14.8h, v14.8b, v25.8b
ssubl v15.8h, v15.8b, v25.8b
ssubl v17.8h, v17.8b, v25.8b
ssubl v18.8h, v18.8b, v25.8b
ssubl v19.8h, v19.8b, v25.8b

cmp w8, #1
beq Width1

Width2:
ld1 {v12.8b}, [x16]
ld1 {v16.8b}, [x17]
ld1 {v20.8b}, [x18]
ld1 {v23.4s}, [x3]
ld1 {v24.4s}, [x19]

ssubl v12.8h, v12.8b, v25.8b
ssubl v16.8h, v16.8b, v25.8b
ssubl v20.8h, v20.8b, v25.8b

smlal v21.4s, v0.4h, v9.4h
smlal2 v22.4s, v0.8h, v9.8h
smlal v23.4s, v0.4h, v10.4h
smlal2 v24.4s, v0.8h, v10.8h
smlal v21.4s, v1.4h, v10.4h
smlal2 v22.4s, v1.8h, v10.8h
smlal v23.4s, v1.4h, v11.4h
smlal2 v24.4s, v1.8h, v11.8h
smlal v21.4s, v2.4h, v11.4h
smlal2 v22.4s, v2.8h, v11.8h
smlal v23.4s, v2.4h, v12.4h
smlal2 v24.4s, v2.8h, v12.8h
smlal v21.4s, v3.4h, v13.4h
smlal2 v22.4s, v3.8h, v13.8h
smlal v23.4s, v3.4h, v14.4h
smlal2 v24.4s, v3.8h, v14.8h
smlal v21.4s, v4.4h, v14.4h
smlal2 v22.4s, v4.8h, v14.8h
smlal v23.4s, v4.4h, v15.4h
smlal2 v24.4s, v4.8h, v15.8h
smlal v21.4s, v5.4h, v15.4h
smlal2 v22.4s, v5.8h, v15.8h
smlal v23.4s, v5.4h, v16.4h
smlal2 v24.4s, v5.8h, v16.8h
smlal v21.4s, v6.4h, v17.4h
smlal2 v22.4s, v6.8h, v17.8h
smlal v23.4s, v6.4h, v18.4h
smlal2 v24.4s, v6.8h, v18.8h
smlal v21.4s, v7.4h, v18.4h
smlal2 v22.4s, v7.8h, v18.8h
smlal v23.4s, v7.4h, v19.4h
smlal2 v24.4s, v7.8h, v19.8h
smlal v21.4s, v8.4h, v19.4h
smlal2 v22.4s, v8.8h, v19.8h
smlal v23.4s, v8.4h, v20.4h
smlal2 v24.4s, v8.8h, v20.8h

// Apply left shfit
sqshl v21.4s, v21.4s, v26.4s
sqshl v22.4s, v22.4s, v26.4s
sqshl v23.4s, v23.4s, v26.4s
sqshl v24.4s, v24.4s, v26.4s

// Apply the fixed-point part of the multiplier.
sqrdmulh v21.4s, v21.4s, v27.4s
sqrdmulh v22.4s, v22.4s, v27.4s
sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s

// Apply right shfit
and v9.16b, v28.16b, v21.16b
sshr v9.4s, v9.4s, #31
sqadd v21.4s, v21.4s, v9.4s
srshl v21.4s, v21.4s, v28.4s
and v10.16b, v28.16b, v22.16b
sshr v10.4s, v10.4s, #31
sqadd v22.4s, v22.4s, v10.4s
srshl v22.4s, v22.4s, v28.4s
and v11.16b, v28.16b, v23.16b
sshr v11.4s, v11.4s, #31
sqadd v23.4s, v23.4s, v11.4s
srshl v23.4s, v23.4s, v28.4s
and v12.16b, v28.16b, v24.16b
sshr v12.4s, v12.4s, #31
sqadd v24.4s, v24.4s, v12.4s
srshl v24.4s, v24.4s, v28.4s

// Add output zero point
sqadd v21.4s, v21.4s, v29.4s
sqadd v22.4s, v22.4s, v29.4s
sqadd v23.4s, v23.4s, v29.4s
sqadd v24.4s, v24.4s, v29.4s

// Apply min bound
smax v21.4s, v21.4s, v30.4s
smax v22.4s, v22.4s, v30.4s
smax v23.4s, v23.4s, v30.4s
smax v24.4s, v24.4s, v30.4s

// Apply max bound
smin v21.4s, v21.4s, v31.4s
smin v22.4s, v22.4s, v31.4s
smin v23.4s, v23.4s, v31.4s
smin v24.4s, v24.4s, v31.4s

sqxtn v21.4h, v21.4s
sqxtn2 v21.8h, v22.4s
sqxtn v23.4h, v23.4s
sqxtn2 v23.8h, v24.4s
sqxtn v21.8b, v21.8h
sqxtn2 v21.16b, v23.8h

st1 {v21.8b}, [x0], x6
mov v23.d[0], v21.d[1]
st1 {v23.8b}, [x0], x6
sub w8, w8, #2
cbz w8, End
add x1, x1, x21
b Loop

Width1:
smlal v21.4s, v0.4h, v9.4h
smlal2 v22.4s, v0.8h, v9.8h
smlal v21.4s, v1.4h, v10.4h
smlal2 v22.4s, v1.8h, v10.8h
smlal v21.4s, v2.4h, v11.4h
smlal2 v22.4s, v2.8h, v11.8h
smlal v21.4s, v3.4h, v13.4h
smlal2 v22.4s, v3.8h, v13.8h
smlal v21.4s, v4.4h, v14.4h
smlal2 v22.4s, v4.8h, v14.8h
smlal v21.4s, v5.4h, v15.4h
smlal2 v22.4s, v5.8h, v15.8h
smlal v21.4s, v6.4h, v17.4h
smlal2 v22.4s, v6.8h, v17.8h
smlal v21.4s, v7.4h, v18.4h
smlal2 v22.4s, v7.8h, v18.8h
smlal v21.4s, v8.4h, v19.4h
smlal2 v22.4s, v8.8h, v19.8h

// Apply left shfit
sqshl v21.4s, v21.4s, v26.4s
sqshl v22.4s, v22.4s, v26.4s

// Apply the fixed-point part of the multiplier.
sqrdmulh v21.4s, v21.4s, v27.4s
sqrdmulh v22.4s, v22.4s, v27.4s

// Apply right shfit
and v9.16b, v28.16b, v21.16b
sshr v9.4s, v9.4s, #31
sqadd v21.4s, v21.4s, v9.4s
srshl v21.4s, v21.4s, v28.4s
and v10.16b, v28.16b, v22.16b
sshr v10.4s, v10.4s, #31
sqadd v22.4s, v22.4s, v10.4s
srshl v22.4s, v22.4s, v28.4s

// Add output zero point
sqadd v21.4s, v21.4s, v29.4s
sqadd v22.4s, v22.4s, v29.4s

// Apply min bound
smax v21.4s, v21.4s, v30.4s
smax v22.4s, v22.4s, v30.4s

// Apply max bound
smin v21.4s, v21.4s, v31.4s
smin v22.4s, v22.4s, v31.4s

sqxtn v21.4h, v21.4s
sqxtn2 v21.8h, v22.4s
sqxtn v21.8b, v21.8h

st1 {v21.8b}, [x0], x6
sub w8, w8, #1
cbz w8, End
add x1, x1, x4
b Loop

End:
sub sp, sp, #160
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
ldp x19, x20, [sp], #16
ldp x21, x22, [sp], #16
ret

#endif

+ 5
- 0
mindspore/lite/nnacl/int8/conv_depthwise_int8.c View File

@@ -208,8 +208,13 @@ void ConvDw3x3Int8Block(int8_t *output, const int8_t *buffer, const int16_t *wei
int32_t out_zp, int out_multiplier, int left_shift, int right_shift, int32_t acc_min,
int32_t acc_max, int stride) {
for (; start_c <= end_c - 8; start_c += 8) {
#ifdef ENABLE_ARM64
ConvDw3x3Int8Neon64(output, buffer, weight, bias, col_size, row_size, channel, output_h, output_w, in_zp, out_zp,
out_multiplier, left_shift, right_shift, acc_min, acc_max);
#else
ConvDw3x3Int8Window(output, buffer, weight, bias, col_size, row_size, channel, output_h, output_w, in_zp, out_zp,
out_multiplier, left_shift, right_shift, acc_min, acc_max, stride);
#endif
output += 8;
buffer += 8;
weight += 8;


+ 8
- 0
mindspore/lite/nnacl/int8/conv_depthwise_int8.h View File

@@ -43,6 +43,14 @@ void ConvDwSWInt8(int8_t *output_data, const int8_t *input_data, const int16_t *
void DeconvDwInt8(int8_t *output_data, int32_t *output_buffer, const int16_t *input_data, const int16_t *weight_data,
const int32_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding,
int task_id);

#ifdef ENABLE_ARM64
void ConvDw3x3Int8Neon64(int8_t *output, const int8_t *input, const int16_t *weight, const int32_t *bias,
int input_col_size, int input_row_size, int channel, int output_h, int output_w, int8_t in_zp,
int32_t out_zp, int out_multiplier, int left_shift, int right_shift, int32_t acc_min,
int32_t acc_max);
#endif

#ifdef __cplusplus
}
#endif


Loading…
Cancel
Save