insert assign at the end of graph clear log clear log 2 handle replace call of unreuse args handle bug of replace nodetags/v0.6.0-beta
| @@ -102,7 +102,7 @@ static void AssignLabelForGotoSwitch(NotNull<std::shared_ptr<session::KernelGrap | |||||
| memo->insert(graph.get()); | memo->insert(graph.get()); | ||||
| MS_LOG(INFO) << "Process label goto/switch for " << graph->ToString(); | MS_LOG(INFO) << "Process label goto/switch for " << graph->ToString(); | ||||
| graph->SetExecOrderByDefault(); | |||||
| auto nodes = graph->execution_order(); | auto nodes = graph->execution_order(); | ||||
| auto end_goto = graph->get_end_goto(); | auto end_goto = graph->get_end_goto(); | ||||
| if (end_goto != nullptr) { | if (end_goto != nullptr) { | ||||
| @@ -128,6 +128,7 @@ static void AssignLabelForGotoSwitch(NotNull<std::shared_ptr<session::KernelGrap | |||||
| for (auto &cg : graph->child_graph_order()) { | for (auto &cg : graph->child_graph_order()) { | ||||
| AssignLabelForGotoSwitch(NOT_NULL(cg), memo); | AssignLabelForGotoSwitch(NOT_NULL(cg), memo); | ||||
| } | } | ||||
| graph->SetExecOrderByDefault(); | |||||
| } | } | ||||
| void AscendLabelAssign::AssignLabel(NotNull<std::shared_ptr<session::KernelGraph>> graph) { | void AscendLabelAssign::AssignLabel(NotNull<std::shared_ptr<session::KernelGraph>> graph) { | ||||
| @@ -199,7 +199,6 @@ class AnfRuntimeAlgorithm { | |||||
| static bool IsScalarInput(const CNodePtr &cnode, size_t index); | static bool IsScalarInput(const CNodePtr &cnode, size_t index); | ||||
| static bool IsScalarOutput(const CNodePtr &cnode, size_t index); | static bool IsScalarOutput(const CNodePtr &cnode, size_t index); | ||||
| static void ReorderExecList(NotNull<std::vector<CNodePtr> *> node_list); | static void ReorderExecList(NotNull<std::vector<CNodePtr> *> node_list); | ||||
| static bool IsWhileTrueGraph(const KernelGraphPtr &child_graph); | |||||
| // get fix output precision of cnode. | // get fix output precision of cnode. | ||||
| static TypeId GetCNodeOutputPrecision(const AnfNodePtr &node); | static TypeId GetCNodeOutputPrecision(const AnfNodePtr &node); | ||||
| // get fix output precision from prev node, input_idx is the input index of current node related to prev node. | // get fix output precision from prev node, input_idx is the input index of current node related to prev node. | ||||
| @@ -19,6 +19,7 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include "session/anf_runtime_algorithm.h" | #include "session/anf_runtime_algorithm.h" | ||||
| #include "utils/union_find_set.h" | #include "utils/union_find_set.h" | ||||
| #include "device/ascend/ascend_label_assign.h" | |||||
| static constexpr size_t kCNodePrim = 0; | static constexpr size_t kCNodePrim = 0; | ||||
| static constexpr size_t kCNodeCallArg = 1; | static constexpr size_t kCNodeCallArg = 1; | ||||
| @@ -35,17 +36,25 @@ namespace mindspore { | |||||
| namespace session { | namespace session { | ||||
| static CNodePtr GetJumpNode(NotNull<KernelGraphPtr> parent_graph, NotNull<KernelGraphPtr> child_graph) { | static CNodePtr GetJumpNode(NotNull<KernelGraphPtr> parent_graph, NotNull<KernelGraphPtr> child_graph) { | ||||
| auto &nodes = parent_graph->execution_order(); | auto &nodes = parent_graph->execution_order(); | ||||
| CNodePtr last_jump_node = nullptr; | |||||
| for (auto &node : nodes) { | for (auto &node : nodes) { | ||||
| if (IsPrimitiveCNode(node, prim::kPrimLabelGoto) && child_graph->get_start_label() == node->input(kCNodeCallArg)) { | |||||
| return node; | |||||
| } else if (IsPrimitiveCNode(node, prim::kPrimLabelSwitch) && | |||||
| (child_graph->get_start_label() == node->input(kCNodeSwitchFalse) || | |||||
| child_graph->get_start_label() == node->input(kCNodeSwitchTrue))) { | |||||
| return node; | |||||
| if (IsPrimitiveCNode(node, prim::kPrimLabelGoto)) { | |||||
| if (child_graph->get_start_label() == node->input(kCNodeCallArg)) { | |||||
| return node; | |||||
| } | |||||
| last_jump_node = node; | |||||
| } else if (IsPrimitiveCNode(node, prim::kPrimLabelSwitch)) { | |||||
| if (child_graph->get_start_label() == node->input(kCNodeSwitchFalse) || | |||||
| child_graph->get_start_label() == node->input(kCNodeSwitchTrue)) { | |||||
| return node; | |||||
| } | |||||
| last_jump_node = node; | |||||
| } | } | ||||
| } | } | ||||
| MS_LOG(INFO) << "Cannot find jump node from " << parent_graph->ToString() << " to " << child_graph->ToString(); | |||||
| return nullptr; | |||||
| if (last_jump_node == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "Cannot find jump node from " << parent_graph->ToString() << " to " << child_graph->ToString(); | |||||
| } | |||||
| return last_jump_node; | |||||
| } | } | ||||
| static void InitUnionFindSet(NotNull<KernelGraphPtr> kg, const NotNull<UnionFindSet<AnfNodePtr> *> union_find_set, | static void InitUnionFindSet(NotNull<KernelGraphPtr> kg, const NotNull<UnionFindSet<AnfNodePtr> *> union_find_set, | ||||
| @@ -90,6 +99,9 @@ static void UnionParentParameter(NotNull<KernelGraphPtr> kg, const NotNull<Union | |||||
| if (!arg->isa<Parameter>()) { | if (!arg->isa<Parameter>()) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (kg->unreuse_args().find(arg) != kg->unreuse_args().end()) { | |||||
| continue; | |||||
| } | |||||
| union_find_set->Union(arg, para); | union_find_set->Union(arg, para); | ||||
| } | } | ||||
| } | } | ||||
| @@ -133,24 +145,28 @@ static void RecursiveReplaceNode(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> | |||||
| } | } | ||||
| } | } | ||||
| static AnfNodePtr GetMainParameter(NotNull<KernelGraphPtr> root_kg, const AnfNodePtr key, | |||||
| const std::set<AnfNodePtr> ¶meter_reuse_set) { | |||||
| AnfNodePtr main_parameter = key; | |||||
| std::set<AnfNodePtr> root_inputs_set; | |||||
| const auto &root_inputs_vector = root_kg->inputs(); | |||||
| root_inputs_set.insert(root_inputs_vector.begin(), root_inputs_vector.end()); | |||||
| for (auto &node : parameter_reuse_set) { | |||||
| if (root_inputs_set.find(node) != root_inputs_set.end()) { | |||||
| main_parameter = node; | |||||
| break; | |||||
| } | |||||
| } | |||||
| return main_parameter; | |||||
| } | |||||
| static void ReuseParameter(NotNull<KernelGraphPtr> root_kg, NotNull<UnionFindSet<AnfNodePtr> *> parameter_set) { | static void ReuseParameter(NotNull<KernelGraphPtr> root_kg, NotNull<UnionFindSet<AnfNodePtr> *> parameter_set) { | ||||
| auto parameter_reuse_sets = parameter_set->GetSets(); | auto parameter_reuse_sets = parameter_set->GetSets(); | ||||
| for (auto &[key, parameter_reuse_set] : parameter_reuse_sets) { | for (auto &[key, parameter_reuse_set] : parameter_reuse_sets) { | ||||
| if (parameter_reuse_set.size() <= 1) { | if (parameter_reuse_set.size() <= 1) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| AnfNodePtr main_parameter = key; | |||||
| std::set<AnfNodePtr> root_inputs_set; | |||||
| const auto &root_inputs_vector = root_kg->inputs(); | |||||
| root_inputs_set.insert(root_inputs_vector.begin(), root_inputs_vector.end()); | |||||
| for (auto &node : parameter_reuse_set) { | |||||
| if (root_inputs_set.find(node) != root_inputs_set.end()) { | |||||
| main_parameter = node; | |||||
| break; | |||||
| } | |||||
| } | |||||
| auto main_parameter = GetMainParameter(root_kg, key, parameter_reuse_set); | |||||
| std::set<KernelGraphPtr> memo; | std::set<KernelGraphPtr> memo; | ||||
| RecursiveReplaceNode(root_kg, NOT_NULL(main_parameter), parameter_reuse_set, NOT_NULL(&memo)); | RecursiveReplaceNode(root_kg, NOT_NULL(main_parameter), parameter_reuse_set, NOT_NULL(&memo)); | ||||
| } | } | ||||
| @@ -168,6 +184,7 @@ CNodePtr GetNextRealKernel(const std::vector<CNodePtr> &list, size_t start) { | |||||
| void AscendControlParser::LinkGraph(NotNull<KernelGraphPtr> kg) { | void AscendControlParser::LinkGraph(NotNull<KernelGraphPtr> kg) { | ||||
| std::set<KernelGraphPtr> memo; | std::set<KernelGraphPtr> memo; | ||||
| (void)ProcessKernelGraph(kg, nullptr, nullptr, NOT_NULL(&memo)); | (void)ProcessKernelGraph(kg, nullptr, nullptr, NOT_NULL(&memo)); | ||||
| device::ascend::AscendLabelAssign::GetInstance().AssignLabel(kg); | |||||
| std::map<uint32_t, KernelGraphPtr> graph_id_map; | std::map<uint32_t, KernelGraphPtr> graph_id_map; | ||||
| for (auto &g : memo) { | for (auto &g : memo) { | ||||
| MS_EXCEPTION_IF_NULL(g); | MS_EXCEPTION_IF_NULL(g); | ||||
| @@ -177,12 +194,13 @@ void AscendControlParser::LinkGraph(NotNull<KernelGraphPtr> kg) { | |||||
| } | } | ||||
| graph_id_map[g->graph_id()] = g; | graph_id_map[g->graph_id()] = g; | ||||
| } | } | ||||
| // Insert Assign | |||||
| ChildGraphDataAssign(graph_id_map); | |||||
| // Make UnionFindSet | // Make UnionFindSet | ||||
| UnionFindSet<AnfNodePtr> parameter_set = MakeUnionFindSet(kg); | UnionFindSet<AnfNodePtr> parameter_set = MakeUnionFindSet(kg); | ||||
| // Reuse Parameter | // Reuse Parameter | ||||
| ReuseParameter(kg, NOT_NULL(¶meter_set)); | ReuseParameter(kg, NOT_NULL(¶meter_set)); | ||||
| // Insert Assign | |||||
| ChildGraphDataAssign(graph_id_map); | |||||
| } | } | ||||
| void AscendControlParser::ExecutorValidate(NotNull<KernelGraphPtr> root_graph) { | void AscendControlParser::ExecutorValidate(NotNull<KernelGraphPtr> root_graph) { | ||||
| @@ -193,6 +211,7 @@ void AscendControlParser::ExecutorValidate(NotNull<KernelGraphPtr> root_graph) { | |||||
| void AscendControlParser::ChildGraphDataAssign(const std::map<uint32_t, KernelGraphPtr> &graph_id_map) { | void AscendControlParser::ChildGraphDataAssign(const std::map<uint32_t, KernelGraphPtr> &graph_id_map) { | ||||
| for (auto &iter : graph_id_map) { | for (auto &iter : graph_id_map) { | ||||
| auto &kg = iter.second; | auto &kg = iter.second; | ||||
| MS_LOG(INFO) << "Data assign graph:" << kg->graph_id(); | |||||
| MS_EXCEPTION_IF_NULL(kg); | MS_EXCEPTION_IF_NULL(kg); | ||||
| std::set<std::pair<AnfNodePtr, AnfNodePtr>> memo; | std::set<std::pair<AnfNodePtr, AnfNodePtr>> memo; | ||||
| const std::vector<std::pair<AnfNodePtr, std::vector<AnfNodePtr>>> &real_inputs = kg->real_inputs(); | const std::vector<std::pair<AnfNodePtr, std::vector<AnfNodePtr>>> &real_inputs = kg->real_inputs(); | ||||
| @@ -206,8 +225,14 @@ void AscendControlParser::ChildGraphDataAssign(const std::map<uint32_t, KernelGr | |||||
| } else { | } else { | ||||
| memo.emplace(parameter, arg); | memo.emplace(parameter, arg); | ||||
| } | } | ||||
| if (arg->isa<Parameter>()) { | |||||
| auto unreuse_args_map = kg->unreuse_args(); | |||||
| auto unreuse_arg_iter = unreuse_args_map.find(arg); | |||||
| if (unreuse_arg_iter == unreuse_args_map.end()) { | |||||
| MS_EXCEPTION_IF_NULL(arg); | |||||
| MS_EXCEPTION_IF_NULL(parameter); | MS_EXCEPTION_IF_NULL(parameter); | ||||
| if (!arg->isa<Parameter>()) { | |||||
| MS_LOG(EXCEPTION) << "Reused arg must be parameter, arg:" << arg->DebugString() << "."; | |||||
| } | |||||
| MS_LOG(DEBUG) << "Parameter should be reused, no need insert assign, parameter: " << parameter->DebugString() | MS_LOG(DEBUG) << "Parameter should be reused, no need insert assign, parameter: " << parameter->DebugString() | ||||
| << ", arg:" << arg->DebugString(); | << ", arg:" << arg->DebugString(); | ||||
| continue; | continue; | ||||
| @@ -220,6 +245,7 @@ void AscendControlParser::ChildGraphDataAssign(const std::map<uint32_t, KernelGr | |||||
| NOT_NULL(parameter)); | NOT_NULL(parameter)); | ||||
| } | } | ||||
| } | } | ||||
| kg->SetExecOrderByDefault(); | |||||
| } | } | ||||
| } | } | ||||
| @@ -353,7 +379,6 @@ void AscendControlParser::RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodeP | |||||
| // 5 recurse sub graph | // 5 recurse sub graph | ||||
| CNodePtr sub_label = ProcessKernelGraph(NOT_NULL(call_kg), cur_node, back_label, memo); | CNodePtr sub_label = ProcessKernelGraph(NOT_NULL(call_kg), cur_node, back_label, memo); | ||||
| new_inputs.push_back(sub_label); | new_inputs.push_back(sub_label); | ||||
| new_inputs.insert(new_inputs.end(), origin_inputs.begin(), origin_inputs.end()); | |||||
| cur_node->set_inputs(new_inputs); | cur_node->set_inputs(new_inputs); | ||||
| cur_node->set_abstract(nullptr); | cur_node->set_abstract(nullptr); | ||||
| MS_LOG(INFO) << "Succeed processing call func " << cur_node->DebugString(); | MS_LOG(INFO) << "Succeed processing call func " << cur_node->DebugString(); | ||||
| @@ -394,7 +419,6 @@ void AscendControlParser::RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNod | |||||
| } | } | ||||
| std::swap(new_switch_inputs[kCNodeSwitchTrue], new_switch_inputs[kCNodeSwitchFalse]); | std::swap(new_switch_inputs[kCNodeSwitchTrue], new_switch_inputs[kCNodeSwitchFalse]); | ||||
| new_switch_inputs.insert(new_switch_inputs.end(), origin_switch_inputs.begin(), origin_switch_inputs.end()); | |||||
| cur_node->set_inputs(new_switch_inputs); | cur_node->set_inputs(new_switch_inputs); | ||||
| cur_node->set_abstract(nullptr); | cur_node->set_abstract(nullptr); | ||||
| MS_LOG(INFO) << "Succeed processing switch func " << cur_node->DebugString(); | MS_LOG(INFO) << "Succeed processing switch func " << cur_node->DebugString(); | ||||
| @@ -477,6 +501,16 @@ void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> fr | |||||
| auto assign_node = InsertAssignToGraph(from_graph, NOT_NULL(from_outputs[i]), NOT_NULL(to_outputs[i])); | auto assign_node = InsertAssignToGraph(from_graph, NOT_NULL(from_outputs[i]), NOT_NULL(to_outputs[i])); | ||||
| if (assign_node != nullptr) { | if (assign_node != nullptr) { | ||||
| auto jump_node = GetJumpNode(from_graph, to_graph); | auto jump_node = GetJumpNode(from_graph, to_graph); | ||||
| const auto &from_graph_exe_order = from_graph->execution_order(); | |||||
| auto jump_node_iter = std::find(from_graph_exe_order.begin(), from_graph_exe_order.end(), jump_node); | |||||
| if (jump_node_iter == from_graph_exe_order.end()) { | |||||
| MS_EXCEPTION_IF_NULL(jump_node); | |||||
| MS_LOG(EXCEPTION) << "Can't find node:" << jump_node->DebugString() << " in graph:" << from_graph->graph_id(); | |||||
| } | |||||
| // insert assign between jump_node -1 and jump_node | |||||
| if (jump_node_iter != from_graph_exe_order.begin()) { | |||||
| InsertControlDependToGraph(from_graph, NOT_NULL(*(jump_node_iter - 1)), NOT_NULL(assign_node)); | |||||
| } | |||||
| if (jump_node != nullptr) { | if (jump_node != nullptr) { | ||||
| InsertControlDependToGraph(from_graph, NOT_NULL(assign_node), NOT_NULL(jump_node)); | InsertControlDependToGraph(from_graph, NOT_NULL(assign_node), NOT_NULL(jump_node)); | ||||
| } | } | ||||
| @@ -501,8 +535,6 @@ AnfNodePtr AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, | |||||
| auto assign_node = kg->NewCNode(inputs); | auto assign_node = kg->NewCNode(inputs); | ||||
| MS_EXCEPTION_IF_NULL(assign_node); | MS_EXCEPTION_IF_NULL(assign_node); | ||||
| assign_node->set_abstract(to->abstract()); | assign_node->set_abstract(to->abstract()); | ||||
| // append the assign at the end of from graph | |||||
| InsertDependToGraph(kg, NOT_NULL(assign_node)); | |||||
| return assign_node; | return assign_node; | ||||
| } | } | ||||
| @@ -527,7 +559,6 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr> | |||||
| std::vector<CNodePtr> execution_order; | std::vector<CNodePtr> execution_order; | ||||
| uint32_t child_order_index = 0; | uint32_t child_order_index = 0; | ||||
| for (auto &node : cnodes) { | for (auto &node : cnodes) { | ||||
| execution_order.push_back(node); | execution_order.push_back(node); | ||||
| if (node == graph->get_end_goto()) { | if (node == graph->get_end_goto()) { | ||||
| @@ -23,6 +23,7 @@ | |||||
| #include "session/kernel_graph.h" | #include "session/kernel_graph.h" | ||||
| #include "utils/base_ref.h" | #include "utils/base_ref.h" | ||||
| #include "utils/contract.h" | #include "utils/contract.h" | ||||
| #include "utils/union_find_set.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace session { | namespace session { | ||||
| @@ -202,7 +202,8 @@ static std::vector<std::vector<CNodePtr>> GetChildList(const std::vector<CNodePt | |||||
| } | } | ||||
| static void BindCallArgsWithParameter(const std::vector<AnfNodePtr> ¶meters, const std::vector<AnfNodePtr> &args, | static void BindCallArgsWithParameter(const std::vector<AnfNodePtr> ¶meters, const std::vector<AnfNodePtr> &args, | ||||
| KernelGraph *child_graph) { | |||||
| const KernelGraphPtr &graph, KernelGraphPtr child_graph, | |||||
| const NotNull<std::set<KernelGraphPtr> *> memo) { | |||||
| MS_EXCEPTION_IF_NULL(child_graph); | MS_EXCEPTION_IF_NULL(child_graph); | ||||
| MS_LOG(INFO) << "Start bind parameter of child graph:" << child_graph->graph_id(); | MS_LOG(INFO) << "Start bind parameter of child graph:" << child_graph->graph_id(); | ||||
| if (args.empty()) { | if (args.empty()) { | ||||
| @@ -214,18 +215,25 @@ static void BindCallArgsWithParameter(const std::vector<AnfNodePtr> ¶meters, | |||||
| } | } | ||||
| child_graph->SetExecOrderByDefault(); | child_graph->SetExecOrderByDefault(); | ||||
| for (size_t i = 0; i < parameters.size(); i++) { | for (size_t i = 0; i < parameters.size(); i++) { | ||||
| MS_LOG(INFO) << "parameters[" << i << "]" << parameters[i]->DebugString() << ",args[" << i << "]" | |||||
| << args[i]->DebugString(); | |||||
| if (args[i] == parameters[i]) { | if (args[i] == parameters[i]) { | ||||
| child_graph->SetRealInput(parameters[i], args[i]); | |||||
| MS_LOG(INFO) << "Parameter and arg are same."; | MS_LOG(INFO) << "Parameter and arg are same."; | ||||
| continue; | continue; | ||||
| } | } | ||||
| child_graph->SetRealInput(parameters[i], args[i]); | child_graph->SetRealInput(parameters[i], args[i]); | ||||
| if (memo->find(child_graph) != memo->end() || !args[i]->isa<Parameter>()) { | |||||
| MS_LOG(INFO) << "Add unreused arg,graph:" << graph->graph_id(); | |||||
| child_graph->AddUnreuseArgs(args[i], graph); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| // if a call has kernel input, it's a child graph split from ME, so these kernel input should be set into real input of | // if a call has kernel input, it's a child graph split from ME, so these kernel input should be set into real input of | ||||
| // graph.For example, call input = (prim,graph,kernel1,kernel2),then real_input = [kernel1,kernel2] | // graph.For example, call input = (prim,graph,kernel1,kernel2),then real_input = [kernel1,kernel2] | ||||
| static void UpdateRealInput(NotNull<KernelGraphPtr> graph, bool split_flag) { | |||||
| static void UpdateRealInput(NotNull<KernelGraphPtr> graph, bool split_flag, | |||||
| const NotNull<std::set<KernelGraphPtr> *> memo) { | |||||
| MS_EXCEPTION_IF_NULL(memo.get()); | |||||
| auto call_nodes = graph->FindNodeByPrimitive(prim::kPrimCall); | auto call_nodes = graph->FindNodeByPrimitive(prim::kPrimCall); | ||||
| for (auto &call_node : call_nodes) { | for (auto &call_node : call_nodes) { | ||||
| MS_EXCEPTION_IF_NULL(call_node); | MS_EXCEPTION_IF_NULL(call_node); | ||||
| @@ -235,7 +243,7 @@ static void UpdateRealInput(NotNull<KernelGraphPtr> graph, bool split_flag) { | |||||
| std::vector<AnfNodePtr> real_args = | std::vector<AnfNodePtr> real_args = | ||||
| std::vector<AnfNodePtr>(call_node->inputs().begin() + 2, call_node->inputs().end()); | std::vector<AnfNodePtr>(call_node->inputs().begin() + 2, call_node->inputs().end()); | ||||
| std::vector<AnfNodePtr> child_inputs = child_graphs[0]->inputs(); | std::vector<AnfNodePtr> child_inputs = child_graphs[0]->inputs(); | ||||
| BindCallArgsWithParameter(child_inputs, real_args, child_graphs[0].get()); | |||||
| BindCallArgsWithParameter(child_inputs, real_args, graph, child_graphs[0], memo); | |||||
| if (split_flag) { | if (split_flag) { | ||||
| call_node->set_inputs(std::vector<AnfNodePtr>(call_node->inputs().begin(), call_node->inputs().begin() + 2)); | call_node->set_inputs(std::vector<AnfNodePtr>(call_node->inputs().begin(), call_node->inputs().begin() + 2)); | ||||
| } | } | ||||
| @@ -256,8 +264,8 @@ static void UpdateRealInput(NotNull<KernelGraphPtr> graph, bool split_flag) { | |||||
| } | } | ||||
| return ret; | return ret; | ||||
| }; | }; | ||||
| BindCallArgsWithParameter(child_graphs[0]->inputs(), get_partial_args(2), child_graphs[0].get()); | |||||
| BindCallArgsWithParameter(child_graphs[1]->inputs(), get_partial_args(3), child_graphs[1].get()); | |||||
| BindCallArgsWithParameter(child_graphs[0]->inputs(), get_partial_args(2), graph, child_graphs[0], memo); | |||||
| BindCallArgsWithParameter(child_graphs[1]->inputs(), get_partial_args(3), graph, child_graphs[1], memo); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -306,8 +314,6 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) { | |||||
| LinkChildGraphs(NOT_NULL(root_graph)); | LinkChildGraphs(NOT_NULL(root_graph)); | ||||
| // resource initialize | // resource initialize | ||||
| InitRuntimeResource(); | InitRuntimeResource(); | ||||
| // assign label | |||||
| AssignLabel(NOT_NULL(root_graph)); | |||||
| // recurse compile child root_graph | // recurse compile child root_graph | ||||
| std::set<KernelGraphPtr> memo; | std::set<KernelGraphPtr> memo; | ||||
| RecurseCompileGraph(NOT_NULL(root_graph), NOT_NULL(&memo)); | RecurseCompileGraph(NOT_NULL(root_graph), NOT_NULL(&memo)); | ||||
| @@ -665,12 +671,6 @@ void AscendSession::AssignStream(NotNull<KernelGraphPtr> kernel_graph) const { | |||||
| MS_LOG(INFO) << "Finish!"; | MS_LOG(INFO) << "Finish!"; | ||||
| } | } | ||||
| void AscendSession::AssignLabel(NotNull<KernelGraphPtr> kernel_graph) const { | |||||
| MS_LOG(INFO) << "Start!"; | |||||
| device::ascend::AscendLabelAssign::GetInstance().AssignLabel(kernel_graph); | |||||
| MS_LOG(INFO) << "Finish!"; | |||||
| } | |||||
| void AscendSession::BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const { | void AscendSession::BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const { | ||||
| MS_LOG(INFO) << "Start!"; | MS_LOG(INFO) << "Start!"; | ||||
| struct timeval start_time, end_time; | struct timeval start_time, end_time; | ||||
| @@ -1591,14 +1591,17 @@ std::vector<AnfNodePtr> AscendSession::ConstructSplitedGraph(const KernelGraphPt | |||||
| auto input = cnode->inputs()[input_idx]; | auto input = cnode->inputs()[input_idx]; | ||||
| MS_EXCEPTION_IF_NULL(input); | MS_EXCEPTION_IF_NULL(input); | ||||
| AnfNodePtr new_parameter = nullptr; | AnfNodePtr new_parameter = nullptr; | ||||
| // check whether input has been put into args of call, if mulptiple use of one parameter or cnode, only set one | |||||
| // parameter in graph inputs and one arg in call node | |||||
| auto call_input_it = std::find(call_node_inputs.begin(), call_node_inputs.end(), input); | |||||
| if (call_input_it != call_node_inputs.end()) { | |||||
| cnode->set_input(input_idx, new_graph_inputs[std::distance(call_node_inputs.begin(), call_input_it)]); | |||||
| continue; | |||||
| } | |||||
| // value node consider move to new graph | // value node consider move to new graph | ||||
| if (input->isa<ValueNode>()) { | if (input->isa<ValueNode>()) { | ||||
| cnode->set_input(input_idx, input); | cnode->set_input(input_idx, input); | ||||
| continue; | continue; | ||||
| } else if (input->isa<Parameter>()) { | |||||
| // parameter reuse and should attention mulptiple use of one parameter | |||||
| cnode->set_input(input_idx, input); | |||||
| new_parameter = input; | |||||
| } else if (AnfAlgo::GetGraphId(input.get()) != new_kernel_graph->graph_id()) { | } else if (AnfAlgo::GetGraphId(input.get()) != new_kernel_graph->graph_id()) { | ||||
| // if is cnode and not in current child graph | // if is cnode and not in current child graph | ||||
| new_parameter = CreateNewParameterFromCNode(input, true, new_kernel_graph.get()); | new_parameter = CreateNewParameterFromCNode(input, true, new_kernel_graph.get()); | ||||
| @@ -1607,12 +1610,8 @@ std::vector<AnfNodePtr> AscendSession::ConstructSplitedGraph(const KernelGraphPt | |||||
| // if is a cnode and in current graph | // if is a cnode and in current graph | ||||
| continue; | continue; | ||||
| } | } | ||||
| // if mulptiple use of one parameter or cnode, only set one parameter in graph inputs and one arg in call node | |||||
| // args | |||||
| if (std::find(call_node_inputs.begin(), call_node_inputs.end(), new_parameter) == call_node_inputs.end()) { | |||||
| new_graph_inputs.push_back(new_parameter); | |||||
| call_node_inputs.push_back(input); | |||||
| } | |||||
| new_graph_inputs.push_back(new_parameter); | |||||
| call_node_inputs.push_back(input); | |||||
| } | } | ||||
| } | } | ||||
| // set graph inputs of new graph | // set graph inputs of new graph | ||||
| @@ -1640,7 +1639,7 @@ void AscendSession::SplitGraphs(NotNull<KernelGraphPtr> root_graph) { | |||||
| // if root graph output is a call node ,the root graph is condition graph of 'if' sentence | // if root graph output is a call node ,the root graph is condition graph of 'if' sentence | ||||
| auto root_graph_output = AnfAlgo::VisitKernelWithReturnType(root_graph->output(), 0).first; | auto root_graph_output = AnfAlgo::VisitKernelWithReturnType(root_graph->output(), 0).first; | ||||
| if (AnfAlgo::CheckPrimitiveType(root_graph_output, prim::kPrimCall)) { | if (AnfAlgo::CheckPrimitiveType(root_graph_output, prim::kPrimCall)) { | ||||
| SplitGraph(root_graph, {prim::kPrimReturn}); | |||||
| SplitGraph(root_graph, {prim::kPrimReturn}, NOT_NULL(&memo)); | |||||
| for (auto &child_graph : root_graph->child_graph_order()) { | for (auto &child_graph : root_graph->child_graph_order()) { | ||||
| RecurseSplitGraph(NOT_NULL(child_graph), NOT_NULL(&memo)); | RecurseSplitGraph(NOT_NULL(child_graph), NOT_NULL(&memo)); | ||||
| } | } | ||||
| @@ -1681,7 +1680,8 @@ AnfNodePtr AscendSession::BindNewCallToNewGraph(NotNull<KernelGraphPtr> graph, | |||||
| return new_call; | return new_call; | ||||
| } | } | ||||
| void AscendSession::SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<PrimitivePtr> &cut_prims) { | |||||
| void AscendSession::SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<PrimitivePtr> &cut_prims, | |||||
| const NotNull<std::set<KernelGraphPtr> *> memo) { | |||||
| MS_LOG(INFO) << "Start,graph_id:" << graph->graph_id(); | MS_LOG(INFO) << "Start,graph_id:" << graph->graph_id(); | ||||
| bool split_flag = false; | bool split_flag = false; | ||||
| auto apply_list = GetCNodes(TopoSort(graph->get_return())); | auto apply_list = GetCNodes(TopoSort(graph->get_return())); | ||||
| @@ -1719,14 +1719,13 @@ void AscendSession::SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<Pri | |||||
| split_flag = true; | split_flag = true; | ||||
| } | } | ||||
| AscendControlParser::UpdateChildGraphOrder(graph); | AscendControlParser::UpdateChildGraphOrder(graph); | ||||
| UpdateRealInput(graph, split_flag); | |||||
| UpdateRealInput(graph, split_flag, memo); | |||||
| MS_LOG(INFO) << "Split graph[" << graph->graph_id() << "] end"; | MS_LOG(INFO) << "Split graph[" << graph->graph_id() << "] end"; | ||||
| // recurse to split child graph | |||||
| } | } | ||||
| void AscendSession::RecurseSplitGraph(NotNull<KernelGraphPtr> graph, const NotNull<std::set<KernelGraphPtr> *> memo) { | void AscendSession::RecurseSplitGraph(NotNull<KernelGraphPtr> graph, const NotNull<std::set<KernelGraphPtr> *> memo) { | ||||
| memo->insert(graph.get()); | memo->insert(graph.get()); | ||||
| SplitGraph(graph, {prim::kPrimCall}); | |||||
| SplitGraph(graph, {prim::kPrimCall}, memo); | |||||
| for (auto &child_graph : graph->child_graph_order()) { | for (auto &child_graph : graph->child_graph_order()) { | ||||
| if (memo->find(child_graph) == memo->end()) { | if (memo->find(child_graph) == memo->end()) { | ||||
| RecurseSplitGraph(NOT_NULL(child_graph), memo); | RecurseSplitGraph(NOT_NULL(child_graph), memo); | ||||
| @@ -77,7 +77,6 @@ class AscendSession : public SessionBasic { | |||||
| void AdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; | void AdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; | ||||
| void RunOpAdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; | void RunOpAdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; | ||||
| void AssignStream(NotNull<KernelGraphPtr> kernel_graph) const; | void AssignStream(NotNull<KernelGraphPtr> kernel_graph) const; | ||||
| void AssignLabel(NotNull<KernelGraphPtr> kernel_graph) const; | |||||
| void BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; | void BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; | ||||
| void MemoryAlloc(KernelGraph *kernel_graph) const; | void MemoryAlloc(KernelGraph *kernel_graph) const; | ||||
| void RunOpMemoryAlloc(const std::vector<tensor::TensorPtr> &input_tensors, KernelGraph *kernel_graph) const; | void RunOpMemoryAlloc(const std::vector<tensor::TensorPtr> &input_tensors, KernelGraph *kernel_graph) const; | ||||
| @@ -100,7 +99,8 @@ class AscendSession : public SessionBasic { | |||||
| void SetFinalGraphOutput(const ValuePtr &value); | void SetFinalGraphOutput(const ValuePtr &value); | ||||
| void SetFinalGraphOutput(const VectorRef &vec_output); | void SetFinalGraphOutput(const VectorRef &vec_output); | ||||
| void SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<PrimitivePtr> &cut_prims); | |||||
| void SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<PrimitivePtr> &cut_prims, | |||||
| const NotNull<std::set<KernelGraphPtr> *> memo); | |||||
| // split graphs with recurse from root graph | // split graphs with recurse from root graph | ||||
| void SplitGraphs(NotNull<KernelGraphPtr> root_graph); | void SplitGraphs(NotNull<KernelGraphPtr> root_graph); | ||||
| void BackendOptimization(const std::vector<KernelGraphPtr> &all_graphs); | void BackendOptimization(const std::vector<KernelGraphPtr> &all_graphs); | ||||
| @@ -103,6 +103,23 @@ AnfNodePtr MakeValueNode(const AnfNodePtr &node) { | |||||
| AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get()); | AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get()); | ||||
| return new_value_node; | return new_value_node; | ||||
| } | } | ||||
| bool IsSameLabel(const CNodePtr &left, const CNodePtr &right) { | |||||
| if (left == right) { | |||||
| return true; | |||||
| } | |||||
| if (left == nullptr || right == nullptr) { | |||||
| return false; | |||||
| } | |||||
| if (!IsPrimitiveCNode(left, GetCNodePrimitive(right))) { | |||||
| return false; | |||||
| } | |||||
| if (AnfAlgo::HasNodeAttr(kAttrLabelIndex, left) && AnfAlgo::HasNodeAttr(kAttrLabelIndex, right)) { | |||||
| return AnfAlgo::GetNodeAttr<uint32_t>(left, kAttrLabelIndex) == | |||||
| AnfAlgo::GetNodeAttr<uint32_t>(right, kAttrLabelIndex); | |||||
| } | |||||
| return false; | |||||
| } | |||||
| } // namespace | } // namespace | ||||
| std::vector<AnfNodePtr> KernelGraph::outputs() const { | std::vector<AnfNodePtr> KernelGraph::outputs() const { | ||||
| auto graph_output = output(); | auto graph_output = output(); | ||||
| @@ -219,6 +236,19 @@ void KernelGraph::SetExecOrderByDefault() { | |||||
| if (node == start_label_ || node == end_goto_) { | if (node == start_label_ || node == end_goto_) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (IsSameLabel(node, end_goto_)) { | |||||
| end_goto_ = node; | |||||
| MS_LOG(INFO) << "Replace end_goto_ in kernel graph:" << graph_id(); | |||||
| continue; | |||||
| } | |||||
| if (IsSameLabel(node, start_label_)) { | |||||
| start_label_ = node; | |||||
| MS_LOG(INFO) << "Replace start_label_ in kernel graph:" << graph_id(); | |||||
| continue; | |||||
| } | |||||
| re_order.push_back(node); | re_order.push_back(node); | ||||
| } | } | ||||
| if (end_goto_ != nullptr) { | if (end_goto_ != nullptr) { | ||||
| @@ -748,10 +778,9 @@ void KernelGraph::ReplaceNode(NotNull<AnfNodePtr> old_anf_node, NotNull<AnfNodeP | |||||
| } | } | ||||
| // update front to backend map | // update front to backend map | ||||
| FrontBackendlMapUpdate(old_anf_node, new_anf_node); | FrontBackendlMapUpdate(old_anf_node, new_anf_node); | ||||
| // update output depend relations | |||||
| node_output_edges_[new_anf_node.get()] = it->second; | |||||
| (void)node_output_edges_.erase(old_anf_node); | |||||
| } | } | ||||
| // if change the ir of graph, regenerate execution order of graph | |||||
| SetExecOrderByDefault(); | |||||
| // update graph inputs in child graph | // update graph inputs in child graph | ||||
| auto it_real_inputs = std::find_if(real_inputs_.begin(), real_inputs_.end(), | auto it_real_inputs = std::find_if(real_inputs_.begin(), real_inputs_.end(), | ||||
| [&old_anf_node](const std::pair<AnfNodePtr, std::vector<AnfNodePtr>> &n) -> bool { | [&old_anf_node](const std::pair<AnfNodePtr, std::vector<AnfNodePtr>> &n) -> bool { | ||||
| @@ -767,7 +796,7 @@ void KernelGraph::ReplaceNode(NotNull<AnfNodePtr> old_anf_node, NotNull<AnfNodeP | |||||
| return n.first == new_anf_node.get(); | return n.first == new_anf_node.get(); | ||||
| }); | }); | ||||
| if (iter != real_inputs_.end()) { | if (iter != real_inputs_.end()) { | ||||
| MS_LOG(WARNING) << new_anf_node->DebugString() << " already exist in real inputs, will be rewrited."; | |||||
| MS_LOG(WARNING) << new_anf_node->DebugString() << " Already exist in real inputs, will be rewrited."; | |||||
| iter->second = old_args; | iter->second = old_args; | ||||
| } else { | } else { | ||||
| real_inputs_.emplace_back(new_anf_node, old_args); | real_inputs_.emplace_back(new_anf_node, old_args); | ||||
| @@ -824,6 +853,10 @@ void KernelGraph::SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &ar | |||||
| } | } | ||||
| } | } | ||||
| void KernelGraph::AddUnreuseArgs(const AnfNodePtr &arg, const std::shared_ptr<KernelGraph> &from_graph) { | |||||
| unreuse_args_[arg] = from_graph; | |||||
| } | |||||
| void KernelGraph::UpdateCallRealInput() { | void KernelGraph::UpdateCallRealInput() { | ||||
| MS_LOG(INFO) << "Update graph id: " << graph_id_; | MS_LOG(INFO) << "Update graph id: " << graph_id_; | ||||
| std::vector<std::pair<AnfNodePtr, std::vector<AnfNodePtr>>> real_inputs_map; | std::vector<std::pair<AnfNodePtr, std::vector<AnfNodePtr>>> real_inputs_map; | ||||
| @@ -836,6 +869,17 @@ void KernelGraph::UpdateCallRealInput() { | |||||
| // if real input is a call node ,find the child graph output act as the new real input | // if real input is a call node ,find the child graph output act as the new real input | ||||
| auto tmp_real_input = GetCallRealOutputs(real_input); | auto tmp_real_input = GetCallRealOutputs(real_input); | ||||
| std::copy(tmp_real_input.begin(), tmp_real_input.end(), std::back_inserter(new_real_inputs)); | std::copy(tmp_real_input.begin(), tmp_real_input.end(), std::back_inserter(new_real_inputs)); | ||||
| // replace the call in unreuse_args_ | |||||
| auto unreuse_arg_it = unreuse_args_.find(real_input); | |||||
| if (unreuse_arg_it != unreuse_args_.end()) { | |||||
| auto old_graph = unreuse_arg_it->second; | |||||
| for (auto new_real_input : new_real_inputs) { | |||||
| // if call reference graph output is parameter, it will be allowed to reuse | |||||
| if (!new_real_input->isa<Parameter>()) { | |||||
| unreuse_args_[new_real_input] = old_graph; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | } | ||||
| real_inputs_map.emplace_back(parameter, new_real_inputs); | real_inputs_map.emplace_back(parameter, new_real_inputs); | ||||
| } | } | ||||
| @@ -130,6 +130,9 @@ class KernelGraph : public FuncGraph { | |||||
| // get real inputs | // get real inputs | ||||
| const std::vector<std::pair<AnfNodePtr, std::vector<AnfNodePtr>>> &real_inputs() const { return real_inputs_; } | const std::vector<std::pair<AnfNodePtr, std::vector<AnfNodePtr>>> &real_inputs() const { return real_inputs_; } | ||||
| void SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &arg); | void SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &arg); | ||||
| // mark unreused args | |||||
| void AddUnreuseArgs(const AnfNodePtr &arg, const std::shared_ptr<KernelGraph> &from_graph); | |||||
| const std::map<AnfNodePtr, std::shared_ptr<KernelGraph>> &unreuse_args() const { return unreuse_args_; } | |||||
| // used to dump ir | // used to dump ir | ||||
| std::string ToString() const override; | std::string ToString() const override; | ||||
| // update the real input if the node is a call | // update the real input if the node is a call | ||||
| @@ -198,6 +201,7 @@ class KernelGraph : public FuncGraph { | |||||
| std::shared_ptr<KernelGraph> parent_graph_; | std::shared_ptr<KernelGraph> parent_graph_; | ||||
| // record real parameters,inputs_ is the formal parameters | // record real parameters,inputs_ is the formal parameters | ||||
| std::vector<std::pair<AnfNodePtr, std::vector<AnfNodePtr>>> real_inputs_; | std::vector<std::pair<AnfNodePtr, std::vector<AnfNodePtr>>> real_inputs_; | ||||
| std::map<AnfNodePtr, std::shared_ptr<KernelGraph>> unreuse_args_; | |||||
| CNodePtr start_label_; | CNodePtr start_label_; | ||||
| CNodePtr end_goto_; | CNodePtr end_goto_; | ||||
| @@ -99,6 +99,19 @@ class ControlIfbyIfbyIf(nn.Cell): | |||||
| return out | return out | ||||
| class ControlSimpleWhile(nn.Cell): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.addn = op.AddN() | |||||
| def construct(self, x, y, input_data): | |||||
| out = input_data | |||||
| while x: | |||||
| out = self.addn([input_data, input_data, input_data]) | |||||
| x = y | |||||
| return out | |||||
| class ControlMixedWhileIf(nn.Cell): | class ControlMixedWhileIf(nn.Cell): | ||||
| def __init__(self): | def __init__(self): | ||||
| super().__init__() | super().__init__() | ||||
| @@ -204,6 +217,22 @@ def test_if_by_if_by_if(): | |||||
| assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001) | assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001) | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_simple_while(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
| x = np.array(True).astype(np.bool) | |||||
| y = np.array(False).astype(np.bool) | |||||
| input_shape = (127, 7, 53, 31) | |||||
| input_data = np.random.randn(*input_shape).astype(np.float32) | |||||
| net = ControlSimpleWhile() | |||||
| output = net(Tensor(x), Tensor(y), Tensor(input_data)) | |||||
| expect = input_data * 3 | |||||
| assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001) | |||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @pytest.mark.platform_arm_ascend_training | @pytest.mark.platform_arm_ascend_training | ||||
| @pytest.mark.platform_x86_ascend_training | @pytest.mark.platform_x86_ascend_training | ||||