Browse Source

float16 quant fix

tags/v1.1.0
kai00 5 years ago
parent
commit
797221b144
7 changed files with 25 additions and 0 deletions
  1. +1
    -0
      mindspore/lite/src/lite_kernel.cc
  2. +4
    -0
      mindspore/lite/src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.cc
  3. +4
    -0
      mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc
  4. +4
    -0
      mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_depthwise_fp16.cc
  5. +4
    -0
      mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_fp16.cc
  6. +4
    -0
      mindspore/lite/src/runtime/kernel/arm/fp16/fullconnection_fp16.cc
  7. +4
    -0
      mindspore/lite/src/runtime/kernel/arm/fp16/matmul_fp16.cc

+ 1
- 0
mindspore/lite/src/lite_kernel.cc View File

@@ -214,6 +214,7 @@ float *LiteKernelUtil::DequantWeight(lite::Tensor *input_tensor) {
dequant_datas[j] = static_cast<float>((quant_datas[j] - zero_point) * scale);
}
}

return dequant_datas;
}
} // namespace mindspore::kernel

+ 4
- 0
mindspore/lite/src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.cc View File

@@ -151,6 +151,7 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector<lite::Tensor *>
MS_LOG(ERROR) << "dequant data is nullptr.";
return nullptr;
}
weight_tensor->set_data_type(kNumberTypeFloat32);
weight_tensor->SetData(dequant_weight);
}

@@ -166,6 +167,7 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector<lite::Tensor *>
MS_LOG(ERROR) << "kernel is nullptr.";
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
return nullptr;
@@ -177,12 +179,14 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector<lite::Tensor *>
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
return nullptr;
}
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
return kernel;


+ 4
- 0
mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc View File

@@ -196,6 +196,7 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::Tensor *> &
MS_LOG(ERROR) << "dequant data is nullptr.";
return nullptr;
}
weight_tensor->set_data_type(kNumberTypeFloat32);
weight_tensor->SetData(dequant_weight);
}

@@ -232,6 +233,7 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::Tensor *> &
MS_LOG(DEBUG) << "Create conv fp16 kernel failed.";
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
return nullptr;
@@ -243,12 +245,14 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::Tensor *> &
<< ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
return nullptr;
}
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
return kernel;


+ 4
- 0
mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_depthwise_fp16.cc View File

@@ -210,6 +210,7 @@ kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vector<lite::Tensor
MS_LOG(ERROR) << "dequant data is nullptr.";
return nullptr;
}
weight_tensor->set_data_type(kNumberTypeFloat32);
weight_tensor->SetData(dequant_weight);
}

@@ -218,6 +219,7 @@ kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vector<lite::Tensor
MS_LOG(ERROR) << "kernel is nullptr.";
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
return nullptr;
@@ -229,12 +231,14 @@ kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vector<lite::Tensor
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
return nullptr;
}
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
return kernel;


+ 4
- 0
mindspore/lite/src/runtime/kernel/arm/fp16/deconvolution_fp16.cc View File

@@ -217,6 +217,7 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector<lite::Tensor *>
MS_LOG(ERROR) << "dequant data is nullptr.";
return nullptr;
}
weight_tensor->set_data_type(kNumberTypeFloat32);
weight_tensor->SetData(dequant_weight);
}

@@ -225,6 +226,7 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector<lite::Tensor *>
MS_LOG(ERROR) << "kernel is nullptr.";
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
return nullptr;
@@ -236,12 +238,14 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector<lite::Tensor *>
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
return nullptr;
}
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
return kernel;


+ 4
- 0
mindspore/lite/src/runtime/kernel/arm/fp16/fullconnection_fp16.cc View File

@@ -198,6 +198,7 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vector<lite::T
MS_LOG(ERROR) << "dequant data is nullptr.";
return nullptr;
}
weight_tensor->set_data_type(kNumberTypeFloat32);
weight_tensor->SetData(dequant_weight);
}
auto *kernel = new (std::nothrow) FullconnectionFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
@@ -205,6 +206,7 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vector<lite::T
MS_LOG(ERROR) << "kernel is nullptr.";
if (!weight_tensor->GetQuantParams().empty()) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
return nullptr;
@@ -216,12 +218,14 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vector<lite::T
delete kernel;
if (!weight_tensor->GetQuantParams().empty()) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
return nullptr;
}
if (!weight_tensor->GetQuantParams().empty()) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
return kernel;


+ 4
- 0
mindspore/lite/src/runtime/kernel/arm/fp16/matmul_fp16.cc View File

@@ -261,6 +261,7 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector<lite::Tensor *>
MS_LOG(ERROR) << "dequant data is nullptr.";
return nullptr;
}
weight_tensor->set_data_type(kNumberTypeFloat32);
weight_tensor->SetData(dequant_weight);
}
auto *kernel = new (std::nothrow) MatmulFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
@@ -268,6 +269,7 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector<lite::Tensor *>
MS_LOG(ERROR) << "kernel is nullptr.";
if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
return nullptr;
@@ -279,12 +281,14 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector<lite::Tensor *>
delete kernel;
if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
return nullptr;
}
if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) {
weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data);
}
return kernel;


Loading…
Cancel
Save