Browse Source

!8611 [MSLITE] add boardcast optimize

From: @ling_qiao_min
Reviewed-by: @hangangqiang
Signed-off-by:
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
f1cc7054f8
4 changed files with 232 additions and 41 deletions
  1. +143
    -0
      mindspore/lite/nnacl/int8/add_int8.c
  2. +1
    -0
      mindspore/lite/nnacl/int8/add_int8.h
  3. +82
    -40
      mindspore/lite/src/runtime/kernel/arm/int8/add_int8.cc
  4. +6
    -1
      mindspore/lite/src/runtime/kernel/arm/int8/add_int8.h

+ 143
- 0
mindspore/lite/nnacl/int8/add_int8.c View File

@@ -161,3 +161,146 @@ void AddInt8(const int8_t *input0, const int8_t *input1, int8_t *output, int siz
} }
return; return;
} }

void AddOptInt8(const int8_t *ptr_in, const int8_t element_in, int8_t *output, int size, AddQuantParameter *params) {
int in0_left_shift = (1 << params->left_shift_) * (1 << params->in0_left_shift_);
int in1_left_shift = (1 << params->left_shift_) * (1 << params->in1_left_shift_);
int index = 0;

#ifdef ENABLE_ARM
const int8x16_t in1_src = vdupq_n_s8(element_in);

const int8x16_t min_vec = vdupq_n_s8(params->min_);
const int8x16_t max_vac = vdupq_n_s8(params->max_);

const int16x8_t in0_zp_vec = vdupq_n_s16(params->in0_zp_);
const int16x8_t in1_zp_vec = vdupq_n_s16(params->in1_zp_);
const int16x8_t out_zp_vec = vdupq_n_s16(params->out_zp_);

const int32x4_t in0_left_vec = vdupq_n_s32(in0_left_shift);
const int32x4_t in1_left_vec = vdupq_n_s32(in1_left_shift);

const int32x4_t in0_right_vec = vdupq_n_s32(-params->in0_right_shift_);
const int32x4_t in1_right_vec = vdupq_n_s32(-params->in1_right_shift_);

const int32x4_t out_left_vec = vdupq_n_s32(params->out_left_shift_);
const int32x4_t out_right_vec = vdupq_n_s32(-params->out_right_shift_);

for (; index <= size - 16; index += 16) {
const int8x16_t in0_src = vld1q_s8(ptr_in + index);

const int16x8_t in0_s16_low = vmovl_s8(vget_low_s8(in0_src));
const int16x8_t in0_s16_high = vmovl_s8(vget_high_s8(in0_src));
const int16x8_t in1_s16_low = vmovl_s8(vget_low_s8(in1_src));
const int16x8_t in1_s16_high = vmovl_s8(vget_high_s8(in1_src));

const int16x8_t in0_zp_low = vaddq_s16(in0_s16_low, in0_zp_vec);
const int16x8_t in0_zp_high = vaddq_s16(in0_s16_high, in0_zp_vec);
const int16x8_t in1_zp_low = vaddq_s16(in1_s16_low, in1_zp_vec);
const int16x8_t in1_zp_high = vaddq_s16(in1_s16_high, in1_zp_vec);

int32x4_t in0_1 = vmovl_s16(vget_low_s16(in0_zp_low));
int32x4_t in0_2 = vmovl_s16(vget_high_s16(in0_zp_low));
int32x4_t in0_3 = vmovl_s16(vget_low_s16(in0_zp_high));
int32x4_t in0_4 = vmovl_s16(vget_high_s16(in0_zp_high));
int32x4_t in1_1 = vmovl_s16(vget_low_s16(in1_zp_low));
int32x4_t in1_2 = vmovl_s16(vget_high_s16(in1_zp_low));
int32x4_t in1_3 = vmovl_s16(vget_low_s16(in1_zp_high));
int32x4_t in1_4 = vmovl_s16(vget_high_s16(in1_zp_high));

// Apply left shift
in0_1 = vmulq_s32(in0_1, in0_left_vec);
in0_2 = vmulq_s32(in0_2, in0_left_vec);
in0_3 = vmulq_s32(in0_3, in0_left_vec);
in0_4 = vmulq_s32(in0_4, in0_left_vec);
in1_1 = vmulq_s32(in1_1, in1_left_vec);
in1_2 = vmulq_s32(in1_2, in1_left_vec);
in1_3 = vmulq_s32(in1_3, in1_left_vec);
in1_4 = vmulq_s32(in1_4, in1_left_vec);

// Apply the fixed-point part of the multiplier.
in0_1 = vqrdmulhq_n_s32(in0_1, params->in0_multiplier_);
in0_2 = vqrdmulhq_n_s32(in0_2, params->in0_multiplier_);
in0_3 = vqrdmulhq_n_s32(in0_3, params->in0_multiplier_);
in0_4 = vqrdmulhq_n_s32(in0_4, params->in0_multiplier_);
in1_1 = vqrdmulhq_n_s32(in1_1, params->in1_multiplier_);
in1_2 = vqrdmulhq_n_s32(in1_2, params->in1_multiplier_);
in1_3 = vqrdmulhq_n_s32(in1_3, params->in1_multiplier_);
in1_4 = vqrdmulhq_n_s32(in1_4, params->in1_multiplier_);

// Apply right shift
in0_1 = vqaddq_s32(in0_1, vshrq_n_s32(vandq_s32(in0_1, in0_right_vec), 31));
in0_2 = vqaddq_s32(in0_2, vshrq_n_s32(vandq_s32(in0_2, in0_right_vec), 31));
in0_3 = vqaddq_s32(in0_3, vshrq_n_s32(vandq_s32(in0_3, in0_right_vec), 31));
in0_4 = vqaddq_s32(in0_4, vshrq_n_s32(vandq_s32(in0_4, in0_right_vec), 31));
in1_1 = vqaddq_s32(in1_1, vshrq_n_s32(vandq_s32(in1_1, in1_right_vec), 31));
in1_2 = vqaddq_s32(in1_2, vshrq_n_s32(vandq_s32(in1_2, in1_right_vec), 31));
in1_3 = vqaddq_s32(in1_3, vshrq_n_s32(vandq_s32(in1_3, in1_right_vec), 31));
in1_4 = vqaddq_s32(in1_4, vshrq_n_s32(vandq_s32(in1_4, in1_right_vec), 31));

in0_1 = vrshlq_s32(in0_1, in0_right_vec);
in0_2 = vrshlq_s32(in0_2, in0_right_vec);
in0_3 = vrshlq_s32(in0_3, in0_right_vec);
in0_4 = vrshlq_s32(in0_4, in0_right_vec);
in1_1 = vrshlq_s32(in1_1, in1_right_vec);
in1_2 = vrshlq_s32(in1_2, in1_right_vec);
in1_3 = vrshlq_s32(in1_3, in1_right_vec);
in1_4 = vrshlq_s32(in1_4, in1_right_vec);

/* calculate output */
int32x4_t out1 = vaddq_s32(in0_1, in1_1);
int32x4_t out2 = vaddq_s32(in0_2, in1_2);
int32x4_t out3 = vaddq_s32(in0_3, in1_3);
int32x4_t out4 = vaddq_s32(in0_4, in1_4);

// Apply left shift
out1 = vshlq_s32(out1, out_left_vec);
out2 = vshlq_s32(out2, out_left_vec);
out3 = vshlq_s32(out3, out_left_vec);
out4 = vshlq_s32(out4, out_left_vec);

// Apply the fixed-point part of the multiplier.
out1 = vqrdmulhq_n_s32(out1, params->out_multiplier_);
out2 = vqrdmulhq_n_s32(out2, params->out_multiplier_);
out3 = vqrdmulhq_n_s32(out3, params->out_multiplier_);
out4 = vqrdmulhq_n_s32(out4, params->out_multiplier_);

// Apply right shift
out1 = vqaddq_s32(out1, vshrq_n_s32(vandq_s32(out1, out_right_vec), 31));
out2 = vqaddq_s32(out2, vshrq_n_s32(vandq_s32(out2, out_right_vec), 31));
out3 = vqaddq_s32(out3, vshrq_n_s32(vandq_s32(out3, out_right_vec), 31));
out4 = vqaddq_s32(out4, vshrq_n_s32(vandq_s32(out4, out_right_vec), 31));

out1 = vrshlq_s32(out1, out_right_vec);
out2 = vrshlq_s32(out2, out_right_vec);
out3 = vrshlq_s32(out3, out_right_vec);
out4 = vrshlq_s32(out4, out_right_vec);

const int16x4_t out1_s16 = vmovn_s32(out1);
const int16x4_t out2_s16 = vmovn_s32(out2);
const int16x4_t out3_s16 = vmovn_s32(out3);
const int16x4_t out4_s16 = vmovn_s32(out4);

const int16x8_t out_s16_1 = vaddq_s16(vcombine_s16(out1_s16, out2_s16), out_zp_vec);
const int16x8_t out_s16_2 = vaddq_s16(vcombine_s16(out3_s16, out4_s16), out_zp_vec);

const int8x16_t out = vcombine_s8(vqmovn_s16(out_s16_1), vqmovn_s16(out_s16_2));
const int8x16_t int8_out = vmaxq_s8(min_vec, vminq_s8(max_vac, out));

vst1q_s8(output + index, int8_out);
}
#endif

for (; index < size; index++) {
const int32_t in0_left = (ptr_in[index] + params->in0_zp_) * in0_left_shift;
const int32_t in1_left = (element_in + params->in1_zp_) * in1_left_shift;
const int32_t in0 = MultiplyByMultiplierAndRightShift(in0_left, params->in0_multiplier_, params->in0_right_shift_);
const int32_t in1 = MultiplyByMultiplierAndRightShift(in1_left, params->in1_multiplier_, params->in1_right_shift_);

int32_t out = MultiplyByQuantizedMultiplier(in0 + in1, params->out_multiplier_, params->out_left_shift_,
-params->out_right_shift_);
out += params->out_zp_;
output[index] = (int8_t)MSMAX(params->min_, MSMIN(out, params->max_));
}
return;
}

