| @@ -0,0 +1,108 @@ | |||||
| #ifdef __aarch64__ | |||||
| .text | |||||
| .align 5 | |||||
| .global ConvDwInt8PostAlign4PerChannel | |||||
| #ifndef __APPLE__ | |||||
| .type ConvDwInt8PostAlign4PerChannel, %function | |||||
| #endif | |||||
| // void ConvDwInt8PostAlign4PerChannel(int8_t *dst, int32_t *buffer, int channel4, int32_t output_zp, int32_t *out_multiplier, | |||||
| // int32_t *left_shift, int32_t *right_shift, int32_t acc_min, int32_t acc_max); | |||||
| // x0: dst, x1: buffer, x2: num_pixels, x3: output_zp, x4: out_multiplier, | |||||
| // x5: left_shift, x6: right_shift, x7: acc_min, x8: acc_max | |||||
| ConvDwInt8PostAlign4PerChannel: | |||||
| // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to | |||||
| // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers | |||||
| // x19 ~ x29 should be also preserved | |||||
| // whereas our coding style do not permit such amount of parameters | |||||
| ldr x8, [sp] | |||||
| dup v29.4s, w3 | |||||
| dup v30.4s, w7 | |||||
| dup v31.4s, w8 | |||||
| LoopDepth8: | |||||
| cmp x2, #8 | |||||
| blt LoopDepth4 | |||||
| ld1 {v0.4s}, [x1], #16 | |||||
| ld1 {v1.4s}, [x1], #16 | |||||
| ld1 {v2.4s}, [x5], #16 | |||||
| ld1 {v3.4s}, [x5], #16 | |||||
| ld1 {v4.4s}, [x4], #16 | |||||
| ld1 {v5.4s}, [x4], #16 | |||||
| sqshl v0.4s, v0.4s, v2.4s | |||||
| sqshl v1.4s, v1.4s, v3.4s | |||||
| ld1 {v6.4s}, [x6], #16 | |||||
| ld1 {v7.4s}, [x6], #16 | |||||
| sqrdmulh v0.4s, v0.4s, v4.4s | |||||
| sqrdmulh v1.4s, v1.4s, v5.4s | |||||
| and v16.16b, v6.16b, v0.16b | |||||
| sshr v16.4s, v16.4s, #31 | |||||
| sqadd v0.4s, v0.4s, v16.4s | |||||
| srshl v0.4s, v0.4s, v6.4s | |||||
| and v17.16b, v7.16b, v1.16b | |||||
| sshr v17.4s, v17.4s, #31 | |||||
| sqadd v1.4s, v1.4s, v17.4s | |||||
| srshl v1.4s, v1.4s, v7.4s | |||||
| add v0.4s, v0.4s, v29.4s | |||||
| add v1.4s, v1.4s, v29.4s | |||||
| smax v0.4s, v0.4s, v30.4s | |||||
| smax v1.4s, v1.4s, v30.4s | |||||
| smin v0.4s, v0.4s, v31.4s | |||||
| smin v1.4s, v1.4s, v31.4s | |||||
| sqxtn v0.4h, v0.4s | |||||
| sqxtn v1.4h, v1.4s | |||||
| sqxtn v0.8b, v0.8h | |||||
| sqxtn v1.8b, v1.8h | |||||
| st1 {v0.s}[0], [x0], #4 | |||||
| st1 {v1.s}[0], [x0], #4 | |||||
| sub x2, x2, #8 | |||||
| cmp x2, #8 | |||||
| bge LoopDepth8 | |||||
| LoopDepth4: | |||||
| cmp x2, #4 | |||||
| blt End | |||||
| ld1 {v0.4s}, [x1], #16 | |||||
| ld1 {v2.4s}, [x5], #16 | |||||
| sqshl v0.4s, v0.4s, v2.4s | |||||
| ld1 {v4.4s}, [x4], #16 | |||||
| sqrdmulh v0.4s, v0.4s, v4.4s | |||||
| ld1 {v6.4s}, [x6], #16 | |||||
| and v16.16b, v6.16b, v0.16b | |||||
| sshr v16.4s, v16.4s, #31 | |||||
| sqadd v0.4s, v0.4s, v16.4s | |||||
| srshl v0.4s, v0.4s, v6.4s | |||||
| add v0.4s, v0.4s, v29.4s | |||||
| smax v0.4s, v0.4s, v30.4s | |||||
| smin v0.4s, v0.4s, v31.4s | |||||
| sqxtn v0.4h, v0.4s | |||||
| sqxtn v0.8b, v0.8h | |||||
| st1 {v0.s}[0], [x0], #4 | |||||
| sub x2, x2, #4 | |||||
| bge LoopDepth4 | |||||
| End: | |||||
| ret | |||||
| #endif | |||||
| @@ -53,6 +53,9 @@ void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t * | |||||
| int output_channel, int input_step, int8_t input_zp); | int output_channel, int input_step, int8_t input_zp); | ||||
| void ConvDwInt8PostAlign4(int8_t *dst, int32_t *buffer, int num_pixels, int32_t output_zp, int32_t out_multiplier, | void ConvDwInt8PostAlign4(int8_t *dst, int32_t *buffer, int num_pixels, int32_t output_zp, int32_t out_multiplier, | ||||
| int32_t left_shift, int32_t right_shift, int32_t acc_min, int32_t acc_max); | int32_t left_shift, int32_t right_shift, int32_t acc_min, int32_t acc_max); | ||||
| void ConvDwInt8PostAlign4PerChannel(int8_t *dst, int32_t *buffer, int channel4, int32_t output_zp, | |||||
| int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift, int32_t acc_min, | |||||
| int32_t acc_max); | |||||
| #endif | #endif | ||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| @@ -20,7 +20,6 @@ | |||||
| #include "nnacl/int8/common_func.h" | #include "nnacl/int8/common_func.h" | ||||
| /*conv depthwise int8 begin*/ | /*conv depthwise int8 begin*/ | ||||
| // only support perlayer | |||||
| #ifndef ENABLE_ARM64 | #ifndef ENABLE_ARM64 | ||||
| void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t *weight_ptr, int num_pixels, | void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t *weight_ptr, int num_pixels, | ||||
| int output_channel, int input_step, int8_t input_zp) { | int output_channel, int input_step, int8_t input_zp) { | ||||
| @@ -34,20 +33,46 @@ void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t * | |||||
| } | } | ||||
| #endif | #endif | ||||
| void ConvDwInt8Post(int8_t *dst, int32_t *buffer, int num_pixels, int32_t output_zp, int32_t out_multiplier, | |||||
| int32_t left_shift, int32_t right_shift, int32_t acc_min, int32_t acc_max) { | |||||
| int align_num = 0; | |||||
| void ConvDwInt8Post(int8_t *dst, int32_t *buffer, int output_w, int channel, int32_t output_zp, int32_t *out_multiplier, | |||||
| int32_t *left_shift, int32_t *right_shift, int32_t acc_min, int32_t acc_max, bool per_channel) { | |||||
| if (per_channel) { | |||||
| // support perchannel | |||||
| for (int w = 0; w < output_w; w++) { | |||||
| int channel4 = 0; | |||||
| #ifdef ENABLE_ARM64 | #ifdef ENABLE_ARM64 | ||||
| align_num = num_pixels / 4 * 4; | |||||
| ConvDwInt8PostAlign4(dst, buffer, align_num, output_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max); | |||||
| channel4 = channel / 4 * 4; | |||||
| ConvDwInt8PostAlign4PerChannel(dst, buffer, channel4, output_zp, out_multiplier, left_shift, right_shift, acc_min, | |||||
| acc_max); | |||||
| #endif | #endif | ||||
| for (int i = align_num; i < num_pixels; i++) { | |||||
| buffer[i] = RoundingDivideByPOT( | |||||
| SaturatingRoundingDoublingHighMul(buffer[i] * (1 << (unsigned int)left_shift), out_multiplier), -right_shift); | |||||
| buffer[i] += output_zp; | |||||
| buffer[i] = MSMAX(buffer[i], acc_min); | |||||
| buffer[i] = MSMIN(buffer[i], acc_max); | |||||
| dst[i] = (buffer[i]); | |||||
| for (int c = channel4; c < channel; c++) { | |||||
| buffer[c] = RoundingDivideByPOT( | |||||
| SaturatingRoundingDoublingHighMul(buffer[c] * (1 << (unsigned int)left_shift[c]), out_multiplier[c]), | |||||
| -right_shift[c]); | |||||
| buffer[c] += output_zp; | |||||
| buffer[c] = MSMAX(buffer[c], acc_min); | |||||
| buffer[c] = MSMIN(buffer[c], acc_max); | |||||
| dst[c] = (buffer[c]); | |||||
| } | |||||
| buffer += channel; | |||||
| dst += channel; | |||||
| } | |||||
| } else { | |||||
| int num_pixels = output_w * channel; | |||||
| int align_num = 0; | |||||
| #ifdef ENABLE_ARM64 | |||||
| align_num = num_pixels / 4 * 4; | |||||
| ConvDwInt8PostAlign4(dst, buffer, align_num, output_zp, out_multiplier[0], left_shift[0], right_shift[0], acc_min, | |||||
| acc_max); | |||||
| #endif | |||||
| for (int i = align_num; i < num_pixels; i++) { | |||||
| buffer[i] = RoundingDivideByPOT( | |||||
| SaturatingRoundingDoublingHighMul(buffer[i] * (1 << (unsigned int)left_shift[0]), out_multiplier[0]), | |||||
| -right_shift[0]); | |||||
| buffer[i] += output_zp; | |||||
| buffer[i] = MSMAX(buffer[i], acc_min); | |||||
| buffer[i] = MSMIN(buffer[i], acc_max); | |||||
| dst[i] = (buffer[i]); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -57,9 +82,10 @@ void ConvDwInt8(int8_t *output_data, int32_t *row_buffer, const int8_t *input_da | |||||
| int h_start = h_step * task_id; | int h_start = h_step * task_id; | ||||
| int h_end = MSMIN(h_start + h_step, conv_param->output_h_); | int h_end = MSMIN(h_start + h_step, conv_param->output_h_); | ||||
| int out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_[0]; | |||||
| int left_shift = conv_param->conv_quant_arg_.left_shift_[0]; | |||||
| int right_shift = conv_param->conv_quant_arg_.right_shift_[0]; | |||||
| bool filter_per_channel = conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL; | |||||
| int *out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_; | |||||
| int *left_shift = conv_param->conv_quant_arg_.left_shift_; | |||||
| int *right_shift = conv_param->conv_quant_arg_.right_shift_; | |||||
| int intput_zp = conv_param->conv_quant_arg_.input_quant_args_[0].zp_; | int intput_zp = conv_param->conv_quant_arg_.input_quant_args_[0].zp_; | ||||
| int output_zp = conv_param->conv_quant_arg_.output_quant_args_[0].zp_; | int output_zp = conv_param->conv_quant_arg_.output_quant_args_[0].zp_; | ||||
| @@ -105,8 +131,8 @@ void ConvDwInt8(int8_t *output_data, int32_t *row_buffer, const int8_t *input_da | |||||
| } | } | ||||
| } | } | ||||
| // post func, acc int32 -> dst int8 | // post func, acc int32 -> dst int8 | ||||
| ConvDwInt8Post(dst_data, row_buffer, conv_param->output_w_ * conv_param->output_channel_, output_zp, | |||||
| out_multiplier, left_shift, right_shift, acc_min, acc_max); | |||||
| ConvDwInt8Post(dst_data, row_buffer, conv_param->output_w_, conv_param->output_channel_, output_zp, | |||||
| out_multiplier, left_shift, right_shift, acc_min, acc_max, filter_per_channel); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -51,14 +51,25 @@ int ConvolutionDepthwiseInt8CPUKernel::InitWeightBias() { | |||||
| PackNCHWToNHWCInt8(origin_weight, tmp_weight, 1, weight_tensor->Height() * weight_tensor->Width(), | PackNCHWToNHWCInt8(origin_weight, tmp_weight, 1, weight_tensor->Height() * weight_tensor->Width(), | ||||
| weight_tensor->Batch()); | weight_tensor->Batch()); | ||||
| int weight_zp = conv_param_->conv_quant_arg_.filter_quant_args_[0].zp_; | |||||
| packed_weight_ = reinterpret_cast<int16_t *>(malloc(pack_weight_size * sizeof(int16_t))); | packed_weight_ = reinterpret_cast<int16_t *>(malloc(pack_weight_size * sizeof(int16_t))); | ||||
| if (packed_weight_ == nullptr) { | if (packed_weight_ == nullptr) { | ||||
| MS_LOG(ERROR) << "Malloc buffer failed."; | MS_LOG(ERROR) << "Malloc buffer failed."; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| for (int i = 0; i < weight_tensor->ElementsNum(); i++) { | |||||
| packed_weight_[i] = (int16_t)(tmp_weight[i] - weight_zp); | |||||
| bool filter_per_channel = conv_param_->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL; | |||||
| if (filter_per_channel) { | |||||
| for (int i = 0; i < weight_tensor->Height() * weight_tensor->Width(); i++) { | |||||
| for (int c = 0; c < channel; c++) { | |||||
| int weight_zp = conv_param_->conv_quant_arg_.filter_quant_args_[c].zp_; | |||||
| packed_weight_[i * channel + c] = (int16_t)(tmp_weight[i * channel + c] - weight_zp); | |||||
| } | |||||
| } | |||||
| } else { | |||||
| int weight_zp = conv_param_->conv_quant_arg_.filter_quant_args_[0].zp_; | |||||
| for (int i = 0; i < weight_tensor->ElementsNum(); i++) { | |||||
| packed_weight_[i] = (int16_t)(tmp_weight[i] - weight_zp); | |||||
| } | |||||
| } | } | ||||
| free(tmp_weight); | free(tmp_weight); | ||||
| @@ -166,14 +177,8 @@ kernel::LiteKernel *CpuConvDwInt8KernelCreator(const std::vector<lite::Tensor *> | |||||
| const mindspore::lite::PrimitiveC *primitive) { | const mindspore::lite::PrimitiveC *primitive) { | ||||
| MS_ASSERT(opParameter != nullptr); | MS_ASSERT(opParameter != nullptr); | ||||
| MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D); | MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D); | ||||
| kernel::LiteKernel *kernel; | |||||
| auto filter_quant_size = inputs[kWeightIndex]->GetQuantParams().size(); | |||||
| if (filter_quant_size == 1) { // per tensor | |||||
| kernel = new (std::nothrow) kernel::ConvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||||
| } else { // per channel | |||||
| kernel = | |||||
| new (std::nothrow) kernel::ConvolutionDepthwiseSWInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||||
| } | |||||
| auto kernel = | |||||
| new (std::nothrow) kernel::ConvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||||
| if (kernel == nullptr) { | if (kernel == nullptr) { | ||||
| MS_LOG(ERROR) << "kernel is nullptr."; | MS_LOG(ERROR) << "kernel is nullptr."; | ||||
| return nullptr; | return nullptr; | ||||