|
|
|
@@ -24,6 +24,7 @@ |
|
|
|
namespace mindspore { |
|
|
|
namespace lite { |
|
|
|
int GetPrimitiveType(const void *primitive) { |
|
|
|
MS_ASSERT(primitive != nullptr); |
|
|
|
if (primitive == nullptr) { |
|
|
|
return -1; |
|
|
|
} |
|
|
|
@@ -51,6 +52,7 @@ const char *PrimitiveCurVersionTypeName(int type) { |
|
|
|
int GenPrimVersionKey(int primitive_type, int schema_version) { return primitive_type * 1000 + schema_version; } |
|
|
|
|
|
|
|
bool IsPartialNode(const void *primitive) { |
|
|
|
MS_ASSERT(primitive != nullptr); |
|
|
|
int schema_version = VersionManager::GetInstance()->GetSchemaVersion(); |
|
|
|
if (schema_version == SCHEMA_CUR) { |
|
|
|
return reinterpret_cast<const schema::Primitive *>(primitive)->value_type() == schema::PrimitiveType_PartialFusion; |
|
|
|
@@ -65,9 +67,11 @@ bool IsPartialNode(const void *primitive) { |
|
|
|
} |
|
|
|
|
|
|
|
int GetPartialGraphIndex(const void *primitive) { |
|
|
|
MS_ASSERT(primitive != nullptr); |
|
|
|
int index = -1; |
|
|
|
int schema_version = VersionManager::GetInstance()->GetSchemaVersion(); |
|
|
|
if (schema_version == SCHEMA_CUR) { |
|
|
|
MS_ASSERT(static_cast<const schema::Primitive *>(primitive)->value_as_PartialFusion() != nullptr); |
|
|
|
index = static_cast<const schema::Primitive *>(primitive)->value_as_PartialFusion()->sub_graph_index(); |
|
|
|
} |
|
|
|
#ifdef ENABLE_V0 |
|
|
|
@@ -79,6 +83,7 @@ int GetPartialGraphIndex(const void *primitive) { |
|
|
|
} |
|
|
|
|
|
|
|
bool IsWhileNode(const void *primitive) { |
|
|
|
MS_ASSERT(primitive != nullptr); |
|
|
|
int schema_version = VersionManager::GetInstance()->GetSchemaVersion(); |
|
|
|
if (schema_version == SCHEMA_CUR) { |
|
|
|
return reinterpret_cast<const schema::Primitive *>(primitive)->value_type() == schema::PrimitiveType_While; |
|
|
|
@@ -92,13 +97,16 @@ bool IsWhileNode(const void *primitive) { |
|
|
|
} |
|
|
|
|
|
|
|
int GetWhileBodySubgraphIndex(const void *primitive) { |
|
|
|
MS_ASSERT(primitive != nullptr); |
|
|
|
int index = -1; |
|
|
|
int schema_version = VersionManager::GetInstance()->GetSchemaVersion(); |
|
|
|
if (schema_version == SCHEMA_CUR) { |
|
|
|
MS_ASSERT(static_cast<const schema::Primitive *>(primitive)->value_as_While() != nullptr); |
|
|
|
index = reinterpret_cast<const schema::Primitive *>(primitive)->value_as_While()->body_subgraph_index(); |
|
|
|
} |
|
|
|
#ifdef ENABLE_V0 |
|
|
|
if (schema_version == SCHEMA_V0) { |
|
|
|
MS_ASSERT(static_cast<const schema::Primitive *>(primitive)->value_as_While() != nullptr); |
|
|
|
index = reinterpret_cast<const schema::v0::Primitive *>(primitive)->value_as_While()->bodySubgraphIndex(); |
|
|
|
} |
|
|
|
#endif |
|
|
|
@@ -106,13 +114,16 @@ int GetWhileBodySubgraphIndex(const void *primitive) { |
|
|
|
} |
|
|
|
|
|
|
|
int GetWhileCondSubgraphIndex(const void *primitive) { |
|
|
|
MS_ASSERT(primitive != nullptr); |
|
|
|
int index = -1; |
|
|
|
int schema_version = VersionManager::GetInstance()->GetSchemaVersion(); |
|
|
|
if (schema_version == SCHEMA_CUR) { |
|
|
|
MS_ASSERT(static_cast<const schema::Primitive *>(primitive)->value_as_While() != nullptr); |
|
|
|
index = reinterpret_cast<const schema::Primitive *>(primitive)->value_as_While()->cond_subgraph_index(); |
|
|
|
} |
|
|
|
#ifdef ENABLE_V0 |
|
|
|
if (schema_version == SCHEMA_V0) { |
|
|
|
MS_ASSERT(static_cast<const schema::Primitive *>(primitive)->value_as_While() != nullptr); |
|
|
|
index = reinterpret_cast<const schema::v0::Primitive *>(primitive)->value_as_While()->condSubgraphIndex(); |
|
|
|
} |
|
|
|
#endif |
|
|
|
|