diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/conv_depthwise_int8.c b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/conv_depthwise_int8.c index 1977f49e7c..9d8d88acb5 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/conv_depthwise_int8.c +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/conv_depthwise_int8.c @@ -21,8 +21,9 @@ /*conv depthwise int8 begin*/ void DepthwiseBorderPixelInt8(int8_t *dst, const int16_t *src, const int16_t *weight, const int32_t *bias, int height, - int width, int in_kh_step, int in_kw_step, int kernel_w, int out_multiplier, - int left_shift, int right_shift, int32_t out_zp, int32_t acc_min, int32_t acc_max) { + int width, int in_kh_step, int in_kw_step, int kernel_w, int *out_multiplier, + int *left_shift, int *right_shift, int32_t out_zp, int32_t acc_min, int32_t acc_max, + bool per_channel) { int tmp_buffer[C4NUM]; for (int i = 0; i < C4NUM; i++) { tmp_buffer[i] = 0; @@ -42,10 +43,18 @@ void DepthwiseBorderPixelInt8(int8_t *dst, const int16_t *src, const int16_t *we src_kh += in_kh_step; weight_kh += kernel_w * C4NUM; } // kernel_h loop + int32_t left = left_shift[0]; + int32_t right = right_shift[0]; + int32_t multiplier = out_multiplier[0]; for (int c = 0; c < C4NUM; c++) { + if (per_channel) { + left = left_shift[c]; + right = right_shift[c]; + multiplier = out_multiplier[c]; + } tmp_buffer[c] += bias[c]; tmp_buffer[c] = RoundingDivideByPOT( - SaturatingRoundingDoublingHighMul(tmp_buffer[c] * (1 << (unsigned int)left_shift), out_multiplier), -right_shift); + SaturatingRoundingDoublingHighMul(tmp_buffer[c] * (1 << (unsigned int)left), multiplier), right); tmp_buffer[c] += out_zp; tmp_buffer[c] = MSMAX(tmp_buffer[c], acc_min); tmp_buffer[c] = MSMIN(tmp_buffer[c], acc_max); @@ -55,7 +64,8 @@ void DepthwiseBorderPixelInt8(int8_t *dst, const int16_t *src, const int16_t *we void DepthwiseBorderInt8(int8_t *dst, const int16_t *src, const int16_t *weight, const int32_t *bias, int top, int bottom, int left, int right, const ConvParameter *conv_param, - const SlidingWindowParam *sliding) { + const SlidingWindowParam *sliding, int *out_multiplier, int *left_shift, int *right_shift, + bool per_channel) { int8_t *dst_h = dst + top * sliding->out_h_step_; for (int oh = top; oh < bottom; oh++) { int ih = oh * conv_param->stride_h_ - conv_param->pad_h_; @@ -73,12 +83,11 @@ void DepthwiseBorderInt8(int8_t *dst, const int16_t *src, const int16_t *weight, const int16_t *src_kernel = src_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_; const int16_t *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C4NUM; - DepthwiseBorderPixelInt8( - dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, sliding->in_kh_step_, - sliding->in_kw_step_, conv_param->kernel_w_, conv_param->conv_quant_arg_.quant_multiplier_[0], - conv_param->conv_quant_arg_.left_shift_[0], conv_param->conv_quant_arg_.right_shift_[0], - conv_param->conv_quant_arg_.output_quant_args_[0].zp_, conv_param->conv_quant_arg_.out_act_min_[0], - conv_param->conv_quant_arg_.out_act_max_[0]); + DepthwiseBorderPixelInt8(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, + sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_, out_multiplier, + left_shift, right_shift, conv_param->conv_quant_arg_.output_quant_args_[0].zp_, + conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], + per_channel); dst_kernel += sliding->block_channel_; } // width loop @@ -89,8 +98,8 @@ void DepthwiseBorderInt8(int8_t *dst, const int16_t *src, const int16_t *weight, #ifndef ENABLE_ARM64 void DepthwiseCenterInt8(int8_t *dst, const int16_t *src, const int16_t *weight, const int32_t *bias, int height, int width, int kernel_h, int kernel_w, int out_h_step, int block_channel, int in_sh_step, - int in_sw_step, int in_kh_step, int in_kw_step, int out_multiplier, int left_shift, - int right_shift, int32_t out_zp, int32_t acc_min, int32_t acc_max) { + int in_sw_step, int in_kh_step, int in_kw_step, int *out_multiplier, int *left_shift, + int *right_shift, int32_t out_zp, int32_t acc_min, int32_t acc_max, bool per_channel) { int tmp_buffer[C4NUM]; int8_t *dst_h = dst; const int16_t *src_h = src; @@ -118,11 +127,18 @@ void DepthwiseCenterInt8(int8_t *dst, const int16_t *src, const int16_t *weight, weight_kh += kernel_w * C4NUM; } // kernel_h loop // add bias relu + int32_t left = left_shift[0]; + int32_t right = right_shift[0]; + int32_t multiplier = out_multiplier[0]; for (int c = 0; c < C4NUM; c++) { + if (per_channel) { + left = left_shift[c]; + right = right_shift[c]; + multiplier = out_multiplier[c]; + } tmp_buffer[c] += bias[c]; tmp_buffer[c] = RoundingDivideByPOT( - SaturatingRoundingDoublingHighMul(tmp_buffer[c] * (1 << (unsigned int)left_shift), out_multiplier), - -right_shift); + SaturatingRoundingDoublingHighMul(tmp_buffer[c] * (1 << (unsigned int)left), multiplier), -right); tmp_buffer[c] += out_zp; tmp_buffer[c] = MSMAX(tmp_buffer[c], acc_min); tmp_buffer[c] = MSMIN(tmp_buffer[c], acc_max); @@ -141,20 +157,33 @@ void ConvDwInt8(int8_t *output_data, const int16_t *input_data, const int16_t *w const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id) { const int16_t *src = input_data; int8_t *dst = output_data; + bool per_channel = conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL; + int *out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_; + int *left_shift = conv_param->conv_quant_arg_.left_shift_; + int *right_shift = conv_param->conv_quant_arg_.right_shift_; for (int b = 0; b < conv_param->output_batch_; b++) { for (int oc = task_id; oc < sliding->c_block_; oc += conv_param->thread_num_) { const int16_t *src_data = src + oc * C4NUM; int8_t *dst_data = dst + oc * C4NUM; const int16_t *weight = weight_data + oc * sliding->kernel_step_; const int32_t *bias = bias_data + oc * C4NUM; + + if (per_channel) { + out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_ + oc; + left_shift = conv_param->conv_quant_arg_.left_shift_ + oc; + right_shift = conv_param->conv_quant_arg_.right_shift_ + oc; + } + DepthwiseBorderInt8(dst_data, src_data, weight, bias, 0, sliding->top_, 0, conv_param->output_w_, conv_param, - sliding); + sliding, out_multiplier, left_shift, right_shift, per_channel); DepthwiseBorderInt8(dst_data, src_data, weight, bias, sliding->bottom_, conv_param->output_h_, 0, - conv_param->output_w_, conv_param, sliding); + conv_param->output_w_, conv_param, sliding, out_multiplier, left_shift, right_shift, + per_channel); DepthwiseBorderInt8(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, 0, sliding->left_, - conv_param, sliding); + conv_param, sliding, out_multiplier, left_shift, right_shift, per_channel); DepthwiseBorderInt8(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, sliding->right_, - conv_param->output_w_, conv_param, sliding); + conv_param->output_w_, conv_param, sliding, out_multiplier, left_shift, right_shift, + per_channel); if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) { int in_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_h_; @@ -171,13 +200,13 @@ void ConvDwInt8(int8_t *output_data, const int16_t *input_data, const int16_t *w conv_param->conv_quant_arg_.output_quant_args_[0].zp_, conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0]); #else + DepthwiseCenterInt8( out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_, sliding->block_channel_, - sliding->in_sh_step_, sliding->in_sw_step_, sliding->in_kh_step_, sliding->in_kw_step_, - conv_param->conv_quant_arg_.quant_multiplier_[0], conv_param->conv_quant_arg_.left_shift_[0], - conv_param->conv_quant_arg_.right_shift_[0], conv_param->conv_quant_arg_.output_quant_args_[0].zp_, - conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0]); + sliding->in_sh_step_, sliding->in_sw_step_, sliding->in_kh_step_, sliding->in_kw_step_, out_multiplier, + left_shift, right_shift, conv_param->conv_quant_arg_.output_quant_args_[0].zp_, + conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], per_channel); #endif } } // output C4 loop diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/pack.c b/mindspore/lite/src/runtime/kernel/arm/nnacl/pack.c index dbc713dde6..43fbb068a2 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/pack.c +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/pack.c @@ -847,6 +847,9 @@ void PackDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight int weight_zp = conv_param->conv_quant_arg_.filter_quant_args_[0].zp_; int unit = conv_param->kernel_h_ * conv_param->kernel_w_; for (int c = 0; c < conv_param->output_channel_; c++) { + if (conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL) { + weight_zp = conv_param->conv_quant_arg_.filter_quant_args_[c].zp_; + } int c4_block_num = c / C4NUM; int c4_block_rem = c % C4NUM; int8_t *src_c = origin_weight + c * unit;