Browse Source

!11479 【MS】【LITE】【GPU】opencl support int dtype

From: @wangdongxu6
Reviewed-by: @HilbertDavid,@ddwsky
Signed-off-by: @ddwsky
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
1929b420b2
4 changed files with 43 additions and 5 deletions
  1. +26
    -1
      mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.cc
  2. +4
    -1
      mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.cc
  3. +12
    -2
      mindspore/lite/src/runtime/opencl/opencl_allocator.cc
  4. +1
    -1
      mindspore/lite/src/runtime/opencl/opencl_runtime.cc

+ 26
- 1
mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.cc View File

@@ -67,7 +67,30 @@ int OpenCLKernel::GetImageSize(size_t idx, lite::opencl::ImageSize *img_size) {
return RET_ERROR;
}
auto img_info = GpuTensorInfo(out_tensors_[idx]);
size_t img_dtype = ocl_runtime_->GetFp16Enable() ? CL_HALF_FLOAT : CL_FLOAT;
size_t img_dtype = CL_FLOAT;
switch (out_tensors_[idx]->data_type()) {
case kNumberTypeFloat32:
case kNumberTypeInt32:
case kNumberTypeUInt32: {
img_dtype = CL_FLOAT;
break;
}
case kNumberTypeFloat16:
case kNumberTypeInt16:
case kNumberTypeUInt16: {
img_dtype = CL_HALF_FLOAT;
break;
}
case kNumberTypeInt8:
case kNumberTypeUInt8: {
img_dtype = CL_UNSIGNED_INT8;
break;
}
default: {
MS_LOG(WARNING) << "Unsupported data_type " << out_tensors_[idx]->data_type();
return RET_ERROR;
}
}
*img_size = {img_info.width, img_info.height, img_dtype};
return RET_OK;
}
@@ -326,6 +349,7 @@ std::set<size_t> OpenCLKernel::GenerateLocalByGlobal(size_t global_i) {
}
return local_;
}

int OpenCLKernel::DequantWeight() {
bool is_fp16 = ocl_runtime_->GetFp16Enable();
auto *weight_tensor = in_tensors_.at(kWeightIndex);
@@ -368,6 +392,7 @@ int OpenCLKernel::DequantWeight() {
}
return RET_OK;
}

void OpenCLKernel::FreeDequantedWeight() {
auto *weight_tensor = in_tensors_.at(kWeightIndex);
if (dequant_flag_) {


+ 4
- 1
mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.cc View File

@@ -254,7 +254,10 @@ int OpenCLSubGraph::UpdateTensorDataTypePass() {
for (auto jv : cur_outs) {
if (out_set.count(jv) == 0) {
MS_ASSERT(jv);
jv->set_data_type(kNumberTypeFloat16);
// if Fp16Enable, only change fp32 to fp16, other dtype is reserved
if (jv->data_type() == kNumberTypeFloat32) {
jv->set_data_type(kNumberTypeFloat16);
}
}
}
}


+ 12
- 2
mindspore/lite/src/runtime/opencl/opencl_allocator.cc View File

@@ -140,9 +140,19 @@ void *OpenCLAllocator::_Malloc(MemType mem_type, void *data, size_t size, const
auto svm_capabilities = ocl_runtime_->GetSVMCapabilities();
MS_ASSERT(img_size.size() == 0 || img_size.size() == 3);
if (mem_type == MemType::IMG) {
size_t dtype_size = img_size.dtype == CL_FLOAT ? sizeof(cl_float4) : sizeof(cl_half4);
size_t dtype_size = 0;
if (img_size.dtype == CL_FLOAT) {
dtype_size = sizeof(cl_float);
} else if (img_size.dtype == CL_HALF_FLOAT) {
dtype_size = sizeof(cl_half);
} else if (img_size.dtype == CL_UNSIGNED_INT8) {
dtype_size = sizeof(cl_uchar);
} else {
MS_LOG(ERROR) << "Unsupported dtype " << img_size.dtype;
return nullptr;
}
uint32_t image_alignment = ocl_runtime_->GetImagePitchAlignment();
size = UP_ROUND(img_size.width, image_alignment) * img_size.height * dtype_size;
size = UP_ROUND(img_size.width, image_alignment) * img_size.height * C4NUM * dtype_size;
}
if (size > ocl_runtime_->GetMaxAllocSize()) {
MS_LOG(ERROR) << "MallocData out of max_size, size: " << size;


+ 1
- 1
mindspore/lite/src/runtime/opencl/opencl_runtime.cc View File

@@ -360,7 +360,7 @@ int OpenCLRuntime::BuildKernel(cl::Kernel &kernel, const std::string &program_na
"-DWRITE_IMAGE=write_imageh -DREAD_IMAGE=read_imageh -DTO_FLT=convert_half -DTO_FLT4=convert_half4";
} else {
build_option +=
" -DFLT=float -DFLT4=float4 -DFLT16=float16 -DAS_FLT4=as_float4 -DAS_UINT4=as_uint4 -DUINT4=uint4 "
" -DFLT=float -DFLT4=float4 -DFLT16=float16 -DAS_FLT4=as_float4 -DAS_UINT4=as_uint4 -DUINT4=uint4 "
"-DWRITE_IMAGE=write_imagef -DREAD_IMAGE=read_imagef -DTO_FLT=convert_float -DTO_FLT4=convert_float4";
}
build_option =


Loading…
Cancel
Save