diff --git a/mindspore/lite/tools/common/storage.cc b/mindspore/lite/tools/common/storage.cc index 549df88f2d..f8b2430a77 100644 --- a/mindspore/lite/tools/common/storage.cc +++ b/mindspore/lite/tools/common/storage.cc @@ -34,7 +34,7 @@ int Storage::Save(const schema::MetaGraphT &graph, const std::string &outputPath std::ofstream output(outputPath + ".ms", std::ofstream::binary); if (!output.is_open()) { - MS_LOG(ERROR) << "ofstream open failed"; + MS_LOG(ERROR) << "Output file path is error"; return RET_ERROR; } diff --git a/mindspore/lite/tools/converter/converter_flags.cc b/mindspore/lite/tools/converter/converter_flags.cc index 7095675789..09f0406ee2 100644 --- a/mindspore/lite/tools/converter/converter_flags.cc +++ b/mindspore/lite/tools/converter/converter_flags.cc @@ -30,9 +30,8 @@ Flags::Flags() { AddFlag(&Flags::weightFile, "weightFile", "Input model weight file path. Needed when fmk is CAFFE. CAFFE: *.caffemodel", ""); AddFlag(&Flags::inferenceTypeIn, "inferenceType", - "Real data type saved in output file, reserved param, NOT used for now. FLOAT | INT8", "FLOAT"); + "Real data type saved in output file, reserved param, NOT used for now. SAME | FLOAT | INT8", "FLOAT"); AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. AwareTraining | PostTraining | WeightQuant", ""); - AddFlag(&Flags::inputInferenceTypeIn, "inputInferenceType", "Input inference data type. FLOAT | INT8", "FLOAT"); AddFlag(&Flags::stdDev, "stdDev", "Standard deviation value for aware-quantization", "128"); AddFlag(&Flags::mean, "mean", "Mean value for aware-quantization", "-0.5"); AddFlag(&Flags::bitNum, "bitNum", "Weight quantization bitNum", "8"); @@ -41,8 +40,10 @@ Flags::Flags() { "16"); AddFlag(&Flags::configFile, "config_file", "Configuration for post-training.", ""); AddFlag(&Flags::formatTrans, "formatTrans", "whether transform format. true | false", "true"); - AddFlag(&Flags::trainModelIn, "trainModel", "whether the model is going to be trained on device." - " true | false", "false"); + AddFlag(&Flags::trainModelIn, "trainModel", + "whether the model is going to be trained on device." + " true | false", + "false"); } int Flags::Init(int argc, const char **argv) { @@ -80,22 +81,15 @@ int Flags::Init(int argc, const char **argv) { std::cerr << "INPUT MISSING: fmk is necessary"; return RET_INPUT_PARAM_LACK; } - if (this->inputInferenceTypeIn == "FLOAT") { - this->inputInferenceType = TypeId::kNumberTypeFloat; - } else if (this->inputInferenceTypeIn == "INT8") { - this->inputInferenceType = TypeId::kNumberTypeInt8; - } else { - std::cerr << "INPUT INVALID: inputInferenceType is invalid: %s, supported inputInferenceType: FLOAT | INT8", - this->inputInferenceTypeIn.c_str(); - return RET_INPUT_PARAM_INVALID; - } if (this->inferenceTypeIn == "FLOAT") { this->inferenceType = TypeId::kNumberTypeFloat; } else if (this->inferenceTypeIn == "INT8") { this->inferenceType = TypeId::kNumberTypeInt8; + } else if (this->inferenceTypeIn == "SAME") { + this->inferenceType = TypeId::kTypeUnknown; } else { - std::cerr << "INPUT INVALID: inferenceType is invalid: %s, supported inferenceType: FLOAT | INT8", + std::cerr << "INPUT INVALID: inferenceType is invalid: %s, supported inferenceType: FLOAT | INT8 | SAME", this->inferenceTypeIn.c_str(); return RET_INPUT_PARAM_INVALID; } @@ -130,7 +124,6 @@ int Flags::Init(int argc, const char **argv) { return RET_INPUT_PARAM_INVALID; } - if (this->trainModelIn == "true") { this->trainModel = true; } else if (this->trainModelIn == "false") { diff --git a/mindspore/lite/tools/converter/converter_flags.h b/mindspore/lite/tools/converter/converter_flags.h index 7e1b822466..6a3bd36bfa 100644 --- a/mindspore/lite/tools/converter/converter_flags.h +++ b/mindspore/lite/tools/converter/converter_flags.h @@ -56,10 +56,8 @@ class Flags : public virtual mindspore::lite::FlagParser { std::string quantTypeIn; QuantType quantType; std::string inferenceTypeIn; - TypeId inferenceType = TypeId::kNumberTypeFloat; // used for parse aware trainning - std::string inputInferenceTypeIn; - TypeId inputInferenceType = TypeId::kNumberTypeFloat; + TypeId inferenceType = TypeId::kNumberTypeFloat; std::string stdDev; std::string mean; // used for post-trainning-weight diff --git a/mindspore/lite/tools/converter/graphdef_transform.cc b/mindspore/lite/tools/converter/graphdef_transform.cc index 82188fff45..920947d7ac 100644 --- a/mindspore/lite/tools/converter/graphdef_transform.cc +++ b/mindspore/lite/tools/converter/graphdef_transform.cc @@ -51,7 +51,7 @@ void GraphDefTransform::CreateQuantizer(const converter::Flags *flags) { case QuantType::QuantType_AwareTraining: { MS_LOG(INFO) << "create AwareTrainingQuantizer!"; fbQuantizer = - std::make_unique(graphDefT, flags->inputInferenceTypeIn, flags->stdDev, flags->mean); + std::make_unique(graphDefT, flags->inferenceType, flags->stdDev, flags->mean); break; } default: @@ -194,11 +194,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { if (ctx.quantType == QuantType_AwareTraining) { Optimizer quantNodeOptimizer; auto dTypeTransPass = new (std::nothrow) DTypeTransPass(); - if (dTypeTransPass == nullptr) { - MS_LOG(ERROR) << "new dTypeTransPass failed"; - return RET_MEMORY_FAILED; - } - dTypeTransPass->SetInputDataDType(ctx.inputInferenceType); + dTypeTransPass->SetInputDataDType(ctx.inferenceType); dTypeTransPass->SetOutputDataDType(ctx.inferenceType); quantNodeOptimizer.AddPass(dTypeTransPass); quantNodeOptimizer.AddPass(new (std::nothrow) QuantCastFusionPass()); diff --git a/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc b/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc index 8e64886cb8..21f16e4ba8 100644 --- a/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc @@ -71,7 +71,7 @@ STATUS InputArray::SetInputArrayQP(schema::MetaGraphT *graph, size_t inputTensor return RET_OK; } -AwareQuantizer::AwareQuantizer(schema::MetaGraphT *graph, const string &inputInferType, const string &stdValues, +AwareQuantizer::AwareQuantizer(schema::MetaGraphT *graph, const TypeId &inferType, const string &stdValues, const string &meanValues) : FbQuantizer(graph) { MS_ASSERT(graph != nullptr); @@ -80,7 +80,7 @@ AwareQuantizer::AwareQuantizer(schema::MetaGraphT *graph, const string &inputInf sz = 0; const float mean = std::stof(meanValues, &sz); std::unique_ptr inArr = nullptr; - if (inputInferType == "FLOAT") { + if (inferType == kNumberTypeFloat) { inArr.reset(new (std::nothrow) InputArray(mean, stdValue)); } else { inArr.reset(new (std::nothrow) InputArray(mean, stdValue, TypeId::kNumberTypeInt8)); diff --git a/mindspore/lite/tools/converter/quantizer/aware_quantizer.h b/mindspore/lite/tools/converter/quantizer/aware_quantizer.h index 40b92f6ead..f7038b1126 100644 --- a/mindspore/lite/tools/converter/quantizer/aware_quantizer.h +++ b/mindspore/lite/tools/converter/quantizer/aware_quantizer.h @@ -48,7 +48,7 @@ struct InputArray { class AwareQuantizer : public FbQuantizer { public: - AwareQuantizer(schema::MetaGraphT *graph, const std::string &inputInferType, const std::string &stdValues, + AwareQuantizer(schema::MetaGraphT *graph, const TypeId &inferType, const std::string &stdValues, const std::string &meanValues); ~AwareQuantizer() { delete (mInputArray); } diff --git a/mindspore/lite/tools/converter/quantizer/calc_quant_param.cc b/mindspore/lite/tools/converter/quantizer/calc_quant_param.cc index d04a7fb2b7..1c85cc5788 100644 --- a/mindspore/lite/tools/converter/quantizer/calc_quant_param.cc +++ b/mindspore/lite/tools/converter/quantizer/calc_quant_param.cc @@ -116,11 +116,11 @@ int CommonCalcer::Calc(MetaGraphT *subGraph, const CNodeT &node) { return status; } if (inputParamDone != node.inputIndex.size()) { - MS_LOG(ERROR) << "Can not determine inputTensor quantParam, node " << node.name; + MS_LOG(WARNING) << "Can not determine inputTensor quantParam, node " << node.name; return RET_ERROR; } if (outputParamDone != node.outputIndex.size()) { - MS_LOG(ERROR) << "Can not determine outputTensor quantParam, node " << node.name; + MS_LOG(WARNING) << "Can not determine outputTensor quantParam, node " << node.name; return RET_ERROR; } return RET_OK; @@ -138,7 +138,7 @@ int LinearCalcer::Calc(MetaGraphT *graph, const CNodeT &node) { MS_ASSERT(outTensor != nullptr); auto outputQuantParam = GetTensorQuantParam(outTensor); MS_ASSERT(outputQuantParam != nullptr); - if (!outputQuantParam->inited) { + if (outputQuantParam == nullptr || !outputQuantParam->inited) { MS_LOG(WARNING) << "Can not determine inputTensor quantParam from outputTensor for node " << node.name; return RET_ERROR; } @@ -204,8 +204,7 @@ class CalcConcat : public QuantParamCalcer { auto &inTensor = graph->allTensors.at(i); MS_ASSERT(inTensor != nullptr); auto inQuantParam = GetTensorQuantParam(inTensor); - MS_ASSERT(inQuantParam != nullptr); - if (!inQuantParam->inited) { + if (inQuantParam == nullptr || !inQuantParam->inited) { return RET_ERROR; } if (numBits == -1) {