Merge pull request !1459 from zhoufeng/link-assigntags/v0.5.0-beta
| @@ -28,6 +28,9 @@ namespace device { | |||
| namespace ascend { | |||
| static void UpdateLabelGoto(NotNull<CNodePtr> node) { | |||
| if (AnfAlgo::HasNodeAttr(kAttrLabelIndex, node)) { | |||
| return; | |||
| } | |||
| if (node->size() <= kLabelGotoLabelId) { | |||
| MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " has invalid input size " << node->size(); | |||
| } | |||
| @@ -42,6 +45,9 @@ static void UpdateLabelGoto(NotNull<CNodePtr> node) { | |||
| } | |||
| static void UpdateLabelSwitch(NotNull<CNodePtr> node) { | |||
| if (AnfAlgo::HasNodeAttr(kAttrLabelIndex, node)) { | |||
| return; | |||
| } | |||
| if (node->size() <= kLabelGotoLabelId) { | |||
| MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " has invalid input size " << node->size(); | |||
| } | |||
| @@ -69,9 +75,12 @@ static void AssignLabelForLabelSet(NotNull<std::shared_ptr<session::KernelGraph> | |||
| if (memo->find(graph.get()) != memo->end()) { | |||
| return; | |||
| } | |||
| memo->insert(graph.get()); | |||
| MS_LOG(INFO) << "Assign label for " << graph->ToString(); | |||
| auto nodes = TopoSort(graph->get_return()); | |||
| graph->SetExecOrderByDefault(); | |||
| auto nodes = graph->execution_order(); | |||
| for (auto &node : nodes) { | |||
| if (!node->isa<CNode>()) { | |||
| continue; | |||
| @@ -97,9 +106,15 @@ static void AssignLabelForGotoSwitch(NotNull<std::shared_ptr<session::KernelGrap | |||
| if (memo->find(graph.get()) != memo->end()) { | |||
| return; | |||
| } | |||
| memo->insert(graph.get()); | |||
| MS_LOG(INFO) << "Process label goto/switch for " << graph->ToString(); | |||
| auto nodes = TopoSort(graph->get_return()); | |||
| graph->SetExecOrderByDefault(); | |||
| auto nodes = graph->execution_order(); | |||
| auto end_goto = graph->get_end_goto(); | |||
| if (end_goto != nullptr) { | |||
| nodes.push_back(end_goto); | |||
| } | |||
| for (auto &node : nodes) { | |||
| if (!node->isa<CNode>()) { | |||
| continue; | |||
| @@ -53,6 +53,7 @@ class KernelRuntime { | |||
| virtual bool GenTask(const session::KernelGraph *graph); | |||
| bool LaunchKernel(const session::KernelGraph *graph); | |||
| virtual void AssignStaticMemoryInput(const session::KernelGraph *graph); | |||
| virtual void AssignStaticMemoryValueNode(session::KernelGraph *graph); | |||
| #ifdef ENABLE_DUMP_E2E | |||
| DumpConfPtr GetDumpConf(); | |||
| @@ -67,7 +68,6 @@ class KernelRuntime { | |||
| TypeId type_id) = 0; | |||
| virtual bool SyncStream() = 0; | |||
| void AssignStaticMemory(session::KernelGraph *graph); | |||
| void AssignStaticMemoryValueNode(session::KernelGraph *graph); | |||
| void AssignDynamicMemory(session::KernelGraph *graph); | |||
| void ReuseAssignDynamicMemory(session::KernelGraph *graph); | |||
| void AssignNodeOutputMem(int flag, const AnfNodePtr &node, int index); | |||
| @@ -22,49 +22,78 @@ | |||
| namespace mindspore { | |||
| namespace session { | |||
| static VectorRef GetCallArgs(std::vector<AnfNodePtr>::iterator iter_begin, std::vector<AnfNodePtr>::iterator iter_end) { | |||
| VectorRef call_args; | |||
| for (auto iter = iter_begin; iter != iter_end; ++iter) { | |||
| if (utils::isa<ValueNode>(*iter)) { | |||
| call_args.push_back(GetValueNode(*iter)); | |||
| } else { | |||
| call_args.push_back(*iter); | |||
| void AscendControlParser::ChildGraphDataAssign(const std::map<uint32_t, KernelGraphPtr> &graph_id_map) { | |||
| for (auto &iter : graph_id_map) { | |||
| auto &kg = iter.second; | |||
| MS_EXCEPTION_IF_NULL(kg); | |||
| auto real_inputs = kg->real_inputs(); | |||
| for (auto &it : real_inputs) { | |||
| auto ¶meter = it.first; | |||
| auto &args = it.second; | |||
| for (auto &arg : args) { | |||
| MS_EXCEPTION_IF_NULL(arg); | |||
| if (arg->isa<Parameter>()) { | |||
| MS_LOG(INFO) << "Parameter should be reused, no need insert assign, parameter: " << parameter->DebugString() | |||
| << ", arg:" << arg->DebugString(); | |||
| continue; | |||
| } | |||
| auto target_graph_iter = graph_id_map.find(AnfAlgo::GetGraphId(arg.get())); | |||
| if (target_graph_iter == graph_id_map.end()) { | |||
| MS_LOG(EXCEPTION) << "Graph id " << AnfAlgo::GetGraphId(arg.get()) << " not found."; | |||
| } | |||
| InsertAssignToGraph(NOT_NULL(target_graph_iter->second), NOT_NULL(arg), NOT_NULL(parameter)); | |||
| } | |||
| } | |||
| } | |||
| return call_args; | |||
| } | |||
| void AscendControlParser::LinkGraph(NotNull<KernelGraphPtr> kg) { | |||
| std::set<KernelGraphPtr> memo; | |||
| ProcessKernelGraph(kg, nullptr, nullptr, {}, NOT_NULL(&memo)); | |||
| ProcessKernelGraph(kg, nullptr, nullptr, NOT_NULL(&memo)); | |||
| std::map<uint32_t, KernelGraphPtr> graph_id_map; | |||
| for (auto &g : memo) { | |||
| if (graph_id_map.find(g->graph_id()) != graph_id_map.end()) { | |||
| MS_LOG(EXCEPTION) << "Two graph has same graph id " << g->graph_id() | |||
| << ", graph: " << graph_id_map[g->graph_id()]->ToString() << " " << g->ToString(); | |||
| } | |||
| graph_id_map[g->graph_id()] = g; | |||
| } | |||
| ChildGraphDataAssign(graph_id_map); | |||
| } | |||
| CNodePtr AscendControlParser::GetNextRealKernel(std::vector<CNodePtr> list, size_t start) { | |||
| for (size_t i = start; i < list.size() - 1; ++i) { | |||
| if (!IsPrimitiveCNode(list[i], prim::kPrimPartial) && AnfAlgo::IsRealKernel(list[i])) { | |||
| return list[i]; | |||
| } | |||
| } | |||
| return nullptr; | |||
| } | |||
| NotNull<CNodePtr> AscendControlParser::ProcessKernelGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &last_node, | |||
| const CNodePtr &last_label, const VectorRef &args, | |||
| const CNodePtr &last_label, | |||
| NotNull<std::set<KernelGraphPtr> *> memo) { | |||
| MS_LOG(INFO) << "Start process KernelGraph " << kg->ToString(); | |||
| // 0. recursive condition | |||
| // 1. recursive condition | |||
| if (memo->find(kg) != memo->end()) { | |||
| MS_LOG(INFO) << "KernelGraph has beed processed: " << kg->ToString(); | |||
| return NOT_NULL(kg->get_start_label()); | |||
| } | |||
| memo->insert(kg.get()); | |||
| // 2. args replace placeholder | |||
| LinkParentGraph(kg, last_node, last_label, args); | |||
| LinkParentGraph(kg, last_node, last_label, memo); | |||
| // 3. topological sort | |||
| std::vector<CNodePtr> nodes = GetCNodes(TopoSort(kg->get_return())); | |||
| kg->SetExecOrderByDefault(); | |||
| std::vector<CNodePtr> nodes = kg->execution_order(); | |||
| if (nodes.empty()) { | |||
| MS_LOG(EXCEPTION) << "KernelGraph " << kg->ToString() << " has no cnodes!"; | |||
| } | |||
| // 4. insert first_label | |||
| auto start_label = kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSetOpName))}); | |||
| for (auto node : nodes) { | |||
| if (!IsPrimitiveCNode(node, prim::kPrimPartial)) { | |||
| InsertControlDependToGraph(kg, NOT_NULL(start_label), NOT_NULL(node)); | |||
| break; | |||
| } | |||
| } | |||
| MS_LOG(INFO) << "Insert start label " << start_label->DebugString() << " to " << kg->ToString(); | |||
| kg->set_start_label(start_label); | |||
| // 5. traverse | |||
| for (size_t i = 0; i < nodes.size(); ++i) { | |||
| @@ -79,17 +108,19 @@ NotNull<CNodePtr> AscendControlParser::ProcessKernelGraph(NotNull<KernelGraphPtr | |||
| } | |||
| AnfNodePtr arg = cnode->input(kCNodeCallArg); | |||
| if (IsValueNode<KernelGraph>(arg)) { | |||
| RecurseCall(kg, NOT_NULL(cnode), (i + 1 < nodes.size() ? nodes[i + 1] : nullptr), memo); | |||
| RecurseCall(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo); | |||
| } else if (!arg->isa<CNode>()) { | |||
| MS_LOG(EXCEPTION) << "Unknown type call node " << cnode->DebugString(); | |||
| } else if (IsPrimitiveCNode(arg->cast<CNodePtr>(), prim::kPrimSwitch)) { | |||
| auto arg_cnode = arg->cast<CNodePtr>(); | |||
| cnode->set_inputs(cnode->inputs()); | |||
| RecurseSwitch(kg, NOT_NULL(cnode), memo); | |||
| MS_EXCEPTION_IF_NULL(arg_cnode); | |||
| cnode->set_inputs(arg_cnode->inputs()); | |||
| RecurseSwitch(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo); | |||
| } else if (IsPrimitiveCNode(arg->cast<CNodePtr>(), prim::kPrimSwitchLayer)) { | |||
| auto arg_cnode = arg->cast<CNodePtr>(); | |||
| cnode->set_inputs(cnode->inputs()); | |||
| RecurseSwitchLayer(kg, NOT_NULL(cnode), memo); | |||
| MS_EXCEPTION_IF_NULL(arg_cnode); | |||
| cnode->set_inputs(arg_cnode->inputs()); | |||
| RecurseSwitchLayer(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo); | |||
| } | |||
| } | |||
| @@ -97,16 +128,6 @@ NotNull<CNodePtr> AscendControlParser::ProcessKernelGraph(NotNull<KernelGraphPtr | |||
| return NOT_NULL(start_label); | |||
| } | |||
| std::vector<CNodePtr> AscendControlParser::GetCNodes(const std::vector<AnfNodePtr> &in) { | |||
| std::vector<CNodePtr> out; | |||
| for (auto &node : in) { | |||
| if (node->isa<CNode>()) { | |||
| out.push_back(node->cast<CNodePtr>()); | |||
| } | |||
| } | |||
| return out; | |||
| } | |||
| void AscendControlParser::InsertDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> attch_node) { | |||
| std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>("depend"))}; | |||
| auto return_node = kg->get_return(); | |||
| @@ -128,11 +149,7 @@ void AscendControlParser::InsertControlDependToGraph(NotNull<KernelGraphPtr> kg, | |||
| } | |||
| void AscendControlParser::LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &from_graph_call_node, | |||
| const CNodePtr &last_label, const VectorRef &args) { | |||
| if (from_graph_call_node != nullptr) { | |||
| SetSubGraphInput(kg, NOT_NULL(from_graph_call_node), args); | |||
| } | |||
| const CNodePtr &last_label, NotNull<std::set<KernelGraphPtr> *> memo) { | |||
| auto origin_return = kg->get_return(); | |||
| std::vector<AnfNodePtr> origin_return_inputs = origin_return->inputs(); | |||
| // if entry graph, replace return with make_tuple | |||
| @@ -146,7 +163,8 @@ void AscendControlParser::LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNod | |||
| // else replace return with label_goto | |||
| auto label_goto = | |||
| kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelGotoOpName)), last_label}); | |||
| InsertDependToGraph(kg, NOT_NULL(label_goto)); | |||
| MS_LOG(INFO) << "Insert end goto " << label_goto->DebugString() << " to " << kg->ToString(); | |||
| kg->set_end_goto(label_goto); | |||
| } | |||
| } | |||
| @@ -157,13 +175,14 @@ void AscendControlParser::RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodeP | |||
| // 1 get kernel graph | |||
| auto origin_inputs = cur_node->inputs(); | |||
| std::vector<AnfNodePtr> new_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelGotoOpName))}; | |||
| auto call_args = GetCallArgs(origin_inputs.begin() + 1, origin_inputs.end()); | |||
| if (!IsValueNode<KernelGraph>(origin_inputs[kCNodeCallArg])) { | |||
| MS_LOG(WARNING) << "Node " << cur_node->DebugString(10) << " index " << kCNodeCallArg << " is not a ValueNode"; | |||
| return; | |||
| } | |||
| // 2 return label | |||
| auto back_label = kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSetOpName))}); | |||
| MS_LOG(INFO) << "Insert back label " << back_label->DebugString() << " to " << kg->ToString() << " call node " | |||
| << cur_node->DebugString(); | |||
| // 3 add depend relationship | |||
| InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label)); | |||
| if (next_node != nullptr && next_node != kg->get_return()) { | |||
| @@ -173,7 +192,7 @@ void AscendControlParser::RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodeP | |||
| // 4 modify call op to goto op | |||
| cur_node->set_input(kCNodePrim, new_inputs[kCNodePrim]); | |||
| // 5 recurse sub graph | |||
| CNodePtr sub_label = ProcessKernelGraph(NOT_NULL(call_kg), cur_node, back_label, call_args, memo); | |||
| CNodePtr sub_label = ProcessKernelGraph(NOT_NULL(call_kg), cur_node, back_label, memo); | |||
| new_inputs.push_back(sub_label); | |||
| new_inputs.insert(new_inputs.end(), origin_inputs.begin(), origin_inputs.end()); | |||
| cur_node->set_inputs(new_inputs); | |||
| @@ -182,32 +201,37 @@ void AscendControlParser::RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodeP | |||
| } | |||
| void AscendControlParser::RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, | |||
| NotNull<std::set<KernelGraphPtr> *> memo) { | |||
| const CNodePtr &next_node, NotNull<std::set<KernelGraphPtr> *> memo) { | |||
| MS_LOG(INFO) << "process switch node " << cur_node->DebugString(); | |||
| if (cur_node->size() < kCNodeSwitchLength) { | |||
| MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLength; | |||
| } | |||
| // 1 return label | |||
| auto back_label = kg->NewCNode({std::make_shared<ValueNode>(prim::kPrimLabelSet)}); | |||
| // 2 recurse sub graph | |||
| auto back_label = kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSetOpName))}); | |||
| MS_LOG(INFO) << "Insert back label " << back_label->DebugString() << " to " << kg->ToString() << " switch node " | |||
| << cur_node->DebugString(); | |||
| // 2 add depend relationship | |||
| InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label)); | |||
| if (next_node != nullptr && next_node != kg->get_return()) { | |||
| InsertControlDependToGraph(kg, NOT_NULL(back_label), NOT_NULL(next_node)); | |||
| } | |||
| // 3 recurse sub graph | |||
| auto origin_switch_inputs = cur_node->inputs(); | |||
| std::vector<AnfNodePtr> new_switch_inputs = { | |||
| std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSwitchOpName)), | |||
| origin_switch_inputs[kCNodeSwitchCond]}; | |||
| for (size_t i = kCNodeSwitchCond + 1; i < kCNodeSwitchLength; ++i) { | |||
| // 2.1 branch kernel graph and args | |||
| // 3.1 branch kernel graph and args | |||
| CNodePtr partial; | |||
| KernelGraphPtr branch_fg; | |||
| VectorRef call_args; | |||
| std::tie(partial, branch_fg, call_args) = ParsePartial(NOT_NULL(origin_switch_inputs[i])); | |||
| // 2.2 add depend relationship | |||
| InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label)); | |||
| // 2.3 recurse sub graph | |||
| CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, call_args, memo); | |||
| std::tie(partial, branch_fg) = ParsePartial(NOT_NULL(origin_switch_inputs[i])); | |||
| // 3.2 recurse sub graph | |||
| CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo); | |||
| new_switch_inputs.push_back(branch_label); | |||
| } | |||
| std::swap(new_switch_inputs[kCNodeSwitchTrue], new_switch_inputs[kCNodeSwitchFalse]); | |||
| new_switch_inputs.insert(new_switch_inputs.end(), origin_switch_inputs.begin(), origin_switch_inputs.end()); | |||
| cur_node->set_inputs(new_switch_inputs); | |||
| cur_node->set_abstract(nullptr); | |||
| @@ -215,7 +239,7 @@ void AscendControlParser::RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNod | |||
| } | |||
| void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, | |||
| NotNull<std::set<KernelGraphPtr> *> memo) { | |||
| const CNodePtr &next_node, NotNull<std::set<KernelGraphPtr> *> memo) { | |||
| MS_LOG(INFO) << "process switch node " << cur_node->DebugString(); | |||
| if (cur_node->size() < kCNodeSwitchLayerLength) { | |||
| @@ -229,21 +253,24 @@ void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull | |||
| } | |||
| auto branch_partial = utils::cast<CNodePtr>(branch_tuple)->inputs(); | |||
| // 1 return label | |||
| auto back_label = kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSwitchOpName))}); | |||
| // 2 recurse sub graph | |||
| auto back_label = kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSetOpName))}); | |||
| // 2 add depend relationship | |||
| InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label)); | |||
| if (next_node != nullptr && next_node != kg->get_return()) { | |||
| InsertControlDependToGraph(kg, NOT_NULL(back_label), NOT_NULL(next_node)); | |||
| } | |||
| // 3 recurse sub graph | |||
| auto origin_switch_inputs = cur_node->inputs(); | |||
| std::vector<AnfNodePtr> new_switch_inputs = {std::make_shared<ValueNode>(prim::kPrimLabelSwitch), | |||
| origin_switch_inputs[kCNodeSwitchCond]}; | |||
| std::vector<AnfNodePtr> new_switch_inputs = { | |||
| std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSwitchOpName)), | |||
| origin_switch_inputs[kCNodeSwitchCond]}; | |||
| for (size_t i = 0; i < branch_partial.size(); ++i) { | |||
| // 2.1 branch kernel graph and args | |||
| // 3.1 branch kernel graph and args | |||
| CNodePtr partial; | |||
| KernelGraphPtr branch_fg; | |||
| VectorRef call_args; | |||
| std::tie(partial, branch_fg, call_args) = ParsePartial(NOT_NULL(origin_switch_inputs[i])); | |||
| // 2.2 add depend relationship | |||
| InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label)); | |||
| // 2.3 recurse sub graph | |||
| CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, call_args, memo); | |||
| std::tie(partial, branch_fg) = ParsePartial(NOT_NULL(origin_switch_inputs[i])); | |||
| // 3.2 recurse sub graph | |||
| CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo); | |||
| new_switch_inputs.push_back(branch_label); | |||
| } | |||
| new_switch_inputs.insert(new_switch_inputs.end(), branch_partial.begin(), branch_partial.end()); | |||
| @@ -252,7 +279,7 @@ void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull | |||
| MS_LOG(INFO) << "success process switch layer " << cur_node->DebugString(); | |||
| } | |||
| std::tuple<CNodePtr, KernelGraphPtr, VectorRef> AscendControlParser::ParsePartial(NotNull<AnfNodePtr> node) { | |||
| std::tuple<CNodePtr, KernelGraphPtr> AscendControlParser::ParsePartial(NotNull<AnfNodePtr> node) { | |||
| if (!node.get()->isa<CNode>()) { | |||
| MS_LOG(EXCEPTION) << "Switch branches must be partial, node: " << node->DebugString(); | |||
| } | |||
| @@ -263,9 +290,8 @@ std::tuple<CNodePtr, KernelGraphPtr, VectorRef> AscendControlParser::ParsePartia | |||
| } | |||
| auto partial_inputs = partial_cnode->inputs(); | |||
| auto branch_kg = GetValueNode<KernelGraphPtr>(partial_inputs[kCNodePartialFunc]); | |||
| auto call_args = GetCallArgs(partial_inputs.begin() + kCNodePartialFunc + 1, partial_inputs.end()); | |||
| return {partial_cnode, branch_kg, call_args}; | |||
| return {partial_cnode, branch_kg}; | |||
| } | |||
| void AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, | |||
| @@ -289,31 +315,199 @@ void AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNul | |||
| InsertDependToGraph(kg, NOT_NULL(assign_node)); | |||
| } | |||
| size_t AscendControlParser::SetChildGraphInput(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> node, | |||
| size_t input_index) { | |||
| auto output_num = AnfAlgo::GetOutputTensorNum(node); | |||
| if (output_num > 1 && !AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) { | |||
| return input_index + output_num; | |||
| NotNull<AnfNodePtr> AscendControlParser::GetRealInput(NotNull<KernelGraphPtr> from_graph, | |||
| NotNull<KernelGraphPtr> to_graph, NotNull<AnfNodePtr> param) { | |||
| std::set<AnfNodePtr> args_list = to_graph->GetRealInput(param); | |||
| for (auto arg : args_list) { | |||
| if (arg->func_graph() == from_graph.get()) { | |||
| return NOT_NULL(arg); | |||
| } | |||
| } | |||
| MS_LOG(EXCEPTION) << to_graph->ToString() << " input " << param->DebugString() << " not from " | |||
| << from_graph->ToString(); | |||
| } | |||
| auto &graph_inputs = kg->inputs(); | |||
| if (input_index >= graph_inputs.size()) { | |||
| MS_LOG(EXCEPTION) << "input_index " << input_index << " out of range size " << graph_inputs.size(); | |||
| void AscendControlParser::LinkArgsToParam(NotNull<KernelGraphPtr> to_graph, NotNull<KernelGraphPtr> target_graph, | |||
| NotNull<AnfNodePtr> arg, NotNull<AnfNodePtr> param) { | |||
| if (IsPrimitiveCNode(arg, prim::kPrimMakeTuple) && IsPrimitiveCNode(param, prim::kPrimMakeTuple)) { | |||
| MS_LOG(INFO) << "Arg " << arg->DebugString() << " Param " << param->DebugString() << " is a tuple"; | |||
| CNodePtr cnode_arg = arg.get()->cast<CNodePtr>(); | |||
| CNodePtr cnode_param = param.get()->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode_arg); | |||
| MS_EXCEPTION_IF_NULL(cnode_param); | |||
| if (cnode_arg->size() != cnode_param->size()) { | |||
| MS_LOG(EXCEPTION) << "Arg " << arg->DebugString() << " size " << cnode_arg->size() << " but Param " | |||
| << param->DebugString() << " size " << cnode_param->size(); | |||
| } | |||
| for (size_t i = 1; i < cnode_param->size(); ++i) { | |||
| LinkArgsToParam(to_graph, target_graph, NOT_NULL(cnode_arg->input(i)), NOT_NULL(cnode_param->input(i))); | |||
| } | |||
| } else if (arg->isa<CNode>()) { | |||
| InsertAssignToGraph(target_graph, arg, param); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Arg " << arg->DebugString() << " Param " << param->DebugString() << " unknown type."; | |||
| } | |||
| auto backend_parameter = graph_inputs[input_index]; | |||
| if (node.get()->isa<Parameter>()) { | |||
| MS_EXCEPTION_IF_NULL(backend_parameter); | |||
| MS_LOG(INFO) << "Reuse node [" << node->DebugString() << "], old node[" << backend_parameter->DebugString() | |||
| << "] will be replaced."; | |||
| kg->ReplaceNode(backend_parameter, node); | |||
| return input_index; | |||
| } | |||
| void AscendControlParser::ExecutorValidate(NotNull<KernelGraphPtr> root_graph) { | |||
| std::set<KernelGraphPtr> memo; | |||
| (void)RecurseGraph(nullptr, nullptr, root_graph, NOT_NULL(&memo)); | |||
| } | |||
| std::vector<CNodePtr> AscendControlParser::RecurseGraph(const CNodePtr &cur_label_goto, const CNodePtr &end_label_goto, | |||
| NotNull<KernelGraphPtr> graph, | |||
| NotNull<std::set<KernelGraphPtr> *> memo) { | |||
| MS_LOG(INFO) << "graph:" << graph->graph_id() << " start"; | |||
| auto print_vector = [&](std::vector<CNodePtr> vec) -> void { | |||
| MS_LOG(INFO) << "graph:" << graph->graph_id() << "execution order"; | |||
| for (size_t i = 0; i < vec.size(); i++) { | |||
| MS_LOG(INFO) << "[" << i << "][" << vec[i]->DebugString() << "]"; | |||
| } | |||
| }; | |||
| if (memo->find(graph) != memo->end()) { | |||
| return {}; | |||
| } | |||
| memo->insert(graph.get()); | |||
| graph->SetExecOrderByDefault(); | |||
| std::vector<CNodePtr> cnodes = graph->execution_order(); | |||
| std::map<uint32_t, CNodePtr> label_map; | |||
| std::map<CNodePtr, std::vector<uint32_t>> label_switch_map; | |||
| std::tie(label_map, label_switch_map) = GetLabelNode(cnodes); | |||
| std::vector<CNodePtr> execution_order; | |||
| for (auto &node : cnodes) { | |||
| execution_order.push_back(node); | |||
| if (node == graph->get_end_goto()) { | |||
| continue; | |||
| } | |||
| auto label_iter = | |||
| std::find_if(label_map.begin(), label_map.end(), | |||
| [node](const std::map<uint32_t, CNodePtr>::value_type iter) { return iter.second == node; }); | |||
| if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) { | |||
| if (!CheckLabelIndex(label_iter->first, 0, label_iter->second, graph)) { | |||
| MS_LOG(EXCEPTION) << "Check label index fail"; | |||
| } | |||
| auto child_graph = graph->child_graph_order()[label_iter->first]; | |||
| if (child_graph == graph->parent_graph()) { | |||
| continue; | |||
| } | |||
| std::map<uint32_t, CNodePtr> child_label_map; | |||
| std::tie(child_label_map, std::ignore) = GetLabelNode(child_graph->execution_order()); | |||
| auto child_execution_order = | |||
| RecurseGraph(child_label_map.begin()->second, child_label_map.rbegin()->second, NOT_NULL(child_graph), memo); | |||
| execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end()); | |||
| } else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSwitch)) { | |||
| std::vector<uint32_t> label_list = label_switch_map.find(node)->second; | |||
| std::reverse(label_list.begin(), label_list.end()); | |||
| for (size_t i = 0; i < label_list.size(); ++i) { | |||
| if (!CheckLabelIndex(label_iter->first + i, label_list[i], label_iter->second, graph)) { | |||
| MS_LOG(EXCEPTION) << "Check label index fail"; | |||
| } | |||
| auto child_graph = graph->child_graph_order()[label_iter->first + i]; | |||
| if (child_graph == graph->parent_graph()) { | |||
| continue; | |||
| } | |||
| std::map<uint32_t, CNodePtr> child_label_map; | |||
| std::tie(child_label_map, std::ignore) = GetLabelNode(child_graph->execution_order()); | |||
| auto child_execution_order = | |||
| RecurseGraph(child_label_map.begin()->second, child_label_map.rbegin()->second, NOT_NULL(child_graph), memo); | |||
| execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end()); | |||
| } | |||
| } | |||
| } | |||
| InsertAssignToGraph(kg, node, NOT_NULL(backend_parameter)); | |||
| return input_index + 1; | |||
| graph->set_execution_order(execution_order); | |||
| print_vector(graph->execution_order()); | |||
| return execution_order; | |||
| } | |||
| void AscendControlParser::SetSubGraphInput(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> from_graph_call_node, | |||
| const VectorRef &args) {} | |||
| bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cur_label, | |||
| NotNull<KernelGraphPtr> graph) { | |||
| // check index and child order size | |||
| if (graph->child_graph_order().size() <= static_cast<size_t>(order_index)) { | |||
| MS_LOG(EXCEPTION) << "Child graph order is wrong, graph " << graph->ToString() << " child graph size " | |||
| << graph->child_graph_order().size() << " goto index " << order_index; | |||
| } | |||
| if (AnfAlgo::CheckPrimitiveType(cur_label, prim::kPrimLabelGoto)) { | |||
| // check label_goto and start_label in child graph | |||
| if (!AnfAlgo::HasNodeAttr(kAttrLabelIndex, cur_label)) { | |||
| MS_LOG(EXCEPTION) << "LabelSetKernel has no attr label_index"; | |||
| } | |||
| auto primitive = AnfAlgo::GetCNodePrimitive(cur_label); | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| uint32_t label_goto_index = GetValue<uint32_t>(primitive->GetAttr(kAttrLabelIndex)); | |||
| label_index = label_goto_index; | |||
| } | |||
| // get start_label_set_index of child graph | |||
| auto child_graph = graph->child_graph_order()[order_index]; | |||
| MS_EXCEPTION_IF_NULL(child_graph); | |||
| auto start_label_set = child_graph->get_start_label(); | |||
| if (!AnfAlgo::HasNodeAttr(kAttrLabelIndex, start_label_set)) { | |||
| MS_LOG(EXCEPTION) << "LabelSetKernel has no attr label_index"; | |||
| } | |||
| auto start_primitive = AnfAlgo::GetCNodePrimitive(start_label_set); | |||
| MS_EXCEPTION_IF_NULL(start_primitive); | |||
| uint32_t start_label_set_index = GetValue<uint32_t>(start_primitive->GetAttr(kAttrLabelIndex)); | |||
| if (label_index != start_label_set_index) { | |||
| MS_LOG(WARNING) << cur_label->DebugString() << " index " << label_index << " but " << start_label_set->DebugString() | |||
| << " index " << start_label_set_index << " current child graph order : " << order_index; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| std::tuple<std::map<uint32_t, CNodePtr>, std::map<CNodePtr, std::vector<uint32_t>>> AscendControlParser::GetLabelNode( | |||
| const std::vector<CNodePtr> &nodes) { | |||
| std::map<uint32_t, CNodePtr> label_map; | |||
| std::map<CNodePtr, std::vector<uint32_t>> label_switch_map; | |||
| // record child graph | |||
| uint32_t index = 0; | |||
| for (auto &node : nodes) { | |||
| if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) { | |||
| label_map[index] = node; | |||
| ++index; | |||
| } else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSwitch)) { | |||
| if (!AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, node)) { | |||
| MS_LOG(EXCEPTION) << "LabelSwitchKernel has no attr label_switch_list"; | |||
| } | |||
| auto primitive = AnfAlgo::GetCNodePrimitive(node); | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| std::vector<uint32_t> label_list = GetValue<std::vector<uint32_t>>(primitive->GetAttr(kAttrLabelSwitchList)); | |||
| label_switch_map.insert({node, label_list}); | |||
| for (size_t i = 0; i < label_list.size(); ++i) { | |||
| label_map[index] = node; | |||
| ++index; | |||
| } | |||
| } | |||
| } | |||
| return {label_map, label_switch_map}; | |||
| } | |||
| void AscendControlParser::UpdateChildGraphOrder(NotNull<KernelGraphPtr> kg) { | |||
| MS_LOG(INFO) << "graph id:" << kg->graph_id(); | |||
| kg->SetExecOrderByDefault(); | |||
| auto call_nodes = kg->FindNodeByPrimitive(std::make_shared<Primitive>(prim::kPrimCall->name())); | |||
| std::vector<KernelGraphPtr> child_graph_order; | |||
| for (auto &call_node : call_nodes) { | |||
| MS_EXCEPTION_IF_NULL(call_node); | |||
| auto call_child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node->cast<CNodePtr>()); | |||
| for (const auto &child_graph : call_child_graphs) { | |||
| MS_EXCEPTION_IF_NULL(child_graph); | |||
| if (child_graph != kg->parent_graph()) { | |||
| child_graph->set_parent_graph(kg.get()); | |||
| } | |||
| child_graph_order.push_back(child_graph); | |||
| } | |||
| } | |||
| for (size_t i = 0; i < child_graph_order.size(); i++) { | |||
| MS_LOG(INFO) << "child graph[" << i << "][id:" << child_graph_order[i]->graph_id() << "]"; | |||
| } | |||
| kg->set_child_graph_order(child_graph_order); | |||
| } | |||
| } // namespace session | |||
| } // namespace mindspore | |||
| @@ -17,6 +17,7 @@ | |||
| #define MINDSPORE_CCSRC_SESSION_ASCEND_CONTROL_PARSER_H | |||
| #include <set> | |||
| #include <map> | |||
| #include <vector> | |||
| #include <tuple> | |||
| #include "session/kernel_graph.h" | |||
| @@ -28,31 +29,44 @@ namespace session { | |||
| class AscendControlParser { | |||
| public: | |||
| static void ChildGraphDataAssign(const std::map<uint32_t, KernelGraphPtr> &graph_id_map); | |||
| static void LinkGraph(NotNull<KernelGraphPtr> kg); | |||
| static void InsertDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> attch_node); | |||
| static void InsertControlDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> first_node, | |||
| NotNull<AnfNodePtr> second_node); | |||
| static void ExecutorValidate(NotNull<KernelGraphPtr> root_graph); | |||
| static void UpdateChildGraphOrder(NotNull<KernelGraphPtr> kg); | |||
| private: | |||
| static NotNull<CNodePtr> ProcessKernelGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &last_node, | |||
| const CNodePtr &last_label, const VectorRef &args, | |||
| NotNull<std::set<KernelGraphPtr> *> memo); | |||
| const CNodePtr &last_label, NotNull<std::set<KernelGraphPtr> *> memo); | |||
| static void RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, const CNodePtr &next_node, | |||
| NotNull<std::set<KernelGraphPtr> *> memo); | |||
| static void RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, | |||
| static void RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, const CNodePtr &next_node, | |||
| NotNull<std::set<KernelGraphPtr> *> memo); | |||
| static void RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, | |||
| static void RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, const CNodePtr &next_node, | |||
| NotNull<std::set<KernelGraphPtr> *> memo); | |||
| static std::vector<CNodePtr> GetCNodes(const std::vector<AnfNodePtr> &in); | |||
| static void LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &from_graph_call_node, | |||
| const CNodePtr &last_label, const VectorRef &args); | |||
| static void SetSubGraphInput(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> from_graph_call_node, | |||
| const VectorRef &args); | |||
| static std::tuple<CNodePtr, KernelGraphPtr, VectorRef> ParsePartial(NotNull<AnfNodePtr> node); | |||
| const CNodePtr &last_label, NotNull<std::set<KernelGraphPtr> *> memo); | |||
| static std::tuple<CNodePtr, KernelGraphPtr> ParsePartial(NotNull<AnfNodePtr> node); | |||
| static void LinkArgsToParam(NotNull<KernelGraphPtr> to_graph, NotNull<KernelGraphPtr> target_graph, | |||
| NotNull<AnfNodePtr> arg, NotNull<AnfNodePtr> param); | |||
| static NotNull<AnfNodePtr> GetRealInput(NotNull<KernelGraphPtr> from_graph, NotNull<KernelGraphPtr> to_graph, | |||
| NotNull<AnfNodePtr> param); | |||
| static void InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to); | |||
| static size_t SetChildGraphInput(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> node, size_t input_index); | |||
| static CNodePtr GetNextRealKernel(std::vector<CNodePtr> list, size_t start); | |||
| // root graph order | |||
| static std::tuple<std::map<uint32_t, CNodePtr>, std::map<CNodePtr, std::vector<uint32_t>>> GetLabelNode( | |||
| const std::vector<CNodePtr> &nodes); | |||
| static bool CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cnode, | |||
| NotNull<KernelGraphPtr> graph); | |||
| static std::vector<CNodePtr> RecurseGraph(const CNodePtr &cur_label_goto, const CNodePtr &end_label_goto, | |||
| NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo); | |||
| static constexpr size_t kCNodePrim = 0; | |||
| static constexpr size_t kCNodeCallArg = 1; | |||
| @@ -177,10 +177,6 @@ std::vector<std::vector<CNodePtr>> GetChildList(const KernelGraph &cur_graph, co | |||
| for (size_t i = 0; i < cnodes.size(); i++) { | |||
| if (AnfAlgo::CheckPrimitiveType(cnodes[i], prim::kPrimCall) && !AnfAlgo::IsSwitchCall(cnodes[i])) { | |||
| auto call_kernel_graph = AnfAlgo::GetCallNodeKernelGraph(cnodes[i]); | |||
| // if graph is the true branch of while,no need split graph | |||
| if (call_kernel_graph.size() == 1 && call_kernel_graph[0] == cur_graph.parent_graph()) { | |||
| continue; | |||
| } | |||
| auto prev_call_list = std::vector<CNodePtr>(cnodes.begin() + after_call_index, cnodes.begin() + i); | |||
| auto call_list = std::vector<CNodePtr>(1, cnodes[i]); | |||
| after_call_index = i + 1; | |||
| @@ -195,10 +191,10 @@ std::vector<std::vector<CNodePtr>> GetChildList(const KernelGraph &cur_graph, co | |||
| // if a call has kernel input, it's a child graph split from ME, so these kernel input should be set into real input of | |||
| // graph.For example, call input = (prim,graph,kernel1,kernel2),then real_input = [kernel1,kernel2] | |||
| void UpdateRealInput(KernelGraph *graph) { | |||
| static void UpdateRealInput(KernelGraph *graph) { | |||
| auto call_nodes = graph->FindNodeByPrimitive(prim::kPrimCall); | |||
| auto bind_call_partial_with_parameter = [&](const std::vector<AnfNodePtr> ¶meters, | |||
| const std::vector<AnfNodePtr> &args, KernelGraph *child_graph) -> void { | |||
| auto bind_call_arg_with_parameter = [&](const std::vector<AnfNodePtr> ¶meters, | |||
| const std::vector<AnfNodePtr> &args, KernelGraph *child_graph) -> void { | |||
| MS_EXCEPTION_IF_NULL(child_graph); | |||
| MS_LOG(INFO) << "start bind parameter of child graph:" << child_graph->graph_id(); | |||
| if (args.empty()) { | |||
| @@ -208,8 +204,21 @@ void UpdateRealInput(KernelGraph *graph) { | |||
| MS_LOG(EXCEPTION) << "graph:" << child_graph->graph_id() << " parameters size:" << parameters.size() | |||
| << " and args size:" << args.size() << " not equal!"; | |||
| } | |||
| child_graph->SetExecOrderByDefault(); | |||
| for (size_t i = 0; i < parameters.size(); i++) { | |||
| MS_LOG(INFO) << "bind paramreter:" << parameters[i]->DebugString() << " ,arg:" << args[i]->DebugString(); | |||
| if (args[i] == parameters[i]) { | |||
| child_graph->SetRealInput(parameters[i], args[i]); | |||
| MS_LOG(INFO) << "Parameter and arg are same"; | |||
| continue; | |||
| } | |||
| // if arg is a parameter ,then reuse this parameter | |||
| if (args[i]->isa<Parameter>()) { | |||
| MS_LOG(INFO) << "Parameter:" << parameters[i]->DebugString() << " of graph:" << child_graph->graph_id() | |||
| << " reuse parameter:" << args[i]->DebugString() | |||
| << " of graph:" << AnfAlgo::GetGraphId(args[i].get()); | |||
| child_graph->ReplaceNode(parameters[i], args[i]); | |||
| continue; | |||
| } | |||
| child_graph->SetRealInput(parameters[i], args[i]); | |||
| } | |||
| }; | |||
| @@ -218,9 +227,10 @@ void UpdateRealInput(KernelGraph *graph) { | |||
| auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node); | |||
| if (child_graphs.size() == 1) { | |||
| MS_EXCEPTION_IF_NULL(child_graphs[0]); | |||
| bind_call_partial_with_parameter( | |||
| child_graphs[0]->inputs(), std::vector<AnfNodePtr>(call_node->inputs().begin() + 2, call_node->inputs().end()), | |||
| child_graphs[0].get()); | |||
| std::vector<AnfNodePtr> real_args = | |||
| std::vector<AnfNodePtr>(call_node->inputs().begin() + 2, call_node->inputs().end()); | |||
| std::vector<AnfNodePtr> child_inputs = child_graphs[0]->inputs(); | |||
| bind_call_arg_with_parameter(child_inputs, real_args, child_graphs[0].get()); | |||
| call_node->set_inputs(std::vector<AnfNodePtr>(call_node->inputs().begin(), call_node->inputs().begin() + 2)); | |||
| } else if (child_graphs.size() == 2) { | |||
| auto get_partial_args = [&](size_t input_index) -> std::vector<AnfNodePtr> { | |||
| @@ -237,8 +247,8 @@ void UpdateRealInput(KernelGraph *graph) { | |||
| std::vector<AnfNodePtr>(partial_cnode->inputs().begin(), partial_cnode->inputs().begin() + 2)); | |||
| return ret; | |||
| }; | |||
| bind_call_partial_with_parameter(child_graphs[0]->inputs(), get_partial_args(2), child_graphs[0].get()); | |||
| bind_call_partial_with_parameter(child_graphs[1]->inputs(), get_partial_args(3), child_graphs[1].get()); | |||
| bind_call_arg_with_parameter(child_graphs[0]->inputs(), get_partial_args(2), child_graphs[0].get()); | |||
| bind_call_arg_with_parameter(child_graphs[1]->inputs(), get_partial_args(3), child_graphs[1].get()); | |||
| } | |||
| } | |||
| } | |||
| @@ -248,6 +258,11 @@ void RecurseToUpdateCallRealInput(KernelGraph *graph) { | |||
| MS_LOG(INFO) << "start graph id:" << graph->graph_id(); | |||
| graph->UpdateCallRealInput(); | |||
| for (auto &child_graph : graph->child_graph_order()) { | |||
| if (child_graph == graph->parent_graph()) { | |||
| MS_LOG(INFO) << "Child graph:" << child_graph->graph_id() | |||
| << ",parent graph:" << graph->parent_graph()->graph_id(); | |||
| continue; | |||
| } | |||
| RecurseToUpdateCallRealInput(child_graph.get()); | |||
| } | |||
| } | |||
| @@ -265,31 +280,31 @@ GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrL | |||
| GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) { | |||
| MS_LOG(INFO) << "start"; | |||
| auto graph = ConstructKernelGraph(func_graph); | |||
| MS_LOG(INFO) << "graph input size:" << graph->inputs().size(); | |||
| // split switch | |||
| SplitGraphs(graph); | |||
| MS_LOG(INFO) << "graph input size:" << graph->inputs().size(); | |||
| // insert goto labels and label_sets | |||
| LinkChildGraphs(NOT_NULL(graph)); | |||
| MS_LOG(INFO) << "graph input size:" << graph->inputs().size(); | |||
| // resource initialize | |||
| InitRuntimeResource(); | |||
| // assign label | |||
| AssignLabel(NOT_NULL(graph)); | |||
| if (!graph->executable()) { | |||
| return graph->graph_id(); | |||
| } | |||
| for (auto iter : graphs_) { | |||
| if (iter.second == graph) { | |||
| MS_LOG(INFO) << "Entry graph " << graph->ToString() << " graph id " << graph->graph_id(); | |||
| final_graph_id_ = graph->graph_id(); | |||
| } | |||
| MS_LOG(INFO) << "CompileChildGraph " << iter.second->ToString(); | |||
| CompileChildGraph(iter.second); | |||
| } | |||
| MS_LOG(INFO) << "graph input size:" << graph->inputs().size(); | |||
| // recurse compile child graph | |||
| RecurseCompileGraph(graph); | |||
| MS_LOG(INFO) << "graph input size:" << graph->inputs().size(); | |||
| // root graph valiate,include genearte execute order and so on | |||
| RootGraphExecutorValidate(NOT_NULL(graph)); | |||
| MS_LOG(INFO) << "graph input size:" << graph->inputs().size(); | |||
| // adjust kernel | |||
| AdjustKernel(graph); | |||
| // root graph valiate,include genearte execute order and so on | |||
| RootGraphExecutorValidate(graph.get()); | |||
| MS_LOG(INFO) << "graph input size:" << graph->inputs().size(); | |||
| // assign stream | |||
| AssignStream(graph); | |||
| // build kernel | |||
| BuildKernel(graph); | |||
| // alloc mem | |||
| MemoryAlloc(graph.get()); | |||
| // task generate | |||
| @@ -365,6 +380,7 @@ void AscendSession::BuildGraph(GraphId graph_id) { | |||
| void AscendSession::CompileChildGraph(const KernelGraphPtr &child_graph) { | |||
| MS_EXCEPTION_IF_NULL(child_graph); | |||
| MS_LOG(INFO) << "CompileChildGraph " << child_graph->ToString(); | |||
| opt::AscendBackendIRFusionOptimization(child_graph); | |||
| // select kernel build info | |||
| SelectKernel(*child_graph); | |||
| @@ -376,12 +392,14 @@ void AscendSession::CompileChildGraph(const KernelGraphPtr &child_graph) { | |||
| auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); | |||
| MS_EXCEPTION_IF_NULL(runtime_instance); | |||
| runtime_instance->AssignStaticMemoryInput(child_graph.get()); | |||
| runtime_instance->AssignStaticMemoryValueNode(child_graph.get()); | |||
| } | |||
| void AscendSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, | |||
| VectorRef *const outputs) { | |||
| MS_LOG(INFO) << "start"; | |||
| auto kernel_graph = GetGraph(graph_id); | |||
| DumpIR("./run_graph.ir", kernel_graph); | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| // if none of child graph and no anf output exists | |||
| if (!kernel_graph->executable()) { | |||
| @@ -1378,10 +1396,10 @@ void AscendSession::SyncInitialTenosrToDevice() { | |||
| } | |||
| } | |||
| KernelGraphPtr AscendSession::ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph, | |||
| const std::vector<CNodePtr> &list) { | |||
| std::vector<AnfNodePtr> AscendSession::ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph, | |||
| const std::vector<CNodePtr> &list) { | |||
| MS_EXCEPTION_IF_NULL(new_kernel_graph); | |||
| MS_LOG(INFO) << "start split kernel graph:" << new_kernel_graph->graph_id(); | |||
| MS_LOG(INFO) << "start contruct splited kernel graph:" << new_kernel_graph->graph_id(); | |||
| // count the output of every anf node | |||
| std::set<AnfNodePtr> has_output_nodes; | |||
| for (auto &anf_node : list) { | |||
| @@ -1390,21 +1408,23 @@ KernelGraphPtr AscendSession::ConstructSplitedGraph(const KernelGraphPtr &new_ke | |||
| } | |||
| } | |||
| MS_LOG(INFO) << "Construct input of kernel graph:" << new_kernel_graph->graph_id(); | |||
| std::vector<AnfNodePtr> call_node_inputs; | |||
| auto graph_inputs = new_kernel_graph->MutableInputs(); | |||
| MS_EXCEPTION_IF_NULL(graph_inputs); | |||
| // create new parameter from cnode | |||
| for (auto &anf_node : list) { | |||
| auto cnode = anf_node->cast<CNodePtr>(); | |||
| for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) { | |||
| auto input = cnode->inputs()[input_idx]; | |||
| MS_EXCEPTION_IF_NULL(input); | |||
| if (!input->isa<CNode>()) { | |||
| if (input->isa<Parameter>()) { | |||
| graph_inputs->push_back(input); | |||
| cnode->set_input(input_idx, input); | |||
| continue; | |||
| } | |||
| if (AnfAlgo::GetGraphId(input.get()) != new_kernel_graph->graph_id()) { | |||
| } else if (AnfAlgo::GetGraphId(input.get()) != new_kernel_graph->graph_id()) { | |||
| auto new_parameter = CreateNewParameterFromCNode(input, true, new_kernel_graph.get()); | |||
| cnode->set_input(input_idx, new_parameter); | |||
| new_kernel_graph->SetRealInput(new_parameter, input); | |||
| } | |||
| call_node_inputs.push_back(input); | |||
| } | |||
| } | |||
| MS_LOG(INFO) << "Construct output of kernel graph:" << new_kernel_graph->graph_id(); | |||
| @@ -1424,7 +1444,7 @@ KernelGraphPtr AscendSession::ConstructSplitedGraph(const KernelGraphPtr &new_ke | |||
| new_kernel_graph->set_output(new_kernel_graph->NewCNode(make_tuple_inputs)); | |||
| } | |||
| MS_LOG(INFO) << "end"; | |||
| return new_kernel_graph; | |||
| return call_node_inputs; | |||
| } | |||
| void AscendSession::SplitGraphs(const KernelGraphPtr &root_graph) { | |||
| @@ -1438,7 +1458,7 @@ void AscendSession::SplitGraph(const KernelGraphPtr &graph) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| auto apply_list = GetCNodes(TopoSort(graph->get_return())); | |||
| // update the root graph child graph order | |||
| graph->UpdateChildGraphOrder(); | |||
| AscendControlParser::UpdateChildGraphOrder(NOT_NULL(graph)); | |||
| // get child list from current graph | |||
| std::vector<std::vector<CNodePtr>> child_graph_lists = GetChildList(*graph, apply_list); | |||
| auto bind_new_call_to_new_graph = [&](std::vector<CNodePtr> child_graph_list) -> AnfNodePtr { | |||
| @@ -1457,7 +1477,8 @@ void AscendSession::SplitGraph(const KernelGraphPtr &graph) { | |||
| for (auto &child_graph_node : child_graph_list) { | |||
| AnfAlgo::SetGraphId(child_graph->graph_id(), child_graph_node.get()); | |||
| } | |||
| ConstructSplitedGraph(child_graph, child_graph_list); | |||
| auto call_node_args = ConstructSplitedGraph(child_graph, child_graph_list); | |||
| std::copy(call_node_args.begin(), call_node_args.end(), std::back_inserter(new_call_input)); | |||
| auto new_call = graph->NewCNode(new_call_input); | |||
| AnfAlgo::SetNodeAttr("graph id", MakeValue(graph->graph_id()), new_call); | |||
| return new_call; | |||
| @@ -1466,26 +1487,59 @@ void AscendSession::SplitGraph(const KernelGraphPtr &graph) { | |||
| std::list<AnfNodePtr> depend_input = {}; | |||
| for (size_t call_index = 0; call_index < child_graph_lists.size(); call_index++) { | |||
| auto call_node = bind_new_call_to_new_graph(child_graph_lists[call_index]); | |||
| MS_EXCEPTION_IF_NULL(call_node); | |||
| // if call node is the last call of true graph,no need create child graph after that | |||
| auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node->cast<CNodePtr>()); | |||
| depend_input.push_front(call_node); | |||
| if (child_graphs.size() == 1 && child_graphs[0] == graph->parent_graph()) { | |||
| break; | |||
| } | |||
| } | |||
| depend_input.push_front(graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())))); | |||
| auto depend = graph->NewCNode(std::vector<AnfNodePtr>(depend_input.begin(), depend_input.end())); | |||
| auto new_return_primitive = | |||
| graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimReturn->name()))); | |||
| graph->set_return(graph->NewCNode({new_return_primitive, depend})); | |||
| AnfNodePtr pre_call_node = nullptr; | |||
| AnfNodePtr cur_call_node = nullptr; | |||
| auto iter = depend_input.begin(); | |||
| for (++iter; iter != depend_input.end(); ++iter) { | |||
| pre_call_node = cur_call_node; | |||
| cur_call_node = *iter; | |||
| if (pre_call_node != nullptr && cur_call_node != nullptr) { | |||
| AscendControlParser::InsertControlDependToGraph(NOT_NULL(graph), NOT_NULL(cur_call_node), | |||
| NOT_NULL(pre_call_node)); | |||
| } | |||
| } | |||
| } | |||
| graph->UpdateChildGraphOrder(); | |||
| AscendControlParser::UpdateChildGraphOrder(NOT_NULL(graph)); | |||
| UpdateRealInput(graph.get()); | |||
| auto graph_name = std::string("./kernel-graph-").append(std::to_string(graph->graph_id())); | |||
| DumpIR(graph_name, graph); | |||
| MS_LOG(INFO) << "split graph[" << graph->graph_id() << "] end"; | |||
| // recurse to split child graph | |||
| for (auto &child_graph : graph->child_graph_order()) { | |||
| SplitGraph(child_graph); | |||
| if (child_graph != graph->parent_graph()) { | |||
| SplitGraph(child_graph); | |||
| } | |||
| } | |||
| } | |||
| void AscendSession::LinkChildGraphs(NotNull<KernelGraphPtr> graph) { AscendControlParser::LinkGraph(graph); } | |||
| void AscendSession::RootGraphExecutorValidate(NotNull<KernelGraphPtr> graph) { | |||
| AscendControlParser::ExecutorValidate(graph); | |||
| } | |||
| void AscendSession::RecurseCompileGraph(const KernelGraphPtr &graph) { | |||
| CompileChildGraph(graph); | |||
| for (auto child_graph : graph->child_graph_order()) { | |||
| if (child_graph == graph->parent_graph()) { | |||
| continue; | |||
| } | |||
| RecurseCompileGraph(child_graph); | |||
| } | |||
| } | |||
| } // namespace session | |||
| } // namespace mindspore | |||
| @@ -104,10 +104,10 @@ class AscendSession : public SessionBasic { | |||
| void SelectKernelGraphKernel(const KernelGraph &graph) {} | |||
| void ConvertPredictModel(const KernelGraphPtr graph) {} | |||
| void HardwareOptimizeGraphs(const KernelGraphPtr graph) {} | |||
| void RootGraphExecutorValidate(KernelGraph *graph) {} | |||
| void RecurseUpdateAllChildGraohOrder(KernelGraph *root_graph); | |||
| KernelGraphPtr ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph, const std::vector<CNodePtr> &list); | |||
| void ChildGraphCommunicationDecrease(std::vector<std::vector<AnfNodePtr>> *anf_node_lists); | |||
| void RootGraphExecutorValidate(NotNull<KernelGraphPtr> graph); | |||
| std::vector<AnfNodePtr> ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph, | |||
| const std::vector<CNodePtr> &list); | |||
| void RecurseCompileGraph(const KernelGraphPtr &graph); | |||
| // merge execution order list of child graphs | |||
| void MergeGraphExecOrder(); | |||
| @@ -165,6 +165,21 @@ void KernelGraph::SetExecOrderByDefault() { | |||
| } | |||
| } | |||
| CheckLoop(); | |||
| // resort start label / end goto | |||
| std::vector<CNodePtr> re_order; | |||
| if (start_label_ != nullptr) { | |||
| re_order.push_back(start_label_); | |||
| } | |||
| for (auto &node : execution_order_) { | |||
| if (node == start_label_ || node == end_goto_) { | |||
| continue; | |||
| } | |||
| re_order.push_back(node); | |||
| } | |||
| if (end_goto_ != nullptr) { | |||
| re_order.push_back(end_goto_); | |||
| } | |||
| execution_order_ = re_order; | |||
| } | |||
| void KernelGraph::CheckLoop() { | |||
| @@ -360,7 +375,8 @@ void KernelGraph::FrontBackendlMapAdd(const AnfNodePtr &front_anf, const AnfNode | |||
| void KernelGraph::FrontBackendlMapUpdate(const AnfNodePtr &old_backend_anf, const AnfNodePtr &new_backend_anf) { | |||
| MS_EXCEPTION_IF_NULL(old_backend_anf); | |||
| MS_EXCEPTION_IF_NULL(new_backend_anf); | |||
| if (old_backend_anf.get() == new_backend_anf.get()) { | |||
| if (old_backend_anf == new_backend_anf) { | |||
| MS_LOG(INFO) << "old:" << old_backend_anf->DebugString() << ",new:" << new_backend_anf->DebugString(); | |||
| MS_LOG(EXCEPTION) << "old can't be same with new"; | |||
| } | |||
| if (backend_front_anf_map_.find(old_backend_anf) == backend_front_anf_map_.end()) { | |||
| @@ -569,32 +585,52 @@ void KernelGraph::ReplaceNode(const AnfNodePtr &old_anf_node, AnfNodePtr new_anf | |||
| MS_EXCEPTION_IF_NULL(new_anf_node); | |||
| MS_EXCEPTION_IF_NULL(inputs_); | |||
| auto it = node_output_edges_.find(old_anf_node); | |||
| if (it == node_output_edges_.end()) { | |||
| MS_LOG(EXCEPTION) << "Can't find anf node in node_output_edges map"; | |||
| } | |||
| auto &outputs = it->second; | |||
| for (auto &output_node : outputs) { | |||
| auto output_cnode = output_node.first->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(output_cnode); | |||
| auto &output_node_inputs = output_cnode->inputs(); | |||
| for (size_t i = 1; i < output_node_inputs.size(); i++) { | |||
| if (output_node_inputs[i] == old_anf_node) { | |||
| output_cnode->set_input(i, new_anf_node); | |||
| if (it != node_output_edges_.end()) { | |||
| const auto &outputs = it->second; | |||
| for (auto &output_node : outputs) { | |||
| MS_EXCEPTION_IF_NULL(output_node.first); | |||
| auto output_cnode = output_node.first->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(output_cnode); | |||
| const auto &output_node_inputs = output_cnode->inputs(); | |||
| for (size_t i = 1; i < output_node_inputs.size(); i++) { | |||
| if (output_node_inputs[i] == old_anf_node) { | |||
| output_cnode->set_input(i, new_anf_node); | |||
| } | |||
| } | |||
| } | |||
| // update graph inputs | |||
| for (size_t i = 0; i < inputs_->size(); i++) { | |||
| if ((*inputs_)[i] == old_anf_node) { | |||
| (*inputs_)[i] = new_anf_node; | |||
| break; | |||
| // update graph inputs | |||
| for (size_t i = 0; i < inputs_->size(); i++) { | |||
| if ((*inputs_)[i] == old_anf_node) { | |||
| MS_LOG(INFO) << "Replace input of graph:" << graph_id_ << ", old graph input: " << old_anf_node->DebugString() | |||
| << ",new graph input:" << new_anf_node->DebugString(); | |||
| (*inputs_)[i] = new_anf_node; | |||
| break; | |||
| } | |||
| } | |||
| MS_LOG(INFO) << "Inputs of graph id:" << graph_id(); | |||
| for (size_t i = 0; i < inputs().size(); i++) { | |||
| MS_LOG(INFO) << "[" << i << "]:" << inputs()[i]->DebugString(); | |||
| } | |||
| } | |||
| // update front to backend map | |||
| FrontBackendlMapUpdate(old_anf_node, new_anf_node); | |||
| // update output depend relations | |||
| node_output_edges_[new_anf_node] = it->second; | |||
| (void)node_output_edges_.erase(old_anf_node); | |||
| } | |||
| // update graph inputs in child graph | |||
| auto it_real_inputs = real_inputs_.find(old_anf_node); | |||
| if (it_real_inputs != real_inputs_.end()) { | |||
| // insert new parameter to map | |||
| auto iter = real_inputs_.find(new_anf_node); | |||
| if (iter != real_inputs_.end()) { | |||
| MS_LOG(WARNING) << new_anf_node->DebugString() << " already exist in real inputs, will be rewrited."; | |||
| iter->second = it_real_inputs->second; | |||
| } else { | |||
| real_inputs_[new_anf_node] = it_real_inputs->second; | |||
| } | |||
| // erase old parameter in map | |||
| real_inputs_.erase(old_anf_node); | |||
| } | |||
| // update front to backend map | |||
| FrontBackendlMapUpdate(old_anf_node, new_anf_node); | |||
| // update output depend relations | |||
| node_output_edges_[new_anf_node] = it->second; | |||
| (void)node_output_edges_.erase(old_anf_node); | |||
| } | |||
| void KernelGraph::UpdateExecuteKernelStreamLabel() { | |||
| @@ -603,29 +639,6 @@ void KernelGraph::UpdateExecuteKernelStreamLabel() { | |||
| } | |||
| } | |||
| void KernelGraph::UpdateChildGraphOrder() { | |||
| MS_LOG(INFO) << "graph id:" << graph_id_; | |||
| auto call_nodes = FindNodeByPrimitive(std::make_shared<Primitive>(prim::kPrimCall->name())); | |||
| for (auto &old_child_graph : child_graph_order_) { | |||
| old_child_graph->set_parent_graph(nullptr); | |||
| } | |||
| child_graph_order_.clear(); | |||
| for (auto &call_node : call_nodes) { | |||
| MS_EXCEPTION_IF_NULL(call_node); | |||
| auto call_child_graphs = AnfAlgo ::GetCallNodeKernelGraph(call_node->cast<CNodePtr>()); | |||
| for (const auto &child_graph : call_child_graphs) { | |||
| MS_EXCEPTION_IF_NULL(child_graph); | |||
| if (child_graph != parent_graph()) { | |||
| child_graph->set_parent_graph(shared_from_this()->cast<std::shared_ptr<KernelGraph>>()); | |||
| child_graph_order_.push_back(child_graph); | |||
| } | |||
| } | |||
| } | |||
| for (size_t i = 0; i < child_graph_order_.size(); i++) { | |||
| MS_LOG(INFO) << "child graph[" << i << "][id:" << child_graph_order_[i]->graph_id() << "]"; | |||
| } | |||
| } | |||
| std::vector<std::shared_ptr<KernelGraph>> KernelGraph::GetLeafGraphOrder() { | |||
| std::vector<std::shared_ptr<KernelGraph>> leaf_graph_order; | |||
| if (IsLeafGraph()) { | |||
| @@ -643,9 +656,8 @@ std::vector<std::shared_ptr<KernelGraph>> KernelGraph::GetLeafGraphOrder() { | |||
| bool KernelGraph::IsLeafGraph() const { return child_graph_order_.empty(); } | |||
| std::vector<CNodePtr> KernelGraph::FindNodeByPrimitive(const PrimitivePtr &primitive) const { | |||
| auto anf_list = TopoSort(get_return()); | |||
| std::vector<CNodePtr> result; | |||
| for (const auto &anf : anf_list) { | |||
| for (const auto &anf : execution_order_) { | |||
| if (AnfAlgo::CheckPrimitiveType(anf, primitive) && AnfAlgo::GetGraphId(anf.get()) == graph_id_) { | |||
| result.push_back(anf->cast<CNodePtr>()); | |||
| } | |||
| @@ -653,14 +665,6 @@ std::vector<CNodePtr> KernelGraph::FindNodeByPrimitive(const PrimitivePtr &primi | |||
| return result; | |||
| } | |||
| std::set<AnfNodePtr> KernelGraph::GetRealInput(const AnfNodePtr ¶meter) { | |||
| MS_EXCEPTION_IF_NULL(parameter); | |||
| if (real_inputs_.find(parameter) == real_inputs_.end()) { | |||
| return {}; | |||
| } | |||
| return real_inputs_[parameter]; | |||
| } | |||
| void KernelGraph::SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &arg) { | |||
| MS_EXCEPTION_IF_NULL(parameter); | |||
| MS_EXCEPTION_IF_NULL(arg); | |||
| @@ -674,37 +678,41 @@ void KernelGraph::SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &ar | |||
| (void)args.insert(arg); | |||
| } | |||
| std::set<AnfNodePtr> KernelGraph::GetRealInput(const AnfNodePtr ¶meter) { | |||
| MS_EXCEPTION_IF_NULL(parameter); | |||
| auto iter = real_inputs_.find(parameter); | |||
| if (iter != real_inputs_.end()) { | |||
| return iter->second; | |||
| } | |||
| MS_LOG(EXCEPTION) << parameter->DebugString() << " not found."; | |||
| } | |||
| void KernelGraph::UpdateCallRealInput() { | |||
| MS_LOG(INFO) << "Update graph id: " << graph_id_; | |||
| for (auto &it : real_inputs_) { | |||
| auto ¶meter = it.first; | |||
| MS_EXCEPTION_IF_NULL(parameter); | |||
| auto &real_inputs = it.second; | |||
| std::set<AnfNodePtr> new_real_inputs; | |||
| std::vector<AnfNodePtr> new_real_inputs; | |||
| std::set<AnfNodePtr> erase_real_inputs; | |||
| for (auto &real_input : real_inputs) { | |||
| // if real input is a call node ,find the child graph output act as the new real input | |||
| auto item_with_index = AnfAlgo::VisitKernelWithReturnType(real_input, 0); | |||
| MS_EXCEPTION_IF_NULL(item_with_index.first); | |||
| if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimCall)) { | |||
| MS_LOG(INFO) << "paramter: " << parameter->DebugString() | |||
| << " erase real input:" << item_with_index.first->DebugString(); | |||
| (void)erase_real_inputs.insert(item_with_index.first); | |||
| auto call_node_outputs = GetCallRealOutputs(item_with_index.first); | |||
| for (auto &call_node_output : call_node_outputs) { | |||
| MS_EXCEPTION_IF_NULL(call_node_output); | |||
| MS_LOG(INFO) << "paramter: " << parameter->DebugString() | |||
| << " insert real input:" << call_node_output->DebugString(); | |||
| (void)new_real_inputs.insert(call_node_output); | |||
| } | |||
| new_real_inputs = GetCallRealOutputs(item_with_index.first); | |||
| continue; | |||
| } | |||
| for (auto &erase_node : erase_real_inputs) { | |||
| (void)real_inputs.erase(erase_node); | |||
| } | |||
| for (auto &new_real_input : new_real_inputs) { | |||
| (void)real_inputs.insert(new_real_input); | |||
| } | |||
| } | |||
| for (auto &erase_node : erase_real_inputs) { | |||
| MS_LOG(INFO) << "paramter: " << parameter->DebugString() << " erase real input:" << erase_node->DebugString(); | |||
| (void)real_inputs.erase(erase_node); | |||
| } | |||
| for (auto &new_real_input : new_real_inputs) { | |||
| MS_LOG(INFO) << "paramter: " << parameter->DebugString() | |||
| << " insert real input:" << new_real_input->DebugString(); | |||
| (void)real_inputs.insert(new_real_input); | |||
| } | |||
| } | |||
| } | |||
| @@ -103,10 +103,9 @@ class KernelGraph : public FuncGraph { | |||
| void UpdateExecuteKernelStreamLabel(); | |||
| // calculate the leaf graph order of root graph | |||
| std::vector<std::shared_ptr<KernelGraph>> GetLeafGraphOrder(); | |||
| // update the child graph order of graph | |||
| void UpdateChildGraphOrder(); | |||
| // get the child graph of current graph | |||
| std::vector<std::shared_ptr<KernelGraph>> child_graph_order() const { return child_graph_order_; } | |||
| // the child graph of current graph | |||
| const std::vector<std::shared_ptr<KernelGraph>> &child_graph_order() const { return child_graph_order_; } | |||
| void set_child_graph_order(const std::vector<std::shared_ptr<KernelGraph>> &order) { child_graph_order_ = order; } | |||
| // checkout whether current graph is leaf graph | |||
| bool IsLeafGraph() const; | |||
| @@ -123,6 +122,7 @@ class KernelGraph : public FuncGraph { | |||
| // find anf node in graph | |||
| std::vector<CNodePtr> FindNodeByPrimitive(const PrimitivePtr &primitive) const; | |||
| // get real inputs | |||
| const std::map<AnfNodePtr, std::set<AnfNodePtr>> &real_inputs() const { return real_inputs_; } | |||
| std::set<AnfNodePtr> GetRealInput(const AnfNodePtr ¶meter); | |||
| void SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &arg); | |||
| // used to dump ir | |||
| @@ -132,6 +132,8 @@ class KernelGraph : public FuncGraph { | |||
| void set_start_label(const CNodePtr &start_label) { start_label_ = start_label; } | |||
| CNodePtr get_start_label() { return start_label_; } | |||
| void set_end_goto(const CNodePtr &end_goto) { end_goto_ = end_goto; } | |||
| CNodePtr get_end_goto() { return end_goto_; } | |||
| private: | |||
| // remove value node form graph | |||
| @@ -185,6 +187,7 @@ class KernelGraph : public FuncGraph { | |||
| std::map<AnfNodePtr, std::set<AnfNodePtr>> real_inputs_; | |||
| CNodePtr start_label_; | |||
| CNodePtr end_goto_; | |||
| }; | |||
| } // namespace session | |||
| using KernelGraphPtr = std::shared_ptr<session::KernelGraph>; | |||
| @@ -147,6 +147,7 @@ BaseRef CreatTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph, | |||
| MS_LOG(INFO) << "create tensor for output[" << anf->DebugString() << "]"; | |||
| auto item_with_index = AnfAlgo::VisitKernelWithReturnType(anf, 0); | |||
| MS_EXCEPTION_IF_NULL(item_with_index.first); | |||
| MS_LOG(INFO) << "create tensor for output after visit:" << item_with_index.first->DebugString(); | |||
| // special handle for maketuple | |||
| if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) { | |||
| auto cnode = item_with_index.first->cast<CNodePtr>(); | |||
| @@ -479,31 +480,12 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) | |||
| } | |||
| for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) { | |||
| auto anf = cnode->inputs()[input_idx]; | |||
| auto anf = cnode->input(input_idx); | |||
| MS_EXCEPTION_IF_NULL(anf); | |||
| // anf has been created before | |||
| if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) { | |||
| cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf)); | |||
| continue; | |||
| } else if (anf->isa<ValueNode>()) { | |||
| if (!IsValueNode<FuncGraph>(anf)) { | |||
| // if input is a common value node, | |||
| auto new_value_node = CreateNewValueNode(anf, graph); | |||
| if (new_value_node != nullptr) { | |||
| cnode_inputs.emplace_back(new_value_node); | |||
| } | |||
| } else { | |||
| // if input is a ValueNode<FuncGraph> | |||
| auto new_value_node = CreateValueNodeKernelGraph(anf, graph); | |||
| if (new_value_node != nullptr) { | |||
| cnode_inputs.emplace_back(new_value_node); | |||
| } | |||
| } | |||
| continue; | |||
| } else if (anf->isa<Parameter>()) { | |||
| auto new_parameter = CreateNewParameter(anf, graph); | |||
| cnode_inputs.push_back(new_parameter); | |||
| continue; | |||
| } | |||
| MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]"; | |||
| } | |||
| @@ -613,32 +595,22 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP | |||
| for (const auto &node : node_list) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString(); | |||
| if (!node->isa<CNode>()) { | |||
| MS_LOG(DEBUG) << "Node " << node->DebugString() << " is not CNode"; | |||
| if (node->isa<Parameter>()) { | |||
| (void)CreateNewParameter(node, graph.get()); | |||
| continue; | |||
| } else if (node->isa<ValueNode>()) { | |||
| if (!IsValueNode<FuncGraph>(node)) { | |||
| // if input is a common value node, | |||
| (void)CreateNewValueNode(node, graph.get()); | |||
| } else { | |||
| // if input is a ValueNode<FuncGraph> | |||
| auto child_graph = ConstructKernelGraph(AnfAlgo::GetValueNodeFuncGraph(node)); | |||
| auto new_value_node = CreateValueNodeKernelGraph(node, graph.get()); | |||
| } | |||
| continue; | |||
| } else { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| // recurse control ops: call, partial | |||
| auto attr_input = cnode->input(kAnfPrimitiveIndex); | |||
| MS_EXCEPTION_IF_NULL(attr_input); | |||
| if (IsValueNode<FuncGraph>(attr_input)) { | |||
| // recurse call subgraph | |||
| auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(attr_input); | |||
| ConstructKernelGraph(sub_func_graph); | |||
| } else if (IsValueNode<Primitive>(attr_input)) { | |||
| auto prim = GetCNodePrimitive(node); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| if (prim->name() == kPartialOpName) { | |||
| // recurse partial subgraph | |||
| auto func_graph_node = cnode->input(kAnfPartialFuncGraphIndex); | |||
| MS_EXCEPTION_IF_NULL(func_graph_node); | |||
| auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(func_graph_node); | |||
| ConstructKernelGraph(sub_func_graph); | |||
| } | |||
| } | |||
| // create a new cnode object | |||
| auto new_cnode = CreateNewCNode(cnode, graph.get()); | |||
| MS_EXCEPTION_IF_NULL(new_cnode); | |||
| @@ -650,7 +622,21 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP | |||
| } | |||
| } | |||
| } | |||
| auto graph_inputs = graph->MutableInputs(); | |||
| MS_EXCEPTION_IF_NULL(graph_inputs); | |||
| graph_inputs->clear(); | |||
| for (auto ¶meter : func_graph->parameters()) { | |||
| MS_EXCEPTION_IF_NULL(parameter); | |||
| auto backend_parameter = graph->GetBackendAnfByFrontAnf(parameter); | |||
| if (backend_parameter == nullptr) { | |||
| // for example "def f(x,y,z) {return x + y}", parameter z in unused | |||
| CreateNewParameterFromParameter(parameter, false, graph.get()); | |||
| MS_LOG(INFO) << "Can't find parameter:" << parameter->DebugString(); | |||
| continue; | |||
| } | |||
| MS_LOG(INFO) << "graph[" << graph->graph_id() << "],parameter:" << parameter->DebugString(); | |||
| graph_inputs->push_back(backend_parameter); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(context_); | |||
| FuncGraphManagerPtr manager = context_->manager(); | |||
| if (manager) { | |||
| @@ -716,6 +702,11 @@ void SessionBasic::UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_grap | |||
| const std::vector<tensor::TensorPtr> &input_tensors) const { | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| MS_EXCEPTION_IF_NULL(outputs); | |||
| if (!kernel_graph->child_graph_order().empty()) { | |||
| // use the last child graph output as the root graph output | |||
| UpdateOutputs(kernel_graph->child_graph_order().back(), outputs, input_tensors); | |||
| return; | |||
| } | |||
| auto anf_outputs = kernel_graph->outputs(); | |||
| for (auto &item : anf_outputs) { | |||
| MS_LOG(INFO) << "update output[" << item->DebugString() << "]"; | |||
| @@ -487,8 +487,7 @@ void CompileGraph::AddExternal(const LinConvertResult &result) { | |||
| } | |||
| void TraverseGraphMap( | |||
| const FuncGraphManagerPtr &manager_ptr, FuncGraphTransaction *const tr, | |||
| const FuncGraphSet &fgs, | |||
| const FuncGraphManagerPtr &manager_ptr, FuncGraphTransaction *const tr, const FuncGraphSet &fgs, | |||
| const std::function<std::shared_ptr<FuncGraph>(const PrimitivePtr, const AbstractFunctionPtr)> &get_prim_graph) { | |||
| MS_EXCEPTION_IF_NULL(manager_ptr); | |||
| MS_EXCEPTION_IF_NULL(tr); | |||