Merge pull request !4372 from wangshaocong/litetags/v0.7.0-beta
| @@ -58,7 +58,7 @@ int Reduce::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor | |||||
| for (size_t i = 0; i < in_shape.size(); i++) { | for (size_t i = 0; i < in_shape.size(); i++) { | ||||
| bool reduce_axis = false; | bool reduce_axis = false; | ||||
| for (int idx = 0; idx < num_axes; ++idx) { | for (int idx = 0; idx < num_axes; ++idx) { | ||||
| if (static_cast<size_t>((*axes)[idx]) == i) { | |||||
| if (static_cast<size_t>((*axes)[idx]) == i || static_cast<size_t>((*axes)[idx] + in_shape.size()) == i) { | |||||
| reduce_axis = true; | reduce_axis = true; | ||||
| break; | break; | ||||
| } | } | ||||
| @@ -71,7 +71,7 @@ int ReduceCPUKernel::CheckParameters() { | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| for (auto i = 0; i < num_axes_; i++) { | for (auto i = 0; i < num_axes_; i++) { | ||||
| if (axes_[i] < -static_cast<int>(input_rank) || static_cast<size_t>(axes_[i]) >= input_rank) { | |||||
| if (axes_[i] < -static_cast<int>(input_rank) || axes_[i] >= static_cast<int>(input_rank)) { | |||||
| MS_LOG(ERROR) << "Reduce got invalid axis " << axes_[i] << ", axis should be in [" | MS_LOG(ERROR) << "Reduce got invalid axis " << axes_[i] << ", axis should be in [" | ||||
| << -static_cast<int>(input_rank) << ", " << input_rank - 1 << "]."; | << -static_cast<int>(input_rank) << ", " << input_rank - 1 << "]."; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -236,18 +236,31 @@ void TfliteModelParser::SetInputTensor(const std::unique_ptr<tflite::SubGraphT> | |||||
| } | } | ||||
| } | } | ||||
| void TfliteModelParser::SetGraphTensorIndex(const mindspore::lite::TensorCache &tensorCache, | |||||
| void TfliteModelParser::SetGraphTensorIndex(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const mindspore::lite::TensorCache &tensorCache, | |||||
| schema::MetaGraphT *subGraphDef) { | schema::MetaGraphT *subGraphDef) { | ||||
| auto opGraph = OpGraphT::Build(subGraphDef); | |||||
| auto graphInputs = tensorCache.GetGraphInputs(); | auto graphInputs = tensorCache.GetGraphInputs(); | ||||
| auto graphOutputs = opGraph->GetOutputNode(); | |||||
| subGraphDef->inputIndex.assign(graphInputs.begin(), graphInputs.end()); | subGraphDef->inputIndex.assign(graphInputs.begin(), graphInputs.end()); | ||||
| for (const auto &output : graphOutputs) { | |||||
| auto op = opMap[output->ID()]; | |||||
| for (auto outputIndex : op->outputIndex) { | |||||
| subGraphDef->outputIndex.emplace_back(outputIndex); | |||||
| for (auto outputIndex : tflite_subgraph->outputs) { | |||||
| int i = 0; | |||||
| bool found = false; | |||||
| for (const auto &tfliteOp : tflite_subgraph->operators) { | |||||
| int j = 0; | |||||
| auto opType = GetTfliteNodeType(tfliteOp, tflite_model); | |||||
| std::string opName = opType + "-" + std::to_string(i++); | |||||
| for (auto opOutputIndex : tfliteOp->outputs) { | |||||
| if (outputIndex == opOutputIndex) { | |||||
| subGraphDef->outputIndex.emplace_back(opMap[opName]->outputIndex[j]); | |||||
| found = true; | |||||
| break; | |||||
| } | |||||
| j++; | |||||
| } | |||||
| if (found) { | |||||
| break; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -284,7 +297,7 @@ MetaGraphT *TfliteModelParser::Parse(const std::string &modelFile, const std::st | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| SetGraphTensorIndex(tensorCache, subGraph.get()); | |||||
| SetGraphTensorIndex(tflite_subgraph, tflite_model, tensorCache, subGraph.get()); | |||||
| SetAllTensors(tensorCache, subGraph.get()); | SetAllTensors(tensorCache, subGraph.get()); | ||||
| return subGraph.release(); | return subGraph.release(); | ||||
| } | } | ||||
| @@ -50,7 +50,10 @@ class TfliteModelParser : public ModelParser { | |||||
| void SetInputTensor(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, TensorCache *tensor_cache); | void SetInputTensor(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, TensorCache *tensor_cache); | ||||
| void SetGraphTensorIndex(const mindspore::lite::TensorCache &tensorCache, schema::MetaGraphT *subGraphDef); | |||||
| void SetGraphTensorIndex(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, | |||||
| const mindspore::lite::TensorCache &tensorCache, | |||||
| schema::MetaGraphT *subGraphDef); | |||||
| STATUS ParseOp(const std::unique_ptr<tflite::ModelT> &tflite_model, | STATUS ParseOp(const std::unique_ptr<tflite::ModelT> &tflite_model, | ||||
| const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::MetaGraphT *sub_graph, | const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::MetaGraphT *sub_graph, | ||||