Browse Source

!9503 add model check for session

From: @lyvette
Reviewed-by: @hangangqiang,@zhanghaibo5
Signed-off-by: @hangangqiang
pull/9503/MERGE
mindspore-ci-bot Gitee 5 years ago
parent
commit
175d7b0bbc
3 changed files with 35 additions and 25 deletions
  1. +4
    -0
      mindspore/lite/src/lite_session.cc
  2. +29
    -25
      mindspore/lite/src/model_common.cc
  3. +2
    -0
      mindspore/lite/src/model_common.h

+ 4
- 0
mindspore/lite/src/lite_session.cc View File

@@ -293,6 +293,10 @@ void LiteSession::InitGraphInOutTensors(const lite::Model *model) {
}

int LiteSession::CompileGraph(Model *model) {
if (!ModelVerify(*model)) {
MS_LOG(ERROR) << "Wrong model input";
return RET_ERROR;
}
bool expected = false;
if (!is_running_.compare_exchange_strong(expected, true)) {
MS_LOG(ERROR) << "Not support multi-threading";


+ 29
- 25
mindspore/lite/src/model_common.cc View File

@@ -201,35 +201,39 @@ int SubGraphVerify(const Model &model) {
auto tensor_size = model.all_tensors_.size();
auto node_size = model.all_nodes_.size();

for (auto &graph : model.sub_graphs_) {
if (graph == nullptr) {
MS_LOG(ERROR) << "graph is null.";
return RET_ERROR;
}
if (std::any_of(graph->input_indices_.begin(), graph->input_indices_.end(),
[&tensor_size](const uint32_t &idx) { return idx >= tensor_size; })) {
MS_LOG(ERROR) << "Index of graph->input_indices_ is beyond tensor_size.";
return RET_ERROR;
}
if (std::any_of(graph->output_indices_.begin(), graph->output_indices_.end(),
[&tensor_size](const uint32_t &idx) { return idx >= tensor_size; })) {
MS_LOG(ERROR) << "Index of graph->output_indices_ is beyond tensor_size.";
return RET_ERROR;
}
if (std::any_of(graph->tensor_indices_.begin(), graph->tensor_indices_.end(),
[&tensor_size](const uint32_t &idx) { return idx >= tensor_size; })) {
MS_LOG(ERROR) << "Index of graph->tensor_indices_ is beyond tensor_size.";
return RET_ERROR;
}
if (std::any_of(graph->node_indices_.begin(), graph->node_indices_.end(),
[&node_size](const uint32_t &idx) { return idx >= node_size; })) {
MS_LOG(ERROR) << "Index of graph->node_indices_ is beyond node_size.";
return RET_ERROR;
if (!model.sub_graphs_.empty()) {
for (auto &graph : model.sub_graphs_) {
if (graph == nullptr) {
MS_LOG(ERROR) << "graph is null.";
return RET_ERROR;
}
if (std::any_of(graph->input_indices_.begin(), graph->input_indices_.end(),
[&tensor_size](const uint32_t &idx) { return idx >= tensor_size; })) {
MS_LOG(ERROR) << "Index of graph->input_indices_ is beyond tensor_size.";
return RET_ERROR;
}
if (std::any_of(graph->output_indices_.begin(), graph->output_indices_.end(),
[&tensor_size](const uint32_t &idx) { return idx >= tensor_size; })) {
MS_LOG(ERROR) << "Index of graph->output_indices_ is beyond tensor_size.";
return RET_ERROR;
}
if (std::any_of(graph->tensor_indices_.begin(), graph->tensor_indices_.end(),
[&tensor_size](const uint32_t &idx) { return idx >= tensor_size; })) {
MS_LOG(ERROR) << "Index of graph->tensor_indices_ is beyond tensor_size.";
return RET_ERROR;
}
if (std::any_of(graph->node_indices_.begin(), graph->node_indices_.end(),
[&node_size](const uint32_t &idx) { return idx >= node_size; })) {
MS_LOG(ERROR) << "Index of graph->node_indices_ is beyond node_size.";
return RET_ERROR;
}
}
}
return RET_OK;
}

bool ModelVerify(const Model &model) { return NodeVerify(model) == RET_OK && SubGraphVerify(model) == RET_OK; }

Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf) {
if (model_buf == nullptr) {
MS_LOG(ERROR) << "The model buf is nullptr";
@@ -314,6 +318,6 @@ Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf) {
delete model;
return nullptr;
}
return NodeVerify(*model) == RET_OK && SubGraphVerify(*model) == RET_OK ? model : nullptr;
return ModelVerify(*model) ? model : nullptr;
}
} // namespace mindspore::lite

+ 2
- 0
mindspore/lite/src/model_common.h View File

@@ -32,6 +32,8 @@ int NodeVerify(const Model &model);

int SubGraphVerify(const Model &model);

bool ModelVerify(const Model &model);

Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf);
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_SRC_MODEL_COMMON_H_

Loading…
Cancel
Save