Browse Source

add model check for session

tags/v1.1.0
yvette 5 years ago
parent
commit
e71c433568
3 changed files with 8 additions and 10 deletions
  1. +5
    -0
      mindspore/lite/src/lite_session.cc
  2. +2
    -9
      mindspore/lite/src/model_common.cc
  3. +1
    -1
      mindspore/lite/src/model_common.h

+ 5
- 0
mindspore/lite/src/lite_session.cc View File

@@ -291,6 +291,11 @@ void LiteSession::InitGraphInOutTensors(const lite::Model *model) {
} }


int LiteSession::CompileGraph(Model *model) { int LiteSession::CompileGraph(Model *model) {
if (!ModelVerify(*model)) {
MS_LOG(ERROR) << "wrong model input, please check";
return RET_ERROR;
}

bool expected = false; bool expected = false;
if (!is_running_.compare_exchange_strong(expected, true)) { if (!is_running_.compare_exchange_strong(expected, true)) {
MS_LOG(ERROR) << "Not support multi-threading"; MS_LOG(ERROR) << "Not support multi-threading";


+ 2
- 9
mindspore/lite/src/model_common.cc View File

@@ -138,14 +138,7 @@ int SubGraphVerify(const Model &model) {
return RET_OK; return RET_OK;
} }


int ModelVerify(const Model &model, const int &schema_version) {
if (schema_version == SCHEMA_VERSION::SCHEMA_CUR) {
return NodeVerify(model) == RET_OK && SubGraphVerify(model) == RET_OK;
} else if (schema_version == SCHEMA_VERSION::SCHEMA_V0) {
return NodeVerify(model) == RET_OK;
}
return RET_ERROR;
}
bool ModelVerify(const Model &model) { return NodeVerify(model) == RET_OK && SubGraphVerify(model) == RET_OK; }


const void *GetMetaGraphByVerison(const char *buf, const int &schema_version) { const void *GetMetaGraphByVerison(const char *buf, const int &schema_version) {
if (buf == nullptr) { if (buf == nullptr) {
@@ -230,6 +223,6 @@ Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf) {
return nullptr; return nullptr;
} }


return ModelVerify(*model, schema_version) ? model : nullptr;
return ModelVerify(*model) ? model : nullptr;
} }
} // namespace mindspore::lite } // namespace mindspore::lite

+ 1
- 1
mindspore/lite/src/model_common.h View File

@@ -181,7 +181,7 @@ int NodeVerify(const Model &model);


int SubGraphVerify(const Model &model); int SubGraphVerify(const Model &model);


int ModelVerify(const Model &model, const int &schema_version);
bool ModelVerify(const Model &model);


const void *GetMetaGraphByVerison(const char *buf, const int &schema_version); const void *GetMetaGraphByVerison(const char *buf, const int &schema_version);




Loading…
Cancel
Save