|
|
|
@@ -67,7 +67,30 @@ int OpenCLKernel::GetImageSize(size_t idx, lite::opencl::ImageSize *img_size) { |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
auto img_info = GpuTensorInfo(out_tensors_[idx]); |
|
|
|
size_t img_dtype = ocl_runtime_->GetFp16Enable() ? CL_HALF_FLOAT : CL_FLOAT; |
|
|
|
size_t img_dtype = CL_FLOAT; |
|
|
|
switch (out_tensors_[idx]->data_type()) { |
|
|
|
case kNumberTypeFloat32: |
|
|
|
case kNumberTypeInt32: |
|
|
|
case kNumberTypeUInt32: { |
|
|
|
img_dtype = CL_FLOAT; |
|
|
|
break; |
|
|
|
} |
|
|
|
case kNumberTypeFloat16: |
|
|
|
case kNumberTypeInt16: |
|
|
|
case kNumberTypeUInt16: { |
|
|
|
img_dtype = CL_HALF_FLOAT; |
|
|
|
break; |
|
|
|
} |
|
|
|
case kNumberTypeInt8: |
|
|
|
case kNumberTypeUInt8: { |
|
|
|
img_dtype = CL_UNSIGNED_INT8; |
|
|
|
break; |
|
|
|
} |
|
|
|
default: { |
|
|
|
MS_LOG(WARNING) << "Unsupported data_type " << out_tensors_[idx]->data_type(); |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
} |
|
|
|
*img_size = {img_info.width, img_info.height, img_dtype}; |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
@@ -326,6 +349,7 @@ std::set<size_t> OpenCLKernel::GenerateLocalByGlobal(size_t global_i) { |
|
|
|
} |
|
|
|
return local_; |
|
|
|
} |
|
|
|
|
|
|
|
int OpenCLKernel::DequantWeight() { |
|
|
|
bool is_fp16 = ocl_runtime_->GetFp16Enable(); |
|
|
|
auto *weight_tensor = in_tensors_.at(kWeightIndex); |
|
|
|
@@ -368,6 +392,7 @@ int OpenCLKernel::DequantWeight() { |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
void OpenCLKernel::FreeDequantedWeight() { |
|
|
|
auto *weight_tensor = in_tensors_.at(kWeightIndex); |
|
|
|
if (dequant_flag_) { |
|
|
|
|