| @@ -120,6 +120,8 @@ kernel::LiteKernel *KernelRegistry::GetKernel(const std::vector<Tensor *> &in_te | |||||
| kernel->set_desc(key); | kernel->set_desc(key); | ||||
| } | } | ||||
| return kernel; | return kernel; | ||||
| } else { | |||||
| free(parameter); | |||||
| } | } | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -19,8 +19,10 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <set> | #include <set> | ||||
| #include <map> | |||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "src/common/log_adapter.h" | #include "src/common/log_adapter.h" | ||||
| #include "ir/dtype/type_id.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| @@ -68,6 +70,26 @@ class NoSupportOp { | |||||
| std::set<std::string> noSupportOps; | std::set<std::string> noSupportOps; | ||||
| std::string fmkType; | std::string fmkType; | ||||
| }; | }; | ||||
| class TensorDataType { | |||||
| public: | |||||
| ~TensorDataType() = default; | |||||
| static TensorDataType *GetInstance() { | |||||
| static TensorDataType tensorDataType; | |||||
| return &tensorDataType; | |||||
| } | |||||
| void UpdateTensorType(int32_t index, int32_t type) { tensorDataTypeMap[index] = type; } | |||||
| int32_t GetTensorType(int32_t index) const { | |||||
| if (tensorDataTypeMap.find(index) == tensorDataTypeMap.end()) { | |||||
| return TypeId::kTypeUnknown; | |||||
| } | |||||
| return tensorDataTypeMap.at(index); | |||||
| } | |||||
| private: | |||||
| TensorDataType() {} | |||||
| std::map<int32_t, int32_t> tensorDataTypeMap; | |||||
| }; | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_RETURN_CODE_H | #endif // MINDSPORE_LITE_TOOLS_CONVERTER_RETURN_CODE_H | ||||
| @@ -132,9 +132,11 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||||
| // do quantization | // do quantization | ||||
| { | { | ||||
| Optimizer fusionOptimizer; | |||||
| fusionOptimizer.AddPass(new (std::nothrow) TensorQuantPass()); | |||||
| status = fusionOptimizer.Run(graphDefT); | |||||
| Optimizer tensorQuantOptimizer; | |||||
| tensorQuantOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); | |||||
| tensorQuantOptimizer.AddPass(new (std::nothrow) InferShapePass()); | |||||
| tensorQuantOptimizer.AddPass(new (std::nothrow) TensorQuantPass()); | |||||
| status = tensorQuantOptimizer.Run(graphDefT); | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "DoQuantize failed!"; | MS_LOG(ERROR) << "DoQuantize failed!"; | ||||
| return status; | return status; | ||||
| @@ -18,6 +18,7 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <set> | #include <set> | ||||
| #include "tools/common/node_util.h" | #include "tools/common/node_util.h" | ||||
| #include "tools/converter/converter_context.h" | |||||
| #include "src/common/common.h" | #include "src/common/common.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| @@ -52,12 +53,8 @@ STATUS DTypeTransPass::Run(schema::MetaGraphT *graph) { | |||||
| STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) { | STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) { | ||||
| MS_ASSERT(graph != nullptr); | MS_ASSERT(graph != nullptr); | ||||
| auto &graphInIdxes = graph->inputIndex; | auto &graphInIdxes = graph->inputIndex; | ||||
| if (this->inputDataDType == TypeId::kTypeUnknown) { | |||||
| return RET_OK; | |||||
| } | |||||
| if (this->inputDataDType != TypeId::kNumberTypeFloat32 && this->inputDataDType != TypeId::kNumberTypeUInt8 && | if (this->inputDataDType != TypeId::kNumberTypeFloat32 && this->inputDataDType != TypeId::kNumberTypeUInt8 && | ||||
| this->inputDataDType != TypeId::kNumberTypeInt8) { | |||||
| this->inputDataDType != TypeId::kNumberTypeInt8 && this->inputDataDType != TypeId::kTypeUnknown) { | |||||
| MS_LOG(ERROR) << "Invalid inputDataType: " << this->inputDataDType; | MS_LOG(ERROR) << "Invalid inputDataType: " << this->inputDataDType; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -67,7 +64,9 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) { | |||||
| if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited) { | if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| int32_t tensorDataType = this->inputDataDType != TypeId::kTypeUnknown | |||||
| ? this->inputDataDType | |||||
| : TensorDataType::GetInstance()->GetTensorType(graphInIdx); | |||||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | ||||
| auto nodeName = (*iter)->name; | auto nodeName = (*iter)->name; | ||||
| for (size_t inputIndexIdx = 0; inputIndexIdx < (*iter)->inputIndex.size(); inputIndexIdx++) { | for (size_t inputIndexIdx = 0; inputIndexIdx < (*iter)->inputIndex.size(); inputIndexIdx++) { | ||||
| @@ -75,9 +74,8 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) { | |||||
| STATUS status = RET_OK; | STATUS status = RET_OK; | ||||
| // insert dtype cast node between input tensor and input node | // insert dtype cast node between input tensor and input node | ||||
| if (this->inputDataDType != tensor->dataType) { | |||||
| iter = InsertDTypeTransNode(graph, iter, kBefore, inputIndexIdx, this->inputDataDType, tensor->dataType, | |||||
| &status); | |||||
| if (tensorDataType != tensor->dataType && tensorDataType != kTypeUnknown) { | |||||
| iter = InsertDTypeTransNode(graph, iter, kBefore, inputIndexIdx, tensorDataType, tensor->dataType, &status); | |||||
| } | } | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| @@ -93,11 +91,8 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) { | |||||
| STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) { | STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) { | ||||
| MS_ASSERT(graph != nullptr); | MS_ASSERT(graph != nullptr); | ||||
| if (outputDataDType == TypeId::kTypeUnknown) { | |||||
| return RET_OK; | |||||
| } | |||||
| if (this->outputDataDType != TypeId::kNumberTypeFloat32 && this->outputDataDType != TypeId::kNumberTypeUInt8 && | if (this->outputDataDType != TypeId::kNumberTypeFloat32 && this->outputDataDType != TypeId::kNumberTypeUInt8 && | ||||
| this->outputDataDType != TypeId::kNumberTypeInt8) { | |||||
| this->outputDataDType != TypeId::kNumberTypeInt8 && this->outputDataDType != TypeId::kTypeUnknown) { | |||||
| MS_LOG(ERROR) << "Invalid outputDataType: " << this->outputDataDType; | MS_LOG(ERROR) << "Invalid outputDataType: " << this->outputDataDType; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -108,6 +103,9 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) { | |||||
| if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited) { | if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| int32_t tensorDataType = this->outputDataDType != TypeId::kTypeUnknown | |||||
| ? this->inputDataDType | |||||
| : TensorDataType::GetInstance()->GetTensorType(graphOutIdx); | |||||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | ||||
| auto nodeName = (*iter)->name; | auto nodeName = (*iter)->name; | ||||
| MS_ASSERT(node != nullptr); | MS_ASSERT(node != nullptr); | ||||
| @@ -115,9 +113,8 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) { | |||||
| if ((*iter)->outputIndex.at(outputIndexIdx) == graphOutIdx) { | if ((*iter)->outputIndex.at(outputIndexIdx) == graphOutIdx) { | ||||
| // insert transNode | // insert transNode | ||||
| STATUS status = RET_OK; | STATUS status = RET_OK; | ||||
| if (this->outputDataDType != tensor->dataType) { | |||||
| iter = InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, tensor->dataType, this->outputDataDType, | |||||
| &status); | |||||
| if (tensorDataType != tensor->dataType && tensorDataType != kTypeUnknown) { | |||||
| iter = InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, tensor->dataType, tensorDataType, &status); | |||||
| } | } | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "InsertDTypeTransNode after " << nodeName.c_str() << " failed"; | MS_LOG(ERROR) << "InsertDTypeTransNode after " << nodeName.c_str() << " failed"; | ||||
| @@ -17,23 +17,42 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <cmath> | #include <cmath> | ||||
| #include "tools/converter/legacy_optimizer/graph/tensor_quant_pass.h" | #include "tools/converter/legacy_optimizer/graph/tensor_quant_pass.h" | ||||
| #include "tools/converter/converter_context.h" | |||||
| #include "tools/converter/quantizer/quantize_util.h" | #include "tools/converter/quantizer/quantize_util.h" | ||||
| #include "tools/common/tensor_util.h" | #include "tools/common/tensor_util.h" | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| STATUS TensorQuantPass::Run(schema::MetaGraphT *graph) { | STATUS TensorQuantPass::Run(schema::MetaGraphT *graph) { | ||||
| for (auto &node : graph->nodes) { | |||||
| if (node->primitive->value.type == PrimitiveType_QuantDTypeCast) { | |||||
| auto attr = node->primitive->value.AsQuantDTypeCast(); | |||||
| auto &inputTensor = graph->allTensors.at(node->inputIndex.front()); | |||||
| inputTensor->dataType = attr->srcT; | |||||
| auto &outputTensor = graph->allTensors.at(node->outputIndex.front()); | |||||
| outputTensor->dataType = attr->dstT; | |||||
| if (attr->srcT == TypeId::kNumberTypeUInt8) { | |||||
| attr->srcT = TypeId::kNumberTypeInt8; | |||||
| } | |||||
| if (attr->dstT == TypeId::kNumberTypeUInt8) { | |||||
| attr->dstT = TypeId::kNumberTypeInt8; | |||||
| } | |||||
| } | |||||
| } | |||||
| int index = -1; | |||||
| for (auto &tensor : graph->allTensors) { | for (auto &tensor : graph->allTensors) { | ||||
| if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited || tensor->data.empty()) { | |||||
| index++; | |||||
| if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| if (tensor->dataType != TypeId::kNumberTypeFloat32 && tensor->dataType != TypeId::kNumberTypeFloat && | if (tensor->dataType != TypeId::kNumberTypeFloat32 && tensor->dataType != TypeId::kNumberTypeFloat && | ||||
| tensor->dataType != TypeId::kNumberTypeUInt8) { | |||||
| tensor->dataType != TypeId::kNumberTypeUInt8 && tensor->dataType != TypeId::kTypeUnknown) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| // perlayer | // perlayer | ||||
| if (tensor->quantParams.size() == 1) { | if (tensor->quantParams.size() == 1) { | ||||
| auto &quantParam = tensor->quantParams.front(); | auto &quantParam = tensor->quantParams.front(); | ||||
| size_t wShapeSize = GetShapeSize(*(tensor.get())); | |||||
| size_t wShapeSize = tensor->data.empty() ? 0 : GetShapeSize(*(tensor.get())); | |||||
| void *oriWeightData = tensor->data.data(); | void *oriWeightData = tensor->data.data(); | ||||
| if (quantParam->dstDtype == TypeId::kNumberTypeInt8) { | if (quantParam->dstDtype == TypeId::kNumberTypeInt8) { | ||||
| std::vector<int8_t> qDatas(wShapeSize); | std::vector<int8_t> qDatas(wShapeSize); | ||||
| @@ -41,6 +60,9 @@ STATUS TensorQuantPass::Run(schema::MetaGraphT *graph) { | |||||
| if (tensor->dataType == TypeId::kNumberTypeFloat || | if (tensor->dataType == TypeId::kNumberTypeFloat || | ||||
| tensor->dataType == TypeId::kNumberTypeFloat32) { // normal awareing quant | tensor->dataType == TypeId::kNumberTypeFloat32) { // normal awareing quant | ||||
| auto *weightData = static_cast<float *>(oriWeightData); | auto *weightData = static_cast<float *>(oriWeightData); | ||||
| if (weightData == nullptr) { | |||||
| continue; | |||||
| } | |||||
| for (size_t j = 0; j < wShapeSize; j++) { | for (size_t j = 0; j < wShapeSize; j++) { | ||||
| qDatas[j] = quant::QuantizeData<int8_t>(weightData[j], weightQauntParam.get()); | qDatas[j] = quant::QuantizeData<int8_t>(weightData[j], weightQauntParam.get()); | ||||
| } | } | ||||
| @@ -52,15 +74,18 @@ STATUS TensorQuantPass::Run(schema::MetaGraphT *graph) { | |||||
| weightQauntParam->zeroPoint -= 128; | weightQauntParam->zeroPoint -= 128; | ||||
| tensor->quantParams.clear(); | tensor->quantParams.clear(); | ||||
| tensor->quantParams.emplace_back(weightQauntParam.release()); | tensor->quantParams.emplace_back(weightQauntParam.release()); | ||||
| TensorDataType::GetInstance()->UpdateTensorType(index, TypeId::kNumberTypeUInt8); | |||||
| } | } | ||||
| tensor->dataType = TypeId::kNumberTypeInt8; | tensor->dataType = TypeId::kNumberTypeInt8; | ||||
| tensor->data.clear(); | |||||
| tensor->data.resize(wShapeSize * sizeof(int8_t)); | |||||
| auto ret = | |||||
| memcpy_s(tensor->data.data(), wShapeSize * sizeof(int8_t), qDatas.data(), wShapeSize * sizeof(int8_t)); | |||||
| if (ret != EOK) { | |||||
| MS_LOG(ERROR) << "memcpy_s failed: " << ret; | |||||
| return RET_ERROR; | |||||
| if (!tensor->data.empty()) { | |||||
| tensor->data.clear(); | |||||
| tensor->data.resize(wShapeSize * sizeof(int8_t)); | |||||
| auto ret = | |||||
| memcpy_s(tensor->data.data(), wShapeSize * sizeof(int8_t), qDatas.data(), wShapeSize * sizeof(int8_t)); | |||||
| if (ret != EOK) { | |||||
| MS_LOG(ERROR) << "memcpy_s failed: " << ret; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | } | ||||
| } else if (quantParam->dstDtype == TypeId::kNumberTypeInt32) { | } else if (quantParam->dstDtype == TypeId::kNumberTypeInt32) { | ||||
| // quant bias data | // quant bias data | ||||
| @@ -53,7 +53,7 @@ STATUS TfliteDequantizeParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| MS_LOG(ERROR) << "new op failed"; | MS_LOG(ERROR) << "new op failed"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| attr->srcT = kNumberTypeInt8; | |||||
| attr->srcT = GetTfliteDataType(in_tensor->type); | |||||
| attr->dstT = GetTfliteDataType(out_tensor->type); | attr->dstT = GetTfliteDataType(out_tensor->type); | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| op->primitive->value.type = schema::PrimitiveType_QuantDTypeCast; | op->primitive->value.type = schema::PrimitiveType_QuantDTypeCast; | ||||
| @@ -76,12 +76,6 @@ void TfliteModelParser::SetTensorQuantParam(const std::unique_ptr<tflite::Tensor | |||||
| quant_param->zeroPoint = tflite_tensor->quantization->zero_point[i]; | quant_param->zeroPoint = tflite_tensor->quantization->zero_point[i]; | ||||
| } | } | ||||
| // change quant param min to 0 to fit ms-lite ops | |||||
| if (GetTfliteDataType(tflite_tensor->type) == TypeId::kNumberTypeUInt8 && tensor->data.empty()) { | |||||
| quant_param->zeroPoint = quant_param->zeroPoint - 128; | |||||
| tensor->dataType = TypeId::kNumberTypeInt8; | |||||
| } | |||||
| if (!tflite_tensor->quantization->min.empty()) { | if (!tflite_tensor->quantization->min.empty()) { | ||||
| quant_param->min = tflite_tensor->quantization->min[i]; | quant_param->min = tflite_tensor->quantization->min[i]; | ||||
| } | } | ||||
| @@ -127,7 +121,6 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflit | |||||
| } | } | ||||
| continue; | continue; | ||||
| } | } | ||||
| sub_graph->nodes.emplace_back(op.release()); | sub_graph->nodes.emplace_back(op.release()); | ||||
| opMap[sub_graph->nodes.back()->name] = sub_graph->nodes.back().get(); | opMap[sub_graph->nodes.back()->name] = sub_graph->nodes.back().get(); | ||||
| tfliteOpMap[tflite_op.get()] = sub_graph->nodes.back().get(); | tfliteOpMap[tflite_op.get()] = sub_graph->nodes.back().get(); | ||||
| @@ -53,7 +53,7 @@ STATUS TfliteQuantizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::u | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| attr->srcT = GetTfliteDataType(in_tensor->type); | attr->srcT = GetTfliteDataType(in_tensor->type); | ||||
| attr->dstT = kNumberTypeInt8; | |||||
| attr->dstT = GetTfliteDataType(out_tensor->type); | |||||
| op->primitive->value.type = schema::PrimitiveType_QuantDTypeCast; | op->primitive->value.type = schema::PrimitiveType_QuantDTypeCast; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| } else { | } else { | ||||