Browse Source

add inferenceType UINT8 and recitify converter params

tags/v1.0.0
cjh9368 5 years ago
parent
commit
a161a3375a
12 changed files with 122 additions and 38 deletions
  1. +38
    -2
      mindspore/lite/nnacl/int8/quant_dtype_cast.c
  2. +4
    -2
      mindspore/lite/nnacl/int8/quant_dtype_cast.h
  3. +46
    -13
      mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.cc
  4. +1
    -0
      mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.h
  5. +2
    -2
      mindspore/lite/test/run_benchmark_nets.sh
  6. +4
    -1
      mindspore/lite/tools/benchmark/benchmark.cc
  7. +6
    -3
      mindspore/lite/tools/benchmark/benchmark.h
  8. +9
    -10
      mindspore/lite/tools/converter/converter_flags.cc
  9. +0
    -1
      mindspore/lite/tools/converter/converter_flags.h
  10. +1
    -1
      mindspore/lite/tools/converter/graphdef_transform.cc
  11. +9
    -2
      mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc
  12. +2
    -1
      mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.cc

+ 38
- 2
mindspore/lite/nnacl/int8/quant_dtype_cast.c View File

@@ -18,7 +18,7 @@
#include "nnacl/int8/quant_dtype_cast.h" #include "nnacl/int8/quant_dtype_cast.h"
#include "nnacl/errorcode.h" #include "nnacl/errorcode.h"


int DoDequantizeInt8(int8_t *quant_values, float *real_values, float scale, int32_t zp, int size) {
int DoDequantizeInt8ToFp32(int8_t *quant_values, float *real_values, float scale, int32_t zp, int size) {
if (quant_values == NULL || real_values == NULL) { if (quant_values == NULL || real_values == NULL) {
return NNACL_PARAM_INVALID; return NNACL_PARAM_INVALID;
} }
@@ -29,7 +29,7 @@ int DoDequantizeInt8(int8_t *quant_values, float *real_values, float scale, int3
return NNACL_OK; return NNACL_OK;
} }


int DoQuantizeToInt8(float *real_values, int8_t *quant_values, float scale, int32_t zp, int size) {
int DoQuantizeToInt8FromFp32(float *real_values, int8_t *quant_values, float scale, int32_t zp, int size) {
if (quant_values == NULL || real_values == NULL) { if (quant_values == NULL || real_values == NULL) {
return NNACL_PARAM_INVALID; return NNACL_PARAM_INVALID;
} }
@@ -46,3 +46,39 @@ int DoQuantizeToInt8(float *real_values, int8_t *quant_values, float scale, int3
} }
return NNACL_OK; return NNACL_OK;
} }

int DoDequantizeInt8ToUInt8(int8_t *quant_values, uint8_t *real_values, int size) {
if (quant_values == NULL || real_values == NULL) {
return NNACL_PARAM_INVALID;
}

for (int i = 0; i < size; ++i) {
int temp = quant_values[i] + 128;
if (temp > 255) {
real_values[i] = (uint8_t)255;
} else if (temp < 0) {
real_values[i] = 0;
} else {
real_values[i] = (uint8_t)temp;
}
}
return NNACL_OK;
}

int DoQuantizeToInt8FromUint8(uint8_t *real_values, int8_t *quant_values, int size) {
if (quant_values == NULL || real_values == NULL) {
return NNACL_PARAM_INVALID;
}

for (int i = 0; i < size; ++i) {
int temp = real_values[i] - 128;
if (temp > 127) {
quant_values[i] = 127;
} else if (temp < -128) {
quant_values[i] = -128;
} else {
quant_values[i] = (int8_t)temp;
}
}
return NNACL_OK;
}

+ 4
- 2
mindspore/lite/nnacl/int8/quant_dtype_cast.h View File

