diff --git a/mindspore/lite/nnacl/assembly/arm64/ConvDwInt8PostAlign4PerChannel.S b/mindspore/lite/nnacl/assembly/arm64/ConvDwInt8PostAlign4PerChannel.S new file mode 100644 index 0000000000..7c9a544edf --- /dev/null +++ b/mindspore/lite/nnacl/assembly/arm64/ConvDwInt8PostAlign4PerChannel.S @@ -0,0 +1,108 @@ +#ifdef __aarch64__ + +.text +.align 5 +.global ConvDwInt8PostAlign4PerChannel +#ifndef __APPLE__ +.type ConvDwInt8PostAlign4PerChannel, %function +#endif + +// void ConvDwInt8PostAlign4PerChannel(int8_t *dst, int32_t *buffer, int channel4, int32_t output_zp, int32_t *out_multiplier, +// int32_t *left_shift, int32_t *right_shift, int32_t acc_min, int32_t acc_max); +// x0: dst, x1: buffer, x2: num_pixels, x3: output_zp, x4: out_multiplier, +// x5: left_shift, x6: right_shift, x7: acc_min, x8: acc_max + +ConvDwInt8PostAlign4PerChannel: + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + ldr x8, [sp] + + dup v29.4s, w3 + dup v30.4s, w7 + dup v31.4s, w8 + + LoopDepth8: + cmp x2, #8 + blt LoopDepth4 + ld1 {v0.4s}, [x1], #16 + ld1 {v1.4s}, [x1], #16 + + ld1 {v2.4s}, [x5], #16 + ld1 {v3.4s}, [x5], #16 + + ld1 {v4.4s}, [x4], #16 + ld1 {v5.4s}, [x4], #16 + + sqshl v0.4s, v0.4s, v2.4s + sqshl v1.4s, v1.4s, v3.4s + + ld1 {v6.4s}, [x6], #16 + ld1 {v7.4s}, [x6], #16 + + sqrdmulh v0.4s, v0.4s, v4.4s + sqrdmulh v1.4s, v1.4s, v5.4s + + and v16.16b, v6.16b, v0.16b + sshr v16.4s, v16.4s, #31 + sqadd v0.4s, v0.4s, v16.4s + srshl v0.4s, v0.4s, v6.4s + and v17.16b, v7.16b, v1.16b + sshr v17.4s, v17.4s, #31 + sqadd v1.4s, v1.4s, v17.4s + srshl v1.4s, v1.4s, v7.4s + + add v0.4s, v0.4s, v29.4s + add v1.4s, v1.4s, v29.4s + + smax v0.4s, v0.4s, v30.4s + smax v1.4s, v1.4s, v30.4s + + smin v0.4s, v0.4s, v31.4s + smin v1.4s, v1.4s, v31.4s + + sqxtn v0.4h, v0.4s + sqxtn v1.4h, v1.4s + + sqxtn v0.8b, v0.8h + sqxtn v1.8b, v1.8h + + st1 {v0.s}[0], [x0], #4 + st1 {v1.s}[0], [x0], #4 + + sub x2, x2, #8 + cmp x2, #8 + bge LoopDepth8 + + LoopDepth4: + cmp x2, #4 + blt End + ld1 {v0.4s}, [x1], #16 + ld1 {v2.4s}, [x5], #16 + + sqshl v0.4s, v0.4s, v2.4s + + ld1 {v4.4s}, [x4], #16 + sqrdmulh v0.4s, v0.4s, v4.4s + + ld1 {v6.4s}, [x6], #16 + and v16.16b, v6.16b, v0.16b + sshr v16.4s, v16.4s, #31 + sqadd v0.4s, v0.4s, v16.4s + srshl v0.4s, v0.4s, v6.4s + + add v0.4s, v0.4s, v29.4s + smax v0.4s, v0.4s, v30.4s + smin v0.4s, v0.4s, v31.4s + + sqxtn v0.4h, v0.4s + sqxtn v0.8b, v0.8h + + st1 {v0.s}[0], [x0], #4 + + sub x2, x2, #4 + bge LoopDepth4 + End: + ret +#endif diff --git a/mindspore/lite/nnacl/int8/common_func.h b/mindspore/lite/nnacl/int8/common_func.h index 1e1b965d34..3e79180d19 100644 --- a/mindspore/lite/nnacl/int8/common_func.h +++ b/mindspore/lite/nnacl/int8/common_func.h @@ -53,6 +53,9 @@ void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t * int output_channel, int input_step, int8_t input_zp); void ConvDwInt8PostAlign4(int8_t *dst, int32_t *buffer, int num_pixels, int32_t output_zp, int32_t out_multiplier, int32_t left_shift, int32_t right_shift, int32_t acc_min, int32_t acc_max); +void ConvDwInt8PostAlign4PerChannel(int8_t *dst, int32_t *buffer, int channel4, int32_t output_zp, + int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift, int32_t acc_min, + int32_t acc_max); #endif #ifdef __cplusplus diff --git a/mindspore/lite/nnacl/int8/conv_depthwise_int8.c b/mindspore/lite/nnacl/int8/conv_depthwise_int8.c index 7e7d9d4067..2514bd1bb4 100644 --- a/mindspore/lite/nnacl/int8/conv_depthwise_int8.c +++ b/mindspore/lite/nnacl/int8/conv_depthwise_int8.c @@ -20,7 +20,6 @@ #include "nnacl/int8/common_func.h" /*conv depthwise int8 begin*/ -// only support perlayer #ifndef ENABLE_ARM64 void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t *weight_ptr, int num_pixels, int output_channel, int input_step, int8_t input_zp) { @@ -34,20 +33,46 @@ void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t * } #endif -void ConvDwInt8Post(int8_t *dst, int32_t *buffer, int num_pixels, int32_t output_zp, int32_t out_multiplier, - int32_t left_shift, int32_t right_shift, int32_t acc_min, int32_t acc_max) { - int align_num = 0; +void ConvDwInt8Post(int8_t *dst, int32_t *buffer, int output_w, int channel, int32_t output_zp, int32_t *out_multiplier, + int32_t *left_shift, int32_t *right_shift, int32_t acc_min, int32_t acc_max, bool per_channel) { + if (per_channel) { + // support perchannel + for (int w = 0; w < output_w; w++) { + int channel4 = 0; #ifdef ENABLE_ARM64 - align_num = num_pixels / 4 * 4; - ConvDwInt8PostAlign4(dst, buffer, align_num, output_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max); + channel4 = channel / 4 * 4; + ConvDwInt8PostAlign4PerChannel(dst, buffer, channel4, output_zp, out_multiplier, left_shift, right_shift, acc_min, + acc_max); #endif - for (int i = align_num; i < num_pixels; i++) { - buffer[i] = RoundingDivideByPOT( - SaturatingRoundingDoublingHighMul(buffer[i] * (1 << (unsigned int)left_shift), out_multiplier), -right_shift); - buffer[i] += output_zp; - buffer[i] = MSMAX(buffer[i], acc_min); - buffer[i] = MSMIN(buffer[i], acc_max); - dst[i] = (buffer[i]); + for (int c = channel4; c < channel; c++) { + buffer[c] = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(buffer[c] * (1 << (unsigned int)left_shift[c]), out_multiplier[c]), + -right_shift[c]); + buffer[c] += output_zp; + buffer[c] = MSMAX(buffer[c], acc_min); + buffer[c] = MSMIN(buffer[c], acc_max); + dst[c] = (buffer[c]); + } + buffer += channel; + dst += channel; + } + } else { + int num_pixels = output_w * channel; + int align_num = 0; +#ifdef ENABLE_ARM64 + align_num = num_pixels / 4 * 4; + ConvDwInt8PostAlign4(dst, buffer, align_num, output_zp, out_multiplier[0], left_shift[0], right_shift[0], acc_min, + acc_max); +#endif + for (int i = align_num; i < num_pixels; i++) { + buffer[i] = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(buffer[i] * (1 << (unsigned int)left_shift[0]), out_multiplier[0]), + -right_shift[0]); + buffer[i] += output_zp; + buffer[i] = MSMAX(buffer[i], acc_min); + buffer[i] = MSMIN(buffer[i], acc_max); + dst[i] = (buffer[i]); + } } } @@ -57,9 +82,10 @@ void ConvDwInt8(int8_t *output_data, int32_t *row_buffer, const int8_t *input_da int h_start = h_step * task_id; int h_end = MSMIN(h_start + h_step, conv_param->output_h_); - int out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_[0]; - int left_shift = conv_param->conv_quant_arg_.left_shift_[0]; - int right_shift = conv_param->conv_quant_arg_.right_shift_[0]; + bool filter_per_channel = conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL; + int *out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_; + int *left_shift = conv_param->conv_quant_arg_.left_shift_; + int *right_shift = conv_param->conv_quant_arg_.right_shift_; int intput_zp = conv_param->conv_quant_arg_.input_quant_args_[0].zp_; int output_zp = conv_param->conv_quant_arg_.output_quant_args_[0].zp_; @@ -105,8 +131,8 @@ void ConvDwInt8(int8_t *output_data, int32_t *row_buffer, const int8_t *input_da } } // post func, acc int32 -> dst int8 - ConvDwInt8Post(dst_data, row_buffer, conv_param->output_w_ * conv_param->output_channel_, output_zp, - out_multiplier, left_shift, right_shift, acc_min, acc_max); + ConvDwInt8Post(dst_data, row_buffer, conv_param->output_w_, conv_param->output_channel_, output_zp, + out_multiplier, left_shift, right_shift, acc_min, acc_max, filter_per_channel); } } } diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_int8.cc index a0871e196f..e8a28ae3df 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_int8.cc @@ -51,14 +51,25 @@ int ConvolutionDepthwiseInt8CPUKernel::InitWeightBias() { PackNCHWToNHWCInt8(origin_weight, tmp_weight, 1, weight_tensor->Height() * weight_tensor->Width(), weight_tensor->Batch()); - int weight_zp = conv_param_->conv_quant_arg_.filter_quant_args_[0].zp_; packed_weight_ = reinterpret_cast(malloc(pack_weight_size * sizeof(int16_t))); if (packed_weight_ == nullptr) { MS_LOG(ERROR) << "Malloc buffer failed."; return RET_ERROR; } - for (int i = 0; i < weight_tensor->ElementsNum(); i++) { - packed_weight_[i] = (int16_t)(tmp_weight[i] - weight_zp); + + bool filter_per_channel = conv_param_->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL; + if (filter_per_channel) { + for (int i = 0; i < weight_tensor->Height() * weight_tensor->Width(); i++) { + for (int c = 0; c < channel; c++) { + int weight_zp = conv_param_->conv_quant_arg_.filter_quant_args_[c].zp_; + packed_weight_[i * channel + c] = (int16_t)(tmp_weight[i * channel + c] - weight_zp); + } + } + } else { + int weight_zp = conv_param_->conv_quant_arg_.filter_quant_args_[0].zp_; + for (int i = 0; i < weight_tensor->ElementsNum(); i++) { + packed_weight_[i] = (int16_t)(tmp_weight[i] - weight_zp); + } } free(tmp_weight); @@ -166,14 +177,8 @@ kernel::LiteKernel *CpuConvDwInt8KernelCreator(const std::vector const mindspore::lite::PrimitiveC *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D); - kernel::LiteKernel *kernel; - auto filter_quant_size = inputs[kWeightIndex]->GetQuantParams().size(); - if (filter_quant_size == 1) { // per tensor - kernel = new (std::nothrow) kernel::ConvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); - } else { // per channel - kernel = - new (std::nothrow) kernel::ConvolutionDepthwiseSWInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); - } + auto kernel = + new (std::nothrow) kernel::ConvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "kernel is nullptr."; return nullptr;