Browse Source

!10012 [MS][LITE][Develop]avx fp32 matmul kernel support for deconv

From: @lx0095
Reviewed-by: @zhanghaibo5,@zhang_xue_tong,@zhang_xue_tong
Signed-off-by: @zhang_xue_tong,@zhang_xue_tong
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
58cb834733
6 changed files with 378 additions and 1 deletions
  1. +273
    -0
      mindspore/lite/nnacl/assembly/avx/ConvDwFp32Avx3x3.S
  2. +49
    -0
      mindspore/lite/nnacl/common_func.c
  3. +6
    -0
      mindspore/lite/nnacl/common_func.h
  4. +42
    -0
      mindspore/lite/nnacl/fp32/conv_depthwise_fp32.c
  5. +2
    -0
      mindspore/lite/nnacl/fp32/conv_depthwise_fp32.h
  6. +6
    -1
      mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_fp32.cc

+ 273
- 0
mindspore/lite/nnacl/assembly/avx/ConvDwFp32Avx3x3.S View File

@@ -0,0 +1,273 @@
#ifdef ENABLE_AVX
#ifndef WIN32

.text
.align 4
.global ConvDwFp32Avx3x3
#ifndef __APPLE__
.type ConvDwFp32Avx3x3, %function
#endif

// void ConvDwFp32Avx3x3(float *output, float **input, const float *weights, const float *bias, int channels, int output_width,
// size_t input_stride, size_t relu)
// rdi: output
// rsi: input
// rdx: weights
// rcx: bias
// r8: channels
// r9: output_width
// 8: input_stride
// 16: relu
// 24: relu6

ConvDwFp32Avx3x3:
pushq %r15
pushq %r14
pushq %r13
pushq %r12
pushq %rbx
pushq %rbp
pushq %r9
pushq %r8
pushq %rcx
pushq %rdx
pushq %rsi
pushq %rdi
addq $96, %rsp

movq $6, %rax
vcvtsi2ss %rax, %xmm15, %xmm15
vshufps $0, %xmm15, %xmm15, %xmm15
vinsertf128 $1, %xmm15, %ymm15, %ymm15
vxorps %ymm14, %ymm14, %ymm14

LoopPixel:
movq -80(%rsp), %rdx
movq -72(%rsp), %rcx
movq -64(%rsp), %r8
movq (%rsi), %r9
movq 8(%rsi), %r10
movq 16(%rsi), %r11
movq 24(%rsi), %r12
movq 32(%rsi), %r13
movq 40(%rsi), %r14
movq 48(%rsi), %r15
movq 56(%rsi), %rbp
movq 64(%rsi), %rbx

vmovups (%r9), %ymm0
addq $32, %r9
vmovups (%r10), %ymm1
addq $32, %r10
vmovups (%r11), %ymm2
addq $32, %r11

vmovups (%rdx), %ymm11
addq $32, %rdx
vmovups (%rdx), %ymm12
addq $32, %rdx
vmovups (%rdx), %ymm13
addq $32, %rdx

vmovups (%rcx), %ymm10
addq $32, %rcx

cmpq $8, %r8
jbe LeftLoop
LoopC8:
vfmadd231ps %ymm11, %ymm0, %ymm10
vmovups (%r12), %ymm3
addq $32, %r12
vmovups (%rdx), %ymm11
addq $32, %rdx
vfmadd231ps %ymm12, %ymm1, %ymm10
vmovups (%r13), %ymm4
addq $32, %r13
vmovups (%rdx), %ymm12
addq $32, %rdx
vfmadd231ps %ymm13, %ymm2, %ymm10
vmovups (%r14), %ymm5
addq $32, %r14
vmovups (%rdx), %ymm13
addq $32, %rdx
vfmadd231ps %ymm11, %ymm3, %ymm10
vmovups (%r15), %ymm6
addq $32, %r15
vmovups (%rdx), %ymm11
addq $32, %rdx
vfmadd231ps %ymm12, %ymm4, %ymm10
vmovups (%rbp), %ymm7
addq $32, %rbp
vmovups (%rdx), %ymm12
addq $32, %rdx
vfmadd231ps %ymm13, %ymm5, %ymm10
vmovups (%rbx), %ymm8
addq $32, %rbx
vmovups (%rdx), %ymm13
addq $32, %rdx
vfmadd231ps %ymm11, %ymm6, %ymm10
vmovups (%r9), %ymm0
addq $32, %r9
vmovups (%rdx), %ymm11
addq $32, %rdx
vfmadd231ps %ymm12, %ymm7, %ymm10
vmovups (%r10), %ymm1
addq $32, %r10
vmovups (%rdx), %ymm12
addq $32, %rdx
vfmadd231ps %ymm13, %ymm8, %ymm10
vmovups (%r11), %ymm2
addq $32, %r11
vmovups (%rdx), %ymm13
addq $32, %rdx

