| @@ -1006,6 +1006,25 @@ bool AnfRuntimeAlgorithm::IsParameterWeight(const ParameterPtr &node) { | |||||
| return node->has_default(); | return node->has_default(); | ||||
| } | } | ||||
| bool AnfRuntimeAlgorithm::IsLabelIndexInNode(const AnfNodePtr &node, size_t label_index) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (!node->isa<CNode>()) { | |||||
| return false; | |||||
| } | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| if (AnfAlgo::GetCNodeName(cnode) == kLabelGotoOpName && | |||||
| (AnfAlgo::GetNodeAttr<uint32_t>(cnode, kAttrLabelIndex) == label_index)) { | |||||
| return true; | |||||
| } else if (AnfAlgo::GetCNodeName(cnode) == kLabelSwitchOpName) { | |||||
| auto label_list = AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(cnode, kAttrLabelSwitchList); | |||||
| if (std::find(label_list.begin(), label_list.end(), label_index) != label_list.end()) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| void AnfRuntimeAlgorithm::SetStreamId(uint32_t stream_id, AnfNode *node) { | void AnfRuntimeAlgorithm::SetStreamId(uint32_t stream_id, AnfNode *node) { | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info()); | auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info()); | ||||
| @@ -188,6 +188,8 @@ class AnfRuntimeAlgorithm { | |||||
| static bool IsNodeInGraphKernel(const AnfNodePtr &node); | static bool IsNodeInGraphKernel(const AnfNodePtr &node); | ||||
| // check parameter is weight or data | // check parameter is weight or data | ||||
| static bool IsParameterWeight(const ParameterPtr &node); | static bool IsParameterWeight(const ParameterPtr &node); | ||||
| // checkout whether the anf node is include the label_index. | |||||
| static bool IsLabelIndexInNode(const AnfNodePtr &node, size_t label_index); | |||||
| // set stream id of kernel,which will be set in stream assign and be used in stream generate | // set stream id of kernel,which will be set in stream assign and be used in stream generate | ||||
| static void SetStreamId(uint32_t stream_id, AnfNode *node); | static void SetStreamId(uint32_t stream_id, AnfNode *node); | ||||
| // get stream id | // get stream id | ||||
| @@ -1238,7 +1238,7 @@ void AscendSession::MultiCallGraphOptimize(NotNull<KernelGraphPtr> root_graph) { | |||||
| MS_LOG(INFO) << "graph: " << graph->graph_id() << " has been called by more than two graphs"; | MS_LOG(INFO) << "graph: " << graph->graph_id() << " has been called by more than two graphs"; | ||||
| int32_t index = 0; | int32_t index = 0; | ||||
| std::vector<KernelGraphPtr> child_graphs; | std::vector<KernelGraphPtr> child_graphs; | ||||
| auto start_label = graph->get_start_label(); | |||||
| auto start_label_id = AnfAlgo::GetNodeAttr<uint32_t>(graph->get_start_label(), kAttrLabelIndex); | |||||
| auto end_node = graph->get_end_goto(); | auto end_node = graph->get_end_goto(); | ||||
| ParameterPtr post_label_param = graph->AddExtraParamAndTensor("label_param", 0); | ParameterPtr post_label_param = graph->AddExtraParamAndTensor("label_param", 0); | ||||
| std::vector<AnfNodePtr> new_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSwitchOpName)), | std::vector<AnfNodePtr> new_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSwitchOpName)), | ||||
| @@ -1247,9 +1247,7 @@ void AscendSession::MultiCallGraphOptimize(NotNull<KernelGraphPtr> root_graph) { | |||||
| auto kg = graphs_[graph_id]; | auto kg = graphs_[graph_id]; | ||||
| auto nodes = kg->execution_order(); | auto nodes = kg->execution_order(); | ||||
| for (uint32_t i = 0; i < nodes.size(); i++) { | for (uint32_t i = 0; i < nodes.size(); i++) { | ||||
| if (AnfAlgo::GetCNodeName(nodes[i]) == kLabelGotoOpName && | |||||
| (AnfAlgo::GetNodeAttr<uint32_t>(nodes[i], kAttrLabelIndex) == | |||||
| AnfAlgo::GetNodeAttr<uint32_t>(start_label, kAttrLabelIndex))) { | |||||
| if (AnfAlgo::IsLabelIndexInNode(nodes[i], start_label_id)) { | |||||
| if (i < (nodes.size() - 1)) { | if (i < (nodes.size() - 1)) { | ||||
| new_inputs.push_back(nodes[i + 1]); | new_inputs.push_back(nodes[i + 1]); | ||||
| } else { | } else { | ||||