From cf9d13b24e05964a763298041bf9b6945b4d553a Mon Sep 17 00:00:00 2001 From: lixian Date: Mon, 28 Sep 2020 11:27:41 +0800 Subject: [PATCH] optimization for fp32 matmul kernel on arm64 --- mindspore/lite/internal/CMakeLists.txt | 3 +- .../assembly/arm32/IndirectGemmFp32_8x4.S | 302 ---- .../lite/nnacl/assembly/arm32/MatmulFp32.S | 368 ++++ .../assembly/arm64/IndirectGemmFp32_8x8.S | 730 -------- .../lite/nnacl/assembly/arm64/MatmulFp32.S | 597 ++++--- .../lite/nnacl/assembly/arm64/MatmulFp32Opt.S | 1516 +++++++++-------- .../assembly/arm64/MatmulFp32OptRemain.S | 880 ++++++++-- mindspore/lite/nnacl/common_func.c | 149 -- mindspore/lite/nnacl/common_func.h | 12 - mindspore/lite/nnacl/fp32/matmul.c | 15 +- mindspore/lite/nnacl/fp32/matmul.h | 9 +- 11 files changed, 2307 insertions(+), 2274 deletions(-) delete mode 100644 mindspore/lite/nnacl/assembly/arm32/IndirectGemmFp32_8x4.S create mode 100644 mindspore/lite/nnacl/assembly/arm32/MatmulFp32.S delete mode 100644 mindspore/lite/nnacl/assembly/arm64/IndirectGemmFp32_8x8.S diff --git a/mindspore/lite/internal/CMakeLists.txt b/mindspore/lite/internal/CMakeLists.txt index 6237f6c369..cb8e6c076c 100644 --- a/mindspore/lite/internal/CMakeLists.txt +++ b/mindspore/lite/internal/CMakeLists.txt @@ -39,7 +39,8 @@ if (PLATFORM_ARM64) # assembly file(GLOB ASSEMBLY_SRC ${CMAKE_CURRENT_SOURCE_DIR}/../nnacl/assembly/arm64/MatmulFp32OptRemain.S - ${CMAKE_CURRENT_SOURCE_DIR}/../nnacl/assembly/arm64/MatmulFp32Opt.S) + ${CMAKE_CURRENT_SOURCE_DIR}/../nnacl/assembly/arm64/MatmulFp32Opt.S + ${CMAKE_CURRENT_SOURCE_DIR}/../nnacl/assembly/arm64/MatmulFp32.S) set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C) set(KERNEL_SRC ${KERNEL_SRC} ${ASSEMBLY_SRC}) add_library(mslite_internal SHARED ${CCSRC} ${KERNEL_SRC} ${TRAIN_KERNEL_SRC}) diff --git a/mindspore/lite/nnacl/assembly/arm32/IndirectGemmFp32_8x4.S b/mindspore/lite/nnacl/assembly/arm32/IndirectGemmFp32_8x4.S deleted file mode 100644 index 215178d35a..0000000000 --- a/mindspore/lite/nnacl/assembly/arm32/IndirectGemmFp32_8x4.S +++ /dev/null @@ -1,302 +0,0 @@ -#ifdef __arm__ -#ifndef __aarch64__ - -.text -.align 5 -.global IndirectGemmFp32_8x4 -#ifndef __APPLE__ -.type IndirectGemmFp32_8x4, %function -#endif - -// void IndirectGemmFp32_8x4(float *output, float *input, float *weight, float *bias, -// size_t kSize, size_t ic4, size_t oc8, size_t offset, size_t mode, size_t writeC4, size_t relu, size_t relu6); -// r0: output, r1: input, r2: weight, r3: bias, r4: kSize, r5: ic4, r6: oc, r7: offset -// r8:mode, r10: writeMode, r10: relu, r10:relu6 -// mode = 0 for general convolution, where one conv unit is a row -// mode = 1 for winograd/common gemm, where the total channels of one input is a row -IndirectGemmFp32_8x4: - - .macro INIT_BIAS - veor q8, q8, q8 - cmp r3, #0 - beq InitBias - vld1.32 {q8}, [r3] - InitBias: - vmov q9, q8 - vmov q10, q8 - vmov q11, q8 - vmov q12, q8 - vmov q13, q8 - vmov q14, q8 - vmov q15, q8 - .endm - - // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" - // according to https://stackoverflow.com/questions/53625807 - // even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway - // clang's rule seems more simple, though there are no subroutine calls here - // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf - push {r4-r8, r10, r11, lr} - vpush {q4-q7} - add sp, sp, #96 - - ldr r4, [sp] - ldr r5, [sp, #4] - ldr r6, [sp, #8] - ldr r7, [sp, #12] - ldr r8, [sp, #16] - - cmp r8, #0 - bne LoopOc - // step is one for common convolution, where ic8 should multiply by kernel size - // step is (a+b-1) for F(a,b) in winograd - mul r5, r4, r5 - mov r4, #1 - - LoopOc: - mov r8, r4 - mov r12, r1 - - LoopKsize: - - mov r11, r0 - INIT_BIAS - - // load input for output 1-2 - vld1.32 {q0, q1}, [r12]! - vld1.32 {q2, q3}, [r12]! - // load weight - vld1.32 {q4, q5}, [r2]! - // step for output 1-2 - vmla.f32 q8, q4, d0[0] - vmla.f32 q9, q4, d2[0] - vmla.f32 q8, q5, d0[1] - vmla.f32 q9, q5, d2[1] - vld1.32 {q6, q7}, [r2]! - - subs r10, r5, #1 - beq LoopIcEnd - - LoopIc: - vmla.f32 q8, q6, d1[0] - vmla.f32 q9, q6, d3[0] - vmla.f32 q8, q7, d1[1] - vmla.f32 q9, q7, d3[1] - vmla.f32 q10, q4, d4[0] - vmla.f32 q11, q4, d6[0] - vmla.f32 q10, q5, d4[1] - vmla.f32 q11, q5, d6[1] - vld1.s32 {q0, q1}, [r12]! - vmla.f32 q10, q6, d5[0] - vmla.f32 q11, q6, d7[0] - vmla.f32 q10, q7, d5[1] - vmla.f32 q11, q7, d7[1] - vld1.s32 {q2, q3}, [r12]! - vmla.f32 q12, q4, d0[0] - vmla.f32 q13, q4, d2[0] - vmla.f32 q12, q5, d0[1] - vmla.f32 q13, q5, d2[1] - vmla.f32 q14, q4, d4[0] - vmla.f32 q15, q4, d6[0] - vmla.f32 q14, q5, d4[1] - vmla.f32 q15, q5, d6[1] - vld1.s32 {q4, q5}, [r2]! - vmla.f32 q12, q6, d1[0] - vmla.f32 q13, q6, d3[0] - vmla.f32 q12, q7, d1[1] - vmla.f32 q13, q7, d3[1] - vld1.s32 {q0, q1}, [r12]! - vmla.f32 q14, q6, d5[0] - vmla.f32 q15, q6, d7[0] - vmla.f32 q14, q7, d5[1] - vmla.f32 q15, q7, d7[1] - vld1.s32 {q6, q7}, [r2]! - vmla.f32 q8, q4, d0[0] - vmla.f32 q9, q4, d2[0] - vmla.f32 q8, q5, d0[1] - vmla.f32 q9, q5, d2[1] - vld1.s32 {q2, q3}, [r12]! - - subs r10, r10, #1 - bne LoopIc - - LoopIcEnd: - vmla.f32 q8, q6, d1[0] - vmla.f32 q9, q6, d3[0] - vmla.f32 q8, q7, d1[1] - vmla.f32 q9, q7, d3[1] - vmla.f32 q10, q4, d4[0] - vmla.f32 q11, q4, d6[0] - vmla.f32 q10, q5, d4[1] - vmla.f32 q11, q5, d6[1] - vld1.s32 {q0, q1}, [r12]! - vmla.f32 q10, q6, d5[0] - vmla.f32 q11, q6, d7[0] - vmla.f32 q10, q7, d5[1] - vmla.f32 q11, q7, d7[1] - vld1.s32 {q2, q3}, [r12]! - vmla.f32 q12, q4, d0[0] - vmla.f32 q13, q4, d2[0] - vmla.f32 q12, q5, d0[1] - vmla.f32 q13, q5, d2[1] - vmla.f32 q14, q4, d4[0] - vmla.f32 q15, q4, d6[0] - vmla.f32 q14, q5, d4[1] - vmla.f32 q15, q5, d6[1] - vmla.f32 q12, q6, d1[0] - vmla.f32 q13, q6, d3[0] - vmla.f32 q12, q7, d1[1] - vmla.f32 q13, q7, d3[1] - vmla.f32 q14, q6, d5[0] - vmla.f32 q15, q6, d7[0] - vmla.f32 q14, q7, d5[1] - vmla.f32 q15, q7, d7[1] - - ldr r10, [sp, #28] - cmp r10, #0 - bne Relu6 - ldr r10, [sp, #24] - cmp r10, #0 - bne Relu - b WriteStart - Relu6: - vmov.i32 q7, #6 - vcvt.f32.s32 q7, q7 - vmin.f32 q8, q8, q7 - vmin.f32 q9, q9, q7 - vmin.f32 q10, q10, q7 - vmin.f32 q11, q11, q7 - vmin.f32 q12, q12, q7 - vmin.f32 q13, q13, q7 - vmin.f32 q14, q14, q7 - vmin.f32 q15, q15, q7 - Relu: - veor q7, q7, q7 - vmax.f32 q8, q8, q7 - vmax.f32 q9, q9, q7 - vmax.f32 q10, q10, q7 - vmax.f32 q11, q11, q7 - vmax.f32 q12, q12, q7 - vmax.f32 q13, q13, q7 - vmax.f32 q14, q14, q7 - vmax.f32 q15, q15, q7 - - WriteStart: - ldr r10, [sp, #20] - cmp r10, #0 - bne Write4 - cmp r6, #1 - beq Write1 - cmp r6, #2 - beq Write2 - cmp r6, #3 - beq Write3 - b Write4 - Write1: - vst1.32 d16[0], [r11] - add r11, r11, r7 - vst1.32 d18[0], [r11] - add r11, r11, r7 - vst1.32 d20[0], [r11] - add r11, r11, r7 - vst1.32 d22[0], [r11] - add r11, r11, r7 - vst1.32 d24[0], [r11] - add r11, r11, r7 - vst1.32 d26[0], [r11] - add r11, r11, r7 - vst1.32 d28[0], [r11] - add r11, r11, r7 - vst1.32 d30[0], [r11] - add r11, r11, r7 - add r0, r0, #4 - b WriteEnd - Write2: - vst1.32 d16, [r11] - add r11, r11, r7 - vst1.32 d18, [r11] - add r11, r11, r7 - vst1.32 d20, [r11] - add r11, r11, r7 - vst1.32 d22, [r11] - add r11, r11, r7 - vst1.32 d24, [r11] - add r11, r11, r7 - vst1.32 d26, [r11] - add r11, r11, r7 - vst1.32 d28, [r11] - add r11, r11, r7 - vst1.32 d30, [r11] - add r11, r11, r7 - add r0, r0, #8 - b WriteEnd - Write3: - add lr, r11, #8 - vst1.32 d16, [r11] - add r11, r11, r7 - vst1.32 d17[0], [lr] - add lr, lr, r7 - vst1.32 d18, [r11] - add r11, r11, r7 - vst1.32 d19[0], [lr] - add lr, lr, r7 - vst1.32 d20, [r11] - add r11, r11, r7 - vst1.32 d21[0], [lr] - add lr, lr, r7 - vst1.32 d22, [r11] - add r11, r11, r7 - vst1.32 d23[0], [lr] - add lr, lr, r7 - vst1.32 d24, [r11] - add r11, r11, r7 - vst1.32 d25[0], [lr] - add lr, lr, r7 - vst1.32 d26, [r11] - add r11, r11, r7 - vst1.32 d27[0], [lr] - add lr, lr, r7 - vst1.32 d28, [r11] - add r11, r11, r7 - vst1.32 d29[0], [lr] - add lr, lr, r7 - vst1.32 d30, [r11] - add r11, r11, r7 - vst1.32 d31[0], [lr] - add lr, lr, r7 - add r0, r0, #12 - b WriteEnd - Write4: - // prefetching is not prefered while writing results in spite of cache missings - // you could try pld - // there are almost no benefits observed though - vst1.32 {q8}, [r11], r7 - vst1.32 {q9}, [r11], r7 - vst1.32 {q10}, [r11], r7 - vst1.32 {q11}, [r11], r7 - vst1.32 {q12}, [r11], r7 - vst1.32 {q13}, [r11], r7 - vst1.32 {q14}, [r11], r7 - vst1.32 {q15}, [r11], r7 - add r0, r0, #16 - - WriteEnd: - - subs r8, r8, #1 - bne LoopKsize - - cmp r6, #4 - ble LoopOcEnd - sub r6, r6, #4 - cmp r3, #0 - beq NoStepFowrard - add r3, r3, #16 - NoStepFowrard: - b LoopOc - -LoopOcEnd: - sub sp, sp, #96 - vpop {q4-q7} - pop {r4-r8, r10, r11, pc} -#endif -#endif diff --git a/mindspore/lite/nnacl/assembly/arm32/MatmulFp32.S b/mindspore/lite/nnacl/assembly/arm32/MatmulFp32.S new file mode 100644 index 0000000000..76efe3cda4 --- /dev/null +++ b/mindspore/lite/nnacl/assembly/arm32/MatmulFp32.S @@ -0,0 +1,368 @@ +#ifdef ENABLE_ARM32 + .text + .align 5 + .global MatmulFloatNeon32 +#ifndef __APPLE__ + .type MatmulFloatNeon32, %function +#endif + +// void MatmulFloatNeon32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth +// int row, int col, size_t stride, size_t writeNhwc, size_t WriteWino) +// r0: a +// r1: b +// r2: c +// r3: bias +// r4: act_type +// r5: depth +// r6: row +// r7: col +// r8: stride +// lr: writeNhwc/writeWino + +MatmulFloatNeon32: + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r0-r8, r10, r11, lr} + add sp, sp, #48 + + ldr r5, [sp, #4] + ldr r7, [sp, #12] + ldr r8, [sp, #16] + + mov lr, #32 // sizeof(float) * 8 + mul r12, r5, lr // block stride of lhs/rhs: sizeof(float) * 8 * depth + ldr lr, [sp, #24] + cmp lr, #0 + beq NoWinoSteps + mov lr, #4 + mul r11, r7, r8 // stride * col * sizeof(float) + mul r11, r11, lr + mov lr, #32 + mul r10, r8, lr // stride * 8 * sizeof(float) +NoWinoSteps: + mov lr, #4 + mul r8, r8, lr // stride * sizeof(float) + +LoopCol: + ldr r6, [sp, #8] // reload lhs row + ldr r0, [sp, #-48] // reload lhs ptr + ldr r2, [sp, #-40] // reload dst ptr + + LoopRow: + ldr r1, [sp, #-44] // reload rhs ptr + ldr r5, [sp, #4] // reload depth + veor q8, q8, q8 + veor q9, q9, q9 + veor q10, q10, q10 + veor q11, q11, q11 + veor q12, q12, q12 + veor q13, q13, q13 + veor q14, q14, q14 + veor q15, q15, q15 + + LoopDepth: + vld1.32 {q0}, [r0]! + vld1.32 {q1, q2}, [r1]! + vmla.f32 q8, q1, d0[0] + vmla.f32 q9, q2, d0[0] + vmla.f32 q10, q1, d0[1] + vmla.f32 q11, q2, d0[1] + vmla.f32 q12, q1, d1[0] + vmla.f32 q13, q2, d1[0] + vmla.f32 q14, q1, d1[1] + vmla.f32 q15, q2, d1[1] + + subs r5, r5, #1 + bne LoopDepth + + Bias: + cmp r3, #0 + beq Activation + vld1.32 {q0}, [r3]! + vld1.32 {q1}, [r3] + sub r3, r3, #16 + vadd.f32 q8, q8, q0 + vadd.f32 q9, q9, q1 + vadd.f32 q10, q10, q0 + vadd.f32 q11, q11, q1 + vadd.f32 q12, q12, q0 + vadd.f32 q13, q13, q1 + vadd.f32 q14, q14, q0 + vadd.f32 q15, q15, q1 + + Activation: + ldr lr, [sp] + cmp lr, #2 + beq Relu6 + cmp lr, #1 + beq Relu + b Write + + Relu6: + vmov.i32 q2, #6 + vcvt.f32.s32 q2, q2 + vmin.f32 q8, q8, q2 + vmin.f32 q9, q9, q2 + vmin.f32 q10, q10, q2 + vmin.f32 q11, q11, q2 + vmin.f32 q12, q12, q2 + vmin.f32 q13, q13, q2 + vmin.f32 q14, q14, q2 + vmin.f32 q15, q15, q2 + + Relu: + veor q3, q3, q3 + vmax.f32 q8, q8, q3 + vmax.f32 q9, q9, q3 + vmax.f32 q10, q10, q3 + vmax.f32 q11, q11, q3 + vmax.f32 q12, q12, q3 + vmax.f32 q13, q13, q3 + vmax.f32 q14, q14, q3 + vmax.f32 q15, q15, q3 + + Write: + ldr lr, [sp, #24] + cmp lr, #0 + bne WriteWino + ldr lr, [sp, #20] + cmp lr, #0 + beq WriteC8 + cmp r7, #1 + beq Write1 + cmp r7, #2 + beq Write2 + cmp r7, #3 + beq Write3 + cmp r7, #4 + beq Write4 + cmp r7, #5 + beq Write5 + cmp r7, #6 + beq Write6 + cmp r7, #7 + beq Write7 + b Write8 + + Write1: + vst1.32 d16[0], [r2] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + vst1.32 d20[0], [r2] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + vst1.32 d24[0], [r2] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + vst1.32 d28[0], [r2] + add r2, r2, r8 + b WriteEnd + Write2: + vst1.32 d16, [r2] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + vst1.32 d20, [r2] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + vst1.32 d24, [r2] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + vst1.32 d28, [r2] + add r2, r2, r8 + b WriteEnd + Write3: + add r4, r2, #8 + vst1.32 d16, [r2] + vst1.32 d17[0], [r4] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d20, [r2] + vst1.32 d21[0], [r4] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d24, [r2] + vst1.32 d25[0], [r4] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d28, [r2] + vst1.32 d29[0], [r4] + add r2, r2, r8 + b WriteEnd + Write4: + vst1.32 q8, [r2] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + vst1.32 q10, [r2] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + vst1.32 q12, [r2] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + vst1.32 q14, [r2] + add r2, r2, r8 + b WriteEnd + Write5: + add r4, r2, #16 + vst1.32 q8, [r2] + vst1.32 d18[0], [r4] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 q10, [r2] + vst1.32 d22[0], [r4] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 q12, [r2] + vst1.32 d26[0], [r4] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 q14, [r2] + vst1.32 d30[0], [r4] + add r2, r2, r8 + b WriteEnd + Write6: + add r4, r2, #16 + vst1.32 q8, [r2] + vst1.32 d18, [r4] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 q10, [r2] + vst1.32 d22, [r4] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 q12, [r2] + vst1.32 d26, [r4] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 q14, [r2] + vst1.32 d30, [r4] + add r2, r2, r8 + b WriteEnd + Write7: + add lr, r2, #24 + add r4, r2, #16 + vst1.32 q8, [r2] + vst1.32 d18, [r4] + vst1.32 d19[0], [lr] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + add lr, lr, r8 + vst1.32 q10, [r2] + vst1.32 d22, [r4] + vst1.32 d23[0], [lr] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + add lr, lr, r8 + vst1.32 q12, [r2] + vst1.32 d26, [r4] + vst1.32 d27[0], [lr] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + add lr, lr, r8 + vst1.32 q14, [r2] + vst1.32 d30, [r4] + vst1.32 d31[0], [lr] + add r2, r2, r8 + b WriteEnd + WriteC8: + vst1.32 {q8, q9}, [r2]! + vst1.32 {q10, q11}, [r2]! + vst1.32 {q12, q13}, [r2]! + vst1.32 {q14, q15}, [r2]! + str r2, [sp, #-40] + b WriteEnd + WriteWino: + vst1.32 {q8, q9}, [r2] + add r2, r2, r11 + vst1.32 {q10, q11}, [r2] + add r2, r2, r11 + vst1.32 {q12, q13}, [r2] + add r2, r2, r11 + vst1.32 {q14, q15}, [r2] + add r2, r2, r11 + b WriteEnd + Write8: + vst1.32 {q8, q9}, [r2] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + vst1.32 {q10, q11}, [r2] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + vst1.32 {q12, q13}, [r2] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + vst1.32 {q14, q15}, [r2] + add r2, r2, r8 + + WriteEnd: + cmp r6, #4 + ble LoopRowEnd + sub r6, r6, #4 // lhs row - 4 + b LoopRow + + LoopRowEnd: + ldr r1, [sp, #-44] + add r1, r1, r12 // rhs ptr + stride + str r1, [sp, #-44] + cmp r3, #0 + beq NoBiasStep + add r3, r3, #32 // bias ptr + stride + NoBiasStep: + ldr lr, [sp, #24] + cmp lr, #0 + bne WinoDstStep + ldr lr, [sp, #20] + cmp lr, #0 + beq NoDstStep + ldr r2, [sp, #-40] + add r2, r2, #32 // dst ptr + stride + str r2, [sp, #-40] + b NoDstStep + WinoDstStep: + ldr r2, [sp, #-40] + add r2, r2, r10 + str r2, [sp, #-40] + NoDstStep: + cmp r7, #8 + ble LoopColEnd + sub r7, r7, #8 // rhs col - 8 + b LoopCol + +LoopColEnd: + sub sp, sp, #48 + pop {r0-r8, r10, r11, pc} +#endif diff --git a/mindspore/lite/nnacl/assembly/arm64/IndirectGemmFp32_8x8.S b/mindspore/lite/nnacl/assembly/arm64/IndirectGemmFp32_8x8.S deleted file mode 100644 index 483dfac09b..0000000000 --- a/mindspore/lite/nnacl/assembly/arm64/IndirectGemmFp32_8x8.S +++ /dev/null @@ -1,730 +0,0 @@ -#ifdef __aarch64__ - -.text -.align 5 -.global IndirectGemmFp32_8x8 -#ifndef __APPLE__ -.type IndirectGemmFp32_8x8, %function -#endif - -// void IndirectGemmFp32_8x8(float *output, float *input, float *weight, float *bias, -// size_t kSize, size_t ic4, size_t oc8, size_t offset, size_t mode, size_t writeC4, size_t relu, size_t relu6); -// x0: output, x1: input, x2: weight, x3: bias, x4: kSize, x5: ic4, x6: oc, x7: offset -// x8:mode, x9: writeMode, x10: relu, x11:relu6 -// mode = 0 for general convolution, where one conv unit is a row -// mode = 1 for winograd/common gemm, where the total channels of one input is a row -IndirectGemmFp32_8x8: - - .macro INIT_BIAS - dup v16.4s, wzr - dup v17.4s, wzr - cbz x3, InitBias - ld1 {v16.4s, v17.4s}, [x3] - InitBias: - mov v18.16b, v16.16b - mov v19.16b, v17.16b - mov v20.16b, v16.16b - mov v21.16b, v17.16b - mov v22.16b, v16.16b - mov v23.16b, v17.16b - mov v24.16b, v16.16b - mov v25.16b, v17.16b - mov v26.16b, v16.16b - mov v27.16b, v17.16b - mov v28.16b, v16.16b - mov v29.16b, v17.16b - mov v30.16b, v16.16b - mov v31.16b, v17.16b - .endm - - .macro INIT_BIAS_HALF - dup v16.4s, wzr - cbz x3, InitBiasHalf - ld1 {v16.4s}, [x3] - InitBiasHalf: - mov v18.16b, v16.16b - mov v20.16b, v16.16b - mov v22.16b, v16.16b - mov v24.16b, v16.16b - mov v26.16b, v16.16b - mov v28.16b, v16.16b - mov v30.16b, v16.16b - .endm - - // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to - // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers - // r19 ~ r29 should be also preserved - // whereas our coding style do not permit such amount of parameters - sub sp, sp, #128 - st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 - st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 - - ldr x8, [sp, #0] - ldr x9, [sp, #8] - ldr x10, [sp, #16] - ldr x11, [sp, #24] - - cbnz x8, NoStepShuffle - // step is one for common convolution, where ic8 should multiply by kernel size - // step is (a+b-1) for F(a,b) in winograd - mul x5, x4, x5 - mov x4, #1 - -NoStepShuffle: - // x8 is used to store offset now - // only useful for WriteC4 - mov x8, #16 - mul x8, x8, x4 - -IndirectGemmStart: - - cmp x6, #4 - ble LoopOcHalf - - LoopOc: - - mov x14, x4 - mov x12, x1 - - LoopKsize: - - mov x15, x0 - INIT_BIAS - - // load input for output 1-2 - ld1 {v0.4s, v1.4s}, [x12], #32 - // load weight - ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x2], #64 - // step for output 1-2 - fmla v16.4s, v8.4s, v0.s[0] - fmla v17.4s, v9.4s, v0.s[0] - fmla v18.4s, v8.4s, v1.s[0] - fmla v19.4s, v9.4s, v1.s[0] - // load input for output 3-4 - ld1 {v2.4s, v3.4s}, [x12], #32 - // another step for output 1-2 - fmla v16.4s, v10.4s, v0.s[1] - fmla v17.4s, v11.4s, v0.s[1] - fmla v18.4s, v10.4s, v1.s[1] - fmla v19.4s, v11.4s, v1.s[1] - // load input for output 5-8 - // input cache should be refreshed after loading - // ATTENTION: advance is prefered, but advancing too much may lead to invalid prefetching - ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x12], #64 - // step for output 3-8 - fmla v20.4s, v8.4s, v2.s[0] - fmla v21.4s, v9.4s, v2.s[0] - fmla v22.4s, v8.4s, v3.s[0] - fmla v23.4s, v9.4s, v3.s[0] - - subs x13, x5, #1 - beq LoopIcEnd - - LoopIc: - fmla v24.4s, v8.4s, v4.s[0] - fmla v25.4s, v9.4s, v4.s[0] - fmla v26.4s, v8.4s, v5.s[0] - fmla v27.4s, v9.4s, v5.s[0] - fmla v28.4s, v8.4s, v6.s[0] - fmla v29.4s, v9.4s, v6.s[0] - fmla v30.4s, v8.4s, v7.s[0] - fmla v31.4s, v9.4s, v7.s[0] - // load weight - ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x2], #64 - // step for output 3-8 - fmla v20.4s, v10.4s, v2.s[1] - fmla v21.4s, v11.4s, v2.s[1] - fmla v22.4s, v10.4s, v3.s[1] - fmla v23.4s, v11.4s, v3.s[1] - fmla v24.4s, v10.4s, v4.s[1] - fmla v25.4s, v11.4s, v4.s[1] - fmla v26.4s, v10.4s, v5.s[1] - fmla v27.4s, v11.4s, v5.s[1] - fmla v28.4s, v10.4s, v6.s[1] - fmla v29.4s, v11.4s, v6.s[1] - fmla v30.4s, v10.4s, v7.s[1] - fmla v31.4s, v11.4s, v7.s[1] - // another step for output 1-8 - fmla v16.4s, v12.4s, v0.s[2] - fmla v17.4s, v13.4s, v0.s[2] - fmla v18.4s, v12.4s, v1.s[2] - fmla v19.4s, v13.4s, v1.s[2] - fmla v20.4s, v12.4s, v2.s[2] - fmla v21.4s, v13.4s, v2.s[2] - fmla v22.4s, v12.4s, v3.s[2] - fmla v23.4s, v13.4s, v3.s[2] - fmla v24.4s, v12.4s, v4.s[2] - fmla v25.4s, v13.4s, v4.s[2] - fmla v26.4s, v12.4s, v5.s[2] - fmla v27.4s, v13.4s, v5.s[2] - fmla v28.4s, v12.4s, v6.s[2] - fmla v29.4s, v13.4s, v6.s[2] - fmla v30.4s, v12.4s, v7.s[2] - fmla v31.4s, v13.4s, v7.s[2] - // load weight - ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x2], #64 - // another step for output 1-8 - fmla v16.4s, v14.4s, v0.s[3] - fmla v17.4s, v15.4s, v0.s[3] - fmla v18.4s, v14.4s, v1.s[3] - fmla v19.4s, v15.4s, v1.s[3] - fmla v20.4s, v14.4s, v2.s[3] - fmla v21.4s, v15.4s, v2.s[3] - fmla v22.4s, v14.4s, v3.s[3] - fmla v23.4s, v15.4s, v3.s[3] - // load input for output 1-4 - ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x12], #64 - fmla v24.4s, v14.4s, v4.s[3] - fmla v25.4s, v15.4s, v4.s[3] - fmla v26.4s, v14.4s, v5.s[3] - fmla v27.4s, v15.4s, v5.s[3] - fmla v28.4s, v14.4s, v6.s[3] - fmla v29.4s, v15.4s, v6.s[3] - fmla v30.4s, v14.4s, v7.s[3] - fmla v31.4s, v15.4s, v7.s[3] - // load input for output 5-8 - ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x12], #64 - // step for output 1-8 - fmla v16.4s, v8.4s, v0.s[0] - fmla v17.4s, v9.4s, v0.s[0] - fmla v18.4s, v8.4s, v1.s[0] - fmla v19.4s, v9.4s, v1.s[0] - fmla v16.4s, v10.4s, v0.s[1] - fmla v17.4s, v11.4s, v0.s[1] - fmla v18.4s, v10.4s, v1.s[1] - fmla v19.4s, v11.4s, v1.s[1] - fmla v20.4s, v8.4s, v2.s[0] - fmla v21.4s, v9.4s, v2.s[0] - fmla v22.4s, v8.4s, v3.s[0] - fmla v23.4s, v9.4s, v3.s[0] - - subs x13, x13, #1 - bne LoopIc - - LoopIcEnd: - fmla v24.4s, v8.4s, v4.s[0] - fmla v25.4s, v9.4s, v4.s[0] - fmla v26.4s, v8.4s, v5.s[0] - fmla v27.4s, v9.4s, v5.s[0] - fmla v28.4s, v8.4s, v6.s[0] - fmla v29.4s, v9.4s, v6.s[0] - fmla v30.4s, v8.4s, v7.s[0] - fmla v31.4s, v9.4s, v7.s[0] - // load weight - ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x2], #64 - // step for output 3-8 - fmla v20.4s, v10.4s, v2.s[1] - fmla v21.4s, v11.4s, v2.s[1] - fmla v22.4s, v10.4s, v3.s[1] - fmla v23.4s, v11.4s, v3.s[1] - fmla v24.4s, v10.4s, v4.s[1] - fmla v25.4s, v11.4s, v4.s[1] - fmla v26.4s, v10.4s, v5.s[1] - fmla v27.4s, v11.4s, v5.s[1] - fmla v28.4s, v10.4s, v6.s[1] - fmla v29.4s, v11.4s, v6.s[1] - fmla v30.4s, v10.4s, v7.s[1] - fmla v31.4s, v11.4s, v7.s[1] - // another step for output 1-8 - fmla v16.4s, v12.4s, v0.s[2] - fmla v17.4s, v13.4s, v0.s[2] - fmla v18.4s, v12.4s, v1.s[2] - fmla v19.4s, v13.4s, v1.s[2] - fmla v20.4s, v12.4s, v2.s[2] - fmla v21.4s, v13.4s, v2.s[2] - fmla v22.4s, v12.4s, v3.s[2] - fmla v23.4s, v13.4s, v3.s[2] - fmla v24.4s, v12.4s, v4.s[2] - fmla v25.4s, v13.4s, v4.s[2] - fmla v26.4s, v12.4s, v5.s[2] - fmla v27.4s, v13.4s, v5.s[2] - fmla v28.4s, v12.4s, v6.s[2] - fmla v29.4s, v13.4s, v6.s[2] - fmla v30.4s, v12.4s, v7.s[2] - fmla v31.4s, v13.4s, v7.s[2] - // another step for output 1-8 - fmla v16.4s, v14.4s, v0.s[3] - fmla v17.4s, v15.4s, v0.s[3] - fmla v18.4s, v14.4s, v1.s[3] - fmla v19.4s, v15.4s, v1.s[3] - fmla v20.4s, v14.4s, v2.s[3] - fmla v21.4s, v15.4s, v2.s[3] - fmla v22.4s, v14.4s, v3.s[3] - fmla v23.4s, v15.4s, v3.s[3] - fmla v24.4s, v14.4s, v4.s[3] - fmla v25.4s, v15.4s, v4.s[3] - fmla v26.4s, v14.4s, v5.s[3] - fmla v27.4s, v15.4s, v5.s[3] - fmla v28.4s, v14.4s, v6.s[3] - fmla v29.4s, v15.4s, v6.s[3] - fmla v30.4s, v14.4s, v7.s[3] - fmla v31.4s, v15.4s, v7.s[3] - // prefetching is not prefered while writing results in spite of cache missings - // you could try prfm pstl2strm - // there are almost no benefits observed though - cbnz x11, Relu6 - cbnz x10, Relu - b WriteStart - Relu6: - movi v1.4s, #6 - scvtf v1.4s, v1.4s - fmin v16.4s, v16.4s, v1.4s - fmin v17.4s, v17.4s, v1.4s - fmin v18.4s, v18.4s, v1.4s - fmin v19.4s, v19.4s, v1.4s - fmin v20.4s, v20.4s, v1.4s - fmin v21.4s, v21.4s, v1.4s - fmin v22.4s, v22.4s, v1.4s - fmin v23.4s, v23.4s, v1.4s - fmin v24.4s, v24.4s, v1.4s - fmin v25.4s, v25.4s, v1.4s - fmin v26.4s, v26.4s, v1.4s - fmin v27.4s, v27.4s, v1.4s - fmin v28.4s, v28.4s, v1.4s - fmin v29.4s, v29.4s, v1.4s - fmin v30.4s, v30.4s, v1.4s - fmin v31.4s, v31.4s, v1.4s - Relu: - dup v0.4s, wzr - fmax v16.4s, v16.4s, v0.4s - fmax v17.4s, v17.4s, v0.4s - fmax v18.4s, v18.4s, v0.4s - fmax v19.4s, v19.4s, v0.4s - fmax v20.4s, v20.4s, v0.4s - fmax v21.4s, v21.4s, v0.4s - fmax v22.4s, v22.4s, v0.4s - fmax v23.4s, v23.4s, v0.4s - fmax v24.4s, v24.4s, v0.4s - fmax v25.4s, v25.4s, v0.4s - fmax v26.4s, v26.4s, v0.4s - fmax v27.4s, v27.4s, v0.4s - fmax v28.4s, v28.4s, v0.4s - fmax v29.4s, v29.4s, v0.4s - fmax v30.4s, v30.4s, v0.4s - fmax v31.4s, v31.4s, v0.4s - - WriteStart: - cbnz x9, WriteC4 - cmp x6, #5 - beq Write5 - cmp x6, #6 - beq Write6 - cmp x6, #7 - beq Write7 - b Write8 - Write5: - add x17, x15, #16 - st1 {v16.4s}, [x15], x7 - str s17, [x17] - add x17, x17, x7 - st1 {v18.4s}, [x15], x7 - str s19, [x17] - add x17, x17, x7 - st1 {v20.4s}, [x15], x7 - str s21, [x17] - add x17, x17, x7 - st1 {v22.4s}, [x15], x7 - str s23, [x17] - add x17, x17, x7 - st1 {v24.4s}, [x15], x7 - str s25, [x17] - add x17, x17, x7 - st1 {v26.4s}, [x15], x7 - str s27, [x17] - add x17, x17, x7 - st1 {v28.4s}, [x15], x7 - str s29, [x17] - add x17, x17, x7 - st1 {v30.4s}, [x15] - str s31, [x17] - add x0, x0, #20 - b WriteEnd - Write6: - add x17, x15, #16 - st1 {v16.4s}, [x15], x7 - dup s16, v17.s[1] - stp s17, s16, [x17] - add x17, x17, x7 - st1 {v18.4s}, [x15], x7 - dup s18, v19.s[1] - stp s19, s18, [x17] - add x17, x17, x7 - st1 {v20.4s}, [x15], x7 - dup s20, v21.s[1] - stp s21, s20, [x17] - add x17, x17, x7 - st1 {v22.4s}, [x15], x7 - dup s22, v23.s[1] - stp s23, s22, [x17] - add x17, x17, x7 - st1 {v24.4s}, [x15], x7 - dup s24, v25.s[1] - stp s25, s24, [x17] - add x17, x17, x7 - st1 {v26.4s}, [x15], x7 - dup s26, v27.s[1] - stp s27, s26, [x17] - add x17, x17, x7 - st1 {v28.4s}, [x15], x7 - dup s28, v29.s[1] - stp s29, s28, [x17] - add x17, x17, x7 - st1 {v30.4s}, [x15] - dup s30, v31.s[1] - stp s31, s30, [x17] - add x0, x0, #24 - b WriteEnd - Write7: - add x17, x15, #16 - add x16, x15, #24 - st1 {v16.4s}, [x15], x7 - dup s16, v17.s[1] - stp s17, s16, [x17] - add x17, x17, x7 - st1 {v17.s}[2], [x16], x7 - st1 {v18.4s}, [x15], x7 - dup s18, v19.s[1] - stp s19, s18, [x17] - add x17, x17, x7 - st1 {v19.s}[2], [x16], x7 - st1 {v20.4s}, [x15], x7 - dup s20, v21.s[1] - stp s21, s20, [x17] - add x17, x17, x7 - st1 {v21.s}[2], [x16], x7 - st1 {v22.4s}, [x15], x7 - dup s22, v23.s[1] - stp s23, s22, [x17] - add x17, x17, x7 - st1 {v23.s}[2], [x16], x7 - st1 {v24.4s}, [x15], x7 - dup s24, v25.s[1] - stp s25, s24, [x17] - add x17, x17, x7 - st1 {v25.s}[2], [x16], x7 - st1 {v26.4s}, [x15], x7 - dup s26, v27.s[1] - stp s27, s26, [x17] - add x17, x17, x7 - st1 {v27.s}[2], [x16], x7 - st1 {v28.4s}, [x15], x7 - dup s28, v29.s[1] - stp s29, s28, [x17] - add x17, x17, x7 - st1 {v29.s}[2], [x16], x7 - st1 {v30.4s}, [x15], x7 - dup s30, v31.s[1] - stp s31, s30, [x17] - add x17, x17, x7 - st1 {v31.s}[2], [x16], x7 - add x0, x0, #28 - b WriteEnd - WriteC4: - st1 {v16.4s}, [x15], x7 - st1 {v18.4s}, [x15], x7 - st1 {v20.4s}, [x15], x7 - st1 {v22.4s}, [x15], x7 - st1 {v24.4s}, [x15], x7 - st1 {v26.4s}, [x15], x7 - st1 {v28.4s}, [x15], x7 - st1 {v30.4s}, [x15] - add x15, x8, x0 - st1 {v17.4s}, [x15], x7 - st1 {v19.4s}, [x15], x7 - st1 {v21.4s}, [x15], x7 - st1 {v23.4s}, [x15], x7 - st1 {v25.4s}, [x15], x7 - st1 {v27.4s}, [x15], x7 - st1 {v29.4s}, [x15], x7 - st1 {v31.4s}, [x15] - add x0, x0, #16 - b WriteEnd - Write8: - st1 {v16.4s, v17.4s}, [x15], x7 - st1 {v18.4s, v19.4s}, [x15], x7 - st1 {v20.4s, v21.4s}, [x15], x7 - st1 {v22.4s, v23.4s}, [x15], x7 - st1 {v24.4s, v25.4s}, [x15], x7 - st1 {v26.4s, v27.4s}, [x15], x7 - st1 {v28.4s, v29.4s}, [x15], x7 - st1 {v30.4s, v31.4s}, [x15] - add x0, x0, #32 - - WriteEnd: - - subs x14, x14, #1 - bne LoopKsize - - subs x6, x6, #8 - ble LoopOcEnd - cbz x9, NoStepC4Block - add x0, x0, x8 - NoStepC4Block: - cbz x3, NoStepForward - add x3, x3, #32 - NoStepForward: - cmp x6, #4 - bgt LoopOc - - LoopOcHalf: - mov x18, #32 - - mov x14, x4 - mov x12, x1 - - LoopKsizeHalf: - - mov x15, x0 - INIT_BIAS_HALF - - // load input for output 1-2 - ld1 {v0.4s, v1.4s}, [x12], #32 - // load weight - ld1 {v8.4s}, [x2], x18 - ld1 {v10.4s}, [x2], x18 - // step for output 1-2 - fmla v16.4s, v8.4s, v0.s[0] - fmla v18.4s, v8.4s, v1.s[0] - // load input for output 3-4 - ld1 {v2.4s, v3.4s}, [x12], #32 - // another step for output 1-2 - fmla v16.4s, v10.4s, v0.s[1] - fmla v18.4s, v10.4s, v1.s[1] - // load input for output 5-8 - // input cache should be refreshed after loading - // ATTENTION: advance is prefered, but advancing too much may lead to invalid prefetching - ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x12], #64 - // step for output 3-8 - fmla v20.4s, v8.4s, v2.s[0] - fmla v22.4s, v8.4s, v3.s[0] - - subs x13, x5, #1 - beq LoopIcEndHalf - - LoopIcHalf: - fmla v24.4s, v8.4s, v4.s[0] - fmla v26.4s, v8.4s, v5.s[0] - fmla v28.4s, v8.4s, v6.s[0] - fmla v30.4s, v8.4s, v7.s[0] - // load weight - ld1 {v12.4s}, [x2], x18 - // step for output 3-8 - fmla v20.4s, v10.4s, v2.s[1] - fmla v22.4s, v10.4s, v3.s[1] - // load weight - ld1 {v14.4s}, [x2], x18 - fmla v24.4s, v10.4s, v4.s[1] - fmla v26.4s, v10.4s, v5.s[1] - fmla v28.4s, v10.4s, v6.s[1] - fmla v30.4s, v10.4s, v7.s[1] - // another step for output 1-8 - fmla v16.4s, v12.4s, v0.s[2] - fmla v18.4s, v12.4s, v1.s[2] - fmla v20.4s, v12.4s, v2.s[2] - fmla v22.4s, v12.4s, v3.s[2] - fmla v24.4s, v12.4s, v4.s[2] - fmla v26.4s, v12.4s, v5.s[2] - fmla v28.4s, v12.4s, v6.s[2] - fmla v30.4s, v12.4s, v7.s[2] - // load weight - ld1 {v8.4s}, [x2], x18 - // another step for output 1-8 - fmla v16.4s, v14.4s, v0.s[3] - fmla v18.4s, v14.4s, v1.s[3] - // load weight - ld1 {v10.4s}, [x2], x18 - fmla v20.4s, v14.4s, v2.s[3] - fmla v22.4s, v14.4s, v3.s[3] - // load input for output 1-4 - ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x12], #64 - fmla v24.4s, v14.4s, v4.s[3] - fmla v26.4s, v14.4s, v5.s[3] - fmla v28.4s, v14.4s, v6.s[3] - fmla v30.4s, v14.4s, v7.s[3] - // load input for output 5-8 - ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x12], #64 - // step for output 1-8 - fmla v16.4s, v8.4s, v0.s[0] - fmla v18.4s, v8.4s, v1.s[0] - fmla v16.4s, v10.4s, v0.s[1] - fmla v18.4s, v10.4s, v1.s[1] - fmla v20.4s, v8.4s, v2.s[0] - fmla v22.4s, v8.4s, v3.s[0] - - subs x13, x13, #1 - bne LoopIcHalf - - LoopIcEndHalf: - fmla v24.4s, v8.4s, v4.s[0] - fmla v26.4s, v8.4s, v5.s[0] - fmla v28.4s, v8.4s, v6.s[0] - fmla v30.4s, v8.4s, v7.s[0] - // load weight - ld1 {v12.4s}, [x2], x18 - // step for output 3-8 - fmla v20.4s, v10.4s, v2.s[1] - fmla v22.4s, v10.4s, v3.s[1] - // load weight - ld1 {v14.4s}, [x2], x18 - fmla v24.4s, v10.4s, v4.s[1] - fmla v26.4s, v10.4s, v5.s[1] - fmla v28.4s, v10.4s, v6.s[1] - fmla v30.4s, v10.4s, v7.s[1] - // another step for output 1-8 - fmla v16.4s, v12.4s, v0.s[2] - fmla v18.4s, v12.4s, v1.s[2] - fmla v20.4s, v12.4s, v2.s[2] - fmla v22.4s, v12.4s, v3.s[2] - fmla v24.4s, v12.4s, v4.s[2] - fmla v26.4s, v12.4s, v5.s[2] - fmla v28.4s, v12.4s, v6.s[2] - fmla v30.4s, v12.4s, v7.s[2] - // another step for output 1-8 - fmla v16.4s, v14.4s, v0.s[3] - fmla v18.4s, v14.4s, v1.s[3] - fmla v20.4s, v14.4s, v2.s[3] - fmla v22.4s, v14.4s, v3.s[3] - fmla v24.4s, v14.4s, v4.s[3] - fmla v26.4s, v14.4s, v5.s[3] - fmla v28.4s, v14.4s, v6.s[3] - fmla v30.4s, v14.4s, v7.s[3] - - cbnz x11, Relu6Half - cbnz x10, ReluHalf - b WriteStartHalf - Relu6Half: - movi v1.4s, #6 - scvtf v1.4s, v1.4s - fmin v16.4s, v16.4s, v1.4s - fmin v18.4s, v18.4s, v1.4s - fmin v20.4s, v20.4s, v1.4s - fmin v22.4s, v22.4s, v1.4s - fmin v24.4s, v24.4s, v1.4s - fmin v26.4s, v26.4s, v1.4s - fmin v28.4s, v28.4s, v1.4s - fmin v30.4s, v30.4s, v1.4s - ReluHalf: - dup v0.4s, wzr - fmax v16.4s, v16.4s, v0.4s - fmax v18.4s, v18.4s, v0.4s - fmax v20.4s, v20.4s, v0.4s - fmax v22.4s, v22.4s, v0.4s - fmax v24.4s, v24.4s, v0.4s - fmax v26.4s, v26.4s, v0.4s - fmax v28.4s, v28.4s, v0.4s - fmax v30.4s, v30.4s, v0.4s - - WriteStartHalf: - cbnz x9, Write4 - cmp x6, #1 - beq Write1 - cmp x6, #2 - beq Write2 - cmp x6, #3 - beq Write3 - b Write4 - Write1: - str s16, [x15] - add x15, x15, x7 - str s18, [x15] - add x15, x15, x7 - str s20, [x15] - add x15, x15, x7 - str s22, [x15] - add x15, x15, x7 - str s24, [x15] - add x15, x15, x7 - str s26, [x15] - add x15, x15, x7 - str s28, [x15] - add x15, x15, x7 - str s30, [x15] - add x0, x0, #4 - b WriteEndHalf - Write2: - dup s17, v16.s[1] - stp s16, s17, [x15] - add x15, x15, x7 - dup s19, v18.s[1] - stp s18, s19, [x15] - add x15, x15, x7 - dup s21, v20.s[1] - stp s20, s21, [x15] - add x15, x15, x7 - dup s23, v22.s[1] - stp s22, s23, [x15] - add x15, x15, x7 - dup s25, v24.s[1] - stp s24, s25, [x15] - add x15, x15, x7 - dup s27, v26.s[1] - stp s26, s27, [x15] - add x15, x15, x7 - dup s29, v28.s[1] - stp s28, s29, [x15] - add x15, x15, x7 - dup s31, v30.s[1] - stp s30, s31, [x15] - add x0, x0, #8 - b WriteEndHalf - Write3: - add x17, x15, #8 - dup s17, v16.s[1] - stp s16, s17, [x15] - add x15, x15, x7 - st1 {v16.s}[2], [x17], x7 - dup s19, v18.s[1] - stp s18, s19, [x15] - add x15, x15, x7 - st1 {v18.s}[2], [x17], x7 - dup s21, v20.s[1] - stp s20, s21, [x15] - add x15, x15, x7 - st1 {v20.s}[2], [x17], x7 - dup s23, v22.s[1] - stp s22, s23, [x15] - add x15, x15, x7 - st1 {v22.s}[2], [x17], x7 - dup s25, v24.s[1] - stp s24, s25, [x15] - add x15, x15, x7 - st1 {v24.s}[2], [x17], x7 - dup s27, v26.s[1] - stp s26, s27, [x15] - add x15, x15, x7 - st1 {v26.s}[2], [x17], x7 - dup s29, v28.s[1] - stp s28, s29, [x15] - add x15, x15, x7 - st1 {v28.s}[2], [x17], x7 - dup s31, v30.s[1] - stp s30, s31, [x15] - st1 {v30.s}[2], [x17] - add x0, x0, #12 - b WriteEndHalf - Write4: - // prefetching is not prefered while writing results in spite of cache missings - // you could try prfm pstl2strm - // there are almost no benefits observed though - st1 {v16.4s}, [x15], x7 - st1 {v18.4s}, [x15], x7 - st1 {v20.4s}, [x15], x7 - st1 {v22.4s}, [x15], x7 - st1 {v24.4s}, [x15], x7 - st1 {v26.4s}, [x15], x7 - st1 {v28.4s}, [x15], x7 - st1 {v30.4s}, [x15] - add x0, x0, #16 - - WriteEndHalf: - - subs x14, x14, #1 - bne LoopKsizeHalf - -LoopOcEnd: - - sub sp, sp, #128 - ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 - ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 - ret -#endif diff --git a/mindspore/lite/nnacl/assembly/arm64/MatmulFp32.S b/mindspore/lite/nnacl/assembly/arm64/MatmulFp32.S index b2bbea889c..63d07cf80b 100644 --- a/mindspore/lite/nnacl/assembly/arm64/MatmulFp32.S +++ b/mindspore/lite/nnacl/assembly/arm64/MatmulFp32.S @@ -7,7 +7,7 @@ #endif // void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth -// int row, int col, int stride, bool write_nhwc) +// int row, int col, size_t stride, size_t writeNhwc, size_t WriteWino) // x0: a // x1: b // x2: c @@ -17,18 +17,27 @@ // w6: row // w7: col // w17: stride -// w13: writeC8 +// w13: c8_nhwc_c4 MatmulFloatNeon64: sub sp, sp, #128 st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldr x9, [sp, #8] + ldr x14, [sp, #16] + mov w18, #32 // sizeof(float) * 8 mul w15, w5, w18 // block stride of lhs/rhs: sizeof(float) * 8 * depth - mov x11, x3 // bias flag mov x18, #4 ldr x17, [sp] + cbz x14, NoWinoSteps + mul x8, x7, x17 + mov x11, #8 + mul x11, x11, x17 + mul x8, x8, x18 + mul x11, x11, x18 +NoWinoSteps: mul x17, x17, x18 L1: @@ -39,7 +48,14 @@ L1: L2: mov x16, x1 // reload rhs ptr mov w13, w5 // reload depth - mov x14, x3 // reload bias ptr + dup v8.4s, wzr + dup v9.4s, wzr + dup v10.4s, wzr + dup v11.4s, wzr + dup v12.4s, wzr + dup v13.4s, wzr + dup v14.4s, wzr + dup v15.4s, wzr dup v16.4s, wzr dup v17.4s, wzr dup v18.4s, wzr @@ -57,116 +73,86 @@ L2: dup v30.4s, wzr dup v31.4s, wzr - cmp w13, #4 - blt CommLoopMul - -OptLoopMul4: - ld1 {v0.4s, v1.4s}, [x12], #32 - ld1 {v8.4s, v9.4s}, [x16], #32 - fmla v16.4s, v8.4s, v0.s[0] - fmla v17.4s, v9.4s, v0.s[0] - fmla v18.4s, v8.4s, v0.s[1] - fmla v19.4s, v9.4s, v0.s[1] - fmla v20.4s, v8.4s, v0.s[2] - fmla v21.4s, v9.4s, v0.s[2] - fmla v22.4s, v8.4s, v0.s[3] - fmla v23.4s, v9.4s, v0.s[3] - ld1 {v10.4s, v11.4s}, [x16], #32 - fmla v24.4s, v8.4s, v1.s[0] - fmla v25.4s, v9.4s, v1.s[0] - fmla v26.4s, v8.4s, v1.s[1] - fmla v27.4s, v9.4s, v1.s[1] - ld1 {v2.4s, v3.4s}, [x12], #32 - fmla v28.4s, v8.4s, v1.s[2] - fmla v29.4s, v9.4s, v1.s[2] - fmla v30.4s, v8.4s, v1.s[3] - fmla v31.4s, v9.4s, v1.s[3] - fmla v16.4s, v10.4s, v2.s[0] - fmla v17.4s, v11.4s, v2.s[0] - fmla v18.4s, v10.4s, v2.s[1] - fmla v19.4s, v11.4s, v2.s[1] - fmla v20.4s, v10.4s, v2.s[2] - fmla v21.4s, v11.4s, v2.s[2] - fmla v22.4s, v10.4s, v2.s[3] - fmla v23.4s, v11.4s, v2.s[3] - ld1 {v12.4s, v13.4s}, [x16], #32 - fmla v24.4s, v10.4s, v3.s[0] - fmla v25.4s, v11.4s, v3.s[0] - fmla v26.4s, v10.4s, v3.s[1] - fmla v27.4s, v11.4s, v3.s[1] - ld1 {v4.4s, v5.4s}, [x12], #32 - fmla v28.4s, v10.4s, v3.s[2] - fmla v29.4s, v11.4s, v3.s[2] - fmla v30.4s, v10.4s, v3.s[3] - fmla v31.4s, v11.4s, v3.s[3] - fmla v16.4s, v12.4s, v4.s[0] - fmla v17.4s, v13.4s, v4.s[0] - fmla v18.4s, v12.4s, v4.s[1] - fmla v19.4s, v13.4s, v4.s[1] - fmla v20.4s, v12.4s, v4.s[2] - fmla v21.4s, v13.4s, v4.s[2] - fmla v22.4s, v12.4s, v4.s[3] - fmla v23.4s, v13.4s, v4.s[3] - ld1 {v6.4s,v7.4s}, [x12], #32 - fmla v24.4s, v12.4s, v5.s[0] - fmla v25.4s, v13.4s, v5.s[0] - fmla v26.4s, v12.4s, v5.s[1] - fmla v27.4s, v13.4s, v5.s[1] - ld1 {v14.4s, v15.4s}, [x16], #32 - fmla v28.4s, v12.4s, v5.s[2] - fmla v29.4s, v13.4s, v5.s[2] - fmla v30.4s, v12.4s, v5.s[3] - fmla v31.4s, v13.4s, v5.s[3] - fmla v16.4s, v14.4s, v6.s[0] - fmla v17.4s, v15.4s, v6.s[0] - fmla v18.4s, v14.4s, v6.s[1] - fmla v19.4s, v15.4s, v6.s[1] - fmla v20.4s, v14.4s, v6.s[2] - fmla v21.4s, v15.4s, v6.s[2] - fmla v22.4s, v14.4s, v6.s[3] - fmla v23.4s, v15.4s, v6.s[3] - fmla v24.4s, v14.4s, v7.s[0] - fmla v25.4s, v15.4s, v7.s[0] - fmla v26.4s, v14.4s, v7.s[1] - fmla v27.4s, v15.4s, v7.s[1] - fmla v28.4s, v14.4s, v7.s[2] - fmla v29.4s, v15.4s, v7.s[2] - fmla v30.4s, v14.4s, v7.s[3] - fmla v31.4s, v15.4s, v7.s[3] +LoopStart: + ld1 {v0.4s, v1.4s, v2.4s}, [x12], #48 + ld1 {v3.4s, v4.4s}, [x16], #32 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] - sub w13, w13, #4 - cmp w13, #0 - ble Bias - cmp w13, #4 - bge OptLoopMul4 + subs w13, w13, #1 + beq LoopEnd -CommLoopMul: - ld1 {v0.4s, v1.4s}, [x12], #32 - ld1 {v2.4s, v3.4s}, [x16], #32 - fmla v16.4s, v2.4s, v0.s[0] - fmla v17.4s, v3.4s, v0.s[0] - fmla v18.4s, v2.4s, v0.s[1] - fmla v19.4s, v3.4s, v0.s[1] - fmla v20.4s, v2.4s, v0.s[2] - fmla v21.4s, v3.4s, v0.s[2] - fmla v22.4s, v2.4s, v0.s[3] - fmla v23.4s, v3.4s, v0.s[3] - fmla v24.4s, v2.4s, v1.s[0] - fmla v25.4s, v3.4s, v1.s[0] - fmla v26.4s, v2.4s, v1.s[1] - fmla v27.4s, v3.4s, v1.s[1] - fmla v28.4s, v2.4s, v1.s[2] - fmla v29.4s, v3.4s, v1.s[2] - fmla v30.4s, v2.4s, v1.s[3] - fmla v31.4s, v3.4s, v1.s[3] +Loop: + ld1 {v0.4s}, [x12], #16 + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v17.4s, v4.4s, v1.s[0] + fmla v19.4s, v4.4s, v1.s[1] + fmla v21.4s, v4.4s, v1.s[2] + fmla v23.4s, v4.4s, v1.s[3] + ld1 {v1.4s}, [x12], #16 + fmla v24.4s, v3.4s, v2.s[0] + fmla v26.4s, v3.4s, v2.s[1] + fmla v28.4s, v3.4s, v2.s[2] + fmla v30.4s, v3.4s, v2.s[3] + ld1 {v3.4s}, [x16], #16 + fmla v25.4s, v4.4s, v2.s[0] + fmla v27.4s, v4.4s, v2.s[1] + fmla v29.4s, v4.4s, v2.s[2] + fmla v31.4s, v4.4s, v2.s[3] + ld1 {v4.4s}, [x16], #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + ld1 {v2.4s}, [x12], #16 + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] subs w13, w13, #1 - bgt CommLoopMul + bgt Loop + +LoopEnd: + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v17.4s, v4.4s, v1.s[0] + fmla v19.4s, v4.4s, v1.s[1] + fmla v21.4s, v4.4s, v1.s[2] + fmla v23.4s, v4.4s, v1.s[3] + fmla v24.4s, v3.4s, v2.s[0] + fmla v26.4s, v3.4s, v2.s[1] + fmla v28.4s, v3.4s, v2.s[2] + fmla v30.4s, v3.4s, v2.s[3] + fmla v25.4s, v4.4s, v2.s[0] + fmla v27.4s, v4.4s, v2.s[1] + fmla v29.4s, v4.4s, v2.s[2] + fmla v31.4s, v4.4s, v2.s[3] Bias: - cbz x11, Activation - ld1 {v0.4s}, [x14], #16 - ld1 {v1.4s}, [x14], #16 + cbz x3, Activation + ld1 {v0.4s}, [x3], #16 + ld1 {v1.4s}, [x3] + sub x3, x3, #16 + fadd v8.4s, v8.4s, v0.4s + fadd v9.4s, v9.4s, v1.4s + fadd v10.4s, v10.4s, v0.4s + fadd v11.4s, v11.4s, v1.4s + fadd v12.4s, v12.4s, v0.4s + fadd v13.4s, v13.4s, v1.4s + fadd v14.4s, v14.4s, v0.4s + fadd v15.4s, v15.4s, v1.4s fadd v16.4s, v16.4s, v0.4s fadd v17.4s, v17.4s, v1.4s fadd v18.4s, v18.4s, v0.4s @@ -192,48 +178,64 @@ Activation: b Write Relu6: - mov w8, #6 - dup v15.4s, w8 - scvtf v15.4s, v15.4s - fmin v16.4s, v16.4s, v15.4s - fmin v17.4s, v17.4s, v15.4s - fmin v18.4s, v18.4s, v15.4s - fmin v19.4s, v19.4s, v15.4s - fmin v20.4s, v20.4s, v15.4s - fmin v21.4s, v21.4s, v15.4s - fmin v22.4s, v22.4s, v15.4s - fmin v23.4s, v23.4s, v15.4s - fmin v24.4s, v24.4s, v15.4s - fmin v25.4s, v25.4s, v15.4s - fmin v26.4s, v26.4s, v15.4s - fmin v27.4s, v27.4s, v15.4s - fmin v28.4s, v28.4s, v15.4s - fmin v29.4s, v29.4s, v15.4s - fmin v30.4s, v30.4s, v15.4s - fmin v31.4s, v31.4s, v15.4s + mov w13, #6 + dup v2.4s, w13 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v9.4s, v9.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v11.4s, v11.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v13.4s, v13.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v15.4s, v15.4s, v2.4s + fmin v16.4s, v16.4s, v2.4s + fmin v17.4s, v17.4s, v2.4s + fmin v18.4s, v18.4s, v2.4s + fmin v19.4s, v19.4s, v2.4s + fmin v20.4s, v20.4s, v2.4s + fmin v21.4s, v21.4s, v2.4s + fmin v22.4s, v22.4s, v2.4s + fmin v23.4s, v23.4s, v2.4s + fmin v24.4s, v24.4s, v2.4s + fmin v25.4s, v25.4s, v2.4s + fmin v26.4s, v26.4s, v2.4s + fmin v27.4s, v27.4s, v2.4s + fmin v28.4s, v28.4s, v2.4s + fmin v29.4s, v29.4s, v2.4s + fmin v30.4s, v30.4s, v2.4s + fmin v31.4s, v31.4s, v2.4s Relu: - dup v14.4s, wzr - fmax v16.4s, v16.4s, v14.4s - fmax v17.4s, v17.4s, v14.4s - fmax v18.4s, v18.4s, v14.4s - fmax v19.4s, v19.4s, v14.4s - fmax v20.4s, v20.4s, v14.4s - fmax v21.4s, v21.4s, v14.4s - fmax v22.4s, v22.4s, v14.4s - fmax v23.4s, v23.4s, v14.4s - fmax v24.4s, v24.4s, v14.4s - fmax v25.4s, v25.4s, v14.4s - fmax v26.4s, v26.4s, v14.4s - fmax v27.4s, v27.4s, v14.4s - fmax v28.4s, v28.4s, v14.4s - fmax v29.4s, v29.4s, v14.4s - fmax v30.4s, v30.4s, v14.4s - fmax v31.4s, v31.4s, v14.4s + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v9.4s, v9.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v11.4s, v11.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v13.4s, v13.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v15.4s, v15.4s, v3.4s + fmax v16.4s, v16.4s, v3.4s + fmax v17.4s, v17.4s, v3.4s + fmax v18.4s, v18.4s, v3.4s + fmax v19.4s, v19.4s, v3.4s + fmax v20.4s, v20.4s, v3.4s + fmax v21.4s, v21.4s, v3.4s + fmax v22.4s, v22.4s, v3.4s + fmax v23.4s, v23.4s, v3.4s + fmax v24.4s, v24.4s, v3.4s + fmax v25.4s, v25.4s, v3.4s + fmax v26.4s, v26.4s, v3.4s + fmax v27.4s, v27.4s, v3.4s + fmax v28.4s, v28.4s, v3.4s + fmax v29.4s, v29.4s, v3.4s + fmax v30.4s, v30.4s, v3.4s + fmax v31.4s, v31.4s, v3.4s Write: - ldrb w13, [sp, #8] - cbz w13, WriteC8 + cbnz x14, WriteWino + cbz x9, WriteC8 cmp w7, #1 beq Write1 cmp w7, #2 @@ -251,71 +253,107 @@ Write: b Write8 Write1: - str s16, [x18] + str s8, [x18] cmp w10, #1 beq WriteEnd add x18, x18, x17 - str s18, [x18] + str s10, [x18] cmp w10, #2 beq WriteEnd add x18, x18, x17 - str s20, [x18] + str s12, [x18] cmp w10, #3 beq WriteEnd add x18, x18, x17 - str s22, [x18] + str s14, [x18] cmp w10, #4 beq WriteEnd add x18, x18, x17 - str s24, [x18] + str s16, [x18] cmp w10, #5 beq WriteEnd add x18, x18, x17 - str s26, [x18] + str s18, [x18] cmp w10, #6 beq WriteEnd add x18, x18, x17 - str s28, [x18] + str s20, [x18] cmp w10, #7 beq WriteEnd add x18, x18, x17 + str s22, [x18] + cmp w10, #8 + beq WriteEnd + add x18, x18, x17 + str s24, [x18] + cmp w10, #9 + beq WriteEnd + add x18, x18, x17 + str s26, [x18] + cmp w10, #10 + beq WriteEnd + add x18, x18, x17 + str s28, [x18] + cmp w10, #11 + beq WriteEnd + add x18, x18, x17 str s30, [x18] add x18, x18, x17 b WriteEnd Write2: + dup s9, v8.s[1] + stp s8, s9, [x18] + cmp w10, #1 + beq WriteEnd + add x18, x18, x17 + dup s11, v10.s[1] + stp s10, s11, [x18] + cmp w10, #2 + beq WriteEnd + add x18, x18, x17 + dup s13, v12.s[1] + stp s12, s13, [x18] + cmp w10, #3 + beq WriteEnd + add x18, x18, x17 + dup s15, v14.s[1] + stp s14, s15, [x18] + cmp w10, #4 + beq WriteEnd + add x18, x18, x17 dup s17, v16.s[1] stp s16, s17, [x18] - cmp w10, #1 + cmp w10, #5 beq WriteEnd add x18, x18, x17 dup s19, v18.s[1] stp s18, s19, [x18] - cmp w10, #2 + cmp w10, #6 beq WriteEnd add x18, x18, x17 dup s21, v20.s[1] stp s20, s21, [x18] - cmp w10, #3 + cmp w10, #7 beq WriteEnd add x18, x18, x17 dup s23, v22.s[1] stp s22, s23, [x18] - cmp w10, #4 + cmp w10, #8 beq WriteEnd add x18, x18, x17 dup s25, v24.s[1] stp s24, s25, [x18] - cmp w10, #5 + cmp w10, #9 beq WriteEnd add x18, x18, x17 dup s27, v26.s[1] stp s26, s27, [x18] - cmp w10, #6 + cmp w10, #10 beq WriteEnd add x18, x18, x17 dup s29, v28.s[1] stp s28, s29, [x18] - cmp w10, #7 + cmp w10, #11 beq WriteEnd add x18, x18, x17 dup s31, v30.s[1] @@ -324,47 +362,71 @@ Write2: b WriteEnd Write3: add x13, x18, #8 + dup s9, v8.s[1] + stp s8, s9, [x18] + add x18, x18, x17 + st1 {v8.s}[2], [x13], x17 + cmp w10, #1 + beq WriteEnd + dup s11, v10.s[1] + stp s10, s11, [x18] + add x18, x18, x17 + st1 {v10.s}[2], [x13], x17 + cmp w10, #2 + beq WriteEnd + dup s13, v12.s[1] + stp s12, s13, [x18] + add x18, x18, x17 + st1 {v12.s}[2], [x13], x17 + cmp w10, #3 + beq WriteEnd + dup s15, v14.s[1] + stp s14, s15, [x18] + add x18, x18, x17 + st1 {v14.s}[2], [x13], x17 + cmp w10, #4 + beq WriteEnd dup s17, v16.s[1] stp s16, s17, [x18] add x18, x18, x17 st1 {v16.s}[2], [x13], x17 - cmp w10, #1 + cmp w10, #5 beq WriteEnd dup s19, v18.s[1] stp s18, s19, [x18] add x18, x18, x17 st1 {v18.s}[2], [x13], x17 - cmp w10, #2 + cmp w10, #6 beq WriteEnd dup s21, v20.s[1] stp s20, s21, [x18] add x18, x18, x17 st1 {v20.s}[2], [x13], x17 - cmp w10, #3 + cmp w10, #7 beq WriteEnd dup s23, v22.s[1] stp s22, s23, [x18] add x18, x18, x17 st1 {v22.s}[2], [x13], x17 - cmp w10, #4 + cmp w10, #8 beq WriteEnd dup s25, v24.s[1] stp s24, s25, [x18] add x18, x18, x17 st1 {v24.s}[2], [x13], x17 - cmp w10, #5 + cmp w10, #9 beq WriteEnd dup s27, v26.s[1] stp s26, s27, [x18] add x18, x18, x17 st1 {v26.s}[2], [x13], x17 - cmp w10, #6 + cmp w10, #10 beq WriteEnd dup s29, v28.s[1] stp s28, s29, [x18] add x18, x18, x17 st1 {v28.s}[2], [x13], x17 - cmp w10, #7 + cmp w10, #11 beq WriteEnd dup s31, v30.s[1] stp s30, s31, [x18] @@ -372,64 +434,96 @@ Write3: st1 {v30.s}[2], [x13] b WriteEnd Write4: - st1 {v16.4s}, [x18], x17 + st1 {v8.4s}, [x18], x17 cmp w10, #1 beq WriteEnd - st1 {v18.4s}, [x18], x17 + st1 {v10.4s}, [x18], x17 cmp w10, #2 beq WriteEnd - st1 {v20.4s}, [x18], x17 + st1 {v12.4s}, [x18], x17 cmp w10, #3 beq WriteEnd - st1 {v22.4s}, [x18], x17 + st1 {v14.4s}, [x18], x17 cmp w10, #4 beq WriteEnd - st1 {v24.4s}, [x18], x17 + st1 {v16.4s}, [x18], x17 cmp w10, #5 beq WriteEnd - st1 {v26.4s}, [x18], x17 + st1 {v18.4s}, [x18], x17 cmp w10, #6 beq WriteEnd - st1 {v28.4s}, [x18], x17 + st1 {v20.4s}, [x18], x17 cmp w10, #7 beq WriteEnd + st1 {v22.4s}, [x18], x17 + cmp w10, #8 + beq WriteEnd + st1 {v24.4s}, [x18], x17 + cmp w10, #9 + beq WriteEnd + st1 {v26.4s}, [x18], x17 + cmp w10, #10 + beq WriteEnd + st1 {v28.4s}, [x18], x17 + cmp w10, #11 + beq WriteEnd st1 {v30.4s}, [x18], x17 b WriteEnd Write5: add x13, x18, #16 + st1 {v8.4s}, [x18], x17 + str s9, [x13] + cmp w10, #1 + beq WriteEnd + add x13, x13, x17 + st1 {v10.4s}, [x18], x17 + str s11, [x13] + cmp w10, #2 + beq WriteEnd + add x13, x13, x17 + st1 {v12.4s}, [x18], x17 + str s13, [x13] + cmp w10, #3 + beq WriteEnd + add x13, x13, x17 + st1 {v14.4s}, [x18], x17 + str s15, [x13] + cmp w10, #4 + beq WriteEnd + add x13, x13, x17 st1 {v16.4s}, [x18], x17 str s17, [x13] - cmp w10, #1 + cmp w10, #5 beq WriteEnd add x13, x13, x17 st1 {v18.4s}, [x18], x17 str s19, [x13] - cmp w10, #2 + cmp w10, #6 beq WriteEnd add x13, x13, x17 st1 {v20.4s}, [x18], x17 str s21, [x13] - cmp w10, #3 + cmp w10, #7 beq WriteEnd add x13, x13, x17 st1 {v22.4s}, [x18], x17 str s23, [x13] - cmp w10, #4 + cmp w10, #8 beq WriteEnd add x13, x13, x17 st1 {v24.4s}, [x18], x17 str s25, [x13] - cmp w10, #5 + cmp w10, #9 beq WriteEnd add x13, x13, x17 st1 {v26.4s}, [x18], x17 str s27, [x13] - cmp w10, #6 + cmp w10, #10 beq WriteEnd add x13, x13, x17 st1 {v28.4s}, [x18], x17 str s29, [x13] - cmp w10, #7 + cmp w10, #11 beq WriteEnd add x13, x13, x17 st1 {v30.4s}, [x18], x17 @@ -437,46 +531,70 @@ Write5: b WriteEnd Write6: add x13, x18, #16 + st1 {v8.4s}, [x18], x17 + dup s8, v9.s[1] + stp s9, s8, [x13] + cmp w10, #1 + beq WriteEnd + add x13, x13, x17 + st1 {v10.4s}, [x18], x17 + dup s10, v11.s[1] + stp s11, s10, [x13] + cmp w10, #2 + beq WriteEnd + add x13, x13, x17 + st1 {v12.4s}, [x18], x17 + dup s12, v13.s[1] + stp s13, s12, [x13] + cmp w10, #3 + beq WriteEnd + add x13, x13, x17 + st1 {v14.4s}, [x18], x17 + dup s14, v15.s[1] + stp s15, s14, [x13] + cmp w10, #4 + beq WriteEnd + add x13, x13, x17 st1 {v16.4s}, [x18], x17 dup s16, v17.s[1] stp s17, s16, [x13] - cmp w10, #1 + cmp w10, #5 beq WriteEnd add x13, x13, x17 st1 {v18.4s}, [x18], x17 dup s18, v19.s[1] stp s19, s18, [x13] - cmp w10, #2 + cmp w10, #6 beq WriteEnd add x13, x13, x17 st1 {v20.4s}, [x18], x17 dup s20, v21.s[1] stp s21, s20, [x13] - cmp w10, #3 + cmp w10, #7 beq WriteEnd add x13, x13, x17 st1 {v22.4s}, [x18], x17 dup s22, v23.s[1] stp s23, s22, [x13] - cmp w10, #4 + cmp w10, #8 beq WriteEnd add x13, x13, x17 st1 {v24.4s}, [x18], x17 dup s24, v25.s[1] stp s25, s24, [x13] - cmp w10, #5 + cmp w10, #9 beq WriteEnd add x13, x13, x17 st1 {v26.4s}, [x18], x17 dup s26, v27.s[1] stp s27, s26, [x13] - cmp w10, #6 + cmp w10, #10 beq WriteEnd add x13, x13, x17 st1 {v28.4s}, [x18], x17 dup s28, v29.s[1] stp s29, s28, [x13] - cmp w10, #7 + cmp w10, #11 beq WriteEnd add x13, x13, x17 st1 {v30.4s}, [x18], x17 @@ -486,54 +604,82 @@ Write6: Write7: add x13, x18, #16 add x16, x18, #24 + st1 {v8.4s}, [x18], x17 + dup s8, v9.s[1] + stp s9, s8, [x13] + add x13, x13, x17 + st1 {v9.s}[2], [x16], x17 + cmp w10, #1 + beq WriteEnd + st1 {v10.4s}, [x18], x17 + dup s10, v11.s[1] + stp s11, s10, [x13] + add x13, x13, x17 + st1 {v11.s}[2], [x16], x17 + cmp w10, #2 + beq WriteEnd + st1 {v12.4s}, [x18], x17 + dup s12, v13.s[1] + stp s13, s12, [x13] + add x13, x13, x17 + st1 {v13.s}[2], [x16], x17 + cmp w10, #3 + beq WriteEnd + st1 {v14.4s}, [x18], x17 + dup s14, v15.s[1] + stp s15, s14, [x13] + add x13, x13, x17 + st1 {v15.s}[2], [x16], x17 + cmp w10, #4 + beq WriteEnd st1 {v16.4s}, [x18], x17 dup s16, v17.s[1] stp s17, s16, [x13] add x13, x13, x17 st1 {v17.s}[2], [x16], x17 - cmp w10, #1 + cmp w10, #5 beq WriteEnd st1 {v18.4s}, [x18], x17 dup s18, v19.s[1] stp s19, s18, [x13] add x13, x13, x17 st1 {v19.s}[2], [x16], x17 - cmp w10, #2 + cmp w10, #6 beq WriteEnd st1 {v20.4s}, [x18], x17 dup s20, v21.s[1] stp s21, s20, [x13] add x13, x13, x17 st1 {v21.s}[2], [x16], x17 - cmp w10, #3 + cmp w10, #7 beq WriteEnd st1 {v22.4s}, [x18], x17 dup s22, v23.s[1] stp s23, s22, [x13] add x13, x13, x17 st1 {v23.s}[2], [x16], x17 - cmp w10, #4 + cmp w10, #8 beq WriteEnd st1 {v24.4s}, [x18], x17 dup s24, v25.s[1] stp s25, s24, [x13] add x13, x13, x17 st1 {v25.s}[2], [x16], x17 - cmp w10, #5 + cmp w10, #9 beq WriteEnd st1 {v26.4s}, [x18], x17 dup s26, v27.s[1] stp s27, s26, [x13] add x13, x13, x17 st1 {v27.s}[2], [x16], x17 - cmp w10, #6 + cmp w10, #10 beq WriteEnd st1 {v28.4s}, [x18], x17 dup s28, v29.s[1] stp s29, s28, [x13] add x13, x13, x17 st1 {v29.s}[2], [x16], x17 - cmp w10, #7 + cmp w10, #11 beq WriteEnd st1 {v30.4s}, [x18], x17 dup s30, v31.s[1] @@ -542,46 +688,79 @@ Write7: st1 {v31.s}[2], [x16], x17 b WriteEnd WriteC8: - st1 {v16.8h, v17.8h, v18.8h, v19.8h}, [x2], #64 - st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x2], #64 - st1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x2], #64 - st1 {v28.8h, v29.8h, v30.8h, v31.8h}, [x2], #64 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x2], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x2], #64 + st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x2], #64 + st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x2], #64 + st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + b WriteEnd +WriteWino: + st1 {v8.4s, v9.4s}, [x18], x8 + st1 {v10.4s, v11.4s}, [x18], x8 + st1 {v12.4s, v13.4s}, [x18], x8 + st1 {v14.4s, v15.4s}, [x18], x8 + st1 {v16.4s, v17.4s}, [x18], x8 + st1 {v18.4s, v19.4s}, [x18], x8 + st1 {v20.4s, v21.4s}, [x18], x8 + st1 {v22.4s, v23.4s}, [x18], x8 + st1 {v24.4s, v25.4s}, [x18], x8 + st1 {v26.4s, v27.4s}, [x18], x8 + st1 {v28.4s, v29.4s}, [x18], x8 + st1 {v30.4s, v31.4s}, [x18], x8 b WriteEnd Write8: - st1 {v16.4s, v17.4s}, [x18], x17 + st1 {v8.4s, v9.4s}, [x18], x17 cmp w10, #1 beq WriteEnd - st1 {v18.4s, v19.4s}, [x18], x17 + st1 {v10.4s, v11.4s}, [x18], x17 cmp w10, #2 beq WriteEnd - st1 {v20.4s, v21.4s}, [x18], x17 + st1 {v12.4s, v13.4s}, [x18], x17 cmp w10, #3 beq WriteEnd - st1 {v22.4s, v23.4s}, [x18], x17 + st1 {v14.4s, v15.4s}, [x18], x17 cmp w10, #4 beq WriteEnd - st1 {v24.4s, v25.4s}, [x18], x17 + st1 {v16.4s, v17.4s}, [x18], x17 cmp w10, #5 beq WriteEnd - st1 {v26.4s, v27.4s}, [x18], x17 + st1 {v18.4s, v19.4s}, [x18], x17 cmp w10, #6 beq WriteEnd - st1 {v28.4s, v29.4s}, [x18], x17 + st1 {v20.4s, v21.4s}, [x18], x17 cmp w10, #7 beq WriteEnd + st1 {v22.4s, v23.4s}, [x18], x17 + cmp w10, #8 + beq WriteEnd + st1 {v24.4s, v25.4s}, [x18], x17 + cmp w10, #9 + beq WriteEnd + st1 {v26.4s, v27.4s}, [x18], x17 + cmp w10, #10 + beq WriteEnd + st1 {v28.4s, v29.4s}, [x18], x17 + cmp w10, #11 + beq WriteEnd st1 {v30.4s, v31.4s}, [x18], x17 WriteEnd: - subs w10, w10, #8 // lhs row - 8 + subs w10, w10, #12 // lhs row - 12 bgt L2 End2: subs w7, w7, #8 // rhs col - 8 add x1, x1, x15 // rhs ptr + stride + cbz x3, NoBiasStep add x3, x3, #32 // bias ptr + stride - ldrb w13, [sp, #8] - cbz w13, NoDstStep +NoBiasStep: + cbnz x14, WinoDstStep + cbz x9, NoDstStep add x2, x2, #32 // dst ptr + stride + b NoDstStep +WinoDstStep: + add x2, x2, x11 NoDstStep: bgt L1 diff --git a/mindspore/lite/nnacl/assembly/arm64/MatmulFp32Opt.S b/mindspore/lite/nnacl/assembly/arm64/MatmulFp32Opt.S index d998c791bc..16d6bba647 100644 --- a/mindspore/lite/nnacl/assembly/arm64/MatmulFp32Opt.S +++ b/mindspore/lite/nnacl/assembly/arm64/MatmulFp32Opt.S @@ -7,766 +7,814 @@ #endif // void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth -// int row, int col, size_t stride, size_t writeNhwc, size_t WriteWino) +// int row, int col, size_t stride, size_t writeMode) // x0: a // x1: b // x2: c // x3: bias -// w4: act_type -// w5: depth -// w6: row -// w7: col -// w17: stride -// w13: c8_nhwc_c4 +// x4: act_type +// x5: depth +// x6: row +// x7: col +// x8: stride +// x9: writeMode MatmulFloatNeon64Opt: - sub sp, sp, #128 - st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 - st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + sub sp, sp, #144 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + stp x19, x20, [sp], #16 - ldr x9, [sp, #8] - ldr x14, [sp, #16] + ldr x8, [sp] + ldr x9, [sp, #8] - mov w18, #32 // sizeof(float) * 8 - mul w15, w5, w18 // block stride of lhs/rhs: sizeof(float) * 8 * depth - mov x18, #4 - ldr x17, [sp] - cbz x14, NoWinoSteps - mul x8, x7, x17 - mov x11, #8 - mul x11, x11, x17 - mul x8, x8, x18 - mul x11, x11, x18 + mov x18, #48 // sizeof(float) * 12 + mul x17, x5, x18 // block stride of lhs/rhs: sizeof(float) * 12 * depth + cbnz x9, NoC8Steps + mov x11, x2 + mov x18, #32 + mul x16, x6, x18 // row * 8 * sizeof(float) +NoC8Steps: + cmp x9, #2 + bne NoWinoSteps + mov x18, #4 + mul x15, x7, x8 + mul x15, x15, x18 // kernel_size * col *sizeof(float) + mov x18, #32 + mul x16, x8, x18 // kernel_size * 8 * sizeof(float) NoWinoSteps: - mul x17, x17, x18 + mov x18, #4 + mul x8, x8, x18 -L1: - mov w10, w6 // reload lhs row - mov x12, x0 // reload lhs ptr - mov x18, x2 // reload dst ptr +LoopRow: + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias -L2: - mov x16, x1 // reload rhs ptr - mov w13, w5 // reload depth - dup v8.4s, wzr - dup v9.4s, wzr - dup v10.4s, wzr - dup v11.4s, wzr - dup v12.4s, wzr - dup v13.4s, wzr - dup v14.4s, wzr - dup v15.4s, wzr - dup v16.4s, wzr - dup v17.4s, wzr - dup v18.4s, wzr - dup v19.4s, wzr - dup v20.4s, wzr - dup v21.4s, wzr - dup v22.4s, wzr - dup v23.4s, wzr - dup v24.4s, wzr - dup v25.4s, wzr - dup v26.4s, wzr - dup v27.4s, wzr - dup v28.4s, wzr - dup v29.4s, wzr - dup v30.4s, wzr - dup v31.4s, wzr + LoopCol: + cbz x9, NoReloadDst + mov x11, x2 + NoReloadDst: + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth -LoopStart: - ld1 {v0.4s, v1.4s, v2.4s}, [x12], #48 - ld1 {v3.4s, v4.4s}, [x16], #32 - fmla v8.4s, v3.4s, v0.s[0] - fmla v10.4s, v3.4s, v0.s[1] - fmla v12.4s, v3.4s, v0.s[2] - fmla v14.4s, v3.4s, v0.s[3] - fmla v9.4s, v4.4s, v0.s[0] - fmla v11.4s, v4.4s, v0.s[1] - fmla v13.4s, v4.4s, v0.s[2] - fmla v15.4s, v4.4s, v0.s[3] + cmp x13, #4 + ble LoopDepthStartHalf - subs w13, w13, #1 - beq LoopEnd + LoopDepthStart: + ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 + ld1 {v3.4s, v4.4s}, [x14], #32 + fmul v8.4s, v3.4s, v0.s[0] + fmul v10.4s, v3.4s, v0.s[1] + fmul v12.4s, v3.4s, v0.s[2] + fmul v14.4s, v3.4s, v0.s[3] + fmul v9.4s, v4.4s, v0.s[0] + fmul v11.4s, v4.4s, v0.s[1] + fmul v13.4s, v4.4s, v0.s[2] + fmul v15.4s, v4.4s, v0.s[3] + fmul v16.4s, v3.4s, v1.s[0] + fmul v18.4s, v3.4s, v1.s[1] + fmul v20.4s, v3.4s, v1.s[2] + fmul v22.4s, v3.4s, v1.s[3] + fmul v17.4s, v4.4s, v1.s[0] + fmul v19.4s, v4.4s, v1.s[1] + fmul v21.4s, v4.4s, v1.s[2] + fmul v23.4s, v4.4s, v1.s[3] + fmul v24.4s, v3.4s, v2.s[0] + fmul v26.4s, v3.4s, v2.s[1] + fmul v28.4s, v3.4s, v2.s[2] + fmul v30.4s, v3.4s, v2.s[3] + fmul v25.4s, v4.4s, v2.s[0] + fmul v27.4s, v4.4s, v2.s[1] + fmul v29.4s, v4.4s, v2.s[2] + fmul v31.4s, v4.4s, v2.s[3] -Loop: - ld1 {v0.4s}, [x12], #16 - fmla v16.4s, v3.4s, v1.s[0] - fmla v18.4s, v3.4s, v1.s[1] - fmla v20.4s, v3.4s, v1.s[2] - fmla v22.4s, v3.4s, v1.s[3] - fmla v17.4s, v4.4s, v1.s[0] - fmla v19.4s, v4.4s, v1.s[1] - fmla v21.4s, v4.4s, v1.s[2] - fmla v23.4s, v4.4s, v1.s[3] - ld1 {v1.4s}, [x12], #16 - fmla v24.4s, v3.4s, v2.s[0] - fmla v26.4s, v3.4s, v2.s[1] - fmla v28.4s, v3.4s, v2.s[2] - fmla v30.4s, v3.4s, v2.s[3] - ld1 {v3.4s}, [x16], #16 - fmla v25.4s, v4.4s, v2.s[0] - fmla v27.4s, v4.4s, v2.s[1] - fmla v29.4s, v4.4s, v2.s[2] - fmla v31.4s, v4.4s, v2.s[3] - ld1 {v4.4s}, [x16], #16 - fmla v8.4s, v3.4s, v0.s[0] - fmla v10.4s, v3.4s, v0.s[1] - fmla v12.4s, v3.4s, v0.s[2] - fmla v14.4s, v3.4s, v0.s[3] - ld1 {v2.4s}, [x12], #16 - fmla v9.4s, v4.4s, v0.s[0] - fmla v11.4s, v4.4s, v0.s[1] - fmla v13.4s, v4.4s, v0.s[2] - fmla v15.4s, v4.4s, v0.s[3] + subs x19, x19, #1 + beq Bias - subs w13, w13, #1 - bgt Loop + LoopDepth: + ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 + ld1 {v3.4s, v4.4s}, [x14], #32 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v17.4s, v4.4s, v1.s[0] + fmla v19.4s, v4.4s, v1.s[1] + fmla v21.4s, v4.4s, v1.s[2] + fmla v23.4s, v4.4s, v1.s[3] + fmla v24.4s, v3.4s, v2.s[0] + fmla v26.4s, v3.4s, v2.s[1] + fmla v28.4s, v3.4s, v2.s[2] + fmla v30.4s, v3.4s, v2.s[3] + fmla v25.4s, v4.4s, v2.s[0] + fmla v27.4s, v4.4s, v2.s[1] + fmla v29.4s, v4.4s, v2.s[2] + fmla v31.4s, v4.4s, v2.s[3] -LoopEnd: - fmla v16.4s, v3.4s, v1.s[0] - fmla v18.4s, v3.4s, v1.s[1] - fmla v20.4s, v3.4s, v1.s[2] - fmla v22.4s, v3.4s, v1.s[3] - fmla v17.4s, v4.4s, v1.s[0] - fmla v19.4s, v4.4s, v1.s[1] - fmla v21.4s, v4.4s, v1.s[2] - fmla v23.4s, v4.4s, v1.s[3] - fmla v24.4s, v3.4s, v2.s[0] - fmla v26.4s, v3.4s, v2.s[1] - fmla v28.4s, v3.4s, v2.s[2] - fmla v30.4s, v3.4s, v2.s[3] - fmla v25.4s, v4.4s, v2.s[0] - fmla v27.4s, v4.4s, v2.s[1] - fmla v29.4s, v4.4s, v2.s[2] - fmla v31.4s, v4.4s, v2.s[3] + subs x19, x19, #1 + bgt LoopDepth -Bias: - cbz x3, Activation - ld1 {v0.4s}, [x3], #16 - ld1 {v1.4s}, [x3] - sub x3, x3, #16 - fadd v8.4s, v8.4s, v0.4s - fadd v9.4s, v9.4s, v1.4s - fadd v10.4s, v10.4s, v0.4s - fadd v11.4s, v11.4s, v1.4s - fadd v12.4s, v12.4s, v0.4s - fadd v13.4s, v13.4s, v1.4s - fadd v14.4s, v14.4s, v0.4s - fadd v15.4s, v15.4s, v1.4s - fadd v16.4s, v16.4s, v0.4s - fadd v17.4s, v17.4s, v1.4s - fadd v18.4s, v18.4s, v0.4s - fadd v19.4s, v19.4s, v1.4s - fadd v20.4s, v20.4s, v0.4s - fadd v21.4s, v21.4s, v1.4s - fadd v22.4s, v22.4s, v0.4s - fadd v23.4s, v23.4s, v1.4s - fadd v24.4s, v24.4s, v0.4s - fadd v25.4s, v25.4s, v1.4s - fadd v26.4s, v26.4s, v0.4s - fadd v27.4s, v27.4s, v1.4s - fadd v28.4s, v28.4s, v0.4s - fadd v29.4s, v29.4s, v1.4s - fadd v30.4s, v30.4s, v0.4s - fadd v31.4s, v31.4s, v1.4s + Bias: + cbz x3, Activation + ld1 {v0.4s}, [x12], #16 + ld1 {v1.4s}, [x12], #16 + fadd v8.4s, v8.4s, v0.4s + fadd v9.4s, v9.4s, v1.4s + fadd v10.4s, v10.4s, v0.4s + fadd v11.4s, v11.4s, v1.4s + fadd v12.4s, v12.4s, v0.4s + fadd v13.4s, v13.4s, v1.4s + fadd v14.4s, v14.4s, v0.4s + fadd v15.4s, v15.4s, v1.4s + fadd v16.4s, v16.4s, v0.4s + fadd v17.4s, v17.4s, v1.4s + fadd v18.4s, v18.4s, v0.4s + fadd v19.4s, v19.4s, v1.4s + fadd v20.4s, v20.4s, v0.4s + fadd v21.4s, v21.4s, v1.4s + fadd v22.4s, v22.4s, v0.4s + fadd v23.4s, v23.4s, v1.4s + fadd v24.4s, v24.4s, v0.4s + fadd v25.4s, v25.4s, v1.4s + fadd v26.4s, v26.4s, v0.4s + fadd v27.4s, v27.4s, v1.4s + fadd v28.4s, v28.4s, v0.4s + fadd v29.4s, v29.4s, v1.4s + fadd v30.4s, v30.4s, v0.4s + fadd v31.4s, v31.4s, v1.4s -Activation: - cmp w4, #2 - beq Relu6 - cmp w4, #1 - beq Relu - b Write + Activation: + cmp x4, #2 + beq Relu6 + cmp x4, #1 + beq Relu + b Write -Relu6: - mov w13, #6 - dup v2.4s, w13 - scvtf v2.4s, v2.4s - fmin v8.4s, v8.4s, v2.4s - fmin v9.4s, v9.4s, v2.4s - fmin v10.4s, v10.4s, v2.4s - fmin v11.4s, v11.4s, v2.4s - fmin v12.4s, v12.4s, v2.4s - fmin v13.4s, v13.4s, v2.4s - fmin v14.4s, v14.4s, v2.4s - fmin v15.4s, v15.4s, v2.4s - fmin v16.4s, v16.4s, v2.4s - fmin v17.4s, v17.4s, v2.4s - fmin v18.4s, v18.4s, v2.4s - fmin v19.4s, v19.4s, v2.4s - fmin v20.4s, v20.4s, v2.4s - fmin v21.4s, v21.4s, v2.4s - fmin v22.4s, v22.4s, v2.4s - fmin v23.4s, v23.4s, v2.4s - fmin v24.4s, v24.4s, v2.4s - fmin v25.4s, v25.4s, v2.4s - fmin v26.4s, v26.4s, v2.4s - fmin v27.4s, v27.4s, v2.4s - fmin v28.4s, v28.4s, v2.4s - fmin v29.4s, v29.4s, v2.4s - fmin v30.4s, v30.4s, v2.4s - fmin v31.4s, v31.4s, v2.4s + Relu6: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v9.4s, v9.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v11.4s, v11.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v13.4s, v13.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v15.4s, v15.4s, v2.4s + fmin v16.4s, v16.4s, v2.4s + fmin v17.4s, v17.4s, v2.4s + fmin v18.4s, v18.4s, v2.4s + fmin v19.4s, v19.4s, v2.4s + fmin v20.4s, v20.4s, v2.4s + fmin v21.4s, v21.4s, v2.4s + fmin v22.4s, v22.4s, v2.4s + fmin v23.4s, v23.4s, v2.4s + fmin v24.4s, v24.4s, v2.4s + fmin v25.4s, v25.4s, v2.4s + fmin v26.4s, v26.4s, v2.4s + fmin v27.4s, v27.4s, v2.4s + fmin v28.4s, v28.4s, v2.4s + fmin v29.4s, v29.4s, v2.4s + fmin v30.4s, v30.4s, v2.4s + fmin v31.4s, v31.4s, v2.4s + + Relu: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v9.4s, v9.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v11.4s, v11.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v13.4s, v13.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v15.4s, v15.4s, v3.4s + fmax v16.4s, v16.4s, v3.4s + fmax v17.4s, v17.4s, v3.4s + fmax v18.4s, v18.4s, v3.4s + fmax v19.4s, v19.4s, v3.4s + fmax v20.4s, v20.4s, v3.4s + fmax v21.4s, v21.4s, v3.4s + fmax v22.4s, v22.4s, v3.4s + fmax v23.4s, v23.4s, v3.4s + fmax v24.4s, v24.4s, v3.4s + fmax v25.4s, v25.4s, v3.4s + fmax v26.4s, v26.4s, v3.4s + fmax v27.4s, v27.4s, v3.4s + fmax v28.4s, v28.4s, v3.4s + fmax v29.4s, v29.4s, v3.4s + fmax v30.4s, v30.4s, v3.4s + fmax v31.4s, v31.4s, v3.4s + b Write -Relu: - dup v3.4s, wzr - fmax v8.4s, v8.4s, v3.4s - fmax v9.4s, v9.4s, v3.4s - fmax v10.4s, v10.4s, v3.4s - fmax v11.4s, v11.4s, v3.4s - fmax v12.4s, v12.4s, v3.4s - fmax v13.4s, v13.4s, v3.4s - fmax v14.4s, v14.4s, v3.4s - fmax v15.4s, v15.4s, v3.4s - fmax v16.4s, v16.4s, v3.4s - fmax v17.4s, v17.4s, v3.4s - fmax v18.4s, v18.4s, v3.4s - fmax v19.4s, v19.4s, v3.4s - fmax v20.4s, v20.4s, v3.4s - fmax v21.4s, v21.4s, v3.4s - fmax v22.4s, v22.4s, v3.4s - fmax v23.4s, v23.4s, v3.4s - fmax v24.4s, v24.4s, v3.4s - fmax v25.4s, v25.4s, v3.4s - fmax v26.4s, v26.4s, v3.4s - fmax v27.4s, v27.4s, v3.4s - fmax v28.4s, v28.4s, v3.4s - fmax v29.4s, v29.4s, v3.4s - fmax v30.4s, v30.4s, v3.4s - fmax v31.4s, v31.4s, v3.4s + LoopDepthStartHalf: + ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 + ld1 {v3.4s, v4.4s}, [x14], #32 + fmul v8.4s, v3.4s, v0.s[0] + fmul v10.4s, v3.4s, v0.s[1] + fmul v12.4s, v3.4s, v0.s[2] + fmul v14.4s, v3.4s, v0.s[3] + fmul v16.4s, v3.4s, v1.s[0] + fmul v18.4s, v3.4s, v1.s[1] + fmul v20.4s, v3.4s, v1.s[2] + fmul v22.4s, v3.4s, v1.s[3] + fmul v24.4s, v3.4s, v2.s[0] + fmul v26.4s, v3.4s, v2.s[1] + fmul v28.4s, v3.4s, v2.s[2] + fmul v30.4s, v3.4s, v2.s[3] -Write: - cbnz x14, WriteWino - cbz x9, WriteC8 - cmp w7, #1 - beq Write1 - cmp w7, #2 - beq Write2 - cmp w7, #3 - beq Write3 - cmp w7, #4 - beq Write4 - cmp w7, #5 - beq Write5 - cmp w7, #6 - beq Write6 - cmp w7, #7 - beq Write7 - b Write8 + subs x19, x19, #1 + beq BiasHalf -Write1: - str s8, [x18] - cmp w10, #1 - beq WriteEnd - add x18, x18, x17 - str s10, [x18] - cmp w10, #2 - beq WriteEnd - add x18, x18, x17 - str s12, [x18] - cmp w10, #3 - beq WriteEnd - add x18, x18, x17 - str s14, [x18] - cmp w10, #4 - beq WriteEnd - add x18, x18, x17 - str s16, [x18] - cmp w10, #5 - beq WriteEnd - add x18, x18, x17 - str s18, [x18] - cmp w10, #6 - beq WriteEnd - add x18, x18, x17 - str s20, [x18] - cmp w10, #7 - beq WriteEnd - add x18, x18, x17 - str s22, [x18] - cmp w10, #8 - beq WriteEnd - add x18, x18, x17 - str s24, [x18] - cmp w10, #9 - beq WriteEnd - add x18, x18, x17 - str s26, [x18] - cmp w10, #10 - beq WriteEnd - add x18, x18, x17 - str s28, [x18] - cmp w10, #11 - beq WriteEnd - add x18, x18, x17 - str s30, [x18] - add x18, x18, x17 - b WriteEnd -Write2: - dup s9, v8.s[1] - stp s8, s9, [x18] - cmp w10, #1 - beq WriteEnd - add x18, x18, x17 - dup s11, v10.s[1] - stp s10, s11, [x18] - cmp w10, #2 - beq WriteEnd - add x18, x18, x17 - dup s13, v12.s[1] - stp s12, s13, [x18] - cmp w10, #3 - beq WriteEnd - add x18, x18, x17 - dup s15, v14.s[1] - stp s14, s15, [x18] - cmp w10, #4 - beq WriteEnd - add x18, x18, x17 - dup s17, v16.s[1] - stp s16, s17, [x18] - cmp w10, #5 - beq WriteEnd - add x18, x18, x17 - dup s19, v18.s[1] - stp s18, s19, [x18] - cmp w10, #6 - beq WriteEnd - add x18, x18, x17 - dup s21, v20.s[1] - stp s20, s21, [x18] - cmp w10, #7 - beq WriteEnd - add x18, x18, x17 - dup s23, v22.s[1] - stp s22, s23, [x18] - cmp w10, #8 - beq WriteEnd - add x18, x18, x17 - dup s25, v24.s[1] - stp s24, s25, [x18] - cmp w10, #9 - beq WriteEnd - add x18, x18, x17 - dup s27, v26.s[1] - stp s26, s27, [x18] - cmp w10, #10 - beq WriteEnd - add x18, x18, x17 - dup s29, v28.s[1] - stp s28, s29, [x18] - cmp w10, #11 - beq WriteEnd - add x18, x18, x17 - dup s31, v30.s[1] - stp s30, s31, [x18] - add x18, x18, x17 - b WriteEnd -Write3: - add x13, x18, #8 - dup s9, v8.s[1] - stp s8, s9, [x18] - add x18, x18, x17 - st1 {v8.s}[2], [x13], x17 - cmp w10, #1 - beq WriteEnd - dup s11, v10.s[1] - stp s10, s11, [x18] - add x18, x18, x17 - st1 {v10.s}[2], [x13], x17 - cmp w10, #2 - beq WriteEnd - dup s13, v12.s[1] - stp s12, s13, [x18] - add x18, x18, x17 - st1 {v12.s}[2], [x13], x17 - cmp w10, #3 - beq WriteEnd - dup s15, v14.s[1] - stp s14, s15, [x18] - add x18, x18, x17 - st1 {v14.s}[2], [x13], x17 - cmp w10, #4 - beq WriteEnd - dup s17, v16.s[1] - stp s16, s17, [x18] - add x18, x18, x17 - st1 {v16.s}[2], [x13], x17 - cmp w10, #5 - beq WriteEnd - dup s19, v18.s[1] - stp s18, s19, [x18] - add x18, x18, x17 - st1 {v18.s}[2], [x13], x17 - cmp w10, #6 - beq WriteEnd - dup s21, v20.s[1] - stp s20, s21, [x18] - add x18, x18, x17 - st1 {v20.s}[2], [x13], x17 - cmp w10, #7 - beq WriteEnd - dup s23, v22.s[1] - stp s22, s23, [x18] - add x18, x18, x17 - st1 {v22.s}[2], [x13], x17 - cmp w10, #8 - beq WriteEnd - dup s25, v24.s[1] - stp s24, s25, [x18] - add x18, x18, x17 - st1 {v24.s}[2], [x13], x17 - cmp w10, #9 - beq WriteEnd - dup s27, v26.s[1] - stp s26, s27, [x18] - add x18, x18, x17 - st1 {v26.s}[2], [x13], x17 - cmp w10, #10 - beq WriteEnd - dup s29, v28.s[1] - stp s28, s29, [x18] - add x18, x18, x17 - st1 {v28.s}[2], [x13], x17 - cmp w10, #11 - beq WriteEnd - dup s31, v30.s[1] - stp s30, s31, [x18] - add x18, x18, x17 - st1 {v30.s}[2], [x13] - b WriteEnd -Write4: - st1 {v8.4s}, [x18], x17 - cmp w10, #1 - beq WriteEnd - st1 {v10.4s}, [x18], x17 - cmp w10, #2 - beq WriteEnd - st1 {v12.4s}, [x18], x17 - cmp w10, #3 - beq WriteEnd - st1 {v14.4s}, [x18], x17 - cmp w10, #4 - beq WriteEnd - st1 {v16.4s}, [x18], x17 - cmp w10, #5 - beq WriteEnd - st1 {v18.4s}, [x18], x17 - cmp w10, #6 - beq WriteEnd - st1 {v20.4s}, [x18], x17 - cmp w10, #7 - beq WriteEnd - st1 {v22.4s}, [x18], x17 - cmp w10, #8 - beq WriteEnd - st1 {v24.4s}, [x18], x17 - cmp w10, #9 - beq WriteEnd - st1 {v26.4s}, [x18], x17 - cmp w10, #10 - beq WriteEnd - st1 {v28.4s}, [x18], x17 - cmp w10, #11 - beq WriteEnd - st1 {v30.4s}, [x18], x17 - b WriteEnd -Write5: - add x13, x18, #16 - st1 {v8.4s}, [x18], x17 - str s9, [x13] - cmp w10, #1 - beq WriteEnd - add x13, x13, x17 - st1 {v10.4s}, [x18], x17 - str s11, [x13] - cmp w10, #2 - beq WriteEnd - add x13, x13, x17 - st1 {v12.4s}, [x18], x17 - str s13, [x13] - cmp w10, #3 - beq WriteEnd - add x13, x13, x17 - st1 {v14.4s}, [x18], x17 - str s15, [x13] - cmp w10, #4 - beq WriteEnd - add x13, x13, x17 - st1 {v16.4s}, [x18], x17 - str s17, [x13] - cmp w10, #5 - beq WriteEnd - add x13, x13, x17 - st1 {v18.4s}, [x18], x17 - str s19, [x13] - cmp w10, #6 - beq WriteEnd - add x13, x13, x17 - st1 {v20.4s}, [x18], x17 - str s21, [x13] - cmp w10, #7 - beq WriteEnd - add x13, x13, x17 - st1 {v22.4s}, [x18], x17 - str s23, [x13] - cmp w10, #8 - beq WriteEnd - add x13, x13, x17 - st1 {v24.4s}, [x18], x17 - str s25, [x13] - cmp w10, #9 - beq WriteEnd - add x13, x13, x17 - st1 {v26.4s}, [x18], x17 - str s27, [x13] - cmp w10, #10 - beq WriteEnd - add x13, x13, x17 - st1 {v28.4s}, [x18], x17 - str s29, [x13] - cmp w10, #11 - beq WriteEnd - add x13, x13, x17 - st1 {v30.4s}, [x18], x17 - str s31, [x13] - b WriteEnd -Write6: - add x13, x18, #16 - st1 {v8.4s}, [x18], x17 - dup s8, v9.s[1] - stp s9, s8, [x13] - cmp w10, #1 - beq WriteEnd - add x13, x13, x17 - st1 {v10.4s}, [x18], x17 - dup s10, v11.s[1] - stp s11, s10, [x13] - cmp w10, #2 - beq WriteEnd - add x13, x13, x17 - st1 {v12.4s}, [x18], x17 - dup s12, v13.s[1] - stp s13, s12, [x13] - cmp w10, #3 - beq WriteEnd - add x13, x13, x17 - st1 {v14.4s}, [x18], x17 - dup s14, v15.s[1] - stp s15, s14, [x13] - cmp w10, #4 - beq WriteEnd - add x13, x13, x17 - st1 {v16.4s}, [x18], x17 - dup s16, v17.s[1] - stp s17, s16, [x13] - cmp w10, #5 - beq WriteEnd - add x13, x13, x17 - st1 {v18.4s}, [x18], x17 - dup s18, v19.s[1] - stp s19, s18, [x13] - cmp w10, #6 - beq WriteEnd - add x13, x13, x17 - st1 {v20.4s}, [x18], x17 - dup s20, v21.s[1] - stp s21, s20, [x13] - cmp w10, #7 - beq WriteEnd - add x13, x13, x17 - st1 {v22.4s}, [x18], x17 - dup s22, v23.s[1] - stp s23, s22, [x13] - cmp w10, #8 - beq WriteEnd - add x13, x13, x17 - st1 {v24.4s}, [x18], x17 - dup s24, v25.s[1] - stp s25, s24, [x13] - cmp w10, #9 - beq WriteEnd - add x13, x13, x17 - st1 {v26.4s}, [x18], x17 - dup s26, v27.s[1] - stp s27, s26, [x13] - cmp w10, #10 - beq WriteEnd - add x13, x13, x17 - st1 {v28.4s}, [x18], x17 - dup s28, v29.s[1] - stp s29, s28, [x13] - cmp w10, #11 - beq WriteEnd - add x13, x13, x17 - st1 {v30.4s}, [x18], x17 - dup s30, v31.s[1] - stp s31, s30, [x13] - b WriteEnd -Write7: - add x13, x18, #16 - add x16, x18, #24 - st1 {v8.4s}, [x18], x17 - dup s8, v9.s[1] - stp s9, s8, [x13] - add x13, x13, x17 - st1 {v9.s}[2], [x16], x17 - cmp w10, #1 - beq WriteEnd - st1 {v10.4s}, [x18], x17 - dup s10, v11.s[1] - stp s11, s10, [x13] - add x13, x13, x17 - st1 {v11.s}[2], [x16], x17 - cmp w10, #2 - beq WriteEnd - st1 {v12.4s}, [x18], x17 - dup s12, v13.s[1] - stp s13, s12, [x13] - add x13, x13, x17 - st1 {v13.s}[2], [x16], x17 - cmp w10, #3 - beq WriteEnd - st1 {v14.4s}, [x18], x17 - dup s14, v15.s[1] - stp s15, s14, [x13] - add x13, x13, x17 - st1 {v15.s}[2], [x16], x17 - cmp w10, #4 - beq WriteEnd - st1 {v16.4s}, [x18], x17 - dup s16, v17.s[1] - stp s17, s16, [x13] - add x13, x13, x17 - st1 {v17.s}[2], [x16], x17 - cmp w10, #5 - beq WriteEnd - st1 {v18.4s}, [x18], x17 - dup s18, v19.s[1] - stp s19, s18, [x13] - add x13, x13, x17 - st1 {v19.s}[2], [x16], x17 - cmp w10, #6 - beq WriteEnd - st1 {v20.4s}, [x18], x17 - dup s20, v21.s[1] - stp s21, s20, [x13] - add x13, x13, x17 - st1 {v21.s}[2], [x16], x17 - cmp w10, #7 - beq WriteEnd - st1 {v22.4s}, [x18], x17 - dup s22, v23.s[1] - stp s23, s22, [x13] - add x13, x13, x17 - st1 {v23.s}[2], [x16], x17 - cmp w10, #8 - beq WriteEnd - st1 {v24.4s}, [x18], x17 - dup s24, v25.s[1] - stp s25, s24, [x13] - add x13, x13, x17 - st1 {v25.s}[2], [x16], x17 - cmp w10, #9 - beq WriteEnd - st1 {v26.4s}, [x18], x17 - dup s26, v27.s[1] - stp s27, s26, [x13] - add x13, x13, x17 - st1 {v27.s}[2], [x16], x17 - cmp w10, #10 - beq WriteEnd - st1 {v28.4s}, [x18], x17 - dup s28, v29.s[1] - stp s29, s28, [x13] - add x13, x13, x17 - st1 {v29.s}[2], [x16], x17 - cmp w10, #11 - beq WriteEnd - st1 {v30.4s}, [x18], x17 - dup s30, v31.s[1] - stp s31, s30, [x13] - add x13, x13, x17 - st1 {v31.s}[2], [x16], x17 - b WriteEnd -WriteC8: - st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x2], #64 - st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x2], #64 - st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 - st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x2], #64 - st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x2], #64 - st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 - b WriteEnd -WriteWino: - st1 {v8.4s, v9.4s}, [x18], x8 - st1 {v10.4s, v11.4s}, [x18], x8 - st1 {v12.4s, v13.4s}, [x18], x8 - st1 {v14.4s, v15.4s}, [x18], x8 - st1 {v16.4s, v17.4s}, [x18], x8 - st1 {v18.4s, v19.4s}, [x18], x8 - st1 {v20.4s, v21.4s}, [x18], x8 - st1 {v22.4s, v23.4s}, [x18], x8 - st1 {v24.4s, v25.4s}, [x18], x8 - st1 {v26.4s, v27.4s}, [x18], x8 - st1 {v28.4s, v29.4s}, [x18], x8 - st1 {v30.4s, v31.4s}, [x18], x8 - b WriteEnd -Write8: - st1 {v8.4s, v9.4s}, [x18], x17 - cmp w10, #1 - beq WriteEnd - st1 {v10.4s, v11.4s}, [x18], x17 - cmp w10, #2 - beq WriteEnd - st1 {v12.4s, v13.4s}, [x18], x17 - cmp w10, #3 - beq WriteEnd - st1 {v14.4s, v15.4s}, [x18], x17 - cmp w10, #4 - beq WriteEnd - st1 {v16.4s, v17.4s}, [x18], x17 - cmp w10, #5 - beq WriteEnd - st1 {v18.4s, v19.4s}, [x18], x17 - cmp w10, #6 - beq WriteEnd - st1 {v20.4s, v21.4s}, [x18], x17 - cmp w10, #7 - beq WriteEnd - st1 {v22.4s, v23.4s}, [x18], x17 - cmp w10, #8 - beq WriteEnd - st1 {v24.4s, v25.4s}, [x18], x17 - cmp w10, #9 - beq WriteEnd - st1 {v26.4s, v27.4s}, [x18], x17 - cmp w10, #10 - beq WriteEnd - st1 {v28.4s, v29.4s}, [x18], x17 - cmp w10, #11 - beq WriteEnd - st1 {v30.4s, v31.4s}, [x18], x17 + LoopDepthHalf: + ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 + ld1 {v3.4s, v4.4s}, [x14], #32 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v24.4s, v3.4s, v2.s[0] + fmla v26.4s, v3.4s, v2.s[1] + fmla v28.4s, v3.4s, v2.s[2] + fmla v30.4s, v3.4s, v2.s[3] -WriteEnd: - subs w10, w10, #12 // lhs row - 12 - bgt L2 + subs x19, x19, #1 + bgt LoopDepthHalf -End2: - subs w7, w7, #8 // rhs col - 8 - add x1, x1, x15 // rhs ptr + stride - cbz x3, NoBiasStep - add x3, x3, #32 // bias ptr + stride -NoBiasStep: - cbnz x14, WinoDstStep - cbz x9, NoDstStep - add x2, x2, #32 // dst ptr + stride - b NoDstStep -WinoDstStep: - add x2, x2, x11 -NoDstStep: - bgt L1 + BiasHalf: + cbz x3, ActivationHalf + ld1 {v0.4s}, [x12], #16 + ld1 {v1.4s}, [x12], #16 + fadd v8.4s, v8.4s, v0.4s + fadd v10.4s, v10.4s, v0.4s + fadd v12.4s, v12.4s, v0.4s + fadd v14.4s, v14.4s, v0.4s + fadd v16.4s, v16.4s, v0.4s + fadd v18.4s, v18.4s, v0.4s + fadd v20.4s, v20.4s, v0.4s + fadd v22.4s, v22.4s, v0.4s + fadd v24.4s, v24.4s, v0.4s + fadd v26.4s, v26.4s, v0.4s + fadd v28.4s, v28.4s, v0.4s + fadd v30.4s, v30.4s, v0.4s -End1: - sub sp, sp, #128 + ActivationHalf: + cmp x4, #2 + beq Relu6Half + cmp x4, #1 + beq ReluHalf + b Write + + Relu6Half: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v16.4s, v16.4s, v2.4s + fmin v18.4s, v18.4s, v2.4s + fmin v20.4s, v20.4s, v2.4s + fmin v22.4s, v22.4s, v2.4s + fmin v24.4s, v24.4s, v2.4s + fmin v26.4s, v26.4s, v2.4s + fmin v28.4s, v28.4s, v2.4s + fmin v30.4s, v30.4s, v2.4s + + ReluHalf: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v16.4s, v16.4s, v3.4s + fmax v18.4s, v18.4s, v3.4s + fmax v20.4s, v20.4s, v3.4s + fmax v22.4s, v22.4s, v3.4s + fmax v24.4s, v24.4s, v3.4s + fmax v26.4s, v26.4s, v3.4s + fmax v28.4s, v28.4s, v3.4s + fmax v30.4s, v30.4s, v3.4s + + Write: + cmp x9, #2 + beq WriteWino + cbz x9, WriteC8 + cmp x13, #1 + beq Write1 + cmp x13, #2 + beq Write2 + cmp x13, #3 + beq Write3 + cmp x13, #4 + beq Write4 + cmp x13, #5 + beq Write5 + cmp x13, #6 + beq Write6 + cmp x13, #7 + beq Write7 + b Write8 + + Write1: + add x2, x2, #4 + str s8, [x11] + cmp x6, #1 + beq WriteEnd + add x11, x11, x8 + str s10, [x11] + cmp x6, #2 + beq WriteEnd + add x11, x11, x8 + str s12, [x11] + cmp x6, #3 + beq WriteEnd + add x11, x11, x8 + str s14, [x11] + cmp x6, #4 + beq WriteEnd + add x11, x11, x8 + str s16, [x11] + cmp x6, #5 + beq WriteEnd + add x11, x11, x8 + str s18, [x11] + cmp x6, #6 + beq WriteEnd + add x11, x11, x8 + str s20, [x11] + cmp x6, #7 + beq WriteEnd + add x11, x11, x8 + str s22, [x11] + cmp x6, #8 + beq WriteEnd + add x11, x11, x8 + str s24, [x11] + cmp x6, #9 + beq WriteEnd + add x11, x11, x8 + str s26, [x11] + cmp x6, #10 + beq WriteEnd + add x11, x11, x8 + str s28, [x11] + cmp x6, #11 + beq WriteEnd + add x11, x11, x8 + str s30, [x11] + add x11, x11, x8 + add x11, x11, #4 + b WriteEnd + Write2: + add x2, x2, #8 + str d8, [x11] + cmp x6, #1 + beq WriteEnd + add x11, x11, x8 + str d10, [x11] + cmp x6, #2 + beq WriteEnd + add x11, x11, x8 + str d12, [x11] + cmp x6, #3 + beq WriteEnd + add x11, x11, x8 + str d14, [x11] + cmp x6, #4 + beq WriteEnd + add x11, x11, x8 + str d16, [x11] + cmp x6, #5 + beq WriteEnd + add x11, x11, x8 + str d18, [x11] + cmp x6, #6 + beq WriteEnd + add x11, x11, x8 + str d20, [x11] + cmp x6, #7 + beq WriteEnd + add x11, x11, x8 + str d22, [x11] + cmp x6, #8 + beq WriteEnd + add x11, x11, x8 + str d24, [x11] + cmp x6, #9 + beq WriteEnd + add x11, x11, x8 + str d26, [x11] + cmp x6, #10 + beq WriteEnd + add x11, x11, x8 + str d28, [x11] + cmp x6, #11 + beq WriteEnd + add x11, x11, x8 + str d30, [x11] + add x11, x11, x8 + add x11, x11, #8 + b WriteEnd + Write3: + add x2, x2, #12 + add x19, x11, #8 + str d8, [x11] + st1 {v8.s}[2], [x19], x8 + cmp x6, #1 + beq WriteEnd + add x11, x11, x8 + str d10, [x11] + st1 {v10.s}[2], [x19], x8 + cmp x6, #2 + beq WriteEnd + add x11, x11, x8 + str d12, [x11] + st1 {v12.s}[2], [x19], x8 + cmp x6, #3 + beq WriteEnd + add x11, x11, x8 + str d14, [x11] + st1 {v14.s}[2], [x19], x8 + cmp x6, #4 + beq WriteEnd + add x11, x11, x8 + str d16, [x11] + st1 {v16.s}[2], [x19], x8 + cmp x6, #5 + beq WriteEnd + add x11, x11, x8 + str d18, [x11] + st1 {v18.s}[2], [x19], x8 + cmp x6, #6 + beq WriteEnd + add x11, x11, x8 + str d20, [x11] + st1 {v20.s}[2], [x19], x8 + cmp x6, #7 + beq WriteEnd + add x11, x11, x8 + str d22, [x11] + st1 {v22.s}[2], [x19], x8 + cmp x6, #8 + beq WriteEnd + add x11, x11, x8 + str d24, [x11] + st1 {v24.s}[2], [x19], x8 + cmp x6, #9 + beq WriteEnd + add x11, x11, x8 + str d26, [x11] + st1 {v26.s}[2], [x19], x8 + cmp x6, #10 + beq WriteEnd + add x11, x11, x8 + str d28, [x11] + st1 {v28.s}[2], [x19], x8 + cmp x6, #11 + beq WriteEnd + add x11, x11, x8 + str d30, [x11] + st1 {v30.s}[2], [x19] + add x11, x11, x8 + add x11, x11, #12 + b WriteEnd + Write4: + add x2, x2, #16 + st1 {v8.4s}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s}, [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s}, [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s}, [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s}, [x11], x8 + add x11, x11, #16 + b WriteEnd + Write5: + add x2, x2, #20 + add x19, x11, #16 + st1 {v8.4s}, [x11], x8 + str s9, [x19] + cmp x6, #1 + beq WriteEnd + add x19, x19, x8 + st1 {v10.4s}, [x11], x8 + str s11, [x19] + cmp x6, #2 + beq WriteEnd + add x19, x19, x8 + st1 {v12.4s}, [x11], x8 + str s13, [x19] + cmp x6, #3 + beq WriteEnd + add x19, x19, x8 + st1 {v14.4s}, [x11], x8 + str s15, [x19] + cmp x6, #4 + beq WriteEnd + add x19, x19, x8 + st1 {v16.4s}, [x11], x8 + str s17, [x19] + cmp x6, #5 + beq WriteEnd + add x19, x19, x8 + st1 {v18.4s}, [x11], x8 + str s19, [x19] + cmp x6, #6 + beq WriteEnd + add x19, x19, x8 + st1 {v20.4s}, [x11], x8 + str s21, [x19] + cmp x6, #7 + beq WriteEnd + add x19, x19, x8 + st1 {v22.4s}, [x11], x8 + str s23, [x19] + cmp x6, #8 + beq WriteEnd + add x19, x19, x8 + st1 {v24.4s}, [x11], x8 + str s25, [x19] + cmp x6, #9 + beq WriteEnd + add x19, x19, x8 + st1 {v26.4s}, [x11], x8 + str s27, [x19] + cmp x6, #10 + beq WriteEnd + add x19, x19, x8 + st1 {v28.4s}, [x11], x8 + str s29, [x19] + cmp x6, #11 + beq WriteEnd + add x19, x19, x8 + st1 {v30.4s}, [x11], x8 + str s31, [x19] + add x11, x11, #20 + b WriteEnd + Write6: + add x2, x2, #24 + add x19, x11, #16 + st1 {v8.4s}, [x11], x8 + str d9, [x19] + cmp x6, #1 + beq WriteEnd + add x19, x19, x8 + st1 {v10.4s}, [x11], x8 + str d11, [x19] + cmp x6, #2 + beq WriteEnd + add x19, x19, x8 + st1 {v12.4s}, [x11], x8 + str d13, [x19] + cmp x6, #3 + beq WriteEnd + add x19, x19, x8 + st1 {v14.4s}, [x11], x8 + str d15, [x19] + cmp x6, #4 + beq WriteEnd + add x19, x19, x8 + st1 {v16.4s}, [x11], x8 + str d17, [x19] + cmp x6, #5 + beq WriteEnd + add x19, x19, x8 + st1 {v18.4s}, [x11], x8 + str d19, [x19] + cmp x6, #6 + beq WriteEnd + add x19, x19, x8 + st1 {v20.4s}, [x11], x8 + str d21, [x19] + cmp x6, #7 + beq WriteEnd + add x19, x19, x8 + st1 {v22.4s}, [x11], x8 + str d23, [x19] + cmp x6, #8 + beq WriteEnd + add x19, x19, x8 + st1 {v24.4s}, [x11], x8 + str d25, [x19] + cmp x6, #9 + beq WriteEnd + add x19, x19, x8 + st1 {v26.4s}, [x11], x8 + str d27, [x19] + cmp x6, #10 + beq WriteEnd + add x19, x19, x8 + st1 {v28.4s}, [x11], x8 + str d29, [x19] + cmp x6, #11 + beq WriteEnd + add x19, x19, x8 + st1 {v30.4s}, [x11], x8 + str d31, [x19] + add x11, x11, #24 + b WriteEnd + Write7: + add x2, x2, #28 + add x19, x11, #16 + add x20, x11, #24 + st1 {v8.4s}, [x11], x8 + str d9, [x19] + st1 {v9.s}[2], [x20], x8 + cmp x6, #1 + beq WriteEnd + add x19, x19, x8 + st1 {v10.4s}, [x11], x8 + str d11, [x19] + st1 {v11.s}[2], [x20], x8 + cmp x6, #2 + beq WriteEnd + add x19, x19, x8 + st1 {v12.4s}, [x11], x8 + str d13, [x19] + st1 {v13.s}[2], [x20], x8 + cmp x6, #3 + beq WriteEnd + add x19, x19, x8 + st1 {v14.4s}, [x11], x8 + str d15, [x19] + st1 {v15.s}[2], [x20], x8 + cmp x6, #4 + beq WriteEnd + add x19, x19, x8 + st1 {v16.4s}, [x11], x8 + str d17, [x19] + st1 {v17.s}[2], [x20], x8 + cmp x6, #5 + beq WriteEnd + add x19, x19, x8 + st1 {v18.4s}, [x11], x8 + str d19, [x19] + st1 {v19.s}[2], [x20], x8 + cmp x6, #6 + beq WriteEnd + add x19, x19, x8 + st1 {v20.4s}, [x11], x8 + str d21, [x19] + st1 {v21.s}[2], [x20], x8 + cmp x6, #7 + beq WriteEnd + add x19, x19, x8 + st1 {v22.4s}, [x11], x8 + str d23, [x19] + st1 {v23.s}[2], [x20], x8 + cmp x6, #8 + beq WriteEnd + add x19, x19, x8 + st1 {v24.4s}, [x11], x8 + str d25, [x19] + st1 {v25.s}[2], [x20], x8 + cmp x6, #9 + beq WriteEnd + add x19, x19, x8 + st1 {v26.4s}, [x11], x8 + str d27, [x19] + st1 {v27.s}[2], [x20], x8 + cmp x6, #10 + beq WriteEnd + add x19, x19, x8 + st1 {v28.4s}, [x11], x8 + str d29, [x19] + st1 {v29.s}[2], [x20], x8 + cmp x6, #11 + beq WriteEnd + add x19, x19, x8 + st1 {v30.4s}, [x11], x8 + str d31, [x19] + add x19, x19, x8 + st1 {v31.s}[2], [x20], x8 + add x11, x11, #28 + b WriteEnd + WriteC8: + mov x19, x11 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x19], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x19], #64 + st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x19], #64 + st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x19], #64 + st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x19], #64 + st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x19], #64 + add x11, x11, x16 + b WriteEnd + WriteWino: + add x2, x11, x16 + st1 {v8.4s, v9.4s}, [x11], x15 + st1 {v10.4s, v11.4s}, [x11], x15 + st1 {v12.4s, v13.4s}, [x11], x15 + st1 {v14.4s, v15.4s}, [x11], x15 + st1 {v16.4s, v17.4s}, [x11], x15 + st1 {v18.4s, v19.4s}, [x11], x15 + st1 {v20.4s, v21.4s}, [x11], x15 + st1 {v22.4s, v23.4s}, [x11], x15 + st1 {v24.4s, v25.4s}, [x11], x15 + st1 {v26.4s, v27.4s}, [x11], x15 + st1 {v28.4s, v29.4s}, [x11], x15 + st1 {v30.4s, v31.4s}, [x11], x15 + b WriteEnd + Write8: + add x2, x2, #32 + st1 {v8.4s, v9.4s}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s, v11.4s}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s, v13.4s}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s, v15.4s}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s, v17.4s}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s, v19.4s}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s, v21.4s}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s, v23.4s}, [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s, v25.4s}, [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s, v27.4s}, [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s, v29.4s}, [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s, v31.4s}, [x11], x8 + add x11, x11, #32 + + WriteEnd: + subs x13, x13, #8 // rhs col - 8 + bgt LoopCol + + add x0, x0, x17 + cbz x9, C8DstStep + mov x18, #4 + mul x18, x18, x7 + sub x11, x11, x18 + mov x2, x11 + b NoDstStep + C8DstStep: + add x2, x2, #384 + mov x11, x2 + NoDstStep: + subs x6, x6, #12 + bgt LoopRow + + sub sp, sp, #144 ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 ret #endif diff --git a/mindspore/lite/nnacl/assembly/arm64/MatmulFp32OptRemain.S b/mindspore/lite/nnacl/assembly/arm64/MatmulFp32OptRemain.S index dd42b49245..ccc6ce534f 100644 --- a/mindspore/lite/nnacl/assembly/arm64/MatmulFp32OptRemain.S +++ b/mindspore/lite/nnacl/assembly/arm64/MatmulFp32OptRemain.S @@ -6,139 +6,761 @@ .type MatmulFloatNeon64OptRemain, %function #endif -// void MatmulFloatNeon64(const float *a, const float *b, float *c, int depth -// int row, int col, size_t stride) +// void MatmulFloatNeon64Remain(const float *a, const float *b, float *c, const float *bias, int act_type, int depth +// int row, int col, size_t stride, size_t writeMode) // x0: a // x1: b // x2: c -// x3: depth -// x4: row -// x5: col -// x6: stride -// only for winograd +// x3: bias +// x4: act_type +// x5: depth +// x6: row +// x7: col +// x8: stride +// x9: writeMode + MatmulFloatNeon64OptRemain: - mov x18, #32 // sizeof(float) * 8 - mul x9, x3, x18 // block stride of lhs/rhs: sizeof(float) * 8 * depth + sub sp, sp, #144 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + stp x19, x20, [sp], #16 + + ldr x8, [sp] + ldr x9, [sp, #8] + + mov x18, #48 // sizeof(float) * 12 + mul x17, x5, x18 // block stride of lhs/rhs: sizeof(float) * 12 * depth + cbnz x9, NoC8Steps + mov x11, x2 + mov x18, #32 + mul x16, x6, x18 // row * 8 * sizeof(float) +NoC8Steps: + cmp x9, #2 + bne NoWinoSteps + mov x18, #4 + mul x15, x7, x8 + mul x15, x15, x18 // kernel_size * col *sizeof(float) + mov x18, #32 + mul x16, x8, x18 // kernel_size * 8 * sizeof(float) +NoWinoSteps: mov x18, #4 - mul x8, x5, x6 - mov x11, #8 - mul x11, x11, x6 mul x8, x8, x18 - mul x11, x11, x18 - - cmp x4, #4 - ble LoopH4 - - LoopH8: - mov x10, x4 // reload lhs row - mov x12, x0 // reload lhs ptr - mov x18, x2 // reload dst ptr - - LoopW8: - mov x16, x1 // reload rhs ptr - mov x13, x3 // reload depth - dup v16.4s, wzr - dup v17.4s, wzr - dup v18.4s, wzr - dup v19.4s, wzr - dup v20.4s, wzr - dup v21.4s, wzr - dup v22.4s, wzr - dup v23.4s, wzr - dup v24.4s, wzr - dup v25.4s, wzr - dup v26.4s, wzr - dup v27.4s, wzr - dup v28.4s, wzr - dup v29.4s, wzr - dup v30.4s, wzr - dup v31.4s, wzr - - LoopD8: - ld1 {v0.4s, v1.4s, v2.4s}, [x12], #48 - ld1 {v3.4s, v4.4s}, [x16], #32 - fmla v16.4s, v3.4s, v0.s[0] - fmla v18.4s, v3.4s, v0.s[1] - fmla v20.4s, v3.4s, v0.s[2] - fmla v22.4s, v3.4s, v0.s[3] - fmla v17.4s, v4.4s, v0.s[0] - fmla v19.4s, v4.4s, v0.s[1] - fmla v21.4s, v4.4s, v0.s[2] - fmla v23.4s, v4.4s, v0.s[3] - fmla v24.4s, v3.4s, v1.s[0] - fmla v26.4s, v3.4s, v1.s[1] - fmla v28.4s, v3.4s, v1.s[2] - fmla v30.4s, v3.4s, v1.s[3] - fmla v25.4s, v4.4s, v1.s[0] - fmla v27.4s, v4.4s, v1.s[1] - fmla v29.4s, v4.4s, v1.s[2] - fmla v31.4s, v4.4s, v1.s[3] - - subs w13, w13, #1 - bgt LoopD8 - - st1 {v16.4s, v17.4s}, [x18], x8 - st1 {v18.4s, v19.4s}, [x18], x8 - st1 {v20.4s, v21.4s}, [x18], x8 - st1 {v22.4s, v23.4s}, [x18], x8 - st1 {v24.4s, v25.4s}, [x18], x8 - st1 {v26.4s, v27.4s}, [x18], x8 - st1 {v28.4s, v29.4s}, [x18], x8 - st1 {v30.4s, v31.4s}, [x18], x8 - - subs x10, x10, #8 // lhs row - 8 - bgt LoopW8 - - subs x5, x5, #8 // rhs col - 8 - add x1, x1, x9 // rhs ptr + stride - add x2, x2, x11 - bgt LoopH8 - - ret - - LoopH4: - mov x10, x4 // reload lhs row - mov x12, x0 // reload lhs ptr - mov x18, x2 // reload dst ptr - - LoopW4: - mov x16, x1 // reload rhs ptr - mov x13, x3 // reload depth - dup v16.4s, wzr - dup v17.4s, wzr - dup v18.4s, wzr - dup v19.4s, wzr - dup v20.4s, wzr - dup v21.4s, wzr - dup v22.4s, wzr - dup v23.4s, wzr - - LoopD4: - ld1 {v0.4s, v1.4s, v2.4s}, [x12], #48 - ld1 {v3.4s, v4.4s}, [x16], #32 - fmla v16.4s, v3.4s, v0.s[0] - fmla v18.4s, v3.4s, v0.s[1] - fmla v20.4s, v3.4s, v0.s[2] - fmla v22.4s, v3.4s, v0.s[3] - fmla v17.4s, v4.4s, v0.s[0] - fmla v19.4s, v4.4s, v0.s[1] - fmla v21.4s, v4.4s, v0.s[2] - fmla v23.4s, v4.4s, v0.s[3] - - subs x13, x13, #1 - bgt LoopD4 - - st1 {v16.4s, v17.4s}, [x18], x8 - st1 {v18.4s, v19.4s}, [x18], x8 - st1 {v20.4s, v21.4s}, [x18], x8 - st1 {v22.4s, v23.4s}, [x18], x8 - - subs x10, x10, #4 // lhs row - 4 - bgt LoopW4 - - subs x5, x5, #8 // rhs col - 8 - add x1, x1, x9 // rhs ptr + stride - add x2, x2, x11 - bgt LoopH4 - ret + +LoopRow: + cmp x6, #4 + ble LoopRow4 + +LoopRow8: + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol8: + cbz x9, NoReloadDst8 + mov x11, x2 + NoReloadDst8: + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + cmp x13, #4 + ble LoopDepthStartHalf8 + + LoopDepthStart8: + ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 + ld1 {v3.4s, v4.4s}, [x14], #32 + fmul v8.4s, v3.4s, v0.s[0] + fmul v10.4s, v3.4s, v0.s[1] + fmul v12.4s, v3.4s, v0.s[2] + fmul v14.4s, v3.4s, v0.s[3] + fmul v9.4s, v4.4s, v0.s[0] + fmul v11.4s, v4.4s, v0.s[1] + fmul v13.4s, v4.4s, v0.s[2] + fmul v15.4s, v4.4s, v0.s[3] + fmul v16.4s, v3.4s, v1.s[0] + fmul v18.4s, v3.4s, v1.s[1] + fmul v20.4s, v3.4s, v1.s[2] + fmul v22.4s, v3.4s, v1.s[3] + fmul v17.4s, v4.4s, v1.s[0] + fmul v19.4s, v4.4s, v1.s[1] + fmul v21.4s, v4.4s, v1.s[2] + fmul v23.4s, v4.4s, v1.s[3] + + subs x19, x19, #1 + beq Bias8 + + LoopDepth8: + ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 + ld1 {v3.4s, v4.4s}, [x14], #32 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v17.4s, v4.4s, v1.s[0] + fmla v19.4s, v4.4s, v1.s[1] + fmla v21.4s, v4.4s, v1.s[2] + fmla v23.4s, v4.4s, v1.s[3] + + subs x19, x19, #1 + bgt LoopDepth8 + + Bias8: + cbz x3, Activation8 + ld1 {v0.4s}, [x12], #16 + ld1 {v1.4s}, [x12], #16 + fadd v8.4s, v8.4s, v0.4s + fadd v9.4s, v9.4s, v1.4s + fadd v10.4s, v10.4s, v0.4s + fadd v11.4s, v11.4s, v1.4s + fadd v12.4s, v12.4s, v0.4s + fadd v13.4s, v13.4s, v1.4s + fadd v14.4s, v14.4s, v0.4s + fadd v15.4s, v15.4s, v1.4s + fadd v16.4s, v16.4s, v0.4s + fadd v17.4s, v17.4s, v1.4s + fadd v18.4s, v18.4s, v0.4s + fadd v19.4s, v19.4s, v1.4s + fadd v20.4s, v20.4s, v0.4s + fadd v21.4s, v21.4s, v1.4s + fadd v22.4s, v22.4s, v0.4s + fadd v23.4s, v23.4s, v1.4s + + Activation8: + cmp x4, #2 + beq Relu68 + cmp x4, #1 + beq Relu8 + b Write + + Relu68: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v9.4s, v9.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v11.4s, v11.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v13.4s, v13.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v15.4s, v15.4s, v2.4s + fmin v16.4s, v16.4s, v2.4s + fmin v17.4s, v17.4s, v2.4s + fmin v18.4s, v18.4s, v2.4s + fmin v19.4s, v19.4s, v2.4s + fmin v20.4s, v20.4s, v2.4s + fmin v21.4s, v21.4s, v2.4s + fmin v22.4s, v22.4s, v2.4s + fmin v23.4s, v23.4s, v2.4s + + Relu8: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v9.4s, v9.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v11.4s, v11.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v13.4s, v13.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v15.4s, v15.4s, v3.4s + fmax v16.4s, v16.4s, v3.4s + fmax v17.4s, v17.4s, v3.4s + fmax v18.4s, v18.4s, v3.4s + fmax v19.4s, v19.4s, v3.4s + fmax v20.4s, v20.4s, v3.4s + fmax v21.4s, v21.4s, v3.4s + fmax v22.4s, v22.4s, v3.4s + fmax v23.4s, v23.4s, v3.4s + b Write + + LoopDepthStartHalf8: + ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 + ld1 {v3.4s, v4.4s}, [x14], #32 + fmul v8.4s, v3.4s, v0.s[0] + fmul v10.4s, v3.4s, v0.s[1] + fmul v12.4s, v3.4s, v0.s[2] + fmul v14.4s, v3.4s, v0.s[3] + fmul v16.4s, v3.4s, v1.s[0] + fmul v18.4s, v3.4s, v1.s[1] + fmul v20.4s, v3.4s, v1.s[2] + fmul v22.4s, v3.4s, v1.s[3] + + subs x19, x19, #1 + beq BiasHalf8 + + LoopDepthHalf8: + ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 + ld1 {v3.4s, v4.4s}, [x14], #32 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + + subs x19, x19, #1 + bgt LoopDepthHalf8 + + BiasHalf8: + cbz x3, ActivationHalf8 + ld1 {v0.4s}, [x12], #16 + ld1 {v1.4s}, [x12], #16 + fadd v8.4s, v8.4s, v0.4s + fadd v10.4s, v10.4s, v0.4s + fadd v12.4s, v12.4s, v0.4s + fadd v14.4s, v14.4s, v0.4s + fadd v16.4s, v16.4s, v0.4s + fadd v18.4s, v18.4s, v0.4s + fadd v20.4s, v20.4s, v0.4s + fadd v22.4s, v22.4s, v0.4s + + ActivationHalf8: + cmp x4, #2 + beq Relu6Half8 + cmp x4, #1 + beq ReluHalf8 + b Write + + Relu6Half8: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v16.4s, v16.4s, v2.4s + fmin v18.4s, v18.4s, v2.4s + fmin v20.4s, v20.4s, v2.4s + fmin v22.4s, v22.4s, v2.4s + + ReluHalf8: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v16.4s, v16.4s, v3.4s + fmax v18.4s, v18.4s, v3.4s + fmax v20.4s, v20.4s, v3.4s + fmax v22.4s, v22.4s, v3.4s + b Write + +LoopRow4: + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol4: + cbz x9, NoReloadDst4 + mov x11, x2 + NoReloadDst4: + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + cmp x13, #4 + ble LoopDepthStartHalf4 + + LoopDepthStart4: + ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 + ld1 {v3.4s, v4.4s}, [x14], #32 + fmul v8.4s, v3.4s, v0.s[0] + fmul v10.4s, v3.4s, v0.s[1] + fmul v12.4s, v3.4s, v0.s[2] + fmul v14.4s, v3.4s, v0.s[3] + fmul v9.4s, v4.4s, v0.s[0] + fmul v11.4s, v4.4s, v0.s[1] + fmul v13.4s, v4.4s, v0.s[2] + fmul v15.4s, v4.4s, v0.s[3] + + subs x19, x19, #1 + beq Bias4 + + LoopDepth4: + ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 + ld1 {v3.4s, v4.4s}, [x14], #32 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + + subs x19, x19, #1 + bgt LoopDepth4 + + Bias4: + cbz x3, Activation4 + ld1 {v0.4s}, [x12], #16 + ld1 {v1.4s}, [x12], #16 + fadd v8.4s, v8.4s, v0.4s + fadd v9.4s, v9.4s, v1.4s + fadd v10.4s, v10.4s, v0.4s + fadd v11.4s, v11.4s, v1.4s + fadd v12.4s, v12.4s, v0.4s + fadd v13.4s, v13.4s, v1.4s + fadd v14.4s, v14.4s, v0.4s + fadd v15.4s, v15.4s, v1.4s + + Activation4: + cmp x4, #2 + beq Relu64 + cmp x4, #1 + beq Relu4 + b Write + + Relu64: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v9.4s, v9.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v11.4s, v11.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v13.4s, v13.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v15.4s, v15.4s, v2.4s + + Relu4: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v9.4s, v9.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v11.4s, v11.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v13.4s, v13.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v15.4s, v15.4s, v3.4s + b Write + + LoopDepthStartHalf4: + ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 + ld1 {v3.4s, v4.4s}, [x14], #32 + fmul v8.4s, v3.4s, v0.s[0] + fmul v10.4s, v3.4s, v0.s[1] + fmul v12.4s, v3.4s, v0.s[2] + fmul v14.4s, v3.4s, v0.s[3] + + subs x19, x19, #1 + beq BiasHalf4 + + LoopDepthHalf4: + ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 + ld1 {v3.4s, v4.4s}, [x14], #32 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + + subs x19, x19, #1 + bgt LoopDepthHalf4 + + BiasHalf4: + cbz x3, ActivationHalf4 + ld1 {v0.4s}, [x12], #16 + ld1 {v1.4s}, [x12], #16 + fadd v8.4s, v8.4s, v0.4s + fadd v10.4s, v10.4s, v0.4s + fadd v12.4s, v12.4s, v0.4s + fadd v14.4s, v14.4s, v0.4s + + ActivationHalf4: + cmp x4, #2 + beq Relu6Half4 + cmp x4, #1 + beq ReluHalf4 + b Write + + Relu6Half4: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + + ReluHalf4: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + + Write: + cmp x9, #2 + beq WriteWino + cbz x9, WriteC8 + cmp x13, #1 + beq Write1 + cmp x13, #2 + beq Write2 + cmp x13, #3 + beq Write3 + cmp x13, #4 + beq Write4 + cmp x13, #5 + beq Write5 + cmp x13, #6 + beq Write6 + cmp x13, #7 + beq Write7 + b Write8 + + Write1: + add x2, x2, #4 + str s8, [x11] + cmp x6, #1 + beq WriteEnd + add x11, x11, x8 + str s10, [x11] + cmp x6, #2 + beq WriteEnd + add x11, x11, x8 + str s12, [x11] + cmp x6, #3 + beq WriteEnd + add x11, x11, x8 + str s14, [x11] + cmp x6, #4 + beq WriteEnd + add x11, x11, x8 + str s16, [x11] + cmp x6, #5 + beq WriteEnd + add x11, x11, x8 + str s18, [x11] + cmp x6, #6 + beq WriteEnd + add x11, x11, x8 + str s20, [x11] + cmp x6, #7 + beq WriteEnd + add x11, x11, x8 + str s22, [x11] + add x11, x11, x8 + add x11, x11, #4 + b WriteEnd + Write2: + add x2, x2, #8 + str d8, [x11] + cmp x6, #1 + beq WriteEnd + add x11, x11, x8 + str d10, [x11] + cmp x6, #2 + beq WriteEnd + add x11, x11, x8 + str d12, [x11] + cmp x6, #3 + beq WriteEnd + add x11, x11, x8 + str d14, [x11] + cmp x6, #4 + beq WriteEnd + add x11, x11, x8 + str d16, [x11] + cmp x6, #5 + beq WriteEnd + add x11, x11, x8 + str d18, [x11] + cmp x6, #6 + beq WriteEnd + add x11, x11, x8 + str d20, [x11] + cmp x6, #7 + beq WriteEnd + add x11, x11, x8 + str d22, [x11] + add x11, x11, x8 + add x11, x11, #8 + b WriteEnd + Write3: + add x2, x2, #12 + add x19, x11, #8 + str d8, [x11] + st1 {v8.s}[2], [x19], x8 + cmp x6, #1 + beq WriteEnd + add x11, x11, x8 + str d10, [x11] + st1 {v10.s}[2], [x19], x8 + cmp x6, #2 + beq WriteEnd + add x11, x11, x8 + str d12, [x11] + st1 {v12.s}[2], [x19], x8 + cmp x6, #3 + beq WriteEnd + add x11, x11, x8 + str d14, [x11] + st1 {v14.s}[2], [x19], x8 + cmp x6, #4 + beq WriteEnd + add x11, x11, x8 + str d16, [x11] + st1 {v16.s}[2], [x19], x8 + cmp x6, #5 + beq WriteEnd + add x11, x11, x8 + str d18, [x11] + st1 {v18.s}[2], [x19], x8 + cmp x6, #6 + beq WriteEnd + add x11, x11, x8 + str d20, [x11] + st1 {v20.s}[2], [x19], x8 + cmp x6, #7 + beq WriteEnd + add x11, x11, x8 + str d22, [x11] + st1 {v22.s}[2], [x19], x8 + add x11, x11, x8 + add x11, x11, #12 + b WriteEnd + Write4: + add x2, x2, #16 + st1 {v8.4s}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], x8 + add x11, x11, #16 + b WriteEnd + Write5: + add x2, x2, #20 + add x19, x11, #16 + st1 {v8.4s}, [x11], x8 + str s9, [x19] + cmp x6, #1 + beq WriteEnd + add x19, x19, x8 + st1 {v10.4s}, [x11], x8 + str s11, [x19] + cmp x6, #2 + beq WriteEnd + add x19, x19, x8 + st1 {v12.4s}, [x11], x8 + str s13, [x19] + cmp x6, #3 + beq WriteEnd + add x19, x19, x8 + st1 {v14.4s}, [x11], x8 + str s15, [x19] + cmp x6, #4 + beq WriteEnd + add x19, x19, x8 + st1 {v16.4s}, [x11], x8 + str s17, [x19] + cmp x6, #5 + beq WriteEnd + add x19, x19, x8 + st1 {v18.4s}, [x11], x8 + str s19, [x19] + cmp x6, #6 + beq WriteEnd + add x19, x19, x8 + st1 {v20.4s}, [x11], x8 + str s21, [x19] + cmp x6, #7 + beq WriteEnd + add x19, x19, x8 + st1 {v22.4s}, [x11], x8 + str s23, [x19] + add x11, x11, #20 + b WriteEnd + Write6: + add x2, x2, #24 + add x19, x11, #16 + st1 {v8.4s}, [x11], x8 + str d9, [x19] + cmp x6, #1 + beq WriteEnd + add x19, x19, x8 + st1 {v10.4s}, [x11], x8 + str d11, [x19] + cmp x6, #2 + beq WriteEnd + add x19, x19, x8 + st1 {v12.4s}, [x11], x8 + str d13, [x19] + cmp x6, #3 + beq WriteEnd + add x19, x19, x8 + st1 {v14.4s}, [x11], x8 + str d15, [x19] + cmp x6, #4 + beq WriteEnd + add x19, x19, x8 + st1 {v16.4s}, [x11], x8 + str d17, [x19] + cmp x6, #5 + beq WriteEnd + add x19, x19, x8 + st1 {v18.4s}, [x11], x8 + str d19, [x19] + cmp x6, #6 + beq WriteEnd + add x19, x19, x8 + st1 {v20.4s}, [x11], x8 + str d21, [x19] + cmp x6, #7 + beq WriteEnd + add x19, x19, x8 + st1 {v22.4s}, [x11], x8 + str d23, [x19] + add x11, x11, #24 + b WriteEnd + Write7: + add x2, x2, #28 + add x19, x11, #16 + add x20, x11, #24 + st1 {v8.4s}, [x11], x8 + str d9, [x19] + st1 {v9.s}[2], [x20], x8 + cmp x6, #1 + beq WriteEnd + add x19, x19, x8 + st1 {v10.4s}, [x11], x8 + str d11, [x19] + st1 {v11.s}[2], [x20], x8 + cmp x6, #2 + beq WriteEnd + add x19, x19, x8 + st1 {v12.4s}, [x11], x8 + str d13, [x19] + st1 {v13.s}[2], [x20], x8 + cmp x6, #3 + beq WriteEnd + add x19, x19, x8 + st1 {v14.4s}, [x11], x8 + str d15, [x19] + st1 {v15.s}[2], [x20], x8 + cmp x6, #4 + beq WriteEnd + add x19, x19, x8 + st1 {v16.4s}, [x11], x8 + str d17, [x19] + st1 {v17.s}[2], [x20], x8 + cmp x6, #5 + beq WriteEnd + add x19, x19, x8 + st1 {v18.4s}, [x11], x8 + str d19, [x19] + st1 {v19.s}[2], [x20], x8 + cmp x6, #6 + beq WriteEnd + add x19, x19, x8 + st1 {v20.4s}, [x11], x8 + str d21, [x19] + st1 {v21.s}[2], [x20], x8 + cmp x6, #7 + beq WriteEnd + add x19, x19, x8 + st1 {v22.4s}, [x11], x8 + str d23, [x19] + st1 {v23.s}[2], [x20], x8 + add x11, x11, #28 + b WriteEnd + WriteC8: + mov x19, x11 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x19], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x19], #64 + st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x19], #64 + st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x19], #64 + add x11, x11, x16 + b WriteEnd + WriteWino: + add x2, x11, x16 + st1 {v8.4s, v9.4s}, [x11], x15 + st1 {v10.4s, v11.4s}, [x11], x15 + st1 {v12.4s, v13.4s}, [x11], x15 + st1 {v14.4s, v15.4s}, [x11], x15 + st1 {v16.4s, v17.4s}, [x11], x15 + st1 {v18.4s, v19.4s}, [x11], x15 + st1 {v20.4s, v21.4s}, [x11], x15 + st1 {v22.4s, v23.4s}, [x11], x15 + b WriteEnd + Write8: + add x2, x2, #32 + st1 {v8.4s, v9.4s}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s, v11.4s}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s, v13.4s}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s, v15.4s}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s, v17.4s}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s, v19.4s}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s, v21.4s}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s, v23.4s}, [x11], x8 + add x11, x11, #32 + + WriteEnd: + subs x13, x13, #8 // rhs col - 8 + ble LoopColEnd + cmp x6, #4 + ble LoopCol4 + b LoopCol8 + +LoopColEnd: + add x0, x0, x17 + cbz x9, C8DstStep + mov x18, #4 + mul x18, x18, x7 + sub x11, x11, x18 + mov x2, x11 + b NoDstStep + C8DstStep: + add x2, x2, #384 + mov x11, x2 + NoDstStep: + subs x6, x6, #12 + bgt LoopRow + + sub sp, sp, #144 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ret #endif diff --git a/mindspore/lite/nnacl/common_func.c b/mindspore/lite/nnacl/common_func.c index def3913b67..ec6e632a89 100644 --- a/mindspore/lite/nnacl/common_func.c +++ b/mindspore/lite/nnacl/common_func.c @@ -27,137 +27,6 @@ int offsetComm(const int *shape, const int dim0, const int dim1, const int dim2) int offset4d(const int *shape, const int *dims) { return offset(shape, dims[0], dims[1], dims[2], dims[3]); } -#ifndef ENABLE_ARM64 -void IndirectGemmFp32(float *output, const float *input, const float *weight, const float *bias, size_t step, int ic4, - int output_channel, size_t offset, size_t relu, size_t relu6) { - for (int i = 0; i < TILE_NUM; i++) { - int input_tile_offset = i * C4NUM; - int output_tile_offset = i * output_channel; - for (int j = 0; j < output_channel; j++) { - int oc8_block = j / C8NUM; - int oc8_res = j % C8NUM; - int weight_oc_offset = oc8_block * step * ic4 * C4NUM * C8NUM + oc8_res; - int out_oc_offset = output_tile_offset + j; - - float acc = 0; - for (int n = 0; n < step; n++) { - int input_kw_offset = input_tile_offset + n * ic4 * C4NUM * TILE_NUM; - int weight_kw_offset = weight_oc_offset + n * ic4 * C4NUM * C8NUM; - - for (int k = 0; k < ic4; k++) { - int input_ic4_offset = input_kw_offset + k * TILE_NUM * C4NUM; - int weight_ic4_offset = weight_kw_offset + k * C4NUM * C8NUM; - for (int m = 0; m < C4NUM; m++) { - int input_ic_offset = input_ic4_offset + m; - int weight_ic_offset = weight_ic4_offset + m * C8NUM; - acc += (weight + weight_ic_offset)[0] * (input + input_ic_offset)[0]; - } - } - } - acc += bias[j]; - if (relu) { - acc = acc > 0 ? acc : 0; - } else if (relu6) { - if (acc < 0) { - acc = 0; - } else if (acc > 6) { - acc = 6; - } else { - } - } - (output + out_oc_offset)[0] = acc; - } - } -} - -void IndirectGemmFp32_8x8(float *output, const float *input, const float *weight, const float *bias, size_t step, - size_t ic4, size_t output_channel, size_t offset, size_t mode, size_t writeC4, size_t relu, - size_t relu6) { - int oc4 = UP_DIV(output_channel, C4NUM); - if (mode && writeC4) { - for (int i = 0; i < TILE_NUM; i++) { - int input_tile_offset = i * C4NUM; - int output_tile_offset = i * oc4 * C4NUM * step; - for (int j = 0; j < output_channel; j++) { - int oc4_block = j / 4; - int oc4_res = j % 4; - int oc8_block = oc4_block / 2; - int oc8_res = oc4_block % 2; - int weight_oc_offset = oc8_block * step * ic4 * C4NUM * C8NUM + oc8_res * C4NUM + oc4_res; - int out_oc_offset = output_tile_offset + oc4_block * step * C4NUM + oc4_res; - - for (int n = 0; n < step; n++) { - int input_kw_offset = input_tile_offset + n * ic4 * C4NUM * TILE_NUM; - int weight_kw_offset = weight_oc_offset + n * ic4 * C4NUM * C8NUM; - int output_kw_offset = out_oc_offset + n * C4NUM; - float acc = 0; - - for (int k = 0; k < ic4; k++) { - int input_ic4_offset = input_kw_offset + k * TILE_NUM * C4NUM; - int weight_ic4_offset = weight_kw_offset + k * C4NUM * C8NUM; - for (int m = 0; m < 4; m++) { - int input_ic_offset = input_ic4_offset + m; - int weight_ic_offset = weight_ic4_offset + m * C8NUM; - acc += (weight + weight_ic_offset)[0] * (input + input_ic_offset)[0]; - } - } - (output + output_kw_offset)[0] = acc; - } - } - } - } else if (mode) { - IndirectGemmFp32_Comm(output, input, weight, ic4, C8NUM, output_channel, offset); - } else { - IndirectGemmFp32(output, input, weight, bias, step, ic4, output_channel, offset, relu, relu6); - } -} -#endif - -#ifndef ENABLE_ARM32 -void IndirectGemmFp32_8x4(float *output, const float *input, const float *weight, const float *bias, size_t step, - size_t ic4, size_t output_channel, size_t offset, size_t mode, size_t writeC4, size_t relu, - size_t relu6) { - for (int i = 0; i < TILE_NUM; i++) { - int input_tile_offset = i * C4NUM; - int output_tile_offset = i * output_channel; - for (int j = 0; j < output_channel; j++) { - int oc4_block = j / C4NUM; - int oc4_res = j % C4NUM; - int weight_oc_offset = oc4_block * step * ic4 * C4NUM * C4NUM + oc4_res; - int out_oc_offset = output_tile_offset + j; - - float acc = 0; - for (int n = 0; n < step; n++) { - int input_kw_offset = input_tile_offset + n * ic4 * C4NUM * TILE_NUM; - int weight_kw_offset = weight_oc_offset + n * ic4 * C4NUM * C4NUM; - - for (int k = 0; k < ic4; k++) { - int input_ic4_offset = input_kw_offset + k * TILE_NUM * C4NUM; - int weight_ic4_offset = weight_kw_offset + k * C4NUM * C4NUM; - for (int m = 0; m < C4NUM; m++) { - int input_ic_offset = input_ic4_offset + m; - int weight_ic_offset = weight_ic4_offset + m * C4NUM; - acc += (weight + weight_ic_offset)[0] * (input + input_ic_offset)[0]; - } - } - } - acc += bias[j]; - if (relu) { - acc = acc > 0 ? acc : 0; - } else if (relu6) { - if (acc < 0) { - acc = 0; - } else if (acc > 6) { - acc = 6; - } else { - } - } - (output + out_oc_offset)[0] = acc; - } - } -} -#endif - int8_t MinInt8(int8_t a, int8_t b) { return b ^ ((a ^ b) & -(a < b)); } int8_t MaxInt8(int8_t a, int8_t b) { return a ^ ((a ^ b) & -(a < b)); } @@ -210,21 +79,3 @@ void Relu6Fp32(float *data, float *dst, int ele_num) { data[j] = data[j] > 6 ? 6 : data[j]; } } - -void IndirectGemmFp32_Comm(float *output, const float *input, const float *weight, size_t ic4, size_t hw, size_t oc, - size_t offset) { - for (int r = 0; r < hw; r++) { - for (int c = 0; c < oc; c++) { - float value = 0; - for (int deep = 0; deep < ic4; deep++) { - int d4mod = deep % 4; - int d4div = deep / 4; - int a_index = d4div * 4 * 8 + r * 4 + d4mod; - const int b_index = 8 * deep + c; - value += input[a_index] * weight[b_index]; - } - output[r * offset + c] = value; - } - } - return; -} diff --git a/mindspore/lite/nnacl/common_func.h b/mindspore/lite/nnacl/common_func.h index 93a761e1fd..f126227e30 100644 --- a/mindspore/lite/nnacl/common_func.h +++ b/mindspore/lite/nnacl/common_func.h @@ -31,18 +31,6 @@ int8_t MinInt8(int8_t a, int8_t b); int8_t MaxInt8(int8_t a, int8_t b); void ReluFp32(float *data, float *dst, int ele_num); void Relu6Fp32(float *data, float *dst, int ele_num); -void SimplePostFuncInt8(const int *in, int8_t *out, int oc, int plane, int plane8, int32_t multiplier, - int32_t left_shift, int32_t right_shift, int32_t zp); -void IndirectGemmFp32_8x8(float *output, const float *input, const float *weight, const float *bias, size_t step, - size_t ic4, size_t output_channel, size_t offset, size_t mode, size_t writeC4, size_t relu, - size_t relu6); -void IndirectGemmFp32_8x4(float *output, const float *input, const float *weight, const float *bias, size_t step, - size_t ic4, size_t output_channel, size_t offset, size_t mode, size_t writeC4, size_t relu, - size_t relu6); -void IndirectGemmFp32_Comm(float *output, const float *input, const float *weight, size_t ic4, size_t hw, size_t oc, - size_t offset); -void IndirectGemmFp32(float *output, const float *input, const float *weight, const float *bias, size_t step, int ic4, - int output_channel, size_t offset, size_t relu, size_t relu6); int offset(const int *shape, const int dim0, const int dim1, const int dim2, const int dim3); int offsetComm(const int *shape, const int dim0, const int dim1, const int dim2); int offset4d(const int *shape, const int *dims); diff --git a/mindspore/lite/nnacl/fp32/matmul.c b/mindspore/lite/nnacl/fp32/matmul.c index e138f56a2d..934480fb28 100644 --- a/mindspore/lite/nnacl/fp32/matmul.c +++ b/mindspore/lite/nnacl/fp32/matmul.c @@ -470,14 +470,19 @@ void MatMul4x8(const float *a, const float *b, float *dst, const float *bias, Ac void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row, int col, size_t stride, int out_type) { #ifdef ENABLE_ARM64 - if (out_type == 2 && row <= 8) { - MatmulFloatNeon64OptRemain(a, b, c, deep, row, col, stride); + if (out_type == OutType_C8) { + MatmulFloatNeon64(a, b, c, bias, (int)act_type, deep, row, col, stride, 0, 0); + } else if (row <= 8) { + MatmulFloatNeon64OptRemain(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type)); } else { - MatmulFloatNeon64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type == OutType_Nhwc), - (int)(out_type == OutType_TileC8)); + MatmulFloatNeon64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type)); } #elif ENABLE_ARM32 - MatmulFloatNeon32Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type)); + if (out_type == OutType_C8) { + MatmulFloatNeon32(a, b, c, bias, (int)act_type, deep, row, col, stride, 0, 0); + } else { + MatmulFloatNeon32Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type)); + } #else MatMul12x8(a, b, c, bias, act_type, deep, row, col, stride, out_type); #endif diff --git a/mindspore/lite/nnacl/fp32/matmul.h b/mindspore/lite/nnacl/fp32/matmul.h index bb1d88295c..da3d7a7ac2 100644 --- a/mindspore/lite/nnacl/fp32/matmul.h +++ b/mindspore/lite/nnacl/fp32/matmul.h @@ -36,11 +36,14 @@ void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col); void RowMajor2Col12Major(float *src_ptr, float *dst_ptr, size_t row, size_t col); #ifdef ENABLE_ARM64 void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, - int col, size_t stride, bool write_nhwc); + int col, size_t stride, size_t writeNhwc, size_t WriteWino); void MatmulFloatNeon64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, - int col, size_t stride, size_t write_nhwc, size_t write_c4); -void MatmulFloatNeon64OptRemain(const float *a, const float *b, float *c, int depth, int row, int col, size_t stride); + int col, size_t stride, size_t write_mode); +void MatmulFloatNeon64OptRemain(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, + int row, int col, size_t stride, size_t write_mode); #elif ENABLE_ARM32 +void MatmulFloatNeon32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, + int col, int stride, size_t writeNhwc, size_t WriteWino); void MatmulFloatNeon32Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, int col, int stride, int write_mode); #endif