@@ -28,8 +28,10 @@ typedef struct QuantDTypeCastParameter {
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
int DoDequantizeInt8(int8_t *quant_values, float *real_values, float scale, int32_t zp, int size);
int DoQuantizeToInt8(float *real_values, int8_t *quant_values, float scale, int32_t zp, int size);
int DoDequantizeInt8ToFp32(int8_t *quant_values, float *real_values, float scale, int32_t zp, int size);
int DoQuantizeToInt8FromFp32(float *real_values, int8_t *quant_values, float scale, int32_t zp, int size);
int DoDequantizeInt8ToUInt8(int8_t *quant_values, uint8_t *real_values, int size);
int DoQuantizeToInt8FromUint8(uint8_t *real_values, int8_t *quant_values, int size);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif


+ 46
- 13
mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.cc View File

@@ -53,6 +53,18 @@ int QuantDTypeCastCPUKernel::Init() {
return RET_ERROR; return RET_ERROR;
} }
inverse_ = true; inverse_ = true;
} else if (param->srcT == kNumberTypeUInt8 && param->dstT == kNumberTypeInt8) {
if (in_tensor->data_type() != kNumberTypeUInt8 || out_tensor->data_type() != kNumberTypeInt8) {
MS_LOG(ERROR) << "param data type and tensor data type do not match.";
return RET_ERROR;
}
inverse_ = false;
} else if (param->srcT == kNumberTypeInt8 && param->dstT == kNumberTypeUInt8) {
if (in_tensor->data_type() != kNumberTypeInt8 || out_tensor->data_type() != kNumberTypeUInt8) {
MS_LOG(ERROR) << "param data type and tensor data type do not match.";
return RET_ERROR;
}
inverse_ = true;
} else { } else {
MS_LOG(ERROR) << "param data type not supported:" MS_LOG(ERROR) << "param data type not supported:"
<< " src: " << param->srcT << " dst: " << param->dstT; << " src: " << param->srcT << " dst: " << param->dstT;
@@ -83,16 +95,25 @@ int QuantDTypeCastCPUKernel::QuantDTypeCast(int task_id) {
MS_LOG(ERROR) << "QuantDTypeCast need quantization parameters which is not found."; MS_LOG(ERROR) << "QuantDTypeCast need quantization parameters which is not found.";
return RET_ERROR; return RET_ERROR;
} }
auto quant_arg = !in_tensors_.front()->GetQuantParams().empty() ? in_tensors_.front()->GetQuantParams().front() :
out_tensors_.front()->GetQuantParams().front();
auto quant_arg = !in_tensors_.front()->GetQuantParams().empty() ? in_tensors_.front()->GetQuantParams().front()
: out_tensors_.front()->GetQuantParams().front();
int ret; int ret;
if (inverse_) {
ret = DoDequantizeInt8(int8_ptr_ + thread_offset, float32_ptr_ + thread_offset, quant_arg.scale,
quant_arg.zeroPoint, num_unit_thread);
if (uint8_ptr_ == nullptr) {
if (inverse_) {
ret = DoDequantizeInt8ToFp32(int8_ptr_ + thread_offset, float32_ptr_ + thread_offset, quant_arg.scale,
quant_arg.zeroPoint, num_unit_thread);
} else {
ret = DoQuantizeToInt8FromFp32(float32_ptr_ + thread_offset, int8_ptr_ + thread_offset, quant_arg.scale,
quant_arg.zeroPoint, num_unit_thread);
}
} else { } else {
ret = DoQuantizeToInt8(float32_ptr_ + thread_offset, int8_ptr_ + thread_offset, quant_arg.scale,
quant_arg.zeroPoint, num_unit_thread);
if (inverse_) {
ret = DoDequantizeInt8ToUInt8(int8_ptr_ + thread_offset, uint8_ptr_ + thread_offset, num_unit_thread);
} else {
ret = DoQuantizeToInt8FromUint8(uint8_ptr_ + thread_offset, int8_ptr_ + thread_offset, num_unit_thread);
}
} }

