From a161a3375a6837de454a7b76f7f38b5c3cdb2154 Mon Sep 17 00:00:00 2001 From: cjh9368 Date: Wed, 16 Sep 2020 19:26:15 +0800 Subject: [PATCH] add inferenceType UINT8 and recitify converter params --- mindspore/lite/nnacl/int8/quant_dtype_cast.c | 40 ++++++++++++- mindspore/lite/nnacl/int8/quant_dtype_cast.h | 6 +- .../kernel/arm/base/quant_dtype_cast.cc | 59 +++++++++++++++---- .../kernel/arm/base/quant_dtype_cast.h | 1 + mindspore/lite/test/run_benchmark_nets.sh | 4 +- mindspore/lite/tools/benchmark/benchmark.cc | 5 +- mindspore/lite/tools/benchmark/benchmark.h | 9 ++- .../lite/tools/converter/converter_flags.cc | 19 +++--- .../lite/tools/converter/converter_flags.h | 1 - .../tools/converter/graphdef_transform.cc | 2 +- .../graph/dtype_trans_pass.cc | 11 +++- .../parser/tflite/tflite_quantize_parser.cc | 3 +- 12 files changed, 122 insertions(+), 38 deletions(-) diff --git a/mindspore/lite/nnacl/int8/quant_dtype_cast.c b/mindspore/lite/nnacl/int8/quant_dtype_cast.c index 1184bf0c10..3728db12e3 100644 --- a/mindspore/lite/nnacl/int8/quant_dtype_cast.c +++ b/mindspore/lite/nnacl/int8/quant_dtype_cast.c @@ -18,7 +18,7 @@ #include "nnacl/int8/quant_dtype_cast.h" #include "nnacl/errorcode.h" -int DoDequantizeInt8(int8_t *quant_values, float *real_values, float scale, int32_t zp, int size) { +int DoDequantizeInt8ToFp32(int8_t *quant_values, float *real_values, float scale, int32_t zp, int size) { if (quant_values == NULL || real_values == NULL) { return NNACL_PARAM_INVALID; } @@ -29,7 +29,7 @@ int DoDequantizeInt8(int8_t *quant_values, float *real_values, float scale, int3 return NNACL_OK; } -int DoQuantizeToInt8(float *real_values, int8_t *quant_values, float scale, int32_t zp, int size) { +int DoQuantizeToInt8FromFp32(float *real_values, int8_t *quant_values, float scale, int32_t zp, int size) { if (quant_values == NULL || real_values == NULL) { return NNACL_PARAM_INVALID; } @@ -46,3 +46,39 @@ int DoQuantizeToInt8(float *real_values, int8_t *quant_values, float scale, int3 } return NNACL_OK; } + +int DoDequantizeInt8ToUInt8(int8_t *quant_values, uint8_t *real_values, int size) { + if (quant_values == NULL || real_values == NULL) { + return NNACL_PARAM_INVALID; + } + + for (int i = 0; i < size; ++i) { + int temp = quant_values[i] + 128; + if (temp > 255) { + real_values[i] = (uint8_t)255; + } else if (temp < 0) { + real_values[i] = 0; + } else { + real_values[i] = (uint8_t)temp; + } + } + return NNACL_OK; +} + +int DoQuantizeToInt8FromUint8(uint8_t *real_values, int8_t *quant_values, int size) { + if (quant_values == NULL || real_values == NULL) { + return NNACL_PARAM_INVALID; + } + + for (int i = 0; i < size; ++i) { + int temp = real_values[i] - 128; + if (temp > 127) { + quant_values[i] = 127; + } else if (temp < -128) { + quant_values[i] = -128; + } else { + quant_values[i] = (int8_t)temp; + } + } + return NNACL_OK; +} diff --git a/mindspore/lite/nnacl/int8/quant_dtype_cast.h b/mindspore/lite/nnacl/int8/quant_dtype_cast.h index 8e79fd5b55..941d1952e6 100644 --- a/mindspore/lite/nnacl/int8/quant_dtype_cast.h +++ b/mindspore/lite/nnacl/int8/quant_dtype_cast.h @@ -28,8 +28,10 @@ typedef struct QuantDTypeCastParameter { #ifdef __cplusplus extern "C" { #endif -int DoDequantizeInt8(int8_t *quant_values, float *real_values, float scale, int32_t zp, int size); -int DoQuantizeToInt8(float *real_values, int8_t *quant_values, float scale, int32_t zp, int size); +int DoDequantizeInt8ToFp32(int8_t *quant_values, float *real_values, float scale, int32_t zp, int size); +int DoQuantizeToInt8FromFp32(float *real_values, int8_t *quant_values, float scale, int32_t zp, int size); +int DoDequantizeInt8ToUInt8(int8_t *quant_values, uint8_t *real_values, int size); +int DoQuantizeToInt8FromUint8(uint8_t *real_values, int8_t *quant_values, int size); #ifdef __cplusplus } #endif diff --git a/mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.cc b/mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.cc index 6c913aba38..5967718079 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.cc @@ -53,6 +53,18 @@ int QuantDTypeCastCPUKernel::Init() { return RET_ERROR; } inverse_ = true; + } else if (param->srcT == kNumberTypeUInt8 && param->dstT == kNumberTypeInt8) { + if (in_tensor->data_type() != kNumberTypeUInt8 || out_tensor->data_type() != kNumberTypeInt8) { + MS_LOG(ERROR) << "param data type and tensor data type do not match."; + return RET_ERROR; + } + inverse_ = false; + } else if (param->srcT == kNumberTypeInt8 && param->dstT == kNumberTypeUInt8) { + if (in_tensor->data_type() != kNumberTypeInt8 || out_tensor->data_type() != kNumberTypeUInt8) { + MS_LOG(ERROR) << "param data type and tensor data type do not match."; + return RET_ERROR; + } + inverse_ = true; } else { MS_LOG(ERROR) << "param data type not supported:" << " src: " << param->srcT << " dst: " << param->dstT; @@ -83,16 +95,25 @@ int QuantDTypeCastCPUKernel::QuantDTypeCast(int task_id) { MS_LOG(ERROR) << "QuantDTypeCast need quantization parameters which is not found."; return RET_ERROR; } - auto quant_arg = !in_tensors_.front()->GetQuantParams().empty() ? in_tensors_.front()->GetQuantParams().front() : - out_tensors_.front()->GetQuantParams().front(); + auto quant_arg = !in_tensors_.front()->GetQuantParams().empty() ? in_tensors_.front()->GetQuantParams().front() + : out_tensors_.front()->GetQuantParams().front(); int ret; - if (inverse_) { - ret = DoDequantizeInt8(int8_ptr_ + thread_offset, float32_ptr_ + thread_offset, quant_arg.scale, - quant_arg.zeroPoint, num_unit_thread); + if (uint8_ptr_ == nullptr) { + if (inverse_) { + ret = DoDequantizeInt8ToFp32(int8_ptr_ + thread_offset, float32_ptr_ + thread_offset, quant_arg.scale, + quant_arg.zeroPoint, num_unit_thread); + } else { + ret = DoQuantizeToInt8FromFp32(float32_ptr_ + thread_offset, int8_ptr_ + thread_offset, quant_arg.scale, + quant_arg.zeroPoint, num_unit_thread); + } } else { - ret = DoQuantizeToInt8(float32_ptr_ + thread_offset, int8_ptr_ + thread_offset, quant_arg.scale, - quant_arg.zeroPoint, num_unit_thread); + if (inverse_) { + ret = DoDequantizeInt8ToUInt8(int8_ptr_ + thread_offset, uint8_ptr_ + thread_offset, num_unit_thread); + } else { + ret = DoQuantizeToInt8FromUint8(uint8_ptr_ + thread_offset, int8_ptr_ + thread_offset, num_unit_thread); + } } + if (ret != RET_OK) { MS_LOG(ERROR) << "QuantDTypeCast error task_id[" << task_id << "] error_code[" << ret << "]"; return RET_ERROR; @@ -116,12 +137,23 @@ int QuantDTypeCastCPUKernel::Run() { MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; return prepare_ret; } - if (inverse_) { - int8_ptr_ = reinterpret_cast(in_tensors_[0]->MutableData()); - float32_ptr_ = reinterpret_cast(out_tensors_[0]->MutableData()); - } else { - float32_ptr_ = reinterpret_cast(in_tensors_[0]->MutableData()); - int8_ptr_ = reinterpret_cast(out_tensors_[0]->MutableData()); + + if (in_tensors_[0]->data_type() == TypeId::kNumberTypeInt8 && + out_tensors_[0]->data_type() == TypeId::kNumberTypeFloat32) { + int8_ptr_ = reinterpret_cast(in_tensors_[0]->data_c()); + float32_ptr_ = reinterpret_cast(out_tensors_[0]->data_c()); + } else if (in_tensors_[0]->data_type() == TypeId::kNumberTypeFloat32 && + out_tensors_[0]->data_type() == TypeId::kNumberTypeInt8) { + float32_ptr_ = reinterpret_cast(in_tensors_[0]->data_c()); + int8_ptr_ = reinterpret_cast(out_tensors_[0]->data_c()); + } else if (in_tensors_[0]->data_type() == TypeId::kNumberTypeInt8 && + out_tensors_[0]->data_type() == TypeId::kNumberTypeUInt8) { + int8_ptr_ = reinterpret_cast(in_tensors_[0]->data_c()); + uint8_ptr_ = reinterpret_cast(out_tensors_[0]->data_c()); + } else if (in_tensors_[0]->data_type() == TypeId::kNumberTypeUInt8 && + out_tensors_[0]->data_type() == TypeId::kNumberTypeInt8) { + uint8_ptr_ = reinterpret_cast(in_tensors_[0]->data_c()); + int8_ptr_ = reinterpret_cast(out_tensors_[0]->data_c()); } auto ret = ParallelLaunch(THREAD_POOL_DEFAULT, QuantDTypeCastRun, this, thread_n_num_); @@ -156,6 +188,7 @@ kernel::LiteKernel *CpuQuantDTypeCastFp32KernelCreator(const std::vector> "${run_converter_log_file}" - echo './converter_lite --fmk=MS --modelFile='${models_path}'/'${model_name}' --outputFile='${ms_models_path}'/'${model_name}'' >> "${run_converter_log_file}" - ./converter_lite --fmk=MS --modelFile=${models_path}/${model_name} --outputFile=${ms_models_path}/${model_name} + echo './converter_lite --fmk=MINDIR --modelFile='${models_path}'/'${model_name}' --outputFile='${ms_models_path}'/'${model_name}'' >> "${run_converter_log_file}" + ./converter_lite --fmk=MINDIR --modelFile=${models_path}/${model_name} --outputFile=${ms_models_path}/${model_name} if [ $? = 0 ]; then converter_result='converter mindspore '${model_name}' pass';echo ${converter_result} >> ${run_converter_result_file} else diff --git a/mindspore/lite/tools/benchmark/benchmark.cc b/mindspore/lite/tools/benchmark/benchmark.cc index b9b80cc52d..2ae68aa0cc 100644 --- a/mindspore/lite/tools/benchmark/benchmark.cc +++ b/mindspore/lite/tools/benchmark/benchmark.cc @@ -186,7 +186,6 @@ int Benchmark::CompareOutput() { } else { tensor = tensors.front(); } - MS_ASSERT(tensor->GetDataType() == DataType_DT_FLOAT); MS_ASSERT(tensor->GetData() != nullptr); float bias = 0; switch (msCalibDataType) { @@ -198,6 +197,10 @@ int Benchmark::CompareOutput() { bias = CompareData(nodeOrTensorName, tensor->shape(), static_cast(tensor->MutableData())); break; } + case TypeId::kNumberTypeUInt8: { + bias = CompareData(nodeOrTensorName, tensor->shape(), static_cast(tensor->MutableData())); + break; + } case TypeId::kNumberTypeInt32: { bias = CompareData(nodeOrTensorName, tensor->shape(), static_cast(tensor->MutableData())); break; diff --git a/mindspore/lite/tools/benchmark/benchmark.h b/mindspore/lite/tools/benchmark/benchmark.h index b69a6a9bf9..e38790d9d2 100644 --- a/mindspore/lite/tools/benchmark/benchmark.h +++ b/mindspore/lite/tools/benchmark/benchmark.h @@ -66,7 +66,8 @@ class MS_API BenchmarkFlags : public virtual FlagParser { AddFlag(&BenchmarkFlags::warmUpLoopCount, "warmUpLoopCount", "Run warm up loop", 3); // MarkAccuracy AddFlag(&BenchmarkFlags::calibDataPath, "calibDataPath", "Calibration data file path", ""); - AddFlag(&BenchmarkFlags::calibDataType, "calibDataType", "Calibration data type. FLOAT | INT32 | INT8", "FLOAT"); + AddFlag(&BenchmarkFlags::calibDataType, "calibDataType", "Calibration data type. FLOAT | INT32 | INT8 | UINT8", + "FLOAT"); AddFlag(&BenchmarkFlags::accuracyThreshold, "accuracyThreshold", "Threshold of accuracy", 0.5); } @@ -222,8 +223,10 @@ class MS_API Benchmark { std::vector msInputs; std::unordered_map> msOutputs; std::unordered_map calibData; - std::unordered_map dataTypeMap{ - {"FLOAT", TypeId::kNumberTypeFloat}, {"INT8", TypeId::kNumberTypeInt8}, {"INT32", TypeId::kNumberTypeInt32}}; + std::unordered_map dataTypeMap{{"FLOAT", TypeId::kNumberTypeFloat}, + {"INT8", TypeId::kNumberTypeInt8}, + {"INT32", TypeId::kNumberTypeInt32}, + {"UINT8", TypeId::kNumberTypeUInt8}}; TypeId msCalibDataType = TypeId::kNumberTypeFloat; }; diff --git a/mindspore/lite/tools/converter/converter_flags.cc b/mindspore/lite/tools/converter/converter_flags.cc index 09f0406ee2..0567628114 100644 --- a/mindspore/lite/tools/converter/converter_flags.cc +++ b/mindspore/lite/tools/converter/converter_flags.cc @@ -23,14 +23,14 @@ namespace mindspore { namespace lite { namespace converter { Flags::Flags() { - AddFlag(&Flags::fmkIn, "fmk", "Input model framework type. TFLITE | CAFFE | MS", ""); + AddFlag(&Flags::fmkIn, "fmk", "Input model framework type. TFLITE | CAFFE | MINDIR | ONNX", ""); AddFlag(&Flags::modelFile, "modelFile", - "Input model file path. TFLITE: *.tflite | CAFFE: *.prototxt | MS: *.mindir | ONNX: *.onnx", ""); + "Input model file path. TFLITE: *.tflite | CAFFE: *.prototxt | MINDIR: *.mindir | ONNX: *.onnx", ""); AddFlag(&Flags::outputFile, "outputFile", "Output model file path. Will add .ms automatically", ""); AddFlag(&Flags::weightFile, "weightFile", "Input model weight file path. Needed when fmk is CAFFE. CAFFE: *.caffemodel", ""); - AddFlag(&Flags::inferenceTypeIn, "inferenceType", - "Real data type saved in output file, reserved param, NOT used for now. SAME | FLOAT | INT8", "FLOAT"); + AddFlag(&Flags::inferenceTypeIn, "inferenceType", "Data type of input and output tensors. FLOAT | INT8 | UINT8", + "FLOAT"); AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. AwareTraining | PostTraining | WeightQuant", ""); AddFlag(&Flags::stdDev, "stdDev", "Standard deviation value for aware-quantization", "128"); AddFlag(&Flags::mean, "mean", "Mean value for aware-quantization", "-0.5"); @@ -39,7 +39,6 @@ Flags::Flags() { AddFlag(&Flags::convWeightQuantChannelThreshold, "convWeightQuantChannelThreshold", "convWeightQuantChannelThreshold", "16"); AddFlag(&Flags::configFile, "config_file", "Configuration for post-training.", ""); - AddFlag(&Flags::formatTrans, "formatTrans", "whether transform format. true | false", "true"); AddFlag(&Flags::trainModelIn, "trainModel", "whether the model is going to be trained on device." " true | false", @@ -86,24 +85,24 @@ int Flags::Init(int argc, const char **argv) { this->inferenceType = TypeId::kNumberTypeFloat; } else if (this->inferenceTypeIn == "INT8") { this->inferenceType = TypeId::kNumberTypeInt8; - } else if (this->inferenceTypeIn == "SAME") { - this->inferenceType = TypeId::kTypeUnknown; + } else if (this->inferenceTypeIn == "UINT8") { + this->inferenceType = TypeId::kNumberTypeUInt8; } else { - std::cerr << "INPUT INVALID: inferenceType is invalid: %s, supported inferenceType: FLOAT | INT8 | SAME", + std::cerr << "INPUT INVALID: inferenceType is invalid: %s, supported inferenceType: FLOAT | INT8 | UINT8", this->inferenceTypeIn.c_str(); return RET_INPUT_PARAM_INVALID; } if (this->fmkIn == "CAFFE") { this->fmk = FmkType_CAFFE; - } else if (this->fmkIn == "MS") { + } else if (this->fmkIn == "MINDIR") { this->fmk = FmkType_MS; } else if (this->fmkIn == "TFLITE") { this->fmk = FmkType_TFLITE; } else if (this->fmkIn == "ONNX") { this->fmk = FmkType_ONNX; } else { - std::cerr << "INPUT ILLEGAL: fmk must be TFLITE|CAFFE|MS|ONNX"; + std::cerr << "INPUT ILLEGAL: fmk must be TFLITE|CAFFE|MINDIR|ONNX"; return RET_INPUT_PARAM_INVALID; } diff --git a/mindspore/lite/tools/converter/converter_flags.h b/mindspore/lite/tools/converter/converter_flags.h index 6a3bd36bfa..aa1881a253 100644 --- a/mindspore/lite/tools/converter/converter_flags.h +++ b/mindspore/lite/tools/converter/converter_flags.h @@ -64,7 +64,6 @@ class Flags : public virtual mindspore::lite::FlagParser { std::string quantSize; std::string bitNum; std::string configFile; - bool formatTrans = true; std::string convWeightQuantChannelThreshold; std::string trainModelIn; bool trainModel = false; diff --git a/mindspore/lite/tools/converter/graphdef_transform.cc b/mindspore/lite/tools/converter/graphdef_transform.cc index 920947d7ac..1f9ec8f689 100644 --- a/mindspore/lite/tools/converter/graphdef_transform.cc +++ b/mindspore/lite/tools/converter/graphdef_transform.cc @@ -137,7 +137,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { } } // format transform - if (ctx.formatTrans) { + { Optimizer formatTransOptimizer; auto formatTransPass = new (std::nothrow) FormatTransPass(); if (formatTransPass == nullptr) { diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc index e52cc50457..9c2b4b924a 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc @@ -103,7 +103,10 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) { if (outputDataDType == TypeId::kNumberTypeInt8) { return RET_OK; } - MS_ASSERT(outputDataDType == TypeId::kNumberTypeFloat); + if (this->outputDataDType != TypeId::kNumberTypeFloat && this->outputDataDType != TypeId::kNumberTypeUInt8) { + MS_LOG(ERROR) << "Invalid outputDataType: " << this->outputDataDType; + return RET_ERROR; + } auto &graphOutIdxes = graph->outputIndex; for (auto graphOutIdx : graphOutIdxes) { for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { @@ -113,7 +116,11 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) { if ((*iter)->outputIndex.at(outputIndexIdx) == graphOutIdx) { // insert transNode STATUS status = RET_OK; - iter = InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, kInt8ToFP32, &status); + if (inputDataDType == TypeId::kNumberTypeFloat) { + iter = InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, kInt8ToFP32, &status); + } else { + iter = InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, kInt8ToUInt8, &status); + } if (status != RET_OK) { MS_LOG(ERROR) << "InsertDTypeTransNode after " << nodeName.c_str() << " failed"; return status; diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.cc index 67fd6e5d1f..49d12f6e3f 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.cc @@ -46,7 +46,8 @@ STATUS TfliteQuantizeParser::Parse(const std::unique_ptr &tfl MS_LOG(ERROR) << "output tensor is null"; return RET_NULL_PTR; } - if (GetTfliteDataType(in_tensor->type) != kNumberTypeInt8) { + if (GetTfliteDataType(out_tensor->type) == kNumberTypeInt8 || + GetTfliteDataType(out_tensor->type) == kNumberTypeUInt8) { std::unique_ptr attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed";