|
|
|
@@ -29,12 +29,12 @@ Flags::Flags() { |
|
|
|
AddFlag(&Flags::outputFile, "outputFile", "Output model file path. Will add .ms automatically", ""); |
|
|
|
AddFlag(&Flags::weightFile, "weightFile", |
|
|
|
"Input model weight file path. Needed when fmk is CAFFE. CAFFE: *.caffemodel", ""); |
|
|
|
AddFlag(&Flags::inferenceType, "inferenceType", |
|
|
|
"Real data type saved in output file, reserved param, NOT used for now. FLOAT | FP16 | UINT8", "FLOAT"); |
|
|
|
AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. AwareTraining | WeightQuant | PostTraining", ""); |
|
|
|
AddFlag(&Flags::inputInferenceTypeIn, "inputInferenceType", "Input inference data type. FLOAT | UINT8", "FLOAT"); |
|
|
|
AddFlag(&Flags::inferenceTypeIn, "inferenceType", |
|
|
|
"Real data type saved in output file, reserved param, NOT used for now. FLOAT | INT8", "FLOAT"); |
|
|
|
AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. AwareTraining | PostTraining", ""); |
|
|
|
AddFlag(&Flags::inputInferenceTypeIn, "inputInferenceType", "Input inference data type. FLOAT | INT8", "FLOAT"); |
|
|
|
AddFlag(&Flags::stdDev, "stdDev", "Standard deviation value for aware-quantization", "128"); |
|
|
|
AddFlag(&Flags::mean, "mean", "Mean value for aware-quantization", "127"); |
|
|
|
AddFlag(&Flags::mean, "mean", "Mean value for aware-quantization", "-0.5"); |
|
|
|
AddFlag(&Flags::quantSize, "quantSize", "Weight quantization size threshold", "0"); |
|
|
|
AddFlag(&Flags::configFile, "config_file", "Configuration for post-training.", ""); |
|
|
|
AddFlag(&Flags::formatTrans, "formatTrans", "whether transform format. true | false", "true"); |
|
|
|
@@ -77,14 +77,24 @@ int Flags::Init(int argc, const char **argv) { |
|
|
|
} |
|
|
|
if (this->inputInferenceTypeIn == "FLOAT") { |
|
|
|
this->inputInferenceType = TypeId::kNumberTypeFloat; |
|
|
|
} else if (this->inputInferenceTypeIn == "UINT8") { |
|
|
|
this->inputInferenceType = TypeId::kNumberTypeUInt8; |
|
|
|
} else if (this->inputInferenceTypeIn == "INT8") { |
|
|
|
this->inputInferenceType = TypeId::kNumberTypeInt8; |
|
|
|
} else { |
|
|
|
std::cerr << "INPUT INVALID: inputInferenceType is invalid: %s", this->inputInferenceTypeIn.c_str(); |
|
|
|
std::cerr << "INPUT INVALID: inputInferenceType is invalid: %s, supported inputInferenceType: FLOAT | INT8", |
|
|
|
this->inputInferenceTypeIn.c_str(); |
|
|
|
return 1; |
|
|
|
} |
|
|
|
|
|
|
|
if (this->inferenceTypeIn == "FLOAT") { |
|
|
|
this->inferenceType = TypeId::kNumberTypeFloat; |
|
|
|
} else if (this->inferenceTypeIn == "INT8") { |
|
|
|
this->inferenceType = TypeId::kNumberTypeInt8; |
|
|
|
} else { |
|
|
|
std::cerr << "INPUT INVALID: inferenceType is invalid: %s, supported inferenceType: FLOAT | INT8", |
|
|
|
this->inferenceTypeIn.c_str(); |
|
|
|
return 1; |
|
|
|
} |
|
|
|
|
|
|
|
if (this->fmkIn == "CAFFE") { |
|
|
|
this->fmk = FmkType_CAFFE; |
|
|
|
} else if (this->fmkIn == "MS") { |
|
|
|
|