Browse Source

index check

tags/v1.0.0
chenjianping 5 years ago
parent
commit
b3b4fc4e6b
1 changed files with 16 additions and 2 deletions
  1. +16
    -2
      mindspore/lite/internal/src/lite_session.cc

+ 16
- 2
mindspore/lite/internal/src/lite_session.cc View File

@@ -105,6 +105,10 @@ int LiteSession::CompileGraph(Model *model) {
InitFuncs(); InitFuncs();
g_model = model; g_model = model;
for (auto in : g_model->input_indices_) { for (auto in : g_model->input_indices_) {
if (in >= g_model->all_tensors_.size() || in < 0) {
LITE_LOG_ERROR("Invalid input indices!");
return RET_PARAM_INVALID;
}
g_model->all_tensors_[in]->data_ = g_allocator.Malloc(g_model->all_tensors_[in]->Size()); g_model->all_tensors_[in]->data_ = g_allocator.Malloc(g_model->all_tensors_[in]->Size());
} }
g_infershape_interrupt = false; g_infershape_interrupt = false;
@@ -118,7 +122,12 @@ int LiteSession::CompileGraph(Model *model) {
TensorPtrVector LiteSession::GetInputs() const { TensorPtrVector LiteSession::GetInputs() const {
TensorPtrVector in(g_model->input_indices_.size()); TensorPtrVector in(g_model->input_indices_.size());
for (size_t i = 0; i < g_model->input_indices_.size(); ++i) { for (size_t i = 0; i < g_model->input_indices_.size(); ++i) {
in.at(i) = g_model->all_tensors_[g_model->input_indices_[i]];
auto index = g_model->input_indices_[i];
if (index < 0 || index >= g_model->all_tensors_.size()) {
LITE_ERROR_LOG("Invalid input index: %u", index);
return TensorPtrVector();
}
in.at(i) = g_model->all_tensors_[index];
} }
return in; return in;
} }
@@ -130,7 +139,12 @@ TensorPtrVector LiteSession::GetOutputsByNodeName(const String &node_name) const
TensorPtrVector LiteSession::GetOutputs() const { TensorPtrVector LiteSession::GetOutputs() const {
TensorPtrVector out(g_model->output_indices_.size()); TensorPtrVector out(g_model->output_indices_.size());
for (size_t i = 0; i < g_model->output_indices_.size(); ++i) { for (size_t i = 0; i < g_model->output_indices_.size(); ++i) {
out.at(i) = g_model->all_tensors_[g_model->output_indices_[i]];
auto index = g_model->output_indices_[i];
if (index < 0 || index >= g_model->all_tensors_.size()) {
LITE_ERROR_LOG("Invalid output index: %u", index);
return TensorPtrVector();
}
out.at(i) = g_model->all_tensors_[index];
} }
return out; return out;
} }


Loading…
Cancel
Save