From: @zhanyuan1 Reviewed-by: @zhang_xue_tong,@zhanghaibo5 Signed-off-by: @zhang_xue_tongtags/v1.1.0
| @@ -0,0 +1,405 @@ | |||
| #ifdef ENABLE_ARM32 | |||
| .text | |||
| .align 5 | |||
| .global MatmulFloatNeon32Opt12x4 | |||
| #ifndef __APPLE__ | |||
| .type MatmulFloatNeon32Opt12x4, %function | |||
| #endif | |||
| // void MatmulFloatNeon32Opt12x4(const float *a, const float *b, float *c, const float *bias, int act_type, int depth | |||
| // int row, int col, size_t stride, size_t writeMode) | |||
| // r0: a | |||
| // r1: b | |||
| // r2: c | |||
| // r3: bias | |||
| // r4: act_type | |||
| // r5: depth | |||
| // r6: row | |||
| // r7: col | |||
| // r8: stride | |||
| // lr: OutType_C8 = 0, OutType_Nhwc = 1, OutType_TileC8 = 2 | |||
| MatmulFloatNeon32Opt12x4: | |||
| // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf | |||
| push {r0-r8, r10, r11, lr} | |||
| add sp, sp, #48 | |||
| ldr r5, [sp, #4] | |||
| ldr r6, [sp, #8] | |||
| ldr r7, [sp, #12] | |||
| ldr r8, [sp, #16] | |||
| mov lr, #48 // sizeof(float) * 12 | |||
| mul r12, r5, lr // block stride of lhs: sizeof(float) * 12 * depth | |||
| mov lr, #4 | |||
| mul r8, r8, lr // stride * sizeof(float) | |||
| LoopRow: | |||
| ldr r1, [sp, #-44] // reload rhs ptr | |||
| ldr r7, [sp, #12] // reload rhs col | |||
| ldr r3, [sp, #-36] // reload bias ptr | |||
| LoopCol: | |||
| ldr r2, [sp, #-40] // reload dst ptr | |||
| ldr r0, [sp, #-48] // reload lhs ptr | |||
| ldr r5, [sp, #4] // reload depth | |||
| vld1.32 {q3}, [r1]! | |||
| vld1.32 {q0, q1}, [r0]! | |||
| vmul.f32 q4, q3, d0[0] | |||
| vmul.f32 q5, q3, d0[1] | |||
| vmul.f32 q6, q3, d1[0] | |||
| vld1.32 {q2}, [r0]! | |||
| vmul.f32 q7, q3, d1[1] | |||
| vmul.f32 q8, q3, d2[0] | |||
| vmul.f32 q9, q3, d2[1] | |||
| vmul.f32 q10, q3, d3[0] | |||
| vmul.f32 q11, q3, d3[1] | |||
| vmul.f32 q12, q3, d4[0] | |||
| vmul.f32 q13, q3, d4[1] | |||
| vmul.f32 q14, q3, d5[0] | |||
| vmul.f32 q15, q3, d5[1] | |||
| subs r5, r5, #1 | |||
| beq Bias | |||
| LoopDepth: | |||
| vld1.32 {q3}, [r1]! | |||
| vld1.32 {q0, q1}, [r0]! | |||
| vmla.f32 q4, q3, d0[0] | |||
| vmla.f32 q5, q3, d0[1] | |||
| vmla.f32 q6, q3, d1[0] | |||
| vld1.32 {q2}, [r0]! | |||
| vmla.f32 q7, q3, d1[1] | |||
| vmla.f32 q8, q3, d2[0] | |||
| vmla.f32 q9, q3, d2[1] | |||
| vmla.f32 q10, q3, d3[0] | |||
| vmla.f32 q11, q3, d3[1] | |||
| vmla.f32 q12, q3, d4[0] | |||
| vmla.f32 q13, q3, d4[1] | |||
| vmla.f32 q14, q3, d5[0] | |||
| vmla.f32 q15, q3, d5[1] | |||
| subs r5, r5, #1 | |||
| bne LoopDepth | |||
| Bias: | |||
| cmp r3, #0 | |||
| beq Activation | |||
| vld1.32 {q0}, [r3]! | |||
| vadd.f32 q4, q4, q0 | |||
| vadd.f32 q5, q5, q0 | |||
| vadd.f32 q6, q6, q0 | |||
| vadd.f32 q7, q7, q0 | |||
| vadd.f32 q8, q8, q0 | |||
| vadd.f32 q9, q9, q0 | |||
| vadd.f32 q10, q10, q0 | |||
| vadd.f32 q11, q11, q0 | |||
| vadd.f32 q12, q12, q0 | |||
| vadd.f32 q13, q13, q0 | |||
| vadd.f32 q14, q14, q0 | |||
| vadd.f32 q15, q15, q0 | |||
| Activation: | |||
| ldr lr, [sp] | |||
| cmp lr, #3 | |||
| beq Relu6 | |||
| cmp lr, #1 | |||
| beq Relu | |||
| b Write | |||
| Relu6: | |||
| vmov.i32 q2, #6 | |||
| vcvt.f32.s32 q2, q2 | |||
| vmin.f32 q4, q4, q2 | |||
| vmin.f32 q5, q5, q2 | |||
| vmin.f32 q6, q6, q2 | |||
| vmin.f32 q7, q7, q2 | |||
| vmin.f32 q8, q8, q2 | |||
| vmin.f32 q9, q9, q2 | |||
| vmin.f32 q10, q10, q2 | |||
| vmin.f32 q11, q11, q2 | |||
| vmin.f32 q12, q12, q2 | |||
| vmin.f32 q13, q13, q2 | |||
| vmin.f32 q14, q14, q2 | |||
| vmin.f32 q15, q15, q2 | |||
| Relu: | |||
| veor q3, q3, q3 | |||
| vmax.f32 q4, q4, q3 | |||
| vmax.f32 q5, q5, q3 | |||
| vmax.f32 q6, q6, q3 | |||
| vmax.f32 q7, q7, q3 | |||
| vmax.f32 q8, q8, q3 | |||
| vmax.f32 q9, q9, q3 | |||
| vmax.f32 q10, q10, q3 | |||
| vmax.f32 q11, q11, q3 | |||
| vmax.f32 q12, q12, q3 | |||
| vmax.f32 q13, q13, q3 | |||
| vmax.f32 q14, q14, q3 | |||
| vmax.f32 q15, q15, q3 | |||
| Write: | |||
| cmp r7, #1 | |||
| beq Write1 | |||
| cmp r7, #2 | |||
| beq Write2 | |||
| cmp r7, #3 | |||
| beq Write3 | |||
| b Write4 | |||
| Write1: | |||
| add lr, r2, #4 | |||
| str lr, [sp, #-40] | |||
| vst1.32 d8[0], [r2] | |||
| cmp r6, #1 | |||
| beq WriteEnd | |||
| add r2, r2, r8 | |||
| vst1.32 d10[0], [r2] | |||
| cmp r6, #2 | |||
| beq WriteEnd | |||
| add r2, r2, r8 | |||
| vst1.32 d12[0], [r2] | |||
| cmp r6, #3 | |||
| beq WriteEnd | |||
| add r2, r2, r8 | |||
| vst1.32 d14[0], [r2] | |||
| cmp r6, #4 | |||
| beq WriteEnd | |||
| add r2, r2, r8 | |||
| vst1.32 d16[0], [r2] | |||
| cmp r6, #5 | |||
| beq WriteEnd | |||
| add r2, r2, r8 | |||
| vst1.32 d18[0], [r2] | |||
| cmp r6, #6 | |||
| beq WriteEnd | |||
| add r2, r2, r8 | |||
| vst1.32 d20[0], [r2] | |||
| cmp r6, #7 | |||
| beq WriteEnd | |||
| add r2, r2, r8 | |||
| vst1.32 d22[0], [r2] | |||
| cmp r6, #8 | |||
| beq WriteEnd | |||
| add r2, r2, r8 | |||
| vst1.32 d24[0], [r2] | |||
| cmp r6, #9 | |||
| beq WriteEnd | |||
| add r2, r2, r8 | |||
| vst1.32 d26[0], [r2] | |||
| cmp r6, #10 | |||
| beq WriteEnd | |||
| add r2, r2, r8 | |||
| vst1.32 d28[0], [r2] | |||
| cmp r6, #11 | |||
| beq WriteEnd | |||
| add r2, r2, r8 | |||
| vst1.32 d30[0], [r2] | |||
| add r2, r2, r8 | |||
| add r2, r2, #4 | |||
| b WriteEnd | |||
| Write2: | |||
| add lr, r2, #8 | |||
| str lr, [sp, #-40] | |||
| vst1.32 d8, [r2] | |||
| cmp r6, #1 | |||
| beq WriteEnd | |||
| add r2, r2, r8 | |||
| vst1.32 d10, [r2] | |||
| cmp r6, #2 | |||
| beq WriteEnd | |||
| add r2, r2, r8 | |||
| vst1.32 d12, [r2] | |||
| cmp r6, #3 | |||
| beq WriteEnd | |||
| add r2, r2, r8 | |||
| vst1.32 d14, [r2] | |||
| cmp r6, #4 | |||
| beq WriteEnd | |||
| add r2, r2, r8 | |||
| vst1.32 d16, [r2] | |||
| cmp r6, #5 | |||
| beq WriteEnd | |||
| add r2, r2, r8 | |||
| vst1.32 d18, [r2] | |||
| cmp r6, #6 | |||
| beq WriteEnd | |||
| add r2, r2, r8 | |||
| vst1.32 d20, [r2] | |||
| cmp r6, #7 | |||
| beq WriteEnd | |||
| add r2, r2, r8 | |||
| vst1.32 d22, [r2] | |||
| cmp r6, #8 | |||
| beq WriteEnd | |||
| add r2, r2, r8 | |||
| vst1.32 d24, [r2] | |||
| cmp r6, #9 | |||
| beq WriteEnd | |||
| add r2, r2, r8 | |||
| vst1.32 d26, [r2] | |||
| cmp r6, #10 | |||
| beq WriteEnd | |||
| add r2, r2, r8 | |||
| vst1.32 d28, [r2] | |||
| cmp r6, #11 | |||
| beq WriteEnd | |||
| add r2, r2, r8 | |||
| vst1.32 d30, [r2] | |||
| add r2, r2, r8 | |||
| add r2, r2, #8 | |||
| b WriteEnd | |||
| Write3: | |||
| add lr, r2, #12 | |||
| str lr, [sp, #-40] | |||
| add r4, r2, #8 | |||
| vst1.32 d8, [r2] | |||
| vst1.32 d9[0], [r4] | |||
| cmp r6, #1 | |||
| beq WriteEnd | |||
| add r2, r2, r8 | |||
| add r4, r4, r8 | |||
| vst1.32 d10, [r2] | |||
| vst1.32 d11[0], [r4] | |||
| cmp r6, #2 | |||
| beq WriteEnd | |||
| add r2, r2, r8 | |||
| add r4, r4, r8 | |||
| vst1.32 d12, [r2] | |||
| vst1.32 d13[0], [r4] | |||
| cmp r6, #3 | |||
| beq WriteEnd | |||
| add r2, r2, r8 | |||
| add r4, r4, r8 | |||
| vst1.32 d14, [r2] | |||
| vst1.32 d15[0], [r4] | |||
| cmp r6, #4 | |||
| beq WriteEnd | |||
| add r2, r2, r8 | |||
| add r4, r4, r8 | |||
| vst1.32 d16, [r2] | |||
| vst1.32 d17[0], [r4] | |||
| cmp r6, #5 | |||
| beq WriteEnd | |||
| add r2, r2, r8 | |||
| add r4, r4, r8 | |||
| vst1.32 d18, [r2] | |||
| vst1.32 d19[0], [r4] | |||
| cmp r6, #6 | |||
| beq WriteEnd | |||
| add r2, r2, r8 | |||
| add r4, r4, r8 | |||
| vst1.32 d20, [r2] | |||
| vst1.32 d21[0], [r4] | |||
| cmp r6, #7 | |||
| beq WriteEnd | |||
| add r2, r2, r8 | |||
| add r4, r4, r8 | |||
| vst1.32 d22, [r2] | |||
| vst1.32 d23[0], [r4] | |||
| cmp r6, #8 | |||
| beq WriteEnd | |||
| add r2, r2, r8 | |||
| add r4, r4, r8 | |||
| vst1.32 d24, [r2] | |||
| vst1.32 d25[0], [r4] | |||
| cmp r6, #9 | |||
| beq WriteEnd | |||
| add r2, r2, r8 | |||
| add r4, r4, r8 | |||
| vst1.32 d26, [r2] | |||
| vst1.32 d27[0], [r4] | |||
| cmp r6, #10 | |||
| beq WriteEnd | |||
| add r2, r2, r8 | |||
| add r4, r4, r8 | |||
| vst1.32 d28, [r2] | |||
| vst1.32 d29[0], [r4] | |||
| cmp r6, #11 | |||
| beq WriteEnd | |||
| add r2, r2, r8 | |||
| add r4, r4, r8 | |||
| vst1.32 d30, [r2] | |||
| vst1.32 d31[0], [r4] | |||
| add r2, r2, r8 | |||
| add r2, r2, #12 | |||
| b WriteEnd | |||
| Write4: | |||
| add lr, r2, #16 | |||
| str lr, [sp, #-40] | |||
| vst1.32 q4, [r2] | |||
| cmp r6, #1 | |||
| beq WriteEnd | |||
| add r2, r2, r8 | |||
| vst1.32 q5, [r2] | |||
| cmp r6, #2 | |||
| beq WriteEnd | |||
| add r2, r2, r8 | |||
| vst1.32 q6, [r2] | |||
| cmp r6, #3 | |||
| beq WriteEnd | |||
| add r2, r2, r8 | |||
| vst1.32 q7, [r2] | |||
| cmp r6, #4 | |||
| beq WriteEnd | |||
| add r2, r2, r8 | |||
| vst1.32 q8, [r2] | |||
| cmp r6, #5 | |||
| beq WriteEnd | |||
| add r2, r2, r8 | |||
| vst1.32 q9, [r2] | |||
| cmp r6, #6 | |||
| beq WriteEnd | |||
| add r2, r2, r8 | |||
| vst1.32 q10, [r2] | |||
| cmp r6, #7 | |||
| beq WriteEnd | |||
| add r2, r2, r8 | |||
| vst1.32 q11, [r2] | |||
| cmp r6, #8 | |||
| beq WriteEnd | |||
| add r2, r2, r8 | |||
| vst1.32 q12, [r2] | |||
| cmp r6, #9 | |||
| beq WriteEnd | |||
| add r2, r2, r8 | |||
| vst1.32 q13, [r2] | |||
| cmp r6, #10 | |||
| beq WriteEnd | |||
| add r2, r2, r8 | |||
| vst1.32 q14, [r2] | |||
| cmp r6, #11 | |||
| beq WriteEnd | |||
| add r2, r2, r8 | |||
| vst1.32 q15, [r2] | |||
| add r2, r2, r8 | |||
| add r2, r2, #16 | |||
| b WriteEnd | |||
| WriteEnd: | |||
| cmp r7, #4 | |||
| ble LoopColEnd | |||
| sub r7, r7, #4 // rhs col - 4 | |||
| b LoopCol | |||
| LoopColEnd: | |||
| ldr r0, [sp, #-48] | |||
| add r0, r0, r12 // lhs ptr + stride | |||
| str r0, [sp, #-48] | |||
| mov lr, #4 | |||
| ldr r7, [sp, #12] // reload rhs col | |||
| mul lr, lr, r7 | |||
| sub r2, r2, lr | |||
| str r2, [sp, #-40] | |||
| cmp r6, #12 | |||
| ble LoopRowEnd | |||
| sub r6, r6, #12 // lhs row - 12 | |||
| b LoopRow | |||
| LoopRowEnd: | |||
| sub sp, sp, #48 | |||
| pop {r0-r8, r10, r11, pc} | |||
| #endif | |||
| @@ -28,7 +28,7 @@ void ConvFp32(const float *input_data, float *packed_input, const float *packed_ | |||
| int output_count = conv_param->output_h_ * conv_param->output_w_; | |||
| #ifdef ENABLE_AVX | |||
| const int cal_num = C6NUM; | |||
| #elif defined(ENABLE_ARM32) || defined(ENABLE_SSE) | |||
| #elif defined(ENABLE_SSE) | |||
| const int cal_num = C4NUM; | |||
| #else | |||
| const int cal_num = C12NUM; | |||
| @@ -52,7 +52,7 @@ void ConvFp32(const float *input_data, float *packed_input, const float *packed_ | |||
| float *gemm_output = output_data + out_offset; | |||
| #ifdef ENABLE_AVX | |||
| RowMajor2Col6Major(gemm_input, col_major_gemm_input, cal_num, deep); | |||
| #elif defined(ENABLE_ARM32) || defined(ENABLE_SSE) | |||
| #elif defined(ENABLE_SSE) | |||
| RowMajor2Col4Major(gemm_input, col_major_gemm_input, cal_num, deep); | |||
| #else | |||
| RowMajor2Col12Major(gemm_input, col_major_gemm_input, cal_num, deep); | |||
| @@ -874,6 +874,8 @@ void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActT | |||
| #elif ENABLE_ARM32 | |||
| if (out_type == OutType_C8) { | |||
| MatmulFloatNeon32(a, b, c, bias, (int)act_type, deep, row, col, stride, 0, 0); | |||
| } else if (out_type == OutType_Nhwc) { | |||
| MatmulFloatNeon32Opt12x4(a, b, c, bias, (int)act_type, deep, row, col, stride, 1); | |||
| } else { | |||
| MatmulFloatNeon32Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type)); | |||
| } | |||
| @@ -53,6 +53,8 @@ void MatmulFloatNeon32(const float *a, const float *b, float *c, const float *bi | |||
| int col, int stride, size_t writeNhwc, size_t WriteWino); | |||
| void MatmulFloatNeon32Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, | |||
| int col, int stride, int write_mode); | |||
| void MatmulFloatNeon32Opt12x4(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, | |||
| int row, int col, int stride, int write_mode); | |||
| #elif ENABLE_SSE | |||
| #include <x86intrin.h> | |||
| void MatmulFloatSse64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, | |||
| @@ -544,7 +544,11 @@ void GemmMatmulPlus(int ta, int tb, int M, int N, int K, float alpha, const floa | |||
| } | |||
| } | |||
| if (incremental) output = fworkspace; | |||
| #ifdef ENABLE_ARM32 | |||
| MatmulFloatNeon32Opt(mat_a_input, mat_b_input, output, gcb->bias, (int)gcb->atype, K, M, N, ldc, 1); | |||
| #else | |||
| MatMulOpt(mat_a_input, mat_b_input, output, gcb->bias, gcb->atype, K, M, N, ldc, OutType_Nhwc); | |||
| #endif | |||
| if (incremental) addv(output, mat_c, beta, M, N, ldc); | |||
| gcb->mat_a = mat_a_input; | |||
| gcb->mat_b = mat_b_input; | |||
| @@ -46,10 +46,12 @@ typedef struct MatMulParameter { | |||
| int row_8_; | |||
| int row_12_; | |||
| int row_16_; | |||
| int row_align_; | |||
| int col_2_; | |||
| int col_4_; | |||
| int col_8_; | |||
| int col_16_; | |||
| int col_align_; | |||
| int deep_; | |||
| int deep_4_; | |||
| int deep_16_; | |||
| @@ -74,6 +74,8 @@ int Convolution1x1CPUKernel::InitConv1x1BiasWeight() { | |||
| #ifdef ENABLE_AVX | |||
| int col_tile = C16NUM; | |||
| #elif defined(ENABLE_ARM32) | |||
| int col_tile = C4NUM; | |||
| #else | |||
| int col_tile = C8NUM; | |||
| #endif | |||
| @@ -100,6 +102,9 @@ int Convolution1x1CPUKernel::InitConv1x1BiasWeight() { | |||
| #ifdef ENABLE_AVX | |||
| RowMajor2Col16Major(reinterpret_cast<float *>(filter_tensor->MutableData()), weight_ptr_, output_channel, | |||
| input_channel); | |||
| #elif defined(ENABLE_ARM32) | |||
| RowMajor2Col4Major(reinterpret_cast<float *>(filter_tensor->MutableData()), weight_ptr_, output_channel, | |||
| input_channel); | |||
| #else | |||
| RowMajor2Col8Major(reinterpret_cast<float *>(filter_tensor->MutableData()), weight_ptr_, output_channel, | |||
| input_channel); | |||
| @@ -111,7 +116,7 @@ int Convolution1x1CPUKernel::InitConv1x1Param() { | |||
| int hw_tile = C12NUM; | |||
| #ifdef ENABLE_AVX | |||
| hw_tile = C6NUM; | |||
| #elif defined(ENABLE_ARM32) || defined(ENABLE_SSE) | |||
| #elif defined(ENABLE_SSE) | |||
| hw_tile = C4NUM; | |||
| #endif | |||
| if ((matmul_param_->row_ > (hw_tile * op_parameter_->thread_num_)) && (matmul_param_->row_ > matmul_param_->col_)) { | |||
| @@ -121,6 +126,8 @@ int Convolution1x1CPUKernel::InitConv1x1Param() { | |||
| } else { | |||
| #ifdef ENABLE_AVX | |||
| int col_tile = C16NUM; | |||
| #elif defined(ENABLE_ARM32) | |||
| int col_tile = C4NUM; | |||
| #else | |||
| int col_tile = C8NUM; | |||
| #endif | |||
| @@ -195,7 +202,7 @@ int Convolution1x1CPUKernel::DoConv1x1Hw(int task_id) { | |||
| #if ENABLE_AVX | |||
| RowMajor2Col6Major(thread_input_ptr, thread_pack_input, cur_hw_, matmul_param_->deep_); | |||
| #elif defined(ENABLE_ARM32) || defined(ENABLE_SSE) | |||
| #elif defined(ENABLE_SSE) | |||
| RowMajor2Col4Major(thread_input_ptr, thread_pack_input, cur_hw_, matmul_param_->deep_); | |||
| #else | |||
| RowMajor2Col12Major(thread_input_ptr, thread_pack_input, cur_hw_, matmul_param_->deep_); | |||
| @@ -225,7 +232,7 @@ int Convolution1x1CPUKernel::Run() { | |||
| #ifdef ENABLE_AVX | |||
| pack_input_ = | |||
| reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_6_ * matmul_param_->deep_ * sizeof(float))); | |||
| #elif defined(ENABLE_ARM32) || defined(ENABLE_SSE) | |||
| #elif defined(ENABLE_SSE) | |||
| pack_input_ = | |||
| reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_4_ * matmul_param_->deep_ * sizeof(float))); | |||
| #else | |||
| @@ -251,7 +258,7 @@ int Convolution1x1CPUKernel::Run() { | |||
| } else { | |||
| #ifdef ENABLE_AVX | |||
| RowMajor2Col6Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_); | |||
| #elif defined(ENABLE_ARM32) || defined(ENABLE_SSE) | |||
| #elif defined(ENABLE_SSE) | |||
| RowMajor2Col4Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_); | |||
| #else | |||
| RowMajor2Col12Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_); | |||
| @@ -44,6 +44,8 @@ int ConvolutionCPUKernel::InitWeightBias() { | |||
| int kernel_plane = filter_tensor->Height() * filter_tensor->Width(); | |||
| #ifdef ENABLE_AVX | |||
| const int oc_block = C16NUM; | |||
| #elif ENABLE_ARM32 | |||
| const int oc_block = C4NUM; | |||
| #else | |||
| const int oc_block = C8NUM; | |||
| #endif | |||
| @@ -59,6 +61,8 @@ int ConvolutionCPUKernel::InitWeightBias() { | |||
| memset(packed_weight_, 0, pack_weight_size * sizeof(float)); | |||
| #ifdef ENABLE_AVX | |||
| RowMajor2Col16Major(origin_weight, packed_weight_, out_channel, in_channel * kernel_plane); | |||
| #elif ENABLE_ARM32 | |||
| RowMajor2Col4Major(origin_weight, packed_weight_, out_channel, in_channel * kernel_plane); | |||
| #else | |||
| RowMajor2Col8Major(origin_weight, packed_weight_, out_channel, in_channel * kernel_plane); | |||
| #endif | |||
| @@ -84,7 +88,7 @@ int ConvolutionCPUKernel::InitTmpBuffer() { | |||
| #ifdef ENABLE_AVX | |||
| int unit_size = conv_param_->kernel_h_ * conv_param_->kernel_w_ * conv_param_->input_channel_ * C6NUM * thread_count_; | |||
| #elif ENABLE_ARM32 || ENABLE_SSE | |||
| #elif ENABLE_SSE | |||
| int unit_size = conv_param_->kernel_h_ * conv_param_->kernel_w_ * conv_param_->input_channel_ * C4NUM * thread_count_; | |||
| #else | |||
| int unit_size = | |||
| @@ -53,16 +53,18 @@ int FullconnectionCPUKernel::ReSize() { | |||
| #ifdef ENABLE_AVX | |||
| int col_tile = C16NUM; | |||
| #elif defined(ENABLE_ARM32) | |||
| int col_tile = C4NUM; | |||
| #else | |||
| int col_tile = C8NUM; | |||
| #endif | |||
| fc_param_->row_12_ = UP_ROUND(fc_param_->row_, C12NUM); | |||
| fc_param_->col_8_ = UP_ROUND(fc_param_->col_, col_tile); | |||
| fc_param_->col_align_ = UP_ROUND(fc_param_->col_, col_tile); | |||
| fc_param_->row_6_ = UP_ROUND(fc_param_->col_, C6NUM); | |||
| fc_param_->row_4_ = UP_ROUND(fc_param_->row_, C4NUM); | |||
| thread_count_ = MSMIN(thread_count_, UP_DIV(fc_param_->col_8_, col_tile)); | |||
| thread_stride_ = UP_DIV(UP_DIV(fc_param_->col_8_, col_tile), thread_count_); | |||
| thread_count_ = MSMIN(thread_count_, UP_DIV(fc_param_->col_align_, col_tile)); | |||
| thread_stride_ = UP_DIV(UP_DIV(fc_param_->col_align_, col_tile), thread_count_); | |||
| #ifdef ENABLE_ARM | |||
| if (fc_param_->row_ == 1) { | |||
| @@ -72,7 +74,7 @@ int FullconnectionCPUKernel::ReSize() { | |||
| } | |||
| #endif | |||
| if (in_tensors_.size() == 3) { | |||
| int col_tmp = is_vector_input_ ? fc_param_->col_ : fc_param_->col_8_; | |||
| int col_tmp = is_vector_input_ ? fc_param_->col_ : fc_param_->col_align_; | |||
| bias_ptr_ = reinterpret_cast<float *>(malloc(col_tmp * sizeof(float))); | |||
| if (bias_ptr_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc bias_ptr_ failed"; | |||
| @@ -83,7 +85,7 @@ int FullconnectionCPUKernel::ReSize() { | |||
| #ifdef ENABLE_AVX | |||
| int row_tmp = is_vector_input_ ? 1 : fc_param_->row_6_; | |||
| #elif defined(ENABLE_ARM32) || defined(ENABLE_SSE) | |||
| #elif defined(ENABLE_SSE) | |||
| int row_tmp = is_vector_input_ ? 1 : fc_param_->row_4_; | |||
| #else | |||
| int row_tmp = is_vector_input_ ? 1 : fc_param_->row_12_; | |||
| @@ -94,7 +96,7 @@ int FullconnectionCPUKernel::ReSize() { | |||
| } | |||
| memset(a_pack_ptr_, 0, row_tmp * fc_param_->deep_ * sizeof(float)); | |||
| int col_tmp = is_vector_input_ ? fc_param_->col_ : fc_param_->col_8_; | |||
| int col_tmp = is_vector_input_ ? fc_param_->col_ : fc_param_->col_align_; | |||
| b_pack_ptr_ = reinterpret_cast<float *>(malloc(col_tmp * fc_param_->deep_ * sizeof(float))); | |||
| if (b_pack_ptr_ == nullptr) { | |||
| FreeBuf(); | |||
| @@ -130,7 +132,7 @@ void FullconnectionCPUKernel::InitMatrixA(const float *src_ptr, float *dst_ptr) | |||
| #ifdef ENABLE_AVX | |||
| RowMajor2Col6Major(src_ptr, a_pack_ptr_, fc_param_->row_, fc_param_->deep_); | |||
| #elif defined(ENABLE_ARM32) || defined(ENABLE_SSE) | |||
| #elif defined(ENABLE_SSE) | |||
| RowMajor2Col4Major(src_ptr, a_pack_ptr_, fc_param_->row_, fc_param_->deep_); | |||
| #else | |||
| RowMajor2Col12Major(src_ptr, a_pack_ptr_, fc_param_->row_, fc_param_->deep_); | |||
| @@ -144,6 +146,8 @@ void FullconnectionCPUKernel::InitMatrixB(const float *src_ptr, float *dst_ptr) | |||
| } | |||
| #ifdef ENABLE_AVX | |||
| RowMajor2Col16Major(src_ptr, dst_ptr, fc_param_->col_, fc_param_->deep_); | |||
| #elif defined(ENABLE_ARM32) | |||
| RowMajor2Col4Major(src_ptr, dst_ptr, fc_param_->col_, fc_param_->deep_); | |||
| #else | |||
| RowMajor2Col8Major(src_ptr, dst_ptr, fc_param_->col_, fc_param_->deep_); | |||
| #endif | |||
| @@ -162,6 +166,8 @@ int FcFp32MatmulRun(void *cdata, int task_id) { | |||
| int FullconnectionCPUKernel::DoMatmul(int task_id) { | |||
| #ifdef ENABLE_AVX | |||
| int col_tile = C16NUM; | |||
| #elif defined(ENABLE_ARM32) | |||
| int col_tile = C4NUM; | |||
| #else | |||
| int col_tile = C8NUM; | |||
| #endif | |||
| @@ -74,17 +74,15 @@ int MatmulCPUKernel::MallocMatrixABuffer() { | |||
| } | |||
| #endif | |||
| params_->deep_ = params_->a_transpose_ ? a_shape[a_shape.size() - 2] : a_shape[a_shape.size() - 1]; | |||
| params_->row_4_ = UP_ROUND(params_->row_, C4NUM); | |||
| params_->row_6_ = UP_ROUND(params_->row_, C6NUM); | |||
| params_->row_12_ = UP_ROUND(params_->row_, C12NUM); | |||
| #ifdef ENABLE_AVX | |||
| int row_tmp = is_vector_a_ ? 1 : params_->row_6_; | |||
| #elif defined(ENABLE_ARM32) || defined(ENABLE_SSE) | |||
| int row_tmp = is_vector_a_ ? 1 : params_->row_4_; | |||
| params_->row_align_ = UP_ROUND(params_->row_, C6NUM); | |||
| #elif defined(ENABLE_SSE) | |||
| params_->row_align_ = UP_ROUND(params_->row_, C4NUM); | |||
| #else | |||
| int row_tmp = is_vector_a_ ? 1 : params_->row_12_; | |||
| params_->row_align_ = UP_ROUND(params_->row_, C12NUM); | |||
| #endif | |||
| int row_tmp = is_vector_a_ ? 1 : params_->row_align_; | |||
| if (params_->a_const_) { | |||
| a_pack_ptr_ = reinterpret_cast<float *>(malloc(params_->batch * row_tmp * params_->deep_ * sizeof(float))); | |||
| } else { | |||
| @@ -109,17 +107,12 @@ int MatmulCPUKernel::MallocMatrixBBuffer() { | |||
| for (size_t i = 0; i < b_shape.size() - 2; ++i) { | |||
| batch *= b_shape[i]; | |||
| } | |||
| #ifdef ENABLE_AVX | |||
| int col_tile = C16NUM; | |||
| #else | |||
| int col_tile = C8NUM; | |||
| #endif | |||
| params_->batch = batch; | |||
| params_->col_ = params_->b_transpose_ ? b_shape[b_shape.size() - 2] : b_shape[b_shape.size() - 1]; | |||
| params_->col_8_ = UP_ROUND(params_->col_, col_tile); | |||
| params_->col_align_ = UP_ROUND(params_->col_, col_tile_); | |||
| params_->deep_ = params_->b_transpose_ ? b_shape[b_shape.size() - 1] : b_shape[b_shape.size() - 2]; | |||
| int col_tmp = is_vector_a_ ? params_->col_ : params_->col_8_; | |||
| int col_tmp = is_vector_a_ ? params_->col_ : params_->col_align_; | |||
| if (params_->b_const_) { | |||
| b_pack_ptr_ = reinterpret_cast<float *>(malloc(params_->batch * col_tmp * params_->deep_ * sizeof(float))); | |||
| } else { | |||
| @@ -131,8 +124,8 @@ int MatmulCPUKernel::MallocMatrixBBuffer() { | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| thread_count_ = MSMIN(thread_count_, UP_DIV(params_->col_8_, col_tile)); | |||
| thread_stride_ = UP_DIV(UP_DIV(params_->col_8_, col_tile), thread_count_); | |||
| thread_count_ = MSMIN(thread_count_, UP_DIV(params_->col_align_, col_tile_)); | |||
| thread_stride_ = UP_DIV(UP_DIV(params_->col_align_, col_tile_), thread_count_); | |||
| return RET_OK; | |||
| } | |||
| @@ -142,13 +135,8 @@ int MatmulCPUKernel::InitBias() { | |||
| params_->col_ = params_->b_const_ | |||
| ? (params_->b_transpose_ ? b_shape.at(b_shape.size() - 2) : b_shape.at(b_shape.size() - 1)) | |||
| : (c_shape.at(c_shape.size() - 1)); | |||
| #ifdef ENABLE_AVX | |||
| int col_tile = C16NUM; | |||
| #else | |||
| int col_tile = C8NUM; | |||
| #endif | |||
| params_->col_8_ = UP_ROUND(params_->col_, col_tile); | |||
| auto col_tmp = is_vector_a_ ? params_->col_ : params_->col_8_; | |||
| params_->col_align_ = UP_ROUND(params_->col_, col_tile_); | |||
| auto col_tmp = is_vector_a_ ? params_->col_ : params_->col_align_; | |||
| if (bias_ptr_ == nullptr) { | |||
| bias_ptr_ = reinterpret_cast<float *>(malloc(col_tmp * sizeof(float))); | |||
| if (bias_ptr_ == nullptr) { | |||
| @@ -184,22 +172,20 @@ void MatmulCPUKernel::InitMatrixA(const float *src_ptr, float *dst_ptr) { | |||
| for (int i = 0; i < params_->batch; i++) { | |||
| const float *src = src_ptr + i * params_->deep_ * params_->row_; | |||
| float *dst = dst_ptr + i * params_->deep_ * params_->row_align_; | |||
| #ifdef ENABLE_AVX | |||
| float *dst = dst_ptr + i * params_->deep_ * params_->row_6_; | |||
| if (params_->a_transpose_) { | |||
| RowMajor2Row6Major(src, dst, params_->deep_, params_->row_); | |||
| } else { | |||
| RowMajor2Col6Major(src, dst, params_->row_, params_->deep_); | |||
| } | |||
| #elif defined(ENABLE_ARM32) || defined(ENABLE_SSE) | |||
| float *dst = dst_ptr + i * params_->deep_ * params_->row_4_; | |||
| #elif defined(ENABLE_SSE) | |||
| if (params_->a_transpose_) { | |||
| RowMajor2Row4Major(src, dst, params_->deep_, params_->row_); | |||
| } else { | |||
| RowMajor2Col4Major(src, dst, params_->row_, params_->deep_); | |||
| } | |||
| #else | |||
| float *dst = dst_ptr + i * params_->deep_ * params_->row_12_; | |||
| if (params_->a_transpose_) { | |||
| RowMajor2Row12Major(src, dst, params_->deep_, params_->row_); | |||
| } else { | |||
| @@ -226,13 +212,19 @@ void MatmulCPUKernel::InitMatrixB(const float *src_ptr, float *dst_ptr) { | |||
| for (int i = 0; i < params_->batch; i++) { | |||
| const float *src = src_ptr + i * params_->deep_ * params_->col_; | |||
| float *dst = dst_ptr + i * params_->deep_ * params_->col_8_; | |||
| float *dst = dst_ptr + i * params_->deep_ * params_->col_align_; | |||
| #ifdef ENABLE_AVX | |||
| if (params_->b_transpose_) { | |||
| RowMajor2Col16Major(src, dst, params_->col_, params_->deep_); | |||
| } else { | |||
| RowMajor2Row16Major(src, dst, params_->deep_, params_->col_); | |||
| } | |||
| #elif defined(ENABLE_ARM32) | |||
| if (params_->b_transpose_) { | |||
| RowMajor2Col4Major(src, dst, params_->col_, params_->deep_); | |||
| } else { | |||
| RowMajor2Row4Major(src, dst, params_->deep_, params_->col_); | |||
| } | |||
| #else | |||
| if (params_->b_transpose_) { | |||
| RowMajor2Col8Major(src, dst, params_->col_, params_->deep_); | |||
| @@ -245,6 +237,13 @@ void MatmulCPUKernel::InitMatrixB(const float *src_ptr, float *dst_ptr) { | |||
| } | |||
| int MatmulCPUKernel::Init() { | |||
| #ifdef ENABLE_AVX | |||
| col_tile_ = C16NUM; | |||
| #elif defined(ENABLE_ARM32) | |||
| col_tile_ = C4NUM; | |||
| #else | |||
| col_tile_ = C8NUM; | |||
| #endif | |||
| params_->a_const_ = (in_tensors_.at(0)->data_c() != nullptr); | |||
| params_->b_const_ = (in_tensors_.at(1)->data_c() != nullptr); | |||
| if (params_->a_const_) { | |||
| @@ -275,18 +274,13 @@ int MatmulCPUKernel::Init() { | |||
| } | |||
| int MatmulCPUKernel::RunImpl(int task_id) { | |||
| #ifdef ENABLE_AVX | |||
| int col_tile = C16NUM; | |||
| #else | |||
| int col_tile = C8NUM; | |||
| #endif | |||
| int cur_oc = MSMIN(thread_stride_ * col_tile, params_->col_ - task_id * thread_stride_ * col_tile); | |||
| int cur_oc = MSMIN(thread_stride_ * col_tile_, params_->col_ - task_id * thread_stride_ * col_tile_); | |||
| if (cur_oc <= 0) { | |||
| return RET_OK; | |||
| } | |||
| auto b = cur_b_ptr_ + task_id * thread_stride_ * col_tile * params_->deep_; | |||
| auto c = cur_c_ptr_ + task_id * thread_stride_ * col_tile; | |||
| auto bias = bias_ptr_ ? bias_ptr_ + task_id * thread_stride_ * col_tile : NULL; | |||
| auto b = cur_b_ptr_ + task_id * thread_stride_ * col_tile_ * params_->deep_; | |||
| auto c = cur_c_ptr_ + task_id * thread_stride_ * col_tile_; | |||
| auto bias = bias_ptr_ ? bias_ptr_ + task_id * thread_stride_ * col_tile_ : NULL; | |||
| MS_ASSERT(cur_a_ptr_); | |||
| MS_ASSERT(b); | |||
| MS_ASSERT(c); | |||
| @@ -356,14 +350,8 @@ int MatmulCPUKernel::Run() { | |||
| cur_b_ptr_ = b_ptr_ + i * params_->deep_ * params_->col_; | |||
| cur_c_ptr_ = c_src + i * params_->row_ * params_->col_; | |||
| } else { | |||
| #ifdef ENABLE_AVX | |||
| cur_a_ptr_ = a_ptr_ + i * params_->row_6_ * params_->deep_; | |||
| #elif defined(ENABLE_ARM32) || defined(ENABLE_SSE) | |||
| cur_a_ptr_ = a_ptr_ + i * params_->row_4_ * params_->deep_; | |||
| #else | |||
| cur_a_ptr_ = a_ptr_ + i * params_->row_12_ * params_->deep_; | |||
| #endif | |||
| cur_b_ptr_ = b_ptr_ + i * params_->deep_ * params_->col_8_; | |||
| cur_a_ptr_ = a_ptr_ + i * params_->row_align_ * params_->deep_; | |||
| cur_b_ptr_ = b_ptr_ + i * params_->deep_ * params_->col_align_; | |||
| cur_c_ptr_ = c_src + i * params_->row_ * params_->col_; | |||
| } | |||
| auto ret = ParallelLaunch(this->context_->thread_pool_, MatmulFloatRun, this, thread_count_); | |||
| @@ -54,6 +54,7 @@ class MatmulCPUKernel : public MatmulBaseCPUKernel { | |||
| float *cur_b_ptr_ = nullptr; | |||
| float *cur_c_ptr_ = nullptr; | |||
| bool is_vector_a_ = false; | |||
| int col_tile_ = 0; | |||
| }; | |||
| } // namespace mindspore::kernel | |||