From c5b5cb03c9967954f5e6fbcaaf393976584d06b3 Mon Sep 17 00:00:00 2001 From: xutianchun Date: Tue, 29 Sep 2020 15:15:41 +0800 Subject: [PATCH] fix bias correction bug --- .../lite/tools/anf_exporter/anf_exporter.cc | 16 +++- .../lite/tools/anf_exporter/anf_exporter.h | 9 +- .../quantizer/post_training_quantizer.cc | 91 +++++++++++++------ .../quantizer/post_training_quantizer.h | 20 ++-- .../tools/converter/quantizer/quant_cast.cc | 3 +- .../converter/quantizer/quantize_util.cc | 11 +-- 6 files changed, 100 insertions(+), 50 deletions(-) diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.cc b/mindspore/lite/tools/anf_exporter/anf_exporter.cc index 877b2d5868..330c1c12a4 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.cc @@ -195,7 +195,7 @@ int AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_pt return RET_OK; } -schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool keep_graph) { +schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive) { auto cnodes = func_graph->GetOrderedCnodes(); auto meta_graphT = std::make_unique(); int ret = RET_OK; @@ -236,7 +236,15 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool kee } node->nodeType = schema::NodeType_CNode; node->name = cnode->fullname_with_scope(); - node->primitive = std::unique_ptr(primT); + if (copy_primitive) { + auto primitive = new (std::nothrow) schema::PrimitiveT(); + if (primitive != nullptr) { + *primitive = *primT; + node->primitive = std::unique_ptr(primitive); + } + } else { + node->primitive = std::unique_ptr(primT); + } ret = SetOpInputNode(cnode, meta_graphT, node.get()); if (ret != RET_OK) { MS_LOG(ERROR) << "SetOpInputNode failed"; @@ -518,8 +526,8 @@ bool AnfExporter::IsPrimitiveCNode(const AnfNodePtr &node, schema::PrimitiveType return (schema::PrimitiveType)(prim->Type()) == type; } -schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph) { +schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive) { AnfExporter anf_exporter; - return anf_exporter.Export(func_graph, keep_graph); + return anf_exporter.Export(func_graph, keep_graph, copy_primitive); } } // namespace mindspore::lite diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.h b/mindspore/lite/tools/anf_exporter/anf_exporter.h index d0a5c7cadc..f8d5011f48 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.h +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.h @@ -31,7 +31,7 @@ class AnfExporter { public: AnfExporter() = default; virtual ~AnfExporter() = default; - schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph = false); + schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph = false, bool copy_primitive = false); void SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr &meta_graphT, schema::CNodeT *fb_node); int SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr &meta_graphT, @@ -56,7 +56,10 @@ class AnfExporter { std::map node_id_map_; std::vector graph_input_nodes_; }; - -schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph = false); +// by default, copy_primitive is false, which means that the MetaGraph and func_graph share the same schema::PrimitiveT. +// but in PostQuantization, the func_graph need to transfer to MetaGraph first and do MetaGraph pass, which may modify +// the schema::PrimitiveT and cause bug; If all the passes have been done in func_graph, every thing would be simple +// and clear. +schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph = false, bool copy_primitive = false); } // namespace mindspore::lite #endif // MINDSPORE_LITE_SRC_ANF_EXPORTER_ANF_EXPORTER_H_ diff --git a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc index c4add981aa..7aeb91c5ae 100644 --- a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc @@ -919,14 +919,10 @@ STATUS PostTrainingQuantizer::Int8Inference() { const std::vector &beforeOutputs, const mindspore::session::CallBackParam &callParam) -> bool { if (callParam.type_callback_param == kTypeConv2D || callParam.type_callback_param == kTypeDepthwiseConv2D) { - while (!fp32_op_input_ready) { + vector fp32_op_input; + while (!OpInputDataHandle(FETCH, callParam.name_callback_param, &fp32_op_input)) { std::this_thread::sleep_for(std::chrono::milliseconds(10)); } - if (callParam.name_callback_param != fp32_op_input_name) { - MS_LOG(ERROR) << "current int8 op name: " << callParam.name_callback_param - << " ready fp32 op name: " << fp32_op_input_name; - return false; - } auto tensor = beforeInputs[0]; auto lite_tensor = dynamic_cast(tensor); @@ -962,7 +958,6 @@ STATUS PostTrainingQuantizer::Int8Inference() { MS_LOG(ERROR) << "memcpy error: " << ret; return false; } - fp32_op_input_ready = false; } return true; }; @@ -972,14 +967,10 @@ STATUS PostTrainingQuantizer::Int8Inference() { const std::vector &afterOutputs, const mindspore::session::CallBackParam &callParam) -> bool { if (callParam.type_callback_param == kTypeConv2D || callParam.type_callback_param == kTypeDepthwiseConv2D) { - while (!fp32_op_output_ch_mean_ready) { + vector fp32_op_output_ch_mean; + while (!OpOutputChMeanDataHandle(FETCH, callParam.name_callback_param, &fp32_op_output_ch_mean)) { std::this_thread::sleep_for(std::chrono::milliseconds(10)); } - if (callParam.name_callback_param != fp32_op_output_name) { - MS_LOG(ERROR) << "current int8 op name: " << callParam.name_callback_param - << " ready fp32 op name: " << fp32_op_output_name; - return false; - } auto tensor = afterOutputs[0]; auto lite_tensor = dynamic_cast(tensor); @@ -1036,9 +1027,7 @@ STATUS PostTrainingQuantizer::Int8Inference() { } else { op_bias_diff_map[callParam.name_callback_param] = dequant_op_output_ch_mean; } - fp32_op_output_ch_mean_ready = false; } - return true; }; ret = int8_session_->RunGraph(beforeCallBack, afterCallBack); @@ -1072,23 +1061,21 @@ STATUS PostTrainingQuantizer::BiasCorrection(FuncGraphPtr func_graph) { const std::vector &beforeOutputs, const mindspore::session::CallBackParam &callParam) -> bool { if (callParam.type_callback_param == kTypeConv2D || callParam.type_callback_param == kTypeDepthwiseConv2D) { - while (fp32_op_input_ready) { - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - } if (PostTrainingQuantizer::CheckFp32TensorVec(callParam.name_callback_param, beforeInputs) != RET_OK) { return false; } auto tensor = beforeInputs[0]; size_t elem_count = tensor->ElementsNum(); - fp32_op_input.resize(elem_count); + std::vector fp32_op_input(elem_count); auto ret = memcpy_s(fp32_op_input.data(), fp32_op_input.size() * sizeof(float), tensor->MutableData(), tensor->Size()); if (ret != EOK) { MS_LOG(ERROR) << "memcpy error: " << ret; return false; } - fp32_op_input_name = callParam.name_callback_param; - fp32_op_input_ready = true; + while (!OpInputDataHandle(STORE, callParam.name_callback_param, &fp32_op_input)) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } } return true; }; @@ -1098,9 +1085,6 @@ STATUS PostTrainingQuantizer::BiasCorrection(FuncGraphPtr func_graph) { const std::vector &afterOutputs, const mindspore::session::CallBackParam &callParam) -> bool { if (callParam.type_callback_param == kTypeConv2D || callParam.type_callback_param == kTypeDepthwiseConv2D) { - while (fp32_op_output_ch_mean_ready) { - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - } if (PostTrainingQuantizer::CheckFp32TensorVec(callParam.name_callback_param, afterOutputs) != RET_OK) { return false; } @@ -1118,7 +1102,7 @@ STATUS PostTrainingQuantizer::BiasCorrection(FuncGraphPtr func_graph) { MS_LOG(ERROR) << "unexpected channels: 0"; return false; } - fp32_op_output_ch_mean.resize(channels); + std::vector fp32_op_output_ch_mean(channels); auto one_filter_size = elem_count / channels; for (int i = 0; i < channels; i++) { float sum = 0; @@ -1133,8 +1117,9 @@ STATUS PostTrainingQuantizer::BiasCorrection(FuncGraphPtr func_graph) { sum = sum / one_filter_size; fp32_op_output_ch_mean[i] = sum; } - fp32_op_output_name = callParam.name_callback_param; - fp32_op_output_ch_mean_ready = true; + while (!OpOutputChMeanDataHandle(STORE, callParam.name_callback_param, &fp32_op_output_ch_mean)) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } } return true; @@ -1326,7 +1311,7 @@ STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr func_graph) { return status; } // anf -- fb - auto meta_graph = Export(func_graph, true); + auto meta_graph = Export(func_graph, true, true); if (meta_graph == nullptr) { MS_LOG(ERROR) << "Export to meta_graph return nullptr"; return RET_ERROR; @@ -1409,7 +1394,7 @@ STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr func_graph) { if (calibrator_->GetBiasCorrection()) { // init in8 session // anf -- fb - auto int8_meta_graph = Export(func_graph, true); + auto int8_meta_graph = Export(func_graph, true, true); if (int8_meta_graph == nullptr) { MS_LOG(ERROR) << "Export to int8_meta_graph return nullptr"; return RET_ERROR; @@ -1461,6 +1446,54 @@ STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr func_graph) { return RET_OK; } + +bool PostTrainingQuantizer::OpInputDataHandle(OperationType type, const string &op_name, std::vector *data) { + std::lock_guard lg(mutex_op_input); + if (type == STORE) { + if (fp32_op_input_map.find(op_name) != fp32_op_input_map.end()) { + // the data has not been fetched by int8 model + return false; + } + fp32_op_input_map[op_name] = *data; + return true; + } else if (type == FETCH) { + if (fp32_op_input_map.find(op_name) == fp32_op_input_map.end()) { + // the data not generated by fp32 model yet + return false; + } + *data = fp32_op_input_map[op_name]; + fp32_op_input_map.erase(op_name); + return true; + } else { + MS_LOG(ERROR) << "unexpected type: " << type; + } + return false; +} + +bool PostTrainingQuantizer::OpOutputChMeanDataHandle(OperationType type, const string &op_name, + std::vector *data) { + std::lock_guard lg(mutex_op_output); + if (type == STORE) { + if (fp32_op_output_ch_mean_map.find(op_name) != fp32_op_output_ch_mean_map.end()) { + // the data has not been fetched by int8 model + return false; + } + fp32_op_output_ch_mean_map[op_name] = *data; + return true; + } else if (type == FETCH) { + if (fp32_op_output_ch_mean_map.find(op_name) == fp32_op_output_ch_mean_map.end()) { + // the data not generated by fp32 model yet + return false; + } + *data = fp32_op_output_ch_mean_map[op_name]; + fp32_op_output_ch_mean_map.erase(op_name); + return true; + } else { + MS_LOG(ERROR) << "unexpected type: " << type; + } + return false; +} + } // namespace quant } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.h b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.h index 5d78abe691..021ac30bac 100644 --- a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.h +++ b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.h @@ -73,13 +73,19 @@ class PostTrainingQuantizer : public Quantizer { mindspore::lite::LiteSession *fp32_session_; mindspore::lite::LiteSession *int8_session_; - std::string fp32_op_input_name; - std::string fp32_op_output_name; - std::vector fp32_op_input; - std::vector fp32_op_output_ch_mean; - std::map> op_bias_diff_map; - std::atomic fp32_op_input_ready{false}; - std::atomic fp32_op_output_ch_mean_ready{false}; + std::map> fp32_op_input_map; // concurency + std::map> fp32_op_output_ch_mean_map; // concurency + std::map> op_bias_diff_map; // only use by int8 model + std::mutex mutex_op_input; + std::mutex mutex_op_output; + + enum OperationType { + STORE, + FETCH, + }; + + bool OpInputDataHandle(OperationType type, const string &op_name, std::vector *data); + bool OpOutputChMeanDataHandle(OperationType type, const string &op_name, std::vector *data); const std::string kTypeConv2D = schema::EnumNamePrimitiveType(schema::PrimitiveType_Conv2D); const std::string kTypeDepthwiseConv2D = schema::EnumNamePrimitiveType(schema::PrimitiveType_DepthwiseConv2D); diff --git a/mindspore/lite/tools/converter/quantizer/quant_cast.cc b/mindspore/lite/tools/converter/quantizer/quant_cast.cc index ea90526069..8752b9e50f 100644 --- a/mindspore/lite/tools/converter/quantizer/quant_cast.cc +++ b/mindspore/lite/tools/converter/quantizer/quant_cast.cc @@ -32,6 +32,7 @@ ValueNodePtr NewQuantCastValueNode(int src_type, int dst_type, const std::vector for (auto &quant_param : quant_params) { std::vector quant_params_in = {quant_param}; primTValue->AddInputQuantParam(quant_params_in); + primTValue->AddOutputQuantParam(quant_params_in); } return NewValueNode(primTValue); } @@ -88,7 +89,7 @@ STATUS QuantCast::Run(FuncGraphPtr graph) { } else if (curnode_quant_type == schema::QuantType_QUANT_NONE && input_cnode_quant_type == schema::QuantType_PostTraining) { value_node = NewQuantCastValueNode(kNumberTypeInt8, kNumberTypeFloat32, - input_cnode_primitive_c->GetInputQuantParams().front()); + input_cnode_primitive_c->GetOutputQuantParams().front()); } if (value_node == nullptr) { MS_LOG(WARNING) << "value_node is null! " diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.cc b/mindspore/lite/tools/converter/quantizer/quantize_util.cc index 498b47c8c5..a79c0452d4 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.cc +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.cc @@ -99,12 +99,11 @@ bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const { auto type = (schema::PrimitiveType)primitive_c->Type(); MS_LOG(INFO) << "Primitive type: " << type; static const std::vector uint8OpList = { - schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw, - schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D, - schema::PrimitiveType_Add, schema::PrimitiveType_Pooling, - schema::PrimitiveType_Concat, /*schema::PrimitiveType_SoftMax,*/ - schema::PrimitiveType_Reshape, schema::PrimitiveType_FullConnection, - schema::PrimitiveType_MatMul, schema::PrimitiveType_Activation}; + schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw, schema::PrimitiveType_Conv2D, + schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_Add, schema::PrimitiveType_Pooling, + /*schema::PrimitiveType_Concat, schema::PrimitiveType_SoftMax,*/ + schema::PrimitiveType_Reshape, schema::PrimitiveType_FullConnection, schema::PrimitiveType_MatMul, + schema::PrimitiveType_Activation}; return IsContain(uint8OpList, type); }