Browse Source

fix bias correction bug

tags/v1.1.0
xutianchun 5 years ago
parent
commit
c5b5cb03c9
6 changed files with 100 additions and 50 deletions
  1. +12
    -4
      mindspore/lite/tools/anf_exporter/anf_exporter.cc
  2. +6
    -3
      mindspore/lite/tools/anf_exporter/anf_exporter.h
  3. +62
    -29
      mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc
  4. +13
    -7
      mindspore/lite/tools/converter/quantizer/post_training_quantizer.h
  5. +2
    -1
      mindspore/lite/tools/converter/quantizer/quant_cast.cc
  6. +5
    -6
      mindspore/lite/tools/converter/quantizer/quantize_util.cc

+ 12
- 4
mindspore/lite/tools/anf_exporter/anf_exporter.cc View File

@@ -195,7 +195,7 @@ int AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_pt
return RET_OK; return RET_OK;
} }


schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool keep_graph) {
schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive) {
auto cnodes = func_graph->GetOrderedCnodes(); auto cnodes = func_graph->GetOrderedCnodes();
auto meta_graphT = std::make_unique<schema::MetaGraphT>(); auto meta_graphT = std::make_unique<schema::MetaGraphT>();
int ret = RET_OK; int ret = RET_OK;
@@ -236,7 +236,15 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool kee
} }
node->nodeType = schema::NodeType_CNode; node->nodeType = schema::NodeType_CNode;
node->name = cnode->fullname_with_scope(); node->name = cnode->fullname_with_scope();
node->primitive = std::unique_ptr<schema::PrimitiveT>(primT);
if (copy_primitive) {
auto primitive = new (std::nothrow) schema::PrimitiveT();
if (primitive != nullptr) {
*primitive = *primT;
node->primitive = std::unique_ptr<schema::PrimitiveT>(primitive);
}
} else {
node->primitive = std::unique_ptr<schema::PrimitiveT>(primT);
}
ret = SetOpInputNode(cnode, meta_graphT, node.get()); ret = SetOpInputNode(cnode, meta_graphT, node.get());
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "SetOpInputNode failed"; MS_LOG(ERROR) << "SetOpInputNode failed";
@@ -518,8 +526,8 @@ bool AnfExporter::IsPrimitiveCNode(const AnfNodePtr &node, schema::PrimitiveType
return (schema::PrimitiveType)(prim->Type()) == type; return (schema::PrimitiveType)(prim->Type()) == type;
} }


schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph) {
schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive) {
AnfExporter anf_exporter; AnfExporter anf_exporter;
return anf_exporter.Export(func_graph, keep_graph);
return anf_exporter.Export(func_graph, keep_graph, copy_primitive);
} }
} // namespace mindspore::lite } // namespace mindspore::lite

+ 6
- 3
mindspore/lite/tools/anf_exporter/anf_exporter.h View File

@@ -31,7 +31,7 @@ class AnfExporter {
public: public:
AnfExporter() = default; AnfExporter() = default;
virtual ~AnfExporter() = default; virtual ~AnfExporter() = default;
schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph = false);
schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph = false, bool copy_primitive = false);
void SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, void SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
schema::CNodeT *fb_node); schema::CNodeT *fb_node);
int SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, int SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
@@ -56,7 +56,10 @@ class AnfExporter {
std::map<std::string, int> node_id_map_; std::map<std::string, int> node_id_map_;
std::vector<schema::CNodeT *> graph_input_nodes_; std::vector<schema::CNodeT *> graph_input_nodes_;
}; };

schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph = false);
// by default, copy_primitive is false, which means that the MetaGraph and func_graph share the same schema::PrimitiveT.
// but in PostQuantization, the func_graph need to transfer to MetaGraph first and do MetaGraph pass, which may modify
// the schema::PrimitiveT and cause bug; If all the passes have been done in func_graph, every thing would be simple
// and clear.
schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph = false, bool copy_primitive = false);
} // namespace mindspore::lite } // namespace mindspore::lite
#endif // MINDSPORE_LITE_SRC_ANF_EXPORTER_ANF_EXPORTER_H_ #endif // MINDSPORE_LITE_SRC_ANF_EXPORTER_ANF_EXPORTER_H_

+ 62
- 29
mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc View File

@@ -919,14 +919,10 @@ STATUS PostTrainingQuantizer::Int8Inference() {
const std::vector<mindspore::tensor::MSTensor *> &beforeOutputs, const std::vector<mindspore::tensor::MSTensor *> &beforeOutputs,
const mindspore::session::CallBackParam &callParam) -> bool { const mindspore::session::CallBackParam &callParam) -> bool {
if (callParam.type_callback_param == kTypeConv2D || callParam.type_callback_param == kTypeDepthwiseConv2D) { if (callParam.type_callback_param == kTypeConv2D || callParam.type_callback_param == kTypeDepthwiseConv2D) {
while (!fp32_op_input_ready) {
vector<float> fp32_op_input;
while (!OpInputDataHandle(FETCH, callParam.name_callback_param, &fp32_op_input)) {
std::this_thread::sleep_for(std::chrono::milliseconds(10)); std::this_thread::sleep_for(std::chrono::milliseconds(10));
} }
if (callParam.name_callback_param != fp32_op_input_name) {
MS_LOG(ERROR) << "current int8 op name: " << callParam.name_callback_param
<< " ready fp32 op name: " << fp32_op_input_name;
return false;
}
auto tensor = beforeInputs[0]; auto tensor = beforeInputs[0];
auto lite_tensor = dynamic_cast<mindspore::lite::Tensor *>(tensor); auto lite_tensor = dynamic_cast<mindspore::lite::Tensor *>(tensor);


@@ -962,7 +958,6 @@ STATUS PostTrainingQuantizer::Int8Inference() {
MS_LOG(ERROR) << "memcpy error: " << ret; MS_LOG(ERROR) << "memcpy error: " << ret;
return false; return false;
} }
fp32_op_input_ready = false;
} }
return true; return true;
}; };
@@ -972,14 +967,10 @@ STATUS PostTrainingQuantizer::Int8Inference() {
const std::vector<mindspore::tensor::MSTensor *> &afterOutputs, const std::vector<mindspore::tensor::MSTensor *> &afterOutputs,
const mindspore::session::CallBackParam &callParam) -> bool { const mindspore::session::CallBackParam &callParam) -> bool {
if (callParam.type_callback_param == kTypeConv2D || callParam.type_callback_param == kTypeDepthwiseConv2D) { if (callParam.type_callback_param == kTypeConv2D || callParam.type_callback_param == kTypeDepthwiseConv2D) {
while (!fp32_op_output_ch_mean_ready) {
vector<float> fp32_op_output_ch_mean;
while (!OpOutputChMeanDataHandle(FETCH, callParam.name_callback_param, &fp32_op_output_ch_mean)) {
std::this_thread::sleep_for(std::chrono::milliseconds(10)); std::this_thread::sleep_for(std::chrono::milliseconds(10));
} }
if (callParam.name_callback_param != fp32_op_output_name) {
MS_LOG(ERROR) << "current int8 op name: " << callParam.name_callback_param
<< " ready fp32 op name: " << fp32_op_output_name;
return false;
}
auto tensor = afterOutputs[0]; auto tensor = afterOutputs[0];
auto lite_tensor = dynamic_cast<mindspore::lite::Tensor *>(tensor); auto lite_tensor = dynamic_cast<mindspore::lite::Tensor *>(tensor);


@@ -1036,9 +1027,7 @@ STATUS PostTrainingQuantizer::Int8Inference() {
} else { } else {
op_bias_diff_map[callParam.name_callback_param] = dequant_op_output_ch_mean; op_bias_diff_map[callParam.name_callback_param] = dequant_op_output_ch_mean;
} }
fp32_op_output_ch_mean_ready = false;
} }

return true; return true;
}; };
ret = int8_session_->RunGraph(beforeCallBack, afterCallBack); ret = int8_session_->RunGraph(beforeCallBack, afterCallBack);
@@ -1072,23 +1061,21 @@ STATUS PostTrainingQuantizer::BiasCorrection(FuncGraphPtr func_graph) {
const std::vector<mindspore::tensor::MSTensor *> &beforeOutputs, const std::vector<mindspore::tensor::MSTensor *> &beforeOutputs,
const mindspore::session::CallBackParam &callParam) -> bool { const mindspore::session::CallBackParam &callParam) -> bool {
if (callParam.type_callback_param == kTypeConv2D || callParam.type_callback_param == kTypeDepthwiseConv2D) { if (callParam.type_callback_param == kTypeConv2D || callParam.type_callback_param == kTypeDepthwiseConv2D) {
while (fp32_op_input_ready) {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
if (PostTrainingQuantizer::CheckFp32TensorVec(callParam.name_callback_param, beforeInputs) != RET_OK) { if (PostTrainingQuantizer::CheckFp32TensorVec(callParam.name_callback_param, beforeInputs) != RET_OK) {
return false; return false;
} }
auto tensor = beforeInputs[0]; auto tensor = beforeInputs[0];
size_t elem_count = tensor->ElementsNum(); size_t elem_count = tensor->ElementsNum();
fp32_op_input.resize(elem_count);
std::vector<float> fp32_op_input(elem_count);
auto ret = auto ret =
memcpy_s(fp32_op_input.data(), fp32_op_input.size() * sizeof(float), tensor->MutableData(), tensor->Size()); memcpy_s(fp32_op_input.data(), fp32_op_input.size() * sizeof(float), tensor->MutableData(), tensor->Size());
if (ret != EOK) { if (ret != EOK) {
MS_LOG(ERROR) << "memcpy error: " << ret; MS_LOG(ERROR) << "memcpy error: " << ret;
return false; return false;
} }
fp32_op_input_name = callParam.name_callback_param;
fp32_op_input_ready = true;
while (!OpInputDataHandle(STORE, callParam.name_callback_param, &fp32_op_input)) {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
} }
return true; return true;
}; };
@@ -1098,9 +1085,6 @@ STATUS PostTrainingQuantizer::BiasCorrection(FuncGraphPtr func_graph) {
const std::vector<mindspore::tensor::MSTensor *> &afterOutputs, const std::vector<mindspore::tensor::MSTensor *> &afterOutputs,
const mindspore::session::CallBackParam &callParam) -> bool { const mindspore::session::CallBackParam &callParam) -> bool {
if (callParam.type_callback_param == kTypeConv2D || callParam.type_callback_param == kTypeDepthwiseConv2D) { if (callParam.type_callback_param == kTypeConv2D || callParam.type_callback_param == kTypeDepthwiseConv2D) {
while (fp32_op_output_ch_mean_ready) {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
if (PostTrainingQuantizer::CheckFp32TensorVec(callParam.name_callback_param, afterOutputs) != RET_OK) { if (PostTrainingQuantizer::CheckFp32TensorVec(callParam.name_callback_param, afterOutputs) != RET_OK) {
return false; return false;
} }
@@ -1118,7 +1102,7 @@ STATUS PostTrainingQuantizer::BiasCorrection(FuncGraphPtr func_graph) {
MS_LOG(ERROR) << "unexpected channels: 0"; MS_LOG(ERROR) << "unexpected channels: 0";
return false; return false;
} }
fp32_op_output_ch_mean.resize(channels);
std::vector<float> fp32_op_output_ch_mean(channels);
auto one_filter_size = elem_count / channels; auto one_filter_size = elem_count / channels;
for (int i = 0; i < channels; i++) { for (int i = 0; i < channels; i++) {
float sum = 0; float sum = 0;
@@ -1133,8 +1117,9 @@ STATUS PostTrainingQuantizer::BiasCorrection(FuncGraphPtr func_graph) {
sum = sum / one_filter_size; sum = sum / one_filter_size;
fp32_op_output_ch_mean[i] = sum; fp32_op_output_ch_mean[i] = sum;
} }
fp32_op_output_name = callParam.name_callback_param;
fp32_op_output_ch_mean_ready = true;
while (!OpOutputChMeanDataHandle(STORE, callParam.name_callback_param, &fp32_op_output_ch_mean)) {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
} }


return true; return true;
@@ -1326,7 +1311,7 @@ STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr func_graph) {
return status; return status;
} }
// anf -- fb // anf -- fb
auto meta_graph = Export(func_graph, true);
auto meta_graph = Export(func_graph, true, true);
if (meta_graph == nullptr) { if (meta_graph == nullptr) {
MS_LOG(ERROR) << "Export to meta_graph return nullptr"; MS_LOG(ERROR) << "Export to meta_graph return nullptr";
return RET_ERROR; return RET_ERROR;
@@ -1409,7 +1394,7 @@ STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr func_graph) {
if (calibrator_->GetBiasCorrection()) { if (calibrator_->GetBiasCorrection()) {
// init in8 session // init in8 session
// anf -- fb // anf -- fb
auto int8_meta_graph = Export(func_graph, true);
auto int8_meta_graph = Export(func_graph, true, true);
if (int8_meta_graph == nullptr) { if (int8_meta_graph == nullptr) {
MS_LOG(ERROR) << "Export to int8_meta_graph return nullptr"; MS_LOG(ERROR) << "Export to int8_meta_graph return nullptr";
return RET_ERROR; return RET_ERROR;
@@ -1461,6 +1446,54 @@ STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr func_graph) {


return RET_OK; return RET_OK;
} }

bool PostTrainingQuantizer::OpInputDataHandle(OperationType type, const string &op_name, std::vector<float> *data) {
std::lock_guard<std::mutex> lg(mutex_op_input);
if (type == STORE) {
if (fp32_op_input_map.find(op_name) != fp32_op_input_map.end()) {
// the data has not been fetched by int8 model
return false;
}
fp32_op_input_map[op_name] = *data;
return true;
} else if (type == FETCH) {
if (fp32_op_input_map.find(op_name) == fp32_op_input_map.end()) {
// the data not generated by fp32 model yet
return false;
}
*data = fp32_op_input_map[op_name];
fp32_op_input_map.erase(op_name);
return true;
} else {
MS_LOG(ERROR) << "unexpected type: " << type;
}
return false;
}

bool PostTrainingQuantizer::OpOutputChMeanDataHandle(OperationType type, const string &op_name,
std::vector<float> *data) {
std::lock_guard<std::mutex> lg(mutex_op_output);
if (type == STORE) {
if (fp32_op_output_ch_mean_map.find(op_name) != fp32_op_output_ch_mean_map.end()) {
// the data has not been fetched by int8 model
return false;
}
fp32_op_output_ch_mean_map[op_name] = *data;
return true;
} else if (type == FETCH) {
if (fp32_op_output_ch_mean_map.find(op_name) == fp32_op_output_ch_mean_map.end()) {
// the data not generated by fp32 model yet
return false;
}
*data = fp32_op_output_ch_mean_map[op_name];
fp32_op_output_ch_mean_map.erase(op_name);
return true;
} else {
MS_LOG(ERROR) << "unexpected type: " << type;
}
return false;
}

} // namespace quant } // namespace quant
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

+ 13
- 7
mindspore/lite/tools/converter/quantizer/post_training_quantizer.h View File

@@ -73,13 +73,19 @@ class PostTrainingQuantizer : public Quantizer {
mindspore::lite::LiteSession *fp32_session_; mindspore::lite::LiteSession *fp32_session_;
mindspore::lite::LiteSession *int8_session_; mindspore::lite::LiteSession *int8_session_;


std::string fp32_op_input_name;
std::string fp32_op_output_name;
std::vector<float> fp32_op_input;
std::vector<float> fp32_op_output_ch_mean;
std::map<std::string, std::vector<float>> op_bias_diff_map;
std::atomic<bool> fp32_op_input_ready{false};
std::atomic<bool> fp32_op_output_ch_mean_ready{false};
std::map<std::string, std::vector<float>> fp32_op_input_map; // concurency
std::map<std::string, std::vector<float>> fp32_op_output_ch_mean_map; // concurency
std::map<std::string, std::vector<float>> op_bias_diff_map; // only use by int8 model
std::mutex mutex_op_input;
std::mutex mutex_op_output;

enum OperationType {
STORE,
FETCH,
};

bool OpInputDataHandle(OperationType type, const string &op_name, std::vector<float> *data);
bool OpOutputChMeanDataHandle(OperationType type, const string &op_name, std::vector<float> *data);


const std::string kTypeConv2D = schema::EnumNamePrimitiveType(schema::PrimitiveType_Conv2D); const std::string kTypeConv2D = schema::EnumNamePrimitiveType(schema::PrimitiveType_Conv2D);
const std::string kTypeDepthwiseConv2D = schema::EnumNamePrimitiveType(schema::PrimitiveType_DepthwiseConv2D); const std::string kTypeDepthwiseConv2D = schema::EnumNamePrimitiveType(schema::PrimitiveType_DepthwiseConv2D);


+ 2
- 1
mindspore/lite/tools/converter/quantizer/quant_cast.cc View File

@@ -32,6 +32,7 @@ ValueNodePtr NewQuantCastValueNode(int src_type, int dst_type, const std::vector
for (auto &quant_param : quant_params) { for (auto &quant_param : quant_params) {
std::vector<schema::QuantParamT> quant_params_in = {quant_param}; std::vector<schema::QuantParamT> quant_params_in = {quant_param};
primTValue->AddInputQuantParam(quant_params_in); primTValue->AddInputQuantParam(quant_params_in);
primTValue->AddOutputQuantParam(quant_params_in);
} }
return NewValueNode(primTValue); return NewValueNode(primTValue);
} }
@@ -88,7 +89,7 @@ STATUS QuantCast::Run(FuncGraphPtr graph) {
} else if (curnode_quant_type == schema::QuantType_QUANT_NONE && } else if (curnode_quant_type == schema::QuantType_QUANT_NONE &&
input_cnode_quant_type == schema::QuantType_PostTraining) { input_cnode_quant_type == schema::QuantType_PostTraining) {
value_node = NewQuantCastValueNode(kNumberTypeInt8, kNumberTypeFloat32, value_node = NewQuantCastValueNode(kNumberTypeInt8, kNumberTypeFloat32,
input_cnode_primitive_c->GetInputQuantParams().front());
input_cnode_primitive_c->GetOutputQuantParams().front());
} }
if (value_node == nullptr) { if (value_node == nullptr) {
MS_LOG(WARNING) << "value_node is null! " MS_LOG(WARNING) << "value_node is null! "


+ 5
- 6
mindspore/lite/tools/converter/quantizer/quantize_util.cc View File

@@ -99,12 +99,11 @@ bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const {
auto type = (schema::PrimitiveType)primitive_c->Type(); auto type = (schema::PrimitiveType)primitive_c->Type();
MS_LOG(INFO) << "Primitive type: " << type; MS_LOG(INFO) << "Primitive type: " << type;
static const std::vector<schema::PrimitiveType> uint8OpList = { static const std::vector<schema::PrimitiveType> uint8OpList = {
schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw,
schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D,
schema::PrimitiveType_Add, schema::PrimitiveType_Pooling,
schema::PrimitiveType_Concat, /*schema::PrimitiveType_SoftMax,*/
schema::PrimitiveType_Reshape, schema::PrimitiveType_FullConnection,
schema::PrimitiveType_MatMul, schema::PrimitiveType_Activation};
schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw, schema::PrimitiveType_Conv2D,
schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_Add, schema::PrimitiveType_Pooling,
/*schema::PrimitiveType_Concat, schema::PrimitiveType_SoftMax,*/
schema::PrimitiveType_Reshape, schema::PrimitiveType_FullConnection, schema::PrimitiveType_MatMul,
schema::PrimitiveType_Activation};
return IsContain(uint8OpList, type); return IsContain(uint8OpList, type);
} }




Loading…
Cancel
Save