| @@ -21,7 +21,7 @@ | |||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "src/common/log_adapter.h" | #include "src/common/log_adapter.h" | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| #include "tools/converter/quantizer/quantize_util.h" | |||||
| #include "src/param_value_lite.h" | |||||
| #endif | #endif | ||||
| #ifndef PRIMITIVE_WRITEABLE | #ifndef PRIMITIVE_WRITEABLE | ||||
| @@ -24,7 +24,7 @@ | |||||
| #include "src/common/log_adapter.h" | #include "src/common/log_adapter.h" | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| #include <float.h> | #include <float.h> | ||||
| #include "tools/converter/quantizer/quantize_util.h" | |||||
| #include "src/param_value_lite.h" | |||||
| #endif | #endif | ||||
| #ifndef PRIMITIVE_WRITEABLE | #ifndef PRIMITIVE_WRITEABLE | ||||
| @@ -21,8 +21,7 @@ | |||||
| #include "src/common/log_adapter.h" | #include "src/common/log_adapter.h" | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| #include <float.h> | #include <float.h> | ||||
| #include "tools/converter/quantizer/quantize_util.h" | |||||
| #include "src/param_value_lite.h" | |||||
| #endif | #endif | ||||
| #ifndef PRIMITIVE_WRITEABLE | #ifndef PRIMITIVE_WRITEABLE | ||||
| @@ -19,7 +19,7 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| #include "tools/converter/quantizer/quantize_util.h" | |||||
| #include "src/param_value_lite.h" | |||||
| #endif | #endif | ||||
| #ifndef PRIMITIVE_WRITEABLE | #ifndef PRIMITIVE_WRITEABLE | ||||
| #include "src/ops/ops_register.h" | #include "src/ops/ops_register.h" | ||||
| @@ -18,7 +18,7 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <utility> | #include <utility> | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| #include "tools/converter/quantizer/quantize_util.h" | |||||
| #include "src/param_value_lite.h" | |||||
| #endif | #endif | ||||
| #ifndef PRIMITIVE_WRITEABLE | #ifndef PRIMITIVE_WRITEABLE | ||||
| @@ -19,8 +19,7 @@ | |||||
| #include "src/common/log_adapter.h" | #include "src/common/log_adapter.h" | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| #include <float.h> | #include <float.h> | ||||
| #include "tools/converter/quantizer/quantize_util.h" | |||||
| #include "src/param_value_lite.h" | |||||
| #endif | #endif | ||||
| #ifndef PRIMITIVE_WRITEABLE | #ifndef PRIMITIVE_WRITEABLE | ||||
| @@ -19,7 +19,7 @@ | |||||
| #include "src/common/log_adapter.h" | #include "src/common/log_adapter.h" | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| #include <float.h> | #include <float.h> | ||||
| #include "tools/converter/quantizer/quantize_util.h" | |||||
| #include "src/param_value_lite.h" | |||||
| #endif | #endif | ||||
| #ifndef PRIMITIVE_WRITEABLE | #ifndef PRIMITIVE_WRITEABLE | ||||
| @@ -19,7 +19,7 @@ | |||||
| #include "src/common/log_adapter.h" | #include "src/common/log_adapter.h" | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| #include <float.h> | #include <float.h> | ||||
| #include "tools/converter/quantizer/quantize_util.h" | |||||
| #include "src/param_value_lite.h" | |||||
| #endif | #endif | ||||
| #ifndef PRIMITIVE_WRITEABLE | #ifndef PRIMITIVE_WRITEABLE | ||||
| @@ -387,7 +387,9 @@ void PrimitiveC::set_input_quant_params(const std::vector<std::vector<schema::Qu | |||||
| } | } | ||||
| void PrimitiveC::set_input_quant_param(const size_t &index, const std::vector<schema::QuantParamT> &input_quant_param) { | void PrimitiveC::set_input_quant_param(const size_t &index, const std::vector<schema::QuantParamT> &input_quant_param) { | ||||
| MS_ASSERT(index < this->input_quant_param_.size()); | |||||
| if (index >= this->input_quant_param_.size()) { | |||||
| this->input_quant_param_.resize(index + 1); | |||||
| } | |||||
| this->input_quant_param_.at(index) = input_quant_param; | this->input_quant_param_.at(index) = input_quant_param; | ||||
| } | } | ||||
| @@ -493,7 +495,7 @@ std::shared_ptr<PrimitiveC> GetTupleGetItemPrim() { | |||||
| } | } | ||||
| template <typename T, typename = std::enable_if<std::is_base_of<PrimitiveC, T>::value>> | template <typename T, typename = std::enable_if<std::is_base_of<PrimitiveC, T>::value>> | ||||
| std::shared_ptr<PrimitiveC> NewPrimitiveC(const Primitive &prim, const std::vector<AnfNodePtr> &inputs, | |||||
| std::shared_ptr<PrimitiveC> NewPrimitiveC(const mindspore::Primitive &prim, const std::vector<AnfNodePtr> &inputs, | |||||
| const schema::QuantType &quantType) { | const schema::QuantType &quantType) { | ||||
| auto primc = std::make_shared<T>(); | auto primc = std::make_shared<T>(); | ||||
| if (primc == nullptr) { | if (primc == nullptr) { | ||||
| @@ -204,7 +204,7 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap | |||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); | ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| this->mQuantizer = std::make_unique<quant::WeightQuantizer>(new_graph, config->quantWeightSize, | |||||
| this->mQuantizer = std::make_unique<quant::WeightQuantizer>(new_graph, config->configFile, config->quantWeightSize, | |||||
| config->quantWeightChannel, config->bitNum); | config->quantWeightChannel, config->bitNum); | ||||
| if (mQuantizer == nullptr) { | if (mQuantizer == nullptr) { | ||||
| MS_LOG(ERROR) << "New WeightQuantizer failed"; | MS_LOG(ERROR) << "New WeightQuantizer failed"; | ||||
| @@ -32,8 +32,6 @@ | |||||
| #include "tools/anf_exporter/anf_exporter.h" | #include "tools/anf_exporter/anf_exporter.h" | ||||
| #include "tools/anf_importer/import_from_mindir.h" | #include "tools/anf_importer/import_from_mindir.h" | ||||
| #include "proto/onnx.pb.h" | #include "proto/onnx.pb.h" | ||||
| #include "tools/converter/quantizer/post_training_quantizer.h" | |||||
| #include "tools/converter/quantizer/quant_cast.h" | |||||
| #include "include/version.h" | #include "include/version.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -16,7 +16,6 @@ | |||||
| #include "tools/converter/legacy_optimizer/graph/tensor_name_pass.h" | #include "tools/converter/legacy_optimizer/graph/tensor_name_pass.h" | ||||
| #include "tools/converter/converter_context.h" | #include "tools/converter/converter_context.h" | ||||
| #include "tools/converter/quantizer/quantize_util.h" | |||||
| #include "tools/common/tensor_util.h" | #include "tools/common/tensor_util.h" | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| @@ -6,7 +6,6 @@ include_directories(${3RD_DIR}/opencv/build/include/opencv4) | |||||
| file(GLOB QUANTIZER | file(GLOB QUANTIZER | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/calc_quant_param.cc | ${CMAKE_CURRENT_SOURCE_DIR}/calc_quant_param.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/quantizer.cc | ${CMAKE_CURRENT_SOURCE_DIR}/quantizer.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/aware_quantizer.cc | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/quantize_util.cc | ${CMAKE_CURRENT_SOURCE_DIR}/quantize_util.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/post_training_quantizer.cc | ${CMAKE_CURRENT_SOURCE_DIR}/post_training_quantizer.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/quant_cast.cc | ${CMAKE_CURRENT_SOURCE_DIR}/quant_cast.cc | ||||
| @@ -39,6 +39,7 @@ | |||||
| #include "tools/common/tensor_util.h" | #include "tools/common/tensor_util.h" | ||||
| #include "src/common/file_utils.h" | #include "src/common/file_utils.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "tools/converter/quantizer/weight_quantizer.h" | |||||
| using std::string; | using std::string; | ||||
| using std::vector; | using std::vector; | ||||
| @@ -380,182 +381,16 @@ STATUS Calibrator::AddQuantizedOp(const CNodePtr &node) { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| void Calibrator::AddImage(const string &file, size_t index) { | |||||
| if (index >= images_.size()) { | |||||
| MS_LOG(ERROR) << "images_ size: " << images_.size() << " but index: " << index; | |||||
| return; | |||||
| } | |||||
| auto exist = [](const string &file) { | |||||
| struct stat buf {}; | |||||
| return stat(file.c_str(), &buf) == 0; | |||||
| }; | |||||
| if (exist(file)) { | |||||
| this->images_[index].push_back(file); | |||||
| } else { | |||||
| MS_LOG(WARNING) << "invalid image file path: " << file; | |||||
| } | |||||
| } | |||||
| STATUS Calibrator::GenerateInputData(size_t input_index, size_t image_index, | STATUS Calibrator::GenerateInputData(size_t input_index, size_t image_index, | ||||
| mindspore::tensor::MSTensor *tensor) const { | mindspore::tensor::MSTensor *tensor) const { | ||||
| MS_ASSERT(tensor != nullptr); | |||||
| if (input_index >= images_.size()) { | |||||
| MS_LOG(ERROR) << "images_ size: " << images_.size() << " but input_index: " << input_index; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (image_index >= images_[input_index].size()) { | |||||
| MS_LOG(ERROR) << "images_[input_index] size: " << images_[input_index].size() | |||||
| << " but image_index: " << image_index; | |||||
| return RET_ERROR; | |||||
| } | |||||
| string path = images_[input_index][image_index]; | |||||
| MS_LOG(INFO) << "read image: " << path; | |||||
| size_t size; | |||||
| char *bin_buf = ReadFile(path.c_str(), &size); | |||||
| if (bin_buf == nullptr) { | |||||
| MS_LOG(ERROR) << "ReadFile return nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto data = tensor->MutableData(); | |||||
| if (data == nullptr) { | |||||
| MS_LOG(ERROR) << "Get tensor MutableData return nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| if (size != tensor->Size()) { | |||||
| MS_LOG(ERROR) << "the input data is not consistent with model input, file_size: " << size | |||||
| << " input tensor size: " << tensor->Size(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto ret = memcpy_s(data, tensor->Size(), bin_buf, size); | |||||
| if (ret != EOK) { | |||||
| MS_LOG(ERROR) << "memcpy_s error: " << ret; | |||||
| delete[] bin_buf; | |||||
| return RET_ERROR; | |||||
| } | |||||
| delete[] bin_buf; | |||||
| return RET_OK; | |||||
| return CopyInputDataToTensor(input_index, image_index, images_, tensor); | |||||
| } | } | ||||
| STATUS Calibrator::CollectImages() { | STATUS Calibrator::CollectImages() { | ||||
| this->images_.resize(config_param_.image_paths.size()); | |||||
| auto input_i = 0; | |||||
| bool multi_input = config_param_.image_paths.size() > 1; | |||||
| for (const auto &image_path : config_param_.image_paths) { | |||||
| DIR *root = opendir(image_path.c_str()); | |||||
| if (root == nullptr) { | |||||
| MS_LOG(ERROR) << "invalid image path: " << image_path; | |||||
| return RET_PARAM_INVALID; | |||||
| } | |||||
| struct dirent *image_dir = readdir(root); | |||||
| size_t count = 0; | |||||
| while (image_dir != nullptr) { | |||||
| string file_name(image_dir->d_name); | |||||
| if (file_name != "." && file_name != "..") { | |||||
| const std::string file_path = image_path + "/" + file_name; | |||||
| if (multi_input || config_param_.batch_count == 0) { | |||||
| this->AddImage(file_path, input_i); | |||||
| count++; | |||||
| } else if (count < config_param_.batch_count) { | |||||
| this->AddImage(file_path, input_i); | |||||
| count++; | |||||
| } else { | |||||
| break; | |||||
| } | |||||
| } | |||||
| image_dir = readdir(root); | |||||
| } | |||||
| std::sort(images_[input_i].begin(), images_[input_i].end()); | |||||
| if (config_param_.batch_count != 0 && config_param_.batch_count < images_[input_i].size()) { | |||||
| images_[input_i].resize(config_param_.batch_count); | |||||
| } | |||||
| closedir(root); | |||||
| input_i++; | |||||
| } | |||||
| return RET_OK; | |||||
| return CollectCalibInputs(config_param_.image_paths, config_param_.batch_count, &images_); | |||||
| } | } | ||||
| STATUS Calibrator::ReadConfig() { | |||||
| if (config_path_.empty() || config_path_.length() > PATH_MAX) { | |||||
| MS_LOG(ERROR) << "invalid config path!"; | |||||
| return RET_PARAM_INVALID; | |||||
| } | |||||
| // check whether config file path is valid | |||||
| char *resolved_path = new (std::nothrow) char[PATH_MAX]{0}; | |||||
| if (resolved_path == nullptr) { | |||||
| MS_LOG(ERROR) << "New an object failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| #ifdef _WIN32 | |||||
| if (_fullpath(resolved_path, config_path_.c_str(), 1024) != nullptr) { | |||||
| config_path_ = string(resolved_path); | |||||
| } | |||||
| #else | |||||
| if (realpath(config_path_.c_str(), resolved_path) != nullptr) { | |||||
| config_path_ = string(resolved_path); | |||||
| } | |||||
| #endif | |||||
| std::ifstream fs(config_path_.c_str(), std::ifstream::in); | |||||
| if (!fs.is_open()) { | |||||
| MS_LOG(ERROR) << "config proto file %s open failed: " << config_path_; | |||||
| delete[] resolved_path; | |||||
| return RET_PARAM_INVALID; | |||||
| } | |||||
| std::string line; | |||||
| while (std::getline(fs, line)) { | |||||
| auto index = line.find('='); | |||||
| if (index == std::string::npos) { | |||||
| MS_LOG(ERROR) << "the config file is invalid, can not find '=', please check"; | |||||
| delete[] resolved_path; | |||||
| return RET_PARAM_INVALID; | |||||
| } | |||||
| auto key = line.substr(0, index); | |||||
| auto value = line.substr(index + 1); | |||||
| Trim(&key); | |||||
| Trim(&value); | |||||
| if (key == "image_path") { | |||||
| auto &raw_image_paths = value; | |||||
| auto ind = raw_image_paths.find(','); | |||||
| while (ind != std::string::npos) { | |||||
| auto image_path = raw_image_paths.substr(0, ind); | |||||
| Trim(&image_path); | |||||
| config_param_.image_paths.push_back(image_path); | |||||
| raw_image_paths = raw_image_paths.substr(ind + 1); | |||||
| Trim(&raw_image_paths); | |||||
| ind = raw_image_paths.find(','); | |||||
| } | |||||
| config_param_.image_paths.push_back(raw_image_paths); | |||||
| } else if (key == "batch_count") { | |||||
| config_param_.batch_count = std::stoul(value); | |||||
| } else if (key == "thread_num") { | |||||
| config_param_.thread_num = std::stoul(value); | |||||
| } else if (key == "method_x") { | |||||
| if (value != kMethodKL && value != kMethodMaxMin && value != kMethodOutlier) { | |||||
| MS_LOG(WARNING) << "unsupported method_x: " << value << ". Use default value."; | |||||
| } else { | |||||
| config_param_.method_x = value; | |||||
| } | |||||
| } else if (key == "bias_correction") { | |||||
| std::for_each(value.begin(), value.end(), ::tolower); | |||||
| if (value == "true") { | |||||
| config_param_.bias_correction = true; | |||||
| } | |||||
| } else { | |||||
| MS_LOG(WARNING) << "unsupported parameter"; | |||||
| } | |||||
| } | |||||
| for (const auto &path : config_param_.image_paths) { | |||||
| MS_LOG(DEBUG) << "calibration data_path: " << path; | |||||
| } | |||||
| MS_LOG(DEBUG) << "batch_count: " << config_param_.batch_count << " " | |||||
| << "method_x: " << config_param_.method_x << " " | |||||
| << "thread_num: " << config_param_.thread_num << " " | |||||
| << "bias_correction: " << config_param_.bias_correction; | |||||
| delete[] resolved_path; | |||||
| fs.close(); | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS Calibrator::ReadConfig() { return ParseConfigFile(config_path_, &config_param_); } | |||||
| Calibrator::Calibrator(string path, size_t bit_num, int quant_max, int quant_min) | Calibrator::Calibrator(string path, size_t bit_num, int quant_max, int quant_min) | ||||
| : config_path_(std::move(path)), bit_num_(bit_num), quant_max_(quant_max), quant_min_(quant_min) {} | : config_path_(std::move(path)), bit_num_(bit_num), quant_max_(quant_max), quant_min_(quant_min) {} | ||||
| @@ -621,8 +456,8 @@ STATUS PostTrainingQuantizer::DoQuantOutput(double scale, int zeropoint, struct | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| STATUS PostTrainingQuantizer::DoWeightQuant(const AnfNodePtr &weight, std::shared_ptr<PrimitiveC> primitive_c, | |||||
| bool perchanel) const { | |||||
| STATUS PostTrainingQuantizer::DoWeightQuant(const std::string &op_name, const AnfNodePtr &weight, | |||||
| std::shared_ptr<PrimitiveC> primitive_c, bool perchanel) const { | |||||
| MS_ASSERT(weight != nullptr); | MS_ASSERT(weight != nullptr); | ||||
| MS_ASSERT(lite_primitive != nullptr); | MS_ASSERT(lite_primitive != nullptr); | ||||
| // perlayer | // perlayer | ||||
| @@ -640,8 +475,21 @@ STATUS PostTrainingQuantizer::DoWeightQuant(const AnfNodePtr &weight, std::share | |||||
| MS_LOG(ERROR) << weight->fullname_with_scope() << " can not get value"; | MS_LOG(ERROR) << weight->fullname_with_scope() << " can not get value"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| auto status = QuantFilter<int8_t>(paramValue, std::move(primitive_c), QuantType_PostTraining, quant_max, quant_min, | |||||
| bit_num, perchanel); | |||||
| auto bit_num_t = bit_num; | |||||
| auto quant_max_t = quant_max; | |||||
| auto quant_min_t = quant_min; | |||||
| if (calibrator_->config_param_.mixed) { | |||||
| auto opname_iter = opname_bit_.find(op_name); | |||||
| if (opname_iter == opname_bit_.end()) { | |||||
| MS_LOG(WARNING) << op_name << " not in the opname_bit_ map"; | |||||
| } else { | |||||
| bit_num_t = opname_iter->second; | |||||
| quant_max_t = (1 << (unsigned int)(bit_num_t - 1)) - 1; | |||||
| quant_min_t = -(1 << (unsigned int)(bit_num_t - 1)); | |||||
| } | |||||
| } | |||||
| auto status = QuantFilter<int8_t>(paramValue, std::move(primitive_c), QuantType_PostTraining, quant_max_t, | |||||
| quant_min_t, bit_num_t, perchanel); | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "QuantFilter failed: " << status; | MS_LOG(ERROR) << "QuantFilter failed: " << status; | ||||
| return status; | return status; | ||||
| @@ -921,7 +769,7 @@ STATUS PostTrainingQuantizer::QuantNode() { | |||||
| } | } | ||||
| if (abstractTensor->element()->GetTypeTrack()->type_id() == kNumberTypeFloat32) { | if (abstractTensor->element()->GetTypeTrack()->type_id() == kNumberTypeFloat32) { | ||||
| MS_LOG(DEBUG) << "this parameter do quant"; | MS_LOG(DEBUG) << "this parameter do quant"; | ||||
| DoWeightQuant(input_node, primitive_c, false); | |||||
| DoWeightQuant(op_name, input_node, primitive_c, false); | |||||
| } else { | } else { | ||||
| MS_LOG(DEBUG) << "this parameter no need to do quant"; | MS_LOG(DEBUG) << "this parameter no need to do quant"; | ||||
| } | } | ||||
| @@ -943,7 +791,7 @@ STATUS PostTrainingQuantizer::QuantNode() { | |||||
| op_type == PrimitiveType_FullConnection) { | op_type == PrimitiveType_FullConnection) { | ||||
| perchannel = true; | perchannel = true; | ||||
| } | } | ||||
| DoWeightQuant(weight, primitive_c, perchannel); | |||||
| DoWeightQuant(op_name, weight, primitive_c, perchannel); | |||||
| // do bias quant | // do bias quant | ||||
| if (cnode->inputs().size() == 4) { | if (cnode->inputs().size() == 4) { | ||||
| auto bias = cnode->input(3); | auto bias = cnode->input(3); | ||||
| @@ -982,18 +830,8 @@ STATUS PostTrainingQuantizer::UpdateDivergInverval() { | |||||
| * 3. save quantied node | * 3. save quantied node | ||||
| **/ | **/ | ||||
| STATUS PostTrainingQuantizer::PreProcess() { | STATUS PostTrainingQuantizer::PreProcess() { | ||||
| if (this->calibrator_ == nullptr) { | |||||
| MS_LOG(ERROR) << "calibrator is null!"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| // 1. generate config param | |||||
| STATUS status = calibrator_->ReadConfig(); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "read proto text failed!"; | |||||
| return status; | |||||
| } | |||||
| // 2. collect image files | // 2. collect image files | ||||
| status = calibrator_->CollectImages(); | |||||
| auto status = calibrator_->CollectImages(); | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "collect images failed!"; | MS_LOG(ERROR) << "collect images failed!"; | ||||
| return status; | return status; | ||||
| @@ -1560,55 +1398,49 @@ STATUS PostTrainingQuantizer::ComputeThreshold() { return this->calibrator_->Com | |||||
| STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr func_graph) { | STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr func_graph) { | ||||
| MS_LOG(INFO) << "start to parse config file"; | MS_LOG(INFO) << "start to parse config file"; | ||||
| STATUS status = PreProcess(); | |||||
| if (this->calibrator_ == nullptr) { | |||||
| MS_LOG(ERROR) << "calibrator is null!"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| // 1. generate config param | |||||
| STATUS status = calibrator_->ReadConfig(); | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "do pre process failed!"; | |||||
| MS_LOG(ERROR) << "read proto text failed!"; | |||||
| return status; | return status; | ||||
| } | } | ||||
| // anf -- fb | |||||
| auto meta_graph = Export(func_graph, true, true); | |||||
| if (meta_graph == nullptr) { | |||||
| MS_LOG(ERROR) << "Export to meta_graph return nullptr"; | |||||
| return RET_ERROR; | |||||
| if (calibrator_->config_param_.mixed) { | |||||
| // get opname_bit map | |||||
| auto weight_quant_func_graph = CopyFuncGraph(func_graph); | |||||
| if (weight_quant_func_graph == nullptr) { | |||||
| MS_LOG(ERROR) << "CopyFuncGraph error"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| WeightQuantizer weight_quantizer(weight_quant_func_graph, calibrator_->config_param_); | |||||
| weight_quantizer.flags = flags; | |||||
| status = weight_quantizer.DoQuantize(weight_quant_func_graph); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "Do mix weight quant error"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| opname_bit_ = weight_quantizer.opname_bit_; | |||||
| } | } | ||||
| // transform | |||||
| GraphDefTransform transform; | |||||
| transform.SetGraphDef(meta_graph); | |||||
| flags.quantType = schema::QuantType_QUANT_NONE; | |||||
| status = transform.Transform(flags); | |||||
| status = PreProcess(); | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "FBTransform model failed " << status; | |||||
| return RET_ERROR; | |||||
| } | |||||
| MS_LOG(INFO) << "start create session"; | |||||
| flatbuffers::FlatBufferBuilder builder(1024); | |||||
| auto offset = schema::MetaGraph::Pack(builder, meta_graph); | |||||
| builder.Finish(offset); | |||||
| schema::FinishMetaGraphBuffer(builder, offset); | |||||
| size_t size = builder.GetSize(); | |||||
| auto *content = reinterpret_cast<const char *>(builder.GetBufferPointer()); | |||||
| if (content == nullptr) { | |||||
| MS_LOG(ERROR) << "GetBufferPointer nullptr"; | |||||
| return RET_ERROR; | |||||
| MS_LOG(ERROR) << "do pre process failed!"; | |||||
| return status; | |||||
| } | } | ||||
| auto model = lite::Model::Import(content, size); | |||||
| Context ctx; | |||||
| ctx.thread_num_ = calibrator_->GetThreadNum(); | |||||
| fp32_session_ = dynamic_cast<mindspore::lite::LiteSession *>(session::LiteSession::CreateSession(&ctx)); | |||||
| // anf -- fb | |||||
| flags.quantType = schema::QuantType_QUANT_NONE; | |||||
| MS_LOG(INFO) << "start create session"; | |||||
| fp32_session_ = CreateSessionByFuncGraph(func_graph, flags, calibrator_->GetThreadNum()); | |||||
| if (fp32_session_ == nullptr) { | if (fp32_session_ == nullptr) { | ||||
| MS_LOG(ERROR) << "create session failed!"; | MS_LOG(ERROR) << "create session failed!"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto ret = fp32_session_->CompileGraph(model); | |||||
| if (ret != lite::RET_OK) { | |||||
| MS_LOG(ERROR) << "compile graph error"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| MS_LOG(INFO) << "start to update divergence's max value"; | MS_LOG(INFO) << "start to update divergence's max value"; | ||||
| status = DoInference(); | status = DoInference(); | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| @@ -1647,49 +1479,13 @@ STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr func_graph) { | |||||
| if (calibrator_->GetBiasCorrection()) { | if (calibrator_->GetBiasCorrection()) { | ||||
| // init in8 session | // init in8 session | ||||
| // anf -- fb | |||||
| auto int8_meta_graph = Export(func_graph, true, true); | |||||
| if (int8_meta_graph == nullptr) { | |||||
| MS_LOG(ERROR) << "Export to int8_meta_graph return nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| // transform | |||||
| GraphDefTransform fb_transform; | |||||
| fb_transform.SetGraphDef(int8_meta_graph); | |||||
| MS_LOG(INFO) << "create quant session"; | |||||
| flags.quantType = schema::QuantType_PostTraining; | flags.quantType = schema::QuantType_PostTraining; | ||||
| status = fb_transform.Transform(flags); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "FBTransform model failed " << status; | |||||
| return RET_ERROR; | |||||
| } | |||||
| MS_LOG(INFO) << "start create quantized session"; | |||||
| flatbuffers::FlatBufferBuilder int8_builder(1024); | |||||
| auto int8_offset = schema::MetaGraph::Pack(int8_builder, int8_meta_graph); | |||||
| int8_builder.Finish(int8_offset); | |||||
| schema::FinishMetaGraphBuffer(int8_builder, int8_offset); | |||||
| size = int8_builder.GetSize(); | |||||
| auto *int8_content = reinterpret_cast<const char *>(int8_builder.GetBufferPointer()); | |||||
| if (int8_content == nullptr) { | |||||
| MS_LOG(ERROR) << "GetBufferPointer nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto int8_model = lite::Model::Import(int8_content, size); | |||||
| Context int8_ctx; | |||||
| int8_ctx.thread_num_ = calibrator_->GetThreadNum(); | |||||
| int8_ctx.device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = HIGHER_CPU; | |||||
| int8_session_ = dynamic_cast<mindspore::lite::LiteSession *>(session::LiteSession::CreateSession(&int8_ctx)); | |||||
| int8_session_ = CreateSessionByFuncGraph(func_graph, flags, calibrator_->GetThreadNum()); | |||||
| if (int8_session_ == nullptr) { | if (int8_session_ == nullptr) { | ||||
| MS_LOG(ERROR) << "create session failed!"; | MS_LOG(ERROR) << "create session failed!"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| ret = int8_session_->CompileGraph(int8_model); | |||||
| if (ret != lite::RET_OK) { | |||||
| MS_LOG(ERROR) << "compile graph error"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| MS_LOG(INFO) << "do bias correction"; | MS_LOG(INFO) << "do bias correction"; | ||||
| status = BiasCorrection(func_graph); | status = BiasCorrection(func_graph); | ||||
| @@ -28,6 +28,8 @@ | |||||
| #include "tools/converter/quantizer/quantizer.h" | #include "tools/converter/quantizer/quantizer.h" | ||||
| #include "tools/converter/converter.h" | #include "tools/converter/converter.h" | ||||
| #include "include/ms_tensor.h" | #include "include/ms_tensor.h" | ||||
| #include "tools/converter/quantizer/quantize_util.h" | |||||
| #include "tools/converter/quantizer/weight_quantizer.h" | |||||
| namespace mindspore::lite::quant { | namespace mindspore::lite::quant { | ||||
| class Calibrator; | class Calibrator; | ||||
| @@ -38,19 +40,8 @@ struct MaxMin { | |||||
| float max; | float max; | ||||
| }; | }; | ||||
| const char kMethodMaxMin[] = "MAX_MIN"; | |||||
| const char kMethodKL[] = "KL"; | |||||
| const char kMethodOutlier[] = "RemovalOutlier"; | |||||
| constexpr int kDefaultBinNumber = 2048; | constexpr int kDefaultBinNumber = 2048; | ||||
| struct ConfigParam { | |||||
| std::vector<std::string> image_paths; | |||||
| uint32_t batch_count{100}; | |||||
| std::string method_x{kMethodKL}; | |||||
| uint32_t thread_num{1}; | |||||
| bool bias_correction{false}; | |||||
| }; | |||||
| class PostTrainingQuantizer : public Quantizer { | class PostTrainingQuantizer : public Quantizer { | ||||
| public: | public: | ||||
| PostTrainingQuantizer(FuncGraphPtr graph, std::string path, int bit_num, TypeId target_type = kNumberTypeInt8, | PostTrainingQuantizer(FuncGraphPtr graph, std::string path, int bit_num, TypeId target_type = kNumberTypeInt8, | ||||
| @@ -64,14 +55,16 @@ class PostTrainingQuantizer : public Quantizer { | |||||
| int quant_min{INT8_MIN}; | int quant_min{INT8_MIN}; | ||||
| private: | private: | ||||
| std::map<std::string, int> opname_bit_; | |||||
| bool per_channel_{true}; | bool per_channel_{true}; | ||||
| TypeId target_type_{kNumberTypeInt8}; | TypeId target_type_{kNumberTypeInt8}; | ||||
| std::unique_ptr<Calibrator> calibrator_; | std::unique_ptr<Calibrator> calibrator_; | ||||
| mindspore::lite::LiteSession *fp32_session_; | |||||
| mindspore::lite::LiteSession *int8_session_; | |||||
| session::LiteSession *fp32_session_{nullptr}; | |||||
| session::LiteSession *int8_session_{nullptr}; | |||||
| std::map<std::string, std::vector<float>> fp32_op_input_map; // concurency | std::map<std::string, std::vector<float>> fp32_op_input_map; // concurency | ||||
| std::map<std::string, std::vector<float>> fp32_op_output_ch_mean_map; // concurency | std::map<std::string, std::vector<float>> fp32_op_output_ch_mean_map; // concurency | ||||
| @@ -112,7 +105,8 @@ class PostTrainingQuantizer : public Quantizer { | |||||
| STATUS DoQuantOutput(double scale, int32_t zeropoint, struct MaxMin *max_min, | STATUS DoQuantOutput(double scale, int32_t zeropoint, struct MaxMin *max_min, | ||||
| const std::shared_ptr<PrimitiveC> &) const; | const std::shared_ptr<PrimitiveC> &) const; | ||||
| STATUS DoWeightQuant(const AnfNodePtr &weight, std::shared_ptr<PrimitiveC> primitive_c, bool perchannel) const; | |||||
| STATUS DoWeightQuant(const std::string &op_name, const AnfNodePtr &weight, std::shared_ptr<PrimitiveC> primitive_c, | |||||
| bool perchannel) const; | |||||
| STATUS DoBiasQuant(const AnfNodePtr &bias, const std::shared_ptr<PrimitiveC> &primitive_c); | STATUS DoBiasQuant(const AnfNodePtr &bias, const std::shared_ptr<PrimitiveC> &primitive_c); | ||||
| STATUS Int8Inference(); | STATUS Int8Inference(); | ||||
| @@ -213,13 +207,13 @@ class Calibrator { | |||||
| std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> *GetOutputDivergInfo(); | std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> *GetOutputDivergInfo(); | ||||
| PostQuantConfig config_param_; | |||||
| private: | private: | ||||
| std::vector<std::vector<std::string>> images_; // multi_input, echo input has multi input data | std::vector<std::vector<std::string>> images_; // multi_input, echo input has multi input data | ||||
| std::string config_path_; | std::string config_path_; | ||||
| ConfigParam config_param_; | |||||
| std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> inputs_diverg_info_; | std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> inputs_diverg_info_; | ||||
| std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> outputs_diverg_info_; | std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> outputs_diverg_info_; | ||||
| @@ -227,8 +221,6 @@ class Calibrator { | |||||
| size_t bit_num_; | size_t bit_num_; | ||||
| int quant_max_; | int quant_max_; | ||||
| int quant_min_; | int quant_min_; | ||||
| void AddImage(const std::string &file, size_t index); | |||||
| }; | }; | ||||
| } // namespace mindspore::lite::quant | } // namespace mindspore::lite::quant | ||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_POSTRAINING_QUANTIZER_H | #endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_POSTRAINING_QUANTIZER_H | ||||
| @@ -17,6 +17,8 @@ | |||||
| #include "mindspore/lite/tools/converter/quantizer/quantize_util.h" | #include "mindspore/lite/tools/converter/quantizer/quantize_util.h" | ||||
| #include <cmath> | #include <cmath> | ||||
| #include <string> | #include <string> | ||||
| #include <map> | |||||
| #include <fstream> | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | #include <vector> | ||||
| @@ -26,6 +28,8 @@ | |||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "abstract/abstract_value.h" | #include "abstract/abstract_value.h" | ||||
| #include "securec/include/securec.h" | #include "securec/include/securec.h" | ||||
| #include "tools/anf_exporter/anf_exporter.h" | |||||
| #include "mindspore/lite/include/version.h" | |||||
| using std::string; | using std::string; | ||||
| using std::vector; | using std::vector; | ||||
| @@ -83,10 +87,10 @@ bool QuantStrategy::CanConvOpQuantized(const CNodePtr &node) const { | |||||
| bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const { | bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const { | ||||
| MS_ASSERT(node != nullptr); | MS_ASSERT(node != nullptr); | ||||
| if (!node->isa<CNode>()) { | |||||
| if (!node->isa<mindspore::CNode>()) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| auto cnode = std::dynamic_pointer_cast<CNode>(node); | |||||
| auto cnode = std::dynamic_pointer_cast<mindspore::CNode>(node); | |||||
| auto type = NodePrimitiveType(cnode); | auto type = NodePrimitiveType(cnode); | ||||
| static const std::vector<schema::PrimitiveType> int8OpList = { | static const std::vector<schema::PrimitiveType> int8OpList = { | ||||
| schema::PrimitiveType_Conv2D, | schema::PrimitiveType_Conv2D, | ||||
| @@ -475,4 +479,307 @@ schema::PrimitiveType NodePrimitiveType(const CNodePtr &cnode) { | |||||
| } | } | ||||
| return (schema::PrimitiveType)primitive_c->Type(); | return (schema::PrimitiveType)primitive_c->Type(); | ||||
| } | } | ||||
| STATUS ParseConfigFile(std::string config_file, PostQuantConfig *post_quant_config) { | |||||
| if (post_quant_config == nullptr) { | |||||
| MS_LOG(ERROR) << "post_quant_config is null."; | |||||
| return RET_PARAM_INVALID; | |||||
| } | |||||
| if (config_file.empty() || config_file.length() > PATH_MAX) { | |||||
| MS_LOG(ERROR) << "invalid config path!"; | |||||
| return RET_PARAM_INVALID; | |||||
| } | |||||
| // check whether config file path is valid | |||||
| auto resolved_path = std::make_unique<char[]>(PATH_MAX); | |||||
| if (resolved_path == nullptr) { | |||||
| MS_LOG(ERROR) << "New an object failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| #ifdef _WIN32 | |||||
| if (_fullpath(resolved_path.get(), config_file.c_str(), 1024) != nullptr) { | |||||
| config_file = string(resolved_path.get()); | |||||
| } | |||||
| #else | |||||
| if (realpath(config_file.c_str(), resolved_path.get()) != nullptr) { | |||||
| config_file = string(resolved_path.get()); | |||||
| } | |||||
| #endif | |||||
| std::ifstream fs(config_file.c_str(), std::ifstream::in); | |||||
| if (!fs.is_open()) { | |||||
| MS_LOG(ERROR) << "config file open failed: " << config_file; | |||||
| return RET_PARAM_INVALID; | |||||
| } | |||||
| std::string line; | |||||
| while (std::getline(fs, line)) { | |||||
| Trim(&line); | |||||
| if (line.empty()) { | |||||
| continue; | |||||
| } | |||||
| auto index = line.find('='); | |||||
| if (index == std::string::npos) { | |||||
| MS_LOG(ERROR) << "the config file is invalid, can not find '=', please check"; | |||||
| return RET_PARAM_INVALID; | |||||
| } | |||||
| auto key = line.substr(0, index); | |||||
| auto value = line.substr(index + 1); | |||||
| Trim(&key); | |||||
| Trim(&value); | |||||
| if (key == "image_path") { | |||||
| auto &raw_image_paths = value; | |||||
| auto ind = raw_image_paths.find(','); | |||||
| while (ind != std::string::npos) { | |||||
| auto image_path = raw_image_paths.substr(0, ind); | |||||
| Trim(&image_path); | |||||
| post_quant_config->image_paths.push_back(image_path); | |||||
| raw_image_paths = raw_image_paths.substr(ind + 1); | |||||
| Trim(&raw_image_paths); | |||||
| ind = raw_image_paths.find(','); | |||||
| } | |||||
| post_quant_config->image_paths.push_back(raw_image_paths); | |||||
| } else if (key == "batch_count") { | |||||
| post_quant_config->batch_count = std::stoul(value); | |||||
| } else if (key == "thread_num") { | |||||
| post_quant_config->thread_num = std::stoul(value); | |||||
| } else if (key == "method_x") { | |||||
| if (value != kMethodKL && value != kMethodMaxMin && value != kMethodOutlier) { | |||||
| MS_LOG(WARNING) << "unsupported method_x: " << value << ". Use default value."; | |||||
| } else { | |||||
| post_quant_config->method_x = value; | |||||
| } | |||||
| } else if (key == "bias_correction") { | |||||
| std::for_each(value.begin(), value.end(), ::tolower); | |||||
| if (value == "true") { | |||||
| post_quant_config->bias_correction = true; | |||||
| } | |||||
| } else if (key == "mixed") { | |||||
| std::for_each(value.begin(), value.end(), ::tolower); | |||||
| if (value == "true") { | |||||
| post_quant_config->mixed = true; | |||||
| } | |||||
| } else if (key == "mean_error_threshold") { | |||||
| post_quant_config->mean_error_threshold = std::stof(value); | |||||
| } else { | |||||
| MS_LOG(WARNING) << "unsupported parameter: " << key; | |||||
| } | |||||
| } | |||||
| for (const auto &path : post_quant_config->image_paths) { | |||||
| MS_LOG(DEBUG) << "calibration data_path: " << path; | |||||
| } | |||||
| MS_LOG(DEBUG) << "batch_count: " << post_quant_config->batch_count << "\n" | |||||
| << "method_x: " << post_quant_config->method_x << "\n" | |||||
| << "thread_num: " << post_quant_config->thread_num << "\n" | |||||
| << "bias_correction: " << post_quant_config->bias_correction << "\n" | |||||
| << "mixed: " << post_quant_config->mixed << "\n" | |||||
| << "mean_error_threshold: " << post_quant_config->mean_error_threshold; | |||||
| post_quant_config->inited = true; | |||||
| fs.close(); | |||||
| return RET_OK; | |||||
| } | |||||
| session::LiteSession *CreateSessionByFuncGraph(const FuncGraphPtr &func_graph, const converter::Flags &flags, | |||||
| int thread_num) { | |||||
| auto meta_graph = Export(func_graph, true, true); | |||||
| if (meta_graph == nullptr) { | |||||
| MS_LOG(ERROR) << "Export to meta_graph failed"; | |||||
| return nullptr; | |||||
| } | |||||
| // transform | |||||
| GraphDefTransform fb_transform; | |||||
| fb_transform.SetGraphDef(meta_graph); | |||||
| auto status = fb_transform.Transform(flags); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "FBTransform model failed"; | |||||
| return nullptr; | |||||
| } | |||||
| meta_graph->version = Version(); | |||||
| flatbuffers::FlatBufferBuilder builder(1024); | |||||
| auto offset = schema::MetaGraph::Pack(builder, meta_graph); | |||||
| builder.Finish(offset); | |||||
| schema::FinishMetaGraphBuffer(builder, offset); | |||||
| auto size = builder.GetSize(); | |||||
| auto *content = reinterpret_cast<const char *>(builder.GetBufferPointer()); | |||||
| if (content == nullptr) { | |||||
| MS_LOG(ERROR) << "GetBufferPointer return null"; | |||||
| return nullptr; | |||||
| } | |||||
| auto model = lite::Model::Import(content, size); | |||||
| if (model == nullptr) { | |||||
| MS_LOG(ERROR) << "Import model failed"; | |||||
| return nullptr; | |||||
| } | |||||
| Context ctx; | |||||
| ctx.thread_num_ = thread_num; | |||||
| auto session = session::LiteSession::CreateSession(&ctx); | |||||
| if (session == nullptr) { | |||||
| MS_LOG(ERROR) << "create session failed."; | |||||
| return nullptr; | |||||
| } | |||||
| status = session->CompileGraph(model); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "CompileGraph error"; | |||||
| return nullptr; | |||||
| } | |||||
| model->Free(); | |||||
| return session; | |||||
| } | |||||
| STATUS CollectCalibInputs(const std::vector<std::string> &input_dirs, size_t count_limited, | |||||
| std::vector<std::vector<std::string>> *inputs) { | |||||
| if (inputs == nullptr) { | |||||
| MS_LOG(ERROR) << "inputs is null"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto AddImage = [&inputs](const std::string &file, size_t index) { | |||||
| if (index >= inputs->size()) { | |||||
| MS_LOG(ERROR) << "images_ size: " << inputs->size() << " but input index: " << index; | |||||
| return; | |||||
| } | |||||
| struct stat buf {}; | |||||
| if (stat(file.c_str(), &buf) == 0) { | |||||
| inputs->at(index).push_back(file); | |||||
| } else { | |||||
| MS_LOG(WARNING) << "invalid image file path: " << file; | |||||
| } | |||||
| }; | |||||
| inputs->resize(input_dirs.size()); | |||||
| auto input_i = 0; | |||||
| bool multi_input = input_dirs.size() > 1; | |||||
| for (const auto &image_path : input_dirs) { | |||||
| DIR *root = opendir(image_path.c_str()); | |||||
| if (root == nullptr) { | |||||
| MS_LOG(ERROR) << "invalid image path: " << image_path; | |||||
| return RET_PARAM_INVALID; | |||||
| } | |||||
| struct dirent *image_dir = readdir(root); | |||||
| size_t count = 0; | |||||
| while (image_dir != nullptr) { | |||||
| string file_name(image_dir->d_name); | |||||
| if (file_name != "." && file_name != "..") { | |||||
| const std::string file_path = image_path + "/" + file_name; | |||||
| if (multi_input || count == 0) { | |||||
| AddImage(file_path, input_i); | |||||
| count++; | |||||
| } else if (count < count_limited) { | |||||
| AddImage(file_path, input_i); | |||||
| count++; | |||||
| } else { | |||||
| break; | |||||
| } | |||||
| } | |||||
| image_dir = readdir(root); | |||||
| } | |||||
| std::sort(inputs->at(input_i).begin(), inputs->at(input_i).end()); | |||||
| if (count_limited != 0 && count_limited < inputs->at(input_i).size()) { | |||||
| inputs->at(input_i).resize(count_limited); | |||||
| } | |||||
| closedir(root); | |||||
| input_i++; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS CopyInputDataToTensor(size_t input_index, size_t image_index, | |||||
| const std::vector<std::vector<std::string>> &images, mindspore::tensor::MSTensor *tensor) { | |||||
| MS_ASSERT(tensor != nullptr); | |||||
| if (input_index >= images.size()) { | |||||
| MS_LOG(ERROR) << "images_ size: " << images.size() << " but input_index: " << input_index; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (image_index >= images[input_index].size()) { | |||||
| MS_LOG(ERROR) << "images_[input_index] size: " << images[input_index].size() << " but image_index: " << image_index; | |||||
| return RET_ERROR; | |||||
| } | |||||
| string path = images[input_index][image_index]; | |||||
| MS_LOG(INFO) << "read image: " << path; | |||||
| size_t size; | |||||
| char *bin_buf = ReadFile(path.c_str(), &size); | |||||
| if (bin_buf == nullptr) { | |||||
| MS_LOG(ERROR) << "ReadFile return nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto data = tensor->MutableData(); | |||||
| if (data == nullptr) { | |||||
| MS_LOG(ERROR) << "Get tensor MutableData return nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| if (size != tensor->Size()) { | |||||
| MS_LOG(ERROR) << "the input data is not consistent with model input, file_size: " << size | |||||
| << " input tensor size: " << tensor->Size(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto ret = memcpy_s(data, tensor->Size(), bin_buf, size); | |||||
| if (ret != EOK) { | |||||
| MS_LOG(ERROR) << "memcpy_s error: " << ret; | |||||
| delete[] bin_buf; | |||||
| return RET_ERROR; | |||||
| } | |||||
| delete[] bin_buf; | |||||
| return RET_OK; | |||||
| } | |||||
| FuncGraphPtr CopyFuncGraph(const FuncGraphPtr &func_graph) { | |||||
| Cloner cloner({func_graph}, true, true, true, std::make_shared<TraceCopy>(), nullptr); | |||||
| auto new_func_graph = cloner[func_graph]; | |||||
| std::map<std::string, CNodePtr> old_cnode_map; | |||||
| for (const auto &cnode : func_graph->GetOrderedCnodes()) { | |||||
| old_cnode_map[cnode->fullname_with_scope()] = cnode; | |||||
| } | |||||
| for (auto &cnode : new_func_graph->GetOrderedCnodes()) { | |||||
| auto cnode_name = cnode->fullname_with_scope(); | |||||
| auto old_cnode_iter = old_cnode_map.find(cnode_name); | |||||
| if (old_cnode_iter == old_cnode_map.end()) { | |||||
| MS_LOG(ERROR) << "can not find node: " << cnode_name; | |||||
| return nullptr; | |||||
| } | |||||
| auto old_cnode = old_cnode_iter->second; | |||||
| auto inputs = cnode->inputs(); | |||||
| for (size_t i = 0; i < inputs.size(); i++) { | |||||
| auto input_node = inputs[i]; | |||||
| if (input_node->isa<Parameter>()) { | |||||
| auto param_node = input_node->cast<ParameterPtr>(); | |||||
| if (param_node->has_default()) { | |||||
| ParamValueLitePtr old_param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param()); | |||||
| auto new_param_value = std::make_shared<ParamValueLite>(); | |||||
| auto copyed_data = malloc(old_param_value->tensor_size()); | |||||
| if (copyed_data == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc data error, size: " << old_param_value->tensor_size(); | |||||
| return nullptr; | |||||
| } | |||||
| memcpy(copyed_data, old_param_value->tensor_addr(), old_param_value->tensor_size()); | |||||
| new_param_value->set_tensor_size(old_param_value->tensor_size()); | |||||
| new_param_value->set_tensor_addr(copyed_data); | |||||
| new_param_value->set_tensor_shape(old_param_value->tensor_shape()); | |||||
| new_param_value->set_format(old_param_value->format()); | |||||
| new_param_value->set_tensor_type(old_param_value->tensor_type()); | |||||
| param_node->set_default_param(new_param_value); | |||||
| } | |||||
| auto old_abstract_base = param_node->abstract(); | |||||
| if (!utils::isa<abstract::AbstractTensorPtr>(old_abstract_base)) { | |||||
| MS_LOG(ERROR) << "Abstract of parameter should be abstract tensor, " << param_node->name(); | |||||
| return nullptr; | |||||
| } | |||||
| auto old_abstract = utils::cast<abstract::AbstractTensorPtr>(old_abstract_base); | |||||
| auto new_abstract = std::make_shared<abstract::AbstractTensor>(old_abstract->element()->GetTypeTrack(), | |||||
| old_abstract->GetShapeTrack()); | |||||
| param_node->set_abstract(new_abstract); | |||||
| } | |||||
| } // end inputs loop | |||||
| } // end cnodes loop | |||||
| return new_func_graph; | |||||
| } | |||||
| } // namespace mindspore::lite::quant | } // namespace mindspore::lite::quant | ||||
| @@ -17,6 +17,8 @@ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANTIZER_UTIL_H | #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANTIZER_UTIL_H | ||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANTIZER_UTIL_H | #define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANTIZER_UTIL_H | ||||
| #include <dirent.h> | |||||
| #include <sys/stat.h> | |||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| #include <cmath> | #include <cmath> | ||||
| @@ -35,11 +37,29 @@ | |||||
| #include "ir/primitive.h" | #include "ir/primitive.h" | ||||
| #include "abstract/dshape.h" | #include "abstract/dshape.h" | ||||
| #include "tools/converter/quantizer/bitpacking.h" | #include "tools/converter/quantizer/bitpacking.h" | ||||
| #include "src/lite_session.h" | |||||
| #include "tools/converter/graphdef_transform.h" | |||||
| #include "src/common/file_utils.h" | |||||
| namespace mindspore::lite::quant { | namespace mindspore::lite::quant { | ||||
| static constexpr size_t UINT8_QUANTIZATION = 8; | static constexpr size_t UINT8_QUANTIZATION = 8; | ||||
| static constexpr size_t WEIGHT_INDEX = 1; | static constexpr size_t WEIGHT_INDEX = 1; | ||||
| const char kMethodMaxMin[] = "MAX_MIN"; | |||||
| const char kMethodKL[] = "KL"; | |||||
| const char kMethodOutlier[] = "RemovalOutlier"; | |||||
| struct PostQuantConfig { | |||||
| std::vector<std::string> image_paths; | |||||
| uint32_t batch_count{100}; | |||||
| std::string method_x{kMethodKL}; | |||||
| uint32_t thread_num{1}; | |||||
| bool bias_correction{false}; | |||||
| bool mixed{false}; | |||||
| float mean_error_threshold{0.04}; | |||||
| bool inited{false}; | |||||
| }; | |||||
| /** | /** | ||||
| * 1. when op's weight size > mWeightSize just skip | * 1. when op's weight size > mWeightSize just skip | ||||
| * 2. only do conv/deconv/convdepthwise/deconvdepthwise/mul/matmul/batchmatmul quantization | * 2. only do conv/deconv/convdepthwise/deconvdepthwise/mul/matmul/batchmatmul quantization | ||||
| @@ -320,6 +340,21 @@ STATUS QuantFilter(const ParamValueLitePtr &weight, const std::shared_ptr<Primit | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| // utils | |||||
| schema::PrimitiveType NodePrimitiveType(const CNodePtr &cnode); | schema::PrimitiveType NodePrimitiveType(const CNodePtr &cnode); | ||||
| STATUS ParseConfigFile(std::string config_file, PostQuantConfig *post_quant_config); | |||||
| session::LiteSession *CreateSessionByFuncGraph(const FuncGraphPtr &func_graph, const converter::Flags &flags, | |||||
| int thread_num); | |||||
| STATUS CollectCalibInputs(const std::vector<std::string> &input_dirs, size_t count_limited, | |||||
| std::vector<std::vector<std::string>> *inputs); | |||||
| STATUS CopyInputDataToTensor(size_t input_index, size_t image_index, | |||||
| const std::vector<std::vector<std::string>> &images, mindspore::tensor::MSTensor *tensor); | |||||
| FuncGraphPtr CopyFuncGraph(const FuncGraphPtr &); | |||||
| } // namespace mindspore::lite::quant | } // namespace mindspore::lite::quant | ||||
| #endif | #endif | ||||
| @@ -18,6 +18,7 @@ | |||||
| #include <list> | #include <list> | ||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include <unordered_map> | |||||
| #include "src/common/common.h" | #include "src/common/common.h" | ||||
| #include "ir/dtype/type_id.h" | #include "ir/dtype/type_id.h" | ||||
| @@ -36,6 +37,7 @@ bool WeightQuantizer::IsPosNum(const std::string &str) { | |||||
| } | } | ||||
| return true; | return true; | ||||
| } | } | ||||
| STATUS WeightQuantizer::WeightQuantInputCheck(const converter::Flags *config) { | STATUS WeightQuantizer::WeightQuantInputCheck(const converter::Flags *config) { | ||||
| MS_ASSERT(config != nullptr); | MS_ASSERT(config != nullptr); | ||||
| if (!WeightQuantizer::IsPosNum(config->quantWeightChannel)) { | if (!WeightQuantizer::IsPosNum(config->quantWeightChannel)) { | ||||
| @@ -57,28 +59,57 @@ STATUS WeightQuantizer::WeightQuantInputCheck(const converter::Flags *config) { | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const string &weightSize, | |||||
| WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const PostQuantConfig &config) : Quantizer(graph) { | |||||
| quant_strategy_ = std::make_unique<QuantStrategy>(0, 0); | |||||
| config_param_ = config; | |||||
| } | |||||
| WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const std::string &config_file, const string &weightSize, | |||||
| const std::string &convWeightChannelThreshold, const std::string &bitNum) | const std::string &convWeightChannelThreshold, const std::string &bitNum) | ||||
| : Quantizer(graph) { | : Quantizer(graph) { | ||||
| this->config_file_ = config_file; | |||||
| auto quantSize = static_cast<size_t>(std::stoull(weightSize)); | auto quantSize = static_cast<size_t>(std::stoull(weightSize)); | ||||
| this->bitNum = static_cast<size_t>(std::stoull(bitNum)); | |||||
| this->bit_num_ = static_cast<size_t>(std::stoull(bitNum)); | |||||
| auto convQuantWeightChannelThreshold = static_cast<size_t>(std::stoull(convWeightChannelThreshold)); | auto convQuantWeightChannelThreshold = static_cast<size_t>(std::stoull(convWeightChannelThreshold)); | ||||
| mStrategy = std::make_unique<QuantStrategy>(quantSize, convQuantWeightChannelThreshold); | |||||
| quant_max = (1 << (unsigned int)(this->bitNum - 1)) - 1; | |||||
| quant_min = -(1 << (unsigned int)(this->bitNum - 1)); | |||||
| quant_strategy_ = std::make_unique<QuantStrategy>(quantSize, convQuantWeightChannelThreshold); | |||||
| quant_max = (1 << (unsigned int)(this->bit_num_ - 1)) - 1; | |||||
| quant_min = -(1 << (unsigned int)(this->bit_num_ - 1)); | |||||
| // parse type_id | // parse type_id | ||||
| if (this->bitNum > 0 && this->bitNum <= 8) { | |||||
| if (this->bit_num_ > 0 && this->bit_num_ <= 8) { | |||||
| type_id = kNumberTypeInt8; | type_id = kNumberTypeInt8; | ||||
| } else if (this->bitNum <= 16) { | |||||
| } else if (this->bit_num_ <= 16) { | |||||
| type_id = kNumberTypeInt16; | type_id = kNumberTypeInt16; | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "invalid input bits"; | MS_LOG(ERROR) << "invalid input bits"; | ||||
| } | } | ||||
| } | } | ||||
| WeightQuantizer::~WeightQuantizer() { delete fp32_session_; } | |||||
| STATUS WeightQuantizer::SetAbstract(ParamValueLitePtr param_value, ParameterPtr param_node, | |||||
| std::shared_ptr<PrimitiveC> primitive_c) { | |||||
| // set dtype | |||||
| param_value->set_tensor_type(type_id); | |||||
| auto abstract_base = param_node->abstract(); | |||||
| if (abstract_base == nullptr) { | |||||
| MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) { | |||||
| MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << param_node->name(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base); | |||||
| abstract_tensor->element()->set_type(TypeIdToType(type_id)); | |||||
| primitive_c->set_quant_type(schema::QuantType_WeightQuant); | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS WeightQuantizer::DoConvQuantize(const std::list<CNodePtr> &nodes) { | STATUS WeightQuantizer::DoConvQuantize(const std::list<CNodePtr> &nodes) { | ||||
| for (auto &cnode : nodes) { | for (auto &cnode : nodes) { | ||||
| if (!mStrategy->CanConvOpQuantized(cnode)) { | |||||
| if (!quant_strategy_->CanConvOpQuantized(cnode)) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -108,36 +139,28 @@ STATUS WeightQuantizer::DoConvQuantize(const std::list<CNodePtr> &nodes) { | |||||
| } | } | ||||
| auto status = RET_ERROR; | auto status = RET_ERROR; | ||||
| if (type_id == kNumberTypeInt8) { | if (type_id == kNumberTypeInt8) { | ||||
| status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bitNum, true); | |||||
| status = | |||||
| QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bit_num_, true); | |||||
| } else if (type_id == kNumberTypeInt16) { | } else if (type_id == kNumberTypeInt16) { | ||||
| status = | status = | ||||
| QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bitNum, true); | |||||
| QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bit_num_, true); | |||||
| } | } | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "QuantFilter failed : " << status; | MS_LOG(ERROR) << "QuantFilter failed : " << status; | ||||
| return status; | return status; | ||||
| } | } | ||||
| // set dtype | |||||
| param_value->set_tensor_type(type_id); | |||||
| auto abstractBase = param_node->abstract(); | |||||
| if (abstractBase == nullptr) { | |||||
| MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (!utils::isa<abstract::AbstractTensorPtr>(abstractBase)) { | |||||
| MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << param_node->name(); | |||||
| status = SetAbstract(param_value, param_node, primitive_c); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "SetAbstract failed : " << status; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase); | |||||
| abstractTensor->element()->set_type(TypeIdToType(type_id)); | |||||
| primitive_c->set_quant_type(schema::QuantType_WeightQuant); | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| STATUS WeightQuantizer::DoMulQuantize(const std::list<CNodePtr> &nodes) { | STATUS WeightQuantizer::DoMulQuantize(const std::list<CNodePtr> &nodes) { | ||||
| for (auto &node : nodes) { | for (auto &node : nodes) { | ||||
| if (!mStrategy->CanMulOpQuantized(node)) { | |||||
| if (!quant_strategy_->CanMulOpQuantized(node)) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| auto already_quant = false; | auto already_quant = false; | ||||
| @@ -186,38 +209,271 @@ STATUS WeightQuantizer::DoMulQuantize(const std::list<CNodePtr> &nodes) { | |||||
| auto status = RET_ERROR; | auto status = RET_ERROR; | ||||
| if (type_id == kNumberTypeInt8) { | if (type_id == kNumberTypeInt8) { | ||||
| status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bitNum, true); | |||||
| status = | |||||
| QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bit_num_, true); | |||||
| } else if (type_id == kNumberTypeInt16) { | } else if (type_id == kNumberTypeInt16) { | ||||
| status = | status = | ||||
| QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bitNum, true); | |||||
| QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bit_num_, true); | |||||
| } | } | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "QuantFilter failed : " << status; | MS_LOG(ERROR) << "QuantFilter failed : " << status; | ||||
| return status; | return status; | ||||
| } | } | ||||
| param_value->set_tensor_type(type_id); | |||||
| // set dtype | |||||
| auto abstractBase = param_node->abstract(); | |||||
| if (abstractBase == nullptr) { | |||||
| MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name(); | |||||
| status = SetAbstract(param_value, param_node, primitive_c); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "SetAbstract failed : " << status; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| constexpr float relative_tolerance = 1e-5; | |||||
| constexpr float abs_tolerance = 1e-4; | |||||
| template <typename T> | |||||
| float CompareOutputData(const std::unordered_map<std::string, mindspore::tensor::MSTensor *> &expected_tensor, | |||||
| const std::unordered_map<std::string, mindspore::tensor::MSTensor *> &compare_tensor) { | |||||
| auto valid_data = [](T data) -> bool { return (!std::isnan(data) && !std::isinf(data)); }; | |||||
| float total_mean_error = 0.0f; | |||||
| int tensor_cnt = expected_tensor.size(); | |||||
| if (tensor_cnt <= 0) { | |||||
| MS_LOG(ERROR) << "unexpected tensor_cnt: " << tensor_cnt; | |||||
| return RET_ERROR; | |||||
| } | |||||
| for (const auto &exp_tensor_pair : expected_tensor) { | |||||
| float mean_error = 0.0f; | |||||
| int error_cnt = 0; | |||||
| auto exp_tensor_name = exp_tensor_pair.first; | |||||
| auto exp_tensor = exp_tensor_pair.second; | |||||
| auto cmp_tensor_find_iter = compare_tensor.find(exp_tensor_name); | |||||
| if (cmp_tensor_find_iter == compare_tensor.end()) { | |||||
| MS_LOG(ERROR) << "can not find: " << exp_tensor_name; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (!utils::isa<abstract::AbstractTensorPtr>(abstractBase)) { | |||||
| MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << param_node->name(); | |||||
| auto cmp_tensor = cmp_tensor_find_iter->second; | |||||
| auto exp_tensor_shape = exp_tensor->shape(); | |||||
| auto cmp_tensor_shape = cmp_tensor->shape(); | |||||
| if (exp_tensor_shape != cmp_tensor_shape) { | |||||
| MS_LOG(ERROR) << "exp tensor shape not equal to cmp. exp_tensor_elem_cnt: " << exp_tensor->ElementsNum() | |||||
| << " cmp_tensor_elem_cnt: " << cmp_tensor->ElementsNum(); | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase); | |||||
| abstractTensor->element()->set_type(TypeIdToType(type_id)); | |||||
| primitive_c->set_quant_type(schema::QuantType_WeightQuant); | |||||
| auto exp_data = static_cast<T *>(exp_tensor->MutableData()); | |||||
| auto cmp_data = static_cast<T *>(cmp_tensor->MutableData()); | |||||
| auto elem_cnt = exp_tensor->ElementsNum(); | |||||
| for (int i = 0; i < elem_cnt; i++) { | |||||
| if (!valid_data(exp_data[i]) || !valid_data(cmp_data[i])) { | |||||
| MS_LOG(ERROR) << "data is not valid. exp: " << exp_data[i] << " cmp: " << cmp_data[i] << " index: " << i; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto tolerance = abs_tolerance + relative_tolerance * fabs(exp_data[i]); | |||||
| auto abs_error = std::fabs(exp_data[i] - cmp_data[i]); | |||||
| if (abs_error > tolerance) { | |||||
| if (fabs(exp_data[i] == 0)) { | |||||
| if (abs_error > 1e-5) { | |||||
| mean_error += abs_error; | |||||
| error_cnt++; | |||||
| } else { | |||||
| // it is ok, very close to 0 | |||||
| continue; | |||||
| } | |||||
| } else { | |||||
| mean_error += abs_error / (fabs(exp_data[i]) + FLT_MIN); | |||||
| error_cnt++; | |||||
| } | |||||
| } else { | |||||
| // it is ok, no error | |||||
| continue; | |||||
| } | |||||
| } // end one tensor data loop | |||||
| total_mean_error += mean_error / elem_cnt; | |||||
| } // end tensor loop | |||||
| return total_mean_error / tensor_cnt; | |||||
| } | |||||
| STATUS WeightQuantizer::DoMiexedQuant(FuncGraphPtr func_graph) { | |||||
| // 0.1 Create Fp32 Session | |||||
| flags.quantType = schema::QuantType_QUANT_NONE; | |||||
| fp32_session_ = CreateSessionByFuncGraph(func_graph, flags, config_param_.thread_num); | |||||
| if (fp32_session_ == nullptr) { | |||||
| MS_LOG(ERROR) << "CreateSessoin fail"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto fp32_inputs = fp32_session_->GetInputs(); | |||||
| // 0.2 Parse input calib files | |||||
| auto status = CollectCalibInputs(config_param_.image_paths, config_param_.batch_count, &images_); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "CollectCalibInputs fail"; | |||||
| return RET_ERROR; | |||||
| } | } | ||||
| auto cnodes = func_graph->GetOrderedCnodes(); | |||||
| for (auto iter = cnodes.end(); iter != cnodes.begin();) { | |||||
| auto cnode = *(--iter); | |||||
| auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); | |||||
| if (primitive_c == nullptr) { | |||||
| MS_LOG(ERROR) << "primitive_c is null."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto op_name = cnode->fullname_with_scope(); | |||||
| MS_LOG(DEBUG) << "process node: " << op_name | |||||
| << " type: " << schema::EnumNamePrimitiveType((schema::PrimitiveType)primitive_c->Type()); | |||||
| if (quant_strategy_->CanConvOpQuantized(cnode) || quant_strategy_->CanMulOpQuantized(cnode)) { | |||||
| auto input_node = cnode->input(2); | |||||
| if (!input_node->isa<Parameter>()) { | |||||
| MS_LOG(WARNING) << op_name << " the second input is not parameter"; | |||||
| continue; | |||||
| } | |||||
| auto param_node = input_node->cast<ParameterPtr>(); | |||||
| if (!param_node->has_default()) { | |||||
| MS_LOG(WARNING) << op_name << " the second input can not convert to parameter"; | |||||
| continue; | |||||
| } | |||||
| auto param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param()); | |||||
| if (param_value == nullptr) { | |||||
| MS_LOG(WARNING) << op_name << " the second input can not convert to parameter"; | |||||
| continue; | |||||
| } | |||||
| if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) { | |||||
| MS_LOG(WARNING) << op_name << " the second input type is not float"; | |||||
| continue; | |||||
| } | |||||
| // copy origin data in case to recover | |||||
| auto *raw_data = static_cast<float *>(param_value->tensor_addr()); | |||||
| auto elem_count = param_value->tensor_shape_size(); | |||||
| auto origin_data = malloc(sizeof(float) * elem_count); | |||||
| auto ret = memcpy_s(origin_data, sizeof(float) * elem_count, raw_data, param_value->tensor_size()); | |||||
| if (ret != EOK) { | |||||
| MS_LOG(ERROR) << "memcpy fail: " | |||||
| << " dst size: " << sizeof(float) * elem_count << " src size: " << param_value->tensor_size(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| // 1. try quant | |||||
| for (int bit_num_t = 2; bit_num_t <= 8; bit_num_t++) { | |||||
| type_id = TypeId::kNumberTypeInt8; | |||||
| int quant_max_t = (1 << (unsigned int)(bit_num_t - 1)) - 1; | |||||
| int quant_min_t = -(1 << (unsigned int)(bit_num_t - 1)); | |||||
| if (type_id == TypeId::kNumberTypeInt8) { | |||||
| status = QuantFilter<int8_t>(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t, | |||||
| quant_min_t, bit_num_t, true); | |||||
| } else if (type_id == TypeId::kNumberTypeInt16) { | |||||
| status = QuantFilter<int16_t>(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t, | |||||
| quant_min_t, bit_num_t, true); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "unexpected type_id: " << type_id; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "quant filter fail."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| status = SetAbstract(param_value, param_node, primitive_c); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "SetAbstract failed : " << status; | |||||
| return RET_ERROR; | |||||
| } | |||||
| // 2. evaluate the quant | |||||
| // 2.1 create quant session, get input, output tensor | |||||
| flags.quantType = schema::QuantType_WeightQuant; | |||||
| auto quant_session = | |||||
| std::unique_ptr<session::LiteSession>(CreateSessionByFuncGraph(func_graph, flags, config_param_.thread_num)); | |||||
| if (quant_session == nullptr) { | |||||
| MS_LOG(ERROR) << "create session error: " << status; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto quant_inputs = quant_session->GetInputs(); | |||||
| auto mean_error = 0.0f; | |||||
| if (fp32_inputs.size() != images_.size()) { | |||||
| MS_LOG(ERROR) << "model's input tensor cnt: " << fp32_inputs.size() << " != " << images_.size(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto image_cnt = images_.at(0).size(); | |||||
| for (size_t i = 0; i < image_cnt; i++) { | |||||
| // set multi-input data | |||||
| for (size_t input_index = 0; input_index < fp32_inputs.size(); input_index++) { | |||||
| status = CopyInputDataToTensor(input_index, i, images_, fp32_inputs[input_index]); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "generate input data from images failed!"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| status = CopyInputDataToTensor(input_index, i, images_, quant_inputs[input_index]); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "generate input data from images failed!"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| std::future<STATUS> fp32_inference = std::async( | |||||
| std::launch::async, [](session::LiteSession *fp32_session) -> STATUS { return fp32_session->RunGraph(); }, | |||||
| fp32_session_); | |||||
| status = quant_session->RunGraph(); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "quant session run error"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| status = fp32_inference.get(); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "fp32 session run error"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| // 3. compare betwen quant and fp32 | |||||
| auto fp32_outputs = fp32_session_->GetOutputs(); | |||||
| auto quant_outputs = quant_session->GetOutputs(); | |||||
| mean_error += CompareOutputData<float>(fp32_outputs, quant_outputs); | |||||
| } // end_for: calib data loop | |||||
| mean_error = mean_error / image_cnt; | |||||
| if (mean_error <= config_param_.mean_error_threshold) { | |||||
| MS_LOG(DEBUG) << "op: " << op_name << " got mixed bit: " << bit_num_t << " mean_error: " << mean_error; | |||||
| opname_bit_[op_name] = bit_num_t; | |||||
| break; | |||||
| } else if (bit_num_t != 8) { | |||||
| // recover | |||||
| param_value->set_tensor_size(sizeof(float) * elem_count); | |||||
| ret = memcpy_s(raw_data, param_value->tensor_size(), origin_data, sizeof(float) * elem_count); | |||||
| if (ret != EOK) { | |||||
| MS_LOG(ERROR) << "memcpy fail: " | |||||
| << " src size: " << sizeof(float) * elem_count << " dst size: " << param_value->tensor_size(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| } else { | |||||
| MS_LOG(DEBUG) << "op: " << op_name << " set bit: " << bit_num_t << " mean_error: " << mean_error; | |||||
| opname_bit_[op_name] = bit_num_t; | |||||
| } | |||||
| } // end bit loop | |||||
| free(origin_data); | |||||
| } // if: conv and matmul | |||||
| } // end loop: all cnode | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| STATUS WeightQuantizer::DoQuantize(FuncGraphPtr funcGraph) { | |||||
| MS_ASSERT(funcGraph != nullptr); | |||||
| STATUS WeightQuantizer::DoQuantize(FuncGraphPtr func_graph) { | |||||
| MS_ASSERT(func_graph != nullptr); | |||||
| STATUS ret; | STATUS ret; | ||||
| auto cnodes = funcGraph->GetOrderedCnodes(); | |||||
| auto cnodes = func_graph->GetOrderedCnodes(); | |||||
| if (!config_file_.empty()) { | |||||
| ret = ParseConfigFile(config_file_, &config_param_); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "ReadConfig error."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| if (config_param_.mixed) { | |||||
| MS_LOG(INFO) << "Do mixed bit quantization"; | |||||
| return DoMiexedQuant(func_graph); | |||||
| } | |||||
| ret = DoConvQuantize(cnodes); | ret = DoConvQuantize(cnodes); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "DoConvQuantize failed :" << ret; | MS_LOG(ERROR) << "DoConvQuantize failed :" << ret; | ||||
| @@ -17,9 +17,12 @@ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_WEIGHT_QUANTIZER_H | #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_WEIGHT_QUANTIZER_H | ||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_WEIGHT_QUANTIZER_H | #define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_WEIGHT_QUANTIZER_H | ||||
| #include <future> | |||||
| #include <memory> | #include <memory> | ||||
| #include <map> | |||||
| #include <list> | #include <list> | ||||
| #include <string> | #include <string> | ||||
| #include <vector> | |||||
| #include "tools/converter/quantizer/quantizer.h" | #include "tools/converter/quantizer/quantizer.h" | ||||
| #include "tools/converter/quantizer/quantize_util.h" | #include "tools/converter/quantizer/quantize_util.h" | ||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| @@ -27,27 +30,37 @@ | |||||
| #include "include/model.h" | #include "include/model.h" | ||||
| #include "base/base.h" | #include "base/base.h" | ||||
| #include "abstract/dshape.h" | #include "abstract/dshape.h" | ||||
| #include "src/lite_session.h" | |||||
| namespace mindspore::lite::quant { | namespace mindspore::lite::quant { | ||||
| class WeightQuantizer : public Quantizer { | class WeightQuantizer : public Quantizer { | ||||
| public: | public: | ||||
| WeightQuantizer(FuncGraphPtr graph, const std::string &weightSize, const std::string &covWeightChannelThreshold, | |||||
| const std::string &bitNum); | |||||
| WeightQuantizer(FuncGraphPtr graph, const std::string &config_file, const std::string &weightSize, | |||||
| const std::string &covWeightChannelThreshold, const std::string &bitNum); | |||||
| WeightQuantizer(FuncGraphPtr graph, const PostQuantConfig &config); | |||||
| ~WeightQuantizer(); | |||||
| ~WeightQuantizer() = default; | |||||
| STATUS DoQuantize(FuncGraphPtr funcGraph) override; | |||||
| STATUS DoQuantize(FuncGraphPtr func_graph) override; | |||||
| STATUS DoConvQuantize(const std::list<CNodePtr> &nodes); | STATUS DoConvQuantize(const std::list<CNodePtr> &nodes); | ||||
| STATUS DoMulQuantize(const std::list<CNodePtr> &nodes); | STATUS DoMulQuantize(const std::list<CNodePtr> &nodes); | ||||
| static STATUS WeightQuantInputCheck(const converter::Flags *config); | static STATUS WeightQuantInputCheck(const converter::Flags *config); | ||||
| static bool IsPosNum(const std::string &str); | static bool IsPosNum(const std::string &str); | ||||
| int quant_max; | int quant_max; | ||||
| int quant_min; | int quant_min; | ||||
| TypeId type_id{kTypeUnknown}; | TypeId type_id{kTypeUnknown}; | ||||
| std::map<std::string, int> opname_bit_; | |||||
| private: | private: | ||||
| std::unique_ptr<QuantStrategy> mStrategy; | |||||
| size_t bitNum; | |||||
| std::unique_ptr<QuantStrategy> quant_strategy_; | |||||
| size_t bit_num_; | |||||
| std::string config_file_; | |||||
| PostQuantConfig config_param_; | |||||
| std::vector<std::vector<std::string>> images_; // multi_input, [[mode_input_0], [model_input_1]...] | |||||
| session::LiteSession *fp32_session_ = nullptr; | |||||
| STATUS DoMiexedQuant(FuncGraphPtr); | |||||
| STATUS SetAbstract(ParamValueLitePtr param_value, ParameterPtr param_node, std::shared_ptr<PrimitiveC> primitive_c); | |||||
| }; | }; | ||||
| } // namespace mindspore::lite::quant | } // namespace mindspore::lite::quant | ||||
| #endif | #endif | ||||