From 11762e59db9887c3df2987a914d1df35628d59f2 Mon Sep 17 00:00:00 2001 From: Pengyongrong Date: Tue, 15 Sep 2020 05:19:18 -0700 Subject: [PATCH] concat ops support 4input --- .../src/runtime/kernel/opencl/cl/concat.cl | 145 +++++++++++++++++- .../runtime/kernel/opencl/kernel/concat.cc | 26 +++- .../src/runtime/kernel/opencl/concat_tests.cc | 22 ++- 3 files changed, 179 insertions(+), 14 deletions(-) diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/concat.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/concat.cl index 25a3f176cc..8d2cfe692a 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/concat.cl +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/concat.cl @@ -62,7 +62,7 @@ __kernel void Concat3input_NHWC4(__read_only image2d_t input0, __read_only image if (Y < input_shape0.z) { FLT4 result0 = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X))); WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result0); - } else if (Y < (input_shape0.z + input_shape0.z)) { + } else if (Y < (input_shape0.z + input_shape1.z)) { FLT4 result1 = READ_IMAGE(input1, smp_none, (int2)((Y - input_shape0.z) * input_shape1.w + Z, (X))); WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result1); } else { @@ -74,7 +74,7 @@ __kernel void Concat3input_NHWC4(__read_only image2d_t input0, __read_only image if (Z < input_shape0.w) { FLT4 result0 = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X))); WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result0); - } else if (Z < (input_shape0.w + input_shape0.w)) { + } else if (Z < (input_shape0.w + input_shape1.w)) { FLT4 result1 = READ_IMAGE(input1, smp_none, (int2)((Y)*input_shape1.w + Z - input_shape0.w, (X))); WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result1); } else { @@ -196,3 +196,144 @@ __kernel void Concat3input_NC4HW4(__read_only image2d_t input0, __read_only imag } } } + +__kernel void Concat4input_NC4HW4(__read_only image2d_t input0, __read_only image2d_t input1, + __read_only image2d_t input2, __read_only image2d_t input3, + __write_only image2d_t output, int4 input_shape0, int4 input_shape1, + int4 input_shape2, int4 input_shape3, int4 output_shape, const int axis) { + int X = get_global_id(0); // N*H + int Y = get_global_id(1); // W + int Z = get_global_id(2); // c/4 + if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) { + return; + } + if (input_shape0.y == 0 || input_shape1.y == 0 || input_shape2.y == 0 || output_shape.y == 0) { + return; + } + int in_postion_x; + int out_pos_x = (X / output_shape.y) * output_shape.w * output_shape.y + Z * output_shape.y + X % output_shape.y; + if (axis == 1) { + if (X < input_shape0.y) { + in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y; + FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x)); + WRITE_IMAGE(output, (int2)((Y), out_pos_x), result); + } else if (X < input_shape0.y + input_shape1.y) { + in_postion_x = ((X - input_shape0.y) / input_shape1.y) * input_shape1.w * input_shape1.y + Z * input_shape1.y + + ((X - input_shape0.y) % input_shape1.y); + FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y), in_postion_x)); + WRITE_IMAGE(output, (int2)((Y), out_pos_x), result); + } else if (X < input_shape0.y + input_shape1.y + input_shape2.y) { + in_postion_x = ((X - input_shape0.y - input_shape1.y) / input_shape2.y) * input_shape2.w * input_shape2.y + + Z * input_shape2.y + ((X - input_shape0.y - input_shape1.y) % input_shape2.y); + FLT4 result = READ_IMAGE(input2, smp_none, (int2)((Y), in_postion_x)); + WRITE_IMAGE(output, (int2)((Y), out_pos_x), result); + } else { + in_postion_x = + ((X - input_shape0.y - input_shape1.y - input_shape2.y) / input_shape3.y) * input_shape3.w * input_shape3.y + + Z * input_shape3.y + ((X - input_shape0.y - input_shape1.y - input_shape2.y) % input_shape3.y); + FLT4 result = READ_IMAGE(input3, smp_none, (int2)((Y), in_postion_x)); + WRITE_IMAGE(output, (int2)((Y), out_pos_x), result); + } + } else if (axis == 2) { + if (Y < input_shape0.z) { + in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y; + FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x)); + WRITE_IMAGE(output, (int2)((Y), out_pos_x), result); + } else if (Y < input_shape0.z + input_shape1.z) { + in_postion_x = (X / input_shape1.y) * input_shape1.w * input_shape1.y + Z * input_shape1.y + (X % input_shape1.y); + FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y - input_shape0.z), in_postion_x)); + WRITE_IMAGE(output, (int2)((Y), out_pos_x), result); + } else if (Y < input_shape0.z + input_shape1.z + input_shape2.z) { + in_postion_x = (X / input_shape2.y) * input_shape2.w * input_shape2.y + Z * input_shape2.y + (X % input_shape2.y); + FLT4 result = READ_IMAGE(input2, smp_none, (int2)((Y - input_shape0.z - input_shape1.z), in_postion_x)); + WRITE_IMAGE(output, (int2)((Y), out_pos_x), result); + } else { + in_postion_x = (X / input_shape3.y) * input_shape3.w * input_shape3.y + Z * input_shape3.y + (X % input_shape3.y); + FLT4 result = + READ_IMAGE(input3, smp_none, (int2)((Y - input_shape0.z - input_shape1.z - input_shape2.z), in_postion_x)); + WRITE_IMAGE(output, (int2)((Y), out_pos_x), result); + } + } else { + if (Z < input_shape0.w) { + in_postion_x = (X / input_shape0.y) * input_shape0.w * input_shape0.y + Z * input_shape0.y + X % input_shape0.y; + FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y), in_postion_x)); + WRITE_IMAGE(output, (int2)((Y), out_pos_x), result); + } else if (Z < input_shape0.w + input_shape1.w) { + in_postion_x = (X / input_shape1.y) * input_shape1.w * input_shape1.y + (Z - input_shape0.w) * input_shape1.y + + (X % input_shape1.y); + FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y), in_postion_x)); + WRITE_IMAGE(output, (int2)((Y), out_pos_x), result); + } else if (Z < input_shape0.w + input_shape1.w + input_shape2.w) { + in_postion_x = (X / input_shape2.y) * input_shape2.w * input_shape2.y + + (Z - input_shape0.w - input_shape1.w) * input_shape2.y + (X % input_shape2.y); + FLT4 result = READ_IMAGE(input2, smp_none, (int2)((Y), in_postion_x)); + WRITE_IMAGE(output, (int2)((Y), out_pos_x), result); + } else { + in_postion_x = (X / input_shape3.y) * input_shape3.w * input_shape3.y + + (Z - input_shape0.w - input_shape1.w - input_shape2.w) * input_shape3.y + (X % input_shape3.y); + FLT4 result = READ_IMAGE(input3, smp_none, (int2)((Y), in_postion_x)); + WRITE_IMAGE(output, (int2)((Y), out_pos_x), result); + } + } +} + +__kernel void Concat4input_NHWC4(__read_only image2d_t input0, __read_only image2d_t input1, + __read_only image2d_t input2, __read_only image2d_t input3, + __write_only image2d_t output, int4 input_shape0, int4 input_shape1, int4 input_shape2, + int4 input_shape3, int4 output_shape, const int axis) { + int X = get_global_id(0); // N*H + int Y = get_global_id(1); // W + int Z = get_global_id(2); // c/4 + if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) { + return; + } + if (axis == 1) { + if (X < input_shape0.y) { + FLT4 result0 = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X))); + WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result0); + } else if (X < (input_shape0.y + input_shape1.y)) { + FLT4 result1 = READ_IMAGE(input1, smp_none, (int2)((Y)*input_shape1.w + Z, (X - input_shape0.y))); + WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result1); + } else if (X < (input_shape0.y + input_shape1.y + input_shape2.y)) { + FLT4 result2 = + READ_IMAGE(input2, smp_none, (int2)((Y)*input_shape2.w + Z, (X - input_shape0.y - input_shape1.y))); + WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result2); + } else { + FLT4 result3 = READ_IMAGE(input3, smp_none, + (int2)((Y)*input_shape3.w + Z, (X - input_shape0.y - input_shape1.y - input_shape2.y))); + WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result3); + } + } else if (axis == 2) { + if (Y < input_shape0.z) { + FLT4 result0 = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X))); + WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result0); + } else if (Y < (input_shape0.z + input_shape1.z)) { + FLT4 result1 = READ_IMAGE(input1, smp_none, (int2)((Y - input_shape0.z) * input_shape1.w + Z, (X))); + WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result1); + } else if (Y < (input_shape0.z + input_shape1.z + input_shape2.z)) { + FLT4 result2 = + READ_IMAGE(input2, smp_none, (int2)((Y - input_shape0.z - input_shape1.z) * input_shape2.w + Z, (X))); + WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result2); + } else { + FLT4 result3 = READ_IMAGE( + input3, smp_none, (int2)((Y - input_shape0.z - input_shape1.z - input_shape2.z) * input_shape3.w + Z, (X))); + WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result3); + } + } else { + if (Z < input_shape0.w) { + FLT4 result0 = READ_IMAGE(input0, smp_none, (int2)((Y)*input_shape0.w + Z, (X))); + WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result0); + } else if (Z < (input_shape0.w + input_shape1.w)) { + FLT4 result1 = READ_IMAGE(input1, smp_none, (int2)((Y)*input_shape1.w + Z - input_shape0.w, (X))); + WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result1); + } else if (Z < (input_shape0.w + input_shape1.w + input_shape2.w)) { + FLT4 result2 = + READ_IMAGE(input2, smp_none, (int2)((Y)*input_shape2.w + Z - input_shape0.w - input_shape1.w, (X))); + WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result2); + } else { + FLT4 result3 = READ_IMAGE(input3, smp_none, + (int2)((Y)*input_shape3.w + Z - input_shape0.w - input_shape1.w - input_shape2.w, (X))); + WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result3); + } + } +} diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.cc index 4f44aa4723..276833f168 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.cc @@ -97,8 +97,10 @@ int ConcatOpenCLKernel::Init() { kernel_name += "2input"; } else if (in_tensors_.size() == 3) { kernel_name += "3input"; + } else if (in_tensors_.size() == 4) { + kernel_name += "4input"; } else { - MS_LOG(ERROR) << " input must be 2 or 3"; + MS_LOG(ERROR) << " input must be 2 3 or 4"; return RET_ERROR; } if (in_format == schema::Format_NC4HW4) { @@ -193,11 +195,25 @@ int ConcatOpenCLKernel::Run() { ocl_runtime->SetKernelArg(kernel_, arg_cn++, input_shape3_); ocl_runtime->SetKernelArg(kernel_, arg_cn++, output_shape_); ocl_runtime->SetKernelArg(kernel_, arg_cn++, param->axis_); - } else if (in_tensors_.size() < 2) { - MS_LOG(ERROR) << " input sizes must >= 2 "; - return RET_ERROR; + } else if (in_tensors_.size() == 4) { + auto input3_shape = in_tensors_[2]->shape(); + auto input4_shape = in_tensors_[3]->shape(); + cl_int4 input_shape3_ = {input3_shape[0], input3_shape[1], input3_shape[2], UP_DIV(input3_shape[3], C4NUM)}; + cl_int4 input_shape4_ = {input4_shape[0], input4_shape[1], input4_shape[2], UP_DIV(input4_shape[3], C4NUM)}; + + ocl_runtime->SetKernelArg(kernel_, arg_cn++, in_tensors_[0]->MutableData()); + ocl_runtime->SetKernelArg(kernel_, arg_cn++, in_tensors_[1]->MutableData()); + ocl_runtime->SetKernelArg(kernel_, arg_cn++, in_tensors_[2]->MutableData()); + ocl_runtime->SetKernelArg(kernel_, arg_cn++, in_tensors_[3]->MutableData()); + ocl_runtime->SetKernelArg(kernel_, arg_cn++, out_tensors_[0]->MutableData()); + ocl_runtime->SetKernelArg(kernel_, arg_cn++, input_shape1_); + ocl_runtime->SetKernelArg(kernel_, arg_cn++, input_shape2_); + ocl_runtime->SetKernelArg(kernel_, arg_cn++, input_shape3_); + ocl_runtime->SetKernelArg(kernel_, arg_cn++, input_shape4_); + ocl_runtime->SetKernelArg(kernel_, arg_cn++, output_shape_); + ocl_runtime->SetKernelArg(kernel_, arg_cn++, param->axis_); } else { - MS_LOG(ERROR) << " only support inputs <= 3 "; + MS_LOG(ERROR) << " input sizes must 2 or 3 or 4"; return RET_ERROR; } ocl_runtime->RunKernel(kernel_, global, local, nullptr); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/concat_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/concat_tests.cc index 74822fd0cd..4f79dbbcd8 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/concat_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/concat_tests.cc @@ -47,22 +47,25 @@ TEST_F(TestConcatOpenCLfp16, ConcatFp16_2input_dim4_axis3) { auto allocator = ocl_runtime->GetAllocator(); // get the input from .bin - size_t input1_size, input2_size, input3_size, output_size; + size_t input1_size, input2_size, input3_size, input4_size, output_size; std::string input1Ppath = "./test_data/concatfp16_input1.bin"; std::string input2Ppath = "./test_data/concatfp16_input2.bin"; std::string input3Ppath = "./test_data/concatfp16_input3.bin"; + std::string input4Ppath = "./test_data/concatfp16_input4.bin"; std::string correctOutputPath = "./test_data/concatfp16_output.bin"; auto input_data1 = reinterpret_cast(mindspore::lite::ReadFile(input1Ppath.c_str(), &input1_size)); auto input_data2 = reinterpret_cast(mindspore::lite::ReadFile(input2Ppath.c_str(), &input2_size)); auto input_data3 = reinterpret_cast(mindspore::lite::ReadFile(input3Ppath.c_str(), &input3_size)); + auto input_data4 = reinterpret_cast(mindspore::lite::ReadFile(input4Ppath.c_str(), &input4_size)); auto correctOutput = reinterpret_cast(mindspore::lite::ReadFile(correctOutputPath.c_str(), &output_size)); MS_LOG(INFO) << " init tensors "; - constexpr int INPUT_NUM = 2; - std::array, INPUT_NUM> input_shapes = {std::vector{1, 19, 19, 96}, - std::vector{1, 19, 19, 96}}; - std::vector output_shape = {1, 19, 19, 192}; + constexpr int INPUT_NUM = 4; + std::array, INPUT_NUM> input_shapes = { + std::vector{1, 19, 19, 96}, std::vector{1, 19, 19, 96}, std::vector{1, 19, 19, 96}, + std::vector{1, 19, 19, 96}}; + std::vector output_shape = {1, 76, 19, 96}; auto data_type = kNumberTypeFloat16; auto tensor_type = lite::TensorCategory(schema::NodeType_ValueNode); std::vector inputs; @@ -97,7 +100,7 @@ TEST_F(TestConcatOpenCLfp16, ConcatFp16_2input_dim4_axis3) { } return; } - param->axis_ = 3; + param->axis_ = 1; auto *concat_kernel = new (std::nothrow) kernel::ConcatOpenCLKernel(reinterpret_cast(param), inputs, outputs); if (concat_kernel == nullptr) { @@ -141,8 +144,13 @@ TEST_F(TestConcatOpenCLfp16, ConcatFp16_2input_dim4_axis3) { memcpy(inputs[0]->MutableData(), input_data1, input1_size); memcpy(inputs[1]->MutableData(), input_data2, input2_size); memcpy(inputs[2]->MutableData(), input_data3, input3_size); + } else if (inputs.size() == 4) { + memcpy(inputs[0]->MutableData(), input_data1, input1_size); + memcpy(inputs[1]->MutableData(), input_data2, input2_size); + memcpy(inputs[2]->MutableData(), input_data3, input3_size); + memcpy(inputs[3]->MutableData(), input_data4, input4_size); } else { - MS_LOG(ERROR) << " input size must be 2 or 3"; + MS_LOG(ERROR) << " input size must be 2 or 3 or 4"; } std::cout << "==================output data================" << std::endl;