Browse Source

!26326 matmul int8 sdot

Merge pull request !26326 from zhaozhenlong/lite/issue/matmul-int8
tags/v1.6.0
i-robot Gitee 4 years ago
parent
commit
dafa363bde
9 changed files with 495 additions and 52 deletions
  1. +264
    -2
      mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/matmul_int8.c
  2. +10
    -2
      mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/matmul_int8.h
  3. +1
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/matmul_parameter.h
  4. +1
    -1
      mindspore/lite/micro/coder/opcoders/nnacl/int8/matmul_base_int8_coder.cc
  5. +8
    -7
      mindspore/lite/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.cc
  6. +7
    -7
      mindspore/lite/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.cc
  7. +2
    -2
      mindspore/lite/micro/coder/wrapper/int8/matmul_int8_wrapper.c
  8. +189
    -31
      mindspore/lite/src/runtime/kernel/arm/int8/matmul_base_int8.cc
  9. +13
    -0
      mindspore/lite/src/runtime/kernel/arm/int8/matmul_base_int8.h

+ 264
- 2
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/matmul_int8.c View File

@@ -32,6 +32,69 @@ void RowMajor2Row2x16MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row,
}
}

void RowMajor2Row4x4MajorInt8(const int8_t *src, int8_t *dst, int row, int col) {
int row_div = row / C4NUM * C4NUM;
int col_4 = UP_ROUND(col, C4NUM);
int col_div = col / C4NUM * C4NUM;

const int8_t *src_r4 = src;
int8_t *packed_r4 = dst;
const int8_t *src_c4 = NULL;
int8_t *packed_c4 = NULL;
for (int r = 0; r < row_div; r += C4NUM) {
src_c4 = src_r4;
packed_c4 = packed_r4;

for (int c = 0; c < col_div; c += C4NUM) {
for (int i = 0; i < C4NUM; i++) {
packed_c4[i * C4NUM + 0] = src_c4[i * col + 0];
packed_c4[i * C4NUM + 1] = src_c4[i * col + 1];
packed_c4[i * C4NUM + 2] = src_c4[i * col + 2];
packed_c4[i * C4NUM + 3] = src_c4[i * col + 3];
}
src_c4 += C4NUM;
packed_c4 += C16NUM;
}

if (col == col_div) {
continue;
}
memset(packed_c4, 0, C16NUM * sizeof(int8_t));
for (int i = 0; i < C4NUM; ++i) {
for (int c = 0; c < col - col_div; ++c) {
packed_c4[i * C4NUM + c] = src_c4[i * col + c];
}
}
src_r4 += C4NUM * col;
packed_r4 += C4NUM * col_4;
}

if (row == row_div) {
return;
}
memset(packed_r4, 0, C4NUM * col_4);
src_c4 = src_r4;
packed_c4 = packed_r4;
for (int c = 0; c < col_div; c += C4NUM) {
for (int i = 0; i < row - row_div; ++i) {
packed_c4[i * C4NUM + 0] = src_c4[i * col + 0];
packed_c4[i * C4NUM + 1] = src_c4[i * col + 1];
packed_c4[i * C4NUM + 2] = src_c4[i * col + 2];
packed_c4[i * C4NUM + 3] = src_c4[i * col + 3];
}
src_c4 += C4NUM;
packed_c4 += C16NUM;
}
if (col == col_div) {
return;
}
for (int i = 0; i < row - row_div; ++i) {
for (int c = 0; c < col - col_div; ++c) {
packed_c4[i * C4NUM + c] = src_c4[i * col + c];
}
}
}

