From: @lzkcode Reviewed-by: @zhang_xue_tong,@zhanghaibo5 Signed-off-by: @zhang_xue_tongpull/14941/MERGE
| @@ -607,7 +607,7 @@ build_lite() | |||
| cmake -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" -DANDROID_NATIVE_API_LEVEL="19" \ | |||
| -DANDROID_NDK="${ANDROID_NDK}" -DANDROID_ABI="armeabi-v7a" -DANDROID_TOOLCHAIN_NAME="clang" \ | |||
| -DANDROID_STL=${ANDROID_STL} -DCMAKE_BUILD_TYPE=${BUILD_TYPE} \ | |||
| -DPLATFORM_ARM32=on -DENABLE_NEON=on -DSUPPORT_TRAIN=${SUPPORT_TRAIN} \ | |||
| -DPLATFORM_ARM32=on -DENABLE_NEON=on -DSUPPORT_TRAIN=${SUPPORT_TRAIN} -DENABLE_FP16="on" \ | |||
| -DENABLE_TOOLS=${ENABLE_TOOLS} -DENABLE_CONVERTER=${ENABLE_CONVERTER} -DBUILD_TESTCASES=${RUN_TESTCASES} \ | |||
| -DSUPPORT_GPU=${LOCAL_LITE_ENABLE_GPU} -DSUPPORT_NPU=${LOCAL_LITE_ENABLE_NPU} -DENABLE_V0=on \ | |||
| -DOFFLINE_COMPILE=${OPENCL_OFFLINE_COMPILE} -DBUILD_MINDDATA=${COMPILE_MINDDATA_LITE} \ | |||
| @@ -68,7 +68,7 @@ add_library(nnacl_mid OBJECT ${KERNEL_SRC} ${TRAIN_SRC} ${ASSEMBLY_SRC}) | |||
| add_dependencies(nnacl fbs_src) | |||
| add_dependencies(nnacl_mid fbs_src) | |||
| ########################### arm64 build optimize library ######################## | |||
| if(PLATFORM_ARM64) | |||
| ########################### arm fp16 build optimize library ######################## | |||
| if(ENABLE_FP16) | |||
| add_subdirectory(${NNACL_DIR}/optimize) | |||
| endif() | |||
| @@ -30,7 +30,7 @@ typedef struct ArgElement { | |||
| int8_t i8_data_; | |||
| int32_t i_data_; | |||
| float f_data_; | |||
| #ifdef ENABLE_ARM64 | |||
| #ifdef ENABLE_ARM | |||
| float16_t f16_data_; | |||
| #endif | |||
| } data_; | |||
| @@ -0,0 +1,602 @@ | |||
| #ifdef ENABLE_ARM32 | |||
| #include "nnacl/assembly_global.h" | |||
| .text | |||
| .align 5 | |||
| .global MatMul12x8A32Fp16 | |||
| #ifndef __APPLE__ | |||
| .type MatMul12x8A32Fp16, %function | |||
| #endif | |||
| // void MatMul12x8A32Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, | |||
| // int deep, int row, int col, int stride, bool write_mode); | |||
| // r0: a | |||
| // r1: b | |||
| // r2: dst | |||
| // r3: bias | |||
| // #4: depth | |||
| // #8: row | |||
| // #12: col | |||
| // #16: stride | |||
| // #20: writeNhwc/writeWino | |||
| asm_function MatMul12x8A32Fp16 | |||
| // r13(sp) and r15(pc) can not be used!! | |||
| // r9 r4 is tmp register | |||
| // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf | |||
| push {r3-r11, lr} | |||
| vpush {q4-q7} | |||
| add sp, sp, #104 | |||
| ldr r5, [sp, #4] | |||
| ldr r6, [sp, #8] | |||
| ldr r7, [sp, #12] | |||
| ldr r8, [sp, #16] | |||
| ldr lr, [sp, #20] | |||
| mov r10, r1 // b | |||
| mov r11, r0 // a | |||
| mov r12, r2 // dst | |||
| cmp lr, #2 | |||
| bne NoWinograd | |||
| mul r4, r8, r7 // stride * col | |||
| add r4, r4, r4 // r4 * sizeof(float16_t) | |||
| mov r9, #16 | |||
| mul r9, r8, r9 // stride * 8 * sizeof(float16_t) | |||
| NoWinograd: | |||
| add r8, r8, r8 // stride * sizeof(float16_t) | |||
| a .req r0 | |||
| weight .req r1 | |||
| dst .req r2 | |||
| bias .req r3 | |||
| depth .req r5 | |||
| row .req r6 | |||
| col .req r7 | |||
| stride .req r8 | |||
| b_tmp .req r10 | |||
| a_tmp .req r11 | |||
| dst_tmp .req r12 | |||
| .macro STORE_12x8 p1 | |||
| vst1.16 {\p1}, [dst] | |||
| add dst, dst, stride | |||
| .endm | |||
| .macro STORE_12x7 p1, p2, p3 | |||
| add r4, dst, #8 | |||
| add r9, dst, #12 | |||
| vst1.16 {\p1}, [dst] | |||
| vst1.32 {\p2}, [r4] | |||
| vst1.16 {\p3}, [r9] | |||
| add dst, dst, stride | |||
| .endm | |||
| .macro STORE_12x6 p1, p2 | |||
| add r4, dst, #8 | |||
| vst1.16 {\p1}, [dst] | |||
| vst1.32 {\p2}, [r4] | |||
| add dst, dst, stride | |||
| .endm | |||
| .macro STORE_12x5 p1, p2 | |||
| add r4, dst, #8 | |||
| vst1.16 {\p1}, [dst] | |||
| vst1.16 {\p2}, [r4] | |||
| add dst, dst, stride | |||
| .endm | |||
| .macro STORE_12x4 p1 | |||
| vst1.16 {\p1}, [dst] | |||
| add dst, dst, stride | |||
| .endm | |||
| .macro STORE_12x3 p1, p2 | |||
| add r4, dst, #4 | |||
| vst1.32 {\p1}, [dst] | |||
| vst1.16 {\p2}, [r4] | |||
| add dst, dst, stride | |||
| .endm | |||
| .macro STORE_12x2 p1 | |||
| vst1.32 {\p1}, [dst] | |||
| add dst, dst, stride | |||
| .endm | |||
| .macro STORE_12x1 p1 | |||
| vst1.16 {\p1}, [dst] | |||
| add dst, dst, stride | |||
| .endm | |||
| .macro STORE_C8 p1, p2 | |||
| vst1.16 {\p1}, [dst] | |||
| cmp row, \p2 | |||
| add dst, dst, stride | |||
| beq WriteEnd | |||
| .endm | |||
| .macro STORE_C7 p1, p2, p3, p4 | |||
| add r4, dst, #8 | |||
| add r9, dst, #12 | |||
| vst1.16 {\p1}, [dst] | |||
| vst1.32 {\p2}, [r4] | |||
| vst1.16 {\p3}, [r9] | |||
| add dst, dst, stride | |||
| cmp row, \p4 | |||
| beq WriteEnd | |||
| .endm | |||
| .macro STORE_C6 p1, p2, p3 | |||
| add r4, dst, #8 | |||
| vst1.16 {\p1}, [dst] | |||
| vst1.32 {\p2}, [r4] | |||
| add dst, dst, stride | |||
| cmp row, \p3 | |||
| beq WriteEnd | |||
| .endm | |||
| .macro STORE_C5 p1, p2, p3 | |||
| add r4, dst, #8 | |||
| vst1.16 {\p1}, [dst] | |||
| vst1.16 {\p2}, [r4] | |||
| add dst, dst, stride | |||
| cmp row, \p3 | |||
| beq WriteEnd | |||
| .endm | |||
| .macro STORE_C4 p1, p2 | |||
| vst1.16 {\p1}, [dst] | |||
| cmp row, \p2 | |||
| add dst, dst, stride | |||
| beq WriteEnd | |||
| .endm | |||
| .macro STORE_C3 p1, p2, p3 | |||
| add r4, dst, #4 | |||
| vst1.32 {\p1}, [dst] | |||
| vst1.16 {\p2}, [r4] | |||
| add dst, dst, stride | |||
| cmp row, \p3 | |||
| beq WriteEnd | |||
| .endm | |||
| .macro STORE_C2 p1, p2 | |||
| vst1.32 {\p1}, [dst] | |||
| add dst, dst, stride | |||
| cmp row, \p2 | |||
| beq WriteEnd | |||
| .endm | |||
| .macro STORE_C1 p1, p2 | |||
| vst1.16 {\p1}, [dst] | |||
| add dst, dst, stride | |||
| cmp row, \p2 | |||
| beq WriteEnd | |||
| .endm | |||
| LoopRow12: | |||
| ldr bias, [sp, #-40] | |||
| LoopCol8: | |||
| mov dst, dst_tmp | |||
| mov a, a_tmp | |||
| ldr depth, [sp, #4] | |||
| veor q4, q4, q4 | |||
| veor q5, q5, q5 | |||
| veor q6, q6, q6 | |||
| veor q7, q7, q7 | |||
| veor q8, q8, q8 | |||
| veor q9, q9, q9 | |||
| veor q10, q10, q10 | |||
| veor q11, q11, q11 | |||
| veor q12, q12, q12 | |||
| veor q13, q13, q13 | |||
| veor q14, q14, q14 | |||
| veor q15, q15, q15 | |||
| LoopDepth: | |||
| vld1.16 {q0, d2}, [a]! | |||
| vld1.16 {q2}, [weight]! | |||
| vmla.f16 q4, q2, d0[0] | |||
| vmla.f16 q5, q2, d0[1] | |||
| vmla.f16 q6, q2, d0[2] | |||
| vmla.f16 q7, q2, d0[3] | |||
| vmla.f16 q8, q2, d1[0] | |||
| vmla.f16 q9, q2, d1[1] | |||
| vmla.f16 q10, q2, d1[2] | |||
| vmla.f16 q11, q2, d1[3] | |||
| vmla.f16 q12, q2, d2[0] | |||
| vmla.f16 q13, q2, d2[1] | |||
| vmla.f16 q14, q2, d2[2] | |||
| vmla.f16 q15, q2, d2[3] | |||
| subs depth, depth, #1 | |||
| bne LoopDepth | |||
| Bias: | |||
| cmp bias, #0 | |||
| beq Activation | |||
| vld1.16 {q0}, [bias]! | |||
| vadd.f16 q4, q4, q0 | |||
| vadd.f16 q5, q5, q0 | |||
| vadd.f16 q6, q6, q0 | |||
| vadd.f16 q7, q7, q0 | |||
| vadd.f16 q8, q8, q0 | |||
| vadd.f16 q9, q9, q0 | |||
| vadd.f16 q10, q10, q0 | |||
| vadd.f16 q11, q11, q0 | |||
| vadd.f16 q12, q12, q0 | |||
| vadd.f16 q13, q13, q0 | |||
| vadd.f16 q14, q14, q0 | |||
| vadd.f16 q15, q15, q0 | |||
| Activation: | |||
| ldr lr, [sp] | |||
| cmp lr, #3 | |||
| beq Relu6 | |||
| cmp lr, #1 | |||
| beq Relu | |||
| b Write | |||
| Relu6: | |||
| vmov.i16 q2, #0x4600 | |||
| vadd.f16 q4, q4, q2 | |||
| vadd.f16 q5, q5, q2 | |||
| vadd.f16 q6, q6, q2 | |||
| vadd.f16 q7, q7, q2 | |||
| vmin.f16 q8, q8, q2 | |||
| vmin.f16 q9, q9, q2 | |||
| vmin.f16 q10, q10, q2 | |||
| vmin.f16 q11, q11, q2 | |||
| vmin.f16 q12, q12, q2 | |||
| vmin.f16 q13, q13, q2 | |||
| vmin.f16 q14, q14, q2 | |||
| vmin.f16 q15, q15, q2 | |||
| Relu: | |||
| veor q3, q3, q3 | |||
| vmax.f16 q4, q4, q3 | |||
| vmax.f16 q5, q5, q3 | |||
| vmax.f16 q6, q6, q3 | |||
| vmax.f16 q7, q7, q3 | |||
| vmax.f16 q8, q8, q3 | |||
| vmax.f16 q9, q9, q3 | |||
| vmax.f16 q10, q10, q3 | |||
| vmax.f16 q11, q11, q3 | |||
| vmax.f16 q12, q12, q3 | |||
| vmax.f16 q13, q13, q3 | |||
| vmax.f16 q14, q14, q3 | |||
| vmax.f16 q15, q15, q3 | |||
| Write: | |||
| ldr lr, [sp, #20] | |||
| cmp lr, #2 | |||
| beq WriteWinograd | |||
| cmp row, #12 | |||
| bge Write12xCol | |||
| b WriteRowxCol | |||
| WriteWinograd: | |||
| vst1.16 {q4}, [dst] | |||
| add dst, dst, r4 | |||
| vst1.16 {q5}, [dst] | |||
| add dst, dst, r4 | |||
| vst1.16 {q6}, [dst] | |||
| add dst, dst, r4 | |||
| vst1.16 {q7}, [dst] | |||
| add dst, dst, r4 | |||
| vst1.16 {q8}, [dst] | |||
| add dst, dst, r4 | |||
| vst1.16 {q9}, [dst] | |||
| add dst, dst, r4 | |||
| vst1.16 {q10}, [dst] | |||
| add dst, dst, r4 | |||
| vst1.16 {q11}, [dst] | |||
| add dst, dst, r4 | |||
| vst1.16 {q12}, [dst] | |||
| add dst, dst, r4 | |||
| vst1.16 {q13}, [dst] | |||
| add dst, dst, r4 | |||
| vst1.16 {q14}, [dst] | |||
| add dst, dst, r4 | |||
| vst1.16 {q15}, [dst] | |||
| add dst_tmp, dst_tmp, r9 | |||
| b WriteEnd | |||
| Write12xCol: | |||
| cmp col, #8 | |||
| bge Write12x8 | |||
| cmp col, #1 | |||
| beq Write12x1 | |||
| cmp col, #2 | |||
| beq Write12x2 | |||
| cmp col, #3 | |||
| beq Write12x3 | |||
| cmp col, #4 | |||
| beq Write12x4 | |||
| cmp col, #5 | |||
| beq Write12x5 | |||
| cmp col, #6 | |||
| beq Write12x6 | |||
| b Write12x7 | |||
| WriteRowxCol: | |||
| cmp col, #8 | |||
| bge WriteRowx8 | |||
| cmp col, #1 | |||
| beq WriteRowx1 | |||
| cmp col, #2 | |||
| beq WriteRowx2 | |||
| cmp col, #3 | |||
| beq WriteRowx3 | |||
| cmp col, #4 | |||
| beq WriteRowx4 | |||
| cmp col, #5 | |||
| beq WriteRowx5 | |||
| cmp col, #6 | |||
| beq WriteRowx6 | |||
| b WriteRowx7 | |||
| Write12x8: | |||
| STORE_12x8 q4 | |||
| STORE_12x8 q5 | |||
| STORE_12x8 q6 | |||
| STORE_12x8 q7 | |||
| STORE_12x8 q8 | |||
| STORE_12x8 q9 | |||
| STORE_12x8 q10 | |||
| STORE_12x8 q11 | |||
| STORE_12x8 q12 | |||
| STORE_12x8 q13 | |||
| STORE_12x8 q14 | |||
| STORE_12x8 q15 | |||
| b WriteEnd | |||
| WriteRowx8: | |||
| STORE_C8 q4, #1 | |||
| STORE_C8 q5, #2 | |||
| STORE_C8 q6, #3 | |||
| STORE_C8 q7, #4 | |||
| STORE_C8 q8, #5 | |||
| STORE_C8 q9, #6 | |||
| STORE_C8 q10, #7 | |||
| STORE_C8 q11, #8 | |||
| STORE_C8 q12, #9 | |||
| STORE_C8 q13, #10 | |||
| STORE_C8 q14, #11 | |||
| STORE_C8 q15, #12 | |||
| b WriteEnd | |||
| Write12x1: | |||
| STORE_12x1 d8[0] | |||
| STORE_12x1 d10[0] | |||
| STORE_12x1 d12[0] | |||
| STORE_12x1 d14[0] | |||
| STORE_12x1 d16[0] | |||
| STORE_12x1 d18[0] | |||
| STORE_12x1 d20[0] | |||
| STORE_12x1 d22[0] | |||
| STORE_12x1 d24[0] | |||
| STORE_12x1 d26[0] | |||
| STORE_12x1 d28[0] | |||
| STORE_12x1 d30[0] | |||
| b WriteEnd | |||
| WriteRowx1: | |||
| STORE_C1 d8[0], #1 | |||
| STORE_C1 d10[0], #2 | |||
| STORE_C1 d12[0], #3 | |||
| STORE_C1 d14[0], #4 | |||
| STORE_C1 d16[0], #5 | |||
| STORE_C1 d18[0], #6 | |||
| STORE_C1 d20[0], #7 | |||
| STORE_C1 d22[0], #8 | |||
| STORE_C1 d24[0], #9 | |||
| STORE_C1 d26[0], #10 | |||
| STORE_C1 d28[0], #11 | |||
| STORE_C1 d30[0], #12 | |||
| b WriteEnd | |||
| Write12x2: | |||
| STORE_12x2 d8[0] | |||
| STORE_12x2 d10[0] | |||
| STORE_12x2 d12[0] | |||
| STORE_12x2 d14[0] | |||
| STORE_12x2 d16[0] | |||
| STORE_12x2 d18[0] | |||
| STORE_12x2 d20[0] | |||
| STORE_12x2 d22[0] | |||
| STORE_12x2 d24[0] | |||
| STORE_12x2 d26[0] | |||
| STORE_12x2 d28[0] | |||
| STORE_12x2 d30[0] | |||
| b WriteEnd | |||
| WriteRowx2: | |||
| STORE_C2 d8[0], #1 | |||
| STORE_C2 d10[0], #2 | |||
| STORE_C2 d12[0], #3 | |||
| STORE_C2 d14[0], #4 | |||
| STORE_C2 d16[0], #5 | |||
| STORE_C2 d18[0], #6 | |||
| STORE_C2 d20[0], #7 | |||
| STORE_C2 d22[0], #8 | |||
| STORE_C2 d24[0], #9 | |||
| STORE_C2 d26[0], #10 | |||
| STORE_C2 d28[0], #11 | |||
| STORE_C2 d30[0], #12 | |||
| b WriteEnd | |||
| Write12x3: | |||
| STORE_12x3 d8[0], d8[2] | |||
| STORE_12x3 d10[0], d10[2] | |||
| STORE_12x3 d12[0], d12[2] | |||
| STORE_12x3 d14[0], d14[2] | |||
| STORE_12x3 d16[0], d16[2] | |||
| STORE_12x3 d18[0], d18[2] | |||
| STORE_12x3 d20[0], d20[2] | |||
| STORE_12x3 d22[0], d22[2] | |||
| STORE_12x3 d24[0], d24[2] | |||
| STORE_12x3 d26[0], d26[2] | |||
| STORE_12x3 d28[0], d28[2] | |||
| STORE_12x3 d30[0], d30[2] | |||
| b WriteEnd | |||
| WriteRowx3: | |||
| STORE_C3 d8[0], d8[2], #1 | |||
| STORE_C3 d10[0], d10[2], #2 | |||
| STORE_C3 d12[0], d12[2], #3 | |||
| STORE_C3 d14[0], d14[2], #4 | |||
| STORE_C3 d16[0], d16[2], #5 | |||
| STORE_C3 d18[0], d18[2], #6 | |||
| STORE_C3 d20[0], d20[2], #7 | |||
| STORE_C3 d22[0], d22[2], #8 | |||
| STORE_C3 d24[0], d24[2], #9 | |||
| STORE_C3 d26[0], d26[2], #10 | |||
| STORE_C3 d28[0], d28[2], #11 | |||
| STORE_C3 d30[0], d30[2], #12 | |||
| b WriteEnd | |||
| Write12x4: | |||
| STORE_12x4 d8 | |||
| STORE_12x4 d10 | |||
| STORE_12x4 d12 | |||
| STORE_12x4 d14 | |||
| STORE_12x4 d16 | |||
| STORE_12x4 d18 | |||
| STORE_12x4 d20 | |||
| STORE_12x4 d22 | |||
| STORE_12x4 d24 | |||
| STORE_12x4 d26 | |||
| STORE_12x4 d28 | |||
| STORE_12x4 d30 | |||
| b WriteEnd | |||
| WriteRowx4: | |||
| STORE_C4 d8, #1 | |||
| STORE_C4 d10, #2 | |||
| STORE_C4 d12, #3 | |||
| STORE_C4 d14, #4 | |||
| STORE_C4 d16, #5 | |||
| STORE_C4 d18, #6 | |||
| STORE_C4 d20, #7 | |||
| STORE_C4 d22, #8 | |||
| STORE_C4 d24, #9 | |||
| STORE_C4 d26, #10 | |||
| STORE_C4 d28, #11 | |||
| STORE_C4 d30, #12 | |||
| b WriteEnd | |||
| Write12x5: | |||
| STORE_12x5 d8, d9[0] | |||
| STORE_12x5 d10, d11[0] | |||
| STORE_12x5 d12, d13[0] | |||
| STORE_12x5 d14, d15[0] | |||
| STORE_12x5 d16, d17[0] | |||
| STORE_12x5 d18, d19[0] | |||
| STORE_12x5 d20, d21[0] | |||
| STORE_12x5 d22, d23[0] | |||
| STORE_12x5 d24, d25[0] | |||
| STORE_12x5 d26, d27[0] | |||
| STORE_12x5 d28, d29[0] | |||
| STORE_12x5 d30, d31[0] | |||
| b WriteEnd | |||
| WriteRowx5: | |||
| STORE_C5 d8, d9[0], #1 | |||
| STORE_C5 d10, d11[0], #2 | |||
| STORE_C5 d12, d13[0], #3 | |||
| STORE_C5 d14, d15[0], #4 | |||
| STORE_C5 d16, d17[0], #5 | |||
| STORE_C5 d18, d19[0], #6 | |||
| STORE_C5 d20, d21[0], #7 | |||
| STORE_C5 d22, d23[0], #8 | |||
| STORE_C5 d24, d25[0], #9 | |||
| STORE_C5 d26, d27[0], #10 | |||
| STORE_C5 d28, d29[0], #11 | |||
| STORE_C5 d30, d31[0], #12 | |||
| b WriteEnd | |||
| Write12x6: | |||
| STORE_12x6 d8, d9[0] | |||
| STORE_12x6 d10, d11[0] | |||
| STORE_12x6 d12, d13[0] | |||
| STORE_12x6 d14, d15[0] | |||
| STORE_12x6 d16, d17[0] | |||
| STORE_12x6 d18, d19[0] | |||
| STORE_12x6 d20, d21[0] | |||
| STORE_12x6 d22, d23[0] | |||
| STORE_12x6 d24, d25[0] | |||
| STORE_12x6 d26, d27[0] | |||
| STORE_12x6 d28, d29[0] | |||
| STORE_12x6 d30, d31[0] | |||
| b WriteEnd | |||
| WriteRowx6: | |||
| STORE_C6 d8, d9[0], #1 | |||
| STORE_C6 d10, d11[0], #2 | |||
| STORE_C6 d12, d13[0], #3 | |||
| STORE_C6 d14, d15[0], #4 | |||
| STORE_C6 d16, d17[0], #5 | |||
| STORE_C6 d18, d19[0], #6 | |||
| STORE_C6 d20, d21[0], #7 | |||
| STORE_C6 d22, d23[0], #8 | |||
| STORE_C6 d24, d25[0], #9 | |||
| STORE_C6 d26, d27[0], #10 | |||
| STORE_C6 d28, d29[0], #11 | |||
| STORE_C6 d30, d31[0], #12 | |||
| b WriteEnd | |||
| Write12x7: | |||
| STORE_12x7 d8, d9[0], d9[2] | |||
| STORE_12x7 d10, d11[0], d11[2] | |||
| STORE_12x7 d12, d13[0], d13[2] | |||
| STORE_12x7 d14, d15[0], d15[2] | |||
| STORE_12x7 d16, d17[0], d17[2] | |||
| STORE_12x7 d18, d19[0], d19[2] | |||
| STORE_12x7 d20, d21[0], d21[2] | |||
| STORE_12x7 d22, d23[0], d23[2] | |||
| STORE_12x7 d24, d25[0], d25[2] | |||
| STORE_12x7 d26, d27[0], d27[2] | |||
| STORE_12x7 d28, d29[0], d29[2] | |||
| STORE_12x7 d30, d31[0], d31[2] | |||
| b WriteEnd | |||
| WriteRowx7: | |||
| STORE_C7 d8, d9[0], d9[2], #1 | |||
| STORE_C7 d10, d11[0], d11[2], #2 | |||
| STORE_C7 d12, d13[0], d13[2], #3 | |||
| STORE_C7 d14, d15[0], d15[2], #4 | |||
| STORE_C7 d16, d17[0], d17[2], #5 | |||
| STORE_C7 d18, d19[0], d19[2], #6 | |||
| STORE_C7 d20, d21[0], d21[2], #7 | |||
| STORE_C7 d22, d23[0], d23[2], #8 | |||
| STORE_C7 d24, d25[0], d25[2], #9 | |||
| STORE_C7 d26, d27[0], d27[2], #10 | |||
| STORE_C7 d28, d29[0], d29[2], #11 | |||
| STORE_C7 d30, d31[0], d31[2], #12 | |||
| b WriteEnd | |||
| WriteEnd: | |||
| cmp col, #8 | |||
| ble LoopColEnd | |||
| sub col, col, #8 | |||
| ldr lr, [sp, #20] | |||
| cmp lr, #2 | |||
| beq LoopCol8 | |||
| add dst_tmp, dst_tmp, #16 | |||
| b LoopCol8 | |||
| LoopColEnd: | |||
| cmp row, #12 | |||
| ble LoopRowEnd | |||
| sub row, row, #12 | |||
| mov a_tmp, a | |||
| mov weight, b_tmp | |||
| ldr lr, [sp, #20] | |||
| cmp lr, #2 | |||
| beq WinogradDst | |||
| ldr lr, [sp, #12] | |||
| sub lr, lr, col | |||
| add lr, lr, lr // col *= 2 | |||
| sub dst_tmp, dst, lr | |||
| b LoopRow | |||
| WinogradDst: | |||
| add dst_tmp, dst, r9 | |||
| LoopRow: | |||
| mov dst, dst_tmp | |||
| ldr col, [sp, #12] | |||
| b LoopRow12 | |||
| LoopRowEnd: | |||
| sub sp, sp, #104 | |||
| vpop {q4-q7} | |||
| pop {r3-r11, pc} | |||
| #endif | |||
| @@ -1,667 +0,0 @@ | |||
| #ifdef ENABLE_ARM64 | |||
| #include "nnacl/assembly_global.h" | |||
| .text | |||
| .align 5 | |||
| // void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, | |||
| // size_t step, size_t ic4, size_t oc8, size_t offset, size_t mode, size_t writeC4, size_t relu, size_t relu6); | |||
| // x0: output, x1: input, x2: weight, x3: bias, x4: step, x5: ic4, x6: oc8, x7: offset, | |||
| // x8:mode, x9: writeC4, x10:relu, x11: relu6 | |||
| // compute 8 channel for 16 outputs | |||
| asm_function IndirectGemmFp16_16x8 | |||
| .macro INIT_BIAS | |||
| dup v16.4s, wzr | |||
| cbz x3, InitBias | |||
| ld1 {v16.8h}, [x3] | |||
| InitBias: | |||
| mov v17.16b, v16.16b | |||
| mov v18.16b, v16.16b | |||
| mov v19.16b, v16.16b | |||
| mov v20.16b, v16.16b | |||
| mov v21.16b, v16.16b | |||
| mov v22.16b, v16.16b | |||
| mov v23.16b, v16.16b | |||
| mov v24.16b, v16.16b | |||
| mov v25.16b, v16.16b | |||
| mov v26.16b, v16.16b | |||
| mov v27.16b, v16.16b | |||
| mov v28.16b, v16.16b | |||
| mov v29.16b, v16.16b | |||
| mov v30.16b, v16.16b | |||
| mov v31.16b, v16.16b | |||
| .endm | |||
| // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to | |||
| // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers | |||
| // x19 ~ r29 should be also preserved | |||
| // whereas our coding style do not permit such amount of parameters | |||
| sub sp, sp, #144 | |||
| // performance between storing 4 registers at the same time and separately storing them on in-order cores | |||
| // is not tested yet | |||
| st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 | |||
| st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 | |||
| stp x19, x20, [sp], #16 | |||
| ldr x8, [sp, #0] | |||
| ldr x9, [sp, #8] | |||
| ldr x10, [sp, #16] | |||
| ldr x11, [sp, #24] | |||
| cbnz x8, IndirectGemmStart | |||
| // step is one for common convolution, where ic8 should multiply by kernel size | |||
| // step is (a+b-1) for F(a,b) in winograd | |||
| mul x5, x4, x5 | |||
| mov x4, #1 | |||
| IndirectGemmStart: | |||
| LoopOc: | |||
| mov x14, x4 | |||
| mov x12, x1 | |||
| LoopKsize: | |||
| mov x15, x0 | |||
| INIT_BIAS | |||
| // load input for output 1-8 | |||
| ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x12], #64 | |||
| // load weight | |||
| ld1 {v8.8h, v9.8h}, [x2], #32 | |||
| // first 2 steps for output 1 and 3 | |||
| fmla v16.8h, v8.8h, v0.h[0] | |||
| fmla v18.8h, v8.8h, v1.h[0] | |||
| fmla v16.8h, v9.8h, v0.h[1] | |||
| fmla v18.8h, v9.8h, v1.h[1] | |||
| // load weight | |||
| ld1 {v10.8h, v11.8h}, [x2], #32 | |||
| // first 2 steps for output 2 and 4 | |||
| fmla v17.8h, v8.8h, v0.h[4] | |||
| fmla v19.8h, v8.8h, v1.h[4] | |||
| fmla v17.8h, v9.8h, v0.h[5] | |||
| fmla v19.8h, v9.8h, v1.h[5] | |||
| // load input for output 9-16 | |||
| // input cache should be refreshed after loading | |||
| // ATTENTION: advance is preferred, but advancing too much may lead to invalid prefetching | |||
| ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x12], #64 | |||
| // last 2 steps for output 1 and 3 | |||
| fmla v16.8h, v10.8h, v0.h[2] | |||
| fmla v18.8h, v10.8h, v1.h[2] | |||
| fmla v16.8h, v11.8h, v0.h[3] | |||
| fmla v18.8h, v11.8h, v1.h[3] | |||
| // check if ic4=1 | |||
| subs x13, x5, #1 | |||
| beq LoopIcEnd | |||
| LoopIc: | |||
| // last 2 steps for output 2 and 4 | |||
| fmla v17.8h, v10.8h, v0.h[6] | |||
| fmla v19.8h, v10.8h, v1.h[6] | |||
| fmla v17.8h, v11.8h, v0.h[7] | |||
| fmla v19.8h, v11.8h, v1.h[7] | |||
| // steps for output 5-8 | |||
| fmla v20.8h, v8.8h, v2.h[0] | |||
| fmla v22.8h, v8.8h, v3.h[0] | |||
| fmla v20.8h, v9.8h, v2.h[1] | |||
| fmla v22.8h, v9.8h, v3.h[1] | |||
| fmla v21.8h, v8.8h, v2.h[4] | |||
| fmla v23.8h, v8.8h, v3.h[4] | |||
| fmla v21.8h, v9.8h, v2.h[5] | |||
| fmla v23.8h, v9.8h, v3.h[5] | |||
| fmla v20.8h, v10.8h, v2.h[2] | |||
| fmla v22.8h, v10.8h, v3.h[2] | |||
| fmla v20.8h, v11.8h, v2.h[3] | |||
| fmla v22.8h, v11.8h, v3.h[3] | |||
| fmla v21.8h, v10.8h, v2.h[6] | |||
| fmla v23.8h, v10.8h, v3.h[6] | |||
| fmla v21.8h, v11.8h, v2.h[7] | |||
| fmla v23.8h, v11.8h, v3.h[7] | |||
| // load input for output 1-8 | |||
| ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x12], #64 | |||
| // steps for output 9-12 | |||
| fmla v24.8h, v8.8h, v4.h[0] | |||
| fmla v26.8h, v8.8h, v5.h[0] | |||
| fmla v24.8h, v9.8h, v4.h[1] | |||
| fmla v26.8h, v9.8h, v5.h[1] | |||
| fmla v25.8h, v8.8h, v4.h[4] | |||
| fmla v27.8h, v8.8h, v5.h[4] | |||
| fmla v25.8h, v9.8h, v4.h[5] | |||
| fmla v27.8h, v9.8h, v5.h[5] | |||
| fmla v24.8h, v10.8h, v4.h[2] | |||
| fmla v26.8h, v10.8h, v5.h[2] | |||
| fmla v24.8h, v11.8h, v4.h[3] | |||
| fmla v26.8h, v11.8h, v5.h[3] | |||
| fmla v25.8h, v10.8h, v4.h[6] | |||
| fmla v27.8h, v10.8h, v5.h[6] | |||
| fmla v25.8h, v11.8h, v4.h[7] | |||
| fmla v27.8h, v11.8h, v5.h[7] | |||
| // steps for output 13-16 | |||
| fmla v28.8h, v8.8h, v6.h[0] | |||
| fmla v30.8h, v8.8h, v7.h[0] | |||
| fmla v28.8h, v9.8h, v6.h[1] | |||
| fmla v30.8h, v9.8h, v7.h[1] | |||
| fmla v29.8h, v8.8h, v6.h[4] | |||
| fmla v31.8h, v8.8h, v7.h[4] | |||
| fmla v29.8h, v9.8h, v6.h[5] | |||
| fmla v31.8h, v9.8h, v7.h[5] | |||
| // load weight | |||
| ld1 {v8.8h, v9.8h}, [x2], #32 | |||
| fmla v28.8h, v10.8h, v6.h[2] | |||
| fmla v30.8h, v10.8h, v7.h[2] | |||
| fmla v28.8h, v11.8h, v6.h[3] | |||
| fmla v30.8h, v11.8h, v7.h[3] | |||
| fmla v29.8h, v10.8h, v6.h[6] | |||
| fmla v31.8h, v10.8h, v7.h[6] | |||
| fmla v29.8h, v11.8h, v6.h[7] | |||
| fmla v31.8h, v11.8h, v7.h[7] | |||
| // load weight | |||
| ld1 {v10.8h, v11.8h}, [x2], #32 | |||
| // first 2 steps for output 1-4 | |||
| fmla v16.8h, v8.8h, v0.h[0] | |||
| fmla v18.8h, v8.8h, v1.h[0] | |||
| fmla v16.8h, v9.8h, v0.h[1] | |||
| fmla v18.8h, v9.8h, v1.h[1] | |||
| fmla v17.8h, v8.8h, v0.h[4] | |||
| fmla v19.8h, v8.8h, v1.h[4] | |||
| fmla v17.8h, v9.8h, v0.h[5] | |||
| fmla v19.8h, v9.8h, v1.h[5] | |||
| // load input for output 9-16 | |||
| ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x12], #64 | |||
| // last 2 steps for output 1 and 3 | |||
| fmla v16.8h, v10.8h, v0.h[2] | |||
| fmla v18.8h, v10.8h, v1.h[2] | |||
| fmla v16.8h, v11.8h, v0.h[3] | |||
| fmla v18.8h, v11.8h, v1.h[3] | |||
| subs x13, x13, #1 | |||
| bne LoopIc | |||
| LoopIcEnd: | |||
| fmla v17.8h, v10.8h, v0.h[6] | |||
| fmla v19.8h, v10.8h, v1.h[6] | |||
| fmla v17.8h, v11.8h, v0.h[7] | |||
| fmla v19.8h, v11.8h, v1.h[7] | |||
| // steps for output 5-8 | |||
| fmla v20.8h, v8.8h, v2.h[0] | |||
| fmla v22.8h, v8.8h, v3.h[0] | |||
| fmla v20.8h, v9.8h, v2.h[1] | |||
| fmla v22.8h, v9.8h, v3.h[1] | |||
| fmla v21.8h, v8.8h, v2.h[4] | |||
| fmla v23.8h, v8.8h, v3.h[4] | |||
| fmla v21.8h, v9.8h, v2.h[5] | |||
| fmla v23.8h, v9.8h, v3.h[5] | |||
| fmla v20.8h, v10.8h, v2.h[2] | |||
| fmla v22.8h, v10.8h, v3.h[2] | |||
| fmla v20.8h, v11.8h, v2.h[3] | |||
| fmla v22.8h, v11.8h, v3.h[3] | |||
| fmla v21.8h, v10.8h, v2.h[6] | |||
| fmla v23.8h, v10.8h, v3.h[6] | |||
| fmla v21.8h, v11.8h, v2.h[7] | |||
| fmla v23.8h, v11.8h, v3.h[7] | |||
| // steps for output 9-12 | |||
| fmla v24.8h, v8.8h, v4.h[0] | |||
| fmla v26.8h, v8.8h, v5.h[0] | |||
| fmla v24.8h, v9.8h, v4.h[1] | |||
| fmla v26.8h, v9.8h, v5.h[1] | |||
| fmla v25.8h, v8.8h, v4.h[4] | |||
| fmla v27.8h, v8.8h, v5.h[4] | |||
| fmla v25.8h, v9.8h, v4.h[5] | |||
| fmla v27.8h, v9.8h, v5.h[5] | |||
| fmla v24.8h, v10.8h, v4.h[2] | |||
| fmla v26.8h, v10.8h, v5.h[2] | |||
| fmla v24.8h, v11.8h, v4.h[3] | |||
| fmla v26.8h, v11.8h, v5.h[3] | |||
| fmla v25.8h, v10.8h, v4.h[6] | |||
| fmla v27.8h, v10.8h, v5.h[6] | |||
| fmla v25.8h, v11.8h, v4.h[7] | |||
| fmla v27.8h, v11.8h, v5.h[7] | |||
| // steps for output 13-16 | |||
| fmla v28.8h, v8.8h, v6.h[0] | |||
| fmla v30.8h, v8.8h, v7.h[0] | |||
| fmla v28.8h, v9.8h, v6.h[1] | |||
| fmla v30.8h, v9.8h, v7.h[1] | |||
| fmla v29.8h, v8.8h, v6.h[4] | |||
| fmla v31.8h, v8.8h, v7.h[4] | |||
| fmla v29.8h, v9.8h, v6.h[5] | |||
| fmla v31.8h, v9.8h, v7.h[5] | |||
| fmla v28.8h, v10.8h, v6.h[2] | |||
| fmla v30.8h, v10.8h, v7.h[2] | |||
| fmla v28.8h, v11.8h, v6.h[3] | |||
| fmla v30.8h, v11.8h, v7.h[3] | |||
| fmla v29.8h, v10.8h, v6.h[6] | |||
| fmla v31.8h, v10.8h, v7.h[6] | |||
| fmla v29.8h, v11.8h, v6.h[7] | |||
| fmla v31.8h, v11.8h, v7.h[7] | |||
| cbnz x11, Relu6 | |||
| cbnz x10, Relu | |||
| b WriteStart | |||
| Relu6: | |||
| movi v9.8h, #0x46, lsl #8 | |||
| fmin v16.8h, v16.8h, v9.8h | |||
| fmin v17.8h, v17.8h, v9.8h | |||
| fmin v18.8h, v18.8h, v9.8h | |||
| fmin v19.8h, v19.8h, v9.8h | |||
| fmin v20.8h, v20.8h, v9.8h | |||
| fmin v21.8h, v21.8h, v9.8h | |||
| fmin v22.8h, v22.8h, v9.8h | |||
| fmin v23.8h, v23.8h, v9.8h | |||
| fmin v24.8h, v24.8h, v9.8h | |||
| fmin v25.8h, v25.8h, v9.8h | |||
| fmin v26.8h, v26.8h, v9.8h | |||
| fmin v27.8h, v27.8h, v9.8h | |||
| fmin v28.8h, v28.8h, v9.8h | |||
| fmin v29.8h, v29.8h, v9.8h | |||
| fmin v30.8h, v30.8h, v9.8h | |||
| fmin v31.8h, v31.8h, v9.8h | |||
| Relu: | |||
| dup v8.4s, wzr | |||
| fmax v16.8h, v16.8h, v8.8h | |||
| fmax v17.8h, v17.8h, v8.8h | |||
| fmax v18.8h, v18.8h, v8.8h | |||
| fmax v19.8h, v19.8h, v8.8h | |||
| fmax v20.8h, v20.8h, v8.8h | |||
| fmax v21.8h, v21.8h, v8.8h | |||
| fmax v22.8h, v22.8h, v8.8h | |||
| fmax v23.8h, v23.8h, v8.8h | |||
| fmax v24.8h, v24.8h, v8.8h | |||
| fmax v25.8h, v25.8h, v8.8h | |||
| fmax v26.8h, v26.8h, v8.8h | |||
| fmax v27.8h, v27.8h, v8.8h | |||
| fmax v28.8h, v28.8h, v8.8h | |||
| fmax v29.8h, v29.8h, v8.8h | |||
| fmax v30.8h, v30.8h, v8.8h | |||
| fmax v31.8h, v31.8h, v8.8h | |||
| WriteStart: | |||
| cbnz x9, Write8 | |||
| cmp x6, #1 | |||
| beq Write1 | |||
| cmp x6, #2 | |||
| beq Write2 | |||
| cmp x6, #3 | |||
| beq Write3 | |||
| cmp x6, #4 | |||
| beq Write4 | |||
| cmp x6, #5 | |||
| beq Write5 | |||
| cmp x6, #6 | |||
| beq Write6 | |||
| cmp x6, #7 | |||
| beq Write7 | |||
| b Write8 | |||
| // prefetching is not preferred while writing results in spite of cache missing | |||
| // you could try prfm pstl2strm | |||
| // there are almost no benefits observed though | |||
| Write1: | |||
| str h16, [x15] | |||
| add x15, x15, x7 | |||
| str h17, [x15] | |||
| add x15, x15, x7 | |||
| str h18, [x15] | |||
| add x15, x15, x7 | |||
| str h19, [x15] | |||
| add x15, x15, x7 | |||
| str h20, [x15] | |||
| add x15, x15, x7 | |||
| str h21, [x15] | |||
| add x15, x15, x7 | |||
| str h22, [x15] | |||
| add x15, x15, x7 | |||
| str h23, [x15] | |||
| add x15, x15, x7 | |||
| str h24, [x15] | |||
| add x15, x15, x7 | |||
| str h25, [x15] | |||
| add x15, x15, x7 | |||
| str h26, [x15] | |||
| add x15, x15, x7 | |||
| str h27, [x15] | |||
| add x15, x15, x7 | |||
| str h28, [x15] | |||
| add x15, x15, x7 | |||
| str h29, [x15] | |||
| add x15, x15, x7 | |||
| str h30, [x15] | |||
| add x15, x15, x7 | |||
| str h31, [x15] | |||
| add x0, x0, #2 | |||
| b WriteEnd | |||
| Write2: | |||
| add x17, x15, #2 | |||
| st1 {v16.h}[0], [x15], x7 | |||
| st1 {v16.h}[1], [x17], x7 | |||
| st1 {v17.h}[0], [x15], x7 | |||
| st1 {v17.h}[1], [x17], x7 | |||
| st1 {v18.h}[0], [x15], x7 | |||
| st1 {v18.h}[1], [x17], x7 | |||
| st1 {v19.h}[0], [x15], x7 | |||
| st1 {v19.h}[1], [x17], x7 | |||
| st1 {v20.h}[0], [x15], x7 | |||
| st1 {v20.h}[1], [x17], x7 | |||
| st1 {v21.h}[0], [x15], x7 | |||
| st1 {v21.h}[1], [x17], x7 | |||
| st1 {v22.h}[0], [x15], x7 | |||
| st1 {v22.h}[1], [x17], x7 | |||
| st1 {v23.h}[0], [x15], x7 | |||
| st1 {v23.h}[1], [x17], x7 | |||
| st1 {v24.h}[0], [x15], x7 | |||
| st1 {v24.h}[1], [x17], x7 | |||
| st1 {v25.h}[0], [x15], x7 | |||
| st1 {v25.h}[1], [x17], x7 | |||
| st1 {v26.h}[0], [x15], x7 | |||
| st1 {v26.h}[1], [x17], x7 | |||
| st1 {v27.h}[0], [x15], x7 | |||
| st1 {v27.h}[1], [x17], x7 | |||
| st1 {v28.h}[0], [x15], x7 | |||
| st1 {v28.h}[1], [x17], x7 | |||
| st1 {v29.h}[0], [x15], x7 | |||
| st1 {v29.h}[1], [x17], x7 | |||
| st1 {v30.h}[0], [x15], x7 | |||
| st1 {v30.h}[1], [x17], x7 | |||
| st1 {v31.h}[0], [x15] | |||
| st1 {v31.h}[1], [x17] | |||
| add x0, x0, #4 | |||
| b WriteEnd | |||
| Write3: | |||
| add x17, x15, #4 | |||
| add x16, x15, #2 | |||
| st1 {v16.h}[0], [x15], x7 | |||
| st1 {v16.h}[1], [x16], x7 | |||
| st1 {v16.h}[2], [x17], x7 | |||
| st1 {v17.h}[0], [x15], x7 | |||
| st1 {v17.h}[1], [x16], x7 | |||
| st1 {v17.h}[2], [x17], x7 | |||
| st1 {v18.h}[0], [x15], x7 | |||
| st1 {v18.h}[1], [x16], x7 | |||
| st1 {v18.h}[2], [x17], x7 | |||
| st1 {v19.h}[0], [x15], x7 | |||
| st1 {v19.h}[1], [x16], x7 | |||
| st1 {v19.h}[2], [x17], x7 | |||
| st1 {v20.h}[0], [x15], x7 | |||
| st1 {v20.h}[1], [x16], x7 | |||
| st1 {v20.h}[2], [x17], x7 | |||
| st1 {v21.h}[0], [x15], x7 | |||
| st1 {v21.h}[1], [x16], x7 | |||
| st1 {v21.h}[2], [x17], x7 | |||
| st1 {v22.h}[0], [x15], x7 | |||
| st1 {v22.h}[1], [x16], x7 | |||
| st1 {v22.h}[2], [x17], x7 | |||
| st1 {v23.h}[0], [x15], x7 | |||
| st1 {v23.h}[1], [x16], x7 | |||
| st1 {v23.h}[2], [x17], x7 | |||
| st1 {v24.h}[0], [x15], x7 | |||
| st1 {v24.h}[1], [x16], x7 | |||
| st1 {v24.h}[2], [x17], x7 | |||
| st1 {v25.h}[0], [x15], x7 | |||
| st1 {v25.h}[1], [x16], x7 | |||
| st1 {v25.h}[2], [x17], x7 | |||
| st1 {v26.h}[0], [x15], x7 | |||
| st1 {v26.h}[1], [x16], x7 | |||
| st1 {v26.h}[2], [x17], x7 | |||
| st1 {v27.h}[0], [x15], x7 | |||
| st1 {v27.h}[1], [x16], x7 | |||
| st1 {v27.h}[2], [x17], x7 | |||
| st1 {v28.h}[0], [x15], x7 | |||
| st1 {v28.h}[1], [x16], x7 | |||
| st1 {v28.h}[2], [x17], x7 | |||
| st1 {v29.h}[0], [x15], x7 | |||
| st1 {v29.h}[1], [x16], x7 | |||
| st1 {v29.h}[2], [x17], x7 | |||
| st1 {v30.h}[0], [x15], x7 | |||
| st1 {v30.h}[1], [x16], x7 | |||
| st1 {v30.h}[2], [x17], x7 | |||
| st1 {v31.h}[0], [x15] | |||
| st1 {v31.h}[1], [x16] | |||
| st1 {v31.h}[2], [x17] | |||
| add x0, x0, #6 | |||
| b WriteEnd | |||
| Write4: | |||
| st1 {v16.4h}, [x15], x7 | |||
| st1 {v17.4h}, [x15], x7 | |||
| st1 {v18.4h}, [x15], x7 | |||
| st1 {v19.4h}, [x15], x7 | |||
| st1 {v20.4h}, [x15], x7 | |||
| st1 {v21.4h}, [x15], x7 | |||
| st1 {v22.4h}, [x15], x7 | |||
| st1 {v23.4h}, [x15], x7 | |||
| st1 {v24.4h}, [x15], x7 | |||
| st1 {v25.4h}, [x15], x7 | |||
| st1 {v26.4h}, [x15], x7 | |||
| st1 {v27.4h}, [x15], x7 | |||
| st1 {v28.4h}, [x15], x7 | |||
| st1 {v29.4h}, [x15], x7 | |||
| st1 {v30.4h}, [x15], x7 | |||
| st1 {v31.4h}, [x15] | |||
| add x0, x0, #8 | |||
| b WriteEnd | |||
| Write5: | |||
| add x17, x15, #8 | |||
| st1 {v16.4h}, [x15], x7 | |||
| st1 {v16.h}[4], [x17], x7 | |||
| st1 {v17.4h}, [x15], x7 | |||
| st1 {v17.h}[4], [x17], x7 | |||
| st1 {v18.4h}, [x15], x7 | |||
| st1 {v18.h}[4], [x17], x7 | |||
| st1 {v19.4h}, [x15], x7 | |||
| st1 {v19.h}[4], [x17], x7 | |||
| st1 {v20.4h}, [x15], x7 | |||
| st1 {v20.h}[4], [x17], x7 | |||
| st1 {v21.4h}, [x15], x7 | |||
| st1 {v21.h}[4], [x17], x7 | |||
| st1 {v22.4h}, [x15], x7 | |||
| st1 {v22.h}[4], [x17], x7 | |||
| st1 {v23.4h}, [x15], x7 | |||
| st1 {v23.h}[4], [x17], x7 | |||
| st1 {v24.4h}, [x15], x7 | |||
| st1 {v24.h}[4], [x17], x7 | |||
| st1 {v25.4h}, [x15], x7 | |||
| st1 {v25.h}[4], [x17], x7 | |||
| st1 {v26.4h}, [x15], x7 | |||
| st1 {v26.h}[4], [x17], x7 | |||
| st1 {v27.4h}, [x15], x7 | |||
| st1 {v27.h}[4], [x17], x7 | |||
| st1 {v28.4h}, [x15], x7 | |||
| st1 {v28.h}[4], [x17], x7 | |||
| st1 {v29.4h}, [x15], x7 | |||
| st1 {v29.h}[4], [x17], x7 | |||
| st1 {v30.4h}, [x15], x7 | |||
| st1 {v30.h}[4], [x17], x7 | |||
| st1 {v31.4h}, [x15] | |||
| st1 {v31.h}[4], [x17] | |||
| add x0, x0, #10 | |||
| b WriteEnd | |||
| Write6: | |||
| add x17, x15, #8 | |||
| add x16, x15, #10 | |||
| st1 {v16.4h}, [x15], x7 | |||
| ins v0.s[0], v16.s[2] | |||
| st1 {v0.h}[0], [x17], x7 | |||
| st1 {v0.h}[1], [x16], x7 | |||
| st1 {v17.4h}, [x15], x7 | |||
| ins v1.s[0], v17.s[2] | |||
| st1 {v1.h}[0], [x17], x7 | |||
| st1 {v1.h}[1], [x16], x7 | |||
| st1 {v18.4h}, [x15], x7 | |||
| ins v2.s[0], v18.s[2] | |||
| st1 {v2.h}[0], [x17], x7 | |||
| st1 {v2.h}[1], [x16], x7 | |||
| st1 {v19.4h}, [x15], x7 | |||
| ins v3.s[0], v19.s[2] | |||
| st1 {v3.h}[0], [x17], x7 | |||
| st1 {v3.h}[1], [x16], x7 | |||
| st1 {v20.4h}, [x15], x7 | |||
| ins v4.s[0], v20.s[2] | |||
| st1 {v4.h}[0], [x17], x7 | |||
| st1 {v4.h}[1], [x16], x7 | |||
| st1 {v21.4h}, [x15], x7 | |||
| ins v5.s[0], v21.s[2] | |||
| st1 {v5.h}[0], [x17], x7 | |||
| st1 {v5.h}[1], [x16], x7 | |||
| st1 {v22.4h}, [x15], x7 | |||
| ins v6.s[0], v22.s[2] | |||
| st1 {v6.h}[0], [x17], x7 | |||
| st1 {v6.h}[1], [x16], x7 | |||
| st1 {v23.4h}, [x15], x7 | |||
| ins v7.s[0], v23.s[2] | |||
| st1 {v7.h}[0], [x17], x7 | |||
| st1 {v7.h}[1], [x16], x7 | |||
| st1 {v24.4h}, [x15], x7 | |||
| ins v8.s[0], v24.s[2] | |||
| st1 {v8.h}[0], [x17], x7 | |||
| st1 {v8.h}[1], [x16], x7 | |||
| st1 {v25.4h}, [x15], x7 | |||
| ins v9.s[0], v25.s[2] | |||
| st1 {v9.h}[0], [x17], x7 | |||
| st1 {v9.h}[1], [x16], x7 | |||
| st1 {v26.4h}, [x15], x7 | |||
| ins v10.s[0], v26.s[2] | |||
| st1 {v10.h}[0], [x17], x7 | |||
| st1 {v10.h}[1], [x16], x7 | |||
| st1 {v27.4h}, [x15], x7 | |||
| ins v11.s[0], v27.s[2] | |||
| st1 {v11.h}[0], [x17], x7 | |||
| st1 {v11.h}[1], [x16], x7 | |||
| st1 {v28.4h}, [x15], x7 | |||
| ins v12.s[0], v28.s[2] | |||
| st1 {v12.h}[0], [x17], x7 | |||
| st1 {v12.h}[1], [x16], x7 | |||
| st1 {v29.4h}, [x15], x7 | |||
| ins v13.s[0], v29.s[2] | |||
| st1 {v13.h}[0], [x17], x7 | |||
| st1 {v13.h}[1], [x16], x7 | |||
| st1 {v30.4h}, [x15], x7 | |||
| ins v14.s[0], v30.s[2] | |||
| st1 {v14.h}[0], [x17], x7 | |||
| st1 {v14.h}[1], [x16], x7 | |||
| st1 {v31.4h}, [x15] | |||
| ins v15.s[0], v31.s[2] | |||
| st1 {v14.h}[0], [x17] | |||
| st1 {v14.h}[1], [x16] | |||
| add x0, x0, #12 | |||
| b WriteEnd | |||
| Write7: | |||
| add x17, x15, #8 | |||
| add x19, x15, #10 | |||
| add x16, x15, #12 | |||
| st1 {v16.4h}, [x15], x7 | |||
| ins v0.s[0], v16.s[2] | |||
| st1 {v0.h}[0], [x17], x7 | |||
| st1 {v0.h}[1], [x19], x7 | |||
| st1 {v16.h}[6], [x16], x7 | |||
| st1 {v17.4h}, [x15], x7 | |||
| ins v1.s[0], v17.s[2] | |||
| st1 {v1.h}[0], [x17], x7 | |||
| st1 {v1.h}[1], [x19], x7 | |||
| st1 {v17.h}[6], [x16], x7 | |||
| st1 {v18.4h}, [x15], x7 | |||
| ins v2.s[0], v18.s[2] | |||
| st1 {v2.h}[0], [x17], x7 | |||
| st1 {v2.h}[1], [x19], x7 | |||
| st1 {v18.h}[6], [x16], x7 | |||
| st1 {v19.4h}, [x15], x7 | |||
| ins v3.s[0], v19.s[2] | |||
| st1 {v3.h}[0], [x17], x7 | |||
| st1 {v3.h}[1], [x19], x7 | |||
| st1 {v19.h}[6], [x16], x7 | |||
| st1 {v20.4h}, [x15], x7 | |||
| ins v4.s[0], v20.s[2] | |||
| st1 {v4.h}[0], [x17], x7 | |||
| st1 {v4.h}[1], [x19], x7 | |||
| st1 {v20.h}[6], [x16], x7 | |||
| st1 {v21.4h}, [x15], x7 | |||
| ins v5.s[0], v21.s[2] | |||
| st1 {v5.h}[0], [x17], x7 | |||
| st1 {v5.h}[1], [x19], x7 | |||
| st1 {v21.h}[6], [x16], x7 | |||
| st1 {v22.4h}, [x15], x7 | |||
| ins v6.s[0], v22.s[2] | |||
| st1 {v6.h}[0], [x17], x7 | |||
| st1 {v6.h}[1], [x19], x7 | |||
| st1 {v22.h}[6], [x16], x7 | |||
| st1 {v23.4h}, [x15], x7 | |||
| ins v7.s[0], v23.s[2] | |||
| st1 {v7.h}[0], [x17], x7 | |||
| st1 {v7.h}[1], [x19], x7 | |||
| st1 {v23.h}[6], [x16], x7 | |||
| st1 {v24.4h}, [x15], x7 | |||
| ins v8.s[0], v24.s[2] | |||
| st1 {v8.h}[0], [x17], x7 | |||
| st1 {v8.h}[1], [x19], x7 | |||
| st1 {v24.h}[6], [x16], x7 | |||
| st1 {v25.4h}, [x15], x7 | |||
| ins v9.s[0], v25.s[2] | |||
| st1 {v9.h}[0], [x17], x7 | |||
| st1 {v9.h}[1], [x19], x7 | |||
| st1 {v25.h}[6], [x16], x7 | |||
| st1 {v26.4h}, [x15], x7 | |||
| ins v10.s[0], v26.s[2] | |||
| st1 {v10.h}[0], [x17], x7 | |||
| st1 {v10.h}[1], [x19], x7 | |||
| st1 {v26.h}[6], [x16], x7 | |||
| st1 {v27.4h}, [x15], x7 | |||
| ins v11.s[0], v27.s[2] | |||
| st1 {v11.h}[0], [x17], x7 | |||
| st1 {v11.h}[1], [x19], x7 | |||
| st1 {v27.h}[6], [x16], x7 | |||
| st1 {v28.4h}, [x15], x7 | |||
| ins v12.s[0], v28.s[2] | |||
| st1 {v12.h}[0], [x17], x7 | |||
| st1 {v12.h}[1], [x19], x7 | |||
| st1 {v28.h}[6], [x16], x7 | |||
| st1 {v29.4h}, [x15], x7 | |||
| ins v13.s[0], v29.s[2] | |||
| st1 {v13.h}[0], [x17], x7 | |||
| st1 {v13.h}[1], [x19], x7 | |||
| st1 {v29.h}[6], [x16], x7 | |||
| st1 {v30.4h}, [x15], x7 | |||
| ins v14.s[0], v30.s[2] | |||
| st1 {v14.h}[0], [x17], x7 | |||
| st1 {v14.h}[1], [x19], x7 | |||
| st1 {v30.h}[6], [x16], x7 | |||
| st1 {v31.4h}, [x15] | |||
| ins v15.s[0], v31.s[2] | |||
| st1 {v15.h}[0], [x17] | |||
| st1 {v15.h}[1], [x19] | |||
| st1 {v31.h}[6], [x16] | |||
| add x0, x0, #14 | |||
| b WriteEnd | |||
| Write8: | |||
| st1 {v16.8h}, [x15], x7 | |||
| st1 {v17.8h}, [x15], x7 | |||
| st1 {v18.8h}, [x15], x7 | |||
| st1 {v19.8h}, [x15], x7 | |||
| st1 {v20.8h}, [x15], x7 | |||
| st1 {v21.8h}, [x15], x7 | |||
| st1 {v22.8h}, [x15], x7 | |||
| st1 {v23.8h}, [x15], x7 | |||
| st1 {v24.8h}, [x15], x7 | |||
| st1 {v25.8h}, [x15], x7 | |||
| st1 {v26.8h}, [x15], x7 | |||
| st1 {v27.8h}, [x15], x7 | |||
| st1 {v28.8h}, [x15], x7 | |||
| st1 {v29.8h}, [x15], x7 | |||
| st1 {v30.8h}, [x15], x7 | |||
| st1 {v31.8h}, [x15] | |||
| add x0, x0, #16 | |||
| WriteEnd: | |||
| subs x14, x14, #1 | |||
| bne LoopKsize | |||
| subs x6, x6, #8 | |||
| cbz x3, NoStepForward | |||
| add x3, x3, #16 | |||
| NoStepForward: | |||
| bgt LoopOc | |||
| sub sp, sp, #144 | |||
| ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 | |||
| ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 | |||
| ldp x19, x20, [sp], #16 | |||
| ret | |||
| #endif | |||
| @@ -29,7 +29,7 @@ int ReluFp16(const float16_t *src, float16_t *dst, int ele_num) { | |||
| } | |||
| #endif | |||
| for (; offset < ele_num; offset++) { | |||
| dst[offset] = src[offset] < 0 ? 0 : src[offset]; | |||
| dst[offset] = src[offset] < 0.0f ? 0.0f : src[offset]; | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| @@ -47,14 +47,24 @@ int Relu6Fp16(const float16_t *data, float16_t *dst, int ele_num) { | |||
| } | |||
| #endif | |||
| for (; offset < ele_num; offset++) { | |||
| dst[offset] = data[offset] < 0 ? 0 : data[offset]; | |||
| dst[offset] = dst[offset] > 6 ? 6 : dst[offset]; | |||
| dst[offset] = data[offset] < 0.0f ? 0.0f : data[offset]; | |||
| dst[offset] = dst[offset] > 6.0f ? 6.0f : dst[offset]; | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| int LReluFp16(const float16_t *src, float16_t *dst, int ele_num, float16_t alpha) { | |||
| for (int i = 0; i < ele_num; ++i) { | |||
| int i = 0; | |||
| #ifdef ENABLE_NEON | |||
| int ele_c8 = UP_ROUND(ele_num, C8NUM); | |||
| for (; i < ele_c8; i += C8NUM) { | |||
| float16x8_t src_tmp = vld1q_f16(src + i); | |||
| float16x8_t mul_tmp = vmulq_n_f16(src_tmp, alpha); | |||
| float16x8_t mask = vcgtq_f16(src_tmp, vdupq_n_f16(0.0f)); | |||
| vst1q_f16(dst + i, vbslq_f32(mask, src_tmp, mul_tmp)); | |||
| } | |||
| #endif | |||
| for (; i < ele_num; ++i) { | |||
| dst[i] = src[i] > (float16_t)0.0f ? src[i] : (src[i] * alpha); | |||
| } | |||
| return NNACL_OK; | |||
| @@ -62,12 +72,12 @@ int LReluFp16(const float16_t *src, float16_t *dst, int ele_num, float16_t alpha | |||
| int SigmoidFp16(const float16_t *src, float16_t *dst, int ele_num) { | |||
| int i = 0; | |||
| #ifdef ENABLE_ARM64 | |||
| #ifdef ENABLE_NEON | |||
| int count = (ele_num / C4NUM) * C4NUM; | |||
| for (; i < count; i += C4NUM) { | |||
| float32x4_t tmp; | |||
| simd_exp(vnegq_f32(vcvt_f32_f16(vld1_f16(src + i))), (float *)&tmp); | |||
| vst1_f16(dst + i, vcvt_f16_f32(vdivq_f32(vdupq_n_f32(1.0f), vaddq_f32(vdupq_n_f32(1.0f), tmp)))); | |||
| vst1_f16(dst + i, vcvt_f16_f32(MS_DIVQ_F32(vdupq_n_f32(1.0f), vaddq_f32(vdupq_n_f32(1.0f), tmp)))); | |||
| } | |||
| #endif | |||
| for (; i < ele_num; ++i) { | |||
| @@ -79,9 +89,9 @@ int SigmoidFp16(const float16_t *src, float16_t *dst, int ele_num) { | |||
| } | |||
| float16_t TanhOptFp16(float16_t src) { | |||
| if (src > 5.0) { | |||
| if (src > 5.0f) { | |||
| return 1.0f; | |||
| } else if (src < -5.0) { | |||
| } else if (src < -5.0f) { | |||
| return -1.0f; | |||
| } else { | |||
| float square = src * src; | |||
| @@ -93,7 +103,7 @@ float16_t TanhOptFp16(float16_t src) { | |||
| int TanhFp16(const float16_t *src, float16_t *dst, int ele_num) { | |||
| int i = 0; | |||
| #ifdef ENABLE_ARM64 | |||
| #ifdef ENABLE_NEON | |||
| static float32x4_t paramv[] = {{378.0f, 378.0f, 378.0f, 378.0f}, | |||
| {17325.0f, 17325.0f, 17325.0f, 17325.0f}, | |||
| {135135.0f, 135135.0f, 135135.0f, 135135.0f}, | |||
| @@ -112,7 +122,7 @@ int TanhFp16(const float16_t *src, float16_t *dst, int ele_num) { | |||
| float32x4_t b = vaddq_f32( | |||
| vmulq_f32(vaddq_f32(vmulq_f32(vaddq_f32(vmulq_f32(paramv[3], square), paramv[4]), square), paramv[5]), square), | |||
| paramv[2]); | |||
| vst1_f16(dst + i, vcvt_f16_f32(vminq_f32(vmaxq_f32(vdivq_f32(a, b), neg_one), pos_one))); | |||
| vst1_f16(dst + i, vcvt_f16_f32(vminq_f32(vmaxq_f32(MS_DIVQ_F32(a, b), neg_one), pos_one))); | |||
| } | |||
| #endif | |||
| for (; i < ele_num; ++i) { | |||
| @@ -130,7 +140,7 @@ int TanhFp16(const float16_t *src, float16_t *dst, int ele_num) { | |||
| int HSwishFp16(const float16_t *src, float16_t *dst, int ele_num) { | |||
| for (int i = 0; i < ele_num; ++i) { | |||
| float16_t in = src[i]; | |||
| float16_t relu6 = MSMIN(MSMAX(in + 3, 0), 6); | |||
| float16_t relu6 = MSMIN(MSMAX(in + 3.0f, 0.0f), 6.0f); | |||
| dst[i] = in * relu6 / (float16_t)6.0f; | |||
| } | |||
| return NNACL_OK; | |||
| @@ -181,26 +191,26 @@ int GeluFp16(const float16_t *src, int length, float16_t *dst, bool approximate) | |||
| for (; i < C8; i += C8NUM) { | |||
| float16x8_t in = vld1q_f16(src + i); | |||
| float16x8_t res = | |||
| 0.5 * in * (1.0 + MS_TANHX8_F16(((float16_t)0.79788456080287 + (float16_t)0.035677408136 * in * in) * in)); | |||
| 0.5f * in * (1.0f + MS_TANHX8_F16(((float16_t)0.79788456080287f + (float16_t)0.035677408136f * in * in) * in)); | |||
| vst1q_f16(dst + i, res); | |||
| } | |||
| #endif | |||
| for (; i < length; i++) { | |||
| dst[i] = | |||
| 0.5 * src[i] * | |||
| (1.0 + TanhOptFp16(((float16_t)0.79788456080287 + (float16_t)0.035677408136 * src[i] * src[i]) * src[i])); | |||
| 0.5f * src[i] * | |||
| (1.0f + TanhOptFp16(((float16_t)0.79788456080287f + (float16_t)0.035677408136f * src[i] * src[i]) * src[i])); | |||
| } | |||
| } else { | |||
| #ifdef ENABLE_NEON | |||
| int C8 = UP_ROUND(length, C8NUM); | |||
| for (; i < C8; i += C8NUM) { | |||
| float16x8_t in = vld1q_f16(src + i); | |||
| const float16x8_t res = 0.5 * in * (1.0 + MS_ERFX8_F16(in / (float16_t)1.4142135623730951f)); | |||
| const float16x8_t res = 0.5f * in * (1.0f + MS_ERFX8_F16(in / (float16_t)1.4142135623730951f)); | |||
| vst1q_f16(dst + i, res); | |||
| } | |||
| #endif | |||
| for (; i < length; i++) { | |||
| dst[i] = 0.5 * src[i] * (1.0 + erf(src[i] / 1.4142135623730951f)); | |||
| dst[i] = 0.5f * src[i] * (1.0f + erff(src[i] / 1.4142135623730951f)); | |||
| } | |||
| } | |||
| return NNACL_OK; | |||
| @@ -16,11 +16,9 @@ | |||
| #ifndef MINDSPORE_NNACL_FP16_ACTIVATION_FP16_H_ | |||
| #define MINDSPORE_NNACL_FP16_ACTIVATION_FP16_H_ | |||
| #ifdef ENABLE_NEON | |||
| #include <arm_neon.h> | |||
| #endif | |||
| #include <math.h> | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/intrinsics/ms_simd_instructions_fp16.h" | |||
| #include "nnacl/int8/fixed_point.h" | |||
| #ifdef __cplusplus | |||
| @@ -569,7 +569,7 @@ int ElementDivFp16(const float16_t *input0, const float16_t *input1, float16_t * | |||
| for (; index <= element_size - 8; index += C8NUM) { | |||
| float16x8_t vin0 = vld1q_f16(input0 + index); | |||
| float16x8_t vin1 = vld1q_f16(input1 + index); | |||
| float16x8_t vout = vdivq_f16(vin0, vin1); | |||
| float16x8_t vout = MS_DIVQ_F16(vin0, vin1); | |||
| vst1q_f16(output + index, vout); | |||
| } | |||
| #endif | |||
| @@ -591,7 +591,7 @@ int ElementOptDivFp16(const float16_t *input0, const float16_t *input1, float16_ | |||
| #ifdef ENABLE_NEON | |||
| for (; index <= element_size - 8; index += C8NUM) { | |||
| float16x8_t vin1 = vld1q_f16(input1 + index); | |||
| float16x8_t vout = vdivq_f16(vin0_opt, vin1); | |||
| float16x8_t vout = MS_DIVQ_F16(vin0_opt, vin1); | |||
| vst1q_f16(output + index, vout); | |||
| } | |||
| #endif | |||
| @@ -606,7 +606,7 @@ int ElementOptDivFp16(const float16_t *input0, const float16_t *input1, float16_ | |||
| #ifdef ENABLE_NEON | |||
| for (; index <= element_size - 8; index += C8NUM) { | |||
| float16x8_t vin0 = vld1q_f16(input0 + index); | |||
| float16x8_t vout = vdivq_f16(vin0, vin1_opt); | |||
| float16x8_t vout = MS_DIVQ_F16(vin0, vin1_opt); | |||
| vst1q_f16(output + index, vout); | |||
| } | |||
| #endif | |||
| @@ -624,7 +624,7 @@ int ElementDivReluFp16(const float16_t *input0, const float16_t *input1, float16 | |||
| for (; index <= element_size - 8; index += C8NUM) { | |||
| float16x8_t vin0 = vld1q_f16(input0 + index); | |||
| float16x8_t vin1 = vld1q_f16(input1 + index); | |||
| float16x8_t vout = vdivq_f16(vin0, vin1); | |||
| float16x8_t vout = MS_DIVQ_F16(vin0, vin1); | |||
| vout = vmaxq_f16(vout, zeros); | |||
| vst1q_f16(output + index, vout); | |||
| } | |||
| @@ -652,7 +652,7 @@ int ElementOptDivReluFp16(const float16_t *input0, const float16_t *input1, floa | |||
| #ifdef ENABLE_NEON | |||
| for (; index <= element_size - 8; index += C8NUM) { | |||
| float16x8_t vin1 = vld1q_f16(input1 + index); | |||
| float16x8_t vout = vmaxq_f16(vdivq_f16(vin0_opt, vin1), zeros); | |||
| float16x8_t vout = vmaxq_f16(MS_DIVQ_F16(vin0_opt, vin1), zeros); | |||
| vst1q_f16(output + index, vout); | |||
| } | |||
| #endif | |||
| @@ -670,7 +670,7 @@ int ElementOptDivReluFp16(const float16_t *input0, const float16_t *input1, floa | |||
| #ifdef ENABLE_NEON | |||
| for (; index <= element_size - 8; index += C8NUM) { | |||
| float16x8_t vin0 = vld1q_f16(input0 + index); | |||
| float16x8_t vout = vmaxq_f16(vdivq_f16(vin0, vin1_opt), zeros); | |||
| float16x8_t vout = vmaxq_f16(MS_DIVQ_F16(vin0, vin1_opt), zeros); | |||
| vst1q_f16(output + index, vout); | |||
| } | |||
| #endif | |||
| @@ -689,7 +689,7 @@ int ElementDivRelu6Fp16(const float16_t *input0, const float16_t *input1, float1 | |||
| for (; index <= element_size - 8; index += C8NUM) { | |||
| float16x8_t vin0 = vld1q_f16(input0 + index); | |||
| float16x8_t vin1 = vld1q_f16(input1 + index); | |||
| float16x8_t vout = vdivq_f16(vin0, vin1); | |||
| float16x8_t vout = MS_DIVQ_F16(vin0, vin1); | |||
| vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); | |||
| vst1q_f16(output + index, vout); | |||
| } | |||
| @@ -716,7 +716,7 @@ int ElementOptDivRelu6Fp16(const float16_t *input0, const float16_t *input1, flo | |||
| #ifdef ENABLE_NEON | |||
| for (; index <= element_size - 8; index += C8NUM) { | |||
| float16x8_t vin1 = vld1q_f16(input1 + index); | |||
| float16x8_t vout = vminq_f16(vmaxq_f16(vdivq_f16(vin0_opt, vin1), zeros), bounds); | |||
| float16x8_t vout = vminq_f16(vmaxq_f16(MS_DIVQ_F16(vin0_opt, vin1), zeros), bounds); | |||
| vst1q_f16(output + index, vout); | |||
| } | |||
| #endif | |||
| @@ -733,7 +733,7 @@ int ElementOptDivRelu6Fp16(const float16_t *input0, const float16_t *input1, flo | |||
| #ifdef ENABLE_NEON | |||
| for (; index <= element_size - 8; index += C8NUM) { | |||
| float16x8_t vin0 = vld1q_f16(input0 + index); | |||
| float16x8_t vout = vminq_f16(vmaxq_f16(vdivq_f16(vin0, vin1_opt), zeros), bounds); | |||
| float16x8_t vout = vminq_f16(vmaxq_f16(MS_DIVQ_F16(vin0, vin1_opt), zeros), bounds); | |||
| vst1q_f16(output + index, vout); | |||
| } | |||
| #endif | |||
| @@ -16,10 +16,8 @@ | |||
| #ifndef MINDSPORE_NNACL_FP16_ARITHMETIC_FP16_H_ | |||
| #define MINDSPORE_NNACL_FP16_ARITHMETIC_FP16_H_ | |||
| #ifdef ENABLE_NEON | |||
| #include <arm_neon.h> | |||
| #endif | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/intrinsics/ms_simd_instructions_fp16.h" | |||
| #include "nnacl/base/arithmetic_base.h" | |||
| #include "nnacl/errorcode.h" | |||
| @@ -80,7 +80,7 @@ int ElementLogicalNotFp16(float16_t *input, float16_t *output, int element_size) | |||
| int ElementRoundFp16(float16_t *input, float16_t *output, int element_size) { | |||
| for (int i = 0; i < element_size; i++) { | |||
| output[i] = round(input[i]); | |||
| output[i] = roundf(input[i]); | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| @@ -94,7 +94,7 @@ int ElementFloorFp16(float16_t *input, float16_t *output, int element_size) { | |||
| int ElementCeilFp16(float16_t *input, float16_t *output, int number) { | |||
| for (int i = 0; i < number; ++i) { | |||
| output[i] = ceil(input[i]); | |||
| output[i] = ceilf(input[i]); | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| @@ -16,10 +16,8 @@ | |||
| #ifndef MINDSPORE_NNACL_FP16_ARITHMETIC_SELF_FP16_H_ | |||
| #define MINDSPORE_NNACL_FP16_ARITHMETIC_SELF_FP16_H_ | |||
| #ifdef ENABLE_NEON | |||
| #include <arm_neon.h> | |||
| #endif | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/intrinsics/ms_simd_instructions_fp16.h" | |||
| #include "nnacl/errorcode.h" | |||
| #ifdef __cplusplus | |||
| @@ -26,7 +26,7 @@ void BatchNormFp16(const float16_t *input, const void *mean, const void *varianc | |||
| for (int i = 0; i < cur_unit; i++) { | |||
| for (int c = 0; c < param->channel_; c++) { | |||
| float16_t variance_sqrt = sqrt(((const float16_t *)variance)[c] + param->epsilon_); | |||
| float16_t variance_sqrt = sqrtf(((const float16_t *)variance)[c] + param->epsilon_); | |||
| if (variance_sqrt != 0) { | |||
| output[cur_offset + c] = (input[cur_offset + c] - ((const float16_t *)mean)[c]) / variance_sqrt; | |||
| } | |||
| @@ -44,7 +44,7 @@ void FusedBatchNormFp16(const void *input, const void *scale, const void *offset | |||
| for (int i = 0; i < cur_unit; i++) { | |||
| for (int c = 0; c < param->channel_; c++) { | |||
| float16_t variance_sqrt = sqrt(((const float16_t *)variance)[c] + param->epsilon_); | |||
| float16_t variance_sqrt = sqrtf(((const float16_t *)variance)[c] + param->epsilon_); | |||
| if (variance_sqrt != 0) { | |||
| float16_t norm_val = | |||
| (((const float16_t *)input)[cur_offset + c] - ((const float16_t *)mean)[c]) / variance_sqrt; | |||
| @@ -13,12 +13,9 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_BATCHNORM_FP16_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_BATCHNORM_FP16_H_ | |||
| #ifndef MINDSPORE_NNACL_FP16_BATCHNORM_FP16_H_ | |||
| #define MINDSPORE_NNACL_FP16_BATCHNORM_FP16_H_ | |||
| #ifdef ENABLE_NEON | |||
| #include <arm_neon.h> | |||
| #endif | |||
| #include "nnacl/batchnorm_parameter.h" | |||
| #ifdef __cplusplus | |||
| @@ -34,4 +31,4 @@ void FusedBatchNormFp16(const void *input, const void *scale, const void *offset | |||
| } | |||
| #endif | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_BATCHNORM_FP16_H_ | |||
| #endif // MINDSPORE_NNACL_FP16_BATCHNORM_FP16_H_ | |||
| @@ -16,7 +16,6 @@ | |||
| #ifndef MINDSPORE_NNACL_CAST_FP16_H_ | |||
| #define MINDSPORE_NNACL_CAST_FP16_H_ | |||
| #include <arm_neon.h> | |||
| #include "nnacl/op_base.h" | |||
| #ifdef __cplusplus | |||
| @@ -56,3 +56,17 @@ void PostConvFuncFp16C4(const float16_t *c4_out, float16_t *nhwc_out, const floa | |||
| PostFuncBiasReluC4Fp16(nhwc_out, c4_out, bias, oc4div, oc4mod, plane, stride_size, act_type); | |||
| return; | |||
| } | |||
| #ifdef ENABLE_ARM82_A32 | |||
| void PostFuncBiasReluC4Fp16(float16_t *dst, const float16_t *src, const float16_t *bias, size_t oc4div, size_t oc4mod, | |||
| size_t plane_size, size_t plane_stride, size_t relu_type) { | |||
| // TODO(fun): function | |||
| return; | |||
| } | |||
| void PostFuncBiasReluC8Fp16(float16_t *dst, const float16_t *src, const float16_t *bias, size_t oc8div, size_t oc8mod, | |||
| size_t plane_size, size_t stride, size_t relu_type) { | |||
| // TODO(fun): function | |||
| return; | |||
| } | |||
| #endif | |||
| @@ -16,7 +16,6 @@ | |||
| #ifndef MINDSPORE_NNACL_FP16_COMMON_FUNC_FP16_H_ | |||
| #define MINDSPORE_NNACL_FP16_COMMON_FUNC_FP16_H_ | |||
| #include <arm_neon.h> | |||
| #include "nnacl/op_base.h" | |||
| #ifdef __cplusplus | |||
| @@ -16,9 +16,6 @@ | |||
| #ifndef MINDSPORE_NNACL_FP16_CONSTANT_OF_SHAPE_FP16_H_ | |||
| #define MINDSPORE_NNACL_FP16_CONSTANT_OF_SHAPE_FP16_H_ | |||
| #ifdef ENABLE_NEON | |||
| #include <arm_neon.h> | |||
| #endif | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/errorcode.h" | |||
| #include "nnacl/constant_of_shape_parameter.h" | |||
| @@ -27,7 +24,7 @@ | |||
| extern "C" { | |||
| #endif | |||
| #ifdef __cplusplus | |||
| #ifdef ENABLE_NEON | |||
| #ifdef ENABLE_FP16 | |||
| inline int ConstantOfShapeFp16(float16_t *output, int start, int end, float16_t value) { | |||
| for (int i = start; i < end; i++) { | |||
| output[i] = value; | |||
| @@ -18,6 +18,18 @@ | |||
| #include <string.h> | |||
| #include "nnacl/fp16/activation_fp16.h" | |||
| #ifdef ENABLE_ARM82_A32 | |||
| void ConvDwFp16Row(float16_t *output_ptr, const float16_t *input_ptr, const float16_t *weight_ptr, size_t num_pixels, | |||
| size_t output_channel, size_t input_step) { | |||
| for (int i = 0; i < num_pixels; i++) { | |||
| for (int c = 0; c < output_channel; c++) { | |||
| *output_ptr++ += weight_ptr[c] * input_ptr[c]; | |||
| } | |||
| input_ptr += input_step; | |||
| } | |||
| } | |||
| #endif | |||
| void ConvDwFp16(float16_t *output_data, const float16_t *input_data, const float16_t *weight_data, | |||
| const float16_t *bias_data, const ConvParameter *conv_param, int task_id) { | |||
| int h_step = UP_DIV(conv_param->output_h_, conv_param->thread_num_); | |||
| @@ -57,7 +69,6 @@ void ConvDwFp16(float16_t *output_data, const float16_t *input_data, const float | |||
| const float16_t *src_kw = src_kh + iw_origin * conv_param->input_channel_; | |||
| int num_pixels = out_w_end - out_w_start; | |||
| ConvDwFp16Row(dst_w, src_kw, weight_kh, num_pixels, conv_param->output_channel_, in_sw_step); | |||
| weight_kh += conv_param->output_channel_; | |||
| } | |||
| @@ -23,9 +23,9 @@ | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| #ifdef ENABLE_ARM64 | |||
| void ConvDwFp16Row(float16_t *output_ptr, const float16_t *input_ptr, const float16_t *filter_ptr, size_t num_pixels, | |||
| size_t input_channel, size_t input_step); | |||
| #ifdef ENABLE_ARM64 | |||
| void ConvDwFp16Border(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, | |||
| size_t height, size_t width, size_t in_kh_step, size_t in_kw_step, size_t kernel_w, size_t relu, | |||
| size_t relu6); | |||
| @@ -31,7 +31,7 @@ void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weigh | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| #ifndef ENABLE_NEON | |||
| #ifndef ENABLE_ARM64 | |||
| void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step, | |||
| size_t ic4, size_t out_channel, size_t offset, size_t mode, size_t writeC8, size_t relu, | |||
| size_t relu6) { | |||
| @@ -124,7 +124,11 @@ void IndirectGemmFp16_16x8_c8(float16_t *output, float16_t *input, float16_t *we | |||
| // fp16 convolution common (im2col+gemm) | |||
| void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_weight, float16_t *bias_data, | |||
| float16_t *col_major_input, float16_t *output_data, int task_id, ConvParameter *conv_param) { | |||
| #ifdef ENABLE_ARM64 | |||
| const int tile_n = 16; | |||
| #else | |||
| const int tile_n = 12; | |||
| #endif | |||
| int out_channel = conv_param->output_channel_; | |||
| int output_count = conv_param->output_h_ * conv_param->output_w_; | |||
| int output_tile_count = UP_DIV(output_count, tile_n); | |||
| @@ -144,7 +148,11 @@ void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_ | |||
| Im2ColPackUnitFp16(input_data + in_batch_offset, conv_param, gemm_input, real_cal_num, start_index); | |||
| int out_offset = thread_id * tile_n * out_channel + out_batch_offset; | |||
| #ifdef ENABLE_ARM64 | |||
| RowMajor2Col16MajorFp16Opt(gemm_input, col_major_gemm_input, tile_n, deep); | |||
| #else | |||
| RowMajor2Col12MajorFp16Opt(gemm_input, col_major_gemm_input, tile_n, deep); | |||
| #endif | |||
| MatMulFp16(col_major_gemm_input, packed_weight, output_data + out_offset, bias_data, conv_param->act_type_, deep, | |||
| real_cal_num, out_channel, out_channel, OutType_Nhwc); | |||
| } | |||
| @@ -155,7 +163,11 @@ void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_ | |||
| void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const float16_t *bias_data, | |||
| float16_t *output_data, TmpBufferAddressFp16 *buffer_list, int task_id, ConvParameter *conv_param, | |||
| InputTransFp16Func in_func, OutputTransFp16Func out_func) { | |||
| #ifdef ENABLE_ARM64 | |||
| const int tile_num = 16; | |||
| #else | |||
| const int tile_num = 12; | |||
| #endif | |||
| int in_channel = conv_param->input_channel_; | |||
| int out_w_block = UP_DIV(conv_param->output_w_, conv_param->output_unit_); | |||
| int out_h_block = UP_DIV(conv_param->output_h_, conv_param->output_unit_); | |||
| @@ -194,7 +206,11 @@ void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const floa | |||
| float16_t *dst_ptr = gemm_out + task_id * gemm_out_offset; | |||
| float16_t *tmp_col_ptr = col_buffer + task_id * col_buffer_offset; | |||
| for (int i = 0; i < input_unit_square; ++i) { | |||
| #ifdef ENABLE_ARM64 | |||
| RowMajor2Col16MajorFp16Opt(src_ptr + i * tile_num * in_channel, tmp_col_ptr, cal_num, in_channel); | |||
| #else | |||
| RowMajor2Col12MajorFp16Opt(src_ptr + i * tile_num * in_channel, tmp_col_ptr, cal_num, in_channel); | |||
| #endif | |||
| MatMulFp16(tmp_col_ptr, trans_weight + i * in_channel * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0, in_channel, | |||
| cal_num, oc8 * C8NUM, input_unit_square, OutType_TileC8); | |||
| } | |||
| @@ -24,7 +24,7 @@ | |||
| typedef float16_t *TmpBufferAddressFp16; | |||
| typedef float16_t *MatricesFp16; | |||
| #ifndef ENABLE_NEON | |||
| #ifndef ENABLE_ARM64 | |||
| void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step, | |||
| size_t ic4, size_t oc8, size_t offset, size_t mode, size_t writeC8, size_t relu, | |||
| size_t relu6); | |||
| @@ -17,7 +17,6 @@ | |||
| #ifndef MINDSPORE_NNACL_FP16_CROP_FP16_H_ | |||
| #define MINDSPORE_NNACL_FP16_CROP_FP16_H_ | |||
| #include <arm_neon.h> | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/crop_parameter.h" | |||
| @@ -53,11 +53,7 @@ int DeConvPostFp16(const float16_t *src, float16_t *tmp, const float16_t *bias, | |||
| int dst_index = oh * dst_oh_stride + ow * dst_ow_stride + kh * dst_kh_stride + kw * dst_kw_stride; | |||
| float16_t *tmp_dst = dst_ptr + dst_index; | |||
| const float16_t *tmp_src = src_ptr + src_index; | |||
| #ifdef DEBUG_CODE | |||
| for (int i = 0; i < C8NUM; i++) { | |||
| tmp_dst[i] += tmp_src[i]; | |||
| } | |||
| #else | |||
| #ifdef ENABLE_ARM64 | |||
| asm volatile( | |||
| "mov x0, %[tmp_src] \n" | |||
| "mov x1, %[tmp_dst] \n" | |||
| @@ -72,6 +68,10 @@ int DeConvPostFp16(const float16_t *src, float16_t *tmp, const float16_t *bias, | |||
| : | |||
| : [ tmp_src ] "r"(tmp_src), [ tmp_dst ] "r"(tmp_dst) | |||
| : "x0", "x1", "v0", "v1"); | |||
| #else | |||
| for (int i = 0; i < C8NUM; i++) { | |||
| tmp_dst[i] += tmp_src[i]; | |||
| } | |||
| #endif | |||
| } /*kw*/ | |||
| } /*kh*/ | |||
| @@ -47,6 +47,7 @@ void DeConvWgMergeFp16(const float16_t *src, float16_t *dst, size_t src_stride, | |||
| size_t cuont8 = count / C8NUM * C8NUM; | |||
| int i = 0; | |||
| for (; i < cuont8; i += C8NUM) { | |||
| #ifdef ENABLE_ARM64 | |||
| size_t src_step = src_stride * sizeof(float16_t); | |||
| size_t dst_step = dst_stride * sizeof(float16_t); | |||
| asm volatile( | |||
| @@ -93,7 +94,9 @@ void DeConvWgMergeFp16(const float16_t *src, float16_t *dst, size_t src_stride, | |||
| : | |||
| : [ src_ptr ] "r"(src_ptr), [ dst_ptr ] "r"(dst_ptr), [ src_step ] "r"(src_step), [ dst_step ] "r"(dst_step) | |||
| : "x7", "x8", "x10", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"); | |||
| #else | |||
| // TODO(fun): arm32 | |||
| #endif | |||
| src_ptr += C8NUM * src_stride; | |||
| dst_ptr += C8NUM * dst_stride; | |||
| } | |||
| @@ -373,3 +376,23 @@ void DeconvWgPostFp16(float16_t *tile_out, float16_t *nc4hw4_output, ConvParamet | |||
| } | |||
| return; | |||
| } | |||
| #ifdef ENABLE_ARM82_A32 | |||
| void WinogradTransLeftFp16(const float16_t *S, const float16_t *B, float16_t *M, size_t w, size_t h, size_t k, | |||
| size_t length) { | |||
| // TODO(fun): function | |||
| return; | |||
| } | |||
| void WinogradTransRightFp16(const float16_t *S, const float16_t *B, float16_t *M, size_t w, size_t h, size_t k, | |||
| size_t length) { | |||
| // TODO(fun): function | |||
| return; | |||
| } | |||
| void TiledC4MatmulFp16(float16_t *dst, const float16_t *src, const float16_t *weight, size_t ic4, size_t cal_num, | |||
| size_t oc4) { | |||
| // TODO(fun): function | |||
| return; | |||
| } | |||
| #endif | |||
| @@ -18,6 +18,7 @@ | |||
| #define MINDSPORE_NNACL_FP16_EXP_H_ | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/intrinsics/ms_simd_instructions_fp16.h" | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| @@ -16,6 +16,7 @@ | |||
| #include "nnacl/fp16/instance_norm_fp16.h" | |||
| #include <math.h> | |||
| #include "nnacl/errorcode.h" | |||
| #include "nnacl/intrinsics/ms_simd_instructions_fp16.h" | |||
| int InstanceNormFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *gamma_data, | |||
| const float16_t *beta_data, const InstanceNormParameter *param, size_t task_id) { | |||
| @@ -17,7 +17,6 @@ | |||
| #define MINDSPORE_NNACL_FP16_INSTANCE_NORM_H_ | |||
| #include "nnacl/instance_norm_parameter.h" | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| @@ -21,6 +21,7 @@ | |||
| #include "nnacl/fp16/arithmetic_fp16.h" | |||
| #include "nnacl/fp16/matmul_fp16.h" | |||
| #include "nnacl/fp16/cast_fp16.h" | |||
| #include "nnacl/intrinsics/ms_simd_instructions_fp16.h" | |||
| void PackLstmWeightFp32ToFp16(float16_t *dst, const float *src, int batch, int deep, int col, int col_align) { | |||
| for (int i = 0; i < batch; i++) { | |||
| @@ -121,7 +122,7 @@ int ElementOptMulAccFp16(const float16_t *input0, const float16_t input1, float1 | |||
| for (; index <= element_size - 8; index += 8) { | |||
| float16x8_t vin0 = vld1q_f16(input0 + index); | |||
| float16x8_t vout = vld1q_f16(output + index); | |||
| vout = vfmaq_n_f16(vout, vin0, input1); | |||
| vout = MS_FMAQ_N_F16(vout, vin0, input1); | |||
| vst1q_f16(output + index, vout); | |||
| } | |||
| for (; index < element_size; index++) { | |||
| @@ -226,24 +226,43 @@ void ColMajor2Row8MajorFp16(const void *src_ptr, float16_t *dst_ptr, size_t row, | |||
| return; | |||
| } | |||
| void MatMul16x8(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, | |||
| int deep, int row, int col, int stride, bool write_nhwc) { | |||
| if (write_nhwc) { | |||
| /* col16-major * row8-major => col-major */ | |||
| void MatMul16x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, | |||
| int deep, int row, int col, int stride, int write_mode) { | |||
| if (write_mode == OutType_Nhwc) { | |||
| for (int r = 0; r < row; r++) { | |||
| for (int c = 0; c < col; c++) { | |||
| int r16div = r / 16, r16mod = r % 16; | |||
| int c8div = c / 8, c8mod = c % 8; | |||
| size_t ci = r * stride + c; | |||
| float16_t value = 0; | |||
| for (int d = 0; d < deep; d++) { | |||
| size_t ai = r16div * deep * 16 + d * 16 + r16mod; | |||
| size_t bi = c8div * deep * 8 + d * 8 + c8mod; | |||
| value = value + a[ai] * b[bi]; | |||
| } | |||
| ADD_BIAS(value, bias, c) | |||
| DO_RELU(value, act_type) | |||
| DO_RELU6(value, act_type) | |||
| dst[ci] = value; | |||
| } | |||
| } | |||
| } else if (write_mode == OutType_C8) { | |||
| int col_8 = UP_ROUND(col, C8NUM); | |||
| int row_16 = UP_ROUND(row, C16NUM); | |||
| for (int r = 0; r < row_16; r++) { | |||
| for (int c = 0; c < col_8; c++) { | |||
| int r16div = r / C16NUM, r16mod = r % C16NUM; | |||
| int c8div = c / C8NUM, c8mod = c % C8NUM; | |||
| size_t ci = r * stride + c; | |||
| float value = 0; | |||
| size_t ci = (c8div * C8NUM * row_16 + r * C8NUM + c8mod); | |||
| float16_t value = 0; | |||
| for (int d = 0; d < deep; d++) { | |||
| size_t ai = r16div * deep * C16NUM + d * C16NUM + r16mod; | |||
| size_t bi = c8div * deep * C8NUM + d * C8NUM + c8mod; | |||
| value = value + a[ai] * b[bi]; | |||
| } | |||
| if (bias != NULL) value += bias[c]; | |||
| if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); | |||
| if (act_type != ActType_No) value = MSMAX(0.0f, value); | |||
| ADD_BIAS(value, bias, c) | |||
| DO_RELU(value, act_type) | |||
| DO_RELU6(value, act_type) | |||
| dst[ci] = value; | |||
| } | |||
| } | |||
| @@ -254,37 +273,119 @@ void MatMul16x8(const float16_t *a, const float16_t *b, float16_t *dst, const fl | |||
| for (int j = 0; j < col; ++j) { | |||
| int c8div = j / 8, c8mod = j % 8; | |||
| size_t ci = dst_r_offset + c8div * 8 * stride + c8mod; | |||
| float value = 0; | |||
| float16_t value = 0; | |||
| for (int d = 0; d < deep; ++d) { | |||
| size_t ai = src_r_offset + d * C16NUM; | |||
| size_t bi = c8div * deep * 8 + d * 8 + c8mod; | |||
| value = value + a[ai] * b[bi]; | |||
| } | |||
| if (bias != NULL) value += bias[j]; | |||
| if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); | |||
| if (act_type != ActType_No) value = MSMAX(0.0f, value); | |||
| ADD_BIAS(value, bias, j) | |||
| DO_RELU(value, act_type) | |||
| DO_RELU6(value, act_type) | |||
| dst[ci] = value; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void MatMul12x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, | |||
| int deep, int row, int col, int stride, int write_mode) { | |||
| if (write_mode == OutType_Nhwc) { | |||
| for (int r = 0; r < row; r++) { | |||
| for (int c = 0; c < col; c++) { | |||
| int r12div = r / 12, r12mod = r % 12; | |||
| int c8div = c / 8, c8mod = c % 8; | |||
| size_t ci = r * stride + c; | |||
| float16_t value = 0; | |||
| for (int d = 0; d < deep; d++) { | |||
| size_t ai = r12div * deep * 12 + d * 12 + r12mod; | |||
| size_t bi = c8div * deep * 8 + d * 8 + c8mod; | |||
| value = value + a[ai] * b[bi]; | |||
| } | |||
| ADD_BIAS(value, bias, c) | |||
| DO_RELU(value, act_type) | |||
| DO_RELU6(value, act_type) | |||
| dst[ci] = value; | |||
| } | |||
| } | |||
| } else if (write_mode == OutType_C8) { | |||
| int col_8 = UP_ROUND(col, C8NUM); | |||
| int row_12 = UP_ROUND(row, C12NUM); | |||
| for (int r = 0; r < row_12; r++) { | |||
| for (int c = 0; c < col_8; c++) { | |||
| int r12div = r / C12NUM, r12mod = r % C12NUM; | |||
| int c8div = c / C8NUM, c8mod = c % C8NUM; | |||
| size_t ci = (c8div * C8NUM * row_12 + r * C8NUM + c8mod); | |||
| float16_t value = 0; | |||
| for (int d = 0; d < deep; d++) { | |||
| size_t ai = r12div * deep * C12NUM + d * C12NUM + r12mod; | |||
| size_t bi = c8div * deep * C8NUM + d * C8NUM + c8mod; | |||
| value = value + a[ai] * b[bi]; | |||
| } | |||
| ADD_BIAS(value, bias, c) | |||
| DO_RELU(value, act_type) | |||
| DO_RELU6(value, act_type) | |||
| dst[ci] = value; | |||
| } | |||
| } | |||
| } else { | |||
| for (int i = 0; i < row; ++i) { | |||
| int src_r_offset = i; | |||
| int dst_r_offset = i * col * stride; | |||
| for (int j = 0; j < col; ++j) { | |||
| int c8div = j / 8, c8mod = j % 8; | |||
| size_t ci = dst_r_offset + c8div * 8 * stride + c8mod; | |||
| float16_t value = 0; | |||
| for (int d = 0; d < deep; ++d) { | |||
| size_t ai = src_r_offset + d * C12NUM; | |||
| size_t bi = c8div * deep * 8 + d * 8 + c8mod; | |||
| value = value + a[ai] * b[bi]; | |||
| } | |||
| ADD_BIAS(value, bias, j) | |||
| DO_RELU(value, act_type) | |||
| DO_RELU6(value, act_type) | |||
| dst[ci] = value; | |||
| } | |||
| } | |||
| } | |||
| return; | |||
| } | |||
| void MatMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type, | |||
| int depth, int row, int col, int stride, int out_type) { | |||
| if (out_type == OutType_C8) { | |||
| #ifdef ENABLE_ARM64 | |||
| MatmulFp16Neon64(a, b, c, bias, (int)act_type, depth, row, col, stride, false); | |||
| #else | |||
| MatMul12x8A32Fp16(a, b, c, bias, (int)act_type, depth, row, col, stride, out_type); | |||
| #endif | |||
| } else { | |||
| #ifdef ENABLE_ARM64 | |||
| MatmulFp16Neon64Opt(a, b, c, bias, (int)act_type, depth, row, col, stride, out_type); | |||
| #else | |||
| MatMul12x8A32Fp16(a, b, c, bias, (int)act_type, depth, row, col, stride, out_type); | |||
| #endif | |||
| } | |||
| return; | |||
| } | |||
| #ifdef ENABLE_ARM82_A32 | |||
| void MatVecMulA32Fp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, | |||
| int depth, int col) { | |||
| // TODO(fun): function | |||
| return; | |||
| } | |||
| #endif | |||
| void MatVecMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type, | |||
| int depth, int col) { | |||
| #ifdef ENABLE_ARM64 | |||
| MatVecMulFp16Neon64(a, b, c, bias, (int)act_type, depth, col); | |||
| #else | |||
| MatVecMulA32Fp16(a, b, c, bias, (int)act_type, depth, col); | |||
| #endif | |||
| } | |||
| #ifdef ENABLE_ARM64 | |||
| static void Row2Col16Block16(const float16_t *src_ptr, float16_t *dst_ptr, size_t col) { | |||
| size_t stride = col * 2; | |||
| asm volatile( | |||
| @@ -392,6 +493,7 @@ static void Row2Col16Block16(const float16_t *src_ptr, float16_t *dst_ptr, size_ | |||
| "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", | |||
| "v31"); | |||
| } | |||
| #endif | |||
| void RowMajor2Col16MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col) { | |||
| size_t row_up_16 = UP_ROUND(row, C16NUM); | |||
| @@ -442,6 +544,54 @@ void RowMajor2Col16MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, si | |||
| return; | |||
| } | |||
| void RowMajor2Col12MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col) { | |||
| size_t row_up_12 = UP_ROUND(row, C12NUM); | |||
| size_t row12 = row / C12NUM * C12NUM; | |||
| size_t col8 = col / C8NUM * C8NUM; | |||
| const float16_t *src_r = src_ptr; | |||
| float16_t *dst_r = dst_ptr; | |||
| size_t ri = 0; | |||
| // transpose 12x8 | |||
| for (; ri < row12; ri += C12NUM) { | |||
| size_t ci = 0; | |||
| for (; ci < col8; ci += C8NUM) { | |||
| const float16_t *src_c = src_r + ci; | |||
| float16_t *dst_c = dst_r + ci * C12NUM; | |||
| #ifdef ENABLE_ARM82_A32 | |||
| Transpose12x8A32Fp16(src_c, dst_c, col * sizeof(float16_t), 24); | |||
| #else | |||
| for (int tr = 0; tr < C12NUM; tr++) { | |||
| for (int tc = 0; tc < C8NUM; tc++) { | |||
| dst_c[tc * C12NUM + tr] = src_c[tr * col + tc]; | |||
| } | |||
| } | |||
| #endif | |||
| } | |||
| for (; ci < col; ci++) { | |||
| const float16_t *src_c = src_r + ci; | |||
| float16_t *dst_c = dst_r + ci * C12NUM; | |||
| for (size_t i = 0; i < C12NUM; i++) { | |||
| dst_c[i] = src_c[i * col]; | |||
| } | |||
| } | |||
| src_r += C12NUM * col; | |||
| dst_r += C12NUM * col; | |||
| } | |||
| for (; ri < row; ri++) { | |||
| for (size_t i = 0; i < col; ++i) { | |||
| dst_r[i * C12NUM] = src_r[i]; | |||
| } | |||
| src_r += col; | |||
| dst_r += 1; | |||
| } | |||
| for (; ri < row_up_12; ri++) { | |||
| for (size_t i = 0; i < col; i++) { | |||
| dst_r[i * C12NUM] = 0; | |||
| } | |||
| dst_r += 1; | |||
| } | |||
| } | |||
| void RowMajor2Col16MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src) { | |||
| if (is_fp32_src) { | |||
| const float *fp32_src = (const float *)src; | |||
| @@ -19,18 +19,51 @@ | |||
| #include <float.h> | |||
| #include <string.h> | |||
| #ifdef ENABLE_NEON | |||
| #ifdef ENABLE_ARM64 | |||
| #include <arm_neon.h> | |||
| #endif | |||
| #include "nnacl/errorcode.h" | |||
| #include "nnacl/matmul_parameter.h" | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/intrinsics/ms_simd_instructions_fp16.h" | |||
| #include "nnacl/fp16/pack_fp16.h" | |||
| #define ADD_BIAS(value, bias, c) \ | |||
| if (bias != NULL) value = value + bias[c]; | |||
| #define DO_RELU(value, act_type) \ | |||
| if (act_type == ActType_Relu) value = MSMAX(0.0f, value); | |||
| #define DO_RELU6(value, act_type) \ | |||
| if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); \ | |||
| if (act_type == ActType_Relu6) value = MSMAX(0.0f, value); | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| void MatMul16x8(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, | |||
| int deep, int row, int col, int stride, bool write_nhwc); | |||
| void MatMul16x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, | |||
| int deep, int row, int col, int stride, int write_mode); | |||
| void MatMul12x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, | |||
| int deep, int row, int col, int stride, int write_mode); | |||
| #ifdef ENABLE_ARM64 | |||
| void MatmulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, | |||
| size_t depth, size_t row, size_t col, size_t stride, bool write_nhwc); | |||
| void MatmulFp16Neon64Opt(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, | |||
| size_t depth, size_t row, size_t col, size_t stride, size_t write_nhwc); | |||
| void MatVecMulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, | |||
| int depth, int col); | |||
| #elif ENABLE_ARM82_A32 | |||
| void MatMul12x8A32Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, | |||
| int deep, int row, int col, int stride, int write_mode); | |||
| void MatVecMulA32Fp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, | |||
| int depth, int col); | |||
| #endif | |||
| void MatMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type, | |||
| int depth, int row, int col, int stride, int out_type); | |||
| @@ -42,14 +75,7 @@ void ColMajor2Row8MajorFp16(const void *src_ptr, float16_t *dst_ptr, size_t row, | |||
| void RowMajor2Col16MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col); | |||
| void MatmulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, | |||
| size_t depth, size_t row, size_t col, size_t stride, bool write_nhwc); | |||
| void MatmulFp16Neon64Opt(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, | |||
| size_t depth, size_t row, size_t col, size_t stride, size_t write_nhwc); | |||
| void MatVecMulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, | |||
| int depth, int col); | |||
| void RowMajor2Col12MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col); | |||
| void RowMajor2Col16MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src); | |||
| @@ -39,7 +39,7 @@ void MatrixMultiplyWinogradFp16(const float16_t *matix_a, const float16_t *matri | |||
| for (int i = 0; i < m; ++i) { | |||
| for (int j = 0; j < n; ++j) { | |||
| for (int y = 0; y < in_channel; ++y) { | |||
| float16_t tmp = 0; | |||
| float tmp = 0; | |||
| for (int z = 0; z < k; ++z) { | |||
| tmp += matix_a[z * in_channel + y + i * in_channel * k] * matrix_b[j + z * n]; | |||
| } | |||
| @@ -160,129 +160,35 @@ void PackNCHWToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int | |||
| } | |||
| void PackNHWCToNCHWFp16(const void *src, void *dst, int batches, int plane, int channel) { | |||
| int hw16 = plane / C16NUM * C16NUM; | |||
| #ifdef ENABLE_ARM64 | |||
| // Transpose16x8 in arm64 | |||
| const int hw_tile = C16NUM; | |||
| #else | |||
| // Transpose8x8 in others | |||
| const int hw_tile = C8NUM; | |||
| #endif | |||
| int hw_align = plane / hw_tile * hw_tile; | |||
| int c8 = channel / C8NUM * C8NUM; | |||
| int batch = plane * channel; | |||
| for (int n = 0; n < batches; n++) { | |||
| const float16_t *src_batch = (const float16_t *)src + n * batch; | |||
| float16_t *dst_batch = (float16_t *)dst + n * batch; | |||
| int hw = 0; | |||
| for (; hw < hw16; hw += C16NUM) { | |||
| for (; hw < hw_align; hw += hw_tile) { | |||
| int c = 0; | |||
| for (; c < c8; c += C8NUM) { | |||
| const float16_t *src_ptr = src_batch + hw * channel + c; | |||
| float16_t *dst_ptr = dst_batch + c * plane + hw; | |||
| #ifdef ENABLE_ARM64 | |||
| size_t srcStride = channel * sizeof(float16_t); | |||
| size_t dstStride = plane * sizeof(float16_t); | |||
| asm volatile( | |||
| "mov x10, %[src_ptr]\n" | |||
| "mov x11, %[dst_ptr]\n" | |||
| "ld1 {v0.8h}, [x10], %[srcStride]\n" | |||
| "ld1 {v1.8h}, [x10], %[srcStride]\n" | |||
| "ld1 {v2.8h}, [x10], %[srcStride]\n" | |||
| "ld1 {v3.8h}, [x10], %[srcStride]\n" | |||
| "ld1 {v4.8h}, [x10], %[srcStride]\n" | |||
| "ld1 {v5.8h}, [x10], %[srcStride]\n" | |||
| "ld1 {v6.8h}, [x10], %[srcStride]\n" | |||
| "ld1 {v7.8h}, [x10], %[srcStride]\n" | |||
| "zip1 v16.8h, v0.8h, v1.8h\n" | |||
| "zip1 v17.8h, v2.8h, v3.8h\n" | |||
| "zip1 v18.8h, v4.8h, v5.8h\n" | |||
| "zip1 v19.8h, v6.8h, v7.8h\n" | |||
| "ld1 {v8.8h}, [x10], %[srcStride]\n" | |||
| "ld1 {v9.8h}, [x10], %[srcStride]\n" | |||
| "ld1 {v10.8h}, [x10], %[srcStride]\n" | |||
| "ld1 {v11.8h}, [x10], %[srcStride]\n" | |||
| "ld1 {v12.8h}, [x10], %[srcStride]\n" | |||
| "ld1 {v13.8h}, [x10], %[srcStride]\n" | |||
| "ld1 {v14.8h}, [x10], %[srcStride]\n" | |||
| "ld1 {v15.8h}, [x10], %[srcStride]\n" | |||
| "trn1 v20.4s, v16.4s, v17.4s\n" | |||
| "trn2 v21.4s, v16.4s, v17.4s\n" | |||
| "trn1 v22.4s, v18.4s, v19.4s\n" | |||
| "trn2 v23.4s, v18.4s, v19.4s\n" | |||
| "trn1 v24.2d, v20.2d, v22.2d\n" | |||
| "trn2 v25.2d, v20.2d, v22.2d\n" | |||
| "trn1 v26.2d, v21.2d, v23.2d\n" | |||
| "trn2 v27.2d, v21.2d, v23.2d\n" | |||
| "zip1 v16.8h, v8.8h, v9.8h\n" | |||
| "zip1 v17.8h, v10.8h, v11.8h\n" | |||
| "zip1 v18.8h, v12.8h, v13.8h\n" | |||
| "zip1 v19.8h, v14.8h, v15.8h\n" | |||
| "trn1 v20.4s, v16.4s, v17.4s\n" | |||
| "trn2 v21.4s, v16.4s, v17.4s\n" | |||
| "trn1 v22.4s, v18.4s, v19.4s\n" | |||
| "trn2 v23.4s, v18.4s, v19.4s\n" | |||
| "trn1 v28.2d, v20.2d, v22.2d\n" | |||
| "trn2 v29.2d, v20.2d, v22.2d\n" | |||
| "trn1 v30.2d, v21.2d, v23.2d\n" | |||
| "trn2 v31.2d, v21.2d, v23.2d\n" | |||
| "add x10, x11, #16\n" | |||
| "st1 {v24.8h}, [x11], %[dstStride]\n" | |||
| "st1 {v28.8h}, [x10], %[dstStride]\n" | |||
| "st1 {v26.8h}, [x11], %[dstStride]\n" | |||
| "st1 {v30.8h}, [x10], %[dstStride]\n" | |||
| "st1 {v25.8h}, [x11], %[dstStride]\n" | |||
| "st1 {v29.8h}, [x10], %[dstStride]\n" | |||
| "st1 {v27.8h}, [x11], %[dstStride]\n" | |||
| "st1 {v31.8h}, [x10], %[dstStride]\n" | |||
| "zip2 v16.8h, v0.8h, v1.8h\n" | |||
| "zip2 v17.8h, v2.8h, v3.8h\n" | |||
| "zip2 v18.8h, v4.8h, v5.8h\n" | |||
| "zip2 v19.8h, v6.8h, v7.8h\n" | |||
| "trn1 v20.4s, v16.4s, v17.4s\n" | |||
| "trn2 v21.4s, v16.4s, v17.4s\n" | |||
| "trn1 v22.4s, v18.4s, v19.4s\n" | |||
| "trn2 v23.4s, v18.4s, v19.4s\n" | |||
| "trn1 v24.2d, v20.2d, v22.2d\n" | |||
| "trn2 v25.2d, v20.2d, v22.2d\n" | |||
| "trn1 v26.2d, v21.2d, v23.2d\n" | |||
| "trn2 v27.2d, v21.2d, v23.2d\n" | |||
| "zip2 v16.8h, v8.8h, v9.8h\n" | |||
| "zip2 v17.8h, v10.8h, v11.8h\n" | |||
| "zip2 v18.8h, v12.8h, v13.8h\n" | |||
| "zip2 v19.8h, v14.8h, v15.8h\n" | |||
| "trn1 v20.4s, v16.4s, v17.4s\n" | |||
| "trn2 v21.4s, v16.4s, v17.4s\n" | |||
| "trn1 v22.4s, v18.4s, v19.4s\n" | |||
| "trn2 v23.4s, v18.4s, v19.4s\n" | |||
| "trn1 v28.2d, v20.2d, v22.2d\n" | |||
| "trn2 v29.2d, v20.2d, v22.2d\n" | |||
| "trn1 v30.2d, v21.2d, v23.2d\n" | |||
| "trn2 v31.2d, v21.2d, v23.2d\n" | |||
| "st1 {v24.8h}, [x11], %[dstStride]\n" | |||
| "st1 {v28.8h}, [x10], %[dstStride]\n" | |||
| "st1 {v26.8h}, [x11], %[dstStride]\n" | |||
| "st1 {v30.8h}, [x10], %[dstStride]\n" | |||
| "st1 {v25.8h}, [x11], %[dstStride]\n" | |||
| "st1 {v29.8h}, [x10], %[dstStride]\n" | |||
| "st1 {v27.8h}, [x11], %[dstStride]\n" | |||
| "st1 {v31.8h}, [x10], %[dstStride]\n" | |||
| : | |||
| : | |||
| [ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ srcStride ] "r"(srcStride), [ dstStride ] "r"(dstStride) | |||
| : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||
| "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", | |||
| "v30", "v31"); | |||
| size_t src_stride = channel * sizeof(float16_t); | |||
| size_t dst_stride = plane * sizeof(float16_t); | |||
| Transpose16x8ARM64Fp16(src_ptr, dst_ptr, src_stride, dst_stride); | |||
| #elif defined(ENABLE_ARM82_A32) | |||
| size_t src_stride = channel * sizeof(float16_t); | |||
| size_t dst_stride = plane * sizeof(float16_t); | |||
| Transpose8x8A32Fp16(src_ptr, dst_ptr, src_stride, dst_stride); | |||
| #else | |||
| for (int tr = 0; tr < C16NUM; tr++) { | |||
| for (int tr = 0; tr < hw_tile; tr++) { | |||
| for (int tc = 0; tc < C8NUM; tc++) { | |||
| dst_ptr[tc * plane + tr] = src_ptr[tr * channel + tc]; | |||
| } | |||
| @@ -292,7 +198,7 @@ void PackNHWCToNCHWFp16(const void *src, void *dst, int batches, int plane, int | |||
| for (; c < channel; c++) { | |||
| const float16_t *src_ptr = src_batch + hw * channel + c; | |||
| float16_t *dst_ptr = dst_batch + c * plane + hw; | |||
| for (size_t i = 0; i < C16NUM; i++) { | |||
| for (size_t i = 0; i < hw_tile; i++) { | |||
| dst_ptr[i] = src_ptr[i * channel]; | |||
| } | |||
| } | |||
| @@ -305,7 +211,6 @@ void PackNHWCToNCHWFp16(const void *src, void *dst, int batches, int plane, int | |||
| } | |||
| } | |||
| } | |||
| return; | |||
| } | |||
| void PackNCHWToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel) { | |||
| @@ -565,3 +470,246 @@ void PackNHWC8ToNHWCFp16(float16_t *src, float16_t *dst, int batch, int plane, i | |||
| } | |||
| } | |||
| } | |||
| #ifdef ENABLE_ARM82_A32 | |||
| inline void Transpose8x8A32Fp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t dst_stride) { | |||
| asm volatile( | |||
| "mov r10, %[src]\n" | |||
| "mov r12, %[dst]\n" | |||
| "vld1.16 {q0}, [r10], %[src_stride]\n" | |||
| "vld1.16 {q2}, [r10], %[src_stride]\n" | |||
| "vld1.16 {q4}, [r10], %[src_stride]\n" | |||
| "vld1.16 {q6}, [r10], %[src_stride]\n" | |||
| "vtrn.16 d0, d4\n" | |||
| "vtrn.16 d1, d5\n" | |||
| "vtrn.16 d8, d12\n" | |||
| "vtrn.16 d9, d13\n" | |||
| "vld1.16 {q8}, [r10], %[src_stride]\n" | |||
| "vld1.16 {q10}, [r10], %[src_stride]\n" | |||
| "vld1.16 {q12}, [r10], %[src_stride]\n" | |||
| "vld1.16 {q14}, [r10], %[src_stride]\n" | |||
| "vtrn.32 d0, d8\n" | |||
| "vtrn.32 d4, d12\n" | |||
| "vtrn.32 d1, d9\n" | |||
| "vtrn.32 d5, d13\n" | |||
| "vtrn.16 d16, d20\n" | |||
| "vtrn.16 d17, d21\n" | |||
| "vtrn.16 d24, d28\n" | |||
| "vtrn.16 d25, d29\n" | |||
| "vtrn.32 d16, d24\n" | |||
| "vtrn.32 d20, d28\n" | |||
| "vtrn.32 d17, d25\n" | |||
| "vtrn.32 d21, d29\n" | |||
| "vswp d1, d16\n" | |||
| "vswp d5, d20\n" | |||
| "vswp d9, d24\n" | |||
| "vswp d13, d28\n" | |||
| "vst1.16 {q0}, [r12], %[dst_stride]\n" | |||
| "vst1.16 {q2}, [r12], %[dst_stride]\n" | |||
| "vst1.16 {q4}, [r12], %[dst_stride]\n" | |||
| "vst1.16 {q6}, [r12], %[dst_stride]\n" | |||
| "vst1.16 {q8}, [r12], %[dst_stride]\n" | |||
| "vst1.16 {q10}, [r12], %[dst_stride]\n" | |||
| "vst1.16 {q12}, [r12], %[dst_stride]\n" | |||
| "vst1.16 {q14}, [r12], %[dst_stride]\n" | |||
| : | |||
| : [ dst ] "r"(dst), [ src ] "r"(src), [ src_stride ] "r"(src_stride), [ dst_stride ] "r"(dst_stride) | |||
| : "r10", "r12", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", | |||
| "q15"); | |||
| } | |||
| inline void Transpose12x8A32Fp16(const float16_t *src_c, float16_t *dst_c, size_t src_stride, size_t dst_stride) { | |||
| asm volatile( | |||
| "mov r10, %[src_c]\n" | |||
| "mov r12, %[dst_c]\n" | |||
| "vld1.16 {q0}, [r10], %[src_stride]\n" | |||
| "vld1.16 {q2}, [r10], %[src_stride]\n" | |||
| "vld1.16 {q4}, [r10], %[src_stride]\n" | |||
| "vld1.16 {q6}, [r10], %[src_stride]\n" | |||
| "vtrn.16 d0, d4\n" | |||
| "vtrn.16 d1, d5\n" | |||
| "vtrn.16 d8, d12\n" | |||
| "vtrn.16 d9, d13\n" | |||
| "vld1.16 {q8}, [r10], %[src_stride]\n" | |||
| "vld1.16 {q10}, [r10], %[src_stride]\n" | |||
| "vld1.16 {q12}, [r10], %[src_stride]\n" | |||
| "vld1.16 {q14}, [r10], %[src_stride]\n" | |||
| "vtrn.32 d0, d8\n" | |||
| "vtrn.32 d4, d12\n" | |||
| "vtrn.32 d1, d9\n" | |||
| "vtrn.32 d5, d13\n" | |||
| "vtrn.16 d16, d20\n" | |||
| "vtrn.16 d17, d21\n" | |||
| "vtrn.16 d24, d28\n" | |||
| "vtrn.16 d25, d29\n" | |||
| "vld1.16 {q1}, [r10], %[src_stride]\n" | |||
| "vld1.16 {q3}, [r10], %[src_stride]\n" | |||
| "vld1.16 {q5}, [r10], %[src_stride]\n" | |||
| "vld1.16 {q7}, [r10], %[src_stride]\n" | |||
| "vtrn.32 d16, d24\n" | |||
| "vtrn.32 d20, d28\n" | |||
| "vtrn.32 d17, d25\n" | |||
| "vtrn.32 d21, d29\n" | |||
| "vswp d1, d16\n" | |||
| "vswp d5, d20\n" | |||
| "vswp d9, d24\n" | |||
| "vswp d13, d28\n" | |||
| "vtrn.16 d2, d6\n" | |||
| "vtrn.16 d3, d7\n" | |||
| "vtrn.16 d10, d14\n" | |||
| "vtrn.16 d11, d15\n" | |||
| "vtrn.32 d2, d10\n" | |||
| "vtrn.32 d6, d14\n" | |||
| "vtrn.32 d3, d11\n" | |||
| "vtrn.32 d7, d15\n" | |||
| "vst1.16 {q0, d2}, [r12], %[dst_stride]\n" | |||
| "vst1.16 {q2, d6}, [r12], %[dst_stride]\n" | |||
| "vst1.16 {q4, d10}, [r12], %[dst_stride]\n" | |||
| "vst1.16 {q6, d14}, [r12], %[dst_stride]\n" | |||
| "vswp d3, d18\n" | |||
| "vswp d7, d22\n" | |||
| "vswp d11, d26\n" | |||
| "vswp d15, d30\n" | |||
| "vst1.16 {q8, d18}, [r12], %[dst_stride]\n" | |||
| "vst1.16 {q10, d22}, [r12], %[dst_stride]\n" | |||
| "vst1.16 {q12, d26}, [r12], %[dst_stride]\n" | |||
| "vst1.16 {q14, d30}, [r12], %[dst_stride]\n" | |||
| : | |||
| : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ src_stride ] "r"(src_stride), [ dst_stride ] "r"(dst_stride) | |||
| : "r10", "r12", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", | |||
| "q15"); | |||
| } | |||
| #endif | |||
| #ifdef ENABLE_ARM64 | |||
| inline void Transpose16x8ARM64Fp16(const float16_t *src_ptr, float16_t *dst_ptr, size_t src_stride, size_t dst_stride) { | |||
| asm volatile( | |||
| "mov x10, %[src_ptr]\n" | |||
| "mov x11, %[dst_ptr]\n" | |||
| "ld1 {v0.8h}, [x10], %[src_stride]\n" | |||
| "ld1 {v1.8h}, [x10], %[src_stride]\n" | |||
| "ld1 {v2.8h}, [x10], %[src_stride]\n" | |||
| "ld1 {v3.8h}, [x10], %[src_stride]\n" | |||
| "ld1 {v4.8h}, [x10], %[src_stride]\n" | |||
| "ld1 {v5.8h}, [x10], %[src_stride]\n" | |||
| "ld1 {v6.8h}, [x10], %[src_stride]\n" | |||
| "ld1 {v7.8h}, [x10], %[src_stride]\n" | |||
| "zip1 v16.8h, v0.8h, v1.8h\n" | |||
| "zip1 v17.8h, v2.8h, v3.8h\n" | |||
| "zip1 v18.8h, v4.8h, v5.8h\n" | |||
| "zip1 v19.8h, v6.8h, v7.8h\n" | |||
| "ld1 {v8.8h}, [x10], %[src_stride]\n" | |||
| "ld1 {v9.8h}, [x10], %[src_stride]\n" | |||
| "ld1 {v10.8h}, [x10], %[src_stride]\n" | |||
| "ld1 {v11.8h}, [x10], %[src_stride]\n" | |||
| "ld1 {v12.8h}, [x10], %[src_stride]\n" | |||
| "ld1 {v13.8h}, [x10], %[src_stride]\n" | |||
| "ld1 {v14.8h}, [x10], %[src_stride]\n" | |||
| "ld1 {v15.8h}, [x10], %[src_stride]\n" | |||
| "trn1 v20.4s, v16.4s, v17.4s\n" | |||
| "trn2 v21.4s, v16.4s, v17.4s\n" | |||
| "trn1 v22.4s, v18.4s, v19.4s\n" | |||
| "trn2 v23.4s, v18.4s, v19.4s\n" | |||
| "trn1 v24.2d, v20.2d, v22.2d\n" | |||
| "trn2 v25.2d, v20.2d, v22.2d\n" | |||
| "trn1 v26.2d, v21.2d, v23.2d\n" | |||
| "trn2 v27.2d, v21.2d, v23.2d\n" | |||
| "zip1 v16.8h, v8.8h, v9.8h\n" | |||
| "zip1 v17.8h, v10.8h, v11.8h\n" | |||
| "zip1 v18.8h, v12.8h, v13.8h\n" | |||
| "zip1 v19.8h, v14.8h, v15.8h\n" | |||
| "trn1 v20.4s, v16.4s, v17.4s\n" | |||
| "trn2 v21.4s, v16.4s, v17.4s\n" | |||
| "trn1 v22.4s, v18.4s, v19.4s\n" | |||
| "trn2 v23.4s, v18.4s, v19.4s\n" | |||
| "trn1 v28.2d, v20.2d, v22.2d\n" | |||
| "trn2 v29.2d, v20.2d, v22.2d\n" | |||
| "trn1 v30.2d, v21.2d, v23.2d\n" | |||
| "trn2 v31.2d, v21.2d, v23.2d\n" | |||
| "add x10, x11, #16\n" | |||
| "st1 {v24.8h}, [x11], %[dst_stride]\n" | |||
| "st1 {v28.8h}, [x10], %[dst_stride]\n" | |||
| "st1 {v26.8h}, [x11], %[dst_stride]\n" | |||
| "st1 {v30.8h}, [x10], %[dst_stride]\n" | |||
| "st1 {v25.8h}, [x11], %[dst_stride]\n" | |||
| "st1 {v29.8h}, [x10], %[dst_stride]\n" | |||
| "st1 {v27.8h}, [x11], %[dst_stride]\n" | |||
| "st1 {v31.8h}, [x10], %[dst_stride]\n" | |||
| "zip2 v16.8h, v0.8h, v1.8h\n" | |||
| "zip2 v17.8h, v2.8h, v3.8h\n" | |||
| "zip2 v18.8h, v4.8h, v5.8h\n" | |||
| "zip2 v19.8h, v6.8h, v7.8h\n" | |||
| "trn1 v20.4s, v16.4s, v17.4s\n" | |||
| "trn2 v21.4s, v16.4s, v17.4s\n" | |||
| "trn1 v22.4s, v18.4s, v19.4s\n" | |||
| "trn2 v23.4s, v18.4s, v19.4s\n" | |||
| "trn1 v24.2d, v20.2d, v22.2d\n" | |||
| "trn2 v25.2d, v20.2d, v22.2d\n" | |||
| "trn1 v26.2d, v21.2d, v23.2d\n" | |||
| "trn2 v27.2d, v21.2d, v23.2d\n" | |||
| "zip2 v16.8h, v8.8h, v9.8h\n" | |||
| "zip2 v17.8h, v10.8h, v11.8h\n" | |||
| "zip2 v18.8h, v12.8h, v13.8h\n" | |||
| "zip2 v19.8h, v14.8h, v15.8h\n" | |||
| "trn1 v20.4s, v16.4s, v17.4s\n" | |||
| "trn2 v21.4s, v16.4s, v17.4s\n" | |||
| "trn1 v22.4s, v18.4s, v19.4s\n" | |||
| "trn2 v23.4s, v18.4s, v19.4s\n" | |||
| "trn1 v28.2d, v20.2d, v22.2d\n" | |||
| "trn2 v29.2d, v20.2d, v22.2d\n" | |||
| "trn1 v30.2d, v21.2d, v23.2d\n" | |||
| "trn2 v31.2d, v21.2d, v23.2d\n" | |||
| "st1 {v24.8h}, [x11], %[dst_stride]\n" | |||
| "st1 {v28.8h}, [x10], %[dst_stride]\n" | |||
| "st1 {v26.8h}, [x11], %[dst_stride]\n" | |||
| "st1 {v30.8h}, [x10], %[dst_stride]\n" | |||
| "st1 {v25.8h}, [x11], %[dst_stride]\n" | |||
| "st1 {v29.8h}, [x10], %[dst_stride]\n" | |||
| "st1 {v27.8h}, [x11], %[dst_stride]\n" | |||
| "st1 {v31.8h}, [x10], %[dst_stride]\n" | |||
| : | |||
| : [ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ src_stride ] "r"(src_stride), [ dst_stride ] "r"(dst_stride) | |||
| : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||
| "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", | |||
| "v31"); | |||
| } | |||
| #endif | |||
| @@ -17,11 +17,9 @@ | |||
| #ifndef MINDSPORE_NNACL_FP16_PACK_FP16_H_ | |||
| #define MINDSPORE_NNACL_FP16_PACK_FP16_H_ | |||
| #ifdef ENABLE_NEON | |||
| #include <arm_neon.h> | |||
| #endif | |||
| #include "nnacl/conv_parameter.h" | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/intrinsics/ms_simd_instructions_fp16.h" | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| @@ -72,6 +70,17 @@ void PackNHWCFp16ToC8HWN8Fp16(float16_t *src, float16_t *dst, int batch, int pla | |||
| void PackNHWC8Fp16ToNHWCFp32(float16_t *src, float *dst, int batch, int plane, int channel); | |||
| void PackNHWC8ToNHWCFp16(float16_t *src, float16_t *dst, int batch, int plane, int channel); | |||
| #ifdef ENABLE_ARM82_A32 | |||
| void Transpose8x8A32Fp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t dst_stride); | |||
| void Transpose12x8A32Fp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t dst_stride); | |||
| #endif | |||
| #ifdef ENABLE_ARM64 | |||
| void Transpose16x8ARM64Fp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t dst_stride); | |||
| #endif | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| @@ -16,9 +16,6 @@ | |||
| #ifndef MINDSPORE_NNACL_FP16_PAD_FP16_H_ | |||
| #define MINDSPORE_NNACL_FP16_PAD_FP16_H_ | |||
| #ifdef ENABLE_NEON | |||
| #include <arm_neon.h> | |||
| #endif | |||
| #include "nnacl/fp32/pad_fp32.h" | |||
| #ifdef __cplusplus | |||
| @@ -18,10 +18,8 @@ | |||
| #define MINDSPORE_NNACL_FP16_POOLING_FP16_H_ | |||
| #include <math.h> | |||
| #ifdef ENABLE_NEON | |||
| #include <arm_neon.h> | |||
| #endif | |||
| #include "nnacl/pooling_parameter.h" | |||
| #include "nnacl/intrinsics/ms_simd_instructions_fp16.h" | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| @@ -17,7 +17,7 @@ | |||
| #include "nnacl/fp16/power_fp16.h" | |||
| #include "nnacl/errorcode.h" | |||
| #if defined(ENABLE_NEON) | |||
| #if defined(ENABLE_ARM64) | |||
| float16x8_t OptimizedPowerSimdFp16(float16x8_t x, const void *exponent) { | |||
| int tmp = (int)(*(float16_t *)exponent); | |||
| int exp = abs(tmp); | |||
| @@ -53,23 +53,23 @@ float16_t OptimizedPowerScalarFp16(float16_t x, const void *exponent) { | |||
| void PowerBroadCastFp16(const float16_t *input, const float16_t *exponent, float16_t *output, int len, float scale, | |||
| float shift) { | |||
| PowerScalarFunFp16 PowerScalarFunFp16_ = NULL; | |||
| #if defined(ENABLE_NEON) | |||
| #if defined(ENABLE_ARM64) | |||
| PowerSimdFunFp16 PowerSimdFunFp16_ = NULL; | |||
| #endif | |||
| if (CheckInteger(*exponent)) { | |||
| #if defined(ENABLE_NEON) | |||
| #if defined(ENABLE_ARM64) | |||
| PowerSimdFunFp16_ = OptimizedPowerSimdFp16; | |||
| #endif | |||
| PowerScalarFunFp16_ = OptimizedPowerScalarFp16; | |||
| } else { | |||
| #if defined(ENABLE_NEON) | |||
| #if defined(ENABLE_ARM64) | |||
| PowerSimdFunFp16_ = StdPowerSimdFp16; | |||
| #endif | |||
| PowerScalarFunFp16_ = StdPowerScalarFp16; | |||
| } | |||
| int i = 0; | |||
| #ifdef ENABLE_NEON | |||
| #ifdef ENABLE_ARM64 | |||
| int len_c8 = UP_ROUND(len, C8NUM); | |||
| float16x8_t scale_8 = vmovq_n_f16(scale); | |||
| float16x8_t shift_8 = vmovq_n_f16(shift); | |||
| @@ -87,7 +87,7 @@ void PowerSingleFp16(const float16_t *input, const float16_t *exponent, float16_ | |||
| float shift) { | |||
| int i = 0; | |||
| PowerScalarFunFp16 PowerScalarFunFp16_ = NULL; | |||
| #ifdef ENABLE_NEON | |||
| #ifdef ENABLE_ARM64 | |||
| int len_c8 = UP_ROUND(len, C8NUM); | |||
| float16x8_t scale_8 = vmovq_n_f16(scale); | |||
| float16x8_t shift_8 = vmovq_n_f16(shift); | |||
| @@ -19,9 +19,10 @@ | |||
| #include <math.h> | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/intrinsics/ms_simd_instructions_fp16.h" | |||
| #include "nnacl/power_parameter.h" | |||
| #if defined(ENABLE_NEON) | |||
| #if defined(ENABLE_ARM64) | |||
| typedef float16x8_t (*PowerSimdFunFp16)(float16x8_t x, const void *exponent); | |||
| #endif | |||
| typedef float16_t (*PowerScalarFunFp16)(float16_t x, const void *exponent); | |||
| @@ -36,7 +37,7 @@ static inline float16_t StdPowerScalarFp16(float16_t x, const void *exponent) { | |||
| return powf(x, *(float16_t *)exponent); | |||
| } | |||
| #if defined(ENABLE_NEON) | |||
| #if defined(ENABLE_ARM64) | |||
| static inline float16x8_t StdPowerSimdFp16(float16x8_t x, const void *exponent) { | |||
| float16x8_t result; | |||
| result[0] = powf(x[0], *(float16_t *)exponent); | |||
| @@ -18,10 +18,7 @@ | |||
| #define MINDSPORE_NNACL_FP16_QUANTDTYPECAST_FP16_H_ | |||
| #include "nnacl/op_base.h" | |||
| #ifdef ENABLE_NEON | |||
| #include <arm_neon.h> | |||
| #endif | |||
| #include "nnacl/intrinsics/ms_simd_instructions_fp16.h" | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| @@ -19,9 +19,6 @@ | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/reduce_parameter.h" | |||
| #ifdef ENABLE_NEON | |||
| #include <arm_neon.h> | |||
| #endif | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| @@ -18,10 +18,9 @@ | |||
| #define MINDSPORE_NNACL_SCALE_FP16_H_ | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/intrinsics/ms_simd_instructions_fp16.h" | |||
| #include "nnacl/scale.h" | |||
| #ifdef ENABLE_NEON | |||
| #include <arm_neon.h> | |||
| #endif | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| @@ -40,7 +40,7 @@ void SoftmaxNormFp16(const float16_t *src, float16_t *dst, int batch, int channe | |||
| } | |||
| } | |||
| int k = 0; | |||
| #ifdef ENABLE_NEON | |||
| #ifdef ENABLE_ARM64 | |||
| int count2 = (channel / C8NUM) * C8NUM; | |||
| for (; k < count2; k += C8NUM) { | |||
| float16x8_t input_8 = vld1q_f16(src + cur_batch_offset + k); | |||
| @@ -58,9 +58,9 @@ void SoftmaxNormFp16(const float16_t *src, float16_t *dst, int batch, int channe | |||
| void SumAndDivFp16(const float16_t *src, float16_t *dst, int batch, int channel) { | |||
| int cur_batch_offset = 0; | |||
| for (int i = 0; i < batch; i++, cur_batch_offset += channel) { | |||
| float16_t sum = 0; | |||
| float16_t sum = 0.0f; | |||
| int j = 0; | |||
| #ifdef ENABLE_NEON | |||
| #ifdef ENABLE_ARM64 | |||
| float16x8_t sum8 = vdupq_n_f16(0); | |||
| int count = (channel / C8NUM) * C8NUM; | |||
| for (; j < count; j += C8NUM) { | |||
| @@ -72,7 +72,7 @@ void SumAndDivFp16(const float16_t *src, float16_t *dst, int batch, int channel) | |||
| sum += src[cur_batch_offset + j]; | |||
| } | |||
| int k = 0; | |||
| #ifdef ENABLE_NEON | |||
| #ifdef ENABLE_ARM64 | |||
| const float16_t div = 1.0f / sum; | |||
| for (; k < count; k += C8NUM) { | |||
| vst1q_f16(dst + cur_batch_offset + k, vmulq_n_f16(vld1q_f16(src + cur_batch_offset + k), div)); | |||
| @@ -117,7 +117,7 @@ void SoftmaxFp16(const float16_t *input_ptr, float16_t *output_ptr, float16_t *s | |||
| } | |||
| for (int j = 0; j < input_shape[axis]; j++) { | |||
| int axis_offset = inner_offset + j * inner_size; | |||
| output_ptr[axis_offset] = exp(input_ptr[axis_offset] - max_data); | |||
| output_ptr[axis_offset] = expf(input_ptr[axis_offset] - max_data); | |||
| sum_data[k + sum_outter_offset] += output_ptr[axis_offset]; | |||
| } | |||
| } | |||
| @@ -18,10 +18,9 @@ | |||
| #define MINDSPORE_NNACL_FP16_SOFTMAX_FP16_H_ | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/intrinsics/ms_simd_instructions_fp16.h" | |||
| #include "nnacl/softmax_parameter.h" | |||
| #ifdef ENABLE_NEON | |||
| #include <arm_neon.h> | |||
| #endif | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| @@ -18,10 +18,8 @@ | |||
| #define MINDSPORE_NNACL_FP16_TRANSPOSE_FP16_H_ | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/intrinsics/ms_simd_instructions_fp16.h" | |||
| #include "nnacl/transpose.h" | |||
| #ifdef ENABLE_NEON | |||
| #include <arm_neon.h> | |||
| #endif | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| @@ -16,562 +16,15 @@ | |||
| #include "nnacl/fp16/winograd_transform_fp16.h" | |||
| // for fp16 convolution 3x3 filter/input/output transform F(4,3) | |||
| void Conv3x3Fp16InputUnit(float16_t *tmp_data, float16_t *trans_input_data, size_t step) { | |||
| float16x8_t d00 = vld1q_f16(tmp_data); | |||
| float16x8_t d01 = vld1q_f16(tmp_data + 8); | |||
| float16x8_t d02 = vld1q_f16(tmp_data + 2 * 8); | |||
| float16x8_t d03 = vld1q_f16(tmp_data + 3 * 8); | |||
| float16x8_t d04 = vld1q_f16(tmp_data + 4 * 8); | |||
| float16x8_t d05 = vld1q_f16(tmp_data + 5 * 8); | |||
| float16x8_t d10 = vld1q_f16(tmp_data + 6 * 8); | |||
| float16x8_t d11 = vld1q_f16(tmp_data + 7 * 8); | |||
| float16x8_t d12 = vld1q_f16(tmp_data + 8 * 8); | |||
| float16x8_t d13 = vld1q_f16(tmp_data + 9 * 8); | |||
| float16x8_t d14 = vld1q_f16(tmp_data + 10 * 8); | |||
| float16x8_t d15 = vld1q_f16(tmp_data + 11 * 8); | |||
| float16x8_t d20 = vld1q_f16(tmp_data + 12 * 8); | |||
| float16x8_t d21 = vld1q_f16(tmp_data + 13 * 8); | |||
| float16x8_t d22 = vld1q_f16(tmp_data + 14 * 8); | |||
| float16x8_t d23 = vld1q_f16(tmp_data + 15 * 8); | |||
| float16x8_t d24 = vld1q_f16(tmp_data + 16 * 8); | |||
| float16x8_t d25 = vld1q_f16(tmp_data + 17 * 8); | |||
| float16x8_t d30 = vld1q_f16(tmp_data + 18 * 8); | |||
| float16x8_t d31 = vld1q_f16(tmp_data + 19 * 8); | |||
| float16x8_t d32 = vld1q_f16(tmp_data + 20 * 8); | |||
| float16x8_t d33 = vld1q_f16(tmp_data + 21 * 8); | |||
| float16x8_t d34 = vld1q_f16(tmp_data + 22 * 8); | |||
| float16x8_t d35 = vld1q_f16(tmp_data + 23 * 8); | |||
| float16x8_t d40 = vld1q_f16(tmp_data + 24 * 8); | |||
| float16x8_t d41 = vld1q_f16(tmp_data + 25 * 8); | |||
| float16x8_t d42 = vld1q_f16(tmp_data + 26 * 8); | |||
| float16x8_t d43 = vld1q_f16(tmp_data + 27 * 8); | |||
| float16x8_t d44 = vld1q_f16(tmp_data + 28 * 8); | |||
| float16x8_t d45 = vld1q_f16(tmp_data + 29 * 8); | |||
| float16x8_t d50 = vld1q_f16(tmp_data + 30 * 8); | |||
| float16x8_t d51 = vld1q_f16(tmp_data + 31 * 8); | |||
| float16x8_t d52 = vld1q_f16(tmp_data + 32 * 8); | |||
| float16x8_t d53 = vld1q_f16(tmp_data + 33 * 8); | |||
| float16x8_t d54 = vld1q_f16(tmp_data + 34 * 8); | |||
| float16x8_t d55 = vld1q_f16(tmp_data + 35 * 8); | |||
| float16x8_t t00 = vaddq_f16(vsubq_f16(vmulq_n_f16(d00, 4), vmulq_n_f16(d20, 5)), d40); | |||
| float16x8_t t01 = vaddq_f16(vsubq_f16(vmulq_n_f16(d01, 4), vmulq_n_f16(d21, 5)), d41); | |||
| float16x8_t t02 = vaddq_f16(vsubq_f16(vmulq_n_f16(d02, 4), vmulq_n_f16(d22, 5)), d42); | |||
| float16x8_t t03 = vaddq_f16(vsubq_f16(vmulq_n_f16(d03, 4), vmulq_n_f16(d23, 5)), d43); | |||
| float16x8_t t04 = vaddq_f16(vsubq_f16(vmulq_n_f16(d04, 4), vmulq_n_f16(d24, 5)), d44); | |||
| float16x8_t t05 = vaddq_f16(vsubq_f16(vmulq_n_f16(d05, 4), vmulq_n_f16(d25, 5)), d45); | |||
| float16x8_t t10 = vaddq_f16(vaddq_f16(d30, d40), vmulq_n_f16(vaddq_f16(d10, d20), -4)); | |||
| float16x8_t t11 = vaddq_f16(vaddq_f16(d31, d41), vmulq_n_f16(vaddq_f16(d11, d21), -4)); | |||
| float16x8_t t12 = vaddq_f16(vaddq_f16(d32, d42), vmulq_n_f16(vaddq_f16(d12, d22), -4)); | |||
| float16x8_t t13 = vaddq_f16(vaddq_f16(d33, d43), vmulq_n_f16(vaddq_f16(d13, d23), -4)); | |||
| float16x8_t t14 = vaddq_f16(vaddq_f16(d34, d44), vmulq_n_f16(vaddq_f16(d14, d24), -4)); | |||
| float16x8_t t15 = vaddq_f16(vaddq_f16(d35, d45), vmulq_n_f16(vaddq_f16(d15, d25), -4)); | |||
| float16x8_t t20 = vaddq_f16(vsubq_f16(d40, d30), vmulq_n_f16(vsubq_f16(d10, d20), 4)); | |||
| float16x8_t t21 = vaddq_f16(vsubq_f16(d41, d31), vmulq_n_f16(vsubq_f16(d11, d21), 4)); | |||
| float16x8_t t22 = vaddq_f16(vsubq_f16(d42, d32), vmulq_n_f16(vsubq_f16(d12, d22), 4)); | |||
| float16x8_t t23 = vaddq_f16(vsubq_f16(d43, d33), vmulq_n_f16(vsubq_f16(d13, d23), 4)); | |||
| float16x8_t t24 = vaddq_f16(vsubq_f16(d44, d34), vmulq_n_f16(vsubq_f16(d14, d24), 4)); | |||
| float16x8_t t25 = vaddq_f16(vsubq_f16(d45, d35), vmulq_n_f16(vsubq_f16(d15, d25), 4)); | |||
| float16x8_t t30 = vaddq_f16(vsubq_f16(d40, d20), vmulq_n_f16(vsubq_f16(d30, d10), 2)); | |||
| float16x8_t t31 = vaddq_f16(vsubq_f16(d41, d21), vmulq_n_f16(vsubq_f16(d31, d11), 2)); | |||
| float16x8_t t32 = vaddq_f16(vsubq_f16(d42, d22), vmulq_n_f16(vsubq_f16(d32, d12), 2)); | |||
| float16x8_t t33 = vaddq_f16(vsubq_f16(d43, d23), vmulq_n_f16(vsubq_f16(d33, d13), 2)); | |||
| float16x8_t t34 = vaddq_f16(vsubq_f16(d44, d24), vmulq_n_f16(vsubq_f16(d34, d14), 2)); | |||
| float16x8_t t35 = vaddq_f16(vsubq_f16(d45, d25), vmulq_n_f16(vsubq_f16(d35, d15), 2)); | |||
| float16x8_t t40 = vaddq_f16(vsubq_f16(d40, d20), vmulq_n_f16(vsubq_f16(d10, d30), 2)); | |||
| float16x8_t t41 = vaddq_f16(vsubq_f16(d41, d21), vmulq_n_f16(vsubq_f16(d11, d31), 2)); | |||
| float16x8_t t42 = vaddq_f16(vsubq_f16(d42, d22), vmulq_n_f16(vsubq_f16(d12, d32), 2)); | |||
| float16x8_t t43 = vaddq_f16(vsubq_f16(d43, d23), vmulq_n_f16(vsubq_f16(d13, d33), 2)); | |||
| float16x8_t t44 = vaddq_f16(vsubq_f16(d44, d24), vmulq_n_f16(vsubq_f16(d14, d34), 2)); | |||
| float16x8_t t45 = vaddq_f16(vsubq_f16(d45, d25), vmulq_n_f16(vsubq_f16(d15, d35), 2)); | |||
| float16x8_t t50 = vaddq_f16(vsubq_f16(vmulq_n_f16(d10, 4), vmulq_n_f16(d30, 5)), d50); | |||
| float16x8_t t51 = vaddq_f16(vsubq_f16(vmulq_n_f16(d11, 4), vmulq_n_f16(d31, 5)), d51); | |||
| float16x8_t t52 = vaddq_f16(vsubq_f16(vmulq_n_f16(d12, 4), vmulq_n_f16(d32, 5)), d52); | |||
| float16x8_t t53 = vaddq_f16(vsubq_f16(vmulq_n_f16(d13, 4), vmulq_n_f16(d33, 5)), d53); | |||
| float16x8_t t54 = vaddq_f16(vsubq_f16(vmulq_n_f16(d14, 4), vmulq_n_f16(d34, 5)), d54); | |||
| float16x8_t t55 = vaddq_f16(vsubq_f16(vmulq_n_f16(d15, 4), vmulq_n_f16(d35, 5)), d55); | |||
| float16x8_t m00 = vaddq_f16(vsubq_f16(vmulq_n_f16(t00, 4), vmulq_n_f16(t02, 5)), t04); | |||
| float16x8_t m01 = vaddq_f16(vaddq_f16(t03, t04), vmulq_n_f16(vaddq_f16(t01, t02), -4)); | |||
| float16x8_t m02 = vaddq_f16(vsubq_f16(t04, t03), vmulq_n_f16(vsubq_f16(t01, t02), 4)); | |||
| float16x8_t m03 = vaddq_f16(vsubq_f16(t04, t02), vmulq_n_f16(vsubq_f16(t03, t01), 2)); | |||
| float16x8_t m04 = vaddq_f16(vsubq_f16(t04, t02), vmulq_n_f16(vsubq_f16(t01, t03), 2)); | |||
| float16x8_t m05 = vaddq_f16(vsubq_f16(vmulq_n_f16(t01, 4), vmulq_n_f16(t03, 5)), t05); | |||
| float16x8_t m10 = vaddq_f16(vsubq_f16(vmulq_n_f16(t10, 4), vmulq_n_f16(t12, 5)), t14); | |||
| float16x8_t m11 = vaddq_f16(vaddq_f16(t13, t14), vmulq_n_f16(vaddq_f16(t11, t12), -4)); | |||
| float16x8_t m12 = vaddq_f16(vsubq_f16(t14, t13), vmulq_n_f16(vsubq_f16(t11, t12), 4)); | |||
| float16x8_t m13 = vaddq_f16(vsubq_f16(t14, t12), vmulq_n_f16(vsubq_f16(t13, t11), 2)); | |||
| float16x8_t m14 = vaddq_f16(vsubq_f16(t14, t12), vmulq_n_f16(vsubq_f16(t11, t13), 2)); | |||
| float16x8_t m15 = vaddq_f16(vsubq_f16(vmulq_n_f16(t11, 4), vmulq_n_f16(t13, 5)), t15); | |||
| float16x8_t m20 = vaddq_f16(vsubq_f16(vmulq_n_f16(t20, 4), vmulq_n_f16(t22, 5)), t24); | |||
| float16x8_t m21 = vaddq_f16(vaddq_f16(t23, t24), vmulq_n_f16(vaddq_f16(t21, t22), -4)); | |||
| float16x8_t m22 = vaddq_f16(vsubq_f16(t24, t23), vmulq_n_f16(vsubq_f16(t21, t22), 4)); | |||
| float16x8_t m23 = vaddq_f16(vsubq_f16(t24, t22), vmulq_n_f16(vsubq_f16(t23, t21), 2)); | |||
| float16x8_t m24 = vaddq_f16(vsubq_f16(t24, t22), vmulq_n_f16(vsubq_f16(t21, t23), 2)); | |||
| float16x8_t m25 = vaddq_f16(vsubq_f16(vmulq_n_f16(t21, 4), vmulq_n_f16(t23, 5)), t25); | |||
| float16x8_t m30 = vaddq_f16(vsubq_f16(vmulq_n_f16(t30, 4), vmulq_n_f16(t32, 5)), t34); | |||
| float16x8_t m31 = vaddq_f16(vaddq_f16(t33, t34), vmulq_n_f16(vaddq_f16(t31, t32), -4)); | |||
| float16x8_t m32 = vaddq_f16(vsubq_f16(t34, t33), vmulq_n_f16(vsubq_f16(t31, t32), 4)); | |||
| float16x8_t m33 = vaddq_f16(vsubq_f16(t34, t32), vmulq_n_f16(vsubq_f16(t33, t31), 2)); | |||
| float16x8_t m34 = vaddq_f16(vsubq_f16(t34, t32), vmulq_n_f16(vsubq_f16(t31, t33), 2)); | |||
| float16x8_t m35 = vaddq_f16(vsubq_f16(vmulq_n_f16(t31, 4), vmulq_n_f16(t33, 5)), t35); | |||
| float16x8_t m40 = vaddq_f16(vsubq_f16(vmulq_n_f16(t40, 4), vmulq_n_f16(t42, 5)), t44); | |||
| float16x8_t m41 = vaddq_f16(vaddq_f16(t43, t44), vmulq_n_f16(vaddq_f16(t41, t42), -4)); | |||
| float16x8_t m42 = vaddq_f16(vsubq_f16(t44, t43), vmulq_n_f16(vsubq_f16(t41, t42), 4)); | |||
| float16x8_t m43 = vaddq_f16(vsubq_f16(t44, t42), vmulq_n_f16(vsubq_f16(t43, t41), 2)); | |||
| float16x8_t m44 = vaddq_f16(vsubq_f16(t44, t42), vmulq_n_f16(vsubq_f16(t41, t43), 2)); | |||
| float16x8_t m45 = vaddq_f16(vsubq_f16(vmulq_n_f16(t41, 4), vmulq_n_f16(t43, 5)), t45); | |||
| float16x8_t m50 = vaddq_f16(vsubq_f16(vmulq_n_f16(t50, 4), vmulq_n_f16(t52, 5)), t54); | |||
| float16x8_t m51 = vaddq_f16(vaddq_f16(t53, t54), vmulq_n_f16(vaddq_f16(t51, t52), -4)); | |||
| float16x8_t m52 = vaddq_f16(vsubq_f16(t54, t53), vmulq_n_f16(vsubq_f16(t51, t52), 4)); | |||
| float16x8_t m53 = vaddq_f16(vsubq_f16(t54, t52), vmulq_n_f16(vsubq_f16(t53, t51), 2)); | |||
| float16x8_t m54 = vaddq_f16(vsubq_f16(t54, t52), vmulq_n_f16(vsubq_f16(t51, t53), 2)); | |||
| float16x8_t m55 = vaddq_f16(vsubq_f16(vmulq_n_f16(t51, 4), vmulq_n_f16(t53, 5)), t55); | |||
| vst1_f16(trans_input_data, vget_low_f16(m00)); | |||
| vst1_f16(trans_input_data + 64, vget_high_f16(m00)); | |||
| vst1_f16(trans_input_data + step, vget_low_f16(m01)); | |||
| vst1_f16(trans_input_data + step + 64, vget_high_f16(m01)); | |||
| vst1_f16(trans_input_data + 2 * step, vget_low_f16(m02)); | |||
| vst1_f16(trans_input_data + 2 * step + 64, vget_high_f16(m02)); | |||
| vst1_f16(trans_input_data + 3 * step, vget_low_f16(m03)); | |||
| vst1_f16(trans_input_data + 3 * step + 64, vget_high_f16(m03)); | |||
| vst1_f16(trans_input_data + 4 * step, vget_low_f16(m04)); | |||
| vst1_f16(trans_input_data + 4 * step + 64, vget_high_f16(m04)); | |||
| vst1_f16(trans_input_data + 5 * step, vget_low_f16(m05)); | |||
| vst1_f16(trans_input_data + 5 * step + 64, vget_high_f16(m05)); | |||
| vst1_f16(trans_input_data + 6 * step, vget_low_f16(m10)); | |||
| vst1_f16(trans_input_data + 6 * step + 64, vget_high_f16(m10)); | |||
| vst1_f16(trans_input_data + 7 * step, vget_low_f16(m11)); | |||
| vst1_f16(trans_input_data + 7 * step + 64, vget_high_f16(m11)); | |||
| vst1_f16(trans_input_data + 8 * step, vget_low_f16(m12)); | |||
| vst1_f16(trans_input_data + 8 * step + 64, vget_high_f16(m12)); | |||
| vst1_f16(trans_input_data + 9 * step, vget_low_f16(m13)); | |||
| vst1_f16(trans_input_data + 9 * step + 64, vget_high_f16(m13)); | |||
| vst1_f16(trans_input_data + 10 * step, vget_low_f16(m14)); | |||
| vst1_f16(trans_input_data + 10 * step + 64, vget_high_f16(m14)); | |||
| vst1_f16(trans_input_data + 11 * step, vget_low_f16(m15)); | |||
| vst1_f16(trans_input_data + 11 * step + 64, vget_high_f16(m15)); | |||
| vst1_f16(trans_input_data + 12 * step, vget_low_f16(m20)); | |||
| vst1_f16(trans_input_data + 12 * step + 64, vget_high_f16(m20)); | |||
| vst1_f16(trans_input_data + 13 * step, vget_low_f16(m21)); | |||
| vst1_f16(trans_input_data + 13 * step + 64, vget_high_f16(m21)); | |||
| vst1_f16(trans_input_data + 14 * step, vget_low_f16(m22)); | |||
| vst1_f16(trans_input_data + 14 * step + 64, vget_high_f16(m22)); | |||
| vst1_f16(trans_input_data + 15 * step, vget_low_f16(m23)); | |||
| vst1_f16(trans_input_data + 15 * step + 64, vget_high_f16(m23)); | |||
| vst1_f16(trans_input_data + 16 * step, vget_low_f16(m24)); | |||
| vst1_f16(trans_input_data + 16 * step + 64, vget_high_f16(m24)); | |||
| vst1_f16(trans_input_data + 17 * step, vget_low_f16(m25)); | |||
| vst1_f16(trans_input_data + 17 * step + 64, vget_high_f16(m25)); | |||
| vst1_f16(trans_input_data + 18 * step, vget_low_f16(m30)); | |||
| vst1_f16(trans_input_data + 18 * step + 64, vget_high_f16(m30)); | |||
| vst1_f16(trans_input_data + 19 * step, vget_low_f16(m31)); | |||
| vst1_f16(trans_input_data + 19 * step + 64, vget_high_f16(m31)); | |||
| vst1_f16(trans_input_data + 20 * step, vget_low_f16(m32)); | |||
| vst1_f16(trans_input_data + 20 * step + 64, vget_high_f16(m32)); | |||
| vst1_f16(trans_input_data + 21 * step, vget_low_f16(m33)); | |||
| vst1_f16(trans_input_data + 21 * step + 64, vget_high_f16(m33)); | |||
| vst1_f16(trans_input_data + 22 * step, vget_low_f16(m34)); | |||
| vst1_f16(trans_input_data + 22 * step + 64, vget_high_f16(m34)); | |||
| vst1_f16(trans_input_data + 23 * step, vget_low_f16(m35)); | |||
| vst1_f16(trans_input_data + 23 * step + 64, vget_high_f16(m35)); | |||
| vst1_f16(trans_input_data + 24 * step, vget_low_f16(m40)); | |||
| vst1_f16(trans_input_data + 24 * step + 64, vget_high_f16(m40)); | |||
| vst1_f16(trans_input_data + 25 * step, vget_low_f16(m41)); | |||
| vst1_f16(trans_input_data + 25 * step + 64, vget_high_f16(m41)); | |||
| vst1_f16(trans_input_data + 26 * step, vget_low_f16(m42)); | |||
| vst1_f16(trans_input_data + 26 * step + 64, vget_high_f16(m42)); | |||
| vst1_f16(trans_input_data + 27 * step, vget_low_f16(m43)); | |||
| vst1_f16(trans_input_data + 27 * step + 64, vget_high_f16(m43)); | |||
| vst1_f16(trans_input_data + 28 * step, vget_low_f16(m44)); | |||
| vst1_f16(trans_input_data + 28 * step + 64, vget_high_f16(m44)); | |||
| vst1_f16(trans_input_data + 29 * step, vget_low_f16(m45)); | |||
| vst1_f16(trans_input_data + 29 * step + 64, vget_high_f16(m45)); | |||
| vst1_f16(trans_input_data + 30 * step, vget_low_f16(m50)); | |||
| vst1_f16(trans_input_data + 30 * step + 64, vget_high_f16(m50)); | |||
| vst1_f16(trans_input_data + 31 * step, vget_low_f16(m51)); | |||
| vst1_f16(trans_input_data + 31 * step + 64, vget_high_f16(m51)); | |||
| vst1_f16(trans_input_data + 32 * step, vget_low_f16(m52)); | |||
| vst1_f16(trans_input_data + 32 * step + 64, vget_high_f16(m52)); | |||
| vst1_f16(trans_input_data + 33 * step, vget_low_f16(m53)); | |||
| vst1_f16(trans_input_data + 33 * step + 64, vget_high_f16(m53)); | |||
| vst1_f16(trans_input_data + 34 * step, vget_low_f16(m54)); | |||
| vst1_f16(trans_input_data + 34 * step + 64, vget_high_f16(m54)); | |||
| vst1_f16(trans_input_data + 35 * step, vget_low_f16(m55)); | |||
| vst1_f16(trans_input_data + 35 * step + 64, vget_high_f16(m55)); | |||
| } | |||
| void Conv3x3Fp16InputTransform(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, | |||
| int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param) { | |||
| // input data format : nhwc | |||
| const int output_unit = 4; | |||
| int input_channel = conv_param->input_channel_; | |||
| int input_width = conv_param->input_w_; | |||
| int input_height = conv_param->input_h_; | |||
| int pad_w = conv_param->pad_l_; | |||
| int pad_h = conv_param->pad_u_; | |||
| int ic8 = UP_DIV(input_channel, C8NUM); | |||
| if (out_w_block == 0) { | |||
| return; | |||
| } | |||
| for (int cal_id = 0; cal_id < real_cal_num; cal_id++) { | |||
| int x_id = start_index + cal_id; | |||
| int origin_x = (x_id % out_w_block) * output_unit - pad_w; | |||
| int origin_y = (x_id / out_w_block) * output_unit - pad_h; | |||
| int real_x_start = origin_x > 0 ? 0 : -origin_x; | |||
| int real_x_end = (origin_x + 6) < input_width ? 6 : (input_width - origin_x); | |||
| int real_y_start = origin_y > 0 ? 0 : -origin_y; | |||
| int real_y_end = (origin_y + 6) < input_height ? 6 : (input_height - origin_y); | |||
| int src_plane_offset = ic8 * C8NUM * (origin_y * input_width + origin_x); | |||
| int dst_plane_offset = cal_id * C4NUM; | |||
| for (int ic = 0; ic < ic8; ic++) { | |||
| // clear tmp buffer | |||
| memset(tmp_data, 0, 6 * 6 * C8NUM * sizeof(float16_t)); | |||
| // get real input block with padding | |||
| int src_ic4_offset = src_plane_offset + ic * C8NUM; | |||
| for (int interval = real_y_start; interval < real_y_end; interval++) { | |||
| int src_y_offset = src_ic4_offset + (interval * input_width + real_x_start) * ic8 * C8NUM; | |||
| int dst_y_offset = interval * 6 * C8NUM + real_x_start * C8NUM; | |||
| for (int j = 0; j < (real_x_end - real_x_start); j++) { | |||
| int src_x_offset = src_y_offset + j * ic8 * C8NUM; | |||
| int dst_x_offset = dst_y_offset + j * C8NUM; | |||
| float16_t *src_addr = (float16_t *)(input_data) + src_x_offset; | |||
| float16_t *dst_addr = tmp_data + dst_x_offset; | |||
| vst1q_f16(dst_addr, vld1q_f16(src_addr)); | |||
| } | |||
| } | |||
| // input transform | |||
| int dst_ic4_offset = dst_plane_offset + ic * 16 * C8NUM; | |||
| size_t dst_step = ic8 * C8NUM * 16; | |||
| float16_t *trans_input_ptr = trans_input + dst_ic4_offset; | |||
| Conv3x3Fp16InputUnit(tmp_data, trans_input_ptr, dst_step); | |||
| } | |||
| } | |||
| } | |||
| void Conv3x3Fp16FilterTransform(const float16_t *weight_data, float16_t *trans_weight, int iC4, int output_channel, | |||
| int kernel_plane) { | |||
| int dst_step = iC4 * C4NUM * 8; | |||
| for (int o = 0; o < output_channel; o++) { | |||
| int oc8_block_num = o / C8NUM; | |||
| int oc8_block_rem = o % C8NUM; | |||
| int src_oc_offset = o * iC4 * C4NUM * kernel_plane; | |||
| int dst_oc_offset = oc8_block_num * C8NUM * iC4 * C4NUM * 36 + oc8_block_rem; | |||
| for (int i = 0; i < iC4; i++) { | |||
| const float16_t *src_ic4_ptr = weight_data + src_oc_offset + i * kernel_plane * C4NUM; | |||
| float16_t *dst_ic4_ptr = trans_weight + dst_oc_offset + i * 8 * C4NUM; | |||
| float16x4_t g00 = vld1_f16(src_ic4_ptr); | |||
| float16x4_t g01 = vld1_f16(src_ic4_ptr + 4); | |||
| float16x4_t g02 = vld1_f16(src_ic4_ptr + 2 * 4); | |||
| float16x4_t g10 = vld1_f16(src_ic4_ptr + 3 * 4); | |||
| float16x4_t g11 = vld1_f16(src_ic4_ptr + 4 * 4); | |||
| float16x4_t g12 = vld1_f16(src_ic4_ptr + 5 * 4); | |||
| float16x4_t g20 = vld1_f16(src_ic4_ptr + 6 * 4); | |||
| float16x4_t g21 = vld1_f16(src_ic4_ptr + 7 * 4); | |||
| float16x4_t g22 = vld1_f16(src_ic4_ptr + 8 * 4); | |||
| float16x4_t dst00 = vmul_n_f16(g00, 0.25); | |||
| float16x4_t dst01 = vmul_n_f16(g01, 0.25); | |||
| float16x4_t dst02 = vmul_n_f16(g02, 0.25); | |||
| float16x4_t dst10 = vmul_n_f16(vadd_f16(g00, vadd_f16(g10, g20)), -0.1666666666667); | |||
| float16x4_t dst11 = vmul_n_f16(vadd_f16(g01, vadd_f16(g11, g21)), -0.1666666666667); | |||
| float16x4_t dst12 = vmul_n_f16(vadd_f16(g02, vadd_f16(g12, g22)), -0.1666666666667); | |||
| float16x4_t dst20 = vmul_n_f16(vsub_f16(vadd_f16(g00, g20), g10), -0.1666666666667); | |||
| float16x4_t dst21 = vmul_n_f16(vsub_f16(vadd_f16(g01, g21), g11), -0.1666666666667); | |||
| float16x4_t dst22 = vmul_n_f16(vsub_f16(vadd_f16(g02, g22), g12), -0.1666666666667); | |||
| float16x4_t dst30 = vadd_f16(vmul_n_f16(g10, 0.08333333333333), | |||
| vadd_f16(vmul_n_f16(g00, 0.04166666666667), vmul_n_f16(g20, 0.1666666666667))); | |||
| float16x4_t dst31 = vadd_f16(vmul_n_f16(g11, 0.08333333333333), | |||
| vadd_f16(vmul_n_f16(g01, 0.04166666666667), vmul_n_f16(g21, 0.1666666666667))); | |||
| float16x4_t dst32 = vadd_f16(vmul_n_f16(g12, 0.08333333333333), | |||
| vadd_f16(vmul_n_f16(g02, 0.04166666666667), vmul_n_f16(g22, 0.1666666666667))); | |||
| float16x4_t dst40 = vsub_f16(vadd_f16(vmul_n_f16(g00, 0.04166666666667), vmul_n_f16(g20, 0.1666666666667)), | |||
| vmul_n_f16(g10, 0.08333333333333)); | |||
| float16x4_t dst41 = vsub_f16(vadd_f16(vmul_n_f16(g01, 0.04166666666667), vmul_n_f16(g21, 0.1666666666667)), | |||
| vmul_n_f16(g11, 0.08333333333333)); | |||
| float16x4_t dst42 = vsub_f16(vadd_f16(vmul_n_f16(g02, 0.04166666666667), vmul_n_f16(g22, 0.1666666666667)), | |||
| vmul_n_f16(g12, 0.08333333333333)); | |||
| float16x4_t dst50 = g20; | |||
| float16x4_t dst51 = g21; | |||
| float16x4_t dst52 = g22; | |||
| float16x4_t m00 = vmul_n_f16(dst00, 0.25); | |||
| float16x4_t m01 = vmul_n_f16(vadd_f16(dst00, vadd_f16(dst01, dst02)), -0.1666666666667); | |||
| float16x4_t m02 = vmul_n_f16(vsub_f16(vadd_f16(dst00, dst02), dst01), -0.1666666666667); | |||
| float16x4_t m03 = vadd_f16(vmul_n_f16(dst01, 0.08333333333333), | |||
| vadd_f16(vmul_n_f16(dst00, 0.04166666666667), vmul_n_f16(dst02, 0.1666666666667))); | |||
| float16x4_t m04 = vsub_f16(vadd_f16(vmul_n_f16(dst00, 0.04166666666667), vmul_n_f16(dst02, 0.1666666666667)), | |||
| vmul_n_f16(dst01, 0.08333333333333)); | |||
| float16x4_t m05 = dst02; | |||
| float16x4_t m10 = vmul_n_f16(dst10, 0.25); | |||
| float16x4_t m11 = vmul_n_f16(vadd_f16(dst10, vadd_f16(dst11, dst12)), -0.1666666666667); | |||
| float16x4_t m12 = vmul_n_f16(vsub_f16(vadd_f16(dst10, dst12), dst11), -0.1666666666667); | |||
| float16x4_t m13 = vadd_f16(vmul_n_f16(dst11, 0.08333333333333), | |||
| vadd_f16(vmul_n_f16(dst10, 0.04166666666667), vmul_n_f16(dst12, 0.1666666666667))); | |||
| float16x4_t m14 = vsub_f16(vadd_f16(vmul_n_f16(dst10, 0.04166666666667), vmul_n_f16(dst12, 0.1666666666667)), | |||
| vmul_n_f16(dst11, 0.08333333333333)); | |||
| float16x4_t m15 = dst12; | |||
| float16x4_t m20 = vmul_n_f16(dst20, 0.25); | |||
| float16x4_t m21 = vmul_n_f16(vadd_f16(dst20, vadd_f16(dst21, dst22)), -0.1666666666667); | |||
| float16x4_t m22 = vmul_n_f16(vsub_f16(vadd_f16(dst20, dst22), dst21), -0.1666666666667); | |||
| float16x4_t m23 = vadd_f16(vmul_n_f16(dst21, 0.08333333333333), | |||
| vadd_f16(vmul_n_f16(dst20, 0.04166666666667), vmul_n_f16(dst22, 0.1666666666667))); | |||
| float16x4_t m24 = vsub_f16(vadd_f16(vmul_n_f16(dst20, 0.04166666666667), vmul_n_f16(dst22, 0.1666666666667)), | |||
| vmul_n_f16(dst21, 0.08333333333333)); | |||
| float16x4_t m25 = dst22; | |||
| float16x4_t m30 = vmul_n_f16(dst30, 0.25); | |||
| float16x4_t m31 = vmul_n_f16(vadd_f16(dst30, vadd_f16(dst31, dst32)), -0.1666666666667); | |||
| float16x4_t m32 = vmul_n_f16(vsub_f16(vadd_f16(dst30, dst32), dst31), -0.1666666666667); | |||
| float16x4_t m33 = vadd_f16(vmul_n_f16(dst31, 0.08333333333333), | |||
| vadd_f16(vmul_n_f16(dst30, 0.04166666666667), vmul_n_f16(dst32, 0.1666666666667))); | |||
| float16x4_t m34 = vsub_f16(vadd_f16(vmul_n_f16(dst30, 0.04166666666667), vmul_n_f16(dst32, 0.1666666666667)), | |||
| vmul_n_f16(dst31, 0.08333333333333)); | |||
| float16x4_t m35 = dst32; | |||
| float16x4_t m40 = vmul_n_f16(dst40, 0.25); | |||
| float16x4_t m41 = vmul_n_f16(vadd_f16(dst40, vadd_f16(dst41, dst42)), -0.1666666666667); | |||
| float16x4_t m42 = vmul_n_f16(vsub_f16(vadd_f16(dst40, dst42), dst41), -0.1666666666667); | |||
| float16x4_t m43 = vadd_f16(vmul_n_f16(dst41, 0.08333333333333), | |||
| vadd_f16(vmul_n_f16(dst40, 0.04166666666667), vmul_n_f16(dst42, 0.1666666666667))); | |||
| float16x4_t m44 = vsub_f16(vadd_f16(vmul_n_f16(dst40, 0.04166666666667), vmul_n_f16(dst42, 0.1666666666667)), | |||
| vmul_n_f16(dst41, 0.08333333333333)); | |||
| float16x4_t m45 = dst42; | |||
| float16x4_t m50 = vmul_n_f16(dst50, 0.25); | |||
| float16x4_t m51 = vmul_n_f16(vadd_f16(dst50, vadd_f16(dst51, dst52)), -0.1666666666667); | |||
| float16x4_t m52 = vmul_n_f16(vsub_f16(vadd_f16(dst50, dst52), dst51), -0.1666666666667); | |||
| float16x4_t m53 = vadd_f16(vmul_n_f16(dst51, 0.08333333333333), | |||
| vadd_f16(vmul_n_f16(dst50, 0.04166666666667), vmul_n_f16(dst52, 0.1666666666667))); | |||
| float16x4_t m54 = vsub_f16(vadd_f16(vmul_n_f16(dst50, 0.04166666666667), vmul_n_f16(dst52, 0.1666666666667)), | |||
| vmul_n_f16(dst51, 0.08333333333333)); | |||
| float16x4_t m55 = dst52; | |||
| for (int j = 0; j < 4; j++) { | |||
| dst_ic4_ptr[j * 8] = m00[j]; | |||
| dst_ic4_ptr[j * 8 + dst_step] = m01[j]; | |||
| dst_ic4_ptr[j * 8 + 2 * dst_step] = m02[j]; | |||
| dst_ic4_ptr[j * 8 + 3 * dst_step] = m03[j]; | |||
| dst_ic4_ptr[j * 8 + 4 * dst_step] = m04[j]; | |||
| dst_ic4_ptr[j * 8 + 5 * dst_step] = m05[j]; | |||
| dst_ic4_ptr[j * 8 + 6 * dst_step] = m10[j]; | |||
| dst_ic4_ptr[j * 8 + 7 * dst_step] = m11[j]; | |||
| dst_ic4_ptr[j * 8 + 8 * dst_step] = m12[j]; | |||
| dst_ic4_ptr[j * 8 + 9 * dst_step] = m13[j]; | |||
| dst_ic4_ptr[j * 8 + 10 * dst_step] = m14[j]; | |||
| dst_ic4_ptr[j * 8 + 11 * dst_step] = m15[j]; | |||
| dst_ic4_ptr[j * 8 + 12 * dst_step] = m20[j]; | |||
| dst_ic4_ptr[j * 8 + 13 * dst_step] = m21[j]; | |||
| dst_ic4_ptr[j * 8 + 14 * dst_step] = m22[j]; | |||
| dst_ic4_ptr[j * 8 + 15 * dst_step] = m23[j]; | |||
| dst_ic4_ptr[j * 8 + 16 * dst_step] = m24[j]; | |||
| dst_ic4_ptr[j * 8 + 17 * dst_step] = m25[j]; | |||
| dst_ic4_ptr[j * 8 + 18 * dst_step] = m30[j]; | |||
| dst_ic4_ptr[j * 8 + 19 * dst_step] = m31[j]; | |||
| dst_ic4_ptr[j * 8 + 20 * dst_step] = m32[j]; | |||
| dst_ic4_ptr[j * 8 + 21 * dst_step] = m33[j]; | |||
| dst_ic4_ptr[j * 8 + 22 * dst_step] = m34[j]; | |||
| dst_ic4_ptr[j * 8 + 23 * dst_step] = m35[j]; | |||
| dst_ic4_ptr[j * 8 + 24 * dst_step] = m40[j]; | |||
| dst_ic4_ptr[j * 8 + 25 * dst_step] = m41[j]; | |||
| dst_ic4_ptr[j * 8 + 26 * dst_step] = m42[j]; | |||
| dst_ic4_ptr[j * 8 + 27 * dst_step] = m43[j]; | |||
| dst_ic4_ptr[j * 8 + 28 * dst_step] = m44[j]; | |||
| dst_ic4_ptr[j * 8 + 29 * dst_step] = m45[j]; | |||
| dst_ic4_ptr[j * 8 + 30 * dst_step] = m50[j]; | |||
| dst_ic4_ptr[j * 8 + 31 * dst_step] = m51[j]; | |||
| dst_ic4_ptr[j * 8 + 32 * dst_step] = m52[j]; | |||
| dst_ic4_ptr[j * 8 + 33 * dst_step] = m53[j]; | |||
| dst_ic4_ptr[j * 8 + 34 * dst_step] = m54[j]; | |||
| dst_ic4_ptr[j * 8 + 35 * dst_step] = m55[j]; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void Conv3x3Fp16OutputUnit(const float16_t *gemm_out, const float16_t *bias_data, float16_t *output_data, | |||
| int output_w) { | |||
| float16x8_t s00 = vld1q_f16(gemm_out); | |||
| float16x8_t s01 = vld1q_f16(gemm_out + 8); | |||
| float16x8_t s02 = vld1q_f16(gemm_out + 16); | |||
| float16x8_t s03 = vld1q_f16(gemm_out + 24); | |||
| float16x8_t s04 = vld1q_f16(gemm_out + 32); | |||
| float16x8_t s05 = vld1q_f16(gemm_out + 40); | |||
| float16x8_t s10 = vld1q_f16(gemm_out + 48); | |||
| float16x8_t s11 = vld1q_f16(gemm_out + 56); | |||
| float16x8_t s12 = vld1q_f16(gemm_out + 64); | |||
| float16x8_t s13 = vld1q_f16(gemm_out + 72); | |||
| float16x8_t s14 = vld1q_f16(gemm_out + 80); | |||
| float16x8_t s15 = vld1q_f16(gemm_out + 88); | |||
| float16x8_t s20 = vld1q_f16(gemm_out + 96); | |||
| float16x8_t s21 = vld1q_f16(gemm_out + 104); | |||
| float16x8_t s22 = vld1q_f16(gemm_out + 112); | |||
| float16x8_t s23 = vld1q_f16(gemm_out + 120); | |||
| float16x8_t s24 = vld1q_f16(gemm_out + 128); | |||
| float16x8_t s25 = vld1q_f16(gemm_out + 136); | |||
| float16x8_t s30 = vld1q_f16(gemm_out + 144); | |||
| float16x8_t s31 = vld1q_f16(gemm_out + 152); | |||
| float16x8_t s32 = vld1q_f16(gemm_out + 160); | |||
| float16x8_t s33 = vld1q_f16(gemm_out + 168); | |||
| float16x8_t s34 = vld1q_f16(gemm_out + 176); | |||
| float16x8_t s35 = vld1q_f16(gemm_out + 184); | |||
| float16x8_t s40 = vld1q_f16(gemm_out + 192); | |||
| float16x8_t s41 = vld1q_f16(gemm_out + 200); | |||
| float16x8_t s42 = vld1q_f16(gemm_out + 208); | |||
| float16x8_t s43 = vld1q_f16(gemm_out + 216); | |||
| float16x8_t s44 = vld1q_f16(gemm_out + 224); | |||
| float16x8_t s45 = vld1q_f16(gemm_out + 232); | |||
| float16x8_t s50 = vld1q_f16(gemm_out + 240); | |||
| float16x8_t s51 = vld1q_f16(gemm_out + 248); | |||
| float16x8_t s52 = vld1q_f16(gemm_out + 256); | |||
| float16x8_t s53 = vld1q_f16(gemm_out + 264); | |||
| float16x8_t s54 = vld1q_f16(gemm_out + 272); | |||
| float16x8_t s55 = vld1q_f16(gemm_out + 280); | |||
| float16x8_t t00 = vaddq_f16(vaddq_f16(vaddq_f16(s00, s10), vaddq_f16(s20, s30)), s40); | |||
| float16x8_t t01 = vaddq_f16(vaddq_f16(vaddq_f16(s01, s11), vaddq_f16(s21, s31)), s41); | |||
| float16x8_t t02 = vaddq_f16(vaddq_f16(vaddq_f16(s02, s12), vaddq_f16(s22, s32)), s42); | |||
| float16x8_t t03 = vaddq_f16(vaddq_f16(vaddq_f16(s03, s13), vaddq_f16(s23, s33)), s43); | |||
| float16x8_t t04 = vaddq_f16(vaddq_f16(vaddq_f16(s04, s14), vaddq_f16(s24, s34)), s44); | |||
| float16x8_t t05 = vaddq_f16(vaddq_f16(vaddq_f16(s05, s15), vaddq_f16(s25, s35)), s45); | |||
| float16x8_t t10 = vaddq_f16(vsubq_f16(s10, s20), vmulq_n_f16(vsubq_f16(s30, s40), 2)); | |||
| float16x8_t t11 = vaddq_f16(vsubq_f16(s11, s21), vmulq_n_f16(vsubq_f16(s31, s41), 2)); | |||
| float16x8_t t12 = vaddq_f16(vsubq_f16(s12, s22), vmulq_n_f16(vsubq_f16(s32, s42), 2)); | |||
| float16x8_t t13 = vaddq_f16(vsubq_f16(s13, s23), vmulq_n_f16(vsubq_f16(s33, s43), 2)); | |||
| float16x8_t t14 = vaddq_f16(vsubq_f16(s14, s24), vmulq_n_f16(vsubq_f16(s34, s44), 2)); | |||
| float16x8_t t15 = vaddq_f16(vsubq_f16(s15, s25), vmulq_n_f16(vsubq_f16(s35, s45), 2)); | |||
| float16x8_t t20 = vaddq_f16(vaddq_f16(s10, s20), vmulq_n_f16(vaddq_f16(s30, s40), 4)); | |||
| float16x8_t t21 = vaddq_f16(vaddq_f16(s11, s21), vmulq_n_f16(vaddq_f16(s31, s41), 4)); | |||
| float16x8_t t22 = vaddq_f16(vaddq_f16(s12, s22), vmulq_n_f16(vaddq_f16(s32, s42), 4)); | |||
| float16x8_t t23 = vaddq_f16(vaddq_f16(s13, s23), vmulq_n_f16(vaddq_f16(s33, s43), 4)); | |||
| float16x8_t t24 = vaddq_f16(vaddq_f16(s14, s24), vmulq_n_f16(vaddq_f16(s34, s44), 4)); | |||
| float16x8_t t25 = vaddq_f16(vaddq_f16(s15, s25), vmulq_n_f16(vaddq_f16(s35, s45), 4)); | |||
| float16x8_t t30 = vaddq_f16(vaddq_f16(vsubq_f16(s10, s20), vmulq_n_f16(vsubq_f16(s30, s40), 8)), s50); | |||
| float16x8_t t31 = vaddq_f16(vaddq_f16(vsubq_f16(s11, s21), vmulq_n_f16(vsubq_f16(s31, s41), 8)), s51); | |||
| float16x8_t t32 = vaddq_f16(vaddq_f16(vsubq_f16(s12, s22), vmulq_n_f16(vsubq_f16(s32, s42), 8)), s52); | |||
| float16x8_t t33 = vaddq_f16(vaddq_f16(vsubq_f16(s13, s23), vmulq_n_f16(vsubq_f16(s33, s43), 8)), s53); | |||
| float16x8_t t34 = vaddq_f16(vaddq_f16(vsubq_f16(s14, s24), vmulq_n_f16(vsubq_f16(s34, s44), 8)), s54); | |||
| float16x8_t t35 = vaddq_f16(vaddq_f16(vsubq_f16(s15, s25), vmulq_n_f16(vsubq_f16(s35, s45), 8)), s55); | |||
| float16x8_t bias_ptr = vld1q_f16(bias_data); | |||
| float16x8_t d00 = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t00, t01), vaddq_f16(t02, t03)), t04), bias_ptr); | |||
| float16x8_t d01 = vaddq_f16(vaddq_f16(vsubq_f16(t01, t02), vmulq_n_f16(vsubq_f16(t03, t04), 2)), bias_ptr); | |||
| float16x8_t d02 = vaddq_f16(vaddq_f16(vaddq_f16(t01, t02), vmulq_n_f16(vaddq_f16(t03, t04), 4)), bias_ptr); | |||
| float16x8_t d03 = | |||
| vaddq_f16(vaddq_f16(vaddq_f16(vsubq_f16(t01, t02), vmulq_n_f16(vsubq_f16(t03, t04), 8)), t05), bias_ptr); | |||
| float16x8_t d10 = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t10, t11), vaddq_f16(t12, t13)), t14), bias_ptr); | |||
| float16x8_t d11 = vaddq_f16(vaddq_f16(vsubq_f16(t11, t12), vmulq_n_f16(vsubq_f16(t13, t14), 2)), bias_ptr); | |||
| float16x8_t d12 = vaddq_f16(vaddq_f16(vaddq_f16(t11, t12), vmulq_n_f16(vaddq_f16(t13, t14), 4)), bias_ptr); | |||
| float16x8_t d13 = | |||
| vaddq_f16(vaddq_f16(vaddq_f16(vsubq_f16(t11, t12), vmulq_n_f16(vsubq_f16(t13, t14), 8)), t15), bias_ptr); | |||
| float16x8_t d20 = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t20, t21), vaddq_f16(t22, t23)), t24), bias_ptr); | |||
| float16x8_t d21 = vaddq_f16(vaddq_f16(vsubq_f16(t21, t22), vmulq_n_f16(vsubq_f16(t23, t24), 2)), bias_ptr); | |||
| float16x8_t d22 = vaddq_f16(vaddq_f16(vaddq_f16(t21, t22), vmulq_n_f16(vaddq_f16(t23, t24), 4)), bias_ptr); | |||
| float16x8_t d23 = | |||
| vaddq_f16(vaddq_f16(vaddq_f16(vsubq_f16(t21, t22), vmulq_n_f16(vsubq_f16(t23, t24), 8)), t25), bias_ptr); | |||
| float16x8_t d30 = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t30, t31), vaddq_f16(t32, t33)), t34), bias_ptr); | |||
| float16x8_t d31 = vaddq_f16(vaddq_f16(vsubq_f16(t31, t32), vmulq_n_f16(vsubq_f16(t33, t34), 2)), bias_ptr); | |||
| float16x8_t d32 = vaddq_f16(vaddq_f16(vaddq_f16(t31, t32), vmulq_n_f16(vaddq_f16(t33, t34), 4)), bias_ptr); | |||
| float16x8_t d33 = | |||
| vaddq_f16(vaddq_f16(vaddq_f16(vsubq_f16(t31, t32), vmulq_n_f16(vsubq_f16(t33, t34), 8)), t35), bias_ptr); | |||
| vst1q_f16(output_data, d00); | |||
| vst1q_f16(output_data + 8, d01); | |||
| vst1q_f16(output_data + 16, d02); | |||
| vst1q_f16(output_data + 24, d03); | |||
| vst1q_f16(output_data + output_w * 8, d10); | |||
| vst1q_f16(output_data + output_w * 8 + 8, d11); | |||
| vst1q_f16(output_data + output_w * 8 + 16, d12); | |||
| vst1q_f16(output_data + output_w * 8 + 24, d13); | |||
| vst1q_f16(output_data + 2 * output_w * 8, d20); | |||
| vst1q_f16(output_data + 2 * output_w * 8 + 8, d21); | |||
| vst1q_f16(output_data + 2 * output_w * 8 + 16, d22); | |||
| vst1q_f16(output_data + 2 * output_w * 8 + 24, d23); | |||
| vst1q_f16(output_data + 3 * output_w * 8, d30); | |||
| vst1q_f16(output_data + 3 * output_w * 8 + 8, d31); | |||
| vst1q_f16(output_data + 3 * output_w * 8 + 16, d32); | |||
| vst1q_f16(output_data + 3 * output_w * 8 + 24, d33); | |||
| } | |||
| void Conv3x3Fp16OutputTransform(const float16_t *gemm_out, float16_t *out_data, const float16_t *bias_data, | |||
| int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param) { | |||
| int output_channel = conv_param->output_channel_; | |||
| int output_h = conv_param->output_h_; | |||
| int out_h_block = UP_DIV(output_h, C4NUM); | |||
| int oc8 = UP_DIV(output_channel, C8NUM); | |||
| if (out_w_block == 0) { | |||
| return; | |||
| } | |||
| for (int i = 0; i < real_cal_num; i++) { | |||
| int out_w_index = (start_index + i) % out_w_block; | |||
| int out_h_index = (start_index + i) / out_w_block; | |||
| int src_tile_offset = i * oc8 * C8NUM * 36; | |||
| int dst_tile_offset = C8NUM * (out_w_index * C4NUM + out_h_index * C4NUM * out_w_block * C4NUM); | |||
| for (int j = 0; j < oc8; j++) { | |||
| int src_oc8_offset = src_tile_offset + j * 36 * C8NUM; | |||
| int dst_oc8_offset = dst_tile_offset + j * C8NUM * out_h_block * out_w_block * C4NUM * C4NUM; | |||
| const float16_t *src_ptr = gemm_out + src_oc8_offset; | |||
| const float16_t *bias_ptr = bias_data + j * C8NUM; | |||
| float16_t *dst_ptr = out_data + dst_oc8_offset; | |||
| // output transform | |||
| Conv3x3Fp16OutputUnit(src_ptr, bias_ptr, dst_ptr, out_w_block * C4NUM); | |||
| } | |||
| } | |||
| } | |||
| // fp16 common winograd | |||
| void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, int cal_num, | |||
| int out_tile_index, int out_w_block_num, ConvParameter *conv_param, | |||
| InputTransFp16Func func) { | |||
| #ifdef ENABLE_ARM64 | |||
| const int tile_num = 16; | |||
| #else | |||
| const int tile_num = 12; | |||
| #endif | |||
| int input_unit = conv_param->input_unit_; | |||
| int output_unit = conv_param->output_unit_; | |||
| int in_channel = conv_param->input_channel_; | |||
| @@ -23,25 +23,11 @@ | |||
| #include "nnacl/fp16/cast_fp16.h" | |||
| #include "nnacl/fp16/conv_fp16.h" | |||
| #include "nnacl/fp16/matrix_fp16.h" | |||
| #include "nnacl/fp16/pack_fp16.h" | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| // for fp16 convolution 3x3 filter/input/output transform | |||
| void Conv3x3Fp16InputUnit(float16_t *tmp_data, float16_t *trans_input_data, size_t step); | |||
| void Conv3x3Fp16InputTransform(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, | |||
| int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param); | |||
| void Conv3x3Fp16FilterTransform(const float16_t *weight_data, float16_t *trans_weight, int iC8, int output_channel, | |||
| int kernel_plane); | |||
| void Conv3x3Fp16OutputUnit(const float16_t *gemm_out, const float16_t *bias_data, float16_t *output_data, int output_w); | |||
| void Conv3x3Fp16OutputTransform(const float16_t *gemm_out, float16_t *out_data, const float16_t *bias_data, | |||
| int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param); | |||
| // fp16 common winograd | |||
| void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, int cal_num, | |||
| int out_tile_index, int out_w_block_num, ConvParameter *conv_param, | |||
| @@ -17,9 +17,9 @@ | |||
| #ifndef MINDSPORE_NNACL_FP16_WINOGRAD_UTILS_H_ | |||
| #define MINDSPORE_NNACL_FP16_WINOGRAD_UTILS_H_ | |||
| #include <arm_neon.h> | |||
| #include "nnacl/conv_parameter.h" | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/intrinsics/ms_simd_instructions_fp16.h" | |||
| #define MAX_LEN 256 | |||
| @@ -17,9 +17,11 @@ | |||
| #ifndef MINDSPORE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ | |||
| #define MINDSPORE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ | |||
| #include <math.h> | |||
| #ifdef ENABLE_ARM | |||
| #include <arm_neon.h> | |||
| #endif | |||
| #if defined(ENABLE_SSE) || defined(ENABLE_AVX) | |||
| #include <x86intrin.h> | |||
| #endif | |||
| @@ -46,7 +48,7 @@ | |||
| #ifdef ENABLE_ARM64 | |||
| #define MS_DIVQ_F32(src1, src2) vdivq_f32(src1, src2) | |||
| #else | |||
| inline static float32x4_t vrecp(float32x4_t v) { | |||
| static inline float32x4_t vrecp(float32x4_t v) { | |||
| float32x4_t r = vrecpeq_f32(v); | |||
| r = vmulq_f32(vrecpsq_f32(v, r), r); | |||
| r = vmulq_f32(vrecpsq_f32(v, r), r); | |||
| @@ -205,25 +207,4 @@ static inline MS_FLOAT32X4 MS_ERFX4_F32(MS_FLOAT32X4 src) { | |||
| return dst; | |||
| } | |||
| #ifdef ENABLE_ARM64 | |||
| static inline float16x8_t MS_TANHX8_F16(float16x8_t src) { | |||
| float32x4_t src_low = vcvt_f32_f16(vget_low_f16(src)); | |||
| float32x4_t src_high = vcvt_f32_f16(vget_high_f16(src)); | |||
| return vcombine_f16(vcvt_f16_f32(MS_TANHX4_F32(src_low)), vcvt_f16_f32(MS_TANHX4_F32(src_high))); | |||
| } | |||
| static inline float16x8_t MS_ERFX8_F16(float16x8_t src) { | |||
| float16x8_t dst; | |||
| dst[0] = erff(src[0]); | |||
| dst[1] = erff(src[1]); | |||
| dst[2] = erff(src[2]); | |||
| dst[3] = erff(src[3]); | |||
| dst[4] = erff(src[4]); | |||
| dst[5] = erff(src[5]); | |||
| dst[6] = erff(src[6]); | |||
| dst[7] = erff(src[7]); | |||
| return dst; | |||
| } | |||
| #endif | |||
| #endif // MINDSPORE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ | |||
| @@ -0,0 +1,99 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_FP16_H_ | |||
| #define MINDSPORE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_FP16_H_ | |||
| #include <math.h> | |||
| #include "nnacl/intrinsics/ms_simd_instructions.h" | |||
| #if defined(ENABLE_ARM82_A32) | |||
| static inline float16x8_t divq_f16(float16x8_t in1, float16x8_t in2) { | |||
| float16x8_t dst; | |||
| asm volatile( | |||
| "vrecpe.f16 q14, %3\n" | |||
| "vrecps.f16 q15, %3, q14\n" | |||
| "vmul.f16 q14, q15, q14\n" | |||
| "vrecps.f16 q15, %3, q14\n" | |||
| "vmul.f16 q14, q15, q14\n" | |||
| "vmul.f16 %0, %2, q14\n" | |||
| : "=w"(dst) | |||
| : "0"(dst), "w"(in1), "w"(in2) | |||
| : "q14", "q15"); | |||
| return dst; | |||
| } | |||
| static inline float16x4_t div_f16(float16x4_t in1, float16x4_t in2) { | |||
| float16x4_t dst; | |||
| asm volatile( | |||
| "vrecpe.f16 d14, %3\n" | |||
| "vrecps.f16 d16, %3, d14\n" | |||
| "vmul.f16 d14, d16, d14\n" | |||
| "vrecps.f16 d16, %3, d14\n" | |||
| "vmul.f16 d14, d16, d14\n" | |||
| "vmul.f16 %0, %2, d14\n" | |||
| : "=w"(dst) | |||
| : "0"(dst), "w"(in1), "w"(in2) | |||
| : "d14", "d16"); | |||
| return dst; | |||
| } | |||
| static inline float vaddvq_f32(float32x4_t in) { // is not support in arm82 aarch32 | |||
| return in[0] + in[1] + in[2] + in[3]; | |||
| } | |||
| static inline float32x4_t cvt_f32_f16(float16x4_t in) { | |||
| float32x4_t dst; | |||
| asm volatile("vcvt.f32.f16 %0, %2\n" : "=w"(dst) : "0"(dst), "w"(in) :); | |||
| return dst; | |||
| } | |||
| static inline float16x4_t cvt_f16_f32(float32x4_t in) { | |||
| float16x4_t dst; | |||
| asm volatile("vcvt.f16.f32 %0, %2\n" : "=w"(dst) : "0"(dst), "w"(in) :); | |||
| return dst; | |||
| } | |||
| #define MS_CVT_F32_F16(src) cvt_f32_f16(src) | |||
| #define MS_CVT_F16_F32(src) cvt_f16_f32(src) | |||
| #define MS_DIV_F16(src1, src2) div_f16(src1, src2) | |||
| #define MS_DIVQ_F16(src1, src2) divq_f16(src1, src2) | |||
| #define MS_FMAQ_N_F16(src1, src2, src3) vfmaq_f16(src1, src2, vdupq_n_f16(src3)) | |||
| #else | |||
| #define MS_CVT_F32_F16(src) vcvt_f32_f16(src) | |||
| #define MS_CVT_F16_F32(src) vcvt_f16_f32(src) | |||
| #define MS_DIV_F16(src1, src2) vdiv_f16(src1, src2) | |||
| #define MS_DIVQ_F16(src1, src2) vdivq_f16(src1, src2) | |||
| #define MS_FMAQ_N_F16(src1, src2, src3) vfmaq_n_f16(src1, src2, src3) | |||
| #endif | |||
| static inline float16x8_t MS_TANHX8_F16(float16x8_t src) { | |||
| float32x4_t src_low = MS_CVT_F32_F16(vget_low_f16(src)); | |||
| float32x4_t src_high = MS_CVT_F32_F16(vget_high_f16(src)); | |||
| return vcombine_f16(MS_CVT_F16_F32(MS_TANHX4_F32(src_low)), MS_CVT_F16_F32(MS_TANHX4_F32(src_high))); | |||
| } | |||
| static inline float16x8_t MS_ERFX8_F16(float16x8_t src) { | |||
| float16x8_t dst; | |||
| dst[0] = erff(src[0]); | |||
| dst[1] = erff(src[1]); | |||
| dst[2] = erff(src[2]); | |||
| dst[3] = erff(src[3]); | |||
| dst[4] = erff(src[4]); | |||
| dst[5] = erff(src[5]); | |||
| dst[6] = erff(src[6]); | |||
| dst[7] = erff(src[7]); | |||
| return dst; | |||
| } | |||
| #endif // MINDSPORE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_FP16_H_ | |||
| @@ -4,11 +4,15 @@ set(NNACL_DIR ${CMAKE_CURRENT_SOURCE_DIR}/..) | |||
| include_directories(NNACL_DIR) | |||
| ########################### optimized files ########################### | |||
| file(GLOB SDOT_SRC ${NNACL_DIR}/assembly/opt/*.S) | |||
| file(GLOB FP16_C_SRC ${NNACL_DIR}/fp16/*.c) | |||
| file(GLOB FP16_NEON_SRC ${NNACL_DIR}/assembly/fp16/*.S) | |||
| if(PLATFORM_ARM32) | |||
| file(GLOB FP16_NEON_SRC ${NNACL_DIR}/assembly/arm82_aarch32_fp16/*.S) | |||
| else() | |||
| file(GLOB FP16_NEON_SRC ${NNACL_DIR}/assembly/fp16/*.S) | |||
| file(GLOB SDOT_SRC ${NNACL_DIR}/assembly/opt/*.S) | |||
| set_property(SOURCE ${SDOT_SRC} PROPERTY LANGUAGE C) | |||
| endif() | |||
| set_property(SOURCE ${SDOT_SRC} PROPERTY LANGUAGE C) | |||
| set_property(SOURCE ${FP16_C_SRC} PROPERTY LANGUAGE C) | |||
| set_property(SOURCE ${FP16_NEON_SRC} PROPERTY LANGUAGE C) | |||
| @@ -17,7 +21,6 @@ if(APPLE) | |||
| set_source_files_properties(${FP16_NEON_SRC} PROPERTIES COMPILE_FLAGS "-x assembler-with-cpp") | |||
| endif() | |||
| ########################### share library build ######################## | |||
| list(APPEND SDOT_FILES ${SDOT_SRC}) | |||
| list(APPEND FP16_FILES ${FP16_C_SRC}) | |||
| list(APPEND FP16_FILES ${FP16_NEON_SRC}) | |||
| @@ -27,13 +30,20 @@ if(SUPPORT_TRAIN) | |||
| endif() | |||
| string(REPLACE "-fvisibility=hidden" "-fvisibility=default" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") | |||
| set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8.2-a+dotprod+fp16") | |||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv8.2-a+dotprod+fp16") | |||
| add_library(nnacl_optimize_mid OBJECT ${SDOT_FILES}) | |||
| add_dependencies(nnacl_optimize_mid fbs_src) | |||
| if(NOT PLATFORM_ARM32) | |||
| list(APPEND SDOT_FILES ${SDOT_SRC}) | |||
| add_library(nnacl_optimize_mid OBJECT ${SDOT_FILES}) | |||
| add_dependencies(nnacl_optimize_mid fbs_src) | |||
| endif() | |||
| if(ENABLE_FP16) | |||
| add_library(nnacl_fp16_mid OBJECT ${FP16_FILES}) | |||
| if(PLATFORM_ARM32) | |||
| target_compile_options(nnacl_fp16_mid PRIVATE -march=armv8.2-a+fp16 -mfpu=neon-fp-armv8 -mfloat-abi=softfp) | |||
| else() | |||
| set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8.2-a+dotprod+fp16") | |||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv8.2-a+dotprod+fp16") | |||
| endif() | |||
| add_dependencies(nnacl_fp16_mid fbs_src) | |||
| endif() | |||
| @@ -5,6 +5,12 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_L | |||
| message(FATAL_ERROR "GCC version ${CMAKE_CXX_COMPILER_VERSION} must not be less than 7.3.0") | |||
| endif() | |||
| if(PLATFORM_ARM32 AND CMAKE_CXX_COMPILER_ID STREQUAL "Clang" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9.0) | |||
| set(ENABLE_FP16 "off") | |||
| message(WARNING "If you want to build fp16 in arm82_a32, \ | |||
| your Clang version:[${CMAKE_CXX_COMPILER_VERSION}] must not be less than 9.0 and please use android nkd r21e!") | |||
| endif() | |||
| option(MS_VERSION_MAJOR "major version" 0) | |||
| option(MS_VERSION_MINOR "minor version" 7) | |||
| option(MS_VERSION_REVISION "revision version" 0) | |||
| @@ -15,6 +21,7 @@ option(PLATFORM_ARM64 "if build device for arm64" off) | |||
| option(PLATFORM_ARM32 "if build device for arm32" off) | |||
| option(ENABLE_CONVERTER "if build converter" on) | |||
| option(ENABLE_FP16 "if build fp16 ops" off) | |||
| option(ENABLE_ARM82_A32 "if build fp16 on platform_arm32" off) | |||
| option(ENABLE_TOOLS "if build tools" on) | |||
| option(BUILD_TESTCASES "if build testcase" on) | |||
| option(SUPPORT_GPU "if support gpu" off) | |||
| @@ -177,6 +184,9 @@ if(ENABLE_NEON) | |||
| endif() | |||
| if(ENABLE_FP16) | |||
| add_compile_definitions(ENABLE_FP16) | |||
| if(PLATFORM_ARM32) | |||
| add_compile_definitions(ENABLE_ARM82_A32) | |||
| endif() | |||
| endif() | |||
| if(SUPPORT_GPU STREQUAL opencl) | |||
| add_definitions(-DGPU_OPENCL) | |||
| @@ -3,6 +3,9 @@ if(ENABLE_V0) | |||
| add_definitions(-DENABLE_V0) | |||
| endif() | |||
| include_directories(${CCSRC_DIR}/backend/kernel_compiler/cpu) | |||
| set(LITE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/..) | |||
| include_directories(${LITE_DIR}/nnacl/) | |||
| include_directories(${LITE_DIR}/nnacl/optimize) | |||
| if(PLATFORM_ARM32 OR PLATFORM_ARM64) | |||
| #for performance | |||
| @@ -210,9 +213,11 @@ if("${CMAKE_BUILD_TYPE}" STREQUAL "Release") | |||
| endif() | |||
| ########################## build optimize and float16 library ################################# | |||
| if(PLATFORM_ARM64) | |||
| target_link_libraries(mindspore-lite cpu_opt_kernel_mid nnacl_optimize_mid) | |||
| target_link_libraries(mindspore-lite_static cpu_opt_kernel_mid nnacl_optimize_mid) | |||
| if(PLATFORM_ARM) | |||
| if(PLATFORM_ARM64) | |||
| target_link_libraries(mindspore-lite cpu_opt_kernel_mid nnacl_optimize_mid) | |||
| target_link_libraries(mindspore-lite_static cpu_opt_kernel_mid nnacl_optimize_mid) | |||
| endif() | |||
| if(ENABLE_FP16) | |||
| target_link_libraries(mindspore-lite cpu_fp16_kernel_mid nnacl_fp16_mid) | |||
| target_link_libraries(mindspore-lite_static cpu_fp16_kernel_mid nnacl_fp16_mid) | |||
| @@ -248,8 +253,10 @@ if(DEFINED ARCHS) | |||
| target_link_libraries(mindspore_lite mindrt_mid) | |||
| endif() | |||
| if(PLATFORM_ARM64) | |||
| target_link_libraries(mindspore_lite cpu_opt_kernel_mid nnacl_optimize_mid) | |||
| if(PLATFORM_ARM) | |||
| if(PLATFORM_ARM64) | |||
| target_link_libraries(mindspore_lite cpu_opt_kernel_mid nnacl_optimize_mid) | |||
| endif() | |||
| if(ENABLE_FP16) | |||
| target_link_libraries(mindspore_lite cpu_fp16_kernel_mid nnacl_fp16_mid) | |||
| endif() | |||
| @@ -155,7 +155,11 @@ bool IsSupportSDot() { | |||
| bool IsSupportFloat16() { | |||
| bool status = false; | |||
| #ifdef ENABLE_ARM64 | |||
| #ifdef ENABLE_ARM32 | |||
| status = true; | |||
| #endif | |||
| #if defined(ENABLE_ARM64) | |||
| #if defined(__ANDROID__) | |||
| int hwcap_type = 16; | |||
| uint32_t hwcap = getHwCap(hwcap_type); | |||
| @@ -44,7 +44,7 @@ uint64_t GetTimeUs(); | |||
| bool IsSupportSDot(); | |||
| bool IsSupportFloat16(); | |||
| #if defined(__arm__) || defined(__aarch64__) | |||
| #if defined(__arm__) | |||
| uint32_t getHwCap(int hwcap_type); | |||
| #endif | |||
| @@ -19,7 +19,7 @@ | |||
| #include "src/common/version_manager.h" | |||
| #include "nnacl/pooling_parameter.h" | |||
| #include "src/ios_reg_kernels.h" | |||
| #ifdef ENABLE_ARM64 | |||
| #if defined(ENABLE_FP16) && defined(ENABLE_ARM) | |||
| #if defined(__ANDROID__) | |||
| #include <asm/hwcap.h> | |||
| #endif | |||
| @@ -55,6 +55,8 @@ int KernelRegistry::Init() { | |||
| } else { | |||
| MS_LOG(INFO) << "The current device NOT supports Sdot."; | |||
| } | |||
| #endif | |||
| #ifdef ENABLE_FP16 | |||
| if (mindspore::lite::IsSupportFloat16()) { | |||
| MS_LOG(INFO) << "The current device supports float16."; | |||
| } else { | |||
| @@ -17,7 +17,7 @@ endif() | |||
| add_library(cpu_kernel_mid OBJECT ${KERNEL_SRC}) | |||
| add_dependencies(cpu_kernel_mid fbs_src) | |||
| if(PLATFORM_ARM64) | |||
| if(PLATFORM_ARM) | |||
| if(ENABLE_FP16) | |||
| file(GLOB FP16_KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/fp16/*.cc) | |||
| if(SUPPORT_TRAIN) | |||
| @@ -52,7 +52,7 @@ int ConstantOfShapeCPUKernel::DoExecute(int task_id) { | |||
| ConstantOfShapeInt32(reinterpret_cast<int32_t *>(output_ptr_), start, start + current_stride, | |||
| param_->value_.int32_value_); | |||
| break; | |||
| #ifdef ENABLE_NEON | |||
| #ifdef ENABLE_FP16 | |||
| case kNumberTypeFloat16: | |||
| ConstantOfShapeFp16(reinterpret_cast<float16_t *>(output_ptr_), start, start + current_stride, | |||
| param_->value_.f32_value_); | |||
| @@ -31,8 +31,8 @@ int Convolution1x1FP16CPUKernel::InitMatmulParam() { | |||
| matmul_param_->row_ = conv_param_->output_h_ * conv_param_->output_w_; | |||
| matmul_param_->col_ = conv_param_->output_channel_; | |||
| matmul_param_->deep_ = conv_param_->input_channel_; | |||
| matmul_param_->row_16_ = UP_ROUND(matmul_param_->row_, C16NUM); | |||
| matmul_param_->col_8_ = UP_ROUND(matmul_param_->col_, C8NUM); | |||
| matmul_param_->row_align_ = UP_ROUND(matmul_param_->row_, row_tile_); | |||
| matmul_param_->col_align_ = UP_ROUND(matmul_param_->col_, col_tile_); | |||
| matmul_param_->act_type_ = conv_param_->act_type_; | |||
| return RET_OK; | |||
| } | |||
| @@ -54,14 +54,14 @@ int Convolution1x1FP16CPUKernel::InitConv1x1Param() { | |||
| pre_trans_input_ = (conv_param_->pad_u_ != 0 || conv_param_->pad_l_ != 0 || conv_param_->stride_h_ != 1 || | |||
| conv_param_->stride_w_ != 1); | |||
| if ((matmul_param_->row_ > (C16NUM * op_parameter_->thread_num_)) && (matmul_param_->row_ > matmul_param_->col_)) { | |||
| if ((matmul_param_->row_ > (row_tile_ * op_parameter_->thread_num_)) && (matmul_param_->row_ > matmul_param_->col_)) { | |||
| multi_thread_by_hw_ = true; | |||
| thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->row_, C16NUM)); | |||
| thread_stride_ = UP_DIV(UP_DIV(matmul_param_->row_, C16NUM), thread_count_) * C16NUM; | |||
| thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->row_, row_tile_)); | |||
| thread_stride_ = UP_DIV(UP_DIV(matmul_param_->row_, row_tile_), thread_count_) * row_tile_; | |||
| } else { | |||
| multi_thread_by_hw_ = false; | |||
| thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->col_, C8NUM)); | |||
| thread_stride_ = UP_DIV(UP_DIV(matmul_param_->col_, C8NUM), thread_count_) * C8NUM; | |||
| thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->col_, col_tile_)); | |||
| thread_stride_ = UP_DIV(UP_DIV(matmul_param_->col_, col_tile_), thread_count_) * col_tile_; | |||
| } | |||
| if (pre_trans_input_) { | |||
| @@ -81,8 +81,8 @@ int Convolution1x1FP16CPUKernel::InitWeightBias() { | |||
| auto output_channel = weight_tensor->Batch(); | |||
| if (in_tensors_.size() == 3) { | |||
| size_t size = UP_ROUND(output_channel, C8NUM) * sizeof(float16_t); | |||
| size_t weight_size = output_channel * sizeof(float16_t); | |||
| size_t size = UP_ROUND(output_channel, col_tile_) * sizeof(float16_t); | |||
| size_t bias_size = output_channel * sizeof(float16_t); | |||
| bias_data_ = malloc(size); | |||
| if (bias_data_ == nullptr) { | |||
| MS_LOG(ERROR) << "Conv1x1 Malloc bias_ptr_ error!"; | |||
| @@ -94,11 +94,11 @@ int Convolution1x1FP16CPUKernel::InitWeightBias() { | |||
| MS_LOG(ERROR) << "Conv1x1 only support fp16 weight"; | |||
| return RET_ERROR; | |||
| } | |||
| memset(reinterpret_cast<char *>(bias_data_) + weight_size, 0, size - weight_size); | |||
| memset(reinterpret_cast<char *>(bias_data_) + bias_size, 0, size - bias_size); | |||
| } | |||
| size_t size = input_channel * UP_ROUND(output_channel, C8NUM) * sizeof(float16_t); | |||
| size_t down_size = input_channel * DOWN_DIV(output_channel, C8NUM) * C8NUM * sizeof(float16_t); | |||
| size_t size = input_channel * UP_ROUND(output_channel, col_tile_) * sizeof(float16_t); | |||
| size_t down_size = input_channel * DOWN_DIV(output_channel, col_tile_) * col_tile_ * sizeof(float16_t); | |||
| weight_ptr_ = reinterpret_cast<float16_t *>(malloc(size)); | |||
| if (weight_ptr_ == nullptr) { | |||
| MS_LOG(ERROR) << "Conv1x1 Malloc weight_ptr_ error!"; | |||
| @@ -111,6 +111,12 @@ int Convolution1x1FP16CPUKernel::InitWeightBias() { | |||
| } | |||
| int Convolution1x1FP16CPUKernel::Init() { | |||
| col_tile_ = C8NUM; | |||
| #ifdef ENABLE_ARM64 | |||
| row_tile_ = C16NUM; | |||
| #else | |||
| row_tile_ = C12NUM; | |||
| #endif | |||
| matmul_param_ = new (std::nothrow) MatMulParameter(); | |||
| if (matmul_param_ == nullptr) { | |||
| MS_LOG(ERROR) << "Init matmul_param_ failed."; | |||
| @@ -177,8 +183,11 @@ int Convolution1x1FP16CPUKernel::RunHw(int task_id) { | |||
| float16_t *thread_input_ptr = input_ptr_ + task_id * thread_stride_ * matmul_param_->deep_; | |||
| float16_t *thread_pack_input = pack_input_ + task_id * thread_stride_ * matmul_param_->deep_; | |||
| #ifdef ENABLE_ARM64 | |||
| RowMajor2Col16MajorFp16Opt(thread_input_ptr, thread_pack_input, cur_hw_, matmul_param_->deep_); | |||
| #else | |||
| RowMajor2Col12MajorFp16Opt(thread_input_ptr, thread_pack_input, cur_hw_, matmul_param_->deep_); | |||
| #endif | |||
| float16_t *thread_output_ptr = output_ptr_ + task_id * thread_stride_ * matmul_param_->col_; | |||
| MatMulFp16(thread_pack_input, weight_ptr_, thread_output_ptr, reinterpret_cast<float16_t *>(bias_data_), | |||
| matmul_param_->act_type_, matmul_param_->deep_, cur_hw_, matmul_param_->col_, matmul_param_->col_, | |||
| @@ -211,7 +220,7 @@ int Convolution1x1FP16CPUKernel::Run() { | |||
| ConvolutionBaseFP16CPUKernel::GetExecuteTensor(); | |||
| pack_input_ = reinterpret_cast<float16_t *>( | |||
| ctx_->allocator->Malloc(matmul_param_->row_16_ * matmul_param_->deep_ * sizeof(float16_t))); | |||
| ctx_->allocator->Malloc(matmul_param_->row_align_ * matmul_param_->deep_ * sizeof(float16_t))); | |||
| if (pack_input_ == nullptr) { | |||
| MS_LOG(ERROR) << "Conv1x1 Malloc pack_input_ error!"; | |||
| return RET_MEMORY_FAILED; | |||
| @@ -231,7 +240,11 @@ int Convolution1x1FP16CPUKernel::Run() { | |||
| if (multi_thread_by_hw_) { | |||
| ret = ParallelLaunch(this->context_->thread_pool_, Convolution1x1Fp16RunHw, this, thread_count_); | |||
| } else { | |||
| #ifdef ENABLE_ARM64 | |||
| RowMajor2Col16MajorFp16Opt(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_); | |||
| #else | |||
| RowMajor2Col12MajorFp16Opt(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_); | |||
| #endif | |||
| ret = ParallelLaunch(this->context_->thread_pool_, Convolution1x1Fp16RunOc, this, thread_count_); | |||
| } | |||
| if (ret != RET_OK) { | |||
| @@ -62,6 +62,8 @@ class Convolution1x1FP16CPUKernel : public ConvolutionBaseFP16CPUKernel { | |||
| float16_t *pack_input_ = nullptr; | |||
| float16_t *output_ptr_ = nullptr; | |||
| MatMulParameter *matmul_param_ = nullptr; | |||
| int col_tile_; | |||
| int row_tile_; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| @@ -34,7 +34,7 @@ int ConvolutionFP16CPUKernel::InitWeightBias() { | |||
| int out_channel = filter_tensor->Batch(); | |||
| conv_param_->input_channel_ = in_channel; | |||
| conv_param_->output_channel_ = out_channel; | |||
| int oc8 = UP_ROUND(out_channel, C8NUM); | |||
| int oc8 = UP_ROUND(out_channel, col_tile_); | |||
| int kernel_plane = filter_tensor->Height() * filter_tensor->Width(); | |||
| int pack_weight_size = oc8 * in_channel * kernel_plane; | |||
| @@ -69,9 +69,8 @@ int ConvolutionFP16CPUKernel::InitWeightBias() { | |||
| } | |||
| int ConvolutionFP16CPUKernel::InitTmpBuffer() { | |||
| const int cal_num = 16; | |||
| int unit_size = | |||
| conv_param_->kernel_h_ * conv_param_->kernel_w_ * conv_param_->input_channel_ * cal_num * thread_count_; | |||
| conv_param_->kernel_h_ * conv_param_->kernel_w_ * conv_param_->input_channel_ * row_tile_ * thread_count_; | |||
| packed_input_ = reinterpret_cast<float16_t *>(ctx_->allocator->Malloc(unit_size * sizeof(float16_t))); | |||
| if (packed_input_ == nullptr) { | |||
| @@ -88,6 +87,12 @@ int ConvolutionFP16CPUKernel::InitTmpBuffer() { | |||
| } | |||
| int ConvolutionFP16CPUKernel::Init() { | |||
| #ifdef ENABLE_ARM64 | |||
| row_tile_ = C16NUM; | |||
| #else | |||
| row_tile_ = C12NUM; | |||
| #endif | |||
| col_tile_ = C8NUM; | |||
| auto ret = InitWeightBias(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Init weight bias failed."; | |||
| @@ -99,7 +104,7 @@ int ConvolutionFP16CPUKernel::Init() { | |||
| void ConvolutionFP16CPUKernel::AdjustNumberOfThread() { | |||
| auto out_tensor = out_tensors_.front(); | |||
| int out_plane = out_tensor->Height() * out_tensor->Width(); | |||
| thread_count_ = MSMIN(ctx_->thread_num_, UP_DIV(out_plane, C16NUM)); | |||
| thread_count_ = MSMIN(ctx_->thread_num_, UP_DIV(out_plane, row_tile_)); | |||
| conv_param_->thread_num_ = thread_count_; | |||
| } | |||
| @@ -62,6 +62,8 @@ class ConvolutionFP16CPUKernel : public ConvolutionBaseFP16CPUKernel { | |||
| float16_t *packed_input_ = nullptr; | |||
| float16_t *packed_weight_ = nullptr; | |||
| float16_t *col_major_input_ = nullptr; | |||
| int col_tile_; | |||
| int row_tile_; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| @@ -38,13 +38,10 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() { | |||
| int out_channel = filter_tensor->Batch(); | |||
| conv_param_->input_channel_ = in_channel; | |||
| conv_param_->output_channel_ = out_channel; | |||
| const int oc_block = C8NUM; | |||
| int oc_block_num = UP_DIV(out_channel, C8NUM); | |||
| int oc_block_num = UP_DIV(out_channel, col_tile_); | |||
| // init weight | |||
| // set data | |||
| auto trans_matrix_data_size = input_unit_ * input_unit_ * in_channel * oc_block_num * oc_block * sizeof(float16_t); | |||
| auto trans_matrix_data_size = input_unit_ * input_unit_ * in_channel * oc_block_num * col_tile_ * sizeof(float16_t); | |||
| trans_weight_ = reinterpret_cast<float16_t *>(malloc(trans_matrix_data_size)); | |||
| if (trans_weight_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc trans_weight_ failed."; | |||
| @@ -73,7 +70,7 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() { | |||
| MS_LOG(ERROR) << "get execute filter failed."; | |||
| return ret; | |||
| } | |||
| ret = WinogradFilterTransformFp16(execute_weight_, matrix_g, matrix_gt, oc_block); | |||
| ret = WinogradFilterTransformFp16(execute_weight_, matrix_g, matrix_gt, col_tile_); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "winograd filter transform failed."; | |||
| return ret; | |||
| @@ -85,12 +82,12 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() { | |||
| } | |||
| // init bias | |||
| bias_data_ = malloc(oc_block_num * oc_block * sizeof(float16_t)); | |||
| bias_data_ = malloc(oc_block_num * col_tile_ * sizeof(float16_t)); | |||
| if (bias_data_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc bias_data_ failed."; | |||
| return RET_ERROR; | |||
| } | |||
| memset(bias_data_, 0, oc_block_num * oc_block * sizeof(float16_t)); | |||
| memset(bias_data_, 0, oc_block_num * col_tile_ * sizeof(float16_t)); | |||
| if (in_tensors_.size() == kInputSize2) { | |||
| if (origin_bias_data_type_ == kNumberTypeFloat16) { | |||
| @@ -106,11 +103,9 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() { | |||
| } | |||
| int ConvolutionWinogradFP16CPUKernel::InitTmpBuffer() { | |||
| const int cal_num = 16; | |||
| int channel_out = conv_param_->output_channel_; | |||
| size_t tile_buffer_size = | |||
| thread_count_ * cal_num * input_unit_ * input_unit_ * conv_param_->input_channel_ * sizeof(float16_t); | |||
| thread_count_ * row_tile_ * input_unit_ * input_unit_ * conv_param_->input_channel_ * sizeof(float16_t); | |||
| trans_input_ = reinterpret_cast<float16_t *>(ctx_->allocator->Malloc(tile_buffer_size)); | |||
| if (trans_input_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc trans_input_ failed."; | |||
| @@ -118,7 +113,7 @@ int ConvolutionWinogradFP16CPUKernel::InitTmpBuffer() { | |||
| } | |||
| gemm_out_ = reinterpret_cast<float16_t *>(ctx_->allocator->Malloc( | |||
| thread_count_ * cal_num * input_unit_ * input_unit_ * UP_ROUND(channel_out, C8NUM) * sizeof(float16_t))); | |||
| thread_count_ * row_tile_ * input_unit_ * input_unit_ * UP_ROUND(channel_out, C8NUM) * sizeof(float16_t))); | |||
| if (gemm_out_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc gemm_out_ failed."; | |||
| return RET_ERROR; | |||
| @@ -132,7 +127,7 @@ int ConvolutionWinogradFP16CPUKernel::InitTmpBuffer() { | |||
| } | |||
| col_buffer_ = reinterpret_cast<float16_t *>( | |||
| ctx_->allocator->Malloc(thread_count_ * cal_num * conv_param_->input_channel_ * sizeof(float16_t))); | |||
| ctx_->allocator->Malloc(thread_count_ * row_tile_ * conv_param_->input_channel_ * sizeof(float16_t))); | |||
| if (col_buffer_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc col_buffer_ failed."; | |||
| return RET_ERROR; | |||
| @@ -160,6 +155,12 @@ int ConvolutionWinogradFP16CPUKernel::ConfigInputOutput() { | |||
| } | |||
| int ConvolutionWinogradFP16CPUKernel::Init() { | |||
| col_tile_ = C8NUM; | |||
| #ifdef ENABLE_ARM64 | |||
| row_tile_ = C16NUM; | |||
| #else | |||
| row_tile_ = C12NUM; | |||
| #endif | |||
| kernel_unit_ = conv_param_->kernel_h_; | |||
| input_unit_ = output_unit_ + kernel_unit_ - 1; | |||
| conv_param_->input_unit_ = input_unit_; | |||
| @@ -86,6 +86,8 @@ class ConvolutionWinogradFP16CPUKernel : public ConvolutionBaseFP16CPUKernel { | |||
| TmpBufferAddressFp16 tmp_buffer_address_list_[4]; | |||
| InputTransFp16Func in_func_; | |||
| OutputTransFp16Func out_func_; | |||
| int col_tile_; | |||
| int row_tile_; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| @@ -13,26 +13,20 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifdef ENABLE_ARM64 | |||
| #ifdef ENABLE_ARM | |||
| #include <arm_neon.h> | |||
| #endif | |||
| #include "nnacl/fp16/cast_fp16.h" | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| #ifdef ENABLE_ARM64 | |||
| extern void Float32ToFloat16(const float *input, float16_t *output, int number); | |||
| extern void Float16ToFloat32(const float16_t *input, float *output, int number); | |||
| inline void Float32ToFloat16_fp16_handler(const void *input, void *output, int number) { | |||
| static inline void Float32ToFloat16_fp16_handler(const void *input, void *output, int number) { | |||
| Float32ToFloat16(reinterpret_cast<const float *>(input), reinterpret_cast<float16_t *>(output), number); | |||
| } | |||
| inline void Float16ToFloat32_fp16_handler(const void *input, void *output, int number) { | |||
| static inline void Float16ToFloat32_fp16_handler(const void *input, void *output, int number) { | |||
| Float16ToFloat32(reinterpret_cast<const float16_t *>(input), reinterpret_cast<float *>(output), number); | |||
| } | |||
| #endif | |||
| #ifdef __cplusplus | |||
| } | |||
| @@ -53,8 +53,8 @@ int PoolingFp16CPUKernel::ReSize() { | |||
| } | |||
| int PoolingFp16CPUKernel::RunImpl(int task_id) { | |||
| float16_t minf = -FLT_MAX; | |||
| float16_t maxf = FLT_MAX; | |||
| float16_t minf = -FLT16_MAX; | |||
| float16_t maxf = FLT16_MAX; | |||
| if (pooling_param_->act_type_ == ActType_Relu) { | |||
| minf = 0.f; | |||
| } else if (pooling_param_->act_type_ == ActType_Relu6) { | |||
| @@ -45,7 +45,7 @@ | |||
| #include "src/runtime/agent/npu/optimizer/npu_fusion_pass.h" | |||
| #include "src/runtime/agent/npu/optimizer/npu_insert_transform_pass.h" | |||
| #endif | |||
| #if defined(ENABLE_ARM64) && defined(ENABLE_FP16) | |||
| #if defined(ENABLE_ARM) && defined(ENABLE_FP16) | |||
| #include "src/runtime/kernel/arm/fp16/fp16_op_handler.h" | |||
| #endif | |||
| @@ -230,7 +230,7 @@ int CastConstTensorData(Tensor *tensor, std::map<Tensor *, Tensor *> *restored_o | |||
| auto origin_data = tensor->data_c(); | |||
| MS_ASSERT(origin_data != nullptr); | |||
| if (tensor->data_type() == kNumberTypeFloat32 && dst_data_type == kNumberTypeFloat16) { | |||
| #if defined(ENABLE_ARM64) && defined(ENABLE_FP16) | |||
| #if defined(ENABLE_ARM) && defined(ENABLE_FP16) | |||
| auto restore_tensor = Tensor::CopyTensor(*tensor, false); | |||
| restore_tensor->set_data(origin_data); | |||
| restore_tensor->set_own_data(tensor->own_data()); | |||
| @@ -17,7 +17,7 @@ | |||
| #include "src/sub_graph_kernel.h" | |||
| #include "src/tensor.h" | |||
| #include "src/tensorlist.h" | |||
| #if defined(ENABLE_ARM64) && defined(ENABLE_FP16) | |||
| #ifdef ENABLE_FP16 | |||
| #include "src/runtime/kernel/arm/fp16/fp16_op_handler.h" | |||
| #endif | |||
| #include "src/common/version_manager.h" | |||
| @@ -283,7 +283,7 @@ int CpuFp16SubGraph::Float16TensorToFloat32Tensor(lite::Tensor *tensor) { | |||
| } | |||
| int CpuFp16SubGraph::PreProcess() { | |||
| #ifdef ENABLE_ARM64 | |||
| #ifdef ENABLE_FP16 | |||
| if (!mindspore::lite::IsSupportFloat16()) { | |||
| MS_LOG(ERROR) << "Unsupported fp16 in this devices"; | |||
| return RET_ERROR; | |||
| @@ -347,7 +347,7 @@ int CpuFp16SubGraph::PreProcess() { | |||
| } | |||
| int CpuFp16SubGraph::PostProcess() { | |||
| #ifdef ENABLE_ARM64 | |||
| #ifdef ENABLE_FP16 | |||
| if (!mindspore::lite::IsSupportFloat16()) { | |||
| MS_LOG(ERROR) << "Unsupported fp16 in this devices"; | |||
| return RET_ERROR; | |||
| @@ -378,8 +378,11 @@ add_dependencies(lite-test fbs_src) | |||
| target_link_libraries(lite-test dl mindspore::gtest) | |||
| if(PLATFORM_ARM64 AND ENABLE_FP16) | |||
| target_link_libraries(lite-test nnacl_fp16_mid nnacl_optimize_mid) | |||
| if(PLATFORM_ARM AND ENABLE_FP16) | |||
| target_link_libraries(lite-test nnacl_fp16_mid) | |||
| if(PLATFORM_ARM64) | |||
| target_link_libraries(lite-test nnacl_optimize_mid) | |||
| endif() | |||
| endif() | |||
| if(PLATFORM_ARM) | |||