| @@ -371,249 +371,380 @@ void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight | |||
| } | |||
| } | |||
| void Conv1x1PreOpt(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel, | |||
| size_t output_channel, size_t plane_size, ConvParameter *conv_param) { | |||
| void Conv1x1PreOptPeroc(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel, | |||
| size_t output_channel, size_t plane_size, int32_t *filter_zp, size_t inputsum_stride) { | |||
| int ic4 = UP_ROUND(input_channel, C4NUM); | |||
| int oc8 = UP_ROUND(output_channel, C8NUM); | |||
| int hw8 = UP_ROUND(plane_size, C8NUM); | |||
| size_t hw_8div = plane_size / C8NUM * C8NUM; | |||
| size_t hw_8res = plane_size - hw_8div; | |||
| size_t oc_8div = output_channel / C8NUM * C8NUM; | |||
| size_t oc_8res = output_channel - oc_8div; | |||
| size_t ic_4div = input_channel / C4NUM * C4NUM; | |||
| int32_t filter_zp = conv_param->conv_quant_arg_.filter_quant_args_[0].zp_; | |||
| if (conv_param->conv_quant_arg_.filter_arg_num_ == 1) { | |||
| const int8_t *src_r = src_input; | |||
| int8_t *pack_r = packed_input; | |||
| /* per layer */ | |||
| for (int hwi = 0; hwi < hw_8div; hwi += C8NUM) { | |||
| const int8_t *src_r = src_input; | |||
| int8_t *pack_r = packed_input; | |||
| int32_t *input_sum_r = input_sum; | |||
| /* per layer */ | |||
| for (int hwi = 0; hwi < hw_8div; hwi += C8NUM) { | |||
| const int8_t *src_ic = src_r; | |||
| int8_t *pack_ic = pack_r; | |||
| int32_t *input_sum_oc = input_sum_r; | |||
| int32_t tmp_sum_value[8] = {0}; | |||
| for (int ici = 0; ici < ic_4div; ici += C4NUM) { | |||
| for (int i = 0; i < C8NUM; i++) { | |||
| tmp_sum_value[i] += src_ic[0 + i * input_channel]; | |||
| tmp_sum_value[i] += src_ic[1 + i * input_channel]; | |||
| tmp_sum_value[i] += src_ic[2 + i * input_channel]; | |||
| tmp_sum_value[i] += src_ic[3 + i * input_channel]; | |||
| pack_ic[0 + i * C4NUM] = src_ic[0 + i * input_channel]; | |||
| pack_ic[1 + i * C4NUM] = src_ic[1 + i * input_channel]; | |||
| pack_ic[2 + i * C4NUM] = src_ic[2 + i * input_channel]; | |||
| pack_ic[3 + i * C4NUM] = src_ic[3 + i * input_channel]; | |||
| } | |||
| src_ic += C4NUM; | |||
| pack_ic += C4NUM * C8NUM; | |||
| } | |||
| for (int ici = ic_4div; ici < input_channel; ici += 1) { | |||
| for (int i = 0; i < C8NUM; i++) { | |||
| tmp_sum_value[i] += src_ic[i * input_channel]; | |||
| pack_ic[i * C4NUM] = src_ic[i * input_channel]; | |||
| } | |||
| src_ic += 1; | |||
| pack_ic += 1; | |||
| } | |||
| for (int oci = 0; oci < oc_8div; oci += C8NUM) { | |||
| for (int ri = 0; ri < C8NUM; ri++) { | |||
| input_sum_oc[ri * C8NUM + 0] = tmp_sum_value[ri] * filter_zp[oci + 0]; | |||
| input_sum_oc[ri * C8NUM + 1] = tmp_sum_value[ri] * filter_zp[oci + 1]; | |||
| input_sum_oc[ri * C8NUM + 2] = tmp_sum_value[ri] * filter_zp[oci + 2]; | |||
| input_sum_oc[ri * C8NUM + 3] = tmp_sum_value[ri] * filter_zp[oci + 3]; | |||
| input_sum_oc[ri * C8NUM + 4] = tmp_sum_value[ri] * filter_zp[oci + 4]; | |||
| input_sum_oc[ri * C8NUM + 5] = tmp_sum_value[ri] * filter_zp[oci + 5]; | |||
| input_sum_oc[ri * C8NUM + 6] = tmp_sum_value[ri] * filter_zp[oci + 6]; | |||
| input_sum_oc[ri * C8NUM + 7] = tmp_sum_value[ri] * filter_zp[oci + 7]; | |||
| } | |||
| input_sum_oc += inputsum_stride; | |||
| } | |||
| if (oc_8div != output_channel) { | |||
| for (int oci = 0; oci < oc_8res; oci += 1) { | |||
| for (int ri = 0; ri < C8NUM; ri++) { | |||
| input_sum_oc[ri * C8NUM + oci] = tmp_sum_value[ri] * filter_zp[oc_8div + oci]; | |||
| } | |||
| } | |||
| for (int oci = oc_8res; oci < C8NUM; oci += 1) { | |||
| for (int ri = 0; ri < C8NUM; ri++) { | |||
| input_sum_oc[ri * C8NUM + oci] = 0; | |||
| } | |||
| } | |||
| } /* oc8 res done */ | |||
| src_r += input_channel * C8NUM; | |||
| pack_r += ic4 * C8NUM; | |||
| input_sum_r += C8NUM * C8NUM; | |||
| } | |||
| if (hw_8div != plane_size) { | |||
| memset(pack_r, 0, C8NUM * ic4); | |||
| for (int hwi = hw_8div; hwi < plane_size; hwi += 1) { | |||
| int32_t *input_sum_oc = input_sum_r; | |||
| int32_t tmp_sum_value = 0; | |||
| const int8_t *src_ic = src_r; | |||
| int8_t *pack_ic = pack_r; | |||
| int32_t *input_sum_r = input_sum + hwi; | |||
| #ifdef ENABLE_ARM64 | |||
| size_t src_stride = input_channel; | |||
| size_t ic_4res = input_channel - ic_4div; | |||
| asm volatile( | |||
| "dup v10.4s, wzr \n" | |||
| "dup v11.4s, wzr \n" | |||
| "mov x20, %[input_sum_r] \n" | |||
| "dup v20.4s, %w[filter_zp] \n" | |||
| "mov x10, %[src_ic] \n" | |||
| "mov x11, %[pack_ic] \n" | |||
| "mov x0, #0 \n" | |||
| "1: \n" | |||
| "cmp x0, %[ic_4div] \n" | |||
| "add x0, x0, #4\n" | |||
| "mov x12, x10 \n" | |||
| "add x10, x10, #4\n" | |||
| "blt 2f \n" | |||
| "cmp %[ic_4res], #0\n" | |||
| "beq 6f \n" | |||
| "cmp %[ic_4res], #1\n" | |||
| "beq 3f \n" | |||
| "cmp %[ic_4res], #2\n" | |||
| "beq 4f \n" | |||
| "cmp %[ic_4res], #3\n" | |||
| "beq 5f \n" | |||
| "2: \n" | |||
| "ld1 {v0.s}[0], [x12], %[src_stride]\n" | |||
| "ld1 {v0.s}[1], [x12], %[src_stride]\n" | |||
| "ld1 {v0.s}[2], [x12], %[src_stride]\n" | |||
| "ld1 {v0.s}[3], [x12], %[src_stride]\n" | |||
| "ld1 {v1.s}[0], [x12], %[src_stride]\n" | |||
| "ld1 {v1.s}[1], [x12], %[src_stride]\n" | |||
| "ld1 {v1.s}[2], [x12], %[src_stride]\n" | |||
| "ld1 {v1.s}[3], [x12], %[src_stride]\n" | |||
| "st1 {v0.16b}, [x11], #16\n" | |||
| "st1 {v1.16b}, [x11], #16\n" | |||
| "saddlp v4.8h, v0.16b \n" | |||
| "saddlp v5.8h, v1.16b \n" | |||
| "saddlp v0.4s, v4.8h \n" | |||
| "saddlp v1.4s, v5.8h \n" | |||
| "add v10.4s, v10.4s, v0.4s \n" | |||
| "add v11.4s, v11.4s, v1.4s \n" | |||
| "b 1b \n" | |||
| "3: \n" /* col res 1 */ | |||
| "dup v0.4s, wzr \n" | |||
| "dup v1.4s, wzr \n" | |||
| "ld1 {v0.b}[0], [x12], %[src_stride]\n" | |||
| "ld1 {v0.b}[4], [x12], %[src_stride]\n" | |||
| "ld1 {v0.b}[8], [x12], %[src_stride]\n" | |||
| "ld1 {v0.b}[12], [x12], %[src_stride]\n" | |||
| "ld1 {v1.b}[0], [x12], %[src_stride]\n" | |||
| "ld1 {v1.b}[4], [x12], %[src_stride]\n" | |||
| "ld1 {v1.b}[8], [x12], %[src_stride]\n" | |||
| "ld1 {v1.b}[12], [x12], %[src_stride]\n" | |||
| "st1 {v0.16b}, [x11], #16\n" | |||
| "st1 {v1.16b}, [x11], #16\n" | |||
| "saddlp v4.8h, v0.16b \n" | |||
| "saddlp v5.8h, v1.16b \n" | |||
| "saddlp v0.4s, v4.8h \n" | |||
| "saddlp v1.4s, v5.8h \n" | |||
| "add v10.4s, v10.4s, v0.4s \n" | |||
| "add v11.4s, v11.4s, v1.4s \n" | |||
| "b 6f \n" | |||
| "4: \n" /* col res 2 */ | |||
| "dup v0.4s, wzr \n" | |||
| "dup v1.4s, wzr \n" | |||
| "ld1 {v0.h}[0], [x12], %[src_stride]\n" | |||
| "ld1 {v0.h}[2], [x12], %[src_stride]\n" | |||
| "ld1 {v0.h}[4], [x12], %[src_stride]\n" | |||
| "ld1 {v0.h}[6], [x12], %[src_stride]\n" | |||
| "ld1 {v1.h}[0], [x12], %[src_stride]\n" | |||
| "ld1 {v1.h}[2], [x12], %[src_stride]\n" | |||
| "ld1 {v1.h}[4], [x12], %[src_stride]\n" | |||
| "ld1 {v1.h}[6], [x12], %[src_stride]\n" | |||
| "st1 {v0.16b}, [x11], #16\n" | |||
| "st1 {v1.16b}, [x11], #16\n" | |||
| "saddlp v4.8h, v0.16b \n" | |||
| "saddlp v5.8h, v1.16b \n" | |||
| "saddlp v0.4s, v4.8h \n" | |||
| "saddlp v1.4s, v5.8h \n" | |||
| "add v10.4s, v10.4s, v0.4s \n" | |||
| "add v11.4s, v11.4s, v1.4s \n" | |||
| "b 6f \n" | |||
| "5: \n" /* col res 3 */ | |||
| "dup v0.4s, wzr \n" | |||
| "dup v1.4s, wzr \n" | |||
| "add x13, x12, #2 \n" | |||
| "ld1 {v0.h}[0], [x12], %[src_stride]\n" | |||
| "ld1 {v0.b}[2], [x13], %[src_stride]\n" | |||
| "ld1 {v0.h}[2], [x12], %[src_stride]\n" | |||
| "ld1 {v0.b}[6], [x13], %[src_stride]\n" | |||
| "ld1 {v0.h}[4], [x12], %[src_stride]\n" | |||
| "ld1 {v0.b}[10], [x13], %[src_stride]\n" | |||
| "ld1 {v0.h}[6], [x12], %[src_stride]\n" | |||
| "ld1 {v0.b}[14], [x13], %[src_stride]\n" | |||
| "ld1 {v1.h}[0], [x12], %[src_stride]\n" | |||
| "ld1 {v1.b}[2], [x13], %[src_stride]\n" | |||
| "ld1 {v1.h}[2], [x12], %[src_stride]\n" | |||
| "ld1 {v1.b}[6], [x13], %[src_stride]\n" | |||
| "ld1 {v1.h}[4], [x12], %[src_stride]\n" | |||
| "ld1 {v1.b}[10], [x13], %[src_stride]\n" | |||
| "ld1 {v1.h}[6], [x12], %[src_stride]\n" | |||
| "ld1 {v1.b}[14], [x13], %[src_stride]\n" | |||
| "st1 {v0.16b}, [x11], #16\n" | |||
| "st1 {v1.16b}, [x11], #16\n" | |||
| "saddlp v4.8h, v0.16b \n" | |||
| "saddlp v5.8h, v1.16b \n" | |||
| "saddlp v0.4s, v4.8h \n" | |||
| "saddlp v1.4s, v5.8h \n" | |||
| "add v10.4s, v10.4s, v0.4s \n" | |||
| "add v11.4s, v11.4s, v1.4s \n" | |||
| "b 6f \n" | |||
| "6: \n" | |||
| "mul v10.4s, v10.4s, v20.4s \n" | |||
| "mul v11.4s, v11.4s, v20.4s \n" | |||
| "st1 {v10.4s}, [x20], #16 \n" | |||
| "st1 {v11.4s}, [x20], #16 \n" | |||
| : | |||
| : [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ input_sum_r ] "r"(input_sum_r), | |||
| [ src_stride ] "r"(src_stride), [ ic_4div ] "r"(ic_4div), [ ic_4res ] "r"(ic_4res), | |||
| [ filter_zp ] "r"(filter_zp) | |||
| : "x0", "x1", "x10", "x11", "x12", "x13", "x20", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v10", "v11", | |||
| "v20"); | |||
| #else | |||
| int32_t tmp_sum_value[8] = {0}; | |||
| for (int ici = 0; ici < ic_4div; ici += C4NUM) { | |||
| for (int i = 0; i < C8NUM; i++) { | |||
| tmp_sum_value[i] += src_ic[0 + i * input_channel]; | |||
| tmp_sum_value[i] += src_ic[1 + i * input_channel]; | |||
| tmp_sum_value[i] += src_ic[2 + i * input_channel]; | |||
| tmp_sum_value[i] += src_ic[3 + i * input_channel]; | |||
| pack_ic[0 + i * C4NUM] = src_ic[0 + i * input_channel]; | |||
| pack_ic[1 + i * C4NUM] = src_ic[1 + i * input_channel]; | |||
| pack_ic[2 + i * C4NUM] = src_ic[2 + i * input_channel]; | |||
| pack_ic[3 + i * C4NUM] = src_ic[3 + i * input_channel]; | |||
| } | |||
| tmp_sum_value += src_ic[0]; | |||
| tmp_sum_value += src_ic[1]; | |||
| tmp_sum_value += src_ic[2]; | |||
| tmp_sum_value += src_ic[3]; | |||
| pack_ic[0] = src_ic[0]; | |||
| pack_ic[1] = src_ic[1]; | |||
| pack_ic[2] = src_ic[2]; | |||
| pack_ic[3] = src_ic[3]; | |||
| src_ic += C4NUM; | |||
| pack_ic += C4NUM * C8NUM; | |||
| } | |||
| for (int ici = ic_4div; ici < input_channel; ici += 1) { | |||
| for (int i = 0; i < C8NUM; i++) { | |||
| tmp_sum_value[i] += src_ic[i * input_channel]; | |||
| pack_ic[i * C4NUM] = src_ic[i * input_channel]; | |||
| } | |||
| tmp_sum_value += src_ic[0]; | |||
| pack_ic[0] = src_ic[0]; | |||
| src_ic += 1; | |||
| pack_ic += 1; | |||
| } | |||
| for (int oci = 0; oci < oc_8div; oci += C8NUM) { | |||
| for (int curoi = 0; curoi < C8NUM; curoi++) { | |||
| input_sum_oc[curoi] = tmp_sum_value * filter_zp[oci + curoi]; | |||
| } | |||
| input_sum_oc += inputsum_stride; | |||
| } | |||
| if (oc_8div != output_channel) { | |||
| for (int oci = 0; oci < oc_8res; oci += 1) { | |||
| input_sum_oc[oci] = tmp_sum_value * filter_zp[oc_8div + oci]; | |||
| } | |||
| for (int oci = oc_8res; oci < C8NUM; oci += 1) { | |||
| input_sum_oc[oci] = 0; | |||
| } | |||
| } /* oc8 res done */ | |||
| src_r += input_channel; | |||
| pack_r += C4NUM; | |||
| input_sum_r += C8NUM; | |||
| } | |||
| for (int hwi = plane_size; hwi < hw8; hwi++) { | |||
| for (int oc = 0; oc < oc8; oc++) { | |||
| int oc8div = oc / C8NUM, oc8res = oc % C8NUM; | |||
| input_sum[oc8div * inputsum_stride + hwi * C8NUM + oc8res] = 0; | |||
| } | |||
| } | |||
| } | |||
| return; | |||
| } | |||
| void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel, | |||
| size_t plane_size, ConvParameter *conv_param) { | |||
| int ic4 = UP_ROUND(input_channel, C4NUM); | |||
| size_t hw_8div = plane_size / C8NUM * C8NUM; | |||
| size_t ic_4div = input_channel / C4NUM * C4NUM; | |||
| int32_t filter_zp = conv_param->conv_quant_arg_.filter_quant_args_[0].zp_; | |||
| const int8_t *src_r = src_input; | |||
| int8_t *pack_r = packed_input; | |||
| /* per layer */ | |||
| for (int hwi = 0; hwi < hw_8div; hwi += C8NUM) { | |||
| const int8_t *src_ic = src_r; | |||
| int8_t *pack_ic = pack_r; | |||
| int32_t *input_sum_r = input_sum + hwi; | |||
| #ifdef ENABLE_ARM64 | |||
| size_t src_stride = input_channel; | |||
| size_t ic_4res = input_channel - ic_4div; | |||
| asm volatile( | |||
| "dup v10.4s, wzr \n" | |||
| "dup v11.4s, wzr \n" | |||
| "mov x20, %[input_sum_r] \n" | |||
| "dup v20.4s, %w[filter_zp] \n" | |||
| "mov x10, %[src_ic] \n" | |||
| "mov x11, %[pack_ic] \n" | |||
| "mov x0, #0 \n" | |||
| "1: \n" | |||
| "cmp x0, %[ic_4div] \n" | |||
| "add x0, x0, #4\n" | |||
| "mov x12, x10 \n" | |||
| "add x10, x10, #4\n" | |||
| "blt 2f \n" | |||
| "cmp %[ic_4res], #0\n" | |||
| "beq 6f \n" | |||
| "cmp %[ic_4res], #1\n" | |||
| "beq 3f \n" | |||
| "cmp %[ic_4res], #2\n" | |||
| "beq 4f \n" | |||
| "cmp %[ic_4res], #3\n" | |||
| "beq 5f \n" | |||
| "2: \n" | |||
| "ld1 {v0.s}[0], [x12], %[src_stride]\n" | |||
| "ld1 {v0.s}[1], [x12], %[src_stride]\n" | |||
| "ld1 {v0.s}[2], [x12], %[src_stride]\n" | |||
| "ld1 {v0.s}[3], [x12], %[src_stride]\n" | |||
| "ld1 {v1.s}[0], [x12], %[src_stride]\n" | |||
| "ld1 {v1.s}[1], [x12], %[src_stride]\n" | |||
| "ld1 {v1.s}[2], [x12], %[src_stride]\n" | |||
| "ld1 {v1.s}[3], [x12], %[src_stride]\n" | |||
| "st1 {v0.16b}, [x11], #16\n" | |||
| "st1 {v1.16b}, [x11], #16\n" | |||
| "saddlp v4.8h, v0.16b \n" | |||
| "saddlp v5.8h, v1.16b \n" | |||
| "saddlp v0.4s, v4.8h \n" | |||
| "saddlp v1.4s, v5.8h \n" | |||
| "add v10.4s, v10.4s, v0.4s \n" | |||
| "add v11.4s, v11.4s, v1.4s \n" | |||
| "b 1b \n" | |||
| "3: \n" /* col res 1 */ | |||
| "dup v0.4s, wzr \n" | |||
| "dup v1.4s, wzr \n" | |||
| "ld1 {v0.b}[0], [x12], %[src_stride]\n" | |||
| "ld1 {v0.b}[4], [x12], %[src_stride]\n" | |||
| "ld1 {v0.b}[8], [x12], %[src_stride]\n" | |||
| "ld1 {v0.b}[12], [x12], %[src_stride]\n" | |||
| "ld1 {v1.b}[0], [x12], %[src_stride]\n" | |||
| "ld1 {v1.b}[4], [x12], %[src_stride]\n" | |||
| "ld1 {v1.b}[8], [x12], %[src_stride]\n" | |||
| "ld1 {v1.b}[12], [x12], %[src_stride]\n" | |||
| "st1 {v0.16b}, [x11], #16\n" | |||
| "st1 {v1.16b}, [x11], #16\n" | |||
| "saddlp v4.8h, v0.16b \n" | |||
| "saddlp v5.8h, v1.16b \n" | |||
| "saddlp v0.4s, v4.8h \n" | |||
| "saddlp v1.4s, v5.8h \n" | |||
| "add v10.4s, v10.4s, v0.4s \n" | |||
| "add v11.4s, v11.4s, v1.4s \n" | |||
| "b 6f \n" | |||
| "4: \n" /* col res 2 */ | |||
| "dup v0.4s, wzr \n" | |||
| "dup v1.4s, wzr \n" | |||
| "ld1 {v0.h}[0], [x12], %[src_stride]\n" | |||
| "ld1 {v0.h}[2], [x12], %[src_stride]\n" | |||
| "ld1 {v0.h}[4], [x12], %[src_stride]\n" | |||
| "ld1 {v0.h}[6], [x12], %[src_stride]\n" | |||
| "ld1 {v1.h}[0], [x12], %[src_stride]\n" | |||
| "ld1 {v1.h}[2], [x12], %[src_stride]\n" | |||
| "ld1 {v1.h}[4], [x12], %[src_stride]\n" | |||
| "ld1 {v1.h}[6], [x12], %[src_stride]\n" | |||
| "st1 {v0.16b}, [x11], #16\n" | |||
| "st1 {v1.16b}, [x11], #16\n" | |||
| "saddlp v4.8h, v0.16b \n" | |||
| "saddlp v5.8h, v1.16b \n" | |||
| "saddlp v0.4s, v4.8h \n" | |||
| "saddlp v1.4s, v5.8h \n" | |||
| "add v10.4s, v10.4s, v0.4s \n" | |||
| "add v11.4s, v11.4s, v1.4s \n" | |||
| "b 6f \n" | |||
| "5: \n" /* col res 3 */ | |||
| "dup v0.4s, wzr \n" | |||
| "dup v1.4s, wzr \n" | |||
| "add x13, x12, #2 \n" | |||
| "ld1 {v0.h}[0], [x12], %[src_stride]\n" | |||
| "ld1 {v0.b}[2], [x13], %[src_stride]\n" | |||
| "ld1 {v0.h}[2], [x12], %[src_stride]\n" | |||
| "ld1 {v0.b}[6], [x13], %[src_stride]\n" | |||
| "ld1 {v0.h}[4], [x12], %[src_stride]\n" | |||
| "ld1 {v0.b}[10], [x13], %[src_stride]\n" | |||
| "ld1 {v0.h}[6], [x12], %[src_stride]\n" | |||
| "ld1 {v0.b}[14], [x13], %[src_stride]\n" | |||
| "ld1 {v1.h}[0], [x12], %[src_stride]\n" | |||
| "ld1 {v1.b}[2], [x13], %[src_stride]\n" | |||
| "ld1 {v1.h}[2], [x12], %[src_stride]\n" | |||
| "ld1 {v1.b}[6], [x13], %[src_stride]\n" | |||
| "ld1 {v1.h}[4], [x12], %[src_stride]\n" | |||
| "ld1 {v1.b}[10], [x13], %[src_stride]\n" | |||
| "ld1 {v1.h}[6], [x12], %[src_stride]\n" | |||
| "ld1 {v1.b}[14], [x13], %[src_stride]\n" | |||
| "st1 {v0.16b}, [x11], #16\n" | |||
| "st1 {v1.16b}, [x11], #16\n" | |||
| "saddlp v4.8h, v0.16b \n" | |||
| "saddlp v5.8h, v1.16b \n" | |||
| "saddlp v0.4s, v4.8h \n" | |||
| "saddlp v1.4s, v5.8h \n" | |||
| "add v10.4s, v10.4s, v0.4s \n" | |||
| "add v11.4s, v11.4s, v1.4s \n" | |||
| "b 6f \n" | |||
| "6: \n" | |||
| "mul v10.4s, v10.4s, v20.4s \n" | |||
| "mul v11.4s, v11.4s, v20.4s \n" | |||
| "st1 {v10.4s}, [x20], #16 \n" | |||
| "st1 {v11.4s}, [x20], #16 \n" | |||
| : | |||
| : [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ input_sum_r ] "r"(input_sum_r), | |||
| [ src_stride ] "r"(src_stride), [ ic_4div ] "r"(ic_4div), [ ic_4res ] "r"(ic_4res), [ filter_zp ] "r"(filter_zp) | |||
| : "x0", "x1", "x10", "x11", "x12", "x13", "x20", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v10", "v11", | |||
| "v20"); | |||
| #else | |||
| int32_t tmp_sum_value[8] = {0}; | |||
| for (int ici = 0; ici < ic_4div; ici += C4NUM) { | |||
| for (int i = 0; i < C8NUM; i++) { | |||
| input_sum_r[i] = tmp_sum_value[i] * filter_zp; | |||
| tmp_sum_value[i] += src_ic[0 + i * input_channel]; | |||
| tmp_sum_value[i] += src_ic[1 + i * input_channel]; | |||
| tmp_sum_value[i] += src_ic[2 + i * input_channel]; | |||
| tmp_sum_value[i] += src_ic[3 + i * input_channel]; | |||
| pack_ic[0 + i * C4NUM] = src_ic[0 + i * input_channel]; | |||
| pack_ic[1 + i * C4NUM] = src_ic[1 + i * input_channel]; | |||
| pack_ic[2 + i * C4NUM] = src_ic[2 + i * input_channel]; | |||
| pack_ic[3 + i * C4NUM] = src_ic[3 + i * input_channel]; | |||
| } | |||
| #endif | |||
| src_r += input_channel * C8NUM; | |||
| pack_r += ic4 * C8NUM; | |||
| src_ic += C4NUM; | |||
| pack_ic += C4NUM * C8NUM; | |||
| } | |||
| for (int ici = ic_4div; ici < input_channel; ici += 1) { | |||
| for (int i = 0; i < C8NUM; i++) { | |||
| tmp_sum_value[i] += src_ic[i * input_channel]; | |||
| pack_ic[i * C4NUM] = src_ic[i * input_channel]; | |||
| } | |||
| src_ic += 1; | |||
| pack_ic += 1; | |||
| } | |||
| if (hw_8div != plane_size) { | |||
| memset(pack_r, 0, C8NUM * ic4); | |||
| for (int hwi = hw_8div; hwi < plane_size; hwi += 1) { | |||
| int32_t tmp_sum_value = 0; | |||
| const int8_t *src_ic = src_r; | |||
| int8_t *pack_ic = pack_r; | |||
| for (int ici = 0; ici < ic_4div; ici += C4NUM) { | |||
| tmp_sum_value += src_ic[0]; | |||
| tmp_sum_value += src_ic[1]; | |||
| tmp_sum_value += src_ic[2]; | |||
| tmp_sum_value += src_ic[3]; | |||
| pack_ic[0] = src_ic[0]; | |||
| pack_ic[1] = src_ic[1]; | |||
| pack_ic[2] = src_ic[2]; | |||
| pack_ic[3] = src_ic[3]; | |||
| src_ic += C4NUM; | |||
| pack_ic += C4NUM * C8NUM; | |||
| } | |||
| for (int ici = ic_4div; ici < input_channel; ici += 1) { | |||
| tmp_sum_value += src_ic[0]; | |||
| pack_ic[0] = src_ic[0]; | |||
| src_ic += 1; | |||
| pack_ic += 1; | |||
| } | |||
| input_sum[hwi] = tmp_sum_value * filter_zp; | |||
| src_r += input_channel; | |||
| pack_r += C4NUM; | |||
| for (int i = 0; i < C8NUM; i++) { | |||
| input_sum_r[i] = tmp_sum_value[i] * filter_zp; | |||
| } | |||
| #endif | |||
| src_r += input_channel * C8NUM; | |||
| pack_r += ic4 * C8NUM; | |||
| } | |||
| if (hw_8div != plane_size) { | |||
| memset(pack_r, 0, C8NUM * ic4); | |||
| for (int hwi = hw_8div; hwi < plane_size; hwi += 1) { | |||
| int32_t tmp_sum_value = 0; | |||
| const int8_t *src_ic = src_r; | |||
| int8_t *pack_ic = pack_r; | |||
| for (int ici = 0; ici < ic_4div; ici += C4NUM) { | |||
| tmp_sum_value += src_ic[0]; | |||
| tmp_sum_value += src_ic[1]; | |||
| tmp_sum_value += src_ic[2]; | |||
| tmp_sum_value += src_ic[3]; | |||
| pack_ic[0] = src_ic[0]; | |||
| pack_ic[1] = src_ic[1]; | |||
| pack_ic[2] = src_ic[2]; | |||
| pack_ic[3] = src_ic[3]; | |||
| src_ic += C4NUM; | |||
| pack_ic += C4NUM * C8NUM; | |||
| } | |||
| for (int hwi = plane_size; hwi < plane_size + hw_8res; hwi++) { | |||
| input_sum[hwi] = 0; | |||
| for (int ici = ic_4div; ici < input_channel; ici += 1) { | |||
| tmp_sum_value += src_ic[0]; | |||
| pack_ic[0] = src_ic[0]; | |||
| src_ic += 1; | |||
| pack_ic += 1; | |||
| } | |||
| input_sum[hwi] = tmp_sum_value * filter_zp; | |||
| src_r += input_channel; | |||
| pack_r += C4NUM; | |||
| } | |||
| for (int hwi = plane_size; hwi < UP_ROUND(plane_size, C8NUM); hwi++) { | |||
| input_sum[hwi] = 0; | |||
| } | |||
| } else { | |||
| /* per channel */ | |||
| RowMajor2Row4x8MajorInt8(src_input, packed_input, plane_size, input_channel); | |||
| PackInputSum8x4Int8(packed_input, input_sum, input_channel, output_channel, plane_size, conv_param); | |||
| } | |||
| return; | |||
| } | |||
| void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, | |||
| const int32_t *bias, int row, int col, int deep4, ConvParameter *conv_param, | |||
| MATMUL_OPT_R_FUNC matmul_func) { | |||
| const int32_t *bias, int row, int col, int deep4, int32_t *left_shift, int32_t *right_shift, | |||
| int32_t *multiplier, ConvParameter *conv_param, MATMUL_OPT_R_FUNC matmul_func) { | |||
| matmul_func(packed_input, packed_weight, dst, row, col, deep4, conv_param->output_channel_, input_sum, bias, | |||
| conv_param->conv_quant_arg_.left_shift_, conv_param->conv_quant_arg_.right_shift_, | |||
| conv_param->conv_quant_arg_.quant_multiplier_, conv_param->conv_quant_arg_.output_quant_args_[0].zp_, | |||
| conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], false); | |||
| left_shift, right_shift, multiplier, conv_param->conv_quant_arg_.output_quant_args_[0].zp_, | |||
| conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], | |||
| conv_param->conv_quant_arg_.filter_arg_num_ != 1); | |||
| return; | |||
| } | |||
| void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, | |||
| const int32_t *bias, int row, int col, int deep16, ConvParameter *conv_param) { | |||
| const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift, | |||
| int32_t *multiplier, ConvParameter *conv_param) { | |||
| if (conv_param->conv_quant_arg_.filter_arg_num_ > 1) { | |||
| return MatMulInt8_16x4_r(packed_input, packed_weight, dst, row, col, deep16, conv_param->output_channel_, input_sum, | |||
| bias, left_shift, right_shift, multiplier, | |||
| conv_param->conv_quant_arg_.output_quant_args_[0].zp_, | |||
| conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], | |||
| conv_param->conv_quant_arg_.filter_arg_num_ != 1); | |||
| } | |||
| #ifdef ENABLE_ARM64 | |||
| MatmulInt8Neon64(packed_input, packed_weight, dst, UP_ROUND(row, C4NUM), UP_ROUND(col, C4NUM), deep16, input_sum, | |||
| bias, conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], | |||
| @@ -622,10 +753,9 @@ void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t | |||
| conv_param->conv_quant_arg_.right_shift_[0], row, col, conv_param->output_channel_); | |||
| #else | |||
| MatMulInt8_16x4_r(packed_input, packed_weight, dst, row, col, deep16, conv_param->output_channel_, input_sum, bias, | |||
| conv_param->conv_quant_arg_.left_shift_, conv_param->conv_quant_arg_.right_shift_, | |||
| conv_param->conv_quant_arg_.quant_multiplier_, | |||
| conv_param->conv_quant_arg_.output_quant_args_[0].zp_, conv_param->conv_quant_arg_.out_act_min_[0], | |||
| conv_param->conv_quant_arg_.out_act_max_[0], false); | |||
| left_shift, right_shift, multiplier, conv_param->conv_quant_arg_.output_quant_args_[0].zp_, | |||
| conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], | |||
| conv_param->conv_quant_arg_.filter_arg_num_ != 1); | |||
| #endif | |||
| return; | |||
| } | |||
| @@ -54,13 +54,16 @@ void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight | |||
| ConvParameter *conv_param, GEMM_FUNC gemm_func); | |||
| // int8 convolution 1x1 | |||
| void Conv1x1PreOpt(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel, | |||
| size_t output_channel, size_t plane_size, ConvParameter *conv_param); | |||
| void Conv1x1PreOptPeroc(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel, | |||
| size_t output_channel, size_t plane_size, int32_t *filter_zp, size_t inputsum_stride); | |||
| void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel, | |||
| size_t plane_size, ConvParameter *conv_param); | |||
| void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, | |||
| const int32_t *bias, int row, int col, int deep16, ConvParameter *conv_param); | |||
| const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift, | |||
| int32_t *multiplier, ConvParameter *conv_param); | |||
| void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, | |||
| const int32_t *bias, int row, int col, int deep4, ConvParameter *conv_param, | |||
| MATMUL_OPT_R_FUNC matmul_func); | |||
| const int32_t *bias, int row, int col, int deep4, int32_t *left_shift, int32_t *right_shift, | |||
| int32_t *multiplier, ConvParameter *conv_param, MATMUL_OPT_R_FUNC matmul_func); | |||
| // int8 convolution 3x3 | |||
| void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data, | |||
| @@ -186,8 +186,9 @@ void MatMulInt8_16x4(const int8_t *a, const int8_t *b, int *dst, int row_4, int | |||
| void MatMulInt8_16x4_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_16, | |||
| size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift, | |||
| int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi, | |||
| bool per_channel) { | |||
| /* row4x16-major * row16x4-major => (int8)row-major : per-channel */ | |||
| bool peroc) { | |||
| /* support per-layer && weight per-channel */ | |||
| /* row4x16-major * row16x4-major => (int8)row-major*/ | |||
| for (int r = 0; r < row; r++) { | |||
| for (int c = 0; c < col; c++) { | |||
| int r4div = r / C4NUM, r4mod = r % C4NUM; | |||
| @@ -200,12 +201,13 @@ void MatMulInt8_16x4_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row | |||
| size_t bi = c4div * deep_16 * C4NUM + d16div * C4NUM * C16NUM + c4mod * C16NUM + d16mod; | |||
| value = value + a[ai] * b[bi]; | |||
| } | |||
| int32_t cur_input_sum = per_channel ? input_sum[c4div * UP_ROUND(row, C4NUM) + r * C4NUM + c4mod] : input_sum[r]; | |||
| int32_t cur_input_sum = | |||
| peroc ? input_sum[c4div * UP_ROUND(row, C4NUM) * C4NUM + r * C4NUM + c4mod] : input_sum[r]; | |||
| value -= cur_input_sum; | |||
| value += bias[c]; | |||
| int32_t cur_left_shift = per_channel ? left_shift[c] : left_shift[0]; | |||
| int32_t cur_right_shift = per_channel ? right_shift[c] : right_shift[0]; | |||
| int32_t cur_multiplier = per_channel ? multiplier[c] : multiplier[0]; | |||
| int32_t cur_left_shift = peroc ? left_shift[c] : left_shift[0]; | |||
| int32_t cur_right_shift = peroc ? right_shift[c] : right_shift[0]; | |||
| int32_t cur_multiplier = peroc ? multiplier[c] : multiplier[0]; | |||
| value = MultiplyByQuantizedMultiplier(value, cur_multiplier, cur_left_shift, cur_right_shift) + output_zp; | |||
| value = MSMIN(maxi, value); | |||
| value = MSMAX(mini, value); | |||
| @@ -232,7 +234,8 @@ void MatMulInt8_8x8_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, | |||
| size_t bi = c8div * deep_4 * C8NUM + d4div * C8NUM * C4NUM + c8mod * C4NUM + d4mod; | |||
| value = value + a[ai] * b[bi]; | |||
| } | |||
| int32_t cur_input_sum = per_channel ? input_sum[c8div * UP_ROUND(row, C8NUM) + r * C8NUM + c8mod] : input_sum[r]; | |||
| int32_t cur_input_sum = | |||
| per_channel ? input_sum[c8div * UP_ROUND(row, C8NUM) * C8NUM + r * C8NUM + c8mod] : input_sum[r]; | |||
| value -= cur_input_sum; | |||
| value += bias[c]; | |||
| int32_t cur_left_shift = per_channel ? left_shift[c] : left_shift[0]; | |||
| @@ -17,6 +17,7 @@ | |||
| #ifndef MINDSPORE_LITE_NNACL_INT8_MATMUL_H_ | |||
| #define MINDSPORE_LITE_NNACL_INT8_MATMUL_H_ | |||
| #include <stdio.h> | |||
| #include <string.h> | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/matmul_parameter.h" | |||
| @@ -195,7 +195,7 @@ void Pack1x1WeightFp32(const float *weight_data, float *packed_weight, ConvParam | |||
| } | |||
| void PackInputSum16x4PerLayer(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16) { | |||
| /* optimize normal -> same layout */ | |||
| /* normal matmul : 4x16 * 16x4 -> 4x4 */ | |||
| #ifdef ENABLE_ARM64 | |||
| asm volatile( | |||
| "mov x10, %[src] \n" | |||
| @@ -267,62 +267,158 @@ void PackInputSum16x4PerLayer(const int8_t *src, int32_t *dst, int32_t filter_zp | |||
| return; | |||
| } | |||
| void PackInputSum16x4Int8(const int8_t *input_value, int32_t *input_sum, size_t input_channel, size_t output_channel, | |||
| size_t plane_size, ConvParameter *conv_param) { | |||
| void PackInputSum16x4PerChannel(const int8_t *input_value, int32_t *input_sum, int32_t *filter_zp_ptr, | |||
| size_t plane_size, size_t input_channel, size_t output_channel) { | |||
| size_t hw4 = UP_ROUND(plane_size, C4NUM); | |||
| size_t ic16 = UP_ROUND(input_channel, C16NUM); | |||
| if (conv_param->conv_quant_arg_.filter_arg_num_ == 1) { | |||
| PackInputSum16x4PerLayer(input_value, input_sum, conv_param->conv_quant_arg_.filter_quant_args_[0].zp_, hw4, ic16); | |||
| } else { | |||
| for (int ri = 0; ri < plane_size; ri++) { | |||
| int ri4div = ri / C4NUM, ri4mod = ri % C4NUM; | |||
| for (int ci = 0; ci < output_channel; ci++) { | |||
| int32_t tmp_sum_value = 0; | |||
| int ci4div = ci / C4NUM, ci4mod = ci % C4NUM; | |||
| int32_t filter_zp = conv_param->conv_quant_arg_.filter_quant_args_[ci].zp_; | |||
| for (int di = 0; di < input_channel; di++) { | |||
| size_t di16div = di / C16NUM, di16mod = di % C16NUM; | |||
| int src_index = ri4div * C4NUM * ic16 + di16div * C16NUM * C4NUM + ri4mod * C16NUM + di16mod; | |||
| tmp_sum_value += input_value[src_index]; | |||
| } | |||
| int dst_index = ci4div * C4NUM * hw4 + ri * C4NUM + ci4mod; | |||
| input_sum[dst_index] = tmp_sum_value * filter_zp; | |||
| #ifdef ENABLE_ARM64 | |||
| size_t oc_div4 = output_channel / C4NUM * C4NUM; | |||
| size_t oc_res4 = output_channel - oc_div4; | |||
| size_t inputsun_stride = hw4 * C4NUM * 4 - C4NUM * C4NUM * 4; | |||
| asm volatile( | |||
| "mov x10, %[input_value] \n" | |||
| "mov x11, %[input_sum] \n" | |||
| "mov x15, %[filter_zp_ptr] \n" | |||
| "mov x0, #0 \n" // row 4 count | |||
| "1: \n" | |||
| "cmp x0, %[hw4] \n" | |||
| "beq 11f \n" | |||
| "add x0, x0, #4\n" | |||
| "dup v10.4s, wzr \n" | |||
| "mov x2, #0 \n" // input deep count | |||
| "mov x16, x15 \n" | |||
| "2: \n" | |||
| "cmp x2, %[ic16] \n" | |||
| "beq 3f \n" | |||
| "add x2, x2, #16 \n" | |||
| "ld1 {v0.16b}, [x10], #16\n" | |||
| "ld1 {v1.16b}, [x10], #16\n" | |||
| "ld1 {v2.16b}, [x10], #16\n" | |||
| "ld1 {v3.16b}, [x10], #16\n" | |||
| "saddlp v4.8h, v0.16b \n" | |||
| "saddlp v5.8h, v1.16b \n" | |||
| "saddlp v6.8h, v2.16b \n" | |||
| "saddlp v7.8h, v3.16b \n" | |||
| "saddlp v0.4S, v4.8h \n" | |||
| "saddlp v1.4S, v5.8h \n" | |||
| "saddlp v2.4S, v6.8h \n" | |||
| "saddlp v3.4S, v7.8h \n" | |||
| "addv s4, v0.4S \n" | |||
| "addv s5, v1.4S \n" | |||
| "addv s6, v2.4S \n" | |||
| "addv s7, v3.4S \n" | |||
| "mov v0.s[0], v4.s[0] \n" | |||
| "mov v0.s[1], v5.s[0] \n" | |||
| "mov v0.s[2], v6.s[0] \n" | |||
| "mov v0.s[3], v7.s[0] \n" | |||
| "add v10.4s, v10.4s, v0.4s \n" | |||
| "b 2b \n" | |||
| "3: \n" | |||
| "mov x12, x11 \n" // tmp inputsm inputsum hw | |||
| "add x11, x11, #64 \n" | |||
| "mov x4, #0 \n" // oc count | |||
| "dup v1.4s, v10.s[0] \n" | |||
| "dup v2.4s, v10.s[1] \n" | |||
| "dup v3.4s, v10.s[2] \n" | |||
| "dup v4.4s, v10.s[3] \n" | |||
| "4: \n" | |||
| "cmp x4, %[oc_div4] \n" | |||
| "beq 6f \n" | |||
| "add x4, x4, #4\n" | |||
| "ld1 {v15.4s}, [x16], #16\n" | |||
| "mul v16.4s, v15.4s, v1.4s \n" | |||
| "mul v17.4s, v15.4s, v2.4s \n" | |||
| "mul v18.4s, v15.4s, v3.4s \n" | |||
| "mul v19.4s, v15.4s, v4.4s \n" | |||
| "st1 {v16.4s}, [x12], #16 \n" | |||
| "st1 {v17.4s}, [x12], #16 \n" | |||
| "st1 {v18.4s}, [x12], #16 \n" | |||
| "st1 {v19.4s}, [x12], #16 \n" | |||
| "add x12, x12, %[inputsun_stride] \n" | |||
| "b 4b \n" | |||
| "6: \n" | |||
| "cmp %[oc_res4], #0\n" | |||
| "beq 1b \n" | |||
| "dup v15.4s, wzr \n" | |||
| "cmp %[oc_res4], #1\n" | |||
| "beq 7f \n" | |||
| "cmp %[oc_res4], #2\n" | |||
| "beq 8f \n" | |||
| "cmp %[oc_res4], #3\n" | |||
| "beq 9f \n" | |||
| "7: \n" | |||
| "ld1 {v15.s}[0], [x16] \n" | |||
| "b 10f \n" | |||
| "8: \n" | |||
| "ld1 {v15.h}[0], [x16] \n" | |||
| "b 10f \n" | |||
| "9: \n" | |||
| "ld1 {v15.h}[0], [x16] \n" | |||
| "add x16, x16, #8 \n" | |||
| "ld1 {v15.s}[2], [x16] \n" | |||
| "b 10f \n" | |||
| "10: \n" | |||
| "mul v16.4s, v15.4s, v1.4s \n" | |||
| "mul v17.4s, v15.4s, v2.4s \n" | |||
| "mul v18.4s, v15.4s, v3.4s \n" | |||
| "mul v19.4s, v15.4s, v4.4s \n" | |||
| "st1 {v16.4s}, [x12], #16 \n" | |||
| "st1 {v17.4s}, [x12], #16 \n" | |||
| "st1 {v18.4s}, [x12], #16 \n" | |||
| "st1 {v19.4s}, [x12], #16 \n" | |||
| "b 1b \n" | |||
| "11: \n" | |||
| : | |||
| : [ input_value ] "r"(input_value), [ input_sum ] "r"(input_sum), [ filter_zp_ptr ] "r"(filter_zp_ptr), | |||
| [ hw4 ] "r"(hw4), [ ic16 ] "r"(ic16), [ oc_div4 ] "r"(oc_div4), [ oc_res4 ] "r"(oc_res4), | |||
| [ inputsun_stride ] "r"(inputsun_stride) | |||
| : "x0", "x2", "x4", "x10", "x11", "x12", "x15", "x16", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v10", "v15", | |||
| "v16", "v17", "v18", "v19"); | |||
| #else | |||
| for (int ri = 0; ri < plane_size; ri++) { | |||
| int ri4div = ri / C4NUM, ri4mod = ri % C4NUM; | |||
| for (int ci = 0; ci < output_channel; ci++) { | |||
| int32_t tmp_sum_value = 0; | |||
| int ci4div = ci / C4NUM, ci4mod = ci % C4NUM; | |||
| int32_t filter_zp = filter_zp_ptr[ci]; | |||
| for (int di = 0; di < input_channel; di++) { | |||
| size_t di16div = di / C16NUM, di16mod = di % C16NUM; | |||
| int src_index = ri4div * C4NUM * ic16 + di16div * C16NUM * C4NUM + ri4mod * C16NUM + di16mod; | |||
| tmp_sum_value += input_value[src_index]; | |||
| } | |||
| int dst_index = ci4div * C4NUM * hw4 + ri * C4NUM + ci4mod; | |||
| input_sum[dst_index] = tmp_sum_value * filter_zp; | |||
| } | |||
| } | |||
| #endif | |||
| return; | |||
| } | |||
| void PackInputSum8x4Int8(const int8_t *input_value, int32_t *input_sum, size_t input_channel, size_t output_channel, | |||
| size_t plane_size, ConvParameter *conv_param) { | |||
| size_t hw8 = UP_ROUND(plane_size, C8NUM); | |||
| size_t ic4 = UP_ROUND(input_channel, C4NUM); | |||
| void PackInputSum16x4Int8(const int8_t *input, int32_t *input_sum, int32_t *filter_zp, ConvParameter *conv_param) { | |||
| size_t hw4 = UP_ROUND(conv_param->input_h_ * conv_param->input_w_, C4NUM); | |||
| size_t ic16 = UP_ROUND(conv_param->input_channel_, C16NUM); | |||
| if (conv_param->conv_quant_arg_.filter_arg_num_ == 1) { | |||
| for (int r = 0; r < hw8; r++) { | |||
| int32_t tmp_value = 0; | |||
| for (int c = 0; c < ic4; c++) { | |||
| int r8div = r / C8NUM, r8mod = r % C8NUM, c4div = c / C4NUM, c4mod = c % C4NUM; | |||
| int src_index = r8div * C8NUM * ic4 + c4div * C8NUM * C4NUM + r8mod * C4NUM + c4mod; | |||
| tmp_value += input_value[src_index]; | |||
| } | |||
| input_sum[r] = tmp_value * conv_param->conv_quant_arg_.filter_quant_args_[0].zp_; | |||
| } | |||
| PackInputSum16x4PerLayer(input, input_sum, conv_param->conv_quant_arg_.filter_quant_args_[0].zp_, hw4, ic16); | |||
| } else { | |||
| for (int ri = 0; ri < plane_size; ri++) { | |||
| int ri8div = ri / C8NUM, ri8mod = ri % C8NUM; | |||
| for (int ci = 0; ci < output_channel; ci++) { | |||
| int32_t tmp_sum_value = 0; | |||
| int ci8div = ci / C8NUM, ci8mod = ci % C8NUM; | |||
| int32_t filter_zp = conv_param->conv_quant_arg_.filter_quant_args_[ci].zp_; | |||
| for (int di = 0; di < input_channel; di++) { | |||
| size_t di4div = di / C4NUM, di4mod = di % C4NUM; | |||
| int src_index = ri8div * C8NUM * ic4 + di4div * C8NUM * C4NUM + ri8mod * C4NUM + di4mod; | |||
| tmp_sum_value += input_value[src_index]; | |||
| } | |||
| int dst_index = ci8div * C8NUM * hw8 + ri * C8NUM + ci8mod; | |||
| input_sum[dst_index] = tmp_sum_value * filter_zp; | |||
| } | |||
| } | |||
| PackInputSum16x4PerChannel(input, input_sum, filter_zp, conv_param->input_h_ * conv_param->input_w_, | |||
| conv_param->input_channel_, conv_param->output_channel_); | |||
| } | |||
| return; | |||
| } | |||
| @@ -17,6 +17,7 @@ | |||
| #ifndef MINDSPORE_LITE_NNACL_PACK_H_ | |||
| #define MINDSPORE_LITE_NNACL_PACK_H_ | |||
| #include <stdio.h> | |||
| #ifdef ENABLE_NEON | |||
| #include <arm_neon.h> | |||
| #endif | |||
| @@ -41,8 +42,7 @@ void Conv1x1InputPack(const void *src_ptr, void *dst_ptr, ConvParameter *conv_pa | |||
| void Pack1x1WeightFp32(const float *weight_data, float *packed_weight, ConvParameter *conv_param); | |||
| void PackInputSum16x4Int8(const int8_t *input_value, int32_t *input_sum, size_t input_channel, size_t output_channel, | |||
| size_t plane_size, ConvParameter *conv_param); | |||
| void PackInputSum16x4Int8(const int8_t *input, int32_t *input_sum, int32_t *filter_zp, ConvParameter *conv_param); | |||
| void PackInputSum8x4Int8(const int8_t *input_value, int32_t *input_sum, size_t input_channel, size_t output_channel, | |||
| size_t plane_size, ConvParameter *conv_param); | |||
| @@ -316,14 +316,12 @@ int ConvolutionBaseCPUKernel::SetQuantParam() { | |||
| MS_LOG(ERROR) << "Set Quant Multiplier Failed."; | |||
| return ret; | |||
| } | |||
| // now only consider per tensor for output | |||
| bool relu = conv_param_->act_type_ == ActType_Relu; | |||
| bool relu6 = conv_param_->act_type_ == ActType_Relu6; | |||
| CalculateActivationRangeQuantized(relu, relu6, conv_param_->conv_quant_arg_.output_quant_args_[0].zp_, | |||
| conv_param_->conv_quant_arg_.output_quant_args_[0].scale_, | |||
| &conv_param_->conv_quant_arg_.out_act_min_[0], | |||
| &conv_param_->conv_quant_arg_.out_act_max_[0]); | |||
| return RET_OK; | |||
| } | |||
| int ConvolutionBaseCPUKernel::RestoreFilter(lite::Tensor *input_tensor) { | |||
| @@ -16,6 +16,7 @@ | |||
| #include "src/runtime/kernel/arm/int8/convolution_1x1_int8.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| #include "src/common/file_utils.h" | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_MEMORY_FAILED; | |||
| @@ -41,6 +42,10 @@ Convolution1x1Int8CPUKernel::~Convolution1x1Int8CPUKernel() { | |||
| free(packed_weight_); | |||
| packed_weight_ = nullptr; | |||
| } | |||
| if (filter_peroc_ && filter_zp_ptr_ != nullptr) { | |||
| free(filter_zp_ptr_); | |||
| filter_zp_ptr_ = nullptr; | |||
| } | |||
| FreeResizeBuf(); | |||
| FreeQuantParam(); | |||
| } | |||
| @@ -54,7 +59,7 @@ void Convolution1x1Int8CPUKernel::FreeResizeBuf() { | |||
| } | |||
| void Convolution1x1Int8CPUKernel::CheckSupportOptimize() { | |||
| support_optimize_ = true; | |||
| support_optimize_ = false; | |||
| matmul_func_ = MatMulInt8_8x8_r; | |||
| #ifdef ENABLE_ARM64 | |||
| void *optimize_op_handler = OptimizeModule::GetInstance()->optimized_op_handler_; | |||
| @@ -73,6 +78,10 @@ void Convolution1x1Int8CPUKernel::CheckSupportOptimize() { | |||
| support_optimize_ = false; | |||
| matmul_func_ = nullptr; | |||
| } | |||
| if (filter_peroc_) { | |||
| support_optimize_ = false; | |||
| } | |||
| #endif | |||
| return; | |||
| } | |||
| @@ -118,14 +127,23 @@ int Convolution1x1Int8CPUKernel::InitWeightBias() { | |||
| int32_t input_zp = conv_param_->conv_quant_arg_.input_quant_args_[0].zp_; | |||
| for (int oc = 0; oc < output_channel; oc++) { | |||
| int32_t weight_sum_value = 0; | |||
| int32_t filter_zp = (conv_param_->conv_quant_arg_.filter_arg_num_ == 1) | |||
| ? conv_param_->conv_quant_arg_.filter_quant_args_[0].zp_ | |||
| : conv_param_->conv_quant_arg_.filter_quant_args_[oc].zp_; | |||
| int32_t filter_zp = (filter_peroc_) ? conv_param_->conv_quant_arg_.filter_quant_args_[oc].zp_ | |||
| : conv_param_->conv_quant_arg_.filter_quant_args_[0].zp_; | |||
| for (int ic = 0; ic < input_channel; ic++) { | |||
| weight_sum_value += weight[oc * input_channel + ic]; | |||
| } | |||
| bias_data[oc] += filter_zp * input_zp * input_channel - weight_sum_value * input_zp; | |||
| } | |||
| if (filter_peroc_) { | |||
| filter_zp_ptr_ = reinterpret_cast<int32_t *>(malloc(output_channel * sizeof(int32_t))); | |||
| if (filter_zp_ptr_ == nullptr) { | |||
| return RET_ERROR; | |||
| } | |||
| for (int fi = 0; fi < output_channel; fi++) { | |||
| filter_zp_ptr_[fi] = conv_param_->conv_quant_arg_.filter_quant_args_[fi].zp_; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -136,14 +154,16 @@ int Convolution1x1Int8CPUKernel::Init() { | |||
| return RET_ERROR; | |||
| } | |||
| CheckSupportOptimize(); | |||
| auto ret = SetQuantParam(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Set quant param failed."; | |||
| return ret; | |||
| } | |||
| filter_peroc_ = (conv_param_->conv_quant_arg_.filter_arg_num_ != 1); | |||
| CheckSupportOptimize(); | |||
| ret = InitWeightBias(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Init weight bias failed."; | |||
| @@ -229,14 +249,17 @@ void Convolution1x1Int8CPUKernel::Pre1x1Trans(int8_t *src_input, int8_t *src_out | |||
| ParallelLaunch(THREAD_POOL_DEFAULT, Convolution1x1Int8Pre, this, thread_count_hw_); | |||
| } else { | |||
| RowMajor2Row16x4MajorInt8(input_ptr_, packed_input_, matmul_param_->row_, matmul_param_->deep_); | |||
| PackInputSum16x4Int8(packed_input_, input_sum_, matmul_param_->deep_, matmul_param_->col_, matmul_param_->row_, | |||
| conv_param_); | |||
| PackInputSum16x4Int8(packed_input_, input_sum_, filter_zp_ptr_, conv_param_); | |||
| } | |||
| return; | |||
| } | |||
| int Convolution1x1Int8CPUKernel::RunImpl(int task_id) { | |||
| int32_t *cur_input_sum = input_sum_; | |||
| int32_t *cur_left_shift = conv_param_->conv_quant_arg_.left_shift_; | |||
| int32_t *cur_right_shift = conv_param_->conv_quant_arg_.right_shift_; | |||
| int32_t *cur_multiplier = conv_param_->conv_quant_arg_.quant_multiplier_; | |||
| if (support_optimize_) { | |||
| int cur_stride = thread_stride_ * C8NUM; | |||
| int res_stride = matmul_param_->col_ - task_id * thread_stride_ * C8NUM; | |||
| @@ -244,10 +267,17 @@ int Convolution1x1Int8CPUKernel::RunImpl(int task_id) { | |||
| if (cur_oc <= 0) { | |||
| return RET_OK; | |||
| } | |||
| if (filter_peroc_) { | |||
| cur_input_sum = input_sum_ + task_id * matmul_param_->row_8_ * thread_stride_ * C8NUM; | |||
| cur_left_shift = conv_param_->conv_quant_arg_.left_shift_ + task_id * thread_stride_ * C8NUM; | |||
| cur_right_shift = conv_param_->conv_quant_arg_.right_shift_ + task_id * thread_stride_ * C8NUM; | |||
| cur_multiplier = conv_param_->conv_quant_arg_.quant_multiplier_ + task_id * thread_stride_ * C8NUM; | |||
| } | |||
| Conv1x1Int8Opt(packed_input_, packed_weight_ + task_id * thread_stride_ * C8NUM * matmul_param_->deep_4_, | |||
| output_ptr_ + task_id * thread_stride_ * C8NUM, input_sum_, | |||
| output_ptr_ + task_id * thread_stride_ * C8NUM, cur_input_sum, | |||
| reinterpret_cast<int32_t *>(bias_data_) + task_id * thread_stride_ * C8NUM, matmul_param_->row_, | |||
| cur_oc, matmul_param_->deep_4_, conv_param_, matmul_func_); | |||
| cur_oc, matmul_param_->deep_4_, cur_left_shift, cur_right_shift, cur_multiplier, conv_param_, | |||
| matmul_func_); | |||
| } else { | |||
| int cur_stride = thread_stride_ * C4NUM; | |||
| int res_stride = matmul_param_->col_ - task_id * thread_stride_ * C4NUM; | |||
| @@ -255,10 +285,16 @@ int Convolution1x1Int8CPUKernel::RunImpl(int task_id) { | |||
| if (cur_oc <= 0) { | |||
| return RET_OK; | |||
| } | |||
| if (filter_peroc_) { | |||
| cur_input_sum = input_sum_ + task_id * matmul_param_->row_4_ * thread_stride_ * C4NUM; | |||
| cur_left_shift = conv_param_->conv_quant_arg_.left_shift_ + task_id * thread_stride_ * C4NUM; | |||
| cur_right_shift = conv_param_->conv_quant_arg_.right_shift_ + task_id * thread_stride_ * C4NUM; | |||
| cur_multiplier = conv_param_->conv_quant_arg_.quant_multiplier_ + task_id * thread_stride_ * C4NUM; | |||
| } | |||
| Conv1x1Int8(packed_input_, packed_weight_ + task_id * thread_stride_ * C4NUM * matmul_param_->deep_16_, | |||
| output_ptr_ + task_id * thread_stride_ * C4NUM, input_sum_, | |||
| output_ptr_ + task_id * thread_stride_ * C4NUM, cur_input_sum, | |||
| reinterpret_cast<int32_t *>(bias_data_) + task_id * thread_stride_ * C4NUM, matmul_param_->row_, cur_oc, | |||
| matmul_param_->deep_16_, conv_param_); | |||
| matmul_param_->deep_16_, cur_left_shift, cur_right_shift, cur_multiplier, conv_param_); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -270,10 +306,18 @@ int Convolution1x1Int8CPUKernel::RunPre(int task_id) { | |||
| if (cur_hw <= 0) { | |||
| return RET_OK; | |||
| } | |||
| Conv1x1PreOpt(input_ptr_ + task_id * thread_stride_hw_ * C8NUM * matmul_param_->deep_, | |||
| packed_input_ + task_id * thread_stride_hw_ * C8NUM * matmul_param_->deep_4_, | |||
| input_sum_ + task_id * thread_stride_hw_ * C8NUM, matmul_param_->deep_, matmul_param_->col_, cur_hw, | |||
| conv_param_); | |||
| if (filter_peroc_) { | |||
| Conv1x1PreOptPeroc(input_ptr_ + task_id * thread_stride_hw_ * C8NUM * matmul_param_->deep_, | |||
| packed_input_ + task_id * thread_stride_hw_ * C8NUM * matmul_param_->deep_4_, | |||
| input_sum_ + task_id * thread_stride_hw_ * C8NUM * C8NUM, matmul_param_->deep_, | |||
| matmul_param_->col_, cur_hw, filter_zp_ptr_, matmul_param_->row_8_ * C8NUM); | |||
| } else { | |||
| Conv1x1PreOptPert(input_ptr_ + task_id * thread_stride_hw_ * C8NUM * matmul_param_->deep_, | |||
| packed_input_ + task_id * thread_stride_hw_ * C8NUM * matmul_param_->deep_4_, | |||
| input_sum_ + task_id * thread_stride_hw_ * C8NUM, matmul_param_->deep_, cur_hw, conv_param_); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -56,7 +56,8 @@ class Convolution1x1Int8CPUKernel : public ConvolutionBaseCPUKernel { | |||
| void CheckSupportOptimize(); | |||
| private: | |||
| int32_t *input_sum_ = nullptr; /* per-channel: oc4 format */ | |||
| int32_t *input_sum_ = nullptr; /* per-channel: oc4 format */ | |||
| int32_t *filter_zp_ptr_ = nullptr; /* oc - per - channel */ | |||
| int8_t *packed_weight_ = nullptr; | |||
| int8_t *packed_input_ = nullptr; | |||
| int8_t *input_ptr_ = nullptr; | |||
| @@ -70,6 +71,7 @@ class Convolution1x1Int8CPUKernel : public ConvolutionBaseCPUKernel { | |||
| MatMulParameter *matmul_param_ = nullptr; | |||
| MATMUL_OPT_R_FUNC matmul_func_ = nullptr; | |||
| bool support_optimize_ = false; | |||
| bool filter_peroc_ = false; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| @@ -397,10 +397,9 @@ kernel::LiteKernel *CpuConvInt8KernelCreator(const std::vector<lite::Tensor *> & | |||
| int dilation_h = conv_param->dilation_h_; | |||
| int dilation_w = conv_param->dilation_w_; | |||
| kernel::LiteKernel *kernel; | |||
| auto filter_quant_size = inputs[kWeightIndex]->GetQuantParams().size(); | |||
| if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) { | |||
| kernel = new (std::nothrow) kernel::ConvolutionInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| } else if (kernel_h == 1 && kernel_w == 1 && filter_quant_size == 1) { | |||
| } else if (kernel_h == 1 && kernel_w == 1) { | |||
| kernel = new (std::nothrow) kernel::Convolution1x1Int8CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| } else { | |||
| kernel = new (std::nothrow) kernel::ConvolutionInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||