| @@ -1,15 +1,16 @@ | |||
| #ifdef ENABLE_AVX | |||
| #ifndef WIN32 | |||
| .text | |||
| .align 4 | |||
| .global ConvDwFp32Avx3x3 | |||
| #ifndef __APPLE__ | |||
| #ifndef WIN32 | |||
| .type ConvDwFp32Avx3x3, %function | |||
| #endif | |||
| #endif | |||
| // void ConvDwFp32Avx3x3(float *output, float **input, const float *weights, const float *bias, int channels, int output_width, | |||
| // size_t input_stride, size_t relu) | |||
| // void ConvDwFp32Avx3x3(float *output, float **input, const float *weights, const float *bias, size_t channels, size_t output_width, | |||
| // size_t input_stride, size_t relum, szie_t relu6) | |||
| // in linux x64 platfrom: | |||
| // rdi: output | |||
| // rsi: input | |||
| // rdx: weights | |||
| @@ -20,6 +21,16 @@ | |||
| // 16: relu | |||
| // 24: relu6 | |||
| // in win x64 platfrom: "shadow space" needs to be opened up for first four parameters ==> 32 bites | |||
| // rcx: output | |||
| // rdx: input | |||
| // r8: weights | |||
| // r9: bias | |||
| // 40: channels | |||
| // 48: output_width | |||
| // 56: input_stride | |||
| // 64: relu | |||
| // 72: relu6 | |||
| ConvDwFp32Avx3x3: | |||
| pushq %r15 | |||
| pushq %r14 | |||
| @@ -27,14 +38,34 @@ ConvDwFp32Avx3x3: | |||
| pushq %r12 | |||
| pushq %rbx | |||
| pushq %rbp | |||
| pushq %r9 | |||
| pushq %r8 | |||
| pushq %rcx | |||
| pushq %rdx | |||
| pushq %rsi | |||
| pushq %rdi | |||
| pushq %r9 // -56 | |||
| pushq %r8 // -64 | |||
| pushq %rcx // -72 | |||
| pushq %rdx // -80 | |||
| pushq %rsi // -88 | |||
| pushq %rdi // -96 | |||
| addq $96, %rsp | |||
| #ifdef WIN32 | |||
| movq %rcx, %rdi | |||
| movq %rdx, %rsi | |||
| movq %r8, %rdx | |||
| movq %r9, %rcx | |||
| movq 40(%rsp), %r8 // channels | |||
| movq 48(%rsp), %r9 // output_width | |||
| mov %rdx, -80(%rsp) | |||
| mov %rcx, -72(%rsp) | |||
| mov %r9, -56(%rsp) | |||
| mov %r8, -64(%rsp) | |||
| movq 56(%rsp), %rbp // input_stride | |||
| movq %rbp, 8(%rsp) | |||
| movq 64(%rsp), %rbp // relu | |||
| movq %rbp, 16(%rsp) | |||
| movq 72(%rsp), %rbp // relu6 | |||
| movq %rbp, 24(%rsp) | |||
| #endif | |||
| movq $6, %rax | |||
| vcvtsi2ss %rax, %xmm15, %xmm15 | |||
| vshufps $0, %xmm15, %xmm15, %xmm15 | |||
| @@ -270,4 +301,3 @@ End: | |||
| popq %r15 | |||
| retq | |||
| #endif | |||
| #endif | |||
| @@ -1,14 +1,16 @@ | |||
| #ifdef ENABLE_AVX | |||
| #ifndef WIN32 | |||
| .text | |||
| .align 4 | |||
| .global MatmulFloatAvxOpt | |||
| #ifndef __APPLE__ | |||
| #ifndef WIN32 | |||
| .type MatmulFloatAvxOpt, %function | |||
| #endif | |||
| #endif | |||
| // void MatmulFloatNeon32Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth | |||
| // void MatmulFloatAvxOpt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth | |||
| // int row, int col, size_t stride, size_t writeMode) | |||
| // parameters pass in Linux x86 platform: | |||
| // rdi: a | |||
| // rsi: b | |||
| // rdx: c | |||
| @@ -20,6 +22,18 @@ | |||
| // 24: stride | |||
| // 32: writeNhwc/writeWino | |||
| // parameters pass in win x64 platfrom: "shadow space" needs to be opened up for first four parameters ==> 32 bites | |||
| // rcx: a | |||
| // rdx: b | |||
| // r8: c | |||
| // r9: bias | |||
| // 40: act_type | |||
| // 48: depth | |||
| // 56: row | |||
| // 64: col | |||
| // 72: stride | |||
| // 80: writeMode | |||
| MatmulFloatAvxOpt: | |||
| // rbx, rsp, rbp, r12-r15 must be saved according to x86 calling convention | |||
| pushq %r15 | |||
| @@ -28,14 +42,37 @@ MatmulFloatAvxOpt: | |||
| pushq %r12 | |||
| pushq %rbx | |||
| pushq %rbp | |||
| pushq %r9 | |||
| pushq %r8 | |||
| pushq %rcx | |||
| pushq %rdx | |||
| pushq %rsi | |||
| pushq %rdi | |||
| addq $96, %rsp | |||
| pushq %r9 // -56 | |||
| pushq %r8 // -64 | |||
| pushq %rcx // -72 | |||
| pushq %rdx // -80 | |||
| pushq %rsi // -88 | |||
| pushq %rdi // -96 | |||
| pushq %rsi // -104 rsi | |||
| pushq %rdi // -112 rdi | |||
| addq $112, %rsp | |||
| #ifdef WIN32 | |||
| movq %rcx, %rdi | |||
| movq %rdx, %rsi | |||
| movq %r8, %rdx | |||
| movq %r9, %rcx | |||
| movq 40(%rsp), %r8 // act_type | |||
| movq 48(%rsp), %r9 // depth | |||
| movq %r9, -56(%rsp) // r9 | |||
| movq %rcx, -72(%rsp) // rcx | |||
| movq %rdx, -80(%rsp) // rdx | |||
| movq %rsi, -88(%rsp) // rsi | |||
| movq %rdi, -96(%rsp) // rdi | |||
| movq 56(%rsp), %rbp // row | |||
| movq %rbp, 8(%rsp) | |||
| movq 64(%rsp), %rbp // col | |||
| movq %rbp, 16(%rsp) | |||
| movq 72(%rsp), %rbp // stride | |||
| movq %rbp, 24(%rsp) | |||
| movq 80(%rsp), %rbp // weiteMode | |||
| movq %rbp, 32(%rsp) | |||
| #endif | |||
| movq 8(%rsp), %rbp | |||
| movq 16(%rsp), %rbx | |||
| movq 24(%rsp), %r10 | |||
| @@ -926,10 +963,12 @@ LoopRow: | |||
| jmp LoopRow | |||
| LoopRowEnd: | |||
| subq $96, %rsp | |||
| subq $112, %rsp | |||
| popq %rdi | |||
| popq %rsi | |||
| popq %rdx | |||
| popq %rdx | |||
| popq %rdx | |||
| popq %rcx | |||
| popq %r8 | |||
| popq %r9 | |||
| @@ -941,4 +980,3 @@ LoopRowEnd: | |||
| popq %r15 | |||
| retq | |||
| #endif | |||
| #endif | |||
| @@ -681,47 +681,6 @@ void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, c | |||
| #endif | |||
| #ifdef ENABLE_AVX | |||
| #ifdef WIN32 | |||
| void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels, | |||
| int output_width, int input_stride, bool relu, bool relu6, int kernel) { | |||
| do { | |||
| float *in[kernel]; | |||
| for (int k = 0; k < kernel; k++) { | |||
| in[k] = input[k]; | |||
| } | |||
| input = input + input_stride; | |||
| size_t c = channels; | |||
| const float *w = weights; | |||
| float *out = output; | |||
| memcpy(out, bias, channels * sizeof(float)); | |||
| for (; c >= C8NUM; c -= C8NUM) { | |||
| for (int i = 0; i < C8NUM; i++) { | |||
| for (int k = 0; k < kernel; k++) { | |||
| out[i] += in[k][i] * w[i + k * C8NUM]; | |||
| } | |||
| } | |||
| w += kernel * C8NUM; | |||
| out += C8NUM; | |||
| for (int k = 0; k < kernel; k++) { | |||
| in[k] += C8NUM; | |||
| } | |||
| } | |||
| for (int i = 0; i < c; i++) { | |||
| for (int k = 0; k < kernel; k++) { | |||
| out[i] += in[k][i] * w[i + k * C8NUM]; | |||
| } | |||
| } | |||
| if (relu) { | |||
| ReluFp32C8(output, output, channels); | |||
| } | |||
| if (relu6) { | |||
| Relu6Fp32C8(output, output, channels); | |||
| } | |||
| output += channels; | |||
| } while (--output_width != 0); | |||
| } | |||
| #else | |||
| void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels, | |||
| int output_width, int input_stride, bool relu, bool relu6, int kernel) { | |||
| if (kernel == 9) { | |||
| @@ -729,7 +688,6 @@ void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, c | |||
| } | |||
| } | |||
| #endif | |||
| #endif | |||
| void ConvDwIndirection(float *output_data, float **indirect_buffer, const float *weight_data, const float *bias_data, | |||
| float *zero_ptr, const ConvParameter *conv_param, int task_id) { | |||
| @@ -67,10 +67,8 @@ void ConvDwFp32Indirect5x5(float *output, float **input, const float *weights, c | |||
| #endif | |||
| #ifdef ENABLE_AVX | |||
| #ifndef WIN32 | |||
| void ConvDwFp32Avx3x3(float *output, float **input, const float *weights, const float *bias, int channels, | |||
| int output_width, size_t input_stride, size_t relu, size_t relu6); | |||
| #endif | |||
| void ConvDwFp32Avx3x3(float *output, float **input, const float *weights, const float *bias, size_t channels, | |||
| size_t output_width, size_t input_stride, size_t relu, size_t relu6); | |||
| #endif | |||
| void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels, | |||
| @@ -883,11 +883,7 @@ void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActT | |||
| if (out_type == OutType_C8) { | |||
| MatmulFloatSse64(a, b, c, bias, (int)act_type, deep, row, col, stride, 0, 0); | |||
| } else { | |||
| #ifdef WIN32 | |||
| MatMul6x16(a, b, c, bias, act_type, deep, row, col, stride, out_type); | |||
| #else | |||
| MatmulFloatAvxOpt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type)); | |||
| #endif | |||
| MatmulFloatAvxOpt(a, b, c, bias, (size_t)act_type, deep, row, col, stride, (size_t)(out_type)); | |||
| } | |||
| #elif ENABLE_SSE | |||
| if (out_type == OutType_C8) { | |||
| @@ -62,8 +62,8 @@ void MatmulFloatSse64(const float *a, const float *b, float *c, const float *bia | |||
| void MatmulFloatSse64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, | |||
| int col, int stride, int write_mode); | |||
| #ifdef ENABLE_AVX | |||
| void MatmulFloatAvxOpt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, | |||
| int col, int stride, int write_mode); | |||
| void MatmulFloatAvxOpt(const float *a, const float *b, float *c, const float *bias, size_t act_type, size_t depth, | |||
| size_t row, size_t col, size_t stride, size_t write_mode); | |||
| #endif | |||
| #endif | |||