From 6077066b025ff1c07b992aed327f4d94c6ee574e Mon Sep 17 00:00:00 2001 From: nihuini Date: Sun, 5 Apr 2020 14:02:32 +0800 Subject: [PATCH] binaryop broadcasting special type 3 4 for lhs --- docs/developer-guide/binaryop-broadcasting.md | 2 + src/layer/arm/binaryop_arm.cpp | 230 ++++++++++++++++-- src/layer/binaryop.cpp | 66 ++++- src/layer/vulkan/binaryop_vulkan.cpp | 11 +- .../vulkan/shader/binaryop_broadcast.comp | 14 ++ .../shader/binaryop_broadcast_a1_pack4.comp | 10 +- .../shader/binaryop_broadcast_a1_pack8.comp | 10 +- .../shader/binaryop_broadcast_pack4.comp | 7 + .../shader/binaryop_broadcast_pack8.comp | 7 + tests/test_binaryop.cpp | 36 +++ 10 files changed, 364 insertions(+), 29 deletions(-) diff --git a/docs/developer-guide/binaryop-broadcasting.md b/docs/developer-guide/binaryop-broadcasting.md index 9278afa45..0fc683b59 100644 --- a/docs/developer-guide/binaryop-broadcasting.md +++ b/docs/developer-guide/binaryop-broadcasting.md @@ -34,3 +34,5 @@ some special broadcasting rule exists for model compatibility |---|---|---|---| |1|[2,3,4]|[1,1,4]|[2,3,4]| |2|[2,3,4]|[2,3,1]|[2,3,4]| +|3|[1,1,4]|[2,3,4]|[2,3,4]| +|4|[2,3,1]|[2,3,4]|[2,3,4]| diff --git a/src/layer/arm/binaryop_arm.cpp b/src/layer/arm/binaryop_arm.cpp index 287fce8d0..424e2aad4 100644 --- a/src/layer/arm/binaryop_arm.cpp +++ b/src/layer/arm/binaryop_arm.cpp @@ -60,21 +60,21 @@ static int binary_op_pack4(const Mat& a, const Mat& b, Mat& c, const Option& opt if (a.dims == 3) { - c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); - if (c.empty()) - return -100; - if (b.dims == 3) { if (w1 == 1 && h1 == 1 && channels1 == channels) { // special type 1 + c.create(w, h, channels, elemsize, elempack, opt.blob_allocator); + if (c.empty()) + return -100; + #pragma omp parallel for num_threads(opt.num_threads) for (int q=0; qcreate(LayerShaderType::binaryop_broadcast_pack4, opt, specializations); } - if (shape.dims == 0 || (shape.dims == 1 && shape.w == 1 && elempack == 1 && elempack1 == 4)) + if (shape.dims == 0 || (shape.dims == 1 && shape.w == 1 && elempack == 1 && elempack1 == 4) + || (shape.dims == 3 && shape1.dims == 3 && shape1.w == shape.w && shape1.h == shape.h && shape.c == 1 && elempack == 1 && elempack1 == 4)) { pipeline_binaryop_broadcast_a1_pack4 = new Pipeline(vkdev); pipeline_binaryop_broadcast_a1_pack4->set_optimal_local_size_xyz(local_size_xyz); @@ -251,7 +252,8 @@ int BinaryOp_vulkan::create_pipeline(const Option& opt) pipeline_binaryop_broadcast_pack8->create(LayerShaderType::binaryop_broadcast_pack8, opt, specializations); } - if ((opt.use_shader_pack8 && shape.dims == 0) || (shape.dims == 1 && shape.w == 1 && elempack == 1 && elempack1 == 8)) + if ((opt.use_shader_pack8 && shape.dims == 0) || (shape.dims == 1 && shape.w == 1 && elempack == 1 && elempack1 == 8) + || (shape.dims == 3 && shape1.dims == 3 && shape1.w == shape.w && shape1.h == shape.h && shape.c == 1 && elempack == 1 && elempack1 == 8)) { pipeline_binaryop_broadcast_a1_pack8 = new Pipeline(vkdev); pipeline_binaryop_broadcast_a1_pack8->set_optimal_local_size_xyz(local_size_xyz); @@ -391,6 +393,11 @@ int BinaryOp_vulkan::forward(const std::vector& bottom_blobs, std::vector // special type 2 pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_b1_pack8 : pipeline_binaryop_broadcast_b1_pack4; } + else if (bottom_blob.dims == 3 && bottom_blob1.dims == 3 && bottom_blob1.w == bottom_blob.w && bottom_blob1.h == bottom_blob.h && bottom_blob.c == 1 && bottom_blob.elempack == 1) + { + // special type 4 + pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_a1_pack8 : pipeline_binaryop_broadcast_a1_pack4; + } else { pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_pack8 : pipeline_binaryop_broadcast_pack4; diff --git a/src/layer/vulkan/shader/binaryop_broadcast.comp b/src/layer/vulkan/shader/binaryop_broadcast.comp index a3005c22d..379a8f812 100644 --- a/src/layer/vulkan/shader/binaryop_broadcast.comp +++ b/src/layer/vulkan/shader/binaryop_broadcast.comp @@ -102,6 +102,20 @@ void main() ai = gi; bi = gy * psc(bw) + gx; } + + if (psc(aw) == 1 && psc(ah) == 1) + { + // special type 3 + ai = gz * psc(acstep); + bi = gi; + } + + if (psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(ac) == 1) + { + // special type 4 + ai = gy * psc(aw) + gx; + bi = gi; + } } if (psc(bdims) == 2) diff --git a/src/layer/vulkan/shader/binaryop_broadcast_a1_pack4.comp b/src/layer/vulkan/shader/binaryop_broadcast_a1_pack4.comp index 85b54263c..fb892876a 100644 --- a/src/layer/vulkan/shader/binaryop_broadcast_a1_pack4.comp +++ b/src/layer/vulkan/shader/binaryop_broadcast_a1_pack4.comp @@ -82,8 +82,16 @@ void main() const int gi = gz * psc(outcstep) + gy * psc(outw) + gx; + int ai = 0; + + if (psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(ac) == 1) + { + // special type 4 + ai = gy * psc(aw) + gx; + } + // type 2 3 4 - afpvec4 v1 = afpvec4(buffer_ld1(a_blob_data, 0)); + afpvec4 v1 = afpvec4(buffer_ld1(a_blob_data, ai)); afpvec4 v2 = buffer_ld4(b_blob_data, gi); afpvec4 res; diff --git a/src/layer/vulkan/shader/binaryop_broadcast_a1_pack8.comp b/src/layer/vulkan/shader/binaryop_broadcast_a1_pack8.comp index b84ddfd2c..c5c321f2d 100644 --- a/src/layer/vulkan/shader/binaryop_broadcast_a1_pack8.comp +++ b/src/layer/vulkan/shader/binaryop_broadcast_a1_pack8.comp @@ -83,8 +83,16 @@ void main() const int gi = gz * psc(outcstep) + gy * psc(outw) + gx; + int ai = 0; + + if (psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(ac) == 1) + { + // special type 2 + ai = gy * psc(bw) + gx; + } + // type 2 3 4 - afpvec4 v1 = afpvec4(buffer_ld1(a_blob_data, 0)); + afpvec4 v1 = afpvec4(buffer_ld1(a_blob_data, ai)); afpvec8 v2 = buffer_ld8(b_blob_data, gi); afpvec8 res; diff --git a/src/layer/vulkan/shader/binaryop_broadcast_pack4.comp b/src/layer/vulkan/shader/binaryop_broadcast_pack4.comp index 61d8c8b01..6f4dfa167 100644 --- a/src/layer/vulkan/shader/binaryop_broadcast_pack4.comp +++ b/src/layer/vulkan/shader/binaryop_broadcast_pack4.comp @@ -95,6 +95,13 @@ void main() ai = gi; bi = gz * psc(bcstep); } + + if (psc(aw) == 1 && psc(ah) == 1) + { + // special type 3 + ai = gz * psc(acstep); + bi = gi; + } } if (psc(bdims) == 2) diff --git a/src/layer/vulkan/shader/binaryop_broadcast_pack8.comp b/src/layer/vulkan/shader/binaryop_broadcast_pack8.comp index ec91e1973..958f42848 100644 --- a/src/layer/vulkan/shader/binaryop_broadcast_pack8.comp +++ b/src/layer/vulkan/shader/binaryop_broadcast_pack8.comp @@ -96,6 +96,13 @@ void main() ai = gi; bi = gz * psc(bcstep); } + + if (psc(aw) == 1 && psc(ah) == 1) + { + // special type 3 + ai = gz * psc(acstep); + bi = gi; + } } if (psc(bdims) == 2) diff --git a/tests/test_binaryop.cpp b/tests/test_binaryop.cpp index 9cdd06422..6dd2d28c8 100644 --- a/tests/test_binaryop.cpp +++ b/tests/test_binaryop.cpp @@ -452,6 +452,40 @@ static int test_binaryop_s2() return 0; } +static int test_binaryop_s3() +{ + for (int op_type=0; op_type