| @@ -32,10 +32,9 @@ table QuantParam { | |||||
| narrowRange: bool = true; | narrowRange: bool = true; | ||||
| numBits: int = 8; | numBits: int = 8; | ||||
| inited: bool = false; | inited: bool = false; | ||||
| varCorr: double = 1; | |||||
| meanCorr: double = 0; | |||||
| varCorr: float = 1; | |||||
| meanCorr: float = 0; | |||||
| dstDtype: int = 32; | dstDtype: int = 32; | ||||
| clusters: [float]; | |||||
| } | } | ||||
| table Tensor { | table Tensor { | ||||
| @@ -49,6 +48,7 @@ table Tensor { | |||||
| offset: int; | offset: int; | ||||
| data: [ubyte]; | data: [ubyte]; | ||||
| quantParams: [QuantParam]; | quantParams: [QuantParam]; | ||||
| quantClusters: [float]; | |||||
| } | } | ||||
| union PrimitiveType { | union PrimitiveType { | ||||
| @@ -107,15 +107,17 @@ int LiteSession::ConvertTensors(const lite::Model *model) { | |||||
| quant_arg.var_corr = quant_params->Get(j)->varCorr(); | quant_arg.var_corr = quant_params->Get(j)->varCorr(); | ||||
| quant_arg.mean_corr = quant_params->Get(j)->meanCorr(); | quant_arg.mean_corr = quant_params->Get(j)->meanCorr(); | ||||
| quant_arg.inited = quant_params->Get(j)->inited(); | quant_arg.inited = quant_params->Get(j)->inited(); | ||||
| auto quant_clusters = quant_params->Get(j)->clusters(); | |||||
| if (quant_clusters != nullptr) { | |||||
| for (size_t k = 0; k < quant_clusters->size(); k++) { | |||||
| quant_arg.clusters.emplace_back(quant_clusters->Get(k)); | |||||
| } | |||||
| } | |||||
| dstTensor->AddQuantParam(quant_arg); | dstTensor->AddQuantParam(quant_arg); | ||||
| } | } | ||||
| } | } | ||||
| auto quant_clusters = srcTensor->quantClusters(); | |||||
| if (quant_clusters != nullptr) { | |||||
| std::vector<float> clusters; | |||||
| for (size_t j = 0; j < quant_clusters->size(); j++) { | |||||
| clusters.push_back(quant_clusters->Get(j)); | |||||
| } | |||||
| dstTensor->SetQuantClusters(clusters); | |||||
| } | |||||
| this->tensors_.emplace_back(dstTensor); | this->tensors_.emplace_back(dstTensor); | ||||
| } | } | ||||
| @@ -79,11 +79,12 @@ class DequantUtil { | |||||
| } | } | ||||
| } else { | } else { | ||||
| auto quant_param = input_tensor->GetQuantParams(); | auto quant_param = input_tensor->GetQuantParams(); | ||||
| auto quant_clusters = input_tensor->GetQuantClusters(); | |||||
| auto param = quant_param.front(); | auto param = quant_param.front(); | ||||
| auto scale = param.scale; | auto scale = param.scale; | ||||
| auto zero_point = param.zeroPoint; | auto zero_point = param.zeroPoint; | ||||
| for (int64_t j = 0; j < input_tensor->ElementsNum(); j++) { | for (int64_t j = 0; j < input_tensor->ElementsNum(); j++) { | ||||
| if (param.clusters.size() != 0) { | |||||
| if (!quant_clusters.empty()) { | |||||
| int8_t index = quant_datas[j]; | int8_t index = quant_datas[j]; | ||||
| if (index > INT8_MAX || index < INT8_MIN) { | if (index > INT8_MAX || index < INT8_MIN) { | ||||
| MS_LOG(ERROR) << "KMeans param quant is error."; | MS_LOG(ERROR) << "KMeans param quant is error."; | ||||
| @@ -367,6 +367,10 @@ void Tensor::AddQuantParam(const QuantArg &quant_arg) { this->quant_params_.push | |||||
| std::vector<QuantArg> Tensor::GetQuantParams() const { return this->quant_params_; } | std::vector<QuantArg> Tensor::GetQuantParams() const { return this->quant_params_; } | ||||
| std::vector<float> Tensor::GetQuantClusters() const { return this->quant_clusters_; } | |||||
| void Tensor::SetQuantClusters(const std::vector<float> &clusters) { this->quant_clusters_ = clusters; } | |||||
| std::vector<tensor::MSTensor *> TensorVectorCast(const std::vector<Tensor *> &src) { | std::vector<tensor::MSTensor *> TensorVectorCast(const std::vector<Tensor *> &src) { | ||||
| std::vector<tensor::MSTensor *> target(src.size()); | std::vector<tensor::MSTensor *> target(src.size()); | ||||
| std::transform(src.begin(), src.end(), target.begin(), [](Tensor *t) { return dynamic_cast<tensor::MSTensor *>(t); }); | std::transform(src.begin(), src.end(), target.begin(), [](Tensor *t) { return dynamic_cast<tensor::MSTensor *>(t); }); | ||||
| @@ -33,8 +33,8 @@ namespace lite { | |||||
| struct QuantArg { | struct QuantArg { | ||||
| double scale; | double scale; | ||||
| int32_t zeroPoint; | int32_t zeroPoint; | ||||
| double var_corr{1}; | |||||
| double mean_corr{0}; | |||||
| float var_corr{1}; | |||||
| float mean_corr{0}; | |||||
| bool inited; | bool inited; | ||||
| std::vector<float> clusters{}; | std::vector<float> clusters{}; | ||||
| }; | }; | ||||
| @@ -119,6 +119,10 @@ class Tensor : public mindspore::tensor::MSTensor { | |||||
| std::vector<QuantArg> GetQuantParams() const; | std::vector<QuantArg> GetQuantParams() const; | ||||
| std::vector<float> GetQuantClusters() const; | |||||
| void SetQuantClusters(const std::vector<float> &clusters); | |||||
| bool IsConst(); | bool IsConst(); | ||||
| bool IsScalar(); | bool IsScalar(); | ||||
| @@ -138,6 +142,7 @@ class Tensor : public mindspore::tensor::MSTensor { | |||||
| Category category_; | Category category_; | ||||
| size_t ref_count_ = 0; | size_t ref_count_ = 0; | ||||
| std::vector<QuantArg> quant_params_; | std::vector<QuantArg> quant_params_; | ||||
| std::vector<float> quant_clusters_; | |||||
| mindspore::lite::Allocator *allocator_ = nullptr; | mindspore::lite::Allocator *allocator_ = nullptr; | ||||
| }; | }; | ||||
| @@ -449,7 +449,6 @@ std::vector<int8_t> KMeans(float *data, size_t elem_count, size_t k, size_t epoc | |||||
| error = error_cur; | error = error_cur; | ||||
| } | } | ||||
| // update data | // update data | ||||
| quantParam->clusters = clusters; | |||||
| return clusters_index; | return clusters_index; | ||||
| } | } | ||||
| @@ -130,7 +130,7 @@ T QuantizeData(float originData, const schema::QuantParamT &quantParam, int quan | |||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primitive_c, QuantType quantType, | STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primitive_c, QuantType quantType, | ||||
| int quant_max, int quant_min, size_t bitNum, bool per_channel) { | |||||
| int quant_max, int quant_min, size_t bitNum, bool per_channel, bool k_means = false) { | |||||
| auto dims = weight->tensor_shape(); | auto dims = weight->tensor_shape(); | ||||
| auto op_type = (schema::PrimitiveType)primitive_c->Type(); | auto op_type = (schema::PrimitiveType)primitive_c->Type(); | ||||
| if (per_channel) { | if (per_channel) { | ||||
| @@ -208,7 +208,7 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti | |||||
| average_raw += raw_data; | average_raw += raw_data; | ||||
| } | } | ||||
| } | } | ||||
| if (quantType == QuantType_WeightQuant && quant_param.clusters.size() == 0) { | |||||
| if (quantType == QuantType_WeightQuant && !k_means) { | |||||
| // mean | // mean | ||||
| average_dequant = average_dequant / one_filter_size; | average_dequant = average_dequant / one_filter_size; | ||||
| average_raw = average_raw / one_filter_size; | average_raw = average_raw / one_filter_size; | ||||
| @@ -256,7 +256,7 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti | |||||
| } | } | ||||
| schema::QuantParamT quant_param; | schema::QuantParamT quant_param; | ||||
| if (quant_param.clusters.size() == 0) { | |||||
| if (!k_means) { | |||||
| STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bitNum); | STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bitNum); | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "CalQuantizationParams failed" << status; | MS_LOG(ERROR) << "CalQuantizationParams failed" << status; | ||||
| @@ -267,7 +267,7 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti | |||||
| // update data and datatype | // update data and datatype | ||||
| for (uint32_t i = 0; i < elem_count; i++) { | for (uint32_t i = 0; i < elem_count; i++) { | ||||
| float raw_data = raw_datas[i]; | float raw_data = raw_datas[i]; | ||||
| if (quant_param.clusters.size() == 0) { | |||||
| if (!k_means) { | |||||
| auto quant_data = QuantizeData<T>(raw_data, quant_param, quant_max, quant_min); | auto quant_data = QuantizeData<T>(raw_data, quant_param, quant_max, quant_min); | ||||
| quant_datas[i] = quant_data; | quant_datas[i] = quant_data; | ||||
| } | } | ||||