| @@ -120,6 +120,8 @@ kernel::LiteKernel *KernelRegistry::GetKernel(const std::vector<Tensor *> &in_te | |||
| kernel->set_desc(key); | |||
| } | |||
| return kernel; | |||
| } else { | |||
| free(parameter); | |||
| } | |||
| return nullptr; | |||
| } | |||
| @@ -19,8 +19,10 @@ | |||
| #include <string> | |||
| #include <set> | |||
| #include <map> | |||
| #include "include/errorcode.h" | |||
| #include "src/common/log_adapter.h" | |||
| #include "ir/dtype/type_id.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -68,6 +70,26 @@ class NoSupportOp { | |||
| std::set<std::string> noSupportOps; | |||
| std::string fmkType; | |||
| }; | |||
| class TensorDataType { | |||
| public: | |||
| ~TensorDataType() = default; | |||
| static TensorDataType *GetInstance() { | |||
| static TensorDataType tensorDataType; | |||
| return &tensorDataType; | |||
| } | |||
| void UpdateTensorType(int32_t index, int32_t type) { tensorDataTypeMap[index] = type; } | |||
| int32_t GetTensorType(int32_t index) const { | |||
| if (tensorDataTypeMap.find(index) == tensorDataTypeMap.end()) { | |||
| return TypeId::kTypeUnknown; | |||
| } | |||
| return tensorDataTypeMap.at(index); | |||
| } | |||
| private: | |||
| TensorDataType() {} | |||
| std::map<int32_t, int32_t> tensorDataTypeMap; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_RETURN_CODE_H | |||
| @@ -132,9 +132,11 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||
| // do quantization | |||
| { | |||
| Optimizer fusionOptimizer; | |||
| fusionOptimizer.AddPass(new (std::nothrow) TensorQuantPass()); | |||
| status = fusionOptimizer.Run(graphDefT); | |||
| Optimizer tensorQuantOptimizer; | |||
| tensorQuantOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); | |||
| tensorQuantOptimizer.AddPass(new (std::nothrow) InferShapePass()); | |||
| tensorQuantOptimizer.AddPass(new (std::nothrow) TensorQuantPass()); | |||
| status = tensorQuantOptimizer.Run(graphDefT); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "DoQuantize failed!"; | |||
| return status; | |||
| @@ -18,6 +18,7 @@ | |||
| #include <string> | |||
| #include <set> | |||
| #include "tools/common/node_util.h" | |||
| #include "tools/converter/converter_context.h" | |||
| #include "src/common/common.h" | |||
| #include "src/common/utils.h" | |||
| @@ -52,12 +53,8 @@ STATUS DTypeTransPass::Run(schema::MetaGraphT *graph) { | |||
| STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) { | |||
| MS_ASSERT(graph != nullptr); | |||
| auto &graphInIdxes = graph->inputIndex; | |||
| if (this->inputDataDType == TypeId::kTypeUnknown) { | |||
| return RET_OK; | |||
| } | |||
| if (this->inputDataDType != TypeId::kNumberTypeFloat32 && this->inputDataDType != TypeId::kNumberTypeUInt8 && | |||
| this->inputDataDType != TypeId::kNumberTypeInt8) { | |||
| this->inputDataDType != TypeId::kNumberTypeInt8 && this->inputDataDType != TypeId::kTypeUnknown) { | |||
| MS_LOG(ERROR) << "Invalid inputDataType: " << this->inputDataDType; | |||
| return RET_ERROR; | |||
| } | |||
| @@ -67,7 +64,9 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) { | |||
| if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited) { | |||
| continue; | |||
| } | |||
| int32_t tensorDataType = this->inputDataDType != TypeId::kTypeUnknown | |||
| ? this->inputDataDType | |||
| : TensorDataType::GetInstance()->GetTensorType(graphInIdx); | |||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | |||
| auto nodeName = (*iter)->name; | |||
| for (size_t inputIndexIdx = 0; inputIndexIdx < (*iter)->inputIndex.size(); inputIndexIdx++) { | |||
| @@ -75,9 +74,8 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) { | |||
| STATUS status = RET_OK; | |||
| // insert dtype cast node between input tensor and input node | |||
| if (this->inputDataDType != tensor->dataType) { | |||
| iter = InsertDTypeTransNode(graph, iter, kBefore, inputIndexIdx, this->inputDataDType, tensor->dataType, | |||
| &status); | |||
| if (tensorDataType != tensor->dataType && tensorDataType != kTypeUnknown) { | |||
| iter = InsertDTypeTransNode(graph, iter, kBefore, inputIndexIdx, tensorDataType, tensor->dataType, &status); | |||
| } | |||
| if (status != RET_OK) { | |||
| @@ -93,11 +91,8 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) { | |||
| STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) { | |||
| MS_ASSERT(graph != nullptr); | |||
| if (outputDataDType == TypeId::kTypeUnknown) { | |||
| return RET_OK; | |||
| } | |||
| if (this->outputDataDType != TypeId::kNumberTypeFloat32 && this->outputDataDType != TypeId::kNumberTypeUInt8 && | |||
| this->outputDataDType != TypeId::kNumberTypeInt8) { | |||
| this->outputDataDType != TypeId::kNumberTypeInt8 && this->outputDataDType != TypeId::kTypeUnknown) { | |||
| MS_LOG(ERROR) << "Invalid outputDataType: " << this->outputDataDType; | |||
| return RET_ERROR; | |||
| } | |||
| @@ -108,6 +103,9 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) { | |||
| if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited) { | |||
| continue; | |||
| } | |||
| int32_t tensorDataType = this->outputDataDType != TypeId::kTypeUnknown | |||
| ? this->inputDataDType | |||
| : TensorDataType::GetInstance()->GetTensorType(graphOutIdx); | |||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | |||
| auto nodeName = (*iter)->name; | |||
| MS_ASSERT(node != nullptr); | |||
| @@ -115,9 +113,8 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) { | |||
| if ((*iter)->outputIndex.at(outputIndexIdx) == graphOutIdx) { | |||
| // insert transNode | |||
| STATUS status = RET_OK; | |||
| if (this->outputDataDType != tensor->dataType) { | |||
| iter = InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, tensor->dataType, this->outputDataDType, | |||
| &status); | |||
| if (tensorDataType != tensor->dataType && tensorDataType != kTypeUnknown) { | |||
| iter = InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, tensor->dataType, tensorDataType, &status); | |||
| } | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "InsertDTypeTransNode after " << nodeName.c_str() << " failed"; | |||
| @@ -17,23 +17,42 @@ | |||
| #include <vector> | |||
| #include <cmath> | |||
| #include "tools/converter/legacy_optimizer/graph/tensor_quant_pass.h" | |||
| #include "tools/converter/converter_context.h" | |||
| #include "tools/converter/quantizer/quantize_util.h" | |||
| #include "tools/common/tensor_util.h" | |||
| namespace mindspore::lite { | |||
| STATUS TensorQuantPass::Run(schema::MetaGraphT *graph) { | |||
| for (auto &node : graph->nodes) { | |||
| if (node->primitive->value.type == PrimitiveType_QuantDTypeCast) { | |||
| auto attr = node->primitive->value.AsQuantDTypeCast(); | |||
| auto &inputTensor = graph->allTensors.at(node->inputIndex.front()); | |||
| inputTensor->dataType = attr->srcT; | |||
| auto &outputTensor = graph->allTensors.at(node->outputIndex.front()); | |||
| outputTensor->dataType = attr->dstT; | |||
| if (attr->srcT == TypeId::kNumberTypeUInt8) { | |||
| attr->srcT = TypeId::kNumberTypeInt8; | |||
| } | |||
| if (attr->dstT == TypeId::kNumberTypeUInt8) { | |||
| attr->dstT = TypeId::kNumberTypeInt8; | |||
| } | |||
| } | |||
| } | |||
| int index = -1; | |||
| for (auto &tensor : graph->allTensors) { | |||
| if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited || tensor->data.empty()) { | |||
| index++; | |||
| if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited) { | |||
| continue; | |||
| } | |||
| if (tensor->dataType != TypeId::kNumberTypeFloat32 && tensor->dataType != TypeId::kNumberTypeFloat && | |||
| tensor->dataType != TypeId::kNumberTypeUInt8) { | |||
| tensor->dataType != TypeId::kNumberTypeUInt8 && tensor->dataType != TypeId::kTypeUnknown) { | |||
| continue; | |||
| } | |||
| // perlayer | |||
| if (tensor->quantParams.size() == 1) { | |||
| auto &quantParam = tensor->quantParams.front(); | |||
| size_t wShapeSize = GetShapeSize(*(tensor.get())); | |||
| size_t wShapeSize = tensor->data.empty() ? 0 : GetShapeSize(*(tensor.get())); | |||
| void *oriWeightData = tensor->data.data(); | |||
| if (quantParam->dstDtype == TypeId::kNumberTypeInt8) { | |||
| std::vector<int8_t> qDatas(wShapeSize); | |||
| @@ -41,6 +60,9 @@ STATUS TensorQuantPass::Run(schema::MetaGraphT *graph) { | |||
| if (tensor->dataType == TypeId::kNumberTypeFloat || | |||
| tensor->dataType == TypeId::kNumberTypeFloat32) { // normal awareing quant | |||
| auto *weightData = static_cast<float *>(oriWeightData); | |||
| if (weightData == nullptr) { | |||
| continue; | |||
| } | |||
| for (size_t j = 0; j < wShapeSize; j++) { | |||
| qDatas[j] = quant::QuantizeData<int8_t>(weightData[j], weightQauntParam.get()); | |||
| } | |||
| @@ -52,15 +74,18 @@ STATUS TensorQuantPass::Run(schema::MetaGraphT *graph) { | |||
| weightQauntParam->zeroPoint -= 128; | |||
| tensor->quantParams.clear(); | |||
| tensor->quantParams.emplace_back(weightQauntParam.release()); | |||
| TensorDataType::GetInstance()->UpdateTensorType(index, TypeId::kNumberTypeUInt8); | |||
| } | |||
| tensor->dataType = TypeId::kNumberTypeInt8; | |||
| tensor->data.clear(); | |||
| tensor->data.resize(wShapeSize * sizeof(int8_t)); | |||
| auto ret = | |||
| memcpy_s(tensor->data.data(), wShapeSize * sizeof(int8_t), qDatas.data(), wShapeSize * sizeof(int8_t)); | |||
| if (ret != EOK) { | |||
| MS_LOG(ERROR) << "memcpy_s failed: " << ret; | |||
| return RET_ERROR; | |||
| if (!tensor->data.empty()) { | |||
| tensor->data.clear(); | |||
| tensor->data.resize(wShapeSize * sizeof(int8_t)); | |||
| auto ret = | |||
| memcpy_s(tensor->data.data(), wShapeSize * sizeof(int8_t), qDatas.data(), wShapeSize * sizeof(int8_t)); | |||
| if (ret != EOK) { | |||
| MS_LOG(ERROR) << "memcpy_s failed: " << ret; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| } else if (quantParam->dstDtype == TypeId::kNumberTypeInt32) { | |||
| // quant bias data | |||
| @@ -53,7 +53,7 @@ STATUS TfliteDequantizeParser::Parse(TfliteTensorsInfo *tensors_info, | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->srcT = kNumberTypeInt8; | |||
| attr->srcT = GetTfliteDataType(in_tensor->type); | |||
| attr->dstT = GetTfliteDataType(out_tensor->type); | |||
| op->primitive->value.value = attr.release(); | |||
| op->primitive->value.type = schema::PrimitiveType_QuantDTypeCast; | |||
| @@ -76,12 +76,6 @@ void TfliteModelParser::SetTensorQuantParam(const std::unique_ptr<tflite::Tensor | |||
| quant_param->zeroPoint = tflite_tensor->quantization->zero_point[i]; | |||
| } | |||
| // change quant param min to 0 to fit ms-lite ops | |||
| if (GetTfliteDataType(tflite_tensor->type) == TypeId::kNumberTypeUInt8 && tensor->data.empty()) { | |||
| quant_param->zeroPoint = quant_param->zeroPoint - 128; | |||
| tensor->dataType = TypeId::kNumberTypeInt8; | |||
| } | |||
| if (!tflite_tensor->quantization->min.empty()) { | |||
| quant_param->min = tflite_tensor->quantization->min[i]; | |||
| } | |||
| @@ -127,7 +121,6 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflit | |||
| } | |||
| continue; | |||
| } | |||
| sub_graph->nodes.emplace_back(op.release()); | |||
| opMap[sub_graph->nodes.back()->name] = sub_graph->nodes.back().get(); | |||
| tfliteOpMap[tflite_op.get()] = sub_graph->nodes.back().get(); | |||
| @@ -53,7 +53,7 @@ STATUS TfliteQuantizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::u | |||
| return RET_NULL_PTR; | |||
| } | |||
| attr->srcT = GetTfliteDataType(in_tensor->type); | |||
| attr->dstT = kNumberTypeInt8; | |||
| attr->dstT = GetTfliteDataType(out_tensor->type); | |||
| op->primitive->value.type = schema::PrimitiveType_QuantDTypeCast; | |||
| op->primitive->value.value = attr.release(); | |||
| } else { | |||