Browse Source

[MS][LITE][Develop] int8 conv 1x1 support weight per output-channel on x86

tags/v1.0.0
ling 5 years ago
parent
commit
0db75b70d7
10 changed files with 574 additions and 298 deletions
  1. +345
    -215
      mindspore/lite/nnacl/int8/conv_int8.c
  2. +8
    -5
      mindspore/lite/nnacl/int8/conv_int8.h
  3. +10
    -7
      mindspore/lite/nnacl/int8/matmul_int8.c
  4. +1
    -0
      mindspore/lite/nnacl/int8/matmul_int8.h
  5. +143
    -47
      mindspore/lite/nnacl/pack.c
  6. +2
    -2
      mindspore/lite/nnacl/pack.h
  7. +0
    -2
      mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc
  8. +61
    -17
      mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.cc
  9. +3
    -1
      mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.h
  10. +1
    -2
      mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.cc

+ 345
- 215
mindspore/lite/nnacl/int8/conv_int8.c View File

@@ -371,249 +371,380 @@ void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight
}
}

void Conv1x1PreOpt(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel,
size_t output_channel, size_t plane_size, ConvParameter *conv_param) {
void Conv1x1PreOptPeroc(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel,
size_t output_channel, size_t plane_size, int32_t *filter_zp, size_t inputsum_stride) {
int ic4 = UP_ROUND(input_channel, C4NUM);
int oc8 = UP_ROUND(output_channel, C8NUM);
int hw8 = UP_ROUND(plane_size, C8NUM);
size_t hw_8div = plane_size / C8NUM * C8NUM;
size_t hw_8res = plane_size - hw_8div;
size_t oc_8div = output_channel / C8NUM * C8NUM;
size_t oc_8res = output_channel - oc_8div;
size_t ic_4div = input_channel / C4NUM * C4NUM;
int32_t filter_zp = conv_param->conv_quant_arg_.filter_quant_args_[0].zp_;

if (conv_param->conv_quant_arg_.filter_arg_num_ == 1) {
const int8_t *src_r = src_input;
int8_t *pack_r = packed_input;
/* per layer */
for (int hwi = 0; hwi < hw_8div; hwi += C8NUM) {
const int8_t *src_r = src_input;
int8_t *pack_r = packed_input;
int32_t *input_sum_r = input_sum;

/* per layer */
for (int hwi = 0; hwi < hw_8div; hwi += C8NUM) {
const int8_t *src_ic = src_r;
int8_t *pack_ic = pack_r;
int32_t *input_sum_oc = input_sum_r;

int32_t tmp_sum_value[8] = {0};
for (int ici = 0; ici < ic_4div; ici += C4NUM) {
for (int i = 0; i < C8NUM; i++) {
tmp_sum_value[i] += src_ic[0 + i * input_channel];
tmp_sum_value[i] += src_ic[1 + i * input_channel];
tmp_sum_value[i] += src_ic[2 + i * input_channel];
tmp_sum_value[i] += src_ic[3 + i * input_channel];
pack_ic[0 + i * C4NUM] = src_ic[0 + i * input_channel];
pack_ic[1 + i * C4NUM] = src_ic[1 + i * input_channel];
pack_ic[2 + i * C4NUM] = src_ic[2 + i * input_channel];
pack_ic[3 + i * C4NUM] = src_ic[3 + i * input_channel];
}
src_ic += C4NUM;
pack_ic += C4NUM * C8NUM;
}
for (int ici = ic_4div; ici < input_channel; ici += 1) {
for (int i = 0; i < C8NUM; i++) {
tmp_sum_value[i] += src_ic[i * input_channel];
pack_ic[i * C4NUM] = src_ic[i * input_channel];
}
src_ic += 1;
pack_ic += 1;
}

for (int oci = 0; oci < oc_8div; oci += C8NUM) {
for (int ri = 0; ri < C8NUM; ri++) {
input_sum_oc[ri * C8NUM + 0] = tmp_sum_value[ri] * filter_zp[oci + 0];
input_sum_oc[ri * C8NUM + 1] = tmp_sum_value[ri] * filter_zp[oci + 1];
input_sum_oc[ri * C8NUM + 2] = tmp_sum_value[ri] * filter_zp[oci + 2];
input_sum_oc[ri * C8NUM + 3] = tmp_sum_value[ri] * filter_zp[oci + 3];
input_sum_oc[ri * C8NUM + 4] = tmp_sum_value[ri] * filter_zp[oci + 4];
input_sum_oc[ri * C8NUM + 5] = tmp_sum_value[ri] * filter_zp[oci + 5];
input_sum_oc[ri * C8NUM + 6] = tmp_sum_value[ri] * filter_zp[oci + 6];
input_sum_oc[ri * C8NUM + 7] = tmp_sum_value[ri] * filter_zp[oci + 7];
}
input_sum_oc += inputsum_stride;
}
if (oc_8div != output_channel) {
for (int oci = 0; oci < oc_8res; oci += 1) {
for (int ri = 0; ri < C8NUM; ri++) {
input_sum_oc[ri * C8NUM + oci] = tmp_sum_value[ri] * filter_zp[oc_8div + oci];
}
}
for (int oci = oc_8res; oci < C8NUM; oci += 1) {
for (int ri = 0; ri < C8NUM; ri++) {
input_sum_oc[ri * C8NUM + oci] = 0;
}
}
} /* oc8 res done */

src_r += input_channel * C8NUM;
pack_r += ic4 * C8NUM;
input_sum_r += C8NUM * C8NUM;
}

if (hw_8div != plane_size) {
memset(pack_r, 0, C8NUM * ic4);
for (int hwi = hw_8div; hwi < plane_size; hwi += 1) {
int32_t *input_sum_oc = input_sum_r;
int32_t tmp_sum_value = 0;
const int8_t *src_ic = src_r;
int8_t *pack_ic = pack_r;
int32_t *input_sum_r = input_sum + hwi;
#ifdef ENABLE_ARM64
size_t src_stride = input_channel;
size_t ic_4res = input_channel - ic_4div;
asm volatile(
"dup v10.4s, wzr \n"
"dup v11.4s, wzr \n"
"mov x20, %[input_sum_r] \n"
"dup v20.4s, %w[filter_zp] \n"

"mov x10, %[src_ic] \n"
"mov x11, %[pack_ic] \n"

"mov x0, #0 \n"
"1: \n"
"cmp x0, %[ic_4div] \n"
"add x0, x0, #4\n"
"mov x12, x10 \n"
"add x10, x10, #4\n"
"blt 2f \n"
"cmp %[ic_4res], #0\n"
"beq 6f \n"
"cmp %[ic_4res], #1\n"
"beq 3f \n"
"cmp %[ic_4res], #2\n"
"beq 4f \n"
"cmp %[ic_4res], #3\n"
"beq 5f \n"

"2: \n"
"ld1 {v0.s}[0], [x12], %[src_stride]\n"
"ld1 {v0.s}[1], [x12], %[src_stride]\n"
"ld1 {v0.s}[2], [x12], %[src_stride]\n"
"ld1 {v0.s}[3], [x12], %[src_stride]\n"
"ld1 {v1.s}[0], [x12], %[src_stride]\n"
"ld1 {v1.s}[1], [x12], %[src_stride]\n"
"ld1 {v1.s}[2], [x12], %[src_stride]\n"
"ld1 {v1.s}[3], [x12], %[src_stride]\n"

"st1 {v0.16b}, [x11], #16\n"
"st1 {v1.16b}, [x11], #16\n"

"saddlp v4.8h, v0.16b \n"
"saddlp v5.8h, v1.16b \n"

"saddlp v0.4s, v4.8h \n"
"saddlp v1.4s, v5.8h \n"

"add v10.4s, v10.4s, v0.4s \n"
"add v11.4s, v11.4s, v1.4s \n"
"b 1b \n"

"3: \n" /* col res 1 */
"dup v0.4s, wzr \n"
"dup v1.4s, wzr \n"

"ld1 {v0.b}[0], [x12], %[src_stride]\n"
"ld1 {v0.b}[4], [x12], %[src_stride]\n"
"ld1 {v0.b}[8], [x12], %[src_stride]\n"
"ld1 {v0.b}[12], [x12], %[src_stride]\n"
"ld1 {v1.b}[0], [x12], %[src_stride]\n"
"ld1 {v1.b}[4], [x12], %[src_stride]\n"
"ld1 {v1.b}[8], [x12], %[src_stride]\n"
"ld1 {v1.b}[12], [x12], %[src_stride]\n"

"st1 {v0.16b}, [x11], #16\n"
"st1 {v1.16b}, [x11], #16\n"
"saddlp v4.8h, v0.16b \n"
"saddlp v5.8h, v1.16b \n"
"saddlp v0.4s, v4.8h \n"
"saddlp v1.4s, v5.8h \n"
"add v10.4s, v10.4s, v0.4s \n"
"add v11.4s, v11.4s, v1.4s \n"
"b 6f \n"

"4: \n" /* col res 2 */
"dup v0.4s, wzr \n"
"dup v1.4s, wzr \n"

"ld1 {v0.h}[0], [x12], %[src_stride]\n"
"ld1 {v0.h}[2], [x12], %[src_stride]\n"
"ld1 {v0.h}[4], [x12], %[src_stride]\n"
"ld1 {v0.h}[6], [x12], %[src_stride]\n"
"ld1 {v1.h}[0], [x12], %[src_stride]\n"
"ld1 {v1.h}[2], [x12], %[src_stride]\n"
"ld1 {v1.h}[4], [x12], %[src_stride]\n"
"ld1 {v1.h}[6], [x12], %[src_stride]\n"

"st1 {v0.16b}, [x11], #16\n"
"st1 {v1.16b}, [x11], #16\n"
"saddlp v4.8h, v0.16b \n"
"saddlp v5.8h, v1.16b \n"
"saddlp v0.4s, v4.8h \n"
"saddlp v1.4s, v5.8h \n"
"add v10.4s, v10.4s, v0.4s \n"
"add v11.4s, v11.4s, v1.4s \n"
"b 6f \n"

"5: \n" /* col res 3 */
"dup v0.4s, wzr \n"
"dup v1.4s, wzr \n"
"add x13, x12, #2 \n"

"ld1 {v0.h}[0], [x12], %[src_stride]\n"
"ld1 {v0.b}[2], [x13], %[src_stride]\n"
"ld1 {v0.h}[2], [x12], %[src_stride]\n"
"ld1 {v0.b}[6], [x13], %[src_stride]\n"
"ld1 {v0.h}[4], [x12], %[src_stride]\n"
"ld1 {v0.b}[10], [x13], %[src_stride]\n"
"ld1 {v0.h}[6], [x12], %[src_stride]\n"
"ld1 {v0.b}[14], [x13], %[src_stride]\n"
"ld1 {v1.h}[0], [x12], %[src_stride]\n"
"ld1 {v1.b}[2], [x13], %[src_stride]\n"
"ld1 {v1.h}[2], [x12], %[src_stride]\n"
"ld1 {v1.b}[6], [x13], %[src_stride]\n"
"ld1 {v1.h}[4], [x12], %[src_stride]\n"
"ld1 {v1.b}[10], [x13], %[src_stride]\n"
"ld1 {v1.h}[6], [x12], %[src_stride]\n"
"ld1 {v1.b}[14], [x13], %[src_stride]\n"

"st1 {v0.16b}, [x11], #16\n"
"st1 {v1.16b}, [x11], #16\n"
"saddlp v4.8h, v0.16b \n"
"saddlp v5.8h, v1.16b \n"
"saddlp v0.4s, v4.8h \n"
"saddlp v1.4s, v5.8h \n"
"add v10.4s, v10.4s, v0.4s \n"
"add v11.4s, v11.4s, v1.4s \n"
"b 6f \n"

"6: \n"
"mul v10.4s, v10.4s, v20.4s \n"
"mul v11.4s, v11.4s, v20.4s \n"

"st1 {v10.4s}, [x20], #16 \n"
"st1 {v11.4s}, [x20], #16 \n"

:
: [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ input_sum_r ] "r"(input_sum_r),
[ src_stride ] "r"(src_stride), [ ic_4div ] "r"(ic_4div), [ ic_4res ] "r"(ic_4res),
[ filter_zp ] "r"(filter_zp)
: "x0", "x1", "x10", "x11", "x12", "x13", "x20", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v10", "v11",
"v20");
#else
int32_t tmp_sum_value[8] = {0};
for (int ici = 0; ici < ic_4div; ici += C4NUM) {
for (int i = 0; i < C8NUM; i++) {
tmp_sum_value[i] += src_ic[0 + i * input_channel];
tmp_sum_value[i] += src_ic[1 + i * input_channel];
tmp_sum_value[i] += src_ic[2 + i * input_channel];
tmp_sum_value[i] += src_ic[3 + i * input_channel];
pack_ic[0 + i * C4NUM] = src_ic[0 + i * input_channel];
pack_ic[1 + i * C4NUM] = src_ic[1 + i * input_channel];
pack_ic[2 + i * C4NUM] = src_ic[2 + i * input_channel];
pack_ic[3 + i * C4NUM] = src_ic[3 + i * input_channel];
}
tmp_sum_value += src_ic[0];
tmp_sum_value += src_ic[1];
tmp_sum_value += src_ic[2];
tmp_sum_value += src_ic[3];
pack_ic[0] = src_ic[0];
pack_ic[1] = src_ic[1];
pack_ic[2] = src_ic[2];
pack_ic[3] = src_ic[3];
src_ic += C4NUM;
pack_ic += C4NUM * C8NUM;
}
for (int ici = ic_4div; ici < input_channel; ici += 1) {
for (int i = 0; i < C8NUM; i++) {
tmp_sum_value[i] += src_ic[i * input_channel];
pack_ic[i * C4NUM] = src_ic[i * input_channel];
}
tmp_sum_value += src_ic[0];
pack_ic[0] = src_ic[0];
src_ic += 1;
pack_ic += 1;
}

for (int oci = 0; oci < oc_8div; oci += C8NUM) {
for (int curoi = 0; curoi < C8NUM; curoi++) {
input_sum_oc[curoi] = tmp_sum_value * filter_zp[oci + curoi];
}
input_sum_oc += inputsum_stride;
}
if (oc_8div != output_channel) {
for (int oci = 0; oci < oc_8res; oci += 1) {
input_sum_oc[oci] = tmp_sum_value * filter_zp[oc_8div + oci];
}
for (int oci = oc_8res; oci < C8NUM; oci += 1) {
input_sum_oc[oci] = 0;
}
} /* oc8 res done */

src_r += input_channel;
pack_r += C4NUM;
input_sum_r += C8NUM;
}

for (int hwi = plane_size; hwi < hw8; hwi++) {
for (int oc = 0; oc < oc8; oc++) {
int oc8div = oc / C8NUM, oc8res = oc % C8NUM;
input_sum[oc8div * inputsum_stride + hwi * C8NUM + oc8res] = 0;
}
}
}
return;
}

void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel,
size_t plane_size, ConvParameter *conv_param) {
int ic4 = UP_ROUND(input_channel, C4NUM);
size_t hw_8div = plane_size / C8NUM * C8NUM;
size_t ic_4div = input_channel / C4NUM * C4NUM;
int32_t filter_zp = conv_param->conv_quant_arg_.filter_quant_args_[0].zp_;

const int8_t *src_r = src_input;
int8_t *pack_r = packed_input;
/* per layer */
for (int hwi = 0; hwi < hw_8div; hwi += C8NUM) {
const int8_t *src_ic = src_r;
int8_t *pack_ic = pack_r;
int32_t *input_sum_r = input_sum + hwi;
#ifdef ENABLE_ARM64
size_t src_stride = input_channel;
size_t ic_4res = input_channel - ic_4div;
asm volatile(
"dup v10.4s, wzr \n"
"dup v11.4s, wzr \n"
"mov x20, %[input_sum_r] \n"
"dup v20.4s, %w[filter_zp] \n"

"mov x10, %[src_ic] \n"
"mov x11, %[pack_ic] \n"

"mov x0, #0 \n"
"1: \n"
"cmp x0, %[ic_4div] \n"
"add x0, x0, #4\n"
"mov x12, x10 \n"
"add x10, x10, #4\n"
"blt 2f \n"
"cmp %[ic_4res], #0\n"
"beq 6f \n"
"cmp %[ic_4res], #1\n"
"beq 3f \n"
"cmp %[ic_4res], #2\n"
"beq 4f \n"
"cmp %[ic_4res], #3\n"
"beq 5f \n"

"2: \n"
"ld1 {v0.s}[0], [x12], %[src_stride]\n"
"ld1 {v0.s}[1], [x12], %[src_stride]\n"
"ld1 {v0.s}[2], [x12], %[src_stride]\n"
"ld1 {v0.s}[3], [x12], %[src_stride]\n"
"ld1 {v1.s}[0], [x12], %[src_stride]\n"
"ld1 {v1.s}[1], [x12], %[src_stride]\n"
"ld1 {v1.s}[2], [x12], %[src_stride]\n"
"ld1 {v1.s}[3], [x12], %[src_stride]\n"

"st1 {v0.16b}, [x11], #16\n"
"st1 {v1.16b}, [x11], #16\n"

"saddlp v4.8h, v0.16b \n"
"saddlp v5.8h, v1.16b \n"

"saddlp v0.4s, v4.8h \n"
"saddlp v1.4s, v5.8h \n"

"add v10.4s, v10.4s, v0.4s \n"
"add v11.4s, v11.4s, v1.4s \n"
"b 1b \n"

"3: \n" /* col res 1 */
"dup v0.4s, wzr \n"
"dup v1.4s, wzr \n"

"ld1 {v0.b}[0], [x12], %[src_stride]\n"
"ld1 {v0.b}[4], [x12], %[src_stride]\n"
"ld1 {v0.b}[8], [x12], %[src_stride]\n"
"ld1 {v0.b}[12], [x12], %[src_stride]\n"
"ld1 {v1.b}[0], [x12], %[src_stride]\n"
"ld1 {v1.b}[4], [x12], %[src_stride]\n"
"ld1 {v1.b}[8], [x12], %[src_stride]\n"
"ld1 {v1.b}[12], [x12], %[src_stride]\n"

"st1 {v0.16b}, [x11], #16\n"
"st1 {v1.16b}, [x11], #16\n"
"saddlp v4.8h, v0.16b \n"
"saddlp v5.8h, v1.16b \n"
"saddlp v0.4s, v4.8h \n"
"saddlp v1.4s, v5.8h \n"
"add v10.4s, v10.4s, v0.4s \n"
"add v11.4s, v11.4s, v1.4s \n"
"b 6f \n"

"4: \n" /* col res 2 */
"dup v0.4s, wzr \n"
"dup v1.4s, wzr \n"

"ld1 {v0.h}[0], [x12], %[src_stride]\n"
"ld1 {v0.h}[2], [x12], %[src_stride]\n"
"ld1 {v0.h}[4], [x12], %[src_stride]\n"
"ld1 {v0.h}[6], [x12], %[src_stride]\n"
"ld1 {v1.h}[0], [x12], %[src_stride]\n"
"ld1 {v1.h}[2], [x12], %[src_stride]\n"
"ld1 {v1.h}[4], [x12], %[src_stride]\n"
"ld1 {v1.h}[6], [x12], %[src_stride]\n"

"st1 {v0.16b}, [x11], #16\n"
"st1 {v1.16b}, [x11], #16\n"
"saddlp v4.8h, v0.16b \n"
"saddlp v5.8h, v1.16b \n"
"saddlp v0.4s, v4.8h \n"
"saddlp v1.4s, v5.8h \n"
"add v10.4s, v10.4s, v0.4s \n"
"add v11.4s, v11.4s, v1.4s \n"
"b 6f \n"

"5: \n" /* col res 3 */
"dup v0.4s, wzr \n"
"dup v1.4s, wzr \n"
"add x13, x12, #2 \n"

"ld1 {v0.h}[0], [x12], %[src_stride]\n"
"ld1 {v0.b}[2], [x13], %[src_stride]\n"
"ld1 {v0.h}[2], [x12], %[src_stride]\n"
"ld1 {v0.b}[6], [x13], %[src_stride]\n"
"ld1 {v0.h}[4], [x12], %[src_stride]\n"
"ld1 {v0.b}[10], [x13], %[src_stride]\n"
"ld1 {v0.h}[6], [x12], %[src_stride]\n"
"ld1 {v0.b}[14], [x13], %[src_stride]\n"
"ld1 {v1.h}[0], [x12], %[src_stride]\n"
"ld1 {v1.b}[2], [x13], %[src_stride]\n"
"ld1 {v1.h}[2], [x12], %[src_stride]\n"
"ld1 {v1.b}[6], [x13], %[src_stride]\n"
"ld1 {v1.h}[4], [x12], %[src_stride]\n"
"ld1 {v1.b}[10], [x13], %[src_stride]\n"
"ld1 {v1.h}[6], [x12], %[src_stride]\n"
"ld1 {v1.b}[14], [x13], %[src_stride]\n"

"st1 {v0.16b}, [x11], #16\n"
"st1 {v1.16b}, [x11], #16\n"
"saddlp v4.8h, v0.16b \n"
"saddlp v5.8h, v1.16b \n"
"saddlp v0.4s, v4.8h \n"
"saddlp v1.4s, v5.8h \n"
"add v10.4s, v10.4s, v0.4s \n"
"add v11.4s, v11.4s, v1.4s \n"
"b 6f \n"

"6: \n"
"mul v10.4s, v10.4s, v20.4s \n"
"mul v11.4s, v11.4s, v20.4s \n"

"st1 {v10.4s}, [x20], #16 \n"
"st1 {v11.4s}, [x20], #16 \n"

:
: [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ input_sum_r ] "r"(input_sum_r),
[ src_stride ] "r"(src_stride), [ ic_4div ] "r"(ic_4div), [ ic_4res ] "r"(ic_4res), [ filter_zp ] "r"(filter_zp)
: "x0", "x1", "x10", "x11", "x12", "x13", "x20", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v10", "v11",
"v20");
#else
int32_t tmp_sum_value[8] = {0};
for (int ici = 0; ici < ic_4div; ici += C4NUM) {
for (int i = 0; i < C8NUM; i++) {
input_sum_r[i] = tmp_sum_value[i] * filter_zp;
tmp_sum_value[i] += src_ic[0 + i * input_channel];
tmp_sum_value[i] += src_ic[1 + i * input_channel];
tmp_sum_value[i] += src_ic[2 + i * input_channel];
tmp_sum_value[i] += src_ic[3 + i * input_channel];
pack_ic[0 + i * C4NUM] = src_ic[0 + i * input_channel];
pack_ic[1 + i * C4NUM] = src_ic[1 + i * input_channel];
pack_ic[2 + i * C4NUM] = src_ic[2 + i * input_channel];
pack_ic[3 + i * C4NUM] = src_ic[3 + i * input_channel];
}
#endif
src_r += input_channel * C8NUM;
pack_r += ic4 * C8NUM;
src_ic += C4NUM;
pack_ic += C4NUM * C8NUM;
}
for (int ici = ic_4div; ici < input_channel; ici += 1) {
for (int i = 0; i < C8NUM; i++) {
tmp_sum_value[i] += src_ic[i * input_channel];
pack_ic[i * C4NUM] = src_ic[i * input_channel];
}
src_ic += 1;
pack_ic += 1;
}

if (hw_8div != plane_size) {
memset(pack_r, 0, C8NUM * ic4);
for (int hwi = hw_8div; hwi < plane_size; hwi += 1) {
int32_t tmp_sum_value = 0;
const int8_t *src_ic = src_r;
int8_t *pack_ic = pack_r;
for (int ici = 0; ici < ic_4div; ici += C4NUM) {
tmp_sum_value += src_ic[0];
tmp_sum_value += src_ic[1];
tmp_sum_value += src_ic[2];
tmp_sum_value += src_ic[3];
pack_ic[0] = src_ic[0];
pack_ic[1] = src_ic[1];
pack_ic[2] = src_ic[2];
pack_ic[3] = src_ic[3];
src_ic += C4NUM;
pack_ic += C4NUM * C8NUM;
}
for (int ici = ic_4div; ici < input_channel; ici += 1) {
tmp_sum_value += src_ic[0];
pack_ic[0] = src_ic[0];
src_ic += 1;
pack_ic += 1;
}
input_sum[hwi] = tmp_sum_value * filter_zp;
src_r += input_channel;
pack_r += C4NUM;
for (int i = 0; i < C8NUM; i++) {
input_sum_r[i] = tmp_sum_value[i] * filter_zp;
}
#endif
src_r += input_channel * C8NUM;
pack_r += ic4 * C8NUM;
}

if (hw_8div != plane_size) {
memset(pack_r, 0, C8NUM * ic4);
for (int hwi = hw_8div; hwi < plane_size; hwi += 1) {
int32_t tmp_sum_value = 0;
const int8_t *src_ic = src_r;
int8_t *pack_ic = pack_r;
for (int ici = 0; ici < ic_4div; ici += C4NUM) {
tmp_sum_value += src_ic[0];
tmp_sum_value += src_ic[1];
tmp_sum_value += src_ic[2];
tmp_sum_value += src_ic[3];
pack_ic[0] = src_ic[0];
pack_ic[1] = src_ic[1];
pack_ic[2] = src_ic[2];
pack_ic[3] = src_ic[3];
src_ic += C4NUM;
pack_ic += C4NUM * C8NUM;
}
for (int hwi = plane_size; hwi < plane_size + hw_8res; hwi++) {
input_sum[hwi] = 0;
for (int ici = ic_4div; ici < input_channel; ici += 1) {
tmp_sum_value += src_ic[0];
pack_ic[0] = src_ic[0];
src_ic += 1;
pack_ic += 1;
}
input_sum[hwi] = tmp_sum_value * filter_zp;
src_r += input_channel;
pack_r += C4NUM;
}
for (int hwi = plane_size; hwi < UP_ROUND(plane_size, C8NUM); hwi++) {
input_sum[hwi] = 0;
}
} else {
/* per channel */
RowMajor2Row4x8MajorInt8(src_input, packed_input, plane_size, input_channel);
PackInputSum8x4Int8(packed_input, input_sum, input_channel, output_channel, plane_size, conv_param);
}
return;
}

void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum,
const int32_t *bias, int row, int col, int deep4, ConvParameter *conv_param,
MATMUL_OPT_R_FUNC matmul_func) {
const int32_t *bias, int row, int col, int deep4, int32_t *left_shift, int32_t *right_shift,
int32_t *multiplier, ConvParameter *conv_param, MATMUL_OPT_R_FUNC matmul_func) {
matmul_func(packed_input, packed_weight, dst, row, col, deep4, conv_param->output_channel_, input_sum, bias,
conv_param->conv_quant_arg_.left_shift_, conv_param->conv_quant_arg_.right_shift_,
conv_param->conv_quant_arg_.quant_multiplier_, conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], false);
left_shift, right_shift, multiplier, conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0],
conv_param->conv_quant_arg_.filter_arg_num_ != 1);
return;
}

void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum,
const int32_t *bias, int row, int col, int deep16, ConvParameter *conv_param) {
const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift,
int32_t *multiplier, ConvParameter *conv_param) {
if (conv_param->conv_quant_arg_.filter_arg_num_ > 1) {
return MatMulInt8_16x4_r(packed_input, packed_weight, dst, row, col, deep16, conv_param->output_channel_, input_sum,
bias, left_shift, right_shift, multiplier,
conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0],
conv_param->conv_quant_arg_.filter_arg_num_ != 1);
}
#ifdef ENABLE_ARM64
MatmulInt8Neon64(packed_input, packed_weight, dst, UP_ROUND(row, C4NUM), UP_ROUND(col, C4NUM), deep16, input_sum,
bias, conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0],
@@ -622,10 +753,9 @@ void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t
conv_param->conv_quant_arg_.right_shift_[0], row, col, conv_param->output_channel_);
#else
MatMulInt8_16x4_r(packed_input, packed_weight, dst, row, col, deep16, conv_param->output_channel_, input_sum, bias,
conv_param->conv_quant_arg_.left_shift_, conv_param->conv_quant_arg_.right_shift_,
conv_param->conv_quant_arg_.quant_multiplier_,
conv_param->conv_quant_arg_.output_quant_args_[0].zp_, conv_param->conv_quant_arg_.out_act_min_[0],
conv_param->conv_quant_arg_.out_act_max_[0], false);
left_shift, right_shift, multiplier, conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0],
conv_param->conv_quant_arg_.filter_arg_num_ != 1);
#endif
return;
}


