| @@ -0,0 +1,367 @@ | |||||
| #ifdef ENABLE_ARM32 | |||||
| .text | |||||
| .align 5 | |||||
| .global MatmulFloatNeon32Opt | |||||
| #ifndef __APPLE__ | |||||
| .type MatmulFloatNeon32Opt, %function | |||||
| #endif | |||||
| // void MatmulFloatNeon32Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth | |||||
| // int row, int col, size_t stride, size_t writeNhwc, size_t WriteWino) | |||||
| // r0: a | |||||
| // r1: b | |||||
| // r2: c | |||||
| // r3: bias | |||||
| // r4: act_type | |||||
| // r5: depth | |||||
| // r6: row | |||||
| // r7: col | |||||
| // r8: stride | |||||
| // lr: writeNhwc/writeWino | |||||
| MatmulFloatNeon32Opt: | |||||
| // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf | |||||
| push {r0-r8, r10, r11, lr} | |||||
| add sp, sp, #48 | |||||
| ldr r5, [sp, #4] | |||||
| ldr r7, [sp, #12] | |||||
| ldr r8, [sp, #16] | |||||
| mov lr, #32 // sizeof(float) * 8 | |||||
| mul r12, r5, lr // block stride of lhs/rhs: sizeof(float) * 8 * depth | |||||
| ldr lr, [sp, #24] | |||||
| cmp lr, #0 | |||||
| beq NoWinoSteps | |||||
| mov lr, #4 | |||||
| mul r11, r7, r8 // stride * col * sizeof(float) | |||||
| mul r11, r11, lr | |||||
| mov lr, #32 | |||||
| mul r10, r8, lr // stride * 8 * sizeof(float) | |||||
| NoWinoSteps: | |||||
| mov lr, #4 | |||||
| mul r8, r8, lr // stride * sizeof(float) | |||||
| LoopCol: | |||||
| ldr r6, [sp, #8] // reload lhs row | |||||
| ldr r0, [sp, #-48] // reload lhs ptr | |||||
| ldr r2, [sp, #-40] // reload dst ptr | |||||
| LoopRow: | |||||
| ldr r1, [sp, #-44] // reload rhs ptr | |||||
| ldr r5, [sp, #4] // reload depth | |||||
| veor q8, q8, q8 | |||||
| veor q9, q9, q9 | |||||
| veor q10, q10, q10 | |||||
| veor q11, q11, q11 | |||||
| veor q12, q12, q12 | |||||
| veor q13, q13, q13 | |||||
| veor q14, q14, q14 | |||||
| veor q15, q15, q15 | |||||
| LoopDepth: | |||||
| vld1.32 {q0}, [r0]! | |||||
| vld1.32 {q1, q2}, [r1]! | |||||
| vmla.f32 q8, q1, d0[0] | |||||
| vmla.f32 q9, q2, d0[0] | |||||
| vmla.f32 q10, q1, d0[1] | |||||
| vmla.f32 q11, q2, d0[1] | |||||
| vmla.f32 q12, q1, d1[0] | |||||
| vmla.f32 q13, q2, d1[0] | |||||
| vmla.f32 q14, q1, d1[1] | |||||
| vmla.f32 q15, q2, d1[1] | |||||
| subs r5, r5, #1 | |||||
| bne LoopDepth | |||||
| Bias: | |||||
| cmp r3, #0 | |||||
| beq Activation | |||||
| vld1.32 {q0}, [r3]! | |||||
| vld1.32 {q1}, [r3] | |||||
| sub r3, r3, #16 | |||||
| vadd.f32 q8, q8, q0 | |||||
| vadd.f32 q9, q9, q1 | |||||
| vadd.f32 q10, q10, q0 | |||||
| vadd.f32 q11, q11, q1 | |||||
| vadd.f32 q12, q12, q0 | |||||
| vadd.f32 q13, q13, q1 | |||||
| vadd.f32 q14, q14, q0 | |||||
| vadd.f32 q15, q15, q1 | |||||
| Activation: | |||||
| ldr lr, [sp] | |||||
| cmp lr, #2 | |||||
| beq Relu6 | |||||
| cmp lr, #1 | |||||
| beq Relu | |||||
| b Write | |||||
| Relu6: | |||||
| vmov.i32 q2, #6 | |||||
| vcvt.f32.s32 q2, q2 | |||||
| vmin.f32 q8, q8, q2 | |||||
| vmin.f32 q9, q9, q2 | |||||
| vmin.f32 q10, q10, q2 | |||||
| vmin.f32 q11, q11, q2 | |||||
| vmin.f32 q12, q12, q2 | |||||
| vmin.f32 q13, q13, q2 | |||||
| vmin.f32 q14, q14, q2 | |||||
| vmin.f32 q15, q15, q2 | |||||
| Relu: | |||||
| veor q3, q3, q3 | |||||
| vmax.f32 q8, q8, q3 | |||||
| vmax.f32 q9, q9, q3 | |||||
| vmax.f32 q10, q10, q3 | |||||
| vmax.f32 q11, q11, q3 | |||||
| vmax.f32 q12, q12, q3 | |||||
| vmax.f32 q13, q13, q3 | |||||
| vmax.f32 q14, q14, q3 | |||||
| vmax.f32 q15, q15, q3 | |||||
| Write: | |||||
| ldr lr, [sp, #24] | |||||
| cmp lr, #0 | |||||
| bne WriteWino | |||||
| ldr lr, [sp, #20] | |||||
| cmp lr, #0 | |||||
| beq WriteC8 | |||||
| cmp r7, #1 | |||||
| beq Write1 | |||||
| cmp r7, #2 | |||||
| beq Write2 | |||||
| cmp r7, #3 | |||||
| beq Write3 | |||||
| cmp r7, #4 | |||||
| beq Write4 | |||||
| cmp r7, #5 | |||||
| beq Write5 | |||||
| cmp r7, #6 | |||||
| beq Write6 | |||||
| cmp r7, #7 | |||||
| beq Write7 | |||||
| b Write8 | |||||
| Write1: | |||||
| vst1.32 d16[0], [r2] | |||||
| cmp r6, #1 | |||||
| beq WriteEnd | |||||
| add r2, r2, r8 | |||||
| vst1.32 d20[0], [r2] | |||||
| cmp r6, #2 | |||||
| beq WriteEnd | |||||
| add r2, r2, r8 | |||||
| vst1.32 d24[0], [r2] | |||||
| cmp r6, #3 | |||||
| beq WriteEnd | |||||
| add r2, r2, r8 | |||||
| vst1.32 d28[0], [r2] | |||||
| add r2, r2, r8 | |||||
| b WriteEnd | |||||
| Write2: | |||||
| vst1.32 d16, [r2] | |||||
| cmp r6, #1 | |||||
| beq WriteEnd | |||||
| add r2, r2, r8 | |||||
| vst1.32 d20, [r2] | |||||
| cmp r6, #2 | |||||
| beq WriteEnd | |||||
| add r2, r2, r8 | |||||
| vst1.32 d24, [r2] | |||||
| cmp r6, #3 | |||||
| beq WriteEnd | |||||
| add r2, r2, r8 | |||||
| vst1.32 d28, [r2] | |||||
| add r2, r2, r8 | |||||
| b WriteEnd | |||||
| Write3: | |||||
| add r4, r2, #8 | |||||
| vst1.32 d16, [r2] | |||||
| vst1.32 d17[0], [r4] | |||||
| cmp r6, #1 | |||||
| beq WriteEnd | |||||
| add r2, r2, r8 | |||||
| add r4, r4, r8 | |||||
| vst1.32 d20, [r2] | |||||
| vst1.32 d21[0], [r4] | |||||
| cmp r6, #2 | |||||
| beq WriteEnd | |||||
| add r2, r2, r8 | |||||
| add r4, r4, r8 | |||||
| vst1.32 d24, [r2] | |||||
| vst1.32 d25[0], [r4] | |||||
| cmp r6, #3 | |||||
| beq WriteEnd | |||||
| add r2, r2, r8 | |||||
| add r4, r4, r8 | |||||
| vst1.32 d28, [r2] | |||||
| vst1.32 d29[0], [r4] | |||||
| add r2, r2, r8 | |||||
| b WriteEnd | |||||
| Write4: | |||||
| vst1.32 q8, [r2] | |||||
| cmp r6, #1 | |||||
| beq WriteEnd | |||||
| add r2, r2, r8 | |||||
| vst1.32 q10, [r2] | |||||
| cmp r6, #2 | |||||
| beq WriteEnd | |||||
| add r2, r2, r8 | |||||
| vst1.32 q12, [r2] | |||||
| cmp r6, #3 | |||||
| beq WriteEnd | |||||
| add r2, r2, r8 | |||||
| vst1.32 q14, [r2] | |||||
| add r2, r2, r8 | |||||
| b WriteEnd | |||||
| Write5: | |||||
| add r4, r2, #16 | |||||
| vst1.32 q8, [r2] | |||||
| vst1.32 d18[0], [r4] | |||||
| cmp r6, #1 | |||||
| beq WriteEnd | |||||
| add r2, r2, r8 | |||||
| add r4, r4, r8 | |||||
| vst1.32 q10, [r2] | |||||
| vst1.32 d22[0], [r4] | |||||
| cmp r6, #2 | |||||
| beq WriteEnd | |||||
| add r2, r2, r8 | |||||
| add r4, r4, r8 | |||||
| vst1.32 q12, [r2] | |||||
| vst1.32 d26[0], [r4] | |||||
| cmp r6, #3 | |||||
| beq WriteEnd | |||||
| add r2, r2, r8 | |||||
| add r4, r4, r8 | |||||
| vst1.32 q14, [r2] | |||||
| vst1.32 d30[0], [r4] | |||||
| add r2, r2, r8 | |||||
| b WriteEnd | |||||
| Write6: | |||||
| add r4, r2, #16 | |||||
| vst1.32 q8, [r2] | |||||
| vst1.32 d18, [r4] | |||||
| cmp r6, #1 | |||||
| beq WriteEnd | |||||
| add r2, r2, r8 | |||||
| add r4, r4, r8 | |||||
| vst1.32 q10, [r2] | |||||
| vst1.32 d22, [r4] | |||||
| cmp r6, #2 | |||||
| beq WriteEnd | |||||
| add r2, r2, r8 | |||||
| add r4, r4, r8 | |||||
| vst1.32 q12, [r2] | |||||
| vst1.32 d26, [r4] | |||||
| cmp r6, #3 | |||||
| beq WriteEnd | |||||
| add r2, r2, r8 | |||||
| add r4, r4, r8 | |||||
| vst1.32 q14, [r2] | |||||
| vst1.32 d30, [r4] | |||||
| add r2, r2, r8 | |||||
| b WriteEnd | |||||
| Write7: | |||||
| add lr, r2, #24 | |||||
| add r4, r2, #16 | |||||
| vst1.32 q8, [r2] | |||||
| vst1.32 d18, [r4] | |||||
| vst1.32 d19[0], [lr] | |||||
| cmp r6, #1 | |||||
| beq WriteEnd | |||||
| add r2, r2, r8 | |||||
| add r4, r4, r8 | |||||
| add lr, lr, r8 | |||||
| vst1.32 q10, [r2] | |||||
| vst1.32 d22, [r4] | |||||
| vst1.32 d23[0], [lr] | |||||
| cmp r6, #2 | |||||
| beq WriteEnd | |||||
| add r2, r2, r8 | |||||
| add r4, r4, r8 | |||||
| add lr, lr, r8 | |||||
| vst1.32 q12, [r2] | |||||
| vst1.32 d26, [r4] | |||||
| vst1.32 d27[0], [lr] | |||||
| cmp r6, #3 | |||||
| beq WriteEnd | |||||
| add r2, r2, r8 | |||||
| add r4, r4, r8 | |||||
| add lr, lr, r8 | |||||
| vst1.32 q14, [r2] | |||||
| vst1.32 d30, [r4] | |||||
| vst1.32 d31[0], [lr] | |||||
| add r2, r2, r8 | |||||
| b WriteEnd | |||||
| WriteC8: | |||||
| vst1.32 {q8, q9}, [r2]! | |||||
| vst1.32 {q10, q11}, [r2]! | |||||
| vst1.32 {q12, q13}, [r2]! | |||||
| vst1.32 {q14, q15}, [r2]! | |||||
| b WriteEnd | |||||
| WriteWino: | |||||
| vst1.32 {q8, q9}, [r2] | |||||
| add r2, r2, r11 | |||||
| vst1.32 {q10, q11}, [r2] | |||||
| add r2, r2, r11 | |||||
| vst1.32 {q12, q13}, [r2] | |||||
| add r2, r2, r11 | |||||
| vst1.32 {q14, q15}, [r2] | |||||
| add r2, r2, r11 | |||||
| b WriteEnd | |||||
| Write8: | |||||
| vst1.32 {q8, q9}, [r2] | |||||
| cmp r6, #1 | |||||
| beq WriteEnd | |||||
| add r2, r2, r8 | |||||
| vst1.32 {q10, q11}, [r2] | |||||
| cmp r6, #2 | |||||
| beq WriteEnd | |||||
| add r2, r2, r8 | |||||
| vst1.32 {q12, q13}, [r2] | |||||
| cmp r6, #3 | |||||
| beq WriteEnd | |||||
| add r2, r2, r8 | |||||
| vst1.32 {q14, q15}, [r2] | |||||
| add r2, r2, r8 | |||||
| WriteEnd: | |||||
| cmp r6, #4 | |||||
| ble LoopRowEnd | |||||
| sub r6, r6, #4 // lhs row - 4 | |||||
| b LoopRow | |||||
| LoopRowEnd: | |||||
| ldr r1, [sp, #-44] | |||||
| add r1, r1, r12 // rhs ptr + stride | |||||
| str r1, [sp, #-44] | |||||
| cmp r3, #0 | |||||
| beq NoBiasStep | |||||
| add r3, r3, #32 // bias ptr + stride | |||||
| NoBiasStep: | |||||
| ldr lr, [sp, #24] | |||||
| cmp lr, #0 | |||||
| bne WinoDstStep | |||||
| ldr lr, [sp, #20] | |||||
| cmp lr, #0 | |||||
| beq NoDstStep | |||||
| ldr r2, [sp, #-40] | |||||
| add r2, r2, #32 // dst ptr + stride | |||||
| str r2, [sp, #-40] | |||||
| b NoDstStep | |||||
| WinoDstStep: | |||||
| ldr r2, [sp, #-40] | |||||
| add r2, r2, r10 | |||||
| str r2, [sp, #-40] | |||||
| NoDstStep: | |||||
| cmp r7, #8 | |||||
| ble LoopColEnd | |||||
| sub r7, r7, #8 // rhs col - 8 | |||||
| b LoopCol | |||||
| LoopColEnd: | |||||
| sub sp, sp, #48 | |||||
| pop {r0-r8, r10, r11, pc} | |||||
| #endif | |||||
| @@ -112,7 +112,8 @@ void IndirectGemmFp32_8x8(float *output, const float *input, const float *weight | |||||
| } | } | ||||
| } | } | ||||
| #endif | #endif | ||||
| // #ifndef ENABLE_ARM32 | |||||
| #ifndef ENABLE_ARM32 | |||||
| void IndirectGemmFp32_8x4(float *output, const float *input, const float *weight, const float *bias, size_t step, | void IndirectGemmFp32_8x4(float *output, const float *input, const float *weight, const float *bias, size_t step, | ||||
| size_t ic4, size_t output_channel, size_t offset, size_t mode, size_t writeC4, size_t relu, | size_t ic4, size_t output_channel, size_t offset, size_t mode, size_t writeC4, size_t relu, | ||||
| size_t relu6) { | size_t relu6) { | ||||
| @@ -155,7 +156,7 @@ void IndirectGemmFp32_8x4(float *output, const float *input, const float *weight | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| // #endif | |||||
| #endif | |||||
| int8_t MinInt8(int8_t a, int8_t b) { return b ^ ((a ^ b) & -(a < b)); } | int8_t MinInt8(int8_t a, int8_t b) { return b ^ ((a ^ b) & -(a < b)); } | ||||
| @@ -270,7 +270,12 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_ | |||||
| int out_w_block = UP_DIV(conv_param->output_w_, out_unit); | int out_w_block = UP_DIV(conv_param->output_w_, out_unit); | ||||
| int out_h_block = UP_DIV(conv_param->output_h_, out_unit); | int out_h_block = UP_DIV(conv_param->output_h_, out_unit); | ||||
| int output_count = out_w_block * out_h_block; | int output_count = out_w_block * out_h_block; | ||||
| int output_tile_count = UP_DIV(output_count, C12NUM); | |||||
| #ifdef ENABLE_ARM32 | |||||
| int tile_num = 4; | |||||
| #else | |||||
| int tile_num = 12; | |||||
| #endif | |||||
| int output_tile_count = UP_DIV(output_count, tile_num); | |||||
| int out_channel = conv_param->output_channel_; | int out_channel = conv_param->output_channel_; | ||||
| int oc4 = UP_DIV(out_channel, C4NUM); | int oc4 = UP_DIV(out_channel, C4NUM); | ||||
| int oc8 = UP_DIV(out_channel, C8NUM); | int oc8 = UP_DIV(out_channel, C8NUM); | ||||
| @@ -281,19 +286,19 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_ | |||||
| float *tmp_out_data = buffer_list[2]; | float *tmp_out_data = buffer_list[2]; | ||||
| float *tmp_data = buffer_list[3]; | float *tmp_data = buffer_list[3]; | ||||
| float *col_buffer = buffer_list[4]; | float *col_buffer = buffer_list[4]; | ||||
| int trans_input_offset = C12NUM * input_unit_square * ic4 * C4NUM; | |||||
| int gemm_out_offset = C12NUM * input_unit_square * oc8 * C8NUM; | |||||
| int trans_input_offset = tile_num * input_unit_square * ic4 * C4NUM; | |||||
| int gemm_out_offset = tile_num * input_unit_square * oc8 * C8NUM; | |||||
| int tmp_data_offset = input_unit_square * C4NUM; | int tmp_data_offset = input_unit_square * C4NUM; | ||||
| int col_buffer_offset = C12NUM * ic4 * C4NUM; | |||||
| int col_buffer_offset = tile_num * ic4 * C4NUM; | |||||
| // step 1 : filter transform (pre-processed offline) | // step 1 : filter transform (pre-processed offline) | ||||
| // step 2 : input transform (online) | // step 2 : input transform (online) | ||||
| for (int b = 0; b < in_batch; b++) { | for (int b = 0; b < in_batch; b++) { | ||||
| int in_batch_offset = b * ic4 * C4NUM * conv_param->input_h_ * conv_param->input_w_; | int in_batch_offset = b * ic4 * C4NUM * conv_param->input_h_ * conv_param->input_w_; | ||||
| int tmp_out_batch_offset = b * out_w_block * out_h_block * out_unit * out_unit * oc4 * C4NUM; | int tmp_out_batch_offset = b * out_w_block * out_h_block * out_unit * out_unit * oc4 * C4NUM; | ||||
| for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_num) { | for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_num) { | ||||
| int out_tile_index = thread_id * C12NUM; | |||||
| int cal_num = output_count - thread_id * C12NUM; | |||||
| cal_num = cal_num > C12NUM ? C12NUM : cal_num; | |||||
| int out_tile_index = thread_id * tile_num; | |||||
| int cal_num = output_count - thread_id * tile_num; | |||||
| cal_num = cal_num > tile_num ? tile_num : cal_num; | |||||
| WinogradInputTransform(input_data + in_batch_offset, trans_input + task_id * trans_input_offset, | WinogradInputTransform(input_data + in_batch_offset, trans_input + task_id * trans_input_offset, | ||||
| tmp_data + task_id * tmp_data_offset, cal_num, out_tile_index, out_w_block, conv_param, | tmp_data + task_id * tmp_data_offset, cal_num, out_tile_index, out_w_block, conv_param, | ||||
| in_func); | in_func); | ||||
| @@ -302,7 +307,11 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_ | |||||
| float *dst_ptr = gemm_out + task_id * gemm_out_offset; | float *dst_ptr = gemm_out + task_id * gemm_out_offset; | ||||
| float *tmp_col_ptr = col_buffer + task_id * col_buffer_offset; | float *tmp_col_ptr = col_buffer + task_id * col_buffer_offset; | ||||
| for (int i = 0; i < input_unit_square; ++i) { | for (int i = 0; i < input_unit_square; ++i) { | ||||
| #ifdef ENABLE_ARM32 | |||||
| RowMajor2Col4Major(src_ptr + i * C4NUM * ic4 * C4NUM, tmp_col_ptr, C4NUM, ic4 * C4NUM); | |||||
| #else | |||||
| RowMajor2Col12Major(src_ptr + i * C12NUM * ic4 * C4NUM, tmp_col_ptr, C12NUM, ic4 * C4NUM); | RowMajor2Col12Major(src_ptr + i * C12NUM * ic4 * C4NUM, tmp_col_ptr, C12NUM, ic4 * C4NUM); | ||||
| #endif | |||||
| MatMulOpt(tmp_col_ptr, trans_weight + i * ic4 * C4NUM * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0, ic4 * C4NUM, | MatMulOpt(tmp_col_ptr, trans_weight + i * ic4 * C4NUM * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0, ic4 * C4NUM, | ||||
| cal_num, oc8 * C8NUM, input_unit_square, 2); | cal_num, oc8 * C8NUM, input_unit_square, 2); | ||||
| } | } | ||||
| @@ -460,7 +469,12 @@ void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_dat | |||||
| int out_w_block = UP_DIV(conv_param->output_w_, OUPUT_UNIT); | int out_w_block = UP_DIV(conv_param->output_w_, OUPUT_UNIT); | ||||
| int out_h_block = UP_DIV(conv_param->output_h_, OUPUT_UNIT); | int out_h_block = UP_DIV(conv_param->output_h_, OUPUT_UNIT); | ||||
| int output_count = out_w_block * out_h_block; | int output_count = out_w_block * out_h_block; | ||||
| int output_tile_count = UP_DIV(output_count, C12NUM); | |||||
| #ifdef ENABLE_ARM32 | |||||
| int tile_num = 4; | |||||
| #else | |||||
| int tile_num = 12; | |||||
| #endif | |||||
| int output_tile_count = UP_DIV(output_count, tile_num); | |||||
| const int input_unit_square = 4 * 4; | const int input_unit_square = 4 * 4; | ||||
| float *tile_buffer = buffer_list[0]; | float *tile_buffer = buffer_list[0]; | ||||
| @@ -468,10 +482,10 @@ void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_dat | |||||
| float *tmp_dst_buffer = buffer_list[2]; | float *tmp_dst_buffer = buffer_list[2]; | ||||
| float *nc4hw4_out = buffer_list[3]; | float *nc4hw4_out = buffer_list[3]; | ||||
| float *col_buffer = buffer_list[4]; | float *col_buffer = buffer_list[4]; | ||||
| int tile_buffer_offset = C12NUM * input_unit_square * ic4 * C4NUM; | |||||
| int tile_buffer_offset = tile_num * input_unit_square * ic4 * C4NUM; | |||||
| int block_unit_buffer_offset = input_unit_square * C4NUM; | int block_unit_buffer_offset = input_unit_square * C4NUM; | ||||
| int tmp_dst_buffer_offset = C12NUM * input_unit_square * oc8 * C8NUM; | |||||
| int col_buffer_offset = C12NUM * ic4 * C4NUM; | |||||
| int tmp_dst_buffer_offset = tile_num * input_unit_square * oc8 * C8NUM; | |||||
| int col_buffer_offset = tile_num * ic4 * C4NUM; | |||||
| int input_batch = conv_param->input_batch_; | int input_batch = conv_param->input_batch_; | ||||
| for (int batch = 0; batch < input_batch; batch++) { | for (int batch = 0; batch < input_batch; batch++) { | ||||
| @@ -479,8 +493,8 @@ void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_dat | |||||
| int nc4hw4_buffer_offset = batch * oc4 * C4NUM * conv_param->output_h_ * conv_param->output_w_; | int nc4hw4_buffer_offset = batch * oc4 * C4NUM * conv_param->output_h_ * conv_param->output_w_; | ||||
| for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) { | for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) { | ||||
| int start_index = thread_id * C12NUM; | |||||
| int real_cal_num = (output_count - start_index) < C12NUM ? (output_count - start_index) : C12NUM; | |||||
| int start_index = thread_id * tile_num; | |||||
| int real_cal_num = (output_count - start_index) < tile_num ? (output_count - start_index) : tile_num; | |||||
| Conv3x3Fp32InputTransform(input_data + in_batch_offset, tile_buffer + task_id * tile_buffer_offset, | Conv3x3Fp32InputTransform(input_data + in_batch_offset, tile_buffer + task_id * tile_buffer_offset, | ||||
| block_unit_buffer + task_id * block_unit_buffer_offset, start_index, real_cal_num, | block_unit_buffer + task_id * block_unit_buffer_offset, start_index, real_cal_num, | ||||
| out_w_block, conv_param); | out_w_block, conv_param); | ||||
| @@ -489,7 +503,11 @@ void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_dat | |||||
| float *tmp_col_ptr = col_buffer + task_id * col_buffer_offset; | float *tmp_col_ptr = col_buffer + task_id * col_buffer_offset; | ||||
| float *dst_ptr = tmp_dst_buffer + task_id * tmp_dst_buffer_offset; | float *dst_ptr = tmp_dst_buffer + task_id * tmp_dst_buffer_offset; | ||||
| for (int i = 0; i < input_unit_square; ++i) { | for (int i = 0; i < input_unit_square; ++i) { | ||||
| #ifdef ENABLE_ARM32 | |||||
| RowMajor2Col4Major(src_ptr + i * C4NUM * ic4 * C4NUM, tmp_col_ptr, C4NUM, ic4 * C4NUM); | |||||
| #else | |||||
| RowMajor2Col12Major(src_ptr + i * C12NUM * ic4 * C4NUM, tmp_col_ptr, C12NUM, ic4 * C4NUM); | RowMajor2Col12Major(src_ptr + i * C12NUM * ic4 * C4NUM, tmp_col_ptr, C12NUM, ic4 * C4NUM); | ||||
| #endif | |||||
| MatMulOpt(tmp_col_ptr, transed_weight + i * ic4 * C4NUM * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0, | MatMulOpt(tmp_col_ptr, transed_weight + i * ic4 * C4NUM * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0, | ||||
| ic4 * C4NUM, real_cal_num, oc8 * C8NUM, input_unit_square, 2); | ic4 * C4NUM, real_cal_num, oc8 * C8NUM, input_unit_square, 2); | ||||
| } | } | ||||
| @@ -40,7 +40,12 @@ int DeConvPostFp32C12x8(const float *src, float *tmp, const float *bias, float * | |||||
| size_t kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_; | size_t kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_; | ||||
| size_t output_plane = conv_param->output_w_ * conv_param->output_h_; | size_t output_plane = conv_param->output_w_ * conv_param->output_h_; | ||||
| int oc8 = UP_ROUND(output_channel, C8NUM); | int oc8 = UP_ROUND(output_channel, C8NUM); | ||||
| int in_plane12 = UP_ROUND(input_plane, C12NUM); | |||||
| #ifdef ENABLE_ARM32 | |||||
| int tile_num = 4; | |||||
| #else | |||||
| int tile_num = 12; | |||||
| #endif | |||||
| int in_plane12 = UP_ROUND(input_plane, tile_num); | |||||
| int src_iw_stride = C8NUM; | int src_iw_stride = C8NUM; | ||||
| int src_ih_stride = conv_param->input_w_ * C8NUM; | int src_ih_stride = conv_param->input_w_ * C8NUM; | ||||
| int src_kw_stride = in_plane12 * C8NUM; | int src_kw_stride = in_plane12 * C8NUM; | ||||
| @@ -16,6 +16,18 @@ | |||||
| #include "nnacl/fp32/matmul.h" | #include "nnacl/fp32/matmul.h" | ||||
| void RowMajor2Row4Major(float *src_ptr, float *dst_ptr, int row, int col) { | |||||
| for (int r = 0; r < row; r++) { | |||||
| float *src = src_ptr + r * col; | |||||
| for (int c = 0; c < col; c++) { | |||||
| int cd8 = c / 4; | |||||
| int cm8 = c % 4; | |||||
| dst_ptr[cd8 * 4 * row + r * 4 + cm8] = src[c]; | |||||
| } | |||||
| } | |||||
| return; | |||||
| } | |||||
| void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col) { | void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col) { | ||||
| for (int r = 0; r < row; r++) { | for (int r = 0; r < row; r++) { | ||||
| float *src = src_ptr + r * col; | float *src = src_ptr + r * col; | ||||
| @@ -115,6 +127,61 @@ void RowMajor2Col12Major(float *src_ptr, float *dst_ptr, size_t row, size_t col) | |||||
| : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | ||||
| "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", | "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", | ||||
| "v30", "v31"); | "v30", "v31"); | ||||
| #elif ENABLE_ARM32 | |||||
| size_t stride = col * sizeof(float); | |||||
| asm volatile( | |||||
| "mov r10, %[src_c]\n" | |||||
| "mov r12, %[dst_c]\n" | |||||
| "vld1.32 {q0}, [r10], %[stride]\n" | |||||
| "vld1.32 {q3}, [r10], %[stride]\n" | |||||
| "vld1.32 {q10}, [r10], %[stride]\n" | |||||
| "vld1.32 {q13}, [r10], %[stride]\n" | |||||
| "vtrn.32 d0, d6\n" | |||||
| "vtrn.32 d1, d7\n" | |||||
| "vtrn.32 d20, d26\n" | |||||
| "vtrn.32 d21, d27\n" | |||||
| "vld1.32 {q1}, [r10], %[stride]\n" | |||||
| "vld1.32 {q8}, [r10], %[stride]\n" | |||||
| "vld1.32 {q11}, [r10], %[stride]\n" | |||||
| "vld1.32 {q14}, [r10], %[stride]\n" | |||||
| "vswp d1, d20\n" | |||||
| "vswp d7, d26\n" | |||||
| "vld1.32 {q2}, [r10], %[stride]\n" | |||||
| "vld1.32 {q9}, [r10], %[stride]\n" | |||||
| "vld1.32 {q12}, [r10], %[stride]\n" | |||||
| "vld1.32 {q15}, [r10], %[stride]\n" | |||||
| "vtrn.32 d2, d16\n" | |||||
| "vtrn.32 d3, d17\n" | |||||
| "vtrn.32 d22, d28\n" | |||||
| "vtrn.32 d23, d29\n" | |||||
| "vswp d3, d22\n" | |||||
| "vswp d17, d28\n" | |||||
| "vtrn.32 d4, d18\n" | |||||
| "vtrn.32 d5, d19\n" | |||||
| "vtrn.32 d24, d30\n" | |||||
| "vtrn.32 d25, d31\n" | |||||
| "vswp d5, d24\n" | |||||
| "vswp d19, d30\n" | |||||
| "vst1.32 {q0, q1}, [r12]!\n" | |||||
| "vst1.32 {q2, q3}, [r12]!\n" | |||||
| "vst1.32 {q8, q9}, [r12]!\n" | |||||
| "vst1.32 {q10, q11}, [r12]!\n" | |||||
| "vst1.32 {q12, q13}, [r12]!\n" | |||||
| "vst1.32 {q14, q15}, [r12]!\n" | |||||
| : | |||||
| : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride) | |||||
| : "r10", "r12", "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); | |||||
| #else | #else | ||||
| for (int tr = 0; tr < C12NUM; tr++) { | for (int tr = 0; tr < C12NUM; tr++) { | ||||
| for (int tc = 0; tc < C4NUM; tc++) { | for (int tc = 0; tc < C4NUM; tc++) { | ||||
| @@ -242,6 +309,75 @@ void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col) | |||||
| return; | return; | ||||
| } | } | ||||
| void RowMajor2Col4Major(float *src_ptr, float *dst_ptr, size_t row, size_t col) { | |||||
| size_t row8 = row / C4NUM * C4NUM; | |||||
| size_t col4 = col / C4NUM * C4NUM; | |||||
| float *src_r = src_ptr; | |||||
| float *dst_r = dst_ptr; | |||||
| size_t ri = 0; | |||||
| for (; ri < row8; ri += C4NUM) { | |||||
| size_t ci = 0; | |||||
| for (; ci < col4; ci += C4NUM) { | |||||
| float *src_c = src_r + ci; | |||||
| float *dst_c = dst_r + ci * C4NUM; | |||||
| /* 4x4 row-major to col-major */ | |||||
| #ifdef ENABLE_ARM32 | |||||
| size_t stride = col * 4; | |||||
| asm volatile( | |||||
| "mov r10, %[src_c]\n" | |||||
| "mov r12, %[dst_c]\n" | |||||
| "vld1.32 {q0}, [r10], %[stride]\n" | |||||
| "vld1.32 {q1}, [r10], %[stride]\n" | |||||
| "vld1.32 {q2}, [r10], %[stride]\n" | |||||
| "vld1.32 {q3}, [r10], %[stride]\n" | |||||
| "vtrn.32 d0, d2\n" | |||||
| "vtrn.32 d1, d3\n" | |||||
| "vtrn.32 d4, d6\n" | |||||
| "vtrn.32 d5, d7\n" | |||||
| "vswp d1, d4\n" | |||||
| "vswp d3, d6\n" | |||||
| "vst1.32 {q0}, [r12]!\n" | |||||
| "vst1.32 {q1}, [r12]!\n" | |||||
| "vst1.32 {q2}, [r12]!\n" | |||||
| "vst1.32 {q3}, [r12]!\n" | |||||
| : | |||||
| : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride) | |||||
| : "r10", "r12", "q0", "q1", "q2", "q3"); | |||||
| #else | |||||
| for (int tr = 0; tr < C4NUM; tr++) { | |||||
| for (int tc = 0; tc < C4NUM; tc++) { | |||||
| dst_c[tc * C4NUM + tr] = src_c[tr * col + tc]; | |||||
| } | |||||
| } | |||||
| #endif | |||||
| } | |||||
| for (; ci < col; ci++) { | |||||
| float *src_c = src_r + ci; | |||||
| float *dst_c = dst_r + ci * C4NUM; | |||||
| for (size_t i = 0; i < C4NUM; i++) { | |||||
| dst_c[i] = src_c[i * col]; | |||||
| } | |||||
| } | |||||
| src_r += C4NUM * col; | |||||
| dst_r += C4NUM * col; | |||||
| } | |||||
| for (; ri < row; ri++) { | |||||
| for (size_t i = 0; i < col; i++) { | |||||
| dst_r[i * C4NUM] = src_r[i]; | |||||
| } | |||||
| src_r += col; | |||||
| dst_r += 1; | |||||
| } | |||||
| return; | |||||
| } | |||||
| void MatrixUnPackUnit(const void *src, void *dst, size_t row, size_t col, size_t src_stride, size_t dst_stride, | void MatrixUnPackUnit(const void *src, void *dst, size_t row, size_t col, size_t src_stride, size_t dst_stride, | ||||
| size_t data_lenth) { | size_t data_lenth) { | ||||
| size_t copy_size = col * data_lenth; | size_t copy_size = col * data_lenth; | ||||
| @@ -418,6 +554,9 @@ void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActT | |||||
| MatmulFloatNeon64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type == OutType_Nhwc), | MatmulFloatNeon64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type == OutType_Nhwc), | ||||
| (int)(out_type == OutType_TileC8)); | (int)(out_type == OutType_TileC8)); | ||||
| } | } | ||||
| #elif ENABLE_ARM32 | |||||
| MatmulFloatNeon32Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type == OutType_Nhwc), | |||||
| (int)(out_type == OutType_TileC8)); | |||||
| #else | #else | ||||
| MatMul12x8(a, b, c, bias, act_type, deep, row, col, stride, out_type); | MatMul12x8(a, b, c, bias, act_type, deep, row, col, stride, out_type); | ||||
| #endif | #endif | ||||
| @@ -29,8 +29,10 @@ extern "C" { | |||||
| void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row, | void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row, | ||||
| int col, size_t stride, int out_type); | int col, size_t stride, int out_type); | ||||
| void RowMajor2Row4Major(float *src_ptr, float *dst_ptr, int row, int col); | |||||
| void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col); | void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col); | ||||
| void RowMajor2Row12Major(float *src_ptr, float *dst_ptr, int row, int col); | void RowMajor2Row12Major(float *src_ptr, float *dst_ptr, int row, int col); | ||||
| void RowMajor2Col4Major(float *src_ptr, float *dst_ptr, size_t row, size_t col); | |||||
| void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col); | void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col); | ||||
| void RowMajor2Col12Major(float *src_ptr, float *dst_ptr, size_t row, size_t col); | void RowMajor2Col12Major(float *src_ptr, float *dst_ptr, size_t row, size_t col); | ||||
| void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, size_t row, size_t col, size_t stride); | void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, size_t row, size_t col, size_t stride); | ||||
| @@ -40,6 +42,9 @@ void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bi | |||||
| void MatmulFloatNeon64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, | void MatmulFloatNeon64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, | ||||
| int col, size_t stride, size_t write_nhwc, size_t write_c4); | int col, size_t stride, size_t write_nhwc, size_t write_c4); | ||||
| void MatmulFloatNeon64OptRemain(const float *a, const float *b, float *c, int depth, int row, int col, size_t stride); | void MatmulFloatNeon64OptRemain(const float *a, const float *b, float *c, int depth, int row, int col, size_t stride); | ||||
| #elif ENABLE_ARM32 | |||||
| void MatmulFloatNeon32Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, | |||||
| int col, size_t stride, size_t write_nhwc, size_t write_c4); | |||||
| #endif | #endif | ||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| } | } | ||||
| @@ -1223,6 +1223,78 @@ void PackNHWCToNCHWFp32(const void *src, void *dst, int batches, int plane, int | |||||
| : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | ||||
| "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", | "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", | ||||
| "v30", "v31"); | "v30", "v31"); | ||||
| #elif ENABLE_ARM32 | |||||
| size_t srcStride = channel * sizeof(float); | |||||
| size_t dstStride = plane * sizeof(float); | |||||
| asm volatile( | |||||
| "mov r10, %[src_ptr]\n" | |||||
| "mov r12, %[dst_ptr]\n" | |||||
| "vld1.32 {q0, q1}, [r10], %[srcStride]\n" | |||||
| "vld1.32 {q2, q3}, [r10], %[srcStride]\n" | |||||
| "vtrn.32 d0, d4\n" | |||||
| "vtrn.32 d1, d5\n" | |||||
| "vtrn.32 d2, d6\n" | |||||
| "vtrn.32 d3, d7\n" | |||||
| "vld1.32 {q4, q5}, [r10], %[srcStride]\n" | |||||
| "vld1.32 {q6, q7}, [r10], %[srcStride]\n" | |||||
| "vtrn.32 d8, d12\n" | |||||
| "vtrn.32 d9, d13\n" | |||||
| "vtrn.32 d10, d14\n" | |||||
| "vtrn.32 d11, d15\n" | |||||
| "vld1.32 {q8, q9}, [r10], %[srcStride]\n" | |||||
| "vld1.32 {q10, q11}, [r10], %[srcStride]\n" | |||||
| "vswp d1, d8\n" | |||||
| "vswp d3, d10\n" | |||||
| "vswp d5, d12\n" | |||||
| "vswp d7, d14\n" | |||||
| "vtrn.32 d16, d20\n" | |||||
| "vtrn.32 d17, d21\n" | |||||
| "vtrn.32 d18, d22\n" | |||||
| "vtrn.32 d19, d23\n" | |||||
| "vld1.32 {q12, q13}, [r10], %[srcStride]\n" | |||||
| "vld1.32 {q14, q15}, [r10], %[srcStride]\n" | |||||
| "vtrn.32 d24, d28\n" | |||||
| "vtrn.32 d25, d29\n" | |||||
| "vtrn.32 d26, d30\n" | |||||
| "vtrn.32 d27, d31\n" | |||||
| "vswp d17, d24\n" | |||||
| "vswp d19, d26\n" | |||||
| "vswp d21, d28\n" | |||||
| "vswp d23, d30\n" | |||||
| "add r10, r12, #16\n" | |||||
| "vst1.32 {q0}, [r12], %[dstStride]\n" | |||||
| "vst1.32 {q8}, [r10], %[dstStride]\n" | |||||
| "vst1.32 {q2}, [r12], %[dstStride]\n" | |||||
| "vst1.32 {q10}, [r10], %[dstStride]\n" | |||||
| "vst1.32 {q4}, [r12], %[dstStride]\n" | |||||
| "vst1.32 {q12}, [r10], %[dstStride]\n" | |||||
| "vst1.32 {q6}, [r12], %[dstStride]\n" | |||||
| "vst1.32 {q14}, [r10], %[dstStride]\n" | |||||
| "vst1.32 {q1}, [r12], %[dstStride]\n" | |||||
| "vst1.32 {q9}, [r10], %[dstStride]\n" | |||||
| "vst1.32 {q3}, [r12], %[dstStride]\n" | |||||
| "vst1.32 {q11}, [r10], %[dstStride]\n" | |||||
| "vst1.32 {q5}, [r12], %[dstStride]\n" | |||||
| "vst1.32 {q13}, [r10], %[dstStride]\n" | |||||
| "vst1.32 {q7}, [r12], %[dstStride]\n" | |||||
| "vst1.32 {q15}, [r10], %[dstStride]\n" | |||||
| : | |||||
| : | |||||
| [ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ srcStride ] "r"(srcStride), [ dstStride ] "r"(dstStride) | |||||
| : "r10", "r12", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", | |||||
| "q15"); | |||||
| #else | #else | ||||
| for (int tr = 0; tr < C8NUM; tr++) { | for (int tr = 0; tr < C8NUM; tr++) { | ||||
| for (int tc = 0; tc < C8NUM; tc++) { | for (int tc = 0; tc < C8NUM; tc++) { | ||||
| @@ -67,8 +67,13 @@ void WinogradInputTransform(const float *input_data, float *trans_input, float * | |||||
| } | } | ||||
| } | } | ||||
| // input transform | // input transform | ||||
| #ifdef ENABLE_ARM32 | |||||
| int tile_num = 4; | |||||
| #else | |||||
| int tile_num = 12; | |||||
| #endif | |||||
| int dst_ic4_offset = dst_plane_offset + ic * C4NUM; | int dst_ic4_offset = dst_plane_offset + ic * C4NUM; | ||||
| size_t dst_step = C12NUM * ic4 * C4NUM; | |||||
| size_t dst_step = tile_num * ic4 * C4NUM; | |||||
| float *trans_input_ptr = trans_input + dst_ic4_offset; | float *trans_input_ptr = trans_input + dst_ic4_offset; | ||||
| func(tmp_data, trans_input_ptr, C4NUM, dst_step); | func(tmp_data, trans_input_ptr, C4NUM, dst_step); | ||||
| // GeneralInputTransformUnit(tmp_data, trans_input_ptr, matrix_b, matrix_bt, C4NUM, dst_step, input_unit); | // GeneralInputTransformUnit(tmp_data, trans_input_ptr, matrix_b, matrix_bt, C4NUM, dst_step, input_unit); | ||||
| @@ -331,8 +336,13 @@ void Conv3x3Fp32InputTransform(const float *input_data, float *trans_input, floa | |||||
| } | } | ||||
| // input transform | // input transform | ||||
| #ifdef ENABLE_ARM32 | |||||
| int tile_num = 4; | |||||
| #else | |||||
| int tile_num = 12; | |||||
| #endif | |||||
| int dst_ic4_offset = dst_plane_offset + ic * C4NUM; | int dst_ic4_offset = dst_plane_offset + ic * C4NUM; | ||||
| size_t dst_step = C12NUM * ic4 * C4NUM; | |||||
| size_t dst_step = tile_num * ic4 * C4NUM; | |||||
| float *trans_input_ptr = trans_input + dst_ic4_offset; | float *trans_input_ptr = trans_input + dst_ic4_offset; | ||||
| Conv3x3Fp32InputUnit(tmp_data, trans_input_ptr, dst_step); | Conv3x3Fp32InputUnit(tmp_data, trans_input_ptr, dst_step); | ||||
| } | } | ||||
| @@ -26,15 +26,13 @@ if (PLATFORM_ARM64) | |||||
| set(KERNEL_SRC ${KERNEL_SRC} ${ASSEMBLY_SRC}) | set(KERNEL_SRC ${KERNEL_SRC} ${ASSEMBLY_SRC}) | ||||
| endif() | endif() | ||||
| #[[ | |||||
| if (PLATFORM_ARM32) | if (PLATFORM_ARM32) | ||||
| # assembly | # assembly | ||||
| file(GLOB ASSEMBLY_SRC nnacl/assembly/arm32/*.s | |||||
| nnacl/assembly/arm32/*.S | |||||
| file(GLOB ASSEMBLY_SRC ${CMAKE_CURRENT_SOURCE_DIR}/../../../../nnacl/assembly/arm32/*.s | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../../../nnacl/assembly/arm32/*.S | |||||
| ) | ) | ||||
| set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C) | set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C) | ||||
| set(KERNEL_SRC ${KERNEL_SRC} ${ASSEMBLY_SRC}) | set(KERNEL_SRC ${KERNEL_SRC} ${ASSEMBLY_SRC}) | ||||
| endif() | endif() | ||||
| ]] | |||||
| add_library(cpu_kernel_mid_ OBJECT ${KERNEL_SRC} ${TRAIN_KERNEL_SRC}) | add_library(cpu_kernel_mid_ OBJECT ${KERNEL_SRC} ${TRAIN_KERNEL_SRC}) | ||||
| @@ -59,6 +59,7 @@ void Convolution1x1CPUKernel::InitConv1x1MatmulParam() { | |||||
| matmul_param_->row_ = conv_param_->output_h_ * conv_param_->output_w_; | matmul_param_->row_ = conv_param_->output_h_ * conv_param_->output_w_; | ||||
| matmul_param_->col_ = conv_param_->output_channel_; | matmul_param_->col_ = conv_param_->output_channel_; | ||||
| matmul_param_->deep_ = conv_param_->input_channel_; | matmul_param_->deep_ = conv_param_->input_channel_; | ||||
| matmul_param_->row_4_ = UP_ROUND(matmul_param_->row_, C4NUM); | |||||
| matmul_param_->row_12_ = UP_ROUND(matmul_param_->row_, C12NUM); | matmul_param_->row_12_ = UP_ROUND(matmul_param_->row_, C12NUM); | ||||
| matmul_param_->col_8_ = UP_ROUND(matmul_param_->col_, C8NUM); | matmul_param_->col_8_ = UP_ROUND(matmul_param_->col_, C8NUM); | ||||
| matmul_param_->act_type_ = conv_param_->act_type_; | matmul_param_->act_type_ = conv_param_->act_type_; | ||||
| @@ -120,8 +121,11 @@ void Convolution1x1CPUKernel::Pre1x1Trans(float *src_input, float *src_output) { | |||||
| } else { | } else { | ||||
| input_ptr_ = src_input; | input_ptr_ = src_input; | ||||
| } | } | ||||
| #ifdef ENABLE_ARM32 | |||||
| RowMajor2Col4Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_); | |||||
| #else | |||||
| RowMajor2Col12Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_); | RowMajor2Col12Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_); | ||||
| #endif | |||||
| return; | return; | ||||
| } | } | ||||
| @@ -169,8 +173,13 @@ int Convolution1x1CPUKernel::Run() { | |||||
| auto src_in = reinterpret_cast<float *>(in_tensors_[0]->MutableData()); | auto src_in = reinterpret_cast<float *>(in_tensors_[0]->MutableData()); | ||||
| auto src_out = reinterpret_cast<float *>(out_tensors_[0]->MutableData()); | auto src_out = reinterpret_cast<float *>(out_tensors_[0]->MutableData()); | ||||
| #ifdef ENABLE_ARM32 | |||||
| pack_input_ = | |||||
| reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_4_ * matmul_param_->deep_ * sizeof(float))); | |||||
| #else | |||||
| pack_input_ = | pack_input_ = | ||||
| reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_12_ * matmul_param_->deep_ * sizeof(float))); | reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_12_ * matmul_param_->deep_ * sizeof(float))); | ||||
| #endif | |||||
| if (pack_input_ == nullptr) { | if (pack_input_ == nullptr) { | ||||
| MS_LOG(ERROR) << "Conv1x1 Malloc pack_input_ error!"; | MS_LOG(ERROR) << "Conv1x1 Malloc pack_input_ error!"; | ||||
| return RET_MEMORY_FAILED; | return RET_MEMORY_FAILED; | ||||
| @@ -95,7 +95,12 @@ int Convolution3x3CPUKernel::InitTmpBuffer() { | |||||
| const int k_plane = 16; | const int k_plane = 16; | ||||
| MS_ASSERT(ctx_->allocator != nullptr); | MS_ASSERT(ctx_->allocator != nullptr); | ||||
| size_t tile_buffer_size = thread_count_ * C12NUM * C16NUM * ic4 * C4NUM * sizeof(float); | |||||
| #ifdef ENABLE_ARM32 | |||||
| int tile_num = 4; | |||||
| #else | |||||
| int tile_num = 12; | |||||
| #endif | |||||
| size_t tile_buffer_size = thread_count_ * tile_num * C16NUM * ic4 * C4NUM * sizeof(float); | |||||
| tile_buffer_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(tile_buffer_size)); | tile_buffer_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(tile_buffer_size)); | ||||
| if (tile_buffer_ == nullptr) { | if (tile_buffer_ == nullptr) { | ||||
| MS_LOG(ERROR) << "malloc tile buffer failed."; | MS_LOG(ERROR) << "malloc tile buffer failed."; | ||||
| @@ -109,14 +114,14 @@ int Convolution3x3CPUKernel::InitTmpBuffer() { | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| size_t tmp_dst_buffer_size = thread_count_ * C12NUM * k_plane * oC8 * C8NUM * sizeof(float); | |||||
| size_t tmp_dst_buffer_size = thread_count_ * tile_num * k_plane * oC8 * C8NUM * sizeof(float); | |||||
| tmp_dst_buffer_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(tmp_dst_buffer_size)); | tmp_dst_buffer_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(tmp_dst_buffer_size)); | ||||
| if (tmp_dst_buffer_ == nullptr) { | if (tmp_dst_buffer_ == nullptr) { | ||||
| MS_LOG(ERROR) << "malloc tmp_dst_buffer_ failed."; | MS_LOG(ERROR) << "malloc tmp_dst_buffer_ failed."; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| size_t col_buffer_size = thread_count_ * C12NUM * C4NUM * ic4 * sizeof(float); | |||||
| size_t col_buffer_size = thread_count_ * tile_num * C4NUM * ic4 * sizeof(float); | |||||
| col_buffer_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(col_buffer_size)); | col_buffer_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(col_buffer_size)); | ||||
| if (col_buffer_ == nullptr) { | if (col_buffer_ == nullptr) { | ||||
| MS_LOG(ERROR) << "malloc col_buffer_ failed."; | MS_LOG(ERROR) << "malloc col_buffer_ failed."; | ||||
| @@ -150,9 +150,14 @@ int ConvolutionWinogradCPUKernel::InitTmpBuffer() { | |||||
| int oc4 = UP_DIV(channel_out, C4NUM); | int oc4 = UP_DIV(channel_out, C4NUM); | ||||
| int oc8 = UP_DIV(channel_out, C8NUM); | int oc8 = UP_DIV(channel_out, C8NUM); | ||||
| int ic4 = UP_DIV(conv_param_->input_channel_, C4NUM); | int ic4 = UP_DIV(conv_param_->input_channel_, C4NUM); | ||||
| #ifdef ENABLE_ARM32 | |||||
| int tile_num = 4; | |||||
| #else | |||||
| int tile_num = 12; | |||||
| #endif | |||||
| MS_ASSERT(ctx_->allocator != nullptr); | MS_ASSERT(ctx_->allocator != nullptr); | ||||
| size_t tile_buffer_size = thread_count_ * C12NUM * input_unit_ * input_unit_ * ic4 * C4NUM * sizeof(float); | |||||
| size_t tile_buffer_size = thread_count_ * tile_num * input_unit_ * input_unit_ * ic4 * C4NUM * sizeof(float); | |||||
| trans_input_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(tile_buffer_size)); | trans_input_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(tile_buffer_size)); | ||||
| if (trans_input_ == nullptr) { | if (trans_input_ == nullptr) { | ||||
| MS_LOG(ERROR) << "malloc trans_input_ failed."; | MS_LOG(ERROR) << "malloc trans_input_ failed."; | ||||
| @@ -160,7 +165,7 @@ int ConvolutionWinogradCPUKernel::InitTmpBuffer() { | |||||
| } | } | ||||
| gemm_out_ = reinterpret_cast<float *>( | gemm_out_ = reinterpret_cast<float *>( | ||||
| ctx_->allocator->Malloc(thread_count_ * C12NUM * input_unit_ * input_unit_ * oc8 * C8NUM * sizeof(float))); | |||||
| ctx_->allocator->Malloc(thread_count_ * tile_num * input_unit_ * input_unit_ * oc8 * C8NUM * sizeof(float))); | |||||
| if (gemm_out_ == nullptr) { | if (gemm_out_ == nullptr) { | ||||
| MS_LOG(ERROR) << "malloc gemm_out_ failed."; | MS_LOG(ERROR) << "malloc gemm_out_ failed."; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -184,7 +189,7 @@ int ConvolutionWinogradCPUKernel::InitTmpBuffer() { | |||||
| } | } | ||||
| col_buffer_ = | col_buffer_ = | ||||
| reinterpret_cast<float *>(ctx_->allocator->Malloc(thread_count_ * C12NUM * ic4 * C4NUM * sizeof(float))); | |||||
| reinterpret_cast<float *>(ctx_->allocator->Malloc(thread_count_ * tile_num * ic4 * C4NUM * sizeof(float))); | |||||
| if (col_buffer_ == nullptr) { | if (col_buffer_ == nullptr) { | ||||
| MS_LOG(ERROR) << "malloc col_buffer_ failed."; | MS_LOG(ERROR) << "malloc col_buffer_ failed."; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -85,6 +85,7 @@ int DeConvolutionCPUKernel::InitParam() { | |||||
| matmul_param_->deep_ = conv_param_->input_channel_; | matmul_param_->deep_ = conv_param_->input_channel_; | ||||
| matmul_param_->col_ = conv_param_->output_channel_ * kernel_plane_; | matmul_param_->col_ = conv_param_->output_channel_ * kernel_plane_; | ||||
| matmul_param_->row_12_ = UP_ROUND(matmul_param_->row_, C12NUM); | matmul_param_->row_12_ = UP_ROUND(matmul_param_->row_, C12NUM); | ||||
| matmul_param_->row_4_ = UP_ROUND(matmul_param_->row_, C4NUM); | |||||
| matmul_param_->col_8_ = UP_ROUND(conv_param_->output_channel_, C8NUM) * kernel_plane_; | matmul_param_->col_8_ = UP_ROUND(conv_param_->output_channel_, C8NUM) * kernel_plane_; | ||||
| thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(conv_param_->output_channel_, C8NUM)); | thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(conv_param_->output_channel_, C8NUM)); | ||||
| @@ -112,10 +113,17 @@ int DeConvolutionCPUKernel::DoDeconv(int task_id) { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| #ifdef ENABLE_ARM32 | |||||
| auto tmp_buffer = tmp_buffer_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->row_4_; | |||||
| MatMulOpt(pack_input_, weight_ptr_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->deep_, | |||||
| tmp_buffer, nullptr, ActType_No, matmul_param_->deep_, matmul_param_->row_4_, oc * C8NUM * kernel_plane_, | |||||
| matmul_param_->col_, OutType_C8); | |||||
| #else | |||||
| auto tmp_buffer = tmp_buffer_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->row_12_; | auto tmp_buffer = tmp_buffer_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->row_12_; | ||||
| MatMulOpt(pack_input_, weight_ptr_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->deep_, | MatMulOpt(pack_input_, weight_ptr_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->deep_, | ||||
| tmp_buffer, nullptr, ActType_No, matmul_param_->deep_, matmul_param_->row_12_, oc * C8NUM * kernel_plane_, | tmp_buffer, nullptr, ActType_No, matmul_param_->deep_, matmul_param_->row_12_, oc * C8NUM * kernel_plane_, | ||||
| matmul_param_->col_, OutType_C8); | matmul_param_->col_, OutType_C8); | ||||
| #endif | |||||
| DeConvPostFp32C12x8(tmp_buffer, pack_output_ + task_id * thread_stride_ * C8NUM * output_plane_, | DeConvPostFp32C12x8(tmp_buffer, pack_output_ + task_id * thread_stride_ * C8NUM * output_plane_, | ||||
| reinterpret_cast<float *>(bias_data_) + thread_stride_ * task_id * C8NUM, | reinterpret_cast<float *>(bias_data_) + thread_stride_ * task_id * C8NUM, | ||||
| @@ -159,15 +167,25 @@ int DeConvolutionCPUKernel::InitRunBuf() { | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| #ifdef ENABLE_ARM32 | |||||
| tmp_buffer_ = | |||||
| reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_4_ * matmul_param_->col_8_ * sizeof(float))); | |||||
| #else | |||||
| tmp_buffer_ = | tmp_buffer_ = | ||||
| reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_12_ * matmul_param_->col_8_ * sizeof(float))); | reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_12_ * matmul_param_->col_8_ * sizeof(float))); | ||||
| #endif | |||||
| if (tmp_buffer_ == nullptr) { | if (tmp_buffer_ == nullptr) { | ||||
| MS_LOG(ERROR) << "Conv1x1 Malloc tmp_buffer_ error!"; | MS_LOG(ERROR) << "Conv1x1 Malloc tmp_buffer_ error!"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| #ifdef ENABLE_ARM32 | |||||
| pack_input_ = | |||||
| reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_4_ * matmul_param_->deep_ * sizeof(float))); | |||||
| #else | |||||
| pack_input_ = | pack_input_ = | ||||
| reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_12_ * matmul_param_->deep_ * sizeof(float))); | reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_12_ * matmul_param_->deep_ * sizeof(float))); | ||||
| #endif | |||||
| if (pack_input_ == nullptr) { | if (pack_input_ == nullptr) { | ||||
| MS_LOG(ERROR) << "deconv Malloc pack_input_ error!"; | MS_LOG(ERROR) << "deconv Malloc pack_input_ error!"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -49,6 +49,7 @@ int FullconnectionCPUKernel::ReSize() { | |||||
| fc_param_->row_12_ = UP_ROUND(fc_param_->row_, C12NUM); | fc_param_->row_12_ = UP_ROUND(fc_param_->row_, C12NUM); | ||||
| fc_param_->col_8_ = UP_ROUND(fc_param_->col_, C8NUM); | fc_param_->col_8_ = UP_ROUND(fc_param_->col_, C8NUM); | ||||
| fc_param_->row_4_ = UP_ROUND(fc_param_->row_, C4NUM); | |||||
| thread_count_ = MSMIN(thread_count_, UP_DIV(fc_param_->col_8_, 8)); | thread_count_ = MSMIN(thread_count_, UP_DIV(fc_param_->col_8_, 8)); | ||||
| thread_stride_ = UP_DIV(UP_DIV(fc_param_->col_8_, 8), thread_count_); | thread_stride_ = UP_DIV(UP_DIV(fc_param_->col_8_, 8), thread_count_); | ||||
| @@ -59,11 +60,19 @@ int FullconnectionCPUKernel::ReSize() { | |||||
| memcpy(bias_ptr_, in_tensors_[2]->MutableData(), fc_param_->col_ * sizeof(float)); | memcpy(bias_ptr_, in_tensors_[2]->MutableData(), fc_param_->col_ * sizeof(float)); | ||||
| } | } | ||||
| #ifdef ENABLE_ARM32 | |||||
| a_c12_ptr_ = reinterpret_cast<float *>(malloc(fc_param_->row_4_ * fc_param_->deep_ * sizeof(float))); | |||||
| if (a_c12_ptr_ == nullptr) { | |||||
| return RET_MEMORY_FAILED; | |||||
| } | |||||
| memset(a_c12_ptr_, 0, fc_param_->row_4_ * fc_param_->deep_ * sizeof(float)); | |||||
| #else | |||||
| a_c12_ptr_ = reinterpret_cast<float *>(malloc(fc_param_->row_12_ * fc_param_->deep_ * sizeof(float))); | a_c12_ptr_ = reinterpret_cast<float *>(malloc(fc_param_->row_12_ * fc_param_->deep_ * sizeof(float))); | ||||
| if (a_c12_ptr_ == nullptr) { | if (a_c12_ptr_ == nullptr) { | ||||
| return RET_MEMORY_FAILED; | return RET_MEMORY_FAILED; | ||||
| } | } | ||||
| memset(a_c12_ptr_, 0, fc_param_->row_12_ * fc_param_->deep_ * sizeof(float)); | memset(a_c12_ptr_, 0, fc_param_->row_12_ * fc_param_->deep_ * sizeof(float)); | ||||
| #endif | |||||
| b_r8_ptr_ = reinterpret_cast<float *>(malloc(fc_param_->col_8_ * fc_param_->deep_ * sizeof(float))); | b_r8_ptr_ = reinterpret_cast<float *>(malloc(fc_param_->col_8_ * fc_param_->deep_ * sizeof(float))); | ||||
| if (b_r8_ptr_ == nullptr) { | if (b_r8_ptr_ == nullptr) { | ||||
| @@ -87,7 +96,11 @@ int FullconnectionCPUKernel::Init() { | |||||
| } | } | ||||
| void FullconnectionCPUKernel::InitMatrixA(float *src_ptr, float *dst_ptr) { | void FullconnectionCPUKernel::InitMatrixA(float *src_ptr, float *dst_ptr) { | ||||
| #ifdef ENABLE_ARM32 | |||||
| RowMajor2Col4Major(src_ptr, a_c12_ptr_, fc_param_->row_, fc_param_->deep_); | |||||
| #else | |||||
| RowMajor2Col12Major(src_ptr, a_c12_ptr_, fc_param_->row_, fc_param_->deep_); | RowMajor2Col12Major(src_ptr, a_c12_ptr_, fc_param_->row_, fc_param_->deep_); | ||||
| #endif | |||||
| } | } | ||||
| void FullconnectionCPUKernel::InitMatrixB(float *src_ptr, float *dst_ptr) { | void FullconnectionCPUKernel::InitMatrixB(float *src_ptr, float *dst_ptr) { | ||||
| @@ -62,17 +62,27 @@ int MatmulCPUKernel::ReSize() { | |||||
| params_->row_ = c_shape[c_shape.size() - 2]; | params_->row_ = c_shape[c_shape.size() - 2]; | ||||
| params_->col_ = c_shape[c_shape.size() - 1]; | params_->col_ = c_shape[c_shape.size() - 1]; | ||||
| params_->deep_ = params_->a_transpose_ ? a_shape[a_shape.size() - 2] : a_shape[a_shape.size() - 1]; | params_->deep_ = params_->a_transpose_ ? a_shape[a_shape.size() - 2] : a_shape[a_shape.size() - 1]; | ||||
| params_->row_4_ = UP_ROUND(params_->row_, C4NUM); | |||||
| params_->row_12_ = UP_ROUND(params_->row_, C12NUM); | params_->row_12_ = UP_ROUND(params_->row_, C12NUM); | ||||
| params_->col_8_ = UP_ROUND(params_->col_, 8); | params_->col_8_ = UP_ROUND(params_->col_, 8); | ||||
| thread_count_ = MSMIN(thread_count_, UP_DIV(params_->col_8_, 8)); | thread_count_ = MSMIN(thread_count_, UP_DIV(params_->col_8_, 8)); | ||||
| thread_stride_ = UP_DIV(UP_DIV(params_->col_8_, 8), thread_count_); | thread_stride_ = UP_DIV(UP_DIV(params_->col_8_, 8), thread_count_); | ||||
| #ifdef ENABLE_ARM32 | |||||
| a_c12_ptr_ = reinterpret_cast<float *>(malloc(params_->batch * params_->row_4_ * params_->deep_ * sizeof(float))); | |||||
| if (a_c12_ptr_ == nullptr) { | |||||
| FreeTmpBuffer(); | |||||
| return RET_MEMORY_FAILED; | |||||
| } | |||||
| memset(a_c12_ptr_, 0, params_->row_4_ * params_->deep_ * sizeof(float)); | |||||
| #else | |||||
| a_c12_ptr_ = reinterpret_cast<float *>(malloc(params_->batch * params_->row_12_ * params_->deep_ * sizeof(float))); | a_c12_ptr_ = reinterpret_cast<float *>(malloc(params_->batch * params_->row_12_ * params_->deep_ * sizeof(float))); | ||||
| if (a_c12_ptr_ == nullptr) { | if (a_c12_ptr_ == nullptr) { | ||||
| FreeTmpBuffer(); | FreeTmpBuffer(); | ||||
| return RET_MEMORY_FAILED; | return RET_MEMORY_FAILED; | ||||
| } | } | ||||
| memset(a_c12_ptr_, 0, params_->row_12_ * params_->deep_ * sizeof(float)); | memset(a_c12_ptr_, 0, params_->row_12_ * params_->deep_ * sizeof(float)); | ||||
| #endif | |||||
| b_r8_ptr_ = reinterpret_cast<float *>(malloc(params_->batch * params_->col_8_ * params_->deep_ * sizeof(float))); | b_r8_ptr_ = reinterpret_cast<float *>(malloc(params_->batch * params_->col_8_ * params_->deep_ * sizeof(float))); | ||||
| if (b_r8_ptr_ == nullptr) { | if (b_r8_ptr_ == nullptr) { | ||||
| @@ -106,12 +116,21 @@ int MatmulCPUKernel::ReSize() { | |||||
| void MatmulCPUKernel::InitMatrixA(float *src_ptr, float *dst_ptr) { | void MatmulCPUKernel::InitMatrixA(float *src_ptr, float *dst_ptr) { | ||||
| for (int i = 0; i < params_->batch; i++) { | for (int i = 0; i < params_->batch; i++) { | ||||
| float *src = src_ptr + i * params_->deep_ * params_->row_; | float *src = src_ptr + i * params_->deep_ * params_->row_; | ||||
| #ifdef ENABLE_ARM32 | |||||
| float *dst = dst_ptr + i * params_->deep_ * params_->row_4_; | |||||
| if (params_->a_transpose_) { | |||||
| RowMajor2Row4Major(src, dst, params_->deep_, params_->row_); | |||||
| } else { | |||||
| RowMajor2Col4Major(src, dst, params_->row_, params_->deep_); | |||||
| } | |||||
| #else | |||||
| float *dst = dst_ptr + i * params_->deep_ * params_->row_12_; | float *dst = dst_ptr + i * params_->deep_ * params_->row_12_; | ||||
| if (params_->a_transpose_) { | if (params_->a_transpose_) { | ||||
| RowMajor2Row12Major(src, dst, params_->deep_, params_->row_); | RowMajor2Row12Major(src, dst, params_->deep_, params_->row_); | ||||
| } else { | } else { | ||||
| RowMajor2Col12Major(src, dst, params_->row_, params_->deep_); | RowMajor2Col12Major(src, dst, params_->row_, params_->deep_); | ||||
| } | } | ||||
| #endif | |||||
| } | } | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -79,7 +79,7 @@ if (PLATFORM_ARM64) | |||||
| ${TEST_ASSEMBLY_SRC} | ${TEST_ASSEMBLY_SRC} | ||||
| ) | ) | ||||
| endif() | endif() | ||||
| #[[ | |||||
| if (PLATFORM_ARM32) | if (PLATFORM_ARM32) | ||||
| # assembly | # assembly | ||||
| file(GLOB TEST_ASSEMBLY_SRC | file(GLOB TEST_ASSEMBLY_SRC | ||||
| @@ -91,7 +91,7 @@ if (PLATFORM_ARM32) | |||||
| ${TEST_ASSEMBLY_SRC} | ${TEST_ASSEMBLY_SRC} | ||||
| ) | ) | ||||
| endif() | endif() | ||||
| ]] | |||||
| if (ENABLE_FP16) | if (ENABLE_FP16) | ||||
| file(GLOB KERNEL_OP_FP16_SRC | file(GLOB KERNEL_OP_FP16_SRC | ||||
| ${LITE_DIR}/src/runtime/kernel/arm/fp16/*.cc | ${LITE_DIR}/src/runtime/kernel/arm/fp16/*.cc | ||||