From 5fc9eb7f6384857a8f1df419aa34874a5234f264 Mon Sep 17 00:00:00 2001 From: yvette Date: Fri, 4 Dec 2020 15:20:59 +0800 Subject: [PATCH] add model check for session --- mindspore/lite/src/lite_session.cc | 4 +++ mindspore/lite/src/model_common.cc | 54 ++++++++++++++++-------------- mindspore/lite/src/model_common.h | 2 ++ 3 files changed, 35 insertions(+), 25 deletions(-) diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc index 67a8c0cbd0..5ab84d8f80 100644 --- a/mindspore/lite/src/lite_session.cc +++ b/mindspore/lite/src/lite_session.cc @@ -293,6 +293,10 @@ void LiteSession::InitGraphInOutTensors(const lite::Model *model) { } int LiteSession::CompileGraph(Model *model) { + if (!ModelVerify(*model)) { + MS_LOG(ERROR) << "Wrong model input"; + return RET_ERROR; + } bool expected = false; if (!is_running_.compare_exchange_strong(expected, true)) { MS_LOG(ERROR) << "Not support multi-threading"; diff --git a/mindspore/lite/src/model_common.cc b/mindspore/lite/src/model_common.cc index 06b861716b..1b5b4d1610 100644 --- a/mindspore/lite/src/model_common.cc +++ b/mindspore/lite/src/model_common.cc @@ -201,35 +201,39 @@ int SubGraphVerify(const Model &model) { auto tensor_size = model.all_tensors_.size(); auto node_size = model.all_nodes_.size(); - for (auto &graph : model.sub_graphs_) { - if (graph == nullptr) { - MS_LOG(ERROR) << "graph is null."; - return RET_ERROR; - } - if (std::any_of(graph->input_indices_.begin(), graph->input_indices_.end(), - [&tensor_size](const uint32_t &idx) { return idx >= tensor_size; })) { - MS_LOG(ERROR) << "Index of graph->input_indices_ is beyond tensor_size."; - return RET_ERROR; - } - if (std::any_of(graph->output_indices_.begin(), graph->output_indices_.end(), - [&tensor_size](const uint32_t &idx) { return idx >= tensor_size; })) { - MS_LOG(ERROR) << "Index of graph->output_indices_ is beyond tensor_size."; - return RET_ERROR; - } - if (std::any_of(graph->tensor_indices_.begin(), graph->tensor_indices_.end(), - [&tensor_size](const uint32_t &idx) { return idx >= tensor_size; })) { - MS_LOG(ERROR) << "Index of graph->tensor_indices_ is beyond tensor_size."; - return RET_ERROR; - } - if (std::any_of(graph->node_indices_.begin(), graph->node_indices_.end(), - [&node_size](const uint32_t &idx) { return idx >= node_size; })) { - MS_LOG(ERROR) << "Index of graph->node_indices_ is beyond node_size."; - return RET_ERROR; + if (!model.sub_graphs_.empty()) { + for (auto &graph : model.sub_graphs_) { + if (graph == nullptr) { + MS_LOG(ERROR) << "graph is null."; + return RET_ERROR; + } + if (std::any_of(graph->input_indices_.begin(), graph->input_indices_.end(), + [&tensor_size](const uint32_t &idx) { return idx >= tensor_size; })) { + MS_LOG(ERROR) << "Index of graph->input_indices_ is beyond tensor_size."; + return RET_ERROR; + } + if (std::any_of(graph->output_indices_.begin(), graph->output_indices_.end(), + [&tensor_size](const uint32_t &idx) { return idx >= tensor_size; })) { + MS_LOG(ERROR) << "Index of graph->output_indices_ is beyond tensor_size."; + return RET_ERROR; + } + if (std::any_of(graph->tensor_indices_.begin(), graph->tensor_indices_.end(), + [&tensor_size](const uint32_t &idx) { return idx >= tensor_size; })) { + MS_LOG(ERROR) << "Index of graph->tensor_indices_ is beyond tensor_size."; + return RET_ERROR; + } + if (std::any_of(graph->node_indices_.begin(), graph->node_indices_.end(), + [&node_size](const uint32_t &idx) { return idx >= node_size; })) { + MS_LOG(ERROR) << "Index of graph->node_indices_ is beyond node_size."; + return RET_ERROR; + } } } return RET_OK; } +bool ModelVerify(const Model &model) { return NodeVerify(model) == RET_OK && SubGraphVerify(model) == RET_OK; } + Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf) { if (model_buf == nullptr) { MS_LOG(ERROR) << "The model buf is nullptr"; @@ -314,6 +318,6 @@ Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf) { delete model; return nullptr; } - return NodeVerify(*model) == RET_OK && SubGraphVerify(*model) == RET_OK ? model : nullptr; + return ModelVerify(*model) ? model : nullptr; } } // namespace mindspore::lite diff --git a/mindspore/lite/src/model_common.h b/mindspore/lite/src/model_common.h index dfcb21e477..1faa1dff80 100644 --- a/mindspore/lite/src/model_common.h +++ b/mindspore/lite/src/model_common.h @@ -32,6 +32,8 @@ int NodeVerify(const Model &model); int SubGraphVerify(const Model &model); +bool ModelVerify(const Model &model); + Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf); } // namespace mindspore::lite #endif // MINDSPORE_LITE_SRC_MODEL_COMMON_H_