Merge pull request !4336 from pengyongrong/concattags/v0.7.0-beta
| @@ -1,44 +1,41 @@ | |||
| #pragma OPENCL EXTENSION cl_khr_fp16 : enable | |||
| // #pragma OPENCL EXTENSION cl_khr_fp16 : enable | |||
| #define FLT4 float4 | |||
| __constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST; | |||
| __kernel void Concat(__write_only image2d_t output_image2d, __read_only image2d_t input0_image2d, | |||
| __read_only image2d_t input1_image2d, int2 shared_int0, int4 shared_out) { | |||
| __kernel void Concat(__read_only image2d_t input0, __read_only image2d_t input1, __write_only image2d_t output, | |||
| int2 input_channels, int4 output_shape) { | |||
| int X = get_global_id(0); // H | |||
| int Y = get_global_id(1); // W | |||
| int S = 0; | |||
| if (X >= shared_out.y || Y >= shared_out.z) return; | |||
| for (int i = 0; i < shared_int0.x; i++) { | |||
| FLT4 result0 = read_imagef(input0_image2d, smp_none, (int2)((Y)*shared_int0.x + (i), (X))); | |||
| write_imagef(output_image2d, (int2)((Y)*shared_out.w + (S), (X)), result0); | |||
| S++; | |||
| int Z = get_global_id(2); // c/4 | |||
| if (X >= output_shape.y || Y >= output_shape.z || Z >= output_shape.w) { | |||
| return; | |||
| } | |||
| for (int i = 0; i < shared_int0.y; i++) { | |||
| FLT4 result1 = read_imagef(input1_image2d, smp_none, (int2)((Y)*shared_int0.y + (i), (X))); | |||
| write_imagef(output_image2d, (int2)((Y)*shared_out.w + (S), (X)), result1); | |||
| S++; | |||
| if (Z < input_channels.x) { | |||
| FLT4 result = read_imagef(input0, smp_none, (int2)((Y)*input_channels.x + Z, (X))); | |||
| write_imagef(output, (int2)((Y)*output_shape.w + Z, (X)), result); | |||
| } else { | |||
| FLT4 result = read_imagef(input1, smp_none, (int2)((Y)*input_channels.y + Z - input_channels.x, (X))); | |||
| write_imagef(output, (int2)((Y)*output_shape.w + Z, (X)), result); | |||
| } | |||
| } | |||
| __kernel void Concat3input(__write_only image2d_t output_image2d, __read_only image2d_t input0_image2d, | |||
| __read_only image2d_t input1_image2d, __read_only image2d_t input2_image2d, int3 shared_int0, | |||
| int4 shared_out) { | |||
| __kernel void Concat3input(__read_only image2d_t input0, __read_only image2d_t input1, __read_only image2d_t input2, | |||
| __write_only image2d_t output, int3 input_channels, int4 output_shape) { | |||
| int X = get_global_id(0); // H | |||
| int Y = get_global_id(1); // W | |||
| int S = 0; | |||
| if (X >= shared_out.y || Y >= shared_out.z) return; | |||
| for (int i = 0; i < shared_int0.x; i++) { | |||
| FLT4 result0 = read_imagef(input0_image2d, smp_none, (int2)((Y)*shared_int0.x + (i), (X))); | |||
| write_imagef(output_image2d, (int2)((Y)*shared_out.w + (S), (X)), result0); | |||
| S++; | |||
| } | |||
| for (int i = 0; i < shared_int0.y; i++) { | |||
| FLT4 result1 = read_imagef(input1_image2d, smp_none, (int2)((Y)*shared_int0.y + (i), (X))); | |||
| write_imagef(output_image2d, (int2)((Y)*shared_out.w + (S), (X)), result1); | |||
| S++; | |||
| int Z = get_global_id(2); // c/4 | |||
| if (X >= output_shape.y || Y >= output_shape.z || Z >= output_shape.w) { | |||
| return; | |||
| } | |||
| for (int i = 0; i < shared_int0.z; i++) { | |||
| FLT4 result2 = read_imagef(input2_image2d, smp_none, (int2)((Y)*shared_int0.z + (i), (X))); | |||
| write_imagef(output_image2d, (int2)((Y)*shared_out.w + (S), (X)), result2); | |||
| S++; | |||
| if (Z < input_channels.x) { | |||
| FLT4 result0 = read_imagef(input0, smp_none, (int2)((Y)*input_channels.x + Z, (X))); | |||
| write_imagef(output, (int2)((Y)*output_shape.w + Z, (X)), result0); | |||
| } else if (Z < (input_channels.x + input_channels.y)) { | |||
| FLT4 result1 = read_imagef(input1, smp_none, (int2)((Y)*input_channels.y + Z - input_channels.x, (X))); | |||
| write_imagef(output, (int2)((Y)*output_shape.w + Z, (X)), result1); | |||
| } else { | |||
| FLT4 result2 = | |||
| read_imagef(input2, smp_none, (int2)((Y)*input_channels.z + Z - input_channels.x - input_channels.y, (X))); | |||
| write_imagef(output, (int2)((Y)*output_shape.w + Z, (X)), result2); | |||
| } | |||
| } | |||
| @@ -46,21 +46,23 @@ int ConcatOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) | |||
| img_size->clear(); | |||
| std::vector<size_t> vec{im_dst_x, im_dst_y, img_dtype}; | |||
| *img_size = vec; | |||
| return 1; | |||
| return RET_OK; | |||
| } | |||
| int ConcatOpenCLKernel::Init() { | |||
| if (in_tensors_[0]->shape().size() != 4) { | |||
| MS_LOG(ERROR) << "only support dim=4"; | |||
| return RET_ERROR; | |||
| } | |||
| auto param = reinterpret_cast<ConcatParameter *>(this->op_parameter_); | |||
| MS_LOG(INFO) << "concat at axis=: " << param->axis_; | |||
| MS_LOG(DEBUG) << "concat at axis=: " << param->axis_; | |||
| if (param->axis_ != 0 && param->axis_ != 3) { | |||
| MS_LOG(ERROR) << "only support axis=0 or axis=3"; | |||
| return RET_ERROR; | |||
| } | |||
| if (param->axis_ == 0) { | |||
| return 0; | |||
| return RET_OK; | |||
| } | |||
| if (in_tensors_.size() == 2) { | |||
| std::set<std::string> build_options; | |||
| @@ -82,10 +84,10 @@ int ConcatOpenCLKernel::Init() { | |||
| ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options); | |||
| } | |||
| return 0; | |||
| return RET_OK; | |||
| } | |||
| int ConcatOpenCLKernel::ReSize() { return 0; } | |||
| int ConcatOpenCLKernel::ReSize() { return RET_OK; } | |||
| int ConcatOpenCLKernel::Run_axis0() { | |||
| auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | |||
| @@ -111,11 +113,7 @@ int ConcatOpenCLKernel::Run_axis0() { | |||
| ocl_runtime->UnmapBuffer(*buffer, tensor->Data()); | |||
| } | |||
| } | |||
| return 0; | |||
| } | |||
| int DivideRoundUp(int n, int div) { | |||
| int q = n / div; | |||
| return n % div == 0 ? q : q + 1; | |||
| return RET_OK; | |||
| } | |||
| int GetBiggestDividerWithPriority(int number, int max_divider) { | |||
| @@ -128,19 +126,22 @@ int GetBiggestDividerWithPriority(int number, int max_divider) { | |||
| if (number % 2 == 0 && 2 <= max_divider) { | |||
| return number / 2; | |||
| } | |||
| for (int i = max_divider; i != 0; i--) { | |||
| if (number % i == 0) { | |||
| return i; | |||
| } | |||
| } | |||
| return 1; | |||
| return RET_OK; | |||
| } | |||
| void ConcatGetWorkGroup(const std::vector<size_t> &global, std::vector<size_t> *local, int max_size) { | |||
| int x = std::min(GetBiggestDividerWithPriority(global[0], 8), 4); | |||
| const int max_divider = 8; | |||
| const int max_x = 4, max_y = 8; | |||
| int x = std::min(GetBiggestDividerWithPriority(global[0], max_divider), max_x); | |||
| int yz = max_size / x; | |||
| int y = std::min(std::min(GetBiggestDividerWithPriority(global[1], 8), yz), 8); | |||
| int z = std::min(yz / y, DivideRoundUp(global[2], 2)); | |||
| int y = std::min(std::min(GetBiggestDividerWithPriority(global[1], max_divider), yz), max_y); | |||
| int z = std::min(yz / y, static_cast<int>(UP_DIV(global[2], 2))); | |||
| local->clear(); | |||
| local->push_back(x); | |||
| @@ -159,44 +160,53 @@ int ConcatOpenCLKernel::Run() { | |||
| auto input1_shape = in_tensors_[1]->shape(); | |||
| auto output_shape = out_tensors_[0]->shape(); | |||
| cl_int2 input0_shape2_ = {DivideRoundUp(input0_shape[3], 4), DivideRoundUp(input1_shape[3], 4)}; // change | |||
| cl_int4 output_shape_ = {output_shape[0], output_shape[1], output_shape[2], DivideRoundUp(output_shape[3], 4)}; | |||
| cl_int2 input0_shape2_ = {UP_DIV(input0_shape[3], C4NUM), UP_DIV(input1_shape[3], C4NUM)}; // change | |||
| cl_int4 output_shape_ = {output_shape[0], output_shape[1], output_shape[2], UP_DIV(output_shape[3], C4NUM)}; | |||
| uint32_t OH = output_shape[0] * output_shape[1]; // N*H | |||
| uint32_t OH = output_shape[1]; // N*H | |||
| uint32_t OW = output_shape[2]; | |||
| uint32_t OC = UP_DIV(output_shape[3], C4NUM); | |||
| std::vector<size_t> local = {1, 1}; | |||
| std::vector<size_t> global = {OH, OW}; | |||
| // ConcatGetWorkGroup(global, &local, 512); | |||
| const std::vector<size_t> &max_global = ocl_runtime->GetWorkItemSize(); | |||
| std::vector<size_t> local = {1, 1, 1}; // init local | |||
| std::vector<size_t> global = {OH, OW, OC}; | |||
| ConcatGetWorkGroup(global, &local, max_global[0]); | |||
| int arg_cn = 0; | |||
| if (in_tensors_.size() == 2) { | |||
| ocl_runtime->SetKernelArg(kernel_, arg_cn++, out_tensors_[0]->Data()); | |||
| ocl_runtime->SetKernelArg(kernel_, arg_cn++, in_tensors_[0]->Data()); | |||
| ocl_runtime->SetKernelArg(kernel_, arg_cn++, in_tensors_[1]->Data()); | |||
| ocl_runtime->SetKernelArg(kernel_, arg_cn++, out_tensors_[0]->Data()); | |||
| ocl_runtime->SetKernelArg(kernel_, arg_cn++, input0_shape2_); | |||
| ocl_runtime->SetKernelArg(kernel_, arg_cn++, output_shape_); | |||
| } else if (in_tensors_.size() == 3) { | |||
| auto input2_shape = in_tensors_[2]->shape(); | |||
| cl_int3 input0_shape3_ = {DivideRoundUp(input0_shape[3], 4), DivideRoundUp(input1_shape[3], 4), | |||
| DivideRoundUp(input2_shape[3], 4)}; | |||
| ocl_runtime->SetKernelArg(kernel_, arg_cn++, out_tensors_[0]->Data()); | |||
| cl_int3 input0_shape3_ = {UP_DIV(input0_shape[3], C4NUM), UP_DIV(input1_shape[3], C4NUM), | |||
| UP_DIV(input2_shape[3], C4NUM)}; | |||
| ocl_runtime->SetKernelArg(kernel_, arg_cn++, in_tensors_[0]->Data()); | |||
| ocl_runtime->SetKernelArg(kernel_, arg_cn++, in_tensors_[1]->Data()); | |||
| ocl_runtime->SetKernelArg(kernel_, arg_cn++, in_tensors_[2]->Data()); | |||
| ocl_runtime->SetKernelArg(kernel_, arg_cn++, out_tensors_[0]->Data()); | |||
| ocl_runtime->SetKernelArg(kernel_, arg_cn++, input0_shape3_); | |||
| ocl_runtime->SetKernelArg(kernel_, arg_cn++, output_shape_); | |||
| } else { | |||
| MS_LOG(ERROR) << "only support inputs<=3"; | |||
| return RET_ERROR; | |||
| } | |||
| ocl_runtime->RunKernel(kernel_, global, local, nullptr); | |||
| return 0; | |||
| return RET_OK; | |||
| } // namespace mindspore::kernel | |||
| kernel::LiteKernel *OpenCLConcatKernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs, | |||
| OpParameter *opParameter, const lite::Context *ctx, | |||
| const kernel::KernelKey &desc, const lite::Primitive *primitive) { | |||
| auto *kernel = new ConcatOpenCLKernel(opParameter, inputs, outputs); | |||
| auto *kernel = new (std::nothrow) ConcatOpenCLKernel(opParameter, inputs, outputs); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "new ConcatOpenCLKernel failed"; | |||
| return nullptr; | |||
| } | |||
| auto ret = kernel->Init(); | |||
| if (0 != ret) { | |||
| MS_LOG(ERROR) << "Init kernel failed, name: Convolution"; | |||
| @@ -255,6 +255,14 @@ int ConvolutionOpenCLKernel::GetGlobalLocal(std::vector<size_t> *global, std::ve | |||
| local_h = global_h / 2; | |||
| } | |||
| auto output_tensor = out_tensors_[0]; | |||
| const size_t CO = output_tensor->Channel(); | |||
| const size_t CO_SLICES = UP_DIV(CO, C4NUM); | |||
| const size_t OW = output_tensor->Width(); | |||
| if (OW * CO_SLICES > 65536) { | |||
| local_w = 4; | |||
| } | |||
| global->clear(); | |||
| global->push_back(UP_DIV(param->output_h_, local_h) * local_h); | |||
| global->push_back(UP_DIV(param->output_w_, local_w) * local_w); | |||
| @@ -21,31 +21,9 @@ | |||
| #include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" | |||
| #include "mindspore/lite/src/runtime/kernel/opencl/kernel/concat.h" | |||
| int DivideRoundUp(int n, int div) { | |||
| int q = n / div; | |||
| return n % div == 0 ? q : q + 1; | |||
| } | |||
| void printfNode(float *result, const std::vector<int> &tempNode) { | |||
| for (int i = 0; i < tempNode[0]; i++) { | |||
| for (int j = 0; j < tempNode[1]; j++) { | |||
| for (int k = 0; k < tempNode[2]; k++) { | |||
| for (int w = 0; w < tempNode[3]; w++) { | |||
| std::cout | |||
| << result[i * tempNode[2] * tempNode[1] * tempNode[3] + j * tempNode[2] * tempNode[3] + k * tempNode[3] + w] | |||
| << " "; | |||
| } | |||
| std::cout << std::endl; | |||
| } | |||
| std::cout << std::endl; | |||
| } | |||
| std::cout << std::endl; | |||
| } | |||
| std::cout << std::endl; | |||
| } | |||
| void ConcatComputeByCPU_2input_dim4_axis3(float *input0, float *input1, float *output, std::vector<int> input_shape0, | |||
| std::vector<int> input_shape1, std::vector<int> output_shape, | |||
| const int axis) { | |||
| void ConcatComputeByCPU_2input_dim4_axis3(const float *input0, const float *input1, float *output, | |||
| std::vector<int> input_shape0, std::vector<int> input_shape1, | |||
| std::vector<int> output_shape, const int axis) { | |||
| int postion, index0 = 0, index1 = 0; | |||
| for (int i = 0; i < output_shape[0]; i++) { | |||
| for (int j = 0; j < output_shape[1]; j++) { | |||
| @@ -77,17 +55,17 @@ void ConcatComputeByCPU_3input_dim4_axis3(float *input0, float *input1, float *i | |||
| k * output_shape[3]; | |||
| for (int w = 0; w < output_shape[3]; w++) { | |||
| if (w < input_shape0[3]) { | |||
| int align = DivideRoundUp(input_shape0[3], 4) * 4; | |||
| int align = UP_DIV(input_shape0[3], 4) * 4; | |||
| index0 = i * input_shape0[1] * input_shape0[2] * align + j * input_shape0[2] * align + k * align + w; | |||
| output[postion++] = input0[index0]; | |||
| } else if (w >= input_shape0[3] && w < (input_shape0[3] + input_shape1[3])) { | |||
| int align = DivideRoundUp(input_shape1[3], 4) * 4; | |||
| int align = UP_DIV(input_shape1[3], 4) * 4; | |||
| index1 = i * input_shape1[1] * input_shape1[2] * align + j * input_shape1[2] * align + k * align + w - | |||
| input_shape0[3]; | |||
| output[postion++] = input1[index1]; | |||
| } else if ((input_shape0[3] + input_shape1[3]) <= w && | |||
| w < (input_shape0[3] + input_shape1[3] + input_shape2[3])) { | |||
| int align = DivideRoundUp(input_shape2[3], 4) * 4; | |||
| int align = UP_DIV(input_shape2[3], 4) * 4; | |||
| index2 = i * input_shape2[1] * input_shape2[2] * align + j * input_shape2[2] * align + k * align + w - | |||
| input_shape0[3] - input_shape1[3]; | |||
| output[postion++] = input2[index2]; | |||
| @@ -113,7 +91,6 @@ template <typename T> | |||
| void CompareOutputData1(T *output_data, T *correct_data, int size, float err_bound) { | |||
| for (size_t i = 0; i < size; i++) { | |||
| T abs = fabs(output_data[i] - correct_data[i]); | |||
| // printf("i=%d %.3f %.3f\n", i, output_data[i], correct_data[i]); | |||
| ASSERT_LE(abs, err_bound); | |||
| } | |||
| } | |||
| @@ -126,34 +103,50 @@ TEST_F(TestConcatOpenCL, ConcatFp32_2input_dim4_axis3) { | |||
| MS_LOG(INFO) << "init tensors"; | |||
| constexpr int INPUT_NUM = 2; | |||
| // std::array<std::vector<int>, INPUT_NUM> input_shapes = { | |||
| // std::vector<int>{1, 120, 120, 16}, std::vector<int>{1, 120, 120, 16},std::vector<int>{1, 120, 120, 96}}; | |||
| std::array<std::vector<int>, INPUT_NUM> input_shapes = {std::vector<int>{1, 32, 512, 48}, | |||
| std::vector<int>{1, 32, 512, 48}}; | |||
| std::vector<int> output_shape = {1, 32, 512, 96}; | |||
| output_shape[3] = DivideRoundUp(output_shape[3], 4) * 4; | |||
| std::array<std::vector<int>, INPUT_NUM> input_shapes = {std::vector<int>{1, 16, 256, 80}, | |||
| std::vector<int>{1, 16, 256, 80}}; | |||
| std::vector<int> output_shape = {1, 16, 256, 160}; | |||
| auto data_type = kNumberTypeFloat32; | |||
| auto tensor_type = schema::NodeType_ValueNode; | |||
| std::vector<lite::tensor::Tensor *> inputs; | |||
| for (auto &shape : input_shapes) { | |||
| inputs.push_back(new lite::tensor::Tensor(data_type, shape, schema::Format_NHWC, tensor_type)); | |||
| inputs.push_back(new lite::tensor::Tensor(data_type, shape, schema::Format_NHWC4, tensor_type)); | |||
| } | |||
| auto *output_tensor = | |||
| new (std::nothrow) lite::tensor::Tensor(data_type, output_shape, schema::Format_NHWC4, tensor_type); | |||
| if (output_tensor == nullptr) { | |||
| MS_LOG(INFO) << "new output_tensor failed"; | |||
| return; | |||
| } | |||
| auto *output_tensor = new lite::tensor::Tensor(data_type, output_shape, schema::Format_NHWC, tensor_type); | |||
| std::vector<lite::tensor::Tensor *> outputs{output_tensor}; | |||
| std::cout << "input_shapes size=: " << input_shapes.size() << std::endl; | |||
| MS_LOG(INFO) << "input_shapes size=: " << input_shapes.size(); | |||
| std::cout << "initialize tensors"; | |||
| auto param = new ConcatParameter(); | |||
| MS_LOG(INFO) << "initialize tensors"; | |||
| auto param = new (std::nothrow) ConcatParameter(); | |||
| if (param == nullptr) { | |||
| MS_LOG(INFO) << "new ConcatParameter failed"; | |||
| return; | |||
| } | |||
| param->axis_ = 3; | |||
| auto *concat_kernel = new kernel::ConcatOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs); | |||
| auto *concat_kernel = | |||
| new (std::nothrow) kernel::ConcatOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs); | |||
| if (concat_kernel == nullptr) { | |||
| MS_LOG(INFO) << "new kernel::ConcatOpenCLKernel failed"; | |||
| return; | |||
| } | |||
| concat_kernel->Init(); | |||
| MS_LOG(INFO) << "initialize sub_graph"; | |||
| std::vector<kernel::LiteKernel *> kernels{concat_kernel}; | |||
| auto *sub_graph = new kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); | |||
| // to do allocate memory for inputs and outputs | |||
| for (auto &input_tensor : inputs) { | |||
| input_tensor->MallocData(allocator); | |||
| } | |||
| MS_LOG(INFO) << "initialize sub_graph"; | |||
| std::vector<kernel::LiteKernel *> kernels{concat_kernel}; | |||
| auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); | |||
| if (sub_graph == nullptr) { | |||
| MS_LOG(INFO) << "new kernel::SubGraphOpenCLKernel failed"; | |||
| return; | |||
| } | |||
| sub_graph->Init(); | |||
| unsigned int seed = 123; | |||
| MS_LOG(INFO) << "initialize input data"; | |||
| @@ -182,5 +175,6 @@ TEST_F(TestConcatOpenCL, ConcatFp32_2input_dim4_axis3) { | |||
| sub_graph->Run(); | |||
| auto *output_data_gpu = reinterpret_cast<float *>(output_tensor->Data()); | |||
| CompareOutputData1(output_data_gpu, output_data_cpu.data(), output_tensor->ElementsNum(), 0.00001); | |||
| lite::opencl::OpenCLRuntime::DeleteInstance(); | |||
| } | |||
| } // namespace mindspore | |||