Browse Source

[MSLITE][Develop] arm cpu int conv depthwise support weight per channel

tags/v1.0.0
yangruoqi713 5 years ago
parent
commit
27de06dbc8
4 changed files with 171 additions and 29 deletions
  1. +108
    -0
      mindspore/lite/nnacl/assembly/arm64/ConvDwInt8PostAlign4PerChannel.S
  2. +3
    -0
      mindspore/lite/nnacl/int8/common_func.h
  3. +44
    -18
      mindspore/lite/nnacl/int8/conv_depthwise_int8.c
  4. +16
    -11
      mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_int8.cc

+ 108
- 0
mindspore/lite/nnacl/assembly/arm64/ConvDwInt8PostAlign4PerChannel.S View File

@@ -0,0 +1,108 @@
#ifdef __aarch64__

.text
.align 5
.global ConvDwInt8PostAlign4PerChannel
#ifndef __APPLE__
.type ConvDwInt8PostAlign4PerChannel, %function
#endif

// void ConvDwInt8PostAlign4PerChannel(int8_t *dst, int32_t *buffer, int channel4, int32_t output_zp, int32_t *out_multiplier,
// int32_t *left_shift, int32_t *right_shift, int32_t acc_min, int32_t acc_max);
// x0: dst, x1: buffer, x2: num_pixels, x3: output_zp, x4: out_multiplier,
// x5: left_shift, x6: right_shift, x7: acc_min, x8: acc_max

ConvDwInt8PostAlign4PerChannel:
// registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to
// https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers
// x19 ~ x29 should be also preserved
// whereas our coding style do not permit such amount of parameters
ldr x8, [sp]

dup v29.4s, w3
dup v30.4s, w7
dup v31.4s, w8

LoopDepth8:
cmp x2, #8
blt LoopDepth4
ld1 {v0.4s}, [x1], #16
ld1 {v1.4s}, [x1], #16

ld1 {v2.4s}, [x5], #16
ld1 {v3.4s}, [x5], #16

ld1 {v4.4s}, [x4], #16
ld1 {v5.4s}, [x4], #16

sqshl v0.4s, v0.4s, v2.4s
sqshl v1.4s, v1.4s, v3.4s

ld1 {v6.4s}, [x6], #16
ld1 {v7.4s}, [x6], #16

sqrdmulh v0.4s, v0.4s, v4.4s
sqrdmulh v1.4s, v1.4s, v5.4s

and v16.16b, v6.16b, v0.16b
sshr v16.4s, v16.4s, #31
sqadd v0.4s, v0.4s, v16.4s
srshl v0.4s, v0.4s, v6.4s
and v17.16b, v7.16b, v1.16b
sshr v17.4s, v17.4s, #31
sqadd v1.4s, v1.4s, v17.4s
srshl v1.4s, v1.4s, v7.4s

add v0.4s, v0.4s, v29.4s
add v1.4s, v1.4s, v29.4s

smax v0.4s, v0.4s, v30.4s
smax v1.4s, v1.4s, v30.4s

smin v0.4s, v0.4s, v31.4s
smin v1.4s, v1.4s, v31.4s

sqxtn v0.4h, v0.4s
sqxtn v1.4h, v1.4s

sqxtn v0.8b, v0.8h
sqxtn v1.8b, v1.8h

st1 {v0.s}[0], [x0], #4
st1 {v1.s}[0], [x0], #4

sub x2, x2, #8
cmp x2, #8
bge LoopDepth8

LoopDepth4:
cmp x2, #4
blt End
ld1 {v0.4s}, [x1], #16
ld1 {v2.4s}, [x5], #16

sqshl v0.4s, v0.4s, v2.4s

ld1 {v4.4s}, [x4], #16
sqrdmulh v0.4s, v0.4s, v4.4s

ld1 {v6.4s}, [x6], #16
and v16.16b, v6.16b, v0.16b
sshr v16.4s, v16.4s, #31
sqadd v0.4s, v0.4s, v16.4s
srshl v0.4s, v0.4s, v6.4s

