Browse Source

enrich ncnn binary broadcast rules (#4513)

tags/20230223
nihui GitHub 3 years ago
parent
commit
ab4cfbf5b0
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
35 changed files with 6387 additions and 14302 deletions
  1. +40
    -39
      docs/developer-guide/binaryop-broadcasting.md
  2. +1
    -0
      docs/developer-guide/operators.md
  3. +869
    -2133
      src/layer/arm/binaryop_arm.cpp
  4. +500
    -2126
      src/layer/arm/binaryop_arm_asimdhp.cpp
  5. +334
    -722
      src/layer/binaryop.cpp
  6. +2
    -1
      src/layer/binaryop.h
  7. +424
    -733
      src/layer/loongarch/binaryop_loongarch.cpp
  8. +426
    -739
      src/layer/mips/binaryop_mips.cpp
  9. +938
    -2372
      src/layer/riscv/binaryop_riscv.cpp
  10. +348
    -337
      src/layer/vulkan/binaryop_vulkan.cpp
  11. +6
    -7
      src/layer/vulkan/binaryop_vulkan.h
  12. +27
    -31
      src/layer/vulkan/shader/binaryop.comp
  13. +0
    -553
      src/layer/vulkan/shader/binaryop_broadcast.comp
  14. +0
    -169
      src/layer/vulkan/shader/binaryop_broadcast_a1_pack8.comp
  15. +193
    -0
      src/layer/vulkan/shader/binaryop_broadcast_inner.comp
  16. +193
    -0
      src/layer/vulkan/shader/binaryop_broadcast_inner_pack4.comp
  17. +234
    -0
      src/layer/vulkan/shader/binaryop_broadcast_inner_pack8.comp
  18. +43
    -47
      src/layer/vulkan/shader/binaryop_broadcast_outer.comp
  19. +34
    -38
      src/layer/vulkan/shader/binaryop_broadcast_outer_pack4.comp
  20. +59
    -56
      src/layer/vulkan/shader/binaryop_broadcast_outer_pack8.comp
  21. +0
    -502
      src/layer/vulkan/shader/binaryop_broadcast_pack4.comp
  22. +0
    -539
      src/layer/vulkan/shader/binaryop_broadcast_pack8.comp
  23. +21
    -34
      src/layer/vulkan/shader/binaryop_pack4.comp
  24. +62
    -106
      src/layer/vulkan/shader/binaryop_pack8.comp
  25. +633
    -2178
      src/layer/x86/binaryop_x86.cpp
  26. +237
    -281
      tests/test_binaryop.cpp
  27. +237
    -279
      tests/test_binaryop_1.cpp
  28. +238
    -280
      tests/test_binaryop_2.cpp
  29. +1
    -0
      tools/pnnx/src/CMakeLists.txt
  30. +2
    -0
      tools/pnnx/src/pass_ncnn.cpp
  31. +28
    -0
      tools/pnnx/src/pass_ncnn/expand_expression.cpp
  32. +153
    -0
      tools/pnnx/src/pass_ncnn/insert_reshape_numpy_binaryop_broadcast.cpp
  33. +25
    -0
      tools/pnnx/src/pass_ncnn/insert_reshape_numpy_binaryop_broadcast.h
  34. +1
    -0
      tools/pnnx/tests/ncnn/CMakeLists.txt
  35. +78
    -0
      tools/pnnx/tests/ncnn/test_ncnn_numpy_binaryop_broadcast.py

+ 40
- 39
docs/developer-guide/binaryop-broadcasting.md View File

@@ -6,47 +6,48 @@ C = BinaryOp(A, B)

shape notation convention is [w], [w,h], [w,h,c], [w,h,d,c]

* binaryop with scalar and scalar-like

|type|A|B|C|
|---|---|---|---|
|0|[2]|scalar / [1]|[2]|
|1|[2,3]|scalar / [1] / [1,1]|[2,3]|
|2|[2,3,4]|scalar / [1] / [1,1] / [1,1,1]|[2,3,4]|
|3|[2,3,4,5]|scalar / [1] / [1,1] / [1,1,1] / [1,1,1,1]|[2,3,4,5]|

* no broadcast

|type|A|B|C|
|---|---|---|---|
|4|[2]|[2]|[2]|
|5|[2,3]|[2,3]|[2,3]|
|6|[2,3,4]|[2,3,4]|[2,3,4]|
|7|[2,3,4,5]|[2,3,4,5]|[2,3,4,5]|

* broadcast B for inner axis

|type|A|B|C|
|---|---|---|---|
|1|[1]|scalar|[1]|
|2|[1]|[2]|[2]|
|3|[1]|[2,3]|[2,3]|
|4|[1]|[2,3,4]|[2,3,4]|
|5|[2]|scalar|[2]|
|6|[2]|[1]|[2]|
|7|[2]|[2]|[2]|
|8|[3]|[2,3]|[2,3]|
|9|[4]|[2,3,4]|[2,3,4]|
|10|[2,3]|scalar|[2,3]|
|11|[2,3]|[1]|[2,3]|
|12|[2,3]|[3]|[2,3]|
|13|[2,3]|[2,3]|[2,3]|
|14|[3,4]|[2,3,4]|[2,3,4]|
|15|[2,3,4]|scalar|[2,3,4]|
|16|[2,3,4]|[1]|[2,3,4]|
|17|[2,3,4]|[4]|[2,3,4]|
|18|[2,3,4]|[3,4]|[2,3,4]|
|19|[2,3,4]|[2,3,4]|[2,3,4]|
|20|[1]|[2,3,4,5]|[2,3,4,5]|
|21|[5]|[2,3,4,5]|[2,3,4,5]|
|22|[4,5]|[2,3,4,5]|[2,3,4,5]|
|23|[3,4,5]|[2,3,4,5]|[2,3,4,5]|
|24|[2,3,4,5]|scalar|[2,3,4,5]|
|25|[2,3,4,5]|[1]|[2,3,4,5]|
|26|[2,3,4,5]|[5]|[2,3,4,5]|
|27|[2,3,4,5]|[4,5]|[2,3,4,5]|
|28|[2,3,4,5]|[3,4,5]|[2,3,4,5]|
|29|[2,3,4,5]|[2,3,4,5]|[2,3,4,5]|

some special broadcasting rule exists for model compatibility
|8|[2,3]|[3] / [1,3]|[2,3]|
|9|[2,3,4]|[4] / [1,1,4]|[2,3,4]|
|10|[2,3,4]|[3,4] / [1,3,4]|[2,3,4]|
|11|[2,3,4,5]|[5] / [1,1,1,5]|[2,3,4,5]|
|12|[2,3,4,5]|[4,5] / [1,1,4,5]|[2,3,4,5]|
|13|[2,3,4,5]|[3,4,5] / [1,3,4,5]|[2,3,4,5]|

* broadcast B for outer axis

|type|A|B|C|
|---|---|---|---|
|14|[2,3]|[2,1]|[2,3]|
|15|[2,3,4]|[2,1,1]|[2,3,4]|
|16|[2,3,4]|[2,3,1]|[2,3,4]|
|17|[2,3,4,5]|[2,1,1,1]|[2,3,4,5]|
|18|[2,3,4,5]|[2,3,1,1]|[2,3,4,5]|
|19|[2,3,4,5]|[2,3,4,1]|[2,3,4,5]|

* some special broadcasting rule exists for model compatibility

|special type|A|B|C|
|---|---|---|---|
|1|[2,3,4]|[1,1,4]|[2,3,4]|
|2|[2,3,4]|[2,3,1]|[2,3,4]|
|3|[1,1,4]|[2,3,4]|[2,3,4]|
|4|[2,3,1]|[2,3,4]|[2,3,4]|
|5|[2,3,4]|[1,3,4]|[2,3,4]|
|6|[2,3,4]|[2,1,4]|[2,3,4]|
|7|[1,3,4]|[2,3,4]|[2,3,4]|
|8|[2,1,4]|[2,3,4]|[2,3,4]|
|20|[2,3,4]|[2,1,4]|[2,3,4]|

+ 1
- 0
docs/developer-guide/operators.md View File

@@ -161,6 +161,7 @@ Operation type:
- 6 = POW
- 7 = RSUB
- 8 = RDIV
- 9 = RPOW

# BNLL
```


+ 869
- 2133
src/layer/arm/binaryop_arm.cpp
File diff suppressed because it is too large
View File


+ 500
- 2126
src/layer/arm/binaryop_arm_asimdhp.cpp
File diff suppressed because it is too large
View File


+ 334
- 722
src/layer/binaryop.cpp
File diff suppressed because it is too large
View File


+ 2
- 1
src/layer/binaryop.h View File

@@ -42,7 +42,8 @@ public:
Operation_MIN = 5,
Operation_POW = 6,
Operation_RSUB = 7,
Operation_RDIV = 8
Operation_RDIV = 8,
Operation_RPOW = 9
};

public:


+ 424
- 733
src/layer/loongarch/binaryop_loongarch.cpp
File diff suppressed because it is too large
View File


+ 426
- 739
src/layer/mips/binaryop_mips.cpp
File diff suppressed because it is too large
View File


+ 938
- 2372
src/layer/riscv/binaryop_riscv.cpp
File diff suppressed because it is too large
View File


+ 348
- 337
src/layer/vulkan/binaryop_vulkan.cpp View File

@@ -29,13 +29,29 @@ BinaryOp_vulkan::BinaryOp_vulkan()
pipeline_binaryop_pack4 = 0;
pipeline_binaryop_pack8 = 0;

pipeline_binaryop_broadcast = 0;
pipeline_binaryop_broadcast_pack4 = 0;
pipeline_binaryop_broadcast_a1_pack4 = 0;
pipeline_binaryop_broadcast_b1_pack4 = 0;
pipeline_binaryop_broadcast_pack8 = 0;
pipeline_binaryop_broadcast_a1_pack8 = 0;
pipeline_binaryop_broadcast_b1_pack8 = 0;
pipeline_binaryop_broadcast_inner[0] = 0;
pipeline_binaryop_broadcast_inner[1] = 0;
pipeline_binaryop_broadcast_inner_pack4[0] = 0;
pipeline_binaryop_broadcast_inner_pack4[1] = 0;
pipeline_binaryop_broadcast_inner_pack8[0] = 0;
pipeline_binaryop_broadcast_inner_pack8[1] = 0;
pipeline_binaryop_broadcast_outer[0] = 0;
pipeline_binaryop_broadcast_outer[1] = 0;
pipeline_binaryop_broadcast_outer_pack4[0] = 0;
pipeline_binaryop_broadcast_outer_pack4[1] = 0;
pipeline_binaryop_broadcast_outer_pack8[0] = 0;
pipeline_binaryop_broadcast_outer_pack8[1] = 0;
}

static int get_reverse_op_type(int op_type)
{
if (op_type == BinaryOp::Operation_SUB) return BinaryOp::Operation_RSUB;
if (op_type == BinaryOp::Operation_DIV) return BinaryOp::Operation_RDIV;
if (op_type == BinaryOp::Operation_POW) return BinaryOp::Operation_RPOW;
if (op_type == BinaryOp::Operation_RSUB) return BinaryOp::Operation_SUB;
if (op_type == BinaryOp::Operation_RDIV) return BinaryOp::Operation_DIV;
if (op_type == BinaryOp::Operation_RPOW) return BinaryOp::Operation_POW;
return op_type;
}

int BinaryOp_vulkan::create_pipeline(const Option& opt)
@@ -182,20 +198,33 @@ int BinaryOp_vulkan::create_pipeline(const Option& opt)
// broadcast
if (shape.dims == 0 || broadcast)
{
bool a_is_lower = false;
if (shape.dims != 0 && shape1.dims != 0)
{
const bool b_is_scalar = shape1_packed.w * shape1_packed.h * shape1_packed.d * shape1_packed.c * shape1_packed.elempack == 1;
const bool a_rank_is_lower = shape_packed.dims < shape1_packed.dims && !b_is_scalar;
const bool a_size_is_lower = shape_packed.w * shape_packed.h * shape_packed.d * shape_packed.c * shape_packed.elempack < shape1_packed.w * shape1_packed.h * shape1_packed.d * shape1_packed.c * shape1_packed.elempack;
a_is_lower = a_rank_is_lower || (!a_rank_is_lower && a_size_is_lower);
}
const Mat& A_shape_packed = a_is_lower ? shape1_packed : shape_packed;
const Mat& B_shape_packed = a_is_lower ? shape_packed : shape1_packed;

const int op_type_r = get_reverse_op_type(op_type);

std::vector<vk_specialization_type> specializations(1 + 18);
specializations[0].i = op_type;
specializations[1 + 0].i = shape_packed.dims;
specializations[1 + 1].i = shape_packed.w;
specializations[1 + 2].i = shape_packed.h;
specializations[1 + 3].i = shape_packed.d;
specializations[1 + 4].i = shape_packed.c;
specializations[1 + 5].i = shape_packed.cstep;
specializations[1 + 6].i = shape1_packed.dims;
specializations[1 + 7].i = shape1_packed.w;
specializations[1 + 8].i = shape1_packed.h;
specializations[1 + 9].i = shape1_packed.d;
specializations[1 + 10].i = shape1_packed.c;
specializations[1 + 11].i = shape1_packed.cstep;
specializations[1 + 0].i = A_shape_packed.dims;
specializations[1 + 1].i = A_shape_packed.w;
specializations[1 + 2].i = A_shape_packed.h;
specializations[1 + 3].i = A_shape_packed.d;
specializations[1 + 4].i = A_shape_packed.c;
specializations[1 + 5].i = A_shape_packed.cstep;
specializations[1 + 6].i = B_shape_packed.dims;
specializations[1 + 7].i = B_shape_packed.w;
specializations[1 + 8].i = B_shape_packed.h;
specializations[1 + 9].i = B_shape_packed.d;
specializations[1 + 10].i = B_shape_packed.c;
specializations[1 + 11].i = B_shape_packed.cstep;
specializations[1 + 12].i = out_shape_packed.dims;
specializations[1 + 13].i = out_shape_packed.w;
specializations[1 + 14].i = out_shape_packed.h;
@@ -203,23 +232,26 @@ int BinaryOp_vulkan::create_pipeline(const Option& opt)
specializations[1 + 16].i = out_shape_packed.c;
specializations[1 + 17].i = out_shape_packed.cstep;

std::vector<vk_specialization_type> specializations_broadcast_a1_b1(1 + 15);
specializations_broadcast_a1_b1[0].i = op_type;
specializations_broadcast_a1_b1[1 + 0].i = shape_packed.dims;
specializations_broadcast_a1_b1[1 + 1].i = shape_packed.w;
specializations_broadcast_a1_b1[1 + 2].i = shape_packed.h * shape_packed.d;
specializations_broadcast_a1_b1[1 + 3].i = shape_packed.c;
specializations_broadcast_a1_b1[1 + 4].i = shape_packed.cstep;
specializations_broadcast_a1_b1[1 + 5].i = shape1_packed.dims;
specializations_broadcast_a1_b1[1 + 6].i = shape1_packed.w;
specializations_broadcast_a1_b1[1 + 7].i = shape1_packed.h * shape1_packed.d;
specializations_broadcast_a1_b1[1 + 8].i = shape1_packed.c;
specializations_broadcast_a1_b1[1 + 9].i = shape1_packed.cstep;
specializations_broadcast_a1_b1[1 + 10].i = out_shape_packed.dims;
specializations_broadcast_a1_b1[1 + 11].i = out_shape_packed.w;
specializations_broadcast_a1_b1[1 + 12].i = out_shape_packed.h * out_shape_packed.d;
specializations_broadcast_a1_b1[1 + 13].i = out_shape_packed.c;
specializations_broadcast_a1_b1[1 + 14].i = out_shape_packed.cstep;
std::vector<vk_specialization_type> specializations_r(1 + 18);
specializations_r[0].i = op_type_r;
specializations_r[1 + 0].i = A_shape_packed.dims;
specializations_r[1 + 1].i = A_shape_packed.w;
specializations_r[1 + 2].i = A_shape_packed.h;
specializations_r[1 + 3].i = A_shape_packed.d;
specializations_r[1 + 4].i = A_shape_packed.c;
specializations_r[1 + 5].i = A_shape_packed.cstep;
specializations_r[1 + 6].i = B_shape_packed.dims;
specializations_r[1 + 7].i = B_shape_packed.w;
specializations_r[1 + 8].i = B_shape_packed.h;
specializations_r[1 + 9].i = B_shape_packed.d;
specializations_r[1 + 10].i = B_shape_packed.c;
specializations_r[1 + 11].i = B_shape_packed.cstep;
specializations_r[1 + 12].i = out_shape_packed.dims;
specializations_r[1 + 13].i = out_shape_packed.w;
specializations_r[1 + 14].i = out_shape_packed.h;
specializations_r[1 + 15].i = out_shape_packed.d;
specializations_r[1 + 16].i = out_shape_packed.c;
specializations_r[1 + 17].i = out_shape_packed.cstep;

Mat local_size_xyz;
if (out_shape_packed.dims == 1)
@@ -248,59 +280,75 @@ int BinaryOp_vulkan::create_pipeline(const Option& opt)
}

// pack1
if (shape.dims == 0 || (elempack == 1 && elempack1 == 1))
if (shape.dims == 0 || (out_elempack == 1))
{
pipeline_binaryop_broadcast = new Pipeline(vkdev);
pipeline_binaryop_broadcast->set_optimal_local_size_xyz(local_size_xyz);
pipeline_binaryop_broadcast->create(LayerShaderType::binaryop_broadcast, opt, specializations);
pipeline_binaryop_broadcast_inner[0] = new Pipeline(vkdev);
pipeline_binaryop_broadcast_inner[0]->set_optimal_local_size_xyz(local_size_xyz);
pipeline_binaryop_broadcast_inner[0]->create(LayerShaderType::binaryop_broadcast_inner, opt, specializations);

pipeline_binaryop_broadcast_outer[0] = new Pipeline(vkdev);
pipeline_binaryop_broadcast_outer[0]->set_optimal_local_size_xyz(local_size_xyz);
pipeline_binaryop_broadcast_outer[0]->create(LayerShaderType::binaryop_broadcast_outer, opt, specializations);

if (op_type_r != op_type)
{
// sub div pow ...
pipeline_binaryop_broadcast_inner[1] = new Pipeline(vkdev);
pipeline_binaryop_broadcast_inner[1]->set_optimal_local_size_xyz(local_size_xyz);
pipeline_binaryop_broadcast_inner[1]->create(LayerShaderType::binaryop_broadcast_inner, opt, specializations_r);

pipeline_binaryop_broadcast_outer[1] = new Pipeline(vkdev);
pipeline_binaryop_broadcast_outer[1]->set_optimal_local_size_xyz(local_size_xyz);
pipeline_binaryop_broadcast_outer[1]->create(LayerShaderType::binaryop_broadcast_outer, opt, specializations_r);
}
}

// pack4
if (shape.dims == 0 || (elempack == 4 && elempack1 == 4))
if (shape.dims == 0 || (out_elempack == 4))
{
pipeline_binaryop_broadcast_pack4 = new Pipeline(vkdev);
pipeline_binaryop_broadcast_pack4->set_optimal_local_size_xyz(local_size_xyz);
pipeline_binaryop_broadcast_pack4->create(LayerShaderType::binaryop_broadcast_pack4, opt, specializations);
}
pipeline_binaryop_broadcast_inner_pack4[0] = new Pipeline(vkdev);
pipeline_binaryop_broadcast_inner_pack4[0]->set_optimal_local_size_xyz(local_size_xyz);
pipeline_binaryop_broadcast_inner_pack4[0]->create(LayerShaderType::binaryop_broadcast_inner_pack4, opt, specializations);

if (shape.dims == 0 || (shape.dims == 1 && shape.w == 1 && elempack == 1 && elempack1 == 4)
|| (shape.dims == 3 && shape1.dims == 3 && shape1.w == shape.w && shape1.h == shape.h && shape.c == 1 && elempack == 1 && elempack1 == 4))
{
pipeline_binaryop_broadcast_a1_pack4 = new Pipeline(vkdev);
pipeline_binaryop_broadcast_a1_pack4->set_optimal_local_size_xyz(local_size_xyz);
pipeline_binaryop_broadcast_a1_pack4->create(LayerShaderType::binaryop_broadcast_a1_pack4, opt, specializations_broadcast_a1_b1);
}
pipeline_binaryop_broadcast_outer_pack4[0] = new Pipeline(vkdev);
pipeline_binaryop_broadcast_outer_pack4[0]->set_optimal_local_size_xyz(local_size_xyz);
pipeline_binaryop_broadcast_outer_pack4[0]->create(LayerShaderType::binaryop_broadcast_outer_pack4, opt, specializations);

if (shape.dims == 0 || (shape1.dims == 1 && shape1.w == 1 && elempack1 == 1 && elempack == 4)
|| (shape.dims == 3 && shape1.dims == 3 && shape1.w == shape.w && shape1.h == shape.h && shape1.c == 1 && elempack1 == 1 && elempack == 4))
{
pipeline_binaryop_broadcast_b1_pack4 = new Pipeline(vkdev);
pipeline_binaryop_broadcast_b1_pack4->set_optimal_local_size_xyz(local_size_xyz);
pipeline_binaryop_broadcast_b1_pack4->create(LayerShaderType::binaryop_broadcast_b1_pack4, opt, specializations_broadcast_a1_b1);
if (op_type_r != op_type)
{
// sub div pow ...
pipeline_binaryop_broadcast_inner_pack4[1] = new Pipeline(vkdev);
pipeline_binaryop_broadcast_inner_pack4[1]->set_optimal_local_size_xyz(local_size_xyz);
pipeline_binaryop_broadcast_inner_pack4[1]->create(LayerShaderType::binaryop_broadcast_inner_pack4, opt, specializations_r);

pipeline_binaryop_broadcast_outer_pack4[1] = new Pipeline(vkdev);
pipeline_binaryop_broadcast_outer_pack4[1]->set_optimal_local_size_xyz(local_size_xyz);
pipeline_binaryop_broadcast_outer_pack4[1]->create(LayerShaderType::binaryop_broadcast_outer_pack4, opt, specializations_r);
}
}

// pack8
if ((opt.use_shader_pack8 && shape.dims == 0) || (elempack == 8 && elempack1 == 8))
if ((opt.use_shader_pack8 && shape.dims == 0) || (out_elempack == 8))
{
pipeline_binaryop_broadcast_pack8 = new Pipeline(vkdev);
pipeline_binaryop_broadcast_pack8->set_optimal_local_size_xyz(local_size_xyz);
pipeline_binaryop_broadcast_pack8->create(LayerShaderType::binaryop_broadcast_pack8, opt, specializations);
}
pipeline_binaryop_broadcast_inner_pack8[0] = new Pipeline(vkdev);
pipeline_binaryop_broadcast_inner_pack8[0]->set_optimal_local_size_xyz(local_size_xyz);
pipeline_binaryop_broadcast_inner_pack8[0]->create(LayerShaderType::binaryop_broadcast_inner_pack8, opt, specializations);

if ((opt.use_shader_pack8 && shape.dims == 0) || (shape.dims == 1 && shape.w == 1 && elempack == 1 && elempack1 == 8)
|| (shape.dims == 3 && shape1.dims == 3 && shape1.w == shape.w && shape1.h == shape.h && shape.c == 1 && elempack == 1 && elempack1 == 8))
{
pipeline_binaryop_broadcast_a1_pack8 = new Pipeline(vkdev);
pipeline_binaryop_broadcast_a1_pack8->set_optimal_local_size_xyz(local_size_xyz);
pipeline_binaryop_broadcast_a1_pack8->create(LayerShaderType::binaryop_broadcast_a1_pack8, opt, specializations_broadcast_a1_b1);
}
pipeline_binaryop_broadcast_outer_pack8[0] = new Pipeline(vkdev);
pipeline_binaryop_broadcast_outer_pack8[0]->set_optimal_local_size_xyz(local_size_xyz);
pipeline_binaryop_broadcast_outer_pack8[0]->create(LayerShaderType::binaryop_broadcast_outer_pack8, opt, specializations);

if ((opt.use_shader_pack8 && shape.dims == 0) || (shape1.dims == 1 && shape1.w == 1 && elempack1 == 1 && elempack == 8)
|| (shape.dims == 3 && shape1.dims == 3 && shape1.w == shape.w && shape1.h == shape.h && shape1.c == 1 && elempack1 == 1 && elempack == 8))
{
pipeline_binaryop_broadcast_b1_pack8 = new Pipeline(vkdev);
pipeline_binaryop_broadcast_b1_pack8->set_optimal_local_size_xyz(local_size_xyz);
pipeline_binaryop_broadcast_b1_pack8->create(LayerShaderType::binaryop_broadcast_b1_pack8, opt, specializations_broadcast_a1_b1);
if (op_type_r != op_type)
{
// sub div pow ...
pipeline_binaryop_broadcast_inner_pack8[1] = new Pipeline(vkdev);
pipeline_binaryop_broadcast_inner_pack8[1]->set_optimal_local_size_xyz(local_size_xyz);
pipeline_binaryop_broadcast_inner_pack8[1]->create(LayerShaderType::binaryop_broadcast_inner_pack8, opt, specializations_r);

pipeline_binaryop_broadcast_outer_pack8[1] = new Pipeline(vkdev);
pipeline_binaryop_broadcast_outer_pack8[1]->set_optimal_local_size_xyz(local_size_xyz);
pipeline_binaryop_broadcast_outer_pack8[1]->create(LayerShaderType::binaryop_broadcast_outer_pack8, opt, specializations_r);
}
}
}

@@ -318,167 +366,75 @@ int BinaryOp_vulkan::destroy_pipeline(const Option& /*opt*/)
delete pipeline_binaryop_pack8;
pipeline_binaryop_pack8 = 0;

delete pipeline_binaryop_broadcast;
pipeline_binaryop_broadcast = 0;

delete pipeline_binaryop_broadcast_pack4;
pipeline_binaryop_broadcast_pack4 = 0;
delete pipeline_binaryop_broadcast_inner[0];
delete pipeline_binaryop_broadcast_inner[1];
pipeline_binaryop_broadcast_inner[0] = 0;
pipeline_binaryop_broadcast_inner[1] = 0;

delete pipeline_binaryop_broadcast_a1_pack4;
pipeline_binaryop_broadcast_a1_pack4 = 0;
delete pipeline_binaryop_broadcast_inner_pack4[0];
delete pipeline_binaryop_broadcast_inner_pack4[1];
pipeline_binaryop_broadcast_inner_pack4[0] = 0;
pipeline_binaryop_broadcast_inner_pack4[1] = 0;

delete pipeline_binaryop_broadcast_b1_pack4;
pipeline_binaryop_broadcast_b1_pack4 = 0;
delete pipeline_binaryop_broadcast_inner_pack8[0];
delete pipeline_binaryop_broadcast_inner_pack8[1];
pipeline_binaryop_broadcast_inner_pack8[0] = 0;
pipeline_binaryop_broadcast_inner_pack8[1] = 0;

delete pipeline_binaryop_broadcast_pack8;
pipeline_binaryop_broadcast_pack8 = 0;
delete pipeline_binaryop_broadcast_outer[0];
delete pipeline_binaryop_broadcast_outer[1];
pipeline_binaryop_broadcast_outer[0] = 0;
pipeline_binaryop_broadcast_outer[1] = 0;

delete pipeline_binaryop_broadcast_a1_pack8;
pipeline_binaryop_broadcast_a1_pack8 = 0;
delete pipeline_binaryop_broadcast_outer_pack4[0];
delete pipeline_binaryop_broadcast_outer_pack4[1];
pipeline_binaryop_broadcast_outer_pack4[0] = 0;
pipeline_binaryop_broadcast_outer_pack4[1] = 0;

delete pipeline_binaryop_broadcast_b1_pack8;
pipeline_binaryop_broadcast_b1_pack8 = 0;
delete pipeline_binaryop_broadcast_outer_pack8[0];
delete pipeline_binaryop_broadcast_outer_pack8[1];
pipeline_binaryop_broadcast_outer_pack8[0] = 0;
pipeline_binaryop_broadcast_outer_pack8[1] = 0;

return 0;
}

int BinaryOp_vulkan::forward(const std::vector<VkMat>& bottom_blobs, std::vector<VkMat>& top_blobs, VkCompute& cmd, const Option& opt) const
{
const VkMat& bottom_blob = bottom_blobs[0];
const VkMat& bottom_blob1 = bottom_blobs[1];
const bool b_is_scalar = bottom_blobs[1].w * bottom_blobs[1].h * bottom_blobs[1].d * bottom_blobs[1].c * bottom_blobs[1].elempack == 1;
const bool a_rank_is_lower = bottom_blobs[0].dims < bottom_blobs[1].dims && !b_is_scalar;
const bool a_size_is_lower = bottom_blobs[0].w * bottom_blobs[0].h * bottom_blobs[0].d * bottom_blobs[0].c * bottom_blobs[0].elempack < bottom_blobs[1].w * bottom_blobs[1].h * bottom_blobs[1].d * bottom_blobs[1].c * bottom_blobs[1].elempack;
const bool a_is_lower = a_rank_is_lower || (!a_rank_is_lower && a_size_is_lower);
const VkMat& A = a_is_lower ? bottom_blobs[1] : bottom_blobs[0];
const VkMat& B = a_is_lower ? bottom_blobs[0] : bottom_blobs[1];
const int op_type_r = a_is_lower ? get_reverse_op_type(op_type) : op_type;

VkMat& top_blob = top_blobs[0];

// broadcast
if (bottom_blob.dims > bottom_blob1.dims)
{
top_blob.create_like(bottom_blob, opt.blob_vkallocator);
}
else if (bottom_blob.dims < bottom_blob1.dims)
{
top_blob.create_like(bottom_blob1, opt.blob_vkallocator);
}
else // if (bottom_blob.dims == bottom_blob1.dims)
{
if (bottom_blob.w * bottom_blob.h * bottom_blob.d * bottom_blob.c * bottom_blob.elempack >= bottom_blob1.w * bottom_blob1.h * bottom_blob1.d * bottom_blob1.c * bottom_blob1.elempack)
{
top_blob.create_like(bottom_blob, opt.blob_vkallocator);
}
else
{
top_blob.create_like(bottom_blob1, opt.blob_vkallocator);
}
}
top_blob.create_like(A, opt.blob_vkallocator);
if (top_blob.empty())
return -100;

int out_elempack = top_blob.elempack;

std::vector<VkMat> bindings(3);
bindings[0] = bottom_blob;
bindings[1] = bottom_blob1;
bindings[0] = A;
bindings[1] = B;
bindings[2] = top_blob;

bool broadcast = true;
if (bottom_blob.dims == bottom_blob1.dims
&& bottom_blob.w == bottom_blob1.w
&& bottom_blob.h == bottom_blob1.h
&& bottom_blob.d == bottom_blob1.d
&& bottom_blob.c == bottom_blob1.c
&& bottom_blob.elempack == bottom_blob1.elempack)
{
broadcast = false;
}

if (broadcast)
{
std::vector<vk_constant_type> constants(18);
constants[0].i = bottom_blob.dims;
constants[1].i = bottom_blob.w;
constants[2].i = bottom_blob.h;
constants[3].i = bottom_blob.d;
constants[4].i = bottom_blob.c;
constants[5].i = bottom_blob.cstep;
constants[6].i = bottom_blob1.dims;
constants[7].i = bottom_blob1.w;
constants[8].i = bottom_blob1.h;
constants[9].i = bottom_blob1.d;
constants[10].i = bottom_blob1.c;
constants[11].i = bottom_blob1.cstep;
constants[12].i = top_blob.dims;
constants[13].i = top_blob.w;
constants[14].i = top_blob.h;
constants[15].i = top_blob.d;
constants[16].i = top_blob.c;
constants[17].i = top_blob.cstep;

std::vector<vk_constant_type> constants_broadcast_a1b1(15);
constants_broadcast_a1b1[0].i = bottom_blob.dims;
constants_broadcast_a1b1[1].i = bottom_blob.w;
constants_broadcast_a1b1[2].i = bottom_blob.h * bottom_blob.d;
constants_broadcast_a1b1[3].i = bottom_blob.c;
constants_broadcast_a1b1[4].i = bottom_blob.cstep;
constants_broadcast_a1b1[5].i = bottom_blob1.dims;
constants_broadcast_a1b1[6].i = bottom_blob1.w;
constants_broadcast_a1b1[7].i = bottom_blob1.h * bottom_blob1.d;
constants_broadcast_a1b1[8].i = bottom_blob1.c;
constants_broadcast_a1b1[9].i = bottom_blob1.cstep;
constants_broadcast_a1b1[10].i = top_blob.dims;
constants_broadcast_a1b1[11].i = top_blob.w;
constants_broadcast_a1b1[12].i = top_blob.h * top_blob.d;
constants_broadcast_a1b1[13].i = top_blob.c;
constants_broadcast_a1b1[14].i = top_blob.cstep;

bool broadcast_a1b1 = true;

const Pipeline* pipeline = 0;
if (bottom_blob.elempack == 1 && bottom_blob1.elempack == 1)
{
pipeline = pipeline_binaryop_broadcast;
broadcast_a1b1 = false;
}
else
{
if (bottom_blob.dims == 1 && bottom_blob.w == 1 && bottom_blob.elempack == 1)
{
pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_a1_pack8 : pipeline_binaryop_broadcast_a1_pack4;
}
else if (bottom_blob1.dims == 1 && bottom_blob1.w == 1 && bottom_blob1.elempack == 1)
{
pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_b1_pack8 : pipeline_binaryop_broadcast_b1_pack4;
}
else if (bottom_blob.dims == 3 && bottom_blob1.dims == 3 && bottom_blob1.w == bottom_blob.w && bottom_blob1.h == bottom_blob.h && bottom_blob1.c == 1 && bottom_blob1.elempack == 1)
{
// special type 2
pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_b1_pack8 : pipeline_binaryop_broadcast_b1_pack4;
}
else if (bottom_blob.dims == 3 && bottom_blob1.dims == 3 && bottom_blob1.w == bottom_blob.w && bottom_blob1.h == bottom_blob.h && bottom_blob.c == 1 && bottom_blob.elempack == 1)
{
// special type 4
pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_a1_pack8 : pipeline_binaryop_broadcast_a1_pack4;
}
else
{
pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_pack8 : pipeline_binaryop_broadcast_pack4;
broadcast_a1b1 = false;
}
}

cmd.record_pipeline(pipeline, bindings, broadcast_a1b1 ? constants_broadcast_a1b1 : constants, top_blob);
}
else
// no broadcast
if (A.dims == B.dims && A.w == B.w && A.h == B.h && A.d == B.d && A.c == B.c && A.elempack == B.elempack)
{
std::vector<vk_constant_type> constants(15);
constants[0].i = bottom_blob.dims;
constants[1].i = bottom_blob.w;
constants[2].i = bottom_blob.h * bottom_blob.d;
constants[3].i = bottom_blob.c;
constants[4].i = bottom_blob.cstep;
constants[5].i = bottom_blob1.dims;
constants[6].i = bottom_blob1.w;
constants[7].i = bottom_blob1.h * bottom_blob1.d;
constants[8].i = bottom_blob1.c;
constants[9].i = bottom_blob1.cstep;
constants[0].i = A.dims;
constants[1].i = A.w;
constants[2].i = A.h * A.d;
constants[3].i = A.c;
constants[4].i = A.cstep;
constants[5].i = B.dims;
constants[6].i = B.w;
constants[7].i = B.h * B.d;
constants[8].i = B.c;
constants[9].i = B.cstep;
constants[10].i = top_blob.dims;
constants[11].i = top_blob.w;
constants[12].i = top_blob.h * top_blob.d;
@@ -490,8 +446,86 @@ int BinaryOp_vulkan::forward(const std::vector<VkMat>& bottom_blobs, std::vector
: pipeline_binaryop;

cmd.record_pipeline(pipeline, bindings, constants, top_blob);

return 0;
}

std::vector<vk_constant_type> constants(18);
constants[0].i = A.dims;
constants[1].i = A.w;
constants[2].i = A.h;
constants[3].i = A.d;
constants[4].i = A.c;
constants[5].i = A.cstep;
constants[6].i = B.dims;
constants[7].i = B.w;
constants[8].i = B.h;
constants[9].i = B.d;
constants[10].i = B.c;
constants[11].i = B.cstep;
constants[12].i = top_blob.dims;
constants[13].i = top_blob.w;
constants[14].i = top_blob.h;
constants[15].i = top_blob.d;
constants[16].i = top_blob.c;
constants[17].i = top_blob.cstep;

const int ri = op_type_r == op_type ? 0 : 1;

if (B.w * B.h * B.d * B.c * B.elempack == 1)
{
const Pipeline* pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_outer_pack8[ri]
: out_elempack == 4 ? pipeline_binaryop_broadcast_outer_pack4[ri]
: pipeline_binaryop_broadcast_outer[ri];

cmd.record_pipeline(pipeline, bindings, constants, top_blob);

return 0;
}

// broadcast B for inner axis
if ((B.dims < A.dims)
|| (A.dims == 2 && B.w == 1 && B.h == A.h)
|| (A.dims == 3 && B.w == 1 && B.h == 1 && B.c == A.c)
|| (A.dims == 3 && B.w == 1 && B.h == A.h && B.c == A.c)
|| (A.dims == 4 && B.w == 1 && B.h == 1 && B.d == 1 && B.c == A.c)
|| (A.dims == 4 && B.w == 1 && B.h == 1 && B.d == A.d && B.c == A.c)
|| (A.dims == 4 && B.w == 1 && B.h == A.h && B.d == A.d && B.c == A.c))
{
const Pipeline* pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_inner_pack8[ri]
: out_elempack == 4 ? pipeline_binaryop_broadcast_inner_pack4[ri]
: pipeline_binaryop_broadcast_inner[ri];

cmd.record_pipeline(pipeline, bindings, constants, top_blob);

return 0;
}

// broadcast B for outer axis
if (B.elempack == 1 && ((A.dims == 2 && B.w == A.w && B.h == 1) || (A.dims == 3 && B.w == A.w && B.h == 1 && B.c == 1) || (A.dims == 3 && B.w == A.w && B.h == A.h && B.c == 1) || (A.dims == 4 && B.w == A.w && B.h == 1 && B.d == 1 && B.c == 1) || (A.dims == 4 && B.w == A.w && B.h == A.h && B.d == 1 && B.c == 1) || (A.dims == 4 && B.w == A.w && B.h == A.h && B.d == A.d && B.c == 1)))
{
const Pipeline* pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_outer_pack8[ri]
: out_elempack == 4 ? pipeline_binaryop_broadcast_outer_pack4[ri]
: pipeline_binaryop_broadcast_outer[ri];

cmd.record_pipeline(pipeline, bindings, constants, top_blob);

return 0;
}

// some special broadcast rule here
if (A.dims == 3 && B.dims == 3 && A.w == B.w && B.h == 1 && A.c == B.c)
{
const Pipeline* pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_inner_pack8[ri]
: out_elempack == 4 ? pipeline_binaryop_broadcast_inner_pack4[ri]
: pipeline_binaryop_broadcast_inner[ri];

cmd.record_pipeline(pipeline, bindings, constants, top_blob);

return 0;
}

// should never reach here
return 0;
}

@@ -522,141 +556,40 @@ int BinaryOp_vulkan::forward_inplace(VkMat& bottom_top_blob, VkCompute& cmd, con

int BinaryOp_vulkan::forward(const std::vector<VkImageMat>& bottom_blobs, std::vector<VkImageMat>& top_blobs, VkCompute& cmd, const Option& opt) const
{
const VkImageMat& bottom_blob = bottom_blobs[0];
const VkImageMat& bottom_blob1 = bottom_blobs[1];
const bool b_is_scalar = bottom_blobs[1].w * bottom_blobs[1].h * bottom_blobs[1].d * bottom_blobs[1].c * bottom_blobs[1].elempack == 1;
const bool a_rank_is_lower = bottom_blobs[0].dims < bottom_blobs[1].dims && !b_is_scalar;
const bool a_size_is_lower = bottom_blobs[0].w * bottom_blobs[0].h * bottom_blobs[0].d * bottom_blobs[0].c * bottom_blobs[0].elempack < bottom_blobs[1].w * bottom_blobs[1].h * bottom_blobs[1].d * bottom_blobs[1].c * bottom_blobs[1].elempack;
const bool a_is_lower = a_rank_is_lower || a_size_is_lower;
const VkImageMat& A = a_is_lower ? bottom_blobs[1] : bottom_blobs[0];
const VkImageMat& B = a_is_lower ? bottom_blobs[0] : bottom_blobs[1];
const int op_type_r = a_is_lower ? get_reverse_op_type(op_type) : op_type;

VkImageMat& top_blob = top_blobs[0];

// broadcast
if (bottom_blob.dims > bottom_blob1.dims)
{
top_blob.create_like(bottom_blob, opt.blob_vkallocator);
}
else if (bottom_blob.dims < bottom_blob1.dims)
{
top_blob.create_like(bottom_blob1, opt.blob_vkallocator);
}
else // if (bottom_blob.dims == bottom_blob1.dims)
{
if (bottom_blob.w * bottom_blob.h * bottom_blob.d * bottom_blob.c * bottom_blob.elempack >= bottom_blob1.w * bottom_blob1.h * bottom_blob1.d * bottom_blob1.c * bottom_blob1.elempack)
{
top_blob.create_like(bottom_blob, opt.blob_vkallocator);
}
else
{
top_blob.create_like(bottom_blob1, opt.blob_vkallocator);
}
}
top_blob.create_like(A, opt.blob_vkallocator);
if (top_blob.empty())
return -100;

int out_elempack = top_blob.elempack;

std::vector<VkImageMat> bindings(3);
bindings[0] = bottom_blob;
bindings[1] = bottom_blob1;
bindings[0] = A;
bindings[1] = B;
bindings[2] = top_blob;

bool broadcast = true;
if (bottom_blob.dims == bottom_blob1.dims
&& bottom_blob.w == bottom_blob1.w
&& bottom_blob.h == bottom_blob1.h
&& bottom_blob.d == bottom_blob1.d
&& bottom_blob.c == bottom_blob1.c
&& bottom_blob.elempack == bottom_blob1.elempack)
{
broadcast = false;
}

if (broadcast)
{
std::vector<vk_constant_type> constants(18);
constants[0].i = bottom_blob.dims;
constants[1].i = bottom_blob.w;
constants[2].i = bottom_blob.h;
constants[3].i = bottom_blob.d;
constants[4].i = bottom_blob.c;
constants[5].i = 0; //bottom_blob.cstep;
constants[6].i = bottom_blob1.dims;
constants[7].i = bottom_blob1.w;
constants[8].i = bottom_blob1.h;
constants[9].i = bottom_blob1.d;
constants[10].i = bottom_blob1.c;
constants[11].i = 0; //bottom_blob1.cstep;
constants[12].i = top_blob.dims;
constants[13].i = top_blob.w;
constants[14].i = top_blob.h;
constants[15].i = top_blob.d;
constants[16].i = top_blob.c;
constants[17].i = 0; //top_blob.cstep;

std::vector<vk_constant_type> constants_broadcast_a1b1(15);
constants_broadcast_a1b1[0].i = bottom_blob.dims;
constants_broadcast_a1b1[1].i = bottom_blob.w;
constants_broadcast_a1b1[2].i = bottom_blob.h * bottom_blob.d;
constants_broadcast_a1b1[3].i = bottom_blob.c;
constants_broadcast_a1b1[4].i = 0; //bottom_blob.cstep;
constants_broadcast_a1b1[5].i = bottom_blob1.dims;
constants_broadcast_a1b1[6].i = bottom_blob1.w;
constants_broadcast_a1b1[7].i = bottom_blob1.h * bottom_blob1.d;
constants_broadcast_a1b1[8].i = bottom_blob1.c;
constants_broadcast_a1b1[9].i = 0; //bottom_blob1.cstep;
constants_broadcast_a1b1[10].i = top_blob.dims;
constants_broadcast_a1b1[11].i = top_blob.w;
constants_broadcast_a1b1[12].i = top_blob.h * top_blob.d;
constants_broadcast_a1b1[13].i = top_blob.c;
constants_broadcast_a1b1[14].i = 0; //top_blob.cstep;

bool broadcast_a1b1 = true;

const Pipeline* pipeline = 0;
if (bottom_blob.elempack == 1 && bottom_blob1.elempack == 1)
{
pipeline = pipeline_binaryop_broadcast;
broadcast_a1b1 = false;
}
else
{
if (bottom_blob.dims == 1 && bottom_blob.w == 1 && bottom_blob.elempack == 1)
{
pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_a1_pack8 : pipeline_binaryop_broadcast_a1_pack4;
}
else if (bottom_blob1.dims == 1 && bottom_blob1.w == 1 && bottom_blob1.elempack == 1)
{
pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_b1_pack8 : pipeline_binaryop_broadcast_b1_pack4;
}
else if (bottom_blob.dims == 3 && bottom_blob1.dims == 3 && bottom_blob1.w == bottom_blob.w && bottom_blob1.h == bottom_blob.h && bottom_blob1.c == 1 && bottom_blob1.elempack == 1)
{
// special type 2
pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_b1_pack8 : pipeline_binaryop_broadcast_b1_pack4;
}
else if (bottom_blob.dims == 3 && bottom_blob1.dims == 3 && bottom_blob1.w == bottom_blob.w && bottom_blob1.h == bottom_blob.h && bottom_blob.c == 1 && bottom_blob.elempack == 1)
{
// special type 4
pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_a1_pack8 : pipeline_binaryop_broadcast_a1_pack4;
}
else
{
pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_pack8 : pipeline_binaryop_broadcast_pack4;
broadcast_a1b1 = false;
}
}

cmd.record_pipeline(pipeline, bindings, broadcast_a1b1 ? constants_broadcast_a1b1 : constants, top_blob);
}
else
// no broadcast
if (A.dims == B.dims && A.w == B.w && A.h == B.h && A.d == B.d && A.c == B.c && A.elempack == B.elempack)
{
std::vector<vk_constant_type> constants(15);
constants[0].i = bottom_blob.dims;
constants[1].i = bottom_blob.w;
constants[2].i = bottom_blob.h * bottom_blob.d;
constants[3].i = bottom_blob.c;
constants[4].i = 0; //bottom_blob.cstep;
constants[5].i = bottom_blob1.dims;
constants[6].i = bottom_blob1.w;
constants[7].i = bottom_blob1.h * bottom_blob1.d;
constants[8].i = bottom_blob1.c;
constants[9].i = 0; //bottom_blob1.cstep;
constants[0].i = A.dims;
constants[1].i = A.w;
constants[2].i = A.h * A.d;
constants[3].i = A.c;
constants[4].i = 0; //A.cstep;
constants[5].i = B.dims;
constants[6].i = B.w;
constants[7].i = B.h * B.d;
constants[8].i = B.c;
constants[9].i = 0; //B.cstep;
constants[10].i = top_blob.dims;
constants[11].i = top_blob.w;
constants[12].i = top_blob.h * top_blob.d;
@@ -668,8 +601,86 @@ int BinaryOp_vulkan::forward(const std::vector<VkImageMat>& bottom_blobs, std::v
: pipeline_binaryop;

cmd.record_pipeline(pipeline, bindings, constants, top_blob);

return 0;
}

std::vector<vk_constant_type> constants(18);
constants[0].i = A.dims;
constants[1].i = A.w;
constants[2].i = A.h;
constants[3].i = A.d;
constants[4].i = A.c;
constants[5].i = 0; //A.cstep;
constants[6].i = B.dims;
constants[7].i = B.w;
constants[8].i = B.h;
constants[9].i = B.d;
constants[10].i = B.c;
constants[11].i = 0; //B.cstep;
constants[12].i = top_blob.dims;
constants[13].i = top_blob.w;
constants[14].i = top_blob.h;
constants[15].i = top_blob.d;
constants[16].i = top_blob.c;
constants[17].i = 0; //top_blob.cstep;

const int ri = op_type_r == op_type ? 0 : 1;

if (B.w * B.h * B.d * B.c * B.elempack == 1)
{
const Pipeline* pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_outer_pack8[ri]
: out_elempack == 4 ? pipeline_binaryop_broadcast_outer_pack4[ri]
: pipeline_binaryop_broadcast_outer[ri];

cmd.record_pipeline(pipeline, bindings, constants, top_blob);

return 0;
}

// broadcast B for inner axis
if ((B.dims < A.dims)
|| (A.dims == 2 && B.w == 1 && B.h == A.h)
|| (A.dims == 3 && B.w == 1 && B.h == 1 && B.c == A.c)
|| (A.dims == 3 && B.w == 1 && B.h == A.h && B.c == A.c)
|| (A.dims == 4 && B.w == 1 && B.h == 1 && B.d == 1 && B.c == A.c)
|| (A.dims == 4 && B.w == 1 && B.h == 1 && B.d == A.d && B.c == A.c)
|| (A.dims == 4 && B.w == 1 && B.h == A.h && B.d == A.d && B.c == A.c))
{
const Pipeline* pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_inner_pack8[ri]
: out_elempack == 4 ? pipeline_binaryop_broadcast_inner_pack4[ri]
: pipeline_binaryop_broadcast_inner[ri];

cmd.record_pipeline(pipeline, bindings, constants, top_blob);

return 0;
}

// broadcast B for outer axis
if (B.elempack == 1 && ((A.dims == 2 && B.w == A.w && B.h == 1) || (A.dims == 3 && B.w == A.w && B.h == 1 && B.c == 1) || (A.dims == 3 && B.w == A.w && B.h == A.h && B.c == 1) || (A.dims == 4 && B.w == A.w && B.h == 1 && B.d == 1 && B.c == 1) || (A.dims == 4 && B.w == A.w && B.h == A.h && B.d == 1 && B.c == 1) || (A.dims == 4 && B.w == A.w && B.h == A.h && B.d == A.d && B.c == 1)))
{
const Pipeline* pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_outer_pack8[ri]
: out_elempack == 4 ? pipeline_binaryop_broadcast_outer_pack4[ri]
: pipeline_binaryop_broadcast_outer[ri];

cmd.record_pipeline(pipeline, bindings, constants, top_blob);

return 0;
}

// some special broadcast rule here
if (A.dims == 3 && B.dims == 3 && A.w == B.w && B.h == 1 && A.c == B.c)
{
const Pipeline* pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_inner_pack8[ri]
: out_elempack == 4 ? pipeline_binaryop_broadcast_inner_pack4[ri]
: pipeline_binaryop_broadcast_inner[ri];

cmd.record_pipeline(pipeline, bindings, constants, top_blob);

return 0;
}

// should never reach here
return 0;
}



+ 6
- 7
src/layer/vulkan/binaryop_vulkan.h View File

@@ -43,13 +43,12 @@ public:
Pipeline* pipeline_binaryop_pack8;

// broadcast
Pipeline* pipeline_binaryop_broadcast;
Pipeline* pipeline_binaryop_broadcast_pack4;
Pipeline* pipeline_binaryop_broadcast_a1_pack4;
Pipeline* pipeline_binaryop_broadcast_b1_pack4;
Pipeline* pipeline_binaryop_broadcast_pack8;
Pipeline* pipeline_binaryop_broadcast_a1_pack8;
Pipeline* pipeline_binaryop_broadcast_b1_pack8;
Pipeline* pipeline_binaryop_broadcast_inner[2];
Pipeline* pipeline_binaryop_broadcast_inner_pack4[2];
Pipeline* pipeline_binaryop_broadcast_inner_pack8[2];
Pipeline* pipeline_binaryop_broadcast_outer[2];
Pipeline* pipeline_binaryop_broadcast_outer_pack4[2];
Pipeline* pipeline_binaryop_broadcast_outer_pack8[2];
};

} // namespace ncnn


+ 27
- 31
src/layer/vulkan/shader/binaryop.comp View File

@@ -92,52 +92,48 @@ void main()
afp v1 = buffer_ld1(a_blob_data, gi);
#endif

afp res;
afp v2;

if (with_scalar == 1)
{
// type 5 10 15
afp b = afp(const_b);

if (op_type == 0) res = v1 + b;
if (op_type == 1) res = v1 - b;
if (op_type == 2) res = v1 * b;
if (op_type == 3) res = v1 / b;
if (op_type == 4) res = max(v1, b);
if (op_type == 5) res = min(v1, b);
if (op_type == 6) res = pow(v1, b);
if (op_type == 7) res = b - v1;
if (op_type == 8) res = b / v1;

// type 0 1 2 3
v2 = afp(const_b);
}
else if (psc(bdims) == 1 && psc(bw) == 1)
{
// type 0 1 2 3
#if NCNN_image_shader
image3d_st1(top_blob_3d, ivec3(gx, gy, gz), res);
v2 = image3d_ld1(b_blob_3d, ivec3(0, 0, 0));
#else
buffer_st1(a_blob_data, gi, res);
v2 = buffer_ld1(b_blob_data, 0);
#endif
}
else
{
// type 7 13 19
// type 4 5 6 7
#if NCNN_image_shader
afp v2 = image3d_ld1(b_blob_3d, ivec3(gx, gy, gz));
v2 = image3d_ld1(b_blob_3d, ivec3(gx, gy, gz));
#else
afp v2 = buffer_ld1(b_blob_data, gi);
v2 = buffer_ld1(b_blob_data, gi);
#endif
}

afp res;

if (op_type == 0) res = v1 + v2;
if (op_type == 1) res = v1 - v2;
if (op_type == 2) res = v1 * v2;
if (op_type == 3) res = v1 / v2;
if (op_type == 4) res = max(v1, v2);
if (op_type == 5) res = min(v1, v2);
if (op_type == 6) res = pow(v1, v2);
if (op_type == 7) res = v2 - v1;
if (op_type == 8) res = v2 / v1;
if (op_type == 0) res = v1 + v2;
if (op_type == 1) res = v1 - v2;
if (op_type == 2) res = v1 * v2;
if (op_type == 3) res = v1 / v2;
if (op_type == 4) res = max(v1, v2);
if (op_type == 5) res = min(v1, v2);
if (op_type == 6) res = pow(v1, v2);
if (op_type == 7) res = v2 - v1;
if (op_type == 8) res = v2 / v1;
if (op_type == 9) res = pow(v2, v1);

#if NCNN_image_shader
image3d_st1(top_blob_3d, ivec3(gx, gy, gz), res);
image3d_st1(top_blob_3d, ivec3(gx, gy, gz), res);
#else
buffer_st1(top_blob_data, gi, res);
buffer_st1(top_blob_data, gi, res);
#endif
}
}

+ 0
- 553
src/layer/vulkan/shader/binaryop_broadcast.comp View File

@@ -1,553 +0,0 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#version 450

#if NCNN_fp16_storage
#extension GL_EXT_shader_16bit_storage: require
#endif
#if NCNN_fp16_arithmetic
#extension GL_EXT_shader_explicit_arithmetic_types_float16: require
#endif

layout (constant_id = 0) const int op_type = 0;

#define shape_constant_id_offset 1
layout (constant_id = shape_constant_id_offset + 0) const int adims = 0;
layout (constant_id = shape_constant_id_offset + 1) const int aw = 0;
layout (constant_id = shape_constant_id_offset + 2) const int ah = 0;
layout (constant_id = shape_constant_id_offset + 3) const int ad = 0;
layout (constant_id = shape_constant_id_offset + 4) const int ac = 0;
layout (constant_id = shape_constant_id_offset + 5) const int acstep = 0;

layout (constant_id = shape_constant_id_offset + 6) const int bdims = 0;
layout (constant_id = shape_constant_id_offset + 7) const int bw = 0;
layout (constant_id = shape_constant_id_offset + 8) const int bh = 0;
layout (constant_id = shape_constant_id_offset + 9) const int bd = 0;
layout (constant_id = shape_constant_id_offset + 10) const int bc = 0;
layout (constant_id = shape_constant_id_offset + 11) const int bcstep = 0;

layout (constant_id = shape_constant_id_offset + 12) const int outdims = 0;
layout (constant_id = shape_constant_id_offset + 13) const int outw = 0;
layout (constant_id = shape_constant_id_offset + 14) const int outh = 0;
layout (constant_id = shape_constant_id_offset + 15) const int outd = 0;
layout (constant_id = shape_constant_id_offset + 16) const int outc = 0;
layout (constant_id = shape_constant_id_offset + 17) const int outcstep = 0;

#if NCNN_image_shader
layout (binding = 0) uniform unfp sampler3D a_blob_3d;
layout (binding = 1) uniform unfp sampler3D b_blob_3d;
layout (binding = 2, imfmtc1) writeonly uniform unfp image3D top_blob_3d;
#else
layout (binding = 0) readonly buffer a_blob { sfp a_blob_data[]; };
layout (binding = 1) readonly buffer b_blob { sfp b_blob_data[]; };
layout (binding = 2) writeonly buffer top_blob { sfp top_blob_data[]; };
#endif

layout (push_constant) uniform parameter
{
int adims;
int aw;
int ah;
int ad;
int ac;
int acstep;

int bdims;
int bw;
int bh;
int bd;
int bc;
int bcstep;

int outdims;
int outw;
int outh;
int outd;
int outc;
int outcstep;
} p;

void main()
{
int gx = int(gl_GlobalInvocationID.x);
int gy = int(gl_GlobalInvocationID.y);
int gz = int(gl_GlobalInvocationID.z);

if (gx >= psc(outw) || gy >= psc(outh) * psc(outd) || gz >= psc(outc))
return;

#if NCNN_image_shader
int ax = gx;
int ay = gy;
int az = gz;
int bx = gx;
int by = gy;
int bz = gz;

if (psc(adims) == 4)
{
int yd = gy / psc(outh);
int yh = gy % psc(outh);

if (psc(bdims) == 3)
{
// type 28
bx = yh;
by = yd;
bz = gz;
}

if (psc(bdims) == 2)
{
// type 27
bx = yd;
by = gz;
bz = 0;
}

if (psc(bdims) == 1)
{
if (psc(bw) == 1)
{
// type 25
bx = 0;
by = 0;
bz = 0;
}
else
{
// type 26
bx = gz;
by = 0;
bz = 0;
}
}
}
else if (psc(adims) == 3)
{
if (psc(bdims) == 4)
{
int yd = gy / psc(outh);
int yh = gy % psc(outh);

// type 23
ax = yh;
ay = yd;
az = gz;
}

if (psc(bdims) == 3)
{
if (psc(bw) == 1 && psc(bh) == 1)
{
// special type 1
bx = 0;
by = 0;
}

if (psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(bc) == 1)
{
// special type 2
bz = 0;
}

if (psc(aw) == 1 && psc(ah) == 1)
{
// special type 3
ax = 0;
ay = 0;
}

if (psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(ac) == 1)
{
// special type 4
az = 0;
}

if (psc(aw) != 1 && psc(bw) == 1 && psc(bh) == psc(ah) && psc(bc) == psc(ac))
{
// special type 5
bx = 0;
}

if (psc(bw) == psc(aw) && psc(ah) != 1 && psc(bh) == 1 && psc(bc) == psc(ac))
{
// special type 6
by = 0;
}

if (psc(bw) != 1 && psc(aw) == 1 && psc(bh) == psc(ah) && psc(bc) == psc(ac))
{
// special type 7
ax = 0;
}

if (psc(bw) == psc(aw) && psc(bh) != 1 && psc(ah) == 1 && psc(bc) == psc(ac))
{
// special type 8
ay = 0;
}
}

if (psc(bdims) == 2)
{
// type 18
bx = gy;
by = gz;
bz = 0;
}

if (psc(bdims) == 1)
{
if (psc(bw) == 1)
{
// type 16
bx = 0;
by = 0;
bz = 0;
}
else
{
// type 17
bx = gz;
by = 0;
bz = 0;
}
}
}
else if (psc(adims) == 2)
{
if (psc(bdims) == 4)
{
int yd = gy / psc(outh);
int yh = gy % psc(outh);

// type 22
ax = yd;
ay = gz;
az = 0;
}

if (psc(bdims) == 3)
{
// type 14
ax = gy;
ay = gz;
az = 0;
}

if (psc(bdims) == 1)
{
if (psc(bw) == 1)
{
// type 11
bx = 0;
by = 0;
bz = 0;
}
else
{
// type 12
bx = gy;
by = 0;
bz = 0;
}
}
}
else if (psc(adims) == 1)
{
if (psc(aw) == 1)
{
// type 2 3 4 20
ax = 0;
ay = 0;
az = 0;
}
else
{
if (psc(bdims) == 4)
{
// type 21
ax = gz;
ay = 0;
az = 0;
}

if (psc(bdims) == 3)
{
// type 9
ax = gz;
ay = 0;
az = 0;
}

if (psc(bdims) == 2)
{
// type 8
ax = gy;
ay = 0;
az = 0;
}

if (psc(bdims) == 1)
{
if (psc(bw) == 1)
{
// type 6
bx = 0;
by = 0;
bz = 0;
}
}
}
}

afp v1 = image3d_ld1(a_blob_3d, ivec3(ax, ay, az));
afp v2 = image3d_ld1(b_blob_3d, ivec3(bx, by, bz));
#else
const int gi = gz * psc(outcstep) + gy * psc(outw) + gx;

int ai;
int bi;

if (psc(adims) == 4)
{
int yd = gy / psc(outh);
int yh = gy % psc(outh);

if (psc(bdims) == 3)
{
// type 28
ai = gi;
bi = gz * psc(bcstep) + yd * psc(bw) + yh;
}

if (psc(bdims) == 2)
{
// type 27
ai = gi;
bi = gz * psc(bw) + yd;
}

if (psc(bdims) == 1)
{
if (psc(bw) == 1)
{
// type 25
ai = gi;
bi = 0;
}
else
{
// type 26
ai = gi;
bi = gz;
}
}
}
else if (psc(adims) == 3)
{
if (psc(bdims) == 4)
{
int yd = gy / psc(outh);
int yh = gy % psc(outh);

// type 23
ai = gz * psc(acstep) + yd * psc(aw) + yh;
bi = gi;
}

if (psc(bdims) == 3)
{
if (psc(bw) == 1 && psc(bh) == 1)
{
// special type 1
ai = gi;
bi = gz * psc(bcstep);
}

if (psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(bc) == 1)
{
// special type 2
ai = gi;
bi = gy * psc(bw) + gx;
}

if (psc(aw) == 1 && psc(ah) == 1)
{
// special type 3
ai = gz * psc(acstep);
bi = gi;
}

if (psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(ac) == 1)
{
// special type 4
ai = gy * psc(aw) + gx;
bi = gi;
}

if (psc(aw) != 1 && psc(bw) == 1 && psc(bh) == psc(ah) && psc(bc) == psc(ac))
{
// special type 5
ai = gi;
bi = gz * psc(bcstep) + gy;
}

if (psc(bw) == psc(aw) && psc(ah) != 1 && psc(bh) == 1 && psc(bc) == psc(ac))
{
// special type 6
ai = gi;
bi = gz * psc(bcstep) + gx;
}

if (psc(bw) != 1 && psc(aw) == 1 && psc(bh) == psc(ah) && psc(bc) == psc(ac))
{
// special type 7
ai = gz * psc(acstep) + gy;
bi = gi;
}

if (psc(bw) == psc(aw) && psc(bh) != 1 && psc(ah) == 1 && psc(bc) == psc(ac))
{
// special type 8
ai = gz * psc(acstep) + gx;
bi = gi;
}
}

if (psc(bdims) == 2)
{
// type 18
ai = gi;
bi = gz * psc(bw) + gy;
}

if (psc(bdims) == 1)
{
if (psc(bw) == 1)
{
// type 16
ai = gi;
bi = 0;
}
else
{
// type 17
ai = gi;
bi = gz;
}
}
}
else if (psc(adims) == 2)
{
if (psc(bdims) == 4)
{
int yd = gy / psc(outh);
int yh = gy % psc(outh);

// type 22
ai = gz * psc(aw) + yd;
bi = gi;
}

if (psc(bdims) == 3)
{
// type 14
ai = gz * psc(aw) + gy;
bi = gi;
}

if (psc(bdims) == 1)
{
if (psc(bw) == 1)
{
// type 11
ai = gi;
bi = 0;
}
else
{
// type 12
ai = gi;
bi = gy;
}
}
}
else if (psc(adims) == 1)
{
if (psc(aw) == 1)
{
// type 2 3 4 20
ai = 0;
bi = gi;
}
else
{
if (psc(bdims) == 4)
{
// type 21
ai = gz;
bi = gi;
}

if (psc(bdims) == 3)
{
// type 9
ai = gz;
bi = gi;
}

if (psc(bdims) == 2)
{
// type 8
ai = gy;
bi = gi;
}

if (psc(bdims) == 1)
{
if (psc(bw) == 1)
{
// type 6
ai = gi;
bi = 0;
}
}
}
}

afp v1 = buffer_ld1(a_blob_data, ai);
afp v2 = buffer_ld1(b_blob_data, bi);
#endif

afp res;

if (op_type == 0) res = v1 + v2;
if (op_type == 1) res = v1 - v2;
if (op_type == 2) res = v1 * v2;
if (op_type == 3) res = v1 / v2;
if (op_type == 4) res = max(v1, v2);
if (op_type == 5) res = min(v1, v2);
if (op_type == 6) res = pow(v1, v2);
if (op_type == 7) res = v2 - v1;
if (op_type == 8) res = v2 / v1;

#if NCNN_image_shader
image3d_st1(top_blob_3d, ivec3(gx, gy, gz), res);
#else
buffer_st1(top_blob_data, gi, res);
#endif
}

+ 0
- 169
src/layer/vulkan/shader/binaryop_broadcast_a1_pack8.comp View File

@@ -1,169 +0,0 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#version 450

#if NCNN_fp16_storage
#extension GL_EXT_shader_16bit_storage: require
struct sfpvec8 { f16vec4 abcd; f16vec4 efgh; };
#endif
#if NCNN_fp16_arithmetic
#extension GL_EXT_shader_explicit_arithmetic_types_float16: require
#endif

layout (constant_id = 0) const int op_type = 0;

#define shape_constant_id_offset 1
layout (constant_id = shape_constant_id_offset + 0) const int adims = 0;
layout (constant_id = shape_constant_id_offset + 1) const int aw = 0;
layout (constant_id = shape_constant_id_offset + 2) const int ah = 0;
layout (constant_id = shape_constant_id_offset + 3) const int ac = 0;
layout (constant_id = shape_constant_id_offset + 4) const int acstep = 0;

layout (constant_id = shape_constant_id_offset + 5) const int bdims = 0;
layout (constant_id = shape_constant_id_offset + 6) const int bw = 0;
layout (constant_id = shape_constant_id_offset + 7) const int bh = 0;
layout (constant_id = shape_constant_id_offset + 8) const int bc = 0;
layout (constant_id = shape_constant_id_offset + 9) const int bcstep = 0;

layout (constant_id = shape_constant_id_offset + 10) const int outdims = 0;
layout (constant_id = shape_constant_id_offset + 11) const int outw = 0;
layout (constant_id = shape_constant_id_offset + 12) const int outh = 0;
layout (constant_id = shape_constant_id_offset + 13) const int outc = 0;
layout (constant_id = shape_constant_id_offset + 14) const int outcstep = 0;

#if NCNN_image_shader
layout (binding = 0) uniform unfp sampler3D a_blob_3d;
layout (binding = 1) uniform unfp sampler3D b_blob_3d;
layout (binding = 2, imfmtc4) writeonly uniform unfp image3D top_blob_3d;
#else
layout (binding = 0) readonly buffer a_blob { sfp a_blob_data[]; };
layout (binding = 1) readonly buffer b_blob { sfpvec8 b_blob_data[]; };
layout (binding = 2) writeonly buffer top_blob { sfpvec8 top_blob_data[]; };
#endif

layout (push_constant) uniform parameter
{
int adims;
int aw;
int ah;
int ac;
int acstep;

int bdims;
int bw;
int bh;
int bc;
int bcstep;

int outdims;
int outw;
int outh;
int outc;
int outcstep;
} p;

void main()
{
int gx = int(gl_GlobalInvocationID.x);
int gy = int(gl_GlobalInvocationID.y);
int gz = int(gl_GlobalInvocationID.z);

if (gx >= psc(outw) || gy >= psc(outh) || gz >= psc(outc))
return;

#if NCNN_image_shader
afpvec4 v1;

if (psc(adims) == 3 && psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(ac) == 1)
{
// special type 4
v1 = afpvec4(image3d_ld1(a_blob_3d, ivec3(gx, gy, 0)));
}
else
{
v1 = afpvec4(image3d_ld1(a_blob_3d, ivec3(0, 0, 0)));
}

afpvec8 v2 = image3d_ld8(b_blob_3d, ivec3(gx, gy, gz));
#else
const int gi = gz * psc(outcstep) + gy * psc(outw) + gx;

int ai = 0;

if (psc(adims) == 3 && psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(ac) == 1)
{
// special type 2
ai = gy * psc(bw) + gx;
}

// type 2 3 4
afpvec4 v1 = afpvec4(buffer_ld1(a_blob_data, ai));
afpvec8 v2 = buffer_ld8(b_blob_data, gi);
#endif

afpvec8 res;

if (op_type == 0)
{
res[0] = v1 + v2[0];
res[1] = v1 + v2[1];
}
if (op_type == 1)
{
res[0] = v1 - v2[0];
res[1] = v1 - v2[1];
}
if (op_type == 2)
{
res[0] = v1 * v2[0];
res[1] = v1 * v2[1];
}
if (op_type == 3)
{
res[0] = v1 / v2[0];
res[1] = v1 / v2[1];
}
if (op_type == 4)
{
res[0] = max(v1, v2[0]);
res[1] = max(v1, v2[1]);
}
if (op_type == 5)
{
res[0] = min(v1, v2[0]);
res[1] = min(v1, v2[1]);
}
if (op_type == 6)
{
res[0] = pow(v1, v2[0]);
res[1] = pow(v1, v2[1]);
}
if (op_type == 7)
{
res[0] = v2[0] - v1;
res[1] = v2[1] - v1;
}
if (op_type == 8)
{
res[0] = v2[0] / v1;
res[1] = v2[1] / v1;
}

#if NCNN_image_shader
image3d_st8(top_blob_3d, ivec3(gx, gy, gz), res);
#else
buffer_st8(top_blob_data, gi, res);
#endif
}

+ 193
- 0
src/layer/vulkan/shader/binaryop_broadcast_inner.comp View File

@@ -0,0 +1,193 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#version 450

#if NCNN_fp16_storage
#extension GL_EXT_shader_16bit_storage: require
#endif
#if NCNN_fp16_arithmetic
#extension GL_EXT_shader_explicit_arithmetic_types_float16: require
#endif

layout (constant_id = 0) const int op_type = 0;

#define shape_constant_id_offset 1
layout (constant_id = shape_constant_id_offset + 0) const int adims = 0;
layout (constant_id = shape_constant_id_offset + 1) const int aw = 0;
layout (constant_id = shape_constant_id_offset + 2) const int ah = 0;
layout (constant_id = shape_constant_id_offset + 3) const int ad = 0;
layout (constant_id = shape_constant_id_offset + 4) const int ac = 0;
layout (constant_id = shape_constant_id_offset + 5) const int acstep = 0;

layout (constant_id = shape_constant_id_offset + 6) const int bdims = 0;
layout (constant_id = shape_constant_id_offset + 7) const int bw = 0;
layout (constant_id = shape_constant_id_offset + 8) const int bh = 0;
layout (constant_id = shape_constant_id_offset + 9) const int bd = 0;
layout (constant_id = shape_constant_id_offset + 10) const int bc = 0;
layout (constant_id = shape_constant_id_offset + 11) const int bcstep = 0;

layout (constant_id = shape_constant_id_offset + 12) const int outdims = 0;
layout (constant_id = shape_constant_id_offset + 13) const int outw = 0;
layout (constant_id = shape_constant_id_offset + 14) const int outh = 0;
layout (constant_id = shape_constant_id_offset + 15) const int outd = 0;
layout (constant_id = shape_constant_id_offset + 16) const int outc = 0;
layout (constant_id = shape_constant_id_offset + 17) const int outcstep = 0;

#if NCNN_image_shader
layout (binding = 0) uniform unfp sampler3D a_blob_3d;
layout (binding = 1) uniform unfp sampler3D b_blob_3d;
layout (binding = 2, imfmtc1) writeonly uniform unfp image3D top_blob_3d;
#else
layout (binding = 0) readonly buffer a_blob { sfp a_blob_data[]; };
layout (binding = 1) readonly buffer b_blob { sfp b_blob_data[]; };
layout (binding = 2) writeonly buffer top_blob { sfp top_blob_data[]; };
#endif

layout (push_constant) uniform parameter
{
int adims;
int aw;
int ah;
int ad;
int ac;
int acstep;

int bdims;
int bw;
int bh;
int bd;
int bc;
int bcstep;

int outdims;
int outw;
int outh;
int outd;
int outc;
int outcstep;
} p;

void main()
{
int gx = int(gl_GlobalInvocationID.x);
int gy = int(gl_GlobalInvocationID.y);
int gz = int(gl_GlobalInvocationID.z);

if (gx >= psc(outw) || gy >= psc(outh) * psc(outd) || gz >= psc(outc))
return;

int yd = gy / psc(outh);
int yh = gy % psc(outh);

int bx = gx;
int by = gy;
int bz = gz;

if (psc(adims) == psc(bdims))
{
// explicit broadcast
bx = min(gx, psc(bw) - 1);
by = min(yd, psc(bd) - 1) * psc(bh) + min(yh, psc(bh) - 1);
bz = min(gz, psc(bc) - 1);
}
else
{
// implicit broadcast
if (psc(adims) == 4)
{
if (psc(bdims) == 3)
{
// type 13
bx = yh;
by = yd;
bz = gz;
}

if (psc(bdims) == 2)
{
// type 12
bx = yd;
by = gz;
bz = 0;
}

if (psc(bdims) == 1)
{
// type 11
bx = gz;
by = 0;
bz = 0;
}
}
else if (psc(adims) == 3)
{
if (psc(bdims) == 2)
{
// type 10
bx = gy;
by = gz;
bz = 0;
}

if (psc(bdims) == 1)
{
// type 9
bx = gz;
by = 0;
bz = 0;
}
}
else // if (psc(adims) == 2)
{
// if (psc(bdims) == 1)
{
// type 8
bx = gy;
by = 0;
bz = 0;
}
}
}

#if NCNN_image_shader
afp v1 = image3d_ld1(a_blob_3d, ivec3(gx, gy, gz));
afp v2 = image3d_ld1(b_blob_3d, ivec3(bx, by, bz));
#else
int gi = gz * psc(outcstep) + gy * psc(outw) + gx;
int bi = bz * psc(bcstep) + by * psc(bw) + bx;

afp v1 = buffer_ld1(a_blob_data, gi);
afp v2 = buffer_ld1(b_blob_data, bi);
#endif

afp res;

if (op_type == 0) res = v1 + v2;
if (op_type == 1) res = v1 - v2;
if (op_type == 2) res = v1 * v2;
if (op_type == 3) res = v1 / v2;
if (op_type == 4) res = max(v1, v2);
if (op_type == 5) res = min(v1, v2);
if (op_type == 6) res = pow(v1, v2);
if (op_type == 7) res = v2 - v1;
if (op_type == 8) res = v2 / v1;
if (op_type == 9) res = pow(v2, v1);

#if NCNN_image_shader
image3d_st1(top_blob_3d, ivec3(gx, gy, gz), res);
#else
buffer_st1(top_blob_data, gi, res);
#endif
}

+ 193
- 0
src/layer/vulkan/shader/binaryop_broadcast_inner_pack4.comp View File

@@ -0,0 +1,193 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#version 450

#if NCNN_fp16_storage
#extension GL_EXT_shader_16bit_storage: require
#endif
#if NCNN_fp16_arithmetic
#extension GL_EXT_shader_explicit_arithmetic_types_float16: require
#endif

layout (constant_id = 0) const int op_type = 0;

#define shape_constant_id_offset 1
layout (constant_id = shape_constant_id_offset + 0) const int adims = 0;
layout (constant_id = shape_constant_id_offset + 1) const int aw = 0;
layout (constant_id = shape_constant_id_offset + 2) const int ah = 0;
layout (constant_id = shape_constant_id_offset + 3) const int ad = 0;
layout (constant_id = shape_constant_id_offset + 4) const int ac = 0;
layout (constant_id = shape_constant_id_offset + 5) const int acstep = 0;

layout (constant_id = shape_constant_id_offset + 6) const int bdims = 0;
layout (constant_id = shape_constant_id_offset + 7) const int bw = 0;
layout (constant_id = shape_constant_id_offset + 8) const int bh = 0;
layout (constant_id = shape_constant_id_offset + 9) const int bd = 0;
layout (constant_id = shape_constant_id_offset + 10) const int bc = 0;
layout (constant_id = shape_constant_id_offset + 11) const int bcstep = 0;

layout (constant_id = shape_constant_id_offset + 12) const int outdims = 0;
layout (constant_id = shape_constant_id_offset + 13) const int outw = 0;
layout (constant_id = shape_constant_id_offset + 14) const int outh = 0;
layout (constant_id = shape_constant_id_offset + 15) const int outd = 0;
layout (constant_id = shape_constant_id_offset + 16) const int outc = 0;
layout (constant_id = shape_constant_id_offset + 17) const int outcstep = 0;

#if NCNN_image_shader
layout (binding = 0) uniform unfp sampler3D a_blob_3d;
layout (binding = 1) uniform unfp sampler3D b_blob_3d;
layout (binding = 2, imfmtc4) writeonly uniform unfp image3D top_blob_3d;
#else
layout (binding = 0) readonly buffer a_blob { sfpvec4 a_blob_data[]; };
layout (binding = 1) readonly buffer b_blob { sfpvec4 b_blob_data[]; };
layout (binding = 2) writeonly buffer top_blob { sfpvec4 top_blob_data[]; };
#endif

layout (push_constant) uniform parameter
{
int adims;
int aw;
int ah;
int ad;
int ac;
int acstep;

int bdims;
int bw;
int bh;
int bd;
int bc;
int bcstep;

int outdims;
int outw;
int outh;
int outd;
int outc;
int outcstep;
} p;

void main()
{
int gx = int(gl_GlobalInvocationID.x);
int gy = int(gl_GlobalInvocationID.y);
int gz = int(gl_GlobalInvocationID.z);

if (gx >= psc(outw) || gy >= psc(outh) * psc(outd) || gz >= psc(outc))
return;

int yd = gy / psc(outh);
int yh = gy % psc(outh);

int bx = gx;
int by = gy;
int bz = gz;

if (psc(adims) == psc(bdims))
{
// explicit broadcast
bx = min(gx, psc(bw) - 1);
by = min(yd, psc(bd) - 1) * psc(bh) + min(yh, psc(bh) - 1);
bz = min(gz, psc(bc) - 1);
}
else
{
// implicit broadcast
if (psc(adims) == 4)
{
if (psc(bdims) == 3)
{
// type 13
bx = yh;
by = yd;
bz = gz;
}

if (psc(bdims) == 2)
{
// type 12
bx = yd;
by = gz;
bz = 0;
}

if (psc(bdims) == 1)
{
// type 11
bx = gz;
by = 0;
bz = 0;
}
}
else if (psc(adims) == 3)
{
if (psc(bdims) == 2)
{
// type 10
bx = gy;
by = gz;
bz = 0;
}

if (psc(bdims) == 1)
{
// type 9
bx = gz;
by = 0;
bz = 0;
}
}
else // if (psc(adims) == 2)
{
// if (psc(bdims) == 1)
{
// type 8
bx = gy;
by = 0;
bz = 0;
}
}
}

#if NCNN_image_shader
afpvec4 v1 = image3d_ld4(a_blob_3d, ivec3(gx, gy, gz));
afpvec4 v2 = image3d_ld4(b_blob_3d, ivec3(bx, by, bz));
#else
int gi = gz * psc(outcstep) + gy * psc(outw) + gx;
int bi = bz * psc(bcstep) + by * psc(bw) + bx;

afpvec4 v1 = buffer_ld4(a_blob_data, gi);
afpvec4 v2 = buffer_ld4(b_blob_data, bi);
#endif

afpvec4 res;

if (op_type == 0) res = v1 + v2;
if (op_type == 1) res = v1 - v2;
if (op_type == 2) res = v1 * v2;
if (op_type == 3) res = v1 / v2;
if (op_type == 4) res = max(v1, v2);
if (op_type == 5) res = min(v1, v2);
if (op_type == 6) res = pow(v1, v2);
if (op_type == 7) res = v2 - v1;
if (op_type == 8) res = v2 / v1;
if (op_type == 9) res = pow(v2, v1);

#if NCNN_image_shader
image3d_st4(top_blob_3d, ivec3(gx, gy, gz), res);
#else
buffer_st4(top_blob_data, gi, res);
#endif
}

+ 234
- 0
src/layer/vulkan/shader/binaryop_broadcast_inner_pack8.comp View File

@@ -0,0 +1,234 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#version 450

#if NCNN_fp16_storage
#extension GL_EXT_shader_16bit_storage: require
struct sfpvec8 { f16vec4 abcd; f16vec4 efgh; };
#endif
#if NCNN_fp16_arithmetic
#extension GL_EXT_shader_explicit_arithmetic_types_float16: require
#endif

layout (constant_id = 0) const int op_type = 0;

#define shape_constant_id_offset 1
layout (constant_id = shape_constant_id_offset + 0) const int adims = 0;
layout (constant_id = shape_constant_id_offset + 1) const int aw = 0;
layout (constant_id = shape_constant_id_offset + 2) const int ah = 0;
layout (constant_id = shape_constant_id_offset + 3) const int ad = 0;
layout (constant_id = shape_constant_id_offset + 4) const int ac = 0;
layout (constant_id = shape_constant_id_offset + 5) const int acstep = 0;

layout (constant_id = shape_constant_id_offset + 6) const int bdims = 0;
layout (constant_id = shape_constant_id_offset + 7) const int bw = 0;
layout (constant_id = shape_constant_id_offset + 8) const int bh = 0;
layout (constant_id = shape_constant_id_offset + 9) const int bd = 0;
layout (constant_id = shape_constant_id_offset + 10) const int bc = 0;
layout (constant_id = shape_constant_id_offset + 11) const int bcstep = 0;

layout (constant_id = shape_constant_id_offset + 12) const int outdims = 0;
layout (constant_id = shape_constant_id_offset + 13) const int outw = 0;
layout (constant_id = shape_constant_id_offset + 14) const int outh = 0;
layout (constant_id = shape_constant_id_offset + 15) const int outd = 0;
layout (constant_id = shape_constant_id_offset + 16) const int outc = 0;
layout (constant_id = shape_constant_id_offset + 17) const int outcstep = 0;

#if NCNN_image_shader
layout (binding = 0) uniform unfp sampler3D a_blob_3d;
layout (binding = 1) uniform unfp sampler3D b_blob_3d;
layout (binding = 2, imfmtc4) writeonly uniform unfp image3D top_blob_3d;
#else
layout (binding = 0) readonly buffer a_blob { sfpvec8 a_blob_data[]; };
layout (binding = 1) readonly buffer b_blob { sfpvec8 b_blob_data[]; };
layout (binding = 2) writeonly buffer top_blob { sfpvec8 top_blob_data[]; };
#endif

layout (push_constant) uniform parameter
{
int adims;
int aw;
int ah;
int ad;
int ac;
int acstep;

int bdims;
int bw;
int bh;
int bd;
int bc;
int bcstep;

int outdims;
int outw;
int outh;
int outd;
int outc;
int outcstep;
} p;

void main()
{
int gx = int(gl_GlobalInvocationID.x);
int gy = int(gl_GlobalInvocationID.y);
int gz = int(gl_GlobalInvocationID.z);

if (gx >= psc(outw) || gy >= psc(outh) * psc(outd) || gz >= psc(outc))
return;

int yd = gy / psc(outh);
int yh = gy % psc(outh);

int bx = gx;
int by = gy;
int bz = gz;

if (psc(adims) == psc(bdims))
{
// explicit broadcast
bx = min(gx, psc(bw) - 1);
by = min(yd, psc(bd) - 1) * psc(bh) + min(yh, psc(bh) - 1);
bz = min(gz, psc(bc) - 1);
}
else
{
// implicit broadcast
if (psc(adims) == 4)
{
if (psc(bdims) == 3)
{
// type 13
bx = yh;
by = yd;
bz = gz;
}

if (psc(bdims) == 2)
{
// type 12
bx = yd;
by = gz;
bz = 0;
}

if (psc(bdims) == 1)
{
// type 11
bx = gz;
by = 0;
bz = 0;
}
}
else if (psc(adims) == 3)
{
if (psc(bdims) == 2)
{
// type 10
bx = gy;
by = gz;
bz = 0;
}

if (psc(bdims) == 1)
{
// type 9
bx = gz;
by = 0;
bz = 0;
}
}
else // if (psc(adims) == 2)
{
// if (psc(bdims) == 1)
{
// type 8
bx = gy;
by = 0;
bz = 0;
}
}
}

#if NCNN_image_shader
afpvec8 v1 = image3d_ld8(a_blob_3d, ivec3(gx, gy, gz));
afpvec8 v2 = image3d_ld8(b_blob_3d, ivec3(bx, by, bz));
#else
int gi = gz * psc(outcstep) + gy * psc(outw) + gx;
int bi = bz * psc(bcstep) + by * psc(bw) + bx;

afpvec8 v1 = buffer_ld8(a_blob_data, gi);
afpvec8 v2 = buffer_ld8(b_blob_data, bi);
#endif

afpvec8 res;

if (op_type == 0)
{
res[0] = v1[0] + v2[0];
res[1] = v1[1] + v2[1];
}
if (op_type == 1)
{
res[0] = v1[0] - v2[0];
res[1] = v1[1] - v2[1];
}
if (op_type == 2)
{
res[0] = v1[0] * v2[0];
res[1] = v1[1] * v2[1];
}
if (op_type == 3)
{
res[0] = v1[0] / v2[0];
res[1] = v1[1] / v2[1];
}
if (op_type == 4)
{
res[0] = max(v1[0], v2[0]);
res[1] = max(v1[1], v2[1]);
}
if (op_type == 5)
{
res[0] = min(v1[0], v2[0]);
res[1] = min(v1[1], v2[1]);
}
if (op_type == 6)
{
res[0] = pow(v1[0], v2[0]);
res[1] = pow(v1[1], v2[1]);
}
if (op_type == 7)
{
res[0] = v2[0] - v1[0];
res[1] = v2[1] - v1[1];
}
if (op_type == 8)
{
res[0] = v2[0] / v1[0];
res[1] = v2[1] / v1[1];
}
if (op_type == 9)
{
res[0] = pow(v2[0], v1[0]);
res[1] = pow(v2[1], v1[1]);
}

#if NCNN_image_shader
image3d_st8(top_blob_3d, ivec3(gx, gy, gz), res);
#else
buffer_st8(top_blob_data, gi, res);
#endif
}

src/layer/vulkan/shader/binaryop_broadcast_a1_pack4.comp → src/layer/vulkan/shader/binaryop_broadcast_outer.comp View File

@@ -1,6 +1,6 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
@@ -27,29 +27,32 @@ layout (constant_id = 0) const int op_type = 0;
layout (constant_id = shape_constant_id_offset + 0) const int adims = 0;
layout (constant_id = shape_constant_id_offset + 1) const int aw = 0;
layout (constant_id = shape_constant_id_offset + 2) const int ah = 0;
layout (constant_id = shape_constant_id_offset + 3) const int ac = 0;
layout (constant_id = shape_constant_id_offset + 4) const int acstep = 0;

layout (constant_id = shape_constant_id_offset + 5) const int bdims = 0;
layout (constant_id = shape_constant_id_offset + 6) const int bw = 0;
layout (constant_id = shape_constant_id_offset + 7) const int bh = 0;
layout (constant_id = shape_constant_id_offset + 8) const int bc = 0;
layout (constant_id = shape_constant_id_offset + 9) const int bcstep = 0;

layout (constant_id = shape_constant_id_offset + 10) const int outdims = 0;
layout (constant_id = shape_constant_id_offset + 11) const int outw = 0;
layout (constant_id = shape_constant_id_offset + 12) const int outh = 0;
layout (constant_id = shape_constant_id_offset + 13) const int outc = 0;
layout (constant_id = shape_constant_id_offset + 14) const int outcstep = 0;
layout (constant_id = shape_constant_id_offset + 3) const int ad = 0;
layout (constant_id = shape_constant_id_offset + 4) const int ac = 0;
layout (constant_id = shape_constant_id_offset + 5) const int acstep = 0;

layout (constant_id = shape_constant_id_offset + 6) const int bdims = 0;
layout (constant_id = shape_constant_id_offset + 7) const int bw = 0;
layout (constant_id = shape_constant_id_offset + 8) const int bh = 0;
layout (constant_id = shape_constant_id_offset + 9) const int bd = 0;
layout (constant_id = shape_constant_id_offset + 10) const int bc = 0;
layout (constant_id = shape_constant_id_offset + 11) const int bcstep = 0;

layout (constant_id = shape_constant_id_offset + 12) const int outdims = 0;
layout (constant_id = shape_constant_id_offset + 13) const int outw = 0;
layout (constant_id = shape_constant_id_offset + 14) const int outh = 0;
layout (constant_id = shape_constant_id_offset + 15) const int outd = 0;
layout (constant_id = shape_constant_id_offset + 16) const int outc = 0;
layout (constant_id = shape_constant_id_offset + 17) const int outcstep = 0;

#if NCNN_image_shader
layout (binding = 0) uniform unfp sampler3D a_blob_3d;
layout (binding = 1) uniform unfp sampler3D b_blob_3d;
layout (binding = 2, imfmtc4) writeonly uniform unfp image3D top_blob_3d;
layout (binding = 2, imfmtc1) writeonly uniform unfp image3D top_blob_3d;
#else
layout (binding = 0) readonly buffer a_blob { sfp a_blob_data[]; };
layout (binding = 1) readonly buffer b_blob { sfpvec4 b_blob_data[]; };
layout (binding = 2) writeonly buffer top_blob { sfpvec4 top_blob_data[]; };
layout (binding = 1) readonly buffer b_blob { sfp b_blob_data[]; };
layout (binding = 2) writeonly buffer top_blob { sfp top_blob_data[]; };
#endif

layout (push_constant) uniform parameter
@@ -57,18 +60,21 @@ layout (push_constant) uniform parameter
int adims;
int aw;
int ah;
int ad;
int ac;
int acstep;

int bdims;
int bw;
int bh;
int bd;
int bc;
int bcstep;

int outdims;
int outw;
int outh;
int outd;
int outc;
int outcstep;
} p;
@@ -79,40 +85,29 @@ void main()
int gy = int(gl_GlobalInvocationID.y);
int gz = int(gl_GlobalInvocationID.z);

if (gx >= psc(outw) || gy >= psc(outh) || gz >= psc(outc))
if (gx >= psc(outw) || gy >= psc(outh) * psc(outd) || gz >= psc(outc))
return;

#if NCNN_image_shader
afpvec4 v1;

if (psc(adims) == 3 && psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(ac) == 1)
{
// special type 4
v1 = afpvec4(image3d_ld1(a_blob_3d, ivec3(gx, gy, 0)));
}
else
{
v1 = afpvec4(image3d_ld1(a_blob_3d, ivec3(0, 0, 0)));
}

afpvec4 v2 = image3d_ld4(b_blob_3d, ivec3(gx, gy, gz));
#else
const int gi = gz * psc(outcstep) + gy * psc(outw) + gx;
int yd = gy / psc(outh);
int yh = gy % psc(outh);

int ai = 0;
// explicit broadcast
int bx = min(gx, psc(bw) - 1);
int by = min(yd, psc(bd) - 1) * psc(bh) + min(yh, psc(bh) - 1);
int bz = min(gz, psc(bc) - 1);

if (psc(adims) == 3 && psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(ac) == 1)
{
// special type 4
ai = gy * psc(aw) + gx;
}
#if NCNN_image_shader
afp v1 = image3d_ld1(a_blob_3d, ivec3(gx, gy, gz));
afp v2 = image3d_ld1(b_blob_3d, ivec3(bx, by, bz));
#else
int gi = gz * psc(outcstep) + gy * psc(outw) + gx;
int bi = bz * psc(bcstep) + by * psc(bw) + bx;

// type 2 3 4
afpvec4 v1 = afpvec4(buffer_ld1(a_blob_data, ai));
afpvec4 v2 = buffer_ld4(b_blob_data, gi);
afp v1 = buffer_ld1(a_blob_data, gi);
afp v2 = buffer_ld1(b_blob_data, bi);
#endif

afpvec4 res;
afp res;

if (op_type == 0) res = v1 + v2;
if (op_type == 1) res = v1 - v2;
@@ -123,10 +118,11 @@ void main()
if (op_type == 6) res = pow(v1, v2);
if (op_type == 7) res = v2 - v1;
if (op_type == 8) res = v2 / v1;
if (op_type == 9) res = pow(v2, v1);

#if NCNN_image_shader
image3d_st4(top_blob_3d, ivec3(gx, gy, gz), res);
image3d_st1(top_blob_3d, ivec3(gx, gy, gz), res);
#else
buffer_st4(top_blob_data, gi, res);
buffer_st1(top_blob_data, gi, res);
#endif
}

src/layer/vulkan/shader/binaryop_broadcast_b1_pack4.comp → src/layer/vulkan/shader/binaryop_broadcast_outer_pack4.comp View File

@@ -1,6 +1,6 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
@@ -27,20 +27,23 @@ layout (constant_id = 0) const int op_type = 0;
layout (constant_id = shape_constant_id_offset + 0) const int adims = 0;
layout (constant_id = shape_constant_id_offset + 1) const int aw = 0;
layout (constant_id = shape_constant_id_offset + 2) const int ah = 0;
layout (constant_id = shape_constant_id_offset + 3) const int ac = 0;
layout (constant_id = shape_constant_id_offset + 4) const int acstep = 0;

layout (constant_id = shape_constant_id_offset + 5) const int bdims = 0;
layout (constant_id = shape_constant_id_offset + 6) const int bw = 0;
layout (constant_id = shape_constant_id_offset + 7) const int bh = 0;
layout (constant_id = shape_constant_id_offset + 8) const int bc = 0;
layout (constant_id = shape_constant_id_offset + 9) const int bcstep = 0;

layout (constant_id = shape_constant_id_offset + 10) const int outdims = 0;
layout (constant_id = shape_constant_id_offset + 11) const int outw = 0;
layout (constant_id = shape_constant_id_offset + 12) const int outh = 0;
layout (constant_id = shape_constant_id_offset + 13) const int outc = 0;
layout (constant_id = shape_constant_id_offset + 14) const int outcstep = 0;
layout (constant_id = shape_constant_id_offset + 3) const int ad = 0;
layout (constant_id = shape_constant_id_offset + 4) const int ac = 0;
layout (constant_id = shape_constant_id_offset + 5) const int acstep = 0;

layout (constant_id = shape_constant_id_offset + 6) const int bdims = 0;
layout (constant_id = shape_constant_id_offset + 7) const int bw = 0;
layout (constant_id = shape_constant_id_offset + 8) const int bh = 0;
layout (constant_id = shape_constant_id_offset + 9) const int bd = 0;
layout (constant_id = shape_constant_id_offset + 10) const int bc = 0;
layout (constant_id = shape_constant_id_offset + 11) const int bcstep = 0;

layout (constant_id = shape_constant_id_offset + 12) const int outdims = 0;
layout (constant_id = shape_constant_id_offset + 13) const int outw = 0;
layout (constant_id = shape_constant_id_offset + 14) const int outh = 0;
layout (constant_id = shape_constant_id_offset + 15) const int outd = 0;
layout (constant_id = shape_constant_id_offset + 16) const int outc = 0;
layout (constant_id = shape_constant_id_offset + 17) const int outcstep = 0;

#if NCNN_image_shader
layout (binding = 0) uniform unfp sampler3D a_blob_3d;
@@ -57,18 +60,21 @@ layout (push_constant) uniform parameter
int adims;
int aw;
int ah;
int ad;
int ac;
int acstep;

int bdims;
int bw;
int bh;
int bd;
int bc;
int bcstep;

int outdims;
int outw;
int outh;
int outd;
int outc;
int outcstep;
} p;
@@ -79,35 +85,24 @@ void main()
int gy = int(gl_GlobalInvocationID.y);
int gz = int(gl_GlobalInvocationID.z);

if (gx >= psc(outw) || gy >= psc(outh) || gz >= psc(outc))
if (gx >= psc(outw) || gy >= psc(outh) * psc(outd) || gz >= psc(outc))
return;

#if NCNN_image_shader
afpvec4 v2;

if (psc(bdims) == 3 && psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(bc) == 1)
{
// special type 2
v2 = afpvec4(image3d_ld1(b_blob_3d, ivec3(gx, gy, 0)));
}
else
{
v2 = afpvec4(image3d_ld1(b_blob_3d, ivec3(0, 0, 0)));
}
int yd = gy / psc(outh);
int yh = gy % psc(outh);

// explicit broadcast
int bx = min(gx, psc(bw) - 1);
int by = min(yd, psc(bd) - 1) * psc(bh) + min(yh, psc(bh) - 1);
int bz = min(gz, psc(bc) - 1);

#if NCNN_image_shader
afpvec4 v1 = image3d_ld4(a_blob_3d, ivec3(gx, gy, gz));
afpvec4 v2 = afpvec4(image3d_ld1(b_blob_3d, ivec3(bx, by, bz)));
#else
const int gi = gz * psc(outcstep) + gy * psc(outw) + gx;

int bi = 0;

if (psc(bdims) == 3 && psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(bc) == 1)
{
// special type 2
bi = gy * psc(bw) + gx;
}
int gi = gz * psc(outcstep) + gy * psc(outw) + gx;
int bi = bz * psc(bcstep) + by * psc(bw) + bx;

// type 6 11 16
afpvec4 v1 = buffer_ld4(a_blob_data, gi);
afpvec4 v2 = afpvec4(buffer_ld1(b_blob_data, bi));
#endif
@@ -123,6 +118,7 @@ void main()
if (op_type == 6) res = pow(v1, v2);
if (op_type == 7) res = v2 - v1;
if (op_type == 8) res = v2 / v1;
if (op_type == 9) res = pow(v2, v1);

#if NCNN_image_shader
image3d_st4(top_blob_3d, ivec3(gx, gy, gz), res);

src/layer/vulkan/shader/binaryop_broadcast_b1_pack8.comp → src/layer/vulkan/shader/binaryop_broadcast_outer_pack8.comp View File

@@ -1,6 +1,6 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
@@ -28,20 +28,23 @@ layout (constant_id = 0) const int op_type = 0;
layout (constant_id = shape_constant_id_offset + 0) const int adims = 0;
layout (constant_id = shape_constant_id_offset + 1) const int aw = 0;
layout (constant_id = shape_constant_id_offset + 2) const int ah = 0;
layout (constant_id = shape_constant_id_offset + 3) const int ac = 0;
layout (constant_id = shape_constant_id_offset + 4) const int acstep = 0;

layout (constant_id = shape_constant_id_offset + 5) const int bdims = 0;
layout (constant_id = shape_constant_id_offset + 6) const int bw = 0;
layout (constant_id = shape_constant_id_offset + 7) const int bh = 0;
layout (constant_id = shape_constant_id_offset + 8) const int bc = 0;
layout (constant_id = shape_constant_id_offset + 9) const int bcstep = 0;

layout (constant_id = shape_constant_id_offset + 10) const int outdims = 0;
layout (constant_id = shape_constant_id_offset + 11) const int outw = 0;
layout (constant_id = shape_constant_id_offset + 12) const int outh = 0;
layout (constant_id = shape_constant_id_offset + 13) const int outc = 0;
layout (constant_id = shape_constant_id_offset + 14) const int outcstep = 0;
layout (constant_id = shape_constant_id_offset + 3) const int ad = 0;
layout (constant_id = shape_constant_id_offset + 4) const int ac = 0;
layout (constant_id = shape_constant_id_offset + 5) const int acstep = 0;

layout (constant_id = shape_constant_id_offset + 6) const int bdims = 0;
layout (constant_id = shape_constant_id_offset + 7) const int bw = 0;
layout (constant_id = shape_constant_id_offset + 8) const int bh = 0;
layout (constant_id = shape_constant_id_offset + 9) const int bd = 0;
layout (constant_id = shape_constant_id_offset + 10) const int bc = 0;
layout (constant_id = shape_constant_id_offset + 11) const int bcstep = 0;

layout (constant_id = shape_constant_id_offset + 12) const int outdims = 0;
layout (constant_id = shape_constant_id_offset + 13) const int outw = 0;
layout (constant_id = shape_constant_id_offset + 14) const int outh = 0;
layout (constant_id = shape_constant_id_offset + 15) const int outd = 0;
layout (constant_id = shape_constant_id_offset + 16) const int outc = 0;
layout (constant_id = shape_constant_id_offset + 17) const int outcstep = 0;

#if NCNN_image_shader
layout (binding = 0) uniform unfp sampler3D a_blob_3d;
@@ -58,18 +61,21 @@ layout (push_constant) uniform parameter
int adims;
int aw;
int ah;
int ad;
int ac;
int acstep;

int bdims;
int bw;
int bh;
int bd;
int bc;
int bcstep;

int outdims;
int outw;
int outh;
int outd;
int outc;
int outcstep;
} p;
@@ -80,85 +86,82 @@ void main()
int gy = int(gl_GlobalInvocationID.y);
int gz = int(gl_GlobalInvocationID.z);

if (gx >= psc(outw) || gy >= psc(outh) || gz >= psc(outc))
if (gx >= psc(outw) || gy >= psc(outh) * psc(outd) || gz >= psc(outc))
return;

#if NCNN_image_shader
afpvec4 v2;
int yd = gy / psc(outh);
int yh = gy % psc(outh);

if (psc(bdims) == 3 && psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(bc) == 1)
{
// special type 2
v2 = afpvec4(image3d_ld1(b_blob_3d, ivec3(gx, gy, 0)));
}
else
{
v2 = afpvec4(image3d_ld1(b_blob_3d, ivec3(0, 0, 0)));
}
// explicit broadcast
int bx = min(gx, psc(bw) - 1);
int by = min(yd, psc(bd) - 1) * psc(bh) + min(yh, psc(bh) - 1);
int bz = min(gz, psc(bc) - 1);

#if NCNN_image_shader
afpvec8 v1 = image3d_ld8(a_blob_3d, ivec3(gx, gy, gz));
afp b = image3d_ld1(b_blob_3d, ivec3(bx, by, bz));
#else
const int gi = gz * psc(outcstep) + gy * psc(outw) + gx;

int bi = 0;
int gi = gz * psc(outcstep) + gy * psc(outw) + gx;
int bi = bz * psc(bcstep) + by * psc(bw) + bx;

if (psc(bdims) == 3 && psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(bc) == 1)
{
// special type 2
bi = gy * psc(bw) + gx;
}

// type 6 11 16
afpvec8 v1 = buffer_ld8(a_blob_data, gi);
afpvec4 v2 = afpvec4(buffer_ld1(b_blob_data, bi));
afp b = buffer_ld1(b_blob_data, bi);
#endif
afpvec8 v2;
v2[0] = afpvec4(b);
v2[1] = afpvec4(b);

afpvec8 res;

if (op_type == 0)
{
res[0] = v1[0] + v2;
res[1] = v1[1] + v2;
res[0] = v1[0] + v2[0];
res[1] = v1[1] + v2[1];
}
if (op_type == 1)
{
res[0] = v1[0] - v2;
res[1] = v1[1] - v2;
res[0] = v1[0] - v2[0];
res[1] = v1[1] - v2[1];
}
if (op_type == 2)
{
res[0] = v1[0] * v2;
res[1] = v1[1] * v2;
res[0] = v1[0] * v2[0];
res[1] = v1[1] * v2[1];
}
if (op_type == 3)
{
res[0] = v1[0] / v2;
res[1] = v1[1] / v2;
res[0] = v1[0] / v2[0];
res[1] = v1[1] / v2[1];
}
if (op_type == 4)
{
res[0] = max(v1[0], v2);
res[1] = max(v1[1], v2);
res[0] = max(v1[0], v2[0]);
res[1] = max(v1[1], v2[1]);
}
if (op_type == 5)
{
res[0] = min(v1[0], v2);
res[1] = min(v1[1], v2);
res[0] = min(v1[0], v2[0]);
res[1] = min(v1[1], v2[1]);
}
if (op_type == 6)
{
res[0] = pow(v1[0], v2);
res[1] = pow(v1[1], v2);
res[0] = pow(v1[0], v2[0]);
res[1] = pow(v1[1], v2[1]);
}
if (op_type == 7)
{
res[0] = v2 - v1[0];
res[1] = v2 - v1[1];
res[0] = v2[0] - v1[0];
res[1] = v2[1] - v1[1];
}
if (op_type == 8)
{
res[0] = v2 / v1[0];
res[1] = v2 / v1[1];
res[0] = v2[0] / v1[0];
res[1] = v2[1] / v1[1];
}
if (op_type == 9)
{
res[0] = pow(v2[0], v1[0]);
res[1] = pow(v2[1], v1[1]);
}

#if NCNN_image_shader

+ 0
- 502
src/layer/vulkan/shader/binaryop_broadcast_pack4.comp View File

@@ -1,502 +0,0 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#version 450

#if NCNN_fp16_storage
#extension GL_EXT_shader_16bit_storage: require
#endif
#if NCNN_fp16_arithmetic
#extension GL_EXT_shader_explicit_arithmetic_types_float16: require
#endif

layout (constant_id = 0) const int op_type = 0;

#define shape_constant_id_offset 1
layout (constant_id = shape_constant_id_offset + 0) const int adims = 0;
layout (constant_id = shape_constant_id_offset + 1) const int aw = 0;
layout (constant_id = shape_constant_id_offset + 2) const int ah = 0;
layout (constant_id = shape_constant_id_offset + 3) const int ad = 0;
layout (constant_id = shape_constant_id_offset + 4) const int ac = 0;
layout (constant_id = shape_constant_id_offset + 5) const int acstep = 0;

layout (constant_id = shape_constant_id_offset + 6) const int bdims = 0;
layout (constant_id = shape_constant_id_offset + 7) const int bw = 0;
layout (constant_id = shape_constant_id_offset + 8) const int bh = 0;
layout (constant_id = shape_constant_id_offset + 9) const int bd = 0;
layout (constant_id = shape_constant_id_offset + 10) const int bc = 0;
layout (constant_id = shape_constant_id_offset + 11) const int bcstep = 0;

layout (constant_id = shape_constant_id_offset + 12) const int outdims = 0;
layout (constant_id = shape_constant_id_offset + 13) const int outw = 0;
layout (constant_id = shape_constant_id_offset + 14) const int outh = 0;
layout (constant_id = shape_constant_id_offset + 15) const int outd = 0;
layout (constant_id = shape_constant_id_offset + 16) const int outc = 0;
layout (constant_id = shape_constant_id_offset + 17) const int outcstep = 0;

#if NCNN_image_shader
layout (binding = 0) uniform unfp sampler3D a_blob_3d;
layout (binding = 1) uniform unfp sampler3D b_blob_3d;
layout (binding = 2, imfmtc4) writeonly uniform unfp image3D top_blob_3d;
#else
layout (binding = 0) readonly buffer a_blob { sfpvec4 a_blob_data[]; };
layout (binding = 1) readonly buffer b_blob { sfpvec4 b_blob_data[]; };
layout (binding = 2) writeonly buffer top_blob { sfpvec4 top_blob_data[]; };
#endif

layout (push_constant) uniform parameter
{
int adims;
int aw;
int ah;
int ad;
int ac;
int acstep;

int bdims;
int bw;
int bh;
int bd;
int bc;
int bcstep;

int outdims;
int outw;
int outh;
int outd;
int outc;
int outcstep;
} p;

void main()
{
int gx = int(gl_GlobalInvocationID.x);
int gy = int(gl_GlobalInvocationID.y);
int gz = int(gl_GlobalInvocationID.z);

if (gx >= psc(outw) || gy >= psc(outh) * psc(outd) || gz >= psc(outc))
return;

#if NCNN_image_shader
int ax = gx;
int ay = gy;
int az = gz;
int bx = gx;
int by = gy;
int bz = gz;

if (psc(adims) == 4)
{
int yd = gy / psc(outh);
int yh = gy % psc(outh);

if (psc(bdims) == 3)
{
// type 28
bx = yh;
by = yd;
bz = gz;
}

if (psc(bdims) == 2)
{
// type 27
bx = yd;
by = gz;
bz = 0;
}

if (psc(bdims) == 1)
{
if (psc(bw) == 1)
{
// type 25
bx = 0;
by = 0;
bz = 0;
}
else
{
// type 26
bx = gz;
by = 0;
bz = 0;
}
}
}
else if (psc(adims) == 3)
{
if (psc(bdims) == 4)
{
int yd = gy / psc(outh);
int yh = gy % psc(outh);

// type 23
ax = yh;
ay = yd;
az = gz;
}

if (psc(bdims) == 3)
{
if (psc(bw) == 1 && psc(bh) == 1)
{
// special type 1
bx = 0;
by = 0;
}

if (psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(bc) == 1)
{
// special type 2
bz = 0;
}

if (psc(aw) == 1 && psc(ah) == 1)
{
// special type 3
ax = 0;
ay = 0;
}

if (psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(ac) == 1)
{
// special type 4
az = 0;
}

if (psc(aw) != 1 && psc(bw) == 1 && psc(bh) == psc(ah) && psc(bc) == psc(ac))
{
// special type 5
bx = 0;
}

if (psc(bw) == psc(aw) && psc(ah) != 1 && psc(bh) == 1 && psc(bc) == psc(ac))
{
// special type 6
by = 0;
}

if (psc(bw) != 1 && psc(aw) == 1 && psc(bh) == psc(ah) && psc(bc) == psc(ac))
{
// special type 7
ax = 0;
}

if (psc(bw) == psc(aw) && psc(bh) != 1 && psc(ah) == 1 && psc(bc) == psc(ac))
{
// special type 8
ay = 0;
}
}

if (psc(bdims) == 2)
{
// type 18
bx = gy;
by = gz;
bz = 0;
}

if (psc(bdims) == 1)
{
if (psc(bw) == 1)
{
// type 16
bx = 0;
by = 0;
bz = 0;
}
else
{
// type 17
bx = gz;
by = 0;
bz = 0;
}
}
}
else if (psc(adims) == 2)
{
if (psc(bdims) == 4)
{
int yd = gy / psc(outh);
int yh = gy % psc(outh);

// type 22
ax = yd;
ay = gz;
az = 0;
}

if (psc(bdims) == 3)
{
// type 14
ax = gy;
ay = gz;
az = 0;
}

if (psc(bdims) == 1)
{
if (psc(bw) == 1)
{
// type 11
bx = 0;
by = 0;
bz = 0;
}
else
{
// type 12
bx = gy;
by = 0;
bz = 0;
}
}
}
else if (psc(adims) == 1)
{
if (psc(aw) == 1)
{
// type 2 3 4 20
ax = 0;
ay = 0;
az = 0;
}
else
{
if (psc(bdims) == 4)
{
// type 21
ax = gz;
ay = 0;
az = 0;
}

if (psc(bdims) == 3)
{
// type 9
ax = gz;
ay = 0;
az = 0;
}

if (psc(bdims) == 2)
{
// type 8
ax = gy;
ay = 0;
az = 0;
}

if (psc(bdims) == 1)
{
if (psc(bw) == 1)
{
// type 6
bx = 0;
by = 0;
bz = 0;
}
}
}
}

afpvec4 v1 = image3d_ld4(a_blob_3d, ivec3(ax, ay, az));
afpvec4 v2 = image3d_ld4(b_blob_3d, ivec3(bx, by, bz));
#else
const int gi = gz * psc(outcstep) + gy * psc(outw) + gx;

int ai;
int bi;

if (psc(adims) == 4)
{
int yd = gy / psc(outh);
int yh = gy % psc(outh);

if (psc(bdims) == 3)
{
// type 28
ai = gi;
bi = gz * psc(bcstep) + yd * psc(bw) + yh;
}

if (psc(bdims) == 2)
{
// type 27
ai = gi;
bi = gz * psc(bw) + yd;
}

if (psc(bdims) == 1)
{
if (psc(bw) == 1)
{
// type 25
ai = gi;
bi = 0;
}
else
{
// type 26
ai = gi;
bi = gz;
}
}
}
else if (psc(adims) == 3)
{
if (psc(bdims) == 4)
{
int yd = gy / psc(outh);
int yh = gy % psc(outh);

// type 23
ai = gz * psc(acstep) + yd * psc(aw) + yh;
bi = gi;
}

if (psc(bdims) == 3)
{
if (psc(bw) == 1 && psc(bh) == 1)
{
// special type 1
ai = gi;
bi = gz * psc(bcstep);
}

if (psc(aw) == 1 && psc(ah) == 1)
{
// special type 3
ai = gz * psc(acstep);
bi = gi;
}

if (psc(aw) != 1 && psc(bw) == 1 && psc(bh) == psc(ah) && psc(bc) == psc(ac))
{
// special type 5
ai = gi;
bi = gz * psc(bcstep) + gy;
}

if (psc(bw) == psc(aw) && psc(ah) != 1 && psc(bh) == 1 && psc(bc) == psc(ac))
{
// special type 6
ai = gi;
bi = gz * psc(bcstep) + gx;
}

if (psc(bw) != 1 && psc(aw) == 1 && psc(bh) == psc(ah) && psc(bc) == psc(ac))
{
// special type 7
ai = gz * psc(acstep) + gy;
bi = gi;
}

if (psc(bw) == psc(aw) && psc(bh) != 1 && psc(ah) == 1 && psc(bc) == psc(ac))
{
// special type 8
ai = gz * psc(acstep) + gx;
bi = gi;
}
}

if (psc(bdims) == 2)
{
// type 18
ai = gi;
bi = gz * psc(bw) + gy;
}

if (psc(bdims) == 1)
{
// type 17
ai = gi;
bi = gz;
}
}
else if (psc(adims) == 2)
{
if (psc(bdims) == 4)
{
int yd = gy / psc(outh);
int yh = gy % psc(outh);

// type 22
ai = gz * psc(aw) + yd;
bi = gi;
}

if (psc(bdims) == 3)
{
// type 14
ai = gz * psc(aw) + gy;
bi = gi;
}

if (psc(bdims) == 1)
{
// type 12
ai = gi;
bi = gy;
}
}
else if (psc(adims) == 1)
{
if (psc(bdims) == 4)
{
// type 21
ai = gz;
bi = gi;
}

if (psc(bdims) == 3)
{
// type 9
ai = gz;
bi = gi;
}

if (psc(bdims) == 2)
{
// type 8
ai = gy;
bi = gi;
}
}

afpvec4 v1 = buffer_ld4(a_blob_data, ai);
afpvec4 v2 = buffer_ld4(b_blob_data, bi);
#endif

afpvec4 res;

if (op_type == 0) res = v1 + v2;
if (op_type == 1) res = v1 - v2;
if (op_type == 2) res = v1 * v2;
if (op_type == 3) res = v1 / v2;
if (op_type == 4) res = max(v1, v2);
if (op_type == 5) res = min(v1, v2);
if (op_type == 6) res = pow(v1, v2);
if (op_type == 7) res = v2 - v1;
if (op_type == 8) res = v2 / v1;

#if NCNN_image_shader
image3d_st4(top_blob_3d, ivec3(gx, gy, gz), res);
#else
buffer_st4(top_blob_data, gi, res);
#endif
}

+ 0
- 539
src/layer/vulkan/shader/binaryop_broadcast_pack8.comp View File

@@ -1,539 +0,0 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#version 450

#if NCNN_fp16_storage
#extension GL_EXT_shader_16bit_storage: require
struct sfpvec8 { f16vec4 abcd; f16vec4 efgh; };
#endif
#if NCNN_fp16_arithmetic
#extension GL_EXT_shader_explicit_arithmetic_types_float16: require
#endif

layout (constant_id = 0) const int op_type = 0;

#define shape_constant_id_offset 1
layout (constant_id = shape_constant_id_offset + 0) const int adims = 0;
layout (constant_id = shape_constant_id_offset + 1) const int aw = 0;
layout (constant_id = shape_constant_id_offset + 2) const int ah = 0;
layout (constant_id = shape_constant_id_offset + 3) const int ad = 0;
layout (constant_id = shape_constant_id_offset + 4) const int ac = 0;
layout (constant_id = shape_constant_id_offset + 5) const int acstep = 0;

layout (constant_id = shape_constant_id_offset + 6) const int bdims = 0;
layout (constant_id = shape_constant_id_offset + 7) const int bw = 0;
layout (constant_id = shape_constant_id_offset + 8) const int bh = 0;
layout (constant_id = shape_constant_id_offset + 9) const int bd = 0;
layout (constant_id = shape_constant_id_offset + 10) const int bc = 0;
layout (constant_id = shape_constant_id_offset + 11) const int bcstep = 0;

layout (constant_id = shape_constant_id_offset + 12) const int outdims = 0;
layout (constant_id = shape_constant_id_offset + 13) const int outw = 0;
layout (constant_id = shape_constant_id_offset + 14) const int outh = 0;
layout (constant_id = shape_constant_id_offset + 15) const int outd = 0;
layout (constant_id = shape_constant_id_offset + 16) const int outc = 0;
layout (constant_id = shape_constant_id_offset + 17) const int outcstep = 0;

#if NCNN_image_shader
layout (binding = 0) uniform unfp sampler3D a_blob_3d;
layout (binding = 1) uniform unfp sampler3D b_blob_3d;
layout (binding = 2, imfmtc4) writeonly uniform unfp image3D top_blob_3d;
#else
layout (binding = 0) readonly buffer a_blob { sfpvec8 a_blob_data[]; };
layout (binding = 1) readonly buffer b_blob { sfpvec8 b_blob_data[]; };
layout (binding = 2) writeonly buffer top_blob { sfpvec8 top_blob_data[]; };
#endif

layout (push_constant) uniform parameter
{
int adims;
int aw;
int ah;
int ad;
int ac;
int acstep;

int bdims;
int bw;
int bh;
int bd;
int bc;
int bcstep;

int outdims;
int outw;
int outh;
int outd;
int outc;
int outcstep;
} p;

void main()
{
int gx = int(gl_GlobalInvocationID.x);
int gy = int(gl_GlobalInvocationID.y);
int gz = int(gl_GlobalInvocationID.z);

if (gx >= psc(outw) || gy >= psc(outh) * psc(outd) || gz >= psc(outc))
return;

#if NCNN_image_shader
int ax = gx;
int ay = gy;
int az = gz;
int bx = gx;
int by = gy;
int bz = gz;

if (psc(adims) == 4)
{
int yd = gy / psc(outh);
int yh = gy % psc(outh);

if (psc(bdims) == 3)
{
// type 28
bx = yh;
by = yd;
bz = gz;
}

if (psc(bdims) == 2)
{
// type 27
bx = yd;
by = gz;
bz = 0;
}

if (psc(bdims) == 1)
{
if (psc(bw) == 1)
{
// type 25
bx = 0;
by = 0;
bz = 0;
}
else
{
// type 26
bx = gz;
by = 0;
bz = 0;
}
}
}
else if (psc(adims) == 3)
{
if (psc(bdims) == 4)
{
int yd = gy / psc(outh);
int yh = gy % psc(outh);

// type 23
ax = yh;
ay = yd;
az = gz;
}

if (psc(bdims) == 3)
{
if (psc(bw) == 1 && psc(bh) == 1)
{
// special type 1
bx = 0;
by = 0;
}

if (psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(bc) == 1)
{
// special type 2
bz = 0;
}

if (psc(aw) == 1 && psc(ah) == 1)
{
// special type 3
ax = 0;
ay = 0;
}

if (psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(ac) == 1)
{
// special type 4
az = 0;
}

if (psc(aw) != 1 && psc(bw) == 1 && psc(bh) == psc(ah) && psc(bc) == psc(ac))
{
// special type 5
bx = 0;
}

if (psc(bw) == psc(aw) && psc(ah) != 1 && psc(bh) == 1 && psc(bc) == psc(ac))
{
// special type 6
by = 0;
}

if (psc(bw) != 1 && psc(aw) == 1 && psc(bh) == psc(ah) && psc(bc) == psc(ac))
{
// special type 7
ax = 0;
}

if (psc(bw) == psc(aw) && psc(bh) != 1 && psc(ah) == 1 && psc(bc) == psc(ac))
{
// special type 8
ay = 0;
}
}

if (psc(bdims) == 2)
{
// type 18
bx = gy;
by = gz;
bz = 0;
}

if (psc(bdims) == 1)
{
if (psc(bw) == 1)
{
// type 16
bx = 0;
by = 0;
bz = 0;
}
else
{
// type 17
bx = gz;
by = 0;
bz = 0;
}
}
}
else if (psc(adims) == 2)
{
if (psc(bdims) == 4)
{
int yd = gy / psc(outh);
int yh = gy % psc(outh);

// type 22
ax = yd;
ay = gz;
az = 0;
}

if (psc(bdims) == 3)
{
// type 14
ax = gy;
ay = gz;
az = 0;
}

if (psc(bdims) == 1)
{
if (psc(bw) == 1)
{
// type 11
bx = 0;
by = 0;
bz = 0;
}
else
{
// type 12
bx = gy;
by = 0;
bz = 0;
}
}
}
else if (psc(adims) == 1)
{
if (psc(aw) == 1)
{
// type 2 3 4 20
ax = 0;
ay = 0;
az = 0;
}
else
{
if (psc(bdims) == 4)
{
// type 21
ax = gz;
ay = 0;
az = 0;
}

if (psc(bdims) == 3)
{
// type 9
ax = gz;
ay = 0;
az = 0;
}

if (psc(bdims) == 2)
{
// type 8
ax = gy;
ay = 0;
az = 0;
}

if (psc(bdims) == 1)
{
if (psc(bw) == 1)
{
// type 6
bx = 0;
by = 0;
bz = 0;
}
}
}
}

afpvec8 v1 = image3d_ld8(a_blob_3d, ivec3(ax, ay, az));
afpvec8 v2 = image3d_ld8(b_blob_3d, ivec3(bx, by, bz));
#else
const int gi = gz * psc(outcstep) + gy * psc(outw) + gx;

int ai;
int bi;

if (psc(adims) == 4)
{
int yd = gy / psc(outh);
int yh = gy % psc(outh);

if (psc(bdims) == 3)
{
// type 28
ai = gi;
bi = gz * psc(bcstep) + yd * psc(bw) + yh;
}

if (psc(bdims) == 2)
{
// type 27
ai = gi;
bi = gz * psc(bw) + yd;
}

if (psc(bdims) == 1)
{
if (psc(bw) == 1)
{
// type 25
ai = gi;
bi = 0;
}
else
{
// type 26
ai = gi;
bi = gz;
}
}
}
else if (psc(adims) == 3)
{
if (psc(bdims) == 4)
{
int yd = gy / psc(outh);
int yh = gy % psc(outh);

// type 23
ai = gz * psc(acstep) + yd * psc(aw) + yh;
bi = gi;
}

if (psc(bdims) == 3)
{
if (psc(bw) == 1 && psc(bh) == 1)
{
// special type 1
ai = gi;
bi = gz * psc(bcstep);
}

if (psc(aw) == 1 && psc(ah) == 1)
{
// special type 3
ai = gz * psc(acstep);
bi = gi;
}

if (psc(aw) != 1 && psc(bw) == 1 && psc(bh) == psc(ah) && psc(bc) == psc(ac))
{
// special type 5
bi = gz * psc(bcstep) + gy;
ai = gi;
}

if (psc(bw) == psc(aw) && psc(ah) != 1 && psc(bh) == 1 && psc(bc) == psc(ac))
{
// special type 6
bi = gz * psc(bcstep) + gx;
ai = gi;
}

if (psc(bw) != 1 && psc(aw) == 1 && psc(bh) == psc(ah) && psc(bc) == psc(ac))
{
// special type 7
ai = gz * psc(acstep) + gy;
bi = gi;
}

if (psc(bw) == psc(aw) && psc(bh) != 1 && psc(ah) == 1 && psc(bc) == psc(ac))
{
// special type 8
ai = gz * psc(acstep) + gx;
bi = gi;
}
}

if (psc(bdims) == 2)
{
// type 18
ai = gi;
bi = gz * psc(bw) + gy;
}

if (psc(bdims) == 1)
{
// type 17
ai = gi;
bi = gz;
}
}
else if (psc(adims) == 2)
{
if (psc(bdims) == 4)
{
int yd = gy / psc(outh);
int yh = gy % psc(outh);

// type 22
ai = gz * psc(aw) + yd;
bi = gi;
}

if (psc(bdims) == 3)
{
// type 14
ai = gz * psc(aw) + gy;
bi = gi;
}

if (psc(bdims) == 1)
{
// type 12
ai = gi;
bi = gy;
}
}
else if (psc(adims) == 1)
{
if (psc(bdims) == 4)
{
// type 21
ai = gz;
bi = gi;
}

if (psc(bdims) == 3)
{
// type 9
ai = gz;
bi = gi;
}

if (psc(bdims) == 2)
{
// type 8
ai = gy;
bi = gi;
}
}

afpvec8 v1 = buffer_ld8(a_blob_data, ai);
afpvec8 v2 = buffer_ld8(b_blob_data, bi);
#endif

afpvec8 res;

if (op_type == 0)
{
res[0] = v1[0] + v2[0];
res[1] = v1[1] + v2[1];
}
if (op_type == 1)
{
res[0] = v1[0] - v2[0];
res[1] = v1[1] - v2[1];
}
if (op_type == 2)
{
res[0] = v1[0] * v2[0];
res[1] = v1[1] * v2[1];
}
if (op_type == 3)
{
res[0] = v1[0] / v2[0];
res[1] = v1[1] / v2[1];
}
if (op_type == 4)
{
res[0] = max(v1[0], v2[0]);
res[1] = max(v1[1], v2[1]);
}
if (op_type == 5)
{
res[0] = min(v1[0], v2[0]);
res[1] = min(v1[1], v2[1]);
}
if (op_type == 6)
{
res[0] = pow(v1[0], v2[0]);
res[1] = pow(v1[1], v2[1]);
}
if (op_type == 7)
{
res[0] = v2[0] - v1[0];
res[1] = v2[1] - v1[1];
}
if (op_type == 8)
{
res[0] = v2[0] / v1[0];
res[1] = v2[1] / v1[1];
}

#if NCNN_image_shader
image3d_st8(top_blob_3d, ivec3(gx, gy, gz), res);
#else
buffer_st8(top_blob_data, gi, res);
#endif
}

+ 21
- 34
src/layer/vulkan/shader/binaryop_pack4.comp View File

@@ -92,52 +92,39 @@ void main()
afpvec4 v1 = buffer_ld4(a_blob_data, gi);
#endif

afpvec4 res;
afpvec4 v2;

if (with_scalar == 1)
{
// type 5 10 15
afp b = afp(const_b);

if (op_type == 0) res = v1 + b;
if (op_type == 1) res = v1 - b;
if (op_type == 2) res = v1 * b;
if (op_type == 3) res = v1 / b;
if (op_type == 4) res = max(v1, b);
if (op_type == 5) res = min(v1, b);
if (op_type == 6) res = pow(v1, afpvec4(b));
if (op_type == 7) res = b - v1;
if (op_type == 8) res = b / v1;

#if NCNN_image_shader
image3d_st4(top_blob_3d, ivec3(gx, gy, gz), res);
#else
buffer_st4(a_blob_data, gi, res);
#endif
// type 0 1 2 3
v2 = afpvec4(const_b);
}
else
{
// type 7 13 19
// type 4 5 6 7
#if NCNN_image_shader
afpvec4 v2 = image3d_ld4(b_blob_3d, ivec3(gx, gy, gz));
v2 = image3d_ld4(b_blob_3d, ivec3(gx, gy, gz));
#else
afpvec4 v2 = buffer_ld4(b_blob_data, gi);
v2 = buffer_ld4(b_blob_data, gi);
#endif
}

afpvec4 res;

if (op_type == 0) res = v1 + v2;
if (op_type == 1) res = v1 - v2;
if (op_type == 2) res = v1 * v2;
if (op_type == 3) res = v1 / v2;
if (op_type == 4) res = max(v1, v2);
if (op_type == 5) res = min(v1, v2);
if (op_type == 6) res = pow(v1, v2);
if (op_type == 7) res = v2 - v1;
if (op_type == 8) res = v2 / v1;
if (op_type == 0) res = v1 + v2;
if (op_type == 1) res = v1 - v2;
if (op_type == 2) res = v1 * v2;
if (op_type == 3) res = v1 / v2;
if (op_type == 4) res = max(v1, v2);
if (op_type == 5) res = min(v1, v2);
if (op_type == 6) res = pow(v1, v2);
if (op_type == 7) res = v2 - v1;
if (op_type == 8) res = v2 / v1;
if (op_type == 9) res = pow(v2, v1);

#if NCNN_image_shader
image3d_st4(top_blob_3d, ivec3(gx, gy, gz), res);
image3d_st4(top_blob_3d, ivec3(gx, gy, gz), res);
#else
buffer_st4(top_blob_data, gi, res);
buffer_st4(top_blob_data, gi, res);
#endif
}
}

+ 62
- 106
src/layer/vulkan/shader/binaryop_pack8.comp View File

@@ -93,124 +93,80 @@ void main()
afpvec8 v1 = buffer_ld8(a_blob_data, gi);
#endif

afpvec8 res;
afpvec8 v2;

if (with_scalar == 1)
{
// type 5 10 15
afp b = afp(const_b);

if (op_type == 0)
{
res[0] = v1[0] + b;
res[1] = v1[1] + b;
}
if (op_type == 1)
{
res[0] = v1[0] - b;
res[1] = v1[1] - b;
}
if (op_type == 2)
{
res[0] = v1[0] * b;
res[1] = v1[1] * b;
}
if (op_type == 3)
{
res[0] = v1[0] / b;
res[1] = v1[1] / b;
}
if (op_type == 4)
{
res[0] = max(v1[0], b);
res[1] = max(v1[1], b);
}
if (op_type == 5)
{
res[0] = min(v1[0], b);
res[1] = min(v1[1], b);
}
if (op_type == 6)
{
res[0] = pow(v1[0], afpvec4(b));
res[1] = pow(v1[1], afpvec4(b));
}
if (op_type == 7)
{
res[0] = b - v1[0];
res[1] = b - v1[1];
}
if (op_type == 8)
{
res[0] = b / v1[0];
res[1] = b / v1[1];
}

#if NCNN_image_shader
image3d_st8(top_blob_3d, ivec3(gx, gy, gz), res);
#else
buffer_st8(a_blob_data, gi, res);
#endif
// type 0 1 2 3
v2[0] = afpvec4(const_b);
v2[1] = afpvec4(const_b);
}
else
{
// type 7 13 19
// type 4 5 6 7
#if NCNN_image_shader
afpvec8 v2 = image3d_ld8(b_blob_3d, ivec3(gx, gy, gz));
v2 = image3d_ld8(b_blob_3d, ivec3(gx, gy, gz));
#else
afpvec8 v2 = buffer_ld8(b_blob_data, gi);
v2 = buffer_ld8(b_blob_data, gi);
#endif
}

if (op_type == 0)
{
res[0] = v1[0] + v2[0];
res[1] = v1[1] + v2[1];
}
if (op_type == 1)
{
res[0] = v1[0] - v2[0];
res[1] = v1[1] - v2[1];
}
if (op_type == 2)
{
res[0] = v1[0] * v2[0];
res[1] = v1[1] * v2[1];
}
if (op_type == 3)
{
res[0] = v1[0] / v2[0];
res[1] = v1[1] / v2[1];
}
if (op_type == 4)
{
res[0] = max(v1[0], v2[0]);
res[1] = max(v1[1], v2[1]);
}
if (op_type == 5)
{
res[0] = min(v1[0], v2[0]);
res[1] = min(v1[1], v2[1]);
}
if (op_type == 6)
{
res[0] = pow(v1[0], v2[0]);
res[1] = pow(v1[1], v2[1]);
}
if (op_type == 7)
{
res[0] = v2[0] - v1[0];
res[1] = v2[1] - v1[1];
}
if (op_type == 8)
{
res[0] = v2[0] / v1[0];
res[1] = v2[1] / v1[1];
}
afpvec8 res;

if (op_type == 0)
{
res[0] = v1[0] + v2[0];
res[1] = v1[1] + v2[1];
}
if (op_type == 1)
{
res[0] = v1[0] - v2[0];
res[1] = v1[1] - v2[1];
}
if (op_type == 2)
{
res[0] = v1[0] * v2[0];
res[1] = v1[1] * v2[1];
}
if (op_type == 3)
{
res[0] = v1[0] / v2[0];
res[1] = v1[1] / v2[1];
}
if (op_type == 4)
{
res[0] = max(v1[0], v2[0]);
res[1] = max(v1[1], v2[1]);
}
if (op_type == 5)
{
res[0] = min(v1[0], v2[0]);
res[1] = min(v1[1], v2[1]);
}
if (op_type == 6)
{
res[0] = pow(v1[0], v2[0]);
res[1] = pow(v1[1], v2[1]);
}
if (op_type == 7)
{
res[0] = v2[0] - v1[0];
res[1] = v2[1] - v1[1];
}
if (op_type == 8)
{
res[0] = v2[0] / v1[0];
res[1] = v2[1] / v1[1];
}
if (op_type == 9)
{
res[0] = pow(v2[0], v1[0]);
res[1] = pow(v2[1], v1[1]);
}

#if NCNN_image_shader
image3d_st8(top_blob_3d, ivec3(gx, gy, gz), res);
image3d_st8(top_blob_3d, ivec3(gx, gy, gz), res);
#else
buffer_st8(top_blob_data, gi, res);
buffer_st8(top_blob_data, gi, res);
#endif
}
}

+ 633
- 2178
src/layer/x86/binaryop_x86.cpp
File diff suppressed because it is too large
View File


+ 237
- 281
tests/test_binaryop.cpp View File

@@ -15,7 +15,7 @@
#include "layer/binaryop.h"
#include "testutil.h"

#define OP_TYPE_MAX 9
#define OP_TYPE_MAX 10

static int op_type = 0;

@@ -23,15 +23,19 @@ static int test_binaryop(const ncnn::Mat& _a, const ncnn::Mat& _b)
{
ncnn::Mat a = _a;
ncnn::Mat b = _b;
if (op_type == 6)
if (op_type == 6 || op_type == 9)
{
// value must be positive for pow
// value must be positive for pow/rpow
a = a.clone();
b = b.clone();
Randomize(a, 0.001f, 2.f);
Randomize(b, 0.001f, 2.f);
}
if (op_type == 3 || op_type == 8)
{
// value must be positive for pow
// value must be positive for div/rdiv
a = a.clone();
b = b.clone();
Randomize(a, 0.1f, 10.f);
Randomize(b, 0.1f, 10.f);
}
@@ -59,12 +63,18 @@ static int test_binaryop(const ncnn::Mat& _a, const ncnn::Mat& _b)
static int test_binaryop(const ncnn::Mat& _a, float b)
{
ncnn::Mat a = _a;
if (op_type == 6)
if (op_type == 6 || op_type == 9)
{
// value must be positive for pow
Randomize(a, 0.001f, 2.f);
b = RandomFloat(0.001f, 2.f);
}
if (op_type == 3 || op_type == 8)
{
// value must be positive for div/rdiv
a = a.clone();
Randomize(a, 0.1f, 10.f);
}

ncnn::ParamDict pd;
pd.set(0, op_type);
@@ -82,300 +92,274 @@ static int test_binaryop(const ncnn::Mat& _a, float b)
return ret;
}

// https://github.com/Tencent/ncnn/wiki/binaryop-broadcasting

static int test_binaryop_1()
{
return 0
|| test_binaryop(RandomMat(1), 1.f);
}

static int test_binaryop_2()
{
return 0
|| test_binaryop(RandomMat(1), RandomMat(1))
|| test_binaryop(RandomMat(1), RandomMat(4))
|| test_binaryop(RandomMat(1), RandomMat(16));
}

static int test_binaryop_3()
{
return 0
|| test_binaryop(RandomMat(1), RandomMat(11, 3))
|| test_binaryop(RandomMat(1), RandomMat(11, 4))
|| test_binaryop(RandomMat(1), RandomMat(11, 16));
}

static int test_binaryop_4()
{
return 0
|| test_binaryop(RandomMat(1), RandomMat(11, 6, 2))
|| test_binaryop(RandomMat(1), RandomMat(11, 6, 4))
|| test_binaryop(RandomMat(1), RandomMat(11, 6, 16));
}

static int test_binaryop_5()
{
return 0
|| test_binaryop(RandomMat(2), 1.f)
|| test_binaryop(RandomMat(4), 1.f)
|| test_binaryop(RandomMat(16), 1.f);
}

static int test_binaryop_6()
{
return 0
|| test_binaryop(RandomMat(2), RandomMat(1))
|| test_binaryop(RandomMat(4), RandomMat(1))
|| test_binaryop(RandomMat(16), RandomMat(1));
}

static int test_binaryop_7()
{
return 0
|| test_binaryop(RandomMat(2), RandomMat(2))
|| test_binaryop(RandomMat(4), RandomMat(4))
|| test_binaryop(RandomMat(16), RandomMat(16));
}
ncnn::Mat a[] = {
RandomMat(31),
RandomMat(28),
RandomMat(24),
RandomMat(32),
RandomMat(13, 31),
RandomMat(14, 28),
RandomMat(15, 24),
RandomMat(16, 32),
RandomMat(7, 3, 31),
RandomMat(6, 4, 28),
RandomMat(5, 5, 24),
RandomMat(4, 6, 32),
RandomMat(2, 7, 3, 31),
RandomMat(3, 6, 4, 28),
RandomMat(4, 5, 5, 24),
RandomMat(5, 4, 6, 32)
};

ncnn::Mat b[] = {
RandomMat(1),
RandomMat(1, 1),
RandomMat(1, 1, 1),
RandomMat(1, 1, 1, 1)
};

for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++)
{
for (int j = 0; j < sizeof(b) / sizeof(b[0]); j++)
{
int ret = test_binaryop(a[i], b[j]) || test_binaryop(b[j], a[i]);
if (ret != 0)
return ret;
}

int ret = test_binaryop(a[i], 0.2f);
if (ret != 0)
return ret;
}

static int test_binaryop_8()
{
return 0
|| test_binaryop(RandomMat(3), RandomMat(11, 3))
|| test_binaryop(RandomMat(4), RandomMat(11, 4))
|| test_binaryop(RandomMat(16), RandomMat(11, 16));
return 0;
}

static int test_binaryop_9()
static int test_binaryop_2()
{
return 0
|| test_binaryop(RandomMat(2), RandomMat(11, 6, 2))
|| test_binaryop(RandomMat(4), RandomMat(11, 6, 4))
|| test_binaryop(RandomMat(16), RandomMat(11, 6, 16));
}
ncnn::Mat a[] = {
RandomMat(31),
RandomMat(28),
RandomMat(24),
RandomMat(32),
RandomMat(13, 31),
RandomMat(14, 28),
RandomMat(15, 24),
RandomMat(16, 32),
RandomMat(7, 3, 31),
RandomMat(6, 4, 28),
RandomMat(5, 5, 24),
RandomMat(4, 6, 32),
RandomMat(2, 7, 3, 31),
RandomMat(3, 6, 4, 28),
RandomMat(4, 5, 5, 24),
RandomMat(5, 4, 6, 32)
};

for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++)
{
ncnn::Mat b;
b.create_like(a[i]);
Randomize(b);

static int test_binaryop_10()
{
return 0
|| test_binaryop(RandomMat(11, 3), 1.f)
|| test_binaryop(RandomMat(11, 4), 1.f)
|| test_binaryop(RandomMat(11, 16), 1.f);
}
int ret = test_binaryop(a[i], b) || test_binaryop(b, a[i]);
if (ret != 0)
return ret;
}

static int test_binaryop_11()
{
return 0
|| test_binaryop(RandomMat(11, 3), RandomMat(1))
|| test_binaryop(RandomMat(11, 4), RandomMat(1))
|| test_binaryop(RandomMat(11, 16), RandomMat(1));
return 0;
}

static int test_binaryop_12()
static int test_binaryop_3()
{
return 0
|| test_binaryop(RandomMat(11, 3), RandomMat(3))
|| test_binaryop(RandomMat(11, 4), RandomMat(4))
|| test_binaryop(RandomMat(11, 16), RandomMat(16));
}
ncnn::Mat a[] = {
RandomMat(13, 31),
RandomMat(14, 28),
RandomMat(15, 24),
RandomMat(16, 32)
};

static int test_binaryop_13()
{
return 0
|| test_binaryop(RandomMat(11, 3), RandomMat(11, 3))
|| test_binaryop(RandomMat(11, 4), RandomMat(11, 4))
|| test_binaryop(RandomMat(11, 16), RandomMat(11, 16));
}
for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++)
{
ncnn::Mat b0(a[i].h);
ncnn::Mat b1(1, a[i].h);
Randomize(b0);
Randomize(b1);

static int test_binaryop_14()
{
return 0
|| test_binaryop(RandomMat(6, 2), RandomMat(11, 6, 2))
|| test_binaryop(RandomMat(6, 4), RandomMat(11, 6, 4))
|| test_binaryop(RandomMat(6, 16), RandomMat(11, 6, 16));
}
int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i])
|| test_binaryop(a[i], b1) || test_binaryop(b1, a[i]);
if (ret != 0)
return ret;
}

static int test_binaryop_15()
{
return 0
|| test_binaryop(RandomMat(11, 6, 2), 1.f)
|| test_binaryop(RandomMat(11, 6, 4), 1.f)
|| test_binaryop(RandomMat(11, 6, 16), 1.f);
return 0;
}

static int test_binaryop_16()
static int test_binaryop_4()
{
return 0
|| test_binaryop(RandomMat(11, 6, 2), RandomMat(1))
|| test_binaryop(RandomMat(11, 6, 4), RandomMat(1))
|| test_binaryop(RandomMat(11, 6, 16), RandomMat(1));
}
ncnn::Mat a[] = {
RandomMat(7, 3, 31),
RandomMat(6, 4, 28),
RandomMat(5, 5, 24),
RandomMat(4, 6, 32)
};

static int test_binaryop_17()
{
return 0
|| test_binaryop(RandomMat(11, 6, 2), RandomMat(2))
|| test_binaryop(RandomMat(11, 6, 4), RandomMat(4))
|| test_binaryop(RandomMat(11, 6, 16), RandomMat(16));
}

static int test_binaryop_18()
{
return 0
|| test_binaryop(RandomMat(11, 6, 2), RandomMat(6, 2))
|| test_binaryop(RandomMat(11, 6, 4), RandomMat(6, 4))
|| test_binaryop(RandomMat(11, 6, 16), RandomMat(6, 16));
}
for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++)
{
ncnn::Mat b0(a[i].c);
ncnn::Mat b1(1, 1, a[i].c);
ncnn::Mat b2(a[i].h, a[i].c);
ncnn::Mat b3(1, a[i].h, a[i].c);
Randomize(b0);
Randomize(b1);
Randomize(b2);
Randomize(b3);

int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i])
|| test_binaryop(a[i], b1) || test_binaryop(b1, a[i])
|| test_binaryop(a[i], b2) || test_binaryop(b2, a[i])
|| test_binaryop(a[i], b3) || test_binaryop(b3, a[i]);
if (ret != 0)
return ret;
}

static int test_binaryop_19()
{
return 0
|| test_binaryop(RandomMat(11, 6, 2), RandomMat(11, 6, 2))
|| test_binaryop(RandomMat(11, 6, 4), RandomMat(11, 6, 4))
|| test_binaryop(RandomMat(11, 6, 16), RandomMat(11, 6, 16));
return 0;
}

static int test_binaryop_20()
static int test_binaryop_5()
{
return 0
|| test_binaryop(RandomMat(1), RandomMat(11, 3, 4, 2))
|| test_binaryop(RandomMat(1), RandomMat(11, 3, 4, 4))
|| test_binaryop(RandomMat(1), RandomMat(11, 3, 4, 16));
}
ncnn::Mat a[] = {
RandomMat(2, 7, 3, 31),
RandomMat(3, 6, 4, 28),
RandomMat(4, 5, 5, 24),
RandomMat(5, 4, 6, 32)
};

static int test_binaryop_21()
{
return 0
|| test_binaryop(RandomMat(2), RandomMat(11, 3, 4, 2))
|| test_binaryop(RandomMat(4), RandomMat(11, 3, 4, 4))
|| test_binaryop(RandomMat(16), RandomMat(11, 3, 4, 16));
}
for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++)
{
ncnn::Mat b0(a[i].c);
ncnn::Mat b1(1, 1, 1, a[i].c);
ncnn::Mat b2(a[i].d, a[i].c);
ncnn::Mat b3(1, 1, a[i].d, a[i].c);
ncnn::Mat b4(a[i].h, a[i].d, a[i].c);
ncnn::Mat b5(1, a[i].h, a[i].d, a[i].c);
Randomize(b0);
Randomize(b1);
Randomize(b2);
Randomize(b3);
Randomize(b4);
Randomize(b5);

int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i])
|| test_binaryop(a[i], b1) || test_binaryop(b1, a[i])
|| test_binaryop(a[i], b2) || test_binaryop(b2, a[i])
|| test_binaryop(a[i], b3) || test_binaryop(b3, a[i])
|| test_binaryop(a[i], b4) || test_binaryop(b4, a[i])
|| test_binaryop(a[i], b5) || test_binaryop(b5, a[i]);
if (ret != 0)
return ret;
}

static int test_binaryop_22()
{
return 0
|| test_binaryop(RandomMat(4, 2), RandomMat(11, 3, 4, 2))
|| test_binaryop(RandomMat(4, 4), RandomMat(11, 3, 4, 4))
|| test_binaryop(RandomMat(4, 16), RandomMat(11, 3, 4, 16));
return 0;
}

static int test_binaryop_23()
static int test_binaryop_6()
{
return 0
|| test_binaryop(RandomMat(3, 4, 2), RandomMat(11, 3, 4, 2))
|| test_binaryop(RandomMat(3, 4, 4), RandomMat(11, 3, 4, 4))
|| test_binaryop(RandomMat(3, 4, 16), RandomMat(11, 3, 4, 16));
}
ncnn::Mat a[] = {
RandomMat(13, 31),
RandomMat(14, 28),
RandomMat(15, 24),
RandomMat(16, 32)
};

static int test_binaryop_24()
{
return 0
|| test_binaryop(RandomMat(11, 3, 4, 2), 1.f)
|| test_binaryop(RandomMat(11, 3, 4, 4), 1.f)
|| test_binaryop(RandomMat(11, 3, 4, 16), 1.f);
}
for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++)
{
ncnn::Mat b0(a[i].w, 1);
Randomize(b0);

static int test_binaryop_25()
{
return 0
|| test_binaryop(RandomMat(11, 3, 4, 2), RandomMat(1))
|| test_binaryop(RandomMat(11, 3, 4, 4), RandomMat(1))
|| test_binaryop(RandomMat(11, 3, 4, 16), RandomMat(1));
}
int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i]);
if (ret != 0)
return ret;
}

static int test_binaryop_26()
{
return 0
|| test_binaryop(RandomMat(11, 3, 4, 2), RandomMat(2))
|| test_binaryop(RandomMat(11, 3, 4, 4), RandomMat(4))
|| test_binaryop(RandomMat(11, 3, 4, 16), RandomMat(16));
return 0;
}

static int test_binaryop_27()
static int test_binaryop_7()
{
return 0
|| test_binaryop(RandomMat(11, 3, 4, 2), RandomMat(4, 2))
|| test_binaryop(RandomMat(11, 3, 4, 4), RandomMat(4, 4))
|| test_binaryop(RandomMat(11, 3, 4, 16), RandomMat(4, 16));
}
ncnn::Mat a[] = {
RandomMat(7, 3, 31),
RandomMat(6, 4, 28),
RandomMat(5, 5, 24),
RandomMat(4, 6, 32)
};

static int test_binaryop_28()
{
return 0
|| test_binaryop(RandomMat(11, 3, 4, 2), RandomMat(3, 4, 2))
|| test_binaryop(RandomMat(11, 3, 4, 4), RandomMat(3, 4, 4))
|| test_binaryop(RandomMat(11, 3, 4, 16), RandomMat(3, 4, 16));
}
for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++)
{
ncnn::Mat b0(a[i].w, 1, 1);
ncnn::Mat b1(a[i].w, a[i].h, 1);
Randomize(b0);
Randomize(b1);

static int test_binaryop_29()
{
return 0
|| test_binaryop(RandomMat(11, 3, 4, 2), RandomMat(11, 3, 4, 2))
|| test_binaryop(RandomMat(11, 3, 4, 4), RandomMat(11, 3, 4, 4))
|| test_binaryop(RandomMat(11, 3, 4, 16), RandomMat(11, 3, 4, 16));
}
int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i])
|| test_binaryop(a[i], b1) || test_binaryop(b1, a[i]);
if (ret != 0)
return ret;
}

static int test_binaryop_s1()
{
return 0
|| test_binaryop(RandomMat(11, 6, 2), RandomMat(1, 1, 2))
|| test_binaryop(RandomMat(11, 6, 4), RandomMat(1, 1, 4))
|| test_binaryop(RandomMat(11, 6, 16), RandomMat(1, 1, 16));
return 0;
}

static int test_binaryop_s2()
static int test_binaryop_8()
{
return 0
|| test_binaryop(RandomMat(11, 6, 2), RandomMat(11, 6, 1))
|| test_binaryop(RandomMat(11, 6, 4), RandomMat(11, 6, 1))
|| test_binaryop(RandomMat(11, 6, 16), RandomMat(11, 6, 1));
}
ncnn::Mat a[] = {
RandomMat(2, 7, 3, 31),
RandomMat(3, 6, 4, 28),
RandomMat(4, 5, 5, 24),
RandomMat(5, 4, 6, 32)
};

static int test_binaryop_s3()
{
return 0
|| test_binaryop(RandomMat(1, 1, 2), RandomMat(11, 6, 2))
|| test_binaryop(RandomMat(1, 1, 4), RandomMat(11, 6, 4))
|| test_binaryop(RandomMat(1, 1, 16), RandomMat(11, 6, 16));
}
for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++)
{
ncnn::Mat b0(a[i].w, 1, 1, 1);
ncnn::Mat b1(a[i].w, a[i].h, 1, 1);
ncnn::Mat b2(a[i].w, a[i].h, a[i].d, 1);
Randomize(b0);
Randomize(b1);
Randomize(b2);

int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i])
|| test_binaryop(a[i], b1) || test_binaryop(b1, a[i])
|| test_binaryop(a[i], b2) || test_binaryop(b2, a[i]);
if (ret != 0)
return ret;
}

static int test_binaryop_s4()
{
return 0
|| test_binaryop(RandomMat(11, 6, 1), RandomMat(11, 6, 2))
|| test_binaryop(RandomMat(11, 6, 1), RandomMat(11, 6, 4))
|| test_binaryop(RandomMat(11, 6, 1), RandomMat(11, 6, 16));
return 0;
}

static int test_binaryop_s5()
static int test_binaryop_9()
{
return 0
|| test_binaryop(RandomMat(11, 6, 2), RandomMat(1, 6, 2))
|| test_binaryop(RandomMat(11, 6, 4), RandomMat(1, 6, 4))
|| test_binaryop(RandomMat(11, 6, 16), RandomMat(1, 6, 16));
}
ncnn::Mat a[] = {
RandomMat(7, 3, 31),
RandomMat(6, 4, 28),
RandomMat(5, 5, 24),
RandomMat(4, 6, 32)
};

static int test_binaryop_s6()
{
return 0
|| test_binaryop(RandomMat(11, 6, 2), RandomMat(11, 1, 2))
|| test_binaryop(RandomMat(11, 6, 4), RandomMat(11, 1, 4))
|| test_binaryop(RandomMat(11, 6, 16), RandomMat(11, 1, 16));
}
for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++)
{
ncnn::Mat b0(a[i].w, 1, a[i].c);
Randomize(b0);

static int test_binaryop_s7()
{
return 0
|| test_binaryop(RandomMat(1, 6, 2), RandomMat(11, 6, 2))
|| test_binaryop(RandomMat(1, 6, 4), RandomMat(11, 6, 4))
|| test_binaryop(RandomMat(1, 6, 16), RandomMat(11, 6, 16));
}
int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i]);
if (ret != 0)
return ret;
}

static int test_binaryop_s8()
{
return 0
|| test_binaryop(RandomMat(11, 1, 2), RandomMat(11, 6, 2))
|| test_binaryop(RandomMat(11, 1, 4), RandomMat(11, 6, 4))
|| test_binaryop(RandomMat(11, 1, 16), RandomMat(11, 6, 16));
return 0;
}

int main()
@@ -393,35 +377,7 @@ int main()
|| test_binaryop_6()
|| test_binaryop_7()
|| test_binaryop_8()
|| test_binaryop_9()
|| test_binaryop_10()
|| test_binaryop_11()
|| test_binaryop_12()
|| test_binaryop_13()
|| test_binaryop_14()
|| test_binaryop_15()
|| test_binaryop_16()
|| test_binaryop_17()
|| test_binaryop_18()
|| test_binaryop_19()
|| test_binaryop_20()
|| test_binaryop_21()
|| test_binaryop_22()
|| test_binaryop_23()
|| test_binaryop_24()
|| test_binaryop_25()
|| test_binaryop_26()
|| test_binaryop_27()
|| test_binaryop_28()
|| test_binaryop_29()
|| test_binaryop_s1()
|| test_binaryop_s2()
|| test_binaryop_s3()
|| test_binaryop_s4()
|| test_binaryop_s5()
|| test_binaryop_s6()
|| test_binaryop_s7()
|| test_binaryop_s8();
|| test_binaryop_9();

if (ret != 0)
return ret;


+ 237
- 279
tests/test_binaryop_1.cpp View File

@@ -15,7 +15,7 @@
#include "layer/binaryop.h"
#include "testutil.h"

#define OP_TYPE_MAX 9
#define OP_TYPE_MAX 10

static int op_type = 0;

@@ -23,15 +23,19 @@ static int test_binaryop(const ncnn::Mat& _a, const ncnn::Mat& _b)
{
ncnn::Mat a = _a;
ncnn::Mat b = _b;
if (op_type == 6)
if (op_type == 6 || op_type == 9)
{
// value must be positive for pow
// value must be positive for pow/rpow
a = a.clone();
b = b.clone();
Randomize(a, 0.001f, 2.f);
Randomize(b, 0.001f, 2.f);
}
if (op_type == 3 || op_type == 8)
{
// value must be positive for pow
// value must be positive for div/rdiv
a = a.clone();
b = b.clone();
Randomize(a, 0.1f, 10.f);
Randomize(b, 0.1f, 10.f);
}
@@ -59,12 +63,18 @@ static int test_binaryop(const ncnn::Mat& _a, const ncnn::Mat& _b)
static int test_binaryop(const ncnn::Mat& _a, float b)
{
ncnn::Mat a = _a;
if (op_type == 6)
if (op_type == 6 || op_type == 9)
{
// value must be positive for pow
Randomize(a, 0.001f, 2.f);
b = RandomFloat(0.001f, 2.f);
}
if (op_type == 3 || op_type == 8)
{
// value must be positive for div/rdiv
a = a.clone();
Randomize(a, 0.1f, 10.f);
}

ncnn::ParamDict pd;
pd.set(0, op_type);
@@ -86,296 +96,272 @@ static int test_binaryop(const ncnn::Mat& _a, float b)

static int test_binaryop_1()
{
return 0
|| test_binaryop(RandomMat(1), 1.f);
}

static int test_binaryop_2()
{
return 0
|| test_binaryop(RandomMat(1), RandomMat(1))
|| test_binaryop(RandomMat(1), RandomMat(4))
|| test_binaryop(RandomMat(1), RandomMat(16));
}

static int test_binaryop_3()
{
return 0
|| test_binaryop(RandomMat(1), RandomMat(11, 3))
|| test_binaryop(RandomMat(1), RandomMat(11, 4))
|| test_binaryop(RandomMat(1), RandomMat(11, 16));
}

static int test_binaryop_4()
{
return 0
|| test_binaryop(RandomMat(1), RandomMat(11, 6, 2))
|| test_binaryop(RandomMat(1), RandomMat(11, 6, 4))
|| test_binaryop(RandomMat(1), RandomMat(11, 6, 16));
}

static int test_binaryop_5()
{
return 0
|| test_binaryop(RandomMat(2), 1.f)
|| test_binaryop(RandomMat(4), 1.f)
|| test_binaryop(RandomMat(16), 1.f);
}

static int test_binaryop_6()
{
return 0
|| test_binaryop(RandomMat(2), RandomMat(1))
|| test_binaryop(RandomMat(4), RandomMat(1))
|| test_binaryop(RandomMat(16), RandomMat(1));
}

static int test_binaryop_7()
{
return 0
|| test_binaryop(RandomMat(2), RandomMat(2))
|| test_binaryop(RandomMat(4), RandomMat(4))
|| test_binaryop(RandomMat(16), RandomMat(16));
}

static int test_binaryop_8()
{
return 0
|| test_binaryop(RandomMat(3), RandomMat(11, 3))
|| test_binaryop(RandomMat(4), RandomMat(11, 4))
|| test_binaryop(RandomMat(16), RandomMat(11, 16));
}
ncnn::Mat a[] = {
RandomMat(31),
RandomMat(28),
RandomMat(24),
RandomMat(32),
RandomMat(13, 31),
RandomMat(14, 28),
RandomMat(15, 24),
RandomMat(16, 32),
RandomMat(7, 3, 31),
RandomMat(6, 4, 28),
RandomMat(5, 5, 24),
RandomMat(4, 6, 32),
RandomMat(2, 7, 3, 31),
RandomMat(3, 6, 4, 28),
RandomMat(4, 5, 5, 24),
RandomMat(5, 4, 6, 32)
};

ncnn::Mat b[] = {
RandomMat(1),
RandomMat(1, 1),
RandomMat(1, 1, 1),
RandomMat(1, 1, 1, 1)
};

for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++)
{
for (int j = 0; j < sizeof(b) / sizeof(b[0]); j++)
{
int ret = test_binaryop(a[i], b[j]) || test_binaryop(b[j], a[i]);
if (ret != 0)
return ret;
}

int ret = test_binaryop(a[i], 0.2f);
if (ret != 0)
return ret;
}

static int test_binaryop_9()
{
return 0
|| test_binaryop(RandomMat(2), RandomMat(11, 6, 2))
|| test_binaryop(RandomMat(4), RandomMat(11, 6, 4))
|| test_binaryop(RandomMat(16), RandomMat(11, 6, 16));
return 0;
}

static int test_binaryop_10()
static int test_binaryop_2()
{
return 0
|| test_binaryop(RandomMat(11, 3), 1.f)
|| test_binaryop(RandomMat(11, 4), 1.f)
|| test_binaryop(RandomMat(11, 16), 1.f);
}
ncnn::Mat a[] = {
RandomMat(31),
RandomMat(28),
RandomMat(24),
RandomMat(32),
RandomMat(13, 31),
RandomMat(14, 28),
RandomMat(15, 24),
RandomMat(16, 32),
RandomMat(7, 3, 31),
RandomMat(6, 4, 28),
RandomMat(5, 5, 24),
RandomMat(4, 6, 32),
RandomMat(2, 7, 3, 31),
RandomMat(3, 6, 4, 28),
RandomMat(4, 5, 5, 24),
RandomMat(5, 4, 6, 32)
};

for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++)
{
ncnn::Mat b;
b.create_like(a[i]);
Randomize(b);

static int test_binaryop_11()
{
return 0
|| test_binaryop(RandomMat(11, 3), RandomMat(1))
|| test_binaryop(RandomMat(11, 4), RandomMat(1))
|| test_binaryop(RandomMat(11, 16), RandomMat(1));
}
int ret = test_binaryop(a[i], b) || test_binaryop(b, a[i]);
if (ret != 0)
return ret;
}

static int test_binaryop_12()
{
return 0
|| test_binaryop(RandomMat(11, 3), RandomMat(3))
|| test_binaryop(RandomMat(11, 4), RandomMat(4))
|| test_binaryop(RandomMat(11, 16), RandomMat(16));
return 0;
}

static int test_binaryop_13()
static int test_binaryop_3()
{
return 0
|| test_binaryop(RandomMat(11, 3), RandomMat(11, 3))
|| test_binaryop(RandomMat(11, 4), RandomMat(11, 4))
|| test_binaryop(RandomMat(11, 16), RandomMat(11, 16));
}
ncnn::Mat a[] = {
RandomMat(13, 31),
RandomMat(14, 28),
RandomMat(15, 24),
RandomMat(16, 32)
};

static int test_binaryop_14()
{
return 0
|| test_binaryop(RandomMat(6, 2), RandomMat(11, 6, 2))
|| test_binaryop(RandomMat(6, 4), RandomMat(11, 6, 4))
|| test_binaryop(RandomMat(6, 16), RandomMat(11, 6, 16));
}
for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++)
{
ncnn::Mat b0(a[i].h);
ncnn::Mat b1(1, a[i].h);
Randomize(b0);
Randomize(b1);

static int test_binaryop_15()
{
return 0
|| test_binaryop(RandomMat(11, 6, 2), 1.f)
|| test_binaryop(RandomMat(11, 6, 4), 1.f)
|| test_binaryop(RandomMat(11, 6, 16), 1.f);
}
int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i])
|| test_binaryop(a[i], b1) || test_binaryop(b1, a[i]);
if (ret != 0)
return ret;
}

static int test_binaryop_16()
{
return 0
|| test_binaryop(RandomMat(11, 6, 2), RandomMat(1))
|| test_binaryop(RandomMat(11, 6, 4), RandomMat(1))
|| test_binaryop(RandomMat(11, 6, 16), RandomMat(1));
return 0;
}

static int test_binaryop_17()
static int test_binaryop_4()
{
return 0
|| test_binaryop(RandomMat(11, 6, 2), RandomMat(2))
|| test_binaryop(RandomMat(11, 6, 4), RandomMat(4))
|| test_binaryop(RandomMat(11, 6, 16), RandomMat(16));
}
ncnn::Mat a[] = {
RandomMat(7, 3, 31),
RandomMat(6, 4, 28),
RandomMat(5, 5, 24),
RandomMat(4, 6, 32)
};

static int test_binaryop_18()
{
return 0
|| test_binaryop(RandomMat(11, 6, 2), RandomMat(6, 2))
|| test_binaryop(RandomMat(11, 6, 4), RandomMat(6, 4))
|| test_binaryop(RandomMat(11, 6, 16), RandomMat(6, 16));
}
for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++)
{
ncnn::Mat b0(a[i].c);
ncnn::Mat b1(1, 1, a[i].c);
ncnn::Mat b2(a[i].h, a[i].c);
ncnn::Mat b3(1, a[i].h, a[i].c);
Randomize(b0);
Randomize(b1);
Randomize(b2);
Randomize(b3);

int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i])
|| test_binaryop(a[i], b1) || test_binaryop(b1, a[i])
|| test_binaryop(a[i], b2) || test_binaryop(b2, a[i])
|| test_binaryop(a[i], b3) || test_binaryop(b3, a[i]);
if (ret != 0)
return ret;
}

static int test_binaryop_19()
{
return 0
|| test_binaryop(RandomMat(11, 6, 2), RandomMat(11, 6, 2))
|| test_binaryop(RandomMat(11, 6, 4), RandomMat(11, 6, 4))
|| test_binaryop(RandomMat(11, 6, 16), RandomMat(11, 6, 16));
return 0;
}

static int test_binaryop_20()
static int test_binaryop_5()
{
return 0
|| test_binaryop(RandomMat(1), RandomMat(11, 3, 4, 2))
|| test_binaryop(RandomMat(1), RandomMat(11, 3, 4, 4))
|| test_binaryop(RandomMat(1), RandomMat(11, 3, 4, 16));
}
ncnn::Mat a[] = {
RandomMat(2, 7, 3, 31),
RandomMat(3, 6, 4, 28),
RandomMat(4, 5, 5, 24),
RandomMat(5, 4, 6, 32)
};

static int test_binaryop_21()
{
return 0
|| test_binaryop(RandomMat(2), RandomMat(11, 3, 4, 2))
|| test_binaryop(RandomMat(4), RandomMat(11, 3, 4, 4))
|| test_binaryop(RandomMat(16), RandomMat(11, 3, 4, 16));
}
for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++)
{
ncnn::Mat b0(a[i].c);
ncnn::Mat b1(1, 1, 1, a[i].c);
ncnn::Mat b2(a[i].d, a[i].c);
ncnn::Mat b3(1, 1, a[i].d, a[i].c);
ncnn::Mat b4(a[i].h, a[i].d, a[i].c);
ncnn::Mat b5(1, a[i].h, a[i].d, a[i].c);
Randomize(b0);
Randomize(b1);
Randomize(b2);
Randomize(b3);
Randomize(b4);
Randomize(b5);

int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i])
|| test_binaryop(a[i], b1) || test_binaryop(b1, a[i])
|| test_binaryop(a[i], b2) || test_binaryop(b2, a[i])
|| test_binaryop(a[i], b3) || test_binaryop(b3, a[i])
|| test_binaryop(a[i], b4) || test_binaryop(b4, a[i])
|| test_binaryop(a[i], b5) || test_binaryop(b5, a[i]);
if (ret != 0)
return ret;
}

static int test_binaryop_22()
{
return 0
|| test_binaryop(RandomMat(4, 2), RandomMat(11, 3, 4, 2))
|| test_binaryop(RandomMat(4, 4), RandomMat(11, 3, 4, 4))
|| test_binaryop(RandomMat(4, 16), RandomMat(11, 3, 4, 16));
return 0;
}

static int test_binaryop_23()
static int test_binaryop_6()
{
return 0
|| test_binaryop(RandomMat(3, 4, 2), RandomMat(11, 3, 4, 2))
|| test_binaryop(RandomMat(3, 4, 4), RandomMat(11, 3, 4, 4))
|| test_binaryop(RandomMat(3, 4, 16), RandomMat(11, 3, 4, 16));
}
ncnn::Mat a[] = {
RandomMat(13, 31),
RandomMat(14, 28),
RandomMat(15, 24),
RandomMat(16, 32)
};

static int test_binaryop_24()
{
return 0
|| test_binaryop(RandomMat(11, 3, 4, 2), 1.f)
|| test_binaryop(RandomMat(11, 3, 4, 4), 1.f)
|| test_binaryop(RandomMat(11, 3, 4, 16), 1.f);
}
for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++)
{
ncnn::Mat b0(a[i].w, 1);
Randomize(b0);

static int test_binaryop_25()
{
return 0
|| test_binaryop(RandomMat(11, 3, 4, 2), RandomMat(1))
|| test_binaryop(RandomMat(11, 3, 4, 4), RandomMat(1))
|| test_binaryop(RandomMat(11, 3, 4, 16), RandomMat(1));
}
int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i]);
if (ret != 0)
return ret;
}

static int test_binaryop_26()
{
return 0
|| test_binaryop(RandomMat(11, 3, 4, 2), RandomMat(2))
|| test_binaryop(RandomMat(11, 3, 4, 4), RandomMat(4))
|| test_binaryop(RandomMat(11, 3, 4, 16), RandomMat(16));
return 0;
}

static int test_binaryop_27()
static int test_binaryop_7()
{
return 0
|| test_binaryop(RandomMat(11, 3, 4, 2), RandomMat(4, 2))
|| test_binaryop(RandomMat(11, 3, 4, 4), RandomMat(4, 4))
|| test_binaryop(RandomMat(11, 3, 4, 16), RandomMat(4, 16));
}
ncnn::Mat a[] = {
RandomMat(7, 3, 31),
RandomMat(6, 4, 28),
RandomMat(5, 5, 24),
RandomMat(4, 6, 32)
};

static int test_binaryop_28()
{
return 0
|| test_binaryop(RandomMat(11, 3, 4, 2), RandomMat(3, 4, 2))
|| test_binaryop(RandomMat(11, 3, 4, 4), RandomMat(3, 4, 4))
|| test_binaryop(RandomMat(11, 3, 4, 16), RandomMat(3, 4, 16));
}
for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++)
{
ncnn::Mat b0(a[i].w, 1, 1);
ncnn::Mat b1(a[i].w, a[i].h, 1);
Randomize(b0);
Randomize(b1);

static int test_binaryop_29()
{
return 0
|| test_binaryop(RandomMat(11, 3, 4, 2), RandomMat(11, 3, 4, 2))
|| test_binaryop(RandomMat(11, 3, 4, 4), RandomMat(11, 3, 4, 4))
|| test_binaryop(RandomMat(11, 3, 4, 16), RandomMat(11, 3, 4, 16));
}
int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i])
|| test_binaryop(a[i], b1) || test_binaryop(b1, a[i]);
if (ret != 0)
return ret;
}

static int test_binaryop_s1()
{
return 0
|| test_binaryop(RandomMat(11, 6, 2), RandomMat(1, 1, 2))
|| test_binaryop(RandomMat(11, 6, 4), RandomMat(1, 1, 4))
|| test_binaryop(RandomMat(11, 6, 16), RandomMat(1, 1, 16));
return 0;
}

static int test_binaryop_s2()
static int test_binaryop_8()
{
return 0
|| test_binaryop(RandomMat(11, 6, 2), RandomMat(11, 6, 1))
|| test_binaryop(RandomMat(11, 6, 4), RandomMat(11, 6, 1))
|| test_binaryop(RandomMat(11, 6, 16), RandomMat(11, 6, 1));
}
ncnn::Mat a[] = {
RandomMat(2, 7, 3, 31),
RandomMat(3, 6, 4, 28),
RandomMat(4, 5, 5, 24),
RandomMat(5, 4, 6, 32)
};

static int test_binaryop_s3()
{
return 0
|| test_binaryop(RandomMat(1, 1, 2), RandomMat(11, 6, 2))
|| test_binaryop(RandomMat(1, 1, 4), RandomMat(11, 6, 4))
|| test_binaryop(RandomMat(1, 1, 16), RandomMat(11, 6, 16));
}
for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++)
{
ncnn::Mat b0(a[i].w, 1, 1, 1);
ncnn::Mat b1(a[i].w, a[i].h, 1, 1);
ncnn::Mat b2(a[i].w, a[i].h, a[i].d, 1);
Randomize(b0);
Randomize(b1);
Randomize(b2);

int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i])
|| test_binaryop(a[i], b1) || test_binaryop(b1, a[i])
|| test_binaryop(a[i], b2) || test_binaryop(b2, a[i]);
if (ret != 0)
return ret;
}

static int test_binaryop_s4()
{
return 0
|| test_binaryop(RandomMat(11, 6, 1), RandomMat(11, 6, 2))
|| test_binaryop(RandomMat(11, 6, 1), RandomMat(11, 6, 4))
|| test_binaryop(RandomMat(11, 6, 1), RandomMat(11, 6, 16));
return 0;
}

static int test_binaryop_s5()
static int test_binaryop_9()
{
return 0
|| test_binaryop(RandomMat(11, 6, 2), RandomMat(1, 6, 2))
|| test_binaryop(RandomMat(11, 6, 4), RandomMat(1, 6, 4))
|| test_binaryop(RandomMat(11, 6, 16), RandomMat(1, 6, 16));
}
ncnn::Mat a[] = {
RandomMat(7, 3, 31),
RandomMat(6, 4, 28),
RandomMat(5, 5, 24),
RandomMat(4, 6, 32)
};

static int test_binaryop_s6()
{
return 0
|| test_binaryop(RandomMat(11, 6, 2), RandomMat(11, 1, 2))
|| test_binaryop(RandomMat(11, 6, 4), RandomMat(11, 1, 4))
|| test_binaryop(RandomMat(11, 6, 16), RandomMat(11, 1, 16));
}
for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++)
{
ncnn::Mat b0(a[i].w, 1, a[i].c);
Randomize(b0);

static int test_binaryop_s7()
{
return 0
|| test_binaryop(RandomMat(1, 6, 2), RandomMat(11, 6, 2))
|| test_binaryop(RandomMat(1, 6, 4), RandomMat(11, 6, 4))
|| test_binaryop(RandomMat(1, 6, 16), RandomMat(11, 6, 16));
}
int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i]);
if (ret != 0)
return ret;
}

static int test_binaryop_s8()
{
return 0
|| test_binaryop(RandomMat(11, 1, 2), RandomMat(11, 6, 2))
|| test_binaryop(RandomMat(11, 1, 4), RandomMat(11, 6, 4))
|| test_binaryop(RandomMat(11, 1, 16), RandomMat(11, 6, 16));
return 0;
}

int main()
@@ -393,35 +379,7 @@ int main()
|| test_binaryop_6()
|| test_binaryop_7()
|| test_binaryop_8()
|| test_binaryop_9()
|| test_binaryop_10()
|| test_binaryop_11()
|| test_binaryop_12()
|| test_binaryop_13()
|| test_binaryop_14()
|| test_binaryop_15()
|| test_binaryop_16()
|| test_binaryop_17()
|| test_binaryop_18()
|| test_binaryop_19()
|| test_binaryop_20()
|| test_binaryop_21()
|| test_binaryop_22()
|| test_binaryop_23()
|| test_binaryop_24()
|| test_binaryop_25()
|| test_binaryop_26()
|| test_binaryop_27()
|| test_binaryop_28()
|| test_binaryop_29()
|| test_binaryop_s1()
|| test_binaryop_s2()
|| test_binaryop_s3()
|| test_binaryop_s4()
|| test_binaryop_s5()
|| test_binaryop_s6()
|| test_binaryop_s7()
|| test_binaryop_s8();
|| test_binaryop_9();

if (ret != 0)
return ret;


+ 238
- 280
tests/test_binaryop_2.cpp View File

@@ -15,7 +15,7 @@
#include "layer/binaryop.h"
#include "testutil.h"

#define OP_TYPE_MAX 9
#define OP_TYPE_MAX 10

static int op_type = 0;

@@ -23,15 +23,19 @@ static int test_binaryop(const ncnn::Mat& _a, const ncnn::Mat& _b)
{
ncnn::Mat a = _a;
ncnn::Mat b = _b;
if (op_type == 6)
if (op_type == 6 || op_type == 9)
{
// value must be positive for pow
// value must be positive for pow/rpow
a = a.clone();
b = b.clone();
Randomize(a, 0.001f, 2.f);
Randomize(b, 0.001f, 2.f);
}
if (op_type == 3 || op_type == 8)
{
// value must be positive for pow
// value must be positive for div/rdiv
a = a.clone();
b = b.clone();
Randomize(a, 0.1f, 10.f);
Randomize(b, 0.1f, 10.f);
}
@@ -59,12 +63,18 @@ static int test_binaryop(const ncnn::Mat& _a, const ncnn::Mat& _b)
static int test_binaryop(const ncnn::Mat& _a, float b)
{
ncnn::Mat a = _a;
if (op_type == 6)
if (op_type == 6 || op_type == 9)
{
// value must be positive for pow
// value must be positive for pow/rpow
Randomize(a, 0.001f, 2.f);
b = RandomFloat(0.001f, 2.f);
}
if (op_type == 3 || op_type == 8)
{
// value must be positive for div/rdiv
a = a.clone();
Randomize(a, 0.1f, 10.f);
}

ncnn::ParamDict pd;
pd.set(0, op_type);
@@ -86,296 +96,272 @@ static int test_binaryop(const ncnn::Mat& _a, float b)

static int test_binaryop_1()
{
return 0
|| test_binaryop(RandomMat(1), 1.f);
}

static int test_binaryop_2()
{
return 0
|| test_binaryop(RandomMat(1), RandomMat(1))
|| test_binaryop(RandomMat(1), RandomMat(4))
|| test_binaryop(RandomMat(1), RandomMat(16));
}

static int test_binaryop_3()
{
return 0
|| test_binaryop(RandomMat(1), RandomMat(11, 3))
|| test_binaryop(RandomMat(1), RandomMat(11, 4))
|| test_binaryop(RandomMat(1), RandomMat(11, 16));
}

static int test_binaryop_4()
{
return 0
|| test_binaryop(RandomMat(1), RandomMat(11, 6, 2))
|| test_binaryop(RandomMat(1), RandomMat(11, 6, 4))
|| test_binaryop(RandomMat(1), RandomMat(11, 6, 16));
}

static int test_binaryop_5()
{
return 0
|| test_binaryop(RandomMat(2), 1.f)
|| test_binaryop(RandomMat(4), 1.f)
|| test_binaryop(RandomMat(16), 1.f);
}

static int test_binaryop_6()
{
return 0
|| test_binaryop(RandomMat(2), RandomMat(1))
|| test_binaryop(RandomMat(4), RandomMat(1))
|| test_binaryop(RandomMat(16), RandomMat(1));
}

static int test_binaryop_7()
{
return 0
|| test_binaryop(RandomMat(2), RandomMat(2))
|| test_binaryop(RandomMat(4), RandomMat(4))
|| test_binaryop(RandomMat(16), RandomMat(16));
}

static int test_binaryop_8()
{
return 0
|| test_binaryop(RandomMat(3), RandomMat(11, 3))
|| test_binaryop(RandomMat(4), RandomMat(11, 4))
|| test_binaryop(RandomMat(16), RandomMat(11, 16));
}
ncnn::Mat a[] = {
RandomMat(31),
RandomMat(28),
RandomMat(24),
RandomMat(32),
RandomMat(13, 31),
RandomMat(14, 28),
RandomMat(15, 24),
RandomMat(16, 32),
RandomMat(7, 3, 31),
RandomMat(6, 4, 28),
RandomMat(5, 5, 24),
RandomMat(4, 6, 32),
RandomMat(2, 7, 3, 31),
RandomMat(3, 6, 4, 28),
RandomMat(4, 5, 5, 24),
RandomMat(5, 4, 6, 32)
};

ncnn::Mat b[] = {
RandomMat(1),
RandomMat(1, 1),
RandomMat(1, 1, 1),
RandomMat(1, 1, 1, 1)
};

for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++)
{
for (int j = 0; j < sizeof(b) / sizeof(b[0]); j++)
{
int ret = test_binaryop(a[i], b[j]) || test_binaryop(b[j], a[i]);
if (ret != 0)
return ret;
}

int ret = test_binaryop(a[i], 0.2f);
if (ret != 0)
return ret;
}

static int test_binaryop_9()
{
return 0
|| test_binaryop(RandomMat(2), RandomMat(11, 6, 2))
|| test_binaryop(RandomMat(4), RandomMat(11, 6, 4))
|| test_binaryop(RandomMat(16), RandomMat(11, 6, 16));
return 0;
}

static int test_binaryop_10()
static int test_binaryop_2()
{
return 0
|| test_binaryop(RandomMat(11, 3), 1.f)
|| test_binaryop(RandomMat(11, 4), 1.f)
|| test_binaryop(RandomMat(11, 16), 1.f);
}
ncnn::Mat a[] = {
RandomMat(31),
RandomMat(28),
RandomMat(24),
RandomMat(32),
RandomMat(13, 31),
RandomMat(14, 28),
RandomMat(15, 24),
RandomMat(16, 32),
RandomMat(7, 3, 31),
RandomMat(6, 4, 28),
RandomMat(5, 5, 24),
RandomMat(4, 6, 32),
RandomMat(2, 7, 3, 31),
RandomMat(3, 6, 4, 28),
RandomMat(4, 5, 5, 24),
RandomMat(5, 4, 6, 32)
};

for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++)
{
ncnn::Mat b;
b.create_like(a[i]);
Randomize(b);

static int test_binaryop_11()
{
return 0
|| test_binaryop(RandomMat(11, 3), RandomMat(1))
|| test_binaryop(RandomMat(11, 4), RandomMat(1))
|| test_binaryop(RandomMat(11, 16), RandomMat(1));
}
int ret = test_binaryop(a[i], b) || test_binaryop(b, a[i]);
if (ret != 0)
return ret;
}

static int test_binaryop_12()
{
return 0
|| test_binaryop(RandomMat(11, 3), RandomMat(3))
|| test_binaryop(RandomMat(11, 4), RandomMat(4))
|| test_binaryop(RandomMat(11, 16), RandomMat(16));
return 0;
}

static int test_binaryop_13()
static int test_binaryop_3()
{
return 0
|| test_binaryop(RandomMat(11, 3), RandomMat(11, 3))
|| test_binaryop(RandomMat(11, 4), RandomMat(11, 4))
|| test_binaryop(RandomMat(11, 16), RandomMat(11, 16));
}
ncnn::Mat a[] = {
RandomMat(13, 31),
RandomMat(14, 28),
RandomMat(15, 24),
RandomMat(16, 32)
};

static int test_binaryop_14()
{
return 0
|| test_binaryop(RandomMat(6, 2), RandomMat(11, 6, 2))
|| test_binaryop(RandomMat(6, 4), RandomMat(11, 6, 4))
|| test_binaryop(RandomMat(6, 16), RandomMat(11, 6, 16));
}
for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++)
{
ncnn::Mat b0(a[i].h);
ncnn::Mat b1(1, a[i].h);
Randomize(b0);
Randomize(b1);

static int test_binaryop_15()
{
return 0
|| test_binaryop(RandomMat(11, 6, 2), 1.f)
|| test_binaryop(RandomMat(11, 6, 4), 1.f)
|| test_binaryop(RandomMat(11, 6, 16), 1.f);
}
int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i])
|| test_binaryop(a[i], b1) || test_binaryop(b1, a[i]);
if (ret != 0)
return ret;
}

static int test_binaryop_16()
{
return 0
|| test_binaryop(RandomMat(11, 6, 2), RandomMat(1))
|| test_binaryop(RandomMat(11, 6, 4), RandomMat(1))
|| test_binaryop(RandomMat(11, 6, 16), RandomMat(1));
return 0;
}

static int test_binaryop_17()
static int test_binaryop_4()
{
return 0
|| test_binaryop(RandomMat(11, 6, 2), RandomMat(2))
|| test_binaryop(RandomMat(11, 6, 4), RandomMat(4))
|| test_binaryop(RandomMat(11, 6, 16), RandomMat(16));
}
ncnn::Mat a[] = {
RandomMat(7, 3, 31),
RandomMat(6, 4, 28),
RandomMat(5, 5, 24),
RandomMat(4, 6, 32)
};

static int test_binaryop_18()
{
return 0
|| test_binaryop(RandomMat(11, 6, 2), RandomMat(6, 2))
|| test_binaryop(RandomMat(11, 6, 4), RandomMat(6, 4))
|| test_binaryop(RandomMat(11, 6, 16), RandomMat(6, 16));
}
for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++)
{
ncnn::Mat b0(a[i].c);
ncnn::Mat b1(1, 1, a[i].c);
ncnn::Mat b2(a[i].h, a[i].c);
ncnn::Mat b3(1, a[i].h, a[i].c);
Randomize(b0);
Randomize(b1);
Randomize(b2);
Randomize(b3);

int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i])
|| test_binaryop(a[i], b1) || test_binaryop(b1, a[i])
|| test_binaryop(a[i], b2) || test_binaryop(b2, a[i])
|| test_binaryop(a[i], b3) || test_binaryop(b3, a[i]);
if (ret != 0)
return ret;
}

static int test_binaryop_19()
{
return 0
|| test_binaryop(RandomMat(11, 6, 2), RandomMat(11, 6, 2))
|| test_binaryop(RandomMat(11, 6, 4), RandomMat(11, 6, 4))
|| test_binaryop(RandomMat(11, 6, 16), RandomMat(11, 6, 16));
return 0;
}

static int test_binaryop_20()
static int test_binaryop_5()
{
return 0
|| test_binaryop(RandomMat(1), RandomMat(11, 3, 4, 2))
|| test_binaryop(RandomMat(1), RandomMat(11, 3, 4, 4))
|| test_binaryop(RandomMat(1), RandomMat(11, 3, 4, 16));
}
ncnn::Mat a[] = {
RandomMat(2, 7, 3, 31),
RandomMat(3, 6, 4, 28),
RandomMat(4, 5, 5, 24),
RandomMat(5, 4, 6, 32)
};

static int test_binaryop_21()
{
return 0
|| test_binaryop(RandomMat(2), RandomMat(11, 3, 4, 2))
|| test_binaryop(RandomMat(4), RandomMat(11, 3, 4, 4))
|| test_binaryop(RandomMat(16), RandomMat(11, 3, 4, 16));
}
for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++)
{
ncnn::Mat b0(a[i].c);
ncnn::Mat b1(1, 1, 1, a[i].c);
ncnn::Mat b2(a[i].d, a[i].c);
ncnn::Mat b3(1, 1, a[i].d, a[i].c);
ncnn::Mat b4(a[i].h, a[i].d, a[i].c);
ncnn::Mat b5(1, a[i].h, a[i].d, a[i].c);
Randomize(b0);
Randomize(b1);
Randomize(b2);
Randomize(b3);
Randomize(b4);
Randomize(b5);

int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i])
|| test_binaryop(a[i], b1) || test_binaryop(b1, a[i])
|| test_binaryop(a[i], b2) || test_binaryop(b2, a[i])
|| test_binaryop(a[i], b3) || test_binaryop(b3, a[i])
|| test_binaryop(a[i], b4) || test_binaryop(b4, a[i])
|| test_binaryop(a[i], b5) || test_binaryop(b5, a[i]);
if (ret != 0)
return ret;
}

static int test_binaryop_22()
{
return 0
|| test_binaryop(RandomMat(4, 2), RandomMat(11, 3, 4, 2))
|| test_binaryop(RandomMat(4, 4), RandomMat(11, 3, 4, 4))
|| test_binaryop(RandomMat(4, 16), RandomMat(11, 3, 4, 16));
return 0;
}

static int test_binaryop_23()
static int test_binaryop_6()
{
return 0
|| test_binaryop(RandomMat(3, 4, 2), RandomMat(11, 3, 4, 2))
|| test_binaryop(RandomMat(3, 4, 4), RandomMat(11, 3, 4, 4))
|| test_binaryop(RandomMat(3, 4, 16), RandomMat(11, 3, 4, 16));
}
ncnn::Mat a[] = {
RandomMat(13, 31),
RandomMat(14, 28),
RandomMat(15, 24),
RandomMat(16, 32)
};

static int test_binaryop_24()
{
return 0
|| test_binaryop(RandomMat(11, 3, 4, 2), 1.f)
|| test_binaryop(RandomMat(11, 3, 4, 4), 1.f)
|| test_binaryop(RandomMat(11, 3, 4, 16), 1.f);
}
for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++)
{
ncnn::Mat b0(a[i].w, 1);
Randomize(b0);

static int test_binaryop_25()
{
return 0
|| test_binaryop(RandomMat(11, 3, 4, 2), RandomMat(1))
|| test_binaryop(RandomMat(11, 3, 4, 4), RandomMat(1))
|| test_binaryop(RandomMat(11, 3, 4, 16), RandomMat(1));
}
int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i]);
if (ret != 0)
return ret;
}

static int test_binaryop_26()
{
return 0
|| test_binaryop(RandomMat(11, 3, 4, 2), RandomMat(2))
|| test_binaryop(RandomMat(11, 3, 4, 4), RandomMat(4))
|| test_binaryop(RandomMat(11, 3, 4, 16), RandomMat(16));
return 0;
}

static int test_binaryop_27()
static int test_binaryop_7()
{
return 0
|| test_binaryop(RandomMat(11, 3, 4, 2), RandomMat(4, 2))
|| test_binaryop(RandomMat(11, 3, 4, 4), RandomMat(4, 4))
|| test_binaryop(RandomMat(11, 3, 4, 16), RandomMat(4, 16));
}
ncnn::Mat a[] = {
RandomMat(7, 3, 31),
RandomMat(6, 4, 28),
RandomMat(5, 5, 24),
RandomMat(4, 6, 32)
};

static int test_binaryop_28()
{
return 0
|| test_binaryop(RandomMat(11, 3, 4, 2), RandomMat(3, 4, 2))
|| test_binaryop(RandomMat(11, 3, 4, 4), RandomMat(3, 4, 4))
|| test_binaryop(RandomMat(11, 3, 4, 16), RandomMat(3, 4, 16));
}
for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++)
{
ncnn::Mat b0(a[i].w, 1, 1);
ncnn::Mat b1(a[i].w, a[i].h, 1);
Randomize(b0);
Randomize(b1);

static int test_binaryop_29()
{
return 0
|| test_binaryop(RandomMat(11, 3, 4, 2), RandomMat(11, 3, 4, 2))
|| test_binaryop(RandomMat(11, 3, 4, 4), RandomMat(11, 3, 4, 4))
|| test_binaryop(RandomMat(11, 3, 4, 16), RandomMat(11, 3, 4, 16));
}
int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i])
|| test_binaryop(a[i], b1) || test_binaryop(b1, a[i]);
if (ret != 0)
return ret;
}

static int test_binaryop_s1()
{
return 0
|| test_binaryop(RandomMat(11, 6, 2), RandomMat(1, 1, 2))
|| test_binaryop(RandomMat(11, 6, 4), RandomMat(1, 1, 4))
|| test_binaryop(RandomMat(11, 6, 16), RandomMat(1, 1, 16));
return 0;
}

static int test_binaryop_s2()
static int test_binaryop_8()
{
return 0
|| test_binaryop(RandomMat(11, 6, 2), RandomMat(11, 6, 1))
|| test_binaryop(RandomMat(11, 6, 4), RandomMat(11, 6, 1))
|| test_binaryop(RandomMat(11, 6, 16), RandomMat(11, 6, 1));
}
ncnn::Mat a[] = {
RandomMat(2, 7, 3, 31),
RandomMat(3, 6, 4, 28),
RandomMat(4, 5, 5, 24),
RandomMat(5, 4, 6, 32)
};

static int test_binaryop_s3()
{
return 0
|| test_binaryop(RandomMat(1, 1, 2), RandomMat(11, 6, 2))
|| test_binaryop(RandomMat(1, 1, 4), RandomMat(11, 6, 4))
|| test_binaryop(RandomMat(1, 1, 16), RandomMat(11, 6, 16));
}
for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++)
{
ncnn::Mat b0(a[i].w, 1, 1, 1);
ncnn::Mat b1(a[i].w, a[i].h, 1, 1);
ncnn::Mat b2(a[i].w, a[i].h, a[i].d, 1);
Randomize(b0);
Randomize(b1);
Randomize(b2);

int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i])
|| test_binaryop(a[i], b1) || test_binaryop(b1, a[i])
|| test_binaryop(a[i], b2) || test_binaryop(b2, a[i]);
if (ret != 0)
return ret;
}

static int test_binaryop_s4()
{
return 0
|| test_binaryop(RandomMat(11, 6, 1), RandomMat(11, 6, 2))
|| test_binaryop(RandomMat(11, 6, 1), RandomMat(11, 6, 4))
|| test_binaryop(RandomMat(11, 6, 1), RandomMat(11, 6, 16));
return 0;
}

static int test_binaryop_s5()
static int test_binaryop_9()
{
return 0
|| test_binaryop(RandomMat(11, 6, 2), RandomMat(1, 6, 2))
|| test_binaryop(RandomMat(11, 6, 4), RandomMat(1, 6, 4))
|| test_binaryop(RandomMat(11, 6, 16), RandomMat(1, 6, 16));
}
ncnn::Mat a[] = {
RandomMat(7, 3, 31),
RandomMat(6, 4, 28),
RandomMat(5, 5, 24),
RandomMat(4, 6, 32)
};

static int test_binaryop_s6()
{
return 0
|| test_binaryop(RandomMat(11, 6, 2), RandomMat(11, 1, 2))
|| test_binaryop(RandomMat(11, 6, 4), RandomMat(11, 1, 4))
|| test_binaryop(RandomMat(11, 6, 16), RandomMat(11, 1, 16));
}
for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++)
{
ncnn::Mat b0(a[i].w, 1, a[i].c);
Randomize(b0);

static int test_binaryop_s7()
{
return 0
|| test_binaryop(RandomMat(1, 6, 2), RandomMat(11, 6, 2))
|| test_binaryop(RandomMat(1, 6, 4), RandomMat(11, 6, 4))
|| test_binaryop(RandomMat(1, 6, 16), RandomMat(11, 6, 16));
}
int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i]);
if (ret != 0)
return ret;
}

static int test_binaryop_s8()
{
return 0
|| test_binaryop(RandomMat(11, 1, 2), RandomMat(11, 6, 2))
|| test_binaryop(RandomMat(11, 1, 4), RandomMat(11, 6, 4))
|| test_binaryop(RandomMat(11, 1, 16), RandomMat(11, 6, 16));
return 0;
}

int main()
@@ -393,35 +379,7 @@ int main()
|| test_binaryop_6()
|| test_binaryop_7()
|| test_binaryop_8()
|| test_binaryop_9()
|| test_binaryop_10()
|| test_binaryop_11()
|| test_binaryop_12()
|| test_binaryop_13()
|| test_binaryop_14()
|| test_binaryop_15()
|| test_binaryop_16()
|| test_binaryop_17()
|| test_binaryop_18()
|| test_binaryop_19()
|| test_binaryop_20()
|| test_binaryop_21()
|| test_binaryop_22()
|| test_binaryop_23()
|| test_binaryop_24()
|| test_binaryop_25()
|| test_binaryop_26()
|| test_binaryop_27()
|| test_binaryop_28()
|| test_binaryop_29()
|| test_binaryop_s1()
|| test_binaryop_s2()
|| test_binaryop_s3()
|| test_binaryop_s4()
|| test_binaryop_s5()
|| test_binaryop_s6()
|| test_binaryop_s7()
|| test_binaryop_s8();
|| test_binaryop_9();

if (ret != 0)
return ret;


+ 1
- 0
tools/pnnx/src/CMakeLists.txt View File

@@ -366,6 +366,7 @@ set(pnnx_pass_ncnn_SRCS
pass_ncnn/fuse_innerproduct_activation.cpp
pass_ncnn/fuse_transpose_matmul.cpp
pass_ncnn/fuse_binaryop_eltwise.cpp
pass_ncnn/insert_reshape_numpy_binaryop_broadcast.cpp
pass_ncnn/insert_reshape_linear.cpp
pass_ncnn/insert_reshape_pooling.cpp



+ 2
- 0
tools/pnnx/src/pass_ncnn.cpp View File

@@ -43,6 +43,7 @@
#include "pass_ncnn/fuse_innerproduct_activation.h"
#include "pass_ncnn/fuse_transpose_matmul.h"
#include "pass_ncnn/fuse_binaryop_eltwise.h"
#include "pass_ncnn/insert_reshape_numpy_binaryop_broadcast.h"
#include "pass_ncnn/insert_reshape_linear.h"
#include "pass_ncnn/insert_reshape_pooling.h"

@@ -85,6 +86,7 @@ void pass_ncnn(Graph& g)

ncnn::convert_half_to_float(g);

ncnn::insert_reshape_numpy_binaryop_broadcast(g);
ncnn::insert_reshape_pooling(g);
ncnn::insert_reshape_linear(g);



+ 28
- 0
tools/pnnx/src/pass_ncnn/expand_expression.cpp View File

@@ -169,6 +169,8 @@ static std::string expand_expression(Graph& graph, const Operator* op, int& pnnx
Operand* op_unary_out = graph.new_operand(op->name + "_" + r);
op_unary_out->producer = op_unary;

op_unary_out->shape = op_unary_in->shape;

op_unary->inputs.push_back(op_unary_in);
op_unary->outputs.push_back(op_unary_out);
}
@@ -204,6 +206,8 @@ static std::string expand_expression(Graph& graph, const Operator* op, int& pnnx
Operand* op_binary_out = graph.new_operand(op->name + "_" + r);
op_binary_out->producer = op_binary;

op_binary_out->shape = op_binary_inb->shape;

op_binary->inputs.push_back(op_binary_inb);
op_binary->outputs.push_back(op_binary_out);
}
@@ -218,6 +222,8 @@ static std::string expand_expression(Graph& graph, const Operator* op, int& pnnx
Operand* op_binary_out = graph.new_operand(op->name + "_" + r);
op_binary_out->producer = op_binary;

op_binary_out->shape = op_binary_ina->shape;

op_binary->inputs.push_back(op_binary_ina);
op_binary->outputs.push_back(op_binary_out);
}
@@ -232,6 +238,28 @@ static std::string expand_expression(Graph& graph, const Operator* op, int& pnnx
Operand* op_binary_out = graph.new_operand(op->name + "_" + r);
op_binary_out->producer = op_binary;

// resolve out shape
std::vector<int> out_shape;
{
std::vector<int> a_shape = op_binary_ina->shape;
std::vector<int> b_shape = op_binary_inb->shape;
int outrank = (int)std::max(a_shape.size(), b_shape.size());
for (int k = (int)a_shape.size(); k < outrank; k++)
{
a_shape.insert(a_shape.begin(), 1);
}
for (int k = (int)b_shape.size(); k < outrank; k++)
{
b_shape.insert(b_shape.begin(), 1);
}
out_shape.resize(outrank);
for (int k = 0; k < outrank; k++)
{
out_shape[k] = std::max(a_shape[k], b_shape[k]);
}
}
op_binary_out->shape = out_shape;

op_binary->inputs.push_back(op_binary_ina);
op_binary->inputs.push_back(op_binary_inb);
op_binary->outputs.push_back(op_binary_out);


+ 153
- 0
tools/pnnx/src/pass_ncnn/insert_reshape_numpy_binaryop_broadcast.cpp View File

@@ -0,0 +1,153 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "insert_reshape_numpy_binaryop_broadcast.h"
#include "pass_ncnn.h"

namespace pnnx {

namespace ncnn {

void insert_reshape_numpy_binaryop_broadcast(Graph& graph)
{
while (1)
{
bool matched = false;

for (size_t i = 0; i < graph.ops.size(); i++)
{
Operator* op = graph.ops[i];

if (op->type != "BinaryOp")
continue;

if (op->inputs.size() != 2)
continue;

if (op->inputs[0]->shape.empty() || op->inputs[1]->shape.empty())
continue;

int batch_index0 = op->inputs[0]->params["__batch_index"].i;
int batch_index1 = op->inputs[1]->params["__batch_index"].i;
if (batch_index0 != batch_index1)
{
fprintf(stderr, "binaryop broadcast across batch axis %d and %d is not supported\n", batch_index0, batch_index1);
continue;
}

if (op->inputs[0]->shape.size() == 5 && batch_index0 == 233)
{
if (op->inputs[0]->shape[0] == 1)
{
fprintf(stderr, "assume reshape 5-rank tensor has batch_index 0\n");
batch_index0 = 0;
}
}
if (op->inputs[1]->shape.size() == 5 && batch_index1 == 233)
{
if (op->inputs[1]->shape[0] == 1)
{
fprintf(stderr, "assume reshape 5-rank tensor has batch_index 0\n");
batch_index1 = 0;
}
}

// drop shape batch index
std::vector<int> new_shape0;
std::vector<int> new_shape1;
for (int j = 0; j < (int)op->inputs[0]->shape.size(); j++)
{
if (j == batch_index0 && (op->inputs[0]->shape[j] == 1 || op->inputs[0]->shape[j] == op->inputs[1]->shape[j]))
continue;

new_shape0.push_back(op->inputs[0]->shape[j]);
}
for (int j = 0; j < (int)op->inputs[1]->shape.size(); j++)
{
if (j == batch_index1 && (op->inputs[1]->shape[j] == 1 || op->inputs[1]->shape[j] == op->inputs[0]->shape[j]))
continue;

new_shape1.push_back(op->inputs[1]->shape[j]);
}

const int input_rank0 = (int)new_shape0.size();
const int input_rank1 = (int)new_shape1.size();

if (input_rank0 >= 5)
{
fprintf(stderr, "binaryop tensor0 with rank %d is not supported yet!\n", (int)op->inputs[0]->shape.size());
}

if (input_rank1 >= 5)
{
fprintf(stderr, "binaryop tensor1 with rank %d is not supported yet!\n", (int)op->inputs[1]->shape.size());
}

if (input_rank0 == input_rank1)
{
// no broadcast after ignoring batch index
continue;
}

// fprintf(stderr, "insert_reshape_numpy_binaryop_broadcast %d %d\n", input_rank0, input_rank1);

matched = true;

const int binaryop_lower_rank_in_index = input_rank0 < input_rank1 ? 0 : 1;

Operand* binaryop_lower_rank_in = op->inputs[binaryop_lower_rank_in_index];

Operator* reshape0 = graph.new_operator_before("Tensor.reshape", op->name + "_ncnnreshape0", op);

Operand* reshape0_out = graph.new_operand(op->name + "_ncnnreshape0_out");

reshape0->inputs.push_back(binaryop_lower_rank_in);
reshape0->outputs.push_back(reshape0_out);

for (size_t j = 0; j < binaryop_lower_rank_in->consumers.size(); j++)
{
if (binaryop_lower_rank_in->consumers[j] == op)
{
binaryop_lower_rank_in->consumers[j] = reshape0;
break;
}
}

op->inputs[binaryop_lower_rank_in_index] = reshape0_out;

reshape0_out->producer = reshape0;
reshape0_out->consumers.push_back(op);

reshape0_out->params["__batch_index"] = input_rank0 < input_rank1 ? batch_index0 : batch_index1;

// insert explicit broadcast index for missing ranks
std::vector<int> reshape0_shape = input_rank0 < input_rank1 ? new_shape0 : new_shape1;
for (int j = 0; j < std::abs(input_rank0 - input_rank1); j++)
{
reshape0_shape.insert(reshape0_shape.begin(), 1);
}

reshape0->params["shape"] = reshape0_shape;

break;
}

if (!matched)
break;
}
}

} // namespace ncnn

} // namespace pnnx

+ 25
- 0
tools/pnnx/src/pass_ncnn/insert_reshape_numpy_binaryop_broadcast.h View File

@@ -0,0 +1,25 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_ncnn.h"

namespace pnnx {

namespace ncnn {

void insert_reshape_numpy_binaryop_broadcast(Graph& graph);

} // namespace ncnn

} // namespace pnnx

+ 1
- 0
tools/pnnx/tests/ncnn/CMakeLists.txt View File

@@ -187,6 +187,7 @@ pnnx_ncnn_add_test(vit_b_32)
pnnx_ncnn_add_test(ncnn_fuse_transpose_matmul)
pnnx_ncnn_add_test(ncnn_fuse_shufflechannel_slice)
pnnx_ncnn_add_test(ncnn_fuse_binaryop_eltwise)
pnnx_ncnn_add_test(ncnn_numpy_binaryop_broadcast)

if(Torch_VERSION VERSION_GREATER_EQUAL "1.9")
pnnx_ncnn_add_test(F_mish)


+ 78
- 0
tools/pnnx/tests/ncnn/test_ncnn_numpy_binaryop_broadcast.py View File

@@ -0,0 +1,78 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y, z, w, u, v):
a = x + y
b = x - z
c = x * w
d = y / z
e = y + w
f = z - w
g = y + x
h = z - x
i = w * x
j = z / y
k = w + y
l = w - z
m = (x - z) * w
n = (x + y) - (z + w)
o = x.view(1, 1, 5) + y.view(1, 7, 5) - z
p = u * y
q = z / v
return a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(5)
y = torch.rand(7, 5)
z = torch.rand(4, 7, 5)
w = torch.rand(6, 4, 7, 5)
u = torch.rand(7, 1)
v = torch.rand(4, 1, 1)

a = net(x, y, z, w, u, v)

# export torchscript
mod = torch.jit.trace(net, (x, y, z, w, u, v))
mod.save("test_ncnn_numpy_binaryop_broadcast.pt")

# torchscript to pnnx
import os
os.system("../../src/pnnx test_ncnn_numpy_binaryop_broadcast.pt inputshape=[5],[7,5],[4,7,5],[6,4,7,5],[7,1],[4,1,1]")

# ncnn inference
import test_ncnn_numpy_binaryop_broadcast_ncnn
b = test_ncnn_numpy_binaryop_broadcast_ncnn.test_inference()

for a0, b0 in zip(a, b):
if not torch.allclose(a0, b0, 1e-4, 1e-4):
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

Loading…
Cancel
Save