Browse Source

!4372 Modify the method for getting output index of metagraph.

Merge pull request !4372 from wangshaocong/lite
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
49fd9fa978
4 changed files with 28 additions and 12 deletions
  1. +1
    -1
      mindspore/lite/src/ops/reduce.cc
  2. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/fp32/reduce.cc
  3. +22
    -9
      mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc
  4. +4
    -1
      mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h

+ 1
- 1
mindspore/lite/src/ops/reduce.cc View File

@@ -58,7 +58,7 @@ int Reduce::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
for (size_t i = 0; i < in_shape.size(); i++) {
bool reduce_axis = false;
for (int idx = 0; idx < num_axes; ++idx) {
if (static_cast<size_t>((*axes)[idx]) == i) {
if (static_cast<size_t>((*axes)[idx]) == i || static_cast<size_t>((*axes)[idx] + in_shape.size()) == i) {
reduce_axis = true;
break;
}


+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/fp32/reduce.cc View File

@@ -71,7 +71,7 @@ int ReduceCPUKernel::CheckParameters() {
return RET_ERROR;
}
for (auto i = 0; i < num_axes_; i++) {
if (axes_[i] < -static_cast<int>(input_rank) || static_cast<size_t>(axes_[i]) >= input_rank) {
if (axes_[i] < -static_cast<int>(input_rank) || axes_[i] >= static_cast<int>(input_rank)) {
MS_LOG(ERROR) << "Reduce got invalid axis " << axes_[i] << ", axis should be in ["
<< -static_cast<int>(input_rank) << ", " << input_rank - 1 << "].";
return RET_ERROR;


+ 22
- 9
mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc View File

@@ -236,18 +236,31 @@ void TfliteModelParser::SetInputTensor(const std::unique_ptr<tflite::SubGraphT>
}
}

void TfliteModelParser::SetGraphTensorIndex(const mindspore::lite::TensorCache &tensorCache,
void TfliteModelParser::SetGraphTensorIndex(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const mindspore::lite::TensorCache &tensorCache,
schema::MetaGraphT *subGraphDef) {
auto opGraph = OpGraphT::Build(subGraphDef);
auto graphInputs = tensorCache.GetGraphInputs();
auto graphOutputs = opGraph->GetOutputNode();

subGraphDef->inputIndex.assign(graphInputs.begin(), graphInputs.end());

for (const auto &output : graphOutputs) {
auto op = opMap[output->ID()];
for (auto outputIndex : op->outputIndex) {
subGraphDef->outputIndex.emplace_back(outputIndex);
for (auto outputIndex : tflite_subgraph->outputs) {
int i = 0;
bool found = false;
for (const auto &tfliteOp : tflite_subgraph->operators) {
int j = 0;
auto opType = GetTfliteNodeType(tfliteOp, tflite_model);
std::string opName = opType + "-" + std::to_string(i++);
for (auto opOutputIndex : tfliteOp->outputs) {
if (outputIndex == opOutputIndex) {
subGraphDef->outputIndex.emplace_back(opMap[opName]->outputIndex[j]);
found = true;
break;
}
j++;
}
if (found) {
break;
}
}
}
}
@@ -284,7 +297,7 @@ MetaGraphT *TfliteModelParser::Parse(const std::string &modelFile, const std::st
return nullptr;
}

SetGraphTensorIndex(tensorCache, subGraph.get());
SetGraphTensorIndex(tflite_subgraph, tflite_model, tensorCache, subGraph.get());
SetAllTensors(tensorCache, subGraph.get());
return subGraph.release();
}


+ 4
- 1
mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h View File

@@ -50,7 +50,10 @@ class TfliteModelParser : public ModelParser {

void SetInputTensor(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, TensorCache *tensor_cache);

void SetGraphTensorIndex(const mindspore::lite::TensorCache &tensorCache, schema::MetaGraphT *subGraphDef);
void SetGraphTensorIndex(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
const std::unique_ptr<tflite::ModelT> &tflite_model,
const mindspore::lite::TensorCache &tensorCache,
schema::MetaGraphT *subGraphDef);

STATUS ParseOp(const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::MetaGraphT *sub_graph,


Loading…
Cancel
Save