Merge pull request !26326 from zhaozhenlong/lite/issue/matmul-int8tags/v1.6.0
| @@ -32,6 +32,69 @@ void RowMajor2Row2x16MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, | |||
| } | |||
| } | |||
| void RowMajor2Row4x4MajorInt8(const int8_t *src, int8_t *dst, int row, int col) { | |||
| int row_div = row / C4NUM * C4NUM; | |||
| int col_4 = UP_ROUND(col, C4NUM); | |||
| int col_div = col / C4NUM * C4NUM; | |||
| const int8_t *src_r4 = src; | |||
| int8_t *packed_r4 = dst; | |||
| const int8_t *src_c4 = NULL; | |||
| int8_t *packed_c4 = NULL; | |||
| for (int r = 0; r < row_div; r += C4NUM) { | |||
| src_c4 = src_r4; | |||
| packed_c4 = packed_r4; | |||
| for (int c = 0; c < col_div; c += C4NUM) { | |||
| for (int i = 0; i < C4NUM; i++) { | |||
| packed_c4[i * C4NUM + 0] = src_c4[i * col + 0]; | |||
| packed_c4[i * C4NUM + 1] = src_c4[i * col + 1]; | |||
| packed_c4[i * C4NUM + 2] = src_c4[i * col + 2]; | |||
| packed_c4[i * C4NUM + 3] = src_c4[i * col + 3]; | |||
| } | |||
| src_c4 += C4NUM; | |||
| packed_c4 += C16NUM; | |||
| } | |||
| if (col == col_div) { | |||
| continue; | |||
| } | |||
| memset(packed_c4, 0, C16NUM * sizeof(int8_t)); | |||
| for (int i = 0; i < C4NUM; ++i) { | |||
| for (int c = 0; c < col - col_div; ++c) { | |||
| packed_c4[i * C4NUM + c] = src_c4[i * col + c]; | |||
| } | |||
| } | |||
| src_r4 += C4NUM * col; | |||
| packed_r4 += C4NUM * col_4; | |||
| } | |||
| if (row == row_div) { | |||
| return; | |||
| } | |||
| memset(packed_r4, 0, C4NUM * col_4); | |||
| src_c4 = src_r4; | |||
| packed_c4 = packed_r4; | |||
| for (int c = 0; c < col_div; c += C4NUM) { | |||
| for (int i = 0; i < row - row_div; ++i) { | |||
| packed_c4[i * C4NUM + 0] = src_c4[i * col + 0]; | |||
| packed_c4[i * C4NUM + 1] = src_c4[i * col + 1]; | |||
| packed_c4[i * C4NUM + 2] = src_c4[i * col + 2]; | |||
| packed_c4[i * C4NUM + 3] = src_c4[i * col + 3]; | |||
| } | |||
| src_c4 += C4NUM; | |||
| packed_c4 += C16NUM; | |||
| } | |||
| if (col == col_div) { | |||
| return; | |||
| } | |||
| for (int i = 0; i < row - row_div; ++i) { | |||
| for (int c = 0; c < col - col_div; ++c) { | |||
| packed_c4[i * C4NUM + c] = src_c4[i * col + c]; | |||
| } | |||
| } | |||
| } | |||
| void RowMajor2Col16x2MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col) { | |||
| int row16 = UP_ROUND(row, C16NUM); | |||
| int stride = sizeof(int8_t) * C16NUM * C2NUM; | |||
| @@ -525,7 +588,139 @@ void PackInput4x4AndInputSumPert(const int8_t *src_input, int8_t *packed_input, | |||
| return; | |||
| } | |||
| void RowMajor2Col16x4MajorInt8(const int8_t *src, int row, int col, int8_t *dst) { | |||
| #ifdef ENABLE_ARM64 | |||
| void PackInput2Col4x4AndInputSumPert_arm64(const int8_t *src_ic, int8_t *packed_ic, int32_t *input_sum, int row, | |||
| size_t row_stride, int32_t filter_zp) { | |||
| asm volatile( | |||
| "ld1 {v12.s}[0], [%[input_sum]]\n" | |||
| "mov w10, %w[row]\n" | |||
| "mov x11, %[src_ic]\n" | |||
| "mov x12, %[packed_ic]\n" | |||
| "sxtl v6.8h, v12.8b\n" | |||
| "sxtl v12.4s, v6.4h\n" | |||
| "cmp w10, wzr\n" | |||
| "beq 1f\n" | |||
| "2:\n" | |||
| "subs w10, w10, #4\n" | |||
| "ld1 {v0.s}[0], [x11], %[row_stride]\n" | |||
| "ld1 {v1.s}[0], [x11], %[row_stride]\n" | |||
| "ld1 {v0.s}[1], [x11], %[row_stride]\n" | |||
| "ld1 {v1.s}[1], [x11], %[row_stride]\n" | |||
| "zip1 v2.8b, v0.8b, v1.8b\n" | |||
| "zip2 v3.8b, v0.8b, v1.8b\n" | |||
| "zip1 v4.4h, v2.4h, v3.4h\n" | |||
| "zip2 v5.4h, v2.4h, v3.4h\n" | |||
| "st1 {v4.4h, v5.4h}, [x12], #16\n" | |||
| "sxtl v6.8h, v0.8b\n" | |||
| "sxtl v7.4s, v6.4h\n" | |||
| "sxtl2 v8.4s, v6.8h\n" | |||
| "sxtl v9.8h, v1.8b\n" | |||
| "sxtl v10.4s, v9.4h\n" | |||
| "sxtl2 v11.4s, v9.8h\n" | |||
| "add v10.4s, v10.4s, v7.4s\n" | |||
| "add v10.4s, v10.4s, v8.4s\n" | |||
| "add v10.4s, v10.4s, v10.4s\n" | |||
| "add v10.4s, v10.4s, v11.4s\n" | |||
| "bgt 2b\n" | |||
| "1:\n" | |||
| : | |||
| : [ src_ic ] "r"(src_ic), [ packed_ic ] "r"(packed_ic), [ input_sum ] "r"(input_sum), [ row ] "r"(row), | |||
| [ row_stride ] "r"(row_stride), [ filter_zp ] "r"(filter_zp) | |||
| : "memory", "w10", "x11", "x12", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12"); | |||
| return; | |||
| } | |||
| #endif | |||
| // For matmul input a transpose case | |||
| void PackInput2Col4x4AndInputSumPert(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, int row, | |||
| int col, int row_stride, int32_t filter_zp) { | |||
| const int row_tile = C4NUM; | |||
| int row_align = UP_ROUND(row, row_tile); | |||
| int row_div = row / row_tile * row_tile; | |||
| const int row_res = row - row_div; | |||
| const int col_tile = C4NUM; | |||
| int col_div = col / col_tile * col_tile; | |||
| const int col_res = col - col_div; | |||
| const int8_t *src_ic = NULL; | |||
| int8_t *packed_ic = NULL; | |||
| int32_t *tmp_sum = NULL; | |||
| for (int c = 0; c < col_div; c += C4NUM) { | |||
| int r = 0; | |||
| src_ic = src_input + c; | |||
| packed_ic = packed_input + c * row_align; | |||
| tmp_sum = input_sum + c; | |||
| #ifdef ENABLE_ARM64 | |||
| PackInput2Col4x4AndInputSumPert_arm64(src_ic, packed_ic, tmp_sum, row_div, row_stride, filter_zp); | |||
| packed_ic += C4NUM * row_div; | |||
| src_ic += row_div * row_stride; | |||
| #else | |||
| for (; r < row_div; r += C4NUM) { | |||
| for (int i = 0; i < row_tile; i++) { | |||
| packed_ic[0 * row_tile + i] = src_ic[i * row_stride + 0]; | |||
| packed_ic[1 * row_tile + i] = src_ic[i * row_stride + 1]; | |||
| packed_ic[2 * row_tile + i] = src_ic[i * row_stride + 2]; | |||
| packed_ic[3 * row_tile + i] = src_ic[i * row_stride + 3]; | |||
| tmp_sum[0] += src_ic[i * row_stride + 0]; | |||
| tmp_sum[1] += src_ic[i * row_stride + 1]; | |||
| tmp_sum[2] += src_ic[i * row_stride + 2]; | |||
| tmp_sum[3] += src_ic[i * row_stride + 3]; | |||
| } | |||
| packed_ic += C16NUM; | |||
| src_ic += row_tile * row_stride; | |||
| } | |||
| #endif | |||
| for (r = 0; r < row_res; ++r) { | |||
| for (int i = 0; i < C4NUM; ++i) { | |||
| packed_ic[i * row_tile + r] = src_ic[r * row_stride + i]; | |||
| tmp_sum[i] += src_ic[r * row_stride + i]; | |||
| } | |||
| } | |||
| } | |||
| if (col_res == 0) { | |||
| for (int i = 0; i < col; ++i) { | |||
| input_sum[i] *= filter_zp; | |||
| } | |||
| return; | |||
| } | |||
| src_ic = src_input + col_div; | |||
| packed_ic = packed_input + row_align * col_div; | |||
| tmp_sum = input_sum + col_div; | |||
| for (int r = 0; r < row_div; r += row_tile) { | |||
| for (int i = 0; i < col_res; ++i) { | |||
| packed_ic[i * row_tile + 0] = src_ic[0 * row_stride + i]; | |||
| packed_ic[i * row_tile + 1] = src_ic[1 * row_stride + i]; | |||
| packed_ic[i * row_tile + 2] = src_ic[2 * row_stride + i]; | |||
| packed_ic[i * row_tile + 3] = src_ic[3 * row_stride + i]; | |||
| tmp_sum[i] += src_ic[0 * row_stride + i]; | |||
| tmp_sum[i] += src_ic[1 * row_stride + i]; | |||
| tmp_sum[i] += src_ic[2 * row_stride + i]; | |||
| tmp_sum[i] += src_ic[3 * row_stride + i]; | |||
| } | |||
| src_ic += row_tile * row_stride; | |||
| packed_ic += row_tile * col_tile; | |||
| } | |||
| for (int r = 0; r < row_res; ++r) { | |||
| for (int c = 0; c < col_res; ++c) { | |||
| packed_ic[c * row_tile + r] = src_ic[r * row_stride + c]; | |||
| tmp_sum[c] += src_ic[r * row_stride + c]; | |||
| } | |||
| } | |||
| for (int i = 0; i < col; ++i) { | |||
| input_sum[i] *= filter_zp; | |||
| } | |||
| } | |||
| void RowMajor2Col16x4MajorInt8(const int8_t *src, int8_t *dst, int row, int col) { | |||
| int row_16 = UP_ROUND(row, C16NUM); | |||
| int stride = sizeof(int8_t) * 16 * 4; | |||
| for (int r = 0; r < row_16; ++r) { | |||
| @@ -541,7 +736,54 @@ void RowMajor2Col16x4MajorInt8(const int8_t *src, int row, int col, int8_t *dst) | |||
| } | |||
| } | |||
| // dst: weight_zp * input_row_sums | |||
| void RowMajor2Col4x4MajorInt8(const int8_t *src, int row, int col, int8_t *dst) { | |||
| int row_4 = UP_ROUND(row, C4NUM); | |||
| int stride = C4NUM * C4NUM; | |||
| for (int r = 0; r < row_4; ++r) { | |||
| for (int c = 0; c < col; ++c) { | |||
| int stride_idx = c / C4NUM * (row_4 / C4NUM) + r / C4NUM; | |||
| if (r >= row) { | |||
| dst[stride * stride_idx + c % C4NUM * C4NUM + r % C4NUM] = 0; | |||
| } else { | |||
| int src_idx = r * col + c; | |||
| dst[stride * stride_idx + c % C4NUM * C4NUM + r % C4NUM] = src[src_idx]; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void RowMajor2Col4x16MajorPartInt8(const int8_t *src, int8_t *dst, int row, int col, int cur_oc) { | |||
| int row_4 = UP_ROUND(row, C4NUM); | |||
| int stride = C16NUM * C4NUM; | |||
| for (int r = 0; r < row_4; ++r) { | |||
| for (int c = 0; c < cur_oc; ++c) { | |||
| int stride_idx = c / C16NUM * (row_4 / C4NUM) + r / C4NUM; | |||
| if (r >= row) { | |||
| dst[stride * stride_idx + c % C16NUM * C4NUM + r % C4NUM] = 0; | |||
| } else { | |||
| int src_idx = r * col + c; | |||
| dst[stride * stride_idx + c % C16NUM * C4NUM + r % C4NUM] = src[src_idx]; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void RowMajor2Col4x16MajorInt8(const int8_t *src, int8_t *dst, int row, int col) { | |||
| int row_4 = UP_ROUND(row, C4NUM); | |||
| int stride = C16NUM * C4NUM; | |||
| for (int r = 0; r < row_4; ++r) { | |||
| for (int c = 0; c < col; ++c) { | |||
| int stride_idx = c / C16NUM * (row_4 / C4NUM) + r / C4NUM; | |||
| if (r >= row) { | |||
| dst[stride * stride_idx + c % C16NUM * C4NUM + r % C4NUM] = 0; | |||
| } else { | |||
| int src_idx = r * col + c; | |||
| dst[stride * stride_idx + c % C16NUM * C4NUM + r % C4NUM] = src[src_idx]; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void CalcInputSums(const int8_t *input, int row, int col, int weight_zp, int *dst, DataOrder order) { | |||
| for (int r = 0; r < row; ++r) { | |||
| int sum = 0; | |||
| @@ -576,3 +818,23 @@ void CalcWeightBiasSums(const int8_t *weight, int row, int col, int input_zp, co | |||
| } | |||
| } | |||
| } | |||
| void CalcPartWeightBiasSums(const int8_t *weight, int row, int stride, int cur_col, int input_zp, | |||
| const int *weight_zp_ptr, const int *bias, int *dst, DataOrder order, | |||
| bool filter_per_channel) { | |||
| for (int c = 0; c < cur_col; ++c) { | |||
| int sum = 0; | |||
| for (int r = 0; r < row; ++r) { | |||
| if (order == RowMajor) { | |||
| sum += weight[r * stride + c]; | |||
| } else { | |||
| sum += weight[c * row + r]; | |||
| } | |||
| } | |||
| int weight_zp = filter_per_channel ? weight_zp_ptr[c] : weight_zp_ptr[0]; | |||
| dst[c] = row * input_zp * weight_zp - input_zp * sum; | |||
| if (bias != NULL) { | |||
| dst[c] += bias[c]; | |||
| } | |||
| } | |||
| } | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -25,14 +25,22 @@ | |||
| extern "C" { | |||
| #endif | |||
| /* 4x16 16x4 -> 4x4 */ | |||
| /* sdot 4x4 4x16 -> 4x16 */ | |||
| /* matmul */ | |||
| void MatMulInt8_16x4(const int8_t *a, const int8_t *b, int *dst, int row_4, int col_4, int deep_16, | |||
| const int *input_sum, const int *bias); | |||
| void RowMajor2Row16x4MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col); | |||
| void RowMajor2Col16x4MajorInt8(const int8_t *src, int row, int col, int8_t *dst); | |||
| void RowMajor2Col16x4MajorInt8(const int8_t *src, int8_t *dst, int row, int col); | |||
| void RowMajor2Col4x16MajorInt8(const int8_t *src, int8_t *dst, int row, int col); | |||
| void RowMajor2Col4x16MajorPartInt8(const int8_t *src, int8_t *dst, int row, int col, int cur_oc); | |||
| void PackInput2Col4x4AndInputSumPert(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, int row, | |||
| int col, int row_stride, int32_t filter_zp); | |||
| void CalcInputSums(const int8_t *input, int row, int col, int weight_zp, int *dst, DataOrder order); | |||
| void CalcWeightBiasSums(const int8_t *weight, int row, int col, int input_zp, const int *weight_zp_ptr, const int *bias, | |||
| int *dst, DataOrder order, bool filter_per_channel); | |||
| void CalcPartWeightBiasSums(const int8_t *weight, int row, int stride, int cur_col, int input_zp, | |||
| const int *weight_zp_ptr, const int *bias, int *dst, DataOrder order, | |||
| bool filter_per_channel); | |||
| void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, const int *a_sums, | |||
| const int *bias, int act_min, int act_max, int out_zp, const int32_t *multiplier, | |||
| const int32_t *left_shift, const int32_t *right_shift, size_t stride, size_t filter_peroc, | |||
| @@ -54,6 +54,7 @@ typedef struct MatMulParameter { | |||
| int deep_; | |||
| int deep_4_; | |||
| int deep_16_; | |||
| int deep_align_; | |||
| int batch; | |||
| bool a_transpose_; /* false : row-major */ | |||
| bool b_transpose_; /* true : col-major */ | |||
| @@ -225,7 +225,7 @@ int MatMulBaseInt8Coder::DoCode(CoderContext *const context) { | |||
| for (int i = 0; i < param_->batch; i++) { | |||
| std::string current_src_a = a_ptr_str + "+" + std::to_string(i * param_->row_ * param_->deep_); | |||
| if (param_->a_transpose_) { | |||
| code.CodeFunction("RowMajor2Col16x4MajorInt8", current_src_a, param_->deep_, param_->row_, pack_a_ptr_); | |||
| code.CodeFunction("RowMajor2Col16x4MajorInt8", current_src_a, pack_a_ptr_, param_->deep_, param_->row_); | |||
| code.CodeFunction("CalcInputSums", current_src_a, param_->row_, param_->deep_, "tmp_weight_zp", input_sums_, | |||
| ColMajor); | |||
| } else { | |||
| @@ -73,13 +73,14 @@ void NNaclFp32Serializer::CodeStruct(const std::string &name, const ConvParamete | |||
| } | |||
| void NNaclFp32Serializer::CodeStruct(const std::string &name, const MatMulParameter &mat_mul_parameter) { | |||
| CodeBaseStruct( | |||
| "MatMulParameter", name, mat_mul_parameter.op_parameter_, mat_mul_parameter.has_bias_, mat_mul_parameter.row_, | |||
| mat_mul_parameter.col_, mat_mul_parameter.row_4_, mat_mul_parameter.row_6_, mat_mul_parameter.row_12_, | |||
| mat_mul_parameter.row_16_, mat_mul_parameter.row_align_, mat_mul_parameter.col_4_, mat_mul_parameter.col_8_, | |||
| mat_mul_parameter.col_align_, mat_mul_parameter.deep_, mat_mul_parameter.deep_4_, mat_mul_parameter.deep_16_, | |||
| mat_mul_parameter.batch, mat_mul_parameter.a_transpose_, mat_mul_parameter.b_transpose_, mat_mul_parameter.a_const_, | |||
| mat_mul_parameter.b_const_, mat_mul_parameter.act_type_, mat_mul_parameter.use_axis_, mat_mul_parameter.axis_); | |||
| CodeBaseStruct("MatMulParameter", name, mat_mul_parameter.op_parameter_, mat_mul_parameter.has_bias_, | |||
| mat_mul_parameter.row_, mat_mul_parameter.col_, mat_mul_parameter.row_4_, mat_mul_parameter.row_6_, | |||
| mat_mul_parameter.row_12_, mat_mul_parameter.row_16_, mat_mul_parameter.row_align_, | |||
| mat_mul_parameter.col_4_, mat_mul_parameter.col_8_, mat_mul_parameter.col_align_, | |||
| mat_mul_parameter.deep_, mat_mul_parameter.deep_4_, mat_mul_parameter.deep_16_, | |||
| mat_mul_parameter.deep_align_, mat_mul_parameter.batch, mat_mul_parameter.a_transpose_, | |||
| mat_mul_parameter.b_transpose_, mat_mul_parameter.a_const_, mat_mul_parameter.b_const_, | |||
| mat_mul_parameter.act_type_, mat_mul_parameter.use_axis_, mat_mul_parameter.axis_); | |||
| } | |||
| void NNaclFp32Serializer::CodeStruct(const std::string &name, const ScaleParameter &scale_parameter) { | |||
| @@ -62,13 +62,13 @@ void NNaclInt8Serializer::CodeStruct(const std::string &name, const ConvParamete | |||
| } | |||
| void NNaclInt8Serializer::CodeStruct(const std::string &name, const MatMulParameter &matmul_parameter) { | |||
| CodeBaseStruct<false>("MatMulParameter", name, matmul_parameter.op_parameter_, matmul_parameter.has_bias_, | |||
| matmul_parameter.row_, matmul_parameter.col_, matmul_parameter.row_4_, matmul_parameter.row_6_, | |||
| matmul_parameter.row_12_, matmul_parameter.row_16_, matmul_parameter.row_align_, | |||
| matmul_parameter.col_4_, matmul_parameter.col_8_, matmul_parameter.col_align_, | |||
| matmul_parameter.deep_, matmul_parameter.deep_4_, matmul_parameter.deep_16_, | |||
| matmul_parameter.batch, matmul_parameter.a_transpose_, matmul_parameter.b_transpose_, | |||
| matmul_parameter.a_const_, matmul_parameter.b_const_, matmul_parameter.act_type_); | |||
| CodeBaseStruct<false>( | |||
| "MatMulParameter", name, matmul_parameter.op_parameter_, matmul_parameter.has_bias_, matmul_parameter.row_, | |||
| matmul_parameter.col_, matmul_parameter.row_4_, matmul_parameter.row_6_, matmul_parameter.row_12_, | |||
| matmul_parameter.row_16_, matmul_parameter.row_align_, matmul_parameter.col_4_, matmul_parameter.col_8_, | |||
| matmul_parameter.col_align_, matmul_parameter.deep_, matmul_parameter.deep_4_, matmul_parameter.deep_16_, | |||
| matmul_parameter.deep_align_, matmul_parameter.batch, matmul_parameter.a_transpose_, matmul_parameter.b_transpose_, | |||
| matmul_parameter.a_const_, matmul_parameter.b_const_, matmul_parameter.act_type_); | |||
| } | |||
| void NNaclInt8Serializer::CodeStruct(const std::string &name, const AddQuantParameter &add_quant_parameter) { | |||
| @@ -21,7 +21,7 @@ void InitInt8MatrixA(int8_t *src_ptr, int32_t *input_sums, int8_t *dst_ptr, int | |||
| for (int i = 0; i < batch; ++i) { | |||
| int8_t *cur_a_ptr = src_ptr + i * row * deep; | |||
| if (a_transpose) { | |||
| RowMajor2Col16x4MajorInt8(cur_a_ptr, deep, row, dst_ptr); | |||
| RowMajor2Col16x4MajorInt8(cur_a_ptr, dst_ptr, deep, row); | |||
| CalcInputSums(cur_a_ptr, row, deep, *weight_zp, input_sums, ColMajor); | |||
| } else { | |||
| RowMajor2Row16x4MajorInt8(cur_a_ptr, dst_ptr, row, deep); | |||
| @@ -48,7 +48,7 @@ void InitInt8MatrixB(int8_t *weight_ptr, int32_t *weight_bias_sums_batch_, int8_ | |||
| #ifdef ENABLE_ARM32 | |||
| RowMajor2Col16x2MajorInt8(cur_b, cur_b_pack, deep, col); | |||
| #else | |||
| RowMajor2Col16x4MajorInt8(cur_b, deep, col, cur_b_pack); | |||
| RowMajor2Col16x4MajorInt8(cur_b, cur_b_pack, deep, col); | |||
| #endif | |||
| CalcWeightBiasSums(cur_b, deep, col, input_zp, weight_zp, bias_ptr, cur_sums, RowMajor, false); | |||
| } | |||
| @@ -15,6 +15,7 @@ | |||
| */ | |||
| #include "src/runtime/kernel/arm/int8/matmul_base_int8.h" | |||
| #include "src/runtime/kernel/arm/int8/opt_op_handler.h" | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_MEMORY_FAILED; | |||
| @@ -32,6 +33,97 @@ int MatmulBaseInt8Run(void *cdata, int task_id, float lhs_scale, float rhs_scale | |||
| return RET_OK; | |||
| } | |||
| #ifdef ENABLE_ARM64 | |||
| int Arm64SdotPreRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) { | |||
| CHECK_NULL_RETURN(cdata); | |||
| auto op = reinterpret_cast<MatmulBaseInt8CPUKernel *>(cdata); | |||
| auto ret = op->Arm64SdotPre(task_id); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "MatmulInt8Run error task_id[" << task_id << "] error_code[" << ret << "]"; | |||
| return ret; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int Arm64SdotRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) { | |||
| CHECK_NULL_RETURN(cdata); | |||
| auto op = reinterpret_cast<MatmulBaseInt8CPUKernel *>(cdata); | |||
| auto ret = op->Arm64SdotImpl(task_id); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "MatmulInt8Run error task_id[" << task_id << "] error_code[" << ret << "]"; | |||
| return ret; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int MatmulBaseInt8CPUKernel::Arm64SdotPre(int task_id) { | |||
| int row_thread_count = MSMIN(op_parameter_->thread_num_, UP_DIV(param_->row_align_, row_tile_)); | |||
| int row_stride = UP_DIV(UP_DIV(param_->row_align_, row_tile_), row_thread_count) * row_tile_; | |||
| int row_current_stride = task_id * row_stride; | |||
| int row_res_stride = param_->row_ - row_current_stride; | |||
| int cur_r = MSMIN(row_res_stride, row_stride); | |||
| if (cur_r <= 0) { | |||
| return RET_OK; | |||
| } | |||
| int tmp_weight_zp = filter_per_channel_ ? 1 : quant_param_->filter_zp_[0]; | |||
| auto current_a_pack = pack_a_ptr_ + row_current_stride * param_->deep_align_; | |||
| if (param_->a_transpose_) { | |||
| auto current_src_a = batch_input_ptr_ + row_current_stride; | |||
| PackInput2Col4x4AndInputSumPert(current_src_a, current_a_pack, input_sums_ + row_current_stride, param_->deep_, | |||
| cur_r, param_->row_, tmp_weight_zp); | |||
| } else { | |||
| auto current_src_a = batch_input_ptr_ + row_current_stride * param_->deep_; | |||
| PackInput4x4AndInputSumPert(current_src_a, current_a_pack, input_sums_ + row_current_stride, param_->deep_, cur_r, | |||
| tmp_weight_zp); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int MatmulBaseInt8CPUKernel::Arm64SdotImpl(int task_id) { | |||
| int stride = thread_stride_ * col_tile_; | |||
| int cur_stride = task_id * stride; | |||
| int res_stride = param_->col_ - cur_stride; | |||
| int cur_oc = MSMIN(stride, res_stride); | |||
| if (cur_oc <= 0) { | |||
| return RET_OK; | |||
| } | |||
| if (param_->b_const_ == false) { | |||
| auto current_sums = batch_sums_ + cur_stride; | |||
| auto current_b_pack = batch_b_ptr_ + cur_stride * param_->deep_align_; | |||
| auto current_filter_zp = filter_per_channel_ ? quant_param_->filter_zp_ + cur_stride : quant_param_->filter_zp_; | |||
| auto current_bias = bias_ptr_ == nullptr ? nullptr : bias_ptr_ + cur_stride; | |||
| if (param_->b_transpose_) { | |||
| auto current_weight = batch_weight_ptr_ + cur_stride * param_->deep_; | |||
| RowMajor2Row4x16MajorInt8(current_weight, current_b_pack, cur_oc, param_->deep_); | |||
| CalcPartWeightBiasSums(current_weight, param_->deep_, param_->col_, cur_oc, quant_param_->input_.zp_, | |||
| current_filter_zp, current_bias, current_sums, ColMajor, filter_per_channel_); | |||
| } else { | |||
| auto current_weight = batch_weight_ptr_ + cur_stride; | |||
| RowMajor2Col4x16MajorPartInt8(current_weight, current_b_pack, param_->deep_, param_->col_, cur_oc); | |||
| CalcPartWeightBiasSums(current_weight, param_->deep_, param_->col_, cur_oc, quant_param_->input_.zp_, | |||
| current_filter_zp, current_bias, current_sums, RowMajor, filter_per_channel_); | |||
| } | |||
| } | |||
| int32_t *cur_left = filter_per_channel_ ? quant_param_->left_shift_ + cur_stride : quant_param_->left_shift_; | |||
| int32_t *cur_right = filter_per_channel_ ? quant_param_->right_shift_ + cur_stride : quant_param_->right_shift_; | |||
| int32_t *cur_mul = | |||
| filter_per_channel_ ? quant_param_->quant_multiplier_ + cur_stride : quant_param_->quant_multiplier_; | |||
| int32_t *cur_zp = filter_per_channel_ ? quant_param_->filter_zp_ + cur_stride : quant_param_->filter_zp_; | |||
| MatmulInt8DpOpt(pack_a_ptr_, batch_b_ptr_ + cur_stride * param_->deep_align_, batch_c_ptr_ + cur_stride, param_->row_, | |||
| cur_oc, param_->deep_align_, input_sums_, weight_bias_sums_ + cur_stride, quant_param_->out_act_min_, | |||
| quant_param_->out_act_max_, quant_param_->output_.zp_, cur_mul, cur_left, cur_right, param_->col_, | |||
| filter_per_channel_, cur_zp); | |||
| return RET_OK; | |||
| } | |||
| #endif | |||
| int MatmulBaseInt8CPUKernel::RunImpl(int task_id) { | |||
| int stride = thread_stride_ * col_tile_; | |||
| int cur_stride = task_id * stride; | |||
| @@ -47,8 +139,8 @@ int MatmulBaseInt8CPUKernel::RunImpl(int task_id) { | |||
| filter_per_channel_ ? quant_param_->quant_multiplier_ + cur_stride : quant_param_->quant_multiplier_; | |||
| int32_t *cur_zp = filter_per_channel_ ? quant_param_->filter_zp_ + cur_stride : quant_param_->filter_zp_; | |||
| MatmulInt8Opt(pack_a_ptr_, batch_b_ptr_ + cur_stride * param_->deep_16_, batch_c_ptr_ + cur_stride, param_->row_, | |||
| cur_oc, param_->deep_16_, input_sums_, weight_bias_sums_ + cur_stride, quant_param_->out_act_min_, | |||
| MatmulInt8Opt(pack_a_ptr_, batch_b_ptr_ + cur_stride * param_->deep_align_, batch_c_ptr_ + cur_stride, param_->row_, | |||
| cur_oc, param_->deep_align_, input_sums_, weight_bias_sums_ + cur_stride, quant_param_->out_act_min_, | |||
| quant_param_->out_act_max_, quant_param_->output_.zp_, cur_mul, cur_left, cur_right, param_->col_, | |||
| filter_per_channel_, cur_zp); | |||
| @@ -59,11 +151,6 @@ MatmulBaseInt8CPUKernel::~MatmulBaseInt8CPUKernel() { | |||
| FreeQuantParam(); | |||
| FreeTmpBuffer(); | |||
| if (bias_ptr_ != nullptr) { | |||
| free(bias_ptr_); | |||
| bias_ptr_ = nullptr; | |||
| } | |||
| } | |||
| void MatmulBaseInt8CPUKernel::FreeQuantParam() { | |||
| @@ -180,17 +267,59 @@ void MatmulBaseInt8CPUKernel::InitParameter() { | |||
| #ifdef ENABLE_ARM32 | |||
| row_tile_ = C4NUM; | |||
| col_tile_ = C2NUM; | |||
| deep_tile_ = C16NUM; | |||
| #elif ENABLE_ARM64 | |||
| support_sdot_ = mindspore::lite::IsSupportSDot(); | |||
| row_tile_ = C4NUM; | |||
| if (support_sdot_) { | |||
| col_tile_ = C16NUM; | |||
| deep_tile_ = C4NUM; | |||
| } else { | |||
| col_tile_ = C4NUM; | |||
| deep_tile_ = C16NUM; | |||
| } | |||
| #else | |||
| row_tile_ = C4NUM; | |||
| col_tile_ = C4NUM; | |||
| deep_tile_ = C16NUM; | |||
| #endif | |||
| if (param_->a_transpose_) { | |||
| a_pack_func_ = RowMajor2Col16x4MajorInt8; | |||
| } else { | |||
| a_pack_func_ = RowMajor2Row16x4MajorInt8; | |||
| } | |||
| if (param_->b_transpose_) { | |||
| #ifdef ENABLE_ARM32 | |||
| b_pack_func_ = RowMajor2Row2x16MajorInt8; | |||
| #elif ENABLE_ARM64 | |||
| if (support_sdot_) { | |||
| b_pack_func_ = RowMajor2Row4x16MajorInt8; | |||
| } else { | |||
| b_pack_func_ = RowMajor2Row16x4MajorInt8; | |||
| } | |||
| #else | |||
| b_pack_func_ = RowMajor2Row16x4MajorInt8; | |||
| #endif | |||
| } else { | |||
| #ifdef ENABLE_ARM32 | |||
| b_pack_func_ = RowMajor2Col16x2MajorInt8; | |||
| #elif ENABLE_ARM64 | |||
| if (support_sdot_) { | |||
| b_pack_func_ = RowMajor2Col4x16MajorInt8; | |||
| } else { | |||
| b_pack_func_ = RowMajor2Col16x4MajorInt8; | |||
| } | |||
| #else | |||
| b_pack_func_ = RowMajor2Col16x4MajorInt8; | |||
| #endif | |||
| } | |||
| return; | |||
| } | |||
| void MatmulBaseInt8CPUKernel::ResizeParameter() { | |||
| param_->row_align_ = UP_ROUND(param_->row_, row_tile_); | |||
| param_->col_align_ = UP_ROUND(param_->col_, col_tile_); | |||
| param_->deep_16_ = UP_ROUND(param_->deep_, C16NUM); | |||
| param_->deep_align_ = UP_ROUND(param_->deep_, deep_tile_); | |||
| thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(param_->col_align_, col_tile_)); | |||
| thread_stride_ = UP_DIV(UP_DIV(param_->col_align_, col_tile_), thread_count_); | |||
| @@ -221,37 +350,30 @@ void MatmulBaseInt8CPUKernel::TransferB() { | |||
| auto weight_data = reinterpret_cast<int8_t *>(in_tensors_.at(1)->data()); | |||
| for (int i = 0; i < param_->batch; i++) { | |||
| auto current_weight = weight_data + i * param_->deep_ * param_->col_; | |||
| auto current_b_pack = pack_b_ptr_ + i * param_->col_align_ * param_->deep_16_; | |||
| auto current_b_pack = pack_b_ptr_ + i * param_->col_align_ * param_->deep_align_; | |||
| auto current_sums = weight_bias_sums_ + i * param_->col_align_; | |||
| MS_CHECK_PTR_IF_NULL(b_pack_func_); | |||
| if (param_->b_transpose_) { | |||
| #ifdef ENABLE_ARM32 | |||
| RowMajor2Row2x16MajorInt8(current_weight, current_b_pack, param_->col_, param_->deep_); | |||
| #else | |||
| RowMajor2Row16x4MajorInt8(current_weight, current_b_pack, param_->col_, param_->deep_); | |||
| #endif | |||
| b_pack_func_(current_weight, current_b_pack, param_->col_, param_->deep_); | |||
| CalcWeightBiasSums(current_weight, param_->deep_, param_->col_, quant_param_->input_.zp_, | |||
| quant_param_->filter_zp_, bias_ptr_, current_sums, ColMajor, filter_per_channel_); | |||
| } else { | |||
| #ifdef ENABLE_ARM32 | |||
| RowMajor2Col16x2MajorInt8(current_weight, current_b_pack, param_->deep_, param_->col_); | |||
| #else | |||
| RowMajor2Col16x4MajorInt8(current_weight, param_->deep_, param_->col_, current_b_pack); | |||
| #endif | |||
| b_pack_func_(current_weight, current_b_pack, param_->deep_, param_->col_); | |||
| CalcWeightBiasSums(current_weight, param_->deep_, param_->col_, quant_param_->input_.zp_, | |||
| quant_param_->filter_zp_, bias_ptr_, current_sums, RowMajor, false); | |||
| quant_param_->filter_zp_, bias_ptr_, current_sums, RowMajor, filter_per_channel_); | |||
| } | |||
| } | |||
| return; | |||
| } | |||
| int MatmulBaseInt8CPUKernel::InitTmpBuffer() { | |||
| pack_a_ptr_ = reinterpret_cast<int8_t *>(malloc(param_->row_align_ * param_->deep_16_ * sizeof(int8_t))); | |||
| pack_a_ptr_ = reinterpret_cast<int8_t *>(malloc(param_->row_align_ * param_->deep_align_ * sizeof(int8_t))); | |||
| if (pack_a_ptr_ == nullptr) { | |||
| FreeTmpBuffer(); | |||
| return RET_ERROR; | |||
| } | |||
| pack_b_ptr_ = | |||
| reinterpret_cast<int8_t *>(malloc(param_->batch * param_->col_align_ * param_->deep_16_ * sizeof(int8_t))); | |||
| reinterpret_cast<int8_t *>(malloc(param_->batch * param_->col_align_ * param_->deep_align_ * sizeof(int8_t))); | |||
| if (pack_b_ptr_ == nullptr) { | |||
| FreeTmpBuffer(); | |||
| return RET_ERROR; | |||
| @@ -267,8 +389,8 @@ int MatmulBaseInt8CPUKernel::InitTmpBuffer() { | |||
| return RET_ERROR; | |||
| } | |||
| memset(pack_a_ptr_, 0, param_->row_align_ * param_->deep_16_ * sizeof(int8_t)); | |||
| memset(pack_b_ptr_, 0, param_->batch * param_->col_align_ * param_->deep_16_ * sizeof(int8_t)); | |||
| memset(pack_a_ptr_, 0, param_->row_align_ * param_->deep_align_ * sizeof(int8_t)); | |||
| memset(pack_b_ptr_, 0, param_->batch * param_->col_align_ * param_->deep_align_ * sizeof(int8_t)); | |||
| memset(input_sums_, 0, param_->row_align_ * sizeof(int)); | |||
| memset(weight_bias_sums_, 0, param_->batch * param_->col_align_ * sizeof(int)); | |||
| @@ -278,9 +400,7 @@ int MatmulBaseInt8CPUKernel::InitTmpBuffer() { | |||
| int MatmulBaseInt8CPUKernel::InitBias() { | |||
| if (in_tensors_.size() == kInputSize2) { | |||
| auto bias_tensor = in_tensors_[kBiasIndex]; | |||
| MS_CHECK_GT(bias_tensor->ElementsNum(), 0, RET_ERROR); | |||
| int max_bias_data = UP_ROUND(bias_tensor->ElementsNum(), C4NUM); | |||
| bias_ptr_ = reinterpret_cast<int *>(malloc(max_bias_data * sizeof(int))); | |||
| bias_ptr_ = reinterpret_cast<int *>(bias_tensor->data()); | |||
| if (bias_ptr_ == nullptr) { | |||
| MS_LOG(ERROR) << "Memory allocation failed"; | |||
| FreeTmpBuffer(); | |||
| @@ -332,11 +452,47 @@ int MatmulBaseInt8CPUKernel::ReSize() { | |||
| return RET_OK; | |||
| } | |||
| #ifdef ENABLE_ARM64 | |||
| int MatmulBaseInt8CPUKernel::RunArm64Sdot() { | |||
| int8_t *a_ptr = reinterpret_cast<int8_t *>(in_tensors_.at(0)->data()); | |||
| int8_t *b_ptr = reinterpret_cast<int8_t *>(in_tensors_.at(1)->data()); | |||
| int8_t *c_ptr = reinterpret_cast<int8_t *>(out_tensors_.at(0)->data()); | |||
| CHECK_NULL_RETURN(a_ptr); | |||
| CHECK_NULL_RETURN(b_ptr); | |||
| CHECK_NULL_RETURN(c_ptr); | |||
| for (int i = 0; i < param_->batch; i++) { | |||
| batch_input_ptr_ = a_ptr + i * param_->row_ * param_->deep_; | |||
| auto ret = ParallelLaunch(this->ms_context_, Arm64SdotPreRun, this, op_parameter_->thread_num_); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "RunArm64Sdot error: [" << ret << "]"; | |||
| return ret; | |||
| } | |||
| batch_weight_ptr_ = b_ptr + i * param_->col_ * param_->deep_; | |||
| batch_b_ptr_ = pack_b_ptr_ + i * param_->col_align_ * param_->deep_align_; | |||
| batch_sums_ = weight_bias_sums_ + i * param_->col_align_; | |||
| batch_c_ptr_ = c_ptr + i * param_->row_ * param_->col_; | |||
| ret = ParallelLaunch(this->ms_context_, Arm64SdotRun, this, thread_count_); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "RunArm64Sdot error: [" << ret << "]"; | |||
| return ret; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| #endif | |||
| int MatmulBaseInt8CPUKernel::Run() { | |||
| #ifdef ENABLE_ARM64 | |||
| if (support_sdot_) { | |||
| return RunArm64Sdot(); | |||
| } | |||
| #endif | |||
| if (param_->b_const_ == false) { | |||
| TransferB(); | |||
| } | |||
| int8_t *a_ptr = reinterpret_cast<int8_t *>(in_tensors_.at(0)->data()); | |||
| int8_t *c_ptr = reinterpret_cast<int8_t *>(out_tensors_.at(0)->data()); | |||
| CHECK_NULL_RETURN(a_ptr); | |||
| @@ -345,14 +501,16 @@ int MatmulBaseInt8CPUKernel::Run() { | |||
| for (int i = 0; i < param_->batch; i++) { | |||
| auto current_src_a = a_ptr + i * param_->row_ * param_->deep_; | |||
| if (param_->a_transpose_) { | |||
| RowMajor2Col16x4MajorInt8(current_src_a, param_->deep_, param_->row_, pack_a_ptr_); | |||
| MS_CHECK_TRUE_RET(a_pack_func_ != nullptr, RET_ERROR); | |||
| a_pack_func_(current_src_a, pack_a_ptr_, param_->deep_, param_->row_); | |||
| CalcInputSums(current_src_a, param_->row_, param_->deep_, tmp_weight_zp, input_sums_, ColMajor); | |||
| } else { | |||
| RowMajor2Row16x4MajorInt8(current_src_a, pack_a_ptr_, param_->row_, param_->deep_); | |||
| MS_CHECK_TRUE_RET(a_pack_func_ != nullptr, RET_ERROR); | |||
| a_pack_func_(current_src_a, pack_a_ptr_, param_->row_, param_->deep_); | |||
| CalcInputSums(current_src_a, param_->row_, param_->deep_, tmp_weight_zp, input_sums_, RowMajor); | |||
| } | |||
| batch_b_ptr_ = pack_b_ptr_ + i * param_->col_align_ * param_->deep_16_; | |||
| batch_b_ptr_ = pack_b_ptr_ + i * param_->col_align_ * param_->deep_align_; | |||
| batch_sums_ = weight_bias_sums_ + i * param_->col_align_; | |||
| batch_c_ptr_ = c_ptr + i * param_->row_ * param_->col_; | |||
| @@ -29,6 +29,8 @@ | |||
| namespace mindspore::kernel { | |||
| class MatmulBaseInt8CPUKernel : public InnerKernel { | |||
| typedef void (*PackFunc)(const int8_t *src, int8_t *dst, int row, int col); | |||
| public: | |||
| MatmulBaseInt8CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx) | |||
| @@ -42,6 +44,11 @@ class MatmulBaseInt8CPUKernel : public InnerKernel { | |||
| public: | |||
| int RunImpl(int task_id); | |||
| #ifdef ENABLE_ARM64 | |||
| int RunArm64Sdot(); | |||
| int Arm64SdotImpl(int task_id); | |||
| int Arm64SdotPre(int task_id); | |||
| #endif | |||
| protected: | |||
| void InitParameter(); | |||
| @@ -72,12 +79,18 @@ class MatmulBaseInt8CPUKernel : public InnerKernel { | |||
| int *weight_bias_sums_ = nullptr; | |||
| int *bias_ptr_ = nullptr; | |||
| bool filter_per_channel_ = true; | |||
| int8_t *batch_input_ptr_ = nullptr; | |||
| int8_t *batch_weight_ptr_ = nullptr; | |||
| int8_t *batch_b_ptr_ = nullptr; | |||
| int8_t *batch_c_ptr_ = nullptr; | |||
| int *batch_sums_ = nullptr; | |||
| int row_tile_ = C4NUM; | |||
| int col_tile_ = C4NUM; | |||
| int deep_tile_ = C16NUM; | |||
| int channel_num_ = 0; | |||
| bool support_sdot_ = false; | |||
| PackFunc a_pack_func_{nullptr}; | |||
| PackFunc b_pack_func_{nullptr}; | |||
| }; | |||
| } // namespace mindspore::kernel | |||