Browse Source

!4163 return unorderd_map rather than vector for LiteSession::GetOutputs

Merge pull request !4163 from hangq/master
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
8e3c8f3d6e
4 changed files with 10 additions and 16 deletions
  1. +3
    -2
      mindspore/lite/include/lite_session.h
  2. +2
    -11
      mindspore/lite/src/lite_session.cc
  3. +1
    -1
      mindspore/lite/src/lite_session.h
  4. +4
    -2
      mindspore/lite/test/ut/src/infer_test.cc

+ 3
- 2
mindspore/lite/include/lite_session.h View File

@@ -20,6 +20,7 @@
#include <memory>
#include <vector>
#include <string>
#include <unordered_map>
#include "include/ms_tensor.h"
#include "include/model.h"
#include "include/context.h"
@@ -85,8 +86,8 @@ class MS_API LiteSession {

/// \brief Get output MindSpore Lite MSTensors of model.
///
/// \return A vector of MindSpore Lite MSTensor.
virtual std::vector<tensor::MSTensor *> GetOutputs() const = 0;
/// \return A map of output node name and MindSpore Lite MSTensor.
virtual std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> GetOutputs() const = 0;

/// \brief Get output MindSpore Lite MSTensors of model by node name.
///


+ 2
- 11
mindspore/lite/src/lite_session.cc View File

@@ -187,17 +187,8 @@ int LiteSession::RunGraph(const session::KernelCallBack &before, const session::
}
}

std::vector<mindspore::tensor::MSTensor *> LiteSession::GetOutputs() const {
std::vector<mindspore::tensor::MSTensor *> ret;
for (auto &iter : this->output_map) {
auto &node_output_tensors = iter.second;
for (auto tensor : node_output_tensors) {
if (!IsContain(ret, tensor)) {
ret.emplace_back(tensor);
}
}
}
return ret;
std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> LiteSession::GetOutputs() const {
return this->output_map;
}

int LiteSession::Init(Context *context) {


+ 1
- 1
mindspore/lite/src/lite_session.h View File

@@ -49,7 +49,7 @@ class LiteSession : public session::LiteSession {
int RunGraph(const session::KernelCallBack &before = nullptr,
const session::KernelCallBack &after = nullptr) override;

std::vector<mindspore::tensor::MSTensor *> GetOutputs() const override;
std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> GetOutputs() const override;

std::vector<mindspore::tensor::MSTensor *> GetOutputsByName(const std::string &name) const override;



+ 4
- 2
mindspore/lite/test/ut/src/infer_test.cc View File

@@ -130,7 +130,8 @@ TEST_F(InferTest, TestConvNode) {
ASSERT_EQ(lite::RET_OK, ret);
auto outputs = session->GetOutputs();
ASSERT_EQ(outputs.size(), 1);
auto outTensor = outputs.front();
ASSERT_EQ(outputs.begin()->second.size(), 1);
auto outTensor = outputs.begin()->second.front();
ASSERT_NE(nullptr, outTensor);
ASSERT_EQ(28 * 28 * 32, outTensor->ElementsNum());
ASSERT_EQ(TypeId::kNumberTypeFloat32, outTensor->data_type());
@@ -220,7 +221,8 @@ TEST_F(InferTest, TestAddNode) {
ASSERT_EQ(lite::RET_OK, ret);
auto outputs = session->GetOutputs();
ASSERT_EQ(outputs.size(), 1);
auto outTensor = outputs.front();
ASSERT_EQ(outputs.begin()->second.size(), 1);
auto outTensor = outputs.begin()->second.front();
ASSERT_NE(nullptr, outTensor);
ASSERT_EQ(28 * 28 * 3, outTensor->ElementsNum());
ASSERT_EQ(TypeId::kNumberTypeFloat32, outTensor->data_type());


Loading…
Cancel
Save