| @@ -54,6 +54,7 @@ endif () | |||
| if (SUPPORT_GPU) | |||
| add_definitions(-DUSE_OPENCL_WRAPPER) | |||
| add_definitions(-DMS_OPENCL_PROFILE=false) | |||
| add_definitions(-DCL_HPP_TARGET_OPENCL_VERSION=200) | |||
| add_compile_definitions(SUPPORT_GPU) | |||
| if(OFFLINE_COMPILE) | |||
| add_compile_definitions(PROGRAM_WITH_IL) | |||
| @@ -1,6 +1,8 @@ | |||
| #define FLT half | |||
| #define FLT4 half4 | |||
| #define FLT16 half16 | |||
| #define READ_IMAGE read_imageh | |||
| #define WRITE_IMAGE write_imageh | |||
| __constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; | |||
| __kernel void conv2d_transpose2x2(__read_only image2d_t src_data, __global FLT16 *weight, __read_only image2d_t biases, | |||
| __write_only image2d_t dst_data, int2 kernel_size, int2 stride, int2 padding, | |||
| @@ -14,17 +16,17 @@ __kernel void conv2d_transpose2x2(__read_only image2d_t src_data, __global FLT16 | |||
| int src_w = w / 2; | |||
| src_w = src_w * 2; | |||
| int co = get_global_id(2); | |||
| if (h * 2 >= dst_size.x || w * 2 >= dst_size.y || co >= dst_size.z) return; | |||
| if (src_h * 2 >= dst_size.x || src_w * 2 >= dst_size.y || co >= dst_size.z) return; | |||
| FLT4 r0 = (FLT4)(0.f); | |||
| FLT4 r1 = (FLT4)(0.f); | |||
| FLT4 r2 = (FLT4)(0.f); | |||
| FLT4 r3 = (FLT4)(0.f); | |||
| int base_w = (co * 4 + kh + kw * 2) * src_size.z; | |||
| for (int ci = 0; ci < src_size.z; ++ci) { | |||
| FLT4 x0 = read_imagef(src_data, smp_zero, (int2)(src_w * src_size.z + ci, src_h)); | |||
| FLT4 x1 = read_imagef(src_data, smp_zero, (int2)(src_w * src_size.z + ci, src_h + 1)); | |||
| FLT4 x2 = read_imagef(src_data, smp_zero, (int2)((src_w + 1) * src_size.z + ci, src_h)); | |||
| FLT4 x3 = read_imagef(src_data, smp_zero, (int2)((src_w + 1) * src_size.z + ci, src_h + 1)); | |||
| FLT4 x0 = READ_IMAGE(src_data, smp_zero, (int2)(src_w * src_size.z + ci, src_h)); | |||
| FLT4 x1 = READ_IMAGE(src_data, smp_zero, (int2)(src_w * src_size.z + ci, src_h + 1)); | |||
| FLT4 x2 = READ_IMAGE(src_data, smp_zero, (int2)((src_w + 1) * src_size.z + ci, src_h)); | |||
| FLT4 x3 = READ_IMAGE(src_data, smp_zero, (int2)((src_w + 1) * src_size.z + ci, src_h + 1)); | |||
| FLT16 weight_cache = weight[base_w++]; | |||
| r0 += x0.x * weight_cache.s0123; | |||
| r0 += x0.y * weight_cache.s4567; | |||
| @@ -46,14 +48,14 @@ __kernel void conv2d_transpose2x2(__read_only image2d_t src_data, __global FLT16 | |||
| r3 += x3.z * weight_cache.s89ab; | |||
| r3 += x3.w * weight_cache.scdef; | |||
| } | |||
| FLT4 bias_val = read_imagef(biases, smp_zero, (int2)(co, 0)); | |||
| FLT4 bias_val = READ_IMAGE(biases, smp_zero, (int2)(co, 0)); | |||
| r0 += bias_val; | |||
| r1 += bias_val; | |||
| r2 += bias_val; | |||
| r3 += bias_val; | |||
| write_imagef(dst_data, (int2)((2 * src_w + kw) * dst_size.z + co, 2 * src_h + kh), r0); | |||
| write_imagef(dst_data, (int2)((2 * src_w + kw) * dst_size.z + co, 2 * src_h + kh + 2), r1); | |||
| write_imagef(dst_data, (int2)((2 * src_w + kw + 2) * dst_size.z + co, 2 * src_h + kh), r2); | |||
| write_imagef(dst_data, (int2)((2 * src_w + kw + 2) * dst_size.z + co, 2 * src_h + kh + 2), r3); | |||
| WRITE_IMAGE(dst_data, (int2)((2 * src_w + kw) * dst_size.z + co, 2 * src_h + kh), r0); | |||
| WRITE_IMAGE(dst_data, (int2)((2 * src_w + kw) * dst_size.z + co, 2 * src_h + kh + 2), r1); | |||
| WRITE_IMAGE(dst_data, (int2)((2 * src_w + kw + 2) * dst_size.z + co, 2 * src_h + kh), r2); | |||
| WRITE_IMAGE(dst_data, (int2)((2 * src_w + kw + 2) * dst_size.z + co, 2 * src_h + kh + 2), r3); | |||
| } | |||
| @@ -1,31 +1,32 @@ | |||
| #pragma OPENCL EXTENSION cl_khr_fp16 : enable | |||
| #define FLT4 half4 | |||
| #define FLT16 half16 | |||
| __kernel void MatMul(__global FLT4 *x, __global FLT16 *weight, __global FLT4 *buffer, __global FLT4 *bias, | |||
| int2 offset_ci, int2 offset_co, int has_bias) { | |||
| #define READ_IMAGE read_imageh | |||
| #define WRITE_IMAGE write_imageh | |||
| __constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; | |||
| __kernel void MatMul(__read_only image2d_t input, __global FLT16 *weight, __read_only image2d_t bias, | |||
| __write_only image2d_t output, int2 offset_ci, int2 offset_co, int has_bias) { | |||
| int2 gid = (int2)(get_global_id(0), get_global_id(1)); | |||
| int2 lid = (int2)(get_local_id(0), get_local_id(1)); | |||
| FLT4 s = (FLT4)(0.0f); | |||
| FLT4 result = (FLT4)(0.0f); | |||
| bool inside = gid.x < offset_co.y; | |||
| for (uint i = lid.y; i < offset_ci.y && inside; i += 4) { | |||
| FLT4 v = x[i]; | |||
| FLT4 v = READ_IMAGE(input, smp_zero, (int2)(i, 0)); | |||
| FLT16 w = weight[gid.x + i * offset_co.y]; | |||
| s.x += dot(v, w.s0123); | |||
| s.y += dot(v, w.s4567); | |||
| s.z += dot(v, w.s89ab); | |||
| s.w += dot(v, w.scdef); | |||
| result.x += dot(v, w.s0123); | |||
| result.y += dot(v, w.s4567); | |||
| result.z += dot(v, w.s89ab); | |||
| result.w += dot(v, w.scdef); | |||
| } | |||
| __local FLT4 temp[64][4]; | |||
| temp[lid.x][lid.y] = s; | |||
| temp[lid.x][lid.y] = result; | |||
| barrier(CLK_LOCAL_MEM_FENCE); | |||
| if (lid.y == 0 && inside) { | |||
| s += temp[lid.x][1]; | |||
| s += temp[lid.x][2]; | |||
| s += temp[lid.x][3]; | |||
| result += temp[lid.x][1]; | |||
| result += temp[lid.x][2]; | |||
| result += temp[lid.x][3]; | |||
| if (has_bias != 0) { | |||
| s += bias[gid.x]; | |||
| result += READ_IMAGE(bias, smp_zero, (int2)(gid.x, 0)); | |||
| } | |||
| buffer[gid.x] = s; | |||
| // memory pollution? or protected by opencl | |||
| WRITE_IMAGE(output, (int2)(gid.x, 0), result); | |||
| } | |||
| } | |||
| @@ -1,6 +1,8 @@ | |||
| #define FLT float | |||
| #define FLT4 float4 | |||
| #define FLT16 float16 | |||
| #define READ_IMAGE read_imagef | |||
| #define WRITE_IMAGE write_imagef | |||
| __constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; | |||
| __kernel void conv2d_transpose2x2(__read_only image2d_t src_data, __global FLT16 *weight, __read_only image2d_t biases, | |||
| __write_only image2d_t dst_data, int2 kernel_size, int2 stride, int2 padding, | |||
| @@ -14,17 +16,17 @@ __kernel void conv2d_transpose2x2(__read_only image2d_t src_data, __global FLT16 | |||
| int src_w = w / 2; | |||
| src_w = src_w * 2; | |||
| int co = get_global_id(2); | |||
| if (h * 2 >= dst_size.x || w * 2 >= dst_size.y || co >= dst_size.z) return; | |||
| if (src_h * 2 >= dst_size.x || src_w * 2 >= dst_size.y || co >= dst_size.z) return; | |||
| FLT4 r0 = (FLT4)(0.f); | |||
| FLT4 r1 = (FLT4)(0.f); | |||
| FLT4 r2 = (FLT4)(0.f); | |||
| FLT4 r3 = (FLT4)(0.f); | |||
| int base_w = (co * 4 + kh + kw * 2) * src_size.z; | |||
| for (int ci = 0; ci < src_size.z; ++ci) { | |||
| FLT4 x0 = read_imagef(src_data, smp_zero, (int2)(src_w * src_size.z + ci, src_h)); | |||
| FLT4 x1 = read_imagef(src_data, smp_zero, (int2)(src_w * src_size.z + ci, src_h + 1)); | |||
| FLT4 x2 = read_imagef(src_data, smp_zero, (int2)((src_w + 1) * src_size.z + ci, src_h)); | |||
| FLT4 x3 = read_imagef(src_data, smp_zero, (int2)((src_w + 1) * src_size.z + ci, src_h + 1)); | |||
| FLT4 x0 = READ_IMAGE(src_data, smp_zero, (int2)(src_w * src_size.z + ci, src_h)); | |||
| FLT4 x1 = READ_IMAGE(src_data, smp_zero, (int2)(src_w * src_size.z + ci, src_h + 1)); | |||
| FLT4 x2 = READ_IMAGE(src_data, smp_zero, (int2)((src_w + 1) * src_size.z + ci, src_h)); | |||
| FLT4 x3 = READ_IMAGE(src_data, smp_zero, (int2)((src_w + 1) * src_size.z + ci, src_h + 1)); | |||
| FLT16 weight_cache = weight[base_w++]; | |||
| r0 += x0.x * weight_cache.s0123; | |||
| r0 += x0.y * weight_cache.s4567; | |||
| @@ -46,14 +48,14 @@ __kernel void conv2d_transpose2x2(__read_only image2d_t src_data, __global FLT16 | |||
| r3 += x3.z * weight_cache.s89ab; | |||
| r3 += x3.w * weight_cache.scdef; | |||
| } | |||
| FLT4 bias_val = read_imagef(biases, smp_zero, (int2)(co, 0)); | |||
| FLT4 bias_val = READ_IMAGE(biases, smp_zero, (int2)(co, 0)); | |||
| r0 += bias_val; | |||
| r1 += bias_val; | |||
| r2 += bias_val; | |||
| r3 += bias_val; | |||
| write_imagef(dst_data, (int2)((2 * src_w + kw) * dst_size.z + co, 2 * src_h + kh), r0); | |||
| write_imagef(dst_data, (int2)((2 * src_w + kw) * dst_size.z + co, 2 * src_h + kh + 2), r1); | |||
| write_imagef(dst_data, (int2)((2 * src_w + kw + 2) * dst_size.z + co, 2 * src_h + kh), r2); | |||
| write_imagef(dst_data, (int2)((2 * src_w + kw + 2) * dst_size.z + co, 2 * src_h + kh + 2), r3); | |||
| WRITE_IMAGE(dst_data, (int2)((2 * src_w + kw) * dst_size.z + co, 2 * src_h + kh), r0); | |||
| WRITE_IMAGE(dst_data, (int2)((2 * src_w + kw) * dst_size.z + co, 2 * src_h + kh + 2), r1); | |||
| WRITE_IMAGE(dst_data, (int2)((2 * src_w + kw + 2) * dst_size.z + co, 2 * src_h + kh), r2); | |||
| WRITE_IMAGE(dst_data, (int2)((2 * src_w + kw + 2) * dst_size.z + co, 2 * src_h + kh + 2), r3); | |||
| } | |||
| @@ -1,30 +1,32 @@ | |||
| #define FLT4 float4 | |||
| #define FLT16 float16 | |||
| __kernel void MatMul(__global FLT4 *x, __global FLT16 *weight, __global FLT4 *buffer, __global FLT4 *bias, | |||
| int2 offset_ci, int2 offset_co, int has_bias) { | |||
| #define READ_IMAGE read_imagef | |||
| #define WRITE_IMAGE write_imagef | |||
| __constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; | |||
| __kernel void MatMul(__read_only image2d_t input, __global FLT16 *weight, __read_only image2d_t bias, | |||
| __write_only image2d_t output, int2 offset_ci, int2 offset_co, int has_bias) { | |||
| int2 gid = (int2)(get_global_id(0), get_global_id(1)); | |||
| int2 lid = (int2)(get_local_id(0), get_local_id(1)); | |||
| FLT4 s = (FLT4)(0.0f); | |||
| FLT4 result = (FLT4)(0.0f); | |||
| bool inside = gid.x < offset_co.y; | |||
| for (uint i = lid.y; i < offset_ci.y && inside; i += 4) { | |||
| FLT4 v = x[i]; | |||
| FLT4 v = READ_IMAGE(input, smp_zero, (int2)(i, 0)); | |||
| FLT16 w = weight[gid.x + i * offset_co.y]; | |||
| s.x += dot(v, w.s0123); | |||
| s.y += dot(v, w.s4567); | |||
| s.z += dot(v, w.s89ab); | |||
| s.w += dot(v, w.scdef); | |||
| result.x += dot(v, w.s0123); | |||
| result.y += dot(v, w.s4567); | |||
| result.z += dot(v, w.s89ab); | |||
| result.w += dot(v, w.scdef); | |||
| } | |||
| __local FLT4 temp[64][4]; | |||
| temp[lid.x][lid.y] = s; | |||
| temp[lid.x][lid.y] = result; | |||
| barrier(CLK_LOCAL_MEM_FENCE); | |||
| if (lid.y == 0 && inside) { | |||
| s += temp[lid.x][1]; | |||
| s += temp[lid.x][2]; | |||
| s += temp[lid.x][3]; | |||
| result += temp[lid.x][1]; | |||
| result += temp[lid.x][2]; | |||
| result += temp[lid.x][3]; | |||
| if (has_bias != 0) { | |||
| s += bias[gid.x]; | |||
| result += READ_IMAGE(bias, smp_zero, (int2)(gid.x, 0)); | |||
| } | |||
| buffer[gid.x] = s; | |||
| // memory pollution? or protected by opencl | |||
| WRITE_IMAGE(output, (int2)(gid.x, 0), result); | |||
| } | |||
| } | |||
| @@ -144,8 +144,8 @@ int Conv2dTransposeOpenCLKernel::Run() { | |||
| &out_error_code); | |||
| // local size should less than MAX_GROUP_SIZE | |||
| std::vector<size_t> local = {16, 1, 16}; | |||
| std::vector<size_t> global = {UP_ROUND((size_t)oh / 2, local[0]), UP_ROUND((size_t)ow / 2, local[1]), | |||
| UP_ROUND((size_t)co / 4, local[2])}; | |||
| std::vector<size_t> global = {UP_ROUND((size_t)UP_ROUND(oh / 2, 2), local[0]), | |||
| UP_ROUND((size_t)UP_ROUND(ow / 2, 2), local[1]), UP_ROUND((size_t)co / 4, local[2])}; | |||
| cl_int2 kernel_size = {kh, kw}; | |||
| cl_int2 stride = {2, 2}; | |||
| @@ -50,22 +50,24 @@ int MatMulOpenCLKernel::Init() { | |||
| ocl_runtime->LoadSource(program_name, source); | |||
| ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options); | |||
| #endif | |||
| int ci = inputs_[1]->shape()[1]; | |||
| auto weight_format = inputs_[1]->GetFormat(); | |||
| if (weight_format != schema::Format_NHWC) { | |||
| MS_LOG(ERROR) << "weight format(" << weight_format << ") " | |||
| << "format not support!"; | |||
| return 1; | |||
| } | |||
| int ci = inputs_[1]->shape()[3]; | |||
| int co = inputs_[1]->shape()[0]; | |||
| sizeCI = {ci, UP_DIV(ci, 4)}; | |||
| sizeCO = {co, UP_DIV(co, 4)}; | |||
| auto allocator = ocl_runtime->GetAllocator(); | |||
| padWeight_ = reinterpret_cast<FLOAT_T *>(allocator->Malloc(sizeCI.s[1] * sizeCO.s[1] * 16 * sizeof(FLOAT_T))); | |||
| padWeight_ = reinterpret_cast<FLOAT_T *>(allocator->MapBuffer(padWeight_, CL_MAP_WRITE, nullptr, true)); | |||
| if (hasBias_) { | |||
| bias_ = reinterpret_cast<FLOAT_T *>(allocator->Malloc(sizeCO.s[1] * 4 * sizeof(FLOAT_T))); | |||
| bias_ = reinterpret_cast<FLOAT_T *>(allocator->MapBuffer(bias_, CL_MAP_WRITE, nullptr, true)); | |||
| } | |||
| bias_ = reinterpret_cast<FLOAT_T *>(allocator->Malloc(sizeCO.s[1] * 4 * sizeof(FLOAT_T))); | |||
| bias_ = reinterpret_cast<FLOAT_T *>(allocator->MapBuffer(bias_, CL_MAP_WRITE, nullptr, true)); | |||
| PadWeight(); | |||
| allocator->UnmapBuffer(padWeight_); | |||
| if (hasBias_) { | |||
| allocator->UnmapBuffer(bias_); | |||
| } | |||
| allocator->UnmapBuffer(bias_); | |||
| outputs_[0]->SetFormat(schema::Format_NHWC4); | |||
| MS_LOG(DEBUG) << kernel_name << " Init Done!"; | |||
| return 0; | |||
| @@ -98,6 +100,10 @@ void MatMulOpenCLKernel::PadWeight() { | |||
| for (int i = sizeCO.s[0]; i < sizeCO.s[1] * 4; i++) { | |||
| bias_[i] = 0; | |||
| } | |||
| } else { | |||
| for (int i = 0; i < sizeCO.s[1] * 4; i++) { | |||
| bias_[i] = 0; | |||
| } | |||
| } | |||
| } | |||
| @@ -114,18 +120,34 @@ int MatMulOpenCLKernel::Run() { | |||
| std::vector<size_t> local = {64, 4}; | |||
| std::vector<size_t> global = {UP_ROUND(sizeCO.s[1], local[0]), 4}; | |||
| ocl_runtime->SetKernelArg(kernel_, 0, inputs_[0]->Data()); | |||
| ocl_runtime->SetKernelArg(kernel_, 1, padWeight_); | |||
| ocl_runtime->SetKernelArg(kernel_, 2, outputs_[0]->Data()); | |||
| if (hasBias_) { | |||
| ocl_runtime->SetKernelArg(kernel_, 3, bias_); | |||
| } else { | |||
| ocl_runtime->SetKernelArg(kernel_, 3, nullptr); | |||
| cl::ImageFormat image_format; | |||
| { | |||
| image_format.image_channel_order = CL_RGBA; | |||
| #ifdef ENABLE_FP16 | |||
| image_format.image_channel_data_type = CL_HALF_FLOAT; | |||
| #else | |||
| image_format.image_channel_data_type = CL_FLOAT; | |||
| #endif | |||
| } | |||
| cl_int in_error_code, in_error_code_weight, in_error_code_bias, out_error_code; | |||
| cl::Image2D img_input(*ocl_runtime->Context(), CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, image_format, sizeCI.s[1], 1, | |||
| 0, inputs_[0]->Data(), &in_error_code); | |||
| cl::Image2D img_bias(*ocl_runtime->Context(), CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, image_format, sizeCO.s[1], 1, | |||
| 0, bias_, &in_error_code_bias); | |||
| cl::Image2D img_out(*ocl_runtime->Context(), CL_MEM_WRITE_ONLY, image_format, sizeCO.s[1], 1, 0, nullptr, | |||
| &out_error_code); | |||
| ocl_runtime->SetKernelArg(kernel_, 0, img_input); | |||
| ocl_runtime->SetKernelArg(kernel_, 1, padWeight_); | |||
| ocl_runtime->SetKernelArg(kernel_, 2, img_bias); | |||
| ocl_runtime->SetKernelArg(kernel_, 3, img_out); | |||
| ocl_runtime->SetKernelArg(kernel_, 4, sizeCI); | |||
| ocl_runtime->SetKernelArg(kernel_, 5, sizeCO); | |||
| ocl_runtime->SetKernelArg(kernel_, 6, hasBias_ ? 1 : 0); | |||
| ocl_runtime->RunKernel(kernel_, global, local, nullptr); | |||
| auto origin = cl::array<cl::size_type, 3U>{0, 0, 0}; | |||
| auto region = cl::array<cl::size_type, 3U>{(size_t)(sizeCO.s[1]), 1, 1}; | |||
| ocl_runtime->GetDefaultCommandQueue()->enqueueReadImage(img_out, CL_TRUE, origin, region, 0, 0, outputs_[0]->Data()); | |||
| return 0; | |||
| } | |||
| @@ -151,4 +173,3 @@ kernel::LiteKernel *OpenCLMatMulKernelCreator(const std::vector<lite::tensor::Te | |||
| REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_MatMul, OpenCLMatMulKernelCreator) | |||
| REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_FullConnection, OpenCLMatMulKernelCreator) | |||
| } // namespace mindspore::kernel | |||
| @@ -294,6 +294,7 @@ if (SUPPORT_GPU) | |||
| ${TEST_DIR}/ut/src/runtime/kernel/opencl/avg_pooling_tests.cc | |||
| ${TEST_DIR}/ut/src/runtime/kernel/opencl/max_pooling_tests.cc | |||
| ${TEST_DIR}/ut/src/runtime/kernel/opencl/utils_tests.cc | |||
| ${TEST_DIR}/ut/src/runtime/kernel/opencl/conv2d_transpose_tests.cc | |||
| ) | |||
| endif() | |||
| @@ -0,0 +1,110 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <iostream> | |||
| #include <memory> | |||
| #include "common/common_test.h" | |||
| #include "mindspore/lite/src/common/file_utils.h" | |||
| #include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" | |||
| #include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" | |||
| #include "mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.h" | |||
| #include "mindspore/core/utils/log_adapter.h" | |||
| namespace mindspore { | |||
| class TestConv2dTransposeOpenCL : public mindspore::Common { | |||
| public: | |||
| TestConv2dTransposeOpenCL() {} | |||
| }; | |||
| TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp32) { | |||
| // setbuf(stdout, NULL); | |||
| auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | |||
| ocl_runtime->Init(); | |||
| int pad = 0; | |||
| int n = 1; | |||
| int h = 240; | |||
| int w = 240; | |||
| int kh = 2; | |||
| int kw = 2; | |||
| int ci = 128; | |||
| int co = 128; | |||
| int oh = 2 * h - 1 + 2 * (kh - 1 - pad) - kh + 1; | |||
| int ow = 2 * w - 1 + 2 * (kw - 1 - pad) - kw + 1; | |||
| size_t input_size; | |||
| std::string input_path = "./test_data/conv2d_transpose/conv2d_transpose_fp32_input.bin"; | |||
| auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); | |||
| size_t weight_size; | |||
| std::string weight_path = "./test_data/conv2d_transpose/conv2d_transpose_fp32_weight.bin"; | |||
| auto weight_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(weight_path.c_str(), &weight_size)); | |||
| size_t bias_size; | |||
| std::string bias_path = "./test_data/conv2d_transpose/conv2d_transpose_fp32_bias.bin"; | |||
| auto bias_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(bias_path.c_str(), &bias_size)); | |||
| lite::tensor::Tensor *tensor_x = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {1, h, w, ci}); | |||
| tensor_x->SetData(input_data); | |||
| lite::tensor::Tensor *tensor_w = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {co, kh, kw, ci}); | |||
| tensor_w->SetData(weight_data); | |||
| lite::tensor::Tensor *tensor_bias = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {co}); | |||
| tensor_bias->SetData(bias_data); | |||
| lite::tensor::Tensor *tensor_out = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {1, oh, ow, co}); | |||
| std::vector<lite::tensor::Tensor *> inputs{tensor_x, tensor_w, tensor_bias}; | |||
| std::vector<lite::tensor::Tensor *> outputs{tensor_out}; | |||
| ConvParameter *opParameter = new ConvParameter(); | |||
| opParameter->kernel_h_ = kh; | |||
| opParameter->kernel_w_ = kw; | |||
| opParameter->stride_h_ = 2; | |||
| opParameter->stride_w_ = 2; | |||
| opParameter->pad_h_ = pad; | |||
| opParameter->pad_w_ = pad; | |||
| opParameter->input_channel_ = ci; | |||
| opParameter->output_channel_ = co; | |||
| auto *arith_kernel = | |||
| new kernel::Conv2dTransposeOpenCLKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs); | |||
| arith_kernel->Init(); | |||
| std::vector<kernel::LiteKernel *> kernels{arith_kernel}; | |||
| auto *pGraph = new kernel::SubGraphOpenCLKernel({tensor_x}, outputs, kernels, kernels, kernels); | |||
| pGraph->Init(); | |||
| pGraph->Run(); | |||
| printf("==================output data=================\n"); | |||
| float *output_data = reinterpret_cast<float *>(tensor_out->Data()); | |||
| std::cout << std::endl; | |||
| size_t output_size; | |||
| std::string output_path = "./test_data/conv2d_transpose/conv2d_transpose_fp32_output.bin"; | |||
| auto correct_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(output_path.c_str(), &output_size)); | |||
| int size_n = oh * ow * co; | |||
| size_n = size_n > 100 ? 100 : size_n; | |||
| for (int i = 0; i < size_n; i++) { | |||
| std::cout << output_data[i] << ", "; | |||
| if ((i + 1) % co == 0) { | |||
| std::cout << std::endl; | |||
| } | |||
| } | |||
| std::cout << std::endl; | |||
| // compare | |||
| CompareOutputData(output_data, correct_data, oh * ow * co, 0.00001); | |||
| MS_LOG(INFO) << "Test Conv2dTransposeFp32 passed"; | |||
| } | |||
| } // namespace mindspore | |||
| @@ -41,38 +41,39 @@ TEST_F(TestMatMulOpenCL, MatMulFp32) { | |||
| std::string weight_path = "./test_data/matmul/matmul_fp32_weight.bin"; | |||
| auto weight_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(weight_path.c_str(), &weight_size)); | |||
| lite::tensor::Tensor *tensor_x = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {1, ci}); | |||
| lite::tensor::Tensor *tensor_x = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {1, 1, 1, ci}); | |||
| tensor_x->SetData(input_data); | |||
| lite::tensor::Tensor *tensor_w = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {co, ci}); | |||
| lite::tensor::Tensor *tensor_w = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {co, 1, 1, ci}); | |||
| tensor_w->SetData(weight_data); | |||
| lite::tensor::Tensor *tensor_out = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {1, co}); | |||
| lite::tensor::Tensor *tensor_out = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {1, 1, 1, co}); | |||
| std::vector<lite::tensor::Tensor *> inputs{tensor_x, tensor_w}; | |||
| std::vector<lite::tensor::Tensor *> outputs{tensor_out}; | |||
| auto *arith_kernel = new kernel::MatMulOpenCLKernel(nullptr, inputs, outputs, false); | |||
| arith_kernel->Init(); | |||
| std::vector<kernel::LiteKernel *> kernels{arith_kernel}; | |||
| auto *pGraph = new kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); | |||
| auto *pGraph = new kernel::SubGraphOpenCLKernel({tensor_x}, outputs, kernels, kernels, kernels); | |||
| pGraph->Init(); | |||
| memcpy(inputs[0]->Data(), input_data, sizeof(float) * ci); | |||
| pGraph->Run(); | |||
| size_t output_size; | |||
| std::string output_path = "./test_data/matmul/matmul_fp32_output.bin"; | |||
| auto correct_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(output_path.c_str(), &output_size)); | |||
| printf("==================output data=================\n"); | |||
| float *output_data = reinterpret_cast<float *>(tensor_out->Data()); | |||
| std::cout << std::endl; | |||
| for (int i = 0; i < co; i++) { | |||
| std::cout << output_data[i] << ", "; | |||
| int size_n = co; | |||
| size_n = size_n > 100 ? 100 : size_n; | |||
| for (int i = 0; i < size_n; i++) { | |||
| std::cout << output_data[i] << " "; | |||
| } | |||
| std::cout << std::endl; | |||
| size_t output_size; | |||
| std::string output_path = "./test_data/matmul/matmul_fp32_output.bin"; | |||
| auto correct_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(output_path.c_str(), &output_size)); | |||
| // compare | |||
| CompareOutputData(output_data, correct_data, co * sizeof(float), 0.00001); | |||
| CompareOutputData(output_data, correct_data, co, 0.00001); | |||
| delete input_data; | |||
| delete weight_data; | |||