From 4707f1b38022392c75ddd13955307f63d48d9bc5 Mon Sep 17 00:00:00 2001 From: zhanyuan Date: Thu, 20 Aug 2020 16:51:58 +0800 Subject: [PATCH] Matmul_int8 arm64 neon optimize --- .../kernel/arm/int8/deconvolution_int8.cc | 34 ++- .../runtime/kernel/arm/int8/matmul_int8.cc | 59 +++- .../src/runtime/kernel/arm/int8/matmul_int8.h | 36 ++- .../arm/nnacl/assembly/arm64/MatmulInt8.S | 276 ++++++++++++++++++ .../arm/nnacl/assembly/arm64/MatmulR4Int8.S | 194 ++++++++++++ .../arm/nnacl/assembly/opt/MatmulOptR4Int8.S | 157 ++++++++++ .../runtime/kernel/arm/nnacl/fp32/matmul.c | 2 +- .../runtime/kernel/arm/nnacl/fp32/matmul.h | 2 +- .../runtime/kernel/arm/nnacl/int8/deconv.c | 2 +- .../kernel/arm/nnacl/int8/matmul_int8.c | 62 +++- .../kernel/arm/nnacl/int8/matmul_int8.h | 25 +- .../kernel/arm/nnacl/matmul_parameter.h | 4 +- .../runtime/kernel/arm/nnacl/opt_op_handler.c | 9 + .../kernel/arm/nnacl/optimized_kernel.h | 2 +- .../kernel/arm/nnacl/quantization/quantize.c | 15 +- .../kernel/arm/int8/deconv_int8_tests.cc | 2 +- .../kernel/arm/int8/matmul_int8_tests.cc | 1 - 17 files changed, 841 insertions(+), 41 deletions(-) create mode 100644 mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/MatmulInt8.S create mode 100644 mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/MatmulR4Int8.S create mode 100644 mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/opt/MatmulOptR4Int8.S diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_int8.cc index ca69f7c20f..c5f4eb5c2e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_int8.cc @@ -18,6 +18,7 @@ #include "src/runtime/kernel/arm/nnacl/quantization/fixed_point.h" #include "src/runtime/runtime_api.h" #include "src/kernel_registry.h" +#include "src/runtime/kernel/arm/nnacl/optimized_kernel.h" using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::lite::KernelRegistrar; @@ -89,14 +90,24 @@ int DeConvInt8CPUKernel::Init() { void DeConvInt8CPUKernel::CheckSupportOptimize() { matmul_func_ = nullptr; - support_optimize_ = false; + support_optimize_ = true; #ifdef ENABLE_ARM64 - /* todo */ + void *optimize_op_handler = OptimizeModule::GetInstance()->optimized_op_handler_; + if (optimize_op_handler != nullptr) { + dlerror(); + *(reinterpret_cast(&matmul_func_)) = dlsym(optimize_op_handler, "MatMulR4Int8_optimize_handler"); + auto dlopen_error = dlerror(); + if (dlopen_error != nullptr) { + MS_LOG(ERROR) << "load matmul func failed! " << dlopen_error << "."; + support_optimize_ = false; + matmul_func_ = nullptr; + } + } else { + support_optimize_ = false; + matmul_func_ = nullptr; + } #endif - - support_optimize_ = true; - matmul_func_ = MatMulOptR4Int8; } int DeConvInt8CPUKernel::InitParam() { @@ -109,15 +120,10 @@ int DeConvInt8CPUKernel::InitParam() { matmul_param_->deep_ = conv_param_->input_channel_; matmul_param_->col_ = conv_param_->output_channel_ * conv_param_->kernel_h_ * conv_param_->kernel_w_; - if (support_optimize_) { - input_trans_func_ = RowMajor2Row16x4MajorInt8; - size_t oc4 = UP_DIV(conv_param_->output_channel_, C4NUM); - thread_count_ = MSMIN(op_parameter_->thread_num_, oc4); - thread_stride_ = UP_DIV(oc4, thread_count_); - } else { - /*todo */ - } - + input_trans_func_ = RowMajor2Row16x4MajorInt8; + size_t oc4 = UP_DIV(conv_param_->output_channel_, C4NUM); + thread_count_ = MSMIN(op_parameter_->thread_num_, oc4); + thread_stride_ = UP_DIV(oc4, thread_count_); return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc index 1f5c8d49a5..3ae90eaab7 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc @@ -47,9 +47,29 @@ int MatmulInt8CPUKernel::ReSize() { params_->deep_ = params_->a_transpose_ ? x_shape[x_shape.size() - 2] : x_shape[x_shape.size() - 1]; params_->row_8_ = UP_ROUND(params_->row_, 8); params_->col_8_ = UP_ROUND(params_->col_, 8); - thread_count_ = MSMIN(thread_count_, UP_DIV(params_->col_8_, 8)); - thread_stride_ = UP_DIV(UP_DIV(params_->col_8_, 8), thread_count_); +#ifdef ENABLE_ARM64 + r4_ = UP_ROUND(params_->row_, 4); + c4_ = UP_ROUND(params_->col_, 4); + d16_ = UP_ROUND(params_->deep_, 16); + a_r4d16_ptr_ = reinterpret_cast(ctx_->allocator->Malloc(r4_ * d16_ * sizeof(int8_t))); + if (!a_r4d16_ptr_) return RET_MEMORY_FAILED; + memset(a_r4d16_ptr_, 0, r4_ * d16_ * sizeof(int8_t)); + b_c4d16_ptr_ = reinterpret_cast(ctx_->allocator->Malloc(c4_ * d16_ * sizeof(int8_t))); + if (!b_c4d16_ptr_) return RET_MEMORY_FAILED; + memset(b_c4d16_ptr_, 0, c4_ * d16_ * sizeof(int8_t)); + c_r4c4_ptr_ = reinterpret_cast(ctx_->allocator->Malloc(r4_ * c4_ * sizeof(int8_t))); + if (!c_r4c4_ptr_) return RET_MEMORY_FAILED; + memset(c_r4c4_ptr_, 0, r4_ * c4_ * sizeof(int8_t)); + a_sums_ = reinterpret_cast(ctx_->allocator->Malloc(r4_ * sizeof(int))); + if (!a_sums_) return RET_MEMORY_FAILED; + memset(a_sums_, 0, r4_ * sizeof(int)); + b_bias_ = reinterpret_cast(ctx_->allocator->Malloc(c4_ * sizeof(int))); + if (!b_bias_) return RET_MEMORY_FAILED; + memset(b_bias_, 0, c4_ * sizeof(int)); + thread_count_ = MSMIN(thread_count_, UP_DIV(c4_, 4)); + thread_stride_ = UP_DIV(UP_DIV(c4_, 4), thread_count_); +#else a_c8_ptr_ = reinterpret_cast(ctx_->allocator->Malloc(params_->row_8_ * params_->deep_ * sizeof(int8_t))); if (!a_c8_ptr_) { return RET_MEMORY_FAILED; @@ -65,6 +85,9 @@ int MatmulInt8CPUKernel::ReSize() { return RET_MEMORY_FAILED; } memset(c_r8x8_ptr_, 0, params_->row_8_ * params_->col_8_ * sizeof(int)); + thread_count_ = MSMIN(thread_count_, UP_DIV(params_->col_8_, 8)); + thread_stride_ = UP_DIV(UP_DIV(params_->col_8_, 8), thread_count_); +#endif auto input_tensor = in_tensors_[0]; auto params = input_tensor->GetQuantParams(); @@ -89,14 +112,27 @@ int MatmulInt8CPUKernel::ReSize() { } int MatmulInt8CPUKernel::RunImpl(int task_id) { +#ifdef ENABLE_ARM64 + int cur_oc = MSMIN(thread_stride_, UP_DIV(c4_, 4) - task_id * thread_stride_); + if (cur_oc <= 0) { + return RET_OK; + } + auto cur_b = b_c4d16_ptr_ + task_id * thread_stride_ * 4 * d16_; + auto cur_c = c_r4c4_ptr_ + task_id * thread_stride_ * 4 * r4_; + auto &p = quant_params_; + MatmulInt8Neon64(a_r4d16_ptr_, cur_b, cur_c, r4_, c4_, d16_, a_sums_, b_bias_, INT_MIN, INT_MAX, p.output.zp_, + p.quant_multiplier, p.left_shift, p.right_shift); +#else int cur_oc = MSMIN(thread_stride_, UP_DIV(params_->col_8_, 8) - task_id * thread_stride_); if (cur_oc <= 0) { return RET_OK; } auto cur_b = b_r8_ptr_ + task_id * thread_stride_ * C8NUM * params_->deep_; auto cur_c = c_r8x8_ptr_ + task_id * thread_stride_ * C8NUM * params_->row_8_; + MatMulInt8(a_c8_ptr_, cur_b, cur_c, params_->row_8_, cur_oc * 8, params_->deep_, quant_params_.input.zp_, quant_params_.weight.zp_); +#endif return RET_OK; } @@ -127,6 +163,24 @@ int MatmulInt8CPUKernel::Run() { auto cur_a_ptr = a_ptr + i * a_stride; auto cur_b_ptr = b_ptr + i * b_stride; auto cur_c_ptr = c_ptr + i * c_stride; + +#ifdef ENABLE_ARM64 + if (params_->a_transpose_) { + RowMajor2Col16x4Major(cur_a_ptr, params_->deep_, params_->row_, a_r4d16_ptr_, d16_); + } else { + RowMajor2Row4x16Major(cur_a_ptr, params_->row_, params_->deep_, a_r4d16_ptr_, d16_); + } + if (params_->b_transpose_) { + RowMajor2Row4x16Major(cur_b_ptr, params_->col_, params_->deep_, b_c4d16_ptr_, d16_); + } else { + RowMajor2Col16x4Major(cur_b_ptr, params_->deep_, params_->col_, b_c4d16_ptr_, d16_); + } + auto &q = quant_params_; + RowMajor2Asums(cur_a_ptr, params_->row_, params_->deep_, q.weight.zp_, a_sums_); + RowMajor2Bbias(cur_b_ptr, params_->deep_, params_->col_, q.input.zp_, q.weight.zp_, NULL, b_bias_); + LiteBackendParallelLaunch(MatmulInt8Run, this, thread_count_); + Row4x4Major2RowMajor(c_r4c4_ptr_, r4_, cur_c_ptr, params_->row_, params_->col_); +#else if (params_->a_transpose_) { RowMajor2Row8MajorInt8(cur_a_ptr, a_c8_ptr_, params_->deep_, params_->row_); } else { @@ -141,6 +195,7 @@ int MatmulInt8CPUKernel::Run() { auto &q = quant_params_; SimplePostFuncInt8(c_r8x8_ptr_, cur_c_ptr, params_->col_, params_->row_, params_->row_8_, q.quant_multiplier, q.left_shift, q.right_shift, q.output.zp_); +#endif } return RET_OK; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.h index 193495163b..7dde20d2e7 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.h @@ -39,6 +39,28 @@ class MatmulInt8CPUKernel : public MatmulBaseCPUKernel { private: void FreeTmpBuffer() { +#ifdef ENABLE_ARM64 + if (a_r4d16_ptr_ != nullptr) { + ctx_->allocator->Free(a_r4d16_ptr_); + a_r4d16_ptr_ = nullptr; + } + if (b_c4d16_ptr_ != nullptr) { + ctx_->allocator->Free(b_c4d16_ptr_); + b_c4d16_ptr_ = nullptr; + } + if (c_r4c4_ptr_ != nullptr) { + ctx_->allocator->Free(c_r4c4_ptr_); + c_r4c4_ptr_ = nullptr; + } + if (a_sums_ != nullptr) { + ctx_->allocator->Free(a_sums_); + a_sums_ = nullptr; + } + if (b_bias_ != nullptr) { + ctx_->allocator->Free(b_bias_); + b_bias_ = nullptr; + } +#else if (a_c8_ptr_ != nullptr) { ctx_->allocator->Free(a_c8_ptr_); a_c8_ptr_ = nullptr; @@ -51,12 +73,24 @@ class MatmulInt8CPUKernel : public MatmulBaseCPUKernel { ctx_->allocator->Free(c_r8x8_ptr_); c_r8x8_ptr_ = nullptr; } +#endif } MatmulQuantArg quant_params_; +#ifdef ENABLE_ARM64 + int8_t *a_r4d16_ptr_ = nullptr; + int8_t *b_c4d16_ptr_ = nullptr; + int8_t *c_r4c4_ptr_ = nullptr; + int *a_sums_ = nullptr; + int *b_bias_ = nullptr; + int r4_; + int c4_; + int d16_; +#else int8_t *a_c8_ptr_ = nullptr; int8_t *b_r8_ptr_ = nullptr; int *c_r8x8_ptr_ = nullptr; -}; +#endif +}; // namespace mindspore::kernel } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_MATMUL_INT8_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/MatmulInt8.S b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/MatmulInt8.S new file mode 100644 index 0000000000..9f1c11a3e9 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/MatmulInt8.S @@ -0,0 +1,276 @@ +#ifdef __aarch64__ + .text + .align 5 + .global MatmulInt8Neon64 +#ifndef __APPLE__ + .type MatmulInt8Neon64, %function +#endif + +// +// int8 RM 16x4 block +// /-----------------------------------------| +// |v4.b[0] v5.b[0] v6.b[0] v7.b[0] | +// | ... ... ... ... | +// |v4.b[15] v5.b[15] v5.b[15] v7.b[15] | +// \-----------------------------------------/ +// int8 LM 4x16 block +// /---------------------\ /-----------------------------------------| +// |v0.b[0] ... v0.b[15] | |v16.4s v17.4s v18.4s v19.4s | +// |v1.b[0] ... v1.b[15] | |v20.4s v21.4s v22.4s v23.4s | +// |v2.b[0] ... v2.b[15] | |v24.4s v25.4s v26.4s v27.4s | +// |v3.b[0] ... v3.b[15] | |v28.4s v29.4s v30.4s v31.4s | +// \---------------------/ \-----------------------------------------/ +// int32 accumulators 4x4 block + +//void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16, +// const int *a_sums, const int *bias, int act_min, int act_max, int out_zp, +// int multiplier, int left_shift, int right_shift); + +// x0: a(left matrix ptr) +// x1: b(right matrix ptr) +// x2: out ptr +// w3: row4 +// w4: col4 +// w5: deep16 +// x6: a_sums +// x7: bias +// w8: act_min +// w9: act_max +// w10: out_zp +// w11: multiplier +// w12: left_shift +// w13: right_shift + +MatmulInt8Neon64: + sub sp, sp, #160 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + stp x19, x20, [sp], #16 + stp x21, x22, [sp], #16 + + ldr w8, [sp] + ldr w9, [sp, #8] + ldr w10, [sp, #16] + ldr w11, [sp, #24] + ldr w12, [sp, #32] + ldr w13, [sp, #40] + + mov w15, #0 // b col index + mov w16, #0 // a row index + mov w17, #4 // sizeof(int8)*4 + mul w21, w5, w17 // the stride of a/b: sizeof(int8)*4*deep16 + +L1: + cmp w15, w4 + beq End1 + + mov w16, #0 // reset a row index + mov x17, x0 // reload a ptr + mov x22, x6 // reload a_sums ptr +L2: + cmp w16, w3 + beq End2 + + mov x18, x1 // reload b ptr + mov x19, x7 // reload bias ptr + mov w20, w5 // reload depth + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + dup v24.4s, wzr + dup v25.4s, wzr + dup v26.4s, wzr + dup v27.4s, wzr + dup v28.4s, wzr + dup v29.4s, wzr + dup v30.4s, wzr + dup v31.4s, wzr +L3: + cmp w20, #0 + beq End3 + + ld1 {v0.16b}, [x17], #16 + ld1 {v1.16b}, [x17], #16 + ld1 {v2.16b}, [x17], #16 + ld1 {v3.16b}, [x17], #16 + ld1 {v4.16b}, [x18], #16 + ld1 {v5.16b}, [x18], #16 + ld1 {v6.16b}, [x18], #16 + ld1 {v7.16b}, [x18], #16 + + smull v8.8h, v4.8b, v0.8b + smull v9.8h, v5.8b, v0.8b + smull v10.8h, v6.8b, v0.8b + smull v11.8h, v7.8b, v0.8b + smull v12.8h, v4.8b, v1.8b + smull v13.8h, v5.8b, v1.8b + smull v14.8h, v6.8b, v1.8b + smull v15.8h, v7.8b, v1.8b + + smlal2 v8.8h, v4.16b, v0.16b + smlal2 v9.8h, v5.16b, v0.16b + smlal2 v10.8h, v6.16b, v0.16b + smlal2 v11.8h, v7.16b, v0.16b + smlal2 v12.8h, v4.16b, v1.16b + smlal2 v13.8h, v5.16b, v1.16b + smlal2 v14.8h, v6.16b, v1.16b + smlal2 v15.8h, v7.16b, v1.16b + + sadalp v16.4s, v8.8h + sadalp v17.4s, v9.8h + sadalp v18.4s, v10.8h + sadalp v19.4s, v11.8h + sadalp v20.4s, v12.8h + sadalp v21.4s, v13.8h + sadalp v22.4s, v14.8h + sadalp v23.4s, v15.8h + + smull v8.8h, v4.8b, v2.8b + smull v9.8h, v5.8b, v2.8b + smull v10.8h, v6.8b, v2.8b + smull v11.8h, v7.8b, v2.8b + smull v12.8h, v4.8b, v3.8b + smull v13.8h, v5.8b, v3.8b + smull v14.8h, v6.8b, v3.8b + smull v15.8h, v7.8b, v3.8b + + smlal2 v8.8h, v4.16b, v2.16b + smlal2 v9.8h, v5.16b, v2.16b + smlal2 v10.8h, v6.16b, v2.16b + smlal2 v11.8h, v7.16b, v2.16b + smlal2 v12.8h, v4.16b, v3.16b + smlal2 v13.8h, v5.16b, v3.16b + smlal2 v14.8h, v6.16b, v3.16b + smlal2 v15.8h, v7.16b, v3.16b + + sadalp v24.4s, v8.8h + sadalp v25.4s, v9.8h + sadalp v26.4s, v10.8h + sadalp v27.4s, v11.8h + sadalp v28.4s, v12.8h + sadalp v29.4s, v13.8h + sadalp v30.4s, v14.8h + sadalp v31.4s, v15.8h + subs w20, w20, #16 // depth + 16 + b L3 + +End3: + addp v16.4s, v16.4s, v17.4s + addp v18.4s, v18.4s, v19.4s + addp v20.4s, v20.4s, v21.4s + addp v22.4s, v22.4s, v23.4s + addp v24.4s, v24.4s, v25.4s + addp v26.4s, v26.4s, v27.4s + addp v28.4s, v28.4s, v29.4s + addp v30.4s, v30.4s, v31.4s + + addp v16.4s, v16.4s, v18.4s + addp v17.4s, v20.4s, v22.4s + addp v18.4s, v24.4s, v26.4s + addp v19.4s, v28.4s, v30.4s + + // Add (Bias+Depth*Za*Zb-Za*Bsums) + ld1 {v15.4s}, [x19], #16 + add v16.4s, v16.4s, v15.4s + add v17.4s, v17.4s, v15.4s + add v18.4s, v18.4s, v15.4s + add v19.4s, v19.4s, v15.4s + + // Subtract (Asums*Zb) + ld1 {v14.4s}, [x22], #16 + dup v20.4s, v14.s[0] + dup v21.4s, v14.s[1] + dup v22.4s, v14.s[2] + dup v23.4s, v14.s[3] + sub v16.4s, v16.4s, v20.4s + sub v17.4s, v17.4s, v21.4s + sub v18.4s, v18.4s, v22.4s + sub v19.4s, v19.4s, v23.4s + + // Apply left shift + dup v13.4s, w12 + sqshl v16.4s, v16.4s, v13.4s + sqshl v17.4s, v17.4s, v13.4s + sqshl v18.4s, v18.4s, v13.4s + sqshl v19.4s, v19.4s, v13.4s + + // Apply the fixed-point part of the multiplier. + dup v12.4s, w11 + sqrdmulh v16.4s, v16.4s, v12.4s + sqrdmulh v17.4s, v17.4s, v12.4s + sqrdmulh v18.4s, v18.4s, v12.4s + sqrdmulh v19.4s, v19.4s, v12.4s + + // Apply right shift + dup v11.4s, w13 + and v20.16b, v11.16b, v16.16b + sshr v20.4s, v20.4s, #31 + sqadd v16.4s, v16.4s, v20.4s + srshl v16.4s, v16.4s, v11.4s + and v21.16b, v11.16b, v17.16b + sshr v21.4s, v21.4s, #31 + sqadd v17.4s, v17.4s, v21.4s + srshl v17.4s, v17.4s, v11.4s + and v22.16b, v11.16b, v18.16b + sshr v22.4s, v22.4s, #31 + sqadd v18.4s, v18.4s, v22.4s + srshl v18.4s, v18.4s, v11.4s + and v23.16b, v11.16b, v19.16b + sshr v23.4s, v23.4s, #31 + sqadd v19.4s, v19.4s, v23.4s + srshl v19.4s, v19.4s, v11.4s + + // Add the destination zero point + dup v10.4s, w10 + add v16.4s, v16.4s, v10.4s + add v17.4s, v17.4s, v10.4s + add v18.4s, v18.4s, v10.4s + add v19.4s, v19.4s, v10.4s + + // Apply the act_min bound + dup v9.4s, w8 + smax v16.4s, v16.4s, v9.4s + smax v17.4s, v17.4s, v9.4s + smax v18.4s, v18.4s, v9.4s + smax v19.4s, v19.4s, v9.4s + + // Apply the act_min bound + dup v8.4s, w9 + smin v16.4s, v16.4s, v8.4s + smin v17.4s, v17.4s, v8.4s + smin v18.4s, v18.4s, v8.4s + smin v19.4s, v19.4s, v8.4s + + // int32 -> int16 + sqxtn v13.4h, v16.4s + sqxtn2 v13.8h, v17.4s + sqxtn v14.4h, v18.4s + sqxtn2 v14.8h, v19.4s + + // int16 -> int8 + sqxtn v15.8b, v13.8h + sqxtn2 v15.16b, v14.8h + + st1 {v15.16b}, [x2], #16 + add w16, w16, #4 // a row index + 4 + b L2 + +End2: + add w15, w15, #4 // b col index + 4 + add x1, x1, x21 // b ptr + stride + add x7, x7, #16 // bias ptr + stride + b L1 + +End1: + sub sp, sp, #160 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/MatmulR4Int8.S b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/MatmulR4Int8.S new file mode 100644 index 0000000000..7378612110 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/MatmulR4Int8.S @@ -0,0 +1,194 @@ +#ifdef __aarch64__ + .text + .align 5 + .global MatMulR4Int8Neon64 +#ifndef __APPLE__ + .type MatMulR4Int8Neon64, %function +#endif + +// +// int8 RM 16x4 block +// /-----------------------------------------| +// |v4.b[0] v5.b[0] v6.b[0] v7.b[0] | +// | ... ... ... ... | +// |v4.b[15] v5.b[15] v5.b[15] v7.b[15] | +// \-----------------------------------------/ +// int8 LM 4x16 block +// /---------------------\ /-----------------------------------------| +// |v0.b[0] ... v0.b[15] | |v16.4s v17.4s v18.4s v19.4s | +// |v1.b[0] ... v1.b[15] | |v20.4s v21.4s v22.4s v23.4s | +// |v2.b[0] ... v2.b[15] | |v24.4s v25.4s v26.4s v27.4s | +// |v3.b[0] ... v3.b[15] | |v28.4s v29.4s v30.4s v31.4s | +// \---------------------/ \-----------------------------------------/ +// int32 accumulators 4x4 block + +//void MatMulR4Int8Neon64(const int8_t *a, const int8_t *b, int32_t *dst, int row4, int col4, int deep16, +// const int *input_sum, const int *bias) + +// x0: a(left matrix ptr) +// x1: b(right matrix ptr) +// x2: out ptr +// w3: row4 +// w4: col4 +// w5: deep16 +// x6: a_sums +// x7: bias + +MatMulR4Int8Neon64: + sub sp, sp, #128 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + + mov w15, #0 // b col index + mov w16, #0 // a row index + mov w17, #4 // sizeof(int8)*4 + mul w12, w5, w17 // the stride of a/b: sizeof(int8)*4*deep16 + +L1: + cmp w15, w4 + beq End1 + + mov w16, #0 // reset a row index + mov x17, x0 // reload a ptr + mov x13, x6 // reload a_sums ptr +L2: + cmp w16, w3 + beq End2 + + mov x18, x1 // reload b ptr + mov x10, x7 // reload bias ptr + mov w11, w5 // reload depth + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + dup v24.4s, wzr + dup v25.4s, wzr + dup v26.4s, wzr + dup v27.4s, wzr + dup v28.4s, wzr + dup v29.4s, wzr + dup v30.4s, wzr + dup v31.4s, wzr +L3: + cmp w11, #0 + beq End3 + + ld1 {v0.16b}, [x17], #16 + ld1 {v1.16b}, [x17], #16 + ld1 {v2.16b}, [x17], #16 + ld1 {v3.16b}, [x17], #16 + ld1 {v4.16b}, [x18], #16 + ld1 {v5.16b}, [x18], #16 + ld1 {v6.16b}, [x18], #16 + ld1 {v7.16b}, [x18], #16 + + smull v8.8h, v4.8b, v0.8b + smull v9.8h, v5.8b, v0.8b + smull v10.8h, v6.8b, v0.8b + smull v11.8h, v7.8b, v0.8b + smull v12.8h, v4.8b, v1.8b + smull v13.8h, v5.8b, v1.8b + smull v14.8h, v6.8b, v1.8b + smull v15.8h, v7.8b, v1.8b + + smlal2 v8.8h, v4.16b, v0.16b + smlal2 v9.8h, v5.16b, v0.16b + smlal2 v10.8h, v6.16b, v0.16b + smlal2 v11.8h, v7.16b, v0.16b + smlal2 v12.8h, v4.16b, v1.16b + smlal2 v13.8h, v5.16b, v1.16b + smlal2 v14.8h, v6.16b, v1.16b + smlal2 v15.8h, v7.16b, v1.16b + + sadalp v16.4s, v8.8h + sadalp v17.4s, v9.8h + sadalp v18.4s, v10.8h + sadalp v19.4s, v11.8h + sadalp v20.4s, v12.8h + sadalp v21.4s, v13.8h + sadalp v22.4s, v14.8h + sadalp v23.4s, v15.8h + + smull v8.8h, v4.8b, v2.8b + smull v9.8h, v5.8b, v2.8b + smull v10.8h, v6.8b, v2.8b + smull v11.8h, v7.8b, v2.8b + smull v12.8h, v4.8b, v3.8b + smull v13.8h, v5.8b, v3.8b + smull v14.8h, v6.8b, v3.8b + smull v15.8h, v7.8b, v3.8b + + smlal2 v8.8h, v4.16b, v2.16b + smlal2 v9.8h, v5.16b, v2.16b + smlal2 v10.8h, v6.16b, v2.16b + smlal2 v11.8h, v7.16b, v2.16b + smlal2 v12.8h, v4.16b, v3.16b + smlal2 v13.8h, v5.16b, v3.16b + smlal2 v14.8h, v6.16b, v3.16b + smlal2 v15.8h, v7.16b, v3.16b + + sadalp v24.4s, v8.8h + sadalp v25.4s, v9.8h + sadalp v26.4s, v10.8h + sadalp v27.4s, v11.8h + sadalp v28.4s, v12.8h + sadalp v29.4s, v13.8h + sadalp v30.4s, v14.8h + sadalp v31.4s, v15.8h + subs w11, w11, #16 // depth + 16 + b L3 + +End3: + addp v16.4s, v16.4s, v17.4s + addp v18.4s, v18.4s, v19.4s + addp v20.4s, v20.4s, v21.4s + addp v22.4s, v22.4s, v23.4s + addp v24.4s, v24.4s, v25.4s + addp v26.4s, v26.4s, v27.4s + addp v28.4s, v28.4s, v29.4s + addp v30.4s, v30.4s, v31.4s + + addp v16.4s, v16.4s, v18.4s + addp v17.4s, v20.4s, v22.4s + addp v18.4s, v24.4s, v26.4s + addp v19.4s, v28.4s, v30.4s + + // Add (Bias+Depth*Za*Zb-Za*Bsums) + ld1 {v15.4s}, [x10], #16 + add v16.4s, v16.4s, v15.4s + add v17.4s, v17.4s, v15.4s + add v18.4s, v18.4s, v15.4s + add v19.4s, v19.4s, v15.4s + + // Subtract (Asums*Zb) + ld1 {v14.4s}, [x13], #16 + dup v20.4s, v14.s[0] + dup v21.4s, v14.s[1] + dup v22.4s, v14.s[2] + dup v23.4s, v14.s[3] + sub v16.4s, v16.4s, v20.4s + sub v17.4s, v17.4s, v21.4s + sub v18.4s, v18.4s, v22.4s + sub v19.4s, v19.4s, v23.4s + + st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + add w16, w16, #4 // a row index + 4 + b L2 + +End2: + add w15, w15, #4 // b col index + 4 + add x1, x1, x12 // b ptr + stride + add x7, x7, #16 // bias ptr + stride + b L1 + +End1: + sub sp, sp, #128 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ret +#endif diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/opt/MatmulOptR4Int8.S b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/opt/MatmulOptR4Int8.S new file mode 100644 index 0000000000..7a56cd2994 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/opt/MatmulOptR4Int8.S @@ -0,0 +1,157 @@ +#ifdef __aarch64__ + .text + .align 5 + .global MatMulOptR4Int8Neon64 +#ifndef __APPLE__ + .type MatMulOptR4Int8Neon64, %function +#endif + +// +// int8 RM 16x4 block +// /-----------------------------------------| +// |v4.b[0] v5.b[0] v6.b[0] v7.b[0] | +// | ... ... ... ... | +// |v4.b[15] v5.b[15] v5.b[15] v7.b[15] | +// \-----------------------------------------/ +// int8 LM 4x16 block +// /---------------------\ /-----------------------------------------| +// |v0.b[0] ... v0.b[15] | |v16.4s v17.4s v18.4s v19.4s | +// |v1.b[0] ... v1.b[15] | |v20.4s v21.4s v22.4s v23.4s | +// |v2.b[0] ... v2.b[15] | |v24.4s v25.4s v26.4s v27.4s | +// |v3.b[0] ... v3.b[15] | |v28.4s v29.4s v30.4s v31.4s | +// \---------------------/ \-----------------------------------------/ +// int32 accumulators 4x4 block + +//void MatMulOptR4Int8Neon64(const int8_t *a, const int8_t *b, int *dst, int row4, int col4, int deep16, +// const int *input_sum, const int *bias) + +// x0: a(left matrix ptr) +// x1: b(right matrix ptr) +// x2: out ptr +// w3: row4 +// w4: col4 +// w5: deep16 +// x6: a_sums +// x7: bias + +MatMulOptR4Int8Neon64: + sub sp, sp, #128 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + + mov w15, #0 // b col index + mov w16, #0 // a row index + mov w17, #4 // sizeof(int8)*4 + mul w12, w5, w17 // the stride of a/b: sizeof(int8)*4*deep16 + +L1: + cmp w15, w4 + beq End1 + + mov w16, #0 // reset a row index + mov x17, x0 // reload a ptr + mov x13, x6 // reload a_sums ptr +L2: + cmp w16, w3 + beq End2 + + mov x18, x1 // reload b ptr + mov x10, x7 // reload bias ptr + mov w11, w5 // reload depth + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + dup v24.4s, wzr + dup v25.4s, wzr + dup v26.4s, wzr + dup v27.4s, wzr + dup v28.4s, wzr + dup v29.4s, wzr + dup v30.4s, wzr + dup v31.4s, wzr +L3: + cmp w11, #0 + beq End3 + + ld1 {v0.16b}, [x17], #16 + ld1 {v1.16b}, [x17], #16 + ld1 {v2.16b}, [x17], #16 + ld1 {v3.16b}, [x17], #16 + ld1 {v4.16b}, [x18], #16 + ld1 {v5.16b}, [x18], #16 + ld1 {v6.16b}, [x18], #16 + ld1 {v7.16b}, [x18], #16 + + sdot v16.4s, v4.16b, v0.16b + sdot v17.4s, v5.16b, v0.16b + sdot v18.4s, v6.16b, v0.16b + sdot v19.4s, v7.16b, v0.16b + sdot v20.4s, v4.16b, v1.16b + sdot v21.4s, v5.16b, v1.16b + sdot v22.4s, v6.16b, v1.16b + sdot v23.4s, v7.16b, v1.16b + sdot v24.4s, v4.16b, v2.16b + sdot v25.4s, v5.16b, v2.16b + sdot v26.4s, v6.16b, v2.16b + sdot v27.4s, v7.16b, v2.16b + sdot v28.4s, v4.16b, v3.16b + sdot v29.4s, v5.16b, v3.16b + sdot v30.4s, v6.16b, v3.16b + sdot v31.4s, v7.16b, v3.16b + subs w11, w11, #16 // depth + 16 + b L3 + +End3: + addp v16.4s, v16.4s, v17.4s + addp v18.4s, v18.4s, v19.4s + addp v20.4s, v20.4s, v21.4s + addp v22.4s, v22.4s, v23.4s + addp v24.4s, v24.4s, v25.4s + addp v26.4s, v26.4s, v27.4s + addp v28.4s, v28.4s, v29.4s + addp v30.4s, v30.4s, v31.4s + + addp v16.4s, v16.4s, v18.4s + addp v17.4s, v20.4s, v22.4s + addp v18.4s, v24.4s, v26.4s + addp v19.4s, v28.4s, v30.4s + + // Add (Bias+Depth*Za*Zb-Za*Bsums) + ld1 {v15.4s}, [x10], #16 + add v16.4s, v16.4s, v15.4s + add v17.4s, v17.4s, v15.4s + add v18.4s, v18.4s, v15.4s + add v19.4s, v19.4s, v15.4s + + // Subtract (Asums*Zb) + ld1 {v14.4s}, [x13], #16 + dup v20.4s, v14.s[0] + dup v21.4s, v14.s[1] + dup v22.4s, v14.s[2] + dup v23.4s, v14.s[3] + sub v16.4s, v16.4s, v20.4s + sub v17.4s, v17.4s, v21.4s + sub v18.4s, v18.4s, v22.4s + sub v19.4s, v19.4s, v23.4s + + st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + add w16, w16, #4 // a row index + 4 + b L2 + +End2: + add w15, w15, #4 // b col index + 4 + add x1, x1, x12 // b ptr + stride + add x7, x7, #16 // bias ptr + stride + b L1 + +End1: + sub sp, sp, #128 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ret +#endif diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/matmul.c b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/matmul.c index e5f879e0c3..03c65ef23f 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/matmul.c +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/matmul.c @@ -269,7 +269,7 @@ void MatMul8x8(const float *a, const float *b, float *dst, const float *bias, Ac void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row, int col, int stride, bool write_nhwc) { -#ifdef __aarch64__ +#ifdef ENABLE_ARM64 MatmulFloatNeon64(a, b, c, bias, (int)act_type, deep, row, col, stride, write_nhwc); #else MatMul8x8(a, b, c, bias, act_type, deep, row, col, stride, write_nhwc); diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/matmul.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/matmul.h index ce50f2ff56..e612d7c5a3 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/matmul.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/matmul.h @@ -31,7 +31,7 @@ void MatMul(const float *a, const float *b, float *c, const float *bias, ActType void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col); void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col); void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, size_t row, size_t col, size_t stride); -#ifdef __aarch64__ +#ifdef ENABLE_ARM64 void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, int col, size_t stride, bool write_nhwc); #endif diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/deconv.c b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/deconv.c index 3ad6cf904d..30d3a8f734 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/deconv.c +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/deconv.c @@ -197,7 +197,7 @@ int DeConvInt8(const int8_t *input, const int8_t *weight, int32_t *output, int32 size_t act_row, size_t act_col, size_t act_deep, ConvParameter *conv_param, MATMUL_OPT_R4_FUNC matmul_func) { if (matmul_func != NULL) { - matmul_func(output, input, weight, weight_sum, input_sum, act_row, act_col, act_deep); + matmul_func(input, weight, output, act_row, act_col, act_deep, input_sum, weight_sum); } else { /* todo normal int8 deconv */ } diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/matmul_int8.c b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/matmul_int8.c index bca1f639fc..b300d4d148 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/matmul_int8.c +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/matmul_int8.c @@ -74,8 +74,8 @@ void MatMulInt8(const int8_t *a, const int8_t *b, int32_t *c, const int row8, co } } -void MatMulOptR4Int8(int32_t *dst, const int8_t *a, const int8_t *b, const int32_t *bias, const int32_t *input_sum, - size_t row_4, size_t col_4, size_t deep_16) { +void MatMulOptR4Int8(const int8_t *a, const int8_t *b, int *dst, int row_4, int col_4, int deep_16, + const int *input_sum, const int *bias) { /* row4x16-major * row16x4-major => row4x4-major */ for (int r = 0; r < row_4; r++) { for (int c = 0; c < col_4; c++) { @@ -96,3 +96,61 @@ void MatMulOptR4Int8(int32_t *dst, const int8_t *a, const int8_t *b, const int32 } return; } + +#ifdef ENABLE_ARM64 +void RowMajor2Row4x16Major(int8_t *src, int row, int col, int8_t *dst, int col_16) { + int stride = sizeof(int8_t) * 16 * 4; + for (int r = 0; r < row; ++r) { + for (int c = 0; c < col; ++c) { + int stride_n = r / 4 * (col_16 / 16) + c / 16; + int src_idx = r * col + c; + dst[stride * stride_n + r % 4 * 16 + c % 16] = src[src_idx]; + } + } +} + +void RowMajor2Col16x4Major(int8_t *src, int row, int col, int8_t *dst, int row_16) { + int stride = sizeof(int8_t) * 16 * 4; + for (int r = 0; r < row; ++r) { + for (int c = 0; c < col; ++c) { + int stride_n = c / 4 * (row_16 / 16) + r / 16; + int src_idx = r * col + c; + dst[stride * stride_n + c % 4 * 16 + r % 16] = src[src_idx]; + } + } +} + +void RowMajor2Asums(int8_t *a, int row, int col, int b_zp, int *dst) { + for (int r = 0; r < row; ++r) { + for (int c = 0; c < col; ++c) { + int src_idx = r * col + c; + dst[r] += a[src_idx]; + } + dst[r] *= b_zp; + } +} + +void RowMajor2Bbias(int8_t *b, int row, int col, int a_zp, int b_zp, int *bias, int *dst) { + for (int c = 0; c < col; ++c) { + for (int r = 0; r < row; ++r) { + int src_idx = r * col + c; + dst[c] += b[src_idx]; + } + dst[c] = row * a_zp * b_zp - a_zp * dst[c]; + if (bias) { + dst[c] += bias[c]; + } + } +} + +void Row4x4Major2RowMajor(int8_t *src, int row4, int8_t *dst, int row, int cow) { + int stride = sizeof(int8_t) * 4 * 4; + for (int r = 0; r < row; ++r) { + for (int c = 0; c < cow; ++c) { + int sride_n = c / 4 * (row4 / 4) + r / 4; + int dst_idx = r * cow + c; + dst[dst_idx] = src[stride * sride_n + r % 4 * 4 + c % 4]; + } + } +} +#endif diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/matmul_int8.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/matmul_int8.h index 860863f633..0d54b3222f 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/matmul_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/matmul_int8.h @@ -23,14 +23,29 @@ #ifdef __cplusplus extern "C" { #endif -void MatMulInt8(const int8_t *a, const int8_t *b, int32_t *c, const int row8, const int col8, const int deep, - const int32_t a_zp, const int32_t b_zp); -void MatMulOptR4Int8(int32_t *dst, const int8_t *a, const int8_t *b, const int32_t *bias, const int32_t *input_sum, - size_t row_4, size_t col_4, size_t deep_16); - +void MatMulInt8(const int8_t *a, const int8_t *b, int *c, const int row8, const int col8, const int deep, + const int a_zp, const int b_zp); +void MatMulOptR4Int8(const int8_t *a, const int8_t *b, int *dst, int row_4, int col_4, int deep_16, + const int *input_sum, const int *bias); void RowMajor2Row8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col); void RowMajor2Col8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col); void RowMajor2Row16x4MajorInt8(void *src_ptr, void *dst_ptr, int row, int col); + +#ifdef ENABLE_ARM64 +void RowMajor2Row4x16Major(int8_t *src, int row, int col, int8_t *dst, int col_16); +void RowMajor2Col16x4Major(int8_t *src, int row, int col, int8_t *dst, int row_16); +void RowMajor2Asums(int8_t *a, int row, int col, int b_zp, int *dst); +void RowMajor2Bbias(int8_t *b, int row, int col, int a_zp, int b_zp, int *bias, int *dst); +void Row4x4Major2RowMajor(int8_t *src, int row4, int8_t *dst, int row, int cow); + +// bias = bias + depth * a_zp * b_zp - a_zp * b_sums +void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16, const int *a_sums, + const int *bias, int act_min, int act_max, int out_zp, int multiplier, int left_shift, + int right_shift); + +void MatMulR4Int8Neon64(const int8_t *a, const int8_t *b, int32_t *dst, int row4, int col4, int deep16, + const int *input_sum, const int *bias); +#endif #ifdef __cplusplus } #endif diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/matmul_parameter.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/matmul_parameter.h index 54c6d1c9a6..f7061f83fd 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/matmul_parameter.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/matmul_parameter.h @@ -19,8 +19,8 @@ #include "nnacl/op_base.h" -typedef void (*MATMUL_OPT_R4_FUNC)(int32_t *dst, const int8_t *a, const int8_t *b, const int32_t *bias, - const int32_t *input_sum, size_t row_4, size_t col_4, size_t deep_16); +typedef void (*MATMUL_OPT_R4_FUNC)(const int8_t *a, const int8_t *b, int *dst, int row_4, int col_4, int deep_16, + const int *input_sum, const int *bias); typedef void (*MAT_TRANS_FUNC)(void *dst, void *a, int row, int col); diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/opt_op_handler.c b/mindspore/lite/src/runtime/kernel/arm/nnacl/opt_op_handler.c index bb38c99070..47149e6653 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/opt_op_handler.c +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/opt_op_handler.c @@ -23,6 +23,10 @@ extern void IndirectGemmInt8_24x4_dp(int8_t *dst, const int8_t *src, const int8_ size_t ksize, size_t ic4, size_t output_channel, size_t offset, const int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, size_t out_multiplier, size_t shift_before, size_t shift_after); + +extern void MatMulOptR4Int8Neon64(const int8_t *a, const int8_t *b, int *dst, int row4, int col4, int deep16, + const int *input_sum, const int *bias); + #ifdef __cplusplus } #endif @@ -35,4 +39,9 @@ void IndirectGemmInt8_optimize_handler(int8_t *dst, const int8_t *src, const int return IndirectGemmInt8_24x4_dp(dst, src, weight, bias, ksize, ic4, output_channel, offset, input_sum, act_min, act_max, out_zp, out_multiplier, shift_before, shift_after); } + +void MatMulR4Int8_optimize_handler(const int8_t *a, const int8_t *b, int *dst, int row4, int col4, int deep16, + const int *input_sum, const int *bias) { + return MatMulOptR4Int8Neon64(a, b, dst, row4, col4, deep16, input_sum, bias); +} #endif diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/optimized_kernel.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/optimized_kernel.h index 29c56bf04b..87f0c03dfd 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/optimized_kernel.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/optimized_kernel.h @@ -64,7 +64,7 @@ class OptimizeModule { optimized_op_handler_ = dlopen(OPTIMIZE_SHARED_LIBRARY_PATH, RTLD_LAZY); #endif if (optimized_op_handler_ == nullptr) { - printf("Open optimize shared library failed.\n"); + printf("Open optimize shared library failed: %s\n", dlerror()); } } diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/quantization/quantize.c b/mindspore/lite/src/runtime/kernel/arm/nnacl/quantization/quantize.c index 169d064778..4701719564 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/quantization/quantize.c +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/quantization/quantize.c @@ -26,9 +26,7 @@ const double dNormalizer = 0x1p54; const int dNormalizerBias = 54; const int iMantissaBits = 31; - -void QuantizeMultiplierSmallerThanOne(double double_multiplier, int32_t *quantized_multiplier, - int *right_shift) { +void QuantizeMultiplierSmallerThanOne(double double_multiplier, int32_t *quantized_multiplier, int *right_shift) { if (quantized_multiplier == NULL || right_shift == NULL) { return; } @@ -55,10 +53,9 @@ uint8_t QuantizeToUint8(float real_value, float scale, int32_t zp) { return roun int32_t QuantizeToInt8(float real_value, float scale, int32_t zp) { return round(real_value / scale + zp); } -void CalculateActivationRangeQuantized(bool is_relu, bool is_relu6, int32_t zp, float scale, int *mini, - int *maxi) { - int32_t min = CHAR_MIN; - int32_t max = CHAR_MAX; +void CalculateActivationRangeQuantized(bool is_relu, bool is_relu6, int32_t zp, float scale, int *mini, int *maxi) { + int32_t min = INT8_MIN; + int32_t max = INT8_MAX; int32_t quantized_zero = QuantizeToInt8(0, scale, zp); int32_t quantized_six = QuantizeToInt8(6, scale, zp); if (is_relu) { @@ -77,8 +74,8 @@ void CalculateActivationRangeQuantized(bool is_relu, bool is_relu6, int32_t zp, void Quantize(float *input_data, int length, float scale, int zero_point, int8_t *output_data) { for (int i = 0; i < length; ++i) { int q = (int)round(input_data[i] / scale + zero_point); - q = q > CHAR_MAX ? CHAR_MAX : q; - q = q < CHAR_MIN ? CHAR_MIN : q; + q = q > SCHAR_MAX ? SCHAR_MAX : q; + q = q < SCHAR_MIN ? SCHAR_MIN : q; output_data[i] = (int8_t)q; } } diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/deconv_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/deconv_int8_tests.cc index a4c32b1778..f90ea33975 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/deconv_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/deconv_int8_tests.cc @@ -270,7 +270,7 @@ TEST_F(TestDeconvInt8, MatMulOptTest1) { 7894, -51, 0, 0, -4775, -29785, 0, 0, -12597, 4088, 0, 0, -17420, 1815, 0, 0, 15796, 3101, 0, 0, -37969, -10818, 0, 0, 12714, -7827, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - MatMulOptR4Int8(tmp_output, packed_a, packed_b, weight_sum, input_sum, 12, 24, 16); + MatMulOptR4Int8(packed_a, packed_b, tmp_output, 12, 24, 16, input_sum, weight_sum); CompareOutputData(tmp_output, correct_tmp_output, 12 * 3 * 8, 0); } diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/matmul_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/matmul_int8_tests.cc index d556da8d1a..22a1af6a77 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/matmul_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/matmul_int8_tests.cc @@ -116,7 +116,6 @@ TEST_F(TestMatmulInt8, mmint8) { Dequantize(reinterpret_cast(outputs_[0]->Data()), outputs_[0]->ElementsNum(), output_scale, output_zp, fout); CompareOutputData(fout, correct, 6, 0.3); - delete matmul_param; delete mm; for (auto t : inputs_) delete t; for (auto t : outputs_) delete t;