|
|
|
@@ -351,57 +351,6 @@ std::set<size_t> OpenCLKernel::GenerateLocalByGlobal(size_t global_i) { |
|
|
|
return local_; |
|
|
|
} |
|
|
|
|
|
|
|
int OpenCLKernel::DequantWeight() { |
|
|
|
bool is_fp16 = ocl_runtime_->GetFp16Enable(); |
|
|
|
auto *weight_tensor = in_tensors_.at(kWeightIndex); |
|
|
|
restore_quant_data_ = weight_tensor->data_c(); |
|
|
|
dequant_flag_ = !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && |
|
|
|
restore_quant_data_ != nullptr; |
|
|
|
if (dequant_flag_) { |
|
|
|
void *dequant_weight{nullptr}; |
|
|
|
bool set_flag{true}; |
|
|
|
if (is_fp16) { |
|
|
|
#ifdef ENABLE_ARM64 |
|
|
|
if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt8) { |
|
|
|
dequant_weight = lite::DequantUtil::DequantData<int8_t, float16_t>(weight_tensor); |
|
|
|
weight_tensor->set_data_type(kNumberTypeFloat16); |
|
|
|
} else if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt16) { |
|
|
|
dequant_weight = lite::DequantUtil::DequantData<int16_t, float16_t>(weight_tensor); |
|
|
|
weight_tensor->set_data_type(kNumberTypeFloat16); |
|
|
|
} else { |
|
|
|
set_flag = false; |
|
|
|
} |
|
|
|
#else |
|
|
|
set_flag = false; |
|
|
|
#endif |
|
|
|
} else { |
|
|
|
if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt8) { |
|
|
|
dequant_weight = lite::DequantUtil::DequantData<int8_t, float>(weight_tensor); |
|
|
|
weight_tensor->set_data_type(kNumberTypeFloat32); |
|
|
|
} else if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt16) { |
|
|
|
dequant_weight = lite::DequantUtil::DequantData<int16_t, float>(weight_tensor); |
|
|
|
weight_tensor->set_data_type(kNumberTypeFloat32); |
|
|
|
} else { |
|
|
|
set_flag = false; |
|
|
|
} |
|
|
|
} |
|
|
|
if (set_flag && dequant_weight == nullptr) { |
|
|
|
MS_LOG(ERROR) << "dequant data failed."; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
weight_tensor->set_data(dequant_weight); |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
void OpenCLKernel::FreeDequantedWeight() { |
|
|
|
auto *weight_tensor = in_tensors_.at(kWeightIndex); |
|
|
|
if (dequant_flag_) { |
|
|
|
free(weight_tensor->data_c()); |
|
|
|
weight_tensor->set_data(restore_quant_data_); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
int OpenCLKernel::CheckSpecs() { |
|
|
|
if (out_mem_type_ == lite::opencl::MemType::IMG) { |
|
|
|
if (!GpuTensorInfo(out_tensors_[0]).IsImageSizeValid()) { |
|
|
|
|