From 531b3d52523036f982baac969397ec181f788b8e Mon Sep 17 00:00:00 2001 From: Pengyongrong Date: Tue, 15 Sep 2020 02:02:45 -0700 Subject: [PATCH] add new ops name cast --- .../src/runtime/kernel/opencl/kernel/cast.cc | 6 +- .../src/runtime/kernel/opencl/kernel/cast.h | 4 +- mindspore/lite/test/run_test.sh | 6 + .../runtime/kernel/opencl/activation_tests.cc | 14 +- .../kernel/opencl/arithmetic_self_tests.cc | 108 +++++++++++++- .../runtime/kernel/opencl/batchnorm_tests.cc | 138 +++++++++++++++++- .../src/runtime/kernel/opencl/cast_tests.cc | 13 +- .../src/runtime/kernel/opencl/concat_tests.cc | 121 ++++++++++++++- .../kernel/opencl/depthwise_conv2d_tests.cc | 1 - .../src/runtime/kernel/opencl/slice_tests.cc | 135 +++++++++++++++-- 10 files changed, 504 insertions(+), 42 deletions(-) diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/cast.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/cast.cc index 0fe854e339..2e85a596a5 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/cast.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/cast.cc @@ -16,7 +16,7 @@ #include #include #include -#include +#include #include "src/kernel_registry.h" #include "src/runtime/kernel/opencl/kernel/cast.h" #include "src/runtime/kernel/opencl/utils.h" @@ -49,14 +49,16 @@ int CastOpenCLKernel::GetImageSize(size_t idx, std::vector *img_size) { return RET_OK; } -void CastOpenCLKernel::GetKernelName(std::string *kernel_name, CastParameter *param) { +int CastOpenCLKernel::GetKernelName(std::string *kernel_name, CastParameter *param) { if (param->src_type_ == kNumberTypeFloat32 && param->dst_type_ == kNumberTypeFloat16) { kernel_name[0] += "_Fp32ToFp16"; } else if (param->src_type_ == kNumberTypeFloat16 && param->dst_type_ == kNumberTypeFloat32) { kernel_name[0] += "_Fp16ToFp32"; } else { MS_LOG(ERROR) << "unsupported convert format from : " << param->src_type_ << "to " << param->dst_type_; + return RET_ERROR; } + return RET_OK; } int CastOpenCLKernel::Init() { diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/cast.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/cast.h index 1542552522..28e73c2848 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/cast.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/cast.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_CAST_H_ #include -#include +#include #include "ir/anf.h" #include "src/runtime/kernel/opencl/opencl_kernel.h" #include "nnacl/fp32/cast.h" @@ -39,7 +39,7 @@ class CastOpenCLKernel : public OpenCLKernel { int Run() override; - void GetKernelName(std::string *kernel_name, CastParameter *param); + int GetKernelName(std::string *kernel_name, CastParameter *param); int GetImageSize(size_t idx, std::vector *img_size) override; diff --git a/mindspore/lite/test/run_test.sh b/mindspore/lite/test/run_test.sh index 7dd8e10497..ea99abf4fc 100755 --- a/mindspore/lite/test/run_test.sh +++ b/mindspore/lite/test/run_test.sh @@ -31,3 +31,9 @@ cp -fr $TEST_DATA_DIR/testPK ./data # for GPU OpenCL ./lite-test --gtest_filter="TestConvolutionOpenCL.simple_test*" + +./lite-test --gtest_filter="TestArithmeticSelfOpenCLCI.ArithmeticSelfRound*" +./lite-test --gtest_filter="TestConcatOpenCLCI.ConcatFp32_2inputforCI*" +./lite-test --gtest_filter="TestSliceOpenCLfp32.Slicefp32CI*" +./lite-test --gtest_filter="TestBatchnormOpenCLCI.Batchnormfp32CI*" + diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/activation_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/activation_tests.cc index f45f0f8f2d..128dd57c03 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/activation_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/activation_tests.cc @@ -519,17 +519,16 @@ TEST_F(TestActivationOpenCL, LeakyReluFp_dim4) { delete param; delete input_tensor; delete output_tensor; - delete sub_graph; lite::opencl::OpenCLRuntime::DeleteInstance(); } TEST_F(TestActivationOpenCLTanh, TanhFp_dim4) { - std::string in_file = "/data/local/tmp/test_data/in_tanh.bin"; - std::string out_file = "/data/local/tmp/test_data/out_tanh.bin"; + std::string in_file = "/data/local/tmp/test_data/in_tanhfp16.bin"; + std::string out_file = "/data/local/tmp/test_data/out_tanhfp16.bin"; MS_LOG(INFO) << "Tanh Begin test!"; auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); ocl_runtime->Init(); - auto data_type = kNumberTypeFloat32; + auto data_type = kNumberTypeFloat16; ocl_runtime->SetFp16Enable(data_type == kNumberTypeFloat16); bool enable_fp16 = ocl_runtime->GetFp16Enable(); @@ -561,7 +560,7 @@ TEST_F(TestActivationOpenCLTanh, TanhFp_dim4) { printf_tensor("Tanh:FP32--input data--", inputs[0]); } - auto *param = new (std::nothrow) ActivationParameter(); + auto param = reinterpret_cast(malloc(sizeof(ActivationParameter))); if (param == nullptr) { MS_LOG(ERROR) << "New ActivationParameter fail."; delete input_tensor; @@ -628,10 +627,11 @@ TEST_F(TestActivationOpenCLTanh, TanhFp_dim4) { printf_tensor("Tanh:FP32--output data---", outputs[0]); CompareRes(output_tensor, out_file); } - delete param; + lite::opencl::OpenCLRuntime::DeleteInstance(); + input_tensor->SetData(nullptr); delete input_tensor; + output_tensor->SetData(nullptr); delete output_tensor; delete sub_graph; - lite::opencl::OpenCLRuntime::DeleteInstance(); } } // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/arithmetic_self_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/arithmetic_self_tests.cc index 6fee1ca04f..21778846f4 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/arithmetic_self_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/arithmetic_self_tests.cc @@ -28,9 +28,14 @@ class TestArithmeticSelfOpenCLfp16 : public mindspore::CommonTest { TestArithmeticSelfOpenCLfp16() {} }; +class TestArithmeticSelfOpenCLCI : public mindspore::CommonTest { + public: + TestArithmeticSelfOpenCLCI() {} +}; + template void CompareOutputData1(T *input_data1, T *output_data, T *correct_data, int size, float err_bound) { - for (size_t i = 0; i < 100; i++) { + for (size_t i = 0; i < size; i++) { T abs = fabs(output_data[i] - correct_data[i]); ASSERT_LE(abs, err_bound); } @@ -53,7 +58,7 @@ TEST_F(TestArithmeticSelfOpenCLfp16, ArithmeticSelfOpenCLFp16) { MS_LOG(INFO) << " init tensors "; - std::vector shape = {1, 19, 19, 96}; + std::vector shape = {1, 2, 2, 144}; auto data_type = kNumberTypeFloat16; auto tensor_type = lite::TensorCategory(schema::NodeType_ValueNode); auto *input_tensor = new (std::nothrow) lite::Tensor(data_type, shape, schema::Format_NHWC, tensor_type); @@ -66,7 +71,7 @@ TEST_F(TestArithmeticSelfOpenCLfp16, ArithmeticSelfOpenCLFp16) { std::vector outputs{output_tensor}; MS_LOG(INFO) << " initialize param "; - auto param = new (std::nothrow) ArithmeticSelfParameter(); + auto param = reinterpret_cast(malloc(sizeof(ArithmeticSelfParameter))); if (param == nullptr) { MS_LOG(INFO) << " new ConcatParameter failed "; for (auto tensor : inputs) { @@ -77,7 +82,7 @@ TEST_F(TestArithmeticSelfOpenCLfp16, ArithmeticSelfOpenCLFp16) { } return; } - param->op_parameter_.type_ = schema::PrimitiveType_Neg; + param->op_parameter_.type_ = schema::PrimitiveType_Round; auto *arithmeticself_kernel = new (std::nothrow) kernel::ArithmeticSelfOpenCLKernel(reinterpret_cast(param), inputs, outputs); if (arithmeticself_kernel == nullptr) { @@ -120,13 +125,106 @@ TEST_F(TestArithmeticSelfOpenCLfp16, ArithmeticSelfOpenCLFp16) { sub_graph->Run(); auto *output_data_gpu = reinterpret_cast(output_tensor->data_c()); CompareOutputData1(input_data1, output_data_gpu, correctOutput, output_tensor->ElementsNum(), 0.000001); + lite::opencl::OpenCLRuntime::DeleteInstance(); + for (auto tensor : inputs) { + tensor->SetData(nullptr); + delete tensor; + } + for (auto tensor : outputs) { + tensor->SetData(nullptr); + delete tensor; + } + delete sub_graph; +} + +TEST_F(TestArithmeticSelfOpenCLCI, ArithmeticSelfRound) { + MS_LOG(INFO) << " begin test "; + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + ocl_runtime->Init(); + auto allocator = ocl_runtime->GetAllocator(); + float input_data1[] = {0.75f, 0.06f, 0.74f, 0.30f, 0.9f, 0.59f, 0.03f, 0.37f, + 0.75f, 0.06f, 0.74f, 0.30f, 0.9f, 0.59f, 0.03f, 0.37f}; + float correctOutput[] = {1.0f, 0.0f, 1.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.0f, + 1.0f, 0.0f, 1.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.0f}; + + MS_LOG(INFO) << " init tensors "; + std::vector shape = {1, 1, 4, 4}; + auto data_type = kNumberTypeFloat32; + auto tensor_type = lite::TensorCategory(schema::NodeType_ValueNode); + auto *input_tensor = new (std::nothrow) lite::Tensor(data_type, shape, schema::Format_NHWC, tensor_type); + auto *output_tensor = new (std::nothrow) lite::Tensor(data_type, shape, schema::Format_NHWC, tensor_type); + if (input_tensor == nullptr || output_tensor == nullptr) { + MS_LOG(INFO) << " new input_tensor or output_tensor failed "; + return; + } + std::vector inputs{input_tensor}; + std::vector outputs{output_tensor}; + + MS_LOG(INFO) << " initialize param "; + auto param = reinterpret_cast(malloc(sizeof(ArithmeticSelfParameter))); + if (param == nullptr) { + MS_LOG(INFO) << " new ConcatParameter failed "; + for (auto tensor : inputs) { + delete tensor; + } + for (auto tensor : outputs) { + delete tensor; + } + return; + } + param->op_parameter_.type_ = schema::PrimitiveType_Round; + auto *arithmeticself_kernel = + new (std::nothrow) kernel::ArithmeticSelfOpenCLKernel(reinterpret_cast(param), inputs, outputs); + if (arithmeticself_kernel == nullptr) { + MS_LOG(INFO) << " new kernel::ArithmeticSelfOpenCLKernel failed "; + for (auto tensor : inputs) { + delete tensor; + } + for (auto tensor : outputs) { + delete tensor; + } + delete param; + return; + } + arithmeticself_kernel->SetFormatType(schema::Format_NC4HW4); + arithmeticself_kernel->Init(); + // to do allocate memory for inputs and outputs + for (auto &input_tensor : inputs) { + input_tensor->MallocData(allocator); + } + MS_LOG(INFO) << " initialize sub_graph "; + std::vector kernels{arithmeticself_kernel}; + auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); + if (sub_graph == nullptr) { + MS_LOG(INFO) << " new kernel::SubGraphOpenCLKernel failed "; + for (auto tensor : inputs) { + delete tensor; + } + for (auto tensor : outputs) { + delete tensor; + } + delete param; + delete arithmeticself_kernel; + return; + } + sub_graph->Init(); + MS_LOG(INFO) << " initialize input data "; + std::cout << sizeof(input_data1) / sizeof(float) << std::endl; + memcpy(inputs[0]->data_c(), input_data1, sizeof(input_data1)); + + std::cout << "==================output data================" << std::endl; + sub_graph->Run(); + auto *output_data_gpu = reinterpret_cast(output_tensor->data_c()); + CompareOutputData1(input_data1, output_data_gpu, correctOutput, output_tensor->ElementsNum(), 0.000001); + lite::opencl::OpenCLRuntime::DeleteInstance(); for (auto tensor : inputs) { + tensor->SetData(nullptr); delete tensor; } for (auto tensor : outputs) { + tensor->SetData(nullptr); delete tensor; } - delete param; delete sub_graph; } } // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/batchnorm_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/batchnorm_tests.cc index 87242397d2..6a89badaf1 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/batchnorm_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/batchnorm_tests.cc @@ -31,6 +31,128 @@ class TestBatchnormOpenCLfp16 : public mindspore::CommonTest { public: TestBatchnormOpenCLfp16() {} }; +class TestBatchnormOpenCLCI : public mindspore::CommonTest { + public: + TestBatchnormOpenCLCI() {} +}; + +TEST_F(TestBatchnormOpenCLCI, Batchnormfp32CI) { + MS_LOG(INFO) << " begin test "; + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + ocl_runtime->Init(); + auto allocator = ocl_runtime->GetAllocator(); + + MS_LOG(INFO) << " Read tensors from .bin "; + std::vector input_shape = {1, 2, 2, 8}; + std::vector output_shape = {1, 2, 2, 8}; + auto data_type = kNumberTypeFloat32; + auto tensor_type = lite::TensorCategory(schema::NodeType_ValueNode); + + float input_data[] = {2.471454, -2.1379554, -0.0904604, 1.2928944, -0.19215967, -0.8677279, -0.12759617, + 1.2242758, -0.06398406, -0.4041858, 0.20352598, -2.067808, 0.52113044, -1.567617, + 0.28003863, 0.41367245, 0.77298605, 0.29908583, 1.4015813, 1.330567, 1.760135, + 0.6320845, 0.6995399, -1.208123, -1.9738104, -1.3283046, 1.022744, 0.02741058, + 0.84505165, -0.89434445, 1.983211, -0.5485428}; + float correct_data[] = {0.7505676, 0.515882, 0.26147857, 1.6026789, 0.47575232, 0.50116986, 0.33589783, + 1.4884706, 0.56019205, 0.7832671, 0.53893626, -0.5093127, 0.71395767, 0.18509413, + 0.33990562, 0.891792, 0.6230367, 0.89172685, 1.6696336, 1.6263539, 1.1277269, + 1.1784974, 0.34403008, -0.3019984, 0.4167911, 0.6407478, 1.3120956, 0.80740136, + 0.8221321, 0.4891496, 0.3566509, 0.18351318}; + float mean_data[] = {0.3016613, -0.89284, 0.63434774, 0.145766, 0.73353934, -0.6744012, 0.7087985, -0.02967937}; + float var_data[] = {2.5604038, 0.84985304, 0.36261332, 1.9083935, 0.4920925, 0.6476224, 0.6269014, 0.8567283}; + float scale_data[] = {0.1201471, 0.142174, 0.5683258, 0.86815494, 0.23426804, 0.3634345, 0.0077846, 0.6813278}; + float offset_data[] = {0.58764684, 0.70790595, 0.945536, 0.8817803, 0.78489226, 0.5884778, 0.3441211, 0.5654443}; + + MS_LOG(INFO) << " construct tensors "; + lite::Tensor *tensor_data = new (std::nothrow) lite::Tensor(data_type, input_shape, schema::Format_NHWC, tensor_type); + lite::Tensor *tensor_mean = + new (std::nothrow) lite::Tensor(data_type, {1, 1, 1, input_shape[3]}, schema::Format_NHWC, tensor_type); + lite::Tensor *tensor_var = + new (std::nothrow) lite::Tensor(data_type, {1, 1, 1, input_shape[3]}, schema::Format_NHWC, tensor_type); + lite::Tensor *tensor_scale = + new (std::nothrow) lite::Tensor(data_type, {1, 1, 1, input_shape[3]}, schema::Format_NHWC, tensor_type); + lite::Tensor *tensor_offset = + new (std::nothrow) lite::Tensor(data_type, {1, 1, 1, input_shape[3]}, schema::Format_NHWC, tensor_type); + if (tensor_data == nullptr || tensor_mean == nullptr || tensor_var == nullptr || tensor_scale == nullptr || + tensor_offset == nullptr) { + MS_LOG(INFO) << " init tensor failed "; + return; + } + auto *output_tensor = new (std::nothrow) lite::Tensor(data_type, output_shape, schema::Format_NHWC, tensor_type); + if (output_tensor == nullptr) { + MS_LOG(INFO) << " init tensor failed "; + delete tensor_data; + delete tensor_mean; + delete tensor_var; + delete tensor_scale; + delete tensor_offset; + return; + } + std::vector inputs = {tensor_data, tensor_scale, tensor_offset, tensor_mean, tensor_var}; + std::vector outputs{output_tensor}; + + MS_LOG(INFO) << " initialize tensors "; + auto param = reinterpret_cast(malloc(sizeof(BatchNormParameter))); + if (param == nullptr) { + MS_LOG(INFO) << " new BatchNormParameter failed "; + for (auto tensor : outputs) { + delete tensor; + } + return; + } + param->epsilon_ = pow(10, -5); + auto *batchnorm_kernel = + new (std::nothrow) kernel::BatchNormOpenCLKernel(reinterpret_cast(param), inputs, outputs); + if (batchnorm_kernel == nullptr) { + MS_LOG(INFO) << " new kernel::BatchNorm_kernel failed "; + for (auto tensor : outputs) { + delete tensor; + } + delete param; + return; + } + batchnorm_kernel->Init(); + + // to do allocate memory for inputs and outputs + for (auto &input_tensor : inputs) { + input_tensor->MallocData(allocator); + } + + MS_LOG(INFO) << " initialize sub_graph "; + std::vector kernels{batchnorm_kernel}; + auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); + if (sub_graph == nullptr) { + MS_LOG(INFO) << " new kernel::SubGraphOpenCLKernel failed "; + for (auto tensor : outputs) { + delete tensor; + } + delete param; + delete batchnorm_kernel; + return; + } + sub_graph->Init(); + MS_LOG(INFO) << " init tensors "; + memcpy(inputs[0]->data_c(), input_data, sizeof(input_data)); + memcpy(inputs[1]->data_c(), scale_data, sizeof(scale_data)); + memcpy(inputs[2]->data_c(), offset_data, sizeof(offset_data)); + memcpy(inputs[3]->data_c(), mean_data, sizeof(mean_data)); + memcpy(inputs[4]->data_c(), var_data, sizeof(var_data)); + std::cout << "==================output data================" << std::endl; + sub_graph->Run(); + + auto *output_data_gpu = reinterpret_cast(output_tensor->data_c()); + CompareOutputData(output_data_gpu, correct_data, output_tensor->ElementsNum(), 0.0001); + lite::opencl::OpenCLRuntime::DeleteInstance(); + for (auto tensor : inputs) { + tensor->SetData(nullptr); + delete tensor; + } + for (auto tensor : outputs) { + tensor->SetData(nullptr); + delete tensor; + } + delete sub_graph; +} TEST_F(TestBatchnormOpenCLfp16, Batchnormfp16input_dim4) { MS_LOG(INFO) << "begin test"; @@ -42,7 +164,7 @@ TEST_F(TestBatchnormOpenCLfp16, Batchnormfp16input_dim4) { MS_LOG(INFO) << " Read tensors from .bin "; std::vector input_shape = {1, 256, 256, 48}; std::vector output_shape = {1, 256, 256, 48}; - auto data_type = kNumberTypeFloat32; + auto data_type = kNumberTypeFloat16; auto tensor_type = lite::TensorCategory(schema::NodeType_ValueNode); // get the input from .bin @@ -90,7 +212,7 @@ TEST_F(TestBatchnormOpenCLfp16, Batchnormfp16input_dim4) { std::vector outputs{output_tensor}; MS_LOG(INFO) << " initialize tensors "; - auto param = new (std::nothrow) BatchNormParameter(); + auto param = reinterpret_cast(malloc(sizeof(BatchNormParameter))); if (param == nullptr) { MS_LOG(INFO) << " new BatchNormParameter failed "; for (auto tensor : outputs) { @@ -140,15 +262,18 @@ TEST_F(TestBatchnormOpenCLfp16, Batchnormfp16input_dim4) { auto *output_data_gpu = reinterpret_cast(output_tensor->data_c()); CompareOutputData(output_data_gpu, correct_data, output_tensor->ElementsNum(), 0.01); + lite::opencl::OpenCLRuntime::DeleteInstance(); for (auto tensor : inputs) { + tensor->SetData(nullptr); delete tensor; } for (auto tensor : outputs) { + tensor->SetData(nullptr); delete tensor; } - delete param; delete sub_graph; } + TEST_F(TestBatchnormOpenCLfp32, Batchnormfp32input_dim4) { MS_LOG(INFO) << " begin test "; auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); @@ -206,7 +331,7 @@ TEST_F(TestBatchnormOpenCLfp32, Batchnormfp32input_dim4) { std::vector outputs{output_tensor}; MS_LOG(INFO) << " initialize tensors "; - auto param = new (std::nothrow) BatchNormParameter(); + auto param = reinterpret_cast(malloc(sizeof(BatchNormParameter))); if (param == nullptr) { MS_LOG(INFO) << " new BatchNormParameter failed "; for (auto tensor : outputs) { @@ -256,14 +381,15 @@ TEST_F(TestBatchnormOpenCLfp32, Batchnormfp32input_dim4) { auto *output_data_gpu = reinterpret_cast(output_tensor->data_c()); CompareOutputData(output_data_gpu, correct_data, output_tensor->ElementsNum(), 0.0001); + lite::opencl::OpenCLRuntime::DeleteInstance(); for (auto tensor : inputs) { + tensor->SetData(nullptr); delete tensor; } for (auto tensor : outputs) { + tensor->SetData(nullptr); delete tensor; } - delete param; - delete batchnorm_kernel; delete sub_graph; } } // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/cast_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/cast_tests.cc index f3a70bb5e4..b74f735653 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/cast_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/cast_tests.cc @@ -48,7 +48,7 @@ TEST_F(TestCastSelfOpenCL, Castfp32tofp16) { std::string correctOutputPath = "./test_data/out_castfp16.bin"; MS_LOG(INFO) << " initialize param "; - auto param = new (std::nothrow) CastParameter(); + auto param = reinterpret_cast(malloc(sizeof(CastParameter))); if (param == nullptr) { MS_LOG(INFO) << " new CastParameter failed "; return; @@ -113,14 +113,16 @@ TEST_F(TestCastSelfOpenCL, Castfp32tofp16) { sub_graph->Run(); auto *output_data_gpu = reinterpret_cast(output_tensor->data_c()); CompareOutputData1(output_data_gpu, correctOutput, output_tensor->ElementsNum(), 0.000001); + lite::opencl::OpenCLRuntime::DeleteInstance(); for (auto tensor : inputs) { + tensor->SetData(nullptr); delete tensor; } for (auto tensor : outputs) { + tensor->SetData(nullptr); delete tensor; } delete sub_graph; - lite::opencl::OpenCLRuntime::DeleteInstance(); } TEST_F(TestCastSelfOpenCL, Castfp16tofp32) { @@ -135,7 +137,7 @@ TEST_F(TestCastSelfOpenCL, Castfp16tofp32) { std::string correctOutputPath = "./test_data/out_castfp32.bin"; MS_LOG(INFO) << " initialize param "; - auto param = new (std::nothrow) CastParameter(); + auto param = reinterpret_cast(malloc(sizeof(CastParameter))); if (param == nullptr) { MS_LOG(INFO) << " new CastParameter failed "; return; @@ -199,14 +201,15 @@ TEST_F(TestCastSelfOpenCL, Castfp16tofp32) { sub_graph->Run(); auto *output_data_gpu = reinterpret_cast(output_tensor->data_c()); CompareOutputData1(output_data_gpu, correctOutput, output_tensor->ElementsNum(), 0.000001); - + lite::opencl::OpenCLRuntime::DeleteInstance(); for (auto tensor : inputs) { + tensor->SetData(nullptr); delete tensor; } for (auto tensor : outputs) { + tensor->SetData(nullptr); delete tensor; } delete sub_graph; - lite::opencl::OpenCLRuntime::DeleteInstance(); } } // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/concat_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/concat_tests.cc index 45fd1d3b3d..c1ca13634b 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/concat_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/concat_tests.cc @@ -32,6 +32,11 @@ class TestConcatOpenCLfp16 : public mindspore::CommonTest { TestConcatOpenCLfp16() {} }; +class TestConcatOpenCLCI : public mindspore::CommonTest { + public: + TestConcatOpenCLCI() {} +}; + template void CompareOutputData1(T *output_data, T *correct_data, int size, float err_bound) { for (size_t i = 0; i < size; i++) { @@ -39,7 +44,109 @@ void CompareOutputData1(T *output_data, T *correct_data, int size, float err_bou ASSERT_LE(abs, err_bound); } } -TEST_F(TestConcatOpenCLfp16, ConcatFp16_2input_dim4_axis3) { + +TEST_F(TestConcatOpenCLCI, ConcatFp32_2inputforCI) { + MS_LOG(INFO) << " begin test "; + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + ocl_runtime->Init(); + auto allocator = ocl_runtime->GetAllocator(); + + MS_LOG(INFO) << " init tensors "; + constexpr int INPUT_NUM = 2; + std::array, INPUT_NUM> input_shapes = {std::vector{1, 1, 1, 8}, std::vector{1, 1, 1, 8}}; + std::vector output_shape = {2, 1, 1, 8}; + auto data_type = kNumberTypeFloat32; + auto tensor_type = lite::TensorCategory(schema::NodeType_ValueNode); + float input_data1[] = {0.75f, 0.06f, 0.74f, 0.30f, 0.9f, 0.59f, 0.03f, 0.37f}; + float input_data2[] = {0.5f, 0.6f, 0.74f, 0.23f, 0.46f, 0.69f, 0.13f, 0.47f}; + float correctOutput[] = {0.75f, 0.06f, 0.74f, 0.30f, 0.9f, 0.59f, 0.03f, 0.37f, + 0.5f, 0.6f, 0.74f, 0.23f, 0.46f, 0.69f, 0.13f, 0.47f}; + auto *output_tensor = new (std::nothrow) lite::Tensor(data_type, output_shape, schema::Format_NHWC, tensor_type); + if (output_tensor == nullptr) { + MS_LOG(INFO) << " new output_tensor failed "; + return; + } + std::vector inputs; + std::vector outputs{output_tensor}; + for (auto &shape : input_shapes) { + auto input_temp = new (std::nothrow) lite::Tensor(data_type, shape, schema::Format_NHWC, tensor_type); + inputs.push_back(input_temp); + if (input_temp == nullptr) { + MS_LOG(INFO) << " new input_tensor failed "; + return; + } + } + + MS_LOG(INFO) << " initialize tensors "; + auto param = reinterpret_cast(malloc(sizeof(ConcatParameter))); + if (param == nullptr) { + MS_LOG(INFO) << " new ConcatParameter failed "; + for (auto tensor : inputs) { + delete tensor; + } + for (auto tensor : outputs) { + delete tensor; + } + return; + } + param->axis_ = 0; + auto *concat_kernel = + new (std::nothrow) kernel::ConcatOpenCLKernel(reinterpret_cast(param), inputs, outputs); + if (concat_kernel == nullptr) { + MS_LOG(INFO) << " new kernel::ConcatOpenCLKernel failed "; + for (auto tensor : inputs) { + delete tensor; + } + for (auto tensor : outputs) { + delete tensor; + } + delete param; + return; + } + concat_kernel->SetFormatType(schema::Format_NC4HW4); + concat_kernel->Init(); + // to do allocate memory for inputs + for (auto &input_tensor : inputs) { + input_tensor->MallocData(allocator); + } + + MS_LOG(INFO) << " initialize sub_graph "; + std::vector kernels{concat_kernel}; + auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); + if (sub_graph == nullptr) { + MS_LOG(INFO) << " new kernel::SubGraphOpenCLKernel failed "; + for (auto tensor : inputs) { + delete tensor; + } + for (auto tensor : outputs) { + delete tensor; + } + delete param; + delete concat_kernel; + return; + } + sub_graph->Init(); + MS_LOG(INFO) << " initialize input data "; + memcpy(inputs[0]->data_c(), input_data1, sizeof(input_data1)); + memcpy(inputs[1]->data_c(), input_data2, sizeof(input_data2)); + + std::cout << "==================output data================" << std::endl; + sub_graph->Run(); + auto *output_data_gpu = reinterpret_cast(output_tensor->data_c()); + CompareOutputData1(output_data_gpu, correctOutput, output_tensor->ElementsNum(), 0.00001); + lite::opencl::OpenCLRuntime::DeleteInstance(); + for (auto tensor : inputs) { + tensor->SetData(nullptr); + delete tensor; + } + for (auto tensor : outputs) { + tensor->SetData(nullptr); + delete tensor; + } + delete sub_graph; +} + +TEST_F(TestConcatOpenCLfp16, ConcatFp16_2input_dim4_axis1) { MS_LOG(INFO) << " begin test "; auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); ocl_runtime->SetFp16Enable(true); @@ -89,7 +196,7 @@ TEST_F(TestConcatOpenCLfp16, ConcatFp16_2input_dim4_axis3) { MS_LOG(INFO) << " input_shapes size =: " << input_shapes.size(); MS_LOG(INFO) << " initialize tensors "; - auto param = new (std::nothrow) ConcatParameter(); + auto param = reinterpret_cast(malloc(sizeof(ConcatParameter))); if (param == nullptr) { MS_LOG(INFO) << " new ConcatParameter failed "; for (auto tensor : inputs) { @@ -157,13 +264,15 @@ TEST_F(TestConcatOpenCLfp16, ConcatFp16_2input_dim4_axis3) { sub_graph->Run(); auto *output_data_gpu = reinterpret_cast(output_tensor->data_c()); CompareOutputData1(output_data_gpu, correctOutput, output_tensor->ElementsNum(), 0.000001); + lite::opencl::OpenCLRuntime::DeleteInstance(); for (auto tensor : inputs) { + tensor->SetData(nullptr); delete tensor; } for (auto tensor : outputs) { + tensor->SetData(nullptr); delete tensor; } - delete param; delete sub_graph; } @@ -212,7 +321,7 @@ TEST_F(TestConcatOpenCLfp32, ConcatFp32_2input_dim4_axis3) { MS_LOG(INFO) << " input_shapes size=: " << input_shapes.size(); MS_LOG(INFO) << " initialize tensors "; - auto param = new (std::nothrow) ConcatParameter(); + auto param = reinterpret_cast(malloc(sizeof(ConcatParameter))); if (param == nullptr) { MS_LOG(INFO) << " new ConcatParameter failed "; for (auto tensor : inputs) { @@ -276,13 +385,15 @@ TEST_F(TestConcatOpenCLfp32, ConcatFp32_2input_dim4_axis3) { sub_graph->Run(); auto *output_data_gpu = reinterpret_cast(output_tensor->data_c()); CompareOutputData1(output_data_gpu, correctOutput, output_tensor->ElementsNum(), 0.00001); + lite::opencl::OpenCLRuntime::DeleteInstance(); for (auto tensor : inputs) { + tensor->SetData(nullptr); delete tensor; } for (auto tensor : outputs) { + tensor->SetData(nullptr); delete tensor; } - delete param; delete sub_graph; } } // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/depthwise_conv2d_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/depthwise_conv2d_tests.cc index 6ed1b61af0..e1633936a7 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/depthwise_conv2d_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/depthwise_conv2d_tests.cc @@ -451,7 +451,6 @@ TEST_F(TestConvolutionDwOpenCL, NoPadNHWC4Fp16) { float16_t gnd_data[] = {3.3848767, 1.4446403, 1.8428744, 1.3194335, 2.5873442, 2.1384869, 2.04022, 1.1872686, 2.2294958, 1.6570128, 2.465089, 1.4294086, 2.7941442, 1.7871612, 2.188921, 1.0601988}; - lite::opencl::OpenCLRuntime::GetInstance()->SetFp16Enable(true); DepthWiseTestMain(conv_param, input_data, weight_data, gnd_data, schema::Format_NHWC4, kNumberTypeFloat16, true, 1e-2); } diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/slice_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/slice_tests.cc index ce895f4c0c..5b6290e3f3 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/slice_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/slice_tests.cc @@ -40,6 +40,117 @@ void CompareOutputData1(T *output_data, T *correct_data, int size, float err_bou } } +TEST_F(TestSliceOpenCLfp32, Slicefp32CI) { + MS_LOG(INFO) << " begin test "; + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + ocl_runtime->Init(); + auto allocator = ocl_runtime->GetAllocator(); + + MS_LOG(INFO) << " Read tensors from .bin "; + std::vector input_shape = {1, 2, 2, 8}; + std::vector output_shape = {1, 2, 2, 5}; + std::vector begin = {0, 0, 0, 2}; + std::vector size = {1, 2, 2, 5}; + auto data_type = kNumberTypeFloat32; + auto tensor_type = lite::TensorCategory(schema::NodeType_ValueNode); + + float input_data[] = {-0.45816937, 0.92391545, -0.9135602, -1.4002057, 1.1080881, 0.40712625, -0.28128958, + 0.09470133, 0.19801073, 0.04927751, -1.2808367, 0.1470597, 0.03393711, -0.33282498, + -1.0433807, -1.3678077, -0.6423931, 0.5584889, 0.28965706, 0.5343769, 0.75480366, + -1.9328151, -0.48714373, 1.711132, -1.8871949, -0.2987629, -0.14000037, -0.080552, + 0.95056856, -0.06886655, 0.5316237, 0.05787678}; + float correct_data[] = {-0.9135602, -1.4002057, 1.1080881, 0.40712625, -0.28128958, -1.2808367, 0.1470597, + 0.03393711, -0.33282498, -1.0433807, 0.28965706, 0.5343769, 0.75480366, -1.9328151, + -0.48714373, -0.14000037, -0.080552, 0.95056856, -0.06886655, 0.5316237}; + MS_LOG(INFO) << " construct tensors "; + lite::Tensor *tensor_data = new (std::nothrow) lite::Tensor(data_type, input_shape, schema::Format_NHWC, tensor_type); + if (tensor_data == nullptr) { + MS_LOG(INFO) << " init tensor failed "; + return; + } + auto *output_tensor = new (std::nothrow) lite::Tensor(data_type, output_shape, schema::Format_NHWC, tensor_type); + if (output_tensor == nullptr) { + delete tensor_data; + MS_LOG(INFO) << " init tensor failed "; + return; + } + std::vector inputs = {tensor_data}; + std::vector outputs = {output_tensor}; + + MS_LOG(INFO) << "setting SliceParameter "; + auto param = reinterpret_cast(malloc(sizeof(SliceParameter))); + if (param == nullptr) { + for (auto tensor : inputs) { + delete tensor; + } + for (auto tensor : outputs) { + delete tensor; + } + MS_LOG(INFO) << "new SliceParameter failed "; + return; + } + for (int i = 0; i < input_shape.size(); i++) { + param->begin_[i] = begin[i]; + param->size_[i] = size[i]; + } + + auto *slice_kernel = + new (std::nothrow) kernel::SliceOpenCLKernel(reinterpret_cast(param), inputs, outputs); + if (slice_kernel == nullptr) { + for (auto tensor : inputs) { + delete tensor; + } + for (auto tensor : outputs) { + delete tensor; + } + delete param; + MS_LOG(INFO) << "new kernel::slice_kernel failed "; + return; + } + slice_kernel->Init(); + + // to do allocate memory for inputs and outputs + for (auto &input_tensor : inputs) { + input_tensor->MallocData(allocator); + } + + MS_LOG(INFO) << " initialize sub_graph "; + std::vector kernels{slice_kernel}; + auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); + if (sub_graph == nullptr) { + for (auto tensor : inputs) { + delete tensor; + } + for (auto tensor : outputs) { + delete tensor; + } + delete param; + delete slice_kernel; + MS_LOG(INFO) << " new kernel::SubGraphOpenCLKernel failed "; + return; + } + sub_graph->Init(); + + MS_LOG(INFO) << " init tensors "; + memcpy(inputs[0]->data_c(), input_data, sizeof(input_data)); + + std::cout << "==================output data================" << std::endl; + sub_graph->Run(); + + auto *output_data_gpu = reinterpret_cast(output_tensor->data_c()); + CompareOutputData1(output_data_gpu, correct_data, output_tensor->ElementsNum(), 0.0001); + lite::opencl::OpenCLRuntime::DeleteInstance(); + for (auto tensor : inputs) { + tensor->SetData(nullptr); + delete tensor; + } + for (auto tensor : outputs) { + tensor->SetData(nullptr); + delete tensor; + } + delete sub_graph; +} + TEST_F(TestSliceOpenCLfp32, Slicefp32input_dim4) { MS_LOG(INFO) << " begin test "; auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); @@ -49,7 +160,7 @@ TEST_F(TestSliceOpenCLfp32, Slicefp32input_dim4) { MS_LOG(INFO) << " Read tensors from .bin "; std::vector input_shape = {1, 19, 19, 96}; std::vector output_shape = {1, 10, 10, 13}; - std::vector begin = {0, 2, 3, 3}; + std::vector begin = {0, 2, 3, 4}; std::vector size = {1, 10, 10, 13}; auto data_type = kNumberTypeFloat32; auto tensor_type = lite::TensorCategory(schema::NodeType_ValueNode); @@ -76,7 +187,7 @@ TEST_F(TestSliceOpenCLfp32, Slicefp32input_dim4) { std::vector outputs = {output_tensor}; MS_LOG(INFO) << "setting SliceParameter "; - auto param = new (std::nothrow) SliceParameter(); + auto param = reinterpret_cast(malloc(sizeof(SliceParameter))); if (param == nullptr) { for (auto tensor : inputs) { delete tensor; @@ -137,14 +248,18 @@ TEST_F(TestSliceOpenCLfp32, Slicefp32input_dim4) { auto *output_data_gpu = reinterpret_cast(output_tensor->data_c()); CompareOutputData1(output_data_gpu, correct_data, output_tensor->ElementsNum(), 0.0001); + lite::opencl::OpenCLRuntime::DeleteInstance(); for (auto tensor : inputs) { + tensor->SetData(nullptr); delete tensor; } for (auto tensor : outputs) { + tensor->SetData(nullptr); delete tensor; } delete sub_graph; } + TEST_F(TestSliceOpenCLfp16, Slicefp16input_dim4) { MS_LOG(INFO) << " begin test "; auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); @@ -153,17 +268,17 @@ TEST_F(TestSliceOpenCLfp16, Slicefp16input_dim4) { auto allocator = ocl_runtime->GetAllocator(); MS_LOG(INFO) << " Read tensors from .bin "; - std::vector input_shape = {1, 256, 256, 48}; - std::vector output_shape = {1, 255, 255, 15}; + std::vector input_shape = {1, 25, 25, 48}; + std::vector output_shape = {1, 24, 24, 15}; std::vector begin = {0, 1, 1, 7}; - std::vector size = {1, 255, 255, 15}; + std::vector size = {1, 24, 24, 15}; auto data_type = kNumberTypeFloat16; auto tensor_type = lite::TensorCategory(schema::NodeType_ValueNode); // get the input from .bin size_t input_size, output_size; - std::string input_path = "./test_data/in_data.bin"; - std::string output_path = "./test_data/out_data.bin"; + std::string input_path = "./test_data/in_slicefp16.bin"; + std::string output_path = "./test_data/out_slicefp16.bin"; auto input_data = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); auto correct_data = reinterpret_cast(mindspore::lite::ReadFile(output_path.c_str(), &output_size)); @@ -183,7 +298,7 @@ TEST_F(TestSliceOpenCLfp16, Slicefp16input_dim4) { std::vector outputs = {output_tensor}; MS_LOG(INFO) << " setting SliceParameter "; - auto param = new (std::nothrow) SliceParameter(); + auto param = reinterpret_cast(malloc(sizeof(SliceParameter))); if (param == nullptr) { for (auto tensor : inputs) { delete tensor; @@ -241,13 +356,15 @@ TEST_F(TestSliceOpenCLfp16, Slicefp16input_dim4) { std::cout << "==================output data================" << std::endl; sub_graph->Run(); - auto *output_data_gpu = reinterpret_cast(output_tensor->data_c()); CompareOutputData1(output_data_gpu, correct_data, output_tensor->ElementsNum(), 0.0001); + lite::opencl::OpenCLRuntime::DeleteInstance(); for (auto tensor : inputs) { + tensor->SetData(nullptr); delete tensor; } for (auto tensor : outputs) { + tensor->SetData(nullptr); delete tensor; } delete sub_graph;