diff --git a/src/layer/arm/binaryop_arm.cpp b/src/layer/arm/binaryop_arm.cpp index 657abe159..5ca8582e8 100644 --- a/src/layer/arm/binaryop_arm.cpp +++ b/src/layer/arm/binaryop_arm.cpp @@ -31,6 +31,8 @@ BinaryOp_arm::BinaryOp_arm() #if __ARM_NEON support_packing = true; #endif // __ARM_NEON + + support_bf16_storage = true; } #if __ARM_NEON @@ -38,7 +40,7 @@ BinaryOp_arm::BinaryOp_arm() // https://github.com/Tencent/ncnn/wiki/binaryop-broadcasting template -static int binary_op(const Mat& a, const Mat& b, Mat& c, const Option& opt) +static int binary_op_pack4(const Mat& a, const Mat& b, Mat& c, const Option& opt) { Op op; @@ -492,7 +494,7 @@ static int binary_op(const Mat& a, const Mat& b, Mat& c, const Option& opt) } template -static int binary_op_scalar_inplace(Mat& a, float b, const Option& opt) +static int binary_op_scalar_inplace_pack4(Mat& a, float b, const Option& opt) { Op op; @@ -520,24 +522,20 @@ static int binary_op_scalar_inplace(Mat& a, float b, const Option& opt) return 0; } -template -struct binary_op_add { - T operator() (const T& x, const T& y) const { return vaddq_f32(x, y); } +struct binary_op_add_pack4 { + float32x4_t operator() (const float32x4_t& x, const float32x4_t& y) const { return vaddq_f32(x, y); } }; -template -struct binary_op_sub { - T operator() (const T& x, const T& y) const { return vsubq_f32(x, y); } +struct binary_op_sub_pack4 { + float32x4_t operator() (const float32x4_t& x, const float32x4_t& y) const { return vsubq_f32(x, y); } }; -template -struct binary_op_mul { - T operator() (const T& x, const T& y) const { return vmulq_f32(x, y); } +struct binary_op_mul_pack4 { + float32x4_t operator() (const float32x4_t& x, const float32x4_t& y) const { return vmulq_f32(x, y); } }; -template -struct binary_op_div { - T operator() (const T& x, const T& y) const +struct binary_op_div_pack4 { + float32x4_t operator() (const float32x4_t& x, const float32x4_t& y) const #if __aarch64__ { return vdivq_f32(x, y); } #else @@ -545,29 +543,24 @@ struct binary_op_div { #endif }; -template -struct binary_op_max { - T operator() (const T& x, const T& y) const { return vmaxq_f32(x, y); } +struct binary_op_max_pack4 { + float32x4_t operator() (const float32x4_t& x, const float32x4_t& y) const { return vmaxq_f32(x, y); } }; -template -struct binary_op_min { - T operator() (const T& x, const T& y) const { return vminq_f32(x, y); } +struct binary_op_min_pack4 { + float32x4_t operator() (const float32x4_t& x, const float32x4_t& y) const { return vminq_f32(x, y); } }; -template -struct binary_op_pow { - T operator() (const T& x, const T& y) const { return pow_ps(x, y); } +struct binary_op_pow_pack4 { + float32x4_t operator() (const float32x4_t& x, const float32x4_t& y) const { return pow_ps(x, y); } }; -template -struct binary_op_rsub { - T operator() (const T& x, const T& y) const { return vsubq_f32(y, x); } +struct binary_op_rsub_pack4 { + float32x4_t operator() (const float32x4_t& x, const float32x4_t& y) const { return vsubq_f32(y, x); } }; -template -struct binary_op_rdiv { - T operator() (const T& x, const T& y) const +struct binary_op_rdiv_pack4 { + float32x4_t operator() (const float32x4_t& x, const float32x4_t& y) const #if __aarch64__ { return vdivq_f32(y, x); } #else @@ -578,51 +571,47 @@ struct binary_op_rdiv { int BinaryOp_arm::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const { + if (opt.use_bf16_storage) + return forward_bf16s(bottom_blobs, top_blobs, opt); + const Mat& bottom_blob = bottom_blobs[0]; const Mat& bottom_blob1 = bottom_blobs[1]; Mat& top_blob = top_blobs[0]; #if __ARM_NEON - if (opt.use_packing_layout) - { - int elempack = bottom_blob.elempack; int elempack1 = bottom_blob1.elempack; if (elempack == 4 || elempack1 == 4) { - if (op_type == Operation_ADD) - return binary_op< binary_op_add >(bottom_blob, bottom_blob1, top_blob, opt); + return binary_op_pack4(bottom_blob, bottom_blob1, top_blob, opt); if (op_type == Operation_SUB) - return binary_op< binary_op_sub >(bottom_blob, bottom_blob1, top_blob, opt); + return binary_op_pack4(bottom_blob, bottom_blob1, top_blob, opt); if (op_type == Operation_MUL) - return binary_op< binary_op_mul >(bottom_blob, bottom_blob1, top_blob, opt); + return binary_op_pack4(bottom_blob, bottom_blob1, top_blob, opt); if (op_type == Operation_DIV) - return binary_op< binary_op_div >(bottom_blob, bottom_blob1, top_blob, opt); + return binary_op_pack4(bottom_blob, bottom_blob1, top_blob, opt); if (op_type == Operation_MAX) - return binary_op< binary_op_max >(bottom_blob, bottom_blob1, top_blob, opt); + return binary_op_pack4(bottom_blob, bottom_blob1, top_blob, opt); if (op_type == Operation_MIN) - return binary_op< binary_op_min >(bottom_blob, bottom_blob1, top_blob, opt); + return binary_op_pack4(bottom_blob, bottom_blob1, top_blob, opt); if (op_type == Operation_POW) - return binary_op< binary_op_pow >(bottom_blob, bottom_blob1, top_blob, opt); + return binary_op_pack4(bottom_blob, bottom_blob1, top_blob, opt); if (op_type == Operation_RSUB) - return binary_op< binary_op_rsub >(bottom_blob, bottom_blob1, top_blob, opt); + return binary_op_pack4(bottom_blob, bottom_blob1, top_blob, opt); if (op_type == Operation_RDIV) - return binary_op< binary_op_rdiv >(bottom_blob, bottom_blob1, top_blob, opt); - + return binary_op_pack4(bottom_blob, bottom_blob1, top_blob, opt); } - - } // opt.use_packing_layout #endif // __ARM_NEON return BinaryOp::forward(bottom_blobs, top_blobs, opt); @@ -630,47 +619,1093 @@ int BinaryOp_arm::forward(const std::vector& bottom_blobs, std::vector int BinaryOp_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const { -#if __ARM_NEON - if (opt.use_packing_layout) - { + if (opt.use_bf16_storage) + return forward_inplace_bf16s(bottom_top_blob, opt); +#if __ARM_NEON int elempack = bottom_top_blob.elempack; if (elempack == 4) { if (op_type == Operation_ADD) - return binary_op_scalar_inplace< binary_op_add >(bottom_top_blob, b, opt); + return binary_op_scalar_inplace_pack4(bottom_top_blob, b, opt); if (op_type == Operation_SUB) - return binary_op_scalar_inplace< binary_op_sub >(bottom_top_blob, b, opt); + return binary_op_scalar_inplace_pack4(bottom_top_blob, b, opt); if (op_type == Operation_MUL) - return binary_op_scalar_inplace< binary_op_mul >(bottom_top_blob, b, opt); + return binary_op_scalar_inplace_pack4(bottom_top_blob, b, opt); if (op_type == Operation_DIV) - return binary_op_scalar_inplace< binary_op_div >(bottom_top_blob, b, opt); + return binary_op_scalar_inplace_pack4(bottom_top_blob, b, opt); if (op_type == Operation_MAX) - return binary_op_scalar_inplace< binary_op_max >(bottom_top_blob, b, opt); + return binary_op_scalar_inplace_pack4(bottom_top_blob, b, opt); if (op_type == Operation_MIN) - return binary_op_scalar_inplace< binary_op_min >(bottom_top_blob, b, opt); + return binary_op_scalar_inplace_pack4(bottom_top_blob, b, opt); if (op_type == Operation_POW) - return binary_op_scalar_inplace< binary_op_pow >(bottom_top_blob, b, opt); + return binary_op_scalar_inplace_pack4(bottom_top_blob, b, opt); if (op_type == Operation_RSUB) - return binary_op_scalar_inplace< binary_op_rsub >(bottom_top_blob, b, opt); + return binary_op_scalar_inplace_pack4(bottom_top_blob, b, opt); if (op_type == Operation_RDIV) - return binary_op_scalar_inplace< binary_op_rdiv >(bottom_top_blob, b, opt); - + return binary_op_scalar_inplace_pack4(bottom_top_blob, b, opt); } - - } // opt.use_packing_layout #endif // __ARM_NEON return BinaryOp::forward_inplace(bottom_top_blob, opt); } +#if __ARM_NEON +template +static int binary_op_pack4_bf16s(const Mat& a, const Mat& b, Mat& c, const Option& opt) +{ + Op op; + + int w = a.w; + int h = a.h; + int channels = a.c; + int size = w * h; + size_t elemsize = a.elemsize; + int elempack = a.elempack; + + int w1 = b.w; + int h1 = b.h; + int channels1 = b.c; + int size1 = w1 * h1; + size_t elemsize1 = b.elemsize; + int elempack1 = b.elempack; + + if (a.dims == 3) + { + c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); + if (c.empty()) + return -100; + + if (b.dims == 3) + { + if (w1 == 1 && h1 == 1 && channels1 == channels) + { + // special type 1 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q=0; q(q); + unsigned short* outptr = c.channel(q); + + for (int y=0; y(q); + const unsigned short* ptr1 = b.channel(q); + unsigned short* outptr = c.channel(q); + + for (int y=0; y +static int binary_op_scalar_inplace_pack4_bf16s(Mat& a, float b, const Option& opt) +{ + Op op; + + int w = a.w; + int h = a.h; + int channels = a.c; + int size = w * h; + + float32x4_t _b = vdupq_n_f32(b); + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q=0; q +static int binary_op_bf16s(const Mat& a, const Mat& b, Mat& c, const Option& opt) +{ + Op op; + + int w = a.w; + int h = a.h; + int channels = a.c; + int size = w * h; + size_t elemsize = a.elemsize; + + int w1 = b.w; + int h1 = b.h; + int channels1 = b.c; + int size1 = w1 * h1; + + if (a.dims == 3) + { + c.create(w, h, channels, elemsize, opt.blob_allocator); + if (c.empty()) + return -100; + + if (b.dims == 3) + { + if (w1 == 1 && h1 == 1 && channels1 == channels) + { + // special type 1 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const unsigned short* ptr = a.channel(q); + unsigned short* outptr = c.channel(q); + const unsigned short* b0 = b.channel(q); + for (int i = 0; i < size; i++) + { + outptr[i] = float32_to_bfloat16(op(bfloat16_to_float32(ptr[i]), bfloat16_to_float32(b0[0]))); + } + } + + return 0; + } + + if (w1 == w && h1 == h && channels1 == 1) + { + // special type 2 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const unsigned short* ptr = a.channel(q); + const unsigned short* ptr1 = b; + unsigned short* outptr = c.channel(q); + for (int i = 0; i < size; i++) + { + outptr[i] = float32_to_bfloat16(op(bfloat16_to_float32(ptr[i]), bfloat16_to_float32(ptr1[i]))); + } + } + + return 0; + } + + // type 19 + #pragma omp parallel for num_threads(opt.num_threads) + for (int q=0; q(q); + unsigned short* outptr = c.channel(q); + + for (int y=0; y(q); + const unsigned short* ptr1 = b.channel(q); + unsigned short* outptr = c.channel(q); + + for (int y=0; y +static int binary_op_scalar_inplace_bf16s(Mat& a, float b, const Option& opt) +{ + Op op; + + int w = a.w; + int h = a.h; + int channels = a.c; + int size = w * h; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q=0; q& bottom_blobs, std::vector& top_blobs, const Option& opt) const +{ + const Mat& bottom_blob = bottom_blobs[0]; + const Mat& bottom_blob1 = bottom_blobs[1]; + + Mat& top_blob = top_blobs[0]; + +#if __ARM_NEON + int elempack = bottom_blob.elempack; + int elempack1 = bottom_blob1.elempack; + + if (elempack == 4 || elempack1 == 4) + { + if (op_type == Operation_ADD) + return binary_op_pack4_bf16s(bottom_blob, bottom_blob1, top_blob, opt); + + if (op_type == Operation_SUB) + return binary_op_pack4_bf16s(bottom_blob, bottom_blob1, top_blob, opt); + + if (op_type == Operation_MUL) + return binary_op_pack4_bf16s(bottom_blob, bottom_blob1, top_blob, opt); + + if (op_type == Operation_DIV) + return binary_op_pack4_bf16s(bottom_blob, bottom_blob1, top_blob, opt); + + if (op_type == Operation_MAX) + return binary_op_pack4_bf16s(bottom_blob, bottom_blob1, top_blob, opt); + + if (op_type == Operation_MIN) + return binary_op_pack4_bf16s(bottom_blob, bottom_blob1, top_blob, opt); + + if (op_type == Operation_POW) + return binary_op_pack4_bf16s(bottom_blob, bottom_blob1, top_blob, opt); + + if (op_type == Operation_RSUB) + return binary_op_pack4_bf16s(bottom_blob, bottom_blob1, top_blob, opt); + + if (op_type == Operation_RDIV) + return binary_op_pack4_bf16s(bottom_blob, bottom_blob1, top_blob, opt); + } +#endif // __ARM_NEON + + if (elempack == 1 && elempack1 == 1) + { + if (op_type == Operation_ADD) + return binary_op_bf16s< std::plus >(bottom_blob, bottom_blob1, top_blob, opt); + + if (op_type == Operation_SUB) + return binary_op_bf16s< std::minus >(bottom_blob, bottom_blob1, top_blob, opt); + + if (op_type == Operation_MUL) + return binary_op_bf16s< std::multiplies >(bottom_blob, bottom_blob1, top_blob, opt); + + if (op_type == Operation_DIV) + return binary_op_bf16s< std::divides >(bottom_blob, bottom_blob1, top_blob, opt); + + if (op_type == Operation_MAX) + return binary_op_bf16s(bottom_blob, bottom_blob1, top_blob, opt); + + if (op_type == Operation_MIN) + return binary_op_bf16s(bottom_blob, bottom_blob1, top_blob, opt); + + if (op_type == Operation_POW) + return binary_op_bf16s(bottom_blob, bottom_blob1, top_blob, opt); + + if (op_type == Operation_RSUB) + return binary_op_bf16s(bottom_blob, bottom_blob1, top_blob, opt); + + if (op_type == Operation_RDIV) + return binary_op_bf16s(bottom_blob, bottom_blob1, top_blob, opt); + } + + return 0; +} + +int BinaryOp_arm::forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const +{ +#if __ARM_NEON + int elempack = bottom_top_blob.elempack; + + if (elempack == 4) + { + if (op_type == Operation_ADD) + return binary_op_scalar_inplace_pack4_bf16s(bottom_top_blob, b, opt); + + if (op_type == Operation_SUB) + return binary_op_scalar_inplace_pack4_bf16s(bottom_top_blob, b, opt); + + if (op_type == Operation_MUL) + return binary_op_scalar_inplace_pack4_bf16s(bottom_top_blob, b, opt); + + if (op_type == Operation_DIV) + return binary_op_scalar_inplace_pack4_bf16s(bottom_top_blob, b, opt); + + if (op_type == Operation_MAX) + return binary_op_scalar_inplace_pack4_bf16s(bottom_top_blob, b, opt); + + if (op_type == Operation_MIN) + return binary_op_scalar_inplace_pack4_bf16s(bottom_top_blob, b, opt); + + if (op_type == Operation_POW) + return binary_op_scalar_inplace_pack4_bf16s(bottom_top_blob, b, opt); + + if (op_type == Operation_RSUB) + return binary_op_scalar_inplace_pack4_bf16s(bottom_top_blob, b, opt); + + if (op_type == Operation_RDIV) + return binary_op_scalar_inplace_pack4_bf16s(bottom_top_blob, b, opt); + } +#endif // __ARM_NEON + + if (elempack == 1) + { + if (op_type == Operation_ADD) + return binary_op_scalar_inplace_bf16s< std::plus >(bottom_top_blob, b, opt); + + if (op_type == Operation_SUB) + return binary_op_scalar_inplace_bf16s< std::minus >(bottom_top_blob, b, opt); + + if (op_type == Operation_MUL) + return binary_op_scalar_inplace_bf16s< std::multiplies >(bottom_top_blob, b, opt); + + if (op_type == Operation_DIV) + return binary_op_scalar_inplace_bf16s< std::divides >(bottom_top_blob, b, opt); + + if (op_type == Operation_MAX) + return binary_op_scalar_inplace_bf16s(bottom_top_blob, b, opt); + + if (op_type == Operation_MIN) + return binary_op_scalar_inplace_bf16s(bottom_top_blob, b, opt); + + if (op_type == Operation_POW) + return binary_op_scalar_inplace_bf16s(bottom_top_blob, b, opt); + + if (op_type == Operation_RSUB) + return binary_op_scalar_inplace_bf16s(bottom_top_blob, b, opt); + + if (op_type == Operation_RDIV) + return binary_op_scalar_inplace_bf16s(bottom_top_blob, b, opt); + } + + return 0; +} + } // namespace ncnn diff --git a/src/layer/arm/binaryop_arm.h b/src/layer/arm/binaryop_arm.h index dc74c7c2f..860e3ba53 100644 --- a/src/layer/arm/binaryop_arm.h +++ b/src/layer/arm/binaryop_arm.h @@ -27,6 +27,11 @@ public: virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const; + +protected: + int forward_bf16s(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; + int forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const; + }; } // namespace ncnn diff --git a/tests/test_binaryop.cpp b/tests/test_binaryop.cpp index 725bf85bb..9cdd06422 100644 --- a/tests/test_binaryop.cpp +++ b/tests/test_binaryop.cpp @@ -39,6 +39,7 @@ static int test_binaryop(const ncnn::Mat& _a, const ncnn::Mat& _b, int op_type) ncnn::Option opt; opt.num_threads = 1; opt.use_vulkan_compute = true; + opt.use_int8_inference = false; opt.use_fp16_packed = false; opt.use_fp16_storage = false; opt.use_fp16_arithmetic = false; @@ -78,6 +79,7 @@ static int test_binaryop(const ncnn::Mat& _a, float b, int op_type) ncnn::Option opt; opt.num_threads = 1; opt.use_vulkan_compute = true; + opt.use_int8_inference = false; opt.use_fp16_packed = false; opt.use_fp16_storage = false; opt.use_fp16_arithmetic = false;