From f32ff96ecdd411176f2739e669d53404818cd0a7 Mon Sep 17 00:00:00 2001 From: xutianchun Date: Tue, 11 Aug 2020 14:40:33 +0800 Subject: [PATCH] post training quantization --- .../kernel/arm/base/quant_dtype_cast.cc | 2 +- .../src/runtime/kernel/arm/fp32/softmax.h | 2 +- mindspore/lite/tools/converter/converter.cc | 3 +- .../node/weight_format_pass.cc | 7 +- .../tools/converter/quantizer/CMakeLists.txt | 5 +- ...training.cc => post_training_quantizer.cc} | 99 ++++++++++--------- ...t_training.h => post_training_quantizer.h} | 8 +- .../converter/quantizer/quantize_util.cc | 4 +- .../tools/converter/quantizer/quantizer.h | 2 + 9 files changed, 75 insertions(+), 57 deletions(-) rename mindspore/lite/tools/converter/quantizer/{post_training.cc => post_training_quantizer.cc} (93%) rename mindspore/lite/tools/converter/quantizer/{post_training.h => post_training_quantizer.h} (95%) diff --git a/mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.cc b/mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.cc index 10068184d5..656ec30096 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.cc @@ -118,7 +118,7 @@ int QuantDTypeCastCPUKernel::Run() { int8_ptr_ = reinterpret_cast(out_tensors_[0]->Data()); } - int ret = LiteBackendParallelLaunch(QuantDTypeCastRun, this, thread_n_num_); + auto ret = LiteBackendParallelLaunch(QuantDTypeCastRun, this, thread_n_num_); if (ret != RET_OK) { MS_LOG(ERROR) << "Scale error error_code[" << ret << "]"; return RET_ERROR; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/softmax.h b/mindspore/lite/src/runtime/kernel/arm/fp32/softmax.h index 515535a328..e24fe9c640 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/softmax.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/softmax.h @@ -39,7 +39,7 @@ class SoftmaxCPUKernel : public SoftmaxBaseCPUKernel { int Run() override; private: - float *sum_data_; + float *sum_data_ = nullptr; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc index 87ab8a9910..175d5a86dd 100644 --- a/mindspore/lite/tools/converter/converter.cc +++ b/mindspore/lite/tools/converter/converter.cc @@ -31,7 +31,7 @@ #include "src/common/anf_importer/import_from_protobuf.h" #include "tools/converter/parser/onnx/onnx.pb.h" #include "tools/converter/quantizer/weight_quantizer.h" -#include "tools/converter/quantizer/post_training.h" +#include "tools/converter/quantizer/post_training_quantizer.h" #include "tools/converter/quantizer/quant_cast.h" namespace mindspore { @@ -94,6 +94,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) { CreateQuantizer(graph, flag); if (mQuantizer != nullptr) { + mQuantizer->flags = *flag; auto status = mQuantizer->DoQuantize(graph); if (status != RET_OK) { MS_LOG(ERROR) << "Quant failed " << status; diff --git a/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc index 7add49877f..06f6d39ca0 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc @@ -277,8 +277,9 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) { } else if (weightTensor->format == schema::Format_CHWK) { // from onnx if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { status = TransFilterFormat(weightTensor.get(), kCHWK2KHWC); + MS_LOG(DEBUG) << node->name << " weight trans format: CHWK->KHWC"; } else { - status = TransFilterFormat(weightTensor.get(), kCHWK2HWCK); + status = TransFilterFormat(weightTensor.get(), kCHWK2KHWC); } } else if (weightTensor->format == schema::Format_KCHW) { if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { @@ -291,8 +292,8 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) { return -1; } if (status == 0) { - node->primitive->value.AsDepthwiseConv2D()->format = schema::Format_NCHW; - weightTensor->format = schema::Format_HWCK; + node->primitive->value.AsDepthwiseConv2D()->format = schema::Format_NHWC; + weightTensor->format = schema::Format_KHWC; } else { MS_LOG(WARNING) << "TransFilter %ToHWCK failed, node : " << (weightTensor->format == schema::Format_CHWK ? "CHWK" : "CKHW"), diff --git a/mindspore/lite/tools/converter/quantizer/CMakeLists.txt b/mindspore/lite/tools/converter/quantizer/CMakeLists.txt index f22f952164..24a17962c1 100644 --- a/mindspore/lite/tools/converter/quantizer/CMakeLists.txt +++ b/mindspore/lite/tools/converter/quantizer/CMakeLists.txt @@ -4,15 +4,12 @@ include_directories(${3RD_DIR}/flatbuffers/include) include_directories(${3RD_DIR}/opencv/build/include/opencv4) add_library(quantizer_mid OBJECT - #${CMAKE_CURRENT_SOURCE_DIR}/calc_quant_param.cc ${CMAKE_CURRENT_SOURCE_DIR}/quantizer.cc - #${CMAKE_CURRENT_SOURCE_DIR}/aware_quantizer.cc ${CMAKE_CURRENT_SOURCE_DIR}/weight_quantizer.cc ${CMAKE_CURRENT_SOURCE_DIR}/quantize_util.cc ${CMAKE_CURRENT_SOURCE_DIR}/general_bitpacking.cc - ${CMAKE_CURRENT_SOURCE_DIR}/post_training.cc + ${CMAKE_CURRENT_SOURCE_DIR}/post_training_quantizer.cc ${CMAKE_CURRENT_SOURCE_DIR}/quant_cast.cc - #${CMAKE_CURRENT_SOURCE_DIR}/../proto/post_training/post_training.pb.cc ) if(ENABLE_ASAN) diff --git a/mindspore/lite/tools/converter/quantizer/post_training.cc b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc similarity index 93% rename from mindspore/lite/tools/converter/quantizer/post_training.cc rename to mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc index bf149ff9cf..3e8840c4da 100644 --- a/mindspore/lite/tools/converter/quantizer/post_training.cc +++ b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc @@ -28,7 +28,7 @@ #include "schema/inner/model_generated.h" #include "src/ir/tensor.h" #include "src/common/anf_exporter/anf_exporter.h" -#include "tools/converter/quantizer/post_training.h" +#include "tools/converter/quantizer/post_training_quantizer.h" #include "tools/converter/quantizer/quantize_util.h" #include "src/common/common.h" #include "utils/log_adapter.h" @@ -54,7 +54,10 @@ struct DivergInfo { size_t bit_num; int quant_max = 255; int quant_min = 0; - DivergInfo(CNodePtr cnode, int bins, size_t bits, int quant_max, int quant_min) { + std::string method_x = kMethodKL; + + DivergInfo(CNodePtr cnode, int bins, size_t bits, int quant_max, int quant_min, const std::string &method_x) { + this->method_x = method_x; this->cnode = cnode; this->bin_num = bins; this->bit_num = bits; @@ -99,6 +102,12 @@ struct DivergInfo { } STATUS ComputeThreshold() { + if (method_x == kMethodMaxMin) { + this->best_T = std::max(fabs(this->max), fabs(this->min)); + MS_LOG(DEBUG) << "using MAX_MIN, T: " << this->best_T; + return RET_OK; + } + constexpr int quant_bint_nums = 128; int threshold = quant_bint_nums; float min_kl = FLT_MAX; @@ -200,46 +209,32 @@ struct DivergInfo { threshold = i; } } - MS_LOG(DEBUG) << "Best threshold bin index: " << threshold; this->best_T = (static_cast(threshold) + 0.5f) * this->interval; + MS_LOG(DEBUG) << cnode->fullname_with_scope() << " Best threshold bin index: " << threshold + << " T: " << best_T + << " max: " << std::max(fabs(this->max), fabs(this->min)); return RET_OK; } std::pair GetScale() { float max_value = this->best_T; float min_value = -max_value; + MS_ASSERT(quant_max - quant_min != 0); - double scale = (max_value - min_value) / (quant_max - quant_min); + float scale = (max_value - min_value) / (quant_max - quant_min); MS_ASSERT(scale != 0); return std::make_pair(this->cnode, scale); } std::pair GetZeropoint() { - float max_value = this->best_T; - float min_value = -max_value; - MS_ASSERT(quant_max - quant_min != 0); - float scale = (max_value - min_value) / (quant_max - quant_min); - - auto quant_min_float = static_cast(quant_min); - auto quant_max_float = static_cast(quant_max); - MS_ASSERT(scale != 0); - const float zero_point_from_min = quant_min_float - min_value / scale; - // const float zero_point_from_max = quant_max_float - max_value / scale; - int zero_point; - if (zero_point_from_min < quant_min_float) { - zero_point = quant_min; - } else if (zero_point_from_min > quant_max_float) { - zero_point = quant_max; - } else { - zero_point = static_cast(std::round(zero_point_from_min)); - } - MS_LOG(DEBUG) << "zero point:" << zero_point; + int zero_point = 0; if (quant_min == 0 && quant_max == 255) { zero_point = 128; } else if (quant_min == -128 && quant_max == 127) { zero_point = 0; + } else { + MS_LOG(ERROR) << "unexpectd quant range, quant_min: " << quant_min << " quant_max: " << quant_max; } - return std::make_pair(this->cnode, zero_point); } }; @@ -356,9 +351,9 @@ STATUS Calibrator::AddQuantizedOp(CNodePtr node) { } string node_name = node->fullname_with_scope(); std::unique_ptr input_diverg = - std::unique_ptr(new DivergInfo(node, 2048, bit_num_, quant_max_, quant_min_)); + std::unique_ptr(new DivergInfo(node, 2048, bit_num_, quant_max_, quant_min_, config_param_.method_x)); std::unique_ptr output_diverg = - std::unique_ptr(new DivergInfo(node, 2048, bit_num_, quant_max_, quant_min_)); + std::unique_ptr(new DivergInfo(node, 2048, bit_num_, quant_max_, quant_min_, config_param_.method_x)); input_diverg_info_.insert(std::make_pair(string(node_name), std::move(input_diverg))); output_diverg_info_.insert(std::make_pair(string(node_name), std::move(output_diverg))); @@ -383,13 +378,13 @@ STATUS Calibrator::GenerateInputData(const int index, mindspore::tensor::MSTenso MS_LOG(INFO) << "read image: " << path; size_t size; char *binBuf = ReadFile(path.c_str(), &size); - - // auto *rawinputDatas = reinterpret_cast(binBuf); - // auto mobilenet_input = const_cast(rawinputDatas); auto data = tensor->MutableData(); + if (size != tensor->Size()) { + MS_LOG(ERROR) << "the input data is not consistent with model input, file_size: " << size + << " input tensor size: " << tensor->Size(); + return RET_ERROR; + } memcpy(data, binBuf, size); - - // tensor->SetData(mobilenet_input); return RET_OK; } @@ -457,13 +452,20 @@ STATUS Calibrator::ReadConfig() { config_param_.batch_count = std::stoul(value); } else if (key == "thread_num") { config_param_.thread_num = std::stoul(value); + } else if (key == "method_x") { + if (value != kMethodKL && value != kMethodMaxMin) { + MS_LOG(WARNING) << "unsupported method_x: " << value << ". Use default value."; + } else { + config_param_.method_x = value; + } } else { MS_LOG(WARNING) << "unsupported parameter"; } } - MS_LOG(INFO) << "image_path: " << config_param_.image_path << " " - << "batch_count: " << config_param_.batch_count << " " - << "thread_num: " << config_param_.thread_num; + MS_LOG(DEBUG) << "image_path: " << config_param_.image_path << " " + << "batch_count: " << config_param_.batch_count << " " + << "mothod_x: " << config_param_.method_x << " " + << "thread_num: " << config_param_.thread_num; delete[] resolved_path; fs.close(); @@ -615,7 +617,7 @@ STATUS PostTrainingQuantizer::DoBiasQuant(std::shared_ptr input quant_datas[i] = quant_data; } auto ret = - memcpy_s(bias_param->tensor_addr(), shape_size * sizeof(int32_t), quant_datas, shape_size * sizeof(int32_t)); + memcpy_s(bias_param->tensor_addr(), bias_param->tensor_size(), quant_datas, shape_size * sizeof(int32_t)); if (ret != EOK) { MS_LOG(ERROR) << "memcpy_s failed."; delete[] quant_datas; @@ -805,14 +807,6 @@ STATUS PostTrainingQuantizer::DoInference() { MS_LOG(ERROR) << "generate input data from images failed!"; return RET_ERROR; } - /** - * struct CallBackParam { - std::string nodeType; - NODE_ID nodeName; - std::unordered_set depends; - int opExecResult; - }; - */ mindspore::session::KernelCallBack beforeCallBack = [&](const std::vector &beforeInputs, const std::vector &beforeOutputs, @@ -916,9 +910,26 @@ STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr funcGraph) { MS_LOG(ERROR) << "do pre process failed!"; return status; } + + // anf -- fb + auto meta_graph = Export(funcGraph); + if (meta_graph == nullptr) { + MS_LOG(ERROR) << "Export to meta_graph return nullptr"; + return RET_ERROR; + } + + // transform + GraphDefTransform transform; + transform.SetGraphDef(meta_graph); + flags.quantType = schema::QuantType_QUANT_NONE; + status = transform.Transform(flags); + if (status != RET_OK) { + MS_LOG(ERROR) << "FBTransform model failed " << status; + return RET_ERROR; + } MS_LOG(INFO) << "start create session"; flatbuffers::FlatBufferBuilder builder(1024); - auto offset = schema::MetaGraph::Pack(builder, Export(funcGraph)); + auto offset = schema::MetaGraph::Pack(builder, meta_graph); builder.Finish(offset); size_t size = builder.GetSize(); auto *content = reinterpret_cast(builder.GetBufferPointer()); diff --git a/mindspore/lite/tools/converter/quantizer/post_training.h b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.h similarity index 95% rename from mindspore/lite/tools/converter/quantizer/post_training.h rename to mindspore/lite/tools/converter/quantizer/post_training_quantizer.h index 06273396b8..d9e16e1f84 100644 --- a/mindspore/lite/tools/converter/quantizer/post_training.h +++ b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.h @@ -46,10 +46,14 @@ enum ImageFormat { BGR = 2, }; +const char kMethodMaxMin[] = "MAX_MIN"; +const char kMethodKL[] = "KL"; + struct ConfigParam { // ImageFormat imageFormat; std::string image_path; - uint32_t batch_count; + uint32_t batch_count{100}; + std::string method_x{kMethodKL}; uint32_t thread_num; }; @@ -115,6 +119,8 @@ class Calibrator { uint32_t GetThreadNum() const { return config_param_.thread_num; } + std::string GetMethodX() const { return config_param_.method_x; } + STATUS AddQuantizedOp(CNodePtr node); STATUS RecordMaxValue(std::string opName, std::vector data, diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.cc b/mindspore/lite/tools/converter/quantizer/quantize_util.cc index 151bdacd52..9fc5e55df0 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.cc +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.cc @@ -89,7 +89,7 @@ bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const { auto primitiveT_value = GetValueNode>(cnode->input(0)); if (primitiveT_value == nullptr) { - MS_LOG(ERROR) << "PrimitiveT_value is nullptr: " << cnode->fullname_with_scope(); + MS_LOG(WARNING) << "PrimitiveT_value is nullptr: " << cnode->fullname_with_scope(); return false; } @@ -344,7 +344,7 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_ } weightPtr->set_quant_param(quantParam); - auto ret = memcpy_s(rawDatas, weightPtr->tensor_size() * sizeof(int8_t), + auto ret = memcpy_s(rawDatas, weightPtr->tensor_size(), qDatas.data(), shapeSize * sizeof(int8_t)); if (ret != EOK) { MS_LOG(ERROR) << "memcpy error: " << ret; diff --git a/mindspore/lite/tools/converter/quantizer/quantizer.h b/mindspore/lite/tools/converter/quantizer/quantizer.h index 19284052f3..1cbd6f26cc 100644 --- a/mindspore/lite/tools/converter/quantizer/quantizer.h +++ b/mindspore/lite/tools/converter/quantizer/quantizer.h @@ -24,6 +24,7 @@ #include "include/model.h" #include "base/base.h" #include "src/param_value_lite.h" +#include "tools/converter/converter_flags.h" namespace mindspore { namespace lite { @@ -52,6 +53,7 @@ class Quantizer { virtual STATUS DoQuantize(FuncGraphPtr funcGraph) = 0; + mindspore::lite::converter::Flags flags; protected: FuncGraphPtr funcGraph = nullptr; };