diff --git a/mindspore/ccsrc/backend/session/ascend_control_parser.cc b/mindspore/ccsrc/backend/session/ascend_control_parser.cc index b44262b597..fb156ebb63 100644 --- a/mindspore/ccsrc/backend/session/ascend_control_parser.cc +++ b/mindspore/ccsrc/backend/session/ascend_control_parser.cc @@ -771,16 +771,13 @@ std::vector AscendControlParser::RecurseGraph(NotNull } std::vector execution_order; - uint32_t child_order_index = 0; auto recurse_child_graph = [&](uint32_t index, uint32_t label_index, const CNodePtr &node) { - if (!CheckLabelIndex(index, label_index, node)) { + KernelGraphPtr cur_child_graph; + if (!CheckLabelIndex(index, label_index, node, &cur_child_graph)) { MS_LOG(EXCEPTION) << "Check label index fail"; } - if (child_order_index >= graph->child_graph_order().size()) { - MS_LOG(EXCEPTION) << "Index out of range:" << graph->child_graph_order().size(); - } - auto child_graph = graph->child_graph_order()[child_order_index++]; - auto child_execution_order = RecurseGraph(NOT_NULL(child_graph.lock()), memo); + MS_EXCEPTION_IF_NULL(cur_child_graph); + auto child_execution_order = RecurseGraph(NOT_NULL(cur_child_graph), memo); execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end()); }; @@ -809,18 +806,19 @@ std::vector AscendControlParser::RecurseGraph(NotNull return execution_order; } -bool AscendControlParser::CheckLabelIndex(uint32_t index, uint32_t label_index, const CNodePtr &cur_label) { +bool AscendControlParser::CheckLabelIndex(uint32_t index, uint32_t label_index, const CNodePtr &cur_label, + KernelGraphPtr *cur_child_graph) { auto child_graphs = AnfAlgo::GetNodeAttr>(cur_label, kAttrChildGraph); // check index and child order size if (child_graphs.size() <= IntToSize(index)) { MS_LOG(EXCEPTION) << "Child graph index is wrong, current node " << cur_label->ToString() << " child graph size " << child_graphs.size() << " goto index " << index; } - auto child_graph = child_graphs[index]; - MS_EXCEPTION_IF_NULL(child_graph); + *cur_child_graph = child_graphs[index]; + MS_EXCEPTION_IF_NULL(*cur_child_graph); // get start_label_set_index of child graph - auto start_label_set = child_graph->get_start_label(); + auto start_label_set = (*cur_child_graph)->get_start_label(); uint32_t start_label_set_index = AnfAlgo::GetNodeAttr(start_label_set, kAttrLabelIndex); if (label_index != start_label_set_index) { MS_EXCEPTION_IF_NULL(cur_label); diff --git a/mindspore/ccsrc/backend/session/ascend_control_parser.h b/mindspore/ccsrc/backend/session/ascend_control_parser.h index ddecd0a81a..096c9b4406 100644 --- a/mindspore/ccsrc/backend/session/ascend_control_parser.h +++ b/mindspore/ccsrc/backend/session/ascend_control_parser.h @@ -74,7 +74,8 @@ class AscendControlParser { static void AttachChildGraphToReturnNode(NotNull graph, const NotNull *> memo); // root graph order - static bool CheckLabelIndex(uint32_t index, uint32_t label_index, const CNodePtr &cnode); + static bool CheckLabelIndex(uint32_t index, uint32_t label_index, const CNodePtr &cnode, + KernelGraphPtr *cur_child_graph); static std::vector RecurseGraph(NotNull graph, const NotNull *> memo); static void AttachOriginalInputsToGraph(NotNull graph, const std::vector orig_inputs); diff --git a/mindspore/ccsrc/backend/session/kernel_graph.h b/mindspore/ccsrc/backend/session/kernel_graph.h index 25f54f1a58..ef9b51ce45 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.h +++ b/mindspore/ccsrc/backend/session/kernel_graph.h @@ -240,7 +240,7 @@ class KernelGraph : public FuncGraph { // valid inputs std::vector valid_inputs_; - // child graph execute order in root graph + // child graph execute order in parent graph std::vector> child_graph_order_; // input_tensors of control parameter