diff --git a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc index 557d4c1012..5acf6f9f28 100644 --- a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc @@ -289,24 +289,22 @@ std::unordered_map> *Calibrator::GetInp return &this->input_diverg_info_; } -std::unordered_map> *Calibrator::GetOutputDivergInfo() { - return &this->output_diverg_info_; +std::unordered_map>> *Calibrator::GetOutputDivergInfo() { + return &this->outputs_diverg_info_; } -STATUS Calibrator::RecordMaxValue(const std::string &op_name, const vector &data, - std::unordered_map> *diverg_info) { - auto got = (*diverg_info).find(op_name); - if (got != (*diverg_info).end()) { - ((*got).second)->RecordMaxValue(data); - ((*got).second)->RecordMaxValueArray(data); - } +STATUS Calibrator::RecordMaxValue(const vector &data, const std::unique_ptr &diverg_info) { + diverg_info->RecordMaxValue(data); + diverg_info->RecordMaxValueArray(data); return RET_OK; } STATUS Calibrator::ComputeThreshold() { - for (auto iter = this->output_diverg_info_.begin(); iter != this->output_diverg_info_.end(); iter++) { - DivergInfo *info = iter->second.get(); - info->ComputeThreshold(); + for (auto &kv : this->outputs_diverg_info_) { + auto &outputs_diverg_info = kv.second; + for (auto &diverg_info : outputs_diverg_info) { + diverg_info->ComputeThreshold(); + } } // node A's input may be node B's output, no need to re-compute the node A's input quant param which is the same as for (auto iter = this->input_diverg_info_.begin(); iter != this->input_diverg_info_.end(); iter++) { @@ -317,14 +315,21 @@ STATUS Calibrator::ComputeThreshold() { auto input = cnode->input(1); if (input->isa()) { auto input_cnode = std::dynamic_pointer_cast(input); - for (const auto &output_diverg_info : output_diverg_info_) { - auto output_diverg_cnode = output_diverg_info.second->cnode; - if (output_diverg_cnode == input_cnode) { - *info = *(output_diverg_info.second); - info->cnode = cnode; - already_computed = true; + for (const auto &outputs_diverg_info : outputs_diverg_info_) { + if (already_computed) { break; } + for (const auto &output_diverg_info : outputs_diverg_info.second) { + auto output_diverg_cnode = output_diverg_info->cnode; + if (output_diverg_cnode == input_cnode) { + if (NodePrimitiveType(input_cnode) != schema::PrimitiveType_TupleGetItem) { + *info = *output_diverg_info; + info->cnode = cnode; + already_computed = true; + break; + } + } + } } } if (!already_computed) { @@ -334,6 +339,16 @@ STATUS Calibrator::ComputeThreshold() { return RET_OK; } +STATUS Calibrator::UpdateOutputDivergInverval( + std::unordered_map>> *diverg_info) { + for (auto &kv : *diverg_info) { + for (auto &info : kv.second) { + info->UpdateInterval(); + } + } + return RET_OK; +} + STATUS Calibrator::UpdateDivergInverval(std::unordered_map> *diverg_info) { for (auto iter = (*diverg_info).begin(); iter != (*diverg_info).end(); iter++) { DivergInfo *info = iter->second.get(); @@ -342,12 +357,8 @@ STATUS Calibrator::UpdateDivergInverval(std::unordered_map &data, - std::unordered_map> *diverg_info) { - auto got = (*diverg_info).find(op_name); - if (got != (*diverg_info).end()) { - ((*got).second)->UpdateHistogram(data); - } +STATUS Calibrator::UpdateDataFrequency(const vector &data, const std::unique_ptr &diverg_info) { + diverg_info->UpdateHistogram(data); return RET_OK; } @@ -362,8 +373,8 @@ STATUS Calibrator::AddQuantizedOp(CNodePtr node) { std::unique_ptr output_diverg = std::unique_ptr( new DivergInfo(node, kDefaultBinNumber, bit_num_, quant_max_, quant_min_, config_param_.method_x)); - input_diverg_info_.insert(std::make_pair(string(node_name), std::move(input_diverg))); - output_diverg_info_.insert(std::make_pair(string(node_name), std::move(output_diverg))); + input_diverg_info_.insert(std::make_pair(node_name, std::move(input_diverg))); + outputs_diverg_info_[node_name].push_back(std::move(output_diverg)); return RET_OK; } @@ -377,7 +388,6 @@ void Calibrator::AddImage(const string &file, size_t index) { return stat(file.c_str(), &buf) == 0; }; if (exist(file)) { - MS_LOG(INFO) << "load image: " << file; this->images_[index].push_back(file); } else { MS_LOG(WARNING) << "invalid image file path: " << file; @@ -411,6 +421,7 @@ STATUS Calibrator::GenerateInputData(int input_index, int image_index, mindspore STATUS Calibrator::CollectImages() { this->images_.resize(config_param_.image_paths.size()); auto input_i = 0; + bool multi_input = config_param_.image_paths.size() > 1; for (const auto &image_path : config_param_.image_paths) { DIR *root = opendir(image_path.c_str()); if (root == nullptr) { @@ -420,13 +431,14 @@ STATUS Calibrator::CollectImages() { struct dirent *image_dir = readdir(root); size_t count = 0; while (image_dir != nullptr) { - if (image_dir->d_name[0] != '.') { - const std::string file_name = image_path + "/" + image_dir->d_name; - if (config_param_.batch_count == 0) { - this->AddImage(file_name, input_i); + string file_name(image_dir->d_name); + if (file_name != "." && file_name != "..") { + const std::string file_path = image_path + "/" + file_name; + if (multi_input || config_param_.batch_count == 0) { + this->AddImage(file_path, input_i); count++; } else if (count < config_param_.batch_count) { - this->AddImage(file_name, input_i); + this->AddImage(file_path, input_i); count++; } else { break; @@ -434,6 +446,10 @@ STATUS Calibrator::CollectImages() { } image_dir = readdir(root); } + std::sort(images_[input_i].begin(), images_[input_i].end()); + if (config_param_.batch_count != 0 && config_param_.batch_count < images_[input_i].size()) { + images_[input_i].resize(config_param_.batch_count); + } closedir(root); input_i++; } @@ -487,6 +503,7 @@ STATUS Calibrator::ReadConfig() { config_param_.image_paths.push_back(image_path); raw_image_paths = raw_image_paths.substr(ind + 1); Trim(&raw_image_paths); + ind = raw_image_paths.find(','); } config_param_.image_paths.push_back(raw_image_paths); } else if (key == "batch_count") { @@ -550,7 +567,7 @@ PostTrainingQuantizer::PostTrainingQuantizer(FuncGraphPtr graph, string path, in STATUS PostTrainingQuantizer::DoQuantInput(double scale, int zeropoint, struct MaxMin *max_min, std::shared_ptr lite_primitive) { if (!lite_primitive->GetInputQuantParams().empty()) { - return RET_OK; + MS_LOG(DEBUG) << "input quant params not empty"; // multi-input op: like concat } schema::QuantParamT quant_param; quant_param.scale = scale; @@ -567,7 +584,7 @@ STATUS PostTrainingQuantizer::DoQuantInput(double scale, int zeropoint, struct M STATUS PostTrainingQuantizer::DoQuantOutput(double scale, int zeropoint, struct MaxMin *max_min, std::shared_ptr lite_primitive) { if (!lite_primitive->GetOutputQuantParams().empty()) { - return RET_OK; + MS_LOG(DEBUG) << "output quant params not empty"; // multi-output op: like split } schema::QuantParamT quant_param; quant_param.scale = scale; @@ -737,9 +754,7 @@ STATUS PostTrainingQuantizer::QuantNode() { auto input_scale = this->calibrator_->GetScale(this->calibrator_->GetInputDivergInfo()); auto input_zero_point = this->calibrator_->GetZeropoint(this->calibrator_->GetInputDivergInfo()); - auto output_min_max = this->calibrator_->GetMinMax(this->calibrator_->GetOutputDivergInfo()); - auto output_scale = this->calibrator_->GetScale(this->calibrator_->GetOutputDivergInfo()); - auto output_zeropoint = this->calibrator_->GetZeropoint(this->calibrator_->GetOutputDivergInfo()); + auto outputs_diverg_info = calibrator_->GetOutputDivergInfo(); auto cnodes = funcGraph->GetOrderedCnodes(); for (auto &cnode : cnodes) { @@ -760,8 +775,35 @@ STATUS PostTrainingQuantizer::QuantNode() { auto op_type = (schema::PrimitiveType)primitive_c->Type(); MS_LOG(DEBUG) << "OpName: " << op_name; - if (op_type != PrimitiveType_Conv2D && op_type != PrimitiveType_DepthwiseConv2D && - op_type != PrimitiveType_FullConnection) { + if (op_type == PrimitiveType_TupleGetItem) { + auto index_node = cnode->input(2); + auto index_value_node = std::dynamic_pointer_cast(index_node); + if (index_value_node == nullptr) { + MS_LOG(WARNING) << "index value node is null"; + continue; + } + size_t index = GetValue(index_value_node->value()); + + auto input_node = cnode->input(1); + auto input_cnode = std::dynamic_pointer_cast(input_node); + auto input_cnode_primitive_c = GetValueNode>(input_cnode->input(0)); + if (input_cnode_primitive_c == nullptr) { + MS_LOG(WARNING) << "input_cnode_primitive_c is null"; + continue; + } + if (input_cnode_primitive_c->GetOutputQuantParams().size() > index) { + auto quant_param = input_cnode_primitive_c->GetOutputQuantParams()[index]; + primitive_c->AddInputQuantParam(quant_param); + primitive_c->AddOutputQuantParam(quant_param); + } else { + MS_LOG(WARNING) << "this TupleGetItem node's input node: " << input_cnode->fullname_with_scope() + << "'s output quant_params size: " << input_cnode_primitive_c->GetOutputQuantParams().size() + << ", but index: " << index; + } + primitive_c->SetQuantType(schema::QuantType_PostTraining); + continue; + } else if (op_type != PrimitiveType_Conv2D && op_type != PrimitiveType_DepthwiseConv2D && + op_type != PrimitiveType_FullConnection) { for (size_t i = 1; i < cnode->inputs().size(); i++) { auto input_node = cnode->input(i); if (!input_node->isa()) { @@ -821,18 +863,25 @@ STATUS PostTrainingQuantizer::QuantNode() { DoBiasQuant(bias, primitive_c); } } - // do output quant - double OutputScale = output_scale[cnode]; - int32_t OutputZeropoint = output_zeropoint[cnode]; - DoQuantOutput(OutputScale, OutputZeropoint, &output_min_max[cnode], primitive_c); - primitive_c->SetQuantType(schema::QuantType_PostTraining); + // do output quant, there may multi-output + auto &infos = (*outputs_diverg_info)[op_name]; + for (auto &info : infos) { + auto output_scale = info->GetScale().second; + auto output_zp = info->GetZeropoint().second; + struct MaxMin output_min_max {}; + output_min_max.max = info->max; + output_min_max.min = info->min; + + DoQuantOutput(output_scale, output_zp, &output_min_max, primitive_c); + primitive_c->SetQuantType(schema::QuantType_PostTraining); + } } return RET_OK; } STATUS PostTrainingQuantizer::UpdateDivergInverval() { this->calibrator_->UpdateDivergInverval(this->calibrator_->GetInputDivergInfo()); - this->calibrator_->UpdateDivergInverval(this->calibrator_->GetOutputDivergInfo()); + this->calibrator_->UpdateOutputDivergInverval(this->calibrator_->GetOutputDivergInfo()); return RET_OK; } @@ -869,7 +918,6 @@ STATUS PostTrainingQuantizer::PreProcess() { for (auto &cnode : cnodes) { AnfNodePtr anf = std::dynamic_pointer_cast(cnode); if (strategy.CanOpPostQuantized(anf)) { - MS_LOG(INFO) << "node: " << cnode->fullname_with_scope() << " will be quantized"; calibrator_->AddQuantizedOp(cnode); } } @@ -917,6 +965,10 @@ STATUS PostTrainingQuantizer::DoInference() { KernelCallBack beforeCallBack = [&](const std::vector &beforeInputs, const std::vector &beforeOutputs, const CallBackParam &callParam) -> bool { + auto diverg_info_map = calibrator_->GetInputDivergInfo(); + if (diverg_info_map->find(callParam.node_name) == diverg_info_map->end()) { + return true; + } if (PostTrainingQuantizer::CheckFp32TensorVec(callParam.node_name, beforeInputs) != RET_OK) { return false; } @@ -924,21 +976,35 @@ STATUS PostTrainingQuantizer::DoInference() { const float *tData = static_cast(tensor->MutableData()); size_t elem_count = tensor->ElementsNum(); vector data(tData, tData + elem_count); - this->calibrator_->RecordMaxValue(callParam.node_name, data, this->calibrator_->GetInputDivergInfo()); + this->calibrator_->RecordMaxValue(data, (*diverg_info_map)[callParam.node_name]); return true; }; // func KernelCallBack afterCallBack = [&](const std::vector &afterInputs, const std::vector &afterOutputs, const CallBackParam &callParam) -> bool { + auto diverg_info_map = calibrator_->GetOutputDivergInfo(); + if (diverg_info_map->find(callParam.node_name) == diverg_info_map->end()) { + return true; + } if (PostTrainingQuantizer::CheckFp32TensorVec(callParam.node_name, afterOutputs) != RET_OK) { return false; } - auto tensor = afterOutputs[0]; - const float *tensor_data = static_cast(tensor->MutableData()); - size_t elem_count = tensor->ElementsNum(); - vector data(tensor_data, tensor_data + elem_count); - this->calibrator_->RecordMaxValue(callParam.node_name, data, this->calibrator_->GetOutputDivergInfo()); + if (afterOutputs.size() > 1) { + auto output_diverg = std::make_unique(); + *output_diverg = *((*diverg_info_map)[callParam.node_name][0]); + for (size_t i = 1; i < afterOutputs.size(); i++) { + (*diverg_info_map)[callParam.node_name].push_back(std::move(output_diverg)); + } + } + size_t output_i = 0; + for (const auto &tensor : afterOutputs) { + const float *tensor_data = static_cast(tensor->MutableData()); + size_t elem_count = tensor->ElementsNum(); + vector data(tensor_data, tensor_data + elem_count); + this->calibrator_->RecordMaxValue(data, (*diverg_info_map)[callParam.node_name][output_i]); + output_i++; + } return true; }; auto status = fp32_session_->RunGraph(beforeCallBack, afterCallBack); @@ -1317,6 +1383,10 @@ STATUS PostTrainingQuantizer::CollectDataFrequency() { KernelCallBack beforeCallBack = [&](const std::vector &beforeInputs, const std::vector &beforeOutputs, const CallBackParam &callParam) { + auto diverg_info_map = calibrator_->GetInputDivergInfo(); + if (diverg_info_map->find(callParam.node_name) == diverg_info_map->end()) { + return true; + } if (PostTrainingQuantizer::CheckFp32TensorVec(callParam.node_name, beforeInputs) != RET_OK) { return false; } @@ -1324,21 +1394,28 @@ STATUS PostTrainingQuantizer::CollectDataFrequency() { const float *tensor_data = static_cast(tensor->MutableData()); size_t shape_size = tensor->ElementsNum(); vector data(tensor_data, tensor_data + shape_size); - this->calibrator_->UpdateDataFrequency(callParam.node_name, data, this->calibrator_->GetInputDivergInfo()); + this->calibrator_->UpdateDataFrequency(data, (*diverg_info_map)[callParam.node_name]); return true; }; KernelCallBack afterCallBack = [&](const std::vector &after_inputs, const std::vector &after_outputs, const CallBackParam &call_param) { + auto diverg_info_map = calibrator_->GetOutputDivergInfo(); + if (diverg_info_map->find(call_param.node_name) == diverg_info_map->end()) { + return true; + } if (PostTrainingQuantizer::CheckFp32TensorVec(call_param.node_name, after_outputs) != RET_OK) { return false; } - auto tensor = after_outputs[0]; - const float *tenosr_data = static_cast(tensor->MutableData()); - size_t shape_size = tensor->ElementsNum(); - vector data(tenosr_data, tenosr_data + shape_size); - this->calibrator_->UpdateDataFrequency(call_param.node_name, data, this->calibrator_->GetOutputDivergInfo()); + int output_i = 0; + for (const auto &tensor : after_outputs) { + const float *tensor_data = static_cast(tensor->MutableData()); + size_t elem_count = tensor->ElementsNum(); + vector data(tensor_data, tensor_data + elem_count); + this->calibrator_->UpdateDataFrequency(data, (*diverg_info_map)[call_param.node_name][output_i]); + output_i++; + } return true; }; auto status = fp32_session_->RunGraph(beforeCallBack, afterCallBack); diff --git a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.h b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.h index 6ef6a98d0d..23c4df7696 100644 --- a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.h +++ b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.h @@ -132,7 +132,7 @@ struct DivergInfo { std::vector max_datas; std::pair percent_result{0.0, 0.0}; float scale_tmp = 0; - + DivergInfo() = default; DivergInfo(CNodePtr cnode, int bins, size_t bits, int quant_max, int quant_min, const std::string &method_x) { this->method_x = method_x; this->cnode = cnode; @@ -187,13 +187,14 @@ class Calibrator { STATUS AddQuantizedOp(CNodePtr node); - STATUS RecordMaxValue(const std::string &op_name, const std::vector &data, - std::unordered_map> *diverg_info); + STATUS RecordMaxValue(const std::vector &data, const std::unique_ptr &diverg_info); STATUS UpdateDivergInverval(std::unordered_map> *diverg_info); - STATUS UpdateDataFrequency(const std::string &op_name, const std::vector &data, - std::unordered_map> *diverg_info); + STATUS UpdateOutputDivergInverval( + std::unordered_map>> *diverg_info); + + STATUS UpdateDataFrequency(const std::vector &data, const std::unique_ptr &diverg_info); void Dump(); STATUS ComputeThreshold(); @@ -208,7 +209,7 @@ class Calibrator { std::unordered_map> *GetInputDivergInfo(); - std::unordered_map> *GetOutputDivergInfo(); + std::unordered_map>> *GetOutputDivergInfo(); private: std::vector> images_; // multi_input, echo input has multi input data @@ -219,7 +220,7 @@ class Calibrator { std::unordered_map> input_diverg_info_; - std::unordered_map> output_diverg_info_; + std::unordered_map>> outputs_diverg_info_; size_t bit_num_; int quant_max_; diff --git a/mindspore/lite/tools/converter/quantizer/quant_cast.cc b/mindspore/lite/tools/converter/quantizer/quant_cast.cc index 8752b9e50f..d6775146c0 100644 --- a/mindspore/lite/tools/converter/quantizer/quant_cast.cc +++ b/mindspore/lite/tools/converter/quantizer/quant_cast.cc @@ -85,7 +85,7 @@ STATUS QuantCast::Run(FuncGraphPtr graph) { if (curnode_quant_type == schema::QuantType_PostTraining && input_cnode_quant_type == schema::QuantType_QUANT_NONE) { value_node = - NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8, primitive_c->GetInputQuantParams().front()); + NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8, primitive_c->GetInputQuantParams()[i - 1]); } else if (curnode_quant_type == schema::QuantType_QUANT_NONE && input_cnode_quant_type == schema::QuantType_PostTraining) { value_node = NewQuantCastValueNode(kNumberTypeInt8, kNumberTypeFloat32, diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.cc b/mindspore/lite/tools/converter/quantizer/quantize_util.cc index 24993e95c5..8c622c3e17 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.cc +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.cc @@ -90,22 +90,32 @@ bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const { return false; } auto cnode = std::dynamic_pointer_cast(node); - - auto primitive_c = GetValueNode>(cnode->input(0)); - if (primitive_c == nullptr) { - MS_LOG(WARNING) << "primitive_c is nullptr: " << cnode->fullname_with_scope(); - return false; - } - - auto type = (schema::PrimitiveType)primitive_c->Type(); - MS_LOG(INFO) << "Primitive type: " << type; - static const std::vector uint8OpList = { - schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw, schema::PrimitiveType_Conv2D, - schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_Add, schema::PrimitiveType_Pooling, - /*schema::PrimitiveType_Concat, schema::PrimitiveType_SoftMax,*/ - schema::PrimitiveType_Reshape, schema::PrimitiveType_FullConnection, schema::PrimitiveType_MatMul, - schema::PrimitiveType_Activation}; - return IsContain(uint8OpList, type); + auto type = NodePrimitiveType(cnode); + static const std::vector int8OpList = { + schema::PrimitiveType_Nchw2Nhwc, + schema::PrimitiveType_Nhwc2Nchw, + schema::PrimitiveType_Conv2D, + schema::PrimitiveType_DepthwiseConv2D, + schema::PrimitiveType_Add, + schema::PrimitiveType_Pooling, + schema::PrimitiveType_Concat, + schema::PrimitiveType_Split, + schema::PrimitiveType_TupleGetItem, + schema::PrimitiveType_Reshape, + schema::PrimitiveType_FullConnection, + schema::PrimitiveType_MatMul, + schema::PrimitiveType_Crop, + schema::PrimitiveType_DeDepthwiseConv2D, + schema::PrimitiveType_DeConv2D, + schema::PrimitiveType_Activation, + schema::PrimitiveType_TupleGetItem, + }; + bool contain = IsContain(int8OpList, type); + if (!contain) { + MS_LOG(INFO) << "not quant, " << cnode->fullname_with_scope() + << " of type: " << schema::EnumNamePrimitiveType(type); + } + return contain; } bool QuantStrategy::CanMulOpQuantized(const CNodePtr &node) const { @@ -431,6 +441,19 @@ std::vector KMeans(float *data, size_t elem_count, size_t k, size_t epoc quantParam->clusters = clusters; return clusters_index; } + +schema::PrimitiveType NodePrimitiveType(CNodePtr cnode) { + if (cnode == nullptr) { + MS_LOG(ERROR) << "cnode is null"; + return schema::PrimitiveType_NONE; + } + auto primitive_c = GetValueNode>(cnode->input(0)); + if (primitive_c == nullptr) { + MS_LOG(ERROR) << "primitive_c is null"; + return schema::PrimitiveType_NONE; + } + return (schema::PrimitiveType)primitive_c->Type(); +} } // namespace quant } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.h b/mindspore/lite/tools/converter/quantizer/quantize_util.h index c751199cec..2492c7172f 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.h +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.h @@ -287,6 +287,8 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr primiti } STATUS PostBitPack(float *weights, size_t shapeSize, size_t bitNum = UINT8_QUANTIZATION); + +schema::PrimitiveType NodePrimitiveType(CNodePtr cnode); } // namespace quant } // namespace lite } // namespace mindspore