Browse Source

fix bug: image2d oversize in adreno

tags/v1.2.0-rc1
chenzupeng 5 years ago
parent
commit
e1a3d737ac
8 changed files with 126 additions and 47 deletions
  1. +3
    -0
      mindspore/lite/src/runtime/gpu/opencl/opencl_runtime.cc
  2. +4
    -2
      mindspore/lite/src/runtime/gpu/opencl/opencl_runtime.h
  3. +28
    -29
      mindspore/lite/src/runtime/kernel/opencl/cl/conv2d.cl
  4. +18
    -4
      mindspore/lite/src/runtime/kernel/opencl/cl/to_format.cl
  5. +43
    -4
      mindspore/lite/src/runtime/kernel/opencl/cl/transpose.cl
  6. +4
    -4
      mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.cc
  7. +9
    -0
      mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.cc
  8. +17
    -4
      mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h

+ 3
- 0
mindspore/lite/src/runtime/gpu/opencl/opencl_runtime.cc View File

@@ -165,6 +165,8 @@ int OpenCLRuntime::InitGPUDevice(std::vector<cl::Platform> *platforms) {
}
global_memery_size_ = device_->getInfo<CL_DEVICE_GLOBAL_MEM_SIZE>();
max_alloc_size_ = device_->getInfo<CL_DEVICE_MAX_MEM_ALLOC_SIZE>();
max_image2d_width_ = device_->getInfo<CL_DEVICE_IMAGE2D_MAX_WIDTH>();
max_image2d_height_ = device_->getInfo<CL_DEVICE_IMAGE2D_MAX_HEIGHT>();
MS_LOG(INFO) << "Address space bits: " << device_->getInfo<CL_DEVICE_ADDRESS_BITS>();
MS_LOG(INFO) << "Global Mem Size: " << global_memery_size_;
MS_LOG(INFO) << "Global Mem Cache Size: " << global_memery_cachesize_;
@@ -377,6 +379,7 @@ int OpenCLRuntime::BuildKernel(const cl::Kernel &kernel, const std::string &prog
" -DFP16_ENABLE=0 -DFLT=float -DFLT4=float4 -DFLT16=float16 -DAS_FLT4=as_float4 -DAS_UINT4=as_uint4 -DUINT4=uint4"
" -DWRITE_IMAGE=write_imagef -DREAD_IMAGE=read_imagef -DTO_FLT=convert_float -DTO_FLT4=convert_float4";
}
build_option += " -DMAX_IMAGE2D_WIDTH=" + std::to_string(max_image2d_width_);
build_option =
std::accumulate(build_options_ext.begin(), build_options_ext.end(), build_option,
[](const std::string &options, const std::string &option) { return options + " " + option; });


+ 4
- 2
mindspore/lite/src/runtime/gpu/opencl/opencl_runtime.h View File

@@ -16,8 +16,6 @@ j* you may not use this file except in compliance with the License.

#ifndef MINDSPORE_LITE_SRC_OPENCL_RUNTIME_H_
#define MINDSPORE_LITE_SRC_OPENCL_RUNTIME_H_
// Get from Device?
#define MAX_IMAGE2D_SIZE 65535
#include <vector>
#include <map>
#include <memory>
@@ -66,6 +64,8 @@ class OpenCLRuntime {
uint32_t GetSubGroupSize(const cl::Kernel &kernel, const cl::NDRange &range = cl::NullRange);
uint64_t GetGlobalMemSize() { return global_memery_size_; }
uint64_t GetMaxAllocSize() { return max_alloc_size_; }
uint64_t GetMaxImage2DWidth() { return max_image2d_width_; }
uint64_t GetMaxImage2DHeight() { return max_image2d_height_; }
GpuInfo GetGpuInfo();
bool GetFp16Enable() const;
bool SetFp16Enable(bool enable);
@@ -177,6 +177,8 @@ class OpenCLRuntime {
uint64_t global_memery_cachesize_{0};
uint64_t global_memery_size_{0};
uint64_t max_alloc_size_{0};
uint64_t max_image2d_width_{0};
uint64_t max_image2d_height_{0};
int max_work_group_size_{1};
uint32_t compute_units_{0};
uint32_t max_freq_{0};


+ 28
- 29
mindspore/lite/src/runtime/kernel/opencl/cl/conv2d.cl View File

@@ -3,7 +3,6 @@
__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;

#define CI_TILE 4
#define MAX_IMAGE2D_SIZE 65535
#define UP_DIV(x, y) (((x) + (y) - (1)) / (y))

#define DEFINE_ARGS \
@@ -92,10 +91,10 @@ __kernel void Conv2D_H1W1C1(__read_only image2d_t input, __write_only image2d_t
out_h0_w0_c0 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h0_w0_c0));
}

