Merge pull request !6754 from xutianchun/quant_0923tags/v1.1.0
| @@ -32,6 +32,8 @@ table QuantParam { | |||||
| narrowRange: bool = true; | narrowRange: bool = true; | ||||
| numBits: int = 8; | numBits: int = 8; | ||||
| inited: bool = false; | inited: bool = false; | ||||
| var_corr: double = 1; | |||||
| mean_corr: double = 0; | |||||
| } | } | ||||
| table Tensor { | table Tensor { | ||||
| @@ -174,9 +174,9 @@ float *LiteKernelUtil::DequantWeight(lite::Tensor *input_tensor) { | |||||
| MS_LOG(ERROR) << "no quant param"; | MS_LOG(ERROR) << "no quant param"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| const auto *quant_data = static_cast<const int8_t *>(input_tensor->MutableData()); | |||||
| auto *dequant_data = static_cast<float *>(malloc(input_tensor->ElementsNum() * sizeof(float))); | |||||
| if (dequant_data == nullptr) { | |||||
| const auto *quant_datas = static_cast<const int8_t *>(input_tensor->MutableData()); | |||||
| auto *dequant_datas = static_cast<float *>(malloc(input_tensor->ElementsNum() * sizeof(float))); | |||||
| if (dequant_datas == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc faile"; | MS_LOG(ERROR) << "malloc faile"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -185,7 +185,7 @@ float *LiteKernelUtil::DequantWeight(lite::Tensor *input_tensor) { | |||||
| size_t channels = static_cast<size_t>(input_tensor->Batch()); | size_t channels = static_cast<size_t>(input_tensor->Batch()); | ||||
| if (input_tensor->GetQuantParams().size() != channels) { | if (input_tensor->GetQuantParams().size() != channels) { | ||||
| MS_LOG(ERROR) << "Quant param not equal channel num " << input_tensor->GetQuantParams().size() << channels; | MS_LOG(ERROR) << "Quant param not equal channel num " << input_tensor->GetQuantParams().size() << channels; | ||||
| free(dequant_data); | |||||
| free(dequant_datas); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| size_t per_channel_size = input_tensor->ElementsNum() / channels; | size_t per_channel_size = input_tensor->ElementsNum() / channels; | ||||
| @@ -194,9 +194,15 @@ float *LiteKernelUtil::DequantWeight(lite::Tensor *input_tensor) { | |||||
| auto param = quant_param.at(i); | auto param = quant_param.at(i); | ||||
| auto scale = param.scale; | auto scale = param.scale; | ||||
| auto zero_point = param.zeroPoint; | auto zero_point = param.zeroPoint; | ||||
| auto var_corr = param.var_corr; | |||||
| auto mean_corr = param.mean_corr; | |||||
| if (var_corr < 0 || var_corr > 10) { | |||||
| MS_LOG(WARNING) << "unexpeted var_corr: " << var_corr; | |||||
| var_corr = 1; | |||||
| } | |||||
| for (size_t j = 0; j < per_channel_size; j++) { | for (size_t j = 0; j < per_channel_size; j++) { | ||||
| dequant_data[per_channel_size * i + j] = | |||||
| static_cast<float>((quant_data[per_channel_size * i + j] - zero_point) * scale); | |||||
| auto dequant_data = (quant_datas[per_channel_size * i + j] - zero_point) * scale; | |||||
| dequant_datas[per_channel_size * i + j] = static_cast<float>(dequant_data * var_corr + mean_corr); | |||||
| } | } | ||||
| } | } | ||||
| } else { | } else { | ||||
| @@ -205,9 +211,9 @@ float *LiteKernelUtil::DequantWeight(lite::Tensor *input_tensor) { | |||||
| auto scale = param.scale; | auto scale = param.scale; | ||||
| auto zero_point = param.zeroPoint; | auto zero_point = param.zeroPoint; | ||||
| for (int64_t j = 0; j < input_tensor->ElementsNum(); j++) { | for (int64_t j = 0; j < input_tensor->ElementsNum(); j++) { | ||||
| dequant_data[j] = static_cast<float>((quant_data[j] - zero_point) * scale); | |||||
| dequant_datas[j] = static_cast<float>((quant_datas[j] - zero_point) * scale); | |||||
| } | } | ||||
| } | } | ||||
| return dequant_data; | |||||
| return dequant_datas; | |||||
| } | } | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -106,6 +106,8 @@ int LiteSession::ConvertTensors(const lite::Model *model) { | |||||
| QuantArg quant_arg{}; | QuantArg quant_arg{}; | ||||
| quant_arg.scale = quant_params->Get(j)->scale(); | quant_arg.scale = quant_params->Get(j)->scale(); | ||||
| quant_arg.zeroPoint = quant_params->Get(j)->zeroPoint(); | quant_arg.zeroPoint = quant_params->Get(j)->zeroPoint(); | ||||
| quant_arg.var_corr = quant_params->Get(j)->var_corr(); | |||||
| quant_arg.mean_corr = quant_params->Get(j)->mean_corr(); | |||||
| dstTensor->AddQuantParam(quant_arg); | dstTensor->AddQuantParam(quant_arg); | ||||
| } | } | ||||
| } | } | ||||
| @@ -351,7 +353,7 @@ int LiteSession::Init(Context *context) { | |||||
| } | } | ||||
| } | } | ||||
| #endif | #endif | ||||
| executor = new(std::nothrow) Executor(); | |||||
| executor = new (std::nothrow) Executor(); | |||||
| if (nullptr == executor) { | if (nullptr == executor) { | ||||
| MS_LOG(ERROR) << "New Executor failed"; | MS_LOG(ERROR) << "New Executor failed"; | ||||
| is_running_.store(false); | is_running_.store(false); | ||||
| @@ -33,6 +33,8 @@ namespace lite { | |||||
| struct QuantArg { | struct QuantArg { | ||||
| double scale; | double scale; | ||||
| int32_t zeroPoint; | int32_t zeroPoint; | ||||
| double var_corr{1}; | |||||
| double mean_corr{0}; | |||||
| }; | }; | ||||
| class Tensor : public mindspore::tensor::MSTensor { | class Tensor : public mindspore::tensor::MSTensor { | ||||
| @@ -143,7 +143,7 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| std::vector<T> quant_datas(elem_count); | std::vector<T> quant_datas(elem_count); | ||||
| std::vector<float> dequant_datas(elem_count); | |||||
| if (per_channel) { | if (per_channel) { | ||||
| // notice: assume Con2D\DepthwiseConv2D's weight format are same: KHWC | // notice: assume Con2D\DepthwiseConv2D's weight format are same: KHWC | ||||
| // channel at first | // channel at first | ||||
| @@ -173,8 +173,9 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti | |||||
| MS_LOG(ERROR) << "CalQuantizationParams failed" << status; | MS_LOG(ERROR) << "CalQuantizationParams failed" << status; | ||||
| return status; | return status; | ||||
| } | } | ||||
| quant_params.emplace_back(quant_param); | |||||
| // do quantization | // do quantization | ||||
| double average_dequant = 0; | |||||
| double average_raw = 0; | |||||
| for (uint32_t j = 0; j < one_filter_size; j++) { | for (uint32_t j = 0; j < one_filter_size; j++) { | ||||
| auto index = j + i * one_filter_size; | auto index = j + i * one_filter_size; | ||||
| if (index >= elem_count) { | if (index >= elem_count) { | ||||
| @@ -184,7 +185,44 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti | |||||
| float raw_data = raw_datas[index]; | float raw_data = raw_datas[index]; | ||||
| auto quant_data = QuantizeData<T>(raw_data, quant_param, quant_max, quant_min); | auto quant_data = QuantizeData<T>(raw_data, quant_param, quant_max, quant_min); | ||||
| quant_datas[index] = quant_data; | quant_datas[index] = quant_data; | ||||
| if (quantType == QuantType_WeightQuant) { | |||||
| float dequant_data = quant_param.scale * (quant_data - quant_param.zeroPoint); | |||||
| dequant_datas[index] = dequant_data; | |||||
| average_dequant += dequant_data; | |||||
| average_raw += raw_data; | |||||
| } | |||||
| } | } | ||||
| if (quantType == QuantType_WeightQuant) { | |||||
| // mean | |||||
| average_dequant = average_dequant / one_filter_size; | |||||
| average_raw = average_raw / one_filter_size; | |||||
| // std | |||||
| double variance_dequant = 0; | |||||
| double variance_raw = 0; | |||||
| for (uint32_t j = 0; j < one_filter_size; j++) { | |||||
| auto index = j + i * one_filter_size; | |||||
| if (index >= elem_count) { | |||||
| MS_LOG(ERROR) << "over flow!"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| variance_dequant += std::pow(dequant_datas[index] - average_dequant, 2); | |||||
| variance_raw += std::pow(raw_datas[index] - average_raw, 2); | |||||
| } | |||||
| variance_dequant = std::sqrt(variance_dequant / one_filter_size); | |||||
| variance_raw = std::sqrt(variance_raw / one_filter_size); | |||||
| quant_param.var_corr = 1; | |||||
| if (variance_raw != 0 && variance_dequant != 0) { | |||||
| auto temp_var_corr = variance_raw / variance_dequant; | |||||
| if (temp_var_corr > 0 && temp_var_corr < 10) { | |||||
| quant_param.var_corr = temp_var_corr; | |||||
| } else { | |||||
| MS_LOG(WARNING) << "unexpected var_corr: " << temp_var_corr; | |||||
| } | |||||
| } | |||||
| quant_param.mean_corr = average_raw - average_dequant * quant_param.var_corr; | |||||
| } | |||||
| quant_params.emplace_back(quant_param); | |||||
| } | } | ||||
| auto ret = memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(), elem_count * sizeof(int8_t)); | auto ret = memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(), elem_count * sizeof(int8_t)); | ||||
| if (ret != EOK) { | if (ret != EOK) { | ||||