| @@ -1,3 +1,4 @@ | |||
| #pragma OPENCL EXTENSION cl_khr_fp16 : enable | |||
| __constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; | |||
| __kernel void MatMul(__read_only image2d_t input, __global FLT16 *weight, __read_only image2d_t bias, | |||
| __write_only image2d_t output, int2 offset_ci, int2 offset_co, int has_bias) { | |||
| @@ -16,7 +16,7 @@ | |||
| #include <string> | |||
| #include <set> | |||
| #include "src/common/utils.h" | |||
| #include "nnacl/fp32/common_func.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "src/runtime/opencl/opencl_runtime.h" | |||
| #include "src/runtime/kernel/opencl/kernel/conv2d_transpose.h" | |||
| @@ -73,10 +73,6 @@ void Conv2dTransposeOpenCLKernel::PadWeight() { | |||
| int div_co = UP_DIV(co, C4NUM); | |||
| auto allocator = lite::opencl::OpenCLRuntime::GetInstance()->GetAllocator(); | |||
| auto data_size = enable_fp16_ ? sizeof(float16_t) : sizeof(float); | |||
| using FLT = float; | |||
| if (enable_fp16_) { | |||
| using FLT = float16_t; | |||
| } | |||
| // IHWO to OHWI4(I)4(O)(converter format is IHWO) | |||
| // init padWeight_(buffer mem) | |||
| @@ -97,8 +93,8 @@ void Conv2dTransposeOpenCLKernel::PadWeight() { | |||
| int ori_index = ((ci_offset * kh + kh_i) * kw + kw_i) * ci + co_offset; | |||
| if (enable_fp16_) { | |||
| if (weight_dtype == kNumberTypeFloat32) { | |||
| reinterpret_cast<float16_t *>(padWeight_)[index++] = | |||
| lite::Float32ToShort(reinterpret_cast<float *>(origin_weight)[ori_index]); | |||
| reinterpret_cast<uint16_t *>(padWeight_)[index++] = | |||
| Float32ToShort(reinterpret_cast<float *>(origin_weight)[ori_index]); | |||
| } else { | |||
| reinterpret_cast<float16_t *>(padWeight_)[index++] = | |||
| reinterpret_cast<float16_t *>(origin_weight)[ori_index]; | |||
| @@ -107,7 +103,11 @@ void Conv2dTransposeOpenCLKernel::PadWeight() { | |||
| reinterpret_cast<float *>(padWeight_)[index++] = reinterpret_cast<float *>(origin_weight)[ori_index]; | |||
| } | |||
| } else { | |||
| reinterpret_cast<FLT *>(padWeight_)[index++] = 0.; | |||
| if (enable_fp16_) { | |||
| reinterpret_cast<float16_t *>(padWeight_)[index++] = 0.; | |||
| } else { | |||
| reinterpret_cast<float *>(padWeight_)[index++] = 0.; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -134,7 +134,7 @@ void Conv2dTransposeOpenCLKernel::PadWeight() { | |||
| if (bias_dtype == kNumberTypeFloat32 && enable_fp16_) { | |||
| auto fdata = reinterpret_cast<float *>(in_tensors_[2]->Data()); | |||
| for (int i = 0; i < co; i++) { | |||
| reinterpret_cast<float16_t *>(bias_)[i] = lite::Float32ToShort(fdata[i]); | |||
| reinterpret_cast<uint16_t *>(bias_)[i] = Float32ToShort(fdata[i]); | |||
| } | |||
| } else { | |||
| memcpy(bias_, in_tensors_[2]->Data(), co * data_size); | |||
| @@ -16,6 +16,7 @@ | |||
| #include <set> | |||
| #include <string> | |||
| #include "nnacl/fp32/common_func.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "src/runtime/opencl/opencl_runtime.h" | |||
| #include "nnacl/fp32/matmul.h" | |||
| @@ -34,7 +35,7 @@ namespace mindspore::kernel { | |||
| int MatMulOpenCLKernel::Init() { | |||
| std::string kernel_name = "MatMul"; | |||
| auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | |||
| enable_fp16_ = ocl_runtime->GetFp16Enable(); | |||
| #ifdef PROGRAM_WITH_IL | |||
| kernel_ = ocl_runtime->GetKernelFromBinary(kernel_name); | |||
| #else | |||
| @@ -74,11 +75,12 @@ int MatMulOpenCLKernel::ReSize() { return RET_OK; } | |||
| void MatMulOpenCLKernel::PadWeight() { | |||
| auto allocator = lite::opencl::OpenCLRuntime::GetInstance()->GetAllocator(); | |||
| padWeight_ = | |||
| reinterpret_cast<FLOAT_t *>(allocator->Malloc(sizeCI.s[1] * sizeCO.s[1] * C4NUM * C4NUM * sizeof(FLOAT_t))); | |||
| padWeight_ = reinterpret_cast<FLOAT_t *>(allocator->MapBuffer(padWeight_, CL_MAP_WRITE, nullptr, true)); | |||
| auto origin_weight = reinterpret_cast<FLOAT_t *>(in_tensors_.at(kWeightIndex)->Data()); | |||
| size_t dtype_size = enable_fp16_ ? sizeof(float16_t) : sizeof(float); | |||
| padWeight_ = allocator->Malloc(sizeCI.s[1] * sizeCO.s[1] * C4NUM * C4NUM * dtype_size); | |||
| padWeight_ = allocator->MapBuffer(padWeight_, CL_MAP_WRITE, nullptr, true); | |||
| auto origin_weight = in_tensors_.at(kWeightIndex)->Data(); | |||
| int divCI = sizeCI.s[1]; | |||
| int divCO = sizeCO.s[1]; | |||
| int co = sizeCO.s[0]; | |||
| @@ -90,9 +92,29 @@ void MatMulOpenCLKernel::PadWeight() { | |||
| int src_x = i * C4NUM + l; | |||
| int src_y = j * C4NUM + k; | |||
| if (src_x < sizeCI.s[0] && src_y < sizeCO.s[0]) { | |||
| padWeight_[index++] = origin_weight[src_y * sizeCI.s[0] + src_x]; | |||
| if (enable_fp16_) { | |||
| if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeFloat32) { | |||
| reinterpret_cast<uint16_t *>(padWeight_)[index++] = | |||
| Float32ToShort(reinterpret_cast<float *>(origin_weight)[src_y * sizeCI.s[0] + src_x]); | |||
| } else { | |||
| reinterpret_cast<uint16_t *>(padWeight_)[index++] = | |||
| reinterpret_cast<uint16_t *>(origin_weight)[src_y * sizeCI.s[0] + src_x]; | |||
| } | |||
| } else { | |||
| if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeFloat16) { | |||
| reinterpret_cast<float *>(padWeight_)[index++] = | |||
| ShortToFloat32(reinterpret_cast<uint16_t *>(origin_weight)[src_y * sizeCI.s[0] + src_x]); | |||
| } else { | |||
| reinterpret_cast<float *>(padWeight_)[index++] = | |||
| reinterpret_cast<float *>(origin_weight)[src_y * sizeCI.s[0] + src_x]; | |||
| } | |||
| } | |||
| } else { | |||
| padWeight_[index++] = 0; | |||
| if (enable_fp16_) { | |||
| reinterpret_cast<float16_t *>(padWeight_)[index++] = 0; | |||
| } else { | |||
| reinterpret_cast<float *>(padWeight_)[index++] = 0; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -102,17 +124,23 @@ void MatMulOpenCLKernel::PadWeight() { | |||
| size_t im_dst_x, im_dst_y; | |||
| im_dst_x = divCO; | |||
| im_dst_y = 1; | |||
| #ifdef ENABLE_FP16 | |||
| size_t img_dtype = CL_HALF_FLOAT; | |||
| #else | |||
| size_t img_dtype = CL_FLOAT; | |||
| #endif | |||
| if (enable_fp16_) { | |||
| img_dtype = CL_HALF_FLOAT; | |||
| } | |||
| std::vector<size_t> img_size{im_dst_x, im_dst_y, img_dtype}; | |||
| bias_ = reinterpret_cast<FLOAT_t *>(allocator->Malloc(im_dst_x * im_dst_y * C4NUM * sizeof(FLOAT_t), img_size)); | |||
| bias_ = reinterpret_cast<FLOAT_t *>(allocator->MapBuffer(bias_, CL_MAP_WRITE, nullptr, true)); | |||
| memset(bias_, 0x00, divCO * C4NUM * sizeof(FLOAT_t)); | |||
| bias_ = allocator->Malloc(im_dst_x * im_dst_y * C4NUM * dtype_size, img_size); | |||
| bias_ = allocator->MapBuffer(bias_, CL_MAP_WRITE, nullptr, true); | |||
| memset(bias_, 0x00, divCO * C4NUM * dtype_size); | |||
| if (in_tensors_.size() >= 3) { | |||
| memcpy(bias_, in_tensors_[2]->Data(), co * sizeof(FLOAT_t)); | |||
| if (in_tensors_[2]->data_type() == kNumberTypeFloat32 && enable_fp16_) { | |||
| auto fdata = reinterpret_cast<float *>(in_tensors_[2]->Data()); | |||
| for (int i = 0; i < co; i++) { | |||
| reinterpret_cast<uint16_t *>(bias_)[i] = Float32ToShort(fdata[i]); | |||
| } | |||
| } else { | |||
| memcpy(bias_, in_tensors_[2]->Data(), co * dtype_size); | |||
| } | |||
| } | |||
| allocator->UnmapBuffer(bias_); | |||
| } | |||
| @@ -121,11 +149,10 @@ int MatMulOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) | |||
| size_t im_dst_x, im_dst_y; | |||
| im_dst_x = sizeCO.s[1]; | |||
| im_dst_y = 1; | |||
| #ifdef ENABLE_FP16 | |||
| size_t img_dtype = CL_HALF_FLOAT; | |||
| #else | |||
| size_t img_dtype = CL_FLOAT; | |||
| #endif | |||
| if (enable_fp16_) { | |||
| img_dtype = CL_HALF_FLOAT; | |||
| } | |||
| img_size->clear(); | |||
| std::vector<size_t> vec{im_dst_x, im_dst_y, img_dtype}; | |||
| *img_size = vec; | |||
| @@ -23,7 +23,6 @@ | |||
| #include "nnacl/conv_parameter.h" | |||
| #include "src/runtime/opencl/opencl_runtime.h" | |||
| namespace mindspore::kernel { | |||
| class MatMulOpenCLKernel : public OpenCLKernel { | |||
| @@ -43,9 +42,10 @@ class MatMulOpenCLKernel : public OpenCLKernel { | |||
| private: | |||
| cl::Kernel kernel_; | |||
| FLOAT_t *padWeight_; | |||
| FLOAT_t *bias_; | |||
| bool hasBias_ = false; | |||
| void *padWeight_; | |||
| void *bias_; | |||
| bool hasBias_{false}; | |||
| bool enable_fp16_{false}; | |||
| cl_int2 sizeCI; | |||
| cl_int2 sizeCO; | |||
| }; | |||
| @@ -22,6 +22,7 @@ | |||
| #include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" | |||
| #include "mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.h" | |||
| #include "mindspore/core/utils/log_adapter.h" | |||
| #include "mindspore/lite/test/ut/src/runtime/kernel/opencl/utils_tests.h" | |||
| namespace mindspore { | |||
| class TestConv2dTransposeOpenCL : public mindspore::CommonTest { | |||
| @@ -29,7 +30,7 @@ class TestConv2dTransposeOpenCL : public mindspore::CommonTest { | |||
| TestConv2dTransposeOpenCL() {} | |||
| }; | |||
| void RunTestCase(const std::vector<int> shape, const std::vector<std::string> file_path, bool fp16) { | |||
| void RunTestCaseConv2dTranspose(const std::vector<int> shape, const std::vector<std::string> file_path, bool fp16) { | |||
| auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | |||
| if (fp16) { | |||
| ocl_runtime->SetFp16Enable(true); | |||
| @@ -146,32 +147,12 @@ void RunTestCase(const std::vector<int> shape, const std::vector<std::string> fi | |||
| pGraph->Init(); | |||
| memcpy(inputs[0]->Data(), input_data, input_size); | |||
| pGraph->Run(); | |||
| using FLT = float; | |||
| if (fp16) { | |||
| using FLT = float16_t; | |||
| CompareOutput(tensor_out, file_path[3], static_cast<float16_t>(1e-2), 2e-2); | |||
| } else { | |||
| CompareOutput(tensor_out, file_path[3], static_cast<float>(1e-5)); | |||
| } | |||
| std::cout << "==================output data=================" << std::endl; | |||
| FLT *output_data = reinterpret_cast<FLT *>(tensor_out->Data()); | |||
| std::cout << std::endl; | |||
| size_t output_size; | |||
| std::string output_path = file_path[3]; | |||
| auto correct_data = reinterpret_cast<FLT *>(mindspore::lite::ReadFile(output_path.c_str(), &output_size)); | |||
| if (correct_data == nullptr) { | |||
| MS_LOG(ERROR) << "correct_data create error."; | |||
| return; | |||
| } | |||
| int size_n = oh * ow * co; | |||
| size_n = size_n > 100 ? 100 : size_n; | |||
| for (int i = 0; i < size_n; i++) { | |||
| std::cout << output_data[i] << ", " << correct_data[i] << " "; | |||
| if ((i + 1) % co == 0) { | |||
| std::cout << std::endl; | |||
| } | |||
| } | |||
| std::cout << std::endl; | |||
| // compare | |||
| CommonTest::CompareOutputData(output_data, correct_data, oh * ow * co, 0.00001); | |||
| inputs[0]->SetData(nullptr); | |||
| outputs[0]->SetData(nullptr); | |||
| MS_LOG(INFO) << "Test Conv2dTransposeFp32 passed"; | |||
| @@ -190,7 +171,7 @@ TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp32) { | |||
| "./test_data/conv2d_transpose/conv2d_transpose_fp32_weight.bin", | |||
| "./test_data/conv2d_transpose/conv2d_transpose_fp32_bias.bin", | |||
| "./test_data/conv2d_transpose/conv2d_transpose_fp32_output.bin"}; | |||
| RunTestCase(shape, file_path, false); | |||
| RunTestCaseConv2dTranspose(shape, file_path, false); | |||
| } | |||
| TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp16) { | |||
| @@ -207,6 +188,6 @@ TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp16) { | |||
| "./test_data/conv2d_transpose/conv2d_transpose_fp16_weight.bin", | |||
| "./test_data/conv2d_transpose/conv2d_transpose_fp16_bias.bin", | |||
| "./test_data/conv2d_transpose/conv2d_transpose_fp16_output.bin"}; | |||
| RunTestCase(shape, file_path, true); | |||
| RunTestCaseConv2dTranspose(shape, file_path, true); | |||
| } | |||
| } // namespace mindspore | |||
| @@ -21,6 +21,7 @@ | |||
| #include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" | |||
| #include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" | |||
| #include "mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.h" | |||
| #include "mindspore/lite/test/ut/src/runtime/kernel/opencl/utils_tests.h" | |||
| namespace mindspore { | |||
| class TestMatMulOpenCL : public mindspore::CommonTest { | |||
| @@ -28,29 +29,32 @@ class TestMatMulOpenCL : public mindspore::CommonTest { | |||
| TestMatMulOpenCL() {} | |||
| }; | |||
| TEST_F(TestMatMulOpenCL, MatMulFp32) { | |||
| void RunTestCaseMatMul(const std::vector<int> shape, const std::vector<std::string> file_path, bool fp16) { | |||
| auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | |||
| ocl_runtime->Init(); | |||
| if (fp16) { | |||
| ocl_runtime->SetFp16Enable(true); | |||
| } | |||
| auto allocator = ocl_runtime->GetAllocator(); | |||
| size_t input_size; | |||
| int ci = 1280; | |||
| int co = 1001; | |||
| std::string input_path = "./test_data/matmul/matmul_fp32_input.bin"; | |||
| auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); | |||
| int ci = shape[0]; | |||
| int co = shape[1]; | |||
| std::string input_path = file_path[0]; | |||
| auto input_data = mindspore::lite::ReadFile(input_path.c_str(), &input_size); | |||
| if (input_data == nullptr) { | |||
| MS_LOG(ERROR) << "input_data load error."; | |||
| return; | |||
| } | |||
| size_t weight_size; | |||
| std::string weight_path = "./test_data/matmul/matmul_fp32_weight.bin"; | |||
| auto weight_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(weight_path.c_str(), &weight_size)); | |||
| std::string weight_path = file_path[1]; | |||
| auto weight_data = mindspore::lite::ReadFile(weight_path.c_str(), &weight_size); | |||
| if (weight_data == nullptr) { | |||
| MS_LOG(ERROR) << "weight_data load error."; | |||
| return; | |||
| } | |||
| std::vector<int> input_shape = {1, ci}; | |||
| auto tensor_x_ptr = | |||
| std::make_unique<lite::tensor::Tensor>(TypeId(kNumberTypeFloat32), input_shape, schema::Format_NC); | |||
| auto tensor_x_ptr = std::make_unique<lite::tensor::Tensor>(TypeId(fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), | |||
| input_shape, schema::Format_NC); | |||
| auto tensor_x = tensor_x_ptr.get(); | |||
| if (tensor_x == nullptr) { | |||
| MS_LOG(ERROR) << "tensor_x create error."; | |||
| @@ -58,7 +62,8 @@ TEST_F(TestMatMulOpenCL, MatMulFp32) { | |||
| } | |||
| std::vector<int> w_shape = {co, ci}; | |||
| auto tensor_w_ptr = std::make_unique<lite::tensor::Tensor>(TypeId(kNumberTypeFloat32), w_shape); | |||
| auto tensor_w_ptr = | |||
| std::make_unique<lite::tensor::Tensor>(TypeId(fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), w_shape); | |||
| auto tensor_w = tensor_w_ptr.get(); | |||
| if (tensor_w == nullptr) { | |||
| MS_LOG(ERROR) << "tensor_w create error."; | |||
| @@ -67,8 +72,8 @@ TEST_F(TestMatMulOpenCL, MatMulFp32) { | |||
| tensor_w->SetData(weight_data); | |||
| std::vector<int> out_shape = {1, co}; | |||
| auto tensor_out_ptr = | |||
| std::make_unique<lite::tensor::Tensor>(TypeId(kNumberTypeFloat32), out_shape, schema::Format_NC); | |||
| auto tensor_out_ptr = std::make_unique<lite::tensor::Tensor>(TypeId(fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), | |||
| out_shape, schema::Format_NC); | |||
| auto tensor_out = tensor_out_ptr.get(); | |||
| if (tensor_out == nullptr) { | |||
| MS_LOG(ERROR) << "tensor_out create error."; | |||
| @@ -76,16 +81,16 @@ TEST_F(TestMatMulOpenCL, MatMulFp32) { | |||
| } | |||
| std::vector<lite::tensor::Tensor *> inputs{tensor_x, tensor_w}; | |||
| std::vector<lite::tensor::Tensor *> outputs{tensor_out}; | |||
| auto arith_kernel_ptr = std::make_unique<kernel::MatMulOpenCLKernel>(nullptr, inputs, outputs, false); | |||
| auto arith_kernel = arith_kernel_ptr.get(); | |||
| if (arith_kernel == nullptr) { | |||
| MS_LOG(ERROR) << "arith_kernel create error."; | |||
| auto op_kernel_ptr = std::make_unique<kernel::MatMulOpenCLKernel>(nullptr, inputs, outputs, false); | |||
| auto op_kernel = op_kernel_ptr.get(); | |||
| if (op_kernel == nullptr) { | |||
| MS_LOG(ERROR) << "op_kernel create error."; | |||
| return; | |||
| } | |||
| arith_kernel->Init(); | |||
| op_kernel->Init(); | |||
| inputs[0]->MallocData(allocator); | |||
| std::vector<kernel::LiteKernel *> kernels{arith_kernel}; | |||
| std::vector<kernel::LiteKernel *> kernels{op_kernel}; | |||
| std::vector<lite::tensor::Tensor *> inputs_g{tensor_x}; | |||
| auto pGraph_ptr = std::make_unique<kernel::SubGraphOpenCLKernel>(inputs_g, outputs, kernels, kernels, kernels); | |||
| @@ -97,24 +102,34 @@ TEST_F(TestMatMulOpenCL, MatMulFp32) { | |||
| pGraph->Init(); | |||
| memcpy(inputs[0]->Data(), input_data, input_size); | |||
| pGraph->Run(); | |||
| size_t output_size; | |||
| std::string output_path = "./test_data/matmul/matmul_fp32_output.bin"; | |||
| auto correct_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(output_path.c_str(), &output_size)); | |||
| printf("==================output data=================\n"); | |||
| float *output_data = reinterpret_cast<float *>(tensor_out->Data()); | |||
| std::cout << std::endl; | |||
| int size_n = co; | |||
| size_n = size_n > 100 ? 100 : size_n; | |||
| for (int i = 0; i < size_n; i++) { | |||
| std::cout << output_data[i] << " "; | |||
| if (fp16) { | |||
| CompareOutput(tensor_out, file_path[2], static_cast<float16_t>(1e-3), 2e-2); | |||
| } else { | |||
| CompareOutput(tensor_out, file_path[2], static_cast<float>(1e-5)); | |||
| } | |||
| std::cout << std::endl; | |||
| // compare | |||
| CompareOutputData(output_data, correct_data, co, 0.0001); | |||
| tensor_x->SetData(nullptr); | |||
| tensor_out->SetData(nullptr); | |||
| MS_LOG(INFO) << "TestMatMulFp32 passed"; | |||
| } | |||
| TEST_F(TestMatMulOpenCL, MatMulFp32) { | |||
| int ci = 1280; | |||
| int co = 1001; | |||
| std::vector<int> shape = {ci, co}; | |||
| std::vector<std::string> file_path = {"./test_data/matmul/matmul_fp32_input.bin", | |||
| "./test_data/matmul/matmul_fp32_weight.bin", | |||
| "./test_data/matmul/matmul_fp32_output.bin"}; | |||
| RunTestCaseMatMul(shape, file_path, false); | |||
| } | |||
| TEST_F(TestMatMulOpenCL, MatMulFp16) { | |||
| int ci = 1280; | |||
| int co = 1001; | |||
| std::vector<int> shape = {ci, co}; | |||
| std::vector<std::string> file_path = {"./test_data/matmul/matmul_fp16_input.bin", | |||
| "./test_data/matmul/matmul_fp16_weight.bin", | |||
| "./test_data/matmul/matmul_fp16_output.bin"}; | |||
| RunTestCaseMatMul(shape, file_path, true); | |||
| } | |||
| } // namespace mindspore | |||
| @@ -109,7 +109,7 @@ TEST_F(TestMaxPoolingOpenCL, MaxPool_1_32_512_96) { | |||
| MS_LOG(INFO) << "compare result"; | |||
| std::cout << "compare result" << std::endl; | |||
| CompareOutput(output_tensor, expect_file); | |||
| CompareOutput(output_tensor, expect_file, static_cast<float>(1e-5)); | |||
| for (auto tensor : inputs) { | |||
| delete tensor; | |||
| } | |||
| @@ -83,7 +83,7 @@ void RunTestCase(std::vector<int> input_shape, std::vector<int> output_shape, st | |||
| pGraph->Run(); | |||
| MS_LOG(INFO) << "compare result"; | |||
| CompareOutput(output_tensor, expect_file); | |||
| CompareOutput(output_tensor, expect_file, static_cast<float>(1e-5)); | |||
| for (auto tensor : inputs) { | |||
| delete tensor; | |||
| } | |||
| @@ -35,34 +35,4 @@ void LoadTestData(void *dst, size_t dst_size, const std::string &file_path) { | |||
| } | |||
| } | |||
| void CompareOutput(lite::tensor::Tensor *output_tensor, const std::string &file_path) { | |||
| float *output_data = reinterpret_cast<float *>(output_tensor->Data()); | |||
| size_t output_size = output_tensor->Size(); | |||
| float *expect_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(file_path.c_str(), &output_size)); | |||
| printf("output[0:12]:"); | |||
| for (int i = 0; i < 12; i++) { | |||
| printf("[%d]:%.3f ", i, output_data[i]); | |||
| } | |||
| printf("\n"); | |||
| printf("expect[0:12]:"); | |||
| for (int i = 0; i < 12; i++) { | |||
| printf("[%d]:%.3f ", i, expect_data[i]); | |||
| } | |||
| printf("\n"); | |||
| constexpr float atol = 1e-5; | |||
| for (int i = 0; i < output_tensor->ElementsNum(); ++i) { | |||
| if (std::fabs(output_data[i] - expect_data[i]) > atol) { | |||
| printf("error at idx[%d] expect=%.3f output=%.3f \n", i, expect_data[i], output_data[i]); | |||
| printf("error at idx[%d] expect=%.3f output=%.3f \n", i, expect_data[i], output_data[i]); | |||
| printf("error at idx[%d] expect=%.3f output=%.3f \n", i, expect_data[i], output_data[i]); | |||
| return; | |||
| } | |||
| } | |||
| printf("compare success!\n"); | |||
| printf("compare success!\n"); | |||
| printf("compare success!\n\n\n"); | |||
| } | |||
| } // namespace mindspore | |||
| @@ -29,7 +29,30 @@ namespace mindspore { | |||
| void LoadTestData(void *dst, size_t dst_size, const std::string &file_path); | |||
| void CompareOutput(lite::tensor::Tensor *output_tensor, const std::string &file_path); | |||
| template <typename T> | |||
| void CompareOutput(lite::tensor::Tensor *output_tensor, const std::string &file_path, T atol, float rtol = 1e-5) { | |||
| T *output_data = reinterpret_cast<T *>(output_tensor->Data()); | |||
| size_t output_size = output_tensor->Size(); | |||
| T *expect_data = reinterpret_cast<T *>(mindspore::lite::ReadFile(file_path.c_str(), &output_size)); | |||
| printf("output[0:12]:"); | |||
| for (int i = 0; i < 12; i++) { | |||
| printf("[%d]:%.3f ", i, output_data[i]); | |||
| } | |||
| printf("\n"); | |||
| printf("expect[0:12]:"); | |||
| for (int i = 0; i < 12; i++) { | |||
| printf("[%d]:%.3f ", i, expect_data[i]); | |||
| } | |||
| printf("\n"); | |||
| for (int i = 0; i < output_tensor->ElementsNum(); ++i) { | |||
| if (std::fabs(output_data[i] - expect_data[i]) > atol + rtol * std::fabs(expect_data[i])) { | |||
| printf("error at idx[%d] expect=%.3f output=%.3f \n", i, expect_data[i], output_data[i]); | |||
| return; | |||
| } | |||
| } | |||
| printf("compare success!\n"); | |||
| } | |||
| } // namespace mindspore | |||