Browse Source

!4434 fix bug that lite_session GetInputs return repeated input tensor

Merge pull request !4434 from hangq/master
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
b34b7973be
4 changed files with 37 additions and 16 deletions
  1. +12
    -10
      mindspore/lite/src/lite_session.cc
  2. +8
    -4
      mindspore/lite/src/lite_session.h
  3. +16
    -1
      mindspore/lite/tools/benchmark/benchmark.cc
  4. +1
    -1
      mindspore/lite/tools/benchmark/benchmark.h

+ 12
- 10
mindspore/lite/src/lite_session.cc View File

@@ -92,6 +92,16 @@ void LiteSession::InitGraphInputTensors(const lite::Model *model) {
} }
} }


void LiteSession::InitGraphInputMSTensors(const lite::Model *model) {
auto meta_graph = model->GetMetaGraph();
MS_ASSERT(this->input_vec_.empty());
MS_ASSERT(meta_graph != nullptr);
for (auto &input_tensor : this->inputs_) {
MS_ASSERT(input_tensor != nullptr);
this->input_vec_.emplace_back(new lite::tensor::LiteTensor(input_tensor));
}
}

void LiteSession::InitGraphOutputTensors(const lite::Model *model) { void LiteSession::InitGraphOutputTensors(const lite::Model *model) {
auto meta_graph = model->GetMetaGraph(); auto meta_graph = model->GetMetaGraph();
MS_ASSERT(this->outputs_.empty()); MS_ASSERT(this->outputs_.empty());
@@ -169,6 +179,7 @@ void LiteSession::InitGraphOutputMap(const lite::Model *model) {


void LiteSession::InitGraphInOutTensors(const lite::Model *model) { void LiteSession::InitGraphInOutTensors(const lite::Model *model) {
InitGraphInputTensors(model); InitGraphInputTensors(model);
InitGraphInputMSTensors(model);
InitGraphOutputTensors(model); InitGraphOutputTensors(model);
InitGraphInputMap(model); InitGraphInputMap(model);
InitGraphOutputMap(model); InitGraphOutputMap(model);
@@ -201,16 +212,7 @@ int LiteSession::CompileGraph(Model *model) {
} }


std::vector<mindspore::tensor::MSTensor *> LiteSession::GetInputs() const { std::vector<mindspore::tensor::MSTensor *> LiteSession::GetInputs() const {
std::vector<mindspore::tensor::MSTensor *> ret;
for (auto &iter : this->input_map_) {
auto &node_input_tensors = iter.second;
for (auto tensor : node_input_tensors) {
if (!IsContain(ret, tensor)) {
ret.emplace_back(tensor);
}
}
}
return ret;
return this->input_vec_;
} }


int LiteSession::RunGraph(const session::KernelCallBack &before, const session::KernelCallBack &after) { int LiteSession::RunGraph(const session::KernelCallBack &before, const session::KernelCallBack &after) {


+ 8
- 4
mindspore/lite/src/lite_session.h View File

@@ -57,13 +57,15 @@ class LiteSession : public session::LiteSession {
int ConvertTensors(const lite::Model *model); int ConvertTensors(const lite::Model *model);


void InitGraphInOutTensors(const lite::Model *model); void InitGraphInOutTensors(const lite::Model *model);
// init this->inputs_
void InitGraphInputTensors(const lite::Model *model); void InitGraphInputTensors(const lite::Model *model);

// init this->input_vec_
void InitGraphInputMSTensors(const lite::Model *model);
// init this->outputs_
void InitGraphOutputTensors(const lite::Model *model); void InitGraphOutputTensors(const lite::Model *model);

// init this->input_map_
void InitGraphInputMap(const lite::Model *model); void InitGraphInputMap(const lite::Model *model);
// init this->output_map_
void InitGraphOutputMap(const lite::Model *model); void InitGraphOutputMap(const lite::Model *model);


protected: protected:
@@ -74,6 +76,8 @@ class LiteSession : public session::LiteSession {
std::vector<tensor::Tensor *> inputs_; std::vector<tensor::Tensor *> inputs_;
// graph output tensors // graph output tensors
std::vector<tensor::Tensor *> outputs_; std::vector<tensor::Tensor *> outputs_;
// graph input MSTensors
std::vector<mindspore::tensor::MSTensor *> input_vec_;
// graph input node name -- input tensors // graph input node name -- input tensors
std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> input_map_; std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> input_map_;
// graph output node name -- output tensors // graph output node name -- output tensors


+ 16
- 1
mindspore/lite/tools/benchmark/benchmark.cc View File

@@ -49,7 +49,8 @@ int Benchmark::GenerateInputData() {
auto tensorByteSize = tensor->Size(); auto tensorByteSize = tensor->Size();
auto status = GenerateRandomData(tensorByteSize, inputData); auto status = GenerateRandomData(tensorByteSize, inputData);
if (status != 0) { if (status != 0) {
MS_LOG(ERROR) << "GenerateRandomData for inTensor failed %d" << status;
std::cerr << "GenerateRandomData for inTensor failed: " << status << std::endl;
MS_LOG(ERROR) << "GenerateRandomData for inTensor failed:" << status;
return status; return status;
} }
} }
@@ -60,12 +61,14 @@ int Benchmark::LoadInput() {
if (_flags->inDataPath.empty()) { if (_flags->inDataPath.empty()) {
auto status = GenerateInputData(); auto status = GenerateInputData();
if (status != 0) { if (status != 0) {
std::cerr << "Generate input data error " << status << std::endl;
MS_LOG(ERROR) << "Generate input data error " << status; MS_LOG(ERROR) << "Generate input data error " << status;
return status; return status;
} }
} else { } else {
auto status = ReadInputFile(); auto status = ReadInputFile();
if (status != 0) { if (status != 0) {
std::cerr << "ReadInputFile error, " << status << std::endl;
MS_LOG(ERROR) << "ReadInputFile error, " << status; MS_LOG(ERROR) << "ReadInputFile error, " << status;
return status; return status;
} }
@@ -97,6 +100,7 @@ int Benchmark::ReadInputFile() {
char *binBuf = ReadFile(_flags->input_data_list[i].c_str(), &size); char *binBuf = ReadFile(_flags->input_data_list[i].c_str(), &size);
auto tensorDataSize = cur_tensor->Size(); auto tensorDataSize = cur_tensor->Size();
if (size != tensorDataSize) { if (size != tensorDataSize) {
std::cerr << "Input binary file size error, required: %zu, in fact: %zu" << tensorDataSize << size << std::endl;
MS_LOG(ERROR) << "Input binary file size error, required: %zu, in fact: %zu" << tensorDataSize << size; MS_LOG(ERROR) << "Input binary file size error, required: %zu, in fact: %zu" << tensorDataSize << size;
return RET_ERROR; return RET_ERROR;
} }
@@ -113,11 +117,13 @@ int Benchmark::ReadCalibData() {
// read calib data // read calib data
std::ifstream inFile(calibDataPath); std::ifstream inFile(calibDataPath);
if (!inFile.good()) { if (!inFile.good()) {
std::cerr << "file: " << calibDataPath << " is not exist" << std::endl;
MS_LOG(ERROR) << "file: " << calibDataPath << " is not exist"; MS_LOG(ERROR) << "file: " << calibDataPath << " is not exist";
return RET_ERROR; return RET_ERROR;
} }


if (!inFile.is_open()) { if (!inFile.is_open()) {
std::cerr << "file: " << calibDataPath << " open failed" << std::endl;
MS_LOG(ERROR) << "file: " << calibDataPath << " open failed"; MS_LOG(ERROR) << "file: " << calibDataPath << " open failed";
inFile.close(); inFile.close();
return RET_ERROR; return RET_ERROR;
@@ -181,6 +187,7 @@ float Benchmark::CompareData(const std::string &nodeName, std::vector<int> msSha
oss << dim << ","; oss << dim << ",";
} }
oss << ") are different"; oss << ") are different";
std::cerr << oss.str() << std::endl;
MS_LOG(ERROR) << "%s", oss.str().c_str(); MS_LOG(ERROR) << "%s", oss.str().c_str();
return RET_ERROR; return RET_ERROR;
} }
@@ -193,6 +200,7 @@ float Benchmark::CompareData(const std::string &nodeName, std::vector<int> msSha
} }


if (std::isnan(msTensorData[j]) || std::isinf(msTensorData[j])) { if (std::isnan(msTensorData[j]) || std::isinf(msTensorData[j])) {
std::cerr << "Output tensor has nan or inf data, compare fail" << std::endl;
MS_LOG(ERROR) << "Output tensor has nan or inf data, compare fail"; MS_LOG(ERROR) << "Output tensor has nan or inf data, compare fail";
return RET_ERROR; return RET_ERROR;
} }
@@ -524,6 +532,13 @@ int Benchmark::Init() {
return RET_OK; return RET_OK;
} }


Benchmark::~Benchmark() {
for (auto iter : this->calibData) {
delete (iter.second);
}
this->calibData.clear();
}

int RunBenchmark(int argc, const char **argv) { int RunBenchmark(int argc, const char **argv) {
BenchmarkFlags flags; BenchmarkFlags flags;
Option<std::string> err = flags.ParseFlags(argc, argv); Option<std::string> err = flags.ParseFlags(argc, argv);


+ 1
- 1
mindspore/lite/tools/benchmark/benchmark.h View File

@@ -104,7 +104,7 @@ class MS_API Benchmark {
public: public:
explicit Benchmark(BenchmarkFlags *flags) : _flags(flags) {} explicit Benchmark(BenchmarkFlags *flags) : _flags(flags) {}


virtual ~Benchmark() = default;
virtual ~Benchmark();


int Init(); int Init();
int RunBenchmark(const std::string &deviceType = "NPU"); int RunBenchmark(const std::string &deviceType = "NPU");


Loading…
Cancel
Save