|
|
|
@@ -59,8 +59,9 @@ STATUS SwitchPass::Run(mindspore::schema::MetaGraphT *graph) { |
|
|
|
if (type != schema::PrimitiveType_PartialFusion) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
MS_ASSERT(node != nullptr); |
|
|
|
MS_ASSERT(node->primitive != nullptr); |
|
|
|
MS_ASSERT(node->primitive->value..AsPartialFusion() != nullptr); |
|
|
|
MS_ASSERT(node->primitive->value.AsPartialFusion() != nullptr); |
|
|
|
auto partial_prim = node->primitive->value.AsPartialFusion(); |
|
|
|
if (partial_prim->sub_graph_index == -1) { |
|
|
|
continue; |
|
|
|
@@ -467,7 +468,9 @@ STATUS SingleSwitchPass::Init() { |
|
|
|
} |
|
|
|
|
|
|
|
// get cond_graph_nodes_ |
|
|
|
MS_ASSERT(first_partial_node_->primitive->value..AsPartialFusion() != nullptr); |
|
|
|
MS_ASSERT(first_partial_node_ != nullptr); |
|
|
|
MS_ASSERT(first_partial_node_->primitive != nullptr); |
|
|
|
MS_ASSERT(first_partial_node_->primitive->value.AsPartialFusion() != nullptr); |
|
|
|
first_subgraph_index_ = first_partial_node_->primitive->value.AsPartialFusion()->sub_graph_index; |
|
|
|
auto cond_node_indices = graph_->subGraph.at(first_subgraph_index_)->nodeIndices; |
|
|
|
for (auto &index : cond_node_indices) { |
|
|
|
@@ -623,8 +626,9 @@ STATUS SingleSwitchPass::UpdateSubgraphOutput(const size_t &subgraph_index, sche |
|
|
|
|
|
|
|
STATUS SingleSwitchPass::ConcatCondSubgraphInputAndOutput() { |
|
|
|
if (first_subgraph_index_ == -1) { |
|
|
|
MS_ASSERT(first_partial_node_ != nullptr); |
|
|
|
MS_ASSERT(first_partial_node_->primitive != nullptr); |
|
|
|
MS_ASSERT(first_partial_node_->primitive->value..AsPartialFusion() != nullptr); |
|
|
|
MS_ASSERT(first_partial_node_->primitive->value.AsPartialFusion() != nullptr); |
|
|
|
first_partial_node_->primitive->value.AsPartialFusion()->sub_graph_index = -1; |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
@@ -644,8 +648,9 @@ STATUS SingleSwitchPass::ConcatCondSubgraphInputAndOutput() { |
|
|
|
|
|
|
|
STATUS SingleSwitchPass::ConcatBodySubgraphInputAndOutput() { |
|
|
|
if (second_subgraph_index_ == -1) { |
|
|
|
MS_ASSERT(first_partial_node_ != nullptr); |
|
|
|
MS_ASSERT(first_partial_node_->primitive != nullptr); |
|
|
|
MS_ASSERT(first_partial_node_->primitive->value..AsPartialFusion() != nullptr); |
|
|
|
MS_ASSERT(first_partial_node_->primitive->value.AsPartialFusion() != nullptr); |
|
|
|
first_partial_node_->primitive->value.AsPartialFusion()->sub_graph_index = -1; |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|