|
|
|
@@ -60,21 +60,21 @@ static int binary_op_pack4(const Mat& a, const Mat& b, Mat& c, const Option& opt |
|
|
|
|
|
|
|
if (a.dims == 3) |
|
|
|
{ |
|
|
|
c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); |
|
|
|
if (c.empty()) |
|
|
|
return -100; |
|
|
|
|
|
|
|
if (b.dims == 3) |
|
|
|
{ |
|
|
|
if (w1 == 1 && h1 == 1 && channels1 == channels) |
|
|
|
{ |
|
|
|
// special type 1 |
|
|
|
c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); |
|
|
|
if (c.empty()) |
|
|
|
return -100; |
|
|
|
|
|
|
|
#pragma omp parallel for num_threads(opt.num_threads) |
|
|
|
for (int q=0; q<channels; q++) |
|
|
|
{ |
|
|
|
const float* ptr = a.channel(q); |
|
|
|
float* outptr = c.channel(q); |
|
|
|
const float* b0 = b.channel(q); |
|
|
|
float* outptr = c.channel(q); |
|
|
|
float32x4_t _b0 = vld1q_f32(b0); |
|
|
|
for (int i = 0; i < size; i++) |
|
|
|
{ |
|
|
|
@@ -92,6 +92,10 @@ static int binary_op_pack4(const Mat& a, const Mat& b, Mat& c, const Option& opt |
|
|
|
if (w1 == w && h1 == h && channels1 == 1 && elempack1 == 1) |
|
|
|
{ |
|
|
|
// special type 2 |
|
|
|
c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); |
|
|
|
if (c.empty()) |
|
|
|
return -100; |
|
|
|
|
|
|
|
#pragma omp parallel for num_threads(opt.num_threads) |
|
|
|
for (int q = 0; q < channels; q++) |
|
|
|
{ |
|
|
|
@@ -113,7 +117,66 @@ static int binary_op_pack4(const Mat& a, const Mat& b, Mat& c, const Option& opt |
|
|
|
return 0; |
|
|
|
} |
|
|
|
|
|
|
|
if (w == 1 && h == 1 && channels1 == channels) |
|
|
|
{ |
|
|
|
// special type 3 |
|
|
|
c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); |
|
|
|
if (c.empty()) |
|
|
|
return -100; |
|
|
|
|
|
|
|
#pragma omp parallel for num_threads(opt.num_threads) |
|
|
|
for (int q=0; q<channels1; q++) |
|
|
|
{ |
|
|
|
const float* a0 = a.channel(q); |
|
|
|
const float* ptr1 = b.channel(q); |
|
|
|
float* outptr = c.channel(q); |
|
|
|
float32x4_t _a0 = vld1q_f32(a0); |
|
|
|
for (int i = 0; i < size1; i++) |
|
|
|
{ |
|
|
|
float32x4_t _p1 = vld1q_f32(ptr1); |
|
|
|
float32x4_t _outp = op(_a0, _p1); |
|
|
|
vst1q_f32(outptr, _outp); |
|
|
|
ptr1 += 4; |
|
|
|
outptr += 4; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return 0; |
|
|
|
} |
|
|
|
|
|
|
|
if (w1 == w && h1 == h && channels == 1 && elempack == 1) |
|
|
|
{ |
|
|
|
// special type 4 |
|
|
|
c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); |
|
|
|
if (c.empty()) |
|
|
|
return -100; |
|
|
|
|
|
|
|
#pragma omp parallel for num_threads(opt.num_threads) |
|
|
|
for (int q = 0; q < channels1; q++) |
|
|
|
{ |
|
|
|
const float* ptr = a; |
|
|
|
const float* ptr1 = b.channel(q); |
|
|
|
float* outptr = c.channel(q); |
|
|
|
for (int i = 0; i < size1; i++) |
|
|
|
{ |
|
|
|
float32x4_t _p = vld1q_dup_f32(ptr); |
|
|
|
float32x4_t _p1 = vld1q_f32(ptr1); |
|
|
|
float32x4_t _outp = op(_p, _p1); |
|
|
|
vst1q_f32(outptr, _outp); |
|
|
|
ptr += 1; |
|
|
|
ptr1 += 4; |
|
|
|
outptr += 4; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return 0; |
|
|
|
} |
|
|
|
|
|
|
|
// type 19 |
|
|
|
c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); |
|
|
|
if (c.empty()) |
|
|
|
return -100; |
|
|
|
|
|
|
|
#pragma omp parallel for num_threads(opt.num_threads) |
|
|
|
for (int q=0; q<channels; q++) |
|
|
|
{ |
|
|
|
@@ -136,6 +199,10 @@ static int binary_op_pack4(const Mat& a, const Mat& b, Mat& c, const Option& opt |
|
|
|
return 0; |
|
|
|
} |
|
|
|
|
|
|
|
c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); |
|
|
|
if (c.empty()) |
|
|
|
return -100; |
|
|
|
|
|
|
|
if (b.dims == 2) |
|
|
|
{ |
|
|
|
// type 18 |
|
|
|
@@ -216,7 +283,7 @@ static int binary_op_pack4(const Mat& a, const Mat& b, Mat& c, const Option& opt |
|
|
|
if (b.dims == 3) |
|
|
|
{ |
|
|
|
// type 14 |
|
|
|
c.create(w1, h1, channels1, elemsize, elempack, opt.blob_allocator); |
|
|
|
c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); |
|
|
|
if (c.empty()) |
|
|
|
return -100; |
|
|
|
|
|
|
|
@@ -396,7 +463,7 @@ static int binary_op_pack4(const Mat& a, const Mat& b, Mat& c, const Option& opt |
|
|
|
if (b.dims == 3) |
|
|
|
{ |
|
|
|
// type 9 |
|
|
|
c.create(w1, h1, channels1, elemsize, elempack, opt.blob_allocator); |
|
|
|
c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); |
|
|
|
if (c.empty()) |
|
|
|
return -100; |
|
|
|
|
|
|
|
@@ -423,7 +490,7 @@ static int binary_op_pack4(const Mat& a, const Mat& b, Mat& c, const Option& opt |
|
|
|
if (b.dims == 2) |
|
|
|
{ |
|
|
|
// type 8 |
|
|
|
c.create(w1, h1, elemsize, elempack, opt.blob_allocator); |
|
|
|
c.create(w1, h1, elemsize1, elempack1, opt.blob_allocator); |
|
|
|
if (c.empty()) |
|
|
|
return -100; |
|
|
|
|
|
|
|
@@ -681,15 +748,15 @@ static int binary_op_pack4_bf16s(const Mat& a, const Mat& b, Mat& c, const Optio |
|
|
|
|
|
|
|
if (a.dims == 3) |
|
|
|
{ |
|
|
|
c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); |
|
|
|
if (c.empty()) |
|
|
|
return -100; |
|
|
|
|
|
|
|
if (b.dims == 3) |
|
|
|
{ |
|
|
|
if (w1 == 1 && h1 == 1 && channels1 == channels) |
|
|
|
{ |
|
|
|
// special type 1 |
|
|
|
c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); |
|
|
|
if (c.empty()) |
|
|
|
return -100; |
|
|
|
|
|
|
|
#pragma omp parallel for num_threads(opt.num_threads) |
|
|
|
for (int q=0; q<channels; q++) |
|
|
|
{ |
|
|
|
@@ -713,6 +780,10 @@ static int binary_op_pack4_bf16s(const Mat& a, const Mat& b, Mat& c, const Optio |
|
|
|
if (w1 == w && h1 == h && channels1 == 1 && elempack1 == 1) |
|
|
|
{ |
|
|
|
// special type 2 |
|
|
|
c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); |
|
|
|
if (c.empty()) |
|
|
|
return -100; |
|
|
|
|
|
|
|
#pragma omp parallel for num_threads(opt.num_threads) |
|
|
|
for (int q = 0; q < channels; q++) |
|
|
|
{ |
|
|
|
@@ -734,7 +805,66 @@ static int binary_op_pack4_bf16s(const Mat& a, const Mat& b, Mat& c, const Optio |
|
|
|
return 0; |
|
|
|
} |
|
|
|
|
|
|
|
if (w == 1 && h == 1 && channels1 == channels) |
|
|
|
{ |
|
|
|
// special type 3 |
|
|
|
c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); |
|
|
|
if (c.empty()) |
|
|
|
return -100; |
|
|
|
|
|
|
|
#pragma omp parallel for num_threads(opt.num_threads) |
|
|
|
for (int q=0; q<channels1; q++) |
|
|
|
{ |
|
|
|
const unsigned short* a0 = a.channel(q); |
|
|
|
unsigned short* outptr = c.channel(q); |
|
|
|
const unsigned short* ptr1 = b.channel(q); |
|
|
|
float32x4_t _a0 = vreinterpretq_f32_u32(vshll_n_u16(vld1_u16(a0), 16)); |
|
|
|
for (int i = 0; i < size1; i++) |
|
|
|
{ |
|
|
|
float32x4_t _p1 = vreinterpretq_f32_u32(vshll_n_u16(vld1_u16(ptr1), 16)); |
|
|
|
float32x4_t _outp = op(_a0, _p1); |
|
|
|
vst1_u16(outptr, vshrn_n_u32(vreinterpretq_u32_f32(_outp), 16)); |
|
|
|
ptr1 += 4; |
|
|
|
outptr += 4; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return 0; |
|
|
|
} |
|
|
|
|
|
|
|
if (w1 == w && h1 == h && channels == 1 && elempack == 1) |
|
|
|
{ |
|
|
|
// special type 4 |
|
|
|
c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); |
|
|
|
if (c.empty()) |
|
|
|
return -100; |
|
|
|
|
|
|
|
#pragma omp parallel for num_threads(opt.num_threads) |
|
|
|
for (int q = 0; q < channels1; q++) |
|
|
|
{ |
|
|
|
const unsigned short* ptr = a; |
|
|
|
const unsigned short* ptr1 = b.channel(q); |
|
|
|
unsigned short* outptr = c.channel(q); |
|
|
|
for (int i = 0; i < size1; i++) |
|
|
|
{ |
|
|
|
float32x4_t _p = vdupq_n_f32(bfloat16_to_float32(*ptr)); |
|
|
|
float32x4_t _p1 = vreinterpretq_f32_u32(vshll_n_u16(vld1_u16(ptr1), 16)); |
|
|
|
float32x4_t _outp = op(_p, _p1); |
|
|
|
vst1_u16(outptr, vshrn_n_u32(vreinterpretq_u32_f32(_outp), 16)); |
|
|
|
ptr += 1; |
|
|
|
ptr1 += 4; |
|
|
|
outptr += 4; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return 0; |
|
|
|
} |
|
|
|
|
|
|
|
// type 19 |
|
|
|
c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); |
|
|
|
if (c.empty()) |
|
|
|
return -100; |
|
|
|
|
|
|
|
#pragma omp parallel for num_threads(opt.num_threads) |
|
|
|
for (int q=0; q<channels; q++) |
|
|
|
{ |
|
|
|
@@ -757,6 +887,10 @@ static int binary_op_pack4_bf16s(const Mat& a, const Mat& b, Mat& c, const Optio |
|
|
|
return 0; |
|
|
|
} |
|
|
|
|
|
|
|
c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); |
|
|
|
if (c.empty()) |
|
|
|
return -100; |
|
|
|
|
|
|
|
if (b.dims == 2) |
|
|
|
{ |
|
|
|
// type 18 |
|
|
|
@@ -837,7 +971,7 @@ static int binary_op_pack4_bf16s(const Mat& a, const Mat& b, Mat& c, const Optio |
|
|
|
if (b.dims == 3) |
|
|
|
{ |
|
|
|
// type 14 |
|
|
|
c.create(w1, h1, channels1, elemsize, elempack, opt.blob_allocator); |
|
|
|
c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); |
|
|
|
if (c.empty()) |
|
|
|
return -100; |
|
|
|
|
|
|
|
@@ -1017,7 +1151,7 @@ static int binary_op_pack4_bf16s(const Mat& a, const Mat& b, Mat& c, const Optio |
|
|
|
if (b.dims == 3) |
|
|
|
{ |
|
|
|
// type 9 |
|
|
|
c.create(w1, h1, channels1, elemsize, elempack, opt.blob_allocator); |
|
|
|
c.create(w1, h1, channels1, elemsize1, elempack1, opt.blob_allocator); |
|
|
|
if (c.empty()) |
|
|
|
return -100; |
|
|
|
|
|
|
|
@@ -1044,7 +1178,7 @@ static int binary_op_pack4_bf16s(const Mat& a, const Mat& b, Mat& c, const Optio |
|
|
|
if (b.dims == 2) |
|
|
|
{ |
|
|
|
// type 8 |
|
|
|
c.create(w1, h1, elemsize, elempack, opt.blob_allocator); |
|
|
|
c.create(w1, h1, elemsize1, elempack1, opt.blob_allocator); |
|
|
|
if (c.empty()) |
|
|
|
return -100; |
|
|
|
|
|
|
|
@@ -1162,21 +1296,21 @@ static int binary_op_bf16s(const Mat& a, const Mat& b, Mat& c, const Option& opt |
|
|
|
|
|
|
|
if (a.dims == 3) |
|
|
|
{ |
|
|
|
c.create(w, h, channels, elemsize, opt.blob_allocator); |
|
|
|
if (c.empty()) |
|
|
|
return -100; |
|
|
|
|
|
|
|
if (b.dims == 3) |
|
|
|
{ |
|
|
|
if (w1 == 1 && h1 == 1 && channels1 == channels) |
|
|
|
{ |
|
|
|
// special type 1 |
|
|
|
c.create(w, h, channels, elemsize, opt.blob_allocator); |
|
|
|
if (c.empty()) |
|
|
|
return -100; |
|
|
|
|
|
|
|
#pragma omp parallel for num_threads(opt.num_threads) |
|
|
|
for (int q = 0; q < channels; q++) |
|
|
|
{ |
|
|
|
const unsigned short* ptr = a.channel(q); |
|
|
|
unsigned short* outptr = c.channel(q); |
|
|
|
const unsigned short* b0 = b.channel(q); |
|
|
|
unsigned short* outptr = c.channel(q); |
|
|
|
for (int i = 0; i < size; i++) |
|
|
|
{ |
|
|
|
outptr[i] = float32_to_bfloat16(op(bfloat16_to_float32(ptr[i]), bfloat16_to_float32(b0[0]))); |
|
|
|
@@ -1189,6 +1323,10 @@ static int binary_op_bf16s(const Mat& a, const Mat& b, Mat& c, const Option& opt |
|
|
|
if (w1 == w && h1 == h && channels1 == 1) |
|
|
|
{ |
|
|
|
// special type 2 |
|
|
|
c.create(w, h, channels, elemsize, opt.blob_allocator); |
|
|
|
if (c.empty()) |
|
|
|
return -100; |
|
|
|
|
|
|
|
#pragma omp parallel for num_threads(opt.num_threads) |
|
|
|
for (int q = 0; q < channels; q++) |
|
|
|
{ |
|
|
|
@@ -1204,7 +1342,55 @@ static int binary_op_bf16s(const Mat& a, const Mat& b, Mat& c, const Option& opt |
|
|
|
return 0; |
|
|
|
} |
|
|
|
|
|
|
|
if (w == 1 && h == 1 && channels1 == channels) |
|
|
|
{ |
|
|
|
// special type 3 |
|
|
|
c.create(w1, h1, channels1, elemsize, opt.blob_allocator); |
|
|
|
if (c.empty()) |
|
|
|
return -100; |
|
|
|
|
|
|
|
#pragma omp parallel for num_threads(opt.num_threads) |
|
|
|
for (int q = 0; q < channels1; q++) |
|
|
|
{ |
|
|
|
const unsigned short* a0 = a.channel(q); |
|
|
|
const unsigned short* ptr1 = b.channel(q); |
|
|
|
unsigned short* outptr = c.channel(q); |
|
|
|
for (int i = 0; i < size1; i++) |
|
|
|
{ |
|
|
|
outptr[i] = float32_to_bfloat16(op(bfloat16_to_float32(a0[0]), bfloat16_to_float32(ptr1[i]))); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return 0; |
|
|
|
} |
|
|
|
|
|
|
|
if (w1 == w && h1 == h && channels == 1) |
|
|
|
{ |
|
|
|
// special type 4 |
|
|
|
c.create(w1, h1, channels1, elemsize, opt.blob_allocator); |
|
|
|
if (c.empty()) |
|
|
|
return -100; |
|
|
|
|
|
|
|
#pragma omp parallel for num_threads(opt.num_threads) |
|
|
|
for (int q = 0; q < channels1; q++) |
|
|
|
{ |
|
|
|
const unsigned short* ptr = a; |
|
|
|
const unsigned short* ptr1 = b.channel(q); |
|
|
|
unsigned short* outptr = c.channel(q); |
|
|
|
for (int i = 0; i < size1; i++) |
|
|
|
{ |
|
|
|
outptr[i] = float32_to_bfloat16(op(bfloat16_to_float32(ptr[i]), bfloat16_to_float32(ptr1[i]))); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return 0; |
|
|
|
} |
|
|
|
|
|
|
|
// type 19 |
|
|
|
c.create(w, h, channels, elemsize, opt.blob_allocator); |
|
|
|
if (c.empty()) |
|
|
|
return -100; |
|
|
|
|
|
|
|
#pragma omp parallel for num_threads(opt.num_threads) |
|
|
|
for (int q=0; q<channels; q++) |
|
|
|
{ |
|
|
|
@@ -1221,6 +1407,10 @@ static int binary_op_bf16s(const Mat& a, const Mat& b, Mat& c, const Option& opt |
|
|
|
return 0; |
|
|
|
} |
|
|
|
|
|
|
|
c.create(w, h, channels, elemsize, opt.blob_allocator); |
|
|
|
if (c.empty()) |
|
|
|
return -100; |
|
|
|
|
|
|
|
if (b.dims == 2) |
|
|
|
{ |
|
|
|
// type 18 |
|
|
|
|