| @@ -20,7 +20,8 @@ option(SUPPORT_GPU "if support gpu" off) | |||||
| option(OFFLINE_COMPILE "if offline compile OpenCL kernel" off) | option(OFFLINE_COMPILE "if offline compile OpenCL kernel" off) | ||||
| option(BUILD_MINDDATA_EXAMPLE "" on) | option(BUILD_MINDDATA_EXAMPLE "" on) | ||||
| option(ENABLE_VERBOSE "" off) | option(ENABLE_VERBOSE "" off) | ||||
| option(ENABLE_X86_64_SSE "if x86_64 support SSE instruction set" off) | |||||
| option(ENABLE_SSE "if x86_64 support SSE instruction set" off) | |||||
| option(ENABLE_AVX "if x86_64 support SSE instruction set" off) | |||||
| set(DIR_PREFIX mindspore-lite) | set(DIR_PREFIX mindspore-lite) | ||||
| set(MS_VERSION ${MS_VERSION_MAJOR}.${MS_VERSION_MINOR}.${MS_VERSION_REVISION}) | set(MS_VERSION ${MS_VERSION_MAJOR}.${MS_VERSION_MINOR}.${MS_VERSION_REVISION}) | ||||
| @@ -187,7 +188,13 @@ endif() | |||||
| if (NOT PLATFORM_ARM32 AND NOT PLATFORM_ARM64) | if (NOT PLATFORM_ARM32 AND NOT PLATFORM_ARM64) | ||||
| if ("${X86_64_SIMD}" STREQUAL "sse") | if ("${X86_64_SIMD}" STREQUAL "sse") | ||||
| add_compile_definitions(ENABLE_X86_64_SSE) | |||||
| add_compile_definitions(ENABLE_SSE) | |||||
| endif () | |||||
| if ("${X86_64_SIMD}" STREQUAL "avx") | |||||
| add_compile_definitions(ENABLE_SSE) | |||||
| add_compile_definitions(ENABLE_AVX) | |||||
| set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx -mfma") | |||||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx -mfma") | |||||
| endif () | endif () | ||||
| endif () | endif () | ||||
| @@ -37,6 +37,12 @@ if ("${X86_64_SIMD}" STREQUAL "sse") | |||||
| set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C) | set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C) | ||||
| endif() | endif() | ||||
| if ("${X86_64_SIMD}" STREQUAL "avx") | |||||
| file(GLOB ASSEMBLY_SRC ${NNACL_DIR}/x86_64_sse/*.c | |||||
| ${NNACL_DIR}/assembly/avx/*.S) | |||||
| set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C) | |||||
| endif() | |||||
| ########################### build nnacl static library ######################## | ########################### build nnacl static library ######################## | ||||
| string(REPLACE "-fvisibility=hidden" "-fvisibility=default" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") | string(REPLACE "-fvisibility=hidden" "-fvisibility=default" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") | ||||
| add_library(nnacl STATIC ${KERNEL_SRC} ${TRAIN_SRC} ${ASSEMBLY_SRC}) | add_library(nnacl STATIC ${KERNEL_SRC} ${TRAIN_SRC} ${ASSEMBLY_SRC}) | ||||
| @@ -0,0 +1,941 @@ | |||||
| #ifdef ENABLE_AVX | |||||
| #ifndef WIN32 | |||||
| .text | |||||
| .align 4 | |||||
| .global MatmulFloatAvxOpt | |||||
| #ifndef __APPLE__ | |||||
| .type MatmulFloatAvxOpt, %function | |||||
| #endif | |||||
| // void MatmulFloatNeon32Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth | |||||
| // int row, int col, size_t stride, size_t writeMode) | |||||
| // rdi: a | |||||
| // rsi: b | |||||
| // rdx: c | |||||
| // rcx: bias | |||||
| // r8: act_type | |||||
| // r9: depth | |||||
| // 8: row | |||||
| // 16: col | |||||
| // 24: stride | |||||
| // 32: writeNhwc/writeWino | |||||
| MatmulFloatAvxOpt: | |||||
| // rbx, rsp, rbp, r12-r15 must be saved according to x86 calling convention | |||||
| pushq %r15 | |||||
| pushq %r14 | |||||
| pushq %r13 | |||||
| pushq %r12 | |||||
| pushq %rbx | |||||
| pushq %rbp | |||||
| pushq %r9 | |||||
| pushq %r8 | |||||
| pushq %rcx | |||||
| pushq %rdx | |||||
| pushq %rsi | |||||
| pushq %rdi | |||||
| addq $96, %rsp | |||||
| movq 8(%rsp), %rbp | |||||
| movq 16(%rsp), %rbx | |||||
| movq 24(%rsp), %r10 | |||||
| movq 32(%rsp), %r14 | |||||
| movq $24, %r11 | |||||
| imul %r9, %r11 | |||||
| cmpq $0, %r14 | |||||
| jne NoC8Steps | |||||
| movq $48, %r13 | |||||
| imul %rbp, %r13 | |||||
| NoC8Steps: | |||||
| cmpq $2, %r14 | |||||
| jne NoWinoSteps | |||||
| movq $4, %r12 | |||||
| imul %r10, %r12 | |||||
| imul %rbx, %r12 | |||||
| movq $48, %r13 | |||||
| imul %r10, %r13 | |||||
| NoWinoSteps: | |||||
| movq $4, %rax | |||||
| imul %rax, %r10 | |||||
| LoopRow: | |||||
| movq -88(%rsp), %rsi | |||||
| movq 16(%rsp), %rbx | |||||
| movq -72(%rsp), %rcx | |||||
| LoopCol: | |||||
| cmpq $0, %r14 | |||||
| je NoReloadDst | |||||
| movq -80(%rsp), %rdx | |||||
| NoReloadDst: | |||||
| movq -96(%rsp), %rdi | |||||
| movq -56(%rsp), %r9 | |||||
| vmovups (%rsi), %ymm0 | |||||
| vmovups 32(%rsi), %ymm1 | |||||
| vbroadcastss (%rdi), %ymm10 | |||||
| vbroadcastss 4(%rdi), %ymm11 | |||||
| vbroadcastss 8(%rdi), %ymm12 | |||||
| vbroadcastss 12(%rdi), %ymm13 | |||||
| vbroadcastss 16(%rdi), %ymm2 | |||||
| vbroadcastss 20(%rdi), %ymm3 | |||||
| addq $64, %rsi | |||||
| vmulps %ymm0, %ymm10, %ymm4 | |||||
| vmulps %ymm1, %ymm10, %ymm5 | |||||
| vmulps %ymm0, %ymm11, %ymm6 | |||||
| vmulps %ymm1, %ymm11, %ymm7 | |||||
| vmulps %ymm0, %ymm12, %ymm8 | |||||
| vmulps %ymm1, %ymm12, %ymm9 | |||||
| vmulps %ymm0, %ymm13, %ymm10 | |||||
| vmulps %ymm1, %ymm13, %ymm11 | |||||
| add $24, %rdi | |||||
| vmulps %ymm0, %ymm2, %ymm12 | |||||
| vmulps %ymm1, %ymm2, %ymm13 | |||||
| vmulps %ymm0, %ymm3, %ymm14 | |||||
| vmulps %ymm1, %ymm3, %ymm15 | |||||
| subq $1, %r9 | |||||
| cmpq $0, %r9 | |||||
| je Bias | |||||
| LoopDepth: | |||||
| vmovups (%rsi), %ymm0 | |||||
| vmovups 32(%rsi), %ymm1 | |||||
| vbroadcastss (%rdi), %ymm2 | |||||
| vbroadcastss 4(%rdi), %ymm3 | |||||
| vfmadd231ps %ymm0, %ymm2, %ymm4 | |||||
| addq $64, %rsi | |||||
| vfmadd231ps %ymm1, %ymm2, %ymm5 | |||||
| vbroadcastss 8(%rdi), %ymm2 | |||||
| vfmadd231ps %ymm0, %ymm3, %ymm6 | |||||
| vfmadd231ps %ymm1, %ymm3, %ymm7 | |||||
| vbroadcastss 12(%rdi), %ymm3 | |||||
| vfmadd231ps %ymm0, %ymm2, %ymm8 | |||||
| prefetcht0 384(%rsi) | |||||
| vfmadd231ps %ymm1, %ymm2, %ymm9 | |||||
| vbroadcastss 16(%rdi), %ymm2 | |||||
| vfmadd231ps %ymm0, %ymm3, %ymm10 | |||||
| vfmadd231ps %ymm1, %ymm3, %ymm11 | |||||
| vbroadcastss 20(%rdi), %ymm3 | |||||
| vfmadd231ps %ymm0, %ymm2, %ymm12 | |||||
| vfmadd231ps %ymm1, %ymm2, %ymm13 | |||||
| addq $24, %rdi | |||||
| vfmadd231ps %ymm0, %ymm3, %ymm14 | |||||
| vfmadd231ps %ymm1, %ymm3, %ymm15 | |||||
| subq $1, %r9 | |||||
| cmpq $0, %r9 | |||||
| ja LoopDepth | |||||
| Bias: | |||||
| cmpq $0, %rcx | |||||
| je Activation | |||||
| vmovups (%rcx), %ymm0 | |||||
| vmovups 32(%rcx), %ymm1 | |||||
| add $64, %rcx | |||||
| vaddps %ymm0, %ymm4, %ymm4 | |||||
| vaddps %ymm1, %ymm5, %ymm5 | |||||
| vaddps %ymm0, %ymm6, %ymm6 | |||||
| vaddps %ymm1, %ymm7, %ymm7 | |||||
| vaddps %ymm0, %ymm8, %ymm8 | |||||
| vaddps %ymm1, %ymm9, %ymm9 | |||||
| vaddps %ymm0, %ymm10, %ymm10 | |||||
| vaddps %ymm1, %ymm11, %ymm11 | |||||
| vaddps %ymm0, %ymm12, %ymm12 | |||||
| vaddps %ymm1, %ymm13, %ymm13 | |||||
| vaddps %ymm0, %ymm14, %ymm14 | |||||
| vaddps %ymm1, %ymm15, %ymm15 | |||||
| Activation: | |||||
| cmpq $3, %r8 | |||||
| je Relu6 | |||||
| cmpq $1, %r8 | |||||
| je Relu | |||||
| jmp Write | |||||
| Relu6: | |||||
| movq $6, %rax | |||||
| vcvtsi2ss %rax, %xmm0, %xmm0 | |||||
| vshufps $0, %xmm0, %xmm0, %xmm0 | |||||
| vinsertf128 $1, %xmm0, %ymm0, %ymm0 | |||||
| vminps %ymm0, %ymm4, %ymm4 | |||||
| vminps %ymm0, %ymm5, %ymm5 | |||||
| vminps %ymm0, %ymm6, %ymm6 | |||||
| vminps %ymm0, %ymm7, %ymm7 | |||||
| vminps %ymm0, %ymm8, %ymm8 | |||||
| vminps %ymm0, %ymm9, %ymm9 | |||||
| vminps %ymm0, %ymm10, %ymm10 | |||||
| vminps %ymm0, %ymm11, %ymm11 | |||||
| vminps %ymm0, %ymm12, %ymm12 | |||||
| vminps %ymm0, %ymm13, %ymm13 | |||||
| vminps %ymm0, %ymm14, %ymm14 | |||||
| vminps %ymm0, %ymm15, %ymm15 | |||||
| Relu: | |||||
| vxorps %ymm1, %ymm1, %ymm1 | |||||
| vmaxps %ymm1, %ymm4, %ymm4 | |||||
| vmaxps %ymm1, %ymm5, %ymm5 | |||||
| vmaxps %ymm1, %ymm6, %ymm6 | |||||
| vmaxps %ymm1, %ymm7, %ymm7 | |||||
| vmaxps %ymm1, %ymm8, %ymm8 | |||||
| vmaxps %ymm1, %ymm9, %ymm9 | |||||
| vmaxps %ymm1, %ymm10, %ymm10 | |||||
| vmaxps %ymm1, %ymm11, %ymm11 | |||||
| vmaxps %ymm1, %ymm12, %ymm12 | |||||
| vmaxps %ymm1, %ymm13, %ymm13 | |||||
| vmaxps %ymm1, %ymm14, %ymm14 | |||||
| vmaxps %ymm1, %ymm15, %ymm15 | |||||
| Write: | |||||
| cmpq $2, %r14 | |||||
| je WriteWino | |||||
| cmpq $0, %r14 | |||||
| je WriteC8 | |||||
| cmpq $1, %rbx | |||||
| je Write1 | |||||
| cmpq $2, %rbx | |||||
| je Write2 | |||||
| cmpq $3, %rbx | |||||
| je Write3 | |||||
| cmpq $4, %rbx | |||||
| je Write4 | |||||
| cmpq $5, %rbx | |||||
| je Write5 | |||||
| cmpq $6, %rbx | |||||
| je Write6 | |||||
| cmpq $7, %rbx | |||||
| je Write7 | |||||
| cmpq $8, %rbx | |||||
| je Write8 | |||||
| cmpq $9, %rbx | |||||
| je Write9 | |||||
| cmpq $10, %rbx | |||||
| je Write10 | |||||
| cmpq $11, %rbx | |||||
| je Write11 | |||||
| cmpq $12, %rbx | |||||
| je Write12 | |||||
| cmpq $13, %rbx | |||||
| je Write13 | |||||
| cmpq $14, %rbx | |||||
| je Write14 | |||||
| cmpq $15, %rbx | |||||
| je Write15 | |||||
| jmp Write16 | |||||
| Write1: | |||||
| movq %rdx, %rax | |||||
| addq $4, %rax | |||||
| movq %rax, -80(%rsp) | |||||
| vmovss %xmm4, (%rdx) | |||||
| cmpq $1, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovss %xmm6, (%rdx) | |||||
| cmpq $2, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovss %xmm8, (%rdx) | |||||
| cmpq $3, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovss %xmm10, (%rdx) | |||||
| cmpq $4, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovss %xmm12, (%rdx) | |||||
| cmpq $5, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovss %xmm14, (%rdx) | |||||
| addq %r10, %rdx | |||||
| addq $4, %rdx | |||||
| jmp WriteEnd | |||||
| Write2: | |||||
| movq %rdx, %rax | |||||
| addq $8, %rax | |||||
| movq %rax, -80(%rsp) | |||||
| vmovsd %xmm4, (%rdx) | |||||
| cmpq $1, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovsd %xmm6, (%rdx) | |||||
| cmpq $2, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovsd %xmm8, (%rdx) | |||||
| cmpq $3, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovsd %xmm10, (%rdx) | |||||
| cmpq $4, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovsd %xmm12, (%rdx) | |||||
| cmpq $5, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovsd %xmm14, (%rdx) | |||||
| addq %r10, %rdx | |||||
| addq $8, %rdx | |||||
| jmp WriteEnd | |||||
| Write3: | |||||
| movq %rdx, %rax | |||||
| addq $12, %rax | |||||
| movq %rax, -80(%rsp) | |||||
| vmovsd %xmm4, (%rdx) | |||||
| movhlps %xmm4, %xmm4 | |||||
| vmovss %xmm4, 8(%rdx) | |||||
| cmpq $1, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovsd %xmm6, (%rdx) | |||||
| movhlps %xmm6, %xmm6 | |||||
| vmovss %xmm6, 8(%rdx) | |||||
| cmpq $2, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovsd %xmm8, (%rdx) | |||||
| movhlps %xmm8, %xmm8 | |||||
| vmovss %xmm8, 8(%rdx) | |||||
| cmpq $3, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovsd %xmm10, (%rdx) | |||||
| movhlps %xmm10, %xmm10 | |||||
| vmovss %xmm10, 8(%rdx) | |||||
| cmpq $4, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovsd %xmm12, (%rdx) | |||||
| movhlps %xmm12, %xmm12 | |||||
| vmovss %xmm12, 8(%rdx) | |||||
| cmpq $5, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovsd %xmm14, (%rdx) | |||||
| movhlps %xmm14, %xmm14 | |||||
| vmovss %xmm14, 8(%rdx) | |||||
| addq %r10, %rdx | |||||
| addq $12, %rdx | |||||
| jmp WriteEnd | |||||
| Write4: | |||||
| movq %rdx, %rax | |||||
| addq $16, %rax | |||||
| movq %rax, -80(%rsp) | |||||
| vmovups %xmm4, (%rdx) | |||||
| cmpq $1, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %xmm6, (%rdx) | |||||
| cmpq $2, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %xmm8, (%rdx) | |||||
| cmpq $3, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %xmm10, (%rdx) | |||||
| cmpq $4, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %xmm12, (%rdx) | |||||
| cmpq $5, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %xmm14, (%rdx) | |||||
| addq %r10, %rdx | |||||
| addq $16, %rdx | |||||
| jmp WriteEnd | |||||
| Write5: | |||||
| movq %rdx, %rax | |||||
| addq $20, %rax | |||||
| movq %rax, -80(%rsp) | |||||
| vmovups %xmm4, (%rdx) | |||||
| vextractf128 $1, %ymm4, %xmm4 | |||||
| vmovss %xmm4, 16(%rdx) | |||||
| cmpq $1, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %xmm6, (%rdx) | |||||
| vextractf128 $1, %ymm6, %xmm6 | |||||
| vmovss %xmm6, 16(%rdx) | |||||
| cmpq $2, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %xmm8, (%rdx) | |||||
| vextractf128 $1, %ymm8, %xmm8 | |||||
| vmovss %xmm8, 16(%rdx) | |||||
| cmpq $3, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %xmm10, (%rdx) | |||||
| vextractf128 $1, %ymm10, %xmm10 | |||||
| vmovss %xmm10, 16(%rdx) | |||||
| cmpq $4, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %xmm12, (%rdx) | |||||
| vextractf128 $1, %ymm12, %xmm12 | |||||
| vmovss %xmm12, 16(%rdx) | |||||
| cmpq $5, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %xmm14, (%rdx) | |||||
| vextractf128 $1, %ymm14, %xmm14 | |||||
| vmovss %xmm14, 16(%rdx) | |||||
| addq %r10, %rdx | |||||
| addq $20, %rdx | |||||
| jmp WriteEnd | |||||
| Write6: | |||||
| movq %rdx, %rax | |||||
| addq $24, %rax | |||||
| movq %rax, -80(%rsp) | |||||
| vmovups %xmm4, (%rdx) | |||||
| vextractf128 $1, %ymm4, %xmm4 | |||||
| vmovsd %xmm4, 16(%rdx) | |||||
| cmpq $1, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %xmm6, (%rdx) | |||||
| vextractf128 $1, %ymm6, %xmm6 | |||||
| vmovsd %xmm6, 16(%rdx) | |||||
| cmpq $2, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %xmm8, (%rdx) | |||||
| vextractf128 $1, %ymm8, %xmm8 | |||||
| vmovsd %xmm8, 16(%rdx) | |||||
| cmpq $3, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %xmm10, (%rdx) | |||||
| vextractf128 $1, %ymm10, %xmm10 | |||||
| vmovsd %xmm10, 16(%rdx) | |||||
| cmpq $4, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %xmm12, (%rdx) | |||||
| vextractf128 $1, %ymm12, %xmm12 | |||||
| vmovsd %xmm12, 16(%rdx) | |||||
| cmpq $5, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %xmm14, (%rdx) | |||||
| vextractf128 $1, %ymm14, %xmm14 | |||||
| vmovsd %xmm14, 16(%rdx) | |||||
| addq %r10, %rdx | |||||
| addq $24, %rdx | |||||
| jmp WriteEnd | |||||
| Write7: | |||||
| movq %rdx, %rax | |||||
| addq $28, %rax | |||||
| movq %rax, -80(%rsp) | |||||
| vmovups %xmm4, (%rdx) | |||||
| vextractf128 $1, %ymm4, %xmm4 | |||||
| vmovsd %xmm4, 16(%rdx) | |||||
| movhlps %xmm4, %xmm4 | |||||
| vmovss %xmm4, 24(%rdx) | |||||
| cmpq $1, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %xmm6, (%rdx) | |||||
| vextractf128 $1, %ymm6, %xmm6 | |||||
| vmovsd %xmm6, 16(%rdx) | |||||
| movhlps %xmm6, %xmm6 | |||||
| vmovss %xmm6, 24(%rdx) | |||||
| cmpq $2, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %xmm8, (%rdx) | |||||
| vextractf128 $1, %ymm8, %xmm8 | |||||
| vmovsd %xmm8, 16(%rdx) | |||||
| movhlps %xmm8, %xmm8 | |||||
| vmovss %xmm8, 24(%rdx) | |||||
| cmpq $3, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %xmm10, (%rdx) | |||||
| vextractf128 $1, %ymm10, %xmm10 | |||||
| vmovsd %xmm10, 16(%rdx) | |||||
| movhlps %xmm10, %xmm10 | |||||
| vmovss %xmm10, 24(%rdx) | |||||
| cmpq $4, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %xmm12, (%rdx) | |||||
| vextractf128 $1, %ymm12, %xmm12 | |||||
| vmovsd %xmm12, 16(%rdx) | |||||
| movhlps %xmm12, %xmm12 | |||||
| vmovss %xmm12, 24(%rdx) | |||||
| cmpq $5, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %xmm14, (%rdx) | |||||
| vextractf128 $1, %ymm14, %xmm14 | |||||
| vmovsd %xmm14, 16(%rdx) | |||||
| movhlps %xmm14, %xmm14 | |||||
| vmovss %xmm14, 24(%rdx) | |||||
| addq %r10, %rdx | |||||
| addq $28, %rdx | |||||
| jmp WriteEnd | |||||
| Write8: | |||||
| movq %rdx, %rax | |||||
| addq $32, %rax | |||||
| movq %rax, -80(%rsp) | |||||
| vmovups %ymm4, (%rdx) | |||||
| cmpq $1, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %ymm6, (%rdx) | |||||
| cmpq $2, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %ymm8, (%rdx) | |||||
| cmpq $3, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %ymm10, (%rdx) | |||||
| cmpq $4, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %ymm12, (%rdx) | |||||
| cmpq $5, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %ymm14, (%rdx) | |||||
| addq %r10, %rdx | |||||
| addq $32, %rdx | |||||
| jmp WriteEnd | |||||
| Write9: | |||||
| movq %rdx, %rax | |||||
| addq $36, %rax | |||||
| movq %rax, -80(%rsp) | |||||
| vmovups %ymm4, (%rdx) | |||||
| vmovss %xmm5, 32(%rdx) | |||||
| cmpq $1, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %ymm6, (%rdx) | |||||
| vmovss %xmm7, 32(%rdx) | |||||
| cmpq $2, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %ymm8, (%rdx) | |||||
| vmovss %xmm9, 32(%rdx) | |||||
| cmpq $3, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %ymm10, (%rdx) | |||||
| vmovss %xmm11, 32(%rdx) | |||||
| cmpq $4, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %ymm12, (%rdx) | |||||
| vmovss %xmm13, 32(%rdx) | |||||
| cmpq $5, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %ymm14, (%rdx) | |||||
| vmovss %xmm15, 32(%rdx) | |||||
| addq %r10, %rdx | |||||
| addq $36, %rdx | |||||
| jmp WriteEnd | |||||
| Write10: | |||||
| movq %rdx, %rax | |||||
| addq $40, %rax | |||||
| movq %rax, -80(%rsp) | |||||
| vmovups %ymm4, (%rdx) | |||||
| vmovsd %xmm5, 32(%rdx) | |||||
| cmpq $1, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %ymm6, (%rdx) | |||||
| vmovsd %xmm7, 32(%rdx) | |||||
| cmpq $2, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %ymm8, (%rdx) | |||||
| vmovsd %xmm9, 32(%rdx) | |||||
| cmpq $3, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %ymm10, (%rdx) | |||||
| vmovsd %xmm11, 32(%rdx) | |||||
| cmpq $4, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %ymm12, (%rdx) | |||||
| vmovsd %xmm13, 32(%rdx) | |||||
| cmpq $5, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %ymm14, (%rdx) | |||||
| vmovsd %xmm15, 32(%rdx) | |||||
| addq %r10, %rdx | |||||
| addq $40, %rdx | |||||
| jmp WriteEnd | |||||
| Write11: | |||||
| movq %rdx, %rax | |||||
| addq $44, %rax | |||||
| movq %rax, -80(%rsp) | |||||
| vmovups %ymm4, (%rdx) | |||||
| vmovsd %xmm5, 32(%rdx) | |||||
| movhlps %xmm5, %xmm5 | |||||
| vmovss %xmm5, 40(%rdx) | |||||
| cmpq $1, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %ymm6, (%rdx) | |||||
| vmovsd %xmm7, 32(%rdx) | |||||
| movhlps %xmm7, %xmm7 | |||||
| vmovss %xmm7, 40(%rdx) | |||||
| cmpq $2, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %ymm8, (%rdx) | |||||
| vmovsd %xmm9, 32(%rdx) | |||||
| movhlps %xmm9, %xmm9 | |||||
| vmovss %xmm9, 40(%rdx) | |||||
| cmpq $3, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %ymm10, (%rdx) | |||||
| vmovsd %xmm11, 32(%rdx) | |||||
| movhlps %xmm11, %xmm11 | |||||
| vmovss %xmm11, 40(%rdx) | |||||
| cmpq $4, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %ymm12, (%rdx) | |||||
| vmovsd %xmm13, 32(%rdx) | |||||
| movhlps %xmm13, %xmm13 | |||||
| vmovss %xmm13, 40(%rdx) | |||||
| cmpq $5, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %ymm14, (%rdx) | |||||
| vmovsd %xmm15, 32(%rdx) | |||||
| movhlps %xmm15, %xmm15 | |||||
| vmovss %xmm15, 40(%rdx) | |||||
| addq %r10, %rdx | |||||
| addq $44, %rdx | |||||
| jmp WriteEnd | |||||
| Write12: | |||||
| movq %rdx, %rax | |||||
| addq $48, %rax | |||||
| movq %rax, -80(%rsp) | |||||
| vmovups %ymm4, (%rdx) | |||||
| vmovups %xmm5, 32(%rdx) | |||||
| cmpq $1, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %ymm6, (%rdx) | |||||
| vmovups %xmm7, 32(%rdx) | |||||
| cmpq $2, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %ymm8, (%rdx) | |||||
| vmovups %xmm9, 32(%rdx) | |||||
| cmpq $3, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %ymm10, (%rdx) | |||||
| vmovups %xmm11, 32(%rdx) | |||||
| cmpq $4, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %ymm12, (%rdx) | |||||
| vmovups %xmm13, 32(%rdx) | |||||
| cmpq $5, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %ymm14, (%rdx) | |||||
| vmovups %xmm15, 32(%rdx) | |||||
| addq %r10, %rdx | |||||
| addq $48, %rdx | |||||
| jmp WriteEnd | |||||
| Write13: | |||||
| movq %rdx, %rax | |||||
| addq $52, %rax | |||||
| movq %rax, -80(%rsp) | |||||
| vmovups %ymm4, (%rdx) | |||||
| vmovups %xmm5, 32(%rdx) | |||||
| vextractf128 $1, %ymm5, %xmm5 | |||||
| vmovss %xmm5, 48(%rdx) | |||||
| cmpq $1, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %ymm6, (%rdx) | |||||
| vmovups %xmm7, 32(%rdx) | |||||
| vextractf128 $1, %ymm7, %xmm7 | |||||
| vmovss %xmm7, 48(%rdx) | |||||
| cmpq $2, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %ymm8, (%rdx) | |||||
| vmovups %xmm9, 32(%rdx) | |||||
| vextractf128 $1, %ymm9, %xmm9 | |||||
| vmovss %xmm9, 48(%rdx) | |||||
| cmpq $3, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %ymm10, (%rdx) | |||||
| vmovups %xmm11, 32(%rdx) | |||||
| vextractf128 $1, %ymm11, %xmm11 | |||||
| vmovss %xmm11, 48(%rdx) | |||||
| cmpq $4, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %ymm12, (%rdx) | |||||
| vmovups %xmm13, 32(%rdx) | |||||
| vextractf128 $1, %ymm13, %xmm13 | |||||
| vmovss %xmm13, 48(%rdx) | |||||
| cmpq $5, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %ymm14, (%rdx) | |||||
| vmovups %xmm15, 32(%rdx) | |||||
| vextractf128 $1, %ymm15, %xmm15 | |||||
| vmovss %xmm15, 48(%rdx) | |||||
| addq %r10, %rdx | |||||
| addq $52, %rdx | |||||
| jmp WriteEnd | |||||
| Write14: | |||||
| movq %rdx, %rax | |||||
| addq $56, %rax | |||||
| movq %rax, -80(%rsp) | |||||
| vmovups %ymm4, (%rdx) | |||||
| vmovups %xmm5, 32(%rdx) | |||||
| vextractf128 $1, %ymm5, %xmm5 | |||||
| vmovsd %xmm5, 48(%rdx) | |||||
| cmpq $1, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %ymm6, (%rdx) | |||||
| vmovups %xmm7, 32(%rdx) | |||||
| vextractf128 $1, %ymm7, %xmm7 | |||||
| vmovsd %xmm7, 48(%rdx) | |||||
| cmpq $2, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %ymm8, (%rdx) | |||||
| vmovups %xmm9, 32(%rdx) | |||||
| vextractf128 $1, %ymm9, %xmm9 | |||||
| vmovsd %xmm9, 48(%rdx) | |||||
| cmpq $3, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %ymm10, (%rdx) | |||||
| vmovups %xmm11, 32(%rdx) | |||||
| vextractf128 $1, %ymm11, %xmm11 | |||||
| vmovsd %xmm11, 48(%rdx) | |||||
| cmpq $4, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %ymm12, (%rdx) | |||||
| vmovups %xmm13, 32(%rdx) | |||||
| vextractf128 $1, %ymm13, %xmm13 | |||||
| vmovsd %xmm13, 48(%rdx) | |||||
| cmpq $5, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %ymm14, (%rdx) | |||||
| vmovups %xmm15, 32(%rdx) | |||||
| vextractf128 $1, %ymm15, %xmm15 | |||||
| vmovsd %xmm15, 48(%rdx) | |||||
| addq %r10, %rdx | |||||
| addq $56, %rdx | |||||
| jmp WriteEnd | |||||
| Write15: | |||||
| movq %rdx, %rax | |||||
| addq $60, %rax | |||||
| movq %rax, -80(%rsp) | |||||
| vmovups %ymm4, (%rdx) | |||||
| vmovups %xmm5, 32(%rdx) | |||||
| vextractf128 $1, %ymm5, %xmm5 | |||||
| vmovsd %xmm5, 48(%rdx) | |||||
| movhlps %xmm5, %xmm5 | |||||
| vmovss %xmm5, 56(%rdx) | |||||
| cmpq $1, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %ymm6, (%rdx) | |||||
| vmovups %xmm7, 32(%rdx) | |||||
| vextractf128 $1, %ymm7, %xmm7 | |||||
| vmovsd %xmm7, 48(%rdx) | |||||
| movhlps %xmm7, %xmm7 | |||||
| vmovss %xmm7, 56(%rdx) | |||||
| cmpq $2, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %ymm8, (%rdx) | |||||
| vmovups %xmm9, 32(%rdx) | |||||
| vextractf128 $1, %ymm9, %xmm9 | |||||
| vmovsd %xmm9, 48(%rdx) | |||||
| movhlps %xmm9, %xmm9 | |||||
| vmovss %xmm9, 56(%rdx) | |||||
| cmpq $3, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %ymm10, (%rdx) | |||||
| vmovups %xmm11, 32(%rdx) | |||||
| vextractf128 $1, %ymm11, %xmm11 | |||||
| vmovsd %xmm11, 48(%rdx) | |||||
| movhlps %xmm11, %xmm11 | |||||
| vmovss %xmm11, 56(%rdx) | |||||
| cmpq $4, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %ymm12, (%rdx) | |||||
| vmovups %xmm13, 32(%rdx) | |||||
| vextractf128 $1, %ymm13, %xmm13 | |||||
| vmovsd %xmm13, 48(%rdx) | |||||
| movhlps %xmm13, %xmm13 | |||||
| vmovss %xmm13, 56(%rdx) | |||||
| cmpq $5, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %ymm14, (%rdx) | |||||
| vmovups %xmm15, 32(%rdx) | |||||
| vextractf128 $1, %ymm15, %xmm15 | |||||
| vmovsd %xmm15, 48(%rdx) | |||||
| movhlps %xmm15, %xmm15 | |||||
| vmovss %xmm15, 56(%rdx) | |||||
| addq %r10, %rdx | |||||
| addq $60, %rdx | |||||
| jmp WriteEnd | |||||
| WriteC8: | |||||
| movq %rdx, %rax | |||||
| addq %r11, %rdx | |||||
| movq %rdx, %r15 | |||||
| addq %r11, %rdx | |||||
| movq %rdx, -80(%rsp) | |||||
| vmovups %ymm4, (%rax) | |||||
| vmovups %ymm6, 32(%rax) | |||||
| vmovups %ymm8, 64(%rax) | |||||
| vmovups %ymm10, 96(%rax) | |||||
| vmovups %ymm12, 128(%rax) | |||||
| vmovups %ymm14, 160(%rax) | |||||
| vmovups %ymm5, (%r15) | |||||
| vmovups %ymm7, 32(%r15) | |||||
| vmovups %ymm9, 64(%r15) | |||||
| vmovups %ymm11, 96(%r15) | |||||
| vmovups %ymm13, 128(%r15) | |||||
| vmovups %ymm15, 160(%r15) | |||||
| jmp WriteEnd | |||||
| WriteWino: | |||||
| movq %rdx, %rax | |||||
| addq %r13, %rdx | |||||
| movq %rdx, %r15 | |||||
| addq %r13, %rdx | |||||
| movq %rdx, -80(%rsp) | |||||
| vmovups %ymm4, (%rax) | |||||
| vmovups %ymm5, (%r15) | |||||
| addq %r12, %rax | |||||
| addq %r12, %r15 | |||||
| vmovups %ymm6, (%rax) | |||||
| vmovups %ymm7, (%r15) | |||||
| addq %r12, %rax | |||||
| addq %r12, %r15 | |||||
| vmovups %ymm8, (%rax) | |||||
| vmovups %ymm9, (%r15) | |||||
| addq %r12, %rax | |||||
| addq %r12, %r15 | |||||
| vmovups %ymm10, (%rax) | |||||
| vmovups %ymm11, (%r15) | |||||
| addq %r12, %rax | |||||
| addq %r12, %r15 | |||||
| vmovups %ymm12, (%rax) | |||||
| vmovups %ymm13, (%r15) | |||||
| addq %r12, %rax | |||||
| addq %r12, %r15 | |||||
| vmovups %ymm14, (%rax) | |||||
| vmovups %ymm15, (%r15) | |||||
| jmp WriteEnd | |||||
| Write16: | |||||
| movq %rdx, %rax | |||||
| addq $64, %rax | |||||
| movq %rax, -80(%rsp) | |||||
| vmovups %ymm4, (%rdx) | |||||
| vmovups %ymm5, 32(%rdx) | |||||
| cmpq $1, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %ymm6, (%rdx) | |||||
| vmovups %ymm7, 32(%rdx) | |||||
| cmpq $2, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %ymm8, (%rdx) | |||||
| vmovups %ymm9, 32(%rdx) | |||||
| cmpq $3, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %ymm10, (%rdx) | |||||
| vmovups %ymm11, 32(%rdx) | |||||
| cmpq $4, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %ymm12, (%rdx) | |||||
| vmovups %ymm13, 32(%rdx) | |||||
| cmpq $5, %rbp | |||||
| je WriteEnd | |||||
| addq %r10, %rdx | |||||
| vmovups %ymm14, (%rdx) | |||||
| vmovups %ymm15, 32(%rdx) | |||||
| addq %r10, %rdx | |||||
| addq $64, %rdx | |||||
| WriteEnd: | |||||
| cmpq $16, %rbx | |||||
| jbe LoopColEnd | |||||
| subq $16, %rbx | |||||
| jmp LoopCol | |||||
| LoopColEnd: | |||||
| movq -96(%rsp), %rdi | |||||
| addq %r11, %rdi | |||||
| movq %rdi, -96(%rsp) | |||||
| cmpq $0, %r14 | |||||
| je C8DstStep | |||||
| cmpq $2, %r14 | |||||
| je WinoDstStep | |||||
| movq $4, %rax | |||||
| movq 16(%rsp), %rbx | |||||
| imul %rbx, %rax | |||||
| subq %rax, %rdx | |||||
| movq %rdx, -80(%rsp) | |||||
| jmp NoDstStep | |||||
| C8DstStep: | |||||
| movq -80(%rsp), %rax | |||||
| addq $384, %rax | |||||
| movq %rax, -80(%rsp) | |||||
| jmp NoDstStep | |||||
| WinoDstStep: | |||||
| addq %r13, %rdx | |||||
| movq %rdx, -80(%rsp) | |||||
| NoDstStep: | |||||
| cmpq $6, %rbp | |||||
| jbe LoopRowEnd | |||||
| subq $6, %rbp | |||||
| jmp LoopRow | |||||
| LoopRowEnd: | |||||
| subq $96, %rsp | |||||
| popq %rdi | |||||
| popq %rsi | |||||
| popq %rdx | |||||
| popq %rcx | |||||
| popq %r8 | |||||
| popq %r9 | |||||
| popq %rbp | |||||
| popq %rbx | |||||
| popq %r12 | |||||
| popq %r13 | |||||
| popq %r14 | |||||
| popq %r15 | |||||
| retq | |||||
| #endif | |||||
| #endif | |||||
| @@ -56,7 +56,7 @@ void AdderFp32(const float *input_data, float *packed_input, const float *packed | |||||
| int out_channel = conv_param->output_channel_; | int out_channel = conv_param->output_channel_; | ||||
| int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; | int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; | ||||
| int output_count = conv_param->output_h_ * conv_param->output_w_; | int output_count = conv_param->output_h_ * conv_param->output_w_; | ||||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||||
| #if defined(ENABLE_ARM32) || defined(ENABLE_SSE) | |||||
| const int cal_num = C4NUM; | const int cal_num = C4NUM; | ||||
| #else | #else | ||||
| const int cal_num = C12NUM; | const int cal_num = C12NUM; | ||||
| @@ -78,7 +78,7 @@ void AdderFp32(const float *input_data, float *packed_input, const float *packed | |||||
| int out_offset = thread_id * cal_num * out_channel + out_batch_offset; | int out_offset = thread_id * cal_num * out_channel + out_batch_offset; | ||||
| float *gemm_output = output_data + out_offset; | float *gemm_output = output_data + out_offset; | ||||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||||
| #if defined(ENABLE_ARM32) || defined(ENABLE_SSE) | |||||
| RowMajor2Col4Major(gemm_input, col_major_gemm_input, cal_num, deep); | RowMajor2Col4Major(gemm_input, col_major_gemm_input, cal_num, deep); | ||||
| #else | #else | ||||
| RowMajor2Col12Major(gemm_input, col_major_gemm_input, cal_num, deep); | RowMajor2Col12Major(gemm_input, col_major_gemm_input, cal_num, deep); | ||||
| @@ -43,7 +43,7 @@ void PostConvFuncComm(const float *src_ptr_, float *out_ptr, const float *bias_p | |||||
| void PostConvFuncFp32C8(const float *c8_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel, | void PostConvFuncFp32C8(const float *c8_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel, | ||||
| size_t plane_size, size_t stride, size_t relu_type) { | size_t plane_size, size_t stride, size_t relu_type) { | ||||
| #if !defined(ENABLE_ARM) && !defined(ENABLE_X86_64_SSE) | |||||
| #if !defined(ENABLE_ARM) && !defined(ENABLE_SSE) | |||||
| PostConvFuncComm(c8_out_ptr, out_ptr, bias_ptr, output_channel, plane_size, plane_size, stride, relu_type, C8NUM); | PostConvFuncComm(c8_out_ptr, out_ptr, bias_ptr, output_channel, plane_size, plane_size, stride, relu_type, C8NUM); | ||||
| #else | #else | ||||
| size_t oc8mod = output_channel % C8NUM; | size_t oc8mod = output_channel % C8NUM; | ||||
| @@ -68,7 +68,7 @@ void PostConvFuncFp32C4(const float *c4_out_ptr, float *out_ptr, const float *bi | |||||
| return; | return; | ||||
| } | } | ||||
| #if !defined(ENABLE_ARM) && !defined(ENABLE_X86_64_SSE) | |||||
| #if !defined(ENABLE_ARM) && !defined(ENABLE_SSE) | |||||
| void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length) { | void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length) { | ||||
| const int unitStep = 4 * length; | const int unitStep = 4 * length; | ||||
| for (int y = 0; y < h; ++y) { | for (int y = 0; y < h; ++y) { | ||||
| @@ -39,7 +39,7 @@ float ShortToFloat32(uint16_t src_value); | |||||
| uint16_t Float32ToShort(float src_value); | uint16_t Float32ToShort(float src_value); | ||||
| #if defined(ENABLE_ARM) || defined(ENABLE_X86_64_SSE) | |||||
| #if defined(ENABLE_ARM) || defined(ENABLE_SSE) | |||||
| void ConvDwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, | void ConvDwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, | ||||
| size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, | size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, | ||||
| size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, size_t relu, size_t relu6); | size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, size_t relu, size_t relu6); | ||||
| @@ -202,7 +202,7 @@ void ConvDwBorder(float *dst, const float *src, const float *weight, const float | |||||
| const float *src_kernel = src_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_; | const float *src_kernel = src_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_; | ||||
| const float *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C4NUM; | const float *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C4NUM; | ||||
| #if defined(ENABLE_ARM) || defined(ENABLE_X86_64_SSE) | |||||
| #if defined(ENABLE_ARM) || defined(ENABLE_SSE) | |||||
| ConvDwFp32Border(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, | ConvDwFp32Border(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, | ||||
| sliding->in_kh_step_ * sizeof(float), sliding->in_kw_step_ * sizeof(float), | sliding->in_kh_step_ * sizeof(float), sliding->in_kw_step_ * sizeof(float), | ||||
| conv_param->kernel_w_ * C4NUM * sizeof(float), relu, relu6); | conv_param->kernel_w_ * C4NUM * sizeof(float), relu, relu6); | ||||
| @@ -285,7 +285,7 @@ void ConvDwSWFp32(float *output_data, const float *input_data, const float *weig | |||||
| int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_; | int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_; | ||||
| const float *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * sliding->block_channel_; | const float *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * sliding->block_channel_; | ||||
| float *out_t = dst_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_; | float *out_t = dst_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_; | ||||
| #if defined(ENABLE_ARM) || defined(ENABLE_X86_64_SSE) | |||||
| #if defined(ENABLE_ARM) || defined(ENABLE_SSE) | |||||
| ConvDwFp32Center(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, | ConvDwFp32Center(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, | ||||
| conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float), | conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float), | ||||
| sliding->block_channel_ * sizeof(float), sliding->in_sh_step_ * sizeof(float), | sliding->block_channel_ * sizeof(float), sliding->in_sh_step_ * sizeof(float), | ||||
| @@ -839,7 +839,7 @@ void DeconvDwSWFp32(float *output_data, const float *input_data, const float *we | |||||
| float *out_t = dst_data + oh_h_start * sliding->in_h_step_ + oh_w_start * sliding->block_channel_; | float *out_t = dst_data + oh_h_start * sliding->in_h_step_ + oh_w_start * sliding->block_channel_; | ||||
| const float *in_t = src_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_; | const float *in_t = src_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_; | ||||
| #if defined(ENABLE_ARM) || defined(ENABLE_X86_64_SSE) | |||||
| #if defined(ENABLE_ARM) || defined(ENABLE_SSE) | |||||
| DeconvDwFp32Center(out_t, in_t, weight, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, | DeconvDwFp32Center(out_t, in_t, weight, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, | ||||
| conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float), | conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float), | ||||
| sliding->block_channel_ * sizeof(float), sliding->in_sh_step_ * sizeof(float), | sliding->block_channel_ * sizeof(float), sliding->in_sh_step_ * sizeof(float), | ||||
| @@ -26,7 +26,9 @@ void ConvFp32(const float *input_data, float *packed_input, const float *packed_ | |||||
| int out_channel = conv_param->output_channel_; | int out_channel = conv_param->output_channel_; | ||||
| int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; | int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; | ||||
| int output_count = conv_param->output_h_ * conv_param->output_w_; | int output_count = conv_param->output_h_ * conv_param->output_w_; | ||||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||||
| #ifdef ENABLE_AVX | |||||
| const int cal_num = C6NUM; | |||||
| #elif defined(ENABLE_ARM32) || defined(ENABLE_SSE) | |||||
| const int cal_num = C4NUM; | const int cal_num = C4NUM; | ||||
| #else | #else | ||||
| const int cal_num = C12NUM; | const int cal_num = C12NUM; | ||||
| @@ -48,7 +50,9 @@ void ConvFp32(const float *input_data, float *packed_input, const float *packed_ | |||||
| int out_offset = thread_id * cal_num * out_channel + out_batch_offset; | int out_offset = thread_id * cal_num * out_channel + out_batch_offset; | ||||
| float *gemm_output = output_data + out_offset; | float *gemm_output = output_data + out_offset; | ||||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||||
| #ifdef ENABLE_AVX | |||||
| RowMajor2Col6Major(gemm_input, col_major_gemm_input, cal_num, deep); | |||||
| #elif defined(ENABLE_ARM32) || defined(ENABLE_SSE) | |||||
| RowMajor2Col4Major(gemm_input, col_major_gemm_input, cal_num, deep); | RowMajor2Col4Major(gemm_input, col_major_gemm_input, cal_num, deep); | ||||
| #else | #else | ||||
| RowMajor2Col12Major(gemm_input, col_major_gemm_input, cal_num, deep); | RowMajor2Col12Major(gemm_input, col_major_gemm_input, cal_num, deep); | ||||
| @@ -97,7 +101,7 @@ void ConvWinogardFp32(const float *input_data, const float *trans_weight, const | |||||
| float *dst_ptr = gemm_out + task_id * gemm_out_offset; | float *dst_ptr = gemm_out + task_id * gemm_out_offset; | ||||
| float *tmp_col_ptr = col_buffer + task_id * col_buffer_offset; | float *tmp_col_ptr = col_buffer + task_id * col_buffer_offset; | ||||
| for (int i = 0; i < input_unit_square; ++i) { | for (int i = 0; i < input_unit_square; ++i) { | ||||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||||
| #if defined(ENABLE_ARM32) || defined(ENABLE_SSE) | |||||
| RowMajor2Col4Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel); | RowMajor2Col4Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel); | ||||
| #else | #else | ||||
| RowMajor2Col12Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel); | RowMajor2Col12Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel); | ||||
| @@ -41,7 +41,7 @@ void DeConvPostFp32C8(const float *src, float *tmp, const float *bias, float *ds | |||||
| size_t kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_; | size_t kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_; | ||||
| size_t output_plane = conv_param->output_w_ * conv_param->output_h_; | size_t output_plane = conv_param->output_w_ * conv_param->output_h_; | ||||
| int oc8 = UP_ROUND(output_channel, C8NUM); | int oc8 = UP_ROUND(output_channel, C8NUM); | ||||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||||
| #if defined(ENABLE_ARM32) || defined(ENABLE_SSE) | |||||
| const int tile_num = 4; | const int tile_num = 4; | ||||
| #else | #else | ||||
| const int tile_num = 12; | const int tile_num = 12; | ||||
| @@ -28,9 +28,21 @@ void RowMajor2Row4Major(const float *src_ptr, float *dst_ptr, int row, int col) | |||||
| for (int r = 0; r < row; r++) { | for (int r = 0; r < row; r++) { | ||||
| const float *src = src_ptr + r * col; | const float *src = src_ptr + r * col; | ||||
| for (int c = 0; c < col; c++) { | for (int c = 0; c < col; c++) { | ||||
| int cd8 = c / 4; | |||||
| int cm8 = c % 4; | |||||
| dst_ptr[cd8 * 4 * row + r * 4 + cm8] = src[c]; | |||||
| int cd4 = c / C4NUM; | |||||
| int cm4 = c % C4NUM; | |||||
| dst_ptr[cd4 * C4NUM * row + r * C4NUM + cm4] = src[c]; | |||||
| } | |||||
| } | |||||
| return; | |||||
| } | |||||
| void RowMajor2Row6Major(const float *src_ptr, float *dst_ptr, int row, int col) { | |||||
| for (int r = 0; r < row; r++) { | |||||
| const float *src = src_ptr + r * col; | |||||
| for (int c = 0; c < col; c++) { | |||||
| int cd6 = c / C6NUM; | |||||
| int cm6 = c % C6NUM; | |||||
| dst_ptr[cd6 * C6NUM * row + r * C6NUM + cm6] = src[c]; | |||||
| } | } | ||||
| } | } | ||||
| return; | return; | ||||
| @@ -40,9 +52,9 @@ void RowMajor2Row8Major(const float *src_ptr, float *dst_ptr, int row, int col) | |||||
| for (int r = 0; r < row; r++) { | for (int r = 0; r < row; r++) { | ||||
| const float *src = src_ptr + r * col; | const float *src = src_ptr + r * col; | ||||
| for (int c = 0; c < col; c++) { | for (int c = 0; c < col; c++) { | ||||
| int cd8 = c / 8; | |||||
| int cm8 = c % 8; | |||||
| dst_ptr[cd8 * 8 * row + r * 8 + cm8] = src[c]; | |||||
| int cd8 = c / C8NUM; | |||||
| int cm8 = c % C8NUM; | |||||
| dst_ptr[cd8 * C8NUM * row + r * C8NUM + cm8] = src[c]; | |||||
| } | } | ||||
| } | } | ||||
| return; | return; | ||||
| @@ -52,9 +64,21 @@ void RowMajor2Row12Major(const float *src_ptr, float *dst_ptr, int row, int col) | |||||
| for (int r = 0; r < row; r++) { | for (int r = 0; r < row; r++) { | ||||
| const float *src = src_ptr + r * col; | const float *src = src_ptr + r * col; | ||||
| for (int c = 0; c < col; c++) { | for (int c = 0; c < col; c++) { | ||||
| int cd8 = c / C12NUM; | |||||
| int cm8 = c % C12NUM; | |||||
| dst_ptr[cd8 * C12NUM * row + r * C12NUM + cm8] = src[c]; | |||||
| int cd12 = c / C12NUM; | |||||
| int cm12 = c % C12NUM; | |||||
| dst_ptr[cd12 * C12NUM * row + r * C12NUM + cm12] = src[c]; | |||||
| } | |||||
| } | |||||
| return; | |||||
| } | |||||
| void RowMajor2Row16Major(const float *src_ptr, float *dst_ptr, int row, int col) { | |||||
| for (int r = 0; r < row; r++) { | |||||
| const float *src = src_ptr + r * col; | |||||
| for (int c = 0; c < col; c++) { | |||||
| int cd16 = c / C16NUM; | |||||
| int cm16 = c % C16NUM; | |||||
| dst_ptr[cd16 * C16NUM * row + r * C16NUM + cm16] = src[c]; | |||||
| } | } | ||||
| } | } | ||||
| return; | return; | ||||
| @@ -190,7 +214,7 @@ void RowMajor2Col12Major(const float *src_ptr, float *dst_ptr, size_t row, size_ | |||||
| : | : | ||||
| : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride) | : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride) | ||||
| : "r10", "r12", "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); | : "r10", "r12", "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); | ||||
| #elif ENABLE_X86_64_SSE | |||||
| #elif ENABLE_SSE | |||||
| __m128 src1 = _mm_loadu_ps(src_c); | __m128 src1 = _mm_loadu_ps(src_c); | ||||
| __m128 src2 = _mm_loadu_ps(src_c + col); | __m128 src2 = _mm_loadu_ps(src_c + col); | ||||
| __m128 src3 = _mm_loadu_ps(src_c + 2 * col); | __m128 src3 = _mm_loadu_ps(src_c + 2 * col); | ||||
| @@ -421,7 +445,7 @@ void RowMajor2Col8Major(const float *src_ptr, float *dst_ptr, size_t row, size_t | |||||
| : | : | ||||
| : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride) | : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride) | ||||
| : "r10", "r11", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7"); | : "r10", "r11", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7"); | ||||
| #elif ENABLE_X86_64_SSE | |||||
| #elif ENABLE_SSE | |||||
| /* 8x4 row-major to col-major */ | /* 8x4 row-major to col-major */ | ||||
| __m128 src1 = _mm_loadu_ps(src_c); | __m128 src1 = _mm_loadu_ps(src_c); | ||||
| __m128 src2 = _mm_loadu_ps(src_c + col); | __m128 src2 = _mm_loadu_ps(src_c + col); | ||||
| @@ -478,6 +502,145 @@ void RowMajor2Col8Major(const float *src_ptr, float *dst_ptr, size_t row, size_t | |||||
| return; | return; | ||||
| } | } | ||||
| void RowMajor2Col16Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col) { | |||||
| size_t row16 = row / C16NUM * C16NUM; | |||||
| size_t col_skip = col / C4NUM * C4NUM; | |||||
| int skip_size = C4NUM; | |||||
| const float *src_r = src_ptr; | |||||
| float *dst_r = dst_ptr; | |||||
| size_t ri = 0; | |||||
| for (; ri < row16; ri += C16NUM) { | |||||
| size_t ci = 0; | |||||
| for (; ci < col_skip; ci += skip_size) { | |||||
| const float *src_c = src_r + ci; | |||||
| float *dst_c = dst_r + ci * C16NUM; | |||||
| for (int tr = 0; tr < C16NUM; tr++) { | |||||
| for (int tc = 0; tc < C4NUM; tc++) { | |||||
| dst_c[tc * C16NUM + tr] = src_c[tr * col + tc]; | |||||
| } | |||||
| } | |||||
| } | |||||
| for (; ci < col; ci++) { | |||||
| const float *src_c = src_r + ci; | |||||
| float *dst_c = dst_r + ci * C16NUM; | |||||
| for (size_t i = 0; i < C16NUM; i++) { | |||||
| dst_c[i] = src_c[i * col]; | |||||
| } | |||||
| } | |||||
| src_r += C16NUM * col; | |||||
| dst_r += C16NUM * col; | |||||
| } | |||||
| for (; ri < row; ri++) { | |||||
| for (size_t i = 0; i < col; i++) { | |||||
| dst_r[i * C16NUM] = src_r[i]; | |||||
| } | |||||
| src_r += col; | |||||
| dst_r += 1; | |||||
| } | |||||
| return; | |||||
| } | |||||
| void RowMajor2Col6Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col) { | |||||
| size_t totalRow = UP_ROUND(row, C6NUM); | |||||
| size_t row6 = row / C6NUM * C6NUM; | |||||
| size_t col8 = col / C8NUM * C8NUM; | |||||
| const float *src_r = src_ptr; | |||||
| float *dst_r = dst_ptr; | |||||
| size_t ri = 0; | |||||
| for (; ri < row6; ri += C6NUM) { | |||||
| size_t ci = 0; | |||||
| for (; ci < col8; ci += C8NUM) { | |||||
| const float *src_c = src_r + ci; | |||||
| float *dst_c = dst_r + ci * C6NUM; | |||||
| /* 6x8 row-major to col-major */ | |||||
| #ifdef ENABLE_AVX | |||||
| __m256 src0 = _mm256_loadu_ps(src_c); | |||||
| __m256 src1 = _mm256_loadu_ps(src_c + col); | |||||
| __m256 src2 = _mm256_loadu_ps(src_c + 2 * col); | |||||
| __m256 src3 = _mm256_loadu_ps(src_c + 3 * col); | |||||
| __m256 src4 = _mm256_loadu_ps(src_c + 4 * col); | |||||
| __m256 src5 = _mm256_loadu_ps(src_c + 5 * col); | |||||
| __m256 trans0 = _mm256_unpacklo_ps(src0, src1); | |||||
| __m256 trans1 = _mm256_unpacklo_ps(src2, src3); | |||||
| __m256 trans2 = _mm256_unpacklo_ps(src4, src5); | |||||
| __m256 trans3 = _mm256_unpackhi_ps(src0, src1); | |||||
| __m256 trans4 = _mm256_unpackhi_ps(src2, src3); | |||||
| __m256 trans5 = _mm256_unpackhi_ps(src4, src5); | |||||
| __m128 lo0 = _mm256_castps256_ps128(trans0); | |||||
| __m128 lo1 = _mm256_castps256_ps128(trans1); | |||||
| __m128 lo2 = _mm256_castps256_ps128(trans2); | |||||
| __m128 lo3 = _mm256_castps256_ps128(trans3); | |||||
| __m128 lo4 = _mm256_castps256_ps128(trans4); | |||||
| __m128 lo5 = _mm256_castps256_ps128(trans5); | |||||
| __m128 hi0 = _mm256_extractf128_ps(trans0, 1); | |||||
| __m128 hi1 = _mm256_extractf128_ps(trans1, 1); | |||||
| __m128 hi2 = _mm256_extractf128_ps(trans2, 1); | |||||
| __m128 hi3 = _mm256_extractf128_ps(trans3, 1); | |||||
| __m128 hi4 = _mm256_extractf128_ps(trans4, 1); | |||||
| __m128 hi5 = _mm256_extractf128_ps(trans5, 1); | |||||
| __m128 res0 = _mm_shuffle_ps(lo0, lo1, _MM_SHUFFLE(1, 0, 1, 0)); | |||||
| __m128 res1 = _mm_shuffle_ps(lo2, lo0, _MM_SHUFFLE(3, 2, 1, 0)); | |||||
| __m128 res2 = _mm_shuffle_ps(lo1, lo2, _MM_SHUFFLE(3, 2, 3, 2)); | |||||
| __m128 res3 = _mm_shuffle_ps(lo3, lo4, _MM_SHUFFLE(1, 0, 1, 0)); | |||||
| __m128 res4 = _mm_shuffle_ps(lo5, lo3, _MM_SHUFFLE(3, 2, 1, 0)); | |||||
| __m128 res5 = _mm_shuffle_ps(lo4, lo5, _MM_SHUFFLE(3, 2, 3, 2)); | |||||
| __m128 res6 = _mm_shuffle_ps(hi0, hi1, _MM_SHUFFLE(1, 0, 1, 0)); | |||||
| __m128 res7 = _mm_shuffle_ps(hi2, hi0, _MM_SHUFFLE(3, 2, 1, 0)); | |||||
| __m128 res8 = _mm_shuffle_ps(hi1, hi2, _MM_SHUFFLE(3, 2, 3, 2)); | |||||
| __m128 res9 = _mm_shuffle_ps(hi3, hi4, _MM_SHUFFLE(1, 0, 1, 0)); | |||||
| __m128 res10 = _mm_shuffle_ps(hi5, hi3, _MM_SHUFFLE(3, 2, 1, 0)); | |||||
| __m128 res11 = _mm_shuffle_ps(hi4, hi5, _MM_SHUFFLE(3, 2, 3, 2)); | |||||
| _mm_storeu_ps(dst_c, res0); | |||||
| _mm_storeu_ps(dst_c + 4, res1); | |||||
| _mm_storeu_ps(dst_c + 8, res2); | |||||
| _mm_storeu_ps(dst_c + 12, res3); | |||||
| _mm_storeu_ps(dst_c + 16, res4); | |||||
| _mm_storeu_ps(dst_c + 20, res5); | |||||
| _mm_storeu_ps(dst_c + 24, res6); | |||||
| _mm_storeu_ps(dst_c + 28, res7); | |||||
| _mm_storeu_ps(dst_c + 32, res8); | |||||
| _mm_storeu_ps(dst_c + 36, res9); | |||||
| _mm_storeu_ps(dst_c + 40, res10); | |||||
| _mm_storeu_ps(dst_c + 44, res11); | |||||
| #else | |||||
| for (int tr = 0; tr < C6NUM; tr++) { | |||||
| for (int tc = 0; tc < C8NUM; tc++) { | |||||
| dst_c[tc * C6NUM + tr] = src_c[tr * col + tc]; | |||||
| } | |||||
| } | |||||
| #endif | |||||
| } | |||||
| for (; ci < col; ci++) { | |||||
| const float *src_c = src_r + ci; | |||||
| float *dst_c = dst_r + ci * C6NUM; | |||||
| for (size_t i = 0; i < C6NUM; i++) { | |||||
| dst_c[i] = src_c[i * col]; | |||||
| } | |||||
| } | |||||
| src_r += C6NUM * col; | |||||
| dst_r += C6NUM * col; | |||||
| } | |||||
| for (; ri < row; ri++) { | |||||
| for (size_t i = 0; i < col; i++) { | |||||
| dst_r[i * C6NUM] = src_r[i]; | |||||
| } | |||||
| src_r += col; | |||||
| dst_r += 1; | |||||
| } | |||||
| for (; ri < totalRow; ri++) { | |||||
| for (size_t i = 0; i < col; i++) { | |||||
| dst_r[i * C6NUM] = 0; | |||||
| } | |||||
| dst_r += 1; | |||||
| } | |||||
| return; | |||||
| } | |||||
| void RowMajor2Col4Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col) { | void RowMajor2Col4Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col) { | ||||
| size_t row8 = row / C4NUM * C4NUM; | size_t row8 = row / C4NUM * C4NUM; | ||||
| size_t col4 = col / C4NUM * C4NUM; | size_t col4 = col / C4NUM * C4NUM; | ||||
| @@ -519,7 +682,7 @@ void RowMajor2Col4Major(const float *src_ptr, float *dst_ptr, size_t row, size_t | |||||
| : | : | ||||
| : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride) | : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride) | ||||
| : "r10", "r12", "q0", "q1", "q2", "q3"); | : "r10", "r12", "q0", "q1", "q2", "q3"); | ||||
| #elif ENABLE_X86_64_SSE | |||||
| #elif ENABLE_SSE | |||||
| __m128 src1 = _mm_loadu_ps(src_c); | __m128 src1 = _mm_loadu_ps(src_c); | ||||
| __m128 src2 = _mm_loadu_ps(src_c + col); | __m128 src2 = _mm_loadu_ps(src_c + col); | ||||
| __m128 src3 = _mm_loadu_ps(src_c + 2 * col); | __m128 src3 = _mm_loadu_ps(src_c + 2 * col); | ||||
| @@ -630,6 +793,34 @@ void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, A | |||||
| return; | return; | ||||
| } | } | ||||
| #ifdef ENABLE_AVX | |||||
| #ifdef WIN32 | |||||
| void MatMul6x16(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row, | |||||
| int col, int stride, int out_type) { | |||||
| if (out_type == OutType_Nhwc) { | |||||
| for (int r = 0; r < row; r++) { | |||||
| for (int c = 0; c < col; c++) { | |||||
| int r6div = r / C6NUM, r6mod = r % C6NUM; | |||||
| int c16div = c / C16NUM, c16mod = c % C16NUM; | |||||
| size_t ci = r * stride + c; | |||||
| float value = 0; | |||||
| for (int d = 0; d < deep; d++) { | |||||
| size_t ai = r6div * deep * C6NUM + d * C6NUM + r6mod; | |||||
| size_t bi = c16div * deep * C16NUM + d * C16NUM + c16mod; | |||||
| value = value + a[ai] * b[bi]; | |||||
| } | |||||
| if (bias != NULL) value += bias[c]; | |||||
| if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); | |||||
| if (act_type != ActType_No) value = MSMAX(0.0f, value); | |||||
| dst[ci] = value; | |||||
| } | |||||
| } | |||||
| } | |||||
| return; | |||||
| } | |||||
| #endif | |||||
| #endif | |||||
| void MatMul4x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row, | void MatMul4x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row, | ||||
| int col, int stride, int out_type) { | int col, int stride, int out_type) { | ||||
| if (out_type == OutType_C8) { | if (out_type == OutType_C8) { | ||||
| @@ -670,7 +861,19 @@ void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActT | |||||
| } else { | } else { | ||||
| MatmulFloatNeon32Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type)); | MatmulFloatNeon32Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type)); | ||||
| } | } | ||||
| #elif ENABLE_X86_64_SSE | |||||
| #elif ENABLE_AVX | |||||
| if (out_type == OutType_Nhwc) { | |||||
| #ifdef WIN32 | |||||
| MatMul6x16(a, b, c, bias, act_type, deep, row, col, stride, out_type); | |||||
| #else | |||||
| MatmulFloatAvxOpt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type)); | |||||
| #endif | |||||
| } else if (out_type == OutType_C8) { | |||||
| MatmulFloatSse64(a, b, c, bias, (int)act_type, deep, row, col, stride, 0, 0); | |||||
| } else { | |||||
| MatmulFloatSse64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type)); | |||||
| } | |||||
| #elif ENABLE_SSE | |||||
| if (out_type == OutType_C8) { | if (out_type == OutType_C8) { | ||||
| MatmulFloatSse64(a, b, c, bias, (int)act_type, deep, row, col, stride, 0, 0); | MatmulFloatSse64(a, b, c, bias, (int)act_type, deep, row, col, stride, 0, 0); | ||||
| } else { | } else { | ||||
| @@ -31,11 +31,15 @@ void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActT | |||||
| void MatVecMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int depth, int col); | void MatVecMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int depth, int col); | ||||
| void RowMajor2ColMajor(const float *src_ptr, float *dst_ptr, int row, int col); | void RowMajor2ColMajor(const float *src_ptr, float *dst_ptr, int row, int col); | ||||
| void RowMajor2Row4Major(const float *src_ptr, float *dst_ptr, int row, int col); | void RowMajor2Row4Major(const float *src_ptr, float *dst_ptr, int row, int col); | ||||
| void RowMajor2Row6Major(const float *src_ptr, float *dst_ptr, int row, int col); | |||||
| void RowMajor2Row8Major(const float *src_ptr, float *dst_ptr, int row, int col); | void RowMajor2Row8Major(const float *src_ptr, float *dst_ptr, int row, int col); | ||||
| void RowMajor2Row12Major(const float *src_ptr, float *dst_ptr, int row, int col); | void RowMajor2Row12Major(const float *src_ptr, float *dst_ptr, int row, int col); | ||||
| void RowMajor2Row16Major(const float *src_ptr, float *dst_ptr, int row, int col); | |||||
| void RowMajor2Col4Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col); | void RowMajor2Col4Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col); | ||||
| void RowMajor2Col6Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col); | |||||
| void RowMajor2Col8Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col); | void RowMajor2Col8Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col); | ||||
| void RowMajor2Col12Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col); | void RowMajor2Col12Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col); | ||||
| void RowMajor2Col16Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col); | |||||
| #ifdef ENABLE_ARM | #ifdef ENABLE_ARM | ||||
| void MatVecMulFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int col); | void MatVecMulFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int col); | ||||
| #endif | #endif | ||||
| @@ -49,11 +53,16 @@ void MatmulFloatNeon32(const float *a, const float *b, float *c, const float *bi | |||||
| int col, int stride, size_t writeNhwc, size_t WriteWino); | int col, int stride, size_t writeNhwc, size_t WriteWino); | ||||
| void MatmulFloatNeon32Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, | void MatmulFloatNeon32Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, | ||||
| int col, int stride, int write_mode); | int col, int stride, int write_mode); | ||||
| #elif ENABLE_X86_64_SSE | |||||
| #elif ENABLE_SSE | |||||
| #include <x86intrin.h> | |||||
| void MatmulFloatSse64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, | void MatmulFloatSse64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, | ||||
| int col, int stride, size_t writeNhwc, size_t WriteWino); | int col, int stride, size_t writeNhwc, size_t WriteWino); | ||||
| void MatmulFloatSse64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, | void MatmulFloatSse64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, | ||||
| int col, int stride, int write_mode); | int col, int stride, int write_mode); | ||||
| #ifdef ENABLE_AVX | |||||
| void MatmulFloatAvxOpt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, | |||||
| int col, int stride, int write_mode); | |||||
| #endif | |||||
| #endif | #endif | ||||
| #ifdef ENABLE_NNACL_INFER_SHAPE | #ifdef ENABLE_NNACL_INFER_SHAPE | ||||
| @@ -42,6 +42,7 @@ typedef struct MatMulParameter { | |||||
| int row_; | int row_; | ||||
| int col_; | int col_; | ||||
| int row_4_; | int row_4_; | ||||
| int row_6_; | |||||
| int row_8_; | int row_8_; | ||||
| int row_12_; | int row_12_; | ||||
| int row_16_; | int row_16_; | ||||
| @@ -125,7 +125,7 @@ int B(const float *poly_array, float *matrix_b, int in_unit) { | |||||
| return NNACL_OK; | return NNACL_OK; | ||||
| } | } | ||||
| #if !defined(ENABLE_ARM) && !defined(ENABLE_X86_64_SSE) | |||||
| #if !defined(ENABLE_ARM) && !defined(ENABLE_SSE) | |||||
| void MatrixMultiplyWinograd(const float *matix_a, const float *matrix_b, float *matrix_c, int m, int k, int n, | void MatrixMultiplyWinograd(const float *matix_a, const float *matrix_b, float *matrix_c, int m, int k, int n, | ||||
| int in_channel, int c4_channel) { | int in_channel, int c4_channel) { | ||||
| int cnt = 0; | int cnt = 0; | ||||
| @@ -228,7 +228,7 @@ int CookToomFilter(float *matrix_a, float *matrix_at, float *matrix_b, float *ma | |||||
| return NNACL_OK; | return NNACL_OK; | ||||
| } | } | ||||
| #if defined(ENABLE_ARM) || defined(ENABLE_X86_64_SSE) | |||||
| #if defined(ENABLE_ARM) || defined(ENABLE_SSE) | |||||
| void MatrixMultiplyVec(const MS_FLOAT32X4 *matrix_a, const MS_FLOAT32X4 *matrix_b, MS_FLOAT32X4 *matrix_c, | void MatrixMultiplyVec(const MS_FLOAT32X4 *matrix_a, const MS_FLOAT32X4 *matrix_b, MS_FLOAT32X4 *matrix_c, | ||||
| const float *bias, int m, int k, int n) { | const float *bias, int m, int k, int n) { | ||||
| int count = 0; | int count = 0; | ||||
| @@ -52,7 +52,7 @@ void MatrixMultiplyWinograd(const float *matix_a, const float *matrix_b, float * | |||||
| int WinogradWeightTransform(const float *weight_data, float *winograd_data, float *matrix_g, const float *matrix_gt, | int WinogradWeightTransform(const float *weight_data, float *winograd_data, float *matrix_g, const float *matrix_gt, | ||||
| int oc_block, int input_unit_, int kernel_unit_, int channel, int batch, bool pack); | int oc_block, int input_unit_, int kernel_unit_, int channel, int batch, bool pack); | ||||
| #if defined(ENABLE_ARM) || defined(ENABLE_X86_64_SSE) | |||||
| #if defined(ENABLE_ARM) || defined(ENABLE_SSE) | |||||
| void MatrixMultiplyVec(const MS_FLOAT32X4 *matrix_a, const MS_FLOAT32X4 *matrix_b, MS_FLOAT32X4 *matrix_c, | void MatrixMultiplyVec(const MS_FLOAT32X4 *matrix_a, const MS_FLOAT32X4 *matrix_b, MS_FLOAT32X4 *matrix_c, | ||||
| const float *bias, int m, int k, int n); | const float *bias, int m, int k, int n); | ||||
| #endif | #endif | ||||
| @@ -21,8 +21,8 @@ | |||||
| #include <arm_neon.h> | #include <arm_neon.h> | ||||
| #endif | #endif | ||||
| #ifdef ENABLE_X86_64_SSE | |||||
| #include <nmmintrin.h> | |||||
| #ifdef ENABLE_SSE | |||||
| #include <x86intrin.h> | |||||
| #endif | #endif | ||||
| #include <stdint.h> | #include <stdint.h> | ||||
| @@ -31,6 +31,7 @@ | |||||
| #define C2NUM 2 | #define C2NUM 2 | ||||
| #define C4NUM 4 | #define C4NUM 4 | ||||
| #define C6NUM 6 | |||||
| #define C8NUM 8 | #define C8NUM 8 | ||||
| #define C12NUM 12 | #define C12NUM 12 | ||||
| #define C16NUM 16 | #define C16NUM 16 | ||||
| @@ -91,7 +92,7 @@ typedef enum ActType { ActType_No, ActType_Relu, ActType_Sigmod, ActType_Relu6, | |||||
| #define MS_MAXQ_F32 vmaxq_f32 | #define MS_MAXQ_F32 vmaxq_f32 | ||||
| #define MS_MINQ_F32 vminq_f32 | #define MS_MINQ_F32 vminq_f32 | ||||
| #define MS_MULQ_F32(src1, src2) vmulq_n_f32(src1, src2) | #define MS_MULQ_F32(src1, src2) vmulq_n_f32(src1, src2) | ||||
| #elif defined(ENABLE_X86_64_SSE) | |||||
| #elif defined(ENABLE_SSE) | |||||
| #define MS_FLOAT32X4 __m128 | #define MS_FLOAT32X4 __m128 | ||||
| #define MS_LDQ_F32 _mm_loadu_ps | #define MS_LDQ_F32 _mm_loadu_ps | ||||
| #define MS_ADDQ_F32 _mm_add_ps | #define MS_ADDQ_F32 _mm_add_ps | ||||
| @@ -756,7 +756,7 @@ void PackNCHWToNHWCInt8(const void *src, void *dst, int batch, int plane, int ch | |||||
| return; | return; | ||||
| } | } | ||||
| #ifndef ENABLE_X86_64_SSE | |||||
| #ifndef ENABLE_SSE | |||||
| void PackNHWCToNCHWFp32(const void *src, void *dst, int batches, int plane, int channel) { | void PackNHWCToNCHWFp32(const void *src, void *dst, int batches, int plane, int channel) { | ||||
| int hw8 = plane / C8NUM * C8NUM; | int hw8 = plane / C8NUM * C8NUM; | ||||
| int c8 = channel / C8NUM * C8NUM; | int c8 = channel / C8NUM * C8NUM; | ||||
| @@ -79,7 +79,7 @@ void GeneralInputTransformUnit(const float *src_data, float *dst_data, const flo | |||||
| int src_step, int dst_step, int in_unit) { | int src_step, int dst_step, int in_unit) { | ||||
| int len = in_unit * in_unit; | int len = in_unit * in_unit; | ||||
| if (len > MAX_LEN) return; | if (len > MAX_LEN) return; | ||||
| #if defined(ENABLE_ARM) || defined(ENABLE_X86_64_SSE) | |||||
| #if defined(ENABLE_ARM) || defined(ENABLE_SSE) | |||||
| MS_FLOAT32X4 src[MAX_LEN]; | MS_FLOAT32X4 src[MAX_LEN]; | ||||
| MS_FLOAT32X4 t[MAX_LEN]; | MS_FLOAT32X4 t[MAX_LEN]; | ||||
| MS_FLOAT32X4 m[MAX_LEN]; | MS_FLOAT32X4 m[MAX_LEN]; | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifdef ENABLE_X86_64_SSE | |||||
| #include <nmmintrin.h> | |||||
| #ifdef ENABLE_SSE | |||||
| #include <x86intrin.h> | |||||
| #include "nnacl/fp32/conv_depthwise_fp32.h" | #include "nnacl/fp32/conv_depthwise_fp32.h" | ||||
| void ConvDwFp32Border(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, | void ConvDwFp32Border(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifdef ENABLE_X86_64_SSE | |||||
| #include <nmmintrin.h> | |||||
| #ifdef ENABLE_SSE | |||||
| #include <x86intrin.h> | |||||
| #include "nnacl/minimal_filtering_generator.h" | #include "nnacl/minimal_filtering_generator.h" | ||||
| #include "nnacl/op_base.h" | #include "nnacl/op_base.h" | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifdef ENABLE_X86_64_SSE | |||||
| #include <nmmintrin.h> | |||||
| #ifdef ENABLE_SSE | |||||
| #include <x86intrin.h> | |||||
| #include "nnacl/pack.h" | #include "nnacl/pack.h" | ||||
| #include "nnacl/int8/conv_int8.h" | #include "nnacl/int8/conv_int8.h" | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifdef ENABLE_X86_64_SSE | |||||
| #include <nmmintrin.h> | |||||
| #ifdef ENABLE_SSE | |||||
| #include <x86intrin.h> | |||||
| #include "nnacl/fp32/common_func_fp32.h" | #include "nnacl/fp32/common_func_fp32.h" | ||||
| void PostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t oc8div, size_t oc8mod, | void PostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t oc8div, size_t oc8mod, | ||||
| @@ -13,8 +13,8 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifdef ENABLE_X86_64_SSE | |||||
| #include <nmmintrin.h> | |||||
| #ifdef ENABLE_SSE | |||||
| #include <x86intrin.h> | |||||
| #include "nnacl/fp32/common_func_fp32.h" | #include "nnacl/fp32/common_func_fp32.h" | ||||
| void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length) { | void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length) { | ||||
| @@ -60,6 +60,7 @@ void Convolution1x1CPUKernel::InitConv1x1MatmulParam() { | |||||
| matmul_param_->col_ = conv_param_->output_channel_; | matmul_param_->col_ = conv_param_->output_channel_; | ||||
| matmul_param_->deep_ = conv_param_->input_channel_; | matmul_param_->deep_ = conv_param_->input_channel_; | ||||
| matmul_param_->row_4_ = UP_ROUND(matmul_param_->row_, C4NUM); | matmul_param_->row_4_ = UP_ROUND(matmul_param_->row_, C4NUM); | ||||
| matmul_param_->row_6_ = UP_ROUND(matmul_param_->row_, C6NUM); | |||||
| matmul_param_->row_12_ = UP_ROUND(matmul_param_->row_, C12NUM); | matmul_param_->row_12_ = UP_ROUND(matmul_param_->row_, C12NUM); | ||||
| matmul_param_->col_8_ = UP_ROUND(matmul_param_->col_, C8NUM); | matmul_param_->col_8_ = UP_ROUND(matmul_param_->col_, C8NUM); | ||||
| matmul_param_->act_type_ = conv_param_->act_type_; | matmul_param_->act_type_ = conv_param_->act_type_; | ||||
| @@ -71,8 +72,13 @@ int Convolution1x1CPUKernel::InitConv1x1BiasWeight() { | |||||
| auto input_channel = filter_tensor->Channel(); | auto input_channel = filter_tensor->Channel(); | ||||
| auto output_channel = filter_tensor->Batch(); | auto output_channel = filter_tensor->Batch(); | ||||
| #ifdef ENABLE_AVX | |||||
| int col_tile = C16NUM; | |||||
| #else | |||||
| int col_tile = C8NUM; | |||||
| #endif | |||||
| if (in_tensors_.size() == 3) { | if (in_tensors_.size() == 3) { | ||||
| int size = UP_ROUND(output_channel, C8NUM) * sizeof(float); | |||||
| int size = UP_ROUND(output_channel, col_tile) * sizeof(float); | |||||
| int weight_size = output_channel * sizeof(float); | int weight_size = output_channel * sizeof(float); | ||||
| bias_data_ = malloc(size); | bias_data_ = malloc(size); | ||||
| if (bias_data_ == nullptr) { | if (bias_data_ == nullptr) { | ||||
| @@ -83,22 +89,29 @@ int Convolution1x1CPUKernel::InitConv1x1BiasWeight() { | |||||
| memset(reinterpret_cast<char *>(bias_data_) + weight_size, 0, size - weight_size); | memset(reinterpret_cast<char *>(bias_data_) + weight_size, 0, size - weight_size); | ||||
| } | } | ||||
| int size = input_channel * UP_ROUND(output_channel, C8NUM) * sizeof(float); | |||||
| int down_size = input_channel * DOWN_DIV(output_channel, C8NUM) * C8NUM * sizeof(float); | |||||
| int size = input_channel * UP_ROUND(output_channel, col_tile) * sizeof(float); | |||||
| int down_size = input_channel * DOWN_DIV(output_channel, col_tile) * col_tile * sizeof(float); | |||||
| weight_ptr_ = reinterpret_cast<float *>(malloc(size)); | weight_ptr_ = reinterpret_cast<float *>(malloc(size)); | ||||
| if (weight_ptr_ == nullptr) { | if (weight_ptr_ == nullptr) { | ||||
| MS_LOG(ERROR) << "Conv1x1 Malloc weight_ptr_ error!"; | MS_LOG(ERROR) << "Conv1x1 Malloc weight_ptr_ error!"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| memset(reinterpret_cast<char *>(weight_ptr_) + down_size, 0, size - down_size); | memset(reinterpret_cast<char *>(weight_ptr_) + down_size, 0, size - down_size); | ||||
| #ifdef ENABLE_AVX | |||||
| RowMajor2Col16Major(reinterpret_cast<float *>(filter_tensor->MutableData()), weight_ptr_, output_channel, | |||||
| input_channel); | |||||
| #else | |||||
| RowMajor2Col8Major(reinterpret_cast<float *>(filter_tensor->MutableData()), weight_ptr_, output_channel, | RowMajor2Col8Major(reinterpret_cast<float *>(filter_tensor->MutableData()), weight_ptr_, output_channel, | ||||
| input_channel); | input_channel); | ||||
| #endif | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int Convolution1x1CPUKernel::InitConv1x1Param() { | int Convolution1x1CPUKernel::InitConv1x1Param() { | ||||
| int hw_tile = C12NUM; | int hw_tile = C12NUM; | ||||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||||
| #ifdef ENABLE_AVX | |||||
| hw_tile = C6NUM; | |||||
| #elif defined(ENABLE_ARM32) || defined(ENABLE_SSE) | |||||
| hw_tile = C4NUM; | hw_tile = C4NUM; | ||||
| #endif | #endif | ||||
| if ((matmul_param_->row_ > (hw_tile * op_parameter_->thread_num_)) && (matmul_param_->row_ > matmul_param_->col_)) { | if ((matmul_param_->row_ > (hw_tile * op_parameter_->thread_num_)) && (matmul_param_->row_ > matmul_param_->col_)) { | ||||
| @@ -106,9 +119,14 @@ int Convolution1x1CPUKernel::InitConv1x1Param() { | |||||
| thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->row_, hw_tile)); | thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->row_, hw_tile)); | ||||
| thread_stride_ = UP_DIV(UP_DIV(matmul_param_->row_, hw_tile), thread_count_) * hw_tile; | thread_stride_ = UP_DIV(UP_DIV(matmul_param_->row_, hw_tile), thread_count_) * hw_tile; | ||||
| } else { | } else { | ||||
| #ifdef ENABLE_AVX | |||||
| int col_tile = C16NUM; | |||||
| #else | |||||
| int col_tile = C8NUM; | |||||
| #endif | |||||
| multi_thread_by_hw_ = false; | multi_thread_by_hw_ = false; | ||||
| thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->col_, C8NUM)); | |||||
| thread_stride_ = UP_DIV(UP_DIV(matmul_param_->col_, C8NUM), thread_count_) * C8NUM; | |||||
| thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->col_, col_tile)); | |||||
| thread_stride_ = UP_DIV(UP_DIV(matmul_param_->col_, col_tile), thread_count_) * col_tile; | |||||
| } | } | ||||
| pre_trans_input_ = (conv_param_->pad_u_ != 0 || conv_param_->pad_l_ != 0 || conv_param_->stride_h_ != 1 || | pre_trans_input_ = (conv_param_->pad_u_ != 0 || conv_param_->pad_l_ != 0 || conv_param_->stride_h_ != 1 || | ||||
| @@ -175,7 +193,9 @@ int Convolution1x1CPUKernel::DoConv1x1Hw(int task_id) { | |||||
| float *thread_input_ptr = input_ptr_ + task_id * thread_stride_ * matmul_param_->deep_; | float *thread_input_ptr = input_ptr_ + task_id * thread_stride_ * matmul_param_->deep_; | ||||
| float *thread_pack_input = pack_input_ + task_id * thread_stride_ * matmul_param_->deep_; | float *thread_pack_input = pack_input_ + task_id * thread_stride_ * matmul_param_->deep_; | ||||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||||
| #if ENABLE_AVX | |||||
| RowMajor2Col6Major(thread_input_ptr, thread_pack_input, cur_hw_, matmul_param_->deep_); | |||||
| #elif defined(ENABLE_ARM32) || defined(ENABLE_SSE) | |||||
| RowMajor2Col4Major(thread_input_ptr, thread_pack_input, cur_hw_, matmul_param_->deep_); | RowMajor2Col4Major(thread_input_ptr, thread_pack_input, cur_hw_, matmul_param_->deep_); | ||||
| #else | #else | ||||
| RowMajor2Col12Major(thread_input_ptr, thread_pack_input, cur_hw_, matmul_param_->deep_); | RowMajor2Col12Major(thread_input_ptr, thread_pack_input, cur_hw_, matmul_param_->deep_); | ||||
| @@ -202,7 +222,10 @@ int Convolution1x1CPUKernel::Run() { | |||||
| auto src_in = reinterpret_cast<float *>(in_tensors_[0]->MutableData()); | auto src_in = reinterpret_cast<float *>(in_tensors_[0]->MutableData()); | ||||
| auto src_out = reinterpret_cast<float *>(out_tensors_[0]->MutableData()); | auto src_out = reinterpret_cast<float *>(out_tensors_[0]->MutableData()); | ||||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||||
| #ifdef ENABLE_AVX | |||||
| pack_input_ = | |||||
| reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_6_ * matmul_param_->deep_ * sizeof(float))); | |||||
| #elif defined(ENABLE_ARM32) || defined(ENABLE_SSE) | |||||
| pack_input_ = | pack_input_ = | ||||
| reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_4_ * matmul_param_->deep_ * sizeof(float))); | reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_4_ * matmul_param_->deep_ * sizeof(float))); | ||||
| #else | #else | ||||
| @@ -226,7 +249,9 @@ int Convolution1x1CPUKernel::Run() { | |||||
| if (multi_thread_by_hw_) { | if (multi_thread_by_hw_) { | ||||
| ParallelLaunch(this->context_->thread_pool_, Convolution1x1RunHw, this, thread_count_); | ParallelLaunch(this->context_->thread_pool_, Convolution1x1RunHw, this, thread_count_); | ||||
| } else { | } else { | ||||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||||
| #ifdef ENABLE_AVX | |||||
| RowMajor2Col6Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_); | |||||
| #elif defined(ENABLE_ARM32) || defined(ENABLE_SSE) | |||||
| RowMajor2Col4Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_); | RowMajor2Col4Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_); | ||||
| #else | #else | ||||
| RowMajor2Col12Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_); | RowMajor2Col12Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_); | ||||
| @@ -42,7 +42,12 @@ int ConvolutionCPUKernel::InitWeightBias() { | |||||
| conv_param_->input_channel_ = in_channel; | conv_param_->input_channel_ = in_channel; | ||||
| conv_param_->output_channel_ = out_channel; | conv_param_->output_channel_ = out_channel; | ||||
| int kernel_plane = filter_tensor->Height() * filter_tensor->Width(); | int kernel_plane = filter_tensor->Height() * filter_tensor->Width(); | ||||
| int oc_block_num = UP_ROUND(out_channel, C8NUM); | |||||
| #ifdef ENABLE_AVX | |||||
| const int oc_block = C16NUM; | |||||
| #else | |||||
| const int oc_block = C8NUM; | |||||
| #endif | |||||
| int oc_block_num = UP_ROUND(out_channel, oc_block); | |||||
| int pack_weight_size = oc_block_num * in_channel * kernel_plane; | int pack_weight_size = oc_block_num * in_channel * kernel_plane; | ||||
| auto origin_weight = reinterpret_cast<float *>(filter_tensor->data_c()); | auto origin_weight = reinterpret_cast<float *>(filter_tensor->data_c()); | ||||
| @@ -52,7 +57,11 @@ int ConvolutionCPUKernel::InitWeightBias() { | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| memset(packed_weight_, 0, pack_weight_size * sizeof(float)); | memset(packed_weight_, 0, pack_weight_size * sizeof(float)); | ||||
| #ifdef ENABLE_AVX | |||||
| RowMajor2Col16Major(origin_weight, packed_weight_, out_channel, in_channel * kernel_plane); | |||||
| #else | |||||
| RowMajor2Col8Major(origin_weight, packed_weight_, out_channel, in_channel * kernel_plane); | RowMajor2Col8Major(origin_weight, packed_weight_, out_channel, in_channel * kernel_plane); | ||||
| #endif | |||||
| bias_data_ = reinterpret_cast<float *>(malloc(oc_block_num * sizeof(float))); | bias_data_ = reinterpret_cast<float *>(malloc(oc_block_num * sizeof(float))); | ||||
| if (bias_data_ == nullptr) { | if (bias_data_ == nullptr) { | ||||
| @@ -72,7 +81,10 @@ int ConvolutionCPUKernel::InitWeightBias() { | |||||
| int ConvolutionCPUKernel::InitTmpBuffer() { | int ConvolutionCPUKernel::InitTmpBuffer() { | ||||
| MS_ASSERT(ctx_->allocator != nullptr); | MS_ASSERT(ctx_->allocator != nullptr); | ||||
| #ifdef ENABLE_ARM32 | |||||
| #ifdef ENABLE_AVX | |||||
| int unit_size = conv_param_->kernel_h_ * conv_param_->kernel_w_ * conv_param_->input_channel_ * C6NUM * thread_count_; | |||||
| #elif ENABLE_ARM32 || ENABLE_SSE | |||||
| int unit_size = conv_param_->kernel_h_ * conv_param_->kernel_w_ * conv_param_->input_channel_ * C4NUM * thread_count_; | int unit_size = conv_param_->kernel_h_ * conv_param_->kernel_w_ * conv_param_->input_channel_ * C4NUM * thread_count_; | ||||
| #else | #else | ||||
| int unit_size = | int unit_size = | ||||
| @@ -115,7 +115,7 @@ int DeConvolutionCPUKernel::DoDeconv(int task_id) { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||||
| #if defined(ENABLE_ARM32) || defined(ENABLE_SSE) | |||||
| auto tmp_buffer = tmp_buffer_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->row_4_; | auto tmp_buffer = tmp_buffer_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->row_4_; | ||||
| MatMulOpt(pack_input_, weight_ptr_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->deep_, | MatMulOpt(pack_input_, weight_ptr_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->deep_, | ||||
| tmp_buffer, nullptr, ActType_No, matmul_param_->deep_, matmul_param_->row_4_, oc * C8NUM * kernel_plane_, | tmp_buffer, nullptr, ActType_No, matmul_param_->deep_, matmul_param_->row_4_, oc * C8NUM * kernel_plane_, | ||||
| @@ -174,7 +174,7 @@ int DeConvolutionCPUKernel::InitRunBuf() { | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||||
| #if defined(ENABLE_ARM32) || defined(ENABLE_SSE) | |||||
| tmp_buffer_ = | tmp_buffer_ = | ||||
| reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_4_ * matmul_param_->col_8_ * sizeof(float))); | reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_4_ * matmul_param_->col_8_ * sizeof(float))); | ||||
| #else | #else | ||||
| @@ -186,7 +186,7 @@ int DeConvolutionCPUKernel::InitRunBuf() { | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||||
| #if defined(ENABLE_ARM32) || defined(ENABLE_SSE) | |||||
| pack_input_ = | pack_input_ = | ||||
| reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_4_ * matmul_param_->deep_ * sizeof(float))); | reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_4_ * matmul_param_->deep_ * sizeof(float))); | ||||
| #else | #else | ||||
| @@ -215,7 +215,7 @@ int DeConvolutionCPUKernel::Run() { | |||||
| input_ptr_ = src_in + batch_index * input_plane_ * conv_param_->input_channel_; | input_ptr_ = src_in + batch_index * input_plane_ * conv_param_->input_channel_; | ||||
| output_ptr_ = src_out + batch_index * output_plane_ * conv_param_->output_channel_; | output_ptr_ = src_out + batch_index * output_plane_ * conv_param_->output_channel_; | ||||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||||
| #if defined(ENABLE_ARM32) || defined(ENABLE_SSE) | |||||
| RowMajor2Col4Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_); | RowMajor2Col4Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_); | ||||
| #else | #else | ||||
| RowMajor2Col12Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_); | RowMajor2Col12Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_); | ||||
| @@ -51,12 +51,18 @@ int FullconnectionCPUKernel::ReSize() { | |||||
| fc_param_->col_ = out_tensors_.at(0)->shape().back(); | fc_param_->col_ = out_tensors_.at(0)->shape().back(); | ||||
| fc_param_->deep_ = (in_tensors_.at(1)->shape()).at(1); | fc_param_->deep_ = (in_tensors_.at(1)->shape()).at(1); | ||||
| #ifdef ENABLE_AVX | |||||
| int col_tile = C16NUM; | |||||
| #else | |||||
| int col_tile = C8NUM; | |||||
| #endif | |||||
| fc_param_->row_12_ = UP_ROUND(fc_param_->row_, C12NUM); | fc_param_->row_12_ = UP_ROUND(fc_param_->row_, C12NUM); | ||||
| fc_param_->col_8_ = UP_ROUND(fc_param_->col_, C8NUM); | |||||
| fc_param_->col_8_ = UP_ROUND(fc_param_->col_, col_tile); | |||||
| fc_param_->row_6_ = UP_ROUND(fc_param_->col_, C6NUM); | |||||
| fc_param_->row_4_ = UP_ROUND(fc_param_->row_, C4NUM); | fc_param_->row_4_ = UP_ROUND(fc_param_->row_, C4NUM); | ||||
| thread_count_ = MSMIN(thread_count_, UP_DIV(fc_param_->col_8_, 8)); | |||||
| thread_stride_ = UP_DIV(UP_DIV(fc_param_->col_8_, 8), thread_count_); | |||||
| thread_count_ = MSMIN(thread_count_, UP_DIV(fc_param_->col_8_, col_tile)); | |||||
| thread_stride_ = UP_DIV(UP_DIV(fc_param_->col_8_, col_tile), thread_count_); | |||||
| #ifdef ENABLE_ARM | #ifdef ENABLE_ARM | ||||
| if (fc_param_->row_ == 1) { | if (fc_param_->row_ == 1) { | ||||
| @@ -75,7 +81,9 @@ int FullconnectionCPUKernel::ReSize() { | |||||
| memcpy(bias_ptr_, in_tensors_[2]->MutableData(), fc_param_->col_ * sizeof(float)); | memcpy(bias_ptr_, in_tensors_[2]->MutableData(), fc_param_->col_ * sizeof(float)); | ||||
| } | } | ||||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||||
| #ifdef ENABLE_AVX | |||||
| int row_tmp = is_vector_input_ ? 1 : fc_param_->row_6_; | |||||
| #elif defined(ENABLE_ARM32) || defined(ENABLE_SSE) | |||||
| int row_tmp = is_vector_input_ ? 1 : fc_param_->row_4_; | int row_tmp = is_vector_input_ ? 1 : fc_param_->row_4_; | ||||
| #else | #else | ||||
| int row_tmp = is_vector_input_ ? 1 : fc_param_->row_12_; | int row_tmp = is_vector_input_ ? 1 : fc_param_->row_12_; | ||||
| @@ -120,7 +128,9 @@ void FullconnectionCPUKernel::InitMatrixA(const float *src_ptr, float *dst_ptr) | |||||
| return; | return; | ||||
| } | } | ||||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||||
| #ifdef ENABLE_AVX | |||||
| RowMajor2Col6Major(src_ptr, a_pack_ptr_, fc_param_->row_, fc_param_->deep_); | |||||
| #elif defined(ENABLE_ARM32) || defined(ENABLE_SSE) | |||||
| RowMajor2Col4Major(src_ptr, a_pack_ptr_, fc_param_->row_, fc_param_->deep_); | RowMajor2Col4Major(src_ptr, a_pack_ptr_, fc_param_->row_, fc_param_->deep_); | ||||
| #else | #else | ||||
| RowMajor2Col12Major(src_ptr, a_pack_ptr_, fc_param_->row_, fc_param_->deep_); | RowMajor2Col12Major(src_ptr, a_pack_ptr_, fc_param_->row_, fc_param_->deep_); | ||||
| @@ -132,8 +142,11 @@ void FullconnectionCPUKernel::InitMatrixB(const float *src_ptr, float *dst_ptr) | |||||
| memcpy(dst_ptr, src_ptr, fc_param_->col_ * fc_param_->deep_ * sizeof(float)); | memcpy(dst_ptr, src_ptr, fc_param_->col_ * fc_param_->deep_ * sizeof(float)); | ||||
| return; | return; | ||||
| } | } | ||||
| #ifdef ENABLE_AVX | |||||
| RowMajor2Col16Major(src_ptr, dst_ptr, fc_param_->col_, fc_param_->deep_); | |||||
| #else | |||||
| RowMajor2Col8Major(src_ptr, dst_ptr, fc_param_->col_, fc_param_->deep_); | RowMajor2Col8Major(src_ptr, dst_ptr, fc_param_->col_, fc_param_->deep_); | ||||
| #endif | |||||
| } | } | ||||
| int FcFp32MatmulRun(void *cdata, int task_id) { | int FcFp32MatmulRun(void *cdata, int task_id) { | ||||
| @@ -147,14 +160,19 @@ int FcFp32MatmulRun(void *cdata, int task_id) { | |||||
| } | } | ||||
| int FullconnectionCPUKernel::DoMatmul(int task_id) { | int FullconnectionCPUKernel::DoMatmul(int task_id) { | ||||
| int cur_oc = MSMIN(thread_stride_ * C8NUM, fc_param_->col_ - task_id * thread_stride_ * C8NUM); | |||||
| #ifdef ENABLE_AVX | |||||
| int col_tile = C16NUM; | |||||
| #else | |||||
| int col_tile = C8NUM; | |||||
| #endif | |||||
| int cur_oc = MSMIN(thread_stride_ * col_tile, fc_param_->col_ - task_id * thread_stride_ * col_tile); | |||||
| if (cur_oc <= 0) { | if (cur_oc <= 0) { | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| auto b = b_ptr_ + task_id * thread_stride_ * C8NUM * fc_param_->deep_; | |||||
| auto bias = (bias_ptr_ == nullptr) ? nullptr : bias_ptr_ + task_id * thread_stride_ * C8NUM; | |||||
| auto c = c_ptr_ + task_id * thread_stride_ * C8NUM; | |||||
| auto b = b_ptr_ + task_id * thread_stride_ * col_tile * fc_param_->deep_; | |||||
| auto bias = (bias_ptr_ == nullptr) ? nullptr : bias_ptr_ + task_id * thread_stride_ * col_tile; | |||||
| auto c = c_ptr_ + task_id * thread_stride_ * col_tile; | |||||
| if (is_vector_input_) { | if (is_vector_input_) { | ||||
| MatVecMul(a_ptr_, b, c, bias, fc_param_->act_type_, fc_param_->deep_, cur_oc); | MatVecMul(a_ptr_, b, c, bias, fc_param_->act_type_, fc_param_->deep_, cur_oc); | ||||
| } else { | } else { | ||||
| @@ -75,9 +75,12 @@ int MatmulCPUKernel::MallocMatrixABuffer() { | |||||
| #endif | #endif | ||||
| params_->deep_ = params_->a_transpose_ ? a_shape[a_shape.size() - 2] : a_shape[a_shape.size() - 1]; | params_->deep_ = params_->a_transpose_ ? a_shape[a_shape.size() - 2] : a_shape[a_shape.size() - 1]; | ||||
| params_->row_4_ = UP_ROUND(params_->row_, C4NUM); | params_->row_4_ = UP_ROUND(params_->row_, C4NUM); | ||||
| params_->row_6_ = UP_ROUND(params_->row_, C6NUM); | |||||
| params_->row_12_ = UP_ROUND(params_->row_, C12NUM); | params_->row_12_ = UP_ROUND(params_->row_, C12NUM); | ||||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||||
| #ifdef ENABLE_AVX | |||||
| int row_tmp = is_vector_a_ ? 1 : params_->row_6_; | |||||
| #elif defined(ENABLE_ARM32) || defined(ENABLE_SSE) | |||||
| int row_tmp = is_vector_a_ ? 1 : params_->row_4_; | int row_tmp = is_vector_a_ ? 1 : params_->row_4_; | ||||
| #else | #else | ||||
| int row_tmp = is_vector_a_ ? 1 : params_->row_12_; | int row_tmp = is_vector_a_ ? 1 : params_->row_12_; | ||||
| @@ -106,9 +109,14 @@ int MatmulCPUKernel::MallocMatrixBBuffer() { | |||||
| for (size_t i = 0; i < b_shape.size() - 2; ++i) { | for (size_t i = 0; i < b_shape.size() - 2; ++i) { | ||||
| batch *= b_shape[i]; | batch *= b_shape[i]; | ||||
| } | } | ||||
| #ifdef ENABLE_AVX | |||||
| int col_tile = C16NUM; | |||||
| #else | |||||
| int col_tile = C8NUM; | |||||
| #endif | |||||
| params_->batch = batch; | params_->batch = batch; | ||||
| params_->col_ = params_->b_transpose_ ? b_shape[b_shape.size() - 2] : b_shape[b_shape.size() - 1]; | params_->col_ = params_->b_transpose_ ? b_shape[b_shape.size() - 2] : b_shape[b_shape.size() - 1]; | ||||
| params_->col_8_ = UP_ROUND(params_->col_, 8); | |||||
| params_->col_8_ = UP_ROUND(params_->col_, col_tile); | |||||
| params_->deep_ = params_->b_transpose_ ? b_shape[b_shape.size() - 1] : b_shape[b_shape.size() - 2]; | params_->deep_ = params_->b_transpose_ ? b_shape[b_shape.size() - 1] : b_shape[b_shape.size() - 2]; | ||||
| int col_tmp = is_vector_a_ ? params_->col_ : params_->col_8_; | int col_tmp = is_vector_a_ ? params_->col_ : params_->col_8_; | ||||
| @@ -123,8 +131,8 @@ int MatmulCPUKernel::MallocMatrixBBuffer() { | |||||
| return RET_MEMORY_FAILED; | return RET_MEMORY_FAILED; | ||||
| } | } | ||||
| thread_count_ = MSMIN(thread_count_, UP_DIV(params_->col_8_, 8)); | |||||
| thread_stride_ = UP_DIV(UP_DIV(params_->col_8_, 8), thread_count_); | |||||
| thread_count_ = MSMIN(thread_count_, UP_DIV(params_->col_8_, col_tile)); | |||||
| thread_stride_ = UP_DIV(UP_DIV(params_->col_8_, col_tile), thread_count_); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -134,7 +142,12 @@ int MatmulCPUKernel::InitBias() { | |||||
| params_->col_ = params_->b_const_ | params_->col_ = params_->b_const_ | ||||
| ? (params_->b_transpose_ ? b_shape.at(b_shape.size() - 2) : b_shape.at(b_shape.size() - 1)) | ? (params_->b_transpose_ ? b_shape.at(b_shape.size() - 2) : b_shape.at(b_shape.size() - 1)) | ||||
| : (c_shape.at(c_shape.size() - 1)); | : (c_shape.at(c_shape.size() - 1)); | ||||
| params_->col_8_ = UP_ROUND(params_->col_, 8); | |||||
| #ifdef ENABLE_AVX | |||||
| int col_tile = C16NUM; | |||||
| #else | |||||
| int col_tile = C8NUM; | |||||
| #endif | |||||
| params_->col_8_ = UP_ROUND(params_->col_, col_tile); | |||||
| auto col_tmp = is_vector_a_ ? params_->col_ : params_->col_8_; | auto col_tmp = is_vector_a_ ? params_->col_ : params_->col_8_; | ||||
| if (bias_ptr_ == nullptr) { | if (bias_ptr_ == nullptr) { | ||||
| bias_ptr_ = reinterpret_cast<float *>(malloc(col_tmp * sizeof(float))); | bias_ptr_ = reinterpret_cast<float *>(malloc(col_tmp * sizeof(float))); | ||||
| @@ -171,7 +184,14 @@ void MatmulCPUKernel::InitMatrixA(const float *src_ptr, float *dst_ptr) { | |||||
| for (int i = 0; i < params_->batch; i++) { | for (int i = 0; i < params_->batch; i++) { | ||||
| const float *src = src_ptr + i * params_->deep_ * params_->row_; | const float *src = src_ptr + i * params_->deep_ * params_->row_; | ||||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||||
| #ifdef ENABLE_AVX | |||||
| float *dst = dst_ptr + i * params_->deep_ * params_->row_6_; | |||||
| if (params_->a_transpose_) { | |||||
| RowMajor2Row6Major(src, dst, params_->deep_, params_->row_); | |||||
| } else { | |||||
| RowMajor2Col6Major(src, dst, params_->row_, params_->deep_); | |||||
| } | |||||
| #elif defined(ENABLE_ARM32) || defined(ENABLE_SSE) | |||||
| float *dst = dst_ptr + i * params_->deep_ * params_->row_4_; | float *dst = dst_ptr + i * params_->deep_ * params_->row_4_; | ||||
| if (params_->a_transpose_) { | if (params_->a_transpose_) { | ||||
| RowMajor2Row4Major(src, dst, params_->deep_, params_->row_); | RowMajor2Row4Major(src, dst, params_->deep_, params_->row_); | ||||
| @@ -207,11 +227,19 @@ void MatmulCPUKernel::InitMatrixB(const float *src_ptr, float *dst_ptr) { | |||||
| for (int i = 0; i < params_->batch; i++) { | for (int i = 0; i < params_->batch; i++) { | ||||
| const float *src = src_ptr + i * params_->deep_ * params_->col_; | const float *src = src_ptr + i * params_->deep_ * params_->col_; | ||||
| float *dst = dst_ptr + i * params_->deep_ * params_->col_8_; | float *dst = dst_ptr + i * params_->deep_ * params_->col_8_; | ||||
| #ifdef ENABLE_AVX | |||||
| if (params_->b_transpose_) { | |||||
| RowMajor2Col16Major(src, dst, params_->col_, params_->deep_); | |||||
| } else { | |||||
| RowMajor2Row16Major(src, dst, params_->deep_, params_->col_); | |||||
| } | |||||
| #else | |||||
| if (params_->b_transpose_) { | if (params_->b_transpose_) { | ||||
| RowMajor2Col8Major(src, dst, params_->col_, params_->deep_); | RowMajor2Col8Major(src, dst, params_->col_, params_->deep_); | ||||
| } else { | } else { | ||||
| RowMajor2Row8Major(src, dst, params_->deep_, params_->col_); | RowMajor2Row8Major(src, dst, params_->deep_, params_->col_); | ||||
| } | } | ||||
| #endif | |||||
| } | } | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -247,13 +275,18 @@ int MatmulCPUKernel::Init() { | |||||
| } | } | ||||
| int MatmulCPUKernel::RunImpl(int task_id) { | int MatmulCPUKernel::RunImpl(int task_id) { | ||||
| int cur_oc = MSMIN(thread_stride_ * C8NUM, params_->col_ - task_id * thread_stride_ * C8NUM); | |||||
| #ifdef ENABLE_AVX | |||||
| int col_tile = C16NUM; | |||||
| #else | |||||
| int col_tile = C8NUM; | |||||
| #endif | |||||
| int cur_oc = MSMIN(thread_stride_ * col_tile, params_->col_ - task_id * thread_stride_ * col_tile); | |||||
| if (cur_oc <= 0) { | if (cur_oc <= 0) { | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| auto b = cur_b_ptr_ + task_id * thread_stride_ * C8NUM * params_->deep_; | |||||
| auto c = cur_c_ptr_ + task_id * thread_stride_ * C8NUM; | |||||
| auto bias = bias_ptr_ ? bias_ptr_ + task_id * thread_stride_ * C8NUM : NULL; | |||||
| auto b = cur_b_ptr_ + task_id * thread_stride_ * col_tile * params_->deep_; | |||||
| auto c = cur_c_ptr_ + task_id * thread_stride_ * col_tile; | |||||
| auto bias = bias_ptr_ ? bias_ptr_ + task_id * thread_stride_ * col_tile : NULL; | |||||
| MS_ASSERT(cur_a_ptr_); | MS_ASSERT(cur_a_ptr_); | ||||
| MS_ASSERT(b); | MS_ASSERT(b); | ||||
| MS_ASSERT(c); | MS_ASSERT(c); | ||||
| @@ -323,7 +356,9 @@ int MatmulCPUKernel::Run() { | |||||
| cur_b_ptr_ = b_ptr_ + i * params_->deep_ * params_->col_; | cur_b_ptr_ = b_ptr_ + i * params_->deep_ * params_->col_; | ||||
| cur_c_ptr_ = c_src + i * params_->row_ * params_->col_; | cur_c_ptr_ = c_src + i * params_->row_ * params_->col_; | ||||
| } else { | } else { | ||||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||||
| #ifdef ENABLE_AVX | |||||
| cur_a_ptr_ = a_ptr_ + i * params_->row_6_ * params_->deep_; | |||||
| #elif defined(ENABLE_ARM32) || defined(ENABLE_SSE) | |||||
| cur_a_ptr_ = a_ptr_ + i * params_->row_4_ * params_->deep_; | cur_a_ptr_ = a_ptr_ + i * params_->row_4_ * params_->deep_; | ||||
| #else | #else | ||||
| cur_a_ptr_ = a_ptr_ + i * params_->row_12_ * params_->deep_; | cur_a_ptr_ = a_ptr_ + i * params_->row_12_ * params_->deep_; | ||||
| @@ -75,6 +75,16 @@ if ("${X86_64_SIMD}" STREQUAL "sse") | |||||
| ) | ) | ||||
| endif() | endif() | ||||
| if ("${X86_64_SIMD}" STREQUAL "avx") | |||||
| file(GLOB TEST_ASSEMBLY_SRC ${LITE_DIR}/nnacl/x86_64_sse/*.c | |||||
| ${LITE_DIR}/nnacl/assembly/avx/*.S) | |||||
| set_property(SOURCE ${TEST_ASSEMBLY_SRC} PROPERTY LANGUAGE C) | |||||
| set(KERNEL_OP_SRC | |||||
| ${KERNEL_OP_SRC} | |||||
| ${TEST_ASSEMBLY_SRC} | |||||
| ) | |||||
| endif() | |||||
| ### gpu kernel | ### gpu kernel | ||||
| if (SUPPORT_GPU) | if (SUPPORT_GPU) | ||||
| file(GLOB GPU_KERNEL_OP_SRC | file(GLOB GPU_KERNEL_OP_SRC | ||||