From f93a055c3043bbe2ceba3e1b292d79fedb4c7f05 Mon Sep 17 00:00:00 2001 From: ling Date: Mon, 16 Nov 2020 09:15:12 +0800 Subject: [PATCH] [MSLITE] add boardcast optimize --- mindspore/lite/nnacl/int8/add_int8.c | 143 ++++++++++++++++++ mindspore/lite/nnacl/int8/add_int8.h | 1 + .../src/runtime/kernel/arm/int8/add_int8.cc | 122 ++++++++++----- .../src/runtime/kernel/arm/int8/add_int8.h | 7 +- 4 files changed, 232 insertions(+), 41 deletions(-) diff --git a/mindspore/lite/nnacl/int8/add_int8.c b/mindspore/lite/nnacl/int8/add_int8.c index e828ec6d29..a4341939e2 100644 --- a/mindspore/lite/nnacl/int8/add_int8.c +++ b/mindspore/lite/nnacl/int8/add_int8.c @@ -161,3 +161,146 @@ void AddInt8(const int8_t *input0, const int8_t *input1, int8_t *output, int siz } return; } + +void AddOptInt8(const int8_t *ptr_in, const int8_t element_in, int8_t *output, int size, AddQuantParameter *params) { + int in0_left_shift = (1 << params->left_shift_) * (1 << params->in0_left_shift_); + int in1_left_shift = (1 << params->left_shift_) * (1 << params->in1_left_shift_); + int index = 0; + +#ifdef ENABLE_ARM + const int8x16_t in1_src = vdupq_n_s8(element_in); + + const int8x16_t min_vec = vdupq_n_s8(params->min_); + const int8x16_t max_vac = vdupq_n_s8(params->max_); + + const int16x8_t in0_zp_vec = vdupq_n_s16(params->in0_zp_); + const int16x8_t in1_zp_vec = vdupq_n_s16(params->in1_zp_); + const int16x8_t out_zp_vec = vdupq_n_s16(params->out_zp_); + + const int32x4_t in0_left_vec = vdupq_n_s32(in0_left_shift); + const int32x4_t in1_left_vec = vdupq_n_s32(in1_left_shift); + + const int32x4_t in0_right_vec = vdupq_n_s32(-params->in0_right_shift_); + const int32x4_t in1_right_vec = vdupq_n_s32(-params->in1_right_shift_); + + const int32x4_t out_left_vec = vdupq_n_s32(params->out_left_shift_); + const int32x4_t out_right_vec = vdupq_n_s32(-params->out_right_shift_); + + for (; index <= size - 16; index += 16) { + const int8x16_t in0_src = vld1q_s8(ptr_in + index); + + const int16x8_t in0_s16_low = vmovl_s8(vget_low_s8(in0_src)); + const int16x8_t in0_s16_high = vmovl_s8(vget_high_s8(in0_src)); + const int16x8_t in1_s16_low = vmovl_s8(vget_low_s8(in1_src)); + const int16x8_t in1_s16_high = vmovl_s8(vget_high_s8(in1_src)); + + const int16x8_t in0_zp_low = vaddq_s16(in0_s16_low, in0_zp_vec); + const int16x8_t in0_zp_high = vaddq_s16(in0_s16_high, in0_zp_vec); + const int16x8_t in1_zp_low = vaddq_s16(in1_s16_low, in1_zp_vec); + const int16x8_t in1_zp_high = vaddq_s16(in1_s16_high, in1_zp_vec); + + int32x4_t in0_1 = vmovl_s16(vget_low_s16(in0_zp_low)); + int32x4_t in0_2 = vmovl_s16(vget_high_s16(in0_zp_low)); + int32x4_t in0_3 = vmovl_s16(vget_low_s16(in0_zp_high)); + int32x4_t in0_4 = vmovl_s16(vget_high_s16(in0_zp_high)); + int32x4_t in1_1 = vmovl_s16(vget_low_s16(in1_zp_low)); + int32x4_t in1_2 = vmovl_s16(vget_high_s16(in1_zp_low)); + int32x4_t in1_3 = vmovl_s16(vget_low_s16(in1_zp_high)); + int32x4_t in1_4 = vmovl_s16(vget_high_s16(in1_zp_high)); + + // Apply left shift + in0_1 = vmulq_s32(in0_1, in0_left_vec); + in0_2 = vmulq_s32(in0_2, in0_left_vec); + in0_3 = vmulq_s32(in0_3, in0_left_vec); + in0_4 = vmulq_s32(in0_4, in0_left_vec); + in1_1 = vmulq_s32(in1_1, in1_left_vec); + in1_2 = vmulq_s32(in1_2, in1_left_vec); + in1_3 = vmulq_s32(in1_3, in1_left_vec); + in1_4 = vmulq_s32(in1_4, in1_left_vec); + + // Apply the fixed-point part of the multiplier. + in0_1 = vqrdmulhq_n_s32(in0_1, params->in0_multiplier_); + in0_2 = vqrdmulhq_n_s32(in0_2, params->in0_multiplier_); + in0_3 = vqrdmulhq_n_s32(in0_3, params->in0_multiplier_); + in0_4 = vqrdmulhq_n_s32(in0_4, params->in0_multiplier_); + in1_1 = vqrdmulhq_n_s32(in1_1, params->in1_multiplier_); + in1_2 = vqrdmulhq_n_s32(in1_2, params->in1_multiplier_); + in1_3 = vqrdmulhq_n_s32(in1_3, params->in1_multiplier_); + in1_4 = vqrdmulhq_n_s32(in1_4, params->in1_multiplier_); + + // Apply right shift + in0_1 = vqaddq_s32(in0_1, vshrq_n_s32(vandq_s32(in0_1, in0_right_vec), 31)); + in0_2 = vqaddq_s32(in0_2, vshrq_n_s32(vandq_s32(in0_2, in0_right_vec), 31)); + in0_3 = vqaddq_s32(in0_3, vshrq_n_s32(vandq_s32(in0_3, in0_right_vec), 31)); + in0_4 = vqaddq_s32(in0_4, vshrq_n_s32(vandq_s32(in0_4, in0_right_vec), 31)); + in1_1 = vqaddq_s32(in1_1, vshrq_n_s32(vandq_s32(in1_1, in1_right_vec), 31)); + in1_2 = vqaddq_s32(in1_2, vshrq_n_s32(vandq_s32(in1_2, in1_right_vec), 31)); + in1_3 = vqaddq_s32(in1_3, vshrq_n_s32(vandq_s32(in1_3, in1_right_vec), 31)); + in1_4 = vqaddq_s32(in1_4, vshrq_n_s32(vandq_s32(in1_4, in1_right_vec), 31)); + + in0_1 = vrshlq_s32(in0_1, in0_right_vec); + in0_2 = vrshlq_s32(in0_2, in0_right_vec); + in0_3 = vrshlq_s32(in0_3, in0_right_vec); + in0_4 = vrshlq_s32(in0_4, in0_right_vec); + in1_1 = vrshlq_s32(in1_1, in1_right_vec); + in1_2 = vrshlq_s32(in1_2, in1_right_vec); + in1_3 = vrshlq_s32(in1_3, in1_right_vec); + in1_4 = vrshlq_s32(in1_4, in1_right_vec); + + /* calculate output */ + int32x4_t out1 = vaddq_s32(in0_1, in1_1); + int32x4_t out2 = vaddq_s32(in0_2, in1_2); + int32x4_t out3 = vaddq_s32(in0_3, in1_3); + int32x4_t out4 = vaddq_s32(in0_4, in1_4); + + // Apply left shift + out1 = vshlq_s32(out1, out_left_vec); + out2 = vshlq_s32(out2, out_left_vec); + out3 = vshlq_s32(out3, out_left_vec); + out4 = vshlq_s32(out4, out_left_vec); + + // Apply the fixed-point part of the multiplier. + out1 = vqrdmulhq_n_s32(out1, params->out_multiplier_); + out2 = vqrdmulhq_n_s32(out2, params->out_multiplier_); + out3 = vqrdmulhq_n_s32(out3, params->out_multiplier_); + out4 = vqrdmulhq_n_s32(out4, params->out_multiplier_); + + // Apply right shift + out1 = vqaddq_s32(out1, vshrq_n_s32(vandq_s32(out1, out_right_vec), 31)); + out2 = vqaddq_s32(out2, vshrq_n_s32(vandq_s32(out2, out_right_vec), 31)); + out3 = vqaddq_s32(out3, vshrq_n_s32(vandq_s32(out3, out_right_vec), 31)); + out4 = vqaddq_s32(out4, vshrq_n_s32(vandq_s32(out4, out_right_vec), 31)); + + out1 = vrshlq_s32(out1, out_right_vec); + out2 = vrshlq_s32(out2, out_right_vec); + out3 = vrshlq_s32(out3, out_right_vec); + out4 = vrshlq_s32(out4, out_right_vec); + + const int16x4_t out1_s16 = vmovn_s32(out1); + const int16x4_t out2_s16 = vmovn_s32(out2); + const int16x4_t out3_s16 = vmovn_s32(out3); + const int16x4_t out4_s16 = vmovn_s32(out4); + + const int16x8_t out_s16_1 = vaddq_s16(vcombine_s16(out1_s16, out2_s16), out_zp_vec); + const int16x8_t out_s16_2 = vaddq_s16(vcombine_s16(out3_s16, out4_s16), out_zp_vec); + + const int8x16_t out = vcombine_s8(vqmovn_s16(out_s16_1), vqmovn_s16(out_s16_2)); + const int8x16_t int8_out = vmaxq_s8(min_vec, vminq_s8(max_vac, out)); + + vst1q_s8(output + index, int8_out); + } +#endif + + for (; index < size; index++) { + const int32_t in0_left = (ptr_in[index] + params->in0_zp_) * in0_left_shift; + const int32_t in1_left = (element_in + params->in1_zp_) * in1_left_shift; + const int32_t in0 = MultiplyByMultiplierAndRightShift(in0_left, params->in0_multiplier_, params->in0_right_shift_); + const int32_t in1 = MultiplyByMultiplierAndRightShift(in1_left, params->in1_multiplier_, params->in1_right_shift_); + + int32_t out = MultiplyByQuantizedMultiplier(in0 + in1, params->out_multiplier_, params->out_left_shift_, + -params->out_right_shift_); + out += params->out_zp_; + output[index] = (int8_t)MSMAX(params->min_, MSMIN(out, params->max_)); + } + return; +} diff --git a/mindspore/lite/nnacl/int8/add_int8.h b/mindspore/lite/nnacl/int8/add_int8.h index 127be10383..44e0d066d0 100644 --- a/mindspore/lite/nnacl/int8/add_int8.h +++ b/mindspore/lite/nnacl/int8/add_int8.h @@ -46,6 +46,7 @@ extern "C" { #endif void AddInt8(const int8_t *input0, const int8_t *input1, int8_t *output, int size, AddQuantParameter *params); +void AddOptInt8(const int8_t *ptr_in, const int8_t element_in, int8_t *output, int size, AddQuantParameter *params); #ifdef __cplusplus } diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/add_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/add_int8.cc index ff3114293c..84c31c38a3 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/add_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/add_int8.cc @@ -34,7 +34,6 @@ int QuantizedAddCPUKernel::Init() { auto *input0 = in_tensors_.at(0); auto *input1 = in_tensors_.at(1); auto *output = out_tensors_.at(0); - auto act = arith_para_->activation_type_; para_.in0_zp_ = input0->GetQuantParams().front().zeroPoint * -1; para_.in1_zp_ = input1->GetQuantParams().front().zeroPoint * -1; @@ -62,6 +61,7 @@ int QuantizedAddCPUKernel::Init() { para_.in1_left_shift_ = -para_.in1_left_shift_ > 0 ? -para_.in1_left_shift_ : 0; para_.out_left_shift_ = -para_.out_left_shift_ > 0 ? -para_.out_left_shift_ : 0; + auto act = arith_para_->activation_type_; CalculateActivationRangeQuantized(act == ActType_Relu, act == ActType_Relu6, 0, 1, ¶_.min_, ¶_.max_); if (!InferShapeDone()) { @@ -71,11 +71,47 @@ int QuantizedAddCPUKernel::Init() { } int QuantizedAddCPUKernel::ReSize() { - elements_num_ = out_tensors_.at(0)->ElementsNum(); - arith_para_->broadcasting_ = in_tensors_.at(0)->ElementsNum() != in_tensors_.at(1)->ElementsNum(); + auto *input0 = in_tensors_.at(0); + auto *input1 = in_tensors_.at(1); + auto *output = out_tensors_.at(0); + support_opt_add_ = (input0->ElementsNum() == 1) || (input1->ElementsNum() == 1); + if (support_opt_add_) { + arith_para_->broadcasting_ = false; + } + elements_num_ = output->ElementsNum(); thread_count_ = MSMIN(elements_num_, op_parameter_->thread_num_); - thread_stride_ = UP_DIV(elements_num_, thread_count_); + + arith_para_->in_elements_num0_ = in_tensors_[0]->ElementsNum(); + arith_para_->in_elements_num1_ = in_tensors_[1]->ElementsNum(); + arith_para_->out_elements_num_ = out_tensors_[0]->ElementsNum(); + + memcpy(arith_para_->in_shape0_, input0->shape().data(), input0->shape().size() * sizeof(int)); + memcpy(arith_para_->in_shape1_, input1->shape().data(), input1->shape().size() * sizeof(int)); + memcpy(arith_para_->out_shape_, output->shape().data(), output->shape().size() * sizeof(int)); + + if (arith_para_->broadcasting_) { + size_t break_pos_ = 0; + for (auto i = arith_para_->ndim_ - 1; i >= 0; --i) { + if (arith_para_->in_shape0_[i] != arith_para_->in_shape1_[i]) { + break_pos_ = i; + break; + } + } + in_size_ = 1; + out_size_ = 1; + for (size_t i = 0; i < arith_para_->ndim_; i++) { + if (i > break_pos_) { + in_size_ *= arith_para_->out_shape_[i]; + } else { + out_size_ *= arith_para_->out_shape_[i]; + } + } + + ComputeStrides(arith_para_->in_shape0_, arith_para_->in_strides0_, arith_para_->ndim_); + ComputeStrides(arith_para_->in_shape1_, arith_para_->in_strides1_, arith_para_->ndim_); + ComputeStrides(arith_para_->out_shape_, arith_para_->out_strides_, arith_para_->ndim_); + } return RET_OK; } @@ -85,54 +121,60 @@ int AddInt8Run(void *cdata, int task_id) { return RET_OK; } +void QuantizedAddCPUKernel::BroadcastRun(int task_id) { + int stride = UP_DIV(out_size_, thread_count_); + int real_out_count = MSMIN(stride, out_size_ - stride * task_id); + if (real_out_count <= 0) { + return; + } + + int8_t *const_in = arith_para_->in_elements_num0_ == arith_para_->out_elements_num_ ? input1_data_ : input0_data_; + int8_t *offset_in = arith_para_->in_elements_num0_ == arith_para_->out_elements_num_ ? input0_data_ : input1_data_; + offset_in += task_id * stride * in_size_; + int8_t *cur_out = output_data_ + task_id * stride * in_size_; + + for (int i = 0; i < real_out_count; i++) { + AddInt8(offset_in + i * in_size_, const_in, cur_out + i * in_size_, in_size_, ¶_); + } + return; +} + int QuantizedAddCPUKernel::DoExecute(int task_id) { - int rest_count = elements_num_ - task_id * thread_stride_; - int real_count = MSMIN(thread_stride_, rest_count); - if (real_count <= 0) { + /* need broadcast */ + if (arith_para_->broadcasting_) { + BroadcastRun(task_id); return RET_OK; } - int8_t *cur_input0_data = input0_data_ + task_id * thread_stride_; - int8_t *cur_input1_data = input1_data_ + task_id * thread_stride_; - int8_t *cur_output_data = output_data_ + task_id * thread_stride_; + /* no need broadcast */ + int stride = UP_DIV(elements_num_, thread_count_); + int rest_count = elements_num_ - task_id * stride; + int real_count = MSMIN(stride, rest_count); + if (real_count <= 0) { + return RET_OK; + } + int8_t *cur_in0 = input0_data_ + stride * task_id; + int8_t *cur_in1 = input1_data_ + stride * task_id; + int8_t *cur_out = output_data_ + stride * task_id; + if (support_opt_add_) { + int8_t *ptr_in = arith_para_->in_elements_num0_ == 1 ? cur_in1 : cur_in0; + int8_t element_in = arith_para_->in_elements_num0_ == 1 ? input0_data_[0] : input1_data_[0]; + AddOptInt8(ptr_in, element_in, cur_out, rest_count, ¶_); + } else { + AddInt8(cur_in0, cur_in1, cur_out, rest_count, ¶_); + } - AddInt8(cur_input0_data, cur_input1_data, cur_output_data, real_count, ¶_); return RET_OK; } int QuantizedAddCPUKernel::Run() { - int8_t *src_in0 = static_cast(in_tensors_.at(0)->data_c()); - int8_t *src_in1 = static_cast(in_tensors_.at(1)->data_c()); + input0_data_ = static_cast(in_tensors_.at(0)->data_c()); + input1_data_ = static_cast(in_tensors_.at(1)->data_c()); output_data_ = static_cast(out_tensors_.at(0)->data_c()); - if (arith_para_->broadcasting_) { - input0_data_ = static_cast(context_->allocator->Malloc(elements_num_ * sizeof(int8_t))); - if (input0_data_ == nullptr) { - MS_LOG(ERROR) << "malloc input0_data_ failed."; - return RET_ERROR; - } - input1_data_ = static_cast(context_->allocator->Malloc(elements_num_ * sizeof(int8_t))); - if (input1_data_ == nullptr) { - context_->allocator->Free(input0_data_); - input0_data_ = nullptr; - MS_LOG(ERROR) << "malloc input1_data_ failed."; - return RET_ERROR; - } - - TileDimensionsInt8(src_in0, src_in1, input0_data_, input1_data_, arith_para_); - auto ret = ParallelLaunch(context_->thread_pool_, AddInt8Run, this, thread_count_); - - context_->allocator->Free(input0_data_); - context_->allocator->Free(input1_data_); - input0_data_ = nullptr; - input1_data_ = nullptr; - return ret; - } + ParallelLaunch(this->context_->thread_pool_, AddInt8Run, this, thread_count_); - input0_data_ = src_in0; - input1_data_ = src_in1; - auto ret = ParallelLaunch(this->context_->thread_pool_, AddInt8Run, this, thread_count_); - return ret; + return RET_OK; } kernel::LiteKernel *CpuAddInt8KernelCreator(const std::vector &inputs, diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/add_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/add_int8.h index 4324bb3d04..8834387949 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/add_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/add_int8.h @@ -38,12 +38,17 @@ class QuantizedAddCPUKernel : public LiteKernel { int Run() override; int DoExecute(int tId); + private: + void BroadcastRun(int task_id); + private: AddQuantParameter para_; ArithmeticParameter *arith_para_ = nullptr; + int in_size_ = 0; + int out_size_ = 0; int thread_count_ = 1; - int thread_stride_ = 0; int elements_num_ = 0; + bool support_opt_add_ = false; int8_t *input0_data_ = nullptr; int8_t *input1_data_ = nullptr; int8_t *output_data_ = nullptr;