|
|
|
@@ -771,16 +771,13 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr> |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<CNodePtr> execution_order; |
|
|
|
uint32_t child_order_index = 0; |
|
|
|
auto recurse_child_graph = [&](uint32_t index, uint32_t label_index, const CNodePtr &node) { |
|
|
|
if (!CheckLabelIndex(index, label_index, node)) { |
|
|
|
KernelGraphPtr cur_child_graph; |
|
|
|
if (!CheckLabelIndex(index, label_index, node, &cur_child_graph)) { |
|
|
|
MS_LOG(EXCEPTION) << "Check label index fail"; |
|
|
|
} |
|
|
|
if (child_order_index >= graph->child_graph_order().size()) { |
|
|
|
MS_LOG(EXCEPTION) << "Index out of range:" << graph->child_graph_order().size(); |
|
|
|
} |
|
|
|
auto child_graph = graph->child_graph_order()[child_order_index++]; |
|
|
|
auto child_execution_order = RecurseGraph(NOT_NULL(child_graph.lock()), memo); |
|
|
|
MS_EXCEPTION_IF_NULL(cur_child_graph); |
|
|
|
auto child_execution_order = RecurseGraph(NOT_NULL(cur_child_graph), memo); |
|
|
|
execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end()); |
|
|
|
}; |
|
|
|
|
|
|
|
@@ -809,18 +806,19 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr> |
|
|
|
return execution_order; |
|
|
|
} |
|
|
|
|
|
|
|
bool AscendControlParser::CheckLabelIndex(uint32_t index, uint32_t label_index, const CNodePtr &cur_label) { |
|
|
|
bool AscendControlParser::CheckLabelIndex(uint32_t index, uint32_t label_index, const CNodePtr &cur_label, |
|
|
|
KernelGraphPtr *cur_child_graph) { |
|
|
|
auto child_graphs = AnfAlgo::GetNodeAttr<std::vector<KernelGraphPtr>>(cur_label, kAttrChildGraph); |
|
|
|
// check index and child order size |
|
|
|
if (child_graphs.size() <= IntToSize(index)) { |
|
|
|
MS_LOG(EXCEPTION) << "Child graph index is wrong, current node " << cur_label->ToString() << " child graph size " |
|
|
|
<< child_graphs.size() << " goto index " << index; |
|
|
|
} |
|
|
|
auto child_graph = child_graphs[index]; |
|
|
|
MS_EXCEPTION_IF_NULL(child_graph); |
|
|
|
*cur_child_graph = child_graphs[index]; |
|
|
|
MS_EXCEPTION_IF_NULL(*cur_child_graph); |
|
|
|
|
|
|
|
// get start_label_set_index of child graph |
|
|
|
auto start_label_set = child_graph->get_start_label(); |
|
|
|
auto start_label_set = (*cur_child_graph)->get_start_label(); |
|
|
|
uint32_t start_label_set_index = AnfAlgo::GetNodeAttr<uint32_t>(start_label_set, kAttrLabelIndex); |
|
|
|
if (label_index != start_label_set_index) { |
|
|
|
MS_EXCEPTION_IF_NULL(cur_label); |
|
|
|
|