From: @lx0095 Reviewed-by: @zhanghaibo5,@zhang_xue_tong,@zhang_xue_tong Signed-off-by: @zhang_xue_tong,@zhang_xue_tongtags/v1.1.0
| @@ -0,0 +1,273 @@ | |||
| #ifdef ENABLE_AVX | |||
| #ifndef WIN32 | |||
| .text | |||
| .align 4 | |||
| .global ConvDwFp32Avx3x3 | |||
| #ifndef __APPLE__ | |||
| .type ConvDwFp32Avx3x3, %function | |||
| #endif | |||
| // void ConvDwFp32Avx3x3(float *output, float **input, const float *weights, const float *bias, int channels, int output_width, | |||
| // size_t input_stride, size_t relu) | |||
| // rdi: output | |||
| // rsi: input | |||
| // rdx: weights | |||
| // rcx: bias | |||
| // r8: channels | |||
| // r9: output_width | |||
| // 8: input_stride | |||
| // 16: relu | |||
| // 24: relu6 | |||
| ConvDwFp32Avx3x3: | |||
| pushq %r15 | |||
| pushq %r14 | |||
| pushq %r13 | |||
| pushq %r12 | |||
| pushq %rbx | |||
| pushq %rbp | |||
| pushq %r9 | |||
| pushq %r8 | |||
| pushq %rcx | |||
| pushq %rdx | |||
| pushq %rsi | |||
| pushq %rdi | |||
| addq $96, %rsp | |||
| movq $6, %rax | |||
| vcvtsi2ss %rax, %xmm15, %xmm15 | |||
| vshufps $0, %xmm15, %xmm15, %xmm15 | |||
| vinsertf128 $1, %xmm15, %ymm15, %ymm15 | |||
| vxorps %ymm14, %ymm14, %ymm14 | |||
| LoopPixel: | |||
| movq -80(%rsp), %rdx | |||
| movq -72(%rsp), %rcx | |||
| movq -64(%rsp), %r8 | |||
| movq (%rsi), %r9 | |||
| movq 8(%rsi), %r10 | |||
| movq 16(%rsi), %r11 | |||
| movq 24(%rsi), %r12 | |||
| movq 32(%rsi), %r13 | |||
| movq 40(%rsi), %r14 | |||
| movq 48(%rsi), %r15 | |||
| movq 56(%rsi), %rbp | |||
| movq 64(%rsi), %rbx | |||
| vmovups (%r9), %ymm0 | |||
| addq $32, %r9 | |||
| vmovups (%r10), %ymm1 | |||
| addq $32, %r10 | |||
| vmovups (%r11), %ymm2 | |||
| addq $32, %r11 | |||
| vmovups (%rdx), %ymm11 | |||
| addq $32, %rdx | |||
| vmovups (%rdx), %ymm12 | |||
| addq $32, %rdx | |||
| vmovups (%rdx), %ymm13 | |||
| addq $32, %rdx | |||
| vmovups (%rcx), %ymm10 | |||
| addq $32, %rcx | |||
| cmpq $8, %r8 | |||
| jbe LeftLoop | |||
| LoopC8: | |||
| vfmadd231ps %ymm11, %ymm0, %ymm10 | |||
| vmovups (%r12), %ymm3 | |||
| addq $32, %r12 | |||
| vmovups (%rdx), %ymm11 | |||
| addq $32, %rdx | |||
| vfmadd231ps %ymm12, %ymm1, %ymm10 | |||
| vmovups (%r13), %ymm4 | |||
| addq $32, %r13 | |||
| vmovups (%rdx), %ymm12 | |||
| addq $32, %rdx | |||
| vfmadd231ps %ymm13, %ymm2, %ymm10 | |||
| vmovups (%r14), %ymm5 | |||
| addq $32, %r14 | |||
| vmovups (%rdx), %ymm13 | |||
| addq $32, %rdx | |||
| vfmadd231ps %ymm11, %ymm3, %ymm10 | |||
| vmovups (%r15), %ymm6 | |||
| addq $32, %r15 | |||
| vmovups (%rdx), %ymm11 | |||
| addq $32, %rdx | |||
| vfmadd231ps %ymm12, %ymm4, %ymm10 | |||
| vmovups (%rbp), %ymm7 | |||
| addq $32, %rbp | |||
| vmovups (%rdx), %ymm12 | |||
| addq $32, %rdx | |||
| vfmadd231ps %ymm13, %ymm5, %ymm10 | |||
| vmovups (%rbx), %ymm8 | |||
| addq $32, %rbx | |||
| vmovups (%rdx), %ymm13 | |||
| addq $32, %rdx | |||
| vfmadd231ps %ymm11, %ymm6, %ymm10 | |||
| vmovups (%r9), %ymm0 | |||
| addq $32, %r9 | |||
| vmovups (%rdx), %ymm11 | |||
| addq $32, %rdx | |||
| vfmadd231ps %ymm12, %ymm7, %ymm10 | |||
| vmovups (%r10), %ymm1 | |||
| addq $32, %r10 | |||
| vmovups (%rdx), %ymm12 | |||
| addq $32, %rdx | |||
| vfmadd231ps %ymm13, %ymm8, %ymm10 | |||
| vmovups (%r11), %ymm2 | |||
| addq $32, %r11 | |||
| vmovups (%rdx), %ymm13 | |||
| addq $32, %rdx | |||
| movq 24(%rsp), %rax | |||
| cmpq $0, %rax | |||
| jne Relu6 | |||
| movq 16(%rsp), %rax | |||
| cmpq $0, %rax | |||
| jne Relu | |||
| jmp Write | |||
| Relu6: | |||
| vminps %ymm15, %ymm10, %ymm10 | |||
| Relu: | |||
| vmaxps %ymm14, %ymm10, %ymm10 | |||
| Write: | |||
| vmovups %ymm10, (%rdi) | |||
| addq $32, %rdi | |||
| vmovups (%rcx), %ymm10 | |||
| addq $32, %rcx | |||
| subq $8, %r8 | |||
| cmpq $8, %r8 | |||
| ja LoopC8 | |||
| LeftLoop: | |||
| vfmadd231ps %ymm11, %ymm0, %ymm10 | |||
| vmovups (%r12), %ymm3 | |||
| addq $32, %r12 | |||
| vmovups (%rdx), %ymm11 | |||
| addq $32, %rdx | |||
| vfmadd231ps %ymm12, %ymm1, %ymm10 | |||
| vmovups (%r13), %ymm4 | |||
| addq $32, %r13 | |||
| vmovups (%rdx), %ymm12 | |||
| addq $32, %rdx | |||
| vfmadd231ps %ymm13, %ymm2, %ymm10 | |||
| vmovups (%r14), %ymm5 | |||
| addq $32, %r14 | |||
| vmovups (%rdx), %ymm13 | |||
| addq $32, %rdx | |||
| vfmadd231ps %ymm11, %ymm3, %ymm10 | |||
| vmovups (%r15), %ymm6 | |||
| addq $32, %r15 | |||
| vmovups (%rdx), %ymm11 | |||
| addq $32, %rdx | |||
| vfmadd231ps %ymm12, %ymm4, %ymm10 | |||
| vmovups (%rbp), %ymm7 | |||
| addq $32, %rbp | |||
| vmovups (%rdx), %ymm12 | |||
| addq $32, %rdx | |||
| vfmadd231ps %ymm13, %ymm5, %ymm10 | |||
| vmovups (%rbx), %ymm8 | |||
| addq $32, %rbx | |||
| vmovups (%rdx), %ymm13 | |||
| addq $32, %rdx | |||
| vfmadd231ps %ymm11, %ymm6, %ymm10 | |||
| vfmadd231ps %ymm12, %ymm7, %ymm10 | |||
| vfmadd231ps %ymm13, %ymm8, %ymm10 | |||
| movq 24(%rsp), %rax | |||
| cmpq $0, %rax | |||
| jne LeftRelu6 | |||
| movq 16(%rsp), %rax | |||
| cmpq $0, %rax | |||
| jne LeftRelu | |||
| jmp LeftWrite | |||
| LeftRelu6: | |||
| vminps %ymm15, %ymm10, %ymm10 | |||
| LeftRelu: | |||
| vmaxps %ymm14, %ymm10, %ymm10 | |||
| LeftWrite: | |||
| cmpq $1, %r8 | |||
| je Write1 | |||
| cmpq $2, %r8 | |||
| je Write2 | |||
| cmpq $3, %r8 | |||
| je Write3 | |||
| cmpq $4, %r8 | |||
| je Write4 | |||
| cmpq $5, %r8 | |||
| je Write5 | |||
| cmpq $6, %r8 | |||
| je Write6 | |||
| cmpq $7, %r8 | |||
| je Write7 | |||
| jmp Write8 | |||
| Write1: | |||
| vmovss %xmm10, (%rdi) | |||
| addq $4, %rdi | |||
| jmp NextPixel | |||
| Write2: | |||
| vmovsd %xmm10, (%rdi) | |||
| addq $8, %rdi | |||
| jmp NextPixel | |||
| Write3: | |||
| vmovsd %xmm10, (%rdi) | |||
| movhlps %xmm10, %xmm10 | |||
| vmovss %xmm10, 8(%rdi) | |||
| addq $12, %rdi | |||
| jmp NextPixel | |||
| Write4: | |||
| vmovups %xmm10, (%rdi) | |||
| addq $16, %rdi | |||
| jmp NextPixel | |||
| Write5: | |||
| vmovups %xmm10, (%rdi) | |||
| vextractf128 $1, %ymm10, %xmm9 | |||
| vmovss %xmm9, 16(%rdi) | |||
| addq $20, %rdi | |||
| jmp NextPixel | |||
| Write6: | |||
| vmovups %xmm10, (%rdi) | |||
| vextractf128 $1, %ymm10, %xmm9 | |||
| vmovsd %xmm9, 16(%rdi) | |||
| addq $24, %rdi | |||
| jmp NextPixel | |||
| Write7: | |||
| vmovups %xmm10, (%rdi) | |||
| vextractf128 $1, %ymm10, %xmm9 | |||
| vmovsd %xmm9, 16(%rdi) | |||
| movhlps %xmm9, %xmm9 | |||
| vmovss %xmm9, 24(%rdi) | |||
| addq $28, %rdi | |||
| jmp NextPixel | |||
| Write8: | |||
| vmovups %ymm10, (%rdi) | |||
| add $32, %rdi | |||
| NextPixel: | |||
| movq 8(%rsp), %rbp | |||
| addq %rbp, %rsi | |||
| movq -56(%rsp), %rax | |||
| subq $1, %rax | |||
| movq %rax, -56(%rsp) | |||
| cmpq $0, %rax | |||
| ja LoopPixel | |||
| End: | |||
| subq $96, %rsp | |||
| popq %rdi | |||
| popq %rsi | |||
| popq %rdx | |||
| popq %rcx | |||
| popq %r8 | |||
| popq %r9 | |||
| popq %rbp | |||
| popq %rbx | |||
| popq %r12 | |||
| popq %r13 | |||
| popq %r14 | |||
| popq %r15 | |||
| retq | |||
| #endif | |||
| #endif | |||
| @@ -78,3 +78,52 @@ void Relu6Fp32(float *data, float *dst, int ele_num) { | |||
| data[j] = data[j] > 6 ? 6 : data[j]; | |||
| } | |||
| } | |||
| #ifdef ENABLE_AVX | |||
| #ifdef WIN32 | |||
| void ReluFp32C8(float *data, float *dst, int ele_num) { | |||
| int four_block = UP_DIV(ele_num, C8NUM); | |||
| for (int i = 0; i < four_block - 1; i++) { | |||
| int index = i * C8NUM; | |||
| data[index] = data[index] < 0 ? 0 : data[index]; | |||
| data[index + 1] = data[index + 1] < 0 ? 0 : data[index + 1]; | |||
| data[index + 2] = data[index + 2] < 0 ? 0 : data[index + 2]; | |||
| data[index + 3] = data[index + 3] < 0 ? 0 : data[index + 3]; | |||
| data[index + 4] = data[index + 4] < 0 ? 0 : data[index + 4]; | |||
| data[index + 5] = data[index + 5] < 0 ? 0 : data[index + 5]; | |||
| data[index + 6] = data[index + 6] < 0 ? 0 : data[index + 6]; | |||
| data[index + 7] = data[index + 7] < 0 ? 0 : data[index + 7]; | |||
| } | |||
| for (int j = (four_block - 1) * C8NUM; j < ele_num; ++j) { | |||
| data[j] = data[j] < 0 ? 0 : data[j]; | |||
| } | |||
| } | |||
| void Relu6Fp32C8(float *data, float *dst, int ele_num) { | |||
| int four_block = UP_DIV(ele_num, C8NUM); | |||
| for (int i = 0; i < four_block - 1; i++) { | |||
| int index = i * C8NUM; | |||
| data[index] = data[index] < 0 ? 0 : data[index]; | |||
| data[index] = data[index] > 6 ? 6 : data[index]; | |||
| data[index + 1] = data[index + 1] < 0 ? 0 : data[index + 1]; | |||
| data[index + 1] = data[index + 1] > 6 ? 6 : data[index + 1]; | |||
| data[index + 2] = data[index + 2] < 0 ? 0 : data[index + 2]; | |||
| data[index + 2] = data[index + 2] > 6 ? 6 : data[index + 2]; | |||
| data[index + 3] = data[index + 3] < 0 ? 0 : data[index + 3]; | |||
| data[index + 3] = data[index + 3] > 6 ? 6 : data[index + 3]; | |||
| data[index + 4] = data[index + 4] < 0 ? 0 : data[index + 4]; | |||
| data[index + 4] = data[index + 4] > 6 ? 6 : data[index + 4]; | |||
| data[index + 5] = data[index + 5] < 0 ? 0 : data[index + 5]; | |||
| data[index + 5] = data[index + 5] > 6 ? 6 : data[index + 5]; | |||
| data[index + 6] = data[index + 6] < 0 ? 0 : data[index + 6]; | |||
| data[index + 6] = data[index + 6] > 6 ? 6 : data[index + 6]; | |||
| data[index + 7] = data[index + 7] < 0 ? 0 : data[index + 7]; | |||
| data[index + 7] = data[index + 7] > 6 ? 6 : data[index + 7]; | |||
| } | |||
| for (int j = (four_block - 1) * C8NUM; j < ele_num; ++j) { | |||
| data[j] = data[j] < 0 ? 0 : data[j]; | |||
| data[j] = data[j] > 6 ? 6 : data[j]; | |||
| } | |||
| } | |||
| #endif | |||
| #endif | |||
| @@ -31,6 +31,12 @@ int8_t MinInt8(int8_t a, int8_t b); | |||
| int8_t MaxInt8(int8_t a, int8_t b); | |||
| void ReluFp32(float *data, float *dst, int ele_num); | |||
| void Relu6Fp32(float *data, float *dst, int ele_num); | |||
| #ifdef ENABLE_AVX | |||
| #ifdef WIN32 | |||
| void ReluFp32C8(float *data, float *dst, int ele_num); | |||
| void Relu6Fp32C8(float *data, float *dst, int ele_num); | |||
| #endif | |||
| #endif | |||
| int offset(const int *shape, const int dim0, const int dim1, const int dim2, const int dim3); | |||
| int offsetComm(const int *shape, const int dim0, const int dim1, const int dim2); | |||
| int offset4d(const int *shape, const int *dims); | |||
| @@ -681,6 +681,47 @@ void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, c | |||
| #endif | |||
| #ifdef ENABLE_AVX | |||
| #ifdef WIN32 | |||
| void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels, | |||
| int output_width, int input_stride, bool relu, bool relu6, int kernel) { | |||
| do { | |||
| float *in[kernel]; | |||
| for (int k = 0; k < kernel; k++) { | |||
| in[k] = input[k]; | |||
| } | |||
| input = input + input_stride; | |||
| size_t c = channels; | |||
| const float *w = weights; | |||
| float *out = output; | |||
| memcpy(out, bias, channels * sizeof(float)); | |||
| for (; c >= C8NUM; c -= C8NUM) { | |||
| for (int i = 0; i < C8NUM; i++) { | |||
| for (int k = 0; k < kernel; k++) { | |||
| out[i] += in[k][i] * w[i + k * C8NUM]; | |||
| } | |||
| } | |||
| w += kernel * C8NUM; | |||
| out += C8NUM; | |||
| for (int k = 0; k < kernel; k++) { | |||
| in[k] += C8NUM; | |||
| } | |||
| } | |||
| for (int i = 0; i < c; i++) { | |||
| for (int k = 0; k < kernel; k++) { | |||
| out[i] += in[k][i] * w[i + k * C8NUM]; | |||
| } | |||
| } | |||
| if (relu) { | |||
| ReluFp32C8(output, output, channels); | |||
| } | |||
| if (relu6) { | |||
| Relu6Fp32C8(output, output, channels); | |||
| } | |||
| output += channels; | |||
| } while (--output_width != 0); | |||
| } | |||
| #else | |||
| void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels, | |||
| int output_width, int input_stride, bool relu, bool relu6, int kernel) { | |||
| if (kernel == 9) { | |||
| @@ -688,6 +729,7 @@ void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, c | |||
| } | |||
| } | |||
| #endif | |||
| #endif | |||
| void ConvDwIndirection(float *output_data, float **indirect_buffer, const float *weight_data, const float *bias_data, | |||
| float *zero_ptr, const ConvParameter *conv_param, int task_id) { | |||
| @@ -67,9 +67,11 @@ void ConvDwFp32Indirect5x5(float *output, float **input, const float *weights, c | |||
| #endif | |||
| #ifdef ENABLE_AVX | |||
| #ifndef WIN32 | |||
| void ConvDwFp32Avx3x3(float *output, float **input, const float *weights, const float *bias, int channels, | |||
| int output_width, size_t input_stride, size_t relu, size_t relu6); | |||
| #endif | |||
| #endif | |||
| void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels, | |||
| int output_width, int input_stride, bool relu, bool relu6, int kernel); | |||
| @@ -147,7 +147,12 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *> | |||
| conv_param->input_channel_ = inputs[kInputIndex]->Channel(); | |||
| conv_param->output_h_ = outputs[kOutputIndex]->Height(); | |||
| conv_param->output_w_ = outputs[kOutputIndex]->Width(); | |||
| #if defined(ENABLE_ARM64) || defined(ENABLE_AVX) | |||
| #ifdef ENABLE_AVX | |||
| if (conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3) { | |||
| kernel = | |||
| new (std::nothrow) kernel::ConvolutionDepthwiseIndirectCPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| } | |||
| #elif defined(ENABLE_ARM64) | |||
| if (CheckConvDwUseIndirectBuffer(conv_param)) { | |||
| kernel = | |||
| new (std::nothrow) kernel::ConvolutionDepthwiseIndirectCPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||