Browse Source

add model check for session

tags/v1.1.0
yvette 5 years ago
parent
commit
e71c433568
3 changed files with 8 additions and 10 deletions
  1. +5
    -0
      mindspore/lite/src/lite_session.cc
  2. +2
    -9
      mindspore/lite/src/model_common.cc
  3. +1
    -1
      mindspore/lite/src/model_common.h

+ 5
- 0
mindspore/lite/src/lite_session.cc View File

@@ -291,6 +291,11 @@ void LiteSession::InitGraphInOutTensors(const lite::Model *model) {
}

int LiteSession::CompileGraph(Model *model) {
if (!ModelVerify(*model)) {
MS_LOG(ERROR) << "wrong model input, please check";
return RET_ERROR;
}

bool expected = false;
if (!is_running_.compare_exchange_strong(expected, true)) {
MS_LOG(ERROR) << "Not support multi-threading";


+ 2
- 9
mindspore/lite/src/model_common.cc View File

@@ -138,14 +138,7 @@ int SubGraphVerify(const Model &model) {
return RET_OK;
}

int ModelVerify(const Model &model, const int &schema_version) {
if (schema_version == SCHEMA_VERSION::SCHEMA_CUR) {
return NodeVerify(model) == RET_OK && SubGraphVerify(model) == RET_OK;
} else if (schema_version == SCHEMA_VERSION::SCHEMA_V0) {
return NodeVerify(model) == RET_OK;
}
return RET_ERROR;
}
bool ModelVerify(const Model &model) { return NodeVerify(model) == RET_OK && SubGraphVerify(model) == RET_OK; }

const void *GetMetaGraphByVerison(const char *buf, const int &schema_version) {
if (buf == nullptr) {
@@ -230,6 +223,6 @@ Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf) {
return nullptr;
}

return ModelVerify(*model, schema_version) ? model : nullptr;
return ModelVerify(*model) ? model : nullptr;
}
} // namespace mindspore::lite

+ 1
- 1
mindspore/lite/src/model_common.h View File

@@ -181,7 +181,7 @@ int NodeVerify(const Model &model);

int SubGraphVerify(const Model &model);

int ModelVerify(const Model &model, const int &schema_version);
bool ModelVerify(const Model &model);

const void *GetMetaGraphByVerison(const char *buf, const int &schema_version);



Loading…
Cancel
Save