|
|
|
@@ -236,18 +236,31 @@ void TfliteModelParser::SetInputTensor(const std::unique_ptr<tflite::SubGraphT> |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void TfliteModelParser::SetGraphTensorIndex(const mindspore::lite::TensorCache &tensorCache, |
|
|
|
void TfliteModelParser::SetGraphTensorIndex(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, |
|
|
|
const std::unique_ptr<tflite::ModelT> &tflite_model, |
|
|
|
const mindspore::lite::TensorCache &tensorCache, |
|
|
|
schema::MetaGraphT *subGraphDef) { |
|
|
|
auto opGraph = OpGraphT::Build(subGraphDef); |
|
|
|
auto graphInputs = tensorCache.GetGraphInputs(); |
|
|
|
auto graphOutputs = opGraph->GetOutputNode(); |
|
|
|
|
|
|
|
subGraphDef->inputIndex.assign(graphInputs.begin(), graphInputs.end()); |
|
|
|
|
|
|
|
for (const auto &output : graphOutputs) { |
|
|
|
auto op = opMap[output->ID()]; |
|
|
|
for (auto outputIndex : op->outputIndex) { |
|
|
|
subGraphDef->outputIndex.emplace_back(outputIndex); |
|
|
|
for (auto outputIndex : tflite_subgraph->outputs) { |
|
|
|
int i = 0; |
|
|
|
bool found = false; |
|
|
|
for (const auto &tfliteOp : tflite_subgraph->operators) { |
|
|
|
int j = 0; |
|
|
|
auto opType = GetTfliteNodeType(tfliteOp, tflite_model); |
|
|
|
std::string opName = opType + "-" + std::to_string(i++); |
|
|
|
for (auto opOutputIndex : tfliteOp->outputs) { |
|
|
|
if (outputIndex == opOutputIndex) { |
|
|
|
subGraphDef->outputIndex.emplace_back(opMap[opName]->outputIndex[j]); |
|
|
|
found = true; |
|
|
|
break; |
|
|
|
} |
|
|
|
j++; |
|
|
|
} |
|
|
|
if (found) { |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -284,7 +297,7 @@ MetaGraphT *TfliteModelParser::Parse(const std::string &modelFile, const std::st |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
SetGraphTensorIndex(tensorCache, subGraph.get()); |
|
|
|
SetGraphTensorIndex(tflite_subgraph, tflite_model, tensorCache, subGraph.get()); |
|
|
|
SetAllTensors(tensorCache, subGraph.get()); |
|
|
|
return subGraph.release(); |
|
|
|
} |
|
|
|
|