movq 24(%rsp), %rax
cmpq $0, %rax
jne Relu6
movq 16(%rsp), %rax
cmpq $0, %rax
jne Relu
jmp Write
Relu6:
vminps %ymm15, %ymm10, %ymm10
Relu:
vmaxps %ymm14, %ymm10, %ymm10
Write:
vmovups %ymm10, (%rdi)
addq $32, %rdi

vmovups (%rcx), %ymm10
addq $32, %rcx
subq $8, %r8
cmpq $8, %r8
ja LoopC8

LeftLoop:
vfmadd231ps %ymm11, %ymm0, %ymm10
vmovups (%r12), %ymm3
addq $32, %r12
vmovups (%rdx), %ymm11
addq $32, %rdx
vfmadd231ps %ymm12, %ymm1, %ymm10
vmovups (%r13), %ymm4
addq $32, %r13
vmovups (%rdx), %ymm12
addq $32, %rdx
vfmadd231ps %ymm13, %ymm2, %ymm10
vmovups (%r14), %ymm5
addq $32, %r14
vmovups (%rdx), %ymm13
addq $32, %rdx
vfmadd231ps %ymm11, %ymm3, %ymm10
vmovups (%r15), %ymm6
addq $32, %r15
vmovups (%rdx), %ymm11
addq $32, %rdx
vfmadd231ps %ymm12, %ymm4, %ymm10
vmovups (%rbp), %ymm7
addq $32, %rbp
vmovups (%rdx), %ymm12
addq $32, %rdx
vfmadd231ps %ymm13, %ymm5, %ymm10
vmovups (%rbx), %ymm8
addq $32, %rbx
vmovups (%rdx), %ymm13
addq $32, %rdx
vfmadd231ps %ymm11, %ymm6, %ymm10
vfmadd231ps %ymm12, %ymm7, %ymm10
vfmadd231ps %ymm13, %ymm8, %ymm10

movq 24(%rsp), %rax
cmpq $0, %rax
jne LeftRelu6
movq 16(%rsp), %rax
cmpq $0, %rax
jne LeftRelu
jmp LeftWrite
LeftRelu6:
vminps %ymm15, %ymm10, %ymm10
LeftRelu:
vmaxps %ymm14, %ymm10, %ymm10
LeftWrite:
cmpq $1, %r8
je Write1
cmpq $2, %r8
je Write2
cmpq $3, %r8
je Write3
cmpq $4, %r8
je Write4
cmpq $5, %r8
je Write5
cmpq $6, %r8
je Write6
cmpq $7, %r8
je Write7
jmp Write8
Write1:
vmovss %xmm10, (%rdi)
addq $4, %rdi
jmp NextPixel
Write2:
vmovsd %xmm10, (%rdi)
addq $8, %rdi
jmp NextPixel
Write3:
vmovsd %xmm10, (%rdi)
movhlps %xmm10, %xmm10
vmovss %xmm10, 8(%rdi)
addq $12, %rdi
jmp NextPixel
Write4:
vmovups %xmm10, (%rdi)
addq $16, %rdi
jmp NextPixel
Write5:
vmovups %xmm10, (%rdi)
vextractf128 $1, %ymm10, %xmm9
vmovss %xmm9, 16(%rdi)
addq $20, %rdi
jmp NextPixel
Write6:
vmovups %xmm10, (%rdi)
vextractf128 $1, %ymm10, %xmm9
vmovsd %xmm9, 16(%rdi)
addq $24, %rdi
jmp NextPixel
Write7:
vmovups %xmm10, (%rdi)
vextractf128 $1, %ymm10, %xmm9
vmovsd %xmm9, 16(%rdi)
movhlps %xmm9, %xmm9
vmovss %xmm9, 24(%rdi)
addq $28, %rdi
jmp NextPixel
Write8:
vmovups %ymm10, (%rdi)
add $32, %rdi