if (OW * CO_SLICES <= MAX_IMAGE2D_SIZE) {
if (OW * CO_SLICES <= MAX_IMAGE2D_WIDTH) {
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0);
} else {
WRITE_IMAGE(output, (int2)(n_oh0 * CO_SLICES + co_slice0, ow0), out_h0_w0_c0);
WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0);
}
}

@@ -173,15 +172,15 @@ __kernel void Conv2D_H2W1C1(__read_only image2d_t input, __write_only image2d_t
out_h1_w0_c0 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w0_c0));
}

if (OW * CO_SLICES <= MAX_IMAGE2D_SIZE) {
if (OW * CO_SLICES <= MAX_IMAGE2D_WIDTH) {
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0);
} // end if (oh1 < OH)
} else {
WRITE_IMAGE(output, (int2)(n_oh0 * CO_SLICES + co_slice0, ow0), out_h0_w0_c0);
WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(n_oh1 * CO_SLICES + co_slice0, ow0), out_h1_w0_c0);
WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0);
} // end (oh1 < OH)
}
}
@@ -284,7 +283,7 @@ __kernel void Conv2D_H2W1C2(__read_only image2d_t input, __write_only image2d_t
out_h1_w0_c1 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w0_c1));
}

if (OW * CO_SLICES <= MAX_IMAGE2D_SIZE) {
if (OW * CO_SLICES <= MAX_IMAGE2D_WIDTH) {
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0);
@@ -296,14 +295,14 @@ __kernel void Conv2D_H2W1C2(__read_only image2d_t input, __write_only image2d_t
} // end if (oh1 < OH)
} // end if (co_slice1 < CO_SLICES)
} else {
WRITE_IMAGE(output, (int2)(n_oh0 * CO_SLICES + co_slice0, ow0), out_h0_w0_c0);
WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(n_oh1 * CO_SLICES + co_slice0, ow0), out_h1_w0_c0);
WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0);
} // end (oh1 < OH)
if (co_slice1 < CO_SLICES) {
WRITE_IMAGE(output, (int2)(n_oh0 * CO_SLICES + co_slice1, ow0), out_h0_w0_c1);
WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow0), out_h0_w0_c1);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(n_oh1 * CO_SLICES + co_slice1, ow0), out_h1_w0_c1);
WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow0), out_h1_w0_c1);
} // end if (oh1 < OH)
} // end if (co_slice1 < CO_SLICES)
}
@@ -457,7 +456,7 @@ __kernel void Conv2D_H2W2C2(__read_only image2d_t input, __write_only image2d_t
out_h1_w1_c1 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w1_c1));
}