+ 8
- 5
mindspore/lite/nnacl/int8/conv_int8.h View File

@@ -54,13 +54,16 @@ void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight
ConvParameter *conv_param, GEMM_FUNC gemm_func);

// int8 convolution 1x1
void Conv1x1PreOpt(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel,
size_t output_channel, size_t plane_size, ConvParameter *conv_param);
void Conv1x1PreOptPeroc(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel,
size_t output_channel, size_t plane_size, int32_t *filter_zp, size_t inputsum_stride);
void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel,
size_t plane_size, ConvParameter *conv_param);
void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum,
const int32_t *bias, int row, int col, int deep16, ConvParameter *conv_param);
const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift,
int32_t *multiplier, ConvParameter *conv_param);
void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum,
const int32_t *bias, int row, int col, int deep4, ConvParameter *conv_param,
MATMUL_OPT_R_FUNC matmul_func);
const int32_t *bias, int row, int col, int deep4, int32_t *left_shift, int32_t *right_shift,
int32_t *multiplier, ConvParameter *conv_param, MATMUL_OPT_R_FUNC matmul_func);

// int8 convolution 3x3
void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data,


+ 10
- 7
mindspore/lite/nnacl/int8/matmul_int8.c View File

@@ -186,8 +186,9 @@ void MatMulInt8_16x4(const int8_t *a, const int8_t *b, int *dst, int row_4, int
void MatMulInt8_16x4_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_16,
size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift,
int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi,
bool per_channel) {
/* row4x16-major * row16x4-major => (int8)row-major : per-channel */
bool peroc) {
/* support per-layer && weight per-channel */
/* row4x16-major * row16x4-major => (int8)row-major*/
for (int r = 0; r < row; r++) {
for (int c = 0; c < col; c++) {
int r4div = r / C4NUM, r4mod = r % C4NUM;
@@ -200,12 +201,13 @@ void MatMulInt8_16x4_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row
size_t bi = c4div * deep_16 * C4NUM + d16div * C4NUM * C16NUM + c4mod * C16NUM + d16mod;
value = value + a[ai] * b[bi];
}
int32_t cur_input_sum = per_channel ? input_sum[c4div * UP_ROUND(row, C4NUM) + r * C4NUM + c4mod] : input_sum[r];
int32_t cur_input_sum =
peroc ? input_sum[c4div * UP_ROUND(row, C4NUM) * C4NUM + r * C4NUM + c4mod] : input_sum[r];
value -= cur_input_sum;
value += bias[c];
int32_t cur_left_shift = per_channel ? left_shift[c] : left_shift[0];
int32_t cur_right_shift = per_channel ? right_shift[c] : right_shift[0];
int32_t cur_multiplier = per_channel ? multiplier[c] : multiplier[0];
int32_t cur_left_shift = peroc ? left_shift[c] : left_shift[0];
int32_t cur_right_shift = peroc ? right_shift[c] : right_shift[0];
int32_t cur_multiplier = peroc ? multiplier[c] : multiplier[0];
value = MultiplyByQuantizedMultiplier(value, cur_multiplier, cur_left_shift, cur_right_shift) + output_zp;
value = MSMIN(maxi, value);
value = MSMAX(mini, value);
@@ -232,7 +234,8 @@ void MatMulInt8_8x8_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row,
size_t bi = c8div * deep_4 * C8NUM + d4div * C8NUM * C4NUM + c8mod * C4NUM + d4mod;
value = value + a[ai] * b[bi];
}
int32_t cur_input_sum = per_channel ? input_sum[c8div * UP_ROUND(row, C8NUM) + r * C8NUM + c8mod] : input_sum[r];
int32_t cur_input_sum =
per_channel ? input_sum[c8div * UP_ROUND(row, C8NUM) * C8NUM + r * C8NUM + c8mod] : input_sum[r];
value -= cur_input_sum;
value += bias[c];
int32_t cur_left_shift = per_channel ? left_shift[c] : left_shift[0];


+ 1
- 0
mindspore/lite/nnacl/int8/matmul_int8.h View File

@@ -17,6 +17,7 @@
#ifndef MINDSPORE_LITE_NNACL_INT8_MATMUL_H_
#define MINDSPORE_LITE_NNACL_INT8_MATMUL_H_

#include <stdio.h>
#include <string.h>
#include "nnacl/op_base.h"
#include "nnacl/matmul_parameter.h"


+ 143
- 47
mindspore/lite/nnacl/pack.c View File

@@ -195,7 +195,7 @@ void Pack1x1WeightFp32(const float *weight_data, float *packed_weight, ConvParam
}

void PackInputSum16x4PerLayer(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16) {
/* optimize normal -> same layout */
/* normal matmul : 4x16 * 16x4 -> 4x4 */
#ifdef ENABLE_ARM64
asm volatile(
"mov x10, %[src] \n"
@@ -267,62 +267,158 @@ void PackInputSum16x4PerLayer(const int8_t *src, int32_t *dst, int32_t filter_zp
return;
}

void PackInputSum16x4Int8(const int8_t *input_value, int32_t *input_sum, size_t input_channel, size_t output_channel,
size_t plane_size, ConvParameter *conv_param) {
void PackInputSum16x4PerChannel(const int8_t *input_value, int32_t *input_sum, int32_t *filter_zp_ptr,
size_t plane_size, size_t input_channel, size_t output_channel) {
size_t hw4 = UP_ROUND(plane_size, C4NUM);
size_t ic16 = UP_ROUND(input_channel, C16NUM);
if (conv_param->conv_quant_arg_.filter_arg_num_ == 1) {
PackInputSum16x4PerLayer(input_value, input_sum, conv_param->conv_quant_arg_.filter_quant_args_[0].zp_, hw4, ic16);
} else {
for (int ri = 0; ri < plane_size; ri++) {
int ri4div = ri / C4NUM, ri4mod = ri % C4NUM;
for (int ci = 0; ci < output_channel; ci++) {
int32_t tmp_sum_value = 0;
int ci4div = ci / C4NUM, ci4mod = ci % C4NUM;
int32_t filter_zp = conv_param->conv_quant_arg_.filter_quant_args_[ci].zp_;
for (int di = 0; di < input_channel; di++) {
size_t di16div = di / C16NUM, di16mod = di % C16NUM;
int src_index = ri4div * C4NUM * ic16 + di16div * C16NUM * C4NUM + ri4mod * C16NUM + di16mod;
tmp_sum_value += input_value[src_index];
}
int dst_index = ci4div * C4NUM * hw4 + ri * C4NUM + ci4mod;
input_sum[dst_index] = tmp_sum_value * filter_zp;
#ifdef ENABLE_ARM64
size_t oc_div4 = output_channel / C4NUM * C4NUM;
size_t oc_res4 = output_channel - oc_div4;
size_t inputsun_stride = hw4 * C4NUM * 4 - C4NUM * C4NUM * 4;
asm volatile(
"mov x10, %[input_value] \n"
"mov x11, %[input_sum] \n"
"mov x15, %[filter_zp_ptr] \n"

"mov x0, #0 \n" // row 4 count
"1: \n"
"cmp x0, %[hw4] \n"
"beq 11f \n"
"add x0, x0, #4\n"
"dup v10.4s, wzr \n"
"mov x2, #0 \n" // input deep count
"mov x16, x15 \n"

"2: \n"
"cmp x2, %[ic16] \n"
"beq 3f \n"
"add x2, x2, #16 \n"

"ld1 {v0.16b}, [x10], #16\n"
"ld1 {v1.16b}, [x10], #16\n"
"ld1 {v2.16b}, [x10], #16\n"
"ld1 {v3.16b}, [x10], #16\n"

"saddlp v4.8h, v0.16b \n"
"saddlp v5.8h, v1.16b \n"
"saddlp v6.8h, v2.16b \n"
"saddlp v7.8h, v3.16b \n"
"saddlp v0.4S, v4.8h \n"
"saddlp v1.4S, v5.8h \n"
"saddlp v2.4S, v6.8h \n"
"saddlp v3.4S, v7.8h \n"
"addv s4, v0.4S \n"
"addv s5, v1.4S \n"
"addv s6, v2.4S \n"
"addv s7, v3.4S \n"
"mov v0.s[0], v4.s[0] \n"
"mov v0.s[1], v5.s[0] \n"
"mov v0.s[2], v6.s[0] \n"
"mov v0.s[3], v7.s[0] \n"

"add v10.4s, v10.4s, v0.4s \n"
"b 2b \n"

"3: \n"
"mov x12, x11 \n" // tmp inputsm inputsum hw
"add x11, x11, #64 \n"
"mov x4, #0 \n" // oc count

"dup v1.4s, v10.s[0] \n"
"dup v2.4s, v10.s[1] \n"
"dup v3.4s, v10.s[2] \n"
"dup v4.4s, v10.s[3] \n"

"4: \n"
"cmp x4, %[oc_div4] \n"
"beq 6f \n"
"add x4, x4, #4\n"
"ld1 {v15.4s}, [x16], #16\n"

"mul v16.4s, v15.4s, v1.4s \n"
"mul v17.4s, v15.4s, v2.4s \n"
"mul v18.4s, v15.4s, v3.4s \n"
"mul v19.4s, v15.4s, v4.4s \n"
"st1 {v16.4s}, [x12], #16 \n"
"st1 {v17.4s}, [x12], #16 \n"
"st1 {v18.4s}, [x12], #16 \n"
"st1 {v19.4s}, [x12], #16 \n"
"add x12, x12, %[inputsun_stride] \n"
"b 4b \n"

"6: \n"
"cmp %[oc_res4], #0\n"
"beq 1b \n"
"dup v15.4s, wzr \n"
"cmp %[oc_res4], #1\n"
"beq 7f \n"
"cmp %[oc_res4], #2\n"
"beq 8f \n"
"cmp %[oc_res4], #3\n"
"beq 9f \n"

"7: \n"
"ld1 {v15.s}[0], [x16] \n"
"b 10f \n"

"8: \n"
"ld1 {v15.h}[0], [x16] \n"
"b 10f \n"

"9: \n"
"ld1 {v15.h}[0], [x16] \n"
"add x16, x16, #8 \n"
"ld1 {v15.s}[2], [x16] \n"
"b 10f \n"

"10: \n"
"mul v16.4s, v15.4s, v1.4s \n"
"mul v17.4s, v15.4s, v2.4s \n"
"mul v18.4s, v15.4s, v3.4s \n"
"mul v19.4s, v15.4s, v4.4s \n"
"st1 {v16.4s}, [x12], #16 \n"
"st1 {v17.4s}, [x12], #16 \n"
"st1 {v18.4s}, [x12], #16 \n"
"st1 {v19.4s}, [x12], #16 \n"
"b 1b \n"

"11: \n"

:
: [ input_value ] "r"(input_value), [ input_sum ] "r"(input_sum), [ filter_zp_ptr ] "r"(filter_zp_ptr),
[ hw4 ] "r"(hw4), [ ic16 ] "r"(ic16), [ oc_div4 ] "r"(oc_div4), [ oc_res4 ] "r"(oc_res4),
[ inputsun_stride ] "r"(inputsun_stride)
: "x0", "x2", "x4", "x10", "x11", "x12", "x15", "x16", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v10", "v15",
"v16", "v17", "v18", "v19");
#else

for (int ri = 0; ri < plane_size; ri++) {
int ri4div = ri / C4NUM, ri4mod = ri % C4NUM;
for (int ci = 0; ci < output_channel; ci++) {
int32_t tmp_sum_value = 0;
int ci4div = ci / C4NUM, ci4mod = ci % C4NUM;
int32_t filter_zp = filter_zp_ptr[ci];
for (int di = 0; di < input_channel; di++) {
size_t di16div = di / C16NUM, di16mod = di % C16NUM;
int src_index = ri4div * C4NUM * ic16 + di16div * C16NUM * C4NUM + ri4mod * C16NUM + di16mod;
tmp_sum_value += input_value[src_index];
}
int dst_index = ci4div * C4NUM * hw4 + ri * C4NUM + ci4mod;
input_sum[dst_index] = tmp_sum_value * filter_zp;
}
}
#endif
return;
}

void PackInputSum8x4Int8(const int8_t *input_value, int32_t *input_sum, size_t input_channel, size_t output_channel,
size_t plane_size, ConvParameter *conv_param) {
size_t hw8 = UP_ROUND(plane_size, C8NUM);
size_t ic4 = UP_ROUND(input_channel, C4NUM);
void PackInputSum16x4Int8(const int8_t *input, int32_t *input_sum, int32_t *filter_zp, ConvParameter *conv_param) {
size_t hw4 = UP_ROUND(conv_param->input_h_ * conv_param->input_w_, C4NUM);
size_t ic16 = UP_ROUND(conv_param->input_channel_, C16NUM);
if (conv_param->conv_quant_arg_.filter_arg_num_ == 1) {
for (int r = 0; r < hw8; r++) {
int32_t tmp_value = 0;
for (int c = 0; c < ic4; c++) {
int r8div = r / C8NUM, r8mod = r % C8NUM, c4div = c / C4NUM, c4mod = c % C4NUM;
int src_index = r8div * C8NUM * ic4 + c4div * C8NUM * C4NUM + r8mod * C4NUM + c4mod;
tmp_value += input_value[src_index];
}
input_sum[r] = tmp_value * conv_param->conv_quant_arg_.filter_quant_args_[0].zp_;
}
PackInputSum16x4PerLayer(input, input_sum, conv_param->conv_quant_arg_.filter_quant_args_[0].zp_, hw4, ic16);
} else {
for (int ri = 0; ri < plane_size; ri++) {
int ri8div = ri / C8NUM, ri8mod = ri % C8NUM;
for (int ci = 0; ci < output_channel; ci++) {
int32_t tmp_sum_value = 0;
int ci8div = ci / C8NUM, ci8mod = ci % C8NUM;
int32_t filter_zp = conv_param->conv_quant_arg_.filter_quant_args_[ci].zp_;
for (int di = 0; di < input_channel; di++) {
size_t di4div = di / C4NUM, di4mod = di % C4NUM;
int src_index = ri8div * C8NUM * ic4 + di4div * C8NUM * C4NUM + ri8mod * C4NUM + di4mod;
tmp_sum_value += input_value[src_index];
}
int dst_index = ci8div * C8NUM * hw8 + ri * C8NUM + ci8mod;
input_sum[dst_index] = tmp_sum_value * filter_zp;
}
}
PackInputSum16x4PerChannel(input, input_sum, filter_zp, conv_param->input_h_ * conv_param->input_w_,
conv_param->input_channel_, conv_param->output_channel_);
}
return;
}


+ 2
- 2
mindspore/lite/nnacl/pack.h View File

@@ -17,6 +17,7 @@
#ifndef MINDSPORE_LITE_NNACL_PACK_H_
#define MINDSPORE_LITE_NNACL_PACK_H_

#include <stdio.h>
#ifdef ENABLE_NEON
#include <arm_neon.h>
#endif
@@ -41,8 +42,7 @@ void Conv1x1InputPack(const void *src_ptr, void *dst_ptr, ConvParameter *conv_pa

void Pack1x1WeightFp32(const float *weight_data, float *packed_weight, ConvParameter *conv_param);

void PackInputSum16x4Int8(const int8_t *input_value, int32_t *input_sum, size_t input_channel, size_t output_channel,
size_t plane_size, ConvParameter *conv_param);
void PackInputSum16x4Int8(const int8_t *input, int32_t *input_sum, int32_t *filter_zp, ConvParameter *conv_param);

void PackInputSum8x4Int8(const int8_t *input_value, int32_t *input_sum, size_t input_channel, size_t output_channel,
size_t plane_size, ConvParameter *conv_param);


+ 0
- 2
mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc View File

@@ -316,14 +316,12 @@ int ConvolutionBaseCPUKernel::SetQuantParam() {
MS_LOG(ERROR) << "Set Quant Multiplier Failed.";
return ret;
}
// now only consider per tensor for output
bool relu = conv_param_->act_type_ == ActType_Relu;
bool relu6 = conv_param_->act_type_ == ActType_Relu6;
CalculateActivationRangeQuantized(relu, relu6, conv_param_->conv_quant_arg_.output_quant_args_[0].zp_,
conv_param_->conv_quant_arg_.output_quant_args_[0].scale_,
&conv_param_->conv_quant_arg_.out_act_min_[0],
&conv_param_->conv_quant_arg_.out_act_max_[0]);

return RET_OK;
}
int ConvolutionBaseCPUKernel::RestoreFilter(lite::Tensor *input_tensor) {


+ 61
- 17
mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.cc View File

@@ -16,6 +16,7 @@

#include "src/runtime/kernel/arm/int8/convolution_1x1_int8.h"
#include "src/runtime/runtime_api.h"
#include "src/common/file_utils.h"

using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_MEMORY_FAILED;
@@ -41,6 +42,10 @@ Convolution1x1Int8CPUKernel::~Convolution1x1Int8CPUKernel() {
free(packed_weight_);
packed_weight_ = nullptr;
}
if (filter_peroc_ && filter_zp_ptr_ != nullptr) {
free(filter_zp_ptr_);
filter_zp_ptr_ = nullptr;
}
FreeResizeBuf();
FreeQuantParam();
}
@@ -54,7 +59,7 @@ void Convolution1x1Int8CPUKernel::FreeResizeBuf() {
}

void Convolution1x1Int8CPUKernel::CheckSupportOptimize() {
support_optimize_ = true;
support_optimize_ = false;
matmul_func_ = MatMulInt8_8x8_r;
#ifdef ENABLE_ARM64
void *optimize_op_handler = OptimizeModule::GetInstance()->optimized_op_handler_;
@@ -73,6 +78,10 @@ void Convolution1x1Int8CPUKernel::CheckSupportOptimize() {
support_optimize_ = false;
matmul_func_ = nullptr;
}

if (filter_peroc_) {
support_optimize_ = false;
}
#endif
return;
}
@@ -118,14 +127,23 @@ int Convolution1x1Int8CPUKernel::InitWeightBias() {
int32_t input_zp = conv_param_->conv_quant_arg_.input_quant_args_[0].zp_;
for (int oc = 0; oc < output_channel; oc++) {
int32_t weight_sum_value = 0;
int32_t filter_zp = (conv_param_->conv_quant_arg_.filter_arg_num_ == 1)
? conv_param_->conv_quant_arg_.filter_quant_args_[0].zp_
: conv_param_->conv_quant_arg_.filter_quant_args_[oc].zp_;
int32_t filter_zp = (filter_peroc_) ? conv_param_->conv_quant_arg_.filter_quant_args_[oc].zp_
: conv_param_->conv_quant_arg_.filter_quant_args_[0].zp_;
for (int ic = 0; ic < input_channel; ic++) {
weight_sum_value += weight[oc * input_channel + ic];
}
bias_data[oc] += filter_zp * input_zp * input_channel - weight_sum_value * input_zp;
}

if (filter_peroc_) {
filter_zp_ptr_ = reinterpret_cast<int32_t *>(malloc(output_channel * sizeof(int32_t)));
if (filter_zp_ptr_ == nullptr) {
return RET_ERROR;
}
for (int fi = 0; fi < output_channel; fi++) {
filter_zp_ptr_[fi] = conv_param_->conv_quant_arg_.filter_quant_args_[fi].zp_;
}
}
return RET_OK;
}

@@ -136,14 +154,16 @@ int Convolution1x1Int8CPUKernel::Init() {
return RET_ERROR;
}

CheckSupportOptimize();

auto ret = SetQuantParam();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Set quant param failed.";
return ret;
}

filter_peroc_ = (conv_param_->conv_quant_arg_.filter_arg_num_ != 1);

CheckSupportOptimize();

ret = InitWeightBias();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init weight bias failed.";
@@ -229,14 +249,17 @@ void Convolution1x1Int8CPUKernel::Pre1x1Trans(int8_t *src_input, int8_t *src_out
ParallelLaunch(THREAD_POOL_DEFAULT, Convolution1x1Int8Pre, this, thread_count_hw_);
} else {
RowMajor2Row16x4MajorInt8(input_ptr_, packed_input_, matmul_param_->row_, matmul_param_->deep_);
PackInputSum16x4Int8(packed_input_, input_sum_, matmul_param_->deep_, matmul_param_->col_, matmul_param_->row_,
conv_param_);
PackInputSum16x4Int8(packed_input_, input_sum_, filter_zp_ptr_, conv_param_);
}

return;
}

int Convolution1x1Int8CPUKernel::RunImpl(int task_id) {
int32_t *cur_input_sum = input_sum_;
int32_t *cur_left_shift = conv_param_->conv_quant_arg_.left_shift_;
int32_t *cur_right_shift = conv_param_->conv_quant_arg_.right_shift_;
int32_t *cur_multiplier = conv_param_->conv_quant_arg_.quant_multiplier_;

if (support_optimize_) {
int cur_stride = thread_stride_ * C8NUM;
int res_stride = matmul_param_->col_ - task_id * thread_stride_ * C8NUM;
@@ -244,10 +267,17 @@ int Convolution1x1Int8CPUKernel::RunImpl(int task_id) {
if (cur_oc <= 0) {
return RET_OK;
}
if (filter_peroc_) {
cur_input_sum = input_sum_ + task_id * matmul_param_->row_8_ * thread_stride_ * C8NUM;
cur_left_shift = conv_param_->conv_quant_arg_.left_shift_ + task_id * thread_stride_ * C8NUM;
cur_right_shift = conv_param_->conv_quant_arg_.right_shift_ + task_id * thread_stride_ * C8NUM;
cur_multiplier = conv_param_->conv_quant_arg_.quant_multiplier_ + task_id * thread_stride_ * C8NUM;
}
Conv1x1Int8Opt(packed_input_, packed_weight_ + task_id * thread_stride_ * C8NUM * matmul_param_->deep_4_,
output_ptr_ + task_id * thread_stride_ * C8NUM, input_sum_,
output_ptr_ + task_id * thread_stride_ * C8NUM, cur_input_sum,
reinterpret_cast<int32_t *>(bias_data_) + task_id * thread_stride_ * C8NUM, matmul_param_->row_,
cur_oc, matmul_param_->deep_4_, conv_param_, matmul_func_);
cur_oc, matmul_param_->deep_4_, cur_left_shift, cur_right_shift, cur_multiplier, conv_param_,
matmul_func_);
} else {
int cur_stride = thread_stride_ * C4NUM;
int res_stride = matmul_param_->col_ - task_id * thread_stride_ * C4NUM;
@@ -255,10 +285,16 @@ int Convolution1x1Int8CPUKernel::RunImpl(int task_id) {
if (cur_oc <= 0) {
return RET_OK;
}
if (filter_peroc_) {
cur_input_sum = input_sum_ + task_id * matmul_param_->row_4_ * thread_stride_ * C4NUM;
cur_left_shift = conv_param_->conv_quant_arg_.left_shift_ + task_id * thread_stride_ * C4NUM;
cur_right_shift = conv_param_->conv_quant_arg_.right_shift_ + task_id * thread_stride_ * C4NUM;
cur_multiplier = conv_param_->conv_quant_arg_.quant_multiplier_ + task_id * thread_stride_ * C4NUM;
}
Conv1x1Int8(packed_input_, packed_weight_ + task_id * thread_stride_ * C4NUM * matmul_param_->deep_16_,
output_ptr_ + task_id * thread_stride_ * C4NUM, input_sum_,
output_ptr_ + task_id * thread_stride_ * C4NUM, cur_input_sum,
reinterpret_cast<int32_t *>(bias_data_) + task_id * thread_stride_ * C4NUM, matmul_param_->row_, cur_oc,
matmul_param_->deep_16_, conv_param_);
matmul_param_->deep_16_, cur_left_shift, cur_right_shift, cur_multiplier, conv_param_);
}
return RET_OK;
}
@@ -270,10 +306,18 @@ int Convolution1x1Int8CPUKernel::RunPre(int task_id) {
if (cur_hw <= 0) {
return RET_OK;
}
Conv1x1PreOpt(input_ptr_ + task_id * thread_stride_hw_ * C8NUM * matmul_param_->deep_,
packed_input_ + task_id * thread_stride_hw_ * C8NUM * matmul_param_->deep_4_,
input_sum_ + task_id * thread_stride_hw_ * C8NUM, matmul_param_->deep_, matmul_param_->col_, cur_hw,
conv_param_);

if (filter_peroc_) {
Conv1x1PreOptPeroc(input_ptr_ + task_id * thread_stride_hw_ * C8NUM * matmul_param_->deep_,
packed_input_ + task_id * thread_stride_hw_ * C8NUM * matmul_param_->deep_4_,
input_sum_ + task_id * thread_stride_hw_ * C8NUM * C8NUM, matmul_param_->deep_,
matmul_param_->col_, cur_hw, filter_zp_ptr_, matmul_param_->row_8_ * C8NUM);
} else {
Conv1x1PreOptPert(input_ptr_ + task_id * thread_stride_hw_ * C8NUM * matmul_param_->deep_,
packed_input_ + task_id * thread_stride_hw_ * C8NUM * matmul_param_->deep_4_,
input_sum_ + task_id * thread_stride_hw_ * C8NUM, matmul_param_->deep_, cur_hw, conv_param_);
}

return RET_OK;
}



+ 3
- 1
mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.h View File

@@ -56,7 +56,8 @@ class Convolution1x1Int8CPUKernel : public ConvolutionBaseCPUKernel {
void CheckSupportOptimize();

private:
int32_t *input_sum_ = nullptr; /* per-channel: oc4 format */
int32_t *input_sum_ = nullptr; /* per-channel: oc4 format */
int32_t *filter_zp_ptr_ = nullptr; /* oc - per - channel */
int8_t *packed_weight_ = nullptr;
int8_t *packed_input_ = nullptr;
int8_t *input_ptr_ = nullptr;
@@ -70,6 +71,7 @@ class Convolution1x1Int8CPUKernel : public ConvolutionBaseCPUKernel {
MatMulParameter *matmul_param_ = nullptr;
MATMUL_OPT_R_FUNC matmul_func_ = nullptr;
bool support_optimize_ = false;
bool filter_peroc_ = false;
};
} // namespace mindspore::kernel



+ 1
- 2
mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.cc View File

@@ -397,10 +397,9 @@ kernel::LiteKernel *CpuConvInt8KernelCreator(const std::vector<lite::Tensor *> &
int dilation_h = conv_param->dilation_h_;
int dilation_w = conv_param->dilation_w_;
kernel::LiteKernel *kernel;
auto filter_quant_size = inputs[kWeightIndex]->GetQuantParams().size();
if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) {
kernel = new (std::nothrow) kernel::ConvolutionInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
} else if (kernel_h == 1 && kernel_w == 1 && filter_quant_size == 1) {
} else if (kernel_h == 1 && kernel_w == 1) {
kernel = new (std::nothrow) kernel::Convolution1x1Int8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
} else {
kernel = new (std::nothrow) kernel::ConvolutionInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive);


Loading…
Cancel
Save