From: @lx0095 Reviewed-by: @zhanghaibo5,@zhang_xue_tong,@zhang_xue_tong Signed-off-by: @zhang_xue_tong,@zhang_xue_tongtags/v1.1.0
| @@ -0,0 +1,273 @@ | |||||
| #ifdef ENABLE_AVX | |||||
| #ifndef WIN32 | |||||
| .text | |||||
| .align 4 | |||||
| .global ConvDwFp32Avx3x3 | |||||
| #ifndef __APPLE__ | |||||
| .type ConvDwFp32Avx3x3, %function | |||||
| #endif | |||||
| // void ConvDwFp32Avx3x3(float *output, float **input, const float *weights, const float *bias, int channels, int output_width, | |||||
| // size_t input_stride, size_t relu) | |||||
| // rdi: output | |||||
| // rsi: input | |||||
| // rdx: weights | |||||
| // rcx: bias | |||||
| // r8: channels | |||||
| // r9: output_width | |||||
| // 8: input_stride | |||||
| // 16: relu | |||||
| // 24: relu6 | |||||
| ConvDwFp32Avx3x3: | |||||
| pushq %r15 | |||||
| pushq %r14 | |||||
| pushq %r13 | |||||
| pushq %r12 | |||||
| pushq %rbx | |||||
| pushq %rbp | |||||
| pushq %r9 | |||||
| pushq %r8 | |||||
| pushq %rcx | |||||
| pushq %rdx | |||||
| pushq %rsi | |||||
| pushq %rdi | |||||
| addq $96, %rsp | |||||
| movq $6, %rax | |||||
| vcvtsi2ss %rax, %xmm15, %xmm15 | |||||
| vshufps $0, %xmm15, %xmm15, %xmm15 | |||||
| vinsertf128 $1, %xmm15, %ymm15, %ymm15 | |||||
| vxorps %ymm14, %ymm14, %ymm14 | |||||
| LoopPixel: | |||||
| movq -80(%rsp), %rdx | |||||
| movq -72(%rsp), %rcx | |||||
| movq -64(%rsp), %r8 | |||||
| movq (%rsi), %r9 | |||||
| movq 8(%rsi), %r10 | |||||
| movq 16(%rsi), %r11 | |||||
| movq 24(%rsi), %r12 | |||||
| movq 32(%rsi), %r13 | |||||
| movq 40(%rsi), %r14 | |||||
| movq 48(%rsi), %r15 | |||||
| movq 56(%rsi), %rbp | |||||
| movq 64(%rsi), %rbx | |||||
| vmovups (%r9), %ymm0 | |||||
| addq $32, %r9 | |||||
| vmovups (%r10), %ymm1 | |||||
| addq $32, %r10 | |||||
| vmovups (%r11), %ymm2 | |||||
| addq $32, %r11 | |||||
| vmovups (%rdx), %ymm11 | |||||
| addq $32, %rdx | |||||
| vmovups (%rdx), %ymm12 | |||||
| addq $32, %rdx | |||||
| vmovups (%rdx), %ymm13 | |||||
| addq $32, %rdx | |||||
| vmovups (%rcx), %ymm10 | |||||
| addq $32, %rcx | |||||
| cmpq $8, %r8 | |||||
| jbe LeftLoop | |||||
| LoopC8: | |||||
| vfmadd231ps %ymm11, %ymm0, %ymm10 | |||||
| vmovups (%r12), %ymm3 | |||||
| addq $32, %r12 | |||||
| vmovups (%rdx), %ymm11 | |||||
| addq $32, %rdx | |||||
| vfmadd231ps %ymm12, %ymm1, %ymm10 | |||||
| vmovups (%r13), %ymm4 | |||||
| addq $32, %r13 | |||||
| vmovups (%rdx), %ymm12 | |||||
| addq $32, %rdx | |||||
| vfmadd231ps %ymm13, %ymm2, %ymm10 | |||||
| vmovups (%r14), %ymm5 | |||||
| addq $32, %r14 | |||||
| vmovups (%rdx), %ymm13 | |||||
| addq $32, %rdx | |||||
| vfmadd231ps %ymm11, %ymm3, %ymm10 | |||||
| vmovups (%r15), %ymm6 | |||||
| addq $32, %r15 | |||||
| vmovups (%rdx), %ymm11 | |||||
| addq $32, %rdx | |||||
| vfmadd231ps %ymm12, %ymm4, %ymm10 | |||||
| vmovups (%rbp), %ymm7 | |||||
| addq $32, %rbp | |||||
| vmovups (%rdx), %ymm12 | |||||
| addq $32, %rdx | |||||
| vfmadd231ps %ymm13, %ymm5, %ymm10 | |||||
| vmovups (%rbx), %ymm8 | |||||
| addq $32, %rbx | |||||
| vmovups (%rdx), %ymm13 | |||||
| addq $32, %rdx | |||||
| vfmadd231ps %ymm11, %ymm6, %ymm10 | |||||
| vmovups (%r9), %ymm0 | |||||
| addq $32, %r9 | |||||
| vmovups (%rdx), %ymm11 | |||||
| addq $32, %rdx | |||||
| vfmadd231ps %ymm12, %ymm7, %ymm10 | |||||
| vmovups (%r10), %ymm1 | |||||
| addq $32, %r10 | |||||
| vmovups (%rdx), %ymm12 | |||||
| addq $32, %rdx | |||||
| vfmadd231ps %ymm13, %ymm8, %ymm10 | |||||
| vmovups (%r11), %ymm2 | |||||
| addq $32, %r11 | |||||
| vmovups (%rdx), %ymm13 | |||||
| addq $32, %rdx | |||||
| movq 24(%rsp), %rax | |||||
| cmpq $0, %rax | |||||
| jne Relu6 | |||||
| movq 16(%rsp), %rax | |||||
| cmpq $0, %rax | |||||
| jne Relu | |||||
| jmp Write | |||||
| Relu6: | |||||
| vminps %ymm15, %ymm10, %ymm10 | |||||
| Relu: | |||||
| vmaxps %ymm14, %ymm10, %ymm10 | |||||
| Write: | |||||
| vmovups %ymm10, (%rdi) | |||||
| addq $32, %rdi | |||||
| vmovups (%rcx), %ymm10 | |||||
| addq $32, %rcx | |||||
| subq $8, %r8 | |||||
| cmpq $8, %r8 | |||||
| ja LoopC8 | |||||
| LeftLoop: | |||||
| vfmadd231ps %ymm11, %ymm0, %ymm10 | |||||
| vmovups (%r12), %ymm3 | |||||
| addq $32, %r12 | |||||
| vmovups (%rdx), %ymm11 | |||||
| addq $32, %rdx | |||||
| vfmadd231ps %ymm12, %ymm1, %ymm10 | |||||
| vmovups (%r13), %ymm4 | |||||
| addq $32, %r13 | |||||
| vmovups (%rdx), %ymm12 | |||||
| addq $32, %rdx | |||||
| vfmadd231ps %ymm13, %ymm2, %ymm10 | |||||
| vmovups (%r14), %ymm5 | |||||
| addq $32, %r14 | |||||
| vmovups (%rdx), %ymm13 | |||||
| addq $32, %rdx | |||||
| vfmadd231ps %ymm11, %ymm3, %ymm10 | |||||
| vmovups (%r15), %ymm6 | |||||
| addq $32, %r15 | |||||
| vmovups (%rdx), %ymm11 | |||||
| addq $32, %rdx | |||||
| vfmadd231ps %ymm12, %ymm4, %ymm10 | |||||
| vmovups (%rbp), %ymm7 | |||||
| addq $32, %rbp | |||||
| vmovups (%rdx), %ymm12 | |||||
| addq $32, %rdx | |||||
| vfmadd231ps %ymm13, %ymm5, %ymm10 | |||||
| vmovups (%rbx), %ymm8 | |||||
| addq $32, %rbx | |||||
| vmovups (%rdx), %ymm13 | |||||
| addq $32, %rdx | |||||
| vfmadd231ps %ymm11, %ymm6, %ymm10 | |||||
| vfmadd231ps %ymm12, %ymm7, %ymm10 | |||||
| vfmadd231ps %ymm13, %ymm8, %ymm10 | |||||
| movq 24(%rsp), %rax | |||||
| cmpq $0, %rax | |||||
| jne LeftRelu6 | |||||
| movq 16(%rsp), %rax | |||||
| cmpq $0, %rax | |||||
| jne LeftRelu | |||||
| jmp LeftWrite | |||||
| LeftRelu6: | |||||
| vminps %ymm15, %ymm10, %ymm10 | |||||
| LeftRelu: | |||||
| vmaxps %ymm14, %ymm10, %ymm10 | |||||
| LeftWrite: | |||||
| cmpq $1, %r8 | |||||
| je Write1 | |||||
| cmpq $2, %r8 | |||||
| je Write2 | |||||
| cmpq $3, %r8 | |||||
| je Write3 | |||||
| cmpq $4, %r8 | |||||
| je Write4 | |||||
| cmpq $5, %r8 | |||||
| je Write5 | |||||
| cmpq $6, %r8 | |||||
| je Write6 | |||||
| cmpq $7, %r8 | |||||
| je Write7 | |||||
| jmp Write8 | |||||
| Write1: | |||||
| vmovss %xmm10, (%rdi) | |||||
| addq $4, %rdi | |||||
| jmp NextPixel | |||||
| Write2: | |||||
| vmovsd %xmm10, (%rdi) | |||||
| addq $8, %rdi | |||||
| jmp NextPixel | |||||
| Write3: | |||||
| vmovsd %xmm10, (%rdi) | |||||
| movhlps %xmm10, %xmm10 | |||||
| vmovss %xmm10, 8(%rdi) | |||||
| addq $12, %rdi | |||||
| jmp NextPixel | |||||
| Write4: | |||||
| vmovups %xmm10, (%rdi) | |||||
| addq $16, %rdi | |||||
| jmp NextPixel | |||||
| Write5: | |||||
| vmovups %xmm10, (%rdi) | |||||
| vextractf128 $1, %ymm10, %xmm9 | |||||
| vmovss %xmm9, 16(%rdi) | |||||
| addq $20, %rdi | |||||
| jmp NextPixel | |||||
| Write6: | |||||
| vmovups %xmm10, (%rdi) | |||||
| vextractf128 $1, %ymm10, %xmm9 | |||||
| vmovsd %xmm9, 16(%rdi) | |||||
| addq $24, %rdi | |||||
| jmp NextPixel | |||||
| Write7: | |||||
| vmovups %xmm10, (%rdi) | |||||
| vextractf128 $1, %ymm10, %xmm9 | |||||
| vmovsd %xmm9, 16(%rdi) | |||||
| movhlps %xmm9, %xmm9 | |||||
| vmovss %xmm9, 24(%rdi) | |||||
| addq $28, %rdi | |||||
| jmp NextPixel | |||||
| Write8: | |||||
| vmovups %ymm10, (%rdi) | |||||
| add $32, %rdi | |||||
| NextPixel: | |||||
| movq 8(%rsp), %rbp | |||||
| addq %rbp, %rsi | |||||
| movq -56(%rsp), %rax | |||||
| subq $1, %rax | |||||
| movq %rax, -56(%rsp) | |||||
| cmpq $0, %rax | |||||
| ja LoopPixel | |||||
| End: | |||||
| subq $96, %rsp | |||||
| popq %rdi | |||||
| popq %rsi | |||||
| popq %rdx | |||||
| popq %rcx | |||||
| popq %r8 | |||||
| popq %r9 | |||||
| popq %rbp | |||||
| popq %rbx | |||||
| popq %r12 | |||||
| popq %r13 | |||||
| popq %r14 | |||||
| popq %r15 | |||||
| retq | |||||
| #endif | |||||
| #endif | |||||
| @@ -78,3 +78,52 @@ void Relu6Fp32(float *data, float *dst, int ele_num) { | |||||
| data[j] = data[j] > 6 ? 6 : data[j]; | data[j] = data[j] > 6 ? 6 : data[j]; | ||||
| } | } | ||||
| } | } | ||||
| #ifdef ENABLE_AVX | |||||
| #ifdef WIN32 | |||||
| void ReluFp32C8(float *data, float *dst, int ele_num) { | |||||
| int four_block = UP_DIV(ele_num, C8NUM); | |||||
| for (int i = 0; i < four_block - 1; i++) { | |||||
| int index = i * C8NUM; | |||||
| data[index] = data[index] < 0 ? 0 : data[index]; | |||||
| data[index + 1] = data[index + 1] < 0 ? 0 : data[index + 1]; | |||||
| data[index + 2] = data[index + 2] < 0 ? 0 : data[index + 2]; | |||||
| data[index + 3] = data[index + 3] < 0 ? 0 : data[index + 3]; | |||||
| data[index + 4] = data[index + 4] < 0 ? 0 : data[index + 4]; | |||||
| data[index + 5] = data[index + 5] < 0 ? 0 : data[index + 5]; | |||||
| data[index + 6] = data[index + 6] < 0 ? 0 : data[index + 6]; | |||||
| data[index + 7] = data[index + 7] < 0 ? 0 : data[index + 7]; | |||||
| } | |||||
| for (int j = (four_block - 1) * C8NUM; j < ele_num; ++j) { | |||||
| data[j] = data[j] < 0 ? 0 : data[j]; | |||||
| } | |||||
| } | |||||
| void Relu6Fp32C8(float *data, float *dst, int ele_num) { | |||||
| int four_block = UP_DIV(ele_num, C8NUM); | |||||
| for (int i = 0; i < four_block - 1; i++) { | |||||
| int index = i * C8NUM; | |||||
| data[index] = data[index] < 0 ? 0 : data[index]; | |||||
| data[index] = data[index] > 6 ? 6 : data[index]; | |||||
| data[index + 1] = data[index + 1] < 0 ? 0 : data[index + 1]; | |||||
| data[index + 1] = data[index + 1] > 6 ? 6 : data[index + 1]; | |||||
| data[index + 2] = data[index + 2] < 0 ? 0 : data[index + 2]; | |||||
| data[index + 2] = data[index + 2] > 6 ? 6 : data[index + 2]; | |||||
| data[index + 3] = data[index + 3] < 0 ? 0 : data[index + 3]; | |||||
| data[index + 3] = data[index + 3] > 6 ? 6 : data[index + 3]; | |||||
| data[index + 4] = data[index + 4] < 0 ? 0 : data[index + 4]; | |||||
| data[index + 4] = data[index + 4] > 6 ? 6 : data[index + 4]; | |||||
| data[index + 5] = data[index + 5] < 0 ? 0 : data[index + 5]; | |||||
| data[index + 5] = data[index + 5] > 6 ? 6 : data[index + 5]; | |||||
| data[index + 6] = data[index + 6] < 0 ? 0 : data[index + 6]; | |||||
| data[index + 6] = data[index + 6] > 6 ? 6 : data[index + 6]; | |||||
| data[index + 7] = data[index + 7] < 0 ? 0 : data[index + 7]; | |||||
| data[index + 7] = data[index + 7] > 6 ? 6 : data[index + 7]; | |||||
| } | |||||
| for (int j = (four_block - 1) * C8NUM; j < ele_num; ++j) { | |||||
| data[j] = data[j] < 0 ? 0 : data[j]; | |||||
| data[j] = data[j] > 6 ? 6 : data[j]; | |||||
| } | |||||
| } | |||||
| #endif | |||||
| #endif | |||||
| @@ -31,6 +31,12 @@ int8_t MinInt8(int8_t a, int8_t b); | |||||
| int8_t MaxInt8(int8_t a, int8_t b); | int8_t MaxInt8(int8_t a, int8_t b); | ||||
| void ReluFp32(float *data, float *dst, int ele_num); | void ReluFp32(float *data, float *dst, int ele_num); | ||||
| void Relu6Fp32(float *data, float *dst, int ele_num); | void Relu6Fp32(float *data, float *dst, int ele_num); | ||||
| #ifdef ENABLE_AVX | |||||
| #ifdef WIN32 | |||||
| void ReluFp32C8(float *data, float *dst, int ele_num); | |||||
| void Relu6Fp32C8(float *data, float *dst, int ele_num); | |||||
| #endif | |||||
| #endif | |||||
| int offset(const int *shape, const int dim0, const int dim1, const int dim2, const int dim3); | int offset(const int *shape, const int dim0, const int dim1, const int dim2, const int dim3); | ||||
| int offsetComm(const int *shape, const int dim0, const int dim1, const int dim2); | int offsetComm(const int *shape, const int dim0, const int dim1, const int dim2); | ||||
| int offset4d(const int *shape, const int *dims); | int offset4d(const int *shape, const int *dims); | ||||
| @@ -681,6 +681,47 @@ void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, c | |||||
| #endif | #endif | ||||
| #ifdef ENABLE_AVX | #ifdef ENABLE_AVX | ||||
| #ifdef WIN32 | |||||
| void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels, | |||||
| int output_width, int input_stride, bool relu, bool relu6, int kernel) { | |||||
| do { | |||||
| float *in[kernel]; | |||||
| for (int k = 0; k < kernel; k++) { | |||||
| in[k] = input[k]; | |||||
| } | |||||
| input = input + input_stride; | |||||
| size_t c = channels; | |||||
| const float *w = weights; | |||||
| float *out = output; | |||||
| memcpy(out, bias, channels * sizeof(float)); | |||||
| for (; c >= C8NUM; c -= C8NUM) { | |||||
| for (int i = 0; i < C8NUM; i++) { | |||||
| for (int k = 0; k < kernel; k++) { | |||||
| out[i] += in[k][i] * w[i + k * C8NUM]; | |||||
| } | |||||
| } | |||||
| w += kernel * C8NUM; | |||||
| out += C8NUM; | |||||
| for (int k = 0; k < kernel; k++) { | |||||
| in[k] += C8NUM; | |||||
| } | |||||
| } | |||||
| for (int i = 0; i < c; i++) { | |||||
| for (int k = 0; k < kernel; k++) { | |||||
| out[i] += in[k][i] * w[i + k * C8NUM]; | |||||
| } | |||||
| } | |||||
| if (relu) { | |||||
| ReluFp32C8(output, output, channels); | |||||
| } | |||||
| if (relu6) { | |||||
| Relu6Fp32C8(output, output, channels); | |||||
| } | |||||
| output += channels; | |||||
| } while (--output_width != 0); | |||||
| } | |||||
| #else | |||||
| void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels, | void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels, | ||||
| int output_width, int input_stride, bool relu, bool relu6, int kernel) { | int output_width, int input_stride, bool relu, bool relu6, int kernel) { | ||||
| if (kernel == 9) { | if (kernel == 9) { | ||||
| @@ -688,6 +729,7 @@ void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, c | |||||
| } | } | ||||
| } | } | ||||
| #endif | #endif | ||||
| #endif | |||||
| void ConvDwIndirection(float *output_data, float **indirect_buffer, const float *weight_data, const float *bias_data, | void ConvDwIndirection(float *output_data, float **indirect_buffer, const float *weight_data, const float *bias_data, | ||||
| float *zero_ptr, const ConvParameter *conv_param, int task_id) { | float *zero_ptr, const ConvParameter *conv_param, int task_id) { | ||||
| @@ -67,9 +67,11 @@ void ConvDwFp32Indirect5x5(float *output, float **input, const float *weights, c | |||||
| #endif | #endif | ||||
| #ifdef ENABLE_AVX | #ifdef ENABLE_AVX | ||||
| #ifndef WIN32 | |||||
| void ConvDwFp32Avx3x3(float *output, float **input, const float *weights, const float *bias, int channels, | void ConvDwFp32Avx3x3(float *output, float **input, const float *weights, const float *bias, int channels, | ||||
| int output_width, size_t input_stride, size_t relu, size_t relu6); | int output_width, size_t input_stride, size_t relu, size_t relu6); | ||||
| #endif | #endif | ||||
| #endif | |||||
| void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels, | void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels, | ||||
| int output_width, int input_stride, bool relu, bool relu6, int kernel); | int output_width, int input_stride, bool relu, bool relu6, int kernel); | ||||
| @@ -147,7 +147,12 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *> | |||||
| conv_param->input_channel_ = inputs[kInputIndex]->Channel(); | conv_param->input_channel_ = inputs[kInputIndex]->Channel(); | ||||
| conv_param->output_h_ = outputs[kOutputIndex]->Height(); | conv_param->output_h_ = outputs[kOutputIndex]->Height(); | ||||
| conv_param->output_w_ = outputs[kOutputIndex]->Width(); | conv_param->output_w_ = outputs[kOutputIndex]->Width(); | ||||
| #if defined(ENABLE_ARM64) || defined(ENABLE_AVX) | |||||
| #ifdef ENABLE_AVX | |||||
| if (conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3) { | |||||
| kernel = | |||||
| new (std::nothrow) kernel::ConvolutionDepthwiseIndirectCPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||||
| } | |||||
| #elif defined(ENABLE_ARM64) | |||||
| if (CheckConvDwUseIndirectBuffer(conv_param)) { | if (CheckConvDwUseIndirectBuffer(conv_param)) { | ||||
| kernel = | kernel = | ||||
| new (std::nothrow) kernel::ConvolutionDepthwiseIndirectCPUKernel(opParameter, inputs, outputs, ctx, primitive); | new (std::nothrow) kernel::ConvolutionDepthwiseIndirectCPUKernel(opParameter, inputs, outputs, ctx, primitive); | ||||