From 551cdfe2f521438549f4b33785c89a6709344546 Mon Sep 17 00:00:00 2001 From: hangq Date: Sat, 8 Aug 2020 17:49:40 +0800 Subject: [PATCH] return unorderd_map rather than vector for LiteSession::GetOutputs --- mindspore/lite/include/lite_session.h | 5 +++-- mindspore/lite/src/lite_session.cc | 13 ++----------- mindspore/lite/src/lite_session.h | 2 +- mindspore/lite/test/ut/src/infer_test.cc | 6 ++++-- 4 files changed, 10 insertions(+), 16 deletions(-) diff --git a/mindspore/lite/include/lite_session.h b/mindspore/lite/include/lite_session.h index ea762f0f60..80fec03cf5 100644 --- a/mindspore/lite/include/lite_session.h +++ b/mindspore/lite/include/lite_session.h @@ -20,6 +20,7 @@ #include #include #include +#include #include "include/ms_tensor.h" #include "include/model.h" #include "include/context.h" @@ -85,8 +86,8 @@ class MS_API LiteSession { /// \brief Get output MindSpore Lite MSTensors of model. /// - /// \return A vector of MindSpore Lite MSTensor. - virtual std::vector GetOutputs() const = 0; + /// \return A map of output node name and MindSpore Lite MSTensor. + virtual std::unordered_map> GetOutputs() const = 0; /// \brief Get output MindSpore Lite MSTensors of model by node name. /// diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc index aa402415e7..b3ffea7d4b 100644 --- a/mindspore/lite/src/lite_session.cc +++ b/mindspore/lite/src/lite_session.cc @@ -177,17 +177,8 @@ int LiteSession::RunGraph(const session::KernelCallBack &before, const session:: } } -std::vector LiteSession::GetOutputs() const { - std::vector ret; - for (auto &iter : this->output_map) { - auto &node_output_tensors = iter.second; - for (auto tensor : node_output_tensors) { - if (!IsContain(ret, tensor)) { - ret.emplace_back(tensor); - } - } - } - return ret; +std::unordered_map> LiteSession::GetOutputs() const { + return this->output_map; } int LiteSession::Init(Context *context) { diff --git a/mindspore/lite/src/lite_session.h b/mindspore/lite/src/lite_session.h index 4c45809596..56613f7aa1 100644 --- a/mindspore/lite/src/lite_session.h +++ b/mindspore/lite/src/lite_session.h @@ -49,7 +49,7 @@ class LiteSession : public session::LiteSession { int RunGraph(const session::KernelCallBack &before = nullptr, const session::KernelCallBack &after = nullptr) override; - std::vector GetOutputs() const override; + std::unordered_map> GetOutputs() const override; std::vector GetOutputsByName(const std::string &name) const override; diff --git a/mindspore/lite/test/ut/src/infer_test.cc b/mindspore/lite/test/ut/src/infer_test.cc index 6bce0ddad2..04c30e241b 100644 --- a/mindspore/lite/test/ut/src/infer_test.cc +++ b/mindspore/lite/test/ut/src/infer_test.cc @@ -130,7 +130,8 @@ TEST_F(InferTest, TestConvNode) { ASSERT_EQ(lite::RET_OK, ret); auto outputs = session->GetOutputs(); ASSERT_EQ(outputs.size(), 1); - auto outTensor = outputs.front(); + ASSERT_EQ(outputs.begin()->second.size(), 1); + auto outTensor = outputs.begin()->second.front(); ASSERT_NE(nullptr, outTensor); ASSERT_EQ(28 * 28 * 32, outTensor->ElementsNum()); ASSERT_EQ(TypeId::kNumberTypeFloat32, outTensor->data_type()); @@ -220,7 +221,8 @@ TEST_F(InferTest, TestAddNode) { ASSERT_EQ(lite::RET_OK, ret); auto outputs = session->GetOutputs(); ASSERT_EQ(outputs.size(), 1); - auto outTensor = outputs.front(); + ASSERT_EQ(outputs.begin()->second.size(), 1); + auto outTensor = outputs.begin()->second.front(); ASSERT_NE(nullptr, outTensor); ASSERT_EQ(28 * 28 * 3, outTensor->ElementsNum()); ASSERT_EQ(TypeId::kNumberTypeFloat32, outTensor->data_type());