| @@ -0,0 +1,177 @@ | |||
| #ifdef ENABLE_AVX | |||
| .text | |||
| .align 4 | |||
| .global ConvDwFp32Border | |||
| #ifndef __APPLE__ | |||
| #ifndef WIN32 | |||
| .type ConvDwFp32Border, %function | |||
| #endif | |||
| #endif | |||
| // void ConvDwFp32Border(float *dst, const float *src, const float *weight, const float *bias, size_t height, | |||
| // size_t width, size_t in_kh_step, size_t in_kw_step, size_t kernel_w, size_t relu, | |||
| // size_t relu6); | |||
| ConvDwFp32Border: | |||
| pushq %r15 | |||
| pushq %r14 | |||
| pushq %r13 | |||
| pushq %r12 | |||
| pushq %rbx | |||
| pushq %rbp | |||
| pushq %r9 | |||
| pushq %r8 // -64 | |||
| pushq %rcx // -72 | |||
| pushq %rdx // -80 | |||
| pushq %rsi | |||
| pushq %rdi | |||
| addq $96, %rsp | |||
| movq %rdi, %rdx | |||
| #ifdef WIN32 | |||
| movq %rcx, %rdx | |||
| #endif | |||
| movq 8(%rdx), %r12 // src | |||
| movq 16(%rdx), %r13 // weight | |||
| movq 24(%rdx), %rbp // bias | |||
| movq 32(%rdx), %r11 // height | |||
| movq 40(%rdx), %r10 | |||
| movq %r10, -72(%rsp) // width | |||
| movq 48(%rdx), %r10 | |||
| movq %r10, -80(%rsp) // in_kh_step | |||
| movq 56(%rdx), %r10 // in_kw_step | |||
| movq 64(%rdx), %rax // kernel_w | |||
| movq 72(%rdx), %rcx // relu | |||
| movq 80(%rdx), %rbx // reul6 | |||
| movq $6, -64(%rsp) | |||
| movq (%rdx), %rdx | |||
| cmpq $0, %r11 | |||
| je End | |||
| xorps %xmm8, %xmm8 | |||
| LoopHeight: | |||
| movq %r12, %rsi // src_kh, src_kw | |||
| movq %r13, %rdi // weight_kh, weight_kw | |||
| movq -72(%rsp), %r8 // width | |||
| cmpq $6, %r8 | |||
| jae LoopWidth6 | |||
| cmpq $4, %r8 | |||
| jae LoopWidth4 | |||
| cmpq $1, %r8 | |||
| jae LoopWidth1 | |||
| jmp LoopWidthEnd | |||
| LoopWidth6: | |||
| xorps %xmm6, %xmm6 | |||
| xorps %xmm7, %xmm7 | |||
| imul $3, %r10, %r9 | |||
| addq %rsi, %r9 | |||
| vmovups (%rsi), %xmm0 // src_kw | |||
| vmovups (%rsi, %r10), %xmm1 | |||
| vmovups (%rsi, %r10, 2), %xmm2 | |||
| vmovups (%r9), %xmm3 | |||
| vmovups (%rsi, %r10, 4), %xmm4 | |||
| vmovups (%r9, %r10, 2), %xmm5 | |||
| vfmadd231ps (%rdi), %xmm0, %xmm6 | |||
| vfmadd231ps 16(%rdi), %xmm1, %xmm7 | |||
| vfmadd231ps 32(%rdi), %xmm2, %xmm8 | |||
| vfmadd231ps 48(%rdi), %xmm3, %xmm6 | |||
| vfmadd231ps 64(%rdi), %xmm4, %xmm7 | |||
| vfmadd231ps 80(%rdi), %xmm5, %xmm8 | |||
| addps %xmm6, %xmm7 | |||
| imul $6, %r10, %r15 | |||
| addq $96, %rdi | |||
| addps %xmm7, %xmm8 | |||
| addq %r15, %rsi | |||
| subq $6, %r8 | |||
| cmpq $6, %r8 | |||
| jae LoopWidth6 | |||
| cmpq $4, %r8 | |||
| jae LoopWidth4 | |||
| cmpq $0, %r8 | |||
| je LoopWidthEnd | |||
| jmp LoopWidth1 | |||
| LoopWidth4: | |||
| xorps %xmm6, %xmm6 | |||
| xorps %xmm7, %xmm7 | |||
| imul $3, %r10, %r9 | |||
| addq %rsi, %r9 | |||
| vmovups (%rsi), %xmm0 // src_kw | |||
| vmovups (%rsi, %r10, 1), %xmm1 | |||
| vmovups (%rsi, %r10, 2), %xmm2 | |||
| vmovups (%r9), %xmm3 | |||
| vfmadd231ps (%rdi), %xmm0, %xmm6 | |||
| vfmadd231ps 16(%rdi), %xmm1, %xmm7 | |||
| vfmadd231ps 32(%rdi), %xmm2, %xmm8 | |||
| vfmadd231ps 48(%rdi), %xmm3, %xmm6 | |||
| addps %xmm6, %xmm7 | |||
| imul $4, %r10, %r15 | |||
| addq $64, %rdi | |||
| addps %xmm7, %xmm8 | |||
| addq %r15, %rsi | |||
| subq $4, %r8 | |||
| cmpq $4, %r8 | |||
| jae LoopWidth4 | |||
| cmpq $0, %r8 | |||
| je LoopWidthEnd | |||
| jmp LoopWidth1 | |||
| LoopWidth1: | |||
| vmovups (%rsi), %xmm0 // input_tmp | |||
| addq %r10, %rsi | |||
| vfmadd231ps (%rdi), %xmm0, %xmm8 | |||
| addq $16, %rdi | |||
| subq $1, %r8 | |||
| cmpq $0, %r8 | |||
| ja LoopWidth1 | |||
| jmp LoopWidthEnd | |||
| LoopWidthEnd: | |||
| subq $1, %r11 | |||
| cmpq $0, %r11 | |||
| je LoopHeightEnd | |||
| addq -80(%rsp), %r12 // in_kh_step | |||
| addq %rax, %r13 // kernel_w_step | |||
| jmp LoopHeight | |||
| LoopHeightEnd: | |||
| xorps %xmm10, %xmm10 | |||
| vbroadcastss -64(%rsp), %xmm9 | |||
| addps (%rbp), %xmm8 | |||
| cmpq $1, %rbx | |||
| je Relu6 | |||
| cmpq $1, %rcx | |||
| je Relu | |||
| jmp Write | |||
| Relu6: | |||
| minps %xmm9, %xmm8 | |||
| Relu: | |||
| maxps %xmm10, %xmm8 | |||
| Write: | |||
| movups %xmm8, (%rdx) | |||
| End: | |||
| subq $96, %rsp | |||
| popq %rdi | |||
| popq %rsi | |||
| popq %rdx | |||
| popq %rcx | |||
| popq %r8 | |||
| popq %r9 | |||
| popq %rbp | |||
| popq %rbx | |||
| popq %r12 | |||
| popq %r13 | |||
| popq %r14 | |||
| popq %r15 | |||
| retq | |||
| #endif | |||
| @@ -0,0 +1,178 @@ | |||
| #ifdef ENABLE_AVX | |||
| .text | |||
| .align 4 | |||
| .global ConvDwFp32Row | |||
| #ifndef __APPLE__ | |||
| #ifndef WIN32 | |||
| .type ConvDwFp32Row, %function | |||
| #endif | |||
| #endif | |||
| // void ConvDwFp32Row(float *output_ptr, const float *input_tmp, const float *weight_ptr, size_t num_pixels, | |||
| // size_t output_channel, size_t input_step); | |||
| // in linux x64 platform: | |||
| // rdi: output_ptr | |||
| // rsi: input_ptr | |||
| // rdx: weight_ptr | |||
| // rcx: num_pixels | |||
| // r8: output_channel | |||
| // r9: input_step | |||
| // in win x64 platform: "shadow space" needs to be opened up for first four parameters ==> 32 bites | |||
| // rcx: output_ptr | |||
| // rdx: input_ptr | |||
| // r8: weight_ptr | |||
| // r9: num_pixels | |||
| // 40: output_channel | |||
| // 48: input_step | |||
| ConvDwFp32Row: | |||
| pushq %r15 | |||
| pushq %r14 | |||
| pushq %r13 | |||
| pushq %r12 | |||
| pushq %rsi | |||
| pushq %rdi | |||
| addq $48, %rsp | |||
| #ifdef WIN32 | |||
| movq %rcx, %rdi // output_ptr | |||
| movq %rdx, %rsi // input_ptr | |||
| movq %r8, %rdx // weight_ptr | |||
| movq %r9, %rcx // num_pixels | |||
| movq 40(%rsp), %r8 // output_channel | |||
| movq 48(%rsp), %r9 // input_step | |||
| #endif | |||
| movq $4, %r13 | |||
| imul %r13, %r9 | |||
| movq %rsi, %r13 // input_ptr | |||
| movq %rdx, %r14 // weight_ptr | |||
| movq %r8, %r15 // output_channel | |||
| cmpq $0, %rcx | |||
| je End | |||
| LoopPixel: | |||
| movq %r13, %rsi // input_tmp | |||
| movq %r14, %rdx // weight_tmp | |||
| movq %r15, %r8 // channel_tmp | |||
| cmpq $32, %r8 | |||
| jae LoopC32 | |||
| cmpq $16, %r8 | |||
| jae LoopC16 | |||
| cmpq $8, %r8 | |||
| jae LoopC8 | |||
| cmpq $0, %r8 | |||
| ja LoopC | |||
| jmp LoopCEnd | |||
| LoopC32: | |||
| vmovups (%rsi), %ymm0 // input_tmp | |||
| vmovups 32(%rsi), %ymm1 | |||
| vmovups 64(%rsi), %ymm2 | |||
| vmovups 96(%rsi), %ymm3 | |||
| vmovups (%rdi), %ymm8 // output_tmp | |||
| vmovups 32(%rdi), %ymm9 | |||
| vmovups 64(%rdi), %ymm10 | |||
| vmovups 96(%rdi), %ymm11 | |||
| addq $128, %rsi | |||
| vfmadd231ps (%rdx), %ymm0, %ymm8 | |||
| vfmadd231ps 32(%rdx), %ymm1, %ymm9 | |||
| vfmadd231ps 64(%rdx), %ymm2, %ymm10 | |||
| vfmadd231ps 96(%rdx), %ymm3, %ymm11 | |||
| vmovups %ymm8, (%rdi) // output_ptr | |||
| vmovups %ymm9, 32(%rdi) | |||
| vmovups %ymm10, 64(%rdi) | |||
| vmovups %ymm11, 96(%rdi) | |||
| addq $128, %rdi | |||
| addq $128, %rdx | |||
| subq $32, %r8 | |||
| cmpq $32, %r8 | |||
| jae LoopC32 | |||
| cmpq $16, %r8 | |||
| jae LoopC16 | |||
| cmpq $8, %r8 | |||
| jae LoopC8 | |||
| cmpq $0, %r8 | |||
| ja LoopC | |||
| jmp LoopCEnd | |||
| LoopC16: | |||
| vmovups (%rsi), %ymm0 // input_tmp | |||
| vmovups (%rdi), %ymm8 // output_tmp | |||
| vmovups 32(%rsi), %ymm1 | |||
| vmovups 32(%rdi), %ymm9 | |||
| addq $64, %rsi | |||
| vfmadd231ps (%rdx), %ymm0, %ymm8 | |||
| vfmadd231ps 32(%rdx), %ymm1, %ymm9 | |||
| vmovups %ymm8, (%rdi) // output_ptr | |||
| addq $64, %rdx | |||
| vmovups %ymm9, 32(%rdi) | |||
| addq $64, %rdi | |||
| subq $16, %r8 | |||
| cmpq $16, %r8 | |||
| jae LoopC16 | |||
| cmpq $8, %r8 | |||
| jae LoopC8 | |||
| cmpq $0, %r8 | |||
| ja LoopC | |||
| jmp LoopCEnd | |||
| LoopC8: | |||
| vmovups (%rsi), %ymm0 // input_tmp | |||
| vmovups (%rdi), %ymm8 // output_tmp | |||
| addq $32, %rsi | |||
| vfmadd231ps (%rdx), %ymm0, %ymm8 | |||
| addq $32, %rdx | |||
| vmovups %ymm8, (%rdi) | |||
| addq $32, %rdi | |||
| subq $8, %r8 | |||
| cmpq $8, %r8 | |||
| jae LoopC8 | |||
| cmpq $0, %r8 | |||
| ja LoopC | |||
| jmp LoopCEnd | |||
| LoopC: | |||
| vmovss (%rsi), %xmm0 // input_tmp | |||
| vmovss (%rdi), %xmm8 // output_ptr | |||
| vfmadd231ss (%rdx), %xmm0, %xmm8 | |||
| addq $4, %rsi | |||
| addq $4, %rdx | |||
| vmovss %xmm8, (%rdi) | |||
| addq $4, %rdi | |||
| subq $1, %r8 | |||
| cmpq $0, %r8 | |||
| ja LoopC | |||
| jmp LoopCEnd | |||
| LoopCEnd: | |||
| subq $1, %rcx // num_pixel -= 1 | |||
| cmpq $0, %rcx | |||
| je End | |||
| addq %r9, %r13 | |||
| jmp LoopPixel | |||
| End: | |||
| subq $48, %rsp | |||
| popq %rdi | |||
| popq %rsi | |||
| popq %r12 | |||
| popq %r13 | |||
| popq %r14 | |||
| popq %r15 | |||
| retq | |||
| #endif | |||
| @@ -21,6 +21,20 @@ | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/conv_parameter.h" | |||
| typedef struct ConvDwFp32BorderParam { | |||
| float *dst; | |||
| const float *src; | |||
| const float *weight; | |||
| const float *bias; | |||
| size_t height; | |||
| size_t width; | |||
| size_t in_kh_step; | |||
| size_t in_kw_step; | |||
| size_t kernel_w; | |||
| size_t relu; | |||
| size_t relu6; | |||
| } ConvDwFp32BorderParam; | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| @@ -37,8 +51,12 @@ void WinogradTransRight(const float *S, const float *B, float *M, size_t w, size | |||
| void ConvDwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, | |||
| size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, | |||
| size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, size_t relu, size_t relu6); | |||
| #ifdef ENABLE_AVX | |||
| void ConvDwFp32Border(ConvDwFp32BorderParam *param); | |||
| #else | |||
| void ConvDwFp32Border(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, | |||
| size_t in_kh_step, size_t in_kw_step, size_t kernel_w, size_t relu, size_t relu6); | |||
| #endif | |||
| void DeconvDwFp32Center(float *dst, const float *src, const float *weight, size_t height, size_t width, size_t kernel_h, | |||
| size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, | |||
| size_t in_kh_step, size_t in_kw_step); | |||
| @@ -202,8 +202,21 @@ void ConvDwBorder(float *dst, const float *src, const float *weight, const float | |||
| const float *src_kernel = src_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_; | |||
| const float *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C4NUM; | |||
| #if defined(ENABLE_ARM) || defined(ENABLE_SSE) | |||
| #ifdef ENABLE_AVX | |||
| ConvDwFp32BorderParam *param = (ConvDwFp32BorderParam *)malloc(sizeof(ConvDwFp32BorderParam)); | |||
| param->dst = dst_kernel; | |||
| param->src = src_kernel; | |||
| param->weight = weight_kernel; | |||
| param->bias = bias; | |||
| param->height = end_kh - start_kh; | |||
| param->width = end_kw - start_kw; | |||
| param->in_kh_step = sliding->in_kh_step_ * sizeof(float); | |||
| param->in_kw_step = sliding->in_kw_step_ * sizeof(float); | |||
| param->kernel_w = conv_param->kernel_w_ * C4NUM * sizeof(float); | |||
| param->relu = relu; | |||
| param->relu6 = relu6; | |||
| ConvDwFp32Border(param); | |||
| #elif defined(ENABLE_ARM) || defined(ENABLE_SSE) | |||
| ConvDwFp32Border(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, | |||
| sliding->in_kh_step_ * sizeof(float), sliding->in_kw_step_ * sizeof(float), | |||
| conv_param->kernel_w_ * C4NUM * sizeof(float), relu, relu6); | |||
| @@ -14,7 +14,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifdef ENABLE_SSE | |||
| #if defined(ENABLE_SSE) && !defined(ENABLE_AVX) | |||
| #include <x86intrin.h> | |||
| #include "nnacl/fp32/common_func_fp32.h" | |||
| @@ -19,6 +19,7 @@ | |||
| #include "nnacl/fp32/conv_depthwise_fp32.h" | |||
| #include "nnacl/intrinsics/sse/sse_common.h" | |||
| #ifndef ENABLE_AVX | |||
| void ConvDwFp32Border(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, | |||
| size_t in_kh_step, size_t in_kw_step, size_t kernel_w_step, size_t relu, size_t relu6) { | |||
| in_kh_step /= sizeof(float); | |||
| @@ -104,6 +105,7 @@ void ConvDwFp32Border(float *dst, const float *src, const float *weight, const f | |||
| } | |||
| _mm_storeu_ps(dst, dst_ma); | |||
| } | |||
| #endif | |||
| void ConvDwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, | |||
| size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, | |||