| @@ -20,7 +20,8 @@ option(SUPPORT_GPU "if support gpu" off) | |||
| option(OFFLINE_COMPILE "if offline compile OpenCL kernel" off) | |||
| option(BUILD_MINDDATA_EXAMPLE "" on) | |||
| option(ENABLE_VERBOSE "" off) | |||
| option(ENABLE_X86_64_SSE "if x86_64 support SSE instruction set" off) | |||
| option(ENABLE_SSE "if x86_64 support SSE instruction set" off) | |||
| option(ENABLE_AVX "if x86_64 support SSE instruction set" off) | |||
| set(DIR_PREFIX mindspore-lite) | |||
| set(MS_VERSION ${MS_VERSION_MAJOR}.${MS_VERSION_MINOR}.${MS_VERSION_REVISION}) | |||
| @@ -187,7 +188,13 @@ endif() | |||
| if (NOT PLATFORM_ARM32 AND NOT PLATFORM_ARM64) | |||
| if ("${X86_64_SIMD}" STREQUAL "sse") | |||
| add_compile_definitions(ENABLE_X86_64_SSE) | |||
| add_compile_definitions(ENABLE_SSE) | |||
| endif () | |||
| if ("${X86_64_SIMD}" STREQUAL "avx") | |||
| add_compile_definitions(ENABLE_SSE) | |||
| add_compile_definitions(ENABLE_AVX) | |||
| set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx -mfma") | |||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx -mfma") | |||
| endif () | |||
| endif () | |||
| @@ -37,6 +37,12 @@ if ("${X86_64_SIMD}" STREQUAL "sse") | |||
| set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C) | |||
| endif() | |||
| if ("${X86_64_SIMD}" STREQUAL "avx") | |||
| file(GLOB ASSEMBLY_SRC ${NNACL_DIR}/x86_64_sse/*.c | |||
| ${NNACL_DIR}/assembly/avx/*.S) | |||
| set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C) | |||
| endif() | |||
| ########################### build nnacl static library ######################## | |||
| string(REPLACE "-fvisibility=hidden" "-fvisibility=default" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") | |||
| add_library(nnacl STATIC ${KERNEL_SRC} ${TRAIN_SRC} ${ASSEMBLY_SRC}) | |||
| @@ -0,0 +1,941 @@ | |||
| #ifdef ENABLE_AVX | |||
| #ifndef WIN32 | |||
| .text | |||
| .align 4 | |||
| .global MatmulFloatAvxOpt | |||
| #ifndef __APPLE__ | |||
| .type MatmulFloatAvxOpt, %function | |||
| #endif | |||
| // void MatmulFloatNeon32Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth | |||
| // int row, int col, size_t stride, size_t writeMode) | |||
| // rdi: a | |||
| // rsi: b | |||
| // rdx: c | |||
| // rcx: bias | |||
| // r8: act_type | |||
| // r9: depth | |||
| // 8: row | |||
| // 16: col | |||
| // 24: stride | |||
| // 32: writeNhwc/writeWino | |||
| MatmulFloatAvxOpt: | |||
| // rbx, rsp, rbp, r12-r15 must be saved according to x86 calling convention | |||
| pushq %r15 | |||
| pushq %r14 | |||
| pushq %r13 | |||
| pushq %r12 | |||
| pushq %rbx | |||
| pushq %rbp | |||
| pushq %r9 | |||
| pushq %r8 | |||
| pushq %rcx | |||
| pushq %rdx | |||
| pushq %rsi | |||
| pushq %rdi | |||
| addq $96, %rsp | |||
| movq 8(%rsp), %rbp | |||
| movq 16(%rsp), %rbx | |||
| movq 24(%rsp), %r10 | |||
| movq 32(%rsp), %r14 | |||
| movq $24, %r11 | |||
| imul %r9, %r11 | |||
| cmpq $0, %r14 | |||
| jne NoC8Steps | |||
| movq $48, %r13 | |||
| imul %rbp, %r13 | |||
| NoC8Steps: | |||
| cmpq $2, %r14 | |||
| jne NoWinoSteps | |||
| movq $4, %r12 | |||
| imul %r10, %r12 | |||
| imul %rbx, %r12 | |||
| movq $48, %r13 | |||
| imul %r10, %r13 | |||
| NoWinoSteps: | |||
| movq $4, %rax | |||
| imul %rax, %r10 | |||
| LoopRow: | |||
| movq -88(%rsp), %rsi | |||
| movq 16(%rsp), %rbx | |||
| movq -72(%rsp), %rcx | |||
| LoopCol: | |||
| cmpq $0, %r14 | |||
| je NoReloadDst | |||
| movq -80(%rsp), %rdx | |||
| NoReloadDst: | |||
| movq -96(%rsp), %rdi | |||
| movq -56(%rsp), %r9 | |||
| vmovups (%rsi), %ymm0 | |||
| vmovups 32(%rsi), %ymm1 | |||
| vbroadcastss (%rdi), %ymm10 | |||
| vbroadcastss 4(%rdi), %ymm11 | |||
| vbroadcastss 8(%rdi), %ymm12 | |||
| vbroadcastss 12(%rdi), %ymm13 | |||
| vbroadcastss 16(%rdi), %ymm2 | |||
| vbroadcastss 20(%rdi), %ymm3 | |||
| addq $64, %rsi | |||
| vmulps %ymm0, %ymm10, %ymm4 | |||
| vmulps %ymm1, %ymm10, %ymm5 | |||
| vmulps %ymm0, %ymm11, %ymm6 | |||
| vmulps %ymm1, %ymm11, %ymm7 | |||
| vmulps %ymm0, %ymm12, %ymm8 | |||
| vmulps %ymm1, %ymm12, %ymm9 | |||
| vmulps %ymm0, %ymm13, %ymm10 | |||
| vmulps %ymm1, %ymm13, %ymm11 | |||
| add $24, %rdi | |||
| vmulps %ymm0, %ymm2, %ymm12 | |||
| vmulps %ymm1, %ymm2, %ymm13 | |||
| vmulps %ymm0, %ymm3, %ymm14 | |||
| vmulps %ymm1, %ymm3, %ymm15 | |||
| subq $1, %r9 | |||
| cmpq $0, %r9 | |||
| je Bias | |||
| LoopDepth: | |||
| vmovups (%rsi), %ymm0 | |||
| vmovups 32(%rsi), %ymm1 | |||
| vbroadcastss (%rdi), %ymm2 | |||
| vbroadcastss 4(%rdi), %ymm3 | |||
| vfmadd231ps %ymm0, %ymm2, %ymm4 | |||
| addq $64, %rsi | |||
| vfmadd231ps %ymm1, %ymm2, %ymm5 | |||
| vbroadcastss 8(%rdi), %ymm2 | |||
| vfmadd231ps %ymm0, %ymm3, %ymm6 | |||
| vfmadd231ps %ymm1, %ymm3, %ymm7 | |||
| vbroadcastss 12(%rdi), %ymm3 | |||
| vfmadd231ps %ymm0, %ymm2, %ymm8 | |||
| prefetcht0 384(%rsi) | |||
| vfmadd231ps %ymm1, %ymm2, %ymm9 | |||
| vbroadcastss 16(%rdi), %ymm2 | |||
| vfmadd231ps %ymm0, %ymm3, %ymm10 | |||
| vfmadd231ps %ymm1, %ymm3, %ymm11 | |||
| vbroadcastss 20(%rdi), %ymm3 | |||
| vfmadd231ps %ymm0, %ymm2, %ymm12 | |||
| vfmadd231ps %ymm1, %ymm2, %ymm13 | |||
| addq $24, %rdi | |||
| vfmadd231ps %ymm0, %ymm3, %ymm14 | |||
| vfmadd231ps %ymm1, %ymm3, %ymm15 | |||
| subq $1, %r9 | |||
| cmpq $0, %r9 | |||
| ja LoopDepth | |||
| Bias: | |||
| cmpq $0, %rcx | |||
| je Activation | |||
| vmovups (%rcx), %ymm0 | |||
| vmovups 32(%rcx), %ymm1 | |||
| add $64, %rcx | |||
| vaddps %ymm0, %ymm4, %ymm4 | |||
| vaddps %ymm1, %ymm5, %ymm5 | |||
| vaddps %ymm0, %ymm6, %ymm6 | |||
| vaddps %ymm1, %ymm7, %ymm7 | |||
| vaddps %ymm0, %ymm8, %ymm8 | |||
| vaddps %ymm1, %ymm9, %ymm9 | |||
| vaddps %ymm0, %ymm10, %ymm10 | |||
| vaddps %ymm1, %ymm11, %ymm11 | |||
| vaddps %ymm0, %ymm12, %ymm12 | |||
| vaddps %ymm1, %ymm13, %ymm13 | |||
| vaddps %ymm0, %ymm14, %ymm14 | |||
| vaddps %ymm1, %ymm15, %ymm15 | |||
| Activation: | |||
| cmpq $3, %r8 | |||
| je Relu6 | |||
| cmpq $1, %r8 | |||
| je Relu | |||
| jmp Write | |||
| Relu6: | |||
| movq $6, %rax | |||
| vcvtsi2ss %rax, %xmm0, %xmm0 | |||
| vshufps $0, %xmm0, %xmm0, %xmm0 | |||
| vinsertf128 $1, %xmm0, %ymm0, %ymm0 | |||
| vminps %ymm0, %ymm4, %ymm4 | |||
| vminps %ymm0, %ymm5, %ymm5 | |||
| vminps %ymm0, %ymm6, %ymm6 | |||
| vminps %ymm0, %ymm7, %ymm7 | |||
| vminps %ymm0, %ymm8, %ymm8 | |||
| vminps %ymm0, %ymm9, %ymm9 | |||
| vminps %ymm0, %ymm10, %ymm10 | |||
| vminps %ymm0, %ymm11, %ymm11 | |||
| vminps %ymm0, %ymm12, %ymm12 | |||
| vminps %ymm0, %ymm13, %ymm13 | |||
| vminps %ymm0, %ymm14, %ymm14 | |||
| vminps %ymm0, %ymm15, %ymm15 | |||
| Relu: | |||
| vxorps %ymm1, %ymm1, %ymm1 | |||
| vmaxps %ymm1, %ymm4, %ymm4 | |||
| vmaxps %ymm1, %ymm5, %ymm5 | |||
| vmaxps %ymm1, %ymm6, %ymm6 | |||
| vmaxps %ymm1, %ymm7, %ymm7 | |||
| vmaxps %ymm1, %ymm8, %ymm8 | |||
| vmaxps %ymm1, %ymm9, %ymm9 | |||
| vmaxps %ymm1, %ymm10, %ymm10 | |||
| vmaxps %ymm1, %ymm11, %ymm11 | |||
| vmaxps %ymm1, %ymm12, %ymm12 | |||
| vmaxps %ymm1, %ymm13, %ymm13 | |||
| vmaxps %ymm1, %ymm14, %ymm14 | |||
| vmaxps %ymm1, %ymm15, %ymm15 | |||
| Write: | |||
| cmpq $2, %r14 | |||
| je WriteWino | |||
| cmpq $0, %r14 | |||
| je WriteC8 | |||
| cmpq $1, %rbx | |||
| je Write1 | |||
| cmpq $2, %rbx | |||
| je Write2 | |||
| cmpq $3, %rbx | |||
| je Write3 | |||
| cmpq $4, %rbx | |||
| je Write4 | |||
| cmpq $5, %rbx | |||
| je Write5 | |||
| cmpq $6, %rbx | |||
| je Write6 | |||
| cmpq $7, %rbx | |||
| je Write7 | |||
| cmpq $8, %rbx | |||
| je Write8 | |||
| cmpq $9, %rbx | |||
| je Write9 | |||
| cmpq $10, %rbx | |||
| je Write10 | |||
| cmpq $11, %rbx | |||
| je Write11 | |||
| cmpq $12, %rbx | |||
| je Write12 | |||
| cmpq $13, %rbx | |||
| je Write13 | |||
| cmpq $14, %rbx | |||
| je Write14 | |||
| cmpq $15, %rbx | |||
| je Write15 | |||
| jmp Write16 | |||
| Write1: | |||
| movq %rdx, %rax | |||
| addq $4, %rax | |||
| movq %rax, -80(%rsp) | |||
| vmovss %xmm4, (%rdx) | |||
| cmpq $1, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovss %xmm6, (%rdx) | |||
| cmpq $2, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovss %xmm8, (%rdx) | |||
| cmpq $3, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovss %xmm10, (%rdx) | |||
| cmpq $4, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovss %xmm12, (%rdx) | |||
| cmpq $5, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovss %xmm14, (%rdx) | |||
| addq %r10, %rdx | |||
| addq $4, %rdx | |||
| jmp WriteEnd | |||
| Write2: | |||
| movq %rdx, %rax | |||
| addq $8, %rax | |||
| movq %rax, -80(%rsp) | |||
| vmovsd %xmm4, (%rdx) | |||
| cmpq $1, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovsd %xmm6, (%rdx) | |||
| cmpq $2, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovsd %xmm8, (%rdx) | |||
| cmpq $3, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovsd %xmm10, (%rdx) | |||
| cmpq $4, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovsd %xmm12, (%rdx) | |||
| cmpq $5, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovsd %xmm14, (%rdx) | |||
| addq %r10, %rdx | |||
| addq $8, %rdx | |||
| jmp WriteEnd | |||
| Write3: | |||
| movq %rdx, %rax | |||
| addq $12, %rax | |||
| movq %rax, -80(%rsp) | |||
| vmovsd %xmm4, (%rdx) | |||
| movhlps %xmm4, %xmm4 | |||
| vmovss %xmm4, 8(%rdx) | |||
| cmpq $1, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovsd %xmm6, (%rdx) | |||
| movhlps %xmm6, %xmm6 | |||
| vmovss %xmm6, 8(%rdx) | |||
| cmpq $2, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovsd %xmm8, (%rdx) | |||
| movhlps %xmm8, %xmm8 | |||
| vmovss %xmm8, 8(%rdx) | |||
| cmpq $3, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovsd %xmm10, (%rdx) | |||
| movhlps %xmm10, %xmm10 | |||
| vmovss %xmm10, 8(%rdx) | |||
| cmpq $4, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovsd %xmm12, (%rdx) | |||
| movhlps %xmm12, %xmm12 | |||
| vmovss %xmm12, 8(%rdx) | |||
| cmpq $5, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovsd %xmm14, (%rdx) | |||
| movhlps %xmm14, %xmm14 | |||
| vmovss %xmm14, 8(%rdx) | |||
| addq %r10, %rdx | |||
| addq $12, %rdx | |||
| jmp WriteEnd | |||
| Write4: | |||
| movq %rdx, %rax | |||
| addq $16, %rax | |||
| movq %rax, -80(%rsp) | |||
| vmovups %xmm4, (%rdx) | |||
| cmpq $1, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %xmm6, (%rdx) | |||
| cmpq $2, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %xmm8, (%rdx) | |||
| cmpq $3, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %xmm10, (%rdx) | |||
| cmpq $4, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %xmm12, (%rdx) | |||
| cmpq $5, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %xmm14, (%rdx) | |||
| addq %r10, %rdx | |||
| addq $16, %rdx | |||
| jmp WriteEnd | |||
| Write5: | |||
| movq %rdx, %rax | |||
| addq $20, %rax | |||
| movq %rax, -80(%rsp) | |||
| vmovups %xmm4, (%rdx) | |||
| vextractf128 $1, %ymm4, %xmm4 | |||
| vmovss %xmm4, 16(%rdx) | |||
| cmpq $1, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %xmm6, (%rdx) | |||
| vextractf128 $1, %ymm6, %xmm6 | |||
| vmovss %xmm6, 16(%rdx) | |||
| cmpq $2, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %xmm8, (%rdx) | |||
| vextractf128 $1, %ymm8, %xmm8 | |||
| vmovss %xmm8, 16(%rdx) | |||
| cmpq $3, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %xmm10, (%rdx) | |||
| vextractf128 $1, %ymm10, %xmm10 | |||
| vmovss %xmm10, 16(%rdx) | |||
| cmpq $4, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %xmm12, (%rdx) | |||
| vextractf128 $1, %ymm12, %xmm12 | |||
| vmovss %xmm12, 16(%rdx) | |||
| cmpq $5, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %xmm14, (%rdx) | |||
| vextractf128 $1, %ymm14, %xmm14 | |||
| vmovss %xmm14, 16(%rdx) | |||
| addq %r10, %rdx | |||
| addq $20, %rdx | |||
| jmp WriteEnd | |||
| Write6: | |||
| movq %rdx, %rax | |||
| addq $24, %rax | |||
| movq %rax, -80(%rsp) | |||
| vmovups %xmm4, (%rdx) | |||
| vextractf128 $1, %ymm4, %xmm4 | |||
| vmovsd %xmm4, 16(%rdx) | |||
| cmpq $1, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %xmm6, (%rdx) | |||
| vextractf128 $1, %ymm6, %xmm6 | |||
| vmovsd %xmm6, 16(%rdx) | |||
| cmpq $2, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %xmm8, (%rdx) | |||
| vextractf128 $1, %ymm8, %xmm8 | |||
| vmovsd %xmm8, 16(%rdx) | |||
| cmpq $3, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %xmm10, (%rdx) | |||
| vextractf128 $1, %ymm10, %xmm10 | |||
| vmovsd %xmm10, 16(%rdx) | |||
| cmpq $4, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %xmm12, (%rdx) | |||
| vextractf128 $1, %ymm12, %xmm12 | |||
| vmovsd %xmm12, 16(%rdx) | |||
| cmpq $5, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %xmm14, (%rdx) | |||
| vextractf128 $1, %ymm14, %xmm14 | |||
| vmovsd %xmm14, 16(%rdx) | |||
| addq %r10, %rdx | |||
| addq $24, %rdx | |||
| jmp WriteEnd | |||
| Write7: | |||
| movq %rdx, %rax | |||
| addq $28, %rax | |||
| movq %rax, -80(%rsp) | |||
| vmovups %xmm4, (%rdx) | |||
| vextractf128 $1, %ymm4, %xmm4 | |||
| vmovsd %xmm4, 16(%rdx) | |||
| movhlps %xmm4, %xmm4 | |||
| vmovss %xmm4, 24(%rdx) | |||
| cmpq $1, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %xmm6, (%rdx) | |||
| vextractf128 $1, %ymm6, %xmm6 | |||
| vmovsd %xmm6, 16(%rdx) | |||
| movhlps %xmm6, %xmm6 | |||
| vmovss %xmm6, 24(%rdx) | |||
| cmpq $2, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %xmm8, (%rdx) | |||
| vextractf128 $1, %ymm8, %xmm8 | |||
| vmovsd %xmm8, 16(%rdx) | |||
| movhlps %xmm8, %xmm8 | |||
| vmovss %xmm8, 24(%rdx) | |||
| cmpq $3, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %xmm10, (%rdx) | |||
| vextractf128 $1, %ymm10, %xmm10 | |||
| vmovsd %xmm10, 16(%rdx) | |||
| movhlps %xmm10, %xmm10 | |||
| vmovss %xmm10, 24(%rdx) | |||
| cmpq $4, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %xmm12, (%rdx) | |||
| vextractf128 $1, %ymm12, %xmm12 | |||
| vmovsd %xmm12, 16(%rdx) | |||
| movhlps %xmm12, %xmm12 | |||
| vmovss %xmm12, 24(%rdx) | |||
| cmpq $5, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %xmm14, (%rdx) | |||
| vextractf128 $1, %ymm14, %xmm14 | |||
| vmovsd %xmm14, 16(%rdx) | |||
| movhlps %xmm14, %xmm14 | |||
| vmovss %xmm14, 24(%rdx) | |||
| addq %r10, %rdx | |||
| addq $28, %rdx | |||
| jmp WriteEnd | |||
| Write8: | |||
| movq %rdx, %rax | |||
| addq $32, %rax | |||
| movq %rax, -80(%rsp) | |||
| vmovups %ymm4, (%rdx) | |||
| cmpq $1, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %ymm6, (%rdx) | |||
| cmpq $2, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %ymm8, (%rdx) | |||
| cmpq $3, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %ymm10, (%rdx) | |||
| cmpq $4, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %ymm12, (%rdx) | |||
| cmpq $5, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %ymm14, (%rdx) | |||
| addq %r10, %rdx | |||
| addq $32, %rdx | |||
| jmp WriteEnd | |||
| Write9: | |||
| movq %rdx, %rax | |||
| addq $36, %rax | |||
| movq %rax, -80(%rsp) | |||
| vmovups %ymm4, (%rdx) | |||
| vmovss %xmm5, 32(%rdx) | |||
| cmpq $1, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %ymm6, (%rdx) | |||
| vmovss %xmm7, 32(%rdx) | |||
| cmpq $2, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %ymm8, (%rdx) | |||
| vmovss %xmm9, 32(%rdx) | |||
| cmpq $3, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %ymm10, (%rdx) | |||
| vmovss %xmm11, 32(%rdx) | |||
| cmpq $4, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %ymm12, (%rdx) | |||
| vmovss %xmm13, 32(%rdx) | |||
| cmpq $5, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %ymm14, (%rdx) | |||
| vmovss %xmm15, 32(%rdx) | |||
| addq %r10, %rdx | |||
| addq $36, %rdx | |||
| jmp WriteEnd | |||
| Write10: | |||
| movq %rdx, %rax | |||
| addq $40, %rax | |||
| movq %rax, -80(%rsp) | |||
| vmovups %ymm4, (%rdx) | |||
| vmovsd %xmm5, 32(%rdx) | |||
| cmpq $1, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %ymm6, (%rdx) | |||
| vmovsd %xmm7, 32(%rdx) | |||
| cmpq $2, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %ymm8, (%rdx) | |||
| vmovsd %xmm9, 32(%rdx) | |||
| cmpq $3, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %ymm10, (%rdx) | |||
| vmovsd %xmm11, 32(%rdx) | |||
| cmpq $4, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %ymm12, (%rdx) | |||
| vmovsd %xmm13, 32(%rdx) | |||
| cmpq $5, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %ymm14, (%rdx) | |||
| vmovsd %xmm15, 32(%rdx) | |||
| addq %r10, %rdx | |||
| addq $40, %rdx | |||
| jmp WriteEnd | |||
| Write11: | |||
| movq %rdx, %rax | |||
| addq $44, %rax | |||
| movq %rax, -80(%rsp) | |||
| vmovups %ymm4, (%rdx) | |||
| vmovsd %xmm5, 32(%rdx) | |||
| movhlps %xmm5, %xmm5 | |||
| vmovss %xmm5, 40(%rdx) | |||
| cmpq $1, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %ymm6, (%rdx) | |||
| vmovsd %xmm7, 32(%rdx) | |||
| movhlps %xmm7, %xmm7 | |||
| vmovss %xmm7, 40(%rdx) | |||
| cmpq $2, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %ymm8, (%rdx) | |||
| vmovsd %xmm9, 32(%rdx) | |||
| movhlps %xmm9, %xmm9 | |||
| vmovss %xmm9, 40(%rdx) | |||
| cmpq $3, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %ymm10, (%rdx) | |||
| vmovsd %xmm11, 32(%rdx) | |||
| movhlps %xmm11, %xmm11 | |||
| vmovss %xmm11, 40(%rdx) | |||
| cmpq $4, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %ymm12, (%rdx) | |||
| vmovsd %xmm13, 32(%rdx) | |||
| movhlps %xmm13, %xmm13 | |||
| vmovss %xmm13, 40(%rdx) | |||
| cmpq $5, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %ymm14, (%rdx) | |||
| vmovsd %xmm15, 32(%rdx) | |||
| movhlps %xmm15, %xmm15 | |||
| vmovss %xmm15, 40(%rdx) | |||
| addq %r10, %rdx | |||
| addq $44, %rdx | |||
| jmp WriteEnd | |||
| Write12: | |||
| movq %rdx, %rax | |||
| addq $48, %rax | |||
| movq %rax, -80(%rsp) | |||
| vmovups %ymm4, (%rdx) | |||
| vmovups %xmm5, 32(%rdx) | |||
| cmpq $1, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %ymm6, (%rdx) | |||
| vmovups %xmm7, 32(%rdx) | |||
| cmpq $2, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %ymm8, (%rdx) | |||
| vmovups %xmm9, 32(%rdx) | |||
| cmpq $3, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %ymm10, (%rdx) | |||
| vmovups %xmm11, 32(%rdx) | |||
| cmpq $4, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %ymm12, (%rdx) | |||
| vmovups %xmm13, 32(%rdx) | |||
| cmpq $5, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %ymm14, (%rdx) | |||
| vmovups %xmm15, 32(%rdx) | |||
| addq %r10, %rdx | |||
| addq $48, %rdx | |||
| jmp WriteEnd | |||
| Write13: | |||
| movq %rdx, %rax | |||
| addq $52, %rax | |||
| movq %rax, -80(%rsp) | |||
| vmovups %ymm4, (%rdx) | |||
| vmovups %xmm5, 32(%rdx) | |||
| vextractf128 $1, %ymm5, %xmm5 | |||
| vmovss %xmm5, 48(%rdx) | |||
| cmpq $1, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %ymm6, (%rdx) | |||
| vmovups %xmm7, 32(%rdx) | |||
| vextractf128 $1, %ymm7, %xmm7 | |||
| vmovss %xmm7, 48(%rdx) | |||
| cmpq $2, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %ymm8, (%rdx) | |||
| vmovups %xmm9, 32(%rdx) | |||
| vextractf128 $1, %ymm9, %xmm9 | |||
| vmovss %xmm9, 48(%rdx) | |||
| cmpq $3, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %ymm10, (%rdx) | |||
| vmovups %xmm11, 32(%rdx) | |||
| vextractf128 $1, %ymm11, %xmm11 | |||
| vmovss %xmm11, 48(%rdx) | |||
| cmpq $4, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %ymm12, (%rdx) | |||
| vmovups %xmm13, 32(%rdx) | |||
| vextractf128 $1, %ymm13, %xmm13 | |||
| vmovss %xmm13, 48(%rdx) | |||
| cmpq $5, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %ymm14, (%rdx) | |||
| vmovups %xmm15, 32(%rdx) | |||
| vextractf128 $1, %ymm15, %xmm15 | |||
| vmovss %xmm15, 48(%rdx) | |||
| addq %r10, %rdx | |||
| addq $52, %rdx | |||
| jmp WriteEnd | |||
| Write14: | |||
| movq %rdx, %rax | |||
| addq $56, %rax | |||
| movq %rax, -80(%rsp) | |||
| vmovups %ymm4, (%rdx) | |||
| vmovups %xmm5, 32(%rdx) | |||
| vextractf128 $1, %ymm5, %xmm5 | |||
| vmovsd %xmm5, 48(%rdx) | |||
| cmpq $1, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %ymm6, (%rdx) | |||
| vmovups %xmm7, 32(%rdx) | |||
| vextractf128 $1, %ymm7, %xmm7 | |||
| vmovsd %xmm7, 48(%rdx) | |||
| cmpq $2, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %ymm8, (%rdx) | |||
| vmovups %xmm9, 32(%rdx) | |||
| vextractf128 $1, %ymm9, %xmm9 | |||
| vmovsd %xmm9, 48(%rdx) | |||
| cmpq $3, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %ymm10, (%rdx) | |||
| vmovups %xmm11, 32(%rdx) | |||
| vextractf128 $1, %ymm11, %xmm11 | |||
| vmovsd %xmm11, 48(%rdx) | |||
| cmpq $4, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %ymm12, (%rdx) | |||
| vmovups %xmm13, 32(%rdx) | |||
| vextractf128 $1, %ymm13, %xmm13 | |||
| vmovsd %xmm13, 48(%rdx) | |||
| cmpq $5, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %ymm14, (%rdx) | |||
| vmovups %xmm15, 32(%rdx) | |||
| vextractf128 $1, %ymm15, %xmm15 | |||
| vmovsd %xmm15, 48(%rdx) | |||
| addq %r10, %rdx | |||
| addq $56, %rdx | |||
| jmp WriteEnd | |||
| Write15: | |||
| movq %rdx, %rax | |||
| addq $60, %rax | |||
| movq %rax, -80(%rsp) | |||
| vmovups %ymm4, (%rdx) | |||
| vmovups %xmm5, 32(%rdx) | |||
| vextractf128 $1, %ymm5, %xmm5 | |||
| vmovsd %xmm5, 48(%rdx) | |||
| movhlps %xmm5, %xmm5 | |||
| vmovss %xmm5, 56(%rdx) | |||
| cmpq $1, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %ymm6, (%rdx) | |||
| vmovups %xmm7, 32(%rdx) | |||
| vextractf128 $1, %ymm7, %xmm7 | |||
| vmovsd %xmm7, 48(%rdx) | |||
| movhlps %xmm7, %xmm7 | |||
| vmovss %xmm7, 56(%rdx) | |||
| cmpq $2, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %ymm8, (%rdx) | |||
| vmovups %xmm9, 32(%rdx) | |||
| vextractf128 $1, %ymm9, %xmm9 | |||
| vmovsd %xmm9, 48(%rdx) | |||
| movhlps %xmm9, %xmm9 | |||
| vmovss %xmm9, 56(%rdx) | |||
| cmpq $3, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %ymm10, (%rdx) | |||
| vmovups %xmm11, 32(%rdx) | |||
| vextractf128 $1, %ymm11, %xmm11 | |||
| vmovsd %xmm11, 48(%rdx) | |||
| movhlps %xmm11, %xmm11 | |||
| vmovss %xmm11, 56(%rdx) | |||
| cmpq $4, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %ymm12, (%rdx) | |||
| vmovups %xmm13, 32(%rdx) | |||
| vextractf128 $1, %ymm13, %xmm13 | |||
| vmovsd %xmm13, 48(%rdx) | |||
| movhlps %xmm13, %xmm13 | |||
| vmovss %xmm13, 56(%rdx) | |||
| cmpq $5, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %ymm14, (%rdx) | |||
| vmovups %xmm15, 32(%rdx) | |||
| vextractf128 $1, %ymm15, %xmm15 | |||
| vmovsd %xmm15, 48(%rdx) | |||
| movhlps %xmm15, %xmm15 | |||
| vmovss %xmm15, 56(%rdx) | |||
| addq %r10, %rdx | |||
| addq $60, %rdx | |||
| jmp WriteEnd | |||
| WriteC8: | |||
| movq %rdx, %rax | |||
| addq %r11, %rdx | |||
| movq %rdx, %r15 | |||
| addq %r11, %rdx | |||
| movq %rdx, -80(%rsp) | |||
| vmovups %ymm4, (%rax) | |||
| vmovups %ymm6, 32(%rax) | |||
| vmovups %ymm8, 64(%rax) | |||
| vmovups %ymm10, 96(%rax) | |||
| vmovups %ymm12, 128(%rax) | |||
| vmovups %ymm14, 160(%rax) | |||
| vmovups %ymm5, (%r15) | |||
| vmovups %ymm7, 32(%r15) | |||
| vmovups %ymm9, 64(%r15) | |||
| vmovups %ymm11, 96(%r15) | |||
| vmovups %ymm13, 128(%r15) | |||
| vmovups %ymm15, 160(%r15) | |||
| jmp WriteEnd | |||
| WriteWino: | |||
| movq %rdx, %rax | |||
| addq %r13, %rdx | |||
| movq %rdx, %r15 | |||
| addq %r13, %rdx | |||
| movq %rdx, -80(%rsp) | |||
| vmovups %ymm4, (%rax) | |||
| vmovups %ymm5, (%r15) | |||
| addq %r12, %rax | |||
| addq %r12, %r15 | |||
| vmovups %ymm6, (%rax) | |||
| vmovups %ymm7, (%r15) | |||
| addq %r12, %rax | |||
| addq %r12, %r15 | |||
| vmovups %ymm8, (%rax) | |||
| vmovups %ymm9, (%r15) | |||
| addq %r12, %rax | |||
| addq %r12, %r15 | |||
| vmovups %ymm10, (%rax) | |||
| vmovups %ymm11, (%r15) | |||
| addq %r12, %rax | |||
| addq %r12, %r15 | |||
| vmovups %ymm12, (%rax) | |||
| vmovups %ymm13, (%r15) | |||
| addq %r12, %rax | |||
| addq %r12, %r15 | |||
| vmovups %ymm14, (%rax) | |||
| vmovups %ymm15, (%r15) | |||
| jmp WriteEnd | |||
| Write16: | |||
| movq %rdx, %rax | |||
| addq $64, %rax | |||
| movq %rax, -80(%rsp) | |||
| vmovups %ymm4, (%rdx) | |||
| vmovups %ymm5, 32(%rdx) | |||
| cmpq $1, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %ymm6, (%rdx) | |||
| vmovups %ymm7, 32(%rdx) | |||
| cmpq $2, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %ymm8, (%rdx) | |||
| vmovups %ymm9, 32(%rdx) | |||
| cmpq $3, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %ymm10, (%rdx) | |||
| vmovups %ymm11, 32(%rdx) | |||
| cmpq $4, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %ymm12, (%rdx) | |||
| vmovups %ymm13, 32(%rdx) | |||
| cmpq $5, %rbp | |||
| je WriteEnd | |||
| addq %r10, %rdx | |||
| vmovups %ymm14, (%rdx) | |||
| vmovups %ymm15, 32(%rdx) | |||
| addq %r10, %rdx | |||
| addq $64, %rdx | |||
| WriteEnd: | |||
| cmpq $16, %rbx | |||
| jbe LoopColEnd | |||
| subq $16, %rbx | |||
| jmp LoopCol | |||
| LoopColEnd: | |||
| movq -96(%rsp), %rdi | |||
| addq %r11, %rdi | |||
| movq %rdi, -96(%rsp) | |||
| cmpq $0, %r14 | |||
| je C8DstStep | |||
| cmpq $2, %r14 | |||
| je WinoDstStep | |||
| movq $4, %rax | |||
| movq 16(%rsp), %rbx | |||
| imul %rbx, %rax | |||
| subq %rax, %rdx | |||
| movq %rdx, -80(%rsp) | |||
| jmp NoDstStep | |||
| C8DstStep: | |||
| movq -80(%rsp), %rax | |||
| addq $384, %rax | |||
| movq %rax, -80(%rsp) | |||
| jmp NoDstStep | |||
| WinoDstStep: | |||
| addq %r13, %rdx | |||
| movq %rdx, -80(%rsp) | |||
| NoDstStep: | |||
| cmpq $6, %rbp | |||
| jbe LoopRowEnd | |||
| subq $6, %rbp | |||
| jmp LoopRow | |||
| LoopRowEnd: | |||
| subq $96, %rsp | |||
| popq %rdi | |||
| popq %rsi | |||
| popq %rdx | |||
| popq %rcx | |||
| popq %r8 | |||
| popq %r9 | |||
| popq %rbp | |||
| popq %rbx | |||
| popq %r12 | |||
| popq %r13 | |||
| popq %r14 | |||
| popq %r15 | |||
| retq | |||
| #endif | |||
| #endif | |||
| @@ -56,7 +56,7 @@ void AdderFp32(const float *input_data, float *packed_input, const float *packed | |||
| int out_channel = conv_param->output_channel_; | |||
| int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; | |||
| int output_count = conv_param->output_h_ * conv_param->output_w_; | |||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||
| #if defined(ENABLE_ARM32) || defined(ENABLE_SSE) | |||
| const int cal_num = C4NUM; | |||
| #else | |||
| const int cal_num = C12NUM; | |||
| @@ -78,7 +78,7 @@ void AdderFp32(const float *input_data, float *packed_input, const float *packed | |||
| int out_offset = thread_id * cal_num * out_channel + out_batch_offset; | |||
| float *gemm_output = output_data + out_offset; | |||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||
| #if defined(ENABLE_ARM32) || defined(ENABLE_SSE) | |||
| RowMajor2Col4Major(gemm_input, col_major_gemm_input, cal_num, deep); | |||
| #else | |||
| RowMajor2Col12Major(gemm_input, col_major_gemm_input, cal_num, deep); | |||
| @@ -43,7 +43,7 @@ void PostConvFuncComm(const float *src_ptr_, float *out_ptr, const float *bias_p | |||
| void PostConvFuncFp32C8(const float *c8_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel, | |||
| size_t plane_size, size_t stride, size_t relu_type) { | |||
| #if !defined(ENABLE_ARM) && !defined(ENABLE_X86_64_SSE) | |||
| #if !defined(ENABLE_ARM) && !defined(ENABLE_SSE) | |||
| PostConvFuncComm(c8_out_ptr, out_ptr, bias_ptr, output_channel, plane_size, plane_size, stride, relu_type, C8NUM); | |||
| #else | |||
| size_t oc8mod = output_channel % C8NUM; | |||
| @@ -68,7 +68,7 @@ void PostConvFuncFp32C4(const float *c4_out_ptr, float *out_ptr, const float *bi | |||
| return; | |||
| } | |||
| #if !defined(ENABLE_ARM) && !defined(ENABLE_X86_64_SSE) | |||
| #if !defined(ENABLE_ARM) && !defined(ENABLE_SSE) | |||
| void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length) { | |||
| const int unitStep = 4 * length; | |||
| for (int y = 0; y < h; ++y) { | |||
| @@ -39,7 +39,7 @@ float ShortToFloat32(uint16_t src_value); | |||
| uint16_t Float32ToShort(float src_value); | |||
| #if defined(ENABLE_ARM) || defined(ENABLE_X86_64_SSE) | |||
| #if defined(ENABLE_ARM) || defined(ENABLE_SSE) | |||
| void ConvDwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, | |||
| size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, | |||
| size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, size_t relu, size_t relu6); | |||
| @@ -202,7 +202,7 @@ void ConvDwBorder(float *dst, const float *src, const float *weight, const float | |||
| const float *src_kernel = src_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_; | |||
| const float *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C4NUM; | |||
| #if defined(ENABLE_ARM) || defined(ENABLE_X86_64_SSE) | |||
| #if defined(ENABLE_ARM) || defined(ENABLE_SSE) | |||
| ConvDwFp32Border(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, | |||
| sliding->in_kh_step_ * sizeof(float), sliding->in_kw_step_ * sizeof(float), | |||
| conv_param->kernel_w_ * C4NUM * sizeof(float), relu, relu6); | |||
| @@ -285,7 +285,7 @@ void ConvDwSWFp32(float *output_data, const float *input_data, const float *weig | |||
| int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_; | |||
| const float *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * sliding->block_channel_; | |||
| float *out_t = dst_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_; | |||
| #if defined(ENABLE_ARM) || defined(ENABLE_X86_64_SSE) | |||
| #if defined(ENABLE_ARM) || defined(ENABLE_SSE) | |||
| ConvDwFp32Center(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, | |||
| conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float), | |||
| sliding->block_channel_ * sizeof(float), sliding->in_sh_step_ * sizeof(float), | |||
| @@ -839,7 +839,7 @@ void DeconvDwSWFp32(float *output_data, const float *input_data, const float *we | |||
| float *out_t = dst_data + oh_h_start * sliding->in_h_step_ + oh_w_start * sliding->block_channel_; | |||
| const float *in_t = src_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_; | |||
| #if defined(ENABLE_ARM) || defined(ENABLE_X86_64_SSE) | |||
| #if defined(ENABLE_ARM) || defined(ENABLE_SSE) | |||
| DeconvDwFp32Center(out_t, in_t, weight, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, | |||
| conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float), | |||
| sliding->block_channel_ * sizeof(float), sliding->in_sh_step_ * sizeof(float), | |||
| @@ -26,7 +26,9 @@ void ConvFp32(const float *input_data, float *packed_input, const float *packed_ | |||
| int out_channel = conv_param->output_channel_; | |||
| int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; | |||
| int output_count = conv_param->output_h_ * conv_param->output_w_; | |||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||
| #ifdef ENABLE_AVX | |||
| const int cal_num = C6NUM; | |||
| #elif defined(ENABLE_ARM32) || defined(ENABLE_SSE) | |||
| const int cal_num = C4NUM; | |||
| #else | |||
| const int cal_num = C12NUM; | |||
| @@ -48,7 +50,9 @@ void ConvFp32(const float *input_data, float *packed_input, const float *packed_ | |||
| int out_offset = thread_id * cal_num * out_channel + out_batch_offset; | |||
| float *gemm_output = output_data + out_offset; | |||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||
| #ifdef ENABLE_AVX | |||
| RowMajor2Col6Major(gemm_input, col_major_gemm_input, cal_num, deep); | |||
| #elif defined(ENABLE_ARM32) || defined(ENABLE_SSE) | |||
| RowMajor2Col4Major(gemm_input, col_major_gemm_input, cal_num, deep); | |||
| #else | |||
| RowMajor2Col12Major(gemm_input, col_major_gemm_input, cal_num, deep); | |||
| @@ -97,7 +101,7 @@ void ConvWinogardFp32(const float *input_data, const float *trans_weight, const | |||
| float *dst_ptr = gemm_out + task_id * gemm_out_offset; | |||
| float *tmp_col_ptr = col_buffer + task_id * col_buffer_offset; | |||
| for (int i = 0; i < input_unit_square; ++i) { | |||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||
| #if defined(ENABLE_ARM32) || defined(ENABLE_SSE) | |||
| RowMajor2Col4Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel); | |||
| #else | |||
| RowMajor2Col12Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel); | |||
| @@ -41,7 +41,7 @@ void DeConvPostFp32C8(const float *src, float *tmp, const float *bias, float *ds | |||
| size_t kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_; | |||
| size_t output_plane = conv_param->output_w_ * conv_param->output_h_; | |||
| int oc8 = UP_ROUND(output_channel, C8NUM); | |||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||
| #if defined(ENABLE_ARM32) || defined(ENABLE_SSE) | |||
| const int tile_num = 4; | |||
| #else | |||
| const int tile_num = 12; | |||
| @@ -28,9 +28,21 @@ void RowMajor2Row4Major(const float *src_ptr, float *dst_ptr, int row, int col) | |||
| for (int r = 0; r < row; r++) { | |||
| const float *src = src_ptr + r * col; | |||
| for (int c = 0; c < col; c++) { | |||
| int cd8 = c / 4; | |||
| int cm8 = c % 4; | |||
| dst_ptr[cd8 * 4 * row + r * 4 + cm8] = src[c]; | |||
| int cd4 = c / C4NUM; | |||
| int cm4 = c % C4NUM; | |||
| dst_ptr[cd4 * C4NUM * row + r * C4NUM + cm4] = src[c]; | |||
| } | |||
| } | |||
| return; | |||
| } | |||
| void RowMajor2Row6Major(const float *src_ptr, float *dst_ptr, int row, int col) { | |||
| for (int r = 0; r < row; r++) { | |||
| const float *src = src_ptr + r * col; | |||
| for (int c = 0; c < col; c++) { | |||
| int cd6 = c / C6NUM; | |||
| int cm6 = c % C6NUM; | |||
| dst_ptr[cd6 * C6NUM * row + r * C6NUM + cm6] = src[c]; | |||
| } | |||
| } | |||
| return; | |||
| @@ -40,9 +52,9 @@ void RowMajor2Row8Major(const float *src_ptr, float *dst_ptr, int row, int col) | |||
| for (int r = 0; r < row; r++) { | |||
| const float *src = src_ptr + r * col; | |||
| for (int c = 0; c < col; c++) { | |||
| int cd8 = c / 8; | |||
| int cm8 = c % 8; | |||
| dst_ptr[cd8 * 8 * row + r * 8 + cm8] = src[c]; | |||
| int cd8 = c / C8NUM; | |||
| int cm8 = c % C8NUM; | |||
| dst_ptr[cd8 * C8NUM * row + r * C8NUM + cm8] = src[c]; | |||
| } | |||
| } | |||
| return; | |||
| @@ -52,9 +64,21 @@ void RowMajor2Row12Major(const float *src_ptr, float *dst_ptr, int row, int col) | |||
| for (int r = 0; r < row; r++) { | |||
| const float *src = src_ptr + r * col; | |||
| for (int c = 0; c < col; c++) { | |||
| int cd8 = c / C12NUM; | |||
| int cm8 = c % C12NUM; | |||
| dst_ptr[cd8 * C12NUM * row + r * C12NUM + cm8] = src[c]; | |||
| int cd12 = c / C12NUM; | |||
| int cm12 = c % C12NUM; | |||
| dst_ptr[cd12 * C12NUM * row + r * C12NUM + cm12] = src[c]; | |||
| } | |||
| } | |||
| return; | |||
| } | |||
| void RowMajor2Row16Major(const float *src_ptr, float *dst_ptr, int row, int col) { | |||
| for (int r = 0; r < row; r++) { | |||
| const float *src = src_ptr + r * col; | |||
| for (int c = 0; c < col; c++) { | |||
| int cd16 = c / C16NUM; | |||
| int cm16 = c % C16NUM; | |||
| dst_ptr[cd16 * C16NUM * row + r * C16NUM + cm16] = src[c]; | |||
| } | |||
| } | |||
| return; | |||
| @@ -190,7 +214,7 @@ void RowMajor2Col12Major(const float *src_ptr, float *dst_ptr, size_t row, size_ | |||
| : | |||
| : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride) | |||
| : "r10", "r12", "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); | |||
| #elif ENABLE_X86_64_SSE | |||
| #elif ENABLE_SSE | |||
| __m128 src1 = _mm_loadu_ps(src_c); | |||
| __m128 src2 = _mm_loadu_ps(src_c + col); | |||
| __m128 src3 = _mm_loadu_ps(src_c + 2 * col); | |||
| @@ -421,7 +445,7 @@ void RowMajor2Col8Major(const float *src_ptr, float *dst_ptr, size_t row, size_t | |||
| : | |||
| : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride) | |||
| : "r10", "r11", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7"); | |||
| #elif ENABLE_X86_64_SSE | |||
| #elif ENABLE_SSE | |||
| /* 8x4 row-major to col-major */ | |||
| __m128 src1 = _mm_loadu_ps(src_c); | |||
| __m128 src2 = _mm_loadu_ps(src_c + col); | |||
| @@ -478,6 +502,145 @@ void RowMajor2Col8Major(const float *src_ptr, float *dst_ptr, size_t row, size_t | |||
| return; | |||
| } | |||
| void RowMajor2Col16Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col) { | |||
| size_t row16 = row / C16NUM * C16NUM; | |||
| size_t col_skip = col / C4NUM * C4NUM; | |||
| int skip_size = C4NUM; | |||
| const float *src_r = src_ptr; | |||
| float *dst_r = dst_ptr; | |||
| size_t ri = 0; | |||
| for (; ri < row16; ri += C16NUM) { | |||
| size_t ci = 0; | |||
| for (; ci < col_skip; ci += skip_size) { | |||
| const float *src_c = src_r + ci; | |||
| float *dst_c = dst_r + ci * C16NUM; | |||
| for (int tr = 0; tr < C16NUM; tr++) { | |||
| for (int tc = 0; tc < C4NUM; tc++) { | |||
| dst_c[tc * C16NUM + tr] = src_c[tr * col + tc]; | |||
| } | |||
| } | |||
| } | |||
| for (; ci < col; ci++) { | |||
| const float *src_c = src_r + ci; | |||
| float *dst_c = dst_r + ci * C16NUM; | |||
| for (size_t i = 0; i < C16NUM; i++) { | |||
| dst_c[i] = src_c[i * col]; | |||
| } | |||
| } | |||
| src_r += C16NUM * col; | |||
| dst_r += C16NUM * col; | |||
| } | |||
| for (; ri < row; ri++) { | |||
| for (size_t i = 0; i < col; i++) { | |||
| dst_r[i * C16NUM] = src_r[i]; | |||
| } | |||
| src_r += col; | |||
| dst_r += 1; | |||
| } | |||
| return; | |||
| } | |||
| void RowMajor2Col6Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col) { | |||
| size_t totalRow = UP_ROUND(row, C6NUM); | |||
| size_t row6 = row / C6NUM * C6NUM; | |||
| size_t col8 = col / C8NUM * C8NUM; | |||
| const float *src_r = src_ptr; | |||
| float *dst_r = dst_ptr; | |||
| size_t ri = 0; | |||
| for (; ri < row6; ri += C6NUM) { | |||
| size_t ci = 0; | |||
| for (; ci < col8; ci += C8NUM) { | |||
| const float *src_c = src_r + ci; | |||
| float *dst_c = dst_r + ci * C6NUM; | |||
| /* 6x8 row-major to col-major */ | |||
| #ifdef ENABLE_AVX | |||
| __m256 src0 = _mm256_loadu_ps(src_c); | |||
| __m256 src1 = _mm256_loadu_ps(src_c + col); | |||
| __m256 src2 = _mm256_loadu_ps(src_c + 2 * col); | |||
| __m256 src3 = _mm256_loadu_ps(src_c + 3 * col); | |||
| __m256 src4 = _mm256_loadu_ps(src_c + 4 * col); | |||
| __m256 src5 = _mm256_loadu_ps(src_c + 5 * col); | |||
| __m256 trans0 = _mm256_unpacklo_ps(src0, src1); | |||
| __m256 trans1 = _mm256_unpacklo_ps(src2, src3); | |||
| __m256 trans2 = _mm256_unpacklo_ps(src4, src5); | |||
| __m256 trans3 = _mm256_unpackhi_ps(src0, src1); | |||
| __m256 trans4 = _mm256_unpackhi_ps(src2, src3); | |||
| __m256 trans5 = _mm256_unpackhi_ps(src4, src5); | |||
| __m128 lo0 = _mm256_castps256_ps128(trans0); | |||
| __m128 lo1 = _mm256_castps256_ps128(trans1); | |||
| __m128 lo2 = _mm256_castps256_ps128(trans2); | |||
| __m128 lo3 = _mm256_castps256_ps128(trans3); | |||
| __m128 lo4 = _mm256_castps256_ps128(trans4); | |||
| __m128 lo5 = _mm256_castps256_ps128(trans5); | |||
| __m128 hi0 = _mm256_extractf128_ps(trans0, 1); | |||
| __m128 hi1 = _mm256_extractf128_ps(trans1, 1); | |||
| __m128 hi2 = _mm256_extractf128_ps(trans2, 1); | |||
| __m128 hi3 = _mm256_extractf128_ps(trans3, 1); | |||
| __m128 hi4 = _mm256_extractf128_ps(trans4, 1); | |||
| __m128 hi5 = _mm256_extractf128_ps(trans5, 1); | |||
| __m128 res0 = _mm_shuffle_ps(lo0, lo1, _MM_SHUFFLE(1, 0, 1, 0)); | |||
| __m128 res1 = _mm_shuffle_ps(lo2, lo0, _MM_SHUFFLE(3, 2, 1, 0)); | |||
| __m128 res2 = _mm_shuffle_ps(lo1, lo2, _MM_SHUFFLE(3, 2, 3, 2)); | |||
| __m128 res3 = _mm_shuffle_ps(lo3, lo4, _MM_SHUFFLE(1, 0, 1, 0)); | |||
| __m128 res4 = _mm_shuffle_ps(lo5, lo3, _MM_SHUFFLE(3, 2, 1, 0)); | |||
| __m128 res5 = _mm_shuffle_ps(lo4, lo5, _MM_SHUFFLE(3, 2, 3, 2)); | |||
| __m128 res6 = _mm_shuffle_ps(hi0, hi1, _MM_SHUFFLE(1, 0, 1, 0)); | |||
| __m128 res7 = _mm_shuffle_ps(hi2, hi0, _MM_SHUFFLE(3, 2, 1, 0)); | |||
| __m128 res8 = _mm_shuffle_ps(hi1, hi2, _MM_SHUFFLE(3, 2, 3, 2)); | |||
| __m128 res9 = _mm_shuffle_ps(hi3, hi4, _MM_SHUFFLE(1, 0, 1, 0)); | |||
| __m128 res10 = _mm_shuffle_ps(hi5, hi3, _MM_SHUFFLE(3, 2, 1, 0)); | |||
| __m128 res11 = _mm_shuffle_ps(hi4, hi5, _MM_SHUFFLE(3, 2, 3, 2)); | |||
| _mm_storeu_ps(dst_c, res0); | |||
| _mm_storeu_ps(dst_c + 4, res1); | |||
| _mm_storeu_ps(dst_c + 8, res2); | |||
| _mm_storeu_ps(dst_c + 12, res3); | |||
| _mm_storeu_ps(dst_c + 16, res4); | |||
| _mm_storeu_ps(dst_c + 20, res5); | |||
| _mm_storeu_ps(dst_c + 24, res6); | |||
| _mm_storeu_ps(dst_c + 28, res7); | |||
| _mm_storeu_ps(dst_c + 32, res8); | |||
| _mm_storeu_ps(dst_c + 36, res9); | |||
| _mm_storeu_ps(dst_c + 40, res10); | |||
| _mm_storeu_ps(dst_c + 44, res11); | |||
| #else | |||
| for (int tr = 0; tr < C6NUM; tr++) { | |||
| for (int tc = 0; tc < C8NUM; tc++) { | |||
| dst_c[tc * C6NUM + tr] = src_c[tr * col + tc]; | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| for (; ci < col; ci++) { | |||
| const float *src_c = src_r + ci; | |||
| float *dst_c = dst_r + ci * C6NUM; | |||
| for (size_t i = 0; i < C6NUM; i++) { | |||
| dst_c[i] = src_c[i * col]; | |||
| } | |||
| } | |||
| src_r += C6NUM * col; | |||
| dst_r += C6NUM * col; | |||
| } | |||
| for (; ri < row; ri++) { | |||
| for (size_t i = 0; i < col; i++) { | |||
| dst_r[i * C6NUM] = src_r[i]; | |||
| } | |||
| src_r += col; | |||
| dst_r += 1; | |||
| } | |||
| for (; ri < totalRow; ri++) { | |||
| for (size_t i = 0; i < col; i++) { | |||
| dst_r[i * C6NUM] = 0; | |||
| } | |||
| dst_r += 1; | |||
| } | |||
| return; | |||
| } | |||
| void RowMajor2Col4Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col) { | |||
| size_t row8 = row / C4NUM * C4NUM; | |||
| size_t col4 = col / C4NUM * C4NUM; | |||
| @@ -519,7 +682,7 @@ void RowMajor2Col4Major(const float *src_ptr, float *dst_ptr, size_t row, size_t | |||
| : | |||
| : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride) | |||
| : "r10", "r12", "q0", "q1", "q2", "q3"); | |||
| #elif ENABLE_X86_64_SSE | |||
| #elif ENABLE_SSE | |||
| __m128 src1 = _mm_loadu_ps(src_c); | |||
| __m128 src2 = _mm_loadu_ps(src_c + col); | |||
| __m128 src3 = _mm_loadu_ps(src_c + 2 * col); | |||
| @@ -630,6 +793,34 @@ void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, A | |||
| return; | |||
| } | |||
| #ifdef ENABLE_AVX | |||
| #ifdef WIN32 | |||
| void MatMul6x16(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row, | |||
| int col, int stride, int out_type) { | |||
| if (out_type == OutType_Nhwc) { | |||
| for (int r = 0; r < row; r++) { | |||
| for (int c = 0; c < col; c++) { | |||
| int r6div = r / C6NUM, r6mod = r % C6NUM; | |||
| int c16div = c / C16NUM, c16mod = c % C16NUM; | |||
| size_t ci = r * stride + c; | |||
| float value = 0; | |||
| for (int d = 0; d < deep; d++) { | |||
| size_t ai = r6div * deep * C6NUM + d * C6NUM + r6mod; | |||
| size_t bi = c16div * deep * C16NUM + d * C16NUM + c16mod; | |||
| value = value + a[ai] * b[bi]; | |||
| } | |||
| if (bias != NULL) value += bias[c]; | |||
| if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); | |||
| if (act_type != ActType_No) value = MSMAX(0.0f, value); | |||
| dst[ci] = value; | |||
| } | |||
| } | |||
| } | |||
| return; | |||
| } | |||
| #endif | |||
| #endif | |||
| void MatMul4x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row, | |||
| int col, int stride, int out_type) { | |||
| if (out_type == OutType_C8) { | |||
| @@ -670,7 +861,19 @@ void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActT | |||
| } else { | |||
| MatmulFloatNeon32Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type)); | |||
| } | |||
| #elif ENABLE_X86_64_SSE | |||
| #elif ENABLE_AVX | |||
| if (out_type == OutType_Nhwc) { | |||
| #ifdef WIN32 | |||
| MatMul6x16(a, b, c, bias, act_type, deep, row, col, stride, out_type); | |||
| #else | |||
| MatmulFloatAvxOpt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type)); | |||
| #endif | |||
| } else if (out_type == OutType_C8) { | |||
| MatmulFloatSse64(a, b, c, bias, (int)act_type, deep, row, col, stride, 0, 0); | |||
| } else { | |||
| MatmulFloatSse64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type)); | |||
| } | |||
| #elif ENABLE_SSE | |||
| if (out_type == OutType_C8) { | |||
| MatmulFloatSse64(a, b, c, bias, (int)act_type, deep, row, col, stride, 0, 0); | |||
| } else { | |||
| @@ -31,11 +31,15 @@ void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActT | |||
| void MatVecMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int depth, int col); | |||
| void RowMajor2ColMajor(const float *src_ptr, float *dst_ptr, int row, int col); | |||
| void RowMajor2Row4Major(const float *src_ptr, float *dst_ptr, int row, int col); | |||
| void RowMajor2Row6Major(const float *src_ptr, float *dst_ptr, int row, int col); | |||
| void RowMajor2Row8Major(const float *src_ptr, float *dst_ptr, int row, int col); | |||
| void RowMajor2Row12Major(const float *src_ptr, float *dst_ptr, int row, int col); | |||
| void RowMajor2Row16Major(const float *src_ptr, float *dst_ptr, int row, int col); | |||
| void RowMajor2Col4Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col); | |||
| void RowMajor2Col6Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col); | |||
| void RowMajor2Col8Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col); | |||
| void RowMajor2Col12Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col); | |||
| void RowMajor2Col16Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col); | |||
| #ifdef ENABLE_ARM | |||
| void MatVecMulFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int col); | |||
| #endif | |||
| @@ -49,11 +53,16 @@ void MatmulFloatNeon32(const float *a, const float *b, float *c, const float *bi | |||
| int col, int stride, size_t writeNhwc, size_t WriteWino); | |||
| void MatmulFloatNeon32Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, | |||
| int col, int stride, int write_mode); | |||
| #elif ENABLE_X86_64_SSE | |||
| #elif ENABLE_SSE | |||
| #include <x86intrin.h> | |||
| void MatmulFloatSse64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, | |||
| int col, int stride, size_t writeNhwc, size_t WriteWino); | |||
| void MatmulFloatSse64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, | |||
| int col, int stride, int write_mode); | |||
| #ifdef ENABLE_AVX | |||
| void MatmulFloatAvxOpt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, | |||
| int col, int stride, int write_mode); | |||
| #endif | |||
| #endif | |||
| #ifdef ENABLE_NNACL_INFER_SHAPE | |||
| @@ -42,6 +42,7 @@ typedef struct MatMulParameter { | |||
| int row_; | |||
| int col_; | |||
| int row_4_; | |||
| int row_6_; | |||
| int row_8_; | |||
| int row_12_; | |||
| int row_16_; | |||
| @@ -125,7 +125,7 @@ int B(const float *poly_array, float *matrix_b, int in_unit) { | |||
| return NNACL_OK; | |||
| } | |||
| #if !defined(ENABLE_ARM) && !defined(ENABLE_X86_64_SSE) | |||
| #if !defined(ENABLE_ARM) && !defined(ENABLE_SSE) | |||
| void MatrixMultiplyWinograd(const float *matix_a, const float *matrix_b, float *matrix_c, int m, int k, int n, | |||
| int in_channel, int c4_channel) { | |||
| int cnt = 0; | |||
| @@ -228,7 +228,7 @@ int CookToomFilter(float *matrix_a, float *matrix_at, float *matrix_b, float *ma | |||
| return NNACL_OK; | |||
| } | |||
| #if defined(ENABLE_ARM) || defined(ENABLE_X86_64_SSE) | |||
| #if defined(ENABLE_ARM) || defined(ENABLE_SSE) | |||
| void MatrixMultiplyVec(const MS_FLOAT32X4 *matrix_a, const MS_FLOAT32X4 *matrix_b, MS_FLOAT32X4 *matrix_c, | |||
| const float *bias, int m, int k, int n) { | |||
| int count = 0; | |||
| @@ -52,7 +52,7 @@ void MatrixMultiplyWinograd(const float *matix_a, const float *matrix_b, float * | |||
| int WinogradWeightTransform(const float *weight_data, float *winograd_data, float *matrix_g, const float *matrix_gt, | |||
| int oc_block, int input_unit_, int kernel_unit_, int channel, int batch, bool pack); | |||
| #if defined(ENABLE_ARM) || defined(ENABLE_X86_64_SSE) | |||
| #if defined(ENABLE_ARM) || defined(ENABLE_SSE) | |||
| void MatrixMultiplyVec(const MS_FLOAT32X4 *matrix_a, const MS_FLOAT32X4 *matrix_b, MS_FLOAT32X4 *matrix_c, | |||
| const float *bias, int m, int k, int n); | |||
| #endif | |||
| @@ -21,8 +21,8 @@ | |||
| #include <arm_neon.h> | |||
| #endif | |||
| #ifdef ENABLE_X86_64_SSE | |||
| #include <nmmintrin.h> | |||
| #ifdef ENABLE_SSE | |||
| #include <x86intrin.h> | |||
| #endif | |||
| #include <stdint.h> | |||
| @@ -31,6 +31,7 @@ | |||
| #define C2NUM 2 | |||
| #define C4NUM 4 | |||
| #define C6NUM 6 | |||
| #define C8NUM 8 | |||
| #define C12NUM 12 | |||
| #define C16NUM 16 | |||
| @@ -91,7 +92,7 @@ typedef enum ActType { ActType_No, ActType_Relu, ActType_Sigmod, ActType_Relu6, | |||
| #define MS_MAXQ_F32 vmaxq_f32 | |||
| #define MS_MINQ_F32 vminq_f32 | |||
| #define MS_MULQ_F32(src1, src2) vmulq_n_f32(src1, src2) | |||
| #elif defined(ENABLE_X86_64_SSE) | |||
| #elif defined(ENABLE_SSE) | |||
| #define MS_FLOAT32X4 __m128 | |||
| #define MS_LDQ_F32 _mm_loadu_ps | |||
| #define MS_ADDQ_F32 _mm_add_ps | |||
| @@ -756,7 +756,7 @@ void PackNCHWToNHWCInt8(const void *src, void *dst, int batch, int plane, int ch | |||
| return; | |||
| } | |||
| #ifndef ENABLE_X86_64_SSE | |||
| #ifndef ENABLE_SSE | |||
| void PackNHWCToNCHWFp32(const void *src, void *dst, int batches, int plane, int channel) { | |||
| int hw8 = plane / C8NUM * C8NUM; | |||
| int c8 = channel / C8NUM * C8NUM; | |||
| @@ -79,7 +79,7 @@ void GeneralInputTransformUnit(const float *src_data, float *dst_data, const flo | |||
| int src_step, int dst_step, int in_unit) { | |||
| int len = in_unit * in_unit; | |||
| if (len > MAX_LEN) return; | |||
| #if defined(ENABLE_ARM) || defined(ENABLE_X86_64_SSE) | |||
| #if defined(ENABLE_ARM) || defined(ENABLE_SSE) | |||
| MS_FLOAT32X4 src[MAX_LEN]; | |||
| MS_FLOAT32X4 t[MAX_LEN]; | |||
| MS_FLOAT32X4 m[MAX_LEN]; | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifdef ENABLE_X86_64_SSE | |||
| #include <nmmintrin.h> | |||
| #ifdef ENABLE_SSE | |||
| #include <x86intrin.h> | |||
| #include "nnacl/fp32/conv_depthwise_fp32.h" | |||
| void ConvDwFp32Border(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifdef ENABLE_X86_64_SSE | |||
| #include <nmmintrin.h> | |||
| #ifdef ENABLE_SSE | |||
| #include <x86intrin.h> | |||
| #include "nnacl/minimal_filtering_generator.h" | |||
| #include "nnacl/op_base.h" | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifdef ENABLE_X86_64_SSE | |||
| #include <nmmintrin.h> | |||
| #ifdef ENABLE_SSE | |||
| #include <x86intrin.h> | |||
| #include "nnacl/pack.h" | |||
| #include "nnacl/int8/conv_int8.h" | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifdef ENABLE_X86_64_SSE | |||
| #include <nmmintrin.h> | |||
| #ifdef ENABLE_SSE | |||
| #include <x86intrin.h> | |||
| #include "nnacl/fp32/common_func_fp32.h" | |||
| void PostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t oc8div, size_t oc8mod, | |||
| @@ -13,8 +13,8 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifdef ENABLE_X86_64_SSE | |||
| #include <nmmintrin.h> | |||
| #ifdef ENABLE_SSE | |||
| #include <x86intrin.h> | |||
| #include "nnacl/fp32/common_func_fp32.h" | |||
| void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length) { | |||
| @@ -60,6 +60,7 @@ void Convolution1x1CPUKernel::InitConv1x1MatmulParam() { | |||
| matmul_param_->col_ = conv_param_->output_channel_; | |||
| matmul_param_->deep_ = conv_param_->input_channel_; | |||
| matmul_param_->row_4_ = UP_ROUND(matmul_param_->row_, C4NUM); | |||
| matmul_param_->row_6_ = UP_ROUND(matmul_param_->row_, C6NUM); | |||
| matmul_param_->row_12_ = UP_ROUND(matmul_param_->row_, C12NUM); | |||
| matmul_param_->col_8_ = UP_ROUND(matmul_param_->col_, C8NUM); | |||
| matmul_param_->act_type_ = conv_param_->act_type_; | |||
| @@ -71,8 +72,13 @@ int Convolution1x1CPUKernel::InitConv1x1BiasWeight() { | |||
| auto input_channel = filter_tensor->Channel(); | |||
| auto output_channel = filter_tensor->Batch(); | |||
| #ifdef ENABLE_AVX | |||
| int col_tile = C16NUM; | |||
| #else | |||
| int col_tile = C8NUM; | |||
| #endif | |||
| if (in_tensors_.size() == 3) { | |||
| int size = UP_ROUND(output_channel, C8NUM) * sizeof(float); | |||
| int size = UP_ROUND(output_channel, col_tile) * sizeof(float); | |||
| int weight_size = output_channel * sizeof(float); | |||
| bias_data_ = malloc(size); | |||
| if (bias_data_ == nullptr) { | |||
| @@ -83,22 +89,29 @@ int Convolution1x1CPUKernel::InitConv1x1BiasWeight() { | |||
| memset(reinterpret_cast<char *>(bias_data_) + weight_size, 0, size - weight_size); | |||
| } | |||
| int size = input_channel * UP_ROUND(output_channel, C8NUM) * sizeof(float); | |||
| int down_size = input_channel * DOWN_DIV(output_channel, C8NUM) * C8NUM * sizeof(float); | |||
| int size = input_channel * UP_ROUND(output_channel, col_tile) * sizeof(float); | |||
| int down_size = input_channel * DOWN_DIV(output_channel, col_tile) * col_tile * sizeof(float); | |||
| weight_ptr_ = reinterpret_cast<float *>(malloc(size)); | |||
| if (weight_ptr_ == nullptr) { | |||
| MS_LOG(ERROR) << "Conv1x1 Malloc weight_ptr_ error!"; | |||
| return RET_ERROR; | |||
| } | |||
| memset(reinterpret_cast<char *>(weight_ptr_) + down_size, 0, size - down_size); | |||
| #ifdef ENABLE_AVX | |||
| RowMajor2Col16Major(reinterpret_cast<float *>(filter_tensor->MutableData()), weight_ptr_, output_channel, | |||
| input_channel); | |||
| #else | |||
| RowMajor2Col8Major(reinterpret_cast<float *>(filter_tensor->MutableData()), weight_ptr_, output_channel, | |||
| input_channel); | |||
| #endif | |||
| return RET_OK; | |||
| } | |||
| int Convolution1x1CPUKernel::InitConv1x1Param() { | |||
| int hw_tile = C12NUM; | |||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||
| #ifdef ENABLE_AVX | |||
| hw_tile = C6NUM; | |||
| #elif defined(ENABLE_ARM32) || defined(ENABLE_SSE) | |||
| hw_tile = C4NUM; | |||
| #endif | |||
| if ((matmul_param_->row_ > (hw_tile * op_parameter_->thread_num_)) && (matmul_param_->row_ > matmul_param_->col_)) { | |||
| @@ -106,9 +119,14 @@ int Convolution1x1CPUKernel::InitConv1x1Param() { | |||
| thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->row_, hw_tile)); | |||
| thread_stride_ = UP_DIV(UP_DIV(matmul_param_->row_, hw_tile), thread_count_) * hw_tile; | |||
| } else { | |||
| #ifdef ENABLE_AVX | |||
| int col_tile = C16NUM; | |||
| #else | |||
| int col_tile = C8NUM; | |||
| #endif | |||
| multi_thread_by_hw_ = false; | |||
| thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->col_, C8NUM)); | |||
| thread_stride_ = UP_DIV(UP_DIV(matmul_param_->col_, C8NUM), thread_count_) * C8NUM; | |||
| thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->col_, col_tile)); | |||
| thread_stride_ = UP_DIV(UP_DIV(matmul_param_->col_, col_tile), thread_count_) * col_tile; | |||
| } | |||
| pre_trans_input_ = (conv_param_->pad_u_ != 0 || conv_param_->pad_l_ != 0 || conv_param_->stride_h_ != 1 || | |||
| @@ -175,7 +193,9 @@ int Convolution1x1CPUKernel::DoConv1x1Hw(int task_id) { | |||
| float *thread_input_ptr = input_ptr_ + task_id * thread_stride_ * matmul_param_->deep_; | |||
| float *thread_pack_input = pack_input_ + task_id * thread_stride_ * matmul_param_->deep_; | |||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||
| #if ENABLE_AVX | |||
| RowMajor2Col6Major(thread_input_ptr, thread_pack_input, cur_hw_, matmul_param_->deep_); | |||
| #elif defined(ENABLE_ARM32) || defined(ENABLE_SSE) | |||
| RowMajor2Col4Major(thread_input_ptr, thread_pack_input, cur_hw_, matmul_param_->deep_); | |||
| #else | |||
| RowMajor2Col12Major(thread_input_ptr, thread_pack_input, cur_hw_, matmul_param_->deep_); | |||
| @@ -202,7 +222,10 @@ int Convolution1x1CPUKernel::Run() { | |||
| auto src_in = reinterpret_cast<float *>(in_tensors_[0]->MutableData()); | |||
| auto src_out = reinterpret_cast<float *>(out_tensors_[0]->MutableData()); | |||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||
| #ifdef ENABLE_AVX | |||
| pack_input_ = | |||
| reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_6_ * matmul_param_->deep_ * sizeof(float))); | |||
| #elif defined(ENABLE_ARM32) || defined(ENABLE_SSE) | |||
| pack_input_ = | |||
| reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_4_ * matmul_param_->deep_ * sizeof(float))); | |||
| #else | |||
| @@ -226,7 +249,9 @@ int Convolution1x1CPUKernel::Run() { | |||
| if (multi_thread_by_hw_) { | |||
| ParallelLaunch(this->context_->thread_pool_, Convolution1x1RunHw, this, thread_count_); | |||
| } else { | |||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||
| #ifdef ENABLE_AVX | |||
| RowMajor2Col6Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_); | |||
| #elif defined(ENABLE_ARM32) || defined(ENABLE_SSE) | |||
| RowMajor2Col4Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_); | |||
| #else | |||
| RowMajor2Col12Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_); | |||
| @@ -42,7 +42,12 @@ int ConvolutionCPUKernel::InitWeightBias() { | |||
| conv_param_->input_channel_ = in_channel; | |||
| conv_param_->output_channel_ = out_channel; | |||
| int kernel_plane = filter_tensor->Height() * filter_tensor->Width(); | |||
| int oc_block_num = UP_ROUND(out_channel, C8NUM); | |||
| #ifdef ENABLE_AVX | |||
| const int oc_block = C16NUM; | |||
| #else | |||
| const int oc_block = C8NUM; | |||
| #endif | |||
| int oc_block_num = UP_ROUND(out_channel, oc_block); | |||
| int pack_weight_size = oc_block_num * in_channel * kernel_plane; | |||
| auto origin_weight = reinterpret_cast<float *>(filter_tensor->data_c()); | |||
| @@ -52,7 +57,11 @@ int ConvolutionCPUKernel::InitWeightBias() { | |||
| return RET_ERROR; | |||
| } | |||
| memset(packed_weight_, 0, pack_weight_size * sizeof(float)); | |||
| #ifdef ENABLE_AVX | |||
| RowMajor2Col16Major(origin_weight, packed_weight_, out_channel, in_channel * kernel_plane); | |||
| #else | |||
| RowMajor2Col8Major(origin_weight, packed_weight_, out_channel, in_channel * kernel_plane); | |||
| #endif | |||
| bias_data_ = reinterpret_cast<float *>(malloc(oc_block_num * sizeof(float))); | |||
| if (bias_data_ == nullptr) { | |||
| @@ -72,7 +81,10 @@ int ConvolutionCPUKernel::InitWeightBias() { | |||
| int ConvolutionCPUKernel::InitTmpBuffer() { | |||
| MS_ASSERT(ctx_->allocator != nullptr); | |||
| #ifdef ENABLE_ARM32 | |||
| #ifdef ENABLE_AVX | |||
| int unit_size = conv_param_->kernel_h_ * conv_param_->kernel_w_ * conv_param_->input_channel_ * C6NUM * thread_count_; | |||
| #elif ENABLE_ARM32 || ENABLE_SSE | |||
| int unit_size = conv_param_->kernel_h_ * conv_param_->kernel_w_ * conv_param_->input_channel_ * C4NUM * thread_count_; | |||
| #else | |||
| int unit_size = | |||
| @@ -115,7 +115,7 @@ int DeConvolutionCPUKernel::DoDeconv(int task_id) { | |||
| return RET_OK; | |||
| } | |||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||
| #if defined(ENABLE_ARM32) || defined(ENABLE_SSE) | |||
| auto tmp_buffer = tmp_buffer_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->row_4_; | |||
| MatMulOpt(pack_input_, weight_ptr_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->deep_, | |||
| tmp_buffer, nullptr, ActType_No, matmul_param_->deep_, matmul_param_->row_4_, oc * C8NUM * kernel_plane_, | |||
| @@ -174,7 +174,7 @@ int DeConvolutionCPUKernel::InitRunBuf() { | |||
| return RET_NULL_PTR; | |||
| } | |||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||
| #if defined(ENABLE_ARM32) || defined(ENABLE_SSE) | |||
| tmp_buffer_ = | |||
| reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_4_ * matmul_param_->col_8_ * sizeof(float))); | |||
| #else | |||
| @@ -186,7 +186,7 @@ int DeConvolutionCPUKernel::InitRunBuf() { | |||
| return RET_NULL_PTR; | |||
| } | |||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||
| #if defined(ENABLE_ARM32) || defined(ENABLE_SSE) | |||
| pack_input_ = | |||
| reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_4_ * matmul_param_->deep_ * sizeof(float))); | |||
| #else | |||
| @@ -215,7 +215,7 @@ int DeConvolutionCPUKernel::Run() { | |||
| input_ptr_ = src_in + batch_index * input_plane_ * conv_param_->input_channel_; | |||
| output_ptr_ = src_out + batch_index * output_plane_ * conv_param_->output_channel_; | |||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||
| #if defined(ENABLE_ARM32) || defined(ENABLE_SSE) | |||
| RowMajor2Col4Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_); | |||
| #else | |||
| RowMajor2Col12Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_); | |||
| @@ -51,12 +51,18 @@ int FullconnectionCPUKernel::ReSize() { | |||
| fc_param_->col_ = out_tensors_.at(0)->shape().back(); | |||
| fc_param_->deep_ = (in_tensors_.at(1)->shape()).at(1); | |||
| #ifdef ENABLE_AVX | |||
| int col_tile = C16NUM; | |||
| #else | |||
| int col_tile = C8NUM; | |||
| #endif | |||
| fc_param_->row_12_ = UP_ROUND(fc_param_->row_, C12NUM); | |||
| fc_param_->col_8_ = UP_ROUND(fc_param_->col_, C8NUM); | |||
| fc_param_->col_8_ = UP_ROUND(fc_param_->col_, col_tile); | |||
| fc_param_->row_6_ = UP_ROUND(fc_param_->col_, C6NUM); | |||
| fc_param_->row_4_ = UP_ROUND(fc_param_->row_, C4NUM); | |||
| thread_count_ = MSMIN(thread_count_, UP_DIV(fc_param_->col_8_, 8)); | |||
| thread_stride_ = UP_DIV(UP_DIV(fc_param_->col_8_, 8), thread_count_); | |||
| thread_count_ = MSMIN(thread_count_, UP_DIV(fc_param_->col_8_, col_tile)); | |||
| thread_stride_ = UP_DIV(UP_DIV(fc_param_->col_8_, col_tile), thread_count_); | |||
| #ifdef ENABLE_ARM | |||
| if (fc_param_->row_ == 1) { | |||
| @@ -75,7 +81,9 @@ int FullconnectionCPUKernel::ReSize() { | |||
| memcpy(bias_ptr_, in_tensors_[2]->MutableData(), fc_param_->col_ * sizeof(float)); | |||
| } | |||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||
| #ifdef ENABLE_AVX | |||
| int row_tmp = is_vector_input_ ? 1 : fc_param_->row_6_; | |||
| #elif defined(ENABLE_ARM32) || defined(ENABLE_SSE) | |||
| int row_tmp = is_vector_input_ ? 1 : fc_param_->row_4_; | |||
| #else | |||
| int row_tmp = is_vector_input_ ? 1 : fc_param_->row_12_; | |||
| @@ -120,7 +128,9 @@ void FullconnectionCPUKernel::InitMatrixA(const float *src_ptr, float *dst_ptr) | |||
| return; | |||
| } | |||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||
| #ifdef ENABLE_AVX | |||
| RowMajor2Col6Major(src_ptr, a_pack_ptr_, fc_param_->row_, fc_param_->deep_); | |||
| #elif defined(ENABLE_ARM32) || defined(ENABLE_SSE) | |||
| RowMajor2Col4Major(src_ptr, a_pack_ptr_, fc_param_->row_, fc_param_->deep_); | |||
| #else | |||
| RowMajor2Col12Major(src_ptr, a_pack_ptr_, fc_param_->row_, fc_param_->deep_); | |||
| @@ -132,8 +142,11 @@ void FullconnectionCPUKernel::InitMatrixB(const float *src_ptr, float *dst_ptr) | |||
| memcpy(dst_ptr, src_ptr, fc_param_->col_ * fc_param_->deep_ * sizeof(float)); | |||
| return; | |||
| } | |||
| #ifdef ENABLE_AVX | |||
| RowMajor2Col16Major(src_ptr, dst_ptr, fc_param_->col_, fc_param_->deep_); | |||
| #else | |||
| RowMajor2Col8Major(src_ptr, dst_ptr, fc_param_->col_, fc_param_->deep_); | |||
| #endif | |||
| } | |||
| int FcFp32MatmulRun(void *cdata, int task_id) { | |||
| @@ -147,14 +160,19 @@ int FcFp32MatmulRun(void *cdata, int task_id) { | |||
| } | |||
| int FullconnectionCPUKernel::DoMatmul(int task_id) { | |||
| int cur_oc = MSMIN(thread_stride_ * C8NUM, fc_param_->col_ - task_id * thread_stride_ * C8NUM); | |||
| #ifdef ENABLE_AVX | |||
| int col_tile = C16NUM; | |||
| #else | |||
| int col_tile = C8NUM; | |||
| #endif | |||
| int cur_oc = MSMIN(thread_stride_ * col_tile, fc_param_->col_ - task_id * thread_stride_ * col_tile); | |||
| if (cur_oc <= 0) { | |||
| return RET_OK; | |||
| } | |||
| auto b = b_ptr_ + task_id * thread_stride_ * C8NUM * fc_param_->deep_; | |||
| auto bias = (bias_ptr_ == nullptr) ? nullptr : bias_ptr_ + task_id * thread_stride_ * C8NUM; | |||
| auto c = c_ptr_ + task_id * thread_stride_ * C8NUM; | |||
| auto b = b_ptr_ + task_id * thread_stride_ * col_tile * fc_param_->deep_; | |||
| auto bias = (bias_ptr_ == nullptr) ? nullptr : bias_ptr_ + task_id * thread_stride_ * col_tile; | |||
| auto c = c_ptr_ + task_id * thread_stride_ * col_tile; | |||
| if (is_vector_input_) { | |||
| MatVecMul(a_ptr_, b, c, bias, fc_param_->act_type_, fc_param_->deep_, cur_oc); | |||
| } else { | |||
| @@ -75,9 +75,12 @@ int MatmulCPUKernel::MallocMatrixABuffer() { | |||
| #endif | |||
| params_->deep_ = params_->a_transpose_ ? a_shape[a_shape.size() - 2] : a_shape[a_shape.size() - 1]; | |||
| params_->row_4_ = UP_ROUND(params_->row_, C4NUM); | |||
| params_->row_6_ = UP_ROUND(params_->row_, C6NUM); | |||
| params_->row_12_ = UP_ROUND(params_->row_, C12NUM); | |||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||
| #ifdef ENABLE_AVX | |||
| int row_tmp = is_vector_a_ ? 1 : params_->row_6_; | |||
| #elif defined(ENABLE_ARM32) || defined(ENABLE_SSE) | |||
| int row_tmp = is_vector_a_ ? 1 : params_->row_4_; | |||
| #else | |||
| int row_tmp = is_vector_a_ ? 1 : params_->row_12_; | |||
| @@ -106,9 +109,14 @@ int MatmulCPUKernel::MallocMatrixBBuffer() { | |||
| for (size_t i = 0; i < b_shape.size() - 2; ++i) { | |||
| batch *= b_shape[i]; | |||
| } | |||
| #ifdef ENABLE_AVX | |||
| int col_tile = C16NUM; | |||
| #else | |||
| int col_tile = C8NUM; | |||
| #endif | |||
| params_->batch = batch; | |||
| params_->col_ = params_->b_transpose_ ? b_shape[b_shape.size() - 2] : b_shape[b_shape.size() - 1]; | |||
| params_->col_8_ = UP_ROUND(params_->col_, 8); | |||
| params_->col_8_ = UP_ROUND(params_->col_, col_tile); | |||
| params_->deep_ = params_->b_transpose_ ? b_shape[b_shape.size() - 1] : b_shape[b_shape.size() - 2]; | |||
| int col_tmp = is_vector_a_ ? params_->col_ : params_->col_8_; | |||
| @@ -123,8 +131,8 @@ int MatmulCPUKernel::MallocMatrixBBuffer() { | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| thread_count_ = MSMIN(thread_count_, UP_DIV(params_->col_8_, 8)); | |||
| thread_stride_ = UP_DIV(UP_DIV(params_->col_8_, 8), thread_count_); | |||
| thread_count_ = MSMIN(thread_count_, UP_DIV(params_->col_8_, col_tile)); | |||
| thread_stride_ = UP_DIV(UP_DIV(params_->col_8_, col_tile), thread_count_); | |||
| return RET_OK; | |||
| } | |||
| @@ -134,7 +142,12 @@ int MatmulCPUKernel::InitBias() { | |||
| params_->col_ = params_->b_const_ | |||
| ? (params_->b_transpose_ ? b_shape.at(b_shape.size() - 2) : b_shape.at(b_shape.size() - 1)) | |||
| : (c_shape.at(c_shape.size() - 1)); | |||
| params_->col_8_ = UP_ROUND(params_->col_, 8); | |||
| #ifdef ENABLE_AVX | |||
| int col_tile = C16NUM; | |||
| #else | |||
| int col_tile = C8NUM; | |||
| #endif | |||
| params_->col_8_ = UP_ROUND(params_->col_, col_tile); | |||
| auto col_tmp = is_vector_a_ ? params_->col_ : params_->col_8_; | |||
| if (bias_ptr_ == nullptr) { | |||
| bias_ptr_ = reinterpret_cast<float *>(malloc(col_tmp * sizeof(float))); | |||
| @@ -171,7 +184,14 @@ void MatmulCPUKernel::InitMatrixA(const float *src_ptr, float *dst_ptr) { | |||
| for (int i = 0; i < params_->batch; i++) { | |||
| const float *src = src_ptr + i * params_->deep_ * params_->row_; | |||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||
| #ifdef ENABLE_AVX | |||
| float *dst = dst_ptr + i * params_->deep_ * params_->row_6_; | |||
| if (params_->a_transpose_) { | |||
| RowMajor2Row6Major(src, dst, params_->deep_, params_->row_); | |||
| } else { | |||
| RowMajor2Col6Major(src, dst, params_->row_, params_->deep_); | |||
| } | |||
| #elif defined(ENABLE_ARM32) || defined(ENABLE_SSE) | |||
| float *dst = dst_ptr + i * params_->deep_ * params_->row_4_; | |||
| if (params_->a_transpose_) { | |||
| RowMajor2Row4Major(src, dst, params_->deep_, params_->row_); | |||
| @@ -207,11 +227,19 @@ void MatmulCPUKernel::InitMatrixB(const float *src_ptr, float *dst_ptr) { | |||
| for (int i = 0; i < params_->batch; i++) { | |||
| const float *src = src_ptr + i * params_->deep_ * params_->col_; | |||
| float *dst = dst_ptr + i * params_->deep_ * params_->col_8_; | |||
| #ifdef ENABLE_AVX | |||
| if (params_->b_transpose_) { | |||
| RowMajor2Col16Major(src, dst, params_->col_, params_->deep_); | |||
| } else { | |||
| RowMajor2Row16Major(src, dst, params_->deep_, params_->col_); | |||
| } | |||
| #else | |||
| if (params_->b_transpose_) { | |||
| RowMajor2Col8Major(src, dst, params_->col_, params_->deep_); | |||
| } else { | |||
| RowMajor2Row8Major(src, dst, params_->deep_, params_->col_); | |||
| } | |||
| #endif | |||
| } | |||
| return; | |||
| } | |||
| @@ -247,13 +275,18 @@ int MatmulCPUKernel::Init() { | |||
| } | |||
| int MatmulCPUKernel::RunImpl(int task_id) { | |||
| int cur_oc = MSMIN(thread_stride_ * C8NUM, params_->col_ - task_id * thread_stride_ * C8NUM); | |||
| #ifdef ENABLE_AVX | |||
| int col_tile = C16NUM; | |||
| #else | |||
| int col_tile = C8NUM; | |||
| #endif | |||
| int cur_oc = MSMIN(thread_stride_ * col_tile, params_->col_ - task_id * thread_stride_ * col_tile); | |||
| if (cur_oc <= 0) { | |||
| return RET_OK; | |||
| } | |||
| auto b = cur_b_ptr_ + task_id * thread_stride_ * C8NUM * params_->deep_; | |||
| auto c = cur_c_ptr_ + task_id * thread_stride_ * C8NUM; | |||
| auto bias = bias_ptr_ ? bias_ptr_ + task_id * thread_stride_ * C8NUM : NULL; | |||
| auto b = cur_b_ptr_ + task_id * thread_stride_ * col_tile * params_->deep_; | |||
| auto c = cur_c_ptr_ + task_id * thread_stride_ * col_tile; | |||
| auto bias = bias_ptr_ ? bias_ptr_ + task_id * thread_stride_ * col_tile : NULL; | |||
| MS_ASSERT(cur_a_ptr_); | |||
| MS_ASSERT(b); | |||
| MS_ASSERT(c); | |||
| @@ -323,7 +356,9 @@ int MatmulCPUKernel::Run() { | |||
| cur_b_ptr_ = b_ptr_ + i * params_->deep_ * params_->col_; | |||
| cur_c_ptr_ = c_src + i * params_->row_ * params_->col_; | |||
| } else { | |||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||
| #ifdef ENABLE_AVX | |||
| cur_a_ptr_ = a_ptr_ + i * params_->row_6_ * params_->deep_; | |||
| #elif defined(ENABLE_ARM32) || defined(ENABLE_SSE) | |||
| cur_a_ptr_ = a_ptr_ + i * params_->row_4_ * params_->deep_; | |||
| #else | |||
| cur_a_ptr_ = a_ptr_ + i * params_->row_12_ * params_->deep_; | |||
| @@ -75,6 +75,16 @@ if ("${X86_64_SIMD}" STREQUAL "sse") | |||
| ) | |||
| endif() | |||
| if ("${X86_64_SIMD}" STREQUAL "avx") | |||
| file(GLOB TEST_ASSEMBLY_SRC ${LITE_DIR}/nnacl/x86_64_sse/*.c | |||
| ${LITE_DIR}/nnacl/assembly/avx/*.S) | |||
| set_property(SOURCE ${TEST_ASSEMBLY_SRC} PROPERTY LANGUAGE C) | |||
| set(KERNEL_OP_SRC | |||
| ${KERNEL_OP_SRC} | |||
| ${TEST_ASSEMBLY_SRC} | |||
| ) | |||
| endif() | |||
| ### gpu kernel | |||
| if (SUPPORT_GPU) | |||
| file(GLOB GPU_KERNEL_OP_SRC | |||