From: @lzkcode Reviewed-by: @zhang_xue_tong,@zhanghaibo5 Signed-off-by: @zhang_xue_tongpull/14941/MERGE
| @@ -607,7 +607,7 @@ build_lite() | |||||
| cmake -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" -DANDROID_NATIVE_API_LEVEL="19" \ | cmake -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" -DANDROID_NATIVE_API_LEVEL="19" \ | ||||
| -DANDROID_NDK="${ANDROID_NDK}" -DANDROID_ABI="armeabi-v7a" -DANDROID_TOOLCHAIN_NAME="clang" \ | -DANDROID_NDK="${ANDROID_NDK}" -DANDROID_ABI="armeabi-v7a" -DANDROID_TOOLCHAIN_NAME="clang" \ | ||||
| -DANDROID_STL=${ANDROID_STL} -DCMAKE_BUILD_TYPE=${BUILD_TYPE} \ | -DANDROID_STL=${ANDROID_STL} -DCMAKE_BUILD_TYPE=${BUILD_TYPE} \ | ||||
| -DPLATFORM_ARM32=on -DENABLE_NEON=on -DSUPPORT_TRAIN=${SUPPORT_TRAIN} \ | |||||
| -DPLATFORM_ARM32=on -DENABLE_NEON=on -DSUPPORT_TRAIN=${SUPPORT_TRAIN} -DENABLE_FP16="on" \ | |||||
| -DENABLE_TOOLS=${ENABLE_TOOLS} -DENABLE_CONVERTER=${ENABLE_CONVERTER} -DBUILD_TESTCASES=${RUN_TESTCASES} \ | -DENABLE_TOOLS=${ENABLE_TOOLS} -DENABLE_CONVERTER=${ENABLE_CONVERTER} -DBUILD_TESTCASES=${RUN_TESTCASES} \ | ||||
| -DSUPPORT_GPU=${LOCAL_LITE_ENABLE_GPU} -DSUPPORT_NPU=${LOCAL_LITE_ENABLE_NPU} -DENABLE_V0=on \ | -DSUPPORT_GPU=${LOCAL_LITE_ENABLE_GPU} -DSUPPORT_NPU=${LOCAL_LITE_ENABLE_NPU} -DENABLE_V0=on \ | ||||
| -DOFFLINE_COMPILE=${OPENCL_OFFLINE_COMPILE} -DBUILD_MINDDATA=${COMPILE_MINDDATA_LITE} \ | -DOFFLINE_COMPILE=${OPENCL_OFFLINE_COMPILE} -DBUILD_MINDDATA=${COMPILE_MINDDATA_LITE} \ | ||||
| @@ -68,7 +68,7 @@ add_library(nnacl_mid OBJECT ${KERNEL_SRC} ${TRAIN_SRC} ${ASSEMBLY_SRC}) | |||||
| add_dependencies(nnacl fbs_src) | add_dependencies(nnacl fbs_src) | ||||
| add_dependencies(nnacl_mid fbs_src) | add_dependencies(nnacl_mid fbs_src) | ||||
| ########################### arm64 build optimize library ######################## | |||||
| if(PLATFORM_ARM64) | |||||
| ########################### arm fp16 build optimize library ######################## | |||||
| if(ENABLE_FP16) | |||||
| add_subdirectory(${NNACL_DIR}/optimize) | add_subdirectory(${NNACL_DIR}/optimize) | ||||
| endif() | endif() | ||||
| @@ -30,7 +30,7 @@ typedef struct ArgElement { | |||||
| int8_t i8_data_; | int8_t i8_data_; | ||||
| int32_t i_data_; | int32_t i_data_; | ||||
| float f_data_; | float f_data_; | ||||
| #ifdef ENABLE_ARM64 | |||||
| #ifdef ENABLE_ARM | |||||
| float16_t f16_data_; | float16_t f16_data_; | ||||
| #endif | #endif | ||||
| } data_; | } data_; | ||||
| @@ -0,0 +1,602 @@ | |||||
| #ifdef ENABLE_ARM32 | |||||
| #include "nnacl/assembly_global.h" | |||||
| .text | |||||
| .align 5 | |||||
| .global MatMul12x8A32Fp16 | |||||
| #ifndef __APPLE__ | |||||
| .type MatMul12x8A32Fp16, %function | |||||
| #endif | |||||
| // void MatMul12x8A32Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, | |||||
| // int deep, int row, int col, int stride, bool write_mode); | |||||
| // r0: a | |||||
| // r1: b | |||||
| // r2: dst | |||||
| // r3: bias | |||||
| // #4: depth | |||||
| // #8: row | |||||
| // #12: col | |||||
| // #16: stride | |||||
| // #20: writeNhwc/writeWino | |||||
| asm_function MatMul12x8A32Fp16 | |||||
| // r13(sp) and r15(pc) can not be used!! | |||||
| // r9 r4 is tmp register | |||||
| // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf | |||||
| push {r3-r11, lr} | |||||
| vpush {q4-q7} | |||||
| add sp, sp, #104 | |||||
| ldr r5, [sp, #4] | |||||
| ldr r6, [sp, #8] | |||||
| ldr r7, [sp, #12] | |||||
| ldr r8, [sp, #16] | |||||
| ldr lr, [sp, #20] | |||||
| mov r10, r1 // b | |||||
| mov r11, r0 // a | |||||
| mov r12, r2 // dst | |||||
| cmp lr, #2 | |||||
| bne NoWinograd | |||||
| mul r4, r8, r7 // stride * col | |||||
| add r4, r4, r4 // r4 * sizeof(float16_t) | |||||
| mov r9, #16 | |||||
| mul r9, r8, r9 // stride * 8 * sizeof(float16_t) | |||||
| NoWinograd: | |||||
| add r8, r8, r8 // stride * sizeof(float16_t) | |||||
| a .req r0 | |||||
| weight .req r1 | |||||
| dst .req r2 | |||||
| bias .req r3 | |||||
| depth .req r5 | |||||
| row .req r6 | |||||
| col .req r7 | |||||
| stride .req r8 | |||||
| b_tmp .req r10 | |||||
| a_tmp .req r11 | |||||
| dst_tmp .req r12 | |||||
| .macro STORE_12x8 p1 | |||||
| vst1.16 {\p1}, [dst] | |||||
| add dst, dst, stride | |||||
| .endm | |||||
| .macro STORE_12x7 p1, p2, p3 | |||||
| add r4, dst, #8 | |||||
| add r9, dst, #12 | |||||
| vst1.16 {\p1}, [dst] | |||||
| vst1.32 {\p2}, [r4] | |||||
| vst1.16 {\p3}, [r9] | |||||
| add dst, dst, stride | |||||
| .endm | |||||
| .macro STORE_12x6 p1, p2 | |||||
| add r4, dst, #8 | |||||
| vst1.16 {\p1}, [dst] | |||||
| vst1.32 {\p2}, [r4] | |||||
| add dst, dst, stride | |||||
| .endm | |||||
| .macro STORE_12x5 p1, p2 | |||||
| add r4, dst, #8 | |||||
| vst1.16 {\p1}, [dst] | |||||
| vst1.16 {\p2}, [r4] | |||||
| add dst, dst, stride | |||||
| .endm | |||||
| .macro STORE_12x4 p1 | |||||
| vst1.16 {\p1}, [dst] | |||||
| add dst, dst, stride | |||||
| .endm | |||||
| .macro STORE_12x3 p1, p2 | |||||
| add r4, dst, #4 | |||||
| vst1.32 {\p1}, [dst] | |||||
| vst1.16 {\p2}, [r4] | |||||
| add dst, dst, stride | |||||
| .endm | |||||
| .macro STORE_12x2 p1 | |||||
| vst1.32 {\p1}, [dst] | |||||
| add dst, dst, stride | |||||
| .endm | |||||
| .macro STORE_12x1 p1 | |||||
| vst1.16 {\p1}, [dst] | |||||
| add dst, dst, stride | |||||
| .endm | |||||
| .macro STORE_C8 p1, p2 | |||||
| vst1.16 {\p1}, [dst] | |||||
| cmp row, \p2 | |||||
| add dst, dst, stride | |||||
| beq WriteEnd | |||||
| .endm | |||||
| .macro STORE_C7 p1, p2, p3, p4 | |||||
| add r4, dst, #8 | |||||
| add r9, dst, #12 | |||||
| vst1.16 {\p1}, [dst] | |||||
| vst1.32 {\p2}, [r4] | |||||
| vst1.16 {\p3}, [r9] | |||||
| add dst, dst, stride | |||||
| cmp row, \p4 | |||||
| beq WriteEnd | |||||
| .endm | |||||
| .macro STORE_C6 p1, p2, p3 | |||||
| add r4, dst, #8 | |||||
| vst1.16 {\p1}, [dst] | |||||
| vst1.32 {\p2}, [r4] | |||||
| add dst, dst, stride | |||||
| cmp row, \p3 | |||||
| beq WriteEnd | |||||
| .endm | |||||
| .macro STORE_C5 p1, p2, p3 | |||||
| add r4, dst, #8 | |||||
| vst1.16 {\p1}, [dst] | |||||
| vst1.16 {\p2}, [r4] | |||||
| add dst, dst, stride | |||||
| cmp row, \p3 | |||||
| beq WriteEnd | |||||
| .endm | |||||
| .macro STORE_C4 p1, p2 | |||||
| vst1.16 {\p1}, [dst] | |||||
| cmp row, \p2 | |||||
| add dst, dst, stride | |||||
| beq WriteEnd | |||||
| .endm | |||||
| .macro STORE_C3 p1, p2, p3 | |||||
| add r4, dst, #4 | |||||
| vst1.32 {\p1}, [dst] | |||||
| vst1.16 {\p2}, [r4] | |||||
| add dst, dst, stride | |||||
| cmp row, \p3 | |||||
| beq WriteEnd | |||||
| .endm | |||||
| .macro STORE_C2 p1, p2 | |||||
| vst1.32 {\p1}, [dst] | |||||
| add dst, dst, stride | |||||
| cmp row, \p2 | |||||
| beq WriteEnd | |||||
| .endm | |||||
| .macro STORE_C1 p1, p2 | |||||
| vst1.16 {\p1}, [dst] | |||||
| add dst, dst, stride | |||||
| cmp row, \p2 | |||||
| beq WriteEnd | |||||
| .endm | |||||
| LoopRow12: | |||||
| ldr bias, [sp, #-40] | |||||
| LoopCol8: | |||||
| mov dst, dst_tmp | |||||
| mov a, a_tmp | |||||
| ldr depth, [sp, #4] | |||||
| veor q4, q4, q4 | |||||
| veor q5, q5, q5 | |||||
| veor q6, q6, q6 | |||||
| veor q7, q7, q7 | |||||
| veor q8, q8, q8 | |||||
| veor q9, q9, q9 | |||||
| veor q10, q10, q10 | |||||
| veor q11, q11, q11 | |||||
| veor q12, q12, q12 | |||||
| veor q13, q13, q13 | |||||
| veor q14, q14, q14 | |||||
| veor q15, q15, q15 | |||||
| LoopDepth: | |||||
| vld1.16 {q0, d2}, [a]! | |||||
| vld1.16 {q2}, [weight]! | |||||
| vmla.f16 q4, q2, d0[0] | |||||
| vmla.f16 q5, q2, d0[1] | |||||
| vmla.f16 q6, q2, d0[2] | |||||
| vmla.f16 q7, q2, d0[3] | |||||
| vmla.f16 q8, q2, d1[0] | |||||
| vmla.f16 q9, q2, d1[1] | |||||
| vmla.f16 q10, q2, d1[2] | |||||
| vmla.f16 q11, q2, d1[3] | |||||
| vmla.f16 q12, q2, d2[0] | |||||
| vmla.f16 q13, q2, d2[1] | |||||
| vmla.f16 q14, q2, d2[2] | |||||
| vmla.f16 q15, q2, d2[3] | |||||
| subs depth, depth, #1 | |||||
| bne LoopDepth | |||||
| Bias: | |||||
| cmp bias, #0 | |||||
| beq Activation | |||||
| vld1.16 {q0}, [bias]! | |||||
| vadd.f16 q4, q4, q0 | |||||
| vadd.f16 q5, q5, q0 | |||||
| vadd.f16 q6, q6, q0 | |||||
| vadd.f16 q7, q7, q0 | |||||
| vadd.f16 q8, q8, q0 | |||||
| vadd.f16 q9, q9, q0 | |||||
| vadd.f16 q10, q10, q0 | |||||
| vadd.f16 q11, q11, q0 | |||||
| vadd.f16 q12, q12, q0 | |||||
| vadd.f16 q13, q13, q0 | |||||
| vadd.f16 q14, q14, q0 | |||||
| vadd.f16 q15, q15, q0 | |||||
| Activation: | |||||
| ldr lr, [sp] | |||||
| cmp lr, #3 | |||||
| beq Relu6 | |||||
| cmp lr, #1 | |||||
| beq Relu | |||||
| b Write | |||||
| Relu6: | |||||
| vmov.i16 q2, #0x4600 | |||||
| vadd.f16 q4, q4, q2 | |||||
| vadd.f16 q5, q5, q2 | |||||
| vadd.f16 q6, q6, q2 | |||||
| vadd.f16 q7, q7, q2 | |||||
| vmin.f16 q8, q8, q2 | |||||
| vmin.f16 q9, q9, q2 | |||||
| vmin.f16 q10, q10, q2 | |||||
| vmin.f16 q11, q11, q2 | |||||
| vmin.f16 q12, q12, q2 | |||||
| vmin.f16 q13, q13, q2 | |||||
| vmin.f16 q14, q14, q2 | |||||
| vmin.f16 q15, q15, q2 | |||||
| Relu: | |||||
| veor q3, q3, q3 | |||||
| vmax.f16 q4, q4, q3 | |||||
| vmax.f16 q5, q5, q3 | |||||
| vmax.f16 q6, q6, q3 | |||||
| vmax.f16 q7, q7, q3 | |||||
| vmax.f16 q8, q8, q3 | |||||
| vmax.f16 q9, q9, q3 | |||||
| vmax.f16 q10, q10, q3 | |||||
| vmax.f16 q11, q11, q3 | |||||
| vmax.f16 q12, q12, q3 | |||||
| vmax.f16 q13, q13, q3 | |||||
| vmax.f16 q14, q14, q3 | |||||
| vmax.f16 q15, q15, q3 | |||||
| Write: | |||||
| ldr lr, [sp, #20] | |||||
| cmp lr, #2 | |||||
| beq WriteWinograd | |||||
| cmp row, #12 | |||||
| bge Write12xCol | |||||
| b WriteRowxCol | |||||
| WriteWinograd: | |||||
| vst1.16 {q4}, [dst] | |||||
| add dst, dst, r4 | |||||
| vst1.16 {q5}, [dst] | |||||
| add dst, dst, r4 | |||||
| vst1.16 {q6}, [dst] | |||||
| add dst, dst, r4 | |||||
| vst1.16 {q7}, [dst] | |||||
| add dst, dst, r4 | |||||
| vst1.16 {q8}, [dst] | |||||
| add dst, dst, r4 | |||||
| vst1.16 {q9}, [dst] | |||||
| add dst, dst, r4 | |||||
| vst1.16 {q10}, [dst] | |||||
| add dst, dst, r4 | |||||
| vst1.16 {q11}, [dst] | |||||
| add dst, dst, r4 | |||||
| vst1.16 {q12}, [dst] | |||||
| add dst, dst, r4 | |||||
| vst1.16 {q13}, [dst] | |||||
| add dst, dst, r4 | |||||
| vst1.16 {q14}, [dst] | |||||
| add dst, dst, r4 | |||||
| vst1.16 {q15}, [dst] | |||||
| add dst_tmp, dst_tmp, r9 | |||||
| b WriteEnd | |||||
| Write12xCol: | |||||
| cmp col, #8 | |||||
| bge Write12x8 | |||||
| cmp col, #1 | |||||
| beq Write12x1 | |||||
| cmp col, #2 | |||||
| beq Write12x2 | |||||
| cmp col, #3 | |||||
| beq Write12x3 | |||||
| cmp col, #4 | |||||
| beq Write12x4 | |||||
| cmp col, #5 | |||||
| beq Write12x5 | |||||
| cmp col, #6 | |||||
| beq Write12x6 | |||||
| b Write12x7 | |||||
| WriteRowxCol: | |||||
| cmp col, #8 | |||||
| bge WriteRowx8 | |||||
| cmp col, #1 | |||||
| beq WriteRowx1 | |||||
| cmp col, #2 | |||||
| beq WriteRowx2 | |||||
| cmp col, #3 | |||||
| beq WriteRowx3 | |||||
| cmp col, #4 | |||||
| beq WriteRowx4 | |||||
| cmp col, #5 | |||||
| beq WriteRowx5 | |||||
| cmp col, #6 | |||||
| beq WriteRowx6 | |||||
| b WriteRowx7 | |||||
| Write12x8: | |||||
| STORE_12x8 q4 | |||||
| STORE_12x8 q5 | |||||
| STORE_12x8 q6 | |||||
| STORE_12x8 q7 | |||||
| STORE_12x8 q8 | |||||
| STORE_12x8 q9 | |||||
| STORE_12x8 q10 | |||||
| STORE_12x8 q11 | |||||
| STORE_12x8 q12 | |||||
| STORE_12x8 q13 | |||||
| STORE_12x8 q14 | |||||
| STORE_12x8 q15 | |||||
| b WriteEnd | |||||
| WriteRowx8: | |||||
| STORE_C8 q4, #1 | |||||
| STORE_C8 q5, #2 | |||||
| STORE_C8 q6, #3 | |||||
| STORE_C8 q7, #4 | |||||
| STORE_C8 q8, #5 | |||||
| STORE_C8 q9, #6 | |||||
| STORE_C8 q10, #7 | |||||
| STORE_C8 q11, #8 | |||||
| STORE_C8 q12, #9 | |||||
| STORE_C8 q13, #10 | |||||
| STORE_C8 q14, #11 | |||||
| STORE_C8 q15, #12 | |||||
| b WriteEnd | |||||
| Write12x1: | |||||
| STORE_12x1 d8[0] | |||||
| STORE_12x1 d10[0] | |||||
| STORE_12x1 d12[0] | |||||
| STORE_12x1 d14[0] | |||||
| STORE_12x1 d16[0] | |||||
| STORE_12x1 d18[0] | |||||
| STORE_12x1 d20[0] | |||||
| STORE_12x1 d22[0] | |||||
| STORE_12x1 d24[0] | |||||
| STORE_12x1 d26[0] | |||||
| STORE_12x1 d28[0] | |||||
| STORE_12x1 d30[0] | |||||
| b WriteEnd | |||||
| WriteRowx1: | |||||
| STORE_C1 d8[0], #1 | |||||
| STORE_C1 d10[0], #2 | |||||
| STORE_C1 d12[0], #3 | |||||
| STORE_C1 d14[0], #4 | |||||
| STORE_C1 d16[0], #5 | |||||
| STORE_C1 d18[0], #6 | |||||
| STORE_C1 d20[0], #7 | |||||
| STORE_C1 d22[0], #8 | |||||
| STORE_C1 d24[0], #9 | |||||
| STORE_C1 d26[0], #10 | |||||
| STORE_C1 d28[0], #11 | |||||
| STORE_C1 d30[0], #12 | |||||
| b WriteEnd | |||||
| Write12x2: | |||||
| STORE_12x2 d8[0] | |||||
| STORE_12x2 d10[0] | |||||
| STORE_12x2 d12[0] | |||||
| STORE_12x2 d14[0] | |||||
| STORE_12x2 d16[0] | |||||
| STORE_12x2 d18[0] | |||||
| STORE_12x2 d20[0] | |||||
| STORE_12x2 d22[0] | |||||
| STORE_12x2 d24[0] | |||||
| STORE_12x2 d26[0] | |||||
| STORE_12x2 d28[0] | |||||
| STORE_12x2 d30[0] | |||||
| b WriteEnd | |||||
| WriteRowx2: | |||||
| STORE_C2 d8[0], #1 | |||||
| STORE_C2 d10[0], #2 | |||||
| STORE_C2 d12[0], #3 | |||||
| STORE_C2 d14[0], #4 | |||||
| STORE_C2 d16[0], #5 | |||||
| STORE_C2 d18[0], #6 | |||||
| STORE_C2 d20[0], #7 | |||||
| STORE_C2 d22[0], #8 | |||||
| STORE_C2 d24[0], #9 | |||||
| STORE_C2 d26[0], #10 | |||||
| STORE_C2 d28[0], #11 | |||||
| STORE_C2 d30[0], #12 | |||||
| b WriteEnd | |||||
| Write12x3: | |||||
| STORE_12x3 d8[0], d8[2] | |||||
| STORE_12x3 d10[0], d10[2] | |||||
| STORE_12x3 d12[0], d12[2] | |||||
| STORE_12x3 d14[0], d14[2] | |||||
| STORE_12x3 d16[0], d16[2] | |||||
| STORE_12x3 d18[0], d18[2] | |||||
| STORE_12x3 d20[0], d20[2] | |||||
| STORE_12x3 d22[0], d22[2] | |||||
| STORE_12x3 d24[0], d24[2] | |||||
| STORE_12x3 d26[0], d26[2] | |||||
| STORE_12x3 d28[0], d28[2] | |||||
| STORE_12x3 d30[0], d30[2] | |||||
| b WriteEnd | |||||
| WriteRowx3: | |||||
| STORE_C3 d8[0], d8[2], #1 | |||||
| STORE_C3 d10[0], d10[2], #2 | |||||
| STORE_C3 d12[0], d12[2], #3 | |||||
| STORE_C3 d14[0], d14[2], #4 | |||||
| STORE_C3 d16[0], d16[2], #5 | |||||
| STORE_C3 d18[0], d18[2], #6 | |||||
| STORE_C3 d20[0], d20[2], #7 | |||||
| STORE_C3 d22[0], d22[2], #8 | |||||
| STORE_C3 d24[0], d24[2], #9 | |||||
| STORE_C3 d26[0], d26[2], #10 | |||||
| STORE_C3 d28[0], d28[2], #11 | |||||
| STORE_C3 d30[0], d30[2], #12 | |||||
| b WriteEnd | |||||
| Write12x4: | |||||
| STORE_12x4 d8 | |||||
| STORE_12x4 d10 | |||||
| STORE_12x4 d12 | |||||
| STORE_12x4 d14 | |||||
| STORE_12x4 d16 | |||||
| STORE_12x4 d18 | |||||
| STORE_12x4 d20 | |||||
| STORE_12x4 d22 | |||||
| STORE_12x4 d24 | |||||
| STORE_12x4 d26 | |||||
| STORE_12x4 d28 | |||||
| STORE_12x4 d30 | |||||
| b WriteEnd | |||||
| WriteRowx4: | |||||
| STORE_C4 d8, #1 | |||||
| STORE_C4 d10, #2 | |||||
| STORE_C4 d12, #3 | |||||
| STORE_C4 d14, #4 | |||||
| STORE_C4 d16, #5 | |||||
| STORE_C4 d18, #6 | |||||
| STORE_C4 d20, #7 | |||||
| STORE_C4 d22, #8 | |||||
| STORE_C4 d24, #9 | |||||
| STORE_C4 d26, #10 | |||||
| STORE_C4 d28, #11 | |||||
| STORE_C4 d30, #12 | |||||
| b WriteEnd | |||||
| Write12x5: | |||||
| STORE_12x5 d8, d9[0] | |||||
| STORE_12x5 d10, d11[0] | |||||
| STORE_12x5 d12, d13[0] | |||||
| STORE_12x5 d14, d15[0] | |||||
| STORE_12x5 d16, d17[0] | |||||
| STORE_12x5 d18, d19[0] | |||||
| STORE_12x5 d20, d21[0] | |||||
| STORE_12x5 d22, d23[0] | |||||
| STORE_12x5 d24, d25[0] | |||||
| STORE_12x5 d26, d27[0] | |||||
| STORE_12x5 d28, d29[0] | |||||
| STORE_12x5 d30, d31[0] | |||||
| b WriteEnd | |||||
| WriteRowx5: | |||||
| STORE_C5 d8, d9[0], #1 | |||||
| STORE_C5 d10, d11[0], #2 | |||||
| STORE_C5 d12, d13[0], #3 | |||||
| STORE_C5 d14, d15[0], #4 | |||||
| STORE_C5 d16, d17[0], #5 | |||||
| STORE_C5 d18, d19[0], #6 | |||||
| STORE_C5 d20, d21[0], #7 | |||||
| STORE_C5 d22, d23[0], #8 | |||||
| STORE_C5 d24, d25[0], #9 | |||||
| STORE_C5 d26, d27[0], #10 | |||||
| STORE_C5 d28, d29[0], #11 | |||||
| STORE_C5 d30, d31[0], #12 | |||||
| b WriteEnd | |||||
| Write12x6: | |||||
| STORE_12x6 d8, d9[0] | |||||
| STORE_12x6 d10, d11[0] | |||||
| STORE_12x6 d12, d13[0] | |||||
| STORE_12x6 d14, d15[0] | |||||
| STORE_12x6 d16, d17[0] | |||||
| STORE_12x6 d18, d19[0] | |||||
| STORE_12x6 d20, d21[0] | |||||
| STORE_12x6 d22, d23[0] | |||||
| STORE_12x6 d24, d25[0] | |||||
| STORE_12x6 d26, d27[0] | |||||
| STORE_12x6 d28, d29[0] | |||||
| STORE_12x6 d30, d31[0] | |||||
| b WriteEnd | |||||
| WriteRowx6: | |||||
| STORE_C6 d8, d9[0], #1 | |||||
| STORE_C6 d10, d11[0], #2 | |||||
| STORE_C6 d12, d13[0], #3 | |||||
| STORE_C6 d14, d15[0], #4 | |||||
| STORE_C6 d16, d17[0], #5 | |||||
| STORE_C6 d18, d19[0], #6 | |||||
| STORE_C6 d20, d21[0], #7 | |||||
| STORE_C6 d22, d23[0], #8 | |||||
| STORE_C6 d24, d25[0], #9 | |||||
| STORE_C6 d26, d27[0], #10 | |||||
| STORE_C6 d28, d29[0], #11 | |||||
| STORE_C6 d30, d31[0], #12 | |||||
| b WriteEnd | |||||
| Write12x7: | |||||
| STORE_12x7 d8, d9[0], d9[2] | |||||
| STORE_12x7 d10, d11[0], d11[2] | |||||
| STORE_12x7 d12, d13[0], d13[2] | |||||
| STORE_12x7 d14, d15[0], d15[2] | |||||
| STORE_12x7 d16, d17[0], d17[2] | |||||
| STORE_12x7 d18, d19[0], d19[2] | |||||
| STORE_12x7 d20, d21[0], d21[2] | |||||
| STORE_12x7 d22, d23[0], d23[2] | |||||
| STORE_12x7 d24, d25[0], d25[2] | |||||
| STORE_12x7 d26, d27[0], d27[2] | |||||
| STORE_12x7 d28, d29[0], d29[2] | |||||
| STORE_12x7 d30, d31[0], d31[2] | |||||
| b WriteEnd | |||||
| WriteRowx7: | |||||
| STORE_C7 d8, d9[0], d9[2], #1 | |||||
| STORE_C7 d10, d11[0], d11[2], #2 | |||||
| STORE_C7 d12, d13[0], d13[2], #3 | |||||
| STORE_C7 d14, d15[0], d15[2], #4 | |||||
| STORE_C7 d16, d17[0], d17[2], #5 | |||||
| STORE_C7 d18, d19[0], d19[2], #6 | |||||
| STORE_C7 d20, d21[0], d21[2], #7 | |||||
| STORE_C7 d22, d23[0], d23[2], #8 | |||||
| STORE_C7 d24, d25[0], d25[2], #9 | |||||
| STORE_C7 d26, d27[0], d27[2], #10 | |||||
| STORE_C7 d28, d29[0], d29[2], #11 | |||||
| STORE_C7 d30, d31[0], d31[2], #12 | |||||
| b WriteEnd | |||||
| WriteEnd: | |||||
| cmp col, #8 | |||||
| ble LoopColEnd | |||||
| sub col, col, #8 | |||||
| ldr lr, [sp, #20] | |||||
| cmp lr, #2 | |||||
| beq LoopCol8 | |||||
| add dst_tmp, dst_tmp, #16 | |||||
| b LoopCol8 | |||||
| LoopColEnd: | |||||
| cmp row, #12 | |||||
| ble LoopRowEnd | |||||
| sub row, row, #12 | |||||
| mov a_tmp, a | |||||
| mov weight, b_tmp | |||||
| ldr lr, [sp, #20] | |||||
| cmp lr, #2 | |||||
| beq WinogradDst | |||||
| ldr lr, [sp, #12] | |||||
| sub lr, lr, col | |||||
| add lr, lr, lr // col *= 2 | |||||
| sub dst_tmp, dst, lr | |||||
| b LoopRow | |||||
| WinogradDst: | |||||
| add dst_tmp, dst, r9 | |||||
| LoopRow: | |||||
| mov dst, dst_tmp | |||||
| ldr col, [sp, #12] | |||||
| b LoopRow12 | |||||
| LoopRowEnd: | |||||
| sub sp, sp, #104 | |||||
| vpop {q4-q7} | |||||
| pop {r3-r11, pc} | |||||
| #endif | |||||
| @@ -1,667 +0,0 @@ | |||||
| #ifdef ENABLE_ARM64 | |||||
| #include "nnacl/assembly_global.h" | |||||
| .text | |||||
| .align 5 | |||||
| // void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, | |||||
| // size_t step, size_t ic4, size_t oc8, size_t offset, size_t mode, size_t writeC4, size_t relu, size_t relu6); | |||||
| // x0: output, x1: input, x2: weight, x3: bias, x4: step, x5: ic4, x6: oc8, x7: offset, | |||||
| // x8:mode, x9: writeC4, x10:relu, x11: relu6 | |||||
| // compute 8 channel for 16 outputs | |||||
| asm_function IndirectGemmFp16_16x8 | |||||
| .macro INIT_BIAS | |||||
| dup v16.4s, wzr | |||||
| cbz x3, InitBias | |||||
| ld1 {v16.8h}, [x3] | |||||
| InitBias: | |||||
| mov v17.16b, v16.16b | |||||
| mov v18.16b, v16.16b | |||||
| mov v19.16b, v16.16b | |||||
| mov v20.16b, v16.16b | |||||
| mov v21.16b, v16.16b | |||||
| mov v22.16b, v16.16b | |||||
| mov v23.16b, v16.16b | |||||
| mov v24.16b, v16.16b | |||||
| mov v25.16b, v16.16b | |||||
| mov v26.16b, v16.16b | |||||
| mov v27.16b, v16.16b | |||||
| mov v28.16b, v16.16b | |||||
| mov v29.16b, v16.16b | |||||
| mov v30.16b, v16.16b | |||||
| mov v31.16b, v16.16b | |||||
| .endm | |||||
| // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to | |||||
| // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers | |||||
| // x19 ~ r29 should be also preserved | |||||
| // whereas our coding style do not permit such amount of parameters | |||||
| sub sp, sp, #144 | |||||
| // performance between storing 4 registers at the same time and separately storing them on in-order cores | |||||
| // is not tested yet | |||||
| st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 | |||||
| st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 | |||||
| stp x19, x20, [sp], #16 | |||||
| ldr x8, [sp, #0] | |||||
| ldr x9, [sp, #8] | |||||
| ldr x10, [sp, #16] | |||||
| ldr x11, [sp, #24] | |||||
| cbnz x8, IndirectGemmStart | |||||
| // step is one for common convolution, where ic8 should multiply by kernel size | |||||
| // step is (a+b-1) for F(a,b) in winograd | |||||
| mul x5, x4, x5 | |||||
| mov x4, #1 | |||||
| IndirectGemmStart: | |||||
| LoopOc: | |||||
| mov x14, x4 | |||||
| mov x12, x1 | |||||
| LoopKsize: | |||||
| mov x15, x0 | |||||
| INIT_BIAS | |||||
| // load input for output 1-8 | |||||
| ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x12], #64 | |||||
| // load weight | |||||
| ld1 {v8.8h, v9.8h}, [x2], #32 | |||||
| // first 2 steps for output 1 and 3 | |||||
| fmla v16.8h, v8.8h, v0.h[0] | |||||
| fmla v18.8h, v8.8h, v1.h[0] | |||||
| fmla v16.8h, v9.8h, v0.h[1] | |||||
| fmla v18.8h, v9.8h, v1.h[1] | |||||
| // load weight | |||||
| ld1 {v10.8h, v11.8h}, [x2], #32 | |||||
| // first 2 steps for output 2 and 4 | |||||
| fmla v17.8h, v8.8h, v0.h[4] | |||||
| fmla v19.8h, v8.8h, v1.h[4] | |||||
| fmla v17.8h, v9.8h, v0.h[5] | |||||
| fmla v19.8h, v9.8h, v1.h[5] | |||||
| // load input for output 9-16 | |||||
| // input cache should be refreshed after loading | |||||
| // ATTENTION: advance is preferred, but advancing too much may lead to invalid prefetching | |||||
| ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x12], #64 | |||||
| // last 2 steps for output 1 and 3 | |||||
| fmla v16.8h, v10.8h, v0.h[2] | |||||
| fmla v18.8h, v10.8h, v1.h[2] | |||||
| fmla v16.8h, v11.8h, v0.h[3] | |||||
| fmla v18.8h, v11.8h, v1.h[3] | |||||
| // check if ic4=1 | |||||
| subs x13, x5, #1 | |||||
| beq LoopIcEnd | |||||
| LoopIc: | |||||
| // last 2 steps for output 2 and 4 | |||||
| fmla v17.8h, v10.8h, v0.h[6] | |||||
| fmla v19.8h, v10.8h, v1.h[6] | |||||
| fmla v17.8h, v11.8h, v0.h[7] | |||||
| fmla v19.8h, v11.8h, v1.h[7] | |||||
| // steps for output 5-8 | |||||
| fmla v20.8h, v8.8h, v2.h[0] | |||||
| fmla v22.8h, v8.8h, v3.h[0] | |||||
| fmla v20.8h, v9.8h, v2.h[1] | |||||
| fmla v22.8h, v9.8h, v3.h[1] | |||||
| fmla v21.8h, v8.8h, v2.h[4] | |||||
| fmla v23.8h, v8.8h, v3.h[4] | |||||
| fmla v21.8h, v9.8h, v2.h[5] | |||||
| fmla v23.8h, v9.8h, v3.h[5] | |||||
| fmla v20.8h, v10.8h, v2.h[2] | |||||
| fmla v22.8h, v10.8h, v3.h[2] | |||||
| fmla v20.8h, v11.8h, v2.h[3] | |||||
| fmla v22.8h, v11.8h, v3.h[3] | |||||
| fmla v21.8h, v10.8h, v2.h[6] | |||||
| fmla v23.8h, v10.8h, v3.h[6] | |||||
| fmla v21.8h, v11.8h, v2.h[7] | |||||
| fmla v23.8h, v11.8h, v3.h[7] | |||||
| // load input for output 1-8 | |||||
| ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x12], #64 | |||||
| // steps for output 9-12 | |||||
| fmla v24.8h, v8.8h, v4.h[0] | |||||
| fmla v26.8h, v8.8h, v5.h[0] | |||||
| fmla v24.8h, v9.8h, v4.h[1] | |||||
| fmla v26.8h, v9.8h, v5.h[1] | |||||
| fmla v25.8h, v8.8h, v4.h[4] | |||||
| fmla v27.8h, v8.8h, v5.h[4] | |||||
| fmla v25.8h, v9.8h, v4.h[5] | |||||
| fmla v27.8h, v9.8h, v5.h[5] | |||||
| fmla v24.8h, v10.8h, v4.h[2] | |||||
| fmla v26.8h, v10.8h, v5.h[2] | |||||
| fmla v24.8h, v11.8h, v4.h[3] | |||||
| fmla v26.8h, v11.8h, v5.h[3] | |||||
| fmla v25.8h, v10.8h, v4.h[6] | |||||
| fmla v27.8h, v10.8h, v5.h[6] | |||||
| fmla v25.8h, v11.8h, v4.h[7] | |||||
| fmla v27.8h, v11.8h, v5.h[7] | |||||
| // steps for output 13-16 | |||||
| fmla v28.8h, v8.8h, v6.h[0] | |||||
| fmla v30.8h, v8.8h, v7.h[0] | |||||
| fmla v28.8h, v9.8h, v6.h[1] | |||||
| fmla v30.8h, v9.8h, v7.h[1] | |||||
| fmla v29.8h, v8.8h, v6.h[4] | |||||
| fmla v31.8h, v8.8h, v7.h[4] | |||||
| fmla v29.8h, v9.8h, v6.h[5] | |||||
| fmla v31.8h, v9.8h, v7.h[5] | |||||
| // load weight | |||||
| ld1 {v8.8h, v9.8h}, [x2], #32 | |||||
| fmla v28.8h, v10.8h, v6.h[2] | |||||
| fmla v30.8h, v10.8h, v7.h[2] | |||||
| fmla v28.8h, v11.8h, v6.h[3] | |||||
| fmla v30.8h, v11.8h, v7.h[3] | |||||
| fmla v29.8h, v10.8h, v6.h[6] | |||||
| fmla v31.8h, v10.8h, v7.h[6] | |||||
| fmla v29.8h, v11.8h, v6.h[7] | |||||
| fmla v31.8h, v11.8h, v7.h[7] | |||||
| // load weight | |||||
| ld1 {v10.8h, v11.8h}, [x2], #32 | |||||
| // first 2 steps for output 1-4 | |||||
| fmla v16.8h, v8.8h, v0.h[0] | |||||
| fmla v18.8h, v8.8h, v1.h[0] | |||||
| fmla v16.8h, v9.8h, v0.h[1] | |||||
| fmla v18.8h, v9.8h, v1.h[1] | |||||
| fmla v17.8h, v8.8h, v0.h[4] | |||||
| fmla v19.8h, v8.8h, v1.h[4] | |||||
| fmla v17.8h, v9.8h, v0.h[5] | |||||
| fmla v19.8h, v9.8h, v1.h[5] | |||||
| // load input for output 9-16 | |||||
| ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x12], #64 | |||||
| // last 2 steps for output 1 and 3 | |||||
| fmla v16.8h, v10.8h, v0.h[2] | |||||
| fmla v18.8h, v10.8h, v1.h[2] | |||||
| fmla v16.8h, v11.8h, v0.h[3] | |||||
| fmla v18.8h, v11.8h, v1.h[3] | |||||
| subs x13, x13, #1 | |||||
| bne LoopIc | |||||
| LoopIcEnd: | |||||
| fmla v17.8h, v10.8h, v0.h[6] | |||||
| fmla v19.8h, v10.8h, v1.h[6] | |||||
| fmla v17.8h, v11.8h, v0.h[7] | |||||
| fmla v19.8h, v11.8h, v1.h[7] | |||||
| // steps for output 5-8 | |||||
| fmla v20.8h, v8.8h, v2.h[0] | |||||
| fmla v22.8h, v8.8h, v3.h[0] | |||||
| fmla v20.8h, v9.8h, v2.h[1] | |||||
| fmla v22.8h, v9.8h, v3.h[1] | |||||
| fmla v21.8h, v8.8h, v2.h[4] | |||||
| fmla v23.8h, v8.8h, v3.h[4] | |||||
| fmla v21.8h, v9.8h, v2.h[5] | |||||
| fmla v23.8h, v9.8h, v3.h[5] | |||||
| fmla v20.8h, v10.8h, v2.h[2] | |||||
| fmla v22.8h, v10.8h, v3.h[2] | |||||
| fmla v20.8h, v11.8h, v2.h[3] | |||||
| fmla v22.8h, v11.8h, v3.h[3] | |||||
| fmla v21.8h, v10.8h, v2.h[6] | |||||
| fmla v23.8h, v10.8h, v3.h[6] | |||||
| fmla v21.8h, v11.8h, v2.h[7] | |||||
| fmla v23.8h, v11.8h, v3.h[7] | |||||
| // steps for output 9-12 | |||||
| fmla v24.8h, v8.8h, v4.h[0] | |||||
| fmla v26.8h, v8.8h, v5.h[0] | |||||
| fmla v24.8h, v9.8h, v4.h[1] | |||||
| fmla v26.8h, v9.8h, v5.h[1] | |||||
| fmla v25.8h, v8.8h, v4.h[4] | |||||
| fmla v27.8h, v8.8h, v5.h[4] | |||||
| fmla v25.8h, v9.8h, v4.h[5] | |||||
| fmla v27.8h, v9.8h, v5.h[5] | |||||
| fmla v24.8h, v10.8h, v4.h[2] | |||||
| fmla v26.8h, v10.8h, v5.h[2] | |||||
| fmla v24.8h, v11.8h, v4.h[3] | |||||
| fmla v26.8h, v11.8h, v5.h[3] | |||||
| fmla v25.8h, v10.8h, v4.h[6] | |||||
| fmla v27.8h, v10.8h, v5.h[6] | |||||
| fmla v25.8h, v11.8h, v4.h[7] | |||||
| fmla v27.8h, v11.8h, v5.h[7] | |||||
| // steps for output 13-16 | |||||
| fmla v28.8h, v8.8h, v6.h[0] | |||||
| fmla v30.8h, v8.8h, v7.h[0] | |||||
| fmla v28.8h, v9.8h, v6.h[1] | |||||
| fmla v30.8h, v9.8h, v7.h[1] | |||||
| fmla v29.8h, v8.8h, v6.h[4] | |||||
| fmla v31.8h, v8.8h, v7.h[4] | |||||
| fmla v29.8h, v9.8h, v6.h[5] | |||||
| fmla v31.8h, v9.8h, v7.h[5] | |||||
| fmla v28.8h, v10.8h, v6.h[2] | |||||
| fmla v30.8h, v10.8h, v7.h[2] | |||||
| fmla v28.8h, v11.8h, v6.h[3] | |||||
| fmla v30.8h, v11.8h, v7.h[3] | |||||
| fmla v29.8h, v10.8h, v6.h[6] | |||||
| fmla v31.8h, v10.8h, v7.h[6] | |||||
| fmla v29.8h, v11.8h, v6.h[7] | |||||
| fmla v31.8h, v11.8h, v7.h[7] | |||||
| cbnz x11, Relu6 | |||||
| cbnz x10, Relu | |||||
| b WriteStart | |||||
| Relu6: | |||||
| movi v9.8h, #0x46, lsl #8 | |||||
| fmin v16.8h, v16.8h, v9.8h | |||||
| fmin v17.8h, v17.8h, v9.8h | |||||
| fmin v18.8h, v18.8h, v9.8h | |||||
| fmin v19.8h, v19.8h, v9.8h | |||||
| fmin v20.8h, v20.8h, v9.8h | |||||
| fmin v21.8h, v21.8h, v9.8h | |||||
| fmin v22.8h, v22.8h, v9.8h | |||||
| fmin v23.8h, v23.8h, v9.8h | |||||
| fmin v24.8h, v24.8h, v9.8h | |||||
| fmin v25.8h, v25.8h, v9.8h | |||||
| fmin v26.8h, v26.8h, v9.8h | |||||
| fmin v27.8h, v27.8h, v9.8h | |||||
| fmin v28.8h, v28.8h, v9.8h | |||||
| fmin v29.8h, v29.8h, v9.8h | |||||
| fmin v30.8h, v30.8h, v9.8h | |||||
| fmin v31.8h, v31.8h, v9.8h | |||||
| Relu: | |||||
| dup v8.4s, wzr | |||||
| fmax v16.8h, v16.8h, v8.8h | |||||
| fmax v17.8h, v17.8h, v8.8h | |||||
| fmax v18.8h, v18.8h, v8.8h | |||||
| fmax v19.8h, v19.8h, v8.8h | |||||
| fmax v20.8h, v20.8h, v8.8h | |||||
| fmax v21.8h, v21.8h, v8.8h | |||||
| fmax v22.8h, v22.8h, v8.8h | |||||
| fmax v23.8h, v23.8h, v8.8h | |||||
| fmax v24.8h, v24.8h, v8.8h | |||||
| fmax v25.8h, v25.8h, v8.8h | |||||
| fmax v26.8h, v26.8h, v8.8h | |||||
| fmax v27.8h, v27.8h, v8.8h | |||||
| fmax v28.8h, v28.8h, v8.8h | |||||
| fmax v29.8h, v29.8h, v8.8h | |||||
| fmax v30.8h, v30.8h, v8.8h | |||||
| fmax v31.8h, v31.8h, v8.8h | |||||
| WriteStart: | |||||
| cbnz x9, Write8 | |||||
| cmp x6, #1 | |||||
| beq Write1 | |||||
| cmp x6, #2 | |||||
| beq Write2 | |||||
| cmp x6, #3 | |||||
| beq Write3 | |||||
| cmp x6, #4 | |||||
| beq Write4 | |||||
| cmp x6, #5 | |||||
| beq Write5 | |||||
| cmp x6, #6 | |||||
| beq Write6 | |||||
| cmp x6, #7 | |||||
| beq Write7 | |||||
| b Write8 | |||||
| // prefetching is not preferred while writing results in spite of cache missing | |||||
| // you could try prfm pstl2strm | |||||
| // there are almost no benefits observed though | |||||
| Write1: | |||||
| str h16, [x15] | |||||
| add x15, x15, x7 | |||||
| str h17, [x15] | |||||
| add x15, x15, x7 | |||||
| str h18, [x15] | |||||
| add x15, x15, x7 | |||||
| str h19, [x15] | |||||
| add x15, x15, x7 | |||||
| str h20, [x15] | |||||
| add x15, x15, x7 | |||||
| str h21, [x15] | |||||
| add x15, x15, x7 | |||||
| str h22, [x15] | |||||
| add x15, x15, x7 | |||||
| str h23, [x15] | |||||
| add x15, x15, x7 | |||||
| str h24, [x15] | |||||
| add x15, x15, x7 | |||||
| str h25, [x15] | |||||
| add x15, x15, x7 | |||||
| str h26, [x15] | |||||
| add x15, x15, x7 | |||||
| str h27, [x15] | |||||
| add x15, x15, x7 | |||||
| str h28, [x15] | |||||
| add x15, x15, x7 | |||||
| str h29, [x15] | |||||
| add x15, x15, x7 | |||||
| str h30, [x15] | |||||
| add x15, x15, x7 | |||||
| str h31, [x15] | |||||
| add x0, x0, #2 | |||||
| b WriteEnd | |||||
| Write2: | |||||
| add x17, x15, #2 | |||||
| st1 {v16.h}[0], [x15], x7 | |||||
| st1 {v16.h}[1], [x17], x7 | |||||
| st1 {v17.h}[0], [x15], x7 | |||||
| st1 {v17.h}[1], [x17], x7 | |||||
| st1 {v18.h}[0], [x15], x7 | |||||
| st1 {v18.h}[1], [x17], x7 | |||||
| st1 {v19.h}[0], [x15], x7 | |||||
| st1 {v19.h}[1], [x17], x7 | |||||
| st1 {v20.h}[0], [x15], x7 | |||||
| st1 {v20.h}[1], [x17], x7 | |||||
| st1 {v21.h}[0], [x15], x7 | |||||
| st1 {v21.h}[1], [x17], x7 | |||||
| st1 {v22.h}[0], [x15], x7 | |||||
| st1 {v22.h}[1], [x17], x7 | |||||
| st1 {v23.h}[0], [x15], x7 | |||||
| st1 {v23.h}[1], [x17], x7 | |||||
| st1 {v24.h}[0], [x15], x7 | |||||
| st1 {v24.h}[1], [x17], x7 | |||||
| st1 {v25.h}[0], [x15], x7 | |||||
| st1 {v25.h}[1], [x17], x7 | |||||
| st1 {v26.h}[0], [x15], x7 | |||||
| st1 {v26.h}[1], [x17], x7 | |||||
| st1 {v27.h}[0], [x15], x7 | |||||
| st1 {v27.h}[1], [x17], x7 | |||||
| st1 {v28.h}[0], [x15], x7 | |||||
| st1 {v28.h}[1], [x17], x7 | |||||
| st1 {v29.h}[0], [x15], x7 | |||||
| st1 {v29.h}[1], [x17], x7 | |||||
| st1 {v30.h}[0], [x15], x7 | |||||
| st1 {v30.h}[1], [x17], x7 | |||||
| st1 {v31.h}[0], [x15] | |||||
| st1 {v31.h}[1], [x17] | |||||
| add x0, x0, #4 | |||||
| b WriteEnd | |||||
| Write3: | |||||
| add x17, x15, #4 | |||||
| add x16, x15, #2 | |||||
| st1 {v16.h}[0], [x15], x7 | |||||
| st1 {v16.h}[1], [x16], x7 | |||||
| st1 {v16.h}[2], [x17], x7 | |||||
| st1 {v17.h}[0], [x15], x7 | |||||
| st1 {v17.h}[1], [x16], x7 | |||||
| st1 {v17.h}[2], [x17], x7 | |||||
| st1 {v18.h}[0], [x15], x7 | |||||
| st1 {v18.h}[1], [x16], x7 | |||||
| st1 {v18.h}[2], [x17], x7 | |||||
| st1 {v19.h}[0], [x15], x7 | |||||
| st1 {v19.h}[1], [x16], x7 | |||||
| st1 {v19.h}[2], [x17], x7 | |||||
| st1 {v20.h}[0], [x15], x7 | |||||
| st1 {v20.h}[1], [x16], x7 | |||||
| st1 {v20.h}[2], [x17], x7 | |||||
| st1 {v21.h}[0], [x15], x7 | |||||
| st1 {v21.h}[1], [x16], x7 | |||||
| st1 {v21.h}[2], [x17], x7 | |||||
| st1 {v22.h}[0], [x15], x7 | |||||
| st1 {v22.h}[1], [x16], x7 | |||||
| st1 {v22.h}[2], [x17], x7 | |||||
| st1 {v23.h}[0], [x15], x7 | |||||
| st1 {v23.h}[1], [x16], x7 | |||||
| st1 {v23.h}[2], [x17], x7 | |||||
| st1 {v24.h}[0], [x15], x7 | |||||
| st1 {v24.h}[1], [x16], x7 | |||||
| st1 {v24.h}[2], [x17], x7 | |||||
| st1 {v25.h}[0], [x15], x7 | |||||
| st1 {v25.h}[1], [x16], x7 | |||||
| st1 {v25.h}[2], [x17], x7 | |||||
| st1 {v26.h}[0], [x15], x7 | |||||
| st1 {v26.h}[1], [x16], x7 | |||||
| st1 {v26.h}[2], [x17], x7 | |||||
| st1 {v27.h}[0], [x15], x7 | |||||
| st1 {v27.h}[1], [x16], x7 | |||||
| st1 {v27.h}[2], [x17], x7 | |||||
| st1 {v28.h}[0], [x15], x7 | |||||
| st1 {v28.h}[1], [x16], x7 | |||||
| st1 {v28.h}[2], [x17], x7 | |||||
| st1 {v29.h}[0], [x15], x7 | |||||
| st1 {v29.h}[1], [x16], x7 | |||||
| st1 {v29.h}[2], [x17], x7 | |||||
| st1 {v30.h}[0], [x15], x7 | |||||
| st1 {v30.h}[1], [x16], x7 | |||||
| st1 {v30.h}[2], [x17], x7 | |||||
| st1 {v31.h}[0], [x15] | |||||
| st1 {v31.h}[1], [x16] | |||||
| st1 {v31.h}[2], [x17] | |||||
| add x0, x0, #6 | |||||
| b WriteEnd | |||||
| Write4: | |||||
| st1 {v16.4h}, [x15], x7 | |||||
| st1 {v17.4h}, [x15], x7 | |||||
| st1 {v18.4h}, [x15], x7 | |||||
| st1 {v19.4h}, [x15], x7 | |||||
| st1 {v20.4h}, [x15], x7 | |||||
| st1 {v21.4h}, [x15], x7 | |||||
| st1 {v22.4h}, [x15], x7 | |||||
| st1 {v23.4h}, [x15], x7 | |||||
| st1 {v24.4h}, [x15], x7 | |||||
| st1 {v25.4h}, [x15], x7 | |||||
| st1 {v26.4h}, [x15], x7 | |||||
| st1 {v27.4h}, [x15], x7 | |||||
| st1 {v28.4h}, [x15], x7 | |||||
| st1 {v29.4h}, [x15], x7 | |||||
| st1 {v30.4h}, [x15], x7 | |||||
| st1 {v31.4h}, [x15] | |||||
| add x0, x0, #8 | |||||
| b WriteEnd | |||||
| Write5: | |||||
| add x17, x15, #8 | |||||
| st1 {v16.4h}, [x15], x7 | |||||
| st1 {v16.h}[4], [x17], x7 | |||||
| st1 {v17.4h}, [x15], x7 | |||||
| st1 {v17.h}[4], [x17], x7 | |||||
| st1 {v18.4h}, [x15], x7 | |||||
| st1 {v18.h}[4], [x17], x7 | |||||
| st1 {v19.4h}, [x15], x7 | |||||
| st1 {v19.h}[4], [x17], x7 | |||||
| st1 {v20.4h}, [x15], x7 | |||||
| st1 {v20.h}[4], [x17], x7 | |||||
| st1 {v21.4h}, [x15], x7 | |||||
| st1 {v21.h}[4], [x17], x7 | |||||
| st1 {v22.4h}, [x15], x7 | |||||
| st1 {v22.h}[4], [x17], x7 | |||||
| st1 {v23.4h}, [x15], x7 | |||||
| st1 {v23.h}[4], [x17], x7 | |||||
| st1 {v24.4h}, [x15], x7 | |||||
| st1 {v24.h}[4], [x17], x7 | |||||
| st1 {v25.4h}, [x15], x7 | |||||
| st1 {v25.h}[4], [x17], x7 | |||||
| st1 {v26.4h}, [x15], x7 | |||||
| st1 {v26.h}[4], [x17], x7 | |||||
| st1 {v27.4h}, [x15], x7 | |||||
| st1 {v27.h}[4], [x17], x7 | |||||
| st1 {v28.4h}, [x15], x7 | |||||
| st1 {v28.h}[4], [x17], x7 | |||||
| st1 {v29.4h}, [x15], x7 | |||||
| st1 {v29.h}[4], [x17], x7 | |||||
| st1 {v30.4h}, [x15], x7 | |||||
| st1 {v30.h}[4], [x17], x7 | |||||
| st1 {v31.4h}, [x15] | |||||
| st1 {v31.h}[4], [x17] | |||||
| add x0, x0, #10 | |||||
| b WriteEnd | |||||
| Write6: | |||||
| add x17, x15, #8 | |||||
| add x16, x15, #10 | |||||
| st1 {v16.4h}, [x15], x7 | |||||
| ins v0.s[0], v16.s[2] | |||||
| st1 {v0.h}[0], [x17], x7 | |||||
| st1 {v0.h}[1], [x16], x7 | |||||
| st1 {v17.4h}, [x15], x7 | |||||
| ins v1.s[0], v17.s[2] | |||||
| st1 {v1.h}[0], [x17], x7 | |||||
| st1 {v1.h}[1], [x16], x7 | |||||
| st1 {v18.4h}, [x15], x7 | |||||
| ins v2.s[0], v18.s[2] | |||||
| st1 {v2.h}[0], [x17], x7 | |||||
| st1 {v2.h}[1], [x16], x7 | |||||
| st1 {v19.4h}, [x15], x7 | |||||
| ins v3.s[0], v19.s[2] | |||||
| st1 {v3.h}[0], [x17], x7 | |||||
| st1 {v3.h}[1], [x16], x7 | |||||
| st1 {v20.4h}, [x15], x7 | |||||
| ins v4.s[0], v20.s[2] | |||||
| st1 {v4.h}[0], [x17], x7 | |||||
| st1 {v4.h}[1], [x16], x7 | |||||
| st1 {v21.4h}, [x15], x7 | |||||
| ins v5.s[0], v21.s[2] | |||||
| st1 {v5.h}[0], [x17], x7 | |||||
| st1 {v5.h}[1], [x16], x7 | |||||
| st1 {v22.4h}, [x15], x7 | |||||
| ins v6.s[0], v22.s[2] | |||||
| st1 {v6.h}[0], [x17], x7 | |||||
| st1 {v6.h}[1], [x16], x7 | |||||
| st1 {v23.4h}, [x15], x7 | |||||
| ins v7.s[0], v23.s[2] | |||||
| st1 {v7.h}[0], [x17], x7 | |||||
| st1 {v7.h}[1], [x16], x7 | |||||
| st1 {v24.4h}, [x15], x7 | |||||
| ins v8.s[0], v24.s[2] | |||||
| st1 {v8.h}[0], [x17], x7 | |||||
| st1 {v8.h}[1], [x16], x7 | |||||
| st1 {v25.4h}, [x15], x7 | |||||
| ins v9.s[0], v25.s[2] | |||||
| st1 {v9.h}[0], [x17], x7 | |||||
| st1 {v9.h}[1], [x16], x7 | |||||
| st1 {v26.4h}, [x15], x7 | |||||
| ins v10.s[0], v26.s[2] | |||||
| st1 {v10.h}[0], [x17], x7 | |||||
| st1 {v10.h}[1], [x16], x7 | |||||
| st1 {v27.4h}, [x15], x7 | |||||
| ins v11.s[0], v27.s[2] | |||||
| st1 {v11.h}[0], [x17], x7 | |||||
| st1 {v11.h}[1], [x16], x7 | |||||
| st1 {v28.4h}, [x15], x7 | |||||
| ins v12.s[0], v28.s[2] | |||||
| st1 {v12.h}[0], [x17], x7 | |||||
| st1 {v12.h}[1], [x16], x7 | |||||
| st1 {v29.4h}, [x15], x7 | |||||
| ins v13.s[0], v29.s[2] | |||||
| st1 {v13.h}[0], [x17], x7 | |||||
| st1 {v13.h}[1], [x16], x7 | |||||
| st1 {v30.4h}, [x15], x7 | |||||
| ins v14.s[0], v30.s[2] | |||||
| st1 {v14.h}[0], [x17], x7 | |||||
| st1 {v14.h}[1], [x16], x7 | |||||
| st1 {v31.4h}, [x15] | |||||
| ins v15.s[0], v31.s[2] | |||||
| st1 {v14.h}[0], [x17] | |||||
| st1 {v14.h}[1], [x16] | |||||
| add x0, x0, #12 | |||||
| b WriteEnd | |||||
| Write7: | |||||
| add x17, x15, #8 | |||||
| add x19, x15, #10 | |||||
| add x16, x15, #12 | |||||
| st1 {v16.4h}, [x15], x7 | |||||
| ins v0.s[0], v16.s[2] | |||||
| st1 {v0.h}[0], [x17], x7 | |||||
| st1 {v0.h}[1], [x19], x7 | |||||
| st1 {v16.h}[6], [x16], x7 | |||||
| st1 {v17.4h}, [x15], x7 | |||||
| ins v1.s[0], v17.s[2] | |||||
| st1 {v1.h}[0], [x17], x7 | |||||
| st1 {v1.h}[1], [x19], x7 | |||||
| st1 {v17.h}[6], [x16], x7 | |||||
| st1 {v18.4h}, [x15], x7 | |||||
| ins v2.s[0], v18.s[2] | |||||
| st1 {v2.h}[0], [x17], x7 | |||||
| st1 {v2.h}[1], [x19], x7 | |||||
| st1 {v18.h}[6], [x16], x7 | |||||
| st1 {v19.4h}, [x15], x7 | |||||
| ins v3.s[0], v19.s[2] | |||||
| st1 {v3.h}[0], [x17], x7 | |||||
| st1 {v3.h}[1], [x19], x7 | |||||
| st1 {v19.h}[6], [x16], x7 | |||||
| st1 {v20.4h}, [x15], x7 | |||||
| ins v4.s[0], v20.s[2] | |||||
| st1 {v4.h}[0], [x17], x7 | |||||
| st1 {v4.h}[1], [x19], x7 | |||||
| st1 {v20.h}[6], [x16], x7 | |||||
| st1 {v21.4h}, [x15], x7 | |||||
| ins v5.s[0], v21.s[2] | |||||
| st1 {v5.h}[0], [x17], x7 | |||||
| st1 {v5.h}[1], [x19], x7 | |||||
| st1 {v21.h}[6], [x16], x7 | |||||
| st1 {v22.4h}, [x15], x7 | |||||
| ins v6.s[0], v22.s[2] | |||||
| st1 {v6.h}[0], [x17], x7 | |||||
| st1 {v6.h}[1], [x19], x7 | |||||
| st1 {v22.h}[6], [x16], x7 | |||||
| st1 {v23.4h}, [x15], x7 | |||||
| ins v7.s[0], v23.s[2] | |||||
| st1 {v7.h}[0], [x17], x7 | |||||
| st1 {v7.h}[1], [x19], x7 | |||||
| st1 {v23.h}[6], [x16], x7 | |||||
| st1 {v24.4h}, [x15], x7 | |||||
| ins v8.s[0], v24.s[2] | |||||
| st1 {v8.h}[0], [x17], x7 | |||||
| st1 {v8.h}[1], [x19], x7 | |||||
| st1 {v24.h}[6], [x16], x7 | |||||
| st1 {v25.4h}, [x15], x7 | |||||
| ins v9.s[0], v25.s[2] | |||||
| st1 {v9.h}[0], [x17], x7 | |||||
| st1 {v9.h}[1], [x19], x7 | |||||
| st1 {v25.h}[6], [x16], x7 | |||||
| st1 {v26.4h}, [x15], x7 | |||||
| ins v10.s[0], v26.s[2] | |||||
| st1 {v10.h}[0], [x17], x7 | |||||
| st1 {v10.h}[1], [x19], x7 | |||||
| st1 {v26.h}[6], [x16], x7 | |||||
| st1 {v27.4h}, [x15], x7 | |||||
| ins v11.s[0], v27.s[2] | |||||
| st1 {v11.h}[0], [x17], x7 | |||||
| st1 {v11.h}[1], [x19], x7 | |||||
| st1 {v27.h}[6], [x16], x7 | |||||
| st1 {v28.4h}, [x15], x7 | |||||
| ins v12.s[0], v28.s[2] | |||||
| st1 {v12.h}[0], [x17], x7 | |||||
| st1 {v12.h}[1], [x19], x7 | |||||
| st1 {v28.h}[6], [x16], x7 | |||||
| st1 {v29.4h}, [x15], x7 | |||||
| ins v13.s[0], v29.s[2] | |||||
| st1 {v13.h}[0], [x17], x7 | |||||
| st1 {v13.h}[1], [x19], x7 | |||||
| st1 {v29.h}[6], [x16], x7 | |||||
| st1 {v30.4h}, [x15], x7 | |||||
| ins v14.s[0], v30.s[2] | |||||
| st1 {v14.h}[0], [x17], x7 | |||||
| st1 {v14.h}[1], [x19], x7 | |||||
| st1 {v30.h}[6], [x16], x7 | |||||
| st1 {v31.4h}, [x15] | |||||
| ins v15.s[0], v31.s[2] | |||||
| st1 {v15.h}[0], [x17] | |||||
| st1 {v15.h}[1], [x19] | |||||
| st1 {v31.h}[6], [x16] | |||||
| add x0, x0, #14 | |||||
| b WriteEnd | |||||
| Write8: | |||||
| st1 {v16.8h}, [x15], x7 | |||||
| st1 {v17.8h}, [x15], x7 | |||||
| st1 {v18.8h}, [x15], x7 | |||||
| st1 {v19.8h}, [x15], x7 | |||||
| st1 {v20.8h}, [x15], x7 | |||||
| st1 {v21.8h}, [x15], x7 | |||||
| st1 {v22.8h}, [x15], x7 | |||||
| st1 {v23.8h}, [x15], x7 | |||||
| st1 {v24.8h}, [x15], x7 | |||||
| st1 {v25.8h}, [x15], x7 | |||||
| st1 {v26.8h}, [x15], x7 | |||||
| st1 {v27.8h}, [x15], x7 | |||||
| st1 {v28.8h}, [x15], x7 | |||||
| st1 {v29.8h}, [x15], x7 | |||||
| st1 {v30.8h}, [x15], x7 | |||||
| st1 {v31.8h}, [x15] | |||||
| add x0, x0, #16 | |||||
| WriteEnd: | |||||
| subs x14, x14, #1 | |||||
| bne LoopKsize | |||||
| subs x6, x6, #8 | |||||
| cbz x3, NoStepForward | |||||
| add x3, x3, #16 | |||||
| NoStepForward: | |||||
| bgt LoopOc | |||||
| sub sp, sp, #144 | |||||
| ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 | |||||
| ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 | |||||
| ldp x19, x20, [sp], #16 | |||||
| ret | |||||
| #endif | |||||
| @@ -29,7 +29,7 @@ int ReluFp16(const float16_t *src, float16_t *dst, int ele_num) { | |||||
| } | } | ||||
| #endif | #endif | ||||
| for (; offset < ele_num; offset++) { | for (; offset < ele_num; offset++) { | ||||
| dst[offset] = src[offset] < 0 ? 0 : src[offset]; | |||||
| dst[offset] = src[offset] < 0.0f ? 0.0f : src[offset]; | |||||
| } | } | ||||
| return NNACL_OK; | return NNACL_OK; | ||||
| } | } | ||||
| @@ -47,14 +47,24 @@ int Relu6Fp16(const float16_t *data, float16_t *dst, int ele_num) { | |||||
| } | } | ||||
| #endif | #endif | ||||
| for (; offset < ele_num; offset++) { | for (; offset < ele_num; offset++) { | ||||
| dst[offset] = data[offset] < 0 ? 0 : data[offset]; | |||||
| dst[offset] = dst[offset] > 6 ? 6 : dst[offset]; | |||||
| dst[offset] = data[offset] < 0.0f ? 0.0f : data[offset]; | |||||
| dst[offset] = dst[offset] > 6.0f ? 6.0f : dst[offset]; | |||||
| } | } | ||||
| return NNACL_OK; | return NNACL_OK; | ||||
| } | } | ||||
| int LReluFp16(const float16_t *src, float16_t *dst, int ele_num, float16_t alpha) { | int LReluFp16(const float16_t *src, float16_t *dst, int ele_num, float16_t alpha) { | ||||
| for (int i = 0; i < ele_num; ++i) { | |||||
| int i = 0; | |||||
| #ifdef ENABLE_NEON | |||||
| int ele_c8 = UP_ROUND(ele_num, C8NUM); | |||||
| for (; i < ele_c8; i += C8NUM) { | |||||
| float16x8_t src_tmp = vld1q_f16(src + i); | |||||
| float16x8_t mul_tmp = vmulq_n_f16(src_tmp, alpha); | |||||
| float16x8_t mask = vcgtq_f16(src_tmp, vdupq_n_f16(0.0f)); | |||||
| vst1q_f16(dst + i, vbslq_f32(mask, src_tmp, mul_tmp)); | |||||
| } | |||||
| #endif | |||||
| for (; i < ele_num; ++i) { | |||||
| dst[i] = src[i] > (float16_t)0.0f ? src[i] : (src[i] * alpha); | dst[i] = src[i] > (float16_t)0.0f ? src[i] : (src[i] * alpha); | ||||
| } | } | ||||
| return NNACL_OK; | return NNACL_OK; | ||||
| @@ -62,12 +72,12 @@ int LReluFp16(const float16_t *src, float16_t *dst, int ele_num, float16_t alpha | |||||
| int SigmoidFp16(const float16_t *src, float16_t *dst, int ele_num) { | int SigmoidFp16(const float16_t *src, float16_t *dst, int ele_num) { | ||||
| int i = 0; | int i = 0; | ||||
| #ifdef ENABLE_ARM64 | |||||
| #ifdef ENABLE_NEON | |||||
| int count = (ele_num / C4NUM) * C4NUM; | int count = (ele_num / C4NUM) * C4NUM; | ||||
| for (; i < count; i += C4NUM) { | for (; i < count; i += C4NUM) { | ||||
| float32x4_t tmp; | float32x4_t tmp; | ||||
| simd_exp(vnegq_f32(vcvt_f32_f16(vld1_f16(src + i))), (float *)&tmp); | simd_exp(vnegq_f32(vcvt_f32_f16(vld1_f16(src + i))), (float *)&tmp); | ||||
| vst1_f16(dst + i, vcvt_f16_f32(vdivq_f32(vdupq_n_f32(1.0f), vaddq_f32(vdupq_n_f32(1.0f), tmp)))); | |||||
| vst1_f16(dst + i, vcvt_f16_f32(MS_DIVQ_F32(vdupq_n_f32(1.0f), vaddq_f32(vdupq_n_f32(1.0f), tmp)))); | |||||
| } | } | ||||
| #endif | #endif | ||||
| for (; i < ele_num; ++i) { | for (; i < ele_num; ++i) { | ||||
| @@ -79,9 +89,9 @@ int SigmoidFp16(const float16_t *src, float16_t *dst, int ele_num) { | |||||
| } | } | ||||
| float16_t TanhOptFp16(float16_t src) { | float16_t TanhOptFp16(float16_t src) { | ||||
| if (src > 5.0) { | |||||
| if (src > 5.0f) { | |||||
| return 1.0f; | return 1.0f; | ||||
| } else if (src < -5.0) { | |||||
| } else if (src < -5.0f) { | |||||
| return -1.0f; | return -1.0f; | ||||
| } else { | } else { | ||||
| float square = src * src; | float square = src * src; | ||||
| @@ -93,7 +103,7 @@ float16_t TanhOptFp16(float16_t src) { | |||||
| int TanhFp16(const float16_t *src, float16_t *dst, int ele_num) { | int TanhFp16(const float16_t *src, float16_t *dst, int ele_num) { | ||||
| int i = 0; | int i = 0; | ||||
| #ifdef ENABLE_ARM64 | |||||
| #ifdef ENABLE_NEON | |||||
| static float32x4_t paramv[] = {{378.0f, 378.0f, 378.0f, 378.0f}, | static float32x4_t paramv[] = {{378.0f, 378.0f, 378.0f, 378.0f}, | ||||
| {17325.0f, 17325.0f, 17325.0f, 17325.0f}, | {17325.0f, 17325.0f, 17325.0f, 17325.0f}, | ||||
| {135135.0f, 135135.0f, 135135.0f, 135135.0f}, | {135135.0f, 135135.0f, 135135.0f, 135135.0f}, | ||||
| @@ -112,7 +122,7 @@ int TanhFp16(const float16_t *src, float16_t *dst, int ele_num) { | |||||
| float32x4_t b = vaddq_f32( | float32x4_t b = vaddq_f32( | ||||
| vmulq_f32(vaddq_f32(vmulq_f32(vaddq_f32(vmulq_f32(paramv[3], square), paramv[4]), square), paramv[5]), square), | vmulq_f32(vaddq_f32(vmulq_f32(vaddq_f32(vmulq_f32(paramv[3], square), paramv[4]), square), paramv[5]), square), | ||||
| paramv[2]); | paramv[2]); | ||||
| vst1_f16(dst + i, vcvt_f16_f32(vminq_f32(vmaxq_f32(vdivq_f32(a, b), neg_one), pos_one))); | |||||
| vst1_f16(dst + i, vcvt_f16_f32(vminq_f32(vmaxq_f32(MS_DIVQ_F32(a, b), neg_one), pos_one))); | |||||
| } | } | ||||
| #endif | #endif | ||||
| for (; i < ele_num; ++i) { | for (; i < ele_num; ++i) { | ||||
| @@ -130,7 +140,7 @@ int TanhFp16(const float16_t *src, float16_t *dst, int ele_num) { | |||||
| int HSwishFp16(const float16_t *src, float16_t *dst, int ele_num) { | int HSwishFp16(const float16_t *src, float16_t *dst, int ele_num) { | ||||
| for (int i = 0; i < ele_num; ++i) { | for (int i = 0; i < ele_num; ++i) { | ||||
| float16_t in = src[i]; | float16_t in = src[i]; | ||||
| float16_t relu6 = MSMIN(MSMAX(in + 3, 0), 6); | |||||
| float16_t relu6 = MSMIN(MSMAX(in + 3.0f, 0.0f), 6.0f); | |||||
| dst[i] = in * relu6 / (float16_t)6.0f; | dst[i] = in * relu6 / (float16_t)6.0f; | ||||
| } | } | ||||
| return NNACL_OK; | return NNACL_OK; | ||||
| @@ -181,26 +191,26 @@ int GeluFp16(const float16_t *src, int length, float16_t *dst, bool approximate) | |||||
| for (; i < C8; i += C8NUM) { | for (; i < C8; i += C8NUM) { | ||||
| float16x8_t in = vld1q_f16(src + i); | float16x8_t in = vld1q_f16(src + i); | ||||
| float16x8_t res = | float16x8_t res = | ||||
| 0.5 * in * (1.0 + MS_TANHX8_F16(((float16_t)0.79788456080287 + (float16_t)0.035677408136 * in * in) * in)); | |||||
| 0.5f * in * (1.0f + MS_TANHX8_F16(((float16_t)0.79788456080287f + (float16_t)0.035677408136f * in * in) * in)); | |||||
| vst1q_f16(dst + i, res); | vst1q_f16(dst + i, res); | ||||
| } | } | ||||
| #endif | #endif | ||||
| for (; i < length; i++) { | for (; i < length; i++) { | ||||
| dst[i] = | dst[i] = | ||||
| 0.5 * src[i] * | |||||
| (1.0 + TanhOptFp16(((float16_t)0.79788456080287 + (float16_t)0.035677408136 * src[i] * src[i]) * src[i])); | |||||
| 0.5f * src[i] * | |||||
| (1.0f + TanhOptFp16(((float16_t)0.79788456080287f + (float16_t)0.035677408136f * src[i] * src[i]) * src[i])); | |||||
| } | } | ||||
| } else { | } else { | ||||
| #ifdef ENABLE_NEON | #ifdef ENABLE_NEON | ||||
| int C8 = UP_ROUND(length, C8NUM); | int C8 = UP_ROUND(length, C8NUM); | ||||
| for (; i < C8; i += C8NUM) { | for (; i < C8; i += C8NUM) { | ||||
| float16x8_t in = vld1q_f16(src + i); | float16x8_t in = vld1q_f16(src + i); | ||||
| const float16x8_t res = 0.5 * in * (1.0 + MS_ERFX8_F16(in / (float16_t)1.4142135623730951f)); | |||||
| const float16x8_t res = 0.5f * in * (1.0f + MS_ERFX8_F16(in / (float16_t)1.4142135623730951f)); | |||||
| vst1q_f16(dst + i, res); | vst1q_f16(dst + i, res); | ||||
| } | } | ||||
| #endif | #endif | ||||
| for (; i < length; i++) { | for (; i < length; i++) { | ||||
| dst[i] = 0.5 * src[i] * (1.0 + erf(src[i] / 1.4142135623730951f)); | |||||
| dst[i] = 0.5f * src[i] * (1.0f + erff(src[i] / 1.4142135623730951f)); | |||||
| } | } | ||||
| } | } | ||||
| return NNACL_OK; | return NNACL_OK; | ||||
| @@ -16,11 +16,9 @@ | |||||
| #ifndef MINDSPORE_NNACL_FP16_ACTIVATION_FP16_H_ | #ifndef MINDSPORE_NNACL_FP16_ACTIVATION_FP16_H_ | ||||
| #define MINDSPORE_NNACL_FP16_ACTIVATION_FP16_H_ | #define MINDSPORE_NNACL_FP16_ACTIVATION_FP16_H_ | ||||
| #ifdef ENABLE_NEON | |||||
| #include <arm_neon.h> | |||||
| #endif | |||||
| #include <math.h> | #include <math.h> | ||||
| #include "nnacl/op_base.h" | #include "nnacl/op_base.h" | ||||
| #include "nnacl/intrinsics/ms_simd_instructions_fp16.h" | |||||
| #include "nnacl/int8/fixed_point.h" | #include "nnacl/int8/fixed_point.h" | ||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| @@ -569,7 +569,7 @@ int ElementDivFp16(const float16_t *input0, const float16_t *input1, float16_t * | |||||
| for (; index <= element_size - 8; index += C8NUM) { | for (; index <= element_size - 8; index += C8NUM) { | ||||
| float16x8_t vin0 = vld1q_f16(input0 + index); | float16x8_t vin0 = vld1q_f16(input0 + index); | ||||
| float16x8_t vin1 = vld1q_f16(input1 + index); | float16x8_t vin1 = vld1q_f16(input1 + index); | ||||
| float16x8_t vout = vdivq_f16(vin0, vin1); | |||||
| float16x8_t vout = MS_DIVQ_F16(vin0, vin1); | |||||
| vst1q_f16(output + index, vout); | vst1q_f16(output + index, vout); | ||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -591,7 +591,7 @@ int ElementOptDivFp16(const float16_t *input0, const float16_t *input1, float16_ | |||||
| #ifdef ENABLE_NEON | #ifdef ENABLE_NEON | ||||
| for (; index <= element_size - 8; index += C8NUM) { | for (; index <= element_size - 8; index += C8NUM) { | ||||
| float16x8_t vin1 = vld1q_f16(input1 + index); | float16x8_t vin1 = vld1q_f16(input1 + index); | ||||
| float16x8_t vout = vdivq_f16(vin0_opt, vin1); | |||||
| float16x8_t vout = MS_DIVQ_F16(vin0_opt, vin1); | |||||
| vst1q_f16(output + index, vout); | vst1q_f16(output + index, vout); | ||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -606,7 +606,7 @@ int ElementOptDivFp16(const float16_t *input0, const float16_t *input1, float16_ | |||||
| #ifdef ENABLE_NEON | #ifdef ENABLE_NEON | ||||
| for (; index <= element_size - 8; index += C8NUM) { | for (; index <= element_size - 8; index += C8NUM) { | ||||
| float16x8_t vin0 = vld1q_f16(input0 + index); | float16x8_t vin0 = vld1q_f16(input0 + index); | ||||
| float16x8_t vout = vdivq_f16(vin0, vin1_opt); | |||||
| float16x8_t vout = MS_DIVQ_F16(vin0, vin1_opt); | |||||
| vst1q_f16(output + index, vout); | vst1q_f16(output + index, vout); | ||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -624,7 +624,7 @@ int ElementDivReluFp16(const float16_t *input0, const float16_t *input1, float16 | |||||
| for (; index <= element_size - 8; index += C8NUM) { | for (; index <= element_size - 8; index += C8NUM) { | ||||
| float16x8_t vin0 = vld1q_f16(input0 + index); | float16x8_t vin0 = vld1q_f16(input0 + index); | ||||
| float16x8_t vin1 = vld1q_f16(input1 + index); | float16x8_t vin1 = vld1q_f16(input1 + index); | ||||
| float16x8_t vout = vdivq_f16(vin0, vin1); | |||||
| float16x8_t vout = MS_DIVQ_F16(vin0, vin1); | |||||
| vout = vmaxq_f16(vout, zeros); | vout = vmaxq_f16(vout, zeros); | ||||
| vst1q_f16(output + index, vout); | vst1q_f16(output + index, vout); | ||||
| } | } | ||||
| @@ -652,7 +652,7 @@ int ElementOptDivReluFp16(const float16_t *input0, const float16_t *input1, floa | |||||
| #ifdef ENABLE_NEON | #ifdef ENABLE_NEON | ||||
| for (; index <= element_size - 8; index += C8NUM) { | for (; index <= element_size - 8; index += C8NUM) { | ||||
| float16x8_t vin1 = vld1q_f16(input1 + index); | float16x8_t vin1 = vld1q_f16(input1 + index); | ||||
| float16x8_t vout = vmaxq_f16(vdivq_f16(vin0_opt, vin1), zeros); | |||||
| float16x8_t vout = vmaxq_f16(MS_DIVQ_F16(vin0_opt, vin1), zeros); | |||||
| vst1q_f16(output + index, vout); | vst1q_f16(output + index, vout); | ||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -670,7 +670,7 @@ int ElementOptDivReluFp16(const float16_t *input0, const float16_t *input1, floa | |||||
| #ifdef ENABLE_NEON | #ifdef ENABLE_NEON | ||||
| for (; index <= element_size - 8; index += C8NUM) { | for (; index <= element_size - 8; index += C8NUM) { | ||||
| float16x8_t vin0 = vld1q_f16(input0 + index); | float16x8_t vin0 = vld1q_f16(input0 + index); | ||||
| float16x8_t vout = vmaxq_f16(vdivq_f16(vin0, vin1_opt), zeros); | |||||
| float16x8_t vout = vmaxq_f16(MS_DIVQ_F16(vin0, vin1_opt), zeros); | |||||
| vst1q_f16(output + index, vout); | vst1q_f16(output + index, vout); | ||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -689,7 +689,7 @@ int ElementDivRelu6Fp16(const float16_t *input0, const float16_t *input1, float1 | |||||
| for (; index <= element_size - 8; index += C8NUM) { | for (; index <= element_size - 8; index += C8NUM) { | ||||
| float16x8_t vin0 = vld1q_f16(input0 + index); | float16x8_t vin0 = vld1q_f16(input0 + index); | ||||
| float16x8_t vin1 = vld1q_f16(input1 + index); | float16x8_t vin1 = vld1q_f16(input1 + index); | ||||
| float16x8_t vout = vdivq_f16(vin0, vin1); | |||||
| float16x8_t vout = MS_DIVQ_F16(vin0, vin1); | |||||
| vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); | vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); | ||||
| vst1q_f16(output + index, vout); | vst1q_f16(output + index, vout); | ||||
| } | } | ||||
| @@ -716,7 +716,7 @@ int ElementOptDivRelu6Fp16(const float16_t *input0, const float16_t *input1, flo | |||||
| #ifdef ENABLE_NEON | #ifdef ENABLE_NEON | ||||
| for (; index <= element_size - 8; index += C8NUM) { | for (; index <= element_size - 8; index += C8NUM) { | ||||
| float16x8_t vin1 = vld1q_f16(input1 + index); | float16x8_t vin1 = vld1q_f16(input1 + index); | ||||
| float16x8_t vout = vminq_f16(vmaxq_f16(vdivq_f16(vin0_opt, vin1), zeros), bounds); | |||||
| float16x8_t vout = vminq_f16(vmaxq_f16(MS_DIVQ_F16(vin0_opt, vin1), zeros), bounds); | |||||
| vst1q_f16(output + index, vout); | vst1q_f16(output + index, vout); | ||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -733,7 +733,7 @@ int ElementOptDivRelu6Fp16(const float16_t *input0, const float16_t *input1, flo | |||||
| #ifdef ENABLE_NEON | #ifdef ENABLE_NEON | ||||
| for (; index <= element_size - 8; index += C8NUM) { | for (; index <= element_size - 8; index += C8NUM) { | ||||
| float16x8_t vin0 = vld1q_f16(input0 + index); | float16x8_t vin0 = vld1q_f16(input0 + index); | ||||
| float16x8_t vout = vminq_f16(vmaxq_f16(vdivq_f16(vin0, vin1_opt), zeros), bounds); | |||||
| float16x8_t vout = vminq_f16(vmaxq_f16(MS_DIVQ_F16(vin0, vin1_opt), zeros), bounds); | |||||
| vst1q_f16(output + index, vout); | vst1q_f16(output + index, vout); | ||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -16,10 +16,8 @@ | |||||
| #ifndef MINDSPORE_NNACL_FP16_ARITHMETIC_FP16_H_ | #ifndef MINDSPORE_NNACL_FP16_ARITHMETIC_FP16_H_ | ||||
| #define MINDSPORE_NNACL_FP16_ARITHMETIC_FP16_H_ | #define MINDSPORE_NNACL_FP16_ARITHMETIC_FP16_H_ | ||||
| #ifdef ENABLE_NEON | |||||
| #include <arm_neon.h> | |||||
| #endif | |||||
| #include "nnacl/op_base.h" | #include "nnacl/op_base.h" | ||||
| #include "nnacl/intrinsics/ms_simd_instructions_fp16.h" | |||||
| #include "nnacl/base/arithmetic_base.h" | #include "nnacl/base/arithmetic_base.h" | ||||
| #include "nnacl/errorcode.h" | #include "nnacl/errorcode.h" | ||||
| @@ -80,7 +80,7 @@ int ElementLogicalNotFp16(float16_t *input, float16_t *output, int element_size) | |||||
| int ElementRoundFp16(float16_t *input, float16_t *output, int element_size) { | int ElementRoundFp16(float16_t *input, float16_t *output, int element_size) { | ||||
| for (int i = 0; i < element_size; i++) { | for (int i = 0; i < element_size; i++) { | ||||
| output[i] = round(input[i]); | |||||
| output[i] = roundf(input[i]); | |||||
| } | } | ||||
| return NNACL_OK; | return NNACL_OK; | ||||
| } | } | ||||
| @@ -94,7 +94,7 @@ int ElementFloorFp16(float16_t *input, float16_t *output, int element_size) { | |||||
| int ElementCeilFp16(float16_t *input, float16_t *output, int number) { | int ElementCeilFp16(float16_t *input, float16_t *output, int number) { | ||||
| for (int i = 0; i < number; ++i) { | for (int i = 0; i < number; ++i) { | ||||
| output[i] = ceil(input[i]); | |||||
| output[i] = ceilf(input[i]); | |||||
| } | } | ||||
| return NNACL_OK; | return NNACL_OK; | ||||
| } | } | ||||
| @@ -16,10 +16,8 @@ | |||||
| #ifndef MINDSPORE_NNACL_FP16_ARITHMETIC_SELF_FP16_H_ | #ifndef MINDSPORE_NNACL_FP16_ARITHMETIC_SELF_FP16_H_ | ||||
| #define MINDSPORE_NNACL_FP16_ARITHMETIC_SELF_FP16_H_ | #define MINDSPORE_NNACL_FP16_ARITHMETIC_SELF_FP16_H_ | ||||
| #ifdef ENABLE_NEON | |||||
| #include <arm_neon.h> | |||||
| #endif | |||||
| #include "nnacl/op_base.h" | #include "nnacl/op_base.h" | ||||
| #include "nnacl/intrinsics/ms_simd_instructions_fp16.h" | |||||
| #include "nnacl/errorcode.h" | #include "nnacl/errorcode.h" | ||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| @@ -26,7 +26,7 @@ void BatchNormFp16(const float16_t *input, const void *mean, const void *varianc | |||||
| for (int i = 0; i < cur_unit; i++) { | for (int i = 0; i < cur_unit; i++) { | ||||
| for (int c = 0; c < param->channel_; c++) { | for (int c = 0; c < param->channel_; c++) { | ||||
| float16_t variance_sqrt = sqrt(((const float16_t *)variance)[c] + param->epsilon_); | |||||
| float16_t variance_sqrt = sqrtf(((const float16_t *)variance)[c] + param->epsilon_); | |||||
| if (variance_sqrt != 0) { | if (variance_sqrt != 0) { | ||||
| output[cur_offset + c] = (input[cur_offset + c] - ((const float16_t *)mean)[c]) / variance_sqrt; | output[cur_offset + c] = (input[cur_offset + c] - ((const float16_t *)mean)[c]) / variance_sqrt; | ||||
| } | } | ||||
| @@ -44,7 +44,7 @@ void FusedBatchNormFp16(const void *input, const void *scale, const void *offset | |||||
| for (int i = 0; i < cur_unit; i++) { | for (int i = 0; i < cur_unit; i++) { | ||||
| for (int c = 0; c < param->channel_; c++) { | for (int c = 0; c < param->channel_; c++) { | ||||
| float16_t variance_sqrt = sqrt(((const float16_t *)variance)[c] + param->epsilon_); | |||||
| float16_t variance_sqrt = sqrtf(((const float16_t *)variance)[c] + param->epsilon_); | |||||
| if (variance_sqrt != 0) { | if (variance_sqrt != 0) { | ||||
| float16_t norm_val = | float16_t norm_val = | ||||
| (((const float16_t *)input)[cur_offset + c] - ((const float16_t *)mean)[c]) / variance_sqrt; | (((const float16_t *)input)[cur_offset + c] - ((const float16_t *)mean)[c]) / variance_sqrt; | ||||
| @@ -13,12 +13,9 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_BATCHNORM_FP16_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_BATCHNORM_FP16_H_ | |||||
| #ifndef MINDSPORE_NNACL_FP16_BATCHNORM_FP16_H_ | |||||
| #define MINDSPORE_NNACL_FP16_BATCHNORM_FP16_H_ | |||||
| #ifdef ENABLE_NEON | |||||
| #include <arm_neon.h> | |||||
| #endif | |||||
| #include "nnacl/batchnorm_parameter.h" | #include "nnacl/batchnorm_parameter.h" | ||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| @@ -34,4 +31,4 @@ void FusedBatchNormFp16(const void *input, const void *scale, const void *offset | |||||
| } | } | ||||
| #endif | #endif | ||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_BATCHNORM_FP16_H_ | |||||
| #endif // MINDSPORE_NNACL_FP16_BATCHNORM_FP16_H_ | |||||
| @@ -16,7 +16,6 @@ | |||||
| #ifndef MINDSPORE_NNACL_CAST_FP16_H_ | #ifndef MINDSPORE_NNACL_CAST_FP16_H_ | ||||
| #define MINDSPORE_NNACL_CAST_FP16_H_ | #define MINDSPORE_NNACL_CAST_FP16_H_ | ||||
| #include <arm_neon.h> | |||||
| #include "nnacl/op_base.h" | #include "nnacl/op_base.h" | ||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| @@ -56,3 +56,17 @@ void PostConvFuncFp16C4(const float16_t *c4_out, float16_t *nhwc_out, const floa | |||||
| PostFuncBiasReluC4Fp16(nhwc_out, c4_out, bias, oc4div, oc4mod, plane, stride_size, act_type); | PostFuncBiasReluC4Fp16(nhwc_out, c4_out, bias, oc4div, oc4mod, plane, stride_size, act_type); | ||||
| return; | return; | ||||
| } | } | ||||
| #ifdef ENABLE_ARM82_A32 | |||||
| void PostFuncBiasReluC4Fp16(float16_t *dst, const float16_t *src, const float16_t *bias, size_t oc4div, size_t oc4mod, | |||||
| size_t plane_size, size_t plane_stride, size_t relu_type) { | |||||
| // TODO(fun): function | |||||
| return; | |||||
| } | |||||
| void PostFuncBiasReluC8Fp16(float16_t *dst, const float16_t *src, const float16_t *bias, size_t oc8div, size_t oc8mod, | |||||
| size_t plane_size, size_t stride, size_t relu_type) { | |||||
| // TODO(fun): function | |||||
| return; | |||||
| } | |||||
| #endif | |||||
| @@ -16,7 +16,6 @@ | |||||
| #ifndef MINDSPORE_NNACL_FP16_COMMON_FUNC_FP16_H_ | #ifndef MINDSPORE_NNACL_FP16_COMMON_FUNC_FP16_H_ | ||||
| #define MINDSPORE_NNACL_FP16_COMMON_FUNC_FP16_H_ | #define MINDSPORE_NNACL_FP16_COMMON_FUNC_FP16_H_ | ||||
| #include <arm_neon.h> | |||||
| #include "nnacl/op_base.h" | #include "nnacl/op_base.h" | ||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| @@ -16,9 +16,6 @@ | |||||
| #ifndef MINDSPORE_NNACL_FP16_CONSTANT_OF_SHAPE_FP16_H_ | #ifndef MINDSPORE_NNACL_FP16_CONSTANT_OF_SHAPE_FP16_H_ | ||||
| #define MINDSPORE_NNACL_FP16_CONSTANT_OF_SHAPE_FP16_H_ | #define MINDSPORE_NNACL_FP16_CONSTANT_OF_SHAPE_FP16_H_ | ||||
| #ifdef ENABLE_NEON | |||||
| #include <arm_neon.h> | |||||
| #endif | |||||
| #include "nnacl/op_base.h" | #include "nnacl/op_base.h" | ||||
| #include "nnacl/errorcode.h" | #include "nnacl/errorcode.h" | ||||
| #include "nnacl/constant_of_shape_parameter.h" | #include "nnacl/constant_of_shape_parameter.h" | ||||
| @@ -27,7 +24,7 @@ | |||||
| extern "C" { | extern "C" { | ||||
| #endif | #endif | ||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| #ifdef ENABLE_NEON | |||||
| #ifdef ENABLE_FP16 | |||||
| inline int ConstantOfShapeFp16(float16_t *output, int start, int end, float16_t value) { | inline int ConstantOfShapeFp16(float16_t *output, int start, int end, float16_t value) { | ||||
| for (int i = start; i < end; i++) { | for (int i = start; i < end; i++) { | ||||
| output[i] = value; | output[i] = value; | ||||
| @@ -18,6 +18,18 @@ | |||||
| #include <string.h> | #include <string.h> | ||||
| #include "nnacl/fp16/activation_fp16.h" | #include "nnacl/fp16/activation_fp16.h" | ||||
| #ifdef ENABLE_ARM82_A32 | |||||
| void ConvDwFp16Row(float16_t *output_ptr, const float16_t *input_ptr, const float16_t *weight_ptr, size_t num_pixels, | |||||
| size_t output_channel, size_t input_step) { | |||||
| for (int i = 0; i < num_pixels; i++) { | |||||
| for (int c = 0; c < output_channel; c++) { | |||||
| *output_ptr++ += weight_ptr[c] * input_ptr[c]; | |||||
| } | |||||
| input_ptr += input_step; | |||||
| } | |||||
| } | |||||
| #endif | |||||
| void ConvDwFp16(float16_t *output_data, const float16_t *input_data, const float16_t *weight_data, | void ConvDwFp16(float16_t *output_data, const float16_t *input_data, const float16_t *weight_data, | ||||
| const float16_t *bias_data, const ConvParameter *conv_param, int task_id) { | const float16_t *bias_data, const ConvParameter *conv_param, int task_id) { | ||||
| int h_step = UP_DIV(conv_param->output_h_, conv_param->thread_num_); | int h_step = UP_DIV(conv_param->output_h_, conv_param->thread_num_); | ||||
| @@ -57,7 +69,6 @@ void ConvDwFp16(float16_t *output_data, const float16_t *input_data, const float | |||||
| const float16_t *src_kw = src_kh + iw_origin * conv_param->input_channel_; | const float16_t *src_kw = src_kh + iw_origin * conv_param->input_channel_; | ||||
| int num_pixels = out_w_end - out_w_start; | int num_pixels = out_w_end - out_w_start; | ||||
| ConvDwFp16Row(dst_w, src_kw, weight_kh, num_pixels, conv_param->output_channel_, in_sw_step); | ConvDwFp16Row(dst_w, src_kw, weight_kh, num_pixels, conv_param->output_channel_, in_sw_step); | ||||
| weight_kh += conv_param->output_channel_; | weight_kh += conv_param->output_channel_; | ||||
| } | } | ||||
| @@ -23,9 +23,9 @@ | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| extern "C" { | extern "C" { | ||||
| #endif | #endif | ||||
| #ifdef ENABLE_ARM64 | |||||
| void ConvDwFp16Row(float16_t *output_ptr, const float16_t *input_ptr, const float16_t *filter_ptr, size_t num_pixels, | void ConvDwFp16Row(float16_t *output_ptr, const float16_t *input_ptr, const float16_t *filter_ptr, size_t num_pixels, | ||||
| size_t input_channel, size_t input_step); | size_t input_channel, size_t input_step); | ||||
| #ifdef ENABLE_ARM64 | |||||
| void ConvDwFp16Border(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, | void ConvDwFp16Border(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, | ||||
| size_t height, size_t width, size_t in_kh_step, size_t in_kw_step, size_t kernel_w, size_t relu, | size_t height, size_t width, size_t in_kh_step, size_t in_kw_step, size_t kernel_w, size_t relu, | ||||
| size_t relu6); | size_t relu6); | ||||
| @@ -31,7 +31,7 @@ void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weigh | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| } | } | ||||
| #endif | #endif | ||||
| #ifndef ENABLE_NEON | |||||
| #ifndef ENABLE_ARM64 | |||||
| void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step, | void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step, | ||||
| size_t ic4, size_t out_channel, size_t offset, size_t mode, size_t writeC8, size_t relu, | size_t ic4, size_t out_channel, size_t offset, size_t mode, size_t writeC8, size_t relu, | ||||
| size_t relu6) { | size_t relu6) { | ||||
| @@ -124,7 +124,11 @@ void IndirectGemmFp16_16x8_c8(float16_t *output, float16_t *input, float16_t *we | |||||
| // fp16 convolution common (im2col+gemm) | // fp16 convolution common (im2col+gemm) | ||||
| void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_weight, float16_t *bias_data, | void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_weight, float16_t *bias_data, | ||||
| float16_t *col_major_input, float16_t *output_data, int task_id, ConvParameter *conv_param) { | float16_t *col_major_input, float16_t *output_data, int task_id, ConvParameter *conv_param) { | ||||
| #ifdef ENABLE_ARM64 | |||||
| const int tile_n = 16; | const int tile_n = 16; | ||||
| #else | |||||
| const int tile_n = 12; | |||||
| #endif | |||||
| int out_channel = conv_param->output_channel_; | int out_channel = conv_param->output_channel_; | ||||
| int output_count = conv_param->output_h_ * conv_param->output_w_; | int output_count = conv_param->output_h_ * conv_param->output_w_; | ||||
| int output_tile_count = UP_DIV(output_count, tile_n); | int output_tile_count = UP_DIV(output_count, tile_n); | ||||
| @@ -144,7 +148,11 @@ void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_ | |||||
| Im2ColPackUnitFp16(input_data + in_batch_offset, conv_param, gemm_input, real_cal_num, start_index); | Im2ColPackUnitFp16(input_data + in_batch_offset, conv_param, gemm_input, real_cal_num, start_index); | ||||
| int out_offset = thread_id * tile_n * out_channel + out_batch_offset; | int out_offset = thread_id * tile_n * out_channel + out_batch_offset; | ||||
| #ifdef ENABLE_ARM64 | |||||
| RowMajor2Col16MajorFp16Opt(gemm_input, col_major_gemm_input, tile_n, deep); | RowMajor2Col16MajorFp16Opt(gemm_input, col_major_gemm_input, tile_n, deep); | ||||
| #else | |||||
| RowMajor2Col12MajorFp16Opt(gemm_input, col_major_gemm_input, tile_n, deep); | |||||
| #endif | |||||
| MatMulFp16(col_major_gemm_input, packed_weight, output_data + out_offset, bias_data, conv_param->act_type_, deep, | MatMulFp16(col_major_gemm_input, packed_weight, output_data + out_offset, bias_data, conv_param->act_type_, deep, | ||||
| real_cal_num, out_channel, out_channel, OutType_Nhwc); | real_cal_num, out_channel, out_channel, OutType_Nhwc); | ||||
| } | } | ||||
| @@ -155,7 +163,11 @@ void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_ | |||||
| void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const float16_t *bias_data, | void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const float16_t *bias_data, | ||||
| float16_t *output_data, TmpBufferAddressFp16 *buffer_list, int task_id, ConvParameter *conv_param, | float16_t *output_data, TmpBufferAddressFp16 *buffer_list, int task_id, ConvParameter *conv_param, | ||||
| InputTransFp16Func in_func, OutputTransFp16Func out_func) { | InputTransFp16Func in_func, OutputTransFp16Func out_func) { | ||||
| #ifdef ENABLE_ARM64 | |||||
| const int tile_num = 16; | const int tile_num = 16; | ||||
| #else | |||||
| const int tile_num = 12; | |||||
| #endif | |||||
| int in_channel = conv_param->input_channel_; | int in_channel = conv_param->input_channel_; | ||||
| int out_w_block = UP_DIV(conv_param->output_w_, conv_param->output_unit_); | int out_w_block = UP_DIV(conv_param->output_w_, conv_param->output_unit_); | ||||
| int out_h_block = UP_DIV(conv_param->output_h_, conv_param->output_unit_); | int out_h_block = UP_DIV(conv_param->output_h_, conv_param->output_unit_); | ||||
| @@ -194,7 +206,11 @@ void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const floa | |||||
| float16_t *dst_ptr = gemm_out + task_id * gemm_out_offset; | float16_t *dst_ptr = gemm_out + task_id * gemm_out_offset; | ||||
| float16_t *tmp_col_ptr = col_buffer + task_id * col_buffer_offset; | float16_t *tmp_col_ptr = col_buffer + task_id * col_buffer_offset; | ||||
| for (int i = 0; i < input_unit_square; ++i) { | for (int i = 0; i < input_unit_square; ++i) { | ||||
| #ifdef ENABLE_ARM64 | |||||
| RowMajor2Col16MajorFp16Opt(src_ptr + i * tile_num * in_channel, tmp_col_ptr, cal_num, in_channel); | RowMajor2Col16MajorFp16Opt(src_ptr + i * tile_num * in_channel, tmp_col_ptr, cal_num, in_channel); | ||||
| #else | |||||
| RowMajor2Col12MajorFp16Opt(src_ptr + i * tile_num * in_channel, tmp_col_ptr, cal_num, in_channel); | |||||
| #endif | |||||
| MatMulFp16(tmp_col_ptr, trans_weight + i * in_channel * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0, in_channel, | MatMulFp16(tmp_col_ptr, trans_weight + i * in_channel * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0, in_channel, | ||||
| cal_num, oc8 * C8NUM, input_unit_square, OutType_TileC8); | cal_num, oc8 * C8NUM, input_unit_square, OutType_TileC8); | ||||
| } | } | ||||
| @@ -24,7 +24,7 @@ | |||||
| typedef float16_t *TmpBufferAddressFp16; | typedef float16_t *TmpBufferAddressFp16; | ||||
| typedef float16_t *MatricesFp16; | typedef float16_t *MatricesFp16; | ||||
| #ifndef ENABLE_NEON | |||||
| #ifndef ENABLE_ARM64 | |||||
| void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step, | void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step, | ||||
| size_t ic4, size_t oc8, size_t offset, size_t mode, size_t writeC8, size_t relu, | size_t ic4, size_t oc8, size_t offset, size_t mode, size_t writeC8, size_t relu, | ||||
| size_t relu6); | size_t relu6); | ||||
| @@ -17,7 +17,6 @@ | |||||
| #ifndef MINDSPORE_NNACL_FP16_CROP_FP16_H_ | #ifndef MINDSPORE_NNACL_FP16_CROP_FP16_H_ | ||||
| #define MINDSPORE_NNACL_FP16_CROP_FP16_H_ | #define MINDSPORE_NNACL_FP16_CROP_FP16_H_ | ||||
| #include <arm_neon.h> | |||||
| #include "nnacl/op_base.h" | #include "nnacl/op_base.h" | ||||
| #include "nnacl/crop_parameter.h" | #include "nnacl/crop_parameter.h" | ||||
| @@ -53,11 +53,7 @@ int DeConvPostFp16(const float16_t *src, float16_t *tmp, const float16_t *bias, | |||||
| int dst_index = oh * dst_oh_stride + ow * dst_ow_stride + kh * dst_kh_stride + kw * dst_kw_stride; | int dst_index = oh * dst_oh_stride + ow * dst_ow_stride + kh * dst_kh_stride + kw * dst_kw_stride; | ||||
| float16_t *tmp_dst = dst_ptr + dst_index; | float16_t *tmp_dst = dst_ptr + dst_index; | ||||
| const float16_t *tmp_src = src_ptr + src_index; | const float16_t *tmp_src = src_ptr + src_index; | ||||
| #ifdef DEBUG_CODE | |||||
| for (int i = 0; i < C8NUM; i++) { | |||||
| tmp_dst[i] += tmp_src[i]; | |||||
| } | |||||
| #else | |||||
| #ifdef ENABLE_ARM64 | |||||
| asm volatile( | asm volatile( | ||||
| "mov x0, %[tmp_src] \n" | "mov x0, %[tmp_src] \n" | ||||
| "mov x1, %[tmp_dst] \n" | "mov x1, %[tmp_dst] \n" | ||||
| @@ -72,6 +68,10 @@ int DeConvPostFp16(const float16_t *src, float16_t *tmp, const float16_t *bias, | |||||
| : | : | ||||
| : [ tmp_src ] "r"(tmp_src), [ tmp_dst ] "r"(tmp_dst) | : [ tmp_src ] "r"(tmp_src), [ tmp_dst ] "r"(tmp_dst) | ||||
| : "x0", "x1", "v0", "v1"); | : "x0", "x1", "v0", "v1"); | ||||
| #else | |||||
| for (int i = 0; i < C8NUM; i++) { | |||||
| tmp_dst[i] += tmp_src[i]; | |||||
| } | |||||
| #endif | #endif | ||||
| } /*kw*/ | } /*kw*/ | ||||
| } /*kh*/ | } /*kh*/ | ||||
| @@ -47,6 +47,7 @@ void DeConvWgMergeFp16(const float16_t *src, float16_t *dst, size_t src_stride, | |||||
| size_t cuont8 = count / C8NUM * C8NUM; | size_t cuont8 = count / C8NUM * C8NUM; | ||||
| int i = 0; | int i = 0; | ||||
| for (; i < cuont8; i += C8NUM) { | for (; i < cuont8; i += C8NUM) { | ||||
| #ifdef ENABLE_ARM64 | |||||
| size_t src_step = src_stride * sizeof(float16_t); | size_t src_step = src_stride * sizeof(float16_t); | ||||
| size_t dst_step = dst_stride * sizeof(float16_t); | size_t dst_step = dst_stride * sizeof(float16_t); | ||||
| asm volatile( | asm volatile( | ||||
| @@ -93,7 +94,9 @@ void DeConvWgMergeFp16(const float16_t *src, float16_t *dst, size_t src_stride, | |||||
| : | : | ||||
| : [ src_ptr ] "r"(src_ptr), [ dst_ptr ] "r"(dst_ptr), [ src_step ] "r"(src_step), [ dst_step ] "r"(dst_step) | : [ src_ptr ] "r"(src_ptr), [ dst_ptr ] "r"(dst_ptr), [ src_step ] "r"(src_step), [ dst_step ] "r"(dst_step) | ||||
| : "x7", "x8", "x10", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"); | : "x7", "x8", "x10", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"); | ||||
| #else | |||||
| // TODO(fun): arm32 | |||||
| #endif | |||||
| src_ptr += C8NUM * src_stride; | src_ptr += C8NUM * src_stride; | ||||
| dst_ptr += C8NUM * dst_stride; | dst_ptr += C8NUM * dst_stride; | ||||
| } | } | ||||
| @@ -373,3 +376,23 @@ void DeconvWgPostFp16(float16_t *tile_out, float16_t *nc4hw4_output, ConvParamet | |||||
| } | } | ||||
| return; | return; | ||||
| } | } | ||||
| #ifdef ENABLE_ARM82_A32 | |||||
| void WinogradTransLeftFp16(const float16_t *S, const float16_t *B, float16_t *M, size_t w, size_t h, size_t k, | |||||
| size_t length) { | |||||
| // TODO(fun): function | |||||
| return; | |||||
| } | |||||
| void WinogradTransRightFp16(const float16_t *S, const float16_t *B, float16_t *M, size_t w, size_t h, size_t k, | |||||
| size_t length) { | |||||
| // TODO(fun): function | |||||
| return; | |||||
| } | |||||
| void TiledC4MatmulFp16(float16_t *dst, const float16_t *src, const float16_t *weight, size_t ic4, size_t cal_num, | |||||
| size_t oc4) { | |||||
| // TODO(fun): function | |||||
| return; | |||||
| } | |||||
| #endif | |||||
| @@ -18,6 +18,7 @@ | |||||
| #define MINDSPORE_NNACL_FP16_EXP_H_ | #define MINDSPORE_NNACL_FP16_EXP_H_ | ||||
| #include "nnacl/op_base.h" | #include "nnacl/op_base.h" | ||||
| #include "nnacl/intrinsics/ms_simd_instructions_fp16.h" | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| extern "C" { | extern "C" { | ||||
| @@ -16,6 +16,7 @@ | |||||
| #include "nnacl/fp16/instance_norm_fp16.h" | #include "nnacl/fp16/instance_norm_fp16.h" | ||||
| #include <math.h> | #include <math.h> | ||||
| #include "nnacl/errorcode.h" | #include "nnacl/errorcode.h" | ||||
| #include "nnacl/intrinsics/ms_simd_instructions_fp16.h" | |||||
| int InstanceNormFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *gamma_data, | int InstanceNormFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *gamma_data, | ||||
| const float16_t *beta_data, const InstanceNormParameter *param, size_t task_id) { | const float16_t *beta_data, const InstanceNormParameter *param, size_t task_id) { | ||||
| @@ -17,7 +17,6 @@ | |||||
| #define MINDSPORE_NNACL_FP16_INSTANCE_NORM_H_ | #define MINDSPORE_NNACL_FP16_INSTANCE_NORM_H_ | ||||
| #include "nnacl/instance_norm_parameter.h" | #include "nnacl/instance_norm_parameter.h" | ||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| extern "C" { | extern "C" { | ||||
| #endif | #endif | ||||
| @@ -21,6 +21,7 @@ | |||||
| #include "nnacl/fp16/arithmetic_fp16.h" | #include "nnacl/fp16/arithmetic_fp16.h" | ||||
| #include "nnacl/fp16/matmul_fp16.h" | #include "nnacl/fp16/matmul_fp16.h" | ||||
| #include "nnacl/fp16/cast_fp16.h" | #include "nnacl/fp16/cast_fp16.h" | ||||
| #include "nnacl/intrinsics/ms_simd_instructions_fp16.h" | |||||
| void PackLstmWeightFp32ToFp16(float16_t *dst, const float *src, int batch, int deep, int col, int col_align) { | void PackLstmWeightFp32ToFp16(float16_t *dst, const float *src, int batch, int deep, int col, int col_align) { | ||||
| for (int i = 0; i < batch; i++) { | for (int i = 0; i < batch; i++) { | ||||
| @@ -121,7 +122,7 @@ int ElementOptMulAccFp16(const float16_t *input0, const float16_t input1, float1 | |||||
| for (; index <= element_size - 8; index += 8) { | for (; index <= element_size - 8; index += 8) { | ||||
| float16x8_t vin0 = vld1q_f16(input0 + index); | float16x8_t vin0 = vld1q_f16(input0 + index); | ||||
| float16x8_t vout = vld1q_f16(output + index); | float16x8_t vout = vld1q_f16(output + index); | ||||
| vout = vfmaq_n_f16(vout, vin0, input1); | |||||
| vout = MS_FMAQ_N_F16(vout, vin0, input1); | |||||
| vst1q_f16(output + index, vout); | vst1q_f16(output + index, vout); | ||||
| } | } | ||||
| for (; index < element_size; index++) { | for (; index < element_size; index++) { | ||||
| @@ -226,24 +226,43 @@ void ColMajor2Row8MajorFp16(const void *src_ptr, float16_t *dst_ptr, size_t row, | |||||
| return; | return; | ||||
| } | } | ||||
| void MatMul16x8(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, | |||||
| int deep, int row, int col, int stride, bool write_nhwc) { | |||||
| if (write_nhwc) { | |||||
| /* col16-major * row8-major => col-major */ | |||||
| void MatMul16x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, | |||||
| int deep, int row, int col, int stride, int write_mode) { | |||||
| if (write_mode == OutType_Nhwc) { | |||||
| for (int r = 0; r < row; r++) { | for (int r = 0; r < row; r++) { | ||||
| for (int c = 0; c < col; c++) { | for (int c = 0; c < col; c++) { | ||||
| int r16div = r / 16, r16mod = r % 16; | |||||
| int c8div = c / 8, c8mod = c % 8; | |||||
| size_t ci = r * stride + c; | |||||
| float16_t value = 0; | |||||
| for (int d = 0; d < deep; d++) { | |||||
| size_t ai = r16div * deep * 16 + d * 16 + r16mod; | |||||
| size_t bi = c8div * deep * 8 + d * 8 + c8mod; | |||||
| value = value + a[ai] * b[bi]; | |||||
| } | |||||
| ADD_BIAS(value, bias, c) | |||||
| DO_RELU(value, act_type) | |||||
| DO_RELU6(value, act_type) | |||||
| dst[ci] = value; | |||||
| } | |||||
| } | |||||
| } else if (write_mode == OutType_C8) { | |||||
| int col_8 = UP_ROUND(col, C8NUM); | |||||
| int row_16 = UP_ROUND(row, C16NUM); | |||||
| for (int r = 0; r < row_16; r++) { | |||||
| for (int c = 0; c < col_8; c++) { | |||||
| int r16div = r / C16NUM, r16mod = r % C16NUM; | int r16div = r / C16NUM, r16mod = r % C16NUM; | ||||
| int c8div = c / C8NUM, c8mod = c % C8NUM; | int c8div = c / C8NUM, c8mod = c % C8NUM; | ||||
| size_t ci = r * stride + c; | |||||
| float value = 0; | |||||
| size_t ci = (c8div * C8NUM * row_16 + r * C8NUM + c8mod); | |||||
| float16_t value = 0; | |||||
| for (int d = 0; d < deep; d++) { | for (int d = 0; d < deep; d++) { | ||||
| size_t ai = r16div * deep * C16NUM + d * C16NUM + r16mod; | size_t ai = r16div * deep * C16NUM + d * C16NUM + r16mod; | ||||
| size_t bi = c8div * deep * C8NUM + d * C8NUM + c8mod; | size_t bi = c8div * deep * C8NUM + d * C8NUM + c8mod; | ||||
| value = value + a[ai] * b[bi]; | value = value + a[ai] * b[bi]; | ||||
| } | } | ||||
| if (bias != NULL) value += bias[c]; | |||||
| if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); | |||||
| if (act_type != ActType_No) value = MSMAX(0.0f, value); | |||||
| ADD_BIAS(value, bias, c) | |||||
| DO_RELU(value, act_type) | |||||
| DO_RELU6(value, act_type) | |||||
| dst[ci] = value; | dst[ci] = value; | ||||
| } | } | ||||
| } | } | ||||
| @@ -254,37 +273,119 @@ void MatMul16x8(const float16_t *a, const float16_t *b, float16_t *dst, const fl | |||||
| for (int j = 0; j < col; ++j) { | for (int j = 0; j < col; ++j) { | ||||
| int c8div = j / 8, c8mod = j % 8; | int c8div = j / 8, c8mod = j % 8; | ||||
| size_t ci = dst_r_offset + c8div * 8 * stride + c8mod; | size_t ci = dst_r_offset + c8div * 8 * stride + c8mod; | ||||
| float value = 0; | |||||
| float16_t value = 0; | |||||
| for (int d = 0; d < deep; ++d) { | for (int d = 0; d < deep; ++d) { | ||||
| size_t ai = src_r_offset + d * C16NUM; | size_t ai = src_r_offset + d * C16NUM; | ||||
| size_t bi = c8div * deep * 8 + d * 8 + c8mod; | size_t bi = c8div * deep * 8 + d * 8 + c8mod; | ||||
| value = value + a[ai] * b[bi]; | value = value + a[ai] * b[bi]; | ||||
| } | } | ||||
| if (bias != NULL) value += bias[j]; | |||||
| if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); | |||||
| if (act_type != ActType_No) value = MSMAX(0.0f, value); | |||||
| ADD_BIAS(value, bias, j) | |||||
| DO_RELU(value, act_type) | |||||
| DO_RELU6(value, act_type) | |||||
| dst[ci] = value; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| void MatMul12x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, | |||||
| int deep, int row, int col, int stride, int write_mode) { | |||||
| if (write_mode == OutType_Nhwc) { | |||||
| for (int r = 0; r < row; r++) { | |||||
| for (int c = 0; c < col; c++) { | |||||
| int r12div = r / 12, r12mod = r % 12; | |||||
| int c8div = c / 8, c8mod = c % 8; | |||||
| size_t ci = r * stride + c; | |||||
| float16_t value = 0; | |||||
| for (int d = 0; d < deep; d++) { | |||||
| size_t ai = r12div * deep * 12 + d * 12 + r12mod; | |||||
| size_t bi = c8div * deep * 8 + d * 8 + c8mod; | |||||
| value = value + a[ai] * b[bi]; | |||||
| } | |||||
| ADD_BIAS(value, bias, c) | |||||
| DO_RELU(value, act_type) | |||||
| DO_RELU6(value, act_type) | |||||
| dst[ci] = value; | |||||
| } | |||||
| } | |||||
| } else if (write_mode == OutType_C8) { | |||||
| int col_8 = UP_ROUND(col, C8NUM); | |||||
| int row_12 = UP_ROUND(row, C12NUM); | |||||
| for (int r = 0; r < row_12; r++) { | |||||
| for (int c = 0; c < col_8; c++) { | |||||
| int r12div = r / C12NUM, r12mod = r % C12NUM; | |||||
| int c8div = c / C8NUM, c8mod = c % C8NUM; | |||||
| size_t ci = (c8div * C8NUM * row_12 + r * C8NUM + c8mod); | |||||
| float16_t value = 0; | |||||
| for (int d = 0; d < deep; d++) { | |||||
| size_t ai = r12div * deep * C12NUM + d * C12NUM + r12mod; | |||||
| size_t bi = c8div * deep * C8NUM + d * C8NUM + c8mod; | |||||
| value = value + a[ai] * b[bi]; | |||||
| } | |||||
| ADD_BIAS(value, bias, c) | |||||
| DO_RELU(value, act_type) | |||||
| DO_RELU6(value, act_type) | |||||
| dst[ci] = value; | |||||
| } | |||||
| } | |||||
| } else { | |||||
| for (int i = 0; i < row; ++i) { | |||||
| int src_r_offset = i; | |||||
| int dst_r_offset = i * col * stride; | |||||
| for (int j = 0; j < col; ++j) { | |||||
| int c8div = j / 8, c8mod = j % 8; | |||||
| size_t ci = dst_r_offset + c8div * 8 * stride + c8mod; | |||||
| float16_t value = 0; | |||||
| for (int d = 0; d < deep; ++d) { | |||||
| size_t ai = src_r_offset + d * C12NUM; | |||||
| size_t bi = c8div * deep * 8 + d * 8 + c8mod; | |||||
| value = value + a[ai] * b[bi]; | |||||
| } | |||||
| ADD_BIAS(value, bias, j) | |||||
| DO_RELU(value, act_type) | |||||
| DO_RELU6(value, act_type) | |||||
| dst[ci] = value; | dst[ci] = value; | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| return; | |||||
| } | } | ||||
| void MatMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type, | void MatMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type, | ||||
| int depth, int row, int col, int stride, int out_type) { | int depth, int row, int col, int stride, int out_type) { | ||||
| if (out_type == OutType_C8) { | if (out_type == OutType_C8) { | ||||
| #ifdef ENABLE_ARM64 | |||||
| MatmulFp16Neon64(a, b, c, bias, (int)act_type, depth, row, col, stride, false); | MatmulFp16Neon64(a, b, c, bias, (int)act_type, depth, row, col, stride, false); | ||||
| #else | |||||
| MatMul12x8A32Fp16(a, b, c, bias, (int)act_type, depth, row, col, stride, out_type); | |||||
| #endif | |||||
| } else { | } else { | ||||
| #ifdef ENABLE_ARM64 | |||||
| MatmulFp16Neon64Opt(a, b, c, bias, (int)act_type, depth, row, col, stride, out_type); | MatmulFp16Neon64Opt(a, b, c, bias, (int)act_type, depth, row, col, stride, out_type); | ||||
| #else | |||||
| MatMul12x8A32Fp16(a, b, c, bias, (int)act_type, depth, row, col, stride, out_type); | |||||
| #endif | |||||
| } | } | ||||
| return; | return; | ||||
| } | } | ||||
| #ifdef ENABLE_ARM82_A32 | |||||
| void MatVecMulA32Fp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, | |||||
| int depth, int col) { | |||||
| // TODO(fun): function | |||||
| return; | |||||
| } | |||||
| #endif | |||||
| void MatVecMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type, | void MatVecMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type, | ||||
| int depth, int col) { | int depth, int col) { | ||||
| #ifdef ENABLE_ARM64 | |||||
| MatVecMulFp16Neon64(a, b, c, bias, (int)act_type, depth, col); | MatVecMulFp16Neon64(a, b, c, bias, (int)act_type, depth, col); | ||||
| #else | |||||
| MatVecMulA32Fp16(a, b, c, bias, (int)act_type, depth, col); | |||||
| #endif | |||||
| } | } | ||||
| #ifdef ENABLE_ARM64 | |||||
| static void Row2Col16Block16(const float16_t *src_ptr, float16_t *dst_ptr, size_t col) { | static void Row2Col16Block16(const float16_t *src_ptr, float16_t *dst_ptr, size_t col) { | ||||
| size_t stride = col * 2; | size_t stride = col * 2; | ||||
| asm volatile( | asm volatile( | ||||
| @@ -392,6 +493,7 @@ static void Row2Col16Block16(const float16_t *src_ptr, float16_t *dst_ptr, size_ | |||||
| "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", | "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", | ||||
| "v31"); | "v31"); | ||||
| } | } | ||||
| #endif | |||||
| void RowMajor2Col16MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col) { | void RowMajor2Col16MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col) { | ||||
| size_t row_up_16 = UP_ROUND(row, C16NUM); | size_t row_up_16 = UP_ROUND(row, C16NUM); | ||||
| @@ -442,6 +544,54 @@ void RowMajor2Col16MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, si | |||||
| return; | return; | ||||
| } | } | ||||
| void RowMajor2Col12MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col) { | |||||
| size_t row_up_12 = UP_ROUND(row, C12NUM); | |||||
| size_t row12 = row / C12NUM * C12NUM; | |||||
| size_t col8 = col / C8NUM * C8NUM; | |||||
| const float16_t *src_r = src_ptr; | |||||
| float16_t *dst_r = dst_ptr; | |||||
| size_t ri = 0; | |||||
| // transpose 12x8 | |||||
| for (; ri < row12; ri += C12NUM) { | |||||
| size_t ci = 0; | |||||
| for (; ci < col8; ci += C8NUM) { | |||||
| const float16_t *src_c = src_r + ci; | |||||
| float16_t *dst_c = dst_r + ci * C12NUM; | |||||
| #ifdef ENABLE_ARM82_A32 | |||||
| Transpose12x8A32Fp16(src_c, dst_c, col * sizeof(float16_t), 24); | |||||
| #else | |||||
| for (int tr = 0; tr < C12NUM; tr++) { | |||||
| for (int tc = 0; tc < C8NUM; tc++) { | |||||
| dst_c[tc * C12NUM + tr] = src_c[tr * col + tc]; | |||||
| } | |||||
| } | |||||
| #endif | |||||
| } | |||||
| for (; ci < col; ci++) { | |||||
| const float16_t *src_c = src_r + ci; | |||||
| float16_t *dst_c = dst_r + ci * C12NUM; | |||||
| for (size_t i = 0; i < C12NUM; i++) { | |||||
| dst_c[i] = src_c[i * col]; | |||||
| } | |||||
| } | |||||
| src_r += C12NUM * col; | |||||
| dst_r += C12NUM * col; | |||||
| } | |||||
| for (; ri < row; ri++) { | |||||
| for (size_t i = 0; i < col; ++i) { | |||||
| dst_r[i * C12NUM] = src_r[i]; | |||||
| } | |||||
| src_r += col; | |||||
| dst_r += 1; | |||||
| } | |||||
| for (; ri < row_up_12; ri++) { | |||||
| for (size_t i = 0; i < col; i++) { | |||||
| dst_r[i * C12NUM] = 0; | |||||
| } | |||||
| dst_r += 1; | |||||
| } | |||||
| } | |||||
| void RowMajor2Col16MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src) { | void RowMajor2Col16MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src) { | ||||
| if (is_fp32_src) { | if (is_fp32_src) { | ||||
| const float *fp32_src = (const float *)src; | const float *fp32_src = (const float *)src; | ||||
| @@ -19,18 +19,51 @@ | |||||
| #include <float.h> | #include <float.h> | ||||
| #include <string.h> | #include <string.h> | ||||
| #ifdef ENABLE_NEON | |||||
| #ifdef ENABLE_ARM64 | |||||
| #include <arm_neon.h> | #include <arm_neon.h> | ||||
| #endif | #endif | ||||
| #include "nnacl/errorcode.h" | #include "nnacl/errorcode.h" | ||||
| #include "nnacl/matmul_parameter.h" | #include "nnacl/matmul_parameter.h" | ||||
| #include "nnacl/op_base.h" | #include "nnacl/op_base.h" | ||||
| #include "nnacl/intrinsics/ms_simd_instructions_fp16.h" | |||||
| #include "nnacl/fp16/pack_fp16.h" | |||||
| #define ADD_BIAS(value, bias, c) \ | |||||
| if (bias != NULL) value = value + bias[c]; | |||||
| #define DO_RELU(value, act_type) \ | |||||
| if (act_type == ActType_Relu) value = MSMAX(0.0f, value); | |||||
| #define DO_RELU6(value, act_type) \ | |||||
| if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); \ | |||||
| if (act_type == ActType_Relu6) value = MSMAX(0.0f, value); | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| extern "C" { | extern "C" { | ||||
| #endif | #endif | ||||
| void MatMul16x8(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, | |||||
| int deep, int row, int col, int stride, bool write_nhwc); | |||||
| void MatMul16x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, | |||||
| int deep, int row, int col, int stride, int write_mode); | |||||
| void MatMul12x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, | |||||
| int deep, int row, int col, int stride, int write_mode); | |||||
| #ifdef ENABLE_ARM64 | |||||
| void MatmulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, | |||||
| size_t depth, size_t row, size_t col, size_t stride, bool write_nhwc); | |||||
| void MatmulFp16Neon64Opt(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, | |||||
| size_t depth, size_t row, size_t col, size_t stride, size_t write_nhwc); | |||||
| void MatVecMulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, | |||||
| int depth, int col); | |||||
| #elif ENABLE_ARM82_A32 | |||||
| void MatMul12x8A32Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, | |||||
| int deep, int row, int col, int stride, int write_mode); | |||||
| void MatVecMulA32Fp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, | |||||
| int depth, int col); | |||||
| #endif | |||||
| void MatMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type, | void MatMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type, | ||||
| int depth, int row, int col, int stride, int out_type); | int depth, int row, int col, int stride, int out_type); | ||||
| @@ -42,14 +75,7 @@ void ColMajor2Row8MajorFp16(const void *src_ptr, float16_t *dst_ptr, size_t row, | |||||
| void RowMajor2Col16MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col); | void RowMajor2Col16MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col); | ||||
| void MatmulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, | |||||
| size_t depth, size_t row, size_t col, size_t stride, bool write_nhwc); | |||||
| void MatmulFp16Neon64Opt(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, | |||||
| size_t depth, size_t row, size_t col, size_t stride, size_t write_nhwc); | |||||
| void MatVecMulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, | |||||
| int depth, int col); | |||||
| void RowMajor2Col12MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col); | |||||
| void RowMajor2Col16MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src); | void RowMajor2Col16MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src); | ||||
| @@ -39,7 +39,7 @@ void MatrixMultiplyWinogradFp16(const float16_t *matix_a, const float16_t *matri | |||||
| for (int i = 0; i < m; ++i) { | for (int i = 0; i < m; ++i) { | ||||
| for (int j = 0; j < n; ++j) { | for (int j = 0; j < n; ++j) { | ||||
| for (int y = 0; y < in_channel; ++y) { | for (int y = 0; y < in_channel; ++y) { | ||||
| float16_t tmp = 0; | |||||
| float tmp = 0; | |||||
| for (int z = 0; z < k; ++z) { | for (int z = 0; z < k; ++z) { | ||||
| tmp += matix_a[z * in_channel + y + i * in_channel * k] * matrix_b[j + z * n]; | tmp += matix_a[z * in_channel + y + i * in_channel * k] * matrix_b[j + z * n]; | ||||
| } | } | ||||
| @@ -160,129 +160,35 @@ void PackNCHWToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int | |||||
| } | } | ||||
| void PackNHWCToNCHWFp16(const void *src, void *dst, int batches, int plane, int channel) { | void PackNHWCToNCHWFp16(const void *src, void *dst, int batches, int plane, int channel) { | ||||
| int hw16 = plane / C16NUM * C16NUM; | |||||
| #ifdef ENABLE_ARM64 | |||||
| // Transpose16x8 in arm64 | |||||
| const int hw_tile = C16NUM; | |||||
| #else | |||||
| // Transpose8x8 in others | |||||
| const int hw_tile = C8NUM; | |||||
| #endif | |||||
| int hw_align = plane / hw_tile * hw_tile; | |||||
| int c8 = channel / C8NUM * C8NUM; | int c8 = channel / C8NUM * C8NUM; | ||||
| int batch = plane * channel; | int batch = plane * channel; | ||||
| for (int n = 0; n < batches; n++) { | for (int n = 0; n < batches; n++) { | ||||
| const float16_t *src_batch = (const float16_t *)src + n * batch; | const float16_t *src_batch = (const float16_t *)src + n * batch; | ||||
| float16_t *dst_batch = (float16_t *)dst + n * batch; | float16_t *dst_batch = (float16_t *)dst + n * batch; | ||||
| int hw = 0; | int hw = 0; | ||||
| for (; hw < hw16; hw += C16NUM) { | |||||
| for (; hw < hw_align; hw += hw_tile) { | |||||
| int c = 0; | int c = 0; | ||||
| for (; c < c8; c += C8NUM) { | for (; c < c8; c += C8NUM) { | ||||
| const float16_t *src_ptr = src_batch + hw * channel + c; | const float16_t *src_ptr = src_batch + hw * channel + c; | ||||
| float16_t *dst_ptr = dst_batch + c * plane + hw; | float16_t *dst_ptr = dst_batch + c * plane + hw; | ||||
| #ifdef ENABLE_ARM64 | #ifdef ENABLE_ARM64 | ||||
| size_t srcStride = channel * sizeof(float16_t); | |||||
| size_t dstStride = plane * sizeof(float16_t); | |||||
| asm volatile( | |||||
| "mov x10, %[src_ptr]\n" | |||||
| "mov x11, %[dst_ptr]\n" | |||||
| "ld1 {v0.8h}, [x10], %[srcStride]\n" | |||||
| "ld1 {v1.8h}, [x10], %[srcStride]\n" | |||||
| "ld1 {v2.8h}, [x10], %[srcStride]\n" | |||||
| "ld1 {v3.8h}, [x10], %[srcStride]\n" | |||||
| "ld1 {v4.8h}, [x10], %[srcStride]\n" | |||||
| "ld1 {v5.8h}, [x10], %[srcStride]\n" | |||||
| "ld1 {v6.8h}, [x10], %[srcStride]\n" | |||||
| "ld1 {v7.8h}, [x10], %[srcStride]\n" | |||||
| "zip1 v16.8h, v0.8h, v1.8h\n" | |||||
| "zip1 v17.8h, v2.8h, v3.8h\n" | |||||
| "zip1 v18.8h, v4.8h, v5.8h\n" | |||||
| "zip1 v19.8h, v6.8h, v7.8h\n" | |||||
| "ld1 {v8.8h}, [x10], %[srcStride]\n" | |||||
| "ld1 {v9.8h}, [x10], %[srcStride]\n" | |||||
| "ld1 {v10.8h}, [x10], %[srcStride]\n" | |||||
| "ld1 {v11.8h}, [x10], %[srcStride]\n" | |||||
| "ld1 {v12.8h}, [x10], %[srcStride]\n" | |||||
| "ld1 {v13.8h}, [x10], %[srcStride]\n" | |||||
| "ld1 {v14.8h}, [x10], %[srcStride]\n" | |||||
| "ld1 {v15.8h}, [x10], %[srcStride]\n" | |||||
| "trn1 v20.4s, v16.4s, v17.4s\n" | |||||
| "trn2 v21.4s, v16.4s, v17.4s\n" | |||||
| "trn1 v22.4s, v18.4s, v19.4s\n" | |||||
| "trn2 v23.4s, v18.4s, v19.4s\n" | |||||
| "trn1 v24.2d, v20.2d, v22.2d\n" | |||||
| "trn2 v25.2d, v20.2d, v22.2d\n" | |||||
| "trn1 v26.2d, v21.2d, v23.2d\n" | |||||
| "trn2 v27.2d, v21.2d, v23.2d\n" | |||||
| "zip1 v16.8h, v8.8h, v9.8h\n" | |||||
| "zip1 v17.8h, v10.8h, v11.8h\n" | |||||
| "zip1 v18.8h, v12.8h, v13.8h\n" | |||||
| "zip1 v19.8h, v14.8h, v15.8h\n" | |||||
| "trn1 v20.4s, v16.4s, v17.4s\n" | |||||
| "trn2 v21.4s, v16.4s, v17.4s\n" | |||||
| "trn1 v22.4s, v18.4s, v19.4s\n" | |||||
| "trn2 v23.4s, v18.4s, v19.4s\n" | |||||
| "trn1 v28.2d, v20.2d, v22.2d\n" | |||||
| "trn2 v29.2d, v20.2d, v22.2d\n" | |||||
| "trn1 v30.2d, v21.2d, v23.2d\n" | |||||
| "trn2 v31.2d, v21.2d, v23.2d\n" | |||||
| "add x10, x11, #16\n" | |||||
| "st1 {v24.8h}, [x11], %[dstStride]\n" | |||||
| "st1 {v28.8h}, [x10], %[dstStride]\n" | |||||
| "st1 {v26.8h}, [x11], %[dstStride]\n" | |||||
| "st1 {v30.8h}, [x10], %[dstStride]\n" | |||||
| "st1 {v25.8h}, [x11], %[dstStride]\n" | |||||
| "st1 {v29.8h}, [x10], %[dstStride]\n" | |||||
| "st1 {v27.8h}, [x11], %[dstStride]\n" | |||||
| "st1 {v31.8h}, [x10], %[dstStride]\n" | |||||
| "zip2 v16.8h, v0.8h, v1.8h\n" | |||||
| "zip2 v17.8h, v2.8h, v3.8h\n" | |||||
| "zip2 v18.8h, v4.8h, v5.8h\n" | |||||
| "zip2 v19.8h, v6.8h, v7.8h\n" | |||||
| "trn1 v20.4s, v16.4s, v17.4s\n" | |||||
| "trn2 v21.4s, v16.4s, v17.4s\n" | |||||
| "trn1 v22.4s, v18.4s, v19.4s\n" | |||||
| "trn2 v23.4s, v18.4s, v19.4s\n" | |||||
| "trn1 v24.2d, v20.2d, v22.2d\n" | |||||
| "trn2 v25.2d, v20.2d, v22.2d\n" | |||||
| "trn1 v26.2d, v21.2d, v23.2d\n" | |||||
| "trn2 v27.2d, v21.2d, v23.2d\n" | |||||
| "zip2 v16.8h, v8.8h, v9.8h\n" | |||||
| "zip2 v17.8h, v10.8h, v11.8h\n" | |||||
| "zip2 v18.8h, v12.8h, v13.8h\n" | |||||
| "zip2 v19.8h, v14.8h, v15.8h\n" | |||||
| "trn1 v20.4s, v16.4s, v17.4s\n" | |||||
| "trn2 v21.4s, v16.4s, v17.4s\n" | |||||
| "trn1 v22.4s, v18.4s, v19.4s\n" | |||||
| "trn2 v23.4s, v18.4s, v19.4s\n" | |||||
| "trn1 v28.2d, v20.2d, v22.2d\n" | |||||
| "trn2 v29.2d, v20.2d, v22.2d\n" | |||||
| "trn1 v30.2d, v21.2d, v23.2d\n" | |||||
| "trn2 v31.2d, v21.2d, v23.2d\n" | |||||
| "st1 {v24.8h}, [x11], %[dstStride]\n" | |||||
| "st1 {v28.8h}, [x10], %[dstStride]\n" | |||||
| "st1 {v26.8h}, [x11], %[dstStride]\n" | |||||
| "st1 {v30.8h}, [x10], %[dstStride]\n" | |||||
| "st1 {v25.8h}, [x11], %[dstStride]\n" | |||||
| "st1 {v29.8h}, [x10], %[dstStride]\n" | |||||
| "st1 {v27.8h}, [x11], %[dstStride]\n" | |||||
| "st1 {v31.8h}, [x10], %[dstStride]\n" | |||||
| : | |||||
| : | |||||
| [ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ srcStride ] "r"(srcStride), [ dstStride ] "r"(dstStride) | |||||
| : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||||
| "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", | |||||
| "v30", "v31"); | |||||
| size_t src_stride = channel * sizeof(float16_t); | |||||
| size_t dst_stride = plane * sizeof(float16_t); | |||||
| Transpose16x8ARM64Fp16(src_ptr, dst_ptr, src_stride, dst_stride); | |||||
| #elif defined(ENABLE_ARM82_A32) | |||||
| size_t src_stride = channel * sizeof(float16_t); | |||||
| size_t dst_stride = plane * sizeof(float16_t); | |||||
| Transpose8x8A32Fp16(src_ptr, dst_ptr, src_stride, dst_stride); | |||||
| #else | #else | ||||
| for (int tr = 0; tr < C16NUM; tr++) { | |||||
| for (int tr = 0; tr < hw_tile; tr++) { | |||||
| for (int tc = 0; tc < C8NUM; tc++) { | for (int tc = 0; tc < C8NUM; tc++) { | ||||
| dst_ptr[tc * plane + tr] = src_ptr[tr * channel + tc]; | dst_ptr[tc * plane + tr] = src_ptr[tr * channel + tc]; | ||||
| } | } | ||||
| @@ -292,7 +198,7 @@ void PackNHWCToNCHWFp16(const void *src, void *dst, int batches, int plane, int | |||||
| for (; c < channel; c++) { | for (; c < channel; c++) { | ||||
| const float16_t *src_ptr = src_batch + hw * channel + c; | const float16_t *src_ptr = src_batch + hw * channel + c; | ||||
| float16_t *dst_ptr = dst_batch + c * plane + hw; | float16_t *dst_ptr = dst_batch + c * plane + hw; | ||||
| for (size_t i = 0; i < C16NUM; i++) { | |||||
| for (size_t i = 0; i < hw_tile; i++) { | |||||
| dst_ptr[i] = src_ptr[i * channel]; | dst_ptr[i] = src_ptr[i * channel]; | ||||
| } | } | ||||
| } | } | ||||
| @@ -305,7 +211,6 @@ void PackNHWCToNCHWFp16(const void *src, void *dst, int batches, int plane, int | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| return; | |||||
| } | } | ||||
| void PackNCHWToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel) { | void PackNCHWToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel) { | ||||
| @@ -565,3 +470,246 @@ void PackNHWC8ToNHWCFp16(float16_t *src, float16_t *dst, int batch, int plane, i | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| #ifdef ENABLE_ARM82_A32 | |||||
| inline void Transpose8x8A32Fp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t dst_stride) { | |||||
| asm volatile( | |||||
| "mov r10, %[src]\n" | |||||
| "mov r12, %[dst]\n" | |||||
| "vld1.16 {q0}, [r10], %[src_stride]\n" | |||||
| "vld1.16 {q2}, [r10], %[src_stride]\n" | |||||
| "vld1.16 {q4}, [r10], %[src_stride]\n" | |||||
| "vld1.16 {q6}, [r10], %[src_stride]\n" | |||||
| "vtrn.16 d0, d4\n" | |||||
| "vtrn.16 d1, d5\n" | |||||
| "vtrn.16 d8, d12\n" | |||||
| "vtrn.16 d9, d13\n" | |||||
| "vld1.16 {q8}, [r10], %[src_stride]\n" | |||||
| "vld1.16 {q10}, [r10], %[src_stride]\n" | |||||
| "vld1.16 {q12}, [r10], %[src_stride]\n" | |||||
| "vld1.16 {q14}, [r10], %[src_stride]\n" | |||||
| "vtrn.32 d0, d8\n" | |||||
| "vtrn.32 d4, d12\n" | |||||
| "vtrn.32 d1, d9\n" | |||||
| "vtrn.32 d5, d13\n" | |||||
| "vtrn.16 d16, d20\n" | |||||
| "vtrn.16 d17, d21\n" | |||||
| "vtrn.16 d24, d28\n" | |||||
| "vtrn.16 d25, d29\n" | |||||
| "vtrn.32 d16, d24\n" | |||||
| "vtrn.32 d20, d28\n" | |||||
| "vtrn.32 d17, d25\n" | |||||
| "vtrn.32 d21, d29\n" | |||||
| "vswp d1, d16\n" | |||||
| "vswp d5, d20\n" | |||||
| "vswp d9, d24\n" | |||||
| "vswp d13, d28\n" | |||||
| "vst1.16 {q0}, [r12], %[dst_stride]\n" | |||||
| "vst1.16 {q2}, [r12], %[dst_stride]\n" | |||||
| "vst1.16 {q4}, [r12], %[dst_stride]\n" | |||||
| "vst1.16 {q6}, [r12], %[dst_stride]\n" | |||||
| "vst1.16 {q8}, [r12], %[dst_stride]\n" | |||||
| "vst1.16 {q10}, [r12], %[dst_stride]\n" | |||||
| "vst1.16 {q12}, [r12], %[dst_stride]\n" | |||||
| "vst1.16 {q14}, [r12], %[dst_stride]\n" | |||||
| : | |||||
| : [ dst ] "r"(dst), [ src ] "r"(src), [ src_stride ] "r"(src_stride), [ dst_stride ] "r"(dst_stride) | |||||
| : "r10", "r12", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", | |||||
| "q15"); | |||||
| } | |||||
| inline void Transpose12x8A32Fp16(const float16_t *src_c, float16_t *dst_c, size_t src_stride, size_t dst_stride) { | |||||
| asm volatile( | |||||
| "mov r10, %[src_c]\n" | |||||
| "mov r12, %[dst_c]\n" | |||||
| "vld1.16 {q0}, [r10], %[src_stride]\n" | |||||
| "vld1.16 {q2}, [r10], %[src_stride]\n" | |||||
| "vld1.16 {q4}, [r10], %[src_stride]\n" | |||||
| "vld1.16 {q6}, [r10], %[src_stride]\n" | |||||
| "vtrn.16 d0, d4\n" | |||||
| "vtrn.16 d1, d5\n" | |||||
| "vtrn.16 d8, d12\n" | |||||
| "vtrn.16 d9, d13\n" | |||||
| "vld1.16 {q8}, [r10], %[src_stride]\n" | |||||
| "vld1.16 {q10}, [r10], %[src_stride]\n" | |||||
| "vld1.16 {q12}, [r10], %[src_stride]\n" | |||||
| "vld1.16 {q14}, [r10], %[src_stride]\n" | |||||
| "vtrn.32 d0, d8\n" | |||||
| "vtrn.32 d4, d12\n" | |||||
| "vtrn.32 d1, d9\n" | |||||
| "vtrn.32 d5, d13\n" | |||||
| "vtrn.16 d16, d20\n" | |||||
| "vtrn.16 d17, d21\n" | |||||
| "vtrn.16 d24, d28\n" | |||||
| "vtrn.16 d25, d29\n" | |||||
| "vld1.16 {q1}, [r10], %[src_stride]\n" | |||||
| "vld1.16 {q3}, [r10], %[src_stride]\n" | |||||
| "vld1.16 {q5}, [r10], %[src_stride]\n" | |||||
| "vld1.16 {q7}, [r10], %[src_stride]\n" | |||||
| "vtrn.32 d16, d24\n" | |||||
| "vtrn.32 d20, d28\n" | |||||
| "vtrn.32 d17, d25\n" | |||||
| "vtrn.32 d21, d29\n" | |||||
| "vswp d1, d16\n" | |||||
| "vswp d5, d20\n" | |||||
| "vswp d9, d24\n" | |||||
| "vswp d13, d28\n" | |||||
| "vtrn.16 d2, d6\n" | |||||
| "vtrn.16 d3, d7\n" | |||||
| "vtrn.16 d10, d14\n" | |||||
| "vtrn.16 d11, d15\n" | |||||
| "vtrn.32 d2, d10\n" | |||||
| "vtrn.32 d6, d14\n" | |||||
| "vtrn.32 d3, d11\n" | |||||
| "vtrn.32 d7, d15\n" | |||||
| "vst1.16 {q0, d2}, [r12], %[dst_stride]\n" | |||||
| "vst1.16 {q2, d6}, [r12], %[dst_stride]\n" | |||||
| "vst1.16 {q4, d10}, [r12], %[dst_stride]\n" | |||||
| "vst1.16 {q6, d14}, [r12], %[dst_stride]\n" | |||||
| "vswp d3, d18\n" | |||||
| "vswp d7, d22\n" | |||||
| "vswp d11, d26\n" | |||||
| "vswp d15, d30\n" | |||||
| "vst1.16 {q8, d18}, [r12], %[dst_stride]\n" | |||||
| "vst1.16 {q10, d22}, [r12], %[dst_stride]\n" | |||||
| "vst1.16 {q12, d26}, [r12], %[dst_stride]\n" | |||||
| "vst1.16 {q14, d30}, [r12], %[dst_stride]\n" | |||||
| : | |||||
| : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ src_stride ] "r"(src_stride), [ dst_stride ] "r"(dst_stride) | |||||
| : "r10", "r12", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", | |||||
| "q15"); | |||||
| } | |||||
| #endif | |||||
| #ifdef ENABLE_ARM64 | |||||
| inline void Transpose16x8ARM64Fp16(const float16_t *src_ptr, float16_t *dst_ptr, size_t src_stride, size_t dst_stride) { | |||||
| asm volatile( | |||||
| "mov x10, %[src_ptr]\n" | |||||
| "mov x11, %[dst_ptr]\n" | |||||
| "ld1 {v0.8h}, [x10], %[src_stride]\n" | |||||
| "ld1 {v1.8h}, [x10], %[src_stride]\n" | |||||
| "ld1 {v2.8h}, [x10], %[src_stride]\n" | |||||
| "ld1 {v3.8h}, [x10], %[src_stride]\n" | |||||
| "ld1 {v4.8h}, [x10], %[src_stride]\n" | |||||
| "ld1 {v5.8h}, [x10], %[src_stride]\n" | |||||
| "ld1 {v6.8h}, [x10], %[src_stride]\n" | |||||
| "ld1 {v7.8h}, [x10], %[src_stride]\n" | |||||
| "zip1 v16.8h, v0.8h, v1.8h\n" | |||||
| "zip1 v17.8h, v2.8h, v3.8h\n" | |||||
| "zip1 v18.8h, v4.8h, v5.8h\n" | |||||
| "zip1 v19.8h, v6.8h, v7.8h\n" | |||||
| "ld1 {v8.8h}, [x10], %[src_stride]\n" | |||||
| "ld1 {v9.8h}, [x10], %[src_stride]\n" | |||||
| "ld1 {v10.8h}, [x10], %[src_stride]\n" | |||||
| "ld1 {v11.8h}, [x10], %[src_stride]\n" | |||||
| "ld1 {v12.8h}, [x10], %[src_stride]\n" | |||||
| "ld1 {v13.8h}, [x10], %[src_stride]\n" | |||||
| "ld1 {v14.8h}, [x10], %[src_stride]\n" | |||||
| "ld1 {v15.8h}, [x10], %[src_stride]\n" | |||||
| "trn1 v20.4s, v16.4s, v17.4s\n" | |||||
| "trn2 v21.4s, v16.4s, v17.4s\n" | |||||
| "trn1 v22.4s, v18.4s, v19.4s\n" | |||||
| "trn2 v23.4s, v18.4s, v19.4s\n" | |||||
| "trn1 v24.2d, v20.2d, v22.2d\n" | |||||
| "trn2 v25.2d, v20.2d, v22.2d\n" | |||||
| "trn1 v26.2d, v21.2d, v23.2d\n" | |||||
| "trn2 v27.2d, v21.2d, v23.2d\n" | |||||
| "zip1 v16.8h, v8.8h, v9.8h\n" | |||||
| "zip1 v17.8h, v10.8h, v11.8h\n" | |||||
| "zip1 v18.8h, v12.8h, v13.8h\n" | |||||
| "zip1 v19.8h, v14.8h, v15.8h\n" | |||||
| "trn1 v20.4s, v16.4s, v17.4s\n" | |||||
| "trn2 v21.4s, v16.4s, v17.4s\n" | |||||
| "trn1 v22.4s, v18.4s, v19.4s\n" | |||||
| "trn2 v23.4s, v18.4s, v19.4s\n" | |||||
| "trn1 v28.2d, v20.2d, v22.2d\n" | |||||
| "trn2 v29.2d, v20.2d, v22.2d\n" | |||||
| "trn1 v30.2d, v21.2d, v23.2d\n" | |||||
| "trn2 v31.2d, v21.2d, v23.2d\n" | |||||
| "add x10, x11, #16\n" | |||||
| "st1 {v24.8h}, [x11], %[dst_stride]\n" | |||||
| "st1 {v28.8h}, [x10], %[dst_stride]\n" | |||||
| "st1 {v26.8h}, [x11], %[dst_stride]\n" | |||||
| "st1 {v30.8h}, [x10], %[dst_stride]\n" | |||||
| "st1 {v25.8h}, [x11], %[dst_stride]\n" | |||||
| "st1 {v29.8h}, [x10], %[dst_stride]\n" | |||||
| "st1 {v27.8h}, [x11], %[dst_stride]\n" | |||||
| "st1 {v31.8h}, [x10], %[dst_stride]\n" | |||||
| "zip2 v16.8h, v0.8h, v1.8h\n" | |||||
| "zip2 v17.8h, v2.8h, v3.8h\n" | |||||
| "zip2 v18.8h, v4.8h, v5.8h\n" | |||||
| "zip2 v19.8h, v6.8h, v7.8h\n" | |||||
| "trn1 v20.4s, v16.4s, v17.4s\n" | |||||
| "trn2 v21.4s, v16.4s, v17.4s\n" | |||||
| "trn1 v22.4s, v18.4s, v19.4s\n" | |||||
| "trn2 v23.4s, v18.4s, v19.4s\n" | |||||
| "trn1 v24.2d, v20.2d, v22.2d\n" | |||||
| "trn2 v25.2d, v20.2d, v22.2d\n" | |||||
| "trn1 v26.2d, v21.2d, v23.2d\n" | |||||
| "trn2 v27.2d, v21.2d, v23.2d\n" | |||||
| "zip2 v16.8h, v8.8h, v9.8h\n" | |||||
| "zip2 v17.8h, v10.8h, v11.8h\n" | |||||
| "zip2 v18.8h, v12.8h, v13.8h\n" | |||||
| "zip2 v19.8h, v14.8h, v15.8h\n" | |||||
| "trn1 v20.4s, v16.4s, v17.4s\n" | |||||
| "trn2 v21.4s, v16.4s, v17.4s\n" | |||||
| "trn1 v22.4s, v18.4s, v19.4s\n" | |||||
| "trn2 v23.4s, v18.4s, v19.4s\n" | |||||
| "trn1 v28.2d, v20.2d, v22.2d\n" | |||||
| "trn2 v29.2d, v20.2d, v22.2d\n" | |||||
| "trn1 v30.2d, v21.2d, v23.2d\n" | |||||
| "trn2 v31.2d, v21.2d, v23.2d\n" | |||||
| "st1 {v24.8h}, [x11], %[dst_stride]\n" | |||||
| "st1 {v28.8h}, [x10], %[dst_stride]\n" | |||||
| "st1 {v26.8h}, [x11], %[dst_stride]\n" | |||||
| "st1 {v30.8h}, [x10], %[dst_stride]\n" | |||||
| "st1 {v25.8h}, [x11], %[dst_stride]\n" | |||||
| "st1 {v29.8h}, [x10], %[dst_stride]\n" | |||||
| "st1 {v27.8h}, [x11], %[dst_stride]\n" | |||||
| "st1 {v31.8h}, [x10], %[dst_stride]\n" | |||||
| : | |||||
| : [ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ src_stride ] "r"(src_stride), [ dst_stride ] "r"(dst_stride) | |||||
| : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||||
| "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", | |||||
| "v31"); | |||||
| } | |||||
| #endif | |||||
| @@ -17,11 +17,9 @@ | |||||
| #ifndef MINDSPORE_NNACL_FP16_PACK_FP16_H_ | #ifndef MINDSPORE_NNACL_FP16_PACK_FP16_H_ | ||||
| #define MINDSPORE_NNACL_FP16_PACK_FP16_H_ | #define MINDSPORE_NNACL_FP16_PACK_FP16_H_ | ||||
| #ifdef ENABLE_NEON | |||||
| #include <arm_neon.h> | |||||
| #endif | |||||
| #include "nnacl/conv_parameter.h" | #include "nnacl/conv_parameter.h" | ||||
| #include "nnacl/op_base.h" | #include "nnacl/op_base.h" | ||||
| #include "nnacl/intrinsics/ms_simd_instructions_fp16.h" | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| extern "C" { | extern "C" { | ||||
| @@ -72,6 +70,17 @@ void PackNHWCFp16ToC8HWN8Fp16(float16_t *src, float16_t *dst, int batch, int pla | |||||
| void PackNHWC8Fp16ToNHWCFp32(float16_t *src, float *dst, int batch, int plane, int channel); | void PackNHWC8Fp16ToNHWCFp32(float16_t *src, float *dst, int batch, int plane, int channel); | ||||
| void PackNHWC8ToNHWCFp16(float16_t *src, float16_t *dst, int batch, int plane, int channel); | void PackNHWC8ToNHWCFp16(float16_t *src, float16_t *dst, int batch, int plane, int channel); | ||||
| #ifdef ENABLE_ARM82_A32 | |||||
| void Transpose8x8A32Fp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t dst_stride); | |||||
| void Transpose12x8A32Fp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t dst_stride); | |||||
| #endif | |||||
| #ifdef ENABLE_ARM64 | |||||
| void Transpose16x8ARM64Fp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t dst_stride); | |||||
| #endif | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -16,9 +16,6 @@ | |||||
| #ifndef MINDSPORE_NNACL_FP16_PAD_FP16_H_ | #ifndef MINDSPORE_NNACL_FP16_PAD_FP16_H_ | ||||
| #define MINDSPORE_NNACL_FP16_PAD_FP16_H_ | #define MINDSPORE_NNACL_FP16_PAD_FP16_H_ | ||||
| #ifdef ENABLE_NEON | |||||
| #include <arm_neon.h> | |||||
| #endif | |||||
| #include "nnacl/fp32/pad_fp32.h" | #include "nnacl/fp32/pad_fp32.h" | ||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| @@ -18,10 +18,8 @@ | |||||
| #define MINDSPORE_NNACL_FP16_POOLING_FP16_H_ | #define MINDSPORE_NNACL_FP16_POOLING_FP16_H_ | ||||
| #include <math.h> | #include <math.h> | ||||
| #ifdef ENABLE_NEON | |||||
| #include <arm_neon.h> | |||||
| #endif | |||||
| #include "nnacl/pooling_parameter.h" | #include "nnacl/pooling_parameter.h" | ||||
| #include "nnacl/intrinsics/ms_simd_instructions_fp16.h" | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| extern "C" { | extern "C" { | ||||
| #endif | #endif | ||||
| @@ -17,7 +17,7 @@ | |||||
| #include "nnacl/fp16/power_fp16.h" | #include "nnacl/fp16/power_fp16.h" | ||||
| #include "nnacl/errorcode.h" | #include "nnacl/errorcode.h" | ||||
| #if defined(ENABLE_NEON) | |||||
| #if defined(ENABLE_ARM64) | |||||
| float16x8_t OptimizedPowerSimdFp16(float16x8_t x, const void *exponent) { | float16x8_t OptimizedPowerSimdFp16(float16x8_t x, const void *exponent) { | ||||
| int tmp = (int)(*(float16_t *)exponent); | int tmp = (int)(*(float16_t *)exponent); | ||||
| int exp = abs(tmp); | int exp = abs(tmp); | ||||
| @@ -53,23 +53,23 @@ float16_t OptimizedPowerScalarFp16(float16_t x, const void *exponent) { | |||||
| void PowerBroadCastFp16(const float16_t *input, const float16_t *exponent, float16_t *output, int len, float scale, | void PowerBroadCastFp16(const float16_t *input, const float16_t *exponent, float16_t *output, int len, float scale, | ||||
| float shift) { | float shift) { | ||||
| PowerScalarFunFp16 PowerScalarFunFp16_ = NULL; | PowerScalarFunFp16 PowerScalarFunFp16_ = NULL; | ||||
| #if defined(ENABLE_NEON) | |||||
| #if defined(ENABLE_ARM64) | |||||
| PowerSimdFunFp16 PowerSimdFunFp16_ = NULL; | PowerSimdFunFp16 PowerSimdFunFp16_ = NULL; | ||||
| #endif | #endif | ||||
| if (CheckInteger(*exponent)) { | if (CheckInteger(*exponent)) { | ||||
| #if defined(ENABLE_NEON) | |||||
| #if defined(ENABLE_ARM64) | |||||
| PowerSimdFunFp16_ = OptimizedPowerSimdFp16; | PowerSimdFunFp16_ = OptimizedPowerSimdFp16; | ||||
| #endif | #endif | ||||
| PowerScalarFunFp16_ = OptimizedPowerScalarFp16; | PowerScalarFunFp16_ = OptimizedPowerScalarFp16; | ||||
| } else { | } else { | ||||
| #if defined(ENABLE_NEON) | |||||
| #if defined(ENABLE_ARM64) | |||||
| PowerSimdFunFp16_ = StdPowerSimdFp16; | PowerSimdFunFp16_ = StdPowerSimdFp16; | ||||
| #endif | #endif | ||||
| PowerScalarFunFp16_ = StdPowerScalarFp16; | PowerScalarFunFp16_ = StdPowerScalarFp16; | ||||
| } | } | ||||
| int i = 0; | int i = 0; | ||||
| #ifdef ENABLE_NEON | |||||
| #ifdef ENABLE_ARM64 | |||||
| int len_c8 = UP_ROUND(len, C8NUM); | int len_c8 = UP_ROUND(len, C8NUM); | ||||
| float16x8_t scale_8 = vmovq_n_f16(scale); | float16x8_t scale_8 = vmovq_n_f16(scale); | ||||
| float16x8_t shift_8 = vmovq_n_f16(shift); | float16x8_t shift_8 = vmovq_n_f16(shift); | ||||
| @@ -87,7 +87,7 @@ void PowerSingleFp16(const float16_t *input, const float16_t *exponent, float16_ | |||||
| float shift) { | float shift) { | ||||
| int i = 0; | int i = 0; | ||||
| PowerScalarFunFp16 PowerScalarFunFp16_ = NULL; | PowerScalarFunFp16 PowerScalarFunFp16_ = NULL; | ||||
| #ifdef ENABLE_NEON | |||||
| #ifdef ENABLE_ARM64 | |||||
| int len_c8 = UP_ROUND(len, C8NUM); | int len_c8 = UP_ROUND(len, C8NUM); | ||||
| float16x8_t scale_8 = vmovq_n_f16(scale); | float16x8_t scale_8 = vmovq_n_f16(scale); | ||||
| float16x8_t shift_8 = vmovq_n_f16(shift); | float16x8_t shift_8 = vmovq_n_f16(shift); | ||||
| @@ -19,9 +19,10 @@ | |||||
| #include <math.h> | #include <math.h> | ||||
| #include "nnacl/op_base.h" | #include "nnacl/op_base.h" | ||||
| #include "nnacl/intrinsics/ms_simd_instructions_fp16.h" | |||||
| #include "nnacl/power_parameter.h" | #include "nnacl/power_parameter.h" | ||||
| #if defined(ENABLE_NEON) | |||||
| #if defined(ENABLE_ARM64) | |||||
| typedef float16x8_t (*PowerSimdFunFp16)(float16x8_t x, const void *exponent); | typedef float16x8_t (*PowerSimdFunFp16)(float16x8_t x, const void *exponent); | ||||
| #endif | #endif | ||||
| typedef float16_t (*PowerScalarFunFp16)(float16_t x, const void *exponent); | typedef float16_t (*PowerScalarFunFp16)(float16_t x, const void *exponent); | ||||
| @@ -36,7 +37,7 @@ static inline float16_t StdPowerScalarFp16(float16_t x, const void *exponent) { | |||||
| return powf(x, *(float16_t *)exponent); | return powf(x, *(float16_t *)exponent); | ||||
| } | } | ||||
| #if defined(ENABLE_NEON) | |||||
| #if defined(ENABLE_ARM64) | |||||
| static inline float16x8_t StdPowerSimdFp16(float16x8_t x, const void *exponent) { | static inline float16x8_t StdPowerSimdFp16(float16x8_t x, const void *exponent) { | ||||
| float16x8_t result; | float16x8_t result; | ||||
| result[0] = powf(x[0], *(float16_t *)exponent); | result[0] = powf(x[0], *(float16_t *)exponent); | ||||
| @@ -18,10 +18,7 @@ | |||||
| #define MINDSPORE_NNACL_FP16_QUANTDTYPECAST_FP16_H_ | #define MINDSPORE_NNACL_FP16_QUANTDTYPECAST_FP16_H_ | ||||
| #include "nnacl/op_base.h" | #include "nnacl/op_base.h" | ||||
| #ifdef ENABLE_NEON | |||||
| #include <arm_neon.h> | |||||
| #endif | |||||
| #include "nnacl/intrinsics/ms_simd_instructions_fp16.h" | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| extern "C" { | extern "C" { | ||||
| @@ -19,9 +19,6 @@ | |||||
| #include "nnacl/op_base.h" | #include "nnacl/op_base.h" | ||||
| #include "nnacl/reduce_parameter.h" | #include "nnacl/reduce_parameter.h" | ||||
| #ifdef ENABLE_NEON | |||||
| #include <arm_neon.h> | |||||
| #endif | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| extern "C" { | extern "C" { | ||||
| #endif | #endif | ||||
| @@ -18,10 +18,9 @@ | |||||
| #define MINDSPORE_NNACL_SCALE_FP16_H_ | #define MINDSPORE_NNACL_SCALE_FP16_H_ | ||||
| #include "nnacl/op_base.h" | #include "nnacl/op_base.h" | ||||
| #include "nnacl/intrinsics/ms_simd_instructions_fp16.h" | |||||
| #include "nnacl/scale.h" | #include "nnacl/scale.h" | ||||
| #ifdef ENABLE_NEON | |||||
| #include <arm_neon.h> | |||||
| #endif | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| extern "C" { | extern "C" { | ||||
| #endif | #endif | ||||
| @@ -40,7 +40,7 @@ void SoftmaxNormFp16(const float16_t *src, float16_t *dst, int batch, int channe | |||||
| } | } | ||||
| } | } | ||||
| int k = 0; | int k = 0; | ||||
| #ifdef ENABLE_NEON | |||||
| #ifdef ENABLE_ARM64 | |||||
| int count2 = (channel / C8NUM) * C8NUM; | int count2 = (channel / C8NUM) * C8NUM; | ||||
| for (; k < count2; k += C8NUM) { | for (; k < count2; k += C8NUM) { | ||||
| float16x8_t input_8 = vld1q_f16(src + cur_batch_offset + k); | float16x8_t input_8 = vld1q_f16(src + cur_batch_offset + k); | ||||
| @@ -58,9 +58,9 @@ void SoftmaxNormFp16(const float16_t *src, float16_t *dst, int batch, int channe | |||||
| void SumAndDivFp16(const float16_t *src, float16_t *dst, int batch, int channel) { | void SumAndDivFp16(const float16_t *src, float16_t *dst, int batch, int channel) { | ||||
| int cur_batch_offset = 0; | int cur_batch_offset = 0; | ||||
| for (int i = 0; i < batch; i++, cur_batch_offset += channel) { | for (int i = 0; i < batch; i++, cur_batch_offset += channel) { | ||||
| float16_t sum = 0; | |||||
| float16_t sum = 0.0f; | |||||
| int j = 0; | int j = 0; | ||||
| #ifdef ENABLE_NEON | |||||
| #ifdef ENABLE_ARM64 | |||||
| float16x8_t sum8 = vdupq_n_f16(0); | float16x8_t sum8 = vdupq_n_f16(0); | ||||
| int count = (channel / C8NUM) * C8NUM; | int count = (channel / C8NUM) * C8NUM; | ||||
| for (; j < count; j += C8NUM) { | for (; j < count; j += C8NUM) { | ||||
| @@ -72,7 +72,7 @@ void SumAndDivFp16(const float16_t *src, float16_t *dst, int batch, int channel) | |||||
| sum += src[cur_batch_offset + j]; | sum += src[cur_batch_offset + j]; | ||||
| } | } | ||||
| int k = 0; | int k = 0; | ||||
| #ifdef ENABLE_NEON | |||||
| #ifdef ENABLE_ARM64 | |||||
| const float16_t div = 1.0f / sum; | const float16_t div = 1.0f / sum; | ||||
| for (; k < count; k += C8NUM) { | for (; k < count; k += C8NUM) { | ||||
| vst1q_f16(dst + cur_batch_offset + k, vmulq_n_f16(vld1q_f16(src + cur_batch_offset + k), div)); | vst1q_f16(dst + cur_batch_offset + k, vmulq_n_f16(vld1q_f16(src + cur_batch_offset + k), div)); | ||||
| @@ -117,7 +117,7 @@ void SoftmaxFp16(const float16_t *input_ptr, float16_t *output_ptr, float16_t *s | |||||
| } | } | ||||
| for (int j = 0; j < input_shape[axis]; j++) { | for (int j = 0; j < input_shape[axis]; j++) { | ||||
| int axis_offset = inner_offset + j * inner_size; | int axis_offset = inner_offset + j * inner_size; | ||||
| output_ptr[axis_offset] = exp(input_ptr[axis_offset] - max_data); | |||||
| output_ptr[axis_offset] = expf(input_ptr[axis_offset] - max_data); | |||||
| sum_data[k + sum_outter_offset] += output_ptr[axis_offset]; | sum_data[k + sum_outter_offset] += output_ptr[axis_offset]; | ||||
| } | } | ||||
| } | } | ||||
| @@ -18,10 +18,9 @@ | |||||
| #define MINDSPORE_NNACL_FP16_SOFTMAX_FP16_H_ | #define MINDSPORE_NNACL_FP16_SOFTMAX_FP16_H_ | ||||
| #include "nnacl/op_base.h" | #include "nnacl/op_base.h" | ||||
| #include "nnacl/intrinsics/ms_simd_instructions_fp16.h" | |||||
| #include "nnacl/softmax_parameter.h" | #include "nnacl/softmax_parameter.h" | ||||
| #ifdef ENABLE_NEON | |||||
| #include <arm_neon.h> | |||||
| #endif | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| extern "C" { | extern "C" { | ||||
| #endif | #endif | ||||
| @@ -18,10 +18,8 @@ | |||||
| #define MINDSPORE_NNACL_FP16_TRANSPOSE_FP16_H_ | #define MINDSPORE_NNACL_FP16_TRANSPOSE_FP16_H_ | ||||
| #include "nnacl/op_base.h" | #include "nnacl/op_base.h" | ||||
| #include "nnacl/intrinsics/ms_simd_instructions_fp16.h" | |||||
| #include "nnacl/transpose.h" | #include "nnacl/transpose.h" | ||||
| #ifdef ENABLE_NEON | |||||
| #include <arm_neon.h> | |||||
| #endif | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| extern "C" { | extern "C" { | ||||
| @@ -16,562 +16,15 @@ | |||||
| #include "nnacl/fp16/winograd_transform_fp16.h" | #include "nnacl/fp16/winograd_transform_fp16.h" | ||||
| // for fp16 convolution 3x3 filter/input/output transform F(4,3) | |||||
| void Conv3x3Fp16InputUnit(float16_t *tmp_data, float16_t *trans_input_data, size_t step) { | |||||
| float16x8_t d00 = vld1q_f16(tmp_data); | |||||
| float16x8_t d01 = vld1q_f16(tmp_data + 8); | |||||
| float16x8_t d02 = vld1q_f16(tmp_data + 2 * 8); | |||||
| float16x8_t d03 = vld1q_f16(tmp_data + 3 * 8); | |||||
| float16x8_t d04 = vld1q_f16(tmp_data + 4 * 8); | |||||
| float16x8_t d05 = vld1q_f16(tmp_data + 5 * 8); | |||||
| float16x8_t d10 = vld1q_f16(tmp_data + 6 * 8); | |||||
| float16x8_t d11 = vld1q_f16(tmp_data + 7 * 8); | |||||
| float16x8_t d12 = vld1q_f16(tmp_data + 8 * 8); | |||||
| float16x8_t d13 = vld1q_f16(tmp_data + 9 * 8); | |||||
| float16x8_t d14 = vld1q_f16(tmp_data + 10 * 8); | |||||
| float16x8_t d15 = vld1q_f16(tmp_data + 11 * 8); | |||||
| float16x8_t d20 = vld1q_f16(tmp_data + 12 * 8); | |||||
| float16x8_t d21 = vld1q_f16(tmp_data + 13 * 8); | |||||
| float16x8_t d22 = vld1q_f16(tmp_data + 14 * 8); | |||||
| float16x8_t d23 = vld1q_f16(tmp_data + 15 * 8); | |||||
| float16x8_t d24 = vld1q_f16(tmp_data + 16 * 8); | |||||
| float16x8_t d25 = vld1q_f16(tmp_data + 17 * 8); | |||||
| float16x8_t d30 = vld1q_f16(tmp_data + 18 * 8); | |||||
| float16x8_t d31 = vld1q_f16(tmp_data + 19 * 8); | |||||
| float16x8_t d32 = vld1q_f16(tmp_data + 20 * 8); | |||||
| float16x8_t d33 = vld1q_f16(tmp_data + 21 * 8); | |||||
| float16x8_t d34 = vld1q_f16(tmp_data + 22 * 8); | |||||
| float16x8_t d35 = vld1q_f16(tmp_data + 23 * 8); | |||||
| float16x8_t d40 = vld1q_f16(tmp_data + 24 * 8); | |||||
| float16x8_t d41 = vld1q_f16(tmp_data + 25 * 8); | |||||
| float16x8_t d42 = vld1q_f16(tmp_data + 26 * 8); | |||||
| float16x8_t d43 = vld1q_f16(tmp_data + 27 * 8); | |||||
| float16x8_t d44 = vld1q_f16(tmp_data + 28 * 8); | |||||
| float16x8_t d45 = vld1q_f16(tmp_data + 29 * 8); | |||||
| float16x8_t d50 = vld1q_f16(tmp_data + 30 * 8); | |||||
| float16x8_t d51 = vld1q_f16(tmp_data + 31 * 8); | |||||
| float16x8_t d52 = vld1q_f16(tmp_data + 32 * 8); | |||||
| float16x8_t d53 = vld1q_f16(tmp_data + 33 * 8); | |||||
| float16x8_t d54 = vld1q_f16(tmp_data + 34 * 8); | |||||
| float16x8_t d55 = vld1q_f16(tmp_data + 35 * 8); | |||||
| float16x8_t t00 = vaddq_f16(vsubq_f16(vmulq_n_f16(d00, 4), vmulq_n_f16(d20, 5)), d40); | |||||
| float16x8_t t01 = vaddq_f16(vsubq_f16(vmulq_n_f16(d01, 4), vmulq_n_f16(d21, 5)), d41); | |||||
| float16x8_t t02 = vaddq_f16(vsubq_f16(vmulq_n_f16(d02, 4), vmulq_n_f16(d22, 5)), d42); | |||||
| float16x8_t t03 = vaddq_f16(vsubq_f16(vmulq_n_f16(d03, 4), vmulq_n_f16(d23, 5)), d43); | |||||
| float16x8_t t04 = vaddq_f16(vsubq_f16(vmulq_n_f16(d04, 4), vmulq_n_f16(d24, 5)), d44); | |||||
| float16x8_t t05 = vaddq_f16(vsubq_f16(vmulq_n_f16(d05, 4), vmulq_n_f16(d25, 5)), d45); | |||||
| float16x8_t t10 = vaddq_f16(vaddq_f16(d30, d40), vmulq_n_f16(vaddq_f16(d10, d20), -4)); | |||||
| float16x8_t t11 = vaddq_f16(vaddq_f16(d31, d41), vmulq_n_f16(vaddq_f16(d11, d21), -4)); | |||||
| float16x8_t t12 = vaddq_f16(vaddq_f16(d32, d42), vmulq_n_f16(vaddq_f16(d12, d22), -4)); | |||||
| float16x8_t t13 = vaddq_f16(vaddq_f16(d33, d43), vmulq_n_f16(vaddq_f16(d13, d23), -4)); | |||||
| float16x8_t t14 = vaddq_f16(vaddq_f16(d34, d44), vmulq_n_f16(vaddq_f16(d14, d24), -4)); | |||||
| float16x8_t t15 = vaddq_f16(vaddq_f16(d35, d45), vmulq_n_f16(vaddq_f16(d15, d25), -4)); | |||||
| float16x8_t t20 = vaddq_f16(vsubq_f16(d40, d30), vmulq_n_f16(vsubq_f16(d10, d20), 4)); | |||||
| float16x8_t t21 = vaddq_f16(vsubq_f16(d41, d31), vmulq_n_f16(vsubq_f16(d11, d21), 4)); | |||||
| float16x8_t t22 = vaddq_f16(vsubq_f16(d42, d32), vmulq_n_f16(vsubq_f16(d12, d22), 4)); | |||||
| float16x8_t t23 = vaddq_f16(vsubq_f16(d43, d33), vmulq_n_f16(vsubq_f16(d13, d23), 4)); | |||||
| float16x8_t t24 = vaddq_f16(vsubq_f16(d44, d34), vmulq_n_f16(vsubq_f16(d14, d24), 4)); | |||||
| float16x8_t t25 = vaddq_f16(vsubq_f16(d45, d35), vmulq_n_f16(vsubq_f16(d15, d25), 4)); | |||||
| float16x8_t t30 = vaddq_f16(vsubq_f16(d40, d20), vmulq_n_f16(vsubq_f16(d30, d10), 2)); | |||||
| float16x8_t t31 = vaddq_f16(vsubq_f16(d41, d21), vmulq_n_f16(vsubq_f16(d31, d11), 2)); | |||||
| float16x8_t t32 = vaddq_f16(vsubq_f16(d42, d22), vmulq_n_f16(vsubq_f16(d32, d12), 2)); | |||||
| float16x8_t t33 = vaddq_f16(vsubq_f16(d43, d23), vmulq_n_f16(vsubq_f16(d33, d13), 2)); | |||||
| float16x8_t t34 = vaddq_f16(vsubq_f16(d44, d24), vmulq_n_f16(vsubq_f16(d34, d14), 2)); | |||||
| float16x8_t t35 = vaddq_f16(vsubq_f16(d45, d25), vmulq_n_f16(vsubq_f16(d35, d15), 2)); | |||||
| float16x8_t t40 = vaddq_f16(vsubq_f16(d40, d20), vmulq_n_f16(vsubq_f16(d10, d30), 2)); | |||||
| float16x8_t t41 = vaddq_f16(vsubq_f16(d41, d21), vmulq_n_f16(vsubq_f16(d11, d31), 2)); | |||||
| float16x8_t t42 = vaddq_f16(vsubq_f16(d42, d22), vmulq_n_f16(vsubq_f16(d12, d32), 2)); | |||||
| float16x8_t t43 = vaddq_f16(vsubq_f16(d43, d23), vmulq_n_f16(vsubq_f16(d13, d33), 2)); | |||||
| float16x8_t t44 = vaddq_f16(vsubq_f16(d44, d24), vmulq_n_f16(vsubq_f16(d14, d34), 2)); | |||||
| float16x8_t t45 = vaddq_f16(vsubq_f16(d45, d25), vmulq_n_f16(vsubq_f16(d15, d35), 2)); | |||||
| float16x8_t t50 = vaddq_f16(vsubq_f16(vmulq_n_f16(d10, 4), vmulq_n_f16(d30, 5)), d50); | |||||
| float16x8_t t51 = vaddq_f16(vsubq_f16(vmulq_n_f16(d11, 4), vmulq_n_f16(d31, 5)), d51); | |||||
| float16x8_t t52 = vaddq_f16(vsubq_f16(vmulq_n_f16(d12, 4), vmulq_n_f16(d32, 5)), d52); | |||||
| float16x8_t t53 = vaddq_f16(vsubq_f16(vmulq_n_f16(d13, 4), vmulq_n_f16(d33, 5)), d53); | |||||
| float16x8_t t54 = vaddq_f16(vsubq_f16(vmulq_n_f16(d14, 4), vmulq_n_f16(d34, 5)), d54); | |||||
| float16x8_t t55 = vaddq_f16(vsubq_f16(vmulq_n_f16(d15, 4), vmulq_n_f16(d35, 5)), d55); | |||||
| float16x8_t m00 = vaddq_f16(vsubq_f16(vmulq_n_f16(t00, 4), vmulq_n_f16(t02, 5)), t04); | |||||
| float16x8_t m01 = vaddq_f16(vaddq_f16(t03, t04), vmulq_n_f16(vaddq_f16(t01, t02), -4)); | |||||
| float16x8_t m02 = vaddq_f16(vsubq_f16(t04, t03), vmulq_n_f16(vsubq_f16(t01, t02), 4)); | |||||
| float16x8_t m03 = vaddq_f16(vsubq_f16(t04, t02), vmulq_n_f16(vsubq_f16(t03, t01), 2)); | |||||
| float16x8_t m04 = vaddq_f16(vsubq_f16(t04, t02), vmulq_n_f16(vsubq_f16(t01, t03), 2)); | |||||
| float16x8_t m05 = vaddq_f16(vsubq_f16(vmulq_n_f16(t01, 4), vmulq_n_f16(t03, 5)), t05); | |||||
| float16x8_t m10 = vaddq_f16(vsubq_f16(vmulq_n_f16(t10, 4), vmulq_n_f16(t12, 5)), t14); | |||||
| float16x8_t m11 = vaddq_f16(vaddq_f16(t13, t14), vmulq_n_f16(vaddq_f16(t11, t12), -4)); | |||||
| float16x8_t m12 = vaddq_f16(vsubq_f16(t14, t13), vmulq_n_f16(vsubq_f16(t11, t12), 4)); | |||||
| float16x8_t m13 = vaddq_f16(vsubq_f16(t14, t12), vmulq_n_f16(vsubq_f16(t13, t11), 2)); | |||||
| float16x8_t m14 = vaddq_f16(vsubq_f16(t14, t12), vmulq_n_f16(vsubq_f16(t11, t13), 2)); | |||||
| float16x8_t m15 = vaddq_f16(vsubq_f16(vmulq_n_f16(t11, 4), vmulq_n_f16(t13, 5)), t15); | |||||
| float16x8_t m20 = vaddq_f16(vsubq_f16(vmulq_n_f16(t20, 4), vmulq_n_f16(t22, 5)), t24); | |||||
| float16x8_t m21 = vaddq_f16(vaddq_f16(t23, t24), vmulq_n_f16(vaddq_f16(t21, t22), -4)); | |||||
| float16x8_t m22 = vaddq_f16(vsubq_f16(t24, t23), vmulq_n_f16(vsubq_f16(t21, t22), 4)); | |||||
| float16x8_t m23 = vaddq_f16(vsubq_f16(t24, t22), vmulq_n_f16(vsubq_f16(t23, t21), 2)); | |||||
| float16x8_t m24 = vaddq_f16(vsubq_f16(t24, t22), vmulq_n_f16(vsubq_f16(t21, t23), 2)); | |||||
| float16x8_t m25 = vaddq_f16(vsubq_f16(vmulq_n_f16(t21, 4), vmulq_n_f16(t23, 5)), t25); | |||||
| float16x8_t m30 = vaddq_f16(vsubq_f16(vmulq_n_f16(t30, 4), vmulq_n_f16(t32, 5)), t34); | |||||
| float16x8_t m31 = vaddq_f16(vaddq_f16(t33, t34), vmulq_n_f16(vaddq_f16(t31, t32), -4)); | |||||
| float16x8_t m32 = vaddq_f16(vsubq_f16(t34, t33), vmulq_n_f16(vsubq_f16(t31, t32), 4)); | |||||
| float16x8_t m33 = vaddq_f16(vsubq_f16(t34, t32), vmulq_n_f16(vsubq_f16(t33, t31), 2)); | |||||
| float16x8_t m34 = vaddq_f16(vsubq_f16(t34, t32), vmulq_n_f16(vsubq_f16(t31, t33), 2)); | |||||
| float16x8_t m35 = vaddq_f16(vsubq_f16(vmulq_n_f16(t31, 4), vmulq_n_f16(t33, 5)), t35); | |||||
| float16x8_t m40 = vaddq_f16(vsubq_f16(vmulq_n_f16(t40, 4), vmulq_n_f16(t42, 5)), t44); | |||||
| float16x8_t m41 = vaddq_f16(vaddq_f16(t43, t44), vmulq_n_f16(vaddq_f16(t41, t42), -4)); | |||||
| float16x8_t m42 = vaddq_f16(vsubq_f16(t44, t43), vmulq_n_f16(vsubq_f16(t41, t42), 4)); | |||||
| float16x8_t m43 = vaddq_f16(vsubq_f16(t44, t42), vmulq_n_f16(vsubq_f16(t43, t41), 2)); | |||||
| float16x8_t m44 = vaddq_f16(vsubq_f16(t44, t42), vmulq_n_f16(vsubq_f16(t41, t43), 2)); | |||||
| float16x8_t m45 = vaddq_f16(vsubq_f16(vmulq_n_f16(t41, 4), vmulq_n_f16(t43, 5)), t45); | |||||
| float16x8_t m50 = vaddq_f16(vsubq_f16(vmulq_n_f16(t50, 4), vmulq_n_f16(t52, 5)), t54); | |||||
| float16x8_t m51 = vaddq_f16(vaddq_f16(t53, t54), vmulq_n_f16(vaddq_f16(t51, t52), -4)); | |||||
| float16x8_t m52 = vaddq_f16(vsubq_f16(t54, t53), vmulq_n_f16(vsubq_f16(t51, t52), 4)); | |||||
| float16x8_t m53 = vaddq_f16(vsubq_f16(t54, t52), vmulq_n_f16(vsubq_f16(t53, t51), 2)); | |||||
| float16x8_t m54 = vaddq_f16(vsubq_f16(t54, t52), vmulq_n_f16(vsubq_f16(t51, t53), 2)); | |||||
| float16x8_t m55 = vaddq_f16(vsubq_f16(vmulq_n_f16(t51, 4), vmulq_n_f16(t53, 5)), t55); | |||||
| vst1_f16(trans_input_data, vget_low_f16(m00)); | |||||
| vst1_f16(trans_input_data + 64, vget_high_f16(m00)); | |||||
| vst1_f16(trans_input_data + step, vget_low_f16(m01)); | |||||
| vst1_f16(trans_input_data + step + 64, vget_high_f16(m01)); | |||||
| vst1_f16(trans_input_data + 2 * step, vget_low_f16(m02)); | |||||
| vst1_f16(trans_input_data + 2 * step + 64, vget_high_f16(m02)); | |||||
| vst1_f16(trans_input_data + 3 * step, vget_low_f16(m03)); | |||||
| vst1_f16(trans_input_data + 3 * step + 64, vget_high_f16(m03)); | |||||
| vst1_f16(trans_input_data + 4 * step, vget_low_f16(m04)); | |||||
| vst1_f16(trans_input_data + 4 * step + 64, vget_high_f16(m04)); | |||||
| vst1_f16(trans_input_data + 5 * step, vget_low_f16(m05)); | |||||
| vst1_f16(trans_input_data + 5 * step + 64, vget_high_f16(m05)); | |||||
| vst1_f16(trans_input_data + 6 * step, vget_low_f16(m10)); | |||||
| vst1_f16(trans_input_data + 6 * step + 64, vget_high_f16(m10)); | |||||
| vst1_f16(trans_input_data + 7 * step, vget_low_f16(m11)); | |||||
| vst1_f16(trans_input_data + 7 * step + 64, vget_high_f16(m11)); | |||||
| vst1_f16(trans_input_data + 8 * step, vget_low_f16(m12)); | |||||
| vst1_f16(trans_input_data + 8 * step + 64, vget_high_f16(m12)); | |||||
| vst1_f16(trans_input_data + 9 * step, vget_low_f16(m13)); | |||||
| vst1_f16(trans_input_data + 9 * step + 64, vget_high_f16(m13)); | |||||
| vst1_f16(trans_input_data + 10 * step, vget_low_f16(m14)); | |||||
| vst1_f16(trans_input_data + 10 * step + 64, vget_high_f16(m14)); | |||||
| vst1_f16(trans_input_data + 11 * step, vget_low_f16(m15)); | |||||
| vst1_f16(trans_input_data + 11 * step + 64, vget_high_f16(m15)); | |||||
| vst1_f16(trans_input_data + 12 * step, vget_low_f16(m20)); | |||||
| vst1_f16(trans_input_data + 12 * step + 64, vget_high_f16(m20)); | |||||
| vst1_f16(trans_input_data + 13 * step, vget_low_f16(m21)); | |||||
| vst1_f16(trans_input_data + 13 * step + 64, vget_high_f16(m21)); | |||||
| vst1_f16(trans_input_data + 14 * step, vget_low_f16(m22)); | |||||
| vst1_f16(trans_input_data + 14 * step + 64, vget_high_f16(m22)); | |||||
| vst1_f16(trans_input_data + 15 * step, vget_low_f16(m23)); | |||||
| vst1_f16(trans_input_data + 15 * step + 64, vget_high_f16(m23)); | |||||
| vst1_f16(trans_input_data + 16 * step, vget_low_f16(m24)); | |||||
| vst1_f16(trans_input_data + 16 * step + 64, vget_high_f16(m24)); | |||||
| vst1_f16(trans_input_data + 17 * step, vget_low_f16(m25)); | |||||
| vst1_f16(trans_input_data + 17 * step + 64, vget_high_f16(m25)); | |||||
| vst1_f16(trans_input_data + 18 * step, vget_low_f16(m30)); | |||||
| vst1_f16(trans_input_data + 18 * step + 64, vget_high_f16(m30)); | |||||
| vst1_f16(trans_input_data + 19 * step, vget_low_f16(m31)); | |||||
| vst1_f16(trans_input_data + 19 * step + 64, vget_high_f16(m31)); | |||||
| vst1_f16(trans_input_data + 20 * step, vget_low_f16(m32)); | |||||
| vst1_f16(trans_input_data + 20 * step + 64, vget_high_f16(m32)); | |||||
| vst1_f16(trans_input_data + 21 * step, vget_low_f16(m33)); | |||||
| vst1_f16(trans_input_data + 21 * step + 64, vget_high_f16(m33)); | |||||
| vst1_f16(trans_input_data + 22 * step, vget_low_f16(m34)); | |||||
| vst1_f16(trans_input_data + 22 * step + 64, vget_high_f16(m34)); | |||||
| vst1_f16(trans_input_data + 23 * step, vget_low_f16(m35)); | |||||
| vst1_f16(trans_input_data + 23 * step + 64, vget_high_f16(m35)); | |||||
| vst1_f16(trans_input_data + 24 * step, vget_low_f16(m40)); | |||||
| vst1_f16(trans_input_data + 24 * step + 64, vget_high_f16(m40)); | |||||
| vst1_f16(trans_input_data + 25 * step, vget_low_f16(m41)); | |||||
| vst1_f16(trans_input_data + 25 * step + 64, vget_high_f16(m41)); | |||||
| vst1_f16(trans_input_data + 26 * step, vget_low_f16(m42)); | |||||
| vst1_f16(trans_input_data + 26 * step + 64, vget_high_f16(m42)); | |||||
| vst1_f16(trans_input_data + 27 * step, vget_low_f16(m43)); | |||||
| vst1_f16(trans_input_data + 27 * step + 64, vget_high_f16(m43)); | |||||
| vst1_f16(trans_input_data + 28 * step, vget_low_f16(m44)); | |||||
| vst1_f16(trans_input_data + 28 * step + 64, vget_high_f16(m44)); | |||||
| vst1_f16(trans_input_data + 29 * step, vget_low_f16(m45)); | |||||
| vst1_f16(trans_input_data + 29 * step + 64, vget_high_f16(m45)); | |||||
| vst1_f16(trans_input_data + 30 * step, vget_low_f16(m50)); | |||||
| vst1_f16(trans_input_data + 30 * step + 64, vget_high_f16(m50)); | |||||
| vst1_f16(trans_input_data + 31 * step, vget_low_f16(m51)); | |||||
| vst1_f16(trans_input_data + 31 * step + 64, vget_high_f16(m51)); | |||||
| vst1_f16(trans_input_data + 32 * step, vget_low_f16(m52)); | |||||
| vst1_f16(trans_input_data + 32 * step + 64, vget_high_f16(m52)); | |||||
| vst1_f16(trans_input_data + 33 * step, vget_low_f16(m53)); | |||||
| vst1_f16(trans_input_data + 33 * step + 64, vget_high_f16(m53)); | |||||
| vst1_f16(trans_input_data + 34 * step, vget_low_f16(m54)); | |||||
| vst1_f16(trans_input_data + 34 * step + 64, vget_high_f16(m54)); | |||||
| vst1_f16(trans_input_data + 35 * step, vget_low_f16(m55)); | |||||
| vst1_f16(trans_input_data + 35 * step + 64, vget_high_f16(m55)); | |||||
| } | |||||
| void Conv3x3Fp16InputTransform(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, | |||||
| int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param) { | |||||
| // input data format : nhwc | |||||
| const int output_unit = 4; | |||||
| int input_channel = conv_param->input_channel_; | |||||
| int input_width = conv_param->input_w_; | |||||
| int input_height = conv_param->input_h_; | |||||
| int pad_w = conv_param->pad_l_; | |||||
| int pad_h = conv_param->pad_u_; | |||||
| int ic8 = UP_DIV(input_channel, C8NUM); | |||||
| if (out_w_block == 0) { | |||||
| return; | |||||
| } | |||||
| for (int cal_id = 0; cal_id < real_cal_num; cal_id++) { | |||||
| int x_id = start_index + cal_id; | |||||
| int origin_x = (x_id % out_w_block) * output_unit - pad_w; | |||||
| int origin_y = (x_id / out_w_block) * output_unit - pad_h; | |||||
| int real_x_start = origin_x > 0 ? 0 : -origin_x; | |||||
| int real_x_end = (origin_x + 6) < input_width ? 6 : (input_width - origin_x); | |||||
| int real_y_start = origin_y > 0 ? 0 : -origin_y; | |||||
| int real_y_end = (origin_y + 6) < input_height ? 6 : (input_height - origin_y); | |||||
| int src_plane_offset = ic8 * C8NUM * (origin_y * input_width + origin_x); | |||||
| int dst_plane_offset = cal_id * C4NUM; | |||||
| for (int ic = 0; ic < ic8; ic++) { | |||||
| // clear tmp buffer | |||||
| memset(tmp_data, 0, 6 * 6 * C8NUM * sizeof(float16_t)); | |||||
| // get real input block with padding | |||||
| int src_ic4_offset = src_plane_offset + ic * C8NUM; | |||||
| for (int interval = real_y_start; interval < real_y_end; interval++) { | |||||
| int src_y_offset = src_ic4_offset + (interval * input_width + real_x_start) * ic8 * C8NUM; | |||||
| int dst_y_offset = interval * 6 * C8NUM + real_x_start * C8NUM; | |||||
| for (int j = 0; j < (real_x_end - real_x_start); j++) { | |||||
| int src_x_offset = src_y_offset + j * ic8 * C8NUM; | |||||
| int dst_x_offset = dst_y_offset + j * C8NUM; | |||||
| float16_t *src_addr = (float16_t *)(input_data) + src_x_offset; | |||||
| float16_t *dst_addr = tmp_data + dst_x_offset; | |||||
| vst1q_f16(dst_addr, vld1q_f16(src_addr)); | |||||
| } | |||||
| } | |||||
| // input transform | |||||
| int dst_ic4_offset = dst_plane_offset + ic * 16 * C8NUM; | |||||
| size_t dst_step = ic8 * C8NUM * 16; | |||||
| float16_t *trans_input_ptr = trans_input + dst_ic4_offset; | |||||
| Conv3x3Fp16InputUnit(tmp_data, trans_input_ptr, dst_step); | |||||
| } | |||||
| } | |||||
| } | |||||
| void Conv3x3Fp16FilterTransform(const float16_t *weight_data, float16_t *trans_weight, int iC4, int output_channel, | |||||
| int kernel_plane) { | |||||
| int dst_step = iC4 * C4NUM * 8; | |||||
| for (int o = 0; o < output_channel; o++) { | |||||
| int oc8_block_num = o / C8NUM; | |||||
| int oc8_block_rem = o % C8NUM; | |||||
| int src_oc_offset = o * iC4 * C4NUM * kernel_plane; | |||||
| int dst_oc_offset = oc8_block_num * C8NUM * iC4 * C4NUM * 36 + oc8_block_rem; | |||||
| for (int i = 0; i < iC4; i++) { | |||||
| const float16_t *src_ic4_ptr = weight_data + src_oc_offset + i * kernel_plane * C4NUM; | |||||
| float16_t *dst_ic4_ptr = trans_weight + dst_oc_offset + i * 8 * C4NUM; | |||||
| float16x4_t g00 = vld1_f16(src_ic4_ptr); | |||||
| float16x4_t g01 = vld1_f16(src_ic4_ptr + 4); | |||||
| float16x4_t g02 = vld1_f16(src_ic4_ptr + 2 * 4); | |||||
| float16x4_t g10 = vld1_f16(src_ic4_ptr + 3 * 4); | |||||
| float16x4_t g11 = vld1_f16(src_ic4_ptr + 4 * 4); | |||||
| float16x4_t g12 = vld1_f16(src_ic4_ptr + 5 * 4); | |||||
| float16x4_t g20 = vld1_f16(src_ic4_ptr + 6 * 4); | |||||
| float16x4_t g21 = vld1_f16(src_ic4_ptr + 7 * 4); | |||||
| float16x4_t g22 = vld1_f16(src_ic4_ptr + 8 * 4); | |||||
| float16x4_t dst00 = vmul_n_f16(g00, 0.25); | |||||
| float16x4_t dst01 = vmul_n_f16(g01, 0.25); | |||||
| float16x4_t dst02 = vmul_n_f16(g02, 0.25); | |||||
| float16x4_t dst10 = vmul_n_f16(vadd_f16(g00, vadd_f16(g10, g20)), -0.1666666666667); | |||||
| float16x4_t dst11 = vmul_n_f16(vadd_f16(g01, vadd_f16(g11, g21)), -0.1666666666667); | |||||
| float16x4_t dst12 = vmul_n_f16(vadd_f16(g02, vadd_f16(g12, g22)), -0.1666666666667); | |||||
| float16x4_t dst20 = vmul_n_f16(vsub_f16(vadd_f16(g00, g20), g10), -0.1666666666667); | |||||
| float16x4_t dst21 = vmul_n_f16(vsub_f16(vadd_f16(g01, g21), g11), -0.1666666666667); | |||||
| float16x4_t dst22 = vmul_n_f16(vsub_f16(vadd_f16(g02, g22), g12), -0.1666666666667); | |||||
| float16x4_t dst30 = vadd_f16(vmul_n_f16(g10, 0.08333333333333), | |||||
| vadd_f16(vmul_n_f16(g00, 0.04166666666667), vmul_n_f16(g20, 0.1666666666667))); | |||||
| float16x4_t dst31 = vadd_f16(vmul_n_f16(g11, 0.08333333333333), | |||||
| vadd_f16(vmul_n_f16(g01, 0.04166666666667), vmul_n_f16(g21, 0.1666666666667))); | |||||
| float16x4_t dst32 = vadd_f16(vmul_n_f16(g12, 0.08333333333333), | |||||
| vadd_f16(vmul_n_f16(g02, 0.04166666666667), vmul_n_f16(g22, 0.1666666666667))); | |||||
| float16x4_t dst40 = vsub_f16(vadd_f16(vmul_n_f16(g00, 0.04166666666667), vmul_n_f16(g20, 0.1666666666667)), | |||||
| vmul_n_f16(g10, 0.08333333333333)); | |||||
| float16x4_t dst41 = vsub_f16(vadd_f16(vmul_n_f16(g01, 0.04166666666667), vmul_n_f16(g21, 0.1666666666667)), | |||||
| vmul_n_f16(g11, 0.08333333333333)); | |||||
| float16x4_t dst42 = vsub_f16(vadd_f16(vmul_n_f16(g02, 0.04166666666667), vmul_n_f16(g22, 0.1666666666667)), | |||||
| vmul_n_f16(g12, 0.08333333333333)); | |||||
| float16x4_t dst50 = g20; | |||||
| float16x4_t dst51 = g21; | |||||
| float16x4_t dst52 = g22; | |||||
| float16x4_t m00 = vmul_n_f16(dst00, 0.25); | |||||
| float16x4_t m01 = vmul_n_f16(vadd_f16(dst00, vadd_f16(dst01, dst02)), -0.1666666666667); | |||||
| float16x4_t m02 = vmul_n_f16(vsub_f16(vadd_f16(dst00, dst02), dst01), -0.1666666666667); | |||||
| float16x4_t m03 = vadd_f16(vmul_n_f16(dst01, 0.08333333333333), | |||||
| vadd_f16(vmul_n_f16(dst00, 0.04166666666667), vmul_n_f16(dst02, 0.1666666666667))); | |||||
| float16x4_t m04 = vsub_f16(vadd_f16(vmul_n_f16(dst00, 0.04166666666667), vmul_n_f16(dst02, 0.1666666666667)), | |||||
| vmul_n_f16(dst01, 0.08333333333333)); | |||||
| float16x4_t m05 = dst02; | |||||
| float16x4_t m10 = vmul_n_f16(dst10, 0.25); | |||||
| float16x4_t m11 = vmul_n_f16(vadd_f16(dst10, vadd_f16(dst11, dst12)), -0.1666666666667); | |||||
| float16x4_t m12 = vmul_n_f16(vsub_f16(vadd_f16(dst10, dst12), dst11), -0.1666666666667); | |||||
| float16x4_t m13 = vadd_f16(vmul_n_f16(dst11, 0.08333333333333), | |||||
| vadd_f16(vmul_n_f16(dst10, 0.04166666666667), vmul_n_f16(dst12, 0.1666666666667))); | |||||
| float16x4_t m14 = vsub_f16(vadd_f16(vmul_n_f16(dst10, 0.04166666666667), vmul_n_f16(dst12, 0.1666666666667)), | |||||
| vmul_n_f16(dst11, 0.08333333333333)); | |||||
| float16x4_t m15 = dst12; | |||||
| float16x4_t m20 = vmul_n_f16(dst20, 0.25); | |||||
| float16x4_t m21 = vmul_n_f16(vadd_f16(dst20, vadd_f16(dst21, dst22)), -0.1666666666667); | |||||
| float16x4_t m22 = vmul_n_f16(vsub_f16(vadd_f16(dst20, dst22), dst21), -0.1666666666667); | |||||
| float16x4_t m23 = vadd_f16(vmul_n_f16(dst21, 0.08333333333333), | |||||
| vadd_f16(vmul_n_f16(dst20, 0.04166666666667), vmul_n_f16(dst22, 0.1666666666667))); | |||||
| float16x4_t m24 = vsub_f16(vadd_f16(vmul_n_f16(dst20, 0.04166666666667), vmul_n_f16(dst22, 0.1666666666667)), | |||||
| vmul_n_f16(dst21, 0.08333333333333)); | |||||
| float16x4_t m25 = dst22; | |||||
| float16x4_t m30 = vmul_n_f16(dst30, 0.25); | |||||
| float16x4_t m31 = vmul_n_f16(vadd_f16(dst30, vadd_f16(dst31, dst32)), -0.1666666666667); | |||||
| float16x4_t m32 = vmul_n_f16(vsub_f16(vadd_f16(dst30, dst32), dst31), -0.1666666666667); | |||||
| float16x4_t m33 = vadd_f16(vmul_n_f16(dst31, 0.08333333333333), | |||||
| vadd_f16(vmul_n_f16(dst30, 0.04166666666667), vmul_n_f16(dst32, 0.1666666666667))); | |||||
| float16x4_t m34 = vsub_f16(vadd_f16(vmul_n_f16(dst30, 0.04166666666667), vmul_n_f16(dst32, 0.1666666666667)), | |||||
| vmul_n_f16(dst31, 0.08333333333333)); | |||||
| float16x4_t m35 = dst32; | |||||
| float16x4_t m40 = vmul_n_f16(dst40, 0.25); | |||||
| float16x4_t m41 = vmul_n_f16(vadd_f16(dst40, vadd_f16(dst41, dst42)), -0.1666666666667); | |||||
| float16x4_t m42 = vmul_n_f16(vsub_f16(vadd_f16(dst40, dst42), dst41), -0.1666666666667); | |||||
| float16x4_t m43 = vadd_f16(vmul_n_f16(dst41, 0.08333333333333), | |||||
| vadd_f16(vmul_n_f16(dst40, 0.04166666666667), vmul_n_f16(dst42, 0.1666666666667))); | |||||
| float16x4_t m44 = vsub_f16(vadd_f16(vmul_n_f16(dst40, 0.04166666666667), vmul_n_f16(dst42, 0.1666666666667)), | |||||
| vmul_n_f16(dst41, 0.08333333333333)); | |||||
| float16x4_t m45 = dst42; | |||||
| float16x4_t m50 = vmul_n_f16(dst50, 0.25); | |||||
| float16x4_t m51 = vmul_n_f16(vadd_f16(dst50, vadd_f16(dst51, dst52)), -0.1666666666667); | |||||
| float16x4_t m52 = vmul_n_f16(vsub_f16(vadd_f16(dst50, dst52), dst51), -0.1666666666667); | |||||
| float16x4_t m53 = vadd_f16(vmul_n_f16(dst51, 0.08333333333333), | |||||
| vadd_f16(vmul_n_f16(dst50, 0.04166666666667), vmul_n_f16(dst52, 0.1666666666667))); | |||||
| float16x4_t m54 = vsub_f16(vadd_f16(vmul_n_f16(dst50, 0.04166666666667), vmul_n_f16(dst52, 0.1666666666667)), | |||||
| vmul_n_f16(dst51, 0.08333333333333)); | |||||
| float16x4_t m55 = dst52; | |||||
| for (int j = 0; j < 4; j++) { | |||||
| dst_ic4_ptr[j * 8] = m00[j]; | |||||
| dst_ic4_ptr[j * 8 + dst_step] = m01[j]; | |||||
| dst_ic4_ptr[j * 8 + 2 * dst_step] = m02[j]; | |||||
| dst_ic4_ptr[j * 8 + 3 * dst_step] = m03[j]; | |||||
| dst_ic4_ptr[j * 8 + 4 * dst_step] = m04[j]; | |||||
| dst_ic4_ptr[j * 8 + 5 * dst_step] = m05[j]; | |||||
| dst_ic4_ptr[j * 8 + 6 * dst_step] = m10[j]; | |||||
| dst_ic4_ptr[j * 8 + 7 * dst_step] = m11[j]; | |||||
| dst_ic4_ptr[j * 8 + 8 * dst_step] = m12[j]; | |||||
| dst_ic4_ptr[j * 8 + 9 * dst_step] = m13[j]; | |||||
| dst_ic4_ptr[j * 8 + 10 * dst_step] = m14[j]; | |||||
| dst_ic4_ptr[j * 8 + 11 * dst_step] = m15[j]; | |||||
| dst_ic4_ptr[j * 8 + 12 * dst_step] = m20[j]; | |||||
| dst_ic4_ptr[j * 8 + 13 * dst_step] = m21[j]; | |||||
| dst_ic4_ptr[j * 8 + 14 * dst_step] = m22[j]; | |||||
| dst_ic4_ptr[j * 8 + 15 * dst_step] = m23[j]; | |||||
| dst_ic4_ptr[j * 8 + 16 * dst_step] = m24[j]; | |||||
| dst_ic4_ptr[j * 8 + 17 * dst_step] = m25[j]; | |||||
| dst_ic4_ptr[j * 8 + 18 * dst_step] = m30[j]; | |||||
| dst_ic4_ptr[j * 8 + 19 * dst_step] = m31[j]; | |||||
| dst_ic4_ptr[j * 8 + 20 * dst_step] = m32[j]; | |||||
| dst_ic4_ptr[j * 8 + 21 * dst_step] = m33[j]; | |||||
| dst_ic4_ptr[j * 8 + 22 * dst_step] = m34[j]; | |||||
| dst_ic4_ptr[j * 8 + 23 * dst_step] = m35[j]; | |||||
| dst_ic4_ptr[j * 8 + 24 * dst_step] = m40[j]; | |||||
| dst_ic4_ptr[j * 8 + 25 * dst_step] = m41[j]; | |||||
| dst_ic4_ptr[j * 8 + 26 * dst_step] = m42[j]; | |||||
| dst_ic4_ptr[j * 8 + 27 * dst_step] = m43[j]; | |||||
| dst_ic4_ptr[j * 8 + 28 * dst_step] = m44[j]; | |||||
| dst_ic4_ptr[j * 8 + 29 * dst_step] = m45[j]; | |||||
| dst_ic4_ptr[j * 8 + 30 * dst_step] = m50[j]; | |||||
| dst_ic4_ptr[j * 8 + 31 * dst_step] = m51[j]; | |||||
| dst_ic4_ptr[j * 8 + 32 * dst_step] = m52[j]; | |||||
| dst_ic4_ptr[j * 8 + 33 * dst_step] = m53[j]; | |||||
| dst_ic4_ptr[j * 8 + 34 * dst_step] = m54[j]; | |||||
| dst_ic4_ptr[j * 8 + 35 * dst_step] = m55[j]; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| void Conv3x3Fp16OutputUnit(const float16_t *gemm_out, const float16_t *bias_data, float16_t *output_data, | |||||
| int output_w) { | |||||
| float16x8_t s00 = vld1q_f16(gemm_out); | |||||
| float16x8_t s01 = vld1q_f16(gemm_out + 8); | |||||
| float16x8_t s02 = vld1q_f16(gemm_out + 16); | |||||
| float16x8_t s03 = vld1q_f16(gemm_out + 24); | |||||
| float16x8_t s04 = vld1q_f16(gemm_out + 32); | |||||
| float16x8_t s05 = vld1q_f16(gemm_out + 40); | |||||
| float16x8_t s10 = vld1q_f16(gemm_out + 48); | |||||
| float16x8_t s11 = vld1q_f16(gemm_out + 56); | |||||
| float16x8_t s12 = vld1q_f16(gemm_out + 64); | |||||
| float16x8_t s13 = vld1q_f16(gemm_out + 72); | |||||
| float16x8_t s14 = vld1q_f16(gemm_out + 80); | |||||
| float16x8_t s15 = vld1q_f16(gemm_out + 88); | |||||
| float16x8_t s20 = vld1q_f16(gemm_out + 96); | |||||
| float16x8_t s21 = vld1q_f16(gemm_out + 104); | |||||
| float16x8_t s22 = vld1q_f16(gemm_out + 112); | |||||
| float16x8_t s23 = vld1q_f16(gemm_out + 120); | |||||
| float16x8_t s24 = vld1q_f16(gemm_out + 128); | |||||
| float16x8_t s25 = vld1q_f16(gemm_out + 136); | |||||
| float16x8_t s30 = vld1q_f16(gemm_out + 144); | |||||
| float16x8_t s31 = vld1q_f16(gemm_out + 152); | |||||
| float16x8_t s32 = vld1q_f16(gemm_out + 160); | |||||
| float16x8_t s33 = vld1q_f16(gemm_out + 168); | |||||
| float16x8_t s34 = vld1q_f16(gemm_out + 176); | |||||
| float16x8_t s35 = vld1q_f16(gemm_out + 184); | |||||
| float16x8_t s40 = vld1q_f16(gemm_out + 192); | |||||
| float16x8_t s41 = vld1q_f16(gemm_out + 200); | |||||
| float16x8_t s42 = vld1q_f16(gemm_out + 208); | |||||
| float16x8_t s43 = vld1q_f16(gemm_out + 216); | |||||
| float16x8_t s44 = vld1q_f16(gemm_out + 224); | |||||
| float16x8_t s45 = vld1q_f16(gemm_out + 232); | |||||
| float16x8_t s50 = vld1q_f16(gemm_out + 240); | |||||
| float16x8_t s51 = vld1q_f16(gemm_out + 248); | |||||
| float16x8_t s52 = vld1q_f16(gemm_out + 256); | |||||
| float16x8_t s53 = vld1q_f16(gemm_out + 264); | |||||
| float16x8_t s54 = vld1q_f16(gemm_out + 272); | |||||
| float16x8_t s55 = vld1q_f16(gemm_out + 280); | |||||
| float16x8_t t00 = vaddq_f16(vaddq_f16(vaddq_f16(s00, s10), vaddq_f16(s20, s30)), s40); | |||||
| float16x8_t t01 = vaddq_f16(vaddq_f16(vaddq_f16(s01, s11), vaddq_f16(s21, s31)), s41); | |||||
| float16x8_t t02 = vaddq_f16(vaddq_f16(vaddq_f16(s02, s12), vaddq_f16(s22, s32)), s42); | |||||
| float16x8_t t03 = vaddq_f16(vaddq_f16(vaddq_f16(s03, s13), vaddq_f16(s23, s33)), s43); | |||||
| float16x8_t t04 = vaddq_f16(vaddq_f16(vaddq_f16(s04, s14), vaddq_f16(s24, s34)), s44); | |||||
| float16x8_t t05 = vaddq_f16(vaddq_f16(vaddq_f16(s05, s15), vaddq_f16(s25, s35)), s45); | |||||
| float16x8_t t10 = vaddq_f16(vsubq_f16(s10, s20), vmulq_n_f16(vsubq_f16(s30, s40), 2)); | |||||
| float16x8_t t11 = vaddq_f16(vsubq_f16(s11, s21), vmulq_n_f16(vsubq_f16(s31, s41), 2)); | |||||
| float16x8_t t12 = vaddq_f16(vsubq_f16(s12, s22), vmulq_n_f16(vsubq_f16(s32, s42), 2)); | |||||
| float16x8_t t13 = vaddq_f16(vsubq_f16(s13, s23), vmulq_n_f16(vsubq_f16(s33, s43), 2)); | |||||
| float16x8_t t14 = vaddq_f16(vsubq_f16(s14, s24), vmulq_n_f16(vsubq_f16(s34, s44), 2)); | |||||
| float16x8_t t15 = vaddq_f16(vsubq_f16(s15, s25), vmulq_n_f16(vsubq_f16(s35, s45), 2)); | |||||
| float16x8_t t20 = vaddq_f16(vaddq_f16(s10, s20), vmulq_n_f16(vaddq_f16(s30, s40), 4)); | |||||
| float16x8_t t21 = vaddq_f16(vaddq_f16(s11, s21), vmulq_n_f16(vaddq_f16(s31, s41), 4)); | |||||
| float16x8_t t22 = vaddq_f16(vaddq_f16(s12, s22), vmulq_n_f16(vaddq_f16(s32, s42), 4)); | |||||
| float16x8_t t23 = vaddq_f16(vaddq_f16(s13, s23), vmulq_n_f16(vaddq_f16(s33, s43), 4)); | |||||
| float16x8_t t24 = vaddq_f16(vaddq_f16(s14, s24), vmulq_n_f16(vaddq_f16(s34, s44), 4)); | |||||
| float16x8_t t25 = vaddq_f16(vaddq_f16(s15, s25), vmulq_n_f16(vaddq_f16(s35, s45), 4)); | |||||
| float16x8_t t30 = vaddq_f16(vaddq_f16(vsubq_f16(s10, s20), vmulq_n_f16(vsubq_f16(s30, s40), 8)), s50); | |||||
| float16x8_t t31 = vaddq_f16(vaddq_f16(vsubq_f16(s11, s21), vmulq_n_f16(vsubq_f16(s31, s41), 8)), s51); | |||||
| float16x8_t t32 = vaddq_f16(vaddq_f16(vsubq_f16(s12, s22), vmulq_n_f16(vsubq_f16(s32, s42), 8)), s52); | |||||
| float16x8_t t33 = vaddq_f16(vaddq_f16(vsubq_f16(s13, s23), vmulq_n_f16(vsubq_f16(s33, s43), 8)), s53); | |||||
| float16x8_t t34 = vaddq_f16(vaddq_f16(vsubq_f16(s14, s24), vmulq_n_f16(vsubq_f16(s34, s44), 8)), s54); | |||||
| float16x8_t t35 = vaddq_f16(vaddq_f16(vsubq_f16(s15, s25), vmulq_n_f16(vsubq_f16(s35, s45), 8)), s55); | |||||
| float16x8_t bias_ptr = vld1q_f16(bias_data); | |||||
| float16x8_t d00 = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t00, t01), vaddq_f16(t02, t03)), t04), bias_ptr); | |||||
| float16x8_t d01 = vaddq_f16(vaddq_f16(vsubq_f16(t01, t02), vmulq_n_f16(vsubq_f16(t03, t04), 2)), bias_ptr); | |||||
| float16x8_t d02 = vaddq_f16(vaddq_f16(vaddq_f16(t01, t02), vmulq_n_f16(vaddq_f16(t03, t04), 4)), bias_ptr); | |||||
| float16x8_t d03 = | |||||
| vaddq_f16(vaddq_f16(vaddq_f16(vsubq_f16(t01, t02), vmulq_n_f16(vsubq_f16(t03, t04), 8)), t05), bias_ptr); | |||||
| float16x8_t d10 = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t10, t11), vaddq_f16(t12, t13)), t14), bias_ptr); | |||||
| float16x8_t d11 = vaddq_f16(vaddq_f16(vsubq_f16(t11, t12), vmulq_n_f16(vsubq_f16(t13, t14), 2)), bias_ptr); | |||||
| float16x8_t d12 = vaddq_f16(vaddq_f16(vaddq_f16(t11, t12), vmulq_n_f16(vaddq_f16(t13, t14), 4)), bias_ptr); | |||||
| float16x8_t d13 = | |||||
| vaddq_f16(vaddq_f16(vaddq_f16(vsubq_f16(t11, t12), vmulq_n_f16(vsubq_f16(t13, t14), 8)), t15), bias_ptr); | |||||
| float16x8_t d20 = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t20, t21), vaddq_f16(t22, t23)), t24), bias_ptr); | |||||
| float16x8_t d21 = vaddq_f16(vaddq_f16(vsubq_f16(t21, t22), vmulq_n_f16(vsubq_f16(t23, t24), 2)), bias_ptr); | |||||
| float16x8_t d22 = vaddq_f16(vaddq_f16(vaddq_f16(t21, t22), vmulq_n_f16(vaddq_f16(t23, t24), 4)), bias_ptr); | |||||
| float16x8_t d23 = | |||||
| vaddq_f16(vaddq_f16(vaddq_f16(vsubq_f16(t21, t22), vmulq_n_f16(vsubq_f16(t23, t24), 8)), t25), bias_ptr); | |||||
| float16x8_t d30 = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t30, t31), vaddq_f16(t32, t33)), t34), bias_ptr); | |||||
| float16x8_t d31 = vaddq_f16(vaddq_f16(vsubq_f16(t31, t32), vmulq_n_f16(vsubq_f16(t33, t34), 2)), bias_ptr); | |||||
| float16x8_t d32 = vaddq_f16(vaddq_f16(vaddq_f16(t31, t32), vmulq_n_f16(vaddq_f16(t33, t34), 4)), bias_ptr); | |||||
| float16x8_t d33 = | |||||
| vaddq_f16(vaddq_f16(vaddq_f16(vsubq_f16(t31, t32), vmulq_n_f16(vsubq_f16(t33, t34), 8)), t35), bias_ptr); | |||||
| vst1q_f16(output_data, d00); | |||||
| vst1q_f16(output_data + 8, d01); | |||||
| vst1q_f16(output_data + 16, d02); | |||||
| vst1q_f16(output_data + 24, d03); | |||||
| vst1q_f16(output_data + output_w * 8, d10); | |||||
| vst1q_f16(output_data + output_w * 8 + 8, d11); | |||||
| vst1q_f16(output_data + output_w * 8 + 16, d12); | |||||
| vst1q_f16(output_data + output_w * 8 + 24, d13); | |||||
| vst1q_f16(output_data + 2 * output_w * 8, d20); | |||||
| vst1q_f16(output_data + 2 * output_w * 8 + 8, d21); | |||||
| vst1q_f16(output_data + 2 * output_w * 8 + 16, d22); | |||||
| vst1q_f16(output_data + 2 * output_w * 8 + 24, d23); | |||||
| vst1q_f16(output_data + 3 * output_w * 8, d30); | |||||
| vst1q_f16(output_data + 3 * output_w * 8 + 8, d31); | |||||
| vst1q_f16(output_data + 3 * output_w * 8 + 16, d32); | |||||
| vst1q_f16(output_data + 3 * output_w * 8 + 24, d33); | |||||
| } | |||||
| void Conv3x3Fp16OutputTransform(const float16_t *gemm_out, float16_t *out_data, const float16_t *bias_data, | |||||
| int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param) { | |||||
| int output_channel = conv_param->output_channel_; | |||||
| int output_h = conv_param->output_h_; | |||||
| int out_h_block = UP_DIV(output_h, C4NUM); | |||||
| int oc8 = UP_DIV(output_channel, C8NUM); | |||||
| if (out_w_block == 0) { | |||||
| return; | |||||
| } | |||||
| for (int i = 0; i < real_cal_num; i++) { | |||||
| int out_w_index = (start_index + i) % out_w_block; | |||||
| int out_h_index = (start_index + i) / out_w_block; | |||||
| int src_tile_offset = i * oc8 * C8NUM * 36; | |||||
| int dst_tile_offset = C8NUM * (out_w_index * C4NUM + out_h_index * C4NUM * out_w_block * C4NUM); | |||||
| for (int j = 0; j < oc8; j++) { | |||||
| int src_oc8_offset = src_tile_offset + j * 36 * C8NUM; | |||||
| int dst_oc8_offset = dst_tile_offset + j * C8NUM * out_h_block * out_w_block * C4NUM * C4NUM; | |||||
| const float16_t *src_ptr = gemm_out + src_oc8_offset; | |||||
| const float16_t *bias_ptr = bias_data + j * C8NUM; | |||||
| float16_t *dst_ptr = out_data + dst_oc8_offset; | |||||
| // output transform | |||||
| Conv3x3Fp16OutputUnit(src_ptr, bias_ptr, dst_ptr, out_w_block * C4NUM); | |||||
| } | |||||
| } | |||||
| } | |||||
| // fp16 common winograd | // fp16 common winograd | ||||
| void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, int cal_num, | void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, int cal_num, | ||||
| int out_tile_index, int out_w_block_num, ConvParameter *conv_param, | int out_tile_index, int out_w_block_num, ConvParameter *conv_param, | ||||
| InputTransFp16Func func) { | InputTransFp16Func func) { | ||||
| #ifdef ENABLE_ARM64 | |||||
| const int tile_num = 16; | const int tile_num = 16; | ||||
| #else | |||||
| const int tile_num = 12; | |||||
| #endif | |||||
| int input_unit = conv_param->input_unit_; | int input_unit = conv_param->input_unit_; | ||||
| int output_unit = conv_param->output_unit_; | int output_unit = conv_param->output_unit_; | ||||
| int in_channel = conv_param->input_channel_; | int in_channel = conv_param->input_channel_; | ||||
| @@ -23,25 +23,11 @@ | |||||
| #include "nnacl/fp16/cast_fp16.h" | #include "nnacl/fp16/cast_fp16.h" | ||||
| #include "nnacl/fp16/conv_fp16.h" | #include "nnacl/fp16/conv_fp16.h" | ||||
| #include "nnacl/fp16/matrix_fp16.h" | #include "nnacl/fp16/matrix_fp16.h" | ||||
| #include "nnacl/fp16/pack_fp16.h" | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| extern "C" { | extern "C" { | ||||
| #endif | #endif | ||||
| // for fp16 convolution 3x3 filter/input/output transform | |||||
| void Conv3x3Fp16InputUnit(float16_t *tmp_data, float16_t *trans_input_data, size_t step); | |||||
| void Conv3x3Fp16InputTransform(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, | |||||
| int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param); | |||||
| void Conv3x3Fp16FilterTransform(const float16_t *weight_data, float16_t *trans_weight, int iC8, int output_channel, | |||||
| int kernel_plane); | |||||
| void Conv3x3Fp16OutputUnit(const float16_t *gemm_out, const float16_t *bias_data, float16_t *output_data, int output_w); | |||||
| void Conv3x3Fp16OutputTransform(const float16_t *gemm_out, float16_t *out_data, const float16_t *bias_data, | |||||
| int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param); | |||||
| // fp16 common winograd | // fp16 common winograd | ||||
| void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, int cal_num, | void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, int cal_num, | ||||
| int out_tile_index, int out_w_block_num, ConvParameter *conv_param, | int out_tile_index, int out_w_block_num, ConvParameter *conv_param, | ||||
| @@ -17,9 +17,9 @@ | |||||
| #ifndef MINDSPORE_NNACL_FP16_WINOGRAD_UTILS_H_ | #ifndef MINDSPORE_NNACL_FP16_WINOGRAD_UTILS_H_ | ||||
| #define MINDSPORE_NNACL_FP16_WINOGRAD_UTILS_H_ | #define MINDSPORE_NNACL_FP16_WINOGRAD_UTILS_H_ | ||||
| #include <arm_neon.h> | |||||
| #include "nnacl/conv_parameter.h" | #include "nnacl/conv_parameter.h" | ||||
| #include "nnacl/op_base.h" | #include "nnacl/op_base.h" | ||||
| #include "nnacl/intrinsics/ms_simd_instructions_fp16.h" | |||||
| #define MAX_LEN 256 | #define MAX_LEN 256 | ||||
| @@ -17,9 +17,11 @@ | |||||
| #ifndef MINDSPORE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ | #ifndef MINDSPORE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ | ||||
| #define MINDSPORE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ | #define MINDSPORE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ | ||||
| #include <math.h> | #include <math.h> | ||||
| #ifdef ENABLE_ARM | #ifdef ENABLE_ARM | ||||
| #include <arm_neon.h> | #include <arm_neon.h> | ||||
| #endif | #endif | ||||
| #if defined(ENABLE_SSE) || defined(ENABLE_AVX) | #if defined(ENABLE_SSE) || defined(ENABLE_AVX) | ||||
| #include <x86intrin.h> | #include <x86intrin.h> | ||||
| #endif | #endif | ||||
| @@ -46,7 +48,7 @@ | |||||
| #ifdef ENABLE_ARM64 | #ifdef ENABLE_ARM64 | ||||
| #define MS_DIVQ_F32(src1, src2) vdivq_f32(src1, src2) | #define MS_DIVQ_F32(src1, src2) vdivq_f32(src1, src2) | ||||
| #else | #else | ||||
| inline static float32x4_t vrecp(float32x4_t v) { | |||||
| static inline float32x4_t vrecp(float32x4_t v) { | |||||
| float32x4_t r = vrecpeq_f32(v); | float32x4_t r = vrecpeq_f32(v); | ||||
| r = vmulq_f32(vrecpsq_f32(v, r), r); | r = vmulq_f32(vrecpsq_f32(v, r), r); | ||||
| r = vmulq_f32(vrecpsq_f32(v, r), r); | r = vmulq_f32(vrecpsq_f32(v, r), r); | ||||
| @@ -205,25 +207,4 @@ static inline MS_FLOAT32X4 MS_ERFX4_F32(MS_FLOAT32X4 src) { | |||||
| return dst; | return dst; | ||||
| } | } | ||||
| #ifdef ENABLE_ARM64 | |||||
| static inline float16x8_t MS_TANHX8_F16(float16x8_t src) { | |||||
| float32x4_t src_low = vcvt_f32_f16(vget_low_f16(src)); | |||||
| float32x4_t src_high = vcvt_f32_f16(vget_high_f16(src)); | |||||
| return vcombine_f16(vcvt_f16_f32(MS_TANHX4_F32(src_low)), vcvt_f16_f32(MS_TANHX4_F32(src_high))); | |||||
| } | |||||
| static inline float16x8_t MS_ERFX8_F16(float16x8_t src) { | |||||
| float16x8_t dst; | |||||
| dst[0] = erff(src[0]); | |||||
| dst[1] = erff(src[1]); | |||||
| dst[2] = erff(src[2]); | |||||
| dst[3] = erff(src[3]); | |||||
| dst[4] = erff(src[4]); | |||||
| dst[5] = erff(src[5]); | |||||
| dst[6] = erff(src[6]); | |||||
| dst[7] = erff(src[7]); | |||||
| return dst; | |||||
| } | |||||
| #endif | |||||
| #endif // MINDSPORE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ | #endif // MINDSPORE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ | ||||
| @@ -0,0 +1,99 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_FP16_H_ | |||||
| #define MINDSPORE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_FP16_H_ | |||||
| #include <math.h> | |||||
| #include "nnacl/intrinsics/ms_simd_instructions.h" | |||||
| #if defined(ENABLE_ARM82_A32) | |||||
| static inline float16x8_t divq_f16(float16x8_t in1, float16x8_t in2) { | |||||
| float16x8_t dst; | |||||
| asm volatile( | |||||
| "vrecpe.f16 q14, %3\n" | |||||
| "vrecps.f16 q15, %3, q14\n" | |||||
| "vmul.f16 q14, q15, q14\n" | |||||
| "vrecps.f16 q15, %3, q14\n" | |||||
| "vmul.f16 q14, q15, q14\n" | |||||
| "vmul.f16 %0, %2, q14\n" | |||||
| : "=w"(dst) | |||||
| : "0"(dst), "w"(in1), "w"(in2) | |||||
| : "q14", "q15"); | |||||
| return dst; | |||||
| } | |||||
| static inline float16x4_t div_f16(float16x4_t in1, float16x4_t in2) { | |||||
| float16x4_t dst; | |||||
| asm volatile( | |||||
| "vrecpe.f16 d14, %3\n" | |||||
| "vrecps.f16 d16, %3, d14\n" | |||||
| "vmul.f16 d14, d16, d14\n" | |||||
| "vrecps.f16 d16, %3, d14\n" | |||||
| "vmul.f16 d14, d16, d14\n" | |||||
| "vmul.f16 %0, %2, d14\n" | |||||
| : "=w"(dst) | |||||
| : "0"(dst), "w"(in1), "w"(in2) | |||||
| : "d14", "d16"); | |||||
| return dst; | |||||
| } | |||||
| static inline float vaddvq_f32(float32x4_t in) { // is not support in arm82 aarch32 | |||||
| return in[0] + in[1] + in[2] + in[3]; | |||||
| } | |||||
| static inline float32x4_t cvt_f32_f16(float16x4_t in) { | |||||
| float32x4_t dst; | |||||
| asm volatile("vcvt.f32.f16 %0, %2\n" : "=w"(dst) : "0"(dst), "w"(in) :); | |||||
| return dst; | |||||
| } | |||||
| static inline float16x4_t cvt_f16_f32(float32x4_t in) { | |||||
| float16x4_t dst; | |||||
| asm volatile("vcvt.f16.f32 %0, %2\n" : "=w"(dst) : "0"(dst), "w"(in) :); | |||||
| return dst; | |||||
| } | |||||
| #define MS_CVT_F32_F16(src) cvt_f32_f16(src) | |||||
| #define MS_CVT_F16_F32(src) cvt_f16_f32(src) | |||||
| #define MS_DIV_F16(src1, src2) div_f16(src1, src2) | |||||
| #define MS_DIVQ_F16(src1, src2) divq_f16(src1, src2) | |||||
| #define MS_FMAQ_N_F16(src1, src2, src3) vfmaq_f16(src1, src2, vdupq_n_f16(src3)) | |||||
| #else | |||||
| #define MS_CVT_F32_F16(src) vcvt_f32_f16(src) | |||||
| #define MS_CVT_F16_F32(src) vcvt_f16_f32(src) | |||||
| #define MS_DIV_F16(src1, src2) vdiv_f16(src1, src2) | |||||
| #define MS_DIVQ_F16(src1, src2) vdivq_f16(src1, src2) | |||||
| #define MS_FMAQ_N_F16(src1, src2, src3) vfmaq_n_f16(src1, src2, src3) | |||||
| #endif | |||||
| static inline float16x8_t MS_TANHX8_F16(float16x8_t src) { | |||||
| float32x4_t src_low = MS_CVT_F32_F16(vget_low_f16(src)); | |||||
| float32x4_t src_high = MS_CVT_F32_F16(vget_high_f16(src)); | |||||
| return vcombine_f16(MS_CVT_F16_F32(MS_TANHX4_F32(src_low)), MS_CVT_F16_F32(MS_TANHX4_F32(src_high))); | |||||
| } | |||||
| static inline float16x8_t MS_ERFX8_F16(float16x8_t src) { | |||||
| float16x8_t dst; | |||||
| dst[0] = erff(src[0]); | |||||
| dst[1] = erff(src[1]); | |||||
| dst[2] = erff(src[2]); | |||||
| dst[3] = erff(src[3]); | |||||
| dst[4] = erff(src[4]); | |||||
| dst[5] = erff(src[5]); | |||||
| dst[6] = erff(src[6]); | |||||
| dst[7] = erff(src[7]); | |||||
| return dst; | |||||
| } | |||||
| #endif // MINDSPORE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_FP16_H_ | |||||
| @@ -4,11 +4,15 @@ set(NNACL_DIR ${CMAKE_CURRENT_SOURCE_DIR}/..) | |||||
| include_directories(NNACL_DIR) | include_directories(NNACL_DIR) | ||||
| ########################### optimized files ########################### | ########################### optimized files ########################### | ||||
| file(GLOB SDOT_SRC ${NNACL_DIR}/assembly/opt/*.S) | |||||
| file(GLOB FP16_C_SRC ${NNACL_DIR}/fp16/*.c) | file(GLOB FP16_C_SRC ${NNACL_DIR}/fp16/*.c) | ||||
| file(GLOB FP16_NEON_SRC ${NNACL_DIR}/assembly/fp16/*.S) | |||||
| if(PLATFORM_ARM32) | |||||
| file(GLOB FP16_NEON_SRC ${NNACL_DIR}/assembly/arm82_aarch32_fp16/*.S) | |||||
| else() | |||||
| file(GLOB FP16_NEON_SRC ${NNACL_DIR}/assembly/fp16/*.S) | |||||
| file(GLOB SDOT_SRC ${NNACL_DIR}/assembly/opt/*.S) | |||||
| set_property(SOURCE ${SDOT_SRC} PROPERTY LANGUAGE C) | |||||
| endif() | |||||
| set_property(SOURCE ${SDOT_SRC} PROPERTY LANGUAGE C) | |||||
| set_property(SOURCE ${FP16_C_SRC} PROPERTY LANGUAGE C) | set_property(SOURCE ${FP16_C_SRC} PROPERTY LANGUAGE C) | ||||
| set_property(SOURCE ${FP16_NEON_SRC} PROPERTY LANGUAGE C) | set_property(SOURCE ${FP16_NEON_SRC} PROPERTY LANGUAGE C) | ||||
| @@ -17,7 +21,6 @@ if(APPLE) | |||||
| set_source_files_properties(${FP16_NEON_SRC} PROPERTIES COMPILE_FLAGS "-x assembler-with-cpp") | set_source_files_properties(${FP16_NEON_SRC} PROPERTIES COMPILE_FLAGS "-x assembler-with-cpp") | ||||
| endif() | endif() | ||||
| ########################### share library build ######################## | ########################### share library build ######################## | ||||
| list(APPEND SDOT_FILES ${SDOT_SRC}) | |||||
| list(APPEND FP16_FILES ${FP16_C_SRC}) | list(APPEND FP16_FILES ${FP16_C_SRC}) | ||||
| list(APPEND FP16_FILES ${FP16_NEON_SRC}) | list(APPEND FP16_FILES ${FP16_NEON_SRC}) | ||||
| @@ -27,13 +30,20 @@ if(SUPPORT_TRAIN) | |||||
| endif() | endif() | ||||
| string(REPLACE "-fvisibility=hidden" "-fvisibility=default" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") | string(REPLACE "-fvisibility=hidden" "-fvisibility=default" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") | ||||
| set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8.2-a+dotprod+fp16") | |||||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv8.2-a+dotprod+fp16") | |||||
| add_library(nnacl_optimize_mid OBJECT ${SDOT_FILES}) | |||||
| add_dependencies(nnacl_optimize_mid fbs_src) | |||||
| if(NOT PLATFORM_ARM32) | |||||
| list(APPEND SDOT_FILES ${SDOT_SRC}) | |||||
| add_library(nnacl_optimize_mid OBJECT ${SDOT_FILES}) | |||||
| add_dependencies(nnacl_optimize_mid fbs_src) | |||||
| endif() | |||||
| if(ENABLE_FP16) | if(ENABLE_FP16) | ||||
| add_library(nnacl_fp16_mid OBJECT ${FP16_FILES}) | add_library(nnacl_fp16_mid OBJECT ${FP16_FILES}) | ||||
| if(PLATFORM_ARM32) | |||||
| target_compile_options(nnacl_fp16_mid PRIVATE -march=armv8.2-a+fp16 -mfpu=neon-fp-armv8 -mfloat-abi=softfp) | |||||
| else() | |||||
| set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8.2-a+dotprod+fp16") | |||||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv8.2-a+dotprod+fp16") | |||||
| endif() | |||||
| add_dependencies(nnacl_fp16_mid fbs_src) | add_dependencies(nnacl_fp16_mid fbs_src) | ||||
| endif() | endif() | ||||
| @@ -5,6 +5,12 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_L | |||||
| message(FATAL_ERROR "GCC version ${CMAKE_CXX_COMPILER_VERSION} must not be less than 7.3.0") | message(FATAL_ERROR "GCC version ${CMAKE_CXX_COMPILER_VERSION} must not be less than 7.3.0") | ||||
| endif() | endif() | ||||
| if(PLATFORM_ARM32 AND CMAKE_CXX_COMPILER_ID STREQUAL "Clang" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9.0) | |||||
| set(ENABLE_FP16 "off") | |||||
| message(WARNING "If you want to build fp16 in arm82_a32, \ | |||||
| your Clang version:[${CMAKE_CXX_COMPILER_VERSION}] must not be less than 9.0 and please use android nkd r21e!") | |||||
| endif() | |||||
| option(MS_VERSION_MAJOR "major version" 0) | option(MS_VERSION_MAJOR "major version" 0) | ||||
| option(MS_VERSION_MINOR "minor version" 7) | option(MS_VERSION_MINOR "minor version" 7) | ||||
| option(MS_VERSION_REVISION "revision version" 0) | option(MS_VERSION_REVISION "revision version" 0) | ||||
| @@ -15,6 +21,7 @@ option(PLATFORM_ARM64 "if build device for arm64" off) | |||||
| option(PLATFORM_ARM32 "if build device for arm32" off) | option(PLATFORM_ARM32 "if build device for arm32" off) | ||||
| option(ENABLE_CONVERTER "if build converter" on) | option(ENABLE_CONVERTER "if build converter" on) | ||||
| option(ENABLE_FP16 "if build fp16 ops" off) | option(ENABLE_FP16 "if build fp16 ops" off) | ||||
| option(ENABLE_ARM82_A32 "if build fp16 on platform_arm32" off) | |||||
| option(ENABLE_TOOLS "if build tools" on) | option(ENABLE_TOOLS "if build tools" on) | ||||
| option(BUILD_TESTCASES "if build testcase" on) | option(BUILD_TESTCASES "if build testcase" on) | ||||
| option(SUPPORT_GPU "if support gpu" off) | option(SUPPORT_GPU "if support gpu" off) | ||||
| @@ -177,6 +184,9 @@ if(ENABLE_NEON) | |||||
| endif() | endif() | ||||
| if(ENABLE_FP16) | if(ENABLE_FP16) | ||||
| add_compile_definitions(ENABLE_FP16) | add_compile_definitions(ENABLE_FP16) | ||||
| if(PLATFORM_ARM32) | |||||
| add_compile_definitions(ENABLE_ARM82_A32) | |||||
| endif() | |||||
| endif() | endif() | ||||
| if(SUPPORT_GPU STREQUAL opencl) | if(SUPPORT_GPU STREQUAL opencl) | ||||
| add_definitions(-DGPU_OPENCL) | add_definitions(-DGPU_OPENCL) | ||||
| @@ -3,6 +3,9 @@ if(ENABLE_V0) | |||||
| add_definitions(-DENABLE_V0) | add_definitions(-DENABLE_V0) | ||||
| endif() | endif() | ||||
| include_directories(${CCSRC_DIR}/backend/kernel_compiler/cpu) | include_directories(${CCSRC_DIR}/backend/kernel_compiler/cpu) | ||||
| set(LITE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/..) | |||||
| include_directories(${LITE_DIR}/nnacl/) | |||||
| include_directories(${LITE_DIR}/nnacl/optimize) | |||||
| if(PLATFORM_ARM32 OR PLATFORM_ARM64) | if(PLATFORM_ARM32 OR PLATFORM_ARM64) | ||||
| #for performance | #for performance | ||||
| @@ -210,9 +213,11 @@ if("${CMAKE_BUILD_TYPE}" STREQUAL "Release") | |||||
| endif() | endif() | ||||
| ########################## build optimize and float16 library ################################# | ########################## build optimize and float16 library ################################# | ||||
| if(PLATFORM_ARM64) | |||||
| target_link_libraries(mindspore-lite cpu_opt_kernel_mid nnacl_optimize_mid) | |||||
| target_link_libraries(mindspore-lite_static cpu_opt_kernel_mid nnacl_optimize_mid) | |||||
| if(PLATFORM_ARM) | |||||
| if(PLATFORM_ARM64) | |||||
| target_link_libraries(mindspore-lite cpu_opt_kernel_mid nnacl_optimize_mid) | |||||
| target_link_libraries(mindspore-lite_static cpu_opt_kernel_mid nnacl_optimize_mid) | |||||
| endif() | |||||
| if(ENABLE_FP16) | if(ENABLE_FP16) | ||||
| target_link_libraries(mindspore-lite cpu_fp16_kernel_mid nnacl_fp16_mid) | target_link_libraries(mindspore-lite cpu_fp16_kernel_mid nnacl_fp16_mid) | ||||
| target_link_libraries(mindspore-lite_static cpu_fp16_kernel_mid nnacl_fp16_mid) | target_link_libraries(mindspore-lite_static cpu_fp16_kernel_mid nnacl_fp16_mid) | ||||
| @@ -248,8 +253,10 @@ if(DEFINED ARCHS) | |||||
| target_link_libraries(mindspore_lite mindrt_mid) | target_link_libraries(mindspore_lite mindrt_mid) | ||||
| endif() | endif() | ||||
| if(PLATFORM_ARM64) | |||||
| target_link_libraries(mindspore_lite cpu_opt_kernel_mid nnacl_optimize_mid) | |||||
| if(PLATFORM_ARM) | |||||
| if(PLATFORM_ARM64) | |||||
| target_link_libraries(mindspore_lite cpu_opt_kernel_mid nnacl_optimize_mid) | |||||
| endif() | |||||
| if(ENABLE_FP16) | if(ENABLE_FP16) | ||||
| target_link_libraries(mindspore_lite cpu_fp16_kernel_mid nnacl_fp16_mid) | target_link_libraries(mindspore_lite cpu_fp16_kernel_mid nnacl_fp16_mid) | ||||
| endif() | endif() | ||||
| @@ -155,7 +155,11 @@ bool IsSupportSDot() { | |||||
| bool IsSupportFloat16() { | bool IsSupportFloat16() { | ||||
| bool status = false; | bool status = false; | ||||
| #ifdef ENABLE_ARM64 | |||||
| #ifdef ENABLE_ARM32 | |||||
| status = true; | |||||
| #endif | |||||
| #if defined(ENABLE_ARM64) | |||||
| #if defined(__ANDROID__) | #if defined(__ANDROID__) | ||||
| int hwcap_type = 16; | int hwcap_type = 16; | ||||
| uint32_t hwcap = getHwCap(hwcap_type); | uint32_t hwcap = getHwCap(hwcap_type); | ||||
| @@ -44,7 +44,7 @@ uint64_t GetTimeUs(); | |||||
| bool IsSupportSDot(); | bool IsSupportSDot(); | ||||
| bool IsSupportFloat16(); | bool IsSupportFloat16(); | ||||
| #if defined(__arm__) || defined(__aarch64__) | |||||
| #if defined(__arm__) | |||||
| uint32_t getHwCap(int hwcap_type); | uint32_t getHwCap(int hwcap_type); | ||||
| #endif | #endif | ||||
| @@ -19,7 +19,7 @@ | |||||
| #include "src/common/version_manager.h" | #include "src/common/version_manager.h" | ||||
| #include "nnacl/pooling_parameter.h" | #include "nnacl/pooling_parameter.h" | ||||
| #include "src/ios_reg_kernels.h" | #include "src/ios_reg_kernels.h" | ||||
| #ifdef ENABLE_ARM64 | |||||
| #if defined(ENABLE_FP16) && defined(ENABLE_ARM) | |||||
| #if defined(__ANDROID__) | #if defined(__ANDROID__) | ||||
| #include <asm/hwcap.h> | #include <asm/hwcap.h> | ||||
| #endif | #endif | ||||
| @@ -55,6 +55,8 @@ int KernelRegistry::Init() { | |||||
| } else { | } else { | ||||
| MS_LOG(INFO) << "The current device NOT supports Sdot."; | MS_LOG(INFO) << "The current device NOT supports Sdot."; | ||||
| } | } | ||||
| #endif | |||||
| #ifdef ENABLE_FP16 | |||||
| if (mindspore::lite::IsSupportFloat16()) { | if (mindspore::lite::IsSupportFloat16()) { | ||||
| MS_LOG(INFO) << "The current device supports float16."; | MS_LOG(INFO) << "The current device supports float16."; | ||||
| } else { | } else { | ||||
| @@ -17,7 +17,7 @@ endif() | |||||
| add_library(cpu_kernel_mid OBJECT ${KERNEL_SRC}) | add_library(cpu_kernel_mid OBJECT ${KERNEL_SRC}) | ||||
| add_dependencies(cpu_kernel_mid fbs_src) | add_dependencies(cpu_kernel_mid fbs_src) | ||||
| if(PLATFORM_ARM64) | |||||
| if(PLATFORM_ARM) | |||||
| if(ENABLE_FP16) | if(ENABLE_FP16) | ||||
| file(GLOB FP16_KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/fp16/*.cc) | file(GLOB FP16_KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/fp16/*.cc) | ||||
| if(SUPPORT_TRAIN) | if(SUPPORT_TRAIN) | ||||
| @@ -52,7 +52,7 @@ int ConstantOfShapeCPUKernel::DoExecute(int task_id) { | |||||
| ConstantOfShapeInt32(reinterpret_cast<int32_t *>(output_ptr_), start, start + current_stride, | ConstantOfShapeInt32(reinterpret_cast<int32_t *>(output_ptr_), start, start + current_stride, | ||||
| param_->value_.int32_value_); | param_->value_.int32_value_); | ||||
| break; | break; | ||||
| #ifdef ENABLE_NEON | |||||
| #ifdef ENABLE_FP16 | |||||
| case kNumberTypeFloat16: | case kNumberTypeFloat16: | ||||
| ConstantOfShapeFp16(reinterpret_cast<float16_t *>(output_ptr_), start, start + current_stride, | ConstantOfShapeFp16(reinterpret_cast<float16_t *>(output_ptr_), start, start + current_stride, | ||||
| param_->value_.f32_value_); | param_->value_.f32_value_); | ||||
| @@ -31,8 +31,8 @@ int Convolution1x1FP16CPUKernel::InitMatmulParam() { | |||||
| matmul_param_->row_ = conv_param_->output_h_ * conv_param_->output_w_; | matmul_param_->row_ = conv_param_->output_h_ * conv_param_->output_w_; | ||||
| matmul_param_->col_ = conv_param_->output_channel_; | matmul_param_->col_ = conv_param_->output_channel_; | ||||
| matmul_param_->deep_ = conv_param_->input_channel_; | matmul_param_->deep_ = conv_param_->input_channel_; | ||||
| matmul_param_->row_16_ = UP_ROUND(matmul_param_->row_, C16NUM); | |||||
| matmul_param_->col_8_ = UP_ROUND(matmul_param_->col_, C8NUM); | |||||
| matmul_param_->row_align_ = UP_ROUND(matmul_param_->row_, row_tile_); | |||||
| matmul_param_->col_align_ = UP_ROUND(matmul_param_->col_, col_tile_); | |||||
| matmul_param_->act_type_ = conv_param_->act_type_; | matmul_param_->act_type_ = conv_param_->act_type_; | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -54,14 +54,14 @@ int Convolution1x1FP16CPUKernel::InitConv1x1Param() { | |||||
| pre_trans_input_ = (conv_param_->pad_u_ != 0 || conv_param_->pad_l_ != 0 || conv_param_->stride_h_ != 1 || | pre_trans_input_ = (conv_param_->pad_u_ != 0 || conv_param_->pad_l_ != 0 || conv_param_->stride_h_ != 1 || | ||||
| conv_param_->stride_w_ != 1); | conv_param_->stride_w_ != 1); | ||||
| if ((matmul_param_->row_ > (C16NUM * op_parameter_->thread_num_)) && (matmul_param_->row_ > matmul_param_->col_)) { | |||||
| if ((matmul_param_->row_ > (row_tile_ * op_parameter_->thread_num_)) && (matmul_param_->row_ > matmul_param_->col_)) { | |||||
| multi_thread_by_hw_ = true; | multi_thread_by_hw_ = true; | ||||
| thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->row_, C16NUM)); | |||||
| thread_stride_ = UP_DIV(UP_DIV(matmul_param_->row_, C16NUM), thread_count_) * C16NUM; | |||||
| thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->row_, row_tile_)); | |||||
| thread_stride_ = UP_DIV(UP_DIV(matmul_param_->row_, row_tile_), thread_count_) * row_tile_; | |||||
| } else { | } else { | ||||
| multi_thread_by_hw_ = false; | multi_thread_by_hw_ = false; | ||||
| thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->col_, C8NUM)); | |||||
| thread_stride_ = UP_DIV(UP_DIV(matmul_param_->col_, C8NUM), thread_count_) * C8NUM; | |||||
| thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->col_, col_tile_)); | |||||
| thread_stride_ = UP_DIV(UP_DIV(matmul_param_->col_, col_tile_), thread_count_) * col_tile_; | |||||
| } | } | ||||
| if (pre_trans_input_) { | if (pre_trans_input_) { | ||||
| @@ -81,8 +81,8 @@ int Convolution1x1FP16CPUKernel::InitWeightBias() { | |||||
| auto output_channel = weight_tensor->Batch(); | auto output_channel = weight_tensor->Batch(); | ||||
| if (in_tensors_.size() == 3) { | if (in_tensors_.size() == 3) { | ||||
| size_t size = UP_ROUND(output_channel, C8NUM) * sizeof(float16_t); | |||||
| size_t weight_size = output_channel * sizeof(float16_t); | |||||
| size_t size = UP_ROUND(output_channel, col_tile_) * sizeof(float16_t); | |||||
| size_t bias_size = output_channel * sizeof(float16_t); | |||||
| bias_data_ = malloc(size); | bias_data_ = malloc(size); | ||||
| if (bias_data_ == nullptr) { | if (bias_data_ == nullptr) { | ||||
| MS_LOG(ERROR) << "Conv1x1 Malloc bias_ptr_ error!"; | MS_LOG(ERROR) << "Conv1x1 Malloc bias_ptr_ error!"; | ||||
| @@ -94,11 +94,11 @@ int Convolution1x1FP16CPUKernel::InitWeightBias() { | |||||
| MS_LOG(ERROR) << "Conv1x1 only support fp16 weight"; | MS_LOG(ERROR) << "Conv1x1 only support fp16 weight"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| memset(reinterpret_cast<char *>(bias_data_) + weight_size, 0, size - weight_size); | |||||
| memset(reinterpret_cast<char *>(bias_data_) + bias_size, 0, size - bias_size); | |||||
| } | } | ||||
| size_t size = input_channel * UP_ROUND(output_channel, C8NUM) * sizeof(float16_t); | |||||
| size_t down_size = input_channel * DOWN_DIV(output_channel, C8NUM) * C8NUM * sizeof(float16_t); | |||||
| size_t size = input_channel * UP_ROUND(output_channel, col_tile_) * sizeof(float16_t); | |||||
| size_t down_size = input_channel * DOWN_DIV(output_channel, col_tile_) * col_tile_ * sizeof(float16_t); | |||||
| weight_ptr_ = reinterpret_cast<float16_t *>(malloc(size)); | weight_ptr_ = reinterpret_cast<float16_t *>(malloc(size)); | ||||
| if (weight_ptr_ == nullptr) { | if (weight_ptr_ == nullptr) { | ||||
| MS_LOG(ERROR) << "Conv1x1 Malloc weight_ptr_ error!"; | MS_LOG(ERROR) << "Conv1x1 Malloc weight_ptr_ error!"; | ||||
| @@ -111,6 +111,12 @@ int Convolution1x1FP16CPUKernel::InitWeightBias() { | |||||
| } | } | ||||
| int Convolution1x1FP16CPUKernel::Init() { | int Convolution1x1FP16CPUKernel::Init() { | ||||
| col_tile_ = C8NUM; | |||||
| #ifdef ENABLE_ARM64 | |||||
| row_tile_ = C16NUM; | |||||
| #else | |||||
| row_tile_ = C12NUM; | |||||
| #endif | |||||
| matmul_param_ = new (std::nothrow) MatMulParameter(); | matmul_param_ = new (std::nothrow) MatMulParameter(); | ||||
| if (matmul_param_ == nullptr) { | if (matmul_param_ == nullptr) { | ||||
| MS_LOG(ERROR) << "Init matmul_param_ failed."; | MS_LOG(ERROR) << "Init matmul_param_ failed."; | ||||
| @@ -177,8 +183,11 @@ int Convolution1x1FP16CPUKernel::RunHw(int task_id) { | |||||
| float16_t *thread_input_ptr = input_ptr_ + task_id * thread_stride_ * matmul_param_->deep_; | float16_t *thread_input_ptr = input_ptr_ + task_id * thread_stride_ * matmul_param_->deep_; | ||||
| float16_t *thread_pack_input = pack_input_ + task_id * thread_stride_ * matmul_param_->deep_; | float16_t *thread_pack_input = pack_input_ + task_id * thread_stride_ * matmul_param_->deep_; | ||||
| #ifdef ENABLE_ARM64 | |||||
| RowMajor2Col16MajorFp16Opt(thread_input_ptr, thread_pack_input, cur_hw_, matmul_param_->deep_); | RowMajor2Col16MajorFp16Opt(thread_input_ptr, thread_pack_input, cur_hw_, matmul_param_->deep_); | ||||
| #else | |||||
| RowMajor2Col12MajorFp16Opt(thread_input_ptr, thread_pack_input, cur_hw_, matmul_param_->deep_); | |||||
| #endif | |||||
| float16_t *thread_output_ptr = output_ptr_ + task_id * thread_stride_ * matmul_param_->col_; | float16_t *thread_output_ptr = output_ptr_ + task_id * thread_stride_ * matmul_param_->col_; | ||||
| MatMulFp16(thread_pack_input, weight_ptr_, thread_output_ptr, reinterpret_cast<float16_t *>(bias_data_), | MatMulFp16(thread_pack_input, weight_ptr_, thread_output_ptr, reinterpret_cast<float16_t *>(bias_data_), | ||||
| matmul_param_->act_type_, matmul_param_->deep_, cur_hw_, matmul_param_->col_, matmul_param_->col_, | matmul_param_->act_type_, matmul_param_->deep_, cur_hw_, matmul_param_->col_, matmul_param_->col_, | ||||
| @@ -211,7 +220,7 @@ int Convolution1x1FP16CPUKernel::Run() { | |||||
| ConvolutionBaseFP16CPUKernel::GetExecuteTensor(); | ConvolutionBaseFP16CPUKernel::GetExecuteTensor(); | ||||
| pack_input_ = reinterpret_cast<float16_t *>( | pack_input_ = reinterpret_cast<float16_t *>( | ||||
| ctx_->allocator->Malloc(matmul_param_->row_16_ * matmul_param_->deep_ * sizeof(float16_t))); | |||||
| ctx_->allocator->Malloc(matmul_param_->row_align_ * matmul_param_->deep_ * sizeof(float16_t))); | |||||
| if (pack_input_ == nullptr) { | if (pack_input_ == nullptr) { | ||||
| MS_LOG(ERROR) << "Conv1x1 Malloc pack_input_ error!"; | MS_LOG(ERROR) << "Conv1x1 Malloc pack_input_ error!"; | ||||
| return RET_MEMORY_FAILED; | return RET_MEMORY_FAILED; | ||||
| @@ -231,7 +240,11 @@ int Convolution1x1FP16CPUKernel::Run() { | |||||
| if (multi_thread_by_hw_) { | if (multi_thread_by_hw_) { | ||||
| ret = ParallelLaunch(this->context_->thread_pool_, Convolution1x1Fp16RunHw, this, thread_count_); | ret = ParallelLaunch(this->context_->thread_pool_, Convolution1x1Fp16RunHw, this, thread_count_); | ||||
| } else { | } else { | ||||
| #ifdef ENABLE_ARM64 | |||||
| RowMajor2Col16MajorFp16Opt(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_); | RowMajor2Col16MajorFp16Opt(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_); | ||||
| #else | |||||
| RowMajor2Col12MajorFp16Opt(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_); | |||||
| #endif | |||||
| ret = ParallelLaunch(this->context_->thread_pool_, Convolution1x1Fp16RunOc, this, thread_count_); | ret = ParallelLaunch(this->context_->thread_pool_, Convolution1x1Fp16RunOc, this, thread_count_); | ||||
| } | } | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| @@ -62,6 +62,8 @@ class Convolution1x1FP16CPUKernel : public ConvolutionBaseFP16CPUKernel { | |||||
| float16_t *pack_input_ = nullptr; | float16_t *pack_input_ = nullptr; | ||||
| float16_t *output_ptr_ = nullptr; | float16_t *output_ptr_ = nullptr; | ||||
| MatMulParameter *matmul_param_ = nullptr; | MatMulParameter *matmul_param_ = nullptr; | ||||
| int col_tile_; | |||||
| int row_tile_; | |||||
| }; | }; | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -34,7 +34,7 @@ int ConvolutionFP16CPUKernel::InitWeightBias() { | |||||
| int out_channel = filter_tensor->Batch(); | int out_channel = filter_tensor->Batch(); | ||||
| conv_param_->input_channel_ = in_channel; | conv_param_->input_channel_ = in_channel; | ||||
| conv_param_->output_channel_ = out_channel; | conv_param_->output_channel_ = out_channel; | ||||
| int oc8 = UP_ROUND(out_channel, C8NUM); | |||||
| int oc8 = UP_ROUND(out_channel, col_tile_); | |||||
| int kernel_plane = filter_tensor->Height() * filter_tensor->Width(); | int kernel_plane = filter_tensor->Height() * filter_tensor->Width(); | ||||
| int pack_weight_size = oc8 * in_channel * kernel_plane; | int pack_weight_size = oc8 * in_channel * kernel_plane; | ||||
| @@ -69,9 +69,8 @@ int ConvolutionFP16CPUKernel::InitWeightBias() { | |||||
| } | } | ||||
| int ConvolutionFP16CPUKernel::InitTmpBuffer() { | int ConvolutionFP16CPUKernel::InitTmpBuffer() { | ||||
| const int cal_num = 16; | |||||
| int unit_size = | int unit_size = | ||||
| conv_param_->kernel_h_ * conv_param_->kernel_w_ * conv_param_->input_channel_ * cal_num * thread_count_; | |||||
| conv_param_->kernel_h_ * conv_param_->kernel_w_ * conv_param_->input_channel_ * row_tile_ * thread_count_; | |||||
| packed_input_ = reinterpret_cast<float16_t *>(ctx_->allocator->Malloc(unit_size * sizeof(float16_t))); | packed_input_ = reinterpret_cast<float16_t *>(ctx_->allocator->Malloc(unit_size * sizeof(float16_t))); | ||||
| if (packed_input_ == nullptr) { | if (packed_input_ == nullptr) { | ||||
| @@ -88,6 +87,12 @@ int ConvolutionFP16CPUKernel::InitTmpBuffer() { | |||||
| } | } | ||||
| int ConvolutionFP16CPUKernel::Init() { | int ConvolutionFP16CPUKernel::Init() { | ||||
| #ifdef ENABLE_ARM64 | |||||
| row_tile_ = C16NUM; | |||||
| #else | |||||
| row_tile_ = C12NUM; | |||||
| #endif | |||||
| col_tile_ = C8NUM; | |||||
| auto ret = InitWeightBias(); | auto ret = InitWeightBias(); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "Init weight bias failed."; | MS_LOG(ERROR) << "Init weight bias failed."; | ||||
| @@ -99,7 +104,7 @@ int ConvolutionFP16CPUKernel::Init() { | |||||
| void ConvolutionFP16CPUKernel::AdjustNumberOfThread() { | void ConvolutionFP16CPUKernel::AdjustNumberOfThread() { | ||||
| auto out_tensor = out_tensors_.front(); | auto out_tensor = out_tensors_.front(); | ||||
| int out_plane = out_tensor->Height() * out_tensor->Width(); | int out_plane = out_tensor->Height() * out_tensor->Width(); | ||||
| thread_count_ = MSMIN(ctx_->thread_num_, UP_DIV(out_plane, C16NUM)); | |||||
| thread_count_ = MSMIN(ctx_->thread_num_, UP_DIV(out_plane, row_tile_)); | |||||
| conv_param_->thread_num_ = thread_count_; | conv_param_->thread_num_ = thread_count_; | ||||
| } | } | ||||
| @@ -62,6 +62,8 @@ class ConvolutionFP16CPUKernel : public ConvolutionBaseFP16CPUKernel { | |||||
| float16_t *packed_input_ = nullptr; | float16_t *packed_input_ = nullptr; | ||||
| float16_t *packed_weight_ = nullptr; | float16_t *packed_weight_ = nullptr; | ||||
| float16_t *col_major_input_ = nullptr; | float16_t *col_major_input_ = nullptr; | ||||
| int col_tile_; | |||||
| int row_tile_; | |||||
| }; | }; | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -38,13 +38,10 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() { | |||||
| int out_channel = filter_tensor->Batch(); | int out_channel = filter_tensor->Batch(); | ||||
| conv_param_->input_channel_ = in_channel; | conv_param_->input_channel_ = in_channel; | ||||
| conv_param_->output_channel_ = out_channel; | conv_param_->output_channel_ = out_channel; | ||||
| const int oc_block = C8NUM; | |||||
| int oc_block_num = UP_DIV(out_channel, C8NUM); | |||||
| int oc_block_num = UP_DIV(out_channel, col_tile_); | |||||
| // init weight | // init weight | ||||
| // set data | // set data | ||||
| auto trans_matrix_data_size = input_unit_ * input_unit_ * in_channel * oc_block_num * oc_block * sizeof(float16_t); | |||||
| auto trans_matrix_data_size = input_unit_ * input_unit_ * in_channel * oc_block_num * col_tile_ * sizeof(float16_t); | |||||
| trans_weight_ = reinterpret_cast<float16_t *>(malloc(trans_matrix_data_size)); | trans_weight_ = reinterpret_cast<float16_t *>(malloc(trans_matrix_data_size)); | ||||
| if (trans_weight_ == nullptr) { | if (trans_weight_ == nullptr) { | ||||
| MS_LOG(ERROR) << "malloc trans_weight_ failed."; | MS_LOG(ERROR) << "malloc trans_weight_ failed."; | ||||
| @@ -73,7 +70,7 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() { | |||||
| MS_LOG(ERROR) << "get execute filter failed."; | MS_LOG(ERROR) << "get execute filter failed."; | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| ret = WinogradFilterTransformFp16(execute_weight_, matrix_g, matrix_gt, oc_block); | |||||
| ret = WinogradFilterTransformFp16(execute_weight_, matrix_g, matrix_gt, col_tile_); | |||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "winograd filter transform failed."; | MS_LOG(ERROR) << "winograd filter transform failed."; | ||||
| return ret; | return ret; | ||||
| @@ -85,12 +82,12 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() { | |||||
| } | } | ||||
| // init bias | // init bias | ||||
| bias_data_ = malloc(oc_block_num * oc_block * sizeof(float16_t)); | |||||
| bias_data_ = malloc(oc_block_num * col_tile_ * sizeof(float16_t)); | |||||
| if (bias_data_ == nullptr) { | if (bias_data_ == nullptr) { | ||||
| MS_LOG(ERROR) << "malloc bias_data_ failed."; | MS_LOG(ERROR) << "malloc bias_data_ failed."; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| memset(bias_data_, 0, oc_block_num * oc_block * sizeof(float16_t)); | |||||
| memset(bias_data_, 0, oc_block_num * col_tile_ * sizeof(float16_t)); | |||||
| if (in_tensors_.size() == kInputSize2) { | if (in_tensors_.size() == kInputSize2) { | ||||
| if (origin_bias_data_type_ == kNumberTypeFloat16) { | if (origin_bias_data_type_ == kNumberTypeFloat16) { | ||||
| @@ -106,11 +103,9 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() { | |||||
| } | } | ||||
| int ConvolutionWinogradFP16CPUKernel::InitTmpBuffer() { | int ConvolutionWinogradFP16CPUKernel::InitTmpBuffer() { | ||||
| const int cal_num = 16; | |||||
| int channel_out = conv_param_->output_channel_; | int channel_out = conv_param_->output_channel_; | ||||
| size_t tile_buffer_size = | size_t tile_buffer_size = | ||||
| thread_count_ * cal_num * input_unit_ * input_unit_ * conv_param_->input_channel_ * sizeof(float16_t); | |||||
| thread_count_ * row_tile_ * input_unit_ * input_unit_ * conv_param_->input_channel_ * sizeof(float16_t); | |||||
| trans_input_ = reinterpret_cast<float16_t *>(ctx_->allocator->Malloc(tile_buffer_size)); | trans_input_ = reinterpret_cast<float16_t *>(ctx_->allocator->Malloc(tile_buffer_size)); | ||||
| if (trans_input_ == nullptr) { | if (trans_input_ == nullptr) { | ||||
| MS_LOG(ERROR) << "malloc trans_input_ failed."; | MS_LOG(ERROR) << "malloc trans_input_ failed."; | ||||
| @@ -118,7 +113,7 @@ int ConvolutionWinogradFP16CPUKernel::InitTmpBuffer() { | |||||
| } | } | ||||
| gemm_out_ = reinterpret_cast<float16_t *>(ctx_->allocator->Malloc( | gemm_out_ = reinterpret_cast<float16_t *>(ctx_->allocator->Malloc( | ||||
| thread_count_ * cal_num * input_unit_ * input_unit_ * UP_ROUND(channel_out, C8NUM) * sizeof(float16_t))); | |||||
| thread_count_ * row_tile_ * input_unit_ * input_unit_ * UP_ROUND(channel_out, C8NUM) * sizeof(float16_t))); | |||||
| if (gemm_out_ == nullptr) { | if (gemm_out_ == nullptr) { | ||||
| MS_LOG(ERROR) << "malloc gemm_out_ failed."; | MS_LOG(ERROR) << "malloc gemm_out_ failed."; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -132,7 +127,7 @@ int ConvolutionWinogradFP16CPUKernel::InitTmpBuffer() { | |||||
| } | } | ||||
| col_buffer_ = reinterpret_cast<float16_t *>( | col_buffer_ = reinterpret_cast<float16_t *>( | ||||
| ctx_->allocator->Malloc(thread_count_ * cal_num * conv_param_->input_channel_ * sizeof(float16_t))); | |||||
| ctx_->allocator->Malloc(thread_count_ * row_tile_ * conv_param_->input_channel_ * sizeof(float16_t))); | |||||
| if (col_buffer_ == nullptr) { | if (col_buffer_ == nullptr) { | ||||
| MS_LOG(ERROR) << "malloc col_buffer_ failed."; | MS_LOG(ERROR) << "malloc col_buffer_ failed."; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -160,6 +155,12 @@ int ConvolutionWinogradFP16CPUKernel::ConfigInputOutput() { | |||||
| } | } | ||||
| int ConvolutionWinogradFP16CPUKernel::Init() { | int ConvolutionWinogradFP16CPUKernel::Init() { | ||||
| col_tile_ = C8NUM; | |||||
| #ifdef ENABLE_ARM64 | |||||
| row_tile_ = C16NUM; | |||||
| #else | |||||
| row_tile_ = C12NUM; | |||||
| #endif | |||||
| kernel_unit_ = conv_param_->kernel_h_; | kernel_unit_ = conv_param_->kernel_h_; | ||||
| input_unit_ = output_unit_ + kernel_unit_ - 1; | input_unit_ = output_unit_ + kernel_unit_ - 1; | ||||
| conv_param_->input_unit_ = input_unit_; | conv_param_->input_unit_ = input_unit_; | ||||
| @@ -86,6 +86,8 @@ class ConvolutionWinogradFP16CPUKernel : public ConvolutionBaseFP16CPUKernel { | |||||
| TmpBufferAddressFp16 tmp_buffer_address_list_[4]; | TmpBufferAddressFp16 tmp_buffer_address_list_[4]; | ||||
| InputTransFp16Func in_func_; | InputTransFp16Func in_func_; | ||||
| OutputTransFp16Func out_func_; | OutputTransFp16Func out_func_; | ||||
| int col_tile_; | |||||
| int row_tile_; | |||||
| }; | }; | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -13,26 +13,20 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifdef ENABLE_ARM64 | |||||
| #ifdef ENABLE_ARM | |||||
| #include <arm_neon.h> | #include <arm_neon.h> | ||||
| #endif | #endif | ||||
| #include "nnacl/fp16/cast_fp16.h" | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| extern "C" { | extern "C" { | ||||
| #endif | #endif | ||||
| #ifdef ENABLE_ARM64 | |||||
| extern void Float32ToFloat16(const float *input, float16_t *output, int number); | |||||
| extern void Float16ToFloat32(const float16_t *input, float *output, int number); | |||||
| inline void Float32ToFloat16_fp16_handler(const void *input, void *output, int number) { | |||||
| static inline void Float32ToFloat16_fp16_handler(const void *input, void *output, int number) { | |||||
| Float32ToFloat16(reinterpret_cast<const float *>(input), reinterpret_cast<float16_t *>(output), number); | Float32ToFloat16(reinterpret_cast<const float *>(input), reinterpret_cast<float16_t *>(output), number); | ||||
| } | } | ||||
| inline void Float16ToFloat32_fp16_handler(const void *input, void *output, int number) { | |||||
| static inline void Float16ToFloat32_fp16_handler(const void *input, void *output, int number) { | |||||
| Float16ToFloat32(reinterpret_cast<const float16_t *>(input), reinterpret_cast<float *>(output), number); | Float16ToFloat32(reinterpret_cast<const float16_t *>(input), reinterpret_cast<float *>(output), number); | ||||
| } | } | ||||
| #endif | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| } | } | ||||
| @@ -53,8 +53,8 @@ int PoolingFp16CPUKernel::ReSize() { | |||||
| } | } | ||||
| int PoolingFp16CPUKernel::RunImpl(int task_id) { | int PoolingFp16CPUKernel::RunImpl(int task_id) { | ||||
| float16_t minf = -FLT_MAX; | |||||
| float16_t maxf = FLT_MAX; | |||||
| float16_t minf = -FLT16_MAX; | |||||
| float16_t maxf = FLT16_MAX; | |||||
| if (pooling_param_->act_type_ == ActType_Relu) { | if (pooling_param_->act_type_ == ActType_Relu) { | ||||
| minf = 0.f; | minf = 0.f; | ||||
| } else if (pooling_param_->act_type_ == ActType_Relu6) { | } else if (pooling_param_->act_type_ == ActType_Relu6) { | ||||
| @@ -45,7 +45,7 @@ | |||||
| #include "src/runtime/agent/npu/optimizer/npu_fusion_pass.h" | #include "src/runtime/agent/npu/optimizer/npu_fusion_pass.h" | ||||
| #include "src/runtime/agent/npu/optimizer/npu_insert_transform_pass.h" | #include "src/runtime/agent/npu/optimizer/npu_insert_transform_pass.h" | ||||
| #endif | #endif | ||||
| #if defined(ENABLE_ARM64) && defined(ENABLE_FP16) | |||||
| #if defined(ENABLE_ARM) && defined(ENABLE_FP16) | |||||
| #include "src/runtime/kernel/arm/fp16/fp16_op_handler.h" | #include "src/runtime/kernel/arm/fp16/fp16_op_handler.h" | ||||
| #endif | #endif | ||||
| @@ -230,7 +230,7 @@ int CastConstTensorData(Tensor *tensor, std::map<Tensor *, Tensor *> *restored_o | |||||
| auto origin_data = tensor->data_c(); | auto origin_data = tensor->data_c(); | ||||
| MS_ASSERT(origin_data != nullptr); | MS_ASSERT(origin_data != nullptr); | ||||
| if (tensor->data_type() == kNumberTypeFloat32 && dst_data_type == kNumberTypeFloat16) { | if (tensor->data_type() == kNumberTypeFloat32 && dst_data_type == kNumberTypeFloat16) { | ||||
| #if defined(ENABLE_ARM64) && defined(ENABLE_FP16) | |||||
| #if defined(ENABLE_ARM) && defined(ENABLE_FP16) | |||||
| auto restore_tensor = Tensor::CopyTensor(*tensor, false); | auto restore_tensor = Tensor::CopyTensor(*tensor, false); | ||||
| restore_tensor->set_data(origin_data); | restore_tensor->set_data(origin_data); | ||||
| restore_tensor->set_own_data(tensor->own_data()); | restore_tensor->set_own_data(tensor->own_data()); | ||||
| @@ -17,7 +17,7 @@ | |||||
| #include "src/sub_graph_kernel.h" | #include "src/sub_graph_kernel.h" | ||||
| #include "src/tensor.h" | #include "src/tensor.h" | ||||
| #include "src/tensorlist.h" | #include "src/tensorlist.h" | ||||
| #if defined(ENABLE_ARM64) && defined(ENABLE_FP16) | |||||
| #ifdef ENABLE_FP16 | |||||
| #include "src/runtime/kernel/arm/fp16/fp16_op_handler.h" | #include "src/runtime/kernel/arm/fp16/fp16_op_handler.h" | ||||
| #endif | #endif | ||||
| #include "src/common/version_manager.h" | #include "src/common/version_manager.h" | ||||
| @@ -283,7 +283,7 @@ int CpuFp16SubGraph::Float16TensorToFloat32Tensor(lite::Tensor *tensor) { | |||||
| } | } | ||||
| int CpuFp16SubGraph::PreProcess() { | int CpuFp16SubGraph::PreProcess() { | ||||
| #ifdef ENABLE_ARM64 | |||||
| #ifdef ENABLE_FP16 | |||||
| if (!mindspore::lite::IsSupportFloat16()) { | if (!mindspore::lite::IsSupportFloat16()) { | ||||
| MS_LOG(ERROR) << "Unsupported fp16 in this devices"; | MS_LOG(ERROR) << "Unsupported fp16 in this devices"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -347,7 +347,7 @@ int CpuFp16SubGraph::PreProcess() { | |||||
| } | } | ||||
| int CpuFp16SubGraph::PostProcess() { | int CpuFp16SubGraph::PostProcess() { | ||||
| #ifdef ENABLE_ARM64 | |||||
| #ifdef ENABLE_FP16 | |||||
| if (!mindspore::lite::IsSupportFloat16()) { | if (!mindspore::lite::IsSupportFloat16()) { | ||||
| MS_LOG(ERROR) << "Unsupported fp16 in this devices"; | MS_LOG(ERROR) << "Unsupported fp16 in this devices"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -378,8 +378,11 @@ add_dependencies(lite-test fbs_src) | |||||
| target_link_libraries(lite-test dl mindspore::gtest) | target_link_libraries(lite-test dl mindspore::gtest) | ||||
| if(PLATFORM_ARM64 AND ENABLE_FP16) | |||||
| target_link_libraries(lite-test nnacl_fp16_mid nnacl_optimize_mid) | |||||
| if(PLATFORM_ARM AND ENABLE_FP16) | |||||
| target_link_libraries(lite-test nnacl_fp16_mid) | |||||
| if(PLATFORM_ARM64) | |||||
| target_link_libraries(lite-test nnacl_optimize_mid) | |||||
| endif() | |||||
| endif() | endif() | ||||
| if(PLATFORM_ARM) | if(PLATFORM_ARM) | ||||