| @@ -63,7 +63,7 @@ kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vector<lite::T | |||||
| auto *weight_tensor = inputs.at(kWeightIndex); | auto *weight_tensor = inputs.at(kWeightIndex); | ||||
| // data of second tensor of fc may be nullptr | // data of second tensor of fc may be nullptr | ||||
| auto *restore_data = weight_tensor->data_c(); | auto *restore_data = weight_tensor->data_c(); | ||||
| if (!weight_tensor->GetQuantParams().empty()) { | |||||
| if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) { | |||||
| auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor); | auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor); | ||||
| if (dequant_weight == nullptr) { | if (dequant_weight == nullptr) { | ||||
| MS_LOG(ERROR) << "dequant data is nullptr."; | MS_LOG(ERROR) << "dequant data is nullptr."; | ||||
| @@ -91,7 +91,7 @@ kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vector<lite::T | |||||
| } | } | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (!weight_tensor->GetQuantParams().empty()) { | |||||
| if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) { | |||||
| weight_tensor->FreeData(); | weight_tensor->FreeData(); | ||||
| weight_tensor->SetData(restore_data); | weight_tensor->SetData(restore_data); | ||||
| } | } | ||||
| @@ -93,7 +93,7 @@ std::vector<schema::PrimitiveType> GetNhwcAllInputOpList() { return nhwcOpAllInp | |||||
| std::vector<schema::PrimitiveType> GetUint8NhwcOpList() { return int8NeedNhwcOpList; } | std::vector<schema::PrimitiveType> GetUint8NhwcOpList() { return int8NeedNhwcOpList; } | ||||
| std::vector<schema::PrimitiveType> GetUint8OpList() { return int8OpList; } | |||||
| std::vector<schema::PrimitiveType> GetInt8OpList() { return int8OpList; } | |||||
| STATUS NodeUtils::ConvertDims(mindspore::schema::Format src_format, const std::vector<int32_t> &src_dims, | STATUS NodeUtils::ConvertDims(mindspore::schema::Format src_format, const std::vector<int32_t> &src_dims, | ||||
| mindspore::schema::Format dst_format, std::vector<int32_t> *dst_dims) { | mindspore::schema::Format dst_format, std::vector<int32_t> *dst_dims) { | ||||
| @@ -42,7 +42,7 @@ std::vector<schema::PrimitiveType> Getfp32FullOpList(); | |||||
| std::vector<schema::PrimitiveType> GetUint8NhwcOpList(); | std::vector<schema::PrimitiveType> GetUint8NhwcOpList(); | ||||
| std::vector<schema::PrimitiveType> GetUint8OpList(); | |||||
| std::vector<schema::PrimitiveType> GetInt8OpList(); | |||||
| class NodeUtils { | class NodeUtils { | ||||
| public: | public: | ||||
| @@ -51,13 +51,7 @@ STATUS DTypeTransPass::Run(schema::MetaGraphT *graph) { | |||||
| STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) { | STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) { | ||||
| MS_ASSERT(graph != nullptr); | MS_ASSERT(graph != nullptr); | ||||
| // modify inputTensor first | |||||
| auto &graphInIdxes = graph->inputIndex; | auto &graphInIdxes = graph->inputIndex; | ||||
| for (auto graphInIdx : graphInIdxes) { | |||||
| MS_ASSERT(graph->allTensors.size() > graphInIdx); | |||||
| auto &graphInTensor = graph->allTensors.at(graphInIdx); | |||||
| graphInTensor->dataType = TypeId::kNumberTypeInt8; | |||||
| } | |||||
| if (this->inputDataDType == TypeId::kNumberTypeInt8) { | if (this->inputDataDType == TypeId::kNumberTypeInt8) { | ||||
| return RET_OK; | return RET_OK; | ||||
| @@ -70,7 +64,7 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) { | |||||
| for (auto graphInIdx : graphInIdxes) { | for (auto graphInIdx : graphInIdxes) { | ||||
| MS_ASSERT(graphInIdx < graph->allTensors.size()); | MS_ASSERT(graphInIdx < graph->allTensors.size()); | ||||
| auto &tensor = graph->allTensors.at(graphInIdx); | auto &tensor = graph->allTensors.at(graphInIdx); | ||||
| if (tensor->dims.size() != kNHWCDimNumber) { | |||||
| if (tensor->dims.size() != kNHWCDimNumber || tensor->dataType != kNumberTypeInt8) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -137,7 +131,7 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) { | |||||
| MS_ASSERT(graph != nullptr); | MS_ASSERT(graph != nullptr); | ||||
| // insert transNode before and after existNode | // insert transNode before and after existNode | ||||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | ||||
| if (IsContain(GetUint8OpList(), GetCNodeTType(**iter)) && (*iter)->quantType == QuantType_AwareTraining) { | |||||
| if (IsContain(GetInt8OpList(), GetCNodeTType(**iter)) && (*iter)->quantType == QuantType_AwareTraining) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| if (GetCNodeTType(**iter) == PrimitiveType_QuantDTypeCast) { | if (GetCNodeTType(**iter) == PrimitiveType_QuantDTypeCast) { | ||||
| @@ -157,10 +151,16 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) { | |||||
| for (size_t i = 0; i < (*iter)->inputIndex.size(); i++) { | for (size_t i = 0; i < (*iter)->inputIndex.size(); i++) { | ||||
| MS_ASSERT(graph->allTensors.size() > (*iter)->inputIndex.at(i)); | MS_ASSERT(graph->allTensors.size() > (*iter)->inputIndex.at(i)); | ||||
| auto &preTensor = graph->allTensors.at((*iter)->inputIndex.at(i)); | auto &preTensor = graph->allTensors.at((*iter)->inputIndex.at(i)); | ||||
| if (preTensor->dataType == TypeId::kNumberTypeInt || preTensor->dataType == TypeId::kNumberTypeInt32) { | |||||
| continue; | |||||
| } | |||||
| auto &graphInIdxes = graph->inputIndex; | auto &graphInIdxes = graph->inputIndex; | ||||
| if (!preTensor->data.empty() && !IsContain(graphInIdxes, (*iter)->inputIndex.at(i))) { | if (!preTensor->data.empty() && !IsContain(graphInIdxes, (*iter)->inputIndex.at(i))) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (IsContain(graphInIdxes, (*iter)->inputIndex.at(i))) { | |||||
| continue; | |||||
| } | |||||
| iter = InsertDTypeTransNode(graph, iter, kBefore, i, kInt8ToFP32, &status); | iter = InsertDTypeTransNode(graph, iter, kBefore, i, kInt8ToFP32, &status); | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "InsertInt8ToFloat32Node before " << nodeName.c_str() << " failed"; | MS_LOG(ERROR) << "InsertInt8ToFloat32Node before " << nodeName.c_str() << " failed"; | ||||
| @@ -170,6 +170,10 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) { | |||||
| if (needInsertPost) { | if (needInsertPost) { | ||||
| for (size_t i = 0; i < (*iter)->outputIndex.size(); i++) { | for (size_t i = 0; i < (*iter)->outputIndex.size(); i++) { | ||||
| auto &postTensor = graph->allTensors.at((*iter)->outputIndex.at(i)); | |||||
| if (postTensor->dataType == TypeId::kNumberTypeInt || postTensor->dataType == TypeId::kNumberTypeInt32) { | |||||
| continue; | |||||
| } | |||||
| iter = InsertDTypeTransNode(graph, iter, kAfter, i, kFP32ToInt8, &status); | iter = InsertDTypeTransNode(graph, iter, kAfter, i, kFP32ToInt8, &status); | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "InsertFloat32ToUint8Node after " << nodeName.c_str() << " failed"; | MS_LOG(ERROR) << "InsertFloat32ToUint8Node after " << nodeName.c_str() << " failed"; | ||||
| @@ -79,6 +79,7 @@ void TfliteModelParser::SetTensorQuantParam(const std::unique_ptr<tflite::Tensor | |||||
| // change quant param min to 0 to fit ms-lite ops | // change quant param min to 0 to fit ms-lite ops | ||||
| if (GetTfliteDataType(tflite_tensor->type) == TypeId::kNumberTypeUInt8 && tensor->data.empty()) { | if (GetTfliteDataType(tflite_tensor->type) == TypeId::kNumberTypeUInt8 && tensor->data.empty()) { | ||||
| quant_param->zeroPoint = quant_param->zeroPoint - 128; | quant_param->zeroPoint = quant_param->zeroPoint - 128; | ||||
| tensor->dataType = TypeId::kNumberTypeInt8; | |||||
| } | } | ||||
| if (!tflite_tensor->quantization->min.empty()) { | if (!tflite_tensor->quantization->min.empty()) { | ||||
| @@ -164,11 +165,7 @@ STATUS TfliteModelParser::ConvertTensor(const std::unique_ptr<tflite::SubGraphT> | |||||
| MS_LOG(ERROR) << "obtain const tensor failed"; | MS_LOG(ERROR) << "obtain const tensor failed"; | ||||
| return status; | return status; | ||||
| } | } | ||||
| } else if (quantType == QuantType_AwareTraining && tensor->dataType == TypeId::kNumberTypeUInt8) { | |||||
| // set in/out tensor to int8 to fit ms-lite op | |||||
| tensor->dataType = TypeId::kNumberTypeInt8; | |||||
| } | } | ||||
| // set tensor attr | // set tensor attr | ||||
| if (isInput || isConst) { | if (isInput || isConst) { | ||||
| tensor->nodeType = schema::NodeType::NodeType_ValueNode; | tensor->nodeType = schema::NodeType::NodeType_ValueNode; | ||||
| @@ -145,7 +145,7 @@ STATUS AwareQuantizer::GenerateQuantParam() { | |||||
| STATUS AwareQuantizer::DoQuantize() { | STATUS AwareQuantizer::DoQuantize() { | ||||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | ||||
| auto &node = *iter; | auto &node = *iter; | ||||
| if (!IsContain(GetUint8OpList(), GetCNodeTType(*node))) { | |||||
| if (!IsContain(GetInt8OpList(), GetCNodeTType(*node))) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| if (node->quantType != schema::QuantType_AwareTraining) { | if (node->quantType != schema::QuantType_AwareTraining) { | ||||
| @@ -388,7 +388,7 @@ STATUS AwareQuantizer::DetermineNodeQuantType() { | |||||
| } | } | ||||
| } | } | ||||
| if (canQuant && IsContain(GetUint8OpList(), GetCNodeTType(*node))) { | |||||
| if (canQuant && IsContain(GetInt8OpList(), GetCNodeTType(*node))) { | |||||
| node->quantType = schema::QuantType_AwareTraining; | node->quantType = schema::QuantType_AwareTraining; | ||||
| } else { | } else { | ||||
| node->quantType = schema::QuantType_QUANT_NONE; | node->quantType = schema::QuantType_QUANT_NONE; | ||||