From aee2cc6a35ccf99667451afc514e40cb24d961b6 Mon Sep 17 00:00:00 2001 From: Corleone Date: Tue, 15 Sep 2020 11:03:46 +0800 Subject: [PATCH] fixed opencl build error for fp16 --- .../lite/src/runtime/kernel/opencl/cl/arithmetic.cl | 8 ++++---- .../src/runtime/kernel/opencl/kernel/arithmetic.cc | 10 +++++++++- mindspore/lite/src/runtime/opencl/opencl_runtime.cc | 4 ++-- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/arithmetic.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/arithmetic.cl index 7f79bcc243..e01d0cdd4d 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/arithmetic.cl +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/arithmetic.cl @@ -64,7 +64,7 @@ __kernel void ElementAnd_IMG(__read_only image2d_t input_a, __read_only image2d_ FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(X, Y)); FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(X, Y)); - WRITE_IMAGE(output, (int2)(X, Y), AS_FLT4(as_int4(a) & as_int4(b))); + WRITE_IMAGE(output, (int2)(X, Y), AS_FLT4(AS_UINT4(a) & AS_UINT4(b))); } __kernel void ElementOr_IMG(__read_only image2d_t input_a, __read_only image2d_t input_b, __write_only image2d_t output, @@ -77,7 +77,7 @@ __kernel void ElementOr_IMG(__read_only image2d_t input_a, __read_only image2d_t FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(X, Y)); FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(X, Y)); - WRITE_IMAGE(output, (int2)(X, Y), AS_FLT4(as_int4(a) | as_int4(b))); + WRITE_IMAGE(output, (int2)(X, Y), AS_FLT4(AS_UINT4(a) | AS_UINT4(b))); } __kernel void ElementMax_IMG(__read_only image2d_t input_a, __read_only image2d_t input_b, @@ -279,7 +279,7 @@ __kernel void BroadcastAnd_IMG(__read_only image2d_t input_a, float b, __write_o } FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(X, Y)); - WRITE_IMAGE(output, (int2)(X, Y), AS_FLT4(as_int4(a) & (int4)(b))); + WRITE_IMAGE(output, (int2)(X, Y), AS_FLT4(AS_UINT4(a) & (UINT4)((FLT)b))); } __kernel void BroadcastOr_IMG(__read_only image2d_t input_a, float b, __write_only image2d_t output, @@ -291,7 +291,7 @@ __kernel void BroadcastOr_IMG(__read_only image2d_t input_a, float b, __write_on } FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(X, Y)); - WRITE_IMAGE(output, (int2)(X, Y), AS_FLT4(as_int4(a) | (int4)b)); + WRITE_IMAGE(output, (int2)(X, Y), AS_FLT4(AS_UINT4(a) | (UINT4)((FLT)b))); } __kernel void BroadcastMax_IMG(__read_only image2d_t input_a, float b, __write_only image2d_t output, diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.cc index 826d9ddc2c..24ad00b463 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.cc @@ -307,7 +307,15 @@ int ArithmeticOpenCLKernel::Run() { void *weight = weight_ptr_ == nullptr ? in_tensors_[1]->MutableData() : weight_ptr_; runtime_->SetKernelArg(kernel_, arg_idx++, weight); } else { - float weight = static_cast(in_tensors_[1]->MutableData())[0]; + float weight = 0.f; + if (in_tensors_[1]->data_type() == kNumberTypeFloat32) { + weight = static_cast(in_tensors_[1]->MutableData())[0]; + } else if (in_tensors_[1]->data_type() == kNumberTypeFloat16) { + weight = static_cast(static_cast(in_tensors_[1]->MutableData())[0]); + } else { + MS_LOG(ERROR) << "Unsupport data type " << in_tensors_[1]->data_type(); + return RET_ERROR; + } runtime_->SetKernelArg(kernel_, arg_idx++, weight); } runtime_->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->MutableData()); diff --git a/mindspore/lite/src/runtime/opencl/opencl_runtime.cc b/mindspore/lite/src/runtime/opencl/opencl_runtime.cc index 2f13dad701..35159a5219 100644 --- a/mindspore/lite/src/runtime/opencl/opencl_runtime.cc +++ b/mindspore/lite/src/runtime/opencl/opencl_runtime.cc @@ -300,12 +300,12 @@ int OpenCLRuntime::BuildKernel(cl::Kernel &kernel, const std::string &program_na if (fp16_enable_) { // fp16 enable, kernel will use half and read_imageh and write_imageh. build_options_str = - "-DFLT=half -DFLT4=half4 -DFLT16=half16 -DAS_FLT4=as_half4 " + "-DFLT=half -DFLT4=half4 -DFLT16=half16 -DAS_FLT4=as_half4 -DAS_UINT4=as_ushort4 -DUINT4=ushort4 " "-DWRITE_IMAGE=write_imageh -DREAD_IMAGE=read_imageh -DTO_FLT=convert_half -DTO_FLT4=convert_half4 "; } else { // fp16 not enable, kernel will use float and read_imagef and write_imagef. build_options_str = - "-DFLT=float -DFLT4=float4 -DFLT16=float16 -DAS_FLT4=as_float4 " + "-DFLT=float -DFLT4=float4 -DFLT16=float16 -DAS_FLT4=as_float4 -DAS_UINT4=as_uint4 -DUINT4=uint4 " "-DWRITE_IMAGE=write_imagef -DREAD_IMAGE=read_imagef -DTO_FLT=convert_float -DTO_FLT4=convert_float4 "; }