if (OW * CO_SLICES <= MAX_IMAGE2D_SIZE) {
if (OW * CO_SLICES <= MAX_IMAGE2D_WIDTH) {
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0);
WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh0), out_h0_w1_c0);
if (oh1 < OH) {
@@ -473,18 +472,18 @@ __kernel void Conv2D_H2W2C2(__read_only image2d_t input, __write_only image2d_t
} // end if (oh1 < OH)
} // end if (co_slice1 < CO_SLICES)
} else {
WRITE_IMAGE(output, (int2)(n_oh0 * CO_SLICES + co_slice0, ow0), out_h0_w0_c0);
WRITE_IMAGE(output, (int2)(n_oh0 * CO_SLICES + co_slice0, ow1), out_h0_w1_c0);
WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0);
WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow1), out_h0_w1_c0);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(n_oh1 * CO_SLICES + co_slice0, ow0), out_h1_w0_c0);
WRITE_IMAGE(output, (int2)(n_oh1 * CO_SLICES + co_slice0, ow1), out_h1_w1_c0);
WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0);
WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow1), out_h1_w1_c0);
} // end (oh1 < OH)
if (co_slice1 < CO_SLICES) {
WRITE_IMAGE(output, (int2)(n_oh0 * CO_SLICES + co_slice1, ow0), out_h0_w0_c1);
WRITE_IMAGE(output, (int2)(n_oh0 * CO_SLICES + co_slice1, ow1), out_h0_w1_c1);
WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow0), out_h0_w0_c1);
WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow1), out_h0_w1_c1);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(n_oh1 * CO_SLICES + co_slice1, ow0), out_h1_w0_c1);
WRITE_IMAGE(output, (int2)(n_oh1 * CO_SLICES + co_slice1, ow1), out_h1_w1_c1);
WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow0), out_h1_w0_c1);
WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow1), out_h1_w1_c1);
} // end if (oh1 < OH)
} // end if (co_slice1 < CO_SLICES)
}
@@ -645,7 +644,7 @@ __kernel void Conv2D_H2W2C2_Img(__read_only image2d_t input, __write_only image2
out_h1_w1_c1 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w1_c1));
}

if (OW * CO_SLICES <= MAX_IMAGE2D_SIZE) {
if (OW * CO_SLICES <= MAX_IMAGE2D_WIDTH) {
WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0);
WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh0), out_h0_w1_c0);
if (oh1 < OH) {
@@ -661,18 +660,18 @@ __kernel void Conv2D_H2W2C2_Img(__read_only image2d_t input, __write_only image2
} // end if (oh1 < OH)
} // end if (co_slice1 < CO_SLICES)
} else {
WRITE_IMAGE(output, (int2)(n_oh0 * CO_SLICES + co_slice0, ow0), out_h0_w0_c0);
WRITE_IMAGE(output, (int2)(n_oh0 * CO_SLICES + co_slice0, ow1), out_h0_w1_c0);
WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0);
WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow1), out_h0_w1_c0);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(n_oh1 * CO_SLICES + co_slice0, ow0), out_h1_w0_c0);
WRITE_IMAGE(output, (int2)(n_oh1 * CO_SLICES + co_slice0, ow1), out_h1_w1_c0);
WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0);
WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow1), out_h1_w1_c0);
} // end (oh1 < OH)
if (co_slice1 < CO_SLICES) {
WRITE_IMAGE(output, (int2)(n_oh0 * CO_SLICES + co_slice1, ow0), out_h0_w0_c1);
WRITE_IMAGE(output, (int2)(n_oh0 * CO_SLICES + co_slice1, ow1), out_h0_w1_c1);
WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow0), out_h0_w0_c1);
WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow1), out_h0_w1_c1);
if (oh1 < OH) {
WRITE_IMAGE(output, (int2)(n_oh1 * CO_SLICES + co_slice1, ow0), out_h1_w0_c1);
WRITE_IMAGE(output, (int2)(n_oh1 * CO_SLICES + co_slice1, ow1), out_h1_w1_c1);
WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow0), out_h1_w0_c1);
WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow1), out_h1_w1_c1);
} // end if (oh1 < OH)
} // end if (co_slice1 < CO_SLICES)
}


+ 18
- 4
mindspore/lite/src/runtime/kernel/opencl/cl/to_format.cl View File

