|
|
@@ -14,33 +14,39 @@ |
|
|
* limitations under the License. |
|
|
* limitations under the License. |
|
|
*/ |
|
|
*/ |
|
|
#include "src/model_common.h" |
|
|
#include "src/model_common.h" |
|
|
|
|
|
#include "src/ops/while.h" |
|
|
|
|
|
|
|
|
namespace mindspore::lite { |
|
|
namespace mindspore::lite { |
|
|
int ConvertSubGraph(const schema::SubGraph &sub_graph, Model *model) { |
|
|
int ConvertSubGraph(const schema::SubGraph &sub_graph, Model *model) { |
|
|
MS_ASSERT(model != nullptr); |
|
|
|
|
|
|
|
|
if (model == nullptr) { |
|
|
|
|
|
MS_LOG(ERROR) << "model is null."; |
|
|
|
|
|
return RET_ERROR; |
|
|
|
|
|
} |
|
|
|
|
|
if (sub_graph.name() == nullptr || sub_graph.inputIndices() == nullptr || sub_graph.outputIndices() == nullptr || |
|
|
|
|
|
sub_graph.nodeIndices() == nullptr || sub_graph.tensorIndices() == nullptr) { |
|
|
|
|
|
MS_LOG(ERROR) << "sub_graph is invalid."; |
|
|
|
|
|
return RET_ERROR; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
auto *subgraph = new (std::nothrow) Model::SubGraph(); |
|
|
auto *subgraph = new (std::nothrow) Model::SubGraph(); |
|
|
if (subgraph == nullptr) { |
|
|
if (subgraph == nullptr) { |
|
|
MS_LOG(ERROR) << "new subGraph fail!"; |
|
|
MS_LOG(ERROR) << "new subGraph fail!"; |
|
|
return RET_ERROR; |
|
|
return RET_ERROR; |
|
|
} |
|
|
} |
|
|
MS_ASSERT(sub_graph.name() != nullptr); |
|
|
|
|
|
|
|
|
|
|
|
subgraph->name_ = sub_graph.name()->c_str(); |
|
|
subgraph->name_ = sub_graph.name()->c_str(); |
|
|
MS_ASSERT(sub_graph.inputIndices() != nullptr); |
|
|
|
|
|
auto in_count = sub_graph.inputIndices()->size(); |
|
|
auto in_count = sub_graph.inputIndices()->size(); |
|
|
for (uint32_t i = 0; i < in_count; ++i) { |
|
|
for (uint32_t i = 0; i < in_count; ++i) { |
|
|
subgraph->input_indices_.push_back(sub_graph.inputIndices()->Get(i)); |
|
|
subgraph->input_indices_.push_back(sub_graph.inputIndices()->Get(i)); |
|
|
} |
|
|
} |
|
|
MS_ASSERT(sub_graph.outputIndices() != nullptr); |
|
|
|
|
|
auto out_count = sub_graph.outputIndices()->size(); |
|
|
auto out_count = sub_graph.outputIndices()->size(); |
|
|
for (uint32_t i = 0; i < out_count; ++i) { |
|
|
for (uint32_t i = 0; i < out_count; ++i) { |
|
|
subgraph->output_indices_.push_back(sub_graph.outputIndices()->Get(i)); |
|
|
subgraph->output_indices_.push_back(sub_graph.outputIndices()->Get(i)); |
|
|
} |
|
|
} |
|
|
MS_ASSERT(sub_graph.nodeIndices() != nullptr); |
|
|
|
|
|
auto node_count = sub_graph.nodeIndices()->size(); |
|
|
auto node_count = sub_graph.nodeIndices()->size(); |
|
|
for (uint32_t i = 0; i < node_count; ++i) { |
|
|
for (uint32_t i = 0; i < node_count; ++i) { |
|
|
subgraph->node_indices_.push_back(sub_graph.nodeIndices()->Get(i)); |
|
|
subgraph->node_indices_.push_back(sub_graph.nodeIndices()->Get(i)); |
|
|
} |
|
|
} |
|
|
MS_ASSERT(sub_graph.tensorIndices() != nullptr); |
|
|
|
|
|
auto tensor_count = sub_graph.tensorIndices()->size(); |
|
|
auto tensor_count = sub_graph.tensorIndices()->size(); |
|
|
for (uint32_t i = 0; i < tensor_count; ++i) { |
|
|
for (uint32_t i = 0; i < tensor_count; ++i) { |
|
|
subgraph->tensor_indices_.push_back(sub_graph.tensorIndices()->Get(i)); |
|
|
subgraph->tensor_indices_.push_back(sub_graph.tensorIndices()->Get(i)); |
|
|
@@ -50,6 +56,10 @@ int ConvertSubGraph(const schema::SubGraph &sub_graph, Model *model) { |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
int VersionVerify(flatbuffers::Verifier *verify) { |
|
|
int VersionVerify(flatbuffers::Verifier *verify) { |
|
|
|
|
|
if (verify == nullptr) { |
|
|
|
|
|
MS_LOG(ERROR) << "verify is null."; |
|
|
|
|
|
return RET_ERROR; |
|
|
|
|
|
} |
|
|
if (schema::VerifyMetaGraphBuffer(*verify)) { |
|
|
if (schema::VerifyMetaGraphBuffer(*verify)) { |
|
|
return SCHEMA_VERSION::SCHEMA_CUR; |
|
|
return SCHEMA_VERSION::SCHEMA_CUR; |
|
|
} else if (schema::v0::VerifyMetaGraphBuffer(*verify)) { |
|
|
} else if (schema::v0::VerifyMetaGraphBuffer(*verify)) { |
|
|
@@ -58,8 +68,90 @@ int VersionVerify(flatbuffers::Verifier *verify) { |
|
|
return SCHEMA_VERSION::SCHEMA_INVALID; |
|
|
return SCHEMA_VERSION::SCHEMA_INVALID; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
int NodeVerify(const Model &model) { |
|
|
|
|
|
auto tensor_size = model.all_tensors_.size(); |
|
|
|
|
|
uint32_t subGraph_size = model.sub_graphs_.size(); |
|
|
|
|
|
|
|
|
|
|
|
for (auto &node : model.all_nodes_) { |
|
|
|
|
|
if (node == nullptr || node->primitive_ == nullptr) { |
|
|
|
|
|
MS_LOG(ERROR) << "node or its primitive_ is null."; |
|
|
|
|
|
return RET_ERROR; |
|
|
|
|
|
} |
|
|
|
|
|
if (std::any_of(node->input_indices_.begin(), node->input_indices_.end(), |
|
|
|
|
|
[&tensor_size](const uint32_t &idx) { return idx >= tensor_size; })) { |
|
|
|
|
|
MS_LOG(ERROR) << "Index of node->input_indices_ is beyond size."; |
|
|
|
|
|
return RET_ERROR; |
|
|
|
|
|
} |
|
|
|
|
|
if (std::any_of(node->output_indices_.begin(), node->output_indices_.end(), |
|
|
|
|
|
[&tensor_size](const uint32_t &idx) { return idx >= tensor_size; })) { |
|
|
|
|
|
MS_LOG(ERROR) << "Index of node->output_indices_ is beyond size."; |
|
|
|
|
|
return RET_ERROR; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
auto prim = node->primitive_; |
|
|
|
|
|
if (prim->Type() == schema::PrimitiveType_While) { |
|
|
|
|
|
auto whileOp = reinterpret_cast<mindspore::lite::While *>(const_cast<mindspore::lite::PrimitiveC *>(prim)); |
|
|
|
|
|
if (whileOp == nullptr) { |
|
|
|
|
|
MS_LOG(ERROR) << "whileOp is null."; |
|
|
|
|
|
return RET_ERROR; |
|
|
|
|
|
} |
|
|
|
|
|
if (static_cast<uint32_t>(whileOp->GetBodySubgraphIndex()) >= subGraph_size || |
|
|
|
|
|
static_cast<uint32_t>(whileOp->GetCondSubgraphIndex()) >= subGraph_size) { |
|
|
|
|
|
MS_LOG(ERROR) << "index of subGraph is beyond subGraph_size."; |
|
|
|
|
|
return RET_ERROR; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
return RET_OK; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
int SubGraphVerify(const Model &model) { |
|
|
|
|
|
auto tensor_size = model.all_tensors_.size(); |
|
|
|
|
|
auto node_size = model.all_nodes_.size(); |
|
|
|
|
|
|
|
|
|
|
|
for (auto &graph : model.sub_graphs_) { |
|
|
|
|
|
if (graph == nullptr) { |
|
|
|
|
|
MS_LOG(ERROR) << "graph is null."; |
|
|
|
|
|
return RET_ERROR; |
|
|
|
|
|
} |
|
|
|
|
|
if (std::any_of(graph->input_indices_.begin(), graph->input_indices_.end(), |
|
|
|
|
|
[&tensor_size](const uint32_t &idx) { return idx >= tensor_size; })) { |
|
|
|
|
|
MS_LOG(ERROR) << "Index of graph->input_indices_ is beyond tensor_size."; |
|
|
|
|
|
return RET_ERROR; |
|
|
|
|
|
} |
|
|
|
|
|
if (std::any_of(graph->output_indices_.begin(), graph->output_indices_.end(), |
|
|
|
|
|
[&tensor_size](const uint32_t &idx) { return idx >= tensor_size; })) { |
|
|
|
|
|
MS_LOG(ERROR) << "Index of graph->output_indices_ is beyond tensor_size."; |
|
|
|
|
|
return RET_ERROR; |
|
|
|
|
|
} |
|
|
|
|
|
if (std::any_of(graph->tensor_indices_.begin(), graph->tensor_indices_.end(), |
|
|
|
|
|
[&tensor_size](const uint32_t &idx) { return idx >= tensor_size; })) { |
|
|
|
|
|
MS_LOG(ERROR) << "Index of graph->tensor_indices_ is beyond tensor_size."; |
|
|
|
|
|
return RET_ERROR; |
|
|
|
|
|
} |
|
|
|
|
|
if (std::any_of(graph->node_indices_.begin(), graph->node_indices_.end(), |
|
|
|
|
|
[&node_size](const uint32_t &idx) { return idx >= node_size; })) { |
|
|
|
|
|
MS_LOG(ERROR) << "Index of graph->node_indices_ is beyond node_size."; |
|
|
|
|
|
return RET_ERROR; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
return RET_OK; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
int ModelVerify(const Model &model, const int &schema_version) { |
|
|
|
|
|
if (schema_version == SCHEMA_VERSION::SCHEMA_CUR) { |
|
|
|
|
|
return NodeVerify(model) == RET_OK && SubGraphVerify(model) == RET_OK; |
|
|
|
|
|
} else if (schema_version == SCHEMA_VERSION::SCHEMA_V0) { |
|
|
|
|
|
return NodeVerify(model) == RET_OK; |
|
|
|
|
|
} |
|
|
|
|
|
return RET_ERROR; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
const void *GetMetaGraphByVerison(const char *buf, const int &schema_version) { |
|
|
const void *GetMetaGraphByVerison(const char *buf, const int &schema_version) { |
|
|
MS_ASSERT(buf != nullptr); |
|
|
|
|
|
|
|
|
if (buf == nullptr) { |
|
|
|
|
|
MS_LOG(ERROR) << "buf is null."; |
|
|
|
|
|
return nullptr; |
|
|
|
|
|
} |
|
|
if (schema_version == SCHEMA_VERSION::SCHEMA_CUR) { |
|
|
if (schema_version == SCHEMA_VERSION::SCHEMA_CUR) { |
|
|
return reinterpret_cast<const void *>(schema::GetMetaGraph(buf)); |
|
|
return reinterpret_cast<const void *>(schema::GetMetaGraph(buf)); |
|
|
} else if (schema_version == SCHEMA_VERSION::SCHEMA_V0) { |
|
|
} else if (schema_version == SCHEMA_VERSION::SCHEMA_V0) { |
|
|
@@ -69,8 +161,10 @@ const void *GetMetaGraphByVerison(const char *buf, const int &schema_version) { |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
int GenerateModelByVersion(const void *meta_graph, Model *model, const int &schema_version) { |
|
|
int GenerateModelByVersion(const void *meta_graph, Model *model, const int &schema_version) { |
|
|
MS_ASSERT(meta_graph != nullptr); |
|
|
|
|
|
MS_ASSERT(model != nullptr); |
|
|
|
|
|
|
|
|
if (meta_graph == nullptr || model == nullptr) { |
|
|
|
|
|
MS_LOG(ERROR) << "meta_graph or model is null."; |
|
|
|
|
|
return RET_ERROR; |
|
|
|
|
|
} |
|
|
int status = RET_ERROR; |
|
|
int status = RET_ERROR; |
|
|
if (schema_version == SCHEMA_VERSION::SCHEMA_CUR) { |
|
|
if (schema_version == SCHEMA_VERSION::SCHEMA_CUR) { |
|
|
status = GenerateModel<schema::MetaGraph, schema::CNode>(*reinterpret_cast<const schema::MetaGraph *>(meta_graph), |
|
|
status = GenerateModel<schema::MetaGraph, schema::CNode>(*reinterpret_cast<const schema::MetaGraph *>(meta_graph), |
|
|
@@ -135,6 +229,7 @@ Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf) { |
|
|
delete (model); |
|
|
delete (model); |
|
|
return nullptr; |
|
|
return nullptr; |
|
|
} |
|
|
} |
|
|
return model; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return ModelVerify(*model, schema_version) ? model : nullptr; |
|
|
} |
|
|
} |
|
|
} // namespace mindspore::lite |
|
|
} // namespace mindspore::lite |