Merge pull request !4841 from zhanyuan/devtags/v0.7.0-beta
| @@ -18,6 +18,7 @@ | |||
| #include "src/runtime/kernel/arm/nnacl/quantization/fixed_point.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "src/runtime/kernel/arm/nnacl/optimized_kernel.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| @@ -89,14 +90,24 @@ int DeConvInt8CPUKernel::Init() { | |||
| void DeConvInt8CPUKernel::CheckSupportOptimize() { | |||
| matmul_func_ = nullptr; | |||
| support_optimize_ = false; | |||
| support_optimize_ = true; | |||
| #ifdef ENABLE_ARM64 | |||
| /* todo */ | |||
| void *optimize_op_handler = OptimizeModule::GetInstance()->optimized_op_handler_; | |||
| if (optimize_op_handler != nullptr) { | |||
| dlerror(); | |||
| *(reinterpret_cast<void **>(&matmul_func_)) = dlsym(optimize_op_handler, "MatMulR4Int8_optimize_handler"); | |||
| auto dlopen_error = dlerror(); | |||
| if (dlopen_error != nullptr) { | |||
| MS_LOG(ERROR) << "load matmul func failed! " << dlopen_error << "."; | |||
| support_optimize_ = false; | |||
| matmul_func_ = nullptr; | |||
| } | |||
| } else { | |||
| support_optimize_ = false; | |||
| matmul_func_ = nullptr; | |||
| } | |||
| #endif | |||
| support_optimize_ = true; | |||
| matmul_func_ = MatMulOptR4Int8; | |||
| } | |||
| int DeConvInt8CPUKernel::InitParam() { | |||
| @@ -109,15 +120,10 @@ int DeConvInt8CPUKernel::InitParam() { | |||
| matmul_param_->deep_ = conv_param_->input_channel_; | |||
| matmul_param_->col_ = conv_param_->output_channel_ * conv_param_->kernel_h_ * conv_param_->kernel_w_; | |||
| if (support_optimize_) { | |||
| input_trans_func_ = RowMajor2Row16x4MajorInt8; | |||
| size_t oc4 = UP_DIV(conv_param_->output_channel_, C4NUM); | |||
| thread_count_ = MSMIN(op_parameter_->thread_num_, oc4); | |||
| thread_stride_ = UP_DIV(oc4, thread_count_); | |||
| } else { | |||
| /*todo */ | |||
| } | |||
| input_trans_func_ = RowMajor2Row16x4MajorInt8; | |||
| size_t oc4 = UP_DIV(conv_param_->output_channel_, C4NUM); | |||
| thread_count_ = MSMIN(op_parameter_->thread_num_, oc4); | |||
| thread_stride_ = UP_DIV(oc4, thread_count_); | |||
| return RET_OK; | |||
| } | |||
| @@ -47,9 +47,29 @@ int MatmulInt8CPUKernel::ReSize() { | |||
| params_->deep_ = params_->a_transpose_ ? x_shape[x_shape.size() - 2] : x_shape[x_shape.size() - 1]; | |||
| params_->row_8_ = UP_ROUND(params_->row_, 8); | |||
| params_->col_8_ = UP_ROUND(params_->col_, 8); | |||
| thread_count_ = MSMIN(thread_count_, UP_DIV(params_->col_8_, 8)); | |||
| thread_stride_ = UP_DIV(UP_DIV(params_->col_8_, 8), thread_count_); | |||
| #ifdef ENABLE_ARM64 | |||
| r4_ = UP_ROUND(params_->row_, 4); | |||
| c4_ = UP_ROUND(params_->col_, 4); | |||
| d16_ = UP_ROUND(params_->deep_, 16); | |||
| a_r4d16_ptr_ = reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(r4_ * d16_ * sizeof(int8_t))); | |||
| if (!a_r4d16_ptr_) return RET_MEMORY_FAILED; | |||
| memset(a_r4d16_ptr_, 0, r4_ * d16_ * sizeof(int8_t)); | |||
| b_c4d16_ptr_ = reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(c4_ * d16_ * sizeof(int8_t))); | |||
| if (!b_c4d16_ptr_) return RET_MEMORY_FAILED; | |||
| memset(b_c4d16_ptr_, 0, c4_ * d16_ * sizeof(int8_t)); | |||
| c_r4c4_ptr_ = reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(r4_ * c4_ * sizeof(int8_t))); | |||
| if (!c_r4c4_ptr_) return RET_MEMORY_FAILED; | |||
| memset(c_r4c4_ptr_, 0, r4_ * c4_ * sizeof(int8_t)); | |||
| a_sums_ = reinterpret_cast<int *>(ctx_->allocator->Malloc(r4_ * sizeof(int))); | |||
| if (!a_sums_) return RET_MEMORY_FAILED; | |||
| memset(a_sums_, 0, r4_ * sizeof(int)); | |||
| b_bias_ = reinterpret_cast<int *>(ctx_->allocator->Malloc(c4_ * sizeof(int))); | |||
| if (!b_bias_) return RET_MEMORY_FAILED; | |||
| memset(b_bias_, 0, c4_ * sizeof(int)); | |||
| thread_count_ = MSMIN(thread_count_, UP_DIV(c4_, 4)); | |||
| thread_stride_ = UP_DIV(UP_DIV(c4_, 4), thread_count_); | |||
| #else | |||
| a_c8_ptr_ = reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(params_->row_8_ * params_->deep_ * sizeof(int8_t))); | |||
| if (!a_c8_ptr_) { | |||
| return RET_MEMORY_FAILED; | |||
| @@ -65,6 +85,9 @@ int MatmulInt8CPUKernel::ReSize() { | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| memset(c_r8x8_ptr_, 0, params_->row_8_ * params_->col_8_ * sizeof(int)); | |||
| thread_count_ = MSMIN(thread_count_, UP_DIV(params_->col_8_, 8)); | |||
| thread_stride_ = UP_DIV(UP_DIV(params_->col_8_, 8), thread_count_); | |||
| #endif | |||
| auto input_tensor = in_tensors_[0]; | |||
| auto params = input_tensor->GetQuantParams(); | |||
| @@ -89,14 +112,27 @@ int MatmulInt8CPUKernel::ReSize() { | |||
| } | |||
| int MatmulInt8CPUKernel::RunImpl(int task_id) { | |||
| #ifdef ENABLE_ARM64 | |||
| int cur_oc = MSMIN(thread_stride_, UP_DIV(c4_, 4) - task_id * thread_stride_); | |||
| if (cur_oc <= 0) { | |||
| return RET_OK; | |||
| } | |||
| auto cur_b = b_c4d16_ptr_ + task_id * thread_stride_ * 4 * d16_; | |||
| auto cur_c = c_r4c4_ptr_ + task_id * thread_stride_ * 4 * r4_; | |||
| auto &p = quant_params_; | |||
| MatmulInt8Neon64(a_r4d16_ptr_, cur_b, cur_c, r4_, c4_, d16_, a_sums_, b_bias_, INT_MIN, INT_MAX, p.output.zp_, | |||
| p.quant_multiplier, p.left_shift, p.right_shift); | |||
| #else | |||
| int cur_oc = MSMIN(thread_stride_, UP_DIV(params_->col_8_, 8) - task_id * thread_stride_); | |||
| if (cur_oc <= 0) { | |||
| return RET_OK; | |||
| } | |||
| auto cur_b = b_r8_ptr_ + task_id * thread_stride_ * C8NUM * params_->deep_; | |||
| auto cur_c = c_r8x8_ptr_ + task_id * thread_stride_ * C8NUM * params_->row_8_; | |||
| MatMulInt8(a_c8_ptr_, cur_b, cur_c, params_->row_8_, cur_oc * 8, params_->deep_, quant_params_.input.zp_, | |||
| quant_params_.weight.zp_); | |||
| #endif | |||
| return RET_OK; | |||
| } | |||
| @@ -127,6 +163,24 @@ int MatmulInt8CPUKernel::Run() { | |||
| auto cur_a_ptr = a_ptr + i * a_stride; | |||
| auto cur_b_ptr = b_ptr + i * b_stride; | |||
| auto cur_c_ptr = c_ptr + i * c_stride; | |||
| #ifdef ENABLE_ARM64 | |||
| if (params_->a_transpose_) { | |||
| RowMajor2Col16x4Major(cur_a_ptr, params_->deep_, params_->row_, a_r4d16_ptr_, d16_); | |||
| } else { | |||
| RowMajor2Row4x16Major(cur_a_ptr, params_->row_, params_->deep_, a_r4d16_ptr_, d16_); | |||
| } | |||
| if (params_->b_transpose_) { | |||
| RowMajor2Row4x16Major(cur_b_ptr, params_->col_, params_->deep_, b_c4d16_ptr_, d16_); | |||
| } else { | |||
| RowMajor2Col16x4Major(cur_b_ptr, params_->deep_, params_->col_, b_c4d16_ptr_, d16_); | |||
| } | |||
| auto &q = quant_params_; | |||
| RowMajor2Asums(cur_a_ptr, params_->row_, params_->deep_, q.weight.zp_, a_sums_); | |||
| RowMajor2Bbias(cur_b_ptr, params_->deep_, params_->col_, q.input.zp_, q.weight.zp_, NULL, b_bias_); | |||
| LiteBackendParallelLaunch(MatmulInt8Run, this, thread_count_); | |||
| Row4x4Major2RowMajor(c_r4c4_ptr_, r4_, cur_c_ptr, params_->row_, params_->col_); | |||
| #else | |||
| if (params_->a_transpose_) { | |||
| RowMajor2Row8MajorInt8(cur_a_ptr, a_c8_ptr_, params_->deep_, params_->row_); | |||
| } else { | |||
| @@ -141,6 +195,7 @@ int MatmulInt8CPUKernel::Run() { | |||
| auto &q = quant_params_; | |||
| SimplePostFuncInt8(c_r8x8_ptr_, cur_c_ptr, params_->col_, params_->row_, params_->row_8_, q.quant_multiplier, | |||
| q.left_shift, q.right_shift, q.output.zp_); | |||
| #endif | |||
| } | |||
| return RET_OK; | |||
| @@ -39,6 +39,28 @@ class MatmulInt8CPUKernel : public MatmulBaseCPUKernel { | |||
| private: | |||
| void FreeTmpBuffer() { | |||
| #ifdef ENABLE_ARM64 | |||
| if (a_r4d16_ptr_ != nullptr) { | |||
| ctx_->allocator->Free(a_r4d16_ptr_); | |||
| a_r4d16_ptr_ = nullptr; | |||
| } | |||
| if (b_c4d16_ptr_ != nullptr) { | |||
| ctx_->allocator->Free(b_c4d16_ptr_); | |||
| b_c4d16_ptr_ = nullptr; | |||
| } | |||
| if (c_r4c4_ptr_ != nullptr) { | |||
| ctx_->allocator->Free(c_r4c4_ptr_); | |||
| c_r4c4_ptr_ = nullptr; | |||
| } | |||
| if (a_sums_ != nullptr) { | |||
| ctx_->allocator->Free(a_sums_); | |||
| a_sums_ = nullptr; | |||
| } | |||
| if (b_bias_ != nullptr) { | |||
| ctx_->allocator->Free(b_bias_); | |||
| b_bias_ = nullptr; | |||
| } | |||
| #else | |||
| if (a_c8_ptr_ != nullptr) { | |||
| ctx_->allocator->Free(a_c8_ptr_); | |||
| a_c8_ptr_ = nullptr; | |||
| @@ -51,12 +73,24 @@ class MatmulInt8CPUKernel : public MatmulBaseCPUKernel { | |||
| ctx_->allocator->Free(c_r8x8_ptr_); | |||
| c_r8x8_ptr_ = nullptr; | |||
| } | |||
| #endif | |||
| } | |||
| MatmulQuantArg quant_params_; | |||
| #ifdef ENABLE_ARM64 | |||
| int8_t *a_r4d16_ptr_ = nullptr; | |||
| int8_t *b_c4d16_ptr_ = nullptr; | |||
| int8_t *c_r4c4_ptr_ = nullptr; | |||
| int *a_sums_ = nullptr; | |||
| int *b_bias_ = nullptr; | |||
| int r4_; | |||
| int c4_; | |||
| int d16_; | |||
| #else | |||
| int8_t *a_c8_ptr_ = nullptr; | |||
| int8_t *b_r8_ptr_ = nullptr; | |||
| int *c_r8x8_ptr_ = nullptr; | |||
| }; | |||
| #endif | |||
| }; // namespace mindspore::kernel | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_MATMUL_INT8_H_ | |||
| @@ -0,0 +1,276 @@ | |||
| #ifdef __aarch64__ | |||
| .text | |||
| .align 5 | |||
| .global MatmulInt8Neon64 | |||
| #ifndef __APPLE__ | |||
| .type MatmulInt8Neon64, %function | |||
| #endif | |||
| // | |||
| // int8 RM 16x4 block | |||
| // /-----------------------------------------| | |||
| // |v4.b[0] v5.b[0] v6.b[0] v7.b[0] | | |||
| // | ... ... ... ... | | |||
| // |v4.b[15] v5.b[15] v5.b[15] v7.b[15] | | |||
| // \-----------------------------------------/ | |||
| // int8 LM 4x16 block | |||
| // /---------------------\ /-----------------------------------------| | |||
| // |v0.b[0] ... v0.b[15] | |v16.4s v17.4s v18.4s v19.4s | | |||
| // |v1.b[0] ... v1.b[15] | |v20.4s v21.4s v22.4s v23.4s | | |||
| // |v2.b[0] ... v2.b[15] | |v24.4s v25.4s v26.4s v27.4s | | |||
| // |v3.b[0] ... v3.b[15] | |v28.4s v29.4s v30.4s v31.4s | | |||
| // \---------------------/ \-----------------------------------------/ | |||
| // int32 accumulators 4x4 block | |||
| //void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16, | |||
| // const int *a_sums, const int *bias, int act_min, int act_max, int out_zp, | |||
| // int multiplier, int left_shift, int right_shift); | |||
| // x0: a(left matrix ptr) | |||
| // x1: b(right matrix ptr) | |||
| // x2: out ptr | |||
| // w3: row4 | |||
| // w4: col4 | |||
| // w5: deep16 | |||
| // x6: a_sums | |||
| // x7: bias | |||
| // w8: act_min | |||
| // w9: act_max | |||
| // w10: out_zp | |||
| // w11: multiplier | |||
| // w12: left_shift | |||
| // w13: right_shift | |||
| MatmulInt8Neon64: | |||
| sub sp, sp, #160 | |||
| st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 | |||
| st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 | |||
| stp x19, x20, [sp], #16 | |||
| stp x21, x22, [sp], #16 | |||
| ldr w8, [sp] | |||
| ldr w9, [sp, #8] | |||
| ldr w10, [sp, #16] | |||
| ldr w11, [sp, #24] | |||
| ldr w12, [sp, #32] | |||
| ldr w13, [sp, #40] | |||
| mov w15, #0 // b col index | |||
| mov w16, #0 // a row index | |||
| mov w17, #4 // sizeof(int8)*4 | |||
| mul w21, w5, w17 // the stride of a/b: sizeof(int8)*4*deep16 | |||
| L1: | |||
| cmp w15, w4 | |||
| beq End1 | |||
| mov w16, #0 // reset a row index | |||
| mov x17, x0 // reload a ptr | |||
| mov x22, x6 // reload a_sums ptr | |||
| L2: | |||
| cmp w16, w3 | |||
| beq End2 | |||
| mov x18, x1 // reload b ptr | |||
| mov x19, x7 // reload bias ptr | |||
| mov w20, w5 // reload depth | |||
| dup v16.4s, wzr | |||
| dup v17.4s, wzr | |||
| dup v18.4s, wzr | |||
| dup v19.4s, wzr | |||
| dup v20.4s, wzr | |||
| dup v21.4s, wzr | |||
| dup v22.4s, wzr | |||
| dup v23.4s, wzr | |||
| dup v24.4s, wzr | |||
| dup v25.4s, wzr | |||
| dup v26.4s, wzr | |||
| dup v27.4s, wzr | |||
| dup v28.4s, wzr | |||
| dup v29.4s, wzr | |||
| dup v30.4s, wzr | |||
| dup v31.4s, wzr | |||
| L3: | |||
| cmp w20, #0 | |||
| beq End3 | |||
| ld1 {v0.16b}, [x17], #16 | |||
| ld1 {v1.16b}, [x17], #16 | |||
| ld1 {v2.16b}, [x17], #16 | |||
| ld1 {v3.16b}, [x17], #16 | |||
| ld1 {v4.16b}, [x18], #16 | |||
| ld1 {v5.16b}, [x18], #16 | |||
| ld1 {v6.16b}, [x18], #16 | |||
| ld1 {v7.16b}, [x18], #16 | |||
| smull v8.8h, v4.8b, v0.8b | |||
| smull v9.8h, v5.8b, v0.8b | |||
| smull v10.8h, v6.8b, v0.8b | |||
| smull v11.8h, v7.8b, v0.8b | |||
| smull v12.8h, v4.8b, v1.8b | |||
| smull v13.8h, v5.8b, v1.8b | |||
| smull v14.8h, v6.8b, v1.8b | |||
| smull v15.8h, v7.8b, v1.8b | |||
| smlal2 v8.8h, v4.16b, v0.16b | |||
| smlal2 v9.8h, v5.16b, v0.16b | |||
| smlal2 v10.8h, v6.16b, v0.16b | |||
| smlal2 v11.8h, v7.16b, v0.16b | |||
| smlal2 v12.8h, v4.16b, v1.16b | |||
| smlal2 v13.8h, v5.16b, v1.16b | |||
| smlal2 v14.8h, v6.16b, v1.16b | |||
| smlal2 v15.8h, v7.16b, v1.16b | |||
| sadalp v16.4s, v8.8h | |||
| sadalp v17.4s, v9.8h | |||
| sadalp v18.4s, v10.8h | |||
| sadalp v19.4s, v11.8h | |||
| sadalp v20.4s, v12.8h | |||
| sadalp v21.4s, v13.8h | |||
| sadalp v22.4s, v14.8h | |||
| sadalp v23.4s, v15.8h | |||
| smull v8.8h, v4.8b, v2.8b | |||
| smull v9.8h, v5.8b, v2.8b | |||
| smull v10.8h, v6.8b, v2.8b | |||
| smull v11.8h, v7.8b, v2.8b | |||
| smull v12.8h, v4.8b, v3.8b | |||
| smull v13.8h, v5.8b, v3.8b | |||
| smull v14.8h, v6.8b, v3.8b | |||
| smull v15.8h, v7.8b, v3.8b | |||
| smlal2 v8.8h, v4.16b, v2.16b | |||
| smlal2 v9.8h, v5.16b, v2.16b | |||
| smlal2 v10.8h, v6.16b, v2.16b | |||
| smlal2 v11.8h, v7.16b, v2.16b | |||
| smlal2 v12.8h, v4.16b, v3.16b | |||
| smlal2 v13.8h, v5.16b, v3.16b | |||
| smlal2 v14.8h, v6.16b, v3.16b | |||
| smlal2 v15.8h, v7.16b, v3.16b | |||
| sadalp v24.4s, v8.8h | |||
| sadalp v25.4s, v9.8h | |||
| sadalp v26.4s, v10.8h | |||
| sadalp v27.4s, v11.8h | |||
| sadalp v28.4s, v12.8h | |||
| sadalp v29.4s, v13.8h | |||
| sadalp v30.4s, v14.8h | |||
| sadalp v31.4s, v15.8h | |||
| subs w20, w20, #16 // depth + 16 | |||
| b L3 | |||
| End3: | |||
| addp v16.4s, v16.4s, v17.4s | |||
| addp v18.4s, v18.4s, v19.4s | |||
| addp v20.4s, v20.4s, v21.4s | |||
| addp v22.4s, v22.4s, v23.4s | |||
| addp v24.4s, v24.4s, v25.4s | |||
| addp v26.4s, v26.4s, v27.4s | |||
| addp v28.4s, v28.4s, v29.4s | |||
| addp v30.4s, v30.4s, v31.4s | |||
| addp v16.4s, v16.4s, v18.4s | |||
| addp v17.4s, v20.4s, v22.4s | |||
| addp v18.4s, v24.4s, v26.4s | |||
| addp v19.4s, v28.4s, v30.4s | |||
| // Add (Bias+Depth*Za*Zb-Za*Bsums) | |||
| ld1 {v15.4s}, [x19], #16 | |||
| add v16.4s, v16.4s, v15.4s | |||
| add v17.4s, v17.4s, v15.4s | |||
| add v18.4s, v18.4s, v15.4s | |||
| add v19.4s, v19.4s, v15.4s | |||
| // Subtract (Asums*Zb) | |||
| ld1 {v14.4s}, [x22], #16 | |||
| dup v20.4s, v14.s[0] | |||
| dup v21.4s, v14.s[1] | |||
| dup v22.4s, v14.s[2] | |||
| dup v23.4s, v14.s[3] | |||
| sub v16.4s, v16.4s, v20.4s | |||
| sub v17.4s, v17.4s, v21.4s | |||
| sub v18.4s, v18.4s, v22.4s | |||
| sub v19.4s, v19.4s, v23.4s | |||
| // Apply left shift | |||
| dup v13.4s, w12 | |||
| sqshl v16.4s, v16.4s, v13.4s | |||
| sqshl v17.4s, v17.4s, v13.4s | |||
| sqshl v18.4s, v18.4s, v13.4s | |||
| sqshl v19.4s, v19.4s, v13.4s | |||
| // Apply the fixed-point part of the multiplier. | |||
| dup v12.4s, w11 | |||
| sqrdmulh v16.4s, v16.4s, v12.4s | |||
| sqrdmulh v17.4s, v17.4s, v12.4s | |||
| sqrdmulh v18.4s, v18.4s, v12.4s | |||
| sqrdmulh v19.4s, v19.4s, v12.4s | |||
| // Apply right shift | |||
| dup v11.4s, w13 | |||
| and v20.16b, v11.16b, v16.16b | |||
| sshr v20.4s, v20.4s, #31 | |||
| sqadd v16.4s, v16.4s, v20.4s | |||
| srshl v16.4s, v16.4s, v11.4s | |||
| and v21.16b, v11.16b, v17.16b | |||
| sshr v21.4s, v21.4s, #31 | |||
| sqadd v17.4s, v17.4s, v21.4s | |||
| srshl v17.4s, v17.4s, v11.4s | |||
| and v22.16b, v11.16b, v18.16b | |||
| sshr v22.4s, v22.4s, #31 | |||
| sqadd v18.4s, v18.4s, v22.4s | |||
| srshl v18.4s, v18.4s, v11.4s | |||
| and v23.16b, v11.16b, v19.16b | |||
| sshr v23.4s, v23.4s, #31 | |||
| sqadd v19.4s, v19.4s, v23.4s | |||
| srshl v19.4s, v19.4s, v11.4s | |||
| // Add the destination zero point | |||
| dup v10.4s, w10 | |||
| add v16.4s, v16.4s, v10.4s | |||
| add v17.4s, v17.4s, v10.4s | |||
| add v18.4s, v18.4s, v10.4s | |||
| add v19.4s, v19.4s, v10.4s | |||
| // Apply the act_min bound | |||
| dup v9.4s, w8 | |||
| smax v16.4s, v16.4s, v9.4s | |||
| smax v17.4s, v17.4s, v9.4s | |||
| smax v18.4s, v18.4s, v9.4s | |||
| smax v19.4s, v19.4s, v9.4s | |||
| // Apply the act_min bound | |||
| dup v8.4s, w9 | |||
| smin v16.4s, v16.4s, v8.4s | |||
| smin v17.4s, v17.4s, v8.4s | |||
| smin v18.4s, v18.4s, v8.4s | |||
| smin v19.4s, v19.4s, v8.4s | |||
| // int32 -> int16 | |||
| sqxtn v13.4h, v16.4s | |||
| sqxtn2 v13.8h, v17.4s | |||
| sqxtn v14.4h, v18.4s | |||
| sqxtn2 v14.8h, v19.4s | |||
| // int16 -> int8 | |||
| sqxtn v15.8b, v13.8h | |||
| sqxtn2 v15.16b, v14.8h | |||
| st1 {v15.16b}, [x2], #16 | |||
| add w16, w16, #4 // a row index + 4 | |||
| b L2 | |||
| End2: | |||
| add w15, w15, #4 // b col index + 4 | |||
| add x1, x1, x21 // b ptr + stride | |||
| add x7, x7, #16 // bias ptr + stride | |||
| b L1 | |||
| End1: | |||
| sub sp, sp, #160 | |||
| ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 | |||
| ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 | |||
| ldp x19, x20, [sp], #16 | |||
| ldp x21, x22, [sp], #16 | |||
| ret | |||
| #endif | |||
| @@ -0,0 +1,194 @@ | |||
| #ifdef __aarch64__ | |||
| .text | |||
| .align 5 | |||
| .global MatMulR4Int8Neon64 | |||
| #ifndef __APPLE__ | |||
| .type MatMulR4Int8Neon64, %function | |||
| #endif | |||
| // | |||
| // int8 RM 16x4 block | |||
| // /-----------------------------------------| | |||
| // |v4.b[0] v5.b[0] v6.b[0] v7.b[0] | | |||
| // | ... ... ... ... | | |||
| // |v4.b[15] v5.b[15] v5.b[15] v7.b[15] | | |||
| // \-----------------------------------------/ | |||
| // int8 LM 4x16 block | |||
| // /---------------------\ /-----------------------------------------| | |||
| // |v0.b[0] ... v0.b[15] | |v16.4s v17.4s v18.4s v19.4s | | |||
| // |v1.b[0] ... v1.b[15] | |v20.4s v21.4s v22.4s v23.4s | | |||
| // |v2.b[0] ... v2.b[15] | |v24.4s v25.4s v26.4s v27.4s | | |||
| // |v3.b[0] ... v3.b[15] | |v28.4s v29.4s v30.4s v31.4s | | |||
| // \---------------------/ \-----------------------------------------/ | |||
| // int32 accumulators 4x4 block | |||
| //void MatMulR4Int8Neon64(const int8_t *a, const int8_t *b, int32_t *dst, int row4, int col4, int deep16, | |||
| // const int *input_sum, const int *bias) | |||
| // x0: a(left matrix ptr) | |||
| // x1: b(right matrix ptr) | |||
| // x2: out ptr | |||
| // w3: row4 | |||
| // w4: col4 | |||
| // w5: deep16 | |||
| // x6: a_sums | |||
| // x7: bias | |||
| MatMulR4Int8Neon64: | |||
| sub sp, sp, #128 | |||
| st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 | |||
| st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 | |||
| mov w15, #0 // b col index | |||
| mov w16, #0 // a row index | |||
| mov w17, #4 // sizeof(int8)*4 | |||
| mul w12, w5, w17 // the stride of a/b: sizeof(int8)*4*deep16 | |||
| L1: | |||
| cmp w15, w4 | |||
| beq End1 | |||
| mov w16, #0 // reset a row index | |||
| mov x17, x0 // reload a ptr | |||
| mov x13, x6 // reload a_sums ptr | |||
| L2: | |||
| cmp w16, w3 | |||
| beq End2 | |||
| mov x18, x1 // reload b ptr | |||
| mov x10, x7 // reload bias ptr | |||
| mov w11, w5 // reload depth | |||
| dup v16.4s, wzr | |||
| dup v17.4s, wzr | |||
| dup v18.4s, wzr | |||
| dup v19.4s, wzr | |||
| dup v20.4s, wzr | |||
| dup v21.4s, wzr | |||
| dup v22.4s, wzr | |||
| dup v23.4s, wzr | |||
| dup v24.4s, wzr | |||
| dup v25.4s, wzr | |||
| dup v26.4s, wzr | |||
| dup v27.4s, wzr | |||
| dup v28.4s, wzr | |||
| dup v29.4s, wzr | |||
| dup v30.4s, wzr | |||
| dup v31.4s, wzr | |||
| L3: | |||
| cmp w11, #0 | |||
| beq End3 | |||
| ld1 {v0.16b}, [x17], #16 | |||
| ld1 {v1.16b}, [x17], #16 | |||
| ld1 {v2.16b}, [x17], #16 | |||
| ld1 {v3.16b}, [x17], #16 | |||
| ld1 {v4.16b}, [x18], #16 | |||
| ld1 {v5.16b}, [x18], #16 | |||
| ld1 {v6.16b}, [x18], #16 | |||
| ld1 {v7.16b}, [x18], #16 | |||
| smull v8.8h, v4.8b, v0.8b | |||
| smull v9.8h, v5.8b, v0.8b | |||
| smull v10.8h, v6.8b, v0.8b | |||
| smull v11.8h, v7.8b, v0.8b | |||
| smull v12.8h, v4.8b, v1.8b | |||
| smull v13.8h, v5.8b, v1.8b | |||
| smull v14.8h, v6.8b, v1.8b | |||
| smull v15.8h, v7.8b, v1.8b | |||
| smlal2 v8.8h, v4.16b, v0.16b | |||
| smlal2 v9.8h, v5.16b, v0.16b | |||
| smlal2 v10.8h, v6.16b, v0.16b | |||
| smlal2 v11.8h, v7.16b, v0.16b | |||
| smlal2 v12.8h, v4.16b, v1.16b | |||
| smlal2 v13.8h, v5.16b, v1.16b | |||
| smlal2 v14.8h, v6.16b, v1.16b | |||
| smlal2 v15.8h, v7.16b, v1.16b | |||
| sadalp v16.4s, v8.8h | |||
| sadalp v17.4s, v9.8h | |||
| sadalp v18.4s, v10.8h | |||
| sadalp v19.4s, v11.8h | |||
| sadalp v20.4s, v12.8h | |||
| sadalp v21.4s, v13.8h | |||
| sadalp v22.4s, v14.8h | |||
| sadalp v23.4s, v15.8h | |||
| smull v8.8h, v4.8b, v2.8b | |||
| smull v9.8h, v5.8b, v2.8b | |||
| smull v10.8h, v6.8b, v2.8b | |||
| smull v11.8h, v7.8b, v2.8b | |||
| smull v12.8h, v4.8b, v3.8b | |||
| smull v13.8h, v5.8b, v3.8b | |||
| smull v14.8h, v6.8b, v3.8b | |||
| smull v15.8h, v7.8b, v3.8b | |||
| smlal2 v8.8h, v4.16b, v2.16b | |||
| smlal2 v9.8h, v5.16b, v2.16b | |||
| smlal2 v10.8h, v6.16b, v2.16b | |||
| smlal2 v11.8h, v7.16b, v2.16b | |||
| smlal2 v12.8h, v4.16b, v3.16b | |||
| smlal2 v13.8h, v5.16b, v3.16b | |||
| smlal2 v14.8h, v6.16b, v3.16b | |||
| smlal2 v15.8h, v7.16b, v3.16b | |||
| sadalp v24.4s, v8.8h | |||
| sadalp v25.4s, v9.8h | |||
| sadalp v26.4s, v10.8h | |||
| sadalp v27.4s, v11.8h | |||
| sadalp v28.4s, v12.8h | |||
| sadalp v29.4s, v13.8h | |||
| sadalp v30.4s, v14.8h | |||
| sadalp v31.4s, v15.8h | |||
| subs w11, w11, #16 // depth + 16 | |||
| b L3 | |||
| End3: | |||
| addp v16.4s, v16.4s, v17.4s | |||
| addp v18.4s, v18.4s, v19.4s | |||
| addp v20.4s, v20.4s, v21.4s | |||
| addp v22.4s, v22.4s, v23.4s | |||
| addp v24.4s, v24.4s, v25.4s | |||
| addp v26.4s, v26.4s, v27.4s | |||
| addp v28.4s, v28.4s, v29.4s | |||
| addp v30.4s, v30.4s, v31.4s | |||
| addp v16.4s, v16.4s, v18.4s | |||
| addp v17.4s, v20.4s, v22.4s | |||
| addp v18.4s, v24.4s, v26.4s | |||
| addp v19.4s, v28.4s, v30.4s | |||
| // Add (Bias+Depth*Za*Zb-Za*Bsums) | |||
| ld1 {v15.4s}, [x10], #16 | |||
| add v16.4s, v16.4s, v15.4s | |||
| add v17.4s, v17.4s, v15.4s | |||
| add v18.4s, v18.4s, v15.4s | |||
| add v19.4s, v19.4s, v15.4s | |||
| // Subtract (Asums*Zb) | |||
| ld1 {v14.4s}, [x13], #16 | |||
| dup v20.4s, v14.s[0] | |||
| dup v21.4s, v14.s[1] | |||
| dup v22.4s, v14.s[2] | |||
| dup v23.4s, v14.s[3] | |||
| sub v16.4s, v16.4s, v20.4s | |||
| sub v17.4s, v17.4s, v21.4s | |||
| sub v18.4s, v18.4s, v22.4s | |||
| sub v19.4s, v19.4s, v23.4s | |||
| st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 | |||
| add w16, w16, #4 // a row index + 4 | |||
| b L2 | |||
| End2: | |||
| add w15, w15, #4 // b col index + 4 | |||
| add x1, x1, x12 // b ptr + stride | |||
| add x7, x7, #16 // bias ptr + stride | |||
| b L1 | |||
| End1: | |||
| sub sp, sp, #128 | |||
| ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 | |||
| ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 | |||
| ret | |||
| #endif | |||
| @@ -0,0 +1,157 @@ | |||
| #ifdef __aarch64__ | |||
| .text | |||
| .align 5 | |||
| .global MatMulOptR4Int8Neon64 | |||
| #ifndef __APPLE__ | |||
| .type MatMulOptR4Int8Neon64, %function | |||
| #endif | |||
| // | |||
| // int8 RM 16x4 block | |||
| // /-----------------------------------------| | |||
| // |v4.b[0] v5.b[0] v6.b[0] v7.b[0] | | |||
| // | ... ... ... ... | | |||
| // |v4.b[15] v5.b[15] v5.b[15] v7.b[15] | | |||
| // \-----------------------------------------/ | |||
| // int8 LM 4x16 block | |||
| // /---------------------\ /-----------------------------------------| | |||
| // |v0.b[0] ... v0.b[15] | |v16.4s v17.4s v18.4s v19.4s | | |||
| // |v1.b[0] ... v1.b[15] | |v20.4s v21.4s v22.4s v23.4s | | |||
| // |v2.b[0] ... v2.b[15] | |v24.4s v25.4s v26.4s v27.4s | | |||
| // |v3.b[0] ... v3.b[15] | |v28.4s v29.4s v30.4s v31.4s | | |||
| // \---------------------/ \-----------------------------------------/ | |||
| // int32 accumulators 4x4 block | |||
| //void MatMulOptR4Int8Neon64(const int8_t *a, const int8_t *b, int *dst, int row4, int col4, int deep16, | |||
| // const int *input_sum, const int *bias) | |||
| // x0: a(left matrix ptr) | |||
| // x1: b(right matrix ptr) | |||
| // x2: out ptr | |||
| // w3: row4 | |||
| // w4: col4 | |||
| // w5: deep16 | |||
| // x6: a_sums | |||
| // x7: bias | |||
| MatMulOptR4Int8Neon64: | |||
| sub sp, sp, #128 | |||
| st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 | |||
| st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 | |||
| mov w15, #0 // b col index | |||
| mov w16, #0 // a row index | |||
| mov w17, #4 // sizeof(int8)*4 | |||
| mul w12, w5, w17 // the stride of a/b: sizeof(int8)*4*deep16 | |||
| L1: | |||
| cmp w15, w4 | |||
| beq End1 | |||
| mov w16, #0 // reset a row index | |||
| mov x17, x0 // reload a ptr | |||
| mov x13, x6 // reload a_sums ptr | |||
| L2: | |||
| cmp w16, w3 | |||
| beq End2 | |||
| mov x18, x1 // reload b ptr | |||
| mov x10, x7 // reload bias ptr | |||
| mov w11, w5 // reload depth | |||
| dup v16.4s, wzr | |||
| dup v17.4s, wzr | |||
| dup v18.4s, wzr | |||
| dup v19.4s, wzr | |||
| dup v20.4s, wzr | |||
| dup v21.4s, wzr | |||
| dup v22.4s, wzr | |||
| dup v23.4s, wzr | |||
| dup v24.4s, wzr | |||
| dup v25.4s, wzr | |||
| dup v26.4s, wzr | |||
| dup v27.4s, wzr | |||
| dup v28.4s, wzr | |||
| dup v29.4s, wzr | |||
| dup v30.4s, wzr | |||
| dup v31.4s, wzr | |||
| L3: | |||
| cmp w11, #0 | |||
| beq End3 | |||
| ld1 {v0.16b}, [x17], #16 | |||
| ld1 {v1.16b}, [x17], #16 | |||
| ld1 {v2.16b}, [x17], #16 | |||
| ld1 {v3.16b}, [x17], #16 | |||
| ld1 {v4.16b}, [x18], #16 | |||
| ld1 {v5.16b}, [x18], #16 | |||
| ld1 {v6.16b}, [x18], #16 | |||
| ld1 {v7.16b}, [x18], #16 | |||
| sdot v16.4s, v4.16b, v0.16b | |||
| sdot v17.4s, v5.16b, v0.16b | |||
| sdot v18.4s, v6.16b, v0.16b | |||
| sdot v19.4s, v7.16b, v0.16b | |||
| sdot v20.4s, v4.16b, v1.16b | |||
| sdot v21.4s, v5.16b, v1.16b | |||
| sdot v22.4s, v6.16b, v1.16b | |||
| sdot v23.4s, v7.16b, v1.16b | |||
| sdot v24.4s, v4.16b, v2.16b | |||
| sdot v25.4s, v5.16b, v2.16b | |||
| sdot v26.4s, v6.16b, v2.16b | |||
| sdot v27.4s, v7.16b, v2.16b | |||
| sdot v28.4s, v4.16b, v3.16b | |||
| sdot v29.4s, v5.16b, v3.16b | |||
| sdot v30.4s, v6.16b, v3.16b | |||
| sdot v31.4s, v7.16b, v3.16b | |||
| subs w11, w11, #16 // depth + 16 | |||
| b L3 | |||
| End3: | |||
| addp v16.4s, v16.4s, v17.4s | |||
| addp v18.4s, v18.4s, v19.4s | |||
| addp v20.4s, v20.4s, v21.4s | |||
| addp v22.4s, v22.4s, v23.4s | |||
| addp v24.4s, v24.4s, v25.4s | |||
| addp v26.4s, v26.4s, v27.4s | |||
| addp v28.4s, v28.4s, v29.4s | |||
| addp v30.4s, v30.4s, v31.4s | |||
| addp v16.4s, v16.4s, v18.4s | |||
| addp v17.4s, v20.4s, v22.4s | |||
| addp v18.4s, v24.4s, v26.4s | |||
| addp v19.4s, v28.4s, v30.4s | |||
| // Add (Bias+Depth*Za*Zb-Za*Bsums) | |||
| ld1 {v15.4s}, [x10], #16 | |||
| add v16.4s, v16.4s, v15.4s | |||
| add v17.4s, v17.4s, v15.4s | |||
| add v18.4s, v18.4s, v15.4s | |||
| add v19.4s, v19.4s, v15.4s | |||
| // Subtract (Asums*Zb) | |||
| ld1 {v14.4s}, [x13], #16 | |||
| dup v20.4s, v14.s[0] | |||
| dup v21.4s, v14.s[1] | |||
| dup v22.4s, v14.s[2] | |||
| dup v23.4s, v14.s[3] | |||
| sub v16.4s, v16.4s, v20.4s | |||
| sub v17.4s, v17.4s, v21.4s | |||
| sub v18.4s, v18.4s, v22.4s | |||
| sub v19.4s, v19.4s, v23.4s | |||
| st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 | |||
| add w16, w16, #4 // a row index + 4 | |||
| b L2 | |||
| End2: | |||
| add w15, w15, #4 // b col index + 4 | |||
| add x1, x1, x12 // b ptr + stride | |||
| add x7, x7, #16 // bias ptr + stride | |||
| b L1 | |||
| End1: | |||
| sub sp, sp, #128 | |||
| ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 | |||
| ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 | |||
| ret | |||
| #endif | |||
| @@ -269,7 +269,7 @@ void MatMul8x8(const float *a, const float *b, float *dst, const float *bias, Ac | |||
| void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row, int col, | |||
| int stride, bool write_nhwc) { | |||
| #ifdef __aarch64__ | |||
| #ifdef ENABLE_ARM64 | |||
| MatmulFloatNeon64(a, b, c, bias, (int)act_type, deep, row, col, stride, write_nhwc); | |||
| #else | |||
| MatMul8x8(a, b, c, bias, act_type, deep, row, col, stride, write_nhwc); | |||
| @@ -31,7 +31,7 @@ void MatMul(const float *a, const float *b, float *c, const float *bias, ActType | |||
| void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col); | |||
| void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col); | |||
| void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, size_t row, size_t col, size_t stride); | |||
| #ifdef __aarch64__ | |||
| #ifdef ENABLE_ARM64 | |||
| void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, | |||
| int col, size_t stride, bool write_nhwc); | |||
| #endif | |||
| @@ -197,7 +197,7 @@ int DeConvInt8(const int8_t *input, const int8_t *weight, int32_t *output, int32 | |||
| size_t act_row, size_t act_col, size_t act_deep, ConvParameter *conv_param, | |||
| MATMUL_OPT_R4_FUNC matmul_func) { | |||
| if (matmul_func != NULL) { | |||
| matmul_func(output, input, weight, weight_sum, input_sum, act_row, act_col, act_deep); | |||
| matmul_func(input, weight, output, act_row, act_col, act_deep, input_sum, weight_sum); | |||
| } else { | |||
| /* todo normal int8 deconv */ | |||
| } | |||
| @@ -74,8 +74,8 @@ void MatMulInt8(const int8_t *a, const int8_t *b, int32_t *c, const int row8, co | |||
| } | |||
| } | |||
| void MatMulOptR4Int8(int32_t *dst, const int8_t *a, const int8_t *b, const int32_t *bias, const int32_t *input_sum, | |||
| size_t row_4, size_t col_4, size_t deep_16) { | |||
| void MatMulOptR4Int8(const int8_t *a, const int8_t *b, int *dst, int row_4, int col_4, int deep_16, | |||
| const int *input_sum, const int *bias) { | |||
| /* row4x16-major * row16x4-major => row4x4-major */ | |||
| for (int r = 0; r < row_4; r++) { | |||
| for (int c = 0; c < col_4; c++) { | |||
| @@ -96,3 +96,61 @@ void MatMulOptR4Int8(int32_t *dst, const int8_t *a, const int8_t *b, const int32 | |||
| } | |||
| return; | |||
| } | |||
| #ifdef ENABLE_ARM64 | |||
| void RowMajor2Row4x16Major(int8_t *src, int row, int col, int8_t *dst, int col_16) { | |||
| int stride = sizeof(int8_t) * 16 * 4; | |||
| for (int r = 0; r < row; ++r) { | |||
| for (int c = 0; c < col; ++c) { | |||
| int stride_n = r / 4 * (col_16 / 16) + c / 16; | |||
| int src_idx = r * col + c; | |||
| dst[stride * stride_n + r % 4 * 16 + c % 16] = src[src_idx]; | |||
| } | |||
| } | |||
| } | |||
| void RowMajor2Col16x4Major(int8_t *src, int row, int col, int8_t *dst, int row_16) { | |||
| int stride = sizeof(int8_t) * 16 * 4; | |||
| for (int r = 0; r < row; ++r) { | |||
| for (int c = 0; c < col; ++c) { | |||
| int stride_n = c / 4 * (row_16 / 16) + r / 16; | |||
| int src_idx = r * col + c; | |||
| dst[stride * stride_n + c % 4 * 16 + r % 16] = src[src_idx]; | |||
| } | |||
| } | |||
| } | |||
| void RowMajor2Asums(int8_t *a, int row, int col, int b_zp, int *dst) { | |||
| for (int r = 0; r < row; ++r) { | |||
| for (int c = 0; c < col; ++c) { | |||
| int src_idx = r * col + c; | |||
| dst[r] += a[src_idx]; | |||
| } | |||
| dst[r] *= b_zp; | |||
| } | |||
| } | |||
| void RowMajor2Bbias(int8_t *b, int row, int col, int a_zp, int b_zp, int *bias, int *dst) { | |||
| for (int c = 0; c < col; ++c) { | |||
| for (int r = 0; r < row; ++r) { | |||
| int src_idx = r * col + c; | |||
| dst[c] += b[src_idx]; | |||
| } | |||
| dst[c] = row * a_zp * b_zp - a_zp * dst[c]; | |||
| if (bias) { | |||
| dst[c] += bias[c]; | |||
| } | |||
| } | |||
| } | |||
| void Row4x4Major2RowMajor(int8_t *src, int row4, int8_t *dst, int row, int cow) { | |||
| int stride = sizeof(int8_t) * 4 * 4; | |||
| for (int r = 0; r < row; ++r) { | |||
| for (int c = 0; c < cow; ++c) { | |||
| int sride_n = c / 4 * (row4 / 4) + r / 4; | |||
| int dst_idx = r * cow + c; | |||
| dst[dst_idx] = src[stride * sride_n + r % 4 * 4 + c % 4]; | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| @@ -23,14 +23,29 @@ | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| void MatMulInt8(const int8_t *a, const int8_t *b, int32_t *c, const int row8, const int col8, const int deep, | |||
| const int32_t a_zp, const int32_t b_zp); | |||
| void MatMulOptR4Int8(int32_t *dst, const int8_t *a, const int8_t *b, const int32_t *bias, const int32_t *input_sum, | |||
| size_t row_4, size_t col_4, size_t deep_16); | |||
| void MatMulInt8(const int8_t *a, const int8_t *b, int *c, const int row8, const int col8, const int deep, | |||
| const int a_zp, const int b_zp); | |||
| void MatMulOptR4Int8(const int8_t *a, const int8_t *b, int *dst, int row_4, int col_4, int deep_16, | |||
| const int *input_sum, const int *bias); | |||
| void RowMajor2Row8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col); | |||
| void RowMajor2Col8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col); | |||
| void RowMajor2Row16x4MajorInt8(void *src_ptr, void *dst_ptr, int row, int col); | |||
| #ifdef ENABLE_ARM64 | |||
| void RowMajor2Row4x16Major(int8_t *src, int row, int col, int8_t *dst, int col_16); | |||
| void RowMajor2Col16x4Major(int8_t *src, int row, int col, int8_t *dst, int row_16); | |||
| void RowMajor2Asums(int8_t *a, int row, int col, int b_zp, int *dst); | |||
| void RowMajor2Bbias(int8_t *b, int row, int col, int a_zp, int b_zp, int *bias, int *dst); | |||
| void Row4x4Major2RowMajor(int8_t *src, int row4, int8_t *dst, int row, int cow); | |||
| // bias = bias + depth * a_zp * b_zp - a_zp * b_sums | |||
| void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16, const int *a_sums, | |||
| const int *bias, int act_min, int act_max, int out_zp, int multiplier, int left_shift, | |||
| int right_shift); | |||
| void MatMulR4Int8Neon64(const int8_t *a, const int8_t *b, int32_t *dst, int row4, int col4, int deep16, | |||
| const int *input_sum, const int *bias); | |||
| #endif | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| @@ -19,8 +19,8 @@ | |||
| #include "nnacl/op_base.h" | |||
| typedef void (*MATMUL_OPT_R4_FUNC)(int32_t *dst, const int8_t *a, const int8_t *b, const int32_t *bias, | |||
| const int32_t *input_sum, size_t row_4, size_t col_4, size_t deep_16); | |||
| typedef void (*MATMUL_OPT_R4_FUNC)(const int8_t *a, const int8_t *b, int *dst, int row_4, int col_4, int deep_16, | |||
| const int *input_sum, const int *bias); | |||
| typedef void (*MAT_TRANS_FUNC)(void *dst, void *a, int row, int col); | |||
| @@ -23,6 +23,10 @@ extern void IndirectGemmInt8_24x4_dp(int8_t *dst, const int8_t *src, const int8_ | |||
| size_t ksize, size_t ic4, size_t output_channel, size_t offset, | |||
| const int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, | |||
| size_t out_multiplier, size_t shift_before, size_t shift_after); | |||
| extern void MatMulOptR4Int8Neon64(const int8_t *a, const int8_t *b, int *dst, int row4, int col4, int deep16, | |||
| const int *input_sum, const int *bias); | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| @@ -35,4 +39,9 @@ void IndirectGemmInt8_optimize_handler(int8_t *dst, const int8_t *src, const int | |||
| return IndirectGemmInt8_24x4_dp(dst, src, weight, bias, ksize, ic4, output_channel, offset, input_sum, act_min, | |||
| act_max, out_zp, out_multiplier, shift_before, shift_after); | |||
| } | |||
| void MatMulR4Int8_optimize_handler(const int8_t *a, const int8_t *b, int *dst, int row4, int col4, int deep16, | |||
| const int *input_sum, const int *bias) { | |||
| return MatMulOptR4Int8Neon64(a, b, dst, row4, col4, deep16, input_sum, bias); | |||
| } | |||
| #endif | |||
| @@ -64,7 +64,7 @@ class OptimizeModule { | |||
| optimized_op_handler_ = dlopen(OPTIMIZE_SHARED_LIBRARY_PATH, RTLD_LAZY); | |||
| #endif | |||
| if (optimized_op_handler_ == nullptr) { | |||
| printf("Open optimize shared library failed.\n"); | |||
| printf("Open optimize shared library failed: %s\n", dlerror()); | |||
| } | |||
| } | |||
| @@ -26,9 +26,7 @@ const double dNormalizer = 0x1p54; | |||
| const int dNormalizerBias = 54; | |||
| const int iMantissaBits = 31; | |||
| void QuantizeMultiplierSmallerThanOne(double double_multiplier, int32_t *quantized_multiplier, | |||
| int *right_shift) { | |||
| void QuantizeMultiplierSmallerThanOne(double double_multiplier, int32_t *quantized_multiplier, int *right_shift) { | |||
| if (quantized_multiplier == NULL || right_shift == NULL) { | |||
| return; | |||
| } | |||
| @@ -55,10 +53,9 @@ uint8_t QuantizeToUint8(float real_value, float scale, int32_t zp) { return roun | |||
| int32_t QuantizeToInt8(float real_value, float scale, int32_t zp) { return round(real_value / scale + zp); } | |||
| void CalculateActivationRangeQuantized(bool is_relu, bool is_relu6, int32_t zp, float scale, int *mini, | |||
| int *maxi) { | |||
| int32_t min = CHAR_MIN; | |||
| int32_t max = CHAR_MAX; | |||
| void CalculateActivationRangeQuantized(bool is_relu, bool is_relu6, int32_t zp, float scale, int *mini, int *maxi) { | |||
| int32_t min = INT8_MIN; | |||
| int32_t max = INT8_MAX; | |||
| int32_t quantized_zero = QuantizeToInt8(0, scale, zp); | |||
| int32_t quantized_six = QuantizeToInt8(6, scale, zp); | |||
| if (is_relu) { | |||
| @@ -77,8 +74,8 @@ void CalculateActivationRangeQuantized(bool is_relu, bool is_relu6, int32_t zp, | |||
| void Quantize(float *input_data, int length, float scale, int zero_point, int8_t *output_data) { | |||
| for (int i = 0; i < length; ++i) { | |||
| int q = (int)round(input_data[i] / scale + zero_point); | |||
| q = q > CHAR_MAX ? CHAR_MAX : q; | |||
| q = q < CHAR_MIN ? CHAR_MIN : q; | |||
| q = q > SCHAR_MAX ? SCHAR_MAX : q; | |||
| q = q < SCHAR_MIN ? SCHAR_MIN : q; | |||
| output_data[i] = (int8_t)q; | |||
| } | |||
| } | |||
| @@ -270,7 +270,7 @@ TEST_F(TestDeconvInt8, MatMulOptTest1) { | |||
| 7894, -51, 0, 0, -4775, -29785, 0, 0, -12597, 4088, 0, 0, -17420, 1815, | |||
| 0, 0, 15796, 3101, 0, 0, -37969, -10818, 0, 0, 12714, -7827, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0}; | |||
| MatMulOptR4Int8(tmp_output, packed_a, packed_b, weight_sum, input_sum, 12, 24, 16); | |||
| MatMulOptR4Int8(packed_a, packed_b, tmp_output, 12, 24, 16, input_sum, weight_sum); | |||
| CompareOutputData(tmp_output, correct_tmp_output, 12 * 3 * 8, 0); | |||
| } | |||
| @@ -116,7 +116,6 @@ TEST_F(TestMatmulInt8, mmint8) { | |||
| Dequantize(reinterpret_cast<int8_t *>(outputs_[0]->Data()), outputs_[0]->ElementsNum(), output_scale, output_zp, | |||
| fout); | |||
| CompareOutputData(fout, correct, 6, 0.3); | |||
| delete matmul_param; | |||
| delete mm; | |||
| for (auto t : inputs_) delete t; | |||
| for (auto t : outputs_) delete t; | |||