Browse Source

!11969 [MS][LITE][r1.1]fix use of vector in CXX API

From: @lx0095
Reviewed-by: @zhang_xue_tong,@zhang_xue_tong,@hangangqiang
Signed-off-by: @zhang_xue_tong
pull/11969/MERGE
mindspore-ci-bot Gitee 5 years ago
parent
commit
9f0ce32910
5 changed files with 60 additions and 23 deletions
  1. +1
    -1
      mindspore/lite/src/cxx_api/graph/graph.cc
  2. +0
    -2
      mindspore/lite/src/cxx_api/graph/graph_data.h
  3. +53
    -17
      mindspore/lite/src/cxx_api/model/model_impl.cc
  4. +3
    -3
      mindspore/lite/src/cxx_api/tensor/tensor_impl.cc
  5. +3
    -0
      mindspore/lite/src/cxx_api/utils.h

+ 1
- 1
mindspore/lite/src/cxx_api/graph/graph.cc View File

@@ -30,5 +30,5 @@ Graph::Graph(std::nullptr_t) : graph_data_(nullptr) {}

bool Graph::operator==(std::nullptr_t) const { return graph_data_ == nullptr; }

ModelType Graph::ModelType() const { return graph_data_->ModelType(); }
ModelType Graph::ModelType() const { return kMindIR; }
} // namespace mindspore

+ 0
- 2
mindspore/lite/src/cxx_api/graph/graph_data.h View File

@@ -35,8 +35,6 @@ class Graph::GraphData {

std::shared_ptr<lite::Model> lite_model() { return lite_model_; }

enum ModelType ModelType() const { return kMindIR; }

private:
std::shared_ptr<lite::Model> lite_model_;
};


+ 53
- 17
mindspore/lite/src/cxx_api/model/model_impl.cc View File