void RowMajor2Col16x2MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col) {
int row16 = UP_ROUND(row, C16NUM);
int stride = sizeof(int8_t) * C16NUM * C2NUM;
@@ -525,7 +588,139 @@ void PackInput4x4AndInputSumPert(const int8_t *src_input, int8_t *packed_input,
return;
}

void RowMajor2Col16x4MajorInt8(const int8_t *src, int row, int col, int8_t *dst) {
#ifdef ENABLE_ARM64
void PackInput2Col4x4AndInputSumPert_arm64(const int8_t *src_ic, int8_t *packed_ic, int32_t *input_sum, int row,
size_t row_stride, int32_t filter_zp) {
asm volatile(
"ld1 {v12.s}[0], [%[input_sum]]\n"
"mov w10, %w[row]\n"
"mov x11, %[src_ic]\n"
"mov x12, %[packed_ic]\n"
"sxtl v6.8h, v12.8b\n"
"sxtl v12.4s, v6.4h\n"
"cmp w10, wzr\n"
"beq 1f\n"
"2:\n"
"subs w10, w10, #4\n"
"ld1 {v0.s}[0], [x11], %[row_stride]\n"
"ld1 {v1.s}[0], [x11], %[row_stride]\n"
"ld1 {v0.s}[1], [x11], %[row_stride]\n"
"ld1 {v1.s}[1], [x11], %[row_stride]\n"
"zip1 v2.8b, v0.8b, v1.8b\n"
"zip2 v3.8b, v0.8b, v1.8b\n"
"zip1 v4.4h, v2.4h, v3.4h\n"
"zip2 v5.4h, v2.4h, v3.4h\n"
"st1 {v4.4h, v5.4h}, [x12], #16\n"

"sxtl v6.8h, v0.8b\n"
"sxtl v7.4s, v6.4h\n"
"sxtl2 v8.4s, v6.8h\n"
"sxtl v9.8h, v1.8b\n"
"sxtl v10.4s, v9.4h\n"
"sxtl2 v11.4s, v9.8h\n"
"add v10.4s, v10.4s, v7.4s\n"
"add v10.4s, v10.4s, v8.4s\n"
"add v10.4s, v10.4s, v10.4s\n"
"add v10.4s, v10.4s, v11.4s\n"
"bgt 2b\n"
"1:\n"

:
: [ src_ic ] "r"(src_ic), [ packed_ic ] "r"(packed_ic), [ input_sum ] "r"(input_sum), [ row ] "r"(row),
[ row_stride ] "r"(row_stride), [ filter_zp ] "r"(filter_zp)
: "memory", "w10", "x11", "x12", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12");

return;
}
#endif

// For matmul input a transpose case
void PackInput2Col4x4AndInputSumPert(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, int row,
int col, int row_stride, int32_t filter_zp) {
const int row_tile = C4NUM;
int row_align = UP_ROUND(row, row_tile);
int row_div = row / row_tile * row_tile;
const int row_res = row - row_div;

const int col_tile = C4NUM;
int col_div = col / col_tile * col_tile;
const int col_res = col - col_div;

const int8_t *src_ic = NULL;
int8_t *packed_ic = NULL;
int32_t *tmp_sum = NULL;
for (int c = 0; c < col_div; c += C4NUM) {
int r = 0;
src_ic = src_input + c;
packed_ic = packed_input + c * row_align;
tmp_sum = input_sum + c;
#ifdef ENABLE_ARM64
PackInput2Col4x4AndInputSumPert_arm64(src_ic, packed_ic, tmp_sum, row_div, row_stride, filter_zp);
packed_ic += C4NUM * row_div;
src_ic += row_div * row_stride;
#else
for (; r < row_div; r += C4NUM) {
for (int i = 0; i < row_tile; i++) {
packed_ic[0 * row_tile + i] = src_ic[i * row_stride + 0];
packed_ic[1 * row_tile + i] = src_ic[i * row_stride + 1];
packed_ic[2 * row_tile + i] = src_ic[i * row_stride + 2];
packed_ic[3 * row_tile + i] = src_ic[i * row_stride + 3];

tmp_sum[0] += src_ic[i * row_stride + 0];
tmp_sum[1] += src_ic[i * row_stride + 1];
tmp_sum[2] += src_ic[i * row_stride + 2];
tmp_sum[3] += src_ic[i * row_stride + 3];
}
packed_ic += C16NUM;
src_ic += row_tile * row_stride;
}
#endif

for (r = 0; r < row_res; ++r) {
for (int i = 0; i < C4NUM; ++i) {
packed_ic[i * row_tile + r] = src_ic[r * row_stride + i];
tmp_sum[i] += src_ic[r * row_stride + i];
}
}
}
if (col_res == 0) {
for (int i = 0; i < col; ++i) {
input_sum[i] *= filter_zp;
}
return;
}
src_ic = src_input + col_div;
packed_ic = packed_input + row_align * col_div;
tmp_sum = input_sum + col_div;
for (int r = 0; r < row_div; r += row_tile) {
for (int i = 0; i < col_res; ++i) {
packed_ic[i * row_tile + 0] = src_ic[0 * row_stride + i];
packed_ic[i * row_tile + 1] = src_ic[1 * row_stride + i];
packed_ic[i * row_tile + 2] = src_ic[2 * row_stride + i];
packed_ic[i * row_tile + 3] = src_ic[3 * row_stride + i];

tmp_sum[i] += src_ic[0 * row_stride + i];
tmp_sum[i] += src_ic[1 * row_stride + i];
tmp_sum[i] += src_ic[2 * row_stride + i];
tmp_sum[i] += src_ic[3 * row_stride + i];
}
src_ic += row_tile * row_stride;
packed_ic += row_tile * col_tile;
}

for (int r = 0; r < row_res; ++r) {
for (int c = 0; c < col_res; ++c) {
packed_ic[c * row_tile + r] = src_ic[r * row_stride + c];
tmp_sum[c] += src_ic[r * row_stride + c];
}
}

for (int i = 0; i < col; ++i) {
input_sum[i] *= filter_zp;
}
}

void RowMajor2Col16x4MajorInt8(const int8_t *src, int8_t *dst, int row, int col) {
int row_16 = UP_ROUND(row, C16NUM);
int stride = sizeof(int8_t) * 16 * 4;
for (int r = 0; r < row_16; ++r) {
@@ -541,7 +736,54 @@ void RowMajor2Col16x4MajorInt8(const int8_t *src, int row, int col, int8_t *dst)
}
}

// dst: weight_zp * input_row_sums
void RowMajor2Col4x4MajorInt8(const int8_t *src, int row, int col, int8_t *dst) {
int row_4 = UP_ROUND(row, C4NUM);
int stride = C4NUM * C4NUM;
for (int r = 0; r < row_4; ++r) {
for (int c = 0; c < col; ++c) {
int stride_idx = c / C4NUM * (row_4 / C4NUM) + r / C4NUM;
if (r >= row) {
dst[stride * stride_idx + c % C4NUM * C4NUM + r % C4NUM] = 0;
} else {
int src_idx = r * col + c;
dst[stride * stride_idx + c % C4NUM * C4NUM + r % C4NUM] = src[src_idx];
}
}
}
}

void RowMajor2Col4x16MajorPartInt8(const int8_t *src, int8_t *dst, int row, int col, int cur_oc) {
int row_4 = UP_ROUND(row, C4NUM);
int stride = C16NUM * C4NUM;
for (int r = 0; r < row_4; ++r) {
for (int c = 0; c < cur_oc; ++c) {
int stride_idx = c / C16NUM * (row_4 / C4NUM) + r / C4NUM;
if (r >= row) {
dst[stride * stride_idx + c % C16NUM * C4NUM + r % C4NUM] = 0;
} else {
int src_idx = r * col + c;
dst[stride * stride_idx + c % C16NUM * C4NUM + r % C4NUM] = src[src_idx];
}
}
}
}

void RowMajor2Col4x16MajorInt8(const int8_t *src, int8_t *dst, int row, int col) {
int row_4 = UP_ROUND(row, C4NUM);
int stride = C16NUM * C4NUM;
for (int r = 0; r < row_4; ++r) {
for (int c = 0; c < col; ++c) {
int stride_idx = c / C16NUM * (row_4 / C4NUM) + r / C4NUM;
if (r >= row) {
dst[stride * stride_idx + c % C16NUM * C4NUM + r % C4NUM] = 0;
} else {
int src_idx = r * col + c;
dst[stride * stride_idx + c % C16NUM * C4NUM + r % C4NUM] = src[src_idx];
}
}
}
}

void CalcInputSums(const int8_t *input, int row, int col, int weight_zp, int *dst, DataOrder order) {
for (int r = 0; r < row; ++r) {
int sum = 0;
@@ -576,3 +818,23 @@ void CalcWeightBiasSums(const int8_t *weight, int row, int col, int input_zp, co
}
}
}

void CalcPartWeightBiasSums(const int8_t *weight, int row, int stride, int cur_col, int input_zp,
const int *weight_zp_ptr, const int *bias, int *dst, DataOrder order,
bool filter_per_channel) {
for (int c = 0; c < cur_col; ++c) {
int sum = 0;
for (int r = 0; r < row; ++r) {
if (order == RowMajor) {
sum += weight[r * stride + c];
} else {
sum += weight[c * row + r];
}
}
int weight_zp = filter_per_channel ? weight_zp_ptr[c] : weight_zp_ptr[0];
dst[c] = row * input_zp * weight_zp - input_zp * sum;
if (bias != NULL) {
dst[c] += bias[c];
}
}
}

+ 10
- 2
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/int8/matmul_int8.h View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -25,14 +25,22 @@
extern "C" {
#endif
/* 4x16 16x4 -> 4x4 */
/* sdot 4x4 4x16 -> 4x16 */
/* matmul */
void MatMulInt8_16x4(const int8_t *a, const int8_t *b, int *dst, int row_4, int col_4, int deep_16,
const int *input_sum, const int *bias);
void RowMajor2Row16x4MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col);
void RowMajor2Col16x4MajorInt8(const int8_t *src, int row, int col, int8_t *dst);
void RowMajor2Col16x4MajorInt8(const int8_t *src, int8_t *dst, int row, int col);
void RowMajor2Col4x16MajorInt8(const int8_t *src, int8_t *dst, int row, int col);
void RowMajor2Col4x16MajorPartInt8(const int8_t *src, int8_t *dst, int row, int col, int cur_oc);
void PackInput2Col4x4AndInputSumPert(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, int row,
int col, int row_stride, int32_t filter_zp);
void CalcInputSums(const int8_t *input, int row, int col, int weight_zp, int *dst, DataOrder order);
void CalcWeightBiasSums(const int8_t *weight, int row, int col, int input_zp, const int *weight_zp_ptr, const int *bias,
int *dst, DataOrder order, bool filter_per_channel);
void CalcPartWeightBiasSums(const int8_t *weight, int row, int stride, int cur_col, int input_zp,
const int *weight_zp_ptr, const int *bias, int *dst, DataOrder order,
bool filter_per_channel);
void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, const int *a_sums,
const int *bias, int act_min, int act_max, int out_zp, const int32_t *multiplier,
const int32_t *left_shift, const int32_t *right_shift, size_t stride, size_t filter_peroc,


+ 1
- 0
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/matmul_parameter.h View File

@@ -54,6 +54,7 @@ typedef struct MatMulParameter {
int deep_;
int deep_4_;
int deep_16_;
int deep_align_;
int batch;
bool a_transpose_; /* false : row-major */
bool b_transpose_; /* true : col-major */


+ 1
- 1
mindspore/lite/micro/coder/opcoders/nnacl/int8/matmul_base_int8_coder.cc View File

@@ -225,7 +225,7 @@ int MatMulBaseInt8Coder::DoCode(CoderContext *const context) {
for (int i = 0; i < param_->batch; i++) {
std::string current_src_a = a_ptr_str + "+" + std::to_string(i * param_->row_ * param_->deep_);
if (param_->a_transpose_) {
code.CodeFunction("RowMajor2Col16x4MajorInt8", current_src_a, param_->deep_, param_->row_, pack_a_ptr_);
code.CodeFunction("RowMajor2Col16x4MajorInt8", current_src_a, pack_a_ptr_, param_->deep_, param_->row_);
code.CodeFunction("CalcInputSums", current_src_a, param_->row_, param_->deep_, "tmp_weight_zp", input_sums_,
ColMajor);
} else {


+ 8
- 7
mindspore/lite/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.cc View File

@@ -73,13 +73,14 @@ void NNaclFp32Serializer::CodeStruct(const std::string &name, const ConvParamete
}

void NNaclFp32Serializer::CodeStruct(const std::string &name, const MatMulParameter &mat_mul_parameter) {
CodeBaseStruct(
"MatMulParameter", name, mat_mul_parameter.op_parameter_, mat_mul_parameter.has_bias_, mat_mul_parameter.row_,
mat_mul_parameter.col_, mat_mul_parameter.row_4_, mat_mul_parameter.row_6_, mat_mul_parameter.row_12_,
mat_mul_parameter.row_16_, mat_mul_parameter.row_align_, mat_mul_parameter.col_4_, mat_mul_parameter.col_8_,
mat_mul_parameter.col_align_, mat_mul_parameter.deep_, mat_mul_parameter.deep_4_, mat_mul_parameter.deep_16_,
mat_mul_parameter.batch, mat_mul_parameter.a_transpose_, mat_mul_parameter.b_transpose_, mat_mul_parameter.a_const_,
mat_mul_parameter.b_const_, mat_mul_parameter.act_type_, mat_mul_parameter.use_axis_, mat_mul_parameter.axis_);
CodeBaseStruct("MatMulParameter", name, mat_mul_parameter.op_parameter_, mat_mul_parameter.has_bias_,
mat_mul_parameter.row_, mat_mul_parameter.col_, mat_mul_parameter.row_4_, mat_mul_parameter.row_6_,
mat_mul_parameter.row_12_, mat_mul_parameter.row_16_, mat_mul_parameter.row_align_,
mat_mul_parameter.col_4_, mat_mul_parameter.col_8_, mat_mul_parameter.col_align_,
mat_mul_parameter.deep_, mat_mul_parameter.deep_4_, mat_mul_parameter.deep_16_,
mat_mul_parameter.deep_align_, mat_mul_parameter.batch, mat_mul_parameter.a_transpose_,
mat_mul_parameter.b_transpose_, mat_mul_parameter.a_const_, mat_mul_parameter.b_const_,
mat_mul_parameter.act_type_, mat_mul_parameter.use_axis_, mat_mul_parameter.axis_);
}

void NNaclFp32Serializer::CodeStruct(const std::string &name, const ScaleParameter &scale_parameter) {


+ 7
- 7
mindspore/lite/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_int8_serializer.cc View File

@@ -62,13 +62,13 @@ void NNaclInt8Serializer::CodeStruct(const std::string &name, const ConvParamete
}

void NNaclInt8Serializer::CodeStruct(const std::string &name, const MatMulParameter &matmul_parameter) {
CodeBaseStruct<false>("MatMulParameter", name, matmul_parameter.op_parameter_, matmul_parameter.has_bias_,
matmul_parameter.row_, matmul_parameter.col_, matmul_parameter.row_4_, matmul_parameter.row_6_,
matmul_parameter.row_12_, matmul_parameter.row_16_, matmul_parameter.row_align_,
matmul_parameter.col_4_, matmul_parameter.col_8_, matmul_parameter.col_align_,
matmul_parameter.deep_, matmul_parameter.deep_4_, matmul_parameter.deep_16_,
matmul_parameter.batch, matmul_parameter.a_transpose_, matmul_parameter.b_transpose_,
matmul_parameter.a_const_, matmul_parameter.b_const_, matmul_parameter.act_type_);
CodeBaseStruct<false>(
"MatMulParameter", name, matmul_parameter.op_parameter_, matmul_parameter.has_bias_, matmul_parameter.row_,
matmul_parameter.col_, matmul_parameter.row_4_, matmul_parameter.row_6_, matmul_parameter.row_12_,
matmul_parameter.row_16_, matmul_parameter.row_align_, matmul_parameter.col_4_, matmul_parameter.col_8_,
matmul_parameter.col_align_, matmul_parameter.deep_, matmul_parameter.deep_4_, matmul_parameter.deep_16_,
matmul_parameter.deep_align_, matmul_parameter.batch, matmul_parameter.a_transpose_, matmul_parameter.b_transpose_,
matmul_parameter.a_const_, matmul_parameter.b_const_, matmul_parameter.act_type_);
}

void NNaclInt8Serializer::CodeStruct(const std::string &name, const AddQuantParameter &add_quant_parameter) {


+ 2
- 2
mindspore/lite/micro/coder/wrapper/int8/matmul_int8_wrapper.c View File

@@ -21,7 +21,7 @@ void InitInt8MatrixA(int8_t *src_ptr, int32_t *input_sums, int8_t *dst_ptr, int
for (int i = 0; i < batch; ++i) {
int8_t *cur_a_ptr = src_ptr + i * row * deep;
if (a_transpose) {
RowMajor2Col16x4MajorInt8(cur_a_ptr, deep, row, dst_ptr);
RowMajor2Col16x4MajorInt8(cur_a_ptr, dst_ptr, deep, row);
CalcInputSums(cur_a_ptr, row, deep, *weight_zp, input_sums, ColMajor);
} else {
RowMajor2Row16x4MajorInt8(cur_a_ptr, dst_ptr, row, deep);
@@ -48,7 +48,7 @@ void InitInt8MatrixB(int8_t *weight_ptr, int32_t *weight_bias_sums_batch_, int8_
#ifdef ENABLE_ARM32
RowMajor2Col16x2MajorInt8(cur_b, cur_b_pack, deep, col);
#else
RowMajor2Col16x4MajorInt8(cur_b, deep, col, cur_b_pack);
RowMajor2Col16x4MajorInt8(cur_b, cur_b_pack, deep, col);
#endif
CalcWeightBiasSums(cur_b, deep, col, input_zp, weight_zp, bias_ptr, cur_sums, RowMajor, false);
}


+ 189
- 31
mindspore/lite/src/runtime/kernel/arm/int8/matmul_base_int8.cc View File

@@ -15,6 +15,7 @@
*/

#include "src/runtime/kernel/arm/int8/matmul_base_int8.h"
#include "src/runtime/kernel/arm/int8/opt_op_handler.h"

using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_MEMORY_FAILED;
@@ -32,6 +33,97 @@ int MatmulBaseInt8Run(void *cdata, int task_id, float lhs_scale, float rhs_scale
return RET_OK;
}

#ifdef ENABLE_ARM64
int Arm64SdotPreRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
CHECK_NULL_RETURN(cdata);
auto op = reinterpret_cast<MatmulBaseInt8CPUKernel *>(cdata);
auto ret = op->Arm64SdotPre(task_id);
if (ret != RET_OK) {
MS_LOG(ERROR) << "MatmulInt8Run error task_id[" << task_id << "] error_code[" << ret << "]";
return ret;
}
return RET_OK;
}

int Arm64SdotRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
CHECK_NULL_RETURN(cdata);
auto op = reinterpret_cast<MatmulBaseInt8CPUKernel *>(cdata);
auto ret = op->Arm64SdotImpl(task_id);
if (ret != RET_OK) {
MS_LOG(ERROR) << "MatmulInt8Run error task_id[" << task_id << "] error_code[" << ret << "]";
return ret;
}
return RET_OK;
}

int MatmulBaseInt8CPUKernel::Arm64SdotPre(int task_id) {
int row_thread_count = MSMIN(op_parameter_->thread_num_, UP_DIV(param_->row_align_, row_tile_));
int row_stride = UP_DIV(UP_DIV(param_->row_align_, row_tile_), row_thread_count) * row_tile_;

int row_current_stride = task_id * row_stride;
int row_res_stride = param_->row_ - row_current_stride;
int cur_r = MSMIN(row_res_stride, row_stride);
if (cur_r <= 0) {
return RET_OK;
}

int tmp_weight_zp = filter_per_channel_ ? 1 : quant_param_->filter_zp_[0];
auto current_a_pack = pack_a_ptr_ + row_current_stride * param_->deep_align_;

if (param_->a_transpose_) {
auto current_src_a = batch_input_ptr_ + row_current_stride;
PackInput2Col4x4AndInputSumPert(current_src_a, current_a_pack, input_sums_ + row_current_stride, param_->deep_,
cur_r, param_->row_, tmp_weight_zp);
} else {
auto current_src_a = batch_input_ptr_ + row_current_stride * param_->deep_;
PackInput4x4AndInputSumPert(current_src_a, current_a_pack, input_sums_ + row_current_stride, param_->deep_, cur_r,
tmp_weight_zp);
}
return RET_OK;
}

int MatmulBaseInt8CPUKernel::Arm64SdotImpl(int task_id) {
int stride = thread_stride_ * col_tile_;
int cur_stride = task_id * stride;
int res_stride = param_->col_ - cur_stride;
int cur_oc = MSMIN(stride, res_stride);
if (cur_oc <= 0) {
return RET_OK;
}
if (param_->b_const_ == false) {
auto current_sums = batch_sums_ + cur_stride;
auto current_b_pack = batch_b_ptr_ + cur_stride * param_->deep_align_;
auto current_filter_zp = filter_per_channel_ ? quant_param_->filter_zp_ + cur_stride : quant_param_->filter_zp_;
auto current_bias = bias_ptr_ == nullptr ? nullptr : bias_ptr_ + cur_stride;
if (param_->b_transpose_) {
auto current_weight = batch_weight_ptr_ + cur_stride * param_->deep_;

RowMajor2Row4x16MajorInt8(current_weight, current_b_pack, cur_oc, param_->deep_);
CalcPartWeightBiasSums(current_weight, param_->deep_, param_->col_, cur_oc, quant_param_->input_.zp_,
current_filter_zp, current_bias, current_sums, ColMajor, filter_per_channel_);
} else {
auto current_weight = batch_weight_ptr_ + cur_stride;
RowMajor2Col4x16MajorPartInt8(current_weight, current_b_pack, param_->deep_, param_->col_, cur_oc);
CalcPartWeightBiasSums(current_weight, param_->deep_, param_->col_, cur_oc, quant_param_->input_.zp_,
current_filter_zp, current_bias, current_sums, RowMajor, filter_per_channel_);
}
}

int32_t *cur_left = filter_per_channel_ ? quant_param_->left_shift_ + cur_stride : quant_param_->left_shift_;
int32_t *cur_right = filter_per_channel_ ? quant_param_->right_shift_ + cur_stride : quant_param_->right_shift_;
int32_t *cur_mul =
filter_per_channel_ ? quant_param_->quant_multiplier_ + cur_stride : quant_param_->quant_multiplier_;
int32_t *cur_zp = filter_per_channel_ ? quant_param_->filter_zp_ + cur_stride : quant_param_->filter_zp_;

MatmulInt8DpOpt(pack_a_ptr_, batch_b_ptr_ + cur_stride * param_->deep_align_, batch_c_ptr_ + cur_stride, param_->row_,
cur_oc, param_->deep_align_, input_sums_, weight_bias_sums_ + cur_stride, quant_param_->out_act_min_,
quant_param_->out_act_max_, quant_param_->output_.zp_, cur_mul, cur_left, cur_right, param_->col_,
filter_per_channel_, cur_zp);

return RET_OK;
}
#endif

int MatmulBaseInt8CPUKernel::RunImpl(int task_id) {
int stride = thread_stride_ * col_tile_;
int cur_stride = task_id * stride;
@@ -47,8 +139,8 @@ int MatmulBaseInt8CPUKernel::RunImpl(int task_id) {
filter_per_channel_ ? quant_param_->quant_multiplier_ + cur_stride : quant_param_->quant_multiplier_;
int32_t *cur_zp = filter_per_channel_ ? quant_param_->filter_zp_ + cur_stride : quant_param_->filter_zp_;

MatmulInt8Opt(pack_a_ptr_, batch_b_ptr_ + cur_stride * param_->deep_16_, batch_c_ptr_ + cur_stride, param_->row_,
cur_oc, param_->deep_16_, input_sums_, weight_bias_sums_ + cur_stride, quant_param_->out_act_min_,
MatmulInt8Opt(pack_a_ptr_, batch_b_ptr_ + cur_stride * param_->deep_align_, batch_c_ptr_ + cur_stride, param_->row_,
cur_oc, param_->deep_align_, input_sums_, weight_bias_sums_ + cur_stride, quant_param_->out_act_min_,
quant_param_->out_act_max_, quant_param_->output_.zp_, cur_mul, cur_left, cur_right, param_->col_,
filter_per_channel_, cur_zp);

@@ -59,11 +151,6 @@ MatmulBaseInt8CPUKernel::~MatmulBaseInt8CPUKernel() {
FreeQuantParam();

FreeTmpBuffer();

if (bias_ptr_ != nullptr) {
free(bias_ptr_);
bias_ptr_ = nullptr;
}
}

void MatmulBaseInt8CPUKernel::FreeQuantParam() {
@@ -180,17 +267,59 @@ void MatmulBaseInt8CPUKernel::InitParameter() {
#ifdef ENABLE_ARM32
row_tile_ = C4NUM;
col_tile_ = C2NUM;
deep_tile_ = C16NUM;
#elif ENABLE_ARM64
support_sdot_ = mindspore::lite::IsSupportSDot();
row_tile_ = C4NUM;
if (support_sdot_) {
col_tile_ = C16NUM;
deep_tile_ = C4NUM;
} else {
col_tile_ = C4NUM;
deep_tile_ = C16NUM;
}
#else
row_tile_ = C4NUM;
col_tile_ = C4NUM;
deep_tile_ = C16NUM;
#endif
if (param_->a_transpose_) {
a_pack_func_ = RowMajor2Col16x4MajorInt8;
} else {
a_pack_func_ = RowMajor2Row16x4MajorInt8;
}
if (param_->b_transpose_) {
#ifdef ENABLE_ARM32
b_pack_func_ = RowMajor2Row2x16MajorInt8;
#elif ENABLE_ARM64
if (support_sdot_) {
b_pack_func_ = RowMajor2Row4x16MajorInt8;
} else {
b_pack_func_ = RowMajor2Row16x4MajorInt8;
}
#else
b_pack_func_ = RowMajor2Row16x4MajorInt8;
#endif
} else {
#ifdef ENABLE_ARM32
b_pack_func_ = RowMajor2Col16x2MajorInt8;
#elif ENABLE_ARM64
if (support_sdot_) {
b_pack_func_ = RowMajor2Col4x16MajorInt8;
} else {
b_pack_func_ = RowMajor2Col16x4MajorInt8;
}
#else
b_pack_func_ = RowMajor2Col16x4MajorInt8;
#endif
}
return;
}

void MatmulBaseInt8CPUKernel::ResizeParameter() {
param_->row_align_ = UP_ROUND(param_->row_, row_tile_);
param_->col_align_ = UP_ROUND(param_->col_, col_tile_);
param_->deep_16_ = UP_ROUND(param_->deep_, C16NUM);
param_->deep_align_ = UP_ROUND(param_->deep_, deep_tile_);

thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(param_->col_align_, col_tile_));
thread_stride_ = UP_DIV(UP_DIV(param_->col_align_, col_tile_), thread_count_);
@@ -221,37 +350,30 @@ void MatmulBaseInt8CPUKernel::TransferB() {
auto weight_data = reinterpret_cast<int8_t *>(in_tensors_.at(1)->data());
for (int i = 0; i < param_->batch; i++) {
auto current_weight = weight_data + i * param_->deep_ * param_->col_;
auto current_b_pack = pack_b_ptr_ + i * param_->col_align_ * param_->deep_16_;
auto current_b_pack = pack_b_ptr_ + i * param_->col_align_ * param_->deep_align_;
auto current_sums = weight_bias_sums_ + i * param_->col_align_;
MS_CHECK_PTR_IF_NULL(b_pack_func_);
if (param_->b_transpose_) {
#ifdef ENABLE_ARM32
RowMajor2Row2x16MajorInt8(current_weight, current_b_pack, param_->col_, param_->deep_);
#else
RowMajor2Row16x4MajorInt8(current_weight, current_b_pack, param_->col_, param_->deep_);
#endif
b_pack_func_(current_weight, current_b_pack, param_->col_, param_->deep_);
CalcWeightBiasSums(current_weight, param_->deep_, param_->col_, quant_param_->input_.zp_,
quant_param_->filter_zp_, bias_ptr_, current_sums, ColMajor, filter_per_channel_);
} else {
#ifdef ENABLE_ARM32
RowMajor2Col16x2MajorInt8(current_weight, current_b_pack, param_->deep_, param_->col_);
#else
RowMajor2Col16x4MajorInt8(current_weight, param_->deep_, param_->col_, current_b_pack);
#endif
b_pack_func_(current_weight, current_b_pack, param_->deep_, param_->col_);
CalcWeightBiasSums(current_weight, param_->deep_, param_->col_, quant_param_->input_.zp_,
quant_param_->filter_zp_, bias_ptr_, current_sums, RowMajor, false);
quant_param_->filter_zp_, bias_ptr_, current_sums, RowMajor, filter_per_channel_);
}
}
return;
}

int MatmulBaseInt8CPUKernel::InitTmpBuffer() {
pack_a_ptr_ = reinterpret_cast<int8_t *>(malloc(param_->row_align_ * param_->deep_16_ * sizeof(int8_t)));
pack_a_ptr_ = reinterpret_cast<int8_t *>(malloc(param_->row_align_ * param_->deep_align_ * sizeof(int8_t)));
if (pack_a_ptr_ == nullptr) {
FreeTmpBuffer();
return RET_ERROR;
}
pack_b_ptr_ =
reinterpret_cast<int8_t *>(malloc(param_->batch * param_->col_align_ * param_->deep_16_ * sizeof(int8_t)));
reinterpret_cast<int8_t *>(malloc(param_->batch * param_->col_align_ * param_->deep_align_ * sizeof(int8_t)));
if (pack_b_ptr_ == nullptr) {
FreeTmpBuffer();
return RET_ERROR;
@@ -267,8 +389,8 @@ int MatmulBaseInt8CPUKernel::InitTmpBuffer() {
return RET_ERROR;
}

memset(pack_a_ptr_, 0, param_->row_align_ * param_->deep_16_ * sizeof(int8_t));
memset(pack_b_ptr_, 0, param_->batch * param_->col_align_ * param_->deep_16_ * sizeof(int8_t));
memset(pack_a_ptr_, 0, param_->row_align_ * param_->deep_align_ * sizeof(int8_t));
memset(pack_b_ptr_, 0, param_->batch * param_->col_align_ * param_->deep_align_ * sizeof(int8_t));
memset(input_sums_, 0, param_->row_align_ * sizeof(int));
memset(weight_bias_sums_, 0, param_->batch * param_->col_align_ * sizeof(int));

@@ -278,9 +400,7 @@ int MatmulBaseInt8CPUKernel::InitTmpBuffer() {
int MatmulBaseInt8CPUKernel::InitBias() {
if (in_tensors_.size() == kInputSize2) {
auto bias_tensor = in_tensors_[kBiasIndex];
MS_CHECK_GT(bias_tensor->ElementsNum(), 0, RET_ERROR);
int max_bias_data = UP_ROUND(bias_tensor->ElementsNum(), C4NUM);
bias_ptr_ = reinterpret_cast<int *>(malloc(max_bias_data * sizeof(int)));
bias_ptr_ = reinterpret_cast<int *>(bias_tensor->data());
if (bias_ptr_ == nullptr) {
MS_LOG(ERROR) << "Memory allocation failed";
FreeTmpBuffer();
@@ -332,11 +452,47 @@ int MatmulBaseInt8CPUKernel::ReSize() {
return RET_OK;
}

#ifdef ENABLE_ARM64
int MatmulBaseInt8CPUKernel::RunArm64Sdot() {
int8_t *a_ptr = reinterpret_cast<int8_t *>(in_tensors_.at(0)->data());
int8_t *b_ptr = reinterpret_cast<int8_t *>(in_tensors_.at(1)->data());
int8_t *c_ptr = reinterpret_cast<int8_t *>(out_tensors_.at(0)->data());
CHECK_NULL_RETURN(a_ptr);
CHECK_NULL_RETURN(b_ptr);
CHECK_NULL_RETURN(c_ptr);

for (int i = 0; i < param_->batch; i++) {
batch_input_ptr_ = a_ptr + i * param_->row_ * param_->deep_;
auto ret = ParallelLaunch(this->ms_context_, Arm64SdotPreRun, this, op_parameter_->thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "RunArm64Sdot error: [" << ret << "]";
return ret;
}

batch_weight_ptr_ = b_ptr + i * param_->col_ * param_->deep_;
batch_b_ptr_ = pack_b_ptr_ + i * param_->col_align_ * param_->deep_align_;
batch_sums_ = weight_bias_sums_ + i * param_->col_align_;
batch_c_ptr_ = c_ptr + i * param_->row_ * param_->col_;

ret = ParallelLaunch(this->ms_context_, Arm64SdotRun, this, thread_count_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "RunArm64Sdot error: [" << ret << "]";
return ret;
}
}
return RET_OK;
}
#endif

int MatmulBaseInt8CPUKernel::Run() {
#ifdef ENABLE_ARM64
if (support_sdot_) {
return RunArm64Sdot();
}
#endif
if (param_->b_const_ == false) {
TransferB();
}

int8_t *a_ptr = reinterpret_cast<int8_t *>(in_tensors_.at(0)->data());
int8_t *c_ptr = reinterpret_cast<int8_t *>(out_tensors_.at(0)->data());
CHECK_NULL_RETURN(a_ptr);
@@ -345,14 +501,16 @@ int MatmulBaseInt8CPUKernel::Run() {
for (int i = 0; i < param_->batch; i++) {
auto current_src_a = a_ptr + i * param_->row_ * param_->deep_;
if (param_->a_transpose_) {
RowMajor2Col16x4MajorInt8(current_src_a, param_->deep_, param_->row_, pack_a_ptr_);
MS_CHECK_TRUE_RET(a_pack_func_ != nullptr, RET_ERROR);
a_pack_func_(current_src_a, pack_a_ptr_, param_->deep_, param_->row_);
CalcInputSums(current_src_a, param_->row_, param_->deep_, tmp_weight_zp, input_sums_, ColMajor);
} else {
RowMajor2Row16x4MajorInt8(current_src_a, pack_a_ptr_, param_->row_, param_->deep_);
MS_CHECK_TRUE_RET(a_pack_func_ != nullptr, RET_ERROR);
a_pack_func_(current_src_a, pack_a_ptr_, param_->row_, param_->deep_);
CalcInputSums(current_src_a, param_->row_, param_->deep_, tmp_weight_zp, input_sums_, RowMajor);
}

batch_b_ptr_ = pack_b_ptr_ + i * param_->col_align_ * param_->deep_16_;
batch_b_ptr_ = pack_b_ptr_ + i * param_->col_align_ * param_->deep_align_;
batch_sums_ = weight_bias_sums_ + i * param_->col_align_;
batch_c_ptr_ = c_ptr + i * param_->row_ * param_->col_;



+ 13
- 0
mindspore/lite/src/runtime/kernel/arm/int8/matmul_base_int8.h View File

@@ -29,6 +29,8 @@

namespace mindspore::kernel {
class MatmulBaseInt8CPUKernel : public InnerKernel {
typedef void (*PackFunc)(const int8_t *src, int8_t *dst, int row, int col);

public:
MatmulBaseInt8CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
@@ -42,6 +44,11 @@ class MatmulBaseInt8CPUKernel : public InnerKernel {

public:
int RunImpl(int task_id);
#ifdef ENABLE_ARM64
int RunArm64Sdot();
int Arm64SdotImpl(int task_id);
int Arm64SdotPre(int task_id);
#endif

protected:
void InitParameter();
@@ -72,12 +79,18 @@ class MatmulBaseInt8CPUKernel : public InnerKernel {
int *weight_bias_sums_ = nullptr;
int *bias_ptr_ = nullptr;
bool filter_per_channel_ = true;
int8_t *batch_input_ptr_ = nullptr;
int8_t *batch_weight_ptr_ = nullptr;
int8_t *batch_b_ptr_ = nullptr;
int8_t *batch_c_ptr_ = nullptr;
int *batch_sums_ = nullptr;
int row_tile_ = C4NUM;
int col_tile_ = C4NUM;
int deep_tile_ = C16NUM;
int channel_num_ = 0;
bool support_sdot_ = false;
PackFunc a_pack_func_{nullptr};
PackFunc b_pack_func_{nullptr};
};
} // namespace mindspore::kernel



Loading…
Cancel
Save