Merge pull request !6369 from cjh9368/quant_same_datatypetags/v1.0.0
| @@ -18,7 +18,7 @@ | |||
| #include "nnacl/int8/quant_dtype_cast.h" | |||
| #include "nnacl/errorcode.h" | |||
| int DoDequantizeInt8(int8_t *quant_values, float *real_values, float scale, int32_t zp, int size) { | |||
| int DoDequantizeInt8ToFp32(int8_t *quant_values, float *real_values, float scale, int32_t zp, int size) { | |||
| if (quant_values == NULL || real_values == NULL) { | |||
| return NNACL_PARAM_INVALID; | |||
| } | |||
| @@ -29,7 +29,7 @@ int DoDequantizeInt8(int8_t *quant_values, float *real_values, float scale, int3 | |||
| return NNACL_OK; | |||
| } | |||
| int DoQuantizeToInt8(float *real_values, int8_t *quant_values, float scale, int32_t zp, int size) { | |||
| int DoQuantizeToInt8FromFp32(float *real_values, int8_t *quant_values, float scale, int32_t zp, int size) { | |||
| if (quant_values == NULL || real_values == NULL) { | |||
| return NNACL_PARAM_INVALID; | |||
| } | |||
| @@ -46,3 +46,39 @@ int DoQuantizeToInt8(float *real_values, int8_t *quant_values, float scale, int3 | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| int DoDequantizeInt8ToUInt8(int8_t *quant_values, uint8_t *real_values, int size) { | |||
| if (quant_values == NULL || real_values == NULL) { | |||
| return NNACL_PARAM_INVALID; | |||
| } | |||
| for (int i = 0; i < size; ++i) { | |||
| int temp = quant_values[i] + 128; | |||
| if (temp > 255) { | |||
| real_values[i] = (uint8_t)255; | |||
| } else if (temp < 0) { | |||
| real_values[i] = 0; | |||
| } else { | |||
| real_values[i] = (uint8_t)temp; | |||
| } | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| int DoQuantizeToInt8FromUint8(uint8_t *real_values, int8_t *quant_values, int size) { | |||
| if (quant_values == NULL || real_values == NULL) { | |||
| return NNACL_PARAM_INVALID; | |||
| } | |||
| for (int i = 0; i < size; ++i) { | |||
| int temp = real_values[i] - 128; | |||
| if (temp > 127) { | |||
| quant_values[i] = 127; | |||
| } else if (temp < -128) { | |||
| quant_values[i] = -128; | |||
| } else { | |||
| quant_values[i] = (int8_t)temp; | |||
| } | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| @@ -28,8 +28,10 @@ typedef struct QuantDTypeCastParameter { | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| int DoDequantizeInt8(int8_t *quant_values, float *real_values, float scale, int32_t zp, int size); | |||
| int DoQuantizeToInt8(float *real_values, int8_t *quant_values, float scale, int32_t zp, int size); | |||
| int DoDequantizeInt8ToFp32(int8_t *quant_values, float *real_values, float scale, int32_t zp, int size); | |||
| int DoQuantizeToInt8FromFp32(float *real_values, int8_t *quant_values, float scale, int32_t zp, int size); | |||
| int DoDequantizeInt8ToUInt8(int8_t *quant_values, uint8_t *real_values, int size); | |||
| int DoQuantizeToInt8FromUint8(uint8_t *real_values, int8_t *quant_values, int size); | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| @@ -53,6 +53,18 @@ int QuantDTypeCastCPUKernel::Init() { | |||
| return RET_ERROR; | |||
| } | |||
| inverse_ = true; | |||
| } else if (param->srcT == kNumberTypeUInt8 && param->dstT == kNumberTypeInt8) { | |||
| if (in_tensor->data_type() != kNumberTypeUInt8 || out_tensor->data_type() != kNumberTypeInt8) { | |||
| MS_LOG(ERROR) << "param data type and tensor data type do not match."; | |||
| return RET_ERROR; | |||
| } | |||
| inverse_ = false; | |||
| } else if (param->srcT == kNumberTypeInt8 && param->dstT == kNumberTypeUInt8) { | |||
| if (in_tensor->data_type() != kNumberTypeInt8 || out_tensor->data_type() != kNumberTypeUInt8) { | |||
| MS_LOG(ERROR) << "param data type and tensor data type do not match."; | |||
| return RET_ERROR; | |||
| } | |||
| inverse_ = true; | |||
| } else { | |||
| MS_LOG(ERROR) << "param data type not supported:" | |||
| << " src: " << param->srcT << " dst: " << param->dstT; | |||
| @@ -86,13 +98,22 @@ int QuantDTypeCastCPUKernel::QuantDTypeCast(int task_id) { | |||
| auto quant_arg = !in_tensors_.front()->GetQuantParams().empty() ? in_tensors_.front()->GetQuantParams().front() | |||
| : out_tensors_.front()->GetQuantParams().front(); | |||
| int ret; | |||
| if (inverse_) { | |||
| ret = DoDequantizeInt8(int8_ptr_ + thread_offset, float32_ptr_ + thread_offset, quant_arg.scale, | |||
| quant_arg.zeroPoint, num_unit_thread); | |||
| if (uint8_ptr_ == nullptr) { | |||
| if (inverse_) { | |||
| ret = DoDequantizeInt8ToFp32(int8_ptr_ + thread_offset, float32_ptr_ + thread_offset, quant_arg.scale, | |||
| quant_arg.zeroPoint, num_unit_thread); | |||
| } else { | |||
| ret = DoQuantizeToInt8FromFp32(float32_ptr_ + thread_offset, int8_ptr_ + thread_offset, quant_arg.scale, | |||
| quant_arg.zeroPoint, num_unit_thread); | |||
| } | |||
| } else { | |||
| ret = DoQuantizeToInt8(float32_ptr_ + thread_offset, int8_ptr_ + thread_offset, quant_arg.scale, | |||
| quant_arg.zeroPoint, num_unit_thread); | |||
| if (inverse_) { | |||
| ret = DoDequantizeInt8ToUInt8(int8_ptr_ + thread_offset, uint8_ptr_ + thread_offset, num_unit_thread); | |||
| } else { | |||
| ret = DoQuantizeToInt8FromUint8(uint8_ptr_ + thread_offset, int8_ptr_ + thread_offset, num_unit_thread); | |||
| } | |||
| } | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "QuantDTypeCast error task_id[" << task_id << "] error_code[" << ret << "]"; | |||
| return RET_ERROR; | |||
| @@ -116,12 +137,23 @@ int QuantDTypeCastCPUKernel::Run() { | |||
| MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; | |||
| return prepare_ret; | |||
| } | |||
| if (inverse_) { | |||
| int8_ptr_ = reinterpret_cast<int8_t *>(in_tensors_[0]->MutableData()); | |||
| float32_ptr_ = reinterpret_cast<float *>(out_tensors_[0]->MutableData()); | |||
| } else { | |||
| float32_ptr_ = reinterpret_cast<float *>(in_tensors_[0]->MutableData()); | |||
| int8_ptr_ = reinterpret_cast<int8_t *>(out_tensors_[0]->MutableData()); | |||
| if (in_tensors_[0]->data_type() == TypeId::kNumberTypeInt8 && | |||
| out_tensors_[0]->data_type() == TypeId::kNumberTypeFloat32) { | |||
| int8_ptr_ = reinterpret_cast<int8_t *>(in_tensors_[0]->data_c()); | |||
| float32_ptr_ = reinterpret_cast<float *>(out_tensors_[0]->data_c()); | |||
| } else if (in_tensors_[0]->data_type() == TypeId::kNumberTypeFloat32 && | |||
| out_tensors_[0]->data_type() == TypeId::kNumberTypeInt8) { | |||
| float32_ptr_ = reinterpret_cast<float *>(in_tensors_[0]->data_c()); | |||
| int8_ptr_ = reinterpret_cast<int8_t *>(out_tensors_[0]->data_c()); | |||
| } else if (in_tensors_[0]->data_type() == TypeId::kNumberTypeInt8 && | |||
| out_tensors_[0]->data_type() == TypeId::kNumberTypeUInt8) { | |||
| int8_ptr_ = reinterpret_cast<int8_t *>(in_tensors_[0]->data_c()); | |||
| uint8_ptr_ = reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()); | |||
| } else if (in_tensors_[0]->data_type() == TypeId::kNumberTypeUInt8 && | |||
| out_tensors_[0]->data_type() == TypeId::kNumberTypeInt8) { | |||
| uint8_ptr_ = reinterpret_cast<uint8_t *>(in_tensors_[0]->data_c()); | |||
| int8_ptr_ = reinterpret_cast<int8_t *>(out_tensors_[0]->data_c()); | |||
| } | |||
| auto ret = ParallelLaunch(this->context_->thread_pool_, QuantDTypeCastRun, this, thread_n_num_); | |||
| @@ -156,6 +188,7 @@ kernel::LiteKernel *CpuQuantDTypeCastFp32KernelCreator(const std::vector<lite::T | |||
| } | |||
| return kernel; | |||
| } | |||
| REG_KERNEL(kCPU, kNumberTypeUInt8, PrimitiveType_QuantDTypeCast, CpuQuantDTypeCastFp32KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_QuantDTypeCast, CpuQuantDTypeCastFp32KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_QuantDTypeCast, CpuQuantDTypeCastFp32KernelCreator) | |||
| } // namespace mindspore::kernel | |||
| @@ -40,6 +40,7 @@ class QuantDTypeCastCPUKernel : public LiteKernel { | |||
| int thread_n_stride_; | |||
| int num_unit_; | |||
| int8_t *int8_ptr_; | |||
| uint8_t *uint8_ptr_ = nullptr; | |||
| float *float32_ptr_; | |||
| bool inverse_; | |||
| }; | |||
| @@ -72,8 +72,8 @@ function Run_Converter() { | |||
| continue | |||
| fi | |||
| echo ${model_name} >> "${run_converter_log_file}" | |||
| echo './converter_lite --fmk=MS --modelFile='${models_path}'/'${model_name}' --outputFile='${ms_models_path}'/'${model_name}'' >> "${run_converter_log_file}" | |||
| ./converter_lite --fmk=MS --modelFile=${models_path}/${model_name} --outputFile=${ms_models_path}/${model_name} | |||
| echo './converter_lite --fmk=MINDIR --modelFile='${models_path}'/'${model_name}' --outputFile='${ms_models_path}'/'${model_name}'' >> "${run_converter_log_file}" | |||
| ./converter_lite --fmk=MINDIR --modelFile=${models_path}/${model_name} --outputFile=${ms_models_path}/${model_name} | |||
| if [ $? = 0 ]; then | |||
| converter_result='converter mindspore '${model_name}' pass';echo ${converter_result} >> ${run_converter_result_file} | |||
| else | |||
| @@ -186,7 +186,6 @@ int Benchmark::CompareOutput() { | |||
| } else { | |||
| tensor = tensors.front(); | |||
| } | |||
| MS_ASSERT(tensor->GetDataType() == DataType_DT_FLOAT); | |||
| MS_ASSERT(tensor->GetData() != nullptr); | |||
| float bias = 0; | |||
| switch (msCalibDataType) { | |||
| @@ -198,6 +197,10 @@ int Benchmark::CompareOutput() { | |||
| bias = CompareData<int8_t>(nodeOrTensorName, tensor->shape(), static_cast<int8_t *>(tensor->MutableData())); | |||
| break; | |||
| } | |||
| case TypeId::kNumberTypeUInt8: { | |||
| bias = CompareData<uint8_t>(nodeOrTensorName, tensor->shape(), static_cast<uint8_t *>(tensor->MutableData())); | |||
| break; | |||
| } | |||
| case TypeId::kNumberTypeInt32: { | |||
| bias = CompareData<int32_t>(nodeOrTensorName, tensor->shape(), static_cast<int32_t *>(tensor->MutableData())); | |||
| break; | |||
| @@ -66,7 +66,8 @@ class MS_API BenchmarkFlags : public virtual FlagParser { | |||
| AddFlag(&BenchmarkFlags::warmUpLoopCount, "warmUpLoopCount", "Run warm up loop", 3); | |||
| // MarkAccuracy | |||
| AddFlag(&BenchmarkFlags::calibDataPath, "calibDataPath", "Calibration data file path", ""); | |||
| AddFlag(&BenchmarkFlags::calibDataType, "calibDataType", "Calibration data type. FLOAT | INT32 | INT8", "FLOAT"); | |||
| AddFlag(&BenchmarkFlags::calibDataType, "calibDataType", "Calibration data type. FLOAT | INT32 | INT8 | UINT8", | |||
| "FLOAT"); | |||
| AddFlag(&BenchmarkFlags::accuracyThreshold, "accuracyThreshold", "Threshold of accuracy", 0.5); | |||
| } | |||
| @@ -222,8 +223,10 @@ class MS_API Benchmark { | |||
| std::vector<mindspore::tensor::MSTensor *> msInputs; | |||
| std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> msOutputs; | |||
| std::unordered_map<std::string, CheckTensor *> calibData; | |||
| std::unordered_map<std::string, TypeId> dataTypeMap{ | |||
| {"FLOAT", TypeId::kNumberTypeFloat}, {"INT8", TypeId::kNumberTypeInt8}, {"INT32", TypeId::kNumberTypeInt32}}; | |||
| std::unordered_map<std::string, TypeId> dataTypeMap{{"FLOAT", TypeId::kNumberTypeFloat}, | |||
| {"INT8", TypeId::kNumberTypeInt8}, | |||
| {"INT32", TypeId::kNumberTypeInt32}, | |||
| {"UINT8", TypeId::kNumberTypeUInt8}}; | |||
| TypeId msCalibDataType = TypeId::kNumberTypeFloat; | |||
| }; | |||
| @@ -23,14 +23,14 @@ namespace mindspore { | |||
| namespace lite { | |||
| namespace converter { | |||
| Flags::Flags() { | |||
| AddFlag(&Flags::fmkIn, "fmk", "Input model framework type. TFLITE | CAFFE | MS", ""); | |||
| AddFlag(&Flags::fmkIn, "fmk", "Input model framework type. TFLITE | CAFFE | MINDIR | ONNX", ""); | |||
| AddFlag(&Flags::modelFile, "modelFile", | |||
| "Input model file path. TFLITE: *.tflite | CAFFE: *.prototxt | MS: *.mindir | ONNX: *.onnx", ""); | |||
| "Input model file path. TFLITE: *.tflite | CAFFE: *.prototxt | MINDIR: *.mindir | ONNX: *.onnx", ""); | |||
| AddFlag(&Flags::outputFile, "outputFile", "Output model file path. Will add .ms automatically", ""); | |||
| AddFlag(&Flags::weightFile, "weightFile", | |||
| "Input model weight file path. Needed when fmk is CAFFE. CAFFE: *.caffemodel", ""); | |||
| AddFlag(&Flags::inferenceTypeIn, "inferenceType", | |||
| "Real data type saved in output file, reserved param, NOT used for now. SAME | FLOAT | INT8", "FLOAT"); | |||
| AddFlag(&Flags::inferenceTypeIn, "inferenceType", "Data type of input and output tensors. FLOAT | INT8 | UINT8", | |||
| "FLOAT"); | |||
| AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. AwareTraining | PostTraining | WeightQuant", ""); | |||
| AddFlag(&Flags::stdDev, "stdDev", "Standard deviation value for aware-quantization", "128"); | |||
| AddFlag(&Flags::mean, "mean", "Mean value for aware-quantization", "-0.5"); | |||
| @@ -39,7 +39,6 @@ Flags::Flags() { | |||
| AddFlag(&Flags::convWeightQuantChannelThreshold, "convWeightQuantChannelThreshold", "convWeightQuantChannelThreshold", | |||
| "16"); | |||
| AddFlag(&Flags::configFile, "config_file", "Configuration for post-training.", ""); | |||
| AddFlag(&Flags::formatTrans, "formatTrans", "whether transform format. true | false", "true"); | |||
| AddFlag(&Flags::trainModelIn, "trainModel", | |||
| "whether the model is going to be trained on device." | |||
| " true | false", | |||
| @@ -86,24 +85,24 @@ int Flags::Init(int argc, const char **argv) { | |||
| this->inferenceType = TypeId::kNumberTypeFloat; | |||
| } else if (this->inferenceTypeIn == "INT8") { | |||
| this->inferenceType = TypeId::kNumberTypeInt8; | |||
| } else if (this->inferenceTypeIn == "SAME") { | |||
| this->inferenceType = TypeId::kTypeUnknown; | |||
| } else if (this->inferenceTypeIn == "UINT8") { | |||
| this->inferenceType = TypeId::kNumberTypeUInt8; | |||
| } else { | |||
| std::cerr << "INPUT INVALID: inferenceType is invalid: %s, supported inferenceType: FLOAT | INT8 | SAME", | |||
| std::cerr << "INPUT INVALID: inferenceType is invalid: %s, supported inferenceType: FLOAT | INT8 | UINT8", | |||
| this->inferenceTypeIn.c_str(); | |||
| return RET_INPUT_PARAM_INVALID; | |||
| } | |||
| if (this->fmkIn == "CAFFE") { | |||
| this->fmk = FmkType_CAFFE; | |||
| } else if (this->fmkIn == "MS") { | |||
| } else if (this->fmkIn == "MINDIR") { | |||
| this->fmk = FmkType_MS; | |||
| } else if (this->fmkIn == "TFLITE") { | |||
| this->fmk = FmkType_TFLITE; | |||
| } else if (this->fmkIn == "ONNX") { | |||
| this->fmk = FmkType_ONNX; | |||
| } else { | |||
| std::cerr << "INPUT ILLEGAL: fmk must be TFLITE|CAFFE|MS|ONNX"; | |||
| std::cerr << "INPUT ILLEGAL: fmk must be TFLITE|CAFFE|MINDIR|ONNX"; | |||
| return RET_INPUT_PARAM_INVALID; | |||
| } | |||
| @@ -64,7 +64,6 @@ class Flags : public virtual mindspore::lite::FlagParser { | |||
| std::string quantSize; | |||
| std::string bitNum; | |||
| std::string configFile; | |||
| bool formatTrans = true; | |||
| std::string convWeightQuantChannelThreshold; | |||
| std::string trainModelIn; | |||
| bool trainModel = false; | |||
| @@ -137,7 +137,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||
| } | |||
| } | |||
| // format transform | |||
| if (ctx.formatTrans) { | |||
| { | |||
| Optimizer formatTransOptimizer; | |||
| auto formatTransPass = new (std::nothrow) FormatTransPass(); | |||
| if (formatTransPass == nullptr) { | |||
| @@ -103,7 +103,10 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) { | |||
| if (outputDataDType == TypeId::kNumberTypeInt8) { | |||
| return RET_OK; | |||
| } | |||
| MS_ASSERT(outputDataDType == TypeId::kNumberTypeFloat); | |||
| if (this->outputDataDType != TypeId::kNumberTypeFloat && this->outputDataDType != TypeId::kNumberTypeUInt8) { | |||
| MS_LOG(ERROR) << "Invalid outputDataType: " << this->outputDataDType; | |||
| return RET_ERROR; | |||
| } | |||
| auto &graphOutIdxes = graph->outputIndex; | |||
| for (auto graphOutIdx : graphOutIdxes) { | |||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | |||
| @@ -113,7 +116,11 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) { | |||
| if ((*iter)->outputIndex.at(outputIndexIdx) == graphOutIdx) { | |||
| // insert transNode | |||
| STATUS status = RET_OK; | |||
| iter = InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, kInt8ToFP32, &status); | |||
| if (inputDataDType == TypeId::kNumberTypeFloat) { | |||
| iter = InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, kInt8ToFP32, &status); | |||
| } else { | |||
| iter = InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, kInt8ToUInt8, &status); | |||
| } | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "InsertDTypeTransNode after " << nodeName.c_str() << " failed"; | |||
| return status; | |||
| @@ -46,7 +46,8 @@ STATUS TfliteQuantizeParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfl | |||
| MS_LOG(ERROR) << "output tensor is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| if (GetTfliteDataType(in_tensor->type) != kNumberTypeInt8) { | |||
| if (GetTfliteDataType(out_tensor->type) == kNumberTypeInt8 || | |||
| GetTfliteDataType(out_tensor->type) == kNumberTypeUInt8) { | |||
| std::unique_ptr<schema::QuantDTypeCastT> attr = std::make_unique<schema::QuantDTypeCastT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||