|
|
|
@@ -192,41 +192,19 @@ kernel::LiteKernel *CpuScaleFp32KernelCreator(const std::vector<lite::Tensor *> |
|
|
|
MS_LOG(ERROR) << "opParameter is nullptr"; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
auto *weight_tensor = inputs.at(kWeightIndex); |
|
|
|
auto *restore_data = weight_tensor->MutableData(); |
|
|
|
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { |
|
|
|
auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor); |
|
|
|
if (dequant_weight == nullptr) { |
|
|
|
MS_LOG(ERROR) << "dequant data is nullptr."; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
weight_tensor->SetData(dequant_weight); |
|
|
|
} |
|
|
|
|
|
|
|
auto *kernel = new (std::nothrow) ScaleCPUKernel(opParameter, inputs, outputs, ctx, primitive); |
|
|
|
if (kernel == nullptr) { |
|
|
|
MS_LOG(ERROR) << "New kernel fails."; |
|
|
|
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { |
|
|
|
weight_tensor->FreeData(); |
|
|
|
weight_tensor->SetData(restore_data); |
|
|
|
} |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
auto ret = kernel->Init(); |
|
|
|
if (ret != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " |
|
|
|
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); |
|
|
|
delete kernel; |
|
|
|
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { |
|
|
|
weight_tensor->FreeData(); |
|
|
|
weight_tensor->SetData(restore_data); |
|
|
|
} |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { |
|
|
|
weight_tensor->FreeData(); |
|
|
|
weight_tensor->SetData(restore_data); |
|
|
|
} |
|
|
|
return kernel; |
|
|
|
} |
|
|
|
|
|
|
|
|