| @@ -0,0 +1,248 @@ | |||||
| .text | |||||
| .align 5 | |||||
| //.p2align 5,,15 | |||||
| .global PostFuncBiasReluC4 | |||||
| #ifndef __APPLE__ | |||||
| .type PostFuncBiasReluC4, %function | |||||
| #endif | |||||
| //void PostFuncBiasReluC4(float *dst, const float *src, const float *bias, size_t oc4div, size_t oc4mod, | |||||
| // size_t plane_size, size_t plane_stride, size_t relu_type); | |||||
| // r0 dst r1 srx r2 bias | |||||
| // r3 oc4div r4 oc4mod r5 plane_size | |||||
| // r6 plane_stride r7 relu_type | |||||
| // v0 ~ v15 value | |||||
| // v16 v17 bias data | |||||
| // r10 r11 weite loop tmp buf | |||||
| // r16 relu6 #6; r17 relu #0 | |||||
| // lr oc8 loop control | |||||
| // r8 hw loop control | |||||
| PostFuncBiasReluC4: | |||||
| push {r4-r8, r10, r11, lr} | |||||
| add sp, sp, #32 | |||||
| ldr r4, [sp] | |||||
| ldr r5, [sp, #4] | |||||
| ldr r6, [sp, #8] | |||||
| ldr r7, [sp, #12] | |||||
| vmov.i32 q14, #6 | |||||
| vcvt.f32.s32 q14, q14 | |||||
| veor q15, q15, q15 | |||||
| mov lr, #4 | |||||
| add r12, r3, r4 | |||||
| mul r12, r12, lr | |||||
| mov lr, #0 | |||||
| Loop_C4: | |||||
| cmp lr, r3 | |||||
| beq Loop_C1 | |||||
| mov r11, #4 | |||||
| mul r10, lr, r11 | |||||
| add r11, r0, r10 | |||||
| add lr, lr, #4 | |||||
| mov r8, r5 | |||||
| vld1.32 {q12}, [r2]! | |||||
| Loop_4x4: | |||||
| cmp r8, #4 | |||||
| blt Loop_1x4 | |||||
| sub r8, r8, #4 | |||||
| vld1.32 {q0-q1}, [r1]! | |||||
| vld1.32 {q2-q3}, [r1]! | |||||
| vadd.f32 q0, q0, q12 | |||||
| vadd.f32 q1, q1, q12 | |||||
| vadd.f32 q2, q2, q12 | |||||
| vadd.f32 q3, q3, q12 | |||||
| cmp r7, #3 | |||||
| beq Relu6_4x4 | |||||
| cmp r7, #1 | |||||
| beq Relu_4x4 | |||||
| b Write_4x4 | |||||
| Relu6_4x4: | |||||
| vmin.f32 q0, q0, q14 | |||||
| vmin.f32 q1, q1, q14 | |||||
| vmin.f32 q2, q2, q14 | |||||
| vmin.f32 q3, q3, q14 | |||||
| Relu_4x4: | |||||
| vmax.f32 q0, q0, q15 | |||||
| vmax.f32 q1, q1, q15 | |||||
| vmax.f32 q2, q2, q15 | |||||
| vmax.f32 q3, q3, q15 | |||||
| Write_4x4: | |||||
| vst1.32 {q0}, [r11], r12 | |||||
| vst1.32 {q1}, [r11], r12 | |||||
| vst1.32 {q2}, [r11], r12 | |||||
| vst1.32 {q3}, [r11], r12 | |||||
| b Loop_4x4 | |||||
| Loop_1x4: | |||||
| cmp r7, #3 | |||||
| beq Relu6_1x4 | |||||
| cmp r7, #1 | |||||
| beq Relu_1x4 | |||||
| b Write_1x4 | |||||
| Relu6_1x4: | |||||
| cmp r8, #0 | |||||
| beq HW_Add | |||||
| sub r8, r8, #1 | |||||
| vld1.32 {q0}, [r1]! | |||||
| vadd.f32 q0, q0, q12 | |||||
| vmin.f32 q0, q0, q14 | |||||
| vmax.f32 q0, q0, q15 | |||||
| vst1.32 {q0}, [r11], r12 | |||||
| b Relu6_1x4 | |||||
| Relu_1x4: | |||||
| cmp r8, #0 | |||||
| beq HW_Add | |||||
| sub r8, r8, #1 | |||||
| vld1.32 {q0}, [r1]! | |||||
| vadd.f32 q0, q0, q12 | |||||
| vmax.f32 q0, q0, q15 | |||||
| vst1.32 {q0}, [r11], r12 | |||||
| b Relu_1x4 | |||||
| Write_1x4: | |||||
| cmp r8, #0 | |||||
| beq HW_Add | |||||
| sub r8, r8, #1 | |||||
| vld1.32 {q0}, [r1]! | |||||
| vadd.f32 q0, q0, q12 | |||||
| vst1.32 {q0}, [r11], r12 | |||||
| b Write_1x4 | |||||
| HW_Add: | |||||
| add r1, r1, r6 | |||||
| b Loop_C4 | |||||
| Loop_C1: | |||||
| cmp r4, #0 | |||||
| beq End | |||||
| mov r8, r5 | |||||
| vld1.32 {q12}, [r2]! | |||||
| mov r11, #4 | |||||
| mul r10, lr, r11 | |||||
| add r0, r0, r10 | |||||
| cmp r4, #1 | |||||
| beq Loop_C1_1 | |||||
| cmp r4, #2 | |||||
| beq Loop_C1_2 | |||||
| cmp r4, #3 | |||||
| beq Loop_C1_3 | |||||
| Loop_C1_1: | |||||
| cmp r7, #3 | |||||
| beq Loop_C1_1_Relu6 | |||||
| cmp r7, #1 | |||||
| beq Loop_C1_1_Relu | |||||
| b Loop_C1_1_Write | |||||
| Loop_C1_1_Relu6: | |||||
| cmp r8, #0 | |||||
| beq End | |||||
| sub r8, r8, #1 | |||||
| vld1.32 {q0}, [r1]! | |||||
| vadd.f32 q0, q0, q12 | |||||
| vmin.f32 q0, q0, q14 | |||||
| vmax.f32 q0, q0, q15 | |||||
| vst1.32 {d0[0]}, [r0], r12 | |||||
| b Loop_C1_1_Relu6 | |||||
| Loop_C1_1_Relu: | |||||
| cmp r8, #0 | |||||
| beq End | |||||
| sub r8, r8, #1 | |||||
| vld1.32 {q0}, [r1]! | |||||
| vadd.f32 q0, q0, q12 | |||||
| vmax.f32 q0, q0, q15 | |||||
| vst1.32 {d0[0]}, [r0], r12 | |||||
| b Loop_C1_1_Relu | |||||
| Loop_C1_1_Write: | |||||
| cmp r8, #0 | |||||
| beq End | |||||
| sub r8, r8, #1 | |||||
| vld1.32 {q0}, [r1]! | |||||
| vadd.f32 q0, q0, q12 | |||||
| vst1.32 {d0[0]}, [r0], r12 | |||||
| b Loop_C1_1_Write | |||||
| Loop_C1_2: | |||||
| cmp r7, #3 | |||||
| beq Loop_C1_2_Relu6 | |||||
| cmp r7, #1 | |||||
| beq Loop_C1_2_Relu | |||||
| b Loop_C1_2_Write | |||||
| Loop_C1_2_Relu6: | |||||
| cmp r8, #0 | |||||
| beq End | |||||
| sub r8, r8, #1 | |||||
| vld1.32 {q0}, [r1]! | |||||
| vadd.f32 q0, q0, q12 | |||||
| vmin.f32 q0, q0, q14 | |||||
| vmax.f32 q0, q0, q15 | |||||
| vst1.32 {d0}, [r0], r12 | |||||
| b Loop_C1_2_Relu6 | |||||
| Loop_C1_2_Relu: | |||||
| cmp r8, #0 | |||||
| beq End | |||||
| sub r8, r8, #1 | |||||
| vld1.32 {q0}, [r1]! | |||||
| vadd.f32 q0, q0, q12 | |||||
| vmax.f32 q0, q0, q15 | |||||
| vst1.32 {d0}, [r0], r12 | |||||
| b Loop_C1_2_Relu | |||||
| Loop_C1_2_Write: | |||||
| cmp r8, #0 | |||||
| beq End | |||||
| sub r8, r8, #1 | |||||
| vld1.32 {q0}, [r1]! | |||||
| vadd.f32 q0, q0, q12 | |||||
| vst1.32 {d0}, [r0], r12 | |||||
| b Loop_C1_2_Write | |||||
| Loop_C1_3: | |||||
| add r11, r0, #8 | |||||
| cmp r7, #3 | |||||
| beq Loop_C1_3_Relu6 | |||||
| cmp r7, #1 | |||||
| beq Loop_C1_3_Relu | |||||
| b Loop_C1_3_Write | |||||
| Loop_C1_3_Relu6: | |||||
| cmp r8, #0 | |||||
| beq End | |||||
| sub r8, r8, #1 | |||||
| vld1.32 {q0}, [r1]! | |||||
| vadd.f32 q0, q0, q12 | |||||
| vmin.f32 q0, q0, q14 | |||||
| vmax.f32 q0, q0, q15 | |||||
| vst1.32 {d0}, [r0], r6 | |||||
| vst1.32 {d1[0]}, [r11], r12 | |||||
| b Loop_C1_3_Relu6 | |||||
| Loop_C1_3_Relu: | |||||
| cmp r8, #0 | |||||
| beq End | |||||
| sub r8, r8, #1 | |||||
| vld1.32 {q0}, [r1]! | |||||
| vadd.f32 q0, q0, q12 | |||||
| vmax.f32 q0, q0, q15 | |||||
| vst1.32 {d0}, [r0], r6 | |||||
| vst1.32 {d1[0]}, [r11], r12 | |||||
| b Loop_C1_3_Relu | |||||
| Loop_C1_3_Write: | |||||
| cmp r8, #0 | |||||
| beq End | |||||
| sub r8, r8, #1 | |||||
| vld1.32 {q0}, [r1]! | |||||
| vadd.f32 q0, q0, q12 | |||||
| vst1.32 {d0}, [r0], r6 | |||||
| vst1.32 {d1[0]}, [r11], r12 | |||||
| b Loop_C1_3_Write | |||||
| End: | |||||
| sub sp, sp, #32 | |||||
| pop {r4-r8, r10, r11, pc} | |||||
| @@ -0,0 +1,305 @@ | |||||
| #ifdef __aarch64__ | |||||
| .text | |||||
| .align 5 | |||||
| //.p2align 5,,15 | |||||
| .global PostFuncBiasReluC4 | |||||
| #ifndef __APPLE__ | |||||
| .type PostFuncBiasReluC4, %function | |||||
| #endif | |||||
| //void PostFuncBiasReluC4(float *dst, const float *src, const float *bias, size_t oc4div, size_t oc4mod, | |||||
| // size_t plane_size, size_t plane_stride, size_t relu_type); | |||||
| // x0 dst x1 srx x2 bias | |||||
| // w3 oc4div w4 oc4mod w5 plane_size | |||||
| // x6 plane_stride x7 relu_type | |||||
| // v0 ~ v7 value | |||||
| // v16 bias data | |||||
| // x12 oc_stride | |||||
| // x14 x15 write loop tmp buf | |||||
| // v26 relu6 #6; v27 relu #0 | |||||
| // w10 oc4 loop control | |||||
| // w13 hw loop control | |||||
| PostFuncBiasReluC4: | |||||
| movi v26.4s, #6 | |||||
| scvtf v26.4s, v26.4s | |||||
| dup v27.4s, wzr | |||||
| mov x10, #4 | |||||
| add x12, x3, x4 | |||||
| mul x12, x12, x10 | |||||
| mov w10, #0 | |||||
| Loop_C4: | |||||
| cmp w10, w3 | |||||
| beq Loop_C1 | |||||
| mov x15, #4 | |||||
| mul x14, x10, x15 | |||||
| add x15, x0, x14 | |||||
| add w10, w10, #4 | |||||
| mov w13, w5 | |||||
| ld1 {v16.4s}, [x2], #16 | |||||
| Loop_8x4: | |||||
| cmp w13, #8 | |||||
| blt Loop_4x4 | |||||
| sub w13, w13, #8 | |||||
| ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64 | |||||
| ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x1], #64 | |||||
| fadd v0.4s, v0.4s, v16.4s | |||||
| fadd v1.4s, v1.4s, v16.4s | |||||
| fadd v2.4s, v2.4s, v16.4s | |||||
| fadd v3.4s, v3.4s, v16.4s | |||||
| fadd v4.4s, v4.4s, v16.4s | |||||
| fadd v5.4s, v5.4s, v16.4s | |||||
| fadd v6.4s, v6.4s, v16.4s | |||||
| fadd v7.4s, v7.4s, v16.4s | |||||
| cmp x7, #3 | |||||
| beq Relu6_8x4 | |||||
| cmp x7, #1 | |||||
| beq Relu_8x4 | |||||
| b Write_8x4 | |||||
| Relu6_8x4: | |||||
| fmin v0.4s, v0.4s, v26.4s | |||||
| fmin v1.4s, v1.4s, v26.4s | |||||
| fmin v2.4s, v2.4s, v26.4s | |||||
| fmin v3.4s, v3.4s, v26.4s | |||||
| fmin v4.4s, v4.4s, v26.4s | |||||
| fmin v5.4s, v5.4s, v26.4s | |||||
| fmin v6.4s, v6.4s, v26.4s | |||||
| fmin v7.4s, v7.4s, v26.4s | |||||
| Relu_8x4: | |||||
| fmax v0.4s, v0.4s, v27.4s | |||||
| fmax v1.4s, v1.4s, v27.4s | |||||
| fmax v2.4s, v2.4s, v27.4s | |||||
| fmax v3.4s, v3.4s, v27.4s | |||||
| fmax v4.4s, v4.4s, v27.4s | |||||
| fmax v5.4s, v5.4s, v27.4s | |||||
| fmax v6.4s, v6.4s, v27.4s | |||||
| fmax v7.4s, v7.4s, v27.4s | |||||
| Write_8x4: | |||||
| st1 {v0.4s}, [x15], x12 | |||||
| st1 {v1.4s}, [x15], x12 | |||||
| st1 {v2.4s}, [x15], x12 | |||||
| st1 {v3.4s}, [x15], x12 | |||||
| st1 {v4.4s}, [x15], x12 | |||||
| st1 {v5.4s}, [x15], x12 | |||||
| st1 {v6.4s}, [x15], x12 | |||||
| st1 {v7.4s}, [x15], x12 | |||||
| b Loop_8x4 | |||||
| Loop_4x4: | |||||
| cmp w13, #4 | |||||
| blt Loop_1x4 | |||||
| sub w13, w13, #4 | |||||
| ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64 | |||||
| fadd v0.4s, v0.4s, v16.4s | |||||
| fadd v1.4s, v1.4s, v16.4s | |||||
| fadd v2.4s, v2.4s, v16.4s | |||||
| fadd v3.4s, v3.4s, v16.4s | |||||
| cmp x7, #3 | |||||
| beq Relu6_4x4 | |||||
| cmp x7, #1 | |||||
| beq Relu_4x4 | |||||
| b Write_4x4 | |||||
| Relu6_4x4: | |||||
| fmin v0.4s, v0.4s, v26.4s | |||||
| fmin v1.4s, v1.4s, v26.4s | |||||
| fmin v2.4s, v2.4s, v26.4s | |||||
| fmin v3.4s, v3.4s, v26.4s | |||||
| Relu_4x4: | |||||
| fmax v0.4s, v0.4s, v27.4s | |||||
| fmax v1.4s, v1.4s, v27.4s | |||||
| fmax v2.4s, v2.4s, v27.4s | |||||
| fmax v3.4s, v3.4s, v27.4s | |||||
| Write_4x4: | |||||
| st1 {v0.4s}, [x15], x12 | |||||
| st1 {v1.4s}, [x15], x12 | |||||
| st1 {v2.4s}, [x15], x12 | |||||
| st1 {v3.4s}, [x15], x12 | |||||
| Loop_1x4: | |||||
| cmp x7, #3 | |||||
| beq Relu6_1x4 | |||||
| cmp x7, #1 | |||||
| beq Relu_1x4 | |||||
| b Write_1x4 | |||||
| Relu6_1x4: | |||||
| cmp w13, #0 | |||||
| beq HW_Add | |||||
| sub w13, w13, #1 | |||||
| ld1 {v0.4s}, [x1], #16 | |||||
| fadd v0.4s, v0.4s, v16.4s | |||||
| fmin v0.4s, v0.4s, v26.4s | |||||
| fmax v0.4s, v0.4s, v27.4s | |||||
| st1 {v0.4s}, [x15], x12 | |||||
| b Relu6_1x4 | |||||
| Relu_1x4: | |||||
| cmp w13, #0 | |||||
| beq HW_Add | |||||
| sub w13, w13, #1 | |||||
| ld1 {v0.4s}, [x1], #16 | |||||
| fadd v0.4s, v0.4s, v16.4s | |||||
| fmax v0.4s, v0.4s, v27.4s | |||||
| st1 {v0.4s}, [x15], x12 | |||||
| b Relu_1x4 | |||||
| Write_1x4: | |||||
| cmp w13, #0 | |||||
| beq HW_Add | |||||
| sub w13, w13, #1 | |||||
| ld1 {v0.4s}, [x1], #16 | |||||
| fadd v0.4s, v0.4s, v16.4s | |||||
| st1 {v0.4s}, [x15], x12 | |||||
| b Write_1x4 | |||||
| HW_Add: | |||||
| add x1, x1, x6 | |||||
| b Loop_C4 | |||||
| Loop_C1: | |||||
| cmp x4, #0 | |||||
| beq End | |||||
| mov w13, w5 | |||||
| ld1 {v16.4s}, [x2], #16 | |||||
| mov x15, #4 | |||||
| mul x14, x10, x15 | |||||
| add x0, x0, x14 | |||||
| cmp x4, #1 | |||||
| beq Loop_C1_1 | |||||
| cmp x4, #2 | |||||
| beq Loop_C1_2 | |||||
| cmp x4, #3 | |||||
| beq Loop_C1_3 | |||||
| Loop_C1_1: | |||||
| cmp x7, #3 | |||||
| beq Loop_C1_1_Relu6 | |||||
| cmp x7, #1 | |||||
| beq Loop_C1_1_Relu | |||||
| b Loop_C1_1_Write | |||||
| Loop_C1_1_Relu6: | |||||
| cmp w13, #0 | |||||
| beq End | |||||
| sub w13, w13, #1 | |||||
| ld1 {v0.4s}, [x1], #16 | |||||
| fadd v0.4s, v0.4s, v16.4s | |||||
| fmin v0.4s, v0.4s, v26.4s | |||||
| fmax v0.4s, v0.4s, v27.4s | |||||
| str s0, [x0] | |||||
| add x0, x0, x12 | |||||
| b Loop_C1_1_Relu6 | |||||
| Loop_C1_1_Relu: | |||||
| cmp w13, #0 | |||||
| beq End | |||||
| sub w13, w13, #1 | |||||
| ld1 {v0.4s}, [x1], #16 | |||||
| fadd v0.4s, v0.4s, v16.4s | |||||
| fmax v0.4s, v0.4s, v27.4s | |||||
| str s0, [x0] | |||||
| add x0, x0, x12 | |||||
| b Loop_C1_1_Relu | |||||
| Loop_C1_1_Write: | |||||
| cmp w13, #0 | |||||
| beq End | |||||
| sub w13, w13, #1 | |||||
| ld1 {v0.4s}, [x1], #16 | |||||
| fadd v0.4s, v0.4s, v16.4s | |||||
| str s0, [x0] | |||||
| add x0, x0, x12 | |||||
| b Loop_C1_1_Write | |||||
| Loop_C1_2: | |||||
| cmp x7, #3 | |||||
| beq Loop_C1_2_Relu6 | |||||
| cmp x7, #1 | |||||
| beq Loop_C1_2_Relu | |||||
| b Loop_C1_2_Write | |||||
| Loop_C1_2_Relu6: | |||||
| cmp w13, #0 | |||||
| beq End | |||||
| sub w13, w13, #1 | |||||
| ld1 {v0.4s}, [x1], #16 | |||||
| fadd v0.4s, v0.4s, v16.4s | |||||
| fmin v0.4s, v0.4s, v26.4s | |||||
| fmax v0.4s, v0.4s, v27.4s | |||||
| dup s1, v0.s[1] | |||||
| stp s0, s1, [x0] | |||||
| add x0, x0, x12 | |||||
| b Loop_C1_2_Relu6 | |||||
| Loop_C1_2_Relu: | |||||
| cmp w13, #0 | |||||
| beq End | |||||
| sub w13, w13, #1 | |||||
| ld1 {v0.4s}, [x1], #16 | |||||
| fadd v0.4s, v0.4s, v16.4s | |||||
| fmax v0.4s, v0.4s, v27.4s | |||||
| dup s1, v0.s[1] | |||||
| stp s0, s1, [x0] | |||||
| add x0, x0, x12 | |||||
| b Loop_C1_2_Relu | |||||
| Loop_C1_2_Write: | |||||
| cmp w13, #0 | |||||
| beq End | |||||
| sub w13, w13, #1 | |||||
| ld1 {v0.4s}, [x1], #16 | |||||
| fadd v0.4s, v0.4s, v16.4s | |||||
| dup s1, v0.s[1] | |||||
| stp s0, s1, [x0] | |||||
| add x0, x0, x12 | |||||
| b Loop_C1_2_Write | |||||
| Loop_C1_3: | |||||
| add x15, x0, #8 | |||||
| cmp x7, #3 | |||||
| beq Loop_C1_3_Relu6 | |||||
| cmp x7, #1 | |||||
| beq Loop_C1_3_Relu | |||||
| b Loop_C1_3_Write | |||||
| Loop_C1_3_Relu6: | |||||
| cmp w13, #0 | |||||
| beq End | |||||
| sub w13, w13, #1 | |||||
| ld1 {v0.4s}, [x1], #16 | |||||
| fadd v0.4s, v0.4s, v16.4s | |||||
| fmin v0.4s, v0.4s, v26.4s | |||||
| fmax v0.4s, v0.4s, v27.4s | |||||
| dup s1, v0.s[1] | |||||
| stp s0, s1, [x0] | |||||
| add x0, x0, x12 | |||||
| st1 {v0.s}[2], [x15], x12 | |||||
| b Loop_C1_3_Relu6 | |||||
| Loop_C1_3_Relu: | |||||
| cmp w13, #0 | |||||
| beq End | |||||
| sub w13, w13, #1 | |||||
| ld1 {v0.4s}, [x1], #16 | |||||
| fadd v0.4s, v0.4s, v16.4s | |||||
| fmax v0.4s, v0.4s, v27.4s | |||||
| dup s1, v0.s[1] | |||||
| stp s0, s1, [x0] | |||||
| add x0, x0, x12 | |||||
| st1 {v0.s}[2], [x15], x12 | |||||
| b Loop_C1_3_Relu | |||||
| Loop_C1_3_Write: | |||||
| cmp w13, #0 | |||||
| beq End | |||||
| sub w13, w13, #1 | |||||
| ld1 {v0.4s}, [x1], #16 | |||||
| fadd v0.4s, v0.4s, v16.4s | |||||
| dup s1, v0.s[1] | |||||
| stp s0, s1, [x0] | |||||
| add x0, x0, x12 | |||||
| st1 {v0.s}[2], [x15], x12 | |||||
| b Loop_C1_3_Write | |||||
| End: | |||||
| ret | |||||
| #endif | |||||
| @@ -89,7 +89,6 @@ typedef struct DeConvWg { | |||||
| typedef struct DeConvWgABuffer { | typedef struct DeConvWgABuffer { | ||||
| bool buf_init_; | bool buf_init_; | ||||
| bool trans_formed_; | |||||
| void *middle_buffer_; | void *middle_buffer_; | ||||
| void *dest_buffer_; | void *dest_buffer_; | ||||
| } DeConvWgABuffer; | } DeConvWgABuffer; | ||||
| @@ -79,15 +79,16 @@ void DeConvWgMergeFp16(const float16_t *src, float16_t *dst, size_t src_stride, | |||||
| } | } | ||||
| void _deConvWinogradFp16(float16_t *tile_in, float16_t *tile_out, float16_t *weight_buf, float16_t *tmp_buf, | void _deConvWinogradFp16(float16_t *tile_in, float16_t *tile_out, float16_t *weight_buf, float16_t *tmp_buf, | ||||
| float16_t *at_buf, float16_t *a_mid_buf, float16_t *trans_a_buf, bool a_trans, | |||||
| float16_t *at_buf, float16_t *a_mid_buf, float16_t *trans_a_buf, bool *transfered, | |||||
| float16_t *bt_buf, float16_t *b_tmp_buf, int unit_size, int w_start, int h_start, | float16_t *bt_buf, float16_t *b_tmp_buf, int unit_size, int w_start, int h_start, | ||||
| ConvParameter *conv_param, DeConvParam *deconv_param) { | ConvParameter *conv_param, DeConvParam *deconv_param) { | ||||
| int winograd_plane = unit_size * unit_size; | int winograd_plane = unit_size * unit_size; | ||||
| if (!a_trans) { | |||||
| if (!transfered[unit_size]) { | |||||
| WinogradMatrixProductLeftFp16(tile_in, at_buf, a_mid_buf, DECONV_WINOGRAD_DEFAULT_UNIT, unit_size, | WinogradMatrixProductLeftFp16(tile_in, at_buf, a_mid_buf, DECONV_WINOGRAD_DEFAULT_UNIT, unit_size, | ||||
| DECONV_WINOGRAD_DEFAULT_UNIT, deconv_param->ic_div4_ * DECONV_WINOGRAD_DEFAULT_TILE); | DECONV_WINOGRAD_DEFAULT_UNIT, deconv_param->ic_div4_ * DECONV_WINOGRAD_DEFAULT_TILE); | ||||
| WinogradMatrixProductRightFp16(a_mid_buf, at_buf, trans_a_buf, unit_size, unit_size, DECONV_WINOGRAD_DEFAULT_UNIT, | WinogradMatrixProductRightFp16(a_mid_buf, at_buf, trans_a_buf, unit_size, unit_size, DECONV_WINOGRAD_DEFAULT_UNIT, | ||||
| deconv_param->ic_div4_ * DECONV_WINOGRAD_DEFAULT_TILE); | deconv_param->ic_div4_ * DECONV_WINOGRAD_DEFAULT_TILE); | ||||
| transfered[unit_size] = false; | |||||
| } | } | ||||
| for (int index = 0; index < winograd_plane; index++) { | for (int index = 0; index < winograd_plane; index++) { | ||||
| @@ -265,6 +266,7 @@ void DeconvWgFp16(float16_t *nhwc_input_, float16_t *tile_in, float16_t *tile_ou | |||||
| } | } | ||||
| /* compute */ | /* compute */ | ||||
| bool transfered[DECONV_WINOGRAD_BUFFER_COUNT] = {false}; | |||||
| for (int i = 0; i < deconv_param->compute_size_; i++) { | for (int i = 0; i < deconv_param->compute_size_; i++) { | ||||
| DeConvComputeUnit *unit = &deconv_param->compute_units_[i]; | DeConvComputeUnit *unit = &deconv_param->compute_units_[i]; | ||||
| if (unit->use_winograd_) { | if (unit->use_winograd_) { | ||||
| @@ -281,9 +283,8 @@ void DeconvWgFp16(float16_t *nhwc_input_, float16_t *tile_in, float16_t *tile_ou | |||||
| DECONV_WINOGRAD_DEFAULT_TILE * | DECONV_WINOGRAD_DEFAULT_TILE * | ||||
| deconv_param->oc_up4_; | deconv_param->oc_up4_; | ||||
| _deConvWinogradFp16(tile_in, tile_out, (float16_t *)unit->weight_, tmp_buf, unit->winograd_.AT_, mid_a, dst_a, | _deConvWinogradFp16(tile_in, tile_out, (float16_t *)unit->weight_, tmp_buf, unit->winograd_.AT_, mid_a, dst_a, | ||||
| tmp_a->trans_formed_, unit->winograd_.BT_, tmp_b, unit->winograd_.kh_, unit->w_start_, | |||||
| unit->h_start_, conv_param, deconv_param); | |||||
| tmp_a->trans_formed_ = true; | |||||
| transfered, unit->winograd_.BT_, tmp_b, unit->winograd_.kh_, unit->w_start_, unit->h_start_, | |||||
| conv_param, deconv_param); | |||||
| } else { | } else { | ||||
| float16_t *tmp_buf = (float16_t *)unit->tmp_buffer_ + task_id * deconv_param->oc_div4_ * unit->w_size_ * | float16_t *tmp_buf = (float16_t *)unit->tmp_buffer_ + task_id * deconv_param->oc_div4_ * unit->w_size_ * | ||||
| unit->h_size_ * DECONV_WINOGRAD_DEFAULT_TILE * C4NUM; | unit->h_size_ * DECONV_WINOGRAD_DEFAULT_TILE * C4NUM; | ||||
| @@ -56,8 +56,15 @@ void PostConvFuncFp32C8(const float *c8_out_ptr, float *out_ptr, const float *bi | |||||
| void PostConvFuncFp32C4(const float *c4_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel, | void PostConvFuncFp32C4(const float *c4_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel, | ||||
| size_t plane_size, size_t plane_stride, size_t relu_type) { | size_t plane_size, size_t plane_stride, size_t relu_type) { | ||||
| #ifdef ENABLE_ARM | |||||
| size_t oc4mod = output_channel % C4NUM; | |||||
| size_t oc4div = output_channel - oc4mod; | |||||
| size_t stride_size = (plane_stride - plane_size) * C4NUM * sizeof(float); | |||||
| PostFuncBiasReluC4(out_ptr, c4_out_ptr, bias_ptr, oc4div, oc4mod, plane_size, stride_size, relu_type); | |||||
| #else | |||||
| PostConvFuncComm(c4_out_ptr, out_ptr, bias_ptr, output_channel, plane_size, plane_stride, output_channel, relu_type, | PostConvFuncComm(c4_out_ptr, out_ptr, bias_ptr, output_channel, plane_size, plane_stride, output_channel, relu_type, | ||||
| C4NUM); | C4NUM); | ||||
| #endif | |||||
| return; | return; | ||||
| } | } | ||||
| @@ -53,6 +53,8 @@ void ConvDwFp32Border(float *dst, const float *src, const float *weight, const f | |||||
| size_t in_kh_step, size_t in_kw_step, size_t kernel_w, size_t relu, size_t relu6); | size_t in_kh_step, size_t in_kw_step, size_t kernel_w, size_t relu, size_t relu6); | ||||
| void PostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t oc8div, size_t oc8mod, | void PostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t oc8div, size_t oc8mod, | ||||
| size_t plane_size, size_t stride, size_t relu_type); | size_t plane_size, size_t stride, size_t relu_type); | ||||
| void PostFuncBiasReluC4(float *dst, const float *src, const float *bias, size_t oc4div, size_t oc4mod, | |||||
| size_t plane_size, size_t plane_stride, size_t relu_type); | |||||
| #endif | #endif | ||||
| #ifdef ENABLE_ARM64 | #ifdef ENABLE_ARM64 | ||||
| @@ -109,7 +109,11 @@ void DeConvWgInputPack(float *src_ptr, float *dst_ptr, int channel, int stride) | |||||
| float *dst = dst_ptr; | float *dst = dst_ptr; | ||||
| for (int ic = 0; ic < ic4div; ic++) { | for (int ic = 0; ic < ic4div; ic++) { | ||||
| #ifdef ENABLE_ARM | |||||
| vst1q_f32(dst, vld1q_f32(src)); | |||||
| #else | |||||
| memcpy(dst, src, C4NUM * sizeof(float)); | memcpy(dst, src, C4NUM * sizeof(float)); | ||||
| #endif | |||||
| dst += stride; | dst += stride; | ||||
| src += C4NUM; | src += C4NUM; | ||||
| } | } | ||||
| @@ -159,25 +163,27 @@ void MSGemmFloatUnit_4(float *dstOrigin, const float *src, const float *weight, | |||||
| weight_depth_offset); | weight_depth_offset); | ||||
| } | } | ||||
| void DeConvWgMerge(const float *source, float *dest, size_t srcStride, size_t dstStride, size_t count) { | |||||
| void DeConvWgMerge(const float *src, float *dst, size_t src_stride, size_t dst_stride, size_t count) { | |||||
| for (int i = 0; i < count; ++i) { | for (int i = 0; i < count; ++i) { | ||||
| const float *s = source + i * srcStride; | |||||
| float *d = dest + i * dstStride; | |||||
| const float *s = src + i * src_stride; | |||||
| float *d = dst + i * dst_stride; | |||||
| for (int j = 0; j < 4; ++j) { | for (int j = 0; j < 4; ++j) { | ||||
| d[j] += s[j]; | d[j] += s[j]; | ||||
| } | } | ||||
| } | } | ||||
| return; | |||||
| } | } | ||||
| void _deConvWinograd(float *tile_in, float *tile_out, float *weight_buf, float *tmp_buf, float *at_buf, | void _deConvWinograd(float *tile_in, float *tile_out, float *weight_buf, float *tmp_buf, float *at_buf, | ||||
| float *a_mid_buf, float *trans_a_buf, bool a_trans, float *bt_buf, float *b_tmp_buf, int unit_size, | |||||
| int w_start, int h_start, ConvParameter *conv_param, DeConvParam *deconv_param) { | |||||
| float *a_mid_buf, float *trans_a_buf, bool *transfered, float *bt_buf, float *b_tmp_buf, | |||||
| int unit_size, int w_start, int h_start, ConvParameter *conv_param, DeConvParam *deconv_param) { | |||||
| int winograd_plane = unit_size * unit_size; | int winograd_plane = unit_size * unit_size; | ||||
| if (!a_trans) { | |||||
| if (!transfered[unit_size]) { | |||||
| WinogradMatrixProductLeft(tile_in, at_buf, a_mid_buf, DECONV_WINOGRAD_DEFAULT_UNIT, unit_size, | WinogradMatrixProductLeft(tile_in, at_buf, a_mid_buf, DECONV_WINOGRAD_DEFAULT_UNIT, unit_size, | ||||
| DECONV_WINOGRAD_DEFAULT_UNIT, deconv_param->ic_div4_ * DECONV_WINOGRAD_DEFAULT_TILE); | DECONV_WINOGRAD_DEFAULT_UNIT, deconv_param->ic_div4_ * DECONV_WINOGRAD_DEFAULT_TILE); | ||||
| WinogradMatrixProductRight(a_mid_buf, at_buf, trans_a_buf, unit_size, unit_size, DECONV_WINOGRAD_DEFAULT_UNIT, | WinogradMatrixProductRight(a_mid_buf, at_buf, trans_a_buf, unit_size, unit_size, DECONV_WINOGRAD_DEFAULT_UNIT, | ||||
| deconv_param->ic_div4_ * DECONV_WINOGRAD_DEFAULT_TILE); | deconv_param->ic_div4_ * DECONV_WINOGRAD_DEFAULT_TILE); | ||||
| transfered[unit_size] = true; | |||||
| } | } | ||||
| for (int index = 0; index < winograd_plane; index++) { | for (int index = 0; index < winograd_plane; index++) { | ||||
| @@ -274,6 +280,7 @@ void DeconvWg(float *nhwc_input_, float *tile_in, float *tile_out, int start_ind | |||||
| } | } | ||||
| /* compute */ | /* compute */ | ||||
| bool transfered[DECONV_WINOGRAD_BUFFER_COUNT] = {false}; | |||||
| for (int i = 0; i < deconv_param->compute_size_; i++) { | for (int i = 0; i < deconv_param->compute_size_; i++) { | ||||
| DeConvComputeUnit *unit = &deconv_param->compute_units_[i]; | DeConvComputeUnit *unit = &deconv_param->compute_units_[i]; | ||||
| if (unit->use_winograd_) { | if (unit->use_winograd_) { | ||||
| @@ -289,9 +296,8 @@ void DeconvWg(float *nhwc_input_, float *tile_in, float *tile_out, int start_ind | |||||
| float *tmp_b_buf = (float *)unit->winograd_.b_buffer_ + task_id * unit->winograd_.kh_ * unit->winograd_.kw_ * | float *tmp_b_buf = (float *)unit->winograd_.b_buffer_ + task_id * unit->winograd_.kh_ * unit->winograd_.kw_ * | ||||
| deconv_param->oc_up4_ * DECONV_WINOGRAD_DEFAULT_TILE; | deconv_param->oc_up4_ * DECONV_WINOGRAD_DEFAULT_TILE; | ||||
| _deConvWinograd(tile_in, tile_out, (float *)unit->weight_, tmp_buf, unit->winograd_.AT_, wg_mid_a_buf, | _deConvWinograd(tile_in, tile_out, (float *)unit->weight_, tmp_buf, unit->winograd_.AT_, wg_mid_a_buf, | ||||
| wg_dst_a_buf, wg_buf->trans_formed_, unit->winograd_.BT_, tmp_b_buf, unit->winograd_.kh_, | |||||
| unit->w_start_, unit->h_start_, conv_param, deconv_param); | |||||
| wg_buf->trans_formed_ = true; | |||||
| wg_dst_a_buf, transfered, unit->winograd_.BT_, tmp_b_buf, unit->winograd_.kh_, unit->w_start_, | |||||
| unit->h_start_, conv_param, deconv_param); | |||||
| } else { | } else { | ||||
| float *tmp_buf = (float *)unit->tmp_buffer_ + task_id * deconv_param->oc_div4_ * unit->w_size_ * unit->h_size_ * | float *tmp_buf = (float *)unit->tmp_buffer_ + task_id * deconv_param->oc_div4_ * unit->w_size_ * unit->h_size_ * | ||||
| DECONV_WINOGRAD_DEFAULT_TILE * C4NUM; | DECONV_WINOGRAD_DEFAULT_TILE * C4NUM; | ||||
| @@ -75,7 +75,6 @@ int DeConvWinogradFp16CPUKernel::InitParameter() { | |||||
| if (unit.use_winograd_) { | if (unit.use_winograd_) { | ||||
| if (deconv_param_->a_buffer_[unit.winograd_.kh_].buf_init_ == false) { | if (deconv_param_->a_buffer_[unit.winograd_.kh_].buf_init_ == false) { | ||||
| deconv_param_->a_buffer_[unit.winograd_.kh_].buf_init_ = true; | deconv_param_->a_buffer_[unit.winograd_.kh_].buf_init_ = true; | ||||
| deconv_param_->a_buffer_[unit.winograd_.kh_].trans_formed_ = false; | |||||
| size = unit.winograd_.kh_ * unit.winograd_.kw_ * DECONV_WINOGRAD_DEFAULT_TILE * deconv_param_->ic_up4_; | size = unit.winograd_.kh_ * unit.winograd_.kw_ * DECONV_WINOGRAD_DEFAULT_TILE * deconv_param_->ic_up4_; | ||||
| deconv_param_->a_buffer_[unit.winograd_.kh_].middle_buffer_ = | deconv_param_->a_buffer_[unit.winograd_.kh_].middle_buffer_ = | ||||
| @@ -111,9 +110,6 @@ int DeConvWinogradFp16CPUKernel::DoDeconv(int task_id) { | |||||
| int calculate_count = MSMIN(DECONV_WINOGRAD_DEFAULT_TILE, | int calculate_count = MSMIN(DECONV_WINOGRAD_DEFAULT_TILE, | ||||
| deconv_param_->in_tile_w_count_ * deconv_param_->in_tile_h_count_ - start_index); | deconv_param_->in_tile_w_count_ * deconv_param_->in_tile_h_count_ - start_index); | ||||
| for (int i = 0; i < DECONV_WINOGRAD_BUFFER_COUNT; i++) { | |||||
| deconv_param_->a_buffer_[i].trans_formed_ = false; | |||||
| } | |||||
| DeconvWgFp16(nhwc_input_, tile_in, tile_out, start_index, calculate_count, conv_param_, deconv_param_, task_id); | DeconvWgFp16(nhwc_input_, tile_in, tile_out, start_index, calculate_count, conv_param_, deconv_param_, task_id); | ||||
| std::unique_lock<std::mutex> merge_lock(lock_); | std::unique_lock<std::mutex> merge_lock(lock_); | ||||
| @@ -138,7 +138,6 @@ int DeConvolutionWinogradCPUKernel::InitParameter() { | |||||
| if (unit.use_winograd_) { | if (unit.use_winograd_) { | ||||
| if (deconv_param_->a_buffer_[unit.winograd_.kh_].buf_init_ == false) { | if (deconv_param_->a_buffer_[unit.winograd_.kh_].buf_init_ == false) { | ||||
| deconv_param_->a_buffer_[unit.winograd_.kh_].buf_init_ = true; | deconv_param_->a_buffer_[unit.winograd_.kh_].buf_init_ = true; | ||||
| deconv_param_->a_buffer_[unit.winograd_.kh_].trans_formed_ = false; | |||||
| size = unit.winograd_.kh_ * unit.winograd_.kw_ * DECONV_WINOGRAD_DEFAULT_TILE * deconv_param_->ic_up4_; | size = unit.winograd_.kh_ * unit.winograd_.kw_ * DECONV_WINOGRAD_DEFAULT_TILE * deconv_param_->ic_up4_; | ||||
| deconv_param_->a_buffer_[unit.winograd_.kh_].middle_buffer_ = | deconv_param_->a_buffer_[unit.winograd_.kh_].middle_buffer_ = | ||||
| @@ -308,9 +307,6 @@ int DeConvolutionWinogradCPUKernel::DoDeconv(int task_id) { | |||||
| int calculate_count = MSMIN(DECONV_WINOGRAD_DEFAULT_TILE, | int calculate_count = MSMIN(DECONV_WINOGRAD_DEFAULT_TILE, | ||||
| deconv_param_->in_tile_w_count_ * deconv_param_->in_tile_h_count_ - start_index); | deconv_param_->in_tile_w_count_ * deconv_param_->in_tile_h_count_ - start_index); | ||||
| for (int i = 0; i < DECONV_WINOGRAD_BUFFER_COUNT; i++) { | |||||
| deconv_param_->a_buffer_[i].trans_formed_ = false; | |||||
| } | |||||
| DeconvWg(nhwc_input_, tile_in, tile_out, start_index, calculate_count, conv_param_, deconv_param_, task_id); | DeconvWg(nhwc_input_, tile_in, tile_out, start_index, calculate_count, conv_param_, deconv_param_, task_id); | ||||
| std::unique_lock<std::mutex> merge_lock(lock_); | std::unique_lock<std::mutex> merge_lock(lock_); | ||||