+ 1
- 0
mindspore/lite/nnacl/int8/add_int8.h View File

@@ -46,6 +46,7 @@ extern "C" {
#endif #endif


void AddInt8(const int8_t *input0, const int8_t *input1, int8_t *output, int size, AddQuantParameter *params); void AddInt8(const int8_t *input0, const int8_t *input1, int8_t *output, int size, AddQuantParameter *params);
void AddOptInt8(const int8_t *ptr_in, const int8_t element_in, int8_t *output, int size, AddQuantParameter *params);


#ifdef __cplusplus #ifdef __cplusplus
} }


+ 82
- 40
mindspore/lite/src/runtime/kernel/arm/int8/add_int8.cc View File

@@ -34,7 +34,6 @@ int QuantizedAddCPUKernel::Init() {
auto *input0 = in_tensors_.at(0); auto *input0 = in_tensors_.at(0);
auto *input1 = in_tensors_.at(1); auto *input1 = in_tensors_.at(1);
auto *output = out_tensors_.at(0); auto *output = out_tensors_.at(0);
auto act = arith_para_->activation_type_;


para_.in0_zp_ = input0->GetQuantParams().front().zeroPoint * -1; para_.in0_zp_ = input0->GetQuantParams().front().zeroPoint * -1;
para_.in1_zp_ = input1->GetQuantParams().front().zeroPoint * -1; para_.in1_zp_ = input1->GetQuantParams().front().zeroPoint * -1;
@@ -62,6 +61,7 @@ int QuantizedAddCPUKernel::Init() {
para_.in1_left_shift_ = -para_.in1_left_shift_ > 0 ? -para_.in1_left_shift_ : 0; para_.in1_left_shift_ = -para_.in1_left_shift_ > 0 ? -para_.in1_left_shift_ : 0;
para_.out_left_shift_ = -para_.out_left_shift_ > 0 ? -para_.out_left_shift_ : 0; para_.out_left_shift_ = -para_.out_left_shift_ > 0 ? -para_.out_left_shift_ : 0;


auto act = arith_para_->activation_type_;
CalculateActivationRangeQuantized(act == ActType_Relu, act == ActType_Relu6, 0, 1, &para_.min_, &para_.max_); CalculateActivationRangeQuantized(act == ActType_Relu, act == ActType_Relu6, 0, 1, &para_.min_, &para_.max_);


if (!InferShapeDone()) { if (!InferShapeDone()) {
@@ -71,11 +71,47 @@ int QuantizedAddCPUKernel::Init() {
} }


int QuantizedAddCPUKernel::ReSize() { int QuantizedAddCPUKernel::ReSize() {
elements_num_ = out_tensors_.at(0)->ElementsNum();
arith_para_->broadcasting_ = in_tensors_.at(0)->ElementsNum() != in_tensors_.at(1)->ElementsNum();
auto *input0 = in_tensors_.at(0);
auto *input1 = in_tensors_.at(1);
auto *output = out_tensors_.at(0);
support_opt_add_ = (input0->ElementsNum() == 1) || (input1->ElementsNum() == 1);
if (support_opt_add_) {
arith_para_->broadcasting_ = false;
}


elements_num_ = output->ElementsNum();
thread_count_ = MSMIN(elements_num_, op_parameter_->thread_num_); thread_count_ = MSMIN(elements_num_, op_parameter_->thread_num_);
thread_stride_ = UP_DIV(elements_num_, thread_count_);

arith_para_->in_elements_num0_ = in_tensors_[0]->ElementsNum();
arith_para_->in_elements_num1_ = in_tensors_[1]->ElementsNum();
arith_para_->out_elements_num_ = out_tensors_[0]->ElementsNum();

memcpy(arith_para_->in_shape0_, input0->shape().data(), input0->shape().size() * sizeof(int));
memcpy(arith_para_->in_shape1_, input1->shape().data(), input1->shape().size() * sizeof(int));
memcpy(arith_para_->out_shape_, output->shape().data(), output->shape().size() * sizeof(int));

if (arith_para_->broadcasting_) {
size_t break_pos_ = 0;
for (auto i = arith_para_->ndim_ - 1; i >= 0; --i) {
if (arith_para_->in_shape0_[i] != arith_para_->in_shape1_[i]) {
break_pos_ = i;
break;
}
}
in_size_ = 1;
out_size_ = 1;
for (size_t i = 0; i < arith_para_->ndim_; i++) {
if (i > break_pos_) {
in_size_ *= arith_para_->out_shape_[i];
} else {
out_size_ *= arith_para_->out_shape_[i];
}
}

ComputeStrides(arith_para_->in_shape0_, arith_para_->in_strides0_, arith_para_->ndim_);
ComputeStrides(arith_para_->in_shape1_, arith_para_->in_strides1_, arith_para_->ndim_);
ComputeStrides(arith_para_->out_shape_, arith_para_->out_strides_, arith_para_->ndim_);
}
return RET_OK; return RET_OK;
} }


@@ -85,54 +121,60 @@ int AddInt8Run(void *cdata, int task_id) {
return RET_OK; return RET_OK;
} }


void QuantizedAddCPUKernel::BroadcastRun(int task_id) {
int stride = UP_DIV(out_size_, thread_count_);
int real_out_count = MSMIN(stride, out_size_ - stride * task_id);
if (real_out_count <= 0) {
return;
}

int8_t *const_in = arith_para_->in_elements_num0_ == arith_para_->out_elements_num_ ? input1_data_ : input0_data_;
int8_t *offset_in = arith_para_->in_elements_num0_ == arith_para_->out_elements_num_ ? input0_data_ : input1_data_;
offset_in += task_id * stride * in_size_;
int8_t *cur_out = output_data_ + task_id * stride * in_size_;

for (int i = 0; i < real_out_count; i++) {
AddInt8(offset_in + i * in_size_, const_in, cur_out + i * in_size_, in_size_, &para_);
}
return;
}

int QuantizedAddCPUKernel::DoExecute(int task_id) { int QuantizedAddCPUKernel::DoExecute(int task_id) {
int rest_count = elements_num_ - task_id * thread_stride_;
int real_count = MSMIN(thread_stride_, rest_count);
if (real_count <= 0) {
/* need broadcast */
if (arith_para_->broadcasting_) {
BroadcastRun(task_id);
return RET_OK; return RET_OK;
} }


int8_t *cur_input0_data = input0_data_ + task_id * thread_stride_;
int8_t *cur_input1_data = input1_data_ + task_id * thread_stride_;
int8_t *cur_output_data = output_data_ + task_id * thread_stride_;
/* no need broadcast */
int stride = UP_DIV(elements_num_, thread_count_);
int rest_count = elements_num_ - task_id * stride;
int real_count = MSMIN(stride, rest_count);
if (real_count <= 0) {
return RET_OK;
}
int8_t *cur_in0 = input0_data_ + stride * task_id;
int8_t *cur_in1 = input1_data_ + stride * task_id;
int8_t *cur_out = output_data_ + stride * task_id;
if (support_opt_add_) {
int8_t *ptr_in = arith_para_->in_elements_num0_ == 1 ? cur_in1 : cur_in0;
int8_t element_in = arith_para_->in_elements_num0_ == 1 ? input0_data_[0] : input1_data_[0];
AddOptInt8(ptr_in, element_in, cur_out, rest_count, &para_);
} else {
AddInt8(cur_in0, cur_in1, cur_out, rest_count, &para_);
}


AddInt8(cur_input0_data, cur_input1_data, cur_output_data, real_count, &para_);
return RET_OK; return RET_OK;
} }


int QuantizedAddCPUKernel::Run() { int QuantizedAddCPUKernel::Run() {
int8_t *src_in0 = static_cast<int8_t *>(in_tensors_.at(0)->data_c());
int8_t *src_in1 = static_cast<int8_t *>(in_tensors_.at(1)->data_c());
input0_data_ = static_cast<int8_t *>(in_tensors_.at(0)->data_c());
input1_data_ = static_cast<int8_t *>(in_tensors_.at(1)->data_c());
output_data_ = static_cast<int8_t *>(out_tensors_.at(0)->data_c()); output_data_ = static_cast<int8_t *>(out_tensors_.at(0)->data_c());


if (arith_para_->broadcasting_) {
input0_data_ = static_cast<int8_t *>(context_->allocator->Malloc(elements_num_ * sizeof(int8_t)));
if (input0_data_ == nullptr) {
MS_LOG(ERROR) << "malloc input0_data_ failed.";
return RET_ERROR;
}
input1_data_ = static_cast<int8_t *>(context_->allocator->Malloc(elements_num_ * sizeof(int8_t)));
if (input1_data_ == nullptr) {
context_->allocator->Free(input0_data_);
input0_data_ = nullptr;
MS_LOG(ERROR) << "malloc input1_data_ failed.";
return RET_ERROR;
}

TileDimensionsInt8(src_in0, src_in1, input0_data_, input1_data_, arith_para_);
auto ret = ParallelLaunch(context_->thread_pool_, AddInt8Run, this, thread_count_);

context_->allocator->Free(input0_data_);
context_->allocator->Free(input1_data_);
input0_data_ = nullptr;
input1_data_ = nullptr;
return ret;
}
ParallelLaunch(this->context_->thread_pool_, AddInt8Run, this, thread_count_);


input0_data_ = src_in0;
input1_data_ = src_in1;
auto ret = ParallelLaunch(this->context_->thread_pool_, AddInt8Run, this, thread_count_);
return ret;
return RET_OK;
} }


kernel::LiteKernel *CpuAddInt8KernelCreator(const std::vector<lite::Tensor *> &inputs, kernel::LiteKernel *CpuAddInt8KernelCreator(const std::vector<lite::Tensor *> &inputs,


+ 6
- 1
mindspore/lite/src/runtime/kernel/arm/int8/add_int8.h View File

@@ -38,12 +38,17 @@ class QuantizedAddCPUKernel : public LiteKernel {
int Run() override; int Run() override;
int DoExecute(int tId); int DoExecute(int tId);


private:
void BroadcastRun(int task_id);

private: private:
AddQuantParameter para_; AddQuantParameter para_;
ArithmeticParameter *arith_para_ = nullptr; ArithmeticParameter *arith_para_ = nullptr;
int in_size_ = 0;
int out_size_ = 0;
int thread_count_ = 1; int thread_count_ = 1;
int thread_stride_ = 0;
int elements_num_ = 0; int elements_num_ = 0;
bool support_opt_add_ = false;
int8_t *input0_data_ = nullptr; int8_t *input0_data_ = nullptr;
int8_t *input1_data_ = nullptr; int8_t *input1_data_ = nullptr;
int8_t *output_data_ = nullptr; int8_t *output_data_ = nullptr;


Loading…
Cancel
Save