merge me commit for remove inline deal witch multiple cases of switch in ConstructKernelGraph deal with switch and call cases in ConstructKernelGraph fix bug and rebase master ConstructKernelGraph adapte to remove inline fix InsertMultipleAssignToGraph bug add graph input to new graph which is created for switch input replace CreateNewParameterFromCNode to NewParameter in order to set new parameter's abstract and kernel_info avoids create a new switch repeatedly when the cnode is a call switch without real input null pointer check update frontend code Revert "update frontend code" This reverts committags/v0.7.0-betace1f600d1e. update frontend code PR_2948 fix bug of CheckLabalIndex handle switch_layer in ConstructKernelGraph add attr for assign node to avoid erasing by cse pass cherry-pick ms commit[59b35f690d]:temporary avoid list getitem problem rebase master Revert "cherry-pick ms commit[59b35f690d]:temporary avoid list getitem problem" This reverts commit74c258f942. Revert "handle switch_layer in ConstructKernelGraph" This reverts commitcb5367f02d. Revert "update frontend code PR_2948" This reverts commit234ac58340. Revert "merge me commit for remove inline" This reverts commit55c0ebd42b. fix diff after rebase master doing remove inline in me overwrite FindNodePrimitive Revert "doing remove inline in me" This reverts commitb42e893125.
| @@ -1031,31 +1031,29 @@ FuncGraphPtr AnfRuntimeAlgorithm::GetValueNodeFuncGraph(const AnfNodePtr &node) | |||||
| return func_graph; | return func_graph; | ||||
| } | } | ||||
| std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallNodeKernelGraph(const CNodePtr &call_node) { | |||||
| MS_EXCEPTION_IF_NULL(call_node); | |||||
| if (!AnfAlgo::CheckPrimitiveType(call_node, std::make_shared<Primitive>("call"))) { | |||||
| MS_LOG(EXCEPTION) << "Anf node: " << call_node->DebugString() << "is not a call node."; | |||||
| std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallSwitchKernelGraph(const CNodePtr &cnode) { | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| if (!(AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall) || AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch))) { | |||||
| MS_LOG(EXCEPTION) << "Node: " << cnode->DebugString() << "is not a call or switch node."; | |||||
| } | } | ||||
| auto input1 = call_node->input(1); | |||||
| MS_EXCEPTION_IF_NULL(input1); | |||||
| if (input1->isa<ValueNode>()) { | |||||
| if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall)) { | |||||
| auto input1 = cnode->input(kCallKernelGraphIndex); | |||||
| MS_EXCEPTION_IF_NULL(input1); | |||||
| auto value_node = input1->cast<ValueNodePtr>(); | auto value_node = input1->cast<ValueNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(value_node); | MS_EXCEPTION_IF_NULL(value_node); | ||||
| auto kernel_graph = value_node->value(); | auto kernel_graph = value_node->value(); | ||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | MS_EXCEPTION_IF_NULL(kernel_graph); | ||||
| return {kernel_graph->cast<KernelGraphPtr>()}; | return {kernel_graph->cast<KernelGraphPtr>()}; | ||||
| } else if (input1->isa<CNode>() && AnfAlgo::CheckPrimitiveType(input1, prim::kPrimSwitch)) { | |||||
| auto switch_node = input1->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(switch_node); | |||||
| auto get_switch_kernel_graph = [switch_node](size_t input_index) -> KernelGraphPtr { | |||||
| auto partial = switch_node->input(input_index); | |||||
| } else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) { | |||||
| auto get_switch_kernel_graph = [cnode](size_t input_index) -> KernelGraphPtr { | |||||
| auto partial = cnode->input(input_index); | |||||
| MS_EXCEPTION_IF_NULL(partial); | MS_EXCEPTION_IF_NULL(partial); | ||||
| if (IsValueNode<KernelGraph>(partial)) { | if (IsValueNode<KernelGraph>(partial)) { | ||||
| return GetValueNode<KernelGraphPtr>(partial); | return GetValueNode<KernelGraphPtr>(partial); | ||||
| } | } | ||||
| auto partial_cnode = partial->cast<CNodePtr>(); | auto partial_cnode = partial->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(partial_cnode); | MS_EXCEPTION_IF_NULL(partial_cnode); | ||||
| auto graph_node = partial_cnode->input(1); | |||||
| auto graph_node = partial_cnode->input(kCallKernelGraphIndex); | |||||
| MS_EXCEPTION_IF_NULL(graph_node); | MS_EXCEPTION_IF_NULL(graph_node); | ||||
| auto graph_value_node = graph_node->cast<ValueNodePtr>(); | auto graph_value_node = graph_node->cast<ValueNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(graph_value_node); | MS_EXCEPTION_IF_NULL(graph_value_node); | ||||
| @@ -1064,7 +1062,8 @@ std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallNodeKernelGraph(const CN | |||||
| auto child_graph = graph_value->cast<KernelGraphPtr>(); | auto child_graph = graph_value->cast<KernelGraphPtr>(); | ||||
| return child_graph; | return child_graph; | ||||
| }; | }; | ||||
| return {get_switch_kernel_graph(2), get_switch_kernel_graph(3)}; | |||||
| return {get_switch_kernel_graph(kSwitchTrueKernelGraphIndex), | |||||
| get_switch_kernel_graph(kSwitchFalseKernelGraphIndex)}; | |||||
| } | } | ||||
| return {}; | return {}; | ||||
| } | } | ||||
| @@ -201,7 +201,7 @@ class AnfRuntimeAlgorithm { | |||||
| static bool IsCommunicationOp(const AnfNodePtr &node); | static bool IsCommunicationOp(const AnfNodePtr &node); | ||||
| static bool IsGetNext(const NotNull<AnfNodePtr> &node); | static bool IsGetNext(const NotNull<AnfNodePtr> &node); | ||||
| static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node); | static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node); | ||||
| static std::vector<KernelGraphPtr> GetCallNodeKernelGraph(const CNodePtr &call_node); | |||||
| static std::vector<KernelGraphPtr> GetCallSwitchKernelGraph(const CNodePtr &cnode); | |||||
| static bool IsSwitchCall(const CNodePtr &call_node); | static bool IsSwitchCall(const CNodePtr &call_node); | ||||
| static bool IsScalarInput(const CNodePtr &cnode, size_t index); | static bool IsScalarInput(const CNodePtr &cnode, size_t index); | ||||
| static bool IsScalarOutput(const CNodePtr &cnode, size_t index); | static bool IsScalarOutput(const CNodePtr &cnode, size_t index); | ||||
| @@ -361,27 +361,22 @@ void AscendControlParser::ExecutorValidate(NotNull<KernelGraphPtr> root_graph) { | |||||
| } | } | ||||
| } | } | ||||
| std::vector<std::pair<KernelGraphPtr, std::vector<AnfNodePtr>>> AscendControlParser::ParseCallNode( | |||||
| NotNull<CNodePtr> call_node) { | |||||
| std::vector<std::pair<KernelGraphPtr, std::vector<AnfNodePtr>>> AscendControlParser::ParseCallSwitchNode( | |||||
| NotNull<CNodePtr> cnode) { | |||||
| std::vector<std::pair<KernelGraphPtr, std::vector<AnfNodePtr>>> ret; | std::vector<std::pair<KernelGraphPtr, std::vector<AnfNodePtr>>> ret; | ||||
| if (!IsPrimitiveCNode(call_node.get(), prim::kPrimCall)) { | |||||
| MS_LOG(EXCEPTION) << "Node " << call_node->DebugString() << " is not a call node."; | |||||
| } | |||||
| if (call_node->size() <= kCNodeCallArg) { | |||||
| MS_LOG(EXCEPTION) << "Node " << call_node->DebugString() << " has invalid inputs size " << call_node->size(); | |||||
| } | |||||
| const std::vector<AnfNodePtr> &call_node_inputs = call_node->inputs(); | |||||
| auto call_arg = call_node_inputs[kCNodeCallArg]; | |||||
| MS_EXCEPTION_IF_NULL(call_arg); | |||||
| if (IsValueNode<KernelGraph>(call_arg)) { | |||||
| if (IsPrimitiveCNode(cnode.get(), prim::kPrimCall)) { | |||||
| if (cnode->size() <= kCNodeCallArg) { | |||||
| MS_LOG(EXCEPTION) << "Call node " << cnode->DebugString() << " has invalid inputs size " << cnode->size(); | |||||
| } | |||||
| auto call_arg = cnode->input(kCNodeCallArg); | |||||
| MS_EXCEPTION_IF_NULL(call_arg); | |||||
| ret.emplace_back(GetValueNode<KernelGraphPtr>(call_arg), | ret.emplace_back(GetValueNode<KernelGraphPtr>(call_arg), | ||||
| std::vector<AnfNodePtr>(call_node_inputs.begin() + kCNodeCallArg + 1, call_node_inputs.end())); | |||||
| } else if (IsPrimitiveCNode(call_arg, prim::kPrimSwitch)) { | |||||
| auto switch_cnode = call_arg->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(switch_cnode); | |||||
| const std::vector<AnfNodePtr> &switch_inputs = switch_cnode->inputs(); | |||||
| if (switch_inputs.size() <= kCNodeSwitchCond) { | |||||
| MS_LOG(EXCEPTION) << "Node " << switch_cnode->DebugString() << " has invalid inputs size " | |||||
| std::vector<AnfNodePtr>(cnode->inputs().begin() + kCNodeCallArg + 1, cnode->inputs().end())); | |||||
| } else if (IsPrimitiveCNode(cnode.get(), prim::kPrimSwitch)) { | |||||
| const std::vector<AnfNodePtr> &switch_inputs = cnode->inputs(); | |||||
| if (switch_inputs.size() < kCNodeSwitchLength) { | |||||
| MS_LOG(EXCEPTION) << "Switch node " << cnode->DebugString() << " has invalid inputs size " | |||||
| << switch_inputs.size(); | << switch_inputs.size(); | ||||
| } | } | ||||
| for (auto iter = switch_inputs.begin() + kCNodeSwitchCond + 1; iter != switch_inputs.end(); ++iter) { | for (auto iter = switch_inputs.begin() + kCNodeSwitchCond + 1; iter != switch_inputs.end(); ++iter) { | ||||
| @@ -389,7 +384,7 @@ std::vector<std::pair<KernelGraphPtr, std::vector<AnfNodePtr>>> AscendControlPar | |||||
| ret.emplace_back(target_graph, args); | ret.emplace_back(target_graph, args); | ||||
| } | } | ||||
| } else { | } else { | ||||
| MS_LOG(EXCEPTION) << "Unsupport call node: " << call_node->DebugString(5); | |||||
| MS_LOG(EXCEPTION) << "Unsupport call node: " << cnode->DebugString(5); | |||||
| } | } | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| @@ -406,11 +401,11 @@ void AscendControlParser::ChildGraphDataAssign( | |||||
| const std::vector<CNodePtr> &nodes = kg->execution_order(); | const std::vector<CNodePtr> &nodes = kg->execution_order(); | ||||
| for (auto &node : nodes) { | for (auto &node : nodes) { | ||||
| if (!IsPrimitiveCNode(node, prim::kPrimCall)) { | |||||
| if (!(IsPrimitiveCNode(node, prim::kPrimCall) || IsPrimitiveCNode(node, prim::kPrimSwitch))) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| auto child_graph_list = ParseCallNode(NOT_NULL(node)); | |||||
| auto child_graph_list = ParseCallSwitchNode(NOT_NULL(node)); | |||||
| for (auto &[child_graph, args] : child_graph_list) { | for (auto &[child_graph, args] : child_graph_list) { | ||||
| MS_EXCEPTION_IF_NULL(child_graph); | MS_EXCEPTION_IF_NULL(child_graph); | ||||
| const std::vector<AnfNodePtr> ¶ms = child_graph->inputs(); | const std::vector<AnfNodePtr> ¶ms = child_graph->inputs(); | ||||
| @@ -425,7 +420,6 @@ void AscendControlParser::ChildGraphDataAssign( | |||||
| link_list->emplace_back(args[i], params[i]); | link_list->emplace_back(args[i], params[i]); | ||||
| continue; | continue; | ||||
| } | } | ||||
| InsertMultipleAssignToGraph(kg, node, NOT_NULL(args[i]), NOT_NULL(params[i])); | InsertMultipleAssignToGraph(kg, node, NOT_NULL(args[i]), NOT_NULL(params[i])); | ||||
| } | } | ||||
| } | } | ||||
| @@ -475,30 +469,20 @@ NotNull<CNodePtr> AscendControlParser::ProcessKernelGraph(NotNull<KernelGraphPtr | |||||
| for (size_t i = 0; i < nodes.size(); ++i) { | for (size_t i = 0; i < nodes.size(); ++i) { | ||||
| auto &cnode = nodes[i]; | auto &cnode = nodes[i]; | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| if (cnode->size() < kCNodePrim + 1) { | |||||
| MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; | |||||
| } | |||||
| AnfNodePtr fn = cnode->input(kAnfPrimitiveIndex); | |||||
| if (!IsPrimitive(fn, prim::kPrimCall) || cnode->size() < kCNodeCallArg + 1) { | |||||
| MS_LOG(DEBUG) << "Continue node " << cnode->DebugString(); | |||||
| if (!(AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall) || | |||||
| AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch) || | |||||
| AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer))) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| AnfNodePtr arg = cnode->input(kFirstDataInputIndex); | |||||
| MS_EXCEPTION_IF_NULL(arg); | |||||
| if (IsValueNode<KernelGraph>(arg)) { | |||||
| if (IsPrimitiveCNode(cnode, prim::kPrimCall)) { | |||||
| RecurseCall(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo); | RecurseCall(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo); | ||||
| } else if (!arg->isa<CNode>()) { | |||||
| MS_LOG(EXCEPTION) << "Unknown type call node " << cnode->DebugString(); | |||||
| } else if (IsPrimitiveCNode(arg->cast<CNodePtr>(), prim::kPrimSwitch)) { | |||||
| auto arg_cnode = arg->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(arg_cnode); | |||||
| cnode->set_inputs(arg_cnode->inputs()); | |||||
| } else if (IsPrimitiveCNode(cnode, prim::kPrimSwitch)) { | |||||
| RecurseSwitch(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo); | RecurseSwitch(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo); | ||||
| } else if (IsPrimitiveCNode(arg->cast<CNodePtr>(), prim::kPrimSwitchLayer)) { | |||||
| auto arg_cnode = arg->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(arg_cnode); | |||||
| cnode->set_inputs(arg_cnode->inputs()); | |||||
| } else if (IsPrimitiveCNode(cnode, prim::kPrimSwitchLayer)) { | |||||
| RecurseSwitchLayer(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo); | RecurseSwitchLayer(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo); | ||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "Unexpected node: " << cnode->DebugString(); | |||||
| } | } | ||||
| } | } | ||||
| kg->SetExecOrderByDefault(); | kg->SetExecOrderByDefault(); | ||||
| @@ -699,31 +683,22 @@ void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> fr | |||||
| continue; | continue; | ||||
| } | } | ||||
| const auto &from_graph_exe_order = from_graph->execution_order(); | const auto &from_graph_exe_order = from_graph->execution_order(); | ||||
| std::vector<CNodePtr> real_exe_order(from_graph_exe_order.size()); | |||||
| size_t real_exe_order_size = 0; | |||||
| std::copy_if(from_graph_exe_order.begin(), from_graph_exe_order.end(), real_exe_order.begin(), | |||||
| [&real_exe_order_size](const CNodePtr &node) -> bool { | |||||
| return (IsPrimitiveCNode(node, prim::kPrimSwitch) || IsPrimitiveCNode(node, prim::kPrimPartial)) | |||||
| ? false | |||||
| : (++real_exe_order_size, true); | |||||
| }); | |||||
| real_exe_order.resize(real_exe_order_size); | |||||
| if (jump_node == nullptr) { | if (jump_node == nullptr) { | ||||
| if (!real_exe_order.empty()) { | |||||
| InsertControlDependToGraph(from_graph, NOT_NULL(*(real_exe_order.rbegin())), NOT_NULL(assign_node)); | |||||
| if (!from_graph_exe_order.empty()) { | |||||
| InsertControlDependToGraph(from_graph, NOT_NULL(*(from_graph_exe_order.rbegin())), NOT_NULL(assign_node)); | |||||
| } else { | } else { | ||||
| InsertDependToGraph(from_graph, NOT_NULL(assign_node)); | InsertDependToGraph(from_graph, NOT_NULL(assign_node)); | ||||
| } | } | ||||
| continue; | continue; | ||||
| } | } | ||||
| auto jump_node_iter = std::find(real_exe_order.begin(), real_exe_order.end(), jump_node); | |||||
| if (jump_node_iter == real_exe_order.end()) { | |||||
| auto jump_node_iter = std::find(from_graph_exe_order.begin(), from_graph_exe_order.end(), jump_node); | |||||
| if (jump_node_iter == from_graph_exe_order.end()) { | |||||
| MS_LOG(EXCEPTION) << "Cannot find jump node " << jump_node->DebugString() << " in graph " | MS_LOG(EXCEPTION) << "Cannot find jump node " << jump_node->DebugString() << " in graph " | ||||
| << from_graph->ToString(); | << from_graph->ToString(); | ||||
| } | } | ||||
| // insert assign between jump_node -1 and jump_node | // insert assign between jump_node -1 and jump_node | ||||
| if (jump_node_iter != real_exe_order.begin()) { | |||||
| if (jump_node_iter != from_graph_exe_order.begin()) { | |||||
| InsertControlDependToGraph(from_graph, NOT_NULL(*(jump_node_iter - 1)), NOT_NULL(assign_node)); | InsertControlDependToGraph(from_graph, NOT_NULL(*(jump_node_iter - 1)), NOT_NULL(assign_node)); | ||||
| } | } | ||||
| InsertControlDependToGraph(from_graph, NOT_NULL(assign_node), NOT_NULL(jump_node)); | InsertControlDependToGraph(from_graph, NOT_NULL(assign_node), NOT_NULL(jump_node)); | ||||
| @@ -772,6 +747,7 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr> | |||||
| std::vector<CNodePtr> execution_order; | std::vector<CNodePtr> execution_order; | ||||
| uint32_t child_order_index = 0; | uint32_t child_order_index = 0; | ||||
| for (auto &node : cnodes) { | for (auto &node : cnodes) { | ||||
| uint32_t child_graph_index = 0; | |||||
| execution_order.push_back(node); | execution_order.push_back(node); | ||||
| if (node == graph->get_end_goto()) { | if (node == graph->get_end_goto()) { | ||||
| continue; | continue; | ||||
| @@ -779,7 +755,7 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr> | |||||
| if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSwitch)) { | if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSwitch)) { | ||||
| std::vector<uint32_t> label_switch_list = AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(node, kAttrLabelSwitchList); | std::vector<uint32_t> label_switch_list = AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(node, kAttrLabelSwitchList); | ||||
| for (auto iter = label_switch_list.rbegin(); iter != label_switch_list.rend(); ++iter) { | for (auto iter = label_switch_list.rbegin(); iter != label_switch_list.rend(); ++iter) { | ||||
| if (!CheckLabelIndex(child_order_index, *iter, node, graph)) { | |||||
| if (!CheckLabelIndex(child_graph_index++, *iter, node)) { | |||||
| MS_LOG(EXCEPTION) << "Check label index fail"; | MS_LOG(EXCEPTION) << "Check label index fail"; | ||||
| } | } | ||||
| if (child_order_index >= graph->child_graph_order().size()) { | if (child_order_index >= graph->child_graph_order().size()) { | ||||
| @@ -791,9 +767,12 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr> | |||||
| } | } | ||||
| } else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) { | } else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) { | ||||
| uint32_t label_index = AnfAlgo::GetNodeAttr<uint32_t>(node, kAttrLabelIndex); | uint32_t label_index = AnfAlgo::GetNodeAttr<uint32_t>(node, kAttrLabelIndex); | ||||
| if (!CheckLabelIndex(child_order_index, label_index, node, graph)) { | |||||
| if (!CheckLabelIndex(child_graph_index, label_index, node)) { | |||||
| MS_LOG(EXCEPTION) << "Check label index fail"; | MS_LOG(EXCEPTION) << "Check label index fail"; | ||||
| } | } | ||||
| if (child_order_index >= graph->child_graph_order().size()) { | |||||
| MS_LOG(EXCEPTION) << "Index out of range:" << graph->child_graph_order().size(); | |||||
| } | |||||
| auto child_graph = graph->child_graph_order()[child_order_index++]; | auto child_graph = graph->child_graph_order()[child_order_index++]; | ||||
| auto child_execution_order = RecurseGraph(NOT_NULL(child_graph), memo); | auto child_execution_order = RecurseGraph(NOT_NULL(child_graph), memo); | ||||
| execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end()); | execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end()); | ||||
| @@ -804,15 +783,14 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr> | |||||
| return execution_order; | return execution_order; | ||||
| } | } | ||||
| bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cur_label, | |||||
| NotNull<KernelGraphPtr> graph) { | |||||
| const std::vector<std::shared_ptr<KernelGraph>> &child_graph_order = graph->child_graph_order(); | |||||
| bool AscendControlParser::CheckLabelIndex(uint32_t index, uint32_t label_index, const CNodePtr &cur_label) { | |||||
| auto child_graphs = AnfAlgo::GetNodeAttr<std::vector<KernelGraphPtr>>(cur_label, kAttrChildGraph); | |||||
| // check index and child order size | // check index and child order size | ||||
| if (child_graph_order.size() <= IntToSize(order_index)) { | |||||
| MS_LOG(EXCEPTION) << "Child graph order is wrong, graph " << graph->ToString() << " child graph size " | |||||
| << child_graph_order.size() << " goto index " << order_index; | |||||
| if (child_graphs.size() <= IntToSize(index)) { | |||||
| MS_LOG(EXCEPTION) << "Child graph index is wrong, current node " << cur_label->ToString() << " child graph size " | |||||
| << child_graphs.size() << " goto index " << index; | |||||
| } | } | ||||
| auto child_graph = child_graph_order[order_index]; | |||||
| auto child_graph = child_graphs[index]; | |||||
| MS_EXCEPTION_IF_NULL(child_graph); | MS_EXCEPTION_IF_NULL(child_graph); | ||||
| // get start_label_set_index of child graph | // get start_label_set_index of child graph | ||||
| @@ -822,7 +800,7 @@ bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_i | |||||
| MS_EXCEPTION_IF_NULL(cur_label); | MS_EXCEPTION_IF_NULL(cur_label); | ||||
| MS_EXCEPTION_IF_NULL(start_label_set); | MS_EXCEPTION_IF_NULL(start_label_set); | ||||
| MS_LOG(WARNING) << cur_label->DebugString() << " index " << label_index << " but " << start_label_set->DebugString() | MS_LOG(WARNING) << cur_label->DebugString() << " index " << label_index << " but " << start_label_set->DebugString() | ||||
| << " index " << start_label_set_index << " current child graph order : " << order_index; | |||||
| << " index " << start_label_set_index; | |||||
| return false; | return false; | ||||
| } else { | } else { | ||||
| return true; | return true; | ||||
| @@ -64,13 +64,13 @@ class AscendControlParser { | |||||
| const CNodePtr &last_label); | const CNodePtr &last_label); | ||||
| static AnfNodePtr InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to); | static AnfNodePtr InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to); | ||||
| static std::vector<std::pair<KernelGraphPtr, std::vector<AnfNodePtr>>> ParseCallNode(NotNull<CNodePtr> call_node); | |||||
| static std::vector<std::pair<KernelGraphPtr, std::vector<AnfNodePtr>>> ParseCallSwitchNode( | |||||
| NotNull<CNodePtr> call_node); | |||||
| static std::tuple<KernelGraphPtr, std::vector<AnfNodePtr>> ParsePartial(NotNull<AnfNodePtr> node); | static std::tuple<KernelGraphPtr, std::vector<AnfNodePtr>> ParsePartial(NotNull<AnfNodePtr> node); | ||||
| static void AttachChildGraphToReturnNode(NotNull<KernelGraphPtr> graph, | static void AttachChildGraphToReturnNode(NotNull<KernelGraphPtr> graph, | ||||
| const NotNull<std::set<KernelGraphPtr> *> memo); | const NotNull<std::set<KernelGraphPtr> *> memo); | ||||
| // root graph order | // root graph order | ||||
| static bool CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cnode, | |||||
| NotNull<KernelGraphPtr> graph); | |||||
| static bool CheckLabelIndex(uint32_t index, uint32_t label_index, const CNodePtr &cnode); | |||||
| static std::vector<CNodePtr> RecurseGraph(NotNull<KernelGraphPtr> graph, | static std::vector<CNodePtr> RecurseGraph(NotNull<KernelGraphPtr> graph, | ||||
| const NotNull<std::set<KernelGraphPtr> *> memo); | const NotNull<std::set<KernelGraphPtr> *> memo); | ||||
| }; | }; | ||||
| @@ -885,7 +885,7 @@ void AscendSession::CreateMultiBranchOutput(NotNull<KernelGraphPtr> graph, NotNu | |||||
| std::map<AnfNodePtr, AnfNodePtr> need_replace_list; | std::map<AnfNodePtr, AnfNodePtr> need_replace_list; | ||||
| auto node_list = GetCNodes(TopoSort(graph->get_return())); | auto node_list = GetCNodes(TopoSort(graph->get_return())); | ||||
| for (auto &node : node_list) { | for (auto &node : node_list) { | ||||
| if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall)) { | |||||
| if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch)) { | |||||
| // create a parameter to store the output of multiple branch and set the parameter as the condition graph's output | // create a parameter to store the output of multiple branch and set the parameter as the condition graph's output | ||||
| auto output_param = graph->TransTupleToMakeTuple(graph->NewParameter(node->abstract())); | auto output_param = graph->TransTupleToMakeTuple(graph->NewParameter(node->abstract())); | ||||
| MS_EXCEPTION_IF_NULL(graph->MutableInputs()); | MS_EXCEPTION_IF_NULL(graph->MutableInputs()); | ||||
| @@ -898,7 +898,7 @@ void AscendSession::CreateMultiBranchOutput(NotNull<KernelGraphPtr> graph, NotNu | |||||
| MS_LOG(INFO) << "Create parameter " << output_param->DebugString() << " for call node " << node->DebugString() | MS_LOG(INFO) << "Create parameter " << output_param->DebugString() << " for call node " << node->DebugString() | ||||
| << ", depend node is " << depend->DebugString(); | << ", depend node is " << depend->DebugString(); | ||||
| // insert assign in order to transfer child graph output to parameter | // insert assign in order to transfer child graph output to parameter | ||||
| auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(node); | |||||
| auto child_graphs = AnfAlgo::GetCallSwitchKernelGraph(node); | |||||
| for (auto &child_graph : child_graphs) { | for (auto &child_graph : child_graphs) { | ||||
| MS_EXCEPTION_IF_NULL(child_graph); | MS_EXCEPTION_IF_NULL(child_graph); | ||||
| // If graph has no output, the graph is the true graph of while and will call condition graph, no need insert | // If graph has no output, the graph is the true graph of while and will call condition graph, no need insert | ||||
| @@ -67,7 +67,7 @@ std::vector<AnfNodePtr> GetCallRealOutputs(const AnfNodePtr &call_node) { | |||||
| return {node}; | return {node}; | ||||
| } | } | ||||
| std::vector<AnfNodePtr> real_inputs; | std::vector<AnfNodePtr> real_inputs; | ||||
| auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(node->cast<CNodePtr>()); | |||||
| auto child_graphs = AnfAlgo::GetCallSwitchKernelGraph(node->cast<CNodePtr>()); | |||||
| for (const auto &child_graph : child_graphs) { | for (const auto &child_graph : child_graphs) { | ||||
| if (child_graph->get_output_null()) { | if (child_graph->get_output_null()) { | ||||
| continue; | continue; | ||||
| @@ -931,6 +931,18 @@ std::vector<CNodePtr> KernelGraph::FindNodeByPrimitive(const PrimitivePtr &primi | |||||
| return result; | return result; | ||||
| } | } | ||||
| std::vector<CNodePtr> KernelGraph::FindNodeByPrimitive(const std::vector<PrimitivePtr> &primitive_list) const { | |||||
| std::vector<CNodePtr> result; | |||||
| for (const auto &anf : execution_order_) { | |||||
| for (const auto &primitive : primitive_list) { | |||||
| if (AnfAlgo::CheckPrimitiveType(anf, primitive) && AnfAlgo::GetGraphId(anf.get()) == graph_id_) { | |||||
| result.push_back(anf->cast<CNodePtr>()); | |||||
| } | |||||
| } | |||||
| } | |||||
| return result; | |||||
| } | |||||
| void KernelGraph::PrintGraphExecuteOrder() const { | void KernelGraph::PrintGraphExecuteOrder() const { | ||||
| MS_LOG(INFO) << "Graph:" << graph_id_ << "execution order"; | MS_LOG(INFO) << "Graph:" << graph_id_ << "execution order"; | ||||
| for (size_t i = 0; i < execution_order_.size(); i++) { | for (size_t i = 0; i < execution_order_.size(); i++) { | ||||
| @@ -1078,11 +1090,12 @@ bool KernelGraph::IsUniqueTargetInternalOutput(const AnfNodePtr &node, int outpu | |||||
| void KernelGraph::UpdateChildGraphOrder() { | void KernelGraph::UpdateChildGraphOrder() { | ||||
| MS_LOG(INFO) << "Update " << ToString() << " child graph order."; | MS_LOG(INFO) << "Update " << ToString() << " child graph order."; | ||||
| SetExecOrderByDefault(); | SetExecOrderByDefault(); | ||||
| auto call_nodes = FindNodeByPrimitive(std::make_shared<Primitive>(prim::kPrimCall->name())); | |||||
| auto call_nodes = FindNodeByPrimitive( | |||||
| {std::make_shared<Primitive>(prim::kPrimCall->name()), std::make_shared<Primitive>(prim::kPrimSwitch->name())}); | |||||
| std::vector<KernelGraphPtr> child_graph_order; | std::vector<KernelGraphPtr> child_graph_order; | ||||
| for (auto &call_node : call_nodes) { | for (auto &call_node : call_nodes) { | ||||
| MS_EXCEPTION_IF_NULL(call_node); | MS_EXCEPTION_IF_NULL(call_node); | ||||
| auto call_child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node->cast<CNodePtr>()); | |||||
| auto call_child_graphs = AnfAlgo::GetCallSwitchKernelGraph(call_node->cast<CNodePtr>()); | |||||
| for (const auto &child_graph : call_child_graphs) { | for (const auto &child_graph : call_child_graphs) { | ||||
| MS_EXCEPTION_IF_NULL(child_graph); | MS_EXCEPTION_IF_NULL(child_graph); | ||||
| if (child_graph != parent_graph_) { | if (child_graph != parent_graph_) { | ||||
| @@ -131,6 +131,7 @@ class KernelGraph : public FuncGraph { | |||||
| void set_parent_graph(const std::shared_ptr<KernelGraph> &parent_graph) { parent_graph_ = parent_graph; } | void set_parent_graph(const std::shared_ptr<KernelGraph> &parent_graph) { parent_graph_ = parent_graph; } | ||||
| // find anf node in graph | // find anf node in graph | ||||
| std::vector<CNodePtr> FindNodeByPrimitive(const PrimitivePtr &primitive) const; | std::vector<CNodePtr> FindNodeByPrimitive(const PrimitivePtr &primitive) const; | ||||
| std::vector<CNodePtr> FindNodeByPrimitive(const std::vector<PrimitivePtr> &primitive_list) const; | |||||
| // used to dump ir | // used to dump ir | ||||
| std::string ToString() const override; | std::string ToString() const override; | ||||
| @@ -547,45 +547,26 @@ CNodePtr SessionBasic::CreateSwitchInput(const AnfNodePtr &node_input, KernelGra | |||||
| MS_EXCEPTION_IF_NULL(node_input); | MS_EXCEPTION_IF_NULL(node_input); | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| // switch input generalizes partial | // switch input generalizes partial | ||||
| if (AnfAlgo::CheckPrimitiveType(node_input, prim::kPrimPartial) || | |||||
| AnfAlgo::CheckPrimitiveType(node_input, prim::kPrimCall)) { | |||||
| return node_input->cast<CNodePtr>(); | |||||
| } | |||||
| if (node_input->isa<CNode>()) { | |||||
| MS_LOG(EXCEPTION) << "If switch input is " << node_input->DebugString() << ", it mast be partial or call."; | |||||
| } | |||||
| std::vector<AnfNodePtr> partial_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimPartial->name()))}; | std::vector<AnfNodePtr> partial_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimPartial->name()))}; | ||||
| if (node_input->isa<ValueNode>() && IsValueNode<FuncGraph>(node_input)) { | |||||
| partial_inputs.emplace_back(node_input); | |||||
| auto partial_node = graph->NewCNode(partial_inputs); | |||||
| return partial_node; | |||||
| if (AnfAlgo::CheckPrimitiveType(node_input, prim::kPrimPartial)) { | |||||
| auto partial_node = graph->GetBackendAnfByFrontAnf(node_input); | |||||
| return partial_node->cast<CNodePtr>(); | |||||
| } else if (node_input->isa<ValueNode>() && IsValueNode<FuncGraph>(node_input)) { | |||||
| partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(node_input)); | |||||
| } else { | |||||
| KernelGraphPtr kernel_graph = NewKernelGraph(); | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||||
| auto parameter = CreateNewParameterFromCNode(graph->GetBackendAnfByFrontAnf(node_input), true, kernel_graph.get()); | |||||
| auto primitive = NewValueNode(std::make_shared<Primitive>(prim::kPrimReturn->name())); | |||||
| auto return_node = kernel_graph->NewCNode({primitive, parameter}); | |||||
| kernel_graph->set_return(return_node); | |||||
| partial_inputs.emplace_back(std::make_shared<ValueNode>(kernel_graph)); | |||||
| partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(node_input)); | |||||
| } | } | ||||
| KernelGraphPtr kernel_graph = NewKernelGraph(); | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||||
| kernel_graph->set_output(graph->GetBackendAnfByFrontAnf(node_input)); | |||||
| partial_inputs.emplace_back(std::make_shared<ValueNode>(kernel_graph)); | |||||
| auto partial_node = graph->NewCNode(partial_inputs); | auto partial_node = graph->NewCNode(partial_inputs); | ||||
| return partial_node; | return partial_node; | ||||
| } | } | ||||
| CNodePtr SessionBasic::HandleSwitchInputs(const AnfNodePtr &anf_node, KernelGraph *graph) { | |||||
| MS_EXCEPTION_IF_NULL(anf_node); | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| auto node = anf_node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (node->inputs().size() < kSwitchInputSize) { | |||||
| MS_LOG(EXCEPTION) << "Switch input size less than " << kSwitchInputSize; | |||||
| } | |||||
| auto primitive = NewValueNode(std::make_shared<Primitive>(prim::kPrimSwitch->name())); | |||||
| std::vector<AnfNodePtr> switch_inputs = {primitive, node->input(1)}; | |||||
| for (size_t index = 2; index < node->inputs().size(); index++) { | |||||
| auto input = CreateSwitchInput(node->input(index), graph); | |||||
| switch_inputs.emplace_back(input); | |||||
| } | |||||
| auto switch_node = graph->NewCNode(switch_inputs); | |||||
| return switch_node; | |||||
| } | |||||
| std::vector<AnfNodePtr> SessionBasic::CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph) { | std::vector<AnfNodePtr> SessionBasic::CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph) { | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| @@ -611,14 +592,33 @@ std::vector<AnfNodePtr> SessionBasic::CreateSwitchOrPartialNode(const CNodePtr & | |||||
| }); | }); | ||||
| return cnode_inputs; | return cnode_inputs; | ||||
| } else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) { | } else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) { | ||||
| auto switch_node = HandleSwitchInputs(cnode_input, graph); | |||||
| auto switch_cnode = cnode_input->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(switch_cnode); | |||||
| std::vector<AnfNodePtr> switch_inputs = {switch_cnode->input(kAnfPrimitiveIndex), | |||||
| switch_cnode->input(kFirstDataInputIndex)}; | |||||
| for (size_t index = kFirstBranchInSwitch; index < switch_cnode->inputs().size(); index++) { | |||||
| auto node = switch_cnode->input(index); | |||||
| // there is real input in call, should put it to true and false branch in switch | |||||
| if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) { | |||||
| auto partial_node = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(partial_node); | |||||
| std::vector<AnfNodePtr> partial_inputs = partial_node->inputs(); | |||||
| partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(kFirstDataInputIndex))); | |||||
| auto new_partial = graph->NewCNode(partial_inputs); | |||||
| switch_inputs.emplace_back(new_partial); | |||||
| } | |||||
| } | |||||
| if (switch_inputs.size() < kSwitchInputSize) { | |||||
| MS_LOG(EXCEPTION) << "Switch inputs size: " << switch_inputs.size() << "less than " << kSwitchInputSize; | |||||
| } | |||||
| auto switch_node = graph->NewCNode(switch_inputs); | |||||
| cnode_inputs.emplace_back(switch_node); | cnode_inputs.emplace_back(switch_node); | ||||
| return cnode_inputs; | return cnode_inputs; | ||||
| } | } | ||||
| MS_LOG(EXCEPTION) << "CNode input[0] must be partial or switch."; | MS_LOG(EXCEPTION) << "CNode input[0] must be partial or switch."; | ||||
| } | } | ||||
| CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) { | |||||
| CNodePtr SessionBasic::CreateNewCNode(CNodePtr cnode, KernelGraph *graph) { | |||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| std::vector<AnfNodePtr> cnode_inputs; | std::vector<AnfNodePtr> cnode_inputs; | ||||
| @@ -642,7 +642,22 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) | |||||
| } | } | ||||
| } | } | ||||
| } else if (attr_input->isa<CNode>()) { | } else if (attr_input->isa<CNode>()) { | ||||
| cnode_inputs = CreateSwitchOrPartialNode(cnode, graph); | |||||
| auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input); | |||||
| if (cnode->inputs().size() < 2 && AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) { | |||||
| auto switch_cnode = cnode_input->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(switch_cnode); | |||||
| cnode_inputs = switch_cnode->inputs(); | |||||
| } else { | |||||
| cnode_inputs = CreateSwitchOrPartialNode(cnode, graph); | |||||
| } | |||||
| } else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) { | |||||
| cnode_inputs = {graph->GetBackendAnfByFrontAnf(cnode->input(kAnfPrimitiveIndex)), | |||||
| graph->GetBackendAnfByFrontAnf(cnode->input(kFirstDataInputIndex))}; | |||||
| for (size_t index = kFirstBranchInSwitch; index < cnode->inputs().size(); index++) { | |||||
| auto node_input = cnode->input(index); | |||||
| auto switch_input = CreateSwitchInput(node_input, graph); | |||||
| cnode_inputs.emplace_back(switch_input); | |||||
| } | |||||
| } else { | } else { | ||||
| // get primitive of old node | // get primitive of old node | ||||
| auto prim = AnfAlgo::GetCNodePrimitive(cnode); | auto prim = AnfAlgo::GetCNodePrimitive(cnode); | ||||
| @@ -651,21 +666,33 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) | |||||
| cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(*prim)))}; | cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(*prim)))}; | ||||
| } | } | ||||
| for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) { | |||||
| auto anf = cnode->input(input_idx); | |||||
| MS_EXCEPTION_IF_NULL(anf); | |||||
| // anf has been created before | |||||
| if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) { | |||||
| cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf)); | |||||
| continue; | |||||
| } else if (IsValueNode<None>(anf)) { | |||||
| continue; | |||||
| if (!AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) { | |||||
| for (size_t input_idx = kFirstDataInputIndex; input_idx < cnode->inputs().size(); input_idx++) { | |||||
| auto anf = cnode->input(input_idx); | |||||
| MS_EXCEPTION_IF_NULL(anf); | |||||
| // anf has been created before | |||||
| if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) { | |||||
| cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf)); | |||||
| continue; | |||||
| } else if (IsValueNode<None>(anf)) { | |||||
| continue; | |||||
| } | |||||
| MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]"; | |||||
| } | } | ||||
| MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]"; | |||||
| } | } | ||||
| TraceManager::DebugTrace(std::make_shared<TraceCopy>(cnode->debug_info())); | TraceManager::DebugTrace(std::make_shared<TraceCopy>(cnode->debug_info())); | ||||
| auto new_cnode = graph->NewCNode(cnode_inputs); | auto new_cnode = graph->NewCNode(cnode_inputs); | ||||
| TraceManager::EndTrace(); | TraceManager::EndTrace(); | ||||
| // if the cnode is call switch, remove call | |||||
| if (new_cnode->inputs().size() > 1) { | |||||
| auto first_input = new_cnode->input(kFirstDataInputIndex); | |||||
| if (AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimCall) && | |||||
| AnfAlgo::CheckPrimitiveType(first_input, prim::kPrimSwitch)) { | |||||
| new_cnode = first_input->cast<CNodePtr>(); | |||||
| } | |||||
| } | |||||
| return new_cnode; | return new_cnode; | ||||
| } | } | ||||
| @@ -86,11 +86,7 @@ class SessionBasic { | |||||
| CNodePtr CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph, bool *from_other_graph, | CNodePtr CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph, bool *from_other_graph, | ||||
| std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode); | std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode); | ||||
| CNodePtr CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph); | |||||
| CNodePtr CreateSwitchInput(const AnfNodePtr &node_input, KernelGraph *graph); | |||||
| CNodePtr HandleSwitchInputs(const AnfNodePtr &anf_node, KernelGraph *graph); | |||||
| std::vector<AnfNodePtr> CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph); | |||||
| CNodePtr CreateNewCNode(CNodePtr cnode, KernelGraph *graph); | |||||
| // get graph id in child graphs by ME front anf node pointer | // get graph id in child graphs by ME front anf node pointer | ||||
| virtual GraphId GetGraphIdByNode(const AnfNodePtr &) const { return kInvalidGraphId; } | virtual GraphId GetGraphIdByNode(const AnfNodePtr &) const { return kInvalidGraphId; } | ||||
| @@ -112,6 +108,10 @@ class SessionBasic { | |||||
| } | } | ||||
| #endif | #endif | ||||
| private: | |||||
| CNodePtr CreateSwitchInput(const AnfNodePtr &node_input, KernelGraph *graph); | |||||
| std::vector<AnfNodePtr> CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph); | |||||
| protected: | protected: | ||||
| virtual void SetSummaryNodes(KernelGraph *graph); | virtual void SetSummaryNodes(KernelGraph *graph); | ||||
| // Get graph by graph id ,if not exist return null ptr | // Get graph by graph id ,if not exist return null ptr | ||||
| @@ -277,11 +277,14 @@ const int kValueNodeTensorMask = 2; | |||||
| // define special index in special node | // define special index in special node | ||||
| constexpr auto kAnfPrimitiveIndex = 0; | constexpr auto kAnfPrimitiveIndex = 0; | ||||
| constexpr auto kFirstDataInputIndex = 1; | constexpr auto kFirstDataInputIndex = 1; | ||||
| constexpr auto kAnfPartialFuncGraphIndex = 1; | |||||
| constexpr auto kRealInputNodeIndexInTupleGetItem = 1; | constexpr auto kRealInputNodeIndexInTupleGetItem = 1; | ||||
| constexpr auto kInputNodeOutputIndexInTupleGetItem = 2; | constexpr auto kInputNodeOutputIndexInTupleGetItem = 2; | ||||
| constexpr auto kTupleGetItemInputSize = 3; | constexpr auto kTupleGetItemInputSize = 3; | ||||
| constexpr auto kSwitchInputSize = 4; | constexpr auto kSwitchInputSize = 4; | ||||
| constexpr auto kFirstBranchInSwitch = 2; | |||||
| constexpr auto kCallKernelGraphIndex = 1; | |||||
| constexpr auto kSwitchTrueKernelGraphIndex = 2; | |||||
| constexpr auto kSwitchFalseKernelGraphIndex = 3; | |||||
| // index define of control depend | // index define of control depend | ||||
| constexpr auto kControlDependPriorIndex = 1; | constexpr auto kControlDependPriorIndex = 1; | ||||
| constexpr auto kControlDependBehindIndex = 2; | constexpr auto kControlDependBehindIndex = 2; | ||||