From c688a256476dd9c4a9fdca487557898f89fd9bfe Mon Sep 17 00:00:00 2001 From: cjh9368 Date: Sat, 24 Oct 2020 14:38:58 +0800 Subject: [PATCH] bug fix --- mindspore/lite/src/kernel_registry.cc | 2 + .../lite/tools/converter/converter_context.h | 22 +++++++++ .../tools/converter/graphdef_transform.cc | 8 ++-- .../graph/dtype_trans_pass.cc | 29 ++++++------ .../graph/tensor_quant_pass.cc | 45 ++++++++++++++----- .../parser/tflite/tflite_dequantize_parser.cc | 2 +- .../parser/tflite/tflite_model_parser.cc | 7 --- .../parser/tflite/tflite_quantize_parser.cc | 2 +- 8 files changed, 79 insertions(+), 38 deletions(-) diff --git a/mindspore/lite/src/kernel_registry.cc b/mindspore/lite/src/kernel_registry.cc index 1cd0826e99..4ed813a482 100644 --- a/mindspore/lite/src/kernel_registry.cc +++ b/mindspore/lite/src/kernel_registry.cc @@ -120,6 +120,8 @@ kernel::LiteKernel *KernelRegistry::GetKernel(const std::vector &in_te kernel->set_desc(key); } return kernel; + } else { + free(parameter); } return nullptr; } diff --git a/mindspore/lite/tools/converter/converter_context.h b/mindspore/lite/tools/converter/converter_context.h index c179c7008c..6cbcdc7da1 100644 --- a/mindspore/lite/tools/converter/converter_context.h +++ b/mindspore/lite/tools/converter/converter_context.h @@ -19,8 +19,10 @@ #include #include +#include #include "include/errorcode.h" #include "src/common/log_adapter.h" +#include "ir/dtype/type_id.h" namespace mindspore { namespace lite { @@ -68,6 +70,26 @@ class NoSupportOp { std::set noSupportOps; std::string fmkType; }; + +class TensorDataType { + public: + ~TensorDataType() = default; + static TensorDataType *GetInstance() { + static TensorDataType tensorDataType; + return &tensorDataType; + } + void UpdateTensorType(int32_t index, int32_t type) { tensorDataTypeMap[index] = type; } + int32_t GetTensorType(int32_t index) const { + if (tensorDataTypeMap.find(index) == tensorDataTypeMap.end()) { + return TypeId::kTypeUnknown; + } + return tensorDataTypeMap.at(index); + } + + private: + TensorDataType() {} + std::map tensorDataTypeMap; +}; } // namespace lite } // namespace mindspore #endif // MINDSPORE_LITE_TOOLS_CONVERTER_RETURN_CODE_H diff --git a/mindspore/lite/tools/converter/graphdef_transform.cc b/mindspore/lite/tools/converter/graphdef_transform.cc index 74371e900d..36e46fdfa5 100644 --- a/mindspore/lite/tools/converter/graphdef_transform.cc +++ b/mindspore/lite/tools/converter/graphdef_transform.cc @@ -132,9 +132,11 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { // do quantization { - Optimizer fusionOptimizer; - fusionOptimizer.AddPass(new (std::nothrow) TensorQuantPass()); - status = fusionOptimizer.Run(graphDefT); + Optimizer tensorQuantOptimizer; + tensorQuantOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); + tensorQuantOptimizer.AddPass(new (std::nothrow) InferShapePass()); + tensorQuantOptimizer.AddPass(new (std::nothrow) TensorQuantPass()); + status = tensorQuantOptimizer.Run(graphDefT); if (status != RET_OK) { MS_LOG(ERROR) << "DoQuantize failed!"; return status; diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc index 5c84694b24..d1fa82617d 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc @@ -18,6 +18,7 @@ #include #include #include "tools/common/node_util.h" +#include "tools/converter/converter_context.h" #include "src/common/common.h" #include "src/common/utils.h" @@ -52,12 +53,8 @@ STATUS DTypeTransPass::Run(schema::MetaGraphT *graph) { STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) { MS_ASSERT(graph != nullptr); auto &graphInIdxes = graph->inputIndex; - - if (this->inputDataDType == TypeId::kTypeUnknown) { - return RET_OK; - } if (this->inputDataDType != TypeId::kNumberTypeFloat32 && this->inputDataDType != TypeId::kNumberTypeUInt8 && - this->inputDataDType != TypeId::kNumberTypeInt8) { + this->inputDataDType != TypeId::kNumberTypeInt8 && this->inputDataDType != TypeId::kTypeUnknown) { MS_LOG(ERROR) << "Invalid inputDataType: " << this->inputDataDType; return RET_ERROR; } @@ -67,7 +64,9 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) { if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited) { continue; } - + int32_t tensorDataType = this->inputDataDType != TypeId::kTypeUnknown + ? this->inputDataDType + : TensorDataType::GetInstance()->GetTensorType(graphInIdx); for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { auto nodeName = (*iter)->name; for (size_t inputIndexIdx = 0; inputIndexIdx < (*iter)->inputIndex.size(); inputIndexIdx++) { @@ -75,9 +74,8 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) { STATUS status = RET_OK; // insert dtype cast node between input tensor and input node - if (this->inputDataDType != tensor->dataType) { - iter = InsertDTypeTransNode(graph, iter, kBefore, inputIndexIdx, this->inputDataDType, tensor->dataType, - &status); + if (tensorDataType != tensor->dataType && tensorDataType != kTypeUnknown) { + iter = InsertDTypeTransNode(graph, iter, kBefore, inputIndexIdx, tensorDataType, tensor->dataType, &status); } if (status != RET_OK) { @@ -93,11 +91,8 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) { STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) { MS_ASSERT(graph != nullptr); - if (outputDataDType == TypeId::kTypeUnknown) { - return RET_OK; - } if (this->outputDataDType != TypeId::kNumberTypeFloat32 && this->outputDataDType != TypeId::kNumberTypeUInt8 && - this->outputDataDType != TypeId::kNumberTypeInt8) { + this->outputDataDType != TypeId::kNumberTypeInt8 && this->outputDataDType != TypeId::kTypeUnknown) { MS_LOG(ERROR) << "Invalid outputDataType: " << this->outputDataDType; return RET_ERROR; } @@ -108,6 +103,9 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) { if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited) { continue; } + int32_t tensorDataType = this->outputDataDType != TypeId::kTypeUnknown + ? this->inputDataDType + : TensorDataType::GetInstance()->GetTensorType(graphOutIdx); for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { auto nodeName = (*iter)->name; MS_ASSERT(node != nullptr); @@ -115,9 +113,8 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) { if ((*iter)->outputIndex.at(outputIndexIdx) == graphOutIdx) { // insert transNode STATUS status = RET_OK; - if (this->outputDataDType != tensor->dataType) { - iter = InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, tensor->dataType, this->outputDataDType, - &status); + if (tensorDataType != tensor->dataType && tensorDataType != kTypeUnknown) { + iter = InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, tensor->dataType, tensorDataType, &status); } if (status != RET_OK) { MS_LOG(ERROR) << "InsertDTypeTransNode after " << nodeName.c_str() << " failed"; diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_quant_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_quant_pass.cc index 8170c9da40..237d25de1c 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_quant_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_quant_pass.cc @@ -17,23 +17,42 @@ #include #include #include "tools/converter/legacy_optimizer/graph/tensor_quant_pass.h" +#include "tools/converter/converter_context.h" #include "tools/converter/quantizer/quantize_util.h" #include "tools/common/tensor_util.h" namespace mindspore::lite { STATUS TensorQuantPass::Run(schema::MetaGraphT *graph) { + for (auto &node : graph->nodes) { + if (node->primitive->value.type == PrimitiveType_QuantDTypeCast) { + auto attr = node->primitive->value.AsQuantDTypeCast(); + auto &inputTensor = graph->allTensors.at(node->inputIndex.front()); + inputTensor->dataType = attr->srcT; + auto &outputTensor = graph->allTensors.at(node->outputIndex.front()); + outputTensor->dataType = attr->dstT; + + if (attr->srcT == TypeId::kNumberTypeUInt8) { + attr->srcT = TypeId::kNumberTypeInt8; + } + if (attr->dstT == TypeId::kNumberTypeUInt8) { + attr->dstT = TypeId::kNumberTypeInt8; + } + } + } + int index = -1; for (auto &tensor : graph->allTensors) { - if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited || tensor->data.empty()) { + index++; + if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited) { continue; } if (tensor->dataType != TypeId::kNumberTypeFloat32 && tensor->dataType != TypeId::kNumberTypeFloat && - tensor->dataType != TypeId::kNumberTypeUInt8) { + tensor->dataType != TypeId::kNumberTypeUInt8 && tensor->dataType != TypeId::kTypeUnknown) { continue; } // perlayer if (tensor->quantParams.size() == 1) { auto &quantParam = tensor->quantParams.front(); - size_t wShapeSize = GetShapeSize(*(tensor.get())); + size_t wShapeSize = tensor->data.empty() ? 0 : GetShapeSize(*(tensor.get())); void *oriWeightData = tensor->data.data(); if (quantParam->dstDtype == TypeId::kNumberTypeInt8) { std::vector qDatas(wShapeSize); @@ -41,6 +60,9 @@ STATUS TensorQuantPass::Run(schema::MetaGraphT *graph) { if (tensor->dataType == TypeId::kNumberTypeFloat || tensor->dataType == TypeId::kNumberTypeFloat32) { // normal awareing quant auto *weightData = static_cast(oriWeightData); + if (weightData == nullptr) { + continue; + } for (size_t j = 0; j < wShapeSize; j++) { qDatas[j] = quant::QuantizeData(weightData[j], weightQauntParam.get()); } @@ -52,15 +74,18 @@ STATUS TensorQuantPass::Run(schema::MetaGraphT *graph) { weightQauntParam->zeroPoint -= 128; tensor->quantParams.clear(); tensor->quantParams.emplace_back(weightQauntParam.release()); + TensorDataType::GetInstance()->UpdateTensorType(index, TypeId::kNumberTypeUInt8); } tensor->dataType = TypeId::kNumberTypeInt8; - tensor->data.clear(); - tensor->data.resize(wShapeSize * sizeof(int8_t)); - auto ret = - memcpy_s(tensor->data.data(), wShapeSize * sizeof(int8_t), qDatas.data(), wShapeSize * sizeof(int8_t)); - if (ret != EOK) { - MS_LOG(ERROR) << "memcpy_s failed: " << ret; - return RET_ERROR; + if (!tensor->data.empty()) { + tensor->data.clear(); + tensor->data.resize(wShapeSize * sizeof(int8_t)); + auto ret = + memcpy_s(tensor->data.data(), wShapeSize * sizeof(int8_t), qDatas.data(), wShapeSize * sizeof(int8_t)); + if (ret != EOK) { + MS_LOG(ERROR) << "memcpy_s failed: " << ret; + return RET_ERROR; + } } } else if (quantParam->dstDtype == TypeId::kNumberTypeInt32) { // quant bias data diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc index b93bb749ae..5bdd03e289 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_dequantize_parser.cc @@ -53,7 +53,7 @@ STATUS TfliteDequantizeParser::Parse(TfliteTensorsInfo *tensors_info, MS_LOG(ERROR) << "new op failed"; return RET_NULL_PTR; } - attr->srcT = kNumberTypeInt8; + attr->srcT = GetTfliteDataType(in_tensor->type); attr->dstT = GetTfliteDataType(out_tensor->type); op->primitive->value.value = attr.release(); op->primitive->value.type = schema::PrimitiveType_QuantDTypeCast; diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc index 898675707a..220163a867 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc @@ -76,12 +76,6 @@ void TfliteModelParser::SetTensorQuantParam(const std::unique_ptrzeroPoint = tflite_tensor->quantization->zero_point[i]; } - // change quant param min to 0 to fit ms-lite ops - if (GetTfliteDataType(tflite_tensor->type) == TypeId::kNumberTypeUInt8 && tensor->data.empty()) { - quant_param->zeroPoint = quant_param->zeroPoint - 128; - tensor->dataType = TypeId::kNumberTypeInt8; - } - if (!tflite_tensor->quantization->min.empty()) { quant_param->min = tflite_tensor->quantization->min[i]; } @@ -127,7 +121,6 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr &tflit } continue; } - sub_graph->nodes.emplace_back(op.release()); opMap[sub_graph->nodes.back()->name] = sub_graph->nodes.back().get(); tfliteOpMap[tflite_op.get()] = sub_graph->nodes.back().get(); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.cc index 901aa6afa1..0d9910de39 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.cc @@ -53,7 +53,7 @@ STATUS TfliteQuantizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::u return RET_NULL_PTR; } attr->srcT = GetTfliteDataType(in_tensor->type); - attr->dstT = kNumberTypeInt8; + attr->dstT = GetTfliteDataType(out_tensor->type); op->primitive->value.type = schema::PrimitiveType_QuantDTypeCast; op->primitive->value.value = attr.release(); } else {