|
|
|
@@ -75,10 +75,9 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) { |
|
|
|
} |
|
|
|
|
|
|
|
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { |
|
|
|
auto &node = *iter; |
|
|
|
auto nodeName = node->name; |
|
|
|
for (size_t inputIndexIdx = 0; inputIndexIdx < node->inputIndex.size(); inputIndexIdx++) { |
|
|
|
if (node->inputIndex.at(inputIndexIdx) == graphInIdx) { |
|
|
|
auto nodeName = (*iter)->name; |
|
|
|
for (size_t inputIndexIdx = 0; inputIndexIdx < (*iter)->inputIndex.size(); inputIndexIdx++) { |
|
|
|
if ((*iter)->inputIndex.at(inputIndexIdx) == graphInIdx) { |
|
|
|
STATUS status = RET_OK; |
|
|
|
|
|
|
|
// insert dtype cast node between input tensor and input node |
|
|
|
@@ -108,11 +107,10 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) { |
|
|
|
auto &graphOutIdxes = graph->outputIndex; |
|
|
|
for (auto graphOutIdx : graphOutIdxes) { |
|
|
|
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { |
|
|
|
auto &node = *iter; |
|
|
|
auto nodeName = node->name; |
|
|
|
auto nodeName = (*iter)->name; |
|
|
|
MS_ASSERT(node != nullptr); |
|
|
|
for (size_t outputIndexIdx = 0; outputIndexIdx < node->outputIndex.size(); outputIndexIdx++) { |
|
|
|
if (node->outputIndex.at(outputIndexIdx) == graphOutIdx) { |
|
|
|
for (size_t outputIndexIdx = 0; outputIndexIdx < (*iter)->outputIndex.size(); outputIndexIdx++) { |
|
|
|
if ((*iter)->outputIndex.at(outputIndexIdx) == graphOutIdx) { |
|
|
|
// insert transNode |
|
|
|
STATUS status = RET_OK; |
|
|
|
iter = InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, kInt8ToFP32, &status); |
|
|
|
@@ -135,7 +133,6 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) { |
|
|
|
if (IsContain(GetUint8OpList(), GetCNodeTType(**iter)) && (*iter)->quantType == QuantType_AwareTraining) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto &node = *iter; |
|
|
|
if (GetCNodeTType(**iter) == PrimitiveType_QuantDTypeCast) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
@@ -143,8 +140,8 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) { |
|
|
|
if (GetCNodeTType(**iter) == PrimitiveType_Shape) { |
|
|
|
needInsertPost = false; |
|
|
|
} |
|
|
|
auto nodeName = node->name; |
|
|
|
if (node->inputIndex.size() < kMinInputNum) { |
|
|
|
auto nodeName = (*iter)->name; |
|
|
|
if ((*iter)->inputIndex.size() < kMinInputNum) { |
|
|
|
MS_LOG(ERROR) << "Op " << nodeName.c_str() << " should have " << kMinInputNum << " input tensor at least"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
|