diff --git a/mindspore/ccsrc/session/ascend_control_parser.cc b/mindspore/ccsrc/session/ascend_control_parser.cc index 416ea49e63..331e0cd1d0 100644 --- a/mindspore/ccsrc/session/ascend_control_parser.cc +++ b/mindspore/ccsrc/session/ascend_control_parser.cc @@ -19,6 +19,17 @@ #include "session/ascend_control_parser.h" #include "session/anf_runtime_algorithm.h" +static constexpr size_t kCNodePrim = 0; +static constexpr size_t kCNodeCallArg = 1; +static constexpr size_t kCNodeSwitchCond = 1; +static constexpr size_t kCNodeSwitchTrue = 2; +static constexpr size_t kCNodeSwitchFalse = 3; +static constexpr size_t kCNodeSwitchLength = 4; +static constexpr size_t kCNodePartialLength = 2; +static constexpr size_t kCNodePartialFunc = 1; +static constexpr size_t kCNodeSwitchLayerBranch = 2; +static constexpr size_t kCNodeSwitchLayerLength = 3; + namespace mindspore { namespace session { @@ -61,7 +72,7 @@ void AscendControlParser::LinkGraph(NotNull kg) { ChildGraphDataAssign(graph_id_map); } -CNodePtr AscendControlParser::GetNextRealKernel(std::vector list, size_t start) { +CNodePtr AscendControlParser::GetNextRealKernel(const std::vector &list, size_t start) { for (size_t i = start; i < list.size() - 1; ++i) { if (!IsPrimitiveCNode(list[i], prim::kPrimPartial) && AnfAlgo::IsRealKernel(list[i])) { return list[i]; @@ -83,11 +94,11 @@ NotNull AscendControlParser::ProcessKernelGraph(NotNullinsert(kg.get()); // 2. args replace placeholder - LinkParentGraph(kg, last_node, last_label, memo); + LinkParentGraph(kg, last_node, last_label); // 3. topological sort kg->SetExecOrderByDefault(); - std::vector nodes = kg->execution_order(); + const std::vector &nodes = kg->execution_order(); if (nodes.empty()) { MS_LOG(EXCEPTION) << "KernelGraph " << kg->ToString() << " has no cnodes!"; } @@ -149,9 +160,9 @@ void AscendControlParser::InsertControlDependToGraph(NotNull kg, } void AscendControlParser::LinkParentGraph(NotNull kg, const CNodePtr &from_graph_call_node, - const CNodePtr &last_label, NotNull *> memo) { + const CNodePtr &last_label) { auto origin_return = kg->get_return(); - std::vector origin_return_inputs = origin_return->inputs(); + const std::vector &origin_return_inputs = origin_return->inputs(); // if entry graph, replace return with make_tuple if (from_graph_call_node == nullptr || last_label == nullptr) { MS_LOG(INFO) << kg->ToString() << " is entry graph."; @@ -173,7 +184,7 @@ void AscendControlParser::RecurseCall(NotNull kg, NotNullDebugString(); // 1 get kernel graph - auto origin_inputs = cur_node->inputs(); + const std::vector &origin_inputs = cur_node->inputs(); std::vector new_inputs = {std::make_shared(std::make_shared(kLabelGotoOpName))}; if (!IsValueNode(origin_inputs[kCNodeCallArg])) { MS_LOG(WARNING) << "Node " << cur_node->DebugString(10) << " index " << kCNodeCallArg << " is not a ValueNode"; @@ -217,15 +228,14 @@ void AscendControlParser::RecurseSwitch(NotNull kg, NotNullinputs(); + const std::vector &origin_switch_inputs = cur_node->inputs(); std::vector new_switch_inputs = { std::make_shared(std::make_shared(kLabelSwitchOpName)), origin_switch_inputs[kCNodeSwitchCond]}; for (size_t i = kCNodeSwitchCond + 1; i < kCNodeSwitchLength; ++i) { // 3.1 branch kernel graph and args - CNodePtr partial; KernelGraphPtr branch_fg; - std::tie(partial, branch_fg) = ParsePartial(NOT_NULL(origin_switch_inputs[i])); + std::tie(std::ignore, branch_fg) = ParsePartial(NOT_NULL(origin_switch_inputs[i])); // 3.2 recurse sub graph CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo); new_switch_inputs.push_back(branch_label); @@ -249,9 +259,9 @@ void AscendControlParser::RecurseSwitchLayer(NotNull kg, NotNull auto branch_tuple = cur_node->input(kCNodeSwitchLayerBranch); MS_EXCEPTION_IF_NULL(branch_tuple); if (!branch_tuple->isa()) { - MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLayerLength; + MS_LOG(EXCEPTION) << branch_tuple->DebugString() << " is not a CNode"; } - auto branch_partial = utils::cast(branch_tuple)->inputs(); + const std::vector &branch_partial = utils::cast(branch_tuple)->inputs(); // 1 return label auto back_label = kg->NewCNode({std::make_shared(std::make_shared(kLabelSetOpName))}); // 2 add depend relationship @@ -260,15 +270,14 @@ void AscendControlParser::RecurseSwitchLayer(NotNull kg, NotNull InsertControlDependToGraph(kg, NOT_NULL(back_label), NOT_NULL(next_node)); } // 3 recurse sub graph - auto origin_switch_inputs = cur_node->inputs(); + const std::vector &origin_switch_inputs = cur_node->inputs(); std::vector new_switch_inputs = { std::make_shared(std::make_shared(kLabelSwitchOpName)), origin_switch_inputs[kCNodeSwitchCond]}; for (size_t i = 0; i < branch_partial.size(); ++i) { // 3.1 branch kernel graph and args - CNodePtr partial; KernelGraphPtr branch_fg; - std::tie(partial, branch_fg) = ParsePartial(NOT_NULL(origin_switch_inputs[i])); + std::tie(std::ignore, branch_fg) = ParsePartial(NOT_NULL(origin_switch_inputs[i])); // 3.2 recurse sub graph CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo); new_switch_inputs.push_back(branch_label); @@ -315,18 +324,6 @@ void AscendControlParser::InsertAssignToGraph(NotNull kg, NotNul InsertDependToGraph(kg, NOT_NULL(assign_node)); } -NotNull AscendControlParser::GetRealInput(NotNull from_graph, - NotNull to_graph, NotNull param) { - std::set args_list = to_graph->GetRealInput(param); - for (auto arg : args_list) { - if (arg->func_graph() == from_graph.get()) { - return NOT_NULL(arg); - } - } - MS_LOG(EXCEPTION) << to_graph->ToString() << " input " << param->DebugString() << " not from " - << from_graph->ToString(); -} - void AscendControlParser::LinkArgsToParam(NotNull to_graph, NotNull target_graph, NotNull arg, NotNull param) { if (IsPrimitiveCNode(arg, prim::kPrimMakeTuple) && IsPrimitiveCNode(param, prim::kPrimMakeTuple)) { @@ -369,10 +366,10 @@ std::vector AscendControlParser::RecurseGraph(const CNodePtr &cur_labe return {}; } memo->insert(graph.get()); - + const std::vector> &child_graph_order = graph->child_graph_order(); graph->SetExecOrderByDefault(); - std::vector cnodes = graph->execution_order(); + const std::vector &cnodes = graph->execution_order(); std::map label_map; std::map> label_switch_map; std::tie(label_map, label_switch_map) = GetLabelNode(cnodes); @@ -388,10 +385,10 @@ std::vector AscendControlParser::RecurseGraph(const CNodePtr &cur_labe std::find_if(label_map.begin(), label_map.end(), [node](const std::map::value_type iter) { return iter.second == node; }); if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) { - if (!CheckLabelIndex(label_iter->first, 0, label_iter->second, graph)) { + if (label_iter == label_map.end() || !CheckLabelIndex(label_iter->first, 0, label_iter->second, graph)) { MS_LOG(EXCEPTION) << "Check label index fail"; } - auto child_graph = graph->child_graph_order()[label_iter->first]; + auto child_graph = child_graph_order[label_iter->first]; if (child_graph == graph->parent_graph()) { continue; } @@ -407,7 +404,7 @@ std::vector AscendControlParser::RecurseGraph(const CNodePtr &cur_labe if (!CheckLabelIndex(label_iter->first + i, label_list[i], label_iter->second, graph)) { MS_LOG(EXCEPTION) << "Check label index fail"; } - auto child_graph = graph->child_graph_order()[label_iter->first + i]; + auto child_graph = child_graph_order[label_iter->first + i]; if (child_graph == graph->parent_graph()) { continue; } @@ -426,10 +423,11 @@ std::vector AscendControlParser::RecurseGraph(const CNodePtr &cur_labe bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cur_label, NotNull graph) { + const std::vector> &child_graph_order = graph->child_graph_order(); // check index and child order size - if (graph->child_graph_order().size() <= static_cast(order_index)) { + if (child_graph_order.size() <= IntToSize(order_index)) { MS_LOG(EXCEPTION) << "Child graph order is wrong, graph " << graph->ToString() << " child graph size " - << graph->child_graph_order().size() << " goto index " << order_index; + << child_graph_order.size() << " goto index " << order_index; } if (AnfAlgo::CheckPrimitiveType(cur_label, prim::kPrimLabelGoto)) { @@ -443,7 +441,7 @@ bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_i label_index = label_goto_index; } // get start_label_set_index of child graph - auto child_graph = graph->child_graph_order()[order_index]; + auto child_graph = child_graph_order[order_index]; MS_EXCEPTION_IF_NULL(child_graph); auto start_label_set = child_graph->get_start_label(); if (!AnfAlgo::HasNodeAttr(kAttrLabelIndex, start_label_set)) { @@ -468,8 +466,7 @@ std::tuple, std::map, std::map label_list = GetValue>(primitive->GetAttr(kAttrLabelSwitchList)); label_switch_map.insert({node, label_list}); for (size_t i = 0; i < label_list.size(); ++i) { - label_map[index] = node; - ++index; + label_map[index++] = node; } } } diff --git a/mindspore/ccsrc/session/ascend_control_parser.h b/mindspore/ccsrc/session/ascend_control_parser.h index 0f08d39c82..0201d27618 100644 --- a/mindspore/ccsrc/session/ascend_control_parser.h +++ b/mindspore/ccsrc/session/ascend_control_parser.h @@ -49,16 +49,15 @@ class AscendControlParser { NotNull *> memo); static void LinkParentGraph(NotNull kg, const CNodePtr &from_graph_call_node, - const CNodePtr &last_label, NotNull *> memo); + const CNodePtr &last_label); static std::tuple ParsePartial(NotNull node); static void LinkArgsToParam(NotNull to_graph, NotNull target_graph, NotNull arg, NotNull param); - static NotNull GetRealInput(NotNull from_graph, NotNull to_graph, - NotNull param); + static void InsertAssignToGraph(NotNull kg, NotNull from, NotNull to); - static CNodePtr GetNextRealKernel(std::vector list, size_t start); + static CNodePtr GetNextRealKernel(const std::vector &list, size_t start); // root graph order static std::tuple, std::map>> GetLabelNode( @@ -67,20 +66,7 @@ class AscendControlParser { NotNull graph); static std::vector RecurseGraph(const CNodePtr &cur_label_goto, const CNodePtr &end_label_goto, NotNull graph, NotNull *> memo); - - static constexpr size_t kCNodePrim = 0; - static constexpr size_t kCNodeCallArg = 1; - static constexpr size_t kCNodeSwitchCond = 1; - static constexpr size_t kCNodeSwitchTrue = 2; - static constexpr size_t kCNodeSwitchFalse = 3; - static constexpr size_t kCNodeSwitchLength = 4; - static constexpr size_t kCNodePartialLength = 2; - static constexpr size_t kCNodePartialFunc = 1; - static constexpr size_t kCNodeSwitchLayerCond = 1; - static constexpr size_t kCNodeSwitchLayerBranch = 2; - static constexpr size_t kCNodeSwitchLayerLength = 3; }; - } // namespace session } // namespace mindspore diff --git a/mindspore/ccsrc/session/ascend_session.cc b/mindspore/ccsrc/session/ascend_session.cc index b99c99443d..ef6ab83727 100644 --- a/mindspore/ccsrc/session/ascend_session.cc +++ b/mindspore/ccsrc/session/ascend_session.cc @@ -256,7 +256,6 @@ static void UpdateRealInput(KernelGraph *graph) { void RecurseToUpdateCallRealInput(KernelGraph *graph) { MS_EXCEPTION_IF_NULL(graph); MS_LOG(INFO) << "start graph id:" << graph->graph_id(); - graph->UpdateCallRealInput(); for (auto &child_graph : graph->child_graph_order()) { if (child_graph == graph->parent_graph()) { MS_LOG(INFO) << "Child graph:" << child_graph->graph_id() @@ -265,6 +264,8 @@ void RecurseToUpdateCallRealInput(KernelGraph *graph) { } RecurseToUpdateCallRealInput(child_graph.get()); } + // this action should from bottom to top + graph->UpdateCallRealInput(); } } // namespace @@ -280,27 +281,20 @@ GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrL GraphId AscendSession::CompileGraph(NotNull func_graph) { MS_LOG(INFO) << "start"; auto graph = ConstructKernelGraph(func_graph); - MS_LOG(INFO) << "graph input size:" << graph->inputs().size(); // split switch SplitGraphs(graph); - MS_LOG(INFO) << "graph input size:" << graph->inputs().size(); // insert goto labels and label_sets LinkChildGraphs(NOT_NULL(graph)); - MS_LOG(INFO) << "graph input size:" << graph->inputs().size(); // resource initialize InitRuntimeResource(); // assign label AssignLabel(NOT_NULL(graph)); - MS_LOG(INFO) << "graph input size:" << graph->inputs().size(); // recurse compile child graph RecurseCompileGraph(graph); - MS_LOG(INFO) << "graph input size:" << graph->inputs().size(); // root graph valiate,include genearte execute order and so on RootGraphExecutorValidate(NOT_NULL(graph)); - MS_LOG(INFO) << "graph input size:" << graph->inputs().size(); // adjust kernel AdjustKernel(graph); - MS_LOG(INFO) << "graph input size:" << graph->inputs().size(); // assign stream AssignStream(graph); // build kernel @@ -313,7 +307,6 @@ GraphId AscendSession::CompileGraph(NotNull func_graph) { LoadTask(graph); // return the graph id to backend auto graph_id = graph->graph_id(); - MS_LOG(INFO) << "Compile graph " << graph_id << " success"; return graph_id; } diff --git a/mindspore/ccsrc/session/kernel_graph.cc b/mindspore/ccsrc/session/kernel_graph.cc index 1db932cd30..09cab04f55 100644 --- a/mindspore/ccsrc/session/kernel_graph.cc +++ b/mindspore/ccsrc/session/kernel_graph.cc @@ -606,10 +606,6 @@ void KernelGraph::ReplaceNode(const AnfNodePtr &old_anf_node, AnfNodePtr new_anf break; } } - MS_LOG(INFO) << "Inputs of graph id:" << graph_id(); - for (size_t i = 0; i < inputs().size(); i++) { - MS_LOG(INFO) << "[" << i << "]:" << inputs()[i]->DebugString(); - } } // update front to backend map FrontBackendlMapUpdate(old_anf_node, new_anf_node); @@ -713,6 +709,9 @@ void KernelGraph::UpdateCallRealInput() { MS_LOG(INFO) << "paramter: " << parameter->DebugString() << " insert real input:" << new_real_input->DebugString(); (void)real_inputs.insert(new_real_input); + if (new_real_input->isa()) { + ReplaceNode(parameter, new_real_input); + } } } }