@@ -25,7 +25,10 @@ __kernel void to_format_NHWC_to_NHWC4_IMG_float(__global float4 *src_data, __wri
data.z = (FLT)src_addr[2];
}
}
WRITE_IMAGE(dst_data, (int2)(Y * size.z + Z, X), data);
if (size.y * size.z <= MAX_IMAGE2D_WIDTH)
WRITE_IMAGE(dst_data, (int2)(Y * size.z + Z, X), data);
else
WRITE_IMAGE(dst_data, (int2)(Z, X * size.y + Y), data);
}
__kernel void to_format_NHWC_to_NHWC4_IMG_half(__global half4 *src_data, __write_only image2d_t dst_data, int4 size,
int4 shape) {
@@ -52,7 +55,10 @@ __kernel void to_format_NHWC_to_NHWC4_IMG_half(__global half4 *src_data, __write
data.z = (FLT)src_addr[2];
}
}
WRITE_IMAGE(dst_data, (int2)(Y * size.z + Z, X), data);
if (size.y * size.z <= MAX_IMAGE2D_WIDTH)
WRITE_IMAGE(dst_data, (int2)(Y * size.z + Z, X), data);
else
WRITE_IMAGE(dst_data, (int2)(Z, X * size.y + Y), data);
}
__kernel void to_format_NCHW_to_NHWC4_IMG_float(__global float4 *src_data, __write_only image2d_t dst_data, int4 size,
int4 shape) {
@@ -236,7 +242,11 @@ __kernel void to_format_NHWC4_to_NHWC_BUF_float(__read_only image2d_t src_data,
if (X >= size.x || Y >= size.y || Z >= size.z) {
return;
}
float4 data = convert_float4(READ_IMAGEIN(src_data, smp_zero, (int2)(Y * size.z + Z, X)));
float4 data;
if (size.y * size.z <= MAX_IMAGE2D_WIDTH)
data = convert_float4(READ_IMAGEIN(src_data, smp_zero, (int2)(Y * size.z + Z, X)));
else
data = convert_float4(READ_IMAGEIN(src_data, smp_zero, (int2)(Z, X * size.y + Y)));
int offset = (X * shape.z + Y) * shape.w + Z * 4;
__global float *dst_addr = (__global float *)dst_data;
dst_addr += offset;
@@ -320,7 +330,11 @@ __kernel void to_format_NHWC4_to_NHWC_BUF_half(__read_only image2d_t src_data, _
if (X >= size.x || Y >= size.y || Z >= size.z) {
return;
}
half4 data = convert_half4(READ_IMAGEIN(src_data, smp_zero, (int2)(Y * size.z + Z, X)));
half4 data;
if (size.y * size.z <= MAX_IMAGE2D_WIDTH)
data = convert_half4(READ_IMAGEIN(src_data, smp_zero, (int2)(Y * size.z + Z, X)));
else
data = convert_half4(READ_IMAGEIN(src_data, smp_zero, (int2)(Z, X * size.y + Y)));
int offset = (X * shape.z + Y) * shape.w + Z * 4;
__global half *dst_addr = (__global half *)dst_data;
dst_addr += offset;


+ 43
- 4
mindspore/lite/src/runtime/kernel/opencl/cl/transpose.cl View File

@@ -87,18 +87,18 @@ __kernel void transpose_0312_oversize_NHWC4(__read_only image2d_t src_data, __wr
}
int H4 = UP_DIV(shape.y, 4);
int C4 = UP_DIV(shape.w, 4);
FLT4 src0 = READ_IMAGE(src_data, smp_zero, (int2)(Y * H4 + X, 4 * Z));
FLT4 src0 = READ_IMAGE(src_data, smp_zero, (int2)(X, Y * shape.w + 4 * Z));
FLT4 src1 = (FLT4)0.f;
if (4 * Z + 1 < shape.w) {
src1 = READ_IMAGE(src_data, smp_zero, (int2)(Y * H4 + X, 4 * Z + 1));
src1 = READ_IMAGE(src_data, smp_zero, (int2)(X, Y * shape.w + 4 * Z + 1));
}
FLT4 src2 = (FLT4)0.f;
if (4 * Z + 2 < shape.w) {
src2 = READ_IMAGE(src_data, smp_zero, (int2)(Y * H4 + X, 4 * Z + 2));
src2 = READ_IMAGE(src_data, smp_zero, (int2)(X, Y * shape.w + 4 * Z + 2));
}
FLT4 src3 = (FLT4)0.f;
if (4 * Z + 3 < shape.w) {
src3 = READ_IMAGE(src_data, smp_zero, (int2)(Y * H4 + X, 4 * Z + 3));
src3 = READ_IMAGE(src_data, smp_zero, (int2)(X, Y * shape.w + 4 * Z + 3));
}
FLT4 dst0 = (FLT4)(src0.x, src1.x, src2.x, src3.x);
FLT4 dst1 = (FLT4)(src0.y, src1.y, src2.y, src3.y);
@@ -154,6 +154,45 @@ __kernel void transpose_0231_NHWC4(__read_only image2d_t src_data, __write_only
}
}

__kernel void transpose_0231_oversize_NHWC4(__read_only image2d_t src_data, __write_only image2d_t dst_data,
int4 shape) {
int X = get_global_id(0); // H, W for src
int Y = get_global_id(1); // W4, C4 for src
int Z = get_global_id(2); // C4, H4 for src
if (X >= shape.y || 4 * Y >= shape.z || 4 * Z >= shape.w) {
return;
}
int W4 = UP_DIV(shape.z, 4);
int C4 = UP_DIV(shape.w, 4);
FLT4 src0 = READ_IMAGE(src_data, smp_zero, (int2)(Y, 4 * Z * shape.y + X));
FLT4 src1 = (FLT4)0.f;
if (4 * Z + 1 < shape.w) {
src1 = READ_IMAGE(src_data, smp_zero, (int2)(Y, (4 * Z + 1) * shape.y + X));
}
FLT4 src2 = (FLT4)0.f;
if (4 * Z + 2 < shape.w) {
src2 = READ_IMAGE(src_data, smp_zero, (int2)(Y, (4 * Z + 2) * shape.y + X));
}
FLT4 src3 = (FLT4)0.f;
if (4 * Z + 3 < shape.w) {
src3 = READ_IMAGE(src_data, smp_zero, (int2)(Y, (4 * Z + 3) * shape.y + X));
}
FLT4 dst0 = (FLT4)(src0.x, src1.x, src2.x, src3.x);
FLT4 dst1 = (FLT4)(src0.y, src1.y, src2.y, src3.y);
FLT4 dst2 = (FLT4)(src0.z, src1.z, src2.z, src3.z);
FLT4 dst3 = (FLT4)(src0.w, src1.w, src2.w, src3.w);
WRITE_IMAGE(dst_data, (int2)(4 * Y * C4 + Z, X), dst0);
if (4 * Y + 1 < shape.z) {
WRITE_IMAGE(dst_data, (int2)((4 * Y + 1) * C4 + Z, X), dst1);
}
if (4 * Y + 2 < shape.z) {
WRITE_IMAGE(dst_data, (int2)((4 * Y + 2) * C4 + Z, X), dst2);
}
if (4 * Y + 3 < shape.z) {
WRITE_IMAGE(dst_data, (int2)((4 * Y + 3) * C4 + Z, X), dst3);
}
}

__kernel void transpose_0231_NC4HW4(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 shape) {
int X = get_global_id(0); // H, W for src
int Y = get_global_id(1); // W4, C4 for src


+ 4
- 4
mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.cc View File

@@ -91,7 +91,7 @@ int TransposeOpenCLKernel::Prepare() {
kernel_name += "_general";
}
if (in_tensors_[0]->shape().size() == 4 &&
in_tensors_[0]->shape()[2] * UP_DIV(in_tensors_[0]->shape()[3], C4NUM) > MAX_IMAGE2D_SIZE) {
in_tensors_[0]->shape()[2] * UP_DIV(in_tensors_[0]->shape()[3], C4NUM) > ocl_runtime_->GetMaxImage2DWidth()) {
// just for input
kernel_name += "_oversize";
}
@@ -127,9 +127,9 @@ void TransposeOpenCLKernel::SetConstArgs() {
cl_int4 de_perm_cl = {de_perm[0], de_perm[1], de_perm[2], de_perm[3]};
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, de_perm_cl);
GpuTensorInfo in_shape = GpuTensorInfo(in_tensors_[0]);
cl_int4 out_shape = {static_cast<cl_int>(in_shape.N), static_cast<cl_int>(in_shape.H),
static_cast<cl_int>(in_shape.W), static_cast<cl_int>(in_shape.C)};
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, out_shape);
cl_int4 in_shape_int4 = {static_cast<cl_int>(in_shape.N), static_cast<cl_int>(in_shape.H),
static_cast<cl_int>(in_shape.W), static_cast<cl_int>(in_shape.C)};
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, in_shape_int4);
}
}



