From 30376246c0cf65685a52fa7124459424cd37203d Mon Sep 17 00:00:00 2001 From: lixian Date: Tue, 2 Feb 2021 11:33:49 +0800 Subject: [PATCH] fix use of vector in CXX API --- mindspore/lite/src/cxx_api/graph/graph.cc | 2 +- mindspore/lite/src/cxx_api/graph/graph_data.h | 2 - .../lite/src/cxx_api/model/model_impl.cc | 70 ++++++++++++++----- .../lite/src/cxx_api/tensor/tensor_impl.cc | 6 +- mindspore/lite/src/cxx_api/utils.h | 3 + 5 files changed, 60 insertions(+), 23 deletions(-) diff --git a/mindspore/lite/src/cxx_api/graph/graph.cc b/mindspore/lite/src/cxx_api/graph/graph.cc index cdacd62df5..e1c57cdeff 100644 --- a/mindspore/lite/src/cxx_api/graph/graph.cc +++ b/mindspore/lite/src/cxx_api/graph/graph.cc @@ -30,5 +30,5 @@ Graph::Graph(std::nullptr_t) : graph_data_(nullptr) {} bool Graph::operator==(std::nullptr_t) const { return graph_data_ == nullptr; } -ModelType Graph::ModelType() const { return graph_data_->ModelType(); } +ModelType Graph::ModelType() const { return kMindIR; } } // namespace mindspore diff --git a/mindspore/lite/src/cxx_api/graph/graph_data.h b/mindspore/lite/src/cxx_api/graph/graph_data.h index fdd2aec516..584035db8d 100644 --- a/mindspore/lite/src/cxx_api/graph/graph_data.h +++ b/mindspore/lite/src/cxx_api/graph/graph_data.h @@ -35,8 +35,6 @@ class Graph::GraphData { std::shared_ptr lite_model() { return lite_model_; } - enum ModelType ModelType() const { return kMindIR; } - private: std::shared_ptr lite_model_; }; diff --git a/mindspore/lite/src/cxx_api/model/model_impl.cc b/mindspore/lite/src/cxx_api/model/model_impl.cc index ba929be211..0aa48361a3 100644 --- a/mindspore/lite/src/cxx_api/model/model_impl.cc +++ b/mindspore/lite/src/cxx_api/model/model_impl.cc @@ -156,6 +156,7 @@ Status ModelImpl::Predict(const std::vector &inputs, std::vectorclear(); outputs->insert(outputs->end(), res.begin(), res.end()); return kSuccess; } @@ -168,8 +169,13 @@ std::vector ModelImpl::GetInputs() { } std::vector res; auto inputs = session_->GetInputs(); - for (auto input : inputs) { - auto impl = std::shared_ptr(new (std::nothrow) MSTensor::Impl(input)); + if (inputs.empty()) { + MS_LOG(ERROR) << "The inputs of model is null."; + return empty; + } + res.resize(inputs.size()); + for (size_t i = 0; i < inputs.size(); i++) { + auto impl = std::shared_ptr(new (std::nothrow) MSTensor::Impl(inputs[i])); if (impl == nullptr) { MS_LOG(ERROR) << "Create tensor failed."; return empty; @@ -179,7 +185,7 @@ std::vector ModelImpl::GetInputs() { MS_LOG(ERROR) << "Create tensor failed."; return empty; } - res.push_back(tensor); + res[i] = tensor; } return res; } @@ -192,9 +198,22 @@ std::vector ModelImpl::GetOutputs() { } std::vector res; auto names = session_->GetOutputTensorNames(); + if (names.empty()) { + MS_LOG(ERROR) << "The names of model is null."; + return empty; + } auto outputs = session_->GetOutputs(); - for (auto name : names) { - auto impl = std::shared_ptr(new (std::nothrow) MSTensor::Impl(outputs[name])); + if (outputs.empty()) { + MS_LOG(ERROR) << "The outputs of model is null."; + return empty; + } + if (names.size() != outputs.size()) { + MS_LOG(ERROR) << "The size of outputs dose not match the size of names."; + return empty; + } + res.resize(names.size()); + for (size_t i = 0; i < names.size(); i++) { + auto impl = std::shared_ptr(new (std::nothrow) MSTensor::Impl(outputs[names[i]])); if (impl == nullptr) { MS_LOG(ERROR) << "Create tensor failed."; return empty; @@ -204,7 +223,7 @@ std::vector ModelImpl::GetOutputs() { MS_LOG(ERROR) << "Create tensor failed."; return empty; } - res.push_back(tensor); + res[i] = tensor; } return res; } @@ -214,27 +233,44 @@ Status ModelImpl::Resize(const std::vector &inputs, const std::vector< MS_LOG(ERROR) << "Session is null."; return kLiteNullptr; } + if (inputs.empty()) { + MS_LOG(ERROR) << "Inputs is null."; + return kLiteInputParamInvalid; + } + if (dims.empty()) { + MS_LOG(ERROR) << "Dims is null."; + return kLiteInputParamInvalid; + } if (inputs.size() != dims.size()) { - MS_LOG(ERROR) << "The size of inputs is not equal to the size of dims."; + MS_LOG(ERROR) << "The size of inputs does not match the size of dims."; + return kLiteInputParamInvalid; + } + auto model_inputs = session_->GetInputs(); + if (model_inputs.empty()) { + MS_LOG(ERROR) << "The inputs of model is null."; return kLiteParamInvalid; } + if (inputs.size() != model_inputs.size()) { + MS_LOG(ERROR) << "The size of inputs is incorrect."; + return kLiteInputParamInvalid; + } std::vector inner_input; - for (auto input : inputs) { + inner_input.resize(inputs.size()); + std::vector> truncated_shape; + truncated_shape.resize(inputs.size()); + for (size_t i = 0; i < inputs.size(); i++) { + auto input = inputs[i]; if (input.impl_ == nullptr || input.impl_->lite_tensor() == nullptr) { MS_LOG(ERROR) << "Input tensor " << input.Name() << " is null."; return kLiteInputTensorError; } - inner_input.push_back(input.impl_->lite_tensor()); - } - std::vector> truncated_shape; - for (size_t i = 0; i < inner_input.size(); i++) { - std::vector tmp = - TruncateShape(dims.at(i), inner_input.at(i)->data_type(), inner_input.at(i)->Size(), false); - if (tmp.empty()) { - MS_LOG(ERROR) << "Input dims[" << i << "]is invalid."; + inner_input[i] = input.impl_->lite_tensor(); + std::vector shape = TruncateShape(dims[i], inner_input[i]->data_type(), inner_input[i]->Size(), false); + if (shape.empty() && !(dims[i].empty())) { + MS_LOG(ERROR) << "Input dims[" << i << "] is invalid."; return kLiteParamInvalid; } - truncated_shape.push_back(tmp); + truncated_shape[i] = shape; } auto ret = session_->Resize(inner_input, truncated_shape); return static_cast(ret); diff --git a/mindspore/lite/src/cxx_api/tensor/tensor_impl.cc b/mindspore/lite/src/cxx_api/tensor/tensor_impl.cc index 46f968f175..a0ec1677ba 100644 --- a/mindspore/lite/src/cxx_api/tensor/tensor_impl.cc +++ b/mindspore/lite/src/cxx_api/tensor/tensor_impl.cc @@ -29,10 +29,10 @@ namespace mindspore { MSTensor::Impl::Impl(const std::string &name, enum DataType type, const std::vector &shape, const void *data, size_t data_len) { std::vector truncated_shape = TruncateShape(shape, static_cast(type), data_len, true); - if (!truncated_shape.empty()) { - lite_tensor_ = new (std::nothrow) lite::Tensor(name, static_cast(type), truncated_shape, data); - } else { + if (truncated_shape.empty() && !(shape.empty())) { lite_tensor_ = nullptr; + } else { + lite_tensor_ = new (std::nothrow) lite::Tensor(name, static_cast(type), truncated_shape, data); } } diff --git a/mindspore/lite/src/cxx_api/utils.h b/mindspore/lite/src/cxx_api/utils.h index 0e1967486f..22771f2602 100644 --- a/mindspore/lite/src/cxx_api/utils.h +++ b/mindspore/lite/src/cxx_api/utils.h @@ -21,6 +21,9 @@ namespace mindspore { static std::vector TruncateShape(const std::vector &shape, enum TypeId type, size_t data_len, bool verify_size) { std::vector empty; + if (shape.empty()) { + return empty; + } std::vector truncated_shape; size_t element_size = lite::DataTypeSize(type); for (auto i : shape) {