| @@ -78,6 +78,7 @@ void Conv2dTransposeOpenCLKernel::PadWeight() { | |||
| // init padWeight_(buffer mem) | |||
| padWeight_ = allocator->Malloc(div_ci * div_co * C4NUM * C4NUM * kh * kw * data_size); | |||
| padWeight_ = allocator->MapBuffer(padWeight_, CL_MAP_WRITE, nullptr, true); | |||
| memset(padWeight_, 0x00, div_ci * div_co * C4NUM * C4NUM * kh * kw * data_size); | |||
| auto origin_weight = in_tensors_.at(kWeightIndex)->Data(); | |||
| auto weight_dtype = in_tensors_.at(kWeightIndex)->data_type(); | |||
| int index = 0; | |||
| @@ -90,24 +91,20 @@ void Conv2dTransposeOpenCLKernel::PadWeight() { | |||
| int co_offset = co_i * C4NUM + co4_i; | |||
| int ci_offset = ci_i * C4NUM + ci4_i; | |||
| if (co_offset < co && ci_offset < ci) { | |||
| int ori_index = ((ci_offset * kh + kh_i) * kw + kw_i) * ci + co_offset; | |||
| int ori_index = ((ci_offset * kh + kh_i) * kw + kw_i) * co + co_offset; | |||
| if (enable_fp16_) { | |||
| if (weight_dtype == kNumberTypeFloat32) { | |||
| reinterpret_cast<uint16_t *>(padWeight_)[index++] = | |||
| Float32ToShort(reinterpret_cast<float *>(origin_weight)[ori_index]); | |||
| } else { | |||
| reinterpret_cast<float16_t *>(padWeight_)[index++] = | |||
| reinterpret_cast<float16_t *>(origin_weight)[ori_index]; | |||
| reinterpret_cast<uint16_t *>(padWeight_)[index++] = | |||
| reinterpret_cast<uint16_t *>(origin_weight)[ori_index]; | |||
| } | |||
| } else { | |||
| reinterpret_cast<float *>(padWeight_)[index++] = reinterpret_cast<float *>(origin_weight)[ori_index]; | |||
| } | |||
| } else { | |||
| if (enable_fp16_) { | |||
| reinterpret_cast<float16_t *>(padWeight_)[index++] = 0.; | |||
| } else { | |||
| reinterpret_cast<float *>(padWeight_)[index++] = 0.; | |||
| } | |||
| index++; | |||
| } | |||
| } | |||
| } | |||
| @@ -128,7 +125,7 @@ void Conv2dTransposeOpenCLKernel::PadWeight() { | |||
| std::vector<size_t> img_size{im_dst_x, im_dst_y, img_dtype}; | |||
| bias_ = allocator->Malloc(im_dst_x * im_dst_y * C4NUM * data_size, img_size); | |||
| bias_ = allocator->MapBuffer(bias_, CL_MAP_WRITE, nullptr, true); | |||
| memset(bias_, 0x00, div_co * C4NUM * sizeof(data_size)); | |||
| memset(bias_, 0x00, div_co * C4NUM * data_size); | |||
| auto bias_dtype = in_tensors_[2]->data_type(); | |||
| if (in_tensors_.size() >= 3) { | |||
| if (bias_dtype == kNumberTypeFloat32 && enable_fp16_) { | |||
| @@ -145,7 +142,7 @@ void Conv2dTransposeOpenCLKernel::PadWeight() { | |||
| int Conv2dTransposeOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) { | |||
| size_t im_dst_x, im_dst_y; | |||
| im_dst_x = UP_DIV(out_tensors_[0]->Channel() * out_tensors_[0]->Width(), C4NUM); | |||
| im_dst_x = out_tensors_[0]->Width() * UP_DIV(out_tensors_[0]->Channel(), C4NUM); | |||
| im_dst_y = out_tensors_[0]->Height(); | |||
| size_t img_dtype = CL_FLOAT; | |||
| if (enable_fp16_) { | |||
| @@ -168,6 +165,7 @@ int Conv2dTransposeOpenCLKernel::Run() { | |||
| ConvParameter *param = reinterpret_cast<ConvParameter *>(op_parameter_); | |||
| int ci = in_tensors_[0]->Channel(); | |||
| int co = out_tensors_[0]->Channel(); | |||
| int co4 = UP_DIV(co, C4NUM); | |||
| int kh = param->kernel_h_; | |||
| int kw = param->kernel_w_; | |||
| int pad = param->pad_u_; | |||
| @@ -179,7 +177,7 @@ int Conv2dTransposeOpenCLKernel::Run() { | |||
| // local size should less than MAX_GROUP_SIZE | |||
| std::vector<size_t> local = {16, 1, 16}; | |||
| std::vector<size_t> global = {UP_ROUND((size_t)UP_ROUND(oh / 2, 2), local[0]), | |||
| UP_ROUND((size_t)UP_ROUND(ow / 2, 2), local[1]), UP_ROUND((size_t)co / 4, local[2])}; | |||
| UP_ROUND((size_t)UP_ROUND(ow / 2, 2), local[1]), UP_ROUND(co4, local[2])}; | |||
| cl_int2 kernel_size = {kh, kw}; | |||
| cl_int2 stride = {2, 2}; | |||
| @@ -79,7 +79,7 @@ void MatMulOpenCLKernel::PadWeight() { | |||
| size_t dtype_size = enable_fp16_ ? sizeof(float16_t) : sizeof(float); | |||
| padWeight_ = allocator->Malloc(sizeCI.s[1] * sizeCO.s[1] * C4NUM * C4NUM * dtype_size); | |||
| padWeight_ = allocator->MapBuffer(padWeight_, CL_MAP_WRITE, nullptr, true); | |||
| memset(padWeight_, 0x00, sizeCI.s[1] * sizeCO.s[1] * C4NUM * C4NUM * dtype_size); | |||
| auto origin_weight = in_tensors_.at(kWeightIndex)->Data(); | |||
| int divCI = sizeCI.s[1]; | |||
| int divCO = sizeCO.s[1]; | |||
| @@ -110,11 +110,7 @@ void MatMulOpenCLKernel::PadWeight() { | |||
| } | |||
| } | |||
| } else { | |||
| if (enable_fp16_) { | |||
| reinterpret_cast<float16_t *>(padWeight_)[index++] = 0; | |||
| } else { | |||
| reinterpret_cast<float *>(padWeight_)[index++] = 0; | |||
| } | |||
| index++; | |||
| } | |||
| } | |||
| } | |||
| @@ -30,10 +30,13 @@ class TestConv2dTransposeOpenCL : public mindspore::CommonTest { | |||
| TestConv2dTransposeOpenCL() {} | |||
| }; | |||
| void RunTestCaseConv2dTranspose(const std::vector<int> shape, const std::vector<std::string> file_path, bool fp16) { | |||
| void RunTestCaseConv2dTranspose(const std::vector<int> &shape, void *input_data, void *weight_data, void *bias_data, | |||
| void *output_data, bool fp16) { | |||
| auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | |||
| size_t dtype_size = sizeof(float); | |||
| if (fp16) { | |||
| ocl_runtime->SetFp16Enable(true); | |||
| dtype_size = sizeof(float16_t); | |||
| } | |||
| ocl_runtime->Init(); | |||
| auto allocator = ocl_runtime->GetAllocator(); | |||
| @@ -47,30 +50,6 @@ void RunTestCaseConv2dTranspose(const std::vector<int> shape, const std::vector< | |||
| int co = shape[7]; | |||
| int oh = 2 * h - 1 + 2 * (kh - 1 - pad) - kh + 1; | |||
| int ow = 2 * w - 1 + 2 * (kw - 1 - pad) - kw + 1; | |||
| size_t input_size; | |||
| std::string input_path = file_path[0]; | |||
| auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); | |||
| if (input_data == nullptr) { | |||
| MS_LOG(ERROR) << "input_data load error."; | |||
| return; | |||
| } | |||
| size_t weight_size; | |||
| std::string weight_path = file_path[1]; | |||
| auto weight_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(weight_path.c_str(), &weight_size)); | |||
| if (weight_data == nullptr) { | |||
| MS_LOG(ERROR) << "weight_data load error."; | |||
| return; | |||
| } | |||
| size_t bias_size; | |||
| std::string bias_path = file_path[2]; | |||
| auto bias_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(bias_path.c_str(), &bias_size)); | |||
| if (bias_data == nullptr) { | |||
| MS_LOG(ERROR) << "bias_data load error."; | |||
| return; | |||
| } | |||
| std::vector<int> input_shape = {n, h, w, ci}; | |||
| auto tensor_x_ptr = | |||
| std::make_unique<lite::tensor::Tensor>(TypeId(fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), input_shape); | |||
| @@ -145,18 +124,52 @@ void RunTestCaseConv2dTranspose(const std::vector<int> shape, const std::vector< | |||
| } | |||
| pGraph->Init(); | |||
| memcpy(inputs[0]->Data(), input_data, input_size); | |||
| memcpy(inputs[0]->Data(), input_data, n * h * w * ci * dtype_size); | |||
| pGraph->Run(); | |||
| if (fp16) { | |||
| CompareOutput(tensor_out, file_path[3], static_cast<float16_t>(1e-2), 2e-2); | |||
| CompareOutput(outputs[0]->Data(), output_data, n * oh * ow * co, static_cast<float16_t>(1e-3), 2e-2); | |||
| } else { | |||
| CompareOutput(tensor_out, file_path[3], static_cast<float>(1e-5)); | |||
| CompareOutput(outputs[0]->Data(), output_data, n * oh * ow * co, static_cast<float>(1e-5)); | |||
| } | |||
| inputs[0]->SetData(nullptr); | |||
| outputs[0]->SetData(nullptr); | |||
| MS_LOG(INFO) << "Test Conv2dTransposeFp32 passed"; | |||
| } | |||
| void RunTestCaseConv2dTranspose(const std::vector<int> shape, const std::vector<std::string> file_path, bool fp16) { | |||
| size_t input_size; | |||
| std::string input_path = file_path[0]; | |||
| auto input_data = mindspore::lite::ReadFile(input_path.c_str(), &input_size); | |||
| if (input_data == nullptr) { | |||
| MS_LOG(ERROR) << "input_data load error."; | |||
| return; | |||
| } | |||
| size_t weight_size; | |||
| std::string weight_path = file_path[1]; | |||
| auto weight_data = mindspore::lite::ReadFile(weight_path.c_str(), &weight_size); | |||
| if (weight_data == nullptr) { | |||
| MS_LOG(ERROR) << "weight_data load error."; | |||
| return; | |||
| } | |||
| size_t bias_size; | |||
| std::string bias_path = file_path[2]; | |||
| auto bias_data = mindspore::lite::ReadFile(bias_path.c_str(), &bias_size); | |||
| if (bias_data == nullptr) { | |||
| MS_LOG(ERROR) << "bias_data load error."; | |||
| return; | |||
| } | |||
| size_t output_size; | |||
| std::string output_path = file_path[3]; | |||
| auto output_data = mindspore::lite::ReadFile(output_path.c_str(), &output_size); | |||
| if (output_data == nullptr) { | |||
| MS_LOG(ERROR) << "output_data load error."; | |||
| return; | |||
| } | |||
| RunTestCaseConv2dTranspose(shape, input_data, weight_data, bias_data, output_data, fp16); | |||
| } | |||
| TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp32) { | |||
| int pad = 0; | |||
| int n = 1; | |||
| @@ -190,4 +203,41 @@ TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp16) { | |||
| "./test_data/conv2d_transpose/conv2d_transpose_fp16_output.bin"}; | |||
| RunTestCaseConv2dTranspose(shape, file_path, true); | |||
| } | |||
| TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp32_2) { | |||
| int pad = 0; | |||
| int n = 1; | |||
| int h = 2; | |||
| int w = 2; | |||
| int kh = 2; | |||
| int kw = 2; | |||
| int ci = 2; | |||
| int co = 1; | |||
| std::vector<int> shape = {pad, n, h, w, kh, kw, ci, co}; | |||
| std::vector<float> input_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}; | |||
| std::vector<float> weight_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; | |||
| std::vector<float> bias_data = {0.5f}; | |||
| std::vector<float> output_data = {5.5f, 6.5f, 17.5f, 22.5f, 7.5f, 8.5f, 27.5f, 32.5f, | |||
| 29.5f, 38.5f, 41.5f, 54.5f, 47.5f, 56.5f, 67.5f, 80.5f}; | |||
| RunTestCaseConv2dTranspose(shape, input_data.data(), weight_data.data(), bias_data.data(), output_data.data(), false); | |||
| } | |||
| TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp16_2) { | |||
| int pad = 0; | |||
| int n = 1; | |||
| int h = 2; | |||
| int w = 2; | |||
| int kh = 2; | |||
| int kw = 2; | |||
| int ci = 2; | |||
| int co = 1; | |||
| std::vector<int> shape = {pad, n, h, w, kh, kw, ci, co}; | |||
| std::vector<float16_t> input_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}; | |||
| std::vector<float16_t> weight_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; | |||
| std::vector<float16_t> bias_data = {0.5f}; | |||
| std::vector<float16_t> output_data = {5.5f, 6.5f, 17.5f, 22.5f, 7.5f, 8.5f, 27.5f, 32.5f, | |||
| 29.5f, 38.5f, 41.5f, 54.5f, 47.5f, 56.5f, 67.5f, 80.5f}; | |||
| RunTestCaseConv2dTranspose(shape, input_data.data(), weight_data.data(), bias_data.data(), output_data.data(), true); | |||
| } | |||
| } // namespace mindspore | |||
| @@ -30,22 +30,21 @@ namespace mindspore { | |||
| void LoadTestData(void *dst, size_t dst_size, const std::string &file_path); | |||
| template <typename T> | |||
| void CompareOutput(lite::tensor::Tensor *output_tensor, const std::string &file_path, T atol, float rtol = 1e-5) { | |||
| T *output_data = reinterpret_cast<T *>(output_tensor->Data()); | |||
| size_t output_size = output_tensor->Size(); | |||
| T *expect_data = reinterpret_cast<T *>(mindspore::lite::ReadFile(file_path.c_str(), &output_size)); | |||
| void CompareOutput(void *output, void *expect, size_t elem_num, T atol, float rtol = 1e-5) { | |||
| T *output_data = reinterpret_cast<T *>(output); | |||
| T *expect_data = reinterpret_cast<T *>(expect); | |||
| printf("output[0:12]:"); | |||
| for (int i = 0; i < 12; i++) { | |||
| for (int i = 0; i < 12 && i < elem_num; i++) { | |||
| printf("[%d]:%.3f ", i, output_data[i]); | |||
| } | |||
| printf("\n"); | |||
| printf("expect[0:12]:"); | |||
| for (int i = 0; i < 12; i++) { | |||
| for (int i = 0; i < 12 && i < elem_num; i++) { | |||
| printf("[%d]:%.3f ", i, expect_data[i]); | |||
| } | |||
| printf("\n"); | |||
| for (int i = 0; i < output_tensor->ElementsNum(); ++i) { | |||
| for (int i = 0; i < elem_num; ++i) { | |||
| if (std::fabs(output_data[i] - expect_data[i]) > atol + rtol * std::fabs(expect_data[i])) { | |||
| printf("error at idx[%d] expect=%.3f output=%.3f \n", i, expect_data[i], output_data[i]); | |||
| return; | |||
| @@ -54,6 +53,13 @@ void CompareOutput(lite::tensor::Tensor *output_tensor, const std::string &file_ | |||
| printf("compare success!\n"); | |||
| } | |||
| template <typename T> | |||
| void CompareOutput(lite::tensor::Tensor *output_tensor, const std::string &file_path, T atol, float rtol = 1e-5) { | |||
| size_t output_size; | |||
| auto expect_data = mindspore::lite::ReadFile(file_path.c_str(), &output_size); | |||
| CompareOutput(output_tensor->Data(), expect_data, output_tensor->ElementsNum(), atol, rtol); | |||
| } | |||
| } // namespace mindspore | |||
| #endif // TESTS_UT_OPENCL_KERNEL_TESTS_UTILS_H_ | |||