| @@ -18,7 +18,7 @@ | |||||
| #include "nnacl/int8/quant_dtype_cast.h" | #include "nnacl/int8/quant_dtype_cast.h" | ||||
| #include "nnacl/errorcode.h" | #include "nnacl/errorcode.h" | ||||
| int DoDequantizeInt8(int8_t *quant_values, float *real_values, float scale, int32_t zp, int size) { | |||||
| int DoDequantizeInt8ToFp32(int8_t *quant_values, float *real_values, float scale, int32_t zp, int size) { | |||||
| if (quant_values == NULL || real_values == NULL) { | if (quant_values == NULL || real_values == NULL) { | ||||
| return NNACL_PARAM_INVALID; | return NNACL_PARAM_INVALID; | ||||
| } | } | ||||
| @@ -29,7 +29,7 @@ int DoDequantizeInt8(int8_t *quant_values, float *real_values, float scale, int3 | |||||
| return NNACL_OK; | return NNACL_OK; | ||||
| } | } | ||||
| int DoQuantizeToInt8(float *real_values, int8_t *quant_values, float scale, int32_t zp, int size) { | |||||
| int DoQuantizeToInt8FromFp32(float *real_values, int8_t *quant_values, float scale, int32_t zp, int size) { | |||||
| if (quant_values == NULL || real_values == NULL) { | if (quant_values == NULL || real_values == NULL) { | ||||
| return NNACL_PARAM_INVALID; | return NNACL_PARAM_INVALID; | ||||
| } | } | ||||
| @@ -46,3 +46,39 @@ int DoQuantizeToInt8(float *real_values, int8_t *quant_values, float scale, int3 | |||||
| } | } | ||||
| return NNACL_OK; | return NNACL_OK; | ||||
| } | } | ||||
| int DoDequantizeInt8ToUInt8(int8_t *quant_values, uint8_t *real_values, int size) { | |||||
| if (quant_values == NULL || real_values == NULL) { | |||||
| return NNACL_PARAM_INVALID; | |||||
| } | |||||
| for (int i = 0; i < size; ++i) { | |||||
| int temp = quant_values[i] + 128; | |||||
| if (temp > 255) { | |||||
| real_values[i] = (uint8_t)255; | |||||
| } else if (temp < 0) { | |||||
| real_values[i] = 0; | |||||
| } else { | |||||
| real_values[i] = (uint8_t)temp; | |||||
| } | |||||
| } | |||||
| return NNACL_OK; | |||||
| } | |||||
| int DoQuantizeToInt8FromUint8(uint8_t *real_values, int8_t *quant_values, int size) { | |||||
| if (quant_values == NULL || real_values == NULL) { | |||||
| return NNACL_PARAM_INVALID; | |||||
| } | |||||
| for (int i = 0; i < size; ++i) { | |||||
| int temp = real_values[i] - 128; | |||||
| if (temp > 127) { | |||||
| quant_values[i] = 127; | |||||
| } else if (temp < -128) { | |||||
| quant_values[i] = -128; | |||||
| } else { | |||||
| quant_values[i] = (int8_t)temp; | |||||
| } | |||||
| } | |||||
| return NNACL_OK; | |||||
| } | |||||
| @@ -28,8 +28,10 @@ typedef struct QuantDTypeCastParameter { | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| extern "C" { | extern "C" { | ||||
| #endif | #endif | ||||
| int DoDequantizeInt8(int8_t *quant_values, float *real_values, float scale, int32_t zp, int size); | |||||
| int DoQuantizeToInt8(float *real_values, int8_t *quant_values, float scale, int32_t zp, int size); | |||||
| int DoDequantizeInt8ToFp32(int8_t *quant_values, float *real_values, float scale, int32_t zp, int size); | |||||
| int DoQuantizeToInt8FromFp32(float *real_values, int8_t *quant_values, float scale, int32_t zp, int size); | |||||
| int DoDequantizeInt8ToUInt8(int8_t *quant_values, uint8_t *real_values, int size); | |||||
| int DoQuantizeToInt8FromUint8(uint8_t *real_values, int8_t *quant_values, int size); | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -53,6 +53,18 @@ int QuantDTypeCastCPUKernel::Init() { | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| inverse_ = true; | inverse_ = true; | ||||
| } else if (param->srcT == kNumberTypeUInt8 && param->dstT == kNumberTypeInt8) { | |||||
| if (in_tensor->data_type() != kNumberTypeUInt8 || out_tensor->data_type() != kNumberTypeInt8) { | |||||
| MS_LOG(ERROR) << "param data type and tensor data type do not match."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| inverse_ = false; | |||||
| } else if (param->srcT == kNumberTypeInt8 && param->dstT == kNumberTypeUInt8) { | |||||
| if (in_tensor->data_type() != kNumberTypeInt8 || out_tensor->data_type() != kNumberTypeUInt8) { | |||||
| MS_LOG(ERROR) << "param data type and tensor data type do not match."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| inverse_ = true; | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "param data type not supported:" | MS_LOG(ERROR) << "param data type not supported:" | ||||
| << " src: " << param->srcT << " dst: " << param->dstT; | << " src: " << param->srcT << " dst: " << param->dstT; | ||||
| @@ -83,16 +95,25 @@ int QuantDTypeCastCPUKernel::QuantDTypeCast(int task_id) { | |||||
| MS_LOG(ERROR) << "QuantDTypeCast need quantization parameters which is not found."; | MS_LOG(ERROR) << "QuantDTypeCast need quantization parameters which is not found."; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto quant_arg = !in_tensors_.front()->GetQuantParams().empty() ? in_tensors_.front()->GetQuantParams().front() : | |||||
| out_tensors_.front()->GetQuantParams().front(); | |||||
| auto quant_arg = !in_tensors_.front()->GetQuantParams().empty() ? in_tensors_.front()->GetQuantParams().front() | |||||
| : out_tensors_.front()->GetQuantParams().front(); | |||||
| int ret; | int ret; | ||||
| if (inverse_) { | |||||
| ret = DoDequantizeInt8(int8_ptr_ + thread_offset, float32_ptr_ + thread_offset, quant_arg.scale, | |||||
| quant_arg.zeroPoint, num_unit_thread); | |||||
| if (uint8_ptr_ == nullptr) { | |||||
| if (inverse_) { | |||||
| ret = DoDequantizeInt8ToFp32(int8_ptr_ + thread_offset, float32_ptr_ + thread_offset, quant_arg.scale, | |||||
| quant_arg.zeroPoint, num_unit_thread); | |||||
| } else { | |||||
| ret = DoQuantizeToInt8FromFp32(float32_ptr_ + thread_offset, int8_ptr_ + thread_offset, quant_arg.scale, | |||||
| quant_arg.zeroPoint, num_unit_thread); | |||||
| } | |||||
| } else { | } else { | ||||
| ret = DoQuantizeToInt8(float32_ptr_ + thread_offset, int8_ptr_ + thread_offset, quant_arg.scale, | |||||
| quant_arg.zeroPoint, num_unit_thread); | |||||
| if (inverse_) { | |||||
| ret = DoDequantizeInt8ToUInt8(int8_ptr_ + thread_offset, uint8_ptr_ + thread_offset, num_unit_thread); | |||||
| } else { | |||||
| ret = DoQuantizeToInt8FromUint8(uint8_ptr_ + thread_offset, int8_ptr_ + thread_offset, num_unit_thread); | |||||
| } | |||||
| } | } | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "QuantDTypeCast error task_id[" << task_id << "] error_code[" << ret << "]"; | MS_LOG(ERROR) << "QuantDTypeCast error task_id[" << task_id << "] error_code[" << ret << "]"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -116,12 +137,23 @@ int QuantDTypeCastCPUKernel::Run() { | |||||
| MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; | MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; | ||||
| return prepare_ret; | return prepare_ret; | ||||
| } | } | ||||
| if (inverse_) { | |||||
| int8_ptr_ = reinterpret_cast<int8_t *>(in_tensors_[0]->MutableData()); | |||||
| float32_ptr_ = reinterpret_cast<float *>(out_tensors_[0]->MutableData()); | |||||
| } else { | |||||
| float32_ptr_ = reinterpret_cast<float *>(in_tensors_[0]->MutableData()); | |||||
| int8_ptr_ = reinterpret_cast<int8_t *>(out_tensors_[0]->MutableData()); | |||||
| if (in_tensors_[0]->data_type() == TypeId::kNumberTypeInt8 && | |||||
| out_tensors_[0]->data_type() == TypeId::kNumberTypeFloat32) { | |||||
| int8_ptr_ = reinterpret_cast<int8_t *>(in_tensors_[0]->data_c()); | |||||
| float32_ptr_ = reinterpret_cast<float *>(out_tensors_[0]->data_c()); | |||||
| } else if (in_tensors_[0]->data_type() == TypeId::kNumberTypeFloat32 && | |||||
| out_tensors_[0]->data_type() == TypeId::kNumberTypeInt8) { | |||||
| float32_ptr_ = reinterpret_cast<float *>(in_tensors_[0]->data_c()); | |||||
| int8_ptr_ = reinterpret_cast<int8_t *>(out_tensors_[0]->data_c()); | |||||
| } else if (in_tensors_[0]->data_type() == TypeId::kNumberTypeInt8 && | |||||
| out_tensors_[0]->data_type() == TypeId::kNumberTypeUInt8) { | |||||
| int8_ptr_ = reinterpret_cast<int8_t *>(in_tensors_[0]->data_c()); | |||||
| uint8_ptr_ = reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()); | |||||
| } else if (in_tensors_[0]->data_type() == TypeId::kNumberTypeUInt8 && | |||||
| out_tensors_[0]->data_type() == TypeId::kNumberTypeInt8) { | |||||
| uint8_ptr_ = reinterpret_cast<uint8_t *>(in_tensors_[0]->data_c()); | |||||
| int8_ptr_ = reinterpret_cast<int8_t *>(out_tensors_[0]->data_c()); | |||||
| } | } | ||||
| auto ret = ParallelLaunch(THREAD_POOL_DEFAULT, QuantDTypeCastRun, this, thread_n_num_); | auto ret = ParallelLaunch(THREAD_POOL_DEFAULT, QuantDTypeCastRun, this, thread_n_num_); | ||||
| @@ -156,6 +188,7 @@ kernel::LiteKernel *CpuQuantDTypeCastFp32KernelCreator(const std::vector<lite::T | |||||
| } | } | ||||
| return kernel; | return kernel; | ||||
| } | } | ||||
| REG_KERNEL(kCPU, kNumberTypeUInt8, PrimitiveType_QuantDTypeCast, CpuQuantDTypeCastFp32KernelCreator) | |||||
| REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_QuantDTypeCast, CpuQuantDTypeCastFp32KernelCreator) | REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_QuantDTypeCast, CpuQuantDTypeCastFp32KernelCreator) | ||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_QuantDTypeCast, CpuQuantDTypeCastFp32KernelCreator) | REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_QuantDTypeCast, CpuQuantDTypeCastFp32KernelCreator) | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -40,6 +40,7 @@ class QuantDTypeCastCPUKernel : public LiteKernel { | |||||
| int thread_n_stride_; | int thread_n_stride_; | ||||
| int num_unit_; | int num_unit_; | ||||
| int8_t *int8_ptr_; | int8_t *int8_ptr_; | ||||
| uint8_t *uint8_ptr_ = nullptr; | |||||
| float *float32_ptr_; | float *float32_ptr_; | ||||
| bool inverse_; | bool inverse_; | ||||
| }; | }; | ||||
| @@ -72,8 +72,8 @@ function Run_Converter() { | |||||
| continue | continue | ||||
| fi | fi | ||||
| echo ${model_name} >> "${run_converter_log_file}" | echo ${model_name} >> "${run_converter_log_file}" | ||||
| echo './converter_lite --fmk=MS --modelFile='${models_path}'/'${model_name}' --outputFile='${ms_models_path}'/'${model_name}'' >> "${run_converter_log_file}" | |||||
| ./converter_lite --fmk=MS --modelFile=${models_path}/${model_name} --outputFile=${ms_models_path}/${model_name} | |||||
| echo './converter_lite --fmk=MINDIR --modelFile='${models_path}'/'${model_name}' --outputFile='${ms_models_path}'/'${model_name}'' >> "${run_converter_log_file}" | |||||
| ./converter_lite --fmk=MINDIR --modelFile=${models_path}/${model_name} --outputFile=${ms_models_path}/${model_name} | |||||
| if [ $? = 0 ]; then | if [ $? = 0 ]; then | ||||
| converter_result='converter mindspore '${model_name}' pass';echo ${converter_result} >> ${run_converter_result_file} | converter_result='converter mindspore '${model_name}' pass';echo ${converter_result} >> ${run_converter_result_file} | ||||
| else | else | ||||
| @@ -186,7 +186,6 @@ int Benchmark::CompareOutput() { | |||||
| } else { | } else { | ||||
| tensor = tensors.front(); | tensor = tensors.front(); | ||||
| } | } | ||||
| MS_ASSERT(tensor->GetDataType() == DataType_DT_FLOAT); | |||||
| MS_ASSERT(tensor->GetData() != nullptr); | MS_ASSERT(tensor->GetData() != nullptr); | ||||
| float bias = 0; | float bias = 0; | ||||
| switch (msCalibDataType) { | switch (msCalibDataType) { | ||||
| @@ -198,6 +197,10 @@ int Benchmark::CompareOutput() { | |||||
| bias = CompareData<int8_t>(nodeOrTensorName, tensor->shape(), static_cast<int8_t *>(tensor->MutableData())); | bias = CompareData<int8_t>(nodeOrTensorName, tensor->shape(), static_cast<int8_t *>(tensor->MutableData())); | ||||
| break; | break; | ||||
| } | } | ||||
| case TypeId::kNumberTypeUInt8: { | |||||
| bias = CompareData<uint8_t>(nodeOrTensorName, tensor->shape(), static_cast<uint8_t *>(tensor->MutableData())); | |||||
| break; | |||||
| } | |||||
| case TypeId::kNumberTypeInt32: { | case TypeId::kNumberTypeInt32: { | ||||
| bias = CompareData<int32_t>(nodeOrTensorName, tensor->shape(), static_cast<int32_t *>(tensor->MutableData())); | bias = CompareData<int32_t>(nodeOrTensorName, tensor->shape(), static_cast<int32_t *>(tensor->MutableData())); | ||||
| break; | break; | ||||
| @@ -66,7 +66,8 @@ class MS_API BenchmarkFlags : public virtual FlagParser { | |||||
| AddFlag(&BenchmarkFlags::warmUpLoopCount, "warmUpLoopCount", "Run warm up loop", 3); | AddFlag(&BenchmarkFlags::warmUpLoopCount, "warmUpLoopCount", "Run warm up loop", 3); | ||||
| // MarkAccuracy | // MarkAccuracy | ||||
| AddFlag(&BenchmarkFlags::calibDataPath, "calibDataPath", "Calibration data file path", ""); | AddFlag(&BenchmarkFlags::calibDataPath, "calibDataPath", "Calibration data file path", ""); | ||||
| AddFlag(&BenchmarkFlags::calibDataType, "calibDataType", "Calibration data type. FLOAT | INT32 | INT8", "FLOAT"); | |||||
| AddFlag(&BenchmarkFlags::calibDataType, "calibDataType", "Calibration data type. FLOAT | INT32 | INT8 | UINT8", | |||||
| "FLOAT"); | |||||
| AddFlag(&BenchmarkFlags::accuracyThreshold, "accuracyThreshold", "Threshold of accuracy", 0.5); | AddFlag(&BenchmarkFlags::accuracyThreshold, "accuracyThreshold", "Threshold of accuracy", 0.5); | ||||
| } | } | ||||
| @@ -222,8 +223,10 @@ class MS_API Benchmark { | |||||
| std::vector<mindspore::tensor::MSTensor *> msInputs; | std::vector<mindspore::tensor::MSTensor *> msInputs; | ||||
| std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> msOutputs; | std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> msOutputs; | ||||
| std::unordered_map<std::string, CheckTensor *> calibData; | std::unordered_map<std::string, CheckTensor *> calibData; | ||||
| std::unordered_map<std::string, TypeId> dataTypeMap{ | |||||
| {"FLOAT", TypeId::kNumberTypeFloat}, {"INT8", TypeId::kNumberTypeInt8}, {"INT32", TypeId::kNumberTypeInt32}}; | |||||
| std::unordered_map<std::string, TypeId> dataTypeMap{{"FLOAT", TypeId::kNumberTypeFloat}, | |||||
| {"INT8", TypeId::kNumberTypeInt8}, | |||||
| {"INT32", TypeId::kNumberTypeInt32}, | |||||
| {"UINT8", TypeId::kNumberTypeUInt8}}; | |||||
| TypeId msCalibDataType = TypeId::kNumberTypeFloat; | TypeId msCalibDataType = TypeId::kNumberTypeFloat; | ||||
| }; | }; | ||||
| @@ -23,14 +23,14 @@ namespace mindspore { | |||||
| namespace lite { | namespace lite { | ||||
| namespace converter { | namespace converter { | ||||
| Flags::Flags() { | Flags::Flags() { | ||||
| AddFlag(&Flags::fmkIn, "fmk", "Input model framework type. TFLITE | CAFFE | MS", ""); | |||||
| AddFlag(&Flags::fmkIn, "fmk", "Input model framework type. TFLITE | CAFFE | MINDIR | ONNX", ""); | |||||
| AddFlag(&Flags::modelFile, "modelFile", | AddFlag(&Flags::modelFile, "modelFile", | ||||
| "Input model file path. TFLITE: *.tflite | CAFFE: *.prototxt | MS: *.mindir | ONNX: *.onnx", ""); | |||||
| "Input model file path. TFLITE: *.tflite | CAFFE: *.prototxt | MINDIR: *.mindir | ONNX: *.onnx", ""); | |||||
| AddFlag(&Flags::outputFile, "outputFile", "Output model file path. Will add .ms automatically", ""); | AddFlag(&Flags::outputFile, "outputFile", "Output model file path. Will add .ms automatically", ""); | ||||
| AddFlag(&Flags::weightFile, "weightFile", | AddFlag(&Flags::weightFile, "weightFile", | ||||
| "Input model weight file path. Needed when fmk is CAFFE. CAFFE: *.caffemodel", ""); | "Input model weight file path. Needed when fmk is CAFFE. CAFFE: *.caffemodel", ""); | ||||
| AddFlag(&Flags::inferenceTypeIn, "inferenceType", | |||||
| "Real data type saved in output file, reserved param, NOT used for now. SAME | FLOAT | INT8", "FLOAT"); | |||||
| AddFlag(&Flags::inferenceTypeIn, "inferenceType", "Data type of input and output tensors. FLOAT | INT8 | UINT8", | |||||
| "FLOAT"); | |||||
| AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. AwareTraining | PostTraining | WeightQuant", ""); | AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. AwareTraining | PostTraining | WeightQuant", ""); | ||||
| AddFlag(&Flags::stdDev, "stdDev", "Standard deviation value for aware-quantization", "128"); | AddFlag(&Flags::stdDev, "stdDev", "Standard deviation value for aware-quantization", "128"); | ||||
| AddFlag(&Flags::mean, "mean", "Mean value for aware-quantization", "-0.5"); | AddFlag(&Flags::mean, "mean", "Mean value for aware-quantization", "-0.5"); | ||||
| @@ -39,7 +39,6 @@ Flags::Flags() { | |||||
| AddFlag(&Flags::convWeightQuantChannelThreshold, "convWeightQuantChannelThreshold", "convWeightQuantChannelThreshold", | AddFlag(&Flags::convWeightQuantChannelThreshold, "convWeightQuantChannelThreshold", "convWeightQuantChannelThreshold", | ||||
| "16"); | "16"); | ||||
| AddFlag(&Flags::configFile, "config_file", "Configuration for post-training.", ""); | AddFlag(&Flags::configFile, "config_file", "Configuration for post-training.", ""); | ||||
| AddFlag(&Flags::formatTrans, "formatTrans", "whether transform format. true | false", "true"); | |||||
| AddFlag(&Flags::trainModelIn, "trainModel", | AddFlag(&Flags::trainModelIn, "trainModel", | ||||
| "whether the model is going to be trained on device." | "whether the model is going to be trained on device." | ||||
| " true | false", | " true | false", | ||||
| @@ -86,24 +85,24 @@ int Flags::Init(int argc, const char **argv) { | |||||
| this->inferenceType = TypeId::kNumberTypeFloat; | this->inferenceType = TypeId::kNumberTypeFloat; | ||||
| } else if (this->inferenceTypeIn == "INT8") { | } else if (this->inferenceTypeIn == "INT8") { | ||||
| this->inferenceType = TypeId::kNumberTypeInt8; | this->inferenceType = TypeId::kNumberTypeInt8; | ||||
| } else if (this->inferenceTypeIn == "SAME") { | |||||
| this->inferenceType = TypeId::kTypeUnknown; | |||||
| } else if (this->inferenceTypeIn == "UINT8") { | |||||
| this->inferenceType = TypeId::kNumberTypeUInt8; | |||||
| } else { | } else { | ||||
| std::cerr << "INPUT INVALID: inferenceType is invalid: %s, supported inferenceType: FLOAT | INT8 | SAME", | |||||
| std::cerr << "INPUT INVALID: inferenceType is invalid: %s, supported inferenceType: FLOAT | INT8 | UINT8", | |||||
| this->inferenceTypeIn.c_str(); | this->inferenceTypeIn.c_str(); | ||||
| return RET_INPUT_PARAM_INVALID; | return RET_INPUT_PARAM_INVALID; | ||||
| } | } | ||||
| if (this->fmkIn == "CAFFE") { | if (this->fmkIn == "CAFFE") { | ||||
| this->fmk = FmkType_CAFFE; | this->fmk = FmkType_CAFFE; | ||||
| } else if (this->fmkIn == "MS") { | |||||
| } else if (this->fmkIn == "MINDIR") { | |||||
| this->fmk = FmkType_MS; | this->fmk = FmkType_MS; | ||||
| } else if (this->fmkIn == "TFLITE") { | } else if (this->fmkIn == "TFLITE") { | ||||
| this->fmk = FmkType_TFLITE; | this->fmk = FmkType_TFLITE; | ||||
| } else if (this->fmkIn == "ONNX") { | } else if (this->fmkIn == "ONNX") { | ||||
| this->fmk = FmkType_ONNX; | this->fmk = FmkType_ONNX; | ||||
| } else { | } else { | ||||
| std::cerr << "INPUT ILLEGAL: fmk must be TFLITE|CAFFE|MS|ONNX"; | |||||
| std::cerr << "INPUT ILLEGAL: fmk must be TFLITE|CAFFE|MINDIR|ONNX"; | |||||
| return RET_INPUT_PARAM_INVALID; | return RET_INPUT_PARAM_INVALID; | ||||
| } | } | ||||
| @@ -64,7 +64,6 @@ class Flags : public virtual mindspore::lite::FlagParser { | |||||
| std::string quantSize; | std::string quantSize; | ||||
| std::string bitNum; | std::string bitNum; | ||||
| std::string configFile; | std::string configFile; | ||||
| bool formatTrans = true; | |||||
| std::string convWeightQuantChannelThreshold; | std::string convWeightQuantChannelThreshold; | ||||
| std::string trainModelIn; | std::string trainModelIn; | ||||
| bool trainModel = false; | bool trainModel = false; | ||||
| @@ -137,7 +137,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||||
| } | } | ||||
| } | } | ||||
| // format transform | // format transform | ||||
| if (ctx.formatTrans) { | |||||
| { | |||||
| Optimizer formatTransOptimizer; | Optimizer formatTransOptimizer; | ||||
| auto formatTransPass = new (std::nothrow) FormatTransPass(); | auto formatTransPass = new (std::nothrow) FormatTransPass(); | ||||
| if (formatTransPass == nullptr) { | if (formatTransPass == nullptr) { | ||||
| @@ -103,7 +103,10 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) { | |||||
| if (outputDataDType == TypeId::kNumberTypeInt8) { | if (outputDataDType == TypeId::kNumberTypeInt8) { | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| MS_ASSERT(outputDataDType == TypeId::kNumberTypeFloat); | |||||
| if (this->outputDataDType != TypeId::kNumberTypeFloat && this->outputDataDType != TypeId::kNumberTypeUInt8) { | |||||
| MS_LOG(ERROR) << "Invalid outputDataType: " << this->outputDataDType; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto &graphOutIdxes = graph->outputIndex; | auto &graphOutIdxes = graph->outputIndex; | ||||
| for (auto graphOutIdx : graphOutIdxes) { | for (auto graphOutIdx : graphOutIdxes) { | ||||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | ||||
| @@ -113,7 +116,11 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) { | |||||
| if ((*iter)->outputIndex.at(outputIndexIdx) == graphOutIdx) { | if ((*iter)->outputIndex.at(outputIndexIdx) == graphOutIdx) { | ||||
| // insert transNode | // insert transNode | ||||
| STATUS status = RET_OK; | STATUS status = RET_OK; | ||||
| iter = InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, kInt8ToFP32, &status); | |||||
| if (inputDataDType == TypeId::kNumberTypeFloat) { | |||||
| iter = InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, kInt8ToFP32, &status); | |||||
| } else { | |||||
| iter = InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, kInt8ToUInt8, &status); | |||||
| } | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "InsertDTypeTransNode after " << nodeName.c_str() << " failed"; | MS_LOG(ERROR) << "InsertDTypeTransNode after " << nodeName.c_str() << " failed"; | ||||
| return status; | return status; | ||||
| @@ -46,7 +46,8 @@ STATUS TfliteQuantizeParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfl | |||||
| MS_LOG(ERROR) << "output tensor is null"; | MS_LOG(ERROR) << "output tensor is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| if (GetTfliteDataType(in_tensor->type) != kNumberTypeInt8) { | |||||
| if (GetTfliteDataType(out_tensor->type) == kNumberTypeInt8 || | |||||
| GetTfliteDataType(out_tensor->type) == kNumberTypeUInt8) { | |||||
| std::unique_ptr<schema::QuantDTypeCastT> attr = std::make_unique<schema::QuantDTypeCastT>(); | std::unique_ptr<schema::QuantDTypeCastT> attr = std::make_unique<schema::QuantDTypeCastT>(); | ||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||
| MS_LOG(ERROR) << "new op failed"; | MS_LOG(ERROR) << "new op failed"; | ||||