| @@ -1,3 +1,4 @@ | |||
| #pragma OPENCL EXTENSION cl_khr_fp16 : enable | |||
| __constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; | |||
| __kernel void reshape(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 size) { | |||
| int X = get_global_id(0); | |||
| @@ -1,3 +1,4 @@ | |||
| #pragma OPENCL EXTENSION cl_khr_fp16 : enable | |||
| __constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; | |||
| __kernel void transpose_IMG(__read_only image2d_t src_data, __write_only image2d_t dst_data, int2 HW, int2 C) { | |||
| int X = get_global_id(0); | |||
| @@ -75,8 +76,8 @@ __kernel void transpose_BUF(__read_only image2d_t src_data, global FLT4 *dst_dat | |||
| result[3].z = x2.w; | |||
| result[3].w = x3.w; | |||
| dst_data[4 * Y * HW.y + X] = result[0]; | |||
| dst_data[(4 * Y + 1) * HW.y + X] = result[1]; | |||
| dst_data[(4 * Y + 2) * HW.y + X] = result[2]; | |||
| dst_data[(4 * Y + 3) * HW.y + X] = result[3]; | |||
| if (4 * Y < C.x) dst_data[4 * Y * HW.y + X] = result[0]; | |||
| if (4 * Y + 1 < C.x) dst_data[(4 * Y + 1) * HW.y + X] = result[1]; | |||
| if (4 * Y + 2 < C.x) dst_data[(4 * Y + 2) * HW.y + X] = result[2]; | |||
| if (4 * Y + 3 < C.x) dst_data[(4 * Y + 3) * HW.y + X] = result[3]; | |||
| } | |||
| @@ -33,6 +33,7 @@ namespace mindspore::kernel { | |||
| int ReshapeOpenCLKernel::Init() { | |||
| std::string kernel_name = "reshape"; | |||
| auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | |||
| enable_fp16_ = ocl_runtime->GetFp16Enable(); | |||
| in_ori_format_ = in_tensors_[0]->GetFormat(); | |||
| out_ori_format_ = out_tensors_[0]->GetFormat(); | |||
| if (in_ori_format_ != schema::Format_NHWC4 && in_ori_format_ != schema::Format_NHWC) { | |||
| @@ -73,11 +74,10 @@ int ReshapeOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) | |||
| int c = shapex[3]; | |||
| im_dst_x = w * UP_DIV(c, C4NUM); | |||
| im_dst_y = h; | |||
| #ifdef ENABLE_FP16 | |||
| size_t img_dtype = CL_HALF_FLOAT; | |||
| #else | |||
| size_t img_dtype = CL_FLOAT; | |||
| #endif | |||
| if (enable_fp16_) { | |||
| img_dtype = CL_HALF_FLOAT; | |||
| } | |||
| img_size->clear(); | |||
| std::vector<size_t> vec{im_dst_x, im_dst_y, img_dtype}; | |||
| *img_size = vec; | |||
| @@ -121,4 +121,5 @@ kernel::LiteKernel *OpenCLReshapeKernelCreator(const std::vector<lite::tensor::T | |||
| } | |||
| REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Reshape, OpenCLReshapeKernelCreator) | |||
| REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_Reshape, OpenCLReshapeKernelCreator) | |||
| } // namespace mindspore::kernel | |||
| @@ -38,6 +38,7 @@ class ReshapeOpenCLKernel : public OpenCLKernel { | |||
| private: | |||
| cl::Kernel kernel_; | |||
| bool enable_fp16_{false}; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| @@ -35,6 +35,7 @@ namespace mindspore::kernel { | |||
| int TransposeOpenCLKernel::Init() { | |||
| std::string kernel_name = "transpose"; | |||
| auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | |||
| enable_fp16_ = ocl_runtime->GetFp16Enable(); | |||
| if (!is_image_out_) { | |||
| kernel_name += "_BUF"; | |||
| } else { | |||
| @@ -70,11 +71,10 @@ int TransposeOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_siz | |||
| size_t im_dst_x, im_dst_y; | |||
| im_dst_x = UP_DIV(out_tensors_[0]->Height() * out_tensors_[0]->Width(), C4NUM); | |||
| im_dst_y = out_tensors_[0]->Channel(); | |||
| #ifdef ENABLE_FP16 | |||
| size_t img_dtype = CL_HALF_FLOAT; | |||
| #else | |||
| size_t img_dtype = CL_FLOAT; | |||
| #endif | |||
| if (enable_fp16_) { | |||
| img_dtype = CL_HALF_FLOAT; | |||
| } | |||
| img_size->clear(); | |||
| std::vector<size_t> vec{im_dst_x, im_dst_y, img_dtype}; | |||
| *img_size = vec; | |||
| @@ -82,6 +82,7 @@ int TransposeOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_siz | |||
| } | |||
| int TransposeOpenCLKernel::Run() { | |||
| // notice: input image2d size = {c/4, h * w} | |||
| MS_LOG(DEBUG) << this->name() << " Running!"; | |||
| std::vector<int> shapex = in_tensors_[0]->shape(); | |||
| int h = shapex[1]; | |||
| @@ -38,7 +38,8 @@ class TransposeOpenCLKernel : public OpenCLKernel { | |||
| private: | |||
| cl::Kernel kernel_; | |||
| bool is_image_out_ = false; | |||
| bool is_image_out_{false}; | |||
| bool enable_fp16_{false}; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| @@ -31,14 +31,14 @@ class TestConv2dTransposeOpenCL : public mindspore::CommonTest { | |||
| }; | |||
| void RunTestCaseConv2dTranspose(const std::vector<int> &shape, void *input_data, void *weight_data, void *bias_data, | |||
| void *output_data, bool fp16) { | |||
| void *output_data, bool enable_fp16) { | |||
| auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | |||
| ocl_runtime->Init(); | |||
| size_t dtype_size = sizeof(float); | |||
| if (fp16) { | |||
| if (enable_fp16) { | |||
| ocl_runtime->SetFp16Enable(true); | |||
| dtype_size = sizeof(float16_t); | |||
| } | |||
| ocl_runtime->Init(); | |||
| auto allocator = ocl_runtime->GetAllocator(); | |||
| int pad = shape[0]; | |||
| int n = shape[1]; | |||
| @@ -52,7 +52,7 @@ void RunTestCaseConv2dTranspose(const std::vector<int> &shape, void *input_data, | |||
| int ow = 2 * w - 1 + 2 * (kw - 1 - pad) - kw + 1; | |||
| std::vector<int> input_shape = {n, h, w, ci}; | |||
| auto tensor_x_ptr = | |||
| std::make_unique<lite::tensor::Tensor>(TypeId(fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), input_shape); | |||
| std::make_unique<lite::tensor::Tensor>(TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), input_shape); | |||
| auto tensor_x = tensor_x_ptr.get(); | |||
| if (tensor_x == nullptr) { | |||
| MS_LOG(ERROR) << "tensor_x create error."; | |||
| @@ -61,7 +61,7 @@ void RunTestCaseConv2dTranspose(const std::vector<int> &shape, void *input_data, | |||
| std::vector<int> weight_shape = {co, kh, kw, ci}; | |||
| auto tensor_w_ptr = | |||
| std::make_unique<lite::tensor::Tensor>(TypeId(fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), weight_shape); | |||
| std::make_unique<lite::tensor::Tensor>(TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), weight_shape); | |||
| auto tensor_w = tensor_w_ptr.get(); | |||
| if (tensor_w == nullptr) { | |||
| MS_LOG(ERROR) << "tensor_w create error."; | |||
| @@ -71,7 +71,7 @@ void RunTestCaseConv2dTranspose(const std::vector<int> &shape, void *input_data, | |||
| std::vector<int> bias_shape = {co}; | |||
| auto tensor_bias_ptr = | |||
| std::make_unique<lite::tensor::Tensor>(TypeId(fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), bias_shape); | |||
| std::make_unique<lite::tensor::Tensor>(TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), bias_shape); | |||
| auto tensor_bias = tensor_bias_ptr.get(); | |||
| if (tensor_bias == nullptr) { | |||
| MS_LOG(ERROR) << "tensor_bias create error."; | |||
| @@ -81,7 +81,7 @@ void RunTestCaseConv2dTranspose(const std::vector<int> &shape, void *input_data, | |||
| std::vector<int> out_shape = {1, oh, ow, co}; | |||
| auto tensor_out_ptr = | |||
| std::make_unique<lite::tensor::Tensor>(TypeId(fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), out_shape); | |||
| std::make_unique<lite::tensor::Tensor>(TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), out_shape); | |||
| auto tensor_out = tensor_out_ptr.get(); | |||
| if (tensor_out == nullptr) { | |||
| MS_LOG(ERROR) << "tensor_out create error."; | |||
| @@ -126,7 +126,7 @@ void RunTestCaseConv2dTranspose(const std::vector<int> &shape, void *input_data, | |||
| pGraph->Init(); | |||
| memcpy(inputs[0]->Data(), input_data, n * h * w * ci * dtype_size); | |||
| pGraph->Run(); | |||
| if (fp16) { | |||
| if (enable_fp16) { | |||
| CompareOutput(outputs[0]->Data(), output_data, n * oh * ow * co, static_cast<float16_t>(1e-3), 2e-2); | |||
| } else { | |||
| CompareOutput(outputs[0]->Data(), output_data, n * oh * ow * co, static_cast<float>(1e-5)); | |||
| @@ -137,7 +137,8 @@ void RunTestCaseConv2dTranspose(const std::vector<int> &shape, void *input_data, | |||
| lite::opencl::OpenCLRuntime::DeleteInstance(); | |||
| } | |||
| void RunTestCaseConv2dTranspose(const std::vector<int> shape, const std::vector<std::string> file_path, bool fp16) { | |||
| void RunTestCaseConv2dTranspose(const std::vector<int> shape, const std::vector<std::string> file_path, | |||
| bool enable_fp16) { | |||
| size_t input_size; | |||
| std::string input_path = file_path[0]; | |||
| auto input_data = mindspore::lite::ReadFile(input_path.c_str(), &input_size); | |||
| @@ -168,7 +169,7 @@ void RunTestCaseConv2dTranspose(const std::vector<int> shape, const std::vector< | |||
| MS_LOG(ERROR) << "output_data load error."; | |||
| return; | |||
| } | |||
| RunTestCaseConv2dTranspose(shape, input_data, weight_data, bias_data, output_data, fp16); | |||
| RunTestCaseConv2dTranspose(shape, input_data, weight_data, bias_data, output_data, enable_fp16); | |||
| } | |||
| TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp32) { | |||
| @@ -29,32 +29,21 @@ class TestMatMulOpenCL : public mindspore::CommonTest { | |||
| TestMatMulOpenCL() {} | |||
| }; | |||
| void RunTestCaseMatMul(const std::vector<int> shape, const std::vector<std::string> file_path, bool fp16) { | |||
| void RunTestCaseMatMul(const std::vector<int> &shape, void *input_data, void *weight_data, void *output_data, | |||
| bool enable_fp16) { | |||
| auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | |||
| ocl_runtime->Init(); | |||
| if (fp16) { | |||
| size_t dtype_size = sizeof(float); | |||
| if (enable_fp16) { | |||
| ocl_runtime->SetFp16Enable(true); | |||
| dtype_size = sizeof(float16_t); | |||
| } | |||
| auto allocator = ocl_runtime->GetAllocator(); | |||
| size_t input_size; | |||
| int ci = shape[0]; | |||
| int co = shape[1]; | |||
| std::string input_path = file_path[0]; | |||
| auto input_data = mindspore::lite::ReadFile(input_path.c_str(), &input_size); | |||
| if (input_data == nullptr) { | |||
| MS_LOG(ERROR) << "input_data load error."; | |||
| return; | |||
| } | |||
| size_t weight_size; | |||
| std::string weight_path = file_path[1]; | |||
| auto weight_data = mindspore::lite::ReadFile(weight_path.c_str(), &weight_size); | |||
| if (weight_data == nullptr) { | |||
| MS_LOG(ERROR) << "weight_data load error."; | |||
| return; | |||
| } | |||
| std::vector<int> input_shape = {1, ci}; | |||
| auto tensor_x_ptr = std::make_unique<lite::tensor::Tensor>(TypeId(fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), | |||
| input_shape, schema::Format_NC); | |||
| auto tensor_x_ptr = std::make_unique<lite::tensor::Tensor>( | |||
| TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), input_shape, schema::Format_NC); | |||
| auto tensor_x = tensor_x_ptr.get(); | |||
| if (tensor_x == nullptr) { | |||
| MS_LOG(ERROR) << "tensor_x create error."; | |||
| @@ -63,7 +52,7 @@ void RunTestCaseMatMul(const std::vector<int> shape, const std::vector<std::stri | |||
| std::vector<int> w_shape = {co, ci}; | |||
| auto tensor_w_ptr = | |||
| std::make_unique<lite::tensor::Tensor>(TypeId(fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), w_shape); | |||
| std::make_unique<lite::tensor::Tensor>(TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), w_shape); | |||
| auto tensor_w = tensor_w_ptr.get(); | |||
| if (tensor_w == nullptr) { | |||
| MS_LOG(ERROR) << "tensor_w create error."; | |||
| @@ -72,8 +61,8 @@ void RunTestCaseMatMul(const std::vector<int> shape, const std::vector<std::stri | |||
| tensor_w->SetData(weight_data); | |||
| std::vector<int> out_shape = {1, co}; | |||
| auto tensor_out_ptr = std::make_unique<lite::tensor::Tensor>(TypeId(fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), | |||
| out_shape, schema::Format_NC); | |||
| auto tensor_out_ptr = std::make_unique<lite::tensor::Tensor>( | |||
| TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), out_shape, schema::Format_NC); | |||
| auto tensor_out = tensor_out_ptr.get(); | |||
| if (tensor_out == nullptr) { | |||
| MS_LOG(ERROR) << "tensor_out create error."; | |||
| @@ -100,12 +89,12 @@ void RunTestCaseMatMul(const std::vector<int> shape, const std::vector<std::stri | |||
| return; | |||
| } | |||
| pGraph->Init(); | |||
| memcpy(inputs[0]->Data(), input_data, input_size); | |||
| memcpy(inputs[0]->Data(), input_data, ci * dtype_size); | |||
| pGraph->Run(); | |||
| if (fp16) { | |||
| CompareOutput(tensor_out, file_path[2], static_cast<float16_t>(1e-3), 2e-2); | |||
| if (enable_fp16) { | |||
| CompareOutput(outputs[0]->Data(), output_data, co, static_cast<float16_t>(1e-3), 2e-2); | |||
| } else { | |||
| CompareOutput(tensor_out, file_path[2], static_cast<float>(1e-5)); | |||
| CompareOutput(outputs[0]->Data(), output_data, co, static_cast<float>(1e-5)); | |||
| } | |||
| tensor_x->SetData(nullptr); | |||
| @@ -114,6 +103,31 @@ void RunTestCaseMatMul(const std::vector<int> shape, const std::vector<std::stri | |||
| lite::opencl::OpenCLRuntime::DeleteInstance(); | |||
| } | |||
| void RunTestCaseMatMul(const std::vector<int> shape, const std::vector<std::string> file_path, bool enable_fp16) { | |||
| size_t input_size; | |||
| std::string input_path = file_path[0]; | |||
| auto input_data = mindspore::lite::ReadFile(input_path.c_str(), &input_size); | |||
| if (input_data == nullptr) { | |||
| MS_LOG(ERROR) << "input_data load error."; | |||
| return; | |||
| } | |||
| size_t weight_size; | |||
| std::string weight_path = file_path[1]; | |||
| auto weight_data = mindspore::lite::ReadFile(weight_path.c_str(), &weight_size); | |||
| if (weight_data == nullptr) { | |||
| MS_LOG(ERROR) << "weight_data load error."; | |||
| return; | |||
| } | |||
| size_t output_size; | |||
| std::string output_path = file_path[2]; | |||
| auto output_data = mindspore::lite::ReadFile(output_path.c_str(), &output_size); | |||
| if (output_data == nullptr) { | |||
| MS_LOG(ERROR) << "output_data load error."; | |||
| return; | |||
| } | |||
| RunTestCaseMatMul(shape, input_data, weight_data, output_data, enable_fp16); | |||
| } | |||
| TEST_F(TestMatMulOpenCL, MatMulFp32) { | |||
| int ci = 1280; | |||
| int co = 1001; | |||
| @@ -133,4 +147,26 @@ TEST_F(TestMatMulOpenCL, MatMulFp16) { | |||
| "./test_data/matmul/matmul_fp16_output.bin"}; | |||
| RunTestCaseMatMul(shape, file_path, true); | |||
| } | |||
| TEST_F(TestMatMulOpenCL, MatMulFp32_2) { | |||
| int ci = 5; | |||
| int co = 3; | |||
| std::vector<int> shape = {ci, co}; | |||
| std::vector<float> input_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f}; | |||
| std::vector<float> weight_data = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, | |||
| 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; | |||
| std::vector<float> output_data = {10.f, 10.f, 10.f}; | |||
| RunTestCaseMatMul(shape, input_data.data(), weight_data.data(), output_data.data(), false); | |||
| } | |||
| TEST_F(TestMatMulOpenCL, MatMulFp16_2) { | |||
| int ci = 5; | |||
| int co = 3; | |||
| std::vector<int> shape = {ci, co}; | |||
| std::vector<float16_t> input_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f}; | |||
| std::vector<float16_t> weight_data = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, | |||
| 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; | |||
| std::vector<float16_t> output_data = {10.f, 10.f, 10.f}; | |||
| RunTestCaseMatMul(shape, input_data.data(), weight_data.data(), output_data.data(), true); | |||
| } | |||
| } // namespace mindspore | |||
| @@ -21,6 +21,7 @@ | |||
| #include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" | |||
| #include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" | |||
| #include "mindspore/lite/src/runtime/kernel/opencl/kernel/reshape.h" | |||
| #include "mindspore/lite/test/ut/src/runtime/kernel/opencl/utils_tests.h" | |||
| namespace mindspore { | |||
| class TestReshapeOpenCL : public mindspore::CommonTest { | |||
| @@ -28,29 +29,27 @@ class TestReshapeOpenCL : public mindspore::CommonTest { | |||
| TestReshapeOpenCL() {} | |||
| }; | |||
| TEST_F(TestReshapeOpenCL, ReshapeFp32) { | |||
| void RunTestCaseReshape(const std::vector<int> &shape, void *input_data, void *output_data, bool enable_fp16) { | |||
| auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | |||
| ocl_runtime->Init(); | |||
| auto allocator = ocl_runtime->GetAllocator(); | |||
| int c = 63; | |||
| size_t input_size; | |||
| std::string input_path = "./test_data/reshape/reshape_fp32_input.bin"; | |||
| auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); | |||
| if (input_data == nullptr) { | |||
| MS_LOG(ERROR) << "input_data load error."; | |||
| return; | |||
| size_t dtype_size = sizeof(float); | |||
| if (enable_fp16) { | |||
| ocl_runtime->SetFp16Enable(true); | |||
| dtype_size = sizeof(float16_t); | |||
| } | |||
| auto allocator = ocl_runtime->GetAllocator(); | |||
| int c = shape[0]; | |||
| std::vector<int> input_shape = {1, 1, 1, c}; | |||
| auto tensor_x_ptr = | |||
| std::make_unique<lite::tensor::Tensor>(TypeId(kNumberTypeFloat32), input_shape, schema::Format_NHWC); | |||
| auto tensor_x_ptr = std::make_unique<lite::tensor::Tensor>( | |||
| TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), input_shape, schema::Format_NHWC); | |||
| auto tensor_x = tensor_x_ptr.get(); | |||
| if (tensor_x == nullptr) { | |||
| MS_LOG(ERROR) << "tensor_x create error."; | |||
| return; | |||
| } | |||
| std::vector<int> out_shape = {1, c}; | |||
| auto tensor_out_ptr = | |||
| std::make_unique<lite::tensor::Tensor>(TypeId(kNumberTypeFloat32), out_shape, schema::Format_NC); | |||
| auto tensor_out_ptr = std::make_unique<lite::tensor::Tensor>( | |||
| TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), out_shape, schema::Format_NC); | |||
| auto tensor_out = tensor_out_ptr.get(); | |||
| if (tensor_out == nullptr) { | |||
| MS_LOG(ERROR) << "tensor_out create error."; | |||
| @@ -76,36 +75,36 @@ TEST_F(TestReshapeOpenCL, ReshapeFp32) { | |||
| return; | |||
| } | |||
| pGraph->Init(); | |||
| memcpy(inputs[0]->Data(), input_data, input_size); | |||
| memcpy(inputs[0]->Data(), input_data, c * dtype_size); | |||
| pGraph->Run(); | |||
| size_t output_size; | |||
| std::string output_path = "./test_data/reshape/reshape_fp32_output.bin"; | |||
| auto correct_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(output_path.c_str(), &output_size)); | |||
| if (correct_data == nullptr) { | |||
| MS_LOG(ERROR) << "correct_data create error."; | |||
| return; | |||
| } | |||
| printf("==================output data=================\n"); | |||
| float *output_data = reinterpret_cast<float *>(tensor_out->Data()); | |||
| std::cout << std::endl; | |||
| int size_n = c; | |||
| size_n = size_n > 100 ? 100 : size_n; | |||
| for (int i = 0; i < size_n; i++) { | |||
| std::cout << output_data[i] << " "; | |||
| if ((i + 1) % c == 0) { | |||
| std::cout << std::endl; | |||
| } | |||
| if (enable_fp16) { | |||
| CompareOutput(outputs[0]->Data(), output_data, c, static_cast<float16_t>(1e-3), 2e-2); | |||
| } else { | |||
| CompareOutput(outputs[0]->Data(), output_data, c, static_cast<float>(1e-5)); | |||
| } | |||
| std::cout << std::endl; | |||
| // compare | |||
| CompareOutputData(output_data, correct_data, c, 0.00001); | |||
| inputs[0]->SetData(nullptr); | |||
| outputs[0]->SetData(nullptr); | |||
| MS_LOG(INFO) << "Test ReshapeFp32 passed"; | |||
| lite::opencl::OpenCLRuntime::DeleteInstance(); | |||
| } | |||
| TEST_F(TestReshapeOpenCL, ReshapeFp32) { | |||
| int c = 7; | |||
| std::vector<int> shape = {c}; | |||
| std::vector<float> input_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}; | |||
| std::vector<float> output_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; | |||
| RunTestCaseReshape(shape, input_data.data(), output_data.data(), false); | |||
| } | |||
| TEST_F(TestReshapeOpenCL, ReshapeFp16) { | |||
| int c = 7; | |||
| std::vector<int> shape = {c}; | |||
| std::vector<float16_t> input_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}; | |||
| std::vector<float16_t> output_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; | |||
| RunTestCaseReshape(shape, input_data.data(), output_data.data(), true); | |||
| } | |||
| } // namespace mindspore | |||
| @@ -21,6 +21,7 @@ | |||
| #include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" | |||
| #include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" | |||
| #include "mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.h" | |||
| #include "mindspore/lite/test/ut/src/runtime/kernel/opencl/utils_tests.h" | |||
| namespace mindspore { | |||
| class TestTransposeOpenCL : public mindspore::CommonTest { | |||
| @@ -28,31 +29,29 @@ class TestTransposeOpenCL : public mindspore::CommonTest { | |||
| TestTransposeOpenCL() {} | |||
| }; | |||
| TEST_F(TestTransposeOpenCL, TransposeFp32) { | |||
| void RunTestTranspose(const std::vector<int> &shape, void *input_data, void *output_data, bool enable_fp16) { | |||
| auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | |||
| ocl_runtime->Init(); | |||
| auto allocator = ocl_runtime->GetAllocator(); | |||
| int h = 64; | |||
| int w = 1; | |||
| int c = 7360; | |||
| size_t input_size; | |||
| std::string input_path = "./test_data/transpose/transpose_fp32_input.bin"; | |||
| auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); | |||
| if (input_data == nullptr) { | |||
| MS_LOG(ERROR) << "input_data load error."; | |||
| return; | |||
| size_t dtype_size = sizeof(float); | |||
| if (enable_fp16) { | |||
| ocl_runtime->SetFp16Enable(true); | |||
| dtype_size = sizeof(float16_t); | |||
| } | |||
| auto allocator = ocl_runtime->GetAllocator(); | |||
| int h = shape[0]; | |||
| int w = shape[1]; | |||
| int c = shape[2]; | |||
| std::vector<int> input_shape = {1, h, w, c}; | |||
| auto tensor_x_ptr = | |||
| std::make_unique<lite::tensor::Tensor>(TypeId(kNumberTypeFloat32), input_shape, schema::Format_NHWC); | |||
| auto tensor_x_ptr = std::make_unique<lite::tensor::Tensor>( | |||
| TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), input_shape, schema::Format_NHWC); | |||
| auto tensor_x = tensor_x_ptr.get(); | |||
| if (tensor_x == nullptr) { | |||
| MS_LOG(ERROR) << "tensor_x create error."; | |||
| return; | |||
| } | |||
| std::vector<int> out_shape = {1, c, h, w}; | |||
| auto tensor_out_ptr = | |||
| std::make_unique<lite::tensor::Tensor>(TypeId(kNumberTypeFloat32), out_shape, schema::Format_NCHW); | |||
| auto tensor_out_ptr = std::make_unique<lite::tensor::Tensor>( | |||
| TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), out_shape, schema::Format_NCHW); | |||
| auto tensor_out = tensor_out_ptr.get(); | |||
| if (tensor_out == nullptr) { | |||
| MS_LOG(ERROR) << "tensor_out create error."; | |||
| @@ -78,9 +77,35 @@ TEST_F(TestTransposeOpenCL, TransposeFp32) { | |||
| return; | |||
| } | |||
| pGraph->Init(); | |||
| memcpy(inputs[0]->Data(), input_data, input_size); | |||
| memcpy(inputs[0]->Data(), input_data, h * w * c * dtype_size); | |||
| pGraph->Run(); | |||
| if (enable_fp16) { | |||
| CompareOutput(outputs[0]->Data(), output_data, h * w * c, static_cast<float16_t>(1e-3), 2e-2); | |||
| } else { | |||
| CompareOutput(outputs[0]->Data(), output_data, h * w * c, static_cast<float>(1e-5)); | |||
| } | |||
| inputs[0]->SetData(nullptr); | |||
| outputs[0]->SetData(nullptr); | |||
| MS_LOG(INFO) << "Test TransposeFp32 passed"; | |||
| lite::opencl::OpenCLRuntime::DeleteInstance(); | |||
| } | |||
| TEST_F(TestTransposeOpenCL, TransposeFp32) { | |||
| int h = 64; | |||
| int w = 1; | |||
| int c = 7360; | |||
| std::vector<int> shape = {h, w, c}; | |||
| size_t input_size; | |||
| std::string input_path = "./test_data/transpose/transpose_fp32_input.bin"; | |||
| auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); | |||
| if (input_data == nullptr) { | |||
| MS_LOG(ERROR) << "input_data load error."; | |||
| return; | |||
| } | |||
| size_t output_size; | |||
| std::string output_path = "./test_data/transpose/transpose_fp32_output.bin"; | |||
| auto correct_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(output_path.c_str(), &output_size)); | |||
| @@ -88,26 +113,17 @@ TEST_F(TestTransposeOpenCL, TransposeFp32) { | |||
| MS_LOG(ERROR) << "correct_data create error."; | |||
| return; | |||
| } | |||
| printf("==================output data=================\n"); | |||
| float *output_data = reinterpret_cast<float *>(tensor_out->Data()); | |||
| std::cout << std::endl; | |||
| int size_n = h * w * c; | |||
| size_n = size_n > 100 ? 100 : size_n; | |||
| for (int i = 0; i < size_n; i++) { | |||
| std::cout << output_data[i] << " "; | |||
| if ((i + 1) % c == 0) { | |||
| std::cout << std::endl; | |||
| } | |||
| } | |||
| std::cout << std::endl; | |||
| // compare | |||
| CompareOutputData(output_data, correct_data, h * w * c, 0.00001); | |||
| RunTestTranspose(shape, input_data, correct_data, false); | |||
| } | |||
| inputs[0]->SetData(nullptr); | |||
| outputs[0]->SetData(nullptr); | |||
| TEST_F(TestTransposeOpenCL, TransposeFp16) { | |||
| int h = 4; | |||
| int w = 1; | |||
| int c = 3; | |||
| std::vector<int> shape = {h, w, c}; | |||
| std::vector<float16_t> input_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f}; | |||
| std::vector<float16_t> output_data = {0.0f, 3.0f, 6.0f, 9.0f, 1.0f, 4.0f, 7.0f, 10.0f, 2.0f, 5.0f, 8.0f, 11.0f}; | |||
| MS_LOG(INFO) << "Test TransposeFp32 passed"; | |||
| lite::opencl::OpenCLRuntime::DeleteInstance(); | |||
| RunTestTranspose(shape, input_data.data(), output_data.data(), true); | |||
| } | |||
| } // namespace mindspore | |||