Merge pull request !5113 from ling/conv1x1tags/v1.0.0
| @@ -371,26 +371,265 @@ void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight | |||||
| } | } | ||||
| } | } | ||||
| void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, | |||||
| const int32_t *bias, int row, int col, int deep16, ConvParameter *conv_param, | |||||
| MATMUL_OPT_R_FUNC matmul_func) { | |||||
| if (matmul_func != NULL) { | |||||
| matmul_func(packed_input, packed_weight, dst, row, col, deep16, conv_param->output_channel_, input_sum, bias, | |||||
| conv_param->conv_quant_arg_.left_shift_, conv_param->conv_quant_arg_.right_shift_, | |||||
| conv_param->conv_quant_arg_.quant_multiplier_, conv_param->conv_quant_arg_.output_quant_args_[0].zp_, | |||||
| conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], | |||||
| (conv_param->conv_quant_arg_.filter_arg_num_ > 1)); | |||||
| void Conv1x1PreOpt(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel, | |||||
| size_t output_channel, size_t plane_size, ConvParameter *conv_param) { | |||||
| int ic4 = UP_ROUND(input_channel, C4NUM); | |||||
| size_t hw_8div = plane_size / C8NUM * C8NUM; | |||||
| size_t hw_8res = plane_size - hw_8div; | |||||
| size_t ic_4div = input_channel / C4NUM * C4NUM; | |||||
| int32_t filter_zp = conv_param->conv_quant_arg_.filter_quant_args_[0].zp_; | |||||
| if (conv_param->conv_quant_arg_.filter_arg_num_ == 1) { | |||||
| const int8_t *src_r = src_input; | |||||
| int8_t *pack_r = packed_input; | |||||
| /* per layer */ | |||||
| for (int hwi = 0; hwi < hw_8div; hwi += C8NUM) { | |||||
| const int8_t *src_ic = src_r; | |||||
| int8_t *pack_ic = pack_r; | |||||
| int32_t *input_sum_r = input_sum + hwi; | |||||
| #ifdef ENABLE_ARM64 | |||||
| size_t src_stride = input_channel; | |||||
| size_t ic_4res = input_channel - ic_4div; | |||||
| asm volatile( | |||||
| "dup v10.4s, wzr \n" | |||||
| "dup v11.4s, wzr \n" | |||||
| "mov x20, %[input_sum_r] \n" | |||||
| "dup v20.4s, %w[filter_zp] \n" | |||||
| "mov x10, %[src_ic] \n" | |||||
| "mov x11, %[pack_ic] \n" | |||||
| "mov x0, #0 \n" | |||||
| "1: \n" | |||||
| "cmp x0, %[ic_4div] \n" | |||||
| "add x0, x0, #4\n" | |||||
| "mov x12, x10 \n" | |||||
| "add x10, x10, #4\n" | |||||
| "blt 2f \n" | |||||
| "cmp %[ic_4res], #0\n" | |||||
| "beq 6f \n" | |||||
| "cmp %[ic_4res], #1\n" | |||||
| "beq 3f \n" | |||||
| "cmp %[ic_4res], #2\n" | |||||
| "beq 4f \n" | |||||
| "cmp %[ic_4res], #3\n" | |||||
| "beq 5f \n" | |||||
| "2: \n" | |||||
| "ld1 {v0.s}[0], [x12], %[src_stride]\n" | |||||
| "ld1 {v0.s}[1], [x12], %[src_stride]\n" | |||||
| "ld1 {v0.s}[2], [x12], %[src_stride]\n" | |||||
| "ld1 {v0.s}[3], [x12], %[src_stride]\n" | |||||
| "ld1 {v1.s}[0], [x12], %[src_stride]\n" | |||||
| "ld1 {v1.s}[1], [x12], %[src_stride]\n" | |||||
| "ld1 {v1.s}[2], [x12], %[src_stride]\n" | |||||
| "ld1 {v1.s}[3], [x12], %[src_stride]\n" | |||||
| "st1 {v0.16b}, [x11], #16\n" | |||||
| "st1 {v1.16b}, [x11], #16\n" | |||||
| "saddlp v4.8h, v0.16b \n" | |||||
| "saddlp v5.8h, v1.16b \n" | |||||
| "saddlp v0.4s, v4.8h \n" | |||||
| "saddlp v1.4s, v5.8h \n" | |||||
| "add v10.4s, v10.4s, v0.4s \n" | |||||
| "add v11.4s, v11.4s, v1.4s \n" | |||||
| "b 1b \n" | |||||
| "3: \n" /* col res 1 */ | |||||
| "dup v0.4s, wzr \n" | |||||
| "dup v1.4s, wzr \n" | |||||
| "ld1 {v0.b}[0], [x12], %[src_stride]\n" | |||||
| "ld1 {v0.b}[4], [x12], %[src_stride]\n" | |||||
| "ld1 {v0.b}[8], [x12], %[src_stride]\n" | |||||
| "ld1 {v0.b}[12], [x12], %[src_stride]\n" | |||||
| "ld1 {v1.b}[0], [x12], %[src_stride]\n" | |||||
| "ld1 {v1.b}[4], [x12], %[src_stride]\n" | |||||
| "ld1 {v1.b}[8], [x12], %[src_stride]\n" | |||||
| "ld1 {v1.b}[12], [x12], %[src_stride]\n" | |||||
| "st1 {v0.16b}, [x11], #16\n" | |||||
| "st1 {v1.16b}, [x11], #16\n" | |||||
| "saddlp v4.8h, v0.16b \n" | |||||
| "saddlp v5.8h, v1.16b \n" | |||||
| "saddlp v0.4s, v4.8h \n" | |||||
| "saddlp v1.4s, v5.8h \n" | |||||
| "add v10.4s, v10.4s, v0.4s \n" | |||||
| "add v11.4s, v11.4s, v1.4s \n" | |||||
| "b 6f \n" | |||||
| "4: \n" /* col res 2 */ | |||||
| "dup v0.4s, wzr \n" | |||||
| "dup v1.4s, wzr \n" | |||||
| "ld1 {v0.h}[0], [x12], %[src_stride]\n" | |||||
| "ld1 {v0.h}[2], [x12], %[src_stride]\n" | |||||
| "ld1 {v0.h}[4], [x12], %[src_stride]\n" | |||||
| "ld1 {v0.h}[6], [x12], %[src_stride]\n" | |||||
| "ld1 {v1.h}[0], [x12], %[src_stride]\n" | |||||
| "ld1 {v1.h}[2], [x12], %[src_stride]\n" | |||||
| "ld1 {v1.h}[4], [x12], %[src_stride]\n" | |||||
| "ld1 {v1.h}[6], [x12], %[src_stride]\n" | |||||
| "st1 {v0.16b}, [x11], #16\n" | |||||
| "st1 {v1.16b}, [x11], #16\n" | |||||
| "saddlp v4.8h, v0.16b \n" | |||||
| "saddlp v5.8h, v1.16b \n" | |||||
| "saddlp v0.4s, v4.8h \n" | |||||
| "saddlp v1.4s, v5.8h \n" | |||||
| "add v10.4s, v10.4s, v0.4s \n" | |||||
| "add v11.4s, v11.4s, v1.4s \n" | |||||
| "b 6f \n" | |||||
| "5: \n" /* col res 3 */ | |||||
| "dup v0.4s, wzr \n" | |||||
| "dup v1.4s, wzr \n" | |||||
| "add x13, x12, #2 \n" | |||||
| "ld1 {v0.h}[0], [x12], %[src_stride]\n" | |||||
| "ld1 {v0.b}[2], [x13], %[src_stride]\n" | |||||
| "ld1 {v0.h}[2], [x12], %[src_stride]\n" | |||||
| "ld1 {v0.b}[6], [x13], %[src_stride]\n" | |||||
| "ld1 {v0.h}[4], [x12], %[src_stride]\n" | |||||
| "ld1 {v0.b}[10], [x13], %[src_stride]\n" | |||||
| "ld1 {v0.h}[6], [x12], %[src_stride]\n" | |||||
| "ld1 {v0.b}[14], [x13], %[src_stride]\n" | |||||
| "ld1 {v1.h}[0], [x12], %[src_stride]\n" | |||||
| "ld1 {v1.b}[2], [x13], %[src_stride]\n" | |||||
| "ld1 {v1.h}[2], [x12], %[src_stride]\n" | |||||
| "ld1 {v1.b}[6], [x13], %[src_stride]\n" | |||||
| "ld1 {v1.h}[4], [x12], %[src_stride]\n" | |||||
| "ld1 {v1.b}[10], [x13], %[src_stride]\n" | |||||
| "ld1 {v1.h}[6], [x12], %[src_stride]\n" | |||||
| "ld1 {v1.b}[14], [x13], %[src_stride]\n" | |||||
| "st1 {v0.16b}, [x11], #16\n" | |||||
| "st1 {v1.16b}, [x11], #16\n" | |||||
| "saddlp v4.8h, v0.16b \n" | |||||
| "saddlp v5.8h, v1.16b \n" | |||||
| "saddlp v0.4s, v4.8h \n" | |||||
| "saddlp v1.4s, v5.8h \n" | |||||
| "add v10.4s, v10.4s, v0.4s \n" | |||||
| "add v11.4s, v11.4s, v1.4s \n" | |||||
| "b 6f \n" | |||||
| "6: \n" | |||||
| "mul v10.4s, v10.4s, v20.4s \n" | |||||
| "mul v11.4s, v11.4s, v20.4s \n" | |||||
| "st1 {v10.4s}, [x20], #16 \n" | |||||
| "st1 {v11.4s}, [x20], #16 \n" | |||||
| : | |||||
| : [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ input_sum_r ] "r"(input_sum_r), | |||||
| [ src_stride ] "r"(src_stride), [ ic_4div ] "r"(ic_4div), [ ic_4res ] "r"(ic_4res), | |||||
| [ filter_zp ] "r"(filter_zp) | |||||
| : "x0", "x1", "x10", "x11", "x12", "x13", "x20", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v10", "v11", | |||||
| "v20"); | |||||
| #else | |||||
| int32_t tmp_sum_value[8] = {0}; | |||||
| for (int ici = 0; ici < ic_4div; ici += C4NUM) { | |||||
| for (int i = 0; i < C8NUM; i++) { | |||||
| tmp_sum_value[i] += src_ic[0 + i * input_channel]; | |||||
| tmp_sum_value[i] += src_ic[1 + i * input_channel]; | |||||
| tmp_sum_value[i] += src_ic[2 + i * input_channel]; | |||||
| tmp_sum_value[i] += src_ic[3 + i * input_channel]; | |||||
| pack_ic[0 + i * C4NUM] = src_ic[0 + i * input_channel]; | |||||
| pack_ic[1 + i * C4NUM] = src_ic[1 + i * input_channel]; | |||||
| pack_ic[2 + i * C4NUM] = src_ic[2 + i * input_channel]; | |||||
| pack_ic[3 + i * C4NUM] = src_ic[3 + i * input_channel]; | |||||
| } | |||||
| src_ic += C4NUM; | |||||
| pack_ic += C4NUM * C8NUM; | |||||
| } | |||||
| for (int ici = ic_4div; ici < input_channel; ici += 1) { | |||||
| for (int i = 0; i < C8NUM; i++) { | |||||
| tmp_sum_value[i] += src_ic[i * input_channel]; | |||||
| pack_ic[i * C4NUM] = src_ic[i * input_channel]; | |||||
| } | |||||
| src_ic += 1; | |||||
| pack_ic += 1; | |||||
| } | |||||
| for (int i = 0; i < C8NUM; i++) { | |||||
| input_sum_r[i] = tmp_sum_value[i] * filter_zp; | |||||
| } | |||||
| #endif | |||||
| src_r += input_channel * C8NUM; | |||||
| pack_r += ic4 * C8NUM; | |||||
| } | |||||
| if (hw_8div != plane_size) { | |||||
| memset(pack_r, 0, C8NUM * ic4); | |||||
| for (int hwi = hw_8div; hwi < plane_size; hwi += 1) { | |||||
| int32_t tmp_sum_value = 0; | |||||
| const int8_t *src_ic = src_r; | |||||
| int8_t *pack_ic = pack_r; | |||||
| for (int ici = 0; ici < ic_4div; ici += C4NUM) { | |||||
| tmp_sum_value += src_ic[0]; | |||||
| tmp_sum_value += src_ic[1]; | |||||
| tmp_sum_value += src_ic[2]; | |||||
| tmp_sum_value += src_ic[3]; | |||||
| pack_ic[0] = src_ic[0]; | |||||
| pack_ic[1] = src_ic[1]; | |||||
| pack_ic[2] = src_ic[2]; | |||||
| pack_ic[3] = src_ic[3]; | |||||
| src_ic += C4NUM; | |||||
| pack_ic += C4NUM * C8NUM; | |||||
| } | |||||
| for (int ici = ic_4div; ici < input_channel; ici += 1) { | |||||
| tmp_sum_value += src_ic[0]; | |||||
| pack_ic[0] = src_ic[0]; | |||||
| src_ic += 1; | |||||
| pack_ic += 1; | |||||
| } | |||||
| input_sum[hwi] = tmp_sum_value * filter_zp; | |||||
| src_r += input_channel; | |||||
| pack_r += C4NUM; | |||||
| } | |||||
| for (int hwi = plane_size; hwi < plane_size + hw_8res; hwi++) { | |||||
| input_sum[hwi] = 0; | |||||
| } | |||||
| } | |||||
| } else { | } else { | ||||
| MatMulInt8_16x4_r(packed_input, packed_weight, dst, row, col, deep16, conv_param->output_channel_, input_sum, bias, | |||||
| conv_param->conv_quant_arg_.left_shift_, conv_param->conv_quant_arg_.right_shift_, | |||||
| conv_param->conv_quant_arg_.quant_multiplier_, | |||||
| conv_param->conv_quant_arg_.output_quant_args_[0].zp_, | |||||
| conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], | |||||
| (conv_param->conv_quant_arg_.filter_arg_num_ > 1)); | |||||
| /* per channel */ | |||||
| RowMajor2Row4x8MajorInt8(src_input, packed_input, plane_size, input_channel); | |||||
| PackInputSum8x4Int8(packed_input, input_sum, input_channel, output_channel, plane_size, conv_param); | |||||
| } | } | ||||
| return; | return; | ||||
| } | } | ||||
| void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, | |||||
| const int32_t *bias, int row, int col, int deep4, ConvParameter *conv_param, | |||||
| MATMUL_OPT_R_FUNC matmul_func) { | |||||
| matmul_func(packed_input, packed_weight, dst, row, col, deep4, conv_param->output_channel_, input_sum, bias, | |||||
| conv_param->conv_quant_arg_.left_shift_, conv_param->conv_quant_arg_.right_shift_, | |||||
| conv_param->conv_quant_arg_.quant_multiplier_, conv_param->conv_quant_arg_.output_quant_args_[0].zp_, | |||||
| conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], false); | |||||
| return; | |||||
| } | |||||
| void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, | |||||
| const int32_t *bias, int row, int col, int deep16, ConvParameter *conv_param) { | |||||
| #ifdef ENABLE_ARM64 | |||||
| MatmulInt8Neon64(packed_input, packed_weight, dst, UP_ROUND(row, C4NUM), UP_ROUND(col, C4NUM), deep16, input_sum, | |||||
| bias, conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], | |||||
| conv_param->conv_quant_arg_.output_quant_args_[0].zp_, | |||||
| conv_param->conv_quant_arg_.quant_multiplier_[0], conv_param->conv_quant_arg_.left_shift_[0], | |||||
| conv_param->conv_quant_arg_.right_shift_[0], row, col, conv_param->output_channel_); | |||||
| #else | |||||
| MatMulInt8_16x4_r(packed_input, packed_weight, dst, row, col, deep16, conv_param->output_channel_, input_sum, bias, | |||||
| conv_param->conv_quant_arg_.left_shift_, conv_param->conv_quant_arg_.right_shift_, | |||||
| conv_param->conv_quant_arg_.quant_multiplier_, | |||||
| conv_param->conv_quant_arg_.output_quant_args_[0].zp_, conv_param->conv_quant_arg_.out_act_min_[0], | |||||
| conv_param->conv_quant_arg_.out_act_max_[0], false); | |||||
| #endif | |||||
| return; | |||||
| } | |||||
| // int8 convolution 3x3 | // int8 convolution 3x3 | ||||
| void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data, | void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data, | ||||
| int16_t *tile_buffer, int16_t *block_unit_buffer, int32_t *tmp_dst_buffer, int8_t *tmp_out, | int16_t *tile_buffer, int16_t *block_unit_buffer, int32_t *tmp_dst_buffer, int8_t *tmp_out, | ||||
| @@ -54,9 +54,13 @@ void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight | |||||
| ConvParameter *conv_param, GEMM_FUNC gemm_func); | ConvParameter *conv_param, GEMM_FUNC gemm_func); | ||||
| // int8 convolution 1x1 | // int8 convolution 1x1 | ||||
| void Conv1x1PreOpt(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel, | |||||
| size_t output_channel, size_t plane_size, ConvParameter *conv_param); | |||||
| void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, | void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, | ||||
| const int32_t *bias, int row, int col, int deep16, ConvParameter *conv_param, | |||||
| MATMUL_OPT_R_FUNC matmul_func); | |||||
| const int32_t *bias, int row, int col, int deep16, ConvParameter *conv_param); | |||||
| void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, | |||||
| const int32_t *bias, int row, int col, int deep4, ConvParameter *conv_param, | |||||
| MATMUL_OPT_R_FUNC matmul_func); | |||||
| // int8 convolution 3x3 | // int8 convolution 3x3 | ||||
| void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data, | void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data, | ||||
| @@ -172,7 +172,7 @@ void DeConvPackWeightSum(int8_t *weight, int32_t *weight_sum, int32_t input_zp, | |||||
| void DeConvPackInputSum(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16, | void DeConvPackInputSum(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16, | ||||
| bool suppport_opt) { | bool suppport_opt) { | ||||
| /* optimize normal -> same layout */ | /* optimize normal -> same layout */ | ||||
| PackInputSum16x4PerLater(src, dst, filter_zp, row4, col16); | |||||
| PackInputSum16x4PerLayer(src, dst, filter_zp, row4, col16); | |||||
| return; | return; | ||||
| } | } | ||||
| @@ -36,7 +36,24 @@ void RowMajor2Row4x16MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int co | |||||
| for (int c = 0; c < col; c++) { | for (int c = 0; c < col; c++) { | ||||
| int cd16 = c / C16NUM; | int cd16 = c / C16NUM; | ||||
| int cm16 = c % C16NUM; | int cm16 = c % C16NUM; | ||||
| dst_ptr[cd16 * col16 * C4NUM + rd4 * C4NUM * C16NUM + rm4 * C16NUM + cm16] = src_ptr[r * col16 + c]; | |||||
| int dst_index = rd4 * col16 * C4NUM + cd16 * C4NUM * C16NUM + rm4 * C16NUM + cm16; | |||||
| int src_index = r * col + c; | |||||
| dst_ptr[dst_index] = src_ptr[src_index]; | |||||
| } | |||||
| } | |||||
| } | |||||
| void RowMajor2Row8x4MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col) { | |||||
| int col4 = UP_ROUND(col, C4NUM); | |||||
| for (int r = 0; r < row; r++) { | |||||
| int rd8 = r / C8NUM; | |||||
| int rm8 = r % C8NUM; | |||||
| for (int c = 0; c < col; c++) { | |||||
| int cd4 = c / C4NUM; | |||||
| int cm4 = c % C4NUM; | |||||
| int dst_index = rd8 * col4 * C8NUM + cd4 * C8NUM * C4NUM + rm8 * C4NUM + cm4; | |||||
| int src_index = r * col + c; | |||||
| dst_ptr[dst_index] = src_ptr[src_index]; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -50,6 +67,29 @@ void MatrixPack4x16UnitInt8(int8_t *src, int8_t *dst, int row, int col, int stri | |||||
| return; | return; | ||||
| } | } | ||||
| void MatrixEmptyInt8(int8_t *dst, int row, int col) { | |||||
| for (int r = 0; r < row; r++) { | |||||
| int8_t *dst_r = dst + r * C16NUM; | |||||
| memset(dst_r, 0, col * sizeof(int8_t)); | |||||
| } | |||||
| return; | |||||
| } | |||||
| void RowMajor2Row4x8MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col) { | |||||
| /* Row-major to row16x4-major (block row-major) */ | |||||
| int col4 = UP_ROUND(col, C4NUM); | |||||
| for (int r = 0; r < row; r++) { | |||||
| int rd8 = r / C8NUM, rm8 = r % C8NUM; | |||||
| for (int c = 0; c < col; c++) { | |||||
| int cd4 = c / C4NUM, cm4 = c % C4NUM; | |||||
| int src_index = r * col + c; | |||||
| int dst_index = rd8 * col4 * C8NUM + cd4 * C4NUM * C8NUM + rm8 * C4NUM + cm4; | |||||
| dst_ptr[dst_index] = src_ptr[src_index]; | |||||
| } | |||||
| } | |||||
| return; | |||||
| } | |||||
| void RowMajor2Row16x4MajorInt8(void *src_ptr, void *dst_ptr, int row, int col) { | void RowMajor2Row16x4MajorInt8(void *src_ptr, void *dst_ptr, int row, int col) { | ||||
| /* Row-major to row16x4-major (block row-major) */ | /* Row-major to row16x4-major (block row-major) */ | ||||
| int col16 = UP_ROUND(col, C16NUM); | int col16 = UP_ROUND(col, C16NUM); | ||||
| @@ -90,12 +130,15 @@ void RowMajor2Row16x4MajorInt8(void *src_ptr, void *dst_ptr, int row, int col) { | |||||
| if (col != col_16div) { | if (col != col_16div) { | ||||
| MatrixPack4x16UnitInt8(src_r + col_16div, dst_r + col_16div * C4NUM, C4NUM, col_16res, col); | MatrixPack4x16UnitInt8(src_r + col_16div, dst_r + col_16div * C4NUM, C4NUM, col_16res, col); | ||||
| MatrixEmptyInt8(dst_r + col_16div * C4NUM + col_16res, C4NUM, C16NUM - col_16res); | |||||
| } | } | ||||
| src_r += C4NUM * col; | src_r += C4NUM * col; | ||||
| dst_r += C4NUM * col16; | dst_r += C4NUM * col16; | ||||
| } | } | ||||
| if (row != row_4div) { | if (row != row_4div) { | ||||
| memset(dst_r, 0, C4NUM * col16); | |||||
| for (int ci = 0; ci < col_16div; ci += C16NUM) { | for (int ci = 0; ci < col_16div; ci += C16NUM) { | ||||
| MatrixPack4x16UnitInt8(src_r + ci, dst_r + ci * C4NUM, row_4res, C16NUM, col); | MatrixPack4x16UnitInt8(src_r + ci, dst_r + ci * C4NUM, row_4res, C16NUM, col); | ||||
| } | } | ||||
| @@ -172,6 +215,38 @@ void MatMulInt8_16x4_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row | |||||
| return; | return; | ||||
| } | } | ||||
| void MatMulInt8_8x8_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, | |||||
| size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift, | |||||
| int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi, | |||||
| bool per_channel) { | |||||
| /* row8x4-major * row4x8-major => (int8)row-major */ | |||||
| for (int r = 0; r < row; r++) { | |||||
| for (int c = 0; c < col; c++) { | |||||
| int r8div = r / C8NUM, r8mod = r % C8NUM; | |||||
| int c8div = c / C8NUM, c8mod = c % C8NUM; | |||||
| size_t ci = r * stride + c; | |||||
| int32_t value = 0; | |||||
| for (int d = 0; d < deep_4; d++) { | |||||
| int d4div = d / C4NUM, d4mod = d % C4NUM; | |||||
| size_t ai = r8div * deep_4 * C8NUM + d4div * C8NUM * C4NUM + r8mod * C4NUM + d4mod; | |||||
| size_t bi = c8div * deep_4 * C8NUM + d4div * C8NUM * C4NUM + c8mod * C4NUM + d4mod; | |||||
| value = value + a[ai] * b[bi]; | |||||
| } | |||||
| int32_t cur_input_sum = per_channel ? input_sum[c8div * UP_ROUND(row, C8NUM) + r * C8NUM + c8mod] : input_sum[r]; | |||||
| value -= cur_input_sum; | |||||
| value += bias[c]; | |||||
| int32_t cur_left_shift = per_channel ? left_shift[c] : left_shift[0]; | |||||
| int32_t cur_right_shift = per_channel ? right_shift[c] : right_shift[0]; | |||||
| int32_t cur_multiplier = per_channel ? multiplier[c] : multiplier[0]; | |||||
| value = MultiplyByQuantizedMultiplier(value, cur_multiplier, cur_left_shift, cur_right_shift) + output_zp; | |||||
| value = MSMIN(maxi, value); | |||||
| value = MSMAX(mini, value); | |||||
| dst[ci] = (int8_t)value; | |||||
| } | |||||
| } | |||||
| return; | |||||
| } | |||||
| /* row4x16-major * col16x4-major => row4x4-major */ | /* row4x16-major * col16x4-major => row4x4-major */ | ||||
| void MatmulInt8(const int8_t *a, const int8_t *b, int8_t *dst, const int *a_sums, const int *bias, int act_min, | void MatmulInt8(const int8_t *a, const int8_t *b, int8_t *dst, const int *a_sums, const int *bias, int act_min, | ||||
| int act_max, int out_zp, int multiplier, int left_shift, int right_shift, int row, int col, int deep16, | int act_max, int out_zp, int multiplier, int left_shift, int right_shift, int row, int col, int deep16, | ||||
| @@ -35,6 +35,13 @@ void RowMajor2Row4x16MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int co | |||||
| void RowMajor2Col8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col); | void RowMajor2Col8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col); | ||||
| void RowMajor2Row16x4MajorInt8(void *src_ptr, void *dst_ptr, int row, int col); | void RowMajor2Row16x4MajorInt8(void *src_ptr, void *dst_ptr, int row, int col); | ||||
| void MatMulInt8_8x8_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, | |||||
| size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift, | |||||
| int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi, | |||||
| bool per_channel); | |||||
| void RowMajor2Row8x4MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col); | |||||
| void RowMajor2Row4x8MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col); | |||||
| void RowMajor2Row4x16Major(int8_t *src, int row, int col, int8_t *dst, int col_16); | void RowMajor2Row4x16Major(int8_t *src, int row, int col, int8_t *dst, int col_16); | ||||
| void RowMajor2Col16x4Major(int8_t *src, int row, int col, int8_t *dst, int row_16); | void RowMajor2Col16x4Major(int8_t *src, int row, int col, int8_t *dst, int row_16); | ||||
| void CalcInputSums(int8_t *input, int row, int col, int weight_zp, int *dst, DataOrder order); | void CalcInputSums(int8_t *input, int row, int col, int weight_zp, int *dst, DataOrder order); | ||||
| @@ -22,7 +22,7 @@ | |||||
| typedef void (*MATMUL_OPT_R4_FUNC)(const int8_t *a, const int8_t *b, int *dst, int row_4, int col_4, int deep_16, | typedef void (*MATMUL_OPT_R4_FUNC)(const int8_t *a, const int8_t *b, int *dst, int row_4, int col_4, int deep_16, | ||||
| const int *input_sum, const int *bias); | const int *input_sum, const int *bias); | ||||
| typedef void (*MATMUL_OPT_R_FUNC)(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_16, | |||||
| typedef void (*MATMUL_OPT_R_FUNC)(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, | |||||
| size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift, | size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift, | ||||
| int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, | int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, | ||||
| int32_t maxi, bool per_channel); | int32_t maxi, bool per_channel); | ||||
| @@ -35,11 +35,15 @@ typedef struct MatMulParameter { | |||||
| OpParameter op_parameter_; | OpParameter op_parameter_; | ||||
| int row_; | int row_; | ||||
| int col_; | int col_; | ||||
| int row_4_; | |||||
| int row_8_; | int row_8_; | ||||
| int row_12_; | int row_12_; | ||||
| int row_16_; | int row_16_; | ||||
| int col_4_; | |||||
| int col_8_; | int col_8_; | ||||
| int deep_; | int deep_; | ||||
| int deep_4_; | |||||
| int deep_16_; | |||||
| bool has_bias_; | bool has_bias_; | ||||
| int batch; | int batch; | ||||
| bool a_transpose_; /* false : row-major */ | bool a_transpose_; /* false : row-major */ | ||||
| @@ -37,7 +37,7 @@ void IndirectGemmInt8_optimize_handler(int8_t *dst, const int8_t *src, const int | |||||
| size_t ksize, size_t ic4, size_t output_channel, size_t offset, | size_t ksize, size_t ic4, size_t output_channel, size_t offset, | ||||
| const int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, | const int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, | ||||
| int32_t *out_multiplier, int32_t *shift_before, int32_t *shift_after, | int32_t *out_multiplier, int32_t *shift_before, int32_t *shift_after, | ||||
| size_t asymmetric, size_t per_channel) { | |||||
| size_t asymmetric, size_t per_channel) { | |||||
| return IndirectGemmInt8_24x4_dp(dst, src, weight, bias, ksize, ic4, output_channel, offset, input_sum, act_min, | return IndirectGemmInt8_24x4_dp(dst, src, weight, bias, ksize, ic4, output_channel, offset, input_sum, act_min, | ||||
| act_max, out_zp, out_multiplier, shift_before, shift_after, asymmetric, per_channel); | act_max, out_zp, out_multiplier, shift_before, shift_after, asymmetric, per_channel); | ||||
| } | } | ||||
| @@ -47,7 +47,7 @@ void MatMulR4Int8_optimize_handler(const int8_t *a, const int8_t *b, int *dst, i | |||||
| return MatMulOptR4Int8Neon64(a, b, dst, row4, col4, deep16, input_sum, bias); | return MatMulOptR4Int8Neon64(a, b, dst, row4, col4, deep16, input_sum, bias); | ||||
| } | } | ||||
| void MatMulRInt8_optimize_handler(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_16, | |||||
| void MatMulRInt8_optimize_handler(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, | |||||
| size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift, | size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift, | ||||
| int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, | int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, | ||||
| int32_t maxi, bool per_channel) { | int32_t maxi, bool per_channel) { | ||||
| @@ -194,7 +194,7 @@ void Pack1x1WeightFp32(const float *weight_data, float *packed_weight, ConvParam | |||||
| return; | return; | ||||
| } | } | ||||
| void PackInputSum16x4PerLater(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16) { | |||||
| void PackInputSum16x4PerLayer(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16) { | |||||
| /* optimize normal -> same layout */ | /* optimize normal -> same layout */ | ||||
| #ifdef ENABLE_ARM64 | #ifdef ENABLE_ARM64 | ||||
| asm volatile( | asm volatile( | ||||
| @@ -267,12 +267,12 @@ void PackInputSum16x4PerLater(const int8_t *src, int32_t *dst, int32_t filter_zp | |||||
| return; | return; | ||||
| } | } | ||||
| void PackInputSum16x4Int8(int8_t *input_value, int32_t *input_sum, size_t input_channel, size_t output_channel, | |||||
| void PackInputSum16x4Int8(const int8_t *input_value, int32_t *input_sum, size_t input_channel, size_t output_channel, | |||||
| size_t plane_size, ConvParameter *conv_param) { | size_t plane_size, ConvParameter *conv_param) { | ||||
| size_t hw4 = UP_ROUND(plane_size, C4NUM); | size_t hw4 = UP_ROUND(plane_size, C4NUM); | ||||
| size_t ic16 = UP_ROUND(input_channel, C16NUM); | size_t ic16 = UP_ROUND(input_channel, C16NUM); | ||||
| if (conv_param->conv_quant_arg_.filter_arg_num_ == 1) { | if (conv_param->conv_quant_arg_.filter_arg_num_ == 1) { | ||||
| PackInputSum16x4PerLater(input_value, input_sum, conv_param->conv_quant_arg_.filter_quant_args_[0].zp_, hw4, ic16); | |||||
| PackInputSum16x4PerLayer(input_value, input_sum, conv_param->conv_quant_arg_.filter_quant_args_[0].zp_, hw4, ic16); | |||||
| } else { | } else { | ||||
| for (int ri = 0; ri < plane_size; ri++) { | for (int ri = 0; ri < plane_size; ri++) { | ||||
| int ri4div = ri / C4NUM, ri4mod = ri % C4NUM; | int ri4div = ri / C4NUM, ri4mod = ri % C4NUM; | ||||
| @@ -293,6 +293,40 @@ void PackInputSum16x4Int8(int8_t *input_value, int32_t *input_sum, size_t input_ | |||||
| return; | return; | ||||
| } | } | ||||
| void PackInputSum8x4Int8(const int8_t *input_value, int32_t *input_sum, size_t input_channel, size_t output_channel, | |||||
| size_t plane_size, ConvParameter *conv_param) { | |||||
| size_t hw8 = UP_ROUND(plane_size, C8NUM); | |||||
| size_t ic4 = UP_ROUND(input_channel, C4NUM); | |||||
| if (conv_param->conv_quant_arg_.filter_arg_num_ == 1) { | |||||
| for (int r = 0; r < hw8; r++) { | |||||
| int32_t tmp_value = 0; | |||||
| for (int c = 0; c < ic4; c++) { | |||||
| int r8div = r / C8NUM, r8mod = r % C8NUM, c4div = c / C4NUM, c4mod = c % C4NUM; | |||||
| int src_index = r8div * C8NUM * ic4 + c4div * C8NUM * C4NUM + r8mod * C4NUM + c4mod; | |||||
| tmp_value += input_value[src_index]; | |||||
| } | |||||
| input_sum[r] = tmp_value * conv_param->conv_quant_arg_.filter_quant_args_[0].zp_; | |||||
| } | |||||
| } else { | |||||
| for (int ri = 0; ri < plane_size; ri++) { | |||||
| int ri8div = ri / C8NUM, ri8mod = ri % C8NUM; | |||||
| for (int ci = 0; ci < output_channel; ci++) { | |||||
| int32_t tmp_sum_value = 0; | |||||
| int ci8div = ci / C8NUM, ci8mod = ci % C8NUM; | |||||
| int32_t filter_zp = conv_param->conv_quant_arg_.filter_quant_args_[ci].zp_; | |||||
| for (int di = 0; di < input_channel; di++) { | |||||
| size_t di4div = di / C4NUM, di4mod = di % C4NUM; | |||||
| int src_index = ri8div * C8NUM * ic4 + di4div * C8NUM * C4NUM + ri8mod * C4NUM + di4mod; | |||||
| tmp_sum_value += input_value[src_index]; | |||||
| } | |||||
| int dst_index = ci8div * C8NUM * hw8 + ri * C8NUM + ci8mod; | |||||
| input_sum[dst_index] = tmp_sum_value * filter_zp; | |||||
| } | |||||
| } | |||||
| } | |||||
| return; | |||||
| } | |||||
| void Im2ColPackUnitFp32(const float *input_data, ConvParameter *conv_param, float *packed_input, int real_cal_num, | void Im2ColPackUnitFp32(const float *input_data, ConvParameter *conv_param, float *packed_input, int real_cal_num, | ||||
| int block_index) { | int block_index) { | ||||
| // input format : nhwc | // input format : nhwc | ||||
| @@ -35,15 +35,18 @@ void Im2ColPackUnitInt8(const int8_t *input_data, int8_t *packed_input, int real | |||||
| void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int real_cal_num, int block_index, | void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int real_cal_num, int block_index, | ||||
| int32_t *input_sum, ConvParameter *conv_param); | int32_t *input_sum, ConvParameter *conv_param); | ||||
| void PackInputSum16x4PerLater(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16); | |||||
| void PackInputSum16x4PerLayer(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16); | |||||
| void Conv1x1InputPack(const void *src_ptr, void *dst_ptr, ConvParameter *conv_param, int data_size); | void Conv1x1InputPack(const void *src_ptr, void *dst_ptr, ConvParameter *conv_param, int data_size); | ||||
| void Pack1x1WeightFp32(const float *weight_data, float *packed_weight, ConvParameter *conv_param); | void Pack1x1WeightFp32(const float *weight_data, float *packed_weight, ConvParameter *conv_param); | ||||
| void PackInputSum16x4Int8(int8_t *input_value, int32_t *input_sum, size_t input_channel, size_t output_channel, | |||||
| void PackInputSum16x4Int8(const int8_t *input_value, int32_t *input_sum, size_t input_channel, size_t output_channel, | |||||
| size_t plane_size, ConvParameter *conv_param); | size_t plane_size, ConvParameter *conv_param); | ||||
| void PackInputSum8x4Int8(const int8_t *input_value, int32_t *input_sum, size_t input_channel, size_t output_channel, | |||||
| size_t plane_size, ConvParameter *conv_param); | |||||
| void MatrixPack(const float *src, float *dst, int row, int ic4, int stride); | void MatrixPack(const float *src, float *dst, int row, int ic4, int stride); | ||||
| void PackInputToC8Int8(const int8_t *input_data, int16_t *packed_input, ConvParameter *conv_param); | void PackInputToC8Int8(const int8_t *input_data, int16_t *packed_input, ConvParameter *conv_param); | ||||
| @@ -22,6 +22,15 @@ using mindspore::lite::RET_MEMORY_FAILED; | |||||
| using mindspore::lite::RET_OK; | using mindspore::lite::RET_OK; | ||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| int Convolution1x1Int8Pre(void *cdata, int task_id) { | |||||
| auto conv = reinterpret_cast<Convolution1x1Int8CPUKernel *>(cdata); | |||||
| auto error_code = conv->RunPre(task_id); | |||||
| if (error_code != RET_OK) { | |||||
| MS_LOG(ERROR) << "conv1x1 Int8 RunPre error task_id[" << task_id << "] error_code[" << error_code << "]"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| Convolution1x1Int8CPUKernel::~Convolution1x1Int8CPUKernel() { | Convolution1x1Int8CPUKernel::~Convolution1x1Int8CPUKernel() { | ||||
| if (matmul_param_ != nullptr) { | if (matmul_param_ != nullptr) { | ||||
| @@ -37,20 +46,16 @@ Convolution1x1Int8CPUKernel::~Convolution1x1Int8CPUKernel() { | |||||
| } | } | ||||
| void Convolution1x1Int8CPUKernel::FreeResizeBuf() { | void Convolution1x1Int8CPUKernel::FreeResizeBuf() { | ||||
| if (packed_input_ != nullptr) { | |||||
| free(packed_input_); | |||||
| packed_input_ = nullptr; | |||||
| } | |||||
| if (input_sum_ != nullptr) { | |||||
| free(input_sum_); | |||||
| input_sum_ = nullptr; | |||||
| if (pre_trans_input_ && input_ptr_ != nullptr) { | |||||
| free(input_ptr_); | |||||
| input_ptr_ = nullptr; | |||||
| } | } | ||||
| return; | return; | ||||
| } | } | ||||
| void Convolution1x1Int8CPUKernel::CheckSupportOptimize() { | void Convolution1x1Int8CPUKernel::CheckSupportOptimize() { | ||||
| support_optimize_ = false; | |||||
| matmul_func_ = MatMulInt8_16x4_r; | |||||
| support_optimize_ = true; | |||||
| matmul_func_ = MatMulInt8_8x8_r; | |||||
| #ifdef ENABLE_ARM64 | #ifdef ENABLE_ARM64 | ||||
| void *optimize_op_handler = OptimizeModule::GetInstance()->optimized_op_handler_; | void *optimize_op_handler = OptimizeModule::GetInstance()->optimized_op_handler_; | ||||
| if (optimize_op_handler != nullptr) { | if (optimize_op_handler != nullptr) { | ||||
| @@ -63,14 +68,13 @@ void Convolution1x1Int8CPUKernel::CheckSupportOptimize() { | |||||
| matmul_func_ = nullptr; | matmul_func_ = nullptr; | ||||
| } else { | } else { | ||||
| support_optimize_ = true; | support_optimize_ = true; | ||||
| matmul_func_ = MatMulInt8_8x8_r; | |||||
| } | } | ||||
| } else { | } else { | ||||
| support_optimize_ = false; | support_optimize_ = false; | ||||
| matmul_func_ = nullptr; | matmul_func_ = nullptr; | ||||
| } | } | ||||
| #endif | #endif | ||||
| matmul_func_ = MatMulInt8_16x4_r; | |||||
| return; | return; | ||||
| } | } | ||||
| @@ -80,24 +84,32 @@ int Convolution1x1Int8CPUKernel::InitWeightBias() { | |||||
| auto output_channel = filter_tensor->Batch(); | auto output_channel = filter_tensor->Batch(); | ||||
| /* weight */ | /* weight */ | ||||
| size_t size = UP_ROUND(input_channel, C16NUM) * UP_ROUND(output_channel, C4NUM) * sizeof(int8_t); | |||||
| size_t size = support_optimize_ ? UP_ROUND(input_channel, C4NUM) * UP_ROUND(output_channel, C8NUM) * sizeof(int8_t) | |||||
| : UP_ROUND(input_channel, C16NUM) * UP_ROUND(output_channel, C4NUM) * sizeof(int8_t); | |||||
| packed_weight_ = reinterpret_cast<int8_t *>(malloc(size)); | packed_weight_ = reinterpret_cast<int8_t *>(malloc(size)); | ||||
| if (packed_weight_ == nullptr) { | if (packed_weight_ == nullptr) { | ||||
| MS_LOG(ERROR) << "Conv1x1 int8 Malloc weight error!"; | MS_LOG(ERROR) << "Conv1x1 int8 Malloc weight error!"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| memset(packed_weight_, 0, size); | memset(packed_weight_, 0, size); | ||||
| RowMajor2Row4x16MajorInt8(reinterpret_cast<int8_t *>(filter_tensor->Data()), packed_weight_, output_channel, | |||||
| input_channel); | |||||
| if (support_optimize_) { | |||||
| RowMajor2Row8x4MajorInt8(reinterpret_cast<int8_t *>(filter_tensor->Data()), packed_weight_, output_channel, | |||||
| input_channel); | |||||
| } else { | |||||
| RowMajor2Row4x16MajorInt8(reinterpret_cast<int8_t *>(filter_tensor->Data()), packed_weight_, output_channel, | |||||
| input_channel); | |||||
| } | |||||
| /* bias = bias - v2 x zp1 + zp1 x zp2 */ | /* bias = bias - v2 x zp1 + zp1 x zp2 */ | ||||
| int col4 = UP_ROUND(output_channel, C4NUM); | int col4 = UP_ROUND(output_channel, C4NUM); | ||||
| bias_data_ = malloc(col4 * sizeof(int32_t)); | |||||
| int col8 = UP_ROUND(output_channel, C8NUM); | |||||
| size = support_optimize_ ? col8 * sizeof(int32_t) : col4 * sizeof(int32_t); | |||||
| bias_data_ = malloc(size); | |||||
| if (bias_data_ == nullptr) { | if (bias_data_ == nullptr) { | ||||
| MS_LOG(ERROR) << "Conv1x1 int8 Malloc bias_ptr_ error!"; | MS_LOG(ERROR) << "Conv1x1 int8 Malloc bias_ptr_ error!"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| memset(bias_data_, 0, col4 * sizeof(int32_t)); | |||||
| memset(bias_data_, 0, size); | |||||
| if (in_tensors_.size() == 3) { | if (in_tensors_.size() == 3) { | ||||
| memcpy(bias_data_, in_tensors_[kBiasIndex]->Data(), output_channel * sizeof(int32_t)); | memcpy(bias_data_, in_tensors_[kBiasIndex]->Data(), output_channel * sizeof(int32_t)); | ||||
| } | } | ||||
| @@ -119,9 +131,6 @@ int Convolution1x1Int8CPUKernel::InitWeightBias() { | |||||
| } | } | ||||
| int Convolution1x1Int8CPUKernel::Init() { | int Convolution1x1Int8CPUKernel::Init() { | ||||
| if (!InferShapeDone()) { | |||||
| return RET_OK; | |||||
| } | |||||
| matmul_param_ = new (std::nothrow) MatMulParameter(); | matmul_param_ = new (std::nothrow) MatMulParameter(); | ||||
| if (matmul_param_ == nullptr) { | if (matmul_param_ == nullptr) { | ||||
| MS_LOG(ERROR) << "Init matmul_param_ failed."; | MS_LOG(ERROR) << "Init matmul_param_ failed."; | ||||
| @@ -142,6 +151,9 @@ int Convolution1x1Int8CPUKernel::Init() { | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| if (!InferShapeDone()) { | |||||
| return RET_OK; | |||||
| } | |||||
| return ReSize(); | return ReSize(); | ||||
| } | } | ||||
| @@ -152,30 +164,52 @@ int Convolution1x1Int8CPUKernel::InitParam() { | |||||
| matmul_param_->row_ = conv_param_->output_h_ * conv_param_->output_w_; | matmul_param_->row_ = conv_param_->output_h_ * conv_param_->output_w_; | ||||
| matmul_param_->deep_ = conv_param_->input_channel_; | matmul_param_->deep_ = conv_param_->input_channel_; | ||||
| matmul_param_->col_ = conv_param_->output_channel_; | matmul_param_->col_ = conv_param_->output_channel_; | ||||
| thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->col_, C4NUM)); | |||||
| thread_stride_ = UP_DIV(UP_DIV(matmul_param_->col_, C4NUM), thread_count_); | |||||
| size_t size = UP_ROUND(matmul_param_->row_, C4NUM) * UP_ROUND(matmul_param_->deep_, C16NUM); | |||||
| packed_input_ = reinterpret_cast<int8_t *>(malloc(size * sizeof(int8_t))); | |||||
| if (packed_input_ == nullptr) { | |||||
| MS_LOG(ERROR) << "conv1x1 int8 Malloc packed_input_ error!"; | |||||
| return RET_ERROR; | |||||
| matmul_param_->col_4_ = UP_ROUND(matmul_param_->col_, C4NUM); | |||||
| matmul_param_->col_8_ = UP_ROUND(matmul_param_->col_, C8NUM); | |||||
| matmul_param_->row_4_ = UP_ROUND(matmul_param_->row_, C4NUM); | |||||
| matmul_param_->row_8_ = UP_ROUND(matmul_param_->row_, C8NUM); | |||||
| matmul_param_->deep_4_ = UP_ROUND(matmul_param_->deep_, C4NUM); | |||||
| matmul_param_->deep_16_ = UP_ROUND(matmul_param_->deep_, C16NUM); | |||||
| /* init input sum size */ | |||||
| if (support_optimize_) { | |||||
| if (conv_quant_arg_->per_channel_ & FILTER_PER_CHANNEL) { | |||||
| input_sum_size = UP_ROUND(conv_param_->output_channel_, C8NUM) * UP_ROUND(matmul_param_->row_, C8NUM); | |||||
| } else { | |||||
| input_sum_size = UP_ROUND(matmul_param_->row_, C8NUM); | |||||
| } | |||||
| } else { | |||||
| if (conv_quant_arg_->per_channel_ & FILTER_PER_CHANNEL) { | |||||
| input_sum_size = UP_ROUND(conv_param_->output_channel_, C4NUM) * UP_ROUND(matmul_param_->row_, C4NUM); | |||||
| } else { | |||||
| input_sum_size = UP_ROUND(matmul_param_->row_, C4NUM); | |||||
| } | |||||
| } | } | ||||
| memset(packed_input_, 0, size * sizeof(int8_t)); | |||||
| if (conv_quant_arg_->per_channel_ & FILTER_PER_CHANNEL) { | |||||
| size = UP_ROUND(conv_param_->output_channel_, C4NUM) * UP_ROUND(matmul_param_->row_, C4NUM); | |||||
| if (support_optimize_) { | |||||
| thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->col_, C8NUM)); | |||||
| thread_stride_ = UP_DIV(UP_DIV(matmul_param_->col_, C8NUM), thread_count_); | |||||
| } else { | } else { | ||||
| size = UP_ROUND(matmul_param_->row_, C4NUM); | |||||
| thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->col_, C4NUM)); | |||||
| thread_stride_ = UP_DIV(UP_DIV(matmul_param_->col_, C4NUM), thread_count_); | |||||
| } | } | ||||
| input_sum_ = reinterpret_cast<int32_t *>(malloc(size * sizeof(int32_t))); | |||||
| if (input_sum_ == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc input_sum_ failed."; | |||||
| return RET_ERROR; | |||||
| if (support_optimize_) { | |||||
| thread_count_hw_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->row_, C8NUM)); | |||||
| thread_stride_hw_ = UP_DIV(UP_DIV(matmul_param_->row_, C8NUM), thread_count_hw_); | |||||
| } else { | |||||
| thread_count_hw_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->row_, C4NUM)); | |||||
| thread_stride_hw_ = UP_DIV(UP_DIV(matmul_param_->row_, C4NUM), thread_count_hw_); | |||||
| } | } | ||||
| memset(input_sum_, 0, size * sizeof(int32_t)); | |||||
| if (pre_trans_input_) { | |||||
| input_ptr_ = reinterpret_cast<int8_t *>(malloc(matmul_param_->row_ * matmul_param_->deep_ * sizeof(int8_t))); | |||||
| if (input_ptr_ == nullptr) { | |||||
| MS_LOG(ERROR) << "Conv1x1 int8 Malloc input_ptr_ error!"; | |||||
| return RET_MEMORY_FAILED; | |||||
| } | |||||
| memset(input_ptr_, 0, matmul_param_->row_ * matmul_param_->deep_ * sizeof(int8_t)); | |||||
| } | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -199,21 +233,54 @@ void Convolution1x1Int8CPUKernel::Pre1x1Trans(int8_t *src_input, int8_t *src_out | |||||
| } else { | } else { | ||||
| input_ptr_ = src_input; | input_ptr_ = src_input; | ||||
| } | } | ||||
| RowMajor2Row16x4MajorInt8(input_ptr_, packed_input_, matmul_param_->row_, matmul_param_->deep_); | |||||
| if (support_optimize_) { | |||||
| ParallelLaunch(THREAD_POOL_DEFAULT, Convolution1x1Int8Pre, this, thread_count_hw_); | |||||
| } else { | |||||
| RowMajor2Row16x4MajorInt8(input_ptr_, packed_input_, matmul_param_->row_, matmul_param_->deep_); | |||||
| PackInputSum16x4Int8(packed_input_, input_sum_, matmul_param_->deep_, matmul_param_->col_, matmul_param_->row_, | |||||
| conv_param_); | |||||
| } | |||||
| return; | return; | ||||
| } | } | ||||
| int Convolution1x1Int8CPUKernel::RunImpl(int task_id) { | int Convolution1x1Int8CPUKernel::RunImpl(int task_id) { | ||||
| int cur_oc = MSMIN(thread_stride_ * C4NUM, matmul_param_->col_ - task_id * thread_stride_ * C4NUM); | |||||
| if (cur_oc <= 0) { | |||||
| return RET_OK; | |||||
| if (support_optimize_) { | |||||
| int cur_stride = thread_stride_ * C8NUM; | |||||
| int res_stride = matmul_param_->col_ - task_id * thread_stride_ * C8NUM; | |||||
| int cur_oc = MSMIN(cur_stride, res_stride); | |||||
| if (cur_oc <= 0) { | |||||
| return RET_OK; | |||||
| } | |||||
| Conv1x1Int8Opt(packed_input_, packed_weight_ + task_id * thread_stride_ * C8NUM * matmul_param_->deep_4_, | |||||
| output_ptr_ + task_id * thread_stride_ * C8NUM, input_sum_, | |||||
| reinterpret_cast<int32_t *>(bias_data_) + task_id * thread_stride_ * C8NUM, matmul_param_->row_, | |||||
| cur_oc, matmul_param_->deep_4_, conv_param_, matmul_func_); | |||||
| } else { | |||||
| int cur_stride = thread_stride_ * C4NUM; | |||||
| int res_stride = matmul_param_->col_ - task_id * thread_stride_ * C4NUM; | |||||
| int cur_oc = MSMIN(cur_stride, res_stride); | |||||
| if (cur_oc <= 0) { | |||||
| return RET_OK; | |||||
| } | |||||
| Conv1x1Int8(packed_input_, packed_weight_ + task_id * thread_stride_ * C4NUM * matmul_param_->deep_16_, | |||||
| output_ptr_ + task_id * thread_stride_ * C4NUM, input_sum_, | |||||
| reinterpret_cast<int32_t *>(bias_data_) + task_id * thread_stride_ * C4NUM, matmul_param_->row_, cur_oc, | |||||
| matmul_param_->deep_16_, conv_param_); | |||||
| } | } | ||||
| return RET_OK; | |||||
| } | |||||
| int32_t *bias = reinterpret_cast<int32_t *>(bias_data_) + thread_stride_ * C4NUM * task_id; | |||||
| Conv1x1Int8(packed_input_, packed_weight_ + task_id * thread_stride_ * C4NUM * matmul_param_->deep_, | |||||
| output_ptr_ + task_id * thread_stride_ * C4NUM, input_sum_, bias + task_id * thread_stride_ * C4NUM, | |||||
| matmul_param_->row_, cur_oc, UP_ROUND(matmul_param_->deep_, C16NUM), conv_param_, matmul_func_); | |||||
| int Convolution1x1Int8CPUKernel::RunPre(int task_id) { | |||||
| int cur_hw = MSMIN(thread_stride_hw_ * C8NUM, matmul_param_->row_ - task_id * thread_stride_hw_ * C8NUM); | |||||
| if (cur_hw <= 0) { | |||||
| return RET_OK; | |||||
| } | |||||
| Conv1x1PreOpt(input_ptr_ + task_id * thread_stride_hw_ * C8NUM * matmul_param_->deep_, | |||||
| packed_input_ + task_id * thread_stride_hw_ * C8NUM * matmul_param_->deep_4_, | |||||
| input_sum_ + task_id * thread_stride_hw_ * C8NUM, matmul_param_->deep_, matmul_param_->col_, cur_hw, | |||||
| conv_param_); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -227,6 +294,35 @@ int Convolution1x1Int8Impl(void *cdata, int task_id) { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int Convolution1x1Int8CPUKernel::InitRunBuf() { | |||||
| input_sum_ = reinterpret_cast<int32_t *>(malloc(input_sum_size * sizeof(int32_t))); | |||||
| if (input_sum_ == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc input_sum_ failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| size_t size = support_optimize_ ? UP_ROUND(matmul_param_->row_, C8NUM) * UP_ROUND(matmul_param_->deep_, C4NUM) | |||||
| : UP_ROUND(matmul_param_->row_, C4NUM) * UP_ROUND(matmul_param_->deep_, C16NUM); | |||||
| packed_input_ = reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(size * sizeof(int8_t))); | |||||
| if (packed_input_ == nullptr) { | |||||
| MS_LOG(ERROR) << "conv1x1 int8 Malloc packed_input_ error!"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| void Convolution1x1Int8CPUKernel::FreeRunBuf() { | |||||
| if (packed_input_ != nullptr) { | |||||
| ctx_->allocator->Free(packed_input_); | |||||
| packed_input_ = nullptr; | |||||
| } | |||||
| if (input_sum_ != nullptr) { | |||||
| ctx_->allocator->Free(input_sum_); | |||||
| input_sum_ = nullptr; | |||||
| } | |||||
| return; | |||||
| } | |||||
| int Convolution1x1Int8CPUKernel::Run() { | int Convolution1x1Int8CPUKernel::Run() { | ||||
| auto ret = Prepare(); | auto ret = Prepare(); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| @@ -234,13 +330,10 @@ int Convolution1x1Int8CPUKernel::Run() { | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (pre_trans_input_) { | |||||
| input_ptr_ = | |||||
| reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(matmul_param_->row_ * matmul_param_->deep_ * sizeof(int8_t))); | |||||
| if (input_ptr_ == nullptr) { | |||||
| MS_LOG(ERROR) << "Conv1x1 int8 Malloc input_ptr_ error!"; | |||||
| return RET_MEMORY_FAILED; | |||||
| } | |||||
| int error_code = InitRunBuf(); | |||||
| if (error_code != RET_OK) { | |||||
| MS_LOG(ERROR) << "conv1x1 int8 InitRunBuf error_code[" << error_code << "]"; | |||||
| return RET_ERROR; | |||||
| } | } | ||||
| int8_t *src_in = reinterpret_cast<int8_t *>(in_tensors_[0]->Data()); | int8_t *src_in = reinterpret_cast<int8_t *>(in_tensors_[0]->Data()); | ||||
| @@ -249,21 +342,10 @@ int Convolution1x1Int8CPUKernel::Run() { | |||||
| for (int batch_index = 0; batch_index < conv_param_->input_batch_; batch_index++) { | for (int batch_index = 0; batch_index < conv_param_->input_batch_; batch_index++) { | ||||
| Pre1x1Trans(src_in + batch_index * conv_param_->input_h_ * conv_param_->input_w_ * conv_param_->input_channel_, | Pre1x1Trans(src_in + batch_index * conv_param_->input_h_ * conv_param_->input_w_ * conv_param_->input_channel_, | ||||
| src_out + batch_index * matmul_param_->row_ * matmul_param_->col_); | src_out + batch_index * matmul_param_->row_ * matmul_param_->col_); | ||||
| PackInputSum16x4Int8(packed_input_, input_sum_, matmul_param_->deep_, matmul_param_->col_, matmul_param_->row_, | |||||
| conv_param_); | |||||
| int error_code = ParallelLaunch(THREAD_POOL_DEFAULT, Convolution1x1Int8Impl, this, thread_count_); | |||||
| if (error_code != RET_OK) { | |||||
| MS_LOG(ERROR) << "conv1x1 fp16 error error_code[" << error_code << "]"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| ParallelLaunch(THREAD_POOL_DEFAULT, Convolution1x1Int8Impl, this, thread_count_); | |||||
| } | } | ||||
| if (pre_trans_input_ && input_ptr_ != nullptr) { | |||||
| ctx_->allocator->Free(input_ptr_); | |||||
| input_ptr_ = nullptr; | |||||
| } | |||||
| FreeRunBuf(); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -40,8 +40,13 @@ class Convolution1x1Int8CPUKernel : public ConvolutionBaseCPUKernel { | |||||
| int ReSize() override; | int ReSize() override; | ||||
| int Run() override; | int Run() override; | ||||
| private: | |||||
| int InitRunBuf(); | |||||
| void FreeRunBuf(); | |||||
| public: | public: | ||||
| int RunImpl(int task_id); | int RunImpl(int task_id); | ||||
| int RunPre(int task_id); | |||||
| private: | private: | ||||
| void FreeResizeBuf(); | void FreeResizeBuf(); | ||||
| @@ -58,7 +63,10 @@ class Convolution1x1Int8CPUKernel : public ConvolutionBaseCPUKernel { | |||||
| int8_t *output_ptr_ = nullptr; | int8_t *output_ptr_ = nullptr; | ||||
| size_t thread_count_ = 1; | size_t thread_count_ = 1; | ||||
| size_t thread_stride_ = 0; | size_t thread_stride_ = 0; | ||||
| size_t thread_count_hw_ = 1; | |||||
| size_t thread_stride_hw_ = 0; | |||||
| bool pre_trans_input_ = false; | bool pre_trans_input_ = false; | ||||
| size_t input_sum_size = 0; | |||||
| MatMulParameter *matmul_param_ = nullptr; | MatMulParameter *matmul_param_ = nullptr; | ||||
| MATMUL_OPT_R_FUNC matmul_func_ = nullptr; | MATMUL_OPT_R_FUNC matmul_func_ = nullptr; | ||||
| bool support_optimize_ = false; | bool support_optimize_ = false; | ||||
| @@ -398,11 +398,11 @@ kernel::LiteKernel *CpuConvInt8KernelCreator(const std::vector<lite::tensor::Ten | |||||
| int dilation_h = conv_param->dilation_h_; | int dilation_h = conv_param->dilation_h_; | ||||
| int dilation_w = conv_param->dilation_w_; | int dilation_w = conv_param->dilation_w_; | ||||
| kernel::LiteKernel *kernel; | kernel::LiteKernel *kernel; | ||||
| auto filter_quant_size = inputs[kWeightIndex]->GetQuantParams().size(); | |||||
| if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) { | if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) { | ||||
| kernel = new (std::nothrow) kernel::Convolution3x3Int8CPUKernel(opParameter, inputs, outputs, ctx, primitive); | kernel = new (std::nothrow) kernel::Convolution3x3Int8CPUKernel(opParameter, inputs, outputs, ctx, primitive); | ||||
| } else if (kernel_h == 1 && kernel_w == 1) { | |||||
| /* Convolution1x1Int8CPUKernel */ | |||||
| kernel = new (std::nothrow) kernel::ConvolutionInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||||
| } else if (kernel_h == 1 && kernel_w == 1 && filter_quant_size == 1) { | |||||
| kernel = new (std::nothrow) kernel::Convolution1x1Int8CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||||
| } else { | } else { | ||||
| kernel = new (std::nothrow) kernel::ConvolutionInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); | kernel = new (std::nothrow) kernel::ConvolutionInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); | ||||
| } | } | ||||