+ 9
- 0
mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.cc View File

@@ -400,4 +400,13 @@ void OpenCLKernel::FreeDequantedWeight() {
weight_tensor->set_data(restore_quant_data_);
}
}

int OpenCLKernel::CheckSpecs() {
if (out_mem_type_ == lite::opencl::MemType::IMG) {
if (!GpuTensorInfo(out_tensors_[0]).IsImageSizeValid()) {
return RET_ERROR;
}
}
return RET_OK;
}
} // namespace mindspore::kernel

+ 17
- 4
mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h View File

@@ -80,6 +80,7 @@ void Broadcast2GpuShape(DstT *dst, const SrcT *src, int src_num, DstT default_va
struct GpuTensorInfo {
GpuTensorInfo() = default;
explicit GpuTensorInfo(const lite::Tensor *tensor) {
auto ocl_runtime_wrap_ = lite::opencl::OpenCLRuntimeWrapper();
if (tensor == nullptr) {
return;
}
@@ -95,12 +96,16 @@ struct GpuTensorInfo {

FLT_size = tensor->data_type() == kNumberTypeFloat16 ? sizeof(cl_half) : sizeof(cl_float);
FLT4_size = FLT_size * 4;
if (W * Slice <= MAX_IMAGE2D_SIZE) {
if (W * Slice <= ocl_runtime_wrap_.GetInstance()->GetMaxImage2DWidth()) {
height = N * H;
width = W * Slice;
} else {
height = W;
width = N * H * Slice;
height = N * H * W;
width = Slice;
if (height > ocl_runtime_wrap_.GetInstance()->GetMaxImage2DHeight()) {
height = -1;
width = -1;
}
}

ElementsNum = N * H * W * C;
@@ -128,6 +133,8 @@ struct GpuTensorInfo {
return static_cast<int>(no_neg_axis + 4 - NDim);
}

bool IsImageSizeValid() { return width > 0 && height > 0; }

size_t N{1};
size_t H{1};
size_t W{1};
@@ -183,7 +190,7 @@ class OpenCLKernel : public LiteKernel {
int ReSize() override;
int Run() override { return RET_ERROR; }

virtual int CheckSpecs() { return RET_ERROR; }
virtual int CheckSpecs();
virtual int InitWeights() { return RET_OK; }
virtual void SetConstArgs() {}
virtual void SetGlobalLocal() {}
@@ -257,6 +264,12 @@ kernel::LiteKernel *OpenCLKernelCreator(const std::vector<lite::Tensor *> &input
delete kernel;
return nullptr;
}
ret = kernel->OpenCLKernel::CheckSpecs();
if (ret != mindspore::lite::RET_OK) {
MS_LOG(ERROR) << "Check " << opParameter->name_ << " specification failed!";
delete kernel;
return nullptr;
}
return kernel;
}
} // namespace mindspore::kernel


Loading…
Cancel
Save