diff --git a/mindspore/lite/nnacl/int8/conv_int8.c b/mindspore/lite/nnacl/int8/conv_int8.c index 756bd13850..bc6f3c3143 100644 --- a/mindspore/lite/nnacl/int8/conv_int8.c +++ b/mindspore/lite/nnacl/int8/conv_int8.c @@ -371,249 +371,380 @@ void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight } } -void Conv1x1PreOpt(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel, - size_t output_channel, size_t plane_size, ConvParameter *conv_param) { +void Conv1x1PreOptPeroc(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel, + size_t output_channel, size_t plane_size, int32_t *filter_zp, size_t inputsum_stride) { int ic4 = UP_ROUND(input_channel, C4NUM); + int oc8 = UP_ROUND(output_channel, C8NUM); + int hw8 = UP_ROUND(plane_size, C8NUM); size_t hw_8div = plane_size / C8NUM * C8NUM; - size_t hw_8res = plane_size - hw_8div; + size_t oc_8div = output_channel / C8NUM * C8NUM; + size_t oc_8res = output_channel - oc_8div; size_t ic_4div = input_channel / C4NUM * C4NUM; - int32_t filter_zp = conv_param->conv_quant_arg_.filter_quant_args_[0].zp_; - if (conv_param->conv_quant_arg_.filter_arg_num_ == 1) { - const int8_t *src_r = src_input; - int8_t *pack_r = packed_input; - /* per layer */ - for (int hwi = 0; hwi < hw_8div; hwi += C8NUM) { + const int8_t *src_r = src_input; + int8_t *pack_r = packed_input; + int32_t *input_sum_r = input_sum; + + /* per layer */ + for (int hwi = 0; hwi < hw_8div; hwi += C8NUM) { + const int8_t *src_ic = src_r; + int8_t *pack_ic = pack_r; + int32_t *input_sum_oc = input_sum_r; + + int32_t tmp_sum_value[8] = {0}; + for (int ici = 0; ici < ic_4div; ici += C4NUM) { + for (int i = 0; i < C8NUM; i++) { + tmp_sum_value[i] += src_ic[0 + i * input_channel]; + tmp_sum_value[i] += src_ic[1 + i * input_channel]; + tmp_sum_value[i] += src_ic[2 + i * input_channel]; + tmp_sum_value[i] += src_ic[3 + i * input_channel]; + pack_ic[0 + i * C4NUM] = src_ic[0 + i * input_channel]; + pack_ic[1 + i * C4NUM] = src_ic[1 + i * input_channel]; + pack_ic[2 + i * C4NUM] = src_ic[2 + i * input_channel]; + pack_ic[3 + i * C4NUM] = src_ic[3 + i * input_channel]; + } + src_ic += C4NUM; + pack_ic += C4NUM * C8NUM; + } + for (int ici = ic_4div; ici < input_channel; ici += 1) { + for (int i = 0; i < C8NUM; i++) { + tmp_sum_value[i] += src_ic[i * input_channel]; + pack_ic[i * C4NUM] = src_ic[i * input_channel]; + } + src_ic += 1; + pack_ic += 1; + } + + for (int oci = 0; oci < oc_8div; oci += C8NUM) { + for (int ri = 0; ri < C8NUM; ri++) { + input_sum_oc[ri * C8NUM + 0] = tmp_sum_value[ri] * filter_zp[oci + 0]; + input_sum_oc[ri * C8NUM + 1] = tmp_sum_value[ri] * filter_zp[oci + 1]; + input_sum_oc[ri * C8NUM + 2] = tmp_sum_value[ri] * filter_zp[oci + 2]; + input_sum_oc[ri * C8NUM + 3] = tmp_sum_value[ri] * filter_zp[oci + 3]; + input_sum_oc[ri * C8NUM + 4] = tmp_sum_value[ri] * filter_zp[oci + 4]; + input_sum_oc[ri * C8NUM + 5] = tmp_sum_value[ri] * filter_zp[oci + 5]; + input_sum_oc[ri * C8NUM + 6] = tmp_sum_value[ri] * filter_zp[oci + 6]; + input_sum_oc[ri * C8NUM + 7] = tmp_sum_value[ri] * filter_zp[oci + 7]; + } + input_sum_oc += inputsum_stride; + } + if (oc_8div != output_channel) { + for (int oci = 0; oci < oc_8res; oci += 1) { + for (int ri = 0; ri < C8NUM; ri++) { + input_sum_oc[ri * C8NUM + oci] = tmp_sum_value[ri] * filter_zp[oc_8div + oci]; + } + } + for (int oci = oc_8res; oci < C8NUM; oci += 1) { + for (int ri = 0; ri < C8NUM; ri++) { + input_sum_oc[ri * C8NUM + oci] = 0; + } + } + } /* oc8 res done */ + + src_r += input_channel * C8NUM; + pack_r += ic4 * C8NUM; + input_sum_r += C8NUM * C8NUM; + } + + if (hw_8div != plane_size) { + memset(pack_r, 0, C8NUM * ic4); + for (int hwi = hw_8div; hwi < plane_size; hwi += 1) { + int32_t *input_sum_oc = input_sum_r; + int32_t tmp_sum_value = 0; const int8_t *src_ic = src_r; int8_t *pack_ic = pack_r; - int32_t *input_sum_r = input_sum + hwi; -#ifdef ENABLE_ARM64 - size_t src_stride = input_channel; - size_t ic_4res = input_channel - ic_4div; - asm volatile( - "dup v10.4s, wzr \n" - "dup v11.4s, wzr \n" - "mov x20, %[input_sum_r] \n" - "dup v20.4s, %w[filter_zp] \n" - - "mov x10, %[src_ic] \n" - "mov x11, %[pack_ic] \n" - - "mov x0, #0 \n" - "1: \n" - "cmp x0, %[ic_4div] \n" - "add x0, x0, #4\n" - "mov x12, x10 \n" - "add x10, x10, #4\n" - "blt 2f \n" - "cmp %[ic_4res], #0\n" - "beq 6f \n" - "cmp %[ic_4res], #1\n" - "beq 3f \n" - "cmp %[ic_4res], #2\n" - "beq 4f \n" - "cmp %[ic_4res], #3\n" - "beq 5f \n" - - "2: \n" - "ld1 {v0.s}[0], [x12], %[src_stride]\n" - "ld1 {v0.s}[1], [x12], %[src_stride]\n" - "ld1 {v0.s}[2], [x12], %[src_stride]\n" - "ld1 {v0.s}[3], [x12], %[src_stride]\n" - "ld1 {v1.s}[0], [x12], %[src_stride]\n" - "ld1 {v1.s}[1], [x12], %[src_stride]\n" - "ld1 {v1.s}[2], [x12], %[src_stride]\n" - "ld1 {v1.s}[3], [x12], %[src_stride]\n" - - "st1 {v0.16b}, [x11], #16\n" - "st1 {v1.16b}, [x11], #16\n" - - "saddlp v4.8h, v0.16b \n" - "saddlp v5.8h, v1.16b \n" - - "saddlp v0.4s, v4.8h \n" - "saddlp v1.4s, v5.8h \n" - - "add v10.4s, v10.4s, v0.4s \n" - "add v11.4s, v11.4s, v1.4s \n" - "b 1b \n" - - "3: \n" /* col res 1 */ - "dup v0.4s, wzr \n" - "dup v1.4s, wzr \n" - - "ld1 {v0.b}[0], [x12], %[src_stride]\n" - "ld1 {v0.b}[4], [x12], %[src_stride]\n" - "ld1 {v0.b}[8], [x12], %[src_stride]\n" - "ld1 {v0.b}[12], [x12], %[src_stride]\n" - "ld1 {v1.b}[0], [x12], %[src_stride]\n" - "ld1 {v1.b}[4], [x12], %[src_stride]\n" - "ld1 {v1.b}[8], [x12], %[src_stride]\n" - "ld1 {v1.b}[12], [x12], %[src_stride]\n" - - "st1 {v0.16b}, [x11], #16\n" - "st1 {v1.16b}, [x11], #16\n" - "saddlp v4.8h, v0.16b \n" - "saddlp v5.8h, v1.16b \n" - "saddlp v0.4s, v4.8h \n" - "saddlp v1.4s, v5.8h \n" - "add v10.4s, v10.4s, v0.4s \n" - "add v11.4s, v11.4s, v1.4s \n" - "b 6f \n" - - "4: \n" /* col res 2 */ - "dup v0.4s, wzr \n" - "dup v1.4s, wzr \n" - - "ld1 {v0.h}[0], [x12], %[src_stride]\n" - "ld1 {v0.h}[2], [x12], %[src_stride]\n" - "ld1 {v0.h}[4], [x12], %[src_stride]\n" - "ld1 {v0.h}[6], [x12], %[src_stride]\n" - "ld1 {v1.h}[0], [x12], %[src_stride]\n" - "ld1 {v1.h}[2], [x12], %[src_stride]\n" - "ld1 {v1.h}[4], [x12], %[src_stride]\n" - "ld1 {v1.h}[6], [x12], %[src_stride]\n" - - "st1 {v0.16b}, [x11], #16\n" - "st1 {v1.16b}, [x11], #16\n" - "saddlp v4.8h, v0.16b \n" - "saddlp v5.8h, v1.16b \n" - "saddlp v0.4s, v4.8h \n" - "saddlp v1.4s, v5.8h \n" - "add v10.4s, v10.4s, v0.4s \n" - "add v11.4s, v11.4s, v1.4s \n" - "b 6f \n" - - "5: \n" /* col res 3 */ - "dup v0.4s, wzr \n" - "dup v1.4s, wzr \n" - "add x13, x12, #2 \n" - - "ld1 {v0.h}[0], [x12], %[src_stride]\n" - "ld1 {v0.b}[2], [x13], %[src_stride]\n" - "ld1 {v0.h}[2], [x12], %[src_stride]\n" - "ld1 {v0.b}[6], [x13], %[src_stride]\n" - "ld1 {v0.h}[4], [x12], %[src_stride]\n" - "ld1 {v0.b}[10], [x13], %[src_stride]\n" - "ld1 {v0.h}[6], [x12], %[src_stride]\n" - "ld1 {v0.b}[14], [x13], %[src_stride]\n" - "ld1 {v1.h}[0], [x12], %[src_stride]\n" - "ld1 {v1.b}[2], [x13], %[src_stride]\n" - "ld1 {v1.h}[2], [x12], %[src_stride]\n" - "ld1 {v1.b}[6], [x13], %[src_stride]\n" - "ld1 {v1.h}[4], [x12], %[src_stride]\n" - "ld1 {v1.b}[10], [x13], %[src_stride]\n" - "ld1 {v1.h}[6], [x12], %[src_stride]\n" - "ld1 {v1.b}[14], [x13], %[src_stride]\n" - - "st1 {v0.16b}, [x11], #16\n" - "st1 {v1.16b}, [x11], #16\n" - "saddlp v4.8h, v0.16b \n" - "saddlp v5.8h, v1.16b \n" - "saddlp v0.4s, v4.8h \n" - "saddlp v1.4s, v5.8h \n" - "add v10.4s, v10.4s, v0.4s \n" - "add v11.4s, v11.4s, v1.4s \n" - "b 6f \n" - - "6: \n" - "mul v10.4s, v10.4s, v20.4s \n" - "mul v11.4s, v11.4s, v20.4s \n" - - "st1 {v10.4s}, [x20], #16 \n" - "st1 {v11.4s}, [x20], #16 \n" - - : - : [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ input_sum_r ] "r"(input_sum_r), - [ src_stride ] "r"(src_stride), [ ic_4div ] "r"(ic_4div), [ ic_4res ] "r"(ic_4res), - [ filter_zp ] "r"(filter_zp) - : "x0", "x1", "x10", "x11", "x12", "x13", "x20", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v10", "v11", - "v20"); -#else - int32_t tmp_sum_value[8] = {0}; for (int ici = 0; ici < ic_4div; ici += C4NUM) { - for (int i = 0; i < C8NUM; i++) { - tmp_sum_value[i] += src_ic[0 + i * input_channel]; - tmp_sum_value[i] += src_ic[1 + i * input_channel]; - tmp_sum_value[i] += src_ic[2 + i * input_channel]; - tmp_sum_value[i] += src_ic[3 + i * input_channel]; - pack_ic[0 + i * C4NUM] = src_ic[0 + i * input_channel]; - pack_ic[1 + i * C4NUM] = src_ic[1 + i * input_channel]; - pack_ic[2 + i * C4NUM] = src_ic[2 + i * input_channel]; - pack_ic[3 + i * C4NUM] = src_ic[3 + i * input_channel]; - } + tmp_sum_value += src_ic[0]; + tmp_sum_value += src_ic[1]; + tmp_sum_value += src_ic[2]; + tmp_sum_value += src_ic[3]; + pack_ic[0] = src_ic[0]; + pack_ic[1] = src_ic[1]; + pack_ic[2] = src_ic[2]; + pack_ic[3] = src_ic[3]; src_ic += C4NUM; pack_ic += C4NUM * C8NUM; } for (int ici = ic_4div; ici < input_channel; ici += 1) { - for (int i = 0; i < C8NUM; i++) { - tmp_sum_value[i] += src_ic[i * input_channel]; - pack_ic[i * C4NUM] = src_ic[i * input_channel]; - } + tmp_sum_value += src_ic[0]; + pack_ic[0] = src_ic[0]; src_ic += 1; pack_ic += 1; } + for (int oci = 0; oci < oc_8div; oci += C8NUM) { + for (int curoi = 0; curoi < C8NUM; curoi++) { + input_sum_oc[curoi] = tmp_sum_value * filter_zp[oci + curoi]; + } + input_sum_oc += inputsum_stride; + } + if (oc_8div != output_channel) { + for (int oci = 0; oci < oc_8res; oci += 1) { + input_sum_oc[oci] = tmp_sum_value * filter_zp[oc_8div + oci]; + } + for (int oci = oc_8res; oci < C8NUM; oci += 1) { + input_sum_oc[oci] = 0; + } + } /* oc8 res done */ + + src_r += input_channel; + pack_r += C4NUM; + input_sum_r += C8NUM; + } + + for (int hwi = plane_size; hwi < hw8; hwi++) { + for (int oc = 0; oc < oc8; oc++) { + int oc8div = oc / C8NUM, oc8res = oc % C8NUM; + input_sum[oc8div * inputsum_stride + hwi * C8NUM + oc8res] = 0; + } + } + } + return; +} + +void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel, + size_t plane_size, ConvParameter *conv_param) { + int ic4 = UP_ROUND(input_channel, C4NUM); + size_t hw_8div = plane_size / C8NUM * C8NUM; + size_t ic_4div = input_channel / C4NUM * C4NUM; + int32_t filter_zp = conv_param->conv_quant_arg_.filter_quant_args_[0].zp_; + + const int8_t *src_r = src_input; + int8_t *pack_r = packed_input; + /* per layer */ + for (int hwi = 0; hwi < hw_8div; hwi += C8NUM) { + const int8_t *src_ic = src_r; + int8_t *pack_ic = pack_r; + int32_t *input_sum_r = input_sum + hwi; +#ifdef ENABLE_ARM64 + size_t src_stride = input_channel; + size_t ic_4res = input_channel - ic_4div; + asm volatile( + "dup v10.4s, wzr \n" + "dup v11.4s, wzr \n" + "mov x20, %[input_sum_r] \n" + "dup v20.4s, %w[filter_zp] \n" + + "mov x10, %[src_ic] \n" + "mov x11, %[pack_ic] \n" + + "mov x0, #0 \n" + "1: \n" + "cmp x0, %[ic_4div] \n" + "add x0, x0, #4\n" + "mov x12, x10 \n" + "add x10, x10, #4\n" + "blt 2f \n" + "cmp %[ic_4res], #0\n" + "beq 6f \n" + "cmp %[ic_4res], #1\n" + "beq 3f \n" + "cmp %[ic_4res], #2\n" + "beq 4f \n" + "cmp %[ic_4res], #3\n" + "beq 5f \n" + + "2: \n" + "ld1 {v0.s}[0], [x12], %[src_stride]\n" + "ld1 {v0.s}[1], [x12], %[src_stride]\n" + "ld1 {v0.s}[2], [x12], %[src_stride]\n" + "ld1 {v0.s}[3], [x12], %[src_stride]\n" + "ld1 {v1.s}[0], [x12], %[src_stride]\n" + "ld1 {v1.s}[1], [x12], %[src_stride]\n" + "ld1 {v1.s}[2], [x12], %[src_stride]\n" + "ld1 {v1.s}[3], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "st1 {v1.16b}, [x11], #16\n" + + "saddlp v4.8h, v0.16b \n" + "saddlp v5.8h, v1.16b \n" + + "saddlp v0.4s, v4.8h \n" + "saddlp v1.4s, v5.8h \n" + + "add v10.4s, v10.4s, v0.4s \n" + "add v11.4s, v11.4s, v1.4s \n" + "b 1b \n" + + "3: \n" /* col res 1 */ + "dup v0.4s, wzr \n" + "dup v1.4s, wzr \n" + + "ld1 {v0.b}[0], [x12], %[src_stride]\n" + "ld1 {v0.b}[4], [x12], %[src_stride]\n" + "ld1 {v0.b}[8], [x12], %[src_stride]\n" + "ld1 {v0.b}[12], [x12], %[src_stride]\n" + "ld1 {v1.b}[0], [x12], %[src_stride]\n" + "ld1 {v1.b}[4], [x12], %[src_stride]\n" + "ld1 {v1.b}[8], [x12], %[src_stride]\n" + "ld1 {v1.b}[12], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "st1 {v1.16b}, [x11], #16\n" + "saddlp v4.8h, v0.16b \n" + "saddlp v5.8h, v1.16b \n" + "saddlp v0.4s, v4.8h \n" + "saddlp v1.4s, v5.8h \n" + "add v10.4s, v10.4s, v0.4s \n" + "add v11.4s, v11.4s, v1.4s \n" + "b 6f \n" + + "4: \n" /* col res 2 */ + "dup v0.4s, wzr \n" + "dup v1.4s, wzr \n" + + "ld1 {v0.h}[0], [x12], %[src_stride]\n" + "ld1 {v0.h}[2], [x12], %[src_stride]\n" + "ld1 {v0.h}[4], [x12], %[src_stride]\n" + "ld1 {v0.h}[6], [x12], %[src_stride]\n" + "ld1 {v1.h}[0], [x12], %[src_stride]\n" + "ld1 {v1.h}[2], [x12], %[src_stride]\n" + "ld1 {v1.h}[4], [x12], %[src_stride]\n" + "ld1 {v1.h}[6], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "st1 {v1.16b}, [x11], #16\n" + "saddlp v4.8h, v0.16b \n" + "saddlp v5.8h, v1.16b \n" + "saddlp v0.4s, v4.8h \n" + "saddlp v1.4s, v5.8h \n" + "add v10.4s, v10.4s, v0.4s \n" + "add v11.4s, v11.4s, v1.4s \n" + "b 6f \n" + + "5: \n" /* col res 3 */ + "dup v0.4s, wzr \n" + "dup v1.4s, wzr \n" + "add x13, x12, #2 \n" + + "ld1 {v0.h}[0], [x12], %[src_stride]\n" + "ld1 {v0.b}[2], [x13], %[src_stride]\n" + "ld1 {v0.h}[2], [x12], %[src_stride]\n" + "ld1 {v0.b}[6], [x13], %[src_stride]\n" + "ld1 {v0.h}[4], [x12], %[src_stride]\n" + "ld1 {v0.b}[10], [x13], %[src_stride]\n" + "ld1 {v0.h}[6], [x12], %[src_stride]\n" + "ld1 {v0.b}[14], [x13], %[src_stride]\n" + "ld1 {v1.h}[0], [x12], %[src_stride]\n" + "ld1 {v1.b}[2], [x13], %[src_stride]\n" + "ld1 {v1.h}[2], [x12], %[src_stride]\n" + "ld1 {v1.b}[6], [x13], %[src_stride]\n" + "ld1 {v1.h}[4], [x12], %[src_stride]\n" + "ld1 {v1.b}[10], [x13], %[src_stride]\n" + "ld1 {v1.h}[6], [x12], %[src_stride]\n" + "ld1 {v1.b}[14], [x13], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "st1 {v1.16b}, [x11], #16\n" + "saddlp v4.8h, v0.16b \n" + "saddlp v5.8h, v1.16b \n" + "saddlp v0.4s, v4.8h \n" + "saddlp v1.4s, v5.8h \n" + "add v10.4s, v10.4s, v0.4s \n" + "add v11.4s, v11.4s, v1.4s \n" + "b 6f \n" + + "6: \n" + "mul v10.4s, v10.4s, v20.4s \n" + "mul v11.4s, v11.4s, v20.4s \n" + + "st1 {v10.4s}, [x20], #16 \n" + "st1 {v11.4s}, [x20], #16 \n" + + : + : [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ input_sum_r ] "r"(input_sum_r), + [ src_stride ] "r"(src_stride), [ ic_4div ] "r"(ic_4div), [ ic_4res ] "r"(ic_4res), [ filter_zp ] "r"(filter_zp) + : "x0", "x1", "x10", "x11", "x12", "x13", "x20", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v10", "v11", + "v20"); +#else + int32_t tmp_sum_value[8] = {0}; + for (int ici = 0; ici < ic_4div; ici += C4NUM) { for (int i = 0; i < C8NUM; i++) { - input_sum_r[i] = tmp_sum_value[i] * filter_zp; + tmp_sum_value[i] += src_ic[0 + i * input_channel]; + tmp_sum_value[i] += src_ic[1 + i * input_channel]; + tmp_sum_value[i] += src_ic[2 + i * input_channel]; + tmp_sum_value[i] += src_ic[3 + i * input_channel]; + pack_ic[0 + i * C4NUM] = src_ic[0 + i * input_channel]; + pack_ic[1 + i * C4NUM] = src_ic[1 + i * input_channel]; + pack_ic[2 + i * C4NUM] = src_ic[2 + i * input_channel]; + pack_ic[3 + i * C4NUM] = src_ic[3 + i * input_channel]; } -#endif - src_r += input_channel * C8NUM; - pack_r += ic4 * C8NUM; + src_ic += C4NUM; + pack_ic += C4NUM * C8NUM; + } + for (int ici = ic_4div; ici < input_channel; ici += 1) { + for (int i = 0; i < C8NUM; i++) { + tmp_sum_value[i] += src_ic[i * input_channel]; + pack_ic[i * C4NUM] = src_ic[i * input_channel]; + } + src_ic += 1; + pack_ic += 1; } - if (hw_8div != plane_size) { - memset(pack_r, 0, C8NUM * ic4); - for (int hwi = hw_8div; hwi < plane_size; hwi += 1) { - int32_t tmp_sum_value = 0; - const int8_t *src_ic = src_r; - int8_t *pack_ic = pack_r; - for (int ici = 0; ici < ic_4div; ici += C4NUM) { - tmp_sum_value += src_ic[0]; - tmp_sum_value += src_ic[1]; - tmp_sum_value += src_ic[2]; - tmp_sum_value += src_ic[3]; - pack_ic[0] = src_ic[0]; - pack_ic[1] = src_ic[1]; - pack_ic[2] = src_ic[2]; - pack_ic[3] = src_ic[3]; - src_ic += C4NUM; - pack_ic += C4NUM * C8NUM; - } - for (int ici = ic_4div; ici < input_channel; ici += 1) { - tmp_sum_value += src_ic[0]; - pack_ic[0] = src_ic[0]; - src_ic += 1; - pack_ic += 1; - } - input_sum[hwi] = tmp_sum_value * filter_zp; - src_r += input_channel; - pack_r += C4NUM; + for (int i = 0; i < C8NUM; i++) { + input_sum_r[i] = tmp_sum_value[i] * filter_zp; + } +#endif + src_r += input_channel * C8NUM; + pack_r += ic4 * C8NUM; + } + + if (hw_8div != plane_size) { + memset(pack_r, 0, C8NUM * ic4); + for (int hwi = hw_8div; hwi < plane_size; hwi += 1) { + int32_t tmp_sum_value = 0; + const int8_t *src_ic = src_r; + int8_t *pack_ic = pack_r; + for (int ici = 0; ici < ic_4div; ici += C4NUM) { + tmp_sum_value += src_ic[0]; + tmp_sum_value += src_ic[1]; + tmp_sum_value += src_ic[2]; + tmp_sum_value += src_ic[3]; + pack_ic[0] = src_ic[0]; + pack_ic[1] = src_ic[1]; + pack_ic[2] = src_ic[2]; + pack_ic[3] = src_ic[3]; + src_ic += C4NUM; + pack_ic += C4NUM * C8NUM; } - for (int hwi = plane_size; hwi < plane_size + hw_8res; hwi++) { - input_sum[hwi] = 0; + for (int ici = ic_4div; ici < input_channel; ici += 1) { + tmp_sum_value += src_ic[0]; + pack_ic[0] = src_ic[0]; + src_ic += 1; + pack_ic += 1; } + input_sum[hwi] = tmp_sum_value * filter_zp; + src_r += input_channel; + pack_r += C4NUM; + } + for (int hwi = plane_size; hwi < UP_ROUND(plane_size, C8NUM); hwi++) { + input_sum[hwi] = 0; } - } else { - /* per channel */ - RowMajor2Row4x8MajorInt8(src_input, packed_input, plane_size, input_channel); - PackInputSum8x4Int8(packed_input, input_sum, input_channel, output_channel, plane_size, conv_param); } return; } void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, - const int32_t *bias, int row, int col, int deep4, ConvParameter *conv_param, - MATMUL_OPT_R_FUNC matmul_func) { + const int32_t *bias, int row, int col, int deep4, int32_t *left_shift, int32_t *right_shift, + int32_t *multiplier, ConvParameter *conv_param, MATMUL_OPT_R_FUNC matmul_func) { matmul_func(packed_input, packed_weight, dst, row, col, deep4, conv_param->output_channel_, input_sum, bias, - conv_param->conv_quant_arg_.left_shift_, conv_param->conv_quant_arg_.right_shift_, - conv_param->conv_quant_arg_.quant_multiplier_, conv_param->conv_quant_arg_.output_quant_args_[0].zp_, - conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], false); + left_shift, right_shift, multiplier, conv_param->conv_quant_arg_.output_quant_args_[0].zp_, + conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], + conv_param->conv_quant_arg_.filter_arg_num_ != 1); return; } void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, - const int32_t *bias, int row, int col, int deep16, ConvParameter *conv_param) { + const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift, + int32_t *multiplier, ConvParameter *conv_param) { + if (conv_param->conv_quant_arg_.filter_arg_num_ > 1) { + return MatMulInt8_16x4_r(packed_input, packed_weight, dst, row, col, deep16, conv_param->output_channel_, input_sum, + bias, left_shift, right_shift, multiplier, + conv_param->conv_quant_arg_.output_quant_args_[0].zp_, + conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], + conv_param->conv_quant_arg_.filter_arg_num_ != 1); + } #ifdef ENABLE_ARM64 MatmulInt8Neon64(packed_input, packed_weight, dst, UP_ROUND(row, C4NUM), UP_ROUND(col, C4NUM), deep16, input_sum, bias, conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], @@ -622,10 +753,9 @@ void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t conv_param->conv_quant_arg_.right_shift_[0], row, col, conv_param->output_channel_); #else MatMulInt8_16x4_r(packed_input, packed_weight, dst, row, col, deep16, conv_param->output_channel_, input_sum, bias, - conv_param->conv_quant_arg_.left_shift_, conv_param->conv_quant_arg_.right_shift_, - conv_param->conv_quant_arg_.quant_multiplier_, - conv_param->conv_quant_arg_.output_quant_args_[0].zp_, conv_param->conv_quant_arg_.out_act_min_[0], - conv_param->conv_quant_arg_.out_act_max_[0], false); + left_shift, right_shift, multiplier, conv_param->conv_quant_arg_.output_quant_args_[0].zp_, + conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], + conv_param->conv_quant_arg_.filter_arg_num_ != 1); #endif return; } diff --git a/mindspore/lite/nnacl/int8/conv_int8.h b/mindspore/lite/nnacl/int8/conv_int8.h index 5741ee3117..60d84e27f7 100644 --- a/mindspore/lite/nnacl/int8/conv_int8.h +++ b/mindspore/lite/nnacl/int8/conv_int8.h @@ -54,13 +54,16 @@ void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight ConvParameter *conv_param, GEMM_FUNC gemm_func); // int8 convolution 1x1 -void Conv1x1PreOpt(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel, - size_t output_channel, size_t plane_size, ConvParameter *conv_param); +void Conv1x1PreOptPeroc(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel, + size_t output_channel, size_t plane_size, int32_t *filter_zp, size_t inputsum_stride); +void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel, + size_t plane_size, ConvParameter *conv_param); void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, - const int32_t *bias, int row, int col, int deep16, ConvParameter *conv_param); + const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift, + int32_t *multiplier, ConvParameter *conv_param); void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, - const int32_t *bias, int row, int col, int deep4, ConvParameter *conv_param, - MATMUL_OPT_R_FUNC matmul_func); + const int32_t *bias, int row, int col, int deep4, int32_t *left_shift, int32_t *right_shift, + int32_t *multiplier, ConvParameter *conv_param, MATMUL_OPT_R_FUNC matmul_func); // int8 convolution 3x3 void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data, diff --git a/mindspore/lite/nnacl/int8/matmul_int8.c b/mindspore/lite/nnacl/int8/matmul_int8.c index 26aa3269df..1e1241712c 100644 --- a/mindspore/lite/nnacl/int8/matmul_int8.c +++ b/mindspore/lite/nnacl/int8/matmul_int8.c @@ -186,8 +186,9 @@ void MatMulInt8_16x4(const int8_t *a, const int8_t *b, int *dst, int row_4, int void MatMulInt8_16x4_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_16, size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift, int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi, - bool per_channel) { - /* row4x16-major * row16x4-major => (int8)row-major : per-channel */ + bool peroc) { + /* support per-layer && weight per-channel */ + /* row4x16-major * row16x4-major => (int8)row-major*/ for (int r = 0; r < row; r++) { for (int c = 0; c < col; c++) { int r4div = r / C4NUM, r4mod = r % C4NUM; @@ -200,12 +201,13 @@ void MatMulInt8_16x4_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row size_t bi = c4div * deep_16 * C4NUM + d16div * C4NUM * C16NUM + c4mod * C16NUM + d16mod; value = value + a[ai] * b[bi]; } - int32_t cur_input_sum = per_channel ? input_sum[c4div * UP_ROUND(row, C4NUM) + r * C4NUM + c4mod] : input_sum[r]; + int32_t cur_input_sum = + peroc ? input_sum[c4div * UP_ROUND(row, C4NUM) * C4NUM + r * C4NUM + c4mod] : input_sum[r]; value -= cur_input_sum; value += bias[c]; - int32_t cur_left_shift = per_channel ? left_shift[c] : left_shift[0]; - int32_t cur_right_shift = per_channel ? right_shift[c] : right_shift[0]; - int32_t cur_multiplier = per_channel ? multiplier[c] : multiplier[0]; + int32_t cur_left_shift = peroc ? left_shift[c] : left_shift[0]; + int32_t cur_right_shift = peroc ? right_shift[c] : right_shift[0]; + int32_t cur_multiplier = peroc ? multiplier[c] : multiplier[0]; value = MultiplyByQuantizedMultiplier(value, cur_multiplier, cur_left_shift, cur_right_shift) + output_zp; value = MSMIN(maxi, value); value = MSMAX(mini, value); @@ -232,7 +234,8 @@ void MatMulInt8_8x8_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t bi = c8div * deep_4 * C8NUM + d4div * C8NUM * C4NUM + c8mod * C4NUM + d4mod; value = value + a[ai] * b[bi]; } - int32_t cur_input_sum = per_channel ? input_sum[c8div * UP_ROUND(row, C8NUM) + r * C8NUM + c8mod] : input_sum[r]; + int32_t cur_input_sum = + per_channel ? input_sum[c8div * UP_ROUND(row, C8NUM) * C8NUM + r * C8NUM + c8mod] : input_sum[r]; value -= cur_input_sum; value += bias[c]; int32_t cur_left_shift = per_channel ? left_shift[c] : left_shift[0]; diff --git a/mindspore/lite/nnacl/int8/matmul_int8.h b/mindspore/lite/nnacl/int8/matmul_int8.h index 03028c49ec..fe20548b8d 100644 --- a/mindspore/lite/nnacl/int8/matmul_int8.h +++ b/mindspore/lite/nnacl/int8/matmul_int8.h @@ -17,6 +17,7 @@ #ifndef MINDSPORE_LITE_NNACL_INT8_MATMUL_H_ #define MINDSPORE_LITE_NNACL_INT8_MATMUL_H_ +#include #include #include "nnacl/op_base.h" #include "nnacl/matmul_parameter.h" diff --git a/mindspore/lite/nnacl/pack.c b/mindspore/lite/nnacl/pack.c index 89dbc78788..5ceb003d9b 100644 --- a/mindspore/lite/nnacl/pack.c +++ b/mindspore/lite/nnacl/pack.c @@ -195,7 +195,7 @@ void Pack1x1WeightFp32(const float *weight_data, float *packed_weight, ConvParam } void PackInputSum16x4PerLayer(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16) { - /* optimize normal -> same layout */ + /* normal matmul : 4x16 * 16x4 -> 4x4 */ #ifdef ENABLE_ARM64 asm volatile( "mov x10, %[src] \n" @@ -267,62 +267,158 @@ void PackInputSum16x4PerLayer(const int8_t *src, int32_t *dst, int32_t filter_zp return; } -void PackInputSum16x4Int8(const int8_t *input_value, int32_t *input_sum, size_t input_channel, size_t output_channel, - size_t plane_size, ConvParameter *conv_param) { +void PackInputSum16x4PerChannel(const int8_t *input_value, int32_t *input_sum, int32_t *filter_zp_ptr, + size_t plane_size, size_t input_channel, size_t output_channel) { size_t hw4 = UP_ROUND(plane_size, C4NUM); size_t ic16 = UP_ROUND(input_channel, C16NUM); - if (conv_param->conv_quant_arg_.filter_arg_num_ == 1) { - PackInputSum16x4PerLayer(input_value, input_sum, conv_param->conv_quant_arg_.filter_quant_args_[0].zp_, hw4, ic16); - } else { - for (int ri = 0; ri < plane_size; ri++) { - int ri4div = ri / C4NUM, ri4mod = ri % C4NUM; - for (int ci = 0; ci < output_channel; ci++) { - int32_t tmp_sum_value = 0; - int ci4div = ci / C4NUM, ci4mod = ci % C4NUM; - int32_t filter_zp = conv_param->conv_quant_arg_.filter_quant_args_[ci].zp_; - for (int di = 0; di < input_channel; di++) { - size_t di16div = di / C16NUM, di16mod = di % C16NUM; - int src_index = ri4div * C4NUM * ic16 + di16div * C16NUM * C4NUM + ri4mod * C16NUM + di16mod; - tmp_sum_value += input_value[src_index]; - } - int dst_index = ci4div * C4NUM * hw4 + ri * C4NUM + ci4mod; - input_sum[dst_index] = tmp_sum_value * filter_zp; +#ifdef ENABLE_ARM64 + size_t oc_div4 = output_channel / C4NUM * C4NUM; + size_t oc_res4 = output_channel - oc_div4; + size_t inputsun_stride = hw4 * C4NUM * 4 - C4NUM * C4NUM * 4; + asm volatile( + "mov x10, %[input_value] \n" + "mov x11, %[input_sum] \n" + "mov x15, %[filter_zp_ptr] \n" + + "mov x0, #0 \n" // row 4 count + "1: \n" + "cmp x0, %[hw4] \n" + "beq 11f \n" + "add x0, x0, #4\n" + "dup v10.4s, wzr \n" + "mov x2, #0 \n" // input deep count + "mov x16, x15 \n" + + "2: \n" + "cmp x2, %[ic16] \n" + "beq 3f \n" + "add x2, x2, #16 \n" + + "ld1 {v0.16b}, [x10], #16\n" + "ld1 {v1.16b}, [x10], #16\n" + "ld1 {v2.16b}, [x10], #16\n" + "ld1 {v3.16b}, [x10], #16\n" + + "saddlp v4.8h, v0.16b \n" + "saddlp v5.8h, v1.16b \n" + "saddlp v6.8h, v2.16b \n" + "saddlp v7.8h, v3.16b \n" + "saddlp v0.4S, v4.8h \n" + "saddlp v1.4S, v5.8h \n" + "saddlp v2.4S, v6.8h \n" + "saddlp v3.4S, v7.8h \n" + "addv s4, v0.4S \n" + "addv s5, v1.4S \n" + "addv s6, v2.4S \n" + "addv s7, v3.4S \n" + "mov v0.s[0], v4.s[0] \n" + "mov v0.s[1], v5.s[0] \n" + "mov v0.s[2], v6.s[0] \n" + "mov v0.s[3], v7.s[0] \n" + + "add v10.4s, v10.4s, v0.4s \n" + "b 2b \n" + + "3: \n" + "mov x12, x11 \n" // tmp inputsm inputsum hw + "add x11, x11, #64 \n" + "mov x4, #0 \n" // oc count + + "dup v1.4s, v10.s[0] \n" + "dup v2.4s, v10.s[1] \n" + "dup v3.4s, v10.s[2] \n" + "dup v4.4s, v10.s[3] \n" + + "4: \n" + "cmp x4, %[oc_div4] \n" + "beq 6f \n" + "add x4, x4, #4\n" + "ld1 {v15.4s}, [x16], #16\n" + + "mul v16.4s, v15.4s, v1.4s \n" + "mul v17.4s, v15.4s, v2.4s \n" + "mul v18.4s, v15.4s, v3.4s \n" + "mul v19.4s, v15.4s, v4.4s \n" + "st1 {v16.4s}, [x12], #16 \n" + "st1 {v17.4s}, [x12], #16 \n" + "st1 {v18.4s}, [x12], #16 \n" + "st1 {v19.4s}, [x12], #16 \n" + "add x12, x12, %[inputsun_stride] \n" + "b 4b \n" + + "6: \n" + "cmp %[oc_res4], #0\n" + "beq 1b \n" + "dup v15.4s, wzr \n" + "cmp %[oc_res4], #1\n" + "beq 7f \n" + "cmp %[oc_res4], #2\n" + "beq 8f \n" + "cmp %[oc_res4], #3\n" + "beq 9f \n" + + "7: \n" + "ld1 {v15.s}[0], [x16] \n" + "b 10f \n" + + "8: \n" + "ld1 {v15.h}[0], [x16] \n" + "b 10f \n" + + "9: \n" + "ld1 {v15.h}[0], [x16] \n" + "add x16, x16, #8 \n" + "ld1 {v15.s}[2], [x16] \n" + "b 10f \n" + + "10: \n" + "mul v16.4s, v15.4s, v1.4s \n" + "mul v17.4s, v15.4s, v2.4s \n" + "mul v18.4s, v15.4s, v3.4s \n" + "mul v19.4s, v15.4s, v4.4s \n" + "st1 {v16.4s}, [x12], #16 \n" + "st1 {v17.4s}, [x12], #16 \n" + "st1 {v18.4s}, [x12], #16 \n" + "st1 {v19.4s}, [x12], #16 \n" + "b 1b \n" + + "11: \n" + + : + : [ input_value ] "r"(input_value), [ input_sum ] "r"(input_sum), [ filter_zp_ptr ] "r"(filter_zp_ptr), + [ hw4 ] "r"(hw4), [ ic16 ] "r"(ic16), [ oc_div4 ] "r"(oc_div4), [ oc_res4 ] "r"(oc_res4), + [ inputsun_stride ] "r"(inputsun_stride) + : "x0", "x2", "x4", "x10", "x11", "x12", "x15", "x16", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v10", "v15", + "v16", "v17", "v18", "v19"); +#else + + for (int ri = 0; ri < plane_size; ri++) { + int ri4div = ri / C4NUM, ri4mod = ri % C4NUM; + for (int ci = 0; ci < output_channel; ci++) { + int32_t tmp_sum_value = 0; + int ci4div = ci / C4NUM, ci4mod = ci % C4NUM; + int32_t filter_zp = filter_zp_ptr[ci]; + for (int di = 0; di < input_channel; di++) { + size_t di16div = di / C16NUM, di16mod = di % C16NUM; + int src_index = ri4div * C4NUM * ic16 + di16div * C16NUM * C4NUM + ri4mod * C16NUM + di16mod; + tmp_sum_value += input_value[src_index]; } + int dst_index = ci4div * C4NUM * hw4 + ri * C4NUM + ci4mod; + input_sum[dst_index] = tmp_sum_value * filter_zp; } } +#endif return; } -void PackInputSum8x4Int8(const int8_t *input_value, int32_t *input_sum, size_t input_channel, size_t output_channel, - size_t plane_size, ConvParameter *conv_param) { - size_t hw8 = UP_ROUND(plane_size, C8NUM); - size_t ic4 = UP_ROUND(input_channel, C4NUM); +void PackInputSum16x4Int8(const int8_t *input, int32_t *input_sum, int32_t *filter_zp, ConvParameter *conv_param) { + size_t hw4 = UP_ROUND(conv_param->input_h_ * conv_param->input_w_, C4NUM); + size_t ic16 = UP_ROUND(conv_param->input_channel_, C16NUM); if (conv_param->conv_quant_arg_.filter_arg_num_ == 1) { - for (int r = 0; r < hw8; r++) { - int32_t tmp_value = 0; - for (int c = 0; c < ic4; c++) { - int r8div = r / C8NUM, r8mod = r % C8NUM, c4div = c / C4NUM, c4mod = c % C4NUM; - int src_index = r8div * C8NUM * ic4 + c4div * C8NUM * C4NUM + r8mod * C4NUM + c4mod; - tmp_value += input_value[src_index]; - } - input_sum[r] = tmp_value * conv_param->conv_quant_arg_.filter_quant_args_[0].zp_; - } + PackInputSum16x4PerLayer(input, input_sum, conv_param->conv_quant_arg_.filter_quant_args_[0].zp_, hw4, ic16); } else { - for (int ri = 0; ri < plane_size; ri++) { - int ri8div = ri / C8NUM, ri8mod = ri % C8NUM; - for (int ci = 0; ci < output_channel; ci++) { - int32_t tmp_sum_value = 0; - int ci8div = ci / C8NUM, ci8mod = ci % C8NUM; - int32_t filter_zp = conv_param->conv_quant_arg_.filter_quant_args_[ci].zp_; - for (int di = 0; di < input_channel; di++) { - size_t di4div = di / C4NUM, di4mod = di % C4NUM; - int src_index = ri8div * C8NUM * ic4 + di4div * C8NUM * C4NUM + ri8mod * C4NUM + di4mod; - tmp_sum_value += input_value[src_index]; - } - int dst_index = ci8div * C8NUM * hw8 + ri * C8NUM + ci8mod; - input_sum[dst_index] = tmp_sum_value * filter_zp; - } - } + PackInputSum16x4PerChannel(input, input_sum, filter_zp, conv_param->input_h_ * conv_param->input_w_, + conv_param->input_channel_, conv_param->output_channel_); } return; } diff --git a/mindspore/lite/nnacl/pack.h b/mindspore/lite/nnacl/pack.h index b05083c52d..903057d710 100644 --- a/mindspore/lite/nnacl/pack.h +++ b/mindspore/lite/nnacl/pack.h @@ -17,6 +17,7 @@ #ifndef MINDSPORE_LITE_NNACL_PACK_H_ #define MINDSPORE_LITE_NNACL_PACK_H_ +#include #ifdef ENABLE_NEON #include #endif @@ -41,8 +42,7 @@ void Conv1x1InputPack(const void *src_ptr, void *dst_ptr, ConvParameter *conv_pa void Pack1x1WeightFp32(const float *weight_data, float *packed_weight, ConvParameter *conv_param); -void PackInputSum16x4Int8(const int8_t *input_value, int32_t *input_sum, size_t input_channel, size_t output_channel, - size_t plane_size, ConvParameter *conv_param); +void PackInputSum16x4Int8(const int8_t *input, int32_t *input_sum, int32_t *filter_zp, ConvParameter *conv_param); void PackInputSum8x4Int8(const int8_t *input_value, int32_t *input_sum, size_t input_channel, size_t output_channel, size_t plane_size, ConvParameter *conv_param); diff --git a/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc index 1d305d2ac1..3c905ba8f8 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc @@ -316,14 +316,12 @@ int ConvolutionBaseCPUKernel::SetQuantParam() { MS_LOG(ERROR) << "Set Quant Multiplier Failed."; return ret; } - // now only consider per tensor for output bool relu = conv_param_->act_type_ == ActType_Relu; bool relu6 = conv_param_->act_type_ == ActType_Relu6; CalculateActivationRangeQuantized(relu, relu6, conv_param_->conv_quant_arg_.output_quant_args_[0].zp_, conv_param_->conv_quant_arg_.output_quant_args_[0].scale_, &conv_param_->conv_quant_arg_.out_act_min_[0], &conv_param_->conv_quant_arg_.out_act_max_[0]); - return RET_OK; } int ConvolutionBaseCPUKernel::RestoreFilter(lite::Tensor *input_tensor) { diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.cc index ed8cabb85f..d4ac9fcdbc 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.cc @@ -16,6 +16,7 @@ #include "src/runtime/kernel/arm/int8/convolution_1x1_int8.h" #include "src/runtime/runtime_api.h" +#include "src/common/file_utils.h" using mindspore::lite::RET_ERROR; using mindspore::lite::RET_MEMORY_FAILED; @@ -41,6 +42,10 @@ Convolution1x1Int8CPUKernel::~Convolution1x1Int8CPUKernel() { free(packed_weight_); packed_weight_ = nullptr; } + if (filter_peroc_ && filter_zp_ptr_ != nullptr) { + free(filter_zp_ptr_); + filter_zp_ptr_ = nullptr; + } FreeResizeBuf(); FreeQuantParam(); } @@ -54,7 +59,7 @@ void Convolution1x1Int8CPUKernel::FreeResizeBuf() { } void Convolution1x1Int8CPUKernel::CheckSupportOptimize() { - support_optimize_ = true; + support_optimize_ = false; matmul_func_ = MatMulInt8_8x8_r; #ifdef ENABLE_ARM64 void *optimize_op_handler = OptimizeModule::GetInstance()->optimized_op_handler_; @@ -73,6 +78,10 @@ void Convolution1x1Int8CPUKernel::CheckSupportOptimize() { support_optimize_ = false; matmul_func_ = nullptr; } + + if (filter_peroc_) { + support_optimize_ = false; + } #endif return; } @@ -118,14 +127,23 @@ int Convolution1x1Int8CPUKernel::InitWeightBias() { int32_t input_zp = conv_param_->conv_quant_arg_.input_quant_args_[0].zp_; for (int oc = 0; oc < output_channel; oc++) { int32_t weight_sum_value = 0; - int32_t filter_zp = (conv_param_->conv_quant_arg_.filter_arg_num_ == 1) - ? conv_param_->conv_quant_arg_.filter_quant_args_[0].zp_ - : conv_param_->conv_quant_arg_.filter_quant_args_[oc].zp_; + int32_t filter_zp = (filter_peroc_) ? conv_param_->conv_quant_arg_.filter_quant_args_[oc].zp_ + : conv_param_->conv_quant_arg_.filter_quant_args_[0].zp_; for (int ic = 0; ic < input_channel; ic++) { weight_sum_value += weight[oc * input_channel + ic]; } bias_data[oc] += filter_zp * input_zp * input_channel - weight_sum_value * input_zp; } + + if (filter_peroc_) { + filter_zp_ptr_ = reinterpret_cast(malloc(output_channel * sizeof(int32_t))); + if (filter_zp_ptr_ == nullptr) { + return RET_ERROR; + } + for (int fi = 0; fi < output_channel; fi++) { + filter_zp_ptr_[fi] = conv_param_->conv_quant_arg_.filter_quant_args_[fi].zp_; + } + } return RET_OK; } @@ -136,14 +154,16 @@ int Convolution1x1Int8CPUKernel::Init() { return RET_ERROR; } - CheckSupportOptimize(); - auto ret = SetQuantParam(); if (ret != RET_OK) { MS_LOG(ERROR) << "Set quant param failed."; return ret; } + filter_peroc_ = (conv_param_->conv_quant_arg_.filter_arg_num_ != 1); + + CheckSupportOptimize(); + ret = InitWeightBias(); if (ret != RET_OK) { MS_LOG(ERROR) << "Init weight bias failed."; @@ -229,14 +249,17 @@ void Convolution1x1Int8CPUKernel::Pre1x1Trans(int8_t *src_input, int8_t *src_out ParallelLaunch(THREAD_POOL_DEFAULT, Convolution1x1Int8Pre, this, thread_count_hw_); } else { RowMajor2Row16x4MajorInt8(input_ptr_, packed_input_, matmul_param_->row_, matmul_param_->deep_); - PackInputSum16x4Int8(packed_input_, input_sum_, matmul_param_->deep_, matmul_param_->col_, matmul_param_->row_, - conv_param_); + PackInputSum16x4Int8(packed_input_, input_sum_, filter_zp_ptr_, conv_param_); } - return; } int Convolution1x1Int8CPUKernel::RunImpl(int task_id) { + int32_t *cur_input_sum = input_sum_; + int32_t *cur_left_shift = conv_param_->conv_quant_arg_.left_shift_; + int32_t *cur_right_shift = conv_param_->conv_quant_arg_.right_shift_; + int32_t *cur_multiplier = conv_param_->conv_quant_arg_.quant_multiplier_; + if (support_optimize_) { int cur_stride = thread_stride_ * C8NUM; int res_stride = matmul_param_->col_ - task_id * thread_stride_ * C8NUM; @@ -244,10 +267,17 @@ int Convolution1x1Int8CPUKernel::RunImpl(int task_id) { if (cur_oc <= 0) { return RET_OK; } + if (filter_peroc_) { + cur_input_sum = input_sum_ + task_id * matmul_param_->row_8_ * thread_stride_ * C8NUM; + cur_left_shift = conv_param_->conv_quant_arg_.left_shift_ + task_id * thread_stride_ * C8NUM; + cur_right_shift = conv_param_->conv_quant_arg_.right_shift_ + task_id * thread_stride_ * C8NUM; + cur_multiplier = conv_param_->conv_quant_arg_.quant_multiplier_ + task_id * thread_stride_ * C8NUM; + } Conv1x1Int8Opt(packed_input_, packed_weight_ + task_id * thread_stride_ * C8NUM * matmul_param_->deep_4_, - output_ptr_ + task_id * thread_stride_ * C8NUM, input_sum_, + output_ptr_ + task_id * thread_stride_ * C8NUM, cur_input_sum, reinterpret_cast(bias_data_) + task_id * thread_stride_ * C8NUM, matmul_param_->row_, - cur_oc, matmul_param_->deep_4_, conv_param_, matmul_func_); + cur_oc, matmul_param_->deep_4_, cur_left_shift, cur_right_shift, cur_multiplier, conv_param_, + matmul_func_); } else { int cur_stride = thread_stride_ * C4NUM; int res_stride = matmul_param_->col_ - task_id * thread_stride_ * C4NUM; @@ -255,10 +285,16 @@ int Convolution1x1Int8CPUKernel::RunImpl(int task_id) { if (cur_oc <= 0) { return RET_OK; } + if (filter_peroc_) { + cur_input_sum = input_sum_ + task_id * matmul_param_->row_4_ * thread_stride_ * C4NUM; + cur_left_shift = conv_param_->conv_quant_arg_.left_shift_ + task_id * thread_stride_ * C4NUM; + cur_right_shift = conv_param_->conv_quant_arg_.right_shift_ + task_id * thread_stride_ * C4NUM; + cur_multiplier = conv_param_->conv_quant_arg_.quant_multiplier_ + task_id * thread_stride_ * C4NUM; + } Conv1x1Int8(packed_input_, packed_weight_ + task_id * thread_stride_ * C4NUM * matmul_param_->deep_16_, - output_ptr_ + task_id * thread_stride_ * C4NUM, input_sum_, + output_ptr_ + task_id * thread_stride_ * C4NUM, cur_input_sum, reinterpret_cast(bias_data_) + task_id * thread_stride_ * C4NUM, matmul_param_->row_, cur_oc, - matmul_param_->deep_16_, conv_param_); + matmul_param_->deep_16_, cur_left_shift, cur_right_shift, cur_multiplier, conv_param_); } return RET_OK; } @@ -270,10 +306,18 @@ int Convolution1x1Int8CPUKernel::RunPre(int task_id) { if (cur_hw <= 0) { return RET_OK; } - Conv1x1PreOpt(input_ptr_ + task_id * thread_stride_hw_ * C8NUM * matmul_param_->deep_, - packed_input_ + task_id * thread_stride_hw_ * C8NUM * matmul_param_->deep_4_, - input_sum_ + task_id * thread_stride_hw_ * C8NUM, matmul_param_->deep_, matmul_param_->col_, cur_hw, - conv_param_); + + if (filter_peroc_) { + Conv1x1PreOptPeroc(input_ptr_ + task_id * thread_stride_hw_ * C8NUM * matmul_param_->deep_, + packed_input_ + task_id * thread_stride_hw_ * C8NUM * matmul_param_->deep_4_, + input_sum_ + task_id * thread_stride_hw_ * C8NUM * C8NUM, matmul_param_->deep_, + matmul_param_->col_, cur_hw, filter_zp_ptr_, matmul_param_->row_8_ * C8NUM); + } else { + Conv1x1PreOptPert(input_ptr_ + task_id * thread_stride_hw_ * C8NUM * matmul_param_->deep_, + packed_input_ + task_id * thread_stride_hw_ * C8NUM * matmul_param_->deep_4_, + input_sum_ + task_id * thread_stride_hw_ * C8NUM, matmul_param_->deep_, cur_hw, conv_param_); + } + return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.h index 634aa29ff3..342aa3eff7 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.h @@ -56,7 +56,8 @@ class Convolution1x1Int8CPUKernel : public ConvolutionBaseCPUKernel { void CheckSupportOptimize(); private: - int32_t *input_sum_ = nullptr; /* per-channel: oc4 format */ + int32_t *input_sum_ = nullptr; /* per-channel: oc4 format */ + int32_t *filter_zp_ptr_ = nullptr; /* oc - per - channel */ int8_t *packed_weight_ = nullptr; int8_t *packed_input_ = nullptr; int8_t *input_ptr_ = nullptr; @@ -70,6 +71,7 @@ class Convolution1x1Int8CPUKernel : public ConvolutionBaseCPUKernel { MatMulParameter *matmul_param_ = nullptr; MATMUL_OPT_R_FUNC matmul_func_ = nullptr; bool support_optimize_ = false; + bool filter_peroc_ = false; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.cc index 99e7588845..0440309b52 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.cc @@ -397,10 +397,9 @@ kernel::LiteKernel *CpuConvInt8KernelCreator(const std::vector & int dilation_h = conv_param->dilation_h_; int dilation_w = conv_param->dilation_w_; kernel::LiteKernel *kernel; - auto filter_quant_size = inputs[kWeightIndex]->GetQuantParams().size(); if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) { kernel = new (std::nothrow) kernel::ConvolutionInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); - } else if (kernel_h == 1 && kernel_w == 1 && filter_quant_size == 1) { + } else if (kernel_h == 1 && kernel_w == 1) { kernel = new (std::nothrow) kernel::Convolution1x1Int8CPUKernel(opParameter, inputs, outputs, ctx, primitive); } else { kernel = new (std::nothrow) kernel::ConvolutionInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive);