diff --git a/mindspore/lite/test/models_tflite.cfg b/mindspore/lite/test/models_tflite.cfg index 44977a4dfe..13f1db88dc 100644 --- a/mindspore/lite/test/models_tflite.cfg +++ b/mindspore/lite/test/models_tflite.cfg @@ -78,9 +78,9 @@ hiai_cv_labelDetectorModel_v2.tflite #hiai_cv_labelDetectorModel_v3.tflite hiai_cv_labelDetectorModel_v4.tflite hiai_dress_detect.tflite -#hiai_frozen_inference_graph.tflite +hiai_frozen_inference_graph.tflite hiai_ghostnet.tflite hiai_iMaxDN_RGB.tflite hiai_iMaxSR_RGB.tflite hiai_label_and_video.tflite -#hiai_lm_inference_graph.tflite +hiai_lm_inference_graph.tflite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc index bb8a58c163..dac9b3e578 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc @@ -158,7 +158,7 @@ STATUS TfliteModelParser::ConvertTensor(const std::unique_ptr auto isConst = (!tensor_buffer->data.empty()); if (isConst) { CopyConstTensorData(tflite_model_buffer, tflite_tensor.get(), tensor.get()); - } else if (tensor->dataType == TypeId::kNumberTypeUInt8) { + } else if (quantType == QuantType_AwareTraining && tensor->dataType == TypeId::kNumberTypeUInt8) { // set in/out tensor to int8 to fit ms-lite op tensor->dataType = TypeId::kNumberTypeInt8; } @@ -299,6 +299,7 @@ MetaGraphT *TfliteModelParser::Parse(const std::string &model_file, const QuantType &quant_type) { std::unique_ptr sub_graph = std::make_unique(); sub_graph->name = "MS_model converted by TF-Lite"; + quantType = quant_type; // load graph std::unique_ptr tflite_model = ReadTfliteModel(model_file.c_str()); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h index 2d95500213..71e28c3c88 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h @@ -77,6 +77,7 @@ class TfliteModelParser : public ModelParser { std::map opMap; std::map tfliteOpMap; + QuantType quantType = QuantType_QUANT_NONE; }; } // namespace lite } // namespace mindspore