|
|
|
@@ -1238,7 +1238,7 @@ void AscendSession::MultiCallGraphOptimize(NotNull<KernelGraphPtr> root_graph) { |
|
|
|
MS_LOG(INFO) << "graph: " << graph->graph_id() << " has been called by more than two graphs"; |
|
|
|
int32_t index = 0; |
|
|
|
std::vector<KernelGraphPtr> child_graphs; |
|
|
|
auto start_label = graph->get_start_label(); |
|
|
|
auto start_label_id = AnfAlgo::GetNodeAttr<uint32_t>(graph->get_start_label(), kAttrLabelIndex); |
|
|
|
auto end_node = graph->get_end_goto(); |
|
|
|
ParameterPtr post_label_param = graph->AddExtraParamAndTensor("label_param", 0); |
|
|
|
std::vector<AnfNodePtr> new_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSwitchOpName)), |
|
|
|
@@ -1247,9 +1247,7 @@ void AscendSession::MultiCallGraphOptimize(NotNull<KernelGraphPtr> root_graph) { |
|
|
|
auto kg = graphs_[graph_id]; |
|
|
|
auto nodes = kg->execution_order(); |
|
|
|
for (uint32_t i = 0; i < nodes.size(); i++) { |
|
|
|
if (AnfAlgo::GetCNodeName(nodes[i]) == kLabelGotoOpName && |
|
|
|
(AnfAlgo::GetNodeAttr<uint32_t>(nodes[i], kAttrLabelIndex) == |
|
|
|
AnfAlgo::GetNodeAttr<uint32_t>(start_label, kAttrLabelIndex))) { |
|
|
|
if (AnfAlgo::IsLabelIndexInNode(nodes[i], start_label_id)) { |
|
|
|
if (i < (nodes.size() - 1)) { |
|
|
|
new_inputs.push_back(nodes[i + 1]); |
|
|
|
} else { |
|
|
|
|