Merge pull request !7043 from fuzhiye/tmptags/v1.1.0
| @@ -170,7 +170,6 @@ void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const floa | |||||
| int input_unit = conv_param->input_unit_; | int input_unit = conv_param->input_unit_; | ||||
| int in_batch = conv_param->input_batch_; | int in_batch = conv_param->input_batch_; | ||||
| int in_channel = conv_param->input_channel_; | int in_channel = conv_param->input_channel_; | ||||
| int ic8 = UP_DIV(in_channel, C8NUM); | |||||
| int out_unit = conv_param->output_unit_; | int out_unit = conv_param->output_unit_; | ||||
| int out_w_block = UP_DIV(conv_param->output_w_, out_unit); | int out_w_block = UP_DIV(conv_param->output_w_, out_unit); | ||||
| int out_h_block = UP_DIV(conv_param->output_h_, out_unit); | int out_h_block = UP_DIV(conv_param->output_h_, out_unit); | ||||
| @@ -179,18 +178,19 @@ void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const floa | |||||
| int out_channel = conv_param->output_channel_; | int out_channel = conv_param->output_channel_; | ||||
| int oc8 = UP_DIV(out_channel, C8NUM); | int oc8 = UP_DIV(out_channel, C8NUM); | ||||
| int input_unit_square = input_unit * input_unit; | int input_unit_square = input_unit * input_unit; | ||||
| size_t output_offset = oc8 * C8NUM * input_unit_square * sizeof(float16_t); | |||||
| float16_t *trans_input = buffer_list[0]; | float16_t *trans_input = buffer_list[0]; | ||||
| float16_t *gemm_out = buffer_list[1]; | float16_t *gemm_out = buffer_list[1]; | ||||
| float16_t *tmp_data = buffer_list[2]; | float16_t *tmp_data = buffer_list[2]; | ||||
| int trans_input_offset = tile_num * input_unit_square * ic8 * C8NUM; | |||||
| float16_t *col_buffer = buffer_list[3]; | |||||
| int trans_input_offset = tile_num * input_unit_square * in_channel; | |||||
| int gemm_out_offset = tile_num * input_unit_square * oc8 * C8NUM; | int gemm_out_offset = tile_num * input_unit_square * oc8 * C8NUM; | ||||
| int tmp_data_offset = input_unit_square * C8NUM; | int tmp_data_offset = input_unit_square * C8NUM; | ||||
| int col_buffer_offset = tile_num * in_channel; | |||||
| // step 1 : filter transform (pre-processed offline) | // step 1 : filter transform (pre-processed offline) | ||||
| // step 2 : input transform (online) | // step 2 : input transform (online) | ||||
| for (int b = 0; b < in_batch; b++) { | for (int b = 0; b < in_batch; b++) { | ||||
| int in_batch_offset = b * ic8 * C8NUM * conv_param->input_h_ * conv_param->input_w_; | |||||
| int in_batch_offset = b * in_channel * conv_param->input_h_ * conv_param->input_w_; | |||||
| int out_batch_offset = b * out_channel * conv_param->output_h_ * conv_param->output_w_; | int out_batch_offset = b * out_channel * conv_param->output_h_ * conv_param->output_w_; | ||||
| for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_num) { | for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_num) { | ||||
| int out_tile_index = thread_id * tile_num; | int out_tile_index = thread_id * tile_num; | ||||
| @@ -200,8 +200,14 @@ void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const floa | |||||
| tmp_data + task_id * tmp_data_offset, cal_num, out_tile_index, out_w_block, conv_param, | tmp_data + task_id * tmp_data_offset, cal_num, out_tile_index, out_w_block, conv_param, | ||||
| in_func); | in_func); | ||||
| // step 3 : gemm | // step 3 : gemm | ||||
| IndirectGemmFp16_16x8(gemm_out + task_id * gemm_out_offset, trans_input + task_id * trans_input_offset, | |||||
| trans_weight, NULL, input_unit_square, ic8 * 2, oc8 * C8NUM, output_offset, 1, 1, 0, 0); | |||||
| float16_t *src_ptr = trans_input + task_id * trans_input_offset; | |||||
| float16_t *dst_ptr = gemm_out + task_id * gemm_out_offset; | |||||
| float16_t *tmp_col_ptr = col_buffer + task_id * col_buffer_offset; | |||||
| for (int i = 0; i < input_unit_square; ++i) { | |||||
| RowMajor2Col16MajorFp16Opt(src_ptr + i * tile_num * in_channel, tmp_col_ptr, tile_num, in_channel); | |||||
| MatMul16x8(tmp_col_ptr, trans_weight + i * in_channel * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0, in_channel, | |||||
| cal_num, oc8 * C8NUM, input_unit_square, false); | |||||
| } | |||||
| // step 4 : output transform | // step 4 : output transform | ||||
| WinogradOutputTransformFp16(gemm_out + task_id * gemm_out_offset, output_data + out_batch_offset, bias_data, | WinogradOutputTransformFp16(gemm_out + task_id * gemm_out_offset, output_data + out_batch_offset, bias_data, | ||||
| @@ -41,8 +41,8 @@ void ColMajor2Row8MajorFp16(void *src_ptr, float16_t *dst_ptr, size_t row, size_ | |||||
| void MatMul16x8(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, | void MatMul16x8(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, | ||||
| int deep, int row, int col, int stride, bool write_nhwc) { | int deep, int row, int col, int stride, bool write_nhwc) { | ||||
| int row_16 = UP_ROUND(row, C16NUM); | |||||
| int col_8 = UP_ROUND(col, C8NUM); | |||||
| // int row_16 = UP_ROUND(row, C16NUM); | |||||
| // int col_8 = UP_ROUND(col, C8NUM); | |||||
| if (write_nhwc) { | if (write_nhwc) { | ||||
| /* col16-major * row8-major => col-major */ | /* col16-major * row8-major => col-major */ | ||||
| for (int r = 0; r < row; r++) { | for (int r = 0; r < row; r++) { | ||||
| @@ -63,24 +63,42 @@ void MatMul16x8(const float16_t *a, const float16_t *b, float16_t *dst, const fl | |||||
| } | } | ||||
| } | } | ||||
| } else { | } else { | ||||
| /* col16-major * row8-major => row16x8-major */ | |||||
| for (int r = 0; r < row_16; r++) { | |||||
| for (int c = 0; c < col_8; c++) { | |||||
| int r16div = r / C16NUM, r16mod = r % C16NUM; | |||||
| int c8div = c / C8NUM, c8mod = c % C8NUM; | |||||
| size_t ci = c8div * row_16 * C8NUM + r * C8NUM + c8mod; | |||||
| float16_t value = 0; | |||||
| for (int d = 0; d < deep; d++) { | |||||
| size_t ai = r16div * deep * C16NUM + d * C16NUM + r16mod; | |||||
| size_t bi = c8div * deep * C8NUM + d * C8NUM + c8mod; | |||||
| for (int i = 0; i < row; ++i) { | |||||
| int src_r_offset = i; | |||||
| int dst_r_offset = i * col * stride; | |||||
| for (int j = 0; j < col; ++j) { | |||||
| int c8div = j / 8, c8mod = j % 8; | |||||
| size_t ci = dst_r_offset + c8div * 8 * stride + c8mod; | |||||
| float value = 0; | |||||
| for (int d = 0; d < deep; ++d) { | |||||
| size_t ai = src_r_offset + d * C16NUM; | |||||
| size_t bi = c8div * deep * 8 + d * 8 + c8mod; | |||||
| value = value + a[ai] * b[bi]; | value = value + a[ai] * b[bi]; | ||||
| } | } | ||||
| if (bias != NULL) value += bias[col]; | |||||
| if (bias != NULL) value += bias[j]; | |||||
| if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); | if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); | ||||
| if (act_type != ActType_No) value = MSMAX(0.0f, value); | if (act_type != ActType_No) value = MSMAX(0.0f, value); | ||||
| dst[ci] = value; | dst[ci] = value; | ||||
| } | } | ||||
| } | } | ||||
| // /* col16-major * row8-major => row16x8-major */ | |||||
| // for (int r = 0; r < row_16; r++) { | |||||
| // for (int c = 0; c < col_8; c++) { | |||||
| // int r16div = r / C16NUM, r16mod = r % C16NUM; | |||||
| // int c8div = c / C8NUM, c8mod = c % C8NUM; | |||||
| // size_t ci = c8div * row_16 * C8NUM + r * C8NUM + c8mod; | |||||
| // float16_t value = 0; | |||||
| // for (int d = 0; d < deep; d++) { | |||||
| // size_t ai = r16div * deep * C16NUM + d * C16NUM + r16mod; | |||||
| // size_t bi = c8div * deep * C8NUM + d * C8NUM + c8mod; | |||||
| // value = value + a[ai] * b[bi]; | |||||
| // } | |||||
| // if (bias != NULL) value += bias[col]; | |||||
| // if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); | |||||
| // if (act_type != ActType_No) value = MSMAX(0.0f, value); | |||||
| // dst[ci] = value; | |||||
| // } | |||||
| // } | |||||
| } | } | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -29,6 +29,9 @@ | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| extern "C" { | extern "C" { | ||||
| #endif | #endif | ||||
| void MatMul16x8(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, | |||||
| int deep, int row, int col, int stride, bool write_nhwc); | |||||
| void MatMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type, | void MatMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type, | ||||
| int depth, int row, int col, int stride, bool write_nhwc); | int depth, int row, int col, int stride, bool write_nhwc); | ||||
| @@ -594,7 +594,7 @@ void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_in | |||||
| int interval_y_e = src_y_e < input_h ? input_unit : (input_h - src_y_s); | int interval_y_e = src_y_e < input_h ? input_unit : (input_h - src_y_s); | ||||
| int src_plane_offset = in_channel * (src_y_s * input_w + src_x_s); | int src_plane_offset = in_channel * (src_y_s * input_w + src_x_s); | ||||
| int dst_plane_offset = c * C4NUM; | |||||
| int dst_plane_offset = c * in_channel; | |||||
| for (int ic = 0; ic < ic8; ic++) { | for (int ic = 0; ic < ic8; ic++) { | ||||
| // clear tmp buffer | // clear tmp buffer | ||||
| memset(tmp_data, 0, input_unit * input_unit * C8NUM * sizeof(float16_t)); | memset(tmp_data, 0, input_unit * input_unit * C8NUM * sizeof(float16_t)); | ||||
| @@ -622,6 +622,30 @@ void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_in | |||||
| #endif | #endif | ||||
| } | } | ||||
| } | } | ||||
| } else if (real_c < 8 && real_c >= 4) { | |||||
| for (int interval = interval_y_s; interval < interval_y_e; interval++) { | |||||
| int src_y_offset = src_ic8_offset + (interval * input_w + interval_x_s) * in_channel; | |||||
| int dst_y_offset = interval * input_unit * C8NUM + interval_x_s * C8NUM; | |||||
| for (int j = 0; j < (interval_x_e - interval_x_s); j++) { | |||||
| int src_x_offset = src_y_offset + j * in_channel; | |||||
| int dst_x_offset = dst_y_offset + j * C8NUM; | |||||
| const float16_t *src_addr = input_data + src_x_offset; | |||||
| float16_t *dst_addr = tmp_data + dst_x_offset; | |||||
| int rc = real_c - 4; | |||||
| #ifdef ENABLE_NEON | |||||
| vst1_f16(dst_addr, vld1_f16(src_addr)); | |||||
| #else | |||||
| for (int k = 0; k < C4NUM; k++) { | |||||
| dst_addr[k] = src_addr[k]; | |||||
| } | |||||
| #endif | |||||
| src_addr += 4; | |||||
| dst_addr += 4; | |||||
| for (int i = 0; i < rc; ++i) { | |||||
| dst_addr[i] = src_addr[i]; | |||||
| } | |||||
| } | |||||
| } | |||||
| } else { | } else { | ||||
| for (int interval = interval_y_s; interval < interval_y_e; interval++) { | for (int interval = interval_y_s; interval < interval_y_e; interval++) { | ||||
| int src_y_offset = src_ic8_offset + (interval * input_w + interval_x_s) * in_channel; | int src_y_offset = src_ic8_offset + (interval * input_w + interval_x_s) * in_channel; | ||||
| @@ -639,10 +663,10 @@ void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_in | |||||
| } | } | ||||
| // input transform | // input transform | ||||
| int dst_ic8_offset = dst_plane_offset + ic * tile_num * C8NUM; | |||||
| size_t dst_step = ic8 * C8NUM * tile_num; | |||||
| int dst_ic8_offset = dst_plane_offset + ic * C8NUM; | |||||
| size_t dst_step = in_channel * tile_num; | |||||
| float16_t *trans_input_ptr = trans_input + dst_ic8_offset; | float16_t *trans_input_ptr = trans_input + dst_ic8_offset; | ||||
| func(tmp_data, trans_input_ptr, C8NUM, dst_step); | |||||
| func(tmp_data, trans_input_ptr, C8NUM, dst_step, real_c); | |||||
| } | } | ||||
| out_tile_index++; | out_tile_index++; | ||||
| } // cal_tile_num loop | } // cal_tile_num loop | ||||
| @@ -26,7 +26,8 @@ | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| extern "C" { | extern "C" { | ||||
| #endif | #endif | ||||
| typedef void (*InputTransFp16Func)(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step); | |||||
| typedef void (*InputTransFp16Func)(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, | |||||
| int real_c); | |||||
| typedef void (*OutputTransFp16Func)(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, | typedef void (*OutputTransFp16Func)(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, | ||||
| int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); | int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); | ||||
| @@ -56,6 +57,24 @@ void GeneralOutputTransformUnitFp16(const float16_t *src_data, float16_t *dst_da | |||||
| src[14] = vld1q_f16(src_data + 14 * src_step); \ | src[14] = vld1q_f16(src_data + 14 * src_step); \ | ||||
| src[15] = vld1q_f16(src_data + 15 * src_step); | src[15] = vld1q_f16(src_data + 15 * src_step); | ||||
| #define Load16DataC4Fp16 \ | |||||
| src[0] = vld1_f16(src_data + 0 * src_step); \ | |||||
| src[1] = vld1_f16(src_data + 1 * src_step); \ | |||||
| src[2] = vld1_f16(src_data + 2 * src_step); \ | |||||
| src[3] = vld1_f16(src_data + 3 * src_step); \ | |||||
| src[4] = vld1_f16(src_data + 4 * src_step); \ | |||||
| src[5] = vld1_f16(src_data + 5 * src_step); \ | |||||
| src[6] = vld1_f16(src_data + 6 * src_step); \ | |||||
| src[7] = vld1_f16(src_data + 7 * src_step); \ | |||||
| src[8] = vld1_f16(src_data + 8 * src_step); \ | |||||
| src[9] = vld1_f16(src_data + 9 * src_step); \ | |||||
| src[10] = vld1_f16(src_data + 10 * src_step); \ | |||||
| src[11] = vld1_f16(src_data + 11 * src_step); \ | |||||
| src[12] = vld1_f16(src_data + 12 * src_step); \ | |||||
| src[13] = vld1_f16(src_data + 13 * src_step); \ | |||||
| src[14] = vld1_f16(src_data + 14 * src_step); \ | |||||
| src[15] = vld1_f16(src_data + 15 * src_step); | |||||
| #define Load36DataFp16 \ | #define Load36DataFp16 \ | ||||
| src[0] = vld1q_f16(src_data + 0 * src_step); \ | src[0] = vld1q_f16(src_data + 0 * src_step); \ | ||||
| src[1] = vld1q_f16(src_data + 1 * src_step); \ | src[1] = vld1q_f16(src_data + 1 * src_step); \ | ||||
| @@ -94,6 +113,44 @@ void GeneralOutputTransformUnitFp16(const float16_t *src_data, float16_t *dst_da | |||||
| src[34] = vld1q_f16(src_data + 34 * src_step); \ | src[34] = vld1q_f16(src_data + 34 * src_step); \ | ||||
| src[35] = vld1q_f16(src_data + 35 * src_step); | src[35] = vld1q_f16(src_data + 35 * src_step); | ||||
| #define Load36DataC4Fp16 \ | |||||
| src[0] = vld1_f16(src_data + 0 * src_step); \ | |||||
| src[1] = vld1_f16(src_data + 1 * src_step); \ | |||||
| src[2] = vld1_f16(src_data + 2 * src_step); \ | |||||
| src[3] = vld1_f16(src_data + 3 * src_step); \ | |||||
| src[4] = vld1_f16(src_data + 4 * src_step); \ | |||||
| src[5] = vld1_f16(src_data + 5 * src_step); \ | |||||
| src[6] = vld1_f16(src_data + 6 * src_step); \ | |||||
| src[7] = vld1_f16(src_data + 7 * src_step); \ | |||||
| src[8] = vld1_f16(src_data + 8 * src_step); \ | |||||
| src[9] = vld1_f16(src_data + 9 * src_step); \ | |||||
| src[10] = vld1_f16(src_data + 10 * src_step); \ | |||||
| src[11] = vld1_f16(src_data + 11 * src_step); \ | |||||
| src[12] = vld1_f16(src_data + 12 * src_step); \ | |||||
| src[13] = vld1_f16(src_data + 13 * src_step); \ | |||||
| src[14] = vld1_f16(src_data + 14 * src_step); \ | |||||
| src[15] = vld1_f16(src_data + 15 * src_step); \ | |||||
| src[16] = vld1_f16(src_data + 16 * src_step); \ | |||||
| src[17] = vld1_f16(src_data + 17 * src_step); \ | |||||
| src[18] = vld1_f16(src_data + 18 * src_step); \ | |||||
| src[19] = vld1_f16(src_data + 19 * src_step); \ | |||||
| src[20] = vld1_f16(src_data + 20 * src_step); \ | |||||
| src[21] = vld1_f16(src_data + 21 * src_step); \ | |||||
| src[22] = vld1_f16(src_data + 22 * src_step); \ | |||||
| src[23] = vld1_f16(src_data + 23 * src_step); \ | |||||
| src[24] = vld1_f16(src_data + 24 * src_step); \ | |||||
| src[25] = vld1_f16(src_data + 25 * src_step); \ | |||||
| src[26] = vld1_f16(src_data + 26 * src_step); \ | |||||
| src[27] = vld1_f16(src_data + 27 * src_step); \ | |||||
| src[28] = vld1_f16(src_data + 28 * src_step); \ | |||||
| src[29] = vld1_f16(src_data + 29 * src_step); \ | |||||
| src[30] = vld1_f16(src_data + 30 * src_step); \ | |||||
| src[31] = vld1_f16(src_data + 31 * src_step); \ | |||||
| src[32] = vld1_f16(src_data + 32 * src_step); \ | |||||
| src[33] = vld1_f16(src_data + 33 * src_step); \ | |||||
| src[34] = vld1_f16(src_data + 34 * src_step); \ | |||||
| src[35] = vld1_f16(src_data + 35 * src_step); | |||||
| #define Load64DataFp16 \ | #define Load64DataFp16 \ | ||||
| src[0] = vld1q_f16(src_data + 0 * src_step); \ | src[0] = vld1q_f16(src_data + 0 * src_step); \ | ||||
| src[1] = vld1q_f16(src_data + 1 * src_step); \ | src[1] = vld1q_f16(src_data + 1 * src_step); \ | ||||
| @@ -160,13 +217,79 @@ void GeneralOutputTransformUnitFp16(const float16_t *src_data, float16_t *dst_da | |||||
| src[62] = vld1q_f16(src_data + 62 * src_step); \ | src[62] = vld1q_f16(src_data + 62 * src_step); \ | ||||
| src[63] = vld1q_f16(src_data + 63 * src_step); | src[63] = vld1q_f16(src_data + 63 * src_step); | ||||
| #define Load64DataC4Fp16 \ | |||||
| src[0] = vld1_f16(src_data + 0 * src_step); \ | |||||
| src[1] = vld1_f16(src_data + 1 * src_step); \ | |||||
| src[2] = vld1_f16(src_data + 2 * src_step); \ | |||||
| src[3] = vld1_f16(src_data + 3 * src_step); \ | |||||
| src[4] = vld1_f16(src_data + 4 * src_step); \ | |||||
| src[5] = vld1_f16(src_data + 5 * src_step); \ | |||||
| src[6] = vld1_f16(src_data + 6 * src_step); \ | |||||
| src[7] = vld1_f16(src_data + 7 * src_step); \ | |||||
| src[8] = vld1_f16(src_data + 8 * src_step); \ | |||||
| src[9] = vld1_f16(src_data + 9 * src_step); \ | |||||
| src[10] = vld1_f16(src_data + 10 * src_step); \ | |||||
| src[11] = vld1_f16(src_data + 11 * src_step); \ | |||||
| src[12] = vld1_f16(src_data + 12 * src_step); \ | |||||
| src[13] = vld1_f16(src_data + 13 * src_step); \ | |||||
| src[14] = vld1_f16(src_data + 14 * src_step); \ | |||||
| src[15] = vld1_f16(src_data + 15 * src_step); \ | |||||
| src[16] = vld1_f16(src_data + 16 * src_step); \ | |||||
| src[17] = vld1_f16(src_data + 17 * src_step); \ | |||||
| src[18] = vld1_f16(src_data + 18 * src_step); \ | |||||
| src[19] = vld1_f16(src_data + 19 * src_step); \ | |||||
| src[20] = vld1_f16(src_data + 20 * src_step); \ | |||||
| src[21] = vld1_f16(src_data + 21 * src_step); \ | |||||
| src[22] = vld1_f16(src_data + 22 * src_step); \ | |||||
| src[23] = vld1_f16(src_data + 23 * src_step); \ | |||||
| src[24] = vld1_f16(src_data + 24 * src_step); \ | |||||
| src[25] = vld1_f16(src_data + 25 * src_step); \ | |||||
| src[26] = vld1_f16(src_data + 26 * src_step); \ | |||||
| src[27] = vld1_f16(src_data + 27 * src_step); \ | |||||
| src[28] = vld1_f16(src_data + 28 * src_step); \ | |||||
| src[29] = vld1_f16(src_data + 29 * src_step); \ | |||||
| src[30] = vld1_f16(src_data + 30 * src_step); \ | |||||
| src[31] = vld1_f16(src_data + 31 * src_step); \ | |||||
| src[32] = vld1_f16(src_data + 32 * src_step); \ | |||||
| src[33] = vld1_f16(src_data + 33 * src_step); \ | |||||
| src[34] = vld1_f16(src_data + 34 * src_step); \ | |||||
| src[35] = vld1_f16(src_data + 35 * src_step); \ | |||||
| src[36] = vld1_f16(src_data + 36 * src_step); \ | |||||
| src[37] = vld1_f16(src_data + 37 * src_step); \ | |||||
| src[38] = vld1_f16(src_data + 38 * src_step); \ | |||||
| src[39] = vld1_f16(src_data + 39 * src_step); \ | |||||
| src[40] = vld1_f16(src_data + 40 * src_step); \ | |||||
| src[41] = vld1_f16(src_data + 41 * src_step); \ | |||||
| src[42] = vld1_f16(src_data + 42 * src_step); \ | |||||
| src[43] = vld1_f16(src_data + 43 * src_step); \ | |||||
| src[44] = vld1_f16(src_data + 44 * src_step); \ | |||||
| src[45] = vld1_f16(src_data + 45 * src_step); \ | |||||
| src[46] = vld1_f16(src_data + 46 * src_step); \ | |||||
| src[47] = vld1_f16(src_data + 47 * src_step); \ | |||||
| src[48] = vld1_f16(src_data + 48 * src_step); \ | |||||
| src[49] = vld1_f16(src_data + 49 * src_step); \ | |||||
| src[50] = vld1_f16(src_data + 50 * src_step); \ | |||||
| src[51] = vld1_f16(src_data + 51 * src_step); \ | |||||
| src[52] = vld1_f16(src_data + 52 * src_step); \ | |||||
| src[53] = vld1_f16(src_data + 53 * src_step); \ | |||||
| src[54] = vld1_f16(src_data + 54 * src_step); \ | |||||
| src[55] = vld1_f16(src_data + 55 * src_step); \ | |||||
| src[56] = vld1_f16(src_data + 56 * src_step); \ | |||||
| src[57] = vld1_f16(src_data + 57 * src_step); \ | |||||
| src[58] = vld1_f16(src_data + 58 * src_step); \ | |||||
| src[59] = vld1_f16(src_data + 59 * src_step); \ | |||||
| src[60] = vld1_f16(src_data + 60 * src_step); \ | |||||
| src[61] = vld1_f16(src_data + 61 * src_step); \ | |||||
| src[62] = vld1_f16(src_data + 62 * src_step); \ | |||||
| src[63] = vld1_f16(src_data + 63 * src_step); | |||||
| InputTransFp16Func GetInputTransFp16Func(int input_unit); | InputTransFp16Func GetInputTransFp16Func(int input_unit); | ||||
| void InputTransform4x4UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step); | |||||
| void InputTransform4x4UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c); | |||||
| void InputTransform6x6UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step); | |||||
| void InputTransform6x6UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c); | |||||
| void InputTransform8x8UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step); | |||||
| void InputTransform8x8UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c); | |||||
| OutputTransFp16Func GetOutputTransFp16Func(int input_unit, int output_unit, ActType act_type); | OutputTransFp16Func GetOutputTransFp16Func(int input_unit, int output_unit, ActType act_type); | ||||
| @@ -176,6 +299,12 @@ OutputTransFp16Func GetOutputTransFp16Func(int input_unit, int output_unit, ActT | |||||
| vst1q_f16(dst_data + dst_step * out_c, m[2]); \ | vst1q_f16(dst_data + dst_step * out_c, m[2]); \ | ||||
| vst1q_f16(dst_data + dst_step * out_c + out_c, m[3]); | vst1q_f16(dst_data + dst_step * out_c + out_c, m[3]); | ||||
| #define Store4DataC4Fp16 \ | |||||
| vst1_f16(dst_data, m[0]); \ | |||||
| vst1_f16(dst_data + out_c, m[1]); \ | |||||
| vst1_f16(dst_data + dst_step * out_c, m[2]); \ | |||||
| vst1_f16(dst_data + dst_step * out_c + out_c, m[3]); | |||||
| #define Store9DataFp16 \ | #define Store9DataFp16 \ | ||||
| vst1q_f16(dst_data, m[0]); \ | vst1q_f16(dst_data, m[0]); \ | ||||
| vst1q_f16(dst_data + out_c, m[1]); \ | vst1q_f16(dst_data + out_c, m[1]); \ | ||||
| @@ -187,6 +316,17 @@ OutputTransFp16Func GetOutputTransFp16Func(int input_unit, int output_unit, ActT | |||||
| vst1q_f16(dst_data + 2 * dst_step * out_c + out_c, m[7]); \ | vst1q_f16(dst_data + 2 * dst_step * out_c + out_c, m[7]); \ | ||||
| vst1q_f16(dst_data + 2 * dst_step * out_c + 2 * out_c, m[8]); | vst1q_f16(dst_data + 2 * dst_step * out_c + 2 * out_c, m[8]); | ||||
| #define Store9DataC4Fp16 \ | |||||
| vst1_f16(dst_data, m[0]); \ | |||||
| vst1_f16(dst_data + out_c, m[1]); \ | |||||
| vst1_f16(dst_data + 2 * out_c, m[2]); \ | |||||
| vst1_f16(dst_data + dst_step * out_c, m[3]); \ | |||||
| vst1_f16(dst_data + dst_step * out_c + out_c, m[4]); \ | |||||
| vst1_f16(dst_data + dst_step * out_c + 2 * out_c, m[5]); \ | |||||
| vst1_f16(dst_data + 2 * dst_step * out_c, m[6]); \ | |||||
| vst1_f16(dst_data + 2 * dst_step * out_c + out_c, m[7]); \ | |||||
| vst1_f16(dst_data + 2 * dst_step * out_c + 2 * out_c, m[8]); | |||||
| #define Store16DataFp16 \ | #define Store16DataFp16 \ | ||||
| vst1q_f16(dst_data, m[0]); \ | vst1q_f16(dst_data, m[0]); \ | ||||
| vst1q_f16(dst_data + out_c, m[1]); \ | vst1q_f16(dst_data + out_c, m[1]); \ | ||||
| @@ -205,6 +345,24 @@ OutputTransFp16Func GetOutputTransFp16Func(int input_unit, int output_unit, ActT | |||||
| vst1q_f16(dst_data + 3 * dst_step * out_c + 2 * out_c, m[14]); \ | vst1q_f16(dst_data + 3 * dst_step * out_c + 2 * out_c, m[14]); \ | ||||
| vst1q_f16(dst_data + 3 * dst_step * out_c + 3 * out_c, m[15]); | vst1q_f16(dst_data + 3 * dst_step * out_c + 3 * out_c, m[15]); | ||||
| #define Store16DataC4Fp16 \ | |||||
| vst1_f16(dst_data, m[0]); \ | |||||
| vst1_f16(dst_data + out_c, m[1]); \ | |||||
| vst1_f16(dst_data + 2 * out_c, m[2]); \ | |||||
| vst1_f16(dst_data + 3 * out_c, m[3]); \ | |||||
| vst1_f16(dst_data + dst_step * out_c, m[4]); \ | |||||
| vst1_f16(dst_data + dst_step * out_c + out_c, m[5]); \ | |||||
| vst1_f16(dst_data + dst_step * out_c + 2 * out_c, m[6]); \ | |||||
| vst1_f16(dst_data + dst_step * out_c + 3 * out_c, m[7]); \ | |||||
| vst1_f16(dst_data + 2 * dst_step * out_c, m[8]); \ | |||||
| vst1_f16(dst_data + 2 * dst_step * out_c + out_c, m[9]); \ | |||||
| vst1_f16(dst_data + 2 * dst_step * out_c + 2 * out_c, m[10]); \ | |||||
| vst1_f16(dst_data + 2 * dst_step * out_c + 3 * out_c, m[11]); \ | |||||
| vst1_f16(dst_data + 3 * dst_step * out_c, m[12]); \ | |||||
| vst1_f16(dst_data + 3 * dst_step * out_c + out_c, m[13]); \ | |||||
| vst1_f16(dst_data + 3 * dst_step * out_c + 2 * out_c, m[14]); \ | |||||
| vst1_f16(dst_data + 3 * dst_step * out_c + 3 * out_c, m[15]); | |||||
| #define Store25DataFp16 \ | #define Store25DataFp16 \ | ||||
| vst1q_f16(dst_data, m[0]); \ | vst1q_f16(dst_data, m[0]); \ | ||||
| vst1q_f16(dst_data + out_c, m[1]); \ | vst1q_f16(dst_data + out_c, m[1]); \ | ||||
| @@ -232,6 +390,33 @@ OutputTransFp16Func GetOutputTransFp16Func(int input_unit, int output_unit, ActT | |||||
| vst1q_f16(dst_data + 4 * dst_step * out_c + 3 * out_c, m[23]); \ | vst1q_f16(dst_data + 4 * dst_step * out_c + 3 * out_c, m[23]); \ | ||||
| vst1q_f16(dst_data + 4 * dst_step * out_c + 4 * out_c, m[24]); | vst1q_f16(dst_data + 4 * dst_step * out_c + 4 * out_c, m[24]); | ||||
| #define Store25DataC4Fp16 \ | |||||
| vst1_f16(dst_data, m[0]); \ | |||||
| vst1_f16(dst_data + out_c, m[1]); \ | |||||
| vst1_f16(dst_data + 2 * out_c, m[2]); \ | |||||
| vst1_f16(dst_data + 3 * out_c, m[3]); \ | |||||
| vst1_f16(dst_data + 4 * out_c, m[4]); \ | |||||
| vst1_f16(dst_data + dst_step * out_c, m[5]); \ | |||||
| vst1_f16(dst_data + dst_step * out_c + out_c, m[6]); \ | |||||
| vst1_f16(dst_data + dst_step * out_c + 2 * out_c, m[7]); \ | |||||
| vst1_f16(dst_data + dst_step * out_c + 3 * out_c, m[8]); \ | |||||
| vst1_f16(dst_data + dst_step * out_c + 4 * out_c, m[9]); \ | |||||
| vst1_f16(dst_data + 2 * dst_step * out_c, m[10]); \ | |||||
| vst1_f16(dst_data + 2 * dst_step * out_c + out_c, m[11]); \ | |||||
| vst1_f16(dst_data + 2 * dst_step * out_c + 2 * out_c, m[12]); \ | |||||
| vst1_f16(dst_data + 2 * dst_step * out_c + 3 * out_c, m[13]); \ | |||||
| vst1_f16(dst_data + 2 * dst_step * out_c + 4 * out_c, m[14]); \ | |||||
| vst1_f16(dst_data + 3 * dst_step * out_c, m[15]); \ | |||||
| vst1_f16(dst_data + 3 * dst_step * out_c + out_c, m[16]); \ | |||||
| vst1_f16(dst_data + 3 * dst_step * out_c + 2 * out_c, m[17]); \ | |||||
| vst1_f16(dst_data + 3 * dst_step * out_c + 3 * out_c, m[18]); \ | |||||
| vst1_f16(dst_data + 3 * dst_step * out_c + 4 * out_c, m[19]); \ | |||||
| vst1_f16(dst_data + 4 * dst_step * out_c, m[20]); \ | |||||
| vst1_f16(dst_data + 4 * dst_step * out_c + out_c, m[21]); \ | |||||
| vst1_f16(dst_data + 4 * dst_step * out_c + 2 * out_c, m[22]); \ | |||||
| vst1_f16(dst_data + 4 * dst_step * out_c + 3 * out_c, m[23]); \ | |||||
| vst1_f16(dst_data + 4 * dst_step * out_c + 4 * out_c, m[24]); | |||||
| void OutputTransform4x2UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, | void OutputTransform4x2UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, | ||||
| int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); | int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); | ||||
| void OutputTransform4x2ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, | void OutputTransform4x2ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, | ||||
| @@ -197,25 +197,27 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::Tensor *> & | |||||
| auto conv_param = reinterpret_cast<ConvParameter *>(opParameter); | auto conv_param = reinterpret_cast<ConvParameter *>(opParameter); | ||||
| int kernel_h = conv_param->kernel_h_; | int kernel_h = conv_param->kernel_h_; | ||||
| int kernel_w = conv_param->kernel_w_; | int kernel_w = conv_param->kernel_w_; | ||||
| conv_param->input_h_ = inputs.front()->Height(); | |||||
| conv_param->input_w_ = inputs.front()->Width(); | |||||
| conv_param->output_h_ = outputs.front()->Height(); | |||||
| conv_param->output_w_ = outputs.front()->Width(); | |||||
| bool use_winograd = false; | |||||
| int out_unit; | |||||
| if (primitive != nullptr && primitive->GetInferFlag()) { | |||||
| conv_param->input_h_ = inputs.front()->Height(); | |||||
| conv_param->input_w_ = inputs.front()->Width(); | |||||
| conv_param->input_channel_ = inputs.front()->Channel(); | |||||
| conv_param->output_h_ = outputs.front()->Height(); | |||||
| conv_param->output_w_ = outputs.front()->Width(); | |||||
| conv_param->output_channel_ = outputs.front()->Channel(); | |||||
| conv_param->op_parameter_.thread_num_ = ctx->thread_num_; | |||||
| CheckIfUseWinograd(&use_winograd, &out_unit, conv_param); | |||||
| } | |||||
| kernel::LiteKernel *kernel = nullptr; | kernel::LiteKernel *kernel = nullptr; | ||||
| if (kernel_h == 1 && kernel_w == 1) { | if (kernel_h == 1 && kernel_w == 1) { | ||||
| kernel = new (std::nothrow) kernel::Convolution1x1FP16CPUKernel(opParameter, inputs, outputs, ctx, primitive); | kernel = new (std::nothrow) kernel::Convolution1x1FP16CPUKernel(opParameter, inputs, outputs, ctx, primitive); | ||||
| } else if (use_winograd) { | |||||
| kernel = new (std::nothrow) | |||||
| kernel::ConvolutionWinogradFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive, out_unit); | |||||
| } else { | } else { | ||||
| bool use_winograd = false; | |||||
| int out_unit; | |||||
| CheckIfUseWinograd(&use_winograd, &out_unit, conv_param); | |||||
| if (use_winograd) { | |||||
| kernel = new (std::nothrow) | |||||
| kernel::ConvolutionWinogradFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive, out_unit); | |||||
| } | |||||
| if (kernel_h != 1 && kernel_w != 1 && !use_winograd) { | |||||
| kernel = new (std::nothrow) kernel::ConvolutionFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||||
| } | |||||
| kernel = new (std::nothrow) kernel::ConvolutionFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||||
| } | } | ||||
| if (kernel == nullptr) { | if (kernel == nullptr) { | ||||
| MS_LOG(DEBUG) << "Create conv fp16 kernel failed."; | MS_LOG(DEBUG) << "Create conv fp16 kernel failed."; | ||||
| @@ -39,8 +39,6 @@ int ConvolutionWinogradFP16CPUKernel::WinogradFilterTransformFp16(const float16_ | |||||
| // original weight format : ohwi | // original weight format : ohwi | ||||
| auto channel_in = conv_param_->input_channel_; | auto channel_in = conv_param_->input_channel_; | ||||
| auto channel_out = conv_param_->output_channel_; | auto channel_out = conv_param_->output_channel_; | ||||
| int ic8 = UP_DIV(channel_in, C8NUM); | |||||
| int ic4 = ic8 * 2; | |||||
| int input_unit_square = input_unit_ * input_unit_; | int input_unit_square = input_unit_ * input_unit_; | ||||
| int oc_block_num = UP_DIV(channel_out, oc_block); | int oc_block_num = UP_DIV(channel_out, oc_block); | ||||
| @@ -84,17 +82,7 @@ int ConvolutionWinogradFP16CPUKernel::WinogradFilterTransformFp16(const float16_ | |||||
| MS_LOG(ERROR) << "malloc trans_out_data failed."; | MS_LOG(ERROR) << "malloc trans_out_data failed."; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| std::vector<int> shape{input_unit_ * input_unit_, oc_block_num, ic4, C4NUM, oc_block}; | |||||
| std::vector<int> strides; | |||||
| for (int i = 0; i < 4; i++) { | |||||
| int stride = 1; | |||||
| for (int j = i + 1; j < 5; j++) { | |||||
| stride *= shape[j]; | |||||
| } | |||||
| strides.push_back(stride); | |||||
| } | |||||
| int kernel_plane_stride = channel_in; | |||||
| if (oc_block == 0) { | if (oc_block == 0) { | ||||
| MS_LOG(ERROR) << "Divide by zero"; | MS_LOG(ERROR) << "Divide by zero"; | ||||
| free(tmp_weight_data); | free(tmp_weight_data); | ||||
| @@ -104,18 +92,17 @@ int ConvolutionWinogradFP16CPUKernel::WinogradFilterTransformFp16(const float16_ | |||||
| free(matrix_gt_data_fp16); | free(matrix_gt_data_fp16); | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| int stride1 = channel_in * oc_block; | |||||
| for (int i = 0; i < channel_out; i++) { | for (int i = 0; i < channel_out; i++) { | ||||
| int out_c_block = i / oc_block; | int out_c_block = i / oc_block; | ||||
| int out_c_res = i % oc_block; | int out_c_res = i % oc_block; | ||||
| int input_oz_offset = i * kernel_unit_ * kernel_unit_ * channel_in; | int input_oz_offset = i * kernel_unit_ * kernel_unit_ * channel_in; | ||||
| int output_oz_offset = out_c_block * strides[1] * input_unit_ * input_unit_ + out_c_res; | |||||
| int output_oz_offset = out_c_block * stride1 + out_c_res; | |||||
| for (int j = 0; j < channel_in; j++) { | for (int j = 0; j < channel_in; j++) { | ||||
| int ic4_block = j / C4NUM; | |||||
| int ic4_res = j % C4NUM; | |||||
| int input_iz_offset = input_oz_offset + j; | int input_iz_offset = input_oz_offset + j; | ||||
| int output_iz_offset = output_oz_offset + ic4_block * strides[2] + ic4_res * strides[3]; | |||||
| int output_iz_offset = output_oz_offset + j * oc_block; | |||||
| for (int k = 0; k < kernel_unit_ * kernel_unit_; k++) { | for (int k = 0; k < kernel_unit_ * kernel_unit_; k++) { | ||||
| int input_xy_offset = input_iz_offset + k * kernel_plane_stride; | |||||
| int input_xy_offset = input_iz_offset + k * channel_in; | |||||
| tmp_weight_data[k] = *(weight_data + input_xy_offset); | tmp_weight_data[k] = *(weight_data + input_xy_offset); | ||||
| } | } | ||||
| // now we only support row-major matrix-multiply | // now we only support row-major matrix-multiply | ||||
| @@ -125,7 +112,7 @@ int ConvolutionWinogradFP16CPUKernel::WinogradFilterTransformFp16(const float16_ | |||||
| MatrixMultiplyFp16(tmp_data, matrix_gt_data_fp16, trans_out_data, input_unit_, kernel_unit_, input_unit_); | MatrixMultiplyFp16(tmp_data, matrix_gt_data_fp16, trans_out_data, input_unit_, kernel_unit_, input_unit_); | ||||
| for (int z = 0; z < input_unit_square; z++) { | for (int z = 0; z < input_unit_square; z++) { | ||||
| int output_xy_offset = output_iz_offset + z * strides[1]; | |||||
| int output_xy_offset = output_iz_offset + z * oc_block_num * stride1; | |||||
| trans_weight_[output_xy_offset] = trans_out_data[z]; | trans_weight_[output_xy_offset] = trans_out_data[z]; | ||||
| } | } | ||||
| } | } | ||||
| @@ -142,7 +129,6 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() { | |||||
| auto filter_tensor = in_tensors_.at(kWeightIndex); | auto filter_tensor = in_tensors_.at(kWeightIndex); | ||||
| int in_channel = filter_tensor->Channel(); | int in_channel = filter_tensor->Channel(); | ||||
| int out_channel = filter_tensor->Batch(); | int out_channel = filter_tensor->Batch(); | ||||
| int ic8 = UP_DIV(in_channel, C8NUM); | |||||
| conv_param_->input_channel_ = in_channel; | conv_param_->input_channel_ = in_channel; | ||||
| conv_param_->output_channel_ = out_channel; | conv_param_->output_channel_ = out_channel; | ||||
| @@ -157,7 +143,7 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() { | |||||
| } | } | ||||
| // set data | // set data | ||||
| auto trans_matrix_data_size = input_unit_ * input_unit_ * ic8 * C8NUM * oc_block_num * oc_block * sizeof(float16_t); | |||||
| auto trans_matrix_data_size = input_unit_ * input_unit_ * in_channel * oc_block_num * oc_block * sizeof(float16_t); | |||||
| trans_weight_ = reinterpret_cast<float16_t *>(malloc(trans_matrix_data_size)); | trans_weight_ = reinterpret_cast<float16_t *>(malloc(trans_matrix_data_size)); | ||||
| if (trans_weight_ == nullptr) { | if (trans_weight_ == nullptr) { | ||||
| MS_LOG(ERROR) << "malloc trans_weight_ failed."; | MS_LOG(ERROR) << "malloc trans_weight_ failed."; | ||||
| @@ -209,9 +195,9 @@ int ConvolutionWinogradFP16CPUKernel::InitTmpBuffer() { | |||||
| const int cal_num = 16; | const int cal_num = 16; | ||||
| int channel_out = conv_param_->output_channel_; | int channel_out = conv_param_->output_channel_; | ||||
| int oc8 = UP_DIV(channel_out, C8NUM); | int oc8 = UP_DIV(channel_out, C8NUM); | ||||
| int ic8 = UP_DIV(conv_param_->input_channel_, C8NUM); | |||||
| size_t tile_buffer_size = thread_count_ * cal_num * input_unit_ * input_unit_ * ic8 * C8NUM * sizeof(float16_t); | |||||
| size_t tile_buffer_size = | |||||
| thread_count_ * cal_num * input_unit_ * input_unit_ * conv_param_->input_channel_ * sizeof(float16_t); | |||||
| trans_input_ = reinterpret_cast<float16_t *>(ctx_->allocator->Malloc(tile_buffer_size)); | trans_input_ = reinterpret_cast<float16_t *>(ctx_->allocator->Malloc(tile_buffer_size)); | ||||
| if (trans_input_ == nullptr) { | if (trans_input_ == nullptr) { | ||||
| MS_LOG(ERROR) << "malloc trans_input_ failed."; | MS_LOG(ERROR) << "malloc trans_input_ failed."; | ||||
| @@ -232,9 +218,17 @@ int ConvolutionWinogradFP16CPUKernel::InitTmpBuffer() { | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| col_buffer_ = reinterpret_cast<float16_t *>( | |||||
| ctx_->allocator->Malloc(thread_count_ * cal_num * conv_param_->input_channel_ * sizeof(float16_t))); | |||||
| if (col_buffer_ == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc col_buffer_ failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| tmp_buffer_address_list_[0] = trans_input_; | tmp_buffer_address_list_[0] = trans_input_; | ||||
| tmp_buffer_address_list_[1] = gemm_out_; | tmp_buffer_address_list_[1] = gemm_out_; | ||||
| tmp_buffer_address_list_[2] = tmp_data_; | tmp_buffer_address_list_[2] = tmp_data_; | ||||
| tmp_buffer_address_list_[3] = col_buffer_; | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -67,6 +67,10 @@ class ConvolutionWinogradFP16CPUKernel : public ConvolutionBaseFP16CPUKernel { | |||||
| ctx_->allocator->Free(gemm_out_); | ctx_->allocator->Free(gemm_out_); | ||||
| gemm_out_ = nullptr; | gemm_out_ = nullptr; | ||||
| } | } | ||||
| if (col_buffer_ != nullptr) { | |||||
| ctx_->allocator->Free(col_buffer_); | |||||
| col_buffer_ = nullptr; | |||||
| } | |||||
| } | } | ||||
| int kernel_unit_; | int kernel_unit_; | ||||
| int input_unit_; | int input_unit_; | ||||
| @@ -75,7 +79,8 @@ class ConvolutionWinogradFP16CPUKernel : public ConvolutionBaseFP16CPUKernel { | |||||
| float16_t *trans_input_ = nullptr; | float16_t *trans_input_ = nullptr; | ||||
| float16_t *gemm_out_ = nullptr; | float16_t *gemm_out_ = nullptr; | ||||
| float16_t *trans_weight_ = nullptr; | float16_t *trans_weight_ = nullptr; | ||||
| TmpBufferAddressFp16 tmp_buffer_address_list_[3]; | |||||
| float16_t *col_buffer_ = nullptr; | |||||
| TmpBufferAddressFp16 tmp_buffer_address_list_[4]; | |||||
| InputTransFp16Func in_func_; | InputTransFp16Func in_func_; | ||||
| OutputTransFp16Func out_func_; | OutputTransFp16Func out_func_; | ||||
| }; | }; | ||||