Merge pull request !6314 from cjh9368/weight_quanttags/v1.0.0
| @@ -34,7 +34,7 @@ int Storage::Save(const schema::MetaGraphT &graph, const std::string &outputPath | |||||
| std::ofstream output(outputPath + ".ms", std::ofstream::binary); | std::ofstream output(outputPath + ".ms", std::ofstream::binary); | ||||
| if (!output.is_open()) { | if (!output.is_open()) { | ||||
| MS_LOG(ERROR) << "ofstream open failed"; | |||||
| MS_LOG(ERROR) << "Output file path is error"; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -30,9 +30,8 @@ Flags::Flags() { | |||||
| AddFlag(&Flags::weightFile, "weightFile", | AddFlag(&Flags::weightFile, "weightFile", | ||||
| "Input model weight file path. Needed when fmk is CAFFE. CAFFE: *.caffemodel", ""); | "Input model weight file path. Needed when fmk is CAFFE. CAFFE: *.caffemodel", ""); | ||||
| AddFlag(&Flags::inferenceTypeIn, "inferenceType", | AddFlag(&Flags::inferenceTypeIn, "inferenceType", | ||||
| "Real data type saved in output file, reserved param, NOT used for now. FLOAT | INT8", "FLOAT"); | |||||
| "Real data type saved in output file, reserved param, NOT used for now. SAME | FLOAT | INT8", "FLOAT"); | |||||
| AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. AwareTraining | PostTraining | WeightQuant", ""); | AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. AwareTraining | PostTraining | WeightQuant", ""); | ||||
| AddFlag(&Flags::inputInferenceTypeIn, "inputInferenceType", "Input inference data type. FLOAT | INT8", "FLOAT"); | |||||
| AddFlag(&Flags::stdDev, "stdDev", "Standard deviation value for aware-quantization", "128"); | AddFlag(&Flags::stdDev, "stdDev", "Standard deviation value for aware-quantization", "128"); | ||||
| AddFlag(&Flags::mean, "mean", "Mean value for aware-quantization", "-0.5"); | AddFlag(&Flags::mean, "mean", "Mean value for aware-quantization", "-0.5"); | ||||
| AddFlag(&Flags::bitNum, "bitNum", "Weight quantization bitNum", "8"); | AddFlag(&Flags::bitNum, "bitNum", "Weight quantization bitNum", "8"); | ||||
| @@ -41,8 +40,10 @@ Flags::Flags() { | |||||
| "16"); | "16"); | ||||
| AddFlag(&Flags::configFile, "config_file", "Configuration for post-training.", ""); | AddFlag(&Flags::configFile, "config_file", "Configuration for post-training.", ""); | ||||
| AddFlag(&Flags::formatTrans, "formatTrans", "whether transform format. true | false", "true"); | AddFlag(&Flags::formatTrans, "formatTrans", "whether transform format. true | false", "true"); | ||||
| AddFlag(&Flags::trainModelIn, "trainModel", "whether the model is going to be trained on device." | |||||
| " true | false", "false"); | |||||
| AddFlag(&Flags::trainModelIn, "trainModel", | |||||
| "whether the model is going to be trained on device." | |||||
| " true | false", | |||||
| "false"); | |||||
| } | } | ||||
| int Flags::Init(int argc, const char **argv) { | int Flags::Init(int argc, const char **argv) { | ||||
| @@ -80,22 +81,15 @@ int Flags::Init(int argc, const char **argv) { | |||||
| std::cerr << "INPUT MISSING: fmk is necessary"; | std::cerr << "INPUT MISSING: fmk is necessary"; | ||||
| return RET_INPUT_PARAM_LACK; | return RET_INPUT_PARAM_LACK; | ||||
| } | } | ||||
| if (this->inputInferenceTypeIn == "FLOAT") { | |||||
| this->inputInferenceType = TypeId::kNumberTypeFloat; | |||||
| } else if (this->inputInferenceTypeIn == "INT8") { | |||||
| this->inputInferenceType = TypeId::kNumberTypeInt8; | |||||
| } else { | |||||
| std::cerr << "INPUT INVALID: inputInferenceType is invalid: %s, supported inputInferenceType: FLOAT | INT8", | |||||
| this->inputInferenceTypeIn.c_str(); | |||||
| return RET_INPUT_PARAM_INVALID; | |||||
| } | |||||
| if (this->inferenceTypeIn == "FLOAT") { | if (this->inferenceTypeIn == "FLOAT") { | ||||
| this->inferenceType = TypeId::kNumberTypeFloat; | this->inferenceType = TypeId::kNumberTypeFloat; | ||||
| } else if (this->inferenceTypeIn == "INT8") { | } else if (this->inferenceTypeIn == "INT8") { | ||||
| this->inferenceType = TypeId::kNumberTypeInt8; | this->inferenceType = TypeId::kNumberTypeInt8; | ||||
| } else if (this->inferenceTypeIn == "SAME") { | |||||
| this->inferenceType = TypeId::kTypeUnknown; | |||||
| } else { | } else { | ||||
| std::cerr << "INPUT INVALID: inferenceType is invalid: %s, supported inferenceType: FLOAT | INT8", | |||||
| std::cerr << "INPUT INVALID: inferenceType is invalid: %s, supported inferenceType: FLOAT | INT8 | SAME", | |||||
| this->inferenceTypeIn.c_str(); | this->inferenceTypeIn.c_str(); | ||||
| return RET_INPUT_PARAM_INVALID; | return RET_INPUT_PARAM_INVALID; | ||||
| } | } | ||||
| @@ -130,7 +124,6 @@ int Flags::Init(int argc, const char **argv) { | |||||
| return RET_INPUT_PARAM_INVALID; | return RET_INPUT_PARAM_INVALID; | ||||
| } | } | ||||
| if (this->trainModelIn == "true") { | if (this->trainModelIn == "true") { | ||||
| this->trainModel = true; | this->trainModel = true; | ||||
| } else if (this->trainModelIn == "false") { | } else if (this->trainModelIn == "false") { | ||||
| @@ -56,10 +56,8 @@ class Flags : public virtual mindspore::lite::FlagParser { | |||||
| std::string quantTypeIn; | std::string quantTypeIn; | ||||
| QuantType quantType; | QuantType quantType; | ||||
| std::string inferenceTypeIn; | std::string inferenceTypeIn; | ||||
| TypeId inferenceType = TypeId::kNumberTypeFloat; | |||||
| // used for parse aware trainning | // used for parse aware trainning | ||||
| std::string inputInferenceTypeIn; | |||||
| TypeId inputInferenceType = TypeId::kNumberTypeFloat; | |||||
| TypeId inferenceType = TypeId::kNumberTypeFloat; | |||||
| std::string stdDev; | std::string stdDev; | ||||
| std::string mean; | std::string mean; | ||||
| // used for post-trainning-weight | // used for post-trainning-weight | ||||
| @@ -51,7 +51,7 @@ void GraphDefTransform::CreateQuantizer(const converter::Flags *flags) { | |||||
| case QuantType::QuantType_AwareTraining: { | case QuantType::QuantType_AwareTraining: { | ||||
| MS_LOG(INFO) << "create AwareTrainingQuantizer!"; | MS_LOG(INFO) << "create AwareTrainingQuantizer!"; | ||||
| fbQuantizer = | fbQuantizer = | ||||
| std::make_unique<quant::AwareQuantizer>(graphDefT, flags->inputInferenceTypeIn, flags->stdDev, flags->mean); | |||||
| std::make_unique<quant::AwareQuantizer>(graphDefT, flags->inferenceType, flags->stdDev, flags->mean); | |||||
| break; | break; | ||||
| } | } | ||||
| default: | default: | ||||
| @@ -194,11 +194,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||||
| if (ctx.quantType == QuantType_AwareTraining) { | if (ctx.quantType == QuantType_AwareTraining) { | ||||
| Optimizer quantNodeOptimizer; | Optimizer quantNodeOptimizer; | ||||
| auto dTypeTransPass = new (std::nothrow) DTypeTransPass(); | auto dTypeTransPass = new (std::nothrow) DTypeTransPass(); | ||||
| if (dTypeTransPass == nullptr) { | |||||
| MS_LOG(ERROR) << "new dTypeTransPass failed"; | |||||
| return RET_MEMORY_FAILED; | |||||
| } | |||||
| dTypeTransPass->SetInputDataDType(ctx.inputInferenceType); | |||||
| dTypeTransPass->SetInputDataDType(ctx.inferenceType); | |||||
| dTypeTransPass->SetOutputDataDType(ctx.inferenceType); | dTypeTransPass->SetOutputDataDType(ctx.inferenceType); | ||||
| quantNodeOptimizer.AddPass(dTypeTransPass); | quantNodeOptimizer.AddPass(dTypeTransPass); | ||||
| quantNodeOptimizer.AddPass(new (std::nothrow) QuantCastFusionPass()); | quantNodeOptimizer.AddPass(new (std::nothrow) QuantCastFusionPass()); | ||||
| @@ -71,7 +71,7 @@ STATUS InputArray::SetInputArrayQP(schema::MetaGraphT *graph, size_t inputTensor | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| AwareQuantizer::AwareQuantizer(schema::MetaGraphT *graph, const string &inputInferType, const string &stdValues, | |||||
| AwareQuantizer::AwareQuantizer(schema::MetaGraphT *graph, const TypeId &inferType, const string &stdValues, | |||||
| const string &meanValues) | const string &meanValues) | ||||
| : FbQuantizer(graph) { | : FbQuantizer(graph) { | ||||
| MS_ASSERT(graph != nullptr); | MS_ASSERT(graph != nullptr); | ||||
| @@ -80,7 +80,7 @@ AwareQuantizer::AwareQuantizer(schema::MetaGraphT *graph, const string &inputInf | |||||
| sz = 0; | sz = 0; | ||||
| const float mean = std::stof(meanValues, &sz); | const float mean = std::stof(meanValues, &sz); | ||||
| std::unique_ptr<InputArray> inArr = nullptr; | std::unique_ptr<InputArray> inArr = nullptr; | ||||
| if (inputInferType == "FLOAT") { | |||||
| if (inferType == kNumberTypeFloat) { | |||||
| inArr.reset(new (std::nothrow) InputArray(mean, stdValue)); | inArr.reset(new (std::nothrow) InputArray(mean, stdValue)); | ||||
| } else { | } else { | ||||
| inArr.reset(new (std::nothrow) InputArray(mean, stdValue, TypeId::kNumberTypeInt8)); | inArr.reset(new (std::nothrow) InputArray(mean, stdValue, TypeId::kNumberTypeInt8)); | ||||
| @@ -48,7 +48,7 @@ struct InputArray { | |||||
| class AwareQuantizer : public FbQuantizer { | class AwareQuantizer : public FbQuantizer { | ||||
| public: | public: | ||||
| AwareQuantizer(schema::MetaGraphT *graph, const std::string &inputInferType, const std::string &stdValues, | |||||
| AwareQuantizer(schema::MetaGraphT *graph, const TypeId &inferType, const std::string &stdValues, | |||||
| const std::string &meanValues); | const std::string &meanValues); | ||||
| ~AwareQuantizer() { delete (mInputArray); } | ~AwareQuantizer() { delete (mInputArray); } | ||||
| @@ -116,11 +116,11 @@ int CommonCalcer::Calc(MetaGraphT *subGraph, const CNodeT &node) { | |||||
| return status; | return status; | ||||
| } | } | ||||
| if (inputParamDone != node.inputIndex.size()) { | if (inputParamDone != node.inputIndex.size()) { | ||||
| MS_LOG(ERROR) << "Can not determine inputTensor quantParam, node " << node.name; | |||||
| MS_LOG(WARNING) << "Can not determine inputTensor quantParam, node " << node.name; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (outputParamDone != node.outputIndex.size()) { | if (outputParamDone != node.outputIndex.size()) { | ||||
| MS_LOG(ERROR) << "Can not determine outputTensor quantParam, node " << node.name; | |||||
| MS_LOG(WARNING) << "Can not determine outputTensor quantParam, node " << node.name; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| @@ -138,7 +138,7 @@ int LinearCalcer::Calc(MetaGraphT *graph, const CNodeT &node) { | |||||
| MS_ASSERT(outTensor != nullptr); | MS_ASSERT(outTensor != nullptr); | ||||
| auto outputQuantParam = GetTensorQuantParam(outTensor); | auto outputQuantParam = GetTensorQuantParam(outTensor); | ||||
| MS_ASSERT(outputQuantParam != nullptr); | MS_ASSERT(outputQuantParam != nullptr); | ||||
| if (!outputQuantParam->inited) { | |||||
| if (outputQuantParam == nullptr || !outputQuantParam->inited) { | |||||
| MS_LOG(WARNING) << "Can not determine inputTensor quantParam from outputTensor for node " << node.name; | MS_LOG(WARNING) << "Can not determine inputTensor quantParam from outputTensor for node " << node.name; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -204,8 +204,7 @@ class CalcConcat : public QuantParamCalcer { | |||||
| auto &inTensor = graph->allTensors.at(i); | auto &inTensor = graph->allTensors.at(i); | ||||
| MS_ASSERT(inTensor != nullptr); | MS_ASSERT(inTensor != nullptr); | ||||
| auto inQuantParam = GetTensorQuantParam(inTensor); | auto inQuantParam = GetTensorQuantParam(inTensor); | ||||
| MS_ASSERT(inQuantParam != nullptr); | |||||
| if (!inQuantParam->inited) { | |||||
| if (inQuantParam == nullptr || !inQuantParam->inited) { | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (numBits == -1) { | if (numBits == -1) { | ||||