| @@ -3,7 +3,7 @@ | |||||
| #define READ_IMAGE read_imagef | #define READ_IMAGE read_imagef | ||||
| #define WRITE_IMAGE write_imagef | #define WRITE_IMAGE write_imagef | ||||
| __constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; | __constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; | ||||
| __kernel void transpose(__read_only image2d_t src_data, __write_only image2d_t dst_data, int2 HW, int2 C) { | |||||
| __kernel void transpose_IMG(__read_only image2d_t src_data, __write_only image2d_t dst_data, int2 HW, int2 C) { | |||||
| int X = get_global_id(0); | int X = get_global_id(0); | ||||
| int Y = get_global_id(1); | int Y = get_global_id(1); | ||||
| if (X >= HW.y || Y >= C.y) { | if (X >= HW.y || Y >= C.y) { | ||||
| @@ -43,3 +43,44 @@ __kernel void transpose(__read_only image2d_t src_data, __write_only image2d_t d | |||||
| WRITE_IMAGE(dst_data, (int2)(X, 4 * Y + 2), result[2]); | WRITE_IMAGE(dst_data, (int2)(X, 4 * Y + 2), result[2]); | ||||
| WRITE_IMAGE(dst_data, (int2)(X, 4 * Y + 3), result[3]); | WRITE_IMAGE(dst_data, (int2)(X, 4 * Y + 3), result[3]); | ||||
| } | } | ||||
| __kernel void transpose_BUF(__read_only image2d_t src_data, global FLT4 *dst_data, int2 HW, int2 C) { | |||||
| int X = get_global_id(0); | |||||
| int Y = get_global_id(1); | |||||
| if (X >= HW.y || Y >= C.y) { | |||||
| return; | |||||
| } | |||||
| FLT4 result[4]; | |||||
| result[0] = (FLT4)(0.0f); | |||||
| result[1] = (FLT4)(0.0f); | |||||
| result[2] = (FLT4)(0.0f); | |||||
| result[3] = (FLT4)(0.0f); | |||||
| FLT4 x0 = READ_IMAGE(src_data, smp_zero, (int2)(Y, 4 * X)); | |||||
| FLT4 x1 = READ_IMAGE(src_data, smp_zero, (int2)(Y, 4 * X + 1)); | |||||
| FLT4 x2 = READ_IMAGE(src_data, smp_zero, (int2)(Y, 4 * X + 2)); | |||||
| FLT4 x3 = READ_IMAGE(src_data, smp_zero, (int2)(Y, 4 * X + 3)); | |||||
| result[0].x = x0.x; | |||||
| result[0].y = x1.x; | |||||
| result[0].z = x2.x; | |||||
| result[0].w = x3.x; | |||||
| result[1].x = x0.y; | |||||
| result[1].y = x1.y; | |||||
| result[1].z = x2.y; | |||||
| result[1].w = x3.y; | |||||
| result[2].x = x0.z; | |||||
| result[2].y = x1.z; | |||||
| result[2].z = x2.z; | |||||
| result[2].w = x3.z; | |||||
| result[3].x = x0.w; | |||||
| result[3].y = x1.w; | |||||
| result[3].z = x2.w; | |||||
| result[3].w = x3.w; | |||||
| dst_data[4 * Y * HW.y + X] = result[0]; | |||||
| dst_data[(4 * Y + 1) * HW.y + X] = result[1]; | |||||
| dst_data[(4 * Y + 2) * HW.y + X] = result[2]; | |||||
| dst_data[(4 * Y + 3) * HW.y + X] = result[3]; | |||||
| } | |||||
| @@ -36,7 +36,11 @@ namespace mindspore::kernel { | |||||
| int TransposeOpenCLKernel::Init() { | int TransposeOpenCLKernel::Init() { | ||||
| std::string kernel_name = "transpose"; | std::string kernel_name = "transpose"; | ||||
| auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | ||||
| if (!is_image_out_) { | |||||
| kernel_name += "_BUF"; | |||||
| } else { | |||||
| kernel_name += "_IMG"; | |||||
| } | |||||
| #ifdef PROGRAM_WITH_IL | #ifdef PROGRAM_WITH_IL | ||||
| ocl_runtime->CreateKernelFromIL(kernel_(), kernel_name); | ocl_runtime->CreateKernelFromIL(kernel_(), kernel_name); | ||||
| #else | #else | ||||
| @@ -60,8 +64,12 @@ int TransposeOpenCLKernel::Init() { | |||||
| MS_LOG(ERROR) << "input H * W % 4 != 0 not support!"; | MS_LOG(ERROR) << "input H * W % 4 != 0 not support!"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| ori_format_ = out_tensors_[0]->GetFormat(); | |||||
| // Transpose::InferShape just set output->SetFormat(input->GetFormat()); -^-! | |||||
| ori_format_ = schema::Format_NCHW; | |||||
| out_tensors_[0]->SetFormat(schema::Format_NCHW); | out_tensors_[0]->SetFormat(schema::Format_NCHW); | ||||
| if (!is_image_out_) { | |||||
| out_mem_type_ = OpenCLMemType::BUF; | |||||
| } | |||||
| MS_LOG(DEBUG) << kernel_name << " Init Done!"; | MS_LOG(DEBUG) << kernel_name << " Init Done!"; | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -38,6 +38,7 @@ class TransposeOpenCLKernel : public OpenCLKernel { | |||||
| private: | private: | ||||
| cl::Kernel kernel_; | cl::Kernel kernel_; | ||||
| bool is_image_out_ = false; | |||||
| }; | }; | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -34,13 +34,13 @@ int SubGraphOpenCLKernel::GenToFormatOp(const std::vector<lite::tensor::Tensor * | |||||
| out_parameters->clear(); | out_parameters->clear(); | ||||
| out_convert_ops->clear(); | out_convert_ops->clear(); | ||||
| for (size_t i = 0; i < in_tensors.size(); ++i) { | for (size_t i = 0; i < in_tensors.size(); ++i) { | ||||
| OpenCLKernel* cur_opencl_op = reinterpret_cast<OpenCLKernel*>(in_kernels[i]); | |||||
| OpenCLKernel *cur_opencl_op = reinterpret_cast<OpenCLKernel *>(in_kernels[i]); | |||||
| schema::Format ori_format = cur_opencl_op->GetOriFormat(); | schema::Format ori_format = cur_opencl_op->GetOriFormat(); | ||||
| if (mem_type == cur_opencl_op->GetMemType() && in_tensors[i]->GetFormat() == ori_format) { | |||||
| if (mem_type == OpenCLMemType::BUF && mem_type == cur_opencl_op->GetMemType() && | |||||
| in_tensors[i]->GetFormat() == ori_format) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| auto dst_format = | |||||
| (mem_type == OpenCLMemType::IMG) ? in_kernels[i]->out_tensors()[0]->GetFormat() : ori_format; | |||||
| auto dst_format = (mem_type == OpenCLMemType::IMG) ? in_kernels[i]->out_tensors()[0]->GetFormat() : ori_format; | |||||
| auto src_format = | auto src_format = | ||||
| (mem_type == OpenCLMemType::IMG) ? in_tensors[i]->GetFormat() : in_kernels[i]->out_tensors()[0]->GetFormat(); | (mem_type == OpenCLMemType::IMG) ? in_tensors[i]->GetFormat() : in_kernels[i]->out_tensors()[0]->GetFormat(); | ||||
| lite::tensor::Tensor *new_tensor = new (std::nothrow) lite::tensor::Tensor(); | lite::tensor::Tensor *new_tensor = new (std::nothrow) lite::tensor::Tensor(); | ||||
| @@ -125,6 +125,7 @@ int RunSubGraphOpenCLKernel(const std::vector<lite::tensor::Tensor *> &inputs, | |||||
| MS_LOG(ERROR) << "Run SubGraphOpenCLKernel error."; | MS_LOG(ERROR) << "Run SubGraphOpenCLKernel error."; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| delete sub_graph; | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -180,6 +181,7 @@ TEST_F(TestActivationOpenCL, LeakyReluFp32_dim4) { | |||||
| delete input_tensor; | delete input_tensor; | ||||
| delete output_tensor; | delete output_tensor; | ||||
| lite::opencl::OpenCLRuntime::DeleteInstance(); | |||||
| return; | return; | ||||
| } | } | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -119,6 +119,11 @@ TEST_F(TestAvgPoolingOpenCL, AvgPoolFp32) { | |||||
| } | } | ||||
| printf("test all close OK!\n"); | printf("test all close OK!\n"); | ||||
| lite::CompareOutputData(output_data, expect, 4); | lite::CompareOutputData(output_data, expect, 4); | ||||
| delete tensor_in; | |||||
| delete tensor_out; | |||||
| delete pooling_kernel; | |||||
| delete pGraph; | |||||
| lite::opencl::OpenCLRuntime::DeleteInstance(); | |||||
| } | } | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -175,6 +175,14 @@ TEST_F(TestConcatOpenCL, ConcatFp32_2input_dim4_axis3) { | |||||
| sub_graph->Run(); | sub_graph->Run(); | ||||
| auto *output_data_gpu = reinterpret_cast<float *>(output_tensor->Data()); | auto *output_data_gpu = reinterpret_cast<float *>(output_tensor->Data()); | ||||
| CompareOutputData1(output_data_gpu, output_data_cpu.data(), output_tensor->ElementsNum(), 0.00001); | CompareOutputData1(output_data_gpu, output_data_cpu.data(), output_tensor->ElementsNum(), 0.00001); | ||||
| for (auto tensor : inputs) { | |||||
| delete tensor; | |||||
| } | |||||
| for (auto tensor : outputs) { | |||||
| delete tensor; | |||||
| } | |||||
| delete concat_kernel; | |||||
| delete sub_graph; | |||||
| lite::opencl::OpenCLRuntime::DeleteInstance(); | lite::opencl::OpenCLRuntime::DeleteInstance(); | ||||
| } | } | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -108,5 +108,14 @@ TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp32) { | |||||
| CompareOutputData(output_data, correct_data, oh * ow * co, 0.00001); | CompareOutputData(output_data, correct_data, oh * ow * co, 0.00001); | ||||
| MS_LOG(INFO) << "Test Conv2dTransposeFp32 passed"; | MS_LOG(INFO) << "Test Conv2dTransposeFp32 passed"; | ||||
| for (auto tensor : inputs) { | |||||
| delete tensor; | |||||
| } | |||||
| for (auto tensor : outputs) { | |||||
| delete tensor; | |||||
| } | |||||
| delete arith_kernel; | |||||
| delete pGraph; | |||||
| lite::opencl::OpenCLRuntime::DeleteInstance(); | |||||
| } | } | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -120,6 +120,14 @@ void TEST_MAIN(ConvParameter *param, schema::Format data_format, const std::stri | |||||
| MyCompareOutput(output_tensor, expect_file); | MyCompareOutput(output_tensor, expect_file); | ||||
| // lite::CompareOutput(reinterpret_cast<float *>(output_tensor->Data()), expect_file); | // lite::CompareOutput(reinterpret_cast<float *>(output_tensor->Data()), expect_file); | ||||
| for (auto tensor : inputs) { | |||||
| delete tensor; | |||||
| } | |||||
| for (auto tensor : outputs) { | |||||
| delete tensor; | |||||
| } | |||||
| delete conv_kernel; | |||||
| delete sub_graph; | |||||
| mindspore::lite::opencl::OpenCLRuntime::DeleteInstance(); | mindspore::lite::opencl::OpenCLRuntime::DeleteInstance(); | ||||
| } | } | ||||
| @@ -75,12 +75,15 @@ TEST_F(TestMatMulOpenCL, MatMulFp32) { | |||||
| // compare | // compare | ||||
| CompareOutputData(output_data, correct_data, co, 0.00001); | CompareOutputData(output_data, correct_data, co, 0.00001); | ||||
| delete input_data; | |||||
| delete weight_data; | |||||
| delete tensor_x; | |||||
| delete tensor_w; | |||||
| delete tensor_out; | |||||
| delete correct_data; | |||||
| MS_LOG(INFO) << "TestMatMulFp32 passed"; | MS_LOG(INFO) << "TestMatMulFp32 passed"; | ||||
| for (auto tensor : inputs) { | |||||
| delete tensor; | |||||
| } | |||||
| for (auto tensor : outputs) { | |||||
| delete tensor; | |||||
| } | |||||
| delete arith_kernel; | |||||
| delete pGraph; | |||||
| lite::opencl::OpenCLRuntime::DeleteInstance(); | |||||
| } | } | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -92,6 +92,15 @@ TEST_F(TestMaxPoolingOpenCL, MaxPool_1_32_512_96) { | |||||
| MS_LOG(INFO) << "compare result"; | MS_LOG(INFO) << "compare result"; | ||||
| std::cout << "compare result" << std::endl; | std::cout << "compare result" << std::endl; | ||||
| CompareOutput(output_tensor, expect_file); | CompareOutput(output_tensor, expect_file); | ||||
| for (auto tensor : inputs) { | |||||
| delete tensor; | |||||
| } | |||||
| for (auto tensor : outputs) { | |||||
| delete tensor; | |||||
| } | |||||
| delete pooling_kernel; | |||||
| delete pGraph; | |||||
| lite::opencl::OpenCLRuntime::DeleteInstance(); | |||||
| } | } | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -77,6 +77,15 @@ void RunTestCase(std::vector<int> input_shape, std::vector<int> output_shape, st | |||||
| MS_LOG(INFO) << "compare result"; | MS_LOG(INFO) << "compare result"; | ||||
| std::cout << "compare result" << std::endl; | std::cout << "compare result" << std::endl; | ||||
| CompareOutput(output_tensor, expect_file); | CompareOutput(output_tensor, expect_file); | ||||
| for (auto tensor : inputs) { | |||||
| delete tensor; | |||||
| } | |||||
| for (auto tensor : outputs) { | |||||
| delete tensor; | |||||
| } | |||||
| delete kernel; | |||||
| delete pGraph; | |||||
| lite::opencl::OpenCLRuntime::DeleteInstance(); | |||||
| } | } | ||||
| TEST_F(TestSoftmaxOpenCL, Softmax_1) { | TEST_F(TestSoftmaxOpenCL, Softmax_1) { | ||||
| @@ -75,5 +75,14 @@ TEST_F(TestToFormatOpenCL, TransposeFp32) { | |||||
| // compare | // compare | ||||
| CompareOutputData(output_data, correct_data, h * w * c, 0.00001); | CompareOutputData(output_data, correct_data, h * w * c, 0.00001); | ||||
| MS_LOG(INFO) << "TestMatMulFp32 passed"; | MS_LOG(INFO) << "TestMatMulFp32 passed"; | ||||
| for (auto tensor : inputs) { | |||||
| delete tensor; | |||||
| } | |||||
| for (auto tensor : outputs) { | |||||
| delete tensor; | |||||
| } | |||||
| delete arith_kernel; | |||||
| delete pGraph; | |||||
| lite::opencl::OpenCLRuntime::DeleteInstance(); | |||||
| } | } | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -75,5 +75,14 @@ TEST_F(TestTransposeOpenCL, TransposeFp32) { | |||||
| // compare | // compare | ||||
| CompareOutputData(output_data, correct_data, h * w * c, 0.00001); | CompareOutputData(output_data, correct_data, h * w * c, 0.00001); | ||||
| MS_LOG(INFO) << "TestMatMulFp32 passed"; | MS_LOG(INFO) << "TestMatMulFp32 passed"; | ||||
| for (auto tensor : inputs) { | |||||
| delete tensor; | |||||
| } | |||||
| for (auto tensor : outputs) { | |||||
| delete tensor; | |||||
| } | |||||
| delete arith_kernel; | |||||
| delete pGraph; | |||||
| lite::opencl::OpenCLRuntime::DeleteInstance(); | |||||
| } | } | ||||
| } // namespace mindspore | } // namespace mindspore | ||||