diff --git a/docs/developer-guide/binaryop-broadcasting.md b/docs/developer-guide/binaryop-broadcasting.md index 5f69f9ca6..decdaa3b5 100644 --- a/docs/developer-guide/binaryop-broadcasting.md +++ b/docs/developer-guide/binaryop-broadcasting.md @@ -6,47 +6,48 @@ C = BinaryOp(A, B) shape notation convention is [w], [w,h], [w,h,c], [w,h,d,c] +* binaryop with scalar and scalar-like + +|type|A|B|C| +|---|---|---|---| +|0|[2]|scalar / [1]|[2]| +|1|[2,3]|scalar / [1] / [1,1]|[2,3]| +|2|[2,3,4]|scalar / [1] / [1,1] / [1,1,1]|[2,3,4]| +|3|[2,3,4,5]|scalar / [1] / [1,1] / [1,1,1] / [1,1,1,1]|[2,3,4,5]| + +* no broadcast + +|type|A|B|C| +|---|---|---|---| +|4|[2]|[2]|[2]| +|5|[2,3]|[2,3]|[2,3]| +|6|[2,3,4]|[2,3,4]|[2,3,4]| +|7|[2,3,4,5]|[2,3,4,5]|[2,3,4,5]| + +* broadcast B for inner axis + |type|A|B|C| |---|---|---|---| -|1|[1]|scalar|[1]| -|2|[1]|[2]|[2]| -|3|[1]|[2,3]|[2,3]| -|4|[1]|[2,3,4]|[2,3,4]| -|5|[2]|scalar|[2]| -|6|[2]|[1]|[2]| -|7|[2]|[2]|[2]| -|8|[3]|[2,3]|[2,3]| -|9|[4]|[2,3,4]|[2,3,4]| -|10|[2,3]|scalar|[2,3]| -|11|[2,3]|[1]|[2,3]| -|12|[2,3]|[3]|[2,3]| -|13|[2,3]|[2,3]|[2,3]| -|14|[3,4]|[2,3,4]|[2,3,4]| -|15|[2,3,4]|scalar|[2,3,4]| -|16|[2,3,4]|[1]|[2,3,4]| -|17|[2,3,4]|[4]|[2,3,4]| -|18|[2,3,4]|[3,4]|[2,3,4]| -|19|[2,3,4]|[2,3,4]|[2,3,4]| -|20|[1]|[2,3,4,5]|[2,3,4,5]| -|21|[5]|[2,3,4,5]|[2,3,4,5]| -|22|[4,5]|[2,3,4,5]|[2,3,4,5]| -|23|[3,4,5]|[2,3,4,5]|[2,3,4,5]| -|24|[2,3,4,5]|scalar|[2,3,4,5]| -|25|[2,3,4,5]|[1]|[2,3,4,5]| -|26|[2,3,4,5]|[5]|[2,3,4,5]| -|27|[2,3,4,5]|[4,5]|[2,3,4,5]| -|28|[2,3,4,5]|[3,4,5]|[2,3,4,5]| -|29|[2,3,4,5]|[2,3,4,5]|[2,3,4,5]| - -some special broadcasting rule exists for model compatibility +|8|[2,3]|[3] / [1,3]|[2,3]| +|9|[2,3,4]|[4] / [1,1,4]|[2,3,4]| +|10|[2,3,4]|[3,4] / [1,3,4]|[2,3,4]| +|11|[2,3,4,5]|[5] / [1,1,1,5]|[2,3,4,5]| +|12|[2,3,4,5]|[4,5] / [1,1,4,5]|[2,3,4,5]| +|13|[2,3,4,5]|[3,4,5] / [1,3,4,5]|[2,3,4,5]| + +* broadcast B for outer axis + +|type|A|B|C| +|---|---|---|---| +|14|[2,3]|[2,1]|[2,3]| +|15|[2,3,4]|[2,1,1]|[2,3,4]| +|16|[2,3,4]|[2,3,1]|[2,3,4]| +|17|[2,3,4,5]|[2,1,1,1]|[2,3,4,5]| +|18|[2,3,4,5]|[2,3,1,1]|[2,3,4,5]| +|19|[2,3,4,5]|[2,3,4,1]|[2,3,4,5]| + +* some special broadcasting rule exists for model compatibility |special type|A|B|C| |---|---|---|---| -|1|[2,3,4]|[1,1,4]|[2,3,4]| -|2|[2,3,4]|[2,3,1]|[2,3,4]| -|3|[1,1,4]|[2,3,4]|[2,3,4]| -|4|[2,3,1]|[2,3,4]|[2,3,4]| -|5|[2,3,4]|[1,3,4]|[2,3,4]| -|6|[2,3,4]|[2,1,4]|[2,3,4]| -|7|[1,3,4]|[2,3,4]|[2,3,4]| -|8|[2,1,4]|[2,3,4]|[2,3,4]| +|20|[2,3,4]|[2,1,4]|[2,3,4]| diff --git a/docs/developer-guide/operators.md b/docs/developer-guide/operators.md index fa7ac4263..9a49070a1 100644 --- a/docs/developer-guide/operators.md +++ b/docs/developer-guide/operators.md @@ -161,6 +161,7 @@ Operation type: - 6 = POW - 7 = RSUB - 8 = RDIV +- 9 = RPOW # BNLL ``` diff --git a/src/layer/arm/binaryop_arm.cpp b/src/layer/arm/binaryop_arm.cpp index aa4e73202..932d1dc82 100644 --- a/src/layer/arm/binaryop_arm.cpp +++ b/src/layer/arm/binaryop_arm.cpp @@ -41,93 +41,36 @@ BinaryOp_arm::BinaryOp_arm() } template -static int binary_op_2_3_4_20(const Mat& a, const Mat& b, Mat& c, const Option& opt) +static int binary_op_scalar(const Mat& a, float b, Mat& c, const Option& opt) { Op op; - int w = b.w; - int h = b.h; - int d = b.d; - int channels = b.c; - int elempack = b.elempack; - int size = w * h * d * elempack; - - // type 2 3 4 20 - c.create_like(b, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float a0 = a[0]; - const float* ptr = b.channel(q); - float* outptr = c.channel(q); - - int i = 0; -#if __ARM_NEON - float32x4_t _a0 = vdupq_n_f32(a0); - for (; i + 3 < size; i += 4) - { - float32x4_t _p = vld1q_f32(ptr); - float32x4_t _outp = op(_a0, _p); - vst1q_f32(outptr, _outp); - ptr += 4; - outptr += 4; - } -#endif // __ARM_NEON - for (; i < size; i++) - { - *outptr = op(a0, *ptr); - ptr += 1; - outptr += 1; - } - } - - return 0; -} - -template -static int binary_op_6_11_16_25(const Mat& a, const Mat& b, Mat& c, const Option& opt) -{ - Op op; - - int w = a.w; - int h = a.h; - int d = a.d; - int channels = a.c; - int elempack = a.elempack; - int size = w * h * d * elempack; - - // type 6 11 16 25 - c.create_like(a, opt.blob_allocator); - if (c.empty()) - return -100; + const int channels = a.c; + const int size = a.w * a.h * a.d * a.elempack; #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) { const float* ptr = a.channel(q); - const float b0 = b[0]; float* outptr = c.channel(q); int i = 0; #if __ARM_NEON - float32x4_t _b0 = vdupq_n_f32(b0); + float32x4_t _b = vdupq_n_f32(b); for (; i + 3 < size; i += 4) { float32x4_t _p = vld1q_f32(ptr); - float32x4_t _outp = op(_p, _b0); - vst1q_f32(outptr, _outp); + _p = op(_p, _b); + vst1q_f32(outptr, _p); ptr += 4; outptr += 4; } #endif // __ARM_NEON for (; i < size; i++) { - *outptr = op(*ptr, b0); - ptr += 1; - outptr += 1; + *outptr = op(*ptr, b); + ptr++; + outptr++; } } @@ -135,21 +78,12 @@ static int binary_op_6_11_16_25(const Mat& a, const Mat& b, Mat& c, const Option } template -static int binary_op_7_13_19_29(const Mat& a, const Mat& b, Mat& c, const Option& opt) +static int binary_op_no_broadcast(const Mat& a, const Mat& b, Mat& c, const Option& opt) { Op op; - int w = a.w; - int h = a.h; - int d = a.d; - int channels = a.c; - int elempack = a.elempack; - int size = w * h * d * elempack; - - // type 7 13 19 29 - c.create_like(a, opt.blob_allocator); - if (c.empty()) - return -100; + const int channels = a.c; + const int size = a.w * a.h * a.d * a.elempack; #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) @@ -183,12 +117,8 @@ static int binary_op_7_13_19_29(const Mat& a, const Mat& b, Mat& c, const Option return 0; } -#if __ARM_NEON -// broadcasting rule -// https://github.com/Tencent/ncnn/wiki/binaryop-broadcasting - template -static int binary_op_pack4(const Mat& a, const Mat& b, Mat& c, const Option& opt) +static int binary_op_broadcast_inner(const Mat& a, const Mat& b, Mat& c, const Option& opt) { Op op; @@ -196,706 +126,362 @@ static int binary_op_pack4(const Mat& a, const Mat& b, Mat& c, const Option& opt int h = a.h; int d = a.d; int channels = a.c; - int size = w * h * d; - size_t elemsize = a.elemsize; int elempack = a.elempack; - int w1 = b.w; - int h1 = b.h; - int d1 = b.d; - int channels1 = b.c; - int size1 = w1 * h1 * d1; - size_t elemsize1 = b.elemsize; - int elempack1 = b.elempack; - - if (a.dims == 4) + if (a.dims == 2 && b.dims == 1) { - if (b.dims == 4) + // type 8 + #pragma omp parallel for num_threads(opt.num_threads) + for (int y = 0; y < h; y++) { - // type 29 - return binary_op_7_13_19_29(a, b, c, opt); - } - - c.create(w, h, d, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; + const float* ptr = a.row(y); + float* outptr = c.row(y); - if (b.dims == 3) - { - // type 28 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); + const float _b = b[y]; +#if __ARM_NEON + float32x4_t _b_128 = (elempack == 4) ? vld1q_f32((const float*)b + y * 4) : vdupq_n_f32(_b); +#endif // __ARM_NEON - for (int z = 0; z < d; z++) - { - for (int y = 0; y < h; y++) - { - float32x4_t _b0 = vld1q_f32(ptr1); - for (int x = 0; x < w; x++) - { - float32x4_t _p = vld1q_f32(ptr); - float32x4_t _outp = op(_p, _b0); - vst1q_f32(outptr, _outp); - ptr += 4; - outptr += 4; - } + const int size = w * elempack; - ptr1 += 4; - } - } + int i = 0; +#if __ARM_NEON + for (; i + 3 < size; i += 4) + { + float32x4_t _p = vld1q_f32(ptr); + float32x4_t _outp = op(_p, _b_128); + vst1q_f32(outptr, _outp); + ptr += 4; + outptr += 4; + } +#endif // __ARM_NEON + for (; i < size; i++) + { + *outptr = op(*ptr, _b); + ptr += 1; + outptr += 1; } - - return 0; } + } - if (b.dims == 2) + if ((a.dims == 3 || a.dims == 4) && b.dims == 1) + { + // type 9 11 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - // type 27 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.row(q); - float* outptr = c.channel(q); + const float* ptr = a.channel(q); + float* outptr = c.channel(q); - for (int z = 0; z < d; z++) - { - float32x4_t _b0 = vld1q_f32(ptr1); - for (int y = 0; y < h; y++) - { - for (int x = 0; x < w; x++) - { - float32x4_t _p = vld1q_f32(ptr); - float32x4_t _outp = op(_p, _b0); - vst1q_f32(outptr, _outp); - ptr += 4; - outptr += 4; - } - } + const float _b = b[q]; +#if __ARM_NEON + float32x4_t _b_128 = (elempack == 4) ? vld1q_f32((const float*)b + q * 4) : vdupq_n_f32(_b); +#endif // __ARM_NEON - ptr1 += 4; - } - } + const int size = w * h * d * elempack; - return 0; + int i = 0; +#if __ARM_NEON + for (; i + 3 < size; i += 4) + { + float32x4_t _p = vld1q_f32(ptr); + float32x4_t _outp = op(_p, _b_128); + vst1q_f32(outptr, _outp); + ptr += 4; + outptr += 4; + } +#endif // __ARM_NEON + for (; i < size; i++) + { + *outptr = op(*ptr, _b); + ptr += 1; + outptr += 1; + } } + } - if (b.dims == 1) + if (a.dims == 3 && b.dims == 2) + { + // type 10 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - if (b.w == 1 && elempack1 == 1) - { - // type 25 - return binary_op_6_11_16_25(a, b, c, opt); - } + const float* ptr = a.channel(q); + const float* ptr1 = b.row(q); + float* outptr = c.channel(q); - // type 26 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) + const int size = w * elempack; + + for (int y = 0; y < h; y++) { - const float* ptr = a.channel(q); - float32x4_t _b0 = vld1q_f32((const float*)b + q * 4); - float* outptr = c.channel(q); + const float _b = ptr1[y]; +#if __ARM_NEON + float32x4_t _b_128 = (elempack == 4) ? vld1q_f32((const float*)ptr1 + y * 4) : vdupq_n_f32(_b); +#endif // __ARM_NEON - for (int i = 0; i < size; i++) + int i = 0; +#if __ARM_NEON + for (; i + 3 < size; i += 4) { float32x4_t _p = vld1q_f32(ptr); - float32x4_t _outp = op(_p, _b0); + float32x4_t _outp = op(_p, _b_128); vst1q_f32(outptr, _outp); ptr += 4; outptr += 4; } +#endif // __ARM_NEON + for (; i < size; i++) + { + *outptr = op(*ptr, _b); + ptr += 1; + outptr += 1; + } } - - return 0; } } - else if (a.dims == 3) + + if (a.dims == 4 && b.dims == 2) { - if (b.dims == 4) + // type 12 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - // type 23 - c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; + const float* ptr = a.channel(q); + const float* ptr1 = b.row(q); + float* outptr = c.channel(q); - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) + const int size = w * h * elempack; + + for (int z = 0; z < d; z++) { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); + const float _b = ptr1[z]; +#if __ARM_NEON + float32x4_t _b_128 = (elempack == 4) ? vld1q_f32((const float*)ptr1 + z * 4) : vdupq_n_f32(_b); +#endif // __ARM_NEON - for (int z = 0; z < d1; z++) + int i = 0; +#if __ARM_NEON + for (; i + 3 < size; i += 4) { - for (int y = 0; y < h1; y++) - { - float32x4_t _a0 = vld1q_f32(ptr); - for (int x = 0; x < w1; x++) - { - float32x4_t _p = vld1q_f32(ptr1); - float32x4_t _outp = op(_a0, _p); - vst1q_f32(outptr, _outp); - ptr1 += 4; - outptr += 4; - } - - ptr += 4; - } + float32x4_t _p = vld1q_f32(ptr); + float32x4_t _outp = op(_p, _b_128); + vst1q_f32(outptr, _outp); + ptr += 4; + outptr += 4; + } +#endif // __ARM_NEON + for (; i < size; i++) + { + *outptr = op(*ptr, _b); + ptr += 1; + outptr += 1; } } - - return 0; } + } - if (b.dims == 3) + if (a.dims == 4 && b.dims == 3) + { + // type 13 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - if (w1 == 1 && h1 == 1 && channels1 == channels) + const float* ptr = a.channel(q); + float* outptr = c.channel(q); + + const int size = w * elempack; + + for (int z = 0; z < d; z++) { - // special type 1 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; + const float* ptr1 = b.channel(q).row(z); - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) + for (int y = 0; y < h; y++) { - const float* ptr = a.channel(q); - const float* b0 = b.channel(q); - float* outptr = c.channel(q); - float32x4_t _b0 = vld1q_f32(b0); - for (int i = 0; i < size; i++) + const float _b = ptr1[y]; +#if __ARM_NEON + float32x4_t _b_128 = (elempack == 4) ? vld1q_f32((const float*)ptr1 + y * 4) : vdupq_n_f32(_b); +#endif // __ARM_NEON + + int i = 0; +#if __ARM_NEON + for (; i + 3 < size; i += 4) { float32x4_t _p = vld1q_f32(ptr); - float32x4_t _outp = op(_p, _b0); + float32x4_t _outp = op(_p, _b_128); vst1q_f32(outptr, _outp); ptr += 4; outptr += 4; } +#endif // __ARM_NEON + for (; i < size; i++) + { + *outptr = op(*ptr, _b); + ptr += 1; + outptr += 1; + } } - - return 0; } + } + } - if (w1 == w && h1 == h && channels1 == 1 && elempack1 == 1) - { - // special type 2 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; + return 0; +} - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b; - float* outptr = c.channel(q); - for (int i = 0; i < size; i++) - { - float32x4_t _p = vld1q_f32(ptr); - float32x4_t _p1 = vld1q_dup_f32(ptr1); - float32x4_t _outp = op(_p, _p1); - vst1q_f32(outptr, _outp); - ptr += 4; - ptr1 += 1; - outptr += 4; - } - } +template +static int binary_op_broadcast_outer(const Mat& a, const Mat& b, Mat& c, const Option& opt) +{ + Op op; - return 0; - } + int w = a.w; + int h = a.h; + int d = a.d; + int channels = a.c; + int elempack = a.elempack; - if (w == 1 && h == 1 && channels1 == channels) - { - // special type 3 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; + if (a.dims == 2) + { + // type 14 + #pragma omp parallel for num_threads(opt.num_threads) + for (int y = 0; y < h; y++) + { + const float* ptr = a.row(y); + const float* ptr1 = b; + float* outptr = c.row(y); - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) +#if __ARM_NEON + if (elempack == 4) + { + for (int x = 0; x < w; x++) { - const float* a0 = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - float32x4_t _a0 = vld1q_f32(a0); - for (int i = 0; i < size1; i++) - { - float32x4_t _p1 = vld1q_f32(ptr1); - float32x4_t _outp = op(_a0, _p1); - vst1q_f32(outptr, _outp); - ptr1 += 4; - outptr += 4; - } + float32x4_t _p = vld1q_f32(ptr); + float32x4_t _b = vdupq_n_f32(*ptr1); + float32x4_t _outp = op(_p, _b); + vst1q_f32(outptr, _outp); + ptr += 4; + ptr1 += 1; + outptr += 4; } - - return 0; } - - if (w1 == w && h1 == h && channels == 1 && elempack == 1) +#endif // __ARM_NEON + if (elempack == 1) { - // special type 4 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) + for (int x = 0; x < w; x++) { - const float* ptr = a; - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - for (int i = 0; i < size1; i++) - { - float32x4_t _p = vld1q_dup_f32(ptr); - float32x4_t _p1 = vld1q_f32(ptr1); - float32x4_t _outp = op(_p, _p1); - vst1q_f32(outptr, _outp); - ptr += 1; - ptr1 += 4; - outptr += 4; - } + *outptr = op(*ptr, *ptr1); + ptr += 1; + ptr1 += 1; + outptr += 1; } - - return 0; } + } + } - if (w != 1 && w1 == 1 && h1 == h && channels1 == channels) - { - // special type 5 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; + if (a.dims == 3 || a.dims == 4) + { + // type 15 16 17 18 19 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* ptr = a.channel(q); + float* outptr = c.channel(q); - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) + for (int z = 0; z < d; z++) + { + int z1 = std::min(z, b.d - 1); + for (int y = 0; y < h; y++) { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); + int y1 = std::min(y, b.h - 1); + + const float* ptr1 = b.depth(z1).row(y1); - for (int y = 0; y < h; y++) +#if __ARM_NEON + if (elempack == 4) { - float32x4_t _p1 = vld1q_f32(ptr1 + y * 4); for (int x = 0; x < w; x++) { float32x4_t _p = vld1q_f32(ptr); - float32x4_t _outp = op(_p, _p1); + float32x4_t _b = vdupq_n_f32(*ptr1); + float32x4_t _outp = op(_p, _b); vst1q_f32(outptr, _outp); - ptr += 4; + ptr1 += 1; outptr += 4; } } - } - - return 0; - } - - if (w1 == w && h != 1 && h1 == 1 && channels1 == channels) - { - // special type 6 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int y = 0; y < h; y++) +#endif // __ARM_NEON + if (elempack == 1) { for (int x = 0; x < w; x++) { - float32x4_t _p = vld1q_f32(ptr); - float32x4_t _p1 = vld1q_f32(ptr1 + x * 4); - float32x4_t _outp = op(_p, _p1); - vst1q_f32(outptr, _outp); - - ptr += 4; - outptr += 4; + *outptr = op(*ptr, *ptr1); + ptr += 1; + ptr1 += 1; + outptr += 1; } } } - - return 0; } + } + } - if (w1 != 1 && w == 1 && h1 == h && channels1 == channels) - { - // special type 7 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); + return 0; +} - for (int y = 0; y < h1; y++) - { - float32x4_t _p = vld1q_f32(ptr + y * 4); - for (int x = 0; x < w1; x++) - { - float32x4_t _p1 = vld1q_f32(ptr1); - float32x4_t _outp = op(_p, _p1); - vst1q_f32(outptr, _outp); +template +static int binary_op_broadcast_20(const Mat& a, const Mat& b, Mat& c, const Option& opt) +{ + Op op; - ptr1 += 4; - outptr += 4; - } - } - } + int w = a.w; + int h = a.h; + int channels = a.c; + int elempack = a.elempack; - return 0; - } + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* ptr = a.channel(q); + float* outptr = c.channel(q); - if (w1 == w && h1 != 1 && h == 1 && channels1 == channels) - { - // special type 8 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; + for (int y = 0; y < h; y++) + { + const float* ptr1 = b.channel(q); - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); + const int size = w * elempack; - for (int y = 0; y < h1; y++) - { - for (int x = 0; x < w1; x++) - { - float32x4_t _p = vld1q_f32(ptr + x * 4); - float32x4_t _p1 = vld1q_f32(ptr1); - float32x4_t _outp = op(_p, _p1); - vst1q_f32(outptr, _outp); - - ptr1 += 4; - outptr += 4; - } - } - } - - return 0; - } - - // type 19 - return binary_op_7_13_19_29(a, b, c, opt); - } - - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.dims == 2) - { - // type 18 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.row(q); - float* outptr = c.channel(q); - - for (int y = 0; y < h; y++) - { - float32x4_t _b0 = vld1q_f32(ptr1); - for (int x = 0; x < w; x++) - { - float32x4_t _p = vld1q_f32(ptr); - float32x4_t _outp = op(_p, _b0); - vst1q_f32(outptr, _outp); - ptr += 4; - outptr += 4; - } - - ptr1 += 4; - } - } - - return 0; - } - - if (b.dims == 1) - { - if (b.w == 1 && elempack1 == 1) - { - // type 16 - return binary_op_6_11_16_25(a, b, c, opt); - } - - // type 17 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float* ptr = a.channel(q); - float32x4_t _b0 = vld1q_f32((const float*)b + q * 4); - float* outptr = c.channel(q); - - for (int i = 0; i < size; i++) - { - float32x4_t _p = vld1q_f32(ptr); - float32x4_t _outp = op(_p, _b0); - vst1q_f32(outptr, _outp); - ptr += 4; - outptr += 4; - } - } - - return 0; - } - } - else if (a.dims == 2) - { - if (b.dims == 4) - { - // type 22 - c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr = a.row(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int z = 0; z < d1; z++) - { - float32x4_t _a0 = vld1q_f32(ptr); - for (int y = 0; y < h1; y++) - { - for (int x = 0; x < w1; x++) - { - float32x4_t _p = vld1q_f32(ptr1); - float32x4_t _outp = op(_a0, _p); - vst1q_f32(outptr, _outp); - ptr1 += 4; - outptr += 4; - } - } - - ptr += 4; - } - } - - return 0; - } - - if (b.dims == 3) - { - // type 14 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr = a.row(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int y = 0; y < h1; y++) - { - float32x4_t _a0 = vld1q_f32(ptr); - for (int x = 0; x < w1; x++) - { - float32x4_t _p1 = vld1q_f32(ptr1); - float32x4_t _outp = op(_a0, _p1); - vst1q_f32(outptr, _outp); - ptr1 += 4; - outptr += 4; - } - - ptr += 4; - } - } - - return 0; - } - - c.create(w, h, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.dims == 2) - { - // type 13 - return binary_op_7_13_19_29(a, b, c, opt); - } - - if (b.dims == 1) - { - c.create(w, h, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.w == 1 && elempack1 == 1) - { - // type 11 - return binary_op_6_11_16_25(a, b, c, opt); - } - - // type 12 - const float* ptr = a; - const float* ptr1 = b; - float* outptr = c; - - for (int y = 0; y < h; y++) - { - float32x4_t _b0 = vld1q_f32(ptr1); - for (int x = 0; x < w; x++) - { - float32x4_t _p = vld1q_f32(ptr); - float32x4_t _outp = op(_p, _b0); - vst1q_f32(outptr, _outp); - ptr += 4; - outptr += 4; - } - - ptr1 += 4; - } - - return 0; - } - } - else if (a.dims == 1) - { - if (a.w == 1 && elempack == 1) - { - // type 2 3 4 20 - return binary_op_2_3_4_20(a, b, c, opt); - } - - if (b.dims == 4) - { - // type 21 - c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - float32x4_t _a0 = vld1q_f32((const float*)a + q * 4); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int i = 0; i < size1; i++) - { - float32x4_t _p1 = vld1q_f32(ptr1); - float32x4_t _outp = op(_a0, _p1); - vst1q_f32(outptr, _outp); - ptr1 += 4; - outptr += 4; - } - } - - return 0; - } - - if (b.dims == 3) - { - // type 9 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - float32x4_t _a0 = vld1q_f32((const float*)a + q * 4); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int i = 0; i < size1; i++) - { - float32x4_t _p1 = vld1q_f32(ptr1); - float32x4_t _outp = op(_a0, _p1); - vst1q_f32(outptr, _outp); - ptr1 += 4; - outptr += 4; - } - } - - return 0; - } - - if (b.dims == 2) - { - // type 8 - c.create(w1, h1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - const float* ptr = a; - const float* ptr1 = b; - float* outptr = c; - - for (int y = 0; y < h1; y++) + int i = 0; +#if __ARM_NEON + for (; i + 3 < size; i += 4) { - float32x4_t _a0 = vld1q_f32(ptr); - for (int x = 0; x < w1; x++) - { - float32x4_t _p1 = vld1q_f32(ptr1); - float32x4_t _outp = op(_a0, _p1); - vst1q_f32(outptr, _outp); - ptr1 += 4; - outptr += 4; - } - + float32x4_t _p = vld1q_f32(ptr); + float32x4_t _p1 = vld1q_f32(ptr1); + float32x4_t _outp = op(_p, _p1); + vst1q_f32(outptr, _outp); ptr += 4; + ptr1 += 4; + outptr += 4; } - - return 0; - } - - if (b.dims == 1) - { - c.create(w, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.w == 1 && elempack1 == 1) +#endif // __ARM_NEON + for (; i < size; i++) { - // type 6 - return binary_op_6_11_16_25(a, b, c, opt); + *outptr = op(*ptr, *ptr1); + ptr += 1; + ptr1 += 1; + outptr += 1; } - - // type 7 - binary_op_7_13_19_29(a, b, c, opt); } } return 0; } -#endif // __ARM_NEON template static int binary_op_scalar_inplace(Mat& a, float b, const Option& opt) { Op op; - int w = a.w; - int h = a.h; - int d = a.d; - int channels = a.c; - int elempack = a.elempack; - int size = w * h * d * elempack; + const int channels = a.c; + const int size = a.w * a.h * a.d * a.elempack; #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) @@ -916,7 +502,7 @@ static int binary_op_scalar_inplace(Mat& a, float b, const Option& opt) for (; i < size; i++) { *ptr = op(*ptr, b); - ptr += 1; + ptr++; } } @@ -968,6 +554,7 @@ MAKE_FUNCTION(binary_op_rdiv, y / x, vdivq_f32(y, x)) #else MAKE_FUNCTION(binary_op_rdiv, y / x, div_ps(y, x)) #endif +MAKE_FUNCTION(binary_op_rpow, (float)pow(y, x), pow_ps(y, x)) // *INDENT-ON* // clang-format on @@ -975,6 +562,127 @@ MAKE_FUNCTION(binary_op_rdiv, y / x, div_ps(y, x)) } // namespace BinaryOp_arm_functor +static int binary_op_scalar(const Mat& a, float b, Mat& c, int op_type, const Option& opt) +{ + using namespace BinaryOp_arm_functor; + + if (op_type == BinaryOp::Operation_ADD) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_SUB) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_MUL) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_DIV) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_MAX) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_MIN) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_POW) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_RSUB) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_RDIV) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_RPOW) return binary_op_scalar(a, b, c, opt); + + // should never reach here + return 0; +} + +static int binary_op_no_broadcast(const Mat& a, const Mat& b, Mat& c, int op_type, const Option& opt) +{ + using namespace BinaryOp_arm_functor; + + if (op_type == BinaryOp::Operation_ADD) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_SUB) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_MUL) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_DIV) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_MAX) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_MIN) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_POW) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_RSUB) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_RDIV) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_RPOW) return binary_op_no_broadcast(a, b, c, opt); + + // should never reach here + return 0; +} + +static int binary_op_broadcast_inner(const Mat& a, const Mat& b, Mat& c, int op_type, const Option& opt) +{ + // squeeze inner axes + Mat b2 = b; + if (b.dims == 2 && b.w == 1) + b2 = b.reshape(b.h); + else if (b.dims == 3 && b.h == 1) + b2 = b.reshape(b.c); + else if (b.dims == 3 && b.w == 1) + b2 = b.reshape(b.h, b.c); + else if (b.dims == 4 && b.d == 1) + b2 = b.reshape(b.c); + else if (b.dims == 4 && b.h == 1) + b2 = b.reshape(b.d, b.c); + else if (b.dims == 4 && b.w == 1) + b2 = b.reshape(b.h, b.d, b.c); + + using namespace BinaryOp_arm_functor; + + if (op_type == BinaryOp::Operation_ADD) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_SUB) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_MUL) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_DIV) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_MAX) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_MIN) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_POW) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_inner(a, b2, c, opt); + + // should never reach here + return 0; +} + +static int binary_op_broadcast_outer(const Mat& a, const Mat& b, Mat& c, int op_type, const Option& opt) +{ + using namespace BinaryOp_arm_functor; + + if (op_type == BinaryOp::Operation_ADD) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_SUB) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_MUL) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_DIV) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_MAX) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_MIN) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_POW) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_outer(a, b, c, opt); + + // should never reach here + return 0; +} + +static int binary_op_broadcast_20(const Mat& a, const Mat& b, Mat& c, int op_type, const Option& opt) +{ + using namespace BinaryOp_arm_functor; + + if (op_type == BinaryOp::Operation_ADD) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_SUB) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_MUL) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_DIV) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_MAX) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_MIN) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_POW) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_20(a, b, c, opt); + + // should never reach here + return 0; +} + +static int get_reverse_op_type(int op_type) +{ + if (op_type == BinaryOp::Operation_SUB) return BinaryOp::Operation_RSUB; + if (op_type == BinaryOp::Operation_DIV) return BinaryOp::Operation_RDIV; + if (op_type == BinaryOp::Operation_POW) return BinaryOp::Operation_RPOW; + if (op_type == BinaryOp::Operation_RSUB) return BinaryOp::Operation_SUB; + if (op_type == BinaryOp::Operation_RDIV) return BinaryOp::Operation_DIV; + if (op_type == BinaryOp::Operation_RPOW) return BinaryOp::Operation_POW; + return op_type; +} + int BinaryOp_arm::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const { int elembits = std::max(bottom_blobs[0].elembits(), bottom_blobs[1].elembits()); @@ -989,48 +697,56 @@ int BinaryOp_arm::forward(const std::vector& bottom_blobs, std::vector return forward_bf16s(bottom_blobs, top_blobs, opt); #endif - const Mat& bottom_blob = bottom_blobs[0]; - const Mat& bottom_blob1 = bottom_blobs[1]; - Mat& top_blob = top_blobs[0]; - -#if __ARM_NEON - using namespace BinaryOp_arm_functor; + const bool b_is_scalar = bottom_blobs[1].w * bottom_blobs[1].h * bottom_blobs[1].d * bottom_blobs[1].c * bottom_blobs[1].elempack == 1; + const bool a_rank_is_lower = bottom_blobs[0].dims < bottom_blobs[1].dims && !b_is_scalar; + const bool a_size_is_lower = bottom_blobs[0].w * bottom_blobs[0].h * bottom_blobs[0].d * bottom_blobs[0].c * bottom_blobs[0].elempack < bottom_blobs[1].w * bottom_blobs[1].h * bottom_blobs[1].d * bottom_blobs[1].c * bottom_blobs[1].elempack; + const bool a_is_lower = a_rank_is_lower || (!a_rank_is_lower && a_size_is_lower); + const Mat& A = a_is_lower ? bottom_blobs[1] : bottom_blobs[0]; + const Mat& B = a_is_lower ? bottom_blobs[0] : bottom_blobs[1]; + const int op_type_r = a_is_lower ? get_reverse_op_type(op_type) : op_type; - int elempack = bottom_blob.elempack; - int elempack1 = bottom_blob1.elempack; + Mat& top_blob = top_blobs[0]; + top_blob.create_like(A, opt.blob_allocator); + if (top_blob.empty()) + return -100; - if (elempack == 4 || elempack1 == 4) + // B is a scalar + if (B.w * B.h * B.d * B.c * B.elempack == 1) { - if (op_type == Operation_ADD) - return binary_op_pack4(bottom_blob, bottom_blob1, top_blob, opt); - - if (op_type == Operation_SUB) - return binary_op_pack4(bottom_blob, bottom_blob1, top_blob, opt); - - if (op_type == Operation_MUL) - return binary_op_pack4(bottom_blob, bottom_blob1, top_blob, opt); - - if (op_type == Operation_DIV) - return binary_op_pack4(bottom_blob, bottom_blob1, top_blob, opt); - - if (op_type == Operation_MAX) - return binary_op_pack4(bottom_blob, bottom_blob1, top_blob, opt); + return binary_op_scalar(A, B[0], top_blob, op_type_r, opt); + } - if (op_type == Operation_MIN) - return binary_op_pack4(bottom_blob, bottom_blob1, top_blob, opt); + // no broadcast + if (A.dims == B.dims && A.w == B.w && A.h == B.h && A.d == B.d && A.c == B.c && A.elempack == B.elempack) + { + return binary_op_no_broadcast(A, B, top_blob, op_type_r, opt); + } - if (op_type == Operation_POW) - return binary_op_pack4(bottom_blob, bottom_blob1, top_blob, opt); + // broadcast B for inner axis + if ((B.dims < A.dims) + || (A.dims == 2 && B.w == 1 && B.h == A.h) + || (A.dims == 3 && B.w == 1 && B.h == 1 && B.c == A.c) + || (A.dims == 3 && B.w == 1 && B.h == A.h && B.c == A.c) + || (A.dims == 4 && B.w == 1 && B.h == 1 && B.d == 1 && B.c == A.c) + || (A.dims == 4 && B.w == 1 && B.h == 1 && B.d == A.d && B.c == A.c) + || (A.dims == 4 && B.w == 1 && B.h == A.h && B.d == A.d && B.c == A.c)) + { + return binary_op_broadcast_inner(A, B, top_blob, op_type_r, opt); + } - if (op_type == Operation_RSUB) - return binary_op_pack4(bottom_blob1, bottom_blob, top_blob, opt); + // broadcast B for outer axis + if (B.elempack == 1 && ((A.dims == 2 && B.w == A.w && B.h == 1) || (A.dims == 3 && B.w == A.w && B.h == 1 && B.c == 1) || (A.dims == 3 && B.w == A.w && B.h == A.h && B.c == 1) || (A.dims == 4 && B.w == A.w && B.h == 1 && B.d == 1 && B.c == 1) || (A.dims == 4 && B.w == A.w && B.h == A.h && B.d == 1 && B.c == 1) || (A.dims == 4 && B.w == A.w && B.h == A.h && B.d == A.d && B.c == 1))) + { + return binary_op_broadcast_outer(A, B, top_blob, op_type_r, opt); + } - if (op_type == Operation_RDIV) - return binary_op_pack4(bottom_blob1, bottom_blob, top_blob, opt); + // some special broadcast rule here + if (A.dims == 3 && B.dims == 3 && A.w == B.w && B.h == 1 && A.c == B.c) + { + return binary_op_broadcast_20(A, B, top_blob, op_type_r, opt); } -#endif // __ARM_NEON - return BinaryOp::forward(bottom_blobs, top_blobs, opt); + return 0; } int BinaryOp_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const @@ -1049,78 +765,52 @@ int BinaryOp_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const using namespace BinaryOp_arm_functor; - if (op_type == Operation_ADD) - return binary_op_scalar_inplace(bottom_top_blob, b, opt); - - if (op_type == Operation_SUB) - return binary_op_scalar_inplace(bottom_top_blob, b, opt); - - if (op_type == Operation_MUL) - return binary_op_scalar_inplace(bottom_top_blob, b, opt); - - if (op_type == Operation_DIV) - return binary_op_scalar_inplace(bottom_top_blob, b, opt); - - if (op_type == Operation_MAX) - return binary_op_scalar_inplace(bottom_top_blob, b, opt); - - if (op_type == Operation_MIN) - return binary_op_scalar_inplace(bottom_top_blob, b, opt); - - if (op_type == Operation_POW) - return binary_op_scalar_inplace(bottom_top_blob, b, opt); - - if (op_type == Operation_RSUB) - return binary_op_scalar_inplace(bottom_top_blob, b, opt); - - if (op_type == Operation_RDIV) - return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_ADD) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_SUB) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_MUL) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_DIV) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_MAX) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_MIN) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_POW) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_RSUB) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_RDIV) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_RPOW) return binary_op_scalar_inplace(bottom_top_blob, b, opt); return 0; } #if NCNN_BF16 template -static int binary_op_2_3_4_20_bf16s(const Mat& a, const Mat& b, Mat& c, const Option& opt) +static int binary_op_scalar_bf16s(const Mat& a, float b, Mat& c, const Option& opt) { Op op; - int w = b.w; - int h = b.h; - int d = b.d; - int channels = b.c; - int elempack = b.elempack; - int size = w * h * d * elempack; - - // type 2 3 4 20 - c.create_like(b, opt.blob_allocator); - if (c.empty()) - return -100; + const int channels = a.c; + const int size = a.w * a.h * a.d * a.elempack; #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) { - const float a0 = bfloat16_to_float32(((const unsigned short*)a)[0]); - const unsigned short* ptr = b.channel(q); + const unsigned short* ptr = a.channel(q); unsigned short* outptr = c.channel(q); int i = 0; #if __ARM_NEON - float32x4_t _a0 = vdupq_n_f32(a0); + float32x4_t _b = vdupq_n_f32(b); for (; i + 3 < size; i += 4) { float32x4_t _p = bfloat2float(vld1_u16(ptr)); - float32x4_t _outp = op(_a0, _p); - vst1_u16(outptr, float2bfloat(_outp)); + _p = op(_p, _b); + vst1_u16(outptr, float2bfloat(_p)); ptr += 4; outptr += 4; } #endif // __ARM_NEON for (; i < size; i++) { - *outptr = float32_to_bfloat16(op(a0, bfloat16_to_float32(*ptr))); - ptr += 1; - outptr += 1; + *outptr = float32_to_bfloat16(op(bfloat16_to_float32(*ptr), b)); + ptr++; + outptr++; } } @@ -1128,45 +818,38 @@ static int binary_op_2_3_4_20_bf16s(const Mat& a, const Mat& b, Mat& c, const Op } template -static int binary_op_6_11_16_25_bf16s(const Mat& a, const Mat& b, Mat& c, const Option& opt) +static int binary_op_no_broadcast_bf16s(const Mat& a, const Mat& b, Mat& c, const Option& opt) { Op op; - int w = a.w; - int h = a.h; - int d = a.d; - int channels = a.c; - int elempack = a.elempack; - int size = w * h * d * elempack; - - // type 6 11 16 25 - c.create_like(a, opt.blob_allocator); - if (c.empty()) - return -100; + const int channels = a.c; + const int size = a.w * a.h * a.d * a.elempack; #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) { const unsigned short* ptr = a.channel(q); - const float b0 = bfloat16_to_float32(((const unsigned short*)b)[0]); + const unsigned short* ptr1 = b.channel(q); unsigned short* outptr = c.channel(q); int i = 0; #if __ARM_NEON - float32x4_t _b0 = vdupq_n_f32(b0); for (; i + 3 < size; i += 4) { float32x4_t _p = bfloat2float(vld1_u16(ptr)); - float32x4_t _outp = op(_p, _b0); + float32x4_t _p1 = bfloat2float(vld1_u16(ptr1)); + float32x4_t _outp = op(_p, _p1); vst1_u16(outptr, float2bfloat(_outp)); ptr += 4; + ptr1 += 4; outptr += 4; } #endif // __ARM_NEON for (; i < size; i++) { - *outptr = float32_to_bfloat16(op(bfloat16_to_float32(*ptr), b0)); + *outptr = float32_to_bfloat16(op(bfloat16_to_float32(*ptr), bfloat16_to_float32(*ptr1))); ptr += 1; + ptr1 += 1; outptr += 1; } } @@ -1175,7 +858,7 @@ static int binary_op_6_11_16_25_bf16s(const Mat& a, const Mat& b, Mat& c, const } template -static int binary_op_7_13_19_29_bf16s(const Mat& a, const Mat& b, Mat& c, const Option& opt) +static int binary_op_broadcast_inner_bf16s(const Mat& a, const Mat& b, Mat& c, const Option& opt) { Op op; @@ -1184,1362 +867,348 @@ static int binary_op_7_13_19_29_bf16s(const Mat& a, const Mat& b, Mat& c, const int d = a.d; int channels = a.c; int elempack = a.elempack; - int size = w * h * d * elempack; - - // type 7 13 19 29 - c.create_like(a, opt.blob_allocator); - if (c.empty()) - return -100; - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) + if (a.dims == 2 && b.dims == 1) { - const unsigned short* ptr = a.channel(q); - const unsigned short* ptr1 = b.channel(q); - unsigned short* outptr = c.channel(q); + // type 8 + #pragma omp parallel for num_threads(opt.num_threads) + for (int y = 0; y < h; y++) + { + const unsigned short* ptr = a.row(y); + unsigned short* outptr = c.row(y); - int i = 0; + const float _b = bfloat16_to_float32(((const unsigned short*)b)[y]); #if __ARM_NEON - for (; i + 3 < size; i += 4) - { - float32x4_t _p = bfloat2float(vld1_u16(ptr)); - float32x4_t _p1 = bfloat2float(vld1_u16(ptr1)); - float32x4_t _outp = op(_p, _p1); - vst1_u16(outptr, float2bfloat(_outp)); - ptr += 4; - ptr1 += 4; - outptr += 4; - } -#endif // __ARM_NEON - for (; i < size; i++) - { - *outptr = float32_to_bfloat16(op(bfloat16_to_float32(*ptr), bfloat16_to_float32(*ptr1))); - ptr += 1; - ptr1 += 1; - outptr += 1; - } - } - - return 0; -} - -#if __ARM_NEON -template -static int binary_op_pack4_bf16s(const Mat& a, const Mat& b, Mat& c, const Option& opt) -{ - Op op; - - int w = a.w; - int h = a.h; - int d = a.d; - int channels = a.c; - int size = w * h * d; - size_t elemsize = a.elemsize; - int elempack = a.elempack; - - int w1 = b.w; - int h1 = b.h; - int d1 = b.d; - int channels1 = b.c; - int size1 = w1 * h1 * d1; - size_t elemsize1 = b.elemsize; - int elempack1 = b.elempack; - - if (a.dims == 4) - { - if (b.dims == 4) - { - // type 29 - return binary_op_7_13_19_29_bf16s(a, b, c, opt); - } - - c.create(w, h, d, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.dims == 3) - { - // type 28 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const unsigned short* ptr = a.channel(q); - const unsigned short* ptr1 = b.channel(q); - unsigned short* outptr = c.channel(q); - - for (int z = 0; z < d; z++) - { - for (int y = 0; y < h; y++) - { - float32x4_t _b0 = bfloat2float(vld1_u16(ptr1)); - for (int x = 0; x < w; x++) - { - float32x4_t _p = bfloat2float(vld1_u16(ptr)); - float32x4_t _outp = op(_p, _b0); - vst1_u16(outptr, float2bfloat(_outp)); - ptr += 4; - outptr += 4; - } - - ptr1 += 4; - } - } - } - - return 0; - } - - if (b.dims == 2) - { - // type 27 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const unsigned short* ptr = a.channel(q); - const unsigned short* ptr1 = b.row(q); - unsigned short* outptr = c.channel(q); - - for (int z = 0; z < d; z++) - { - float32x4_t _b0 = bfloat2float(vld1_u16(ptr1)); - for (int y = 0; y < h; y++) - { - for (int x = 0; x < w; x++) - { - float32x4_t _p = bfloat2float(vld1_u16(ptr)); - float32x4_t _outp = op(_p, _b0); - vst1_u16(outptr, float2bfloat(_outp)); - ptr += 4; - outptr += 4; - } - } - - ptr1 += 4; - } - } - - return 0; - } - - if (b.dims == 1) - { - if (b.w == 1 && elempack1 == 1) - { - // type 25 - return binary_op_6_11_16_25_bf16s(a, b, c, opt); - } - - // type 26 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const unsigned short* ptr = a.channel(q); - float32x4_t _b0 = bfloat2float(vld1_u16((const unsigned short*)b + q * 4)); - unsigned short* outptr = c.channel(q); - - for (int i = 0; i < size; i++) - { - float32x4_t _p = bfloat2float(vld1_u16(ptr)); - float32x4_t _outp = op(_p, _b0); - vst1_u16(outptr, float2bfloat(_outp)); - ptr += 4; - outptr += 4; - } - } - - return 0; - } - } - else if (a.dims == 3) - { - if (b.dims == 4) - { - // type 23 - c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const unsigned short* ptr = a.channel(q); - const unsigned short* ptr1 = b.channel(q); - unsigned short* outptr = c.channel(q); - - for (int z = 0; z < d1; z++) - { - for (int y = 0; y < h1; y++) - { - float32x4_t _a0 = bfloat2float(vld1_u16(ptr)); - for (int x = 0; x < w1; x++) - { - float32x4_t _p = bfloat2float(vld1_u16(ptr1)); - float32x4_t _outp = op(_a0, _p); - vst1_u16(outptr, float2bfloat(_outp)); - ptr1 += 4; - outptr += 4; - } - - ptr += 4; - } - } - } - - return 0; - } - - if (b.dims == 3) - { - if (w1 == 1 && h1 == 1 && channels1 == channels) - { - // special type 1 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const unsigned short* ptr = a.channel(q); - unsigned short* outptr = c.channel(q); - const unsigned short* b0 = b.channel(q); - float32x4_t _b0 = bfloat2float(vld1_u16(b0)); - for (int i = 0; i < size; i++) - { - float32x4_t _p = bfloat2float(vld1_u16(ptr)); - float32x4_t _outp = op(_p, _b0); - vst1_u16(outptr, float2bfloat(_outp)); - ptr += 4; - outptr += 4; - } - } - - return 0; - } - - if (w1 == w && h1 == h && channels1 == 1 && elempack1 == 1) - { - // special type 2 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const unsigned short* ptr = a.channel(q); - const unsigned short* ptr1 = b; - unsigned short* outptr = c.channel(q); - for (int i = 0; i < size; i++) - { - float32x4_t _p = bfloat2float(vld1_u16(ptr)); - float32x4_t _p1 = vdupq_n_f32(bfloat16_to_float32(*ptr1)); - float32x4_t _outp = op(_p, _p1); - vst1_u16(outptr, float2bfloat(_outp)); - ptr += 4; - ptr1 += 1; - outptr += 4; - } - } - - return 0; - } - - if (w == 1 && h == 1 && channels1 == channels) - { - // special type 3 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const unsigned short* a0 = a.channel(q); - unsigned short* outptr = c.channel(q); - const unsigned short* ptr1 = b.channel(q); - float32x4_t _a0 = bfloat2float(vld1_u16(a0)); - for (int i = 0; i < size1; i++) - { - float32x4_t _p1 = bfloat2float(vld1_u16(ptr1)); - float32x4_t _outp = op(_a0, _p1); - vst1_u16(outptr, float2bfloat(_outp)); - ptr1 += 4; - outptr += 4; - } - } - - return 0; - } - - if (w1 == w && h1 == h && channels == 1 && elempack == 1) - { - // special type 4 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const unsigned short* ptr = a; - const unsigned short* ptr1 = b.channel(q); - unsigned short* outptr = c.channel(q); - for (int i = 0; i < size1; i++) - { - float32x4_t _p = vdupq_n_f32(bfloat16_to_float32(*ptr)); - float32x4_t _p1 = bfloat2float(vld1_u16(ptr1)); - float32x4_t _outp = op(_p, _p1); - vst1_u16(outptr, float2bfloat(_outp)); - ptr += 1; - ptr1 += 4; - outptr += 4; - } - } - - return 0; - } - - if (w != 1 && w1 == 1 && h1 == h && channels1 == channels) - { - // special type 5 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const unsigned short* ptr = a.channel(q); - const unsigned short* ptr1 = b.channel(q); - unsigned short* outptr = c.channel(q); - - for (int y = 0; y < h; y++) - { - float32x4_t _p1 = bfloat2float(vld1_u16(ptr1 + y * 4)); - for (int x = 0; x < w; x++) - { - float32x4_t _p = bfloat2float(vld1_u16(ptr)); - float32x4_t _outp = op(_p, _p1); - vst1_u16(outptr, float2bfloat(_outp)); - - ptr += 4; - outptr += 4; - } - } - } - - return 0; - } - - if (w1 == w && h != 1 && h1 == 1 && channels1 == channels) - { - // special type 6 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const unsigned short* ptr = a.channel(q); - const unsigned short* ptr1 = b.channel(q); - unsigned short* outptr = c.channel(q); - - for (int y = 0; y < h; y++) - { - for (int x = 0; x < w; x++) - { - float32x4_t _p = bfloat2float(vld1_u16(ptr)); - float32x4_t _p1 = bfloat2float(vld1_u16(ptr1 + x * 4)); - float32x4_t _outp = op(_p, _p1); - vst1_u16(outptr, float2bfloat(_outp)); - - ptr += 4; - outptr += 4; - } - } - } - - return 0; - } - - if (w1 != 1 && w == 1 && h1 == h && channels1 == channels) - { - // special type 7 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const unsigned short* ptr = a.channel(q); - const unsigned short* ptr1 = b.channel(q); - unsigned short* outptr = c.channel(q); - - for (int y = 0; y < h1; y++) - { - float32x4_t _p = bfloat2float(vld1_u16(ptr + y * 4)); - for (int x = 0; x < w1; x++) - { - float32x4_t _p1 = bfloat2float(vld1_u16(ptr1)); - float32x4_t _outp = op(_p, _p1); - vst1_u16(outptr, float2bfloat(_outp)); - - ptr1 += 4; - outptr += 4; - } - } - } - - return 0; - } - - if (w1 == w && h1 != 1 && h == 1 && channels1 == channels) - { - // special type 8 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const unsigned short* ptr = a.channel(q); - const unsigned short* ptr1 = b.channel(q); - unsigned short* outptr = c.channel(q); - - for (int y = 0; y < h1; y++) - { - for (int x = 0; x < w1; x++) - { - float32x4_t _p = bfloat2float(vld1_u16(ptr + x * 4)); - float32x4_t _p1 = bfloat2float(vld1_u16(ptr1)); - float32x4_t _outp = op(_p, _p1); - vst1_u16(outptr, float2bfloat(_outp)); - - ptr1 += 4; - outptr += 4; - } - } - } - - return 0; - } - - // type 19 - return binary_op_7_13_19_29_bf16s(a, b, c, opt); - } - - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.dims == 2) - { - // type 18 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const unsigned short* ptr = a.channel(q); - const unsigned short* ptr1 = b.row(q); - unsigned short* outptr = c.channel(q); - - for (int y = 0; y < h; y++) - { - float32x4_t _b0 = bfloat2float(vld1_u16(ptr1)); - for (int x = 0; x < w; x++) - { - float32x4_t _p = bfloat2float(vld1_u16(ptr)); - float32x4_t _outp = op(_p, _b0); - vst1_u16(outptr, float2bfloat(_outp)); - ptr += 4; - outptr += 4; - } - - ptr1 += 4; - } - } - - return 0; - } - - if (b.dims == 1) - { - if (b.w == 1 && elempack1 == 1) - { - // type 16 - return binary_op_6_11_16_25_bf16s(a, b, c, opt); - } - - // type 17 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const unsigned short* ptr = a.channel(q); - float32x4_t _b0 = bfloat2float(vld1_u16((const unsigned short*)b + q * 4)); - unsigned short* outptr = c.channel(q); - - for (int i = 0; i < size; i++) - { - float32x4_t _p = bfloat2float(vld1_u16(ptr)); - float32x4_t _outp = op(_p, _b0); - vst1_u16(outptr, float2bfloat(_outp)); - ptr += 4; - outptr += 4; - } - } - - return 0; - } - } - else if (a.dims == 2) - { - if (b.dims == 4) - { - // type 22 - c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const unsigned short* ptr = a.row(q); - const unsigned short* ptr1 = b.channel(q); - unsigned short* outptr = c.channel(q); - - for (int z = 0; z < d1; z++) - { - float32x4_t _a0 = bfloat2float(vld1_u16(ptr)); - for (int y = 0; y < h1; y++) - { - for (int x = 0; x < w1; x++) - { - float32x4_t _p = bfloat2float(vld1_u16(ptr1)); - float32x4_t _outp = op(_a0, _p); - vst1_u16(outptr, float2bfloat(_outp)); - ptr1 += 4; - outptr += 4; - } - } - - ptr += 4; - } - } - - return 0; - } - - if (b.dims == 3) - { - // type 14 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const unsigned short* ptr = a.row(q); - const unsigned short* ptr1 = b.channel(q); - unsigned short* outptr = c.channel(q); - - for (int y = 0; y < h1; y++) - { - float32x4_t _a0 = bfloat2float(vld1_u16(ptr)); - for (int x = 0; x < w1; x++) - { - float32x4_t _p1 = bfloat2float(vld1_u16(ptr1)); - float32x4_t _outp = op(_a0, _p1); - vst1_u16(outptr, float2bfloat(_outp)); - ptr1 += 4; - outptr += 4; - } - - ptr += 4; - } - } - - return 0; - } - - c.create(w, h, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.dims == 2) - { - // type 13 - return binary_op_7_13_19_29_bf16s(a, b, c, opt); - } - - if (b.dims == 1) - { - c.create(w, h, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.w == 1 && elempack1 == 1) - { - // type 11 - return binary_op_6_11_16_25_bf16s(a, b, c, opt); - } - - // type 12 - const unsigned short* ptr = a; - const unsigned short* ptr1 = b; - unsigned short* outptr = c; - - for (int y = 0; y < h; y++) - { - float32x4_t _b0 = bfloat2float(vld1_u16(ptr1)); - for (int x = 0; x < w; x++) - { - float32x4_t _p = bfloat2float(vld1_u16(ptr)); - float32x4_t _outp = op(_p, _b0); - vst1_u16(outptr, float2bfloat(_outp)); - ptr += 4; - outptr += 4; - } - - ptr1 += 4; - } - - return 0; - } - } - else if (a.dims == 1) - { - if (a.w == 1 && elempack == 1) - { - // type 2 3 4 20 - return binary_op_2_3_4_20_bf16s(a, b, c, opt); - } - - if (b.dims == 4) - { - // type 21 - c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - float32x4_t _a0 = bfloat2float(vld1_u16((const unsigned short*)a + q * 4)); - const unsigned short* ptr1 = b.channel(q); - unsigned short* outptr = c.channel(q); - - for (int i = 0; i < size1; i++) - { - float32x4_t _p1 = bfloat2float(vld1_u16(ptr1)); - float32x4_t _outp = op(_a0, _p1); - vst1_u16(outptr, float2bfloat(_outp)); - ptr1 += 4; - outptr += 4; - } - } - - return 0; - } - - if (b.dims == 3) - { - // type 9 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - float32x4_t _a0 = bfloat2float(vld1_u16((const unsigned short*)a + q * 4)); - const unsigned short* ptr1 = b.channel(q); - unsigned short* outptr = c.channel(q); - - for (int i = 0; i < size1; i++) - { - float32x4_t _p1 = bfloat2float(vld1_u16(ptr1)); - float32x4_t _outp = op(_a0, _p1); - vst1_u16(outptr, float2bfloat(_outp)); - ptr1 += 4; - outptr += 4; - } - } - - return 0; - } - - if (b.dims == 2) - { - // type 8 - c.create(w1, h1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - const unsigned short* ptr = a; - const unsigned short* ptr1 = b; - unsigned short* outptr = c; - - for (int y = 0; y < h1; y++) - { - float32x4_t _a0 = bfloat2float(vld1_u16(ptr)); - for (int x = 0; x < w1; x++) - { - float32x4_t _p1 = bfloat2float(vld1_u16(ptr1)); - float32x4_t _outp = op(_a0, _p1); - vst1_u16(outptr, float2bfloat(_outp)); - ptr1 += 4; - outptr += 4; - } - - ptr += 4; - } - - return 0; - } - - if (b.dims == 1) - { - c.create(w, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.w == 1 && elempack1 == 1) - { - // type 6 - return binary_op_6_11_16_25_bf16s(a, b, c, opt); - } - - // type 7 - binary_op_7_13_19_29_bf16s(a, b, c, opt); - } - } - - return 0; -} -#endif // __ARM_NEON - -template -static int binary_op_bf16s(const Mat& a, const Mat& b, Mat& c, const Option& opt) -{ - Op op; - - int w = a.w; - int h = a.h; - int d = a.d; - int channels = a.c; - int size = w * h * d; - size_t elemsize = a.elemsize; - - int w1 = b.w; - int h1 = b.h; - int d1 = b.d; - int channels1 = b.c; - int size1 = w1 * h1 * d1; - - if (a.dims == 4) - { - if (b.dims == 4) - { - // type 29 - return binary_op_7_13_19_29_bf16s(a, b, c, opt); - } - - c.create(w, h, d, channels, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.dims == 3) - { - // type 28 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const unsigned short* ptr = a.channel(q); - const unsigned short* ptr1 = b.channel(q); - unsigned short* outptr = c.channel(q); - - for (int z = 0; z < d; z++) - { - for (int y = 0; y < h; y++) - { - const float b0 = bfloat16_to_float32(ptr1[y]); - for (int x = 0; x < w; x++) - { - outptr[x] = float32_to_bfloat16(op(bfloat16_to_float32(ptr[x]), b0)); - } - - ptr += w; - outptr += w; - } - - ptr1 += h; - } - } - - return 0; - } - - if (b.dims == 2) - { - // type 27 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const unsigned short* ptr = a.channel(q); - const unsigned short* ptr1 = b.row(q); - unsigned short* outptr = c.channel(q); - - for (int z = 0; z < d; z++) - { - const float b0 = bfloat16_to_float32(ptr1[z]); - for (int y = 0; y < h; y++) - { - for (int x = 0; x < w; x++) - { - outptr[x] = float32_to_bfloat16(op(bfloat16_to_float32(ptr[x]), b0)); - } - - ptr += w; - outptr += w; - } - } - } - - return 0; - } - - if (b.dims == 1) - { - if (b.w == 1) - { - // type 25 - return binary_op_6_11_16_25_bf16s(a, b, c, opt); - } - - // type 26 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const unsigned short* ptr = a.channel(q); - const float b0 = bfloat16_to_float32(((const unsigned short*)b)[q]); - unsigned short* outptr = c.channel(q); - - for (int i = 0; i < size; i++) - { - outptr[i] = float32_to_bfloat16(op(bfloat16_to_float32(ptr[i]), b0)); - } - } - - return 0; - } - } - else if (a.dims == 3) - { - if (b.dims == 4) - { - // type 23 - c.create(w1, h1, d1, channels1, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const unsigned short* ptr = a.channel(q); - const unsigned short* ptr1 = b.channel(q); - unsigned short* outptr = c.channel(q); - - for (int z = 0; z < d1; z++) - { - for (int y = 0; y < h1; y++) - { - const float a0 = bfloat16_to_float32(ptr[y]); - for (int x = 0; x < w1; x++) - { - outptr[x] = float32_to_bfloat16(op(a0, bfloat16_to_float32(ptr1[x]))); - } - - ptr1 += w1; - outptr += w1; - } - - ptr += h1; - } - } - - return 0; - } - - if (b.dims == 3) - { - if (w1 == 1 && h1 == 1 && channels1 == channels) - { - // special type 1 - c.create(w, h, channels, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const unsigned short* ptr = a.channel(q); - const unsigned short* b0 = b.channel(q); - unsigned short* outptr = c.channel(q); - for (int i = 0; i < size; i++) - { - outptr[i] = float32_to_bfloat16(op(bfloat16_to_float32(ptr[i]), bfloat16_to_float32(b0[0]))); - } - } - - return 0; - } - - if (w1 == w && h1 == h && channels1 == 1) - { - // special type 2 - c.create(w, h, channels, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const unsigned short* ptr = a.channel(q); - const unsigned short* ptr1 = b; - unsigned short* outptr = c.channel(q); - for (int i = 0; i < size; i++) - { - outptr[i] = float32_to_bfloat16(op(bfloat16_to_float32(ptr[i]), bfloat16_to_float32(ptr1[i]))); - } - } - - return 0; - } - - if (w == 1 && h == 1 && channels1 == channels) - { - // special type 3 - c.create(w1, h1, channels1, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const unsigned short* a0 = a.channel(q); - const unsigned short* ptr1 = b.channel(q); - unsigned short* outptr = c.channel(q); - for (int i = 0; i < size1; i++) - { - outptr[i] = float32_to_bfloat16(op(bfloat16_to_float32(a0[0]), bfloat16_to_float32(ptr1[i]))); - } - } - - return 0; - } - - if (w1 == w && h1 == h && channels == 1) - { - // special type 4 - c.create(w1, h1, channels1, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const unsigned short* ptr = a; - const unsigned short* ptr1 = b.channel(q); - unsigned short* outptr = c.channel(q); - for (int i = 0; i < size1; i++) - { - outptr[i] = float32_to_bfloat16(op(bfloat16_to_float32(ptr[i]), bfloat16_to_float32(ptr1[i]))); - } - } - - return 0; - } - - if (w != 1 && w1 == 1 && h1 == h && channels1 == channels) - { - // special type 5 - c.create(w, h, channels, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const unsigned short* ptr = a.channel(q); - const unsigned short* ptr1 = b.channel(q); - unsigned short* outptr = c.channel(q); - - for (int y = 0; y < h; y++) - { - const float b0 = bfloat16_to_float32(ptr1[y]); - for (int x = 0; x < w; x++) - { - outptr[x] = float32_to_bfloat16(op(bfloat16_to_float32(ptr[x]), b0)); - } - - ptr += w; - outptr += w; - } - } + float32x4_t _b_128 = (elempack == 4) ? bfloat2float(vld1_u16((const unsigned short*)b + y * 4)) : vdupq_n_f32(_b); +#endif // __ARM_NEON - return 0; - } + const int size = w * elempack; - if (w1 == w && h != 1 && h1 == 1 && channels1 == channels) + int i = 0; +#if __ARM_NEON + for (; i + 3 < size; i += 4) { - // special type 6 - c.create(w, h, channels, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const unsigned short* ptr = a.channel(q); - const unsigned short* ptr1 = b.channel(q); - unsigned short* outptr = c.channel(q); - - for (int y = 0; y < h; y++) - { - for (int x = 0; x < w; x++) - { - outptr[x] = float32_to_bfloat16(op(bfloat16_to_float32(ptr[x]), bfloat16_to_float32(ptr1[x]))); - } - - ptr += w; - outptr += w; - } - } - - return 0; + float32x4_t _p = bfloat2float(vld1_u16(ptr)); + float32x4_t _outp = op(_p, _b_128); + vst1_u16(outptr, float2bfloat(_outp)); + ptr += 4; + outptr += 4; } - - if (w1 != 1 && w == 1 && h1 == h && channels1 == channels) +#endif // __ARM_NEON + for (; i < size; i++) { - // special type 7 - c.create(w1, h1, channels1, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const unsigned short* ptr = a.channel(q); - const unsigned short* ptr1 = b.channel(q); - unsigned short* outptr = c.channel(q); - - for (int y = 0; y < h1; y++) - { - const float a0 = bfloat16_to_float32(ptr[y]); - for (int x = 0; x < w1; x++) - { - outptr[x] = float32_to_bfloat16(op(a0, bfloat16_to_float32(ptr1[x]))); - } - - ptr1 += w1; - outptr += w1; - } - } - - return 0; + *outptr = float32_to_bfloat16(op(bfloat16_to_float32(*ptr), _b)); + ptr += 1; + outptr += 1; } + } + } - if (w1 == w && h1 != 1 && h == 1 && channels1 == channels) - { - // special type 8 - c.create(w1, h1, channels1, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const unsigned short* ptr = a.channel(q); - const unsigned short* ptr1 = b.channel(q); - unsigned short* outptr = c.channel(q); + if ((a.dims == 3 || a.dims == 4) && b.dims == 1) + { + // type 9 11 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const unsigned short* ptr = a.channel(q); + unsigned short* outptr = c.channel(q); - for (int y = 0; y < h1; y++) - { - for (int x = 0; x < w1; x++) - { - outptr[x] = float32_to_bfloat16(op(bfloat16_to_float32(ptr[x]), bfloat16_to_float32(ptr1[x]))); - } + const float _b = bfloat16_to_float32(((const unsigned short*)b)[q]); +#if __ARM_NEON + float32x4_t _b_128 = (elempack == 4) ? bfloat2float(vld1_u16((const unsigned short*)b + q * 4)) : vdupq_n_f32(_b); +#endif // __ARM_NEON - ptr1 += w1; - outptr += w1; - } - } + const int size = w * h * d * elempack; - return 0; + int i = 0; +#if __ARM_NEON + for (; i + 3 < size; i += 4) + { + float32x4_t _p = bfloat2float(vld1_u16(ptr)); + float32x4_t _outp = op(_p, _b_128); + vst1_u16(outptr, float2bfloat(_outp)); + ptr += 4; + outptr += 4; } - - // type 19 - return binary_op_7_13_19_29_bf16s(a, b, c, opt); - } - - c.create(w, h, channels, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.dims == 2) - { - // type 18 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) +#endif // __ARM_NEON + for (; i < size; i++) { - const unsigned short* ptr = a.channel(q); - const unsigned short* ptr1 = b.row(q); - unsigned short* outptr = c.channel(q); - - for (int y = 0; y < h; y++) - { - const float b0 = bfloat16_to_float32(ptr1[y]); - for (int x = 0; x < w; x++) - { - outptr[x] = float32_to_bfloat16(op(bfloat16_to_float32(ptr[x]), b0)); - } - - ptr += w; - outptr += w; - } + *outptr = float32_to_bfloat16(op(bfloat16_to_float32(*ptr), _b)); + ptr += 1; + outptr += 1; } - - return 0; } + } - if (b.dims == 1) + if (a.dims == 3 && b.dims == 2) + { + // type 10 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - if (b.w == 1) - { - // type 16 - return binary_op_6_11_16_25_bf16s(a, b, c, opt); - } + const unsigned short* ptr = a.channel(q); + const unsigned short* ptr1 = b.row(q); + unsigned short* outptr = c.channel(q); + + const int size = w * elempack; - // type 17 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) + for (int y = 0; y < h; y++) { - const unsigned short* ptr = a.channel(q); - const float b0 = bfloat16_to_float32(((const unsigned short*)b)[q]); - unsigned short* outptr = c.channel(q); + const float _b = bfloat16_to_float32(ptr1[y]); +#if __ARM_NEON + float32x4_t _b_128 = (elempack == 4) ? bfloat2float(vld1_u16(ptr1 + y * 4)) : vdupq_n_f32(_b); +#endif // __ARM_NEON - for (int i = 0; i < size; i++) + int i = 0; +#if __ARM_NEON + for (; i + 3 < size; i += 4) + { + float32x4_t _p = bfloat2float(vld1_u16(ptr)); + float32x4_t _outp = op(_p, _b_128); + vst1_u16(outptr, float2bfloat(_outp)); + ptr += 4; + outptr += 4; + } +#endif // __ARM_NEON + for (; i < size; i++) { - outptr[i] = float32_to_bfloat16(op(bfloat16_to_float32(ptr[i]), b0)); + *outptr = float32_to_bfloat16(op(bfloat16_to_float32(*ptr), _b)); + ptr += 1; + outptr += 1; } } - - return 0; } } - else if (a.dims == 2) + + if (a.dims == 4 && b.dims == 2) { - if (b.dims == 4) + // type 12 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - // type 22 - c.create(w1, h1, d1, channels1, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; + const unsigned short* ptr = a.channel(q); + const unsigned short* ptr1 = b.row(q); + unsigned short* outptr = c.channel(q); + + const int size = w * h * elempack; - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) + for (int z = 0; z < d; z++) { - const unsigned short* ptr = a.row(q); - const unsigned short* ptr1 = b.channel(q); - unsigned short* outptr = c.channel(q); + const float _b = bfloat16_to_float32(ptr1[z]); +#if __ARM_NEON + float32x4_t _b_128 = (elempack == 4) ? bfloat2float(vld1_u16(ptr1 + z * 4)) : vdupq_n_f32(_b); +#endif // __ARM_NEON - for (int z = 0; z < d1; z++) + int i = 0; +#if __ARM_NEON + for (; i + 3 < size; i += 4) { - const float a0 = bfloat16_to_float32(ptr[z]); - for (int y = 0; y < h1; y++) - { - for (int x = 0; x < w1; x++) - { - outptr[x] = float32_to_bfloat16(op(a0, bfloat16_to_float32(ptr1[x]))); - } - - ptr1 += w1; - outptr += w1; - } + float32x4_t _p = bfloat2float(vld1_u16(ptr)); + float32x4_t _outp = op(_p, _b_128); + vst1_u16(outptr, float2bfloat(_outp)); + ptr += 4; + outptr += 4; + } +#endif // __ARM_NEON + for (; i < size; i++) + { + *outptr = float32_to_bfloat16(op(bfloat16_to_float32(*ptr), _b)); + ptr += 1; + outptr += 1; } } - - return 0; } + } - if (b.dims == 3) + if (a.dims == 4 && b.dims == 3) + { + // type 13 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - // type 14 - c.create(w1, h1, channels1, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; + const unsigned short* ptr = a.channel(q); + unsigned short* outptr = c.channel(q); - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) + const int size = w * elempack; + + for (int z = 0; z < d; z++) { - const unsigned short* ptr = a.row(q); - const unsigned short* ptr1 = b.channel(q); - unsigned short* outptr = c.channel(q); + const unsigned short* ptr1 = b.channel(q).row(z); - for (int y = 0; y < h1; y++) + for (int y = 0; y < h; y++) { - const float a0 = bfloat16_to_float32(ptr[y]); - for (int x = 0; x < w1; x++) + const float _b = bfloat16_to_float32(ptr1[y]); +#if __ARM_NEON + float32x4_t _b_128 = (elempack == 4) ? bfloat2float(vld1_u16(ptr1 + y * 4)) : vdupq_n_f32(_b); +#endif // __ARM_NEON + + int i = 0; +#if __ARM_NEON + for (; i + 3 < size; i += 4) + { + float32x4_t _p = bfloat2float(vld1_u16(ptr)); + float32x4_t _outp = op(_p, _b_128); + vst1_u16(outptr, float2bfloat(_outp)); + ptr += 4; + outptr += 4; + } +#endif // __ARM_NEON + for (; i < size; i++) { - outptr[x] = float32_to_bfloat16(op(a0, bfloat16_to_float32(ptr1[x]))); + *outptr = float32_to_bfloat16(op(bfloat16_to_float32(*ptr), _b)); + ptr += 1; + outptr += 1; } - - ptr1 += w1; - outptr += w1; } } - - return 0; } + } - c.create(w, h, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; + return 0; +} - if (b.dims == 2) - { - // type 13 - return binary_op_7_13_19_29_bf16s(a, b, c, opt); - } +template +static int binary_op_broadcast_outer_bf16s(const Mat& a, const Mat& b, Mat& c, const Option& opt) +{ + Op op; - if (b.dims == 1) + int w = a.w; + int h = a.h; + int d = a.d; + int channels = a.c; + int elempack = a.elempack; + + if (a.dims == 2) + { + // type 14 + #pragma omp parallel for num_threads(opt.num_threads) + for (int y = 0; y < h; y++) { - c.create(w, h, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; + const unsigned short* ptr = a.row(y); + const unsigned short* ptr1 = b; + unsigned short* outptr = c.row(y); - if (b.w == 1) +#if __ARM_NEON + if (elempack == 4) { - // type 11 - return binary_op_6_11_16_25_bf16s(a, b, c, opt); + for (int x = 0; x < w; x++) + { + float32x4_t _p = bfloat2float(vld1_u16(ptr)); + float32x4_t _b = vdupq_n_f32(bfloat16_to_float32(*ptr1)); + float32x4_t _outp = op(_p, _b); + vst1_u16(outptr, float2bfloat(_outp)); + ptr += 4; + ptr1 += 1; + outptr += 4; + } } - - // type 12 - const unsigned short* ptr = a; - unsigned short* outptr = c; - - for (int y = 0; y < h; y++) +#endif // __ARM_NEON + if (elempack == 1) { - const float b0 = bfloat16_to_float32(((const unsigned short*)b)[y]); for (int x = 0; x < w; x++) { - outptr[x] = float32_to_bfloat16(op(bfloat16_to_float32(ptr[x]), b0)); + *outptr = float32_to_bfloat16(op(bfloat16_to_float32(*ptr), bfloat16_to_float32(*ptr1))); + ptr += 1; + ptr1 += 1; + outptr += 1; } - - ptr += w; - outptr += w; } - - return 0; } } - else if (a.dims == 1) - { - if (a.w == 1) - { - // type 2 3 4 20 - return binary_op_2_3_4_20_bf16s(a, b, c, opt); - } - if (b.dims == 4) + if (a.dims == 3 || a.dims == 4) + { + // type 15 16 17 18 19 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - // type 21 - c.create(w1, h1, d1, channels1, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; + const unsigned short* ptr = a.channel(q); + unsigned short* outptr = c.channel(q); - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) + for (int z = 0; z < d; z++) { - const float a0 = bfloat16_to_float32(((const unsigned short*)a)[q]); - const unsigned short* ptr1 = b.channel(q); - unsigned short* outptr = c.channel(q); - - for (int i = 0; i < size1; i++) + int z1 = std::min(z, b.d - 1); + for (int y = 0; y < h; y++) { - outptr[i] = float32_to_bfloat16(op(a0, bfloat16_to_float32(ptr1[i]))); + int y1 = std::min(y, b.h - 1); + + const unsigned short* ptr1 = b.depth(z1).row(y1); + +#if __ARM_NEON + if (elempack == 4) + { + for (int x = 0; x < w; x++) + { + float32x4_t _p = bfloat2float(vld1_u16(ptr)); + float32x4_t _b = vdupq_n_f32(bfloat16_to_float32(*ptr1)); + float32x4_t _outp = op(_p, _b); + vst1_u16(outptr, float2bfloat(_outp)); + ptr += 4; + ptr1 += 1; + outptr += 4; + } + } +#endif // __ARM_NEON + if (elempack == 1) + { + for (int x = 0; x < w; x++) + { + *outptr = float32_to_bfloat16(op(bfloat16_to_float32(*ptr), bfloat16_to_float32(*ptr1))); + ptr += 1; + ptr1 += 1; + outptr += 1; + } + } } } - - return 0; } + } - if (b.dims == 3) - { - // type 9 - c.create(w1, h1, channels1, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; + return 0; +} - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float a0 = bfloat16_to_float32(((const unsigned short*)a)[q]); - const unsigned short* ptr1 = b.channel(q); - unsigned short* outptr = c.channel(q); +template +static int binary_op_broadcast_20_bf16s(const Mat& a, const Mat& b, Mat& c, const Option& opt) +{ + Op op; - for (int i = 0; i < size1; i++) - { - outptr[i] = float32_to_bfloat16(op(a0, bfloat16_to_float32(ptr1[i]))); - } - } + int w = a.w; + int h = a.h; + int channels = a.c; + int elempack = a.elempack; - return 0; - } + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const unsigned short* ptr = a.channel(q); + unsigned short* outptr = c.channel(q); - if (b.dims == 2) + for (int y = 0; y < h; y++) { - // type 8 - c.create(w1, h1, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; + const unsigned short* ptr1 = b.channel(q); - const unsigned short* ptr1 = b; - unsigned short* outptr = c; + const int size = w * elempack; - for (int y = 0; y < h1; y++) + int i = 0; +#if __ARM_NEON + for (; i + 3 < size; i += 4) { - const float a0 = bfloat16_to_float32(((const unsigned short*)a)[y]); - for (int x = 0; x < w1; x++) - { - outptr[x] = float32_to_bfloat16(op(a0, bfloat16_to_float32(ptr1[x]))); - } - - ptr1 += w1; - outptr += w1; + float32x4_t _p = bfloat2float(vld1_u16(ptr)); + float32x4_t _p1 = bfloat2float(vld1_u16(ptr1)); + float32x4_t _outp = op(_p, _p1); + vst1_u16(outptr, float2bfloat(_outp)); + ptr += 4; + ptr1 += 4; + outptr += 4; } - - return 0; - } - - if (b.dims == 1) - { - c.create(w, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.w == 1) +#endif // __ARM_NEON + for (; i < size; i++) { - // type 6 - return binary_op_6_11_16_25_bf16s(a, b, c, opt); + *outptr = float32_to_bfloat16(op(bfloat16_to_float32(*ptr), bfloat16_to_float32(*ptr1))); + ptr += 1; + ptr1 += 1; + outptr += 1; } - - // type 7 - binary_op_7_13_19_29_bf16s(a, b, c, opt); } } @@ -2551,12 +1220,8 @@ static int binary_op_scalar_inplace_bf16s(Mat& a, float b, const Option& opt) { Op op; - int w = a.w; - int h = a.h; - int d = a.d; - int channels = a.c; - int elempack = a.elempack; - int size = w * h * d * elempack; + const int channels = a.c; + const int size = a.w * a.h * a.d * a.elempack; #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) @@ -2577,85 +1242,172 @@ static int binary_op_scalar_inplace_bf16s(Mat& a, float b, const Option& opt) for (; i < size; i++) { *ptr = float32_to_bfloat16(op(bfloat16_to_float32(*ptr), b)); - ptr += 1; + ptr++; } } return 0; } -int BinaryOp_arm::forward_bf16s(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +static int binary_op_scalar_bf16s(const Mat& a, float b, Mat& c, int op_type, const Option& opt) { - const Mat& bottom_blob = bottom_blobs[0]; - const Mat& bottom_blob1 = bottom_blobs[1]; + using namespace BinaryOp_arm_functor; - Mat& top_blob = top_blobs[0]; + if (op_type == BinaryOp::Operation_ADD) return binary_op_scalar_bf16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_SUB) return binary_op_scalar_bf16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_MUL) return binary_op_scalar_bf16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_DIV) return binary_op_scalar_bf16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_MAX) return binary_op_scalar_bf16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_MIN) return binary_op_scalar_bf16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_POW) return binary_op_scalar_bf16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_RSUB) return binary_op_scalar_bf16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_RDIV) return binary_op_scalar_bf16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_RPOW) return binary_op_scalar_bf16s(a, b, c, opt); + + // should never reach here + return 0; +} +static int binary_op_no_broadcast_bf16s(const Mat& a, const Mat& b, Mat& c, int op_type, const Option& opt) +{ using namespace BinaryOp_arm_functor; - int elempack = bottom_blob.elempack; - int elempack1 = bottom_blob1.elempack; + if (op_type == BinaryOp::Operation_ADD) return binary_op_no_broadcast_bf16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_SUB) return binary_op_no_broadcast_bf16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_MUL) return binary_op_no_broadcast_bf16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_DIV) return binary_op_no_broadcast_bf16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_MAX) return binary_op_no_broadcast_bf16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_MIN) return binary_op_no_broadcast_bf16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_POW) return binary_op_no_broadcast_bf16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_RSUB) return binary_op_no_broadcast_bf16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_RDIV) return binary_op_no_broadcast_bf16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_RPOW) return binary_op_no_broadcast_bf16s(a, b, c, opt); + + // should never reach here + return 0; +} + +static int binary_op_broadcast_inner_bf16s(const Mat& a, const Mat& b, Mat& c, int op_type, const Option& opt) +{ + // squeeze inner axes + Mat b2 = b; + if (b.dims == 2 && b.w == 1) + b2 = b.reshape(b.h); + else if (b.dims == 3 && b.h == 1) + b2 = b.reshape(b.c); + else if (b.dims == 3 && b.w == 1) + b2 = b.reshape(b.h, b.c); + else if (b.dims == 4 && b.d == 1) + b2 = b.reshape(b.c); + else if (b.dims == 4 && b.h == 1) + b2 = b.reshape(b.d, b.c); + else if (b.dims == 4 && b.w == 1) + b2 = b.reshape(b.h, b.d, b.c); -#if __ARM_NEON - if (elempack == 4 || elempack1 == 4) - { - if (op_type == Operation_ADD) - return binary_op_pack4_bf16s(bottom_blob, bottom_blob1, top_blob, opt); + using namespace BinaryOp_arm_functor; - if (op_type == Operation_SUB) - return binary_op_pack4_bf16s(bottom_blob, bottom_blob1, top_blob, opt); + if (op_type == BinaryOp::Operation_ADD) return binary_op_broadcast_inner_bf16s(a, b2, c, opt); + if (op_type == BinaryOp::Operation_SUB) return binary_op_broadcast_inner_bf16s(a, b2, c, opt); + if (op_type == BinaryOp::Operation_MUL) return binary_op_broadcast_inner_bf16s(a, b2, c, opt); + if (op_type == BinaryOp::Operation_DIV) return binary_op_broadcast_inner_bf16s(a, b2, c, opt); + if (op_type == BinaryOp::Operation_MAX) return binary_op_broadcast_inner_bf16s(a, b2, c, opt); + if (op_type == BinaryOp::Operation_MIN) return binary_op_broadcast_inner_bf16s(a, b2, c, opt); + if (op_type == BinaryOp::Operation_POW) return binary_op_broadcast_inner_bf16s(a, b2, c, opt); + if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_inner_bf16s(a, b2, c, opt); + if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_inner_bf16s(a, b2, c, opt); + if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_inner_bf16s(a, b2, c, opt); + + // should never reach here + return 0; +} - if (op_type == Operation_MUL) - return binary_op_pack4_bf16s(bottom_blob, bottom_blob1, top_blob, opt); +static int binary_op_broadcast_outer_bf16s(const Mat& a, const Mat& b, Mat& c, int op_type, const Option& opt) +{ + using namespace BinaryOp_arm_functor; - if (op_type == Operation_DIV) - return binary_op_pack4_bf16s(bottom_blob, bottom_blob1, top_blob, opt); + if (op_type == BinaryOp::Operation_ADD) return binary_op_broadcast_outer_bf16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_SUB) return binary_op_broadcast_outer_bf16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_MUL) return binary_op_broadcast_outer_bf16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_DIV) return binary_op_broadcast_outer_bf16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_MAX) return binary_op_broadcast_outer_bf16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_MIN) return binary_op_broadcast_outer_bf16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_POW) return binary_op_broadcast_outer_bf16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_outer_bf16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_outer_bf16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_outer_bf16s(a, b, c, opt); + + // should never reach here + return 0; +} - if (op_type == Operation_MAX) - return binary_op_pack4_bf16s(bottom_blob, bottom_blob1, top_blob, opt); +static int binary_op_broadcast_20_bf16s(const Mat& a, const Mat& b, Mat& c, int op_type, const Option& opt) +{ + using namespace BinaryOp_arm_functor; - if (op_type == Operation_MIN) - return binary_op_pack4_bf16s(bottom_blob, bottom_blob1, top_blob, opt); + if (op_type == BinaryOp::Operation_ADD) return binary_op_broadcast_20_bf16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_SUB) return binary_op_broadcast_20_bf16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_MUL) return binary_op_broadcast_20_bf16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_DIV) return binary_op_broadcast_20_bf16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_MAX) return binary_op_broadcast_20_bf16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_MIN) return binary_op_broadcast_20_bf16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_POW) return binary_op_broadcast_20_bf16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_20_bf16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_20_bf16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_20_bf16s(a, b, c, opt); + + // should never reach here + return 0; +} - if (op_type == Operation_POW) - return binary_op_pack4_bf16s(bottom_blob, bottom_blob1, top_blob, opt); +int BinaryOp_arm::forward_bf16s(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +{ + const bool b_is_scalar = bottom_blobs[1].w * bottom_blobs[1].h * bottom_blobs[1].d * bottom_blobs[1].c * bottom_blobs[1].elempack == 1; + const bool a_rank_is_lower = bottom_blobs[0].dims < bottom_blobs[1].dims && !b_is_scalar; + const bool a_size_is_lower = bottom_blobs[0].w * bottom_blobs[0].h * bottom_blobs[0].d * bottom_blobs[0].c * bottom_blobs[0].elempack < bottom_blobs[1].w * bottom_blobs[1].h * bottom_blobs[1].d * bottom_blobs[1].c * bottom_blobs[1].elempack; + const bool a_is_lower = a_rank_is_lower || (!a_rank_is_lower && a_size_is_lower); + const Mat& A = a_is_lower ? bottom_blobs[1] : bottom_blobs[0]; + const Mat& B = a_is_lower ? bottom_blobs[0] : bottom_blobs[1]; + const int op_type_r = a_is_lower ? get_reverse_op_type(op_type) : op_type; - if (op_type == Operation_RSUB) - return binary_op_pack4_bf16s(bottom_blob1, bottom_blob, top_blob, opt); + Mat& top_blob = top_blobs[0]; + top_blob.create_like(A, opt.blob_allocator); + if (top_blob.empty()) + return -100; - if (op_type == Operation_RDIV) - return binary_op_pack4_bf16s(bottom_blob1, bottom_blob, top_blob, opt); + // B is a scalar + if (B.w * B.h * B.d * B.c * B.elempack == 1) + { + return binary_op_scalar_bf16s(A, bfloat16_to_float32(((const unsigned short*)B)[0]), top_blob, op_type_r, opt); } -#endif // __ARM_NEON - if (elempack == 1 && elempack1 == 1) + // no broadcast + if (A.dims == B.dims && A.w == B.w && A.h == B.h && A.d == B.d && A.c == B.c && A.elempack == B.elempack) { - if (op_type == Operation_ADD) - return binary_op_bf16s(bottom_blob, bottom_blob1, top_blob, opt); - - if (op_type == Operation_SUB) - return binary_op_bf16s(bottom_blob, bottom_blob1, top_blob, opt); - - if (op_type == Operation_MUL) - return binary_op_bf16s(bottom_blob, bottom_blob1, top_blob, opt); - - if (op_type == Operation_DIV) - return binary_op_bf16s(bottom_blob, bottom_blob1, top_blob, opt); - - if (op_type == Operation_MAX) - return binary_op_bf16s(bottom_blob, bottom_blob1, top_blob, opt); - - if (op_type == Operation_MIN) - return binary_op_bf16s(bottom_blob, bottom_blob1, top_blob, opt); + return binary_op_no_broadcast_bf16s(A, B, top_blob, op_type_r, opt); + } - if (op_type == Operation_POW) - return binary_op_bf16s(bottom_blob, bottom_blob1, top_blob, opt); + // broadcast B for inner axis + if ((B.dims < A.dims) + || (A.dims == 2 && B.w == 1 && B.h == A.h) + || (A.dims == 3 && B.w == 1 && B.h == 1 && B.c == A.c) + || (A.dims == 3 && B.w == 1 && B.h == A.h && B.c == A.c) + || (A.dims == 4 && B.w == 1 && B.h == 1 && B.d == 1 && B.c == A.c) + || (A.dims == 4 && B.w == 1 && B.h == 1 && B.d == A.d && B.c == A.c) + || (A.dims == 4 && B.w == 1 && B.h == A.h && B.d == A.d && B.c == A.c)) + { + return binary_op_broadcast_inner_bf16s(A, B, top_blob, op_type_r, opt); + } - if (op_type == Operation_RSUB) - return binary_op_bf16s(bottom_blob1, bottom_blob, top_blob, opt); + // broadcast B for outer axis + if (B.elempack == 1 && ((A.dims == 2 && B.w == A.w && B.h == 1) || (A.dims == 3 && B.w == A.w && B.h == 1 && B.c == 1) || (A.dims == 3 && B.w == A.w && B.h == A.h && B.c == 1) || (A.dims == 4 && B.w == A.w && B.h == 1 && B.d == 1 && B.c == 1) || (A.dims == 4 && B.w == A.w && B.h == A.h && B.d == 1 && B.c == 1) || (A.dims == 4 && B.w == A.w && B.h == A.h && B.d == A.d && B.c == 1))) + { + return binary_op_broadcast_outer_bf16s(A, B, top_blob, op_type_r, opt); + } - if (op_type == Operation_RDIV) - return binary_op_bf16s(bottom_blob1, bottom_blob, top_blob, opt); + // some special broadcast rule here + if (A.dims == 3 && B.dims == 3 && A.w == B.w && B.h == 1 && A.c == B.c) + { + return binary_op_broadcast_20_bf16s(A, B, top_blob, op_type_r, opt); } return 0; @@ -2665,32 +1417,16 @@ int BinaryOp_arm::forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) { using namespace BinaryOp_arm_functor; - if (op_type == Operation_ADD) - return binary_op_scalar_inplace_bf16s(bottom_top_blob, b, opt); - - if (op_type == Operation_SUB) - return binary_op_scalar_inplace_bf16s(bottom_top_blob, b, opt); - - if (op_type == Operation_MUL) - return binary_op_scalar_inplace_bf16s(bottom_top_blob, b, opt); - - if (op_type == Operation_DIV) - return binary_op_scalar_inplace_bf16s(bottom_top_blob, b, opt); - - if (op_type == Operation_MAX) - return binary_op_scalar_inplace_bf16s(bottom_top_blob, b, opt); - - if (op_type == Operation_MIN) - return binary_op_scalar_inplace_bf16s(bottom_top_blob, b, opt); - - if (op_type == Operation_POW) - return binary_op_scalar_inplace_bf16s(bottom_top_blob, b, opt); - - if (op_type == Operation_RSUB) - return binary_op_scalar_inplace_bf16s(bottom_top_blob, b, opt); - - if (op_type == Operation_RDIV) - return binary_op_scalar_inplace_bf16s(bottom_top_blob, b, opt); + if (op_type == Operation_ADD) return binary_op_scalar_inplace_bf16s(bottom_top_blob, b, opt); + if (op_type == Operation_SUB) return binary_op_scalar_inplace_bf16s(bottom_top_blob, b, opt); + if (op_type == Operation_MUL) return binary_op_scalar_inplace_bf16s(bottom_top_blob, b, opt); + if (op_type == Operation_DIV) return binary_op_scalar_inplace_bf16s(bottom_top_blob, b, opt); + if (op_type == Operation_MAX) return binary_op_scalar_inplace_bf16s(bottom_top_blob, b, opt); + if (op_type == Operation_MIN) return binary_op_scalar_inplace_bf16s(bottom_top_blob, b, opt); + if (op_type == Operation_POW) return binary_op_scalar_inplace_bf16s(bottom_top_blob, b, opt); + if (op_type == Operation_RSUB) return binary_op_scalar_inplace_bf16s(bottom_top_blob, b, opt); + if (op_type == Operation_RDIV) return binary_op_scalar_inplace_bf16s(bottom_top_blob, b, opt); + if (op_type == Operation_RPOW) return binary_op_scalar_inplace_bf16s(bottom_top_blob, b, opt); return 0; } diff --git a/src/layer/arm/binaryop_arm_asimdhp.cpp b/src/layer/arm/binaryop_arm_asimdhp.cpp index ba00de5ce..d6e93fe99 100644 --- a/src/layer/arm/binaryop_arm_asimdhp.cpp +++ b/src/layer/arm/binaryop_arm_asimdhp.cpp @@ -25,105 +25,42 @@ namespace ncnn { #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC template -static int binary_op_2_3_4_20_fp16s(const Mat& a, const Mat& b, Mat& c, const Option& opt) +static int binary_op_scalar_fp16s(const Mat& a, __fp16 b, Mat& c, const Option& opt) { Op op; - int w = b.w; - int h = b.h; - int d = b.d; - int channels = b.c; - int elempack = b.elempack; - int size = w * h * d * elempack; - - // type 2 3 4 20 - c.create_like(b, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const __fp16 a0 = ((const __fp16*)a)[0]; - const __fp16* ptr = b.channel(q); - __fp16* outptr = c.channel(q); - - int i = 0; - float16x8_t _a0 = vdupq_n_f16(a0); - for (; i + 7 < size; i += 8) - { - float16x8_t _p = vld1q_f16(ptr); - float16x8_t _outp = op(_a0, _p); - vst1q_f16(outptr, _outp); - ptr += 8; - outptr += 8; - } - for (; i + 3 < size; i += 4) - { - float16x4_t _p = vld1_f16(ptr); - float16x4_t _outp = op(vget_low_f16(_a0), _p); - vst1_f16(outptr, _outp); - ptr += 4; - outptr += 4; - } - for (; i < size; i++) - { - *outptr = op(a0, *ptr); - ptr += 1; - outptr += 1; - } - } - - return 0; -} - -template -static int binary_op_6_11_16_25_fp16s(const Mat& a, const Mat& b, Mat& c, const Option& opt) -{ - Op op; - - int w = a.w; - int h = a.h; - int d = a.d; - int channels = a.c; - int elempack = a.elempack; - int size = w * h * d * elempack; - - // type 6 11 16 25 - c.create_like(a, opt.blob_allocator); - if (c.empty()) - return -100; + const int channels = a.c; + const int size = a.w * a.h * a.d * a.elempack; #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) { const __fp16* ptr = a.channel(q); - const __fp16 b0 = ((const __fp16*)b)[0]; __fp16* outptr = c.channel(q); int i = 0; - float16x8_t _b0 = vdupq_n_f16(b0); + float16x8_t _b = vdupq_n_f16(b); for (; i + 7 < size; i += 8) { float16x8_t _p = vld1q_f16(ptr); - float16x8_t _outp = op(_p, _b0); - vst1q_f16(outptr, _outp); + _p = op(_p, _b); + vst1q_f16(outptr, _p); ptr += 8; outptr += 8; } for (; i + 3 < size; i += 4) { float16x4_t _p = vld1_f16(ptr); - float16x4_t _outp = op(_p, vget_low_f16(_b0)); - vst1_f16(outptr, _outp); + _p = op(_p, vget_low_f16(_b)); + vst1_f16(outptr, _p); ptr += 4; outptr += 4; } for (; i < size; i++) { - *outptr = op(*ptr, b0); - ptr += 1; - outptr += 1; + *outptr = op(*ptr, b); + ptr++; + outptr++; } } @@ -131,21 +68,12 @@ static int binary_op_6_11_16_25_fp16s(const Mat& a, const Mat& b, Mat& c, const } template -static int binary_op_7_13_19_29_fp16s(const Mat& a, const Mat& b, Mat& c, const Option& opt) +static int binary_op_no_broadcast_fp16s(const Mat& a, const Mat& b, Mat& c, const Option& opt) { Op op; - int w = a.w; - int h = a.h; - int d = a.d; - int channels = a.c; - int elempack = a.elempack; - int size = w * h * d * elempack; - - // type 7 13 19 29 - c.create_like(a, opt.blob_allocator); - if (c.empty()) - return -100; + const int channels = a.c; + const int size = a.w * a.h * a.d * a.elempack; #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) @@ -188,7 +116,7 @@ static int binary_op_7_13_19_29_fp16s(const Mat& a, const Mat& b, Mat& c, const } template -static int binary_op_pack8_fp16s(const Mat& a, const Mat& b, Mat& c, const Option& opt) +static int binary_op_broadcast_inner_fp16s(const Mat& a, const Mat& b, Mat& c, const Option& opt) { Op op; @@ -196,2009 +124,404 @@ static int binary_op_pack8_fp16s(const Mat& a, const Mat& b, Mat& c, const Optio int h = a.h; int d = a.d; int channels = a.c; - int size = w * h * d; - size_t elemsize = a.elemsize; int elempack = a.elempack; - int w1 = b.w; - int h1 = b.h; - int d1 = b.d; - int channels1 = b.c; - int size1 = w1 * h1 * d1; - size_t elemsize1 = b.elemsize; - int elempack1 = b.elempack; - - if (a.dims == 4) - { - if (b.dims == 4) - { - // type 29 - return binary_op_7_13_19_29_fp16s(a, b, c, opt); - } - - c.create(w, h, d, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.dims == 3) - { - // type 28 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const __fp16* ptr = a.channel(q); - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - - for (int z = 0; z < d; z++) - { - for (int y = 0; y < h; y++) - { - float16x8_t _b0 = vld1q_f16(ptr1); - for (int x = 0; x < w; x++) - { - float16x8_t _p = vld1q_f16(ptr); - float16x8_t _outp = op(_p, _b0); - vst1q_f16(outptr, _outp); - ptr += 8; - outptr += 8; - } - - ptr1 += 8; - } - } - } - - return 0; - } - - if (b.dims == 2) - { - // type 27 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const __fp16* ptr = a.channel(q); - const __fp16* ptr1 = b.row(q); - __fp16* outptr = c.channel(q); - - for (int z = 0; z < d; z++) - { - float16x8_t _b0 = vld1q_f16(ptr1); - for (int y = 0; y < h; y++) - { - for (int x = 0; x < w; x++) - { - float16x8_t _p = vld1q_f16(ptr); - float16x8_t _outp = op(_p, _b0); - vst1q_f16(outptr, _outp); - ptr += 8; - outptr += 8; - } - } - - ptr1 += 8; - } - } - - return 0; - } - - if (b.dims == 1) - { - if (b.w == 1 && elempack1 == 1) - { - // type 25 - return binary_op_6_11_16_25_fp16s(a, b, c, opt); - } - - // type 26 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const __fp16* ptr = a.channel(q); - float16x8_t _b0 = vld1q_f16((const __fp16*)b + q * 8); - __fp16* outptr = c.channel(q); - - for (int i = 0; i < size; i++) - { - float16x8_t _p = vld1q_f16(ptr); - float16x8_t _outp = op(_p, _b0); - vst1q_f16(outptr, _outp); - ptr += 8; - outptr += 8; - } - } - - return 0; - } - } - else if (a.dims == 3) + if (a.dims == 2 && b.dims == 1) { - if (b.dims == 4) - { - // type 23 - c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const __fp16* ptr = a.channel(q); - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - - for (int z = 0; z < d1; z++) - { - for (int y = 0; y < h1; y++) - { - float16x8_t _a0 = vld1q_f16(ptr); - for (int x = 0; x < w1; x++) - { - float16x8_t _p = vld1q_f16(ptr1); - float16x8_t _outp = op(_a0, _p); - vst1q_f16(outptr, _outp); - ptr1 += 8; - outptr += 8; - } - - ptr += 8; - } - } - } - - return 0; - } - - if (b.dims == 3) - { - if (w1 == 1 && h1 == 1 && channels1 == channels) - { - // special type 1 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const __fp16* ptr = a.channel(q); - __fp16* outptr = c.channel(q); - const __fp16* b0 = b.channel(q); - float16x8_t _b0 = vld1q_f16(b0); - for (int i = 0; i < size; i++) - { - float16x8_t _p = vld1q_f16(ptr); - float16x8_t _outp = op(_p, _b0); - vst1q_f16(outptr, _outp); - ptr += 8; - outptr += 8; - } - } - - return 0; - } - - if (w1 == w && h1 == h && channels1 == 1 && elempack1 == 1) - { - // special type 2 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const __fp16* ptr = a.channel(q); - const __fp16* ptr1 = b; - __fp16* outptr = c.channel(q); - for (int i = 0; i < size; i++) - { - float16x8_t _p = vld1q_f16(ptr); - float16x8_t _p1 = vdupq_n_f16(*ptr1); - float16x8_t _outp = op(_p, _p1); - vst1q_f16(outptr, _outp); - ptr += 8; - ptr1 += 1; - outptr += 8; - } - } - - return 0; - } - - if (w == 1 && h == 1 && channels1 == channels) - { - // special type 3 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const __fp16* a0 = a.channel(q); - __fp16* outptr = c.channel(q); - const __fp16* ptr1 = b.channel(q); - float16x8_t _a0 = vld1q_f16(a0); - for (int i = 0; i < size1; i++) - { - float16x8_t _p1 = vld1q_f16(ptr1); - float16x8_t _outp = op(_a0, _p1); - vst1q_f16(outptr, _outp); - ptr1 += 8; - outptr += 8; - } - } - - return 0; - } - - if (w1 == w && h1 == h && channels == 1 && elempack == 1) - { - // special type 4 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const __fp16* ptr = a; - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - for (int i = 0; i < size1; i++) - { - float16x8_t _p = vdupq_n_f16(*ptr); - float16x8_t _p1 = vld1q_f16(ptr1); - float16x8_t _outp = op(_p, _p1); - vst1q_f16(outptr, _outp); - ptr += 1; - ptr1 += 8; - outptr += 8; - } - } - - return 0; - } - - if (w != 1 && w1 == 1 && h1 == h && channels1 == channels) - { - // special type 5 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const __fp16* ptr = a.channel(q); - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - - for (int y = 0; y < h; y++) - { - float16x8_t _p1 = vld1q_f16(ptr1 + y * 8); - for (int x = 0; x < w; x++) - { - float16x8_t _p = vld1q_f16(ptr); - float16x8_t _outp = op(_p, _p1); - vst1q_f16(outptr, _outp); - - ptr += 8; - outptr += 8; - } - } - } - - return 0; - } - - if (w1 == w && h != 1 && h1 == 1 && channels1 == channels) - { - // special type 6 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const __fp16* ptr = a.channel(q); - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - - for (int y = 0; y < h; y++) - { - for (int x = 0; x < w; x++) - { - float16x8_t _p = vld1q_f16(ptr); - float16x8_t _p1 = vld1q_f16(ptr1 + x * 8); - float16x8_t _outp = op(_p, _p1); - vst1q_f16(outptr, _outp); - - ptr += 8; - outptr += 8; - } - } - } - - return 0; - } - - if (w1 != 1 && w == 1 && h1 == h && channels1 == channels) - { - // special type 7 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const __fp16* ptr = a.channel(q); - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - - for (int y = 0; y < h1; y++) - { - float16x8_t _p = vld1q_f16(ptr + y * 8); - for (int x = 0; x < w1; x++) - { - float16x8_t _p1 = vld1q_f16(ptr1); - float16x8_t _outp = op(_p, _p1); - vst1q_f16(outptr, _outp); - - ptr1 += 8; - outptr += 8; - } - } - } - - return 0; - } - - if (w1 == w && h1 != 1 && h == 1 && channels1 == channels) - { - // special type 8 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const __fp16* ptr = a.channel(q); - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - - for (int y = 0; y < h1; y++) - { - for (int x = 0; x < w1; x++) - { - float16x8_t _p = vld1q_f16(ptr + x * 8); - float16x8_t _p1 = vld1q_f16(ptr1); - float16x8_t _outp = op(_p, _p1); - vst1q_f16(outptr, _outp); - - ptr1 += 8; - outptr += 8; - } - } - } - - return 0; - } - - // type 19 - return binary_op_7_13_19_29_fp16s(a, b, c, opt); - } - - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.dims == 2) + // type 8 + #pragma omp parallel for num_threads(opt.num_threads) + for (int y = 0; y < h; y++) { - // type 18 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const __fp16* ptr = a.channel(q); - const __fp16* ptr1 = b.row(q); - __fp16* outptr = c.channel(q); - - for (int y = 0; y < h; y++) - { - float16x8_t _b0 = vld1q_f16(ptr1); - for (int x = 0; x < w; x++) - { - float16x8_t _p = vld1q_f16(ptr); - float16x8_t _outp = op(_p, _b0); - vst1q_f16(outptr, _outp); - ptr += 8; - outptr += 8; - } - - ptr1 += 8; - } - } - - return 0; - } - - if (b.dims == 1) - { - if (b.w == 1 && elempack1 == 1) - { - // type 16 - return binary_op_6_11_16_25_fp16s(a, b, c, opt); - } - - // type 17 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const __fp16* ptr = a.channel(q); - float16x8_t _b0 = vld1q_f16((const __fp16*)b + q * 8); - __fp16* outptr = c.channel(q); - - for (int i = 0; i < size; i++) - { - float16x8_t _p = vld1q_f16(ptr); - float16x8_t _outp = op(_p, _b0); - vst1q_f16(outptr, _outp); - ptr += 8; - outptr += 8; - } - } - - return 0; - } - } - else if (a.dims == 2) - { - if (b.dims == 4) - { - // type 22 - c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const __fp16* ptr = a.row(q); - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - - for (int z = 0; z < d1; z++) - { - float16x8_t _a0 = vld1q_f16(ptr); - for (int y = 0; y < h1; y++) - { - for (int x = 0; x < w1; x++) - { - float16x8_t _p = vld1q_f16(ptr1); - float16x8_t _outp = op(_a0, _p); - vst1q_f16(outptr, _outp); - ptr1 += 8; - outptr += 8; - } - } - - ptr += 8; - } - } - - return 0; - } - - if (b.dims == 3) - { - // type 14 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const __fp16* ptr = a.row(q); - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - - for (int y = 0; y < h1; y++) - { - float16x8_t _a0 = vld1q_f16(ptr); - for (int x = 0; x < w1; x++) - { - float16x8_t _p1 = vld1q_f16(ptr1); - float16x8_t _outp = op(_a0, _p1); - vst1q_f16(outptr, _outp); - ptr1 += 8; - outptr += 8; - } - - ptr += 8; - } - } - - return 0; - } - - c.create(w, h, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.dims == 2) - { - // type 13 - return binary_op_7_13_19_29_fp16s(a, b, c, opt); - } - - if (b.dims == 1) - { - c.create(w, h, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.w == 1 && elempack1 == 1) - { - // type 11 - return binary_op_6_11_16_25_fp16s(a, b, c, opt); - } - - // type 12 - const __fp16* ptr = a; - const __fp16* ptr1 = b; - __fp16* outptr = c; - - for (int y = 0; y < h; y++) - { - float16x8_t _b0 = vld1q_f16(ptr1); - for (int x = 0; x < w; x++) - { - float16x8_t _p = vld1q_f16(ptr); - float16x8_t _outp = op(_p, _b0); - vst1q_f16(outptr, _outp); - ptr += 8; - outptr += 8; - } - - ptr1 += 8; - } - - return 0; - } - } - else if (a.dims == 1) - { - if (a.w == 1 && elempack == 1) - { - // type 2 3 4 20 - return binary_op_2_3_4_20_fp16s(a, b, c, opt); - } - - if (b.dims == 4) - { - // type 21 - c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - float16x8_t _a0 = vld1q_f16((const __fp16*)a + q * 8); - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - - for (int i = 0; i < size1; i++) - { - float16x8_t _p1 = vld1q_f16(ptr1); - float16x8_t _outp = op(_a0, _p1); - vst1q_f16(outptr, _outp); - ptr1 += 8; - outptr += 8; - } - } - - return 0; - } - - if (b.dims == 3) - { - // type 9 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - float16x8_t _a0 = vld1q_f16((const __fp16*)a + q * 8); - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - - for (int i = 0; i < size1; i++) - { - float16x8_t _p1 = vld1q_f16(ptr1); - float16x8_t _outp = op(_a0, _p1); - vst1q_f16(outptr, _outp); - ptr1 += 8; - outptr += 8; - } - } - - return 0; - } - - if (b.dims == 2) - { - // type 8 - c.create(w1, h1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - const __fp16* ptr = a; - const __fp16* ptr1 = b; - __fp16* outptr = c; - - for (int y = 0; y < h1; y++) - { - float16x8_t _a0 = vld1q_f16(ptr); - for (int x = 0; x < w1; x++) - { - float16x8_t _p1 = vld1q_f16(ptr1); - float16x8_t _outp = op(_a0, _p1); - vst1q_f16(outptr, _outp); - ptr1 += 8; - outptr += 8; - } - - ptr += 8; - } - - return 0; - } - - if (b.dims == 1) - { - c.create(w, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.w == 1 && elempack1 == 1) - { - // type 6 - return binary_op_6_11_16_25_fp16s(a, b, c, opt); - } - - // type 7 - binary_op_7_13_19_29_fp16s(a, b, c, opt); - } - } - - return 0; -} - -template -static int binary_op_pack4_fp16s(const Mat& a, const Mat& b, Mat& c, const Option& opt) -{ - Op op; - - int w = a.w; - int h = a.h; - int d = a.d; - int channels = a.c; - int size = w * h * d; - size_t elemsize = a.elemsize; - int elempack = a.elempack; - - int w1 = b.w; - int h1 = b.h; - int d1 = b.d; - int channels1 = b.c; - int size1 = w1 * h1 * d1; - size_t elemsize1 = b.elemsize; - int elempack1 = b.elempack; - - if (a.dims == 4) - { - if (b.dims == 4) - { - // type 29 - return binary_op_7_13_19_29_fp16s(a, b, c, opt); - } - - c.create(w, h, d, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.dims == 3) - { - // type 28 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const __fp16* ptr = a.channel(q); - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - - for (int z = 0; z < d; z++) - { - for (int y = 0; y < h; y++) - { - float16x4_t _b0 = vld1_f16(ptr1); - for (int x = 0; x < w; x++) - { - float16x4_t _p = vld1_f16(ptr); - float16x4_t _outp = op(_p, _b0); - vst1_f16(outptr, _outp); - ptr += 4; - outptr += 4; - } - - ptr1 += 4; - } - } - } - - return 0; - } - - if (b.dims == 2) - { - // type 27 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const __fp16* ptr = a.channel(q); - const __fp16* ptr1 = b.row(q); - __fp16* outptr = c.channel(q); - - for (int z = 0; z < d; z++) - { - float16x4_t _b0 = vld1_f16(ptr1); - for (int y = 0; y < h; y++) - { - for (int x = 0; x < w; x++) - { - float16x4_t _p = vld1_f16(ptr); - float16x4_t _outp = op(_p, _b0); - vst1_f16(outptr, _outp); - ptr += 4; - outptr += 4; - } - } - - ptr1 += 4; - } - } - - return 0; - } - - if (b.dims == 1) - { - if (b.w == 1 && elempack1 == 1) - { - // type 25 - return binary_op_6_11_16_25_fp16s(a, b, c, opt); - } - - // type 26 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const __fp16* ptr = a.channel(q); - float16x4_t _b0 = vld1_f16((const __fp16*)b + q * 4); - __fp16* outptr = c.channel(q); - - for (int i = 0; i < size; i++) - { - float16x4_t _p = vld1_f16(ptr); - float16x4_t _outp = op(_p, _b0); - vst1_f16(outptr, _outp); - ptr += 4; - outptr += 4; - } - } - - return 0; - } - } - else if (a.dims == 3) - { - if (b.dims == 4) - { - // type 23 - c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const __fp16* ptr = a.channel(q); - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - - for (int z = 0; z < d1; z++) - { - for (int y = 0; y < h1; y++) - { - float16x4_t _a0 = vld1_f16(ptr); - for (int x = 0; x < w1; x++) - { - float16x4_t _p = vld1_f16(ptr1); - float16x4_t _outp = op(_a0, _p); - vst1_f16(outptr, _outp); - ptr1 += 4; - outptr += 4; - } - - ptr += 4; - } - } - } - - return 0; - } - - if (b.dims == 3) - { - if (w1 == 1 && h1 == 1 && channels1 == channels) - { - // special type 1 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const __fp16* ptr = a.channel(q); - __fp16* outptr = c.channel(q); - const __fp16* b0 = b.channel(q); - float16x4_t _b0 = vld1_f16(b0); - for (int i = 0; i < size; i++) - { - float16x4_t _p = vld1_f16(ptr); - float16x4_t _outp = op(_p, _b0); - vst1_f16(outptr, _outp); - ptr += 4; - outptr += 4; - } - } - - return 0; - } - - if (w1 == w && h1 == h && channels1 == 1 && elempack1 == 1) - { - // special type 2 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const __fp16* ptr = a.channel(q); - const __fp16* ptr1 = b; - __fp16* outptr = c.channel(q); - for (int i = 0; i < size; i++) - { - float16x4_t _p = vld1_f16(ptr); - float16x4_t _p1 = vdup_n_f16(*ptr1); - float16x4_t _outp = op(_p, _p1); - vst1_f16(outptr, _outp); - ptr += 4; - ptr1 += 1; - outptr += 4; - } - } - - return 0; - } - - if (w == 1 && h == 1 && channels1 == channels) - { - // special type 3 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const __fp16* a0 = a.channel(q); - __fp16* outptr = c.channel(q); - const __fp16* ptr1 = b.channel(q); - float16x4_t _a0 = vld1_f16(a0); - for (int i = 0; i < size1; i++) - { - float16x4_t _p1 = vld1_f16(ptr1); - float16x4_t _outp = op(_a0, _p1); - vst1_f16(outptr, _outp); - ptr1 += 4; - outptr += 4; - } - } - - return 0; - } - - if (w1 == w && h1 == h && channels == 1 && elempack == 1) - { - // special type 4 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const __fp16* ptr = a; - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - for (int i = 0; i < size1; i++) - { - float16x4_t _p = vdup_n_f16(*ptr); - float16x4_t _p1 = vld1_f16(ptr1); - float16x4_t _outp = op(_p, _p1); - vst1_f16(outptr, _outp); - ptr += 1; - ptr1 += 4; - outptr += 4; - } - } - - return 0; - } - - if (w != 1 && w1 == 1 && h1 == h && channels1 == channels) - { - // special type 5 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const __fp16* ptr = a.channel(q); - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - - for (int y = 0; y < h; y++) - { - float16x4_t _p1 = vld1_f16(ptr1 + y * 4); - for (int x = 0; x < w; x++) - { - float16x4_t _p = vld1_f16(ptr); - float16x4_t _outp = op(_p, _p1); - vst1_f16(outptr, _outp); - - ptr += 4; - outptr += 4; - } - } - } - - return 0; - } - - if (w1 == w && h != 1 && h1 == 1 && channels1 == channels) - { - // special type 6 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const __fp16* ptr = a.channel(q); - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - - for (int y = 0; y < h; y++) - { - for (int x = 0; x < w; x++) - { - float16x4_t _p = vld1_f16(ptr); - float16x4_t _p1 = vld1_f16(ptr1 + x * 4); - float16x4_t _outp = op(_p, _p1); - vst1_f16(outptr, _outp); - - ptr += 4; - outptr += 4; - } - } - } - - return 0; - } - - if (w1 != 1 && w == 1 && h1 == h && channels1 == channels) - { - // special type 7 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const __fp16* ptr = a.channel(q); - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - - for (int y = 0; y < h1; y++) - { - float16x4_t _p = vld1_f16(ptr + y * 4); - for (int x = 0; x < w1; x++) - { - float16x4_t _p1 = vld1_f16(ptr1); - float16x4_t _outp = op(_p, _p1); - vst1_f16(outptr, _outp); - - ptr1 += 4; - outptr += 4; - } - } - } - - return 0; - } - - if (w1 == w && h1 != 1 && h == 1 && channels1 == channels) - { - // special type 8 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const __fp16* ptr = a.channel(q); - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - - for (int y = 0; y < h1; y++) - { - for (int x = 0; x < w1; x++) - { - float16x4_t _p = vld1_f16(ptr + x * 4); - float16x4_t _p1 = vld1_f16(ptr1); - float16x4_t _outp = op(_p, _p1); - vst1_f16(outptr, _outp); - - ptr1 += 4; - outptr += 4; - } - } - } - - return 0; - } - - // type 19 - return binary_op_7_13_19_29_fp16s(a, b, c, opt); - } - - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.dims == 2) - { - // type 18 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const __fp16* ptr = a.channel(q); - const __fp16* ptr1 = b.row(q); - __fp16* outptr = c.channel(q); - - for (int y = 0; y < h; y++) - { - float16x4_t _b0 = vld1_f16(ptr1); - for (int x = 0; x < w; x++) - { - float16x4_t _p = vld1_f16(ptr); - float16x4_t _outp = op(_p, _b0); - vst1_f16(outptr, _outp); - ptr += 4; - outptr += 4; - } - - ptr1 += 4; - } - } - - return 0; - } - - if (b.dims == 1) - { - if (b.w == 1 && elempack1 == 1) - { - // type 16 - return binary_op_6_11_16_25_fp16s(a, b, c, opt); - } - - // type 17 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const __fp16* ptr = a.channel(q); - float16x4_t _b0 = vld1_f16((const __fp16*)b + q * 4); - __fp16* outptr = c.channel(q); - - for (int i = 0; i < size; i++) - { - float16x4_t _p = vld1_f16(ptr); - float16x4_t _outp = op(_p, _b0); - vst1_f16(outptr, _outp); - ptr += 4; - outptr += 4; - } - } - - return 0; - } - } - else if (a.dims == 2) - { - if (b.dims == 4) - { - // type 22 - c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const __fp16* ptr = a.row(q); - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - - for (int z = 0; z < d1; z++) - { - float16x4_t _a0 = vld1_f16(ptr); - for (int y = 0; y < h1; y++) - { - for (int x = 0; x < w1; x++) - { - float16x4_t _p = vld1_f16(ptr1); - float16x4_t _outp = op(_a0, _p); - vst1_f16(outptr, _outp); - ptr1 += 4; - outptr += 4; - } - } - - ptr += 4; - } - } - - return 0; - } - - if (b.dims == 3) - { - // type 14 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const __fp16* ptr = a.row(q); - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - - for (int y = 0; y < h1; y++) - { - float16x4_t _a0 = vld1_f16(ptr); - for (int x = 0; x < w1; x++) - { - float16x4_t _p1 = vld1_f16(ptr1); - float16x4_t _outp = op(_a0, _p1); - vst1_f16(outptr, _outp); - ptr1 += 4; - outptr += 4; - } - - ptr += 4; - } - } - - return 0; - } - - c.create(w, h, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.dims == 2) - { - // type 13 - return binary_op_7_13_19_29_fp16s(a, b, c, opt); - } - - if (b.dims == 1) - { - c.create(w, h, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.w == 1 && elempack1 == 1) - { - // type 11 - return binary_op_6_11_16_25_fp16s(a, b, c, opt); - } - - // type 12 - const __fp16* ptr = a; - const __fp16* ptr1 = b; - __fp16* outptr = c; - - for (int y = 0; y < h; y++) - { - float16x4_t _b0 = vld1_f16(ptr1); - for (int x = 0; x < w; x++) - { - float16x4_t _p = vld1_f16(ptr); - float16x4_t _outp = op(_p, _b0); - vst1_f16(outptr, _outp); - ptr += 4; - outptr += 4; - } - - ptr1 += 4; - } - - return 0; - } - } - else if (a.dims == 1) - { - if (a.w == 1 && elempack == 1) - { - // type 2 3 4 20 - return binary_op_2_3_4_20_fp16s(a, b, c, opt); - } - - if (b.dims == 4) - { - // type 21 - c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - float16x4_t _a0 = vld1_f16((const __fp16*)a + q * 4); - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - - for (int i = 0; i < size1; i++) - { - float16x4_t _p1 = vld1_f16(ptr1); - float16x4_t _outp = op(_a0, _p1); - vst1_f16(outptr, _outp); - ptr1 += 4; - outptr += 4; - } - } - - return 0; - } - - if (b.dims == 3) - { - // type 9 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - float16x4_t _a0 = vld1_f16((const __fp16*)a + q * 4); - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - - for (int i = 0; i < size1; i++) - { - float16x4_t _p1 = vld1_f16(ptr1); - float16x4_t _outp = op(_a0, _p1); - vst1_f16(outptr, _outp); - ptr1 += 4; - outptr += 4; - } - } - - return 0; - } - - if (b.dims == 2) - { - // type 8 - c.create(w1, h1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - const __fp16* ptr = a; - const __fp16* ptr1 = b; - __fp16* outptr = c; - - for (int y = 0; y < h1; y++) - { - float16x4_t _a0 = vld1_f16(ptr); - for (int x = 0; x < w1; x++) - { - float16x4_t _p1 = vld1_f16(ptr1); - float16x4_t _outp = op(_a0, _p1); - vst1_f16(outptr, _outp); - ptr1 += 4; - outptr += 4; - } - - ptr += 4; - } - - return 0; - } - - if (b.dims == 1) - { - c.create(w, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.w == 1 && elempack1 == 1) - { - // type 6 - return binary_op_6_11_16_25_fp16s(a, b, c, opt); - } - - // type 7 - binary_op_7_13_19_29_fp16s(a, b, c, opt); - } - } - - return 0; -} - -template -static int binary_op_fp16s(const Mat& a, const Mat& b, Mat& c, const Option& opt) -{ - Op op; - - int w = a.w; - int h = a.h; - int d = a.d; - int channels = a.c; - int size = w * h * d; - size_t elemsize = a.elemsize; - - int w1 = b.w; - int h1 = b.h; - int d1 = b.d; - int channels1 = b.c; - int size1 = w1 * h1 * d1; - - if (a.dims == 4) - { - if (b.dims == 4) - { - // type 29 - return binary_op_7_13_19_29_fp16s(a, b, c, opt); - } - - c.create(w, h, d, channels, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.dims == 3) - { - // type 28 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const __fp16* ptr = a.channel(q); - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - - for (int z = 0; z < d; z++) - { - for (int y = 0; y < h; y++) - { - const __fp16 b0 = ptr1[y]; - for (int x = 0; x < w; x++) - { - outptr[x] = op(ptr[x], b0); - } - - ptr += w; - outptr += w; - } - - ptr1 += h; - } - } - - return 0; - } - - if (b.dims == 2) - { - // type 27 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const __fp16* ptr = a.channel(q); - const __fp16* ptr1 = b.row(q); - __fp16* outptr = c.channel(q); - - for (int z = 0; z < d; z++) - { - const __fp16 b0 = ptr1[z]; - for (int y = 0; y < h; y++) - { - for (int x = 0; x < w; x++) - { - outptr[x] = op(ptr[x], b0); - } - - ptr += w; - outptr += w; - } - } - } - - return 0; - } - - if (b.dims == 1) - { - if (b.w == 1) - { - // type 25 - return binary_op_6_11_16_25_fp16s(a, b, c, opt); - } - - // type 26 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const __fp16* ptr = a.channel(q); - const __fp16 b0 = ((const __fp16*)b)[q]; - __fp16* outptr = c.channel(q); - - for (int i = 0; i < size; i++) - { - outptr[i] = op(ptr[i], b0); - } - } - - return 0; - } - } - else if (a.dims == 3) - { - if (b.dims == 4) - { - // type 23 - c.create(w1, h1, d1, channels1, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const __fp16* ptr = a.channel(q); - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - - for (int z = 0; z < d1; z++) - { - for (int y = 0; y < h1; y++) - { - const __fp16 a0 = ptr[y]; - for (int x = 0; x < w1; x++) - { - outptr[x] = op(a0, ptr1[x]); - } - - ptr1 += w1; - outptr += w1; - } - - ptr += h1; - } - } - - return 0; - } - - if (b.dims == 3) - { - if (w1 == 1 && h1 == 1 && channels1 == channels) - { - // special type 1 - c.create(w, h, channels, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const __fp16* ptr = a.channel(q); - const __fp16* b0 = b.channel(q); - __fp16* outptr = c.channel(q); - for (int i = 0; i < size; i++) - { - outptr[i] = op(ptr[i], b0[0]); - } - } - - return 0; - } - - if (w1 == w && h1 == h && channels1 == 1) - { - // special type 2 - c.create(w, h, channels, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const __fp16* ptr = a.channel(q); - const __fp16* ptr1 = b; - __fp16* outptr = c.channel(q); - for (int i = 0; i < size; i++) - { - outptr[i] = op(ptr[i], ptr1[i]); - } - } - - return 0; - } - - if (w == 1 && h == 1 && channels1 == channels) - { - // special type 3 - c.create(w1, h1, channels1, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const __fp16* a0 = a.channel(q); - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - for (int i = 0; i < size1; i++) - { - outptr[i] = op(a0[0], ptr1[i]); - } - } - - return 0; - } - - if (w1 == w && h1 == h && channels == 1) - { - // special type 4 - c.create(w1, h1, channels1, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const __fp16* ptr = a; - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - for (int i = 0; i < size1; i++) - { - outptr[i] = op(ptr[i], ptr1[i]); - } - } - - return 0; - } - - if (w != 1 && w1 == 1 && h1 == h && channels1 == channels) - { - // special type 5 - c.create(w, h, channels, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const __fp16* ptr = a.channel(q); - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - - for (int y = 0; y < h; y++) - { - const __fp16 b0 = ptr1[y]; - for (int x = 0; x < w; x++) - { - outptr[x] = op(ptr[x], b0); - } - - ptr += w; - outptr += w; - } - } - - return 0; - } - - if (w1 == w && h != 1 && h1 == 1 && channels1 == channels) - { - // special type 6 - c.create(w, h, channels, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const __fp16* ptr = a.channel(q); - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - - for (int y = 0; y < h; y++) - { - for (int x = 0; x < w; x++) - { - outptr[x] = op(ptr[x], ptr1[x]); - } - - ptr += w; - outptr += w; - } - } - - return 0; - } - - if (w1 != 1 && w == 1 && h1 == h && channels1 == channels) - { - // special type 7 - c.create(w1, h1, channels1, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const __fp16* ptr = a.channel(q); - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - - for (int y = 0; y < h1; y++) - { - const __fp16 a0 = ptr[y]; - for (int x = 0; x < w1; x++) - { - outptr[x] = op(a0, ptr1[x]); - } - - ptr1 += w1; - outptr += w1; - } - } - - return 0; - } - - if (w1 == w && h1 != 1 && h == 1 && channels1 == channels) - { - // special type 8 - c.create(w1, h1, channels1, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const __fp16* ptr = a.channel(q); - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - - for (int y = 0; y < h1; y++) - { - for (int x = 0; x < w1; x++) - { - outptr[x] = op(ptr[x], ptr1[x]); - } + const __fp16* ptr = a.row(y); + __fp16* outptr = c.row<__fp16>(y); - ptr1 += w1; - outptr += w1; - } - } + const __fp16 _b = ((const __fp16*)b)[y]; + float16x4_t _b_128 = (elempack == 4) ? vld1_f16((const __fp16*)b + y * 4) : vdup_n_f16(_b); + float16x8_t _b_256 = (elempack == 8) ? vld1q_f16((const __fp16*)b + y * 8) : vcombine_f16(_b_128, _b_128); - return 0; - } + const int size = w * elempack; - // type 19 - return binary_op_7_13_19_29_fp16s(a, b, c, opt); + int i = 0; + for (; i + 7 < size; i += 8) + { + float16x8_t _p = vld1q_f16(ptr); + float16x8_t _outp = op(_p, _b_256); + vst1q_f16(outptr, _outp); + ptr += 8; + outptr += 8; + } + for (; i + 3 < size; i += 4) + { + float16x4_t _p = vld1_f16(ptr); + float16x4_t _outp = op(_p, _b_128); + vst1_f16(outptr, _outp); + ptr += 4; + outptr += 4; + } + for (; i < size; i++) + { + *outptr = op(*ptr, _b); + ptr += 1; + outptr += 1; + } } + } - c.create(w, h, channels, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.dims == 2) + if ((a.dims == 3 || a.dims == 4) && b.dims == 1) + { + // type 9 11 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - // type 18 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const __fp16* ptr = a.channel(q); - const __fp16* ptr1 = b.row(q); - __fp16* outptr = c.channel(q); + const __fp16* ptr = a.channel(q); + __fp16* outptr = c.channel(q); - for (int y = 0; y < h; y++) - { - const __fp16 b0 = ptr1[y]; - for (int x = 0; x < w; x++) - { - outptr[x] = op(ptr[x], b0); - } + const __fp16 _b = ((const __fp16*)b)[q]; + float16x4_t _b_128 = (elempack == 4) ? vld1_f16((const __fp16*)b + q * 4) : vdup_n_f16(_b); + float16x8_t _b_256 = (elempack == 8) ? vld1q_f16((const __fp16*)b + q * 8) : vcombine_f16(_b_128, _b_128); - ptr += w; - outptr += w; - } - } + const int size = w * h * d * elempack; - return 0; + int i = 0; + for (; i + 7 < size; i += 8) + { + float16x8_t _p = vld1q_f16(ptr); + float16x8_t _outp = op(_p, _b_256); + vst1q_f16(outptr, _outp); + ptr += 8; + outptr += 8; + } + for (; i + 3 < size; i += 4) + { + float16x4_t _p = vld1_f16(ptr); + float16x4_t _outp = op(_p, _b_128); + vst1_f16(outptr, _outp); + ptr += 4; + outptr += 4; + } + for (; i < size; i++) + { + *outptr = op(*ptr, _b); + ptr += 1; + outptr += 1; + } } + } - if (b.dims == 1) + if (a.dims == 3 && b.dims == 2) + { + // type 10 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - if (b.w == 1) - { - // type 16 - return binary_op_6_11_16_25_fp16s(a, b, c, opt); - } + const __fp16* ptr = a.channel(q); + const __fp16* ptr1 = b.row(q); + __fp16* outptr = c.channel(q); + + const int size = w * elempack; - // type 17 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) + for (int y = 0; y < h; y++) { - const __fp16* ptr = a.channel(q); - const __fp16 b0 = ((const __fp16*)b)[q]; - __fp16* outptr = c.channel(q); + const __fp16 _b = ptr1[y]; + float16x4_t _b_128 = (elempack == 4) ? vld1_f16((const __fp16*)ptr1 + y * 4) : vdup_n_f16(_b); + float16x8_t _b_256 = (elempack == 8) ? vld1q_f16((const __fp16*)ptr1 + y * 8) : vcombine_f16(_b_128, _b_128); - for (int i = 0; i < size; i++) + int i = 0; + for (; i + 7 < size; i += 8) + { + float16x8_t _p = vld1q_f16(ptr); + float16x8_t _outp = op(_p, _b_256); + vst1q_f16(outptr, _outp); + ptr += 8; + outptr += 8; + } + for (; i + 3 < size; i += 4) + { + float16x4_t _p = vld1_f16(ptr); + float16x4_t _outp = op(_p, _b_128); + vst1_f16(outptr, _outp); + ptr += 4; + outptr += 4; + } + for (; i < size; i++) { - outptr[i] = op(ptr[i], b0); + *outptr = op(*ptr, _b); + ptr += 1; + outptr += 1; } } - - return 0; } } - else if (a.dims == 2) + + if (a.dims == 4 && b.dims == 2) { - if (b.dims == 4) + // type 12 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - // type 22 - c.create(w1, h1, d1, channels1, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; + const __fp16* ptr = a.channel(q); + const __fp16* ptr1 = b.row(q); + __fp16* outptr = c.channel(q); + + const int size = w * h * elempack; - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) + for (int z = 0; z < d; z++) { - const __fp16* ptr = a.row(q); - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); + const __fp16 _b = ptr1[z]; + float16x4_t _b_128 = (elempack == 4) ? vld1_f16((const __fp16*)ptr1 + z * 4) : vdup_n_f16(_b); + float16x8_t _b_256 = (elempack == 8) ? vld1q_f16((const __fp16*)ptr1 + z * 8) : vcombine_f16(_b_128, _b_128); - for (int z = 0; z < d1; z++) + int i = 0; + for (; i + 7 < size; i += 8) { - const __fp16 a0 = ptr[z]; - for (int y = 0; y < h1; y++) - { - for (int x = 0; x < w1; x++) - { - outptr[x] = op(a0, ptr1[x]); - } - - ptr1 += w1; - outptr += w1; - } + float16x8_t _p = vld1q_f16(ptr); + float16x8_t _outp = op(_p, _b_256); + vst1q_f16(outptr, _outp); + ptr += 8; + outptr += 8; + } + for (; i + 3 < size; i += 4) + { + float16x4_t _p = vld1_f16(ptr); + float16x4_t _outp = op(_p, _b_128); + vst1_f16(outptr, _outp); + ptr += 4; + outptr += 4; + } + for (; i < size; i++) + { + *outptr = op(*ptr, _b); + ptr += 1; + outptr += 1; } } - - return 0; } + } - if (b.dims == 3) + if (a.dims == 4 && b.dims == 3) + { + // type 13 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - // type 14 - c.create(w1, h1, channels1, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; + const __fp16* ptr = a.channel(q); + __fp16* outptr = c.channel(q); - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) + const int size = w * elempack; + + for (int z = 0; z < d; z++) { - const __fp16* ptr = a.row(q); - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); + const __fp16* ptr1 = b.channel(q).row(z); - for (int y = 0; y < h1; y++) + for (int y = 0; y < h; y++) { - const __fp16 a0 = ptr[y]; - for (int x = 0; x < w1; x++) + const __fp16 _b = ptr1[y]; + float16x4_t _b_128 = (elempack == 4) ? vld1_f16((const __fp16*)ptr1 + y * 4) : vdup_n_f16(_b); + float16x8_t _b_256 = (elempack == 8) ? vld1q_f16((const __fp16*)ptr1 + y * 8) : vcombine_f16(_b_128, _b_128); + + int i = 0; + for (; i + 7 < size; i += 8) { - outptr[x] = op(a0, ptr1[x]); + float16x8_t _p = vld1q_f16(ptr); + float16x8_t _outp = op(_p, _b_256); + vst1q_f16(outptr, _outp); + ptr += 8; + outptr += 8; + } + for (; i + 3 < size; i += 4) + { + float16x4_t _p = vld1_f16(ptr); + float16x4_t _outp = op(_p, _b_128); + vst1_f16(outptr, _outp); + ptr += 4; + outptr += 4; + } + for (; i < size; i++) + { + *outptr = op(*ptr, _b); + ptr += 1; + outptr += 1; } - - ptr1 += w1; - outptr += w1; } } - - return 0; } + } - c.create(w, h, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; + return 0; +} - if (b.dims == 2) - { - // type 13 - return binary_op_7_13_19_29_fp16s(a, b, c, opt); - } +template +static int binary_op_broadcast_outer_fp16s(const Mat& a, const Mat& b, Mat& c, const Option& opt) +{ + Op op; - if (b.dims == 1) + int w = a.w; + int h = a.h; + int d = a.d; + int channels = a.c; + int elempack = a.elempack; + + if (a.dims == 2) + { + // type 14 + #pragma omp parallel for num_threads(opt.num_threads) + for (int y = 0; y < h; y++) { - c.create(w, h, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; + const __fp16* ptr = a.row(y); + const __fp16* ptr1 = b; + __fp16* outptr = c.row<__fp16>(y); - if (b.w == 1) + if (elempack == 8) { - // type 11 - return binary_op_6_11_16_25_fp16s(a, b, c, opt); + for (int x = 0; x < w; x++) + { + float16x8_t _p = vld1q_f16(ptr); + float16x8_t _b = vdupq_n_f16(*ptr1); + float16x8_t _outp = op(_p, _b); + vst1q_f16(outptr, _outp); + ptr += 8; + ptr1 += 1; + outptr += 8; + } } - - // type 12 - const __fp16* ptr = a; - __fp16* outptr = c; - - for (int y = 0; y < h; y++) + if (elempack == 4) { - const __fp16 b0 = ((const __fp16*)b)[y]; for (int x = 0; x < w; x++) { - outptr[x] = op(ptr[x], b0); + float16x4_t _p = vld1_f16(ptr); + float16x4_t _b = vdup_n_f16(*ptr1); + float16x4_t _outp = op(_p, _b); + vst1_f16(outptr, _outp); + ptr += 4; + ptr1 += 1; + outptr += 4; } - - ptr += w; - outptr += w; } - - return 0; - } - } - else if (a.dims == 1) - { - if (a.w == 1) - { - // type 2 3 4 20 - return binary_op_2_3_4_20_fp16s(a, b, c, opt); - } - - if (b.dims == 4) - { - // type 21 - c.create(w1, h1, d1, channels1, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) + if (elempack == 1) { - const __fp16 a0 = ((const __fp16*)a)[q]; - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - - for (int i = 0; i < size1; i++) + for (int x = 0; x < w; x++) { - outptr[i] = op(a0, ptr1[i]); + *outptr = op(*ptr, *ptr1); + ptr += 1; + ptr1 += 1; + outptr += 1; } } - - return 0; } + } - if (b.dims == 3) + if (a.dims == 3 || a.dims == 4) + { + // type 15 16 17 18 19 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - // type 9 - c.create(w1, h1, channels1, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; + const __fp16* ptr = a.channel(q); + __fp16* outptr = c.channel(q); - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) + for (int z = 0; z < d; z++) { - const __fp16 a0 = ((const __fp16*)a)[q]; - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - - for (int i = 0; i < size1; i++) + int z1 = std::min(z, b.d - 1); + for (int y = 0; y < h; y++) { - outptr[i] = op(a0, ptr1[i]); + int y1 = std::min(y, b.h - 1); + + const __fp16* ptr1 = b.depth(z1).row(y1); + + if (elempack == 8) + { + for (int x = 0; x < w; x++) + { + float16x8_t _p = vld1q_f16(ptr); + float16x8_t _b = vdupq_n_f16(*ptr1); + float16x8_t _outp = op(_p, _b); + vst1q_f16(outptr, _outp); + ptr += 8; + ptr1 += 1; + outptr += 8; + } + } + if (elempack == 4) + { + for (int x = 0; x < w; x++) + { + float16x4_t _p = vld1_f16(ptr); + float16x4_t _b = vdup_n_f16(*ptr1); + float16x4_t _outp = op(_p, _b); + vst1_f16(outptr, _outp); + ptr += 4; + ptr1 += 1; + outptr += 4; + } + } + if (elempack == 1) + { + for (int x = 0; x < w; x++) + { + *outptr = op(*ptr, *ptr1); + ptr += 1; + ptr1 += 1; + outptr += 1; + } + } } } - - return 0; } + } - if (b.dims == 2) - { - // type 8 - c.create(w1, h1, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - const __fp16* ptr1 = b; - __fp16* outptr = c; + return 0; +} - for (int y = 0; y < h1; y++) - { - const __fp16 a0 = ((const __fp16*)a)[y]; - for (int x = 0; x < w1; x++) - { - outptr[x] = op(a0, ptr1[x]); - } +template +static int binary_op_broadcast_20_fp16s(const Mat& a, const Mat& b, Mat& c, const Option& opt) +{ + Op op; - ptr1 += w1; - outptr += w1; - } + int w = a.w; + int h = a.h; + int channels = a.c; + int elempack = a.elempack; - return 0; - } + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const __fp16* ptr = a.channel(q); + __fp16* outptr = c.channel(q); - if (b.dims == 1) + for (int y = 0; y < h; y++) { - c.create(w, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; + const __fp16* ptr1 = b.channel(q); - if (b.w == 1) + const int size = w * elempack; + + int i = 0; + for (; i + 7 < size; i += 8) { - // type 6 - return binary_op_6_11_16_25_fp16s(a, b, c, opt); + float16x8_t _p = vld1q_f16(ptr); + float16x8_t _p1 = vld1q_f16(ptr1); + float16x8_t _outp = op(_p, _p1); + vst1q_f16(outptr, _outp); + ptr += 8; + ptr1 += 8; + outptr += 8; + } + for (; i + 3 < size; i += 4) + { + float16x4_t _p = vld1_f16(ptr); + float16x4_t _p1 = vld1_f16(ptr1); + float16x4_t _outp = op(_p, _p1); + vst1_f16(outptr, _outp); + ptr += 4; + ptr1 += 4; + outptr += 4; + } + for (; i < size; i++) + { + *outptr = op(*ptr, *ptr1); + ptr += 1; + ptr1 += 1; + outptr += 1; } - - // type 7 - binary_op_7_13_19_29_fp16s(a, b, c, opt); } } @@ -2206,16 +529,12 @@ static int binary_op_fp16s(const Mat& a, const Mat& b, Mat& c, const Option& opt } template -static int binary_op_scalar_inplace_fp16s(Mat& a, float b, const Option& opt) +static int binary_op_scalar_inplace_fp16s(Mat& a, __fp16 b, const Option& opt) { Op op; - int w = a.w; - int h = a.h; - int d = a.d; - int channels = a.c; - int elempack = a.elempack; - int size = w * h * d * elempack; + const int channels = a.c; + const int size = a.w * a.h * a.d * a.elempack; #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) @@ -2223,7 +542,7 @@ static int binary_op_scalar_inplace_fp16s(Mat& a, float b, const Option& opt) __fp16* ptr = a.channel(q); int i = 0; - float16x8_t _b = vdupq_n_f16((__fp16)b); + float16x8_t _b = vdupq_n_f16(b); for (; i + 7 < size; i += 8) { float16x8_t _p = vld1q_f16(ptr); @@ -2240,8 +559,8 @@ static int binary_op_scalar_inplace_fp16s(Mat& a, float b, const Option& opt) } for (; i < size; i++) { - *ptr = op(*ptr, (__fp16)b); - ptr += 1; + *ptr = op(*ptr, b); + ptr++; } } @@ -2278,6 +597,7 @@ MAKE_FUNCTION(binary_op_min_fp16s, std::min(x, y), vmin_f16(x, y), vminq_f16(x, MAKE_FUNCTION(binary_op_pow_fp16s, (__fp16)pow(x, y), vcvt_f16_f32(pow_ps(vcvt_f32_f16(x), vcvt_f32_f16(y))), vcombine_f16(vcvt_f16_f32(pow_ps(vcvt_f32_f16(vget_low_f16(x)), vcvt_f32_f16(vget_low_f16(y)))), vcvt_f16_f32(pow_ps(vcvt_f32_f16(vget_high_f16(x)), vcvt_f32_f16(vget_high_f16(y)))))) MAKE_FUNCTION(binary_op_rsub_fp16s, y - x, vsub_f16(y, x), vsubq_f16(y, x)) MAKE_FUNCTION(binary_op_rdiv_fp16s, y / x, vdiv_f16(y, x), vdivq_f16(y, x)) +MAKE_FUNCTION(binary_op_rpow_fp16s, (__fp16)pow(y, x), vcvt_f16_f32(pow_ps(vcvt_f32_f16(y), vcvt_f32_f16(x))), vcombine_f16(vcvt_f16_f32(pow_ps(vcvt_f32_f16(vget_low_f16(y)), vcvt_f32_f16(vget_low_f16(x)))), vcvt_f16_f32(pow_ps(vcvt_f32_f16(vget_high_f16(y)), vcvt_f32_f16(vget_high_f16(x)))))) // *INDENT-ON* // clang-format on @@ -2285,106 +605,176 @@ MAKE_FUNCTION(binary_op_rdiv_fp16s, y / x, vdiv_f16(y, x), vdivq_f16(y, x)) } // namespace BinaryOp_arm_functor -int BinaryOp_arm::forward_fp16s(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +static int binary_op_scalar_fp16s(const Mat& a, float b, Mat& c, int op_type, const Option& opt) { - const Mat& bottom_blob = bottom_blobs[0]; - const Mat& bottom_blob1 = bottom_blobs[1]; - - Mat& top_blob = top_blobs[0]; - using namespace BinaryOp_arm_functor; - int elempack = bottom_blob.elempack; - int elempack1 = bottom_blob1.elempack; - - if (elempack == 8 || elempack1 == 8) - { - if (op_type == Operation_ADD) - return binary_op_pack8_fp16s(bottom_blob, bottom_blob1, top_blob, opt); - - if (op_type == Operation_SUB) - return binary_op_pack8_fp16s(bottom_blob, bottom_blob1, top_blob, opt); - - if (op_type == Operation_MUL) - return binary_op_pack8_fp16s(bottom_blob, bottom_blob1, top_blob, opt); - - if (op_type == Operation_DIV) - return binary_op_pack8_fp16s(bottom_blob, bottom_blob1, top_blob, opt); - - if (op_type == Operation_MAX) - return binary_op_pack8_fp16s(bottom_blob, bottom_blob1, top_blob, opt); + if (op_type == BinaryOp::Operation_ADD) return binary_op_scalar_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_SUB) return binary_op_scalar_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_MUL) return binary_op_scalar_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_DIV) return binary_op_scalar_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_MAX) return binary_op_scalar_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_MIN) return binary_op_scalar_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_POW) return binary_op_scalar_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_RSUB) return binary_op_scalar_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_RDIV) return binary_op_scalar_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_RPOW) return binary_op_scalar_fp16s(a, b, c, opt); + + // should never reach here + return 0; +} - if (op_type == Operation_MIN) - return binary_op_pack8_fp16s(bottom_blob, bottom_blob1, top_blob, opt); +static int binary_op_no_broadcast_fp16s(const Mat& a, const Mat& b, Mat& c, int op_type, const Option& opt) +{ + using namespace BinaryOp_arm_functor; - if (op_type == Operation_POW) - return binary_op_pack8_fp16s(bottom_blob, bottom_blob1, top_blob, opt); + if (op_type == BinaryOp::Operation_ADD) return binary_op_no_broadcast_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_SUB) return binary_op_no_broadcast_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_MUL) return binary_op_no_broadcast_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_DIV) return binary_op_no_broadcast_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_MAX) return binary_op_no_broadcast_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_MIN) return binary_op_no_broadcast_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_POW) return binary_op_no_broadcast_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_RSUB) return binary_op_no_broadcast_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_RDIV) return binary_op_no_broadcast_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_RPOW) return binary_op_no_broadcast_fp16s(a, b, c, opt); + + // should never reach here + return 0; +} - if (op_type == Operation_RSUB) - return binary_op_pack8_fp16s(bottom_blob1, bottom_blob, top_blob, opt); +static int binary_op_broadcast_inner_fp16s(const Mat& a, const Mat& b, Mat& c, int op_type, const Option& opt) +{ + // squeeze inner axes + Mat b2 = b; + if (b.dims == 2 && b.w == 1) + b2 = b.reshape(b.h); + else if (b.dims == 3 && b.h == 1) + b2 = b.reshape(b.c); + else if (b.dims == 3 && b.w == 1) + b2 = b.reshape(b.h, b.c); + else if (b.dims == 4 && b.d == 1) + b2 = b.reshape(b.c); + else if (b.dims == 4 && b.h == 1) + b2 = b.reshape(b.d, b.c); + else if (b.dims == 4 && b.w == 1) + b2 = b.reshape(b.h, b.d, b.c); - if (op_type == Operation_RDIV) - return binary_op_pack8_fp16s(bottom_blob1, bottom_blob, top_blob, opt); - } + using namespace BinaryOp_arm_functor; - if (elempack == 4 || elempack1 == 4) - { - if (op_type == Operation_ADD) - return binary_op_pack4_fp16s(bottom_blob, bottom_blob1, top_blob, opt); + if (op_type == BinaryOp::Operation_ADD) return binary_op_broadcast_inner_fp16s(a, b2, c, opt); + if (op_type == BinaryOp::Operation_SUB) return binary_op_broadcast_inner_fp16s(a, b2, c, opt); + if (op_type == BinaryOp::Operation_MUL) return binary_op_broadcast_inner_fp16s(a, b2, c, opt); + if (op_type == BinaryOp::Operation_DIV) return binary_op_broadcast_inner_fp16s(a, b2, c, opt); + if (op_type == BinaryOp::Operation_MAX) return binary_op_broadcast_inner_fp16s(a, b2, c, opt); + if (op_type == BinaryOp::Operation_MIN) return binary_op_broadcast_inner_fp16s(a, b2, c, opt); + if (op_type == BinaryOp::Operation_POW) return binary_op_broadcast_inner_fp16s(a, b2, c, opt); + if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_inner_fp16s(a, b2, c, opt); + if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_inner_fp16s(a, b2, c, opt); + if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_inner_fp16s(a, b2, c, opt); + + // should never reach here + return 0; +} - if (op_type == Operation_SUB) - return binary_op_pack4_fp16s(bottom_blob, bottom_blob1, top_blob, opt); +static int binary_op_broadcast_outer_fp16s(const Mat& a, const Mat& b, Mat& c, int op_type, const Option& opt) +{ + using namespace BinaryOp_arm_functor; - if (op_type == Operation_MUL) - return binary_op_pack4_fp16s(bottom_blob, bottom_blob1, top_blob, opt); + if (op_type == BinaryOp::Operation_ADD) return binary_op_broadcast_outer_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_SUB) return binary_op_broadcast_outer_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_MUL) return binary_op_broadcast_outer_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_DIV) return binary_op_broadcast_outer_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_MAX) return binary_op_broadcast_outer_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_MIN) return binary_op_broadcast_outer_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_POW) return binary_op_broadcast_outer_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_outer_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_outer_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_outer_fp16s(a, b, c, opt); + + // should never reach here + return 0; +} - if (op_type == Operation_DIV) - return binary_op_pack4_fp16s(bottom_blob, bottom_blob1, top_blob, opt); +static int binary_op_broadcast_20_fp16s(const Mat& a, const Mat& b, Mat& c, int op_type, const Option& opt) +{ + using namespace BinaryOp_arm_functor; - if (op_type == Operation_MAX) - return binary_op_pack4_fp16s(bottom_blob, bottom_blob1, top_blob, opt); + if (op_type == BinaryOp::Operation_ADD) return binary_op_broadcast_20_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_SUB) return binary_op_broadcast_20_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_MUL) return binary_op_broadcast_20_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_DIV) return binary_op_broadcast_20_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_MAX) return binary_op_broadcast_20_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_MIN) return binary_op_broadcast_20_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_POW) return binary_op_broadcast_20_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_20_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_20_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_20_fp16s(a, b, c, opt); + + // should never reach here + return 0; +} - if (op_type == Operation_MIN) - return binary_op_pack4_fp16s(bottom_blob, bottom_blob1, top_blob, opt); +static int get_reverse_op_type(int op_type) +{ + if (op_type == BinaryOp::Operation_SUB) return BinaryOp::Operation_RSUB; + if (op_type == BinaryOp::Operation_DIV) return BinaryOp::Operation_RDIV; + if (op_type == BinaryOp::Operation_POW) return BinaryOp::Operation_RPOW; + if (op_type == BinaryOp::Operation_RSUB) return BinaryOp::Operation_SUB; + if (op_type == BinaryOp::Operation_RDIV) return BinaryOp::Operation_DIV; + if (op_type == BinaryOp::Operation_RPOW) return BinaryOp::Operation_POW; + return op_type; +} - if (op_type == Operation_POW) - return binary_op_pack4_fp16s(bottom_blob, bottom_blob1, top_blob, opt); +int BinaryOp_arm::forward_fp16s(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +{ + const bool b_is_scalar = bottom_blobs[1].w * bottom_blobs[1].h * bottom_blobs[1].d * bottom_blobs[1].c * bottom_blobs[1].elempack == 1; + const bool a_rank_is_lower = bottom_blobs[0].dims < bottom_blobs[1].dims && !b_is_scalar; + const bool a_size_is_lower = bottom_blobs[0].w * bottom_blobs[0].h * bottom_blobs[0].d * bottom_blobs[0].c * bottom_blobs[0].elempack < bottom_blobs[1].w * bottom_blobs[1].h * bottom_blobs[1].d * bottom_blobs[1].c * bottom_blobs[1].elempack; + const bool a_is_lower = a_rank_is_lower || (!a_rank_is_lower && a_size_is_lower); + const Mat& A = a_is_lower ? bottom_blobs[1] : bottom_blobs[0]; + const Mat& B = a_is_lower ? bottom_blobs[0] : bottom_blobs[1]; + const int op_type_r = a_is_lower ? get_reverse_op_type(op_type) : op_type; - if (op_type == Operation_RSUB) - return binary_op_pack4_fp16s(bottom_blob1, bottom_blob, top_blob, opt); + Mat& top_blob = top_blobs[0]; + top_blob.create_like(A, opt.blob_allocator); + if (top_blob.empty()) + return -100; - if (op_type == Operation_RDIV) - return binary_op_pack4_fp16s(bottom_blob1, bottom_blob, top_blob, opt); + // B is A scalar + if (B.w * B.h * B.d * B.c * B.elempack == 1) + { + return binary_op_scalar_fp16s(A, ((const __fp16*)B)[0], top_blob, op_type_r, opt); } - if (elempack == 1 && elempack1 == 1) + // no broadcast + if (A.dims == B.dims && A.w == B.w && A.h == B.h && A.d == B.d && A.c == B.c && A.elempack == B.elempack) { - if (op_type == Operation_ADD) - return binary_op_fp16s(bottom_blob, bottom_blob1, top_blob, opt); - - if (op_type == Operation_SUB) - return binary_op_fp16s(bottom_blob, bottom_blob1, top_blob, opt); - - if (op_type == Operation_MUL) - return binary_op_fp16s(bottom_blob, bottom_blob1, top_blob, opt); - - if (op_type == Operation_DIV) - return binary_op_fp16s(bottom_blob, bottom_blob1, top_blob, opt); - - if (op_type == Operation_MAX) - return binary_op_fp16s(bottom_blob, bottom_blob1, top_blob, opt); - - if (op_type == Operation_MIN) - return binary_op_fp16s(bottom_blob, bottom_blob1, top_blob, opt); + return binary_op_no_broadcast_fp16s(A, B, top_blob, op_type_r, opt); + } - if (op_type == Operation_POW) - return binary_op_fp16s(bottom_blob, bottom_blob1, top_blob, opt); + // broadcast B for inner axis + if ((B.dims < A.dims) + || (A.dims == 2 && B.w == 1 && B.h == A.h) + || (A.dims == 3 && B.w == 1 && B.h == 1 && B.c == A.c) + || (A.dims == 3 && B.w == 1 && B.h == A.h && B.c == A.c) + || (A.dims == 4 && B.w == 1 && B.h == 1 && B.d == 1 && B.c == A.c) + || (A.dims == 4 && B.w == 1 && B.h == 1 && B.d == A.d && B.c == A.c) + || (A.dims == 4 && B.w == 1 && B.h == A.h && B.d == A.d && B.c == A.c)) + { + return binary_op_broadcast_inner_fp16s(A, B, top_blob, op_type_r, opt); + } - if (op_type == Operation_RSUB) - return binary_op_fp16s(bottom_blob1, bottom_blob, top_blob, opt); + // broadcast B for outer axis + if (B.elempack == 1 && ((A.dims == 2 && B.w == A.w && B.h == 1) || (A.dims == 3 && B.w == A.w && B.h == 1 && B.c == 1) || (A.dims == 3 && B.w == A.w && B.h == A.h && B.c == 1) || (A.dims == 4 && B.w == A.w && B.h == 1 && B.d == 1 && B.c == 1) || (A.dims == 4 && B.w == A.w && B.h == A.h && B.d == 1 && B.c == 1) || (A.dims == 4 && B.w == A.w && B.h == A.h && B.d == A.d && B.c == 1))) + { + return binary_op_broadcast_outer_fp16s(A, B, top_blob, op_type_r, opt); + } - if (op_type == Operation_RDIV) - return binary_op_fp16s(bottom_blob1, bottom_blob, top_blob, opt); + // some special broadcast rule here + if (A.dims == 3 && B.dims == 3 && A.w == B.w && B.h == 1 && A.c == B.c) + { + return binary_op_broadcast_20_fp16s(A, B, top_blob, op_type_r, opt); } return 0; @@ -2394,32 +784,16 @@ int BinaryOp_arm::forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt) { using namespace BinaryOp_arm_functor; - if (op_type == Operation_ADD) - return binary_op_scalar_inplace_fp16s(bottom_top_blob, b, opt); - - if (op_type == Operation_SUB) - return binary_op_scalar_inplace_fp16s(bottom_top_blob, b, opt); - - if (op_type == Operation_MUL) - return binary_op_scalar_inplace_fp16s(bottom_top_blob, b, opt); - - if (op_type == Operation_DIV) - return binary_op_scalar_inplace_fp16s(bottom_top_blob, b, opt); - - if (op_type == Operation_MAX) - return binary_op_scalar_inplace_fp16s(bottom_top_blob, b, opt); - - if (op_type == Operation_MIN) - return binary_op_scalar_inplace_fp16s(bottom_top_blob, b, opt); - - if (op_type == Operation_POW) - return binary_op_scalar_inplace_fp16s(bottom_top_blob, b, opt); - - if (op_type == Operation_RSUB) - return binary_op_scalar_inplace_fp16s(bottom_top_blob, b, opt); - - if (op_type == Operation_RDIV) - return binary_op_scalar_inplace_fp16s(bottom_top_blob, b, opt); + if (op_type == Operation_ADD) return binary_op_scalar_inplace_fp16s(bottom_top_blob, (__fp16)b, opt); + if (op_type == Operation_SUB) return binary_op_scalar_inplace_fp16s(bottom_top_blob, (__fp16)b, opt); + if (op_type == Operation_MUL) return binary_op_scalar_inplace_fp16s(bottom_top_blob, (__fp16)b, opt); + if (op_type == Operation_DIV) return binary_op_scalar_inplace_fp16s(bottom_top_blob, (__fp16)b, opt); + if (op_type == Operation_MAX) return binary_op_scalar_inplace_fp16s(bottom_top_blob, (__fp16)b, opt); + if (op_type == Operation_MIN) return binary_op_scalar_inplace_fp16s(bottom_top_blob, (__fp16)b, opt); + if (op_type == Operation_POW) return binary_op_scalar_inplace_fp16s(bottom_top_blob, (__fp16)b, opt); + if (op_type == Operation_RSUB) return binary_op_scalar_inplace_fp16s(bottom_top_blob, (__fp16)b, opt); + if (op_type == Operation_RDIV) return binary_op_scalar_inplace_fp16s(bottom_top_blob, (__fp16)b, opt); + if (op_type == Operation_RPOW) return binary_op_scalar_inplace_fp16s(bottom_top_blob, (__fp16)b, opt); return 0; } diff --git a/src/layer/binaryop.cpp b/src/layer/binaryop.cpp index 53eb234bc..b35ff82fd 100644 --- a/src/layer/binaryop.cpp +++ b/src/layer/binaryop.cpp @@ -43,7 +43,7 @@ int BinaryOp::load_param(const ParamDict& pd) // https://github.com/Tencent/ncnn/wiki/binaryop-broadcasting template -static int binary_op(const Mat& a, const Mat& b, Mat& c, const Option& opt) +static int binary_op_scalar(const Mat& a, float b, Mat& c, const Option& opt) { Op op; @@ -52,424 +52,130 @@ static int binary_op(const Mat& a, const Mat& b, Mat& c, const Option& opt) int d = a.d; int channels = a.c; int size = w * h * d; - size_t elemsize = a.elemsize; - int w1 = b.w; - int h1 = b.h; - int d1 = b.d; - int channels1 = b.c; - int size1 = w1 * h1 * d1; - - if (a.dims == 4) + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - if (b.dims == 4) - { - // type 29 - c.create(w, h, d, channels, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); + const float* ptr = a.channel(q); + float* outptr = c.channel(q); - for (int i = 0; i < size; i++) - { - outptr[i] = op(ptr[i], ptr1[i]); - } - } - - return 0; + for (int i = 0; i < size; i++) + { + outptr[i] = op(ptr[i], b); } + } - c.create(w, h, d, channels, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.dims == 3) - { - // type 28 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); + return 0; +} - for (int z = 0; z < d; z++) - { - for (int y = 0; y < h; y++) - { - const float b0 = ptr1[y]; - for (int x = 0; x < w; x++) - { - outptr[x] = op(ptr[x], b0); - } - - ptr += w; - outptr += w; - } +template +static int binary_op_no_broadcast(const Mat& a, const Mat& b, Mat& c, const Option& opt) +{ + Op op; - ptr1 += h; - } - } + int channels = a.c; + int size = a.w * a.h * a.d; - return 0; - } + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* ptr = a.channel(q); + const float* ptr1 = b.channel(q); + float* outptr = c.channel(q); - if (b.dims == 2) + for (int i = 0; i < size; i++) { - // type 27 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.row(q); - float* outptr = c.channel(q); + outptr[i] = op(ptr[i], ptr1[i]); + } + } - for (int z = 0; z < d; z++) - { - const float b0 = ptr1[z]; - for (int y = 0; y < h; y++) - { - for (int x = 0; x < w; x++) - { - outptr[x] = op(ptr[x], b0); - } + return 0; +} - ptr += w; - outptr += w; - } - } - } +template +static int binary_op_broadcast_inner(const Mat& a, const Mat& b, Mat& c, const Option& opt) +{ + Op op; - return 0; - } + int w = a.w; + int h = a.h; + int d = a.d; + int channels = a.c; + int size = w * h * d; - if (b.dims == 1) + if (a.dims == 2 && b.dims == 1) + { + // type 12 + #pragma omp parallel for num_threads(opt.num_threads) + for (int y = 0; y < h; y++) { - if (b.w == 1) - { - // type 25 - const float b0 = b[0]; - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float* ptr = a.channel(q); - float* outptr = c.channel(q); - - for (int i = 0; i < size; i++) - { - outptr[i] = op(ptr[i], b0); - } - } + const float* ptr = a.row(y); + const float b0 = b[y]; + float* outptr = c.row(y); - return 0; - } - - // type 26 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) + for (int x = 0; x < w; x++) { - const float* ptr = a.channel(q); - const float b0 = b[q]; - float* outptr = c.channel(q); - - for (int i = 0; i < size; i++) - { - outptr[i] = op(ptr[i], b0); - } + outptr[x] = op(ptr[x], b0); } - - return 0; } } - else if (a.dims == 3) + + if ((a.dims == 3 || a.dims == 4) && b.dims == 1) { - if (b.dims == 4) + // type 9 11 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - // type 23 - c.create(w1, h1, d1, channels1, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; + const float* ptr = a.channel(q); + const float b0 = b[q]; + float* outptr = c.channel(q); - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) + for (int i = 0; i < size; i++) { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int z = 0; z < d1; z++) - { - for (int y = 0; y < h1; y++) - { - const float a0 = ptr[y]; - for (int x = 0; x < w1; x++) - { - outptr[x] = op(a0, ptr1[x]); - } - - ptr1 += w1; - outptr += w1; - } - - ptr += h1; - } + outptr[i] = op(ptr[i], b0); } - - return 0; } + } - if (b.dims == 3) + if (a.dims == 3 && b.dims == 2) + { + // type 10 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - if (w1 == 1 && h1 == 1 && channels1 == channels) - { - // special type 1 - c.create(w, h, channels, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float* ptr = a.channel(q); - const float* b0 = b.channel(q); - float* outptr = c.channel(q); - for (int i = 0; i < size; i++) - { - outptr[i] = op(ptr[i], b0[0]); - } - } - - return 0; - } + const float* ptr = a.channel(q); + const float* ptr1 = b.row(q); + float* outptr = c.channel(q); - if (w1 == w && h1 == h && channels1 == 1) - { - // special type 2 - c.create(w, h, channels, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b; - float* outptr = c.channel(q); - for (int i = 0; i < size; i++) - { - outptr[i] = op(ptr[i], ptr1[i]); - } - } - - return 0; - } - - if (w == 1 && h == 1 && channels1 == channels) - { - // special type 3 - c.create(w1, h1, channels1, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* a0 = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - for (int i = 0; i < size1; i++) - { - outptr[i] = op(a0[0], ptr1[i]); - } - } - - return 0; - } - - if (w1 == w && h1 == h && channels == 1) - { - // special type 4 - c.create(w1, h1, channels1, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr = a; - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - for (int i = 0; i < size1; i++) - { - outptr[i] = op(ptr[i], ptr1[i]); - } - } - - return 0; - } - - if (w != 1 && w1 == 1 && h1 == h && channels1 == channels) - { - // special type 5 - c.create(w, h, channels, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int y = 0; y < h; y++) - { - const float b0 = ptr1[y]; - for (int x = 0; x < w; x++) - { - outptr[x] = op(ptr[x], b0); - } - - ptr += w; - outptr += w; - } - } - - return 0; - } - - if (w1 == w && h != 1 && h1 == 1 && channels1 == channels) - { - // special type 6 - c.create(w, h, channels, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int y = 0; y < h; y++) - { - for (int x = 0; x < w; x++) - { - outptr[x] = op(ptr[x], ptr1[x]); - } - - ptr += w; - outptr += w; - } - } - - return 0; - } - - if (w1 != 1 && w == 1 && h1 == h && channels1 == channels) - { - // special type 7 - c.create(w1, h1, channels1, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int y = 0; y < h1; y++) - { - const float a0 = ptr[y]; - for (int x = 0; x < w1; x++) - { - outptr[x] = op(a0, ptr1[x]); - } - - ptr1 += w1; - outptr += w1; - } - } - - return 0; - } - - if (w1 == w && h1 != 1 && h == 1 && channels1 == channels) + for (int y = 0; y < h; y++) { - // special type 8 - c.create(w1, h1, channels1, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) + const float b0 = ptr1[y]; + for (int x = 0; x < w; x++) { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int y = 0; y < h1; y++) - { - for (int x = 0; x < w1; x++) - { - outptr[x] = op(ptr[x], ptr1[x]); - } - - ptr1 += w1; - outptr += w1; - } + outptr[x] = op(ptr[x], b0); } - return 0; - } - - // type 19 - c.create(w, h, channels, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int i = 0; i < size; i++) - { - outptr[i] = op(ptr[i], ptr1[i]); - } + ptr += w; + outptr += w; } - - return 0; } + } - c.create(w, h, channels, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.dims == 2) + if (a.dims == 4 && b.dims == 2) + { + // type 12 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - // type 18 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.row(q); - float* outptr = c.channel(q); + const float* ptr = a.channel(q); + const float* ptr1 = b.row(q); + float* outptr = c.channel(q); + for (int z = 0; z < d; z++) + { + const float b0 = ptr1[z]; for (int y = 0; y < h; y++) { - const float b0 = ptr1[y]; for (int x = 0; x < w; x++) { outptr[x] = op(ptr[x], b0); @@ -479,342 +185,125 @@ static int binary_op(const Mat& a, const Mat& b, Mat& c, const Option& opt) outptr += w; } } - - return 0; - } - - if (b.dims == 1) - { - if (b.w == 1) - { - // type 16 - const float b0 = b[0]; - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float* ptr = a.channel(q); - float* outptr = c.channel(q); - - for (int i = 0; i < size; i++) - { - outptr[i] = op(ptr[i], b0); - } - } - - return 0; - } - - // type 17 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float* ptr = a.channel(q); - const float b0 = b[q]; - float* outptr = c.channel(q); - - for (int i = 0; i < size; i++) - { - outptr[i] = op(ptr[i], b0); - } - } - - return 0; } } - else if (a.dims == 2) - { - if (b.dims == 4) - { - // type 22 - c.create(w1, h1, d1, channels1, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr = a.row(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int z = 0; z < d1; z++) - { - const float a0 = ptr[z]; - for (int y = 0; y < h1; y++) - { - for (int x = 0; x < w1; x++) - { - outptr[x] = op(a0, ptr1[x]); - } - ptr1 += w1; - outptr += w1; - } - } - } - - return 0; - } - - if (b.dims == 3) + if (a.dims == 4 && b.dims == 3) + { + // type 13 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - // type 14 - c.create(w1, h1, channels1, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; + const float* ptr = a.channel(q); + const float* ptr1 = b.channel(q); + float* outptr = c.channel(q); - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) + for (int z = 0; z < d; z++) { - const float* ptr = a.row(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int y = 0; y < h1; y++) + for (int y = 0; y < h; y++) { - const float a0 = ptr[y]; - for (int x = 0; x < w1; x++) + const float b0 = ptr1[y]; + for (int x = 0; x < w; x++) { - outptr[x] = op(a0, ptr1[x]); + outptr[x] = op(ptr[x], b0); } - ptr1 += w1; - outptr += w1; + ptr += w; + outptr += w; } - } - - return 0; - } - - c.create(w, h, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - if (b.dims == 2) - { - // type 13 - for (int i = 0; i < size; i++) - { - c[i] = op(a[i], b[i]); + ptr1 += h; } - - return 0; } + } - if (b.dims == 1) - { - c.create(w, h, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; + return 0; +} - if (b.w == 1) - { - // type 11 - const float b0 = b[0]; - for (int i = 0; i < size; i++) - { - c[i] = op(a[i], b0); - } +template +static int binary_op_broadcast_outer(const Mat& a, const Mat& b, Mat& c, const Option& opt) +{ + Op op; - return 0; - } + int w = a.w; + int h = a.h; + int d = a.d; + int channels = a.c; - // type 12 - const float* ptr = a; - float* outptr = c; + if (a.dims == 2) + { + // type 14 + #pragma omp parallel for num_threads(opt.num_threads) + for (int y = 0; y < h; y++) + { + const float* ptr = a.row(y); + const float* ptr1 = b; + float* outptr = c.row(y); - for (int y = 0; y < h; y++) + for (int x = 0; x < w; x++) { - const float b0 = b[y]; - for (int x = 0; x < w; x++) - { - outptr[x] = op(ptr[x], b0); - } - - ptr += w; - outptr += w; + outptr[x] = op(ptr[x], ptr1[x]); } - - return 0; } } - else if (a.dims == 1) + + if (a.dims == 3 || a.dims == 4) { - if (a.w == 1) + // type 15 16 17 18 19 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - if (b.dims == 4) - { - // type 20 - c.create(w1, h1, d1, channels1, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - const float a0 = a[0]; - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int i = 0; i < size1; i++) - { - outptr[i] = op(a0, ptr1[i]); - } - } - - return 0; - } + const float* ptr = a.channel(q); + float* outptr = c.channel(q); - if (b.dims == 3) + for (int z = 0; z < d; z++) { - // type 4 - c.create(w1, h1, channels1, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - const float a0 = a[0]; - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) + int z1 = std::min(z, b.d - 1); + for (int y = 0; y < h; y++) { - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); + int y1 = std::min(y, b.h - 1); - for (int i = 0; i < size1; i++) + const float* ptr1 = b.depth(z1).row(y1); + for (int x = 0; x < w; x++) { - outptr[i] = op(a0, ptr1[i]); + outptr[x] = op(ptr[x], ptr1[x]); } - } - - return 0; - } - - if (b.dims == 2) - { - // type 3 - c.create(w1, h1, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - const float a0 = a[0]; - for (int i = 0; i < size1; i++) - { - c[i] = op(a0, b[i]); - } - - return 0; - } - - if (b.dims == 1) - { - // type 2 - c.create(w1, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - const float a0 = a[0]; - for (int i = 0; i < w1; i++) - { - c[i] = op(a0, b[i]); - } - - return 0; - } - } - - if (b.dims == 4) - { - // type 21 - c.create(w1, h1, d1, channels1, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float a0 = a[q]; - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int i = 0; i < size1; i++) - { - outptr[i] = op(a0, ptr1[i]); - } - } - - return 0; - } - - if (b.dims == 3) - { - // type 9 - c.create(w1, h1, channels1, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float a0 = a[q]; - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int i = 0; i < size1; i++) - { - outptr[i] = op(a0, ptr1[i]); + ptr += w; + outptr += w; } } - - return 0; } + } - if (b.dims == 2) - { - // type 8 - c.create(w1, h1, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - const float* ptr1 = b; - float* outptr = c; + return 0; +} - for (int y = 0; y < h1; y++) - { - const float a0 = a[y]; - for (int x = 0; x < w1; x++) - { - outptr[x] = op(a0, ptr1[x]); - } +template +static int binary_op_broadcast_20(const Mat& a, const Mat& b, Mat& c, const Option& opt) +{ + Op op; - ptr1 += w1; - outptr += w1; - } + int w = a.w; + int h = a.h; + int channels = a.c; - return 0; - } + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* ptr = a.channel(q); + const float* ptr1 = b.channel(q); + float* outptr = c.channel(q); - if (b.dims == 1) + for (int y = 0; y < h; y++) { - c.create(w, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.w == 1) + for (int x = 0; x < w; x++) { - // type 6 - const float b0 = b[0]; - for (int i = 0; i < w; i++) - { - c[i] = op(a[i], b0); - } - - return 0; + outptr[x] = op(ptr[x], ptr1[x]); } - // type 7 - for (int i = 0; i < w; i++) - { - c[i] = op(a[i], b[i]); - } + ptr += w; + outptr += w; } } @@ -826,11 +315,8 @@ static int binary_op_scalar_inplace(Mat& a, float b, const Option& opt) { Op op; - int w = a.w; - int h = a.h; - int d = a.d; int channels = a.c; - int size = w * h * d; + int size = a.w * a.h * a.d; #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) @@ -918,72 +404,198 @@ struct binary_op_rdiv } }; -int BinaryOp::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +struct binary_op_rpow { - const Mat& bottom_blob = bottom_blobs[0]; - const Mat& bottom_blob1 = bottom_blobs[1]; - - Mat& top_blob = top_blobs[0]; - - if (op_type == Operation_ADD) - return binary_op(bottom_blob, bottom_blob1, top_blob, opt); - - if (op_type == Operation_SUB) - return binary_op(bottom_blob, bottom_blob1, top_blob, opt); - - if (op_type == Operation_MUL) - return binary_op(bottom_blob, bottom_blob1, top_blob, opt); - - if (op_type == Operation_DIV) - return binary_op(bottom_blob, bottom_blob1, top_blob, opt); - - if (op_type == Operation_MAX) - return binary_op(bottom_blob, bottom_blob1, top_blob, opt); + float operator()(const float& x, const float& y) const + { + return (float)pow(y, x); + } +}; - if (op_type == Operation_MIN) - return binary_op(bottom_blob, bottom_blob1, top_blob, opt); +static int binary_op_scalar(const Mat& a, float b, Mat& c, int op_type, const Option& opt) +{ + if (op_type == BinaryOp::Operation_ADD) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_SUB) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_MUL) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_DIV) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_MAX) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_MIN) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_POW) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_RSUB) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_RDIV) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_RPOW) return binary_op_scalar(a, b, c, opt); + + // should never reach here + return 0; +} - if (op_type == Operation_POW) - return binary_op(bottom_blob, bottom_blob1, top_blob, opt); +static int binary_op_no_broadcast(const Mat& a, const Mat& b, Mat& c, int op_type, const Option& opt) +{ + if (op_type == BinaryOp::Operation_ADD) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_SUB) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_MUL) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_DIV) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_MAX) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_MIN) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_POW) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_RSUB) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_RDIV) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_RPOW) return binary_op_no_broadcast(a, b, c, opt); + + // should never reach here + return 0; +} - if (op_type == Operation_RSUB) - return binary_op(bottom_blob1, bottom_blob, top_blob, opt); +static int binary_op_broadcast_inner(const Mat& a, const Mat& b, Mat& c, int op_type, const Option& opt) +{ + // squeeze inner axes + Mat b2 = b; + if (b.dims == 2 && b.w == 1) + b2 = b.reshape(b.h); + else if (b.dims == 3 && b.h == 1) + b2 = b.reshape(b.c); + else if (b.dims == 3 && b.w == 1) + b2 = b.reshape(b.h, b.c); + else if (b.dims == 4 && b.d == 1) + b2 = b.reshape(b.c); + else if (b.dims == 4 && b.h == 1) + b2 = b.reshape(b.d, b.c); + else if (b.dims == 4 && b.w == 1) + b2 = b.reshape(b.h, b.d, b.c); + + if (op_type == BinaryOp::Operation_ADD) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_SUB) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_MUL) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_DIV) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_MAX) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_MIN) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_POW) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_inner(a, b2, c, opt); + + // should never reach here + return 0; +} - if (op_type == Operation_RDIV) - return binary_op(bottom_blob1, bottom_blob, top_blob, opt); +static int binary_op_broadcast_outer(const Mat& a, const Mat& b, Mat& c, int op_type, const Option& opt) +{ + if (op_type == BinaryOp::Operation_ADD) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_SUB) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_MUL) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_DIV) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_MAX) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_MIN) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_POW) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_outer(a, b, c, opt); + + // should never reach here + return 0; +} +static int binary_op_broadcast_20(const Mat& a, const Mat& b, Mat& c, int op_type, const Option& opt) +{ + if (op_type == BinaryOp::Operation_ADD) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_SUB) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_MUL) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_DIV) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_MAX) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_MIN) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_POW) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_20(a, b, c, opt); + + // should never reach here return 0; } -int BinaryOp::forward_inplace(Mat& bottom_top_blob, const Option& opt) const +static int get_reverse_op_type(int op_type) { - if (op_type == Operation_ADD) - return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == BinaryOp::Operation_SUB) return BinaryOp::Operation_RSUB; + if (op_type == BinaryOp::Operation_DIV) return BinaryOp::Operation_RDIV; + if (op_type == BinaryOp::Operation_POW) return BinaryOp::Operation_RPOW; + if (op_type == BinaryOp::Operation_RSUB) return BinaryOp::Operation_SUB; + if (op_type == BinaryOp::Operation_RDIV) return BinaryOp::Operation_DIV; + if (op_type == BinaryOp::Operation_RPOW) return BinaryOp::Operation_POW; + return op_type; +} - if (op_type == Operation_SUB) - return binary_op_scalar_inplace(bottom_top_blob, b, opt); +int BinaryOp::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +{ + const bool b_is_scalar = bottom_blobs[1].w * bottom_blobs[1].h * bottom_blobs[1].d * bottom_blobs[1].c * bottom_blobs[1].elempack == 1; + const bool a_rank_is_lower = bottom_blobs[0].dims < bottom_blobs[1].dims && !b_is_scalar; + const bool a_size_is_lower = bottom_blobs[0].w * bottom_blobs[0].h * bottom_blobs[0].d * bottom_blobs[0].c < bottom_blobs[1].w * bottom_blobs[1].h * bottom_blobs[1].d * bottom_blobs[1].c; + const bool a_is_lower = a_rank_is_lower || (!a_rank_is_lower && a_size_is_lower); + const Mat& A = a_is_lower ? bottom_blobs[1] : bottom_blobs[0]; + const Mat& B = a_is_lower ? bottom_blobs[0] : bottom_blobs[1]; + const int op_type_r = a_is_lower ? get_reverse_op_type(op_type) : op_type; - if (op_type == Operation_MUL) - return binary_op_scalar_inplace(bottom_top_blob, b, opt); + Mat& top_blob = top_blobs[0]; + top_blob.create_like(A, opt.blob_allocator); + if (top_blob.empty()) + return -100; - if (op_type == Operation_DIV) - return binary_op_scalar_inplace(bottom_top_blob, b, opt); + // B is a scalar + if (B.w * B.h * B.d * B.c == 1) + { + return binary_op_scalar(A, B[0], top_blob, op_type_r, opt); + } - if (op_type == Operation_MAX) - return binary_op_scalar_inplace(bottom_top_blob, b, opt); + // no broadcast + if (A.dims == B.dims && A.w == B.w && A.h == B.h && A.d == B.d && A.c == B.c) + { + return binary_op_no_broadcast(A, B, top_blob, op_type_r, opt); + } - if (op_type == Operation_MIN) - return binary_op_scalar_inplace(bottom_top_blob, b, opt); + // broadcast B for inner axis + if ((B.dims < A.dims) + || (A.dims == 2 && B.w == 1 && B.h == A.h) + || (A.dims == 3 && B.w == 1 && B.h == 1 && B.c == A.c) + || (A.dims == 3 && B.w == 1 && B.h == A.h && B.c == A.c) + || (A.dims == 4 && B.w == 1 && B.h == 1 && B.d == 1 && B.c == A.c) + || (A.dims == 4 && B.w == 1 && B.h == 1 && B.d == A.d && B.c == A.c) + || (A.dims == 4 && B.w == 1 && B.h == A.h && B.d == A.d && B.c == A.c)) + { + return binary_op_broadcast_inner(A, B, top_blob, op_type_r, opt); + } - if (op_type == Operation_POW) - return binary_op_scalar_inplace(bottom_top_blob, b, opt); + // broadcast B for outer axis + if ((A.dims == 2 && B.w == A.w && B.h == 1) + || (A.dims == 3 && B.w == A.w && B.h == 1 && B.c == 1) + || (A.dims == 3 && B.w == A.w && B.h == A.h && B.c == 1) + || (A.dims == 4 && B.w == A.w && B.h == 1 && B.d == 1 && B.c == 1) + || (A.dims == 4 && B.w == A.w && B.h == A.h && B.d == 1 && B.c == 1) + || (A.dims == 4 && B.w == A.w && B.h == A.h && B.d == A.d && B.c == 1)) + { + return binary_op_broadcast_outer(A, B, top_blob, op_type_r, opt); + } - if (op_type == Operation_RSUB) - return binary_op_scalar_inplace(bottom_top_blob, b, opt); + // some special broadcast rule here + if (A.dims == 3 && B.dims == 3 && A.w == B.w && B.h == 1 && A.c == B.c) + { + return binary_op_broadcast_20(A, B, top_blob, op_type_r, opt); + } - if (op_type == Operation_RDIV) - return binary_op_scalar_inplace(bottom_top_blob, b, opt); + return 0; +} +int BinaryOp::forward_inplace(Mat& bottom_top_blob, const Option& opt) const +{ + if (op_type == Operation_ADD) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_SUB) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_MUL) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_DIV) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_MAX) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_MIN) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_POW) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_RSUB) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_RDIV) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_RPOW) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + + // should nerver reach here return 0; } diff --git a/src/layer/binaryop.h b/src/layer/binaryop.h index 74798f906..f1027535d 100644 --- a/src/layer/binaryop.h +++ b/src/layer/binaryop.h @@ -42,7 +42,8 @@ public: Operation_MIN = 5, Operation_POW = 6, Operation_RSUB = 7, - Operation_RDIV = 8 + Operation_RDIV = 8, + Operation_RPOW = 9 }; public: diff --git a/src/layer/loongarch/binaryop_loongarch.cpp b/src/layer/loongarch/binaryop_loongarch.cpp index 7832c9ca7..89e31c51c 100644 --- a/src/layer/loongarch/binaryop_loongarch.cpp +++ b/src/layer/loongarch/binaryop_loongarch.cpp @@ -31,95 +31,37 @@ BinaryOp_loongarch::BinaryOp_loongarch() } template -static int binary_op_2_3_4_20(const Mat& a, const Mat& b, Mat& c, const Option& opt) +static int binary_op_scalar(const Mat& a, float b, Mat& c, const Option& opt) { Op op; - int w = b.w; - int h = b.h; - int d = b.d; - int channels = b.c; - int elempack = b.elempack; - int size = w * h * d * elempack; - - // type 2 3 4 20 - c.create_like(b, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float a0 = a[0]; - const float* ptr = b.channel(q); - float* outptr = c.channel(q); - - int i = 0; -#if __loongarch_sx - __m128 _a0 = __lsx_vreplfr2vr_s(a0); - for (; i + 3 < size; i += 4) - { - __builtin_prefetch(ptr + 16); - __m128 _p = (__m128)__lsx_vld(ptr, 0); - __m128 _outp = op(_a0, _p); - __lsx_vst(_outp, outptr, 0); - ptr += 4; - outptr += 4; - } -#endif // __loongarch_sx - for (; i < size; i++) - { - *outptr = op(a0, *ptr); - ptr += 1; - outptr += 1; - } - } - - return 0; -} - -template -static int binary_op_6_11_16_25(const Mat& a, const Mat& b, Mat& c, const Option& opt) -{ - Op op; - - int w = a.w; - int h = a.h; - int d = a.d; - int channels = a.c; - int elempack = a.elempack; - int size = w * h * d * elempack; - - // type 6 11 16 25 - c.create_like(a, opt.blob_allocator); - if (c.empty()) - return -100; + const int channels = a.c; + const int size = a.w * a.h * a.d * a.elempack; #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) { const float* ptr = a.channel(q); - const float b0 = b[0]; float* outptr = c.channel(q); int i = 0; #if __loongarch_sx - __m128 _b0 = __lsx_vreplfr2vr_s(b0); + __m128 _b = __lsx_vreplfr2vr_s(b); for (; i + 3 < size; i += 4) { __builtin_prefetch(ptr + 16); __m128 _p = (__m128)__lsx_vld(ptr, 0); - __m128 _outp = op(_p, _b0); - __lsx_vst(_outp, outptr, 0); + _p = op(_p, _b); + __lsx_vst(_p, outptr, 0); ptr += 4; outptr += 4; } #endif // __loongarch_sx for (; i < size; i++) { - *outptr = op(*ptr, b0); - ptr += 1; - outptr += 1; + *outptr = op(*ptr, b); + ptr++; + outptr++; } } @@ -127,21 +69,12 @@ static int binary_op_6_11_16_25(const Mat& a, const Mat& b, Mat& c, const Option } template -static int binary_op_7_13_19_29(const Mat& a, const Mat& b, Mat& c, const Option& opt) +static int binary_op_no_broadcast(const Mat& a, const Mat& b, Mat& c, const Option& opt) { Op op; - int w = a.w; - int h = a.h; - int d = a.d; - int channels = a.c; - int elempack = a.elempack; - int size = w * h * d * elempack; - - // type 7 13 19 29 - c.create_like(a, opt.blob_allocator); - if (c.empty()) - return -100; + const int channels = a.c; + const int size = a.w * a.h * a.d * a.elempack; #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) @@ -177,12 +110,8 @@ static int binary_op_7_13_19_29(const Mat& a, const Mat& b, Mat& c, const Option return 0; } -#if __loongarch_sx -// broadcasting rule -// https://github.com/Tencent/ncnn/wiki/binaryop-broadcasting - template -static int binary_op_pack4(const Mat& a, const Mat& b, Mat& c, const Option& opt) +static int binary_op_broadcast_inner(const Mat& a, const Mat& b, Mat& c, const Option& opt) { Op op; @@ -190,715 +119,363 @@ static int binary_op_pack4(const Mat& a, const Mat& b, Mat& c, const Option& opt int h = a.h; int d = a.d; int channels = a.c; - int size = w * h * d; - size_t elemsize = a.elemsize; int elempack = a.elempack; - int w1 = b.w; - int h1 = b.h; - int d1 = b.d; - int channels1 = b.c; - int size1 = w1 * h1 * d1; - size_t elemsize1 = b.elemsize; - int elempack1 = b.elempack; - - if (a.dims == 4) + if (a.dims == 2 && b.dims == 1) { - if (b.dims == 4) + // type 8 + #pragma omp parallel for num_threads(opt.num_threads) + for (int y = 0; y < h; y++) { - // type 29 - return binary_op_7_13_19_29(a, b, c, opt); - } + const float* ptr = a.row(y); + float* outptr = c.row(y); - c.create(w, h, d, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; + const float _b = b[y]; +#if __loongarch_sx + __m128 _b_128 = (elempack == 4) ? (__m128)__lsx_vld((const float*)b + y * 4, 0) : __lsx_vreplfr2vr_s(_b); +#endif // __loongarch_sx - if (b.dims == 3) - { - // type 28 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); + const int size = w * elempack; - for (int z = 0; z < d; z++) - { - for (int y = 0; y < h; y++) - { - __m128 _b0 = (__m128)__lsx_vld(ptr1, 0); - for (int x = 0; x < w; x++) - { - __builtin_prefetch(ptr + 16); - __m128 _p = (__m128)__lsx_vld(ptr, 0); - __m128 _outp = op(_p, _b0); - __lsx_vst(_outp, outptr, 0); - ptr += 4; - outptr += 4; - } - - ptr1 += 4; - } - } + int i = 0; +#if __loongarch_sx + for (; i + 3 < size; i += 4) + { + __builtin_prefetch(ptr + 16); + __m128 _p = (__m128)__lsx_vld(ptr, 0); + __m128 _outp = op(_p, _b_128); + __lsx_vst(_outp, outptr, 0); + ptr += 4; + outptr += 4; + } +#endif // __loongarch_sx + for (; i < size; i++) + { + *outptr = op(*ptr, _b); + ptr += 1; + outptr += 1; } - - return 0; } + } - if (b.dims == 2) + if ((a.dims == 3 || a.dims == 4) && b.dims == 1) + { + // type 9 11 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - // type 27 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.row(q); - float* outptr = c.channel(q); + const float* ptr = a.channel(q); + float* outptr = c.channel(q); - for (int z = 0; z < d; z++) - { - __m128 _b0 = (__m128)__lsx_vld(ptr1, 0); - for (int y = 0; y < h; y++) - { - for (int x = 0; x < w; x++) - { - __builtin_prefetch(ptr + 16); - __m128 _p = (__m128)__lsx_vld(ptr, 0); - __m128 _outp = op(_p, _b0); - __lsx_vst(_outp, outptr, 0); - ptr += 4; - outptr += 4; - } - } + const float _b = b[q]; +#if __loongarch_sx + __m128 _b_128 = (elempack == 4) ? (__m128)__lsx_vld((const float*)b + q * 4, 0) : __lsx_vreplfr2vr_s(_b); +#endif // __loongarch_sx - ptr1 += 4; - } - } + const int size = w * h * d * elempack; - return 0; + int i = 0; +#if __loongarch_sx + for (; i + 3 < size; i += 4) + { + __builtin_prefetch(ptr + 16); + __m128 _p = (__m128)__lsx_vld(ptr, 0); + __m128 _outp = op(_p, _b_128); + __lsx_vst(_outp, outptr, 0); + ptr += 4; + outptr += 4; + } +#endif // __loongarch_sx + for (; i < size; i++) + { + *outptr = op(*ptr, _b); + ptr += 1; + outptr += 1; + } } + } - if (b.dims == 1) + if (a.dims == 3 && b.dims == 2) + { + // type 10 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - if (b.w == 1 && elempack1 == 1) - { - // type 25 - return binary_op_6_11_16_25(a, b, c, opt); - } + const float* ptr = a.channel(q); + const float* ptr1 = b.row(q); + float* outptr = c.channel(q); + + const int size = w * elempack; - // type 26 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) + for (int y = 0; y < h; y++) { - const float* ptr = a.channel(q); - __m128 _b0 = (__m128)__lsx_vld((const float*)b + q * 4, 0); - float* outptr = c.channel(q); + const float _b = ptr1[y]; +#if __loongarch_sx + __m128 _b_128 = (elempack == 4) ? (__m128)__lsx_vld((const float*)ptr1 + y * 4, 0) : __lsx_vreplfr2vr_s(_b); +#endif // __loongarch_sx - for (int i = 0; i < size; i++) + int i = 0; +#if __loongarch_sx + for (; i + 3 < size; i += 4) { __builtin_prefetch(ptr + 16); __m128 _p = (__m128)__lsx_vld(ptr, 0); - __m128 _outp = op(_p, _b0); + __m128 _outp = op(_p, _b_128); __lsx_vst(_outp, outptr, 0); ptr += 4; outptr += 4; } +#endif // __loongarch_sx + for (; i < size; i++) + { + *outptr = op(*ptr, _b); + ptr += 1; + outptr += 1; + } } - - return 0; } } - else if (a.dims == 3) + + if (a.dims == 4 && b.dims == 2) { - if (b.dims == 4) + // type 12 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - // type 23 - c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; + const float* ptr = a.channel(q); + const float* ptr1 = b.row(q); + float* outptr = c.channel(q); + + const int size = w * h * elempack; - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) + for (int z = 0; z < d; z++) { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); + const float _b = ptr1[z]; +#if __loongarch_sx + __m128 _b_128 = (elempack == 4) ? (__m128)__lsx_vld((const float*)ptr1 + z * 4, 0) : __lsx_vreplfr2vr_s(_b); +#endif // __loongarch_sx - for (int z = 0; z < d1; z++) + int i = 0; +#if __loongarch_sx + for (; i + 3 < size; i += 4) { - for (int y = 0; y < h1; y++) - { - __m128 _a0 = (__m128)__lsx_vld(ptr, 0); - for (int x = 0; x < w1; x++) - { - __builtin_prefetch(ptr1 + 16); - __m128 _p = (__m128)__lsx_vld(ptr1, 0); - __m128 _outp = op(_a0, _p); - __lsx_vst(_outp, outptr, 0); - ptr1 += 4; - outptr += 4; - } - - ptr += 4; - } + __builtin_prefetch(ptr + 16); + __m128 _p = (__m128)__lsx_vld(ptr, 0); + __m128 _outp = op(_p, _b_128); + __lsx_vst(_outp, outptr, 0); + ptr += 4; + outptr += 4; + } +#endif // __loongarch_sx + for (; i < size; i++) + { + *outptr = op(*ptr, _b); + ptr += 1; + outptr += 1; } } - - return 0; } + } - if (b.dims == 3) + if (a.dims == 4 && b.dims == 3) + { + // type 13 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - if (w1 == 1 && h1 == 1 && channels1 == channels) - { - // special type 1 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float* ptr = a.channel(q); - const float* b0 = b.channel(q); - float* outptr = c.channel(q); - __m128 _b0 = (__m128)__lsx_vld(b0, 0); - for (int i = 0; i < size; i++) - { - __builtin_prefetch(ptr + 16); - __m128 _p = (__m128)__lsx_vld(ptr, 0); - __m128 _outp = op(_p, _b0); - __lsx_vst(_outp, outptr, 0); - ptr += 4; - outptr += 4; - } - } + const float* ptr = a.channel(q); + float* outptr = c.channel(q); - return 0; - } + const int size = w * elempack; - if (w1 == w && h1 == h && channels1 == 1 && elempack1 == 1) + for (int z = 0; z < d; z++) { - // special type 2 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; + const float* ptr1 = b.channel(q).row(z); - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) + for (int y = 0; y < h; y++) { - const float* ptr = a.channel(q); - const float* ptr1 = b; - float* outptr = c.channel(q); - for (int i = 0; i < size; i++) + const float _b = ptr1[y]; +#if __loongarch_sx + __m128 _b_128 = (elempack == 4) ? (__m128)__lsx_vld((const float*)ptr1 + y * 4, 0) : __lsx_vreplfr2vr_s(_b); +#endif // __loongarch_sx + + int i = 0; +#if __loongarch_sx + for (; i + 3 < size; i += 4) { __builtin_prefetch(ptr + 16); __m128 _p = (__m128)__lsx_vld(ptr, 0); - __m128 _p1 = __lsx_vreplfr2vr_s(ptr1[0]); - __m128 _outp = op(_p, _p1); + __m128 _outp = op(_p, _b_128); __lsx_vst(_outp, outptr, 0); ptr += 4; - ptr1 += 1; outptr += 4; } - } - - return 0; - } - - if (w == 1 && h == 1 && channels1 == channels) - { - // special type 3 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* a0 = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - __m128 _a0 = (__m128)__lsx_vld(a0, 0); - for (int i = 0; i < size1; i++) - { - __builtin_prefetch(ptr1 + 16); - __m128 _p1 = (__m128)__lsx_vld(ptr1, 0); - __m128 _outp = op(_a0, _p1); - __lsx_vst(_outp, outptr, 0); - ptr1 += 4; - outptr += 4; - } - } - - return 0; - } - - if (w1 == w && h1 == h && channels == 1 && elempack == 1) - { - // special type 4 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr = a; - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - for (int i = 0; i < size1; i++) +#endif // __loongarch_sx + for (; i < size; i++) { - __builtin_prefetch(ptr + 16); - __builtin_prefetch(ptr1 + 16); - __m128 _p = __lsx_vreplfr2vr_s(ptr[0]); - __m128 _p1 = (__m128)__lsx_vld(ptr1, 0); - __m128 _outp = op(_p, _p1); - __lsx_vst(_outp, outptr, 0); + *outptr = op(*ptr, _b); ptr += 1; - ptr1 += 4; - outptr += 4; - } - } - - return 0; - } - - if (w != 1 && w1 == 1 && h1 == h && channels1 == channels) - { - // special type 5 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int y = 0; y < h; y++) - { - __m128 _p1 = (__m128)__lsx_vld(ptr1 + y * 4, 0); - for (int x = 0; x < w; x++) - { - __builtin_prefetch(ptr + 16); - __m128 _p = (__m128)__lsx_vld(ptr, 0); - __m128 _outp = op(_p, _p1); - __lsx_vst(_outp, outptr, 0); - - ptr += 4; - outptr += 4; - } - } - } - - return 0; - } - - if (w1 == w && h != 1 && h1 == 1 && channels1 == channels) - { - // special type 6 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int y = 0; y < h; y++) - { - for (int x = 0; x < w; x++) - { - __builtin_prefetch(ptr + 16); - __m128 _p = (__m128)__lsx_vld(ptr, 0); - __m128 _p1 = (__m128)__lsx_vld(ptr1 + x * 4, 0); - __m128 _outp = op(_p, _p1); - __lsx_vst(_outp, outptr, 0); - - ptr += 4; - outptr += 4; - } + outptr += 1; } } - - return 0; } - - if (w1 != 1 && w == 1 && h1 == h && channels1 == channels) - { - // special type 7 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int y = 0; y < h1; y++) - { - __m128 _p = (__m128)__lsx_vld(ptr + y * 4, 0); - for (int x = 0; x < w1; x++) - { - __builtin_prefetch(ptr1 + 16); - __m128 _p1 = (__m128)__lsx_vld(ptr1, 0); - __m128 _outp = op(_p, _p1); - __lsx_vst(_outp, outptr, 0); - - ptr1 += 4; - outptr += 4; - } - } - } - - return 0; - } - - if (w1 == w && h1 != 1 && h == 1 && channels1 == channels) - { - // special type 8 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int y = 0; y < h1; y++) - { - for (int x = 0; x < w1; x++) - { - __builtin_prefetch(ptr1 + 16); - __m128 _p = (__m128)__lsx_vld(ptr + x * 4, 0); - __m128 _p1 = (__m128)__lsx_vld(ptr1, 0); - __m128 _outp = op(_p, _p1); - __lsx_vst(_outp, outptr, 0); - - ptr1 += 4; - outptr += 4; - } - } - } - - return 0; - } - - // type 19 - return binary_op_7_13_19_29(a, b, c, opt); } + } - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.dims == 2) - { - // type 18 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.row(q); - float* outptr = c.channel(q); - - for (int y = 0; y < h; y++) - { - __m128 _b0 = (__m128)__lsx_vld(ptr1, 0); - for (int x = 0; x < w; x++) - { - __builtin_prefetch(ptr + 16); - __m128 _p = (__m128)__lsx_vld(ptr, 0); - __m128 _outp = op(_p, _b0); - __lsx_vst(_outp, outptr, 0); - ptr += 4; - outptr += 4; - } + return 0; +} - ptr1 += 4; - } - } +template +static int binary_op_broadcast_outer(const Mat& a, const Mat& b, Mat& c, const Option& opt) +{ + Op op; - return 0; - } + int w = a.w; + int h = a.h; + int d = a.d; + int channels = a.c; + int elempack = a.elempack; - if (b.dims == 1) + if (a.dims == 2) + { + // type 14 + #pragma omp parallel for num_threads(opt.num_threads) + for (int y = 0; y < h; y++) { - if (b.w == 1 && elempack1 == 1) - { - // type 16 - return binary_op_6_11_16_25(a, b, c, opt); - } + const float* ptr = a.row(y); + const float* ptr1 = b; + float* outptr = c.row(y); - // type 17 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) +#if __loongarch_sx + if (elempack == 4) { - const float* ptr = a.channel(q); - __m128 _b0 = (__m128)__lsx_vld((const float*)b + q * 4, 0); - float* outptr = c.channel(q); - - for (int i = 0; i < size; i++) + for (int x = 0; x < w; x++) { __builtin_prefetch(ptr + 16); __m128 _p = (__m128)__lsx_vld(ptr, 0); - __m128 _outp = op(_p, _b0); + __m128 _b = __lsx_vreplfr2vr_s(*ptr1); + __m128 _outp = op(_p, _b); __lsx_vst(_outp, outptr, 0); ptr += 4; + ptr1 += 1; outptr += 4; } } - - return 0; +#endif // __loongarch_sx + if (elempack == 1) + { + for (int x = 0; x < w; x++) + { + *outptr = op(*ptr, *ptr1); + ptr += 1; + ptr1 += 1; + outptr += 1; + } + } } } - else if (a.dims == 2) + + if (a.dims == 3 || a.dims == 4) { - if (b.dims == 4) + // type 15 16 17 18 19 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - // type 22 - c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; + const float* ptr = a.channel(q); + float* outptr = c.channel(q); - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) + for (int z = 0; z < d; z++) { - const float* ptr = a.row(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int z = 0; z < d1; z++) + int z1 = std::min(z, b.d - 1); + for (int y = 0; y < h; y++) { - __m128 _a0 = (__m128)__lsx_vld(ptr, 0); - for (int y = 0; y < h1; y++) + int y1 = std::min(y, b.h - 1); + + const float* ptr1 = b.depth(z1).row(y1); + +#if __loongarch_sx + if (elempack == 4) { - for (int x = 0; x < w1; x++) + for (int x = 0; x < w; x++) { - __builtin_prefetch(ptr1 + 16); - __m128 _p = (__m128)__lsx_vld(ptr1, 0); - __m128 _outp = op(_a0, _p); + __builtin_prefetch(ptr + 16); + __m128 _p = (__m128)__lsx_vld(ptr, 0); + __m128 _b = __lsx_vreplfr2vr_s(*ptr1); + __m128 _outp = op(_p, _b); __lsx_vst(_outp, outptr, 0); - ptr1 += 4; + ptr += 4; + ptr1 += 1; outptr += 4; } } - - ptr += 4; - } - } - - return 0; - } - - if (b.dims == 3) - { - // type 14 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr = a.row(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int y = 0; y < h1; y++) - { - __m128 _a0 = (__m128)__lsx_vld(ptr, 0); - for (int x = 0; x < w1; x++) +#endif // __loongarch_sx + if (elempack == 1) { - __builtin_prefetch(ptr1 + 16); - __m128 _p1 = (__m128)__lsx_vld(ptr1, 0); - __m128 _outp = op(_a0, _p1); - __lsx_vst(_outp, outptr, 0); - ptr1 += 4; - outptr += 4; + for (int x = 0; x < w; x++) + { + *outptr = op(*ptr, *ptr1); + ptr += 1; + ptr1 += 1; + outptr += 1; + } } - - ptr += 4; - } - } - - return 0; - } - - c.create(w, h, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.dims == 2) - { - // type 13 - return binary_op_7_13_19_29(a, b, c, opt); - } - - if (b.dims == 1) - { - c.create(w, h, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.w == 1 && elempack1 == 1) - { - // type 11 - return binary_op_6_11_16_25(a, b, c, opt); - } - - // type 12 - const float* ptr = a; - const float* ptr1 = b; - float* outptr = c; - - for (int y = 0; y < h; y++) - { - __m128 _b0 = (__m128)__lsx_vld(ptr1, 0); - for (int x = 0; x < w; x++) - { - __builtin_prefetch(ptr + 16); - __m128 _p = (__m128)__lsx_vld(ptr, 0); - __m128 _outp = op(_p, _b0); - __lsx_vst(_outp, outptr, 0); - ptr += 4; - outptr += 4; } - - ptr1 += 4; } - - return 0; } } - else if (a.dims == 1) - { - if (a.w == 1 && elempack == 1) - { - // type 2 3 4 20 - return binary_op_2_3_4_20(a, b, c, opt); - } - - if (b.dims == 4) - { - // type 21 - c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - __m128 _a0 = (__m128)__lsx_vld((const float*)a + q * 4, 0); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int i = 0; i < size1; i++) - { - __builtin_prefetch(ptr1 + 16); - __m128 _p1 = (__m128)__lsx_vld(ptr1, 0); - __m128 _outp = op(_a0, _p1); - __lsx_vst(_outp, outptr, 0); - ptr1 += 4; - outptr += 4; - } - } - - return 0; - } - if (b.dims == 3) - { - // type 9 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; + return 0; +} - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - __m128 _a0 = (__m128)__lsx_vld((const float*)a + q * 4, 0); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); +template +static int binary_op_broadcast_20(const Mat& a, const Mat& b, Mat& c, const Option& opt) +{ + Op op; - for (int i = 0; i < size1; i++) - { - __builtin_prefetch(ptr1 + 16); - __m128 _p1 = (__m128)__lsx_vld(ptr1, 0); - __m128 _outp = op(_a0, _p1); - __lsx_vst(_outp, outptr, 0); - ptr1 += 4; - outptr += 4; - } - } + int w = a.w; + int h = a.h; + int channels = a.c; + int elempack = a.elempack; - return 0; - } + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* ptr = a.channel(q); + float* outptr = c.channel(q); - if (b.dims == 2) + for (int y = 0; y < h; y++) { - // type 8 - c.create(w1, h1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; + const float* ptr1 = b.channel(q); - const float* ptr = a; - const float* ptr1 = b; - float* outptr = c; + const int size = w * elempack; - for (int y = 0; y < h1; y++) + int i = 0; +#if __loongarch_sx + for (; i + 3 < size; i += 4) { - __m128 _a0 = (__m128)__lsx_vld(ptr, 0); - for (int x = 0; x < w1; x++) - { - __builtin_prefetch(ptr1 + 16); - __m128 _p1 = (__m128)__lsx_vld(ptr1, 0); - __m128 _outp = op(_a0, _p1); - __lsx_vst(_outp, outptr, 0); - ptr1 += 4; - outptr += 4; - } - + __builtin_prefetch(ptr + 16); + __builtin_prefetch(ptr1 + 16); + __m128 _p = (__m128)__lsx_vld(ptr, 0); + __m128 _p1 = (__m128)__lsx_vld(ptr1, 0); + __m128 _outp = op(_p, _p1); + __lsx_vst(_outp, outptr, 0); ptr += 4; + ptr1 += 4; + outptr += 4; } - - return 0; - } - - if (b.dims == 1) - { - c.create(w, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.w == 1 && elempack1 == 1) +#endif // __loongarch_sx + for (; i < size; i++) { - // type 6 - return binary_op_6_11_16_25(a, b, c, opt); + *outptr = op(*ptr, *ptr1); + ptr += 1; + ptr1 += 1; + outptr += 1; } - - // type 7 - binary_op_7_13_19_29(a, b, c, opt); } } return 0; } -#endif // __loongarch_sx template static int binary_op_scalar_inplace(Mat& a, float b, const Option& opt) @@ -976,6 +553,7 @@ MAKE_FUNCTION(binary_op_min, std::min(x, y), __lsx_vfmin_s(x, y)) MAKE_FUNCTION(binary_op_pow, (float)pow(x, y), pow_ps(x, y)) MAKE_FUNCTION(binary_op_rsub, y - x, __lsx_vfsub_s(y, x)) MAKE_FUNCTION(binary_op_rdiv, y / x, __lsx_vfdiv_s(y, x)) +MAKE_FUNCTION(binary_op_rpow, (float)pow(y, x), pow_ps(y, x)) // *INDENT-ON* // clang-format on @@ -983,82 +561,195 @@ MAKE_FUNCTION(binary_op_rdiv, y / x, __lsx_vfdiv_s(y, x)) } // namespace BinaryOp_loongarch_functor -int BinaryOp_loongarch::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +static int binary_op_scalar(const Mat& a, float b, Mat& c, int op_type, const Option& opt) { -#if __loongarch_sx using namespace BinaryOp_loongarch_functor; - const Mat& bottom_blob = bottom_blobs[0]; - const Mat& bottom_blob1 = bottom_blobs[1]; - Mat& top_blob = top_blobs[0]; - - int elempack = bottom_blob.elempack; - int elempack1 = bottom_blob1.elempack; + if (op_type == BinaryOp::Operation_ADD) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_SUB) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_MUL) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_DIV) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_MAX) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_MIN) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_POW) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_RSUB) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_RDIV) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_RPOW) return binary_op_scalar(a, b, c, opt); + + // should never reach here + return 0; +} - if (elempack == 4 || elempack1 == 4) - { - if (op_type == Operation_ADD) - return binary_op_pack4(bottom_blob, bottom_blob1, top_blob, opt); +static int binary_op_no_broadcast(const Mat& a, const Mat& b, Mat& c, int op_type, const Option& opt) +{ + using namespace BinaryOp_loongarch_functor; - if (op_type == Operation_SUB) - return binary_op_pack4(bottom_blob, bottom_blob1, top_blob, opt); + if (op_type == BinaryOp::Operation_ADD) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_SUB) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_MUL) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_DIV) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_MAX) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_MIN) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_POW) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_RSUB) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_RDIV) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_RPOW) return binary_op_no_broadcast(a, b, c, opt); + + // should never reach here + return 0; +} - if (op_type == Operation_MUL) - return binary_op_pack4(bottom_blob, bottom_blob1, top_blob, opt); +static int binary_op_broadcast_inner(const Mat& a, const Mat& b, Mat& c, int op_type, const Option& opt) +{ + // squeeze inner axes + Mat b2 = b; + if (b.dims == 2 && b.w == 1) + b2 = b.reshape(b.h); + else if (b.dims == 3 && b.h == 1) + b2 = b.reshape(b.c); + else if (b.dims == 3 && b.w == 1) + b2 = b.reshape(b.h, b.c); + else if (b.dims == 4 && b.d == 1) + b2 = b.reshape(b.c); + else if (b.dims == 4 && b.h == 1) + b2 = b.reshape(b.d, b.c); + else if (b.dims == 4 && b.w == 1) + b2 = b.reshape(b.h, b.d, b.c); - if (op_type == Operation_DIV) - return binary_op_pack4(bottom_blob, bottom_blob1, top_blob, opt); + using namespace BinaryOp_loongarch_functor; - if (op_type == Operation_MAX) - return binary_op_pack4(bottom_blob, bottom_blob1, top_blob, opt); + if (op_type == BinaryOp::Operation_ADD) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_SUB) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_MUL) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_DIV) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_MAX) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_MIN) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_POW) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_inner(a, b2, c, opt); + + // should never reach here + return 0; +} - if (op_type == Operation_MIN) - return binary_op_pack4(bottom_blob, bottom_blob1, top_blob, opt); +static int binary_op_broadcast_outer(const Mat& a, const Mat& b, Mat& c, int op_type, const Option& opt) +{ + using namespace BinaryOp_loongarch_functor; - if (op_type == Operation_POW) - return binary_op_pack4(bottom_blob, bottom_blob1, top_blob, opt); + if (op_type == BinaryOp::Operation_ADD) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_SUB) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_MUL) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_DIV) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_MAX) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_MIN) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_POW) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_outer(a, b, c, opt); + + // should never reach here + return 0; +} - if (op_type == Operation_RSUB) - return binary_op_pack4(bottom_blob1, bottom_blob, top_blob, opt); +static int binary_op_broadcast_20(const Mat& a, const Mat& b, Mat& c, int op_type, const Option& opt) +{ + using namespace BinaryOp_loongarch_functor; - if (op_type == Operation_RDIV) - return binary_op_pack4(bottom_blob1, bottom_blob, top_blob, opt); - } -#endif // __loongarch_sx + if (op_type == BinaryOp::Operation_ADD) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_SUB) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_MUL) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_DIV) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_MAX) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_MIN) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_POW) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_20(a, b, c, opt); + + // should never reach here + return 0; +} - return BinaryOp::forward(bottom_blobs, top_blobs, opt); +static int get_reverse_op_type(int op_type) +{ + if (op_type == BinaryOp::Operation_SUB) return BinaryOp::Operation_RSUB; + if (op_type == BinaryOp::Operation_DIV) return BinaryOp::Operation_RDIV; + if (op_type == BinaryOp::Operation_POW) return BinaryOp::Operation_RPOW; + if (op_type == BinaryOp::Operation_RSUB) return BinaryOp::Operation_SUB; + if (op_type == BinaryOp::Operation_RDIV) return BinaryOp::Operation_DIV; + if (op_type == BinaryOp::Operation_RPOW) return BinaryOp::Operation_POW; + return op_type; } -int BinaryOp_loongarch::forward_inplace(Mat& bottom_top_blob, const Option& opt) const +int BinaryOp_loongarch::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const { - using namespace BinaryOp_loongarch_functor; + const bool b_is_scalar = bottom_blobs[1].w * bottom_blobs[1].h * bottom_blobs[1].d * bottom_blobs[1].c * bottom_blobs[1].elempack == 1; + const bool a_rank_is_lower = bottom_blobs[0].dims < bottom_blobs[1].dims && !b_is_scalar; + const bool a_size_is_lower = bottom_blobs[0].w * bottom_blobs[0].h * bottom_blobs[0].d * bottom_blobs[0].c * bottom_blobs[0].elempack < bottom_blobs[1].w * bottom_blobs[1].h * bottom_blobs[1].d * bottom_blobs[1].c * bottom_blobs[1].elempack; + const bool a_is_lower = a_rank_is_lower || (!a_rank_is_lower && a_size_is_lower); + const Mat& A = a_is_lower ? bottom_blobs[1] : bottom_blobs[0]; + const Mat& B = a_is_lower ? bottom_blobs[0] : bottom_blobs[1]; + const int op_type_r = a_is_lower ? get_reverse_op_type(op_type) : op_type; - if (op_type == Operation_ADD) - return binary_op_scalar_inplace(bottom_top_blob, b, opt); + Mat& top_blob = top_blobs[0]; + top_blob.create_like(A, opt.blob_allocator); + if (top_blob.empty()) + return -100; - if (op_type == Operation_SUB) - return binary_op_scalar_inplace(bottom_top_blob, b, opt); + // B is a scalar + if (B.w * B.h * B.d * B.c * B.elempack == 1) + { + return binary_op_scalar(A, B[0], top_blob, op_type_r, opt); + } - if (op_type == Operation_MUL) - return binary_op_scalar_inplace(bottom_top_blob, b, opt); + // no broadcast + if (A.dims == B.dims && A.w == B.w && A.h == B.h && A.d == B.d && A.c == B.c && A.elempack == B.elempack) + { + return binary_op_no_broadcast(A, B, top_blob, op_type_r, opt); + } - if (op_type == Operation_DIV) - return binary_op_scalar_inplace(bottom_top_blob, b, opt); + // broadcast B for inner axis + if ((B.dims < A.dims) + || (A.dims == 2 && B.w == 1 && B.h == A.h) + || (A.dims == 3 && B.w == 1 && B.h == 1 && B.c == A.c) + || (A.dims == 3 && B.w == 1 && B.h == A.h && B.c == A.c) + || (A.dims == 4 && B.w == 1 && B.h == 1 && B.d == 1 && B.c == A.c) + || (A.dims == 4 && B.w == 1 && B.h == 1 && B.d == A.d && B.c == A.c) + || (A.dims == 4 && B.w == 1 && B.h == A.h && B.d == A.d && B.c == A.c)) + { + return binary_op_broadcast_inner(A, B, top_blob, op_type_r, opt); + } - if (op_type == Operation_MAX) - return binary_op_scalar_inplace(bottom_top_blob, b, opt); + // broadcast B for outer axis + if (B.elempack == 1 && ((A.dims == 2 && B.w == A.w && B.h == 1) || (A.dims == 3 && B.w == A.w && B.h == 1 && B.c == 1) || (A.dims == 3 && B.w == A.w && B.h == A.h && B.c == 1) || (A.dims == 4 && B.w == A.w && B.h == 1 && B.d == 1 && B.c == 1) || (A.dims == 4 && B.w == A.w && B.h == A.h && B.d == 1 && B.c == 1) || (A.dims == 4 && B.w == A.w && B.h == A.h && B.d == A.d && B.c == 1))) + { + return binary_op_broadcast_outer(A, B, top_blob, op_type_r, opt); + } - if (op_type == Operation_MIN) - return binary_op_scalar_inplace(bottom_top_blob, b, opt); + // some special broadcast rule here + if (A.dims == 3 && B.dims == 3 && A.w == B.w && B.h == 1 && A.c == B.c) + { + return binary_op_broadcast_20(A, B, top_blob, op_type_r, opt); + } - if (op_type == Operation_POW) - return binary_op_scalar_inplace(bottom_top_blob, b, opt); + return 0; +} - if (op_type == Operation_RSUB) - return binary_op_scalar_inplace(bottom_top_blob, b, opt); +int BinaryOp_loongarch::forward_inplace(Mat& bottom_top_blob, const Option& opt) const +{ + using namespace BinaryOp_loongarch_functor; - if (op_type == Operation_RDIV) - return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_ADD) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_SUB) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_MUL) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_DIV) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_MAX) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_MIN) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_POW) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_RSUB) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_RDIV) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_RPOW) return binary_op_scalar_inplace(bottom_top_blob, b, opt); return 0; } diff --git a/src/layer/mips/binaryop_mips.cpp b/src/layer/mips/binaryop_mips.cpp index 241618762..cfcdbee5c 100644 --- a/src/layer/mips/binaryop_mips.cpp +++ b/src/layer/mips/binaryop_mips.cpp @@ -31,95 +31,37 @@ BinaryOp_mips::BinaryOp_mips() } template -static int binary_op_2_3_4_20(const Mat& a, const Mat& b, Mat& c, const Option& opt) +static int binary_op_scalar(const Mat& a, float b, Mat& c, const Option& opt) { Op op; - int w = b.w; - int h = b.h; - int d = b.d; - int channels = b.c; - int elempack = b.elempack; - int size = w * h * d * elempack; - - // type 2 3 4 20 - c.create_like(b, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float a0 = a[0]; - const float* ptr = b.channel(q); - float* outptr = c.channel(q); - - int i = 0; -#if __mips_msa - v4f32 _a0 = __msa_fill_w_f32(a0); - for (; i + 3 < size; i += 4) - { - __builtin_prefetch(ptr + 16); - v4f32 _p = (v4f32)__msa_ld_w(ptr, 0); - v4f32 _outp = op(_a0, _p); - __msa_st_w((v4i32)_outp, outptr, 0); - ptr += 4; - outptr += 4; - } -#endif // __mips_msa - for (; i < size; i++) - { - *outptr = op(a0, *ptr); - ptr += 1; - outptr += 1; - } - } - - return 0; -} - -template -static int binary_op_6_11_16_25(const Mat& a, const Mat& b, Mat& c, const Option& opt) -{ - Op op; - - int w = a.w; - int h = a.h; - int d = a.d; - int channels = a.c; - int elempack = a.elempack; - int size = w * h * d * elempack; - - // type 6 11 16 25 - c.create_like(a, opt.blob_allocator); - if (c.empty()) - return -100; + const int channels = a.c; + const int size = a.w * a.h * a.d * a.elempack; #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) { const float* ptr = a.channel(q); - const float b0 = b[0]; float* outptr = c.channel(q); int i = 0; #if __mips_msa - v4f32 _b0 = __msa_fill_w_f32(b0); + v4f32 _b = __msa_fill_w_f32(b); for (; i + 3 < size; i += 4) { __builtin_prefetch(ptr + 16); v4f32 _p = (v4f32)__msa_ld_w(ptr, 0); - v4f32 _outp = op(_p, _b0); - __msa_st_w((v4i32)_outp, outptr, 0); + _p = op(_p, _b); + __msa_st_w((v4i32)_p, outptr, 0); ptr += 4; outptr += 4; } #endif // __mips_msa for (; i < size; i++) { - *outptr = op(*ptr, b0); - ptr += 1; - outptr += 1; + *outptr = op(*ptr, b); + ptr++; + outptr++; } } @@ -127,21 +69,12 @@ static int binary_op_6_11_16_25(const Mat& a, const Mat& b, Mat& c, const Option } template -static int binary_op_7_13_19_29(const Mat& a, const Mat& b, Mat& c, const Option& opt) +static int binary_op_no_broadcast(const Mat& a, const Mat& b, Mat& c, const Option& opt) { Op op; - int w = a.w; - int h = a.h; - int d = a.d; - int channels = a.c; - int elempack = a.elempack; - int size = w * h * d * elempack; - - // type 7 13 19 29 - c.create_like(a, opt.blob_allocator); - if (c.empty()) - return -100; + const int channels = a.c; + const int size = a.w * a.h * a.d * a.elempack; #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) @@ -177,12 +110,8 @@ static int binary_op_7_13_19_29(const Mat& a, const Mat& b, Mat& c, const Option return 0; } -#if __mips_msa -// broadcasting rule -// https://github.com/Tencent/ncnn/wiki/binaryop-broadcasting - template -static int binary_op_pack4(const Mat& a, const Mat& b, Mat& c, const Option& opt) +static int binary_op_broadcast_inner(const Mat& a, const Mat& b, Mat& c, const Option& opt) { Op op; @@ -190,727 +119,371 @@ static int binary_op_pack4(const Mat& a, const Mat& b, Mat& c, const Option& opt int h = a.h; int d = a.d; int channels = a.c; - int size = w * h * d; - size_t elemsize = a.elemsize; int elempack = a.elempack; - int w1 = b.w; - int h1 = b.h; - int d1 = b.d; - int channels1 = b.c; - int size1 = w1 * h1 * d1; - size_t elemsize1 = b.elemsize; - int elempack1 = b.elempack; - - if (a.dims == 4) + if (a.dims == 2 && b.dims == 1) { - if (b.dims == 4) + // type 8 + #pragma omp parallel for num_threads(opt.num_threads) + for (int y = 0; y < h; y++) { - // type 29 - return binary_op_7_13_19_29(a, b, c, opt); - } + const float* ptr = a.row(y); + float* outptr = c.row(y); - c.create(w, h, d, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.dims == 3) - { - // type 28 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); + const float _b = b[y]; +#if __mips_msa + v4f32 _b_128 = (elempack == 4) ? (v4f32)__msa_ld_w((const float*)b + y * 4, 0) : __msa_fill_w_f32(_b); +#endif // __mips_msa - for (int z = 0; z < d; z++) - { - for (int y = 0; y < h; y++) - { - v4f32 _b0 = (v4f32)__msa_ld_w(ptr1, 0); - for (int x = 0; x < w; x++) - { - __builtin_prefetch(ptr + 16); - v4f32 _p = (v4f32)__msa_ld_w(ptr, 0); - v4f32 _outp = op(_p, _b0); - __msa_st_w((v4i32)_outp, outptr, 0); - ptr += 4; - outptr += 4; - } + const int size = w * elempack; - ptr1 += 4; - } - } + int i = 0; +#if __mips_msa + for (; i + 3 < size; i += 4) + { + __builtin_prefetch(ptr + 16); + v4f32 _p = (v4f32)__msa_ld_w(ptr, 0); + v4f32 _outp = op(_p, _b_128); + __msa_st_w((v4i32)_outp, outptr, 0); + ptr += 4; + outptr += 4; + } +#endif // __mips_msa + for (; i < size; i++) + { + *outptr = op(*ptr, _b); + ptr += 1; + outptr += 1; } - - return 0; } + } - if (b.dims == 2) + if ((a.dims == 3 || a.dims == 4) && b.dims == 1) + { + // type 9 11 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - // type 27 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.row(q); - float* outptr = c.channel(q); + const float* ptr = a.channel(q); + float* outptr = c.channel(q); - for (int z = 0; z < d; z++) - { - v4f32 _b0 = (v4f32)__msa_ld_w(ptr1, 0); - for (int y = 0; y < h; y++) - { - for (int x = 0; x < w; x++) - { - __builtin_prefetch(ptr + 16); - v4f32 _p = (v4f32)__msa_ld_w(ptr, 0); - v4f32 _outp = op(_p, _b0); - __msa_st_w((v4i32)_outp, outptr, 0); - ptr += 4; - outptr += 4; - } - } + const float _b = b[q]; +#if __mips_msa + v4f32 _b_128 = (elempack == 4) ? (v4f32)__msa_ld_w((const float*)b + q * 4, 0) : __msa_fill_w_f32(_b); +#endif // __mips_msa - ptr1 += 4; - } - } + const int size = w * h * d * elempack; - return 0; + int i = 0; +#if __mips_msa + for (; i + 3 < size; i += 4) + { + __builtin_prefetch(ptr + 16); + v4f32 _p = (v4f32)__msa_ld_w(ptr, 0); + v4f32 _outp = op(_p, _b_128); + __msa_st_w((v4i32)_outp, outptr, 0); + ptr += 4; + outptr += 4; + } +#endif // __mips_msa + for (; i < size; i++) + { + *outptr = op(*ptr, _b); + ptr += 1; + outptr += 1; + } } + } - if (b.dims == 1) + if (a.dims == 3 && b.dims == 2) + { + // type 10 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - if (b.w == 1 && elempack1 == 1) - { - // type 25 - return binary_op_6_11_16_25(a, b, c, opt); - } + const float* ptr = a.channel(q); + const float* ptr1 = b.row(q); + float* outptr = c.channel(q); + + const int size = w * elempack; - // type 26 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) + for (int y = 0; y < h; y++) { - const float* ptr = a.channel(q); - v4f32 _b0 = (v4f32)__msa_ld_w((const float*)b + q * 4, 0); - float* outptr = c.channel(q); + const float _b = ptr1[y]; +#if __mips_msa + v4f32 _b_128 = (elempack == 4) ? (v4f32)__msa_ld_w((const float*)ptr1 + y * 4, 0) : __msa_fill_w_f32(_b); +#endif // __mips_msa - for (int i = 0; i < size; i++) + int i = 0; +#if __mips_msa + for (; i + 3 < size; i += 4) { __builtin_prefetch(ptr + 16); v4f32 _p = (v4f32)__msa_ld_w(ptr, 0); - v4f32 _outp = op(_p, _b0); + v4f32 _outp = op(_p, _b_128); __msa_st_w((v4i32)_outp, outptr, 0); ptr += 4; outptr += 4; } +#endif // __mips_msa + for (; i < size; i++) + { + *outptr = op(*ptr, _b); + ptr += 1; + outptr += 1; + } } - - return 0; } } - else if (a.dims == 3) + + if (a.dims == 4 && b.dims == 2) { - if (b.dims == 4) + // type 12 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - // type 23 - c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; + const float* ptr = a.channel(q); + const float* ptr1 = b.row(q); + float* outptr = c.channel(q); - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) + const int size = w * h * elempack; + + for (int z = 0; z < d; z++) { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); + const float _b = ptr1[z]; +#if __mips_msa + v4f32 _b_128 = (elempack == 4) ? (v4f32)__msa_ld_w((const float*)ptr1 + z * 4, 0) : __msa_fill_w_f32(_b); +#endif // __mips_msa - for (int z = 0; z < d1; z++) + int i = 0; +#if __mips_msa + for (; i + 3 < size; i += 4) { - for (int y = 0; y < h1; y++) - { - v4f32 _a0 = (v4f32)__msa_ld_w(ptr, 0); - for (int x = 0; x < w1; x++) - { - __builtin_prefetch(ptr1 + 16); - v4f32 _p = (v4f32)__msa_ld_w(ptr1, 0); - v4f32 _outp = op(_a0, _p); - __msa_st_w((v4i32)_outp, outptr, 0); - ptr1 += 4; - outptr += 4; - } - - ptr += 4; - } + __builtin_prefetch(ptr + 16); + v4f32 _p = (v4f32)__msa_ld_w(ptr, 0); + v4f32 _outp = op(_p, _b_128); + __msa_st_w((v4i32)_outp, outptr, 0); + ptr += 4; + outptr += 4; + } +#endif // __mips_msa + for (; i < size; i++) + { + *outptr = op(*ptr, _b); + ptr += 1; + outptr += 1; } } - - return 0; } + } - if (b.dims == 3) + if (a.dims == 4 && b.dims == 3) + { + // type 13 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - if (w1 == 1 && h1 == 1 && channels1 == channels) - { - // special type 1 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float* ptr = a.channel(q); - const float* b0 = b.channel(q); - float* outptr = c.channel(q); - v4f32 _b0 = (v4f32)__msa_ld_w(b0, 0); - for (int i = 0; i < size; i++) - { - __builtin_prefetch(ptr + 16); - v4f32 _p = (v4f32)__msa_ld_w(ptr, 0); - v4f32 _outp = op(_p, _b0); - __msa_st_w((v4i32)_outp, outptr, 0); - ptr += 4; - outptr += 4; - } - } + const float* ptr = a.channel(q); + float* outptr = c.channel(q); - return 0; - } + const int size = w * elempack; - if (w1 == w && h1 == h && channels1 == 1 && elempack1 == 1) + for (int z = 0; z < d; z++) { - // special type 2 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; + const float* ptr1 = b.channel(q).row(z); - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) + for (int y = 0; y < h; y++) { - const float* ptr = a.channel(q); - const float* ptr1 = b; - float* outptr = c.channel(q); - for (int i = 0; i < size; i++) + const float _b = ptr1[y]; +#if __mips_msa + v4f32 _b_128 = (elempack == 4) ? (v4f32)__msa_ld_w((const float*)ptr1 + y * 4, 0) : __msa_fill_w_f32(_b); +#endif // __mips_msa + + int i = 0; +#if __mips_msa + for (; i + 3 < size; i += 4) { __builtin_prefetch(ptr + 16); v4f32 _p = (v4f32)__msa_ld_w(ptr, 0); - v4f32 _p1 = __msa_fill_w_f32(ptr1[0]); - v4f32 _outp = op(_p, _p1); + v4f32 _outp = op(_p, _b_128); __msa_st_w((v4i32)_outp, outptr, 0); ptr += 4; - ptr1 += 1; - outptr += 4; - } - } - - return 0; - } - - if (w == 1 && h == 1 && channels1 == channels) - { - // special type 3 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* a0 = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - v4f32 _a0 = (v4f32)__msa_ld_w(a0, 0); - for (int i = 0; i < size1; i++) - { - __builtin_prefetch(ptr1 + 16); - v4f32 _p1 = (v4f32)__msa_ld_w(ptr1, 0); - v4f32 _outp = op(_a0, _p1); - __msa_st_w((v4i32)_outp, outptr, 0); - ptr1 += 4; outptr += 4; } - } - - return 0; - } - - if (w1 == w && h1 == h && channels == 1 && elempack == 1) - { - // special type 4 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr = a; - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - for (int i = 0; i < size1; i++) +#endif // __mips_msa + for (; i < size; i++) { - __builtin_prefetch(ptr + 16); - __builtin_prefetch(ptr1 + 16); - v4f32 _p = __msa_fill_w_f32(ptr[0]); - v4f32 _p1 = (v4f32)__msa_ld_w(ptr1, 0); - v4f32 _outp = op(_p, _p1); - __msa_st_w((v4i32)_outp, outptr, 0); + *outptr = op(*ptr, _b); ptr += 1; - ptr1 += 4; - outptr += 4; + outptr += 1; } } - - return 0; } - - if (w != 1 && w1 == 1 && h1 == h && channels1 == channels) - { - // special type 5 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int y = 0; y < h; y++) - { - v4f32 _p1 = (v4f32)__msa_ld_w(ptr1 + y * 4, 0); - for (int x = 0; x < w; x++) - { - __builtin_prefetch(ptr + 16); - v4f32 _p = (v4f32)__msa_ld_w(ptr, 0); - v4f32 _outp = op(_p, _p1); - __msa_st_w((v4i32)_outp, outptr, 0); - - ptr += 4; - outptr += 4; - } - } - } - - return 0; - } - - if (w1 == w && h != 1 && h1 == 1 && channels1 == channels) - { - // special type 6 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int y = 0; y < h; y++) - { - for (int x = 0; x < w; x++) - { - __builtin_prefetch(ptr + 16); - v4f32 _p = (v4f32)__msa_ld_w(ptr, 0); - v4f32 _p1 = (v4f32)__msa_ld_w(ptr1 + x * 4, 0); - v4f32 _outp = op(_p, _p1); - __msa_st_w((v4i32)_outp, outptr, 0); - - ptr += 4; - outptr += 4; - } - } - } - - return 0; - } - - if (w1 != 1 && w == 1 && h1 == h && channels1 == channels) - { - // special type 7 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int y = 0; y < h1; y++) - { - v4f32 _p = (v4f32)__msa_ld_w(ptr + y * 4, 0); - for (int x = 0; x < w1; x++) - { - __builtin_prefetch(ptr1 + 16); - v4f32 _p1 = (v4f32)__msa_ld_w(ptr1, 0); - v4f32 _outp = op(_p, _p1); - __msa_st_w((v4i32)_outp, outptr, 0); - - ptr1 += 4; - outptr += 4; - } - } - } - - return 0; - } - - if (w1 == w && h1 != 1 && h == 1 && channels1 == channels) - { - // special type 8 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int y = 0; y < h1; y++) - { - for (int x = 0; x < w1; x++) - { - __builtin_prefetch(ptr1 + 16); - v4f32 _p = (v4f32)__msa_ld_w(ptr + x * 4, 0); - v4f32 _p1 = (v4f32)__msa_ld_w(ptr1, 0); - v4f32 _outp = op(_p, _p1); - __msa_st_w((v4i32)_outp, outptr, 0); - - ptr1 += 4; - outptr += 4; - } - } - } - - return 0; - } - - // type 19 - return binary_op_7_13_19_29(a, b, c, opt); } + } - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.dims == 2) - { - // type 18 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.row(q); - float* outptr = c.channel(q); - - for (int y = 0; y < h; y++) - { - v4f32 _b0 = (v4f32)__msa_ld_w(ptr1, 0); - for (int x = 0; x < w; x++) - { - __builtin_prefetch(ptr + 16); - v4f32 _p = (v4f32)__msa_ld_w(ptr, 0); - v4f32 _outp = op(_p, _b0); - __msa_st_w((v4i32)_outp, outptr, 0); - ptr += 4; - outptr += 4; - } + return 0; +} - ptr1 += 4; - } - } +template +static int binary_op_broadcast_outer(const Mat& a, const Mat& b, Mat& c, const Option& opt) +{ + Op op; - return 0; - } + int w = a.w; + int h = a.h; + int d = a.d; + int channels = a.c; + int elempack = a.elempack; - if (b.dims == 1) + if (a.dims == 2) + { + // type 14 + #pragma omp parallel for num_threads(opt.num_threads) + for (int y = 0; y < h; y++) { - if (b.w == 1 && elempack1 == 1) - { - // type 16 - return binary_op_6_11_16_25(a, b, c, opt); - } + const float* ptr = a.row(y); + const float* ptr1 = b; + float* outptr = c.row(y); - // type 17 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) +#if __mips_msa + if (elempack == 4) { - const float* ptr = a.channel(q); - v4f32 _b0 = (v4f32)__msa_ld_w((const float*)b + q * 4, 0); - float* outptr = c.channel(q); - - for (int i = 0; i < size; i++) + for (int x = 0; x < w; x++) { __builtin_prefetch(ptr + 16); v4f32 _p = (v4f32)__msa_ld_w(ptr, 0); - v4f32 _outp = op(_p, _b0); + v4f32 _b = __msa_fill_w_f32(*ptr1); + v4f32 _outp = op(_p, _b); __msa_st_w((v4i32)_outp, outptr, 0); ptr += 4; + ptr1 += 1; outptr += 4; } } - - return 0; +#endif // __mips_msa + if (elempack == 1) + { + for (int x = 0; x < w; x++) + { + *outptr = op(*ptr, *ptr1); + ptr += 1; + ptr1 += 1; + outptr += 1; + } + } } } - else if (a.dims == 2) + + if (a.dims == 3 || a.dims == 4) { - if (b.dims == 4) + // type 15 16 17 18 19 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - // type 22 - c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; + const float* ptr = a.channel(q); + float* outptr = c.channel(q); - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) + for (int z = 0; z < d; z++) { - const float* ptr = a.row(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int z = 0; z < d1; z++) + int z1 = std::min(z, b.d - 1); + for (int y = 0; y < h; y++) { - v4f32 _a0 = (v4f32)__msa_ld_w(ptr, 0); - for (int y = 0; y < h1; y++) + int y1 = std::min(y, b.h - 1); + + const float* ptr1 = b.depth(z1).row(y1); + +#if __mips_msa + if (elempack == 4) { - for (int x = 0; x < w1; x++) + for (int x = 0; x < w; x++) { - __builtin_prefetch(ptr1 + 16); - v4f32 _p = (v4f32)__msa_ld_w(ptr1, 0); - v4f32 _outp = op(_a0, _p); + __builtin_prefetch(ptr + 16); + v4f32 _p = (v4f32)__msa_ld_w(ptr, 0); + v4f32 _b = __msa_fill_w_f32(*ptr1); + v4f32 _outp = op(_p, _b); __msa_st_w((v4i32)_outp, outptr, 0); - ptr1 += 4; + ptr += 4; + ptr1 += 1; outptr += 4; } } - - ptr += 4; - } - } - - return 0; - } - - if (b.dims == 3) - { - // type 14 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr = a.row(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int y = 0; y < h1; y++) - { - v4f32 _a0 = (v4f32)__msa_ld_w(ptr, 0); - for (int x = 0; x < w1; x++) +#endif // __mips_msa + if (elempack == 1) { - __builtin_prefetch(ptr1 + 16); - v4f32 _p1 = (v4f32)__msa_ld_w(ptr1, 0); - v4f32 _outp = op(_a0, _p1); - __msa_st_w((v4i32)_outp, outptr, 0); - ptr1 += 4; - outptr += 4; + for (int x = 0; x < w; x++) + { + *outptr = op(*ptr, *ptr1); + ptr += 1; + ptr1 += 1; + outptr += 1; + } } - - ptr += 4; } } - - return 0; - } - - c.create(w, h, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.dims == 2) - { - // type 13 - return binary_op_7_13_19_29(a, b, c, opt); - } - - if (b.dims == 1) - { - c.create(w, h, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.w == 1 && elempack1 == 1) - { - // type 11 - return binary_op_6_11_16_25(a, b, c, opt); - } - - // type 12 - const float* ptr = a; - const float* ptr1 = b; - float* outptr = c; - - for (int y = 0; y < h; y++) - { - v4f32 _b0 = (v4f32)__msa_ld_w(ptr1, 0); - for (int x = 0; x < w; x++) - { - __builtin_prefetch(ptr + 16); - v4f32 _p = (v4f32)__msa_ld_w(ptr, 0); - v4f32 _outp = op(_p, _b0); - __msa_st_w((v4i32)_outp, outptr, 0); - ptr += 4; - outptr += 4; - } - - ptr1 += 4; - } - - return 0; } } - else if (a.dims == 1) - { - if (a.w == 1 && elempack == 1) - { - // type 2 3 4 20 - return binary_op_2_3_4_20(a, b, c, opt); - } - if (b.dims == 4) - { - // type 21 - c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - v4f32 _a0 = (v4f32)__msa_ld_w((const float*)a + q * 4, 0); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int i = 0; i < size1; i++) - { - __builtin_prefetch(ptr1 + 16); - v4f32 _p1 = (v4f32)__msa_ld_w(ptr1, 0); - v4f32 _outp = op(_a0, _p1); - __msa_st_w((v4i32)_outp, outptr, 0); - ptr1 += 4; - outptr += 4; - } - } - - return 0; - } - - if (b.dims == 3) - { - // type 9 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; + return 0; +} - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - v4f32 _a0 = (v4f32)__msa_ld_w((const float*)a + q * 4, 0); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); +template +static int binary_op_broadcast_20(const Mat& a, const Mat& b, Mat& c, const Option& opt) +{ + Op op; - for (int i = 0; i < size1; i++) - { - __builtin_prefetch(ptr1 + 16); - v4f32 _p1 = (v4f32)__msa_ld_w(ptr1, 0); - v4f32 _outp = op(_a0, _p1); - __msa_st_w((v4i32)_outp, outptr, 0); - ptr1 += 4; - outptr += 4; - } - } + int w = a.w; + int h = a.h; + int channels = a.c; + int elempack = a.elempack; - return 0; - } + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* ptr = a.channel(q); + float* outptr = c.channel(q); - if (b.dims == 2) + for (int y = 0; y < h; y++) { - // type 8 - c.create(w1, h1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; + const float* ptr1 = b.channel(q); - const float* ptr = a; - const float* ptr1 = b; - float* outptr = c; + const int size = w * elempack; - for (int y = 0; y < h1; y++) + int i = 0; +#if __mips_msa + for (; i + 3 < size; i += 4) { - v4f32 _a0 = (v4f32)__msa_ld_w(ptr, 0); - for (int x = 0; x < w1; x++) - { - __builtin_prefetch(ptr1 + 16); - v4f32 _p1 = (v4f32)__msa_ld_w(ptr1, 0); - v4f32 _outp = op(_a0, _p1); - __msa_st_w((v4i32)_outp, outptr, 0); - ptr1 += 4; - outptr += 4; - } - + __builtin_prefetch(ptr + 16); + __builtin_prefetch(ptr1 + 16); + v4f32 _p = (v4f32)__msa_ld_w(ptr, 0); + v4f32 _p1 = (v4f32)__msa_ld_w(ptr1, 0); + v4f32 _outp = op(_p, _p1); + __msa_st_w((v4i32)_outp, outptr, 0); ptr += 4; + ptr1 += 4; + outptr += 4; } - - return 0; - } - - if (b.dims == 1) - { - c.create(w, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.w == 1 && elempack1 == 1) +#endif // __mips_msa + for (; i < size; i++) { - // type 6 - return binary_op_6_11_16_25(a, b, c, opt); + *outptr = op(*ptr, *ptr1); + ptr += 1; + ptr1 += 1; + outptr += 1; } - - // type 7 - binary_op_7_13_19_29(a, b, c, opt); } } return 0; } -#endif // __mips_msa template static int binary_op_scalar_inplace(Mat& a, float b, const Option& opt) { Op op; - int w = a.w; - int h = a.h; - int d = a.d; - int channels = a.c; - int elempack = a.elempack; - int size = w * h * d * elempack; + const int channels = a.c; + const int size = a.w * a.h * a.d * a.elempack; #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) @@ -976,6 +549,7 @@ MAKE_FUNCTION(binary_op_min, std::min(x, y), __msa_fmin_w(x, y)) MAKE_FUNCTION(binary_op_pow, (float)pow(x, y), pow_ps(x, y)) MAKE_FUNCTION(binary_op_rsub, y - x, __msa_fsub_w(y, x)) MAKE_FUNCTION(binary_op_rdiv, y / x, __msa_fdiv_w(y, x)) +MAKE_FUNCTION(binary_op_rpow, (float)pow(y, x), pow_ps(y, x)) // *INDENT-ON* // clang-format on @@ -983,82 +557,195 @@ MAKE_FUNCTION(binary_op_rdiv, y / x, __msa_fdiv_w(y, x)) } // namespace BinaryOp_mips_functor -int BinaryOp_mips::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +static int binary_op_scalar(const Mat& a, float b, Mat& c, int op_type, const Option& opt) { -#if __mips_msa using namespace BinaryOp_mips_functor; - const Mat& bottom_blob = bottom_blobs[0]; - const Mat& bottom_blob1 = bottom_blobs[1]; - Mat& top_blob = top_blobs[0]; - - int elempack = bottom_blob.elempack; - int elempack1 = bottom_blob1.elempack; + if (op_type == BinaryOp::Operation_ADD) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_SUB) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_MUL) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_DIV) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_MAX) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_MIN) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_POW) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_RSUB) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_RDIV) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_RPOW) return binary_op_scalar(a, b, c, opt); + + // should never reach here + return 0; +} - if (elempack == 4 || elempack1 == 4) - { - if (op_type == Operation_ADD) - return binary_op_pack4(bottom_blob, bottom_blob1, top_blob, opt); +static int binary_op_no_broadcast(const Mat& a, const Mat& b, Mat& c, int op_type, const Option& opt) +{ + using namespace BinaryOp_mips_functor; - if (op_type == Operation_SUB) - return binary_op_pack4(bottom_blob, bottom_blob1, top_blob, opt); + if (op_type == BinaryOp::Operation_ADD) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_SUB) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_MUL) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_DIV) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_MAX) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_MIN) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_POW) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_RSUB) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_RDIV) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_RPOW) return binary_op_no_broadcast(a, b, c, opt); + + // should never reach here + return 0; +} - if (op_type == Operation_MUL) - return binary_op_pack4(bottom_blob, bottom_blob1, top_blob, opt); +static int binary_op_broadcast_inner(const Mat& a, const Mat& b, Mat& c, int op_type, const Option& opt) +{ + // squeeze inner axes + Mat b2 = b; + if (b.dims == 2 && b.w == 1) + b2 = b.reshape(b.h); + else if (b.dims == 3 && b.h == 1) + b2 = b.reshape(b.c); + else if (b.dims == 3 && b.w == 1) + b2 = b.reshape(b.h, b.c); + else if (b.dims == 4 && b.d == 1) + b2 = b.reshape(b.c); + else if (b.dims == 4 && b.h == 1) + b2 = b.reshape(b.d, b.c); + else if (b.dims == 4 && b.w == 1) + b2 = b.reshape(b.h, b.d, b.c); - if (op_type == Operation_DIV) - return binary_op_pack4(bottom_blob, bottom_blob1, top_blob, opt); + using namespace BinaryOp_mips_functor; - if (op_type == Operation_MAX) - return binary_op_pack4(bottom_blob, bottom_blob1, top_blob, opt); + if (op_type == BinaryOp::Operation_ADD) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_SUB) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_MUL) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_DIV) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_MAX) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_MIN) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_POW) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_inner(a, b2, c, opt); + + // should never reach here + return 0; +} - if (op_type == Operation_MIN) - return binary_op_pack4(bottom_blob, bottom_blob1, top_blob, opt); +static int binary_op_broadcast_outer(const Mat& a, const Mat& b, Mat& c, int op_type, const Option& opt) +{ + using namespace BinaryOp_mips_functor; - if (op_type == Operation_POW) - return binary_op_pack4(bottom_blob, bottom_blob1, top_blob, opt); + if (op_type == BinaryOp::Operation_ADD) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_SUB) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_MUL) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_DIV) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_MAX) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_MIN) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_POW) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_outer(a, b, c, opt); + + // should never reach here + return 0; +} - if (op_type == Operation_RSUB) - return binary_op_pack4(bottom_blob1, bottom_blob, top_blob, opt); +static int binary_op_broadcast_20(const Mat& a, const Mat& b, Mat& c, int op_type, const Option& opt) +{ + using namespace BinaryOp_mips_functor; - if (op_type == Operation_RDIV) - return binary_op_pack4(bottom_blob1, bottom_blob, top_blob, opt); - } -#endif // __mips_msa + if (op_type == BinaryOp::Operation_ADD) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_SUB) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_MUL) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_DIV) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_MAX) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_MIN) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_POW) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_20(a, b, c, opt); + + // should never reach here + return 0; +} - return BinaryOp::forward(bottom_blobs, top_blobs, opt); +static int get_reverse_op_type(int op_type) +{ + if (op_type == BinaryOp::Operation_SUB) return BinaryOp::Operation_RSUB; + if (op_type == BinaryOp::Operation_DIV) return BinaryOp::Operation_RDIV; + if (op_type == BinaryOp::Operation_POW) return BinaryOp::Operation_RPOW; + if (op_type == BinaryOp::Operation_RSUB) return BinaryOp::Operation_SUB; + if (op_type == BinaryOp::Operation_RDIV) return BinaryOp::Operation_DIV; + if (op_type == BinaryOp::Operation_RPOW) return BinaryOp::Operation_POW; + return op_type; } -int BinaryOp_mips::forward_inplace(Mat& bottom_top_blob, const Option& opt) const +int BinaryOp_mips::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const { - using namespace BinaryOp_mips_functor; + const bool b_is_scalar = bottom_blobs[1].w * bottom_blobs[1].h * bottom_blobs[1].d * bottom_blobs[1].c * bottom_blobs[1].elempack == 1; + const bool a_rank_is_lower = bottom_blobs[0].dims < bottom_blobs[1].dims && !b_is_scalar; + const bool a_size_is_lower = bottom_blobs[0].w * bottom_blobs[0].h * bottom_blobs[0].d * bottom_blobs[0].c * bottom_blobs[0].elempack < bottom_blobs[1].w * bottom_blobs[1].h * bottom_blobs[1].d * bottom_blobs[1].c * bottom_blobs[1].elempack; + const bool a_is_lower = a_rank_is_lower || (!a_rank_is_lower && a_size_is_lower); + const Mat& A = a_is_lower ? bottom_blobs[1] : bottom_blobs[0]; + const Mat& B = a_is_lower ? bottom_blobs[0] : bottom_blobs[1]; + const int op_type_r = a_is_lower ? get_reverse_op_type(op_type) : op_type; - if (op_type == Operation_ADD) - return binary_op_scalar_inplace(bottom_top_blob, b, opt); + Mat& top_blob = top_blobs[0]; + top_blob.create_like(A, opt.blob_allocator); + if (top_blob.empty()) + return -100; - if (op_type == Operation_SUB) - return binary_op_scalar_inplace(bottom_top_blob, b, opt); + // B is a scalar + if (B.w * B.h * B.d * B.c * B.elempack == 1) + { + return binary_op_scalar(A, B[0], top_blob, op_type_r, opt); + } - if (op_type == Operation_MUL) - return binary_op_scalar_inplace(bottom_top_blob, b, opt); + // no broadcast + if (A.dims == B.dims && A.w == B.w && A.h == B.h && A.d == B.d && A.c == B.c && A.elempack == B.elempack) + { + return binary_op_no_broadcast(A, B, top_blob, op_type_r, opt); + } - if (op_type == Operation_DIV) - return binary_op_scalar_inplace(bottom_top_blob, b, opt); + // broadcast B for inner axis + if ((B.dims < A.dims) + || (A.dims == 2 && B.w == 1 && B.h == A.h) + || (A.dims == 3 && B.w == 1 && B.h == 1 && B.c == A.c) + || (A.dims == 3 && B.w == 1 && B.h == A.h && B.c == A.c) + || (A.dims == 4 && B.w == 1 && B.h == 1 && B.d == 1 && B.c == A.c) + || (A.dims == 4 && B.w == 1 && B.h == 1 && B.d == A.d && B.c == A.c) + || (A.dims == 4 && B.w == 1 && B.h == A.h && B.d == A.d && B.c == A.c)) + { + return binary_op_broadcast_inner(A, B, top_blob, op_type_r, opt); + } - if (op_type == Operation_MAX) - return binary_op_scalar_inplace(bottom_top_blob, b, opt); + // broadcast B for outer axis + if (B.elempack == 1 && ((A.dims == 2 && B.w == A.w && B.h == 1) || (A.dims == 3 && B.w == A.w && B.h == 1 && B.c == 1) || (A.dims == 3 && B.w == A.w && B.h == A.h && B.c == 1) || (A.dims == 4 && B.w == A.w && B.h == 1 && B.d == 1 && B.c == 1) || (A.dims == 4 && B.w == A.w && B.h == A.h && B.d == 1 && B.c == 1) || (A.dims == 4 && B.w == A.w && B.h == A.h && B.d == A.d && B.c == 1))) + { + return binary_op_broadcast_outer(A, B, top_blob, op_type_r, opt); + } - if (op_type == Operation_MIN) - return binary_op_scalar_inplace(bottom_top_blob, b, opt); + // some special broadcast rule here + if (A.dims == 3 && B.dims == 3 && A.w == B.w && B.h == 1 && A.c == B.c) + { + return binary_op_broadcast_20(A, B, top_blob, op_type_r, opt); + } - if (op_type == Operation_POW) - return binary_op_scalar_inplace(bottom_top_blob, b, opt); + return 0; +} - if (op_type == Operation_RSUB) - return binary_op_scalar_inplace(bottom_top_blob, b, opt); +int BinaryOp_mips::forward_inplace(Mat& bottom_top_blob, const Option& opt) const +{ + using namespace BinaryOp_mips_functor; - if (op_type == Operation_RDIV) - return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_ADD) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_SUB) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_MUL) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_DIV) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_MAX) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_MIN) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_POW) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_RSUB) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_RDIV) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_RPOW) return binary_op_scalar_inplace(bottom_top_blob, b, opt); return 0; } diff --git a/src/layer/riscv/binaryop_riscv.cpp b/src/layer/riscv/binaryop_riscv.cpp index 9858e6548..4464c1cc1 100644 --- a/src/layer/riscv/binaryop_riscv.cpp +++ b/src/layer/riscv/binaryop_riscv.cpp @@ -39,105 +39,52 @@ BinaryOp_riscv::BinaryOp_riscv() #endif } -#if __riscv_vector template -static int binary_op_2_3_4_20(const Mat& a, const Mat& b, Mat& c, const Option& opt) +static int binary_op_scalar(const Mat& a, float b, Mat& c, const Option& opt) { Op op; - int w = b.w; - int h = b.h; - int d = b.d; - int channels = b.c; - int elempack = b.elempack; - int size = w * h * d * elempack; - - // type 2 3 4 20 - c.create_like(b, opt.blob_allocator); - if (c.empty()) - return -100; + const int channels = a.c; + const int size = a.w * a.h * a.d * a.elempack; #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) { - const float a0 = a[0]; - const float* ptr = b.channel(q); + const float* ptr = a.channel(q); float* outptr = c.channel(q); +#if __riscv_vector int n = size; while (n > 0) { size_t vl = vsetvl_e32m8(n); vfloat32m8_t _p = vle32_v_f32m8(ptr, vl); - vfloat32m8_t _outp = op(a0, _p, vl); - vse32_v_f32m8(outptr, _outp, vl); - + _p = op(_p, b, vl); + vse32_v_f32m8(outptr, _p, vl); + n -= vl; ptr += vl; outptr += vl; - n -= vl; } - } - - return 0; -} - -template -static int binary_op_6_11_16_25(const Mat& a, const Mat& b, Mat& c, const Option& opt) -{ - Op op; - - int w = a.w; - int h = a.h; - int d = a.d; - int channels = a.c; - int elempack = a.elempack; - int size = w * h * d * elempack; - - // type 6 11 16 25 - c.create_like(a, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float* ptr = a.channel(q); - const float b0 = b[0]; - float* outptr = c.channel(q); - - int n = size; - while (n > 0) +#else + for (int i = 0; i < size; i++) { - size_t vl = vsetvl_e32m8(n); - vfloat32m8_t _p = vle32_v_f32m8(ptr, vl); - vfloat32m8_t _outp = op(_p, b0, vl); - vse32_v_f32m8(outptr, _outp, vl); - - ptr += vl; - outptr += vl; - n -= vl; + *outptr = op(*ptr, b); + ptr++; + outptr++; } +#endif } return 0; } template -static int binary_op_7_13_19_29(const Mat& a, const Mat& b, Mat& c, const Option& opt) +static int binary_op_no_broadcast(const Mat& a, const Mat& b, Mat& c, const Option& opt) { Op op; - int w = a.w; - int h = a.h; - int d = a.d; - int channels = a.c; - int elempack = a.elempack; - int size = w * h * d * elempack; - - // type 7 13 19 29 - c.create_like(a, opt.blob_allocator); - if (c.empty()) - return -100; + const int channels = a.c; + const int size = a.w * a.h * a.d * a.elempack; #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) @@ -146,6 +93,7 @@ static int binary_op_7_13_19_29(const Mat& a, const Mat& b, Mat& c, const Option const float* ptr1 = b.channel(q); float* outptr = c.channel(q); +#if __riscv_vector int n = size; while (n > 0) { @@ -154,19 +102,27 @@ static int binary_op_7_13_19_29(const Mat& a, const Mat& b, Mat& c, const Option vfloat32m8_t _p1 = vle32_v_f32m8(ptr1, vl); vfloat32m8_t _outp = op(_p, _p1, vl); vse32_v_f32m8(outptr, _outp, vl); - + n -= vl; ptr += vl; ptr1 += vl; outptr += vl; - n -= vl; } +#else + for (int i = 0; i < size; i++) + { + *outptr = op(*ptr, *ptr1); + ptr += 1; + ptr1 += 1; + outptr += 1; + } +#endif } return 0; } template -static int binary_op_rvv(const Mat& a, const Mat& b, Mat& c, const Option& opt) +static int binary_op_broadcast_inner(const Mat& a, const Mat& b, Mat& c, const Option& opt) { Op op; @@ -174,793 +130,367 @@ static int binary_op_rvv(const Mat& a, const Mat& b, Mat& c, const Option& opt) int h = a.h; int d = a.d; int channels = a.c; - int size = w * h * d; - size_t elemsize = a.elemsize; int elempack = a.elempack; - int w1 = b.w; - int h1 = b.h; - int d1 = b.d; - int channels1 = b.c; - int size1 = w1 * h1 * d1; - size_t elemsize1 = b.elemsize; - int elempack1 = b.elempack; - - if (a.dims == 4) + if (a.dims == 2 && b.dims == 1) { - if (b.dims == 4) - { - // type 29 - return binary_op_7_13_19_29(a, b, c, opt); - } - - c.create(w, h, d, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.dims == 3) + // type 8 + #pragma omp parallel for num_threads(opt.num_threads) + for (int y = 0; y < h; y++) { - // type 28 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); + const float* ptr = a.row(y); + float* outptr = c.row(y); - for (int z = 0; z < d; z++) - { - for (int y = 0; y < h; y++) - { - vfloat32m8_t _b0x = vle32_v_f32m8_f32m1(ptr1); + const float _b = b[y]; - int n = w * elempack; - while (n > 0) - { - size_t vl = vsetvl_e32m8(n); - vfloat32m8_t _p = vle32_v_f32m8(ptr, vl); - vfloat32m8_t _outp = op(_p, _b0x, vl); - vse32_v_f32m8(outptr, _outp, vl); - - ptr += vl; - outptr += vl; - n -= vl; - } + const int size = w * elempack; - ptr1 += elempack1; - } - } +#if __riscv_vector + int n = size; + vfloat32m8_t _bx = (elempack == 1) ? vfmv_v_f_f32m8(_b, vsetvl_e32m8(n)) : vle32_v_f32m8_f32m1((const float*)b + y * elempack); + while (n > 0) + { + size_t vl = vsetvl_e32m8(n); + vfloat32m8_t _p = vle32_v_f32m8(ptr, vl); + vfloat32m8_t _outp = op(_p, _bx, vl); + vse32_v_f32m8(outptr, _outp, vl); + n -= vl; + ptr += vl; + outptr += vl; } - - return 0; +#else + for (int i = 0; i < size; i++) + { + *outptr = op(*ptr, _b); + ptr += 1; + outptr += 1; + } +#endif } + } - if (b.dims == 2) + if ((a.dims == 3 || a.dims == 4) && b.dims == 1) + { + // type 9 11 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - // type 27 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.row(q); - float* outptr = c.channel(q); - - for (int z = 0; z < d; z++) - { - vfloat32m8_t _b0x = vle32_v_f32m8_f32m1(ptr1); + const float* ptr = a.channel(q); + float* outptr = c.channel(q); - int n = w * h * elempack; - while (n > 0) - { - size_t vl = vsetvl_e32m8(n); - vfloat32m8_t _p = vle32_v_f32m8(ptr, vl); - vfloat32m8_t _outp = op(_p, _b0x, vl); - vse32_v_f32m8(outptr, _outp, vl); + const float _b = b[q]; - ptr += vl; - outptr += vl; - n -= vl; - } + const int size = w * h * d * elempack; - ptr1 += elempack1; - } +#if __riscv_vector + int n = size; + vfloat32m8_t _bx = (elempack == 1) ? vfmv_v_f_f32m8(_b, vsetvl_e32m8(n)) : vle32_v_f32m8_f32m1((const float*)b + q * elempack); + while (n > 0) + { + size_t vl = vsetvl_e32m8(n); + vfloat32m8_t _p = vle32_v_f32m8(ptr, vl); + vfloat32m8_t _outp = op(_p, _bx, vl); + vse32_v_f32m8(outptr, _outp, vl); + n -= vl; + ptr += vl; + outptr += vl; } - - return 0; +#else + for (int i = 0; i < size; i++) + { + *outptr = op(*ptr, _b); + ptr += 1; + outptr += 1; + } +#endif } + } - if (b.dims == 1) + if (a.dims == 3 && b.dims == 2) + { + // type 10 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - if (b.w == 1 && elempack1 == 1) - { - // type 25 - return binary_op_6_11_16_25(a, b, c, opt); - } + const float* ptr = a.channel(q); + const float* ptr1 = b.row(q); + float* outptr = c.channel(q); - // type 26 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float* ptr = a.channel(q); - float* outptr = c.channel(q); + const int size = w * elempack; - vfloat32m8_t _b0x = vle32_v_f32m8_f32m1((const float*)b + q * elempack); + for (int y = 0; y < h; y++) + { + const float _b = ptr1[y]; - int n = size * elempack; +#if __riscv_vector + int n = size; + vfloat32m8_t _bx = (elempack == 1) ? vfmv_v_f_f32m8(_b, vsetvl_e32m8(n)) : vle32_v_f32m8_f32m1(ptr1 + y * elempack); while (n > 0) { size_t vl = vsetvl_e32m8(n); vfloat32m8_t _p = vle32_v_f32m8(ptr, vl); - vfloat32m8_t _outp = op(_p, _b0x, vl); + vfloat32m8_t _outp = op(_p, _bx, vl); vse32_v_f32m8(outptr, _outp, vl); - - outptr += vl; - ptr += vl; n -= vl; + ptr += vl; + outptr += vl; } +#else + for (int i = 0; i < size; i++) + { + *outptr = op(*ptr, _b); + ptr += 1; + outptr += 1; + } +#endif } - - return 0; } } - else if (a.dims == 3) + + if (a.dims == 4 && b.dims == 2) { - if (b.dims == 4) + // type 12 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - // type 23 - c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; + const float* ptr = a.channel(q); + const float* ptr1 = b.row(q); + float* outptr = c.channel(q); - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) + const int size = w * h * elempack; + + for (int z = 0; z < d; z++) { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); + const float _b = ptr1[z]; - for (int z = 0; z < d1; z++) +#if __riscv_vector + int n = size; + vfloat32m8_t _bx = (elempack == 1) ? vfmv_v_f_f32m8(_b, vsetvl_e32m8(n)) : vle32_v_f32m8_f32m1(ptr1 + z * elempack); + while (n > 0) { - for (int y = 0; y < h1; y++) - { - vfloat32m8_t _a0x = vle32_v_f32m8_f32m1(ptr); - - int n = w1 * elempack1; - while (n > 0) - { - size_t vl = vsetvl_e32m8(n); - vfloat32m8_t _p1 = vle32_v_f32m8(ptr1, vl); - vfloat32m8_t _outp = op(_a0x, _p1, vl); - vse32_v_f32m8(outptr, _outp, vl); - - ptr1 += vl; - outptr += vl; - n -= vl; - } - - ptr += elempack; - } + size_t vl = vsetvl_e32m8(n); + vfloat32m8_t _p = vle32_v_f32m8(ptr, vl); + vfloat32m8_t _outp = op(_p, _bx, vl); + vse32_v_f32m8(outptr, _outp, vl); + n -= vl; + ptr += vl; + outptr += vl; + } +#else + for (int i = 0; i < size; i++) + { + *outptr = op(*ptr, _b); + ptr += 1; + outptr += 1; } +#endif } - - return 0; } + } - if (b.dims == 3) + if (a.dims == 4 && b.dims == 3) + { + // type 13 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - if (w1 == 1 && h1 == 1 && channels1 == channels) + const float* ptr = a.channel(q); + float* outptr = c.channel(q); + + const int size = w * elempack; + + for (int z = 0; z < d; z++) { - // special type 1 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; + const float* ptr1 = b.channel(q).row(z); - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) + for (int y = 0; y < h; y++) { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - vfloat32m8_t _b0x = vle32_v_f32m8_f32m1(ptr1); + const float _b = ptr1[y]; - int n = size * elempack; +#if __riscv_vector + int n = size; + vfloat32m8_t _bx = (elempack == 1) ? vfmv_v_f_f32m8(_b, vsetvl_e32m8(n)) : vle32_v_f32m8_f32m1(ptr1 + y * elempack); while (n > 0) { size_t vl = vsetvl_e32m8(n); vfloat32m8_t _p = vle32_v_f32m8(ptr, vl); - vfloat32m8_t _outp = op(_p, _b0x, vl); + vfloat32m8_t _outp = op(_p, _bx, vl); vse32_v_f32m8(outptr, _outp, vl); - + n -= vl; ptr += vl; outptr += vl; - n -= vl; } - } - - return 0; - } - - if (w1 == w && h1 == h && channels1 == 1 && elempack1 == 1) - { - // special type 2 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b; - float* outptr = c.channel(q); - +#else for (int i = 0; i < size; i++) { - int n = elempack; - while (n > 0) - { - size_t vl = vsetvl_e32m8(n); - vfloat32m8_t _p = vle32_v_f32m8(ptr, vl); - vfloat32m8_t _outp = op(_p, *ptr1, vl); - vse32_v_f32m8(outptr, _outp, vl); - - n -= vl; - ptr += vl; - outptr += vl; - } - - ptr1 += 1; + *outptr = op(*ptr, _b); + ptr += 1; + outptr += 1; } +#endif } - - return 0; } + } + } - if (w == 1 && h == 1 && channels1 == channels) - { - // special type 3 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; + return 0; +} - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); +template +static int binary_op_broadcast_outer(const Mat& a, const Mat& b, Mat& c, const Option& opt) +{ + Op op; + + int w = a.w; + int h = a.h; + int d = a.d; + int channels = a.c; + int elempack = a.elempack; - vfloat32m8_t _a0x = vle32_v_f32m8_f32m1(ptr); + if (a.dims == 2) + { + // type 14 + #pragma omp parallel for num_threads(opt.num_threads) + for (int y = 0; y < h; y++) + { + const float* ptr = a.row(y); + const float* ptr1 = b; + float* outptr = c.row(y); - int n1 = size1 * elempack1; - while (n1 > 0) +#if __riscv_vector + if (elempack != 1) + { + for (int x = 0; x < w; x++) + { + int n = elempack; + while (n > 0) { - size_t vl = vsetvl_e32m8(n1); - vfloat32m8_t _p1 = vle32_v_f32m8(ptr1, vl); - vfloat32m8_t _outp = op(_a0x, _p1, vl); + size_t vl = vsetvl_e32m8(n); + vfloat32m8_t _p = vle32_v_f32m8(ptr, vl); + vfloat32m8_t _outp = op(_p, *ptr1, vl); vse32_v_f32m8(outptr, _outp, vl); - - ptr1 += vl; + n -= vl; + ptr += vl; outptr += vl; - n1 -= vl; } + ptr1 += 1; } - - return 0; } - - if (w1 == w && h1 == h && channels == 1 && elempack == 1) +#endif + if (elempack == 1) { - // special type 4 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) + for (int x = 0; x < w; x++) { - const float* ptr = a; - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int i = 0; i < size1; i++) - { - int n1 = elempack1; - while (n1 > 0) - { - size_t vl = vsetvl_e32m8(n1); - vfloat32m8_t _p1 = vle32_v_f32m8(ptr1, vl); - vfloat32m8_t _p = vfmv_v_f_f32m8(*ptr, vl); - vfloat32m8_t _outp = op(_p, _p1, vl); - vse32_v_f32m8(outptr, _outp, vl); - - n1 -= vl; - ptr1 += vl; - outptr += vl; - } - - ptr += 1; - } + *outptr = op(*ptr, *ptr1); + ptr += 1; + ptr1 += 1; + outptr += 1; } - - return 0; } + } + } - if (w != 1 && w1 == 1 && h1 == h && channels1 == channels) - { - // special type 5 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; + if (a.dims == 3 || a.dims == 4) + { + // type 15 16 17 18 19 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* ptr = a.channel(q); + float* outptr = c.channel(q); - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) + for (int z = 0; z < d; z++) + { + int z1 = std::min(z, b.d - 1); + for (int y = 0; y < h; y++) { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); + int y1 = std::min(y, b.h - 1); - for (int y = 0; y < h; y++) - { - vfloat32m8_t _b0x = vle32_v_f32m8_f32m1(ptr1); + const float* ptr1 = b.depth(z1).row(y1); - int n = w * elempack; - while (n > 0) +#if __riscv_vector + if (elempack != 1) + { + for (int x = 0; x < w; x++) { - size_t vl = vsetvl_e32m8(n); - vfloat32m8_t _p = vle32_v_f32m8(ptr, vl); - vfloat32m8_t _outp = op(_p, _b0x, vl); - vse32_v_f32m8(outptr, _outp, vl); - - ptr += vl; - outptr += vl; - n -= vl; + int n = elempack; + while (n > 0) + { + size_t vl = vsetvl_e32m8(n); + vfloat32m8_t _p = vle32_v_f32m8(ptr, vl); + vfloat32m8_t _outp = op(_p, *ptr1, vl); + vse32_v_f32m8(outptr, _outp, vl); + n -= vl; + ptr += vl; + outptr += vl; + } + ptr1 += 1; } - - ptr1 += elempack1; } - } - - return 0; - } - - if (w1 == w && h != 1 && h1 == 1 && channels1 == channels) - { - // special type 6 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int y = 0; y < h; y++) +#endif + if (elempack == 1) { - int n = w * elempack; - const float* ptr1_vol = ptr1; - while (n > 0) + for (int x = 0; x < w; x++) { - size_t vl = vsetvl_e32m8(n); - vfloat32m8_t _p = vle32_v_f32m8(ptr, vl); - vfloat32m8_t _p1 = vle32_v_f32m8(ptr1_vol, vl); - vfloat32m8_t _outp = op(_p, _p1, vl); - vse32_v_f32m8(outptr, _outp, vl); - - outptr += vl; - ptr += vl; - n -= vl; - ptr1_vol += vl; + *outptr = op(*ptr, *ptr1); + ptr += 1; + ptr1 += 1; + outptr += 1; } } } - - return 0; } + } + } - if (w1 != 1 && w == 1 && h1 == h && channels1 == channels) - { - // special type 7 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; + return 0; +} - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); +template +static int binary_op_broadcast_20(const Mat& a, const Mat& b, Mat& c, const Option& opt) +{ + Op op; - for (int y = 0; y < h1; y++) - { - vfloat32m8_t _a0x = vle32_v_f32m8_f32m1(ptr); + int w = a.w; + int h = a.h; + int channels = a.c; + int elempack = a.elempack; - int n = w1 * elempack1; - while (n > 0) - { - size_t vl = vsetvl_e32m8(n); - vfloat32m8_t _p1 = vle32_v_f32m8(ptr1, vl); - vfloat32m8_t _outp = op(_a0x, _p1, vl); - vse32_v_f32m8(outptr, _outp, vl); - - ptr1 += vl; - outptr += vl; - n -= vl; - } + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* ptr = a.channel(q); + float* outptr = c.channel(q); - ptr += elempack; - } - } - - return 0; - } - - if (w1 == w && h1 != 1 && h == 1 && channels1 == channels) - { - // special type 8 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int y = 0; y < h1; y++) - { - int n = w1 * elempack1; - const float* ptr_vol = ptr; - while (n > 0) - { - size_t vl = vsetvl_e32m8(n); - vfloat32m8_t _p = vle32_v_f32m8(ptr_vol, vl); - vfloat32m8_t _p1 = vle32_v_f32m8(ptr1, vl); - vfloat32m8_t _outp = op(_p, _p1, vl); - vse32_v_f32m8(outptr, _outp, vl); - - ptr1 += vl; - outptr += vl; - ptr_vol += vl; - n -= vl; - } - } - } - - return 0; - } - - // type 19 - return binary_op_7_13_19_29(a, b, c, opt); - } - - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.dims == 2) - { - // type 18 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.row(q); - float* outptr = c.channel(q); - - for (int y = 0; y < h; y++) - { - vfloat32m8_t _b0x = vle32_v_f32m8_f32m1(ptr1); - - int n = w * elempack; - while (n > 0) - { - size_t vl = vsetvl_e32m8(n); - vfloat32m8_t _p = vle32_v_f32m8(ptr, vl); - vfloat32m8_t _outp = op(_p, _b0x, vl); - vse32_v_f32m8(outptr, _outp, vl); - - ptr += vl; - outptr += vl; - n -= vl; - } - - ptr1 += elempack1; - } - } - - return 0; - } - - if (b.dims == 1) - { - if (b.w == 1 && elempack1 == 1) - { - // type 16 - return binary_op_6_11_16_25(a, b, c, opt); - } - - // type 17 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float* ptr = a.channel(q); - float* outptr = c.channel(q); - - vfloat32m8_t _b0x = vle32_v_f32m8_f32m1((const float*)b + q * elempack); - - int n = size * elempack; - while (n > 0) - { - size_t vl = vsetvl_e32m8(n); - vfloat32m8_t _p = vle32_v_f32m8(ptr, vl); - vfloat32m8_t _outp = op(_p, _b0x, vl); - vse32_v_f32m8(outptr, _outp, vl); - - outptr += vl; - ptr += vl; - n -= vl; - } - } - - return 0; - } - } - else if (a.dims == 2) - { - if (b.dims == 4) - { - // type 22 - c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr = a.row(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int z = 0; z < d1; z++) - { - vfloat32m8_t _a0x = vle32_v_f32m8_f32m1(ptr); - - int n = w1 * h1 * elempack1; - while (n > 0) - { - size_t vl = vsetvl_e32m8(n); - vfloat32m8_t _p1 = vle32_v_f32m8(ptr1, vl); - vfloat32m8_t _outp = op(_a0x, _p1, vl); - vse32_v_f32m8(outptr, _outp, vl); - - ptr1 += vl; - outptr += vl; - n -= vl; - } - - ptr += elempack; - } - } - - return 0; - } - - if (b.dims == 3) - { - // type 14 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr = a.row(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int y = 0; y < h1; y++) - { - vfloat32m8_t _a0x = vle32_v_f32m8_f32m1(ptr); - - int n = w1 * elempack1; - while (n > 0) - { - size_t vl = vsetvl_e32m8(n); - vfloat32m8_t _p1 = vle32_v_f32m8(ptr1, vl); - vfloat32m8_t _outp = op(_a0x, _p1, vl); - vse32_v_f32m8(outptr, _outp, vl); - - ptr1 += vl; - outptr += vl; - n -= vl; - } - - ptr += elempack; - } - } - - return 0; - } - - c.create(w, h, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.dims == 2) - { - // type 13 - return binary_op_7_13_19_29(a, b, c, opt); - } - - if (b.dims == 1) - { - c.create(w, h, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.w == 1 && elempack1 == 1) - { - // type 11 - return binary_op_6_11_16_25(a, b, c, opt); - } - - // type 12 - const float* ptr = a; - const float* ptr1 = b; - float* outptr = c; - - for (int y = 0; y < h; y++) - { - vfloat32m8_t _b0x = vle32_v_f32m8_f32m1(ptr1); - - int n = w * elempack; - while (n > 0) - { - size_t vl = vsetvl_e32m8(n); - vfloat32m8_t _p = vle32_v_f32m8(ptr, vl); - vfloat32m8_t _outp = op(_p, _b0x, vl); - vse32_v_f32m8(outptr, _outp, vl); - - ptr += vl; - outptr += vl; - n -= vl; - } - - ptr1 += elempack; - } - - return 0; - } - } - else if (a.dims == 1) - { - if (a.w == 1 && elempack == 1) - { - // type 2 3 4 20 - return binary_op_2_3_4_20(a, b, c, opt); - } - - if (b.dims == 4) - { - // type 21 - c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - vfloat32m8_t _a0x = vle32_v_f32m8_f32m1((const float*)a + q * elempack); - - int n1 = size1 * elempack1; - while (n1 > 0) - { - size_t vl = vsetvl_e32m8(n1); - vfloat32m8_t _p1 = vle32_v_f32m8(ptr1, vl); - vfloat32m8_t _outp = op(_a0x, _p1, vl); - vse32_v_f32m8(outptr, _outp, vl); - - ptr1 += vl; - outptr += vl; - n1 -= vl; - } - } - - return 0; - } - - if (b.dims == 3) - { - // type 9 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - vfloat32m8_t _a0x = vle32_v_f32m8_f32m1((const float*)a + q * elempack); - - int n1 = size1 * elempack1; - while (n1 > 0) - { - size_t vl = vsetvl_e32m8(n1); - vfloat32m8_t _p1 = vle32_v_f32m8(ptr1, vl); - vfloat32m8_t _outp = op(_a0x, _p1, vl); - vse32_v_f32m8(outptr, _outp, vl); - - ptr1 += vl; - outptr += vl; - n1 -= vl; - } - } - - return 0; - } - - if (b.dims == 2) + for (int y = 0; y < h; y++) { - // type 8 - c.create(w1, h1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - const float* ptr = a; - const float* ptr1 = b; - float* outptr = c; - - for (int y = 0; y < h1; y++) - { - vfloat32m8_t _a0x = vle32_v_f32m8_f32m1(ptr); - - int n = w1 * elempack1; - while (n > 0) - { - size_t vl = vsetvl_e32m8(n); - vfloat32m8_t _p1 = vle32_v_f32m8(ptr1, vl); - vfloat32m8_t _outp = op(_a0x, _p1, vl); - vse32_v_f32m8(outptr, _outp, vl); - - ptr1 += vl; - outptr += vl; - n -= vl; - } + const float* ptr1 = b.channel(q); - ptr += elempack; - } - - return 0; - } - - if (b.dims == 1) - { - c.create(w, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; + const int size = w * elempack; - if (b.w == 1 && elempack1 == 1) - { - // type 6 - return binary_op_6_11_16_25(a, b, c, opt); +#if __riscv_vector + int n = size; + while (n > 0) + { + size_t vl = vsetvl_e32m8(n); + vfloat32m8_t _p = vle32_v_f32m8(ptr, vl); + vfloat32m8_t _p1 = vle32_v_f32m8(ptr1, vl); + vfloat32m8_t _outp = op(_p, _p1, vl); + vse32_v_f32m8(outptr, _outp, vl); + n -= vl; + ptr += vl; + ptr1 += vl; + outptr += vl; + } +#else + for (int i = 0; i < size; i++) + { + *outptr = op(*ptr, *ptr1); + ptr += 1; + ptr1 += 1; + outptr += 1; } - - // type 7 - binary_op_7_13_19_29(a, b, c, opt); +#endif } } @@ -968,31 +498,36 @@ static int binary_op_rvv(const Mat& a, const Mat& b, Mat& c, const Option& opt) } template -static int binary_op_scalar_rvv(Mat& a, float b, const Option& opt) +static int binary_op_scalar_inplace(Mat& a, float b, const Option& opt) { Op op; - int w = a.w; - int h = a.h; - int d = a.d; - int channels = a.c; - int size = w * h * d; - int elempack = a.elempack; + + const int channels = a.c; + const int size = a.w * a.h * a.d * a.elempack; #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) { float* ptr = a.channel(q); - int n = size * elempack; + +#if __riscv_vector + int n = size; while (n > 0) { size_t vl = vsetvl_e32m8(n); vfloat32m8_t _p = vle32_v_f32m8(ptr, vl); _p = op(_p, b, vl); vse32_v_f32m8(ptr, _p, vl); - n -= vl; ptr += vl; } +#else + for (int i = 0; i < size; i++) + { + *ptr = op(*ptr, b); + ptr++; + } +#endif } return 0; @@ -1000,242 +535,281 @@ static int binary_op_scalar_rvv(Mat& a, float b, const Option& opt) namespace BinaryOp_riscv_functor { -#define MAKE_FUNCTION(NAME, IMPLVV, IMPLVS, IMPLSV) \ +#if __riscv_vector +#define MAKE_FUNCTION(NAME, IMPL, IMPLVV, IMPLVS, IMPLSV) \ struct NAME \ { \ + float operator()(const float& x, const float& y) const \ + { \ + return IMPL; \ + } \ vfloat32m8_t operator()(const vfloat32m8_t& x, const vfloat32m8_t& y, const size_t vl) const \ { \ return IMPLVV; \ } \ - vfloat32m8_t operator()(const vfloat32m8_t& x, const float y, const size_t vl) const \ + vfloat32m8_t operator()(const vfloat32m8_t& x, const float& y, const size_t vl) const \ { \ return IMPLVS; \ } \ - vfloat32m8_t operator()(const float x, const vfloat32m8_t& y, const size_t vl) const \ + vfloat32m8_t operator()(const float& x, const vfloat32m8_t& y, const size_t vl) const \ { \ return IMPLSV; \ } \ }; +#else +#define MAKE_FUNCTION(NAME, IMPL, IMPLVV, IMPLVS, IMPLSV) \ + struct NAME \ + { \ + float operator()(const float& x, const float& y) const \ + { \ + return IMPL; \ + } \ + }; +#endif -MAKE_FUNCTION(binary_op_add_rvv, vfadd_vv_f32m8(x, y, vl), vfadd_vf_f32m8(x, y, vl), vfadd_vf_f32m8(y, x, vl)) -MAKE_FUNCTION(binary_op_sub_rvv, vfsub_vv_f32m8(x, y, vl), vfsub_vf_f32m8(x, y, vl), vfrsub_vf_f32m8(y, x, vl)) -MAKE_FUNCTION(binary_op_mul_rvv, vfmul_vv_f32m8(x, y, vl), vfmul_vf_f32m8(x, y, vl), vfmul_vf_f32m8(y, x, vl)) -MAKE_FUNCTION(binary_op_div_rvv, vfdiv_vv_f32m8(x, y, vl), vfdiv_vf_f32m8(x, y, vl), vfrdiv_vf_f32m8(y, x, vl)) - -MAKE_FUNCTION(binary_op_max_rvv, vfmax_vv_f32m8(x, y, vl), vfmax_vf_f32m8(x, y, vl), vfmax_vf_f32m8(y, x, vl)) -MAKE_FUNCTION(binary_op_min_rvv, vfmin_vv_f32m8(x, y, vl), vfmin_vf_f32m8(x, y, vl), vfmin_vf_f32m8(y, x, vl)) -MAKE_FUNCTION(binary_op_pow_rvv, pow_ps(x, y, vl), pow_ps(x, vfmv_v_f_f32m8(y, vl), vl), pow_ps(vfmv_v_f_f32m8(x, vl), y, vl)) -MAKE_FUNCTION(binary_op_rsub_rvv, vfsub_vv_f32m8(y, x, vl), vfrsub_vf_f32m8(x, y, vl), vfsub_vf_f32m8(y, x, vl)) -MAKE_FUNCTION(binary_op_rdiv_rvv, vfdiv_vv_f32m8(y, x, vl), vfrdiv_vf_f32m8(x, y, vl), vfdiv_vf_f32m8(y, x, vl)) +// clang-format off +// *INDENT-OFF* +MAKE_FUNCTION(binary_op_add, x + y, vfadd_vv_f32m8(x, y, vl), vfadd_vf_f32m8(x, y, vl), vfadd_vf_f32m8(y, x, vl)) +MAKE_FUNCTION(binary_op_sub, x - y, vfsub_vv_f32m8(x, y, vl), vfsub_vf_f32m8(x, y, vl), vfrsub_vf_f32m8(y, x, vl)) +MAKE_FUNCTION(binary_op_mul, x * y, vfmul_vv_f32m8(x, y, vl), vfmul_vf_f32m8(x, y, vl), vfmul_vf_f32m8(y, x, vl)) +MAKE_FUNCTION(binary_op_div, x / y, vfdiv_vv_f32m8(x, y, vl), vfdiv_vf_f32m8(x, y, vl), vfrdiv_vf_f32m8(y, x, vl)) +MAKE_FUNCTION(binary_op_max, std::max(x, y), vfmax_vv_f32m8(x, y, vl), vfmax_vf_f32m8(x, y, vl), vfmax_vf_f32m8(y, x, vl)) +MAKE_FUNCTION(binary_op_min, std::min(x, y), vfmin_vv_f32m8(x, y, vl), vfmin_vf_f32m8(x, y, vl), vfmin_vf_f32m8(y, x, vl)) +MAKE_FUNCTION(binary_op_pow, (float)pow(x, y), pow_ps(x, y, vl), pow_ps(x, vfmv_v_f_f32m8(y, vl), vl), pow_ps(vfmv_v_f_f32m8(x, vl), y, vl)) +MAKE_FUNCTION(binary_op_rsub, y - x, vfsub_vv_f32m8(y, x, vl), vfrsub_vf_f32m8(x, y, vl), vfsub_vf_f32m8(y, x, vl)) +MAKE_FUNCTION(binary_op_rdiv, y / x, vfdiv_vv_f32m8(y, x, vl), vfrdiv_vf_f32m8(x, y, vl), vfdiv_vf_f32m8(y, x, vl)) +MAKE_FUNCTION(binary_op_rpow, (float)pow(y, x), pow_ps(y, x, vl), pow_ps(vfmv_v_f_f32m8(y, vl), x, vl), pow_ps(y, vfmv_v_f_f32m8(x, vl), vl)) +// *INDENT-ON* +// clang-format on #undef MAKE_FUNCTION } // namespace BinaryOp_riscv_functor -#endif -int BinaryOp_riscv::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +static int binary_op_scalar(const Mat& a, float b, Mat& c, int op_type, const Option& opt) { - int elembits = std::max(bottom_blobs[0].elembits(), bottom_blobs[1].elembits()); -#if __riscv_vector && __riscv_zfh - if (opt.use_fp16_storage && elembits == 16) - { - return forward_fp16s(bottom_blobs, top_blobs, opt); - } -#endif - const Mat& bottom_blob = bottom_blobs[0]; - const Mat& bottom_blob1 = bottom_blobs[1]; - Mat& top_blob = top_blobs[0]; - -#if __riscv_vector using namespace BinaryOp_riscv_functor; - int elempack = bottom_blob.elempack; - int elempack1 = bottom_blob1.elempack; - if (elempack != 1 || elempack1 != 1) - { - if (op_type == Operation_ADD) - return binary_op_rvv(bottom_blob, bottom_blob1, top_blob, opt); + if (op_type == BinaryOp::Operation_ADD) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_SUB) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_MUL) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_DIV) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_MAX) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_MIN) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_POW) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_RSUB) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_RDIV) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_RPOW) return binary_op_scalar(a, b, c, opt); + + // should never reach here + return 0; +} - if (op_type == Operation_SUB) - return binary_op_rvv(bottom_blob, bottom_blob1, top_blob, opt); +static int binary_op_no_broadcast(const Mat& a, const Mat& b, Mat& c, int op_type, const Option& opt) +{ + using namespace BinaryOp_riscv_functor; - if (op_type == Operation_MUL) - return binary_op_rvv(bottom_blob, bottom_blob1, top_blob, opt); + if (op_type == BinaryOp::Operation_ADD) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_SUB) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_MUL) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_DIV) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_MAX) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_MIN) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_POW) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_RSUB) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_RDIV) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_RPOW) return binary_op_no_broadcast(a, b, c, opt); + + // should never reach here + return 0; +} - if (op_type == Operation_DIV) - return binary_op_rvv(bottom_blob, bottom_blob1, top_blob, opt); +static int binary_op_broadcast_inner(const Mat& a, const Mat& b, Mat& c, int op_type, const Option& opt) +{ + // squeeze inner axes + Mat b2 = b; + if (b.dims == 2 && b.w == 1) + b2 = b.reshape(b.h); + else if (b.dims == 3 && b.h == 1) + b2 = b.reshape(b.c); + else if (b.dims == 3 && b.w == 1) + b2 = b.reshape(b.h, b.c); + else if (b.dims == 4 && b.d == 1) + b2 = b.reshape(b.c); + else if (b.dims == 4 && b.h == 1) + b2 = b.reshape(b.d, b.c); + else if (b.dims == 4 && b.w == 1) + b2 = b.reshape(b.h, b.d, b.c); - if (op_type == Operation_MAX) - return binary_op_rvv(bottom_blob, bottom_blob1, top_blob, opt); + using namespace BinaryOp_riscv_functor; - if (op_type == Operation_MIN) - return binary_op_rvv(bottom_blob, bottom_blob1, top_blob, opt); + if (op_type == BinaryOp::Operation_ADD) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_SUB) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_MUL) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_DIV) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_MAX) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_MIN) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_POW) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_inner(a, b2, c, opt); + + // should never reach here + return 0; +} - if (op_type == Operation_POW) - return binary_op_rvv(bottom_blob, bottom_blob1, top_blob, opt); +static int binary_op_broadcast_outer(const Mat& a, const Mat& b, Mat& c, int op_type, const Option& opt) +{ + using namespace BinaryOp_riscv_functor; - if (op_type == Operation_RSUB) - return binary_op_rvv(bottom_blob, bottom_blob1, top_blob, opt); + if (op_type == BinaryOp::Operation_ADD) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_SUB) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_MUL) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_DIV) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_MAX) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_MIN) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_POW) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_outer(a, b, c, opt); + + // should never reach here + return 0; +} - if (op_type == Operation_RDIV) - return binary_op_rvv(bottom_blob, bottom_blob1, top_blob, opt); - } -#endif +static int binary_op_broadcast_20(const Mat& a, const Mat& b, Mat& c, int op_type, const Option& opt) +{ + using namespace BinaryOp_riscv_functor; - return BinaryOp::forward(bottom_blobs, top_blobs, opt); + if (op_type == BinaryOp::Operation_ADD) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_SUB) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_MUL) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_DIV) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_MAX) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_MIN) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_POW) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_20(a, b, c, opt); + + // should never reach here + return 0; } -int BinaryOp_riscv::forward_inplace(Mat& bottom_top_blob, const Option& opt) const +static int get_reverse_op_type(int op_type) { -#if __riscv_vector - int elembits = bottom_top_blob.elembits(); + if (op_type == BinaryOp::Operation_SUB) return BinaryOp::Operation_RSUB; + if (op_type == BinaryOp::Operation_DIV) return BinaryOp::Operation_RDIV; + if (op_type == BinaryOp::Operation_POW) return BinaryOp::Operation_RPOW; + if (op_type == BinaryOp::Operation_RSUB) return BinaryOp::Operation_SUB; + if (op_type == BinaryOp::Operation_RDIV) return BinaryOp::Operation_DIV; + if (op_type == BinaryOp::Operation_RPOW) return BinaryOp::Operation_POW; + return op_type; +} -#if __riscv_zfh +int BinaryOp_riscv::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +{ + int elembits = std::max(bottom_blobs[0].elembits(), bottom_blobs[1].elembits()); + +#if __riscv_vector && __riscv_zfh if (opt.use_fp16_storage && elembits == 16) { - return forward_inplace_fp16s(bottom_top_blob, opt); + return forward_fp16s(bottom_blobs, top_blobs, opt); } #endif - using namespace BinaryOp_riscv_functor; - - if (op_type == Operation_ADD) - return binary_op_scalar_rvv(bottom_top_blob, b, opt); + const bool b_is_scalar = bottom_blobs[1].w * bottom_blobs[1].h * bottom_blobs[1].d * bottom_blobs[1].c * bottom_blobs[1].elempack == 1; + const bool a_rank_is_lower = bottom_blobs[0].dims < bottom_blobs[1].dims && !b_is_scalar; + const bool a_size_is_lower = bottom_blobs[0].w * bottom_blobs[0].h * bottom_blobs[0].d * bottom_blobs[0].c * bottom_blobs[0].elempack < bottom_blobs[1].w * bottom_blobs[1].h * bottom_blobs[1].d * bottom_blobs[1].c * bottom_blobs[1].elempack; + const bool a_is_lower = a_rank_is_lower || (!a_rank_is_lower && a_size_is_lower); + const Mat& A = a_is_lower ? bottom_blobs[1] : bottom_blobs[0]; + const Mat& B = a_is_lower ? bottom_blobs[0] : bottom_blobs[1]; + const int op_type_r = a_is_lower ? get_reverse_op_type(op_type) : op_type; - if (op_type == Operation_SUB) - return binary_op_scalar_rvv(bottom_top_blob, b, opt); - - if (op_type == Operation_MUL) - return binary_op_scalar_rvv(bottom_top_blob, b, opt); - - if (op_type == Operation_DIV) - return binary_op_scalar_rvv(bottom_top_blob, b, opt); - - if (op_type == Operation_MAX) - return binary_op_scalar_rvv(bottom_top_blob, b, opt); - - if (op_type == Operation_MIN) - return binary_op_scalar_rvv(bottom_top_blob, b, opt); - - if (op_type == Operation_POW) - return binary_op_scalar_rvv(bottom_top_blob, b, opt); - - if (op_type == Operation_RSUB) - return binary_op_scalar_rvv(bottom_top_blob, b, opt); - - if (op_type == Operation_RDIV) - return binary_op_scalar_rvv(bottom_top_blob, b, opt); - -#endif - return BinaryOp::forward_inplace(bottom_top_blob, opt); -} - -// fp16sa -#if __riscv_vector && __riscv_zfh -template -static int binary_op_2_3_4_20_fp16s(const Mat& a, const Mat& b, Mat& c, const Option& opt) -{ - Op op; + Mat& top_blob = top_blobs[0]; + top_blob.create_like(A, opt.blob_allocator); + if (top_blob.empty()) + return -100; - int w = b.w; - int h = b.h; - int d = b.d; - int channels = b.c; - int elempack = b.elempack; - int size = w * h * d * elempack; + // B is a scalar + if (B.w * B.h * B.d * B.c * B.elempack == 1) + { + return binary_op_scalar(A, B[0], top_blob, op_type_r, opt); + } - // type 2 3 4 20 - c.create_like(b, opt.blob_allocator); - if (c.empty()) - return -100; + // no broadcast + if (A.dims == B.dims && A.w == B.w && A.h == B.h && A.d == B.d && A.c == B.c && A.elempack == B.elempack) + { + return binary_op_no_broadcast(A, B, top_blob, op_type_r, opt); + } - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) + // broadcast B for inner axis + if ((B.dims < A.dims) + || (A.dims == 2 && B.w == 1 && B.h == A.h) + || (A.dims == 3 && B.w == 1 && B.h == 1 && B.c == A.c) + || (A.dims == 3 && B.w == 1 && B.h == A.h && B.c == A.c) + || (A.dims == 4 && B.w == 1 && B.h == 1 && B.d == 1 && B.c == A.c) + || (A.dims == 4 && B.w == 1 && B.h == 1 && B.d == A.d && B.c == A.c) + || (A.dims == 4 && B.w == 1 && B.h == A.h && B.d == A.d && B.c == A.c)) { - const __fp16 a0 = ((const __fp16*)a)[0]; - const __fp16* ptr = b.channel(q); - __fp16* outptr = c.channel(q); + return binary_op_broadcast_inner(A, B, top_blob, op_type_r, opt); + } - int n = size; - while (n > 0) - { - size_t vl = vsetvl_e16m8(n); - vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); - vfloat16m8_t _outp = op(a0, _p, vl); - vse16_v_f16m8(outptr, _outp, vl); + // broadcast B for outer axis + if (B.elempack == 1 && ((A.dims == 2 && B.w == A.w && B.h == 1) || (A.dims == 3 && B.w == A.w && B.h == 1 && B.c == 1) || (A.dims == 3 && B.w == A.w && B.h == A.h && B.c == 1) || (A.dims == 4 && B.w == A.w && B.h == 1 && B.d == 1 && B.c == 1) || (A.dims == 4 && B.w == A.w && B.h == A.h && B.d == 1 && B.c == 1) || (A.dims == 4 && B.w == A.w && B.h == A.h && B.d == A.d && B.c == 1))) + { + return binary_op_broadcast_outer(A, B, top_blob, op_type_r, opt); + } - ptr += vl; - outptr += vl; - n -= vl; - } + // some special broadcast rule here + if (A.dims == 3 && B.dims == 3 && A.w == B.w && B.h == 1 && A.c == B.c) + { + return binary_op_broadcast_20(A, B, top_blob, op_type_r, opt); } return 0; } -template -static int binary_op_6_11_16_25_fp16s(const Mat& a, const Mat& b, Mat& c, const Option& opt) +int BinaryOp_riscv::forward_inplace(Mat& bottom_top_blob, const Option& opt) const { - Op op; - - int w = a.w; - int h = a.h; - int d = a.d; - int channels = a.c; - int elempack = a.elempack; - int size = w * h * d * elempack; - - // type 6 11 16 25 - c.create_like(a, opt.blob_allocator); - if (c.empty()) - return -100; + int elembits = bottom_top_blob.elembits(); - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) +#if __riscv_zfh + if (opt.use_fp16_storage && elembits == 16) { - const __fp16* ptr = a.channel(q); - const __fp16 b0 = ((const __fp16*)b)[0]; - __fp16* outptr = c.channel(q); + return forward_inplace_fp16s(bottom_top_blob, opt); + } +#endif - int n = size; - while (n > 0) - { - size_t vl = vsetvl_e16m8(n); - vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); - vfloat16m8_t _outp = op(_p, b0, vl); - vse16_v_f16m8(outptr, _outp, vl); + using namespace BinaryOp_riscv_functor; - ptr += vl; - outptr += vl; - n -= vl; - } - } + if (op_type == Operation_ADD) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_SUB) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_MUL) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_DIV) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_MAX) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_MIN) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_POW) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_RSUB) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_RDIV) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_RPOW) return binary_op_scalar_inplace(bottom_top_blob, b, opt); return 0; } +#if __riscv_vector && __riscv_zfh template -static int binary_op_7_13_19_29_fp16s(const Mat& a, const Mat& b, Mat& c, const Option& opt) +static int binary_op_scalar_fp16s(const Mat& a, __fp16 b, Mat& c, const Option& opt) { Op op; - int w = a.w; - int h = a.h; - int d = a.d; - int channels = a.c; - int elempack = a.elempack; - int size = w * h * d * elempack; - - // type 7 13 19 29 - c.create_like(a, opt.blob_allocator); - if (c.empty()) - return -100; + const int channels = a.c; + const int size = a.w * a.h * a.d * a.elempack; #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) { const __fp16* ptr = a.channel(q); - const __fp16* ptr1 = b.channel(q); __fp16* outptr = c.channel(q); int n = size; @@ -1243,1444 +817,361 @@ static int binary_op_7_13_19_29_fp16s(const Mat& a, const Mat& b, Mat& c, const { size_t vl = vsetvl_e16m8(n); vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); - vfloat16m8_t _p1 = vle16_v_f16m8(ptr1, vl); - vfloat16m8_t _outp = op(_p, _p1, vl); - vse16_v_f16m8(outptr, _outp, vl); - - ptr += vl; - ptr1 += vl; - outptr += vl; + _p = op(_p, b, vl); + vse16_v_f16m8(outptr, _p, vl); n -= vl; - } - } - - return 0; -} - -template -static int binary_op_rvv_fp16s(const Mat& a, const Mat& b, Mat& c, const Option& opt) -{ - Op op; - - int w = a.w; - int h = a.h; - int d = a.d; - int channels = a.c; - int size = w * h * d; - size_t elemsize = a.elemsize; - int elempack = a.elempack; - - int w1 = b.w; - int h1 = b.h; - int d1 = b.d; - int channels1 = b.c; - int size1 = w1 * h1 * d1; - size_t elemsize1 = b.elemsize; - int elempack1 = b.elempack; - - if (a.dims == 4) - { - if (b.dims == 4) - { - // type 29 - return binary_op_7_13_19_29_fp16s(a, b, c, opt); - } - - c.create(w, h, d, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.dims == 3) - { - // type 28 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const __fp16* ptr = a.channel(q); - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - - for (int z = 0; z < d; z++) - { - for (int y = 0; y < h; y++) - { - vfloat16m8_t _b0x = vle16_v_f16m8_f16m1(ptr1); - - int n = w * elempack; - while (n > 0) - { - size_t vl = vsetvl_e16m8(n); - vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); - vfloat16m8_t _outp = op(_p, _b0x, vl); - vse16_v_f16m8(outptr, _outp, vl); - - ptr += vl; - outptr += vl; - n -= vl; - } - - ptr1 += elempack1; - } - } - } - - return 0; - } - - if (b.dims == 2) - { - // type 27 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const __fp16* ptr = a.channel(q); - const __fp16* ptr1 = b.row(q); - __fp16* outptr = c.channel(q); - - for (int z = 0; z < d; z++) - { - vfloat16m8_t _b0x = vle16_v_f16m8_f16m1(ptr1); - - int n = w * h * elempack; - while (n > 0) - { - size_t vl = vsetvl_e16m8(n); - vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); - vfloat16m8_t _outp = op(_p, _b0x, vl); - vse16_v_f16m8(outptr, _outp, vl); - - ptr += vl; - outptr += vl; - n -= vl; - } - - ptr1 += elempack1; - } - } - - return 0; - } - - if (b.dims == 1) - { - if (b.w == 1 && elempack1 == 1) - { - // type 25 - return binary_op_6_11_16_25_fp16s(a, b, c, opt); - } - - // type 26 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const __fp16* ptr = a.channel(q); - __fp16* outptr = c.channel(q); - - vfloat16m8_t _b0x = vle16_v_f16m8_f16m1((const __fp16*)b + q * elempack); - - int n = size * elempack; - while (n > 0) - { - size_t vl = vsetvl_e16m8(n); - vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); - vfloat16m8_t _outp = op(_p, _b0x, vl); - vse16_v_f16m8(outptr, _outp, vl); - - outptr += vl; - ptr += vl; - n -= vl; - } - } - - return 0; - } - } - else if (a.dims == 3) - { - if (b.dims == 4) - { - // type 23 - c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const __fp16* ptr = a.channel(q); - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - - for (int z = 0; z < d1; z++) - { - for (int y = 0; y < h1; y++) - { - vfloat16m8_t _a0x = vle16_v_f16m8_f16m1(ptr); - - int n = w1 * elempack1; - while (n > 0) - { - size_t vl = vsetvl_e16m8(n); - vfloat16m8_t _p1 = vle16_v_f16m8(ptr1, vl); - vfloat16m8_t _outp = op(_a0x, _p1, vl); - vse16_v_f16m8(outptr, _outp, vl); - - ptr1 += vl; - outptr += vl; - n -= vl; - } - - ptr += elempack; - } - } - } - - return 0; - } - - if (b.dims == 3) - { - if (w1 == 1 && h1 == 1 && channels1 == channels) - { - // special type 1 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const __fp16* ptr = a.channel(q); - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - - vfloat16m8_t _b0x = vle16_v_f16m8_f16m1(ptr1); - - int n = size * elempack; - while (n > 0) - { - size_t vl = vsetvl_e16m8(n); - vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); - vfloat16m8_t _outp = op(_p, _b0x, vl); - vse16_v_f16m8(outptr, _outp, vl); - - ptr += vl; - outptr += vl; - n -= vl; - } - } - - return 0; - } - - if (w1 == w && h1 == h && channels1 == 1 && elempack1 == 1) - { - // special type 2 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const __fp16* ptr = a.channel(q); - const __fp16* ptr1 = b; - __fp16* outptr = c.channel(q); - - for (int i = 0; i < size; i++) - { - int n = elempack; - while (n > 0) - { - size_t vl = vsetvl_e16m8(n); - vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); - vfloat16m8_t _outp = op(_p, *ptr1, vl); - - vse16_v_f16m8(outptr, _outp, vl); - n -= vl; - ptr += vl; - outptr += vl; - } - ptr1 += 1; - } - } - - return 0; - } - - if (w == 1 && h == 1 && channels1 == channels) - { - // special type 3 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const __fp16* ptr = a.channel(q); - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - - vfloat16m8_t _a0x = vle16_v_f16m8_f16m1(ptr); - - int n1 = size1 * elempack1; - while (n1 > 0) - { - size_t vl = vsetvl_e16m8(n1); - vfloat16m8_t _p1 = vle16_v_f16m8(ptr1, vl); - vfloat16m8_t _outp = op(_a0x, _p1, vl); - vse16_v_f16m8(outptr, _outp, vl); - - ptr1 += vl; - outptr += vl; - n1 -= vl; - } - } - - return 0; - } - - if (w1 == w && h1 == h && channels == 1 && elempack == 1) - { - // special type 4 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const __fp16* ptr = a; - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - - for (int i = 0; i < size1; i++) - { - int n1 = elempack1; - while (n1 > 0) - { - size_t vl = vsetvl_e16m8(n1); - vfloat16m8_t _p1 = vle16_v_f16m8(ptr1, vl); - vfloat16m8_t _p = vfmv_v_f_f16m8(*ptr, vl); - vfloat16m8_t _outp = op(_p, _p1, vl); - vse16_v_f16m8(outptr, _outp, vl); - - n1 -= vl; - ptr1 += vl; - outptr += vl; - } - ptr += 1; - } - } - - return 0; - } - - if (w != 1 && w1 == 1 && h1 == h && channels1 == channels) - { - // special type 5 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const __fp16* ptr = a.channel(q); - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - - for (int y = 0; y < h; y++) - { - vfloat16m8_t _b0x = vle16_v_f16m8_f16m1(ptr1); - - int n = w * elempack; - while (n > 0) - { - size_t vl = vsetvl_e16m8(n); - vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); - vfloat16m8_t _outp = op(_p, _b0x, vl); - vse16_v_f16m8(outptr, _outp, vl); - - ptr += vl; - outptr += vl; - n -= vl; - } - - ptr1 += elempack1; - } - } - - return 0; - } - - if (w1 == w && h != 1 && h1 == 1 && channels1 == channels) - { - // special type 6 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const __fp16* ptr = a.channel(q); - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - - for (int y = 0; y < h; y++) - { - for (int x = 0; x < w; x++) - { - int n = elempack; - const __fp16* ptr1_vol = ptr1 + x * elempack; - while (n > 0) - { - size_t vl = vsetvl_e16m8(n); - vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); - vfloat16m8_t _p1 = vle16_v_f16m8(ptr1_vol, vl); - vfloat16m8_t _outp = op(_p, _p1, vl); - vse16_v_f16m8(outptr, _outp, vl); - - outptr += vl; - ptr += vl; - n -= vl; - ptr1_vol += vl; - } - } - } - } - - return 0; - } - - if (w1 != 1 && w == 1 && h1 == h && channels1 == channels) - { - // special type 7 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const __fp16* ptr = a.channel(q); - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - - for (int y = 0; y < h1; y++) - { - vfloat16m8_t _a0x = vle16_v_f16m8_f16m1(ptr); - - int n = w1 * elempack1; - while (n > 0) - { - size_t vl = vsetvl_e16m8(n); - vfloat16m8_t _p1 = vle16_v_f16m8(ptr1, vl); - vfloat16m8_t _outp = op(_a0x, _p1, vl); - vse16_v_f16m8(outptr, _outp, vl); - - ptr1 += vl; - outptr += vl; - n -= vl; - } - - ptr += elempack; - } - } - - return 0; - } - - if (w1 == w && h1 != 1 && h == 1 && channels1 == channels) - { - // special type 8 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const __fp16* ptr = a.channel(q); - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - - for (int y = 0; y < h1; y++) - { - for (int x = 0; x < w1; x++) - { - int n = elempack; - const __fp16* ptr_vol = ptr + x * elempack; - while (n > 0) - { - size_t vl = vsetvl_e16m8(n); - vfloat16m8_t _p = vle16_v_f16m8(ptr_vol, vl); - vfloat16m8_t _p1 = vle16_v_f16m8(ptr1, vl); - vfloat16m8_t _outp = op(_p, _p1, vl); - vse16_v_f16m8(outptr, _outp, vl); - - ptr1 += vl; - outptr += vl; - ptr_vol += vl; - n -= vl; - } - } - } - } - - return 0; - } - - // type 19 - return binary_op_7_13_19_29_fp16s(a, b, c, opt); - } - - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.dims == 2) - { - // type 18 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const __fp16* ptr = a.channel(q); - const __fp16* ptr1 = b.row(q); - __fp16* outptr = c.channel(q); - - for (int y = 0; y < h; y++) - { - vfloat16m8_t _b0x = vle16_v_f16m8_f16m1(ptr1); - - int n = w * elempack; - while (n > 0) - { - size_t vl = vsetvl_e16m8(n); - vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); - vfloat16m8_t _outp = op(_p, _b0x, vl); - vse16_v_f16m8(outptr, _outp, vl); - - ptr += vl; - outptr += vl; - n -= vl; - } - - ptr1 += elempack1; - } - } - - return 0; - } - - if (b.dims == 1) - { - if (b.w == 1 && elempack1 == 1) - { - // type 16 - return binary_op_6_11_16_25_fp16s(a, b, c, opt); - } - - // type 17 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const __fp16* ptr = a.channel(q); - __fp16* outptr = c.channel(q); - - vfloat16m8_t _b0x = vle16_v_f16m8_f16m1((const __fp16*)b + q * elempack); - - int n = size * elempack; - while (n > 0) - { - size_t vl = vsetvl_e16m8(n); - vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); - vfloat16m8_t _outp = op(_p, _b0x, vl); - vse16_v_f16m8(outptr, _outp, vl); - - outptr += vl; - ptr += vl; - n -= vl; - } - } - - return 0; - } - } - else if (a.dims == 2) - { - if (b.dims == 4) - { - // type 22 - c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const __fp16* ptr = a.row(q); - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - - for (int z = 0; z < d1; z++) - { - vfloat16m8_t _a0x = vle16_v_f16m8_f16m1(ptr); - - int n = w1 * h1 * elempack1; - while (n > 0) - { - size_t vl = vsetvl_e16m8(n); - vfloat16m8_t _p1 = vle16_v_f16m8(ptr1, vl); - vfloat16m8_t _outp = op(_a0x, _p1, vl); - vse16_v_f16m8(outptr, _outp, vl); - - ptr1 += vl; - outptr += vl; - n -= vl; - } - - ptr += elempack; - } - } - - return 0; - } - - if (b.dims == 3) - { - // type 14 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const __fp16* ptr = a.row(q); - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - - for (int y = 0; y < h1; y++) - { - vfloat16m8_t _a0x = vle16_v_f16m8_f16m1(ptr); - - int n = w1 * elempack1; - while (n > 0) - { - size_t vl = vsetvl_e16m8(n); - vfloat16m8_t _p1 = vle16_v_f16m8(ptr1, vl); - vfloat16m8_t _outp = op(_a0x, _p1, vl); - vse16_v_f16m8(outptr, _outp, vl); - - ptr1 += vl; - outptr += vl; - n -= vl; - } - - ptr += elempack; - } - } - - return 0; - } - - c.create(w, h, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.dims == 2) - { - // type 13 - return binary_op_7_13_19_29_fp16s(a, b, c, opt); - } - - if (b.dims == 1) - { - c.create(w, h, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.w == 1 && elempack1 == 1) - { - // type 11 - return binary_op_6_11_16_25_fp16s(a, b, c, opt); - } - - // type 12 - const __fp16* ptr = a; - const __fp16* ptr1 = b; - __fp16* outptr = c; - - for (int y = 0; y < h; y++) - { - vfloat16m8_t _b0x = vle16_v_f16m8_f16m1(ptr1); - - int n = w * elempack; - while (n > 0) - { - size_t vl = vsetvl_e16m8(n); - vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); - vfloat16m8_t _outp = op(_p, _b0x, vl); - vse16_v_f16m8(outptr, _outp, vl); - - ptr += vl; - outptr += vl; - n -= vl; - } - - ptr1 += elempack; - } - - return 0; - } - } - else if (a.dims == 1) - { - if (a.w == 1 && elempack == 1) - { - // type 2 3 4 20 - return binary_op_2_3_4_20_fp16s(a, b, c, opt); - } - - if (b.dims == 4) - { - // type 21 - c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - - vfloat16m8_t _a0x = vle16_v_f16m8_f16m1((const __fp16*)a + q * elempack); - - int n1 = size1 * elempack1; - while (n1 > 0) - { - size_t vl = vsetvl_e16m8(n1); - vfloat16m8_t _p1 = vle16_v_f16m8(ptr1, vl); - vfloat16m8_t _outp = op(_a0x, _p1, vl); - vse16_v_f16m8(outptr, _outp, vl); - - ptr1 += vl; - outptr += vl; - n1 -= vl; - } - } - - return 0; - } - - if (b.dims == 3) - { - // type 9 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - - vfloat16m8_t _a0x = vle16_v_f16m8_f16m1((const __fp16*)a + q * elempack); - - int n1 = size1 * elempack1; - while (n1 > 0) - { - size_t vl = vsetvl_e16m8(n1); - vfloat16m8_t _p1 = vle16_v_f16m8(ptr1, vl); - vfloat16m8_t _outp = op(_a0x, _p1, vl); - vse16_v_f16m8(outptr, _outp, vl); - - ptr1 += vl; - outptr += vl; - n1 -= vl; - } - } - - return 0; - } - - if (b.dims == 2) - { - // type 8 - c.create(w1, h1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - const __fp16* ptr = a; - const __fp16* ptr1 = b; - __fp16* outptr = c; - - for (int y = 0; y < h1; y++) - { - vfloat16m8_t _a0x = vle16_v_f16m8_f16m1(ptr); - - int n = w1 * elempack1; - while (n > 0) - { - size_t vl = vsetvl_e16m8(n); - vfloat16m8_t _p1 = vle16_v_f16m8(ptr1, vl); - vfloat16m8_t _outp = op(_a0x, _p1, vl); - vse16_v_f16m8(outptr, _outp, vl); - - ptr1 += vl; - outptr += vl; - n -= vl; - } - - ptr += elempack; - } - - return 0; - } - - if (b.dims == 1) - { - c.create(w, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.w == 1 && elempack1 == 1) - { - // type 6 - return binary_op_6_11_16_25_fp16s(a, b, c, opt); - } - - // type 7 - binary_op_7_13_19_29_fp16s(a, b, c, opt); - } - } - - return 0; -} - -template -static int binary_op_fp16s(const Mat& a, const Mat& b, Mat& c, const Option& opt) -{ - Op op; - - int w = a.w; - int h = a.h; - int d = a.d; - int channels = a.c; - int size = w * h * d; - size_t elemsize = a.elemsize; - - int w1 = b.w; - int h1 = b.h; - int d1 = b.d; - int channels1 = b.c; - int size1 = w1 * h1 * d1; - - if (a.dims == 4) - { - if (b.dims == 4) - { - // type 29 - return binary_op_7_13_19_29_fp16s(a, b, c, opt); - } - - c.create(w, h, d, channels, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.dims == 3) - { - // type 28 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const __fp16* ptr = a.channel(q); - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - - for (int z = 0; z < d; z++) - { - for (int y = 0; y < h; y++) - { - const __fp16 b0 = ptr1[y]; - for (int x = 0; x < w; x++) - { - outptr[x] = op(ptr[x], b0); - } - - ptr += w; - outptr += w; - } - - ptr1 += h; - } - } - - return 0; - } - - if (b.dims == 2) - { - // type 27 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const __fp16* ptr = a.channel(q); - const __fp16* ptr1 = b.row(q); - __fp16* outptr = c.channel(q); - - for (int z = 0; z < d; z++) - { - const __fp16 b0 = ptr1[z]; - for (int y = 0; y < h; y++) - { - for (int x = 0; x < w; x++) - { - outptr[x] = op(ptr[x], b0); - } - - ptr += w; - outptr += w; - } - } - } - - return 0; - } - - if (b.dims == 1) - { - if (b.w == 1) - { - // type 25 - return binary_op_6_11_16_25_fp16s(a, b, c, opt); - } - - // type 26 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const __fp16* ptr = a.channel(q); - const __fp16 b0 = ((const __fp16*)b)[q]; - __fp16* outptr = c.channel(q); - - for (int i = 0; i < size; i++) - { - outptr[i] = op(ptr[i], b0); - } - } - - return 0; - } - } - else if (a.dims == 3) - { - if (b.dims == 4) - { - // type 23 - c.create(w1, h1, d1, channels1, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const __fp16* ptr = a.channel(q); - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - - for (int z = 0; z < d1; z++) - { - for (int y = 0; y < h1; y++) - { - const __fp16 a0 = ptr[y]; - for (int x = 0; x < w1; x++) - { - outptr[x] = op(a0, ptr1[x]); - } - - ptr1 += w1; - outptr += w1; - } - - ptr += h1; - } - } - - return 0; - } - - if (b.dims == 3) - { - if (w1 == 1 && h1 == 1 && channels1 == channels) - { - // special type 1 - c.create(w, h, channels, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const __fp16* ptr = a.channel(q); - const __fp16* b0 = b.channel(q); - __fp16* outptr = c.channel(q); - for (int i = 0; i < size; i++) - { - outptr[i] = op(ptr[i], b0[0]); - } - } - - return 0; - } - - if (w1 == w && h1 == h && channels1 == 1) - { - // special type 2 - c.create(w, h, channels, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const __fp16* ptr = a.channel(q); - const __fp16* ptr1 = b; - __fp16* outptr = c.channel(q); - for (int i = 0; i < size; i++) - { - outptr[i] = op(ptr[i], ptr1[i]); - } - } - - return 0; - } - - if (w == 1 && h == 1 && channels1 == channels) - { - // special type 3 - c.create(w1, h1, channels1, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const __fp16* a0 = a.channel(q); - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - for (int i = 0; i < size1; i++) - { - outptr[i] = op(a0[0], ptr1[i]); - } - } - - return 0; - } - - if (w1 == w && h1 == h && channels == 1) - { - // special type 4 - c.create(w1, h1, channels1, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const __fp16* ptr = a; - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - for (int i = 0; i < size1; i++) - { - outptr[i] = op(ptr[i], ptr1[i]); - } - } - - return 0; - } - - if (w != 1 && w1 == 1 && h1 == h && channels1 == channels) - { - // special type 5 - c.create(w, h, channels, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const __fp16* ptr = a.channel(q); - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - - for (int y = 0; y < h; y++) - { - const __fp16 b0 = ptr1[y]; - for (int x = 0; x < w; x++) - { - outptr[x] = op(ptr[x], b0); - } - - ptr += w; - outptr += w; - } - } - - return 0; - } - - if (w1 == w && h != 1 && h1 == 1 && channels1 == channels) - { - // special type 6 - c.create(w, h, channels, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const __fp16* ptr = a.channel(q); - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - - for (int y = 0; y < h; y++) - { - for (int x = 0; x < w; x++) - { - outptr[x] = op(ptr[x], ptr1[x]); - } + ptr += vl; + outptr += vl; + } + } - ptr += w; - outptr += w; - } - } + return 0; +} - return 0; - } +template +static int binary_op_no_broadcast_fp16s(const Mat& a, const Mat& b, Mat& c, const Option& opt) +{ + Op op; - if (w1 != 1 && w == 1 && h1 == h && channels1 == channels) - { - // special type 7 - c.create(w1, h1, channels1, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; + const int channels = a.c; + const int size = a.w * a.h * a.d * a.elempack; - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const __fp16* ptr = a.channel(q); - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const __fp16* ptr = a.channel(q); + const __fp16* ptr1 = b.channel(q); + __fp16* outptr = c.channel(q); - for (int y = 0; y < h1; y++) - { - const __fp16 a0 = ptr[y]; - for (int x = 0; x < w1; x++) - { - outptr[x] = op(a0, ptr1[x]); - } + int n = size; + while (n > 0) + { + size_t vl = vsetvl_e16m8(n); + vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); + vfloat16m8_t _p1 = vle16_v_f16m8(ptr1, vl); + vfloat16m8_t _outp = op(_p, _p1, vl); + vse16_v_f16m8(outptr, _outp, vl); + n -= vl; + ptr += vl; + ptr1 += vl; + outptr += vl; + } + } - ptr1 += w1; - outptr += w1; - } - } + return 0; +} - return 0; - } +template +static int binary_op_broadcast_inner_fp16s(const Mat& a, const Mat& b, Mat& c, const Option& opt) +{ + Op op; - if (w1 == w && h1 != 1 && h == 1 && channels1 == channels) - { - // special type 8 - c.create(w1, h1, channels1, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; + int w = a.w; + int h = a.h; + int d = a.d; + int channels = a.c; + int elempack = a.elempack; - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const __fp16* ptr = a.channel(q); - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); + if (a.dims == 2 && b.dims == 1) + { + // type 8 + #pragma omp parallel for num_threads(opt.num_threads) + for (int y = 0; y < h; y++) + { + const __fp16* ptr = a.row(y); + __fp16* outptr = c.row<__fp16>(y); - for (int y = 0; y < h1; y++) - { - for (int x = 0; x < w1; x++) - { - outptr[x] = op(ptr[x], ptr1[x]); - } + const __fp16 _b = ((const __fp16*)b)[y]; - ptr1 += w1; - outptr += w1; - } - } + const int size = w * elempack; - return 0; + int n = size; + vfloat16m8_t _bx = (elempack == 1) ? vfmv_v_f_f16m8(_b, vsetvl_e16m8(n)) : vle16_v_f16m8_f16m1((const __fp16*)b + y * elempack); + while (n > 0) + { + size_t vl = vsetvl_e16m8(n); + vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); + vfloat16m8_t _outp = op(_p, _bx, vl); + vse16_v_f16m8(outptr, _outp, vl); + n -= vl; + ptr += vl; + outptr += vl; } - - // type 19 - return binary_op_7_13_19_29_fp16s(a, b, c, opt); } + } - c.create(w, h, channels, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.dims == 2) + if ((a.dims == 3 || a.dims == 4) && b.dims == 1) + { + // type 9 11 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - // type 18 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const __fp16* ptr = a.channel(q); - const __fp16* ptr1 = b.row(q); - __fp16* outptr = c.channel(q); + const __fp16* ptr = a.channel(q); + __fp16* outptr = c.channel(q); - for (int y = 0; y < h; y++) - { - const __fp16 b0 = ptr1[y]; - for (int x = 0; x < w; x++) - { - outptr[x] = op(ptr[x], b0); - } + const __fp16 _b = ((const __fp16*)b)[q]; - ptr += w; - outptr += w; - } - } + const int size = w * h * d * elempack; - return 0; + int n = size; + vfloat16m8_t _bx = (elempack == 1) ? vfmv_v_f_f16m8(_b, vsetvl_e16m8(n)) : vle16_v_f16m8_f16m1((const __fp16*)b + q * elempack); + while (n > 0) + { + size_t vl = vsetvl_e16m8(n); + vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); + vfloat16m8_t _outp = op(_p, _bx, vl); + vse16_v_f16m8(outptr, _outp, vl); + n -= vl; + ptr += vl; + outptr += vl; + } } + } - if (b.dims == 1) + if (a.dims == 3 && b.dims == 2) + { + // type 10 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - if (b.w == 1) - { - // type 16 - return binary_op_6_11_16_25_fp16s(a, b, c, opt); - } + const __fp16* ptr = a.channel(q); + const __fp16* ptr1 = b.row(q); + __fp16* outptr = c.channel(q); - // type 17 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) + const int size = w * elempack; + + for (int y = 0; y < h; y++) { - const __fp16* ptr = a.channel(q); - const __fp16 b0 = ((const __fp16*)b)[q]; - __fp16* outptr = c.channel(q); + const __fp16 _b = ptr1[y]; - for (int i = 0; i < size; i++) + int n = size; + vfloat16m8_t _bx = (elempack == 1) ? vfmv_v_f_f16m8(_b, vsetvl_e16m8(n)) : vle16_v_f16m8_f16m1(ptr1 + y * elempack); + while (n > 0) { - outptr[i] = op(ptr[i], b0); + size_t vl = vsetvl_e16m8(n); + vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); + vfloat16m8_t _outp = op(_p, _bx, vl); + vse16_v_f16m8(outptr, _outp, vl); + n -= vl; + ptr += vl; + outptr += vl; } } - - return 0; } } - else if (a.dims == 2) + + if (a.dims == 4 && b.dims == 2) { - if (b.dims == 4) + // type 12 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - // type 22 - c.create(w1, h1, d1, channels1, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; + const __fp16* ptr = a.channel(q); + const __fp16* ptr1 = b.row(q); + __fp16* outptr = c.channel(q); - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) + const int size = w * h * elempack; + + for (int z = 0; z < d; z++) { - const __fp16* ptr = a.row(q); - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); + const __fp16 _b = ptr1[z]; - for (int z = 0; z < d1; z++) + int n = size; + vfloat16m8_t _bx = (elempack == 1) ? vfmv_v_f_f16m8(_b, vsetvl_e16m8(n)) : vle16_v_f16m8_f16m1(ptr1 + z * elempack); + while (n > 0) { - const __fp16 a0 = ptr[z]; - for (int y = 0; y < h1; y++) - { - for (int x = 0; x < w1; x++) - { - outptr[x] = op(a0, ptr1[x]); - } - - ptr1 += w1; - outptr += w1; - } + size_t vl = vsetvl_e16m8(n); + vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); + vfloat16m8_t _outp = op(_p, _bx, vl); + vse16_v_f16m8(outptr, _outp, vl); + n -= vl; + ptr += vl; + outptr += vl; } } - - return 0; } + } - if (b.dims == 3) + if (a.dims == 4 && b.dims == 3) + { + // type 13 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - // type 14 - c.create(w1, h1, channels1, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; + const __fp16* ptr = a.channel(q); + __fp16* outptr = c.channel(q); - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) + const int size = w * elempack; + + for (int z = 0; z < d; z++) { - const __fp16* ptr = a.row(q); - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); + const __fp16* ptr1 = b.channel(q).row(z); - for (int y = 0; y < h1; y++) + for (int y = 0; y < h; y++) { - const __fp16 a0 = ptr[y]; - for (int x = 0; x < w1; x++) + const __fp16 _b = ptr1[y]; + + int n = size; + vfloat16m8_t _bx = (elempack == 1) ? vfmv_v_f_f16m8(_b, vsetvl_e16m8(n)) : vle16_v_f16m8_f16m1(ptr1 + y * elempack); + while (n > 0) { - outptr[x] = op(a0, ptr1[x]); + size_t vl = vsetvl_e16m8(n); + vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); + vfloat16m8_t _outp = op(_p, _bx, vl); + vse16_v_f16m8(outptr, _outp, vl); + n -= vl; + ptr += vl; + outptr += vl; } - - ptr1 += w1; - outptr += w1; } } - - return 0; } + } - c.create(w, h, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; + return 0; +} - if (b.dims == 2) - { - // type 13 - return binary_op_7_13_19_29_fp16s(a, b, c, opt); - } +template +static int binary_op_broadcast_outer_fp16s(const Mat& a, const Mat& b, Mat& c, const Option& opt) +{ + Op op; + + int w = a.w; + int h = a.h; + int d = a.d; + int channels = a.c; + int elempack = a.elempack; - if (b.dims == 1) + if (a.dims == 2) + { + // type 14 + #pragma omp parallel for num_threads(opt.num_threads) + for (int y = 0; y < h; y++) { - c.create(w, h, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; + const __fp16* ptr = a.row(y); + const __fp16* ptr1 = b; + __fp16* outptr = c.row<__fp16>(y); - if (b.w == 1) + if (elempack != 1) { - // type 11 - return binary_op_6_11_16_25_fp16s(a, b, c, opt); + for (int x = 0; x < w; x++) + { + int n = elempack; + while (n > 0) + { + size_t vl = vsetvl_e16m8(n); + vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); + vfloat16m8_t _outp = op(_p, *ptr1, vl); + vse16_v_f16m8(outptr, _outp, vl); + n -= vl; + ptr += vl; + outptr += vl; + } + ptr1 += 1; + } } - - // type 12 - const __fp16* ptr = a; - __fp16* outptr = c; - - for (int y = 0; y < h; y++) + if (elempack == 1) { - const __fp16 b0 = ((const __fp16*)b)[y]; for (int x = 0; x < w; x++) { - outptr[x] = op(ptr[x], b0); + *outptr = op(*ptr, *ptr1); + ptr += 1; + ptr1 += 1; + outptr += 1; } - - ptr += w; - outptr += w; } - - return 0; } } - else if (a.dims == 1) - { - if (a.w == 1) - { - // type 2 3 4 20 - return binary_op_2_3_4_20_fp16s(a, b, c, opt); - } - if (b.dims == 4) + if (a.dims == 3 || a.dims == 4) + { + // type 15 16 17 18 19 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - // type 21 - c.create(w1, h1, d1, channels1, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; + const __fp16* ptr = a.channel(q); + __fp16* outptr = c.channel(q); - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) + for (int z = 0; z < d; z++) { - const __fp16 a0 = ((const __fp16*)a)[q]; - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); - - for (int i = 0; i < size1; i++) + int z1 = std::min(z, b.d - 1); + for (int y = 0; y < h; y++) { - outptr[i] = op(a0, ptr1[i]); - } - } - - return 0; - } - - if (b.dims == 3) - { - // type 9 - c.create(w1, h1, channels1, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; + int y1 = std::min(y, b.h - 1); - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const __fp16 a0 = ((const __fp16*)a)[q]; - const __fp16* ptr1 = b.channel(q); - __fp16* outptr = c.channel(q); + const __fp16* ptr1 = b.depth(z1).row(y1); - for (int i = 0; i < size1; i++) - { - outptr[i] = op(a0, ptr1[i]); + if (elempack != 1) + { + for (int x = 0; x < w; x++) + { + int n = elempack; + while (n > 0) + { + size_t vl = vsetvl_e16m8(n); + vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); + vfloat16m8_t _outp = op(_p, *ptr1, vl); + vse16_v_f16m8(outptr, _outp, vl); + n -= vl; + ptr += vl; + outptr += vl; + } + ptr1 += 1; + } + } + if (elempack == 1) + { + for (int x = 0; x < w; x++) + { + *outptr = op(*ptr, *ptr1); + ptr += 1; + ptr1 += 1; + outptr += 1; + } + } } } - - return 0; } + } - if (b.dims == 2) - { - // type 8 - c.create(w1, h1, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; - - const __fp16* ptr1 = b; - __fp16* outptr = c; + return 0; +} - for (int y = 0; y < h1; y++) - { - const __fp16 a0 = ((const __fp16*)a)[y]; - for (int x = 0; x < w1; x++) - { - outptr[x] = op(a0, ptr1[x]); - } +template +static int binary_op_broadcast_20_fp16s(const Mat& a, const Mat& b, Mat& c, const Option& opt) +{ + Op op; - ptr1 += w1; - outptr += w1; - } + int w = a.w; + int h = a.h; + int channels = a.c; + int elempack = a.elempack; - return 0; - } + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const __fp16* ptr = a.channel(q); + __fp16* outptr = c.channel(q); - if (b.dims == 1) + for (int y = 0; y < h; y++) { - c.create(w, elemsize, opt.blob_allocator); - if (c.empty()) - return -100; + const __fp16* ptr1 = b.channel(q); + + const int size = w * elempack; - if (b.w == 1) + int n = size; + while (n > 0) { - // type 6 - return binary_op_6_11_16_25_fp16s(a, b, c, opt); + size_t vl = vsetvl_e16m8(n); + vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); + vfloat16m8_t _p1 = vle16_v_f16m8(ptr1, vl); + vfloat16m8_t _outp = op(_p, _p1, vl); + vse16_v_f16m8(outptr, _outp, vl); + n -= vl; + ptr += vl; + ptr1 += vl; + outptr += vl; } - - // type 7 - binary_op_7_13_19_29_fp16s(a, b, c, opt); } } @@ -2688,7 +1179,7 @@ static int binary_op_fp16s(const Mat& a, const Mat& b, Mat& c, const Option& opt } template -static int binary_op_scalar_rvv_fp16s(Mat& a, float b, const Option& opt) +static int binary_op_scalar_inplace_fp16s(Mat& a, __fp16 b, const Option& opt) { Op op; int w = a.w; @@ -2710,7 +1201,6 @@ static int binary_op_scalar_rvv_fp16s(Mat& a, float b, const Option& opt) vfloat16m8_t _p = vle16_v_f16m8(ptr, vl); _p = op(_p, b, vl); vse16_v_f16m8(ptr, _p, vl); - n -= vl; ptr += vl; } @@ -2732,11 +1222,11 @@ namespace BinaryOp_riscv_functor { { \ return IMPLVV; \ } \ - vfloat16m8_t operator()(const vfloat16m8_t& x, const float y, const size_t vl) const \ + vfloat16m8_t operator()(const vfloat16m8_t& x, const __fp16& y, const size_t vl) const \ { \ return IMPLVS; \ } \ - vfloat16m8_t operator()(const float x, const vfloat16m8_t& y, const size_t vl) const \ + vfloat16m8_t operator()(const __fp16& x, const vfloat16m8_t& y, const size_t vl) const \ { \ return IMPLSV; \ } \ @@ -2753,6 +1243,7 @@ MAKE_FUNCTION(binary_op_min_fp16s, std::min(x, y), vfmin_vv_f16m8(x, y, vl), vfm MAKE_FUNCTION(binary_op_pow_fp16s, (__fp16)pow((float)x, (float)y), pow_ps(x, y, vl), pow_ps(x, vfmv_v_f_f16m8(y, vl), vl), pow_ps(vfmv_v_f_f16m8(x, vl), y, vl)) MAKE_FUNCTION(binary_op_rsub_fp16s, y - x, vfsub_vv_f16m8(y, x, vl), vfrsub_vf_f16m8(x, y, vl), vfsub_vf_f16m8(y, x, vl)) MAKE_FUNCTION(binary_op_rdiv_fp16s, y / x, vfdiv_vv_f16m8(y, x, vl), vfrdiv_vf_f16m8(x, y, vl), vfdiv_vf_f16m8(y, x, vl)) +MAKE_FUNCTION(binary_op_rpow_fp16s, (__fp16)pow((float)y, (float)x), pow_ps(y, x, vl), pow_ps(vfmv_v_f_f16m8(y, vl), x, vl), pow_ps(y, vfmv_v_f_f16m8(x, vl), vl)) // *INDENT-ON* // clang-format on @@ -2760,74 +1251,165 @@ MAKE_FUNCTION(binary_op_rdiv_fp16s, y / x, vfdiv_vv_f16m8(y, x, vl), vfrdiv_vf_f } // namespace BinaryOp_riscv_functor -int BinaryOp_riscv::forward_fp16s(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +static int binary_op_scalar_fp16s(const Mat& a, __fp16 b, Mat& c, int op_type, const Option& opt) { - const Mat& bottom_blob = bottom_blobs[0]; - const Mat& bottom_blob1 = bottom_blobs[1]; - Mat& top_blob = top_blobs[0]; - using namespace BinaryOp_riscv_functor; - int elempack = bottom_blob.elempack; - int elempack1 = bottom_blob1.elempack; - if (elempack != 1 || elempack1 != 1) - { - if (op_type == Operation_ADD) - return binary_op_rvv_fp16s(bottom_blob, bottom_blob1, top_blob, opt); - - if (op_type == Operation_SUB) - return binary_op_rvv_fp16s(bottom_blob, bottom_blob1, top_blob, opt); + if (op_type == BinaryOp::Operation_ADD) return binary_op_scalar_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_SUB) return binary_op_scalar_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_MUL) return binary_op_scalar_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_DIV) return binary_op_scalar_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_MAX) return binary_op_scalar_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_MIN) return binary_op_scalar_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_POW) return binary_op_scalar_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_RSUB) return binary_op_scalar_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_RDIV) return binary_op_scalar_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_RPOW) return binary_op_scalar_fp16s(a, b, c, opt); + + // should never reach here + return 0; +} - if (op_type == Operation_MUL) - return binary_op_rvv_fp16s(bottom_blob, bottom_blob1, top_blob, opt); +static int binary_op_no_broadcast_fp16s(const Mat& a, const Mat& b, Mat& c, int op_type, const Option& opt) +{ + using namespace BinaryOp_riscv_functor; - if (op_type == Operation_DIV) - return binary_op_rvv_fp16s(bottom_blob, bottom_blob1, top_blob, opt); + if (op_type == BinaryOp::Operation_ADD) return binary_op_no_broadcast_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_SUB) return binary_op_no_broadcast_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_MUL) return binary_op_no_broadcast_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_DIV) return binary_op_no_broadcast_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_MAX) return binary_op_no_broadcast_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_MIN) return binary_op_no_broadcast_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_POW) return binary_op_no_broadcast_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_RSUB) return binary_op_no_broadcast_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_RDIV) return binary_op_no_broadcast_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_RPOW) return binary_op_no_broadcast_fp16s(a, b, c, opt); + + // should never reach here + return 0; +} - if (op_type == Operation_MAX) - return binary_op_rvv_fp16s(bottom_blob, bottom_blob1, top_blob, opt); +static int binary_op_broadcast_inner_fp16s(const Mat& a, const Mat& b, Mat& c, int op_type, const Option& opt) +{ + // squeeze inner axes + Mat b2 = b; + if (b.dims == 2 && b.w == 1) + b2 = b.reshape(b.h); + else if (b.dims == 3 && b.h == 1) + b2 = b.reshape(b.c); + else if (b.dims == 3 && b.w == 1) + b2 = b.reshape(b.h, b.c); + else if (b.dims == 4 && b.d == 1) + b2 = b.reshape(b.c); + else if (b.dims == 4 && b.h == 1) + b2 = b.reshape(b.d, b.c); + else if (b.dims == 4 && b.w == 1) + b2 = b.reshape(b.h, b.d, b.c); - if (op_type == Operation_MIN) - return binary_op_rvv_fp16s(bottom_blob, bottom_blob1, top_blob, opt); + using namespace BinaryOp_riscv_functor; - if (op_type == Operation_POW) - return binary_op_rvv_fp16s(bottom_blob, bottom_blob1, top_blob, opt); + if (op_type == BinaryOp::Operation_ADD) return binary_op_broadcast_inner_fp16s(a, b2, c, opt); + if (op_type == BinaryOp::Operation_SUB) return binary_op_broadcast_inner_fp16s(a, b2, c, opt); + if (op_type == BinaryOp::Operation_MUL) return binary_op_broadcast_inner_fp16s(a, b2, c, opt); + if (op_type == BinaryOp::Operation_DIV) return binary_op_broadcast_inner_fp16s(a, b2, c, opt); + if (op_type == BinaryOp::Operation_MAX) return binary_op_broadcast_inner_fp16s(a, b2, c, opt); + if (op_type == BinaryOp::Operation_MIN) return binary_op_broadcast_inner_fp16s(a, b2, c, opt); + if (op_type == BinaryOp::Operation_POW) return binary_op_broadcast_inner_fp16s(a, b2, c, opt); + if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_inner_fp16s(a, b2, c, opt); + if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_inner_fp16s(a, b2, c, opt); + if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_inner_fp16s(a, b2, c, opt); + + // should never reach here + return 0; +} - if (op_type == Operation_RSUB) - return binary_op_rvv_fp16s(bottom_blob, bottom_blob1, top_blob, opt); +static int binary_op_broadcast_outer_fp16s(const Mat& a, const Mat& b, Mat& c, int op_type, const Option& opt) +{ + using namespace BinaryOp_riscv_functor; - if (op_type == Operation_RDIV) - return binary_op_rvv_fp16s(bottom_blob, bottom_blob1, top_blob, opt); - } + if (op_type == BinaryOp::Operation_ADD) return binary_op_broadcast_outer_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_SUB) return binary_op_broadcast_outer_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_MUL) return binary_op_broadcast_outer_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_DIV) return binary_op_broadcast_outer_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_MAX) return binary_op_broadcast_outer_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_MIN) return binary_op_broadcast_outer_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_POW) return binary_op_broadcast_outer_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_outer_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_outer_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_outer_fp16s(a, b, c, opt); + + // should never reach here + return 0; +} - if (elempack == 1 && elempack1 == 1) - { - if (op_type == Operation_ADD) - return binary_op_fp16s(bottom_blob, bottom_blob1, top_blob, opt); +static int binary_op_broadcast_20_fp16s(const Mat& a, const Mat& b, Mat& c, int op_type, const Option& opt) +{ + using namespace BinaryOp_riscv_functor; - if (op_type == Operation_SUB) - return binary_op_fp16s(bottom_blob, bottom_blob1, top_blob, opt); + if (op_type == BinaryOp::Operation_ADD) return binary_op_broadcast_20_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_SUB) return binary_op_broadcast_20_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_MUL) return binary_op_broadcast_20_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_DIV) return binary_op_broadcast_20_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_MAX) return binary_op_broadcast_20_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_MIN) return binary_op_broadcast_20_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_POW) return binary_op_broadcast_20_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_20_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_20_fp16s(a, b, c, opt); + if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_20_fp16s(a, b, c, opt); + + // should never reach here + return 0; +} - if (op_type == Operation_MUL) - return binary_op_fp16s(bottom_blob, bottom_blob1, top_blob, opt); +int BinaryOp_riscv::forward_fp16s(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +{ + const bool b_is_scalar = bottom_blobs[1].w * bottom_blobs[1].h * bottom_blobs[1].d * bottom_blobs[1].c * bottom_blobs[1].elempack == 1; + const bool a_rank_is_lower = bottom_blobs[0].dims < bottom_blobs[1].dims && !b_is_scalar; + const bool a_size_is_lower = bottom_blobs[0].w * bottom_blobs[0].h * bottom_blobs[0].d * bottom_blobs[0].c * bottom_blobs[0].elempack < bottom_blobs[1].w * bottom_blobs[1].h * bottom_blobs[1].d * bottom_blobs[1].c * bottom_blobs[1].elempack; + const bool a_is_lower = a_rank_is_lower || (!a_rank_is_lower && a_size_is_lower); + const Mat& A = a_is_lower ? bottom_blobs[1] : bottom_blobs[0]; + const Mat& B = a_is_lower ? bottom_blobs[0] : bottom_blobs[1]; + const int op_type_r = a_is_lower ? get_reverse_op_type(op_type) : op_type; - if (op_type == Operation_DIV) - return binary_op_fp16s(bottom_blob, bottom_blob1, top_blob, opt); + Mat& top_blob = top_blobs[0]; + top_blob.create_like(A, opt.blob_allocator); + if (top_blob.empty()) + return -100; - if (op_type == Operation_MAX) - return binary_op_fp16s(bottom_blob, bottom_blob1, top_blob, opt); + // B is A scalar + if (B.w * B.h * B.d * B.c * B.elempack == 1) + { + return binary_op_scalar_fp16s(A, ((const __fp16*)B)[0], top_blob, op_type_r, opt); + } - if (op_type == Operation_MIN) - return binary_op_fp16s(bottom_blob, bottom_blob1, top_blob, opt); + // no broadcast + if (A.dims == B.dims && A.w == B.w && A.h == B.h && A.d == B.d && A.c == B.c && A.elempack == B.elempack) + { + return binary_op_no_broadcast_fp16s(A, B, top_blob, op_type_r, opt); + } - if (op_type == Operation_POW) - return binary_op_fp16s(bottom_blob, bottom_blob1, top_blob, opt); + // broadcast B for inner axis + if ((B.dims < A.dims) + || (A.dims == 2 && B.w == 1 && B.h == A.h) + || (A.dims == 3 && B.w == 1 && B.h == 1 && B.c == A.c) + || (A.dims == 3 && B.w == 1 && B.h == A.h && B.c == A.c) + || (A.dims == 4 && B.w == 1 && B.h == 1 && B.d == 1 && B.c == A.c) + || (A.dims == 4 && B.w == 1 && B.h == 1 && B.d == A.d && B.c == A.c) + || (A.dims == 4 && B.w == 1 && B.h == A.h && B.d == A.d && B.c == A.c)) + { + return binary_op_broadcast_inner_fp16s(A, B, top_blob, op_type_r, opt); + } - if (op_type == Operation_RSUB) - return binary_op_fp16s(bottom_blob1, bottom_blob, top_blob, opt); + // broadcast B for outer axis + if (B.elempack == 1 && ((A.dims == 2 && B.w == A.w && B.h == 1) || (A.dims == 3 && B.w == A.w && B.h == 1 && B.c == 1) || (A.dims == 3 && B.w == A.w && B.h == A.h && B.c == 1) || (A.dims == 4 && B.w == A.w && B.h == 1 && B.d == 1 && B.c == 1) || (A.dims == 4 && B.w == A.w && B.h == A.h && B.d == 1 && B.c == 1) || (A.dims == 4 && B.w == A.w && B.h == A.h && B.d == A.d && B.c == 1))) + { + return binary_op_broadcast_outer_fp16s(A, B, top_blob, op_type_r, opt); + } - if (op_type == Operation_RDIV) - return binary_op_fp16s(bottom_blob1, bottom_blob, top_blob, opt); + // some special broadcast rule here + if (A.dims == 3 && B.dims == 3 && A.w == B.w && B.h == 1 && A.c == B.c) + { + return binary_op_broadcast_20_fp16s(A, B, top_blob, op_type_r, opt); } return 0; @@ -2837,32 +1419,16 @@ int BinaryOp_riscv::forward_inplace_fp16s(Mat& bottom_top_blob, const Option& op { using namespace BinaryOp_riscv_functor; - if (op_type == Operation_ADD) - return binary_op_scalar_rvv_fp16s(bottom_top_blob, b, opt); - - if (op_type == Operation_SUB) - return binary_op_scalar_rvv_fp16s(bottom_top_blob, b, opt); - - if (op_type == Operation_MUL) - return binary_op_scalar_rvv_fp16s(bottom_top_blob, b, opt); - - if (op_type == Operation_DIV) - return binary_op_scalar_rvv_fp16s(bottom_top_blob, b, opt); - - if (op_type == Operation_MAX) - return binary_op_scalar_rvv_fp16s(bottom_top_blob, b, opt); - - if (op_type == Operation_MIN) - return binary_op_scalar_rvv_fp16s(bottom_top_blob, b, opt); - - if (op_type == Operation_POW) - return binary_op_scalar_rvv_fp16s(bottom_top_blob, b, opt); - - if (op_type == Operation_RSUB) - return binary_op_scalar_rvv_fp16s(bottom_top_blob, b, opt); - - if (op_type == Operation_RDIV) - return binary_op_scalar_rvv_fp16s(bottom_top_blob, b, opt); + if (op_type == Operation_ADD) return binary_op_scalar_inplace_fp16s(bottom_top_blob, (__fp16)b, opt); + if (op_type == Operation_SUB) return binary_op_scalar_inplace_fp16s(bottom_top_blob, (__fp16)b, opt); + if (op_type == Operation_MUL) return binary_op_scalar_inplace_fp16s(bottom_top_blob, (__fp16)b, opt); + if (op_type == Operation_DIV) return binary_op_scalar_inplace_fp16s(bottom_top_blob, (__fp16)b, opt); + if (op_type == Operation_MAX) return binary_op_scalar_inplace_fp16s(bottom_top_blob, (__fp16)b, opt); + if (op_type == Operation_MIN) return binary_op_scalar_inplace_fp16s(bottom_top_blob, (__fp16)b, opt); + if (op_type == Operation_POW) return binary_op_scalar_inplace_fp16s(bottom_top_blob, (__fp16)b, opt); + if (op_type == Operation_RSUB) return binary_op_scalar_inplace_fp16s(bottom_top_blob, (__fp16)b, opt); + if (op_type == Operation_RDIV) return binary_op_scalar_inplace_fp16s(bottom_top_blob, (__fp16)b, opt); + if (op_type == Operation_RPOW) return binary_op_scalar_inplace_fp16s(bottom_top_blob, (__fp16)b, opt); return 0; } diff --git a/src/layer/vulkan/binaryop_vulkan.cpp b/src/layer/vulkan/binaryop_vulkan.cpp index 4a1f7141e..855a76c54 100644 --- a/src/layer/vulkan/binaryop_vulkan.cpp +++ b/src/layer/vulkan/binaryop_vulkan.cpp @@ -29,13 +29,29 @@ BinaryOp_vulkan::BinaryOp_vulkan() pipeline_binaryop_pack4 = 0; pipeline_binaryop_pack8 = 0; - pipeline_binaryop_broadcast = 0; - pipeline_binaryop_broadcast_pack4 = 0; - pipeline_binaryop_broadcast_a1_pack4 = 0; - pipeline_binaryop_broadcast_b1_pack4 = 0; - pipeline_binaryop_broadcast_pack8 = 0; - pipeline_binaryop_broadcast_a1_pack8 = 0; - pipeline_binaryop_broadcast_b1_pack8 = 0; + pipeline_binaryop_broadcast_inner[0] = 0; + pipeline_binaryop_broadcast_inner[1] = 0; + pipeline_binaryop_broadcast_inner_pack4[0] = 0; + pipeline_binaryop_broadcast_inner_pack4[1] = 0; + pipeline_binaryop_broadcast_inner_pack8[0] = 0; + pipeline_binaryop_broadcast_inner_pack8[1] = 0; + pipeline_binaryop_broadcast_outer[0] = 0; + pipeline_binaryop_broadcast_outer[1] = 0; + pipeline_binaryop_broadcast_outer_pack4[0] = 0; + pipeline_binaryop_broadcast_outer_pack4[1] = 0; + pipeline_binaryop_broadcast_outer_pack8[0] = 0; + pipeline_binaryop_broadcast_outer_pack8[1] = 0; +} + +static int get_reverse_op_type(int op_type) +{ + if (op_type == BinaryOp::Operation_SUB) return BinaryOp::Operation_RSUB; + if (op_type == BinaryOp::Operation_DIV) return BinaryOp::Operation_RDIV; + if (op_type == BinaryOp::Operation_POW) return BinaryOp::Operation_RPOW; + if (op_type == BinaryOp::Operation_RSUB) return BinaryOp::Operation_SUB; + if (op_type == BinaryOp::Operation_RDIV) return BinaryOp::Operation_DIV; + if (op_type == BinaryOp::Operation_RPOW) return BinaryOp::Operation_POW; + return op_type; } int BinaryOp_vulkan::create_pipeline(const Option& opt) @@ -182,20 +198,33 @@ int BinaryOp_vulkan::create_pipeline(const Option& opt) // broadcast if (shape.dims == 0 || broadcast) { + bool a_is_lower = false; + if (shape.dims != 0 && shape1.dims != 0) + { + const bool b_is_scalar = shape1_packed.w * shape1_packed.h * shape1_packed.d * shape1_packed.c * shape1_packed.elempack == 1; + const bool a_rank_is_lower = shape_packed.dims < shape1_packed.dims && !b_is_scalar; + const bool a_size_is_lower = shape_packed.w * shape_packed.h * shape_packed.d * shape_packed.c * shape_packed.elempack < shape1_packed.w * shape1_packed.h * shape1_packed.d * shape1_packed.c * shape1_packed.elempack; + a_is_lower = a_rank_is_lower || (!a_rank_is_lower && a_size_is_lower); + } + const Mat& A_shape_packed = a_is_lower ? shape1_packed : shape_packed; + const Mat& B_shape_packed = a_is_lower ? shape_packed : shape1_packed; + + const int op_type_r = get_reverse_op_type(op_type); + std::vector specializations(1 + 18); specializations[0].i = op_type; - specializations[1 + 0].i = shape_packed.dims; - specializations[1 + 1].i = shape_packed.w; - specializations[1 + 2].i = shape_packed.h; - specializations[1 + 3].i = shape_packed.d; - specializations[1 + 4].i = shape_packed.c; - specializations[1 + 5].i = shape_packed.cstep; - specializations[1 + 6].i = shape1_packed.dims; - specializations[1 + 7].i = shape1_packed.w; - specializations[1 + 8].i = shape1_packed.h; - specializations[1 + 9].i = shape1_packed.d; - specializations[1 + 10].i = shape1_packed.c; - specializations[1 + 11].i = shape1_packed.cstep; + specializations[1 + 0].i = A_shape_packed.dims; + specializations[1 + 1].i = A_shape_packed.w; + specializations[1 + 2].i = A_shape_packed.h; + specializations[1 + 3].i = A_shape_packed.d; + specializations[1 + 4].i = A_shape_packed.c; + specializations[1 + 5].i = A_shape_packed.cstep; + specializations[1 + 6].i = B_shape_packed.dims; + specializations[1 + 7].i = B_shape_packed.w; + specializations[1 + 8].i = B_shape_packed.h; + specializations[1 + 9].i = B_shape_packed.d; + specializations[1 + 10].i = B_shape_packed.c; + specializations[1 + 11].i = B_shape_packed.cstep; specializations[1 + 12].i = out_shape_packed.dims; specializations[1 + 13].i = out_shape_packed.w; specializations[1 + 14].i = out_shape_packed.h; @@ -203,23 +232,26 @@ int BinaryOp_vulkan::create_pipeline(const Option& opt) specializations[1 + 16].i = out_shape_packed.c; specializations[1 + 17].i = out_shape_packed.cstep; - std::vector specializations_broadcast_a1_b1(1 + 15); - specializations_broadcast_a1_b1[0].i = op_type; - specializations_broadcast_a1_b1[1 + 0].i = shape_packed.dims; - specializations_broadcast_a1_b1[1 + 1].i = shape_packed.w; - specializations_broadcast_a1_b1[1 + 2].i = shape_packed.h * shape_packed.d; - specializations_broadcast_a1_b1[1 + 3].i = shape_packed.c; - specializations_broadcast_a1_b1[1 + 4].i = shape_packed.cstep; - specializations_broadcast_a1_b1[1 + 5].i = shape1_packed.dims; - specializations_broadcast_a1_b1[1 + 6].i = shape1_packed.w; - specializations_broadcast_a1_b1[1 + 7].i = shape1_packed.h * shape1_packed.d; - specializations_broadcast_a1_b1[1 + 8].i = shape1_packed.c; - specializations_broadcast_a1_b1[1 + 9].i = shape1_packed.cstep; - specializations_broadcast_a1_b1[1 + 10].i = out_shape_packed.dims; - specializations_broadcast_a1_b1[1 + 11].i = out_shape_packed.w; - specializations_broadcast_a1_b1[1 + 12].i = out_shape_packed.h * out_shape_packed.d; - specializations_broadcast_a1_b1[1 + 13].i = out_shape_packed.c; - specializations_broadcast_a1_b1[1 + 14].i = out_shape_packed.cstep; + std::vector specializations_r(1 + 18); + specializations_r[0].i = op_type_r; + specializations_r[1 + 0].i = A_shape_packed.dims; + specializations_r[1 + 1].i = A_shape_packed.w; + specializations_r[1 + 2].i = A_shape_packed.h; + specializations_r[1 + 3].i = A_shape_packed.d; + specializations_r[1 + 4].i = A_shape_packed.c; + specializations_r[1 + 5].i = A_shape_packed.cstep; + specializations_r[1 + 6].i = B_shape_packed.dims; + specializations_r[1 + 7].i = B_shape_packed.w; + specializations_r[1 + 8].i = B_shape_packed.h; + specializations_r[1 + 9].i = B_shape_packed.d; + specializations_r[1 + 10].i = B_shape_packed.c; + specializations_r[1 + 11].i = B_shape_packed.cstep; + specializations_r[1 + 12].i = out_shape_packed.dims; + specializations_r[1 + 13].i = out_shape_packed.w; + specializations_r[1 + 14].i = out_shape_packed.h; + specializations_r[1 + 15].i = out_shape_packed.d; + specializations_r[1 + 16].i = out_shape_packed.c; + specializations_r[1 + 17].i = out_shape_packed.cstep; Mat local_size_xyz; if (out_shape_packed.dims == 1) @@ -248,59 +280,75 @@ int BinaryOp_vulkan::create_pipeline(const Option& opt) } // pack1 - if (shape.dims == 0 || (elempack == 1 && elempack1 == 1)) + if (shape.dims == 0 || (out_elempack == 1)) { - pipeline_binaryop_broadcast = new Pipeline(vkdev); - pipeline_binaryop_broadcast->set_optimal_local_size_xyz(local_size_xyz); - pipeline_binaryop_broadcast->create(LayerShaderType::binaryop_broadcast, opt, specializations); + pipeline_binaryop_broadcast_inner[0] = new Pipeline(vkdev); + pipeline_binaryop_broadcast_inner[0]->set_optimal_local_size_xyz(local_size_xyz); + pipeline_binaryop_broadcast_inner[0]->create(LayerShaderType::binaryop_broadcast_inner, opt, specializations); + + pipeline_binaryop_broadcast_outer[0] = new Pipeline(vkdev); + pipeline_binaryop_broadcast_outer[0]->set_optimal_local_size_xyz(local_size_xyz); + pipeline_binaryop_broadcast_outer[0]->create(LayerShaderType::binaryop_broadcast_outer, opt, specializations); + + if (op_type_r != op_type) + { + // sub div pow ... + pipeline_binaryop_broadcast_inner[1] = new Pipeline(vkdev); + pipeline_binaryop_broadcast_inner[1]->set_optimal_local_size_xyz(local_size_xyz); + pipeline_binaryop_broadcast_inner[1]->create(LayerShaderType::binaryop_broadcast_inner, opt, specializations_r); + + pipeline_binaryop_broadcast_outer[1] = new Pipeline(vkdev); + pipeline_binaryop_broadcast_outer[1]->set_optimal_local_size_xyz(local_size_xyz); + pipeline_binaryop_broadcast_outer[1]->create(LayerShaderType::binaryop_broadcast_outer, opt, specializations_r); + } } // pack4 - if (shape.dims == 0 || (elempack == 4 && elempack1 == 4)) + if (shape.dims == 0 || (out_elempack == 4)) { - pipeline_binaryop_broadcast_pack4 = new Pipeline(vkdev); - pipeline_binaryop_broadcast_pack4->set_optimal_local_size_xyz(local_size_xyz); - pipeline_binaryop_broadcast_pack4->create(LayerShaderType::binaryop_broadcast_pack4, opt, specializations); - } + pipeline_binaryop_broadcast_inner_pack4[0] = new Pipeline(vkdev); + pipeline_binaryop_broadcast_inner_pack4[0]->set_optimal_local_size_xyz(local_size_xyz); + pipeline_binaryop_broadcast_inner_pack4[0]->create(LayerShaderType::binaryop_broadcast_inner_pack4, opt, specializations); - if (shape.dims == 0 || (shape.dims == 1 && shape.w == 1 && elempack == 1 && elempack1 == 4) - || (shape.dims == 3 && shape1.dims == 3 && shape1.w == shape.w && shape1.h == shape.h && shape.c == 1 && elempack == 1 && elempack1 == 4)) - { - pipeline_binaryop_broadcast_a1_pack4 = new Pipeline(vkdev); - pipeline_binaryop_broadcast_a1_pack4->set_optimal_local_size_xyz(local_size_xyz); - pipeline_binaryop_broadcast_a1_pack4->create(LayerShaderType::binaryop_broadcast_a1_pack4, opt, specializations_broadcast_a1_b1); - } + pipeline_binaryop_broadcast_outer_pack4[0] = new Pipeline(vkdev); + pipeline_binaryop_broadcast_outer_pack4[0]->set_optimal_local_size_xyz(local_size_xyz); + pipeline_binaryop_broadcast_outer_pack4[0]->create(LayerShaderType::binaryop_broadcast_outer_pack4, opt, specializations); - if (shape.dims == 0 || (shape1.dims == 1 && shape1.w == 1 && elempack1 == 1 && elempack == 4) - || (shape.dims == 3 && shape1.dims == 3 && shape1.w == shape.w && shape1.h == shape.h && shape1.c == 1 && elempack1 == 1 && elempack == 4)) - { - pipeline_binaryop_broadcast_b1_pack4 = new Pipeline(vkdev); - pipeline_binaryop_broadcast_b1_pack4->set_optimal_local_size_xyz(local_size_xyz); - pipeline_binaryop_broadcast_b1_pack4->create(LayerShaderType::binaryop_broadcast_b1_pack4, opt, specializations_broadcast_a1_b1); + if (op_type_r != op_type) + { + // sub div pow ... + pipeline_binaryop_broadcast_inner_pack4[1] = new Pipeline(vkdev); + pipeline_binaryop_broadcast_inner_pack4[1]->set_optimal_local_size_xyz(local_size_xyz); + pipeline_binaryop_broadcast_inner_pack4[1]->create(LayerShaderType::binaryop_broadcast_inner_pack4, opt, specializations_r); + + pipeline_binaryop_broadcast_outer_pack4[1] = new Pipeline(vkdev); + pipeline_binaryop_broadcast_outer_pack4[1]->set_optimal_local_size_xyz(local_size_xyz); + pipeline_binaryop_broadcast_outer_pack4[1]->create(LayerShaderType::binaryop_broadcast_outer_pack4, opt, specializations_r); + } } // pack8 - if ((opt.use_shader_pack8 && shape.dims == 0) || (elempack == 8 && elempack1 == 8)) + if ((opt.use_shader_pack8 && shape.dims == 0) || (out_elempack == 8)) { - pipeline_binaryop_broadcast_pack8 = new Pipeline(vkdev); - pipeline_binaryop_broadcast_pack8->set_optimal_local_size_xyz(local_size_xyz); - pipeline_binaryop_broadcast_pack8->create(LayerShaderType::binaryop_broadcast_pack8, opt, specializations); - } + pipeline_binaryop_broadcast_inner_pack8[0] = new Pipeline(vkdev); + pipeline_binaryop_broadcast_inner_pack8[0]->set_optimal_local_size_xyz(local_size_xyz); + pipeline_binaryop_broadcast_inner_pack8[0]->create(LayerShaderType::binaryop_broadcast_inner_pack8, opt, specializations); - if ((opt.use_shader_pack8 && shape.dims == 0) || (shape.dims == 1 && shape.w == 1 && elempack == 1 && elempack1 == 8) - || (shape.dims == 3 && shape1.dims == 3 && shape1.w == shape.w && shape1.h == shape.h && shape.c == 1 && elempack == 1 && elempack1 == 8)) - { - pipeline_binaryop_broadcast_a1_pack8 = new Pipeline(vkdev); - pipeline_binaryop_broadcast_a1_pack8->set_optimal_local_size_xyz(local_size_xyz); - pipeline_binaryop_broadcast_a1_pack8->create(LayerShaderType::binaryop_broadcast_a1_pack8, opt, specializations_broadcast_a1_b1); - } + pipeline_binaryop_broadcast_outer_pack8[0] = new Pipeline(vkdev); + pipeline_binaryop_broadcast_outer_pack8[0]->set_optimal_local_size_xyz(local_size_xyz); + pipeline_binaryop_broadcast_outer_pack8[0]->create(LayerShaderType::binaryop_broadcast_outer_pack8, opt, specializations); - if ((opt.use_shader_pack8 && shape.dims == 0) || (shape1.dims == 1 && shape1.w == 1 && elempack1 == 1 && elempack == 8) - || (shape.dims == 3 && shape1.dims == 3 && shape1.w == shape.w && shape1.h == shape.h && shape1.c == 1 && elempack1 == 1 && elempack == 8)) - { - pipeline_binaryop_broadcast_b1_pack8 = new Pipeline(vkdev); - pipeline_binaryop_broadcast_b1_pack8->set_optimal_local_size_xyz(local_size_xyz); - pipeline_binaryop_broadcast_b1_pack8->create(LayerShaderType::binaryop_broadcast_b1_pack8, opt, specializations_broadcast_a1_b1); + if (op_type_r != op_type) + { + // sub div pow ... + pipeline_binaryop_broadcast_inner_pack8[1] = new Pipeline(vkdev); + pipeline_binaryop_broadcast_inner_pack8[1]->set_optimal_local_size_xyz(local_size_xyz); + pipeline_binaryop_broadcast_inner_pack8[1]->create(LayerShaderType::binaryop_broadcast_inner_pack8, opt, specializations_r); + + pipeline_binaryop_broadcast_outer_pack8[1] = new Pipeline(vkdev); + pipeline_binaryop_broadcast_outer_pack8[1]->set_optimal_local_size_xyz(local_size_xyz); + pipeline_binaryop_broadcast_outer_pack8[1]->create(LayerShaderType::binaryop_broadcast_outer_pack8, opt, specializations_r); + } } } @@ -318,167 +366,75 @@ int BinaryOp_vulkan::destroy_pipeline(const Option& /*opt*/) delete pipeline_binaryop_pack8; pipeline_binaryop_pack8 = 0; - delete pipeline_binaryop_broadcast; - pipeline_binaryop_broadcast = 0; - - delete pipeline_binaryop_broadcast_pack4; - pipeline_binaryop_broadcast_pack4 = 0; + delete pipeline_binaryop_broadcast_inner[0]; + delete pipeline_binaryop_broadcast_inner[1]; + pipeline_binaryop_broadcast_inner[0] = 0; + pipeline_binaryop_broadcast_inner[1] = 0; - delete pipeline_binaryop_broadcast_a1_pack4; - pipeline_binaryop_broadcast_a1_pack4 = 0; + delete pipeline_binaryop_broadcast_inner_pack4[0]; + delete pipeline_binaryop_broadcast_inner_pack4[1]; + pipeline_binaryop_broadcast_inner_pack4[0] = 0; + pipeline_binaryop_broadcast_inner_pack4[1] = 0; - delete pipeline_binaryop_broadcast_b1_pack4; - pipeline_binaryop_broadcast_b1_pack4 = 0; + delete pipeline_binaryop_broadcast_inner_pack8[0]; + delete pipeline_binaryop_broadcast_inner_pack8[1]; + pipeline_binaryop_broadcast_inner_pack8[0] = 0; + pipeline_binaryop_broadcast_inner_pack8[1] = 0; - delete pipeline_binaryop_broadcast_pack8; - pipeline_binaryop_broadcast_pack8 = 0; + delete pipeline_binaryop_broadcast_outer[0]; + delete pipeline_binaryop_broadcast_outer[1]; + pipeline_binaryop_broadcast_outer[0] = 0; + pipeline_binaryop_broadcast_outer[1] = 0; - delete pipeline_binaryop_broadcast_a1_pack8; - pipeline_binaryop_broadcast_a1_pack8 = 0; + delete pipeline_binaryop_broadcast_outer_pack4[0]; + delete pipeline_binaryop_broadcast_outer_pack4[1]; + pipeline_binaryop_broadcast_outer_pack4[0] = 0; + pipeline_binaryop_broadcast_outer_pack4[1] = 0; - delete pipeline_binaryop_broadcast_b1_pack8; - pipeline_binaryop_broadcast_b1_pack8 = 0; + delete pipeline_binaryop_broadcast_outer_pack8[0]; + delete pipeline_binaryop_broadcast_outer_pack8[1]; + pipeline_binaryop_broadcast_outer_pack8[0] = 0; + pipeline_binaryop_broadcast_outer_pack8[1] = 0; return 0; } int BinaryOp_vulkan::forward(const std::vector& bottom_blobs, std::vector& top_blobs, VkCompute& cmd, const Option& opt) const { - const VkMat& bottom_blob = bottom_blobs[0]; - const VkMat& bottom_blob1 = bottom_blobs[1]; + const bool b_is_scalar = bottom_blobs[1].w * bottom_blobs[1].h * bottom_blobs[1].d * bottom_blobs[1].c * bottom_blobs[1].elempack == 1; + const bool a_rank_is_lower = bottom_blobs[0].dims < bottom_blobs[1].dims && !b_is_scalar; + const bool a_size_is_lower = bottom_blobs[0].w * bottom_blobs[0].h * bottom_blobs[0].d * bottom_blobs[0].c * bottom_blobs[0].elempack < bottom_blobs[1].w * bottom_blobs[1].h * bottom_blobs[1].d * bottom_blobs[1].c * bottom_blobs[1].elempack; + const bool a_is_lower = a_rank_is_lower || (!a_rank_is_lower && a_size_is_lower); + const VkMat& A = a_is_lower ? bottom_blobs[1] : bottom_blobs[0]; + const VkMat& B = a_is_lower ? bottom_blobs[0] : bottom_blobs[1]; + const int op_type_r = a_is_lower ? get_reverse_op_type(op_type) : op_type; VkMat& top_blob = top_blobs[0]; - - // broadcast - if (bottom_blob.dims > bottom_blob1.dims) - { - top_blob.create_like(bottom_blob, opt.blob_vkallocator); - } - else if (bottom_blob.dims < bottom_blob1.dims) - { - top_blob.create_like(bottom_blob1, opt.blob_vkallocator); - } - else // if (bottom_blob.dims == bottom_blob1.dims) - { - if (bottom_blob.w * bottom_blob.h * bottom_blob.d * bottom_blob.c * bottom_blob.elempack >= bottom_blob1.w * bottom_blob1.h * bottom_blob1.d * bottom_blob1.c * bottom_blob1.elempack) - { - top_blob.create_like(bottom_blob, opt.blob_vkallocator); - } - else - { - top_blob.create_like(bottom_blob1, opt.blob_vkallocator); - } - } + top_blob.create_like(A, opt.blob_vkallocator); if (top_blob.empty()) return -100; int out_elempack = top_blob.elempack; std::vector bindings(3); - bindings[0] = bottom_blob; - bindings[1] = bottom_blob1; + bindings[0] = A; + bindings[1] = B; bindings[2] = top_blob; - bool broadcast = true; - if (bottom_blob.dims == bottom_blob1.dims - && bottom_blob.w == bottom_blob1.w - && bottom_blob.h == bottom_blob1.h - && bottom_blob.d == bottom_blob1.d - && bottom_blob.c == bottom_blob1.c - && bottom_blob.elempack == bottom_blob1.elempack) - { - broadcast = false; - } - - if (broadcast) - { - std::vector constants(18); - constants[0].i = bottom_blob.dims; - constants[1].i = bottom_blob.w; - constants[2].i = bottom_blob.h; - constants[3].i = bottom_blob.d; - constants[4].i = bottom_blob.c; - constants[5].i = bottom_blob.cstep; - constants[6].i = bottom_blob1.dims; - constants[7].i = bottom_blob1.w; - constants[8].i = bottom_blob1.h; - constants[9].i = bottom_blob1.d; - constants[10].i = bottom_blob1.c; - constants[11].i = bottom_blob1.cstep; - constants[12].i = top_blob.dims; - constants[13].i = top_blob.w; - constants[14].i = top_blob.h; - constants[15].i = top_blob.d; - constants[16].i = top_blob.c; - constants[17].i = top_blob.cstep; - - std::vector constants_broadcast_a1b1(15); - constants_broadcast_a1b1[0].i = bottom_blob.dims; - constants_broadcast_a1b1[1].i = bottom_blob.w; - constants_broadcast_a1b1[2].i = bottom_blob.h * bottom_blob.d; - constants_broadcast_a1b1[3].i = bottom_blob.c; - constants_broadcast_a1b1[4].i = bottom_blob.cstep; - constants_broadcast_a1b1[5].i = bottom_blob1.dims; - constants_broadcast_a1b1[6].i = bottom_blob1.w; - constants_broadcast_a1b1[7].i = bottom_blob1.h * bottom_blob1.d; - constants_broadcast_a1b1[8].i = bottom_blob1.c; - constants_broadcast_a1b1[9].i = bottom_blob1.cstep; - constants_broadcast_a1b1[10].i = top_blob.dims; - constants_broadcast_a1b1[11].i = top_blob.w; - constants_broadcast_a1b1[12].i = top_blob.h * top_blob.d; - constants_broadcast_a1b1[13].i = top_blob.c; - constants_broadcast_a1b1[14].i = top_blob.cstep; - - bool broadcast_a1b1 = true; - - const Pipeline* pipeline = 0; - if (bottom_blob.elempack == 1 && bottom_blob1.elempack == 1) - { - pipeline = pipeline_binaryop_broadcast; - broadcast_a1b1 = false; - } - else - { - if (bottom_blob.dims == 1 && bottom_blob.w == 1 && bottom_blob.elempack == 1) - { - pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_a1_pack8 : pipeline_binaryop_broadcast_a1_pack4; - } - else if (bottom_blob1.dims == 1 && bottom_blob1.w == 1 && bottom_blob1.elempack == 1) - { - pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_b1_pack8 : pipeline_binaryop_broadcast_b1_pack4; - } - else if (bottom_blob.dims == 3 && bottom_blob1.dims == 3 && bottom_blob1.w == bottom_blob.w && bottom_blob1.h == bottom_blob.h && bottom_blob1.c == 1 && bottom_blob1.elempack == 1) - { - // special type 2 - pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_b1_pack8 : pipeline_binaryop_broadcast_b1_pack4; - } - else if (bottom_blob.dims == 3 && bottom_blob1.dims == 3 && bottom_blob1.w == bottom_blob.w && bottom_blob1.h == bottom_blob.h && bottom_blob.c == 1 && bottom_blob.elempack == 1) - { - // special type 4 - pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_a1_pack8 : pipeline_binaryop_broadcast_a1_pack4; - } - else - { - pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_pack8 : pipeline_binaryop_broadcast_pack4; - broadcast_a1b1 = false; - } - } - - cmd.record_pipeline(pipeline, bindings, broadcast_a1b1 ? constants_broadcast_a1b1 : constants, top_blob); - } - else + // no broadcast + if (A.dims == B.dims && A.w == B.w && A.h == B.h && A.d == B.d && A.c == B.c && A.elempack == B.elempack) { std::vector constants(15); - constants[0].i = bottom_blob.dims; - constants[1].i = bottom_blob.w; - constants[2].i = bottom_blob.h * bottom_blob.d; - constants[3].i = bottom_blob.c; - constants[4].i = bottom_blob.cstep; - constants[5].i = bottom_blob1.dims; - constants[6].i = bottom_blob1.w; - constants[7].i = bottom_blob1.h * bottom_blob1.d; - constants[8].i = bottom_blob1.c; - constants[9].i = bottom_blob1.cstep; + constants[0].i = A.dims; + constants[1].i = A.w; + constants[2].i = A.h * A.d; + constants[3].i = A.c; + constants[4].i = A.cstep; + constants[5].i = B.dims; + constants[6].i = B.w; + constants[7].i = B.h * B.d; + constants[8].i = B.c; + constants[9].i = B.cstep; constants[10].i = top_blob.dims; constants[11].i = top_blob.w; constants[12].i = top_blob.h * top_blob.d; @@ -490,8 +446,86 @@ int BinaryOp_vulkan::forward(const std::vector& bottom_blobs, std::vector : pipeline_binaryop; cmd.record_pipeline(pipeline, bindings, constants, top_blob); + + return 0; + } + + std::vector constants(18); + constants[0].i = A.dims; + constants[1].i = A.w; + constants[2].i = A.h; + constants[3].i = A.d; + constants[4].i = A.c; + constants[5].i = A.cstep; + constants[6].i = B.dims; + constants[7].i = B.w; + constants[8].i = B.h; + constants[9].i = B.d; + constants[10].i = B.c; + constants[11].i = B.cstep; + constants[12].i = top_blob.dims; + constants[13].i = top_blob.w; + constants[14].i = top_blob.h; + constants[15].i = top_blob.d; + constants[16].i = top_blob.c; + constants[17].i = top_blob.cstep; + + const int ri = op_type_r == op_type ? 0 : 1; + + if (B.w * B.h * B.d * B.c * B.elempack == 1) + { + const Pipeline* pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_outer_pack8[ri] + : out_elempack == 4 ? pipeline_binaryop_broadcast_outer_pack4[ri] + : pipeline_binaryop_broadcast_outer[ri]; + + cmd.record_pipeline(pipeline, bindings, constants, top_blob); + + return 0; + } + + // broadcast B for inner axis + if ((B.dims < A.dims) + || (A.dims == 2 && B.w == 1 && B.h == A.h) + || (A.dims == 3 && B.w == 1 && B.h == 1 && B.c == A.c) + || (A.dims == 3 && B.w == 1 && B.h == A.h && B.c == A.c) + || (A.dims == 4 && B.w == 1 && B.h == 1 && B.d == 1 && B.c == A.c) + || (A.dims == 4 && B.w == 1 && B.h == 1 && B.d == A.d && B.c == A.c) + || (A.dims == 4 && B.w == 1 && B.h == A.h && B.d == A.d && B.c == A.c)) + { + const Pipeline* pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_inner_pack8[ri] + : out_elempack == 4 ? pipeline_binaryop_broadcast_inner_pack4[ri] + : pipeline_binaryop_broadcast_inner[ri]; + + cmd.record_pipeline(pipeline, bindings, constants, top_blob); + + return 0; + } + + // broadcast B for outer axis + if (B.elempack == 1 && ((A.dims == 2 && B.w == A.w && B.h == 1) || (A.dims == 3 && B.w == A.w && B.h == 1 && B.c == 1) || (A.dims == 3 && B.w == A.w && B.h == A.h && B.c == 1) || (A.dims == 4 && B.w == A.w && B.h == 1 && B.d == 1 && B.c == 1) || (A.dims == 4 && B.w == A.w && B.h == A.h && B.d == 1 && B.c == 1) || (A.dims == 4 && B.w == A.w && B.h == A.h && B.d == A.d && B.c == 1))) + { + const Pipeline* pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_outer_pack8[ri] + : out_elempack == 4 ? pipeline_binaryop_broadcast_outer_pack4[ri] + : pipeline_binaryop_broadcast_outer[ri]; + + cmd.record_pipeline(pipeline, bindings, constants, top_blob); + + return 0; } + // some special broadcast rule here + if (A.dims == 3 && B.dims == 3 && A.w == B.w && B.h == 1 && A.c == B.c) + { + const Pipeline* pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_inner_pack8[ri] + : out_elempack == 4 ? pipeline_binaryop_broadcast_inner_pack4[ri] + : pipeline_binaryop_broadcast_inner[ri]; + + cmd.record_pipeline(pipeline, bindings, constants, top_blob); + + return 0; + } + + // should never reach here return 0; } @@ -522,141 +556,40 @@ int BinaryOp_vulkan::forward_inplace(VkMat& bottom_top_blob, VkCompute& cmd, con int BinaryOp_vulkan::forward(const std::vector& bottom_blobs, std::vector& top_blobs, VkCompute& cmd, const Option& opt) const { - const VkImageMat& bottom_blob = bottom_blobs[0]; - const VkImageMat& bottom_blob1 = bottom_blobs[1]; + const bool b_is_scalar = bottom_blobs[1].w * bottom_blobs[1].h * bottom_blobs[1].d * bottom_blobs[1].c * bottom_blobs[1].elempack == 1; + const bool a_rank_is_lower = bottom_blobs[0].dims < bottom_blobs[1].dims && !b_is_scalar; + const bool a_size_is_lower = bottom_blobs[0].w * bottom_blobs[0].h * bottom_blobs[0].d * bottom_blobs[0].c * bottom_blobs[0].elempack < bottom_blobs[1].w * bottom_blobs[1].h * bottom_blobs[1].d * bottom_blobs[1].c * bottom_blobs[1].elempack; + const bool a_is_lower = a_rank_is_lower || a_size_is_lower; + const VkImageMat& A = a_is_lower ? bottom_blobs[1] : bottom_blobs[0]; + const VkImageMat& B = a_is_lower ? bottom_blobs[0] : bottom_blobs[1]; + const int op_type_r = a_is_lower ? get_reverse_op_type(op_type) : op_type; VkImageMat& top_blob = top_blobs[0]; - - // broadcast - if (bottom_blob.dims > bottom_blob1.dims) - { - top_blob.create_like(bottom_blob, opt.blob_vkallocator); - } - else if (bottom_blob.dims < bottom_blob1.dims) - { - top_blob.create_like(bottom_blob1, opt.blob_vkallocator); - } - else // if (bottom_blob.dims == bottom_blob1.dims) - { - if (bottom_blob.w * bottom_blob.h * bottom_blob.d * bottom_blob.c * bottom_blob.elempack >= bottom_blob1.w * bottom_blob1.h * bottom_blob1.d * bottom_blob1.c * bottom_blob1.elempack) - { - top_blob.create_like(bottom_blob, opt.blob_vkallocator); - } - else - { - top_blob.create_like(bottom_blob1, opt.blob_vkallocator); - } - } + top_blob.create_like(A, opt.blob_vkallocator); if (top_blob.empty()) return -100; int out_elempack = top_blob.elempack; std::vector bindings(3); - bindings[0] = bottom_blob; - bindings[1] = bottom_blob1; + bindings[0] = A; + bindings[1] = B; bindings[2] = top_blob; - bool broadcast = true; - if (bottom_blob.dims == bottom_blob1.dims - && bottom_blob.w == bottom_blob1.w - && bottom_blob.h == bottom_blob1.h - && bottom_blob.d == bottom_blob1.d - && bottom_blob.c == bottom_blob1.c - && bottom_blob.elempack == bottom_blob1.elempack) - { - broadcast = false; - } - - if (broadcast) - { - std::vector constants(18); - constants[0].i = bottom_blob.dims; - constants[1].i = bottom_blob.w; - constants[2].i = bottom_blob.h; - constants[3].i = bottom_blob.d; - constants[4].i = bottom_blob.c; - constants[5].i = 0; //bottom_blob.cstep; - constants[6].i = bottom_blob1.dims; - constants[7].i = bottom_blob1.w; - constants[8].i = bottom_blob1.h; - constants[9].i = bottom_blob1.d; - constants[10].i = bottom_blob1.c; - constants[11].i = 0; //bottom_blob1.cstep; - constants[12].i = top_blob.dims; - constants[13].i = top_blob.w; - constants[14].i = top_blob.h; - constants[15].i = top_blob.d; - constants[16].i = top_blob.c; - constants[17].i = 0; //top_blob.cstep; - - std::vector constants_broadcast_a1b1(15); - constants_broadcast_a1b1[0].i = bottom_blob.dims; - constants_broadcast_a1b1[1].i = bottom_blob.w; - constants_broadcast_a1b1[2].i = bottom_blob.h * bottom_blob.d; - constants_broadcast_a1b1[3].i = bottom_blob.c; - constants_broadcast_a1b1[4].i = 0; //bottom_blob.cstep; - constants_broadcast_a1b1[5].i = bottom_blob1.dims; - constants_broadcast_a1b1[6].i = bottom_blob1.w; - constants_broadcast_a1b1[7].i = bottom_blob1.h * bottom_blob1.d; - constants_broadcast_a1b1[8].i = bottom_blob1.c; - constants_broadcast_a1b1[9].i = 0; //bottom_blob1.cstep; - constants_broadcast_a1b1[10].i = top_blob.dims; - constants_broadcast_a1b1[11].i = top_blob.w; - constants_broadcast_a1b1[12].i = top_blob.h * top_blob.d; - constants_broadcast_a1b1[13].i = top_blob.c; - constants_broadcast_a1b1[14].i = 0; //top_blob.cstep; - - bool broadcast_a1b1 = true; - - const Pipeline* pipeline = 0; - if (bottom_blob.elempack == 1 && bottom_blob1.elempack == 1) - { - pipeline = pipeline_binaryop_broadcast; - broadcast_a1b1 = false; - } - else - { - if (bottom_blob.dims == 1 && bottom_blob.w == 1 && bottom_blob.elempack == 1) - { - pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_a1_pack8 : pipeline_binaryop_broadcast_a1_pack4; - } - else if (bottom_blob1.dims == 1 && bottom_blob1.w == 1 && bottom_blob1.elempack == 1) - { - pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_b1_pack8 : pipeline_binaryop_broadcast_b1_pack4; - } - else if (bottom_blob.dims == 3 && bottom_blob1.dims == 3 && bottom_blob1.w == bottom_blob.w && bottom_blob1.h == bottom_blob.h && bottom_blob1.c == 1 && bottom_blob1.elempack == 1) - { - // special type 2 - pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_b1_pack8 : pipeline_binaryop_broadcast_b1_pack4; - } - else if (bottom_blob.dims == 3 && bottom_blob1.dims == 3 && bottom_blob1.w == bottom_blob.w && bottom_blob1.h == bottom_blob.h && bottom_blob.c == 1 && bottom_blob.elempack == 1) - { - // special type 4 - pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_a1_pack8 : pipeline_binaryop_broadcast_a1_pack4; - } - else - { - pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_pack8 : pipeline_binaryop_broadcast_pack4; - broadcast_a1b1 = false; - } - } - - cmd.record_pipeline(pipeline, bindings, broadcast_a1b1 ? constants_broadcast_a1b1 : constants, top_blob); - } - else + // no broadcast + if (A.dims == B.dims && A.w == B.w && A.h == B.h && A.d == B.d && A.c == B.c && A.elempack == B.elempack) { std::vector constants(15); - constants[0].i = bottom_blob.dims; - constants[1].i = bottom_blob.w; - constants[2].i = bottom_blob.h * bottom_blob.d; - constants[3].i = bottom_blob.c; - constants[4].i = 0; //bottom_blob.cstep; - constants[5].i = bottom_blob1.dims; - constants[6].i = bottom_blob1.w; - constants[7].i = bottom_blob1.h * bottom_blob1.d; - constants[8].i = bottom_blob1.c; - constants[9].i = 0; //bottom_blob1.cstep; + constants[0].i = A.dims; + constants[1].i = A.w; + constants[2].i = A.h * A.d; + constants[3].i = A.c; + constants[4].i = 0; //A.cstep; + constants[5].i = B.dims; + constants[6].i = B.w; + constants[7].i = B.h * B.d; + constants[8].i = B.c; + constants[9].i = 0; //B.cstep; constants[10].i = top_blob.dims; constants[11].i = top_blob.w; constants[12].i = top_blob.h * top_blob.d; @@ -668,8 +601,86 @@ int BinaryOp_vulkan::forward(const std::vector& bottom_blobs, std::v : pipeline_binaryop; cmd.record_pipeline(pipeline, bindings, constants, top_blob); + + return 0; + } + + std::vector constants(18); + constants[0].i = A.dims; + constants[1].i = A.w; + constants[2].i = A.h; + constants[3].i = A.d; + constants[4].i = A.c; + constants[5].i = 0; //A.cstep; + constants[6].i = B.dims; + constants[7].i = B.w; + constants[8].i = B.h; + constants[9].i = B.d; + constants[10].i = B.c; + constants[11].i = 0; //B.cstep; + constants[12].i = top_blob.dims; + constants[13].i = top_blob.w; + constants[14].i = top_blob.h; + constants[15].i = top_blob.d; + constants[16].i = top_blob.c; + constants[17].i = 0; //top_blob.cstep; + + const int ri = op_type_r == op_type ? 0 : 1; + + if (B.w * B.h * B.d * B.c * B.elempack == 1) + { + const Pipeline* pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_outer_pack8[ri] + : out_elempack == 4 ? pipeline_binaryop_broadcast_outer_pack4[ri] + : pipeline_binaryop_broadcast_outer[ri]; + + cmd.record_pipeline(pipeline, bindings, constants, top_blob); + + return 0; + } + + // broadcast B for inner axis + if ((B.dims < A.dims) + || (A.dims == 2 && B.w == 1 && B.h == A.h) + || (A.dims == 3 && B.w == 1 && B.h == 1 && B.c == A.c) + || (A.dims == 3 && B.w == 1 && B.h == A.h && B.c == A.c) + || (A.dims == 4 && B.w == 1 && B.h == 1 && B.d == 1 && B.c == A.c) + || (A.dims == 4 && B.w == 1 && B.h == 1 && B.d == A.d && B.c == A.c) + || (A.dims == 4 && B.w == 1 && B.h == A.h && B.d == A.d && B.c == A.c)) + { + const Pipeline* pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_inner_pack8[ri] + : out_elempack == 4 ? pipeline_binaryop_broadcast_inner_pack4[ri] + : pipeline_binaryop_broadcast_inner[ri]; + + cmd.record_pipeline(pipeline, bindings, constants, top_blob); + + return 0; + } + + // broadcast B for outer axis + if (B.elempack == 1 && ((A.dims == 2 && B.w == A.w && B.h == 1) || (A.dims == 3 && B.w == A.w && B.h == 1 && B.c == 1) || (A.dims == 3 && B.w == A.w && B.h == A.h && B.c == 1) || (A.dims == 4 && B.w == A.w && B.h == 1 && B.d == 1 && B.c == 1) || (A.dims == 4 && B.w == A.w && B.h == A.h && B.d == 1 && B.c == 1) || (A.dims == 4 && B.w == A.w && B.h == A.h && B.d == A.d && B.c == 1))) + { + const Pipeline* pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_outer_pack8[ri] + : out_elempack == 4 ? pipeline_binaryop_broadcast_outer_pack4[ri] + : pipeline_binaryop_broadcast_outer[ri]; + + cmd.record_pipeline(pipeline, bindings, constants, top_blob); + + return 0; + } + + // some special broadcast rule here + if (A.dims == 3 && B.dims == 3 && A.w == B.w && B.h == 1 && A.c == B.c) + { + const Pipeline* pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_inner_pack8[ri] + : out_elempack == 4 ? pipeline_binaryop_broadcast_inner_pack4[ri] + : pipeline_binaryop_broadcast_inner[ri]; + + cmd.record_pipeline(pipeline, bindings, constants, top_blob); + + return 0; } + // should never reach here return 0; } diff --git a/src/layer/vulkan/binaryop_vulkan.h b/src/layer/vulkan/binaryop_vulkan.h index f33cc7d54..da3cf5096 100644 --- a/src/layer/vulkan/binaryop_vulkan.h +++ b/src/layer/vulkan/binaryop_vulkan.h @@ -43,13 +43,12 @@ public: Pipeline* pipeline_binaryop_pack8; // broadcast - Pipeline* pipeline_binaryop_broadcast; - Pipeline* pipeline_binaryop_broadcast_pack4; - Pipeline* pipeline_binaryop_broadcast_a1_pack4; - Pipeline* pipeline_binaryop_broadcast_b1_pack4; - Pipeline* pipeline_binaryop_broadcast_pack8; - Pipeline* pipeline_binaryop_broadcast_a1_pack8; - Pipeline* pipeline_binaryop_broadcast_b1_pack8; + Pipeline* pipeline_binaryop_broadcast_inner[2]; + Pipeline* pipeline_binaryop_broadcast_inner_pack4[2]; + Pipeline* pipeline_binaryop_broadcast_inner_pack8[2]; + Pipeline* pipeline_binaryop_broadcast_outer[2]; + Pipeline* pipeline_binaryop_broadcast_outer_pack4[2]; + Pipeline* pipeline_binaryop_broadcast_outer_pack8[2]; }; } // namespace ncnn diff --git a/src/layer/vulkan/shader/binaryop.comp b/src/layer/vulkan/shader/binaryop.comp index 97189beaf..0c2046d2d 100644 --- a/src/layer/vulkan/shader/binaryop.comp +++ b/src/layer/vulkan/shader/binaryop.comp @@ -92,52 +92,48 @@ void main() afp v1 = buffer_ld1(a_blob_data, gi); #endif - afp res; + afp v2; if (with_scalar == 1) { - // type 5 10 15 - afp b = afp(const_b); - - if (op_type == 0) res = v1 + b; - if (op_type == 1) res = v1 - b; - if (op_type == 2) res = v1 * b; - if (op_type == 3) res = v1 / b; - if (op_type == 4) res = max(v1, b); - if (op_type == 5) res = min(v1, b); - if (op_type == 6) res = pow(v1, b); - if (op_type == 7) res = b - v1; - if (op_type == 8) res = b / v1; - + // type 0 1 2 3 + v2 = afp(const_b); + } + else if (psc(bdims) == 1 && psc(bw) == 1) + { + // type 0 1 2 3 #if NCNN_image_shader - image3d_st1(top_blob_3d, ivec3(gx, gy, gz), res); + v2 = image3d_ld1(b_blob_3d, ivec3(0, 0, 0)); #else - buffer_st1(a_blob_data, gi, res); + v2 = buffer_ld1(b_blob_data, 0); #endif } else { - // type 7 13 19 + // type 4 5 6 7 #if NCNN_image_shader - afp v2 = image3d_ld1(b_blob_3d, ivec3(gx, gy, gz)); + v2 = image3d_ld1(b_blob_3d, ivec3(gx, gy, gz)); #else - afp v2 = buffer_ld1(b_blob_data, gi); + v2 = buffer_ld1(b_blob_data, gi); #endif + } + + afp res; - if (op_type == 0) res = v1 + v2; - if (op_type == 1) res = v1 - v2; - if (op_type == 2) res = v1 * v2; - if (op_type == 3) res = v1 / v2; - if (op_type == 4) res = max(v1, v2); - if (op_type == 5) res = min(v1, v2); - if (op_type == 6) res = pow(v1, v2); - if (op_type == 7) res = v2 - v1; - if (op_type == 8) res = v2 / v1; + if (op_type == 0) res = v1 + v2; + if (op_type == 1) res = v1 - v2; + if (op_type == 2) res = v1 * v2; + if (op_type == 3) res = v1 / v2; + if (op_type == 4) res = max(v1, v2); + if (op_type == 5) res = min(v1, v2); + if (op_type == 6) res = pow(v1, v2); + if (op_type == 7) res = v2 - v1; + if (op_type == 8) res = v2 / v1; + if (op_type == 9) res = pow(v2, v1); #if NCNN_image_shader - image3d_st1(top_blob_3d, ivec3(gx, gy, gz), res); + image3d_st1(top_blob_3d, ivec3(gx, gy, gz), res); #else - buffer_st1(top_blob_data, gi, res); + buffer_st1(top_blob_data, gi, res); #endif - } } diff --git a/src/layer/vulkan/shader/binaryop_broadcast.comp b/src/layer/vulkan/shader/binaryop_broadcast.comp deleted file mode 100644 index 64d063e65..000000000 --- a/src/layer/vulkan/shader/binaryop_broadcast.comp +++ /dev/null @@ -1,553 +0,0 @@ -// Tencent is pleased to support the open source community by making ncnn available. -// -// Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except -// in compliance with the License. You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software distributed -// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -// CONDITIONS OF ANY KIND, either express or implied. See the License for the -// specific language governing permissions and limitations under the License. - -#version 450 - -#if NCNN_fp16_storage -#extension GL_EXT_shader_16bit_storage: require -#endif -#if NCNN_fp16_arithmetic -#extension GL_EXT_shader_explicit_arithmetic_types_float16: require -#endif - -layout (constant_id = 0) const int op_type = 0; - -#define shape_constant_id_offset 1 -layout (constant_id = shape_constant_id_offset + 0) const int adims = 0; -layout (constant_id = shape_constant_id_offset + 1) const int aw = 0; -layout (constant_id = shape_constant_id_offset + 2) const int ah = 0; -layout (constant_id = shape_constant_id_offset + 3) const int ad = 0; -layout (constant_id = shape_constant_id_offset + 4) const int ac = 0; -layout (constant_id = shape_constant_id_offset + 5) const int acstep = 0; - -layout (constant_id = shape_constant_id_offset + 6) const int bdims = 0; -layout (constant_id = shape_constant_id_offset + 7) const int bw = 0; -layout (constant_id = shape_constant_id_offset + 8) const int bh = 0; -layout (constant_id = shape_constant_id_offset + 9) const int bd = 0; -layout (constant_id = shape_constant_id_offset + 10) const int bc = 0; -layout (constant_id = shape_constant_id_offset + 11) const int bcstep = 0; - -layout (constant_id = shape_constant_id_offset + 12) const int outdims = 0; -layout (constant_id = shape_constant_id_offset + 13) const int outw = 0; -layout (constant_id = shape_constant_id_offset + 14) const int outh = 0; -layout (constant_id = shape_constant_id_offset + 15) const int outd = 0; -layout (constant_id = shape_constant_id_offset + 16) const int outc = 0; -layout (constant_id = shape_constant_id_offset + 17) const int outcstep = 0; - -#if NCNN_image_shader -layout (binding = 0) uniform unfp sampler3D a_blob_3d; -layout (binding = 1) uniform unfp sampler3D b_blob_3d; -layout (binding = 2, imfmtc1) writeonly uniform unfp image3D top_blob_3d; -#else -layout (binding = 0) readonly buffer a_blob { sfp a_blob_data[]; }; -layout (binding = 1) readonly buffer b_blob { sfp b_blob_data[]; }; -layout (binding = 2) writeonly buffer top_blob { sfp top_blob_data[]; }; -#endif - -layout (push_constant) uniform parameter -{ - int adims; - int aw; - int ah; - int ad; - int ac; - int acstep; - - int bdims; - int bw; - int bh; - int bd; - int bc; - int bcstep; - - int outdims; - int outw; - int outh; - int outd; - int outc; - int outcstep; -} p; - -void main() -{ - int gx = int(gl_GlobalInvocationID.x); - int gy = int(gl_GlobalInvocationID.y); - int gz = int(gl_GlobalInvocationID.z); - - if (gx >= psc(outw) || gy >= psc(outh) * psc(outd) || gz >= psc(outc)) - return; - -#if NCNN_image_shader - int ax = gx; - int ay = gy; - int az = gz; - int bx = gx; - int by = gy; - int bz = gz; - - if (psc(adims) == 4) - { - int yd = gy / psc(outh); - int yh = gy % psc(outh); - - if (psc(bdims) == 3) - { - // type 28 - bx = yh; - by = yd; - bz = gz; - } - - if (psc(bdims) == 2) - { - // type 27 - bx = yd; - by = gz; - bz = 0; - } - - if (psc(bdims) == 1) - { - if (psc(bw) == 1) - { - // type 25 - bx = 0; - by = 0; - bz = 0; - } - else - { - // type 26 - bx = gz; - by = 0; - bz = 0; - } - } - } - else if (psc(adims) == 3) - { - if (psc(bdims) == 4) - { - int yd = gy / psc(outh); - int yh = gy % psc(outh); - - // type 23 - ax = yh; - ay = yd; - az = gz; - } - - if (psc(bdims) == 3) - { - if (psc(bw) == 1 && psc(bh) == 1) - { - // special type 1 - bx = 0; - by = 0; - } - - if (psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(bc) == 1) - { - // special type 2 - bz = 0; - } - - if (psc(aw) == 1 && psc(ah) == 1) - { - // special type 3 - ax = 0; - ay = 0; - } - - if (psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(ac) == 1) - { - // special type 4 - az = 0; - } - - if (psc(aw) != 1 && psc(bw) == 1 && psc(bh) == psc(ah) && psc(bc) == psc(ac)) - { - // special type 5 - bx = 0; - } - - if (psc(bw) == psc(aw) && psc(ah) != 1 && psc(bh) == 1 && psc(bc) == psc(ac)) - { - // special type 6 - by = 0; - } - - if (psc(bw) != 1 && psc(aw) == 1 && psc(bh) == psc(ah) && psc(bc) == psc(ac)) - { - // special type 7 - ax = 0; - } - - if (psc(bw) == psc(aw) && psc(bh) != 1 && psc(ah) == 1 && psc(bc) == psc(ac)) - { - // special type 8 - ay = 0; - } - } - - if (psc(bdims) == 2) - { - // type 18 - bx = gy; - by = gz; - bz = 0; - } - - if (psc(bdims) == 1) - { - if (psc(bw) == 1) - { - // type 16 - bx = 0; - by = 0; - bz = 0; - } - else - { - // type 17 - bx = gz; - by = 0; - bz = 0; - } - } - } - else if (psc(adims) == 2) - { - if (psc(bdims) == 4) - { - int yd = gy / psc(outh); - int yh = gy % psc(outh); - - // type 22 - ax = yd; - ay = gz; - az = 0; - } - - if (psc(bdims) == 3) - { - // type 14 - ax = gy; - ay = gz; - az = 0; - } - - if (psc(bdims) == 1) - { - if (psc(bw) == 1) - { - // type 11 - bx = 0; - by = 0; - bz = 0; - } - else - { - // type 12 - bx = gy; - by = 0; - bz = 0; - } - } - } - else if (psc(adims) == 1) - { - if (psc(aw) == 1) - { - // type 2 3 4 20 - ax = 0; - ay = 0; - az = 0; - } - else - { - if (psc(bdims) == 4) - { - // type 21 - ax = gz; - ay = 0; - az = 0; - } - - if (psc(bdims) == 3) - { - // type 9 - ax = gz; - ay = 0; - az = 0; - } - - if (psc(bdims) == 2) - { - // type 8 - ax = gy; - ay = 0; - az = 0; - } - - if (psc(bdims) == 1) - { - if (psc(bw) == 1) - { - // type 6 - bx = 0; - by = 0; - bz = 0; - } - } - } - } - - afp v1 = image3d_ld1(a_blob_3d, ivec3(ax, ay, az)); - afp v2 = image3d_ld1(b_blob_3d, ivec3(bx, by, bz)); -#else - const int gi = gz * psc(outcstep) + gy * psc(outw) + gx; - - int ai; - int bi; - - if (psc(adims) == 4) - { - int yd = gy / psc(outh); - int yh = gy % psc(outh); - - if (psc(bdims) == 3) - { - // type 28 - ai = gi; - bi = gz * psc(bcstep) + yd * psc(bw) + yh; - } - - if (psc(bdims) == 2) - { - // type 27 - ai = gi; - bi = gz * psc(bw) + yd; - } - - if (psc(bdims) == 1) - { - if (psc(bw) == 1) - { - // type 25 - ai = gi; - bi = 0; - } - else - { - // type 26 - ai = gi; - bi = gz; - } - } - } - else if (psc(adims) == 3) - { - if (psc(bdims) == 4) - { - int yd = gy / psc(outh); - int yh = gy % psc(outh); - - // type 23 - ai = gz * psc(acstep) + yd * psc(aw) + yh; - bi = gi; - } - - if (psc(bdims) == 3) - { - if (psc(bw) == 1 && psc(bh) == 1) - { - // special type 1 - ai = gi; - bi = gz * psc(bcstep); - } - - if (psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(bc) == 1) - { - // special type 2 - ai = gi; - bi = gy * psc(bw) + gx; - } - - if (psc(aw) == 1 && psc(ah) == 1) - { - // special type 3 - ai = gz * psc(acstep); - bi = gi; - } - - if (psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(ac) == 1) - { - // special type 4 - ai = gy * psc(aw) + gx; - bi = gi; - } - - if (psc(aw) != 1 && psc(bw) == 1 && psc(bh) == psc(ah) && psc(bc) == psc(ac)) - { - // special type 5 - ai = gi; - bi = gz * psc(bcstep) + gy; - } - - if (psc(bw) == psc(aw) && psc(ah) != 1 && psc(bh) == 1 && psc(bc) == psc(ac)) - { - // special type 6 - ai = gi; - bi = gz * psc(bcstep) + gx; - } - - if (psc(bw) != 1 && psc(aw) == 1 && psc(bh) == psc(ah) && psc(bc) == psc(ac)) - { - // special type 7 - ai = gz * psc(acstep) + gy; - bi = gi; - } - - if (psc(bw) == psc(aw) && psc(bh) != 1 && psc(ah) == 1 && psc(bc) == psc(ac)) - { - // special type 8 - ai = gz * psc(acstep) + gx; - bi = gi; - } - } - - if (psc(bdims) == 2) - { - // type 18 - ai = gi; - bi = gz * psc(bw) + gy; - } - - if (psc(bdims) == 1) - { - if (psc(bw) == 1) - { - // type 16 - ai = gi; - bi = 0; - } - else - { - // type 17 - ai = gi; - bi = gz; - } - } - } - else if (psc(adims) == 2) - { - if (psc(bdims) == 4) - { - int yd = gy / psc(outh); - int yh = gy % psc(outh); - - // type 22 - ai = gz * psc(aw) + yd; - bi = gi; - } - - if (psc(bdims) == 3) - { - // type 14 - ai = gz * psc(aw) + gy; - bi = gi; - } - - if (psc(bdims) == 1) - { - if (psc(bw) == 1) - { - // type 11 - ai = gi; - bi = 0; - } - else - { - // type 12 - ai = gi; - bi = gy; - } - } - } - else if (psc(adims) == 1) - { - if (psc(aw) == 1) - { - // type 2 3 4 20 - ai = 0; - bi = gi; - } - else - { - if (psc(bdims) == 4) - { - // type 21 - ai = gz; - bi = gi; - } - - if (psc(bdims) == 3) - { - // type 9 - ai = gz; - bi = gi; - } - - if (psc(bdims) == 2) - { - // type 8 - ai = gy; - bi = gi; - } - - if (psc(bdims) == 1) - { - if (psc(bw) == 1) - { - // type 6 - ai = gi; - bi = 0; - } - } - } - } - - afp v1 = buffer_ld1(a_blob_data, ai); - afp v2 = buffer_ld1(b_blob_data, bi); -#endif - - afp res; - - if (op_type == 0) res = v1 + v2; - if (op_type == 1) res = v1 - v2; - if (op_type == 2) res = v1 * v2; - if (op_type == 3) res = v1 / v2; - if (op_type == 4) res = max(v1, v2); - if (op_type == 5) res = min(v1, v2); - if (op_type == 6) res = pow(v1, v2); - if (op_type == 7) res = v2 - v1; - if (op_type == 8) res = v2 / v1; - -#if NCNN_image_shader - image3d_st1(top_blob_3d, ivec3(gx, gy, gz), res); -#else - buffer_st1(top_blob_data, gi, res); -#endif -} diff --git a/src/layer/vulkan/shader/binaryop_broadcast_a1_pack8.comp b/src/layer/vulkan/shader/binaryop_broadcast_a1_pack8.comp deleted file mode 100644 index f44174803..000000000 --- a/src/layer/vulkan/shader/binaryop_broadcast_a1_pack8.comp +++ /dev/null @@ -1,169 +0,0 @@ -// Tencent is pleased to support the open source community by making ncnn available. -// -// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except -// in compliance with the License. You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software distributed -// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -// CONDITIONS OF ANY KIND, either express or implied. See the License for the -// specific language governing permissions and limitations under the License. - -#version 450 - -#if NCNN_fp16_storage -#extension GL_EXT_shader_16bit_storage: require -struct sfpvec8 { f16vec4 abcd; f16vec4 efgh; }; -#endif -#if NCNN_fp16_arithmetic -#extension GL_EXT_shader_explicit_arithmetic_types_float16: require -#endif - -layout (constant_id = 0) const int op_type = 0; - -#define shape_constant_id_offset 1 -layout (constant_id = shape_constant_id_offset + 0) const int adims = 0; -layout (constant_id = shape_constant_id_offset + 1) const int aw = 0; -layout (constant_id = shape_constant_id_offset + 2) const int ah = 0; -layout (constant_id = shape_constant_id_offset + 3) const int ac = 0; -layout (constant_id = shape_constant_id_offset + 4) const int acstep = 0; - -layout (constant_id = shape_constant_id_offset + 5) const int bdims = 0; -layout (constant_id = shape_constant_id_offset + 6) const int bw = 0; -layout (constant_id = shape_constant_id_offset + 7) const int bh = 0; -layout (constant_id = shape_constant_id_offset + 8) const int bc = 0; -layout (constant_id = shape_constant_id_offset + 9) const int bcstep = 0; - -layout (constant_id = shape_constant_id_offset + 10) const int outdims = 0; -layout (constant_id = shape_constant_id_offset + 11) const int outw = 0; -layout (constant_id = shape_constant_id_offset + 12) const int outh = 0; -layout (constant_id = shape_constant_id_offset + 13) const int outc = 0; -layout (constant_id = shape_constant_id_offset + 14) const int outcstep = 0; - -#if NCNN_image_shader -layout (binding = 0) uniform unfp sampler3D a_blob_3d; -layout (binding = 1) uniform unfp sampler3D b_blob_3d; -layout (binding = 2, imfmtc4) writeonly uniform unfp image3D top_blob_3d; -#else -layout (binding = 0) readonly buffer a_blob { sfp a_blob_data[]; }; -layout (binding = 1) readonly buffer b_blob { sfpvec8 b_blob_data[]; }; -layout (binding = 2) writeonly buffer top_blob { sfpvec8 top_blob_data[]; }; -#endif - -layout (push_constant) uniform parameter -{ - int adims; - int aw; - int ah; - int ac; - int acstep; - - int bdims; - int bw; - int bh; - int bc; - int bcstep; - - int outdims; - int outw; - int outh; - int outc; - int outcstep; -} p; - -void main() -{ - int gx = int(gl_GlobalInvocationID.x); - int gy = int(gl_GlobalInvocationID.y); - int gz = int(gl_GlobalInvocationID.z); - - if (gx >= psc(outw) || gy >= psc(outh) || gz >= psc(outc)) - return; - -#if NCNN_image_shader - afpvec4 v1; - - if (psc(adims) == 3 && psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(ac) == 1) - { - // special type 4 - v1 = afpvec4(image3d_ld1(a_blob_3d, ivec3(gx, gy, 0))); - } - else - { - v1 = afpvec4(image3d_ld1(a_blob_3d, ivec3(0, 0, 0))); - } - - afpvec8 v2 = image3d_ld8(b_blob_3d, ivec3(gx, gy, gz)); -#else - const int gi = gz * psc(outcstep) + gy * psc(outw) + gx; - - int ai = 0; - - if (psc(adims) == 3 && psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(ac) == 1) - { - // special type 2 - ai = gy * psc(bw) + gx; - } - - // type 2 3 4 - afpvec4 v1 = afpvec4(buffer_ld1(a_blob_data, ai)); - afpvec8 v2 = buffer_ld8(b_blob_data, gi); -#endif - - afpvec8 res; - - if (op_type == 0) - { - res[0] = v1 + v2[0]; - res[1] = v1 + v2[1]; - } - if (op_type == 1) - { - res[0] = v1 - v2[0]; - res[1] = v1 - v2[1]; - } - if (op_type == 2) - { - res[0] = v1 * v2[0]; - res[1] = v1 * v2[1]; - } - if (op_type == 3) - { - res[0] = v1 / v2[0]; - res[1] = v1 / v2[1]; - } - if (op_type == 4) - { - res[0] = max(v1, v2[0]); - res[1] = max(v1, v2[1]); - } - if (op_type == 5) - { - res[0] = min(v1, v2[0]); - res[1] = min(v1, v2[1]); - } - if (op_type == 6) - { - res[0] = pow(v1, v2[0]); - res[1] = pow(v1, v2[1]); - } - if (op_type == 7) - { - res[0] = v2[0] - v1; - res[1] = v2[1] - v1; - } - if (op_type == 8) - { - res[0] = v2[0] / v1; - res[1] = v2[1] / v1; - } - -#if NCNN_image_shader - image3d_st8(top_blob_3d, ivec3(gx, gy, gz), res); -#else - buffer_st8(top_blob_data, gi, res); -#endif -} diff --git a/src/layer/vulkan/shader/binaryop_broadcast_inner.comp b/src/layer/vulkan/shader/binaryop_broadcast_inner.comp new file mode 100644 index 000000000..f1d850107 --- /dev/null +++ b/src/layer/vulkan/shader/binaryop_broadcast_inner.comp @@ -0,0 +1,193 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#version 450 + +#if NCNN_fp16_storage +#extension GL_EXT_shader_16bit_storage: require +#endif +#if NCNN_fp16_arithmetic +#extension GL_EXT_shader_explicit_arithmetic_types_float16: require +#endif + +layout (constant_id = 0) const int op_type = 0; + +#define shape_constant_id_offset 1 +layout (constant_id = shape_constant_id_offset + 0) const int adims = 0; +layout (constant_id = shape_constant_id_offset + 1) const int aw = 0; +layout (constant_id = shape_constant_id_offset + 2) const int ah = 0; +layout (constant_id = shape_constant_id_offset + 3) const int ad = 0; +layout (constant_id = shape_constant_id_offset + 4) const int ac = 0; +layout (constant_id = shape_constant_id_offset + 5) const int acstep = 0; + +layout (constant_id = shape_constant_id_offset + 6) const int bdims = 0; +layout (constant_id = shape_constant_id_offset + 7) const int bw = 0; +layout (constant_id = shape_constant_id_offset + 8) const int bh = 0; +layout (constant_id = shape_constant_id_offset + 9) const int bd = 0; +layout (constant_id = shape_constant_id_offset + 10) const int bc = 0; +layout (constant_id = shape_constant_id_offset + 11) const int bcstep = 0; + +layout (constant_id = shape_constant_id_offset + 12) const int outdims = 0; +layout (constant_id = shape_constant_id_offset + 13) const int outw = 0; +layout (constant_id = shape_constant_id_offset + 14) const int outh = 0; +layout (constant_id = shape_constant_id_offset + 15) const int outd = 0; +layout (constant_id = shape_constant_id_offset + 16) const int outc = 0; +layout (constant_id = shape_constant_id_offset + 17) const int outcstep = 0; + +#if NCNN_image_shader +layout (binding = 0) uniform unfp sampler3D a_blob_3d; +layout (binding = 1) uniform unfp sampler3D b_blob_3d; +layout (binding = 2, imfmtc1) writeonly uniform unfp image3D top_blob_3d; +#else +layout (binding = 0) readonly buffer a_blob { sfp a_blob_data[]; }; +layout (binding = 1) readonly buffer b_blob { sfp b_blob_data[]; }; +layout (binding = 2) writeonly buffer top_blob { sfp top_blob_data[]; }; +#endif + +layout (push_constant) uniform parameter +{ + int adims; + int aw; + int ah; + int ad; + int ac; + int acstep; + + int bdims; + int bw; + int bh; + int bd; + int bc; + int bcstep; + + int outdims; + int outw; + int outh; + int outd; + int outc; + int outcstep; +} p; + +void main() +{ + int gx = int(gl_GlobalInvocationID.x); + int gy = int(gl_GlobalInvocationID.y); + int gz = int(gl_GlobalInvocationID.z); + + if (gx >= psc(outw) || gy >= psc(outh) * psc(outd) || gz >= psc(outc)) + return; + + int yd = gy / psc(outh); + int yh = gy % psc(outh); + + int bx = gx; + int by = gy; + int bz = gz; + + if (psc(adims) == psc(bdims)) + { + // explicit broadcast + bx = min(gx, psc(bw) - 1); + by = min(yd, psc(bd) - 1) * psc(bh) + min(yh, psc(bh) - 1); + bz = min(gz, psc(bc) - 1); + } + else + { + // implicit broadcast + if (psc(adims) == 4) + { + if (psc(bdims) == 3) + { + // type 13 + bx = yh; + by = yd; + bz = gz; + } + + if (psc(bdims) == 2) + { + // type 12 + bx = yd; + by = gz; + bz = 0; + } + + if (psc(bdims) == 1) + { + // type 11 + bx = gz; + by = 0; + bz = 0; + } + } + else if (psc(adims) == 3) + { + if (psc(bdims) == 2) + { + // type 10 + bx = gy; + by = gz; + bz = 0; + } + + if (psc(bdims) == 1) + { + // type 9 + bx = gz; + by = 0; + bz = 0; + } + } + else // if (psc(adims) == 2) + { + // if (psc(bdims) == 1) + { + // type 8 + bx = gy; + by = 0; + bz = 0; + } + } + } + +#if NCNN_image_shader + afp v1 = image3d_ld1(a_blob_3d, ivec3(gx, gy, gz)); + afp v2 = image3d_ld1(b_blob_3d, ivec3(bx, by, bz)); +#else + int gi = gz * psc(outcstep) + gy * psc(outw) + gx; + int bi = bz * psc(bcstep) + by * psc(bw) + bx; + + afp v1 = buffer_ld1(a_blob_data, gi); + afp v2 = buffer_ld1(b_blob_data, bi); +#endif + + afp res; + + if (op_type == 0) res = v1 + v2; + if (op_type == 1) res = v1 - v2; + if (op_type == 2) res = v1 * v2; + if (op_type == 3) res = v1 / v2; + if (op_type == 4) res = max(v1, v2); + if (op_type == 5) res = min(v1, v2); + if (op_type == 6) res = pow(v1, v2); + if (op_type == 7) res = v2 - v1; + if (op_type == 8) res = v2 / v1; + if (op_type == 9) res = pow(v2, v1); + +#if NCNN_image_shader + image3d_st1(top_blob_3d, ivec3(gx, gy, gz), res); +#else + buffer_st1(top_blob_data, gi, res); +#endif +} diff --git a/src/layer/vulkan/shader/binaryop_broadcast_inner_pack4.comp b/src/layer/vulkan/shader/binaryop_broadcast_inner_pack4.comp new file mode 100644 index 000000000..b3e502161 --- /dev/null +++ b/src/layer/vulkan/shader/binaryop_broadcast_inner_pack4.comp @@ -0,0 +1,193 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#version 450 + +#if NCNN_fp16_storage +#extension GL_EXT_shader_16bit_storage: require +#endif +#if NCNN_fp16_arithmetic +#extension GL_EXT_shader_explicit_arithmetic_types_float16: require +#endif + +layout (constant_id = 0) const int op_type = 0; + +#define shape_constant_id_offset 1 +layout (constant_id = shape_constant_id_offset + 0) const int adims = 0; +layout (constant_id = shape_constant_id_offset + 1) const int aw = 0; +layout (constant_id = shape_constant_id_offset + 2) const int ah = 0; +layout (constant_id = shape_constant_id_offset + 3) const int ad = 0; +layout (constant_id = shape_constant_id_offset + 4) const int ac = 0; +layout (constant_id = shape_constant_id_offset + 5) const int acstep = 0; + +layout (constant_id = shape_constant_id_offset + 6) const int bdims = 0; +layout (constant_id = shape_constant_id_offset + 7) const int bw = 0; +layout (constant_id = shape_constant_id_offset + 8) const int bh = 0; +layout (constant_id = shape_constant_id_offset + 9) const int bd = 0; +layout (constant_id = shape_constant_id_offset + 10) const int bc = 0; +layout (constant_id = shape_constant_id_offset + 11) const int bcstep = 0; + +layout (constant_id = shape_constant_id_offset + 12) const int outdims = 0; +layout (constant_id = shape_constant_id_offset + 13) const int outw = 0; +layout (constant_id = shape_constant_id_offset + 14) const int outh = 0; +layout (constant_id = shape_constant_id_offset + 15) const int outd = 0; +layout (constant_id = shape_constant_id_offset + 16) const int outc = 0; +layout (constant_id = shape_constant_id_offset + 17) const int outcstep = 0; + +#if NCNN_image_shader +layout (binding = 0) uniform unfp sampler3D a_blob_3d; +layout (binding = 1) uniform unfp sampler3D b_blob_3d; +layout (binding = 2, imfmtc4) writeonly uniform unfp image3D top_blob_3d; +#else +layout (binding = 0) readonly buffer a_blob { sfpvec4 a_blob_data[]; }; +layout (binding = 1) readonly buffer b_blob { sfpvec4 b_blob_data[]; }; +layout (binding = 2) writeonly buffer top_blob { sfpvec4 top_blob_data[]; }; +#endif + +layout (push_constant) uniform parameter +{ + int adims; + int aw; + int ah; + int ad; + int ac; + int acstep; + + int bdims; + int bw; + int bh; + int bd; + int bc; + int bcstep; + + int outdims; + int outw; + int outh; + int outd; + int outc; + int outcstep; +} p; + +void main() +{ + int gx = int(gl_GlobalInvocationID.x); + int gy = int(gl_GlobalInvocationID.y); + int gz = int(gl_GlobalInvocationID.z); + + if (gx >= psc(outw) || gy >= psc(outh) * psc(outd) || gz >= psc(outc)) + return; + + int yd = gy / psc(outh); + int yh = gy % psc(outh); + + int bx = gx; + int by = gy; + int bz = gz; + + if (psc(adims) == psc(bdims)) + { + // explicit broadcast + bx = min(gx, psc(bw) - 1); + by = min(yd, psc(bd) - 1) * psc(bh) + min(yh, psc(bh) - 1); + bz = min(gz, psc(bc) - 1); + } + else + { + // implicit broadcast + if (psc(adims) == 4) + { + if (psc(bdims) == 3) + { + // type 13 + bx = yh; + by = yd; + bz = gz; + } + + if (psc(bdims) == 2) + { + // type 12 + bx = yd; + by = gz; + bz = 0; + } + + if (psc(bdims) == 1) + { + // type 11 + bx = gz; + by = 0; + bz = 0; + } + } + else if (psc(adims) == 3) + { + if (psc(bdims) == 2) + { + // type 10 + bx = gy; + by = gz; + bz = 0; + } + + if (psc(bdims) == 1) + { + // type 9 + bx = gz; + by = 0; + bz = 0; + } + } + else // if (psc(adims) == 2) + { + // if (psc(bdims) == 1) + { + // type 8 + bx = gy; + by = 0; + bz = 0; + } + } + } + +#if NCNN_image_shader + afpvec4 v1 = image3d_ld4(a_blob_3d, ivec3(gx, gy, gz)); + afpvec4 v2 = image3d_ld4(b_blob_3d, ivec3(bx, by, bz)); +#else + int gi = gz * psc(outcstep) + gy * psc(outw) + gx; + int bi = bz * psc(bcstep) + by * psc(bw) + bx; + + afpvec4 v1 = buffer_ld4(a_blob_data, gi); + afpvec4 v2 = buffer_ld4(b_blob_data, bi); +#endif + + afpvec4 res; + + if (op_type == 0) res = v1 + v2; + if (op_type == 1) res = v1 - v2; + if (op_type == 2) res = v1 * v2; + if (op_type == 3) res = v1 / v2; + if (op_type == 4) res = max(v1, v2); + if (op_type == 5) res = min(v1, v2); + if (op_type == 6) res = pow(v1, v2); + if (op_type == 7) res = v2 - v1; + if (op_type == 8) res = v2 / v1; + if (op_type == 9) res = pow(v2, v1); + +#if NCNN_image_shader + image3d_st4(top_blob_3d, ivec3(gx, gy, gz), res); +#else + buffer_st4(top_blob_data, gi, res); +#endif +} diff --git a/src/layer/vulkan/shader/binaryop_broadcast_inner_pack8.comp b/src/layer/vulkan/shader/binaryop_broadcast_inner_pack8.comp new file mode 100644 index 000000000..1297f1630 --- /dev/null +++ b/src/layer/vulkan/shader/binaryop_broadcast_inner_pack8.comp @@ -0,0 +1,234 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#version 450 + +#if NCNN_fp16_storage +#extension GL_EXT_shader_16bit_storage: require +struct sfpvec8 { f16vec4 abcd; f16vec4 efgh; }; +#endif +#if NCNN_fp16_arithmetic +#extension GL_EXT_shader_explicit_arithmetic_types_float16: require +#endif + +layout (constant_id = 0) const int op_type = 0; + +#define shape_constant_id_offset 1 +layout (constant_id = shape_constant_id_offset + 0) const int adims = 0; +layout (constant_id = shape_constant_id_offset + 1) const int aw = 0; +layout (constant_id = shape_constant_id_offset + 2) const int ah = 0; +layout (constant_id = shape_constant_id_offset + 3) const int ad = 0; +layout (constant_id = shape_constant_id_offset + 4) const int ac = 0; +layout (constant_id = shape_constant_id_offset + 5) const int acstep = 0; + +layout (constant_id = shape_constant_id_offset + 6) const int bdims = 0; +layout (constant_id = shape_constant_id_offset + 7) const int bw = 0; +layout (constant_id = shape_constant_id_offset + 8) const int bh = 0; +layout (constant_id = shape_constant_id_offset + 9) const int bd = 0; +layout (constant_id = shape_constant_id_offset + 10) const int bc = 0; +layout (constant_id = shape_constant_id_offset + 11) const int bcstep = 0; + +layout (constant_id = shape_constant_id_offset + 12) const int outdims = 0; +layout (constant_id = shape_constant_id_offset + 13) const int outw = 0; +layout (constant_id = shape_constant_id_offset + 14) const int outh = 0; +layout (constant_id = shape_constant_id_offset + 15) const int outd = 0; +layout (constant_id = shape_constant_id_offset + 16) const int outc = 0; +layout (constant_id = shape_constant_id_offset + 17) const int outcstep = 0; + +#if NCNN_image_shader +layout (binding = 0) uniform unfp sampler3D a_blob_3d; +layout (binding = 1) uniform unfp sampler3D b_blob_3d; +layout (binding = 2, imfmtc4) writeonly uniform unfp image3D top_blob_3d; +#else +layout (binding = 0) readonly buffer a_blob { sfpvec8 a_blob_data[]; }; +layout (binding = 1) readonly buffer b_blob { sfpvec8 b_blob_data[]; }; +layout (binding = 2) writeonly buffer top_blob { sfpvec8 top_blob_data[]; }; +#endif + +layout (push_constant) uniform parameter +{ + int adims; + int aw; + int ah; + int ad; + int ac; + int acstep; + + int bdims; + int bw; + int bh; + int bd; + int bc; + int bcstep; + + int outdims; + int outw; + int outh; + int outd; + int outc; + int outcstep; +} p; + +void main() +{ + int gx = int(gl_GlobalInvocationID.x); + int gy = int(gl_GlobalInvocationID.y); + int gz = int(gl_GlobalInvocationID.z); + + if (gx >= psc(outw) || gy >= psc(outh) * psc(outd) || gz >= psc(outc)) + return; + + int yd = gy / psc(outh); + int yh = gy % psc(outh); + + int bx = gx; + int by = gy; + int bz = gz; + + if (psc(adims) == psc(bdims)) + { + // explicit broadcast + bx = min(gx, psc(bw) - 1); + by = min(yd, psc(bd) - 1) * psc(bh) + min(yh, psc(bh) - 1); + bz = min(gz, psc(bc) - 1); + } + else + { + // implicit broadcast + if (psc(adims) == 4) + { + if (psc(bdims) == 3) + { + // type 13 + bx = yh; + by = yd; + bz = gz; + } + + if (psc(bdims) == 2) + { + // type 12 + bx = yd; + by = gz; + bz = 0; + } + + if (psc(bdims) == 1) + { + // type 11 + bx = gz; + by = 0; + bz = 0; + } + } + else if (psc(adims) == 3) + { + if (psc(bdims) == 2) + { + // type 10 + bx = gy; + by = gz; + bz = 0; + } + + if (psc(bdims) == 1) + { + // type 9 + bx = gz; + by = 0; + bz = 0; + } + } + else // if (psc(adims) == 2) + { + // if (psc(bdims) == 1) + { + // type 8 + bx = gy; + by = 0; + bz = 0; + } + } + } + +#if NCNN_image_shader + afpvec8 v1 = image3d_ld8(a_blob_3d, ivec3(gx, gy, gz)); + afpvec8 v2 = image3d_ld8(b_blob_3d, ivec3(bx, by, bz)); +#else + int gi = gz * psc(outcstep) + gy * psc(outw) + gx; + int bi = bz * psc(bcstep) + by * psc(bw) + bx; + + afpvec8 v1 = buffer_ld8(a_blob_data, gi); + afpvec8 v2 = buffer_ld8(b_blob_data, bi); +#endif + + afpvec8 res; + + if (op_type == 0) + { + res[0] = v1[0] + v2[0]; + res[1] = v1[1] + v2[1]; + } + if (op_type == 1) + { + res[0] = v1[0] - v2[0]; + res[1] = v1[1] - v2[1]; + } + if (op_type == 2) + { + res[0] = v1[0] * v2[0]; + res[1] = v1[1] * v2[1]; + } + if (op_type == 3) + { + res[0] = v1[0] / v2[0]; + res[1] = v1[1] / v2[1]; + } + if (op_type == 4) + { + res[0] = max(v1[0], v2[0]); + res[1] = max(v1[1], v2[1]); + } + if (op_type == 5) + { + res[0] = min(v1[0], v2[0]); + res[1] = min(v1[1], v2[1]); + } + if (op_type == 6) + { + res[0] = pow(v1[0], v2[0]); + res[1] = pow(v1[1], v2[1]); + } + if (op_type == 7) + { + res[0] = v2[0] - v1[0]; + res[1] = v2[1] - v1[1]; + } + if (op_type == 8) + { + res[0] = v2[0] / v1[0]; + res[1] = v2[1] / v1[1]; + } + if (op_type == 9) + { + res[0] = pow(v2[0], v1[0]); + res[1] = pow(v2[1], v1[1]); + } + +#if NCNN_image_shader + image3d_st8(top_blob_3d, ivec3(gx, gy, gz), res); +#else + buffer_st8(top_blob_data, gi, res); +#endif +} diff --git a/src/layer/vulkan/shader/binaryop_broadcast_a1_pack4.comp b/src/layer/vulkan/shader/binaryop_broadcast_outer.comp similarity index 58% rename from src/layer/vulkan/shader/binaryop_broadcast_a1_pack4.comp rename to src/layer/vulkan/shader/binaryop_broadcast_outer.comp index aabe12d93..241ecd3f9 100644 --- a/src/layer/vulkan/shader/binaryop_broadcast_a1_pack4.comp +++ b/src/layer/vulkan/shader/binaryop_broadcast_outer.comp @@ -1,6 +1,6 @@ // Tencent is pleased to support the open source community by making ncnn available. // -// Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. // // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except // in compliance with the License. You may obtain a copy of the License at @@ -27,29 +27,32 @@ layout (constant_id = 0) const int op_type = 0; layout (constant_id = shape_constant_id_offset + 0) const int adims = 0; layout (constant_id = shape_constant_id_offset + 1) const int aw = 0; layout (constant_id = shape_constant_id_offset + 2) const int ah = 0; -layout (constant_id = shape_constant_id_offset + 3) const int ac = 0; -layout (constant_id = shape_constant_id_offset + 4) const int acstep = 0; - -layout (constant_id = shape_constant_id_offset + 5) const int bdims = 0; -layout (constant_id = shape_constant_id_offset + 6) const int bw = 0; -layout (constant_id = shape_constant_id_offset + 7) const int bh = 0; -layout (constant_id = shape_constant_id_offset + 8) const int bc = 0; -layout (constant_id = shape_constant_id_offset + 9) const int bcstep = 0; - -layout (constant_id = shape_constant_id_offset + 10) const int outdims = 0; -layout (constant_id = shape_constant_id_offset + 11) const int outw = 0; -layout (constant_id = shape_constant_id_offset + 12) const int outh = 0; -layout (constant_id = shape_constant_id_offset + 13) const int outc = 0; -layout (constant_id = shape_constant_id_offset + 14) const int outcstep = 0; +layout (constant_id = shape_constant_id_offset + 3) const int ad = 0; +layout (constant_id = shape_constant_id_offset + 4) const int ac = 0; +layout (constant_id = shape_constant_id_offset + 5) const int acstep = 0; + +layout (constant_id = shape_constant_id_offset + 6) const int bdims = 0; +layout (constant_id = shape_constant_id_offset + 7) const int bw = 0; +layout (constant_id = shape_constant_id_offset + 8) const int bh = 0; +layout (constant_id = shape_constant_id_offset + 9) const int bd = 0; +layout (constant_id = shape_constant_id_offset + 10) const int bc = 0; +layout (constant_id = shape_constant_id_offset + 11) const int bcstep = 0; + +layout (constant_id = shape_constant_id_offset + 12) const int outdims = 0; +layout (constant_id = shape_constant_id_offset + 13) const int outw = 0; +layout (constant_id = shape_constant_id_offset + 14) const int outh = 0; +layout (constant_id = shape_constant_id_offset + 15) const int outd = 0; +layout (constant_id = shape_constant_id_offset + 16) const int outc = 0; +layout (constant_id = shape_constant_id_offset + 17) const int outcstep = 0; #if NCNN_image_shader layout (binding = 0) uniform unfp sampler3D a_blob_3d; layout (binding = 1) uniform unfp sampler3D b_blob_3d; -layout (binding = 2, imfmtc4) writeonly uniform unfp image3D top_blob_3d; +layout (binding = 2, imfmtc1) writeonly uniform unfp image3D top_blob_3d; #else layout (binding = 0) readonly buffer a_blob { sfp a_blob_data[]; }; -layout (binding = 1) readonly buffer b_blob { sfpvec4 b_blob_data[]; }; -layout (binding = 2) writeonly buffer top_blob { sfpvec4 top_blob_data[]; }; +layout (binding = 1) readonly buffer b_blob { sfp b_blob_data[]; }; +layout (binding = 2) writeonly buffer top_blob { sfp top_blob_data[]; }; #endif layout (push_constant) uniform parameter @@ -57,18 +60,21 @@ layout (push_constant) uniform parameter int adims; int aw; int ah; + int ad; int ac; int acstep; int bdims; int bw; int bh; + int bd; int bc; int bcstep; int outdims; int outw; int outh; + int outd; int outc; int outcstep; } p; @@ -79,40 +85,29 @@ void main() int gy = int(gl_GlobalInvocationID.y); int gz = int(gl_GlobalInvocationID.z); - if (gx >= psc(outw) || gy >= psc(outh) || gz >= psc(outc)) + if (gx >= psc(outw) || gy >= psc(outh) * psc(outd) || gz >= psc(outc)) return; -#if NCNN_image_shader - afpvec4 v1; - - if (psc(adims) == 3 && psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(ac) == 1) - { - // special type 4 - v1 = afpvec4(image3d_ld1(a_blob_3d, ivec3(gx, gy, 0))); - } - else - { - v1 = afpvec4(image3d_ld1(a_blob_3d, ivec3(0, 0, 0))); - } - - afpvec4 v2 = image3d_ld4(b_blob_3d, ivec3(gx, gy, gz)); -#else - const int gi = gz * psc(outcstep) + gy * psc(outw) + gx; + int yd = gy / psc(outh); + int yh = gy % psc(outh); - int ai = 0; + // explicit broadcast + int bx = min(gx, psc(bw) - 1); + int by = min(yd, psc(bd) - 1) * psc(bh) + min(yh, psc(bh) - 1); + int bz = min(gz, psc(bc) - 1); - if (psc(adims) == 3 && psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(ac) == 1) - { - // special type 4 - ai = gy * psc(aw) + gx; - } +#if NCNN_image_shader + afp v1 = image3d_ld1(a_blob_3d, ivec3(gx, gy, gz)); + afp v2 = image3d_ld1(b_blob_3d, ivec3(bx, by, bz)); +#else + int gi = gz * psc(outcstep) + gy * psc(outw) + gx; + int bi = bz * psc(bcstep) + by * psc(bw) + bx; - // type 2 3 4 - afpvec4 v1 = afpvec4(buffer_ld1(a_blob_data, ai)); - afpvec4 v2 = buffer_ld4(b_blob_data, gi); + afp v1 = buffer_ld1(a_blob_data, gi); + afp v2 = buffer_ld1(b_blob_data, bi); #endif - afpvec4 res; + afp res; if (op_type == 0) res = v1 + v2; if (op_type == 1) res = v1 - v2; @@ -123,10 +118,11 @@ void main() if (op_type == 6) res = pow(v1, v2); if (op_type == 7) res = v2 - v1; if (op_type == 8) res = v2 / v1; + if (op_type == 9) res = pow(v2, v1); #if NCNN_image_shader - image3d_st4(top_blob_3d, ivec3(gx, gy, gz), res); + image3d_st1(top_blob_3d, ivec3(gx, gy, gz), res); #else - buffer_st4(top_blob_data, gi, res); + buffer_st1(top_blob_data, gi, res); #endif } diff --git a/src/layer/vulkan/shader/binaryop_broadcast_b1_pack4.comp b/src/layer/vulkan/shader/binaryop_broadcast_outer_pack4.comp similarity index 69% rename from src/layer/vulkan/shader/binaryop_broadcast_b1_pack4.comp rename to src/layer/vulkan/shader/binaryop_broadcast_outer_pack4.comp index 9a1c31b76..9b65e81a5 100644 --- a/src/layer/vulkan/shader/binaryop_broadcast_b1_pack4.comp +++ b/src/layer/vulkan/shader/binaryop_broadcast_outer_pack4.comp @@ -1,6 +1,6 @@ // Tencent is pleased to support the open source community by making ncnn available. // -// Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. // // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except // in compliance with the License. You may obtain a copy of the License at @@ -27,20 +27,23 @@ layout (constant_id = 0) const int op_type = 0; layout (constant_id = shape_constant_id_offset + 0) const int adims = 0; layout (constant_id = shape_constant_id_offset + 1) const int aw = 0; layout (constant_id = shape_constant_id_offset + 2) const int ah = 0; -layout (constant_id = shape_constant_id_offset + 3) const int ac = 0; -layout (constant_id = shape_constant_id_offset + 4) const int acstep = 0; - -layout (constant_id = shape_constant_id_offset + 5) const int bdims = 0; -layout (constant_id = shape_constant_id_offset + 6) const int bw = 0; -layout (constant_id = shape_constant_id_offset + 7) const int bh = 0; -layout (constant_id = shape_constant_id_offset + 8) const int bc = 0; -layout (constant_id = shape_constant_id_offset + 9) const int bcstep = 0; - -layout (constant_id = shape_constant_id_offset + 10) const int outdims = 0; -layout (constant_id = shape_constant_id_offset + 11) const int outw = 0; -layout (constant_id = shape_constant_id_offset + 12) const int outh = 0; -layout (constant_id = shape_constant_id_offset + 13) const int outc = 0; -layout (constant_id = shape_constant_id_offset + 14) const int outcstep = 0; +layout (constant_id = shape_constant_id_offset + 3) const int ad = 0; +layout (constant_id = shape_constant_id_offset + 4) const int ac = 0; +layout (constant_id = shape_constant_id_offset + 5) const int acstep = 0; + +layout (constant_id = shape_constant_id_offset + 6) const int bdims = 0; +layout (constant_id = shape_constant_id_offset + 7) const int bw = 0; +layout (constant_id = shape_constant_id_offset + 8) const int bh = 0; +layout (constant_id = shape_constant_id_offset + 9) const int bd = 0; +layout (constant_id = shape_constant_id_offset + 10) const int bc = 0; +layout (constant_id = shape_constant_id_offset + 11) const int bcstep = 0; + +layout (constant_id = shape_constant_id_offset + 12) const int outdims = 0; +layout (constant_id = shape_constant_id_offset + 13) const int outw = 0; +layout (constant_id = shape_constant_id_offset + 14) const int outh = 0; +layout (constant_id = shape_constant_id_offset + 15) const int outd = 0; +layout (constant_id = shape_constant_id_offset + 16) const int outc = 0; +layout (constant_id = shape_constant_id_offset + 17) const int outcstep = 0; #if NCNN_image_shader layout (binding = 0) uniform unfp sampler3D a_blob_3d; @@ -57,18 +60,21 @@ layout (push_constant) uniform parameter int adims; int aw; int ah; + int ad; int ac; int acstep; int bdims; int bw; int bh; + int bd; int bc; int bcstep; int outdims; int outw; int outh; + int outd; int outc; int outcstep; } p; @@ -79,35 +85,24 @@ void main() int gy = int(gl_GlobalInvocationID.y); int gz = int(gl_GlobalInvocationID.z); - if (gx >= psc(outw) || gy >= psc(outh) || gz >= psc(outc)) + if (gx >= psc(outw) || gy >= psc(outh) * psc(outd) || gz >= psc(outc)) return; -#if NCNN_image_shader - afpvec4 v2; - - if (psc(bdims) == 3 && psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(bc) == 1) - { - // special type 2 - v2 = afpvec4(image3d_ld1(b_blob_3d, ivec3(gx, gy, 0))); - } - else - { - v2 = afpvec4(image3d_ld1(b_blob_3d, ivec3(0, 0, 0))); - } + int yd = gy / psc(outh); + int yh = gy % psc(outh); + + // explicit broadcast + int bx = min(gx, psc(bw) - 1); + int by = min(yd, psc(bd) - 1) * psc(bh) + min(yh, psc(bh) - 1); + int bz = min(gz, psc(bc) - 1); +#if NCNN_image_shader afpvec4 v1 = image3d_ld4(a_blob_3d, ivec3(gx, gy, gz)); + afpvec4 v2 = afpvec4(image3d_ld1(b_blob_3d, ivec3(bx, by, bz))); #else - const int gi = gz * psc(outcstep) + gy * psc(outw) + gx; - - int bi = 0; - - if (psc(bdims) == 3 && psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(bc) == 1) - { - // special type 2 - bi = gy * psc(bw) + gx; - } + int gi = gz * psc(outcstep) + gy * psc(outw) + gx; + int bi = bz * psc(bcstep) + by * psc(bw) + bx; - // type 6 11 16 afpvec4 v1 = buffer_ld4(a_blob_data, gi); afpvec4 v2 = afpvec4(buffer_ld1(b_blob_data, bi)); #endif @@ -123,6 +118,7 @@ void main() if (op_type == 6) res = pow(v1, v2); if (op_type == 7) res = v2 - v1; if (op_type == 8) res = v2 / v1; + if (op_type == 9) res = pow(v2, v1); #if NCNN_image_shader image3d_st4(top_blob_3d, ivec3(gx, gy, gz), res); diff --git a/src/layer/vulkan/shader/binaryop_broadcast_b1_pack8.comp b/src/layer/vulkan/shader/binaryop_broadcast_outer_pack8.comp similarity index 58% rename from src/layer/vulkan/shader/binaryop_broadcast_b1_pack8.comp rename to src/layer/vulkan/shader/binaryop_broadcast_outer_pack8.comp index e2c98395f..d6a6a91dd 100644 --- a/src/layer/vulkan/shader/binaryop_broadcast_b1_pack8.comp +++ b/src/layer/vulkan/shader/binaryop_broadcast_outer_pack8.comp @@ -1,6 +1,6 @@ // Tencent is pleased to support the open source community by making ncnn available. // -// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved. +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. // // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except // in compliance with the License. You may obtain a copy of the License at @@ -28,20 +28,23 @@ layout (constant_id = 0) const int op_type = 0; layout (constant_id = shape_constant_id_offset + 0) const int adims = 0; layout (constant_id = shape_constant_id_offset + 1) const int aw = 0; layout (constant_id = shape_constant_id_offset + 2) const int ah = 0; -layout (constant_id = shape_constant_id_offset + 3) const int ac = 0; -layout (constant_id = shape_constant_id_offset + 4) const int acstep = 0; - -layout (constant_id = shape_constant_id_offset + 5) const int bdims = 0; -layout (constant_id = shape_constant_id_offset + 6) const int bw = 0; -layout (constant_id = shape_constant_id_offset + 7) const int bh = 0; -layout (constant_id = shape_constant_id_offset + 8) const int bc = 0; -layout (constant_id = shape_constant_id_offset + 9) const int bcstep = 0; - -layout (constant_id = shape_constant_id_offset + 10) const int outdims = 0; -layout (constant_id = shape_constant_id_offset + 11) const int outw = 0; -layout (constant_id = shape_constant_id_offset + 12) const int outh = 0; -layout (constant_id = shape_constant_id_offset + 13) const int outc = 0; -layout (constant_id = shape_constant_id_offset + 14) const int outcstep = 0; +layout (constant_id = shape_constant_id_offset + 3) const int ad = 0; +layout (constant_id = shape_constant_id_offset + 4) const int ac = 0; +layout (constant_id = shape_constant_id_offset + 5) const int acstep = 0; + +layout (constant_id = shape_constant_id_offset + 6) const int bdims = 0; +layout (constant_id = shape_constant_id_offset + 7) const int bw = 0; +layout (constant_id = shape_constant_id_offset + 8) const int bh = 0; +layout (constant_id = shape_constant_id_offset + 9) const int bd = 0; +layout (constant_id = shape_constant_id_offset + 10) const int bc = 0; +layout (constant_id = shape_constant_id_offset + 11) const int bcstep = 0; + +layout (constant_id = shape_constant_id_offset + 12) const int outdims = 0; +layout (constant_id = shape_constant_id_offset + 13) const int outw = 0; +layout (constant_id = shape_constant_id_offset + 14) const int outh = 0; +layout (constant_id = shape_constant_id_offset + 15) const int outd = 0; +layout (constant_id = shape_constant_id_offset + 16) const int outc = 0; +layout (constant_id = shape_constant_id_offset + 17) const int outcstep = 0; #if NCNN_image_shader layout (binding = 0) uniform unfp sampler3D a_blob_3d; @@ -58,18 +61,21 @@ layout (push_constant) uniform parameter int adims; int aw; int ah; + int ad; int ac; int acstep; int bdims; int bw; int bh; + int bd; int bc; int bcstep; int outdims; int outw; int outh; + int outd; int outc; int outcstep; } p; @@ -80,85 +86,82 @@ void main() int gy = int(gl_GlobalInvocationID.y); int gz = int(gl_GlobalInvocationID.z); - if (gx >= psc(outw) || gy >= psc(outh) || gz >= psc(outc)) + if (gx >= psc(outw) || gy >= psc(outh) * psc(outd) || gz >= psc(outc)) return; -#if NCNN_image_shader - afpvec4 v2; + int yd = gy / psc(outh); + int yh = gy % psc(outh); - if (psc(bdims) == 3 && psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(bc) == 1) - { - // special type 2 - v2 = afpvec4(image3d_ld1(b_blob_3d, ivec3(gx, gy, 0))); - } - else - { - v2 = afpvec4(image3d_ld1(b_blob_3d, ivec3(0, 0, 0))); - } + // explicit broadcast + int bx = min(gx, psc(bw) - 1); + int by = min(yd, psc(bd) - 1) * psc(bh) + min(yh, psc(bh) - 1); + int bz = min(gz, psc(bc) - 1); +#if NCNN_image_shader afpvec8 v1 = image3d_ld8(a_blob_3d, ivec3(gx, gy, gz)); + afp b = image3d_ld1(b_blob_3d, ivec3(bx, by, bz)); #else - const int gi = gz * psc(outcstep) + gy * psc(outw) + gx; - - int bi = 0; + int gi = gz * psc(outcstep) + gy * psc(outw) + gx; + int bi = bz * psc(bcstep) + by * psc(bw) + bx; - if (psc(bdims) == 3 && psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(bc) == 1) - { - // special type 2 - bi = gy * psc(bw) + gx; - } - - // type 6 11 16 afpvec8 v1 = buffer_ld8(a_blob_data, gi); - afpvec4 v2 = afpvec4(buffer_ld1(b_blob_data, bi)); + afp b = buffer_ld1(b_blob_data, bi); #endif + afpvec8 v2; + v2[0] = afpvec4(b); + v2[1] = afpvec4(b); afpvec8 res; if (op_type == 0) { - res[0] = v1[0] + v2; - res[1] = v1[1] + v2; + res[0] = v1[0] + v2[0]; + res[1] = v1[1] + v2[1]; } if (op_type == 1) { - res[0] = v1[0] - v2; - res[1] = v1[1] - v2; + res[0] = v1[0] - v2[0]; + res[1] = v1[1] - v2[1]; } if (op_type == 2) { - res[0] = v1[0] * v2; - res[1] = v1[1] * v2; + res[0] = v1[0] * v2[0]; + res[1] = v1[1] * v2[1]; } if (op_type == 3) { - res[0] = v1[0] / v2; - res[1] = v1[1] / v2; + res[0] = v1[0] / v2[0]; + res[1] = v1[1] / v2[1]; } if (op_type == 4) { - res[0] = max(v1[0], v2); - res[1] = max(v1[1], v2); + res[0] = max(v1[0], v2[0]); + res[1] = max(v1[1], v2[1]); } if (op_type == 5) { - res[0] = min(v1[0], v2); - res[1] = min(v1[1], v2); + res[0] = min(v1[0], v2[0]); + res[1] = min(v1[1], v2[1]); } if (op_type == 6) { - res[0] = pow(v1[0], v2); - res[1] = pow(v1[1], v2); + res[0] = pow(v1[0], v2[0]); + res[1] = pow(v1[1], v2[1]); } if (op_type == 7) { - res[0] = v2 - v1[0]; - res[1] = v2 - v1[1]; + res[0] = v2[0] - v1[0]; + res[1] = v2[1] - v1[1]; } if (op_type == 8) { - res[0] = v2 / v1[0]; - res[1] = v2 / v1[1]; + res[0] = v2[0] / v1[0]; + res[1] = v2[1] / v1[1]; + } + if (op_type == 9) + { + res[0] = pow(v2[0], v1[0]); + res[1] = pow(v2[1], v1[1]); } #if NCNN_image_shader diff --git a/src/layer/vulkan/shader/binaryop_broadcast_pack4.comp b/src/layer/vulkan/shader/binaryop_broadcast_pack4.comp deleted file mode 100644 index 1f71ae1ea..000000000 --- a/src/layer/vulkan/shader/binaryop_broadcast_pack4.comp +++ /dev/null @@ -1,502 +0,0 @@ -// Tencent is pleased to support the open source community by making ncnn available. -// -// Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except -// in compliance with the License. You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software distributed -// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -// CONDITIONS OF ANY KIND, either express or implied. See the License for the -// specific language governing permissions and limitations under the License. - -#version 450 - -#if NCNN_fp16_storage -#extension GL_EXT_shader_16bit_storage: require -#endif -#if NCNN_fp16_arithmetic -#extension GL_EXT_shader_explicit_arithmetic_types_float16: require -#endif - -layout (constant_id = 0) const int op_type = 0; - -#define shape_constant_id_offset 1 -layout (constant_id = shape_constant_id_offset + 0) const int adims = 0; -layout (constant_id = shape_constant_id_offset + 1) const int aw = 0; -layout (constant_id = shape_constant_id_offset + 2) const int ah = 0; -layout (constant_id = shape_constant_id_offset + 3) const int ad = 0; -layout (constant_id = shape_constant_id_offset + 4) const int ac = 0; -layout (constant_id = shape_constant_id_offset + 5) const int acstep = 0; - -layout (constant_id = shape_constant_id_offset + 6) const int bdims = 0; -layout (constant_id = shape_constant_id_offset + 7) const int bw = 0; -layout (constant_id = shape_constant_id_offset + 8) const int bh = 0; -layout (constant_id = shape_constant_id_offset + 9) const int bd = 0; -layout (constant_id = shape_constant_id_offset + 10) const int bc = 0; -layout (constant_id = shape_constant_id_offset + 11) const int bcstep = 0; - -layout (constant_id = shape_constant_id_offset + 12) const int outdims = 0; -layout (constant_id = shape_constant_id_offset + 13) const int outw = 0; -layout (constant_id = shape_constant_id_offset + 14) const int outh = 0; -layout (constant_id = shape_constant_id_offset + 15) const int outd = 0; -layout (constant_id = shape_constant_id_offset + 16) const int outc = 0; -layout (constant_id = shape_constant_id_offset + 17) const int outcstep = 0; - -#if NCNN_image_shader -layout (binding = 0) uniform unfp sampler3D a_blob_3d; -layout (binding = 1) uniform unfp sampler3D b_blob_3d; -layout (binding = 2, imfmtc4) writeonly uniform unfp image3D top_blob_3d; -#else -layout (binding = 0) readonly buffer a_blob { sfpvec4 a_blob_data[]; }; -layout (binding = 1) readonly buffer b_blob { sfpvec4 b_blob_data[]; }; -layout (binding = 2) writeonly buffer top_blob { sfpvec4 top_blob_data[]; }; -#endif - -layout (push_constant) uniform parameter -{ - int adims; - int aw; - int ah; - int ad; - int ac; - int acstep; - - int bdims; - int bw; - int bh; - int bd; - int bc; - int bcstep; - - int outdims; - int outw; - int outh; - int outd; - int outc; - int outcstep; -} p; - -void main() -{ - int gx = int(gl_GlobalInvocationID.x); - int gy = int(gl_GlobalInvocationID.y); - int gz = int(gl_GlobalInvocationID.z); - - if (gx >= psc(outw) || gy >= psc(outh) * psc(outd) || gz >= psc(outc)) - return; - -#if NCNN_image_shader - int ax = gx; - int ay = gy; - int az = gz; - int bx = gx; - int by = gy; - int bz = gz; - - if (psc(adims) == 4) - { - int yd = gy / psc(outh); - int yh = gy % psc(outh); - - if (psc(bdims) == 3) - { - // type 28 - bx = yh; - by = yd; - bz = gz; - } - - if (psc(bdims) == 2) - { - // type 27 - bx = yd; - by = gz; - bz = 0; - } - - if (psc(bdims) == 1) - { - if (psc(bw) == 1) - { - // type 25 - bx = 0; - by = 0; - bz = 0; - } - else - { - // type 26 - bx = gz; - by = 0; - bz = 0; - } - } - } - else if (psc(adims) == 3) - { - if (psc(bdims) == 4) - { - int yd = gy / psc(outh); - int yh = gy % psc(outh); - - // type 23 - ax = yh; - ay = yd; - az = gz; - } - - if (psc(bdims) == 3) - { - if (psc(bw) == 1 && psc(bh) == 1) - { - // special type 1 - bx = 0; - by = 0; - } - - if (psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(bc) == 1) - { - // special type 2 - bz = 0; - } - - if (psc(aw) == 1 && psc(ah) == 1) - { - // special type 3 - ax = 0; - ay = 0; - } - - if (psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(ac) == 1) - { - // special type 4 - az = 0; - } - - if (psc(aw) != 1 && psc(bw) == 1 && psc(bh) == psc(ah) && psc(bc) == psc(ac)) - { - // special type 5 - bx = 0; - } - - if (psc(bw) == psc(aw) && psc(ah) != 1 && psc(bh) == 1 && psc(bc) == psc(ac)) - { - // special type 6 - by = 0; - } - - if (psc(bw) != 1 && psc(aw) == 1 && psc(bh) == psc(ah) && psc(bc) == psc(ac)) - { - // special type 7 - ax = 0; - } - - if (psc(bw) == psc(aw) && psc(bh) != 1 && psc(ah) == 1 && psc(bc) == psc(ac)) - { - // special type 8 - ay = 0; - } - } - - if (psc(bdims) == 2) - { - // type 18 - bx = gy; - by = gz; - bz = 0; - } - - if (psc(bdims) == 1) - { - if (psc(bw) == 1) - { - // type 16 - bx = 0; - by = 0; - bz = 0; - } - else - { - // type 17 - bx = gz; - by = 0; - bz = 0; - } - } - } - else if (psc(adims) == 2) - { - if (psc(bdims) == 4) - { - int yd = gy / psc(outh); - int yh = gy % psc(outh); - - // type 22 - ax = yd; - ay = gz; - az = 0; - } - - if (psc(bdims) == 3) - { - // type 14 - ax = gy; - ay = gz; - az = 0; - } - - if (psc(bdims) == 1) - { - if (psc(bw) == 1) - { - // type 11 - bx = 0; - by = 0; - bz = 0; - } - else - { - // type 12 - bx = gy; - by = 0; - bz = 0; - } - } - } - else if (psc(adims) == 1) - { - if (psc(aw) == 1) - { - // type 2 3 4 20 - ax = 0; - ay = 0; - az = 0; - } - else - { - if (psc(bdims) == 4) - { - // type 21 - ax = gz; - ay = 0; - az = 0; - } - - if (psc(bdims) == 3) - { - // type 9 - ax = gz; - ay = 0; - az = 0; - } - - if (psc(bdims) == 2) - { - // type 8 - ax = gy; - ay = 0; - az = 0; - } - - if (psc(bdims) == 1) - { - if (psc(bw) == 1) - { - // type 6 - bx = 0; - by = 0; - bz = 0; - } - } - } - } - - afpvec4 v1 = image3d_ld4(a_blob_3d, ivec3(ax, ay, az)); - afpvec4 v2 = image3d_ld4(b_blob_3d, ivec3(bx, by, bz)); -#else - const int gi = gz * psc(outcstep) + gy * psc(outw) + gx; - - int ai; - int bi; - - if (psc(adims) == 4) - { - int yd = gy / psc(outh); - int yh = gy % psc(outh); - - if (psc(bdims) == 3) - { - // type 28 - ai = gi; - bi = gz * psc(bcstep) + yd * psc(bw) + yh; - } - - if (psc(bdims) == 2) - { - // type 27 - ai = gi; - bi = gz * psc(bw) + yd; - } - - if (psc(bdims) == 1) - { - if (psc(bw) == 1) - { - // type 25 - ai = gi; - bi = 0; - } - else - { - // type 26 - ai = gi; - bi = gz; - } - } - } - else if (psc(adims) == 3) - { - if (psc(bdims) == 4) - { - int yd = gy / psc(outh); - int yh = gy % psc(outh); - - // type 23 - ai = gz * psc(acstep) + yd * psc(aw) + yh; - bi = gi; - } - - if (psc(bdims) == 3) - { - if (psc(bw) == 1 && psc(bh) == 1) - { - // special type 1 - ai = gi; - bi = gz * psc(bcstep); - } - - if (psc(aw) == 1 && psc(ah) == 1) - { - // special type 3 - ai = gz * psc(acstep); - bi = gi; - } - - if (psc(aw) != 1 && psc(bw) == 1 && psc(bh) == psc(ah) && psc(bc) == psc(ac)) - { - // special type 5 - ai = gi; - bi = gz * psc(bcstep) + gy; - } - - if (psc(bw) == psc(aw) && psc(ah) != 1 && psc(bh) == 1 && psc(bc) == psc(ac)) - { - // special type 6 - ai = gi; - bi = gz * psc(bcstep) + gx; - } - - if (psc(bw) != 1 && psc(aw) == 1 && psc(bh) == psc(ah) && psc(bc) == psc(ac)) - { - // special type 7 - ai = gz * psc(acstep) + gy; - bi = gi; - } - - if (psc(bw) == psc(aw) && psc(bh) != 1 && psc(ah) == 1 && psc(bc) == psc(ac)) - { - // special type 8 - ai = gz * psc(acstep) + gx; - bi = gi; - } - } - - if (psc(bdims) == 2) - { - // type 18 - ai = gi; - bi = gz * psc(bw) + gy; - } - - if (psc(bdims) == 1) - { - // type 17 - ai = gi; - bi = gz; - } - } - else if (psc(adims) == 2) - { - if (psc(bdims) == 4) - { - int yd = gy / psc(outh); - int yh = gy % psc(outh); - - // type 22 - ai = gz * psc(aw) + yd; - bi = gi; - } - - if (psc(bdims) == 3) - { - // type 14 - ai = gz * psc(aw) + gy; - bi = gi; - } - - if (psc(bdims) == 1) - { - // type 12 - ai = gi; - bi = gy; - } - } - else if (psc(adims) == 1) - { - if (psc(bdims) == 4) - { - // type 21 - ai = gz; - bi = gi; - } - - if (psc(bdims) == 3) - { - // type 9 - ai = gz; - bi = gi; - } - - if (psc(bdims) == 2) - { - // type 8 - ai = gy; - bi = gi; - } - } - - afpvec4 v1 = buffer_ld4(a_blob_data, ai); - afpvec4 v2 = buffer_ld4(b_blob_data, bi); -#endif - - afpvec4 res; - - if (op_type == 0) res = v1 + v2; - if (op_type == 1) res = v1 - v2; - if (op_type == 2) res = v1 * v2; - if (op_type == 3) res = v1 / v2; - if (op_type == 4) res = max(v1, v2); - if (op_type == 5) res = min(v1, v2); - if (op_type == 6) res = pow(v1, v2); - if (op_type == 7) res = v2 - v1; - if (op_type == 8) res = v2 / v1; - -#if NCNN_image_shader - image3d_st4(top_blob_3d, ivec3(gx, gy, gz), res); -#else - buffer_st4(top_blob_data, gi, res); -#endif -} diff --git a/src/layer/vulkan/shader/binaryop_broadcast_pack8.comp b/src/layer/vulkan/shader/binaryop_broadcast_pack8.comp deleted file mode 100644 index 41d00199b..000000000 --- a/src/layer/vulkan/shader/binaryop_broadcast_pack8.comp +++ /dev/null @@ -1,539 +0,0 @@ -// Tencent is pleased to support the open source community by making ncnn available. -// -// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except -// in compliance with the License. You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software distributed -// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -// CONDITIONS OF ANY KIND, either express or implied. See the License for the -// specific language governing permissions and limitations under the License. - -#version 450 - -#if NCNN_fp16_storage -#extension GL_EXT_shader_16bit_storage: require -struct sfpvec8 { f16vec4 abcd; f16vec4 efgh; }; -#endif -#if NCNN_fp16_arithmetic -#extension GL_EXT_shader_explicit_arithmetic_types_float16: require -#endif - -layout (constant_id = 0) const int op_type = 0; - -#define shape_constant_id_offset 1 -layout (constant_id = shape_constant_id_offset + 0) const int adims = 0; -layout (constant_id = shape_constant_id_offset + 1) const int aw = 0; -layout (constant_id = shape_constant_id_offset + 2) const int ah = 0; -layout (constant_id = shape_constant_id_offset + 3) const int ad = 0; -layout (constant_id = shape_constant_id_offset + 4) const int ac = 0; -layout (constant_id = shape_constant_id_offset + 5) const int acstep = 0; - -layout (constant_id = shape_constant_id_offset + 6) const int bdims = 0; -layout (constant_id = shape_constant_id_offset + 7) const int bw = 0; -layout (constant_id = shape_constant_id_offset + 8) const int bh = 0; -layout (constant_id = shape_constant_id_offset + 9) const int bd = 0; -layout (constant_id = shape_constant_id_offset + 10) const int bc = 0; -layout (constant_id = shape_constant_id_offset + 11) const int bcstep = 0; - -layout (constant_id = shape_constant_id_offset + 12) const int outdims = 0; -layout (constant_id = shape_constant_id_offset + 13) const int outw = 0; -layout (constant_id = shape_constant_id_offset + 14) const int outh = 0; -layout (constant_id = shape_constant_id_offset + 15) const int outd = 0; -layout (constant_id = shape_constant_id_offset + 16) const int outc = 0; -layout (constant_id = shape_constant_id_offset + 17) const int outcstep = 0; - -#if NCNN_image_shader -layout (binding = 0) uniform unfp sampler3D a_blob_3d; -layout (binding = 1) uniform unfp sampler3D b_blob_3d; -layout (binding = 2, imfmtc4) writeonly uniform unfp image3D top_blob_3d; -#else -layout (binding = 0) readonly buffer a_blob { sfpvec8 a_blob_data[]; }; -layout (binding = 1) readonly buffer b_blob { sfpvec8 b_blob_data[]; }; -layout (binding = 2) writeonly buffer top_blob { sfpvec8 top_blob_data[]; }; -#endif - -layout (push_constant) uniform parameter -{ - int adims; - int aw; - int ah; - int ad; - int ac; - int acstep; - - int bdims; - int bw; - int bh; - int bd; - int bc; - int bcstep; - - int outdims; - int outw; - int outh; - int outd; - int outc; - int outcstep; -} p; - -void main() -{ - int gx = int(gl_GlobalInvocationID.x); - int gy = int(gl_GlobalInvocationID.y); - int gz = int(gl_GlobalInvocationID.z); - - if (gx >= psc(outw) || gy >= psc(outh) * psc(outd) || gz >= psc(outc)) - return; - -#if NCNN_image_shader - int ax = gx; - int ay = gy; - int az = gz; - int bx = gx; - int by = gy; - int bz = gz; - - if (psc(adims) == 4) - { - int yd = gy / psc(outh); - int yh = gy % psc(outh); - - if (psc(bdims) == 3) - { - // type 28 - bx = yh; - by = yd; - bz = gz; - } - - if (psc(bdims) == 2) - { - // type 27 - bx = yd; - by = gz; - bz = 0; - } - - if (psc(bdims) == 1) - { - if (psc(bw) == 1) - { - // type 25 - bx = 0; - by = 0; - bz = 0; - } - else - { - // type 26 - bx = gz; - by = 0; - bz = 0; - } - } - } - else if (psc(adims) == 3) - { - if (psc(bdims) == 4) - { - int yd = gy / psc(outh); - int yh = gy % psc(outh); - - // type 23 - ax = yh; - ay = yd; - az = gz; - } - - if (psc(bdims) == 3) - { - if (psc(bw) == 1 && psc(bh) == 1) - { - // special type 1 - bx = 0; - by = 0; - } - - if (psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(bc) == 1) - { - // special type 2 - bz = 0; - } - - if (psc(aw) == 1 && psc(ah) == 1) - { - // special type 3 - ax = 0; - ay = 0; - } - - if (psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(ac) == 1) - { - // special type 4 - az = 0; - } - - if (psc(aw) != 1 && psc(bw) == 1 && psc(bh) == psc(ah) && psc(bc) == psc(ac)) - { - // special type 5 - bx = 0; - } - - if (psc(bw) == psc(aw) && psc(ah) != 1 && psc(bh) == 1 && psc(bc) == psc(ac)) - { - // special type 6 - by = 0; - } - - if (psc(bw) != 1 && psc(aw) == 1 && psc(bh) == psc(ah) && psc(bc) == psc(ac)) - { - // special type 7 - ax = 0; - } - - if (psc(bw) == psc(aw) && psc(bh) != 1 && psc(ah) == 1 && psc(bc) == psc(ac)) - { - // special type 8 - ay = 0; - } - } - - if (psc(bdims) == 2) - { - // type 18 - bx = gy; - by = gz; - bz = 0; - } - - if (psc(bdims) == 1) - { - if (psc(bw) == 1) - { - // type 16 - bx = 0; - by = 0; - bz = 0; - } - else - { - // type 17 - bx = gz; - by = 0; - bz = 0; - } - } - } - else if (psc(adims) == 2) - { - if (psc(bdims) == 4) - { - int yd = gy / psc(outh); - int yh = gy % psc(outh); - - // type 22 - ax = yd; - ay = gz; - az = 0; - } - - if (psc(bdims) == 3) - { - // type 14 - ax = gy; - ay = gz; - az = 0; - } - - if (psc(bdims) == 1) - { - if (psc(bw) == 1) - { - // type 11 - bx = 0; - by = 0; - bz = 0; - } - else - { - // type 12 - bx = gy; - by = 0; - bz = 0; - } - } - } - else if (psc(adims) == 1) - { - if (psc(aw) == 1) - { - // type 2 3 4 20 - ax = 0; - ay = 0; - az = 0; - } - else - { - if (psc(bdims) == 4) - { - // type 21 - ax = gz; - ay = 0; - az = 0; - } - - if (psc(bdims) == 3) - { - // type 9 - ax = gz; - ay = 0; - az = 0; - } - - if (psc(bdims) == 2) - { - // type 8 - ax = gy; - ay = 0; - az = 0; - } - - if (psc(bdims) == 1) - { - if (psc(bw) == 1) - { - // type 6 - bx = 0; - by = 0; - bz = 0; - } - } - } - } - - afpvec8 v1 = image3d_ld8(a_blob_3d, ivec3(ax, ay, az)); - afpvec8 v2 = image3d_ld8(b_blob_3d, ivec3(bx, by, bz)); -#else - const int gi = gz * psc(outcstep) + gy * psc(outw) + gx; - - int ai; - int bi; - - if (psc(adims) == 4) - { - int yd = gy / psc(outh); - int yh = gy % psc(outh); - - if (psc(bdims) == 3) - { - // type 28 - ai = gi; - bi = gz * psc(bcstep) + yd * psc(bw) + yh; - } - - if (psc(bdims) == 2) - { - // type 27 - ai = gi; - bi = gz * psc(bw) + yd; - } - - if (psc(bdims) == 1) - { - if (psc(bw) == 1) - { - // type 25 - ai = gi; - bi = 0; - } - else - { - // type 26 - ai = gi; - bi = gz; - } - } - } - else if (psc(adims) == 3) - { - if (psc(bdims) == 4) - { - int yd = gy / psc(outh); - int yh = gy % psc(outh); - - // type 23 - ai = gz * psc(acstep) + yd * psc(aw) + yh; - bi = gi; - } - - if (psc(bdims) == 3) - { - if (psc(bw) == 1 && psc(bh) == 1) - { - // special type 1 - ai = gi; - bi = gz * psc(bcstep); - } - - if (psc(aw) == 1 && psc(ah) == 1) - { - // special type 3 - ai = gz * psc(acstep); - bi = gi; - } - - if (psc(aw) != 1 && psc(bw) == 1 && psc(bh) == psc(ah) && psc(bc) == psc(ac)) - { - // special type 5 - bi = gz * psc(bcstep) + gy; - ai = gi; - } - - if (psc(bw) == psc(aw) && psc(ah) != 1 && psc(bh) == 1 && psc(bc) == psc(ac)) - { - // special type 6 - bi = gz * psc(bcstep) + gx; - ai = gi; - } - - if (psc(bw) != 1 && psc(aw) == 1 && psc(bh) == psc(ah) && psc(bc) == psc(ac)) - { - // special type 7 - ai = gz * psc(acstep) + gy; - bi = gi; - } - - if (psc(bw) == psc(aw) && psc(bh) != 1 && psc(ah) == 1 && psc(bc) == psc(ac)) - { - // special type 8 - ai = gz * psc(acstep) + gx; - bi = gi; - } - } - - if (psc(bdims) == 2) - { - // type 18 - ai = gi; - bi = gz * psc(bw) + gy; - } - - if (psc(bdims) == 1) - { - // type 17 - ai = gi; - bi = gz; - } - } - else if (psc(adims) == 2) - { - if (psc(bdims) == 4) - { - int yd = gy / psc(outh); - int yh = gy % psc(outh); - - // type 22 - ai = gz * psc(aw) + yd; - bi = gi; - } - - if (psc(bdims) == 3) - { - // type 14 - ai = gz * psc(aw) + gy; - bi = gi; - } - - if (psc(bdims) == 1) - { - // type 12 - ai = gi; - bi = gy; - } - } - else if (psc(adims) == 1) - { - if (psc(bdims) == 4) - { - // type 21 - ai = gz; - bi = gi; - } - - if (psc(bdims) == 3) - { - // type 9 - ai = gz; - bi = gi; - } - - if (psc(bdims) == 2) - { - // type 8 - ai = gy; - bi = gi; - } - } - - afpvec8 v1 = buffer_ld8(a_blob_data, ai); - afpvec8 v2 = buffer_ld8(b_blob_data, bi); -#endif - - afpvec8 res; - - if (op_type == 0) - { - res[0] = v1[0] + v2[0]; - res[1] = v1[1] + v2[1]; - } - if (op_type == 1) - { - res[0] = v1[0] - v2[0]; - res[1] = v1[1] - v2[1]; - } - if (op_type == 2) - { - res[0] = v1[0] * v2[0]; - res[1] = v1[1] * v2[1]; - } - if (op_type == 3) - { - res[0] = v1[0] / v2[0]; - res[1] = v1[1] / v2[1]; - } - if (op_type == 4) - { - res[0] = max(v1[0], v2[0]); - res[1] = max(v1[1], v2[1]); - } - if (op_type == 5) - { - res[0] = min(v1[0], v2[0]); - res[1] = min(v1[1], v2[1]); - } - if (op_type == 6) - { - res[0] = pow(v1[0], v2[0]); - res[1] = pow(v1[1], v2[1]); - } - if (op_type == 7) - { - res[0] = v2[0] - v1[0]; - res[1] = v2[1] - v1[1]; - } - if (op_type == 8) - { - res[0] = v2[0] / v1[0]; - res[1] = v2[1] / v1[1]; - } - -#if NCNN_image_shader - image3d_st8(top_blob_3d, ivec3(gx, gy, gz), res); -#else - buffer_st8(top_blob_data, gi, res); -#endif -} diff --git a/src/layer/vulkan/shader/binaryop_pack4.comp b/src/layer/vulkan/shader/binaryop_pack4.comp index b19505520..ea55e302f 100644 --- a/src/layer/vulkan/shader/binaryop_pack4.comp +++ b/src/layer/vulkan/shader/binaryop_pack4.comp @@ -92,52 +92,39 @@ void main() afpvec4 v1 = buffer_ld4(a_blob_data, gi); #endif - afpvec4 res; + afpvec4 v2; if (with_scalar == 1) { - // type 5 10 15 - afp b = afp(const_b); - - if (op_type == 0) res = v1 + b; - if (op_type == 1) res = v1 - b; - if (op_type == 2) res = v1 * b; - if (op_type == 3) res = v1 / b; - if (op_type == 4) res = max(v1, b); - if (op_type == 5) res = min(v1, b); - if (op_type == 6) res = pow(v1, afpvec4(b)); - if (op_type == 7) res = b - v1; - if (op_type == 8) res = b / v1; - -#if NCNN_image_shader - image3d_st4(top_blob_3d, ivec3(gx, gy, gz), res); -#else - buffer_st4(a_blob_data, gi, res); -#endif + // type 0 1 2 3 + v2 = afpvec4(const_b); } else { - // type 7 13 19 + // type 4 5 6 7 #if NCNN_image_shader - afpvec4 v2 = image3d_ld4(b_blob_3d, ivec3(gx, gy, gz)); + v2 = image3d_ld4(b_blob_3d, ivec3(gx, gy, gz)); #else - afpvec4 v2 = buffer_ld4(b_blob_data, gi); + v2 = buffer_ld4(b_blob_data, gi); #endif + } + + afpvec4 res; - if (op_type == 0) res = v1 + v2; - if (op_type == 1) res = v1 - v2; - if (op_type == 2) res = v1 * v2; - if (op_type == 3) res = v1 / v2; - if (op_type == 4) res = max(v1, v2); - if (op_type == 5) res = min(v1, v2); - if (op_type == 6) res = pow(v1, v2); - if (op_type == 7) res = v2 - v1; - if (op_type == 8) res = v2 / v1; + if (op_type == 0) res = v1 + v2; + if (op_type == 1) res = v1 - v2; + if (op_type == 2) res = v1 * v2; + if (op_type == 3) res = v1 / v2; + if (op_type == 4) res = max(v1, v2); + if (op_type == 5) res = min(v1, v2); + if (op_type == 6) res = pow(v1, v2); + if (op_type == 7) res = v2 - v1; + if (op_type == 8) res = v2 / v1; + if (op_type == 9) res = pow(v2, v1); #if NCNN_image_shader - image3d_st4(top_blob_3d, ivec3(gx, gy, gz), res); + image3d_st4(top_blob_3d, ivec3(gx, gy, gz), res); #else - buffer_st4(top_blob_data, gi, res); + buffer_st4(top_blob_data, gi, res); #endif - } } diff --git a/src/layer/vulkan/shader/binaryop_pack8.comp b/src/layer/vulkan/shader/binaryop_pack8.comp index 010187918..67e59b6e3 100644 --- a/src/layer/vulkan/shader/binaryop_pack8.comp +++ b/src/layer/vulkan/shader/binaryop_pack8.comp @@ -93,124 +93,80 @@ void main() afpvec8 v1 = buffer_ld8(a_blob_data, gi); #endif - afpvec8 res; + afpvec8 v2; if (with_scalar == 1) { - // type 5 10 15 - afp b = afp(const_b); - - if (op_type == 0) - { - res[0] = v1[0] + b; - res[1] = v1[1] + b; - } - if (op_type == 1) - { - res[0] = v1[0] - b; - res[1] = v1[1] - b; - } - if (op_type == 2) - { - res[0] = v1[0] * b; - res[1] = v1[1] * b; - } - if (op_type == 3) - { - res[0] = v1[0] / b; - res[1] = v1[1] / b; - } - if (op_type == 4) - { - res[0] = max(v1[0], b); - res[1] = max(v1[1], b); - } - if (op_type == 5) - { - res[0] = min(v1[0], b); - res[1] = min(v1[1], b); - } - if (op_type == 6) - { - res[0] = pow(v1[0], afpvec4(b)); - res[1] = pow(v1[1], afpvec4(b)); - } - if (op_type == 7) - { - res[0] = b - v1[0]; - res[1] = b - v1[1]; - } - if (op_type == 8) - { - res[0] = b / v1[0]; - res[1] = b / v1[1]; - } - -#if NCNN_image_shader - image3d_st8(top_blob_3d, ivec3(gx, gy, gz), res); -#else - buffer_st8(a_blob_data, gi, res); -#endif + // type 0 1 2 3 + v2[0] = afpvec4(const_b); + v2[1] = afpvec4(const_b); } else { - // type 7 13 19 + // type 4 5 6 7 #if NCNN_image_shader - afpvec8 v2 = image3d_ld8(b_blob_3d, ivec3(gx, gy, gz)); + v2 = image3d_ld8(b_blob_3d, ivec3(gx, gy, gz)); #else - afpvec8 v2 = buffer_ld8(b_blob_data, gi); + v2 = buffer_ld8(b_blob_data, gi); #endif + } - if (op_type == 0) - { - res[0] = v1[0] + v2[0]; - res[1] = v1[1] + v2[1]; - } - if (op_type == 1) - { - res[0] = v1[0] - v2[0]; - res[1] = v1[1] - v2[1]; - } - if (op_type == 2) - { - res[0] = v1[0] * v2[0]; - res[1] = v1[1] * v2[1]; - } - if (op_type == 3) - { - res[0] = v1[0] / v2[0]; - res[1] = v1[1] / v2[1]; - } - if (op_type == 4) - { - res[0] = max(v1[0], v2[0]); - res[1] = max(v1[1], v2[1]); - } - if (op_type == 5) - { - res[0] = min(v1[0], v2[0]); - res[1] = min(v1[1], v2[1]); - } - if (op_type == 6) - { - res[0] = pow(v1[0], v2[0]); - res[1] = pow(v1[1], v2[1]); - } - if (op_type == 7) - { - res[0] = v2[0] - v1[0]; - res[1] = v2[1] - v1[1]; - } - if (op_type == 8) - { - res[0] = v2[0] / v1[0]; - res[1] = v2[1] / v1[1]; - } + afpvec8 res; + + if (op_type == 0) + { + res[0] = v1[0] + v2[0]; + res[1] = v1[1] + v2[1]; + } + if (op_type == 1) + { + res[0] = v1[0] - v2[0]; + res[1] = v1[1] - v2[1]; + } + if (op_type == 2) + { + res[0] = v1[0] * v2[0]; + res[1] = v1[1] * v2[1]; + } + if (op_type == 3) + { + res[0] = v1[0] / v2[0]; + res[1] = v1[1] / v2[1]; + } + if (op_type == 4) + { + res[0] = max(v1[0], v2[0]); + res[1] = max(v1[1], v2[1]); + } + if (op_type == 5) + { + res[0] = min(v1[0], v2[0]); + res[1] = min(v1[1], v2[1]); + } + if (op_type == 6) + { + res[0] = pow(v1[0], v2[0]); + res[1] = pow(v1[1], v2[1]); + } + if (op_type == 7) + { + res[0] = v2[0] - v1[0]; + res[1] = v2[1] - v1[1]; + } + if (op_type == 8) + { + res[0] = v2[0] / v1[0]; + res[1] = v2[1] / v1[1]; + } + if (op_type == 9) + { + res[0] = pow(v2[0], v1[0]); + res[1] = pow(v2[1], v1[1]); + } #if NCNN_image_shader - image3d_st8(top_blob_3d, ivec3(gx, gy, gz), res); + image3d_st8(top_blob_3d, ivec3(gx, gy, gz), res); #else - buffer_st8(top_blob_data, gi, res); + buffer_st8(top_blob_data, gi, res); #endif - } } diff --git a/src/layer/x86/binaryop_x86.cpp b/src/layer/x86/binaryop_x86.cpp index 6d6a09fd7..4dbe4e03b 100644 --- a/src/layer/x86/binaryop_x86.cpp +++ b/src/layer/x86/binaryop_x86.cpp @@ -38,137 +38,58 @@ BinaryOp_x86::BinaryOp_x86() } template -static int binary_op_2_3_4_20(const Mat& a, const Mat& b, Mat& c, const Option& opt) +static int binary_op_scalar(const Mat& a, float b, Mat& c, const Option& opt) { Op op; - int w = b.w; - int h = b.h; - int d = b.d; - int channels = b.c; - int elempack = b.elempack; - int size = w * h * d * elempack; - - // type 2 3 4 20 - c.create_like(b, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float a0 = a[0]; - const float* ptr = b.channel(q); - float* outptr = c.channel(q); - - int i = 0; -#if __SSE2__ -#if __AVX__ -#if __AVX512F__ - __m512 _a0_avx512 = _mm512_set1_ps(a0); - for (; i + 15 < size; i += 16) - { - __m512 _p = _mm512_loadu_ps(ptr); - __m512 _outp = op.func_pack16(_a0_avx512, _p); - _mm512_storeu_ps(outptr, _outp); - ptr += 16; - outptr += 16; - } -#endif // __AVX512F__ - __m256 _a0_avx = _mm256_set1_ps(a0); - for (; i + 7 < size; i += 8) - { - __m256 _p = _mm256_loadu_ps(ptr); - __m256 _outp = op.func_pack8(_a0_avx, _p); - _mm256_storeu_ps(outptr, _outp); - ptr += 8; - outptr += 8; - } -#endif // __AVX__ - __m128 _a0 = _mm_set1_ps(a0); - for (; i + 3 < size; i += 4) - { - __m128 _p = _mm_load_ps(ptr); - __m128 _outp = op.func_pack4(_a0, _p); - _mm_store_ps(outptr, _outp); - ptr += 4; - outptr += 4; - } -#endif // __SSE2__ - for (; i < size; i++) - { - *outptr = op.func(a0, *ptr); - ptr += 1; - outptr += 1; - } - } - - return 0; -} - -template -static int binary_op_6_11_16_25(const Mat& a, const Mat& b, Mat& c, const Option& opt) -{ - Op op; - - int w = a.w; - int h = a.h; - int d = a.d; - int channels = a.c; - int elempack = a.elempack; - int size = w * h * d * elempack; - - // type 6 11 16 25 - c.create_like(a, opt.blob_allocator); - if (c.empty()) - return -100; + const int channels = a.c; + const int size = a.w * a.h * a.d * a.elempack; #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) { const float* ptr = a.channel(q); - const float b0 = b[0]; float* outptr = c.channel(q); int i = 0; #if __SSE2__ #if __AVX__ #if __AVX512F__ - __m512 _b0_avx512 = _mm512_set1_ps(b0); + __m512 _b_avx512 = _mm512_set1_ps(b); for (; i + 15 < size; i += 16) { __m512 _p = _mm512_loadu_ps(ptr); - __m512 _outp = op.func_pack16(_p, _b0_avx512); - _mm512_storeu_ps(outptr, _outp); + _p = op.func_pack16(_p, _b_avx512); + _mm512_storeu_ps(outptr, _p); ptr += 16; outptr += 16; } #endif // __AVX512F__ - __m256 _b0_avx = _mm256_set1_ps(b0); + __m256 _b_avx = _mm256_set1_ps(b); for (; i + 7 < size; i += 8) { __m256 _p = _mm256_loadu_ps(ptr); - __m256 _outp = op.func_pack8(_p, _b0_avx); - _mm256_storeu_ps(outptr, _outp); + _p = op.func_pack8(_p, _b_avx); + _mm256_storeu_ps(outptr, _p); ptr += 8; outptr += 8; } #endif // __AVX__ - __m128 _b0 = _mm_set1_ps(b0); + __m128 _b = _mm_set1_ps(b); for (; i + 3 < size; i += 4) { __m128 _p = _mm_load_ps(ptr); - __m128 _outp = op.func_pack4(_p, _b0); - _mm_store_ps(outptr, _outp); + _p = op.func_pack4(_p, _b); + _mm_store_ps(outptr, _p); ptr += 4; outptr += 4; } #endif // __SSE2__ for (; i < size; i++) { - *outptr = op.func(*ptr, b0); - ptr += 1; - outptr += 1; + *outptr = op.func(*ptr, b); + ptr++; + outptr++; } } @@ -176,21 +97,12 @@ static int binary_op_6_11_16_25(const Mat& a, const Mat& b, Mat& c, const Option } template -static int binary_op_7_13_19_29(const Mat& a, const Mat& b, Mat& c, const Option& opt) +static int binary_op_no_broadcast(const Mat& a, const Mat& b, Mat& c, const Option& opt) { Op op; - int w = a.w; - int h = a.h; - int d = a.d; - int channels = a.c; - int elempack = a.elempack; - int size = w * h * d * elempack; - - // type 7 13 19 29 - c.create_like(a, opt.blob_allocator); - if (c.empty()) - return -100; + const int channels = a.c; + const int size = a.w * a.h * a.d * a.elempack; #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) @@ -248,14 +160,8 @@ static int binary_op_7_13_19_29(const Mat& a, const Mat& b, Mat& c, const Option return 0; } -#if __SSE2__ -#if __AVX__ -#if __AVX512F__ -// broadcasting rule -// https://github.com/Tencent/ncnn/wiki/binaryop-broadcasting - template -static int binary_op_pack16(const Mat& a, const Mat& b, Mat& c, const Option& opt) +static int binary_op_broadcast_inner(const Mat& a, const Mat& b, Mat& c, const Option& opt) { Op op; @@ -263,697 +169,339 @@ static int binary_op_pack16(const Mat& a, const Mat& b, Mat& c, const Option& op int h = a.h; int d = a.d; int channels = a.c; - int size = w * h * d; - size_t elemsize = a.elemsize; int elempack = a.elempack; - int w1 = b.w; - int h1 = b.h; - int d1 = b.d; - int channels1 = b.c; - int size1 = w1 * h1 * d1; - size_t elemsize1 = b.elemsize; - int elempack1 = b.elempack; - - if (a.dims == 4) + if (a.dims == 2 && b.dims == 1) { - if (b.dims == 4) + // type 8 + #pragma omp parallel for num_threads(opt.num_threads) + for (int y = 0; y < h; y++) { - // type 29 - return binary_op_7_13_19_29(a, b, c, opt); - } - - c.create(w, h, d, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; + const float* ptr = a.row(y); + float* outptr = c.row(y); - if (b.dims == 3) - { - // type 28 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); + const float _b = b[y]; +#if __SSE2__ + __m128 _b_128 = (elempack == 4) ? _mm_loadu_ps((const float*)b + y * 4) : _mm_set1_ps(_b); +#if __AVX__ + __m256 _b_256 = (elempack == 8) ? _mm256_loadu_ps((const float*)b + y * 8) : _mm256_insertf128_ps(_mm256_castps128_ps256(_b_128), _b_128, 1); +#if __AVX512F__ + __m512 _b_512 = (elempack == 16) ? _mm512_loadu_ps((const float*)b + y * 16) : _mm512_insertf32x8(_mm512_castps256_ps512(_b_256), _b_256, 1); +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ - for (int z = 0; z < d; z++) - { - for (int y = 0; y < h; y++) - { - __m512 _b0 = _mm512_loadu_ps(ptr1); - for (int x = 0; x < w; x++) - { - __m512 _p = _mm512_loadu_ps(ptr); - __m512 _outp = op.func_pack16(_p, _b0); - _mm512_storeu_ps(outptr, _outp); - ptr += 16; - outptr += 16; - } + const int size = w * elempack; - ptr1 += 16; - } - } + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = _mm512_loadu_ps(ptr); + __m512 _outp = op.func_pack16(_p, _b_512); + _mm512_storeu_ps(outptr, _outp); + ptr += 16; + outptr += 16; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = _mm256_loadu_ps(ptr); + __m256 _outp = op.func_pack8(_p, _b_256); + _mm256_storeu_ps(outptr, _outp); + ptr += 8; + outptr += 8; + } +#endif // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _p = _mm_loadu_ps(ptr); + __m128 _outp = op.func_pack4(_p, _b_128); + _mm_storeu_ps(outptr, _outp); + ptr += 4; + outptr += 4; + } +#endif // __SSE2__ + for (; i < size; i++) + { + *outptr = op.func(*ptr, _b); + ptr += 1; + outptr += 1; } - - return 0; } + } - if (b.dims == 2) + if ((a.dims == 3 || a.dims == 4) && b.dims == 1) + { + // type 9 11 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - // type 27 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.row(q); - float* outptr = c.channel(q); + const float* ptr = a.channel(q); + float* outptr = c.channel(q); - for (int z = 0; z < d; z++) - { - __m512 _b0 = _mm512_loadu_ps(ptr1); - for (int y = 0; y < h; y++) - { - for (int x = 0; x < w; x++) - { - __m512 _p = _mm512_loadu_ps(ptr); - __m512 _outp = op.func_pack16(_p, _b0); - _mm512_storeu_ps(outptr, _outp); - ptr += 16; - outptr += 16; - } - } + const float _b = b[q]; +#if __SSE2__ + __m128 _b_128 = (elempack == 4) ? _mm_loadu_ps((const float*)b + q * 4) : _mm_set1_ps(_b); +#if __AVX__ + __m256 _b_256 = (elempack == 8) ? _mm256_loadu_ps((const float*)b + q * 8) : _mm256_insertf128_ps(_mm256_castps128_ps256(_b_128), _b_128, 1); +#if __AVX512F__ + __m512 _b_512 = (elempack == 16) ? _mm512_loadu_ps((const float*)b + q * 16) : _mm512_insertf32x8(_mm512_castps256_ps512(_b_256), _b_256, 1); +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ - ptr1 += 16; - } - } + const int size = w * h * d * elempack; - return 0; + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = _mm512_loadu_ps(ptr); + __m512 _outp = op.func_pack16(_p, _b_512); + _mm512_storeu_ps(outptr, _outp); + ptr += 16; + outptr += 16; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = _mm256_loadu_ps(ptr); + __m256 _outp = op.func_pack8(_p, _b_256); + _mm256_storeu_ps(outptr, _outp); + ptr += 8; + outptr += 8; + } +#endif // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _p = _mm_loadu_ps(ptr); + __m128 _outp = op.func_pack4(_p, _b_128); + _mm_storeu_ps(outptr, _outp); + ptr += 4; + outptr += 4; + } +#endif // __SSE2__ + for (; i < size; i++) + { + *outptr = op.func(*ptr, _b); + ptr += 1; + outptr += 1; + } } + } - if (b.dims == 1) + if (a.dims == 3 && b.dims == 2) + { + // type 10 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - if (b.w == 1 && elempack1 == 1) - { - // type 25 - return binary_op_6_11_16_25(a, b, c, opt); - } + const float* ptr = a.channel(q); + const float* ptr1 = b.row(q); + float* outptr = c.channel(q); - // type 26 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) + const int size = w * elempack; + + for (int y = 0; y < h; y++) { - const float* ptr = a.channel(q); - __m512 _b0 = _mm512_loadu_ps((const float*)b + q * 16); - float* outptr = c.channel(q); + const float _b = ptr1[y]; +#if __SSE2__ + __m128 _b_128 = (elempack == 4) ? _mm_loadu_ps((const float*)ptr1 + y * 4) : _mm_set1_ps(_b); +#if __AVX__ + __m256 _b_256 = (elempack == 8) ? _mm256_loadu_ps((const float*)ptr1 + y * 8) : _mm256_insertf128_ps(_mm256_castps128_ps256(_b_128), _b_128, 1); +#if __AVX512F__ + __m512 _b_512 = (elempack == 16) ? _mm512_loadu_ps((const float*)ptr1 + y * 16) : _mm512_insertf32x8(_mm512_castps256_ps512(_b_256), _b_256, 1); +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ - for (int i = 0; i < size; i++) + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) { __m512 _p = _mm512_loadu_ps(ptr); - __m512 _outp = op.func_pack16(_p, _b0); + __m512 _outp = op.func_pack16(_p, _b_512); _mm512_storeu_ps(outptr, _outp); ptr += 16; outptr += 16; } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = _mm256_loadu_ps(ptr); + __m256 _outp = op.func_pack8(_p, _b_256); + _mm256_storeu_ps(outptr, _outp); + ptr += 8; + outptr += 8; + } +#endif // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _p = _mm_loadu_ps(ptr); + __m128 _outp = op.func_pack4(_p, _b_128); + _mm_storeu_ps(outptr, _outp); + ptr += 4; + outptr += 4; + } +#endif // __SSE2__ + for (; i < size; i++) + { + *outptr = op.func(*ptr, _b); + ptr += 1; + outptr += 1; + } } - - return 0; } } - else if (a.dims == 3) + + if (a.dims == 4 && b.dims == 2) { - if (b.dims == 4) + // type 12 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - // type 23 - c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; + const float* ptr = a.channel(q); + const float* ptr1 = b.row(q); + float* outptr = c.channel(q); - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) + const int size = w * h * elempack; + + for (int z = 0; z < d; z++) { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); + const float _b = ptr1[z]; +#if __SSE2__ + __m128 _b_128 = (elempack == 4) ? _mm_loadu_ps((const float*)ptr1 + z * 4) : _mm_set1_ps(_b); +#if __AVX__ + __m256 _b_256 = (elempack == 8) ? _mm256_loadu_ps((const float*)ptr1 + z * 8) : _mm256_insertf128_ps(_mm256_castps128_ps256(_b_128), _b_128, 1); +#if __AVX512F__ + __m512 _b_512 = (elempack == 16) ? _mm512_loadu_ps((const float*)ptr1 + z * 16) : _mm512_insertf32x8(_mm512_castps256_ps512(_b_256), _b_256, 1); +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ - for (int z = 0; z < d1; z++) + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) { - for (int y = 0; y < h1; y++) - { - __m512 _a0 = _mm512_loadu_ps(ptr); - for (int x = 0; x < w1; x++) - { - __m512 _p = _mm512_loadu_ps(ptr1); - __m512 _outp = op.func_pack16(_a0, _p); - _mm512_storeu_ps(outptr, _outp); - ptr1 += 16; - outptr += 16; - } - - ptr += 16; - } + __m512 _p = _mm512_loadu_ps(ptr); + __m512 _outp = op.func_pack16(_p, _b_512); + _mm512_storeu_ps(outptr, _outp); + ptr += 16; + outptr += 16; + } +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) + { + __m256 _p = _mm256_loadu_ps(ptr); + __m256 _outp = op.func_pack8(_p, _b_256); + _mm256_storeu_ps(outptr, _outp); + ptr += 8; + outptr += 8; + } +#endif // __AVX__ + for (; i + 3 < size; i += 4) + { + __m128 _p = _mm_loadu_ps(ptr); + __m128 _outp = op.func_pack4(_p, _b_128); + _mm_storeu_ps(outptr, _outp); + ptr += 4; + outptr += 4; + } +#endif // __SSE2__ + for (; i < size; i++) + { + *outptr = op.func(*ptr, _b); + ptr += 1; + outptr += 1; } } - - return 0; } + } - if (b.dims == 3) + if (a.dims == 4 && b.dims == 3) + { + // type 13 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - if (w1 == 1 && h1 == 1 && channels1 == channels) - { - // special type 1 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float* ptr = a.channel(q); - const float* b0 = b.channel(q); - float* outptr = c.channel(q); - __m512 _b0 = _mm512_loadu_ps(b0); - for (int i = 0; i < size; i++) - { - __m512 _p = _mm512_loadu_ps(ptr); - __m512 _outp = op.func_pack16(_p, _b0); - _mm512_storeu_ps(outptr, _outp); - ptr += 16; - outptr += 16; - } - } + const float* ptr = a.channel(q); + float* outptr = c.channel(q); - return 0; - } + const int size = w * elempack; - if (w1 == w && h1 == h && channels1 == 1 && elempack1 == 1) + for (int z = 0; z < d; z++) { - // special type 2 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; + const float* ptr1 = b.channel(q).row(z); - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) + for (int y = 0; y < h; y++) { - const float* ptr = a.channel(q); - const float* ptr1 = b; - float* outptr = c.channel(q); - for (int i = 0; i < size; i++) + const float _b = ptr1[y]; +#if __SSE2__ + __m128 _b_128 = (elempack == 4) ? _mm_loadu_ps((const float*)ptr1 + y * 4) : _mm_set1_ps(_b); +#if __AVX__ + __m256 _b_256 = (elempack == 8) ? _mm256_loadu_ps((const float*)ptr1 + y * 8) : _mm256_insertf128_ps(_mm256_castps128_ps256(_b_128), _b_128, 1); +#if __AVX512F__ + __m512 _b_512 = (elempack == 16) ? _mm512_loadu_ps((const float*)ptr1 + y * 16) : _mm512_insertf32x8(_mm512_castps256_ps512(_b_256), _b_256, 1); +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ + + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) { __m512 _p = _mm512_loadu_ps(ptr); - __m512 _p1 = _mm512_set1_ps(*ptr1); - __m512 _outp = op.func_pack16(_p, _p1); + __m512 _outp = op.func_pack16(_p, _b_512); _mm512_storeu_ps(outptr, _outp); ptr += 16; - ptr1 += 1; outptr += 16; } - } - - return 0; - } - - if (w == 1 && h == 1 && channels1 == channels) - { - // special type 3 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* a0 = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - __m512 _a0 = _mm512_loadu_ps(a0); - for (int i = 0; i < size1; i++) - { - __m512 _p1 = _mm512_loadu_ps(ptr1); - __m512 _outp = op.func_pack16(_a0, _p1); - _mm512_storeu_ps(outptr, _outp); - ptr1 += 16; - outptr += 16; - } - } - - return 0; - } - - if (w1 == w && h1 == h && channels == 1 && elempack == 1) - { - // special type 4 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr = a; - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - for (int i = 0; i < size1; i++) - { - __m512 _p = _mm512_set1_ps(*ptr); - __m512 _p1 = _mm512_loadu_ps(ptr1); - __m512 _outp = op.func_pack16(_p, _p1); - _mm512_storeu_ps(outptr, _outp); - ptr += 1; - ptr1 += 16; - outptr += 16; - } - } - - return 0; - } - - if (w != 1 && w1 == 1 && h1 == h && channels1 == channels) - { - // special type 5 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int y = 0; y < h; y++) - { - __m512 _p1 = _mm512_loadu_ps(ptr1 + y * 16); - for (int x = 0; x < w; x++) - { - __m512 _p = _mm512_loadu_ps(ptr); - __m512 _outp = op.func_pack16(_p, _p1); - _mm512_storeu_ps(outptr, _outp); - - ptr += 16; - outptr += 16; - } - } - } - - return 0; - } - - if (w1 == w && h != 1 && h1 == 1 && channels1 == channels) - { - // special type 6 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int y = 0; y < h; y++) - { - for (int x = 0; x < w; x++) - { - __m512 _p = _mm512_loadu_ps(ptr); - __m512 _p1 = _mm512_loadu_ps(ptr1 + x * 16); - __m512 _outp = op.func_pack16(_p, _p1); - _mm512_storeu_ps(outptr, _outp); - - ptr += 16; - outptr += 16; - } - } - } - - return 0; - } - - if (w1 != 1 && w == 1 && h1 == h && channels1 == channels) - { - // special type 7 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int y = 0; y < h1; y++) - { - __m512 _p = _mm512_loadu_ps(ptr + y * 16); - for (int x = 0; x < w1; x++) - { - __m512 _p1 = _mm512_loadu_ps(ptr1); - __m512 _outp = op.func_pack16(_p, _p1); - _mm512_storeu_ps(outptr, _outp); - - ptr1 += 16; - outptr += 16; - } - } - } - - return 0; - } - - if (w1 == w && h1 != 1 && h == 1 && channels1 == channels) - { - // special type 8 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int y = 0; y < h1; y++) - { - for (int x = 0; x < w1; x++) - { - __m512 _p = _mm512_loadu_ps(ptr + x * 16); - __m512 _p1 = _mm512_loadu_ps(ptr1); - __m512 _outp = op.func_pack16(_p, _p1); - _mm512_storeu_ps(outptr, _outp); - - ptr1 += 16; - outptr += 16; - } - } - } - - return 0; - } - - // type 19 - return binary_op_7_13_19_29(a, b, c, opt); - } - - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.dims == 2) - { - // type 18 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.row(q); - float* outptr = c.channel(q); - - for (int y = 0; y < h; y++) - { - __m512 _b0 = _mm512_loadu_ps(ptr1); - for (int x = 0; x < w; x++) +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) { - __m512 _p = _mm512_loadu_ps(ptr); - __m512 _outp = op.func_pack16(_p, _b0); - _mm512_storeu_ps(outptr, _outp); - ptr += 16; - outptr += 16; + __m256 _p = _mm256_loadu_ps(ptr); + __m256 _outp = op.func_pack8(_p, _b_256); + _mm256_storeu_ps(outptr, _outp); + ptr += 8; + outptr += 8; } - - ptr1 += 16; - } - } - - return 0; - } - - if (b.dims == 1) - { - if (b.w == 1 && elempack1 == 1) - { - // type 16 - return binary_op_6_11_16_25(a, b, c, opt); - } - - // type 17 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float* ptr = a.channel(q); - __m512 _b0 = _mm512_loadu_ps((const float*)b + q * 16); - float* outptr = c.channel(q); - - for (int i = 0; i < size; i++) - { - __m512 _p = _mm512_loadu_ps(ptr); - __m512 _outp = op.func_pack16(_p, _b0); - _mm512_storeu_ps(outptr, _outp); - ptr += 16; - outptr += 16; - } - } - - return 0; - } - } - else if (a.dims == 2) - { - if (b.dims == 4) - { - // type 22 - c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr = a.row(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int z = 0; z < d1; z++) - { - __m512 _a0 = _mm512_loadu_ps(ptr); - for (int y = 0; y < h1; y++) +#endif // __AVX__ + for (; i + 3 < size; i += 4) { - for (int x = 0; x < w1; x++) - { - __m512 _p = _mm512_loadu_ps(ptr1); - __m512 _outp = op.func_pack16(_a0, _p); - _mm512_storeu_ps(outptr, _outp); - ptr1 += 16; - outptr += 16; - } + __m128 _p = _mm_loadu_ps(ptr); + __m128 _outp = op.func_pack4(_p, _b_128); + _mm_storeu_ps(outptr, _outp); + ptr += 4; + outptr += 4; } - - ptr += 16; - } - } - - return 0; - } - - if (b.dims == 3) - { - // type 14 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr = a.row(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int y = 0; y < h1; y++) - { - __m512 _a0 = _mm512_loadu_ps(ptr); - for (int x = 0; x < w1; x++) +#endif // __SSE2__ + for (; i < size; i++) { - __m512 _p1 = _mm512_loadu_ps(ptr1); - __m512 _outp = op.func_pack16(_a0, _p1); - _mm512_storeu_ps(outptr, _outp); - ptr1 += 16; - outptr += 16; + *outptr = op.func(*ptr, _b); + ptr += 1; + outptr += 1; } - - ptr += 16; - } - } - - return 0; - } - - c.create(w, h, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.dims == 2) - { - // type 13 - return binary_op_7_13_19_29(a, b, c, opt); - } - - if (b.dims == 1) - { - c.create(w, h, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.w == 1 && elempack1 == 1) - { - // type 11 - return binary_op_6_11_16_25(a, b, c, opt); - } - - // type 12 - const float* ptr = a; - const float* ptr1 = b; - float* outptr = c; - - for (int y = 0; y < h; y++) - { - __m512 _b0 = _mm512_loadu_ps(ptr1); - for (int x = 0; x < w; x++) - { - __m512 _p = _mm512_loadu_ps(ptr); - __m512 _outp = op.func_pack16(_p, _b0); - _mm512_storeu_ps(outptr, _outp); - ptr += 16; - outptr += 16; - } - - ptr1 += 16; - } - - return 0; - } - } - else if (a.dims == 1) - { - if (a.w == 1 && elempack == 1) - { - // type 2 3 4 20 - return binary_op_2_3_4_20(a, b, c, opt); - } - - if (b.dims == 4) - { - // type 21 - c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - __m512 _a0 = _mm512_loadu_ps((const float*)a + q * 16); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int i = 0; i < size1; i++) - { - __m512 _p1 = _mm512_loadu_ps(ptr1); - __m512 _outp = op.func_pack16(_a0, _p1); - _mm512_storeu_ps(outptr, _outp); - ptr1 += 16; - outptr += 16; - } - } - - return 0; - } - - if (b.dims == 3) - { - // type 9 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - __m512 _a0 = _mm512_loadu_ps((const float*)a + q * 16); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int i = 0; i < size1; i++) - { - __m512 _p1 = _mm512_loadu_ps(ptr1); - __m512 _outp = op.func_pack16(_a0, _p1); - _mm512_storeu_ps(outptr, _outp); - ptr1 += 16; - outptr += 16; - } - } - - return 0; - } - - if (b.dims == 2) - { - // type 8 - c.create(w1, h1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - const float* ptr = a; - const float* ptr1 = b; - float* outptr = c; - - for (int y = 0; y < h1; y++) - { - __m512 _a0 = _mm512_loadu_ps(ptr); - for (int x = 0; x < w1; x++) - { - __m512 _p1 = _mm512_loadu_ps(ptr1); - __m512 _outp = op.func_pack16(_a0, _p1); - _mm512_storeu_ps(outptr, _outp); - ptr1 += 16; - outptr += 16; } - - ptr += 16; } - - return 0; - } - - if (b.dims == 1) - { - c.create(w, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.w == 1 && elempack1 == 1) - { - // type 6 - return binary_op_6_11_16_25(a, b, c, opt); - } - - // type 7 - binary_op_7_13_19_29(a, b, c, opt); } } return 0; } -#endif // __AVX512F__ template -static int binary_op_pack8(const Mat& a, const Mat& b, Mat& c, const Option& opt) +static int binary_op_broadcast_outer(const Mat& a, const Mat& b, Mat& c, const Option& opt) { Op op; @@ -961,1404 +509,236 @@ static int binary_op_pack8(const Mat& a, const Mat& b, Mat& c, const Option& opt int h = a.h; int d = a.d; int channels = a.c; - int size = w * h * d; - size_t elemsize = a.elemsize; int elempack = a.elempack; - int w1 = b.w; - int h1 = b.h; - int d1 = b.d; - int channels1 = b.c; - int size1 = w1 * h1 * d1; - size_t elemsize1 = b.elemsize; - int elempack1 = b.elempack; - - if (a.dims == 4) + if (a.dims == 2) { - if (b.dims == 4) - { - // type 29 - return binary_op_7_13_19_29(a, b, c, opt); - } - - c.create(w, h, d, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.dims == 3) + // type 14 + #pragma omp parallel for num_threads(opt.num_threads) + for (int y = 0; y < h; y++) { - // type 28 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); + const float* ptr = a.row(y); + const float* ptr1 = b; + float* outptr = c.row(y); - for (int z = 0; z < d; z++) - { - for (int y = 0; y < h; y++) - { - __m256 _b0 = _mm256_loadu_ps(ptr1); - for (int x = 0; x < w; x++) - { - __m256 _p = _mm256_loadu_ps(ptr); - __m256 _outp = op.func_pack8(_p, _b0); - _mm256_storeu_ps(outptr, _outp); - ptr += 8; - outptr += 8; - } - - ptr1 += 8; - } - } - } - - return 0; - } - - if (b.dims == 2) - { - // type 27 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.row(q); - float* outptr = c.channel(q); - - for (int z = 0; z < d; z++) - { - __m256 _b0 = _mm256_loadu_ps(ptr1); - for (int y = 0; y < h; y++) - { - for (int x = 0; x < w; x++) - { - __m256 _p = _mm256_loadu_ps(ptr); - __m256 _outp = op.func_pack8(_p, _b0); - _mm256_storeu_ps(outptr, _outp); - ptr += 8; - outptr += 8; - } - } - - ptr1 += 8; - } - } - - return 0; - } - - if (b.dims == 1) - { - if (b.w == 1 && elempack1 == 1) - { - // type 25 - return binary_op_6_11_16_25(a, b, c, opt); - } - - // type 26 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float* ptr = a.channel(q); - __m256 _b0 = _mm256_loadu_ps((const float*)b + q * 8); - float* outptr = c.channel(q); - - for (int i = 0; i < size; i++) - { - __m256 _p = _mm256_loadu_ps(ptr); - __m256 _outp = op.func_pack8(_p, _b0); - _mm256_storeu_ps(outptr, _outp); - ptr += 8; - outptr += 8; - } - } - - return 0; - } - } - else if (a.dims == 3) - { - if (b.dims == 4) - { - // type 23 - c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int z = 0; z < d1; z++) - { - for (int y = 0; y < h1; y++) - { - __m256 _a0 = _mm256_loadu_ps(ptr); - for (int x = 0; x < w1; x++) - { - __m256 _p = _mm256_loadu_ps(ptr1); - __m256 _outp = op.func_pack8(_a0, _p); - _mm256_storeu_ps(outptr, _outp); - ptr1 += 8; - outptr += 8; - } - - ptr += 8; - } - } - } - - return 0; - } - - if (b.dims == 3) - { - if (w1 == 1 && h1 == 1 && channels1 == channels) - { - // special type 1 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float* ptr = a.channel(q); - const float* b0 = b.channel(q); - float* outptr = c.channel(q); - __m256 _b0 = _mm256_loadu_ps(b0); - for (int i = 0; i < size; i++) - { - __m256 _p = _mm256_loadu_ps(ptr); - __m256 _outp = op.func_pack8(_p, _b0); - _mm256_storeu_ps(outptr, _outp); - ptr += 8; - outptr += 8; - } - } - - return 0; - } - - if (w1 == w && h1 == h && channels1 == 1 && elempack1 == 1) - { - // special type 2 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b; - float* outptr = c.channel(q); - for (int i = 0; i < size; i++) - { - __m256 _p = _mm256_loadu_ps(ptr); - __m256 _p1 = _mm256_broadcast_ss(ptr1); - __m256 _outp = op.func_pack8(_p, _p1); - _mm256_storeu_ps(outptr, _outp); - ptr += 8; - ptr1 += 1; - outptr += 8; - } - } - - return 0; - } - - if (w == 1 && h == 1 && channels1 == channels) - { - // special type 3 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* a0 = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - __m256 _a0 = _mm256_loadu_ps(a0); - for (int i = 0; i < size1; i++) - { - __m256 _p1 = _mm256_loadu_ps(ptr1); - __m256 _outp = op.func_pack8(_a0, _p1); - _mm256_storeu_ps(outptr, _outp); - ptr1 += 8; - outptr += 8; - } - } - - return 0; - } - - if (w1 == w && h1 == h && channels == 1 && elempack == 1) - { - // special type 4 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr = a; - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - for (int i = 0; i < size1; i++) - { - __m256 _p = _mm256_broadcast_ss(ptr); - __m256 _p1 = _mm256_loadu_ps(ptr1); - __m256 _outp = op.func_pack8(_p, _p1); - _mm256_storeu_ps(outptr, _outp); - ptr += 1; - ptr1 += 8; - outptr += 8; - } - } - - return 0; - } - - if (w != 1 && w1 == 1 && h1 == h && channels1 == channels) - { - // special type 5 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int y = 0; y < h; y++) - { - __m256 _p1 = _mm256_loadu_ps(ptr1 + y * 8); - for (int x = 0; x < w; x++) - { - __m256 _p = _mm256_loadu_ps(ptr); - __m256 _outp = op.func_pack8(_p, _p1); - _mm256_storeu_ps(outptr, _outp); - - ptr += 8; - outptr += 8; - } - } - } - - return 0; - } - - if (w1 == w && h != 1 && h1 == 1 && channels1 == channels) - { - // special type 6 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int y = 0; y < h; y++) - { - for (int x = 0; x < w; x++) - { - __m256 _p = _mm256_loadu_ps(ptr); - __m256 _p1 = _mm256_loadu_ps(ptr1 + x * 8); - __m256 _outp = op.func_pack8(_p, _p1); - _mm256_storeu_ps(outptr, _outp); - - ptr += 8; - outptr += 8; - } - } - } - - return 0; - } - - if (w1 != 1 && w == 1 && h1 == h && channels1 == channels) - { - // special type 7 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int y = 0; y < h1; y++) - { - __m256 _p = _mm256_loadu_ps(ptr + y * 8); - for (int x = 0; x < w1; x++) - { - __m256 _p1 = _mm256_loadu_ps(ptr1); - __m256 _outp = op.func_pack8(_p, _p1); - _mm256_storeu_ps(outptr, _outp); - - ptr1 += 8; - outptr += 8; - } - } - } - - return 0; - } - - if (w1 == w && h1 != 1 && h == 1 && channels1 == channels) - { - // special type 8 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int y = 0; y < h1; y++) - { - for (int x = 0; x < w1; x++) - { - __m256 _p = _mm256_loadu_ps(ptr + x * 8); - __m256 _p1 = _mm256_loadu_ps(ptr1); - __m256 _outp = op.func_pack8(_p, _p1); - _mm256_storeu_ps(outptr, _outp); - - ptr1 += 8; - outptr += 8; - } - } - } - - return 0; - } - - // type 19 - return binary_op_7_13_19_29(a, b, c, opt); - } - - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.dims == 2) - { - // type 18 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.row(q); - float* outptr = c.channel(q); - - for (int y = 0; y < h; y++) - { - __m256 _b0 = _mm256_loadu_ps(ptr1); - for (int x = 0; x < w; x++) - { - __m256 _p = _mm256_loadu_ps(ptr); - __m256 _outp = op.func_pack8(_p, _b0); - _mm256_storeu_ps(outptr, _outp); - ptr += 8; - outptr += 8; - } - - ptr1 += 8; - } - } - - return 0; - } - - if (b.dims == 1) - { - if (b.w == 1 && elempack1 == 1) - { - // type 16 - return binary_op_6_11_16_25(a, b, c, opt); - } - - // type 17 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float* ptr = a.channel(q); - __m256 _b0 = _mm256_loadu_ps((const float*)b + q * 8); - float* outptr = c.channel(q); - - for (int i = 0; i < size; i++) - { - __m256 _p = _mm256_loadu_ps(ptr); - __m256 _outp = op.func_pack8(_p, _b0); - _mm256_storeu_ps(outptr, _outp); - ptr += 8; - outptr += 8; - } - } - - return 0; - } - } - else if (a.dims == 2) - { - if (b.dims == 4) - { - // type 22 - c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr = a.row(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int z = 0; z < d1; z++) - { - __m256 _a0 = _mm256_loadu_ps(ptr); - for (int y = 0; y < h1; y++) - { - for (int x = 0; x < w1; x++) - { - __m256 _p = _mm256_loadu_ps(ptr1); - __m256 _outp = op.func_pack8(_a0, _p); - _mm256_storeu_ps(outptr, _outp); - ptr1 += 8; - outptr += 8; - } - } - - ptr += 8; - } - } - - return 0; - } - - if (b.dims == 3) - { - // type 14 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr = a.row(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int y = 0; y < h1; y++) - { - __m256 _a0 = _mm256_loadu_ps(ptr); - for (int x = 0; x < w1; x++) - { - __m256 _p1 = _mm256_loadu_ps(ptr1); - __m256 _outp = op.func_pack8(_a0, _p1); - _mm256_storeu_ps(outptr, _outp); - ptr1 += 8; - outptr += 8; - } - - ptr += 8; - } - } - - return 0; - } - - c.create(w, h, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.dims == 2) - { - // type 13 - return binary_op_7_13_19_29(a, b, c, opt); - } - - if (b.dims == 1) - { - c.create(w, h, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.w == 1 && elempack1 == 1) - { - // type 11 - return binary_op_6_11_16_25(a, b, c, opt); - } - - // type 12 - const float* ptr = a; - const float* ptr1 = b; - float* outptr = c; - - for (int y = 0; y < h; y++) - { - __m256 _b0 = _mm256_loadu_ps(ptr1); - for (int x = 0; x < w; x++) - { - __m256 _p = _mm256_loadu_ps(ptr); - __m256 _outp = op.func_pack8(_p, _b0); - _mm256_storeu_ps(outptr, _outp); - ptr += 8; - outptr += 8; - } - - ptr1 += 8; - } - - return 0; - } - } - else if (a.dims == 1) - { - if (a.w == 1 && elempack == 1) - { - // type 2 3 4 20 - return binary_op_2_3_4_20(a, b, c, opt); - } - - if (b.dims == 4) - { - // type 21 - c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - __m256 _a0 = _mm256_loadu_ps((const float*)a + q * 8); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int i = 0; i < size1; i++) - { - __m256 _p1 = _mm256_loadu_ps(ptr1); - __m256 _outp = op.func_pack8(_a0, _p1); - _mm256_storeu_ps(outptr, _outp); - ptr1 += 8; - outptr += 8; - } - } - - return 0; - } - - if (b.dims == 3) - { - // type 9 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - __m256 _a0 = _mm256_loadu_ps((const float*)a + q * 8); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int i = 0; i < size1; i++) - { - __m256 _p1 = _mm256_loadu_ps(ptr1); - __m256 _outp = op.func_pack8(_a0, _p1); - _mm256_storeu_ps(outptr, _outp); - ptr1 += 8; - outptr += 8; - } - } - - return 0; - } - - if (b.dims == 2) - { - // type 8 - c.create(w1, h1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - const float* ptr = a; - const float* ptr1 = b; - float* outptr = c; - - for (int y = 0; y < h1; y++) - { - __m256 _a0 = _mm256_loadu_ps(ptr); - for (int x = 0; x < w1; x++) - { - __m256 _p1 = _mm256_loadu_ps(ptr1); - __m256 _outp = op.func_pack8(_a0, _p1); - _mm256_storeu_ps(outptr, _outp); - ptr1 += 8; - outptr += 8; - } - - ptr += 8; - } - - return 0; - } - - if (b.dims == 1) - { - c.create(w, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.w == 1 && elempack1 == 1) - { - // type 6 - return binary_op_6_11_16_25(a, b, c, opt); - } - - // type 7 - binary_op_7_13_19_29(a, b, c, opt); - } - } - - return 0; -} -#endif // __AVX__ - -template -static int binary_op_pack4(const Mat& a, const Mat& b, Mat& c, const Option& opt) -{ - Op op; - - int w = a.w; - int h = a.h; - int d = a.d; - int channels = a.c; - int size = w * h * d; - size_t elemsize = a.elemsize; - int elempack = a.elempack; - - int w1 = b.w; - int h1 = b.h; - int d1 = b.d; - int channels1 = b.c; - int size1 = w1 * h1 * d1; - size_t elemsize1 = b.elemsize; - int elempack1 = b.elempack; - - if (a.dims == 4) - { - if (b.dims == 4) - { - // type 29 - return binary_op_7_13_19_29(a, b, c, opt); - } - - c.create(w, h, d, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.dims == 3) - { - // type 28 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int z = 0; z < d; z++) - { - for (int y = 0; y < h; y++) - { - __m128 _b0 = _mm_loadu_ps(ptr1); - for (int x = 0; x < w; x++) - { - __m128 _p = _mm_loadu_ps(ptr); - __m128 _outp = op.func_pack4(_p, _b0); - _mm_storeu_ps(outptr, _outp); - ptr += 4; - outptr += 4; - } - - ptr1 += 4; - } - } - } - - return 0; - } - - if (b.dims == 2) - { - // type 27 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.row(q); - float* outptr = c.channel(q); - - for (int z = 0; z < d; z++) - { - __m128 _b0 = _mm_loadu_ps(ptr1); - for (int y = 0; y < h; y++) - { - for (int x = 0; x < w; x++) - { - __m128 _p = _mm_loadu_ps(ptr); - __m128 _outp = op.func_pack4(_p, _b0); - _mm_storeu_ps(outptr, _outp); - ptr += 4; - outptr += 4; - } - } - - ptr1 += 4; - } - } - - return 0; - } - - if (b.dims == 1) - { - if (b.w == 1 && elempack1 == 1) - { - // type 25 - return binary_op_6_11_16_25(a, b, c, opt); - } - - // type 26 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float* ptr = a.channel(q); - __m128 _b0 = _mm_loadu_ps((const float*)b + q * 4); - float* outptr = c.channel(q); - - for (int i = 0; i < size; i++) - { - __m128 _p = _mm_loadu_ps(ptr); - __m128 _outp = op.func_pack4(_p, _b0); - _mm_storeu_ps(outptr, _outp); - ptr += 4; - outptr += 4; - } - } - - return 0; - } - } - else if (a.dims == 3) - { - if (b.dims == 4) - { - // type 23 - c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int z = 0; z < d1; z++) - { - for (int y = 0; y < h1; y++) - { - __m128 _a0 = _mm_loadu_ps(ptr); - for (int x = 0; x < w1; x++) - { - __m128 _p = _mm_loadu_ps(ptr1); - __m128 _outp = op.func_pack4(_a0, _p); - _mm_storeu_ps(outptr, _outp); - ptr1 += 4; - outptr += 4; - } - - ptr += 4; - } - } - } - - return 0; - } - - if (b.dims == 3) - { - if (w1 == 1 && h1 == 1 && channels1 == channels) - { - // special type 1 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float* ptr = a.channel(q); - float* outptr = c.channel(q); - const float* b0 = b.channel(q); - __m128 _b0 = _mm_loadu_ps(b0); - for (int i = 0; i < size; i++) - { - __m128 _p = _mm_loadu_ps(ptr); - __m128 _outp = op.func_pack4(_p, _b0); - _mm_storeu_ps(outptr, _outp); - ptr += 4; - outptr += 4; - } - } - - return 0; - } - - if (w1 == w && h1 == h && channels1 == 1 && elempack1 == 1) - { - // special type 2 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b; - float* outptr = c.channel(q); - for (int i = 0; i < size; i++) - { - __m128 _p = _mm_loadu_ps(ptr); - __m128 _p1 = _mm_set1_ps(*ptr1); - __m128 _outp = op.func_pack4(_p, _p1); - _mm_storeu_ps(outptr, _outp); - ptr += 4; - ptr1 += 1; - outptr += 4; - } - } - - return 0; - } - - if (w == 1 && h == 1 && channels1 == channels) - { - // special type 3 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* a0 = a.channel(q); - float* outptr = c.channel(q); - const float* ptr1 = b.channel(q); - __m128 _a0 = _mm_loadu_ps(a0); - for (int i = 0; i < size1; i++) - { - __m128 _p1 = _mm_loadu_ps(ptr1); - __m128 _outp = op.func_pack4(_a0, _p1); - _mm_storeu_ps(outptr, _outp); - ptr1 += 4; - outptr += 4; - } - } - - return 0; - } - - if (w1 == w && h1 == h && channels == 1 && elempack == 1) - { - // special type 4 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr = a; - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - for (int i = 0; i < size1; i++) - { - __m128 _p = _mm_set1_ps(*ptr); - __m128 _p1 = _mm_loadu_ps(ptr1); - __m128 _outp = op.func_pack4(_p, _p1); - _mm_storeu_ps(outptr, _outp); - ptr += 1; - ptr1 += 4; - outptr += 4; - } - } - - return 0; - } - - if (w != 1 && w1 == 1 && h1 == h && channels1 == channels) - { - // special type 5 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int y = 0; y < h; y++) - { - __m128 _p1 = _mm_loadu_ps(ptr1 + y * 4); - for (int x = 0; x < w; x++) - { - __m128 _p = _mm_loadu_ps(ptr); - __m128 _outp = op.func_pack4(_p, _p1); - _mm_storeu_ps(outptr, _outp); - - ptr += 4; - outptr += 4; - } - } - } - - return 0; - } - - if (w1 == w && h != 1 && h1 == 1 && channels1 == channels) - { - // special type 6 - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int y = 0; y < h; y++) - { - for (int x = 0; x < w; x++) - { - __m128 _p = _mm_loadu_ps(ptr); - __m128 _p1 = _mm_loadu_ps(ptr1 + x * 4); - __m128 _outp = op.func_pack4(_p, _p1); - _mm_storeu_ps(outptr, _outp); - - ptr += 4; - outptr += 4; - } - } - } - - return 0; - } - - if (w1 != 1 && w == 1 && h1 == h && channels1 == channels) - { - // special type 7 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int y = 0; y < h1; y++) - { - __m128 _p = _mm_loadu_ps(ptr + y * 4); - for (int x = 0; x < w1; x++) - { - __m128 _p1 = _mm_loadu_ps(ptr1); - __m128 _outp = op.func_pack4(_p, _p1); - _mm_storeu_ps(outptr, _outp); - - ptr1 += 4; - outptr += 4; - } - } - } - - return 0; - } - - if (w1 == w && h1 != 1 && h == 1 && channels1 == channels) - { - // special type 8 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr = a.channel(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int y = 0; y < h1; y++) - { - for (int x = 0; x < w1; x++) - { - __m128 _p = _mm_loadu_ps(ptr + x * 4); - __m128 _p1 = _mm_loadu_ps(ptr1); - __m128 _outp = op.func_pack4(_p, _p1); - _mm_storeu_ps(outptr, _outp); - - ptr1 += 4; - outptr += 4; - } - } - } - - return 0; - } - - // type 19 - return binary_op_7_13_19_29(a, b, c, opt); - } - - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.dims == 2) - { - // type 18 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) { - const float* ptr = a.channel(q); - const float* ptr1 = b.row(q); - float* outptr = c.channel(q); - - for (int y = 0; y < h; y++) + for (int x = 0; x < w; x++) { - __m128 _b0 = _mm_loadu_ps(ptr1); - for (int x = 0; x < w; x++) - { - __m128 _p = _mm_loadu_ps(ptr); - __m128 _outp = op.func_pack4(_p, _b0); - _mm_storeu_ps(outptr, _outp); - ptr += 4; - outptr += 4; - } - - ptr1 += 4; + __m512 _p = _mm512_loadu_ps(ptr); + __m512 _b = _mm512_set1_ps(*ptr1); + __m512 _outp = op.func_pack16(_p, _b); + _mm512_storeu_ps(outptr, _outp); + ptr += 16; + ptr1 += 1; + outptr += 16; } } - - return 0; - } - - if (b.dims == 1) - { - if (b.w == 1 && elempack1 == 1) +#endif // __AVX512F__ + if (elempack == 8) { - // type 16 - return binary_op_6_11_16_25(a, b, c, opt); + for (int x = 0; x < w; x++) + { + __m256 _p = _mm256_loadu_ps(ptr); + __m256 _b = _mm256_set1_ps(*ptr1); + __m256 _outp = op.func_pack8(_p, _b); + _mm256_storeu_ps(outptr, _outp); + ptr += 8; + ptr1 += 1; + outptr += 8; + } } - - // type 17 - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) +#endif // __AVX__ + if (elempack == 4) { - const float* ptr = a.channel(q); - __m128 _b0 = _mm_loadu_ps((const float*)b + q * 4); - float* outptr = c.channel(q); - - for (int i = 0; i < size; i++) + for (int x = 0; x < w; x++) { __m128 _p = _mm_loadu_ps(ptr); - __m128 _outp = op.func_pack4(_p, _b0); + __m128 _b = _mm_set1_ps(*ptr1); + __m128 _outp = op.func_pack4(_p, _b); _mm_storeu_ps(outptr, _outp); ptr += 4; + ptr1 += 1; outptr += 4; } } - - return 0; +#endif // __SSE2__ + if (elempack == 1) + { + for (int x = 0; x < w; x++) + { + *outptr = op.func(*ptr, *ptr1); + ptr += 1; + ptr1 += 1; + outptr += 1; + } + } } } - else if (a.dims == 2) + + if (a.dims == 3 || a.dims == 4) { - if (b.dims == 4) + // type 15 16 17 18 19 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - // type 22 - c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; + const float* ptr = a.channel(q); + float* outptr = c.channel(q); - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) + for (int z = 0; z < d; z++) { - const float* ptr = a.row(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int z = 0; z < d1; z++) + int z1 = std::min(z, b.d - 1); + for (int y = 0; y < h; y++) { - __m128 _a0 = _mm_loadu_ps(ptr); - for (int y = 0; y < h1; y++) + int y1 = std::min(y, b.h - 1); + + const float* ptr1 = b.depth(z1).row(y1); + +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + for (int x = 0; x < w; x++) + { + __m512 _p = _mm512_loadu_ps(ptr); + __m512 _b = _mm512_set1_ps(*ptr1); + __m512 _outp = op.func_pack16(_p, _b); + _mm512_storeu_ps(outptr, _outp); + ptr += 16; + ptr1 += 1; + outptr += 16; + } + } +#endif // __AVX512F__ + if (elempack == 8) + { + for (int x = 0; x < w; x++) + { + __m256 _p = _mm256_loadu_ps(ptr); + __m256 _b = _mm256_set1_ps(*ptr1); + __m256 _outp = op.func_pack8(_p, _b); + _mm256_storeu_ps(outptr, _outp); + ptr += 8; + ptr1 += 1; + outptr += 8; + } + } +#endif // __AVX__ + if (elempack == 4) { - for (int x = 0; x < w1; x++) + for (int x = 0; x < w; x++) { - __m128 _p = _mm_loadu_ps(ptr1); - __m128 _outp = op.func_pack4(_a0, _p); + __m128 _p = _mm_loadu_ps(ptr); + __m128 _b = _mm_set1_ps(*ptr1); + __m128 _outp = op.func_pack4(_p, _b); _mm_storeu_ps(outptr, _outp); - ptr1 += 4; + ptr += 4; + ptr1 += 1; outptr += 4; } } - - ptr += 4; - } - } - - return 0; - } - - if (b.dims == 3) - { - // type 14 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - const float* ptr = a.row(q); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int y = 0; y < h1; y++) - { - __m128 _a0 = _mm_loadu_ps(ptr); - for (int x = 0; x < w1; x++) +#endif // __SSE2__ + if (elempack == 1) { - __m128 _p1 = _mm_loadu_ps(ptr1); - __m128 _outp = op.func_pack4(_a0, _p1); - _mm_storeu_ps(outptr, _outp); - ptr1 += 4; - outptr += 4; + for (int x = 0; x < w; x++) + { + *outptr = op.func(*ptr, *ptr1); + ptr += 1; + ptr1 += 1; + outptr += 1; + } } - - ptr += 4; } } - - return 0; - } - - c.create(w, h, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.dims == 2) - { - // type 13 - return binary_op_7_13_19_29(a, b, c, opt); } + } - if (b.dims == 1) - { - c.create(w, h, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.w == 1 && elempack1 == 1) - { - // type 11 - return binary_op_6_11_16_25(a, b, c, opt); - } - - // type 12 - const float* ptr = a; - const float* ptr1 = b; - float* outptr = c; + return 0; +} - for (int y = 0; y < h; y++) - { - __m128 _b0 = _mm_loadu_ps(ptr1); - for (int x = 0; x < w; x++) - { - __m128 _p = _mm_loadu_ps(ptr); - __m128 _outp = op.func_pack4(_p, _b0); - _mm_storeu_ps(outptr, _outp); - ptr += 4; - outptr += 4; - } +template +static int binary_op_broadcast_20(const Mat& a, const Mat& b, Mat& c, const Option& opt) +{ + Op op; - ptr1 += 4; - } + int w = a.w; + int h = a.h; + int channels = a.c; + int elempack = a.elempack; - return 0; - } - } - else if (a.dims == 1) + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) { - if (a.w == 1 && elempack == 1) - { - // type 2 3 4 20 - return binary_op_2_3_4_20(a, b, c, opt); - } + const float* ptr = a.channel(q); + float* outptr = c.channel(q); - if (b.dims == 4) + for (int y = 0; y < h; y++) { - // type 21 - c.create(w1, h1, d1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; + const float* ptr1 = b.channel(q); - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) - { - __m128 _a0 = _mm_loadu_ps((const float*)a + q * 4); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); + const int size = w * elempack; - for (int i = 0; i < size1; i++) - { - __m128 _p1 = _mm_loadu_ps(ptr1); - __m128 _outp = op.func_pack4(_a0, _p1); - _mm_storeu_ps(outptr, _outp); - ptr1 += 4; - outptr += 4; - } + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + for (; i + 15 < size; i += 16) + { + __m512 _p = _mm512_loadu_ps(ptr); + __m512 _p1 = _mm512_loadu_ps(ptr1); + __m512 _outp = op.func_pack16(_p, _p1); + _mm512_storeu_ps(outptr, _outp); + ptr += 16; + ptr1 += 16; + outptr += 16; } - - return 0; - } - - if (b.dims == 3) - { - // type 9 - c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels1; q++) +#endif // __AVX512F__ + for (; i + 7 < size; i += 8) { - __m128 _a0 = _mm_loadu_ps((const float*)a + q * 4); - const float* ptr1 = b.channel(q); - float* outptr = c.channel(q); - - for (int i = 0; i < size1; i++) - { - __m128 _p1 = _mm_loadu_ps(ptr1); - __m128 _outp = op.func_pack4(_a0, _p1); - _mm_storeu_ps(outptr, _outp); - ptr1 += 4; - outptr += 4; - } + __m256 _p = _mm256_loadu_ps(ptr); + __m256 _p1 = _mm256_loadu_ps(ptr1); + __m256 _outp = op.func_pack8(_p, _p1); + _mm256_storeu_ps(outptr, _outp); + ptr += 8; + ptr1 += 8; + outptr += 8; } - - return 0; - } - - if (b.dims == 2) - { - // type 8 - c.create(w1, h1, elemsize1, elempack1, opt.blob_allocator); - if (c.empty()) - return -100; - - const float* ptr = a; - const float* ptr1 = b; - float* outptr = c; - - for (int y = 0; y < h1; y++) +#endif // __AVX__ + for (; i + 3 < size; i += 4) { - __m128 _a0 = _mm_loadu_ps(ptr); - for (int x = 0; x < w1; x++) - { - __m128 _p1 = _mm_loadu_ps(ptr1); - __m128 _outp = op.func_pack4(_a0, _p1); - _mm_storeu_ps(outptr, _outp); - ptr1 += 4; - outptr += 4; - } - + __m128 _p = _mm_loadu_ps(ptr); + __m128 _p1 = _mm_loadu_ps(ptr1); + __m128 _outp = op.func_pack4(_p, _p1); + _mm_storeu_ps(outptr, _outp); ptr += 4; + ptr1 += 4; + outptr += 4; } - - return 0; - } - - if (b.dims == 1) - { - c.create(w, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - - if (b.w == 1 && elempack1 == 1) +#endif // __SSE2__ + for (; i < size; i++) { - // type 6 - return binary_op_6_11_16_25(a, b, c, opt); + *outptr = op.func(*ptr, *ptr1); + ptr += 1; + ptr1 += 1; + outptr += 1; } - - // type 7 - binary_op_7_13_19_29(a, b, c, opt); } } return 0; } -#endif // __SSE2__ template static int binary_op_scalar_inplace(Mat& a, float b, const Option& opt) { Op op; - int w = a.w; - int h = a.h; - int d = a.d; - int channels = a.c; - int elempack = a.elempack; - int size = w * h * d * elempack; + const int channels = a.c; + const int size = a.w * a.h * a.d * a.elempack; #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) @@ -2642,148 +1022,223 @@ struct binary_op_rdiv #endif // __SSE2__ }; -} // namespace BinaryOp_x86_functor - -int BinaryOp_x86::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +struct binary_op_rpow { + float func(const float& x, const float& y) const + { + return (float)pow(y, x); + } #if __SSE2__ - using namespace BinaryOp_x86_functor; - - const Mat& bottom_blob = bottom_blobs[0]; - const Mat& bottom_blob1 = bottom_blobs[1]; - Mat& top_blob = top_blobs[0]; - - int elempack = bottom_blob.elempack; - int elempack1 = bottom_blob1.elempack; - + __m128 func_pack4(const __m128& x, const __m128& y) const + { + return pow_ps(y, x); + } #if __AVX__ + __m256 func_pack8(const __m256& x, const __m256& y) const + { + return pow256_ps(y, x); + } #if __AVX512F__ - if (elempack == 16 || elempack1 == 16) + __m512 func_pack16(const __m512& x, const __m512& y) const { - if (op_type == Operation_ADD) - return binary_op_pack16(bottom_blob, bottom_blob1, top_blob, opt); - - if (op_type == Operation_SUB) - return binary_op_pack16(bottom_blob, bottom_blob1, top_blob, opt); + return pow512_ps(y, x); + } +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +}; - if (op_type == Operation_MUL) - return binary_op_pack16(bottom_blob, bottom_blob1, top_blob, opt); +} // namespace BinaryOp_x86_functor - if (op_type == Operation_DIV) - return binary_op_pack16(bottom_blob, bottom_blob1, top_blob, opt); +static int binary_op_scalar(const Mat& a, float b, Mat& c, int op_type, const Option& opt) +{ + using namespace BinaryOp_x86_functor; - if (op_type == Operation_MAX) - return binary_op_pack16(bottom_blob, bottom_blob1, top_blob, opt); + if (op_type == BinaryOp::Operation_ADD) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_SUB) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_MUL) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_DIV) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_MAX) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_MIN) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_POW) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_RSUB) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_RDIV) return binary_op_scalar(a, b, c, opt); + if (op_type == BinaryOp::Operation_RPOW) return binary_op_scalar(a, b, c, opt); + + // should never reach here + return 0; +} - if (op_type == Operation_MIN) - return binary_op_pack16(bottom_blob, bottom_blob1, top_blob, opt); +static int binary_op_no_broadcast(const Mat& a, const Mat& b, Mat& c, int op_type, const Option& opt) +{ + using namespace BinaryOp_x86_functor; - if (op_type == Operation_POW) - return binary_op_pack16(bottom_blob, bottom_blob1, top_blob, opt); + if (op_type == BinaryOp::Operation_ADD) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_SUB) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_MUL) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_DIV) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_MAX) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_MIN) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_POW) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_RSUB) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_RDIV) return binary_op_no_broadcast(a, b, c, opt); + if (op_type == BinaryOp::Operation_RPOW) return binary_op_no_broadcast(a, b, c, opt); + + // should never reach here + return 0; +} - if (op_type == Operation_RSUB) - return binary_op_pack16(bottom_blob1, bottom_blob, top_blob, opt); +static int binary_op_broadcast_inner(const Mat& a, const Mat& b, Mat& c, int op_type, const Option& opt) +{ + // squeeze inner axes + Mat b2 = b; + if (b.dims == 2 && b.w == 1) + b2 = b.reshape(b.h); + else if (b.dims == 3 && b.h == 1) + b2 = b.reshape(b.c); + else if (b.dims == 3 && b.w == 1) + b2 = b.reshape(b.h, b.c); + else if (b.dims == 4 && b.d == 1) + b2 = b.reshape(b.c); + else if (b.dims == 4 && b.h == 1) + b2 = b.reshape(b.d, b.c); + else if (b.dims == 4 && b.w == 1) + b2 = b.reshape(b.h, b.d, b.c); - if (op_type == Operation_RDIV) - return binary_op_pack16(bottom_blob1, bottom_blob, top_blob, opt); - } -#endif // __AVX512F__ + using namespace BinaryOp_x86_functor; - if (elempack == 8 || elempack1 == 8) - { - if (op_type == Operation_ADD) - return binary_op_pack8(bottom_blob, bottom_blob1, top_blob, opt); + if (op_type == BinaryOp::Operation_ADD) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_SUB) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_MUL) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_DIV) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_MAX) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_MIN) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_POW) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_inner(a, b2, c, opt); + if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_inner(a, b2, c, opt); + + // should never reach here + return 0; +} - if (op_type == Operation_SUB) - return binary_op_pack8(bottom_blob, bottom_blob1, top_blob, opt); +static int binary_op_broadcast_outer(const Mat& a, const Mat& b, Mat& c, int op_type, const Option& opt) +{ + using namespace BinaryOp_x86_functor; - if (op_type == Operation_MUL) - return binary_op_pack8(bottom_blob, bottom_blob1, top_blob, opt); + if (op_type == BinaryOp::Operation_ADD) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_SUB) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_MUL) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_DIV) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_MAX) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_MIN) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_POW) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_outer(a, b, c, opt); + if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_outer(a, b, c, opt); + + // should never reach here + return 0; +} - if (op_type == Operation_DIV) - return binary_op_pack8(bottom_blob, bottom_blob1, top_blob, opt); +static int binary_op_broadcast_20(const Mat& a, const Mat& b, Mat& c, int op_type, const Option& opt) +{ + using namespace BinaryOp_x86_functor; - if (op_type == Operation_MAX) - return binary_op_pack8(bottom_blob, bottom_blob1, top_blob, opt); + if (op_type == BinaryOp::Operation_ADD) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_SUB) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_MUL) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_DIV) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_MAX) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_MIN) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_POW) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_20(a, b, c, opt); + if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_20(a, b, c, opt); + + // should never reach here + return 0; +} - if (op_type == Operation_MIN) - return binary_op_pack8(bottom_blob, bottom_blob1, top_blob, opt); +static int get_reverse_op_type(int op_type) +{ + if (op_type == BinaryOp::Operation_SUB) return BinaryOp::Operation_RSUB; + if (op_type == BinaryOp::Operation_DIV) return BinaryOp::Operation_RDIV; + if (op_type == BinaryOp::Operation_POW) return BinaryOp::Operation_RPOW; + if (op_type == BinaryOp::Operation_RSUB) return BinaryOp::Operation_SUB; + if (op_type == BinaryOp::Operation_RDIV) return BinaryOp::Operation_DIV; + if (op_type == BinaryOp::Operation_RPOW) return BinaryOp::Operation_POW; + return op_type; +} - if (op_type == Operation_POW) - return binary_op_pack8(bottom_blob, bottom_blob1, top_blob, opt); +int BinaryOp_x86::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +{ + const bool b_is_scalar = bottom_blobs[1].w * bottom_blobs[1].h * bottom_blobs[1].d * bottom_blobs[1].c * bottom_blobs[1].elempack == 1; + const bool a_rank_is_lower = bottom_blobs[0].dims < bottom_blobs[1].dims && !b_is_scalar; + const bool a_size_is_lower = bottom_blobs[0].w * bottom_blobs[0].h * bottom_blobs[0].d * bottom_blobs[0].c * bottom_blobs[0].elempack < bottom_blobs[1].w * bottom_blobs[1].h * bottom_blobs[1].d * bottom_blobs[1].c * bottom_blobs[1].elempack; + const bool a_is_lower = a_rank_is_lower || (!a_rank_is_lower && a_size_is_lower); + const Mat& A = a_is_lower ? bottom_blobs[1] : bottom_blobs[0]; + const Mat& B = a_is_lower ? bottom_blobs[0] : bottom_blobs[1]; + const int op_type_r = a_is_lower ? get_reverse_op_type(op_type) : op_type; - if (op_type == Operation_RSUB) - return binary_op_pack8(bottom_blob1, bottom_blob, top_blob, opt); + Mat& top_blob = top_blobs[0]; + top_blob.create_like(A, opt.blob_allocator); + if (top_blob.empty()) + return -100; - if (op_type == Operation_RDIV) - return binary_op_pack8(bottom_blob1, bottom_blob, top_blob, opt); + // B is a scalar + if (B.w * B.h * B.d * B.c * B.elempack == 1) + { + return binary_op_scalar(A, B[0], top_blob, op_type_r, opt); } -#endif // __AVX__ - if (elempack == 4 || elempack1 == 4) + // no broadcast + if (A.dims == B.dims && A.w == B.w && A.h == B.h && A.d == B.d && A.c == B.c && A.elempack == B.elempack) { - if (op_type == Operation_ADD) - return binary_op_pack4(bottom_blob, bottom_blob1, top_blob, opt); - - if (op_type == Operation_SUB) - return binary_op_pack4(bottom_blob, bottom_blob1, top_blob, opt); - - if (op_type == Operation_MUL) - return binary_op_pack4(bottom_blob, bottom_blob1, top_blob, opt); - - if (op_type == Operation_DIV) - return binary_op_pack4(bottom_blob, bottom_blob1, top_blob, opt); - - if (op_type == Operation_MAX) - return binary_op_pack4(bottom_blob, bottom_blob1, top_blob, opt); - - if (op_type == Operation_MIN) - return binary_op_pack4(bottom_blob, bottom_blob1, top_blob, opt); + return binary_op_no_broadcast(A, B, top_blob, op_type_r, opt); + } - if (op_type == Operation_POW) - return binary_op_pack4(bottom_blob, bottom_blob1, top_blob, opt); + // broadcast B for inner axis + if ((B.dims < A.dims) + || (A.dims == 2 && B.w == 1 && B.h == A.h) + || (A.dims == 3 && B.w == 1 && B.h == 1 && B.c == A.c) + || (A.dims == 3 && B.w == 1 && B.h == A.h && B.c == A.c) + || (A.dims == 4 && B.w == 1 && B.h == 1 && B.d == 1 && B.c == A.c) + || (A.dims == 4 && B.w == 1 && B.h == 1 && B.d == A.d && B.c == A.c) + || (A.dims == 4 && B.w == 1 && B.h == A.h && B.d == A.d && B.c == A.c)) + { + return binary_op_broadcast_inner(A, B, top_blob, op_type_r, opt); + } - if (op_type == Operation_RSUB) - return binary_op_pack4(bottom_blob1, bottom_blob, top_blob, opt); + // broadcast B for outer axis + if (B.elempack == 1 && ((A.dims == 2 && B.w == A.w && B.h == 1) || (A.dims == 3 && B.w == A.w && B.h == 1 && B.c == 1) || (A.dims == 3 && B.w == A.w && B.h == A.h && B.c == 1) || (A.dims == 4 && B.w == A.w && B.h == 1 && B.d == 1 && B.c == 1) || (A.dims == 4 && B.w == A.w && B.h == A.h && B.d == 1 && B.c == 1) || (A.dims == 4 && B.w == A.w && B.h == A.h && B.d == A.d && B.c == 1))) + { + return binary_op_broadcast_outer(A, B, top_blob, op_type_r, opt); + } - if (op_type == Operation_RDIV) - return binary_op_pack4(bottom_blob1, bottom_blob, top_blob, opt); + // some special broadcast rule here + if (A.dims == 3 && B.dims == 3 && A.w == B.w && B.h == 1 && A.c == B.c) + { + return binary_op_broadcast_20(A, B, top_blob, op_type_r, opt); } -#endif // __SSE2__ - return BinaryOp::forward(bottom_blobs, top_blobs, opt); + return 0; } int BinaryOp_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const { using namespace BinaryOp_x86_functor; - if (op_type == Operation_ADD) - return binary_op_scalar_inplace(bottom_top_blob, b, opt); - - if (op_type == Operation_SUB) - return binary_op_scalar_inplace(bottom_top_blob, b, opt); - - if (op_type == Operation_MUL) - return binary_op_scalar_inplace(bottom_top_blob, b, opt); - - if (op_type == Operation_DIV) - return binary_op_scalar_inplace(bottom_top_blob, b, opt); - - if (op_type == Operation_MAX) - return binary_op_scalar_inplace(bottom_top_blob, b, opt); - - if (op_type == Operation_MIN) - return binary_op_scalar_inplace(bottom_top_blob, b, opt); - - if (op_type == Operation_POW) - return binary_op_scalar_inplace(bottom_top_blob, b, opt); - - if (op_type == Operation_RSUB) - return binary_op_scalar_inplace(bottom_top_blob, b, opt); - - if (op_type == Operation_RDIV) - return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_ADD) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_SUB) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_MUL) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_DIV) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_MAX) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_MIN) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_POW) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_RSUB) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_RDIV) return binary_op_scalar_inplace(bottom_top_blob, b, opt); + if (op_type == Operation_RPOW) return binary_op_scalar_inplace(bottom_top_blob, b, opt); return 0; } diff --git a/tests/test_binaryop.cpp b/tests/test_binaryop.cpp index f79ec024b..2fa9faace 100644 --- a/tests/test_binaryop.cpp +++ b/tests/test_binaryop.cpp @@ -15,7 +15,7 @@ #include "layer/binaryop.h" #include "testutil.h" -#define OP_TYPE_MAX 9 +#define OP_TYPE_MAX 10 static int op_type = 0; @@ -23,15 +23,19 @@ static int test_binaryop(const ncnn::Mat& _a, const ncnn::Mat& _b) { ncnn::Mat a = _a; ncnn::Mat b = _b; - if (op_type == 6) + if (op_type == 6 || op_type == 9) { - // value must be positive for pow + // value must be positive for pow/rpow + a = a.clone(); + b = b.clone(); Randomize(a, 0.001f, 2.f); Randomize(b, 0.001f, 2.f); } if (op_type == 3 || op_type == 8) { - // value must be positive for pow + // value must be positive for div/rdiv + a = a.clone(); + b = b.clone(); Randomize(a, 0.1f, 10.f); Randomize(b, 0.1f, 10.f); } @@ -59,12 +63,18 @@ static int test_binaryop(const ncnn::Mat& _a, const ncnn::Mat& _b) static int test_binaryop(const ncnn::Mat& _a, float b) { ncnn::Mat a = _a; - if (op_type == 6) + if (op_type == 6 || op_type == 9) { // value must be positive for pow Randomize(a, 0.001f, 2.f); b = RandomFloat(0.001f, 2.f); } + if (op_type == 3 || op_type == 8) + { + // value must be positive for div/rdiv + a = a.clone(); + Randomize(a, 0.1f, 10.f); + } ncnn::ParamDict pd; pd.set(0, op_type); @@ -82,300 +92,274 @@ static int test_binaryop(const ncnn::Mat& _a, float b) return ret; } -// https://github.com/Tencent/ncnn/wiki/binaryop-broadcasting - static int test_binaryop_1() { - return 0 - || test_binaryop(RandomMat(1), 1.f); -} - -static int test_binaryop_2() -{ - return 0 - || test_binaryop(RandomMat(1), RandomMat(1)) - || test_binaryop(RandomMat(1), RandomMat(4)) - || test_binaryop(RandomMat(1), RandomMat(16)); -} - -static int test_binaryop_3() -{ - return 0 - || test_binaryop(RandomMat(1), RandomMat(11, 3)) - || test_binaryop(RandomMat(1), RandomMat(11, 4)) - || test_binaryop(RandomMat(1), RandomMat(11, 16)); -} - -static int test_binaryop_4() -{ - return 0 - || test_binaryop(RandomMat(1), RandomMat(11, 6, 2)) - || test_binaryop(RandomMat(1), RandomMat(11, 6, 4)) - || test_binaryop(RandomMat(1), RandomMat(11, 6, 16)); -} - -static int test_binaryop_5() -{ - return 0 - || test_binaryop(RandomMat(2), 1.f) - || test_binaryop(RandomMat(4), 1.f) - || test_binaryop(RandomMat(16), 1.f); -} - -static int test_binaryop_6() -{ - return 0 - || test_binaryop(RandomMat(2), RandomMat(1)) - || test_binaryop(RandomMat(4), RandomMat(1)) - || test_binaryop(RandomMat(16), RandomMat(1)); -} - -static int test_binaryop_7() -{ - return 0 - || test_binaryop(RandomMat(2), RandomMat(2)) - || test_binaryop(RandomMat(4), RandomMat(4)) - || test_binaryop(RandomMat(16), RandomMat(16)); -} + ncnn::Mat a[] = { + RandomMat(31), + RandomMat(28), + RandomMat(24), + RandomMat(32), + RandomMat(13, 31), + RandomMat(14, 28), + RandomMat(15, 24), + RandomMat(16, 32), + RandomMat(7, 3, 31), + RandomMat(6, 4, 28), + RandomMat(5, 5, 24), + RandomMat(4, 6, 32), + RandomMat(2, 7, 3, 31), + RandomMat(3, 6, 4, 28), + RandomMat(4, 5, 5, 24), + RandomMat(5, 4, 6, 32) + }; + + ncnn::Mat b[] = { + RandomMat(1), + RandomMat(1, 1), + RandomMat(1, 1, 1), + RandomMat(1, 1, 1, 1) + }; + + for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) + { + for (int j = 0; j < sizeof(b) / sizeof(b[0]); j++) + { + int ret = test_binaryop(a[i], b[j]) || test_binaryop(b[j], a[i]); + if (ret != 0) + return ret; + } + + int ret = test_binaryop(a[i], 0.2f); + if (ret != 0) + return ret; + } -static int test_binaryop_8() -{ - return 0 - || test_binaryop(RandomMat(3), RandomMat(11, 3)) - || test_binaryop(RandomMat(4), RandomMat(11, 4)) - || test_binaryop(RandomMat(16), RandomMat(11, 16)); + return 0; } -static int test_binaryop_9() +static int test_binaryop_2() { - return 0 - || test_binaryop(RandomMat(2), RandomMat(11, 6, 2)) - || test_binaryop(RandomMat(4), RandomMat(11, 6, 4)) - || test_binaryop(RandomMat(16), RandomMat(11, 6, 16)); -} + ncnn::Mat a[] = { + RandomMat(31), + RandomMat(28), + RandomMat(24), + RandomMat(32), + RandomMat(13, 31), + RandomMat(14, 28), + RandomMat(15, 24), + RandomMat(16, 32), + RandomMat(7, 3, 31), + RandomMat(6, 4, 28), + RandomMat(5, 5, 24), + RandomMat(4, 6, 32), + RandomMat(2, 7, 3, 31), + RandomMat(3, 6, 4, 28), + RandomMat(4, 5, 5, 24), + RandomMat(5, 4, 6, 32) + }; + + for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) + { + ncnn::Mat b; + b.create_like(a[i]); + Randomize(b); -static int test_binaryop_10() -{ - return 0 - || test_binaryop(RandomMat(11, 3), 1.f) - || test_binaryop(RandomMat(11, 4), 1.f) - || test_binaryop(RandomMat(11, 16), 1.f); -} + int ret = test_binaryop(a[i], b) || test_binaryop(b, a[i]); + if (ret != 0) + return ret; + } -static int test_binaryop_11() -{ - return 0 - || test_binaryop(RandomMat(11, 3), RandomMat(1)) - || test_binaryop(RandomMat(11, 4), RandomMat(1)) - || test_binaryop(RandomMat(11, 16), RandomMat(1)); + return 0; } -static int test_binaryop_12() +static int test_binaryop_3() { - return 0 - || test_binaryop(RandomMat(11, 3), RandomMat(3)) - || test_binaryop(RandomMat(11, 4), RandomMat(4)) - || test_binaryop(RandomMat(11, 16), RandomMat(16)); -} + ncnn::Mat a[] = { + RandomMat(13, 31), + RandomMat(14, 28), + RandomMat(15, 24), + RandomMat(16, 32) + }; -static int test_binaryop_13() -{ - return 0 - || test_binaryop(RandomMat(11, 3), RandomMat(11, 3)) - || test_binaryop(RandomMat(11, 4), RandomMat(11, 4)) - || test_binaryop(RandomMat(11, 16), RandomMat(11, 16)); -} + for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) + { + ncnn::Mat b0(a[i].h); + ncnn::Mat b1(1, a[i].h); + Randomize(b0); + Randomize(b1); -static int test_binaryop_14() -{ - return 0 - || test_binaryop(RandomMat(6, 2), RandomMat(11, 6, 2)) - || test_binaryop(RandomMat(6, 4), RandomMat(11, 6, 4)) - || test_binaryop(RandomMat(6, 16), RandomMat(11, 6, 16)); -} + int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i]) + || test_binaryop(a[i], b1) || test_binaryop(b1, a[i]); + if (ret != 0) + return ret; + } -static int test_binaryop_15() -{ - return 0 - || test_binaryop(RandomMat(11, 6, 2), 1.f) - || test_binaryop(RandomMat(11, 6, 4), 1.f) - || test_binaryop(RandomMat(11, 6, 16), 1.f); + return 0; } -static int test_binaryop_16() +static int test_binaryop_4() { - return 0 - || test_binaryop(RandomMat(11, 6, 2), RandomMat(1)) - || test_binaryop(RandomMat(11, 6, 4), RandomMat(1)) - || test_binaryop(RandomMat(11, 6, 16), RandomMat(1)); -} + ncnn::Mat a[] = { + RandomMat(7, 3, 31), + RandomMat(6, 4, 28), + RandomMat(5, 5, 24), + RandomMat(4, 6, 32) + }; -static int test_binaryop_17() -{ - return 0 - || test_binaryop(RandomMat(11, 6, 2), RandomMat(2)) - || test_binaryop(RandomMat(11, 6, 4), RandomMat(4)) - || test_binaryop(RandomMat(11, 6, 16), RandomMat(16)); -} - -static int test_binaryop_18() -{ - return 0 - || test_binaryop(RandomMat(11, 6, 2), RandomMat(6, 2)) - || test_binaryop(RandomMat(11, 6, 4), RandomMat(6, 4)) - || test_binaryop(RandomMat(11, 6, 16), RandomMat(6, 16)); -} + for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) + { + ncnn::Mat b0(a[i].c); + ncnn::Mat b1(1, 1, a[i].c); + ncnn::Mat b2(a[i].h, a[i].c); + ncnn::Mat b3(1, a[i].h, a[i].c); + Randomize(b0); + Randomize(b1); + Randomize(b2); + Randomize(b3); + + int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i]) + || test_binaryop(a[i], b1) || test_binaryop(b1, a[i]) + || test_binaryop(a[i], b2) || test_binaryop(b2, a[i]) + || test_binaryop(a[i], b3) || test_binaryop(b3, a[i]); + if (ret != 0) + return ret; + } -static int test_binaryop_19() -{ - return 0 - || test_binaryop(RandomMat(11, 6, 2), RandomMat(11, 6, 2)) - || test_binaryop(RandomMat(11, 6, 4), RandomMat(11, 6, 4)) - || test_binaryop(RandomMat(11, 6, 16), RandomMat(11, 6, 16)); + return 0; } -static int test_binaryop_20() +static int test_binaryop_5() { - return 0 - || test_binaryop(RandomMat(1), RandomMat(11, 3, 4, 2)) - || test_binaryop(RandomMat(1), RandomMat(11, 3, 4, 4)) - || test_binaryop(RandomMat(1), RandomMat(11, 3, 4, 16)); -} + ncnn::Mat a[] = { + RandomMat(2, 7, 3, 31), + RandomMat(3, 6, 4, 28), + RandomMat(4, 5, 5, 24), + RandomMat(5, 4, 6, 32) + }; -static int test_binaryop_21() -{ - return 0 - || test_binaryop(RandomMat(2), RandomMat(11, 3, 4, 2)) - || test_binaryop(RandomMat(4), RandomMat(11, 3, 4, 4)) - || test_binaryop(RandomMat(16), RandomMat(11, 3, 4, 16)); -} + for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) + { + ncnn::Mat b0(a[i].c); + ncnn::Mat b1(1, 1, 1, a[i].c); + ncnn::Mat b2(a[i].d, a[i].c); + ncnn::Mat b3(1, 1, a[i].d, a[i].c); + ncnn::Mat b4(a[i].h, a[i].d, a[i].c); + ncnn::Mat b5(1, a[i].h, a[i].d, a[i].c); + Randomize(b0); + Randomize(b1); + Randomize(b2); + Randomize(b3); + Randomize(b4); + Randomize(b5); + + int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i]) + || test_binaryop(a[i], b1) || test_binaryop(b1, a[i]) + || test_binaryop(a[i], b2) || test_binaryop(b2, a[i]) + || test_binaryop(a[i], b3) || test_binaryop(b3, a[i]) + || test_binaryop(a[i], b4) || test_binaryop(b4, a[i]) + || test_binaryop(a[i], b5) || test_binaryop(b5, a[i]); + if (ret != 0) + return ret; + } -static int test_binaryop_22() -{ - return 0 - || test_binaryop(RandomMat(4, 2), RandomMat(11, 3, 4, 2)) - || test_binaryop(RandomMat(4, 4), RandomMat(11, 3, 4, 4)) - || test_binaryop(RandomMat(4, 16), RandomMat(11, 3, 4, 16)); + return 0; } -static int test_binaryop_23() +static int test_binaryop_6() { - return 0 - || test_binaryop(RandomMat(3, 4, 2), RandomMat(11, 3, 4, 2)) - || test_binaryop(RandomMat(3, 4, 4), RandomMat(11, 3, 4, 4)) - || test_binaryop(RandomMat(3, 4, 16), RandomMat(11, 3, 4, 16)); -} + ncnn::Mat a[] = { + RandomMat(13, 31), + RandomMat(14, 28), + RandomMat(15, 24), + RandomMat(16, 32) + }; -static int test_binaryop_24() -{ - return 0 - || test_binaryop(RandomMat(11, 3, 4, 2), 1.f) - || test_binaryop(RandomMat(11, 3, 4, 4), 1.f) - || test_binaryop(RandomMat(11, 3, 4, 16), 1.f); -} + for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) + { + ncnn::Mat b0(a[i].w, 1); + Randomize(b0); -static int test_binaryop_25() -{ - return 0 - || test_binaryop(RandomMat(11, 3, 4, 2), RandomMat(1)) - || test_binaryop(RandomMat(11, 3, 4, 4), RandomMat(1)) - || test_binaryop(RandomMat(11, 3, 4, 16), RandomMat(1)); -} + int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i]); + if (ret != 0) + return ret; + } -static int test_binaryop_26() -{ - return 0 - || test_binaryop(RandomMat(11, 3, 4, 2), RandomMat(2)) - || test_binaryop(RandomMat(11, 3, 4, 4), RandomMat(4)) - || test_binaryop(RandomMat(11, 3, 4, 16), RandomMat(16)); + return 0; } -static int test_binaryop_27() +static int test_binaryop_7() { - return 0 - || test_binaryop(RandomMat(11, 3, 4, 2), RandomMat(4, 2)) - || test_binaryop(RandomMat(11, 3, 4, 4), RandomMat(4, 4)) - || test_binaryop(RandomMat(11, 3, 4, 16), RandomMat(4, 16)); -} + ncnn::Mat a[] = { + RandomMat(7, 3, 31), + RandomMat(6, 4, 28), + RandomMat(5, 5, 24), + RandomMat(4, 6, 32) + }; -static int test_binaryop_28() -{ - return 0 - || test_binaryop(RandomMat(11, 3, 4, 2), RandomMat(3, 4, 2)) - || test_binaryop(RandomMat(11, 3, 4, 4), RandomMat(3, 4, 4)) - || test_binaryop(RandomMat(11, 3, 4, 16), RandomMat(3, 4, 16)); -} + for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) + { + ncnn::Mat b0(a[i].w, 1, 1); + ncnn::Mat b1(a[i].w, a[i].h, 1); + Randomize(b0); + Randomize(b1); -static int test_binaryop_29() -{ - return 0 - || test_binaryop(RandomMat(11, 3, 4, 2), RandomMat(11, 3, 4, 2)) - || test_binaryop(RandomMat(11, 3, 4, 4), RandomMat(11, 3, 4, 4)) - || test_binaryop(RandomMat(11, 3, 4, 16), RandomMat(11, 3, 4, 16)); -} + int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i]) + || test_binaryop(a[i], b1) || test_binaryop(b1, a[i]); + if (ret != 0) + return ret; + } -static int test_binaryop_s1() -{ - return 0 - || test_binaryop(RandomMat(11, 6, 2), RandomMat(1, 1, 2)) - || test_binaryop(RandomMat(11, 6, 4), RandomMat(1, 1, 4)) - || test_binaryop(RandomMat(11, 6, 16), RandomMat(1, 1, 16)); + return 0; } -static int test_binaryop_s2() +static int test_binaryop_8() { - return 0 - || test_binaryop(RandomMat(11, 6, 2), RandomMat(11, 6, 1)) - || test_binaryop(RandomMat(11, 6, 4), RandomMat(11, 6, 1)) - || test_binaryop(RandomMat(11, 6, 16), RandomMat(11, 6, 1)); -} + ncnn::Mat a[] = { + RandomMat(2, 7, 3, 31), + RandomMat(3, 6, 4, 28), + RandomMat(4, 5, 5, 24), + RandomMat(5, 4, 6, 32) + }; -static int test_binaryop_s3() -{ - return 0 - || test_binaryop(RandomMat(1, 1, 2), RandomMat(11, 6, 2)) - || test_binaryop(RandomMat(1, 1, 4), RandomMat(11, 6, 4)) - || test_binaryop(RandomMat(1, 1, 16), RandomMat(11, 6, 16)); -} + for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) + { + ncnn::Mat b0(a[i].w, 1, 1, 1); + ncnn::Mat b1(a[i].w, a[i].h, 1, 1); + ncnn::Mat b2(a[i].w, a[i].h, a[i].d, 1); + Randomize(b0); + Randomize(b1); + Randomize(b2); + + int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i]) + || test_binaryop(a[i], b1) || test_binaryop(b1, a[i]) + || test_binaryop(a[i], b2) || test_binaryop(b2, a[i]); + if (ret != 0) + return ret; + } -static int test_binaryop_s4() -{ - return 0 - || test_binaryop(RandomMat(11, 6, 1), RandomMat(11, 6, 2)) - || test_binaryop(RandomMat(11, 6, 1), RandomMat(11, 6, 4)) - || test_binaryop(RandomMat(11, 6, 1), RandomMat(11, 6, 16)); + return 0; } -static int test_binaryop_s5() +static int test_binaryop_9() { - return 0 - || test_binaryop(RandomMat(11, 6, 2), RandomMat(1, 6, 2)) - || test_binaryop(RandomMat(11, 6, 4), RandomMat(1, 6, 4)) - || test_binaryop(RandomMat(11, 6, 16), RandomMat(1, 6, 16)); -} + ncnn::Mat a[] = { + RandomMat(7, 3, 31), + RandomMat(6, 4, 28), + RandomMat(5, 5, 24), + RandomMat(4, 6, 32) + }; -static int test_binaryop_s6() -{ - return 0 - || test_binaryop(RandomMat(11, 6, 2), RandomMat(11, 1, 2)) - || test_binaryop(RandomMat(11, 6, 4), RandomMat(11, 1, 4)) - || test_binaryop(RandomMat(11, 6, 16), RandomMat(11, 1, 16)); -} + for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) + { + ncnn::Mat b0(a[i].w, 1, a[i].c); + Randomize(b0); -static int test_binaryop_s7() -{ - return 0 - || test_binaryop(RandomMat(1, 6, 2), RandomMat(11, 6, 2)) - || test_binaryop(RandomMat(1, 6, 4), RandomMat(11, 6, 4)) - || test_binaryop(RandomMat(1, 6, 16), RandomMat(11, 6, 16)); -} + int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i]); + if (ret != 0) + return ret; + } -static int test_binaryop_s8() -{ - return 0 - || test_binaryop(RandomMat(11, 1, 2), RandomMat(11, 6, 2)) - || test_binaryop(RandomMat(11, 1, 4), RandomMat(11, 6, 4)) - || test_binaryop(RandomMat(11, 1, 16), RandomMat(11, 6, 16)); + return 0; } int main() @@ -393,35 +377,7 @@ int main() || test_binaryop_6() || test_binaryop_7() || test_binaryop_8() - || test_binaryop_9() - || test_binaryop_10() - || test_binaryop_11() - || test_binaryop_12() - || test_binaryop_13() - || test_binaryop_14() - || test_binaryop_15() - || test_binaryop_16() - || test_binaryop_17() - || test_binaryop_18() - || test_binaryop_19() - || test_binaryop_20() - || test_binaryop_21() - || test_binaryop_22() - || test_binaryop_23() - || test_binaryop_24() - || test_binaryop_25() - || test_binaryop_26() - || test_binaryop_27() - || test_binaryop_28() - || test_binaryop_29() - || test_binaryop_s1() - || test_binaryop_s2() - || test_binaryop_s3() - || test_binaryop_s4() - || test_binaryop_s5() - || test_binaryop_s6() - || test_binaryop_s7() - || test_binaryop_s8(); + || test_binaryop_9(); if (ret != 0) return ret; diff --git a/tests/test_binaryop_1.cpp b/tests/test_binaryop_1.cpp index bc0ec9c89..6f0e2e87a 100644 --- a/tests/test_binaryop_1.cpp +++ b/tests/test_binaryop_1.cpp @@ -15,7 +15,7 @@ #include "layer/binaryop.h" #include "testutil.h" -#define OP_TYPE_MAX 9 +#define OP_TYPE_MAX 10 static int op_type = 0; @@ -23,15 +23,19 @@ static int test_binaryop(const ncnn::Mat& _a, const ncnn::Mat& _b) { ncnn::Mat a = _a; ncnn::Mat b = _b; - if (op_type == 6) + if (op_type == 6 || op_type == 9) { - // value must be positive for pow + // value must be positive for pow/rpow + a = a.clone(); + b = b.clone(); Randomize(a, 0.001f, 2.f); Randomize(b, 0.001f, 2.f); } if (op_type == 3 || op_type == 8) { - // value must be positive for pow + // value must be positive for div/rdiv + a = a.clone(); + b = b.clone(); Randomize(a, 0.1f, 10.f); Randomize(b, 0.1f, 10.f); } @@ -59,12 +63,18 @@ static int test_binaryop(const ncnn::Mat& _a, const ncnn::Mat& _b) static int test_binaryop(const ncnn::Mat& _a, float b) { ncnn::Mat a = _a; - if (op_type == 6) + if (op_type == 6 || op_type == 9) { // value must be positive for pow Randomize(a, 0.001f, 2.f); b = RandomFloat(0.001f, 2.f); } + if (op_type == 3 || op_type == 8) + { + // value must be positive for div/rdiv + a = a.clone(); + Randomize(a, 0.1f, 10.f); + } ncnn::ParamDict pd; pd.set(0, op_type); @@ -86,296 +96,272 @@ static int test_binaryop(const ncnn::Mat& _a, float b) static int test_binaryop_1() { - return 0 - || test_binaryop(RandomMat(1), 1.f); -} - -static int test_binaryop_2() -{ - return 0 - || test_binaryop(RandomMat(1), RandomMat(1)) - || test_binaryop(RandomMat(1), RandomMat(4)) - || test_binaryop(RandomMat(1), RandomMat(16)); -} - -static int test_binaryop_3() -{ - return 0 - || test_binaryop(RandomMat(1), RandomMat(11, 3)) - || test_binaryop(RandomMat(1), RandomMat(11, 4)) - || test_binaryop(RandomMat(1), RandomMat(11, 16)); -} - -static int test_binaryop_4() -{ - return 0 - || test_binaryop(RandomMat(1), RandomMat(11, 6, 2)) - || test_binaryop(RandomMat(1), RandomMat(11, 6, 4)) - || test_binaryop(RandomMat(1), RandomMat(11, 6, 16)); -} - -static int test_binaryop_5() -{ - return 0 - || test_binaryop(RandomMat(2), 1.f) - || test_binaryop(RandomMat(4), 1.f) - || test_binaryop(RandomMat(16), 1.f); -} - -static int test_binaryop_6() -{ - return 0 - || test_binaryop(RandomMat(2), RandomMat(1)) - || test_binaryop(RandomMat(4), RandomMat(1)) - || test_binaryop(RandomMat(16), RandomMat(1)); -} - -static int test_binaryop_7() -{ - return 0 - || test_binaryop(RandomMat(2), RandomMat(2)) - || test_binaryop(RandomMat(4), RandomMat(4)) - || test_binaryop(RandomMat(16), RandomMat(16)); -} - -static int test_binaryop_8() -{ - return 0 - || test_binaryop(RandomMat(3), RandomMat(11, 3)) - || test_binaryop(RandomMat(4), RandomMat(11, 4)) - || test_binaryop(RandomMat(16), RandomMat(11, 16)); -} + ncnn::Mat a[] = { + RandomMat(31), + RandomMat(28), + RandomMat(24), + RandomMat(32), + RandomMat(13, 31), + RandomMat(14, 28), + RandomMat(15, 24), + RandomMat(16, 32), + RandomMat(7, 3, 31), + RandomMat(6, 4, 28), + RandomMat(5, 5, 24), + RandomMat(4, 6, 32), + RandomMat(2, 7, 3, 31), + RandomMat(3, 6, 4, 28), + RandomMat(4, 5, 5, 24), + RandomMat(5, 4, 6, 32) + }; + + ncnn::Mat b[] = { + RandomMat(1), + RandomMat(1, 1), + RandomMat(1, 1, 1), + RandomMat(1, 1, 1, 1) + }; + + for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) + { + for (int j = 0; j < sizeof(b) / sizeof(b[0]); j++) + { + int ret = test_binaryop(a[i], b[j]) || test_binaryop(b[j], a[i]); + if (ret != 0) + return ret; + } + + int ret = test_binaryop(a[i], 0.2f); + if (ret != 0) + return ret; + } -static int test_binaryop_9() -{ - return 0 - || test_binaryop(RandomMat(2), RandomMat(11, 6, 2)) - || test_binaryop(RandomMat(4), RandomMat(11, 6, 4)) - || test_binaryop(RandomMat(16), RandomMat(11, 6, 16)); + return 0; } -static int test_binaryop_10() +static int test_binaryop_2() { - return 0 - || test_binaryop(RandomMat(11, 3), 1.f) - || test_binaryop(RandomMat(11, 4), 1.f) - || test_binaryop(RandomMat(11, 16), 1.f); -} + ncnn::Mat a[] = { + RandomMat(31), + RandomMat(28), + RandomMat(24), + RandomMat(32), + RandomMat(13, 31), + RandomMat(14, 28), + RandomMat(15, 24), + RandomMat(16, 32), + RandomMat(7, 3, 31), + RandomMat(6, 4, 28), + RandomMat(5, 5, 24), + RandomMat(4, 6, 32), + RandomMat(2, 7, 3, 31), + RandomMat(3, 6, 4, 28), + RandomMat(4, 5, 5, 24), + RandomMat(5, 4, 6, 32) + }; + + for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) + { + ncnn::Mat b; + b.create_like(a[i]); + Randomize(b); -static int test_binaryop_11() -{ - return 0 - || test_binaryop(RandomMat(11, 3), RandomMat(1)) - || test_binaryop(RandomMat(11, 4), RandomMat(1)) - || test_binaryop(RandomMat(11, 16), RandomMat(1)); -} + int ret = test_binaryop(a[i], b) || test_binaryop(b, a[i]); + if (ret != 0) + return ret; + } -static int test_binaryop_12() -{ - return 0 - || test_binaryop(RandomMat(11, 3), RandomMat(3)) - || test_binaryop(RandomMat(11, 4), RandomMat(4)) - || test_binaryop(RandomMat(11, 16), RandomMat(16)); + return 0; } -static int test_binaryop_13() +static int test_binaryop_3() { - return 0 - || test_binaryop(RandomMat(11, 3), RandomMat(11, 3)) - || test_binaryop(RandomMat(11, 4), RandomMat(11, 4)) - || test_binaryop(RandomMat(11, 16), RandomMat(11, 16)); -} + ncnn::Mat a[] = { + RandomMat(13, 31), + RandomMat(14, 28), + RandomMat(15, 24), + RandomMat(16, 32) + }; -static int test_binaryop_14() -{ - return 0 - || test_binaryop(RandomMat(6, 2), RandomMat(11, 6, 2)) - || test_binaryop(RandomMat(6, 4), RandomMat(11, 6, 4)) - || test_binaryop(RandomMat(6, 16), RandomMat(11, 6, 16)); -} + for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) + { + ncnn::Mat b0(a[i].h); + ncnn::Mat b1(1, a[i].h); + Randomize(b0); + Randomize(b1); -static int test_binaryop_15() -{ - return 0 - || test_binaryop(RandomMat(11, 6, 2), 1.f) - || test_binaryop(RandomMat(11, 6, 4), 1.f) - || test_binaryop(RandomMat(11, 6, 16), 1.f); -} + int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i]) + || test_binaryop(a[i], b1) || test_binaryop(b1, a[i]); + if (ret != 0) + return ret; + } -static int test_binaryop_16() -{ - return 0 - || test_binaryop(RandomMat(11, 6, 2), RandomMat(1)) - || test_binaryop(RandomMat(11, 6, 4), RandomMat(1)) - || test_binaryop(RandomMat(11, 6, 16), RandomMat(1)); + return 0; } -static int test_binaryop_17() +static int test_binaryop_4() { - return 0 - || test_binaryop(RandomMat(11, 6, 2), RandomMat(2)) - || test_binaryop(RandomMat(11, 6, 4), RandomMat(4)) - || test_binaryop(RandomMat(11, 6, 16), RandomMat(16)); -} + ncnn::Mat a[] = { + RandomMat(7, 3, 31), + RandomMat(6, 4, 28), + RandomMat(5, 5, 24), + RandomMat(4, 6, 32) + }; -static int test_binaryop_18() -{ - return 0 - || test_binaryop(RandomMat(11, 6, 2), RandomMat(6, 2)) - || test_binaryop(RandomMat(11, 6, 4), RandomMat(6, 4)) - || test_binaryop(RandomMat(11, 6, 16), RandomMat(6, 16)); -} + for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) + { + ncnn::Mat b0(a[i].c); + ncnn::Mat b1(1, 1, a[i].c); + ncnn::Mat b2(a[i].h, a[i].c); + ncnn::Mat b3(1, a[i].h, a[i].c); + Randomize(b0); + Randomize(b1); + Randomize(b2); + Randomize(b3); + + int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i]) + || test_binaryop(a[i], b1) || test_binaryop(b1, a[i]) + || test_binaryop(a[i], b2) || test_binaryop(b2, a[i]) + || test_binaryop(a[i], b3) || test_binaryop(b3, a[i]); + if (ret != 0) + return ret; + } -static int test_binaryop_19() -{ - return 0 - || test_binaryop(RandomMat(11, 6, 2), RandomMat(11, 6, 2)) - || test_binaryop(RandomMat(11, 6, 4), RandomMat(11, 6, 4)) - || test_binaryop(RandomMat(11, 6, 16), RandomMat(11, 6, 16)); + return 0; } -static int test_binaryop_20() +static int test_binaryop_5() { - return 0 - || test_binaryop(RandomMat(1), RandomMat(11, 3, 4, 2)) - || test_binaryop(RandomMat(1), RandomMat(11, 3, 4, 4)) - || test_binaryop(RandomMat(1), RandomMat(11, 3, 4, 16)); -} + ncnn::Mat a[] = { + RandomMat(2, 7, 3, 31), + RandomMat(3, 6, 4, 28), + RandomMat(4, 5, 5, 24), + RandomMat(5, 4, 6, 32) + }; -static int test_binaryop_21() -{ - return 0 - || test_binaryop(RandomMat(2), RandomMat(11, 3, 4, 2)) - || test_binaryop(RandomMat(4), RandomMat(11, 3, 4, 4)) - || test_binaryop(RandomMat(16), RandomMat(11, 3, 4, 16)); -} + for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) + { + ncnn::Mat b0(a[i].c); + ncnn::Mat b1(1, 1, 1, a[i].c); + ncnn::Mat b2(a[i].d, a[i].c); + ncnn::Mat b3(1, 1, a[i].d, a[i].c); + ncnn::Mat b4(a[i].h, a[i].d, a[i].c); + ncnn::Mat b5(1, a[i].h, a[i].d, a[i].c); + Randomize(b0); + Randomize(b1); + Randomize(b2); + Randomize(b3); + Randomize(b4); + Randomize(b5); + + int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i]) + || test_binaryop(a[i], b1) || test_binaryop(b1, a[i]) + || test_binaryop(a[i], b2) || test_binaryop(b2, a[i]) + || test_binaryop(a[i], b3) || test_binaryop(b3, a[i]) + || test_binaryop(a[i], b4) || test_binaryop(b4, a[i]) + || test_binaryop(a[i], b5) || test_binaryop(b5, a[i]); + if (ret != 0) + return ret; + } -static int test_binaryop_22() -{ - return 0 - || test_binaryop(RandomMat(4, 2), RandomMat(11, 3, 4, 2)) - || test_binaryop(RandomMat(4, 4), RandomMat(11, 3, 4, 4)) - || test_binaryop(RandomMat(4, 16), RandomMat(11, 3, 4, 16)); + return 0; } -static int test_binaryop_23() +static int test_binaryop_6() { - return 0 - || test_binaryop(RandomMat(3, 4, 2), RandomMat(11, 3, 4, 2)) - || test_binaryop(RandomMat(3, 4, 4), RandomMat(11, 3, 4, 4)) - || test_binaryop(RandomMat(3, 4, 16), RandomMat(11, 3, 4, 16)); -} + ncnn::Mat a[] = { + RandomMat(13, 31), + RandomMat(14, 28), + RandomMat(15, 24), + RandomMat(16, 32) + }; -static int test_binaryop_24() -{ - return 0 - || test_binaryop(RandomMat(11, 3, 4, 2), 1.f) - || test_binaryop(RandomMat(11, 3, 4, 4), 1.f) - || test_binaryop(RandomMat(11, 3, 4, 16), 1.f); -} + for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) + { + ncnn::Mat b0(a[i].w, 1); + Randomize(b0); -static int test_binaryop_25() -{ - return 0 - || test_binaryop(RandomMat(11, 3, 4, 2), RandomMat(1)) - || test_binaryop(RandomMat(11, 3, 4, 4), RandomMat(1)) - || test_binaryop(RandomMat(11, 3, 4, 16), RandomMat(1)); -} + int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i]); + if (ret != 0) + return ret; + } -static int test_binaryop_26() -{ - return 0 - || test_binaryop(RandomMat(11, 3, 4, 2), RandomMat(2)) - || test_binaryop(RandomMat(11, 3, 4, 4), RandomMat(4)) - || test_binaryop(RandomMat(11, 3, 4, 16), RandomMat(16)); + return 0; } -static int test_binaryop_27() +static int test_binaryop_7() { - return 0 - || test_binaryop(RandomMat(11, 3, 4, 2), RandomMat(4, 2)) - || test_binaryop(RandomMat(11, 3, 4, 4), RandomMat(4, 4)) - || test_binaryop(RandomMat(11, 3, 4, 16), RandomMat(4, 16)); -} + ncnn::Mat a[] = { + RandomMat(7, 3, 31), + RandomMat(6, 4, 28), + RandomMat(5, 5, 24), + RandomMat(4, 6, 32) + }; -static int test_binaryop_28() -{ - return 0 - || test_binaryop(RandomMat(11, 3, 4, 2), RandomMat(3, 4, 2)) - || test_binaryop(RandomMat(11, 3, 4, 4), RandomMat(3, 4, 4)) - || test_binaryop(RandomMat(11, 3, 4, 16), RandomMat(3, 4, 16)); -} + for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) + { + ncnn::Mat b0(a[i].w, 1, 1); + ncnn::Mat b1(a[i].w, a[i].h, 1); + Randomize(b0); + Randomize(b1); -static int test_binaryop_29() -{ - return 0 - || test_binaryop(RandomMat(11, 3, 4, 2), RandomMat(11, 3, 4, 2)) - || test_binaryop(RandomMat(11, 3, 4, 4), RandomMat(11, 3, 4, 4)) - || test_binaryop(RandomMat(11, 3, 4, 16), RandomMat(11, 3, 4, 16)); -} + int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i]) + || test_binaryop(a[i], b1) || test_binaryop(b1, a[i]); + if (ret != 0) + return ret; + } -static int test_binaryop_s1() -{ - return 0 - || test_binaryop(RandomMat(11, 6, 2), RandomMat(1, 1, 2)) - || test_binaryop(RandomMat(11, 6, 4), RandomMat(1, 1, 4)) - || test_binaryop(RandomMat(11, 6, 16), RandomMat(1, 1, 16)); + return 0; } -static int test_binaryop_s2() +static int test_binaryop_8() { - return 0 - || test_binaryop(RandomMat(11, 6, 2), RandomMat(11, 6, 1)) - || test_binaryop(RandomMat(11, 6, 4), RandomMat(11, 6, 1)) - || test_binaryop(RandomMat(11, 6, 16), RandomMat(11, 6, 1)); -} + ncnn::Mat a[] = { + RandomMat(2, 7, 3, 31), + RandomMat(3, 6, 4, 28), + RandomMat(4, 5, 5, 24), + RandomMat(5, 4, 6, 32) + }; -static int test_binaryop_s3() -{ - return 0 - || test_binaryop(RandomMat(1, 1, 2), RandomMat(11, 6, 2)) - || test_binaryop(RandomMat(1, 1, 4), RandomMat(11, 6, 4)) - || test_binaryop(RandomMat(1, 1, 16), RandomMat(11, 6, 16)); -} + for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) + { + ncnn::Mat b0(a[i].w, 1, 1, 1); + ncnn::Mat b1(a[i].w, a[i].h, 1, 1); + ncnn::Mat b2(a[i].w, a[i].h, a[i].d, 1); + Randomize(b0); + Randomize(b1); + Randomize(b2); + + int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i]) + || test_binaryop(a[i], b1) || test_binaryop(b1, a[i]) + || test_binaryop(a[i], b2) || test_binaryop(b2, a[i]); + if (ret != 0) + return ret; + } -static int test_binaryop_s4() -{ - return 0 - || test_binaryop(RandomMat(11, 6, 1), RandomMat(11, 6, 2)) - || test_binaryop(RandomMat(11, 6, 1), RandomMat(11, 6, 4)) - || test_binaryop(RandomMat(11, 6, 1), RandomMat(11, 6, 16)); + return 0; } -static int test_binaryop_s5() +static int test_binaryop_9() { - return 0 - || test_binaryop(RandomMat(11, 6, 2), RandomMat(1, 6, 2)) - || test_binaryop(RandomMat(11, 6, 4), RandomMat(1, 6, 4)) - || test_binaryop(RandomMat(11, 6, 16), RandomMat(1, 6, 16)); -} + ncnn::Mat a[] = { + RandomMat(7, 3, 31), + RandomMat(6, 4, 28), + RandomMat(5, 5, 24), + RandomMat(4, 6, 32) + }; -static int test_binaryop_s6() -{ - return 0 - || test_binaryop(RandomMat(11, 6, 2), RandomMat(11, 1, 2)) - || test_binaryop(RandomMat(11, 6, 4), RandomMat(11, 1, 4)) - || test_binaryop(RandomMat(11, 6, 16), RandomMat(11, 1, 16)); -} + for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) + { + ncnn::Mat b0(a[i].w, 1, a[i].c); + Randomize(b0); -static int test_binaryop_s7() -{ - return 0 - || test_binaryop(RandomMat(1, 6, 2), RandomMat(11, 6, 2)) - || test_binaryop(RandomMat(1, 6, 4), RandomMat(11, 6, 4)) - || test_binaryop(RandomMat(1, 6, 16), RandomMat(11, 6, 16)); -} + int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i]); + if (ret != 0) + return ret; + } -static int test_binaryop_s8() -{ - return 0 - || test_binaryop(RandomMat(11, 1, 2), RandomMat(11, 6, 2)) - || test_binaryop(RandomMat(11, 1, 4), RandomMat(11, 6, 4)) - || test_binaryop(RandomMat(11, 1, 16), RandomMat(11, 6, 16)); + return 0; } int main() @@ -393,35 +379,7 @@ int main() || test_binaryop_6() || test_binaryop_7() || test_binaryop_8() - || test_binaryop_9() - || test_binaryop_10() - || test_binaryop_11() - || test_binaryop_12() - || test_binaryop_13() - || test_binaryop_14() - || test_binaryop_15() - || test_binaryop_16() - || test_binaryop_17() - || test_binaryop_18() - || test_binaryop_19() - || test_binaryop_20() - || test_binaryop_21() - || test_binaryop_22() - || test_binaryop_23() - || test_binaryop_24() - || test_binaryop_25() - || test_binaryop_26() - || test_binaryop_27() - || test_binaryop_28() - || test_binaryop_29() - || test_binaryop_s1() - || test_binaryop_s2() - || test_binaryop_s3() - || test_binaryop_s4() - || test_binaryop_s5() - || test_binaryop_s6() - || test_binaryop_s7() - || test_binaryop_s8(); + || test_binaryop_9(); if (ret != 0) return ret; diff --git a/tests/test_binaryop_2.cpp b/tests/test_binaryop_2.cpp index 1608a2880..1a744fef2 100644 --- a/tests/test_binaryop_2.cpp +++ b/tests/test_binaryop_2.cpp @@ -15,7 +15,7 @@ #include "layer/binaryop.h" #include "testutil.h" -#define OP_TYPE_MAX 9 +#define OP_TYPE_MAX 10 static int op_type = 0; @@ -23,15 +23,19 @@ static int test_binaryop(const ncnn::Mat& _a, const ncnn::Mat& _b) { ncnn::Mat a = _a; ncnn::Mat b = _b; - if (op_type == 6) + if (op_type == 6 || op_type == 9) { - // value must be positive for pow + // value must be positive for pow/rpow + a = a.clone(); + b = b.clone(); Randomize(a, 0.001f, 2.f); Randomize(b, 0.001f, 2.f); } if (op_type == 3 || op_type == 8) { - // value must be positive for pow + // value must be positive for div/rdiv + a = a.clone(); + b = b.clone(); Randomize(a, 0.1f, 10.f); Randomize(b, 0.1f, 10.f); } @@ -59,12 +63,18 @@ static int test_binaryop(const ncnn::Mat& _a, const ncnn::Mat& _b) static int test_binaryop(const ncnn::Mat& _a, float b) { ncnn::Mat a = _a; - if (op_type == 6) + if (op_type == 6 || op_type == 9) { - // value must be positive for pow + // value must be positive for pow/rpow Randomize(a, 0.001f, 2.f); b = RandomFloat(0.001f, 2.f); } + if (op_type == 3 || op_type == 8) + { + // value must be positive for div/rdiv + a = a.clone(); + Randomize(a, 0.1f, 10.f); + } ncnn::ParamDict pd; pd.set(0, op_type); @@ -86,296 +96,272 @@ static int test_binaryop(const ncnn::Mat& _a, float b) static int test_binaryop_1() { - return 0 - || test_binaryop(RandomMat(1), 1.f); -} - -static int test_binaryop_2() -{ - return 0 - || test_binaryop(RandomMat(1), RandomMat(1)) - || test_binaryop(RandomMat(1), RandomMat(4)) - || test_binaryop(RandomMat(1), RandomMat(16)); -} - -static int test_binaryop_3() -{ - return 0 - || test_binaryop(RandomMat(1), RandomMat(11, 3)) - || test_binaryop(RandomMat(1), RandomMat(11, 4)) - || test_binaryop(RandomMat(1), RandomMat(11, 16)); -} - -static int test_binaryop_4() -{ - return 0 - || test_binaryop(RandomMat(1), RandomMat(11, 6, 2)) - || test_binaryop(RandomMat(1), RandomMat(11, 6, 4)) - || test_binaryop(RandomMat(1), RandomMat(11, 6, 16)); -} - -static int test_binaryop_5() -{ - return 0 - || test_binaryop(RandomMat(2), 1.f) - || test_binaryop(RandomMat(4), 1.f) - || test_binaryop(RandomMat(16), 1.f); -} - -static int test_binaryop_6() -{ - return 0 - || test_binaryop(RandomMat(2), RandomMat(1)) - || test_binaryop(RandomMat(4), RandomMat(1)) - || test_binaryop(RandomMat(16), RandomMat(1)); -} - -static int test_binaryop_7() -{ - return 0 - || test_binaryop(RandomMat(2), RandomMat(2)) - || test_binaryop(RandomMat(4), RandomMat(4)) - || test_binaryop(RandomMat(16), RandomMat(16)); -} - -static int test_binaryop_8() -{ - return 0 - || test_binaryop(RandomMat(3), RandomMat(11, 3)) - || test_binaryop(RandomMat(4), RandomMat(11, 4)) - || test_binaryop(RandomMat(16), RandomMat(11, 16)); -} + ncnn::Mat a[] = { + RandomMat(31), + RandomMat(28), + RandomMat(24), + RandomMat(32), + RandomMat(13, 31), + RandomMat(14, 28), + RandomMat(15, 24), + RandomMat(16, 32), + RandomMat(7, 3, 31), + RandomMat(6, 4, 28), + RandomMat(5, 5, 24), + RandomMat(4, 6, 32), + RandomMat(2, 7, 3, 31), + RandomMat(3, 6, 4, 28), + RandomMat(4, 5, 5, 24), + RandomMat(5, 4, 6, 32) + }; + + ncnn::Mat b[] = { + RandomMat(1), + RandomMat(1, 1), + RandomMat(1, 1, 1), + RandomMat(1, 1, 1, 1) + }; + + for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) + { + for (int j = 0; j < sizeof(b) / sizeof(b[0]); j++) + { + int ret = test_binaryop(a[i], b[j]) || test_binaryop(b[j], a[i]); + if (ret != 0) + return ret; + } + + int ret = test_binaryop(a[i], 0.2f); + if (ret != 0) + return ret; + } -static int test_binaryop_9() -{ - return 0 - || test_binaryop(RandomMat(2), RandomMat(11, 6, 2)) - || test_binaryop(RandomMat(4), RandomMat(11, 6, 4)) - || test_binaryop(RandomMat(16), RandomMat(11, 6, 16)); + return 0; } -static int test_binaryop_10() +static int test_binaryop_2() { - return 0 - || test_binaryop(RandomMat(11, 3), 1.f) - || test_binaryop(RandomMat(11, 4), 1.f) - || test_binaryop(RandomMat(11, 16), 1.f); -} + ncnn::Mat a[] = { + RandomMat(31), + RandomMat(28), + RandomMat(24), + RandomMat(32), + RandomMat(13, 31), + RandomMat(14, 28), + RandomMat(15, 24), + RandomMat(16, 32), + RandomMat(7, 3, 31), + RandomMat(6, 4, 28), + RandomMat(5, 5, 24), + RandomMat(4, 6, 32), + RandomMat(2, 7, 3, 31), + RandomMat(3, 6, 4, 28), + RandomMat(4, 5, 5, 24), + RandomMat(5, 4, 6, 32) + }; + + for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) + { + ncnn::Mat b; + b.create_like(a[i]); + Randomize(b); -static int test_binaryop_11() -{ - return 0 - || test_binaryop(RandomMat(11, 3), RandomMat(1)) - || test_binaryop(RandomMat(11, 4), RandomMat(1)) - || test_binaryop(RandomMat(11, 16), RandomMat(1)); -} + int ret = test_binaryop(a[i], b) || test_binaryop(b, a[i]); + if (ret != 0) + return ret; + } -static int test_binaryop_12() -{ - return 0 - || test_binaryop(RandomMat(11, 3), RandomMat(3)) - || test_binaryop(RandomMat(11, 4), RandomMat(4)) - || test_binaryop(RandomMat(11, 16), RandomMat(16)); + return 0; } -static int test_binaryop_13() +static int test_binaryop_3() { - return 0 - || test_binaryop(RandomMat(11, 3), RandomMat(11, 3)) - || test_binaryop(RandomMat(11, 4), RandomMat(11, 4)) - || test_binaryop(RandomMat(11, 16), RandomMat(11, 16)); -} + ncnn::Mat a[] = { + RandomMat(13, 31), + RandomMat(14, 28), + RandomMat(15, 24), + RandomMat(16, 32) + }; -static int test_binaryop_14() -{ - return 0 - || test_binaryop(RandomMat(6, 2), RandomMat(11, 6, 2)) - || test_binaryop(RandomMat(6, 4), RandomMat(11, 6, 4)) - || test_binaryop(RandomMat(6, 16), RandomMat(11, 6, 16)); -} + for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) + { + ncnn::Mat b0(a[i].h); + ncnn::Mat b1(1, a[i].h); + Randomize(b0); + Randomize(b1); -static int test_binaryop_15() -{ - return 0 - || test_binaryop(RandomMat(11, 6, 2), 1.f) - || test_binaryop(RandomMat(11, 6, 4), 1.f) - || test_binaryop(RandomMat(11, 6, 16), 1.f); -} + int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i]) + || test_binaryop(a[i], b1) || test_binaryop(b1, a[i]); + if (ret != 0) + return ret; + } -static int test_binaryop_16() -{ - return 0 - || test_binaryop(RandomMat(11, 6, 2), RandomMat(1)) - || test_binaryop(RandomMat(11, 6, 4), RandomMat(1)) - || test_binaryop(RandomMat(11, 6, 16), RandomMat(1)); + return 0; } -static int test_binaryop_17() +static int test_binaryop_4() { - return 0 - || test_binaryop(RandomMat(11, 6, 2), RandomMat(2)) - || test_binaryop(RandomMat(11, 6, 4), RandomMat(4)) - || test_binaryop(RandomMat(11, 6, 16), RandomMat(16)); -} + ncnn::Mat a[] = { + RandomMat(7, 3, 31), + RandomMat(6, 4, 28), + RandomMat(5, 5, 24), + RandomMat(4, 6, 32) + }; -static int test_binaryop_18() -{ - return 0 - || test_binaryop(RandomMat(11, 6, 2), RandomMat(6, 2)) - || test_binaryop(RandomMat(11, 6, 4), RandomMat(6, 4)) - || test_binaryop(RandomMat(11, 6, 16), RandomMat(6, 16)); -} + for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) + { + ncnn::Mat b0(a[i].c); + ncnn::Mat b1(1, 1, a[i].c); + ncnn::Mat b2(a[i].h, a[i].c); + ncnn::Mat b3(1, a[i].h, a[i].c); + Randomize(b0); + Randomize(b1); + Randomize(b2); + Randomize(b3); + + int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i]) + || test_binaryop(a[i], b1) || test_binaryop(b1, a[i]) + || test_binaryop(a[i], b2) || test_binaryop(b2, a[i]) + || test_binaryop(a[i], b3) || test_binaryop(b3, a[i]); + if (ret != 0) + return ret; + } -static int test_binaryop_19() -{ - return 0 - || test_binaryop(RandomMat(11, 6, 2), RandomMat(11, 6, 2)) - || test_binaryop(RandomMat(11, 6, 4), RandomMat(11, 6, 4)) - || test_binaryop(RandomMat(11, 6, 16), RandomMat(11, 6, 16)); + return 0; } -static int test_binaryop_20() +static int test_binaryop_5() { - return 0 - || test_binaryop(RandomMat(1), RandomMat(11, 3, 4, 2)) - || test_binaryop(RandomMat(1), RandomMat(11, 3, 4, 4)) - || test_binaryop(RandomMat(1), RandomMat(11, 3, 4, 16)); -} + ncnn::Mat a[] = { + RandomMat(2, 7, 3, 31), + RandomMat(3, 6, 4, 28), + RandomMat(4, 5, 5, 24), + RandomMat(5, 4, 6, 32) + }; -static int test_binaryop_21() -{ - return 0 - || test_binaryop(RandomMat(2), RandomMat(11, 3, 4, 2)) - || test_binaryop(RandomMat(4), RandomMat(11, 3, 4, 4)) - || test_binaryop(RandomMat(16), RandomMat(11, 3, 4, 16)); -} + for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) + { + ncnn::Mat b0(a[i].c); + ncnn::Mat b1(1, 1, 1, a[i].c); + ncnn::Mat b2(a[i].d, a[i].c); + ncnn::Mat b3(1, 1, a[i].d, a[i].c); + ncnn::Mat b4(a[i].h, a[i].d, a[i].c); + ncnn::Mat b5(1, a[i].h, a[i].d, a[i].c); + Randomize(b0); + Randomize(b1); + Randomize(b2); + Randomize(b3); + Randomize(b4); + Randomize(b5); + + int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i]) + || test_binaryop(a[i], b1) || test_binaryop(b1, a[i]) + || test_binaryop(a[i], b2) || test_binaryop(b2, a[i]) + || test_binaryop(a[i], b3) || test_binaryop(b3, a[i]) + || test_binaryop(a[i], b4) || test_binaryop(b4, a[i]) + || test_binaryop(a[i], b5) || test_binaryop(b5, a[i]); + if (ret != 0) + return ret; + } -static int test_binaryop_22() -{ - return 0 - || test_binaryop(RandomMat(4, 2), RandomMat(11, 3, 4, 2)) - || test_binaryop(RandomMat(4, 4), RandomMat(11, 3, 4, 4)) - || test_binaryop(RandomMat(4, 16), RandomMat(11, 3, 4, 16)); + return 0; } -static int test_binaryop_23() +static int test_binaryop_6() { - return 0 - || test_binaryop(RandomMat(3, 4, 2), RandomMat(11, 3, 4, 2)) - || test_binaryop(RandomMat(3, 4, 4), RandomMat(11, 3, 4, 4)) - || test_binaryop(RandomMat(3, 4, 16), RandomMat(11, 3, 4, 16)); -} + ncnn::Mat a[] = { + RandomMat(13, 31), + RandomMat(14, 28), + RandomMat(15, 24), + RandomMat(16, 32) + }; -static int test_binaryop_24() -{ - return 0 - || test_binaryop(RandomMat(11, 3, 4, 2), 1.f) - || test_binaryop(RandomMat(11, 3, 4, 4), 1.f) - || test_binaryop(RandomMat(11, 3, 4, 16), 1.f); -} + for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) + { + ncnn::Mat b0(a[i].w, 1); + Randomize(b0); -static int test_binaryop_25() -{ - return 0 - || test_binaryop(RandomMat(11, 3, 4, 2), RandomMat(1)) - || test_binaryop(RandomMat(11, 3, 4, 4), RandomMat(1)) - || test_binaryop(RandomMat(11, 3, 4, 16), RandomMat(1)); -} + int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i]); + if (ret != 0) + return ret; + } -static int test_binaryop_26() -{ - return 0 - || test_binaryop(RandomMat(11, 3, 4, 2), RandomMat(2)) - || test_binaryop(RandomMat(11, 3, 4, 4), RandomMat(4)) - || test_binaryop(RandomMat(11, 3, 4, 16), RandomMat(16)); + return 0; } -static int test_binaryop_27() +static int test_binaryop_7() { - return 0 - || test_binaryop(RandomMat(11, 3, 4, 2), RandomMat(4, 2)) - || test_binaryop(RandomMat(11, 3, 4, 4), RandomMat(4, 4)) - || test_binaryop(RandomMat(11, 3, 4, 16), RandomMat(4, 16)); -} + ncnn::Mat a[] = { + RandomMat(7, 3, 31), + RandomMat(6, 4, 28), + RandomMat(5, 5, 24), + RandomMat(4, 6, 32) + }; -static int test_binaryop_28() -{ - return 0 - || test_binaryop(RandomMat(11, 3, 4, 2), RandomMat(3, 4, 2)) - || test_binaryop(RandomMat(11, 3, 4, 4), RandomMat(3, 4, 4)) - || test_binaryop(RandomMat(11, 3, 4, 16), RandomMat(3, 4, 16)); -} + for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) + { + ncnn::Mat b0(a[i].w, 1, 1); + ncnn::Mat b1(a[i].w, a[i].h, 1); + Randomize(b0); + Randomize(b1); -static int test_binaryop_29() -{ - return 0 - || test_binaryop(RandomMat(11, 3, 4, 2), RandomMat(11, 3, 4, 2)) - || test_binaryop(RandomMat(11, 3, 4, 4), RandomMat(11, 3, 4, 4)) - || test_binaryop(RandomMat(11, 3, 4, 16), RandomMat(11, 3, 4, 16)); -} + int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i]) + || test_binaryop(a[i], b1) || test_binaryop(b1, a[i]); + if (ret != 0) + return ret; + } -static int test_binaryop_s1() -{ - return 0 - || test_binaryop(RandomMat(11, 6, 2), RandomMat(1, 1, 2)) - || test_binaryop(RandomMat(11, 6, 4), RandomMat(1, 1, 4)) - || test_binaryop(RandomMat(11, 6, 16), RandomMat(1, 1, 16)); + return 0; } -static int test_binaryop_s2() +static int test_binaryop_8() { - return 0 - || test_binaryop(RandomMat(11, 6, 2), RandomMat(11, 6, 1)) - || test_binaryop(RandomMat(11, 6, 4), RandomMat(11, 6, 1)) - || test_binaryop(RandomMat(11, 6, 16), RandomMat(11, 6, 1)); -} + ncnn::Mat a[] = { + RandomMat(2, 7, 3, 31), + RandomMat(3, 6, 4, 28), + RandomMat(4, 5, 5, 24), + RandomMat(5, 4, 6, 32) + }; -static int test_binaryop_s3() -{ - return 0 - || test_binaryop(RandomMat(1, 1, 2), RandomMat(11, 6, 2)) - || test_binaryop(RandomMat(1, 1, 4), RandomMat(11, 6, 4)) - || test_binaryop(RandomMat(1, 1, 16), RandomMat(11, 6, 16)); -} + for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) + { + ncnn::Mat b0(a[i].w, 1, 1, 1); + ncnn::Mat b1(a[i].w, a[i].h, 1, 1); + ncnn::Mat b2(a[i].w, a[i].h, a[i].d, 1); + Randomize(b0); + Randomize(b1); + Randomize(b2); + + int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i]) + || test_binaryop(a[i], b1) || test_binaryop(b1, a[i]) + || test_binaryop(a[i], b2) || test_binaryop(b2, a[i]); + if (ret != 0) + return ret; + } -static int test_binaryop_s4() -{ - return 0 - || test_binaryop(RandomMat(11, 6, 1), RandomMat(11, 6, 2)) - || test_binaryop(RandomMat(11, 6, 1), RandomMat(11, 6, 4)) - || test_binaryop(RandomMat(11, 6, 1), RandomMat(11, 6, 16)); + return 0; } -static int test_binaryop_s5() +static int test_binaryop_9() { - return 0 - || test_binaryop(RandomMat(11, 6, 2), RandomMat(1, 6, 2)) - || test_binaryop(RandomMat(11, 6, 4), RandomMat(1, 6, 4)) - || test_binaryop(RandomMat(11, 6, 16), RandomMat(1, 6, 16)); -} + ncnn::Mat a[] = { + RandomMat(7, 3, 31), + RandomMat(6, 4, 28), + RandomMat(5, 5, 24), + RandomMat(4, 6, 32) + }; -static int test_binaryop_s6() -{ - return 0 - || test_binaryop(RandomMat(11, 6, 2), RandomMat(11, 1, 2)) - || test_binaryop(RandomMat(11, 6, 4), RandomMat(11, 1, 4)) - || test_binaryop(RandomMat(11, 6, 16), RandomMat(11, 1, 16)); -} + for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) + { + ncnn::Mat b0(a[i].w, 1, a[i].c); + Randomize(b0); -static int test_binaryop_s7() -{ - return 0 - || test_binaryop(RandomMat(1, 6, 2), RandomMat(11, 6, 2)) - || test_binaryop(RandomMat(1, 6, 4), RandomMat(11, 6, 4)) - || test_binaryop(RandomMat(1, 6, 16), RandomMat(11, 6, 16)); -} + int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i]); + if (ret != 0) + return ret; + } -static int test_binaryop_s8() -{ - return 0 - || test_binaryop(RandomMat(11, 1, 2), RandomMat(11, 6, 2)) - || test_binaryop(RandomMat(11, 1, 4), RandomMat(11, 6, 4)) - || test_binaryop(RandomMat(11, 1, 16), RandomMat(11, 6, 16)); + return 0; } int main() @@ -393,35 +379,7 @@ int main() || test_binaryop_6() || test_binaryop_7() || test_binaryop_8() - || test_binaryop_9() - || test_binaryop_10() - || test_binaryop_11() - || test_binaryop_12() - || test_binaryop_13() - || test_binaryop_14() - || test_binaryop_15() - || test_binaryop_16() - || test_binaryop_17() - || test_binaryop_18() - || test_binaryop_19() - || test_binaryop_20() - || test_binaryop_21() - || test_binaryop_22() - || test_binaryop_23() - || test_binaryop_24() - || test_binaryop_25() - || test_binaryop_26() - || test_binaryop_27() - || test_binaryop_28() - || test_binaryop_29() - || test_binaryop_s1() - || test_binaryop_s2() - || test_binaryop_s3() - || test_binaryop_s4() - || test_binaryop_s5() - || test_binaryop_s6() - || test_binaryop_s7() - || test_binaryop_s8(); + || test_binaryop_9(); if (ret != 0) return ret; diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index 6da7c01ee..7f141bb16 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -366,6 +366,7 @@ set(pnnx_pass_ncnn_SRCS pass_ncnn/fuse_innerproduct_activation.cpp pass_ncnn/fuse_transpose_matmul.cpp pass_ncnn/fuse_binaryop_eltwise.cpp + pass_ncnn/insert_reshape_numpy_binaryop_broadcast.cpp pass_ncnn/insert_reshape_linear.cpp pass_ncnn/insert_reshape_pooling.cpp diff --git a/tools/pnnx/src/pass_ncnn.cpp b/tools/pnnx/src/pass_ncnn.cpp index 14dedf0e8..5daac8f4f 100644 --- a/tools/pnnx/src/pass_ncnn.cpp +++ b/tools/pnnx/src/pass_ncnn.cpp @@ -43,6 +43,7 @@ #include "pass_ncnn/fuse_innerproduct_activation.h" #include "pass_ncnn/fuse_transpose_matmul.h" #include "pass_ncnn/fuse_binaryop_eltwise.h" +#include "pass_ncnn/insert_reshape_numpy_binaryop_broadcast.h" #include "pass_ncnn/insert_reshape_linear.h" #include "pass_ncnn/insert_reshape_pooling.h" @@ -85,6 +86,7 @@ void pass_ncnn(Graph& g) ncnn::convert_half_to_float(g); + ncnn::insert_reshape_numpy_binaryop_broadcast(g); ncnn::insert_reshape_pooling(g); ncnn::insert_reshape_linear(g); diff --git a/tools/pnnx/src/pass_ncnn/expand_expression.cpp b/tools/pnnx/src/pass_ncnn/expand_expression.cpp index baec8795c..88308a776 100644 --- a/tools/pnnx/src/pass_ncnn/expand_expression.cpp +++ b/tools/pnnx/src/pass_ncnn/expand_expression.cpp @@ -169,6 +169,8 @@ static std::string expand_expression(Graph& graph, const Operator* op, int& pnnx Operand* op_unary_out = graph.new_operand(op->name + "_" + r); op_unary_out->producer = op_unary; + op_unary_out->shape = op_unary_in->shape; + op_unary->inputs.push_back(op_unary_in); op_unary->outputs.push_back(op_unary_out); } @@ -204,6 +206,8 @@ static std::string expand_expression(Graph& graph, const Operator* op, int& pnnx Operand* op_binary_out = graph.new_operand(op->name + "_" + r); op_binary_out->producer = op_binary; + op_binary_out->shape = op_binary_inb->shape; + op_binary->inputs.push_back(op_binary_inb); op_binary->outputs.push_back(op_binary_out); } @@ -218,6 +222,8 @@ static std::string expand_expression(Graph& graph, const Operator* op, int& pnnx Operand* op_binary_out = graph.new_operand(op->name + "_" + r); op_binary_out->producer = op_binary; + op_binary_out->shape = op_binary_ina->shape; + op_binary->inputs.push_back(op_binary_ina); op_binary->outputs.push_back(op_binary_out); } @@ -232,6 +238,28 @@ static std::string expand_expression(Graph& graph, const Operator* op, int& pnnx Operand* op_binary_out = graph.new_operand(op->name + "_" + r); op_binary_out->producer = op_binary; + // resolve out shape + std::vector out_shape; + { + std::vector a_shape = op_binary_ina->shape; + std::vector b_shape = op_binary_inb->shape; + int outrank = (int)std::max(a_shape.size(), b_shape.size()); + for (int k = (int)a_shape.size(); k < outrank; k++) + { + a_shape.insert(a_shape.begin(), 1); + } + for (int k = (int)b_shape.size(); k < outrank; k++) + { + b_shape.insert(b_shape.begin(), 1); + } + out_shape.resize(outrank); + for (int k = 0; k < outrank; k++) + { + out_shape[k] = std::max(a_shape[k], b_shape[k]); + } + } + op_binary_out->shape = out_shape; + op_binary->inputs.push_back(op_binary_ina); op_binary->inputs.push_back(op_binary_inb); op_binary->outputs.push_back(op_binary_out); diff --git a/tools/pnnx/src/pass_ncnn/insert_reshape_numpy_binaryop_broadcast.cpp b/tools/pnnx/src/pass_ncnn/insert_reshape_numpy_binaryop_broadcast.cpp new file mode 100644 index 000000000..50b39f8d7 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/insert_reshape_numpy_binaryop_broadcast.cpp @@ -0,0 +1,153 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "insert_reshape_numpy_binaryop_broadcast.h" +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +void insert_reshape_numpy_binaryop_broadcast(Graph& graph) +{ + while (1) + { + bool matched = false; + + for (size_t i = 0; i < graph.ops.size(); i++) + { + Operator* op = graph.ops[i]; + + if (op->type != "BinaryOp") + continue; + + if (op->inputs.size() != 2) + continue; + + if (op->inputs[0]->shape.empty() || op->inputs[1]->shape.empty()) + continue; + + int batch_index0 = op->inputs[0]->params["__batch_index"].i; + int batch_index1 = op->inputs[1]->params["__batch_index"].i; + if (batch_index0 != batch_index1) + { + fprintf(stderr, "binaryop broadcast across batch axis %d and %d is not supported\n", batch_index0, batch_index1); + continue; + } + + if (op->inputs[0]->shape.size() == 5 && batch_index0 == 233) + { + if (op->inputs[0]->shape[0] == 1) + { + fprintf(stderr, "assume reshape 5-rank tensor has batch_index 0\n"); + batch_index0 = 0; + } + } + if (op->inputs[1]->shape.size() == 5 && batch_index1 == 233) + { + if (op->inputs[1]->shape[0] == 1) + { + fprintf(stderr, "assume reshape 5-rank tensor has batch_index 0\n"); + batch_index1 = 0; + } + } + + // drop shape batch index + std::vector new_shape0; + std::vector new_shape1; + for (int j = 0; j < (int)op->inputs[0]->shape.size(); j++) + { + if (j == batch_index0 && (op->inputs[0]->shape[j] == 1 || op->inputs[0]->shape[j] == op->inputs[1]->shape[j])) + continue; + + new_shape0.push_back(op->inputs[0]->shape[j]); + } + for (int j = 0; j < (int)op->inputs[1]->shape.size(); j++) + { + if (j == batch_index1 && (op->inputs[1]->shape[j] == 1 || op->inputs[1]->shape[j] == op->inputs[0]->shape[j])) + continue; + + new_shape1.push_back(op->inputs[1]->shape[j]); + } + + const int input_rank0 = (int)new_shape0.size(); + const int input_rank1 = (int)new_shape1.size(); + + if (input_rank0 >= 5) + { + fprintf(stderr, "binaryop tensor0 with rank %d is not supported yet!\n", (int)op->inputs[0]->shape.size()); + } + + if (input_rank1 >= 5) + { + fprintf(stderr, "binaryop tensor1 with rank %d is not supported yet!\n", (int)op->inputs[1]->shape.size()); + } + + if (input_rank0 == input_rank1) + { + // no broadcast after ignoring batch index + continue; + } + + // fprintf(stderr, "insert_reshape_numpy_binaryop_broadcast %d %d\n", input_rank0, input_rank1); + + matched = true; + + const int binaryop_lower_rank_in_index = input_rank0 < input_rank1 ? 0 : 1; + + Operand* binaryop_lower_rank_in = op->inputs[binaryop_lower_rank_in_index]; + + Operator* reshape0 = graph.new_operator_before("Tensor.reshape", op->name + "_ncnnreshape0", op); + + Operand* reshape0_out = graph.new_operand(op->name + "_ncnnreshape0_out"); + + reshape0->inputs.push_back(binaryop_lower_rank_in); + reshape0->outputs.push_back(reshape0_out); + + for (size_t j = 0; j < binaryop_lower_rank_in->consumers.size(); j++) + { + if (binaryop_lower_rank_in->consumers[j] == op) + { + binaryop_lower_rank_in->consumers[j] = reshape0; + break; + } + } + + op->inputs[binaryop_lower_rank_in_index] = reshape0_out; + + reshape0_out->producer = reshape0; + reshape0_out->consumers.push_back(op); + + reshape0_out->params["__batch_index"] = input_rank0 < input_rank1 ? batch_index0 : batch_index1; + + // insert explicit broadcast index for missing ranks + std::vector reshape0_shape = input_rank0 < input_rank1 ? new_shape0 : new_shape1; + for (int j = 0; j < std::abs(input_rank0 - input_rank1); j++) + { + reshape0_shape.insert(reshape0_shape.begin(), 1); + } + + reshape0->params["shape"] = reshape0_shape; + + break; + } + + if (!matched) + break; + } +} + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/insert_reshape_numpy_binaryop_broadcast.h b/tools/pnnx/src/pass_ncnn/insert_reshape_numpy_binaryop_broadcast.h new file mode 100644 index 000000000..469a14eaf --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/insert_reshape_numpy_binaryop_broadcast.h @@ -0,0 +1,25 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +void insert_reshape_numpy_binaryop_broadcast(Graph& graph); + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/tests/ncnn/CMakeLists.txt b/tools/pnnx/tests/ncnn/CMakeLists.txt index 272551a8a..b3aec911c 100644 --- a/tools/pnnx/tests/ncnn/CMakeLists.txt +++ b/tools/pnnx/tests/ncnn/CMakeLists.txt @@ -187,6 +187,7 @@ pnnx_ncnn_add_test(vit_b_32) pnnx_ncnn_add_test(ncnn_fuse_transpose_matmul) pnnx_ncnn_add_test(ncnn_fuse_shufflechannel_slice) pnnx_ncnn_add_test(ncnn_fuse_binaryop_eltwise) +pnnx_ncnn_add_test(ncnn_numpy_binaryop_broadcast) if(Torch_VERSION VERSION_GREATER_EQUAL "1.9") pnnx_ncnn_add_test(F_mish) diff --git a/tools/pnnx/tests/ncnn/test_ncnn_numpy_binaryop_broadcast.py b/tools/pnnx/tests/ncnn/test_ncnn_numpy_binaryop_broadcast.py new file mode 100644 index 000000000..25bc659c2 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_ncnn_numpy_binaryop_broadcast.py @@ -0,0 +1,78 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w, u, v): + a = x + y + b = x - z + c = x * w + d = y / z + e = y + w + f = z - w + g = y + x + h = z - x + i = w * x + j = z / y + k = w + y + l = w - z + m = (x - z) * w + n = (x + y) - (z + w) + o = x.view(1, 1, 5) + y.view(1, 7, 5) - z + p = u * y + q = z / v + return a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(5) + y = torch.rand(7, 5) + z = torch.rand(4, 7, 5) + w = torch.rand(6, 4, 7, 5) + u = torch.rand(7, 1) + v = torch.rand(4, 1, 1) + + a = net(x, y, z, w, u, v) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w, u, v)) + mod.save("test_ncnn_numpy_binaryop_broadcast.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_ncnn_numpy_binaryop_broadcast.pt inputshape=[5],[7,5],[4,7,5],[6,4,7,5],[7,1],[4,1,1]") + + # ncnn inference + import test_ncnn_numpy_binaryop_broadcast_ncnn + b = test_ncnn_numpy_binaryop_broadcast_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1)