Browse Source

add aarch64 int8 neon implement

tags/20181228
BUG1989 nihui 7 years ago
parent
commit
08b39dffb0
5 changed files with 2811 additions and 124 deletions
  1. +1330
    -25
      src/layer/arm/convolution_1x1_int8.h
  2. +1114
    -40
      src/layer/arm/convolution_3x3_int8.h
  3. +265
    -42
      src/layer/arm/convolutiondepthwise_3x3_int8.h
  4. +68
    -11
      src/layer/arm/dequantize_arm.cpp
  5. +34
    -6
      src/layer/arm/quantize_arm.cpp

+ 1330
- 25
src/layer/arm/convolution_1x1_int8.h
File diff suppressed because it is too large
View File


+ 1114
- 40
src/layer/arm/convolution_3x3_int8.h
File diff suppressed because it is too large
View File


+ 265
- 42
src/layer/arm/convolutiondepthwise_3x3_int8.h View File

@@ -30,42 +30,212 @@ static void convdw3x3s1_int8_neon(const Mat &bottom_blob, Mat &top_blob, const M
{
Mat out = top_blob.channel(p);

const signed char *kernel0 = (const signed char *)_kernel + p * 9;
const signed char* kernel = (const signed char *)_kernel + p*9;
int* outptr0 = out;
int* outptr0n = outptr0 + outw;
const signed char* img0 = bottom_blob.channel(p);
const signed char* r0 = img0;
const signed char* r1 = img0 + w;
const signed char* r2 = img0 + w*2;
const signed char* r3 = img0 + w*3;

int i = 0;

int8x8_t _k0 = vdup_n_s8(kernel[0]);
int8x8_t _k1 = vdup_n_s8(kernel[1]);
int8x8_t _k2 = vdup_n_s8(kernel[2]);

int8x8_t _k3 = vdup_n_s8(kernel[3]);
int8x8_t _k4 = vdup_n_s8(kernel[4]);
int8x8_t _k5 = vdup_n_s8(kernel[5]);

int *outptr = out;
int8x8_t _k6 = vdup_n_s8(kernel[6]);
int8x8_t _k7 = vdup_n_s8(kernel[7]);
int8x8_t _k8 = vdup_n_s8(kernel[8]);

const signed char *img0 = bottom_blob.channel(p);
for (; i+1 < outh; i+=2)
{
int nn = outw >> 3;
int remain = outw & 7;

const signed char *r0 = img0;
const signed char *r1 = img0 + w;
const signed char *r2 = img0 + w * 2;
for (; nn >0; nn--)
{
int8x8_t _r0 = vld1_s8(r0);
int8x8_t _r0n = vld1_s8(r0+8);
int8x8_t _r01 = vext_s8(_r0, _r0n, 1);
int8x8_t _r02 = vext_s8(_r0, _r0n, 2);

int16x8_t _sum0 = vmull_s8(_r0, _k0);
_sum0 = vmlal_s8(_sum0, _r01, _k1);
_sum0 = vmlal_s8(_sum0, _r02, _k2);

int8x8_t _r1 = vld1_s8(r1);
int8x8_t _r1n = vld1_s8(r1+8);
int8x8_t _r11 = vext_s8(_r1, _r1n, 1);
int8x8_t _r12 = vext_s8(_r1, _r1n, 2);
_sum0 = vmlal_s8(_sum0, _r1, _k3);
_sum0 = vmlal_s8(_sum0, _r11, _k4);
_sum0 = vmlal_s8(_sum0, _r12, _k5);

int16x8_t _sum1 = vmull_s8(_r1, _k0);
_sum1 = vmlal_s8(_sum1, _r11, _k1);
_sum1 = vmlal_s8(_sum1, _r12, _k2);

int8x8_t _r2 = vld1_s8(r2);
int8x8_t _r2n = vld1_s8(r2+8);
int8x8_t _r21 = vext_s8(_r2, _r2n, 1);
int8x8_t _r22 = vext_s8(_r2, _r2n, 2);
_sum0 = vmlal_s8(_sum0, _r2, _k6);
_sum0 = vmlal_s8(_sum0, _r21, _k7);
_sum0 = vmlal_s8(_sum0, _r22, _k8);

_sum1 = vmlal_s8(_sum1, _r2, _k3);
_sum1 = vmlal_s8(_sum1, _r21, _k4);
_sum1 = vmlal_s8(_sum1, _r22, _k5);

int8x8_t _r3 = vld1_s8(r3);
int8x8_t _r3n = vld1_s8(r3+8);
int8x8_t _r31 = vext_s8(_r3, _r3n, 1);
int8x8_t _r32 = vext_s8(_r3, _r3n, 2);
_sum1 = vmlal_s8(_sum1, _r3, _k6);
_sum1 = vmlal_s8(_sum1, _r31, _k7);
_sum1 = vmlal_s8(_sum1, _r32, _k8);

int32x4_t sum0_s32 = vmovl_s16(vget_low_s16(_sum0));
int32x4_t sum0n_s32 = vmovl_s16(vget_high_s16(_sum0));

vst1q_s32(outptr0, sum0_s32);
vst1q_s32(outptr0+4, sum0n_s32);

int32x4_t sum1_s32 = vmovl_s16(vget_low_s16(_sum1));
int32x4_t sum1n_s32 = vmovl_s16(vget_high_s16(_sum1));

vst1q_s32(outptr0n, sum1_s32);
vst1q_s32(outptr0n+4, sum1n_s32);

r0 += 8;
r1 += 8;
r2 += 8;
r3 += 8;
outptr0 += 8;
outptr0n += 8;
}

for (; remain>0; remain--)
{
//Todo Neon

int sum0 = 0;
int sum0n = 0;

sum0 += (int)r0[0] * kernel[0];
sum0 += (int)r0[1] * kernel[1];
sum0 += (int)r0[2] * kernel[2];
sum0 += (int)r1[0] * kernel[3];
sum0 += (int)r1[1] * kernel[4];
sum0 += (int)r1[2] * kernel[5];
sum0 += (int)r2[0] * kernel[6];
sum0 += (int)r2[1] * kernel[7];
sum0 += (int)r2[2] * kernel[8];

sum0n += (int)r1[0] * kernel[0];
sum0n += (int)r1[1] * kernel[1];
sum0n += (int)r1[2] * kernel[2];
sum0n += (int)r2[0] * kernel[3];
sum0n += (int)r2[1] * kernel[4];
sum0n += (int)r2[2] * kernel[5];
sum0n += (int)r3[0] * kernel[6];
sum0n += (int)r3[1] * kernel[7];
sum0n += (int)r3[2] * kernel[8];

*outptr0 = sum0;
*outptr0n = sum0n;

r0++;
r1++;
r2++;
r3++;
outptr0++;
outptr0n++;
}

r0 += 2 + w;
r1 += 2 + w;
r2 += 2 + w;
r3 += 2 + w;

outptr0 += outw;
outptr0n += outw;
}

int i = 0;
for (; i < outh; i++)
{
int remain = outw;
int nn = outw >> 3;
int remain = outw & 7;

for (; nn >0; nn--)
{
int8x8_t _r0 = vld1_s8(r0);
int8x8_t _r0n = vld1_s8(r0+8);
int8x8_t _r01 = vext_s8(_r0, _r0n, 1);
int8x8_t _r02 = vext_s8(_r0, _r0n, 2);

int16x8_t _sum0 = vmull_s8(_r0, _k0);
_sum0 = vmlal_s8(_sum0, _r01, _k1);
_sum0 = vmlal_s8(_sum0, _r02, _k2);

int8x8_t _r1 = vld1_s8(r1);
int8x8_t _r1n = vld1_s8(r1+8);
int8x8_t _r11 = vext_s8(_r1, _r1n, 1);
int8x8_t _r12 = vext_s8(_r1, _r1n, 2);
_sum0 = vmlal_s8(_sum0, _r1, _k3);
_sum0 = vmlal_s8(_sum0, _r11, _k4);
_sum0 = vmlal_s8(_sum0, _r12, _k5);

int8x8_t _r2 = vld1_s8(r2);
int8x8_t _r2n = vld1_s8(r2+8);
int8x8_t _r21 = vext_s8(_r2, _r2n, 1);
int8x8_t _r22 = vext_s8(_r2, _r2n, 2);
_sum0 = vmlal_s8(_sum0, _r2, _k6);
_sum0 = vmlal_s8(_sum0, _r21, _k7);
_sum0 = vmlal_s8(_sum0, _r22, _k8);

int32x4_t sum0_s32 = vmovl_s16(vget_low_s16(_sum0));
int32x4_t sum0n_s32 = vmovl_s16(vget_high_s16(_sum0));

vst1q_s32(outptr0, sum0_s32);
vst1q_s32(outptr0+4, sum0n_s32);

r0 += 8;
r1 += 8;
r2 += 8;
outptr0 += 8;
}

for (; remain > 0; remain--)
for (; remain>0; remain--)
{
int sum = 0;

sum += (int)r0[0] * (int)kernel0[0];
sum += (int)r0[1] * (int)kernel0[1];
sum += (int)r0[2] * (int)kernel0[2];
sum += (int)r1[0] * (int)kernel0[3];
sum += (int)r1[1] * (int)kernel0[4];
sum += (int)r1[2] * (int)kernel0[5];
sum += (int)r2[0] * (int)kernel0[6];
sum += (int)r2[1] * (int)kernel0[7];
sum += (int)r2[2] * (int)kernel0[8];
sum += (int)r0[0] * kernel[0];
sum += (int)r0[1] * kernel[1];
sum += (int)r0[2] * kernel[2];
sum += (int)r1[0] * kernel[3];
sum += (int)r1[1] * kernel[4];
sum += (int)r1[2] * kernel[5];
sum += (int)r2[0] * kernel[6];
sum += (int)r2[1] * kernel[7];
sum += (int)r2[2] * kernel[8];

*outptr = sum;
*outptr0 = sum;

r0++;
r1++;
r2++;
outptr++;
}
outptr0++;
}

r0 += 2;
r1 += 2;
@@ -82,42 +252,95 @@ static void convdw3x3s2_int8_neon(const Mat &bottom_blob, Mat &top_blob, const M
int outh = top_blob.h;
int outch = top_blob.c;

const int tailstep = w - 2 * outw + w;
const int tailstep = w - 2*outw + w;

#pragma omp parallel for num_threads(opt.num_threads)
for (int p = 0; p < outch; p++)
for (int p=0; p<outch; p++)
{
Mat out = top_blob.channel(p);

const signed char *kernel0 = (const signed char *)_kernel + p * 9;
const signed char* kernel = (const signed char*)_kernel + p*9;

int *outptr = out;
int* outptr = out;

const signed char *img0 = bottom_blob.channel(p);
const signed char* img = bottom_blob.channel(p);

const signed char *r0 = img0;
const signed char *r1 = img0 + w;
const signed char *r2 = img0 + w * 2;
const signed char* r0 = img;
const signed char* r1 = img + w;
const signed char* r2 = img + w*2;

int i = 0;

int8x8_t _k0 = vdup_n_s8(kernel[0]);
int8x8_t _k1 = vdup_n_s8(kernel[1]);
int8x8_t _k2 = vdup_n_s8(kernel[2]);
int8x8_t _k3 = vdup_n_s8(kernel[3]);
int8x8_t _k4 = vdup_n_s8(kernel[4]);
int8x8_t _k5 = vdup_n_s8(kernel[5]);
int8x8_t _k6 = vdup_n_s8(kernel[6]);
int8x8_t _k7 = vdup_n_s8(kernel[7]);
int8x8_t _k8 = vdup_n_s8(kernel[8]);

for (; i < outh; i++)
{
int remain = outw;
{
int nn = outw >> 3;
int remain = outw & 7;

for (; remain > 0; remain--)
for (; nn > 0; nn--)
{
int sum = 0;
int8x8x2_t _r0 = vld2_s8(r0);
int8x8x2_t _r0n = vld2_s8(r0+16);
int8x8_t _r00 = _r0.val[0];
int8x8_t _r01 = _r0.val[1];
int8x8_t _r02 = vext_s8(_r00, _r0n.val[0], 1);

int16x8_t _sum = vmull_s8(_r00, _k0);
_sum = vmlal_s8(_sum, _r01, _k1);
_sum = vmlal_s8(_sum, _r02, _k2);

int8x8x2_t _r1 = vld2_s8(r1);
int8x8x2_t _r1n = vld2_s8(r1+16);
int8x8_t _r10 = _r1.val[0];
int8x8_t _r11 = _r1.val[1];
int8x8_t _r12 = vext_s8(_r10, _r1n.val[0], 1);
_sum = vmlal_s8(_sum, _r10, _k3);
_sum = vmlal_s8(_sum, _r11, _k4);
_sum = vmlal_s8(_sum, _r12, _k5);

int8x8x2_t _r2 = vld2_s8(r2);
int8x8x2_t _r2n = vld2_s8(r2+16);
int8x8_t _r20 = _r2.val[0];
int8x8_t _r21 = _r2.val[1];
int8x8_t _r22 = vext_s8(_r20, _r2n.val[0], 1);
_sum = vmlal_s8(_sum, _r20, _k6);
_sum = vmlal_s8(_sum, _r21, _k7);
_sum = vmlal_s8(_sum, _r22, _k8);

int32x4_t sum0_s32 = vmovl_s16(vget_low_s16(_sum));
int32x4_t sum0n_s32 = vmovl_s16(vget_high_s16(_sum));

vst1q_s32(outptr, sum0_s32);
vst1q_s32(outptr+4, sum0n_s32);

r0 += 16;
r1 += 16;
r2 += 16;
outptr += 8;
}

sum += (int)r0[0] * (int)kernel0[0];
sum += (int)r0[1] * (int)kernel0[1];
sum += (int)r0[2] * (int)kernel0[2];
sum += (int)r1[0] * (int)kernel0[3];
sum += (int)r1[1] * (int)kernel0[4];
sum += (int)r1[2] * (int)kernel0[5];
sum += (int)r2[0] * (int)kernel0[6];
sum += (int)r2[1] * (int)kernel0[7];
sum += (int)r2[2] * (int)kernel0[8];
for (; remain>0; remain--)
{
int sum = 0;
sum += (int)r0[0] * kernel[0];
sum += (int)r0[1] * kernel[1];
sum += (int)r0[2] * kernel[2];
sum += (int)r1[0] * kernel[3];
sum += (int)r1[1] * kernel[4];
sum += (int)r1[2] * kernel[5];
sum += (int)r2[0] * kernel[6];
sum += (int)r2[1] * kernel[7];
sum += (int)r2[2] * kernel[8];

*outptr = sum;



+ 68
- 11
src/layer/arm/dequantize_arm.cpp View File

@@ -21,11 +21,6 @@ DEFINE_LAYER_CREATOR(Dequantize_arm)

int Dequantize_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
{
#if __aarch64__
// TODO port to aarch64
return Dequantize::forward_inplace(bottom_top_blob, opt);
#endif // __aarch64__

int dims = bottom_top_blob.dims;

if (dims == 1)
@@ -116,7 +111,41 @@ int Dequantize_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) con

#if __ARM_NEON
#if __aarch64__
// TODO
float32x4_t _bias = vdupq_n_f32(bias);
float32x4_t _scale = vdupq_n_f32(scale);

if (nn > 0)
{
asm volatile(
"dup v2.4s, %w6 \n" // scale
"dup v3.4s, %w7 \n" // bias
"0: \n"
"prfm pldl1keep, [%1, #128] \n"
"ld1 {v0.4s, v1.4s}, [%1], #32 \n" // data
// top_s32 -> top_f32
"scvtf v5.4s, v0.4s \n"
"scvtf v6.4s, v1.4s \n"
// top_f32 = top_f32 * scale_out
"fmul v5.4s, v5.4s, v2.4s \n"
"fmul v6.4s, v6.4s, v2.4s \n"
// top_f32 = top_f32 + bias_tm
"fadd v5.4s, v5.4s, v3.4s \n"
"fadd v6.4s, v6.4s, v3.4s \n"
// save top_f32
"st1 {v5.4s, v6.4s}, [%2], #32 \n"
"subs %w0, %w0, #1 \n"
"bne 0b \n"
: "=r"(nn), // %0
"=r"(intptr), // %1
"=r"(ptr) // %2
: "0"(nn),
"1"(intptr),
"2"(ptr),
"r"(_scale), // %6
"r"(_bias) // %7
: "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6"
);
}
#else
if (nn > 0)
{
@@ -127,8 +156,8 @@ int Dequantize_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) con
"vdup.f32 q12, %7 \n" //q12 bias

"0: \n"
"vcvt.f32.s32 q0, q0 \n"
"vcvt.f32.s32 q1, q1 \n"
"vcvt.f32.s32 q0, q0 \n"
"vcvt.f32.s32 q1, q1 \n"

"vmul.f32 q0,q0,q10 \n"
"vmul.f32 q1,q1,q10 \n"
@@ -183,7 +212,35 @@ int Dequantize_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) con

#if __ARM_NEON
#if __aarch64__
// TODO
float32x4_t _scale = vdupq_n_f32(scale);

if (nn > 0)
{
asm volatile(
"dup v2.4s, %w6 \n" // scale
"0: \n"
"prfm pldl1keep, [%1, #128] \n"
"ld1 {v0.4s, v1.4s}, [%1], #32 \n" // data
// top_s32 -> top_f32
"scvtf v5.4s, v0.4s \n"
"scvtf v6.4s, v1.4s \n"
// top_f32 = top_f32 * scale_out
"fmul v5.4s, v5.4s, v2.4s \n"
"fmul v6.4s, v6.4s, v2.4s \n"
// save top_f32
"st1 {v5.4s, v6.4s}, [%2], #32 \n"
"subs %w0, %w0, #1 \n"
"bne 0b \n"
: "=r"(nn), // %0
"=r"(intptr), // %1
"=r"(ptr) // %2
: "0"(nn),
"1"(intptr),
"2"(ptr),
"r"(_scale) // %6
: "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6"
);
}
#else
if (nn > 0)
{
@@ -193,8 +250,8 @@ int Dequantize_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) con
"vdup.f32 q10, %6 \n" //q10 scale

"0: \n"
"vcvt.f32.s32 q0, q0 \n"
"vcvt.f32.s32 q1, q1 \n"
"vcvt.f32.s32 q0, q0 \n"
"vcvt.f32.s32 q1, q1 \n"

"vmul.f32 q2,q0,q10 \n"
"vmul.f32 q3,q1,q10 \n"


+ 34
- 6
src/layer/arm/quantize_arm.cpp View File

@@ -31,11 +31,6 @@ static inline signed char float2int8(float v)

int Quantize_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
{
#if __aarch64__
// TODO port to aarch64 fcvtas
return Quantize::forward(bottom_blob, top_blob, opt);
#endif // __aarch64__

#if !__aarch64__
int FPSCR_value = 0;

@@ -115,7 +110,40 @@ int Quantize_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& o

#if __ARM_NEON
#if __aarch64__
// TODO
float32x4_t _scale = vdupq_n_f32(scale);

if (nn > 0)
{
asm volatile(
"dup v2.4s, %w6 \n" //scale
"0: \n"
"prfm pldl1keep, [%1, #128] \n"
"ld1 {v0.4s, v1.4s}, [%1], #32 \n" //data
// bottom_f32 = bottom_f32 * scale
"fmul v3.4s, v0.4s, v2.4s \n"
"fmul v4.4s, v1.4s, v2.4s \n"
// top_f32 -> top_s32
"fcvtas v5.4s, v3.4s \n"
"fcvtas v6.4s, v4.4s \n"
// top_s32 -> top_s16
"sqxtn v7.4h, v5.4s \n"
"sqxtn2 v7.8h, v6.4s \n"
// top_s16 -> top_s8
"sqxtn v8.8b, v7.8h \n"
// save top_s8
"st1 {v8.8b}, [%2], #8 \n"
"subs %w0, %w0, #1 \n"
"bne 0b \n"
: "=r"(nn), // %0
"=r"(ptr), // %1
"=r"(outptr) // %2
: "0"(nn),
"1"(ptr),
"2"(outptr),
"r"(_scale) // %6
: "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8"
);
}
#else
if (nn > 0)
{


Loading…
Cancel
Save