if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "QuantDTypeCast error task_id[" << task_id << "] error_code[" << ret << "]"; MS_LOG(ERROR) << "QuantDTypeCast error task_id[" << task_id << "] error_code[" << ret << "]";
return RET_ERROR; return RET_ERROR;
@@ -116,12 +137,23 @@ int QuantDTypeCastCPUKernel::Run() {
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
return prepare_ret; return prepare_ret;
} }
if (inverse_) {
int8_ptr_ = reinterpret_cast<int8_t *>(in_tensors_[0]->MutableData());
float32_ptr_ = reinterpret_cast<float *>(out_tensors_[0]->MutableData());
} else {
float32_ptr_ = reinterpret_cast<float *>(in_tensors_[0]->MutableData());
int8_ptr_ = reinterpret_cast<int8_t *>(out_tensors_[0]->MutableData());

if (in_tensors_[0]->data_type() == TypeId::kNumberTypeInt8 &&
out_tensors_[0]->data_type() == TypeId::kNumberTypeFloat32) {
int8_ptr_ = reinterpret_cast<int8_t *>(in_tensors_[0]->data_c());
float32_ptr_ = reinterpret_cast<float *>(out_tensors_[0]->data_c());
} else if (in_tensors_[0]->data_type() == TypeId::kNumberTypeFloat32 &&
out_tensors_[0]->data_type() == TypeId::kNumberTypeInt8) {
float32_ptr_ = reinterpret_cast<float *>(in_tensors_[0]->data_c());
int8_ptr_ = reinterpret_cast<int8_t *>(out_tensors_[0]->data_c());
} else if (in_tensors_[0]->data_type() == TypeId::kNumberTypeInt8 &&
out_tensors_[0]->data_type() == TypeId::kNumberTypeUInt8) {
int8_ptr_ = reinterpret_cast<int8_t *>(in_tensors_[0]->data_c());
uint8_ptr_ = reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c());
} else if (in_tensors_[0]->data_type() == TypeId::kNumberTypeUInt8 &&
out_tensors_[0]->data_type() == TypeId::kNumberTypeInt8) {
uint8_ptr_ = reinterpret_cast<uint8_t *>(in_tensors_[0]->data_c());
int8_ptr_ = reinterpret_cast<int8_t *>(out_tensors_[0]->data_c());
} }


auto ret = ParallelLaunch(THREAD_POOL_DEFAULT, QuantDTypeCastRun, this, thread_n_num_); auto ret = ParallelLaunch(THREAD_POOL_DEFAULT, QuantDTypeCastRun, this, thread_n_num_);
@@ -156,6 +188,7 @@ kernel::LiteKernel *CpuQuantDTypeCastFp32KernelCreator(const std::vector<lite::T
} }
return kernel; return kernel;
} }
REG_KERNEL(kCPU, kNumberTypeUInt8, PrimitiveType_QuantDTypeCast, CpuQuantDTypeCastFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_QuantDTypeCast, CpuQuantDTypeCastFp32KernelCreator) REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_QuantDTypeCast, CpuQuantDTypeCastFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_QuantDTypeCast, CpuQuantDTypeCastFp32KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_QuantDTypeCast, CpuQuantDTypeCastFp32KernelCreator)
} // namespace mindspore::kernel } // namespace mindspore::kernel

+ 1
- 0
mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.h View File

@@ -40,6 +40,7 @@ class QuantDTypeCastCPUKernel : public LiteKernel {
int thread_n_stride_; int thread_n_stride_;
int num_unit_; int num_unit_;
int8_t *int8_ptr_; int8_t *int8_ptr_;
uint8_t *uint8_ptr_ = nullptr;
float *float32_ptr_; float *float32_ptr_;
bool inverse_; bool inverse_;
}; };


+ 2
- 2
mindspore/lite/test/run_benchmark_nets.sh View File

