| @@ -49,6 +49,11 @@ namespace { | |||
| constexpr size_t kNopNodeInputSize = 2; | |||
| constexpr size_t kNopNodeRealInputIndex = 1; | |||
| constexpr size_t kReturnDataIndex = 1; | |||
| constexpr size_t kSwitchTrueBranchIndex = 2; | |||
| constexpr size_t kPartialFuncGraphPos = 1; | |||
| constexpr size_t kSwitchLayerBranchPos = 2; | |||
| constexpr size_t kSwitchTrueBranchPos = 2; | |||
| constexpr size_t kMakeTupleInputStartPos = 1; | |||
| const PrimitiveSet follow_first_input_prims = {prim::kPrimDepend, prim::kPrimLoad}; | |||
| @@ -142,6 +147,54 @@ void GetRealOutputRecursively(const AnfNodePtr &node, size_t output_index, | |||
| return inputs->push_back(std::make_pair(node, output_index)); | |||
| } | |||
| // Fetch all outputs of control nodes, visited nodes indicates the call node that has been processed. In control flow, | |||
| // there are recursive calls between funcgraphs, so the processed call nodes are recorded to prevent infinite loops. | |||
| std::vector<KernelWithIndex> GetAllOutputByControlFlowNode(const KernelWithIndex &output_with_index, | |||
| std::set<AnfNodePtr> *visited_call_nodes) { | |||
| std::vector<KernelWithIndex> ret; | |||
| const auto &node = output_with_index.first; | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch)) { | |||
| const auto &switch_cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(switch_cnode); | |||
| const auto &switch_inputs = switch_cnode->inputs(); | |||
| auto output_vector = AnfAlgo::GetAllOutputWithIndex(switch_inputs[kSwitchTrueBranchIndex], visited_call_nodes); | |||
| (void)std::copy(output_vector.begin(), output_vector.end(), std::back_inserter(ret)); | |||
| } else if (AnfAlgo::IsCallNode(node)) { | |||
| if (visited_call_nodes != nullptr) { | |||
| if (visited_call_nodes->find(node) != visited_call_nodes->end()) { | |||
| return ret; | |||
| } else { | |||
| visited_call_nodes->emplace(node); | |||
| } | |||
| } | |||
| // The output of the call node is the output of the funcgraph actually called. | |||
| const auto &func_graphs = AnfAlgo::GetFuncGraphbyCallNode(node); | |||
| for (const auto &func_graph : func_graphs) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| // The call in the graph kernel does not need to be parsed, and the node is directly output. | |||
| if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { | |||
| ret.emplace_back(output_with_index); | |||
| break; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(func_graph->output()); | |||
| const auto &func_graph_output = | |||
| AnfAlgo::VisitKernelWithReturnType(func_graph->output(), output_with_index.second); | |||
| std::set<AnfNodePtr> tmp_visited_nodes = {node}; | |||
| auto output_vector = AnfAlgo::GetAllOutputWithIndex( | |||
| func_graph_output.first, (visited_call_nodes == nullptr ? &tmp_visited_nodes : visited_call_nodes)); | |||
| if (output_with_index.second < output_vector.size()) { | |||
| ret.emplace_back(output_vector[output_with_index.second]); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| return ret; | |||
| } | |||
| // ops pair that dynamic input order is differ from the fixed shape ops | |||
| // pair: <real_input->ori_input, ori_input->real_input> | |||
| static std::map<std::string, std::pair<std::map<size_t, size_t>, std::map<size_t, size_t>>> spec_dynamic_node_list = { | |||
| @@ -339,7 +392,8 @@ std::vector<AnfNodePtr> AnfRuntimeAlgorithm::GetAllOutput(const AnfNodePtr &node | |||
| return ret; | |||
| } | |||
| std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputWithIndex(const AnfNodePtr &node) { | |||
| std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputWithIndex(const AnfNodePtr &node, | |||
| std::set<AnfNodePtr> *visited_call_nodes) { | |||
| std::vector<KernelWithIndex> ret; | |||
| std::vector<KernelWithIndex> ret_empty; | |||
| @@ -348,7 +402,7 @@ std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputWithIndex(const An | |||
| auto make_tuple = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(make_tuple); | |||
| for (size_t i = 1; i < make_tuple->inputs().size(); i++) { | |||
| auto make_tuple_output = GetAllOutputWithIndex(make_tuple->input(i)); | |||
| auto make_tuple_output = GetAllOutputWithIndex(make_tuple->input(i), visited_call_nodes); | |||
| (void)std::copy(make_tuple_output.begin(), make_tuple_output.end(), std::back_inserter(ret)); | |||
| } | |||
| return ret; | |||
| @@ -358,7 +412,7 @@ std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputWithIndex(const An | |||
| if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) { | |||
| auto depend_node = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(depend_node); | |||
| auto real_output = GetAllOutputWithIndex(depend_node->input(kRealInputIndexInDepend)); | |||
| auto real_output = GetAllOutputWithIndex(depend_node->input(kRealInputIndexInDepend), visited_call_nodes); | |||
| (void)std::copy(real_output.begin(), real_output.end(), std::back_inserter(ret)); | |||
| return ret; | |||
| } | |||
| @@ -393,20 +447,16 @@ std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputWithIndex(const An | |||
| // The makeTuple node need recurse. | |||
| if (AnfAlgo::CheckPrimitiveType(output_with_index.first, prim::kPrimMakeTuple)) { | |||
| auto output_vector = GetAllOutputWithIndex(output_with_index.first); | |||
| auto output_vector = GetAllOutputWithIndex(output_with_index.first, visited_call_nodes); | |||
| (void)std::copy(output_vector.begin(), output_vector.end(), std::back_inserter(ret)); | |||
| continue; | |||
| } | |||
| // Ignore the output of front call node. | |||
| if (output_with_index.first->isa<CNode>()) { | |||
| auto cnode = output_with_index.first->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto inputs = cnode->inputs(); | |||
| if (inputs[0]->isa<CNode>()) { | |||
| MS_LOG(INFO) << "The output is call node: " << output_with_index.first->DebugString(); | |||
| return ret_empty; | |||
| } | |||
| // Fetch outputs by control nodes. | |||
| if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch) || AnfAlgo::IsCallNode(node)) { | |||
| const auto &control_node_output = GetAllOutputByControlFlowNode(output_with_index, visited_call_nodes); | |||
| (void)std::copy(control_node_output.begin(), control_node_output.end(), std::back_inserter(ret)); | |||
| continue; | |||
| } | |||
| // The InitDataSetQueue node has no output. | |||
| @@ -2527,5 +2577,100 @@ size_t OpRuntimeInfo::output_tensor_size(size_t index) const { | |||
| } | |||
| return output_tensor_size_[index]; | |||
| } | |||
| bool AnfRuntimeAlgorithm::IsCallNode(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (!node->isa<CNode>()) { | |||
| return false; | |||
| } | |||
| const auto &cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| const auto &inputs = cnode->inputs(); | |||
| if (inputs.empty() || inputs[0] == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Invalid call node:" << node->DebugString(); | |||
| } | |||
| return inputs[0]->isa<CNode>() || (inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(inputs[0])); | |||
| } | |||
| std::set<FuncGraphPtr> AnfRuntimeAlgorithm::GetFuncGraphbyCallNode(const AnfNodePtr &node, size_t call_depth) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| std::set<FuncGraphPtr> func_graphs; | |||
| if (!node->isa<CNode>()) { | |||
| return func_graphs; | |||
| } | |||
| const auto &cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| const auto &call_input0 = cnode->input(0); | |||
| MS_EXCEPTION_IF_NULL(call_input0); | |||
| if (AnfAlgo::IsCallNode(call_input0)) { | |||
| return AnfAlgo::GetFuncGraphbyCallNode(call_input0, ++call_depth); | |||
| } | |||
| if (AnfAlgo::CheckPrimitiveType(call_input0, prim::kPrimSwitch)) { | |||
| // First input node of call is switch node. | |||
| const auto &switch_inputs = call_input0->cast<CNodePtr>()->inputs(); | |||
| for (size_t i = kSwitchTrueBranchPos; i < switch_inputs.size(); ++i) { | |||
| MS_EXCEPTION_IF_NULL(switch_inputs[i]); | |||
| (void)func_graphs.emplace(GetFuncGraphFromPartial(switch_inputs[i], call_depth)); | |||
| } | |||
| } else if (AnfAlgo::CheckPrimitiveType(call_input0, prim::kPrimSwitchLayer)) { | |||
| // First input node of call is switch layer node. | |||
| const auto &tuple_node = cnode->cast<CNodePtr>()->input(kSwitchLayerBranchPos); | |||
| if (!AnfAlgo::CheckPrimitiveType(tuple_node, prim::kPrimMakeTuple)) { | |||
| MS_LOG(EXCEPTION) << "Invalid input tuple node:" << tuple_node->DebugString() | |||
| << " for switch layer node:" << cnode->DebugString(); | |||
| } | |||
| const auto &tuple_inputs = tuple_node->cast<CNodePtr>()->inputs(); | |||
| for (size_t i = kMakeTupleInputStartPos; i < tuple_inputs.size(); ++i) { | |||
| MS_EXCEPTION_IF_NULL(tuple_inputs[i]); | |||
| func_graphs.emplace(GetFuncGraphFromPartial(tuple_inputs[i], call_depth)); | |||
| } | |||
| } else if (IsPartial(call_input0)) { | |||
| // First input node of call is partial node or value node of funcgraph. | |||
| (void)func_graphs.emplace(GetFuncGraphFromPartial(call_input0, call_depth)); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Unable to identify call node" << node->DebugString(); | |||
| } | |||
| return func_graphs; | |||
| } | |||
| bool AnfRuntimeAlgorithm::IsPartial(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| return (node->isa<ValueNode>() && IsValueNode<FuncGraph>(node)) || | |||
| AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial); | |||
| } | |||
| FuncGraphPtr AnfRuntimeAlgorithm::GetFuncGraphFromPartial(const AnfNodePtr &node, size_t depth) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (depth == 1) { | |||
| if (node->isa<ValueNode>() && IsValueNode<FuncGraph>(node)) { | |||
| // Value node of funcgraph. | |||
| return GetValueNode<FuncGraphPtr>(node); | |||
| } else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) { | |||
| // Partial cnode. | |||
| const auto &partial_inputs = node->cast<CNodePtr>()->inputs(); | |||
| return GetValueNode<FuncGraphPtr>(partial_inputs[kPartialFuncGraphPos]); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Invalid partial construct node:" << node->DebugString(); | |||
| } | |||
| } | |||
| // Get funcgraph in the output of inner call. | |||
| if (node->isa<ValueNode>() && IsValueNode<FuncGraph>(node)) { | |||
| return GetFuncGraphFromPartial(GetValueNode<FuncGraphPtr>(node)->output(), depth - 1); | |||
| } else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) { | |||
| const auto &partial_inputs = node->cast<CNodePtr>()->inputs(); | |||
| return GetFuncGraphFromPartial(GetValueNode<FuncGraphPtr>(partial_inputs[kPartialFuncGraphPos])->output(), | |||
| depth - 1); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Invalid partial node:" << node->DebugString(); | |||
| } | |||
| return nullptr; | |||
| } | |||
| } // namespace session | |||
| } // namespace mindspore | |||
| @@ -83,7 +83,8 @@ class AnfRuntimeAlgorithm { | |||
| prim::kPrimMakeTuple}); | |||
| static std::vector<AnfNodePtr> GetAllOutput(const AnfNodePtr &node, | |||
| const std::vector<PrimitivePtr> &return_types = {}); | |||
| static std::vector<KernelWithIndex> GetAllOutputWithIndex(const AnfNodePtr &node); | |||
| static std::vector<KernelWithIndex> GetAllOutputWithIndex(const AnfNodePtr &node, | |||
| std::set<AnfNodePtr> *visited_call_nodes = nullptr); | |||
| // get cnode primitive | |||
| static AnfNodePtr GetCNodePrimitiveNode(const CNodePtr &node); | |||
| static void SetNodeInput(const CNodePtr &node, const AnfNodePtr &input_node, size_t index); | |||
| @@ -329,6 +330,20 @@ class AnfRuntimeAlgorithm { | |||
| static void CacheAddrForGraph(const KernelGraphPtr &kernel_graph); | |||
| static void CacheAddrForKernel(const AnfNodePtr &node, kernel::KernelMod *kernel_mod); | |||
| static void CacheAddrForAtomicClean(const AnfNodePtr &node, kernel::KernelMod *kernel_mod); | |||
| // Check whether node is a call node, there are two types of call nodes: | |||
| // 1. First input of node is a cnode. | |||
| // 2. First input of node is a funcgraph value node. | |||
| static bool IsCallNode(const AnfNodePtr &node); | |||
| // Find all funcgraphs that the call node will call. | |||
| static std::set<FuncGraphPtr> GetFuncGraphbyCallNode(const AnfNodePtr &node, size_t call_depth = 1); | |||
| // Check whether node has a partial structure, a node is a partial structure whicih: | |||
| // 1. a partial cnode. | |||
| // 2. a funcgraph value node. | |||
| static bool IsPartial(const AnfNodePtr &node); | |||
| // Get funcgraph in partial structure. | |||
| // Depth represents the number of layers of the call. When the first input of the call node is a call node, | |||
| // the funcgraph in the return value of the inner call needs to be returned. | |||
| static FuncGraphPtr GetFuncGraphFromPartial(const AnfNodePtr &node, size_t depth = 1); | |||
| }; | |||
| } // namespace session | |||
| using AnfAlgo = session::AnfRuntimeAlgorithm; | |||
| @@ -1373,6 +1373,23 @@ bool KernelGraph::IsDatasetGraph() const { | |||
| std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); } | |||
| bool KernelGraph::IsChildGraphResult(const AnfNodePtr &node) { | |||
| std::vector<AnfNodePtr> child_graph_results; | |||
| for (const auto &child_graph_result : child_graph_result_) { | |||
| MS_EXCEPTION_IF_NULL(child_graph_result); | |||
| if (AnfAlgo::CheckPrimitiveType(child_graph_result, prim::kPrimMakeTuple)) { | |||
| const auto cnode = child_graph_result->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| const auto &inputs = cnode->inputs(); | |||
| child_graph_results.insert(child_graph_results.end(), inputs.begin(), inputs.end()); | |||
| } else { | |||
| child_graph_results.emplace_back(child_graph_result); | |||
| } | |||
| } | |||
| return find(child_graph_results.begin(), child_graph_results.end(), node) != child_graph_results.end(); | |||
| } | |||
| KernelGraph::~KernelGraph() { | |||
| try { | |||
| // Release the kernel resource. | |||
| @@ -260,6 +260,7 @@ class KernelGraph : public FuncGraph { | |||
| void UpdateChildGraphOrder(); | |||
| const std::vector<AnfNodePtr> &child_graph_result() const { return child_graph_result_; } | |||
| void AddChildGraphResult(const AnfNodePtr ¶meter) { child_graph_result_.push_back(parameter); } | |||
| bool IsChildGraphResult(const AnfNodePtr &node); | |||
| void set_child_graph_result(const std::vector<AnfNodePtr> &child_graph_result) { | |||
| child_graph_result_ = child_graph_result; | |||
| } | |||
| @@ -54,6 +54,7 @@ bool CheckValidFuncGraphInput(const AnfNodePtr &node) { | |||
| // Get the funcgraph in partial node. | |||
| FuncGraphPtr GetFuncGraphFromPartial(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| const auto &partial_inputs = node->cast<CNodePtr>()->inputs(); | |||
| return GetValueNode<FuncGraphPtr>(partial_inputs[1]); | |||
| } | |||
| @@ -313,7 +314,7 @@ std::vector<AnfNodePtr> FetchFuncGraphOutput(const FuncGraphPtr &func_graph, std | |||
| if (find((*call_nodes).begin(), (*call_nodes).end(), real_output.first) != (*call_nodes).end()) { | |||
| return outputs; | |||
| } | |||
| if (!IsCallNode(real_output.first)) { | |||
| if (!AnfAlgo::IsCallNode(real_output.first)) { | |||
| outputs.push_back(real_output.first); | |||
| return outputs; | |||
| } | |||
| @@ -349,7 +350,7 @@ std::vector<AnfNodePtr> FetchOutputByCallNode(const AnfNodePtr &call_node, std:: | |||
| } else if (AnfAlgo::CheckPrimitiveType(graph_output, prim::kPrimSwitch)) { | |||
| const auto &switch_outputs = FetchOutputBySwitchNode(graph_output, call_nodes, switch_nodes); | |||
| (void)outputs.insert(outputs.end(), switch_outputs.begin(), switch_outputs.end()); | |||
| } else if (IsCallNode(graph_output)) { | |||
| } else if (AnfAlgo::IsCallNode(graph_output)) { | |||
| const auto &call_outputs = FetchOutputByCallNode(graph_output, call_nodes, switch_nodes); | |||
| (void)outputs.insert(outputs.end(), call_outputs.begin(), call_outputs.end()); | |||
| } else if (graph_output->isa<CNode>()) { | |||
| @@ -388,7 +389,7 @@ std::vector<AnfNodePtr> FetchOutputBySwitchNode(const AnfNodePtr &switch_node, s | |||
| } else if (AnfAlgo::CheckPrimitiveType(inputs[i], prim::kPrimSwitch)) { | |||
| const auto &switch_outputs = FetchOutputBySwitchNode(inputs[i], call_nodes, switch_nodes); | |||
| (void)outputs.insert(outputs.end(), switch_outputs.begin(), switch_outputs.end()); | |||
| } else if (IsCallNode(inputs[i])) { | |||
| } else if (AnfAlgo::IsCallNode(inputs[i])) { | |||
| const auto &call_outputs = FetchOutputByCallNode(inputs[i], call_nodes, switch_nodes); | |||
| (void)outputs.insert(outputs.end(), call_outputs.begin(), call_outputs.end()); | |||
| } else { | |||
| @@ -486,7 +487,7 @@ FuncGraphPtr FetchFuncGraphInNode(const auto &node) { | |||
| AnfNodePtr FetchRealOutputByCallNode(const AnfNodePtr &node, std::set<AnfNodePtr> *call_nodes) { | |||
| const auto &real_node = AnfAlgo::VisitKernelWithReturnType(node, 0, false, {prim::kPrimTupleGetItem}).first; | |||
| if (!IsCallNode(real_node)) { | |||
| if (!AnfAlgo::IsCallNode(real_node)) { | |||
| return real_node; | |||
| } | |||
| if ((*call_nodes).find(real_node) != (*call_nodes).end()) { | |||
| @@ -513,15 +514,6 @@ bool HasAbstractRef(const AnfNodePtr &node) { | |||
| return (abs != nullptr) && abs->isa<abstract::AbstractRef>(); | |||
| } | |||
| bool IsCallNode(const AnfNodePtr &node) { | |||
| if (!node->isa<CNode>()) { | |||
| return false; | |||
| } | |||
| const auto &cnode = node->cast<CNodePtr>(); | |||
| const auto &inputs = cnode->inputs(); | |||
| return inputs[0]->isa<CNode>() || (inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(inputs[0])); | |||
| } | |||
| bool IsSubCallNode(const AnfNodePtr &node) { | |||
| if (!node->isa<CNode>()) { | |||
| return false; | |||
| @@ -604,7 +596,7 @@ std::vector<FuncGraphPtr> FetchFuncGraphbyCallNode(const AnfNodePtr &node) { | |||
| func_graphs.emplace_back(func_graph); | |||
| } | |||
| } | |||
| } else if (IsCallNode(cnode)) { | |||
| } else if (AnfAlgo::IsCallNode(cnode)) { | |||
| return FetchFuncGraphbyCallNode(cnode); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Unable to identify call node" << node->DebugString(); | |||
| @@ -618,7 +610,7 @@ std::vector<FuncGraphPtr> FetchFuncGraphbyCallNode(const AnfNodePtr &node) { | |||
| } | |||
| size_t FetchOutputSizebyCallNode(const AnfNodePtr &node, std::vector<AnfNodePtr> *call_nodes) { | |||
| if (!IsCallNode(node)) { | |||
| if (!AnfAlgo::IsCallNode(node)) { | |||
| MS_LOG(EXCEPTION) << "Invalid call node:" << AnfAlgo::GetNodeDebugString(node); | |||
| } | |||
| if (find((*call_nodes).begin(), (*call_nodes).end(), node) != (*call_nodes).end()) { | |||
| @@ -631,7 +623,7 @@ size_t FetchOutputSizebyCallNode(const AnfNodePtr &node, std::vector<AnfNodePtr> | |||
| const auto &output = func_graph->output(); | |||
| const auto &real_output = AnfAlgo::VisitKernelWithReturnType(output, 0); | |||
| if (IsCallNode(real_output.first)) { | |||
| if (AnfAlgo::IsCallNode(real_output.first)) { | |||
| size_t output_num = FetchOutputSizebyCallNode(real_output.first, call_nodes); | |||
| if (output_num > 0) { | |||
| return output_num; | |||
| @@ -642,7 +634,7 @@ size_t FetchOutputSizebyCallNode(const AnfNodePtr &node, std::vector<AnfNodePtr> | |||
| const auto &inputs = tuple_cnode->inputs(); | |||
| size_t i = 1; | |||
| for (; i < inputs.size(); ++i) { | |||
| if (IsCallNode(inputs[i])) { | |||
| if (AnfAlgo::IsCallNode(inputs[i])) { | |||
| size_t call_output_num = FetchOutputSizebyCallNode(inputs[i], call_nodes); | |||
| if (call_output_num == 0) { | |||
| break; | |||
| @@ -872,7 +864,7 @@ void ControlNodeParser::FetchValueNodeBySwitchNode(const AnfNodePtr &switch_node | |||
| if (node_value->isa<tensor::Tensor>()) { | |||
| (void)((*value_nodes).emplace_back(input)); | |||
| } | |||
| } else if (IsCallNode(input)) { | |||
| } else if (AnfAlgo::IsCallNode(input)) { | |||
| // If input is a call not, should check the switch node in its input. | |||
| const auto &call_node = input->cast<CNodePtr>(); | |||
| const auto &call_inputs = call_node->inputs(); | |||
| @@ -1050,7 +1042,7 @@ void ControlNodeParser::FetchFrontToFrontParameter( | |||
| std::vector<AnfNodePtr> call_inputs; | |||
| call_inputs.assign(inputs.begin() + SizeToInt(kCallInputStartPos), inputs.end()); | |||
| switch_input_parse(inputs[0], call_inputs); | |||
| } else if (IsCallNode(inputs[0])) { | |||
| } else if (AnfAlgo::IsCallNode(inputs[0])) { | |||
| continue; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "First input node of call node is not switch, node:" | |||
| @@ -1098,7 +1090,7 @@ std::vector<AnfNodePtr> ControlNodeParser::FetchControlNodeParameter(const std:: | |||
| void ControlNodeParser::FetchFuncGraphCallNum(const std::vector<AnfNodePtr> &control_nodes) { | |||
| for (const auto &control_node : control_nodes) { | |||
| if (IsCallNode(control_node)) { | |||
| if (AnfAlgo::IsCallNode(control_node)) { | |||
| const auto &func_graphs = FetchFuncGraphbyCallNode(control_node); | |||
| for (const auto &func_graph : func_graphs) { | |||
| @@ -1123,7 +1115,7 @@ void ControlNodeParser::FetchCallInputKernelGraph(const std::vector<KernelGraphP | |||
| const auto inputs = graph->input_nodes(); | |||
| for (const auto &input : inputs) { | |||
| const auto &internal_parameter_with_index = graph->GetFrontNodeByInternalParameter(input); | |||
| if (internal_parameter_with_index.first != nullptr && IsCallNode(internal_parameter_with_index.first)) { | |||
| if (internal_parameter_with_index.first != nullptr && AnfAlgo::IsCallNode(internal_parameter_with_index.first)) { | |||
| call_input_kernel_graphs_[graph] = device_context; | |||
| call_node_to_backend_parameters_[internal_parameter_with_index] = {input, device_context}; | |||
| } | |||
| @@ -1162,12 +1154,12 @@ std::vector<AnfNodePtr> FetchInputParameterbyControlNode(const AnfNodePtr &node, | |||
| for (size_t i = kSwitchTrueBranchPos; i < kSwitchInputNum; ++i) { | |||
| if (inputs[i]->isa<Parameter>()) { | |||
| (void)parameters.emplace_back(inputs[i]); | |||
| } else if (IsCallNode(inputs[i]) || AnfAlgo::CheckPrimitiveType(inputs[i], prim::kPrimSwitch)) { | |||
| } else if (AnfAlgo::IsCallNode(inputs[i]) || AnfAlgo::CheckPrimitiveType(inputs[i], prim::kPrimSwitch)) { | |||
| const auto &sub_parameters = FetchInputParameterbyControlNode(inputs[i], switch_nodes, call_nodes); | |||
| (void)parameters.insert(parameters.end(), sub_parameters.begin(), sub_parameters.end()); | |||
| } | |||
| } | |||
| } else if (IsCallNode(node)) { | |||
| } else if (AnfAlgo::IsCallNode(node)) { | |||
| if ((*call_nodes).find(node) != (*call_nodes).end()) { | |||
| return parameters; | |||
| } | |||
| @@ -1296,7 +1288,7 @@ void ControlNodeParser::FetchFuncGraphToParameter(const std::vector<AnfNodePtr> | |||
| } else if (AnfAlgo::CheckPrimitiveType(inputs[0], prim::kPrimSwitchLayer)) { | |||
| // Switchlayer node. | |||
| FetchParameterBySwitchLayerNode(inputs[0], inputs, &func_graph_to_parameters_); | |||
| } else if (IsCallNode(inputs[0])) { | |||
| } else if (AnfAlgo::IsCallNode(inputs[0])) { | |||
| continue; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Unable to identify call node" << switch_cnode->DebugString(); | |||
| @@ -1373,7 +1365,7 @@ void ControlNodeParser::FetchBackendOutputByFrontOutput(const AnfNodePtr &front_ | |||
| for (const auto &switch_output : switch_outputs) { | |||
| FetchBackendOutputByFrontOutput(switch_output, call_nodes, switch_nodes, results); | |||
| } | |||
| } else if (IsCallNode(front_output)) { | |||
| } else if (AnfAlgo::IsCallNode(front_output)) { | |||
| // Output is a call. | |||
| const auto &call_outputs = FetchOutputByCallNode(front_output, call_nodes, switch_nodes); | |||
| @@ -1429,7 +1421,7 @@ void ControlNodeParser::FetchBackendInputNodebyFrontNode( | |||
| } | |||
| } else if (real_parameter->isa<ValueNode>()) { | |||
| (void)formal_to_real_parameters_[formal_parameter].emplace_back(real_parameter, 0); | |||
| } else if (IsCallNode(real_parameter)) { | |||
| } else if (AnfAlgo::IsCallNode(real_parameter)) { | |||
| const auto func_graphs = FetchFuncGraphbyCallNode(real_parameter); | |||
| for (const auto func_graph : func_graphs) { | |||
| FetchBackendInputNodebyFrontNode(func_graph->output(), formal_parameter, front_to_backend_parameters); | |||
| @@ -56,11 +56,6 @@ using HostParameterToWeight = std::unordered_map<AnfNodePtr, std::vector<AnfNode | |||
| using NodeWithDeviceContext = std::vector<std::pair<AnfNodePtr, DeviceContext *>>; | |||
| using RealToFormalNode = std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>>; | |||
| // Check whether node is a call node, there are two types of call nodes: | |||
| // 1. First input of node is a cnode. | |||
| // 2. First input of node is a funcgraph value node. | |||
| bool IsCallNode(const AnfNodePtr &node); | |||
| // Check if the call node is the input of another call node. | |||
| bool IsSubCallNode(const AnfNodePtr &node); | |||
| @@ -1689,8 +1689,9 @@ void GraphScheduler::FetchKernelTransformTypeAndName(const AnfNodePtr &node, con | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(kernel_type); | |||
| MS_EXCEPTION_IF_NULL(kernel_name); | |||
| if (graph->is_executing_sink() && ((node == nullptr) || node->isa<CNode>())) { | |||
| // In sink mode, the data exchange between child graphs is expressed as parameters. These parameters are stored | |||
| // in the graph and should be obtained from the super kernel actor. | |||
| if (graph->is_executing_sink() && ((node == nullptr) || node->isa<CNode>() || graph->IsChildGraphResult(node))) { | |||
| *kernel_type = KernelTransformType::kSuperKernelActor; | |||
| *kernel_name = graph->ToString() + "_SuperKernelActor"; | |||
| return; | |||
| @@ -980,13 +980,6 @@ std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(con | |||
| AnfAlgo::VisitKernelWithReturnType(root_graph->output(), 0, false, {prim::kPrimTupleGetItem}).first; | |||
| size_t position = 0; | |||
| auto outputs = AnfAlgo::GetAllOutputWithIndex(root_output); | |||
| if (runtime::IsCallNode(root_output)) { | |||
| std::vector<AnfNodePtr> call_nodes; | |||
| size_t call_output_num = runtime::FetchOutputSizebyCallNode(root_output, &call_nodes); | |||
| for (size_t i = 0; i < call_output_num; ++i) { | |||
| (void)outputs.emplace_back(root_output, i); | |||
| } | |||
| } | |||
| outputs_num = outputs.size(); | |||
| for (const auto &output : outputs) { | |||
| if (outputs_order.count(output) == 0) { | |||