From: @lyvette Reviewed-by: @hangangqiang,@zhanghaibo5 Signed-off-by: @hangangqiangtags/v1.1.0
| @@ -291,6 +291,11 @@ void LiteSession::InitGraphInOutTensors(const lite::Model *model) { | |||||
| } | } | ||||
| int LiteSession::CompileGraph(Model *model) { | int LiteSession::CompileGraph(Model *model) { | ||||
| if (!ModelVerify(*model)) { | |||||
| MS_LOG(ERROR) << "wrong model input, please check"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| bool expected = false; | bool expected = false; | ||||
| if (!is_running_.compare_exchange_strong(expected, true)) { | if (!is_running_.compare_exchange_strong(expected, true)) { | ||||
| MS_LOG(ERROR) << "Not support multi-threading"; | MS_LOG(ERROR) << "Not support multi-threading"; | ||||
| @@ -138,14 +138,7 @@ int SubGraphVerify(const Model &model) { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int ModelVerify(const Model &model, const int &schema_version) { | |||||
| if (schema_version == SCHEMA_VERSION::SCHEMA_CUR) { | |||||
| return NodeVerify(model) == RET_OK && SubGraphVerify(model) == RET_OK; | |||||
| } else if (schema_version == SCHEMA_VERSION::SCHEMA_V0) { | |||||
| return NodeVerify(model) == RET_OK; | |||||
| } | |||||
| return RET_ERROR; | |||||
| } | |||||
| bool ModelVerify(const Model &model) { return NodeVerify(model) == RET_OK && SubGraphVerify(model) == RET_OK; } | |||||
| const void *GetMetaGraphByVerison(const char *buf, const int &schema_version) { | const void *GetMetaGraphByVerison(const char *buf, const int &schema_version) { | ||||
| if (buf == nullptr) { | if (buf == nullptr) { | ||||
| @@ -230,6 +223,6 @@ Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf) { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| return ModelVerify(*model, schema_version) ? model : nullptr; | |||||
| return ModelVerify(*model) ? model : nullptr; | |||||
| } | } | ||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -181,7 +181,7 @@ int NodeVerify(const Model &model); | |||||
| int SubGraphVerify(const Model &model); | int SubGraphVerify(const Model &model); | ||||
| int ModelVerify(const Model &model, const int &schema_version); | |||||
| bool ModelVerify(const Model &model); | |||||
| const void *GetMetaGraphByVerison(const char *buf, const int &schema_version); | const void *GetMetaGraphByVerison(const char *buf, const int &schema_version); | ||||