| @@ -21,7 +21,7 @@ | |||
| #include "include/errorcode.h" | |||
| #include "src/common/log_adapter.h" | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| #include "tools/converter/quantizer/quantize_util.h" | |||
| #include "src/param_value_lite.h" | |||
| #endif | |||
| #ifndef PRIMITIVE_WRITEABLE | |||
| @@ -24,7 +24,7 @@ | |||
| #include "src/common/log_adapter.h" | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| #include <float.h> | |||
| #include "tools/converter/quantizer/quantize_util.h" | |||
| #include "src/param_value_lite.h" | |||
| #endif | |||
| #ifndef PRIMITIVE_WRITEABLE | |||
| @@ -21,8 +21,7 @@ | |||
| #include "src/common/log_adapter.h" | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| #include <float.h> | |||
| #include "tools/converter/quantizer/quantize_util.h" | |||
| #include "src/param_value_lite.h" | |||
| #endif | |||
| #ifndef PRIMITIVE_WRITEABLE | |||
| @@ -19,7 +19,7 @@ | |||
| #include <memory> | |||
| #include <string> | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| #include "tools/converter/quantizer/quantize_util.h" | |||
| #include "src/param_value_lite.h" | |||
| #endif | |||
| #ifndef PRIMITIVE_WRITEABLE | |||
| #include "src/ops/ops_register.h" | |||
| @@ -18,7 +18,7 @@ | |||
| #include <memory> | |||
| #include <utility> | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| #include "tools/converter/quantizer/quantize_util.h" | |||
| #include "src/param_value_lite.h" | |||
| #endif | |||
| #ifndef PRIMITIVE_WRITEABLE | |||
| @@ -19,8 +19,7 @@ | |||
| #include "src/common/log_adapter.h" | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| #include <float.h> | |||
| #include "tools/converter/quantizer/quantize_util.h" | |||
| #include "src/param_value_lite.h" | |||
| #endif | |||
| #ifndef PRIMITIVE_WRITEABLE | |||
| @@ -19,7 +19,7 @@ | |||
| #include "src/common/log_adapter.h" | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| #include <float.h> | |||
| #include "tools/converter/quantizer/quantize_util.h" | |||
| #include "src/param_value_lite.h" | |||
| #endif | |||
| #ifndef PRIMITIVE_WRITEABLE | |||
| @@ -19,7 +19,7 @@ | |||
| #include "src/common/log_adapter.h" | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| #include <float.h> | |||
| #include "tools/converter/quantizer/quantize_util.h" | |||
| #include "src/param_value_lite.h" | |||
| #endif | |||
| #ifndef PRIMITIVE_WRITEABLE | |||
| @@ -387,7 +387,9 @@ void PrimitiveC::set_input_quant_params(const std::vector<std::vector<schema::Qu | |||
| } | |||
| void PrimitiveC::set_input_quant_param(const size_t &index, const std::vector<schema::QuantParamT> &input_quant_param) { | |||
| MS_ASSERT(index < this->input_quant_param_.size()); | |||
| if (index >= this->input_quant_param_.size()) { | |||
| this->input_quant_param_.resize(index + 1); | |||
| } | |||
| this->input_quant_param_.at(index) = input_quant_param; | |||
| } | |||
| @@ -493,7 +495,7 @@ std::shared_ptr<PrimitiveC> GetTupleGetItemPrim() { | |||
| } | |||
| template <typename T, typename = std::enable_if<std::is_base_of<PrimitiveC, T>::value>> | |||
| std::shared_ptr<PrimitiveC> NewPrimitiveC(const Primitive &prim, const std::vector<AnfNodePtr> &inputs, | |||
| std::shared_ptr<PrimitiveC> NewPrimitiveC(const mindspore::Primitive &prim, const std::vector<AnfNodePtr> &inputs, | |||
| const schema::QuantType &quantType) { | |||
| auto primc = std::make_shared<T>(); | |||
| if (primc == nullptr) { | |||
| @@ -204,7 +204,7 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); | |||
| return nullptr; | |||
| } | |||
| this->mQuantizer = std::make_unique<quant::WeightQuantizer>(new_graph, config->quantWeightSize, | |||
| this->mQuantizer = std::make_unique<quant::WeightQuantizer>(new_graph, config->configFile, config->quantWeightSize, | |||
| config->quantWeightChannel, config->bitNum); | |||
| if (mQuantizer == nullptr) { | |||
| MS_LOG(ERROR) << "New WeightQuantizer failed"; | |||
| @@ -32,8 +32,6 @@ | |||
| #include "tools/anf_exporter/anf_exporter.h" | |||
| #include "tools/anf_importer/import_from_mindir.h" | |||
| #include "proto/onnx.pb.h" | |||
| #include "tools/converter/quantizer/post_training_quantizer.h" | |||
| #include "tools/converter/quantizer/quant_cast.h" | |||
| #include "include/version.h" | |||
| namespace mindspore { | |||
| @@ -16,7 +16,6 @@ | |||
| #include "tools/converter/legacy_optimizer/graph/tensor_name_pass.h" | |||
| #include "tools/converter/converter_context.h" | |||
| #include "tools/converter/quantizer/quantize_util.h" | |||
| #include "tools/common/tensor_util.h" | |||
| namespace mindspore::lite { | |||
| @@ -6,7 +6,6 @@ include_directories(${3RD_DIR}/opencv/build/include/opencv4) | |||
| file(GLOB QUANTIZER | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/calc_quant_param.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/quantizer.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/aware_quantizer.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/quantize_util.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/post_training_quantizer.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/quant_cast.cc | |||
| @@ -39,6 +39,7 @@ | |||
| #include "tools/common/tensor_util.h" | |||
| #include "src/common/file_utils.h" | |||
| #include "src/common/utils.h" | |||
| #include "tools/converter/quantizer/weight_quantizer.h" | |||
| using std::string; | |||
| using std::vector; | |||
| @@ -380,182 +381,16 @@ STATUS Calibrator::AddQuantizedOp(const CNodePtr &node) { | |||
| return RET_OK; | |||
| } | |||
| void Calibrator::AddImage(const string &file, size_t index) { | |||
| if (index >= images_.size()) { | |||
| MS_LOG(ERROR) << "images_ size: " << images_.size() << " but index: " << index; | |||
| return; | |||
| } | |||
| auto exist = [](const string &file) { | |||
| struct stat buf {}; | |||
| return stat(file.c_str(), &buf) == 0; | |||
| }; | |||
| if (exist(file)) { | |||
| this->images_[index].push_back(file); | |||
| } else { | |||
| MS_LOG(WARNING) << "invalid image file path: " << file; | |||
| } | |||
| } | |||
| STATUS Calibrator::GenerateInputData(size_t input_index, size_t image_index, | |||
| mindspore::tensor::MSTensor *tensor) const { | |||
| MS_ASSERT(tensor != nullptr); | |||
| if (input_index >= images_.size()) { | |||
| MS_LOG(ERROR) << "images_ size: " << images_.size() << " but input_index: " << input_index; | |||
| return RET_ERROR; | |||
| } | |||
| if (image_index >= images_[input_index].size()) { | |||
| MS_LOG(ERROR) << "images_[input_index] size: " << images_[input_index].size() | |||
| << " but image_index: " << image_index; | |||
| return RET_ERROR; | |||
| } | |||
| string path = images_[input_index][image_index]; | |||
| MS_LOG(INFO) << "read image: " << path; | |||
| size_t size; | |||
| char *bin_buf = ReadFile(path.c_str(), &size); | |||
| if (bin_buf == nullptr) { | |||
| MS_LOG(ERROR) << "ReadFile return nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto data = tensor->MutableData(); | |||
| if (data == nullptr) { | |||
| MS_LOG(ERROR) << "Get tensor MutableData return nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| if (size != tensor->Size()) { | |||
| MS_LOG(ERROR) << "the input data is not consistent with model input, file_size: " << size | |||
| << " input tensor size: " << tensor->Size(); | |||
| return RET_ERROR; | |||
| } | |||
| auto ret = memcpy_s(data, tensor->Size(), bin_buf, size); | |||
| if (ret != EOK) { | |||
| MS_LOG(ERROR) << "memcpy_s error: " << ret; | |||
| delete[] bin_buf; | |||
| return RET_ERROR; | |||
| } | |||
| delete[] bin_buf; | |||
| return RET_OK; | |||
| return CopyInputDataToTensor(input_index, image_index, images_, tensor); | |||
| } | |||
| STATUS Calibrator::CollectImages() { | |||
| this->images_.resize(config_param_.image_paths.size()); | |||
| auto input_i = 0; | |||
| bool multi_input = config_param_.image_paths.size() > 1; | |||
| for (const auto &image_path : config_param_.image_paths) { | |||
| DIR *root = opendir(image_path.c_str()); | |||
| if (root == nullptr) { | |||
| MS_LOG(ERROR) << "invalid image path: " << image_path; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| struct dirent *image_dir = readdir(root); | |||
| size_t count = 0; | |||
| while (image_dir != nullptr) { | |||
| string file_name(image_dir->d_name); | |||
| if (file_name != "." && file_name != "..") { | |||
| const std::string file_path = image_path + "/" + file_name; | |||
| if (multi_input || config_param_.batch_count == 0) { | |||
| this->AddImage(file_path, input_i); | |||
| count++; | |||
| } else if (count < config_param_.batch_count) { | |||
| this->AddImage(file_path, input_i); | |||
| count++; | |||
| } else { | |||
| break; | |||
| } | |||
| } | |||
| image_dir = readdir(root); | |||
| } | |||
| std::sort(images_[input_i].begin(), images_[input_i].end()); | |||
| if (config_param_.batch_count != 0 && config_param_.batch_count < images_[input_i].size()) { | |||
| images_[input_i].resize(config_param_.batch_count); | |||
| } | |||
| closedir(root); | |||
| input_i++; | |||
| } | |||
| return RET_OK; | |||
| return CollectCalibInputs(config_param_.image_paths, config_param_.batch_count, &images_); | |||
| } | |||
| STATUS Calibrator::ReadConfig() { | |||
| if (config_path_.empty() || config_path_.length() > PATH_MAX) { | |||
| MS_LOG(ERROR) << "invalid config path!"; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| // check whether config file path is valid | |||
| char *resolved_path = new (std::nothrow) char[PATH_MAX]{0}; | |||
| if (resolved_path == nullptr) { | |||
| MS_LOG(ERROR) << "New an object failed."; | |||
| return RET_ERROR; | |||
| } | |||
| #ifdef _WIN32 | |||
| if (_fullpath(resolved_path, config_path_.c_str(), 1024) != nullptr) { | |||
| config_path_ = string(resolved_path); | |||
| } | |||
| #else | |||
| if (realpath(config_path_.c_str(), resolved_path) != nullptr) { | |||
| config_path_ = string(resolved_path); | |||
| } | |||
| #endif | |||
| std::ifstream fs(config_path_.c_str(), std::ifstream::in); | |||
| if (!fs.is_open()) { | |||
| MS_LOG(ERROR) << "config proto file %s open failed: " << config_path_; | |||
| delete[] resolved_path; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| std::string line; | |||
| while (std::getline(fs, line)) { | |||
| auto index = line.find('='); | |||
| if (index == std::string::npos) { | |||
| MS_LOG(ERROR) << "the config file is invalid, can not find '=', please check"; | |||
| delete[] resolved_path; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| auto key = line.substr(0, index); | |||
| auto value = line.substr(index + 1); | |||
| Trim(&key); | |||
| Trim(&value); | |||
| if (key == "image_path") { | |||
| auto &raw_image_paths = value; | |||
| auto ind = raw_image_paths.find(','); | |||
| while (ind != std::string::npos) { | |||
| auto image_path = raw_image_paths.substr(0, ind); | |||
| Trim(&image_path); | |||
| config_param_.image_paths.push_back(image_path); | |||
| raw_image_paths = raw_image_paths.substr(ind + 1); | |||
| Trim(&raw_image_paths); | |||
| ind = raw_image_paths.find(','); | |||
| } | |||
| config_param_.image_paths.push_back(raw_image_paths); | |||
| } else if (key == "batch_count") { | |||
| config_param_.batch_count = std::stoul(value); | |||
| } else if (key == "thread_num") { | |||
| config_param_.thread_num = std::stoul(value); | |||
| } else if (key == "method_x") { | |||
| if (value != kMethodKL && value != kMethodMaxMin && value != kMethodOutlier) { | |||
| MS_LOG(WARNING) << "unsupported method_x: " << value << ". Use default value."; | |||
| } else { | |||
| config_param_.method_x = value; | |||
| } | |||
| } else if (key == "bias_correction") { | |||
| std::for_each(value.begin(), value.end(), ::tolower); | |||
| if (value == "true") { | |||
| config_param_.bias_correction = true; | |||
| } | |||
| } else { | |||
| MS_LOG(WARNING) << "unsupported parameter"; | |||
| } | |||
| } | |||
| for (const auto &path : config_param_.image_paths) { | |||
| MS_LOG(DEBUG) << "calibration data_path: " << path; | |||
| } | |||
| MS_LOG(DEBUG) << "batch_count: " << config_param_.batch_count << " " | |||
| << "method_x: " << config_param_.method_x << " " | |||
| << "thread_num: " << config_param_.thread_num << " " | |||
| << "bias_correction: " << config_param_.bias_correction; | |||
| delete[] resolved_path; | |||
| fs.close(); | |||
| return RET_OK; | |||
| } | |||
| STATUS Calibrator::ReadConfig() { return ParseConfigFile(config_path_, &config_param_); } | |||
| Calibrator::Calibrator(string path, size_t bit_num, int quant_max, int quant_min) | |||
| : config_path_(std::move(path)), bit_num_(bit_num), quant_max_(quant_max), quant_min_(quant_min) {} | |||
| @@ -621,8 +456,8 @@ STATUS PostTrainingQuantizer::DoQuantOutput(double scale, int zeropoint, struct | |||
| return RET_OK; | |||
| } | |||
| STATUS PostTrainingQuantizer::DoWeightQuant(const AnfNodePtr &weight, std::shared_ptr<PrimitiveC> primitive_c, | |||
| bool perchanel) const { | |||
| STATUS PostTrainingQuantizer::DoWeightQuant(const std::string &op_name, const AnfNodePtr &weight, | |||
| std::shared_ptr<PrimitiveC> primitive_c, bool perchanel) const { | |||
| MS_ASSERT(weight != nullptr); | |||
| MS_ASSERT(lite_primitive != nullptr); | |||
| // perlayer | |||
| @@ -640,8 +475,21 @@ STATUS PostTrainingQuantizer::DoWeightQuant(const AnfNodePtr &weight, std::share | |||
| MS_LOG(ERROR) << weight->fullname_with_scope() << " can not get value"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto status = QuantFilter<int8_t>(paramValue, std::move(primitive_c), QuantType_PostTraining, quant_max, quant_min, | |||
| bit_num, perchanel); | |||
| auto bit_num_t = bit_num; | |||
| auto quant_max_t = quant_max; | |||
| auto quant_min_t = quant_min; | |||
| if (calibrator_->config_param_.mixed) { | |||
| auto opname_iter = opname_bit_.find(op_name); | |||
| if (opname_iter == opname_bit_.end()) { | |||
| MS_LOG(WARNING) << op_name << " not in the opname_bit_ map"; | |||
| } else { | |||
| bit_num_t = opname_iter->second; | |||
| quant_max_t = (1 << (unsigned int)(bit_num_t - 1)) - 1; | |||
| quant_min_t = -(1 << (unsigned int)(bit_num_t - 1)); | |||
| } | |||
| } | |||
| auto status = QuantFilter<int8_t>(paramValue, std::move(primitive_c), QuantType_PostTraining, quant_max_t, | |||
| quant_min_t, bit_num_t, perchanel); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "QuantFilter failed: " << status; | |||
| return status; | |||
| @@ -921,7 +769,7 @@ STATUS PostTrainingQuantizer::QuantNode() { | |||
| } | |||
| if (abstractTensor->element()->GetTypeTrack()->type_id() == kNumberTypeFloat32) { | |||
| MS_LOG(DEBUG) << "this parameter do quant"; | |||
| DoWeightQuant(input_node, primitive_c, false); | |||
| DoWeightQuant(op_name, input_node, primitive_c, false); | |||
| } else { | |||
| MS_LOG(DEBUG) << "this parameter no need to do quant"; | |||
| } | |||
| @@ -943,7 +791,7 @@ STATUS PostTrainingQuantizer::QuantNode() { | |||
| op_type == PrimitiveType_FullConnection) { | |||
| perchannel = true; | |||
| } | |||
| DoWeightQuant(weight, primitive_c, perchannel); | |||
| DoWeightQuant(op_name, weight, primitive_c, perchannel); | |||
| // do bias quant | |||
| if (cnode->inputs().size() == 4) { | |||
| auto bias = cnode->input(3); | |||
| @@ -982,18 +830,8 @@ STATUS PostTrainingQuantizer::UpdateDivergInverval() { | |||
| * 3. save quantied node | |||
| **/ | |||
| STATUS PostTrainingQuantizer::PreProcess() { | |||
| if (this->calibrator_ == nullptr) { | |||
| MS_LOG(ERROR) << "calibrator is null!"; | |||
| return RET_ERROR; | |||
| } | |||
| // 1. generate config param | |||
| STATUS status = calibrator_->ReadConfig(); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "read proto text failed!"; | |||
| return status; | |||
| } | |||
| // 2. collect image files | |||
| status = calibrator_->CollectImages(); | |||
| auto status = calibrator_->CollectImages(); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "collect images failed!"; | |||
| return status; | |||
| @@ -1560,55 +1398,49 @@ STATUS PostTrainingQuantizer::ComputeThreshold() { return this->calibrator_->Com | |||
| STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr func_graph) { | |||
| MS_LOG(INFO) << "start to parse config file"; | |||
| STATUS status = PreProcess(); | |||
| if (this->calibrator_ == nullptr) { | |||
| MS_LOG(ERROR) << "calibrator is null!"; | |||
| return RET_ERROR; | |||
| } | |||
| // 1. generate config param | |||
| STATUS status = calibrator_->ReadConfig(); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "do pre process failed!"; | |||
| MS_LOG(ERROR) << "read proto text failed!"; | |||
| return status; | |||
| } | |||
| // anf -- fb | |||
| auto meta_graph = Export(func_graph, true, true); | |||
| if (meta_graph == nullptr) { | |||
| MS_LOG(ERROR) << "Export to meta_graph return nullptr"; | |||
| return RET_ERROR; | |||
| if (calibrator_->config_param_.mixed) { | |||
| // get opname_bit map | |||
| auto weight_quant_func_graph = CopyFuncGraph(func_graph); | |||
| if (weight_quant_func_graph == nullptr) { | |||
| MS_LOG(ERROR) << "CopyFuncGraph error"; | |||
| return RET_ERROR; | |||
| } | |||
| WeightQuantizer weight_quantizer(weight_quant_func_graph, calibrator_->config_param_); | |||
| weight_quantizer.flags = flags; | |||
| status = weight_quantizer.DoQuantize(weight_quant_func_graph); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Do mix weight quant error"; | |||
| return RET_ERROR; | |||
| } | |||
| opname_bit_ = weight_quantizer.opname_bit_; | |||
| } | |||
| // transform | |||
| GraphDefTransform transform; | |||
| transform.SetGraphDef(meta_graph); | |||
| flags.quantType = schema::QuantType_QUANT_NONE; | |||
| status = transform.Transform(flags); | |||
| status = PreProcess(); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "FBTransform model failed " << status; | |||
| return RET_ERROR; | |||
| } | |||
| MS_LOG(INFO) << "start create session"; | |||
| flatbuffers::FlatBufferBuilder builder(1024); | |||
| auto offset = schema::MetaGraph::Pack(builder, meta_graph); | |||
| builder.Finish(offset); | |||
| schema::FinishMetaGraphBuffer(builder, offset); | |||
| size_t size = builder.GetSize(); | |||
| auto *content = reinterpret_cast<const char *>(builder.GetBufferPointer()); | |||
| if (content == nullptr) { | |||
| MS_LOG(ERROR) << "GetBufferPointer nullptr"; | |||
| return RET_ERROR; | |||
| MS_LOG(ERROR) << "do pre process failed!"; | |||
| return status; | |||
| } | |||
| auto model = lite::Model::Import(content, size); | |||
| Context ctx; | |||
| ctx.thread_num_ = calibrator_->GetThreadNum(); | |||
| fp32_session_ = dynamic_cast<mindspore::lite::LiteSession *>(session::LiteSession::CreateSession(&ctx)); | |||
| // anf -- fb | |||
| flags.quantType = schema::QuantType_QUANT_NONE; | |||
| MS_LOG(INFO) << "start create session"; | |||
| fp32_session_ = CreateSessionByFuncGraph(func_graph, flags, calibrator_->GetThreadNum()); | |||
| if (fp32_session_ == nullptr) { | |||
| MS_LOG(ERROR) << "create session failed!"; | |||
| return RET_ERROR; | |||
| } | |||
| auto ret = fp32_session_->CompileGraph(model); | |||
| if (ret != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "compile graph error"; | |||
| return RET_ERROR; | |||
| } | |||
| MS_LOG(INFO) << "start to update divergence's max value"; | |||
| status = DoInference(); | |||
| if (status != RET_OK) { | |||
| @@ -1647,49 +1479,13 @@ STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr func_graph) { | |||
| if (calibrator_->GetBiasCorrection()) { | |||
| // init in8 session | |||
| // anf -- fb | |||
| auto int8_meta_graph = Export(func_graph, true, true); | |||
| if (int8_meta_graph == nullptr) { | |||
| MS_LOG(ERROR) << "Export to int8_meta_graph return nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| // transform | |||
| GraphDefTransform fb_transform; | |||
| fb_transform.SetGraphDef(int8_meta_graph); | |||
| MS_LOG(INFO) << "create quant session"; | |||
| flags.quantType = schema::QuantType_PostTraining; | |||
| status = fb_transform.Transform(flags); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "FBTransform model failed " << status; | |||
| return RET_ERROR; | |||
| } | |||
| MS_LOG(INFO) << "start create quantized session"; | |||
| flatbuffers::FlatBufferBuilder int8_builder(1024); | |||
| auto int8_offset = schema::MetaGraph::Pack(int8_builder, int8_meta_graph); | |||
| int8_builder.Finish(int8_offset); | |||
| schema::FinishMetaGraphBuffer(int8_builder, int8_offset); | |||
| size = int8_builder.GetSize(); | |||
| auto *int8_content = reinterpret_cast<const char *>(int8_builder.GetBufferPointer()); | |||
| if (int8_content == nullptr) { | |||
| MS_LOG(ERROR) << "GetBufferPointer nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| auto int8_model = lite::Model::Import(int8_content, size); | |||
| Context int8_ctx; | |||
| int8_ctx.thread_num_ = calibrator_->GetThreadNum(); | |||
| int8_ctx.device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = HIGHER_CPU; | |||
| int8_session_ = dynamic_cast<mindspore::lite::LiteSession *>(session::LiteSession::CreateSession(&int8_ctx)); | |||
| int8_session_ = CreateSessionByFuncGraph(func_graph, flags, calibrator_->GetThreadNum()); | |||
| if (int8_session_ == nullptr) { | |||
| MS_LOG(ERROR) << "create session failed!"; | |||
| return RET_ERROR; | |||
| } | |||
| ret = int8_session_->CompileGraph(int8_model); | |||
| if (ret != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "compile graph error"; | |||
| return RET_ERROR; | |||
| } | |||
| MS_LOG(INFO) << "do bias correction"; | |||
| status = BiasCorrection(func_graph); | |||
| @@ -28,6 +28,8 @@ | |||
| #include "tools/converter/quantizer/quantizer.h" | |||
| #include "tools/converter/converter.h" | |||
| #include "include/ms_tensor.h" | |||
| #include "tools/converter/quantizer/quantize_util.h" | |||
| #include "tools/converter/quantizer/weight_quantizer.h" | |||
| namespace mindspore::lite::quant { | |||
| class Calibrator; | |||
| @@ -38,19 +40,8 @@ struct MaxMin { | |||
| float max; | |||
| }; | |||
| const char kMethodMaxMin[] = "MAX_MIN"; | |||
| const char kMethodKL[] = "KL"; | |||
| const char kMethodOutlier[] = "RemovalOutlier"; | |||
| constexpr int kDefaultBinNumber = 2048; | |||
| struct ConfigParam { | |||
| std::vector<std::string> image_paths; | |||
| uint32_t batch_count{100}; | |||
| std::string method_x{kMethodKL}; | |||
| uint32_t thread_num{1}; | |||
| bool bias_correction{false}; | |||
| }; | |||
| class PostTrainingQuantizer : public Quantizer { | |||
| public: | |||
| PostTrainingQuantizer(FuncGraphPtr graph, std::string path, int bit_num, TypeId target_type = kNumberTypeInt8, | |||
| @@ -64,14 +55,16 @@ class PostTrainingQuantizer : public Quantizer { | |||
| int quant_min{INT8_MIN}; | |||
| private: | |||
| std::map<std::string, int> opname_bit_; | |||
| bool per_channel_{true}; | |||
| TypeId target_type_{kNumberTypeInt8}; | |||
| std::unique_ptr<Calibrator> calibrator_; | |||
| mindspore::lite::LiteSession *fp32_session_; | |||
| mindspore::lite::LiteSession *int8_session_; | |||
| session::LiteSession *fp32_session_{nullptr}; | |||
| session::LiteSession *int8_session_{nullptr}; | |||
| std::map<std::string, std::vector<float>> fp32_op_input_map; // concurency | |||
| std::map<std::string, std::vector<float>> fp32_op_output_ch_mean_map; // concurency | |||
| @@ -112,7 +105,8 @@ class PostTrainingQuantizer : public Quantizer { | |||
| STATUS DoQuantOutput(double scale, int32_t zeropoint, struct MaxMin *max_min, | |||
| const std::shared_ptr<PrimitiveC> &) const; | |||
| STATUS DoWeightQuant(const AnfNodePtr &weight, std::shared_ptr<PrimitiveC> primitive_c, bool perchannel) const; | |||
| STATUS DoWeightQuant(const std::string &op_name, const AnfNodePtr &weight, std::shared_ptr<PrimitiveC> primitive_c, | |||
| bool perchannel) const; | |||
| STATUS DoBiasQuant(const AnfNodePtr &bias, const std::shared_ptr<PrimitiveC> &primitive_c); | |||
| STATUS Int8Inference(); | |||
| @@ -213,13 +207,13 @@ class Calibrator { | |||
| std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> *GetOutputDivergInfo(); | |||
| PostQuantConfig config_param_; | |||
| private: | |||
| std::vector<std::vector<std::string>> images_; // multi_input, echo input has multi input data | |||
| std::string config_path_; | |||
| ConfigParam config_param_; | |||
| std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> inputs_diverg_info_; | |||
| std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> outputs_diverg_info_; | |||
| @@ -227,8 +221,6 @@ class Calibrator { | |||
| size_t bit_num_; | |||
| int quant_max_; | |||
| int quant_min_; | |||
| void AddImage(const std::string &file, size_t index); | |||
| }; | |||
| } // namespace mindspore::lite::quant | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_POSTRAINING_QUANTIZER_H | |||
| @@ -17,6 +17,8 @@ | |||
| #include "mindspore/lite/tools/converter/quantizer/quantize_util.h" | |||
| #include <cmath> | |||
| #include <string> | |||
| #include <map> | |||
| #include <fstream> | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <vector> | |||
| @@ -26,6 +28,8 @@ | |||
| #include "src/common/utils.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "securec/include/securec.h" | |||
| #include "tools/anf_exporter/anf_exporter.h" | |||
| #include "mindspore/lite/include/version.h" | |||
| using std::string; | |||
| using std::vector; | |||
| @@ -83,10 +87,10 @@ bool QuantStrategy::CanConvOpQuantized(const CNodePtr &node) const { | |||
| bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const { | |||
| MS_ASSERT(node != nullptr); | |||
| if (!node->isa<CNode>()) { | |||
| if (!node->isa<mindspore::CNode>()) { | |||
| return false; | |||
| } | |||
| auto cnode = std::dynamic_pointer_cast<CNode>(node); | |||
| auto cnode = std::dynamic_pointer_cast<mindspore::CNode>(node); | |||
| auto type = NodePrimitiveType(cnode); | |||
| static const std::vector<schema::PrimitiveType> int8OpList = { | |||
| schema::PrimitiveType_Conv2D, | |||
| @@ -475,4 +479,307 @@ schema::PrimitiveType NodePrimitiveType(const CNodePtr &cnode) { | |||
| } | |||
| return (schema::PrimitiveType)primitive_c->Type(); | |||
| } | |||
| STATUS ParseConfigFile(std::string config_file, PostQuantConfig *post_quant_config) { | |||
| if (post_quant_config == nullptr) { | |||
| MS_LOG(ERROR) << "post_quant_config is null."; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| if (config_file.empty() || config_file.length() > PATH_MAX) { | |||
| MS_LOG(ERROR) << "invalid config path!"; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| // check whether config file path is valid | |||
| auto resolved_path = std::make_unique<char[]>(PATH_MAX); | |||
| if (resolved_path == nullptr) { | |||
| MS_LOG(ERROR) << "New an object failed."; | |||
| return RET_ERROR; | |||
| } | |||
| #ifdef _WIN32 | |||
| if (_fullpath(resolved_path.get(), config_file.c_str(), 1024) != nullptr) { | |||
| config_file = string(resolved_path.get()); | |||
| } | |||
| #else | |||
| if (realpath(config_file.c_str(), resolved_path.get()) != nullptr) { | |||
| config_file = string(resolved_path.get()); | |||
| } | |||
| #endif | |||
| std::ifstream fs(config_file.c_str(), std::ifstream::in); | |||
| if (!fs.is_open()) { | |||
| MS_LOG(ERROR) << "config file open failed: " << config_file; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| std::string line; | |||
| while (std::getline(fs, line)) { | |||
| Trim(&line); | |||
| if (line.empty()) { | |||
| continue; | |||
| } | |||
| auto index = line.find('='); | |||
| if (index == std::string::npos) { | |||
| MS_LOG(ERROR) << "the config file is invalid, can not find '=', please check"; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| auto key = line.substr(0, index); | |||
| auto value = line.substr(index + 1); | |||
| Trim(&key); | |||
| Trim(&value); | |||
| if (key == "image_path") { | |||
| auto &raw_image_paths = value; | |||
| auto ind = raw_image_paths.find(','); | |||
| while (ind != std::string::npos) { | |||
| auto image_path = raw_image_paths.substr(0, ind); | |||
| Trim(&image_path); | |||
| post_quant_config->image_paths.push_back(image_path); | |||
| raw_image_paths = raw_image_paths.substr(ind + 1); | |||
| Trim(&raw_image_paths); | |||
| ind = raw_image_paths.find(','); | |||
| } | |||
| post_quant_config->image_paths.push_back(raw_image_paths); | |||
| } else if (key == "batch_count") { | |||
| post_quant_config->batch_count = std::stoul(value); | |||
| } else if (key == "thread_num") { | |||
| post_quant_config->thread_num = std::stoul(value); | |||
| } else if (key == "method_x") { | |||
| if (value != kMethodKL && value != kMethodMaxMin && value != kMethodOutlier) { | |||
| MS_LOG(WARNING) << "unsupported method_x: " << value << ". Use default value."; | |||
| } else { | |||
| post_quant_config->method_x = value; | |||
| } | |||
| } else if (key == "bias_correction") { | |||
| std::for_each(value.begin(), value.end(), ::tolower); | |||
| if (value == "true") { | |||
| post_quant_config->bias_correction = true; | |||
| } | |||
| } else if (key == "mixed") { | |||
| std::for_each(value.begin(), value.end(), ::tolower); | |||
| if (value == "true") { | |||
| post_quant_config->mixed = true; | |||
| } | |||
| } else if (key == "mean_error_threshold") { | |||
| post_quant_config->mean_error_threshold = std::stof(value); | |||
| } else { | |||
| MS_LOG(WARNING) << "unsupported parameter: " << key; | |||
| } | |||
| } | |||
| for (const auto &path : post_quant_config->image_paths) { | |||
| MS_LOG(DEBUG) << "calibration data_path: " << path; | |||
| } | |||
| MS_LOG(DEBUG) << "batch_count: " << post_quant_config->batch_count << "\n" | |||
| << "method_x: " << post_quant_config->method_x << "\n" | |||
| << "thread_num: " << post_quant_config->thread_num << "\n" | |||
| << "bias_correction: " << post_quant_config->bias_correction << "\n" | |||
| << "mixed: " << post_quant_config->mixed << "\n" | |||
| << "mean_error_threshold: " << post_quant_config->mean_error_threshold; | |||
| post_quant_config->inited = true; | |||
| fs.close(); | |||
| return RET_OK; | |||
| } | |||
| session::LiteSession *CreateSessionByFuncGraph(const FuncGraphPtr &func_graph, const converter::Flags &flags, | |||
| int thread_num) { | |||
| auto meta_graph = Export(func_graph, true, true); | |||
| if (meta_graph == nullptr) { | |||
| MS_LOG(ERROR) << "Export to meta_graph failed"; | |||
| return nullptr; | |||
| } | |||
| // transform | |||
| GraphDefTransform fb_transform; | |||
| fb_transform.SetGraphDef(meta_graph); | |||
| auto status = fb_transform.Transform(flags); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "FBTransform model failed"; | |||
| return nullptr; | |||
| } | |||
| meta_graph->version = Version(); | |||
| flatbuffers::FlatBufferBuilder builder(1024); | |||
| auto offset = schema::MetaGraph::Pack(builder, meta_graph); | |||
| builder.Finish(offset); | |||
| schema::FinishMetaGraphBuffer(builder, offset); | |||
| auto size = builder.GetSize(); | |||
| auto *content = reinterpret_cast<const char *>(builder.GetBufferPointer()); | |||
| if (content == nullptr) { | |||
| MS_LOG(ERROR) << "GetBufferPointer return null"; | |||
| return nullptr; | |||
| } | |||
| auto model = lite::Model::Import(content, size); | |||
| if (model == nullptr) { | |||
| MS_LOG(ERROR) << "Import model failed"; | |||
| return nullptr; | |||
| } | |||
| Context ctx; | |||
| ctx.thread_num_ = thread_num; | |||
| auto session = session::LiteSession::CreateSession(&ctx); | |||
| if (session == nullptr) { | |||
| MS_LOG(ERROR) << "create session failed."; | |||
| return nullptr; | |||
| } | |||
| status = session->CompileGraph(model); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "CompileGraph error"; | |||
| return nullptr; | |||
| } | |||
| model->Free(); | |||
| return session; | |||
| } | |||
| STATUS CollectCalibInputs(const std::vector<std::string> &input_dirs, size_t count_limited, | |||
| std::vector<std::vector<std::string>> *inputs) { | |||
| if (inputs == nullptr) { | |||
| MS_LOG(ERROR) << "inputs is null"; | |||
| return RET_ERROR; | |||
| } | |||
| auto AddImage = [&inputs](const std::string &file, size_t index) { | |||
| if (index >= inputs->size()) { | |||
| MS_LOG(ERROR) << "images_ size: " << inputs->size() << " but input index: " << index; | |||
| return; | |||
| } | |||
| struct stat buf {}; | |||
| if (stat(file.c_str(), &buf) == 0) { | |||
| inputs->at(index).push_back(file); | |||
| } else { | |||
| MS_LOG(WARNING) << "invalid image file path: " << file; | |||
| } | |||
| }; | |||
| inputs->resize(input_dirs.size()); | |||
| auto input_i = 0; | |||
| bool multi_input = input_dirs.size() > 1; | |||
| for (const auto &image_path : input_dirs) { | |||
| DIR *root = opendir(image_path.c_str()); | |||
| if (root == nullptr) { | |||
| MS_LOG(ERROR) << "invalid image path: " << image_path; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| struct dirent *image_dir = readdir(root); | |||
| size_t count = 0; | |||
| while (image_dir != nullptr) { | |||
| string file_name(image_dir->d_name); | |||
| if (file_name != "." && file_name != "..") { | |||
| const std::string file_path = image_path + "/" + file_name; | |||
| if (multi_input || count == 0) { | |||
| AddImage(file_path, input_i); | |||
| count++; | |||
| } else if (count < count_limited) { | |||
| AddImage(file_path, input_i); | |||
| count++; | |||
| } else { | |||
| break; | |||
| } | |||
| } | |||
| image_dir = readdir(root); | |||
| } | |||
| std::sort(inputs->at(input_i).begin(), inputs->at(input_i).end()); | |||
| if (count_limited != 0 && count_limited < inputs->at(input_i).size()) { | |||
| inputs->at(input_i).resize(count_limited); | |||
| } | |||
| closedir(root); | |||
| input_i++; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS CopyInputDataToTensor(size_t input_index, size_t image_index, | |||
| const std::vector<std::vector<std::string>> &images, mindspore::tensor::MSTensor *tensor) { | |||
| MS_ASSERT(tensor != nullptr); | |||
| if (input_index >= images.size()) { | |||
| MS_LOG(ERROR) << "images_ size: " << images.size() << " but input_index: " << input_index; | |||
| return RET_ERROR; | |||
| } | |||
| if (image_index >= images[input_index].size()) { | |||
| MS_LOG(ERROR) << "images_[input_index] size: " << images[input_index].size() << " but image_index: " << image_index; | |||
| return RET_ERROR; | |||
| } | |||
| string path = images[input_index][image_index]; | |||
| MS_LOG(INFO) << "read image: " << path; | |||
| size_t size; | |||
| char *bin_buf = ReadFile(path.c_str(), &size); | |||
| if (bin_buf == nullptr) { | |||
| MS_LOG(ERROR) << "ReadFile return nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto data = tensor->MutableData(); | |||
| if (data == nullptr) { | |||
| MS_LOG(ERROR) << "Get tensor MutableData return nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| if (size != tensor->Size()) { | |||
| MS_LOG(ERROR) << "the input data is not consistent with model input, file_size: " << size | |||
| << " input tensor size: " << tensor->Size(); | |||
| return RET_ERROR; | |||
| } | |||
| auto ret = memcpy_s(data, tensor->Size(), bin_buf, size); | |||
| if (ret != EOK) { | |||
| MS_LOG(ERROR) << "memcpy_s error: " << ret; | |||
| delete[] bin_buf; | |||
| return RET_ERROR; | |||
| } | |||
| delete[] bin_buf; | |||
| return RET_OK; | |||
| } | |||
| FuncGraphPtr CopyFuncGraph(const FuncGraphPtr &func_graph) { | |||
| Cloner cloner({func_graph}, true, true, true, std::make_shared<TraceCopy>(), nullptr); | |||
| auto new_func_graph = cloner[func_graph]; | |||
| std::map<std::string, CNodePtr> old_cnode_map; | |||
| for (const auto &cnode : func_graph->GetOrderedCnodes()) { | |||
| old_cnode_map[cnode->fullname_with_scope()] = cnode; | |||
| } | |||
| for (auto &cnode : new_func_graph->GetOrderedCnodes()) { | |||
| auto cnode_name = cnode->fullname_with_scope(); | |||
| auto old_cnode_iter = old_cnode_map.find(cnode_name); | |||
| if (old_cnode_iter == old_cnode_map.end()) { | |||
| MS_LOG(ERROR) << "can not find node: " << cnode_name; | |||
| return nullptr; | |||
| } | |||
| auto old_cnode = old_cnode_iter->second; | |||
| auto inputs = cnode->inputs(); | |||
| for (size_t i = 0; i < inputs.size(); i++) { | |||
| auto input_node = inputs[i]; | |||
| if (input_node->isa<Parameter>()) { | |||
| auto param_node = input_node->cast<ParameterPtr>(); | |||
| if (param_node->has_default()) { | |||
| ParamValueLitePtr old_param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param()); | |||
| auto new_param_value = std::make_shared<ParamValueLite>(); | |||
| auto copyed_data = malloc(old_param_value->tensor_size()); | |||
| if (copyed_data == nullptr) { | |||
| MS_LOG(ERROR) << "malloc data error, size: " << old_param_value->tensor_size(); | |||
| return nullptr; | |||
| } | |||
| memcpy(copyed_data, old_param_value->tensor_addr(), old_param_value->tensor_size()); | |||
| new_param_value->set_tensor_size(old_param_value->tensor_size()); | |||
| new_param_value->set_tensor_addr(copyed_data); | |||
| new_param_value->set_tensor_shape(old_param_value->tensor_shape()); | |||
| new_param_value->set_format(old_param_value->format()); | |||
| new_param_value->set_tensor_type(old_param_value->tensor_type()); | |||
| param_node->set_default_param(new_param_value); | |||
| } | |||
| auto old_abstract_base = param_node->abstract(); | |||
| if (!utils::isa<abstract::AbstractTensorPtr>(old_abstract_base)) { | |||
| MS_LOG(ERROR) << "Abstract of parameter should be abstract tensor, " << param_node->name(); | |||
| return nullptr; | |||
| } | |||
| auto old_abstract = utils::cast<abstract::AbstractTensorPtr>(old_abstract_base); | |||
| auto new_abstract = std::make_shared<abstract::AbstractTensor>(old_abstract->element()->GetTypeTrack(), | |||
| old_abstract->GetShapeTrack()); | |||
| param_node->set_abstract(new_abstract); | |||
| } | |||
| } // end inputs loop | |||
| } // end cnodes loop | |||
| return new_func_graph; | |||
| } | |||
| } // namespace mindspore::lite::quant | |||
| @@ -17,6 +17,8 @@ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANTIZER_UTIL_H | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANTIZER_UTIL_H | |||
| #include <dirent.h> | |||
| #include <sys/stat.h> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <cmath> | |||
| @@ -35,11 +37,29 @@ | |||
| #include "ir/primitive.h" | |||
| #include "abstract/dshape.h" | |||
| #include "tools/converter/quantizer/bitpacking.h" | |||
| #include "src/lite_session.h" | |||
| #include "tools/converter/graphdef_transform.h" | |||
| #include "src/common/file_utils.h" | |||
| namespace mindspore::lite::quant { | |||
| static constexpr size_t UINT8_QUANTIZATION = 8; | |||
| static constexpr size_t WEIGHT_INDEX = 1; | |||
| const char kMethodMaxMin[] = "MAX_MIN"; | |||
| const char kMethodKL[] = "KL"; | |||
| const char kMethodOutlier[] = "RemovalOutlier"; | |||
| struct PostQuantConfig { | |||
| std::vector<std::string> image_paths; | |||
| uint32_t batch_count{100}; | |||
| std::string method_x{kMethodKL}; | |||
| uint32_t thread_num{1}; | |||
| bool bias_correction{false}; | |||
| bool mixed{false}; | |||
| float mean_error_threshold{0.04}; | |||
| bool inited{false}; | |||
| }; | |||
| /** | |||
| * 1. when op's weight size > mWeightSize just skip | |||
| * 2. only do conv/deconv/convdepthwise/deconvdepthwise/mul/matmul/batchmatmul quantization | |||
| @@ -320,6 +340,21 @@ STATUS QuantFilter(const ParamValueLitePtr &weight, const std::shared_ptr<Primit | |||
| return RET_OK; | |||
| } | |||
| // utils | |||
| schema::PrimitiveType NodePrimitiveType(const CNodePtr &cnode); | |||
| STATUS ParseConfigFile(std::string config_file, PostQuantConfig *post_quant_config); | |||
| session::LiteSession *CreateSessionByFuncGraph(const FuncGraphPtr &func_graph, const converter::Flags &flags, | |||
| int thread_num); | |||
| STATUS CollectCalibInputs(const std::vector<std::string> &input_dirs, size_t count_limited, | |||
| std::vector<std::vector<std::string>> *inputs); | |||
| STATUS CopyInputDataToTensor(size_t input_index, size_t image_index, | |||
| const std::vector<std::vector<std::string>> &images, mindspore::tensor::MSTensor *tensor); | |||
| FuncGraphPtr CopyFuncGraph(const FuncGraphPtr &); | |||
| } // namespace mindspore::lite::quant | |||
| #endif | |||
| @@ -18,6 +18,7 @@ | |||
| #include <list> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <unordered_map> | |||
| #include "src/common/common.h" | |||
| #include "ir/dtype/type_id.h" | |||
| @@ -36,6 +37,7 @@ bool WeightQuantizer::IsPosNum(const std::string &str) { | |||
| } | |||
| return true; | |||
| } | |||
| STATUS WeightQuantizer::WeightQuantInputCheck(const converter::Flags *config) { | |||
| MS_ASSERT(config != nullptr); | |||
| if (!WeightQuantizer::IsPosNum(config->quantWeightChannel)) { | |||
| @@ -57,28 +59,57 @@ STATUS WeightQuantizer::WeightQuantInputCheck(const converter::Flags *config) { | |||
| } | |||
| return RET_OK; | |||
| } | |||
| WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const string &weightSize, | |||
| WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const PostQuantConfig &config) : Quantizer(graph) { | |||
| quant_strategy_ = std::make_unique<QuantStrategy>(0, 0); | |||
| config_param_ = config; | |||
| } | |||
| WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const std::string &config_file, const string &weightSize, | |||
| const std::string &convWeightChannelThreshold, const std::string &bitNum) | |||
| : Quantizer(graph) { | |||
| this->config_file_ = config_file; | |||
| auto quantSize = static_cast<size_t>(std::stoull(weightSize)); | |||
| this->bitNum = static_cast<size_t>(std::stoull(bitNum)); | |||
| this->bit_num_ = static_cast<size_t>(std::stoull(bitNum)); | |||
| auto convQuantWeightChannelThreshold = static_cast<size_t>(std::stoull(convWeightChannelThreshold)); | |||
| mStrategy = std::make_unique<QuantStrategy>(quantSize, convQuantWeightChannelThreshold); | |||
| quant_max = (1 << (unsigned int)(this->bitNum - 1)) - 1; | |||
| quant_min = -(1 << (unsigned int)(this->bitNum - 1)); | |||
| quant_strategy_ = std::make_unique<QuantStrategy>(quantSize, convQuantWeightChannelThreshold); | |||
| quant_max = (1 << (unsigned int)(this->bit_num_ - 1)) - 1; | |||
| quant_min = -(1 << (unsigned int)(this->bit_num_ - 1)); | |||
| // parse type_id | |||
| if (this->bitNum > 0 && this->bitNum <= 8) { | |||
| if (this->bit_num_ > 0 && this->bit_num_ <= 8) { | |||
| type_id = kNumberTypeInt8; | |||
| } else if (this->bitNum <= 16) { | |||
| } else if (this->bit_num_ <= 16) { | |||
| type_id = kNumberTypeInt16; | |||
| } else { | |||
| MS_LOG(ERROR) << "invalid input bits"; | |||
| } | |||
| } | |||
| WeightQuantizer::~WeightQuantizer() { delete fp32_session_; } | |||
| STATUS WeightQuantizer::SetAbstract(ParamValueLitePtr param_value, ParameterPtr param_node, | |||
| std::shared_ptr<PrimitiveC> primitive_c) { | |||
| // set dtype | |||
| param_value->set_tensor_type(type_id); | |||
| auto abstract_base = param_node->abstract(); | |||
| if (abstract_base == nullptr) { | |||
| MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name(); | |||
| return RET_ERROR; | |||
| } | |||
| if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) { | |||
| MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << param_node->name(); | |||
| return RET_ERROR; | |||
| } | |||
| auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base); | |||
| abstract_tensor->element()->set_type(TypeIdToType(type_id)); | |||
| primitive_c->set_quant_type(schema::QuantType_WeightQuant); | |||
| return RET_OK; | |||
| } | |||
| STATUS WeightQuantizer::DoConvQuantize(const std::list<CNodePtr> &nodes) { | |||
| for (auto &cnode : nodes) { | |||
| if (!mStrategy->CanConvOpQuantized(cnode)) { | |||
| if (!quant_strategy_->CanConvOpQuantized(cnode)) { | |||
| continue; | |||
| } | |||
| @@ -108,36 +139,28 @@ STATUS WeightQuantizer::DoConvQuantize(const std::list<CNodePtr> &nodes) { | |||
| } | |||
| auto status = RET_ERROR; | |||
| if (type_id == kNumberTypeInt8) { | |||
| status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bitNum, true); | |||
| status = | |||
| QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bit_num_, true); | |||
| } else if (type_id == kNumberTypeInt16) { | |||
| status = | |||
| QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bitNum, true); | |||
| QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bit_num_, true); | |||
| } | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "QuantFilter failed : " << status; | |||
| return status; | |||
| } | |||
| // set dtype | |||
| param_value->set_tensor_type(type_id); | |||
| auto abstractBase = param_node->abstract(); | |||
| if (abstractBase == nullptr) { | |||
| MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name(); | |||
| return RET_ERROR; | |||
| } | |||
| if (!utils::isa<abstract::AbstractTensorPtr>(abstractBase)) { | |||
| MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << param_node->name(); | |||
| status = SetAbstract(param_value, param_node, primitive_c); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "SetAbstract failed : " << status; | |||
| return RET_ERROR; | |||
| } | |||
| auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase); | |||
| abstractTensor->element()->set_type(TypeIdToType(type_id)); | |||
| primitive_c->set_quant_type(schema::QuantType_WeightQuant); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS WeightQuantizer::DoMulQuantize(const std::list<CNodePtr> &nodes) { | |||
| for (auto &node : nodes) { | |||
| if (!mStrategy->CanMulOpQuantized(node)) { | |||
| if (!quant_strategy_->CanMulOpQuantized(node)) { | |||
| continue; | |||
| } | |||
| auto already_quant = false; | |||
| @@ -186,38 +209,271 @@ STATUS WeightQuantizer::DoMulQuantize(const std::list<CNodePtr> &nodes) { | |||
| auto status = RET_ERROR; | |||
| if (type_id == kNumberTypeInt8) { | |||
| status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bitNum, true); | |||
| status = | |||
| QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bit_num_, true); | |||
| } else if (type_id == kNumberTypeInt16) { | |||
| status = | |||
| QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bitNum, true); | |||
| QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bit_num_, true); | |||
| } | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "QuantFilter failed : " << status; | |||
| return status; | |||
| } | |||
| param_value->set_tensor_type(type_id); | |||
| // set dtype | |||
| auto abstractBase = param_node->abstract(); | |||
| if (abstractBase == nullptr) { | |||
| MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name(); | |||
| status = SetAbstract(param_value, param_node, primitive_c); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "SetAbstract failed : " << status; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| constexpr float relative_tolerance = 1e-5; | |||
| constexpr float abs_tolerance = 1e-4; | |||
| template <typename T> | |||
| float CompareOutputData(const std::unordered_map<std::string, mindspore::tensor::MSTensor *> &expected_tensor, | |||
| const std::unordered_map<std::string, mindspore::tensor::MSTensor *> &compare_tensor) { | |||
| auto valid_data = [](T data) -> bool { return (!std::isnan(data) && !std::isinf(data)); }; | |||
| float total_mean_error = 0.0f; | |||
| int tensor_cnt = expected_tensor.size(); | |||
| if (tensor_cnt <= 0) { | |||
| MS_LOG(ERROR) << "unexpected tensor_cnt: " << tensor_cnt; | |||
| return RET_ERROR; | |||
| } | |||
| for (const auto &exp_tensor_pair : expected_tensor) { | |||
| float mean_error = 0.0f; | |||
| int error_cnt = 0; | |||
| auto exp_tensor_name = exp_tensor_pair.first; | |||
| auto exp_tensor = exp_tensor_pair.second; | |||
| auto cmp_tensor_find_iter = compare_tensor.find(exp_tensor_name); | |||
| if (cmp_tensor_find_iter == compare_tensor.end()) { | |||
| MS_LOG(ERROR) << "can not find: " << exp_tensor_name; | |||
| return RET_ERROR; | |||
| } | |||
| if (!utils::isa<abstract::AbstractTensorPtr>(abstractBase)) { | |||
| MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << param_node->name(); | |||
| auto cmp_tensor = cmp_tensor_find_iter->second; | |||
| auto exp_tensor_shape = exp_tensor->shape(); | |||
| auto cmp_tensor_shape = cmp_tensor->shape(); | |||
| if (exp_tensor_shape != cmp_tensor_shape) { | |||
| MS_LOG(ERROR) << "exp tensor shape not equal to cmp. exp_tensor_elem_cnt: " << exp_tensor->ElementsNum() | |||
| << " cmp_tensor_elem_cnt: " << cmp_tensor->ElementsNum(); | |||
| return RET_ERROR; | |||
| } | |||
| auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase); | |||
| abstractTensor->element()->set_type(TypeIdToType(type_id)); | |||
| primitive_c->set_quant_type(schema::QuantType_WeightQuant); | |||
| auto exp_data = static_cast<T *>(exp_tensor->MutableData()); | |||
| auto cmp_data = static_cast<T *>(cmp_tensor->MutableData()); | |||
| auto elem_cnt = exp_tensor->ElementsNum(); | |||
| for (int i = 0; i < elem_cnt; i++) { | |||
| if (!valid_data(exp_data[i]) || !valid_data(cmp_data[i])) { | |||
| MS_LOG(ERROR) << "data is not valid. exp: " << exp_data[i] << " cmp: " << cmp_data[i] << " index: " << i; | |||
| return RET_ERROR; | |||
| } | |||
| auto tolerance = abs_tolerance + relative_tolerance * fabs(exp_data[i]); | |||
| auto abs_error = std::fabs(exp_data[i] - cmp_data[i]); | |||
| if (abs_error > tolerance) { | |||
| if (fabs(exp_data[i] == 0)) { | |||
| if (abs_error > 1e-5) { | |||
| mean_error += abs_error; | |||
| error_cnt++; | |||
| } else { | |||
| // it is ok, very close to 0 | |||
| continue; | |||
| } | |||
| } else { | |||
| mean_error += abs_error / (fabs(exp_data[i]) + FLT_MIN); | |||
| error_cnt++; | |||
| } | |||
| } else { | |||
| // it is ok, no error | |||
| continue; | |||
| } | |||
| } // end one tensor data loop | |||
| total_mean_error += mean_error / elem_cnt; | |||
| } // end tensor loop | |||
| return total_mean_error / tensor_cnt; | |||
| } | |||
| STATUS WeightQuantizer::DoMiexedQuant(FuncGraphPtr func_graph) { | |||
| // 0.1 Create Fp32 Session | |||
| flags.quantType = schema::QuantType_QUANT_NONE; | |||
| fp32_session_ = CreateSessionByFuncGraph(func_graph, flags, config_param_.thread_num); | |||
| if (fp32_session_ == nullptr) { | |||
| MS_LOG(ERROR) << "CreateSessoin fail"; | |||
| return RET_ERROR; | |||
| } | |||
| auto fp32_inputs = fp32_session_->GetInputs(); | |||
| // 0.2 Parse input calib files | |||
| auto status = CollectCalibInputs(config_param_.image_paths, config_param_.batch_count, &images_); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "CollectCalibInputs fail"; | |||
| return RET_ERROR; | |||
| } | |||
| auto cnodes = func_graph->GetOrderedCnodes(); | |||
| for (auto iter = cnodes.end(); iter != cnodes.begin();) { | |||
| auto cnode = *(--iter); | |||
| auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | |||
| if (primitive_c == nullptr) { | |||
| MS_LOG(ERROR) << "primitive_c is null."; | |||
| return RET_ERROR; | |||
| } | |||
| auto op_name = cnode->fullname_with_scope(); | |||
| MS_LOG(DEBUG) << "process node: " << op_name | |||
| << " type: " << schema::EnumNamePrimitiveType((schema::PrimitiveType)primitive_c->Type()); | |||
| if (quant_strategy_->CanConvOpQuantized(cnode) || quant_strategy_->CanMulOpQuantized(cnode)) { | |||
| auto input_node = cnode->input(2); | |||
| if (!input_node->isa<Parameter>()) { | |||
| MS_LOG(WARNING) << op_name << " the second input is not parameter"; | |||
| continue; | |||
| } | |||
| auto param_node = input_node->cast<ParameterPtr>(); | |||
| if (!param_node->has_default()) { | |||
| MS_LOG(WARNING) << op_name << " the second input can not convert to parameter"; | |||
| continue; | |||
| } | |||
| auto param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param()); | |||
| if (param_value == nullptr) { | |||
| MS_LOG(WARNING) << op_name << " the second input can not convert to parameter"; | |||
| continue; | |||
| } | |||
| if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) { | |||
| MS_LOG(WARNING) << op_name << " the second input type is not float"; | |||
| continue; | |||
| } | |||
| // copy origin data in case to recover | |||
| auto *raw_data = static_cast<float *>(param_value->tensor_addr()); | |||
| auto elem_count = param_value->tensor_shape_size(); | |||
| auto origin_data = malloc(sizeof(float) * elem_count); | |||
| auto ret = memcpy_s(origin_data, sizeof(float) * elem_count, raw_data, param_value->tensor_size()); | |||
| if (ret != EOK) { | |||
| MS_LOG(ERROR) << "memcpy fail: " | |||
| << " dst size: " << sizeof(float) * elem_count << " src size: " << param_value->tensor_size(); | |||
| return RET_ERROR; | |||
| } | |||
| // 1. try quant | |||
| for (int bit_num_t = 2; bit_num_t <= 8; bit_num_t++) { | |||
| type_id = TypeId::kNumberTypeInt8; | |||
| int quant_max_t = (1 << (unsigned int)(bit_num_t - 1)) - 1; | |||
| int quant_min_t = -(1 << (unsigned int)(bit_num_t - 1)); | |||
| if (type_id == TypeId::kNumberTypeInt8) { | |||
| status = QuantFilter<int8_t>(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t, | |||
| quant_min_t, bit_num_t, true); | |||
| } else if (type_id == TypeId::kNumberTypeInt16) { | |||
| status = QuantFilter<int16_t>(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t, | |||
| quant_min_t, bit_num_t, true); | |||
| } else { | |||
| MS_LOG(ERROR) << "unexpected type_id: " << type_id; | |||
| return RET_ERROR; | |||
| } | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "quant filter fail."; | |||
| return RET_ERROR; | |||
| } | |||
| status = SetAbstract(param_value, param_node, primitive_c); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "SetAbstract failed : " << status; | |||
| return RET_ERROR; | |||
| } | |||
| // 2. evaluate the quant | |||
| // 2.1 create quant session, get input, output tensor | |||
| flags.quantType = schema::QuantType_WeightQuant; | |||
| auto quant_session = | |||
| std::unique_ptr<session::LiteSession>(CreateSessionByFuncGraph(func_graph, flags, config_param_.thread_num)); | |||
| if (quant_session == nullptr) { | |||
| MS_LOG(ERROR) << "create session error: " << status; | |||
| return RET_ERROR; | |||
| } | |||
| auto quant_inputs = quant_session->GetInputs(); | |||
| auto mean_error = 0.0f; | |||
| if (fp32_inputs.size() != images_.size()) { | |||
| MS_LOG(ERROR) << "model's input tensor cnt: " << fp32_inputs.size() << " != " << images_.size(); | |||
| return RET_ERROR; | |||
| } | |||
| auto image_cnt = images_.at(0).size(); | |||
| for (size_t i = 0; i < image_cnt; i++) { | |||
| // set multi-input data | |||
| for (size_t input_index = 0; input_index < fp32_inputs.size(); input_index++) { | |||
| status = CopyInputDataToTensor(input_index, i, images_, fp32_inputs[input_index]); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "generate input data from images failed!"; | |||
| return RET_ERROR; | |||
| } | |||
| status = CopyInputDataToTensor(input_index, i, images_, quant_inputs[input_index]); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "generate input data from images failed!"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| std::future<STATUS> fp32_inference = std::async( | |||
| std::launch::async, [](session::LiteSession *fp32_session) -> STATUS { return fp32_session->RunGraph(); }, | |||
| fp32_session_); | |||
| status = quant_session->RunGraph(); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "quant session run error"; | |||
| return RET_ERROR; | |||
| } | |||
| status = fp32_inference.get(); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "fp32 session run error"; | |||
| return RET_ERROR; | |||
| } | |||
| // 3. compare betwen quant and fp32 | |||
| auto fp32_outputs = fp32_session_->GetOutputs(); | |||
| auto quant_outputs = quant_session->GetOutputs(); | |||
| mean_error += CompareOutputData<float>(fp32_outputs, quant_outputs); | |||
| } // end_for: calib data loop | |||
| mean_error = mean_error / image_cnt; | |||
| if (mean_error <= config_param_.mean_error_threshold) { | |||
| MS_LOG(DEBUG) << "op: " << op_name << " got mixed bit: " << bit_num_t << " mean_error: " << mean_error; | |||
| opname_bit_[op_name] = bit_num_t; | |||
| break; | |||
| } else if (bit_num_t != 8) { | |||
| // recover | |||
| param_value->set_tensor_size(sizeof(float) * elem_count); | |||
| ret = memcpy_s(raw_data, param_value->tensor_size(), origin_data, sizeof(float) * elem_count); | |||
| if (ret != EOK) { | |||
| MS_LOG(ERROR) << "memcpy fail: " | |||
| << " src size: " << sizeof(float) * elem_count << " dst size: " << param_value->tensor_size(); | |||
| return RET_ERROR; | |||
| } | |||
| } else { | |||
| MS_LOG(DEBUG) << "op: " << op_name << " set bit: " << bit_num_t << " mean_error: " << mean_error; | |||
| opname_bit_[op_name] = bit_num_t; | |||
| } | |||
| } // end bit loop | |||
| free(origin_data); | |||
| } // if: conv and matmul | |||
| } // end loop: all cnode | |||
| return RET_OK; | |||
| } | |||
| STATUS WeightQuantizer::DoQuantize(FuncGraphPtr funcGraph) { | |||
| MS_ASSERT(funcGraph != nullptr); | |||
| STATUS WeightQuantizer::DoQuantize(FuncGraphPtr func_graph) { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| STATUS ret; | |||
| auto cnodes = funcGraph->GetOrderedCnodes(); | |||
| auto cnodes = func_graph->GetOrderedCnodes(); | |||
| if (!config_file_.empty()) { | |||
| ret = ParseConfigFile(config_file_, &config_param_); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "ReadConfig error."; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| if (config_param_.mixed) { | |||
| MS_LOG(INFO) << "Do mixed bit quantization"; | |||
| return DoMiexedQuant(func_graph); | |||
| } | |||
| ret = DoConvQuantize(cnodes); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "DoConvQuantize failed :" << ret; | |||
| @@ -17,9 +17,12 @@ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_WEIGHT_QUANTIZER_H | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_WEIGHT_QUANTIZER_H | |||
| #include <future> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <list> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "tools/converter/quantizer/quantizer.h" | |||
| #include "tools/converter/quantizer/quantize_util.h" | |||
| #include "ir/func_graph.h" | |||
| @@ -27,27 +30,37 @@ | |||
| #include "include/model.h" | |||
| #include "base/base.h" | |||
| #include "abstract/dshape.h" | |||
| #include "src/lite_session.h" | |||
| namespace mindspore::lite::quant { | |||
| class WeightQuantizer : public Quantizer { | |||
| public: | |||
| WeightQuantizer(FuncGraphPtr graph, const std::string &weightSize, const std::string &covWeightChannelThreshold, | |||
| const std::string &bitNum); | |||
| WeightQuantizer(FuncGraphPtr graph, const std::string &config_file, const std::string &weightSize, | |||
| const std::string &covWeightChannelThreshold, const std::string &bitNum); | |||
| WeightQuantizer(FuncGraphPtr graph, const PostQuantConfig &config); | |||
| ~WeightQuantizer(); | |||
| ~WeightQuantizer() = default; | |||
| STATUS DoQuantize(FuncGraphPtr funcGraph) override; | |||
| STATUS DoQuantize(FuncGraphPtr func_graph) override; | |||
| STATUS DoConvQuantize(const std::list<CNodePtr> &nodes); | |||
| STATUS DoMulQuantize(const std::list<CNodePtr> &nodes); | |||
| static STATUS WeightQuantInputCheck(const converter::Flags *config); | |||
| static bool IsPosNum(const std::string &str); | |||
| int quant_max; | |||
| int quant_min; | |||
| TypeId type_id{kTypeUnknown}; | |||
| std::map<std::string, int> opname_bit_; | |||
| private: | |||
| std::unique_ptr<QuantStrategy> mStrategy; | |||
| size_t bitNum; | |||
| std::unique_ptr<QuantStrategy> quant_strategy_; | |||
| size_t bit_num_; | |||
| std::string config_file_; | |||
| PostQuantConfig config_param_; | |||
| std::vector<std::vector<std::string>> images_; // multi_input, [[mode_input_0], [model_input_1]...] | |||
| session::LiteSession *fp32_session_ = nullptr; | |||
| STATUS DoMiexedQuant(FuncGraphPtr); | |||
| STATUS SetAbstract(ParamValueLitePtr param_value, ParameterPtr param_node, std::shared_ptr<PrimitiveC> primitive_c); | |||
| }; | |||
| } // namespace mindspore::lite::quant | |||
| #endif | |||