| @@ -360,7 +360,7 @@ size_t AnfRuntimeAlgorithm::GetInputTensorNum(const AnfNodePtr &node) { | |||||
| MS_LOG(EXCEPTION) << "Cnode inputs size can't be zero" | MS_LOG(EXCEPTION) << "Cnode inputs size can't be zero" | ||||
| << " trace: " << trace::DumpSourceLines(node); | << " trace: " << trace::DumpSourceLines(node); | ||||
| } | } | ||||
| // exclude intputs[0],which is value_node storing attr,inputs left are real input | |||||
| // exclude inputs[0],which is value_node storing attr,inputs left are real input | |||||
| return input_num - 1; | return input_num - 1; | ||||
| } | } | ||||
| @@ -1191,10 +1191,28 @@ FuncGraphPtr AnfRuntimeAlgorithm::GetValueNodeFuncGraph(const AnfNodePtr &node) | |||||
| std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallSwitchKernelGraph(const CNodePtr &cnode) { | std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallSwitchKernelGraph(const CNodePtr &cnode) { | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| if (!(AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall) || AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch))) { | |||||
| MS_LOG(EXCEPTION) << "Node: " << cnode->DebugString() << "is not a call or switch node." | |||||
| if (!(AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall) || AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch) || | |||||
| AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer))) { | |||||
| MS_LOG(EXCEPTION) << "Node: " << cnode->DebugString() << "is not a call or switch or switch_layer node." | |||||
| << " trace: " << trace::DumpSourceLines(cnode); | << " trace: " << trace::DumpSourceLines(cnode); | ||||
| } | } | ||||
| auto get_switch_kernel_graph = [cnode](size_t input_index) -> KernelGraphPtr { | |||||
| auto partial = cnode->input(input_index); | |||||
| MS_EXCEPTION_IF_NULL(partial); | |||||
| if (IsValueNode<KernelGraph>(partial)) { | |||||
| return GetValueNode<KernelGraphPtr>(partial); | |||||
| } | |||||
| auto partial_cnode = partial->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(partial_cnode); | |||||
| auto graph_node = partial_cnode->input(kCallKernelGraphIndex); | |||||
| MS_EXCEPTION_IF_NULL(graph_node); | |||||
| auto graph_value_node = graph_node->cast<ValueNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(graph_value_node); | |||||
| auto graph_value = graph_value_node->value(); | |||||
| MS_EXCEPTION_IF_NULL(graph_value); | |||||
| auto child_graph = graph_value->cast<KernelGraphPtr>(); | |||||
| return child_graph; | |||||
| }; | |||||
| if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall)) { | if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall)) { | ||||
| auto input1 = cnode->input(kCallKernelGraphIndex); | auto input1 = cnode->input(kCallKernelGraphIndex); | ||||
| MS_EXCEPTION_IF_NULL(input1); | MS_EXCEPTION_IF_NULL(input1); | ||||
| @@ -1204,25 +1222,15 @@ std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallSwitchKernelGraph(const | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | MS_EXCEPTION_IF_NULL(kernel_graph); | ||||
| return {kernel_graph->cast<KernelGraphPtr>()}; | return {kernel_graph->cast<KernelGraphPtr>()}; | ||||
| } else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) { | } else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) { | ||||
| auto get_switch_kernel_graph = [cnode](size_t input_index) -> KernelGraphPtr { | |||||
| auto partial = cnode->input(input_index); | |||||
| MS_EXCEPTION_IF_NULL(partial); | |||||
| if (IsValueNode<KernelGraph>(partial)) { | |||||
| return GetValueNode<KernelGraphPtr>(partial); | |||||
| } | |||||
| auto partial_cnode = partial->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(partial_cnode); | |||||
| auto graph_node = partial_cnode->input(kCallKernelGraphIndex); | |||||
| MS_EXCEPTION_IF_NULL(graph_node); | |||||
| auto graph_value_node = graph_node->cast<ValueNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(graph_value_node); | |||||
| auto graph_value = graph_value_node->value(); | |||||
| MS_EXCEPTION_IF_NULL(graph_value); | |||||
| auto child_graph = graph_value->cast<KernelGraphPtr>(); | |||||
| return child_graph; | |||||
| }; | |||||
| return {get_switch_kernel_graph(kSwitchTrueKernelGraphIndex), | return {get_switch_kernel_graph(kSwitchTrueKernelGraphIndex), | ||||
| get_switch_kernel_graph(kSwitchFalseKernelGraphIndex)}; | get_switch_kernel_graph(kSwitchFalseKernelGraphIndex)}; | ||||
| } else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer)) { | |||||
| std::vector<KernelGraphPtr> child_graphs; | |||||
| for (size_t idx = kMakeTupleInSwitchLayerIndex; idx < cnode->inputs().size(); idx++) { | |||||
| auto child_graph = get_switch_kernel_graph(idx); | |||||
| child_graphs.emplace_back(child_graph); | |||||
| } | |||||
| return child_graphs; | |||||
| } | } | ||||
| return {}; | return {}; | ||||
| } | } | ||||
| @@ -1627,7 +1635,7 @@ void AnfRuntimeAlgorithm::GetAllFatherRealNode(const AnfNodePtr &anf_node, std:: | |||||
| MS_EXCEPTION_IF_NULL(result); | MS_EXCEPTION_IF_NULL(result); | ||||
| MS_EXCEPTION_IF_NULL(visited); | MS_EXCEPTION_IF_NULL(visited); | ||||
| if (visited->find(anf_node) != visited->end()) { | if (visited->find(anf_node) != visited->end()) { | ||||
| MS_LOG(INFO) << "Node:" << anf_node->fullname_with_scope() << " has alreday been visited"; | |||||
| MS_LOG(INFO) << "Node:" << anf_node->fullname_with_scope() << " has already been visited"; | |||||
| return; | return; | ||||
| } | } | ||||
| visited->insert(anf_node); | visited->insert(anf_node); | ||||
| @@ -156,7 +156,7 @@ static std::vector<CNodePtr> GetTargetLabelSetNodes(NotNull<CNodePtr> jump_node, | |||||
| for (auto label_id : target_label_list) { | for (auto label_id : target_label_list) { | ||||
| auto iter = label_id_to_label_set.find(label_id); | auto iter = label_id_to_label_set.find(label_id); | ||||
| if (iter == label_id_to_label_set.end()) { | if (iter == label_id_to_label_set.end()) { | ||||
| MS_LOG(EXCEPTION) << "Connot find LabelSet node has label id " << label_id; | |||||
| MS_LOG(EXCEPTION) << "Cannot find LabelSet node has label id " << label_id; | |||||
| } | } | ||||
| target_labelset_nodes.push_back(iter->second); | target_labelset_nodes.push_back(iter->second); | ||||
| } | } | ||||
| @@ -413,6 +413,16 @@ std::vector<std::pair<KernelGraphPtr, std::vector<AnfNodePtr>>> AscendControlPar | |||||
| const auto &[target_graph, args] = ParsePartial(NOT_NULL(*iter)); | const auto &[target_graph, args] = ParsePartial(NOT_NULL(*iter)); | ||||
| ret.emplace_back(target_graph, args); | ret.emplace_back(target_graph, args); | ||||
| } | } | ||||
| } else if (IsPrimitiveCNode(cnode.get(), prim::kPrimSwitchLayer)) { | |||||
| const std::vector<AnfNodePtr> &switch_layer_inputs = cnode->inputs(); | |||||
| if (switch_layer_inputs.size() <= kCNodeSwitchLayerBranch) { | |||||
| MS_LOG(EXCEPTION) << "Switch layer node " << cnode->DebugString() << " has invalid inputs size " | |||||
| << switch_layer_inputs.size(); | |||||
| } | |||||
| for (auto iter = switch_layer_inputs.begin() + kCNodeSwitchLayerBranch; iter != switch_layer_inputs.end(); ++iter) { | |||||
| const auto &[target_graph, args] = ParsePartial(NOT_NULL(*iter)); | |||||
| ret.emplace_back(target_graph, args); | |||||
| } | |||||
| } else { | } else { | ||||
| MS_LOG(EXCEPTION) << "Unsupported call node: " << cnode->DebugString(5); | MS_LOG(EXCEPTION) << "Unsupported call node: " << cnode->DebugString(5); | ||||
| } | } | ||||
| @@ -431,7 +441,8 @@ void AscendControlParser::ChildGraphDataAssign( | |||||
| const std::vector<CNodePtr> &nodes = kg->execution_order(); | const std::vector<CNodePtr> &nodes = kg->execution_order(); | ||||
| for (auto &node : nodes) { | for (auto &node : nodes) { | ||||
| if (!(IsPrimitiveCNode(node, prim::kPrimCall) || IsPrimitiveCNode(node, prim::kPrimSwitch))) { | |||||
| if (!(IsPrimitiveCNode(node, prim::kPrimCall) || IsPrimitiveCNode(node, prim::kPrimSwitch) || | |||||
| IsPrimitiveCNode(node, prim::kPrimSwitchLayer))) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -647,12 +658,10 @@ void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull | |||||
| MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLayerLength; | MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLayerLength; | ||||
| } | } | ||||
| auto branch_tuple = cur_node->input(kCNodeSwitchLayerBranch); | |||||
| MS_EXCEPTION_IF_NULL(branch_tuple); | |||||
| if (!branch_tuple->isa<CNode>()) { | |||||
| MS_LOG(EXCEPTION) << branch_tuple->DebugString() << " is not a CNode"; | |||||
| std::vector<AnfNodePtr> branch_partial; | |||||
| for (size_t idx = kCNodeSwitchLayerBranch; idx < cur_node->inputs().size(); idx++) { | |||||
| branch_partial.emplace_back(cur_node->input(idx)); | |||||
| } | } | ||||
| const std::vector<AnfNodePtr> &branch_partial = utils::cast<CNodePtr>(branch_tuple)->inputs(); | |||||
| // 1 return label | // 1 return label | ||||
| auto back_label = kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSetOpName))}); | auto back_label = kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSetOpName))}); | ||||
| // 2 add depend relationship | // 2 add depend relationship | ||||
| @@ -673,16 +682,17 @@ void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull | |||||
| // 3.1 branch kernel graph and args | // 3.1 branch kernel graph and args | ||||
| KernelGraphPtr branch_fg; | KernelGraphPtr branch_fg; | ||||
| std::vector<AnfNodePtr> origin_inputs; | std::vector<AnfNodePtr> origin_inputs; | ||||
| std::tie(branch_fg, origin_inputs) = ParsePartial(NOT_NULL(origin_switch_inputs[i])); | |||||
| std::tie(branch_fg, origin_inputs) = ParsePartial(NOT_NULL(origin_switch_inputs[i + kCNodeSwitchLayerBranch])); | |||||
| child_graphs.push_back(branch_fg); | child_graphs.push_back(branch_fg); | ||||
| // 3.2 recurse sub graph | // 3.2 recurse sub graph | ||||
| CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo); | CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo); | ||||
| new_switch_inputs.push_back(branch_label); | new_switch_inputs.push_back(branch_label); | ||||
| AttachOriginalInputsToGraph(kg, origin_inputs); | AttachOriginalInputsToGraph(kg, origin_inputs); | ||||
| } | } | ||||
| new_switch_inputs.insert(new_switch_inputs.end(), branch_partial.begin(), branch_partial.end()); | |||||
| cur_node->set_inputs(new_switch_inputs); | cur_node->set_inputs(new_switch_inputs); | ||||
| cur_node->set_abstract(nullptr); | |||||
| cur_node->set_abstract(std::make_shared<abstract::AbstractNone>()); | |||||
| // To adapt to the true and false branches of the switch, the sequence of the branches is reversed. | |||||
| std::reverse(child_graphs.begin(), child_graphs.end()); | |||||
| AnfAlgo::SetNodeAttr(kAttrChildGraph, MakeValue<std::vector<KernelGraphPtr>>(child_graphs), cur_node.get()); | AnfAlgo::SetNodeAttr(kAttrChildGraph, MakeValue<std::vector<KernelGraphPtr>>(child_graphs), cur_node.get()); | ||||
| MS_LOG(INFO) << "Succeed processing switch layer " << cur_node->DebugString(); | MS_LOG(INFO) << "Succeed processing switch layer " << cur_node->DebugString(); | ||||
| } | } | ||||
| @@ -875,7 +875,7 @@ void AscendSession::BuildGraphImpl(GraphId graph_id) { | |||||
| // generate and load task info to device if it is sink mode | // generate and load task info to device if it is sink mode | ||||
| Load(graph); | Load(graph); | ||||
| } | } | ||||
| // sync the inital const tensor to device | |||||
| // sync the initial const tensor to device | |||||
| SyncInitialTenosrToDevice(); | SyncInitialTenosrToDevice(); | ||||
| DumpAllGraphs({graph}); | DumpAllGraphs({graph}); | ||||
| MS_LOG(INFO) << "End"; | MS_LOG(INFO) << "End"; | ||||
| @@ -1634,7 +1634,8 @@ void AscendSession::CreateMultiBranchOutput(NotNull<KernelGraphPtr> graph, NotNu | |||||
| std::map<AnfNodePtr, AnfNodePtr> need_replace_list; | std::map<AnfNodePtr, AnfNodePtr> need_replace_list; | ||||
| auto node_list = GetCNodes(TopoSort(graph->get_return())); | auto node_list = GetCNodes(TopoSort(graph->get_return())); | ||||
| for (auto &node : node_list) { | for (auto &node : node_list) { | ||||
| if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch)) { | |||||
| if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch) || | |||||
| AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitchLayer)) { | |||||
| // create a parameter to store the output of multiple branch and set the parameter as the condition graph's output | // create a parameter to store the output of multiple branch and set the parameter as the condition graph's output | ||||
| auto output_param = graph->TransTupleToMakeTuple(graph->NewParameter(node->abstract())); | auto output_param = graph->TransTupleToMakeTuple(graph->NewParameter(node->abstract())); | ||||
| MS_EXCEPTION_IF_NULL(graph->MutableInputs()); | MS_EXCEPTION_IF_NULL(graph->MutableInputs()); | ||||
| @@ -1186,8 +1186,9 @@ bool KernelGraph::IsUniqueTargetInternalOutput(const AnfNodePtr &node, int outpu | |||||
| void KernelGraph::UpdateChildGraphOrder() { | void KernelGraph::UpdateChildGraphOrder() { | ||||
| MS_LOG(INFO) << "Update " << ToString() << " child graph order."; | MS_LOG(INFO) << "Update " << ToString() << " child graph order."; | ||||
| SetExecOrderByDefault(); | SetExecOrderByDefault(); | ||||
| auto call_nodes = FindNodeByPrimitive( | |||||
| {std::make_shared<Primitive>(prim::kPrimCall->name()), std::make_shared<Primitive>(prim::kPrimSwitch->name())}); | |||||
| auto call_nodes = FindNodeByPrimitive({std::make_shared<Primitive>(prim::kPrimCall->name()), | |||||
| std::make_shared<Primitive>(prim::kPrimSwitch->name()), | |||||
| std::make_shared<Primitive>(prim::kPrimSwitchLayer->name())}); | |||||
| std::vector<std::weak_ptr<KernelGraph>> child_graph_order; | std::vector<std::weak_ptr<KernelGraph>> child_graph_order; | ||||
| for (auto &call_node : call_nodes) { | for (auto &call_node : call_nodes) { | ||||
| MS_EXCEPTION_IF_NULL(call_node); | MS_EXCEPTION_IF_NULL(call_node); | ||||
| @@ -148,7 +148,7 @@ tensor::TensorPtr CreateCNodeOutputTensor(const session::KernelWithIndex &node_o | |||||
| } | } | ||||
| } | } | ||||
| tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index)); | tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index)); | ||||
| // if in paynative mode,data only copyed to host when user want to print data | |||||
| // if in pynative mode,data only copied to host when user want to print data | |||||
| auto ms_context = MsContext::GetInstance(); | auto ms_context = MsContext::GetInstance(); | ||||
| MS_EXCEPTION_IF_NULL(ms_context); | MS_EXCEPTION_IF_NULL(ms_context); | ||||
| if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode && | if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode && | ||||
| @@ -499,10 +499,7 @@ std::vector<AnfNodePtr> SessionBasic::CreateParameterFromTuple(const AnfNodePtr | |||||
| auto graph_inputs = graph->MutableInputs(); | auto graph_inputs = graph->MutableInputs(); | ||||
| MS_EXCEPTION_IF_NULL(graph_inputs); | MS_EXCEPTION_IF_NULL(graph_inputs); | ||||
| auto create_parameter = [&](const AbstractBasePtr &abstract) -> void { | auto create_parameter = [&](const AbstractBasePtr &abstract) -> void { | ||||
| auto parameter = graph->NewParameter(); | |||||
| MS_EXCEPTION_IF_NULL(parameter); | |||||
| parameter->set_abstract(abstract); | |||||
| auto new_parameter = graph->NewParameter(parameter); | |||||
| auto new_parameter = graph->NewParameter(abstract); | |||||
| parameters.push_back(new_parameter); | parameters.push_back(new_parameter); | ||||
| valid_inputs->push_back(true); | valid_inputs->push_back(true); | ||||
| graph_inputs->push_back(new_parameter); | graph_inputs->push_back(new_parameter); | ||||
| @@ -662,7 +659,7 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph, | |||||
| return new_cnode; | return new_cnode; | ||||
| } | } | ||||
| CNodePtr SessionBasic::CreateSwitchInput(const AnfNodePtr &node_input, KernelGraph *graph) { | |||||
| CNodePtr SessionBasic::CreateSwitchInput(const CNodePtr &cnode, const AnfNodePtr &node_input, KernelGraph *graph) { | |||||
| MS_EXCEPTION_IF_NULL(node_input); | MS_EXCEPTION_IF_NULL(node_input); | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| // switch input generalizes partial | // switch input generalizes partial | ||||
| @@ -675,9 +672,11 @@ CNodePtr SessionBasic::CreateSwitchInput(const AnfNodePtr &node_input, KernelGra | |||||
| } else { | } else { | ||||
| KernelGraphPtr kernel_graph = NewKernelGraph(); | KernelGraphPtr kernel_graph = NewKernelGraph(); | ||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | MS_EXCEPTION_IF_NULL(kernel_graph); | ||||
| auto parameter = CreateNewParameterFromCNode(graph->GetBackendAnfByFrontAnf(node_input), kernel_graph.get()); | |||||
| auto parameter = CreateNewParameterFromCNode(cnode, kernel_graph.get()); | |||||
| parameter->set_abstract(cnode->abstract()); | |||||
| auto primitive = NewValueNode(std::make_shared<Primitive>(prim::kPrimReturn->name())); | auto primitive = NewValueNode(std::make_shared<Primitive>(prim::kPrimReturn->name())); | ||||
| auto return_node = kernel_graph->NewCNode({primitive, parameter}); | auto return_node = kernel_graph->NewCNode({primitive, parameter}); | ||||
| return_node->set_abstract(cnode->abstract()); | |||||
| kernel_graph->set_return(return_node); | kernel_graph->set_return(return_node); | ||||
| partial_inputs.emplace_back(std::make_shared<ValueNode>(kernel_graph)); | partial_inputs.emplace_back(std::make_shared<ValueNode>(kernel_graph)); | ||||
| partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(node_input)); | partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(node_input)); | ||||
| @@ -722,10 +721,97 @@ std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchInputs(const CNodePtr &cno | |||||
| return cnode_inputs; | return cnode_inputs; | ||||
| } | } | ||||
| void SessionBasic::CreateCallNodeReturnFunction(const CNodePtr &cnode, const AnfNodePtr &real_input) { | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| MS_EXCEPTION_IF_NULL(real_input); | |||||
| if (!(AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimPartial))) { | |||||
| MS_LOG(EXCEPTION) << "Node: " << cnode->DebugString() << "is not a partial node."; | |||||
| } | |||||
| auto partial_input = cnode->input(kFirstDataInputIndex); | |||||
| KernelGraphPtr partial_kernel_graph = GetValueNode<KernelGraphPtr>(partial_input); | |||||
| MS_EXCEPTION_IF_NULL(partial_kernel_graph); | |||||
| auto ret = partial_kernel_graph->get_return(); | |||||
| MS_EXCEPTION_IF_NULL(ret); | |||||
| auto return_input = ret->input(kFirstDataInputIndex); | |||||
| // if kernel graph return node is a function | |||||
| if (AnfAlgo::CheckPrimitiveType(return_input, prim::kPrimPartial)) { | |||||
| std::vector<AnfNodePtr> call_inputs = { | |||||
| partial_kernel_graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))}; | |||||
| auto return_input_cnode = return_input->cast<CNodePtr>(); | |||||
| auto partial_inputs = return_input_cnode->inputs(); | |||||
| call_inputs.insert(call_inputs.end(), partial_inputs.begin() + kFirstDataInputIndex, partial_inputs.end()); | |||||
| auto parameter_for_input = CreateNewParameterFromCNode(real_input, partial_kernel_graph.get()); | |||||
| call_inputs.emplace_back(parameter_for_input); | |||||
| auto call_node = partial_kernel_graph->NewCNode(call_inputs); | |||||
| // update abstract | |||||
| KernelGraphPtr sub_partial_kernel_graph = GetValueNode<KernelGraphPtr>(partial_inputs[kFirstDataInputIndex]); | |||||
| auto ret_partial = sub_partial_kernel_graph->get_return(); | |||||
| call_node->set_abstract(ret_partial->abstract()); | |||||
| // update return input | |||||
| ret->set_input(kFirstDataInputIndex, call_node); | |||||
| } | |||||
| } | |||||
| std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchLayerInputs(const CNodePtr &cnode, KernelGraph *graph) { | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| std::vector<AnfNodePtr> cnode_inputs = { | |||||
| graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))}; | |||||
| auto attr_input = cnode->input(kAnfPrimitiveIndex); | |||||
| MS_EXCEPTION_IF_NULL(attr_input); | |||||
| auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input); | |||||
| auto switch_layer_cnode = cnode_input->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(switch_layer_cnode); | |||||
| std::vector<AnfNodePtr> switch_layer_inputs = {switch_layer_cnode->input(kAnfPrimitiveIndex), | |||||
| switch_layer_cnode->input(kFirstDataInputIndex)}; | |||||
| auto make_tuple_node = switch_layer_cnode->input(kMakeTupleInSwitchLayerIndex); | |||||
| MS_EXCEPTION_IF_NULL(make_tuple_node); | |||||
| auto node = make_tuple_node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| auto make_tuple_inputs = node->inputs(); | |||||
| // there is real input in call, should put it to make_tuple in switch_layer | |||||
| auto real_input = cnode->input(kFirstDataInputIndex); | |||||
| auto real_input_back = graph->GetBackendAnfByFrontAnf(real_input); | |||||
| std::vector<AnfNodePtr> new_make_tuple_inputs = { | |||||
| graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name())))}; | |||||
| for (size_t idx = kFirstDataInputIndex; idx < make_tuple_inputs.size(); idx++) { | |||||
| auto partial_idx = make_tuple_inputs[idx]; | |||||
| MS_EXCEPTION_IF_NULL(cnode->abstract()); | |||||
| // switch_layer node input is partial cnode | |||||
| if (AnfAlgo::CheckPrimitiveType(partial_idx, prim::kPrimPartial)) { | |||||
| auto partial_node = partial_idx->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(partial_node); | |||||
| // update kernel graph when switch_layer node return function | |||||
| CreateCallNodeReturnFunction(partial_node, real_input_back); | |||||
| std::vector<AnfNodePtr> new_partial_inputs = partial_node->inputs(); | |||||
| new_partial_inputs.emplace_back(real_input_back); | |||||
| auto new_partial = graph->NewCNode(new_partial_inputs); | |||||
| new_make_tuple_inputs.emplace_back(new_partial); | |||||
| } | |||||
| // switch_layer node input is kernel graph value node | |||||
| if (IsValueNode<KernelGraph>(partial_idx)) { | |||||
| // make_tuple inputs is KernelGraph | |||||
| std::vector<AnfNodePtr> new_partial_inputs; | |||||
| new_partial_inputs.emplace_back(NewValueNode(std::make_shared<Primitive>(prim::kPrimPartial->name()))); | |||||
| new_partial_inputs.emplace_back(partial_idx); | |||||
| new_partial_inputs.emplace_back(real_input_back); | |||||
| auto new_partial = graph->NewCNode(new_partial_inputs); | |||||
| new_make_tuple_inputs.emplace_back(new_partial); | |||||
| } | |||||
| } | |||||
| auto new_make_tuple = graph->NewCNode(new_make_tuple_inputs); | |||||
| switch_layer_inputs.emplace_back(new_make_tuple); | |||||
| auto new_switch_layer = graph->NewCNode(switch_layer_inputs); | |||||
| cnode_inputs.emplace_back(new_switch_layer); | |||||
| return cnode_inputs; | |||||
| } | |||||
| std::vector<AnfNodePtr> SessionBasic::CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph) { | std::vector<AnfNodePtr> SessionBasic::CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph) { | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| // create primitive of cnode:call(partial or switch) | |||||
| // create primitive of cnode:call(partial or switch or switch_layer) | |||||
| std::vector<AnfNodePtr> cnode_inputs = { | std::vector<AnfNodePtr> cnode_inputs = { | ||||
| graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))}; | graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))}; | ||||
| auto attr_input = cnode->input(kAnfPrimitiveIndex); | auto attr_input = cnode->input(kAnfPrimitiveIndex); | ||||
| @@ -748,9 +834,11 @@ std::vector<AnfNodePtr> SessionBasic::CreateSwitchOrPartialNode(const CNodePtr & | |||||
| return cnode_inputs; | return cnode_inputs; | ||||
| } else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) { | } else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) { | ||||
| return CreateCallSwitchInputs(cnode, graph); | return CreateCallSwitchInputs(cnode, graph); | ||||
| } else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitchLayer)) { | |||||
| return CreateCallSwitchLayerInputs(cnode, graph); | |||||
| } | } | ||||
| MS_LOG(ERROR) << "CNode:" << cnode->DebugString() << " input[0]" << cnode_input->DebugString() | MS_LOG(ERROR) << "CNode:" << cnode->DebugString() << " input[0]" << cnode_input->DebugString() | ||||
| << "must be partial or switch."; | |||||
| << "must be partial or switch or switch_layer."; | |||||
| return {}; | return {}; | ||||
| } | } | ||||
| @@ -788,7 +876,7 @@ void SessionBasic::CreateCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, | |||||
| cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(kFirstDataInputIndex))); | cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(kFirstDataInputIndex))); | ||||
| for (size_t index = kFirstBranchInSwitch; index < cnode->inputs().size(); index++) { | for (size_t index = kFirstBranchInSwitch; index < cnode->inputs().size(); index++) { | ||||
| auto node_input = cnode->input(index); | auto node_input = cnode->input(index); | ||||
| auto switch_input = CreateSwitchInput(node_input, graph); | |||||
| auto switch_input = CreateSwitchInput(cnode, node_input, graph); | |||||
| cnode_inputs->emplace_back(switch_input); | cnode_inputs->emplace_back(switch_input); | ||||
| } | } | ||||
| } else { | } else { | ||||
| @@ -841,10 +929,17 @@ CNodePtr SessionBasic::CreateNewCNode(CNodePtr cnode, KernelGraph *graph) { | |||||
| // if the cnode is call switch, remove call | // if the cnode is call switch, remove call | ||||
| if (new_cnode->inputs().size() > 1) { | if (new_cnode->inputs().size() > 1) { | ||||
| auto first_input = new_cnode->input(kFirstDataInputIndex); | auto first_input = new_cnode->input(kFirstDataInputIndex); | ||||
| MS_EXCEPTION_IF_NULL(first_input); | |||||
| if (AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimCall) && | if (AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimCall) && | ||||
| AnfAlgo::CheckPrimitiveType(first_input, prim::kPrimSwitch)) { | AnfAlgo::CheckPrimitiveType(first_input, prim::kPrimSwitch)) { | ||||
| new_cnode = first_input->cast<CNodePtr>(); | new_cnode = first_input->cast<CNodePtr>(); | ||||
| } | } | ||||
| if (AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimCall) && | |||||
| AnfAlgo::CheckPrimitiveType(first_input, prim::kPrimSwitchLayer)) { | |||||
| auto abstract = cnode->abstract(); | |||||
| new_cnode = first_input->cast<CNodePtr>(); | |||||
| new_cnode->set_abstract(abstract); | |||||
| } | |||||
| } | } | ||||
| return new_cnode; | return new_cnode; | ||||
| @@ -1842,7 +1937,7 @@ void SessionBasic::AssignParamKey(const KernelGraphPtr &kernel_graph) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | MS_EXCEPTION_IF_NULL(kernel_graph); | ||||
| // PS embeddingLookup cache check. | // PS embeddingLookup cache check. | ||||
| if (ps::PsDataPrefetch::GetInstance().cache_enable()) { | if (ps::PsDataPrefetch::GetInstance().cache_enable()) { | ||||
| MS_LOG(EXCEPTION) << "The other parameter cann't set ps mode when the embeddingLookup cache is enabled in " | |||||
| MS_LOG(EXCEPTION) << "The other parameter can't set ps mode when the embeddingLookup cache is enabled in " | |||||
| "parameter server training mode."; | "parameter server training mode."; | ||||
| } | } | ||||
| std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph->get_return()); | std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph->get_return()); | ||||
| @@ -125,7 +125,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||||
| #endif | #endif | ||||
| private: | private: | ||||
| CNodePtr CreateSwitchInput(const AnfNodePtr &node_input, KernelGraph *graph); | |||||
| CNodePtr CreateSwitchInput(const CNodePtr &cnode, const AnfNodePtr &node_input, KernelGraph *graph); | |||||
| std::vector<AnfNodePtr> CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph); | std::vector<AnfNodePtr> CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph); | ||||
| std::vector<AnfNodePtr> CreateValueNode(const CNodePtr &cnode, KernelGraph *graph); | std::vector<AnfNodePtr> CreateValueNode(const CNodePtr &cnode, KernelGraph *graph); | ||||
| void CreateCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs); | void CreateCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs); | ||||
| @@ -133,6 +133,8 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||||
| void GetCNodeInfo(const CNodePtr &cnode, std::vector<AnfNodePtr> *cnode_inputs); | void GetCNodeInfo(const CNodePtr &cnode, std::vector<AnfNodePtr> *cnode_inputs); | ||||
| void GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs, | void GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs, | ||||
| std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode); | std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode); | ||||
| std::vector<AnfNodePtr> CreateCallSwitchLayerInputs(const CNodePtr &cnode, KernelGraph *graph); | |||||
| void CreateCallNodeReturnFunction(const CNodePtr &cnode, const AnfNodePtr &real_input); | |||||
| protected: | protected: | ||||
| friend class Executor; | friend class Executor; | ||||
| @@ -407,6 +407,7 @@ constexpr auto kFirstBranchInSwitch = 2; | |||||
| constexpr auto kCallKernelGraphIndex = 1; | constexpr auto kCallKernelGraphIndex = 1; | ||||
| constexpr auto kSwitchTrueKernelGraphIndex = 2; | constexpr auto kSwitchTrueKernelGraphIndex = 2; | ||||
| constexpr auto kSwitchFalseKernelGraphIndex = 3; | constexpr auto kSwitchFalseKernelGraphIndex = 3; | ||||
| constexpr auto kMakeTupleInSwitchLayerIndex = 2; | |||||
| // index define of control depend | // index define of control depend | ||||
| constexpr auto kControlDependPriorIndex = 1; | constexpr auto kControlDependPriorIndex = 1; | ||||
| constexpr auto kControlDependBehindIndex = 2; | constexpr auto kControlDependBehindIndex = 2; | ||||