add v0.4s, v0.4s, v29.4s
smax v0.4s, v0.4s, v30.4s
smin v0.4s, v0.4s, v31.4s

sqxtn v0.4h, v0.4s
sqxtn v0.8b, v0.8h

st1 {v0.s}[0], [x0], #4

sub x2, x2, #4
bge LoopDepth4
End:
ret
#endif

+ 3
- 0
mindspore/lite/nnacl/int8/common_func.h View File

@@ -53,6 +53,9 @@ void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t *
int output_channel, int input_step, int8_t input_zp);
void ConvDwInt8PostAlign4(int8_t *dst, int32_t *buffer, int num_pixels, int32_t output_zp, int32_t out_multiplier,
int32_t left_shift, int32_t right_shift, int32_t acc_min, int32_t acc_max);
void ConvDwInt8PostAlign4PerChannel(int8_t *dst, int32_t *buffer, int channel4, int32_t output_zp,
int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift, int32_t acc_min,
int32_t acc_max);
#endif

#ifdef __cplusplus


+ 44
- 18
mindspore/lite/nnacl/int8/conv_depthwise_int8.c View File

@@ -20,7 +20,6 @@
#include "nnacl/int8/common_func.h"

/*conv depthwise int8 begin*/
// only support perlayer
#ifndef ENABLE_ARM64
void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t *weight_ptr, int num_pixels,
int output_channel, int input_step, int8_t input_zp) {
@@ -34,20 +33,46 @@ void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t *
}
#endif

void ConvDwInt8Post(int8_t *dst, int32_t *buffer, int num_pixels, int32_t output_zp, int32_t out_multiplier,
int32_t left_shift, int32_t right_shift, int32_t acc_min, int32_t acc_max) {
int align_num = 0;
void ConvDwInt8Post(int8_t *dst, int32_t *buffer, int output_w, int channel, int32_t output_zp, int32_t *out_multiplier,
int32_t *left_shift, int32_t *right_shift, int32_t acc_min, int32_t acc_max, bool per_channel) {
if (per_channel) {
// support perchannel
for (int w = 0; w < output_w; w++) {
int channel4 = 0;
#ifdef ENABLE_ARM64
align_num = num_pixels / 4 * 4;
ConvDwInt8PostAlign4(dst, buffer, align_num, output_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max);
channel4 = channel / 4 * 4;
ConvDwInt8PostAlign4PerChannel(dst, buffer, channel4, output_zp, out_multiplier, left_shift, right_shift, acc_min,
acc_max);
#endif
for (int i = align_num; i < num_pixels; i++) {
buffer[i] = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(buffer[i] * (1 << (unsigned int)left_shift), out_multiplier), -right_shift);
buffer[i] += output_zp;
buffer[i] = MSMAX(buffer[i], acc_min);
buffer[i] = MSMIN(buffer[i], acc_max);
dst[i] = (buffer[i]);
for (int c = channel4; c < channel; c++) {
buffer[c] = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(buffer[c] * (1 << (unsigned int)left_shift[c]), out_multiplier[c]),
-right_shift[c]);
buffer[c] += output_zp;
buffer[c] = MSMAX(buffer[c], acc_min);
buffer[c] = MSMIN(buffer[c], acc_max);
dst[c] = (buffer[c]);
}
buffer += channel;
dst += channel;
}
} else {
int num_pixels = output_w * channel;
int align_num = 0;
#ifdef ENABLE_ARM64
align_num = num_pixels / 4 * 4;
ConvDwInt8PostAlign4(dst, buffer, align_num, output_zp, out_multiplier[0], left_shift[0], right_shift[0], acc_min,
acc_max);
#endif
for (int i = align_num; i < num_pixels; i++) {
buffer[i] = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(buffer[i] * (1 << (unsigned int)left_shift[0]), out_multiplier[0]),
-right_shift[0]);
buffer[i] += output_zp;
buffer[i] = MSMAX(buffer[i], acc_min);
buffer[i] = MSMIN(buffer[i], acc_max);
dst[i] = (buffer[i]);
}
}
}

