Merge pull request !4295 from wenchunjiang/remove_inline_1tags/v0.7.0-beta
| @@ -1031,31 +1031,29 @@ FuncGraphPtr AnfRuntimeAlgorithm::GetValueNodeFuncGraph(const AnfNodePtr &node) | |||
| return func_graph; | |||
| } | |||
| std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallNodeKernelGraph(const CNodePtr &call_node) { | |||
| MS_EXCEPTION_IF_NULL(call_node); | |||
| if (!AnfAlgo::CheckPrimitiveType(call_node, std::make_shared<Primitive>("call"))) { | |||
| MS_LOG(EXCEPTION) << "Anf node: " << call_node->DebugString() << "is not a call node."; | |||
| std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallSwitchKernelGraph(const CNodePtr &cnode) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (!(AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall) || AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch))) { | |||
| MS_LOG(EXCEPTION) << "Node: " << cnode->DebugString() << "is not a call or switch node."; | |||
| } | |||
| auto input1 = call_node->input(1); | |||
| MS_EXCEPTION_IF_NULL(input1); | |||
| if (input1->isa<ValueNode>()) { | |||
| if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall)) { | |||
| auto input1 = cnode->input(kCallKernelGraphIndex); | |||
| MS_EXCEPTION_IF_NULL(input1); | |||
| auto value_node = input1->cast<ValueNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(value_node); | |||
| auto kernel_graph = value_node->value(); | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| return {kernel_graph->cast<KernelGraphPtr>()}; | |||
| } else if (input1->isa<CNode>() && AnfAlgo::CheckPrimitiveType(input1, prim::kPrimSwitch)) { | |||
| auto switch_node = input1->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(switch_node); | |||
| auto get_switch_kernel_graph = [switch_node](size_t input_index) -> KernelGraphPtr { | |||
| auto partial = switch_node->input(input_index); | |||
| } else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) { | |||
| auto get_switch_kernel_graph = [cnode](size_t input_index) -> KernelGraphPtr { | |||
| auto partial = cnode->input(input_index); | |||
| MS_EXCEPTION_IF_NULL(partial); | |||
| if (IsValueNode<KernelGraph>(partial)) { | |||
| return GetValueNode<KernelGraphPtr>(partial); | |||
| } | |||
| auto partial_cnode = partial->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(partial_cnode); | |||
| auto graph_node = partial_cnode->input(1); | |||
| auto graph_node = partial_cnode->input(kCallKernelGraphIndex); | |||
| MS_EXCEPTION_IF_NULL(graph_node); | |||
| auto graph_value_node = graph_node->cast<ValueNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(graph_value_node); | |||
| @@ -1064,7 +1062,8 @@ std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallNodeKernelGraph(const CN | |||
| auto child_graph = graph_value->cast<KernelGraphPtr>(); | |||
| return child_graph; | |||
| }; | |||
| return {get_switch_kernel_graph(2), get_switch_kernel_graph(3)}; | |||
| return {get_switch_kernel_graph(kSwitchTrueKernelGraphIndex), | |||
| get_switch_kernel_graph(kSwitchFalseKernelGraphIndex)}; | |||
| } | |||
| return {}; | |||
| } | |||
| @@ -201,7 +201,7 @@ class AnfRuntimeAlgorithm { | |||
| static bool IsCommunicationOp(const AnfNodePtr &node); | |||
| static bool IsGetNext(const NotNull<AnfNodePtr> &node); | |||
| static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node); | |||
| static std::vector<KernelGraphPtr> GetCallNodeKernelGraph(const CNodePtr &call_node); | |||
| static std::vector<KernelGraphPtr> GetCallSwitchKernelGraph(const CNodePtr &cnode); | |||
| static bool IsSwitchCall(const CNodePtr &call_node); | |||
| static bool IsScalarInput(const CNodePtr &cnode, size_t index); | |||
| static bool IsScalarOutput(const CNodePtr &cnode, size_t index); | |||
| @@ -361,27 +361,22 @@ void AscendControlParser::ExecutorValidate(NotNull<KernelGraphPtr> root_graph) { | |||
| } | |||
| } | |||
| std::vector<std::pair<KernelGraphPtr, std::vector<AnfNodePtr>>> AscendControlParser::ParseCallNode( | |||
| NotNull<CNodePtr> call_node) { | |||
| std::vector<std::pair<KernelGraphPtr, std::vector<AnfNodePtr>>> AscendControlParser::ParseCallSwitchNode( | |||
| NotNull<CNodePtr> cnode) { | |||
| std::vector<std::pair<KernelGraphPtr, std::vector<AnfNodePtr>>> ret; | |||
| if (!IsPrimitiveCNode(call_node.get(), prim::kPrimCall)) { | |||
| MS_LOG(EXCEPTION) << "Node " << call_node->DebugString() << " is not a call node."; | |||
| } | |||
| if (call_node->size() <= kCNodeCallArg) { | |||
| MS_LOG(EXCEPTION) << "Node " << call_node->DebugString() << " has invalid inputs size " << call_node->size(); | |||
| } | |||
| const std::vector<AnfNodePtr> &call_node_inputs = call_node->inputs(); | |||
| auto call_arg = call_node_inputs[kCNodeCallArg]; | |||
| MS_EXCEPTION_IF_NULL(call_arg); | |||
| if (IsValueNode<KernelGraph>(call_arg)) { | |||
| if (IsPrimitiveCNode(cnode.get(), prim::kPrimCall)) { | |||
| if (cnode->size() <= kCNodeCallArg) { | |||
| MS_LOG(EXCEPTION) << "Call node " << cnode->DebugString() << " has invalid inputs size " << cnode->size(); | |||
| } | |||
| auto call_arg = cnode->input(kCNodeCallArg); | |||
| MS_EXCEPTION_IF_NULL(call_arg); | |||
| ret.emplace_back(GetValueNode<KernelGraphPtr>(call_arg), | |||
| std::vector<AnfNodePtr>(call_node_inputs.begin() + kCNodeCallArg + 1, call_node_inputs.end())); | |||
| } else if (IsPrimitiveCNode(call_arg, prim::kPrimSwitch)) { | |||
| auto switch_cnode = call_arg->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(switch_cnode); | |||
| const std::vector<AnfNodePtr> &switch_inputs = switch_cnode->inputs(); | |||
| if (switch_inputs.size() <= kCNodeSwitchCond) { | |||
| MS_LOG(EXCEPTION) << "Node " << switch_cnode->DebugString() << " has invalid inputs size " | |||
| std::vector<AnfNodePtr>(cnode->inputs().begin() + kCNodeCallArg + 1, cnode->inputs().end())); | |||
| } else if (IsPrimitiveCNode(cnode.get(), prim::kPrimSwitch)) { | |||
| const std::vector<AnfNodePtr> &switch_inputs = cnode->inputs(); | |||
| if (switch_inputs.size() < kCNodeSwitchLength) { | |||
| MS_LOG(EXCEPTION) << "Switch node " << cnode->DebugString() << " has invalid inputs size " | |||
| << switch_inputs.size(); | |||
| } | |||
| for (auto iter = switch_inputs.begin() + kCNodeSwitchCond + 1; iter != switch_inputs.end(); ++iter) { | |||
| @@ -389,7 +384,7 @@ std::vector<std::pair<KernelGraphPtr, std::vector<AnfNodePtr>>> AscendControlPar | |||
| ret.emplace_back(target_graph, args); | |||
| } | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Unsupport call node: " << call_node->DebugString(5); | |||
| MS_LOG(EXCEPTION) << "Unsupport call node: " << cnode->DebugString(5); | |||
| } | |||
| return ret; | |||
| } | |||
| @@ -406,11 +401,11 @@ void AscendControlParser::ChildGraphDataAssign( | |||
| const std::vector<CNodePtr> &nodes = kg->execution_order(); | |||
| for (auto &node : nodes) { | |||
| if (!IsPrimitiveCNode(node, prim::kPrimCall)) { | |||
| if (!(IsPrimitiveCNode(node, prim::kPrimCall) || IsPrimitiveCNode(node, prim::kPrimSwitch))) { | |||
| continue; | |||
| } | |||
| auto child_graph_list = ParseCallNode(NOT_NULL(node)); | |||
| auto child_graph_list = ParseCallSwitchNode(NOT_NULL(node)); | |||
| for (auto &[child_graph, args] : child_graph_list) { | |||
| MS_EXCEPTION_IF_NULL(child_graph); | |||
| const std::vector<AnfNodePtr> ¶ms = child_graph->inputs(); | |||
| @@ -425,7 +420,6 @@ void AscendControlParser::ChildGraphDataAssign( | |||
| link_list->emplace_back(args[i], params[i]); | |||
| continue; | |||
| } | |||
| InsertMultipleAssignToGraph(kg, node, NOT_NULL(args[i]), NOT_NULL(params[i])); | |||
| } | |||
| } | |||
| @@ -475,30 +469,20 @@ NotNull<CNodePtr> AscendControlParser::ProcessKernelGraph(NotNull<KernelGraphPtr | |||
| for (size_t i = 0; i < nodes.size(); ++i) { | |||
| auto &cnode = nodes[i]; | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (cnode->size() < kCNodePrim + 1) { | |||
| MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; | |||
| } | |||
| AnfNodePtr fn = cnode->input(kAnfPrimitiveIndex); | |||
| if (!IsPrimitive(fn, prim::kPrimCall) || cnode->size() < kCNodeCallArg + 1) { | |||
| MS_LOG(DEBUG) << "Continue node " << cnode->DebugString(); | |||
| if (!(AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall) || | |||
| AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch) || | |||
| AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer))) { | |||
| continue; | |||
| } | |||
| AnfNodePtr arg = cnode->input(kFirstDataInputIndex); | |||
| MS_EXCEPTION_IF_NULL(arg); | |||
| if (IsValueNode<KernelGraph>(arg)) { | |||
| if (IsPrimitiveCNode(cnode, prim::kPrimCall)) { | |||
| RecurseCall(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo); | |||
| } else if (!arg->isa<CNode>()) { | |||
| MS_LOG(EXCEPTION) << "Unknown type call node " << cnode->DebugString(); | |||
| } else if (IsPrimitiveCNode(arg->cast<CNodePtr>(), prim::kPrimSwitch)) { | |||
| auto arg_cnode = arg->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(arg_cnode); | |||
| cnode->set_inputs(arg_cnode->inputs()); | |||
| } else if (IsPrimitiveCNode(cnode, prim::kPrimSwitch)) { | |||
| RecurseSwitch(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo); | |||
| } else if (IsPrimitiveCNode(arg->cast<CNodePtr>(), prim::kPrimSwitchLayer)) { | |||
| auto arg_cnode = arg->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(arg_cnode); | |||
| cnode->set_inputs(arg_cnode->inputs()); | |||
| } else if (IsPrimitiveCNode(cnode, prim::kPrimSwitchLayer)) { | |||
| RecurseSwitchLayer(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Unexpected node: " << cnode->DebugString(); | |||
| } | |||
| } | |||
| kg->SetExecOrderByDefault(); | |||
| @@ -699,31 +683,22 @@ void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> fr | |||
| continue; | |||
| } | |||
| const auto &from_graph_exe_order = from_graph->execution_order(); | |||
| std::vector<CNodePtr> real_exe_order(from_graph_exe_order.size()); | |||
| size_t real_exe_order_size = 0; | |||
| std::copy_if(from_graph_exe_order.begin(), from_graph_exe_order.end(), real_exe_order.begin(), | |||
| [&real_exe_order_size](const CNodePtr &node) -> bool { | |||
| return (IsPrimitiveCNode(node, prim::kPrimSwitch) || IsPrimitiveCNode(node, prim::kPrimPartial)) | |||
| ? false | |||
| : (++real_exe_order_size, true); | |||
| }); | |||
| real_exe_order.resize(real_exe_order_size); | |||
| if (jump_node == nullptr) { | |||
| if (!real_exe_order.empty()) { | |||
| InsertControlDependToGraph(from_graph, NOT_NULL(*(real_exe_order.rbegin())), NOT_NULL(assign_node)); | |||
| if (!from_graph_exe_order.empty()) { | |||
| InsertControlDependToGraph(from_graph, NOT_NULL(*(from_graph_exe_order.rbegin())), NOT_NULL(assign_node)); | |||
| } else { | |||
| InsertDependToGraph(from_graph, NOT_NULL(assign_node)); | |||
| } | |||
| continue; | |||
| } | |||
| auto jump_node_iter = std::find(real_exe_order.begin(), real_exe_order.end(), jump_node); | |||
| if (jump_node_iter == real_exe_order.end()) { | |||
| auto jump_node_iter = std::find(from_graph_exe_order.begin(), from_graph_exe_order.end(), jump_node); | |||
| if (jump_node_iter == from_graph_exe_order.end()) { | |||
| MS_LOG(EXCEPTION) << "Cannot find jump node " << jump_node->DebugString() << " in graph " | |||
| << from_graph->ToString(); | |||
| } | |||
| // insert assign between jump_node -1 and jump_node | |||
| if (jump_node_iter != real_exe_order.begin()) { | |||
| if (jump_node_iter != from_graph_exe_order.begin()) { | |||
| InsertControlDependToGraph(from_graph, NOT_NULL(*(jump_node_iter - 1)), NOT_NULL(assign_node)); | |||
| } | |||
| InsertControlDependToGraph(from_graph, NOT_NULL(assign_node), NOT_NULL(jump_node)); | |||
| @@ -772,6 +747,7 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr> | |||
| std::vector<CNodePtr> execution_order; | |||
| uint32_t child_order_index = 0; | |||
| for (auto &node : cnodes) { | |||
| uint32_t child_graph_index = 0; | |||
| execution_order.push_back(node); | |||
| if (node == graph->get_end_goto()) { | |||
| continue; | |||
| @@ -779,7 +755,7 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr> | |||
| if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSwitch)) { | |||
| std::vector<uint32_t> label_switch_list = AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(node, kAttrLabelSwitchList); | |||
| for (auto iter = label_switch_list.rbegin(); iter != label_switch_list.rend(); ++iter) { | |||
| if (!CheckLabelIndex(child_order_index, *iter, node, graph)) { | |||
| if (!CheckLabelIndex(child_graph_index++, *iter, node)) { | |||
| MS_LOG(EXCEPTION) << "Check label index fail"; | |||
| } | |||
| if (child_order_index >= graph->child_graph_order().size()) { | |||
| @@ -791,9 +767,12 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr> | |||
| } | |||
| } else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) { | |||
| uint32_t label_index = AnfAlgo::GetNodeAttr<uint32_t>(node, kAttrLabelIndex); | |||
| if (!CheckLabelIndex(child_order_index, label_index, node, graph)) { | |||
| if (!CheckLabelIndex(child_graph_index, label_index, node)) { | |||
| MS_LOG(EXCEPTION) << "Check label index fail"; | |||
| } | |||
| if (child_order_index >= graph->child_graph_order().size()) { | |||
| MS_LOG(EXCEPTION) << "Index out of range:" << graph->child_graph_order().size(); | |||
| } | |||
| auto child_graph = graph->child_graph_order()[child_order_index++]; | |||
| auto child_execution_order = RecurseGraph(NOT_NULL(child_graph), memo); | |||
| execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end()); | |||
| @@ -804,15 +783,14 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr> | |||
| return execution_order; | |||
| } | |||
| bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cur_label, | |||
| NotNull<KernelGraphPtr> graph) { | |||
| const std::vector<std::shared_ptr<KernelGraph>> &child_graph_order = graph->child_graph_order(); | |||
| bool AscendControlParser::CheckLabelIndex(uint32_t index, uint32_t label_index, const CNodePtr &cur_label) { | |||
| auto child_graphs = AnfAlgo::GetNodeAttr<std::vector<KernelGraphPtr>>(cur_label, kAttrChildGraph); | |||
| // check index and child order size | |||
| if (child_graph_order.size() <= IntToSize(order_index)) { | |||
| MS_LOG(EXCEPTION) << "Child graph order is wrong, graph " << graph->ToString() << " child graph size " | |||
| << child_graph_order.size() << " goto index " << order_index; | |||
| if (child_graphs.size() <= IntToSize(index)) { | |||
| MS_LOG(EXCEPTION) << "Child graph index is wrong, current node " << cur_label->ToString() << " child graph size " | |||
| << child_graphs.size() << " goto index " << index; | |||
| } | |||
| auto child_graph = child_graph_order[order_index]; | |||
| auto child_graph = child_graphs[index]; | |||
| MS_EXCEPTION_IF_NULL(child_graph); | |||
| // get start_label_set_index of child graph | |||
| @@ -822,7 +800,7 @@ bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_i | |||
| MS_EXCEPTION_IF_NULL(cur_label); | |||
| MS_EXCEPTION_IF_NULL(start_label_set); | |||
| MS_LOG(WARNING) << cur_label->DebugString() << " index " << label_index << " but " << start_label_set->DebugString() | |||
| << " index " << start_label_set_index << " current child graph order : " << order_index; | |||
| << " index " << start_label_set_index; | |||
| return false; | |||
| } else { | |||
| return true; | |||
| @@ -64,13 +64,13 @@ class AscendControlParser { | |||
| const CNodePtr &last_label); | |||
| static AnfNodePtr InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to); | |||
| static std::vector<std::pair<KernelGraphPtr, std::vector<AnfNodePtr>>> ParseCallNode(NotNull<CNodePtr> call_node); | |||
| static std::vector<std::pair<KernelGraphPtr, std::vector<AnfNodePtr>>> ParseCallSwitchNode( | |||
| NotNull<CNodePtr> call_node); | |||
| static std::tuple<KernelGraphPtr, std::vector<AnfNodePtr>> ParsePartial(NotNull<AnfNodePtr> node); | |||
| static void AttachChildGraphToReturnNode(NotNull<KernelGraphPtr> graph, | |||
| const NotNull<std::set<KernelGraphPtr> *> memo); | |||
| // root graph order | |||
| static bool CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cnode, | |||
| NotNull<KernelGraphPtr> graph); | |||
| static bool CheckLabelIndex(uint32_t index, uint32_t label_index, const CNodePtr &cnode); | |||
| static std::vector<CNodePtr> RecurseGraph(NotNull<KernelGraphPtr> graph, | |||
| const NotNull<std::set<KernelGraphPtr> *> memo); | |||
| }; | |||
| @@ -885,7 +885,7 @@ void AscendSession::CreateMultiBranchOutput(NotNull<KernelGraphPtr> graph, NotNu | |||
| std::map<AnfNodePtr, AnfNodePtr> need_replace_list; | |||
| auto node_list = GetCNodes(TopoSort(graph->get_return())); | |||
| for (auto &node : node_list) { | |||
| if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall)) { | |||
| if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch)) { | |||
| // create a parameter to store the output of multiple branch and set the parameter as the condition graph's output | |||
| auto output_param = graph->TransTupleToMakeTuple(graph->NewParameter(node->abstract())); | |||
| MS_EXCEPTION_IF_NULL(graph->MutableInputs()); | |||
| @@ -898,7 +898,7 @@ void AscendSession::CreateMultiBranchOutput(NotNull<KernelGraphPtr> graph, NotNu | |||
| MS_LOG(INFO) << "Create parameter " << output_param->DebugString() << " for call node " << node->DebugString() | |||
| << ", depend node is " << depend->DebugString(); | |||
| // insert assign in order to transfer child graph output to parameter | |||
| auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(node); | |||
| auto child_graphs = AnfAlgo::GetCallSwitchKernelGraph(node); | |||
| for (auto &child_graph : child_graphs) { | |||
| MS_EXCEPTION_IF_NULL(child_graph); | |||
| // If graph has no output, the graph is the true graph of while and will call condition graph, no need insert | |||
| @@ -67,7 +67,7 @@ std::vector<AnfNodePtr> GetCallRealOutputs(const AnfNodePtr &call_node) { | |||
| return {node}; | |||
| } | |||
| std::vector<AnfNodePtr> real_inputs; | |||
| auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(node->cast<CNodePtr>()); | |||
| auto child_graphs = AnfAlgo::GetCallSwitchKernelGraph(node->cast<CNodePtr>()); | |||
| for (const auto &child_graph : child_graphs) { | |||
| if (child_graph->get_output_null()) { | |||
| continue; | |||
| @@ -931,6 +931,18 @@ std::vector<CNodePtr> KernelGraph::FindNodeByPrimitive(const PrimitivePtr &primi | |||
| return result; | |||
| } | |||
| std::vector<CNodePtr> KernelGraph::FindNodeByPrimitive(const std::vector<PrimitivePtr> &primitive_list) const { | |||
| std::vector<CNodePtr> result; | |||
| for (const auto &anf : execution_order_) { | |||
| for (const auto &primitive : primitive_list) { | |||
| if (AnfAlgo::CheckPrimitiveType(anf, primitive) && AnfAlgo::GetGraphId(anf.get()) == graph_id_) { | |||
| result.push_back(anf->cast<CNodePtr>()); | |||
| } | |||
| } | |||
| } | |||
| return result; | |||
| } | |||
| void KernelGraph::PrintGraphExecuteOrder() const { | |||
| MS_LOG(INFO) << "Graph:" << graph_id_ << "execution order"; | |||
| for (size_t i = 0; i < execution_order_.size(); i++) { | |||
| @@ -1078,11 +1090,12 @@ bool KernelGraph::IsUniqueTargetInternalOutput(const AnfNodePtr &node, int outpu | |||
| void KernelGraph::UpdateChildGraphOrder() { | |||
| MS_LOG(INFO) << "Update " << ToString() << " child graph order."; | |||
| SetExecOrderByDefault(); | |||
| auto call_nodes = FindNodeByPrimitive(std::make_shared<Primitive>(prim::kPrimCall->name())); | |||
| auto call_nodes = FindNodeByPrimitive( | |||
| {std::make_shared<Primitive>(prim::kPrimCall->name()), std::make_shared<Primitive>(prim::kPrimSwitch->name())}); | |||
| std::vector<KernelGraphPtr> child_graph_order; | |||
| for (auto &call_node : call_nodes) { | |||
| MS_EXCEPTION_IF_NULL(call_node); | |||
| auto call_child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node->cast<CNodePtr>()); | |||
| auto call_child_graphs = AnfAlgo::GetCallSwitchKernelGraph(call_node->cast<CNodePtr>()); | |||
| for (const auto &child_graph : call_child_graphs) { | |||
| MS_EXCEPTION_IF_NULL(child_graph); | |||
| if (child_graph != parent_graph_) { | |||
| @@ -131,6 +131,7 @@ class KernelGraph : public FuncGraph { | |||
| void set_parent_graph(const std::shared_ptr<KernelGraph> &parent_graph) { parent_graph_ = parent_graph; } | |||
| // find anf node in graph | |||
| std::vector<CNodePtr> FindNodeByPrimitive(const PrimitivePtr &primitive) const; | |||
| std::vector<CNodePtr> FindNodeByPrimitive(const std::vector<PrimitivePtr> &primitive_list) const; | |||
| // used to dump ir | |||
| std::string ToString() const override; | |||
| @@ -547,45 +547,26 @@ CNodePtr SessionBasic::CreateSwitchInput(const AnfNodePtr &node_input, KernelGra | |||
| MS_EXCEPTION_IF_NULL(node_input); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| // switch input generalizes partial | |||
| if (AnfAlgo::CheckPrimitiveType(node_input, prim::kPrimPartial) || | |||
| AnfAlgo::CheckPrimitiveType(node_input, prim::kPrimCall)) { | |||
| return node_input->cast<CNodePtr>(); | |||
| } | |||
| if (node_input->isa<CNode>()) { | |||
| MS_LOG(EXCEPTION) << "If switch input is " << node_input->DebugString() << ", it mast be partial or call."; | |||
| } | |||
| std::vector<AnfNodePtr> partial_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimPartial->name()))}; | |||
| if (node_input->isa<ValueNode>() && IsValueNode<FuncGraph>(node_input)) { | |||
| partial_inputs.emplace_back(node_input); | |||
| auto partial_node = graph->NewCNode(partial_inputs); | |||
| return partial_node; | |||
| if (AnfAlgo::CheckPrimitiveType(node_input, prim::kPrimPartial)) { | |||
| auto partial_node = graph->GetBackendAnfByFrontAnf(node_input); | |||
| return partial_node->cast<CNodePtr>(); | |||
| } else if (node_input->isa<ValueNode>() && IsValueNode<FuncGraph>(node_input)) { | |||
| partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(node_input)); | |||
| } else { | |||
| KernelGraphPtr kernel_graph = NewKernelGraph(); | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| auto parameter = CreateNewParameterFromCNode(graph->GetBackendAnfByFrontAnf(node_input), true, kernel_graph.get()); | |||
| auto primitive = NewValueNode(std::make_shared<Primitive>(prim::kPrimReturn->name())); | |||
| auto return_node = kernel_graph->NewCNode({primitive, parameter}); | |||
| kernel_graph->set_return(return_node); | |||
| partial_inputs.emplace_back(std::make_shared<ValueNode>(kernel_graph)); | |||
| partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(node_input)); | |||
| } | |||
| KernelGraphPtr kernel_graph = NewKernelGraph(); | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| kernel_graph->set_output(graph->GetBackendAnfByFrontAnf(node_input)); | |||
| partial_inputs.emplace_back(std::make_shared<ValueNode>(kernel_graph)); | |||
| auto partial_node = graph->NewCNode(partial_inputs); | |||
| return partial_node; | |||
| } | |||
| CNodePtr SessionBasic::HandleSwitchInputs(const AnfNodePtr &anf_node, KernelGraph *graph) { | |||
| MS_EXCEPTION_IF_NULL(anf_node); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| auto node = anf_node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (node->inputs().size() < kSwitchInputSize) { | |||
| MS_LOG(EXCEPTION) << "Switch input size less than " << kSwitchInputSize; | |||
| } | |||
| auto primitive = NewValueNode(std::make_shared<Primitive>(prim::kPrimSwitch->name())); | |||
| std::vector<AnfNodePtr> switch_inputs = {primitive, node->input(1)}; | |||
| for (size_t index = 2; index < node->inputs().size(); index++) { | |||
| auto input = CreateSwitchInput(node->input(index), graph); | |||
| switch_inputs.emplace_back(input); | |||
| } | |||
| auto switch_node = graph->NewCNode(switch_inputs); | |||
| return switch_node; | |||
| } | |||
| std::vector<AnfNodePtr> SessionBasic::CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| @@ -611,14 +592,33 @@ std::vector<AnfNodePtr> SessionBasic::CreateSwitchOrPartialNode(const CNodePtr & | |||
| }); | |||
| return cnode_inputs; | |||
| } else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) { | |||
| auto switch_node = HandleSwitchInputs(cnode_input, graph); | |||
| auto switch_cnode = cnode_input->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(switch_cnode); | |||
| std::vector<AnfNodePtr> switch_inputs = {switch_cnode->input(kAnfPrimitiveIndex), | |||
| switch_cnode->input(kFirstDataInputIndex)}; | |||
| for (size_t index = kFirstBranchInSwitch; index < switch_cnode->inputs().size(); index++) { | |||
| auto node = switch_cnode->input(index); | |||
| // there is real input in call, should put it to true and false branch in switch | |||
| if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) { | |||
| auto partial_node = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(partial_node); | |||
| std::vector<AnfNodePtr> partial_inputs = partial_node->inputs(); | |||
| partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(kFirstDataInputIndex))); | |||
| auto new_partial = graph->NewCNode(partial_inputs); | |||
| switch_inputs.emplace_back(new_partial); | |||
| } | |||
| } | |||
| if (switch_inputs.size() < kSwitchInputSize) { | |||
| MS_LOG(EXCEPTION) << "Switch inputs size: " << switch_inputs.size() << "less than " << kSwitchInputSize; | |||
| } | |||
| auto switch_node = graph->NewCNode(switch_inputs); | |||
| cnode_inputs.emplace_back(switch_node); | |||
| return cnode_inputs; | |||
| } | |||
| MS_LOG(EXCEPTION) << "CNode input[0] must be partial or switch."; | |||
| } | |||
| CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) { | |||
| CNodePtr SessionBasic::CreateNewCNode(CNodePtr cnode, KernelGraph *graph) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| std::vector<AnfNodePtr> cnode_inputs; | |||
| @@ -642,7 +642,22 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) | |||
| } | |||
| } | |||
| } else if (attr_input->isa<CNode>()) { | |||
| cnode_inputs = CreateSwitchOrPartialNode(cnode, graph); | |||
| auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input); | |||
| if (cnode->inputs().size() < 2 && AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) { | |||
| auto switch_cnode = cnode_input->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(switch_cnode); | |||
| cnode_inputs = switch_cnode->inputs(); | |||
| } else { | |||
| cnode_inputs = CreateSwitchOrPartialNode(cnode, graph); | |||
| } | |||
| } else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) { | |||
| cnode_inputs = {graph->GetBackendAnfByFrontAnf(cnode->input(kAnfPrimitiveIndex)), | |||
| graph->GetBackendAnfByFrontAnf(cnode->input(kFirstDataInputIndex))}; | |||
| for (size_t index = kFirstBranchInSwitch; index < cnode->inputs().size(); index++) { | |||
| auto node_input = cnode->input(index); | |||
| auto switch_input = CreateSwitchInput(node_input, graph); | |||
| cnode_inputs.emplace_back(switch_input); | |||
| } | |||
| } else { | |||
| // get primitive of old node | |||
| auto prim = AnfAlgo::GetCNodePrimitive(cnode); | |||
| @@ -651,21 +666,33 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) | |||
| cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(*prim)))}; | |||
| } | |||
| for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) { | |||
| auto anf = cnode->input(input_idx); | |||
| MS_EXCEPTION_IF_NULL(anf); | |||
| // anf has been created before | |||
| if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) { | |||
| cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf)); | |||
| continue; | |||
| } else if (IsValueNode<None>(anf)) { | |||
| continue; | |||
| if (!AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) { | |||
| for (size_t input_idx = kFirstDataInputIndex; input_idx < cnode->inputs().size(); input_idx++) { | |||
| auto anf = cnode->input(input_idx); | |||
| MS_EXCEPTION_IF_NULL(anf); | |||
| // anf has been created before | |||
| if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) { | |||
| cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf)); | |||
| continue; | |||
| } else if (IsValueNode<None>(anf)) { | |||
| continue; | |||
| } | |||
| MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]"; | |||
| } | |||
| MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]"; | |||
| } | |||
| TraceManager::DebugTrace(std::make_shared<TraceCopy>(cnode->debug_info())); | |||
| auto new_cnode = graph->NewCNode(cnode_inputs); | |||
| TraceManager::EndTrace(); | |||
| // if the cnode is call switch, remove call | |||
| if (new_cnode->inputs().size() > 1) { | |||
| auto first_input = new_cnode->input(kFirstDataInputIndex); | |||
| if (AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimCall) && | |||
| AnfAlgo::CheckPrimitiveType(first_input, prim::kPrimSwitch)) { | |||
| new_cnode = first_input->cast<CNodePtr>(); | |||
| } | |||
| } | |||
| return new_cnode; | |||
| } | |||
| @@ -86,11 +86,7 @@ class SessionBasic { | |||
| CNodePtr CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph, bool *from_other_graph, | |||
| std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode); | |||
| CNodePtr CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph); | |||
| CNodePtr CreateSwitchInput(const AnfNodePtr &node_input, KernelGraph *graph); | |||
| CNodePtr HandleSwitchInputs(const AnfNodePtr &anf_node, KernelGraph *graph); | |||
| std::vector<AnfNodePtr> CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph); | |||
| CNodePtr CreateNewCNode(CNodePtr cnode, KernelGraph *graph); | |||
| // get graph id in child graphs by ME front anf node pointer | |||
| virtual GraphId GetGraphIdByNode(const AnfNodePtr &) const { return kInvalidGraphId; } | |||
| @@ -112,6 +108,10 @@ class SessionBasic { | |||
| } | |||
| #endif | |||
| private: | |||
| CNodePtr CreateSwitchInput(const AnfNodePtr &node_input, KernelGraph *graph); | |||
| std::vector<AnfNodePtr> CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph); | |||
| protected: | |||
| virtual void SetSummaryNodes(KernelGraph *graph); | |||
| // Get graph by graph id ,if not exist return null ptr | |||
| @@ -277,11 +277,14 @@ const int kValueNodeTensorMask = 2; | |||
| // define special index in special node | |||
| constexpr auto kAnfPrimitiveIndex = 0; | |||
| constexpr auto kFirstDataInputIndex = 1; | |||
| constexpr auto kAnfPartialFuncGraphIndex = 1; | |||
| constexpr auto kRealInputNodeIndexInTupleGetItem = 1; | |||
| constexpr auto kInputNodeOutputIndexInTupleGetItem = 2; | |||
| constexpr auto kTupleGetItemInputSize = 3; | |||
| constexpr auto kSwitchInputSize = 4; | |||
| constexpr auto kFirstBranchInSwitch = 2; | |||
| constexpr auto kCallKernelGraphIndex = 1; | |||
| constexpr auto kSwitchTrueKernelGraphIndex = 2; | |||
| constexpr auto kSwitchFalseKernelGraphIndex = 3; | |||
| // index define of control depend | |||
| constexpr auto kControlDependPriorIndex = 1; | |||
| constexpr auto kControlDependBehindIndex = 2; | |||