NextPixel:
movq 8(%rsp), %rbp
addq %rbp, %rsi
movq -56(%rsp), %rax
subq $1, %rax
movq %rax, -56(%rsp)
cmpq $0, %rax
ja LoopPixel
End:
subq $96, %rsp
popq %rdi
popq %rsi
popq %rdx
popq %rcx
popq %r8
popq %r9
popq %rbp
popq %rbx
popq %r12
popq %r13
popq %r14
popq %r15
retq
#endif
#endif

+ 49
- 0
mindspore/lite/nnacl/common_func.c View File

@@ -78,3 +78,52 @@ void Relu6Fp32(float *data, float *dst, int ele_num) {
data[j] = data[j] > 6 ? 6 : data[j];
}
}

#ifdef ENABLE_AVX
#ifdef WIN32
void ReluFp32C8(float *data, float *dst, int ele_num) {
int four_block = UP_DIV(ele_num, C8NUM);
for (int i = 0; i < four_block - 1; i++) {
int index = i * C8NUM;
data[index] = data[index] < 0 ? 0 : data[index];
data[index + 1] = data[index + 1] < 0 ? 0 : data[index + 1];
data[index + 2] = data[index + 2] < 0 ? 0 : data[index + 2];
data[index + 3] = data[index + 3] < 0 ? 0 : data[index + 3];
data[index + 4] = data[index + 4] < 0 ? 0 : data[index + 4];
data[index + 5] = data[index + 5] < 0 ? 0 : data[index + 5];
data[index + 6] = data[index + 6] < 0 ? 0 : data[index + 6];
data[index + 7] = data[index + 7] < 0 ? 0 : data[index + 7];
}
for (int j = (four_block - 1) * C8NUM; j < ele_num; ++j) {
data[j] = data[j] < 0 ? 0 : data[j];
}
}

void Relu6Fp32C8(float *data, float *dst, int ele_num) {
int four_block = UP_DIV(ele_num, C8NUM);
for (int i = 0; i < four_block - 1; i++) {
int index = i * C8NUM;
data[index] = data[index] < 0 ? 0 : data[index];
data[index] = data[index] > 6 ? 6 : data[index];
data[index + 1] = data[index + 1] < 0 ? 0 : data[index + 1];
data[index + 1] = data[index + 1] > 6 ? 6 : data[index + 1];
data[index + 2] = data[index + 2] < 0 ? 0 : data[index + 2];
data[index + 2] = data[index + 2] > 6 ? 6 : data[index + 2];
data[index + 3] = data[index + 3] < 0 ? 0 : data[index + 3];
data[index + 3] = data[index + 3] > 6 ? 6 : data[index + 3];
data[index + 4] = data[index + 4] < 0 ? 0 : data[index + 4];
data[index + 4] = data[index + 4] > 6 ? 6 : data[index + 4];
data[index + 5] = data[index + 5] < 0 ? 0 : data[index + 5];
data[index + 5] = data[index + 5] > 6 ? 6 : data[index + 5];
data[index + 6] = data[index + 6] < 0 ? 0 : data[index + 6];
data[index + 6] = data[index + 6] > 6 ? 6 : data[index + 6];
data[index + 7] = data[index + 7] < 0 ? 0 : data[index + 7];
data[index + 7] = data[index + 7] > 6 ? 6 : data[index + 7];
}
for (int j = (four_block - 1) * C8NUM; j < ele_num; ++j) {
data[j] = data[j] < 0 ? 0 : data[j];
data[j] = data[j] > 6 ? 6 : data[j];
}
}
#endif
#endif

