|
|
|
@@ -51,13 +51,7 @@ STATUS DTypeTransPass::Run(schema::MetaGraphT *graph) { |
|
|
|
|
|
|
|
STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) { |
|
|
|
MS_ASSERT(graph != nullptr); |
|
|
|
// modify inputTensor first |
|
|
|
auto &graphInIdxes = graph->inputIndex; |
|
|
|
for (auto graphInIdx : graphInIdxes) { |
|
|
|
MS_ASSERT(graph->allTensors.size() > graphInIdx); |
|
|
|
auto &graphInTensor = graph->allTensors.at(graphInIdx); |
|
|
|
graphInTensor->dataType = TypeId::kNumberTypeInt8; |
|
|
|
} |
|
|
|
|
|
|
|
if (this->inputDataDType == TypeId::kNumberTypeInt8) { |
|
|
|
return RET_OK; |
|
|
|
@@ -70,7 +64,7 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) { |
|
|
|
for (auto graphInIdx : graphInIdxes) { |
|
|
|
MS_ASSERT(graphInIdx < graph->allTensors.size()); |
|
|
|
auto &tensor = graph->allTensors.at(graphInIdx); |
|
|
|
if (tensor->dims.size() != kNHWCDimNumber) { |
|
|
|
if (tensor->dims.size() != kNHWCDimNumber || tensor->dataType != kNumberTypeInt8) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -137,7 +131,7 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) { |
|
|
|
MS_ASSERT(graph != nullptr); |
|
|
|
// insert transNode before and after existNode |
|
|
|
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { |
|
|
|
if (IsContain(GetUint8OpList(), GetCNodeTType(**iter)) && (*iter)->quantType == QuantType_AwareTraining) { |
|
|
|
if (IsContain(GetInt8OpList(), GetCNodeTType(**iter)) && (*iter)->quantType == QuantType_AwareTraining) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (GetCNodeTType(**iter) == PrimitiveType_QuantDTypeCast) { |
|
|
|
@@ -157,10 +151,16 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) { |
|
|
|
for (size_t i = 0; i < (*iter)->inputIndex.size(); i++) { |
|
|
|
MS_ASSERT(graph->allTensors.size() > (*iter)->inputIndex.at(i)); |
|
|
|
auto &preTensor = graph->allTensors.at((*iter)->inputIndex.at(i)); |
|
|
|
if (preTensor->dataType == TypeId::kNumberTypeInt || preTensor->dataType == TypeId::kNumberTypeInt32) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto &graphInIdxes = graph->inputIndex; |
|
|
|
if (!preTensor->data.empty() && !IsContain(graphInIdxes, (*iter)->inputIndex.at(i))) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (IsContain(graphInIdxes, (*iter)->inputIndex.at(i))) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
iter = InsertDTypeTransNode(graph, iter, kBefore, i, kInt8ToFP32, &status); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "InsertInt8ToFloat32Node before " << nodeName.c_str() << " failed"; |
|
|
|
@@ -170,6 +170,10 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) { |
|
|
|
|
|
|
|
if (needInsertPost) { |
|
|
|
for (size_t i = 0; i < (*iter)->outputIndex.size(); i++) { |
|
|
|
auto &postTensor = graph->allTensors.at((*iter)->outputIndex.at(i)); |
|
|
|
if (postTensor->dataType == TypeId::kNumberTypeInt || postTensor->dataType == TypeId::kNumberTypeInt32) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
iter = InsertDTypeTransNode(graph, iter, kAfter, i, kFP32ToInt8, &status); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "InsertFloat32ToUint8Node after " << nodeName.c_str() << " failed"; |
|
|
|
|