@@ -57,9 +82,10 @@ void ConvDwInt8(int8_t *output_data, int32_t *row_buffer, const int8_t *input_da
int h_start = h_step * task_id;
int h_end = MSMIN(h_start + h_step, conv_param->output_h_);

int out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_[0];
int left_shift = conv_param->conv_quant_arg_.left_shift_[0];
int right_shift = conv_param->conv_quant_arg_.right_shift_[0];
bool filter_per_channel = conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL;
int *out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_;
int *left_shift = conv_param->conv_quant_arg_.left_shift_;
int *right_shift = conv_param->conv_quant_arg_.right_shift_;

int intput_zp = conv_param->conv_quant_arg_.input_quant_args_[0].zp_;
int output_zp = conv_param->conv_quant_arg_.output_quant_args_[0].zp_;
@@ -105,8 +131,8 @@ void ConvDwInt8(int8_t *output_data, int32_t *row_buffer, const int8_t *input_da
}
}
// post func, acc int32 -> dst int8
ConvDwInt8Post(dst_data, row_buffer, conv_param->output_w_ * conv_param->output_channel_, output_zp,
out_multiplier, left_shift, right_shift, acc_min, acc_max);
ConvDwInt8Post(dst_data, row_buffer, conv_param->output_w_, conv_param->output_channel_, output_zp,
out_multiplier, left_shift, right_shift, acc_min, acc_max, filter_per_channel);
}
}
}


+ 16
- 11
mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_int8.cc View File

@@ -51,14 +51,25 @@ int ConvolutionDepthwiseInt8CPUKernel::InitWeightBias() {
PackNCHWToNHWCInt8(origin_weight, tmp_weight, 1, weight_tensor->Height() * weight_tensor->Width(),
weight_tensor->Batch());

int weight_zp = conv_param_->conv_quant_arg_.filter_quant_args_[0].zp_;
packed_weight_ = reinterpret_cast<int16_t *>(malloc(pack_weight_size * sizeof(int16_t)));
if (packed_weight_ == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed.";
return RET_ERROR;
}
for (int i = 0; i < weight_tensor->ElementsNum(); i++) {
packed_weight_[i] = (int16_t)(tmp_weight[i] - weight_zp);

bool filter_per_channel = conv_param_->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL;
if (filter_per_channel) {
for (int i = 0; i < weight_tensor->Height() * weight_tensor->Width(); i++) {
for (int c = 0; c < channel; c++) {
int weight_zp = conv_param_->conv_quant_arg_.filter_quant_args_[c].zp_;
packed_weight_[i * channel + c] = (int16_t)(tmp_weight[i * channel + c] - weight_zp);
}
}
} else {
int weight_zp = conv_param_->conv_quant_arg_.filter_quant_args_[0].zp_;
for (int i = 0; i < weight_tensor->ElementsNum(); i++) {
packed_weight_[i] = (int16_t)(tmp_weight[i] - weight_zp);
}
}
free(tmp_weight);

@@ -166,14 +177,8 @@ kernel::LiteKernel *CpuConvDwInt8KernelCreator(const std::vector<lite::Tensor *>
const mindspore::lite::PrimitiveC *primitive) {
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D);
kernel::LiteKernel *kernel;
auto filter_quant_size = inputs[kWeightIndex]->GetQuantParams().size();
if (filter_quant_size == 1) { // per tensor
kernel = new (std::nothrow) kernel::ConvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
} else { // per channel
kernel =
new (std::nothrow) kernel::ConvolutionDepthwiseSWInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
}
auto kernel =
new (std::nothrow) kernel::ConvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
return nullptr;


Loading…
Cancel
Save