diff --git a/mindspore/ccsrc/backend/session/ascend_control_parser.cc b/mindspore/ccsrc/backend/session/ascend_control_parser.cc index 367e7d2ec6..00fab503e2 100644 --- a/mindspore/ccsrc/backend/session/ascend_control_parser.cc +++ b/mindspore/ccsrc/backend/session/ascend_control_parser.cc @@ -63,7 +63,7 @@ static void RecursiveReplaceNode(NotNull kg, NotNull } for (auto &child : kg->child_graph_order()) { - RecursiveReplaceNode(NOT_NULL(child), main_parameter, parameter_reuse_set, memo); + RecursiveReplaceNode(NOT_NULL(child.lock()), main_parameter, parameter_reuse_set, memo); } } @@ -177,13 +177,14 @@ void AscendControlParser::AttachChildGraphToReturnNode(NotNull g return; } memo->insert(graph.get()); - const std::vector> &child_graph_order = graph->child_graph_order(); + const std::vector> &child_graph_order = graph->child_graph_order(); if (child_graph_order.empty()) { return; } std::vector depend_inputs = {NewValueNode(std::make_shared(prim::kPrimPartial->name()))}; - for (auto &cg : child_graph_order) { + for (auto &kg : child_graph_order) { + std::shared_ptr cg = kg.lock(); MS_EXCEPTION_IF_NULL(cg); auto fg = cg->cast(); MS_EXCEPTION_IF_NULL(fg); @@ -207,7 +208,7 @@ void AscendControlParser::LinkGraph(NotNull kg) { memo.clear(); // assign label resource device::ascend::AscendLabelAssign::GetInstance().AssignLabel(kg); - AttachChildGraphToReturnNode(kg, NOT_NULL(&memo)); + // AttachChildGraphToReturnNode(kg, NOT_NULL(&memo)); } void AscendControlParser::EraseParameter(NotNull root_graph, @@ -428,7 +429,7 @@ void AscendControlParser::ChildGraphDataAssign( } kg->SetExecOrderByDefault(); for (auto &child_graph : kg->child_graph_order()) { - ChildGraphDataAssign(NOT_NULL(child_graph), link_list, memo); + ChildGraphDataAssign(NOT_NULL(child_graph.lock()), link_list, memo); } } @@ -772,7 +773,7 @@ std::vector AscendControlParser::RecurseGraph(NotNull MS_LOG(EXCEPTION) << "Index out of range:" << graph->child_graph_order().size(); } auto child_graph = graph->child_graph_order()[child_order_index++]; - auto child_execution_order = RecurseGraph(NOT_NULL(child_graph), memo); + auto child_execution_order = RecurseGraph(NOT_NULL(child_graph.lock()), memo); execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end()); }; @@ -791,6 +792,10 @@ std::vector AscendControlParser::RecurseGraph(NotNull uint32_t label_index = AnfAlgo::GetNodeAttr(node, kAttrLabelIndex); recurse_child_graph(child_graph_index, label_index, node); } + // erase kAttrChildGraph after finish using + if (AnfAlgo::HasNodeAttr(kAttrChildGraph, node)) { + AnfAlgo::EraseNodeAttr(kAttrChildGraph, node); + } } graph->set_execution_order(execution_order); graph->PrintGraphExecuteOrder(); diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index 7d7a24106e..4d4ad4ad06 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -860,7 +860,7 @@ void AscendSession::CreateMultiBranchOutput(NotNull graph, NotNu memo->insert(graph.get()); graph->UpdateChildGraphOrder(); for (auto &child_graph : graph->child_graph_order()) { - CreateMultiBranchOutput(NOT_NULL(child_graph), memo); + CreateMultiBranchOutput(NOT_NULL(child_graph.lock()), memo); } std::map need_replace_list; auto node_list = GetCNodes(TopoSort(graph->get_return())); @@ -932,7 +932,7 @@ void AscendSession::IrFusionPass(const NotNull graph, NotNullchild_graph_order()) { - IrFusionPass(NOT_NULL(child_graph), memo); + IrFusionPass(NOT_NULL(child_graph.lock()), memo); } } @@ -1016,7 +1016,7 @@ void AscendSession::HardwareOptimize(NotNull graph, HardwareOptimize(graph.get()); for (auto &child_graph : graph->child_graph_order()) { - HardwareOptimize(NOT_NULL(child_graph), memo); + HardwareOptimize(NOT_NULL(child_graph.lock()), memo); } MS_LOG(INFO) << "Finish doing HardwareOptimize in graph: " << graph->graph_id(); } @@ -1035,7 +1035,7 @@ void AscendSession::AssignStaticMemory(NotNull graph, runtime_instance->AssignStaticMemoryInput(graph.get().get()); runtime_instance->AssignStaticMemoryValueNode(graph.get().get()); for (auto &child_graph : graph->child_graph_order()) { - AssignStaticMemory(NOT_NULL(child_graph), memo); + AssignStaticMemory(NOT_NULL(child_graph.lock()), memo); } MS_LOG(INFO) << "Finish assigning static memory for parameter in graph: " << graph->graph_id(); } @@ -1048,9 +1048,11 @@ void AscendSession::UpdateRefOutputMap(NotNull graph, memo->insert(graph.get()); for (auto &child_graph : graph->child_graph_order()) { - UpdateRefOutputMap(NOT_NULL(child_graph), memo); + std::shared_ptr child_graph_ptr = child_graph.lock(); + MS_EXCEPTION_IF_NULL(child_graph_ptr); + UpdateRefOutputMap(NOT_NULL(child_graph_ptr), memo); // copy ref map to final graph - auto child_ref_map = child_graph->GetRefMap(); + auto child_ref_map = child_graph_ptr->GetRefMap(); for (auto &item : child_ref_map) { if (graph->IsInRefOutputMap(item.first)) { MS_LOG(WARNING) << "The ref pair <" << item.first.first->DebugString() << ", " << item.first.second diff --git a/mindspore/ccsrc/backend/session/kernel_graph.cc b/mindspore/ccsrc/backend/session/kernel_graph.cc index 9cba4ba6de..3c856bf06c 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.cc +++ b/mindspore/ccsrc/backend/session/kernel_graph.cc @@ -922,8 +922,9 @@ std::vector> KernelGraph::GetLeafGraphOrder() { leaf_graph_order.push_back(shared_from_this()->cast()); } else { for (const auto &child_graph : child_graph_order_) { - MS_EXCEPTION_IF_NULL(child_graph); - auto child_leaf_graph_order = child_graph->GetLeafGraphOrder(); + std::shared_ptr child_graph_ptr = child_graph.lock(); + MS_EXCEPTION_IF_NULL(child_graph_ptr); + auto child_leaf_graph_order = child_graph_ptr->GetLeafGraphOrder(); std::copy(child_leaf_graph_order.begin(), child_leaf_graph_order.end(), std::back_inserter(leaf_graph_order)); } } @@ -1103,13 +1104,13 @@ void KernelGraph::UpdateChildGraphOrder() { SetExecOrderByDefault(); auto call_nodes = FindNodeByPrimitive( {std::make_shared(prim::kPrimCall->name()), std::make_shared(prim::kPrimSwitch->name())}); - std::vector child_graph_order; + std::vector> child_graph_order; for (auto &call_node : call_nodes) { MS_EXCEPTION_IF_NULL(call_node); auto call_child_graphs = AnfAlgo::GetCallSwitchKernelGraph(call_node->cast()); for (const auto &child_graph : call_child_graphs) { MS_EXCEPTION_IF_NULL(child_graph); - if (child_graph != parent_graph_) { + if (child_graph != parent_graph_.lock()) { auto shared_this = std::dynamic_pointer_cast(shared_from_this()); MS_EXCEPTION_IF_NULL(shared_this); child_graph->set_parent_graph(shared_this); @@ -1118,7 +1119,9 @@ void KernelGraph::UpdateChildGraphOrder() { } } for (size_t i = 0; i < child_graph_order.size(); ++i) { - MS_LOG(INFO) << "Child graph[" << i << "][id:" << child_graph_order[i]->graph_id() << "]"; + std::shared_ptr child_graph = child_graph_order[i].lock(); + MS_EXCEPTION_IF_NULL(child_graph); + MS_LOG(INFO) << "Child graph[" << i << "][id:" << child_graph->graph_id() << "]"; } child_graph_order_ = child_graph_order; } diff --git a/mindspore/ccsrc/backend/session/kernel_graph.h b/mindspore/ccsrc/backend/session/kernel_graph.h index 46e38b4ac6..4c830bc861 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.h +++ b/mindspore/ccsrc/backend/session/kernel_graph.h @@ -114,8 +114,8 @@ class KernelGraph : public FuncGraph { // calculate the leaf graph order of root graph std::vector> GetLeafGraphOrder(); // the child graph of current graph - const std::vector> &child_graph_order() const { return child_graph_order_; } - void set_child_graph_order(const std::vector> &order) { child_graph_order_ = order; } + const std::vector> &child_graph_order() const { return child_graph_order_; } + void set_child_graph_order(const std::vector> &order) { child_graph_order_ = order; } // checkout whether current graph is leaf graph bool IsLeafGraph() const; @@ -126,9 +126,9 @@ class KernelGraph : public FuncGraph { // get input_tensors pointer of control parameter std::shared_ptr> input_ctrl_tensors() const { return input_ctrl_tensors_; } // get parent kernel graph - std::shared_ptr parent_graph() const { return parent_graph_; } + std::weak_ptr parent_graph() const { return parent_graph_; } // set parent kernel graph - void set_parent_graph(const std::shared_ptr &parent_graph) { parent_graph_ = parent_graph; } + void set_parent_graph(const std::weak_ptr &parent_graph) { parent_graph_ = parent_graph; } // find anf node in graph std::vector FindNodeByPrimitive(const PrimitivePtr &primitive) const; std::vector FindNodeByPrimitive(const std::vector &primitive_list) const; @@ -227,13 +227,13 @@ class KernelGraph : public FuncGraph { std::vector valid_inputs_; // child graph execute order in root graph - std::vector> child_graph_order_; + std::vector> child_graph_order_; // input_tensors of control parameter std::shared_ptr> input_ctrl_tensors_; // parameter graph - std::shared_ptr parent_graph_; + std::weak_ptr parent_graph_; CNodePtr start_label_; CNodePtr end_goto_; diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_label_assign.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_label_assign.cc index b15df0d60b..78a964027f 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_label_assign.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_label_assign.cc @@ -89,7 +89,7 @@ static void AssignLabelForLabelSet(NotNull } for (auto &cg : graph->child_graph_order()) { - AssignLabelForLabelSet(NOT_NULL(cg), label_id, memo); + AssignLabelForLabelSet(NOT_NULL(cg.lock()), label_id, memo); } } @@ -120,7 +120,7 @@ static void AssignLabelForGotoSwitch(NotNullchild_graph_order()) { - AssignLabelForGotoSwitch(NOT_NULL(cg), memo); + AssignLabelForGotoSwitch(NOT_NULL(cg.lock()), memo); } }