| @@ -6,47 +6,48 @@ C = BinaryOp(A, B) | |||
| shape notation convention is [w], [w,h], [w,h,c], [w,h,d,c] | |||
| * binaryop with scalar and scalar-like | |||
| |type|A|B|C| | |||
| |---|---|---|---| | |||
| |0|[2]|scalar / [1]|[2]| | |||
| |1|[2,3]|scalar / [1] / [1,1]|[2,3]| | |||
| |2|[2,3,4]|scalar / [1] / [1,1] / [1,1,1]|[2,3,4]| | |||
| |3|[2,3,4,5]|scalar / [1] / [1,1] / [1,1,1] / [1,1,1,1]|[2,3,4,5]| | |||
| * no broadcast | |||
| |type|A|B|C| | |||
| |---|---|---|---| | |||
| |4|[2]|[2]|[2]| | |||
| |5|[2,3]|[2,3]|[2,3]| | |||
| |6|[2,3,4]|[2,3,4]|[2,3,4]| | |||
| |7|[2,3,4,5]|[2,3,4,5]|[2,3,4,5]| | |||
| * broadcast B for inner axis | |||
| |type|A|B|C| | |||
| |---|---|---|---| | |||
| |1|[1]|scalar|[1]| | |||
| |2|[1]|[2]|[2]| | |||
| |3|[1]|[2,3]|[2,3]| | |||
| |4|[1]|[2,3,4]|[2,3,4]| | |||
| |5|[2]|scalar|[2]| | |||
| |6|[2]|[1]|[2]| | |||
| |7|[2]|[2]|[2]| | |||
| |8|[3]|[2,3]|[2,3]| | |||
| |9|[4]|[2,3,4]|[2,3,4]| | |||
| |10|[2,3]|scalar|[2,3]| | |||
| |11|[2,3]|[1]|[2,3]| | |||
| |12|[2,3]|[3]|[2,3]| | |||
| |13|[2,3]|[2,3]|[2,3]| | |||
| |14|[3,4]|[2,3,4]|[2,3,4]| | |||
| |15|[2,3,4]|scalar|[2,3,4]| | |||
| |16|[2,3,4]|[1]|[2,3,4]| | |||
| |17|[2,3,4]|[4]|[2,3,4]| | |||
| |18|[2,3,4]|[3,4]|[2,3,4]| | |||
| |19|[2,3,4]|[2,3,4]|[2,3,4]| | |||
| |20|[1]|[2,3,4,5]|[2,3,4,5]| | |||
| |21|[5]|[2,3,4,5]|[2,3,4,5]| | |||
| |22|[4,5]|[2,3,4,5]|[2,3,4,5]| | |||
| |23|[3,4,5]|[2,3,4,5]|[2,3,4,5]| | |||
| |24|[2,3,4,5]|scalar|[2,3,4,5]| | |||
| |25|[2,3,4,5]|[1]|[2,3,4,5]| | |||
| |26|[2,3,4,5]|[5]|[2,3,4,5]| | |||
| |27|[2,3,4,5]|[4,5]|[2,3,4,5]| | |||
| |28|[2,3,4,5]|[3,4,5]|[2,3,4,5]| | |||
| |29|[2,3,4,5]|[2,3,4,5]|[2,3,4,5]| | |||
| some special broadcasting rule exists for model compatibility | |||
| |8|[2,3]|[3] / [1,3]|[2,3]| | |||
| |9|[2,3,4]|[4] / [1,1,4]|[2,3,4]| | |||
| |10|[2,3,4]|[3,4] / [1,3,4]|[2,3,4]| | |||
| |11|[2,3,4,5]|[5] / [1,1,1,5]|[2,3,4,5]| | |||
| |12|[2,3,4,5]|[4,5] / [1,1,4,5]|[2,3,4,5]| | |||
| |13|[2,3,4,5]|[3,4,5] / [1,3,4,5]|[2,3,4,5]| | |||
| * broadcast B for outer axis | |||
| |type|A|B|C| | |||
| |---|---|---|---| | |||
| |14|[2,3]|[2,1]|[2,3]| | |||
| |15|[2,3,4]|[2,1,1]|[2,3,4]| | |||
| |16|[2,3,4]|[2,3,1]|[2,3,4]| | |||
| |17|[2,3,4,5]|[2,1,1,1]|[2,3,4,5]| | |||
| |18|[2,3,4,5]|[2,3,1,1]|[2,3,4,5]| | |||
| |19|[2,3,4,5]|[2,3,4,1]|[2,3,4,5]| | |||
| * some special broadcasting rule exists for model compatibility | |||
| |special type|A|B|C| | |||
| |---|---|---|---| | |||
| |1|[2,3,4]|[1,1,4]|[2,3,4]| | |||
| |2|[2,3,4]|[2,3,1]|[2,3,4]| | |||
| |3|[1,1,4]|[2,3,4]|[2,3,4]| | |||
| |4|[2,3,1]|[2,3,4]|[2,3,4]| | |||
| |5|[2,3,4]|[1,3,4]|[2,3,4]| | |||
| |6|[2,3,4]|[2,1,4]|[2,3,4]| | |||
| |7|[1,3,4]|[2,3,4]|[2,3,4]| | |||
| |8|[2,1,4]|[2,3,4]|[2,3,4]| | |||
| |20|[2,3,4]|[2,1,4]|[2,3,4]| | |||
| @@ -161,6 +161,7 @@ Operation type: | |||
| - 6 = POW | |||
| - 7 = RSUB | |||
| - 8 = RDIV | |||
| - 9 = RPOW | |||
| # BNLL | |||
| ``` | |||
| @@ -42,7 +42,8 @@ public: | |||
| Operation_MIN = 5, | |||
| Operation_POW = 6, | |||
| Operation_RSUB = 7, | |||
| Operation_RDIV = 8 | |||
| Operation_RDIV = 8, | |||
| Operation_RPOW = 9 | |||
| }; | |||
| public: | |||
| @@ -29,13 +29,29 @@ BinaryOp_vulkan::BinaryOp_vulkan() | |||
| pipeline_binaryop_pack4 = 0; | |||
| pipeline_binaryop_pack8 = 0; | |||
| pipeline_binaryop_broadcast = 0; | |||
| pipeline_binaryop_broadcast_pack4 = 0; | |||
| pipeline_binaryop_broadcast_a1_pack4 = 0; | |||
| pipeline_binaryop_broadcast_b1_pack4 = 0; | |||
| pipeline_binaryop_broadcast_pack8 = 0; | |||
| pipeline_binaryop_broadcast_a1_pack8 = 0; | |||
| pipeline_binaryop_broadcast_b1_pack8 = 0; | |||
| pipeline_binaryop_broadcast_inner[0] = 0; | |||
| pipeline_binaryop_broadcast_inner[1] = 0; | |||
| pipeline_binaryop_broadcast_inner_pack4[0] = 0; | |||
| pipeline_binaryop_broadcast_inner_pack4[1] = 0; | |||
| pipeline_binaryop_broadcast_inner_pack8[0] = 0; | |||
| pipeline_binaryop_broadcast_inner_pack8[1] = 0; | |||
| pipeline_binaryop_broadcast_outer[0] = 0; | |||
| pipeline_binaryop_broadcast_outer[1] = 0; | |||
| pipeline_binaryop_broadcast_outer_pack4[0] = 0; | |||
| pipeline_binaryop_broadcast_outer_pack4[1] = 0; | |||
| pipeline_binaryop_broadcast_outer_pack8[0] = 0; | |||
| pipeline_binaryop_broadcast_outer_pack8[1] = 0; | |||
| } | |||
| static int get_reverse_op_type(int op_type) | |||
| { | |||
| if (op_type == BinaryOp::Operation_SUB) return BinaryOp::Operation_RSUB; | |||
| if (op_type == BinaryOp::Operation_DIV) return BinaryOp::Operation_RDIV; | |||
| if (op_type == BinaryOp::Operation_POW) return BinaryOp::Operation_RPOW; | |||
| if (op_type == BinaryOp::Operation_RSUB) return BinaryOp::Operation_SUB; | |||
| if (op_type == BinaryOp::Operation_RDIV) return BinaryOp::Operation_DIV; | |||
| if (op_type == BinaryOp::Operation_RPOW) return BinaryOp::Operation_POW; | |||
| return op_type; | |||
| } | |||
| int BinaryOp_vulkan::create_pipeline(const Option& opt) | |||
| @@ -182,20 +198,33 @@ int BinaryOp_vulkan::create_pipeline(const Option& opt) | |||
| // broadcast | |||
| if (shape.dims == 0 || broadcast) | |||
| { | |||
| bool a_is_lower = false; | |||
| if (shape.dims != 0 && shape1.dims != 0) | |||
| { | |||
| const bool b_is_scalar = shape1_packed.w * shape1_packed.h * shape1_packed.d * shape1_packed.c * shape1_packed.elempack == 1; | |||
| const bool a_rank_is_lower = shape_packed.dims < shape1_packed.dims && !b_is_scalar; | |||
| const bool a_size_is_lower = shape_packed.w * shape_packed.h * shape_packed.d * shape_packed.c * shape_packed.elempack < shape1_packed.w * shape1_packed.h * shape1_packed.d * shape1_packed.c * shape1_packed.elempack; | |||
| a_is_lower = a_rank_is_lower || (!a_rank_is_lower && a_size_is_lower); | |||
| } | |||
| const Mat& A_shape_packed = a_is_lower ? shape1_packed : shape_packed; | |||
| const Mat& B_shape_packed = a_is_lower ? shape_packed : shape1_packed; | |||
| const int op_type_r = get_reverse_op_type(op_type); | |||
| std::vector<vk_specialization_type> specializations(1 + 18); | |||
| specializations[0].i = op_type; | |||
| specializations[1 + 0].i = shape_packed.dims; | |||
| specializations[1 + 1].i = shape_packed.w; | |||
| specializations[1 + 2].i = shape_packed.h; | |||
| specializations[1 + 3].i = shape_packed.d; | |||
| specializations[1 + 4].i = shape_packed.c; | |||
| specializations[1 + 5].i = shape_packed.cstep; | |||
| specializations[1 + 6].i = shape1_packed.dims; | |||
| specializations[1 + 7].i = shape1_packed.w; | |||
| specializations[1 + 8].i = shape1_packed.h; | |||
| specializations[1 + 9].i = shape1_packed.d; | |||
| specializations[1 + 10].i = shape1_packed.c; | |||
| specializations[1 + 11].i = shape1_packed.cstep; | |||
| specializations[1 + 0].i = A_shape_packed.dims; | |||
| specializations[1 + 1].i = A_shape_packed.w; | |||
| specializations[1 + 2].i = A_shape_packed.h; | |||
| specializations[1 + 3].i = A_shape_packed.d; | |||
| specializations[1 + 4].i = A_shape_packed.c; | |||
| specializations[1 + 5].i = A_shape_packed.cstep; | |||
| specializations[1 + 6].i = B_shape_packed.dims; | |||
| specializations[1 + 7].i = B_shape_packed.w; | |||
| specializations[1 + 8].i = B_shape_packed.h; | |||
| specializations[1 + 9].i = B_shape_packed.d; | |||
| specializations[1 + 10].i = B_shape_packed.c; | |||
| specializations[1 + 11].i = B_shape_packed.cstep; | |||
| specializations[1 + 12].i = out_shape_packed.dims; | |||
| specializations[1 + 13].i = out_shape_packed.w; | |||
| specializations[1 + 14].i = out_shape_packed.h; | |||
| @@ -203,23 +232,26 @@ int BinaryOp_vulkan::create_pipeline(const Option& opt) | |||
| specializations[1 + 16].i = out_shape_packed.c; | |||
| specializations[1 + 17].i = out_shape_packed.cstep; | |||
| std::vector<vk_specialization_type> specializations_broadcast_a1_b1(1 + 15); | |||
| specializations_broadcast_a1_b1[0].i = op_type; | |||
| specializations_broadcast_a1_b1[1 + 0].i = shape_packed.dims; | |||
| specializations_broadcast_a1_b1[1 + 1].i = shape_packed.w; | |||
| specializations_broadcast_a1_b1[1 + 2].i = shape_packed.h * shape_packed.d; | |||
| specializations_broadcast_a1_b1[1 + 3].i = shape_packed.c; | |||
| specializations_broadcast_a1_b1[1 + 4].i = shape_packed.cstep; | |||
| specializations_broadcast_a1_b1[1 + 5].i = shape1_packed.dims; | |||
| specializations_broadcast_a1_b1[1 + 6].i = shape1_packed.w; | |||
| specializations_broadcast_a1_b1[1 + 7].i = shape1_packed.h * shape1_packed.d; | |||
| specializations_broadcast_a1_b1[1 + 8].i = shape1_packed.c; | |||
| specializations_broadcast_a1_b1[1 + 9].i = shape1_packed.cstep; | |||
| specializations_broadcast_a1_b1[1 + 10].i = out_shape_packed.dims; | |||
| specializations_broadcast_a1_b1[1 + 11].i = out_shape_packed.w; | |||
| specializations_broadcast_a1_b1[1 + 12].i = out_shape_packed.h * out_shape_packed.d; | |||
| specializations_broadcast_a1_b1[1 + 13].i = out_shape_packed.c; | |||
| specializations_broadcast_a1_b1[1 + 14].i = out_shape_packed.cstep; | |||
| std::vector<vk_specialization_type> specializations_r(1 + 18); | |||
| specializations_r[0].i = op_type_r; | |||
| specializations_r[1 + 0].i = A_shape_packed.dims; | |||
| specializations_r[1 + 1].i = A_shape_packed.w; | |||
| specializations_r[1 + 2].i = A_shape_packed.h; | |||
| specializations_r[1 + 3].i = A_shape_packed.d; | |||
| specializations_r[1 + 4].i = A_shape_packed.c; | |||
| specializations_r[1 + 5].i = A_shape_packed.cstep; | |||
| specializations_r[1 + 6].i = B_shape_packed.dims; | |||
| specializations_r[1 + 7].i = B_shape_packed.w; | |||
| specializations_r[1 + 8].i = B_shape_packed.h; | |||
| specializations_r[1 + 9].i = B_shape_packed.d; | |||
| specializations_r[1 + 10].i = B_shape_packed.c; | |||
| specializations_r[1 + 11].i = B_shape_packed.cstep; | |||
| specializations_r[1 + 12].i = out_shape_packed.dims; | |||
| specializations_r[1 + 13].i = out_shape_packed.w; | |||
| specializations_r[1 + 14].i = out_shape_packed.h; | |||
| specializations_r[1 + 15].i = out_shape_packed.d; | |||
| specializations_r[1 + 16].i = out_shape_packed.c; | |||
| specializations_r[1 + 17].i = out_shape_packed.cstep; | |||
| Mat local_size_xyz; | |||
| if (out_shape_packed.dims == 1) | |||
| @@ -248,59 +280,75 @@ int BinaryOp_vulkan::create_pipeline(const Option& opt) | |||
| } | |||
| // pack1 | |||
| if (shape.dims == 0 || (elempack == 1 && elempack1 == 1)) | |||
| if (shape.dims == 0 || (out_elempack == 1)) | |||
| { | |||
| pipeline_binaryop_broadcast = new Pipeline(vkdev); | |||
| pipeline_binaryop_broadcast->set_optimal_local_size_xyz(local_size_xyz); | |||
| pipeline_binaryop_broadcast->create(LayerShaderType::binaryop_broadcast, opt, specializations); | |||
| pipeline_binaryop_broadcast_inner[0] = new Pipeline(vkdev); | |||
| pipeline_binaryop_broadcast_inner[0]->set_optimal_local_size_xyz(local_size_xyz); | |||
| pipeline_binaryop_broadcast_inner[0]->create(LayerShaderType::binaryop_broadcast_inner, opt, specializations); | |||
| pipeline_binaryop_broadcast_outer[0] = new Pipeline(vkdev); | |||
| pipeline_binaryop_broadcast_outer[0]->set_optimal_local_size_xyz(local_size_xyz); | |||
| pipeline_binaryop_broadcast_outer[0]->create(LayerShaderType::binaryop_broadcast_outer, opt, specializations); | |||
| if (op_type_r != op_type) | |||
| { | |||
| // sub div pow ... | |||
| pipeline_binaryop_broadcast_inner[1] = new Pipeline(vkdev); | |||
| pipeline_binaryop_broadcast_inner[1]->set_optimal_local_size_xyz(local_size_xyz); | |||
| pipeline_binaryop_broadcast_inner[1]->create(LayerShaderType::binaryop_broadcast_inner, opt, specializations_r); | |||
| pipeline_binaryop_broadcast_outer[1] = new Pipeline(vkdev); | |||
| pipeline_binaryop_broadcast_outer[1]->set_optimal_local_size_xyz(local_size_xyz); | |||
| pipeline_binaryop_broadcast_outer[1]->create(LayerShaderType::binaryop_broadcast_outer, opt, specializations_r); | |||
| } | |||
| } | |||
| // pack4 | |||
| if (shape.dims == 0 || (elempack == 4 && elempack1 == 4)) | |||
| if (shape.dims == 0 || (out_elempack == 4)) | |||
| { | |||
| pipeline_binaryop_broadcast_pack4 = new Pipeline(vkdev); | |||
| pipeline_binaryop_broadcast_pack4->set_optimal_local_size_xyz(local_size_xyz); | |||
| pipeline_binaryop_broadcast_pack4->create(LayerShaderType::binaryop_broadcast_pack4, opt, specializations); | |||
| } | |||
| pipeline_binaryop_broadcast_inner_pack4[0] = new Pipeline(vkdev); | |||
| pipeline_binaryop_broadcast_inner_pack4[0]->set_optimal_local_size_xyz(local_size_xyz); | |||
| pipeline_binaryop_broadcast_inner_pack4[0]->create(LayerShaderType::binaryop_broadcast_inner_pack4, opt, specializations); | |||
| if (shape.dims == 0 || (shape.dims == 1 && shape.w == 1 && elempack == 1 && elempack1 == 4) | |||
| || (shape.dims == 3 && shape1.dims == 3 && shape1.w == shape.w && shape1.h == shape.h && shape.c == 1 && elempack == 1 && elempack1 == 4)) | |||
| { | |||
| pipeline_binaryop_broadcast_a1_pack4 = new Pipeline(vkdev); | |||
| pipeline_binaryop_broadcast_a1_pack4->set_optimal_local_size_xyz(local_size_xyz); | |||
| pipeline_binaryop_broadcast_a1_pack4->create(LayerShaderType::binaryop_broadcast_a1_pack4, opt, specializations_broadcast_a1_b1); | |||
| } | |||
| pipeline_binaryop_broadcast_outer_pack4[0] = new Pipeline(vkdev); | |||
| pipeline_binaryop_broadcast_outer_pack4[0]->set_optimal_local_size_xyz(local_size_xyz); | |||
| pipeline_binaryop_broadcast_outer_pack4[0]->create(LayerShaderType::binaryop_broadcast_outer_pack4, opt, specializations); | |||
| if (shape.dims == 0 || (shape1.dims == 1 && shape1.w == 1 && elempack1 == 1 && elempack == 4) | |||
| || (shape.dims == 3 && shape1.dims == 3 && shape1.w == shape.w && shape1.h == shape.h && shape1.c == 1 && elempack1 == 1 && elempack == 4)) | |||
| { | |||
| pipeline_binaryop_broadcast_b1_pack4 = new Pipeline(vkdev); | |||
| pipeline_binaryop_broadcast_b1_pack4->set_optimal_local_size_xyz(local_size_xyz); | |||
| pipeline_binaryop_broadcast_b1_pack4->create(LayerShaderType::binaryop_broadcast_b1_pack4, opt, specializations_broadcast_a1_b1); | |||
| if (op_type_r != op_type) | |||
| { | |||
| // sub div pow ... | |||
| pipeline_binaryop_broadcast_inner_pack4[1] = new Pipeline(vkdev); | |||
| pipeline_binaryop_broadcast_inner_pack4[1]->set_optimal_local_size_xyz(local_size_xyz); | |||
| pipeline_binaryop_broadcast_inner_pack4[1]->create(LayerShaderType::binaryop_broadcast_inner_pack4, opt, specializations_r); | |||
| pipeline_binaryop_broadcast_outer_pack4[1] = new Pipeline(vkdev); | |||
| pipeline_binaryop_broadcast_outer_pack4[1]->set_optimal_local_size_xyz(local_size_xyz); | |||
| pipeline_binaryop_broadcast_outer_pack4[1]->create(LayerShaderType::binaryop_broadcast_outer_pack4, opt, specializations_r); | |||
| } | |||
| } | |||
| // pack8 | |||
| if ((opt.use_shader_pack8 && shape.dims == 0) || (elempack == 8 && elempack1 == 8)) | |||
| if ((opt.use_shader_pack8 && shape.dims == 0) || (out_elempack == 8)) | |||
| { | |||
| pipeline_binaryop_broadcast_pack8 = new Pipeline(vkdev); | |||
| pipeline_binaryop_broadcast_pack8->set_optimal_local_size_xyz(local_size_xyz); | |||
| pipeline_binaryop_broadcast_pack8->create(LayerShaderType::binaryop_broadcast_pack8, opt, specializations); | |||
| } | |||
| pipeline_binaryop_broadcast_inner_pack8[0] = new Pipeline(vkdev); | |||
| pipeline_binaryop_broadcast_inner_pack8[0]->set_optimal_local_size_xyz(local_size_xyz); | |||
| pipeline_binaryop_broadcast_inner_pack8[0]->create(LayerShaderType::binaryop_broadcast_inner_pack8, opt, specializations); | |||
| if ((opt.use_shader_pack8 && shape.dims == 0) || (shape.dims == 1 && shape.w == 1 && elempack == 1 && elempack1 == 8) | |||
| || (shape.dims == 3 && shape1.dims == 3 && shape1.w == shape.w && shape1.h == shape.h && shape.c == 1 && elempack == 1 && elempack1 == 8)) | |||
| { | |||
| pipeline_binaryop_broadcast_a1_pack8 = new Pipeline(vkdev); | |||
| pipeline_binaryop_broadcast_a1_pack8->set_optimal_local_size_xyz(local_size_xyz); | |||
| pipeline_binaryop_broadcast_a1_pack8->create(LayerShaderType::binaryop_broadcast_a1_pack8, opt, specializations_broadcast_a1_b1); | |||
| } | |||
| pipeline_binaryop_broadcast_outer_pack8[0] = new Pipeline(vkdev); | |||
| pipeline_binaryop_broadcast_outer_pack8[0]->set_optimal_local_size_xyz(local_size_xyz); | |||
| pipeline_binaryop_broadcast_outer_pack8[0]->create(LayerShaderType::binaryop_broadcast_outer_pack8, opt, specializations); | |||
| if ((opt.use_shader_pack8 && shape.dims == 0) || (shape1.dims == 1 && shape1.w == 1 && elempack1 == 1 && elempack == 8) | |||
| || (shape.dims == 3 && shape1.dims == 3 && shape1.w == shape.w && shape1.h == shape.h && shape1.c == 1 && elempack1 == 1 && elempack == 8)) | |||
| { | |||
| pipeline_binaryop_broadcast_b1_pack8 = new Pipeline(vkdev); | |||
| pipeline_binaryop_broadcast_b1_pack8->set_optimal_local_size_xyz(local_size_xyz); | |||
| pipeline_binaryop_broadcast_b1_pack8->create(LayerShaderType::binaryop_broadcast_b1_pack8, opt, specializations_broadcast_a1_b1); | |||
| if (op_type_r != op_type) | |||
| { | |||
| // sub div pow ... | |||
| pipeline_binaryop_broadcast_inner_pack8[1] = new Pipeline(vkdev); | |||
| pipeline_binaryop_broadcast_inner_pack8[1]->set_optimal_local_size_xyz(local_size_xyz); | |||
| pipeline_binaryop_broadcast_inner_pack8[1]->create(LayerShaderType::binaryop_broadcast_inner_pack8, opt, specializations_r); | |||
| pipeline_binaryop_broadcast_outer_pack8[1] = new Pipeline(vkdev); | |||
| pipeline_binaryop_broadcast_outer_pack8[1]->set_optimal_local_size_xyz(local_size_xyz); | |||
| pipeline_binaryop_broadcast_outer_pack8[1]->create(LayerShaderType::binaryop_broadcast_outer_pack8, opt, specializations_r); | |||
| } | |||
| } | |||
| } | |||
| @@ -318,167 +366,75 @@ int BinaryOp_vulkan::destroy_pipeline(const Option& /*opt*/) | |||
| delete pipeline_binaryop_pack8; | |||
| pipeline_binaryop_pack8 = 0; | |||
| delete pipeline_binaryop_broadcast; | |||
| pipeline_binaryop_broadcast = 0; | |||
| delete pipeline_binaryop_broadcast_pack4; | |||
| pipeline_binaryop_broadcast_pack4 = 0; | |||
| delete pipeline_binaryop_broadcast_inner[0]; | |||
| delete pipeline_binaryop_broadcast_inner[1]; | |||
| pipeline_binaryop_broadcast_inner[0] = 0; | |||
| pipeline_binaryop_broadcast_inner[1] = 0; | |||
| delete pipeline_binaryop_broadcast_a1_pack4; | |||
| pipeline_binaryop_broadcast_a1_pack4 = 0; | |||
| delete pipeline_binaryop_broadcast_inner_pack4[0]; | |||
| delete pipeline_binaryop_broadcast_inner_pack4[1]; | |||
| pipeline_binaryop_broadcast_inner_pack4[0] = 0; | |||
| pipeline_binaryop_broadcast_inner_pack4[1] = 0; | |||
| delete pipeline_binaryop_broadcast_b1_pack4; | |||
| pipeline_binaryop_broadcast_b1_pack4 = 0; | |||
| delete pipeline_binaryop_broadcast_inner_pack8[0]; | |||
| delete pipeline_binaryop_broadcast_inner_pack8[1]; | |||
| pipeline_binaryop_broadcast_inner_pack8[0] = 0; | |||
| pipeline_binaryop_broadcast_inner_pack8[1] = 0; | |||
| delete pipeline_binaryop_broadcast_pack8; | |||
| pipeline_binaryop_broadcast_pack8 = 0; | |||
| delete pipeline_binaryop_broadcast_outer[0]; | |||
| delete pipeline_binaryop_broadcast_outer[1]; | |||
| pipeline_binaryop_broadcast_outer[0] = 0; | |||
| pipeline_binaryop_broadcast_outer[1] = 0; | |||
| delete pipeline_binaryop_broadcast_a1_pack8; | |||
| pipeline_binaryop_broadcast_a1_pack8 = 0; | |||
| delete pipeline_binaryop_broadcast_outer_pack4[0]; | |||
| delete pipeline_binaryop_broadcast_outer_pack4[1]; | |||
| pipeline_binaryop_broadcast_outer_pack4[0] = 0; | |||
| pipeline_binaryop_broadcast_outer_pack4[1] = 0; | |||
| delete pipeline_binaryop_broadcast_b1_pack8; | |||
| pipeline_binaryop_broadcast_b1_pack8 = 0; | |||
| delete pipeline_binaryop_broadcast_outer_pack8[0]; | |||
| delete pipeline_binaryop_broadcast_outer_pack8[1]; | |||
| pipeline_binaryop_broadcast_outer_pack8[0] = 0; | |||
| pipeline_binaryop_broadcast_outer_pack8[1] = 0; | |||
| return 0; | |||
| } | |||
| int BinaryOp_vulkan::forward(const std::vector<VkMat>& bottom_blobs, std::vector<VkMat>& top_blobs, VkCompute& cmd, const Option& opt) const | |||
| { | |||
| const VkMat& bottom_blob = bottom_blobs[0]; | |||
| const VkMat& bottom_blob1 = bottom_blobs[1]; | |||
| const bool b_is_scalar = bottom_blobs[1].w * bottom_blobs[1].h * bottom_blobs[1].d * bottom_blobs[1].c * bottom_blobs[1].elempack == 1; | |||
| const bool a_rank_is_lower = bottom_blobs[0].dims < bottom_blobs[1].dims && !b_is_scalar; | |||
| const bool a_size_is_lower = bottom_blobs[0].w * bottom_blobs[0].h * bottom_blobs[0].d * bottom_blobs[0].c * bottom_blobs[0].elempack < bottom_blobs[1].w * bottom_blobs[1].h * bottom_blobs[1].d * bottom_blobs[1].c * bottom_blobs[1].elempack; | |||
| const bool a_is_lower = a_rank_is_lower || (!a_rank_is_lower && a_size_is_lower); | |||
| const VkMat& A = a_is_lower ? bottom_blobs[1] : bottom_blobs[0]; | |||
| const VkMat& B = a_is_lower ? bottom_blobs[0] : bottom_blobs[1]; | |||
| const int op_type_r = a_is_lower ? get_reverse_op_type(op_type) : op_type; | |||
| VkMat& top_blob = top_blobs[0]; | |||
| // broadcast | |||
| if (bottom_blob.dims > bottom_blob1.dims) | |||
| { | |||
| top_blob.create_like(bottom_blob, opt.blob_vkallocator); | |||
| } | |||
| else if (bottom_blob.dims < bottom_blob1.dims) | |||
| { | |||
| top_blob.create_like(bottom_blob1, opt.blob_vkallocator); | |||
| } | |||
| else // if (bottom_blob.dims == bottom_blob1.dims) | |||
| { | |||
| if (bottom_blob.w * bottom_blob.h * bottom_blob.d * bottom_blob.c * bottom_blob.elempack >= bottom_blob1.w * bottom_blob1.h * bottom_blob1.d * bottom_blob1.c * bottom_blob1.elempack) | |||
| { | |||
| top_blob.create_like(bottom_blob, opt.blob_vkallocator); | |||
| } | |||
| else | |||
| { | |||
| top_blob.create_like(bottom_blob1, opt.blob_vkallocator); | |||
| } | |||
| } | |||
| top_blob.create_like(A, opt.blob_vkallocator); | |||
| if (top_blob.empty()) | |||
| return -100; | |||
| int out_elempack = top_blob.elempack; | |||
| std::vector<VkMat> bindings(3); | |||
| bindings[0] = bottom_blob; | |||
| bindings[1] = bottom_blob1; | |||
| bindings[0] = A; | |||
| bindings[1] = B; | |||
| bindings[2] = top_blob; | |||
| bool broadcast = true; | |||
| if (bottom_blob.dims == bottom_blob1.dims | |||
| && bottom_blob.w == bottom_blob1.w | |||
| && bottom_blob.h == bottom_blob1.h | |||
| && bottom_blob.d == bottom_blob1.d | |||
| && bottom_blob.c == bottom_blob1.c | |||
| && bottom_blob.elempack == bottom_blob1.elempack) | |||
| { | |||
| broadcast = false; | |||
| } | |||
| if (broadcast) | |||
| { | |||
| std::vector<vk_constant_type> constants(18); | |||
| constants[0].i = bottom_blob.dims; | |||
| constants[1].i = bottom_blob.w; | |||
| constants[2].i = bottom_blob.h; | |||
| constants[3].i = bottom_blob.d; | |||
| constants[4].i = bottom_blob.c; | |||
| constants[5].i = bottom_blob.cstep; | |||
| constants[6].i = bottom_blob1.dims; | |||
| constants[7].i = bottom_blob1.w; | |||
| constants[8].i = bottom_blob1.h; | |||
| constants[9].i = bottom_blob1.d; | |||
| constants[10].i = bottom_blob1.c; | |||
| constants[11].i = bottom_blob1.cstep; | |||
| constants[12].i = top_blob.dims; | |||
| constants[13].i = top_blob.w; | |||
| constants[14].i = top_blob.h; | |||
| constants[15].i = top_blob.d; | |||
| constants[16].i = top_blob.c; | |||
| constants[17].i = top_blob.cstep; | |||
| std::vector<vk_constant_type> constants_broadcast_a1b1(15); | |||
| constants_broadcast_a1b1[0].i = bottom_blob.dims; | |||
| constants_broadcast_a1b1[1].i = bottom_blob.w; | |||
| constants_broadcast_a1b1[2].i = bottom_blob.h * bottom_blob.d; | |||
| constants_broadcast_a1b1[3].i = bottom_blob.c; | |||
| constants_broadcast_a1b1[4].i = bottom_blob.cstep; | |||
| constants_broadcast_a1b1[5].i = bottom_blob1.dims; | |||
| constants_broadcast_a1b1[6].i = bottom_blob1.w; | |||
| constants_broadcast_a1b1[7].i = bottom_blob1.h * bottom_blob1.d; | |||
| constants_broadcast_a1b1[8].i = bottom_blob1.c; | |||
| constants_broadcast_a1b1[9].i = bottom_blob1.cstep; | |||
| constants_broadcast_a1b1[10].i = top_blob.dims; | |||
| constants_broadcast_a1b1[11].i = top_blob.w; | |||
| constants_broadcast_a1b1[12].i = top_blob.h * top_blob.d; | |||
| constants_broadcast_a1b1[13].i = top_blob.c; | |||
| constants_broadcast_a1b1[14].i = top_blob.cstep; | |||
| bool broadcast_a1b1 = true; | |||
| const Pipeline* pipeline = 0; | |||
| if (bottom_blob.elempack == 1 && bottom_blob1.elempack == 1) | |||
| { | |||
| pipeline = pipeline_binaryop_broadcast; | |||
| broadcast_a1b1 = false; | |||
| } | |||
| else | |||
| { | |||
| if (bottom_blob.dims == 1 && bottom_blob.w == 1 && bottom_blob.elempack == 1) | |||
| { | |||
| pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_a1_pack8 : pipeline_binaryop_broadcast_a1_pack4; | |||
| } | |||
| else if (bottom_blob1.dims == 1 && bottom_blob1.w == 1 && bottom_blob1.elempack == 1) | |||
| { | |||
| pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_b1_pack8 : pipeline_binaryop_broadcast_b1_pack4; | |||
| } | |||
| else if (bottom_blob.dims == 3 && bottom_blob1.dims == 3 && bottom_blob1.w == bottom_blob.w && bottom_blob1.h == bottom_blob.h && bottom_blob1.c == 1 && bottom_blob1.elempack == 1) | |||
| { | |||
| // special type 2 | |||
| pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_b1_pack8 : pipeline_binaryop_broadcast_b1_pack4; | |||
| } | |||
| else if (bottom_blob.dims == 3 && bottom_blob1.dims == 3 && bottom_blob1.w == bottom_blob.w && bottom_blob1.h == bottom_blob.h && bottom_blob.c == 1 && bottom_blob.elempack == 1) | |||
| { | |||
| // special type 4 | |||
| pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_a1_pack8 : pipeline_binaryop_broadcast_a1_pack4; | |||
| } | |||
| else | |||
| { | |||
| pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_pack8 : pipeline_binaryop_broadcast_pack4; | |||
| broadcast_a1b1 = false; | |||
| } | |||
| } | |||
| cmd.record_pipeline(pipeline, bindings, broadcast_a1b1 ? constants_broadcast_a1b1 : constants, top_blob); | |||
| } | |||
| else | |||
| // no broadcast | |||
| if (A.dims == B.dims && A.w == B.w && A.h == B.h && A.d == B.d && A.c == B.c && A.elempack == B.elempack) | |||
| { | |||
| std::vector<vk_constant_type> constants(15); | |||
| constants[0].i = bottom_blob.dims; | |||
| constants[1].i = bottom_blob.w; | |||
| constants[2].i = bottom_blob.h * bottom_blob.d; | |||
| constants[3].i = bottom_blob.c; | |||
| constants[4].i = bottom_blob.cstep; | |||
| constants[5].i = bottom_blob1.dims; | |||
| constants[6].i = bottom_blob1.w; | |||
| constants[7].i = bottom_blob1.h * bottom_blob1.d; | |||
| constants[8].i = bottom_blob1.c; | |||
| constants[9].i = bottom_blob1.cstep; | |||
| constants[0].i = A.dims; | |||
| constants[1].i = A.w; | |||
| constants[2].i = A.h * A.d; | |||
| constants[3].i = A.c; | |||
| constants[4].i = A.cstep; | |||
| constants[5].i = B.dims; | |||
| constants[6].i = B.w; | |||
| constants[7].i = B.h * B.d; | |||
| constants[8].i = B.c; | |||
| constants[9].i = B.cstep; | |||
| constants[10].i = top_blob.dims; | |||
| constants[11].i = top_blob.w; | |||
| constants[12].i = top_blob.h * top_blob.d; | |||
| @@ -490,8 +446,86 @@ int BinaryOp_vulkan::forward(const std::vector<VkMat>& bottom_blobs, std::vector | |||
| : pipeline_binaryop; | |||
| cmd.record_pipeline(pipeline, bindings, constants, top_blob); | |||
| return 0; | |||
| } | |||
| std::vector<vk_constant_type> constants(18); | |||
| constants[0].i = A.dims; | |||
| constants[1].i = A.w; | |||
| constants[2].i = A.h; | |||
| constants[3].i = A.d; | |||
| constants[4].i = A.c; | |||
| constants[5].i = A.cstep; | |||
| constants[6].i = B.dims; | |||
| constants[7].i = B.w; | |||
| constants[8].i = B.h; | |||
| constants[9].i = B.d; | |||
| constants[10].i = B.c; | |||
| constants[11].i = B.cstep; | |||
| constants[12].i = top_blob.dims; | |||
| constants[13].i = top_blob.w; | |||
| constants[14].i = top_blob.h; | |||
| constants[15].i = top_blob.d; | |||
| constants[16].i = top_blob.c; | |||
| constants[17].i = top_blob.cstep; | |||
| const int ri = op_type_r == op_type ? 0 : 1; | |||
| if (B.w * B.h * B.d * B.c * B.elempack == 1) | |||
| { | |||
| const Pipeline* pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_outer_pack8[ri] | |||
| : out_elempack == 4 ? pipeline_binaryop_broadcast_outer_pack4[ri] | |||
| : pipeline_binaryop_broadcast_outer[ri]; | |||
| cmd.record_pipeline(pipeline, bindings, constants, top_blob); | |||
| return 0; | |||
| } | |||
| // broadcast B for inner axis | |||
| if ((B.dims < A.dims) | |||
| || (A.dims == 2 && B.w == 1 && B.h == A.h) | |||
| || (A.dims == 3 && B.w == 1 && B.h == 1 && B.c == A.c) | |||
| || (A.dims == 3 && B.w == 1 && B.h == A.h && B.c == A.c) | |||
| || (A.dims == 4 && B.w == 1 && B.h == 1 && B.d == 1 && B.c == A.c) | |||
| || (A.dims == 4 && B.w == 1 && B.h == 1 && B.d == A.d && B.c == A.c) | |||
| || (A.dims == 4 && B.w == 1 && B.h == A.h && B.d == A.d && B.c == A.c)) | |||
| { | |||
| const Pipeline* pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_inner_pack8[ri] | |||
| : out_elempack == 4 ? pipeline_binaryop_broadcast_inner_pack4[ri] | |||
| : pipeline_binaryop_broadcast_inner[ri]; | |||
| cmd.record_pipeline(pipeline, bindings, constants, top_blob); | |||
| return 0; | |||
| } | |||
| // broadcast B for outer axis | |||
| if (B.elempack == 1 && ((A.dims == 2 && B.w == A.w && B.h == 1) || (A.dims == 3 && B.w == A.w && B.h == 1 && B.c == 1) || (A.dims == 3 && B.w == A.w && B.h == A.h && B.c == 1) || (A.dims == 4 && B.w == A.w && B.h == 1 && B.d == 1 && B.c == 1) || (A.dims == 4 && B.w == A.w && B.h == A.h && B.d == 1 && B.c == 1) || (A.dims == 4 && B.w == A.w && B.h == A.h && B.d == A.d && B.c == 1))) | |||
| { | |||
| const Pipeline* pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_outer_pack8[ri] | |||
| : out_elempack == 4 ? pipeline_binaryop_broadcast_outer_pack4[ri] | |||
| : pipeline_binaryop_broadcast_outer[ri]; | |||
| cmd.record_pipeline(pipeline, bindings, constants, top_blob); | |||
| return 0; | |||
| } | |||
| // some special broadcast rule here | |||
| if (A.dims == 3 && B.dims == 3 && A.w == B.w && B.h == 1 && A.c == B.c) | |||
| { | |||
| const Pipeline* pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_inner_pack8[ri] | |||
| : out_elempack == 4 ? pipeline_binaryop_broadcast_inner_pack4[ri] | |||
| : pipeline_binaryop_broadcast_inner[ri]; | |||
| cmd.record_pipeline(pipeline, bindings, constants, top_blob); | |||
| return 0; | |||
| } | |||
| // should never reach here | |||
| return 0; | |||
| } | |||
| @@ -522,141 +556,40 @@ int BinaryOp_vulkan::forward_inplace(VkMat& bottom_top_blob, VkCompute& cmd, con | |||
| int BinaryOp_vulkan::forward(const std::vector<VkImageMat>& bottom_blobs, std::vector<VkImageMat>& top_blobs, VkCompute& cmd, const Option& opt) const | |||
| { | |||
| const VkImageMat& bottom_blob = bottom_blobs[0]; | |||
| const VkImageMat& bottom_blob1 = bottom_blobs[1]; | |||
| const bool b_is_scalar = bottom_blobs[1].w * bottom_blobs[1].h * bottom_blobs[1].d * bottom_blobs[1].c * bottom_blobs[1].elempack == 1; | |||
| const bool a_rank_is_lower = bottom_blobs[0].dims < bottom_blobs[1].dims && !b_is_scalar; | |||
| const bool a_size_is_lower = bottom_blobs[0].w * bottom_blobs[0].h * bottom_blobs[0].d * bottom_blobs[0].c * bottom_blobs[0].elempack < bottom_blobs[1].w * bottom_blobs[1].h * bottom_blobs[1].d * bottom_blobs[1].c * bottom_blobs[1].elempack; | |||
| const bool a_is_lower = a_rank_is_lower || a_size_is_lower; | |||
| const VkImageMat& A = a_is_lower ? bottom_blobs[1] : bottom_blobs[0]; | |||
| const VkImageMat& B = a_is_lower ? bottom_blobs[0] : bottom_blobs[1]; | |||
| const int op_type_r = a_is_lower ? get_reverse_op_type(op_type) : op_type; | |||
| VkImageMat& top_blob = top_blobs[0]; | |||
| // broadcast | |||
| if (bottom_blob.dims > bottom_blob1.dims) | |||
| { | |||
| top_blob.create_like(bottom_blob, opt.blob_vkallocator); | |||
| } | |||
| else if (bottom_blob.dims < bottom_blob1.dims) | |||
| { | |||
| top_blob.create_like(bottom_blob1, opt.blob_vkallocator); | |||
| } | |||
| else // if (bottom_blob.dims == bottom_blob1.dims) | |||
| { | |||
| if (bottom_blob.w * bottom_blob.h * bottom_blob.d * bottom_blob.c * bottom_blob.elempack >= bottom_blob1.w * bottom_blob1.h * bottom_blob1.d * bottom_blob1.c * bottom_blob1.elempack) | |||
| { | |||
| top_blob.create_like(bottom_blob, opt.blob_vkallocator); | |||
| } | |||
| else | |||
| { | |||
| top_blob.create_like(bottom_blob1, opt.blob_vkallocator); | |||
| } | |||
| } | |||
| top_blob.create_like(A, opt.blob_vkallocator); | |||
| if (top_blob.empty()) | |||
| return -100; | |||
| int out_elempack = top_blob.elempack; | |||
| std::vector<VkImageMat> bindings(3); | |||
| bindings[0] = bottom_blob; | |||
| bindings[1] = bottom_blob1; | |||
| bindings[0] = A; | |||
| bindings[1] = B; | |||
| bindings[2] = top_blob; | |||
| bool broadcast = true; | |||
| if (bottom_blob.dims == bottom_blob1.dims | |||
| && bottom_blob.w == bottom_blob1.w | |||
| && bottom_blob.h == bottom_blob1.h | |||
| && bottom_blob.d == bottom_blob1.d | |||
| && bottom_blob.c == bottom_blob1.c | |||
| && bottom_blob.elempack == bottom_blob1.elempack) | |||
| { | |||
| broadcast = false; | |||
| } | |||
| if (broadcast) | |||
| { | |||
| std::vector<vk_constant_type> constants(18); | |||
| constants[0].i = bottom_blob.dims; | |||
| constants[1].i = bottom_blob.w; | |||
| constants[2].i = bottom_blob.h; | |||
| constants[3].i = bottom_blob.d; | |||
| constants[4].i = bottom_blob.c; | |||
| constants[5].i = 0; //bottom_blob.cstep; | |||
| constants[6].i = bottom_blob1.dims; | |||
| constants[7].i = bottom_blob1.w; | |||
| constants[8].i = bottom_blob1.h; | |||
| constants[9].i = bottom_blob1.d; | |||
| constants[10].i = bottom_blob1.c; | |||
| constants[11].i = 0; //bottom_blob1.cstep; | |||
| constants[12].i = top_blob.dims; | |||
| constants[13].i = top_blob.w; | |||
| constants[14].i = top_blob.h; | |||
| constants[15].i = top_blob.d; | |||
| constants[16].i = top_blob.c; | |||
| constants[17].i = 0; //top_blob.cstep; | |||
| std::vector<vk_constant_type> constants_broadcast_a1b1(15); | |||
| constants_broadcast_a1b1[0].i = bottom_blob.dims; | |||
| constants_broadcast_a1b1[1].i = bottom_blob.w; | |||
| constants_broadcast_a1b1[2].i = bottom_blob.h * bottom_blob.d; | |||
| constants_broadcast_a1b1[3].i = bottom_blob.c; | |||
| constants_broadcast_a1b1[4].i = 0; //bottom_blob.cstep; | |||
| constants_broadcast_a1b1[5].i = bottom_blob1.dims; | |||
| constants_broadcast_a1b1[6].i = bottom_blob1.w; | |||
| constants_broadcast_a1b1[7].i = bottom_blob1.h * bottom_blob1.d; | |||
| constants_broadcast_a1b1[8].i = bottom_blob1.c; | |||
| constants_broadcast_a1b1[9].i = 0; //bottom_blob1.cstep; | |||
| constants_broadcast_a1b1[10].i = top_blob.dims; | |||
| constants_broadcast_a1b1[11].i = top_blob.w; | |||
| constants_broadcast_a1b1[12].i = top_blob.h * top_blob.d; | |||
| constants_broadcast_a1b1[13].i = top_blob.c; | |||
| constants_broadcast_a1b1[14].i = 0; //top_blob.cstep; | |||
| bool broadcast_a1b1 = true; | |||
| const Pipeline* pipeline = 0; | |||
| if (bottom_blob.elempack == 1 && bottom_blob1.elempack == 1) | |||
| { | |||
| pipeline = pipeline_binaryop_broadcast; | |||
| broadcast_a1b1 = false; | |||
| } | |||
| else | |||
| { | |||
| if (bottom_blob.dims == 1 && bottom_blob.w == 1 && bottom_blob.elempack == 1) | |||
| { | |||
| pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_a1_pack8 : pipeline_binaryop_broadcast_a1_pack4; | |||
| } | |||
| else if (bottom_blob1.dims == 1 && bottom_blob1.w == 1 && bottom_blob1.elempack == 1) | |||
| { | |||
| pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_b1_pack8 : pipeline_binaryop_broadcast_b1_pack4; | |||
| } | |||
| else if (bottom_blob.dims == 3 && bottom_blob1.dims == 3 && bottom_blob1.w == bottom_blob.w && bottom_blob1.h == bottom_blob.h && bottom_blob1.c == 1 && bottom_blob1.elempack == 1) | |||
| { | |||
| // special type 2 | |||
| pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_b1_pack8 : pipeline_binaryop_broadcast_b1_pack4; | |||
| } | |||
| else if (bottom_blob.dims == 3 && bottom_blob1.dims == 3 && bottom_blob1.w == bottom_blob.w && bottom_blob1.h == bottom_blob.h && bottom_blob.c == 1 && bottom_blob.elempack == 1) | |||
| { | |||
| // special type 4 | |||
| pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_a1_pack8 : pipeline_binaryop_broadcast_a1_pack4; | |||
| } | |||
| else | |||
| { | |||
| pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_pack8 : pipeline_binaryop_broadcast_pack4; | |||
| broadcast_a1b1 = false; | |||
| } | |||
| } | |||
| cmd.record_pipeline(pipeline, bindings, broadcast_a1b1 ? constants_broadcast_a1b1 : constants, top_blob); | |||
| } | |||
| else | |||
| // no broadcast | |||
| if (A.dims == B.dims && A.w == B.w && A.h == B.h && A.d == B.d && A.c == B.c && A.elempack == B.elempack) | |||
| { | |||
| std::vector<vk_constant_type> constants(15); | |||
| constants[0].i = bottom_blob.dims; | |||
| constants[1].i = bottom_blob.w; | |||
| constants[2].i = bottom_blob.h * bottom_blob.d; | |||
| constants[3].i = bottom_blob.c; | |||
| constants[4].i = 0; //bottom_blob.cstep; | |||
| constants[5].i = bottom_blob1.dims; | |||
| constants[6].i = bottom_blob1.w; | |||
| constants[7].i = bottom_blob1.h * bottom_blob1.d; | |||
| constants[8].i = bottom_blob1.c; | |||
| constants[9].i = 0; //bottom_blob1.cstep; | |||
| constants[0].i = A.dims; | |||
| constants[1].i = A.w; | |||
| constants[2].i = A.h * A.d; | |||
| constants[3].i = A.c; | |||
| constants[4].i = 0; //A.cstep; | |||
| constants[5].i = B.dims; | |||
| constants[6].i = B.w; | |||
| constants[7].i = B.h * B.d; | |||
| constants[8].i = B.c; | |||
| constants[9].i = 0; //B.cstep; | |||
| constants[10].i = top_blob.dims; | |||
| constants[11].i = top_blob.w; | |||
| constants[12].i = top_blob.h * top_blob.d; | |||
| @@ -668,8 +601,86 @@ int BinaryOp_vulkan::forward(const std::vector<VkImageMat>& bottom_blobs, std::v | |||
| : pipeline_binaryop; | |||
| cmd.record_pipeline(pipeline, bindings, constants, top_blob); | |||
| return 0; | |||
| } | |||
| std::vector<vk_constant_type> constants(18); | |||
| constants[0].i = A.dims; | |||
| constants[1].i = A.w; | |||
| constants[2].i = A.h; | |||
| constants[3].i = A.d; | |||
| constants[4].i = A.c; | |||
| constants[5].i = 0; //A.cstep; | |||
| constants[6].i = B.dims; | |||
| constants[7].i = B.w; | |||
| constants[8].i = B.h; | |||
| constants[9].i = B.d; | |||
| constants[10].i = B.c; | |||
| constants[11].i = 0; //B.cstep; | |||
| constants[12].i = top_blob.dims; | |||
| constants[13].i = top_blob.w; | |||
| constants[14].i = top_blob.h; | |||
| constants[15].i = top_blob.d; | |||
| constants[16].i = top_blob.c; | |||
| constants[17].i = 0; //top_blob.cstep; | |||
| const int ri = op_type_r == op_type ? 0 : 1; | |||
| if (B.w * B.h * B.d * B.c * B.elempack == 1) | |||
| { | |||
| const Pipeline* pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_outer_pack8[ri] | |||
| : out_elempack == 4 ? pipeline_binaryop_broadcast_outer_pack4[ri] | |||
| : pipeline_binaryop_broadcast_outer[ri]; | |||
| cmd.record_pipeline(pipeline, bindings, constants, top_blob); | |||
| return 0; | |||
| } | |||
| // broadcast B for inner axis | |||
| if ((B.dims < A.dims) | |||
| || (A.dims == 2 && B.w == 1 && B.h == A.h) | |||
| || (A.dims == 3 && B.w == 1 && B.h == 1 && B.c == A.c) | |||
| || (A.dims == 3 && B.w == 1 && B.h == A.h && B.c == A.c) | |||
| || (A.dims == 4 && B.w == 1 && B.h == 1 && B.d == 1 && B.c == A.c) | |||
| || (A.dims == 4 && B.w == 1 && B.h == 1 && B.d == A.d && B.c == A.c) | |||
| || (A.dims == 4 && B.w == 1 && B.h == A.h && B.d == A.d && B.c == A.c)) | |||
| { | |||
| const Pipeline* pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_inner_pack8[ri] | |||
| : out_elempack == 4 ? pipeline_binaryop_broadcast_inner_pack4[ri] | |||
| : pipeline_binaryop_broadcast_inner[ri]; | |||
| cmd.record_pipeline(pipeline, bindings, constants, top_blob); | |||
| return 0; | |||
| } | |||
| // broadcast B for outer axis | |||
| if (B.elempack == 1 && ((A.dims == 2 && B.w == A.w && B.h == 1) || (A.dims == 3 && B.w == A.w && B.h == 1 && B.c == 1) || (A.dims == 3 && B.w == A.w && B.h == A.h && B.c == 1) || (A.dims == 4 && B.w == A.w && B.h == 1 && B.d == 1 && B.c == 1) || (A.dims == 4 && B.w == A.w && B.h == A.h && B.d == 1 && B.c == 1) || (A.dims == 4 && B.w == A.w && B.h == A.h && B.d == A.d && B.c == 1))) | |||
| { | |||
| const Pipeline* pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_outer_pack8[ri] | |||
| : out_elempack == 4 ? pipeline_binaryop_broadcast_outer_pack4[ri] | |||
| : pipeline_binaryop_broadcast_outer[ri]; | |||
| cmd.record_pipeline(pipeline, bindings, constants, top_blob); | |||
| return 0; | |||
| } | |||
| // some special broadcast rule here | |||
| if (A.dims == 3 && B.dims == 3 && A.w == B.w && B.h == 1 && A.c == B.c) | |||
| { | |||
| const Pipeline* pipeline = out_elempack == 8 ? pipeline_binaryop_broadcast_inner_pack8[ri] | |||
| : out_elempack == 4 ? pipeline_binaryop_broadcast_inner_pack4[ri] | |||
| : pipeline_binaryop_broadcast_inner[ri]; | |||
| cmd.record_pipeline(pipeline, bindings, constants, top_blob); | |||
| return 0; | |||
| } | |||
| // should never reach here | |||
| return 0; | |||
| } | |||
| @@ -43,13 +43,12 @@ public: | |||
| Pipeline* pipeline_binaryop_pack8; | |||
| // broadcast | |||
| Pipeline* pipeline_binaryop_broadcast; | |||
| Pipeline* pipeline_binaryop_broadcast_pack4; | |||
| Pipeline* pipeline_binaryop_broadcast_a1_pack4; | |||
| Pipeline* pipeline_binaryop_broadcast_b1_pack4; | |||
| Pipeline* pipeline_binaryop_broadcast_pack8; | |||
| Pipeline* pipeline_binaryop_broadcast_a1_pack8; | |||
| Pipeline* pipeline_binaryop_broadcast_b1_pack8; | |||
| Pipeline* pipeline_binaryop_broadcast_inner[2]; | |||
| Pipeline* pipeline_binaryop_broadcast_inner_pack4[2]; | |||
| Pipeline* pipeline_binaryop_broadcast_inner_pack8[2]; | |||
| Pipeline* pipeline_binaryop_broadcast_outer[2]; | |||
| Pipeline* pipeline_binaryop_broadcast_outer_pack4[2]; | |||
| Pipeline* pipeline_binaryop_broadcast_outer_pack8[2]; | |||
| }; | |||
| } // namespace ncnn | |||
| @@ -92,52 +92,48 @@ void main() | |||
| afp v1 = buffer_ld1(a_blob_data, gi); | |||
| #endif | |||
| afp res; | |||
| afp v2; | |||
| if (with_scalar == 1) | |||
| { | |||
| // type 5 10 15 | |||
| afp b = afp(const_b); | |||
| if (op_type == 0) res = v1 + b; | |||
| if (op_type == 1) res = v1 - b; | |||
| if (op_type == 2) res = v1 * b; | |||
| if (op_type == 3) res = v1 / b; | |||
| if (op_type == 4) res = max(v1, b); | |||
| if (op_type == 5) res = min(v1, b); | |||
| if (op_type == 6) res = pow(v1, b); | |||
| if (op_type == 7) res = b - v1; | |||
| if (op_type == 8) res = b / v1; | |||
| // type 0 1 2 3 | |||
| v2 = afp(const_b); | |||
| } | |||
| else if (psc(bdims) == 1 && psc(bw) == 1) | |||
| { | |||
| // type 0 1 2 3 | |||
| #if NCNN_image_shader | |||
| image3d_st1(top_blob_3d, ivec3(gx, gy, gz), res); | |||
| v2 = image3d_ld1(b_blob_3d, ivec3(0, 0, 0)); | |||
| #else | |||
| buffer_st1(a_blob_data, gi, res); | |||
| v2 = buffer_ld1(b_blob_data, 0); | |||
| #endif | |||
| } | |||
| else | |||
| { | |||
| // type 7 13 19 | |||
| // type 4 5 6 7 | |||
| #if NCNN_image_shader | |||
| afp v2 = image3d_ld1(b_blob_3d, ivec3(gx, gy, gz)); | |||
| v2 = image3d_ld1(b_blob_3d, ivec3(gx, gy, gz)); | |||
| #else | |||
| afp v2 = buffer_ld1(b_blob_data, gi); | |||
| v2 = buffer_ld1(b_blob_data, gi); | |||
| #endif | |||
| } | |||
| afp res; | |||
| if (op_type == 0) res = v1 + v2; | |||
| if (op_type == 1) res = v1 - v2; | |||
| if (op_type == 2) res = v1 * v2; | |||
| if (op_type == 3) res = v1 / v2; | |||
| if (op_type == 4) res = max(v1, v2); | |||
| if (op_type == 5) res = min(v1, v2); | |||
| if (op_type == 6) res = pow(v1, v2); | |||
| if (op_type == 7) res = v2 - v1; | |||
| if (op_type == 8) res = v2 / v1; | |||
| if (op_type == 0) res = v1 + v2; | |||
| if (op_type == 1) res = v1 - v2; | |||
| if (op_type == 2) res = v1 * v2; | |||
| if (op_type == 3) res = v1 / v2; | |||
| if (op_type == 4) res = max(v1, v2); | |||
| if (op_type == 5) res = min(v1, v2); | |||
| if (op_type == 6) res = pow(v1, v2); | |||
| if (op_type == 7) res = v2 - v1; | |||
| if (op_type == 8) res = v2 / v1; | |||
| if (op_type == 9) res = pow(v2, v1); | |||
| #if NCNN_image_shader | |||
| image3d_st1(top_blob_3d, ivec3(gx, gy, gz), res); | |||
| image3d_st1(top_blob_3d, ivec3(gx, gy, gz), res); | |||
| #else | |||
| buffer_st1(top_blob_data, gi, res); | |||
| buffer_st1(top_blob_data, gi, res); | |||
| #endif | |||
| } | |||
| } | |||
| @@ -1,553 +0,0 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #version 450 | |||
| #if NCNN_fp16_storage | |||
| #extension GL_EXT_shader_16bit_storage: require | |||
| #endif | |||
| #if NCNN_fp16_arithmetic | |||
| #extension GL_EXT_shader_explicit_arithmetic_types_float16: require | |||
| #endif | |||
| layout (constant_id = 0) const int op_type = 0; | |||
| #define shape_constant_id_offset 1 | |||
| layout (constant_id = shape_constant_id_offset + 0) const int adims = 0; | |||
| layout (constant_id = shape_constant_id_offset + 1) const int aw = 0; | |||
| layout (constant_id = shape_constant_id_offset + 2) const int ah = 0; | |||
| layout (constant_id = shape_constant_id_offset + 3) const int ad = 0; | |||
| layout (constant_id = shape_constant_id_offset + 4) const int ac = 0; | |||
| layout (constant_id = shape_constant_id_offset + 5) const int acstep = 0; | |||
| layout (constant_id = shape_constant_id_offset + 6) const int bdims = 0; | |||
| layout (constant_id = shape_constant_id_offset + 7) const int bw = 0; | |||
| layout (constant_id = shape_constant_id_offset + 8) const int bh = 0; | |||
| layout (constant_id = shape_constant_id_offset + 9) const int bd = 0; | |||
| layout (constant_id = shape_constant_id_offset + 10) const int bc = 0; | |||
| layout (constant_id = shape_constant_id_offset + 11) const int bcstep = 0; | |||
| layout (constant_id = shape_constant_id_offset + 12) const int outdims = 0; | |||
| layout (constant_id = shape_constant_id_offset + 13) const int outw = 0; | |||
| layout (constant_id = shape_constant_id_offset + 14) const int outh = 0; | |||
| layout (constant_id = shape_constant_id_offset + 15) const int outd = 0; | |||
| layout (constant_id = shape_constant_id_offset + 16) const int outc = 0; | |||
| layout (constant_id = shape_constant_id_offset + 17) const int outcstep = 0; | |||
| #if NCNN_image_shader | |||
| layout (binding = 0) uniform unfp sampler3D a_blob_3d; | |||
| layout (binding = 1) uniform unfp sampler3D b_blob_3d; | |||
| layout (binding = 2, imfmtc1) writeonly uniform unfp image3D top_blob_3d; | |||
| #else | |||
| layout (binding = 0) readonly buffer a_blob { sfp a_blob_data[]; }; | |||
| layout (binding = 1) readonly buffer b_blob { sfp b_blob_data[]; }; | |||
| layout (binding = 2) writeonly buffer top_blob { sfp top_blob_data[]; }; | |||
| #endif | |||
| layout (push_constant) uniform parameter | |||
| { | |||
| int adims; | |||
| int aw; | |||
| int ah; | |||
| int ad; | |||
| int ac; | |||
| int acstep; | |||
| int bdims; | |||
| int bw; | |||
| int bh; | |||
| int bd; | |||
| int bc; | |||
| int bcstep; | |||
| int outdims; | |||
| int outw; | |||
| int outh; | |||
| int outd; | |||
| int outc; | |||
| int outcstep; | |||
| } p; | |||
| void main() | |||
| { | |||
| int gx = int(gl_GlobalInvocationID.x); | |||
| int gy = int(gl_GlobalInvocationID.y); | |||
| int gz = int(gl_GlobalInvocationID.z); | |||
| if (gx >= psc(outw) || gy >= psc(outh) * psc(outd) || gz >= psc(outc)) | |||
| return; | |||
| #if NCNN_image_shader | |||
| int ax = gx; | |||
| int ay = gy; | |||
| int az = gz; | |||
| int bx = gx; | |||
| int by = gy; | |||
| int bz = gz; | |||
| if (psc(adims) == 4) | |||
| { | |||
| int yd = gy / psc(outh); | |||
| int yh = gy % psc(outh); | |||
| if (psc(bdims) == 3) | |||
| { | |||
| // type 28 | |||
| bx = yh; | |||
| by = yd; | |||
| bz = gz; | |||
| } | |||
| if (psc(bdims) == 2) | |||
| { | |||
| // type 27 | |||
| bx = yd; | |||
| by = gz; | |||
| bz = 0; | |||
| } | |||
| if (psc(bdims) == 1) | |||
| { | |||
| if (psc(bw) == 1) | |||
| { | |||
| // type 25 | |||
| bx = 0; | |||
| by = 0; | |||
| bz = 0; | |||
| } | |||
| else | |||
| { | |||
| // type 26 | |||
| bx = gz; | |||
| by = 0; | |||
| bz = 0; | |||
| } | |||
| } | |||
| } | |||
| else if (psc(adims) == 3) | |||
| { | |||
| if (psc(bdims) == 4) | |||
| { | |||
| int yd = gy / psc(outh); | |||
| int yh = gy % psc(outh); | |||
| // type 23 | |||
| ax = yh; | |||
| ay = yd; | |||
| az = gz; | |||
| } | |||
| if (psc(bdims) == 3) | |||
| { | |||
| if (psc(bw) == 1 && psc(bh) == 1) | |||
| { | |||
| // special type 1 | |||
| bx = 0; | |||
| by = 0; | |||
| } | |||
| if (psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(bc) == 1) | |||
| { | |||
| // special type 2 | |||
| bz = 0; | |||
| } | |||
| if (psc(aw) == 1 && psc(ah) == 1) | |||
| { | |||
| // special type 3 | |||
| ax = 0; | |||
| ay = 0; | |||
| } | |||
| if (psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(ac) == 1) | |||
| { | |||
| // special type 4 | |||
| az = 0; | |||
| } | |||
| if (psc(aw) != 1 && psc(bw) == 1 && psc(bh) == psc(ah) && psc(bc) == psc(ac)) | |||
| { | |||
| // special type 5 | |||
| bx = 0; | |||
| } | |||
| if (psc(bw) == psc(aw) && psc(ah) != 1 && psc(bh) == 1 && psc(bc) == psc(ac)) | |||
| { | |||
| // special type 6 | |||
| by = 0; | |||
| } | |||
| if (psc(bw) != 1 && psc(aw) == 1 && psc(bh) == psc(ah) && psc(bc) == psc(ac)) | |||
| { | |||
| // special type 7 | |||
| ax = 0; | |||
| } | |||
| if (psc(bw) == psc(aw) && psc(bh) != 1 && psc(ah) == 1 && psc(bc) == psc(ac)) | |||
| { | |||
| // special type 8 | |||
| ay = 0; | |||
| } | |||
| } | |||
| if (psc(bdims) == 2) | |||
| { | |||
| // type 18 | |||
| bx = gy; | |||
| by = gz; | |||
| bz = 0; | |||
| } | |||
| if (psc(bdims) == 1) | |||
| { | |||
| if (psc(bw) == 1) | |||
| { | |||
| // type 16 | |||
| bx = 0; | |||
| by = 0; | |||
| bz = 0; | |||
| } | |||
| else | |||
| { | |||
| // type 17 | |||
| bx = gz; | |||
| by = 0; | |||
| bz = 0; | |||
| } | |||
| } | |||
| } | |||
| else if (psc(adims) == 2) | |||
| { | |||
| if (psc(bdims) == 4) | |||
| { | |||
| int yd = gy / psc(outh); | |||
| int yh = gy % psc(outh); | |||
| // type 22 | |||
| ax = yd; | |||
| ay = gz; | |||
| az = 0; | |||
| } | |||
| if (psc(bdims) == 3) | |||
| { | |||
| // type 14 | |||
| ax = gy; | |||
| ay = gz; | |||
| az = 0; | |||
| } | |||
| if (psc(bdims) == 1) | |||
| { | |||
| if (psc(bw) == 1) | |||
| { | |||
| // type 11 | |||
| bx = 0; | |||
| by = 0; | |||
| bz = 0; | |||
| } | |||
| else | |||
| { | |||
| // type 12 | |||
| bx = gy; | |||
| by = 0; | |||
| bz = 0; | |||
| } | |||
| } | |||
| } | |||
| else if (psc(adims) == 1) | |||
| { | |||
| if (psc(aw) == 1) | |||
| { | |||
| // type 2 3 4 20 | |||
| ax = 0; | |||
| ay = 0; | |||
| az = 0; | |||
| } | |||
| else | |||
| { | |||
| if (psc(bdims) == 4) | |||
| { | |||
| // type 21 | |||
| ax = gz; | |||
| ay = 0; | |||
| az = 0; | |||
| } | |||
| if (psc(bdims) == 3) | |||
| { | |||
| // type 9 | |||
| ax = gz; | |||
| ay = 0; | |||
| az = 0; | |||
| } | |||
| if (psc(bdims) == 2) | |||
| { | |||
| // type 8 | |||
| ax = gy; | |||
| ay = 0; | |||
| az = 0; | |||
| } | |||
| if (psc(bdims) == 1) | |||
| { | |||
| if (psc(bw) == 1) | |||
| { | |||
| // type 6 | |||
| bx = 0; | |||
| by = 0; | |||
| bz = 0; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| afp v1 = image3d_ld1(a_blob_3d, ivec3(ax, ay, az)); | |||
| afp v2 = image3d_ld1(b_blob_3d, ivec3(bx, by, bz)); | |||
| #else | |||
| const int gi = gz * psc(outcstep) + gy * psc(outw) + gx; | |||
| int ai; | |||
| int bi; | |||
| if (psc(adims) == 4) | |||
| { | |||
| int yd = gy / psc(outh); | |||
| int yh = gy % psc(outh); | |||
| if (psc(bdims) == 3) | |||
| { | |||
| // type 28 | |||
| ai = gi; | |||
| bi = gz * psc(bcstep) + yd * psc(bw) + yh; | |||
| } | |||
| if (psc(bdims) == 2) | |||
| { | |||
| // type 27 | |||
| ai = gi; | |||
| bi = gz * psc(bw) + yd; | |||
| } | |||
| if (psc(bdims) == 1) | |||
| { | |||
| if (psc(bw) == 1) | |||
| { | |||
| // type 25 | |||
| ai = gi; | |||
| bi = 0; | |||
| } | |||
| else | |||
| { | |||
| // type 26 | |||
| ai = gi; | |||
| bi = gz; | |||
| } | |||
| } | |||
| } | |||
| else if (psc(adims) == 3) | |||
| { | |||
| if (psc(bdims) == 4) | |||
| { | |||
| int yd = gy / psc(outh); | |||
| int yh = gy % psc(outh); | |||
| // type 23 | |||
| ai = gz * psc(acstep) + yd * psc(aw) + yh; | |||
| bi = gi; | |||
| } | |||
| if (psc(bdims) == 3) | |||
| { | |||
| if (psc(bw) == 1 && psc(bh) == 1) | |||
| { | |||
| // special type 1 | |||
| ai = gi; | |||
| bi = gz * psc(bcstep); | |||
| } | |||
| if (psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(bc) == 1) | |||
| { | |||
| // special type 2 | |||
| ai = gi; | |||
| bi = gy * psc(bw) + gx; | |||
| } | |||
| if (psc(aw) == 1 && psc(ah) == 1) | |||
| { | |||
| // special type 3 | |||
| ai = gz * psc(acstep); | |||
| bi = gi; | |||
| } | |||
| if (psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(ac) == 1) | |||
| { | |||
| // special type 4 | |||
| ai = gy * psc(aw) + gx; | |||
| bi = gi; | |||
| } | |||
| if (psc(aw) != 1 && psc(bw) == 1 && psc(bh) == psc(ah) && psc(bc) == psc(ac)) | |||
| { | |||
| // special type 5 | |||
| ai = gi; | |||
| bi = gz * psc(bcstep) + gy; | |||
| } | |||
| if (psc(bw) == psc(aw) && psc(ah) != 1 && psc(bh) == 1 && psc(bc) == psc(ac)) | |||
| { | |||
| // special type 6 | |||
| ai = gi; | |||
| bi = gz * psc(bcstep) + gx; | |||
| } | |||
| if (psc(bw) != 1 && psc(aw) == 1 && psc(bh) == psc(ah) && psc(bc) == psc(ac)) | |||
| { | |||
| // special type 7 | |||
| ai = gz * psc(acstep) + gy; | |||
| bi = gi; | |||
| } | |||
| if (psc(bw) == psc(aw) && psc(bh) != 1 && psc(ah) == 1 && psc(bc) == psc(ac)) | |||
| { | |||
| // special type 8 | |||
| ai = gz * psc(acstep) + gx; | |||
| bi = gi; | |||
| } | |||
| } | |||
| if (psc(bdims) == 2) | |||
| { | |||
| // type 18 | |||
| ai = gi; | |||
| bi = gz * psc(bw) + gy; | |||
| } | |||
| if (psc(bdims) == 1) | |||
| { | |||
| if (psc(bw) == 1) | |||
| { | |||
| // type 16 | |||
| ai = gi; | |||
| bi = 0; | |||
| } | |||
| else | |||
| { | |||
| // type 17 | |||
| ai = gi; | |||
| bi = gz; | |||
| } | |||
| } | |||
| } | |||
| else if (psc(adims) == 2) | |||
| { | |||
| if (psc(bdims) == 4) | |||
| { | |||
| int yd = gy / psc(outh); | |||
| int yh = gy % psc(outh); | |||
| // type 22 | |||
| ai = gz * psc(aw) + yd; | |||
| bi = gi; | |||
| } | |||
| if (psc(bdims) == 3) | |||
| { | |||
| // type 14 | |||
| ai = gz * psc(aw) + gy; | |||
| bi = gi; | |||
| } | |||
| if (psc(bdims) == 1) | |||
| { | |||
| if (psc(bw) == 1) | |||
| { | |||
| // type 11 | |||
| ai = gi; | |||
| bi = 0; | |||
| } | |||
| else | |||
| { | |||
| // type 12 | |||
| ai = gi; | |||
| bi = gy; | |||
| } | |||
| } | |||
| } | |||
| else if (psc(adims) == 1) | |||
| { | |||
| if (psc(aw) == 1) | |||
| { | |||
| // type 2 3 4 20 | |||
| ai = 0; | |||
| bi = gi; | |||
| } | |||
| else | |||
| { | |||
| if (psc(bdims) == 4) | |||
| { | |||
| // type 21 | |||
| ai = gz; | |||
| bi = gi; | |||
| } | |||
| if (psc(bdims) == 3) | |||
| { | |||
| // type 9 | |||
| ai = gz; | |||
| bi = gi; | |||
| } | |||
| if (psc(bdims) == 2) | |||
| { | |||
| // type 8 | |||
| ai = gy; | |||
| bi = gi; | |||
| } | |||
| if (psc(bdims) == 1) | |||
| { | |||
| if (psc(bw) == 1) | |||
| { | |||
| // type 6 | |||
| ai = gi; | |||
| bi = 0; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| afp v1 = buffer_ld1(a_blob_data, ai); | |||
| afp v2 = buffer_ld1(b_blob_data, bi); | |||
| #endif | |||
| afp res; | |||
| if (op_type == 0) res = v1 + v2; | |||
| if (op_type == 1) res = v1 - v2; | |||
| if (op_type == 2) res = v1 * v2; | |||
| if (op_type == 3) res = v1 / v2; | |||
| if (op_type == 4) res = max(v1, v2); | |||
| if (op_type == 5) res = min(v1, v2); | |||
| if (op_type == 6) res = pow(v1, v2); | |||
| if (op_type == 7) res = v2 - v1; | |||
| if (op_type == 8) res = v2 / v1; | |||
| #if NCNN_image_shader | |||
| image3d_st1(top_blob_3d, ivec3(gx, gy, gz), res); | |||
| #else | |||
| buffer_st1(top_blob_data, gi, res); | |||
| #endif | |||
| } | |||
| @@ -1,169 +0,0 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #version 450 | |||
| #if NCNN_fp16_storage | |||
| #extension GL_EXT_shader_16bit_storage: require | |||
| struct sfpvec8 { f16vec4 abcd; f16vec4 efgh; }; | |||
| #endif | |||
| #if NCNN_fp16_arithmetic | |||
| #extension GL_EXT_shader_explicit_arithmetic_types_float16: require | |||
| #endif | |||
| layout (constant_id = 0) const int op_type = 0; | |||
| #define shape_constant_id_offset 1 | |||
| layout (constant_id = shape_constant_id_offset + 0) const int adims = 0; | |||
| layout (constant_id = shape_constant_id_offset + 1) const int aw = 0; | |||
| layout (constant_id = shape_constant_id_offset + 2) const int ah = 0; | |||
| layout (constant_id = shape_constant_id_offset + 3) const int ac = 0; | |||
| layout (constant_id = shape_constant_id_offset + 4) const int acstep = 0; | |||
| layout (constant_id = shape_constant_id_offset + 5) const int bdims = 0; | |||
| layout (constant_id = shape_constant_id_offset + 6) const int bw = 0; | |||
| layout (constant_id = shape_constant_id_offset + 7) const int bh = 0; | |||
| layout (constant_id = shape_constant_id_offset + 8) const int bc = 0; | |||
| layout (constant_id = shape_constant_id_offset + 9) const int bcstep = 0; | |||
| layout (constant_id = shape_constant_id_offset + 10) const int outdims = 0; | |||
| layout (constant_id = shape_constant_id_offset + 11) const int outw = 0; | |||
| layout (constant_id = shape_constant_id_offset + 12) const int outh = 0; | |||
| layout (constant_id = shape_constant_id_offset + 13) const int outc = 0; | |||
| layout (constant_id = shape_constant_id_offset + 14) const int outcstep = 0; | |||
| #if NCNN_image_shader | |||
| layout (binding = 0) uniform unfp sampler3D a_blob_3d; | |||
| layout (binding = 1) uniform unfp sampler3D b_blob_3d; | |||
| layout (binding = 2, imfmtc4) writeonly uniform unfp image3D top_blob_3d; | |||
| #else | |||
| layout (binding = 0) readonly buffer a_blob { sfp a_blob_data[]; }; | |||
| layout (binding = 1) readonly buffer b_blob { sfpvec8 b_blob_data[]; }; | |||
| layout (binding = 2) writeonly buffer top_blob { sfpvec8 top_blob_data[]; }; | |||
| #endif | |||
| layout (push_constant) uniform parameter | |||
| { | |||
| int adims; | |||
| int aw; | |||
| int ah; | |||
| int ac; | |||
| int acstep; | |||
| int bdims; | |||
| int bw; | |||
| int bh; | |||
| int bc; | |||
| int bcstep; | |||
| int outdims; | |||
| int outw; | |||
| int outh; | |||
| int outc; | |||
| int outcstep; | |||
| } p; | |||
| void main() | |||
| { | |||
| int gx = int(gl_GlobalInvocationID.x); | |||
| int gy = int(gl_GlobalInvocationID.y); | |||
| int gz = int(gl_GlobalInvocationID.z); | |||
| if (gx >= psc(outw) || gy >= psc(outh) || gz >= psc(outc)) | |||
| return; | |||
| #if NCNN_image_shader | |||
| afpvec4 v1; | |||
| if (psc(adims) == 3 && psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(ac) == 1) | |||
| { | |||
| // special type 4 | |||
| v1 = afpvec4(image3d_ld1(a_blob_3d, ivec3(gx, gy, 0))); | |||
| } | |||
| else | |||
| { | |||
| v1 = afpvec4(image3d_ld1(a_blob_3d, ivec3(0, 0, 0))); | |||
| } | |||
| afpvec8 v2 = image3d_ld8(b_blob_3d, ivec3(gx, gy, gz)); | |||
| #else | |||
| const int gi = gz * psc(outcstep) + gy * psc(outw) + gx; | |||
| int ai = 0; | |||
| if (psc(adims) == 3 && psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(ac) == 1) | |||
| { | |||
| // special type 2 | |||
| ai = gy * psc(bw) + gx; | |||
| } | |||
| // type 2 3 4 | |||
| afpvec4 v1 = afpvec4(buffer_ld1(a_blob_data, ai)); | |||
| afpvec8 v2 = buffer_ld8(b_blob_data, gi); | |||
| #endif | |||
| afpvec8 res; | |||
| if (op_type == 0) | |||
| { | |||
| res[0] = v1 + v2[0]; | |||
| res[1] = v1 + v2[1]; | |||
| } | |||
| if (op_type == 1) | |||
| { | |||
| res[0] = v1 - v2[0]; | |||
| res[1] = v1 - v2[1]; | |||
| } | |||
| if (op_type == 2) | |||
| { | |||
| res[0] = v1 * v2[0]; | |||
| res[1] = v1 * v2[1]; | |||
| } | |||
| if (op_type == 3) | |||
| { | |||
| res[0] = v1 / v2[0]; | |||
| res[1] = v1 / v2[1]; | |||
| } | |||
| if (op_type == 4) | |||
| { | |||
| res[0] = max(v1, v2[0]); | |||
| res[1] = max(v1, v2[1]); | |||
| } | |||
| if (op_type == 5) | |||
| { | |||
| res[0] = min(v1, v2[0]); | |||
| res[1] = min(v1, v2[1]); | |||
| } | |||
| if (op_type == 6) | |||
| { | |||
| res[0] = pow(v1, v2[0]); | |||
| res[1] = pow(v1, v2[1]); | |||
| } | |||
| if (op_type == 7) | |||
| { | |||
| res[0] = v2[0] - v1; | |||
| res[1] = v2[1] - v1; | |||
| } | |||
| if (op_type == 8) | |||
| { | |||
| res[0] = v2[0] / v1; | |||
| res[1] = v2[1] / v1; | |||
| } | |||
| #if NCNN_image_shader | |||
| image3d_st8(top_blob_3d, ivec3(gx, gy, gz), res); | |||
| #else | |||
| buffer_st8(top_blob_data, gi, res); | |||
| #endif | |||
| } | |||
| @@ -0,0 +1,193 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #version 450 | |||
| #if NCNN_fp16_storage | |||
| #extension GL_EXT_shader_16bit_storage: require | |||
| #endif | |||
| #if NCNN_fp16_arithmetic | |||
| #extension GL_EXT_shader_explicit_arithmetic_types_float16: require | |||
| #endif | |||
| layout (constant_id = 0) const int op_type = 0; | |||
| #define shape_constant_id_offset 1 | |||
| layout (constant_id = shape_constant_id_offset + 0) const int adims = 0; | |||
| layout (constant_id = shape_constant_id_offset + 1) const int aw = 0; | |||
| layout (constant_id = shape_constant_id_offset + 2) const int ah = 0; | |||
| layout (constant_id = shape_constant_id_offset + 3) const int ad = 0; | |||
| layout (constant_id = shape_constant_id_offset + 4) const int ac = 0; | |||
| layout (constant_id = shape_constant_id_offset + 5) const int acstep = 0; | |||
| layout (constant_id = shape_constant_id_offset + 6) const int bdims = 0; | |||
| layout (constant_id = shape_constant_id_offset + 7) const int bw = 0; | |||
| layout (constant_id = shape_constant_id_offset + 8) const int bh = 0; | |||
| layout (constant_id = shape_constant_id_offset + 9) const int bd = 0; | |||
| layout (constant_id = shape_constant_id_offset + 10) const int bc = 0; | |||
| layout (constant_id = shape_constant_id_offset + 11) const int bcstep = 0; | |||
| layout (constant_id = shape_constant_id_offset + 12) const int outdims = 0; | |||
| layout (constant_id = shape_constant_id_offset + 13) const int outw = 0; | |||
| layout (constant_id = shape_constant_id_offset + 14) const int outh = 0; | |||
| layout (constant_id = shape_constant_id_offset + 15) const int outd = 0; | |||
| layout (constant_id = shape_constant_id_offset + 16) const int outc = 0; | |||
| layout (constant_id = shape_constant_id_offset + 17) const int outcstep = 0; | |||
| #if NCNN_image_shader | |||
| layout (binding = 0) uniform unfp sampler3D a_blob_3d; | |||
| layout (binding = 1) uniform unfp sampler3D b_blob_3d; | |||
| layout (binding = 2, imfmtc1) writeonly uniform unfp image3D top_blob_3d; | |||
| #else | |||
| layout (binding = 0) readonly buffer a_blob { sfp a_blob_data[]; }; | |||
| layout (binding = 1) readonly buffer b_blob { sfp b_blob_data[]; }; | |||
| layout (binding = 2) writeonly buffer top_blob { sfp top_blob_data[]; }; | |||
| #endif | |||
| layout (push_constant) uniform parameter | |||
| { | |||
| int adims; | |||
| int aw; | |||
| int ah; | |||
| int ad; | |||
| int ac; | |||
| int acstep; | |||
| int bdims; | |||
| int bw; | |||
| int bh; | |||
| int bd; | |||
| int bc; | |||
| int bcstep; | |||
| int outdims; | |||
| int outw; | |||
| int outh; | |||
| int outd; | |||
| int outc; | |||
| int outcstep; | |||
| } p; | |||
| void main() | |||
| { | |||
| int gx = int(gl_GlobalInvocationID.x); | |||
| int gy = int(gl_GlobalInvocationID.y); | |||
| int gz = int(gl_GlobalInvocationID.z); | |||
| if (gx >= psc(outw) || gy >= psc(outh) * psc(outd) || gz >= psc(outc)) | |||
| return; | |||
| int yd = gy / psc(outh); | |||
| int yh = gy % psc(outh); | |||
| int bx = gx; | |||
| int by = gy; | |||
| int bz = gz; | |||
| if (psc(adims) == psc(bdims)) | |||
| { | |||
| // explicit broadcast | |||
| bx = min(gx, psc(bw) - 1); | |||
| by = min(yd, psc(bd) - 1) * psc(bh) + min(yh, psc(bh) - 1); | |||
| bz = min(gz, psc(bc) - 1); | |||
| } | |||
| else | |||
| { | |||
| // implicit broadcast | |||
| if (psc(adims) == 4) | |||
| { | |||
| if (psc(bdims) == 3) | |||
| { | |||
| // type 13 | |||
| bx = yh; | |||
| by = yd; | |||
| bz = gz; | |||
| } | |||
| if (psc(bdims) == 2) | |||
| { | |||
| // type 12 | |||
| bx = yd; | |||
| by = gz; | |||
| bz = 0; | |||
| } | |||
| if (psc(bdims) == 1) | |||
| { | |||
| // type 11 | |||
| bx = gz; | |||
| by = 0; | |||
| bz = 0; | |||
| } | |||
| } | |||
| else if (psc(adims) == 3) | |||
| { | |||
| if (psc(bdims) == 2) | |||
| { | |||
| // type 10 | |||
| bx = gy; | |||
| by = gz; | |||
| bz = 0; | |||
| } | |||
| if (psc(bdims) == 1) | |||
| { | |||
| // type 9 | |||
| bx = gz; | |||
| by = 0; | |||
| bz = 0; | |||
| } | |||
| } | |||
| else // if (psc(adims) == 2) | |||
| { | |||
| // if (psc(bdims) == 1) | |||
| { | |||
| // type 8 | |||
| bx = gy; | |||
| by = 0; | |||
| bz = 0; | |||
| } | |||
| } | |||
| } | |||
| #if NCNN_image_shader | |||
| afp v1 = image3d_ld1(a_blob_3d, ivec3(gx, gy, gz)); | |||
| afp v2 = image3d_ld1(b_blob_3d, ivec3(bx, by, bz)); | |||
| #else | |||
| int gi = gz * psc(outcstep) + gy * psc(outw) + gx; | |||
| int bi = bz * psc(bcstep) + by * psc(bw) + bx; | |||
| afp v1 = buffer_ld1(a_blob_data, gi); | |||
| afp v2 = buffer_ld1(b_blob_data, bi); | |||
| #endif | |||
| afp res; | |||
| if (op_type == 0) res = v1 + v2; | |||
| if (op_type == 1) res = v1 - v2; | |||
| if (op_type == 2) res = v1 * v2; | |||
| if (op_type == 3) res = v1 / v2; | |||
| if (op_type == 4) res = max(v1, v2); | |||
| if (op_type == 5) res = min(v1, v2); | |||
| if (op_type == 6) res = pow(v1, v2); | |||
| if (op_type == 7) res = v2 - v1; | |||
| if (op_type == 8) res = v2 / v1; | |||
| if (op_type == 9) res = pow(v2, v1); | |||
| #if NCNN_image_shader | |||
| image3d_st1(top_blob_3d, ivec3(gx, gy, gz), res); | |||
| #else | |||
| buffer_st1(top_blob_data, gi, res); | |||
| #endif | |||
| } | |||
| @@ -0,0 +1,193 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #version 450 | |||
| #if NCNN_fp16_storage | |||
| #extension GL_EXT_shader_16bit_storage: require | |||
| #endif | |||
| #if NCNN_fp16_arithmetic | |||
| #extension GL_EXT_shader_explicit_arithmetic_types_float16: require | |||
| #endif | |||
| layout (constant_id = 0) const int op_type = 0; | |||
| #define shape_constant_id_offset 1 | |||
| layout (constant_id = shape_constant_id_offset + 0) const int adims = 0; | |||
| layout (constant_id = shape_constant_id_offset + 1) const int aw = 0; | |||
| layout (constant_id = shape_constant_id_offset + 2) const int ah = 0; | |||
| layout (constant_id = shape_constant_id_offset + 3) const int ad = 0; | |||
| layout (constant_id = shape_constant_id_offset + 4) const int ac = 0; | |||
| layout (constant_id = shape_constant_id_offset + 5) const int acstep = 0; | |||
| layout (constant_id = shape_constant_id_offset + 6) const int bdims = 0; | |||
| layout (constant_id = shape_constant_id_offset + 7) const int bw = 0; | |||
| layout (constant_id = shape_constant_id_offset + 8) const int bh = 0; | |||
| layout (constant_id = shape_constant_id_offset + 9) const int bd = 0; | |||
| layout (constant_id = shape_constant_id_offset + 10) const int bc = 0; | |||
| layout (constant_id = shape_constant_id_offset + 11) const int bcstep = 0; | |||
| layout (constant_id = shape_constant_id_offset + 12) const int outdims = 0; | |||
| layout (constant_id = shape_constant_id_offset + 13) const int outw = 0; | |||
| layout (constant_id = shape_constant_id_offset + 14) const int outh = 0; | |||
| layout (constant_id = shape_constant_id_offset + 15) const int outd = 0; | |||
| layout (constant_id = shape_constant_id_offset + 16) const int outc = 0; | |||
| layout (constant_id = shape_constant_id_offset + 17) const int outcstep = 0; | |||
| #if NCNN_image_shader | |||
| layout (binding = 0) uniform unfp sampler3D a_blob_3d; | |||
| layout (binding = 1) uniform unfp sampler3D b_blob_3d; | |||
| layout (binding = 2, imfmtc4) writeonly uniform unfp image3D top_blob_3d; | |||
| #else | |||
| layout (binding = 0) readonly buffer a_blob { sfpvec4 a_blob_data[]; }; | |||
| layout (binding = 1) readonly buffer b_blob { sfpvec4 b_blob_data[]; }; | |||
| layout (binding = 2) writeonly buffer top_blob { sfpvec4 top_blob_data[]; }; | |||
| #endif | |||
| layout (push_constant) uniform parameter | |||
| { | |||
| int adims; | |||
| int aw; | |||
| int ah; | |||
| int ad; | |||
| int ac; | |||
| int acstep; | |||
| int bdims; | |||
| int bw; | |||
| int bh; | |||
| int bd; | |||
| int bc; | |||
| int bcstep; | |||
| int outdims; | |||
| int outw; | |||
| int outh; | |||
| int outd; | |||
| int outc; | |||
| int outcstep; | |||
| } p; | |||
| void main() | |||
| { | |||
| int gx = int(gl_GlobalInvocationID.x); | |||
| int gy = int(gl_GlobalInvocationID.y); | |||
| int gz = int(gl_GlobalInvocationID.z); | |||
| if (gx >= psc(outw) || gy >= psc(outh) * psc(outd) || gz >= psc(outc)) | |||
| return; | |||
| int yd = gy / psc(outh); | |||
| int yh = gy % psc(outh); | |||
| int bx = gx; | |||
| int by = gy; | |||
| int bz = gz; | |||
| if (psc(adims) == psc(bdims)) | |||
| { | |||
| // explicit broadcast | |||
| bx = min(gx, psc(bw) - 1); | |||
| by = min(yd, psc(bd) - 1) * psc(bh) + min(yh, psc(bh) - 1); | |||
| bz = min(gz, psc(bc) - 1); | |||
| } | |||
| else | |||
| { | |||
| // implicit broadcast | |||
| if (psc(adims) == 4) | |||
| { | |||
| if (psc(bdims) == 3) | |||
| { | |||
| // type 13 | |||
| bx = yh; | |||
| by = yd; | |||
| bz = gz; | |||
| } | |||
| if (psc(bdims) == 2) | |||
| { | |||
| // type 12 | |||
| bx = yd; | |||
| by = gz; | |||
| bz = 0; | |||
| } | |||
| if (psc(bdims) == 1) | |||
| { | |||
| // type 11 | |||
| bx = gz; | |||
| by = 0; | |||
| bz = 0; | |||
| } | |||
| } | |||
| else if (psc(adims) == 3) | |||
| { | |||
| if (psc(bdims) == 2) | |||
| { | |||
| // type 10 | |||
| bx = gy; | |||
| by = gz; | |||
| bz = 0; | |||
| } | |||
| if (psc(bdims) == 1) | |||
| { | |||
| // type 9 | |||
| bx = gz; | |||
| by = 0; | |||
| bz = 0; | |||
| } | |||
| } | |||
| else // if (psc(adims) == 2) | |||
| { | |||
| // if (psc(bdims) == 1) | |||
| { | |||
| // type 8 | |||
| bx = gy; | |||
| by = 0; | |||
| bz = 0; | |||
| } | |||
| } | |||
| } | |||
| #if NCNN_image_shader | |||
| afpvec4 v1 = image3d_ld4(a_blob_3d, ivec3(gx, gy, gz)); | |||
| afpvec4 v2 = image3d_ld4(b_blob_3d, ivec3(bx, by, bz)); | |||
| #else | |||
| int gi = gz * psc(outcstep) + gy * psc(outw) + gx; | |||
| int bi = bz * psc(bcstep) + by * psc(bw) + bx; | |||
| afpvec4 v1 = buffer_ld4(a_blob_data, gi); | |||
| afpvec4 v2 = buffer_ld4(b_blob_data, bi); | |||
| #endif | |||
| afpvec4 res; | |||
| if (op_type == 0) res = v1 + v2; | |||
| if (op_type == 1) res = v1 - v2; | |||
| if (op_type == 2) res = v1 * v2; | |||
| if (op_type == 3) res = v1 / v2; | |||
| if (op_type == 4) res = max(v1, v2); | |||
| if (op_type == 5) res = min(v1, v2); | |||
| if (op_type == 6) res = pow(v1, v2); | |||
| if (op_type == 7) res = v2 - v1; | |||
| if (op_type == 8) res = v2 / v1; | |||
| if (op_type == 9) res = pow(v2, v1); | |||
| #if NCNN_image_shader | |||
| image3d_st4(top_blob_3d, ivec3(gx, gy, gz), res); | |||
| #else | |||
| buffer_st4(top_blob_data, gi, res); | |||
| #endif | |||
| } | |||
| @@ -0,0 +1,234 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #version 450 | |||
| #if NCNN_fp16_storage | |||
| #extension GL_EXT_shader_16bit_storage: require | |||
| struct sfpvec8 { f16vec4 abcd; f16vec4 efgh; }; | |||
| #endif | |||
| #if NCNN_fp16_arithmetic | |||
| #extension GL_EXT_shader_explicit_arithmetic_types_float16: require | |||
| #endif | |||
| layout (constant_id = 0) const int op_type = 0; | |||
| #define shape_constant_id_offset 1 | |||
| layout (constant_id = shape_constant_id_offset + 0) const int adims = 0; | |||
| layout (constant_id = shape_constant_id_offset + 1) const int aw = 0; | |||
| layout (constant_id = shape_constant_id_offset + 2) const int ah = 0; | |||
| layout (constant_id = shape_constant_id_offset + 3) const int ad = 0; | |||
| layout (constant_id = shape_constant_id_offset + 4) const int ac = 0; | |||
| layout (constant_id = shape_constant_id_offset + 5) const int acstep = 0; | |||
| layout (constant_id = shape_constant_id_offset + 6) const int bdims = 0; | |||
| layout (constant_id = shape_constant_id_offset + 7) const int bw = 0; | |||
| layout (constant_id = shape_constant_id_offset + 8) const int bh = 0; | |||
| layout (constant_id = shape_constant_id_offset + 9) const int bd = 0; | |||
| layout (constant_id = shape_constant_id_offset + 10) const int bc = 0; | |||
| layout (constant_id = shape_constant_id_offset + 11) const int bcstep = 0; | |||
| layout (constant_id = shape_constant_id_offset + 12) const int outdims = 0; | |||
| layout (constant_id = shape_constant_id_offset + 13) const int outw = 0; | |||
| layout (constant_id = shape_constant_id_offset + 14) const int outh = 0; | |||
| layout (constant_id = shape_constant_id_offset + 15) const int outd = 0; | |||
| layout (constant_id = shape_constant_id_offset + 16) const int outc = 0; | |||
| layout (constant_id = shape_constant_id_offset + 17) const int outcstep = 0; | |||
| #if NCNN_image_shader | |||
| layout (binding = 0) uniform unfp sampler3D a_blob_3d; | |||
| layout (binding = 1) uniform unfp sampler3D b_blob_3d; | |||
| layout (binding = 2, imfmtc4) writeonly uniform unfp image3D top_blob_3d; | |||
| #else | |||
| layout (binding = 0) readonly buffer a_blob { sfpvec8 a_blob_data[]; }; | |||
| layout (binding = 1) readonly buffer b_blob { sfpvec8 b_blob_data[]; }; | |||
| layout (binding = 2) writeonly buffer top_blob { sfpvec8 top_blob_data[]; }; | |||
| #endif | |||
| layout (push_constant) uniform parameter | |||
| { | |||
| int adims; | |||
| int aw; | |||
| int ah; | |||
| int ad; | |||
| int ac; | |||
| int acstep; | |||
| int bdims; | |||
| int bw; | |||
| int bh; | |||
| int bd; | |||
| int bc; | |||
| int bcstep; | |||
| int outdims; | |||
| int outw; | |||
| int outh; | |||
| int outd; | |||
| int outc; | |||
| int outcstep; | |||
| } p; | |||
| void main() | |||
| { | |||
| int gx = int(gl_GlobalInvocationID.x); | |||
| int gy = int(gl_GlobalInvocationID.y); | |||
| int gz = int(gl_GlobalInvocationID.z); | |||
| if (gx >= psc(outw) || gy >= psc(outh) * psc(outd) || gz >= psc(outc)) | |||
| return; | |||
| int yd = gy / psc(outh); | |||
| int yh = gy % psc(outh); | |||
| int bx = gx; | |||
| int by = gy; | |||
| int bz = gz; | |||
| if (psc(adims) == psc(bdims)) | |||
| { | |||
| // explicit broadcast | |||
| bx = min(gx, psc(bw) - 1); | |||
| by = min(yd, psc(bd) - 1) * psc(bh) + min(yh, psc(bh) - 1); | |||
| bz = min(gz, psc(bc) - 1); | |||
| } | |||
| else | |||
| { | |||
| // implicit broadcast | |||
| if (psc(adims) == 4) | |||
| { | |||
| if (psc(bdims) == 3) | |||
| { | |||
| // type 13 | |||
| bx = yh; | |||
| by = yd; | |||
| bz = gz; | |||
| } | |||
| if (psc(bdims) == 2) | |||
| { | |||
| // type 12 | |||
| bx = yd; | |||
| by = gz; | |||
| bz = 0; | |||
| } | |||
| if (psc(bdims) == 1) | |||
| { | |||
| // type 11 | |||
| bx = gz; | |||
| by = 0; | |||
| bz = 0; | |||
| } | |||
| } | |||
| else if (psc(adims) == 3) | |||
| { | |||
| if (psc(bdims) == 2) | |||
| { | |||
| // type 10 | |||
| bx = gy; | |||
| by = gz; | |||
| bz = 0; | |||
| } | |||
| if (psc(bdims) == 1) | |||
| { | |||
| // type 9 | |||
| bx = gz; | |||
| by = 0; | |||
| bz = 0; | |||
| } | |||
| } | |||
| else // if (psc(adims) == 2) | |||
| { | |||
| // if (psc(bdims) == 1) | |||
| { | |||
| // type 8 | |||
| bx = gy; | |||
| by = 0; | |||
| bz = 0; | |||
| } | |||
| } | |||
| } | |||
| #if NCNN_image_shader | |||
| afpvec8 v1 = image3d_ld8(a_blob_3d, ivec3(gx, gy, gz)); | |||
| afpvec8 v2 = image3d_ld8(b_blob_3d, ivec3(bx, by, bz)); | |||
| #else | |||
| int gi = gz * psc(outcstep) + gy * psc(outw) + gx; | |||
| int bi = bz * psc(bcstep) + by * psc(bw) + bx; | |||
| afpvec8 v1 = buffer_ld8(a_blob_data, gi); | |||
| afpvec8 v2 = buffer_ld8(b_blob_data, bi); | |||
| #endif | |||
| afpvec8 res; | |||
| if (op_type == 0) | |||
| { | |||
| res[0] = v1[0] + v2[0]; | |||
| res[1] = v1[1] + v2[1]; | |||
| } | |||
| if (op_type == 1) | |||
| { | |||
| res[0] = v1[0] - v2[0]; | |||
| res[1] = v1[1] - v2[1]; | |||
| } | |||
| if (op_type == 2) | |||
| { | |||
| res[0] = v1[0] * v2[0]; | |||
| res[1] = v1[1] * v2[1]; | |||
| } | |||
| if (op_type == 3) | |||
| { | |||
| res[0] = v1[0] / v2[0]; | |||
| res[1] = v1[1] / v2[1]; | |||
| } | |||
| if (op_type == 4) | |||
| { | |||
| res[0] = max(v1[0], v2[0]); | |||
| res[1] = max(v1[1], v2[1]); | |||
| } | |||
| if (op_type == 5) | |||
| { | |||
| res[0] = min(v1[0], v2[0]); | |||
| res[1] = min(v1[1], v2[1]); | |||
| } | |||
| if (op_type == 6) | |||
| { | |||
| res[0] = pow(v1[0], v2[0]); | |||
| res[1] = pow(v1[1], v2[1]); | |||
| } | |||
| if (op_type == 7) | |||
| { | |||
| res[0] = v2[0] - v1[0]; | |||
| res[1] = v2[1] - v1[1]; | |||
| } | |||
| if (op_type == 8) | |||
| { | |||
| res[0] = v2[0] / v1[0]; | |||
| res[1] = v2[1] / v1[1]; | |||
| } | |||
| if (op_type == 9) | |||
| { | |||
| res[0] = pow(v2[0], v1[0]); | |||
| res[1] = pow(v2[1], v1[1]); | |||
| } | |||
| #if NCNN_image_shader | |||
| image3d_st8(top_blob_3d, ivec3(gx, gy, gz), res); | |||
| #else | |||
| buffer_st8(top_blob_data, gi, res); | |||
| #endif | |||
| } | |||
| @@ -1,6 +1,6 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| @@ -27,29 +27,32 @@ layout (constant_id = 0) const int op_type = 0; | |||
| layout (constant_id = shape_constant_id_offset + 0) const int adims = 0; | |||
| layout (constant_id = shape_constant_id_offset + 1) const int aw = 0; | |||
| layout (constant_id = shape_constant_id_offset + 2) const int ah = 0; | |||
| layout (constant_id = shape_constant_id_offset + 3) const int ac = 0; | |||
| layout (constant_id = shape_constant_id_offset + 4) const int acstep = 0; | |||
| layout (constant_id = shape_constant_id_offset + 5) const int bdims = 0; | |||
| layout (constant_id = shape_constant_id_offset + 6) const int bw = 0; | |||
| layout (constant_id = shape_constant_id_offset + 7) const int bh = 0; | |||
| layout (constant_id = shape_constant_id_offset + 8) const int bc = 0; | |||
| layout (constant_id = shape_constant_id_offset + 9) const int bcstep = 0; | |||
| layout (constant_id = shape_constant_id_offset + 10) const int outdims = 0; | |||
| layout (constant_id = shape_constant_id_offset + 11) const int outw = 0; | |||
| layout (constant_id = shape_constant_id_offset + 12) const int outh = 0; | |||
| layout (constant_id = shape_constant_id_offset + 13) const int outc = 0; | |||
| layout (constant_id = shape_constant_id_offset + 14) const int outcstep = 0; | |||
| layout (constant_id = shape_constant_id_offset + 3) const int ad = 0; | |||
| layout (constant_id = shape_constant_id_offset + 4) const int ac = 0; | |||
| layout (constant_id = shape_constant_id_offset + 5) const int acstep = 0; | |||
| layout (constant_id = shape_constant_id_offset + 6) const int bdims = 0; | |||
| layout (constant_id = shape_constant_id_offset + 7) const int bw = 0; | |||
| layout (constant_id = shape_constant_id_offset + 8) const int bh = 0; | |||
| layout (constant_id = shape_constant_id_offset + 9) const int bd = 0; | |||
| layout (constant_id = shape_constant_id_offset + 10) const int bc = 0; | |||
| layout (constant_id = shape_constant_id_offset + 11) const int bcstep = 0; | |||
| layout (constant_id = shape_constant_id_offset + 12) const int outdims = 0; | |||
| layout (constant_id = shape_constant_id_offset + 13) const int outw = 0; | |||
| layout (constant_id = shape_constant_id_offset + 14) const int outh = 0; | |||
| layout (constant_id = shape_constant_id_offset + 15) const int outd = 0; | |||
| layout (constant_id = shape_constant_id_offset + 16) const int outc = 0; | |||
| layout (constant_id = shape_constant_id_offset + 17) const int outcstep = 0; | |||
| #if NCNN_image_shader | |||
| layout (binding = 0) uniform unfp sampler3D a_blob_3d; | |||
| layout (binding = 1) uniform unfp sampler3D b_blob_3d; | |||
| layout (binding = 2, imfmtc4) writeonly uniform unfp image3D top_blob_3d; | |||
| layout (binding = 2, imfmtc1) writeonly uniform unfp image3D top_blob_3d; | |||
| #else | |||
| layout (binding = 0) readonly buffer a_blob { sfp a_blob_data[]; }; | |||
| layout (binding = 1) readonly buffer b_blob { sfpvec4 b_blob_data[]; }; | |||
| layout (binding = 2) writeonly buffer top_blob { sfpvec4 top_blob_data[]; }; | |||
| layout (binding = 1) readonly buffer b_blob { sfp b_blob_data[]; }; | |||
| layout (binding = 2) writeonly buffer top_blob { sfp top_blob_data[]; }; | |||
| #endif | |||
| layout (push_constant) uniform parameter | |||
| @@ -57,18 +60,21 @@ layout (push_constant) uniform parameter | |||
| int adims; | |||
| int aw; | |||
| int ah; | |||
| int ad; | |||
| int ac; | |||
| int acstep; | |||
| int bdims; | |||
| int bw; | |||
| int bh; | |||
| int bd; | |||
| int bc; | |||
| int bcstep; | |||
| int outdims; | |||
| int outw; | |||
| int outh; | |||
| int outd; | |||
| int outc; | |||
| int outcstep; | |||
| } p; | |||
| @@ -79,40 +85,29 @@ void main() | |||
| int gy = int(gl_GlobalInvocationID.y); | |||
| int gz = int(gl_GlobalInvocationID.z); | |||
| if (gx >= psc(outw) || gy >= psc(outh) || gz >= psc(outc)) | |||
| if (gx >= psc(outw) || gy >= psc(outh) * psc(outd) || gz >= psc(outc)) | |||
| return; | |||
| #if NCNN_image_shader | |||
| afpvec4 v1; | |||
| if (psc(adims) == 3 && psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(ac) == 1) | |||
| { | |||
| // special type 4 | |||
| v1 = afpvec4(image3d_ld1(a_blob_3d, ivec3(gx, gy, 0))); | |||
| } | |||
| else | |||
| { | |||
| v1 = afpvec4(image3d_ld1(a_blob_3d, ivec3(0, 0, 0))); | |||
| } | |||
| afpvec4 v2 = image3d_ld4(b_blob_3d, ivec3(gx, gy, gz)); | |||
| #else | |||
| const int gi = gz * psc(outcstep) + gy * psc(outw) + gx; | |||
| int yd = gy / psc(outh); | |||
| int yh = gy % psc(outh); | |||
| int ai = 0; | |||
| // explicit broadcast | |||
| int bx = min(gx, psc(bw) - 1); | |||
| int by = min(yd, psc(bd) - 1) * psc(bh) + min(yh, psc(bh) - 1); | |||
| int bz = min(gz, psc(bc) - 1); | |||
| if (psc(adims) == 3 && psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(ac) == 1) | |||
| { | |||
| // special type 4 | |||
| ai = gy * psc(aw) + gx; | |||
| } | |||
| #if NCNN_image_shader | |||
| afp v1 = image3d_ld1(a_blob_3d, ivec3(gx, gy, gz)); | |||
| afp v2 = image3d_ld1(b_blob_3d, ivec3(bx, by, bz)); | |||
| #else | |||
| int gi = gz * psc(outcstep) + gy * psc(outw) + gx; | |||
| int bi = bz * psc(bcstep) + by * psc(bw) + bx; | |||
| // type 2 3 4 | |||
| afpvec4 v1 = afpvec4(buffer_ld1(a_blob_data, ai)); | |||
| afpvec4 v2 = buffer_ld4(b_blob_data, gi); | |||
| afp v1 = buffer_ld1(a_blob_data, gi); | |||
| afp v2 = buffer_ld1(b_blob_data, bi); | |||
| #endif | |||
| afpvec4 res; | |||
| afp res; | |||
| if (op_type == 0) res = v1 + v2; | |||
| if (op_type == 1) res = v1 - v2; | |||
| @@ -123,10 +118,11 @@ void main() | |||
| if (op_type == 6) res = pow(v1, v2); | |||
| if (op_type == 7) res = v2 - v1; | |||
| if (op_type == 8) res = v2 / v1; | |||
| if (op_type == 9) res = pow(v2, v1); | |||
| #if NCNN_image_shader | |||
| image3d_st4(top_blob_3d, ivec3(gx, gy, gz), res); | |||
| image3d_st1(top_blob_3d, ivec3(gx, gy, gz), res); | |||
| #else | |||
| buffer_st4(top_blob_data, gi, res); | |||
| buffer_st1(top_blob_data, gi, res); | |||
| #endif | |||
| } | |||
| @@ -1,6 +1,6 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| @@ -27,20 +27,23 @@ layout (constant_id = 0) const int op_type = 0; | |||
| layout (constant_id = shape_constant_id_offset + 0) const int adims = 0; | |||
| layout (constant_id = shape_constant_id_offset + 1) const int aw = 0; | |||
| layout (constant_id = shape_constant_id_offset + 2) const int ah = 0; | |||
| layout (constant_id = shape_constant_id_offset + 3) const int ac = 0; | |||
| layout (constant_id = shape_constant_id_offset + 4) const int acstep = 0; | |||
| layout (constant_id = shape_constant_id_offset + 5) const int bdims = 0; | |||
| layout (constant_id = shape_constant_id_offset + 6) const int bw = 0; | |||
| layout (constant_id = shape_constant_id_offset + 7) const int bh = 0; | |||
| layout (constant_id = shape_constant_id_offset + 8) const int bc = 0; | |||
| layout (constant_id = shape_constant_id_offset + 9) const int bcstep = 0; | |||
| layout (constant_id = shape_constant_id_offset + 10) const int outdims = 0; | |||
| layout (constant_id = shape_constant_id_offset + 11) const int outw = 0; | |||
| layout (constant_id = shape_constant_id_offset + 12) const int outh = 0; | |||
| layout (constant_id = shape_constant_id_offset + 13) const int outc = 0; | |||
| layout (constant_id = shape_constant_id_offset + 14) const int outcstep = 0; | |||
| layout (constant_id = shape_constant_id_offset + 3) const int ad = 0; | |||
| layout (constant_id = shape_constant_id_offset + 4) const int ac = 0; | |||
| layout (constant_id = shape_constant_id_offset + 5) const int acstep = 0; | |||
| layout (constant_id = shape_constant_id_offset + 6) const int bdims = 0; | |||
| layout (constant_id = shape_constant_id_offset + 7) const int bw = 0; | |||
| layout (constant_id = shape_constant_id_offset + 8) const int bh = 0; | |||
| layout (constant_id = shape_constant_id_offset + 9) const int bd = 0; | |||
| layout (constant_id = shape_constant_id_offset + 10) const int bc = 0; | |||
| layout (constant_id = shape_constant_id_offset + 11) const int bcstep = 0; | |||
| layout (constant_id = shape_constant_id_offset + 12) const int outdims = 0; | |||
| layout (constant_id = shape_constant_id_offset + 13) const int outw = 0; | |||
| layout (constant_id = shape_constant_id_offset + 14) const int outh = 0; | |||
| layout (constant_id = shape_constant_id_offset + 15) const int outd = 0; | |||
| layout (constant_id = shape_constant_id_offset + 16) const int outc = 0; | |||
| layout (constant_id = shape_constant_id_offset + 17) const int outcstep = 0; | |||
| #if NCNN_image_shader | |||
| layout (binding = 0) uniform unfp sampler3D a_blob_3d; | |||
| @@ -57,18 +60,21 @@ layout (push_constant) uniform parameter | |||
| int adims; | |||
| int aw; | |||
| int ah; | |||
| int ad; | |||
| int ac; | |||
| int acstep; | |||
| int bdims; | |||
| int bw; | |||
| int bh; | |||
| int bd; | |||
| int bc; | |||
| int bcstep; | |||
| int outdims; | |||
| int outw; | |||
| int outh; | |||
| int outd; | |||
| int outc; | |||
| int outcstep; | |||
| } p; | |||
| @@ -79,35 +85,24 @@ void main() | |||
| int gy = int(gl_GlobalInvocationID.y); | |||
| int gz = int(gl_GlobalInvocationID.z); | |||
| if (gx >= psc(outw) || gy >= psc(outh) || gz >= psc(outc)) | |||
| if (gx >= psc(outw) || gy >= psc(outh) * psc(outd) || gz >= psc(outc)) | |||
| return; | |||
| #if NCNN_image_shader | |||
| afpvec4 v2; | |||
| if (psc(bdims) == 3 && psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(bc) == 1) | |||
| { | |||
| // special type 2 | |||
| v2 = afpvec4(image3d_ld1(b_blob_3d, ivec3(gx, gy, 0))); | |||
| } | |||
| else | |||
| { | |||
| v2 = afpvec4(image3d_ld1(b_blob_3d, ivec3(0, 0, 0))); | |||
| } | |||
| int yd = gy / psc(outh); | |||
| int yh = gy % psc(outh); | |||
| // explicit broadcast | |||
| int bx = min(gx, psc(bw) - 1); | |||
| int by = min(yd, psc(bd) - 1) * psc(bh) + min(yh, psc(bh) - 1); | |||
| int bz = min(gz, psc(bc) - 1); | |||
| #if NCNN_image_shader | |||
| afpvec4 v1 = image3d_ld4(a_blob_3d, ivec3(gx, gy, gz)); | |||
| afpvec4 v2 = afpvec4(image3d_ld1(b_blob_3d, ivec3(bx, by, bz))); | |||
| #else | |||
| const int gi = gz * psc(outcstep) + gy * psc(outw) + gx; | |||
| int bi = 0; | |||
| if (psc(bdims) == 3 && psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(bc) == 1) | |||
| { | |||
| // special type 2 | |||
| bi = gy * psc(bw) + gx; | |||
| } | |||
| int gi = gz * psc(outcstep) + gy * psc(outw) + gx; | |||
| int bi = bz * psc(bcstep) + by * psc(bw) + bx; | |||
| // type 6 11 16 | |||
| afpvec4 v1 = buffer_ld4(a_blob_data, gi); | |||
| afpvec4 v2 = afpvec4(buffer_ld1(b_blob_data, bi)); | |||
| #endif | |||
| @@ -123,6 +118,7 @@ void main() | |||
| if (op_type == 6) res = pow(v1, v2); | |||
| if (op_type == 7) res = v2 - v1; | |||
| if (op_type == 8) res = v2 / v1; | |||
| if (op_type == 9) res = pow(v2, v1); | |||
| #if NCNN_image_shader | |||
| image3d_st4(top_blob_3d, ivec3(gx, gy, gz), res); | |||
| @@ -1,6 +1,6 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| @@ -28,20 +28,23 @@ layout (constant_id = 0) const int op_type = 0; | |||
| layout (constant_id = shape_constant_id_offset + 0) const int adims = 0; | |||
| layout (constant_id = shape_constant_id_offset + 1) const int aw = 0; | |||
| layout (constant_id = shape_constant_id_offset + 2) const int ah = 0; | |||
| layout (constant_id = shape_constant_id_offset + 3) const int ac = 0; | |||
| layout (constant_id = shape_constant_id_offset + 4) const int acstep = 0; | |||
| layout (constant_id = shape_constant_id_offset + 5) const int bdims = 0; | |||
| layout (constant_id = shape_constant_id_offset + 6) const int bw = 0; | |||
| layout (constant_id = shape_constant_id_offset + 7) const int bh = 0; | |||
| layout (constant_id = shape_constant_id_offset + 8) const int bc = 0; | |||
| layout (constant_id = shape_constant_id_offset + 9) const int bcstep = 0; | |||
| layout (constant_id = shape_constant_id_offset + 10) const int outdims = 0; | |||
| layout (constant_id = shape_constant_id_offset + 11) const int outw = 0; | |||
| layout (constant_id = shape_constant_id_offset + 12) const int outh = 0; | |||
| layout (constant_id = shape_constant_id_offset + 13) const int outc = 0; | |||
| layout (constant_id = shape_constant_id_offset + 14) const int outcstep = 0; | |||
| layout (constant_id = shape_constant_id_offset + 3) const int ad = 0; | |||
| layout (constant_id = shape_constant_id_offset + 4) const int ac = 0; | |||
| layout (constant_id = shape_constant_id_offset + 5) const int acstep = 0; | |||
| layout (constant_id = shape_constant_id_offset + 6) const int bdims = 0; | |||
| layout (constant_id = shape_constant_id_offset + 7) const int bw = 0; | |||
| layout (constant_id = shape_constant_id_offset + 8) const int bh = 0; | |||
| layout (constant_id = shape_constant_id_offset + 9) const int bd = 0; | |||
| layout (constant_id = shape_constant_id_offset + 10) const int bc = 0; | |||
| layout (constant_id = shape_constant_id_offset + 11) const int bcstep = 0; | |||
| layout (constant_id = shape_constant_id_offset + 12) const int outdims = 0; | |||
| layout (constant_id = shape_constant_id_offset + 13) const int outw = 0; | |||
| layout (constant_id = shape_constant_id_offset + 14) const int outh = 0; | |||
| layout (constant_id = shape_constant_id_offset + 15) const int outd = 0; | |||
| layout (constant_id = shape_constant_id_offset + 16) const int outc = 0; | |||
| layout (constant_id = shape_constant_id_offset + 17) const int outcstep = 0; | |||
| #if NCNN_image_shader | |||
| layout (binding = 0) uniform unfp sampler3D a_blob_3d; | |||
| @@ -58,18 +61,21 @@ layout (push_constant) uniform parameter | |||
| int adims; | |||
| int aw; | |||
| int ah; | |||
| int ad; | |||
| int ac; | |||
| int acstep; | |||
| int bdims; | |||
| int bw; | |||
| int bh; | |||
| int bd; | |||
| int bc; | |||
| int bcstep; | |||
| int outdims; | |||
| int outw; | |||
| int outh; | |||
| int outd; | |||
| int outc; | |||
| int outcstep; | |||
| } p; | |||
| @@ -80,85 +86,82 @@ void main() | |||
| int gy = int(gl_GlobalInvocationID.y); | |||
| int gz = int(gl_GlobalInvocationID.z); | |||
| if (gx >= psc(outw) || gy >= psc(outh) || gz >= psc(outc)) | |||
| if (gx >= psc(outw) || gy >= psc(outh) * psc(outd) || gz >= psc(outc)) | |||
| return; | |||
| #if NCNN_image_shader | |||
| afpvec4 v2; | |||
| int yd = gy / psc(outh); | |||
| int yh = gy % psc(outh); | |||
| if (psc(bdims) == 3 && psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(bc) == 1) | |||
| { | |||
| // special type 2 | |||
| v2 = afpvec4(image3d_ld1(b_blob_3d, ivec3(gx, gy, 0))); | |||
| } | |||
| else | |||
| { | |||
| v2 = afpvec4(image3d_ld1(b_blob_3d, ivec3(0, 0, 0))); | |||
| } | |||
| // explicit broadcast | |||
| int bx = min(gx, psc(bw) - 1); | |||
| int by = min(yd, psc(bd) - 1) * psc(bh) + min(yh, psc(bh) - 1); | |||
| int bz = min(gz, psc(bc) - 1); | |||
| #if NCNN_image_shader | |||
| afpvec8 v1 = image3d_ld8(a_blob_3d, ivec3(gx, gy, gz)); | |||
| afp b = image3d_ld1(b_blob_3d, ivec3(bx, by, bz)); | |||
| #else | |||
| const int gi = gz * psc(outcstep) + gy * psc(outw) + gx; | |||
| int bi = 0; | |||
| int gi = gz * psc(outcstep) + gy * psc(outw) + gx; | |||
| int bi = bz * psc(bcstep) + by * psc(bw) + bx; | |||
| if (psc(bdims) == 3 && psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(bc) == 1) | |||
| { | |||
| // special type 2 | |||
| bi = gy * psc(bw) + gx; | |||
| } | |||
| // type 6 11 16 | |||
| afpvec8 v1 = buffer_ld8(a_blob_data, gi); | |||
| afpvec4 v2 = afpvec4(buffer_ld1(b_blob_data, bi)); | |||
| afp b = buffer_ld1(b_blob_data, bi); | |||
| #endif | |||
| afpvec8 v2; | |||
| v2[0] = afpvec4(b); | |||
| v2[1] = afpvec4(b); | |||
| afpvec8 res; | |||
| if (op_type == 0) | |||
| { | |||
| res[0] = v1[0] + v2; | |||
| res[1] = v1[1] + v2; | |||
| res[0] = v1[0] + v2[0]; | |||
| res[1] = v1[1] + v2[1]; | |||
| } | |||
| if (op_type == 1) | |||
| { | |||
| res[0] = v1[0] - v2; | |||
| res[1] = v1[1] - v2; | |||
| res[0] = v1[0] - v2[0]; | |||
| res[1] = v1[1] - v2[1]; | |||
| } | |||
| if (op_type == 2) | |||
| { | |||
| res[0] = v1[0] * v2; | |||
| res[1] = v1[1] * v2; | |||
| res[0] = v1[0] * v2[0]; | |||
| res[1] = v1[1] * v2[1]; | |||
| } | |||
| if (op_type == 3) | |||
| { | |||
| res[0] = v1[0] / v2; | |||
| res[1] = v1[1] / v2; | |||
| res[0] = v1[0] / v2[0]; | |||
| res[1] = v1[1] / v2[1]; | |||
| } | |||
| if (op_type == 4) | |||
| { | |||
| res[0] = max(v1[0], v2); | |||
| res[1] = max(v1[1], v2); | |||
| res[0] = max(v1[0], v2[0]); | |||
| res[1] = max(v1[1], v2[1]); | |||
| } | |||
| if (op_type == 5) | |||
| { | |||
| res[0] = min(v1[0], v2); | |||
| res[1] = min(v1[1], v2); | |||
| res[0] = min(v1[0], v2[0]); | |||
| res[1] = min(v1[1], v2[1]); | |||
| } | |||
| if (op_type == 6) | |||
| { | |||
| res[0] = pow(v1[0], v2); | |||
| res[1] = pow(v1[1], v2); | |||
| res[0] = pow(v1[0], v2[0]); | |||
| res[1] = pow(v1[1], v2[1]); | |||
| } | |||
| if (op_type == 7) | |||
| { | |||
| res[0] = v2 - v1[0]; | |||
| res[1] = v2 - v1[1]; | |||
| res[0] = v2[0] - v1[0]; | |||
| res[1] = v2[1] - v1[1]; | |||
| } | |||
| if (op_type == 8) | |||
| { | |||
| res[0] = v2 / v1[0]; | |||
| res[1] = v2 / v1[1]; | |||
| res[0] = v2[0] / v1[0]; | |||
| res[1] = v2[1] / v1[1]; | |||
| } | |||
| if (op_type == 9) | |||
| { | |||
| res[0] = pow(v2[0], v1[0]); | |||
| res[1] = pow(v2[1], v1[1]); | |||
| } | |||
| #if NCNN_image_shader | |||
| @@ -1,502 +0,0 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #version 450 | |||
| #if NCNN_fp16_storage | |||
| #extension GL_EXT_shader_16bit_storage: require | |||
| #endif | |||
| #if NCNN_fp16_arithmetic | |||
| #extension GL_EXT_shader_explicit_arithmetic_types_float16: require | |||
| #endif | |||
| layout (constant_id = 0) const int op_type = 0; | |||
| #define shape_constant_id_offset 1 | |||
| layout (constant_id = shape_constant_id_offset + 0) const int adims = 0; | |||
| layout (constant_id = shape_constant_id_offset + 1) const int aw = 0; | |||
| layout (constant_id = shape_constant_id_offset + 2) const int ah = 0; | |||
| layout (constant_id = shape_constant_id_offset + 3) const int ad = 0; | |||
| layout (constant_id = shape_constant_id_offset + 4) const int ac = 0; | |||
| layout (constant_id = shape_constant_id_offset + 5) const int acstep = 0; | |||
| layout (constant_id = shape_constant_id_offset + 6) const int bdims = 0; | |||
| layout (constant_id = shape_constant_id_offset + 7) const int bw = 0; | |||
| layout (constant_id = shape_constant_id_offset + 8) const int bh = 0; | |||
| layout (constant_id = shape_constant_id_offset + 9) const int bd = 0; | |||
| layout (constant_id = shape_constant_id_offset + 10) const int bc = 0; | |||
| layout (constant_id = shape_constant_id_offset + 11) const int bcstep = 0; | |||
| layout (constant_id = shape_constant_id_offset + 12) const int outdims = 0; | |||
| layout (constant_id = shape_constant_id_offset + 13) const int outw = 0; | |||
| layout (constant_id = shape_constant_id_offset + 14) const int outh = 0; | |||
| layout (constant_id = shape_constant_id_offset + 15) const int outd = 0; | |||
| layout (constant_id = shape_constant_id_offset + 16) const int outc = 0; | |||
| layout (constant_id = shape_constant_id_offset + 17) const int outcstep = 0; | |||
| #if NCNN_image_shader | |||
| layout (binding = 0) uniform unfp sampler3D a_blob_3d; | |||
| layout (binding = 1) uniform unfp sampler3D b_blob_3d; | |||
| layout (binding = 2, imfmtc4) writeonly uniform unfp image3D top_blob_3d; | |||
| #else | |||
| layout (binding = 0) readonly buffer a_blob { sfpvec4 a_blob_data[]; }; | |||
| layout (binding = 1) readonly buffer b_blob { sfpvec4 b_blob_data[]; }; | |||
| layout (binding = 2) writeonly buffer top_blob { sfpvec4 top_blob_data[]; }; | |||
| #endif | |||
| layout (push_constant) uniform parameter | |||
| { | |||
| int adims; | |||
| int aw; | |||
| int ah; | |||
| int ad; | |||
| int ac; | |||
| int acstep; | |||
| int bdims; | |||
| int bw; | |||
| int bh; | |||
| int bd; | |||
| int bc; | |||
| int bcstep; | |||
| int outdims; | |||
| int outw; | |||
| int outh; | |||
| int outd; | |||
| int outc; | |||
| int outcstep; | |||
| } p; | |||
| void main() | |||
| { | |||
| int gx = int(gl_GlobalInvocationID.x); | |||
| int gy = int(gl_GlobalInvocationID.y); | |||
| int gz = int(gl_GlobalInvocationID.z); | |||
| if (gx >= psc(outw) || gy >= psc(outh) * psc(outd) || gz >= psc(outc)) | |||
| return; | |||
| #if NCNN_image_shader | |||
| int ax = gx; | |||
| int ay = gy; | |||
| int az = gz; | |||
| int bx = gx; | |||
| int by = gy; | |||
| int bz = gz; | |||
| if (psc(adims) == 4) | |||
| { | |||
| int yd = gy / psc(outh); | |||
| int yh = gy % psc(outh); | |||
| if (psc(bdims) == 3) | |||
| { | |||
| // type 28 | |||
| bx = yh; | |||
| by = yd; | |||
| bz = gz; | |||
| } | |||
| if (psc(bdims) == 2) | |||
| { | |||
| // type 27 | |||
| bx = yd; | |||
| by = gz; | |||
| bz = 0; | |||
| } | |||
| if (psc(bdims) == 1) | |||
| { | |||
| if (psc(bw) == 1) | |||
| { | |||
| // type 25 | |||
| bx = 0; | |||
| by = 0; | |||
| bz = 0; | |||
| } | |||
| else | |||
| { | |||
| // type 26 | |||
| bx = gz; | |||
| by = 0; | |||
| bz = 0; | |||
| } | |||
| } | |||
| } | |||
| else if (psc(adims) == 3) | |||
| { | |||
| if (psc(bdims) == 4) | |||
| { | |||
| int yd = gy / psc(outh); | |||
| int yh = gy % psc(outh); | |||
| // type 23 | |||
| ax = yh; | |||
| ay = yd; | |||
| az = gz; | |||
| } | |||
| if (psc(bdims) == 3) | |||
| { | |||
| if (psc(bw) == 1 && psc(bh) == 1) | |||
| { | |||
| // special type 1 | |||
| bx = 0; | |||
| by = 0; | |||
| } | |||
| if (psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(bc) == 1) | |||
| { | |||
| // special type 2 | |||
| bz = 0; | |||
| } | |||
| if (psc(aw) == 1 && psc(ah) == 1) | |||
| { | |||
| // special type 3 | |||
| ax = 0; | |||
| ay = 0; | |||
| } | |||
| if (psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(ac) == 1) | |||
| { | |||
| // special type 4 | |||
| az = 0; | |||
| } | |||
| if (psc(aw) != 1 && psc(bw) == 1 && psc(bh) == psc(ah) && psc(bc) == psc(ac)) | |||
| { | |||
| // special type 5 | |||
| bx = 0; | |||
| } | |||
| if (psc(bw) == psc(aw) && psc(ah) != 1 && psc(bh) == 1 && psc(bc) == psc(ac)) | |||
| { | |||
| // special type 6 | |||
| by = 0; | |||
| } | |||
| if (psc(bw) != 1 && psc(aw) == 1 && psc(bh) == psc(ah) && psc(bc) == psc(ac)) | |||
| { | |||
| // special type 7 | |||
| ax = 0; | |||
| } | |||
| if (psc(bw) == psc(aw) && psc(bh) != 1 && psc(ah) == 1 && psc(bc) == psc(ac)) | |||
| { | |||
| // special type 8 | |||
| ay = 0; | |||
| } | |||
| } | |||
| if (psc(bdims) == 2) | |||
| { | |||
| // type 18 | |||
| bx = gy; | |||
| by = gz; | |||
| bz = 0; | |||
| } | |||
| if (psc(bdims) == 1) | |||
| { | |||
| if (psc(bw) == 1) | |||
| { | |||
| // type 16 | |||
| bx = 0; | |||
| by = 0; | |||
| bz = 0; | |||
| } | |||
| else | |||
| { | |||
| // type 17 | |||
| bx = gz; | |||
| by = 0; | |||
| bz = 0; | |||
| } | |||
| } | |||
| } | |||
| else if (psc(adims) == 2) | |||
| { | |||
| if (psc(bdims) == 4) | |||
| { | |||
| int yd = gy / psc(outh); | |||
| int yh = gy % psc(outh); | |||
| // type 22 | |||
| ax = yd; | |||
| ay = gz; | |||
| az = 0; | |||
| } | |||
| if (psc(bdims) == 3) | |||
| { | |||
| // type 14 | |||
| ax = gy; | |||
| ay = gz; | |||
| az = 0; | |||
| } | |||
| if (psc(bdims) == 1) | |||
| { | |||
| if (psc(bw) == 1) | |||
| { | |||
| // type 11 | |||
| bx = 0; | |||
| by = 0; | |||
| bz = 0; | |||
| } | |||
| else | |||
| { | |||
| // type 12 | |||
| bx = gy; | |||
| by = 0; | |||
| bz = 0; | |||
| } | |||
| } | |||
| } | |||
| else if (psc(adims) == 1) | |||
| { | |||
| if (psc(aw) == 1) | |||
| { | |||
| // type 2 3 4 20 | |||
| ax = 0; | |||
| ay = 0; | |||
| az = 0; | |||
| } | |||
| else | |||
| { | |||
| if (psc(bdims) == 4) | |||
| { | |||
| // type 21 | |||
| ax = gz; | |||
| ay = 0; | |||
| az = 0; | |||
| } | |||
| if (psc(bdims) == 3) | |||
| { | |||
| // type 9 | |||
| ax = gz; | |||
| ay = 0; | |||
| az = 0; | |||
| } | |||
| if (psc(bdims) == 2) | |||
| { | |||
| // type 8 | |||
| ax = gy; | |||
| ay = 0; | |||
| az = 0; | |||
| } | |||
| if (psc(bdims) == 1) | |||
| { | |||
| if (psc(bw) == 1) | |||
| { | |||
| // type 6 | |||
| bx = 0; | |||
| by = 0; | |||
| bz = 0; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| afpvec4 v1 = image3d_ld4(a_blob_3d, ivec3(ax, ay, az)); | |||
| afpvec4 v2 = image3d_ld4(b_blob_3d, ivec3(bx, by, bz)); | |||
| #else | |||
| const int gi = gz * psc(outcstep) + gy * psc(outw) + gx; | |||
| int ai; | |||
| int bi; | |||
| if (psc(adims) == 4) | |||
| { | |||
| int yd = gy / psc(outh); | |||
| int yh = gy % psc(outh); | |||
| if (psc(bdims) == 3) | |||
| { | |||
| // type 28 | |||
| ai = gi; | |||
| bi = gz * psc(bcstep) + yd * psc(bw) + yh; | |||
| } | |||
| if (psc(bdims) == 2) | |||
| { | |||
| // type 27 | |||
| ai = gi; | |||
| bi = gz * psc(bw) + yd; | |||
| } | |||
| if (psc(bdims) == 1) | |||
| { | |||
| if (psc(bw) == 1) | |||
| { | |||
| // type 25 | |||
| ai = gi; | |||
| bi = 0; | |||
| } | |||
| else | |||
| { | |||
| // type 26 | |||
| ai = gi; | |||
| bi = gz; | |||
| } | |||
| } | |||
| } | |||
| else if (psc(adims) == 3) | |||
| { | |||
| if (psc(bdims) == 4) | |||
| { | |||
| int yd = gy / psc(outh); | |||
| int yh = gy % psc(outh); | |||
| // type 23 | |||
| ai = gz * psc(acstep) + yd * psc(aw) + yh; | |||
| bi = gi; | |||
| } | |||
| if (psc(bdims) == 3) | |||
| { | |||
| if (psc(bw) == 1 && psc(bh) == 1) | |||
| { | |||
| // special type 1 | |||
| ai = gi; | |||
| bi = gz * psc(bcstep); | |||
| } | |||
| if (psc(aw) == 1 && psc(ah) == 1) | |||
| { | |||
| // special type 3 | |||
| ai = gz * psc(acstep); | |||
| bi = gi; | |||
| } | |||
| if (psc(aw) != 1 && psc(bw) == 1 && psc(bh) == psc(ah) && psc(bc) == psc(ac)) | |||
| { | |||
| // special type 5 | |||
| ai = gi; | |||
| bi = gz * psc(bcstep) + gy; | |||
| } | |||
| if (psc(bw) == psc(aw) && psc(ah) != 1 && psc(bh) == 1 && psc(bc) == psc(ac)) | |||
| { | |||
| // special type 6 | |||
| ai = gi; | |||
| bi = gz * psc(bcstep) + gx; | |||
| } | |||
| if (psc(bw) != 1 && psc(aw) == 1 && psc(bh) == psc(ah) && psc(bc) == psc(ac)) | |||
| { | |||
| // special type 7 | |||
| ai = gz * psc(acstep) + gy; | |||
| bi = gi; | |||
| } | |||
| if (psc(bw) == psc(aw) && psc(bh) != 1 && psc(ah) == 1 && psc(bc) == psc(ac)) | |||
| { | |||
| // special type 8 | |||
| ai = gz * psc(acstep) + gx; | |||
| bi = gi; | |||
| } | |||
| } | |||
| if (psc(bdims) == 2) | |||
| { | |||
| // type 18 | |||
| ai = gi; | |||
| bi = gz * psc(bw) + gy; | |||
| } | |||
| if (psc(bdims) == 1) | |||
| { | |||
| // type 17 | |||
| ai = gi; | |||
| bi = gz; | |||
| } | |||
| } | |||
| else if (psc(adims) == 2) | |||
| { | |||
| if (psc(bdims) == 4) | |||
| { | |||
| int yd = gy / psc(outh); | |||
| int yh = gy % psc(outh); | |||
| // type 22 | |||
| ai = gz * psc(aw) + yd; | |||
| bi = gi; | |||
| } | |||
| if (psc(bdims) == 3) | |||
| { | |||
| // type 14 | |||
| ai = gz * psc(aw) + gy; | |||
| bi = gi; | |||
| } | |||
| if (psc(bdims) == 1) | |||
| { | |||
| // type 12 | |||
| ai = gi; | |||
| bi = gy; | |||
| } | |||
| } | |||
| else if (psc(adims) == 1) | |||
| { | |||
| if (psc(bdims) == 4) | |||
| { | |||
| // type 21 | |||
| ai = gz; | |||
| bi = gi; | |||
| } | |||
| if (psc(bdims) == 3) | |||
| { | |||
| // type 9 | |||
| ai = gz; | |||
| bi = gi; | |||
| } | |||
| if (psc(bdims) == 2) | |||
| { | |||
| // type 8 | |||
| ai = gy; | |||
| bi = gi; | |||
| } | |||
| } | |||
| afpvec4 v1 = buffer_ld4(a_blob_data, ai); | |||
| afpvec4 v2 = buffer_ld4(b_blob_data, bi); | |||
| #endif | |||
| afpvec4 res; | |||
| if (op_type == 0) res = v1 + v2; | |||
| if (op_type == 1) res = v1 - v2; | |||
| if (op_type == 2) res = v1 * v2; | |||
| if (op_type == 3) res = v1 / v2; | |||
| if (op_type == 4) res = max(v1, v2); | |||
| if (op_type == 5) res = min(v1, v2); | |||
| if (op_type == 6) res = pow(v1, v2); | |||
| if (op_type == 7) res = v2 - v1; | |||
| if (op_type == 8) res = v2 / v1; | |||
| #if NCNN_image_shader | |||
| image3d_st4(top_blob_3d, ivec3(gx, gy, gz), res); | |||
| #else | |||
| buffer_st4(top_blob_data, gi, res); | |||
| #endif | |||
| } | |||
| @@ -1,539 +0,0 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #version 450 | |||
| #if NCNN_fp16_storage | |||
| #extension GL_EXT_shader_16bit_storage: require | |||
| struct sfpvec8 { f16vec4 abcd; f16vec4 efgh; }; | |||
| #endif | |||
| #if NCNN_fp16_arithmetic | |||
| #extension GL_EXT_shader_explicit_arithmetic_types_float16: require | |||
| #endif | |||
| layout (constant_id = 0) const int op_type = 0; | |||
| #define shape_constant_id_offset 1 | |||
| layout (constant_id = shape_constant_id_offset + 0) const int adims = 0; | |||
| layout (constant_id = shape_constant_id_offset + 1) const int aw = 0; | |||
| layout (constant_id = shape_constant_id_offset + 2) const int ah = 0; | |||
| layout (constant_id = shape_constant_id_offset + 3) const int ad = 0; | |||
| layout (constant_id = shape_constant_id_offset + 4) const int ac = 0; | |||
| layout (constant_id = shape_constant_id_offset + 5) const int acstep = 0; | |||
| layout (constant_id = shape_constant_id_offset + 6) const int bdims = 0; | |||
| layout (constant_id = shape_constant_id_offset + 7) const int bw = 0; | |||
| layout (constant_id = shape_constant_id_offset + 8) const int bh = 0; | |||
| layout (constant_id = shape_constant_id_offset + 9) const int bd = 0; | |||
| layout (constant_id = shape_constant_id_offset + 10) const int bc = 0; | |||
| layout (constant_id = shape_constant_id_offset + 11) const int bcstep = 0; | |||
| layout (constant_id = shape_constant_id_offset + 12) const int outdims = 0; | |||
| layout (constant_id = shape_constant_id_offset + 13) const int outw = 0; | |||
| layout (constant_id = shape_constant_id_offset + 14) const int outh = 0; | |||
| layout (constant_id = shape_constant_id_offset + 15) const int outd = 0; | |||
| layout (constant_id = shape_constant_id_offset + 16) const int outc = 0; | |||
| layout (constant_id = shape_constant_id_offset + 17) const int outcstep = 0; | |||
| #if NCNN_image_shader | |||
| layout (binding = 0) uniform unfp sampler3D a_blob_3d; | |||
| layout (binding = 1) uniform unfp sampler3D b_blob_3d; | |||
| layout (binding = 2, imfmtc4) writeonly uniform unfp image3D top_blob_3d; | |||
| #else | |||
| layout (binding = 0) readonly buffer a_blob { sfpvec8 a_blob_data[]; }; | |||
| layout (binding = 1) readonly buffer b_blob { sfpvec8 b_blob_data[]; }; | |||
| layout (binding = 2) writeonly buffer top_blob { sfpvec8 top_blob_data[]; }; | |||
| #endif | |||
| layout (push_constant) uniform parameter | |||
| { | |||
| int adims; | |||
| int aw; | |||
| int ah; | |||
| int ad; | |||
| int ac; | |||
| int acstep; | |||
| int bdims; | |||
| int bw; | |||
| int bh; | |||
| int bd; | |||
| int bc; | |||
| int bcstep; | |||
| int outdims; | |||
| int outw; | |||
| int outh; | |||
| int outd; | |||
| int outc; | |||
| int outcstep; | |||
| } p; | |||
| void main() | |||
| { | |||
| int gx = int(gl_GlobalInvocationID.x); | |||
| int gy = int(gl_GlobalInvocationID.y); | |||
| int gz = int(gl_GlobalInvocationID.z); | |||
| if (gx >= psc(outw) || gy >= psc(outh) * psc(outd) || gz >= psc(outc)) | |||
| return; | |||
| #if NCNN_image_shader | |||
| int ax = gx; | |||
| int ay = gy; | |||
| int az = gz; | |||
| int bx = gx; | |||
| int by = gy; | |||
| int bz = gz; | |||
| if (psc(adims) == 4) | |||
| { | |||
| int yd = gy / psc(outh); | |||
| int yh = gy % psc(outh); | |||
| if (psc(bdims) == 3) | |||
| { | |||
| // type 28 | |||
| bx = yh; | |||
| by = yd; | |||
| bz = gz; | |||
| } | |||
| if (psc(bdims) == 2) | |||
| { | |||
| // type 27 | |||
| bx = yd; | |||
| by = gz; | |||
| bz = 0; | |||
| } | |||
| if (psc(bdims) == 1) | |||
| { | |||
| if (psc(bw) == 1) | |||
| { | |||
| // type 25 | |||
| bx = 0; | |||
| by = 0; | |||
| bz = 0; | |||
| } | |||
| else | |||
| { | |||
| // type 26 | |||
| bx = gz; | |||
| by = 0; | |||
| bz = 0; | |||
| } | |||
| } | |||
| } | |||
| else if (psc(adims) == 3) | |||
| { | |||
| if (psc(bdims) == 4) | |||
| { | |||
| int yd = gy / psc(outh); | |||
| int yh = gy % psc(outh); | |||
| // type 23 | |||
| ax = yh; | |||
| ay = yd; | |||
| az = gz; | |||
| } | |||
| if (psc(bdims) == 3) | |||
| { | |||
| if (psc(bw) == 1 && psc(bh) == 1) | |||
| { | |||
| // special type 1 | |||
| bx = 0; | |||
| by = 0; | |||
| } | |||
| if (psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(bc) == 1) | |||
| { | |||
| // special type 2 | |||
| bz = 0; | |||
| } | |||
| if (psc(aw) == 1 && psc(ah) == 1) | |||
| { | |||
| // special type 3 | |||
| ax = 0; | |||
| ay = 0; | |||
| } | |||
| if (psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(ac) == 1) | |||
| { | |||
| // special type 4 | |||
| az = 0; | |||
| } | |||
| if (psc(aw) != 1 && psc(bw) == 1 && psc(bh) == psc(ah) && psc(bc) == psc(ac)) | |||
| { | |||
| // special type 5 | |||
| bx = 0; | |||
| } | |||
| if (psc(bw) == psc(aw) && psc(ah) != 1 && psc(bh) == 1 && psc(bc) == psc(ac)) | |||
| { | |||
| // special type 6 | |||
| by = 0; | |||
| } | |||
| if (psc(bw) != 1 && psc(aw) == 1 && psc(bh) == psc(ah) && psc(bc) == psc(ac)) | |||
| { | |||
| // special type 7 | |||
| ax = 0; | |||
| } | |||
| if (psc(bw) == psc(aw) && psc(bh) != 1 && psc(ah) == 1 && psc(bc) == psc(ac)) | |||
| { | |||
| // special type 8 | |||
| ay = 0; | |||
| } | |||
| } | |||
| if (psc(bdims) == 2) | |||
| { | |||
| // type 18 | |||
| bx = gy; | |||
| by = gz; | |||
| bz = 0; | |||
| } | |||
| if (psc(bdims) == 1) | |||
| { | |||
| if (psc(bw) == 1) | |||
| { | |||
| // type 16 | |||
| bx = 0; | |||
| by = 0; | |||
| bz = 0; | |||
| } | |||
| else | |||
| { | |||
| // type 17 | |||
| bx = gz; | |||
| by = 0; | |||
| bz = 0; | |||
| } | |||
| } | |||
| } | |||
| else if (psc(adims) == 2) | |||
| { | |||
| if (psc(bdims) == 4) | |||
| { | |||
| int yd = gy / psc(outh); | |||
| int yh = gy % psc(outh); | |||
| // type 22 | |||
| ax = yd; | |||
| ay = gz; | |||
| az = 0; | |||
| } | |||
| if (psc(bdims) == 3) | |||
| { | |||
| // type 14 | |||
| ax = gy; | |||
| ay = gz; | |||
| az = 0; | |||
| } | |||
| if (psc(bdims) == 1) | |||
| { | |||
| if (psc(bw) == 1) | |||
| { | |||
| // type 11 | |||
| bx = 0; | |||
| by = 0; | |||
| bz = 0; | |||
| } | |||
| else | |||
| { | |||
| // type 12 | |||
| bx = gy; | |||
| by = 0; | |||
| bz = 0; | |||
| } | |||
| } | |||
| } | |||
| else if (psc(adims) == 1) | |||
| { | |||
| if (psc(aw) == 1) | |||
| { | |||
| // type 2 3 4 20 | |||
| ax = 0; | |||
| ay = 0; | |||
| az = 0; | |||
| } | |||
| else | |||
| { | |||
| if (psc(bdims) == 4) | |||
| { | |||
| // type 21 | |||
| ax = gz; | |||
| ay = 0; | |||
| az = 0; | |||
| } | |||
| if (psc(bdims) == 3) | |||
| { | |||
| // type 9 | |||
| ax = gz; | |||
| ay = 0; | |||
| az = 0; | |||
| } | |||
| if (psc(bdims) == 2) | |||
| { | |||
| // type 8 | |||
| ax = gy; | |||
| ay = 0; | |||
| az = 0; | |||
| } | |||
| if (psc(bdims) == 1) | |||
| { | |||
| if (psc(bw) == 1) | |||
| { | |||
| // type 6 | |||
| bx = 0; | |||
| by = 0; | |||
| bz = 0; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| afpvec8 v1 = image3d_ld8(a_blob_3d, ivec3(ax, ay, az)); | |||
| afpvec8 v2 = image3d_ld8(b_blob_3d, ivec3(bx, by, bz)); | |||
| #else | |||
| const int gi = gz * psc(outcstep) + gy * psc(outw) + gx; | |||
| int ai; | |||
| int bi; | |||
| if (psc(adims) == 4) | |||
| { | |||
| int yd = gy / psc(outh); | |||
| int yh = gy % psc(outh); | |||
| if (psc(bdims) == 3) | |||
| { | |||
| // type 28 | |||
| ai = gi; | |||
| bi = gz * psc(bcstep) + yd * psc(bw) + yh; | |||
| } | |||
| if (psc(bdims) == 2) | |||
| { | |||
| // type 27 | |||
| ai = gi; | |||
| bi = gz * psc(bw) + yd; | |||
| } | |||
| if (psc(bdims) == 1) | |||
| { | |||
| if (psc(bw) == 1) | |||
| { | |||
| // type 25 | |||
| ai = gi; | |||
| bi = 0; | |||
| } | |||
| else | |||
| { | |||
| // type 26 | |||
| ai = gi; | |||
| bi = gz; | |||
| } | |||
| } | |||
| } | |||
| else if (psc(adims) == 3) | |||
| { | |||
| if (psc(bdims) == 4) | |||
| { | |||
| int yd = gy / psc(outh); | |||
| int yh = gy % psc(outh); | |||
| // type 23 | |||
| ai = gz * psc(acstep) + yd * psc(aw) + yh; | |||
| bi = gi; | |||
| } | |||
| if (psc(bdims) == 3) | |||
| { | |||
| if (psc(bw) == 1 && psc(bh) == 1) | |||
| { | |||
| // special type 1 | |||
| ai = gi; | |||
| bi = gz * psc(bcstep); | |||
| } | |||
| if (psc(aw) == 1 && psc(ah) == 1) | |||
| { | |||
| // special type 3 | |||
| ai = gz * psc(acstep); | |||
| bi = gi; | |||
| } | |||
| if (psc(aw) != 1 && psc(bw) == 1 && psc(bh) == psc(ah) && psc(bc) == psc(ac)) | |||
| { | |||
| // special type 5 | |||
| bi = gz * psc(bcstep) + gy; | |||
| ai = gi; | |||
| } | |||
| if (psc(bw) == psc(aw) && psc(ah) != 1 && psc(bh) == 1 && psc(bc) == psc(ac)) | |||
| { | |||
| // special type 6 | |||
| bi = gz * psc(bcstep) + gx; | |||
| ai = gi; | |||
| } | |||
| if (psc(bw) != 1 && psc(aw) == 1 && psc(bh) == psc(ah) && psc(bc) == psc(ac)) | |||
| { | |||
| // special type 7 | |||
| ai = gz * psc(acstep) + gy; | |||
| bi = gi; | |||
| } | |||
| if (psc(bw) == psc(aw) && psc(bh) != 1 && psc(ah) == 1 && psc(bc) == psc(ac)) | |||
| { | |||
| // special type 8 | |||
| ai = gz * psc(acstep) + gx; | |||
| bi = gi; | |||
| } | |||
| } | |||
| if (psc(bdims) == 2) | |||
| { | |||
| // type 18 | |||
| ai = gi; | |||
| bi = gz * psc(bw) + gy; | |||
| } | |||
| if (psc(bdims) == 1) | |||
| { | |||
| // type 17 | |||
| ai = gi; | |||
| bi = gz; | |||
| } | |||
| } | |||
| else if (psc(adims) == 2) | |||
| { | |||
| if (psc(bdims) == 4) | |||
| { | |||
| int yd = gy / psc(outh); | |||
| int yh = gy % psc(outh); | |||
| // type 22 | |||
| ai = gz * psc(aw) + yd; | |||
| bi = gi; | |||
| } | |||
| if (psc(bdims) == 3) | |||
| { | |||
| // type 14 | |||
| ai = gz * psc(aw) + gy; | |||
| bi = gi; | |||
| } | |||
| if (psc(bdims) == 1) | |||
| { | |||
| // type 12 | |||
| ai = gi; | |||
| bi = gy; | |||
| } | |||
| } | |||
| else if (psc(adims) == 1) | |||
| { | |||
| if (psc(bdims) == 4) | |||
| { | |||
| // type 21 | |||
| ai = gz; | |||
| bi = gi; | |||
| } | |||
| if (psc(bdims) == 3) | |||
| { | |||
| // type 9 | |||
| ai = gz; | |||
| bi = gi; | |||
| } | |||
| if (psc(bdims) == 2) | |||
| { | |||
| // type 8 | |||
| ai = gy; | |||
| bi = gi; | |||
| } | |||
| } | |||
| afpvec8 v1 = buffer_ld8(a_blob_data, ai); | |||
| afpvec8 v2 = buffer_ld8(b_blob_data, bi); | |||
| #endif | |||
| afpvec8 res; | |||
| if (op_type == 0) | |||
| { | |||
| res[0] = v1[0] + v2[0]; | |||
| res[1] = v1[1] + v2[1]; | |||
| } | |||
| if (op_type == 1) | |||
| { | |||
| res[0] = v1[0] - v2[0]; | |||
| res[1] = v1[1] - v2[1]; | |||
| } | |||
| if (op_type == 2) | |||
| { | |||
| res[0] = v1[0] * v2[0]; | |||
| res[1] = v1[1] * v2[1]; | |||
| } | |||
| if (op_type == 3) | |||
| { | |||
| res[0] = v1[0] / v2[0]; | |||
| res[1] = v1[1] / v2[1]; | |||
| } | |||
| if (op_type == 4) | |||
| { | |||
| res[0] = max(v1[0], v2[0]); | |||
| res[1] = max(v1[1], v2[1]); | |||
| } | |||
| if (op_type == 5) | |||
| { | |||
| res[0] = min(v1[0], v2[0]); | |||
| res[1] = min(v1[1], v2[1]); | |||
| } | |||
| if (op_type == 6) | |||
| { | |||
| res[0] = pow(v1[0], v2[0]); | |||
| res[1] = pow(v1[1], v2[1]); | |||
| } | |||
| if (op_type == 7) | |||
| { | |||
| res[0] = v2[0] - v1[0]; | |||
| res[1] = v2[1] - v1[1]; | |||
| } | |||
| if (op_type == 8) | |||
| { | |||
| res[0] = v2[0] / v1[0]; | |||
| res[1] = v2[1] / v1[1]; | |||
| } | |||
| #if NCNN_image_shader | |||
| image3d_st8(top_blob_3d, ivec3(gx, gy, gz), res); | |||
| #else | |||
| buffer_st8(top_blob_data, gi, res); | |||
| #endif | |||
| } | |||
| @@ -92,52 +92,39 @@ void main() | |||
| afpvec4 v1 = buffer_ld4(a_blob_data, gi); | |||
| #endif | |||
| afpvec4 res; | |||
| afpvec4 v2; | |||
| if (with_scalar == 1) | |||
| { | |||
| // type 5 10 15 | |||
| afp b = afp(const_b); | |||
| if (op_type == 0) res = v1 + b; | |||
| if (op_type == 1) res = v1 - b; | |||
| if (op_type == 2) res = v1 * b; | |||
| if (op_type == 3) res = v1 / b; | |||
| if (op_type == 4) res = max(v1, b); | |||
| if (op_type == 5) res = min(v1, b); | |||
| if (op_type == 6) res = pow(v1, afpvec4(b)); | |||
| if (op_type == 7) res = b - v1; | |||
| if (op_type == 8) res = b / v1; | |||
| #if NCNN_image_shader | |||
| image3d_st4(top_blob_3d, ivec3(gx, gy, gz), res); | |||
| #else | |||
| buffer_st4(a_blob_data, gi, res); | |||
| #endif | |||
| // type 0 1 2 3 | |||
| v2 = afpvec4(const_b); | |||
| } | |||
| else | |||
| { | |||
| // type 7 13 19 | |||
| // type 4 5 6 7 | |||
| #if NCNN_image_shader | |||
| afpvec4 v2 = image3d_ld4(b_blob_3d, ivec3(gx, gy, gz)); | |||
| v2 = image3d_ld4(b_blob_3d, ivec3(gx, gy, gz)); | |||
| #else | |||
| afpvec4 v2 = buffer_ld4(b_blob_data, gi); | |||
| v2 = buffer_ld4(b_blob_data, gi); | |||
| #endif | |||
| } | |||
| afpvec4 res; | |||
| if (op_type == 0) res = v1 + v2; | |||
| if (op_type == 1) res = v1 - v2; | |||
| if (op_type == 2) res = v1 * v2; | |||
| if (op_type == 3) res = v1 / v2; | |||
| if (op_type == 4) res = max(v1, v2); | |||
| if (op_type == 5) res = min(v1, v2); | |||
| if (op_type == 6) res = pow(v1, v2); | |||
| if (op_type == 7) res = v2 - v1; | |||
| if (op_type == 8) res = v2 / v1; | |||
| if (op_type == 0) res = v1 + v2; | |||
| if (op_type == 1) res = v1 - v2; | |||
| if (op_type == 2) res = v1 * v2; | |||
| if (op_type == 3) res = v1 / v2; | |||
| if (op_type == 4) res = max(v1, v2); | |||
| if (op_type == 5) res = min(v1, v2); | |||
| if (op_type == 6) res = pow(v1, v2); | |||
| if (op_type == 7) res = v2 - v1; | |||
| if (op_type == 8) res = v2 / v1; | |||
| if (op_type == 9) res = pow(v2, v1); | |||
| #if NCNN_image_shader | |||
| image3d_st4(top_blob_3d, ivec3(gx, gy, gz), res); | |||
| image3d_st4(top_blob_3d, ivec3(gx, gy, gz), res); | |||
| #else | |||
| buffer_st4(top_blob_data, gi, res); | |||
| buffer_st4(top_blob_data, gi, res); | |||
| #endif | |||
| } | |||
| } | |||
| @@ -93,124 +93,80 @@ void main() | |||
| afpvec8 v1 = buffer_ld8(a_blob_data, gi); | |||
| #endif | |||
| afpvec8 res; | |||
| afpvec8 v2; | |||
| if (with_scalar == 1) | |||
| { | |||
| // type 5 10 15 | |||
| afp b = afp(const_b); | |||
| if (op_type == 0) | |||
| { | |||
| res[0] = v1[0] + b; | |||
| res[1] = v1[1] + b; | |||
| } | |||
| if (op_type == 1) | |||
| { | |||
| res[0] = v1[0] - b; | |||
| res[1] = v1[1] - b; | |||
| } | |||
| if (op_type == 2) | |||
| { | |||
| res[0] = v1[0] * b; | |||
| res[1] = v1[1] * b; | |||
| } | |||
| if (op_type == 3) | |||
| { | |||
| res[0] = v1[0] / b; | |||
| res[1] = v1[1] / b; | |||
| } | |||
| if (op_type == 4) | |||
| { | |||
| res[0] = max(v1[0], b); | |||
| res[1] = max(v1[1], b); | |||
| } | |||
| if (op_type == 5) | |||
| { | |||
| res[0] = min(v1[0], b); | |||
| res[1] = min(v1[1], b); | |||
| } | |||
| if (op_type == 6) | |||
| { | |||
| res[0] = pow(v1[0], afpvec4(b)); | |||
| res[1] = pow(v1[1], afpvec4(b)); | |||
| } | |||
| if (op_type == 7) | |||
| { | |||
| res[0] = b - v1[0]; | |||
| res[1] = b - v1[1]; | |||
| } | |||
| if (op_type == 8) | |||
| { | |||
| res[0] = b / v1[0]; | |||
| res[1] = b / v1[1]; | |||
| } | |||
| #if NCNN_image_shader | |||
| image3d_st8(top_blob_3d, ivec3(gx, gy, gz), res); | |||
| #else | |||
| buffer_st8(a_blob_data, gi, res); | |||
| #endif | |||
| // type 0 1 2 3 | |||
| v2[0] = afpvec4(const_b); | |||
| v2[1] = afpvec4(const_b); | |||
| } | |||
| else | |||
| { | |||
| // type 7 13 19 | |||
| // type 4 5 6 7 | |||
| #if NCNN_image_shader | |||
| afpvec8 v2 = image3d_ld8(b_blob_3d, ivec3(gx, gy, gz)); | |||
| v2 = image3d_ld8(b_blob_3d, ivec3(gx, gy, gz)); | |||
| #else | |||
| afpvec8 v2 = buffer_ld8(b_blob_data, gi); | |||
| v2 = buffer_ld8(b_blob_data, gi); | |||
| #endif | |||
| } | |||
| if (op_type == 0) | |||
| { | |||
| res[0] = v1[0] + v2[0]; | |||
| res[1] = v1[1] + v2[1]; | |||
| } | |||
| if (op_type == 1) | |||
| { | |||
| res[0] = v1[0] - v2[0]; | |||
| res[1] = v1[1] - v2[1]; | |||
| } | |||
| if (op_type == 2) | |||
| { | |||
| res[0] = v1[0] * v2[0]; | |||
| res[1] = v1[1] * v2[1]; | |||
| } | |||
| if (op_type == 3) | |||
| { | |||
| res[0] = v1[0] / v2[0]; | |||
| res[1] = v1[1] / v2[1]; | |||
| } | |||
| if (op_type == 4) | |||
| { | |||
| res[0] = max(v1[0], v2[0]); | |||
| res[1] = max(v1[1], v2[1]); | |||
| } | |||
| if (op_type == 5) | |||
| { | |||
| res[0] = min(v1[0], v2[0]); | |||
| res[1] = min(v1[1], v2[1]); | |||
| } | |||
| if (op_type == 6) | |||
| { | |||
| res[0] = pow(v1[0], v2[0]); | |||
| res[1] = pow(v1[1], v2[1]); | |||
| } | |||
| if (op_type == 7) | |||
| { | |||
| res[0] = v2[0] - v1[0]; | |||
| res[1] = v2[1] - v1[1]; | |||
| } | |||
| if (op_type == 8) | |||
| { | |||
| res[0] = v2[0] / v1[0]; | |||
| res[1] = v2[1] / v1[1]; | |||
| } | |||
| afpvec8 res; | |||
| if (op_type == 0) | |||
| { | |||
| res[0] = v1[0] + v2[0]; | |||
| res[1] = v1[1] + v2[1]; | |||
| } | |||
| if (op_type == 1) | |||
| { | |||
| res[0] = v1[0] - v2[0]; | |||
| res[1] = v1[1] - v2[1]; | |||
| } | |||
| if (op_type == 2) | |||
| { | |||
| res[0] = v1[0] * v2[0]; | |||
| res[1] = v1[1] * v2[1]; | |||
| } | |||
| if (op_type == 3) | |||
| { | |||
| res[0] = v1[0] / v2[0]; | |||
| res[1] = v1[1] / v2[1]; | |||
| } | |||
| if (op_type == 4) | |||
| { | |||
| res[0] = max(v1[0], v2[0]); | |||
| res[1] = max(v1[1], v2[1]); | |||
| } | |||
| if (op_type == 5) | |||
| { | |||
| res[0] = min(v1[0], v2[0]); | |||
| res[1] = min(v1[1], v2[1]); | |||
| } | |||
| if (op_type == 6) | |||
| { | |||
| res[0] = pow(v1[0], v2[0]); | |||
| res[1] = pow(v1[1], v2[1]); | |||
| } | |||
| if (op_type == 7) | |||
| { | |||
| res[0] = v2[0] - v1[0]; | |||
| res[1] = v2[1] - v1[1]; | |||
| } | |||
| if (op_type == 8) | |||
| { | |||
| res[0] = v2[0] / v1[0]; | |||
| res[1] = v2[1] / v1[1]; | |||
| } | |||
| if (op_type == 9) | |||
| { | |||
| res[0] = pow(v2[0], v1[0]); | |||
| res[1] = pow(v2[1], v1[1]); | |||
| } | |||
| #if NCNN_image_shader | |||
| image3d_st8(top_blob_3d, ivec3(gx, gy, gz), res); | |||
| image3d_st8(top_blob_3d, ivec3(gx, gy, gz), res); | |||
| #else | |||
| buffer_st8(top_blob_data, gi, res); | |||
| buffer_st8(top_blob_data, gi, res); | |||
| #endif | |||
| } | |||
| } | |||
| @@ -15,7 +15,7 @@ | |||
| #include "layer/binaryop.h" | |||
| #include "testutil.h" | |||
| #define OP_TYPE_MAX 9 | |||
| #define OP_TYPE_MAX 10 | |||
| static int op_type = 0; | |||
| @@ -23,15 +23,19 @@ static int test_binaryop(const ncnn::Mat& _a, const ncnn::Mat& _b) | |||
| { | |||
| ncnn::Mat a = _a; | |||
| ncnn::Mat b = _b; | |||
| if (op_type == 6) | |||
| if (op_type == 6 || op_type == 9) | |||
| { | |||
| // value must be positive for pow | |||
| // value must be positive for pow/rpow | |||
| a = a.clone(); | |||
| b = b.clone(); | |||
| Randomize(a, 0.001f, 2.f); | |||
| Randomize(b, 0.001f, 2.f); | |||
| } | |||
| if (op_type == 3 || op_type == 8) | |||
| { | |||
| // value must be positive for pow | |||
| // value must be positive for div/rdiv | |||
| a = a.clone(); | |||
| b = b.clone(); | |||
| Randomize(a, 0.1f, 10.f); | |||
| Randomize(b, 0.1f, 10.f); | |||
| } | |||
| @@ -59,12 +63,18 @@ static int test_binaryop(const ncnn::Mat& _a, const ncnn::Mat& _b) | |||
| static int test_binaryop(const ncnn::Mat& _a, float b) | |||
| { | |||
| ncnn::Mat a = _a; | |||
| if (op_type == 6) | |||
| if (op_type == 6 || op_type == 9) | |||
| { | |||
| // value must be positive for pow | |||
| Randomize(a, 0.001f, 2.f); | |||
| b = RandomFloat(0.001f, 2.f); | |||
| } | |||
| if (op_type == 3 || op_type == 8) | |||
| { | |||
| // value must be positive for div/rdiv | |||
| a = a.clone(); | |||
| Randomize(a, 0.1f, 10.f); | |||
| } | |||
| ncnn::ParamDict pd; | |||
| pd.set(0, op_type); | |||
| @@ -82,300 +92,274 @@ static int test_binaryop(const ncnn::Mat& _a, float b) | |||
| return ret; | |||
| } | |||
| // https://github.com/Tencent/ncnn/wiki/binaryop-broadcasting | |||
| static int test_binaryop_1() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(1), 1.f); | |||
| } | |||
| static int test_binaryop_2() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(1), RandomMat(1)) | |||
| || test_binaryop(RandomMat(1), RandomMat(4)) | |||
| || test_binaryop(RandomMat(1), RandomMat(16)); | |||
| } | |||
| static int test_binaryop_3() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(1), RandomMat(11, 3)) | |||
| || test_binaryop(RandomMat(1), RandomMat(11, 4)) | |||
| || test_binaryop(RandomMat(1), RandomMat(11, 16)); | |||
| } | |||
| static int test_binaryop_4() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(1), RandomMat(11, 6, 2)) | |||
| || test_binaryop(RandomMat(1), RandomMat(11, 6, 4)) | |||
| || test_binaryop(RandomMat(1), RandomMat(11, 6, 16)); | |||
| } | |||
| static int test_binaryop_5() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(2), 1.f) | |||
| || test_binaryop(RandomMat(4), 1.f) | |||
| || test_binaryop(RandomMat(16), 1.f); | |||
| } | |||
| static int test_binaryop_6() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(2), RandomMat(1)) | |||
| || test_binaryop(RandomMat(4), RandomMat(1)) | |||
| || test_binaryop(RandomMat(16), RandomMat(1)); | |||
| } | |||
| static int test_binaryop_7() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(2), RandomMat(2)) | |||
| || test_binaryop(RandomMat(4), RandomMat(4)) | |||
| || test_binaryop(RandomMat(16), RandomMat(16)); | |||
| } | |||
| ncnn::Mat a[] = { | |||
| RandomMat(31), | |||
| RandomMat(28), | |||
| RandomMat(24), | |||
| RandomMat(32), | |||
| RandomMat(13, 31), | |||
| RandomMat(14, 28), | |||
| RandomMat(15, 24), | |||
| RandomMat(16, 32), | |||
| RandomMat(7, 3, 31), | |||
| RandomMat(6, 4, 28), | |||
| RandomMat(5, 5, 24), | |||
| RandomMat(4, 6, 32), | |||
| RandomMat(2, 7, 3, 31), | |||
| RandomMat(3, 6, 4, 28), | |||
| RandomMat(4, 5, 5, 24), | |||
| RandomMat(5, 4, 6, 32) | |||
| }; | |||
| ncnn::Mat b[] = { | |||
| RandomMat(1), | |||
| RandomMat(1, 1), | |||
| RandomMat(1, 1, 1), | |||
| RandomMat(1, 1, 1, 1) | |||
| }; | |||
| for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) | |||
| { | |||
| for (int j = 0; j < sizeof(b) / sizeof(b[0]); j++) | |||
| { | |||
| int ret = test_binaryop(a[i], b[j]) || test_binaryop(b[j], a[i]); | |||
| if (ret != 0) | |||
| return ret; | |||
| } | |||
| int ret = test_binaryop(a[i], 0.2f); | |||
| if (ret != 0) | |||
| return ret; | |||
| } | |||
| static int test_binaryop_8() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(3), RandomMat(11, 3)) | |||
| || test_binaryop(RandomMat(4), RandomMat(11, 4)) | |||
| || test_binaryop(RandomMat(16), RandomMat(11, 16)); | |||
| return 0; | |||
| } | |||
| static int test_binaryop_9() | |||
| static int test_binaryop_2() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(2), RandomMat(11, 6, 2)) | |||
| || test_binaryop(RandomMat(4), RandomMat(11, 6, 4)) | |||
| || test_binaryop(RandomMat(16), RandomMat(11, 6, 16)); | |||
| } | |||
| ncnn::Mat a[] = { | |||
| RandomMat(31), | |||
| RandomMat(28), | |||
| RandomMat(24), | |||
| RandomMat(32), | |||
| RandomMat(13, 31), | |||
| RandomMat(14, 28), | |||
| RandomMat(15, 24), | |||
| RandomMat(16, 32), | |||
| RandomMat(7, 3, 31), | |||
| RandomMat(6, 4, 28), | |||
| RandomMat(5, 5, 24), | |||
| RandomMat(4, 6, 32), | |||
| RandomMat(2, 7, 3, 31), | |||
| RandomMat(3, 6, 4, 28), | |||
| RandomMat(4, 5, 5, 24), | |||
| RandomMat(5, 4, 6, 32) | |||
| }; | |||
| for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) | |||
| { | |||
| ncnn::Mat b; | |||
| b.create_like(a[i]); | |||
| Randomize(b); | |||
| static int test_binaryop_10() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 3), 1.f) | |||
| || test_binaryop(RandomMat(11, 4), 1.f) | |||
| || test_binaryop(RandomMat(11, 16), 1.f); | |||
| } | |||
| int ret = test_binaryop(a[i], b) || test_binaryop(b, a[i]); | |||
| if (ret != 0) | |||
| return ret; | |||
| } | |||
| static int test_binaryop_11() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 3), RandomMat(1)) | |||
| || test_binaryop(RandomMat(11, 4), RandomMat(1)) | |||
| || test_binaryop(RandomMat(11, 16), RandomMat(1)); | |||
| return 0; | |||
| } | |||
| static int test_binaryop_12() | |||
| static int test_binaryop_3() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 3), RandomMat(3)) | |||
| || test_binaryop(RandomMat(11, 4), RandomMat(4)) | |||
| || test_binaryop(RandomMat(11, 16), RandomMat(16)); | |||
| } | |||
| ncnn::Mat a[] = { | |||
| RandomMat(13, 31), | |||
| RandomMat(14, 28), | |||
| RandomMat(15, 24), | |||
| RandomMat(16, 32) | |||
| }; | |||
| static int test_binaryop_13() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 3), RandomMat(11, 3)) | |||
| || test_binaryop(RandomMat(11, 4), RandomMat(11, 4)) | |||
| || test_binaryop(RandomMat(11, 16), RandomMat(11, 16)); | |||
| } | |||
| for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) | |||
| { | |||
| ncnn::Mat b0(a[i].h); | |||
| ncnn::Mat b1(1, a[i].h); | |||
| Randomize(b0); | |||
| Randomize(b1); | |||
| static int test_binaryop_14() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(6, 2), RandomMat(11, 6, 2)) | |||
| || test_binaryop(RandomMat(6, 4), RandomMat(11, 6, 4)) | |||
| || test_binaryop(RandomMat(6, 16), RandomMat(11, 6, 16)); | |||
| } | |||
| int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i]) | |||
| || test_binaryop(a[i], b1) || test_binaryop(b1, a[i]); | |||
| if (ret != 0) | |||
| return ret; | |||
| } | |||
| static int test_binaryop_15() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 6, 2), 1.f) | |||
| || test_binaryop(RandomMat(11, 6, 4), 1.f) | |||
| || test_binaryop(RandomMat(11, 6, 16), 1.f); | |||
| return 0; | |||
| } | |||
| static int test_binaryop_16() | |||
| static int test_binaryop_4() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 6, 2), RandomMat(1)) | |||
| || test_binaryop(RandomMat(11, 6, 4), RandomMat(1)) | |||
| || test_binaryop(RandomMat(11, 6, 16), RandomMat(1)); | |||
| } | |||
| ncnn::Mat a[] = { | |||
| RandomMat(7, 3, 31), | |||
| RandomMat(6, 4, 28), | |||
| RandomMat(5, 5, 24), | |||
| RandomMat(4, 6, 32) | |||
| }; | |||
| static int test_binaryop_17() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 6, 2), RandomMat(2)) | |||
| || test_binaryop(RandomMat(11, 6, 4), RandomMat(4)) | |||
| || test_binaryop(RandomMat(11, 6, 16), RandomMat(16)); | |||
| } | |||
| static int test_binaryop_18() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 6, 2), RandomMat(6, 2)) | |||
| || test_binaryop(RandomMat(11, 6, 4), RandomMat(6, 4)) | |||
| || test_binaryop(RandomMat(11, 6, 16), RandomMat(6, 16)); | |||
| } | |||
| for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) | |||
| { | |||
| ncnn::Mat b0(a[i].c); | |||
| ncnn::Mat b1(1, 1, a[i].c); | |||
| ncnn::Mat b2(a[i].h, a[i].c); | |||
| ncnn::Mat b3(1, a[i].h, a[i].c); | |||
| Randomize(b0); | |||
| Randomize(b1); | |||
| Randomize(b2); | |||
| Randomize(b3); | |||
| int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i]) | |||
| || test_binaryop(a[i], b1) || test_binaryop(b1, a[i]) | |||
| || test_binaryop(a[i], b2) || test_binaryop(b2, a[i]) | |||
| || test_binaryop(a[i], b3) || test_binaryop(b3, a[i]); | |||
| if (ret != 0) | |||
| return ret; | |||
| } | |||
| static int test_binaryop_19() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 6, 2), RandomMat(11, 6, 2)) | |||
| || test_binaryop(RandomMat(11, 6, 4), RandomMat(11, 6, 4)) | |||
| || test_binaryop(RandomMat(11, 6, 16), RandomMat(11, 6, 16)); | |||
| return 0; | |||
| } | |||
| static int test_binaryop_20() | |||
| static int test_binaryop_5() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(1), RandomMat(11, 3, 4, 2)) | |||
| || test_binaryop(RandomMat(1), RandomMat(11, 3, 4, 4)) | |||
| || test_binaryop(RandomMat(1), RandomMat(11, 3, 4, 16)); | |||
| } | |||
| ncnn::Mat a[] = { | |||
| RandomMat(2, 7, 3, 31), | |||
| RandomMat(3, 6, 4, 28), | |||
| RandomMat(4, 5, 5, 24), | |||
| RandomMat(5, 4, 6, 32) | |||
| }; | |||
| static int test_binaryop_21() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(2), RandomMat(11, 3, 4, 2)) | |||
| || test_binaryop(RandomMat(4), RandomMat(11, 3, 4, 4)) | |||
| || test_binaryop(RandomMat(16), RandomMat(11, 3, 4, 16)); | |||
| } | |||
| for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) | |||
| { | |||
| ncnn::Mat b0(a[i].c); | |||
| ncnn::Mat b1(1, 1, 1, a[i].c); | |||
| ncnn::Mat b2(a[i].d, a[i].c); | |||
| ncnn::Mat b3(1, 1, a[i].d, a[i].c); | |||
| ncnn::Mat b4(a[i].h, a[i].d, a[i].c); | |||
| ncnn::Mat b5(1, a[i].h, a[i].d, a[i].c); | |||
| Randomize(b0); | |||
| Randomize(b1); | |||
| Randomize(b2); | |||
| Randomize(b3); | |||
| Randomize(b4); | |||
| Randomize(b5); | |||
| int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i]) | |||
| || test_binaryop(a[i], b1) || test_binaryop(b1, a[i]) | |||
| || test_binaryop(a[i], b2) || test_binaryop(b2, a[i]) | |||
| || test_binaryop(a[i], b3) || test_binaryop(b3, a[i]) | |||
| || test_binaryop(a[i], b4) || test_binaryop(b4, a[i]) | |||
| || test_binaryop(a[i], b5) || test_binaryop(b5, a[i]); | |||
| if (ret != 0) | |||
| return ret; | |||
| } | |||
| static int test_binaryop_22() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(4, 2), RandomMat(11, 3, 4, 2)) | |||
| || test_binaryop(RandomMat(4, 4), RandomMat(11, 3, 4, 4)) | |||
| || test_binaryop(RandomMat(4, 16), RandomMat(11, 3, 4, 16)); | |||
| return 0; | |||
| } | |||
| static int test_binaryop_23() | |||
| static int test_binaryop_6() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(3, 4, 2), RandomMat(11, 3, 4, 2)) | |||
| || test_binaryop(RandomMat(3, 4, 4), RandomMat(11, 3, 4, 4)) | |||
| || test_binaryop(RandomMat(3, 4, 16), RandomMat(11, 3, 4, 16)); | |||
| } | |||
| ncnn::Mat a[] = { | |||
| RandomMat(13, 31), | |||
| RandomMat(14, 28), | |||
| RandomMat(15, 24), | |||
| RandomMat(16, 32) | |||
| }; | |||
| static int test_binaryop_24() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 3, 4, 2), 1.f) | |||
| || test_binaryop(RandomMat(11, 3, 4, 4), 1.f) | |||
| || test_binaryop(RandomMat(11, 3, 4, 16), 1.f); | |||
| } | |||
| for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) | |||
| { | |||
| ncnn::Mat b0(a[i].w, 1); | |||
| Randomize(b0); | |||
| static int test_binaryop_25() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 3, 4, 2), RandomMat(1)) | |||
| || test_binaryop(RandomMat(11, 3, 4, 4), RandomMat(1)) | |||
| || test_binaryop(RandomMat(11, 3, 4, 16), RandomMat(1)); | |||
| } | |||
| int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i]); | |||
| if (ret != 0) | |||
| return ret; | |||
| } | |||
| static int test_binaryop_26() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 3, 4, 2), RandomMat(2)) | |||
| || test_binaryop(RandomMat(11, 3, 4, 4), RandomMat(4)) | |||
| || test_binaryop(RandomMat(11, 3, 4, 16), RandomMat(16)); | |||
| return 0; | |||
| } | |||
| static int test_binaryop_27() | |||
| static int test_binaryop_7() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 3, 4, 2), RandomMat(4, 2)) | |||
| || test_binaryop(RandomMat(11, 3, 4, 4), RandomMat(4, 4)) | |||
| || test_binaryop(RandomMat(11, 3, 4, 16), RandomMat(4, 16)); | |||
| } | |||
| ncnn::Mat a[] = { | |||
| RandomMat(7, 3, 31), | |||
| RandomMat(6, 4, 28), | |||
| RandomMat(5, 5, 24), | |||
| RandomMat(4, 6, 32) | |||
| }; | |||
| static int test_binaryop_28() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 3, 4, 2), RandomMat(3, 4, 2)) | |||
| || test_binaryop(RandomMat(11, 3, 4, 4), RandomMat(3, 4, 4)) | |||
| || test_binaryop(RandomMat(11, 3, 4, 16), RandomMat(3, 4, 16)); | |||
| } | |||
| for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) | |||
| { | |||
| ncnn::Mat b0(a[i].w, 1, 1); | |||
| ncnn::Mat b1(a[i].w, a[i].h, 1); | |||
| Randomize(b0); | |||
| Randomize(b1); | |||
| static int test_binaryop_29() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 3, 4, 2), RandomMat(11, 3, 4, 2)) | |||
| || test_binaryop(RandomMat(11, 3, 4, 4), RandomMat(11, 3, 4, 4)) | |||
| || test_binaryop(RandomMat(11, 3, 4, 16), RandomMat(11, 3, 4, 16)); | |||
| } | |||
| int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i]) | |||
| || test_binaryop(a[i], b1) || test_binaryop(b1, a[i]); | |||
| if (ret != 0) | |||
| return ret; | |||
| } | |||
| static int test_binaryop_s1() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 6, 2), RandomMat(1, 1, 2)) | |||
| || test_binaryop(RandomMat(11, 6, 4), RandomMat(1, 1, 4)) | |||
| || test_binaryop(RandomMat(11, 6, 16), RandomMat(1, 1, 16)); | |||
| return 0; | |||
| } | |||
| static int test_binaryop_s2() | |||
| static int test_binaryop_8() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 6, 2), RandomMat(11, 6, 1)) | |||
| || test_binaryop(RandomMat(11, 6, 4), RandomMat(11, 6, 1)) | |||
| || test_binaryop(RandomMat(11, 6, 16), RandomMat(11, 6, 1)); | |||
| } | |||
| ncnn::Mat a[] = { | |||
| RandomMat(2, 7, 3, 31), | |||
| RandomMat(3, 6, 4, 28), | |||
| RandomMat(4, 5, 5, 24), | |||
| RandomMat(5, 4, 6, 32) | |||
| }; | |||
| static int test_binaryop_s3() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(1, 1, 2), RandomMat(11, 6, 2)) | |||
| || test_binaryop(RandomMat(1, 1, 4), RandomMat(11, 6, 4)) | |||
| || test_binaryop(RandomMat(1, 1, 16), RandomMat(11, 6, 16)); | |||
| } | |||
| for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) | |||
| { | |||
| ncnn::Mat b0(a[i].w, 1, 1, 1); | |||
| ncnn::Mat b1(a[i].w, a[i].h, 1, 1); | |||
| ncnn::Mat b2(a[i].w, a[i].h, a[i].d, 1); | |||
| Randomize(b0); | |||
| Randomize(b1); | |||
| Randomize(b2); | |||
| int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i]) | |||
| || test_binaryop(a[i], b1) || test_binaryop(b1, a[i]) | |||
| || test_binaryop(a[i], b2) || test_binaryop(b2, a[i]); | |||
| if (ret != 0) | |||
| return ret; | |||
| } | |||
| static int test_binaryop_s4() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 6, 1), RandomMat(11, 6, 2)) | |||
| || test_binaryop(RandomMat(11, 6, 1), RandomMat(11, 6, 4)) | |||
| || test_binaryop(RandomMat(11, 6, 1), RandomMat(11, 6, 16)); | |||
| return 0; | |||
| } | |||
| static int test_binaryop_s5() | |||
| static int test_binaryop_9() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 6, 2), RandomMat(1, 6, 2)) | |||
| || test_binaryop(RandomMat(11, 6, 4), RandomMat(1, 6, 4)) | |||
| || test_binaryop(RandomMat(11, 6, 16), RandomMat(1, 6, 16)); | |||
| } | |||
| ncnn::Mat a[] = { | |||
| RandomMat(7, 3, 31), | |||
| RandomMat(6, 4, 28), | |||
| RandomMat(5, 5, 24), | |||
| RandomMat(4, 6, 32) | |||
| }; | |||
| static int test_binaryop_s6() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 6, 2), RandomMat(11, 1, 2)) | |||
| || test_binaryop(RandomMat(11, 6, 4), RandomMat(11, 1, 4)) | |||
| || test_binaryop(RandomMat(11, 6, 16), RandomMat(11, 1, 16)); | |||
| } | |||
| for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) | |||
| { | |||
| ncnn::Mat b0(a[i].w, 1, a[i].c); | |||
| Randomize(b0); | |||
| static int test_binaryop_s7() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(1, 6, 2), RandomMat(11, 6, 2)) | |||
| || test_binaryop(RandomMat(1, 6, 4), RandomMat(11, 6, 4)) | |||
| || test_binaryop(RandomMat(1, 6, 16), RandomMat(11, 6, 16)); | |||
| } | |||
| int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i]); | |||
| if (ret != 0) | |||
| return ret; | |||
| } | |||
| static int test_binaryop_s8() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 1, 2), RandomMat(11, 6, 2)) | |||
| || test_binaryop(RandomMat(11, 1, 4), RandomMat(11, 6, 4)) | |||
| || test_binaryop(RandomMat(11, 1, 16), RandomMat(11, 6, 16)); | |||
| return 0; | |||
| } | |||
| int main() | |||
| @@ -393,35 +377,7 @@ int main() | |||
| || test_binaryop_6() | |||
| || test_binaryop_7() | |||
| || test_binaryop_8() | |||
| || test_binaryop_9() | |||
| || test_binaryop_10() | |||
| || test_binaryop_11() | |||
| || test_binaryop_12() | |||
| || test_binaryop_13() | |||
| || test_binaryop_14() | |||
| || test_binaryop_15() | |||
| || test_binaryop_16() | |||
| || test_binaryop_17() | |||
| || test_binaryop_18() | |||
| || test_binaryop_19() | |||
| || test_binaryop_20() | |||
| || test_binaryop_21() | |||
| || test_binaryop_22() | |||
| || test_binaryop_23() | |||
| || test_binaryop_24() | |||
| || test_binaryop_25() | |||
| || test_binaryop_26() | |||
| || test_binaryop_27() | |||
| || test_binaryop_28() | |||
| || test_binaryop_29() | |||
| || test_binaryop_s1() | |||
| || test_binaryop_s2() | |||
| || test_binaryop_s3() | |||
| || test_binaryop_s4() | |||
| || test_binaryop_s5() | |||
| || test_binaryop_s6() | |||
| || test_binaryop_s7() | |||
| || test_binaryop_s8(); | |||
| || test_binaryop_9(); | |||
| if (ret != 0) | |||
| return ret; | |||
| @@ -15,7 +15,7 @@ | |||
| #include "layer/binaryop.h" | |||
| #include "testutil.h" | |||
| #define OP_TYPE_MAX 9 | |||
| #define OP_TYPE_MAX 10 | |||
| static int op_type = 0; | |||
| @@ -23,15 +23,19 @@ static int test_binaryop(const ncnn::Mat& _a, const ncnn::Mat& _b) | |||
| { | |||
| ncnn::Mat a = _a; | |||
| ncnn::Mat b = _b; | |||
| if (op_type == 6) | |||
| if (op_type == 6 || op_type == 9) | |||
| { | |||
| // value must be positive for pow | |||
| // value must be positive for pow/rpow | |||
| a = a.clone(); | |||
| b = b.clone(); | |||
| Randomize(a, 0.001f, 2.f); | |||
| Randomize(b, 0.001f, 2.f); | |||
| } | |||
| if (op_type == 3 || op_type == 8) | |||
| { | |||
| // value must be positive for pow | |||
| // value must be positive for div/rdiv | |||
| a = a.clone(); | |||
| b = b.clone(); | |||
| Randomize(a, 0.1f, 10.f); | |||
| Randomize(b, 0.1f, 10.f); | |||
| } | |||
| @@ -59,12 +63,18 @@ static int test_binaryop(const ncnn::Mat& _a, const ncnn::Mat& _b) | |||
| static int test_binaryop(const ncnn::Mat& _a, float b) | |||
| { | |||
| ncnn::Mat a = _a; | |||
| if (op_type == 6) | |||
| if (op_type == 6 || op_type == 9) | |||
| { | |||
| // value must be positive for pow | |||
| Randomize(a, 0.001f, 2.f); | |||
| b = RandomFloat(0.001f, 2.f); | |||
| } | |||
| if (op_type == 3 || op_type == 8) | |||
| { | |||
| // value must be positive for div/rdiv | |||
| a = a.clone(); | |||
| Randomize(a, 0.1f, 10.f); | |||
| } | |||
| ncnn::ParamDict pd; | |||
| pd.set(0, op_type); | |||
| @@ -86,296 +96,272 @@ static int test_binaryop(const ncnn::Mat& _a, float b) | |||
| static int test_binaryop_1() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(1), 1.f); | |||
| } | |||
| static int test_binaryop_2() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(1), RandomMat(1)) | |||
| || test_binaryop(RandomMat(1), RandomMat(4)) | |||
| || test_binaryop(RandomMat(1), RandomMat(16)); | |||
| } | |||
| static int test_binaryop_3() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(1), RandomMat(11, 3)) | |||
| || test_binaryop(RandomMat(1), RandomMat(11, 4)) | |||
| || test_binaryop(RandomMat(1), RandomMat(11, 16)); | |||
| } | |||
| static int test_binaryop_4() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(1), RandomMat(11, 6, 2)) | |||
| || test_binaryop(RandomMat(1), RandomMat(11, 6, 4)) | |||
| || test_binaryop(RandomMat(1), RandomMat(11, 6, 16)); | |||
| } | |||
| static int test_binaryop_5() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(2), 1.f) | |||
| || test_binaryop(RandomMat(4), 1.f) | |||
| || test_binaryop(RandomMat(16), 1.f); | |||
| } | |||
| static int test_binaryop_6() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(2), RandomMat(1)) | |||
| || test_binaryop(RandomMat(4), RandomMat(1)) | |||
| || test_binaryop(RandomMat(16), RandomMat(1)); | |||
| } | |||
| static int test_binaryop_7() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(2), RandomMat(2)) | |||
| || test_binaryop(RandomMat(4), RandomMat(4)) | |||
| || test_binaryop(RandomMat(16), RandomMat(16)); | |||
| } | |||
| static int test_binaryop_8() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(3), RandomMat(11, 3)) | |||
| || test_binaryop(RandomMat(4), RandomMat(11, 4)) | |||
| || test_binaryop(RandomMat(16), RandomMat(11, 16)); | |||
| } | |||
| ncnn::Mat a[] = { | |||
| RandomMat(31), | |||
| RandomMat(28), | |||
| RandomMat(24), | |||
| RandomMat(32), | |||
| RandomMat(13, 31), | |||
| RandomMat(14, 28), | |||
| RandomMat(15, 24), | |||
| RandomMat(16, 32), | |||
| RandomMat(7, 3, 31), | |||
| RandomMat(6, 4, 28), | |||
| RandomMat(5, 5, 24), | |||
| RandomMat(4, 6, 32), | |||
| RandomMat(2, 7, 3, 31), | |||
| RandomMat(3, 6, 4, 28), | |||
| RandomMat(4, 5, 5, 24), | |||
| RandomMat(5, 4, 6, 32) | |||
| }; | |||
| ncnn::Mat b[] = { | |||
| RandomMat(1), | |||
| RandomMat(1, 1), | |||
| RandomMat(1, 1, 1), | |||
| RandomMat(1, 1, 1, 1) | |||
| }; | |||
| for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) | |||
| { | |||
| for (int j = 0; j < sizeof(b) / sizeof(b[0]); j++) | |||
| { | |||
| int ret = test_binaryop(a[i], b[j]) || test_binaryop(b[j], a[i]); | |||
| if (ret != 0) | |||
| return ret; | |||
| } | |||
| int ret = test_binaryop(a[i], 0.2f); | |||
| if (ret != 0) | |||
| return ret; | |||
| } | |||
| static int test_binaryop_9() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(2), RandomMat(11, 6, 2)) | |||
| || test_binaryop(RandomMat(4), RandomMat(11, 6, 4)) | |||
| || test_binaryop(RandomMat(16), RandomMat(11, 6, 16)); | |||
| return 0; | |||
| } | |||
| static int test_binaryop_10() | |||
| static int test_binaryop_2() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 3), 1.f) | |||
| || test_binaryop(RandomMat(11, 4), 1.f) | |||
| || test_binaryop(RandomMat(11, 16), 1.f); | |||
| } | |||
| ncnn::Mat a[] = { | |||
| RandomMat(31), | |||
| RandomMat(28), | |||
| RandomMat(24), | |||
| RandomMat(32), | |||
| RandomMat(13, 31), | |||
| RandomMat(14, 28), | |||
| RandomMat(15, 24), | |||
| RandomMat(16, 32), | |||
| RandomMat(7, 3, 31), | |||
| RandomMat(6, 4, 28), | |||
| RandomMat(5, 5, 24), | |||
| RandomMat(4, 6, 32), | |||
| RandomMat(2, 7, 3, 31), | |||
| RandomMat(3, 6, 4, 28), | |||
| RandomMat(4, 5, 5, 24), | |||
| RandomMat(5, 4, 6, 32) | |||
| }; | |||
| for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) | |||
| { | |||
| ncnn::Mat b; | |||
| b.create_like(a[i]); | |||
| Randomize(b); | |||
| static int test_binaryop_11() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 3), RandomMat(1)) | |||
| || test_binaryop(RandomMat(11, 4), RandomMat(1)) | |||
| || test_binaryop(RandomMat(11, 16), RandomMat(1)); | |||
| } | |||
| int ret = test_binaryop(a[i], b) || test_binaryop(b, a[i]); | |||
| if (ret != 0) | |||
| return ret; | |||
| } | |||
| static int test_binaryop_12() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 3), RandomMat(3)) | |||
| || test_binaryop(RandomMat(11, 4), RandomMat(4)) | |||
| || test_binaryop(RandomMat(11, 16), RandomMat(16)); | |||
| return 0; | |||
| } | |||
| static int test_binaryop_13() | |||
| static int test_binaryop_3() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 3), RandomMat(11, 3)) | |||
| || test_binaryop(RandomMat(11, 4), RandomMat(11, 4)) | |||
| || test_binaryop(RandomMat(11, 16), RandomMat(11, 16)); | |||
| } | |||
| ncnn::Mat a[] = { | |||
| RandomMat(13, 31), | |||
| RandomMat(14, 28), | |||
| RandomMat(15, 24), | |||
| RandomMat(16, 32) | |||
| }; | |||
| static int test_binaryop_14() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(6, 2), RandomMat(11, 6, 2)) | |||
| || test_binaryop(RandomMat(6, 4), RandomMat(11, 6, 4)) | |||
| || test_binaryop(RandomMat(6, 16), RandomMat(11, 6, 16)); | |||
| } | |||
| for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) | |||
| { | |||
| ncnn::Mat b0(a[i].h); | |||
| ncnn::Mat b1(1, a[i].h); | |||
| Randomize(b0); | |||
| Randomize(b1); | |||
| static int test_binaryop_15() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 6, 2), 1.f) | |||
| || test_binaryop(RandomMat(11, 6, 4), 1.f) | |||
| || test_binaryop(RandomMat(11, 6, 16), 1.f); | |||
| } | |||
| int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i]) | |||
| || test_binaryop(a[i], b1) || test_binaryop(b1, a[i]); | |||
| if (ret != 0) | |||
| return ret; | |||
| } | |||
| static int test_binaryop_16() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 6, 2), RandomMat(1)) | |||
| || test_binaryop(RandomMat(11, 6, 4), RandomMat(1)) | |||
| || test_binaryop(RandomMat(11, 6, 16), RandomMat(1)); | |||
| return 0; | |||
| } | |||
| static int test_binaryop_17() | |||
| static int test_binaryop_4() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 6, 2), RandomMat(2)) | |||
| || test_binaryop(RandomMat(11, 6, 4), RandomMat(4)) | |||
| || test_binaryop(RandomMat(11, 6, 16), RandomMat(16)); | |||
| } | |||
| ncnn::Mat a[] = { | |||
| RandomMat(7, 3, 31), | |||
| RandomMat(6, 4, 28), | |||
| RandomMat(5, 5, 24), | |||
| RandomMat(4, 6, 32) | |||
| }; | |||
| static int test_binaryop_18() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 6, 2), RandomMat(6, 2)) | |||
| || test_binaryop(RandomMat(11, 6, 4), RandomMat(6, 4)) | |||
| || test_binaryop(RandomMat(11, 6, 16), RandomMat(6, 16)); | |||
| } | |||
| for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) | |||
| { | |||
| ncnn::Mat b0(a[i].c); | |||
| ncnn::Mat b1(1, 1, a[i].c); | |||
| ncnn::Mat b2(a[i].h, a[i].c); | |||
| ncnn::Mat b3(1, a[i].h, a[i].c); | |||
| Randomize(b0); | |||
| Randomize(b1); | |||
| Randomize(b2); | |||
| Randomize(b3); | |||
| int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i]) | |||
| || test_binaryop(a[i], b1) || test_binaryop(b1, a[i]) | |||
| || test_binaryop(a[i], b2) || test_binaryop(b2, a[i]) | |||
| || test_binaryop(a[i], b3) || test_binaryop(b3, a[i]); | |||
| if (ret != 0) | |||
| return ret; | |||
| } | |||
| static int test_binaryop_19() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 6, 2), RandomMat(11, 6, 2)) | |||
| || test_binaryop(RandomMat(11, 6, 4), RandomMat(11, 6, 4)) | |||
| || test_binaryop(RandomMat(11, 6, 16), RandomMat(11, 6, 16)); | |||
| return 0; | |||
| } | |||
| static int test_binaryop_20() | |||
| static int test_binaryop_5() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(1), RandomMat(11, 3, 4, 2)) | |||
| || test_binaryop(RandomMat(1), RandomMat(11, 3, 4, 4)) | |||
| || test_binaryop(RandomMat(1), RandomMat(11, 3, 4, 16)); | |||
| } | |||
| ncnn::Mat a[] = { | |||
| RandomMat(2, 7, 3, 31), | |||
| RandomMat(3, 6, 4, 28), | |||
| RandomMat(4, 5, 5, 24), | |||
| RandomMat(5, 4, 6, 32) | |||
| }; | |||
| static int test_binaryop_21() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(2), RandomMat(11, 3, 4, 2)) | |||
| || test_binaryop(RandomMat(4), RandomMat(11, 3, 4, 4)) | |||
| || test_binaryop(RandomMat(16), RandomMat(11, 3, 4, 16)); | |||
| } | |||
| for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) | |||
| { | |||
| ncnn::Mat b0(a[i].c); | |||
| ncnn::Mat b1(1, 1, 1, a[i].c); | |||
| ncnn::Mat b2(a[i].d, a[i].c); | |||
| ncnn::Mat b3(1, 1, a[i].d, a[i].c); | |||
| ncnn::Mat b4(a[i].h, a[i].d, a[i].c); | |||
| ncnn::Mat b5(1, a[i].h, a[i].d, a[i].c); | |||
| Randomize(b0); | |||
| Randomize(b1); | |||
| Randomize(b2); | |||
| Randomize(b3); | |||
| Randomize(b4); | |||
| Randomize(b5); | |||
| int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i]) | |||
| || test_binaryop(a[i], b1) || test_binaryop(b1, a[i]) | |||
| || test_binaryop(a[i], b2) || test_binaryop(b2, a[i]) | |||
| || test_binaryop(a[i], b3) || test_binaryop(b3, a[i]) | |||
| || test_binaryop(a[i], b4) || test_binaryop(b4, a[i]) | |||
| || test_binaryop(a[i], b5) || test_binaryop(b5, a[i]); | |||
| if (ret != 0) | |||
| return ret; | |||
| } | |||
| static int test_binaryop_22() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(4, 2), RandomMat(11, 3, 4, 2)) | |||
| || test_binaryop(RandomMat(4, 4), RandomMat(11, 3, 4, 4)) | |||
| || test_binaryop(RandomMat(4, 16), RandomMat(11, 3, 4, 16)); | |||
| return 0; | |||
| } | |||
| static int test_binaryop_23() | |||
| static int test_binaryop_6() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(3, 4, 2), RandomMat(11, 3, 4, 2)) | |||
| || test_binaryop(RandomMat(3, 4, 4), RandomMat(11, 3, 4, 4)) | |||
| || test_binaryop(RandomMat(3, 4, 16), RandomMat(11, 3, 4, 16)); | |||
| } | |||
| ncnn::Mat a[] = { | |||
| RandomMat(13, 31), | |||
| RandomMat(14, 28), | |||
| RandomMat(15, 24), | |||
| RandomMat(16, 32) | |||
| }; | |||
| static int test_binaryop_24() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 3, 4, 2), 1.f) | |||
| || test_binaryop(RandomMat(11, 3, 4, 4), 1.f) | |||
| || test_binaryop(RandomMat(11, 3, 4, 16), 1.f); | |||
| } | |||
| for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) | |||
| { | |||
| ncnn::Mat b0(a[i].w, 1); | |||
| Randomize(b0); | |||
| static int test_binaryop_25() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 3, 4, 2), RandomMat(1)) | |||
| || test_binaryop(RandomMat(11, 3, 4, 4), RandomMat(1)) | |||
| || test_binaryop(RandomMat(11, 3, 4, 16), RandomMat(1)); | |||
| } | |||
| int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i]); | |||
| if (ret != 0) | |||
| return ret; | |||
| } | |||
| static int test_binaryop_26() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 3, 4, 2), RandomMat(2)) | |||
| || test_binaryop(RandomMat(11, 3, 4, 4), RandomMat(4)) | |||
| || test_binaryop(RandomMat(11, 3, 4, 16), RandomMat(16)); | |||
| return 0; | |||
| } | |||
| static int test_binaryop_27() | |||
| static int test_binaryop_7() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 3, 4, 2), RandomMat(4, 2)) | |||
| || test_binaryop(RandomMat(11, 3, 4, 4), RandomMat(4, 4)) | |||
| || test_binaryop(RandomMat(11, 3, 4, 16), RandomMat(4, 16)); | |||
| } | |||
| ncnn::Mat a[] = { | |||
| RandomMat(7, 3, 31), | |||
| RandomMat(6, 4, 28), | |||
| RandomMat(5, 5, 24), | |||
| RandomMat(4, 6, 32) | |||
| }; | |||
| static int test_binaryop_28() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 3, 4, 2), RandomMat(3, 4, 2)) | |||
| || test_binaryop(RandomMat(11, 3, 4, 4), RandomMat(3, 4, 4)) | |||
| || test_binaryop(RandomMat(11, 3, 4, 16), RandomMat(3, 4, 16)); | |||
| } | |||
| for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) | |||
| { | |||
| ncnn::Mat b0(a[i].w, 1, 1); | |||
| ncnn::Mat b1(a[i].w, a[i].h, 1); | |||
| Randomize(b0); | |||
| Randomize(b1); | |||
| static int test_binaryop_29() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 3, 4, 2), RandomMat(11, 3, 4, 2)) | |||
| || test_binaryop(RandomMat(11, 3, 4, 4), RandomMat(11, 3, 4, 4)) | |||
| || test_binaryop(RandomMat(11, 3, 4, 16), RandomMat(11, 3, 4, 16)); | |||
| } | |||
| int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i]) | |||
| || test_binaryop(a[i], b1) || test_binaryop(b1, a[i]); | |||
| if (ret != 0) | |||
| return ret; | |||
| } | |||
| static int test_binaryop_s1() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 6, 2), RandomMat(1, 1, 2)) | |||
| || test_binaryop(RandomMat(11, 6, 4), RandomMat(1, 1, 4)) | |||
| || test_binaryop(RandomMat(11, 6, 16), RandomMat(1, 1, 16)); | |||
| return 0; | |||
| } | |||
| static int test_binaryop_s2() | |||
| static int test_binaryop_8() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 6, 2), RandomMat(11, 6, 1)) | |||
| || test_binaryop(RandomMat(11, 6, 4), RandomMat(11, 6, 1)) | |||
| || test_binaryop(RandomMat(11, 6, 16), RandomMat(11, 6, 1)); | |||
| } | |||
| ncnn::Mat a[] = { | |||
| RandomMat(2, 7, 3, 31), | |||
| RandomMat(3, 6, 4, 28), | |||
| RandomMat(4, 5, 5, 24), | |||
| RandomMat(5, 4, 6, 32) | |||
| }; | |||
| static int test_binaryop_s3() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(1, 1, 2), RandomMat(11, 6, 2)) | |||
| || test_binaryop(RandomMat(1, 1, 4), RandomMat(11, 6, 4)) | |||
| || test_binaryop(RandomMat(1, 1, 16), RandomMat(11, 6, 16)); | |||
| } | |||
| for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) | |||
| { | |||
| ncnn::Mat b0(a[i].w, 1, 1, 1); | |||
| ncnn::Mat b1(a[i].w, a[i].h, 1, 1); | |||
| ncnn::Mat b2(a[i].w, a[i].h, a[i].d, 1); | |||
| Randomize(b0); | |||
| Randomize(b1); | |||
| Randomize(b2); | |||
| int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i]) | |||
| || test_binaryop(a[i], b1) || test_binaryop(b1, a[i]) | |||
| || test_binaryop(a[i], b2) || test_binaryop(b2, a[i]); | |||
| if (ret != 0) | |||
| return ret; | |||
| } | |||
| static int test_binaryop_s4() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 6, 1), RandomMat(11, 6, 2)) | |||
| || test_binaryop(RandomMat(11, 6, 1), RandomMat(11, 6, 4)) | |||
| || test_binaryop(RandomMat(11, 6, 1), RandomMat(11, 6, 16)); | |||
| return 0; | |||
| } | |||
| static int test_binaryop_s5() | |||
| static int test_binaryop_9() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 6, 2), RandomMat(1, 6, 2)) | |||
| || test_binaryop(RandomMat(11, 6, 4), RandomMat(1, 6, 4)) | |||
| || test_binaryop(RandomMat(11, 6, 16), RandomMat(1, 6, 16)); | |||
| } | |||
| ncnn::Mat a[] = { | |||
| RandomMat(7, 3, 31), | |||
| RandomMat(6, 4, 28), | |||
| RandomMat(5, 5, 24), | |||
| RandomMat(4, 6, 32) | |||
| }; | |||
| static int test_binaryop_s6() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 6, 2), RandomMat(11, 1, 2)) | |||
| || test_binaryop(RandomMat(11, 6, 4), RandomMat(11, 1, 4)) | |||
| || test_binaryop(RandomMat(11, 6, 16), RandomMat(11, 1, 16)); | |||
| } | |||
| for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) | |||
| { | |||
| ncnn::Mat b0(a[i].w, 1, a[i].c); | |||
| Randomize(b0); | |||
| static int test_binaryop_s7() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(1, 6, 2), RandomMat(11, 6, 2)) | |||
| || test_binaryop(RandomMat(1, 6, 4), RandomMat(11, 6, 4)) | |||
| || test_binaryop(RandomMat(1, 6, 16), RandomMat(11, 6, 16)); | |||
| } | |||
| int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i]); | |||
| if (ret != 0) | |||
| return ret; | |||
| } | |||
| static int test_binaryop_s8() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 1, 2), RandomMat(11, 6, 2)) | |||
| || test_binaryop(RandomMat(11, 1, 4), RandomMat(11, 6, 4)) | |||
| || test_binaryop(RandomMat(11, 1, 16), RandomMat(11, 6, 16)); | |||
| return 0; | |||
| } | |||
| int main() | |||
| @@ -393,35 +379,7 @@ int main() | |||
| || test_binaryop_6() | |||
| || test_binaryop_7() | |||
| || test_binaryop_8() | |||
| || test_binaryop_9() | |||
| || test_binaryop_10() | |||
| || test_binaryop_11() | |||
| || test_binaryop_12() | |||
| || test_binaryop_13() | |||
| || test_binaryop_14() | |||
| || test_binaryop_15() | |||
| || test_binaryop_16() | |||
| || test_binaryop_17() | |||
| || test_binaryop_18() | |||
| || test_binaryop_19() | |||
| || test_binaryop_20() | |||
| || test_binaryop_21() | |||
| || test_binaryop_22() | |||
| || test_binaryop_23() | |||
| || test_binaryop_24() | |||
| || test_binaryop_25() | |||
| || test_binaryop_26() | |||
| || test_binaryop_27() | |||
| || test_binaryop_28() | |||
| || test_binaryop_29() | |||
| || test_binaryop_s1() | |||
| || test_binaryop_s2() | |||
| || test_binaryop_s3() | |||
| || test_binaryop_s4() | |||
| || test_binaryop_s5() | |||
| || test_binaryop_s6() | |||
| || test_binaryop_s7() | |||
| || test_binaryop_s8(); | |||
| || test_binaryop_9(); | |||
| if (ret != 0) | |||
| return ret; | |||
| @@ -15,7 +15,7 @@ | |||
| #include "layer/binaryop.h" | |||
| #include "testutil.h" | |||
| #define OP_TYPE_MAX 9 | |||
| #define OP_TYPE_MAX 10 | |||
| static int op_type = 0; | |||
| @@ -23,15 +23,19 @@ static int test_binaryop(const ncnn::Mat& _a, const ncnn::Mat& _b) | |||
| { | |||
| ncnn::Mat a = _a; | |||
| ncnn::Mat b = _b; | |||
| if (op_type == 6) | |||
| if (op_type == 6 || op_type == 9) | |||
| { | |||
| // value must be positive for pow | |||
| // value must be positive for pow/rpow | |||
| a = a.clone(); | |||
| b = b.clone(); | |||
| Randomize(a, 0.001f, 2.f); | |||
| Randomize(b, 0.001f, 2.f); | |||
| } | |||
| if (op_type == 3 || op_type == 8) | |||
| { | |||
| // value must be positive for pow | |||
| // value must be positive for div/rdiv | |||
| a = a.clone(); | |||
| b = b.clone(); | |||
| Randomize(a, 0.1f, 10.f); | |||
| Randomize(b, 0.1f, 10.f); | |||
| } | |||
| @@ -59,12 +63,18 @@ static int test_binaryop(const ncnn::Mat& _a, const ncnn::Mat& _b) | |||
| static int test_binaryop(const ncnn::Mat& _a, float b) | |||
| { | |||
| ncnn::Mat a = _a; | |||
| if (op_type == 6) | |||
| if (op_type == 6 || op_type == 9) | |||
| { | |||
| // value must be positive for pow | |||
| // value must be positive for pow/rpow | |||
| Randomize(a, 0.001f, 2.f); | |||
| b = RandomFloat(0.001f, 2.f); | |||
| } | |||
| if (op_type == 3 || op_type == 8) | |||
| { | |||
| // value must be positive for div/rdiv | |||
| a = a.clone(); | |||
| Randomize(a, 0.1f, 10.f); | |||
| } | |||
| ncnn::ParamDict pd; | |||
| pd.set(0, op_type); | |||
| @@ -86,296 +96,272 @@ static int test_binaryop(const ncnn::Mat& _a, float b) | |||
| static int test_binaryop_1() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(1), 1.f); | |||
| } | |||
| static int test_binaryop_2() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(1), RandomMat(1)) | |||
| || test_binaryop(RandomMat(1), RandomMat(4)) | |||
| || test_binaryop(RandomMat(1), RandomMat(16)); | |||
| } | |||
| static int test_binaryop_3() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(1), RandomMat(11, 3)) | |||
| || test_binaryop(RandomMat(1), RandomMat(11, 4)) | |||
| || test_binaryop(RandomMat(1), RandomMat(11, 16)); | |||
| } | |||
| static int test_binaryop_4() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(1), RandomMat(11, 6, 2)) | |||
| || test_binaryop(RandomMat(1), RandomMat(11, 6, 4)) | |||
| || test_binaryop(RandomMat(1), RandomMat(11, 6, 16)); | |||
| } | |||
| static int test_binaryop_5() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(2), 1.f) | |||
| || test_binaryop(RandomMat(4), 1.f) | |||
| || test_binaryop(RandomMat(16), 1.f); | |||
| } | |||
| static int test_binaryop_6() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(2), RandomMat(1)) | |||
| || test_binaryop(RandomMat(4), RandomMat(1)) | |||
| || test_binaryop(RandomMat(16), RandomMat(1)); | |||
| } | |||
| static int test_binaryop_7() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(2), RandomMat(2)) | |||
| || test_binaryop(RandomMat(4), RandomMat(4)) | |||
| || test_binaryop(RandomMat(16), RandomMat(16)); | |||
| } | |||
| static int test_binaryop_8() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(3), RandomMat(11, 3)) | |||
| || test_binaryop(RandomMat(4), RandomMat(11, 4)) | |||
| || test_binaryop(RandomMat(16), RandomMat(11, 16)); | |||
| } | |||
| ncnn::Mat a[] = { | |||
| RandomMat(31), | |||
| RandomMat(28), | |||
| RandomMat(24), | |||
| RandomMat(32), | |||
| RandomMat(13, 31), | |||
| RandomMat(14, 28), | |||
| RandomMat(15, 24), | |||
| RandomMat(16, 32), | |||
| RandomMat(7, 3, 31), | |||
| RandomMat(6, 4, 28), | |||
| RandomMat(5, 5, 24), | |||
| RandomMat(4, 6, 32), | |||
| RandomMat(2, 7, 3, 31), | |||
| RandomMat(3, 6, 4, 28), | |||
| RandomMat(4, 5, 5, 24), | |||
| RandomMat(5, 4, 6, 32) | |||
| }; | |||
| ncnn::Mat b[] = { | |||
| RandomMat(1), | |||
| RandomMat(1, 1), | |||
| RandomMat(1, 1, 1), | |||
| RandomMat(1, 1, 1, 1) | |||
| }; | |||
| for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) | |||
| { | |||
| for (int j = 0; j < sizeof(b) / sizeof(b[0]); j++) | |||
| { | |||
| int ret = test_binaryop(a[i], b[j]) || test_binaryop(b[j], a[i]); | |||
| if (ret != 0) | |||
| return ret; | |||
| } | |||
| int ret = test_binaryop(a[i], 0.2f); | |||
| if (ret != 0) | |||
| return ret; | |||
| } | |||
| static int test_binaryop_9() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(2), RandomMat(11, 6, 2)) | |||
| || test_binaryop(RandomMat(4), RandomMat(11, 6, 4)) | |||
| || test_binaryop(RandomMat(16), RandomMat(11, 6, 16)); | |||
| return 0; | |||
| } | |||
| static int test_binaryop_10() | |||
| static int test_binaryop_2() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 3), 1.f) | |||
| || test_binaryop(RandomMat(11, 4), 1.f) | |||
| || test_binaryop(RandomMat(11, 16), 1.f); | |||
| } | |||
| ncnn::Mat a[] = { | |||
| RandomMat(31), | |||
| RandomMat(28), | |||
| RandomMat(24), | |||
| RandomMat(32), | |||
| RandomMat(13, 31), | |||
| RandomMat(14, 28), | |||
| RandomMat(15, 24), | |||
| RandomMat(16, 32), | |||
| RandomMat(7, 3, 31), | |||
| RandomMat(6, 4, 28), | |||
| RandomMat(5, 5, 24), | |||
| RandomMat(4, 6, 32), | |||
| RandomMat(2, 7, 3, 31), | |||
| RandomMat(3, 6, 4, 28), | |||
| RandomMat(4, 5, 5, 24), | |||
| RandomMat(5, 4, 6, 32) | |||
| }; | |||
| for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) | |||
| { | |||
| ncnn::Mat b; | |||
| b.create_like(a[i]); | |||
| Randomize(b); | |||
| static int test_binaryop_11() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 3), RandomMat(1)) | |||
| || test_binaryop(RandomMat(11, 4), RandomMat(1)) | |||
| || test_binaryop(RandomMat(11, 16), RandomMat(1)); | |||
| } | |||
| int ret = test_binaryop(a[i], b) || test_binaryop(b, a[i]); | |||
| if (ret != 0) | |||
| return ret; | |||
| } | |||
| static int test_binaryop_12() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 3), RandomMat(3)) | |||
| || test_binaryop(RandomMat(11, 4), RandomMat(4)) | |||
| || test_binaryop(RandomMat(11, 16), RandomMat(16)); | |||
| return 0; | |||
| } | |||
| static int test_binaryop_13() | |||
| static int test_binaryop_3() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 3), RandomMat(11, 3)) | |||
| || test_binaryop(RandomMat(11, 4), RandomMat(11, 4)) | |||
| || test_binaryop(RandomMat(11, 16), RandomMat(11, 16)); | |||
| } | |||
| ncnn::Mat a[] = { | |||
| RandomMat(13, 31), | |||
| RandomMat(14, 28), | |||
| RandomMat(15, 24), | |||
| RandomMat(16, 32) | |||
| }; | |||
| static int test_binaryop_14() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(6, 2), RandomMat(11, 6, 2)) | |||
| || test_binaryop(RandomMat(6, 4), RandomMat(11, 6, 4)) | |||
| || test_binaryop(RandomMat(6, 16), RandomMat(11, 6, 16)); | |||
| } | |||
| for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) | |||
| { | |||
| ncnn::Mat b0(a[i].h); | |||
| ncnn::Mat b1(1, a[i].h); | |||
| Randomize(b0); | |||
| Randomize(b1); | |||
| static int test_binaryop_15() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 6, 2), 1.f) | |||
| || test_binaryop(RandomMat(11, 6, 4), 1.f) | |||
| || test_binaryop(RandomMat(11, 6, 16), 1.f); | |||
| } | |||
| int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i]) | |||
| || test_binaryop(a[i], b1) || test_binaryop(b1, a[i]); | |||
| if (ret != 0) | |||
| return ret; | |||
| } | |||
| static int test_binaryop_16() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 6, 2), RandomMat(1)) | |||
| || test_binaryop(RandomMat(11, 6, 4), RandomMat(1)) | |||
| || test_binaryop(RandomMat(11, 6, 16), RandomMat(1)); | |||
| return 0; | |||
| } | |||
| static int test_binaryop_17() | |||
| static int test_binaryop_4() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 6, 2), RandomMat(2)) | |||
| || test_binaryop(RandomMat(11, 6, 4), RandomMat(4)) | |||
| || test_binaryop(RandomMat(11, 6, 16), RandomMat(16)); | |||
| } | |||
| ncnn::Mat a[] = { | |||
| RandomMat(7, 3, 31), | |||
| RandomMat(6, 4, 28), | |||
| RandomMat(5, 5, 24), | |||
| RandomMat(4, 6, 32) | |||
| }; | |||
| static int test_binaryop_18() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 6, 2), RandomMat(6, 2)) | |||
| || test_binaryop(RandomMat(11, 6, 4), RandomMat(6, 4)) | |||
| || test_binaryop(RandomMat(11, 6, 16), RandomMat(6, 16)); | |||
| } | |||
| for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) | |||
| { | |||
| ncnn::Mat b0(a[i].c); | |||
| ncnn::Mat b1(1, 1, a[i].c); | |||
| ncnn::Mat b2(a[i].h, a[i].c); | |||
| ncnn::Mat b3(1, a[i].h, a[i].c); | |||
| Randomize(b0); | |||
| Randomize(b1); | |||
| Randomize(b2); | |||
| Randomize(b3); | |||
| int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i]) | |||
| || test_binaryop(a[i], b1) || test_binaryop(b1, a[i]) | |||
| || test_binaryop(a[i], b2) || test_binaryop(b2, a[i]) | |||
| || test_binaryop(a[i], b3) || test_binaryop(b3, a[i]); | |||
| if (ret != 0) | |||
| return ret; | |||
| } | |||
| static int test_binaryop_19() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 6, 2), RandomMat(11, 6, 2)) | |||
| || test_binaryop(RandomMat(11, 6, 4), RandomMat(11, 6, 4)) | |||
| || test_binaryop(RandomMat(11, 6, 16), RandomMat(11, 6, 16)); | |||
| return 0; | |||
| } | |||
| static int test_binaryop_20() | |||
| static int test_binaryop_5() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(1), RandomMat(11, 3, 4, 2)) | |||
| || test_binaryop(RandomMat(1), RandomMat(11, 3, 4, 4)) | |||
| || test_binaryop(RandomMat(1), RandomMat(11, 3, 4, 16)); | |||
| } | |||
| ncnn::Mat a[] = { | |||
| RandomMat(2, 7, 3, 31), | |||
| RandomMat(3, 6, 4, 28), | |||
| RandomMat(4, 5, 5, 24), | |||
| RandomMat(5, 4, 6, 32) | |||
| }; | |||
| static int test_binaryop_21() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(2), RandomMat(11, 3, 4, 2)) | |||
| || test_binaryop(RandomMat(4), RandomMat(11, 3, 4, 4)) | |||
| || test_binaryop(RandomMat(16), RandomMat(11, 3, 4, 16)); | |||
| } | |||
| for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) | |||
| { | |||
| ncnn::Mat b0(a[i].c); | |||
| ncnn::Mat b1(1, 1, 1, a[i].c); | |||
| ncnn::Mat b2(a[i].d, a[i].c); | |||
| ncnn::Mat b3(1, 1, a[i].d, a[i].c); | |||
| ncnn::Mat b4(a[i].h, a[i].d, a[i].c); | |||
| ncnn::Mat b5(1, a[i].h, a[i].d, a[i].c); | |||
| Randomize(b0); | |||
| Randomize(b1); | |||
| Randomize(b2); | |||
| Randomize(b3); | |||
| Randomize(b4); | |||
| Randomize(b5); | |||
| int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i]) | |||
| || test_binaryop(a[i], b1) || test_binaryop(b1, a[i]) | |||
| || test_binaryop(a[i], b2) || test_binaryop(b2, a[i]) | |||
| || test_binaryop(a[i], b3) || test_binaryop(b3, a[i]) | |||
| || test_binaryop(a[i], b4) || test_binaryop(b4, a[i]) | |||
| || test_binaryop(a[i], b5) || test_binaryop(b5, a[i]); | |||
| if (ret != 0) | |||
| return ret; | |||
| } | |||
| static int test_binaryop_22() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(4, 2), RandomMat(11, 3, 4, 2)) | |||
| || test_binaryop(RandomMat(4, 4), RandomMat(11, 3, 4, 4)) | |||
| || test_binaryop(RandomMat(4, 16), RandomMat(11, 3, 4, 16)); | |||
| return 0; | |||
| } | |||
| static int test_binaryop_23() | |||
| static int test_binaryop_6() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(3, 4, 2), RandomMat(11, 3, 4, 2)) | |||
| || test_binaryop(RandomMat(3, 4, 4), RandomMat(11, 3, 4, 4)) | |||
| || test_binaryop(RandomMat(3, 4, 16), RandomMat(11, 3, 4, 16)); | |||
| } | |||
| ncnn::Mat a[] = { | |||
| RandomMat(13, 31), | |||
| RandomMat(14, 28), | |||
| RandomMat(15, 24), | |||
| RandomMat(16, 32) | |||
| }; | |||
| static int test_binaryop_24() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 3, 4, 2), 1.f) | |||
| || test_binaryop(RandomMat(11, 3, 4, 4), 1.f) | |||
| || test_binaryop(RandomMat(11, 3, 4, 16), 1.f); | |||
| } | |||
| for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) | |||
| { | |||
| ncnn::Mat b0(a[i].w, 1); | |||
| Randomize(b0); | |||
| static int test_binaryop_25() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 3, 4, 2), RandomMat(1)) | |||
| || test_binaryop(RandomMat(11, 3, 4, 4), RandomMat(1)) | |||
| || test_binaryop(RandomMat(11, 3, 4, 16), RandomMat(1)); | |||
| } | |||
| int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i]); | |||
| if (ret != 0) | |||
| return ret; | |||
| } | |||
| static int test_binaryop_26() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 3, 4, 2), RandomMat(2)) | |||
| || test_binaryop(RandomMat(11, 3, 4, 4), RandomMat(4)) | |||
| || test_binaryop(RandomMat(11, 3, 4, 16), RandomMat(16)); | |||
| return 0; | |||
| } | |||
| static int test_binaryop_27() | |||
| static int test_binaryop_7() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 3, 4, 2), RandomMat(4, 2)) | |||
| || test_binaryop(RandomMat(11, 3, 4, 4), RandomMat(4, 4)) | |||
| || test_binaryop(RandomMat(11, 3, 4, 16), RandomMat(4, 16)); | |||
| } | |||
| ncnn::Mat a[] = { | |||
| RandomMat(7, 3, 31), | |||
| RandomMat(6, 4, 28), | |||
| RandomMat(5, 5, 24), | |||
| RandomMat(4, 6, 32) | |||
| }; | |||
| static int test_binaryop_28() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 3, 4, 2), RandomMat(3, 4, 2)) | |||
| || test_binaryop(RandomMat(11, 3, 4, 4), RandomMat(3, 4, 4)) | |||
| || test_binaryop(RandomMat(11, 3, 4, 16), RandomMat(3, 4, 16)); | |||
| } | |||
| for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) | |||
| { | |||
| ncnn::Mat b0(a[i].w, 1, 1); | |||
| ncnn::Mat b1(a[i].w, a[i].h, 1); | |||
| Randomize(b0); | |||
| Randomize(b1); | |||
| static int test_binaryop_29() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 3, 4, 2), RandomMat(11, 3, 4, 2)) | |||
| || test_binaryop(RandomMat(11, 3, 4, 4), RandomMat(11, 3, 4, 4)) | |||
| || test_binaryop(RandomMat(11, 3, 4, 16), RandomMat(11, 3, 4, 16)); | |||
| } | |||
| int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i]) | |||
| || test_binaryop(a[i], b1) || test_binaryop(b1, a[i]); | |||
| if (ret != 0) | |||
| return ret; | |||
| } | |||
| static int test_binaryop_s1() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 6, 2), RandomMat(1, 1, 2)) | |||
| || test_binaryop(RandomMat(11, 6, 4), RandomMat(1, 1, 4)) | |||
| || test_binaryop(RandomMat(11, 6, 16), RandomMat(1, 1, 16)); | |||
| return 0; | |||
| } | |||
| static int test_binaryop_s2() | |||
| static int test_binaryop_8() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 6, 2), RandomMat(11, 6, 1)) | |||
| || test_binaryop(RandomMat(11, 6, 4), RandomMat(11, 6, 1)) | |||
| || test_binaryop(RandomMat(11, 6, 16), RandomMat(11, 6, 1)); | |||
| } | |||
| ncnn::Mat a[] = { | |||
| RandomMat(2, 7, 3, 31), | |||
| RandomMat(3, 6, 4, 28), | |||
| RandomMat(4, 5, 5, 24), | |||
| RandomMat(5, 4, 6, 32) | |||
| }; | |||
| static int test_binaryop_s3() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(1, 1, 2), RandomMat(11, 6, 2)) | |||
| || test_binaryop(RandomMat(1, 1, 4), RandomMat(11, 6, 4)) | |||
| || test_binaryop(RandomMat(1, 1, 16), RandomMat(11, 6, 16)); | |||
| } | |||
| for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) | |||
| { | |||
| ncnn::Mat b0(a[i].w, 1, 1, 1); | |||
| ncnn::Mat b1(a[i].w, a[i].h, 1, 1); | |||
| ncnn::Mat b2(a[i].w, a[i].h, a[i].d, 1); | |||
| Randomize(b0); | |||
| Randomize(b1); | |||
| Randomize(b2); | |||
| int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i]) | |||
| || test_binaryop(a[i], b1) || test_binaryop(b1, a[i]) | |||
| || test_binaryop(a[i], b2) || test_binaryop(b2, a[i]); | |||
| if (ret != 0) | |||
| return ret; | |||
| } | |||
| static int test_binaryop_s4() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 6, 1), RandomMat(11, 6, 2)) | |||
| || test_binaryop(RandomMat(11, 6, 1), RandomMat(11, 6, 4)) | |||
| || test_binaryop(RandomMat(11, 6, 1), RandomMat(11, 6, 16)); | |||
| return 0; | |||
| } | |||
| static int test_binaryop_s5() | |||
| static int test_binaryop_9() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 6, 2), RandomMat(1, 6, 2)) | |||
| || test_binaryop(RandomMat(11, 6, 4), RandomMat(1, 6, 4)) | |||
| || test_binaryop(RandomMat(11, 6, 16), RandomMat(1, 6, 16)); | |||
| } | |||
| ncnn::Mat a[] = { | |||
| RandomMat(7, 3, 31), | |||
| RandomMat(6, 4, 28), | |||
| RandomMat(5, 5, 24), | |||
| RandomMat(4, 6, 32) | |||
| }; | |||
| static int test_binaryop_s6() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 6, 2), RandomMat(11, 1, 2)) | |||
| || test_binaryop(RandomMat(11, 6, 4), RandomMat(11, 1, 4)) | |||
| || test_binaryop(RandomMat(11, 6, 16), RandomMat(11, 1, 16)); | |||
| } | |||
| for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) | |||
| { | |||
| ncnn::Mat b0(a[i].w, 1, a[i].c); | |||
| Randomize(b0); | |||
| static int test_binaryop_s7() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(1, 6, 2), RandomMat(11, 6, 2)) | |||
| || test_binaryop(RandomMat(1, 6, 4), RandomMat(11, 6, 4)) | |||
| || test_binaryop(RandomMat(1, 6, 16), RandomMat(11, 6, 16)); | |||
| } | |||
| int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i]); | |||
| if (ret != 0) | |||
| return ret; | |||
| } | |||
| static int test_binaryop_s8() | |||
| { | |||
| return 0 | |||
| || test_binaryop(RandomMat(11, 1, 2), RandomMat(11, 6, 2)) | |||
| || test_binaryop(RandomMat(11, 1, 4), RandomMat(11, 6, 4)) | |||
| || test_binaryop(RandomMat(11, 1, 16), RandomMat(11, 6, 16)); | |||
| return 0; | |||
| } | |||
| int main() | |||
| @@ -393,35 +379,7 @@ int main() | |||
| || test_binaryop_6() | |||
| || test_binaryop_7() | |||
| || test_binaryop_8() | |||
| || test_binaryop_9() | |||
| || test_binaryop_10() | |||
| || test_binaryop_11() | |||
| || test_binaryop_12() | |||
| || test_binaryop_13() | |||
| || test_binaryop_14() | |||
| || test_binaryop_15() | |||
| || test_binaryop_16() | |||
| || test_binaryop_17() | |||
| || test_binaryop_18() | |||
| || test_binaryop_19() | |||
| || test_binaryop_20() | |||
| || test_binaryop_21() | |||
| || test_binaryop_22() | |||
| || test_binaryop_23() | |||
| || test_binaryop_24() | |||
| || test_binaryop_25() | |||
| || test_binaryop_26() | |||
| || test_binaryop_27() | |||
| || test_binaryop_28() | |||
| || test_binaryop_29() | |||
| || test_binaryop_s1() | |||
| || test_binaryop_s2() | |||
| || test_binaryop_s3() | |||
| || test_binaryop_s4() | |||
| || test_binaryop_s5() | |||
| || test_binaryop_s6() | |||
| || test_binaryop_s7() | |||
| || test_binaryop_s8(); | |||
| || test_binaryop_9(); | |||
| if (ret != 0) | |||
| return ret; | |||
| @@ -366,6 +366,7 @@ set(pnnx_pass_ncnn_SRCS | |||
| pass_ncnn/fuse_innerproduct_activation.cpp | |||
| pass_ncnn/fuse_transpose_matmul.cpp | |||
| pass_ncnn/fuse_binaryop_eltwise.cpp | |||
| pass_ncnn/insert_reshape_numpy_binaryop_broadcast.cpp | |||
| pass_ncnn/insert_reshape_linear.cpp | |||
| pass_ncnn/insert_reshape_pooling.cpp | |||
| @@ -43,6 +43,7 @@ | |||
| #include "pass_ncnn/fuse_innerproduct_activation.h" | |||
| #include "pass_ncnn/fuse_transpose_matmul.h" | |||
| #include "pass_ncnn/fuse_binaryop_eltwise.h" | |||
| #include "pass_ncnn/insert_reshape_numpy_binaryop_broadcast.h" | |||
| #include "pass_ncnn/insert_reshape_linear.h" | |||
| #include "pass_ncnn/insert_reshape_pooling.h" | |||
| @@ -85,6 +86,7 @@ void pass_ncnn(Graph& g) | |||
| ncnn::convert_half_to_float(g); | |||
| ncnn::insert_reshape_numpy_binaryop_broadcast(g); | |||
| ncnn::insert_reshape_pooling(g); | |||
| ncnn::insert_reshape_linear(g); | |||
| @@ -169,6 +169,8 @@ static std::string expand_expression(Graph& graph, const Operator* op, int& pnnx | |||
| Operand* op_unary_out = graph.new_operand(op->name + "_" + r); | |||
| op_unary_out->producer = op_unary; | |||
| op_unary_out->shape = op_unary_in->shape; | |||
| op_unary->inputs.push_back(op_unary_in); | |||
| op_unary->outputs.push_back(op_unary_out); | |||
| } | |||
| @@ -204,6 +206,8 @@ static std::string expand_expression(Graph& graph, const Operator* op, int& pnnx | |||
| Operand* op_binary_out = graph.new_operand(op->name + "_" + r); | |||
| op_binary_out->producer = op_binary; | |||
| op_binary_out->shape = op_binary_inb->shape; | |||
| op_binary->inputs.push_back(op_binary_inb); | |||
| op_binary->outputs.push_back(op_binary_out); | |||
| } | |||
| @@ -218,6 +222,8 @@ static std::string expand_expression(Graph& graph, const Operator* op, int& pnnx | |||
| Operand* op_binary_out = graph.new_operand(op->name + "_" + r); | |||
| op_binary_out->producer = op_binary; | |||
| op_binary_out->shape = op_binary_ina->shape; | |||
| op_binary->inputs.push_back(op_binary_ina); | |||
| op_binary->outputs.push_back(op_binary_out); | |||
| } | |||
| @@ -232,6 +238,28 @@ static std::string expand_expression(Graph& graph, const Operator* op, int& pnnx | |||
| Operand* op_binary_out = graph.new_operand(op->name + "_" + r); | |||
| op_binary_out->producer = op_binary; | |||
| // resolve out shape | |||
| std::vector<int> out_shape; | |||
| { | |||
| std::vector<int> a_shape = op_binary_ina->shape; | |||
| std::vector<int> b_shape = op_binary_inb->shape; | |||
| int outrank = (int)std::max(a_shape.size(), b_shape.size()); | |||
| for (int k = (int)a_shape.size(); k < outrank; k++) | |||
| { | |||
| a_shape.insert(a_shape.begin(), 1); | |||
| } | |||
| for (int k = (int)b_shape.size(); k < outrank; k++) | |||
| { | |||
| b_shape.insert(b_shape.begin(), 1); | |||
| } | |||
| out_shape.resize(outrank); | |||
| for (int k = 0; k < outrank; k++) | |||
| { | |||
| out_shape[k] = std::max(a_shape[k], b_shape[k]); | |||
| } | |||
| } | |||
| op_binary_out->shape = out_shape; | |||
| op_binary->inputs.push_back(op_binary_ina); | |||
| op_binary->inputs.push_back(op_binary_inb); | |||
| op_binary->outputs.push_back(op_binary_out); | |||
| @@ -0,0 +1,153 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "insert_reshape_numpy_binaryop_broadcast.h" | |||
| #include "pass_ncnn.h" | |||
| namespace pnnx { | |||
| namespace ncnn { | |||
| void insert_reshape_numpy_binaryop_broadcast(Graph& graph) | |||
| { | |||
| while (1) | |||
| { | |||
| bool matched = false; | |||
| for (size_t i = 0; i < graph.ops.size(); i++) | |||
| { | |||
| Operator* op = graph.ops[i]; | |||
| if (op->type != "BinaryOp") | |||
| continue; | |||
| if (op->inputs.size() != 2) | |||
| continue; | |||
| if (op->inputs[0]->shape.empty() || op->inputs[1]->shape.empty()) | |||
| continue; | |||
| int batch_index0 = op->inputs[0]->params["__batch_index"].i; | |||
| int batch_index1 = op->inputs[1]->params["__batch_index"].i; | |||
| if (batch_index0 != batch_index1) | |||
| { | |||
| fprintf(stderr, "binaryop broadcast across batch axis %d and %d is not supported\n", batch_index0, batch_index1); | |||
| continue; | |||
| } | |||
| if (op->inputs[0]->shape.size() == 5 && batch_index0 == 233) | |||
| { | |||
| if (op->inputs[0]->shape[0] == 1) | |||
| { | |||
| fprintf(stderr, "assume reshape 5-rank tensor has batch_index 0\n"); | |||
| batch_index0 = 0; | |||
| } | |||
| } | |||
| if (op->inputs[1]->shape.size() == 5 && batch_index1 == 233) | |||
| { | |||
| if (op->inputs[1]->shape[0] == 1) | |||
| { | |||
| fprintf(stderr, "assume reshape 5-rank tensor has batch_index 0\n"); | |||
| batch_index1 = 0; | |||
| } | |||
| } | |||
| // drop shape batch index | |||
| std::vector<int> new_shape0; | |||
| std::vector<int> new_shape1; | |||
| for (int j = 0; j < (int)op->inputs[0]->shape.size(); j++) | |||
| { | |||
| if (j == batch_index0 && (op->inputs[0]->shape[j] == 1 || op->inputs[0]->shape[j] == op->inputs[1]->shape[j])) | |||
| continue; | |||
| new_shape0.push_back(op->inputs[0]->shape[j]); | |||
| } | |||
| for (int j = 0; j < (int)op->inputs[1]->shape.size(); j++) | |||
| { | |||
| if (j == batch_index1 && (op->inputs[1]->shape[j] == 1 || op->inputs[1]->shape[j] == op->inputs[0]->shape[j])) | |||
| continue; | |||
| new_shape1.push_back(op->inputs[1]->shape[j]); | |||
| } | |||
| const int input_rank0 = (int)new_shape0.size(); | |||
| const int input_rank1 = (int)new_shape1.size(); | |||
| if (input_rank0 >= 5) | |||
| { | |||
| fprintf(stderr, "binaryop tensor0 with rank %d is not supported yet!\n", (int)op->inputs[0]->shape.size()); | |||
| } | |||
| if (input_rank1 >= 5) | |||
| { | |||
| fprintf(stderr, "binaryop tensor1 with rank %d is not supported yet!\n", (int)op->inputs[1]->shape.size()); | |||
| } | |||
| if (input_rank0 == input_rank1) | |||
| { | |||
| // no broadcast after ignoring batch index | |||
| continue; | |||
| } | |||
| // fprintf(stderr, "insert_reshape_numpy_binaryop_broadcast %d %d\n", input_rank0, input_rank1); | |||
| matched = true; | |||
| const int binaryop_lower_rank_in_index = input_rank0 < input_rank1 ? 0 : 1; | |||
| Operand* binaryop_lower_rank_in = op->inputs[binaryop_lower_rank_in_index]; | |||
| Operator* reshape0 = graph.new_operator_before("Tensor.reshape", op->name + "_ncnnreshape0", op); | |||
| Operand* reshape0_out = graph.new_operand(op->name + "_ncnnreshape0_out"); | |||
| reshape0->inputs.push_back(binaryop_lower_rank_in); | |||
| reshape0->outputs.push_back(reshape0_out); | |||
| for (size_t j = 0; j < binaryop_lower_rank_in->consumers.size(); j++) | |||
| { | |||
| if (binaryop_lower_rank_in->consumers[j] == op) | |||
| { | |||
| binaryop_lower_rank_in->consumers[j] = reshape0; | |||
| break; | |||
| } | |||
| } | |||
| op->inputs[binaryop_lower_rank_in_index] = reshape0_out; | |||
| reshape0_out->producer = reshape0; | |||
| reshape0_out->consumers.push_back(op); | |||
| reshape0_out->params["__batch_index"] = input_rank0 < input_rank1 ? batch_index0 : batch_index1; | |||
| // insert explicit broadcast index for missing ranks | |||
| std::vector<int> reshape0_shape = input_rank0 < input_rank1 ? new_shape0 : new_shape1; | |||
| for (int j = 0; j < std::abs(input_rank0 - input_rank1); j++) | |||
| { | |||
| reshape0_shape.insert(reshape0_shape.begin(), 1); | |||
| } | |||
| reshape0->params["shape"] = reshape0_shape; | |||
| break; | |||
| } | |||
| if (!matched) | |||
| break; | |||
| } | |||
| } | |||
| } // namespace ncnn | |||
| } // namespace pnnx | |||
| @@ -0,0 +1,25 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_ncnn.h" | |||
| namespace pnnx { | |||
| namespace ncnn { | |||
| void insert_reshape_numpy_binaryop_broadcast(Graph& graph); | |||
| } // namespace ncnn | |||
| } // namespace pnnx | |||
| @@ -187,6 +187,7 @@ pnnx_ncnn_add_test(vit_b_32) | |||
| pnnx_ncnn_add_test(ncnn_fuse_transpose_matmul) | |||
| pnnx_ncnn_add_test(ncnn_fuse_shufflechannel_slice) | |||
| pnnx_ncnn_add_test(ncnn_fuse_binaryop_eltwise) | |||
| pnnx_ncnn_add_test(ncnn_numpy_binaryop_broadcast) | |||
| if(Torch_VERSION VERSION_GREATER_EQUAL "1.9") | |||
| pnnx_ncnn_add_test(F_mish) | |||
| @@ -0,0 +1,78 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| def forward(self, x, y, z, w, u, v): | |||
| a = x + y | |||
| b = x - z | |||
| c = x * w | |||
| d = y / z | |||
| e = y + w | |||
| f = z - w | |||
| g = y + x | |||
| h = z - x | |||
| i = w * x | |||
| j = z / y | |||
| k = w + y | |||
| l = w - z | |||
| m = (x - z) * w | |||
| n = (x + y) - (z + w) | |||
| o = x.view(1, 1, 5) + y.view(1, 7, 5) - z | |||
| p = u * y | |||
| q = z / v | |||
| return a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.rand(5) | |||
| y = torch.rand(7, 5) | |||
| z = torch.rand(4, 7, 5) | |||
| w = torch.rand(6, 4, 7, 5) | |||
| u = torch.rand(7, 1) | |||
| v = torch.rand(4, 1, 1) | |||
| a = net(x, y, z, w, u, v) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (x, y, z, w, u, v)) | |||
| mod.save("test_ncnn_numpy_binaryop_broadcast.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../../src/pnnx test_ncnn_numpy_binaryop_broadcast.pt inputshape=[5],[7,5],[4,7,5],[6,4,7,5],[7,1],[4,1,1]") | |||
| # ncnn inference | |||
| import test_ncnn_numpy_binaryop_broadcast_ncnn | |||
| b = test_ncnn_numpy_binaryop_broadcast_ncnn.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.allclose(a0, b0, 1e-4, 1e-4): | |||
| return False | |||
| return True | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||