From f11bebdcf57541283ef0ef1cc29a8ba7a135bdcd Mon Sep 17 00:00:00 2001 From: nihui Date: Sun, 23 Aug 2020 10:54:24 +0800 Subject: [PATCH] batchnorm convdw3x3 arm fp16sa --- src/layer/arm/batchnorm_arm.cpp | 752 ++++++++++++++++-- src/layer/arm/batchnorm_arm.h | 7 + .../arm/convolutiondepthwise_3x3_fp16s.h | 457 +++++++++++ src/layer/arm/convolutiondepthwise_arm.cpp | 20 + 4 files changed, 1177 insertions(+), 59 deletions(-) create mode 100644 src/layer/arm/convolutiondepthwise_3x3_fp16s.h diff --git a/src/layer/arm/batchnorm_arm.cpp b/src/layer/arm/batchnorm_arm.cpp index a8dae4f8d..09c2d03e1 100644 --- a/src/layer/arm/batchnorm_arm.cpp +++ b/src/layer/arm/batchnorm_arm.cpp @@ -24,11 +24,31 @@ BatchNorm_arm::BatchNorm_arm() { #if __ARM_NEON support_packing = true; +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + support_fp16_storage = true; +#endif #endif // __ARM_NEON + + support_bf16_storage = true; } int BatchNorm_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const { + int elembits = bottom_top_blob.elembits(); + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + if (opt.use_fp16_storage && elembits == 16) + { + if (opt.use_fp16_arithmetic) + return forward_inplace_fp16sa(bottom_top_blob, opt); + else + return forward_inplace_fp16s(bottom_top_blob, opt); + } +#endif + + if (opt.use_bf16_storage && elembits == 16) + return forward_inplace_bf16s(bottom_top_blob, opt); + int dims = bottom_top_blob.dims; int elempack = bottom_top_blob.elempack; @@ -107,81 +127,695 @@ int BatchNorm_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) cons } #endif // __ARM_NEON - if (dims != 3) - return BatchNorm::forward_inplace(bottom_top_blob, opt); + if (dims == 1) + { + int w = bottom_top_blob.w; + + float* ptr = bottom_top_blob; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < w; i++) + { + ptr[i] = b_data[i] * ptr[i] + a_data[i]; + } + } + + if (dims == 2) + { + int w = bottom_top_blob.w; + int h = bottom_top_blob.h; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < h; i++) + { + float* ptr = bottom_top_blob.row(i); + + float a = a_data[i]; + float b = b_data[i]; + + float32x4_t _a = vdupq_n_f32(a); + float32x4_t _b = vdupq_n_f32(b); + + int j = 0; + for (; j + 3 < w; j += 4) + { + float32x4_t _p = vld1q_f32(ptr); + _p = vmlaq_f32(_a, _p, _b); + vst1q_f32(ptr, _p); + + ptr += 4; + } + for (; j < w; j++) + { + *ptr = b * *ptr + a; - int w = bottom_top_blob.w; - int h = bottom_top_blob.h; - // int c = bottom_top_blob.c; - int size = w * h; + ptr++; + } + } + } - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) + if (dims == 3) { - float* ptr = bottom_top_blob.channel(q); + int w = bottom_top_blob.w; + int h = bottom_top_blob.h; + int c = bottom_top_blob.c; + int size = w * h; - float a = a_data[q]; - float b = b_data[q]; + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < c; q++) + { + float* ptr = bottom_top_blob.channel(q); + + float a = a_data[q]; + float b = b_data[q]; #if __ARM_NEON - int nn = size >> 2; - int remain = size - (nn << 2); + int nn = size >> 2; + int remain = size - (nn << 2); #else - int remain = size; + int remain = size; #endif // __ARM_NEON #if __ARM_NEON #if __aarch64__ - if (nn > 0) - { - asm volatile( - "dup v1.4s, %w4 \n" - "dup v2.4s, %w5 \n" - "0: \n" - "prfm pldl1keep, [%1, #128] \n" - "ld1 {v0.4s}, [%1] \n" - "orr v3.16b, v1.16b, v1.16b \n" - "fmla v3.4s, v0.4s, v2.4s \n" - "subs %w0, %w0, #1 \n" - "st1 {v3.4s}, [%1], #16 \n" - "bne 0b \n" - : "=r"(nn), // %0 - "=r"(ptr) // %1 - : "0"(nn), - "1"(ptr), - "r"(a), // %4 - "r"(b) // %5 - : "cc", "memory", "v0", "v1", "v2", "v3"); - } + if (nn > 0) + { + asm volatile( + "dup v1.4s, %w4 \n" + "dup v2.4s, %w5 \n" + "0: \n" + "prfm pldl1keep, [%1, #128] \n" + "ld1 {v0.4s}, [%1] \n" + "orr v3.16b, v1.16b, v1.16b \n" + "fmla v3.4s, v0.4s, v2.4s \n" + "subs %w0, %w0, #1 \n" + "st1 {v3.4s}, [%1], #16 \n" + "bne 0b \n" + : "=r"(nn), // %0 + "=r"(ptr) // %1 + : "0"(nn), + "1"(ptr), + "r"(a), // %4 + "r"(b) // %5 + : "cc", "memory", "v0", "v1", "v2", "v3"); + } #else - if (nn > 0) - { - asm volatile( - "vdup.f32 q1, %4 \n" - "vdup.f32 q2, %5 \n" - "0: \n" - "pld [%1, #128] \n" - "vld1.f32 {d0-d1}, [%1 :128] \n" - "vorr.32 q3, q1, q1 \n" - "vmla.f32 q3, q0, q2 \n" - "subs %0, #1 \n" - "vst1.f32 {d6-d7}, [%1 :128]! \n" - "bne 0b \n" - : "=r"(nn), // %0 - "=r"(ptr) // %1 - : "0"(nn), - "1"(ptr), - "r"(a), // %4 - "r"(b) // %5 - : "cc", "memory", "q0", "q1", "q2", "q3"); - } + if (nn > 0) + { + asm volatile( + "vdup.f32 q1, %4 \n" + "vdup.f32 q2, %5 \n" + "0: \n" + "pld [%1, #128] \n" + "vld1.f32 {d0-d1}, [%1 :128] \n" + "vorr.32 q3, q1, q1 \n" + "vmla.f32 q3, q0, q2 \n" + "subs %0, #1 \n" + "vst1.f32 {d6-d7}, [%1 :128]! \n" + "bne 0b \n" + : "=r"(nn), // %0 + "=r"(ptr) // %1 + : "0"(nn), + "1"(ptr), + "r"(a), // %4 + "r"(b) // %5 + : "cc", "memory", "q0", "q1", "q2", "q3"); + } #endif // __aarch64__ #endif // __ARM_NEON - for (; remain > 0; remain--) + for (; remain > 0; remain--) + { + *ptr = b * *ptr + a; + + ptr++; + } + } + } + + return 0; +} + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +int BatchNorm_arm::forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt) const +{ + int dims = bottom_top_blob.dims; + int elempack = bottom_top_blob.elempack; + + if (elempack == 4) + { + if (dims == 1) + { + int w = bottom_top_blob.w; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < w; i++) + { + __fp16* ptr = (__fp16*)bottom_top_blob + i * 4; + + float32x4_t _a = vld1q_f32((const float*)a_data + i * 4); + float32x4_t _b = vld1q_f32((const float*)b_data + i * 4); + + float32x4_t _p = vcvt_f32_f16(vld1_f16(ptr)); + _p = vfmaq_f32(_a, _p, _b); + vst1_f16(ptr, vcvt_f16_f32(_p)); + } + } + + if (dims == 2) + { + int w = bottom_top_blob.w; + int h = bottom_top_blob.h; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < h; i++) + { + float32x4_t _a = vld1q_f32((const float*)a_data + i * 4); + float32x4_t _b = vld1q_f32((const float*)b_data + i * 4); + + __fp16* ptr = bottom_top_blob.row<__fp16>(i); + + for (int j = 0; j < w; j++) + { + float32x4_t _p = vcvt_f32_f16(vld1_f16(ptr)); + _p = vfmaq_f32(_a, _p, _b); + vst1_f16(ptr, vcvt_f16_f32(_p)); + + ptr += 4; + } + } + } + + if (dims == 3) + { + int w = bottom_top_blob.w; + int h = bottom_top_blob.h; + int c = bottom_top_blob.c; + int size = w * h; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < c; q++) + { + float32x4_t _a = vld1q_f32((const float*)a_data + q * 4); + float32x4_t _b = vld1q_f32((const float*)b_data + q * 4); + + __fp16* ptr = bottom_top_blob.channel(q); + + for (int i = 0; i < size; i++) + { + float32x4_t _p = vcvt_f32_f16(vld1_f16(ptr)); + _p = vfmaq_f32(_a, _p, _b); + vst1_f16(ptr, vcvt_f16_f32(_p)); + + ptr += 4; + } + } + } + + return 0; + } + + if (dims == 1) + { + int w = bottom_top_blob.w; + + __fp16* ptr = bottom_top_blob; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < w; i++) + { + ptr[i] = b_data[i] * ptr[i] + a_data[i]; + } + } + + if (dims == 2) + { + int w = bottom_top_blob.w; + int h = bottom_top_blob.h; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < h; i++) + { + __fp16* ptr = bottom_top_blob.row<__fp16>(i); + + float a = a_data[i]; + float b = b_data[i]; + + float32x4_t _a = vdupq_n_f32(a); + float32x4_t _b = vdupq_n_f32(b); + + int j = 0; + for (; j + 3 < w; j += 4) + { + float32x4_t _p = vcvt_f32_f16(vld1_f16(ptr)); + _p = vfmaq_f32(_a, _p, _b); + vst1_f16(ptr, vcvt_f16_f32(_p)); + + ptr += 4; + } + for (; j < w; j++) + { + *ptr = b * *ptr + a; + + ptr++; + } + } + } + + if (dims == 3) + { + int w = bottom_top_blob.w; + int h = bottom_top_blob.h; + int c = bottom_top_blob.c; + int size = w * h; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < c; q++) + { + __fp16* ptr = bottom_top_blob.channel(q); + + float a = a_data[q]; + float b = b_data[q]; + + float32x4_t _a = vdupq_n_f32(a); + float32x4_t _b = vdupq_n_f32(b); + + int j = 0; + for (; j + 3 < size; j += 4) + { + float32x4_t _p = vcvt_f32_f16(vld1_f16(ptr)); + _p = vfmaq_f32(_a, _p, _b); + vst1_f16(ptr, vcvt_f16_f32(_p)); + + ptr += 4; + } + for (; j < size; j++) + { + *ptr = b * *ptr + a; + + ptr++; + } + } + } + + return 0; +} + +int BatchNorm_arm::forward_inplace_fp16sa(Mat& bottom_top_blob, const Option& opt) const +{ + int dims = bottom_top_blob.dims; + int elempack = bottom_top_blob.elempack; + + if (elempack == 8) + { + if (dims == 1) + { + int w = bottom_top_blob.w; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < w; i++) + { + __fp16* ptr = (__fp16*)bottom_top_blob + i * 8; + + float16x8_t _a = vcombine_f16(vcvt_f16_f32(vld1q_f32((const float*)a_data + i * 8)), vcvt_f16_f32(vld1q_f32((const float*)a_data + i * 8 + 4))); + float16x8_t _b = vcombine_f16(vcvt_f16_f32(vld1q_f32((const float*)b_data + i * 8)), vcvt_f16_f32(vld1q_f32((const float*)b_data + i * 8 + 4))); + + float16x8_t _p = vld1q_f16(ptr); + _p = vfmaq_f16(_a, _p, _b); + vst1q_f16(ptr, _p); + } + } + + if (dims == 2) + { + int w = bottom_top_blob.w; + int h = bottom_top_blob.h; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < h; i++) + { + float16x8_t _a = vcombine_f16(vcvt_f16_f32(vld1q_f32((const float*)a_data + i * 8)), vcvt_f16_f32(vld1q_f32((const float*)a_data + i * 8 + 4))); + float16x8_t _b = vcombine_f16(vcvt_f16_f32(vld1q_f32((const float*)b_data + i * 8)), vcvt_f16_f32(vld1q_f32((const float*)b_data + i * 8 + 4))); + + __fp16* ptr = bottom_top_blob.row<__fp16>(i); + + for (int j = 0; j < w; j++) + { + float16x8_t _p = vld1q_f16(ptr); + _p = vfmaq_f16(_a, _p, _b); + vst1q_f16(ptr, _p); + + ptr += 8; + } + } + } + + if (dims == 3) { - *ptr = b * *ptr + a; + int w = bottom_top_blob.w; + int h = bottom_top_blob.h; + int c = bottom_top_blob.c; + int size = w * h; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < c; q++) + { + float16x8_t _a = vcombine_f16(vcvt_f16_f32(vld1q_f32((const float*)a_data + q * 8)), vcvt_f16_f32(vld1q_f32((const float*)a_data + q * 8 + 4))); + float16x8_t _b = vcombine_f16(vcvt_f16_f32(vld1q_f32((const float*)b_data + q * 8)), vcvt_f16_f32(vld1q_f32((const float*)b_data + q * 8 + 4))); - ptr++; + __fp16* ptr = bottom_top_blob.channel(q); + + for (int i = 0; i < size; i++) + { + float16x8_t _p = vld1q_f16(ptr); + _p = vfmaq_f16(_a, _p, _b); + vst1q_f16(ptr, _p); + + ptr += 8; + } + } + } + + return 0; + } + + if (elempack == 4) + { + if (dims == 1) + { + int w = bottom_top_blob.w; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < w; i++) + { + __fp16* ptr = (__fp16*)bottom_top_blob + i * 4; + + float16x4_t _a = vcvt_f16_f32(vld1q_f32((const float*)a_data + i * 4)); + float16x4_t _b = vcvt_f16_f32(vld1q_f32((const float*)b_data + i * 4)); + + float16x4_t _p = vld1_f16(ptr); + _p = vfma_f16(_a, _p, _b); + vst1_f16(ptr, _p); + } + } + + if (dims == 2) + { + int w = bottom_top_blob.w; + int h = bottom_top_blob.h; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < h; i++) + { + float16x4_t _a = vcvt_f16_f32(vld1q_f32((const float*)a_data + i * 4)); + float16x4_t _b = vcvt_f16_f32(vld1q_f32((const float*)b_data + i * 4)); + + __fp16* ptr = bottom_top_blob.row<__fp16>(i); + + for (int j = 0; j < w; j++) + { + float16x4_t _p = vld1_f16(ptr); + _p = vfma_f16(_a, _p, _b); + vst1_f16(ptr, _p); + + ptr += 4; + } + } + } + + if (dims == 3) + { + int w = bottom_top_blob.w; + int h = bottom_top_blob.h; + int c = bottom_top_blob.c; + int size = w * h; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < c; q++) + { + float16x4_t _a = vcvt_f16_f32(vld1q_f32((const float*)a_data + q * 4)); + float16x4_t _b = vcvt_f16_f32(vld1q_f32((const float*)b_data + q * 4)); + + __fp16* ptr = bottom_top_blob.channel(q); + + for (int i = 0; i < size; i++) + { + float16x4_t _p = vld1_f16(ptr); + _p = vfma_f16(_a, _p, _b); + vst1_f16(ptr, _p); + + ptr += 4; + } + } + } + + return 0; + } + + if (dims == 1) + { + int w = bottom_top_blob.w; + + __fp16* ptr = bottom_top_blob; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < w; i++) + { + ptr[i] = (__fp16)b_data[i] * ptr[i] + (__fp16)a_data[i]; + } + } + + if (dims == 2) + { + int w = bottom_top_blob.w; + int h = bottom_top_blob.h; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < h; i++) + { + __fp16* ptr = bottom_top_blob.row<__fp16>(i); + + __fp16 a = (__fp16)a_data[i]; + __fp16 b = (__fp16)b_data[i]; + + float16x4_t _a = vdup_n_f16(a); + float16x4_t _b = vdup_n_f16(b); + + int j = 0; + for (; j + 3 < w; j += 4) + { + float16x4_t _p = vld1_f16(ptr); + _p = vfma_f16(_a, _p, _b); + vst1_f16(ptr, _p); + + ptr += 4; + } + for (; j < w; j++) + { + *ptr = b * *ptr + a; + + ptr++; + } + } + } + + if (dims == 3) + { + int w = bottom_top_blob.w; + int h = bottom_top_blob.h; + int c = bottom_top_blob.c; + int size = w * h; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < c; q++) + { + __fp16* ptr = bottom_top_blob.channel(q); + + __fp16 a = (__fp16)a_data[q]; + __fp16 b = (__fp16)b_data[q]; + + float16x4_t _a = vdup_n_f16(a); + float16x4_t _b = vdup_n_f16(b); + + int j = 0; + for (; j + 3 < size; j += 4) + { + float16x4_t _p = vld1_f16(ptr); + _p = vfma_f16(_a, _p, _b); + vst1_f16(ptr, _p); + + ptr += 4; + } + for (; j < size; j++) + { + *ptr = b * *ptr + a; + + ptr++; + } + } + } + + return 0; +} +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +int BatchNorm_arm::forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const +{ + int dims = bottom_top_blob.dims; + int elempack = bottom_top_blob.elempack; + + if (elempack == 4) + { + if (dims == 1) + { + int w = bottom_top_blob.w; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < w; i++) + { + unsigned short* ptr = (unsigned short*)bottom_top_blob + i * 4; + + float32x4_t _a = vld1q_f32((const float*)a_data + i * 4); + float32x4_t _b = vld1q_f32((const float*)b_data + i * 4); + + float32x4_t _p = vcvt_f32_bf16(vld1_u16(ptr)); + _p = vfmaq_f32(_a, _p, _b); + vst1_u16(ptr, vcvt_bf16_f32(_p)); + } + } + + if (dims == 2) + { + int w = bottom_top_blob.w; + int h = bottom_top_blob.h; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < h; i++) + { + float32x4_t _a = vld1q_f32((const float*)a_data + i * 4); + float32x4_t _b = vld1q_f32((const float*)b_data + i * 4); + + unsigned short* ptr = bottom_top_blob.row(i); + + for (int j = 0; j < w; j++) + { + float32x4_t _p = vcvt_f32_bf16(vld1_u16(ptr)); + _p = vfmaq_f32(_a, _p, _b); + vst1_u16(ptr, vcvt_bf16_f32(_p)); + + ptr += 4; + } + } + } + + if (dims == 3) + { + int w = bottom_top_blob.w; + int h = bottom_top_blob.h; + int c = bottom_top_blob.c; + int size = w * h; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < c; q++) + { + float32x4_t _a = vld1q_f32((const float*)a_data + q * 4); + float32x4_t _b = vld1q_f32((const float*)b_data + q * 4); + + unsigned short* ptr = bottom_top_blob.channel(q); + + for (int i = 0; i < size; i++) + { + float32x4_t _p = vcvt_f32_bf16(vld1_u16(ptr)); + _p = vfmaq_f32(_a, _p, _b); + vst1_u16(ptr, vcvt_bf16_f32(_p)); + + ptr += 4; + } + } + } + + return 0; + } + + if (dims == 1) + { + int w = bottom_top_blob.w; + + unsigned short* ptr = bottom_top_blob; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < w; i++) + { + ptr[i] = float32_to_bfloat16(b_data[i] * bfloat16_to_float32(ptr[i]) + a_data[i]); + } + } + + if (dims == 2) + { + int w = bottom_top_blob.w; + int h = bottom_top_blob.h; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < h; i++) + { + unsigned short* ptr = bottom_top_blob.row(i); + + float a = a_data[i]; + float b = b_data[i]; + + float32x4_t _a = vdupq_n_f32(a); + float32x4_t _b = vdupq_n_f32(b); + + int j = 0; + for (; j + 3 < w; j += 4) + { + float32x4_t _p = vcvt_f32_bf16(vld1_u16(ptr)); + _p = vfmaq_f32(_a, _p, _b); + vst1_u16(ptr, vcvt_bf16_f32(_p)); + + ptr += 4; + } + for (; j < w; j++) + { + *ptr = float32_to_bfloat16(b * bfloat16_to_float32(*ptr) + a); + + ptr++; + } + } + } + + if (dims == 3) + { + int w = bottom_top_blob.w; + int h = bottom_top_blob.h; + int c = bottom_top_blob.c; + int size = w * h; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < c; q++) + { + unsigned short* ptr = bottom_top_blob.channel(q); + + float a = a_data[q]; + float b = b_data[q]; + + float32x4_t _a = vdupq_n_f32(a); + float32x4_t _b = vdupq_n_f32(b); + + int j = 0; + for (; j + 3 < size; j += 4) + { + float32x4_t _p = vcvt_f32_bf16(vld1_u16(ptr)); + _p = vfmaq_f32(_a, _p, _b); + vst1_u16(ptr, vcvt_bf16_f32(_p)); + + ptr += 4; + } + for (; j < size; j++) + { + *ptr = float32_to_bfloat16(b * bfloat16_to_float32(*ptr) + a); + + ptr++; + } } } diff --git a/src/layer/arm/batchnorm_arm.h b/src/layer/arm/batchnorm_arm.h index fb557bab4..c50cd47af 100644 --- a/src/layer/arm/batchnorm_arm.h +++ b/src/layer/arm/batchnorm_arm.h @@ -25,6 +25,13 @@ public: BatchNorm_arm(); virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const; + +protected: +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + int forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt) const; + int forward_inplace_fp16sa(Mat& bottom_top_blob, const Option& opt) const; +#endif + int forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const; }; } // namespace ncnn diff --git a/src/layer/arm/convolutiondepthwise_3x3_fp16s.h b/src/layer/arm/convolutiondepthwise_3x3_fp16s.h new file mode 100644 index 000000000..118914c98 --- /dev/null +++ b/src/layer/arm/convolutiondepthwise_3x3_fp16s.h @@ -0,0 +1,457 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +static void convdw3x3s1_fp16sa_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& _kernel, const Mat& _bias, const Option& opt) +{ + int w = bottom_blob.w; + + int outw = top_blob.w; + int outh = top_blob.h; + + const int group = bottom_blob.c; + + const __fp16* kernel = _kernel; + const __fp16* bias = _bias; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int g = 0; g < group; g++) + { + Mat out = top_blob.channel(g); + + const __fp16 bias0 = bias ? bias[g] : 0.f; + + const __fp16* kernel0 = kernel + g * 9; + + __fp16* outptr0 = out; + __fp16* outptr1 = outptr0 + outw; + + const __fp16* img0 = bottom_blob.channel(g); + + const __fp16* r0 = img0; + const __fp16* r1 = img0 + w; + const __fp16* r2 = img0 + w * 2; + const __fp16* r3 = img0 + w * 3; + + float16x4_t _k012x = vld1_f16(kernel0); + float16x4_t _k345x = vld1_f16(kernel0 + 3); + float16x4_t _k678x = vld1_f16(kernel0 + 6); + + _k012x = vset_lane_f16(0.f, _k012x, 3); + _k345x = vset_lane_f16(0.f, _k345x, 3); + _k678x = vset_lane_f16(0.f, _k678x, 3); + + float16x8_t _bias0 = vdupq_n_f16(bias0); + + int i = 0; + for (; i + 1 < outh; i += 2) + { + int j = 0; + for (; j + 7 < outw; j += 8) + { + float16x8_t _r00 = vld1q_f16(r0); + float16x8_t _r10 = vld1q_f16(r1); + float16x8_t _r20 = vld1q_f16(r2); + float16x8_t _r30 = vld1q_f16(r3); + + float16x8_t _r0n = vld1q_f16(r0 + 8); + float16x8_t _r1n = vld1q_f16(r1 + 8); + float16x8_t _r2n = vld1q_f16(r2 + 8); + float16x8_t _r3n = vld1q_f16(r3 + 8); + + float16x8_t _r01 = vextq_f16(_r00, _r0n, 1); + float16x8_t _r11 = vextq_f16(_r10, _r1n, 1); + float16x8_t _r21 = vextq_f16(_r20, _r2n, 1); + float16x8_t _r31 = vextq_f16(_r30, _r3n, 1); + + float16x8_t _r02 = vextq_f16(_r00, _r0n, 2); + float16x8_t _r12 = vextq_f16(_r10, _r1n, 2); + float16x8_t _r22 = vextq_f16(_r20, _r2n, 2); + float16x8_t _r32 = vextq_f16(_r30, _r3n, 2); + + float16x8_t _sum0 = _bias0; + float16x8_t _sum1 = _bias0; + + _sum0 = vfmaq_lane_f16(_sum0, _r00, _k012x, 0); + _sum0 = vfmaq_lane_f16(_sum0, _r01, _k012x, 1); + _sum0 = vfmaq_lane_f16(_sum0, _r02, _k012x, 2); + _sum1 = vfmaq_lane_f16(_sum1, _r10, _k012x, 0); + _sum1 = vfmaq_lane_f16(_sum1, _r11, _k012x, 1); + _sum1 = vfmaq_lane_f16(_sum1, _r12, _k012x, 2); + + _sum0 = vfmaq_lane_f16(_sum0, _r10, _k345x, 0); + _sum0 = vfmaq_lane_f16(_sum0, _r11, _k345x, 1); + _sum0 = vfmaq_lane_f16(_sum0, _r12, _k345x, 2); + _sum1 = vfmaq_lane_f16(_sum1, _r20, _k345x, 0); + _sum1 = vfmaq_lane_f16(_sum1, _r21, _k345x, 1); + _sum1 = vfmaq_lane_f16(_sum1, _r22, _k345x, 2); + + _sum0 = vfmaq_lane_f16(_sum0, _r20, _k678x, 0); + _sum0 = vfmaq_lane_f16(_sum0, _r21, _k678x, 1); + _sum0 = vfmaq_lane_f16(_sum0, _r22, _k678x, 2); + _sum1 = vfmaq_lane_f16(_sum1, _r30, _k678x, 0); + _sum1 = vfmaq_lane_f16(_sum1, _r31, _k678x, 1); + _sum1 = vfmaq_lane_f16(_sum1, _r32, _k678x, 2); + + vst1q_f16(outptr0, _sum0); + vst1q_f16(outptr1, _sum1); + + r0 += 8; + r1 += 8; + r2 += 8; + r3 += 8; + outptr0 += 8; + outptr1 += 8; + } + for (; j + 3 < outw; j += 4) + { + float16x4_t _r00 = vld1_f16(r0); + float16x4_t _r10 = vld1_f16(r1); + float16x4_t _r20 = vld1_f16(r2); + float16x4_t _r30 = vld1_f16(r3); + + float16x4_t _r0n = vld1_f16(r0 + 4); + float16x4_t _r1n = vld1_f16(r1 + 4); + float16x4_t _r2n = vld1_f16(r2 + 4); + float16x4_t _r3n = vld1_f16(r3 + 4); + + float16x4_t _r01 = vext_f16(_r00, _r0n, 1); + float16x4_t _r11 = vext_f16(_r10, _r1n, 1); + float16x4_t _r21 = vext_f16(_r20, _r2n, 1); + float16x4_t _r31 = vext_f16(_r30, _r3n, 1); + + float16x4_t _r02 = vext_f16(_r00, _r0n, 2); + float16x4_t _r12 = vext_f16(_r10, _r1n, 2); + float16x4_t _r22 = vext_f16(_r20, _r2n, 2); + float16x4_t _r32 = vext_f16(_r30, _r3n, 2); + + float16x4_t _sum0 = vget_low_f16(_bias0); + float16x4_t _sum1 = vget_low_f16(_bias0); + + _sum0 = vfma_lane_f16(_sum0, _r00, _k012x, 0); + _sum0 = vfma_lane_f16(_sum0, _r01, _k012x, 1); + _sum0 = vfma_lane_f16(_sum0, _r02, _k012x, 2); + _sum1 = vfma_lane_f16(_sum1, _r10, _k012x, 0); + _sum1 = vfma_lane_f16(_sum1, _r11, _k012x, 1); + _sum1 = vfma_lane_f16(_sum1, _r12, _k012x, 2); + + _sum0 = vfma_lane_f16(_sum0, _r10, _k345x, 0); + _sum0 = vfma_lane_f16(_sum0, _r11, _k345x, 1); + _sum0 = vfma_lane_f16(_sum0, _r12, _k345x, 2); + _sum1 = vfma_lane_f16(_sum1, _r20, _k345x, 0); + _sum1 = vfma_lane_f16(_sum1, _r21, _k345x, 1); + _sum1 = vfma_lane_f16(_sum1, _r22, _k345x, 2); + + _sum0 = vfma_lane_f16(_sum0, _r20, _k678x, 0); + _sum0 = vfma_lane_f16(_sum0, _r21, _k678x, 1); + _sum0 = vfma_lane_f16(_sum0, _r22, _k678x, 2); + _sum1 = vfma_lane_f16(_sum1, _r30, _k678x, 0); + _sum1 = vfma_lane_f16(_sum1, _r31, _k678x, 1); + _sum1 = vfma_lane_f16(_sum1, _r32, _k678x, 2); + + vst1_f16(outptr0, _sum0); + vst1_f16(outptr1, _sum1); + + r0 += 4; + r1 += 4; + r2 += 4; + r3 += 4; + outptr0 += 4; + outptr1 += 4; + } + for (; j < outw; j++) + { + float16x4_t _r0 = vld1_f16(r0); + float16x4_t _r1 = vld1_f16(r1); + float16x4_t _r2 = vld1_f16(r2); + float16x4_t _r3 = vld1_f16(r3); + + float16x4_t _sum0 = vmul_f16(_r0, _k012x); + _sum0 = vfma_f16(_sum0, _r1, _k345x); + _sum0 = vfma_f16(_sum0, _r2, _k678x); + + float16x4_t _sum1 = vmul_f16(_r1, _k012x); + _sum1 = vfma_f16(_sum1, _r2, _k345x); + _sum1 = vfma_f16(_sum1, _r3, _k678x); + + _sum0 = vset_lane_f16(bias0, _sum0, 3); + _sum1 = vset_lane_f16(bias0, _sum1, 3); + + *outptr0 = (__fp16)vaddvq_f32(vcvt_f32_f16(_sum0)); + *outptr1 = (__fp16)vaddvq_f32(vcvt_f32_f16(_sum1)); + + r0++; + r1++; + r2++; + r3++; + outptr0++; + outptr1++; + } + + r0 += 2 + w; + r1 += 2 + w; + r2 += 2 + w; + r3 += 2 + w; + + outptr0 += outw; + outptr1 += outw; + } + for (; i < outh; i++) + { + int j = 0; + for (; j + 7 < outw; j += 8) + { + float16x8_t _r00 = vld1q_f16(r0); + float16x8_t _r10 = vld1q_f16(r1); + float16x8_t _r20 = vld1q_f16(r2); + + float16x8_t _r0n = vld1q_f16(r0 + 8); + float16x8_t _r1n = vld1q_f16(r1 + 8); + float16x8_t _r2n = vld1q_f16(r2 + 8); + + float16x8_t _r01 = vextq_f16(_r00, _r0n, 1); + float16x8_t _r11 = vextq_f16(_r10, _r1n, 1); + float16x8_t _r21 = vextq_f16(_r20, _r2n, 1); + + float16x8_t _r02 = vextq_f16(_r00, _r0n, 2); + float16x8_t _r12 = vextq_f16(_r10, _r1n, 2); + float16x8_t _r22 = vextq_f16(_r20, _r2n, 2); + + float16x8_t _sum0 = _bias0; + + _sum0 = vfmaq_lane_f16(_sum0, _r00, _k012x, 0); + _sum0 = vfmaq_lane_f16(_sum0, _r01, _k012x, 1); + _sum0 = vfmaq_lane_f16(_sum0, _r02, _k012x, 2); + + _sum0 = vfmaq_lane_f16(_sum0, _r10, _k345x, 0); + _sum0 = vfmaq_lane_f16(_sum0, _r11, _k345x, 1); + _sum0 = vfmaq_lane_f16(_sum0, _r12, _k345x, 2); + + _sum0 = vfmaq_lane_f16(_sum0, _r20, _k678x, 0); + _sum0 = vfmaq_lane_f16(_sum0, _r21, _k678x, 1); + _sum0 = vfmaq_lane_f16(_sum0, _r22, _k678x, 2); + + vst1q_f16(outptr0, _sum0); + + r0 += 8; + r1 += 8; + r2 += 8; + outptr0 += 8; + } + for (; j + 3 < outw; j += 4) + { + float16x4_t _r00 = vld1_f16(r0); + float16x4_t _r10 = vld1_f16(r1); + float16x4_t _r20 = vld1_f16(r2); + + float16x4_t _r0n = vld1_f16(r0 + 4); + float16x4_t _r1n = vld1_f16(r1 + 4); + float16x4_t _r2n = vld1_f16(r2 + 4); + + float16x4_t _r01 = vext_f16(_r00, _r0n, 1); + float16x4_t _r11 = vext_f16(_r10, _r1n, 1); + float16x4_t _r21 = vext_f16(_r20, _r2n, 1); + + float16x4_t _r02 = vext_f16(_r00, _r0n, 2); + float16x4_t _r12 = vext_f16(_r10, _r1n, 2); + float16x4_t _r22 = vext_f16(_r20, _r2n, 2); + + float16x4_t _sum0 = vget_low_f16(_bias0); + + _sum0 = vfma_lane_f16(_sum0, _r00, _k012x, 0); + _sum0 = vfma_lane_f16(_sum0, _r01, _k012x, 1); + _sum0 = vfma_lane_f16(_sum0, _r02, _k012x, 2); + + _sum0 = vfma_lane_f16(_sum0, _r10, _k345x, 0); + _sum0 = vfma_lane_f16(_sum0, _r11, _k345x, 1); + _sum0 = vfma_lane_f16(_sum0, _r12, _k345x, 2); + + _sum0 = vfma_lane_f16(_sum0, _r20, _k678x, 0); + _sum0 = vfma_lane_f16(_sum0, _r21, _k678x, 1); + _sum0 = vfma_lane_f16(_sum0, _r22, _k678x, 2); + + vst1_f16(outptr0, _sum0); + + r0 += 4; + r1 += 4; + r2 += 4; + outptr0 += 4; + } + for (; j < outw; j++) + { + float16x4_t _r0 = vld1_f16(r0); + float16x4_t _r1 = vld1_f16(r1); + float16x4_t _r2 = vld1_f16(r2); + + float16x4_t _sum = vmul_f16(_r0, _k012x); + _sum = vfma_f16(_sum, _r1, _k345x); + _sum = vfma_f16(_sum, _r2, _k678x); + + _sum = vset_lane_f16(bias0, _sum, 3); + + *outptr0 = (__fp16)vaddvq_f32(vcvt_f32_f16(_sum)); + + r0++; + r1++; + r2++; + outptr0++; + } + + r0 += 2; + r1 += 2; + r2 += 2; + } + } +} + +static void convdw3x3s2_fp16sa_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& _kernel, const Mat& _bias, const Option& opt) +{ + int w = bottom_blob.w; + + int outw = top_blob.w; + int outh = top_blob.h; + + const int group = bottom_blob.c; + + const int tailstep = w - 2 * outw + w; + + const __fp16* kernel = _kernel; + const __fp16* bias = _bias; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int g = 0; g < group; g++) + { + Mat out = top_blob.channel(g); + + const __fp16 bias0 = bias ? bias[g] : 0.f; + + const __fp16* kernel0 = kernel + g * 9; + + __fp16* outptr = out; + + const __fp16* img0 = bottom_blob.channel(g); + + const __fp16* r0 = img0; + const __fp16* r1 = img0 + w; + const __fp16* r2 = img0 + w * 2; + + float16x4_t _k012x = vld1_f16(kernel0); + float16x4_t _k345x = vld1_f16(kernel0 + 3); + float16x4_t _k678x = vld1_f16(kernel0 + 6); + + _k012x = vset_lane_f16(0.f, _k012x, 3); + _k345x = vset_lane_f16(0.f, _k345x, 3); + _k678x = vset_lane_f16(0.f, _k678x, 3); + + float16x8_t _bias0 = vdupq_n_f16(bias0); + + int i = 0; + for (; i < outh; i++) + { + int j = 0; + for (; j + 7 < outw; j += 8) + { + float16x8x2_t _r00 = vld2q_f16(r0); + float16x8x2_t _r10 = vld2q_f16(r1); + float16x8x2_t _r20 = vld2q_f16(r2); + + float16x8x2_t _r0n = vld2q_f16(r0 + 16); + float16x8x2_t _r1n = vld2q_f16(r1 + 16); + float16x8x2_t _r2n = vld2q_f16(r2 + 16); + + float16x8_t _r02 = vextq_f16(_r00.val[0], _r0n.val[0], 1); + float16x8_t _r12 = vextq_f16(_r10.val[0], _r1n.val[0], 1); + float16x8_t _r22 = vextq_f16(_r20.val[0], _r2n.val[0], 1); + + float16x8_t _sum = _bias0; + + _sum = vfmaq_lane_f16(_sum, _r00.val[0], _k012x, 0); + _sum = vfmaq_lane_f16(_sum, _r00.val[1], _k012x, 1); + _sum = vfmaq_lane_f16(_sum, _r02, _k012x, 2); + + _sum = vfmaq_lane_f16(_sum, _r10.val[0], _k345x, 0); + _sum = vfmaq_lane_f16(_sum, _r10.val[1], _k345x, 1); + _sum = vfmaq_lane_f16(_sum, _r12, _k345x, 2); + + _sum = vfmaq_lane_f16(_sum, _r20.val[0], _k678x, 0); + _sum = vfmaq_lane_f16(_sum, _r20.val[1], _k678x, 1); + _sum = vfmaq_lane_f16(_sum, _r22, _k678x, 2); + + vst1q_f16(outptr, _sum); + + r0 += 16; + r1 += 16; + r2 += 16; + outptr += 8; + } + for (; j + 3 < outw; j += 4) + { + float16x4x2_t _r00 = vld2_f16(r0); + float16x4x2_t _r10 = vld2_f16(r1); + float16x4x2_t _r20 = vld2_f16(r2); + + float16x4x2_t _r0n = vld2_f16(r0 + 8); + float16x4x2_t _r1n = vld2_f16(r1 + 8); + float16x4x2_t _r2n = vld2_f16(r2 + 8); + + float16x4_t _r02 = vext_f16(_r00.val[0], _r0n.val[0], 1); + float16x4_t _r12 = vext_f16(_r10.val[0], _r1n.val[0], 1); + float16x4_t _r22 = vext_f16(_r20.val[0], _r2n.val[0], 1); + + float16x4_t _sum = vget_low_f16(_bias0); + + _sum = vfma_lane_f16(_sum, _r00.val[0], _k012x, 0); + _sum = vfma_lane_f16(_sum, _r00.val[1], _k012x, 1); + _sum = vfma_lane_f16(_sum, _r02, _k012x, 2); + + _sum = vfma_lane_f16(_sum, _r10.val[0], _k345x, 0); + _sum = vfma_lane_f16(_sum, _r10.val[1], _k345x, 1); + _sum = vfma_lane_f16(_sum, _r12, _k345x, 2); + + _sum = vfma_lane_f16(_sum, _r20.val[0], _k678x, 0); + _sum = vfma_lane_f16(_sum, _r20.val[1], _k678x, 1); + _sum = vfma_lane_f16(_sum, _r22, _k678x, 2); + + vst1_f16(outptr, _sum); + + r0 += 8; + r1 += 8; + r2 += 8; + outptr += 4; + } + for (; j < outw; j++) + { + float16x4_t _r0 = vld1_f16(r0); + float16x4_t _r1 = vld1_f16(r1); + float16x4_t _r2 = vld1_f16(r2); + + float16x4_t _sum = vmul_f16(_r0, _k012x); + _sum = vfma_f16(_sum, _r1, _k345x); + _sum = vfma_f16(_sum, _r2, _k678x); + + _sum = vset_lane_f16(bias0, _sum, 3); + + *outptr = (__fp16)vaddvq_f32(vcvt_f32_f16(_sum)); + + r0 += 2; + r1 += 2; + r2 += 2; + outptr++; + } + + r0 += tailstep; + r1 += tailstep; + r2 += tailstep; + } + } +} diff --git a/src/layer/arm/convolutiondepthwise_arm.cpp b/src/layer/arm/convolutiondepthwise_arm.cpp index e29e1e23b..b30e7a7df 100644 --- a/src/layer/arm/convolutiondepthwise_arm.cpp +++ b/src/layer/arm/convolutiondepthwise_arm.cpp @@ -38,6 +38,7 @@ namespace ncnn { #include "convolutiondepthwise_5x5_pack4.h" #include "convolutiondepthwise_5x5_pack4_bf16s.h" #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#include "convolutiondepthwise_3x3_fp16s.h" #include "convolutiondepthwise_3x3_pack8_fp16s.h" #include "convolutiondepthwise_5x5_pack8_fp16s.h" #endif @@ -991,6 +992,25 @@ int ConvolutionDepthWise_arm::forward_fp16sa(const Mat& bottom_blob, Mat& top_bl if (elempack == 1) { + if (kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) + { + convdw3x3s1_fp16sa_neon(bottom_blob_bordered, top_blob, weight_data_fp16, bias_data_fp16, opt); + + if (activation) + { + activation->forward_inplace(top_blob, opt); + } + } + else if (kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 2 && stride_h == 2) + { + convdw3x3s2_fp16sa_neon(bottom_blob_bordered, top_blob, weight_data_fp16, bias_data_fp16, opt); + + if (activation) + { + activation->forward_inplace(top_blob, opt); + } + } + else { const int maxk = kernel_w * kernel_h;