| @@ -0,0 +1,633 @@ | |||||
| #ifdef __aarch64__ | |||||
| .text | |||||
| .align 5 | |||||
| .global AdderFloatNeon64 | |||||
| #ifndef __APPLE__ | |||||
| .type AdderFloatNeon64, %function | |||||
| #endif | |||||
| // void AdderFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth | |||||
| // int row, int col, size_t stride) | |||||
| // x0: a | |||||
| // x1: b | |||||
| // x2: c | |||||
| // x3: bias | |||||
| // x4: act_type | |||||
| // x5: depth | |||||
| // x6: row | |||||
| // x7: col | |||||
| // x8: stride | |||||
| // x9: writeMode | |||||
| AdderFloatNeon64: | |||||
| sub sp, sp, #144 | |||||
| st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 | |||||
| st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 | |||||
| stp x19, x20, [sp], #16 | |||||
| ldr x8, [sp] | |||||
| mov x18, #48 // sizeof(float) * 12 | |||||
| mul x17, x5, x18 // block stride of lhs/rhs: sizeof(float) * 12 * depth | |||||
| mov x18, #4 | |||||
| mul x8, x8, x18 | |||||
| LoopRowStart: | |||||
| cmp x6, #4 | |||||
| ble LoopRow4 | |||||
| cmp x6, #8 | |||||
| blt LoopRow8 | |||||
| LoopRow: | |||||
| mov x14, x1 // reload rhs ptr | |||||
| mov x13, x7 // reload rhs col | |||||
| mov x12, x3 // reload bias | |||||
| LoopCol: | |||||
| mov x11, x2 | |||||
| mov x10, x0 // reload lhs ptr | |||||
| mov x19, x5 // reload depth | |||||
| LoopDepthStart: | |||||
| ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 | |||||
| ld1 {v3.4s}, [x14], #16 | |||||
| dup v8.4s, v0.s[0] | |||||
| fabd v9.4s, v3.4s, v8.4s | |||||
| dup v10.4s, v0.s[1] | |||||
| fabd v11.4s, v3.4s, v10.4s | |||||
| dup v12.4s, v0.s[2] | |||||
| fabd v13.4s, v3.4s, v12.4s | |||||
| dup v14.4s, v0.s[3] | |||||
| fabd v15.4s, v3.4s, v14.4s | |||||
| dup v16.4s, v1.s[0] | |||||
| fabd v17.4s, v3.4s, v16.4s | |||||
| dup v18.4s, v1.s[1] | |||||
| fabd v19.4s, v3.4s, v18.4s | |||||
| dup v20.4s, v1.s[2] | |||||
| fabd v21.4s, v3.4s, v20.4s | |||||
| dup v22.4s, v1.s[3] | |||||
| fabd v23.4s, v3.4s, v22.4s | |||||
| dup v24.4s, v2.s[0] | |||||
| fabd v25.4s, v3.4s, v24.4s | |||||
| dup v26.4s, v2.s[1] | |||||
| fabd v27.4s, v3.4s, v26.4s | |||||
| dup v28.4s, v2.s[2] | |||||
| fabd v29.4s, v3.4s, v28.4s | |||||
| dup v30.4s, v2.s[3] | |||||
| fabd v31.4s, v3.4s, v30.4s | |||||
| subs x19, x19, #1 | |||||
| beq Bias | |||||
| LoopDepth: | |||||
| ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 | |||||
| ld1 {v3.4s}, [x14], #16 | |||||
| dup v8.4s, v0.s[0] | |||||
| fabd v8.4s, v3.4s, v8.4s | |||||
| fadd v9.4s, v9.4s, v8.4s | |||||
| dup v10.4s, v0.s[1] | |||||
| fabd v10.4s, v3.4s, v10.4s | |||||
| fadd v11.4s, v11.4s, v10.4s | |||||
| dup v12.4s, v0.s[2] | |||||
| fabd v12.4s, v3.4s, v12.4s | |||||
| fadd v13.4s, v13.4s, v12.4s | |||||
| dup v14.4s, v0.s[3] | |||||
| fabd v14.4s, v3.4s, v14.4s | |||||
| fadd v15.4s, v15.4s, v14.4s | |||||
| dup v16.4s, v1.s[0] | |||||
| fabd v16.4s, v3.4s, v16.4s | |||||
| fadd v17.4s, v17.4s, v16.4s | |||||
| dup v18.4s, v1.s[1] | |||||
| fabd v18.4s, v3.4s, v18.4s | |||||
| fadd v19.4s, v19.4s, v18.4s | |||||
| dup v20.4s, v1.s[2] | |||||
| fabd v20.4s, v3.4s, v20.4s | |||||
| fadd v21.4s, v21.4s, v20.4s | |||||
| dup v22.4s, v1.s[3] | |||||
| fabd v22.4s, v3.4s, v22.4s | |||||
| fadd v23.4s, v23.4s, v22.4s | |||||
| dup v24.4s, v2.s[0] | |||||
| fabd v24.4s, v3.4s, v24.4s | |||||
| fadd v25.4s, v25.4s, v24.4s | |||||
| dup v26.4s, v2.s[1] | |||||
| fabd v26.4s, v3.4s, v26.4s | |||||
| fadd v27.4s, v27.4s, v26.4s | |||||
| dup v28.4s, v2.s[2] | |||||
| fabd v28.4s, v3.4s, v28.4s | |||||
| fadd v29.4s, v29.4s, v28.4s | |||||
| dup v30.4s, v2.s[3] | |||||
| fabd v30.4s, v3.4s, v30.4s | |||||
| fadd v31.4s, v31.4s, v30.4s | |||||
| subs x19, x19, #1 | |||||
| bgt LoopDepth | |||||
| Bias: | |||||
| fneg v9.4s, v9.4s | |||||
| fneg v11.4s, v11.4s | |||||
| fneg v13.4s, v13.4s | |||||
| fneg v15.4s, v15.4s | |||||
| fneg v17.4s, v17.4s | |||||
| fneg v19.4s, v19.4s | |||||
| fneg v21.4s, v21.4s | |||||
| fneg v23.4s, v23.4s | |||||
| fneg v25.4s, v25.4s | |||||
| fneg v27.4s, v27.4s | |||||
| fneg v29.4s, v29.4s | |||||
| fneg v31.4s, v31.4s | |||||
| cbz x3, Activation | |||||
| ld1 {v0.4s}, [x12], #16 | |||||
| fadd v9.4s, v9.4s, v0.4s | |||||
| fadd v11.4s, v11.4s, v0.4s | |||||
| fadd v13.4s, v13.4s, v0.4s | |||||
| fadd v15.4s, v15.4s, v0.4s | |||||
| fadd v17.4s, v17.4s, v0.4s | |||||
| fadd v19.4s, v19.4s, v0.4s | |||||
| fadd v21.4s, v21.4s, v0.4s | |||||
| fadd v23.4s, v23.4s, v0.4s | |||||
| fadd v25.4s, v25.4s, v0.4s | |||||
| fadd v27.4s, v27.4s, v0.4s | |||||
| fadd v29.4s, v29.4s, v0.4s | |||||
| fadd v31.4s, v31.4s, v0.4s | |||||
| Activation: | |||||
| cmp x4, #3 | |||||
| beq Relu6 | |||||
| cmp x4, #1 | |||||
| beq Relu | |||||
| b Write | |||||
| Relu6: | |||||
| mov w19, #6 | |||||
| dup v2.4s, w19 | |||||
| scvtf v2.4s, v2.4s | |||||
| fmin v9.4s, v9.4s, v2.4s | |||||
| fmin v11.4s, v11.4s, v2.4s | |||||
| fmin v13.4s, v13.4s, v2.4s | |||||
| fmin v15.4s, v15.4s, v2.4s | |||||
| fmin v17.4s, v17.4s, v2.4s | |||||
| fmin v19.4s, v19.4s, v2.4s | |||||
| fmin v21.4s, v21.4s, v2.4s | |||||
| fmin v23.4s, v23.4s, v2.4s | |||||
| fmin v25.4s, v25.4s, v2.4s | |||||
| fmin v27.4s, v27.4s, v2.4s | |||||
| fmin v29.4s, v29.4s, v2.4s | |||||
| fmin v31.4s, v31.4s, v2.4s | |||||
| Relu: | |||||
| dup v3.4s, wzr | |||||
| fmax v9.4s, v9.4s, v3.4s | |||||
| fmax v11.4s, v11.4s, v3.4s | |||||
| fmax v13.4s, v13.4s, v3.4s | |||||
| fmax v15.4s, v15.4s, v3.4s | |||||
| fmax v17.4s, v17.4s, v3.4s | |||||
| fmax v19.4s, v19.4s, v3.4s | |||||
| fmax v21.4s, v21.4s, v3.4s | |||||
| fmax v23.4s, v23.4s, v3.4s | |||||
| fmax v25.4s, v25.4s, v3.4s | |||||
| fmax v27.4s, v27.4s, v3.4s | |||||
| fmax v29.4s, v29.4s, v3.4s | |||||
| fmax v31.4s, v31.4s, v3.4s | |||||
| b Write | |||||
| LoopRow8: | |||||
| mov x14, x1 // reload rhs ptr | |||||
| mov x13, x7 // reload rhs col | |||||
| mov x12, x3 // reload bias | |||||
| LoopCol8: | |||||
| mov x11, x2 | |||||
| mov x10, x0 // reload lhs ptr | |||||
| mov x19, x5 // reload depth | |||||
| LoopDepthStart8: | |||||
| ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 | |||||
| ld1 {v3.4s}, [x14], #16 | |||||
| dup v8.4s, v0.s[0] | |||||
| fabd v9.4s, v3.4s, v8.4s | |||||
| dup v10.4s, v0.s[1] | |||||
| fabd v11.4s, v3.4s, v10.4s | |||||
| dup v12.4s, v0.s[2] | |||||
| fabd v13.4s, v3.4s, v12.4s | |||||
| dup v14.4s, v0.s[3] | |||||
| fabd v15.4s, v3.4s, v14.4s | |||||
| dup v16.4s, v1.s[0] | |||||
| fabd v17.4s, v3.4s, v16.4s | |||||
| dup v18.4s, v1.s[1] | |||||
| fabd v19.4s, v3.4s, v18.4s | |||||
| dup v20.4s, v1.s[2] | |||||
| fabd v21.4s, v3.4s, v20.4s | |||||
| dup v22.4s, v1.s[3] | |||||
| fabd v23.4s, v3.4s, v22.4s | |||||
| subs x19, x19, #1 | |||||
| beq Bias8 | |||||
| LoopDepth8: | |||||
| ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 | |||||
| ld1 {v3.4s}, [x14], #16 | |||||
| dup v8.4s, v0.s[0] | |||||
| fabd v8.4s, v3.4s, v8.4s | |||||
| fadd v9.4s, v9.4s, v8.4s | |||||
| dup v10.4s, v0.s[1] | |||||
| fabd v10.4s, v3.4s, v10.4s | |||||
| fadd v11.4s, v11.4s, v10.4s | |||||
| dup v12.4s, v0.s[2] | |||||
| fabd v12.4s, v3.4s, v12.4s | |||||
| fadd v13.4s, v13.4s, v12.4s | |||||
| dup v14.4s, v0.s[3] | |||||
| fabd v14.4s, v3.4s, v14.4s | |||||
| fadd v15.4s, v15.4s, v14.4s | |||||
| dup v16.4s, v1.s[0] | |||||
| fabd v16.4s, v3.4s, v16.4s | |||||
| fadd v17.4s, v17.4s, v16.4s | |||||
| dup v18.4s, v1.s[1] | |||||
| fabd v18.4s, v3.4s, v18.4s | |||||
| fadd v19.4s, v19.4s, v18.4s | |||||
| dup v20.4s, v1.s[2] | |||||
| fabd v20.4s, v3.4s, v20.4s | |||||
| fadd v21.4s, v21.4s, v20.4s | |||||
| dup v22.4s, v1.s[3] | |||||
| fabd v22.4s, v3.4s, v22.4s | |||||
| fadd v23.4s, v23.4s, v22.4s | |||||
| subs x19, x19, #1 | |||||
| bgt LoopDepth8 | |||||
| Bias8: | |||||
| fneg v9.4s, v9.4s | |||||
| fneg v11.4s, v11.4s | |||||
| fneg v13.4s, v13.4s | |||||
| fneg v15.4s, v15.4s | |||||
| fneg v17.4s, v17.4s | |||||
| fneg v19.4s, v19.4s | |||||
| fneg v21.4s, v21.4s | |||||
| fneg v23.4s, v23.4s | |||||
| cbz x3, Activation8 | |||||
| ld1 {v0.4s}, [x12], #16 | |||||
| fadd v9.4s, v9.4s, v0.4s | |||||
| fadd v11.4s, v11.4s, v0.4s | |||||
| fadd v13.4s, v13.4s, v0.4s | |||||
| fadd v15.4s, v15.4s, v0.4s | |||||
| fadd v17.4s, v17.4s, v0.4s | |||||
| fadd v19.4s, v19.4s, v0.4s | |||||
| fadd v21.4s, v21.4s, v0.4s | |||||
| fadd v23.4s, v23.4s, v0.4s | |||||
| Activation8: | |||||
| cmp x4, #3 | |||||
| beq Relu68 | |||||
| cmp x4, #1 | |||||
| beq Relu8 | |||||
| b Write | |||||
| Relu68: | |||||
| mov w19, #6 | |||||
| dup v2.4s, w19 | |||||
| scvtf v2.4s, v2.4s | |||||
| fmin v9.4s, v9.4s, v2.4s | |||||
| fmin v11.4s, v11.4s, v2.4s | |||||
| fmin v13.4s, v13.4s, v2.4s | |||||
| fmin v15.4s, v15.4s, v2.4s | |||||
| fmin v17.4s, v17.4s, v2.4s | |||||
| fmin v19.4s, v19.4s, v2.4s | |||||
| fmin v21.4s, v21.4s, v2.4s | |||||
| fmin v23.4s, v23.4s, v2.4s | |||||
| Relu8: | |||||
| dup v3.4s, wzr | |||||
| fmax v9.4s, v9.4s, v3.4s | |||||
| fmax v11.4s, v11.4s, v3.4s | |||||
| fmax v13.4s, v13.4s, v3.4s | |||||
| fmax v15.4s, v15.4s, v3.4s | |||||
| fmax v17.4s, v17.4s, v3.4s | |||||
| fmax v19.4s, v19.4s, v3.4s | |||||
| fmax v21.4s, v21.4s, v3.4s | |||||
| fmax v23.4s, v23.4s, v3.4s | |||||
| b Write | |||||
| LoopRow4: | |||||
| mov x14, x1 // reload rhs ptr | |||||
| mov x13, x7 // reload rhs col | |||||
| mov x12, x3 // reload bias | |||||
| LoopCol4: | |||||
| mov x11, x2 | |||||
| mov x10, x0 // reload lhs ptr | |||||
| mov x19, x5 // reload depth | |||||
| LoopDepthStart4: | |||||
| ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 | |||||
| ld1 {v3.4s}, [x14], #16 | |||||
| dup v8.4s, v0.s[0] | |||||
| fabd v9.4s, v3.4s, v8.4s | |||||
| dup v10.4s, v0.s[1] | |||||
| fabd v11.4s, v3.4s, v10.4s | |||||
| dup v12.4s, v0.s[2] | |||||
| fabd v13.4s, v3.4s, v12.4s | |||||
| dup v14.4s, v0.s[3] | |||||
| fabd v15.4s, v3.4s, v14.4s | |||||
| subs x19, x19, #1 | |||||
| beq Bias4 | |||||
| LoopDepth4: | |||||
| ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 | |||||
| ld1 {v3.4s}, [x14], #16 | |||||
| dup v8.4s, v0.s[0] | |||||
| fabd v8.4s, v3.4s, v8.4s | |||||
| fadd v9.4s, v9.4s, v8.4s | |||||
| dup v10.4s, v0.s[1] | |||||
| fabd v10.4s, v3.4s, v10.4s | |||||
| fadd v11.4s, v11.4s, v10.4s | |||||
| dup v12.4s, v0.s[2] | |||||
| fabd v12.4s, v3.4s, v12.4s | |||||
| fadd v13.4s, v13.4s, v12.4s | |||||
| dup v14.4s, v0.s[3] | |||||
| fabd v14.4s, v3.4s, v14.4s | |||||
| fadd v15.4s, v15.4s, v14.4s | |||||
| subs x19, x19, #1 | |||||
| bgt LoopDepth4 | |||||
| Bias4: | |||||
| fneg v9.4s, v9.4s | |||||
| fneg v11.4s, v11.4s | |||||
| fneg v13.4s, v13.4s | |||||
| fneg v15.4s, v15.4s | |||||
| cbz x3, Activation4 | |||||
| ld1 {v0.4s}, [x12], #16 | |||||
| fadd v9.4s, v9.4s, v0.4s | |||||
| fadd v11.4s, v11.4s, v0.4s | |||||
| fadd v13.4s, v13.4s, v0.4s | |||||
| fadd v15.4s, v15.4s, v0.4s | |||||
| Activation4: | |||||
| cmp x4, #3 | |||||
| beq Relu64 | |||||
| cmp x4, #1 | |||||
| beq Relu4 | |||||
| b Write | |||||
| Relu64: | |||||
| mov w19, #6 | |||||
| dup v2.4s, w19 | |||||
| scvtf v2.4s, v2.4s | |||||
| fmin v9.4s, v9.4s, v2.4s | |||||
| fmin v11.4s, v11.4s, v2.4s | |||||
| fmin v13.4s, v13.4s, v2.4s | |||||
| fmin v15.4s, v15.4s, v2.4s | |||||
| Relu4: | |||||
| dup v3.4s, wzr | |||||
| fmax v9.4s, v9.4s, v2.4s | |||||
| fmax v11.4s, v11.4s, v2.4s | |||||
| fmax v13.4s, v13.4s, v2.4s | |||||
| fmax v15.4s, v15.4s, v2.4s | |||||
| b Write | |||||
| Write: | |||||
| cmp x13, #1 | |||||
| beq Write1 | |||||
| cmp x13, #2 | |||||
| beq Write2 | |||||
| cmp x13, #3 | |||||
| beq Write3 | |||||
| b Write4 | |||||
| Write1: | |||||
| add x2, x2, #4 | |||||
| str s9, [x11] | |||||
| cmp x6, #1 | |||||
| beq WriteEnd | |||||
| add x11, x11, x8 | |||||
| str s11, [x11] | |||||
| cmp x6, #2 | |||||
| beq WriteEnd | |||||
| add x11, x11, x8 | |||||
| str s13, [x11] | |||||
| cmp x6, #3 | |||||
| beq WriteEnd | |||||
| add x11, x11, x8 | |||||
| str s15, [x11] | |||||
| cmp x6, #4 | |||||
| beq WriteEnd | |||||
| add x11, x11, x8 | |||||
| str s17, [x11] | |||||
| cmp x6, #5 | |||||
| beq WriteEnd | |||||
| add x11, x11, x8 | |||||
| str s19, [x11] | |||||
| cmp x6, #6 | |||||
| beq WriteEnd | |||||
| add x11, x11, x8 | |||||
| str s21, [x11] | |||||
| cmp x6, #7 | |||||
| beq WriteEnd | |||||
| add x11, x11, x8 | |||||
| str s23, [x11] | |||||
| cmp x6, #8 | |||||
| beq WriteEnd | |||||
| add x11, x11, x8 | |||||
| str s25, [x11] | |||||
| cmp x6, #9 | |||||
| beq WriteEnd | |||||
| add x11, x11, x8 | |||||
| str s27, [x11] | |||||
| cmp x6, #10 | |||||
| beq WriteEnd | |||||
| add x11, x11, x8 | |||||
| str s29, [x11] | |||||
| cmp x6, #11 | |||||
| beq WriteEnd | |||||
| add x11, x11, x8 | |||||
| str s31, [x11] | |||||
| add x11, x11, x8 | |||||
| add x11, x11, #4 | |||||
| b WriteEnd | |||||
| Write2: | |||||
| add x2, x2, #8 | |||||
| str d9, [x11] | |||||
| cmp x6, #1 | |||||
| beq WriteEnd | |||||
| add x11, x11, x8 | |||||
| str d11, [x11] | |||||
| cmp x6, #2 | |||||
| beq WriteEnd | |||||
| add x11, x11, x8 | |||||
| str d13, [x11] | |||||
| cmp x6, #3 | |||||
| beq WriteEnd | |||||
| add x11, x11, x8 | |||||
| str d15, [x11] | |||||
| cmp x6, #4 | |||||
| beq WriteEnd | |||||
| add x11, x11, x8 | |||||
| str d17, [x11] | |||||
| cmp x6, #5 | |||||
| beq WriteEnd | |||||
| add x11, x11, x8 | |||||
| str d19, [x11] | |||||
| cmp x6, #6 | |||||
| beq WriteEnd | |||||
| add x11, x11, x8 | |||||
| str d21, [x11] | |||||
| cmp x6, #7 | |||||
| beq WriteEnd | |||||
| add x11, x11, x8 | |||||
| str d23, [x11] | |||||
| cmp x6, #8 | |||||
| beq WriteEnd | |||||
| add x11, x11, x8 | |||||
| str d25, [x11] | |||||
| cmp x6, #9 | |||||
| beq WriteEnd | |||||
| add x11, x11, x8 | |||||
| str d27, [x11] | |||||
| cmp x6, #10 | |||||
| beq WriteEnd | |||||
| add x11, x11, x8 | |||||
| str d29, [x11] | |||||
| cmp x6, #11 | |||||
| beq WriteEnd | |||||
| add x11, x11, x8 | |||||
| str d31, [x11] | |||||
| add x11, x11, x8 | |||||
| add x11, x11, #8 | |||||
| b WriteEnd | |||||
| Write3: | |||||
| add x2, x2, #12 | |||||
| add x19, x11, #8 | |||||
| str d9, [x11] | |||||
| st1 {v9.s}[2], [x19], x8 | |||||
| cmp x6, #1 | |||||
| beq WriteEnd | |||||
| add x11, x11, x8 | |||||
| str d11, [x11] | |||||
| st1 {v11.s}[2], [x19], x8 | |||||
| cmp x6, #2 | |||||
| beq WriteEnd | |||||
| add x11, x11, x8 | |||||
| str d13, [x11] | |||||
| st1 {v13.s}[2], [x19], x8 | |||||
| cmp x6, #3 | |||||
| beq WriteEnd | |||||
| add x11, x11, x8 | |||||
| str d15, [x11] | |||||
| st1 {v15.s}[2], [x19], x8 | |||||
| cmp x6, #4 | |||||
| beq WriteEnd | |||||
| add x11, x11, x8 | |||||
| str d17, [x11] | |||||
| st1 {v17.s}[2], [x19], x8 | |||||
| cmp x6, #5 | |||||
| beq WriteEnd | |||||
| add x11, x11, x8 | |||||
| str d19, [x11] | |||||
| st1 {v19.s}[2], [x19], x8 | |||||
| cmp x6, #6 | |||||
| beq WriteEnd | |||||
| add x11, x11, x8 | |||||
| str d21, [x11] | |||||
| st1 {v21.s}[2], [x19], x8 | |||||
| cmp x6, #7 | |||||
| beq WriteEnd | |||||
| add x11, x11, x8 | |||||
| str d23, [x11] | |||||
| st1 {v23.s}[2], [x19], x8 | |||||
| cmp x6, #8 | |||||
| beq WriteEnd | |||||
| add x11, x11, x8 | |||||
| str d25, [x11] | |||||
| st1 {v25.s}[2], [x19], x8 | |||||
| cmp x6, #9 | |||||
| beq WriteEnd | |||||
| add x11, x11, x8 | |||||
| str d27, [x11] | |||||
| st1 {v27.s}[2], [x19], x8 | |||||
| cmp x6, #10 | |||||
| beq WriteEnd | |||||
| add x11, x11, x8 | |||||
| str d29, [x11] | |||||
| st1 {v29.s}[2], [x19], x8 | |||||
| cmp x6, #11 | |||||
| beq WriteEnd | |||||
| add x11, x11, x8 | |||||
| str d31, [x11] | |||||
| st1 {v31.s}[2], [x19] | |||||
| add x11, x11, x8 | |||||
| add x11, x11, #12 | |||||
| b WriteEnd | |||||
| Write4: | |||||
| add x2, x2, #16 | |||||
| st1 {v9.4s}, [x11], x8 | |||||
| cmp x6, #1 | |||||
| beq WriteEnd | |||||
| st1 {v11.4s}, [x11], x8 | |||||
| cmp x6, #2 | |||||
| beq WriteEnd | |||||
| st1 {v13.4s}, [x11], x8 | |||||
| cmp x6, #3 | |||||
| beq WriteEnd | |||||
| st1 {v15.4s}, [x11], x8 | |||||
| cmp x6, #4 | |||||
| beq WriteEnd | |||||
| st1 {v17.4s}, [x11], x8 | |||||
| cmp x6, #5 | |||||
| beq WriteEnd | |||||
| st1 {v19.4s}, [x11], x8 | |||||
| cmp x6, #6 | |||||
| beq WriteEnd | |||||
| st1 {v21.4s}, [x11], x8 | |||||
| cmp x6, #7 | |||||
| beq WriteEnd | |||||
| st1 {v23.4s}, [x11], x8 | |||||
| cmp x6, #8 | |||||
| beq WriteEnd | |||||
| st1 {v25.4s}, [x11], x8 | |||||
| cmp x6, #9 | |||||
| beq WriteEnd | |||||
| st1 {v27.4s}, [x11], x8 | |||||
| cmp x6, #10 | |||||
| beq WriteEnd | |||||
| st1 {v29.4s}, [x11], x8 | |||||
| cmp x6, #11 | |||||
| beq WriteEnd | |||||
| st1 {v31.4s}, [x11], x8 | |||||
| add x11, x11, #16 | |||||
| b WriteEnd | |||||
| WriteEnd: | |||||
| subs x13, x13, #4 // rhs col - 4 | |||||
| ble LoopColEnd | |||||
| cmp x6, #4 | |||||
| ble LoopCol4 | |||||
| cmp x6, #8 | |||||
| ble LoopCol8 | |||||
| b LoopCol | |||||
| LoopColEnd: | |||||
| add x0, x0, x17 | |||||
| mov x18, #4 | |||||
| mul x18, x18, x7 | |||||
| sub x11, x11, x18 | |||||
| mov x2, x11 | |||||
| subs x6, x6, #12 | |||||
| bgt LoopRowStart | |||||
| sub sp, sp, #144 | |||||
| ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 | |||||
| ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 | |||||
| ldp x19, x20, [sp], #16 | |||||
| ret | |||||
| #endif | |||||
| @@ -0,0 +1,90 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "nnacl/fp32/adder_fp32.h" | |||||
| #include <string.h> | |||||
| #include <math.h> | |||||
| #include "nnacl/fp32/common_func_fp32.h" | |||||
| #include "nnacl/fp32/matmul_fp32.h" | |||||
| void Adder12x4(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row, | |||||
| int col, int stride) { | |||||
| for (int r = 0; r < row; r++) { | |||||
| for (int c = 0; c < col; c++) { | |||||
| int r12div = r / 12, r12mod = r % 12; | |||||
| int c4div = c / 4, c4mod = c % 4; | |||||
| size_t ci = r * stride + c; | |||||
| float value = 0; | |||||
| for (int d = 0; d < deep; d++) { | |||||
| size_t ai = r12div * deep * 12 + d * 12 + r12mod; | |||||
| size_t bi = c4div * deep * 4 + d * 4 + c4mod; | |||||
| value += fabsf(a[ai] - b[bi]); | |||||
| } | |||||
| value = -value; | |||||
| if (bias != NULL) value += bias[c]; | |||||
| if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); | |||||
| if (act_type != ActType_No) value = MSMAX(0.0f, value); | |||||
| dst[ci] = value; | |||||
| } | |||||
| } | |||||
| } | |||||
| void AdderOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row, int col, | |||||
| size_t stride) { | |||||
| #ifdef ENABLE_ARM64 | |||||
| AdderFloatNeon64(a, b, c, bias, (int)act_type, deep, row, col, stride); | |||||
| #else | |||||
| Adder12x4(a, b, c, bias, act_type, deep, row, col, stride); | |||||
| #endif | |||||
| } | |||||
| void AdderFp32(const float *input_data, float *packed_input, const float *packed_weight, const float *bias_data, | |||||
| float *col_major_input, float *output_data, int task_id, ConvParameter *conv_param) { | |||||
| int out_channel = conv_param->output_channel_; | |||||
| int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; | |||||
| int output_count = conv_param->output_h_ * conv_param->output_w_; | |||||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||||
| const int cal_num = C4NUM; | |||||
| #else | |||||
| const int cal_num = C12NUM; | |||||
| #endif | |||||
| int output_tile_count = UP_DIV(output_count, cal_num); | |||||
| for (int b = 0; b < conv_param->input_batch_; b++) { | |||||
| int in_batch_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_; | |||||
| int out_batch_offset = b * out_channel * output_count; | |||||
| for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) { | |||||
| int start_index = thread_id * cal_num; | |||||
| int real_cal_num = (output_count - start_index) < cal_num ? (output_count - start_index) : cal_num; | |||||
| float *gemm_input = packed_input + task_id * deep * cal_num; | |||||
| float *col_major_gemm_input = col_major_input + task_id * deep * cal_num; | |||||
| size_t packed_input_size = deep * cal_num * sizeof(float); | |||||
| memset(gemm_input, 0, packed_input_size); | |||||
| memset(col_major_gemm_input, 0, packed_input_size); | |||||
| Im2ColPackUnitFp32(input_data + in_batch_offset, conv_param, gemm_input, real_cal_num, start_index); | |||||
| int out_offset = thread_id * cal_num * out_channel + out_batch_offset; | |||||
| float *gemm_output = output_data + out_offset; | |||||
| #if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE) | |||||
| RowMajor2Col4Major(gemm_input, col_major_gemm_input, cal_num, deep); | |||||
| #else | |||||
| RowMajor2Col12Major(gemm_input, col_major_gemm_input, cal_num, deep); | |||||
| #endif | |||||
| AdderOpt(col_major_gemm_input, packed_weight, gemm_output, bias_data, conv_param->act_type_, deep, real_cal_num, | |||||
| out_channel, out_channel); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,47 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_NNACL_FP32_ADDER_H_ | |||||
| #define MINDSPORE_LITE_NNACL_FP32_ADDER_H_ | |||||
| #ifdef ENABLE_NEON | |||||
| #include <arm_neon.h> | |||||
| #endif | |||||
| #include "nnacl/pack.h" | |||||
| #include "nnacl/op_base.h" | |||||
| #include "nnacl/common_func.h" | |||||
| #include "nnacl/conv_parameter.h" | |||||
| #ifdef __cplusplus | |||||
| extern "C" { | |||||
| #endif | |||||
| #ifdef ENABLE_ARM64 | |||||
| void AdderFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, | |||||
| int col, size_t stride); | |||||
| #endif | |||||
| void AdderOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row, int col, | |||||
| size_t stride); | |||||
| void AdderFp32(const float *input_data, float *packed_input, const float *packed_weight, const float *bias_data, | |||||
| float *col_major_input, float *output_data, int task_id, ConvParameter *conv_param); | |||||
| #ifdef __cplusplus | |||||
| } | |||||
| #endif | |||||
| #endif // MINDSPORE_LITE_NNACL_FP32_ADDER_H_ | |||||
| @@ -207,6 +207,26 @@ table Conv2D { | |||||
| activationType: ActivationType = 0; | activationType: ActivationType = 0; | ||||
| } | } | ||||
| table Adder { | |||||
| format: Format = 0; | |||||
| group: int; | |||||
| channelIn: int; | |||||
| channelOut: int; | |||||
| kernelW: int; | |||||
| kernelH: int; | |||||
| strideW: int; | |||||
| strideH: int; | |||||
| padMode: PadMode; | |||||
| padUp: int; | |||||
| padDown: int; | |||||
| padLeft: int; | |||||
| padRight: int; | |||||
| dilateW: int; | |||||
| dilateH: int; | |||||
| hasBias: bool = false; | |||||
| activationType: ActivationType = 0; | |||||
| } | |||||
| table Conv2DGradFilter { | table Conv2DGradFilter { | ||||
| format: Format = 0; | format: Format = 0; | ||||
| group: int; | group: int; | ||||
| @@ -1176,6 +1196,3 @@ table All { | |||||
| table Assert { | table Assert { | ||||
| summarize : int; | summarize : int; | ||||
| } | } | ||||
| table Adder { | |||||
| } | |||||
| @@ -15,6 +15,14 @@ | |||||
| */ | */ | ||||
| #include "src/ops/adder.h" | #include "src/ops/adder.h" | ||||
| #include <memory> | |||||
| #include <string> | |||||
| #include "include/errorcode.h" | |||||
| #include "src/common/log_adapter.h" | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| #include "tools/converter/quantizer/quantize_util.h" | |||||
| #endif | |||||
| #ifndef PRIMITIVE_WRITEABLE | #ifndef PRIMITIVE_WRITEABLE | ||||
| #include "src/ops/ops_register.h" | #include "src/ops/ops_register.h" | ||||
| @@ -23,6 +31,118 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| int Adder::GetFormat() const { return this->primitive_->value.AsAdder()->format; } | |||||
| int Adder::GetGroup() const { return this->primitive_->value.AsAdder()->group; } | |||||
| int Adder::GetChannelIn() const { return this->primitive_->value.AsAdder()->channelIn; } | |||||
| int Adder::GetChannelOut() const { return this->primitive_->value.AsAdder()->channelOut; } | |||||
| int Adder::GetKernelW() const { return this->primitive_->value.AsAdder()->kernelW; } | |||||
| int Adder::GetKernelH() const { return this->primitive_->value.AsAdder()->kernelH; } | |||||
| int Adder::GetStrideW() const { return this->primitive_->value.AsAdder()->strideW; } | |||||
| int Adder::GetStrideH() const { return this->primitive_->value.AsAdder()->strideH; } | |||||
| int Adder::GetPadMode() const { return this->primitive_->value.AsAdder()->padMode; } | |||||
| int Adder::GetPadUp() const { return this->primitive_->value.AsAdder()->padUp; } | |||||
| int Adder::GetPadDown() const { return this->primitive_->value.AsAdder()->padDown; } | |||||
| int Adder::GetPadLeft() const { return this->primitive_->value.AsAdder()->padLeft; } | |||||
| int Adder::GetPadRight() const { return this->primitive_->value.AsAdder()->padRight; } | |||||
| int Adder::GetDilateW() const { return this->primitive_->value.AsAdder()->dilateW; } | |||||
| int Adder::GetDilateH() const { return this->primitive_->value.AsAdder()->dilateH; } | |||||
| bool Adder::GetHasBias() const { return this->primitive_->value.AsAdder()->hasBias; } | |||||
| int Adder::GetActivationType() const { return this->primitive_->value.AsAdder()->activationType; } | |||||
| void Adder::SetFormat(int format) { this->primitive_->value.AsAdder()->format = (schema::Format)format; } | |||||
| void Adder::SetGroup(int group) { this->primitive_->value.AsAdder()->group = group; } | |||||
| void Adder::SetChannelIn(int channel_in) { this->primitive_->value.AsAdder()->channelIn = channel_in; } | |||||
| void Adder::SetChannelOut(int channel_out) { this->primitive_->value.AsAdder()->channelOut = channel_out; } | |||||
| void Adder::SetKernelW(int kernel_w) { this->primitive_->value.AsAdder()->kernelW = kernel_w; } | |||||
| void Adder::SetKernelH(int kernel_h) { this->primitive_->value.AsAdder()->kernelH = kernel_h; } | |||||
| void Adder::SetStrideW(int stride_w) { this->primitive_->value.AsAdder()->strideW = stride_w; } | |||||
| void Adder::SetStrideH(int stride_h) { this->primitive_->value.AsAdder()->strideH = stride_h; } | |||||
| void Adder::SetPadMode(int pad_mode) { this->primitive_->value.AsAdder()->padMode = (schema::PadMode)pad_mode; } | |||||
| void Adder::SetPadUp(int pad_up) { this->primitive_->value.AsAdder()->padUp = pad_up; } | |||||
| void Adder::SetPadDown(int pad_down) { this->primitive_->value.AsAdder()->padDown = pad_down; } | |||||
| void Adder::SetPadLeft(int pad_left) { this->primitive_->value.AsAdder()->padLeft = pad_left; } | |||||
| void Adder::SetPadRight(int pad_right) { this->primitive_->value.AsAdder()->padRight = pad_right; } | |||||
| void Adder::SetDilateW(int dilate_w) { this->primitive_->value.AsAdder()->dilateW = dilate_w; } | |||||
| void Adder::SetDilateH(int dilate_h) { this->primitive_->value.AsAdder()->dilateH = dilate_h; } | |||||
| void Adder::SetHasBias(bool has_bias) { this->primitive_->value.AsAdder()->hasBias = has_bias; } | |||||
| void Adder::SetActivationType(int activation_type) { | |||||
| this->primitive_->value.AsAdder()->activationType = (schema::ActivationType)activation_type; | |||||
| } | |||||
| void Adder::PopulaterAdderSingleGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group) { | |||||
| auto attr = std::make_unique<schema::AdderT>(); | |||||
| attr->group = group; | |||||
| auto format = GetValue<std::string>(prim.GetAttr("data_format")); | |||||
| if (format == "NCHW") { | |||||
| attr->format = schema::Format::Format_NCHW; | |||||
| } else if (format == "NHWC") { | |||||
| attr->format = schema::Format::Format_NHWC; | |||||
| } else { | |||||
| attr->format = schema::Format::Format_NUM_OF_FORMAT; | |||||
| } | |||||
| auto pad_list = CastToInt(prim.GetAttr("pad_list")); | |||||
| attr->padUp = pad_list[0]; | |||||
| attr->padDown = pad_list[1]; | |||||
| attr->padLeft = pad_list[2]; | |||||
| attr->padRight = pad_list[3]; | |||||
| auto dilation = CastToInt(prim.GetAttr("dilation")); | |||||
| attr->dilateH = dilation[2]; | |||||
| attr->dilateW = dilation[3]; | |||||
| auto kernel_size = CastToInt(prim.GetAttr("kernel_size")); | |||||
| attr->kernelH = kernel_size[0]; | |||||
| attr->kernelW = kernel_size[1]; | |||||
| auto stride = CastToInt(prim.GetAttr("stride")); | |||||
| attr->strideH = stride[2]; | |||||
| attr->strideW = stride[3]; | |||||
| attr->channelOut = CastToInt(prim.GetAttr("out_channel")).front(); | |||||
| auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode")); | |||||
| if (pad_mode == "valid") { | |||||
| attr->padMode = schema::PadMode_VALID; | |||||
| } else if (pad_mode == "same") { | |||||
| attr->padMode = schema::PadMode_SAME_UPPER; | |||||
| } else { | |||||
| attr->padMode = schema::PadMode_NOTSET; | |||||
| } | |||||
| if (prim.GetAttr("activation_name") != nullptr) { | |||||
| auto activate_name = GetValue<std::string>(prim.GetAttr("activation_name")); | |||||
| attr->activationType = kActivationTypeMap[activate_name]; | |||||
| } else { | |||||
| attr->activationType = schema::ActivationType_NO_ACTIVATION; | |||||
| } | |||||
| primitive->value.type = schema::PrimitiveType_Adder; | |||||
| primitive->value.value = attr.release(); | |||||
| } | |||||
| int Adder::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { | |||||
| if (this->primitive_ == nullptr) { | |||||
| this->primitive_ = new (std::nothrow) schema::PrimitiveT; | |||||
| if (this->primitive_ == nullptr) { | |||||
| MS_LOG(ERROR) << "new primitiveT failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| this->primitive_->value.type = schema::PrimitiveType_Adder; | |||||
| } | |||||
| if (this->primitive_->value.type != schema::PrimitiveType_Adder) { | |||||
| MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto groupAttr = prim.GetAttr("group"); | |||||
| if (groupAttr == nullptr) { | |||||
| MS_LOG(ERROR) << "conv2d op has no group attr,please check pb model"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| int group = CastToInt(groupAttr).front(); | |||||
| PopulaterAdderSingleGroup(prim, this->primitive_, group); | |||||
| PopulaterQuantParam(prim, inputs); | |||||
| return RET_OK; | |||||
| } | |||||
| #else | #else | ||||
| int Adder::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | int Adder::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | ||||
| @@ -33,39 +153,36 @@ int Adder::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers:: | |||||
| MS_LOG(ERROR) << "value_as_Adder return nullptr"; | MS_LOG(ERROR) << "value_as_Adder return nullptr"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto val_offset = schema::CreateAdder(*fbb); | |||||
| auto val_offset = schema::CreateAdder( | |||||
| *fbb, attr->format(), attr->group(), attr->channelIn(), attr->channelOut(), attr->kernelW(), attr->kernelH(), | |||||
| attr->strideW(), attr->strideH(), attr->padMode(), attr->padUp(), attr->padDown(), attr->padLeft(), | |||||
| attr->padRight(), attr->dilateW(), attr->dilateH(), attr->hasBias(), attr->activationType()); | |||||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Adder, val_offset.o); | auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Adder, val_offset.o); | ||||
| fbb->Finish(prim_offset); | fbb->Finish(prim_offset); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int Adder::GetFormat() const { return this->primitive_->value_as_Adder()->format(); } | |||||
| int Adder::GetGroup() const { return this->primitive_->value_as_Adder()->group(); } | |||||
| int Adder::GetChannelIn() const { return this->primitive_->value_as_Adder()->channelIn(); } | |||||
| int Adder::GetChannelOut() const { return this->primitive_->value_as_Adder()->channelOut(); } | |||||
| int Adder::GetKernelW() const { return this->primitive_->value_as_Adder()->kernelW(); } | |||||
| int Adder::GetKernelH() const { return this->primitive_->value_as_Adder()->kernelH(); } | |||||
| int Adder::GetStrideW() const { return this->primitive_->value_as_Adder()->strideW(); } | |||||
| int Adder::GetStrideH() const { return this->primitive_->value_as_Adder()->strideH(); } | |||||
| int Adder::GetPadMode() const { return this->primitive_->value_as_Adder()->padMode(); } | |||||
| int Adder::GetPadUp() const { return this->primitive_->value_as_Adder()->padUp(); } | |||||
| int Adder::GetPadDown() const { return this->primitive_->value_as_Adder()->padDown(); } | |||||
| int Adder::GetPadLeft() const { return this->primitive_->value_as_Adder()->padLeft(); } | |||||
| int Adder::GetPadRight() const { return this->primitive_->value_as_Adder()->padRight(); } | |||||
| int Adder::GetDilateW() const { return this->primitive_->value_as_Adder()->dilateW(); } | |||||
| int Adder::GetDilateH() const { return this->primitive_->value_as_Adder()->dilateH(); } | |||||
| bool Adder::GetHasBias() const { return this->primitive_->value_as_Adder()->hasBias(); } | |||||
| int Adder::GetActivationType() const { return this->primitive_->value_as_Adder()->activationType(); } | |||||
| PrimitiveC *AdderCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Adder>(primitive); } | PrimitiveC *AdderCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Adder>(primitive); } | ||||
| Registry AdderRegistry(schema::PrimitiveType_Adder, AdderCreator); | Registry AdderRegistry(schema::PrimitiveType_Adder, AdderCreator); | ||||
| #endif | #endif | ||||
| int Adder::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||||
| MS_ASSERT(this->primitive_ != nullptr); | |||||
| MS_ASSERT(inputs_.size() == 2); | |||||
| auto input0 = inputs_.front(); | |||||
| MS_ASSERT(input0 != nullptr); | |||||
| MS_ASSERT(input0->shape().size() == 2); | |||||
| auto input1 = inputs_.at(1); | |||||
| MS_ASSERT(input1 != nullptr); | |||||
| MS_ASSERT(input1->shape().size() == 2); | |||||
| auto output = outputs_.front(); | |||||
| MS_ASSERT(output != nullptr); | |||||
| output->set_data_type(input0->data_type()); | |||||
| output->set_format(input0->format()); | |||||
| if (!infer_flag()) { | |||||
| return RET_OK; | |||||
| } | |||||
| std::vector<int> in_shape; | |||||
| in_shape.push_back(input0->shape().at(0)); | |||||
| in_shape.push_back(input1->shape().at(1)); | |||||
| output->set_shape(in_shape); | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -20,22 +20,62 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <set> | #include <set> | ||||
| #include <cmath> | #include <cmath> | ||||
| #include "src/ops/primitive_c.h" | |||||
| #include <memory> | |||||
| #include "src/ops/conv2d.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| class Adder : public PrimitiveC { | |||||
| class Adder : public Conv2D { | |||||
| public: | public: | ||||
| Adder() = default; | Adder() = default; | ||||
| ~Adder() = default; | ~Adder() = default; | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| MS_DECLARE_PARENT(Adder, PrimitiveC); | |||||
| explicit Adder(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||||
| MS_DECLARE_PARENT(Adder, Conv2D); | |||||
| explicit Adder(schema::PrimitiveT *primitive) : Conv2D(primitive) {} | |||||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||||
| void SetFormat(int format); | |||||
| void SetGroup(int group); | |||||
| void SetChannelIn(int channel_in); | |||||
| void SetChannelOut(int channel_out); | |||||
| void SetKernelW(int kernel_w); | |||||
| void SetKernelH(int kernel_h); | |||||
| void SetStrideW(int stride_w); | |||||
| void SetStrideH(int stride_h); | |||||
| void SetPadMode(int pad_mode); | |||||
| void SetPadUp(int pad_up); | |||||
| void SetPadDown(int pad_down); | |||||
| void SetPadLeft(int pad_left); | |||||
| void SetPadRight(int pad_right); | |||||
| void SetDilateW(int dilate_w); | |||||
| void SetDilateH(int dilate_h); | |||||
| void SetHasBias(bool has_bias); | |||||
| void SetActivationType(int activation_type); | |||||
| private: | |||||
| void PopulaterAdderSingleGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group); | |||||
| #else | #else | ||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb); | |||||
| #endif | #endif | ||||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||||
| public: | |||||
| int GetFormat() const; | |||||
| int GetGroup() const; | |||||
| int GetChannelIn() const; | |||||
| int GetChannelOut() const; | |||||
| int GetKernelW() const; | |||||
| int GetKernelH() const; | |||||
| int GetStrideW() const; | |||||
| int GetStrideH() const; | |||||
| int GetPadMode() const; | |||||
| int GetPadUp() const; | |||||
| int GetPadDown() const; | |||||
| int GetPadLeft() const; | |||||
| int GetPadRight() const; | |||||
| int GetDilateW() const; | |||||
| int GetDilateH() const; | |||||
| bool GetHasBias() const; | |||||
| int GetActivationType() const; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -34,23 +34,23 @@ class Conv2D : public PrimitiveC { | |||||
| explicit Conv2D(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | explicit Conv2D(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | ||||
| int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | ||||
| void SetFormat(int format); | |||||
| void SetGroup(int group); | |||||
| void SetChannelIn(int channel_in); | |||||
| void SetChannelOut(int channel_out); | |||||
| void SetKernelW(int kernel_w); | |||||
| void SetKernelH(int kernel_h); | |||||
| void SetStrideW(int stride_w); | |||||
| void SetStrideH(int stride_h); | |||||
| void SetPadMode(int pad_mode); | |||||
| void SetPadUp(int pad_up); | |||||
| void SetPadDown(int pad_down); | |||||
| void SetPadLeft(int pad_left); | |||||
| void SetPadRight(int pad_right); | |||||
| void SetDilateW(int dilate_w); | |||||
| void SetDilateH(int dilate_h); | |||||
| void SetHasBias(bool has_bias); | |||||
| void SetActivationType(int activation_type); | |||||
| virtual void SetFormat(int format); | |||||
| virtual void SetGroup(int group); | |||||
| virtual void SetChannelIn(int channel_in); | |||||
| virtual void SetChannelOut(int channel_out); | |||||
| virtual void SetKernelW(int kernel_w); | |||||
| virtual void SetKernelH(int kernel_h); | |||||
| virtual void SetStrideW(int stride_w); | |||||
| virtual void SetStrideH(int stride_h); | |||||
| virtual void SetPadMode(int pad_mode); | |||||
| virtual void SetPadUp(int pad_up); | |||||
| virtual void SetPadDown(int pad_down); | |||||
| virtual void SetPadLeft(int pad_left); | |||||
| virtual void SetPadRight(int pad_right); | |||||
| virtual void SetDilateW(int dilate_w); | |||||
| virtual void SetDilateH(int dilate_h); | |||||
| virtual void SetHasBias(bool has_bias); | |||||
| virtual void SetActivationType(int activation_type); | |||||
| private: | private: | ||||
| void PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group, | void PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group, | ||||
| @@ -67,23 +67,23 @@ class Conv2D : public PrimitiveC { | |||||
| int PadLeft() const; | int PadLeft() const; | ||||
| int PadRight() const; | int PadRight() const; | ||||
| int GetFormat() const; | |||||
| int GetGroup() const; | |||||
| int GetChannelIn() const; | |||||
| int GetChannelOut() const; | |||||
| int GetKernelW() const; | |||||
| int GetKernelH() const; | |||||
| int GetStrideW() const; | |||||
| int GetStrideH() const; | |||||
| int GetPadMode() const; | |||||
| int GetPadUp() const; | |||||
| int GetPadDown() const; | |||||
| int GetPadLeft() const; | |||||
| int GetPadRight() const; | |||||
| int GetDilateW() const; | |||||
| int GetDilateH() const; | |||||
| bool GetHasBias() const; | |||||
| int GetActivationType() const; | |||||
| virtual int GetFormat() const; | |||||
| virtual int GetGroup() const; | |||||
| virtual int GetChannelIn() const; | |||||
| virtual int GetChannelOut() const; | |||||
| virtual int GetKernelW() const; | |||||
| virtual int GetKernelH() const; | |||||
| virtual int GetStrideW() const; | |||||
| virtual int GetStrideH() const; | |||||
| virtual int GetPadMode() const; | |||||
| virtual int GetPadUp() const; | |||||
| virtual int GetPadDown() const; | |||||
| virtual int GetPadLeft() const; | |||||
| virtual int GetPadRight() const; | |||||
| virtual int GetDilateW() const; | |||||
| virtual int GetDilateH() const; | |||||
| virtual bool GetHasBias() const; | |||||
| virtual int GetActivationType() const; | |||||
| protected: | protected: | ||||
| void ConvInferShape(int input_h, int input_w, int *output_h, int *output_w); | void ConvInferShape(int input_h, int input_w, int *output_h, int *output_w); | ||||
| @@ -15,24 +15,53 @@ | |||||
| */ | */ | ||||
| #include "src/ops/adder.h" | #include "src/ops/adder.h" | ||||
| #include "src/common/log_adapter.h" | |||||
| #include "nnacl/conv_parameter.h" | |||||
| #include "src/ops/primitive_c.h" | #include "src/ops/primitive_c.h" | ||||
| #include "src/ops/populate/populate_register.h" | #include "src/ops/populate/populate_register.h" | ||||
| #include "nnacl/adder.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| OpParameter *PopulateAdderParameter(const mindspore::lite::PrimitiveC *primitive) { | OpParameter *PopulateAdderParameter(const mindspore::lite::PrimitiveC *primitive) { | ||||
| auto *adder_param = reinterpret_cast<AdderParameter *>(malloc(sizeof(AdderParameter))); | |||||
| if (adder_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc AdderParameter failed."; | |||||
| ConvParameter *conv_param = reinterpret_cast<ConvParameter *>(malloc(sizeof(ConvParameter))); | |||||
| if (conv_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc ConvParameter failed."; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| memset(adder_param, 0, sizeof(AdderParameter)); | |||||
| adder_param->op_parameter_.type_ = primitive->Type(); | |||||
| return reinterpret_cast<OpParameter *>(adder_param); | |||||
| memset(conv_param, 0, sizeof(ConvParameter)); | |||||
| conv_param->op_parameter_.type_ = primitive->Type(); | |||||
| auto adder_primitive = | |||||
| reinterpret_cast<mindspore::lite::Adder *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); | |||||
| conv_param->kernel_h_ = adder_primitive->GetKernelH(); | |||||
| conv_param->kernel_w_ = adder_primitive->GetKernelW(); | |||||
| conv_param->group_ = adder_primitive->GetGroup(); | |||||
| conv_param->stride_h_ = adder_primitive->GetStrideH(); | |||||
| conv_param->stride_w_ = adder_primitive->GetStrideW(); | |||||
| auto adder_lite_primitive = (lite::Adder *)primitive; | |||||
| conv_param->pad_u_ = adder_lite_primitive->PadUp(); | |||||
| conv_param->pad_d_ = adder_lite_primitive->PadDown(); | |||||
| conv_param->pad_l_ = adder_lite_primitive->PadLeft(); | |||||
| conv_param->pad_r_ = adder_lite_primitive->PadRight(); | |||||
| conv_param->dilation_h_ = adder_primitive->GetDilateH(); | |||||
| conv_param->dilation_w_ = adder_primitive->GetDilateW(); | |||||
| conv_param->input_channel_ = adder_primitive->GetChannelIn(); | |||||
| conv_param->output_channel_ = adder_primitive->GetChannelOut(); | |||||
| conv_param->group_ = adder_primitive->GetGroup(); | |||||
| auto act_type = adder_primitive->GetActivationType(); | |||||
| switch (act_type) { | |||||
| case schema::ActivationType_RELU: | |||||
| conv_param->act_type_ = ActType_Relu; | |||||
| break; | |||||
| case schema::ActivationType_RELU6: | |||||
| conv_param->act_type_ = ActType_Relu6; | |||||
| break; | |||||
| default: | |||||
| conv_param->act_type_ = ActType_No; | |||||
| break; | |||||
| } | |||||
| return reinterpret_cast<OpParameter *>(conv_param); | |||||
| } | } | ||||
| Registry AdderParameterRegistry(schema::PrimitiveType_Adder, PopulateAdderParameter); | Registry AdderParameterRegistry(schema::PrimitiveType_Adder, PopulateAdderParameter); | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -0,0 +1,133 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/runtime/kernel/arm/fp32/adder_fp32.h" | |||||
| #include "src/kernel_registry.h" | |||||
| #include "src/runtime/runtime_api.h" | |||||
| #include "include/errorcode.h" | |||||
| #include "schema/model_generated.h" | |||||
| #include "nnacl/fp32/adder_fp32.h" | |||||
| #include "nnacl/fp32/matmul_fp32.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::lite::RET_ERROR; | |||||
| using mindspore::lite::RET_INFER_INVALID; | |||||
| using mindspore::lite::RET_OK; | |||||
| using mindspore::schema::PrimitiveType_Adder; | |||||
| using mindspore::schema::Format::Format_NHWC; | |||||
| namespace mindspore::kernel { | |||||
| int AdderCPUKernel::InitWeightBias() { | |||||
| auto filter_tensor = in_tensors_.at(kWeightIndex); | |||||
| int kernel_h = filter_tensor->Height(); | |||||
| int kernel_w = filter_tensor->Width(); | |||||
| int in_channel = filter_tensor->Channel(); | |||||
| int out_channel = filter_tensor->Batch(); | |||||
| conv_param_->input_channel_ = in_channel; | |||||
| conv_param_->output_channel_ = out_channel; | |||||
| int kernel_plane = kernel_h * kernel_w; | |||||
| const int oc_block = C4NUM; | |||||
| int oc_block_num = UP_DIV(out_channel, C4NUM); | |||||
| int pack_weight_size = oc_block_num * oc_block * in_channel * kernel_plane; | |||||
| auto origin_weight = reinterpret_cast<float *>(filter_tensor->MutableData()); | |||||
| packed_weight_ = reinterpret_cast<float *>(malloc(pack_weight_size * sizeof(float))); | |||||
| if (packed_weight_ == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc packed weight failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| memset(packed_weight_, 0, pack_weight_size * sizeof(float)); | |||||
| RowMajor2Col4Major(origin_weight, packed_weight_, out_channel, in_channel * kernel_plane); | |||||
| bias_data_ = reinterpret_cast<float *>(malloc(oc_block_num * oc_block * sizeof(float))); | |||||
| if (bias_data_ == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc bias failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| memset(bias_data_, 0, oc_block_num * oc_block * sizeof(float)); | |||||
| if (in_tensors_.size() == kInputSize2) { | |||||
| auto ori_bias = reinterpret_cast<float *>(in_tensors_.at(kBiasIndex)->MutableData()); | |||||
| memcpy(bias_data_, ori_bias, out_channel * sizeof(float)); | |||||
| } else { | |||||
| MS_ASSERT(in_tensors_.size() == kInputSize1); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int AdderCPUKernel::RunImpl(int task_id) { | |||||
| auto input_tensor = in_tensors_.at(kInputIndex); | |||||
| auto ori_input_data = reinterpret_cast<float *>(input_tensor->data_c()); | |||||
| auto output_addr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->data_c()); | |||||
| AdderFp32(ori_input_data, packed_input_, packed_weight_, reinterpret_cast<float *>(bias_data_), col_major_input_, | |||||
| output_addr, task_id, conv_param_); | |||||
| return RET_OK; | |||||
| } | |||||
| int AdderImpl(void *cdata, int task_id) { | |||||
| auto adder = reinterpret_cast<AdderCPUKernel *>(cdata); | |||||
| auto error_code = adder->RunImpl(task_id); | |||||
| if (error_code != RET_OK) { | |||||
| MS_LOG(ERROR) << "Adder Run error task_id[" << task_id << "] error_code[" << error_code << "]"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int AdderCPUKernel::Run() { | |||||
| auto ret = InitTmpBuffer(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Init tmp buffer failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| int error_code = ParallelLaunch(this->context_->thread_pool_, AdderImpl, this, thread_count_); | |||||
| if (error_code != RET_OK) { | |||||
| MS_LOG(ERROR) << "adder error error_code[" << error_code << "]"; | |||||
| FreeTmpBuffer(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| FreeTmpBuffer(); | |||||
| return RET_OK; | |||||
| } | |||||
| kernel::LiteKernel *CpuAdderFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, OpParameter *op_parameter, | |||||
| const InnerContext *ctx, const kernel::KernelKey &desc, | |||||
| const mindspore::lite::PrimitiveC *primitive) { | |||||
| MS_ASSERT(op_parameter != nullptr); | |||||
| MS_ASSERT(desc.type == schema::PrimitiveType_Adder); | |||||
| MS_ASSERT(desc.data_type == kNumberTypeFloat32); | |||||
| kernel::LiteKernel *kernel = new (std::nothrow) kernel::AdderCPUKernel(op_parameter, inputs, outputs, ctx, primitive); | |||||
| if (kernel == nullptr) { | |||||
| MS_LOG(ERROR) << "kernel is nullptr."; | |||||
| free(op_parameter); | |||||
| return nullptr; | |||||
| } | |||||
| auto ret = kernel->Init(); | |||||
| if (ret != RET_OK && ret != RET_INFER_INVALID) { | |||||
| delete kernel; | |||||
| MS_LOG(ERROR) << "Init kernel failed, name: " << op_parameter->name_ << ", type: " | |||||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(op_parameter->type_)); | |||||
| return nullptr; | |||||
| } | |||||
| return kernel; | |||||
| } | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Adder, CpuAdderFp32KernelCreator) | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,41 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ADDER_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ADDER_H_ | |||||
| #include <vector> | |||||
| #include "src/lite_kernel.h" | |||||
| #include "nnacl/op_base.h" | |||||
| #include "src/runtime/kernel/arm/fp32/convolution_fp32.h" | |||||
| #include "nnacl/fp32/conv_fp32.h" | |||||
| namespace mindspore::kernel { | |||||
| class AdderCPUKernel : public ConvolutionCPUKernel { | |||||
| public: | |||||
| AdderCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||||
| const mindspore::lite::PrimitiveC *primitive) | |||||
| : ConvolutionCPUKernel(parameter, inputs, outputs, ctx, primitive) {} | |||||
| ~AdderCPUKernel() override = default; | |||||
| int InitWeightBias() override; | |||||
| int Run() override; | |||||
| int RunImpl(int task_id) override; | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ADDER_H_ | |||||
| @@ -38,13 +38,13 @@ class ConvolutionCPUKernel : public ConvolutionBaseCPUKernel { | |||||
| } | } | ||||
| int Init() override; | int Init() override; | ||||
| virtual int InitWeightBias(); | |||||
| int InitTmpBuffer(); | |||||
| int ReSize() override; | int ReSize() override; | ||||
| int Run() override; | int Run() override; | ||||
| int RunImpl(int task_id); | |||||
| int InitWeightBias(); | |||||
| int InitTmpBuffer(); | |||||
| virtual int RunImpl(int task_id); | |||||
| private: | |||||
| protected: | |||||
| void FreeTmpBuffer() { | void FreeTmpBuffer() { | ||||
| if (packed_input_ != nullptr) { | if (packed_input_ != nullptr) { | ||||
| ctx_->allocator->Free(packed_input_); | ctx_->allocator->Free(packed_input_); | ||||
| @@ -55,8 +55,10 @@ class ConvolutionCPUKernel : public ConvolutionBaseCPUKernel { | |||||
| col_major_input_ = nullptr; | col_major_input_ = nullptr; | ||||
| } | } | ||||
| } | } | ||||
| float *packed_input_ = nullptr; | |||||
| protected: | |||||
| float *packed_weight_ = nullptr; | float *packed_weight_ = nullptr; | ||||
| float *packed_input_ = nullptr; | |||||
| float *col_major_input_ = nullptr; | float *col_major_input_ = nullptr; | ||||
| }; | }; | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||