Merge pull request !7043 from fuzhiye/tmptags/v1.1.0
| @@ -170,7 +170,6 @@ void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const floa | |||
| int input_unit = conv_param->input_unit_; | |||
| int in_batch = conv_param->input_batch_; | |||
| int in_channel = conv_param->input_channel_; | |||
| int ic8 = UP_DIV(in_channel, C8NUM); | |||
| int out_unit = conv_param->output_unit_; | |||
| int out_w_block = UP_DIV(conv_param->output_w_, out_unit); | |||
| int out_h_block = UP_DIV(conv_param->output_h_, out_unit); | |||
| @@ -179,18 +178,19 @@ void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const floa | |||
| int out_channel = conv_param->output_channel_; | |||
| int oc8 = UP_DIV(out_channel, C8NUM); | |||
| int input_unit_square = input_unit * input_unit; | |||
| size_t output_offset = oc8 * C8NUM * input_unit_square * sizeof(float16_t); | |||
| float16_t *trans_input = buffer_list[0]; | |||
| float16_t *gemm_out = buffer_list[1]; | |||
| float16_t *tmp_data = buffer_list[2]; | |||
| int trans_input_offset = tile_num * input_unit_square * ic8 * C8NUM; | |||
| float16_t *col_buffer = buffer_list[3]; | |||
| int trans_input_offset = tile_num * input_unit_square * in_channel; | |||
| int gemm_out_offset = tile_num * input_unit_square * oc8 * C8NUM; | |||
| int tmp_data_offset = input_unit_square * C8NUM; | |||
| int col_buffer_offset = tile_num * in_channel; | |||
| // step 1 : filter transform (pre-processed offline) | |||
| // step 2 : input transform (online) | |||
| for (int b = 0; b < in_batch; b++) { | |||
| int in_batch_offset = b * ic8 * C8NUM * conv_param->input_h_ * conv_param->input_w_; | |||
| int in_batch_offset = b * in_channel * conv_param->input_h_ * conv_param->input_w_; | |||
| int out_batch_offset = b * out_channel * conv_param->output_h_ * conv_param->output_w_; | |||
| for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_num) { | |||
| int out_tile_index = thread_id * tile_num; | |||
| @@ -200,8 +200,14 @@ void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const floa | |||
| tmp_data + task_id * tmp_data_offset, cal_num, out_tile_index, out_w_block, conv_param, | |||
| in_func); | |||
| // step 3 : gemm | |||
| IndirectGemmFp16_16x8(gemm_out + task_id * gemm_out_offset, trans_input + task_id * trans_input_offset, | |||
| trans_weight, NULL, input_unit_square, ic8 * 2, oc8 * C8NUM, output_offset, 1, 1, 0, 0); | |||
| float16_t *src_ptr = trans_input + task_id * trans_input_offset; | |||
| float16_t *dst_ptr = gemm_out + task_id * gemm_out_offset; | |||
| float16_t *tmp_col_ptr = col_buffer + task_id * col_buffer_offset; | |||
| for (int i = 0; i < input_unit_square; ++i) { | |||
| RowMajor2Col16MajorFp16Opt(src_ptr + i * tile_num * in_channel, tmp_col_ptr, tile_num, in_channel); | |||
| MatMul16x8(tmp_col_ptr, trans_weight + i * in_channel * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0, in_channel, | |||
| cal_num, oc8 * C8NUM, input_unit_square, false); | |||
| } | |||
| // step 4 : output transform | |||
| WinogradOutputTransformFp16(gemm_out + task_id * gemm_out_offset, output_data + out_batch_offset, bias_data, | |||
| @@ -41,8 +41,8 @@ void ColMajor2Row8MajorFp16(void *src_ptr, float16_t *dst_ptr, size_t row, size_ | |||
| void MatMul16x8(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, | |||
| int deep, int row, int col, int stride, bool write_nhwc) { | |||
| int row_16 = UP_ROUND(row, C16NUM); | |||
| int col_8 = UP_ROUND(col, C8NUM); | |||
| // int row_16 = UP_ROUND(row, C16NUM); | |||
| // int col_8 = UP_ROUND(col, C8NUM); | |||
| if (write_nhwc) { | |||
| /* col16-major * row8-major => col-major */ | |||
| for (int r = 0; r < row; r++) { | |||
| @@ -63,24 +63,42 @@ void MatMul16x8(const float16_t *a, const float16_t *b, float16_t *dst, const fl | |||
| } | |||
| } | |||
| } else { | |||
| /* col16-major * row8-major => row16x8-major */ | |||
| for (int r = 0; r < row_16; r++) { | |||
| for (int c = 0; c < col_8; c++) { | |||
| int r16div = r / C16NUM, r16mod = r % C16NUM; | |||
| int c8div = c / C8NUM, c8mod = c % C8NUM; | |||
| size_t ci = c8div * row_16 * C8NUM + r * C8NUM + c8mod; | |||
| float16_t value = 0; | |||
| for (int d = 0; d < deep; d++) { | |||
| size_t ai = r16div * deep * C16NUM + d * C16NUM + r16mod; | |||
| size_t bi = c8div * deep * C8NUM + d * C8NUM + c8mod; | |||
| for (int i = 0; i < row; ++i) { | |||
| int src_r_offset = i; | |||
| int dst_r_offset = i * col * stride; | |||
| for (int j = 0; j < col; ++j) { | |||
| int c8div = j / 8, c8mod = j % 8; | |||
| size_t ci = dst_r_offset + c8div * 8 * stride + c8mod; | |||
| float value = 0; | |||
| for (int d = 0; d < deep; ++d) { | |||
| size_t ai = src_r_offset + d * C16NUM; | |||
| size_t bi = c8div * deep * 8 + d * 8 + c8mod; | |||
| value = value + a[ai] * b[bi]; | |||
| } | |||
| if (bias != NULL) value += bias[col]; | |||
| if (bias != NULL) value += bias[j]; | |||
| if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); | |||
| if (act_type != ActType_No) value = MSMAX(0.0f, value); | |||
| dst[ci] = value; | |||
| } | |||
| } | |||
| // /* col16-major * row8-major => row16x8-major */ | |||
| // for (int r = 0; r < row_16; r++) { | |||
| // for (int c = 0; c < col_8; c++) { | |||
| // int r16div = r / C16NUM, r16mod = r % C16NUM; | |||
| // int c8div = c / C8NUM, c8mod = c % C8NUM; | |||
| // size_t ci = c8div * row_16 * C8NUM + r * C8NUM + c8mod; | |||
| // float16_t value = 0; | |||
| // for (int d = 0; d < deep; d++) { | |||
| // size_t ai = r16div * deep * C16NUM + d * C16NUM + r16mod; | |||
| // size_t bi = c8div * deep * C8NUM + d * C8NUM + c8mod; | |||
| // value = value + a[ai] * b[bi]; | |||
| // } | |||
| // if (bias != NULL) value += bias[col]; | |||
| // if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); | |||
| // if (act_type != ActType_No) value = MSMAX(0.0f, value); | |||
| // dst[ci] = value; | |||
| // } | |||
| // } | |||
| } | |||
| return; | |||
| } | |||
| @@ -29,6 +29,9 @@ | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| void MatMul16x8(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, | |||
| int deep, int row, int col, int stride, bool write_nhwc); | |||
| void MatMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type, | |||
| int depth, int row, int col, int stride, bool write_nhwc); | |||
| @@ -594,7 +594,7 @@ void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_in | |||
| int interval_y_e = src_y_e < input_h ? input_unit : (input_h - src_y_s); | |||
| int src_plane_offset = in_channel * (src_y_s * input_w + src_x_s); | |||
| int dst_plane_offset = c * C4NUM; | |||
| int dst_plane_offset = c * in_channel; | |||
| for (int ic = 0; ic < ic8; ic++) { | |||
| // clear tmp buffer | |||
| memset(tmp_data, 0, input_unit * input_unit * C8NUM * sizeof(float16_t)); | |||
| @@ -622,6 +622,30 @@ void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_in | |||
| #endif | |||
| } | |||
| } | |||
| } else if (real_c < 8 && real_c >= 4) { | |||
| for (int interval = interval_y_s; interval < interval_y_e; interval++) { | |||
| int src_y_offset = src_ic8_offset + (interval * input_w + interval_x_s) * in_channel; | |||
| int dst_y_offset = interval * input_unit * C8NUM + interval_x_s * C8NUM; | |||
| for (int j = 0; j < (interval_x_e - interval_x_s); j++) { | |||
| int src_x_offset = src_y_offset + j * in_channel; | |||
| int dst_x_offset = dst_y_offset + j * C8NUM; | |||
| const float16_t *src_addr = input_data + src_x_offset; | |||
| float16_t *dst_addr = tmp_data + dst_x_offset; | |||
| int rc = real_c - 4; | |||
| #ifdef ENABLE_NEON | |||
| vst1_f16(dst_addr, vld1_f16(src_addr)); | |||
| #else | |||
| for (int k = 0; k < C4NUM; k++) { | |||
| dst_addr[k] = src_addr[k]; | |||
| } | |||
| #endif | |||
| src_addr += 4; | |||
| dst_addr += 4; | |||
| for (int i = 0; i < rc; ++i) { | |||
| dst_addr[i] = src_addr[i]; | |||
| } | |||
| } | |||
| } | |||
| } else { | |||
| for (int interval = interval_y_s; interval < interval_y_e; interval++) { | |||
| int src_y_offset = src_ic8_offset + (interval * input_w + interval_x_s) * in_channel; | |||
| @@ -639,10 +663,10 @@ void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_in | |||
| } | |||
| // input transform | |||
| int dst_ic8_offset = dst_plane_offset + ic * tile_num * C8NUM; | |||
| size_t dst_step = ic8 * C8NUM * tile_num; | |||
| int dst_ic8_offset = dst_plane_offset + ic * C8NUM; | |||
| size_t dst_step = in_channel * tile_num; | |||
| float16_t *trans_input_ptr = trans_input + dst_ic8_offset; | |||
| func(tmp_data, trans_input_ptr, C8NUM, dst_step); | |||
| func(tmp_data, trans_input_ptr, C8NUM, dst_step, real_c); | |||
| } | |||
| out_tile_index++; | |||
| } // cal_tile_num loop | |||
| @@ -26,7 +26,8 @@ | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| typedef void (*InputTransFp16Func)(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step); | |||
| typedef void (*InputTransFp16Func)(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, | |||
| int real_c); | |||
| typedef void (*OutputTransFp16Func)(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, | |||
| int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); | |||
| @@ -56,6 +57,24 @@ void GeneralOutputTransformUnitFp16(const float16_t *src_data, float16_t *dst_da | |||
| src[14] = vld1q_f16(src_data + 14 * src_step); \ | |||
| src[15] = vld1q_f16(src_data + 15 * src_step); | |||
| #define Load16DataC4Fp16 \ | |||
| src[0] = vld1_f16(src_data + 0 * src_step); \ | |||
| src[1] = vld1_f16(src_data + 1 * src_step); \ | |||
| src[2] = vld1_f16(src_data + 2 * src_step); \ | |||
| src[3] = vld1_f16(src_data + 3 * src_step); \ | |||
| src[4] = vld1_f16(src_data + 4 * src_step); \ | |||
| src[5] = vld1_f16(src_data + 5 * src_step); \ | |||
| src[6] = vld1_f16(src_data + 6 * src_step); \ | |||
| src[7] = vld1_f16(src_data + 7 * src_step); \ | |||
| src[8] = vld1_f16(src_data + 8 * src_step); \ | |||
| src[9] = vld1_f16(src_data + 9 * src_step); \ | |||
| src[10] = vld1_f16(src_data + 10 * src_step); \ | |||
| src[11] = vld1_f16(src_data + 11 * src_step); \ | |||
| src[12] = vld1_f16(src_data + 12 * src_step); \ | |||
| src[13] = vld1_f16(src_data + 13 * src_step); \ | |||
| src[14] = vld1_f16(src_data + 14 * src_step); \ | |||
| src[15] = vld1_f16(src_data + 15 * src_step); | |||
| #define Load36DataFp16 \ | |||
| src[0] = vld1q_f16(src_data + 0 * src_step); \ | |||
| src[1] = vld1q_f16(src_data + 1 * src_step); \ | |||
| @@ -94,6 +113,44 @@ void GeneralOutputTransformUnitFp16(const float16_t *src_data, float16_t *dst_da | |||
| src[34] = vld1q_f16(src_data + 34 * src_step); \ | |||
| src[35] = vld1q_f16(src_data + 35 * src_step); | |||
| #define Load36DataC4Fp16 \ | |||
| src[0] = vld1_f16(src_data + 0 * src_step); \ | |||
| src[1] = vld1_f16(src_data + 1 * src_step); \ | |||
| src[2] = vld1_f16(src_data + 2 * src_step); \ | |||
| src[3] = vld1_f16(src_data + 3 * src_step); \ | |||
| src[4] = vld1_f16(src_data + 4 * src_step); \ | |||
| src[5] = vld1_f16(src_data + 5 * src_step); \ | |||
| src[6] = vld1_f16(src_data + 6 * src_step); \ | |||
| src[7] = vld1_f16(src_data + 7 * src_step); \ | |||
| src[8] = vld1_f16(src_data + 8 * src_step); \ | |||
| src[9] = vld1_f16(src_data + 9 * src_step); \ | |||
| src[10] = vld1_f16(src_data + 10 * src_step); \ | |||
| src[11] = vld1_f16(src_data + 11 * src_step); \ | |||
| src[12] = vld1_f16(src_data + 12 * src_step); \ | |||
| src[13] = vld1_f16(src_data + 13 * src_step); \ | |||
| src[14] = vld1_f16(src_data + 14 * src_step); \ | |||
| src[15] = vld1_f16(src_data + 15 * src_step); \ | |||
| src[16] = vld1_f16(src_data + 16 * src_step); \ | |||
| src[17] = vld1_f16(src_data + 17 * src_step); \ | |||
| src[18] = vld1_f16(src_data + 18 * src_step); \ | |||
| src[19] = vld1_f16(src_data + 19 * src_step); \ | |||
| src[20] = vld1_f16(src_data + 20 * src_step); \ | |||
| src[21] = vld1_f16(src_data + 21 * src_step); \ | |||
| src[22] = vld1_f16(src_data + 22 * src_step); \ | |||
| src[23] = vld1_f16(src_data + 23 * src_step); \ | |||
| src[24] = vld1_f16(src_data + 24 * src_step); \ | |||
| src[25] = vld1_f16(src_data + 25 * src_step); \ | |||
| src[26] = vld1_f16(src_data + 26 * src_step); \ | |||
| src[27] = vld1_f16(src_data + 27 * src_step); \ | |||
| src[28] = vld1_f16(src_data + 28 * src_step); \ | |||
| src[29] = vld1_f16(src_data + 29 * src_step); \ | |||
| src[30] = vld1_f16(src_data + 30 * src_step); \ | |||
| src[31] = vld1_f16(src_data + 31 * src_step); \ | |||
| src[32] = vld1_f16(src_data + 32 * src_step); \ | |||
| src[33] = vld1_f16(src_data + 33 * src_step); \ | |||
| src[34] = vld1_f16(src_data + 34 * src_step); \ | |||
| src[35] = vld1_f16(src_data + 35 * src_step); | |||
| #define Load64DataFp16 \ | |||
| src[0] = vld1q_f16(src_data + 0 * src_step); \ | |||
| src[1] = vld1q_f16(src_data + 1 * src_step); \ | |||
| @@ -160,13 +217,79 @@ void GeneralOutputTransformUnitFp16(const float16_t *src_data, float16_t *dst_da | |||
| src[62] = vld1q_f16(src_data + 62 * src_step); \ | |||
| src[63] = vld1q_f16(src_data + 63 * src_step); | |||
| #define Load64DataC4Fp16 \ | |||
| src[0] = vld1_f16(src_data + 0 * src_step); \ | |||
| src[1] = vld1_f16(src_data + 1 * src_step); \ | |||
| src[2] = vld1_f16(src_data + 2 * src_step); \ | |||
| src[3] = vld1_f16(src_data + 3 * src_step); \ | |||
| src[4] = vld1_f16(src_data + 4 * src_step); \ | |||
| src[5] = vld1_f16(src_data + 5 * src_step); \ | |||
| src[6] = vld1_f16(src_data + 6 * src_step); \ | |||
| src[7] = vld1_f16(src_data + 7 * src_step); \ | |||
| src[8] = vld1_f16(src_data + 8 * src_step); \ | |||
| src[9] = vld1_f16(src_data + 9 * src_step); \ | |||
| src[10] = vld1_f16(src_data + 10 * src_step); \ | |||
| src[11] = vld1_f16(src_data + 11 * src_step); \ | |||
| src[12] = vld1_f16(src_data + 12 * src_step); \ | |||
| src[13] = vld1_f16(src_data + 13 * src_step); \ | |||
| src[14] = vld1_f16(src_data + 14 * src_step); \ | |||
| src[15] = vld1_f16(src_data + 15 * src_step); \ | |||
| src[16] = vld1_f16(src_data + 16 * src_step); \ | |||
| src[17] = vld1_f16(src_data + 17 * src_step); \ | |||
| src[18] = vld1_f16(src_data + 18 * src_step); \ | |||
| src[19] = vld1_f16(src_data + 19 * src_step); \ | |||
| src[20] = vld1_f16(src_data + 20 * src_step); \ | |||
| src[21] = vld1_f16(src_data + 21 * src_step); \ | |||
| src[22] = vld1_f16(src_data + 22 * src_step); \ | |||
| src[23] = vld1_f16(src_data + 23 * src_step); \ | |||
| src[24] = vld1_f16(src_data + 24 * src_step); \ | |||
| src[25] = vld1_f16(src_data + 25 * src_step); \ | |||
| src[26] = vld1_f16(src_data + 26 * src_step); \ | |||
| src[27] = vld1_f16(src_data + 27 * src_step); \ | |||
| src[28] = vld1_f16(src_data + 28 * src_step); \ | |||
| src[29] = vld1_f16(src_data + 29 * src_step); \ | |||
| src[30] = vld1_f16(src_data + 30 * src_step); \ | |||
| src[31] = vld1_f16(src_data + 31 * src_step); \ | |||
| src[32] = vld1_f16(src_data + 32 * src_step); \ | |||
| src[33] = vld1_f16(src_data + 33 * src_step); \ | |||
| src[34] = vld1_f16(src_data + 34 * src_step); \ | |||
| src[35] = vld1_f16(src_data + 35 * src_step); \ | |||
| src[36] = vld1_f16(src_data + 36 * src_step); \ | |||
| src[37] = vld1_f16(src_data + 37 * src_step); \ | |||
| src[38] = vld1_f16(src_data + 38 * src_step); \ | |||
| src[39] = vld1_f16(src_data + 39 * src_step); \ | |||
| src[40] = vld1_f16(src_data + 40 * src_step); \ | |||
| src[41] = vld1_f16(src_data + 41 * src_step); \ | |||
| src[42] = vld1_f16(src_data + 42 * src_step); \ | |||
| src[43] = vld1_f16(src_data + 43 * src_step); \ | |||
| src[44] = vld1_f16(src_data + 44 * src_step); \ | |||
| src[45] = vld1_f16(src_data + 45 * src_step); \ | |||
| src[46] = vld1_f16(src_data + 46 * src_step); \ | |||
| src[47] = vld1_f16(src_data + 47 * src_step); \ | |||
| src[48] = vld1_f16(src_data + 48 * src_step); \ | |||
| src[49] = vld1_f16(src_data + 49 * src_step); \ | |||
| src[50] = vld1_f16(src_data + 50 * src_step); \ | |||
| src[51] = vld1_f16(src_data + 51 * src_step); \ | |||
| src[52] = vld1_f16(src_data + 52 * src_step); \ | |||
| src[53] = vld1_f16(src_data + 53 * src_step); \ | |||
| src[54] = vld1_f16(src_data + 54 * src_step); \ | |||
| src[55] = vld1_f16(src_data + 55 * src_step); \ | |||
| src[56] = vld1_f16(src_data + 56 * src_step); \ | |||
| src[57] = vld1_f16(src_data + 57 * src_step); \ | |||
| src[58] = vld1_f16(src_data + 58 * src_step); \ | |||
| src[59] = vld1_f16(src_data + 59 * src_step); \ | |||
| src[60] = vld1_f16(src_data + 60 * src_step); \ | |||
| src[61] = vld1_f16(src_data + 61 * src_step); \ | |||
| src[62] = vld1_f16(src_data + 62 * src_step); \ | |||
| src[63] = vld1_f16(src_data + 63 * src_step); | |||
| InputTransFp16Func GetInputTransFp16Func(int input_unit); | |||
| void InputTransform4x4UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step); | |||
| void InputTransform4x4UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c); | |||
| void InputTransform6x6UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step); | |||
| void InputTransform6x6UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c); | |||
| void InputTransform8x8UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step); | |||
| void InputTransform8x8UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c); | |||
| OutputTransFp16Func GetOutputTransFp16Func(int input_unit, int output_unit, ActType act_type); | |||
| @@ -176,6 +299,12 @@ OutputTransFp16Func GetOutputTransFp16Func(int input_unit, int output_unit, ActT | |||
| vst1q_f16(dst_data + dst_step * out_c, m[2]); \ | |||
| vst1q_f16(dst_data + dst_step * out_c + out_c, m[3]); | |||
| #define Store4DataC4Fp16 \ | |||
| vst1_f16(dst_data, m[0]); \ | |||
| vst1_f16(dst_data + out_c, m[1]); \ | |||
| vst1_f16(dst_data + dst_step * out_c, m[2]); \ | |||
| vst1_f16(dst_data + dst_step * out_c + out_c, m[3]); | |||
| #define Store9DataFp16 \ | |||
| vst1q_f16(dst_data, m[0]); \ | |||
| vst1q_f16(dst_data + out_c, m[1]); \ | |||
| @@ -187,6 +316,17 @@ OutputTransFp16Func GetOutputTransFp16Func(int input_unit, int output_unit, ActT | |||
| vst1q_f16(dst_data + 2 * dst_step * out_c + out_c, m[7]); \ | |||
| vst1q_f16(dst_data + 2 * dst_step * out_c + 2 * out_c, m[8]); | |||
| #define Store9DataC4Fp16 \ | |||
| vst1_f16(dst_data, m[0]); \ | |||
| vst1_f16(dst_data + out_c, m[1]); \ | |||
| vst1_f16(dst_data + 2 * out_c, m[2]); \ | |||
| vst1_f16(dst_data + dst_step * out_c, m[3]); \ | |||
| vst1_f16(dst_data + dst_step * out_c + out_c, m[4]); \ | |||
| vst1_f16(dst_data + dst_step * out_c + 2 * out_c, m[5]); \ | |||
| vst1_f16(dst_data + 2 * dst_step * out_c, m[6]); \ | |||
| vst1_f16(dst_data + 2 * dst_step * out_c + out_c, m[7]); \ | |||
| vst1_f16(dst_data + 2 * dst_step * out_c + 2 * out_c, m[8]); | |||
| #define Store16DataFp16 \ | |||
| vst1q_f16(dst_data, m[0]); \ | |||
| vst1q_f16(dst_data + out_c, m[1]); \ | |||
| @@ -205,6 +345,24 @@ OutputTransFp16Func GetOutputTransFp16Func(int input_unit, int output_unit, ActT | |||
| vst1q_f16(dst_data + 3 * dst_step * out_c + 2 * out_c, m[14]); \ | |||
| vst1q_f16(dst_data + 3 * dst_step * out_c + 3 * out_c, m[15]); | |||
| #define Store16DataC4Fp16 \ | |||
| vst1_f16(dst_data, m[0]); \ | |||
| vst1_f16(dst_data + out_c, m[1]); \ | |||
| vst1_f16(dst_data + 2 * out_c, m[2]); \ | |||
| vst1_f16(dst_data + 3 * out_c, m[3]); \ | |||
| vst1_f16(dst_data + dst_step * out_c, m[4]); \ | |||
| vst1_f16(dst_data + dst_step * out_c + out_c, m[5]); \ | |||
| vst1_f16(dst_data + dst_step * out_c + 2 * out_c, m[6]); \ | |||
| vst1_f16(dst_data + dst_step * out_c + 3 * out_c, m[7]); \ | |||
| vst1_f16(dst_data + 2 * dst_step * out_c, m[8]); \ | |||
| vst1_f16(dst_data + 2 * dst_step * out_c + out_c, m[9]); \ | |||
| vst1_f16(dst_data + 2 * dst_step * out_c + 2 * out_c, m[10]); \ | |||
| vst1_f16(dst_data + 2 * dst_step * out_c + 3 * out_c, m[11]); \ | |||
| vst1_f16(dst_data + 3 * dst_step * out_c, m[12]); \ | |||
| vst1_f16(dst_data + 3 * dst_step * out_c + out_c, m[13]); \ | |||
| vst1_f16(dst_data + 3 * dst_step * out_c + 2 * out_c, m[14]); \ | |||
| vst1_f16(dst_data + 3 * dst_step * out_c + 3 * out_c, m[15]); | |||
| #define Store25DataFp16 \ | |||
| vst1q_f16(dst_data, m[0]); \ | |||
| vst1q_f16(dst_data + out_c, m[1]); \ | |||
| @@ -232,6 +390,33 @@ OutputTransFp16Func GetOutputTransFp16Func(int input_unit, int output_unit, ActT | |||
| vst1q_f16(dst_data + 4 * dst_step * out_c + 3 * out_c, m[23]); \ | |||
| vst1q_f16(dst_data + 4 * dst_step * out_c + 4 * out_c, m[24]); | |||
| #define Store25DataC4Fp16 \ | |||
| vst1_f16(dst_data, m[0]); \ | |||
| vst1_f16(dst_data + out_c, m[1]); \ | |||
| vst1_f16(dst_data + 2 * out_c, m[2]); \ | |||
| vst1_f16(dst_data + 3 * out_c, m[3]); \ | |||
| vst1_f16(dst_data + 4 * out_c, m[4]); \ | |||
| vst1_f16(dst_data + dst_step * out_c, m[5]); \ | |||
| vst1_f16(dst_data + dst_step * out_c + out_c, m[6]); \ | |||
| vst1_f16(dst_data + dst_step * out_c + 2 * out_c, m[7]); \ | |||
| vst1_f16(dst_data + dst_step * out_c + 3 * out_c, m[8]); \ | |||
| vst1_f16(dst_data + dst_step * out_c + 4 * out_c, m[9]); \ | |||
| vst1_f16(dst_data + 2 * dst_step * out_c, m[10]); \ | |||
| vst1_f16(dst_data + 2 * dst_step * out_c + out_c, m[11]); \ | |||
| vst1_f16(dst_data + 2 * dst_step * out_c + 2 * out_c, m[12]); \ | |||
| vst1_f16(dst_data + 2 * dst_step * out_c + 3 * out_c, m[13]); \ | |||
| vst1_f16(dst_data + 2 * dst_step * out_c + 4 * out_c, m[14]); \ | |||
| vst1_f16(dst_data + 3 * dst_step * out_c, m[15]); \ | |||
| vst1_f16(dst_data + 3 * dst_step * out_c + out_c, m[16]); \ | |||
| vst1_f16(dst_data + 3 * dst_step * out_c + 2 * out_c, m[17]); \ | |||
| vst1_f16(dst_data + 3 * dst_step * out_c + 3 * out_c, m[18]); \ | |||
| vst1_f16(dst_data + 3 * dst_step * out_c + 4 * out_c, m[19]); \ | |||
| vst1_f16(dst_data + 4 * dst_step * out_c, m[20]); \ | |||
| vst1_f16(dst_data + 4 * dst_step * out_c + out_c, m[21]); \ | |||
| vst1_f16(dst_data + 4 * dst_step * out_c + 2 * out_c, m[22]); \ | |||
| vst1_f16(dst_data + 4 * dst_step * out_c + 3 * out_c, m[23]); \ | |||
| vst1_f16(dst_data + 4 * dst_step * out_c + 4 * out_c, m[24]); | |||
| void OutputTransform4x2UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, | |||
| int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); | |||
| void OutputTransform4x2ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, | |||
| @@ -197,25 +197,27 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::Tensor *> & | |||
| auto conv_param = reinterpret_cast<ConvParameter *>(opParameter); | |||
| int kernel_h = conv_param->kernel_h_; | |||
| int kernel_w = conv_param->kernel_w_; | |||
| conv_param->input_h_ = inputs.front()->Height(); | |||
| conv_param->input_w_ = inputs.front()->Width(); | |||
| conv_param->output_h_ = outputs.front()->Height(); | |||
| conv_param->output_w_ = outputs.front()->Width(); | |||
| bool use_winograd = false; | |||
| int out_unit; | |||
| if (primitive != nullptr && primitive->GetInferFlag()) { | |||
| conv_param->input_h_ = inputs.front()->Height(); | |||
| conv_param->input_w_ = inputs.front()->Width(); | |||
| conv_param->input_channel_ = inputs.front()->Channel(); | |||
| conv_param->output_h_ = outputs.front()->Height(); | |||
| conv_param->output_w_ = outputs.front()->Width(); | |||
| conv_param->output_channel_ = outputs.front()->Channel(); | |||
| conv_param->op_parameter_.thread_num_ = ctx->thread_num_; | |||
| CheckIfUseWinograd(&use_winograd, &out_unit, conv_param); | |||
| } | |||
| kernel::LiteKernel *kernel = nullptr; | |||
| if (kernel_h == 1 && kernel_w == 1) { | |||
| kernel = new (std::nothrow) kernel::Convolution1x1FP16CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| } else if (use_winograd) { | |||
| kernel = new (std::nothrow) | |||
| kernel::ConvolutionWinogradFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive, out_unit); | |||
| } else { | |||
| bool use_winograd = false; | |||
| int out_unit; | |||
| CheckIfUseWinograd(&use_winograd, &out_unit, conv_param); | |||
| if (use_winograd) { | |||
| kernel = new (std::nothrow) | |||
| kernel::ConvolutionWinogradFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive, out_unit); | |||
| } | |||
| if (kernel_h != 1 && kernel_w != 1 && !use_winograd) { | |||
| kernel = new (std::nothrow) kernel::ConvolutionFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| } | |||
| kernel = new (std::nothrow) kernel::ConvolutionFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| } | |||
| if (kernel == nullptr) { | |||
| MS_LOG(DEBUG) << "Create conv fp16 kernel failed."; | |||
| @@ -39,8 +39,6 @@ int ConvolutionWinogradFP16CPUKernel::WinogradFilterTransformFp16(const float16_ | |||
| // original weight format : ohwi | |||
| auto channel_in = conv_param_->input_channel_; | |||
| auto channel_out = conv_param_->output_channel_; | |||
| int ic8 = UP_DIV(channel_in, C8NUM); | |||
| int ic4 = ic8 * 2; | |||
| int input_unit_square = input_unit_ * input_unit_; | |||
| int oc_block_num = UP_DIV(channel_out, oc_block); | |||
| @@ -84,17 +82,7 @@ int ConvolutionWinogradFP16CPUKernel::WinogradFilterTransformFp16(const float16_ | |||
| MS_LOG(ERROR) << "malloc trans_out_data failed."; | |||
| return RET_ERROR; | |||
| } | |||
| std::vector<int> shape{input_unit_ * input_unit_, oc_block_num, ic4, C4NUM, oc_block}; | |||
| std::vector<int> strides; | |||
| for (int i = 0; i < 4; i++) { | |||
| int stride = 1; | |||
| for (int j = i + 1; j < 5; j++) { | |||
| stride *= shape[j]; | |||
| } | |||
| strides.push_back(stride); | |||
| } | |||
| int kernel_plane_stride = channel_in; | |||
| if (oc_block == 0) { | |||
| MS_LOG(ERROR) << "Divide by zero"; | |||
| free(tmp_weight_data); | |||
| @@ -104,18 +92,17 @@ int ConvolutionWinogradFP16CPUKernel::WinogradFilterTransformFp16(const float16_ | |||
| free(matrix_gt_data_fp16); | |||
| return RET_ERROR; | |||
| } | |||
| int stride1 = channel_in * oc_block; | |||
| for (int i = 0; i < channel_out; i++) { | |||
| int out_c_block = i / oc_block; | |||
| int out_c_res = i % oc_block; | |||
| int input_oz_offset = i * kernel_unit_ * kernel_unit_ * channel_in; | |||
| int output_oz_offset = out_c_block * strides[1] * input_unit_ * input_unit_ + out_c_res; | |||
| int output_oz_offset = out_c_block * stride1 + out_c_res; | |||
| for (int j = 0; j < channel_in; j++) { | |||
| int ic4_block = j / C4NUM; | |||
| int ic4_res = j % C4NUM; | |||
| int input_iz_offset = input_oz_offset + j; | |||
| int output_iz_offset = output_oz_offset + ic4_block * strides[2] + ic4_res * strides[3]; | |||
| int output_iz_offset = output_oz_offset + j * oc_block; | |||
| for (int k = 0; k < kernel_unit_ * kernel_unit_; k++) { | |||
| int input_xy_offset = input_iz_offset + k * kernel_plane_stride; | |||
| int input_xy_offset = input_iz_offset + k * channel_in; | |||
| tmp_weight_data[k] = *(weight_data + input_xy_offset); | |||
| } | |||
| // now we only support row-major matrix-multiply | |||
| @@ -125,7 +112,7 @@ int ConvolutionWinogradFP16CPUKernel::WinogradFilterTransformFp16(const float16_ | |||
| MatrixMultiplyFp16(tmp_data, matrix_gt_data_fp16, trans_out_data, input_unit_, kernel_unit_, input_unit_); | |||
| for (int z = 0; z < input_unit_square; z++) { | |||
| int output_xy_offset = output_iz_offset + z * strides[1]; | |||
| int output_xy_offset = output_iz_offset + z * oc_block_num * stride1; | |||
| trans_weight_[output_xy_offset] = trans_out_data[z]; | |||
| } | |||
| } | |||
| @@ -142,7 +129,6 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() { | |||
| auto filter_tensor = in_tensors_.at(kWeightIndex); | |||
| int in_channel = filter_tensor->Channel(); | |||
| int out_channel = filter_tensor->Batch(); | |||
| int ic8 = UP_DIV(in_channel, C8NUM); | |||
| conv_param_->input_channel_ = in_channel; | |||
| conv_param_->output_channel_ = out_channel; | |||
| @@ -157,7 +143,7 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() { | |||
| } | |||
| // set data | |||
| auto trans_matrix_data_size = input_unit_ * input_unit_ * ic8 * C8NUM * oc_block_num * oc_block * sizeof(float16_t); | |||
| auto trans_matrix_data_size = input_unit_ * input_unit_ * in_channel * oc_block_num * oc_block * sizeof(float16_t); | |||
| trans_weight_ = reinterpret_cast<float16_t *>(malloc(trans_matrix_data_size)); | |||
| if (trans_weight_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc trans_weight_ failed."; | |||
| @@ -209,9 +195,9 @@ int ConvolutionWinogradFP16CPUKernel::InitTmpBuffer() { | |||
| const int cal_num = 16; | |||
| int channel_out = conv_param_->output_channel_; | |||
| int oc8 = UP_DIV(channel_out, C8NUM); | |||
| int ic8 = UP_DIV(conv_param_->input_channel_, C8NUM); | |||
| size_t tile_buffer_size = thread_count_ * cal_num * input_unit_ * input_unit_ * ic8 * C8NUM * sizeof(float16_t); | |||
| size_t tile_buffer_size = | |||
| thread_count_ * cal_num * input_unit_ * input_unit_ * conv_param_->input_channel_ * sizeof(float16_t); | |||
| trans_input_ = reinterpret_cast<float16_t *>(ctx_->allocator->Malloc(tile_buffer_size)); | |||
| if (trans_input_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc trans_input_ failed."; | |||
| @@ -232,9 +218,17 @@ int ConvolutionWinogradFP16CPUKernel::InitTmpBuffer() { | |||
| return RET_ERROR; | |||
| } | |||
| col_buffer_ = reinterpret_cast<float16_t *>( | |||
| ctx_->allocator->Malloc(thread_count_ * cal_num * conv_param_->input_channel_ * sizeof(float16_t))); | |||
| if (col_buffer_ == nullptr) { | |||
| MS_LOG(ERROR) << "malloc col_buffer_ failed."; | |||
| return RET_ERROR; | |||
| } | |||
| tmp_buffer_address_list_[0] = trans_input_; | |||
| tmp_buffer_address_list_[1] = gemm_out_; | |||
| tmp_buffer_address_list_[2] = tmp_data_; | |||
| tmp_buffer_address_list_[3] = col_buffer_; | |||
| return RET_OK; | |||
| } | |||
| @@ -67,6 +67,10 @@ class ConvolutionWinogradFP16CPUKernel : public ConvolutionBaseFP16CPUKernel { | |||
| ctx_->allocator->Free(gemm_out_); | |||
| gemm_out_ = nullptr; | |||
| } | |||
| if (col_buffer_ != nullptr) { | |||
| ctx_->allocator->Free(col_buffer_); | |||
| col_buffer_ = nullptr; | |||
| } | |||
| } | |||
| int kernel_unit_; | |||
| int input_unit_; | |||
| @@ -75,7 +79,8 @@ class ConvolutionWinogradFP16CPUKernel : public ConvolutionBaseFP16CPUKernel { | |||
| float16_t *trans_input_ = nullptr; | |||
| float16_t *gemm_out_ = nullptr; | |||
| float16_t *trans_weight_ = nullptr; | |||
| TmpBufferAddressFp16 tmp_buffer_address_list_[3]; | |||
| float16_t *col_buffer_ = nullptr; | |||
| TmpBufferAddressFp16 tmp_buffer_address_list_[4]; | |||
| InputTransFp16Func in_func_; | |||
| OutputTransFp16Func out_func_; | |||
| }; | |||