| @@ -15,33 +15,7 @@ | |||||
| */ | */ | ||||
| #include "nnacl/int8/matmul_int8.h" | #include "nnacl/int8/matmul_int8.h" | ||||
| #include <limits.h> | |||||
| #include "nnacl/quantization/fixed_point.h" | #include "nnacl/quantization/fixed_point.h" | ||||
| void RowMajor2Row8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col) { | |||||
| for (int r = 0; r < row; r++) { | |||||
| int8_t *src = src_ptr + r * col; | |||||
| for (int c = 0; c < col; c++) { | |||||
| int cd8 = c / 8; | |||||
| int cm8 = c % 8; | |||||
| dst_ptr[cd8 * 8 * row + r * 8 + cm8] = src[c]; | |||||
| } | |||||
| } | |||||
| } | |||||
| void RowMajor2Row4x16MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col) { | |||||
| int col16 = UP_ROUND(col, C16NUM); | |||||
| for (int r = 0; r < row; r++) { | |||||
| int rd4 = r / C4NUM; | |||||
| int rm4 = r % C4NUM; | |||||
| for (int c = 0; c < col; c++) { | |||||
| int cd16 = c / C16NUM; | |||||
| int cm16 = c % C16NUM; | |||||
| int dst_index = rd4 * col16 * C4NUM + cd16 * C4NUM * C16NUM + rm4 * C16NUM + cm16; | |||||
| int src_index = r * col + c; | |||||
| dst_ptr[dst_index] = src_ptr[src_index]; | |||||
| } | |||||
| } | |||||
| } | |||||
| void RowMajor2Row2x16MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col) { | void RowMajor2Row2x16MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col) { | ||||
| int col16 = UP_ROUND(col, C16NUM); | int col16 = UP_ROUND(col, C16NUM); | ||||
| @@ -90,22 +64,7 @@ void MatrixEmptyInt8(int8_t *dst, int row, int col) { | |||||
| return; | return; | ||||
| } | } | ||||
| void RowMajor2Row4x8MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col) { | |||||
| /* Row-major to row16x4-major (block row-major) */ | |||||
| int col4 = UP_ROUND(col, C4NUM); | |||||
| for (int r = 0; r < row; r++) { | |||||
| int rd8 = r / C8NUM, rm8 = r % C8NUM; | |||||
| for (int c = 0; c < col; c++) { | |||||
| int cd4 = c / C4NUM, cm4 = c % C4NUM; | |||||
| int src_index = r * col + c; | |||||
| int dst_index = rd8 * col4 * C8NUM + cd4 * C4NUM * C8NUM + rm8 * C4NUM + cm4; | |||||
| dst_ptr[dst_index] = src_ptr[src_index]; | |||||
| } | |||||
| } | |||||
| return; | |||||
| } | |||||
| void RowMajor2Row16x4MajorInt8(void *src_ptr, void *dst_ptr, int row, int col) { | |||||
| void RowMajor2Row16x4MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col) { | |||||
| /* Row-major to row16x4-major (block row-major) */ | /* Row-major to row16x4-major (block row-major) */ | ||||
| int col16 = UP_ROUND(col, C16NUM); | int col16 = UP_ROUND(col, C16NUM); | ||||
| size_t row_4div = row / C4NUM * C4NUM; | size_t row_4div = row / C4NUM * C4NUM; | ||||
| @@ -185,16 +144,6 @@ void RowMajor2Row16x4MajorInt8(void *src_ptr, void *dst_ptr, int row, int col) { | |||||
| return; | return; | ||||
| } | } | ||||
| void RowMajor2Col8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col) { | |||||
| for (int r = 0; r < row; r++) { | |||||
| int rd8 = r / 8; | |||||
| int rm8 = r % 8; | |||||
| for (int c = 0; c < col; c++) { | |||||
| dst_ptr[rd8 * col * 8 + c * 8 + rm8] = src_ptr[r * col + c]; | |||||
| } | |||||
| } | |||||
| } | |||||
| void MatMulInt8_16x4(const int8_t *a, const int8_t *b, int *dst, int row_4, int col_4, int deep_16, | void MatMulInt8_16x4(const int8_t *a, const int8_t *b, int *dst, int row_4, int col_4, int deep_16, | ||||
| const int *input_sum, const int *bias) { | const int *input_sum, const int *bias) { | ||||
| /* row4x16-major * row16x4-major => row4x4-major */ | /* row4x16-major * row16x4-major => row4x4-major */ | ||||
| @@ -319,47 +268,6 @@ void MatMulInt8_8x8_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, | |||||
| return; | return; | ||||
| } | } | ||||
| /* row4x16-major * col16x4-major => row4x4-major */ | |||||
| void MatmulInt8(const int8_t *a, const int8_t *b, int8_t *dst, const int *a_sums, const int *bias, int act_min, | |||||
| int act_max, int out_zp, int multiplier, int left_shift, int right_shift, int row, int col, int deep16, | |||||
| int stride) { | |||||
| int8_t *output = dst; | |||||
| for (int r = 0; r < row; r++) { | |||||
| for (int c = 0; c < col; c++) { | |||||
| int r4div = r / C4NUM; | |||||
| int r4mod = r % C4NUM; | |||||
| int c4div = c / C4NUM; | |||||
| int c4mod = c % C4NUM; | |||||
| int value = 0; | |||||
| for (int d = 0; d < deep16; d++) { | |||||
| int d16div = d / C16NUM; | |||||
| int d16mod = d % C16NUM; | |||||
| size_t ai = r4div * deep16 * C4NUM + d16div * C4NUM * C16NUM + r4mod * C16NUM + d16mod; | |||||
| size_t bi = c4div * deep16 * C4NUM + d16div * C4NUM * C16NUM + c4mod * C16NUM + d16mod; | |||||
| value += a[ai] * b[bi]; | |||||
| } | |||||
| value -= a_sums[r]; | |||||
| value += bias[c]; | |||||
| value = MultiplyByQuantizedMultiplier(value, multiplier, left_shift, right_shift) + out_zp; | |||||
| value = MSMIN(INT8_MAX, value); | |||||
| value = MSMAX(INT8_MIN, value); | |||||
| output[c] = (int8_t)value; | |||||
| } | |||||
| output += stride; | |||||
| } | |||||
| } | |||||
| void RowMajor2Row4x16Major(int8_t *src, int row, int col, int8_t *dst, int col_16) { | |||||
| int stride = sizeof(int8_t) * 16 * 4; | |||||
| for (int r = 0; r < row; ++r) { | |||||
| for (int c = 0; c < col; ++c) { | |||||
| int stride_n = r / 4 * (col_16 / 16) + c / 16; | |||||
| int src_idx = r * col + c; | |||||
| dst[stride * stride_n + r % 4 * 16 + c % 16] = src[src_idx]; | |||||
| } | |||||
| } | |||||
| } | |||||
| void RowMajor2Col16x4Major(int8_t *src, int row, int col, int8_t *dst, int row_16) { | void RowMajor2Col16x4Major(int8_t *src, int row, int col, int8_t *dst, int row_16) { | ||||
| int stride = sizeof(int8_t) * 16 * 4; | int stride = sizeof(int8_t) * 16 * 4; | ||||
| for (int r = 0; r < row; ++r) { | for (int r = 0; r < row; ++r) { | ||||
| @@ -405,14 +313,3 @@ void CalcWeightBiasSums(int8_t *weight, int row, int col, int input_zp, int weig | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| void Row4x4Major2RowMajor(int8_t *src, int row4, int8_t *dst, int row, int cow) { | |||||
| int stride = sizeof(int8_t) * 4 * 4; | |||||
| for (int r = 0; r < row; ++r) { | |||||
| for (int c = 0; c < cow; ++c) { | |||||
| int sride_n = c / 4 * (row4 / 4) + r / 4; | |||||
| int dst_idx = r * cow + c; | |||||
| dst[dst_idx] = src[stride * sride_n + r % 4 * 4 + c % 4]; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -25,32 +25,27 @@ | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| extern "C" { | extern "C" { | ||||
| #endif | #endif | ||||
| /* 4x16 16x4 -> 4x4 */ | |||||
| void MatMulInt8_16x4(const int8_t *a, const int8_t *b, int *dst, int row_4, int col_4, int deep_16, | void MatMulInt8_16x4(const int8_t *a, const int8_t *b, int *dst, int row_4, int col_4, int deep_16, | ||||
| const int *input_sum, const int *bias); | const int *input_sum, const int *bias); | ||||
| void MatMulInt8_16x4_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_16, | void MatMulInt8_16x4_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_16, | ||||
| size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift, | size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift, | ||||
| int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi, | int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi, | ||||
| bool per_channel); | bool per_channel); | ||||
| void RowMajor2Row8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col); | |||||
| void RowMajor2Row4x16MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col); | |||||
| void RowMajor2Col8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col); | |||||
| void RowMajor2Row16x4MajorInt8(void *src_ptr, void *dst_ptr, int row, int col); | |||||
| void RowMajor2Row16x4MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col); | |||||
| void RowMajor2Col16x4Major(int8_t *src, int row, int col, int8_t *dst, int row_16); | |||||
| void CalcInputSums(int8_t *input, int row, int col, int weight_zp, int *dst, DataOrder order); | |||||
| void CalcWeightBiasSums(int8_t *weight, int row, int col, int input_zp, int weight_zp, int *bias, int *dst, | |||||
| DataOrder order); | |||||
| /* 8x4 4x8 -> 8x8 */ | |||||
| void RowMajor2Row8x4MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col); | |||||
| void MatMulInt8_8x8_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, | void MatMulInt8_8x8_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, | ||||
| size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift, | size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift, | ||||
| int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi, | int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi, | ||||
| size_t per_channel); | size_t per_channel); | ||||
| void RowMajor2Row8x4MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col); | |||||
| void RowMajor2Row4x8MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col); | |||||
| void RowMajor2Row4x16Major(int8_t *src, int row, int col, int8_t *dst, int col_16); | |||||
| void RowMajor2Col16x4Major(int8_t *src, int row, int col, int8_t *dst, int row_16); | |||||
| void CalcInputSums(int8_t *input, int row, int col, int weight_zp, int *dst, DataOrder order); | |||||
| void CalcWeightBiasSums(int8_t *weight, int row, int col, int input_zp, int weight_zp, int *bias, int *dst, | |||||
| DataOrder order); | |||||
| void MatmulInt8(const int8_t *a, const int8_t *b, int8_t *dst, const int *a_sums, const int *bias, int act_min, | |||||
| int act_max, int out_zp, int multiplier, int left_shift, int right_shift, int row, int col, int deep16, | |||||
| int stride); | |||||
| /* 4x16 16x2 -> 4x2 */ | |||||
| void RowMajor2Row2x16MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col); | void RowMajor2Row2x16MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col); | ||||
| void MatMulInt8_4x2_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_16, | void MatMulInt8_4x2_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_16, | ||||
| size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift, | size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift, | ||||
| @@ -27,8 +27,6 @@ typedef void (*MATMUL_OPT_R_FUNC)(const int8_t *a, const int8_t *b, int8_t *dst, | |||||
| int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, | int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, | ||||
| int32_t maxi, size_t per_channel); | int32_t maxi, size_t per_channel); | ||||
| typedef void (*MAT_TRANS_FUNC)(void *dst, void *a, int row, int col); | |||||
| typedef enum OutType { OutType_C8 = 0, OutType_Nhwc = 1, OutType_TileC8 = 2 } OutType; | typedef enum OutType { OutType_C8 = 0, OutType_Nhwc = 1, OutType_TileC8 = 2 } OutType; | ||||
| typedef struct MatMulParameter { | typedef struct MatMulParameter { | ||||
| @@ -158,7 +158,7 @@ int Convolution1x1Int8CPUKernel::InitWeightBias() { | |||||
| RowMajor2Row8x4MajorInt8(reinterpret_cast<int8_t *>(filter_tensor->MutableData()), packed_weight_, output_channel, | RowMajor2Row8x4MajorInt8(reinterpret_cast<int8_t *>(filter_tensor->MutableData()), packed_weight_, output_channel, | ||||
| input_channel); | input_channel); | ||||
| } else { | } else { | ||||
| RowMajor2Row4x16MajorInt8(reinterpret_cast<int8_t *>(filter_tensor->MutableData()), packed_weight_, output_channel, | |||||
| RowMajor2Row16x4MajorInt8(reinterpret_cast<int8_t *>(filter_tensor->MutableData()), packed_weight_, output_channel, | |||||
| input_channel); | input_channel); | ||||
| } | } | ||||
| @@ -207,7 +207,7 @@ int Convolution1x1Int8CPUKernel::InitWeightBiasArm32() { | |||||
| memcpy(bias_data_, in_tensors_[kBiasIndex]->MutableData(), output_channel * sizeof(int32_t)); | memcpy(bias_data_, in_tensors_[kBiasIndex]->MutableData(), output_channel * sizeof(int32_t)); | ||||
| } | } | ||||
| InitBiasByzp(filter_tensor->MutableData(), input_channel, output_channel, UP_ROUND(output_channel, C2NUM)); | |||||
| InitBiasByzp(filter_tensor->MutableData(), input_channel, output_channel, col2); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -125,8 +125,6 @@ int DeConvInt8CPUKernel::InitParam() { | |||||
| matmul_param_->deep_ = conv_param_->input_channel_; | matmul_param_->deep_ = conv_param_->input_channel_; | ||||
| matmul_param_->col_ = conv_param_->output_channel_ * conv_param_->kernel_h_ * conv_param_->kernel_w_; | matmul_param_->col_ = conv_param_->output_channel_ * conv_param_->kernel_h_ * conv_param_->kernel_w_; | ||||
| /* optimize normal -> same data layout */ | |||||
| input_trans_func_ = RowMajor2Row16x4MajorInt8; | |||||
| int oc4 = UP_DIV(conv_param_->output_channel_, C4NUM); | int oc4 = UP_DIV(conv_param_->output_channel_, C4NUM); | ||||
| thread_count_ = MSMIN(op_parameter_->thread_num_, oc4); | thread_count_ = MSMIN(op_parameter_->thread_num_, oc4); | ||||
| thread_stride_ = UP_DIV(oc4, thread_count_); | thread_stride_ = UP_DIV(oc4, thread_count_); | ||||
| @@ -275,8 +273,8 @@ int DeConvInt8CPUKernel::Run() { | |||||
| } | } | ||||
| for (int batch_index = 0; batch_index < conv_param_->input_batch_; batch_index++) { | for (int batch_index = 0; batch_index < conv_param_->input_batch_; batch_index++) { | ||||
| input_trans_func_(src_in + batch_index * matmul_param_->row_ * conv_param_->input_channel_, input_ptr_, | |||||
| matmul_param_->row_, matmul_param_->deep_); | |||||
| RowMajor2Row16x4MajorInt8(src_in + batch_index * matmul_param_->row_ * conv_param_->input_channel_, input_ptr_, | |||||
| matmul_param_->row_, matmul_param_->deep_); | |||||
| output_ptr_ = src_out + batch_index * matmul_param_->col_; | output_ptr_ = src_out + batch_index * matmul_param_->col_; | ||||
| DeConvPackInputSum(input_ptr_, input_sum_, conv_param_->conv_quant_arg_.filter_quant_args_[0].zp_, | DeConvPackInputSum(input_ptr_, input_sum_, conv_param_->conv_quant_arg_.filter_quant_args_[0].zp_, | ||||
| @@ -65,7 +65,6 @@ class DeConvInt8CPUKernel : public ConvolutionBaseCPUKernel { | |||||
| size_t thread_count_ = 1; | size_t thread_count_ = 1; | ||||
| size_t thread_stride_ = 0; | size_t thread_stride_ = 0; | ||||
| MATMUL_OPT_R4_FUNC matmul_func_; | MATMUL_OPT_R4_FUNC matmul_func_; | ||||
| MAT_TRANS_FUNC input_trans_func_; | |||||
| MatMulParameter *matmul_param_ = nullptr; | MatMulParameter *matmul_param_ = nullptr; | ||||
| bool support_optimize_ = true; | bool support_optimize_ = true; | ||||
| }; | }; | ||||
| @@ -57,7 +57,7 @@ int FullconnectionInt8CPUKernel::ReSize() { | |||||
| if (!weight_bias_sums_) return RET_MEMORY_FAILED; | if (!weight_bias_sums_) return RET_MEMORY_FAILED; | ||||
| memset(weight_bias_sums_, 0, c4_ * sizeof(int)); | memset(weight_bias_sums_, 0, c4_ * sizeof(int)); | ||||
| auto weight_data = reinterpret_cast<int8_t *>(in_tensors_[1]->MutableData()); | auto weight_data = reinterpret_cast<int8_t *>(in_tensors_[1]->MutableData()); | ||||
| RowMajor2Row4x16Major(weight_data, fc_param_->col_, fc_param_->deep_, b_c16x4_ptr_, d16_); | |||||
| RowMajor2Row16x4MajorInt8(weight_data, b_c16x4_ptr_, fc_param_->col_, fc_param_->deep_); | |||||
| if (in_tensors_.size() == 3) { | if (in_tensors_.size() == 3) { | ||||
| auto bias_len = fc_param_->col_8_ * sizeof(int); | auto bias_len = fc_param_->col_8_ * sizeof(int); | ||||
| bias_ptr_ = reinterpret_cast<int *>(ctx_->allocator->Malloc(bias_len)); | bias_ptr_ = reinterpret_cast<int *>(ctx_->allocator->Malloc(bias_len)); | ||||
| @@ -111,8 +111,8 @@ int FullconnectionInt8CPUKernel::RunImpl(int task_id) { | |||||
| q.out_act_max, q.output.zp_, &q.quant_multiplier, &q.left_shift, &q.right_shift, p->row_, cur_oc_res, | q.out_act_max, q.output.zp_, &q.quant_multiplier, &q.left_shift, &q.right_shift, p->row_, cur_oc_res, | ||||
| p->col_ * sizeof(int8_t), 0); | p->col_ * sizeof(int8_t), 0); | ||||
| #else | #else | ||||
| MatmulInt8(a_r4x16_ptr_, cur_b, cur_c, input_sums_, cur_bias, q.out_act_min, q.out_act_max, q.output.zp_, | |||||
| q.quant_multiplier, q.left_shift, q.right_shift, p->row_, cur_oc_res, d16_, p->col_); | |||||
| MatMulInt8_16x4_r(a_r4x16_ptr_, cur_b, cur_c, p->row_, cur_oc_res, d16_, p->col_, input_sums_, cur_bias, | |||||
| &q.left_shift, &q.right_shift, &q.quant_multiplier, q.output.zp_, INT8_MIN, INT8_MAX, false); | |||||
| #endif | #endif | ||||
| return RET_OK; | return RET_OK; | ||||
| @@ -135,7 +135,7 @@ int FullconnectionInt8CPUKernel::Run() { | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto input_ptr = reinterpret_cast<int8_t *>(in_tensors_[0]->MutableData()); | auto input_ptr = reinterpret_cast<int8_t *>(in_tensors_[0]->MutableData()); | ||||
| RowMajor2Row4x16Major(input_ptr, fc_param_->row_, fc_param_->deep_, a_r4x16_ptr_, d16_); | |||||
| RowMajor2Row16x4MajorInt8(input_ptr, a_r4x16_ptr_, fc_param_->row_, fc_param_->deep_); | |||||
| CalcInputSums(input_ptr, fc_param_->row_, fc_param_->deep_, quant_params_.weight.zp_, input_sums_, RowMajor); | CalcInputSums(input_ptr, fc_param_->row_, fc_param_->deep_, quant_params_.weight.zp_, input_sums_, RowMajor); | ||||
| ParallelLaunch(THREAD_POOL_DEFAULT, FcInt8Run, this, thread_count_); | ParallelLaunch(THREAD_POOL_DEFAULT, FcInt8Run, this, thread_count_); | ||||
| return RET_OK; | return RET_OK; | ||||
| @@ -104,8 +104,8 @@ int MatmulInt8CPUKernel::RunImpl(int task_id) { | |||||
| p.output.zp_, &p.quant_multiplier, &p.left_shift, &p.right_shift, params_->row_, cur_oc_res, | p.output.zp_, &p.quant_multiplier, &p.left_shift, &p.right_shift, params_->row_, cur_oc_res, | ||||
| params_->col_ * sizeof(int8_t), false); | params_->col_ * sizeof(int8_t), false); | ||||
| #else | #else | ||||
| MatmulInt8(a_r4x16_ptr_, cur_b, cur_c, input_sums_, cur_bias, INT8_MIN, INT8_MAX, p.output.zp_, p.quant_multiplier, | |||||
| p.left_shift, p.right_shift, params_->row_, cur_oc_res, d16_, params_->col_); | |||||
| MatMulInt8_16x4_r(a_r4x16_ptr_, cur_b, cur_c, params_->row_, cur_oc_res, d16_, params_->col_, input_sums_, cur_bias, | |||||
| &p.left_shift, &p.right_shift, &p.quant_multiplier, p.output.zp_, INT8_MIN, INT8_MAX, false); | |||||
| #endif | #endif | ||||
| return RET_OK; | return RET_OK; | ||||
| @@ -142,11 +142,11 @@ int MatmulInt8CPUKernel::Run() { | |||||
| RowMajor2Col16x4Major(cur_a_ptr, params_->deep_, params_->row_, a_r4x16_ptr_, d16_); | RowMajor2Col16x4Major(cur_a_ptr, params_->deep_, params_->row_, a_r4x16_ptr_, d16_); | ||||
| CalcInputSums(cur_a_ptr, params_->row_, params_->deep_, quant_params_.weight.zp_, input_sums_, ColMajor); | CalcInputSums(cur_a_ptr, params_->row_, params_->deep_, quant_params_.weight.zp_, input_sums_, ColMajor); | ||||
| } else { | } else { | ||||
| RowMajor2Row4x16Major(cur_a_ptr, params_->row_, params_->deep_, a_r4x16_ptr_, d16_); | |||||
| RowMajor2Row16x4MajorInt8(cur_a_ptr, a_r4x16_ptr_, params_->row_, params_->deep_); | |||||
| CalcInputSums(cur_a_ptr, params_->row_, params_->deep_, quant_params_.weight.zp_, input_sums_, RowMajor); | CalcInputSums(cur_a_ptr, params_->row_, params_->deep_, quant_params_.weight.zp_, input_sums_, RowMajor); | ||||
| } | } | ||||
| if (params_->b_transpose_) { | if (params_->b_transpose_) { | ||||
| RowMajor2Row4x16Major(cur_b_ptr, params_->col_, params_->deep_, b_c16x4_ptr_, d16_); | |||||
| RowMajor2Row16x4MajorInt8(cur_b_ptr, b_c16x4_ptr_, params_->col_, params_->deep_); | |||||
| CalcWeightBiasSums(cur_b_ptr, params_->deep_, params_->col_, quant_params_.input.zp_, quant_params_.weight.zp_, | CalcWeightBiasSums(cur_b_ptr, params_->deep_, params_->col_, quant_params_.input.zp_, quant_params_.weight.zp_, | ||||
| NULL, weight_bias_sums_, ColMajor); | NULL, weight_bias_sums_, ColMajor); | ||||
| } else { | } else { | ||||
| @@ -113,7 +113,7 @@ TEST_F(TestMatmulInt8, simple) { | |||||
| memset(a_r4x16, 0, ROW4 * DEPTH16); | memset(a_r4x16, 0, ROW4 * DEPTH16); | ||||
| int8_t *b_c16x4 = new int8_t[COL4 * DEPTH16]; | int8_t *b_c16x4 = new int8_t[COL4 * DEPTH16]; | ||||
| memset(b_c16x4, 0, COL4 * DEPTH16); | memset(b_c16x4, 0, COL4 * DEPTH16); | ||||
| RowMajor2Row4x16Major(a, ROW, DEPTH, a_r4x16, DEPTH16); | |||||
| RowMajor2Row16x4MajorInt8(a, a_r4x16, ROW, DEPTH); | |||||
| RowMajor2Col16x4Major(b, DEPTH, COL, b_c16x4, DEPTH16); | RowMajor2Col16x4Major(b, DEPTH, COL, b_c16x4, DEPTH16); | ||||
| int a_sums[ROW4] = {0}; | int a_sums[ROW4] = {0}; | ||||
| int bias[COL4] = {0}; | int bias[COL4] = {0}; | ||||
| @@ -123,7 +123,8 @@ TEST_F(TestMatmulInt8, simple) { | |||||
| MatmulInt8Neon64(a_r4x16, b_c16x4, output, ROW4, COL4, DEPTH16, a_sums, bias, INT8_MIN, INT8_MAX, 0, &multiplier, &ls, | MatmulInt8Neon64(a_r4x16, b_c16x4, output, ROW4, COL4, DEPTH16, a_sums, bias, INT8_MIN, INT8_MAX, 0, &multiplier, &ls, | ||||
| &rs, ROW, COL, COL, false); | &rs, ROW, COL, COL, false); | ||||
| #else | #else | ||||
| MatmulInt8(a_r4x16, b_c16x4, output, a_sums, bias, INT8_MIN, INT8_MAX, 0, multiplier, ls, rs, ROW, COL, DEPTH16, COL); | |||||
| MatMulInt8_16x4_r(a_r4x16, b_c16x4, output, ROW, COL, DEPTH16, COL, a_sums, bias, &ls, &rs, &multiplier, 0, INT8_MIN, | |||||
| INT8_MAX, false); | |||||
| #endif | #endif | ||||
| CompareOutputData(output, correct, ROW * COL, 0.1); | CompareOutputData(output, correct, ROW * COL, 0.1); | ||||
| delete[] a_r4x16; | delete[] a_r4x16; | ||||
| @@ -54,6 +54,7 @@ STATUS TfliteExpandDimsParser::Parse(const std::unique_ptr<tflite::OperatorT> &t | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | tflite_tensors.size(), schema::Format::Format_NHWC); | ||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), | AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), | ||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | tflite_tensors.size(), schema::Format::Format_NHWC); | ||||
| return RET_OK; | |||||
| } | } | ||||
| TfliteNodeRegister g_tfliteExpandDimsParser("ExpandDims", new TfliteExpandDimsParser()); | TfliteNodeRegister g_tfliteExpandDimsParser("ExpandDims", new TfliteExpandDimsParser()); | ||||
| } // namespace lite | } // namespace lite | ||||