@@ -72,8 +72,8 @@ function Run_Converter() {
continue continue
fi fi
echo ${model_name} >> "${run_converter_log_file}" echo ${model_name} >> "${run_converter_log_file}"
echo './converter_lite --fmk=MS --modelFile='${models_path}'/'${model_name}' --outputFile='${ms_models_path}'/'${model_name}'' >> "${run_converter_log_file}"
./converter_lite --fmk=MS --modelFile=${models_path}/${model_name} --outputFile=${ms_models_path}/${model_name}
echo './converter_lite --fmk=MINDIR --modelFile='${models_path}'/'${model_name}' --outputFile='${ms_models_path}'/'${model_name}'' >> "${run_converter_log_file}"
./converter_lite --fmk=MINDIR --modelFile=${models_path}/${model_name} --outputFile=${ms_models_path}/${model_name}
if [ $? = 0 ]; then if [ $? = 0 ]; then
converter_result='converter mindspore '${model_name}' pass';echo ${converter_result} >> ${run_converter_result_file} converter_result='converter mindspore '${model_name}' pass';echo ${converter_result} >> ${run_converter_result_file}
else else


+ 4
- 1
mindspore/lite/tools/benchmark/benchmark.cc View File

@@ -186,7 +186,6 @@ int Benchmark::CompareOutput() {
} else { } else {
tensor = tensors.front(); tensor = tensors.front();
} }
MS_ASSERT(tensor->GetDataType() == DataType_DT_FLOAT);
MS_ASSERT(tensor->GetData() != nullptr); MS_ASSERT(tensor->GetData() != nullptr);
float bias = 0; float bias = 0;
switch (msCalibDataType) { switch (msCalibDataType) {
@@ -198,6 +197,10 @@ int Benchmark::CompareOutput() {
bias = CompareData<int8_t>(nodeOrTensorName, tensor->shape(), static_cast<int8_t *>(tensor->MutableData())); bias = CompareData<int8_t>(nodeOrTensorName, tensor->shape(), static_cast<int8_t *>(tensor->MutableData()));
break; break;
} }
case TypeId::kNumberTypeUInt8: {
bias = CompareData<uint8_t>(nodeOrTensorName, tensor->shape(), static_cast<uint8_t *>(tensor->MutableData()));
break;
}
case TypeId::kNumberTypeInt32: { case TypeId::kNumberTypeInt32: {
bias = CompareData<int32_t>(nodeOrTensorName, tensor->shape(), static_cast<int32_t *>(tensor->MutableData())); bias = CompareData<int32_t>(nodeOrTensorName, tensor->shape(), static_cast<int32_t *>(tensor->MutableData()));
break; break;


+ 6
- 3
mindspore/lite/tools/benchmark/benchmark.h View File

@@ -66,7 +66,8 @@ class MS_API BenchmarkFlags : public virtual FlagParser {
AddFlag(&BenchmarkFlags::warmUpLoopCount, "warmUpLoopCount", "Run warm up loop", 3); AddFlag(&BenchmarkFlags::warmUpLoopCount, "warmUpLoopCount", "Run warm up loop", 3);
// MarkAccuracy // MarkAccuracy
AddFlag(&BenchmarkFlags::calibDataPath, "calibDataPath", "Calibration data file path", ""); AddFlag(&BenchmarkFlags::calibDataPath, "calibDataPath", "Calibration data file path", "");
AddFlag(&BenchmarkFlags::calibDataType, "calibDataType", "Calibration data type. FLOAT | INT32 | INT8", "FLOAT");
AddFlag(&BenchmarkFlags::calibDataType, "calibDataType", "Calibration data type. FLOAT | INT32 | INT8 | UINT8",
"FLOAT");
AddFlag(&BenchmarkFlags::accuracyThreshold, "accuracyThreshold", "Threshold of accuracy", 0.5); AddFlag(&BenchmarkFlags::accuracyThreshold, "accuracyThreshold", "Threshold of accuracy", 0.5);
} }


@@ -222,8 +223,10 @@ class MS_API Benchmark {
std::vector<mindspore::tensor::MSTensor *> msInputs; std::vector<mindspore::tensor::MSTensor *> msInputs;
std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> msOutputs; std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> msOutputs;
std::unordered_map<std::string, CheckTensor *> calibData; std::unordered_map<std::string, CheckTensor *> calibData;
std::unordered_map<std::string, TypeId> dataTypeMap{
{"FLOAT", TypeId::kNumberTypeFloat}, {"INT8", TypeId::kNumberTypeInt8}, {"INT32", TypeId::kNumberTypeInt32}};
std::unordered_map<std::string, TypeId> dataTypeMap{{"FLOAT", TypeId::kNumberTypeFloat},
{"INT8", TypeId::kNumberTypeInt8},
{"INT32", TypeId::kNumberTypeInt32},
{"UINT8", TypeId::kNumberTypeUInt8}};
TypeId msCalibDataType = TypeId::kNumberTypeFloat; TypeId msCalibDataType = TypeId::kNumberTypeFloat;
}; };




+ 9
- 10
mindspore/lite/tools/converter/converter_flags.cc View File

@@ -23,14 +23,14 @@ namespace mindspore {
namespace lite { namespace lite {
namespace converter { namespace converter {
Flags::Flags() { Flags::Flags() {
AddFlag(&Flags::fmkIn, "fmk", "Input model framework type. TFLITE | CAFFE | MS", "");
AddFlag(&Flags::fmkIn, "fmk", "Input model framework type. TFLITE | CAFFE | MINDIR | ONNX", "");
AddFlag(&Flags::modelFile, "modelFile", AddFlag(&Flags::modelFile, "modelFile",
"Input model file path. TFLITE: *.tflite | CAFFE: *.prototxt | MS: *.mindir | ONNX: *.onnx", "");
"Input model file path. TFLITE: *.tflite | CAFFE: *.prototxt | MINDIR: *.mindir | ONNX: *.onnx", "");
AddFlag(&Flags::outputFile, "outputFile", "Output model file path. Will add .ms automatically", ""); AddFlag(&Flags::outputFile, "outputFile", "Output model file path. Will add .ms automatically", "");
AddFlag(&Flags::weightFile, "weightFile", AddFlag(&Flags::weightFile, "weightFile",
"Input model weight file path. Needed when fmk is CAFFE. CAFFE: *.caffemodel", ""); "Input model weight file path. Needed when fmk is CAFFE. CAFFE: *.caffemodel", "");
AddFlag(&Flags::inferenceTypeIn, "inferenceType",
"Real data type saved in output file, reserved param, NOT used for now. SAME | FLOAT | INT8", "FLOAT");
AddFlag(&Flags::inferenceTypeIn, "inferenceType", "Data type of input and output tensors. FLOAT | INT8 | UINT8",
"FLOAT");
AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. AwareTraining | PostTraining | WeightQuant", ""); AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. AwareTraining | PostTraining | WeightQuant", "");
AddFlag(&Flags::stdDev, "stdDev", "Standard deviation value for aware-quantization", "128"); AddFlag(&Flags::stdDev, "stdDev", "Standard deviation value for aware-quantization", "128");
AddFlag(&Flags::mean, "mean", "Mean value for aware-quantization", "-0.5"); AddFlag(&Flags::mean, "mean", "Mean value for aware-quantization", "-0.5");
@@ -39,7 +39,6 @@ Flags::Flags() {
AddFlag(&Flags::convWeightQuantChannelThreshold, "convWeightQuantChannelThreshold", "convWeightQuantChannelThreshold", AddFlag(&Flags::convWeightQuantChannelThreshold, "convWeightQuantChannelThreshold", "convWeightQuantChannelThreshold",
"16"); "16");
AddFlag(&Flags::configFile, "config_file", "Configuration for post-training.", ""); AddFlag(&Flags::configFile, "config_file", "Configuration for post-training.", "");
AddFlag(&Flags::formatTrans, "formatTrans", "whether transform format. true | false", "true");
AddFlag(&Flags::trainModelIn, "trainModel", AddFlag(&Flags::trainModelIn, "trainModel",
"whether the model is going to be trained on device." "whether the model is going to be trained on device."
" true | false", " true | false",
@@ -86,24 +85,24 @@ int Flags::Init(int argc, const char **argv) {
this->inferenceType = TypeId::kNumberTypeFloat; this->inferenceType = TypeId::kNumberTypeFloat;
} else if (this->inferenceTypeIn == "INT8") { } else if (this->inferenceTypeIn == "INT8") {
this->inferenceType = TypeId::kNumberTypeInt8; this->inferenceType = TypeId::kNumberTypeInt8;
} else if (this->inferenceTypeIn == "SAME") {
this->inferenceType = TypeId::kTypeUnknown;
} else if (this->inferenceTypeIn == "UINT8") {
this->inferenceType = TypeId::kNumberTypeUInt8;
} else { } else {
std::cerr << "INPUT INVALID: inferenceType is invalid: %s, supported inferenceType: FLOAT | INT8 | SAME",
std::cerr << "INPUT INVALID: inferenceType is invalid: %s, supported inferenceType: FLOAT | INT8 | UINT8",
this->inferenceTypeIn.c_str(); this->inferenceTypeIn.c_str();
return RET_INPUT_PARAM_INVALID; return RET_INPUT_PARAM_INVALID;
} }


if (this->fmkIn == "CAFFE") { if (this->fmkIn == "CAFFE") {
this->fmk = FmkType_CAFFE; this->fmk = FmkType_CAFFE;
} else if (this->fmkIn == "MS") {
} else if (this->fmkIn == "MINDIR") {
this->fmk = FmkType_MS; this->fmk = FmkType_MS;
} else if (this->fmkIn == "TFLITE") { } else if (this->fmkIn == "TFLITE") {
this->fmk = FmkType_TFLITE; this->fmk = FmkType_TFLITE;
} else if (this->fmkIn == "ONNX") { } else if (this->fmkIn == "ONNX") {
this->fmk = FmkType_ONNX; this->fmk = FmkType_ONNX;
} else { } else {
std::cerr << "INPUT ILLEGAL: fmk must be TFLITE|CAFFE|MS|ONNX";
std::cerr << "INPUT ILLEGAL: fmk must be TFLITE|CAFFE|MINDIR|ONNX";
return RET_INPUT_PARAM_INVALID; return RET_INPUT_PARAM_INVALID;
} }




+ 0
- 1
mindspore/lite/tools/converter/converter_flags.h View File

@@ -64,7 +64,6 @@ class Flags : public virtual mindspore::lite::FlagParser {
std::string quantSize; std::string quantSize;
std::string bitNum; std::string bitNum;
std::string configFile; std::string configFile;
bool formatTrans = true;
std::string convWeightQuantChannelThreshold; std::string convWeightQuantChannelThreshold;
std::string trainModelIn; std::string trainModelIn;
bool trainModel = false; bool trainModel = false;


+ 1
- 1
mindspore/lite/tools/converter/graphdef_transform.cc View File

@@ -137,7 +137,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
} }
} }
// format transform // format transform
if (ctx.formatTrans) {
{
Optimizer formatTransOptimizer; Optimizer formatTransOptimizer;
auto formatTransPass = new (std::nothrow) FormatTransPass(); auto formatTransPass = new (std::nothrow) FormatTransPass();
if (formatTransPass == nullptr) { if (formatTransPass == nullptr) {


+ 9
- 2
mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc View File

@@ -103,7 +103,10 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) {
if (outputDataDType == TypeId::kNumberTypeInt8) { if (outputDataDType == TypeId::kNumberTypeInt8) {
return RET_OK; return RET_OK;
} }
MS_ASSERT(outputDataDType == TypeId::kNumberTypeFloat);
if (this->outputDataDType != TypeId::kNumberTypeFloat && this->outputDataDType != TypeId::kNumberTypeUInt8) {
MS_LOG(ERROR) << "Invalid outputDataType: " << this->outputDataDType;
return RET_ERROR;
}
auto &graphOutIdxes = graph->outputIndex; auto &graphOutIdxes = graph->outputIndex;
for (auto graphOutIdx : graphOutIdxes) { for (auto graphOutIdx : graphOutIdxes) {
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
@@ -113,7 +116,11 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) {
if ((*iter)->outputIndex.at(outputIndexIdx) == graphOutIdx) { if ((*iter)->outputIndex.at(outputIndexIdx) == graphOutIdx) {
// insert transNode // insert transNode
STATUS status = RET_OK; STATUS status = RET_OK;
iter = InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, kInt8ToFP32, &status);
if (inputDataDType == TypeId::kNumberTypeFloat) {
iter = InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, kInt8ToFP32, &status);
} else {
iter = InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, kInt8ToUInt8, &status);
}
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "InsertDTypeTransNode after " << nodeName.c_str() << " failed"; MS_LOG(ERROR) << "InsertDTypeTransNode after " << nodeName.c_str() << " failed";
return status; return status;


+ 2
- 1
mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.cc View File

@@ -46,7 +46,8 @@ STATUS TfliteQuantizeParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfl
MS_LOG(ERROR) << "output tensor is null"; MS_LOG(ERROR) << "output tensor is null";
return RET_NULL_PTR; return RET_NULL_PTR;
} }
if (GetTfliteDataType(in_tensor->type) != kNumberTypeInt8) {
if (GetTfliteDataType(out_tensor->type) == kNumberTypeInt8 ||
GetTfliteDataType(out_tensor->type) == kNumberTypeUInt8) {
std::unique_ptr<schema::QuantDTypeCastT> attr = std::make_unique<schema::QuantDTypeCastT>(); std::unique_ptr<schema::QuantDTypeCastT> attr = std::make_unique<schema::QuantDTypeCastT>();
if (attr == nullptr) { if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed"; MS_LOG(ERROR) << "new op failed";


Loading…
Cancel
Save