| @@ -34,7 +34,7 @@ int Storage::Save(const schema::MetaGraphT &graph, const std::string &outputPath | |||
| std::ofstream output(outputPath + ".ms", std::ofstream::binary); | |||
| if (!output.is_open()) { | |||
| MS_LOG(ERROR) << "ofstream open failed"; | |||
| MS_LOG(ERROR) << "Output file path is error"; | |||
| return RET_ERROR; | |||
| } | |||
| @@ -30,9 +30,8 @@ Flags::Flags() { | |||
| AddFlag(&Flags::weightFile, "weightFile", | |||
| "Input model weight file path. Needed when fmk is CAFFE. CAFFE: *.caffemodel", ""); | |||
| AddFlag(&Flags::inferenceTypeIn, "inferenceType", | |||
| "Real data type saved in output file, reserved param, NOT used for now. FLOAT | INT8", "FLOAT"); | |||
| "Real data type saved in output file, reserved param, NOT used for now. SAME | FLOAT | INT8", "FLOAT"); | |||
| AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. AwareTraining | PostTraining | WeightQuant", ""); | |||
| AddFlag(&Flags::inputInferenceTypeIn, "inputInferenceType", "Input inference data type. FLOAT | INT8", "FLOAT"); | |||
| AddFlag(&Flags::stdDev, "stdDev", "Standard deviation value for aware-quantization", "128"); | |||
| AddFlag(&Flags::mean, "mean", "Mean value for aware-quantization", "-0.5"); | |||
| AddFlag(&Flags::bitNum, "bitNum", "Weight quantization bitNum", "8"); | |||
| @@ -41,8 +40,10 @@ Flags::Flags() { | |||
| "16"); | |||
| AddFlag(&Flags::configFile, "config_file", "Configuration for post-training.", ""); | |||
| AddFlag(&Flags::formatTrans, "formatTrans", "whether transform format. true | false", "true"); | |||
| AddFlag(&Flags::trainModelIn, "trainModel", "whether the model is going to be trained on device." | |||
| " true | false", "false"); | |||
| AddFlag(&Flags::trainModelIn, "trainModel", | |||
| "whether the model is going to be trained on device." | |||
| " true | false", | |||
| "false"); | |||
| } | |||
| int Flags::Init(int argc, const char **argv) { | |||
| @@ -80,22 +81,15 @@ int Flags::Init(int argc, const char **argv) { | |||
| std::cerr << "INPUT MISSING: fmk is necessary"; | |||
| return RET_INPUT_PARAM_LACK; | |||
| } | |||
| if (this->inputInferenceTypeIn == "FLOAT") { | |||
| this->inputInferenceType = TypeId::kNumberTypeFloat; | |||
| } else if (this->inputInferenceTypeIn == "INT8") { | |||
| this->inputInferenceType = TypeId::kNumberTypeInt8; | |||
| } else { | |||
| std::cerr << "INPUT INVALID: inputInferenceType is invalid: %s, supported inputInferenceType: FLOAT | INT8", | |||
| this->inputInferenceTypeIn.c_str(); | |||
| return RET_INPUT_PARAM_INVALID; | |||
| } | |||
| if (this->inferenceTypeIn == "FLOAT") { | |||
| this->inferenceType = TypeId::kNumberTypeFloat; | |||
| } else if (this->inferenceTypeIn == "INT8") { | |||
| this->inferenceType = TypeId::kNumberTypeInt8; | |||
| } else if (this->inferenceTypeIn == "SAME") { | |||
| this->inferenceType = TypeId::kTypeUnknown; | |||
| } else { | |||
| std::cerr << "INPUT INVALID: inferenceType is invalid: %s, supported inferenceType: FLOAT | INT8", | |||
| std::cerr << "INPUT INVALID: inferenceType is invalid: %s, supported inferenceType: FLOAT | INT8 | SAME", | |||
| this->inferenceTypeIn.c_str(); | |||
| return RET_INPUT_PARAM_INVALID; | |||
| } | |||
| @@ -130,7 +124,6 @@ int Flags::Init(int argc, const char **argv) { | |||
| return RET_INPUT_PARAM_INVALID; | |||
| } | |||
| if (this->trainModelIn == "true") { | |||
| this->trainModel = true; | |||
| } else if (this->trainModelIn == "false") { | |||
| @@ -56,10 +56,8 @@ class Flags : public virtual mindspore::lite::FlagParser { | |||
| std::string quantTypeIn; | |||
| QuantType quantType; | |||
| std::string inferenceTypeIn; | |||
| TypeId inferenceType = TypeId::kNumberTypeFloat; | |||
| // used for parse aware trainning | |||
| std::string inputInferenceTypeIn; | |||
| TypeId inputInferenceType = TypeId::kNumberTypeFloat; | |||
| TypeId inferenceType = TypeId::kNumberTypeFloat; | |||
| std::string stdDev; | |||
| std::string mean; | |||
| // used for post-trainning-weight | |||
| @@ -51,7 +51,7 @@ void GraphDefTransform::CreateQuantizer(const converter::Flags *flags) { | |||
| case QuantType::QuantType_AwareTraining: { | |||
| MS_LOG(INFO) << "create AwareTrainingQuantizer!"; | |||
| fbQuantizer = | |||
| std::make_unique<quant::AwareQuantizer>(graphDefT, flags->inputInferenceTypeIn, flags->stdDev, flags->mean); | |||
| std::make_unique<quant::AwareQuantizer>(graphDefT, flags->inferenceType, flags->stdDev, flags->mean); | |||
| break; | |||
| } | |||
| default: | |||
| @@ -194,11 +194,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||
| if (ctx.quantType == QuantType_AwareTraining) { | |||
| Optimizer quantNodeOptimizer; | |||
| auto dTypeTransPass = new (std::nothrow) DTypeTransPass(); | |||
| if (dTypeTransPass == nullptr) { | |||
| MS_LOG(ERROR) << "new dTypeTransPass failed"; | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| dTypeTransPass->SetInputDataDType(ctx.inputInferenceType); | |||
| dTypeTransPass->SetInputDataDType(ctx.inferenceType); | |||
| dTypeTransPass->SetOutputDataDType(ctx.inferenceType); | |||
| quantNodeOptimizer.AddPass(dTypeTransPass); | |||
| quantNodeOptimizer.AddPass(new (std::nothrow) QuantCastFusionPass()); | |||
| @@ -71,7 +71,7 @@ STATUS InputArray::SetInputArrayQP(schema::MetaGraphT *graph, size_t inputTensor | |||
| return RET_OK; | |||
| } | |||
| AwareQuantizer::AwareQuantizer(schema::MetaGraphT *graph, const string &inputInferType, const string &stdValues, | |||
| AwareQuantizer::AwareQuantizer(schema::MetaGraphT *graph, const TypeId &inferType, const string &stdValues, | |||
| const string &meanValues) | |||
| : FbQuantizer(graph) { | |||
| MS_ASSERT(graph != nullptr); | |||
| @@ -80,7 +80,7 @@ AwareQuantizer::AwareQuantizer(schema::MetaGraphT *graph, const string &inputInf | |||
| sz = 0; | |||
| const float mean = std::stof(meanValues, &sz); | |||
| std::unique_ptr<InputArray> inArr = nullptr; | |||
| if (inputInferType == "FLOAT") { | |||
| if (inferType == kNumberTypeFloat) { | |||
| inArr.reset(new (std::nothrow) InputArray(mean, stdValue)); | |||
| } else { | |||
| inArr.reset(new (std::nothrow) InputArray(mean, stdValue, TypeId::kNumberTypeInt8)); | |||
| @@ -48,7 +48,7 @@ struct InputArray { | |||
| class AwareQuantizer : public FbQuantizer { | |||
| public: | |||
| AwareQuantizer(schema::MetaGraphT *graph, const std::string &inputInferType, const std::string &stdValues, | |||
| AwareQuantizer(schema::MetaGraphT *graph, const TypeId &inferType, const std::string &stdValues, | |||
| const std::string &meanValues); | |||
| ~AwareQuantizer() { delete (mInputArray); } | |||
| @@ -116,11 +116,11 @@ int CommonCalcer::Calc(MetaGraphT *subGraph, const CNodeT &node) { | |||
| return status; | |||
| } | |||
| if (inputParamDone != node.inputIndex.size()) { | |||
| MS_LOG(ERROR) << "Can not determine inputTensor quantParam, node " << node.name; | |||
| MS_LOG(WARNING) << "Can not determine inputTensor quantParam, node " << node.name; | |||
| return RET_ERROR; | |||
| } | |||
| if (outputParamDone != node.outputIndex.size()) { | |||
| MS_LOG(ERROR) << "Can not determine outputTensor quantParam, node " << node.name; | |||
| MS_LOG(WARNING) << "Can not determine outputTensor quantParam, node " << node.name; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| @@ -138,7 +138,7 @@ int LinearCalcer::Calc(MetaGraphT *graph, const CNodeT &node) { | |||
| MS_ASSERT(outTensor != nullptr); | |||
| auto outputQuantParam = GetTensorQuantParam(outTensor); | |||
| MS_ASSERT(outputQuantParam != nullptr); | |||
| if (!outputQuantParam->inited) { | |||
| if (outputQuantParam == nullptr || !outputQuantParam->inited) { | |||
| MS_LOG(WARNING) << "Can not determine inputTensor quantParam from outputTensor for node " << node.name; | |||
| return RET_ERROR; | |||
| } | |||
| @@ -204,8 +204,7 @@ class CalcConcat : public QuantParamCalcer { | |||
| auto &inTensor = graph->allTensors.at(i); | |||
| MS_ASSERT(inTensor != nullptr); | |||
| auto inQuantParam = GetTensorQuantParam(inTensor); | |||
| MS_ASSERT(inQuantParam != nullptr); | |||
| if (!inQuantParam->inited) { | |||
| if (inQuantParam == nullptr || !inQuantParam->inited) { | |||
| return RET_ERROR; | |||
| } | |||
| if (numBits == -1) { | |||