@@ -156,6 +156,7 @@ Status ModelImpl::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTen
MS_LOG(DEBUG) << "Empty outputs.";
return kLiteError;
}
outputs->clear();
outputs->insert(outputs->end(), res.begin(), res.end());
return kSuccess;
}
@@ -168,8 +169,13 @@ std::vector<MSTensor> ModelImpl::GetInputs() {
}
std::vector<MSTensor> res;
auto inputs = session_->GetInputs();
for (auto input : inputs) {
auto impl = std::shared_ptr<MSTensor::Impl>(new (std::nothrow) MSTensor::Impl(input));
if (inputs.empty()) {
MS_LOG(ERROR) << "The inputs of model is null.";
return empty;
}
res.resize(inputs.size());
for (size_t i = 0; i < inputs.size(); i++) {
auto impl = std::shared_ptr<MSTensor::Impl>(new (std::nothrow) MSTensor::Impl(inputs[i]));
if (impl == nullptr) {
MS_LOG(ERROR) << "Create tensor failed.";
return empty;
@@ -179,7 +185,7 @@ std::vector<MSTensor> ModelImpl::GetInputs() {
MS_LOG(ERROR) << "Create tensor failed.";
return empty;
}
res.push_back(tensor);
res[i] = tensor;
}
return res;
}
@@ -192,9 +198,22 @@ std::vector<MSTensor> ModelImpl::GetOutputs() {
}
std::vector<MSTensor> res;
auto names = session_->GetOutputTensorNames();
if (names.empty()) {
MS_LOG(ERROR) << "The names of model is null.";
return empty;
}
auto outputs = session_->GetOutputs();
for (auto name : names) {
auto impl = std::shared_ptr<MSTensor::Impl>(new (std::nothrow) MSTensor::Impl(outputs[name]));
if (outputs.empty()) {
MS_LOG(ERROR) << "The outputs of model is null.";
return empty;
}
if (names.size() != outputs.size()) {
MS_LOG(ERROR) << "The size of outputs dose not match the size of names.";
return empty;
}
res.resize(names.size());
for (size_t i = 0; i < names.size(); i++) {
auto impl = std::shared_ptr<MSTensor::Impl>(new (std::nothrow) MSTensor::Impl(outputs[names[i]]));
if (impl == nullptr) {
MS_LOG(ERROR) << "Create tensor failed.";
return empty;
@@ -204,7 +223,7 @@ std::vector<MSTensor> ModelImpl::GetOutputs() {
MS_LOG(ERROR) << "Create tensor failed.";
return empty;
}
res.push_back(tensor);
res[i] = tensor;
}
return res;
}
@@ -214,27 +233,44 @@ Status ModelImpl::Resize(const std::vector<MSTensor> &inputs, const std::vector<
MS_LOG(ERROR) << "Session is null.";
return kLiteNullptr;
}
if (inputs.empty()) {
MS_LOG(ERROR) << "Inputs is null.";
return kLiteInputParamInvalid;
}
if (dims.empty()) {
MS_LOG(ERROR) << "Dims is null.";
return kLiteInputParamInvalid;
}
if (inputs.size() != dims.size()) {
MS_LOG(ERROR) << "The size of inputs is not equal to the size of dims.";
MS_LOG(ERROR) << "The size of inputs does not match the size of dims.";
return kLiteInputParamInvalid;
}
auto model_inputs = session_->GetInputs();
if (model_inputs.empty()) {
MS_LOG(ERROR) << "The inputs of model is null.";
return kLiteParamInvalid;
}
if (inputs.size() != model_inputs.size()) {
MS_LOG(ERROR) << "The size of inputs is incorrect.";
return kLiteInputParamInvalid;
}
std::vector<tensor::MSTensor *> inner_input;
for (auto input : inputs) {
inner_input.resize(inputs.size());
std::vector<std::vector<int32_t>> truncated_shape;
truncated_shape.resize(inputs.size());
for (size_t i = 0; i < inputs.size(); i++) {
auto input = inputs[i];
if (input.impl_ == nullptr || input.impl_->lite_tensor() == nullptr) {
MS_LOG(ERROR) << "Input tensor " << input.Name() << " is null.";
return kLiteInputTensorError;
}
inner_input.push_back(input.impl_->lite_tensor());
}
std::vector<std::vector<int32_t>> truncated_shape;
for (size_t i = 0; i < inner_input.size(); i++) {
std::vector<int32_t> tmp =
TruncateShape(dims.at(i), inner_input.at(i)->data_type(), inner_input.at(i)->Size(), false);
if (tmp.empty()) {
MS_LOG(ERROR) << "Input dims[" << i << "]is invalid.";
inner_input[i] = input.impl_->lite_tensor();
std::vector<int32_t> shape = TruncateShape(dims[i], inner_input[i]->data_type(), inner_input[i]->Size(), false);
if (shape.empty() && !(dims[i].empty())) {
MS_LOG(ERROR) << "Input dims[" << i << "] is invalid.";
return kLiteParamInvalid;
}
truncated_shape.push_back(tmp);
truncated_shape[i] = shape;
}
auto ret = session_->Resize(inner_input, truncated_shape);
return static_cast<StatusCode>(ret);


+ 3
- 3
mindspore/lite/src/cxx_api/tensor/tensor_impl.cc View File

@@ -29,10 +29,10 @@ namespace mindspore {
MSTensor::Impl::Impl(const std::string &name, enum DataType type, const std::vector<int64_t> &shape, const void *data,
size_t data_len) {
std::vector<int32_t> truncated_shape = TruncateShape(shape, static_cast<enum TypeId>(type), data_len, true);
if (!truncated_shape.empty()) {
lite_tensor_ = new (std::nothrow) lite::Tensor(name, static_cast<enum TypeId>(type), truncated_shape, data);
} else {
if (truncated_shape.empty() && !(shape.empty())) {
lite_tensor_ = nullptr;
} else {
lite_tensor_ = new (std::nothrow) lite::Tensor(name, static_cast<enum TypeId>(type), truncated_shape, data);
}
}



+ 3
- 0
mindspore/lite/src/cxx_api/utils.h View File

@@ -21,6 +21,9 @@ namespace mindspore {
static std::vector<int32_t> TruncateShape(const std::vector<int64_t> &shape, enum TypeId type, size_t data_len,
bool verify_size) {
std::vector<int32_t> empty;
if (shape.empty()) {
return empty;
}
std::vector<int32_t> truncated_shape;
size_t element_size = lite::DataTypeSize(type);
for (auto i : shape) {


Loading…
Cancel
Save