| @@ -214,6 +214,7 @@ float *LiteKernelUtil::DequantWeight(lite::Tensor *input_tensor) { | |||
| dequant_datas[j] = static_cast<float>((quant_datas[j] - zero_point) * scale); | |||
| } | |||
| } | |||
| return dequant_datas; | |||
| } | |||
| } // namespace mindspore::kernel | |||
| @@ -151,6 +151,7 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector<lite::Tensor *> | |||
| MS_LOG(ERROR) << "dequant data is nullptr."; | |||
| return nullptr; | |||
| } | |||
| weight_tensor->set_data_type(kNumberTypeFloat32); | |||
| weight_tensor->SetData(dequant_weight); | |||
| } | |||
| @@ -166,6 +167,7 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector<lite::Tensor *> | |||
| MS_LOG(ERROR) << "kernel is nullptr."; | |||
| if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data_type(kNumberTypeInt8); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| return nullptr; | |||
| @@ -177,12 +179,14 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector<lite::Tensor *> | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||
| if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data_type(kNumberTypeInt8); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| return nullptr; | |||
| } | |||
| if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data_type(kNumberTypeInt8); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| return kernel; | |||
| @@ -196,6 +196,7 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::Tensor *> & | |||
| MS_LOG(ERROR) << "dequant data is nullptr."; | |||
| return nullptr; | |||
| } | |||
| weight_tensor->set_data_type(kNumberTypeFloat32); | |||
| weight_tensor->SetData(dequant_weight); | |||
| } | |||
| @@ -232,6 +233,7 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::Tensor *> & | |||
| MS_LOG(DEBUG) << "Create conv fp16 kernel failed."; | |||
| if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data_type(kNumberTypeInt8); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| return nullptr; | |||
| @@ -243,12 +245,14 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::Tensor *> & | |||
| << ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||
| if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data_type(kNumberTypeInt8); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| return nullptr; | |||
| } | |||
| if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data_type(kNumberTypeInt8); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| return kernel; | |||
| @@ -210,6 +210,7 @@ kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vector<lite::Tensor | |||
| MS_LOG(ERROR) << "dequant data is nullptr."; | |||
| return nullptr; | |||
| } | |||
| weight_tensor->set_data_type(kNumberTypeFloat32); | |||
| weight_tensor->SetData(dequant_weight); | |||
| } | |||
| @@ -218,6 +219,7 @@ kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vector<lite::Tensor | |||
| MS_LOG(ERROR) << "kernel is nullptr."; | |||
| if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data_type(kNumberTypeInt8); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| return nullptr; | |||
| @@ -229,12 +231,14 @@ kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vector<lite::Tensor | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||
| if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data_type(kNumberTypeInt8); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| return nullptr; | |||
| } | |||
| if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data_type(kNumberTypeInt8); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| return kernel; | |||
| @@ -217,6 +217,7 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector<lite::Tensor *> | |||
| MS_LOG(ERROR) << "dequant data is nullptr."; | |||
| return nullptr; | |||
| } | |||
| weight_tensor->set_data_type(kNumberTypeFloat32); | |||
| weight_tensor->SetData(dequant_weight); | |||
| } | |||
| @@ -225,6 +226,7 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector<lite::Tensor *> | |||
| MS_LOG(ERROR) << "kernel is nullptr."; | |||
| if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data_type(kNumberTypeInt8); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| return nullptr; | |||
| @@ -236,12 +238,14 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector<lite::Tensor *> | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||
| if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data_type(kNumberTypeInt8); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| return nullptr; | |||
| } | |||
| if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data_type(kNumberTypeInt8); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| return kernel; | |||
| @@ -198,6 +198,7 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vector<lite::T | |||
| MS_LOG(ERROR) << "dequant data is nullptr."; | |||
| return nullptr; | |||
| } | |||
| weight_tensor->set_data_type(kNumberTypeFloat32); | |||
| weight_tensor->SetData(dequant_weight); | |||
| } | |||
| auto *kernel = new (std::nothrow) FullconnectionFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| @@ -205,6 +206,7 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vector<lite::T | |||
| MS_LOG(ERROR) << "kernel is nullptr."; | |||
| if (!weight_tensor->GetQuantParams().empty()) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data_type(kNumberTypeInt8); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| return nullptr; | |||
| @@ -216,12 +218,14 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vector<lite::T | |||
| delete kernel; | |||
| if (!weight_tensor->GetQuantParams().empty()) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data_type(kNumberTypeInt8); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| return nullptr; | |||
| } | |||
| if (!weight_tensor->GetQuantParams().empty()) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data_type(kNumberTypeInt8); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| return kernel; | |||
| @@ -261,6 +261,7 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector<lite::Tensor *> | |||
| MS_LOG(ERROR) << "dequant data is nullptr."; | |||
| return nullptr; | |||
| } | |||
| weight_tensor->set_data_type(kNumberTypeFloat32); | |||
| weight_tensor->SetData(dequant_weight); | |||
| } | |||
| auto *kernel = new (std::nothrow) MatmulFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| @@ -268,6 +269,7 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector<lite::Tensor *> | |||
| MS_LOG(ERROR) << "kernel is nullptr."; | |||
| if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data_type(kNumberTypeInt8); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| return nullptr; | |||
| @@ -279,12 +281,14 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector<lite::Tensor *> | |||
| delete kernel; | |||
| if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data_type(kNumberTypeInt8); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| return nullptr; | |||
| } | |||
| if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) { | |||
| weight_tensor->FreeData(); | |||
| weight_tensor->set_data_type(kNumberTypeInt8); | |||
| weight_tensor->SetData(restore_data); | |||
| } | |||
| return kernel; | |||