From 3a83704c386c6eed19316dcfdfeabef51c61ba37 Mon Sep 17 00:00:00 2001 From: nihui Date: Tue, 28 Dec 2021 23:13:44 +0800 Subject: [PATCH] binary4d, unary4d (#3443) --- .github/workflows/test-coverage.yml | 68 + docs/developer-guide/binaryop-broadcasting.md | 12 +- src/layer/arm/binaryop_arm.cpp | 1766 +++++++++- src/layer/arm/unaryop_arm.cpp | 18 +- src/layer/binaryop.cpp | 256 +- src/layer/mips/binaryop_mips.cpp | 286 +- src/layer/mips/unaryop_mips.cpp | 3 +- src/layer/riscv/binaryop_riscv.cpp | 3121 ++++++++++++----- src/layer/riscv/binaryop_riscv.h | 20 +- src/layer/riscv/convolution_3x3_packn.h | 12 +- src/layer/riscv/convolution_3x3_packn_fp16s.h | 12 +- src/layer/riscv/convolution_packnto1_fp16s.h | 2 +- src/layer/riscv/convolution_sgemm_packn.h | 6 +- .../riscv/convolution_sgemm_packn_fp16s.h | 46 +- src/layer/riscv/convolution_sgemm_packnto1.h | 20 +- .../riscv/convolution_sgemm_packnto1_fp16s.h | 20 +- .../riscv/deconvolution_packnto1_fp16s.h | 2 +- src/layer/riscv/riscv_usability.h | 82 + src/layer/riscv/riscv_v_071_fix.h | 39 + src/layer/riscv/unaryop_riscv.cpp | 6 +- src/layer/vulkan/binaryop_vulkan.cpp | 272 +- src/layer/vulkan/relu_vulkan.cpp | 2 +- .../vulkan/shader/binaryop_broadcast.comp | 175 +- .../shader/binaryop_broadcast_pack4.comp | 173 +- .../shader/binaryop_broadcast_pack8.comp | 173 +- src/layer/vulkan/unaryop_vulkan.cpp | 15 +- src/layer/x86/binaryop_x86.cpp | 584 ++- src/net.cpp | 2 +- tests/test_binaryop.cpp | 94 +- tests/test_unaryop.cpp | 17 +- tests/testutil.h | 4 +- toolchains/c906-v222.toolchain.cmake | 4 +- toolchains/c906-v223.toolchain.cmake | 4 +- toolchains/c906.toolchain.cmake | 4 +- 34 files changed, 6153 insertions(+), 1167 deletions(-) diff --git a/.github/workflows/test-coverage.yml b/.github/workflows/test-coverage.yml index 3383433e1..2c196f27c 100644 --- a/.github/workflows/test-coverage.yml +++ b/.github/workflows/test-coverage.yml @@ -395,6 +395,40 @@ jobs: token: ${{ secrets.CODECOV_TOKEN }} file: build-arm82/lcov.info + linux-gcc-arm82-omp: + runs-on: ubuntu-20.04 + steps: + - uses: actions/checkout@v2 + + - name: lcov + run: sudo apt-get install lcov + - name: cache-qemu + id: cache-qemu + uses: actions/cache@v2.1.7 + with: + path: qemu-install + key: qemu-aarch64-install-1 + - name: checkout-qemu + if: steps.cache-qemu.outputs.cache-hit != 'true' + uses: actions/checkout@v2 + with: + repository: qemu/qemu + path: qemu + ref: 8746309137ba470d1b2e8f5ce86ac228625db940 + - name: qemu + if: steps.cache-qemu.outputs.cache-hit != 'true' + run: | + cd qemu + ./configure --prefix=install --target-list=aarch64-linux-user --disable-system + make -j2 + make install + cp -r aarch64-linux-user/install $GITHUB_WORKSPACE/qemu-install + + - name: aarch64-gnu-toolchain + run: | + sudo apt-get update + sudo apt-get install g++-aarch64-linux-gnu + - name: build-arm82-omp run: | mkdir build-arm82-omp && cd build-arm82-omp @@ -418,6 +452,40 @@ jobs: token: ${{ secrets.CODECOV_TOKEN }} file: build-arm82-omp/lcov.info + linux-gcc-arm82dot-omp: + runs-on: ubuntu-20.04 + steps: + - uses: actions/checkout@v2 + + - name: lcov + run: sudo apt-get install lcov + - name: cache-qemu + id: cache-qemu + uses: actions/cache@v2.1.7 + with: + path: qemu-install + key: qemu-aarch64-install-1 + - name: checkout-qemu + if: steps.cache-qemu.outputs.cache-hit != 'true' + uses: actions/checkout@v2 + with: + repository: qemu/qemu + path: qemu + ref: 8746309137ba470d1b2e8f5ce86ac228625db940 + - name: qemu + if: steps.cache-qemu.outputs.cache-hit != 'true' + run: | + cd qemu + ./configure --prefix=install --target-list=aarch64-linux-user --disable-system + make -j2 + make install + cp -r aarch64-linux-user/install $GITHUB_WORKSPACE/qemu-install + + - name: aarch64-gnu-toolchain + run: | + sudo apt-get update + sudo apt-get install g++-aarch64-linux-gnu + - name: build-arm82dot-omp run: | mkdir build-arm82dot-omp && cd build-arm82dot-omp diff --git a/docs/developer-guide/binaryop-broadcasting.md b/docs/developer-guide/binaryop-broadcasting.md index 16bb14246..5f69f9ca6 100644 --- a/docs/developer-guide/binaryop-broadcasting.md +++ b/docs/developer-guide/binaryop-broadcasting.md @@ -4,7 +4,7 @@ ncnn BinaryOp accepts blobs with different shape C = BinaryOp(A, B) -shape notation convention is [w], [w,h], [w,h,c] +shape notation convention is [w], [w,h], [w,h,c], [w,h,d,c] |type|A|B|C| |---|---|---|---| @@ -27,6 +27,16 @@ shape notation convention is [w], [w,h], [w,h,c] |17|[2,3,4]|[4]|[2,3,4]| |18|[2,3,4]|[3,4]|[2,3,4]| |19|[2,3,4]|[2,3,4]|[2,3,4]| +|20|[1]|[2,3,4,5]|[2,3,4,5]| +|21|[5]|[2,3,4,5]|[2,3,4,5]| +|22|[4,5]|[2,3,4,5]|[2,3,4,5]| +|23|[3,4,5]|[2,3,4,5]|[2,3,4,5]| +|24|[2,3,4,5]|scalar|[2,3,4,5]| +|25|[2,3,4,5]|[1]|[2,3,4,5]| +|26|[2,3,4,5]|[5]|[2,3,4,5]| +|27|[2,3,4,5]|[4,5]|[2,3,4,5]| +|28|[2,3,4,5]|[3,4,5]|[2,3,4,5]| +|29|[2,3,4,5]|[2,3,4,5]|[2,3,4,5]| some special broadcasting rule exists for model compatibility diff --git a/src/layer/arm/binaryop_arm.cpp b/src/layer/arm/binaryop_arm.cpp index d7b97ed02..e8bb98249 100644 --- a/src/layer/arm/binaryop_arm.cpp +++ b/src/layer/arm/binaryop_arm.cpp @@ -49,20 +49,203 @@ static int binary_op_pack4(const Mat& a, const Mat& b, Mat& c, const Option& opt int w = a.w; int h = a.h; + int d = a.d; int channels = a.c; - int size = w * h; + int size = w * h * d; size_t elemsize = a.elemsize; int elempack = a.elempack; int w1 = b.w; int h1 = b.h; + int d1 = b.d; int channels1 = b.c; - int size1 = w1 * h1; + int size1 = w1 * h1 * d1; size_t elemsize1 = b.elemsize; int elempack1 = b.elempack; - if (a.dims == 3) + if (a.dims == 4) { + if (b.dims == 4) + { + // type 29 + c.create(w, h, d, channels, elemsize, elempack, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* ptr = a.channel(q); + const float* ptr1 = b.channel(q); + float* outptr = c.channel(q); + + for (int i = 0; i < size; i++) + { + float32x4_t _p = vld1q_f32(ptr); + float32x4_t _p1 = vld1q_f32(ptr1); + float32x4_t _outp = op(_p, _p1); + vst1q_f32(outptr, _outp); + ptr += 4; + ptr1 += 4; + outptr += 4; + } + } + + return 0; + } + + c.create(w, h, d, channels, elemsize, elempack, opt.blob_allocator); + if (c.empty()) + return -100; + + if (b.dims == 3) + { + // type 28 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* ptr = a.channel(q); + const float* ptr1 = b.channel(q); + float* outptr = c.channel(q); + + for (int z = 0; z < d; z++) + { + for (int y = 0; y < h; y++) + { + float32x4_t _b0 = vld1q_f32(ptr1); + for (int x = 0; x < w; x++) + { + float32x4_t _p = vld1q_f32(ptr); + float32x4_t _outp = op(_p, _b0); + vst1q_f32(outptr, _outp); + ptr += 4; + outptr += 4; + } + + ptr1 += 4; + } + } + } + + return 0; + } + + if (b.dims == 2) + { + // type 27 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* ptr = a.channel(q); + const float* ptr1 = b.row(q); + float* outptr = c.channel(q); + + for (int z = 0; z < d; z++) + { + float32x4_t _b0 = vld1q_f32(ptr1); + for (int y = 0; y < h; y++) + { + for (int x = 0; x < w; x++) + { + float32x4_t _p = vld1q_f32(ptr); + float32x4_t _outp = op(_p, _b0); + vst1q_f32(outptr, _outp); + ptr += 4; + outptr += 4; + } + } + + ptr1 += 4; + } + } + + return 0; + } + + if (b.dims == 1) + { + if (b.w == 1 && elempack1 == 1) + { + // type 25 + float32x4_t _b0 = vdupq_n_f32(b[0]); + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* ptr = a.channel(q); + float* outptr = c.channel(q); + + for (int i = 0; i < size; i++) + { + float32x4_t _p = vld1q_f32(ptr); + float32x4_t _outp = op(_p, _b0); + vst1q_f32(outptr, _outp); + ptr += 4; + outptr += 4; + } + } + + return 0; + } + + // type 26 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* ptr = a.channel(q); + float32x4_t _b0 = vld1q_f32((const float*)b + q * 4); + float* outptr = c.channel(q); + + for (int i = 0; i < size; i++) + { + float32x4_t _p = vld1q_f32(ptr); + float32x4_t _outp = op(_p, _b0); + vst1q_f32(outptr, _outp); + ptr += 4; + outptr += 4; + } + } + + return 0; + } + } + else if (a.dims == 3) + { + if (b.dims == 4) + { + // type 23 + c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const float* ptr = a.channel(q); + const float* ptr1 = b.channel(q); + float* outptr = c.channel(q); + + for (int z = 0; z < d1; z++) + { + for (int y = 0; y < h1; y++) + { + float32x4_t _a0 = vld1q_f32(ptr); + for (int x = 0; x < w1; x++) + { + float32x4_t _p = vld1q_f32(ptr1); + float32x4_t _outp = op(_a0, _p); + vst1q_f32(outptr, _outp); + ptr1 += 4; + outptr += 4; + } + + ptr += 4; + } + } + } + + return 0; + } + if (b.dims == 3) { if (w1 == 1 && h1 == 1 && channels1 == channels) @@ -411,6 +594,42 @@ static int binary_op_pack4(const Mat& a, const Mat& b, Mat& c, const Option& opt } else if (a.dims == 2) { + if (b.dims == 4) + { + // type 22 + c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const float* ptr = a.row(q); + const float* ptr1 = b.channel(q); + float* outptr = c.channel(q); + + for (int z = 0; z < d1; z++) + { + float32x4_t _a0 = vld1q_f32(ptr); + for (int y = 0; y < h1; y++) + { + for (int x = 0; x < w1; x++) + { + float32x4_t _p = vld1q_f32(ptr1); + float32x4_t _outp = op(_a0, _p); + vst1q_f32(outptr, _outp); + ptr1 += 4; + outptr += 4; + } + } + + ptr += 4; + } + } + + return 0; + } + if (b.dims == 3) { // type 14 @@ -519,6 +738,33 @@ static int binary_op_pack4(const Mat& a, const Mat& b, Mat& c, const Option& opt { if (a.w == 1 && elempack == 1) { + if (b.dims == 4) + { + // type 20 + c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); + if (c.empty()) + return -100; + + float32x4_t _a0 = vdupq_n_f32(a[0]); + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const float* ptr1 = b.channel(q); + float* outptr = c.channel(q); + + for (int i = 0; i < size1; i++) + { + float32x4_t _p1 = vld1q_f32(ptr1); + float32x4_t _outp = op(_a0, _p1); + vst1q_f32(outptr, _outp); + ptr1 += 4; + outptr += 4; + } + } + + return 0; + } + if (b.dims == 3) { // type 4 @@ -591,6 +837,33 @@ static int binary_op_pack4(const Mat& a, const Mat& b, Mat& c, const Option& opt } } + if (b.dims == 4) + { + // type 21 + c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + float32x4_t _a0 = vld1q_f32((const float*)a + q * 4); + const float* ptr1 = b.channel(q); + float* outptr = c.channel(q); + + for (int i = 0; i < size1; i++) + { + float32x4_t _p1 = vld1q_f32(ptr1); + float32x4_t _outp = op(_a0, _p1); + vst1q_f32(outptr, _outp); + ptr1 += 4; + outptr += 4; + } + } + + return 0; + } + if (b.dims == 3) { // type 9 @@ -698,8 +971,9 @@ static int binary_op_scalar_inplace_pack4(Mat& a, float b, const Option& opt) int w = a.w; int h = a.h; + int d = a.d; int channels = a.c; - int size = w * h; + int size = w * h * d; float32x4_t _b = vdupq_n_f32(b); @@ -851,10 +1125,10 @@ int BinaryOp_arm::forward(const std::vector& bottom_blobs, std::vector return binary_op_pack4(bottom_blob, bottom_blob1, top_blob, opt); if (op_type == Operation_RSUB) - return binary_op_pack4(bottom_blob, bottom_blob1, top_blob, opt); + return binary_op_pack4(bottom_blob1, bottom_blob, top_blob, opt); if (op_type == Operation_RDIV) - return binary_op_pack4(bottom_blob, bottom_blob1, top_blob, opt); + return binary_op_pack4(bottom_blob1, bottom_blob, top_blob, opt); } #endif // __ARM_NEON @@ -920,33 +1194,216 @@ static int binary_op_pack8_fp16s(const Mat& a, const Mat& b, Mat& c, const Optio int w = a.w; int h = a.h; + int d = a.d; int channels = a.c; - int size = w * h; + int size = w * h * d; size_t elemsize = a.elemsize; int elempack = a.elempack; int w1 = b.w; int h1 = b.h; + int d1 = b.d; int channels1 = b.c; - int size1 = w1 * h1; + int size1 = w1 * h1 * d1; size_t elemsize1 = b.elemsize; int elempack1 = b.elempack; - if (a.dims == 3) + if (a.dims == 4) { - if (b.dims == 3) + if (b.dims == 4) { - if (w1 == 1 && h1 == 1 && channels1 == channels) + // type 29 + c.create(w, h, d, channels, elemsize, elempack, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - // special type 1 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; + const __fp16* ptr = a.channel(q); + const __fp16* ptr1 = b.channel(q); + __fp16* outptr = c.channel(q); - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) + for (int i = 0; i < size; i++) { - const __fp16* ptr = a.channel(q); + float16x8_t _p = vld1q_f16(ptr); + float16x8_t _p1 = vld1q_f16(ptr1); + float16x8_t _outp = op(_p, _p1); + vst1q_f16(outptr, _outp); + ptr += 8; + ptr1 += 8; + outptr += 8; + } + } + + return 0; + } + + c.create(w, h, d, channels, elemsize, elempack, opt.blob_allocator); + if (c.empty()) + return -100; + + if (b.dims == 3) + { + // type 28 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const __fp16* ptr = a.channel(q); + const __fp16* ptr1 = b.channel(q); + __fp16* outptr = c.channel(q); + + for (int z = 0; z < d; z++) + { + for (int y = 0; y < h; y++) + { + float16x8_t _b0 = vld1q_f16(ptr1); + for (int x = 0; x < w; x++) + { + float16x8_t _p = vld1q_f16(ptr); + float16x8_t _outp = op(_p, _b0); + vst1q_f16(outptr, _outp); + ptr += 8; + outptr += 8; + } + + ptr1 += 8; + } + } + } + + return 0; + } + + if (b.dims == 2) + { + // type 27 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const __fp16* ptr = a.channel(q); + const __fp16* ptr1 = b.row(q); + __fp16* outptr = c.channel(q); + + for (int z = 0; z < d; z++) + { + float16x8_t _b0 = vld1q_f16(ptr1); + for (int y = 0; y < h; y++) + { + for (int x = 0; x < w; x++) + { + float16x8_t _p = vld1q_f16(ptr); + float16x8_t _outp = op(_p, _b0); + vst1q_f16(outptr, _outp); + ptr += 8; + outptr += 8; + } + } + + ptr1 += 8; + } + } + + return 0; + } + + if (b.dims == 1) + { + if (b.w == 1 && elempack1 == 1) + { + // type 25 + float16x8_t _b0 = vdupq_n_f16(((const __fp16*)b)[0]); + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const __fp16* ptr = a.channel(q); + __fp16* outptr = c.channel(q); + + for (int i = 0; i < size; i++) + { + float16x8_t _p = vld1q_f16(ptr); + float16x8_t _outp = op(_p, _b0); + vst1q_f16(outptr, _outp); + ptr += 8; + outptr += 8; + } + } + + return 0; + } + + // type 26 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const __fp16* ptr = a.channel(q); + float16x8_t _b0 = vld1q_f16((const __fp16*)b + q * 8); + __fp16* outptr = c.channel(q); + + for (int i = 0; i < size; i++) + { + float16x8_t _p = vld1q_f16(ptr); + float16x8_t _outp = op(_p, _b0); + vst1q_f16(outptr, _outp); + ptr += 8; + outptr += 8; + } + } + + return 0; + } + } + else if (a.dims == 3) + { + if (b.dims == 4) + { + // type 23 + c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const __fp16* ptr = a.channel(q); + const __fp16* ptr1 = b.channel(q); + __fp16* outptr = c.channel(q); + + for (int z = 0; z < d1; z++) + { + for (int y = 0; y < h1; y++) + { + float16x8_t _a0 = vld1q_f16(ptr); + for (int x = 0; x < w1; x++) + { + float16x8_t _p = vld1q_f16(ptr1); + float16x8_t _outp = op(_a0, _p); + vst1q_f16(outptr, _outp); + ptr1 += 8; + outptr += 8; + } + + ptr += 8; + } + } + } + + return 0; + } + + if (b.dims == 3) + { + if (w1 == 1 && h1 == 1 && channels1 == channels) + { + // special type 1 + c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const __fp16* ptr = a.channel(q); __fp16* outptr = c.channel(q); const __fp16* b0 = b.channel(q); float16x8_t _b0 = vld1q_f16(b0); @@ -1282,6 +1739,42 @@ static int binary_op_pack8_fp16s(const Mat& a, const Mat& b, Mat& c, const Optio } else if (a.dims == 2) { + if (b.dims == 4) + { + // type 22 + c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const __fp16* ptr = a.row(q); + const __fp16* ptr1 = b.channel(q); + __fp16* outptr = c.channel(q); + + for (int z = 0; z < d1; z++) + { + float16x8_t _a0 = vld1q_f16(ptr); + for (int y = 0; y < h1; y++) + { + for (int x = 0; x < w1; x++) + { + float16x8_t _p = vld1q_f16(ptr1); + float16x8_t _outp = op(_a0, _p); + vst1q_f16(outptr, _outp); + ptr1 += 8; + outptr += 8; + } + } + + ptr += 8; + } + } + + return 0; + } + if (b.dims == 3) { // type 14 @@ -1390,6 +1883,33 @@ static int binary_op_pack8_fp16s(const Mat& a, const Mat& b, Mat& c, const Optio { if (a.w == 1 && elempack == 1) { + if (b.dims == 4) + { + // type 20 + c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); + if (c.empty()) + return -100; + + float16x8_t _a0 = vdupq_n_f16(((const __fp16*)a)[0]); + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const __fp16* ptr1 = b.channel(q); + __fp16* outptr = c.channel(q); + + for (int i = 0; i < size1; i++) + { + float16x8_t _p1 = vld1q_f16(ptr1); + float16x8_t _outp = op(_a0, _p1); + vst1q_f16(outptr, _outp); + ptr1 += 8; + outptr += 8; + } + } + + return 0; + } + if (b.dims == 3) { // type 4 @@ -1462,6 +1982,33 @@ static int binary_op_pack8_fp16s(const Mat& a, const Mat& b, Mat& c, const Optio } } + if (b.dims == 4) + { + // type 21 + c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + float16x8_t _a0 = vld1q_f16((const __fp16*)a + q * 8); + const __fp16* ptr1 = b.channel(q); + __fp16* outptr = c.channel(q); + + for (int i = 0; i < size1; i++) + { + float16x8_t _p1 = vld1q_f16(ptr1); + float16x8_t _outp = op(_a0, _p1); + vst1q_f16(outptr, _outp); + ptr1 += 8; + outptr += 8; + } + } + + return 0; + } + if (b.dims == 3) { // type 9 @@ -1569,8 +2116,9 @@ static int binary_op_scalar_inplace_pack8_fp16s(Mat& a, float b, const Option& o int w = a.w; int h = a.h; + int d = a.d; int channels = a.c; - int size = w * h; + int size = w * h * d; float16x8_t _b = vdupq_n_f16((__fp16)b); @@ -1672,33 +2220,216 @@ static int binary_op_pack4_fp16s(const Mat& a, const Mat& b, Mat& c, const Optio int w = a.w; int h = a.h; + int d = a.d; int channels = a.c; - int size = w * h; + int size = w * h * d; size_t elemsize = a.elemsize; int elempack = a.elempack; int w1 = b.w; int h1 = b.h; + int d1 = b.d; int channels1 = b.c; - int size1 = w1 * h1; + int size1 = w1 * h1 * d1; size_t elemsize1 = b.elemsize; int elempack1 = b.elempack; - if (a.dims == 3) + if (a.dims == 4) { - if (b.dims == 3) + if (b.dims == 4) { - if (w1 == 1 && h1 == 1 && channels1 == channels) + // type 29 + c.create(w, h, d, channels, elemsize, elempack, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - // special type 1 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; + const __fp16* ptr = a.channel(q); + const __fp16* ptr1 = b.channel(q); + __fp16* outptr = c.channel(q); - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) + for (int i = 0; i < size; i++) { - const __fp16* ptr = a.channel(q); + float16x4_t _p = vld1_f16(ptr); + float16x4_t _p1 = vld1_f16(ptr1); + float16x4_t _outp = op(_p, _p1); + vst1_f16(outptr, _outp); + ptr += 4; + ptr1 += 4; + outptr += 4; + } + } + + return 0; + } + + c.create(w, h, d, channels, elemsize, elempack, opt.blob_allocator); + if (c.empty()) + return -100; + + if (b.dims == 3) + { + // type 28 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const __fp16* ptr = a.channel(q); + const __fp16* ptr1 = b.channel(q); + __fp16* outptr = c.channel(q); + + for (int z = 0; z < d; z++) + { + for (int y = 0; y < h; y++) + { + float16x4_t _b0 = vld1_f16(ptr1); + for (int x = 0; x < w; x++) + { + float16x4_t _p = vld1_f16(ptr); + float16x4_t _outp = op(_p, _b0); + vst1_f16(outptr, _outp); + ptr += 4; + outptr += 4; + } + + ptr1 += 4; + } + } + } + + return 0; + } + + if (b.dims == 2) + { + // type 27 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const __fp16* ptr = a.channel(q); + const __fp16* ptr1 = b.row(q); + __fp16* outptr = c.channel(q); + + for (int z = 0; z < d; z++) + { + float16x4_t _b0 = vld1_f16(ptr1); + for (int y = 0; y < h; y++) + { + for (int x = 0; x < w; x++) + { + float16x4_t _p = vld1_f16(ptr); + float16x4_t _outp = op(_p, _b0); + vst1_f16(outptr, _outp); + ptr += 4; + outptr += 4; + } + } + + ptr1 += 4; + } + } + + return 0; + } + + if (b.dims == 1) + { + if (b.w == 1 && elempack1 == 1) + { + // type 25 + float16x4_t _b0 = vdup_n_f16(((const __fp16*)b)[0]); + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const __fp16* ptr = a.channel(q); + __fp16* outptr = c.channel(q); + + for (int i = 0; i < size; i++) + { + float16x4_t _p = vld1_f16(ptr); + float16x4_t _outp = op(_p, _b0); + vst1_f16(outptr, _outp); + ptr += 4; + outptr += 4; + } + } + + return 0; + } + + // type 26 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const __fp16* ptr = a.channel(q); + float16x4_t _b0 = vld1_f16((const __fp16*)b + q * 4); + __fp16* outptr = c.channel(q); + + for (int i = 0; i < size; i++) + { + float16x4_t _p = vld1_f16(ptr); + float16x4_t _outp = op(_p, _b0); + vst1_f16(outptr, _outp); + ptr += 4; + outptr += 4; + } + } + + return 0; + } + } + else if (a.dims == 3) + { + if (b.dims == 4) + { + // type 23 + c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const __fp16* ptr = a.channel(q); + const __fp16* ptr1 = b.channel(q); + __fp16* outptr = c.channel(q); + + for (int z = 0; z < d1; z++) + { + for (int y = 0; y < h1; y++) + { + float16x4_t _a0 = vld1_f16(ptr); + for (int x = 0; x < w1; x++) + { + float16x4_t _p = vld1_f16(ptr1); + float16x4_t _outp = op(_a0, _p); + vst1_f16(outptr, _outp); + ptr1 += 4; + outptr += 4; + } + + ptr += 4; + } + } + } + + return 0; + } + + if (b.dims == 3) + { + if (w1 == 1 && h1 == 1 && channels1 == channels) + { + // special type 1 + c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const __fp16* ptr = a.channel(q); __fp16* outptr = c.channel(q); const __fp16* b0 = b.channel(q); float16x4_t _b0 = vld1_f16(b0); @@ -2034,6 +2765,42 @@ static int binary_op_pack4_fp16s(const Mat& a, const Mat& b, Mat& c, const Optio } else if (a.dims == 2) { + if (b.dims == 4) + { + // type 22 + c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const __fp16* ptr = a.row(q); + const __fp16* ptr1 = b.channel(q); + __fp16* outptr = c.channel(q); + + for (int z = 0; z < d1; z++) + { + float16x4_t _a0 = vld1_f16(ptr); + for (int y = 0; y < h1; y++) + { + for (int x = 0; x < w1; x++) + { + float16x4_t _p = vld1_f16(ptr1); + float16x4_t _outp = op(_a0, _p); + vst1_f16(outptr, _outp); + ptr1 += 4; + outptr += 4; + } + } + + ptr += 4; + } + } + + return 0; + } + if (b.dims == 3) { // type 14 @@ -2142,6 +2909,33 @@ static int binary_op_pack4_fp16s(const Mat& a, const Mat& b, Mat& c, const Optio { if (a.w == 1 && elempack == 1) { + if (b.dims == 4) + { + // type 20 + c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); + if (c.empty()) + return -100; + + float16x4_t _a0 = vdup_n_f16(((const __fp16*)a)[0]); + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const __fp16* ptr1 = b.channel(q); + __fp16* outptr = c.channel(q); + + for (int i = 0; i < size1; i++) + { + float16x4_t _p1 = vld1_f16(ptr1); + float16x4_t _outp = op(_a0, _p1); + vst1_f16(outptr, _outp); + ptr1 += 4; + outptr += 4; + } + } + + return 0; + } + if (b.dims == 3) { // type 4 @@ -2214,6 +3008,33 @@ static int binary_op_pack4_fp16s(const Mat& a, const Mat& b, Mat& c, const Optio } } + if (b.dims == 4) + { + // type 21 + c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + float16x4_t _a0 = vld1_f16((const __fp16*)a + q * 4); + const __fp16* ptr1 = b.channel(q); + __fp16* outptr = c.channel(q); + + for (int i = 0; i < size1; i++) + { + float16x4_t _p1 = vld1_f16(ptr1); + float16x4_t _outp = op(_a0, _p1); + vst1_f16(outptr, _outp); + ptr1 += 4; + outptr += 4; + } + } + + return 0; + } + if (b.dims == 3) { // type 9 @@ -2321,8 +3142,9 @@ static int binary_op_scalar_inplace_pack4_fp16s(Mat& a, float b, const Option& o int w = a.w; int h = a.h; + int d = a.d; int channels = a.c; - int size = w * h; + int size = w * h * d; float16x4_t _b = vdup_n_f16((__fp16)b); @@ -2422,30 +3244,194 @@ static int binary_op_fp16s(const Mat& a, const Mat& b, Mat& c, const Option& opt int w = a.w; int h = a.h; + int d = a.d; int channels = a.c; - int size = w * h; + int size = w * h * d; size_t elemsize = a.elemsize; int w1 = b.w; int h1 = b.h; + int d1 = b.d; int channels1 = b.c; - int size1 = w1 * h1; + int size1 = w1 * h1 * d1; - if (a.dims == 3) + if (a.dims == 4) { - if (b.dims == 3) + if (b.dims == 4) { - if (w1 == 1 && h1 == 1 && channels1 == channels) + // type 29 + c.create(w, h, d, channels, elemsize, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - // special type 1 - c.create(w, h, channels, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; + const __fp16* ptr = a.channel(q); + const __fp16* ptr1 = b.channel(q); + __fp16* outptr = c.channel(q); - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) + for (int i = 0; i < size; i++) { - const __fp16* ptr = a.channel(q); + outptr[i] = op(ptr[i], ptr1[i]); + } + } + + return 0; + } + + c.create(w, h, d, channels, elemsize, opt.blob_allocator); + if (c.empty()) + return -100; + + if (b.dims == 3) + { + // type 28 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const __fp16* ptr = a.channel(q); + const __fp16* ptr1 = b.channel(q); + __fp16* outptr = c.channel(q); + + for (int z = 0; z < d; z++) + { + for (int y = 0; y < h; y++) + { + const __fp16 b0 = ptr1[y]; + for (int x = 0; x < w; x++) + { + outptr[x] = op(ptr[x], b0); + } + + ptr += w; + outptr += w; + } + + ptr1 += h; + } + } + + return 0; + } + + if (b.dims == 2) + { + // type 27 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const __fp16* ptr = a.channel(q); + const __fp16* ptr1 = b.row(q); + __fp16* outptr = c.channel(q); + + for (int z = 0; z < d; z++) + { + const __fp16 b0 = ptr1[z]; + for (int y = 0; y < h; y++) + { + for (int x = 0; x < w; x++) + { + outptr[x] = op(ptr[x], b0); + } + + ptr += w; + outptr += w; + } + } + } + + return 0; + } + + if (b.dims == 1) + { + if (b.w == 1) + { + // type 25 + const __fp16 b0 = ((const __fp16*)b)[0]; + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const __fp16* ptr = a.channel(q); + __fp16* outptr = c.channel(q); + + for (int i = 0; i < size; i++) + { + outptr[i] = op(ptr[i], b0); + } + } + + return 0; + } + + // type 26 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const __fp16* ptr = a.channel(q); + const __fp16 b0 = ((const __fp16*)b)[q]; + __fp16* outptr = c.channel(q); + + for (int i = 0; i < size; i++) + { + outptr[i] = op(ptr[i], b0); + } + } + + return 0; + } + } + else if (a.dims == 3) + { + if (b.dims == 4) + { + // type 23 + c.create(w1, h1, d1, channels1, elemsize, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const __fp16* ptr = a.channel(q); + const __fp16* ptr1 = b.channel(q); + __fp16* outptr = c.channel(q); + + for (int z = 0; z < d1; z++) + { + for (int y = 0; y < h1; y++) + { + const __fp16 a0 = ptr[y]; + for (int x = 0; x < w1; x++) + { + outptr[x] = op(a0, ptr1[x]); + } + + ptr1 += w1; + outptr += w1; + } + + ptr += h1; + } + } + + return 0; + } + + if (b.dims == 3) + { + if (w1 == 1 && h1 == 1 && channels1 == channels) + { + // special type 1 + c.create(w, h, channels, elemsize, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const __fp16* ptr = a.channel(q); const __fp16* b0 = b.channel(q); __fp16* outptr = c.channel(q); for (int i = 0; i < size; i++) @@ -2732,6 +3718,39 @@ static int binary_op_fp16s(const Mat& a, const Mat& b, Mat& c, const Option& opt } else if (a.dims == 2) { + if (b.dims == 4) + { + // type 22 + c.create(w1, h1, d1, channels1, elemsize, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const __fp16* ptr = a.row(q); + const __fp16* ptr1 = b.channel(q); + __fp16* outptr = c.channel(q); + + for (int z = 0; z < d1; z++) + { + const __fp16 a0 = ptr[z]; + for (int y = 0; y < h1; y++) + { + for (int x = 0; x < w1; x++) + { + outptr[x] = op(a0, ptr1[x]); + } + + ptr1 += w1; + outptr += w1; + } + } + } + + return 0; + } + if (b.dims == 3) { // type 14 @@ -2823,6 +3842,29 @@ static int binary_op_fp16s(const Mat& a, const Mat& b, Mat& c, const Option& opt { if (a.w == 1) { + if (b.dims == 4) + { + // type 20 + c.create(w1, h1, d1, channels1, elemsize, opt.blob_allocator); + if (c.empty()) + return -100; + + const __fp16 a0 = ((const __fp16*)a)[0]; + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const __fp16* ptr1 = b.channel(q); + __fp16* outptr = c.channel(q); + + for (int i = 0; i < size1; i++) + { + outptr[i] = op(a0, ptr1[i]); + } + } + + return 0; + } + if (b.dims == 3) { // type 4 @@ -2883,6 +3925,29 @@ static int binary_op_fp16s(const Mat& a, const Mat& b, Mat& c, const Option& opt } } + if (b.dims == 4) + { + // type 21 + c.create(w1, h1, d1, channels1, elemsize, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const __fp16 a0 = ((const __fp16*)a)[q]; + const __fp16* ptr1 = b.channel(q); + __fp16* outptr = c.channel(q); + + for (int i = 0; i < size1; i++) + { + outptr[i] = op(a0, ptr1[i]); + } + } + + return 0; + } + if (b.dims == 3) { // type 9 @@ -2972,8 +4037,9 @@ static int binary_op_scalar_inplace_fp16s(Mat& a, float b, const Option& opt) int w = a.w; int h = a.h; + int d = a.d; int channels = a.c; - int size = w * h; + int size = w * h * d; __fp16 b16 = (__fp16)b; @@ -3097,10 +4163,10 @@ int BinaryOp_arm::forward_fp16s(const std::vector& bottom_blobs, std::vecto return binary_op_pack8_fp16s(bottom_blob, bottom_blob1, top_blob, opt); if (op_type == Operation_RSUB) - return binary_op_pack8_fp16s(bottom_blob, bottom_blob1, top_blob, opt); + return binary_op_pack8_fp16s(bottom_blob1, bottom_blob, top_blob, opt); if (op_type == Operation_RDIV) - return binary_op_pack8_fp16s(bottom_blob, bottom_blob1, top_blob, opt); + return binary_op_pack8_fp16s(bottom_blob1, bottom_blob, top_blob, opt); } if (elempack == 4 || elempack1 == 4) @@ -3127,10 +4193,10 @@ int BinaryOp_arm::forward_fp16s(const std::vector& bottom_blobs, std::vecto return binary_op_pack4_fp16s(bottom_blob, bottom_blob1, top_blob, opt); if (op_type == Operation_RSUB) - return binary_op_pack4_fp16s(bottom_blob, bottom_blob1, top_blob, opt); + return binary_op_pack4_fp16s(bottom_blob1, bottom_blob, top_blob, opt); if (op_type == Operation_RDIV) - return binary_op_pack4_fp16s(bottom_blob, bottom_blob1, top_blob, opt); + return binary_op_pack4_fp16s(bottom_blob1, bottom_blob, top_blob, opt); } if (elempack == 1 && elempack1 == 1) @@ -3157,10 +4223,10 @@ int BinaryOp_arm::forward_fp16s(const std::vector& bottom_blobs, std::vecto return binary_op_fp16s(bottom_blob, bottom_blob1, top_blob, opt); if (op_type == Operation_RSUB) - return binary_op_fp16s(bottom_blob, bottom_blob1, top_blob, opt); + return binary_op_fp16s(bottom_blob1, bottom_blob, top_blob, opt); if (op_type == Operation_RDIV) - return binary_op_fp16s(bottom_blob, bottom_blob1, top_blob, opt); + return binary_op_fp16s(bottom_blob1, bottom_blob, top_blob, opt); } return 0; @@ -3273,55 +4339,238 @@ static int binary_op_pack4_bf16s(const Mat& a, const Mat& b, Mat& c, const Optio int w = a.w; int h = a.h; + int d = a.d; int channels = a.c; - int size = w * h; + int size = w * h * d; size_t elemsize = a.elemsize; int elempack = a.elempack; int w1 = b.w; int h1 = b.h; + int d1 = b.d; int channels1 = b.c; - int size1 = w1 * h1; + int size1 = w1 * h1 * d1; size_t elemsize1 = b.elemsize; int elempack1 = b.elempack; - if (a.dims == 3) + if (a.dims == 4) { - if (b.dims == 3) + if (b.dims == 4) { - if (w1 == 1 && h1 == 1 && channels1 == channels) + // type 29 + c.create(w, h, d, channels, elemsize, elempack, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - // special type 1 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; + const unsigned short* ptr = a.channel(q); + const unsigned short* ptr1 = b.channel(q); + unsigned short* outptr = c.channel(q); - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) + for (int i = 0; i < size; i++) { - const unsigned short* ptr = a.channel(q); - unsigned short* outptr = c.channel(q); - const unsigned short* b0 = b.channel(q); - float32x4_t _b0 = vcvt_f32_bf16(vld1_u16(b0)); - for (int i = 0; i < size; i++) - { - float32x4_t _p = vcvt_f32_bf16(vld1_u16(ptr)); - float32x4_t _outp = op(_p, _b0); - vst1_u16(outptr, vcvt_bf16_f32(_outp)); - ptr += 4; - outptr += 4; - } + float32x4_t _p = vcvt_f32_bf16(vld1_u16(ptr)); + float32x4_t _p1 = vcvt_f32_bf16(vld1_u16(ptr1)); + float32x4_t _outp = op(_p, _p1); + vst1_u16(outptr, vcvt_bf16_f32(_outp)); + ptr += 4; + ptr1 += 4; + outptr += 4; } - - return 0; } - if (w1 == w && h1 == h && channels1 == 1 && elempack1 == 1) - { - // special type 2 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; + return 0; + } + + c.create(w, h, d, channels, elemsize, elempack, opt.blob_allocator); + if (c.empty()) + return -100; + + if (b.dims == 3) + { + // type 28 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const unsigned short* ptr = a.channel(q); + const unsigned short* ptr1 = b.channel(q); + unsigned short* outptr = c.channel(q); + + for (int z = 0; z < d; z++) + { + for (int y = 0; y < h; y++) + { + float32x4_t _b0 = vcvt_f32_bf16(vld1_u16(ptr1)); + for (int x = 0; x < w; x++) + { + float32x4_t _p = vcvt_f32_bf16(vld1_u16(ptr)); + float32x4_t _outp = op(_p, _b0); + vst1_u16(outptr, vcvt_bf16_f32(_outp)); + ptr += 4; + outptr += 4; + } + + ptr1 += 4; + } + } + } + + return 0; + } + + if (b.dims == 2) + { + // type 27 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const unsigned short* ptr = a.channel(q); + const unsigned short* ptr1 = b.row(q); + unsigned short* outptr = c.channel(q); + + for (int z = 0; z < d; z++) + { + float32x4_t _b0 = vcvt_f32_bf16(vld1_u16(ptr1)); + for (int y = 0; y < h; y++) + { + for (int x = 0; x < w; x++) + { + float32x4_t _p = vcvt_f32_bf16(vld1_u16(ptr)); + float32x4_t _outp = op(_p, _b0); + vst1_u16(outptr, vcvt_bf16_f32(_outp)); + ptr += 4; + outptr += 4; + } + } + + ptr1 += 4; + } + } + + return 0; + } + + if (b.dims == 1) + { + if (b.w == 1 && elempack1 == 1) + { + // type 25 + float32x4_t _b0 = vdupq_n_f32(bfloat16_to_float32(((const unsigned short*)b)[0])); + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const unsigned short* ptr = a.channel(q); + unsigned short* outptr = c.channel(q); + + for (int i = 0; i < size; i++) + { + float32x4_t _p = vcvt_f32_bf16(vld1_u16(ptr)); + float32x4_t _outp = op(_p, _b0); + vst1_u16(outptr, vcvt_bf16_f32(_outp)); + ptr += 4; + outptr += 4; + } + } + + return 0; + } + + // type 26 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const unsigned short* ptr = a.channel(q); + float32x4_t _b0 = vcvt_f32_bf16(vld1_u16((const unsigned short*)b + q * 4)); + unsigned short* outptr = c.channel(q); + + for (int i = 0; i < size; i++) + { + float32x4_t _p = vcvt_f32_bf16(vld1_u16(ptr)); + float32x4_t _outp = op(_p, _b0); + vst1_u16(outptr, vcvt_bf16_f32(_outp)); + ptr += 4; + outptr += 4; + } + } + + return 0; + } + } + else if (a.dims == 3) + { + if (b.dims == 4) + { + // type 23 + c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const unsigned short* ptr = a.channel(q); + const unsigned short* ptr1 = b.channel(q); + unsigned short* outptr = c.channel(q); + + for (int z = 0; z < d1; z++) + { + for (int y = 0; y < h1; y++) + { + float32x4_t _a0 = vcvt_f32_bf16(vld1_u16(ptr)); + for (int x = 0; x < w1; x++) + { + float32x4_t _p = vcvt_f32_bf16(vld1_u16(ptr1)); + float32x4_t _outp = op(_a0, _p); + vst1_u16(outptr, vcvt_bf16_f32(_outp)); + ptr1 += 4; + outptr += 4; + } + + ptr += 4; + } + } + } + + return 0; + } + + if (b.dims == 3) + { + if (w1 == 1 && h1 == 1 && channels1 == channels) + { + // special type 1 + c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const unsigned short* ptr = a.channel(q); + unsigned short* outptr = c.channel(q); + const unsigned short* b0 = b.channel(q); + float32x4_t _b0 = vcvt_f32_bf16(vld1_u16(b0)); + for (int i = 0; i < size; i++) + { + float32x4_t _p = vcvt_f32_bf16(vld1_u16(ptr)); + float32x4_t _outp = op(_p, _b0); + vst1_u16(outptr, vcvt_bf16_f32(_outp)); + ptr += 4; + outptr += 4; + } + } + + return 0; + } + + if (w1 == w && h1 == h && channels1 == 1 && elempack1 == 1) + { + // special type 2 + c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); + if (c.empty()) + return -100; #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) @@ -3635,6 +4884,42 @@ static int binary_op_pack4_bf16s(const Mat& a, const Mat& b, Mat& c, const Optio } else if (a.dims == 2) { + if (b.dims == 4) + { + // type 22 + c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const unsigned short* ptr = a.row(q); + const unsigned short* ptr1 = b.channel(q); + unsigned short* outptr = c.channel(q); + + for (int z = 0; z < d1; z++) + { + float32x4_t _a0 = vcvt_f32_bf16(vld1_u16(ptr)); + for (int y = 0; y < h1; y++) + { + for (int x = 0; x < w1; x++) + { + float32x4_t _p = vcvt_f32_bf16(vld1_u16(ptr1)); + float32x4_t _outp = op(_a0, _p); + vst1_u16(outptr, vcvt_bf16_f32(_outp)); + ptr1 += 4; + outptr += 4; + } + } + + ptr += 4; + } + } + + return 0; + } + if (b.dims == 3) { // type 14 @@ -3743,6 +5028,33 @@ static int binary_op_pack4_bf16s(const Mat& a, const Mat& b, Mat& c, const Optio { if (a.w == 1 && elempack == 1) { + if (b.dims == 4) + { + // type 20 + c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); + if (c.empty()) + return -100; + + float32x4_t _a0 = vdupq_n_f32(bfloat16_to_float32(((const unsigned short*)a)[0])); + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const unsigned short* ptr1 = b.channel(q); + unsigned short* outptr = c.channel(q); + + for (int i = 0; i < size1; i++) + { + float32x4_t _p1 = vcvt_f32_bf16(vld1_u16(ptr1)); + float32x4_t _outp = op(_a0, _p1); + vst1_u16(outptr, vcvt_bf16_f32(_outp)); + ptr1 += 4; + outptr += 4; + } + } + + return 0; + } + if (b.dims == 3) { // type 4 @@ -3815,6 +5127,33 @@ static int binary_op_pack4_bf16s(const Mat& a, const Mat& b, Mat& c, const Optio } } + if (b.dims == 4) + { + // type 21 + c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + float32x4_t _a0 = vcvt_f32_bf16(vld1_u16((const unsigned short*)a + q * 4)); + const unsigned short* ptr1 = b.channel(q); + unsigned short* outptr = c.channel(q); + + for (int i = 0; i < size1; i++) + { + float32x4_t _p1 = vcvt_f32_bf16(vld1_u16(ptr1)); + float32x4_t _outp = op(_a0, _p1); + vst1_u16(outptr, vcvt_bf16_f32(_outp)); + ptr1 += 4; + outptr += 4; + } + } + + return 0; + } + if (b.dims == 3) { // type 9 @@ -3922,8 +5261,9 @@ static int binary_op_scalar_inplace_pack4_bf16s(Mat& a, float b, const Option& o int w = a.w; int h = a.h; + int d = a.d; int channels = a.c; - int size = w * h; + int size = w * h * d; float32x4_t _b = vdupq_n_f32(b); @@ -3952,17 +5292,181 @@ static int binary_op_bf16s(const Mat& a, const Mat& b, Mat& c, const Option& opt int w = a.w; int h = a.h; + int d = a.d; int channels = a.c; - int size = w * h; + int size = w * h * d; size_t elemsize = a.elemsize; int w1 = b.w; int h1 = b.h; + int d1 = b.d; int channels1 = b.c; - int size1 = w1 * h1; + int size1 = w1 * h1 * d1; - if (a.dims == 3) + if (a.dims == 4) { + if (b.dims == 4) + { + // type 29 + c.create(w, h, d, channels, elemsize, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const unsigned short* ptr = a.channel(q); + const unsigned short* ptr1 = b.channel(q); + unsigned short* outptr = c.channel(q); + + for (int i = 0; i < size; i++) + { + outptr[i] = float32_to_bfloat16(op(bfloat16_to_float32(ptr[i]), bfloat16_to_float32(ptr1[i]))); + } + } + + return 0; + } + + c.create(w, h, d, channels, elemsize, opt.blob_allocator); + if (c.empty()) + return -100; + + if (b.dims == 3) + { + // type 28 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const unsigned short* ptr = a.channel(q); + const unsigned short* ptr1 = b.channel(q); + unsigned short* outptr = c.channel(q); + + for (int z = 0; z < d; z++) + { + for (int y = 0; y < h; y++) + { + const float b0 = bfloat16_to_float32(ptr1[y]); + for (int x = 0; x < w; x++) + { + outptr[x] = float32_to_bfloat16(op(bfloat16_to_float32(ptr[x]), b0)); + } + + ptr += w; + outptr += w; + } + + ptr1 += h; + } + } + + return 0; + } + + if (b.dims == 2) + { + // type 27 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const unsigned short* ptr = a.channel(q); + const unsigned short* ptr1 = b.row(q); + unsigned short* outptr = c.channel(q); + + for (int z = 0; z < d; z++) + { + const float b0 = bfloat16_to_float32(ptr1[z]); + for (int y = 0; y < h; y++) + { + for (int x = 0; x < w; x++) + { + outptr[x] = float32_to_bfloat16(op(bfloat16_to_float32(ptr[x]), b0)); + } + + ptr += w; + outptr += w; + } + } + } + + return 0; + } + + if (b.dims == 1) + { + if (b.w == 1) + { + // type 25 + const float b0 = bfloat16_to_float32(((const unsigned short*)b)[0]); + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const unsigned short* ptr = a.channel(q); + unsigned short* outptr = c.channel(q); + + for (int i = 0; i < size; i++) + { + outptr[i] = float32_to_bfloat16(op(bfloat16_to_float32(ptr[i]), b0)); + } + } + + return 0; + } + + // type 26 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const unsigned short* ptr = a.channel(q); + const float b0 = bfloat16_to_float32(((const unsigned short*)b)[q]); + unsigned short* outptr = c.channel(q); + + for (int i = 0; i < size; i++) + { + outptr[i] = float32_to_bfloat16(op(bfloat16_to_float32(ptr[i]), b0)); + } + } + + return 0; + } + } + else if (a.dims == 3) + { + if (b.dims == 4) + { + // type 23 + c.create(w1, h1, d1, channels1, elemsize, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const unsigned short* ptr = a.channel(q); + const unsigned short* ptr1 = b.channel(q); + unsigned short* outptr = c.channel(q); + + for (int z = 0; z < d1; z++) + { + for (int y = 0; y < h1; y++) + { + const float a0 = bfloat16_to_float32(ptr[y]); + for (int x = 0; x < w1; x++) + { + outptr[x] = float32_to_bfloat16(op(a0, bfloat16_to_float32(ptr1[x]))); + } + + ptr1 += w1; + outptr += w1; + } + + ptr += h1; + } + } + + return 0; + } + if (b.dims == 3) { if (w1 == 1 && h1 == 1 && channels1 == channels) @@ -4262,6 +5766,39 @@ static int binary_op_bf16s(const Mat& a, const Mat& b, Mat& c, const Option& opt } else if (a.dims == 2) { + if (b.dims == 4) + { + // type 22 + c.create(w1, h1, d1, channels1, elemsize, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const unsigned short* ptr = a.row(q); + const unsigned short* ptr1 = b.channel(q); + unsigned short* outptr = c.channel(q); + + for (int z = 0; z < d1; z++) + { + const float a0 = bfloat16_to_float32(ptr[z]); + for (int y = 0; y < h1; y++) + { + for (int x = 0; x < w1; x++) + { + outptr[x] = float32_to_bfloat16(op(a0, bfloat16_to_float32(ptr1[x]))); + } + + ptr1 += w1; + outptr += w1; + } + } + } + + return 0; + } + if (b.dims == 3) { // type 14 @@ -4353,6 +5890,29 @@ static int binary_op_bf16s(const Mat& a, const Mat& b, Mat& c, const Option& opt { if (a.w == 1) { + if (b.dims == 4) + { + // type 20 + c.create(w1, h1, d1, channels1, elemsize, opt.blob_allocator); + if (c.empty()) + return -100; + + const float a0 = bfloat16_to_float32(((const unsigned short*)a)[0]); + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const unsigned short* ptr1 = b.channel(q); + unsigned short* outptr = c.channel(q); + + for (int i = 0; i < size1; i++) + { + outptr[i] = float32_to_bfloat16(op(a0, bfloat16_to_float32(ptr1[i]))); + } + } + + return 0; + } + if (b.dims == 3) { // type 4 @@ -4413,6 +5973,29 @@ static int binary_op_bf16s(const Mat& a, const Mat& b, Mat& c, const Option& opt } } + if (b.dims == 4) + { + // type 21 + c.create(w1, h1, d1, channels1, elemsize, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const float a0 = bfloat16_to_float32(((const unsigned short*)a)[q]); + const unsigned short* ptr1 = b.channel(q); + unsigned short* outptr = c.channel(q); + + for (int i = 0; i < size1; i++) + { + outptr[i] = float32_to_bfloat16(op(a0, bfloat16_to_float32(ptr1[i]))); + } + } + + return 0; + } + if (b.dims == 3) { // type 9 @@ -4502,8 +6085,9 @@ static int binary_op_scalar_inplace_bf16s(Mat& a, float b, const Option& opt) int w = a.w; int h = a.h; + int d = a.d; int channels = a.c; - int size = w * h; + int size = w * h * d; #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) @@ -4626,10 +6210,10 @@ int BinaryOp_arm::forward_bf16s(const std::vector& bottom_blobs, std::vecto return binary_op_pack4_bf16s(bottom_blob, bottom_blob1, top_blob, opt); if (op_type == Operation_RSUB) - return binary_op_pack4_bf16s(bottom_blob, bottom_blob1, top_blob, opt); + return binary_op_pack4_bf16s(bottom_blob1, bottom_blob, top_blob, opt); if (op_type == Operation_RDIV) - return binary_op_pack4_bf16s(bottom_blob, bottom_blob1, top_blob, opt); + return binary_op_pack4_bf16s(bottom_blob1, bottom_blob, top_blob, opt); } #endif // __ARM_NEON @@ -4657,10 +6241,10 @@ int BinaryOp_arm::forward_bf16s(const std::vector& bottom_blobs, std::vecto return binary_op_bf16s(bottom_blob, bottom_blob1, top_blob, opt); if (op_type == Operation_RSUB) - return binary_op_bf16s(bottom_blob, bottom_blob1, top_blob, opt); + return binary_op_bf16s(bottom_blob1, bottom_blob, top_blob, opt); if (op_type == Operation_RDIV) - return binary_op_bf16s(bottom_blob, bottom_blob1, top_blob, opt); + return binary_op_bf16s(bottom_blob1, bottom_blob, top_blob, opt); } return 0; diff --git a/src/layer/arm/unaryop_arm.cpp b/src/layer/arm/unaryop_arm.cpp index d809bb80a..177574ea5 100644 --- a/src/layer/arm/unaryop_arm.cpp +++ b/src/layer/arm/unaryop_arm.cpp @@ -48,8 +48,9 @@ static int unary_op_inplace_pack4(Mat& a, const Option& opt) int w = a.w; int h = a.h; + int d = a.d; int channels = a.c; - int size = w * h; + int size = w * h * d; #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) @@ -341,8 +342,9 @@ static int unary_op_inplace_pack8_fp16s(Mat& a, const Option& opt) int w = a.w; int h = a.h; + int d = a.d; int channels = a.c; - int size = w * h; + int size = w * h * d; #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) @@ -554,8 +556,9 @@ static int unary_op_inplace_pack4_fp16s(Mat& a, const Option& opt) int w = a.w; int h = a.h; + int d = a.d; int channels = a.c; - int size = w * h; + int size = w * h * d; #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) @@ -751,8 +754,9 @@ static int unary_op_inplace_fp16s(Mat& a, const Option& opt) int w = a.w; int h = a.h; + int d = a.d; int channels = a.c; - int size = w * h; + int size = w * h * d; #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) @@ -1082,8 +1086,9 @@ static int unary_op_inplace_pack4_bf16s(Mat& a, const Option& opt) int w = a.w; int h = a.h; + int d = a.d; int channels = a.c; - int size = w * h; + int size = w * h * d; #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) @@ -1111,8 +1116,9 @@ static int unary_op_inplace_bf16s(Mat& a, const Option& opt) int w = a.w; int h = a.h; + int d = a.d; int channels = a.c; - int size = w * h; + int size = w * h * d; #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) diff --git a/src/layer/binaryop.cpp b/src/layer/binaryop.cpp index 48892c926..53eb234bc 100644 --- a/src/layer/binaryop.cpp +++ b/src/layer/binaryop.cpp @@ -49,17 +49,181 @@ static int binary_op(const Mat& a, const Mat& b, Mat& c, const Option& opt) int w = a.w; int h = a.h; + int d = a.d; int channels = a.c; - int size = w * h; + int size = w * h * d; size_t elemsize = a.elemsize; int w1 = b.w; int h1 = b.h; + int d1 = b.d; int channels1 = b.c; - int size1 = w1 * h1; + int size1 = w1 * h1 * d1; - if (a.dims == 3) + if (a.dims == 4) { + if (b.dims == 4) + { + // type 29 + c.create(w, h, d, channels, elemsize, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* ptr = a.channel(q); + const float* ptr1 = b.channel(q); + float* outptr = c.channel(q); + + for (int i = 0; i < size; i++) + { + outptr[i] = op(ptr[i], ptr1[i]); + } + } + + return 0; + } + + c.create(w, h, d, channels, elemsize, opt.blob_allocator); + if (c.empty()) + return -100; + + if (b.dims == 3) + { + // type 28 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* ptr = a.channel(q); + const float* ptr1 = b.channel(q); + float* outptr = c.channel(q); + + for (int z = 0; z < d; z++) + { + for (int y = 0; y < h; y++) + { + const float b0 = ptr1[y]; + for (int x = 0; x < w; x++) + { + outptr[x] = op(ptr[x], b0); + } + + ptr += w; + outptr += w; + } + + ptr1 += h; + } + } + + return 0; + } + + if (b.dims == 2) + { + // type 27 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* ptr = a.channel(q); + const float* ptr1 = b.row(q); + float* outptr = c.channel(q); + + for (int z = 0; z < d; z++) + { + const float b0 = ptr1[z]; + for (int y = 0; y < h; y++) + { + for (int x = 0; x < w; x++) + { + outptr[x] = op(ptr[x], b0); + } + + ptr += w; + outptr += w; + } + } + } + + return 0; + } + + if (b.dims == 1) + { + if (b.w == 1) + { + // type 25 + const float b0 = b[0]; + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* ptr = a.channel(q); + float* outptr = c.channel(q); + + for (int i = 0; i < size; i++) + { + outptr[i] = op(ptr[i], b0); + } + } + + return 0; + } + + // type 26 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* ptr = a.channel(q); + const float b0 = b[q]; + float* outptr = c.channel(q); + + for (int i = 0; i < size; i++) + { + outptr[i] = op(ptr[i], b0); + } + } + + return 0; + } + } + else if (a.dims == 3) + { + if (b.dims == 4) + { + // type 23 + c.create(w1, h1, d1, channels1, elemsize, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const float* ptr = a.channel(q); + const float* ptr1 = b.channel(q); + float* outptr = c.channel(q); + + for (int z = 0; z < d1; z++) + { + for (int y = 0; y < h1; y++) + { + const float a0 = ptr[y]; + for (int x = 0; x < w1; x++) + { + outptr[x] = op(a0, ptr1[x]); + } + + ptr1 += w1; + outptr += w1; + } + + ptr += h1; + } + } + + return 0; + } + if (b.dims == 3) { if (w1 == 1 && h1 == 1 && channels1 == channels) @@ -359,6 +523,39 @@ static int binary_op(const Mat& a, const Mat& b, Mat& c, const Option& opt) } else if (a.dims == 2) { + if (b.dims == 4) + { + // type 22 + c.create(w1, h1, d1, channels1, elemsize, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const float* ptr = a.row(q); + const float* ptr1 = b.channel(q); + float* outptr = c.channel(q); + + for (int z = 0; z < d1; z++) + { + const float a0 = ptr[z]; + for (int y = 0; y < h1; y++) + { + for (int x = 0; x < w1; x++) + { + outptr[x] = op(a0, ptr1[x]); + } + + ptr1 += w1; + outptr += w1; + } + } + } + + return 0; + } + if (b.dims == 3) { // type 14 @@ -445,6 +642,29 @@ static int binary_op(const Mat& a, const Mat& b, Mat& c, const Option& opt) { if (a.w == 1) { + if (b.dims == 4) + { + // type 20 + c.create(w1, h1, d1, channels1, elemsize, opt.blob_allocator); + if (c.empty()) + return -100; + + const float a0 = a[0]; + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const float* ptr1 = b.channel(q); + float* outptr = c.channel(q); + + for (int i = 0; i < size1; i++) + { + outptr[i] = op(a0, ptr1[i]); + } + } + + return 0; + } + if (b.dims == 3) { // type 4 @@ -501,6 +721,29 @@ static int binary_op(const Mat& a, const Mat& b, Mat& c, const Option& opt) } } + if (b.dims == 4) + { + // type 21 + c.create(w1, h1, d1, channels1, elemsize, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const float a0 = a[q]; + const float* ptr1 = b.channel(q); + float* outptr = c.channel(q); + + for (int i = 0; i < size1; i++) + { + outptr[i] = op(a0, ptr1[i]); + } + } + + return 0; + } + if (b.dims == 3) { // type 9 @@ -585,8 +828,9 @@ static int binary_op_scalar_inplace(Mat& a, float b, const Option& opt) int w = a.w; int h = a.h; + int d = a.d; int channels = a.c; - int size = w * h; + int size = w * h * d; #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) @@ -703,10 +947,10 @@ int BinaryOp::forward(const std::vector& bottom_blobs, std::vector& to return binary_op(bottom_blob, bottom_blob1, top_blob, opt); if (op_type == Operation_RSUB) - return binary_op(bottom_blob, bottom_blob1, top_blob, opt); + return binary_op(bottom_blob1, bottom_blob, top_blob, opt); if (op_type == Operation_RDIV) - return binary_op(bottom_blob, bottom_blob1, top_blob, opt); + return binary_op(bottom_blob1, bottom_blob, top_blob, opt); return 0; } diff --git a/src/layer/mips/binaryop_mips.cpp b/src/layer/mips/binaryop_mips.cpp index 79e2c6a1f..c02121261 100644 --- a/src/layer/mips/binaryop_mips.cpp +++ b/src/layer/mips/binaryop_mips.cpp @@ -41,20 +41,203 @@ static int binary_op_pack4(const Mat& a, const Mat& b, Mat& c, const Option& opt int w = a.w; int h = a.h; + int d = a.d; int channels = a.c; - int size = w * h; + int size = w * h * d; size_t elemsize = a.elemsize; int elempack = a.elempack; int w1 = b.w; int h1 = b.h; + int d1 = b.d; int channels1 = b.c; - int size1 = w1 * h1; + int size1 = w1 * h1 * d1; size_t elemsize1 = b.elemsize; int elempack1 = b.elempack; - if (a.dims == 3) + if (a.dims == 4) { + if (b.dims == 4) + { + // type 29 + c.create(w, h, d, channels, elemsize, elempack, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* ptr = a.channel(q); + const float* ptr1 = b.channel(q); + float* outptr = c.channel(q); + + for (int i = 0; i < size; i++) + { + v4f32 _p = (v4f32)__msa_ld_w(ptr, 0); + v4f32 _p1 = (v4f32)__msa_ld_w(ptr1, 0); + v4f32 _outp = op(_p, _p1); + __msa_st_w((v4i32)_outp, outptr, 0); + ptr += 4; + ptr1 += 4; + outptr += 4; + } + } + + return 0; + } + + c.create(w, h, d, channels, elemsize, elempack, opt.blob_allocator); + if (c.empty()) + return -100; + + if (b.dims == 3) + { + // type 28 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* ptr = a.channel(q); + const float* ptr1 = b.channel(q); + float* outptr = c.channel(q); + + for (int z = 0; z < d; z++) + { + for (int y = 0; y < h; y++) + { + v4f32 _b0 = (v4f32)__msa_ld_w(ptr1, 0); + for (int x = 0; x < w; x++) + { + v4f32 _p = (v4f32)__msa_ld_w(ptr, 0); + v4f32 _outp = op(_p, _b0); + __msa_st_w((v4i32)_outp, outptr, 0); + ptr += 4; + outptr += 4; + } + + ptr1 += 4; + } + } + } + + return 0; + } + + if (b.dims == 2) + { + // type 27 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* ptr = a.channel(q); + const float* ptr1 = b.row(q); + float* outptr = c.channel(q); + + for (int z = 0; z < d; z++) + { + v4f32 _b0 = (v4f32)__msa_ld_w(ptr1, 0); + for (int y = 0; y < h; y++) + { + for (int x = 0; x < w; x++) + { + v4f32 _p = (v4f32)__msa_ld_w(ptr, 0); + v4f32 _outp = op(_p, _b0); + __msa_st_w((v4i32)_outp, outptr, 0); + ptr += 4; + outptr += 4; + } + } + + ptr1 += 4; + } + } + + return 0; + } + + if (b.dims == 1) + { + if (b.w == 1 && elempack1 == 1) + { + // type 25 + v4f32 _b0 = __msa_fill_w_f32(b[0]); + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* ptr = a.channel(q); + float* outptr = c.channel(q); + + for (int i = 0; i < size; i++) + { + v4f32 _p = (v4f32)__msa_ld_w(ptr, 0); + v4f32 _outp = op(_p, _b0); + __msa_st_w((v4i32)_outp, outptr, 0); + ptr += 4; + outptr += 4; + } + } + + return 0; + } + + // type 26 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* ptr = a.channel(q); + v4f32 _b0 = (v4f32)__msa_ld_w((const float*)b + q * 4, 0); + float* outptr = c.channel(q); + + for (int i = 0; i < size; i++) + { + v4f32 _p = (v4f32)__msa_ld_w(ptr, 0); + v4f32 _outp = op(_p, _b0); + __msa_st_w((v4i32)_outp, outptr, 0); + ptr += 4; + outptr += 4; + } + } + + return 0; + } + } + else if (a.dims == 3) + { + if (b.dims == 4) + { + // type 23 + c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const float* ptr = a.channel(q); + const float* ptr1 = b.channel(q); + float* outptr = c.channel(q); + + for (int z = 0; z < d1; z++) + { + for (int y = 0; y < h1; y++) + { + v4f32 _a0 = (v4f32)__msa_ld_w(ptr, 0); + for (int x = 0; x < w1; x++) + { + v4f32 _p = (v4f32)__msa_ld_w(ptr1, 0); + v4f32 _outp = op(_a0, _p); + __msa_st_w((v4i32)_outp, outptr, 0); + ptr1 += 4; + outptr += 4; + } + + ptr += 4; + } + } + } + + return 0; + } + if (b.dims == 3) { if (w1 == 1 && h1 == 1 && channels1 == channels) @@ -417,6 +600,42 @@ static int binary_op_pack4(const Mat& a, const Mat& b, Mat& c, const Option& opt } else if (a.dims == 2) { + if (b.dims == 4) + { + // type 22 + c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const float* ptr = a.row(q); + const float* ptr1 = b.channel(q); + float* outptr = c.channel(q); + + for (int z = 0; z < d1; z++) + { + v4f32 _a0 = (v4f32)__msa_ld_w(ptr, 0); + for (int y = 0; y < h1; y++) + { + for (int x = 0; x < w1; x++) + { + v4f32 _p = (v4f32)__msa_ld_w(ptr1, 0); + v4f32 _outp = op(_a0, _p); + __msa_st_w((v4i32)_outp, outptr, 0); + ptr1 += 4; + outptr += 4; + } + } + + ptr += 4; + } + } + + return 0; + } + if (b.dims == 3) { // type 14 @@ -530,6 +749,33 @@ static int binary_op_pack4(const Mat& a, const Mat& b, Mat& c, const Option& opt { if (a.w == 1 && elempack == 1) { + if (b.dims == 4) + { + // type 20 + c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); + if (c.empty()) + return -100; + + v4f32 _a0 = __msa_fill_w_f32(a[0]); + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const float* ptr1 = b.channel(q); + float* outptr = c.channel(q); + + for (int i = 0; i < size1; i++) + { + v4f32 _p1 = (v4f32)__msa_ld_w(ptr1, 0); + v4f32 _outp = op(_a0, _p1); + __msa_st_w((v4i32)_outp, outptr, 0); + ptr1 += 4; + outptr += 4; + } + } + + return 0; + } + if (b.dims == 3) { // type 4 @@ -605,6 +851,33 @@ static int binary_op_pack4(const Mat& a, const Mat& b, Mat& c, const Option& opt } } + if (b.dims == 4) + { + // type 21 + c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + v4f32 _a0 = (v4f32)__msa_ld_w((const float*)a + q * 4, 0); + const float* ptr1 = b.channel(q); + float* outptr = c.channel(q); + + for (int i = 0; i < size1; i++) + { + v4f32 _p1 = (v4f32)__msa_ld_w(ptr1, 0); + v4f32 _outp = op(_a0, _p1); + __msa_st_w((v4i32)_outp, outptr, 0); + ptr1 += 4; + outptr += 4; + } + } + + return 0; + } + if (b.dims == 3) { // type 9 @@ -717,8 +990,9 @@ static int binary_op_scalar_inplace_pack4(Mat& a, float b, const Option& opt) int w = a.w; int h = a.h; + int d = a.d; int channels = a.c; - int size = w * h; + int size = w * h * d; v4f32 _b = __msa_fill_w_f32(b); @@ -847,10 +1121,10 @@ int BinaryOp_mips::forward(const std::vector& bottom_blobs, std::vector(bottom_blob, bottom_blob1, top_blob, opt); if (op_type == Operation_RSUB) - return binary_op_pack4(bottom_blob, bottom_blob1, top_blob, opt); + return binary_op_pack4(bottom_blob1, bottom_blob, top_blob, opt); if (op_type == Operation_RDIV) - return binary_op_pack4(bottom_blob, bottom_blob1, top_blob, opt); + return binary_op_pack4(bottom_blob1, bottom_blob, top_blob, opt); } #endif // __mips_msa diff --git a/src/layer/mips/unaryop_mips.cpp b/src/layer/mips/unaryop_mips.cpp index 69b6d1054..6461db6b6 100644 --- a/src/layer/mips/unaryop_mips.cpp +++ b/src/layer/mips/unaryop_mips.cpp @@ -38,8 +38,9 @@ static int unary_op_inplace_pack4(Mat& a, const Option& opt) int w = a.w; int h = a.h; + int d = a.d; int channels = a.c; - int size = w * h; + int size = w * h * d; #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) diff --git a/src/layer/riscv/binaryop_riscv.cpp b/src/layer/riscv/binaryop_riscv.cpp index 36db44a33..36270c1a4 100644 --- a/src/layer/riscv/binaryop_riscv.cpp +++ b/src/layer/riscv/binaryop_riscv.cpp @@ -28,6 +28,9 @@ #include "rvv_mathfun.h" #include "rvv_mathfun_fp16s.h" #endif // __riscv_vector + +#include "riscv_usability.h" + namespace ncnn { BinaryOp_riscv::BinaryOp_riscv() @@ -42,66 +45,266 @@ BinaryOp_riscv::BinaryOp_riscv() #if __riscv_vector template -static int binary_op_rvv(const Mat& a, const Mat& b, Mat& c, - const Option& opt) +static int binary_op_rvv(const Mat& a, const Mat& b, Mat& c, const Option& opt) { Op op; int w = a.w; int h = a.h; + int d = a.d; int channels = a.c; - int size = w * h; + int size = w * h * d; size_t elemsize = a.elemsize; int elempack = a.elempack; int w1 = b.w; int h1 = b.h; + int d1 = b.d; int channels1 = b.c; - int size1 = w1 * h1; + int size1 = w1 * h1 * d1; size_t elemsize1 = b.elemsize; int elempack1 = b.elempack; - if (a.dims == 3) + if (a.dims == 4) { + if (b.dims == 4) + { + // type 29 + c.create(w, h, d, channels, elemsize, elempack, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* ptr = a.channel(q); + const float* ptr1 = b.channel(q); + float* outptr = c.channel(q); + + int n = size * elempack; + while (n > 0) + { + word_type vl = vsetvl_e32m8(n); + vfloat32m8_t _p = vle32_v_f32m8(ptr, vl); + vfloat32m8_t _p1 = vle32_v_f32m8(ptr1, vl); + vfloat32m8_t _outp = op(_p, _p1, vl); + vse32_v_f32m8(outptr, _outp, vl); + + ptr += vl; + ptr1 += vl; + outptr += vl; + n -= vl; + } + } + + return 0; + } + + c.create(w, h, d, channels, elemsize, elempack, opt.blob_allocator); + if (c.empty()) + return -100; + if (b.dims == 3) { - if (w1 == 1 && h1 == 1 && channels1 == channels) + // type 28 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - // special type 1 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; + const float* ptr = a.channel(q); + const float* ptr1 = b.channel(q); + float* outptr = c.channel(q); + + for (int z = 0; z < d; z++) + { + for (int y = 0; y < h; y++) + { + vfloat32m8_t _b0x = vle32_v_f32m8_f32m1(ptr1); + + int n = w * elempack; + while (n > 0) + { + word_type vl = vsetvl_e32m8(n); + vfloat32m8_t _p = vle32_v_f32m8(ptr, vl); + vfloat32m8_t _outp = op(_p, _b0x, vl); + vse32_v_f32m8(outptr, _outp, vl); + + ptr += vl; + outptr += vl; + n -= vl; + } + + ptr1 += elempack1; + } + } + } + return 0; + } + + if (b.dims == 2) + { + // type 27 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* ptr = a.channel(q); + const float* ptr1 = b.row(q); + float* outptr = c.channel(q); + + for (int z = 0; z < d; z++) + { + vfloat32m8_t _b0x = vle32_v_f32m8_f32m1(ptr1); + + int n = w * h * elempack; + while (n > 0) + { + word_type vl = vsetvl_e32m8(n); + vfloat32m8_t _p = vle32_v_f32m8(ptr, vl); + vfloat32m8_t _outp = op(_p, _b0x, vl); + vse32_v_f32m8(outptr, _outp, vl); + + ptr += vl; + outptr += vl; + n -= vl; + } + + ptr1 += elempack1; + } + } + + return 0; + } + + if (b.dims == 1) + { + if (b.w == 1 && elempack1 == 1) + { + // type 25 #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) { const float* ptr = a.channel(q); - const float* b0 = b.channel(q); + const float b0 = b[0]; float* outptr = c.channel(q); int n = size * elempack; while (n > 0) { - const float* b_vol = b0; - int n1 = size1 * elempack1; - while (n1 > 0) - { - word_type vl = vsetvl_e32m8(std::min(n1, n)); + word_type vl = vsetvl_e32m8(n); + vfloat32m8_t _p = vle32_v_f32m8(ptr, vl); + vfloat32m8_t _outp = op(_p, b0, vl); + vse32_v_f32m8(outptr, _outp, vl); - vfloat32m8_t _b = vle32_v_f32m8(b_vol, vl); - vfloat32m8_t _p = vle32_v_f32m8(ptr, vl); - vfloat32m8_t _outp = op(_p, _b, vl); + ptr += vl; + outptr += vl; + n -= vl; + } + } + + return 0; + } + + // type 26 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* ptr = a.channel(q); + float* outptr = c.channel(q); + + vfloat32m8_t _b0x = vle32_v_f32m8_f32m1((const float*)b + q * elempack); + + int n = size * elempack; + while (n > 0) + { + word_type vl = vsetvl_e32m8(n); + vfloat32m8_t _p = vle32_v_f32m8(ptr, vl); + vfloat32m8_t _outp = op(_p, _b0x, vl); + vse32_v_f32m8(outptr, _outp, vl); + + outptr += vl; + ptr += vl; + n -= vl; + } + } + + return 0; + } + } + else if (a.dims == 3) + { + if (b.dims == 4) + { + // type 23 + c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const float* ptr = a.channel(q); + const float* ptr1 = b.channel(q); + float* outptr = c.channel(q); + + for (int z = 0; z < d1; z++) + { + for (int y = 0; y < h1; y++) + { + vfloat32m8_t _a0x = vle32_v_f32m8_f32m1(ptr); + + int n = w1 * elempack1; + while (n > 0) + { + word_type vl = vsetvl_e32m8(n); + vfloat32m8_t _p1 = vle32_v_f32m8(ptr1, vl); + vfloat32m8_t _outp = op(_a0x, _p1, vl); vse32_v_f32m8(outptr, _outp, vl); - ptr += vl; - b_vol += vl; + ptr1 += vl; outptr += vl; - - n1 -= vl; n -= vl; } + + ptr += elempack; + } + } + } + + return 0; + } + + if (b.dims == 3) + { + if (w1 == 1 && h1 == 1 && channels1 == channels) + { + // special type 1 + c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* ptr = a.channel(q); + const float* ptr1 = b.channel(q); + float* outptr = c.channel(q); + + vfloat32m8_t _b0x = vle32_v_f32m8_f32m1(ptr1); + + int n = size * elempack; + while (n > 0) + { + word_type vl = vsetvl_e32m8(n); + vfloat32m8_t _p = vle32_v_f32m8(ptr, vl); + vfloat32m8_t _outp = op(_p, _b0x, vl); + vse32_v_f32m8(outptr, _outp, vl); + + ptr += vl; + outptr += vl; + n -= vl; } } + return 0; } @@ -127,12 +330,13 @@ static int binary_op_rvv(const Mat& a, const Mat& b, Mat& c, word_type vl = vsetvl_e32m8(n); vfloat32m8_t _p = vle32_v_f32m8(ptr, vl); vfloat32m8_t _outp = op(_p, *ptr1, vl); - vse32_v_f32m8(outptr, _outp, vl); + n -= vl; ptr += vl; outptr += vl; } + ptr1 += 1; } } @@ -150,31 +354,23 @@ static int binary_op_rvv(const Mat& a, const Mat& b, Mat& c, #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels1; q++) { - const float* a0 = a.channel(q); + const float* ptr = a.channel(q); const float* ptr1 = b.channel(q); float* outptr = c.channel(q); + vfloat32m8_t _a0x = vle32_v_f32m8_f32m1(ptr); + int n1 = size1 * elempack1; while (n1 > 0) { - const float* a_vol = a0; - int n = size * elempack; - while (n > 0) - { - word_type vl = vsetvl_e32m8(std::min(n1, n)); - - vfloat32m8_t _a0 = vle32_v_f32m8(a_vol, vl); - vfloat32m8_t _p1 = vle32_v_f32m8(ptr1, vl); - vfloat32m8_t _outp = op(_a0, _p1, vl); - vse32_v_f32m8(outptr, _outp, vl); - - ptr1 += vl; - a_vol += vl; - outptr += vl; + word_type vl = vsetvl_e32m8(n1); + vfloat32m8_t _p1 = vle32_v_f32m8(ptr1, vl); + vfloat32m8_t _outp = op(_a0x, _p1, vl); + vse32_v_f32m8(outptr, _outp, vl); - n1 -= vl; - n -= vl; - } + ptr1 += vl; + outptr += vl; + n1 -= vl; } } @@ -204,12 +400,13 @@ static int binary_op_rvv(const Mat& a, const Mat& b, Mat& c, vfloat32m8_t _p1 = vle32_v_f32m8(ptr1, vl); vfloat32m8_t _p = vfmv_v_f_f32m8(*ptr, vl); vfloat32m8_t _outp = op(_p, _p1, vl); - vse32_v_f32m8(outptr, _outp, vl); + n1 -= vl; ptr1 += vl; outptr += vl; } + ptr += 1; } } @@ -230,24 +427,25 @@ static int binary_op_rvv(const Mat& a, const Mat& b, Mat& c, const float* ptr = a.channel(q); const float* ptr1 = b.channel(q); float* outptr = c.channel(q); + for (int y = 0; y < h; y++) { - for (int x = 0; x < w; x++) + vfloat32m8_t _b0x = vle32_v_f32m8_f32m1(ptr1); + + int n = w * elempack; + while (n > 0) { - const float* ptr1_vol = ptr1 + y * elempack; - int n = elempack; - while (n > 0) - { - word_type vl = vsetvl_e32m8(n); - vfloat32m8_t _p1 = vle32_v_f32m8(ptr1_vol, vl); - vfloat32m8_t _p = vle32_v_f32m8(ptr, vl); - vfloat32m8_t _outp = op(_p, _p1, vl); - vse32_v_f32m8(outptr, _outp, vl); - ptr += vl; - outptr += vl; - n -= vl; - } + word_type vl = vsetvl_e32m8(n); + vfloat32m8_t _p = vle32_v_f32m8(ptr, vl); + vfloat32m8_t _outp = op(_p, _b0x, vl); + vse32_v_f32m8(outptr, _outp, vl); + + ptr += vl; + outptr += vl; + n -= vl; } + + ptr1 += elempack1; } } @@ -270,23 +468,20 @@ static int binary_op_rvv(const Mat& a, const Mat& b, Mat& c, for (int y = 0; y < h; y++) { - for (int x = 0; x < w; x++) + int n = w * elempack; + const float* ptr1_vol = ptr1; + while (n > 0) { - int n = elempack; - const float* ptr1_vol = ptr1 + x * elempack; - while (n > 0) - { - word_type vl = vsetvl_e32m8(n); - vfloat32m8_t _p = vle32_v_f32m8(ptr, vl); - vfloat32m8_t _p1 = vle32_v_f32m8(ptr1_vol, vl); - vfloat32m8_t _outp = op(_p, _p1, vl); + word_type vl = vsetvl_e32m8(n); + vfloat32m8_t _p = vle32_v_f32m8(ptr, vl); + vfloat32m8_t _p1 = vle32_v_f32m8(ptr1_vol, vl); + vfloat32m8_t _outp = op(_p, _p1, vl); + vse32_v_f32m8(outptr, _outp, vl); - vse32_v_f32m8(outptr, _outp, vl); - outptr += vl; - ptr += vl; - n -= vl; - ptr1_vol += vl; - } + outptr += vl; + ptr += vl; + n -= vl; + ptr1_vol += vl; } } } @@ -310,23 +505,22 @@ static int binary_op_rvv(const Mat& a, const Mat& b, Mat& c, for (int y = 0; y < h1; y++) { - for (int x = 0; x < w1; x++) + vfloat32m8_t _a0x = vle32_v_f32m8_f32m1(ptr); + + int n = w1 * elempack1; + while (n > 0) { - int n = elempack; - const float* ptr_vol = ptr + y * elempack; - while (n > 0) - { - word_type vl = vsetvl_e32m8(n); - vfloat32m8_t _p = vle32_v_f32m8(ptr_vol, vl); - vfloat32m8_t _p1 = vle32_v_f32m8(ptr1, vl); - vfloat32m8_t _outp = op(_p, _p1, vl); - vse32_v_f32m8(outptr, _outp, vl); - ptr1 += vl; - outptr += vl; - ptr_vol += vl; - n -= vl; - } + word_type vl = vsetvl_e32m8(n); + vfloat32m8_t _p1 = vle32_v_f32m8(ptr1, vl); + vfloat32m8_t _outp = op(_a0x, _p1, vl); + vse32_v_f32m8(outptr, _outp, vl); + + ptr1 += vl; + outptr += vl; + n -= vl; } + + ptr += elempack; } } @@ -349,23 +543,20 @@ static int binary_op_rvv(const Mat& a, const Mat& b, Mat& c, for (int y = 0; y < h1; y++) { - for (int x = 0; x < w1; x++) + int n = w1 * elempack1; + const float* ptr_vol = ptr; + while (n > 0) { - int n = elempack; - const float* ptr_vol = ptr + x * elempack; - while (n > 0) - { - word_type vl = vsetvl_e32m8(n); - vfloat32m8_t _p = vle32_v_f32m8(ptr_vol, vl); - vfloat32m8_t _p1 = vle32_v_f32m8(ptr1, vl); - vfloat32m8_t _outp = op(_p, _p1, vl); - vse32_v_f32m8(outptr, _outp, vl); + word_type vl = vsetvl_e32m8(n); + vfloat32m8_t _p = vle32_v_f32m8(ptr_vol, vl); + vfloat32m8_t _p1 = vle32_v_f32m8(ptr1, vl); + vfloat32m8_t _outp = op(_p, _p1, vl); + vse32_v_f32m8(outptr, _outp, vl); - ptr1 += vl; - outptr += vl; - ptr_vol += vl; - n -= vl; - } + ptr1 += vl; + outptr += vl; + ptr_vol += vl; + n -= vl; } } } @@ -420,24 +611,21 @@ static int binary_op_rvv(const Mat& a, const Mat& b, Mat& c, for (int y = 0; y < h; y++) { - for (int x = 0; x < w; x++) + vfloat32m8_t _b0x = vle32_v_f32m8_f32m1(ptr1); + + int n = w * elempack; + while (n > 0) { - const float* ptr1_vol = ptr1; - int n = elempack1; - while (n > 0) - { - word_type vl = vsetvl_e32m8(n); - vfloat32m8_t _b0 = vle32_v_f32m8(ptr1_vol, vl); - vfloat32m8_t _p = vle32_v_f32m8(ptr, vl); - vfloat32m8_t _outp = op(_p, _b0, vl); - vse32_v_f32m8(outptr, _outp, vl); + word_type vl = vsetvl_e32m8(n); + vfloat32m8_t _p = vle32_v_f32m8(ptr, vl); + vfloat32m8_t _outp = op(_p, _b0x, vl); + vse32_v_f32m8(outptr, _outp, vl); - ptr += vl; - outptr += vl; - ptr1_vol += vl; - n -= vl; - } + ptr += vl; + outptr += vl; + n -= vl; } + ptr1 += elempack1; } } @@ -454,6 +642,7 @@ static int binary_op_rvv(const Mat& a, const Mat& b, Mat& c, for (int q = 0; q < channels; q++) { const float* ptr = a.channel(q); + const float b0 = b[0]; float* outptr = c.channel(q); int n = size * elempack; @@ -461,7 +650,7 @@ static int binary_op_rvv(const Mat& a, const Mat& b, Mat& c, { word_type vl = vsetvl_e32m8(n); vfloat32m8_t _p = vle32_v_f32m8(ptr, vl); - vfloat32m8_t _outp = op(_p, b[0], vl); + vfloat32m8_t _outp = op(_p, b0, vl); vse32_v_f32m8(outptr, _outp, vl); ptr += vl; @@ -480,35 +669,65 @@ static int binary_op_rvv(const Mat& a, const Mat& b, Mat& c, const float* ptr = a.channel(q); float* outptr = c.channel(q); - int n = size * elempack; + vfloat32m8_t _b0x = vle32_v_f32m8_f32m1((const float*)b + q * elempack); + int n = size * elempack; while (n > 0) { - int n1 = elempack1; - const float* ptr1_vol = (const float*)b + q * elempack1; - while (n1 > 0) - { - word_type vl = vsetvl_e32m8(n1); + word_type vl = vsetvl_e32m8(n); + vfloat32m8_t _p = vle32_v_f32m8(ptr, vl); + vfloat32m8_t _outp = op(_p, _b0x, vl); + vse32_v_f32m8(outptr, _outp, vl); - vfloat32m8_t _b0 = vle32_v_f32m8(ptr1_vol, vl); - vfloat32m8_t _p = vle32_v_f32m8(ptr, vl); - vfloat32m8_t _outp = op(_p, _b0, vl); + outptr += vl; + ptr += vl; + n -= vl; + } + } + + return 0; + } + } + else if (a.dims == 2) + { + if (b.dims == 4) + { + // type 22 + c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const float* ptr = a.row(q); + const float* ptr1 = b.channel(q); + float* outptr = c.channel(q); + + for (int z = 0; z < d1; z++) + { + vfloat32m8_t _a0x = vle32_v_f32m8_f32m1(ptr); + + int n = w1 * h1 * elempack1; + while (n > 0) + { + word_type vl = vsetvl_e32m8(n); + vfloat32m8_t _p1 = vle32_v_f32m8(ptr1, vl); + vfloat32m8_t _outp = op(_a0x, _p1, vl); vse32_v_f32m8(outptr, _outp, vl); - ptr1_vol += vl; + ptr1 += vl; outptr += vl; - ptr += vl; - n1 -= vl; n -= vl; } + + ptr += elempack; } } return 0; } - } - else if (a.dims == 2) - { + if (b.dims == 3) { // type 14 @@ -525,23 +744,21 @@ static int binary_op_rvv(const Mat& a, const Mat& b, Mat& c, for (int y = 0; y < h1; y++) { - for (int x = 0; x < w1; x++) + vfloat32m8_t _a0x = vle32_v_f32m8_f32m1(ptr); + + int n = w1 * elempack1; + while (n > 0) { - const float* ptr_vol = ptr; - int n = elempack1; - while (n > 0) - { - word_type vl = vsetvl_e32m8(n); - vfloat32m8_t _a0 = vle32_v_f32m8(ptr_vol, vl); - vfloat32m8_t _p1 = vle32_v_f32m8(ptr1, vl); - vfloat32m8_t _outp = op(_a0, _p1, vl); - vse32_v_f32m8(outptr, _outp, vl); - ptr1 += vl; - outptr += vl; - ptr_vol += vl; - n -= vl; - } + word_type vl = vsetvl_e32m8(n); + vfloat32m8_t _p1 = vle32_v_f32m8(ptr1, vl); + vfloat32m8_t _outp = op(_a0x, _p1, vl); + vse32_v_f32m8(outptr, _outp, vl); + + ptr1 += vl; + outptr += vl; + n -= vl; } + ptr += elempack; } } @@ -559,11 +776,11 @@ static int binary_op_rvv(const Mat& a, const Mat& b, Mat& c, const float* ptr = a; const float* ptr1 = b; float* outptr = c; + int n = size * elempack; while (n > 0) { word_type vl = vsetvl_e32m8(n); - vfloat32m8_t _p = vle32_v_f32m8(ptr, vl); vfloat32m8_t _p1 = vle32_v_f32m8(ptr1, vl); vfloat32m8_t _outp = op(_p, _p1, vl); @@ -588,13 +805,15 @@ static int binary_op_rvv(const Mat& a, const Mat& b, Mat& c, { // type 11 const float* ptr = a; + const float b0 = b[0]; float* outptr = c; + int n = size * elempack; while (n > 0) { word_type vl = vsetvl_e32m8(n); vfloat32m8_t _p = vle32_v_f32m8(ptr, vl); - vfloat32m8_t _outp = op(_p, b[0], vl); + vfloat32m8_t _outp = op(_p, b0, vl); vse32_v_f32m8(outptr, _outp, vl); ptr += vl; @@ -612,23 +831,19 @@ static int binary_op_rvv(const Mat& a, const Mat& b, Mat& c, for (int y = 0; y < h; y++) { - for (int x = 0; x < w; x++) + vfloat32m8_t _b0x = vle32_v_f32m8_f32m1(ptr1); + + int n = w * elempack; + while (n > 0) { - int n = elempack; - const float* ptr1_vol = ptr1; - while (n > 0) - { - word_type vl = vsetvl_e32m8(n); - vfloat32m8_t _b0 = vle32_v_f32m8(ptr1_vol, vl); - vfloat32m8_t _p = vle32_v_f32m8(ptr, vl); - vfloat32m8_t _outp = op(_p, _b0, vl); - vse32_v_f32m8(outptr, _outp, vl); + word_type vl = vsetvl_e32m8(n); + vfloat32m8_t _p = vle32_v_f32m8(ptr, vl); + vfloat32m8_t _outp = op(_p, _b0x, vl); + vse32_v_f32m8(outptr, _outp, vl); - ptr += vl; - ptr1_vol += vl; - outptr += vl; - n -= vl; - } + ptr += vl; + outptr += vl; + n -= vl; } ptr1 += elempack; @@ -641,6 +856,37 @@ static int binary_op_rvv(const Mat& a, const Mat& b, Mat& c, { if (a.w == 1 && elempack == 1) { + if (b.dims == 4) + { + // type 20 + c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const float a0 = a[0]; + const float* ptr1 = b.channel(q); + float* outptr = c.channel(q); + + int n1 = size1 * elempack1; + while (n1 > 0) + { + word_type vl = vsetvl_e32m8(n1); + vfloat32m8_t _p1 = vle32_v_f32m8(ptr1, vl); + vfloat32m8_t _outp = op(a0, _p1, vl); + vse32_v_f32m8(outptr, _outp, vl); + + ptr1 += vl; + outptr += vl; + n1 -= vl; + } + } + + return 0; + } + if (b.dims == 3) { // type 4 @@ -651,6 +897,7 @@ static int binary_op_rvv(const Mat& a, const Mat& b, Mat& c, #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels1; q++) { + const float a0 = a[0]; const float* ptr1 = b.channel(q); float* outptr = c.channel(q); @@ -658,9 +905,8 @@ static int binary_op_rvv(const Mat& a, const Mat& b, Mat& c, while (n1 > 0) { word_type vl = vsetvl_e32m8(n1); - vfloat32m8_t _a0 = vfmv_v_f_f32m8(a[0], vl); vfloat32m8_t _p1 = vle32_v_f32m8(ptr1, vl); - vfloat32m8_t _outp = op(_a0, _p1, vl); + vfloat32m8_t _outp = op(a0, _p1, vl); vse32_v_f32m8(outptr, _outp, vl); ptr1 += vl; @@ -679,6 +925,7 @@ static int binary_op_rvv(const Mat& a, const Mat& b, Mat& c, if (c.empty()) return -100; + const float a0 = a[0]; const float* ptr1 = b; float* outptr = c; @@ -686,9 +933,8 @@ static int binary_op_rvv(const Mat& a, const Mat& b, Mat& c, while (n1 > 0) { word_type vl = vsetvl_e32m8(n1); - vfloat32m8_t _a0 = vfmv_v_f_f32m8(a[0], vl); vfloat32m8_t _p1 = vle32_v_f32m8(ptr1, vl); - vfloat32m8_t _outp = op(_a0, _p1, vl); + vfloat32m8_t _outp = op(a0, _p1, vl); vse32_v_f32m8(outptr, _outp, vl); ptr1 += vl; @@ -702,31 +948,63 @@ static int binary_op_rvv(const Mat& a, const Mat& b, Mat& c, if (b.dims == 1) { // type 2 - c.create(w1, elemsize1, elempack1, opt.blob_allocator); if (c.empty()) return -100; + const float a0 = a[0]; const float* ptr1 = b; float* outptr = c; + int n1 = w1 * elempack1; while (n1 > 0) { word_type vl = vsetvl_e32m8(n1); - - vfloat32m8_t _a0 = vfmv_v_f_f32m8(a[0], vl); vfloat32m8_t _p1 = vle32_v_f32m8(ptr1, vl); - vfloat32m8_t _outp = op(_a0, _p1, vl); + vfloat32m8_t _outp = op(a0, _p1, vl); vse32_v_f32m8(outptr, _outp, vl); ptr1 += vl; outptr += vl; n1 -= vl; } + return 0; } } + if (b.dims == 4) + { + // type 21 + c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const float* ptr1 = b.channel(q); + float* outptr = c.channel(q); + + vfloat32m8_t _a0x = vle32_v_f32m8_f32m1((const float*)a + q * elempack); + + int n1 = size1 * elempack1; + while (n1 > 0) + { + word_type vl = vsetvl_e32m8(n1); + vfloat32m8_t _p1 = vle32_v_f32m8(ptr1, vl); + vfloat32m8_t _outp = op(_a0x, _p1, vl); + vse32_v_f32m8(outptr, _outp, vl); + + ptr1 += vl; + outptr += vl; + n1 -= vl; + } + } + + return 0; + } + if (b.dims == 3) { // type 9 @@ -739,23 +1017,20 @@ static int binary_op_rvv(const Mat& a, const Mat& b, Mat& c, { const float* ptr1 = b.channel(q); float* outptr = c.channel(q); + + vfloat32m8_t _a0x = vle32_v_f32m8_f32m1((const float*)a + q * elempack); + int n1 = size1 * elempack1; while (n1 > 0) { - int n = elempack; - const float* ptr_vol = (const float*)a + q * elempack; - while (n > 0) - { - word_type vl = vsetvl_e32m8(n); - vfloat32m8_t _a0 = vle32_v_f32m8(ptr_vol, vl); - vfloat32m8_t _p1 = vle32_v_f32m8(ptr1, vl); - vfloat32m8_t _outp = op(_a0, _p1, vl); - vse32_v_f32m8(outptr, _outp, vl); - ptr1 += vl; - outptr += vl; - n1 -= vl; - n -= vl; - } + word_type vl = vsetvl_e32m8(n1); + vfloat32m8_t _p1 = vle32_v_f32m8(ptr1, vl); + vfloat32m8_t _outp = op(_a0x, _p1, vl); + vse32_v_f32m8(outptr, _outp, vl); + + ptr1 += vl; + outptr += vl; + n1 -= vl; } } @@ -772,29 +1047,27 @@ static int binary_op_rvv(const Mat& a, const Mat& b, Mat& c, const float* ptr = a; const float* ptr1 = b; float* outptr = c; + for (int y = 0; y < h1; y++) { - for (int x = 0; x < w1; x++) - { - const float* ptr_vol = ptr; - int n = elempack; - while (n > 0) - { - word_type vl = vsetvl_e32m8(n); + vfloat32m8_t _a0x = vle32_v_f32m8_f32m1(ptr); - vfloat32m8_t _a0 = vle32_v_f32m8(ptr, vl); - vfloat32m8_t _p1 = vle32_v_f32m8(ptr1, vl); - vfloat32m8_t _outp = op(_a0, _p1, vl); - vse32_v_f32m8(outptr, _outp, vl); + int n = w1 * elempack1; + while (n > 0) + { + word_type vl = vsetvl_e32m8(n); + vfloat32m8_t _p1 = vle32_v_f32m8(ptr1, vl); + vfloat32m8_t _outp = op(_a0x, _p1, vl); + vse32_v_f32m8(outptr, _outp, vl); - ptr1 += vl; - outptr += vl; - ptr_vol += vl; - n -= vl; - } + ptr1 += vl; + outptr += vl; + n -= vl; } + ptr += elempack; } + return 0; } @@ -808,14 +1081,17 @@ static int binary_op_rvv(const Mat& a, const Mat& b, Mat& c, { // type 6 const float* ptr = a; + const float b0 = b[0]; float* outptr = c; + int n = w * elempack; while (n > 0) { word_type vl = vsetvl_e32m8(n); vfloat32m8_t _p = vle32_v_f32m8(ptr, vl); - vfloat32m8_t _outp = op(_p, b[0], vl); + vfloat32m8_t _outp = op(_p, b0, vl); vse32_v_f32m8(outptr, _outp, vl); + ptr += vl; outptr += vl; n -= vl; @@ -849,142 +1125,189 @@ static int binary_op_rvv(const Mat& a, const Mat& b, Mat& c, return 0; } -struct binary_op_add_rvv -{ - vfloat32m8_t operator()(const vfloat32m8_t& x, const vfloat32m8_t& y, - const word_type& vl) const +template +static int binary_op_scalar_rvv(Mat& a, float b, const Option& opt) +{ + Op op; + int w = a.w; + int h = a.h; + int d = a.d; + int channels = a.c; + int size = w * h * d; + int elempack = a.elempack; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + float* ptr = a.channel(q); + int n = size * elempack; + while (n > 0) + { + word_type vl = vsetvl_e32m8(n); + vfloat32m8_t _p = vle32_v_f32m8(ptr, vl); + _p = op(_p, b, vl); + vse32_v_f32m8(ptr, _p, vl); + + n -= vl; + ptr += vl; + } + } + + return 0; +} + +struct binary_op_add_rvv +{ + vfloat32m8_t operator()(const vfloat32m8_t& x, const vfloat32m8_t& y, const word_type& vl) const { return vfadd_vv_f32m8(x, y, vl); } - vfloat32m8_t operator()(const vfloat32m8_t& x, const float& y, - const word_type& vl) const + vfloat32m8_t operator()(const vfloat32m8_t& x, const float& y, const word_type& vl) const { return vfadd_vf_f32m8(x, y, vl); } + vfloat32m8_t operator()(const float& x, const vfloat32m8_t& y, const word_type& vl) const + { + return vfadd_vf_f32m8(y, x, vl); + } }; struct binary_op_sub_rvv { - vfloat32m8_t operator()(const vfloat32m8_t& x, const vfloat32m8_t& y, - const word_type& vl) const + vfloat32m8_t operator()(const vfloat32m8_t& x, const vfloat32m8_t& y, const word_type& vl) const { return vfsub_vv_f32m8(x, y, vl); } - vfloat32m8_t operator()(const vfloat32m8_t& x, float y, - const word_type& vl) const + vfloat32m8_t operator()(const vfloat32m8_t& x, float y, const word_type& vl) const { return vfsub_vf_f32m8(x, y, vl); } + vfloat32m8_t operator()(float x, const vfloat32m8_t& y, const word_type& vl) const + { + return vfrsub_vf_f32m8(y, x, vl); + } }; struct binary_op_mul_rvv { - vfloat32m8_t operator()(const vfloat32m8_t& x, const vfloat32m8_t& y, - const word_type& vl) const + vfloat32m8_t operator()(const vfloat32m8_t& x, const vfloat32m8_t& y, const word_type& vl) const { return vfmul_vv_f32m8(x, y, vl); } - vfloat32m8_t operator()(const vfloat32m8_t& x, float y, - const word_type& vl) const + vfloat32m8_t operator()(const vfloat32m8_t& x, float y, const word_type& vl) const { return vfmul_vf_f32m8(x, y, vl); } + vfloat32m8_t operator()(float x, const vfloat32m8_t& y, const word_type& vl) const + { + return vfmul_vf_f32m8(y, x, vl); + } }; struct binary_op_div_rvv { - vfloat32m8_t operator()(const vfloat32m8_t& x, const vfloat32m8_t& y, - const word_type& vl) const + vfloat32m8_t operator()(const vfloat32m8_t& x, const vfloat32m8_t& y, const word_type& vl) const { return vfdiv_vv_f32m8(x, y, vl); } - vfloat32m8_t operator()(const vfloat32m8_t& x, float y, - const word_type& vl) const + vfloat32m8_t operator()(const vfloat32m8_t& x, float y, const word_type& vl) const { return vfdiv_vf_f32m8(x, y, vl); } + vfloat32m8_t operator()(float x, const vfloat32m8_t& y, const word_type& vl) const + { + return vfrdiv_vf_f32m8(y, x, vl); + } }; struct binary_op_max_rvv { - vfloat32m8_t operator()(const vfloat32m8_t& x, const vfloat32m8_t& y, - const word_type& vl) const + vfloat32m8_t operator()(const vfloat32m8_t& x, const vfloat32m8_t& y, const word_type& vl) const { return vfmax_vv_f32m8(x, y, vl); } - vfloat32m8_t operator()(const vfloat32m8_t& x, float y, - const word_type& vl) const + vfloat32m8_t operator()(const vfloat32m8_t& x, float y, const word_type& vl) const { return vfmax_vf_f32m8(x, y, vl); } + vfloat32m8_t operator()(float x, const vfloat32m8_t& y, const word_type& vl) const + { + return vfmax_vf_f32m8(y, x, vl); + } }; struct binary_op_min_rvv { - vfloat32m8_t operator()(const vfloat32m8_t& x, const vfloat32m8_t& y, - const word_type& vl) const + vfloat32m8_t operator()(const vfloat32m8_t& x, const vfloat32m8_t& y, const word_type& vl) const { return vfmin_vv_f32m8(x, y, vl); } - vfloat32m8_t operator()(const vfloat32m8_t& x, float y, - const word_type& vl) const + vfloat32m8_t operator()(const vfloat32m8_t& x, float y, const word_type& vl) const { return vfmin_vf_f32m8(x, y, vl); } + vfloat32m8_t operator()(float x, const vfloat32m8_t& y, const word_type& vl) const + { + return vfmin_vf_f32m8(y, x, vl); + } }; struct binary_op_pow_rvv { - vfloat32m8_t operator()(const vfloat32m8_t& x, const vfloat32m8_t& y, - const word_type& vl) const + vfloat32m8_t operator()(const vfloat32m8_t& x, const vfloat32m8_t& y, const word_type& vl) const { - return pow_ps(x, y, vl); // rvv_mathfun.h + return pow_ps(x, y, vl); } - vfloat32m8_t operator()(const vfloat32m8_t& x, float y, - const word_type& vl) const + vfloat32m8_t operator()(const vfloat32m8_t& x, float y, const word_type& vl) const { return pow_ps(x, vfmv_v_f_f32m8(y, vl), vl); } + vfloat32m8_t operator()(float x, const vfloat32m8_t& y, const word_type& vl) const + { + return pow_ps(vfmv_v_f_f32m8(x, vl), y, vl); + } }; struct binary_op_rsub_rvv { - vfloat32m8_t operator()(const vfloat32m8_t& x, const vfloat32m8_t& y, - const word_type& vl) const + vfloat32m8_t operator()(const vfloat32m8_t& x, const vfloat32m8_t& y, const word_type& vl) const { return vfsub_vv_f32m8(y, x, vl); } - vfloat32m8_t operator()(const vfloat32m8_t& x, const float& y, - const word_type& vl) const + vfloat32m8_t operator()(const vfloat32m8_t& x, const float& y, const word_type& vl) const { return vfrsub_vf_f32m8(x, y, vl); } + vfloat32m8_t operator()(const float& x, const vfloat32m8_t& y, const word_type& vl) const + { + return vfsub_vf_f32m8(y, x, vl); + } }; struct binary_op_rdiv_rvv { - vfloat32m8_t operator()(const vfloat32m8_t& x, const vfloat32m8_t& y, - const word_type& vl) const + vfloat32m8_t operator()(const vfloat32m8_t& x, const vfloat32m8_t& y, const word_type& vl) const { return vfdiv_vv_f32m8(y, x, vl); } - vfloat32m8_t operator()(const vfloat32m8_t& x, float y, - const word_type& vl) const + vfloat32m8_t operator()(const vfloat32m8_t& x, float y, const word_type& vl) const { return vfrdiv_vf_f32m8(x, y, vl); } + vfloat32m8_t operator()(float x, const vfloat32m8_t& y, const word_type& vl) const + { + return vfdiv_vf_f32m8(y, x, vl); + } }; #endif -int BinaryOp_riscv::forward(const std::vector& bottom_blobs, - std::vector& top_blobs, - const Option& opt) const +int BinaryOp_riscv::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const { int elembits = std::max(bottom_blobs[0].elembits(), bottom_blobs[1].elembits()); #if __riscv_vector && __riscv_zfh if (opt.use_fp16_storage && elembits == 16) { - return forward_fp16sa(bottom_blobs, top_blobs, opt); + return forward_fp16s(bottom_blobs, top_blobs, opt); } #endif const Mat& bottom_blob = bottom_blobs[0]; @@ -997,79 +1320,38 @@ int BinaryOp_riscv::forward(const std::vector& bottom_blobs, if (elempack != 1 || elempack1 != 1) { if (op_type == Operation_ADD) - return binary_op_rvv(bottom_blob, bottom_blob1, - top_blob, opt); + return binary_op_rvv(bottom_blob, bottom_blob1, top_blob, opt); if (op_type == Operation_SUB) - return binary_op_rvv(bottom_blob, bottom_blob1, - top_blob, opt); + return binary_op_rvv(bottom_blob, bottom_blob1, top_blob, opt); if (op_type == Operation_MUL) - return binary_op_rvv(bottom_blob, bottom_blob1, - top_blob, opt); + return binary_op_rvv(bottom_blob, bottom_blob1, top_blob, opt); if (op_type == Operation_DIV) - return binary_op_rvv(bottom_blob, bottom_blob1, - top_blob, opt); + return binary_op_rvv(bottom_blob, bottom_blob1, top_blob, opt); if (op_type == Operation_MAX) - return binary_op_rvv(bottom_blob, bottom_blob1, - top_blob, opt); + return binary_op_rvv(bottom_blob, bottom_blob1, top_blob, opt); if (op_type == Operation_MIN) - return binary_op_rvv(bottom_blob, bottom_blob1, - top_blob, opt); + return binary_op_rvv(bottom_blob, bottom_blob1, top_blob, opt); if (op_type == Operation_POW) - return binary_op_rvv(bottom_blob, bottom_blob1, - top_blob, opt); + return binary_op_rvv(bottom_blob, bottom_blob1, top_blob, opt); if (op_type == Operation_RSUB) - return binary_op_rvv(bottom_blob, bottom_blob1, - top_blob, opt); + return binary_op_rvv(bottom_blob, bottom_blob1, top_blob, opt); if (op_type == Operation_RDIV) - return binary_op_rvv(bottom_blob, bottom_blob1, - top_blob, opt); + return binary_op_rvv(bottom_blob, bottom_blob1, top_blob, opt); } #endif return BinaryOp::forward(bottom_blobs, top_blobs, opt); } -#if __riscv_vector -template -static int binary_op_scalar_rvv(Mat& a, float b, const Option& opt) -{ - Op op; - int w = a.w; - int h = a.h; - int channels = a.c; - int size = w * h; - int elempack = a.elempack; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - float* ptr = a.channel(q); - int n = size * elempack; - while (n > 0) - { - word_type vl = vsetvl_e32m8(n); - vfloat32m8_t _p = vle32_v_f32m8(ptr, vl); - _p = op(_p, b, vl); - vse32_v_f32m8(ptr, _p, vl); - - n -= vl; - ptr += vl; - } - } - return 0; -} -#endif - -int BinaryOp_riscv::forward_inplace(Mat& bottom_top_blob, - const Option& opt) const +int BinaryOp_riscv::forward_inplace(Mat& bottom_top_blob, const Option& opt) const { #if __riscv_vector int elembits = bottom_top_blob.elembits(); @@ -1077,9 +1359,10 @@ int BinaryOp_riscv::forward_inplace(Mat& bottom_top_blob, #if __riscv_zfh if (opt.use_fp16_storage && elembits == 16) { - return forward_inplace_fp16sa(bottom_top_blob, opt); + return forward_inplace_fp16s(bottom_top_blob, opt); } #endif + if (op_type == Operation_ADD) return binary_op_scalar_rvv(bottom_top_blob, b, opt); @@ -1114,33 +1397,1454 @@ int BinaryOp_riscv::forward_inplace(Mat& bottom_top_blob, // fp16sa #if __riscv_vector && __riscv_zfh template -static int binary_op_rvv_fp16sa(const Mat& a, const Mat& b, Mat& c, - const Option& opt) +static int binary_op_rvv_fp16s(const Mat& a, const Mat& b, Mat& c, const Option& opt) { Op op; int w = a.w; int h = a.h; + int d = a.d; int channels = a.c; - int size = w * h; + int size = w * h * d; size_t elemsize = a.elemsize; int elempack = a.elempack; int w1 = b.w; int h1 = b.h; + int d1 = b.d; int channels1 = b.c; - int size1 = w1 * h1; + int size1 = w1 * h1 * d1; size_t elemsize1 = b.elemsize; int elempack1 = b.elempack; - if (a.dims == 3) - { + if (a.dims == 4) + { + if (b.dims == 4) + { + // type 29 + c.create(w, h, d, channels, elemsize, elempack, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const __fp16* ptr = a.channel(q); + const __fp16* ptr1 = b.channel(q); + __fp16* outptr = c.channel(q); + + int n = size * elempack; + while (n > 0) + { + word_type vl = vsetvl_e16m8(n); + vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); + vfloat16m8_t _p1 = vle16_v_f16m8(ptr1, vl); + vfloat16m8_t _outp = op(_p, _p1, vl); + vse16_v_f16m8(outptr, _outp, vl); + + ptr += vl; + ptr1 += vl; + outptr += vl; + n -= vl; + } + } + + return 0; + } + + c.create(w, h, d, channels, elemsize, elempack, opt.blob_allocator); + if (c.empty()) + return -100; + + if (b.dims == 3) + { + // type 28 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const __fp16* ptr = a.channel(q); + const __fp16* ptr1 = b.channel(q); + __fp16* outptr = c.channel(q); + + for (int z = 0; z < d; z++) + { + for (int y = 0; y < h; y++) + { + vfloat16m8_t _b0x = vle16_v_f16m8_f16m1(ptr1); + + int n = w * elempack; + while (n > 0) + { + word_type vl = vsetvl_e16m8(n); + vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); + vfloat16m8_t _outp = op(_p, _b0x, vl); + vse16_v_f16m8(outptr, _outp, vl); + + ptr += vl; + outptr += vl; + n -= vl; + } + + ptr1 += elempack1; + } + } + } + + return 0; + } + + if (b.dims == 2) + { + // type 27 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const __fp16* ptr = a.channel(q); + const __fp16* ptr1 = b.row(q); + __fp16* outptr = c.channel(q); + + for (int z = 0; z < d; z++) + { + vfloat16m8_t _b0x = vle16_v_f16m8_f16m1(ptr1); + + int n = w * h * elempack; + while (n > 0) + { + word_type vl = vsetvl_e16m8(n); + vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); + vfloat16m8_t _outp = op(_p, _b0x, vl); + vse16_v_f16m8(outptr, _outp, vl); + + ptr += vl; + outptr += vl; + n -= vl; + } + + ptr1 += elempack1; + } + } + + return 0; + } + + if (b.dims == 1) + { + if (b.w == 1 && elempack1 == 1) + { + // type 25 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const __fp16* ptr = a.channel(q); + const __fp16 b0 = ((const __fp16*)b)[0]; + __fp16* outptr = c.channel(q); + + int n = size * elempack; + while (n > 0) + { + word_type vl = vsetvl_e16m8(n); + vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); + vfloat16m8_t _outp = op(_p, b0, vl); + vse16_v_f16m8(outptr, _outp, vl); + + ptr += vl; + outptr += vl; + n -= vl; + } + } + + return 0; + } + + // type 26 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const __fp16* ptr = a.channel(q); + __fp16* outptr = c.channel(q); + + vfloat16m8_t _b0x = vle16_v_f16m8_f16m1((const __fp16*)b + q * elempack); + + int n = size * elempack; + while (n > 0) + { + word_type vl = vsetvl_e16m8(n); + vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); + vfloat16m8_t _outp = op(_p, _b0x, vl); + vse16_v_f16m8(outptr, _outp, vl); + + outptr += vl; + ptr += vl; + n -= vl; + } + } + + return 0; + } + } + else if (a.dims == 3) + { + if (b.dims == 4) + { + // type 23 + c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const __fp16* ptr = a.channel(q); + const __fp16* ptr1 = b.channel(q); + __fp16* outptr = c.channel(q); + + for (int z = 0; z < d1; z++) + { + for (int y = 0; y < h1; y++) + { + vfloat16m8_t _a0x = vle16_v_f16m8_f16m1(ptr); + + int n = w1 * elempack1; + while (n > 0) + { + word_type vl = vsetvl_e16m8(n); + vfloat16m8_t _p1 = vle16_v_f16m8(ptr1, vl); + vfloat16m8_t _outp = op(_a0x, _p1, vl); + vse16_v_f16m8(outptr, _outp, vl); + + ptr1 += vl; + outptr += vl; + n -= vl; + } + + ptr += elempack; + } + } + } + + return 0; + } + + if (b.dims == 3) + { + if (w1 == 1 && h1 == 1 && channels1 == channels) + { + // special type 1 + c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const __fp16* ptr = a.channel(q); + const __fp16* ptr1 = b.channel(q); + __fp16* outptr = c.channel(q); + + vfloat16m8_t _b0x = vle16_v_f16m8_f16m1(ptr1); + + int n = size * elempack; + while (n > 0) + { + word_type vl = vsetvl_e16m8(n); + vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); + vfloat16m8_t _outp = op(_p, _b0x, vl); + vse16_v_f16m8(outptr, _outp, vl); + + ptr += vl; + outptr += vl; + n -= vl; + } + } + + return 0; + } + + if (w1 == w && h1 == h && channels1 == 1 && elempack1 == 1) + { + // special type 2 + c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const __fp16* ptr = a.channel(q); + const __fp16* ptr1 = b; + __fp16* outptr = c.channel(q); + + for (int i = 0; i < size; i++) + { + int n = elempack; + while (n > 0) + { + word_type vl = vsetvl_e16m8(n); + vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); + vfloat16m8_t _outp = op(_p, *ptr1, vl); + + vse16_v_f16m8(outptr, _outp, vl); + n -= vl; + ptr += vl; + outptr += vl; + } + ptr1 += 1; + } + } + + return 0; + } + + if (w == 1 && h == 1 && channels1 == channels) + { + // special type 3 + c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const __fp16* ptr = a.channel(q); + const __fp16* ptr1 = b.channel(q); + __fp16* outptr = c.channel(q); + + vfloat16m8_t _a0x = vle16_v_f16m8_f16m1(ptr); + + int n1 = size1 * elempack1; + while (n1 > 0) + { + word_type vl = vsetvl_e16m8(n1); + vfloat16m8_t _p1 = vle16_v_f16m8(ptr1, vl); + vfloat16m8_t _outp = op(_a0x, _p1, vl); + vse16_v_f16m8(outptr, _outp, vl); + + ptr1 += vl; + outptr += vl; + n1 -= vl; + } + } + + return 0; + } + + if (w1 == w && h1 == h && channels == 1 && elempack == 1) + { + // special type 4 + c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const __fp16* ptr = a; + const __fp16* ptr1 = b.channel(q); + __fp16* outptr = c.channel(q); + + for (int i = 0; i < size1; i++) + { + int n1 = elempack1; + while (n1 > 0) + { + word_type vl = vsetvl_e16m8(n1); + vfloat16m8_t _p1 = vle16_v_f16m8(ptr1, vl); + vfloat16m8_t _p = vfmv_v_f_f16m8(*ptr, vl); + vfloat16m8_t _outp = op(_p, _p1, vl); + vse16_v_f16m8(outptr, _outp, vl); + + n1 -= vl; + ptr1 += vl; + outptr += vl; + } + ptr += 1; + } + } + + return 0; + } + + if (w != 1 && w1 == 1 && h1 == h && channels1 == channels) + { + // special type 5 + c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const __fp16* ptr = a.channel(q); + const __fp16* ptr1 = b.channel(q); + __fp16* outptr = c.channel(q); + + for (int y = 0; y < h; y++) + { + vfloat16m8_t _b0x = vle16_v_f16m8_f16m1(ptr1); + + int n = w * elempack; + while (n > 0) + { + word_type vl = vsetvl_e16m8(n); + vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); + vfloat16m8_t _outp = op(_p, _b0x, vl); + vse16_v_f16m8(outptr, _outp, vl); + + ptr += vl; + outptr += vl; + n -= vl; + } + + ptr1 += elempack1; + } + } + + return 0; + } + + if (w1 == w && h != 1 && h1 == 1 && channels1 == channels) + { + // special type 6 + c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const __fp16* ptr = a.channel(q); + const __fp16* ptr1 = b.channel(q); + __fp16* outptr = c.channel(q); + + for (int y = 0; y < h; y++) + { + for (int x = 0; x < w; x++) + { + int n = elempack; + const __fp16* ptr1_vol = ptr1 + x * elempack; + while (n > 0) + { + word_type vl = vsetvl_e16m8(n); + vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); + vfloat16m8_t _p1 = vle16_v_f16m8(ptr1_vol, vl); + vfloat16m8_t _outp = op(_p, _p1, vl); + vse16_v_f16m8(outptr, _outp, vl); + + outptr += vl; + ptr += vl; + n -= vl; + ptr1_vol += vl; + } + } + } + } + + return 0; + } + + if (w1 != 1 && w == 1 && h1 == h && channels1 == channels) + { + // special type 7 + c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const __fp16* ptr = a.channel(q); + const __fp16* ptr1 = b.channel(q); + __fp16* outptr = c.channel(q); + + for (int y = 0; y < h1; y++) + { + vfloat16m8_t _a0x = vle16_v_f16m8_f16m1(ptr); + + int n = w1 * elempack1; + while (n > 0) + { + word_type vl = vsetvl_e16m8(n); + vfloat16m8_t _p1 = vle16_v_f16m8(ptr1, vl); + vfloat16m8_t _outp = op(_a0x, _p1, vl); + vse16_v_f16m8(outptr, _outp, vl); + + ptr1 += vl; + outptr += vl; + n -= vl; + } + + ptr += elempack; + } + } + + return 0; + } + + if (w1 == w && h1 != 1 && h == 1 && channels1 == channels) + { + // special type 8 + c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const __fp16* ptr = a.channel(q); + const __fp16* ptr1 = b.channel(q); + __fp16* outptr = c.channel(q); + + for (int y = 0; y < h1; y++) + { + for (int x = 0; x < w1; x++) + { + int n = elempack; + const __fp16* ptr_vol = ptr + x * elempack; + while (n > 0) + { + word_type vl = vsetvl_e16m8(n); + vfloat16m8_t _p = vle16_v_f16m8(ptr_vol, vl); + vfloat16m8_t _p1 = vle16_v_f16m8(ptr1, vl); + vfloat16m8_t _outp = op(_p, _p1, vl); + vse16_v_f16m8(outptr, _outp, vl); + + ptr1 += vl; + outptr += vl; + ptr_vol += vl; + n -= vl; + } + } + } + } + + return 0; + } + + // type 19 + c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const __fp16* ptr = a.channel(q); + const __fp16* ptr1 = b.channel(q); + __fp16* outptr = c.channel(q); + + int n = size * elempack; + while (n > 0) + { + word_type vl = vsetvl_e16m8(n); + vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); + vfloat16m8_t _p1 = vle16_v_f16m8(ptr1, vl); + vfloat16m8_t _outp = op(_p, _p1, vl); + vse16_v_f16m8(outptr, _outp, vl); + + ptr += vl; + ptr1 += vl; + outptr += vl; + n -= vl; + } + } + + return 0; + } + + c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); + if (c.empty()) + return -100; + + if (b.dims == 2) + { + // type 18 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const __fp16* ptr = a.channel(q); + const __fp16* ptr1 = b.row(q); + __fp16* outptr = c.channel(q); + + for (int y = 0; y < h; y++) + { + vfloat16m8_t _b0x = vle16_v_f16m8_f16m1(ptr1); + + int n = w * elempack; + while (n > 0) + { + word_type vl = vsetvl_e16m8(n); + vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); + vfloat16m8_t _outp = op(_p, _b0x, vl); + vse16_v_f16m8(outptr, _outp, vl); + + ptr += vl; + outptr += vl; + n -= vl; + } + + ptr1 += elempack1; + } + } + + return 0; + } + + if (b.dims == 1) + { + if (b.w == 1 && elempack1 == 1) + { + // type 16 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const __fp16* ptr = a.channel(q); + const __fp16 b0 = *(const __fp16*)b; + __fp16* outptr = c.channel(q); + + int n = size * elempack; + while (n > 0) + { + word_type vl = vsetvl_e16m8(n); + vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); + vfloat16m8_t _outp = op(_p, b0, vl); + vse16_v_f16m8(outptr, _outp, vl); + + ptr += vl; + outptr += vl; + n -= vl; + } + } + + return 0; + } + + // type 17 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const __fp16* ptr = a.channel(q); + __fp16* outptr = c.channel(q); + + vfloat16m8_t _b0x = vle16_v_f16m8_f16m1((const __fp16*)b + q * elempack); + + int n = size * elempack; + while (n > 0) + { + word_type vl = vsetvl_e16m8(n); + vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); + vfloat16m8_t _outp = op(_p, _b0x, vl); + vse16_v_f16m8(outptr, _outp, vl); + + outptr += vl; + ptr += vl; + n -= vl; + } + } + + return 0; + } + } + else if (a.dims == 2) + { + if (b.dims == 4) + { + // type 22 + c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const __fp16* ptr = a.row(q); + const __fp16* ptr1 = b.channel(q); + __fp16* outptr = c.channel(q); + + for (int z = 0; z < d1; z++) + { + vfloat16m8_t _a0x = vle16_v_f16m8_f16m1(ptr); + + int n = w1 * h1 * elempack1; + while (n > 0) + { + word_type vl = vsetvl_e16m8(n); + vfloat16m8_t _p1 = vle16_v_f16m8(ptr1, vl); + vfloat16m8_t _outp = op(_a0x, _p1, vl); + vse16_v_f16m8(outptr, _outp, vl); + + ptr1 += vl; + outptr += vl; + n -= vl; + } + + ptr += elempack; + } + } + + return 0; + } + + if (b.dims == 3) + { + // type 14 + c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const __fp16* ptr = a.row(q); + const __fp16* ptr1 = b.channel(q); + __fp16* outptr = c.channel(q); + + for (int y = 0; y < h1; y++) + { + vfloat16m8_t _a0x = vle16_v_f16m8_f16m1(ptr); + + int n = w1 * elempack1; + while (n > 0) + { + word_type vl = vsetvl_e16m8(n); + vfloat16m8_t _p1 = vle16_v_f16m8(ptr1, vl); + vfloat16m8_t _outp = op(_a0x, _p1, vl); + vse16_v_f16m8(outptr, _outp, vl); + + ptr1 += vl; + outptr += vl; + n -= vl; + } + + ptr += elempack; + } + } + + return 0; + } + + c.create(w, h, elemsize, elempack, opt.blob_allocator); + if (c.empty()) + return -100; + + if (b.dims == 2) + { + // type 13 + const __fp16* ptr = a; + const __fp16* ptr1 = b; + __fp16* outptr = c; + + int n = size * elempack; + while (n > 0) + { + word_type vl = vsetvl_e16m8(n); + vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); + vfloat16m8_t _p1 = vle16_v_f16m8(ptr1, vl); + vfloat16m8_t _outp = op(_p, _p1, vl); + vse16_v_f16m8(outptr, _outp, vl); + + ptr += vl; + ptr1 += vl; + outptr += vl; + n -= vl; + } + + return 0; + } + + if (b.dims == 1) + { + c.create(w, h, elemsize, elempack, opt.blob_allocator); + if (c.empty()) + return -100; + + if (b.w == 1 && elempack1 == 1) + { + // type 11 + const __fp16* ptr = a; + const __fp16 b0 = ((const __fp16*)b)[0]; + __fp16* outptr = c; + + int n = size * elempack; + while (n > 0) + { + word_type vl = vsetvl_e16m8(n); + vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); + vfloat16m8_t _outp = op(_p, b0, vl); + vse16_v_f16m8(outptr, _outp, vl); + + ptr += vl; + outptr += vl; + n -= vl; + } + + return 0; + } + + // type 12 + const __fp16* ptr = a; + const __fp16* ptr1 = b; + __fp16* outptr = c; + + for (int y = 0; y < h; y++) + { + vfloat16m8_t _b0x = vle16_v_f16m8_f16m1(ptr1); + + int n = w * elempack; + while (n > 0) + { + word_type vl = vsetvl_e16m8(n); + vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); + vfloat16m8_t _outp = op(_p, _b0x, vl); + vse16_v_f16m8(outptr, _outp, vl); + + ptr += vl; + outptr += vl; + n -= vl; + } + + ptr1 += elempack; + } + + return 0; + } + } + else if (a.dims == 1) + { + if (a.w == 1 && elempack == 1) + { + if (b.dims == 4) + { + // type 20 + c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const __fp16 a0 = ((const __fp16*)a)[0]; + const __fp16* ptr1 = b.channel(q); + __fp16* outptr = c.channel(q); + + int n1 = size1 * elempack1; + while (n1 > 0) + { + word_type vl = vsetvl_e16m8(n1); + vfloat16m8_t _p1 = vle16_v_f16m8(ptr1, vl); + vfloat16m8_t _outp = op(a0, _p1, vl); + vse16_v_f16m8(outptr, _outp, vl); + + ptr1 += vl; + outptr += vl; + n1 -= vl; + } + } + + return 0; + } + + if (b.dims == 3) + { + // type 4 + c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const __fp16 a0 = ((const __fp16*)a)[0]; + const __fp16* ptr1 = b.channel(q); + __fp16* outptr = c.channel(q); + + int n1 = size1 * elempack1; + while (n1 > 0) + { + word_type vl = vsetvl_e16m8(n1); + vfloat16m8_t _p1 = vle16_v_f16m8(ptr1, vl); + vfloat16m8_t _outp = op(a0, _p1, vl); + vse16_v_f16m8(outptr, _outp, vl); + + ptr1 += vl; + outptr += vl; + n1 -= vl; + } + } + + return 0; + } + + if (b.dims == 2) + { + // type 3 + c.create(w1, h1, elemsize1, elempack1, opt.blob_allocator); + if (c.empty()) + return -100; + + const __fp16 a0 = ((const __fp16*)a)[0]; + const __fp16* ptr1 = b; + __fp16* outptr = c; + + int n1 = size1 * elempack1; + while (n1 > 0) + { + word_type vl = vsetvl_e16m8(n1); + vfloat16m8_t _p1 = vle16_v_f16m8(ptr1, vl); + vfloat16m8_t _outp = op(a0, _p1, vl); + vse16_v_f16m8(outptr, _outp, vl); + + ptr1 += vl; + outptr += vl; + n1 -= vl; + } + + return 0; + } + + if (b.dims == 1) + { + // type 2 + + c.create(w1, elemsize1, elempack1, opt.blob_allocator); + if (c.empty()) + return -100; + + const __fp16 a0 = ((const __fp16*)a)[0]; + const __fp16* ptr1 = b; + __fp16* outptr = c; + + int n1 = w1 * elempack1; + while (n1 > 0) + { + word_type vl = vsetvl_e16m8(n1); + vfloat16m8_t _p1 = vle16_v_f16m8(ptr1, vl); + vfloat16m8_t _outp = op(a0, _p1, vl); + vse16_v_f16m8(outptr, _outp, vl); + + ptr1 += vl; + outptr += vl; + n1 -= vl; + } + + return 0; + } + } + + if (b.dims == 4) + { + // type 21 + c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const __fp16* ptr1 = b.channel(q); + __fp16* outptr = c.channel(q); + + vfloat16m8_t _a0x = vle16_v_f16m8_f16m1((const __fp16*)a + q * elempack); + + int n1 = size1 * elempack1; + while (n1 > 0) + { + word_type vl = vsetvl_e16m8(n1); + vfloat16m8_t _p1 = vle16_v_f16m8(ptr1, vl); + vfloat16m8_t _outp = op(_a0x, _p1, vl); + vse16_v_f16m8(outptr, _outp, vl); + + ptr1 += vl; + outptr += vl; + n1 -= vl; + } + } + + return 0; + } + + if (b.dims == 3) + { + // type 9 + c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const __fp16* ptr1 = b.channel(q); + __fp16* outptr = c.channel(q); + + vfloat16m8_t _a0x = vle16_v_f16m8_f16m1((const __fp16*)a + q * elempack); + + int n1 = size1 * elempack1; + while (n1 > 0) + { + word_type vl = vsetvl_e16m8(n1); + vfloat16m8_t _p1 = vle16_v_f16m8(ptr1, vl); + vfloat16m8_t _outp = op(_a0x, _p1, vl); + vse16_v_f16m8(outptr, _outp, vl); + + ptr1 += vl; + outptr += vl; + n1 -= vl; + } + } + + return 0; + } + + if (b.dims == 2) + { + // type 8 + c.create(w1, h1, elemsize1, elempack1, opt.blob_allocator); + if (c.empty()) + return -100; + + const __fp16* ptr = a; + const __fp16* ptr1 = b; + __fp16* outptr = c; + + for (int y = 0; y < h1; y++) + { + vfloat16m8_t _a0x = vle16_v_f16m8_f16m1(ptr); + + int n = w1 * elempack1; + while (n > 0) + { + word_type vl = vsetvl_e16m8(n); + vfloat16m8_t _p1 = vle16_v_f16m8(ptr1, vl); + vfloat16m8_t _outp = op(_a0x, _p1, vl); + vse16_v_f16m8(outptr, _outp, vl); + + ptr1 += vl; + outptr += vl; + n -= vl; + } + + ptr += elempack; + } + + return 0; + } + + if (b.dims == 1) + { + c.create(w, elemsize, elempack, opt.blob_allocator); + if (c.empty()) + return -100; + + if (b.w == 1 && elempack1 == 1) + { + // type 6 + const __fp16* ptr = a; + const __fp16 b0 = ((const __fp16*)b)[0]; + __fp16* outptr = c; + + int n = w * elempack; + while (n > 0) + { + word_type vl = vsetvl_e16m8(n); + vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); + vfloat16m8_t _outp = op(_p, b0, vl); + vse16_v_f16m8(outptr, _outp, vl); + + ptr += vl; + outptr += vl; + n -= vl; + } + + return 0; + } + + // type 7 + const __fp16* ptr = a; + const __fp16* ptr1 = b; + __fp16* outptr = c; + + int n = size * elempack; + while (n > 0) + { + word_type vl = vsetvl_e16m8(n); + vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); + vfloat16m8_t _p1 = vle16_v_f16m8(ptr1, vl); + vfloat16m8_t _outp = op(_p, _p1, vl); + vse16_v_f16m8(outptr, _outp, vl); + + ptr += vl; + ptr1 += vl; + outptr += vl; + n -= vl; + } + } + } + + return 0; +} + +template +static int binary_op_scalar_rvv_fp16s(Mat& a, float b, const Option& opt) +{ + Op op; + int w = a.w; + int h = a.h; + int d = a.d; + int channels = a.c; + int size = w * h * d; + int elempack = a.elempack; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + __fp16* ptr = a.channel(q); + int n = size * elempack; + while (n > 0) + { + word_type vl = vsetvl_e16m8(n); + vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); + _p = op(_p, b, vl); + vse16_v_f16m8(ptr, _p, vl); + + n -= vl; + ptr += vl; + } + } + + return 0; +} + +struct binary_op_add_rvv_fp16 +{ + vfloat16m8_t operator()(const vfloat16m8_t& x, const vfloat16m8_t& y, const word_type& vl) const + { + return vfadd_vv_f16m8(x, y, vl); + } + vfloat16m8_t operator()(const vfloat16m8_t& x, const float& y, const word_type& vl) const + { + return vfadd_vf_f16m8(x, y, vl); + } + vfloat16m8_t operator()(const float& x, const vfloat16m8_t& y, const word_type& vl) const + { + return vfadd_vf_f16m8(y, x, vl); + } +}; + +struct binary_op_sub_rvv_fp16 +{ + vfloat16m8_t operator()(const vfloat16m8_t& x, const vfloat16m8_t& y, const word_type& vl) const + { + return vfsub_vv_f16m8(x, y, vl); + } + vfloat16m8_t operator()(const vfloat16m8_t& x, float y, const word_type& vl) const + { + return vfsub_vf_f16m8(x, y, vl); + } + vfloat16m8_t operator()(float x, const vfloat16m8_t& y, const word_type& vl) const + { + return vfrsub_vf_f16m8(y, x, vl); + } +}; + +struct binary_op_mul_rvv_fp16 +{ + vfloat16m8_t operator()(const vfloat16m8_t& x, const vfloat16m8_t& y, const word_type& vl) const + { + return vfmul_vv_f16m8(x, y, vl); + } + vfloat16m8_t operator()(const vfloat16m8_t& x, float y, const word_type& vl) const + { + return vfmul_vf_f16m8(x, y, vl); + } + vfloat16m8_t operator()(float x, const vfloat16m8_t& y, const word_type& vl) const + { + return vfmul_vf_f16m8(y, x, vl); + } +}; + +struct binary_op_div_rvv_fp16 +{ + vfloat16m8_t operator()(const vfloat16m8_t& x, const vfloat16m8_t& y, const word_type& vl) const + { + return vfdiv_vv_f16m8(x, y, vl); + } + vfloat16m8_t operator()(const vfloat16m8_t& x, float y, const word_type& vl) const + { + return vfdiv_vf_f16m8(x, y, vl); + } + vfloat16m8_t operator()(float x, const vfloat16m8_t& y, const word_type& vl) const + { + return vfrdiv_vf_f16m8(y, x, vl); + } +}; + +struct binary_op_max_rvv_fp16 +{ + vfloat16m8_t operator()(const vfloat16m8_t& x, const vfloat16m8_t& y, const word_type& vl) const + { + return vfmax_vv_f16m8(x, y, vl); + } + vfloat16m8_t operator()(const vfloat16m8_t& x, float y, const word_type& vl) const + { + return vfmax_vf_f16m8(x, y, vl); + } + vfloat16m8_t operator()(float x, const vfloat16m8_t& y, const word_type& vl) const + { + return vfmax_vf_f16m8(y, x, vl); + } +}; + +struct binary_op_min_rvv_fp16 +{ + vfloat16m8_t operator()(const vfloat16m8_t& x, const vfloat16m8_t& y, const word_type& vl) const + { + return vfmin_vv_f16m8(x, y, vl); + } + vfloat16m8_t operator()(const vfloat16m8_t& x, float y, const word_type& vl) const + { + return vfmin_vf_f16m8(x, y, vl); + } + vfloat16m8_t operator()(float x, const vfloat16m8_t& y, const word_type& vl) const + { + return vfmin_vf_f16m8(y, x, vl); + } +}; + +struct binary_op_pow_rvv_fp16 +{ + vfloat16m8_t operator()(const vfloat16m8_t& x, const vfloat16m8_t& y, const word_type& vl) const + { + return pow_ps(x, y, vl); + } + vfloat16m8_t operator()(const vfloat16m8_t& x, const __fp16& y, const word_type& vl) const + { + return pow_ps(x, vfmv_v_f_f16m8(y, vl), vl); + } + vfloat16m8_t operator()(const __fp16& x, const vfloat16m8_t& y, const word_type& vl) const + { + return pow_ps(vfmv_v_f_f16m8(x, vl), y, vl); + } +}; + +struct binary_op_rsub_rvv_fp16 +{ + vfloat16m8_t operator()(const vfloat16m8_t& x, const vfloat16m8_t& y, const word_type& vl) const + { + return vfsub_vv_f16m8(y, x, vl); + } + vfloat16m8_t operator()(const vfloat16m8_t& x, const float& y, const word_type& vl) const + { + return vfrsub_vf_f16m8(x, y, vl); + } + vfloat16m8_t operator()(const float& x, const vfloat16m8_t& y, const word_type& vl) const + { + return vfsub_vf_f16m8(y, x, vl); + } +}; + +struct binary_op_rdiv_rvv_fp16 +{ + vfloat16m8_t operator()(const vfloat16m8_t& x, const vfloat16m8_t& y, const word_type& vl) const + { + return vfdiv_vv_f16m8(y, x, vl); + } + vfloat16m8_t operator()(const vfloat16m8_t& x, float y, const word_type& vl) const + { + return vfrdiv_vf_f16m8(x, y, vl); + } + vfloat16m8_t operator()(float x, const vfloat16m8_t& y, const word_type& vl) const + { + return vfdiv_vf_f16m8(y, x, vl); + } +}; + +template +static int binary_op_fp16s(const Mat& a, const Mat& b, Mat& c, const Option& opt) +{ + Op op; + + int w = a.w; + int h = a.h; + int d = a.d; + int channels = a.c; + int size = w * h * d; + size_t elemsize = a.elemsize; + + int w1 = b.w; + int h1 = b.h; + int d1 = b.d; + int channels1 = b.c; + int size1 = w1 * h1 * d1; + + if (a.dims == 4) + { + if (b.dims == 4) + { + // type 29 + c.create(w, h, d, channels, elemsize, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const __fp16* ptr = a.channel(q); + const __fp16* ptr1 = b.channel(q); + __fp16* outptr = c.channel(q); + + for (int i = 0; i < size; i++) + { + outptr[i] = op(ptr[i], ptr1[i]); + } + } + + return 0; + } + + c.create(w, h, d, channels, elemsize, opt.blob_allocator); + if (c.empty()) + return -100; + + if (b.dims == 3) + { + // type 28 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const __fp16* ptr = a.channel(q); + const __fp16* ptr1 = b.channel(q); + __fp16* outptr = c.channel(q); + + for (int z = 0; z < d; z++) + { + for (int y = 0; y < h; y++) + { + const __fp16 b0 = ptr1[y]; + for (int x = 0; x < w; x++) + { + outptr[x] = op(ptr[x], b0); + } + + ptr += w; + outptr += w; + } + + ptr1 += h; + } + } + + return 0; + } + + if (b.dims == 2) + { + // type 27 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const __fp16* ptr = a.channel(q); + const __fp16* ptr1 = b.row(q); + __fp16* outptr = c.channel(q); + + for (int z = 0; z < d; z++) + { + const __fp16 b0 = ptr1[z]; + for (int y = 0; y < h; y++) + { + for (int x = 0; x < w; x++) + { + outptr[x] = op(ptr[x], b0); + } + + ptr += w; + outptr += w; + } + } + } + + return 0; + } + + if (b.dims == 1) + { + if (b.w == 1) + { + // type 25 + const __fp16 b0 = ((const __fp16*)b)[0]; + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const __fp16* ptr = a.channel(q); + __fp16* outptr = c.channel(q); + + for (int i = 0; i < size; i++) + { + outptr[i] = op(ptr[i], b0); + } + } + + return 0; + } + + // type 26 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const __fp16* ptr = a.channel(q); + const __fp16 b0 = ((const __fp16*)b)[q]; + __fp16* outptr = c.channel(q); + + for (int i = 0; i < size; i++) + { + outptr[i] = op(ptr[i], b0); + } + } + + return 0; + } + } + else if (a.dims == 3) + { + if (b.dims == 4) + { + // type 23 + c.create(w1, h1, d1, channels1, elemsize, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const __fp16* ptr = a.channel(q); + const __fp16* ptr1 = b.channel(q); + __fp16* outptr = c.channel(q); + + for (int z = 0; z < d1; z++) + { + for (int y = 0; y < h1; y++) + { + const __fp16 a0 = ptr[y]; + for (int x = 0; x < w1; x++) + { + outptr[x] = op(a0, ptr1[x]); + } + + ptr1 += w1; + outptr += w1; + } + + ptr += h1; + } + } + + return 0; + } + if (b.dims == 3) { if (w1 == 1 && h1 == 1 && channels1 == channels) { // special type 1 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); + c.create(w, h, channels, elemsize, opt.blob_allocator); if (c.empty()) return -100; @@ -1150,37 +2854,19 @@ static int binary_op_rvv_fp16sa(const Mat& a, const Mat& b, Mat& c, const __fp16* ptr = a.channel(q); const __fp16* b0 = b.channel(q); __fp16* outptr = c.channel(q); - - int n = size * elempack; - while (n > 0) + for (int i = 0; i < size; i++) { - const __fp16* b_vol = b0; - int n1 = size1 * elempack1; - while (n1 > 0) - { - word_type vl = vsetvl_e16m8(std::min(n1, n)); - - vfloat16m8_t _b = vle16_v_f16m8(b_vol, vl); - vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); - vfloat16m8_t _outp = op(_p, _b, vl); - vse16_v_f16m8(outptr, _outp, vl); - - ptr += vl; - b_vol += vl; - outptr += vl; - - n1 -= vl; - n -= vl; - } + outptr[i] = op(ptr[i], b0[0]); } } + return 0; } - if (w1 == w && h1 == h && channels1 == 1 && elempack1 == 1) + if (w1 == w && h1 == h && channels1 == 1) { // special type 2 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); + c.create(w, h, channels, elemsize, opt.blob_allocator); if (c.empty()) return -100; @@ -1190,22 +2876,9 @@ static int binary_op_rvv_fp16sa(const Mat& a, const Mat& b, Mat& c, const __fp16* ptr = a.channel(q); const __fp16* ptr1 = b; __fp16* outptr = c.channel(q); - for (int i = 0; i < size; i++) { - int n = elempack; - while (n > 0) - { - word_type vl = vsetvl_e16m8(n); - vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); - vfloat16m8_t _outp = op(_p, *ptr1, vl); - - vse16_v_f16m8(outptr, _outp, vl); - n -= vl; - ptr += vl; - outptr += vl; - } - ptr1 += 1; + outptr[i] = op(ptr[i], ptr1[i]); } } @@ -1215,7 +2888,7 @@ static int binary_op_rvv_fp16sa(const Mat& a, const Mat& b, Mat& c, if (w == 1 && h == 1 && channels1 == channels) { // special type 3 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); + c.create(w1, h1, channels1, elemsize, opt.blob_allocator); if (c.empty()) return -100; @@ -1225,38 +2898,19 @@ static int binary_op_rvv_fp16sa(const Mat& a, const Mat& b, Mat& c, const __fp16* a0 = a.channel(q); const __fp16* ptr1 = b.channel(q); __fp16* outptr = c.channel(q); - - int n1 = size1 * elempack1; - while (n1 > 0) + for (int i = 0; i < size1; i++) { - const __fp16* a_vol = a0; - int n = size * elempack; - while (n > 0) - { - word_type vl = vsetvl_e16m8(std::min(n1, n)); - - vfloat16m8_t _a0 = vle16_v_f16m8(a_vol, vl); - vfloat16m8_t _p1 = vle16_v_f16m8(ptr1, vl); - vfloat16m8_t _outp = op(_a0, _p1, vl); - vse16_v_f16m8(outptr, _outp, vl); - - ptr1 += vl; - a_vol += vl; - outptr += vl; - - n1 -= vl; - n -= vl; - } + outptr[i] = op(a0[0], ptr1[i]); } } return 0; } - if (w1 == w && h1 == h && channels == 1 && elempack == 1) + if (w1 == w && h1 == h && channels == 1) { // special type 4 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); + c.create(w1, h1, channels1, elemsize, opt.blob_allocator); if (c.empty()) return -100; @@ -1266,23 +2920,9 @@ static int binary_op_rvv_fp16sa(const Mat& a, const Mat& b, Mat& c, const __fp16* ptr = a; const __fp16* ptr1 = b.channel(q); __fp16* outptr = c.channel(q); - for (int i = 0; i < size1; i++) { - int n1 = elempack1; - while (n1 > 0) - { - word_type vl = vsetvl_e16m8(n1); - vfloat16m8_t _p1 = vle16_v_f16m8(ptr1, vl); - vfloat16m8_t _p = vfmv_v_f_f16m8(*ptr, vl); - vfloat16m8_t _outp = op(_p, _p1, vl); - - vse16_v_f16m8(outptr, _outp, vl); - n1 -= vl; - ptr1 += vl; - outptr += vl; - } - ptr += 1; + outptr[i] = op(ptr[i], ptr1[i]); } } @@ -1292,7 +2932,7 @@ static int binary_op_rvv_fp16sa(const Mat& a, const Mat& b, Mat& c, if (w != 1 && w1 == 1 && h1 == h && channels1 == channels) { // special type 5 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); + c.create(w, h, channels, elemsize, opt.blob_allocator); if (c.empty()) return -100; @@ -1302,24 +2942,17 @@ static int binary_op_rvv_fp16sa(const Mat& a, const Mat& b, Mat& c, const __fp16* ptr = a.channel(q); const __fp16* ptr1 = b.channel(q); __fp16* outptr = c.channel(q); + for (int y = 0; y < h; y++) { + const __fp16 b0 = ptr1[y]; for (int x = 0; x < w; x++) { - const __fp16* ptr1_vol = ptr1 + y * elempack; - int n = elempack; - while (n > 0) - { - word_type vl = vsetvl_e16m8(n); - vfloat16m8_t _p1 = vle16_v_f16m8(ptr1_vol, vl); - vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); - vfloat16m8_t _outp = op(_p, _p1, vl); - vse16_v_f16m8(outptr, _outp, vl); - ptr += vl; - outptr += vl; - n -= vl; - } + outptr[x] = op(ptr[x], b0); } + + ptr += w; + outptr += w; } } @@ -1329,7 +2962,7 @@ static int binary_op_rvv_fp16sa(const Mat& a, const Mat& b, Mat& c, if (w1 == w && h != 1 && h1 == 1 && channels1 == channels) { // special type 6 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); + c.create(w, h, channels, elemsize, opt.blob_allocator); if (c.empty()) return -100; @@ -1344,22 +2977,11 @@ static int binary_op_rvv_fp16sa(const Mat& a, const Mat& b, Mat& c, { for (int x = 0; x < w; x++) { - int n = elempack; - const __fp16* ptr1_vol = ptr1 + x * elempack; - while (n > 0) - { - word_type vl = vsetvl_e16m8(n); - vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); - vfloat16m8_t _p1 = vle16_v_f16m8(ptr1_vol, vl); - vfloat16m8_t _outp = op(_p, _p1, vl); - - vse16_v_f16m8(outptr, _outp, vl); - outptr += vl; - ptr += vl; - n -= vl; - ptr1_vol += vl; - } + outptr[x] = op(ptr[x], ptr1[x]); } + + ptr += w; + outptr += w; } } @@ -1369,7 +2991,7 @@ static int binary_op_rvv_fp16sa(const Mat& a, const Mat& b, Mat& c, if (w1 != 1 && w == 1 && h1 == h && channels1 == channels) { // special type 7 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); + c.create(w1, h1, channels1, elemsize, opt.blob_allocator); if (c.empty()) return -100; @@ -1382,23 +3004,14 @@ static int binary_op_rvv_fp16sa(const Mat& a, const Mat& b, Mat& c, for (int y = 0; y < h1; y++) { + const __fp16 a0 = ptr[y]; for (int x = 0; x < w1; x++) { - int n = elempack; - const __fp16* ptr_vol = ptr + y * elempack; - while (n > 0) - { - word_type vl = vsetvl_e16m8(n); - vfloat16m8_t _p = vle16_v_f16m8(ptr_vol, vl); - vfloat16m8_t _p1 = vle16_v_f16m8(ptr1, vl); - vfloat16m8_t _outp = op(_p, _p1, vl); - vse16_v_f16m8(outptr, _outp, vl); - ptr1 += vl; - outptr += vl; - ptr_vol += vl; - n -= vl; - } + outptr[x] = op(a0, ptr1[x]); } + + ptr1 += w1; + outptr += w1; } } @@ -1408,7 +3021,7 @@ static int binary_op_rvv_fp16sa(const Mat& a, const Mat& b, Mat& c, if (w1 == w && h1 != 1 && h == 1 && channels1 == channels) { // special type 8 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); + c.create(w1, h1, channels1, elemsize, opt.blob_allocator); if (c.empty()) return -100; @@ -1423,22 +3036,11 @@ static int binary_op_rvv_fp16sa(const Mat& a, const Mat& b, Mat& c, { for (int x = 0; x < w1; x++) { - int n = elempack; - const __fp16* ptr_vol = ptr + x * elempack; - while (n > 0) - { - word_type vl = vsetvl_e16m8(n); - vfloat16m8_t _p = vle16_v_f16m8(ptr_vol, vl); - vfloat16m8_t _p1 = vle16_v_f16m8(ptr1, vl); - vfloat16m8_t _outp = op(_p, _p1, vl); - vse16_v_f16m8(outptr, _outp, vl); - - ptr1 += vl; - outptr += vl; - ptr_vol += vl; - n -= vl; - } + outptr[x] = op(ptr[x], ptr1[x]); } + + ptr1 += w1; + outptr += w1; } } @@ -1446,7 +3048,7 @@ static int binary_op_rvv_fp16sa(const Mat& a, const Mat& b, Mat& c, } // type 19 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); + c.create(w, h, channels, elemsize, opt.blob_allocator); if (c.empty()) return -100; @@ -1457,26 +3059,16 @@ static int binary_op_rvv_fp16sa(const Mat& a, const Mat& b, Mat& c, const __fp16* ptr1 = b.channel(q); __fp16* outptr = c.channel(q); - int n = size * elempack; - while (n > 0) + for (int i = 0; i < size; i++) { - word_type vl = vsetvl_e16m8(n); - vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); - vfloat16m8_t _p1 = vle16_v_f16m8(ptr1, vl); - vfloat16m8_t _outp = op(_p, _p1, vl); - vse16_v_f16m8(outptr, _outp, vl); - - ptr += vl; - ptr1 += vl; - outptr += vl; - n -= vl; + outptr[i] = op(ptr[i], ptr1[i]); } } return 0; } - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); + c.create(w, h, channels, elemsize, opt.blob_allocator); if (c.empty()) return -100; @@ -1492,25 +3084,14 @@ static int binary_op_rvv_fp16sa(const Mat& a, const Mat& b, Mat& c, for (int y = 0; y < h; y++) { + const __fp16 b0 = ptr1[y]; for (int x = 0; x < w; x++) { - const __fp16* ptr1_vol = ptr1; - int n = elempack1; - while (n > 0) - { - word_type vl = vsetvl_e16m8(n); - vfloat16m8_t _b0 = vle16_v_f16m8(ptr1_vol, vl); - vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); - vfloat16m8_t _outp = op(_p, _b0, vl); - vse16_v_f16m8(outptr, _outp, vl); - - ptr += vl; - outptr += vl; - ptr1_vol += vl; - n -= vl; - } + outptr[x] = op(ptr[x], b0); } - ptr1 += elempack1; + + ptr += w; + outptr += w; } } @@ -1519,27 +3100,19 @@ static int binary_op_rvv_fp16sa(const Mat& a, const Mat& b, Mat& c, if (b.dims == 1) { - if (b.w == 1 && elempack1 == 1) + if (b.w == 1) { // type 16 + const __fp16 b0 = ((const __fp16*)b)[0]; #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) { const __fp16* ptr = a.channel(q); - const __fp16 b0 = *(const __fp16*)b; __fp16* outptr = c.channel(q); - int n = size * elempack; - while (n > 0) + for (int i = 0; i < size; i++) { - word_type vl = vsetvl_e16m8(n); - vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); - vfloat16m8_t _outp = op(_p, b0, vl); - vse16_v_f16m8(outptr, _outp, vl); - - ptr += vl; - outptr += vl; - n -= vl; + outptr[i] = op(ptr[i], b0); } } @@ -1551,41 +3124,57 @@ static int binary_op_rvv_fp16sa(const Mat& a, const Mat& b, Mat& c, for (int q = 0; q < channels; q++) { const __fp16* ptr = a.channel(q); + const __fp16 b0 = ((const __fp16*)b)[q]; __fp16* outptr = c.channel(q); - int n = size * elempack; + for (int i = 0; i < size; i++) + { + outptr[i] = op(ptr[i], b0); + } + } - while (n > 0) + return 0; + } + } + else if (a.dims == 2) + { + if (b.dims == 4) + { + // type 22 + c.create(w1, h1, d1, channels1, elemsize, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const __fp16* ptr = a.row(q); + const __fp16* ptr1 = b.channel(q); + __fp16* outptr = c.channel(q); + + for (int z = 0; z < d1; z++) { - int n1 = elempack1; - const __fp16* ptr1_vol = (const __fp16*)b + q * elempack1; - while (n1 > 0) + const __fp16 a0 = ptr[z]; + for (int y = 0; y < h1; y++) { - word_type vl = vsetvl_e16m8(n1); - - vfloat16m8_t _b0 = vle16_v_f16m8(ptr1_vol, vl); - vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); - vfloat16m8_t _outp = op(_p, _b0, vl); - vse16_v_f16m8(outptr, _outp, vl); + for (int x = 0; x < w1; x++) + { + outptr[x] = op(a0, ptr1[x]); + } - ptr1_vol += vl; - outptr += vl; - ptr += vl; - n1 -= vl; - n -= vl; + ptr1 += w1; + outptr += w1; } } } return 0; } - } - else if (a.dims == 2) - { + if (b.dims == 3) { // type 14 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); + c.create(w1, h1, channels1, elemsize, opt.blob_allocator); if (c.empty()) return -100; @@ -1598,31 +3187,21 @@ static int binary_op_rvv_fp16sa(const Mat& a, const Mat& b, Mat& c, for (int y = 0; y < h1; y++) { + const __fp16 a0 = ptr[y]; for (int x = 0; x < w1; x++) { - const __fp16* ptr_vol = ptr; - int n = elempack1; - while (n > 0) - { - word_type vl = vsetvl_e16m8(n); - vfloat16m8_t _a0 = vle16_v_f16m8(ptr_vol, vl); - vfloat16m8_t _p1 = vle16_v_f16m8(ptr1, vl); - vfloat16m8_t _outp = op(_a0, _p1, vl); - vse16_v_f16m8(outptr, _outp, vl); - ptr1 += vl; - outptr += vl; - ptr_vol += vl; - n -= vl; - } + outptr[x] = op(a0, ptr1[x]); } - ptr += elempack; + + ptr1 += w1; + outptr += w1; } } return 0; } - c.create(w, h, elemsize, elempack, opt.blob_allocator); + c.create(w, h, elemsize, opt.blob_allocator); if (c.empty()) return -100; @@ -1632,20 +3211,9 @@ static int binary_op_rvv_fp16sa(const Mat& a, const Mat& b, Mat& c, const __fp16* ptr = a; const __fp16* ptr1 = b; __fp16* outptr = c; - int n = size * elempack; - while (n > 0) + for (int i = 0; i < size; i++) { - word_type vl = vsetvl_e16m8(n); - - vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); - vfloat16m8_t _p1 = vle16_v_f16m8(ptr1, vl); - vfloat16m8_t _outp = op(_p, _p1, vl); - vse16_v_f16m8(outptr, _outp, vl); - - ptr += vl; - ptr1 += vl; - outptr += vl; - n -= vl; + outptr[i] = op(ptr[i], ptr1[i]); } return 0; @@ -1653,27 +3221,19 @@ static int binary_op_rvv_fp16sa(const Mat& a, const Mat& b, Mat& c, if (b.dims == 1) { - c.create(w, h, elemsize, elempack, opt.blob_allocator); + c.create(w, h, elemsize, opt.blob_allocator); if (c.empty()) return -100; - if (b.w == 1 && elempack1 == 1) - { - // type 11 - const __fp16* ptr = a; - const __fp16 b0 = *(const __fp16*)b; - __fp16* outptr = c; - int n = size * elempack; - while (n > 0) - { - word_type vl = vsetvl_e16m8(n); - vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); - vfloat16m8_t _outp = op(_p, b0, vl); - vse16_v_f16m8(outptr, _outp, vl); - - ptr += vl; - outptr += vl; - n -= vl; + if (b.w == 1) + { + // type 11 + const __fp16 b0 = ((const __fp16*)b)[0]; + const __fp16* ptr = a; + __fp16* outptr = c; + for (int i = 0; i < size; i++) + { + outptr[i] = op(ptr[i], b0); } return 0; @@ -1681,31 +3241,18 @@ static int binary_op_rvv_fp16sa(const Mat& a, const Mat& b, Mat& c, // type 12 const __fp16* ptr = a; - const __fp16* ptr1 = b; __fp16* outptr = c; for (int y = 0; y < h; y++) { + const __fp16 b0 = ((const __fp16*)b)[y]; for (int x = 0; x < w; x++) { - int n = elempack; - const __fp16* ptr1_vol = ptr1; - while (n > 0) - { - word_type vl = vsetvl_e16m8(n); - vfloat16m8_t _b0 = vle16_v_f16m8(ptr1_vol, vl); - vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); - vfloat16m8_t _outp = op(_p, _b0, vl); - vse16_v_f16m8(outptr, _outp, vl); - - ptr += vl; - ptr1_vol += vl; - outptr += vl; - n -= vl; - } + outptr[x] = op(ptr[x], b0); } - ptr1 += elempack; + ptr += w; + outptr += w; } return 0; @@ -1713,34 +3260,48 @@ static int binary_op_rvv_fp16sa(const Mat& a, const Mat& b, Mat& c, } else if (a.dims == 1) { - if (a.w == 1 && elempack == 1) + if (a.w == 1) { + if (b.dims == 4) + { + // type 20 + c.create(w1, h1, d1, channels1, elemsize, opt.blob_allocator); + if (c.empty()) + return -100; + + const __fp16 a0 = ((const __fp16*)a)[0]; + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const __fp16* ptr1 = b.channel(q); + __fp16* outptr = c.channel(q); + + for (int i = 0; i < size1; i++) + { + outptr[i] = op(a0, ptr1[i]); + } + } + + return 0; + } + if (b.dims == 3) { // type 4 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); + c.create(w1, h1, channels1, elemsize, opt.blob_allocator); if (c.empty()) return -100; + const __fp16 a0 = ((const __fp16*)a)[0]; #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels1; q++) { - const __fp16 a0 = *(const __fp16*)a; const __fp16* ptr1 = b.channel(q); __fp16* outptr = c.channel(q); - int n1 = size1 * elempack1; - while (n1 > 0) + for (int i = 0; i < size1; i++) { - word_type vl = vsetvl_e16m8(n1); - vfloat16m8_t _a0 = vfmv_v_f_f16m8(a0, vl); - vfloat16m8_t _p1 = vle16_v_f16m8(ptr1, vl); - vfloat16m8_t _outp = op(_a0, _p1, vl); - vse16_v_f16m8(outptr, _outp, vl); - - ptr1 += vl; - outptr += vl; - n1 -= vl; + outptr[i] = op(a0, ptr1[i]); } } @@ -1750,26 +3311,16 @@ static int binary_op_rvv_fp16sa(const Mat& a, const Mat& b, Mat& c, if (b.dims == 2) { // type 3 - c.create(w1, h1, elemsize1, elempack1, opt.blob_allocator); + c.create(w1, h1, elemsize, opt.blob_allocator); if (c.empty()) return -100; - const __fp16 a0 = *(const __fp16*)a; + const __fp16 a0 = ((const __fp16*)a)[0]; const __fp16* ptr1 = b; __fp16* outptr = c; - - int n1 = size1 * elempack1; - while (n1 > 0) + for (int i = 0; i < size1; i++) { - word_type vl = vsetvl_e16m8(n1); - vfloat16m8_t _a0 = vfmv_v_f_f16m8(a0, vl); - vfloat16m8_t _p1 = vle16_v_f16m8(ptr1, vl); - vfloat16m8_t _outp = op(_a0, _p1, vl); - vse16_v_f16m8(outptr, _outp, vl); - - ptr1 += vl; - outptr += vl; - n1 -= vl; + outptr[i] = op(a0, ptr1[i]); } return 0; @@ -1778,61 +3329,62 @@ static int binary_op_rvv_fp16sa(const Mat& a, const Mat& b, Mat& c, if (b.dims == 1) { // type 2 - - c.create(w1, elemsize1, elempack1, opt.blob_allocator); + c.create(w1, elemsize, opt.blob_allocator); if (c.empty()) return -100; - const __fp16 a0 = *(const __fp16*)a; + const __fp16 a0 = ((const __fp16*)a)[0]; const __fp16* ptr1 = b; __fp16* outptr = c; - int n1 = w1 * elempack1; - while (n1 > 0) + for (int i = 0; i < w1; i++) { - word_type vl = vsetvl_e16m8(n1); + outptr[i] = op(a0, ptr1[i]); + } - vfloat16m8_t _a0 = vfmv_v_f_f16m8(a0, vl); - vfloat16m8_t _p1 = vle16_v_f16m8(ptr1, vl); - vfloat16m8_t _outp = op(_a0, _p1, vl); - vse16_v_f16m8(outptr, _outp, vl); + return 0; + } + } - ptr1 += vl; - outptr += vl; - n1 -= vl; + if (b.dims == 4) + { + // type 21 + c.create(w1, h1, d1, channels1, elemsize, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const __fp16 a0 = ((const __fp16*)a)[q]; + const __fp16* ptr1 = b.channel(q); + __fp16* outptr = c.channel(q); + + for (int i = 0; i < size1; i++) + { + outptr[i] = op(a0, ptr1[i]); } - return 0; } + + return 0; } if (b.dims == 3) { // type 9 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); + c.create(w1, h1, channels1, elemsize, opt.blob_allocator); if (c.empty()) return -100; #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels1; q++) { + const __fp16 a0 = ((const __fp16*)a)[q]; const __fp16* ptr1 = b.channel(q); __fp16* outptr = c.channel(q); - int n1 = size1 * elempack1; - while (n1 > 0) + + for (int i = 0; i < size1; i++) { - int n = elempack; - const __fp16* ptr_vol = (const __fp16*)a + q * elempack; - while (n > 0) - { - word_type vl = vsetvl_e16m8(n); - vfloat16m8_t _a0 = vle16_v_f16m8(ptr_vol, vl); - vfloat16m8_t _p1 = vle16_v_f16m8(ptr1, vl); - vfloat16m8_t _outp = op(_a0, _p1, vl); - vse16_v_f16m8(outptr, _outp, vl); - ptr1 += vl; - outptr += vl; - n1 -= vl; - n -= vl; - } + outptr[i] = op(a0, ptr1[i]); } } @@ -1842,61 +3394,43 @@ static int binary_op_rvv_fp16sa(const Mat& a, const Mat& b, Mat& c, if (b.dims == 2) { // type 8 - c.create(w1, h1, elemsize1, elempack1, opt.blob_allocator); + c.create(w1, h1, elemsize, opt.blob_allocator); if (c.empty()) return -100; - const __fp16* ptr = a; const __fp16* ptr1 = b; __fp16* outptr = c; + for (int y = 0; y < h1; y++) { + const __fp16 a0 = ((const __fp16*)a)[y]; for (int x = 0; x < w1; x++) { - const __fp16* ptr_vol = ptr; - int n = elempack; - while (n > 0) - { - word_type vl = vsetvl_e16m8(n); - - vfloat16m8_t _a0 = vle16_v_f16m8(ptr, vl); - vfloat16m8_t _p1 = vle16_v_f16m8(ptr1, vl); - vfloat16m8_t _outp = op(_a0, _p1, vl); - vse16_v_f16m8(outptr, _outp, vl); - - ptr1 += vl; - outptr += vl; - ptr_vol += vl; - n -= vl; - } + outptr[x] = op(a0, ptr1[x]); } - ptr += elempack; + + ptr1 += w1; + outptr += w1; } + return 0; } if (b.dims == 1) { - c.create(w, elemsize, elempack, opt.blob_allocator); + c.create(w, elemsize, opt.blob_allocator); if (c.empty()) return -100; - if (b.w == 1 && elempack1 == 1) + if (b.w == 1) { // type 6 + const __fp16 b0 = ((const __fp16*)b)[0]; const __fp16* ptr = a; - const __fp16 b0 = *(const __fp16*)b; __fp16* outptr = c; - int n = w * elempack; - while (n > 0) + for (int i = 0; i < w; i++) { - word_type vl = vsetvl_e16m8(n); - vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); - vfloat16m8_t _outp = op(_p, b0, vl); - vse16_v_f16m8(outptr, _outp, vl); - ptr += vl; - outptr += vl; - n -= vl; + outptr[i] = op(ptr[i], b0); } return 0; @@ -1906,20 +3440,9 @@ static int binary_op_rvv_fp16sa(const Mat& a, const Mat& b, Mat& c, const __fp16* ptr = a; const __fp16* ptr1 = b; __fp16* outptr = c; - - int n = size * elempack; - while (n > 0) + for (int i = 0; i < w; i++) { - word_type vl = vsetvl_e16m8(n); - vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); - vfloat16m8_t _p1 = vle16_v_f16m8(ptr1, vl); - vfloat16m8_t _outp = op(_p, _p1, vl); - vse16_v_f16m8(outptr, _outp, vl); - - ptr += vl; - ptr1 += vl; - outptr += vl; - n -= vl; + outptr[i] = op(ptr[i], ptr1[i]); } } } @@ -1927,254 +3450,180 @@ static int binary_op_rvv_fp16sa(const Mat& a, const Mat& b, Mat& c, return 0; } -struct binary_op_add_rvv_fp16 +struct binary_op_add_fp16s { - vfloat16m8_t operator()(const vfloat16m8_t& x, const vfloat16m8_t& y, - const word_type& vl) const - { - return vfadd_vv_f16m8(x, y, vl); - } - vfloat16m8_t operator()(const vfloat16m8_t& x, const float& y, - const word_type& vl) const + __fp16 operator()(const __fp16& x, const __fp16& y) const { - return vfadd_vf_f16m8(x, y, vl); + return x + y; } }; -struct binary_op_sub_rvv_fp16 +struct binary_op_sub_fp16s { - vfloat16m8_t operator()(const vfloat16m8_t& x, const vfloat16m8_t& y, - const word_type& vl) const - { - return vfsub_vv_f16m8(x, y, vl); - } - vfloat16m8_t operator()(const vfloat16m8_t& x, float y, - const word_type& vl) const + __fp16 operator()(const __fp16& x, const __fp16& y) const { - return vfsub_vf_f16m8(x, y, vl); + return x - y; } }; -struct binary_op_mul_rvv_fp16 +struct binary_op_mul_fp16s { - vfloat16m8_t operator()(const vfloat16m8_t& x, const vfloat16m8_t& y, - const word_type& vl) const - { - return vfmul_vv_f16m8(x, y, vl); - } - vfloat16m8_t operator()(const vfloat16m8_t& x, float y, - const word_type& vl) const + __fp16 operator()(const __fp16& x, const __fp16& y) const { - return vfmul_vf_f16m8(x, y, vl); + return x * y; } }; -struct binary_op_div_rvv_fp16 +struct binary_op_div_fp16s { - vfloat16m8_t operator()(const vfloat16m8_t& x, const vfloat16m8_t& y, - const word_type& vl) const + __fp16 operator()(const __fp16& x, const __fp16& y) const { - return vfdiv_vv_f16m8(x, y, vl); - } - vfloat16m8_t operator()(const vfloat16m8_t& x, float y, - const word_type& vl) const - { - return vfdiv_vf_f16m8(x, y, vl); + return x / y; } }; -struct binary_op_max_rvv_fp16 +struct binary_op_max_fp16s { - vfloat16m8_t operator()(const vfloat16m8_t& x, const vfloat16m8_t& y, - const word_type& vl) const - { - return vfmax_vv_f16m8(x, y, vl); - } - vfloat16m8_t operator()(const vfloat16m8_t& x, float y, - const word_type& vl) const + __fp16 operator()(const __fp16& x, const __fp16& y) const { - return vfmax_vf_f16m8(x, y, vl); + return std::max(x, y); } }; -struct binary_op_min_rvv_fp16 +struct binary_op_min_fp16s { - vfloat16m8_t operator()(const vfloat16m8_t& x, const vfloat16m8_t& y, - const word_type& vl) const - { - return vfmin_vv_f16m8(x, y, vl); - } - vfloat16m8_t operator()(const vfloat16m8_t& x, float y, - const word_type& vl) const + __fp16 operator()(const __fp16& x, const __fp16& y) const { - return vfmin_vf_f16m8(x, y, vl); + return std::min(x, y); } }; -struct binary_op_pow_rvv_fp16 +struct binary_op_pow_fp16s { - vfloat16m8_t operator()(const vfloat16m8_t& x, const vfloat16m8_t& y, - const word_type& vl) const - { - return pow_ps(x, y, vl); // rvv_mathfun_fp16s.h - } - vfloat16m8_t operator()(const vfloat16m8_t& x, const __fp16& y, - const word_type& vl) const + __fp16 operator()(const __fp16& x, const __fp16& y) const { - vfloat16m8_t _op2 = vfmv_v_f_f16m8(y, vl); - vfloat16m8_t retval = pow_ps(x, _op2, vl); // rvv_mathfun_fp16s.h - return retval; + return (__fp16)pow((float)x, (float)y); } }; -struct binary_op_rsub_rvv_fp16 +struct binary_op_rsub_fp16s { - vfloat16m8_t operator()(const vfloat16m8_t& x, const vfloat16m8_t& y, - const word_type& vl) const - { - return vfsub_vv_f16m8(y, x, vl); - } - vfloat16m8_t operator()(const vfloat16m8_t& x, const float& y, - const word_type& vl) const + __fp16 operator()(const __fp16& x, const __fp16& y) const { - return vfrsub_vf_f16m8(x, y, vl); + return y - x; } }; -struct binary_op_rdiv_rvv_fp16 +struct binary_op_rdiv_fp16s { - vfloat16m8_t operator()(const vfloat16m8_t& x, const vfloat16m8_t& y, - const word_type& vl) const - { - return vfdiv_vv_f16m8(y, x, vl); - } - vfloat16m8_t operator()(const vfloat16m8_t& x, float y, - const word_type& vl) const + __fp16 operator()(const __fp16& x, const __fp16& y) const { - return vfrdiv_vf_f16m8(x, y, vl); + return y / x; } }; -#endif -#if __riscv_vector && __riscv_zfh -int BinaryOp_riscv::forward_fp16sa(const std::vector& bottom_blobs, - std::vector& top_blobs, - const Option& opt) const +int BinaryOp_riscv::forward_fp16s(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const { const Mat& bottom_blob = bottom_blobs[0]; const Mat& bottom_blob1 = bottom_blobs[1]; Mat& top_blob = top_blobs[0]; - if (op_type == Operation_ADD) - return binary_op_rvv_fp16sa(bottom_blob, bottom_blob1, - top_blob, opt); + int elempack = bottom_blob.elempack; + int elempack1 = bottom_blob1.elempack; + if (elempack != 1 || elempack1 != 1) + { + if (op_type == Operation_ADD) + return binary_op_rvv_fp16s(bottom_blob, bottom_blob1, top_blob, opt); - if (op_type == Operation_SUB) - return binary_op_rvv_fp16sa(bottom_blob, bottom_blob1, - top_blob, opt); + if (op_type == Operation_SUB) + return binary_op_rvv_fp16s(bottom_blob, bottom_blob1, top_blob, opt); - if (op_type == Operation_MUL) - return binary_op_rvv_fp16sa(bottom_blob, bottom_blob1, - top_blob, opt); + if (op_type == Operation_MUL) + return binary_op_rvv_fp16s(bottom_blob, bottom_blob1, top_blob, opt); - if (op_type == Operation_DIV) - return binary_op_rvv_fp16sa(bottom_blob, bottom_blob1, - top_blob, opt); + if (op_type == Operation_DIV) + return binary_op_rvv_fp16s(bottom_blob, bottom_blob1, top_blob, opt); - if (op_type == Operation_MAX) - return binary_op_rvv_fp16sa(bottom_blob, bottom_blob1, - top_blob, opt); + if (op_type == Operation_MAX) + return binary_op_rvv_fp16s(bottom_blob, bottom_blob1, top_blob, opt); - if (op_type == Operation_MIN) - return binary_op_rvv_fp16sa(bottom_blob, bottom_blob1, - top_blob, opt); + if (op_type == Operation_MIN) + return binary_op_rvv_fp16s(bottom_blob, bottom_blob1, top_blob, opt); - if (op_type == Operation_POW) - return binary_op_rvv_fp16sa(bottom_blob, bottom_blob1, - top_blob, opt); + if (op_type == Operation_POW) + return binary_op_rvv_fp16s(bottom_blob, bottom_blob1, top_blob, opt); - if (op_type == Operation_RSUB) - return binary_op_rvv_fp16sa(bottom_blob, bottom_blob1, - top_blob, opt); + if (op_type == Operation_RSUB) + return binary_op_rvv_fp16s(bottom_blob, bottom_blob1, top_blob, opt); - if (op_type == Operation_RDIV) - return binary_op_rvv_fp16sa(bottom_blob, bottom_blob1, - top_blob, opt); + if (op_type == Operation_RDIV) + return binary_op_rvv_fp16s(bottom_blob, bottom_blob1, top_blob, opt); + } - return 0; -} + if (elempack == 1 && elempack1 == 1) + { + if (op_type == Operation_ADD) + return binary_op_fp16s(bottom_blob, bottom_blob1, top_blob, opt); -#if __riscv_vector && __riscv_zfh -template -static int binary_op_scalar_rvv_fp16sa(Mat& a, float b, const Option& opt) -{ - Op op; - int w = a.w; - int h = a.h; - int channels = a.c; - int size = w * h; - int elempack = a.elempack; + if (op_type == Operation_SUB) + return binary_op_fp16s(bottom_blob, bottom_blob1, top_blob, opt); - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - __fp16* ptr = a.channel(q); - int n = size * elempack; - while (n > 0) - { - word_type vl = vsetvl_e16m8(n); - vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); - _p = op(_p, b, vl); - vse16_v_f16m8(ptr, _p, vl); + if (op_type == Operation_MUL) + return binary_op_fp16s(bottom_blob, bottom_blob1, top_blob, opt); - n -= vl; - ptr += vl; - } + if (op_type == Operation_DIV) + return binary_op_fp16s(bottom_blob, bottom_blob1, top_blob, opt); + + if (op_type == Operation_MAX) + return binary_op_fp16s(bottom_blob, bottom_blob1, top_blob, opt); + + if (op_type == Operation_MIN) + return binary_op_fp16s(bottom_blob, bottom_blob1, top_blob, opt); + + if (op_type == Operation_POW) + return binary_op_fp16s(bottom_blob, bottom_blob1, top_blob, opt); + + if (op_type == Operation_RSUB) + return binary_op_fp16s(bottom_blob1, bottom_blob, top_blob, opt); + + if (op_type == Operation_RDIV) + return binary_op_fp16s(bottom_blob1, bottom_blob, top_blob, opt); } + return 0; } -#endif -int BinaryOp_riscv::forward_inplace_fp16sa(Mat& bottom_top_blob, - const Option& opt) const + +int BinaryOp_riscv::forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt) const { if (op_type == Operation_ADD) - return binary_op_scalar_rvv_fp16sa(bottom_top_blob, b, - opt); + return binary_op_scalar_rvv_fp16s(bottom_top_blob, b, opt); if (op_type == Operation_SUB) - return binary_op_scalar_rvv_fp16sa(bottom_top_blob, b, - opt); + return binary_op_scalar_rvv_fp16s(bottom_top_blob, b, opt); if (op_type == Operation_MUL) - return binary_op_scalar_rvv_fp16sa(bottom_top_blob, b, - opt); + return binary_op_scalar_rvv_fp16s(bottom_top_blob, b, opt); if (op_type == Operation_DIV) - return binary_op_scalar_rvv_fp16sa(bottom_top_blob, b, - opt); + return binary_op_scalar_rvv_fp16s(bottom_top_blob, b, opt); if (op_type == Operation_MAX) - return binary_op_scalar_rvv_fp16sa(bottom_top_blob, b, - opt); + return binary_op_scalar_rvv_fp16s(bottom_top_blob, b, opt); if (op_type == Operation_MIN) - return binary_op_scalar_rvv_fp16sa(bottom_top_blob, b, - opt); + return binary_op_scalar_rvv_fp16s(bottom_top_blob, b, opt); if (op_type == Operation_POW) - return binary_op_scalar_rvv_fp16sa(bottom_top_blob, b, - opt); + return binary_op_scalar_rvv_fp16s(bottom_top_blob, b, opt); if (op_type == Operation_RSUB) - return binary_op_scalar_rvv_fp16sa(bottom_top_blob, b, - opt); + return binary_op_scalar_rvv_fp16s(bottom_top_blob, b, opt); if (op_type == Operation_RDIV) - return binary_op_scalar_rvv_fp16sa(bottom_top_blob, b, - opt); + return binary_op_scalar_rvv_fp16s(bottom_top_blob, b, opt); + return 0; } - #endif -} // namespace ncnn \ No newline at end of file +} // namespace ncnn diff --git a/src/layer/riscv/binaryop_riscv.h b/src/layer/riscv/binaryop_riscv.h index c7fd73942..0ecd34d68 100644 --- a/src/layer/riscv/binaryop_riscv.h +++ b/src/layer/riscv/binaryop_riscv.h @@ -1,8 +1,6 @@ -// Xavier Hsinyuan is pleased to support the open source community by making -// ncnn available. +// Xavier Hsinyuan is pleased to support the open source community by making ncnn available. // -// Copyright (C) 2021 Xavier Hsinyuan . All rights -// reserved. +// Copyright (C) 2021 Xavier Hsinyuan . All rights reserved. // // Licensed under the BSD 3-Clause License (the "License"); you may not use this // file except in compliance with the License. You may obtain a copy of the @@ -15,10 +13,12 @@ // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the // License for the specific language governing permissions and limitations under // the License. + #ifndef LAYER_BINARYOP_RISCV_H #define LAYER_BINARYOP_RISCV_H #include "binaryop.h" + namespace ncnn { class BinaryOp_riscv : virtual public BinaryOp @@ -26,19 +26,17 @@ class BinaryOp_riscv : virtual public BinaryOp public: BinaryOp_riscv(); - virtual int forward(const std::vector& bottom_blobs, - std::vector& top_blobs, const Option& opt) const; + virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const; protected: #if __riscv_vector && __riscv_zfh - int forward_inplace_fp16sa(Mat& bottom_top_blob, const Option& opt) const; - - int forward_fp16sa(const std::vector& bottom_blobs, - std::vector& top_blobs, const Option& opt) const; + int forward_fp16s(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; + int forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt) const; #endif }; + } // namespace ncnn -#endif // LAYER_BINARYOP_RISCV_H \ No newline at end of file +#endif // LAYER_BINARYOP_RISCV_H diff --git a/src/layer/riscv/convolution_3x3_packn.h b/src/layer/riscv/convolution_3x3_packn.h index 1b9f6094d..534d4622a 100644 --- a/src/layer/riscv/convolution_3x3_packn.h +++ b/src/layer/riscv/convolution_3x3_packn.h @@ -324,7 +324,7 @@ static void conv3x3s1_winograd64_packn_rvv(const Mat& bottom_blob, Mat& top_blob for (int q = 0; q < inch; q++) { -#if RVV_SPEC_0_7 +#if C906 for (int l = 0; l < packn; l++) { tmpptr[0] = r0[l]; @@ -365,7 +365,7 @@ static void conv3x3s1_winograd64_packn_rvv(const Mat& bottom_blob, Mat& top_blob for (int q = 0; q < inch; q++) { -#if RVV_SPEC_0_7 +#if C906 for (int l = 0; l < packn; l++) { tmpptr[0] = r0[l]; @@ -398,7 +398,7 @@ static void conv3x3s1_winograd64_packn_rvv(const Mat& bottom_blob, Mat& top_blob for (int q = 0; q < inch; q++) { -#if RVV_SPEC_0_7 +#if C906 for (int l = 0; l < packn; l++) { tmpptr[0] = r0[l]; @@ -999,7 +999,7 @@ static void conv3x3s1_winograd42_packn_rvv(const Mat& bottom_blob, Mat& top_blob for (int q = 0; q < inch; q++) { -#if RVV_SPEC_0_7 +#if C906 for (int l = 0; l < packn; l++) { tmpptr[0] = r0[l]; @@ -1040,7 +1040,7 @@ static void conv3x3s1_winograd42_packn_rvv(const Mat& bottom_blob, Mat& top_blob for (int q = 0; q < inch; q++) { -#if RVV_SPEC_0_7 +#if C906 for (int l = 0; l < packn; l++) { tmpptr[0] = r0[l]; @@ -1073,7 +1073,7 @@ static void conv3x3s1_winograd42_packn_rvv(const Mat& bottom_blob, Mat& top_blob for (int q = 0; q < inch; q++) { -#if RVV_SPEC_0_7 +#if C906 for (int l = 0; l < packn; l++) { tmpptr[0] = r0[l]; diff --git a/src/layer/riscv/convolution_3x3_packn_fp16s.h b/src/layer/riscv/convolution_3x3_packn_fp16s.h index d64dd0c17..26d814b0d 100644 --- a/src/layer/riscv/convolution_3x3_packn_fp16s.h +++ b/src/layer/riscv/convolution_3x3_packn_fp16s.h @@ -324,7 +324,7 @@ static void conv3x3s1_winograd64_packn_fp16sa_rvv(const Mat& bottom_blob, Mat& t for (int q = 0; q < inch; q++) { -#if RVV_SPEC_0_7 +#if C906 for (int l = 0; l < packn; l++) { tmpptr[0] = r0[l]; @@ -365,7 +365,7 @@ static void conv3x3s1_winograd64_packn_fp16sa_rvv(const Mat& bottom_blob, Mat& t for (int q = 0; q < inch; q++) { -#if RVV_SPEC_0_7 +#if C906 for (int l = 0; l < packn; l++) { tmpptr[0] = r0[l]; @@ -398,7 +398,7 @@ static void conv3x3s1_winograd64_packn_fp16sa_rvv(const Mat& bottom_blob, Mat& t for (int q = 0; q < inch; q++) { -#if RVV_SPEC_0_7 +#if C906 for (int l = 0; l < packn; l++) { tmpptr[0] = r0[l]; @@ -999,7 +999,7 @@ static void conv3x3s1_winograd42_packn_fp16sa_rvv(const Mat& bottom_blob, Mat& t for (int q = 0; q < inch; q++) { -#if RVV_SPEC_0_7 +#if C906 for (int l = 0; l < packn; l++) { tmpptr[0] = r0[l]; @@ -1040,7 +1040,7 @@ static void conv3x3s1_winograd42_packn_fp16sa_rvv(const Mat& bottom_blob, Mat& t for (int q = 0; q < inch; q++) { -#if RVV_SPEC_0_7 +#if C906 for (int l = 0; l < packn; l++) { tmpptr[0] = r0[l]; @@ -1073,7 +1073,7 @@ static void conv3x3s1_winograd42_packn_fp16sa_rvv(const Mat& bottom_blob, Mat& t for (int q = 0; q < inch; q++) { -#if RVV_SPEC_0_7 +#if C906 for (int l = 0; l < packn; l++) { tmpptr[0] = r0[l]; diff --git a/src/layer/riscv/convolution_packnto1_fp16s.h b/src/layer/riscv/convolution_packnto1_fp16s.h index ce489bac7..47c406cda 100644 --- a/src/layer/riscv/convolution_packnto1_fp16s.h +++ b/src/layer/riscv/convolution_packnto1_fp16s.h @@ -84,7 +84,7 @@ static void convolution_packnto1_fp16s_rvv(const Mat& bottom_blob, Mat& top_blob } } -#ifdef RVV_SPEC_0_7 +#if C906 // TODO std::vector ss(packn); vse32_v_f32m2((float*)ss.data(), _sum, vl); diff --git a/src/layer/riscv/convolution_sgemm_packn.h b/src/layer/riscv/convolution_sgemm_packn.h index 66335b273..88518a231 100644 --- a/src/layer/riscv/convolution_sgemm_packn.h +++ b/src/layer/riscv/convolution_sgemm_packn.h @@ -54,7 +54,7 @@ static void im2col_sgemm_packn_rvv(const Mat& bottom_im2col, Mat& top_blob, cons for (int k = 0; k < maxk; k++) { -#if RVV_SPEC_0_7 +#if C906 for (int l = 0; l < packn; l++) { tmpptr[0] = img0[l]; @@ -103,7 +103,7 @@ static void im2col_sgemm_packn_rvv(const Mat& bottom_im2col, Mat& top_blob, cons for (int k = 0; k < maxk; k++) { -#if RVV_SPEC_0_7 +#if C906 for (int l = 0; l < packn; l++) { tmpptr[0] = img0[l]; @@ -144,7 +144,7 @@ static void im2col_sgemm_packn_rvv(const Mat& bottom_im2col, Mat& top_blob, cons for (int k = 0; k < maxk; k++) { -#if RVV_SPEC_0_7 +#if C906 for (int l = 0; l < packn; l++) { tmpptr[0] = img0[l]; diff --git a/src/layer/riscv/convolution_sgemm_packn_fp16s.h b/src/layer/riscv/convolution_sgemm_packn_fp16s.h index 82086acea..977dc3820 100644 --- a/src/layer/riscv/convolution_sgemm_packn_fp16s.h +++ b/src/layer/riscv/convolution_sgemm_packn_fp16s.h @@ -54,7 +54,8 @@ static void im2col_sgemm_packn_fp16sa_rvv(const Mat& bottom_im2col, Mat& top_blo for (int k = 0; k < maxk; k++) { -#if RVV_SPEC_0_7 +#if C906 +#ifdef RVV_SPEC_0_7 asm volatile( "mv t3, %[LEN] \n\t" "mv t1, %[SRC] \n\t" @@ -83,7 +84,22 @@ static void im2col_sgemm_packn_fp16sa_rvv(const Mat& bottom_im2col, Mat& top_blo img0 += size * packn; tmpptr += packn * 8; +#else + for (int l = 0; l < packn; l++) + { + tmpptr[0] = img0[l]; + tmpptr[1] = img0[l + packn]; + tmpptr[2] = img0[l + packn * 2]; + tmpptr[3] = img0[l + packn * 3]; + tmpptr[4] = img0[l + packn * 4]; + tmpptr[5] = img0[l + packn * 5]; + tmpptr[6] = img0[l + packn * 6]; + tmpptr[7] = img0[l + packn * 7]; + tmpptr += 8; + } + img0 += size * packn; +#endif #else vfloat16m1_t _val0 = vle16_v_f16m1(img0, vl); vfloat16m1_t _val1 = vle16_v_f16m1(img0 + packn, vl); @@ -118,7 +134,8 @@ static void im2col_sgemm_packn_fp16sa_rvv(const Mat& bottom_im2col, Mat& top_blo for (int k = 0; k < maxk; k++) { -#if RVV_SPEC_0_7 +#if C906 +#ifdef RVV_SPEC_0_7 asm volatile( "mv t3, %[LEN] \n\t" "mv t1, %[SRC] \n\t" @@ -138,6 +155,18 @@ static void im2col_sgemm_packn_fp16sa_rvv(const Mat& bottom_im2col, Mat& top_blo img0 += size * packn; tmpptr += packn * 4; +#else + for (int l = 0; l < packn; l++) + { + tmpptr[0] = img0[l]; + tmpptr[1] = img0[l + packn]; + tmpptr[2] = img0[l + packn * 2]; + tmpptr[3] = img0[l + packn * 3]; + tmpptr += 4; + } + + img0 += size * packn; +#endif #else vfloat16m1_t _val0 = vle16_v_f16m1(img0, vl); vfloat16m1_t _val1 = vle16_v_f16m1(img0 + packn, vl); @@ -169,7 +198,8 @@ static void im2col_sgemm_packn_fp16sa_rvv(const Mat& bottom_im2col, Mat& top_blo for (int k = 0; k < maxk; k++) { -#if RVV_SPEC_0_7 +#if C906 +#ifdef RVV_SPEC_0_7 asm volatile( "mv t3, %[LEN] \n\t" "mv t1, %[SRC] \n\t" @@ -185,6 +215,16 @@ static void im2col_sgemm_packn_fp16sa_rvv(const Mat& bottom_im2col, Mat& top_blo img0 += size * packn; tmpptr += packn * 2; +#else + for (int l = 0; l < packn; l++) + { + tmpptr[0] = img0[l]; + tmpptr[1] = img0[l + packn]; + tmpptr += 2; + } + + img0 += size * packn; +#endif #else vfloat16m1_t _val0 = vle16_v_f16m1(img0, vl); vfloat16m1_t _val1 = vle16_v_f16m1(img0 + packn, vl); diff --git a/src/layer/riscv/convolution_sgemm_packnto1.h b/src/layer/riscv/convolution_sgemm_packnto1.h index 828bdcb69..686709e04 100644 --- a/src/layer/riscv/convolution_sgemm_packnto1.h +++ b/src/layer/riscv/convolution_sgemm_packnto1.h @@ -53,7 +53,7 @@ static void im2col_sgemm_packnto1_rvv(const Mat& bottom_im2col, Mat& top_blob, c for (int k = 0; k < maxk; k++) { -#if RVV_SPEC_0_7 +#if C906 for (int l = 0; l < packn; l++) { tmpptr[0] = img0[l]; @@ -102,7 +102,7 @@ static void im2col_sgemm_packnto1_rvv(const Mat& bottom_im2col, Mat& top_blob, c for (int k = 0; k < maxk; k++) { -#if RVV_SPEC_0_7 +#if C906 for (int l = 0; l < packn; l++) { tmpptr[0] = img0[l]; @@ -143,7 +143,7 @@ static void im2col_sgemm_packnto1_rvv(const Mat& bottom_im2col, Mat& top_blob, c for (int k = 0; k < maxk; k++) { -#if RVV_SPEC_0_7 +#if C906 for (int l = 0; l < packn; l++) { tmpptr[0] = img0[l]; @@ -240,7 +240,7 @@ static void im2col_sgemm_packnto1_rvv(const Mat& bottom_im2col, Mat& top_blob, c kptr0 += packn; } -#if RVV_SPEC_0_7 +#if C906 vsse32_v_f32m1(outptr0, top_blob.cstep * sizeof(float), _sum0, vl); vsse32_v_f32m1(outptr0 + 1, top_blob.cstep * sizeof(float), _sum1, vl); vsse32_v_f32m1(outptr0 + 2, top_blob.cstep * sizeof(float), _sum2, vl); @@ -281,7 +281,7 @@ static void im2col_sgemm_packnto1_rvv(const Mat& bottom_im2col, Mat& top_blob, c kptr0 += packn; } -#if RVV_SPEC_0_7 +#if C906 vsse32_v_f32m1(outptr0, top_blob.cstep * sizeof(float), _sum0, vl); vsse32_v_f32m1(outptr0 + 1, top_blob.cstep * sizeof(float), _sum1, vl); vsse32_v_f32m1(outptr0 + 2, top_blob.cstep * sizeof(float), _sum2, vl); @@ -312,7 +312,7 @@ static void im2col_sgemm_packnto1_rvv(const Mat& bottom_im2col, Mat& top_blob, c kptr0 += packn; } -#if RVV_SPEC_0_7 +#if C906 vsse32_v_f32m1(outptr0, top_blob.cstep * sizeof(float), _sum0, vl); vsse32_v_f32m1(outptr0 + 1, top_blob.cstep * sizeof(float), _sum1, vl); #else @@ -393,7 +393,7 @@ static void im2col_sgemm_packnto1_rvv(const Mat& bottom_im2col, Mat& top_blob, c kptr0 += packn; } -#ifdef RVV_SPEC_0_7 +#if C906 // TODO std::vector ss0(packn); std::vector ss1(packn); @@ -473,7 +473,7 @@ static void im2col_sgemm_packnto1_rvv(const Mat& bottom_im2col, Mat& top_blob, c kptr0 += packn; } -#ifdef RVV_SPEC_0_7 +#if C906 // TODO std::vector ss0(packn); std::vector ss1(packn); @@ -527,7 +527,7 @@ static void im2col_sgemm_packnto1_rvv(const Mat& bottom_im2col, Mat& top_blob, c kptr0 += packn; } -#ifdef RVV_SPEC_0_7 +#if C906 // TODO std::vector ss0(packn); std::vector ss1(packn); @@ -568,7 +568,7 @@ static void im2col_sgemm_packnto1_rvv(const Mat& bottom_im2col, Mat& top_blob, c kptr0 += packn; } -#ifdef RVV_SPEC_0_7 +#if C906 // TODO std::vector ss0(packn); vse32_v_f32m1((float*)ss0.data(), _sum0, vl); diff --git a/src/layer/riscv/convolution_sgemm_packnto1_fp16s.h b/src/layer/riscv/convolution_sgemm_packnto1_fp16s.h index b8c297c21..d4a935615 100644 --- a/src/layer/riscv/convolution_sgemm_packnto1_fp16s.h +++ b/src/layer/riscv/convolution_sgemm_packnto1_fp16s.h @@ -53,7 +53,7 @@ static void im2col_sgemm_packnto1_fp16sa_rvv(const Mat& bottom_im2col, Mat& top_ for (int k = 0; k < maxk; k++) { -#if RVV_SPEC_0_7 +#if C906 for (int l = 0; l < packn; l++) { tmpptr[0] = img0[l]; @@ -102,7 +102,7 @@ static void im2col_sgemm_packnto1_fp16sa_rvv(const Mat& bottom_im2col, Mat& top_ for (int k = 0; k < maxk; k++) { -#if RVV_SPEC_0_7 +#if C906 for (int l = 0; l < packn; l++) { tmpptr[0] = img0[l]; @@ -143,7 +143,7 @@ static void im2col_sgemm_packnto1_fp16sa_rvv(const Mat& bottom_im2col, Mat& top_ for (int k = 0; k < maxk; k++) { -#if RVV_SPEC_0_7 +#if C906 for (int l = 0; l < packn; l++) { tmpptr[0] = img0[l]; @@ -240,7 +240,7 @@ static void im2col_sgemm_packnto1_fp16sa_rvv(const Mat& bottom_im2col, Mat& top_ kptr0 += packn; } -#if RVV_SPEC_0_7 +#if C906 vsse16_v_f16m1(outptr0, top_blob.cstep * sizeof(__fp16), _sum0, vl); vsse16_v_f16m1(outptr0 + 1, top_blob.cstep * sizeof(__fp16), _sum1, vl); vsse16_v_f16m1(outptr0 + 2, top_blob.cstep * sizeof(__fp16), _sum2, vl); @@ -281,7 +281,7 @@ static void im2col_sgemm_packnto1_fp16sa_rvv(const Mat& bottom_im2col, Mat& top_ kptr0 += packn; } -#if RVV_SPEC_0_7 +#if C906 vsse16_v_f16m1(outptr0, top_blob.cstep * sizeof(__fp16), _sum0, vl); vsse16_v_f16m1(outptr0 + 1, top_blob.cstep * sizeof(__fp16), _sum1, vl); vsse16_v_f16m1(outptr0 + 2, top_blob.cstep * sizeof(__fp16), _sum2, vl); @@ -312,7 +312,7 @@ static void im2col_sgemm_packnto1_fp16sa_rvv(const Mat& bottom_im2col, Mat& top_ kptr0 += packn; } -#if RVV_SPEC_0_7 +#if C906 vsse16_v_f16m1(outptr0, top_blob.cstep * sizeof(__fp16), _sum0, vl); vsse16_v_f16m1(outptr0 + 1, top_blob.cstep * sizeof(__fp16), _sum1, vl); #else @@ -393,7 +393,7 @@ static void im2col_sgemm_packnto1_fp16sa_rvv(const Mat& bottom_im2col, Mat& top_ kptr0 += packn; } -#ifdef RVV_SPEC_0_7 +#if C906 // TODO std::vector<__fp16> ss0(packn); std::vector<__fp16> ss1(packn); @@ -473,7 +473,7 @@ static void im2col_sgemm_packnto1_fp16sa_rvv(const Mat& bottom_im2col, Mat& top_ kptr0 += packn; } -#ifdef RVV_SPEC_0_7 +#if C906 // TODO std::vector<__fp16> ss0(packn); std::vector<__fp16> ss1(packn); @@ -527,7 +527,7 @@ static void im2col_sgemm_packnto1_fp16sa_rvv(const Mat& bottom_im2col, Mat& top_ kptr0 += packn; } -#ifdef RVV_SPEC_0_7 +#if C906 // TODO std::vector<__fp16> ss0(packn); std::vector<__fp16> ss1(packn); @@ -568,7 +568,7 @@ static void im2col_sgemm_packnto1_fp16sa_rvv(const Mat& bottom_im2col, Mat& top_ kptr0 += packn; } -#ifdef RVV_SPEC_0_7 +#if C906 // TODO std::vector<__fp16> ss0(packn); vse16_v_f16m1((__fp16*)ss0.data(), _sum0, vl); diff --git a/src/layer/riscv/deconvolution_packnto1_fp16s.h b/src/layer/riscv/deconvolution_packnto1_fp16s.h index 456ef387a..c42403565 100644 --- a/src/layer/riscv/deconvolution_packnto1_fp16s.h +++ b/src/layer/riscv/deconvolution_packnto1_fp16s.h @@ -91,7 +91,7 @@ static void deconvolution_packnto1_fp16s_rvv(const Mat& bottom_blob, Mat& top_bl kptr += maxk * packn; } -#ifdef RVV_SPEC_0_7 +#if C906 // TODO std::vector ss(packn); vse32_v_f32m2((float*)ss.data(), _sum, vl); diff --git a/src/layer/riscv/riscv_usability.h b/src/layer/riscv/riscv_usability.h index ae1ef0792..f6f3efa0e 100644 --- a/src/layer/riscv/riscv_usability.h +++ b/src/layer/riscv/riscv_usability.h @@ -15,6 +15,14 @@ #ifndef RISCV_USABILITY_H #define RISCV_USABILITY_H +#if __riscv_vector +#ifdef RVV_SPEC_0_7 +#include "riscv_v_071_fix.h" +#else +#include +#endif +#endif // __riscv_vector + #if __riscv_vector static inline int csrr_vl() { @@ -45,6 +53,80 @@ static inline int csrr_vlenb() : "memory"); return a; } + +static inline vfloat32m8_t vle32_v_f32m8_f32m1(const float* ptr) +{ + const int packn = csrr_vlenb() / 4; + const word_type vl = vsetvl_e32m8(packn * 8); + + // NOTE vloxei8_v_f32m8 gets illegal instruction on d1 --- nihui + + // 128bit + static const uint32_t index_128bit[32] = { + 0, 4, 8, 12, + 0, 4, 8, 12, + 0, 4, 8, 12, + 0, 4, 8, 12, + 0, 4, 8, 12, + 0, 4, 8, 12, + 0, 4, 8, 12, + 0, 4, 8, 12 + }; + + // 256bit + static const uint32_t index_256bit[64] = { + 0, 4, 8, 12, 16, 20, 24, 28, + 0, 4, 8, 12, 16, 20, 24, 28, + 0, 4, 8, 12, 16, 20, 24, 28, + 0, 4, 8, 12, 16, 20, 24, 28, + 0, 4, 8, 12, 16, 20, 24, 28, + 0, 4, 8, 12, 16, 20, 24, 28, + 0, 4, 8, 12, 16, 20, 24, 28, + 0, 4, 8, 12, 16, 20, 24, 28 + }; + + const uint32_t* index = packn == 4 ? index_128bit : index_256bit; + vuint32m8_t bindex = vle32_v_u32m8(index, vl); + return vloxei32_v_f32m8(ptr, bindex, vl); +} + +#if __riscv_zfh +static inline vfloat16m8_t vle16_v_f16m8_f16m1(const __fp16* ptr) +{ + const int packn = csrr_vlenb() / 2; + const word_type vl = vsetvl_e16m8(packn * 8); + + // NOTE vloxei8_v_f16m8 gets illegal instruction on d1 --- nihui + + // 128bit + static const uint16_t index_128bit[64] = { + 0, 2, 4, 6, 8, 10, 12, 14, + 0, 2, 4, 6, 8, 10, 12, 14, + 0, 2, 4, 6, 8, 10, 12, 14, + 0, 2, 4, 6, 8, 10, 12, 14, + 0, 2, 4, 6, 8, 10, 12, 14, + 0, 2, 4, 6, 8, 10, 12, 14, + 0, 2, 4, 6, 8, 10, 12, 14, + 0, 2, 4, 6, 8, 10, 12, 14 + }; + + // 256bit + static const uint16_t index_256bit[128] = { + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30 + }; + + const uint16_t* index = packn == 8 ? index_128bit : index_256bit; + vuint16m8_t bindex = vle16_v_u16m8(index, vl); + return vloxei16_v_f16m8(ptr, bindex, vl); +} +#endif // __riscv_zfh #endif // __riscv_vector #endif // RISCV_USABILITY_H diff --git a/src/layer/riscv/riscv_v_071_fix.h b/src/layer/riscv/riscv_v_071_fix.h index 62cc5cf4c..007b4cfc8 100644 --- a/src/layer/riscv/riscv_v_071_fix.h +++ b/src/layer/riscv/riscv_v_071_fix.h @@ -96,6 +96,15 @@ typedef uint16x4xm1_t vuint16m1x4_t; typedef uint16x4xm2_t vuint16m2x4_t; typedef uint16x8xm1_t vuint16m1x8_t; +typedef uint8xm1_t vuint8m1_t; +typedef uint8xm2_t vuint8m2_t; +typedef uint8xm4_t vuint8m4_t; +typedef uint8xm8_t vuint8m8_t; + +typedef uint8x4xm1_t vuint8m1x4_t; +typedef uint8x4xm2_t vuint8m2x4_t; +typedef uint8x8xm1_t vuint8m1x8_t; + #define vsetvl_e32m1(n) vsetvli(n, RVV_E32, RVV_M1) #define vsetvl_e32m2(n) vsetvli(n, RVV_E32, RVV_M2) #define vsetvl_e32m4(n) vsetvli(n, RVV_E32, RVV_M4) @@ -132,6 +141,8 @@ typedef uint16x8xm1_t vuint16m1x8_t; #define vsse32_v_f32m4 vssev_float32xm4 #define vsse32_v_f32m8 vssev_float32xm8 +#define vloxei32_v_f32m8(a, i, vl) vlxev_float32xm8(a, reinterpret_cast(i), vl) + #define vlseg2e32_v_f32m1x2 vlseg2ev_float32x2xm1 #define vsseg2e32_v_f32m1x2 vsseg2ev_float32x2xm1 @@ -617,6 +628,8 @@ static inline vfloat32m1_t vfredmax_vs_f32m8_f32m1(vfloat32m1_t dst, vfloat32m8_ #define vsse16_v_f16m4 vssev_float16xm4 #define vsse16_v_f16m8 vssev_float16xm8 +#define vloxei16_v_f16m8(a, i, vl) vlxev_float16xm8(a, reinterpret_cast(i), vl) + #define vlseg2e16_v_f16m1x2 vlseg2ev_float16x2xm1 #define vsseg2e16_v_f16m1x2 vsseg2ev_float16x2xm1 @@ -1690,6 +1703,32 @@ static inline vuint16m1x8_t vcreate_u16m1x8(vuint16m1_t v0, vuint16m1_t v1, vuin #define vreinterpret_v_f16m4_u16m4(x) reinterpret_cast(x) #define vreinterpret_v_f16m8_u16m8(x) reinterpret_cast(x) +/******************************** uint8 ********************************/ +#define vle8_v_u8m1 vlev_uint8xm1 +#define vle8_v_u8m2 vlev_uint8xm2 +#define vle8_v_u8m4 vlev_uint8xm4 +#define vle8_v_u8m8 vlev_uint8xm8 + +#define vse8_v_u8m1 vsev_uint8xm1 +#define vse8_v_u8m2 vsev_uint8xm2 +#define vse8_v_u8m4 vsev_uint8xm4 +#define vse8_v_u8m8 vsev_uint8xm8 + +#define vlse8_v_u8m1 vlsev_uint8xm1 +#define vlse8_v_u8m2 vlsev_uint8xm2 +#define vlse8_v_u8m4 vlsev_uint8xm4 +#define vlse8_v_u8m8 vlsev_uint8xm8 + +#define vsse8_v_u8m1 vssev_uint8xm1 +#define vsse8_v_u8m2 vssev_uint8xm2 +#define vsse8_v_u8m4 vssev_uint8xm4 +#define vsse8_v_u8m8 vssev_uint8xm8 + +#define vmv_v_x_u8m1 vmvvx_unt8xm1 +#define vmv_v_x_u8m2 vmvvx_unt8xm2 +#define vmv_v_x_u8m4 vmvvx_unt8xm4 +#define vmv_v_x_u8m8 vmvvx_unt8xm8 + /******************************** mask ********************************/ #define vmxor_mm_b32 vmxormm_e32xm1 #define vmxor_mm_b16 vmxormm_e32xm2 diff --git a/src/layer/riscv/unaryop_riscv.cpp b/src/layer/riscv/unaryop_riscv.cpp index a7ed8f70c..280514c53 100644 --- a/src/layer/riscv/unaryop_riscv.cpp +++ b/src/layer/riscv/unaryop_riscv.cpp @@ -46,8 +46,9 @@ static int unary_op_inplace(Mat& a, const Option& opt) int w = a.w; int h = a.h; + int d = a.d; int channels = a.c; - int size = w * h; + int size = w * h * d; int elempack = a.elempack; #pragma omp parallel for num_threads(opt.num_threads) @@ -322,8 +323,9 @@ static int unary_op_inplace_fp16s(Mat& a, const Option& opt) int w = a.w; int h = a.h; + int d = a.d; int channels = a.c; - int size = w * h; + int size = w * h * d; int elempack = a.elempack; #pragma omp parallel for num_threads(opt.num_threads) diff --git a/src/layer/vulkan/binaryop_vulkan.cpp b/src/layer/vulkan/binaryop_vulkan.cpp index e43d351b9..4a1f7141e 100644 --- a/src/layer/vulkan/binaryop_vulkan.cpp +++ b/src/layer/vulkan/binaryop_vulkan.cpp @@ -47,17 +47,17 @@ int BinaryOp_vulkan::create_pipeline(const Option& opt) int elempack = 1; if (shape.dims == 1) elempack = opt.use_shader_pack8 && shape.w % 8 == 0 ? 8 : shape.w % 4 == 0 ? 4 : 1; if (shape.dims == 2) elempack = opt.use_shader_pack8 && shape.h % 8 == 0 ? 8 : shape.h % 4 == 0 ? 4 : 1; - if (shape.dims == 3) elempack = opt.use_shader_pack8 && shape.c % 8 == 0 ? 8 : shape.c % 4 == 0 ? 4 : 1; + if (shape.dims == 3 || shape.dims == 4) elempack = opt.use_shader_pack8 && shape.c % 8 == 0 ? 8 : shape.c % 4 == 0 ? 4 : 1; int elempack1 = 1; if (shape1.dims == 1) elempack1 = opt.use_shader_pack8 && shape1.w % 8 == 0 ? 8 : shape1.w % 4 == 0 ? 4 : 1; if (shape1.dims == 2) elempack1 = opt.use_shader_pack8 && shape1.h % 8 == 0 ? 8 : shape1.h % 4 == 0 ? 4 : 1; - if (shape1.dims == 3) elempack1 = opt.use_shader_pack8 && shape1.c % 8 == 0 ? 8 : shape1.c % 4 == 0 ? 4 : 1; + if (shape1.dims == 3 || shape1.dims == 4) elempack1 = opt.use_shader_pack8 && shape1.c % 8 == 0 ? 8 : shape1.c % 4 == 0 ? 4 : 1; int out_elempack = 1; if (out_shape.dims == 1) out_elempack = opt.use_shader_pack8 && out_shape.w % 8 == 0 ? 8 : out_shape.w % 4 == 0 ? 4 : 1; if (out_shape.dims == 2) out_elempack = opt.use_shader_pack8 && out_shape.h % 8 == 0 ? 8 : out_shape.h % 4 == 0 ? 4 : 1; - if (out_shape.dims == 3) out_elempack = opt.use_shader_pack8 && out_shape.c % 8 == 0 ? 8 : out_shape.c % 4 == 0 ? 4 : 1; + if (out_shape.dims == 3 || out_shape.dims == 4) out_elempack = opt.use_shader_pack8 && out_shape.c % 8 == 0 ? 8 : out_shape.c % 4 == 0 ? 4 : 1; size_t elemsize; size_t elemsize1; @@ -85,19 +85,22 @@ int BinaryOp_vulkan::create_pipeline(const Option& opt) if (shape.dims == 1) shape_packed = Mat(shape.w / elempack, (void*)0, elemsize, elempack); if (shape.dims == 2) shape_packed = Mat(shape.w, shape.h / elempack, (void*)0, elemsize, elempack); if (shape.dims == 3) shape_packed = Mat(shape.w, shape.h, shape.c / elempack, (void*)0, elemsize, elempack); + if (shape.dims == 4) shape_packed = Mat(shape.w, shape.h, shape.d, shape.c / elempack, (void*)0, elemsize, elempack); Mat shape1_packed; if (shape1.dims == 1) shape1_packed = Mat(shape1.w / elempack1, (void*)0, elemsize1, elempack1); if (shape1.dims == 2) shape1_packed = Mat(shape1.w, shape1.h / elempack1, (void*)0, elemsize1, elempack1); if (shape1.dims == 3) shape1_packed = Mat(shape1.w, shape1.h, shape1.c / elempack1, (void*)0, elemsize1, elempack1); + if (shape1.dims == 4) shape1_packed = Mat(shape1.w, shape1.h, shape1.d, shape1.c / elempack1, (void*)0, elemsize1, elempack1); Mat out_shape_packed; if (out_shape.dims == 1) out_shape_packed = Mat(out_shape.w / out_elempack, (void*)0, out_elemsize, out_elempack); if (out_shape.dims == 2) out_shape_packed = Mat(out_shape.w, out_shape.h / out_elempack, (void*)0, out_elemsize, out_elempack); if (out_shape.dims == 3) out_shape_packed = Mat(out_shape.w, out_shape.h, out_shape.c / out_elempack, (void*)0, out_elemsize, out_elempack); + if (out_shape.dims == 4) out_shape_packed = Mat(out_shape.w, out_shape.h, out_shape.d, out_shape.c / out_elempack, (void*)0, out_elemsize, out_elempack); bool broadcast = true; - if (shape.dims == shape1.dims && shape.w == shape1.w && shape.h == shape1.h && shape.c == shape1.c) + if (shape.dims == shape1.dims && shape.w == shape1.w && shape.h == shape1.h && shape.d == shape1.d && shape.c == shape1.c) { broadcast = false; } @@ -111,17 +114,17 @@ int BinaryOp_vulkan::create_pipeline(const Option& opt) specializations[2].f = b; specializations[3 + 0].i = shape_packed.dims; specializations[3 + 1].i = shape_packed.w; - specializations[3 + 2].i = shape_packed.h; + specializations[3 + 2].i = shape_packed.h * shape_packed.d; specializations[3 + 3].i = shape_packed.c; specializations[3 + 4].i = shape_packed.cstep; specializations[3 + 5].i = shape1_packed.dims; specializations[3 + 6].i = shape1_packed.w; - specializations[3 + 7].i = shape1_packed.h; + specializations[3 + 7].i = shape1_packed.h * shape1_packed.d; specializations[3 + 8].i = shape1_packed.c; specializations[3 + 9].i = shape1_packed.cstep; specializations[3 + 10].i = out_shape_packed.dims; specializations[3 + 11].i = out_shape_packed.w; - specializations[3 + 12].i = out_shape_packed.h; + specializations[3 + 12].i = out_shape_packed.h * out_shape_packed.d; specializations[3 + 13].i = out_shape_packed.c; specializations[3 + 14].i = out_shape_packed.cstep; @@ -144,6 +147,12 @@ int BinaryOp_vulkan::create_pipeline(const Option& opt) local_size_xyz.h = std::min(4, out_shape_packed.h); local_size_xyz.c = std::min(4, out_shape_packed.c); } + if (out_shape_packed.dims == 4) + { + local_size_xyz.w = std::min(4, out_shape_packed.w); + local_size_xyz.h = std::min(4, out_shape_packed.h * out_shape_packed.d); + local_size_xyz.c = std::min(4, out_shape_packed.c); + } // pack1 if (shape.dims == 0 || elempack == 1) @@ -173,23 +182,44 @@ int BinaryOp_vulkan::create_pipeline(const Option& opt) // broadcast if (shape.dims == 0 || broadcast) { - std::vector specializations(1 + 15); + std::vector specializations(1 + 18); specializations[0].i = op_type; specializations[1 + 0].i = shape_packed.dims; specializations[1 + 1].i = shape_packed.w; specializations[1 + 2].i = shape_packed.h; - specializations[1 + 3].i = shape_packed.c; - specializations[1 + 4].i = shape_packed.cstep; - specializations[1 + 5].i = shape1_packed.dims; - specializations[1 + 6].i = shape1_packed.w; - specializations[1 + 7].i = shape1_packed.h; - specializations[1 + 8].i = shape1_packed.c; - specializations[1 + 9].i = shape1_packed.cstep; - specializations[1 + 10].i = out_shape_packed.dims; - specializations[1 + 11].i = out_shape_packed.w; - specializations[1 + 12].i = out_shape_packed.h; - specializations[1 + 13].i = out_shape_packed.c; - specializations[1 + 14].i = out_shape_packed.cstep; + specializations[1 + 3].i = shape_packed.d; + specializations[1 + 4].i = shape_packed.c; + specializations[1 + 5].i = shape_packed.cstep; + specializations[1 + 6].i = shape1_packed.dims; + specializations[1 + 7].i = shape1_packed.w; + specializations[1 + 8].i = shape1_packed.h; + specializations[1 + 9].i = shape1_packed.d; + specializations[1 + 10].i = shape1_packed.c; + specializations[1 + 11].i = shape1_packed.cstep; + specializations[1 + 12].i = out_shape_packed.dims; + specializations[1 + 13].i = out_shape_packed.w; + specializations[1 + 14].i = out_shape_packed.h; + specializations[1 + 15].i = out_shape_packed.d; + specializations[1 + 16].i = out_shape_packed.c; + specializations[1 + 17].i = out_shape_packed.cstep; + + std::vector specializations_broadcast_a1_b1(1 + 15); + specializations_broadcast_a1_b1[0].i = op_type; + specializations_broadcast_a1_b1[1 + 0].i = shape_packed.dims; + specializations_broadcast_a1_b1[1 + 1].i = shape_packed.w; + specializations_broadcast_a1_b1[1 + 2].i = shape_packed.h * shape_packed.d; + specializations_broadcast_a1_b1[1 + 3].i = shape_packed.c; + specializations_broadcast_a1_b1[1 + 4].i = shape_packed.cstep; + specializations_broadcast_a1_b1[1 + 5].i = shape1_packed.dims; + specializations_broadcast_a1_b1[1 + 6].i = shape1_packed.w; + specializations_broadcast_a1_b1[1 + 7].i = shape1_packed.h * shape1_packed.d; + specializations_broadcast_a1_b1[1 + 8].i = shape1_packed.c; + specializations_broadcast_a1_b1[1 + 9].i = shape1_packed.cstep; + specializations_broadcast_a1_b1[1 + 10].i = out_shape_packed.dims; + specializations_broadcast_a1_b1[1 + 11].i = out_shape_packed.w; + specializations_broadcast_a1_b1[1 + 12].i = out_shape_packed.h * out_shape_packed.d; + specializations_broadcast_a1_b1[1 + 13].i = out_shape_packed.c; + specializations_broadcast_a1_b1[1 + 14].i = out_shape_packed.cstep; Mat local_size_xyz; if (out_shape_packed.dims == 1) @@ -210,6 +240,12 @@ int BinaryOp_vulkan::create_pipeline(const Option& opt) local_size_xyz.h = std::min(4, out_shape_packed.h); local_size_xyz.c = std::min(4, out_shape_packed.c); } + if (out_shape_packed.dims == 4) + { + local_size_xyz.w = std::min(4, out_shape_packed.w); + local_size_xyz.h = std::min(4, out_shape_packed.h * out_shape_packed.d); + local_size_xyz.c = std::min(4, out_shape_packed.c); + } // pack1 if (shape.dims == 0 || (elempack == 1 && elempack1 == 1)) @@ -232,7 +268,7 @@ int BinaryOp_vulkan::create_pipeline(const Option& opt) { pipeline_binaryop_broadcast_a1_pack4 = new Pipeline(vkdev); pipeline_binaryop_broadcast_a1_pack4->set_optimal_local_size_xyz(local_size_xyz); - pipeline_binaryop_broadcast_a1_pack4->create(LayerShaderType::binaryop_broadcast_a1_pack4, opt, specializations); + pipeline_binaryop_broadcast_a1_pack4->create(LayerShaderType::binaryop_broadcast_a1_pack4, opt, specializations_broadcast_a1_b1); } if (shape.dims == 0 || (shape1.dims == 1 && shape1.w == 1 && elempack1 == 1 && elempack == 4) @@ -240,7 +276,7 @@ int BinaryOp_vulkan::create_pipeline(const Option& opt) { pipeline_binaryop_broadcast_b1_pack4 = new Pipeline(vkdev); pipeline_binaryop_broadcast_b1_pack4->set_optimal_local_size_xyz(local_size_xyz); - pipeline_binaryop_broadcast_b1_pack4->create(LayerShaderType::binaryop_broadcast_b1_pack4, opt, specializations); + pipeline_binaryop_broadcast_b1_pack4->create(LayerShaderType::binaryop_broadcast_b1_pack4, opt, specializations_broadcast_a1_b1); } // pack8 @@ -256,7 +292,7 @@ int BinaryOp_vulkan::create_pipeline(const Option& opt) { pipeline_binaryop_broadcast_a1_pack8 = new Pipeline(vkdev); pipeline_binaryop_broadcast_a1_pack8->set_optimal_local_size_xyz(local_size_xyz); - pipeline_binaryop_broadcast_a1_pack8->create(LayerShaderType::binaryop_broadcast_a1_pack8, opt, specializations); + pipeline_binaryop_broadcast_a1_pack8->create(LayerShaderType::binaryop_broadcast_a1_pack8, opt, specializations_broadcast_a1_b1); } if ((opt.use_shader_pack8 && shape.dims == 0) || (shape1.dims == 1 && shape1.w == 1 && elempack1 == 1 && elempack == 8) @@ -264,7 +300,7 @@ int BinaryOp_vulkan::create_pipeline(const Option& opt) { pipeline_binaryop_broadcast_b1_pack8 = new Pipeline(vkdev); pipeline_binaryop_broadcast_b1_pack8->set_optimal_local_size_xyz(local_size_xyz); - pipeline_binaryop_broadcast_b1_pack8->create(LayerShaderType::binaryop_broadcast_b1_pack8, opt, specializations); + pipeline_binaryop_broadcast_b1_pack8->create(LayerShaderType::binaryop_broadcast_b1_pack8, opt, specializations_broadcast_a1_b1); } } @@ -324,7 +360,7 @@ int BinaryOp_vulkan::forward(const std::vector& bottom_blobs, std::vector } else // if (bottom_blob.dims == bottom_blob1.dims) { - if (bottom_blob.w * bottom_blob.h * bottom_blob.c * bottom_blob.elempack >= bottom_blob1.w * bottom_blob1.h * bottom_blob1.c * bottom_blob1.elempack) + if (bottom_blob.w * bottom_blob.h * bottom_blob.d * bottom_blob.c * bottom_blob.elempack >= bottom_blob1.w * bottom_blob1.h * bottom_blob1.d * bottom_blob1.c * bottom_blob1.elempack) { top_blob.create_like(bottom_blob, opt.blob_vkallocator); } @@ -343,39 +379,63 @@ int BinaryOp_vulkan::forward(const std::vector& bottom_blobs, std::vector bindings[1] = bottom_blob1; bindings[2] = top_blob; - std::vector constants(15); - constants[0].i = bottom_blob.dims; - constants[1].i = bottom_blob.w; - constants[2].i = bottom_blob.h; - constants[3].i = bottom_blob.c; - constants[4].i = bottom_blob.cstep; - constants[5].i = bottom_blob1.dims; - constants[6].i = bottom_blob1.w; - constants[7].i = bottom_blob1.h; - constants[8].i = bottom_blob1.c; - constants[9].i = bottom_blob1.cstep; - constants[10].i = top_blob.dims; - constants[11].i = top_blob.w; - constants[12].i = top_blob.h; - constants[13].i = top_blob.c; - constants[14].i = top_blob.cstep; - bool broadcast = true; if (bottom_blob.dims == bottom_blob1.dims && bottom_blob.w == bottom_blob1.w && bottom_blob.h == bottom_blob1.h + && bottom_blob.d == bottom_blob1.d && bottom_blob.c == bottom_blob1.c && bottom_blob.elempack == bottom_blob1.elempack) { broadcast = false; } - const Pipeline* pipeline = 0; if (broadcast) { + std::vector constants(18); + constants[0].i = bottom_blob.dims; + constants[1].i = bottom_blob.w; + constants[2].i = bottom_blob.h; + constants[3].i = bottom_blob.d; + constants[4].i = bottom_blob.c; + constants[5].i = bottom_blob.cstep; + constants[6].i = bottom_blob1.dims; + constants[7].i = bottom_blob1.w; + constants[8].i = bottom_blob1.h; + constants[9].i = bottom_blob1.d; + constants[10].i = bottom_blob1.c; + constants[11].i = bottom_blob1.cstep; + constants[12].i = top_blob.dims; + constants[13].i = top_blob.w; + constants[14].i = top_blob.h; + constants[15].i = top_blob.d; + constants[16].i = top_blob.c; + constants[17].i = top_blob.cstep; + + std::vector constants_broadcast_a1b1(15); + constants_broadcast_a1b1[0].i = bottom_blob.dims; + constants_broadcast_a1b1[1].i = bottom_blob.w; + constants_broadcast_a1b1[2].i = bottom_blob.h * bottom_blob.d; + constants_broadcast_a1b1[3].i = bottom_blob.c; + constants_broadcast_a1b1[4].i = bottom_blob.cstep; + constants_broadcast_a1b1[5].i = bottom_blob1.dims; + constants_broadcast_a1b1[6].i = bottom_blob1.w; + constants_broadcast_a1b1[7].i = bottom_blob1.h * bottom_blob1.d; + constants_broadcast_a1b1[8].i = bottom_blob1.c; + constants_broadcast_a1b1[9].i = bottom_blob1.cstep; + constants_broadcast_a1b1[10].i = top_blob.dims; + constants_broadcast_a1b1[11].i = top_blob.w; + constants_broadcast_a1b1[12].i = top_blob.h * top_blob.d; + constants_broadcast_a1b1[13].i = top_blob.c; + constants_broadcast_a1b1[14].i = top_blob.cstep; + + bool broadcast_a1b1 = true; + + const Pipeline* pipeline = 0; if (bottom_blob.elempack == 1 && bottom_blob1.elempack == 1) { pipeline = pipeline_binaryop_broadcast; + broadcast_a1b1 = false; } else { @@ -400,18 +460,38 @@ int BinaryOp_vulkan::forward(const std::vector& bottom_blobs, std::vector else { pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_pack8 : pipeline_binaryop_broadcast_pack4; + broadcast_a1b1 = false; } } + + cmd.record_pipeline(pipeline, bindings, broadcast_a1b1 ? constants_broadcast_a1b1 : constants, top_blob); } else { - pipeline = out_elempack == 8 ? pipeline_binaryop_pack8 - : out_elempack == 4 ? pipeline_binaryop_pack4 - : pipeline_binaryop; + std::vector constants(15); + constants[0].i = bottom_blob.dims; + constants[1].i = bottom_blob.w; + constants[2].i = bottom_blob.h * bottom_blob.d; + constants[3].i = bottom_blob.c; + constants[4].i = bottom_blob.cstep; + constants[5].i = bottom_blob1.dims; + constants[6].i = bottom_blob1.w; + constants[7].i = bottom_blob1.h * bottom_blob1.d; + constants[8].i = bottom_blob1.c; + constants[9].i = bottom_blob1.cstep; + constants[10].i = top_blob.dims; + constants[11].i = top_blob.w; + constants[12].i = top_blob.h * top_blob.d; + constants[13].i = top_blob.c; + constants[14].i = top_blob.cstep; + + const Pipeline* pipeline = out_elempack == 8 ? pipeline_binaryop_pack8 + : out_elempack == 4 ? pipeline_binaryop_pack4 + : pipeline_binaryop; + + cmd.record_pipeline(pipeline, bindings, constants, top_blob); } - cmd.record_pipeline(pipeline, bindings, constants, top_blob); - return 0; } @@ -427,7 +507,7 @@ int BinaryOp_vulkan::forward_inplace(VkMat& bottom_top_blob, VkCompute& cmd, con std::vector constants(15); constants[10].i = bottom_top_blob.dims; constants[11].i = bottom_top_blob.w; - constants[12].i = bottom_top_blob.h; + constants[12].i = bottom_top_blob.h * bottom_top_blob.d; constants[13].i = bottom_top_blob.c; constants[14].i = bottom_top_blob.cstep; @@ -458,7 +538,7 @@ int BinaryOp_vulkan::forward(const std::vector& bottom_blobs, std::v } else // if (bottom_blob.dims == bottom_blob1.dims) { - if (bottom_blob.w * bottom_blob.h * bottom_blob.c * bottom_blob.elempack >= bottom_blob1.w * bottom_blob1.h * bottom_blob1.c * bottom_blob1.elempack) + if (bottom_blob.w * bottom_blob.h * bottom_blob.d * bottom_blob.c * bottom_blob.elempack >= bottom_blob1.w * bottom_blob1.h * bottom_blob1.d * bottom_blob1.c * bottom_blob1.elempack) { top_blob.create_like(bottom_blob, opt.blob_vkallocator); } @@ -477,39 +557,63 @@ int BinaryOp_vulkan::forward(const std::vector& bottom_blobs, std::v bindings[1] = bottom_blob1; bindings[2] = top_blob; - std::vector constants(15); - constants[0].i = bottom_blob.dims; - constants[1].i = bottom_blob.w; - constants[2].i = bottom_blob.h; - constants[3].i = bottom_blob.c; - constants[4].i = 0; //bottom_blob.cstep; - constants[5].i = bottom_blob1.dims; - constants[6].i = bottom_blob1.w; - constants[7].i = bottom_blob1.h; - constants[8].i = bottom_blob1.c; - constants[9].i = 0; //bottom_blob1.cstep; - constants[10].i = top_blob.dims; - constants[11].i = top_blob.w; - constants[12].i = top_blob.h; - constants[13].i = top_blob.c; - constants[14].i = 0; //top_blob.cstep; - bool broadcast = true; if (bottom_blob.dims == bottom_blob1.dims && bottom_blob.w == bottom_blob1.w && bottom_blob.h == bottom_blob1.h + && bottom_blob.d == bottom_blob1.d && bottom_blob.c == bottom_blob1.c && bottom_blob.elempack == bottom_blob1.elempack) { broadcast = false; } - const Pipeline* pipeline = 0; if (broadcast) { + std::vector constants(18); + constants[0].i = bottom_blob.dims; + constants[1].i = bottom_blob.w; + constants[2].i = bottom_blob.h; + constants[3].i = bottom_blob.d; + constants[4].i = bottom_blob.c; + constants[5].i = 0; //bottom_blob.cstep; + constants[6].i = bottom_blob1.dims; + constants[7].i = bottom_blob1.w; + constants[8].i = bottom_blob1.h; + constants[9].i = bottom_blob1.d; + constants[10].i = bottom_blob1.c; + constants[11].i = 0; //bottom_blob1.cstep; + constants[12].i = top_blob.dims; + constants[13].i = top_blob.w; + constants[14].i = top_blob.h; + constants[15].i = top_blob.d; + constants[16].i = top_blob.c; + constants[17].i = 0; //top_blob.cstep; + + std::vector constants_broadcast_a1b1(15); + constants_broadcast_a1b1[0].i = bottom_blob.dims; + constants_broadcast_a1b1[1].i = bottom_blob.w; + constants_broadcast_a1b1[2].i = bottom_blob.h * bottom_blob.d; + constants_broadcast_a1b1[3].i = bottom_blob.c; + constants_broadcast_a1b1[4].i = 0; //bottom_blob.cstep; + constants_broadcast_a1b1[5].i = bottom_blob1.dims; + constants_broadcast_a1b1[6].i = bottom_blob1.w; + constants_broadcast_a1b1[7].i = bottom_blob1.h * bottom_blob1.d; + constants_broadcast_a1b1[8].i = bottom_blob1.c; + constants_broadcast_a1b1[9].i = 0; //bottom_blob1.cstep; + constants_broadcast_a1b1[10].i = top_blob.dims; + constants_broadcast_a1b1[11].i = top_blob.w; + constants_broadcast_a1b1[12].i = top_blob.h * top_blob.d; + constants_broadcast_a1b1[13].i = top_blob.c; + constants_broadcast_a1b1[14].i = 0; //top_blob.cstep; + + bool broadcast_a1b1 = true; + + const Pipeline* pipeline = 0; if (bottom_blob.elempack == 1 && bottom_blob1.elempack == 1) { pipeline = pipeline_binaryop_broadcast; + broadcast_a1b1 = false; } else { @@ -534,18 +638,38 @@ int BinaryOp_vulkan::forward(const std::vector& bottom_blobs, std::v else { pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_pack8 : pipeline_binaryop_broadcast_pack4; + broadcast_a1b1 = false; } } + + cmd.record_pipeline(pipeline, bindings, broadcast_a1b1 ? constants_broadcast_a1b1 : constants, top_blob); } else { - pipeline = out_elempack == 8 ? pipeline_binaryop_pack8 - : out_elempack == 4 ? pipeline_binaryop_pack4 - : pipeline_binaryop; + std::vector constants(15); + constants[0].i = bottom_blob.dims; + constants[1].i = bottom_blob.w; + constants[2].i = bottom_blob.h * bottom_blob.d; + constants[3].i = bottom_blob.c; + constants[4].i = 0; //bottom_blob.cstep; + constants[5].i = bottom_blob1.dims; + constants[6].i = bottom_blob1.w; + constants[7].i = bottom_blob1.h * bottom_blob1.d; + constants[8].i = bottom_blob1.c; + constants[9].i = 0; //bottom_blob1.cstep; + constants[10].i = top_blob.dims; + constants[11].i = top_blob.w; + constants[12].i = top_blob.h * top_blob.d; + constants[13].i = top_blob.c; + constants[14].i = 0; //top_blob.cstep; + + const Pipeline* pipeline = out_elempack == 8 ? pipeline_binaryop_pack8 + : out_elempack == 4 ? pipeline_binaryop_pack4 + : pipeline_binaryop; + + cmd.record_pipeline(pipeline, bindings, constants, top_blob); } - cmd.record_pipeline(pipeline, bindings, constants, top_blob); - return 0; } @@ -561,7 +685,7 @@ int BinaryOp_vulkan::forward_inplace(VkImageMat& bottom_top_blob, VkCompute& cmd std::vector constants(15); constants[10].i = bottom_top_blob.dims; constants[11].i = bottom_top_blob.w; - constants[12].i = bottom_top_blob.h; + constants[12].i = bottom_top_blob.h * bottom_top_blob.d; constants[13].i = bottom_top_blob.c; constants[14].i = 0; //bottom_top_blob.cstep; diff --git a/src/layer/vulkan/relu_vulkan.cpp b/src/layer/vulkan/relu_vulkan.cpp index 54d02caf7..fdc4a832f 100644 --- a/src/layer/vulkan/relu_vulkan.cpp +++ b/src/layer/vulkan/relu_vulkan.cpp @@ -84,7 +84,7 @@ int ReLU_vulkan::create_pipeline(const Option& opt) local_size_xyz.h = std::min(4, shape_packed.h); local_size_xyz.c = std::min(4, shape_packed.c); } - if (shape_packed.dims == 3) + if (shape_packed.dims == 4) { local_size_xyz.w = std::min(4, shape_packed.w); local_size_xyz.h = std::min(4, shape_packed.h * shape_packed.d); diff --git a/src/layer/vulkan/shader/binaryop_broadcast.comp b/src/layer/vulkan/shader/binaryop_broadcast.comp index e3f82583c..64d063e65 100644 --- a/src/layer/vulkan/shader/binaryop_broadcast.comp +++ b/src/layer/vulkan/shader/binaryop_broadcast.comp @@ -27,20 +27,23 @@ layout (constant_id = 0) const int op_type = 0; layout (constant_id = shape_constant_id_offset + 0) const int adims = 0; layout (constant_id = shape_constant_id_offset + 1) const int aw = 0; layout (constant_id = shape_constant_id_offset + 2) const int ah = 0; -layout (constant_id = shape_constant_id_offset + 3) const int ac = 0; -layout (constant_id = shape_constant_id_offset + 4) const int acstep = 0; - -layout (constant_id = shape_constant_id_offset + 5) const int bdims = 0; -layout (constant_id = shape_constant_id_offset + 6) const int bw = 0; -layout (constant_id = shape_constant_id_offset + 7) const int bh = 0; -layout (constant_id = shape_constant_id_offset + 8) const int bc = 0; -layout (constant_id = shape_constant_id_offset + 9) const int bcstep = 0; - -layout (constant_id = shape_constant_id_offset + 10) const int outdims = 0; -layout (constant_id = shape_constant_id_offset + 11) const int outw = 0; -layout (constant_id = shape_constant_id_offset + 12) const int outh = 0; -layout (constant_id = shape_constant_id_offset + 13) const int outc = 0; -layout (constant_id = shape_constant_id_offset + 14) const int outcstep = 0; +layout (constant_id = shape_constant_id_offset + 3) const int ad = 0; +layout (constant_id = shape_constant_id_offset + 4) const int ac = 0; +layout (constant_id = shape_constant_id_offset + 5) const int acstep = 0; + +layout (constant_id = shape_constant_id_offset + 6) const int bdims = 0; +layout (constant_id = shape_constant_id_offset + 7) const int bw = 0; +layout (constant_id = shape_constant_id_offset + 8) const int bh = 0; +layout (constant_id = shape_constant_id_offset + 9) const int bd = 0; +layout (constant_id = shape_constant_id_offset + 10) const int bc = 0; +layout (constant_id = shape_constant_id_offset + 11) const int bcstep = 0; + +layout (constant_id = shape_constant_id_offset + 12) const int outdims = 0; +layout (constant_id = shape_constant_id_offset + 13) const int outw = 0; +layout (constant_id = shape_constant_id_offset + 14) const int outh = 0; +layout (constant_id = shape_constant_id_offset + 15) const int outd = 0; +layout (constant_id = shape_constant_id_offset + 16) const int outc = 0; +layout (constant_id = shape_constant_id_offset + 17) const int outcstep = 0; #if NCNN_image_shader layout (binding = 0) uniform unfp sampler3D a_blob_3d; @@ -57,18 +60,21 @@ layout (push_constant) uniform parameter int adims; int aw; int ah; + int ad; int ac; int acstep; int bdims; int bw; int bh; + int bd; int bc; int bcstep; int outdims; int outw; int outh; + int outd; int outc; int outcstep; } p; @@ -79,7 +85,7 @@ void main() int gy = int(gl_GlobalInvocationID.y); int gz = int(gl_GlobalInvocationID.z); - if (gx >= psc(outw) || gy >= psc(outh) || gz >= psc(outc)) + if (gx >= psc(outw) || gy >= psc(outh) * psc(outd) || gz >= psc(outc)) return; #if NCNN_image_shader @@ -90,8 +96,58 @@ void main() int by = gy; int bz = gz; - if (psc(adims) == 3) + if (psc(adims) == 4) + { + int yd = gy / psc(outh); + int yh = gy % psc(outh); + + if (psc(bdims) == 3) + { + // type 28 + bx = yh; + by = yd; + bz = gz; + } + + if (psc(bdims) == 2) + { + // type 27 + bx = yd; + by = gz; + bz = 0; + } + + if (psc(bdims) == 1) + { + if (psc(bw) == 1) + { + // type 25 + bx = 0; + by = 0; + bz = 0; + } + else + { + // type 26 + bx = gz; + by = 0; + bz = 0; + } + } + } + else if (psc(adims) == 3) { + if (psc(bdims) == 4) + { + int yd = gy / psc(outh); + int yh = gy % psc(outh); + + // type 23 + ax = yh; + ay = yd; + az = gz; + } + if (psc(bdims) == 3) { if (psc(bw) == 1 && psc(bh) == 1) @@ -173,6 +229,17 @@ void main() } else if (psc(adims) == 2) { + if (psc(bdims) == 4) + { + int yd = gy / psc(outh); + int yh = gy % psc(outh); + + // type 22 + ax = yd; + ay = gz; + az = 0; + } + if (psc(bdims) == 3) { // type 14 @@ -203,13 +270,21 @@ void main() { if (psc(aw) == 1) { - // type 2 3 4 + // type 2 3 4 20 ax = 0; ay = 0; az = 0; } else { + if (psc(bdims) == 4) + { + // type 21 + ax = gz; + ay = 0; + az = 0; + } + if (psc(bdims) == 3) { // type 9 @@ -247,8 +322,53 @@ void main() int ai; int bi; - if (psc(adims) == 3) + if (psc(adims) == 4) { + int yd = gy / psc(outh); + int yh = gy % psc(outh); + + if (psc(bdims) == 3) + { + // type 28 + ai = gi; + bi = gz * psc(bcstep) + yd * psc(bw) + yh; + } + + if (psc(bdims) == 2) + { + // type 27 + ai = gi; + bi = gz * psc(bw) + yd; + } + + if (psc(bdims) == 1) + { + if (psc(bw) == 1) + { + // type 25 + ai = gi; + bi = 0; + } + else + { + // type 26 + ai = gi; + bi = gz; + } + } + } + else if (psc(adims) == 3) + { + if (psc(bdims) == 4) + { + int yd = gy / psc(outh); + int yh = gy % psc(outh); + + // type 23 + ai = gz * psc(acstep) + yd * psc(aw) + yh; + bi = gi; + } + if (psc(bdims) == 3) { if (psc(bw) == 1 && psc(bh) == 1) @@ -333,6 +453,16 @@ void main() } else if (psc(adims) == 2) { + if (psc(bdims) == 4) + { + int yd = gy / psc(outh); + int yh = gy % psc(outh); + + // type 22 + ai = gz * psc(aw) + yd; + bi = gi; + } + if (psc(bdims) == 3) { // type 14 @@ -360,12 +490,19 @@ void main() { if (psc(aw) == 1) { - // type 2 3 4 + // type 2 3 4 20 ai = 0; bi = gi; } else { + if (psc(bdims) == 4) + { + // type 21 + ai = gz; + bi = gi; + } + if (psc(bdims) == 3) { // type 9 diff --git a/src/layer/vulkan/shader/binaryop_broadcast_pack4.comp b/src/layer/vulkan/shader/binaryop_broadcast_pack4.comp index 60520ca38..1f71ae1ea 100644 --- a/src/layer/vulkan/shader/binaryop_broadcast_pack4.comp +++ b/src/layer/vulkan/shader/binaryop_broadcast_pack4.comp @@ -27,20 +27,23 @@ layout (constant_id = 0) const int op_type = 0; layout (constant_id = shape_constant_id_offset + 0) const int adims = 0; layout (constant_id = shape_constant_id_offset + 1) const int aw = 0; layout (constant_id = shape_constant_id_offset + 2) const int ah = 0; -layout (constant_id = shape_constant_id_offset + 3) const int ac = 0; -layout (constant_id = shape_constant_id_offset + 4) const int acstep = 0; - -layout (constant_id = shape_constant_id_offset + 5) const int bdims = 0; -layout (constant_id = shape_constant_id_offset + 6) const int bw = 0; -layout (constant_id = shape_constant_id_offset + 7) const int bh = 0; -layout (constant_id = shape_constant_id_offset + 8) const int bc = 0; -layout (constant_id = shape_constant_id_offset + 9) const int bcstep = 0; - -layout (constant_id = shape_constant_id_offset + 10) const int outdims = 0; -layout (constant_id = shape_constant_id_offset + 11) const int outw = 0; -layout (constant_id = shape_constant_id_offset + 12) const int outh = 0; -layout (constant_id = shape_constant_id_offset + 13) const int outc = 0; -layout (constant_id = shape_constant_id_offset + 14) const int outcstep = 0; +layout (constant_id = shape_constant_id_offset + 3) const int ad = 0; +layout (constant_id = shape_constant_id_offset + 4) const int ac = 0; +layout (constant_id = shape_constant_id_offset + 5) const int acstep = 0; + +layout (constant_id = shape_constant_id_offset + 6) const int bdims = 0; +layout (constant_id = shape_constant_id_offset + 7) const int bw = 0; +layout (constant_id = shape_constant_id_offset + 8) const int bh = 0; +layout (constant_id = shape_constant_id_offset + 9) const int bd = 0; +layout (constant_id = shape_constant_id_offset + 10) const int bc = 0; +layout (constant_id = shape_constant_id_offset + 11) const int bcstep = 0; + +layout (constant_id = shape_constant_id_offset + 12) const int outdims = 0; +layout (constant_id = shape_constant_id_offset + 13) const int outw = 0; +layout (constant_id = shape_constant_id_offset + 14) const int outh = 0; +layout (constant_id = shape_constant_id_offset + 15) const int outd = 0; +layout (constant_id = shape_constant_id_offset + 16) const int outc = 0; +layout (constant_id = shape_constant_id_offset + 17) const int outcstep = 0; #if NCNN_image_shader layout (binding = 0) uniform unfp sampler3D a_blob_3d; @@ -57,18 +60,21 @@ layout (push_constant) uniform parameter int adims; int aw; int ah; + int ad; int ac; int acstep; int bdims; int bw; int bh; + int bd; int bc; int bcstep; int outdims; int outw; int outh; + int outd; int outc; int outcstep; } p; @@ -79,7 +85,7 @@ void main() int gy = int(gl_GlobalInvocationID.y); int gz = int(gl_GlobalInvocationID.z); - if (gx >= psc(outw) || gy >= psc(outh) || gz >= psc(outc)) + if (gx >= psc(outw) || gy >= psc(outh) * psc(outd) || gz >= psc(outc)) return; #if NCNN_image_shader @@ -90,8 +96,58 @@ void main() int by = gy; int bz = gz; - if (psc(adims) == 3) + if (psc(adims) == 4) { + int yd = gy / psc(outh); + int yh = gy % psc(outh); + + if (psc(bdims) == 3) + { + // type 28 + bx = yh; + by = yd; + bz = gz; + } + + if (psc(bdims) == 2) + { + // type 27 + bx = yd; + by = gz; + bz = 0; + } + + if (psc(bdims) == 1) + { + if (psc(bw) == 1) + { + // type 25 + bx = 0; + by = 0; + bz = 0; + } + else + { + // type 26 + bx = gz; + by = 0; + bz = 0; + } + } + } + else if (psc(adims) == 3) + { + if (psc(bdims) == 4) + { + int yd = gy / psc(outh); + int yh = gy % psc(outh); + + // type 23 + ax = yh; + ay = yd; + az = gz; + } + if (psc(bdims) == 3) { if (psc(bw) == 1 && psc(bh) == 1) @@ -173,6 +229,17 @@ void main() } else if (psc(adims) == 2) { + if (psc(bdims) == 4) + { + int yd = gy / psc(outh); + int yh = gy % psc(outh); + + // type 22 + ax = yd; + ay = gz; + az = 0; + } + if (psc(bdims) == 3) { // type 14 @@ -203,13 +270,21 @@ void main() { if (psc(aw) == 1) { - // type 2 3 4 + // type 2 3 4 20 ax = 0; ay = 0; az = 0; } else { + if (psc(bdims) == 4) + { + // type 21 + ax = gz; + ay = 0; + az = 0; + } + if (psc(bdims) == 3) { // type 9 @@ -247,8 +322,53 @@ void main() int ai; int bi; - if (psc(adims) == 3) + if (psc(adims) == 4) + { + int yd = gy / psc(outh); + int yh = gy % psc(outh); + + if (psc(bdims) == 3) + { + // type 28 + ai = gi; + bi = gz * psc(bcstep) + yd * psc(bw) + yh; + } + + if (psc(bdims) == 2) + { + // type 27 + ai = gi; + bi = gz * psc(bw) + yd; + } + + if (psc(bdims) == 1) + { + if (psc(bw) == 1) + { + // type 25 + ai = gi; + bi = 0; + } + else + { + // type 26 + ai = gi; + bi = gz; + } + } + } + else if (psc(adims) == 3) { + if (psc(bdims) == 4) + { + int yd = gy / psc(outh); + int yh = gy % psc(outh); + + // type 23 + ai = gz * psc(acstep) + yd * psc(aw) + yh; + bi = gi; + } + if (psc(bdims) == 3) { if (psc(bw) == 1 && psc(bh) == 1) @@ -310,6 +430,16 @@ void main() } else if (psc(adims) == 2) { + if (psc(bdims) == 4) + { + int yd = gy / psc(outh); + int yh = gy % psc(outh); + + // type 22 + ai = gz * psc(aw) + yd; + bi = gi; + } + if (psc(bdims) == 3) { // type 14 @@ -326,6 +456,13 @@ void main() } else if (psc(adims) == 1) { + if (psc(bdims) == 4) + { + // type 21 + ai = gz; + bi = gi; + } + if (psc(bdims) == 3) { // type 9 diff --git a/src/layer/vulkan/shader/binaryop_broadcast_pack8.comp b/src/layer/vulkan/shader/binaryop_broadcast_pack8.comp index 28621f0f5..41d00199b 100644 --- a/src/layer/vulkan/shader/binaryop_broadcast_pack8.comp +++ b/src/layer/vulkan/shader/binaryop_broadcast_pack8.comp @@ -28,20 +28,23 @@ layout (constant_id = 0) const int op_type = 0; layout (constant_id = shape_constant_id_offset + 0) const int adims = 0; layout (constant_id = shape_constant_id_offset + 1) const int aw = 0; layout (constant_id = shape_constant_id_offset + 2) const int ah = 0; -layout (constant_id = shape_constant_id_offset + 3) const int ac = 0; -layout (constant_id = shape_constant_id_offset + 4) const int acstep = 0; - -layout (constant_id = shape_constant_id_offset + 5) const int bdims = 0; -layout (constant_id = shape_constant_id_offset + 6) const int bw = 0; -layout (constant_id = shape_constant_id_offset + 7) const int bh = 0; -layout (constant_id = shape_constant_id_offset + 8) const int bc = 0; -layout (constant_id = shape_constant_id_offset + 9) const int bcstep = 0; - -layout (constant_id = shape_constant_id_offset + 10) const int outdims = 0; -layout (constant_id = shape_constant_id_offset + 11) const int outw = 0; -layout (constant_id = shape_constant_id_offset + 12) const int outh = 0; -layout (constant_id = shape_constant_id_offset + 13) const int outc = 0; -layout (constant_id = shape_constant_id_offset + 14) const int outcstep = 0; +layout (constant_id = shape_constant_id_offset + 3) const int ad = 0; +layout (constant_id = shape_constant_id_offset + 4) const int ac = 0; +layout (constant_id = shape_constant_id_offset + 5) const int acstep = 0; + +layout (constant_id = shape_constant_id_offset + 6) const int bdims = 0; +layout (constant_id = shape_constant_id_offset + 7) const int bw = 0; +layout (constant_id = shape_constant_id_offset + 8) const int bh = 0; +layout (constant_id = shape_constant_id_offset + 9) const int bd = 0; +layout (constant_id = shape_constant_id_offset + 10) const int bc = 0; +layout (constant_id = shape_constant_id_offset + 11) const int bcstep = 0; + +layout (constant_id = shape_constant_id_offset + 12) const int outdims = 0; +layout (constant_id = shape_constant_id_offset + 13) const int outw = 0; +layout (constant_id = shape_constant_id_offset + 14) const int outh = 0; +layout (constant_id = shape_constant_id_offset + 15) const int outd = 0; +layout (constant_id = shape_constant_id_offset + 16) const int outc = 0; +layout (constant_id = shape_constant_id_offset + 17) const int outcstep = 0; #if NCNN_image_shader layout (binding = 0) uniform unfp sampler3D a_blob_3d; @@ -58,18 +61,21 @@ layout (push_constant) uniform parameter int adims; int aw; int ah; + int ad; int ac; int acstep; int bdims; int bw; int bh; + int bd; int bc; int bcstep; int outdims; int outw; int outh; + int outd; int outc; int outcstep; } p; @@ -80,7 +86,7 @@ void main() int gy = int(gl_GlobalInvocationID.y); int gz = int(gl_GlobalInvocationID.z); - if (gx >= psc(outw) || gy >= psc(outh) || gz >= psc(outc)) + if (gx >= psc(outw) || gy >= psc(outh) * psc(outd) || gz >= psc(outc)) return; #if NCNN_image_shader @@ -91,8 +97,58 @@ void main() int by = gy; int bz = gz; - if (psc(adims) == 3) + if (psc(adims) == 4) { + int yd = gy / psc(outh); + int yh = gy % psc(outh); + + if (psc(bdims) == 3) + { + // type 28 + bx = yh; + by = yd; + bz = gz; + } + + if (psc(bdims) == 2) + { + // type 27 + bx = yd; + by = gz; + bz = 0; + } + + if (psc(bdims) == 1) + { + if (psc(bw) == 1) + { + // type 25 + bx = 0; + by = 0; + bz = 0; + } + else + { + // type 26 + bx = gz; + by = 0; + bz = 0; + } + } + } + else if (psc(adims) == 3) + { + if (psc(bdims) == 4) + { + int yd = gy / psc(outh); + int yh = gy % psc(outh); + + // type 23 + ax = yh; + ay = yd; + az = gz; + } + if (psc(bdims) == 3) { if (psc(bw) == 1 && psc(bh) == 1) @@ -174,6 +230,17 @@ void main() } else if (psc(adims) == 2) { + if (psc(bdims) == 4) + { + int yd = gy / psc(outh); + int yh = gy % psc(outh); + + // type 22 + ax = yd; + ay = gz; + az = 0; + } + if (psc(bdims) == 3) { // type 14 @@ -204,13 +271,21 @@ void main() { if (psc(aw) == 1) { - // type 2 3 4 + // type 2 3 4 20 ax = 0; ay = 0; az = 0; } else { + if (psc(bdims) == 4) + { + // type 21 + ax = gz; + ay = 0; + az = 0; + } + if (psc(bdims) == 3) { // type 9 @@ -248,8 +323,53 @@ void main() int ai; int bi; - if (psc(adims) == 3) + if (psc(adims) == 4) + { + int yd = gy / psc(outh); + int yh = gy % psc(outh); + + if (psc(bdims) == 3) + { + // type 28 + ai = gi; + bi = gz * psc(bcstep) + yd * psc(bw) + yh; + } + + if (psc(bdims) == 2) + { + // type 27 + ai = gi; + bi = gz * psc(bw) + yd; + } + + if (psc(bdims) == 1) + { + if (psc(bw) == 1) + { + // type 25 + ai = gi; + bi = 0; + } + else + { + // type 26 + ai = gi; + bi = gz; + } + } + } + else if (psc(adims) == 3) { + if (psc(bdims) == 4) + { + int yd = gy / psc(outh); + int yh = gy % psc(outh); + + // type 23 + ai = gz * psc(acstep) + yd * psc(aw) + yh; + bi = gi; + } + if (psc(bdims) == 3) { if (psc(bw) == 1 && psc(bh) == 1) @@ -311,6 +431,16 @@ void main() } else if (psc(adims) == 2) { + if (psc(bdims) == 4) + { + int yd = gy / psc(outh); + int yh = gy % psc(outh); + + // type 22 + ai = gz * psc(aw) + yd; + bi = gi; + } + if (psc(bdims) == 3) { // type 14 @@ -327,6 +457,13 @@ void main() } else if (psc(adims) == 1) { + if (psc(bdims) == 4) + { + // type 21 + ai = gz; + bi = gi; + } + if (psc(bdims) == 3) { // type 9 diff --git a/src/layer/vulkan/unaryop_vulkan.cpp b/src/layer/vulkan/unaryop_vulkan.cpp index b8e09dba1..e4fe8e9f0 100644 --- a/src/layer/vulkan/unaryop_vulkan.cpp +++ b/src/layer/vulkan/unaryop_vulkan.cpp @@ -35,7 +35,7 @@ int UnaryOp_vulkan::create_pipeline(const Option& opt) int elempack = 1; if (shape.dims == 1) elempack = opt.use_shader_pack8 && shape.w % 8 == 0 ? 8 : shape.w % 4 == 0 ? 4 : 1; if (shape.dims == 2) elempack = opt.use_shader_pack8 && shape.h % 8 == 0 ? 8 : shape.h % 4 == 0 ? 4 : 1; - if (shape.dims == 3) elempack = opt.use_shader_pack8 && shape.c % 8 == 0 ? 8 : shape.c % 4 == 0 ? 4 : 1; + if (shape.dims == 3 || shape.dims == 4) elempack = opt.use_shader_pack8 && shape.c % 8 == 0 ? 8 : shape.c % 4 == 0 ? 4 : 1; size_t elemsize; if (opt.use_fp16_storage) @@ -55,12 +55,13 @@ int UnaryOp_vulkan::create_pipeline(const Option& opt) if (shape.dims == 1) shape_packed = Mat(shape.w / elempack, (void*)0, elemsize, elempack); if (shape.dims == 2) shape_packed = Mat(shape.w, shape.h / elempack, (void*)0, elemsize, elempack); if (shape.dims == 3) shape_packed = Mat(shape.w, shape.h, shape.c / elempack, (void*)0, elemsize, elempack); + if (shape.dims == 4) shape_packed = Mat(shape.w, shape.h, shape.d, shape.c / elempack, (void*)0, elemsize, elempack); std::vector specializations(1 + 5); specializations[0].i = op_type; specializations[1 + 0].i = shape_packed.dims; specializations[1 + 1].i = shape_packed.w; - specializations[1 + 2].i = shape_packed.h; + specializations[1 + 2].i = shape_packed.h * shape_packed.d; specializations[1 + 3].i = shape_packed.c; specializations[1 + 4].i = shape_packed.cstep; @@ -83,6 +84,12 @@ int UnaryOp_vulkan::create_pipeline(const Option& opt) local_size_xyz.h = std::min(4, shape_packed.h); local_size_xyz.c = std::min(4, shape_packed.c); } + if (shape_packed.dims == 4) + { + local_size_xyz.w = std::min(4, shape_packed.w); + local_size_xyz.h = std::min(4, shape_packed.h * shape_packed.d); + local_size_xyz.c = std::min(4, shape_packed.c); + } // pack1 if (shape.dims == 0 || elempack == 1) @@ -135,7 +142,7 @@ int UnaryOp_vulkan::forward_inplace(VkMat& bottom_top_blob, VkCompute& cmd, cons std::vector constants(5); constants[0].i = bottom_top_blob.dims; constants[1].i = bottom_top_blob.w; - constants[2].i = bottom_top_blob.h; + constants[2].i = bottom_top_blob.h * bottom_top_blob.d; constants[3].i = bottom_top_blob.c; constants[4].i = bottom_top_blob.cstep; @@ -159,7 +166,7 @@ int UnaryOp_vulkan::forward_inplace(VkImageMat& bottom_top_blob, VkCompute& cmd, std::vector constants(5); constants[0].i = bottom_top_blob.dims; constants[1].i = bottom_top_blob.w; - constants[2].i = bottom_top_blob.h; + constants[2].i = bottom_top_blob.h * bottom_top_blob.d; constants[3].i = bottom_top_blob.c; constants[4].i = 0; //bottom_top_blob.cstep; diff --git a/src/layer/x86/binaryop_x86.cpp b/src/layer/x86/binaryop_x86.cpp index 24b5f11f0..0ef73ee8a 100644 --- a/src/layer/x86/binaryop_x86.cpp +++ b/src/layer/x86/binaryop_x86.cpp @@ -44,20 +44,203 @@ static int binary_op_pack8(const Mat& a, const Mat& b, Mat& c, const Option& opt int w = a.w; int h = a.h; + int d = a.d; int channels = a.c; - int size = w * h; + int size = w * h * d; size_t elemsize = a.elemsize; int elempack = a.elempack; int w1 = b.w; int h1 = b.h; + int d1 = b.d; int channels1 = b.c; - int size1 = w1 * h1; + int size1 = w1 * h1 * d1; size_t elemsize1 = b.elemsize; int elempack1 = b.elempack; - if (a.dims == 3) + if (a.dims == 4) { + if (b.dims == 4) + { + // type 29 + c.create(w, h, d, channels, elemsize, elempack, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* ptr = a.channel(q); + const float* ptr1 = b.channel(q); + float* outptr = c.channel(q); + + for (int i = 0; i < size; i++) + { + __m256 _p = _mm256_loadu_ps(ptr); + __m256 _p1 = _mm256_loadu_ps(ptr1); + __m256 _outp = op(_p, _p1); + _mm256_storeu_ps(outptr, _outp); + ptr += 8; + ptr1 += 8; + outptr += 8; + } + } + + return 0; + } + + c.create(w, h, d, channels, elemsize, elempack, opt.blob_allocator); + if (c.empty()) + return -100; + + if (b.dims == 3) + { + // type 28 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* ptr = a.channel(q); + const float* ptr1 = b.channel(q); + float* outptr = c.channel(q); + + for (int z = 0; z < d; z++) + { + for (int y = 0; y < h; y++) + { + __m256 _b0 = _mm256_loadu_ps(ptr1); + for (int x = 0; x < w; x++) + { + __m256 _p = _mm256_loadu_ps(ptr); + __m256 _outp = op(_p, _b0); + _mm256_storeu_ps(outptr, _outp); + ptr += 8; + outptr += 8; + } + + ptr1 += 8; + } + } + } + + return 0; + } + + if (b.dims == 2) + { + // type 27 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* ptr = a.channel(q); + const float* ptr1 = b.row(q); + float* outptr = c.channel(q); + + for (int z = 0; z < d; z++) + { + __m256 _b0 = _mm256_loadu_ps(ptr1); + for (int y = 0; y < h; y++) + { + for (int x = 0; x < w; x++) + { + __m256 _p = _mm256_loadu_ps(ptr); + __m256 _outp = op(_p, _b0); + _mm256_storeu_ps(outptr, _outp); + ptr += 8; + outptr += 8; + } + } + + ptr1 += 8; + } + } + + return 0; + } + + if (b.dims == 1) + { + if (b.w == 1 && elempack1 == 1) + { + // type 25 + __m256 _b0 = _mm256_set1_ps(b[0]); + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* ptr = a.channel(q); + float* outptr = c.channel(q); + + for (int i = 0; i < size; i++) + { + __m256 _p = _mm256_loadu_ps(ptr); + __m256 _outp = op(_p, _b0); + _mm256_storeu_ps(outptr, _outp); + ptr += 8; + outptr += 8; + } + } + + return 0; + } + + // type 26 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* ptr = a.channel(q); + __m256 _b0 = _mm256_loadu_ps((const float*)b + q * 8); + float* outptr = c.channel(q); + + for (int i = 0; i < size; i++) + { + __m256 _p = _mm256_loadu_ps(ptr); + __m256 _outp = op(_p, _b0); + _mm256_storeu_ps(outptr, _outp); + ptr += 8; + outptr += 8; + } + } + + return 0; + } + } + else if (a.dims == 3) + { + if (b.dims == 4) + { + // type 23 + c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const float* ptr = a.channel(q); + const float* ptr1 = b.channel(q); + float* outptr = c.channel(q); + + for (int z = 0; z < d1; z++) + { + for (int y = 0; y < h1; y++) + { + __m256 _a0 = _mm256_loadu_ps(ptr); + for (int x = 0; x < w1; x++) + { + __m256 _p = _mm256_loadu_ps(ptr1); + __m256 _outp = op(_a0, _p); + _mm256_storeu_ps(outptr, _outp); + ptr1 += 8; + outptr += 8; + } + + ptr += 8; + } + } + } + + return 0; + } + if (b.dims == 3) { if (w1 == 1 && h1 == 1 && channels1 == channels) @@ -406,6 +589,42 @@ static int binary_op_pack8(const Mat& a, const Mat& b, Mat& c, const Option& opt } else if (a.dims == 2) { + if (b.dims == 4) + { + // type 22 + c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const float* ptr = a.row(q); + const float* ptr1 = b.channel(q); + float* outptr = c.channel(q); + + for (int z = 0; z < d1; z++) + { + __m256 _a0 = _mm256_loadu_ps(ptr); + for (int y = 0; y < h1; y++) + { + for (int x = 0; x < w1; x++) + { + __m256 _p = _mm256_loadu_ps(ptr1); + __m256 _outp = op(_a0, _p); + _mm256_storeu_ps(outptr, _outp); + ptr1 += 8; + outptr += 8; + } + } + + ptr += 8; + } + } + + return 0; + } + if (b.dims == 3) { // type 14 @@ -514,6 +733,33 @@ static int binary_op_pack8(const Mat& a, const Mat& b, Mat& c, const Option& opt { if (a.w == 1 && elempack == 1) { + if (b.dims == 4) + { + // type 20 + c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); + if (c.empty()) + return -100; + + __m256 _a0 = _mm256_set1_ps(a[0]); + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const float* ptr1 = b.channel(q); + float* outptr = c.channel(q); + + for (int i = 0; i < size1; i++) + { + __m256 _p1 = _mm256_loadu_ps(ptr1); + __m256 _outp = op(_a0, _p1); + _mm256_storeu_ps(outptr, _outp); + ptr1 += 8; + outptr += 8; + } + } + + return 0; + } + if (b.dims == 3) { // type 4 @@ -586,6 +832,33 @@ static int binary_op_pack8(const Mat& a, const Mat& b, Mat& c, const Option& opt } } + if (b.dims == 4) + { + // type 21 + c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + __m256 _a0 = _mm256_loadu_ps((const float*)a + q * 8); + const float* ptr1 = b.channel(q); + float* outptr = c.channel(q); + + for (int i = 0; i < size1; i++) + { + __m256 _p1 = _mm256_loadu_ps(ptr1); + __m256 _outp = op(_a0, _p1); + _mm256_storeu_ps(outptr, _outp); + ptr1 += 8; + outptr += 8; + } + } + + return 0; + } + if (b.dims == 3) { // type 9 @@ -693,8 +966,9 @@ static int binary_op_scalar_inplace_pack8(Mat& a, float b, const Option& opt) int w = a.w; int h = a.h; + int d = a.d; int channels = a.c; - int size = w * h; + int size = w * h * d; __m256 _b = _mm256_set1_ps(b); @@ -795,20 +1069,203 @@ static int binary_op_pack4(const Mat& a, const Mat& b, Mat& c, const Option& opt int w = a.w; int h = a.h; + int d = a.d; int channels = a.c; - int size = w * h; + int size = w * h * d; size_t elemsize = a.elemsize; int elempack = a.elempack; int w1 = b.w; int h1 = b.h; + int d1 = b.d; int channels1 = b.c; - int size1 = w1 * h1; + int size1 = w1 * h1 * d1; size_t elemsize1 = b.elemsize; int elempack1 = b.elempack; - if (a.dims == 3) + if (a.dims == 4) { + if (b.dims == 4) + { + // type 29 + c.create(w, h, d, channels, elemsize, elempack, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* ptr = a.channel(q); + const float* ptr1 = b.channel(q); + float* outptr = c.channel(q); + + for (int i = 0; i < size; i++) + { + __m128 _p = _mm_loadu_ps(ptr); + __m128 _p1 = _mm_loadu_ps(ptr1); + __m128 _outp = op(_p, _p1); + _mm_storeu_ps(outptr, _outp); + ptr += 4; + ptr1 += 4; + outptr += 4; + } + } + + return 0; + } + + c.create(w, h, d, channels, elemsize, elempack, opt.blob_allocator); + if (c.empty()) + return -100; + + if (b.dims == 3) + { + // type 28 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* ptr = a.channel(q); + const float* ptr1 = b.channel(q); + float* outptr = c.channel(q); + + for (int z = 0; z < d; z++) + { + for (int y = 0; y < h; y++) + { + __m128 _b0 = _mm_loadu_ps(ptr1); + for (int x = 0; x < w; x++) + { + __m128 _p = _mm_loadu_ps(ptr); + __m128 _outp = op(_p, _b0); + _mm_storeu_ps(outptr, _outp); + ptr += 4; + outptr += 4; + } + + ptr1 += 4; + } + } + } + + return 0; + } + + if (b.dims == 2) + { + // type 27 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* ptr = a.channel(q); + const float* ptr1 = b.row(q); + float* outptr = c.channel(q); + + for (int z = 0; z < d; z++) + { + __m128 _b0 = _mm_loadu_ps(ptr1); + for (int y = 0; y < h; y++) + { + for (int x = 0; x < w; x++) + { + __m128 _p = _mm_loadu_ps(ptr); + __m128 _outp = op(_p, _b0); + _mm_storeu_ps(outptr, _outp); + ptr += 4; + outptr += 4; + } + } + + ptr1 += 4; + } + } + + return 0; + } + + if (b.dims == 1) + { + if (b.w == 1 && elempack1 == 1) + { + // type 25 + __m128 _b0 = _mm_set1_ps(b[0]); + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* ptr = a.channel(q); + float* outptr = c.channel(q); + + for (int i = 0; i < size; i++) + { + __m128 _p = _mm_loadu_ps(ptr); + __m128 _outp = op(_p, _b0); + _mm_storeu_ps(outptr, _outp); + ptr += 4; + outptr += 4; + } + } + + return 0; + } + + // type 26 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* ptr = a.channel(q); + __m128 _b0 = _mm_loadu_ps((const float*)b + q * 4); + float* outptr = c.channel(q); + + for (int i = 0; i < size; i++) + { + __m128 _p = _mm_loadu_ps(ptr); + __m128 _outp = op(_p, _b0); + _mm_storeu_ps(outptr, _outp); + ptr += 4; + outptr += 4; + } + } + + return 0; + } + } + else if (a.dims == 3) + { + if (b.dims == 4) + { + // type 23 + c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const float* ptr = a.channel(q); + const float* ptr1 = b.channel(q); + float* outptr = c.channel(q); + + for (int z = 0; z < d1; z++) + { + for (int y = 0; y < h1; y++) + { + __m128 _a0 = _mm_loadu_ps(ptr); + for (int x = 0; x < w1; x++) + { + __m128 _p = _mm_loadu_ps(ptr1); + __m128 _outp = op(_a0, _p); + _mm_storeu_ps(outptr, _outp); + ptr1 += 4; + outptr += 4; + } + + ptr += 4; + } + } + } + + return 0; + } + if (b.dims == 3) { if (w1 == 1 && h1 == 1 && channels1 == channels) @@ -1114,7 +1571,7 @@ static int binary_op_pack4(const Mat& a, const Mat& b, Mat& c, const Option& opt if (b.w == 1 && elempack1 == 1) { // type 16 - __m128 _b0 = _mm_set1_ps(((const float*)b)[0]); + __m128 _b0 = _mm_set1_ps(b[0]); #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) { @@ -1157,6 +1614,42 @@ static int binary_op_pack4(const Mat& a, const Mat& b, Mat& c, const Option& opt } else if (a.dims == 2) { + if (b.dims == 4) + { + // type 22 + c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const float* ptr = a.row(q); + const float* ptr1 = b.channel(q); + float* outptr = c.channel(q); + + for (int z = 0; z < d1; z++) + { + __m128 _a0 = _mm_loadu_ps(ptr); + for (int y = 0; y < h1; y++) + { + for (int x = 0; x < w1; x++) + { + __m128 _p = _mm_loadu_ps(ptr1); + __m128 _outp = op(_a0, _p); + _mm_storeu_ps(outptr, _outp); + ptr1 += 4; + outptr += 4; + } + } + + ptr += 4; + } + } + + return 0; + } + if (b.dims == 3) { // type 14 @@ -1223,7 +1716,7 @@ static int binary_op_pack4(const Mat& a, const Mat& b, Mat& c, const Option& opt if (b.w == 1 && elempack1 == 1) { // type 11 - __m128 _b0 = _mm_set1_ps(((const float*)b)[0]); + __m128 _b0 = _mm_set1_ps(b[0]); const float* ptr = a; float* outptr = c; for (int i = 0; i < size; i++) @@ -1265,6 +1758,33 @@ static int binary_op_pack4(const Mat& a, const Mat& b, Mat& c, const Option& opt { if (a.w == 1 && elempack == 1) { + if (b.dims == 4) + { + // type 20 + c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); + if (c.empty()) + return -100; + + __m128 _a0 = _mm_set1_ps(a[0]); + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + const float* ptr1 = b.channel(q); + float* outptr = c.channel(q); + + for (int i = 0; i < size1; i++) + { + __m128 _p1 = _mm_loadu_ps(ptr1); + __m128 _outp = op(_a0, _p1); + _mm_storeu_ps(outptr, _outp); + ptr1 += 4; + outptr += 4; + } + } + + return 0; + } + if (b.dims == 3) { // type 4 @@ -1272,7 +1792,7 @@ static int binary_op_pack4(const Mat& a, const Mat& b, Mat& c, const Option& opt if (c.empty()) return -100; - __m128 _a0 = _mm_set1_ps(((const float*)a)[0]); + __m128 _a0 = _mm_set1_ps(a[0]); #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels1; q++) { @@ -1299,7 +1819,7 @@ static int binary_op_pack4(const Mat& a, const Mat& b, Mat& c, const Option& opt if (c.empty()) return -100; - __m128 _a0 = _mm_set1_ps(((const float*)a)[0]); + __m128 _a0 = _mm_set1_ps(a[0]); const float* ptr1 = b; float* outptr = c; for (int i = 0; i < size1; i++) @@ -1321,7 +1841,7 @@ static int binary_op_pack4(const Mat& a, const Mat& b, Mat& c, const Option& opt if (c.empty()) return -100; - __m128 _a0 = _mm_set1_ps(((const float*)a)[0]); + __m128 _a0 = _mm_set1_ps(a[0]); const float* ptr1 = b; float* outptr = c; for (int i = 0; i < w1; i++) @@ -1337,6 +1857,33 @@ static int binary_op_pack4(const Mat& a, const Mat& b, Mat& c, const Option& opt } } + if (b.dims == 4) + { + // type 21 + c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); + if (c.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels1; q++) + { + __m128 _a0 = _mm_loadu_ps((const float*)a + q * 4); + const float* ptr1 = b.channel(q); + float* outptr = c.channel(q); + + for (int i = 0; i < size1; i++) + { + __m128 _p1 = _mm_loadu_ps(ptr1); + __m128 _outp = op(_a0, _p1); + _mm_storeu_ps(outptr, _outp); + ptr1 += 4; + outptr += 4; + } + } + + return 0; + } + if (b.dims == 3) { // type 9 @@ -1402,7 +1949,7 @@ static int binary_op_pack4(const Mat& a, const Mat& b, Mat& c, const Option& opt if (b.w == 1 && elempack1 == 1) { // type 6 - __m128 _b0 = _mm_set1_ps(((const float*)b)[0]); + __m128 _b0 = _mm_set1_ps(b[0]); const float* ptr = a; float* outptr = c; for (int i = 0; i < w; i++) @@ -1444,8 +1991,9 @@ static int binary_op_scalar_inplace_pack4(Mat& a, float b, const Option& opt) int w = a.w; int h = a.h; + int d = a.d; int channels = a.c; - int size = w * h; + int size = w * h * d; __m128 _b = _mm_set1_ps((float)b); @@ -1574,10 +2122,10 @@ int BinaryOp_x86::forward(const std::vector& bottom_blobs, std::vector return binary_op_pack8(bottom_blob, bottom_blob1, top_blob, opt); if (op_type == Operation_RSUB) - return binary_op_pack8(bottom_blob, bottom_blob1, top_blob, opt); + return binary_op_pack8(bottom_blob1, bottom_blob, top_blob, opt); if (op_type == Operation_RDIV) - return binary_op_pack8(bottom_blob, bottom_blob1, top_blob, opt); + return binary_op_pack8(bottom_blob1, bottom_blob, top_blob, opt); } #endif // __AVX__ @@ -1605,10 +2153,10 @@ int BinaryOp_x86::forward(const std::vector& bottom_blobs, std::vector return binary_op_pack4(bottom_blob, bottom_blob1, top_blob, opt); if (op_type == Operation_RSUB) - return binary_op_pack4(bottom_blob, bottom_blob1, top_blob, opt); + return binary_op_pack4(bottom_blob1, bottom_blob, top_blob, opt); if (op_type == Operation_RDIV) - return binary_op_pack4(bottom_blob, bottom_blob1, top_blob, opt); + return binary_op_pack4(bottom_blob1, bottom_blob, top_blob, opt); } #endif // __SSE2__ diff --git a/src/net.cpp b/src/net.cpp index a64465a95..0fb50f78f 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -797,7 +797,7 @@ int NetPrivate::convert_layout(Mat& bottom_blob, const Layer* layer, const Optio int elemcount = 0; if (dims == 1) elemcount = bottom_blob.elempack * bottom_blob.w; if (dims == 2) elemcount = bottom_blob.elempack * bottom_blob.h; - if (dims == 3) elemcount = bottom_blob.elempack * bottom_blob.c; + if (dims == 3 || dims == 4) elemcount = bottom_blob.elempack * bottom_blob.c; int elembits = bottom_blob.elembits(); diff --git a/tests/test_binaryop.cpp b/tests/test_binaryop.cpp index e98b08fb9..44e3d1b36 100644 --- a/tests/test_binaryop.cpp +++ b/tests/test_binaryop.cpp @@ -50,7 +50,7 @@ static int test_binaryop(const ncnn::Mat& _a, const ncnn::Mat& _b) int ret = test_layer("BinaryOp", pd, weights, ab); if (ret != 0) { - fprintf(stderr, "test_binaryop failed a.dims=%d a=(%d %d %d) b.dims=%d b=(%d %d %d) op_type=%d\n", a.dims, a.w, a.h, a.c, b.dims, b.w, b.h, b.c, op_type); + fprintf(stderr, "test_binaryop failed a.dims=%d a=(%d %d %d %d) b.dims=%d b=(%d %d %d %d) op_type=%d\n", a.dims, a.w, a.h, a.d, a.c, b.dims, b.w, b.h, b.d, b.c, op_type); } return ret; @@ -76,7 +76,7 @@ static int test_binaryop(const ncnn::Mat& _a, float b) int ret = test_layer("BinaryOp", pd, weights, a); if (ret != 0) { - fprintf(stderr, "test_binaryop failed a.dims=%d a=(%d %d %d) b=%f op_type=%d\n", a.dims, a.w, a.h, a.c, b, op_type); + fprintf(stderr, "test_binaryop failed a.dims=%d a=(%d %d %d %d) b=%f op_type=%d\n", a.dims, a.w, a.h, a.d, a.c, b, op_type); } return ret; @@ -234,6 +234,86 @@ static int test_binaryop_19() || test_binaryop(RandomMat(11, 6, 16), RandomMat(11, 6, 16)); } +static int test_binaryop_20() +{ + return 0 + || test_binaryop(RandomMat(1), RandomMat(11, 3, 4, 2)) + || test_binaryop(RandomMat(1), RandomMat(11, 3, 4, 4)) + || test_binaryop(RandomMat(1), RandomMat(11, 3, 4, 16)); +} + +static int test_binaryop_21() +{ + return 0 + || test_binaryop(RandomMat(2), RandomMat(11, 3, 4, 2)) + || test_binaryop(RandomMat(4), RandomMat(11, 3, 4, 4)) + || test_binaryop(RandomMat(16), RandomMat(11, 3, 4, 16)); +} + +static int test_binaryop_22() +{ + return 0 + || test_binaryop(RandomMat(4, 2), RandomMat(11, 3, 4, 2)) + || test_binaryop(RandomMat(4, 4), RandomMat(11, 3, 4, 4)) + || test_binaryop(RandomMat(4, 16), RandomMat(11, 3, 4, 16)); +} + +static int test_binaryop_23() +{ + return 0 + || test_binaryop(RandomMat(3, 4, 2), RandomMat(11, 3, 4, 2)) + || test_binaryop(RandomMat(3, 4, 4), RandomMat(11, 3, 4, 4)) + || test_binaryop(RandomMat(3, 4, 16), RandomMat(11, 3, 4, 16)); +} + +static int test_binaryop_24() +{ + return 0 + || test_binaryop(RandomMat(11, 3, 4, 2), 1.f) + || test_binaryop(RandomMat(11, 3, 4, 4), 1.f) + || test_binaryop(RandomMat(11, 3, 4, 16), 1.f); +} + +static int test_binaryop_25() +{ + return 0 + || test_binaryop(RandomMat(11, 3, 4, 2), RandomMat(1)) + || test_binaryop(RandomMat(11, 3, 4, 4), RandomMat(1)) + || test_binaryop(RandomMat(11, 3, 4, 16), RandomMat(1)); +} + +static int test_binaryop_26() +{ + return 0 + || test_binaryop(RandomMat(11, 3, 4, 2), RandomMat(2)) + || test_binaryop(RandomMat(11, 3, 4, 4), RandomMat(4)) + || test_binaryop(RandomMat(11, 3, 4, 16), RandomMat(16)); +} + +static int test_binaryop_27() +{ + return 0 + || test_binaryop(RandomMat(11, 3, 4, 2), RandomMat(4, 2)) + || test_binaryop(RandomMat(11, 3, 4, 4), RandomMat(4, 4)) + || test_binaryop(RandomMat(11, 3, 4, 16), RandomMat(4, 16)); +} + +static int test_binaryop_28() +{ + return 0 + || test_binaryop(RandomMat(11, 3, 4, 2), RandomMat(3, 4, 2)) + || test_binaryop(RandomMat(11, 3, 4, 4), RandomMat(3, 4, 4)) + || test_binaryop(RandomMat(11, 3, 4, 16), RandomMat(3, 4, 16)); +} + +static int test_binaryop_29() +{ + return 0 + || test_binaryop(RandomMat(11, 3, 4, 2), RandomMat(11, 3, 4, 2)) + || test_binaryop(RandomMat(11, 3, 4, 4), RandomMat(11, 3, 4, 4)) + || test_binaryop(RandomMat(11, 3, 4, 16), RandomMat(11, 3, 4, 16)); +} + static int test_binaryop_s1() { return 0 @@ -324,6 +404,16 @@ int main() || test_binaryop_17() || test_binaryop_18() || test_binaryop_19() + || test_binaryop_20() + || test_binaryop_21() + || test_binaryop_22() + || test_binaryop_23() + || test_binaryop_24() + || test_binaryop_25() + || test_binaryop_26() + || test_binaryop_27() + || test_binaryop_28() + || test_binaryop_29() || test_binaryop_s1() || test_binaryop_s2() || test_binaryop_s3() diff --git a/tests/test_unaryop.cpp b/tests/test_unaryop.cpp index 42e7115ba..473ab2cab 100644 --- a/tests/test_unaryop.cpp +++ b/tests/test_unaryop.cpp @@ -49,13 +49,21 @@ static int test_unaryop(const ncnn::Mat& _a) int ret = test_layer("UnaryOp", pd, weights, a); if (ret != 0) { - fprintf(stderr, "test_unaryop failed a.dims=%d a=(%d %d %d) op_type=%d\n", a.dims, a.w, a.h, a.c, op_type); + fprintf(stderr, "test_unaryop failed a.dims=%d a=(%d %d %d %d) op_type=%d\n", a.dims, a.w, a.h, a.d, a.c, op_type); } return ret; } static int test_unaryop_0() +{ + return 0 + || test_unaryop(RandomMat(11, 3, 2, 16)) + || test_unaryop(RandomMat(10, 2, 2, 12)) + || test_unaryop(RandomMat(6, 1, 5, 13)); +} + +static int test_unaryop_1() { return 0 || test_unaryop(RandomMat(11, 7, 16)) @@ -63,7 +71,7 @@ static int test_unaryop_0() || test_unaryop(RandomMat(6, 5, 13)); } -static int test_unaryop_1() +static int test_unaryop_2() { return 0 || test_unaryop(RandomMat(12, 16)) @@ -71,7 +79,7 @@ static int test_unaryop_1() || test_unaryop(RandomMat(14, 15)); } -static int test_unaryop_2() +static int test_unaryop_3() { return 0 || test_unaryop(RandomMat(128)) @@ -88,7 +96,8 @@ int main() int ret = 0 || test_unaryop_0() || test_unaryop_1() - || test_unaryop_2(); + || test_unaryop_2() + || test_unaryop_3(); if (ret != 0) return ret; diff --git a/tests/testutil.h b/tests/testutil.h index b4325da52..9e3bb6327 100644 --- a/tests/testutil.h +++ b/tests/testutil.h @@ -457,7 +457,7 @@ int test_layer_cpu(int typeindex, const ncnn::ParamDict& pd, const std::vector