| @@ -0,0 +1,108 @@ | |||
| #ifdef __aarch64__ | |||
| .text | |||
| .align 5 | |||
| .global ConvDwInt8PostAlign4PerChannel | |||
| #ifndef __APPLE__ | |||
| .type ConvDwInt8PostAlign4PerChannel, %function | |||
| #endif | |||
| // void ConvDwInt8PostAlign4PerChannel(int8_t *dst, int32_t *buffer, int channel4, int32_t output_zp, int32_t *out_multiplier, | |||
| // int32_t *left_shift, int32_t *right_shift, int32_t acc_min, int32_t acc_max); | |||
| // x0: dst, x1: buffer, x2: num_pixels, x3: output_zp, x4: out_multiplier, | |||
| // x5: left_shift, x6: right_shift, x7: acc_min, x8: acc_max | |||
| ConvDwInt8PostAlign4PerChannel: | |||
| // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to | |||
| // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers | |||
| // x19 ~ x29 should be also preserved | |||
| // whereas our coding style do not permit such amount of parameters | |||
| ldr x8, [sp] | |||
| dup v29.4s, w3 | |||
| dup v30.4s, w7 | |||
| dup v31.4s, w8 | |||
| LoopDepth8: | |||
| cmp x2, #8 | |||
| blt LoopDepth4 | |||
| ld1 {v0.4s}, [x1], #16 | |||
| ld1 {v1.4s}, [x1], #16 | |||
| ld1 {v2.4s}, [x5], #16 | |||
| ld1 {v3.4s}, [x5], #16 | |||
| ld1 {v4.4s}, [x4], #16 | |||
| ld1 {v5.4s}, [x4], #16 | |||
| sqshl v0.4s, v0.4s, v2.4s | |||
| sqshl v1.4s, v1.4s, v3.4s | |||
| ld1 {v6.4s}, [x6], #16 | |||
| ld1 {v7.4s}, [x6], #16 | |||
| sqrdmulh v0.4s, v0.4s, v4.4s | |||
| sqrdmulh v1.4s, v1.4s, v5.4s | |||
| and v16.16b, v6.16b, v0.16b | |||
| sshr v16.4s, v16.4s, #31 | |||
| sqadd v0.4s, v0.4s, v16.4s | |||
| srshl v0.4s, v0.4s, v6.4s | |||
| and v17.16b, v7.16b, v1.16b | |||
| sshr v17.4s, v17.4s, #31 | |||
| sqadd v1.4s, v1.4s, v17.4s | |||
| srshl v1.4s, v1.4s, v7.4s | |||
| add v0.4s, v0.4s, v29.4s | |||
| add v1.4s, v1.4s, v29.4s | |||
| smax v0.4s, v0.4s, v30.4s | |||
| smax v1.4s, v1.4s, v30.4s | |||
| smin v0.4s, v0.4s, v31.4s | |||
| smin v1.4s, v1.4s, v31.4s | |||
| sqxtn v0.4h, v0.4s | |||
| sqxtn v1.4h, v1.4s | |||
| sqxtn v0.8b, v0.8h | |||
| sqxtn v1.8b, v1.8h | |||
| st1 {v0.s}[0], [x0], #4 | |||
| st1 {v1.s}[0], [x0], #4 | |||
| sub x2, x2, #8 | |||
| cmp x2, #8 | |||
| bge LoopDepth8 | |||
| LoopDepth4: | |||
| cmp x2, #4 | |||
| blt End | |||
| ld1 {v0.4s}, [x1], #16 | |||
| ld1 {v2.4s}, [x5], #16 | |||
| sqshl v0.4s, v0.4s, v2.4s | |||
| ld1 {v4.4s}, [x4], #16 | |||
| sqrdmulh v0.4s, v0.4s, v4.4s | |||
| ld1 {v6.4s}, [x6], #16 | |||
| and v16.16b, v6.16b, v0.16b | |||
| sshr v16.4s, v16.4s, #31 | |||
| sqadd v0.4s, v0.4s, v16.4s | |||
| srshl v0.4s, v0.4s, v6.4s | |||
| add v0.4s, v0.4s, v29.4s | |||
| smax v0.4s, v0.4s, v30.4s | |||
| smin v0.4s, v0.4s, v31.4s | |||
| sqxtn v0.4h, v0.4s | |||
| sqxtn v0.8b, v0.8h | |||
| st1 {v0.s}[0], [x0], #4 | |||
| sub x2, x2, #4 | |||
| bge LoopDepth4 | |||
| End: | |||
| ret | |||
| #endif | |||
| @@ -53,6 +53,9 @@ void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t * | |||
| int output_channel, int input_step, int8_t input_zp); | |||
| void ConvDwInt8PostAlign4(int8_t *dst, int32_t *buffer, int num_pixels, int32_t output_zp, int32_t out_multiplier, | |||
| int32_t left_shift, int32_t right_shift, int32_t acc_min, int32_t acc_max); | |||
| void ConvDwInt8PostAlign4PerChannel(int8_t *dst, int32_t *buffer, int channel4, int32_t output_zp, | |||
| int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift, int32_t acc_min, | |||
| int32_t acc_max); | |||
| #endif | |||
| #ifdef __cplusplus | |||
| @@ -20,7 +20,6 @@ | |||
| #include "nnacl/int8/common_func.h" | |||
| /*conv depthwise int8 begin*/ | |||
| // only support perlayer | |||
| #ifndef ENABLE_ARM64 | |||
| void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t *weight_ptr, int num_pixels, | |||
| int output_channel, int input_step, int8_t input_zp) { | |||
| @@ -34,20 +33,46 @@ void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t * | |||
| } | |||
| #endif | |||
| void ConvDwInt8Post(int8_t *dst, int32_t *buffer, int num_pixels, int32_t output_zp, int32_t out_multiplier, | |||
| int32_t left_shift, int32_t right_shift, int32_t acc_min, int32_t acc_max) { | |||
| int align_num = 0; | |||
| void ConvDwInt8Post(int8_t *dst, int32_t *buffer, int output_w, int channel, int32_t output_zp, int32_t *out_multiplier, | |||
| int32_t *left_shift, int32_t *right_shift, int32_t acc_min, int32_t acc_max, bool per_channel) { | |||
| if (per_channel) { | |||
| // support perchannel | |||
| for (int w = 0; w < output_w; w++) { | |||
| int channel4 = 0; | |||
| #ifdef ENABLE_ARM64 | |||
| align_num = num_pixels / 4 * 4; | |||
| ConvDwInt8PostAlign4(dst, buffer, align_num, output_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max); | |||
| channel4 = channel / 4 * 4; | |||
| ConvDwInt8PostAlign4PerChannel(dst, buffer, channel4, output_zp, out_multiplier, left_shift, right_shift, acc_min, | |||
| acc_max); | |||
| #endif | |||
| for (int i = align_num; i < num_pixels; i++) { | |||
| buffer[i] = RoundingDivideByPOT( | |||
| SaturatingRoundingDoublingHighMul(buffer[i] * (1 << (unsigned int)left_shift), out_multiplier), -right_shift); | |||
| buffer[i] += output_zp; | |||
| buffer[i] = MSMAX(buffer[i], acc_min); | |||
| buffer[i] = MSMIN(buffer[i], acc_max); | |||
| dst[i] = (buffer[i]); | |||
| for (int c = channel4; c < channel; c++) { | |||
| buffer[c] = RoundingDivideByPOT( | |||
| SaturatingRoundingDoublingHighMul(buffer[c] * (1 << (unsigned int)left_shift[c]), out_multiplier[c]), | |||
| -right_shift[c]); | |||
| buffer[c] += output_zp; | |||
| buffer[c] = MSMAX(buffer[c], acc_min); | |||
| buffer[c] = MSMIN(buffer[c], acc_max); | |||
| dst[c] = (buffer[c]); | |||
| } | |||
| buffer += channel; | |||
| dst += channel; | |||
| } | |||
| } else { | |||
| int num_pixels = output_w * channel; | |||
| int align_num = 0; | |||
| #ifdef ENABLE_ARM64 | |||
| align_num = num_pixels / 4 * 4; | |||
| ConvDwInt8PostAlign4(dst, buffer, align_num, output_zp, out_multiplier[0], left_shift[0], right_shift[0], acc_min, | |||
| acc_max); | |||
| #endif | |||
| for (int i = align_num; i < num_pixels; i++) { | |||
| buffer[i] = RoundingDivideByPOT( | |||
| SaturatingRoundingDoublingHighMul(buffer[i] * (1 << (unsigned int)left_shift[0]), out_multiplier[0]), | |||
| -right_shift[0]); | |||
| buffer[i] += output_zp; | |||
| buffer[i] = MSMAX(buffer[i], acc_min); | |||
| buffer[i] = MSMIN(buffer[i], acc_max); | |||
| dst[i] = (buffer[i]); | |||
| } | |||
| } | |||
| } | |||
| @@ -57,9 +82,10 @@ void ConvDwInt8(int8_t *output_data, int32_t *row_buffer, const int8_t *input_da | |||
| int h_start = h_step * task_id; | |||
| int h_end = MSMIN(h_start + h_step, conv_param->output_h_); | |||
| int out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_[0]; | |||
| int left_shift = conv_param->conv_quant_arg_.left_shift_[0]; | |||
| int right_shift = conv_param->conv_quant_arg_.right_shift_[0]; | |||
| bool filter_per_channel = conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL; | |||
| int *out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_; | |||
| int *left_shift = conv_param->conv_quant_arg_.left_shift_; | |||
| int *right_shift = conv_param->conv_quant_arg_.right_shift_; | |||
| int intput_zp = conv_param->conv_quant_arg_.input_quant_args_[0].zp_; | |||
| int output_zp = conv_param->conv_quant_arg_.output_quant_args_[0].zp_; | |||
| @@ -105,8 +131,8 @@ void ConvDwInt8(int8_t *output_data, int32_t *row_buffer, const int8_t *input_da | |||
| } | |||
| } | |||
| // post func, acc int32 -> dst int8 | |||
| ConvDwInt8Post(dst_data, row_buffer, conv_param->output_w_ * conv_param->output_channel_, output_zp, | |||
| out_multiplier, left_shift, right_shift, acc_min, acc_max); | |||
| ConvDwInt8Post(dst_data, row_buffer, conv_param->output_w_, conv_param->output_channel_, output_zp, | |||
| out_multiplier, left_shift, right_shift, acc_min, acc_max, filter_per_channel); | |||
| } | |||
| } | |||
| } | |||
| @@ -51,14 +51,25 @@ int ConvolutionDepthwiseInt8CPUKernel::InitWeightBias() { | |||
| PackNCHWToNHWCInt8(origin_weight, tmp_weight, 1, weight_tensor->Height() * weight_tensor->Width(), | |||
| weight_tensor->Batch()); | |||
| int weight_zp = conv_param_->conv_quant_arg_.filter_quant_args_[0].zp_; | |||
| packed_weight_ = reinterpret_cast<int16_t *>(malloc(pack_weight_size * sizeof(int16_t))); | |||
| if (packed_weight_ == nullptr) { | |||
| MS_LOG(ERROR) << "Malloc buffer failed."; | |||
| return RET_ERROR; | |||
| } | |||
| for (int i = 0; i < weight_tensor->ElementsNum(); i++) { | |||
| packed_weight_[i] = (int16_t)(tmp_weight[i] - weight_zp); | |||
| bool filter_per_channel = conv_param_->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL; | |||
| if (filter_per_channel) { | |||
| for (int i = 0; i < weight_tensor->Height() * weight_tensor->Width(); i++) { | |||
| for (int c = 0; c < channel; c++) { | |||
| int weight_zp = conv_param_->conv_quant_arg_.filter_quant_args_[c].zp_; | |||
| packed_weight_[i * channel + c] = (int16_t)(tmp_weight[i * channel + c] - weight_zp); | |||
| } | |||
| } | |||
| } else { | |||
| int weight_zp = conv_param_->conv_quant_arg_.filter_quant_args_[0].zp_; | |||
| for (int i = 0; i < weight_tensor->ElementsNum(); i++) { | |||
| packed_weight_[i] = (int16_t)(tmp_weight[i] - weight_zp); | |||
| } | |||
| } | |||
| free(tmp_weight); | |||
| @@ -166,14 +177,8 @@ kernel::LiteKernel *CpuConvDwInt8KernelCreator(const std::vector<lite::Tensor *> | |||
| const mindspore::lite::PrimitiveC *primitive) { | |||
| MS_ASSERT(opParameter != nullptr); | |||
| MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D); | |||
| kernel::LiteKernel *kernel; | |||
| auto filter_quant_size = inputs[kWeightIndex]->GetQuantParams().size(); | |||
| if (filter_quant_size == 1) { // per tensor | |||
| kernel = new (std::nothrow) kernel::ConvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| } else { // per channel | |||
| kernel = | |||
| new (std::nothrow) kernel::ConvolutionDepthwiseSWInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| } | |||
| auto kernel = | |||
| new (std::nothrow) kernel::ConvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "kernel is nullptr."; | |||
| return nullptr; | |||