From 0aa289b0fd816acd74b7229b6d16b62f7d90d631 Mon Sep 17 00:00:00 2001 From: lixian Date: Tue, 8 Dec 2020 21:42:16 +0800 Subject: [PATCH] add avx fp32 matmul kernel --- mindspore/lite/CMakeLists.txt | 11 +- mindspore/lite/nnacl/CMakeLists.txt | 6 + mindspore/lite/nnacl/assembly/avx/MatmulAvx.S | 941 ++++++++++++++++++ mindspore/lite/nnacl/fp32/adder_fp32.c | 4 +- mindspore/lite/nnacl/fp32/common_func_fp32.c | 4 +- mindspore/lite/nnacl/fp32/common_func_fp32.h | 2 +- .../lite/nnacl/fp32/conv_depthwise_fp32.c | 6 +- mindspore/lite/nnacl/fp32/conv_fp32.c | 10 +- mindspore/lite/nnacl/fp32/deconv_fp32.c | 2 +- mindspore/lite/nnacl/fp32/matmul_fp32.c | 229 ++++- mindspore/lite/nnacl/fp32/matmul_fp32.h | 11 +- mindspore/lite/nnacl/matmul_parameter.h | 1 + .../lite/nnacl/minimal_filtering_generator.c | 4 +- .../lite/nnacl/minimal_filtering_generator.h | 2 +- mindspore/lite/nnacl/op_base.h | 7 +- mindspore/lite/nnacl/pack.c | 2 +- mindspore/lite/nnacl/winograd_utils.c | 2 +- .../lite/nnacl/x86_64_sse/DepthwiseFp32_Sse.c | 4 +- mindspore/lite/nnacl/x86_64_sse/MatMul_Sse.c | 4 +- .../nnacl/x86_64_sse/PackNHWCToNCHWFp32.c | 4 +- .../lite/nnacl/x86_64_sse/PosFuncBiasRelu.c | 4 +- .../lite/nnacl/x86_64_sse/WinogradTrans.c | 4 +- .../kernel/arm/fp32/convolution_1x1_fp32.cc | 43 +- .../kernel/arm/fp32/convolution_fp32.cc | 16 +- .../kernel/arm/fp32/deconvolution_fp32.cc | 8 +- .../kernel/arm/fp32/fullconnection_fp32.cc | 38 +- .../runtime/kernel/arm/fp32/matmul_fp32.cc | 57 +- mindspore/lite/test/CMakeLists.txt | 10 + 28 files changed, 1354 insertions(+), 82 deletions(-) create mode 100644 mindspore/lite/nnacl/assembly/avx/MatmulAvx.S diff --git a/mindspore/lite/CMakeLists.txt b/mindspore/lite/CMakeLists.txt index c0595ae818..717ab0170f 100644 --- a/mindspore/lite/CMakeLists.txt +++ b/mindspore/lite/CMakeLists.txt @@ -20,7 +20,8 @@ option(SUPPORT_GPU "if support gpu" off) option(OFFLINE_COMPILE "if offline compile OpenCL kernel" off) option(BUILD_MINDDATA_EXAMPLE "" on) option(ENABLE_VERBOSE "" off) -option(ENABLE_X86_64_SSE "if x86_64 support SSE instruction set" off) +option(ENABLE_SSE "if x86_64 support SSE instruction set" off) +option(ENABLE_AVX "if x86_64 support SSE instruction set" off) set(DIR_PREFIX mindspore-lite) set(MS_VERSION ${MS_VERSION_MAJOR}.${MS_VERSION_MINOR}.${MS_VERSION_REVISION}) @@ -187,7 +188,13 @@ endif() if (NOT PLATFORM_ARM32 AND NOT PLATFORM_ARM64) if ("${X86_64_SIMD}" STREQUAL "sse") - add_compile_definitions(ENABLE_X86_64_SSE) + add_compile_definitions(ENABLE_SSE) + endif () + if ("${X86_64_SIMD}" STREQUAL "avx") + add_compile_definitions(ENABLE_SSE) + add_compile_definitions(ENABLE_AVX) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx -mfma") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx -mfma") endif () endif () diff --git a/mindspore/lite/nnacl/CMakeLists.txt b/mindspore/lite/nnacl/CMakeLists.txt index df24b90f8e..ec9045de7d 100644 --- a/mindspore/lite/nnacl/CMakeLists.txt +++ b/mindspore/lite/nnacl/CMakeLists.txt @@ -37,6 +37,12 @@ if ("${X86_64_SIMD}" STREQUAL "sse") set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C) endif() +if ("${X86_64_SIMD}" STREQUAL "avx") + file(GLOB ASSEMBLY_SRC ${NNACL_DIR}/x86_64_sse/*.c + ${NNACL_DIR}/assembly/avx/*.S) + set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C) +endif() + ########################### build nnacl static library ######################## string(REPLACE "-fvisibility=hidden" "-fvisibility=default" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") add_library(nnacl STATIC ${KERNEL_SRC} ${TRAIN_SRC} ${ASSEMBLY_SRC}) diff --git a/mindspore/lite/nnacl/assembly/avx/MatmulAvx.S b/mindspore/lite/nnacl/assembly/avx/MatmulAvx.S new file mode 100644 index 0000000000..3382be2758 --- /dev/null +++ b/mindspore/lite/nnacl/assembly/avx/MatmulAvx.S @@ -0,0 +1,941 @@ +#ifdef ENABLE_AVX +#ifndef WIN32 + .text + .align 4 + .global MatmulFloatAvxOpt +#ifndef __APPLE__ + .type MatmulFloatAvxOpt, %function +#endif + +// void MatmulFloatNeon32Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth +// int row, int col, size_t stride, size_t writeMode) +// rdi: a +// rsi: b +// rdx: c +// rcx: bias +// r8: act_type +// r9: depth +// 8: row +// 16: col +// 24: stride +// 32: writeNhwc/writeWino + +MatmulFloatAvxOpt: + // rbx, rsp, rbp, r12-r15 must be saved according to x86 calling convention + pushq %r15 + pushq %r14 + pushq %r13 + pushq %r12 + pushq %rbx + pushq %rbp + pushq %r9 + pushq %r8 + pushq %rcx + pushq %rdx + pushq %rsi + pushq %rdi + addq $96, %rsp + + movq 8(%rsp), %rbp + movq 16(%rsp), %rbx + movq 24(%rsp), %r10 + movq 32(%rsp), %r14 + + movq $24, %r11 + imul %r9, %r11 + cmpq $0, %r14 + jne NoC8Steps + movq $48, %r13 + imul %rbp, %r13 +NoC8Steps: + cmpq $2, %r14 + jne NoWinoSteps + movq $4, %r12 + imul %r10, %r12 + imul %rbx, %r12 + movq $48, %r13 + imul %r10, %r13 +NoWinoSteps: + movq $4, %rax + imul %rax, %r10 + +LoopRow: + movq -88(%rsp), %rsi + movq 16(%rsp), %rbx + movq -72(%rsp), %rcx + + LoopCol: + cmpq $0, %r14 + je NoReloadDst + movq -80(%rsp), %rdx + NoReloadDst: + movq -96(%rsp), %rdi + movq -56(%rsp), %r9 + + vmovups (%rsi), %ymm0 + vmovups 32(%rsi), %ymm1 + vbroadcastss (%rdi), %ymm10 + vbroadcastss 4(%rdi), %ymm11 + vbroadcastss 8(%rdi), %ymm12 + vbroadcastss 12(%rdi), %ymm13 + vbroadcastss 16(%rdi), %ymm2 + vbroadcastss 20(%rdi), %ymm3 + addq $64, %rsi + vmulps %ymm0, %ymm10, %ymm4 + vmulps %ymm1, %ymm10, %ymm5 + vmulps %ymm0, %ymm11, %ymm6 + vmulps %ymm1, %ymm11, %ymm7 + vmulps %ymm0, %ymm12, %ymm8 + vmulps %ymm1, %ymm12, %ymm9 + vmulps %ymm0, %ymm13, %ymm10 + vmulps %ymm1, %ymm13, %ymm11 + add $24, %rdi + vmulps %ymm0, %ymm2, %ymm12 + vmulps %ymm1, %ymm2, %ymm13 + vmulps %ymm0, %ymm3, %ymm14 + vmulps %ymm1, %ymm3, %ymm15 + + subq $1, %r9 + cmpq $0, %r9 + je Bias + + LoopDepth: + vmovups (%rsi), %ymm0 + vmovups 32(%rsi), %ymm1 + vbroadcastss (%rdi), %ymm2 + vbroadcastss 4(%rdi), %ymm3 + vfmadd231ps %ymm0, %ymm2, %ymm4 + addq $64, %rsi + vfmadd231ps %ymm1, %ymm2, %ymm5 + vbroadcastss 8(%rdi), %ymm2 + vfmadd231ps %ymm0, %ymm3, %ymm6 + vfmadd231ps %ymm1, %ymm3, %ymm7 + vbroadcastss 12(%rdi), %ymm3 + vfmadd231ps %ymm0, %ymm2, %ymm8 + prefetcht0 384(%rsi) + vfmadd231ps %ymm1, %ymm2, %ymm9 + vbroadcastss 16(%rdi), %ymm2 + vfmadd231ps %ymm0, %ymm3, %ymm10 + vfmadd231ps %ymm1, %ymm3, %ymm11 + vbroadcastss 20(%rdi), %ymm3 + vfmadd231ps %ymm0, %ymm2, %ymm12 + vfmadd231ps %ymm1, %ymm2, %ymm13 + addq $24, %rdi + vfmadd231ps %ymm0, %ymm3, %ymm14 + vfmadd231ps %ymm1, %ymm3, %ymm15 + + subq $1, %r9 + cmpq $0, %r9 + ja LoopDepth + + Bias: + cmpq $0, %rcx + je Activation + vmovups (%rcx), %ymm0 + vmovups 32(%rcx), %ymm1 + add $64, %rcx + vaddps %ymm0, %ymm4, %ymm4 + vaddps %ymm1, %ymm5, %ymm5 + vaddps %ymm0, %ymm6, %ymm6 + vaddps %ymm1, %ymm7, %ymm7 + vaddps %ymm0, %ymm8, %ymm8 + vaddps %ymm1, %ymm9, %ymm9 + vaddps %ymm0, %ymm10, %ymm10 + vaddps %ymm1, %ymm11, %ymm11 + vaddps %ymm0, %ymm12, %ymm12 + vaddps %ymm1, %ymm13, %ymm13 + vaddps %ymm0, %ymm14, %ymm14 + vaddps %ymm1, %ymm15, %ymm15 + + Activation: + cmpq $3, %r8 + je Relu6 + cmpq $1, %r8 + je Relu + jmp Write + + Relu6: + movq $6, %rax + vcvtsi2ss %rax, %xmm0, %xmm0 + vshufps $0, %xmm0, %xmm0, %xmm0 + vinsertf128 $1, %xmm0, %ymm0, %ymm0 + vminps %ymm0, %ymm4, %ymm4 + vminps %ymm0, %ymm5, %ymm5 + vminps %ymm0, %ymm6, %ymm6 + vminps %ymm0, %ymm7, %ymm7 + vminps %ymm0, %ymm8, %ymm8 + vminps %ymm0, %ymm9, %ymm9 + vminps %ymm0, %ymm10, %ymm10 + vminps %ymm0, %ymm11, %ymm11 + vminps %ymm0, %ymm12, %ymm12 + vminps %ymm0, %ymm13, %ymm13 + vminps %ymm0, %ymm14, %ymm14 + vminps %ymm0, %ymm15, %ymm15 + + Relu: + vxorps %ymm1, %ymm1, %ymm1 + vmaxps %ymm1, %ymm4, %ymm4 + vmaxps %ymm1, %ymm5, %ymm5 + vmaxps %ymm1, %ymm6, %ymm6 + vmaxps %ymm1, %ymm7, %ymm7 + vmaxps %ymm1, %ymm8, %ymm8 + vmaxps %ymm1, %ymm9, %ymm9 + vmaxps %ymm1, %ymm10, %ymm10 + vmaxps %ymm1, %ymm11, %ymm11 + vmaxps %ymm1, %ymm12, %ymm12 + vmaxps %ymm1, %ymm13, %ymm13 + vmaxps %ymm1, %ymm14, %ymm14 + vmaxps %ymm1, %ymm15, %ymm15 + + Write: + cmpq $2, %r14 + je WriteWino + cmpq $0, %r14 + je WriteC8 + cmpq $1, %rbx + je Write1 + cmpq $2, %rbx + je Write2 + cmpq $3, %rbx + je Write3 + cmpq $4, %rbx + je Write4 + cmpq $5, %rbx + je Write5 + cmpq $6, %rbx + je Write6 + cmpq $7, %rbx + je Write7 + cmpq $8, %rbx + je Write8 + cmpq $9, %rbx + je Write9 + cmpq $10, %rbx + je Write10 + cmpq $11, %rbx + je Write11 + cmpq $12, %rbx + je Write12 + cmpq $13, %rbx + je Write13 + cmpq $14, %rbx + je Write14 + cmpq $15, %rbx + je Write15 + jmp Write16 + + Write1: + movq %rdx, %rax + addq $4, %rax + movq %rax, -80(%rsp) + vmovss %xmm4, (%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovss %xmm6, (%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovss %xmm8, (%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovss %xmm10, (%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovss %xmm12, (%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovss %xmm14, (%rdx) + addq %r10, %rdx + addq $4, %rdx + jmp WriteEnd + Write2: + movq %rdx, %rax + addq $8, %rax + movq %rax, -80(%rsp) + vmovsd %xmm4, (%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovsd %xmm6, (%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovsd %xmm8, (%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovsd %xmm10, (%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovsd %xmm12, (%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovsd %xmm14, (%rdx) + addq %r10, %rdx + addq $8, %rdx + jmp WriteEnd + Write3: + movq %rdx, %rax + addq $12, %rax + movq %rax, -80(%rsp) + vmovsd %xmm4, (%rdx) + movhlps %xmm4, %xmm4 + vmovss %xmm4, 8(%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovsd %xmm6, (%rdx) + movhlps %xmm6, %xmm6 + vmovss %xmm6, 8(%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovsd %xmm8, (%rdx) + movhlps %xmm8, %xmm8 + vmovss %xmm8, 8(%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovsd %xmm10, (%rdx) + movhlps %xmm10, %xmm10 + vmovss %xmm10, 8(%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovsd %xmm12, (%rdx) + movhlps %xmm12, %xmm12 + vmovss %xmm12, 8(%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovsd %xmm14, (%rdx) + movhlps %xmm14, %xmm14 + vmovss %xmm14, 8(%rdx) + addq %r10, %rdx + addq $12, %rdx + jmp WriteEnd + Write4: + movq %rdx, %rax + addq $16, %rax + movq %rax, -80(%rsp) + vmovups %xmm4, (%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm6, (%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm8, (%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm10, (%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm12, (%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm14, (%rdx) + addq %r10, %rdx + addq $16, %rdx + jmp WriteEnd + Write5: + movq %rdx, %rax + addq $20, %rax + movq %rax, -80(%rsp) + vmovups %xmm4, (%rdx) + vextractf128 $1, %ymm4, %xmm4 + vmovss %xmm4, 16(%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm6, (%rdx) + vextractf128 $1, %ymm6, %xmm6 + vmovss %xmm6, 16(%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm8, (%rdx) + vextractf128 $1, %ymm8, %xmm8 + vmovss %xmm8, 16(%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm10, (%rdx) + vextractf128 $1, %ymm10, %xmm10 + vmovss %xmm10, 16(%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm12, (%rdx) + vextractf128 $1, %ymm12, %xmm12 + vmovss %xmm12, 16(%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm14, (%rdx) + vextractf128 $1, %ymm14, %xmm14 + vmovss %xmm14, 16(%rdx) + addq %r10, %rdx + addq $20, %rdx + jmp WriteEnd + Write6: + movq %rdx, %rax + addq $24, %rax + movq %rax, -80(%rsp) + vmovups %xmm4, (%rdx) + vextractf128 $1, %ymm4, %xmm4 + vmovsd %xmm4, 16(%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm6, (%rdx) + vextractf128 $1, %ymm6, %xmm6 + vmovsd %xmm6, 16(%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm8, (%rdx) + vextractf128 $1, %ymm8, %xmm8 + vmovsd %xmm8, 16(%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm10, (%rdx) + vextractf128 $1, %ymm10, %xmm10 + vmovsd %xmm10, 16(%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm12, (%rdx) + vextractf128 $1, %ymm12, %xmm12 + vmovsd %xmm12, 16(%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm14, (%rdx) + vextractf128 $1, %ymm14, %xmm14 + vmovsd %xmm14, 16(%rdx) + addq %r10, %rdx + addq $24, %rdx + jmp WriteEnd + Write7: + movq %rdx, %rax + addq $28, %rax + movq %rax, -80(%rsp) + vmovups %xmm4, (%rdx) + vextractf128 $1, %ymm4, %xmm4 + vmovsd %xmm4, 16(%rdx) + movhlps %xmm4, %xmm4 + vmovss %xmm4, 24(%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm6, (%rdx) + vextractf128 $1, %ymm6, %xmm6 + vmovsd %xmm6, 16(%rdx) + movhlps %xmm6, %xmm6 + vmovss %xmm6, 24(%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm8, (%rdx) + vextractf128 $1, %ymm8, %xmm8 + vmovsd %xmm8, 16(%rdx) + movhlps %xmm8, %xmm8 + vmovss %xmm8, 24(%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm10, (%rdx) + vextractf128 $1, %ymm10, %xmm10 + vmovsd %xmm10, 16(%rdx) + movhlps %xmm10, %xmm10 + vmovss %xmm10, 24(%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm12, (%rdx) + vextractf128 $1, %ymm12, %xmm12 + vmovsd %xmm12, 16(%rdx) + movhlps %xmm12, %xmm12 + vmovss %xmm12, 24(%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm14, (%rdx) + vextractf128 $1, %ymm14, %xmm14 + vmovsd %xmm14, 16(%rdx) + movhlps %xmm14, %xmm14 + vmovss %xmm14, 24(%rdx) + addq %r10, %rdx + addq $28, %rdx + jmp WriteEnd + Write8: + movq %rdx, %rax + addq $32, %rax + movq %rax, -80(%rsp) + vmovups %ymm4, (%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm6, (%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm8, (%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm10, (%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm12, (%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm14, (%rdx) + addq %r10, %rdx + addq $32, %rdx + jmp WriteEnd + Write9: + movq %rdx, %rax + addq $36, %rax + movq %rax, -80(%rsp) + vmovups %ymm4, (%rdx) + vmovss %xmm5, 32(%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm6, (%rdx) + vmovss %xmm7, 32(%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm8, (%rdx) + vmovss %xmm9, 32(%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm10, (%rdx) + vmovss %xmm11, 32(%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm12, (%rdx) + vmovss %xmm13, 32(%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm14, (%rdx) + vmovss %xmm15, 32(%rdx) + addq %r10, %rdx + addq $36, %rdx + jmp WriteEnd + Write10: + movq %rdx, %rax + addq $40, %rax + movq %rax, -80(%rsp) + vmovups %ymm4, (%rdx) + vmovsd %xmm5, 32(%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm6, (%rdx) + vmovsd %xmm7, 32(%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm8, (%rdx) + vmovsd %xmm9, 32(%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm10, (%rdx) + vmovsd %xmm11, 32(%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm12, (%rdx) + vmovsd %xmm13, 32(%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm14, (%rdx) + vmovsd %xmm15, 32(%rdx) + addq %r10, %rdx + addq $40, %rdx + jmp WriteEnd + Write11: + movq %rdx, %rax + addq $44, %rax + movq %rax, -80(%rsp) + vmovups %ymm4, (%rdx) + vmovsd %xmm5, 32(%rdx) + movhlps %xmm5, %xmm5 + vmovss %xmm5, 40(%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm6, (%rdx) + vmovsd %xmm7, 32(%rdx) + movhlps %xmm7, %xmm7 + vmovss %xmm7, 40(%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm8, (%rdx) + vmovsd %xmm9, 32(%rdx) + movhlps %xmm9, %xmm9 + vmovss %xmm9, 40(%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm10, (%rdx) + vmovsd %xmm11, 32(%rdx) + movhlps %xmm11, %xmm11 + vmovss %xmm11, 40(%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm12, (%rdx) + vmovsd %xmm13, 32(%rdx) + movhlps %xmm13, %xmm13 + vmovss %xmm13, 40(%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm14, (%rdx) + vmovsd %xmm15, 32(%rdx) + movhlps %xmm15, %xmm15 + vmovss %xmm15, 40(%rdx) + addq %r10, %rdx + addq $44, %rdx + jmp WriteEnd + Write12: + movq %rdx, %rax + addq $48, %rax + movq %rax, -80(%rsp) + vmovups %ymm4, (%rdx) + vmovups %xmm5, 32(%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm6, (%rdx) + vmovups %xmm7, 32(%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm8, (%rdx) + vmovups %xmm9, 32(%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm10, (%rdx) + vmovups %xmm11, 32(%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm12, (%rdx) + vmovups %xmm13, 32(%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm14, (%rdx) + vmovups %xmm15, 32(%rdx) + addq %r10, %rdx + addq $48, %rdx + jmp WriteEnd + Write13: + movq %rdx, %rax + addq $52, %rax + movq %rax, -80(%rsp) + vmovups %ymm4, (%rdx) + vmovups %xmm5, 32(%rdx) + vextractf128 $1, %ymm5, %xmm5 + vmovss %xmm5, 48(%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm6, (%rdx) + vmovups %xmm7, 32(%rdx) + vextractf128 $1, %ymm7, %xmm7 + vmovss %xmm7, 48(%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm8, (%rdx) + vmovups %xmm9, 32(%rdx) + vextractf128 $1, %ymm9, %xmm9 + vmovss %xmm9, 48(%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm10, (%rdx) + vmovups %xmm11, 32(%rdx) + vextractf128 $1, %ymm11, %xmm11 + vmovss %xmm11, 48(%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm12, (%rdx) + vmovups %xmm13, 32(%rdx) + vextractf128 $1, %ymm13, %xmm13 + vmovss %xmm13, 48(%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm14, (%rdx) + vmovups %xmm15, 32(%rdx) + vextractf128 $1, %ymm15, %xmm15 + vmovss %xmm15, 48(%rdx) + addq %r10, %rdx + addq $52, %rdx + jmp WriteEnd + Write14: + movq %rdx, %rax + addq $56, %rax + movq %rax, -80(%rsp) + vmovups %ymm4, (%rdx) + vmovups %xmm5, 32(%rdx) + vextractf128 $1, %ymm5, %xmm5 + vmovsd %xmm5, 48(%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm6, (%rdx) + vmovups %xmm7, 32(%rdx) + vextractf128 $1, %ymm7, %xmm7 + vmovsd %xmm7, 48(%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm8, (%rdx) + vmovups %xmm9, 32(%rdx) + vextractf128 $1, %ymm9, %xmm9 + vmovsd %xmm9, 48(%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm10, (%rdx) + vmovups %xmm11, 32(%rdx) + vextractf128 $1, %ymm11, %xmm11 + vmovsd %xmm11, 48(%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm12, (%rdx) + vmovups %xmm13, 32(%rdx) + vextractf128 $1, %ymm13, %xmm13 + vmovsd %xmm13, 48(%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm14, (%rdx) + vmovups %xmm15, 32(%rdx) + vextractf128 $1, %ymm15, %xmm15 + vmovsd %xmm15, 48(%rdx) + addq %r10, %rdx + addq $56, %rdx + jmp WriteEnd + Write15: + movq %rdx, %rax + addq $60, %rax + movq %rax, -80(%rsp) + vmovups %ymm4, (%rdx) + vmovups %xmm5, 32(%rdx) + vextractf128 $1, %ymm5, %xmm5 + vmovsd %xmm5, 48(%rdx) + movhlps %xmm5, %xmm5 + vmovss %xmm5, 56(%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm6, (%rdx) + vmovups %xmm7, 32(%rdx) + vextractf128 $1, %ymm7, %xmm7 + vmovsd %xmm7, 48(%rdx) + movhlps %xmm7, %xmm7 + vmovss %xmm7, 56(%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm8, (%rdx) + vmovups %xmm9, 32(%rdx) + vextractf128 $1, %ymm9, %xmm9 + vmovsd %xmm9, 48(%rdx) + movhlps %xmm9, %xmm9 + vmovss %xmm9, 56(%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm10, (%rdx) + vmovups %xmm11, 32(%rdx) + vextractf128 $1, %ymm11, %xmm11 + vmovsd %xmm11, 48(%rdx) + movhlps %xmm11, %xmm11 + vmovss %xmm11, 56(%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm12, (%rdx) + vmovups %xmm13, 32(%rdx) + vextractf128 $1, %ymm13, %xmm13 + vmovsd %xmm13, 48(%rdx) + movhlps %xmm13, %xmm13 + vmovss %xmm13, 56(%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm14, (%rdx) + vmovups %xmm15, 32(%rdx) + vextractf128 $1, %ymm15, %xmm15 + vmovsd %xmm15, 48(%rdx) + movhlps %xmm15, %xmm15 + vmovss %xmm15, 56(%rdx) + addq %r10, %rdx + addq $60, %rdx + jmp WriteEnd + WriteC8: + movq %rdx, %rax + addq %r11, %rdx + movq %rdx, %r15 + addq %r11, %rdx + movq %rdx, -80(%rsp) + vmovups %ymm4, (%rax) + vmovups %ymm6, 32(%rax) + vmovups %ymm8, 64(%rax) + vmovups %ymm10, 96(%rax) + vmovups %ymm12, 128(%rax) + vmovups %ymm14, 160(%rax) + vmovups %ymm5, (%r15) + vmovups %ymm7, 32(%r15) + vmovups %ymm9, 64(%r15) + vmovups %ymm11, 96(%r15) + vmovups %ymm13, 128(%r15) + vmovups %ymm15, 160(%r15) + jmp WriteEnd + WriteWino: + movq %rdx, %rax + addq %r13, %rdx + movq %rdx, %r15 + addq %r13, %rdx + movq %rdx, -80(%rsp) + vmovups %ymm4, (%rax) + vmovups %ymm5, (%r15) + addq %r12, %rax + addq %r12, %r15 + vmovups %ymm6, (%rax) + vmovups %ymm7, (%r15) + addq %r12, %rax + addq %r12, %r15 + vmovups %ymm8, (%rax) + vmovups %ymm9, (%r15) + addq %r12, %rax + addq %r12, %r15 + vmovups %ymm10, (%rax) + vmovups %ymm11, (%r15) + addq %r12, %rax + addq %r12, %r15 + vmovups %ymm12, (%rax) + vmovups %ymm13, (%r15) + addq %r12, %rax + addq %r12, %r15 + vmovups %ymm14, (%rax) + vmovups %ymm15, (%r15) + jmp WriteEnd + Write16: + movq %rdx, %rax + addq $64, %rax + movq %rax, -80(%rsp) + vmovups %ymm4, (%rdx) + vmovups %ymm5, 32(%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm6, (%rdx) + vmovups %ymm7, 32(%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm8, (%rdx) + vmovups %ymm9, 32(%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm10, (%rdx) + vmovups %ymm11, 32(%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm12, (%rdx) + vmovups %ymm13, 32(%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm14, (%rdx) + vmovups %ymm15, 32(%rdx) + addq %r10, %rdx + addq $64, %rdx + + WriteEnd: + cmpq $16, %rbx + jbe LoopColEnd + subq $16, %rbx + jmp LoopCol + + LoopColEnd: + movq -96(%rsp), %rdi + addq %r11, %rdi + movq %rdi, -96(%rsp) + cmpq $0, %r14 + je C8DstStep + cmpq $2, %r14 + je WinoDstStep + movq $4, %rax + movq 16(%rsp), %rbx + imul %rbx, %rax + subq %rax, %rdx + movq %rdx, -80(%rsp) + jmp NoDstStep + C8DstStep: + movq -80(%rsp), %rax + addq $384, %rax + movq %rax, -80(%rsp) + jmp NoDstStep + WinoDstStep: + addq %r13, %rdx + movq %rdx, -80(%rsp) + NoDstStep: + cmpq $6, %rbp + jbe LoopRowEnd + subq $6, %rbp + jmp LoopRow + +LoopRowEnd: + subq $96, %rsp + popq %rdi + popq %rsi + popq %rdx + popq %rcx + popq %r8 + popq %r9 + popq %rbp + popq %rbx + popq %r12 + popq %r13 + popq %r14 + popq %r15 + retq +#endif +#endif diff --git a/mindspore/lite/nnacl/fp32/adder_fp32.c b/mindspore/lite/nnacl/fp32/adder_fp32.c index 61efa9b2be..8060b8ce17 100644 --- a/mindspore/lite/nnacl/fp32/adder_fp32.c +++ b/mindspore/lite/nnacl/fp32/adder_fp32.c @@ -56,7 +56,7 @@ void AdderFp32(const float *input_data, float *packed_input, const float *packed int out_channel = conv_param->output_channel_; int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; int output_count = conv_param->output_h_ * conv_param->output_w_; -#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) +#if defined(ENABLE_ARM32) || defined(ENABLE_SSE) const int cal_num = C4NUM; #else const int cal_num = C12NUM; @@ -78,7 +78,7 @@ void AdderFp32(const float *input_data, float *packed_input, const float *packed int out_offset = thread_id * cal_num * out_channel + out_batch_offset; float *gemm_output = output_data + out_offset; -#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) +#if defined(ENABLE_ARM32) || defined(ENABLE_SSE) RowMajor2Col4Major(gemm_input, col_major_gemm_input, cal_num, deep); #else RowMajor2Col12Major(gemm_input, col_major_gemm_input, cal_num, deep); diff --git a/mindspore/lite/nnacl/fp32/common_func_fp32.c b/mindspore/lite/nnacl/fp32/common_func_fp32.c index 0edd78a2bf..2ec26b5e98 100644 --- a/mindspore/lite/nnacl/fp32/common_func_fp32.c +++ b/mindspore/lite/nnacl/fp32/common_func_fp32.c @@ -43,7 +43,7 @@ void PostConvFuncComm(const float *src_ptr_, float *out_ptr, const float *bias_p void PostConvFuncFp32C8(const float *c8_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel, size_t plane_size, size_t stride, size_t relu_type) { -#if !defined(ENABLE_ARM) && !defined(ENABLE_X86_64_SSE) +#if !defined(ENABLE_ARM) && !defined(ENABLE_SSE) PostConvFuncComm(c8_out_ptr, out_ptr, bias_ptr, output_channel, plane_size, plane_size, stride, relu_type, C8NUM); #else size_t oc8mod = output_channel % C8NUM; @@ -68,7 +68,7 @@ void PostConvFuncFp32C4(const float *c4_out_ptr, float *out_ptr, const float *bi return; } -#if !defined(ENABLE_ARM) && !defined(ENABLE_X86_64_SSE) +#if !defined(ENABLE_ARM) && !defined(ENABLE_SSE) void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length) { const int unitStep = 4 * length; for (int y = 0; y < h; ++y) { diff --git a/mindspore/lite/nnacl/fp32/common_func_fp32.h b/mindspore/lite/nnacl/fp32/common_func_fp32.h index 6c33afef44..a6b7c09cb7 100644 --- a/mindspore/lite/nnacl/fp32/common_func_fp32.h +++ b/mindspore/lite/nnacl/fp32/common_func_fp32.h @@ -39,7 +39,7 @@ float ShortToFloat32(uint16_t src_value); uint16_t Float32ToShort(float src_value); -#if defined(ENABLE_ARM) || defined(ENABLE_X86_64_SSE) +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) void ConvDwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, size_t relu, size_t relu6); diff --git a/mindspore/lite/nnacl/fp32/conv_depthwise_fp32.c b/mindspore/lite/nnacl/fp32/conv_depthwise_fp32.c index 609c66abcb..372cefa8e7 100644 --- a/mindspore/lite/nnacl/fp32/conv_depthwise_fp32.c +++ b/mindspore/lite/nnacl/fp32/conv_depthwise_fp32.c @@ -202,7 +202,7 @@ void ConvDwBorder(float *dst, const float *src, const float *weight, const float const float *src_kernel = src_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_; const float *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C4NUM; -#if defined(ENABLE_ARM) || defined(ENABLE_X86_64_SSE) +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) ConvDwFp32Border(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, sliding->in_kh_step_ * sizeof(float), sliding->in_kw_step_ * sizeof(float), conv_param->kernel_w_ * C4NUM * sizeof(float), relu, relu6); @@ -285,7 +285,7 @@ void ConvDwSWFp32(float *output_data, const float *input_data, const float *weig int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_; const float *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * sliding->block_channel_; float *out_t = dst_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_; -#if defined(ENABLE_ARM) || defined(ENABLE_X86_64_SSE) +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) ConvDwFp32Center(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float), sliding->block_channel_ * sizeof(float), sliding->in_sh_step_ * sizeof(float), @@ -839,7 +839,7 @@ void DeconvDwSWFp32(float *output_data, const float *input_data, const float *we float *out_t = dst_data + oh_h_start * sliding->in_h_step_ + oh_w_start * sliding->block_channel_; const float *in_t = src_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_; -#if defined(ENABLE_ARM) || defined(ENABLE_X86_64_SSE) +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) DeconvDwFp32Center(out_t, in_t, weight, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float), sliding->block_channel_ * sizeof(float), sliding->in_sh_step_ * sizeof(float), diff --git a/mindspore/lite/nnacl/fp32/conv_fp32.c b/mindspore/lite/nnacl/fp32/conv_fp32.c index d76e101237..7715e0d374 100644 --- a/mindspore/lite/nnacl/fp32/conv_fp32.c +++ b/mindspore/lite/nnacl/fp32/conv_fp32.c @@ -26,7 +26,9 @@ void ConvFp32(const float *input_data, float *packed_input, const float *packed_ int out_channel = conv_param->output_channel_; int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; int output_count = conv_param->output_h_ * conv_param->output_w_; -#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) +#ifdef ENABLE_AVX + const int cal_num = C6NUM; +#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE) const int cal_num = C4NUM; #else const int cal_num = C12NUM; @@ -48,7 +50,9 @@ void ConvFp32(const float *input_data, float *packed_input, const float *packed_ int out_offset = thread_id * cal_num * out_channel + out_batch_offset; float *gemm_output = output_data + out_offset; -#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) +#ifdef ENABLE_AVX + RowMajor2Col6Major(gemm_input, col_major_gemm_input, cal_num, deep); +#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE) RowMajor2Col4Major(gemm_input, col_major_gemm_input, cal_num, deep); #else RowMajor2Col12Major(gemm_input, col_major_gemm_input, cal_num, deep); @@ -97,7 +101,7 @@ void ConvWinogardFp32(const float *input_data, const float *trans_weight, const float *dst_ptr = gemm_out + task_id * gemm_out_offset; float *tmp_col_ptr = col_buffer + task_id * col_buffer_offset; for (int i = 0; i < input_unit_square; ++i) { -#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) +#if defined(ENABLE_ARM32) || defined(ENABLE_SSE) RowMajor2Col4Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel); #else RowMajor2Col12Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel); diff --git a/mindspore/lite/nnacl/fp32/deconv_fp32.c b/mindspore/lite/nnacl/fp32/deconv_fp32.c index b644967c8e..bee3802e07 100644 --- a/mindspore/lite/nnacl/fp32/deconv_fp32.c +++ b/mindspore/lite/nnacl/fp32/deconv_fp32.c @@ -41,7 +41,7 @@ void DeConvPostFp32C8(const float *src, float *tmp, const float *bias, float *ds size_t kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_; size_t output_plane = conv_param->output_w_ * conv_param->output_h_; int oc8 = UP_ROUND(output_channel, C8NUM); -#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) +#if defined(ENABLE_ARM32) || defined(ENABLE_SSE) const int tile_num = 4; #else const int tile_num = 12; diff --git a/mindspore/lite/nnacl/fp32/matmul_fp32.c b/mindspore/lite/nnacl/fp32/matmul_fp32.c index 760e71b113..13ca315f9b 100644 --- a/mindspore/lite/nnacl/fp32/matmul_fp32.c +++ b/mindspore/lite/nnacl/fp32/matmul_fp32.c @@ -28,9 +28,21 @@ void RowMajor2Row4Major(const float *src_ptr, float *dst_ptr, int row, int col) for (int r = 0; r < row; r++) { const float *src = src_ptr + r * col; for (int c = 0; c < col; c++) { - int cd8 = c / 4; - int cm8 = c % 4; - dst_ptr[cd8 * 4 * row + r * 4 + cm8] = src[c]; + int cd4 = c / C4NUM; + int cm4 = c % C4NUM; + dst_ptr[cd4 * C4NUM * row + r * C4NUM + cm4] = src[c]; + } + } + return; +} + +void RowMajor2Row6Major(const float *src_ptr, float *dst_ptr, int row, int col) { + for (int r = 0; r < row; r++) { + const float *src = src_ptr + r * col; + for (int c = 0; c < col; c++) { + int cd6 = c / C6NUM; + int cm6 = c % C6NUM; + dst_ptr[cd6 * C6NUM * row + r * C6NUM + cm6] = src[c]; } } return; @@ -40,9 +52,9 @@ void RowMajor2Row8Major(const float *src_ptr, float *dst_ptr, int row, int col) for (int r = 0; r < row; r++) { const float *src = src_ptr + r * col; for (int c = 0; c < col; c++) { - int cd8 = c / 8; - int cm8 = c % 8; - dst_ptr[cd8 * 8 * row + r * 8 + cm8] = src[c]; + int cd8 = c / C8NUM; + int cm8 = c % C8NUM; + dst_ptr[cd8 * C8NUM * row + r * C8NUM + cm8] = src[c]; } } return; @@ -52,9 +64,21 @@ void RowMajor2Row12Major(const float *src_ptr, float *dst_ptr, int row, int col) for (int r = 0; r < row; r++) { const float *src = src_ptr + r * col; for (int c = 0; c < col; c++) { - int cd8 = c / C12NUM; - int cm8 = c % C12NUM; - dst_ptr[cd8 * C12NUM * row + r * C12NUM + cm8] = src[c]; + int cd12 = c / C12NUM; + int cm12 = c % C12NUM; + dst_ptr[cd12 * C12NUM * row + r * C12NUM + cm12] = src[c]; + } + } + return; +} + +void RowMajor2Row16Major(const float *src_ptr, float *dst_ptr, int row, int col) { + for (int r = 0; r < row; r++) { + const float *src = src_ptr + r * col; + for (int c = 0; c < col; c++) { + int cd16 = c / C16NUM; + int cm16 = c % C16NUM; + dst_ptr[cd16 * C16NUM * row + r * C16NUM + cm16] = src[c]; } } return; @@ -190,7 +214,7 @@ void RowMajor2Col12Major(const float *src_ptr, float *dst_ptr, size_t row, size_ : : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride) : "r10", "r12", "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); -#elif ENABLE_X86_64_SSE +#elif ENABLE_SSE __m128 src1 = _mm_loadu_ps(src_c); __m128 src2 = _mm_loadu_ps(src_c + col); __m128 src3 = _mm_loadu_ps(src_c + 2 * col); @@ -421,7 +445,7 @@ void RowMajor2Col8Major(const float *src_ptr, float *dst_ptr, size_t row, size_t : : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride) : "r10", "r11", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7"); -#elif ENABLE_X86_64_SSE +#elif ENABLE_SSE /* 8x4 row-major to col-major */ __m128 src1 = _mm_loadu_ps(src_c); __m128 src2 = _mm_loadu_ps(src_c + col); @@ -478,6 +502,145 @@ void RowMajor2Col8Major(const float *src_ptr, float *dst_ptr, size_t row, size_t return; } +void RowMajor2Col16Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col) { + size_t row16 = row / C16NUM * C16NUM; + size_t col_skip = col / C4NUM * C4NUM; + int skip_size = C4NUM; + const float *src_r = src_ptr; + float *dst_r = dst_ptr; + + size_t ri = 0; + for (; ri < row16; ri += C16NUM) { + size_t ci = 0; + for (; ci < col_skip; ci += skip_size) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C16NUM; + for (int tr = 0; tr < C16NUM; tr++) { + for (int tc = 0; tc < C4NUM; tc++) { + dst_c[tc * C16NUM + tr] = src_c[tr * col + tc]; + } + } + } + for (; ci < col; ci++) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C16NUM; + for (size_t i = 0; i < C16NUM; i++) { + dst_c[i] = src_c[i * col]; + } + } + src_r += C16NUM * col; + dst_r += C16NUM * col; + } + for (; ri < row; ri++) { + for (size_t i = 0; i < col; i++) { + dst_r[i * C16NUM] = src_r[i]; + } + src_r += col; + dst_r += 1; + } + return; +} + +void RowMajor2Col6Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col) { + size_t totalRow = UP_ROUND(row, C6NUM); + size_t row6 = row / C6NUM * C6NUM; + size_t col8 = col / C8NUM * C8NUM; + const float *src_r = src_ptr; + float *dst_r = dst_ptr; + + size_t ri = 0; + for (; ri < row6; ri += C6NUM) { + size_t ci = 0; + for (; ci < col8; ci += C8NUM) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C6NUM; + + /* 6x8 row-major to col-major */ +#ifdef ENABLE_AVX + __m256 src0 = _mm256_loadu_ps(src_c); + __m256 src1 = _mm256_loadu_ps(src_c + col); + __m256 src2 = _mm256_loadu_ps(src_c + 2 * col); + __m256 src3 = _mm256_loadu_ps(src_c + 3 * col); + __m256 src4 = _mm256_loadu_ps(src_c + 4 * col); + __m256 src5 = _mm256_loadu_ps(src_c + 5 * col); + __m256 trans0 = _mm256_unpacklo_ps(src0, src1); + __m256 trans1 = _mm256_unpacklo_ps(src2, src3); + __m256 trans2 = _mm256_unpacklo_ps(src4, src5); + __m256 trans3 = _mm256_unpackhi_ps(src0, src1); + __m256 trans4 = _mm256_unpackhi_ps(src2, src3); + __m256 trans5 = _mm256_unpackhi_ps(src4, src5); + __m128 lo0 = _mm256_castps256_ps128(trans0); + __m128 lo1 = _mm256_castps256_ps128(trans1); + __m128 lo2 = _mm256_castps256_ps128(trans2); + __m128 lo3 = _mm256_castps256_ps128(trans3); + __m128 lo4 = _mm256_castps256_ps128(trans4); + __m128 lo5 = _mm256_castps256_ps128(trans5); + __m128 hi0 = _mm256_extractf128_ps(trans0, 1); + __m128 hi1 = _mm256_extractf128_ps(trans1, 1); + __m128 hi2 = _mm256_extractf128_ps(trans2, 1); + __m128 hi3 = _mm256_extractf128_ps(trans3, 1); + __m128 hi4 = _mm256_extractf128_ps(trans4, 1); + __m128 hi5 = _mm256_extractf128_ps(trans5, 1); + __m128 res0 = _mm_shuffle_ps(lo0, lo1, _MM_SHUFFLE(1, 0, 1, 0)); + __m128 res1 = _mm_shuffle_ps(lo2, lo0, _MM_SHUFFLE(3, 2, 1, 0)); + __m128 res2 = _mm_shuffle_ps(lo1, lo2, _MM_SHUFFLE(3, 2, 3, 2)); + __m128 res3 = _mm_shuffle_ps(lo3, lo4, _MM_SHUFFLE(1, 0, 1, 0)); + __m128 res4 = _mm_shuffle_ps(lo5, lo3, _MM_SHUFFLE(3, 2, 1, 0)); + __m128 res5 = _mm_shuffle_ps(lo4, lo5, _MM_SHUFFLE(3, 2, 3, 2)); + __m128 res6 = _mm_shuffle_ps(hi0, hi1, _MM_SHUFFLE(1, 0, 1, 0)); + __m128 res7 = _mm_shuffle_ps(hi2, hi0, _MM_SHUFFLE(3, 2, 1, 0)); + __m128 res8 = _mm_shuffle_ps(hi1, hi2, _MM_SHUFFLE(3, 2, 3, 2)); + __m128 res9 = _mm_shuffle_ps(hi3, hi4, _MM_SHUFFLE(1, 0, 1, 0)); + __m128 res10 = _mm_shuffle_ps(hi5, hi3, _MM_SHUFFLE(3, 2, 1, 0)); + __m128 res11 = _mm_shuffle_ps(hi4, hi5, _MM_SHUFFLE(3, 2, 3, 2)); + _mm_storeu_ps(dst_c, res0); + _mm_storeu_ps(dst_c + 4, res1); + _mm_storeu_ps(dst_c + 8, res2); + _mm_storeu_ps(dst_c + 12, res3); + _mm_storeu_ps(dst_c + 16, res4); + _mm_storeu_ps(dst_c + 20, res5); + _mm_storeu_ps(dst_c + 24, res6); + _mm_storeu_ps(dst_c + 28, res7); + _mm_storeu_ps(dst_c + 32, res8); + _mm_storeu_ps(dst_c + 36, res9); + _mm_storeu_ps(dst_c + 40, res10); + _mm_storeu_ps(dst_c + 44, res11); +#else + for (int tr = 0; tr < C6NUM; tr++) { + for (int tc = 0; tc < C8NUM; tc++) { + dst_c[tc * C6NUM + tr] = src_c[tr * col + tc]; + } + } +#endif + } + for (; ci < col; ci++) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C6NUM; + for (size_t i = 0; i < C6NUM; i++) { + dst_c[i] = src_c[i * col]; + } + } + src_r += C6NUM * col; + dst_r += C6NUM * col; + } + + for (; ri < row; ri++) { + for (size_t i = 0; i < col; i++) { + dst_r[i * C6NUM] = src_r[i]; + } + src_r += col; + dst_r += 1; + } + + for (; ri < totalRow; ri++) { + for (size_t i = 0; i < col; i++) { + dst_r[i * C6NUM] = 0; + } + dst_r += 1; + } + return; +} + void RowMajor2Col4Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col) { size_t row8 = row / C4NUM * C4NUM; size_t col4 = col / C4NUM * C4NUM; @@ -519,7 +682,7 @@ void RowMajor2Col4Major(const float *src_ptr, float *dst_ptr, size_t row, size_t : : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride) : "r10", "r12", "q0", "q1", "q2", "q3"); -#elif ENABLE_X86_64_SSE +#elif ENABLE_SSE __m128 src1 = _mm_loadu_ps(src_c); __m128 src2 = _mm_loadu_ps(src_c + col); __m128 src3 = _mm_loadu_ps(src_c + 2 * col); @@ -630,6 +793,34 @@ void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, A return; } +#ifdef ENABLE_AVX +#ifdef WIN32 +void MatMul6x16(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row, + int col, int stride, int out_type) { + if (out_type == OutType_Nhwc) { + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int r6div = r / C6NUM, r6mod = r % C6NUM; + int c16div = c / C16NUM, c16mod = c % C16NUM; + size_t ci = r * stride + c; + float value = 0; + for (int d = 0; d < deep; d++) { + size_t ai = r6div * deep * C6NUM + d * C6NUM + r6mod; + size_t bi = c16div * deep * C16NUM + d * C16NUM + c16mod; + value = value + a[ai] * b[bi]; + } + if (bias != NULL) value += bias[c]; + if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); + if (act_type != ActType_No) value = MSMAX(0.0f, value); + dst[ci] = value; + } + } + } + return; +} +#endif +#endif + void MatMul4x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row, int col, int stride, int out_type) { if (out_type == OutType_C8) { @@ -670,7 +861,19 @@ void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActT } else { MatmulFloatNeon32Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type)); } -#elif ENABLE_X86_64_SSE +#elif ENABLE_AVX + if (out_type == OutType_Nhwc) { +#ifdef WIN32 + MatMul6x16(a, b, c, bias, act_type, deep, row, col, stride, out_type); +#else + MatmulFloatAvxOpt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type)); +#endif + } else if (out_type == OutType_C8) { + MatmulFloatSse64(a, b, c, bias, (int)act_type, deep, row, col, stride, 0, 0); + } else { + MatmulFloatSse64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type)); + } +#elif ENABLE_SSE if (out_type == OutType_C8) { MatmulFloatSse64(a, b, c, bias, (int)act_type, deep, row, col, stride, 0, 0); } else { diff --git a/mindspore/lite/nnacl/fp32/matmul_fp32.h b/mindspore/lite/nnacl/fp32/matmul_fp32.h index b9c5dba47c..76db304915 100644 --- a/mindspore/lite/nnacl/fp32/matmul_fp32.h +++ b/mindspore/lite/nnacl/fp32/matmul_fp32.h @@ -31,11 +31,15 @@ void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActT void MatVecMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int depth, int col); void RowMajor2ColMajor(const float *src_ptr, float *dst_ptr, int row, int col); void RowMajor2Row4Major(const float *src_ptr, float *dst_ptr, int row, int col); +void RowMajor2Row6Major(const float *src_ptr, float *dst_ptr, int row, int col); void RowMajor2Row8Major(const float *src_ptr, float *dst_ptr, int row, int col); void RowMajor2Row12Major(const float *src_ptr, float *dst_ptr, int row, int col); +void RowMajor2Row16Major(const float *src_ptr, float *dst_ptr, int row, int col); void RowMajor2Col4Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col); +void RowMajor2Col6Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col); void RowMajor2Col8Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col); void RowMajor2Col12Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col); +void RowMajor2Col16Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col); #ifdef ENABLE_ARM void MatVecMulFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int col); #endif @@ -49,11 +53,16 @@ void MatmulFloatNeon32(const float *a, const float *b, float *c, const float *bi int col, int stride, size_t writeNhwc, size_t WriteWino); void MatmulFloatNeon32Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, int col, int stride, int write_mode); -#elif ENABLE_X86_64_SSE +#elif ENABLE_SSE +#include void MatmulFloatSse64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, int col, int stride, size_t writeNhwc, size_t WriteWino); void MatmulFloatSse64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, int col, int stride, int write_mode); +#ifdef ENABLE_AVX +void MatmulFloatAvxOpt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, + int col, int stride, int write_mode); +#endif #endif #ifdef ENABLE_NNACL_INFER_SHAPE diff --git a/mindspore/lite/nnacl/matmul_parameter.h b/mindspore/lite/nnacl/matmul_parameter.h index 22a10a9597..6b0e2ad8ff 100644 --- a/mindspore/lite/nnacl/matmul_parameter.h +++ b/mindspore/lite/nnacl/matmul_parameter.h @@ -42,6 +42,7 @@ typedef struct MatMulParameter { int row_; int col_; int row_4_; + int row_6_; int row_8_; int row_12_; int row_16_; diff --git a/mindspore/lite/nnacl/minimal_filtering_generator.c b/mindspore/lite/nnacl/minimal_filtering_generator.c index f5d8a5e077..acb8dbdce4 100644 --- a/mindspore/lite/nnacl/minimal_filtering_generator.c +++ b/mindspore/lite/nnacl/minimal_filtering_generator.c @@ -125,7 +125,7 @@ int B(const float *poly_array, float *matrix_b, int in_unit) { return NNACL_OK; } -#if !defined(ENABLE_ARM) && !defined(ENABLE_X86_64_SSE) +#if !defined(ENABLE_ARM) && !defined(ENABLE_SSE) void MatrixMultiplyWinograd(const float *matix_a, const float *matrix_b, float *matrix_c, int m, int k, int n, int in_channel, int c4_channel) { int cnt = 0; @@ -228,7 +228,7 @@ int CookToomFilter(float *matrix_a, float *matrix_at, float *matrix_b, float *ma return NNACL_OK; } -#if defined(ENABLE_ARM) || defined(ENABLE_X86_64_SSE) +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) void MatrixMultiplyVec(const MS_FLOAT32X4 *matrix_a, const MS_FLOAT32X4 *matrix_b, MS_FLOAT32X4 *matrix_c, const float *bias, int m, int k, int n) { int count = 0; diff --git a/mindspore/lite/nnacl/minimal_filtering_generator.h b/mindspore/lite/nnacl/minimal_filtering_generator.h index 385bf244ee..46f4357b11 100644 --- a/mindspore/lite/nnacl/minimal_filtering_generator.h +++ b/mindspore/lite/nnacl/minimal_filtering_generator.h @@ -52,7 +52,7 @@ void MatrixMultiplyWinograd(const float *matix_a, const float *matrix_b, float * int WinogradWeightTransform(const float *weight_data, float *winograd_data, float *matrix_g, const float *matrix_gt, int oc_block, int input_unit_, int kernel_unit_, int channel, int batch, bool pack); -#if defined(ENABLE_ARM) || defined(ENABLE_X86_64_SSE) +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) void MatrixMultiplyVec(const MS_FLOAT32X4 *matrix_a, const MS_FLOAT32X4 *matrix_b, MS_FLOAT32X4 *matrix_c, const float *bias, int m, int k, int n); #endif diff --git a/mindspore/lite/nnacl/op_base.h b/mindspore/lite/nnacl/op_base.h index 6ae3042b17..e0fc630394 100644 --- a/mindspore/lite/nnacl/op_base.h +++ b/mindspore/lite/nnacl/op_base.h @@ -21,8 +21,8 @@ #include #endif -#ifdef ENABLE_X86_64_SSE -#include +#ifdef ENABLE_SSE +#include #endif #include @@ -31,6 +31,7 @@ #define C2NUM 2 #define C4NUM 4 +#define C6NUM 6 #define C8NUM 8 #define C12NUM 12 #define C16NUM 16 @@ -91,7 +92,7 @@ typedef enum ActType { ActType_No, ActType_Relu, ActType_Sigmod, ActType_Relu6, #define MS_MAXQ_F32 vmaxq_f32 #define MS_MINQ_F32 vminq_f32 #define MS_MULQ_F32(src1, src2) vmulq_n_f32(src1, src2) -#elif defined(ENABLE_X86_64_SSE) +#elif defined(ENABLE_SSE) #define MS_FLOAT32X4 __m128 #define MS_LDQ_F32 _mm_loadu_ps #define MS_ADDQ_F32 _mm_add_ps diff --git a/mindspore/lite/nnacl/pack.c b/mindspore/lite/nnacl/pack.c index c4434ae980..245fbb11ca 100644 --- a/mindspore/lite/nnacl/pack.c +++ b/mindspore/lite/nnacl/pack.c @@ -756,7 +756,7 @@ void PackNCHWToNHWCInt8(const void *src, void *dst, int batch, int plane, int ch return; } -#ifndef ENABLE_X86_64_SSE +#ifndef ENABLE_SSE void PackNHWCToNCHWFp32(const void *src, void *dst, int batches, int plane, int channel) { int hw8 = plane / C8NUM * C8NUM; int c8 = channel / C8NUM * C8NUM; diff --git a/mindspore/lite/nnacl/winograd_utils.c b/mindspore/lite/nnacl/winograd_utils.c index fda1d925be..28f89e48fa 100644 --- a/mindspore/lite/nnacl/winograd_utils.c +++ b/mindspore/lite/nnacl/winograd_utils.c @@ -79,7 +79,7 @@ void GeneralInputTransformUnit(const float *src_data, float *dst_data, const flo int src_step, int dst_step, int in_unit) { int len = in_unit * in_unit; if (len > MAX_LEN) return; -#if defined(ENABLE_ARM) || defined(ENABLE_X86_64_SSE) +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) MS_FLOAT32X4 src[MAX_LEN]; MS_FLOAT32X4 t[MAX_LEN]; MS_FLOAT32X4 m[MAX_LEN]; diff --git a/mindspore/lite/nnacl/x86_64_sse/DepthwiseFp32_Sse.c b/mindspore/lite/nnacl/x86_64_sse/DepthwiseFp32_Sse.c index bd2bad1487..486f1bd87b 100644 --- a/mindspore/lite/nnacl/x86_64_sse/DepthwiseFp32_Sse.c +++ b/mindspore/lite/nnacl/x86_64_sse/DepthwiseFp32_Sse.c @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifdef ENABLE_X86_64_SSE -#include +#ifdef ENABLE_SSE +#include #include "nnacl/fp32/conv_depthwise_fp32.h" void ConvDwFp32Border(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, diff --git a/mindspore/lite/nnacl/x86_64_sse/MatMul_Sse.c b/mindspore/lite/nnacl/x86_64_sse/MatMul_Sse.c index 9b4ad28147..75d5a563d7 100644 --- a/mindspore/lite/nnacl/x86_64_sse/MatMul_Sse.c +++ b/mindspore/lite/nnacl/x86_64_sse/MatMul_Sse.c @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifdef ENABLE_X86_64_SSE -#include +#ifdef ENABLE_SSE +#include #include "nnacl/minimal_filtering_generator.h" #include "nnacl/op_base.h" diff --git a/mindspore/lite/nnacl/x86_64_sse/PackNHWCToNCHWFp32.c b/mindspore/lite/nnacl/x86_64_sse/PackNHWCToNCHWFp32.c index ea9bd43ba2..26f602610e 100644 --- a/mindspore/lite/nnacl/x86_64_sse/PackNHWCToNCHWFp32.c +++ b/mindspore/lite/nnacl/x86_64_sse/PackNHWCToNCHWFp32.c @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifdef ENABLE_X86_64_SSE -#include +#ifdef ENABLE_SSE +#include #include "nnacl/pack.h" #include "nnacl/int8/conv_int8.h" diff --git a/mindspore/lite/nnacl/x86_64_sse/PosFuncBiasRelu.c b/mindspore/lite/nnacl/x86_64_sse/PosFuncBiasRelu.c index ffc54c1988..3fbded1ff7 100644 --- a/mindspore/lite/nnacl/x86_64_sse/PosFuncBiasRelu.c +++ b/mindspore/lite/nnacl/x86_64_sse/PosFuncBiasRelu.c @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifdef ENABLE_X86_64_SSE -#include +#ifdef ENABLE_SSE +#include #include "nnacl/fp32/common_func_fp32.h" void PostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t oc8div, size_t oc8mod, diff --git a/mindspore/lite/nnacl/x86_64_sse/WinogradTrans.c b/mindspore/lite/nnacl/x86_64_sse/WinogradTrans.c index 03282532a7..04e4f2333c 100644 --- a/mindspore/lite/nnacl/x86_64_sse/WinogradTrans.c +++ b/mindspore/lite/nnacl/x86_64_sse/WinogradTrans.c @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifdef ENABLE_X86_64_SSE -#include +#ifdef ENABLE_SSE +#include #include "nnacl/fp32/common_func_fp32.h" void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length) { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1_fp32.cc index 261215d386..b71914f669 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1_fp32.cc @@ -60,6 +60,7 @@ void Convolution1x1CPUKernel::InitConv1x1MatmulParam() { matmul_param_->col_ = conv_param_->output_channel_; matmul_param_->deep_ = conv_param_->input_channel_; matmul_param_->row_4_ = UP_ROUND(matmul_param_->row_, C4NUM); + matmul_param_->row_6_ = UP_ROUND(matmul_param_->row_, C6NUM); matmul_param_->row_12_ = UP_ROUND(matmul_param_->row_, C12NUM); matmul_param_->col_8_ = UP_ROUND(matmul_param_->col_, C8NUM); matmul_param_->act_type_ = conv_param_->act_type_; @@ -71,8 +72,13 @@ int Convolution1x1CPUKernel::InitConv1x1BiasWeight() { auto input_channel = filter_tensor->Channel(); auto output_channel = filter_tensor->Batch(); +#ifdef ENABLE_AVX + int col_tile = C16NUM; +#else + int col_tile = C8NUM; +#endif if (in_tensors_.size() == 3) { - int size = UP_ROUND(output_channel, C8NUM) * sizeof(float); + int size = UP_ROUND(output_channel, col_tile) * sizeof(float); int weight_size = output_channel * sizeof(float); bias_data_ = malloc(size); if (bias_data_ == nullptr) { @@ -83,22 +89,29 @@ int Convolution1x1CPUKernel::InitConv1x1BiasWeight() { memset(reinterpret_cast(bias_data_) + weight_size, 0, size - weight_size); } - int size = input_channel * UP_ROUND(output_channel, C8NUM) * sizeof(float); - int down_size = input_channel * DOWN_DIV(output_channel, C8NUM) * C8NUM * sizeof(float); + int size = input_channel * UP_ROUND(output_channel, col_tile) * sizeof(float); + int down_size = input_channel * DOWN_DIV(output_channel, col_tile) * col_tile * sizeof(float); weight_ptr_ = reinterpret_cast(malloc(size)); if (weight_ptr_ == nullptr) { MS_LOG(ERROR) << "Conv1x1 Malloc weight_ptr_ error!"; return RET_ERROR; } memset(reinterpret_cast(weight_ptr_) + down_size, 0, size - down_size); +#ifdef ENABLE_AVX + RowMajor2Col16Major(reinterpret_cast(filter_tensor->MutableData()), weight_ptr_, output_channel, + input_channel); +#else RowMajor2Col8Major(reinterpret_cast(filter_tensor->MutableData()), weight_ptr_, output_channel, input_channel); +#endif return RET_OK; } int Convolution1x1CPUKernel::InitConv1x1Param() { int hw_tile = C12NUM; -#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) +#ifdef ENABLE_AVX + hw_tile = C6NUM; +#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE) hw_tile = C4NUM; #endif if ((matmul_param_->row_ > (hw_tile * op_parameter_->thread_num_)) && (matmul_param_->row_ > matmul_param_->col_)) { @@ -106,9 +119,14 @@ int Convolution1x1CPUKernel::InitConv1x1Param() { thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->row_, hw_tile)); thread_stride_ = UP_DIV(UP_DIV(matmul_param_->row_, hw_tile), thread_count_) * hw_tile; } else { +#ifdef ENABLE_AVX + int col_tile = C16NUM; +#else + int col_tile = C8NUM; +#endif multi_thread_by_hw_ = false; - thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->col_, C8NUM)); - thread_stride_ = UP_DIV(UP_DIV(matmul_param_->col_, C8NUM), thread_count_) * C8NUM; + thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->col_, col_tile)); + thread_stride_ = UP_DIV(UP_DIV(matmul_param_->col_, col_tile), thread_count_) * col_tile; } pre_trans_input_ = (conv_param_->pad_u_ != 0 || conv_param_->pad_l_ != 0 || conv_param_->stride_h_ != 1 || @@ -175,7 +193,9 @@ int Convolution1x1CPUKernel::DoConv1x1Hw(int task_id) { float *thread_input_ptr = input_ptr_ + task_id * thread_stride_ * matmul_param_->deep_; float *thread_pack_input = pack_input_ + task_id * thread_stride_ * matmul_param_->deep_; -#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) +#if ENABLE_AVX + RowMajor2Col6Major(thread_input_ptr, thread_pack_input, cur_hw_, matmul_param_->deep_); +#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE) RowMajor2Col4Major(thread_input_ptr, thread_pack_input, cur_hw_, matmul_param_->deep_); #else RowMajor2Col12Major(thread_input_ptr, thread_pack_input, cur_hw_, matmul_param_->deep_); @@ -202,7 +222,10 @@ int Convolution1x1CPUKernel::Run() { auto src_in = reinterpret_cast(in_tensors_[0]->MutableData()); auto src_out = reinterpret_cast(out_tensors_[0]->MutableData()); -#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) +#ifdef ENABLE_AVX + pack_input_ = + reinterpret_cast(ctx_->allocator->Malloc(matmul_param_->row_6_ * matmul_param_->deep_ * sizeof(float))); +#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE) pack_input_ = reinterpret_cast(ctx_->allocator->Malloc(matmul_param_->row_4_ * matmul_param_->deep_ * sizeof(float))); #else @@ -226,7 +249,9 @@ int Convolution1x1CPUKernel::Run() { if (multi_thread_by_hw_) { ParallelLaunch(this->context_->thread_pool_, Convolution1x1RunHw, this, thread_count_); } else { -#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) +#ifdef ENABLE_AVX + RowMajor2Col6Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_); +#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE) RowMajor2Col4Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_); #else RowMajor2Col12Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_fp32.cc index a9ab73c564..f887d2faf4 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_fp32.cc @@ -42,7 +42,12 @@ int ConvolutionCPUKernel::InitWeightBias() { conv_param_->input_channel_ = in_channel; conv_param_->output_channel_ = out_channel; int kernel_plane = filter_tensor->Height() * filter_tensor->Width(); - int oc_block_num = UP_ROUND(out_channel, C8NUM); +#ifdef ENABLE_AVX + const int oc_block = C16NUM; +#else + const int oc_block = C8NUM; +#endif + int oc_block_num = UP_ROUND(out_channel, oc_block); int pack_weight_size = oc_block_num * in_channel * kernel_plane; auto origin_weight = reinterpret_cast(filter_tensor->data_c()); @@ -52,7 +57,11 @@ int ConvolutionCPUKernel::InitWeightBias() { return RET_ERROR; } memset(packed_weight_, 0, pack_weight_size * sizeof(float)); +#ifdef ENABLE_AVX + RowMajor2Col16Major(origin_weight, packed_weight_, out_channel, in_channel * kernel_plane); +#else RowMajor2Col8Major(origin_weight, packed_weight_, out_channel, in_channel * kernel_plane); +#endif bias_data_ = reinterpret_cast(malloc(oc_block_num * sizeof(float))); if (bias_data_ == nullptr) { @@ -72,7 +81,10 @@ int ConvolutionCPUKernel::InitWeightBias() { int ConvolutionCPUKernel::InitTmpBuffer() { MS_ASSERT(ctx_->allocator != nullptr); -#ifdef ENABLE_ARM32 + +#ifdef ENABLE_AVX + int unit_size = conv_param_->kernel_h_ * conv_param_->kernel_w_ * conv_param_->input_channel_ * C6NUM * thread_count_; +#elif ENABLE_ARM32 || ENABLE_SSE int unit_size = conv_param_->kernel_h_ * conv_param_->kernel_w_ * conv_param_->input_channel_ * C4NUM * thread_count_; #else int unit_size = diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_fp32.cc index 533d889182..94ead04cef 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_fp32.cc @@ -115,7 +115,7 @@ int DeConvolutionCPUKernel::DoDeconv(int task_id) { return RET_OK; } -#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) +#if defined(ENABLE_ARM32) || defined(ENABLE_SSE) auto tmp_buffer = tmp_buffer_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->row_4_; MatMulOpt(pack_input_, weight_ptr_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->deep_, tmp_buffer, nullptr, ActType_No, matmul_param_->deep_, matmul_param_->row_4_, oc * C8NUM * kernel_plane_, @@ -174,7 +174,7 @@ int DeConvolutionCPUKernel::InitRunBuf() { return RET_NULL_PTR; } -#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) +#if defined(ENABLE_ARM32) || defined(ENABLE_SSE) tmp_buffer_ = reinterpret_cast(ctx_->allocator->Malloc(matmul_param_->row_4_ * matmul_param_->col_8_ * sizeof(float))); #else @@ -186,7 +186,7 @@ int DeConvolutionCPUKernel::InitRunBuf() { return RET_NULL_PTR; } -#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) +#if defined(ENABLE_ARM32) || defined(ENABLE_SSE) pack_input_ = reinterpret_cast(ctx_->allocator->Malloc(matmul_param_->row_4_ * matmul_param_->deep_ * sizeof(float))); #else @@ -215,7 +215,7 @@ int DeConvolutionCPUKernel::Run() { input_ptr_ = src_in + batch_index * input_plane_ * conv_param_->input_channel_; output_ptr_ = src_out + batch_index * output_plane_ * conv_param_->output_channel_; -#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) +#if defined(ENABLE_ARM32) || defined(ENABLE_SSE) RowMajor2Col4Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_); #else RowMajor2Col12Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection_fp32.cc index abf5149080..3d1ac6e91d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection_fp32.cc @@ -51,12 +51,18 @@ int FullconnectionCPUKernel::ReSize() { fc_param_->col_ = out_tensors_.at(0)->shape().back(); fc_param_->deep_ = (in_tensors_.at(1)->shape()).at(1); +#ifdef ENABLE_AVX + int col_tile = C16NUM; +#else + int col_tile = C8NUM; +#endif fc_param_->row_12_ = UP_ROUND(fc_param_->row_, C12NUM); - fc_param_->col_8_ = UP_ROUND(fc_param_->col_, C8NUM); + fc_param_->col_8_ = UP_ROUND(fc_param_->col_, col_tile); + fc_param_->row_6_ = UP_ROUND(fc_param_->col_, C6NUM); fc_param_->row_4_ = UP_ROUND(fc_param_->row_, C4NUM); - thread_count_ = MSMIN(thread_count_, UP_DIV(fc_param_->col_8_, 8)); - thread_stride_ = UP_DIV(UP_DIV(fc_param_->col_8_, 8), thread_count_); + thread_count_ = MSMIN(thread_count_, UP_DIV(fc_param_->col_8_, col_tile)); + thread_stride_ = UP_DIV(UP_DIV(fc_param_->col_8_, col_tile), thread_count_); #ifdef ENABLE_ARM if (fc_param_->row_ == 1) { @@ -75,7 +81,9 @@ int FullconnectionCPUKernel::ReSize() { memcpy(bias_ptr_, in_tensors_[2]->MutableData(), fc_param_->col_ * sizeof(float)); } -#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) +#ifdef ENABLE_AVX + int row_tmp = is_vector_input_ ? 1 : fc_param_->row_6_; +#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE) int row_tmp = is_vector_input_ ? 1 : fc_param_->row_4_; #else int row_tmp = is_vector_input_ ? 1 : fc_param_->row_12_; @@ -120,7 +128,9 @@ void FullconnectionCPUKernel::InitMatrixA(const float *src_ptr, float *dst_ptr) return; } -#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) +#ifdef ENABLE_AVX + RowMajor2Col6Major(src_ptr, a_pack_ptr_, fc_param_->row_, fc_param_->deep_); +#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE) RowMajor2Col4Major(src_ptr, a_pack_ptr_, fc_param_->row_, fc_param_->deep_); #else RowMajor2Col12Major(src_ptr, a_pack_ptr_, fc_param_->row_, fc_param_->deep_); @@ -132,8 +142,11 @@ void FullconnectionCPUKernel::InitMatrixB(const float *src_ptr, float *dst_ptr) memcpy(dst_ptr, src_ptr, fc_param_->col_ * fc_param_->deep_ * sizeof(float)); return; } - +#ifdef ENABLE_AVX + RowMajor2Col16Major(src_ptr, dst_ptr, fc_param_->col_, fc_param_->deep_); +#else RowMajor2Col8Major(src_ptr, dst_ptr, fc_param_->col_, fc_param_->deep_); +#endif } int FcFp32MatmulRun(void *cdata, int task_id) { @@ -147,14 +160,19 @@ int FcFp32MatmulRun(void *cdata, int task_id) { } int FullconnectionCPUKernel::DoMatmul(int task_id) { - int cur_oc = MSMIN(thread_stride_ * C8NUM, fc_param_->col_ - task_id * thread_stride_ * C8NUM); +#ifdef ENABLE_AVX + int col_tile = C16NUM; +#else + int col_tile = C8NUM; +#endif + int cur_oc = MSMIN(thread_stride_ * col_tile, fc_param_->col_ - task_id * thread_stride_ * col_tile); if (cur_oc <= 0) { return RET_OK; } - auto b = b_ptr_ + task_id * thread_stride_ * C8NUM * fc_param_->deep_; - auto bias = (bias_ptr_ == nullptr) ? nullptr : bias_ptr_ + task_id * thread_stride_ * C8NUM; - auto c = c_ptr_ + task_id * thread_stride_ * C8NUM; + auto b = b_ptr_ + task_id * thread_stride_ * col_tile * fc_param_->deep_; + auto bias = (bias_ptr_ == nullptr) ? nullptr : bias_ptr_ + task_id * thread_stride_ * col_tile; + auto c = c_ptr_ + task_id * thread_stride_ * col_tile; if (is_vector_input_) { MatVecMul(a_ptr_, b, c, bias, fc_param_->act_type_, fc_param_->deep_, cur_oc); } else { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32.cc index f4637b402f..410c0f6e67 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32.cc @@ -75,9 +75,12 @@ int MatmulCPUKernel::MallocMatrixABuffer() { #endif params_->deep_ = params_->a_transpose_ ? a_shape[a_shape.size() - 2] : a_shape[a_shape.size() - 1]; params_->row_4_ = UP_ROUND(params_->row_, C4NUM); + params_->row_6_ = UP_ROUND(params_->row_, C6NUM); params_->row_12_ = UP_ROUND(params_->row_, C12NUM); -#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) +#ifdef ENABLE_AVX + int row_tmp = is_vector_a_ ? 1 : params_->row_6_; +#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE) int row_tmp = is_vector_a_ ? 1 : params_->row_4_; #else int row_tmp = is_vector_a_ ? 1 : params_->row_12_; @@ -106,9 +109,14 @@ int MatmulCPUKernel::MallocMatrixBBuffer() { for (size_t i = 0; i < b_shape.size() - 2; ++i) { batch *= b_shape[i]; } +#ifdef ENABLE_AVX + int col_tile = C16NUM; +#else + int col_tile = C8NUM; +#endif params_->batch = batch; params_->col_ = params_->b_transpose_ ? b_shape[b_shape.size() - 2] : b_shape[b_shape.size() - 1]; - params_->col_8_ = UP_ROUND(params_->col_, 8); + params_->col_8_ = UP_ROUND(params_->col_, col_tile); params_->deep_ = params_->b_transpose_ ? b_shape[b_shape.size() - 1] : b_shape[b_shape.size() - 2]; int col_tmp = is_vector_a_ ? params_->col_ : params_->col_8_; @@ -123,8 +131,8 @@ int MatmulCPUKernel::MallocMatrixBBuffer() { return RET_MEMORY_FAILED; } - thread_count_ = MSMIN(thread_count_, UP_DIV(params_->col_8_, 8)); - thread_stride_ = UP_DIV(UP_DIV(params_->col_8_, 8), thread_count_); + thread_count_ = MSMIN(thread_count_, UP_DIV(params_->col_8_, col_tile)); + thread_stride_ = UP_DIV(UP_DIV(params_->col_8_, col_tile), thread_count_); return RET_OK; } @@ -134,7 +142,12 @@ int MatmulCPUKernel::InitBias() { params_->col_ = params_->b_const_ ? (params_->b_transpose_ ? b_shape.at(b_shape.size() - 2) : b_shape.at(b_shape.size() - 1)) : (c_shape.at(c_shape.size() - 1)); - params_->col_8_ = UP_ROUND(params_->col_, 8); +#ifdef ENABLE_AVX + int col_tile = C16NUM; +#else + int col_tile = C8NUM; +#endif + params_->col_8_ = UP_ROUND(params_->col_, col_tile); auto col_tmp = is_vector_a_ ? params_->col_ : params_->col_8_; if (bias_ptr_ == nullptr) { bias_ptr_ = reinterpret_cast(malloc(col_tmp * sizeof(float))); @@ -171,7 +184,14 @@ void MatmulCPUKernel::InitMatrixA(const float *src_ptr, float *dst_ptr) { for (int i = 0; i < params_->batch; i++) { const float *src = src_ptr + i * params_->deep_ * params_->row_; -#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) +#ifdef ENABLE_AVX + float *dst = dst_ptr + i * params_->deep_ * params_->row_6_; + if (params_->a_transpose_) { + RowMajor2Row6Major(src, dst, params_->deep_, params_->row_); + } else { + RowMajor2Col6Major(src, dst, params_->row_, params_->deep_); + } +#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE) float *dst = dst_ptr + i * params_->deep_ * params_->row_4_; if (params_->a_transpose_) { RowMajor2Row4Major(src, dst, params_->deep_, params_->row_); @@ -207,11 +227,19 @@ void MatmulCPUKernel::InitMatrixB(const float *src_ptr, float *dst_ptr) { for (int i = 0; i < params_->batch; i++) { const float *src = src_ptr + i * params_->deep_ * params_->col_; float *dst = dst_ptr + i * params_->deep_ * params_->col_8_; +#ifdef ENABLE_AVX + if (params_->b_transpose_) { + RowMajor2Col16Major(src, dst, params_->col_, params_->deep_); + } else { + RowMajor2Row16Major(src, dst, params_->deep_, params_->col_); + } +#else if (params_->b_transpose_) { RowMajor2Col8Major(src, dst, params_->col_, params_->deep_); } else { RowMajor2Row8Major(src, dst, params_->deep_, params_->col_); } +#endif } return; } @@ -247,13 +275,18 @@ int MatmulCPUKernel::Init() { } int MatmulCPUKernel::RunImpl(int task_id) { - int cur_oc = MSMIN(thread_stride_ * C8NUM, params_->col_ - task_id * thread_stride_ * C8NUM); +#ifdef ENABLE_AVX + int col_tile = C16NUM; +#else + int col_tile = C8NUM; +#endif + int cur_oc = MSMIN(thread_stride_ * col_tile, params_->col_ - task_id * thread_stride_ * col_tile); if (cur_oc <= 0) { return RET_OK; } - auto b = cur_b_ptr_ + task_id * thread_stride_ * C8NUM * params_->deep_; - auto c = cur_c_ptr_ + task_id * thread_stride_ * C8NUM; - auto bias = bias_ptr_ ? bias_ptr_ + task_id * thread_stride_ * C8NUM : NULL; + auto b = cur_b_ptr_ + task_id * thread_stride_ * col_tile * params_->deep_; + auto c = cur_c_ptr_ + task_id * thread_stride_ * col_tile; + auto bias = bias_ptr_ ? bias_ptr_ + task_id * thread_stride_ * col_tile : NULL; MS_ASSERT(cur_a_ptr_); MS_ASSERT(b); MS_ASSERT(c); @@ -323,7 +356,9 @@ int MatmulCPUKernel::Run() { cur_b_ptr_ = b_ptr_ + i * params_->deep_ * params_->col_; cur_c_ptr_ = c_src + i * params_->row_ * params_->col_; } else { -#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) +#ifdef ENABLE_AVX + cur_a_ptr_ = a_ptr_ + i * params_->row_6_ * params_->deep_; +#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE) cur_a_ptr_ = a_ptr_ + i * params_->row_4_ * params_->deep_; #else cur_a_ptr_ = a_ptr_ + i * params_->row_12_ * params_->deep_; diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index 0f070bd21e..3dead6c5a4 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -75,6 +75,16 @@ if ("${X86_64_SIMD}" STREQUAL "sse") ) endif() +if ("${X86_64_SIMD}" STREQUAL "avx") + file(GLOB TEST_ASSEMBLY_SRC ${LITE_DIR}/nnacl/x86_64_sse/*.c + ${LITE_DIR}/nnacl/assembly/avx/*.S) + set_property(SOURCE ${TEST_ASSEMBLY_SRC} PROPERTY LANGUAGE C) + set(KERNEL_OP_SRC + ${KERNEL_OP_SRC} + ${TEST_ASSEMBLY_SRC} + ) +endif() + ### gpu kernel if (SUPPORT_GPU) file(GLOB GPU_KERNEL_OP_SRC