+ 6
- 0
mindspore/lite/nnacl/common_func.h View File

@@ -31,6 +31,12 @@ int8_t MinInt8(int8_t a, int8_t b);
int8_t MaxInt8(int8_t a, int8_t b);
void ReluFp32(float *data, float *dst, int ele_num);
void Relu6Fp32(float *data, float *dst, int ele_num);
#ifdef ENABLE_AVX
#ifdef WIN32
void ReluFp32C8(float *data, float *dst, int ele_num);
void Relu6Fp32C8(float *data, float *dst, int ele_num);
#endif
#endif
int offset(const int *shape, const int dim0, const int dim1, const int dim2, const int dim3);
int offsetComm(const int *shape, const int dim0, const int dim1, const int dim2);
int offset4d(const int *shape, const int *dims);


+ 42
- 0
mindspore/lite/nnacl/fp32/conv_depthwise_fp32.c View File

@@ -681,6 +681,47 @@ void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, c
#endif

#ifdef ENABLE_AVX
#ifdef WIN32
void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels,
int output_width, int input_stride, bool relu, bool relu6, int kernel) {
do {
float *in[kernel];
for (int k = 0; k < kernel; k++) {
in[k] = input[k];
}
input = input + input_stride;

size_t c = channels;
const float *w = weights;
float *out = output;
memcpy(out, bias, channels * sizeof(float));
for (; c >= C8NUM; c -= C8NUM) {
for (int i = 0; i < C8NUM; i++) {
for (int k = 0; k < kernel; k++) {
out[i] += in[k][i] * w[i + k * C8NUM];
}
}
w += kernel * C8NUM;
out += C8NUM;
for (int k = 0; k < kernel; k++) {
in[k] += C8NUM;
}
}
for (int i = 0; i < c; i++) {
for (int k = 0; k < kernel; k++) {
out[i] += in[k][i] * w[i + k * C8NUM];
}
}
if (relu) {
ReluFp32C8(output, output, channels);
}
if (relu6) {
Relu6Fp32C8(output, output, channels);
}
output += channels;
} while (--output_width != 0);
}
#else
void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels,
int output_width, int input_stride, bool relu, bool relu6, int kernel) {
if (kernel == 9) {
@@ -688,6 +729,7 @@ void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, c
}
}
#endif
#endif

void ConvDwIndirection(float *output_data, float **indirect_buffer, const float *weight_data, const float *bias_data,
float *zero_ptr, const ConvParameter *conv_param, int task_id) {


+ 2
- 0
mindspore/lite/nnacl/fp32/conv_depthwise_fp32.h View File

@@ -67,9 +67,11 @@ void ConvDwFp32Indirect5x5(float *output, float **input, const float *weights, c
#endif

#ifdef ENABLE_AVX
#ifndef WIN32
void ConvDwFp32Avx3x3(float *output, float **input, const float *weights, const float *bias, int channels,
int output_width, size_t input_stride, size_t relu, size_t relu6);
#endif
#endif

void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels,
int output_width, int input_stride, bool relu, bool relu6, int kernel);


+ 6
- 1
mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_fp32.cc View File

@@ -147,7 +147,12 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *>
conv_param->input_channel_ = inputs[kInputIndex]->Channel();
conv_param->output_h_ = outputs[kOutputIndex]->Height();
conv_param->output_w_ = outputs[kOutputIndex]->Width();
#if defined(ENABLE_ARM64) || defined(ENABLE_AVX)
#ifdef ENABLE_AVX
if (conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3) {
kernel =
new (std::nothrow) kernel::ConvolutionDepthwiseIndirectCPUKernel(opParameter, inputs, outputs, ctx, primitive);
}
#elif defined(ENABLE_ARM64)
if (CheckConvDwUseIndirectBuffer(conv_param)) {
kernel =
new (std::nothrow) kernel::ConvolutionDepthwiseIndirectCPUKernel(opParameter, inputs, outputs, ctx, primitive);


Loading…
Cancel
Save