| @@ -21,6 +21,9 @@ | |||
| namespace mindspore { | |||
| namespace runtime { | |||
| void AbstractActor::RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) { | |||
| MS_EXCEPTION_IF_NULL(input_data); | |||
| MS_EXCEPTION_IF_NULL(input_data->data_); | |||
| MS_EXCEPTION_IF_NULL(input_data->data_->GetPtr()); | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| auto &sequential_num = context->sequential_num_; | |||
| (void)input_op_datas_[sequential_num].emplace_back(input_data); | |||
| @@ -37,9 +37,11 @@ void StackActor::Init() { | |||
| // 6. Call input partial. | |||
| input_datas_num_ = formal_parameters_.size() - input_stack_data_num_ - input_stack_partials_num_; | |||
| if (input_stack_data_num_ < device_tensor_store_keys_.size() + local_device_tensors_.size()) { | |||
| MS_LOG(EXCEPTION) << "Invalid input parameter data num:" << input_stack_data_num_ | |||
| << " device store num:" << device_tensor_store_keys_.size() << " local device tensor num" | |||
| << local_device_tensors_.size() << " for actor:" << GetAID(); | |||
| MS_LOG(EXCEPTION) << "Invalid input stack data num:" << input_stack_data_num_ | |||
| << " device store num:" << device_tensor_store_keys_.size() | |||
| << " local device tensor num:" << local_device_tensors_.size() | |||
| << " input stack data num:" << input_stack_data_num_ | |||
| << " input stack partial num:" << input_stack_partials_num_ << " for actor:" << GetAID(); | |||
| } | |||
| // Fetch the total number of input partial. | |||
| @@ -63,8 +65,8 @@ void StackActor::Init() { | |||
| if (input_stack_data_num_ + input_stack_partials_num_ + input_datas_num_ + input_partials_num_ + | |||
| device_tensor_store_keys_.size() + local_device_tensors_.size() != | |||
| formal_parameters_.size()) { | |||
| MS_LOG(EXCEPTION) << "Invalid input num, input parameter data num:" << input_stack_data_num_ | |||
| << " input parameter partial num:" << input_stack_partials_num_ | |||
| MS_LOG(EXCEPTION) << "Invalid input num, input stack data num:" << input_stack_data_num_ | |||
| << " input stack partial num:" << input_stack_partials_num_ | |||
| << " input data num:" << input_datas_num_ << " input partial num:" << input_partials_num_ | |||
| << " device tensor store size:" << device_tensor_store_keys_.size() | |||
| << " need total size:" << formal_parameters_.size() << " for actor:" << GetAID(); | |||
| @@ -39,8 +39,8 @@ void SwitchActor::Init() { | |||
| } | |||
| auto data = std::make_unique<OpData<DeviceTensor>>(data_arrow->to_op_id_, nullptr, data_arrow->to_input_index_); | |||
| MS_EXCEPTION_IF_NULL(data); | |||
| (void)output_data_.emplace_back(std::move(data)); | |||
| (void)output_data_by_output_index_[data_arrow->from_output_index_].emplace_back(data.get()); | |||
| (void)output_data_.emplace_back(std::move(data)); | |||
| } | |||
| } | |||
| @@ -689,6 +689,56 @@ void AddFormalToRealParameter(const AnfNodePtr &formal_parameter, const AnfNodeP | |||
| (*formal_to_real_parameters)[{formal_parameter, i}].insert(real_parameters.begin(), real_parameters.end()); | |||
| } | |||
| } | |||
| // Recursively traverse the input to confirm whether there is an input of recursive call. | |||
| bool IsFirstControlNode(const AnfNodePtr &node, std::set<AnfNodePtr> *checked_nodes, | |||
| std::set<AnfNodePtr> unrecursion_call_nodes) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_EXCEPTION_IF_NULL(checked_nodes); | |||
| if (!node->isa<CNode>() || checked_nodes->find(node) != checked_nodes->end()) { | |||
| return true; | |||
| } | |||
| checked_nodes->emplace(node); | |||
| const auto &cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| const auto &inputs = cnode->inputs(); | |||
| for (const auto &input : inputs) { | |||
| MS_EXCEPTION_IF_NULL(input); | |||
| if ((AnfAlgo::IsCallNode(input) && unrecursion_call_nodes.find(input) == unrecursion_call_nodes.end()) || | |||
| (!IsFirstControlNode(input, checked_nodes, unrecursion_call_nodes))) { | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| // Get the level of the control node, recursively traverse all the inputs of the node, and find the largest level | |||
| // among them. | |||
| size_t ParseControlNodeLevel(const AnfNodePtr &node, std::set<AnfNodePtr> *checked_nodes, | |||
| const mindspore::HashMap<AnfNodePtr, size_t> &node_to_level) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_EXCEPTION_IF_NULL(checked_nodes); | |||
| if (!node->isa<CNode>() || checked_nodes->find(node) != checked_nodes->end()) { | |||
| return 0; | |||
| } | |||
| checked_nodes->emplace(node); | |||
| auto iter = node_to_level.find(node); | |||
| if (iter != node_to_level.end()) { | |||
| return iter->second; | |||
| } | |||
| size_t level = 0; | |||
| const auto &cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| const auto &inputs = cnode->inputs(); | |||
| for (const auto &input : inputs) { | |||
| size_t tmp_level = ParseControlNodeLevel(input, checked_nodes, node_to_level); | |||
| level = (tmp_level > level ? tmp_level : level); | |||
| } | |||
| return level; | |||
| } | |||
| } // namespace | |||
| KernelWithIndex FetchRealNodeByGetItem(const KernelWithIndex &node_with_index) { | |||
| @@ -870,6 +920,8 @@ void ControlNodeParser::Parse(const std::vector<AnfNodePtr> &control_nodes, cons | |||
| ParseNeedStackKernelGraph(kernel_graph_to_device_contexts); | |||
| ParseNodeLevel(control_nodes); | |||
| ParseNeedStackControlNode(control_nodes); | |||
| ParseFormalToRealParameter(control_nodes); | |||
| @@ -961,7 +1013,7 @@ bool ControlNodeParser::IsRecursionKernelGraph(const KernelGraphPtr &graph) { | |||
| MS_LOG(EXCEPTION) << "Invalid kernel graph:" << graph->ToString(); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(group_info_iter->second); | |||
| if (!group_info_iter->second->is_call_input_) { | |||
| if (!group_info_iter->second->need_stack_) { | |||
| return false; | |||
| } | |||
| for (const auto &front_input_node : group_info_iter->second->front_input_nodes_) { | |||
| @@ -1303,7 +1355,7 @@ bool ControlNodeParser::IsCallInputKernelGraph(KernelGraph *const graph) { | |||
| bool ControlNodeParser::IsCallInputKernelGraphGroup(const std::string &group_name) { | |||
| for (const auto &graph_group : kernel_graph_group_infos_) { | |||
| if (group_name.find(graph_group->group_name_) != std ::string::npos) { | |||
| return graph_group->is_call_input_; | |||
| return graph_group->need_stack_; | |||
| } | |||
| } | |||
| MS_LOG(EXCEPTION) << "Invalid kernel graph group name:" << group_name; | |||
| @@ -1697,28 +1749,6 @@ AnfNodePtr ControlNodeParser::FetchRootGraphFrontNodeBySubFrontNode(const AnfNod | |||
| return sub_front_node_to_root_front_node_[sub_front_node]; | |||
| } | |||
| bool IsFirstControlNode(const AnfNodePtr &node, std::set<AnfNodePtr> *checked_nodes, | |||
| std::set<AnfNodePtr> unrecursion_call_nodes) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_EXCEPTION_IF_NULL(checked_nodes); | |||
| if (!node->isa<CNode>() || checked_nodes->find(node) != checked_nodes->end()) { | |||
| return true; | |||
| } | |||
| checked_nodes->emplace(node); | |||
| const auto &cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| const auto &inputs = cnode->inputs(); | |||
| for (const auto &input : inputs) { | |||
| MS_EXCEPTION_IF_NULL(input); | |||
| if ((AnfAlgo::IsCallNode(input) && unrecursion_call_nodes.find(input) == unrecursion_call_nodes.end()) || | |||
| (!IsFirstControlNode(input, checked_nodes, unrecursion_call_nodes))) { | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| void ControlNodeParser::ParseFirstControlNodeForFuncGraph(const std::vector<AnfNodePtr> &control_nodes) { | |||
| for (const auto &control_node : control_nodes) { | |||
| std::set<AnfNodePtr> checked_nodes; | |||
| @@ -1806,10 +1836,10 @@ void ControlNodeParser::ParseNeedStackControlNode(const std::vector<AnfNodePtr> | |||
| if (call_input_num != 0 && (AnfAlgo::CheckPrimitiveType(inputs[kReturnInputPos], prim::kPrimDepend))) { | |||
| need_stack_control_nodes_.emplace(control_node); | |||
| } | |||
| } else if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimPartial)) { | |||
| auto input_with_indexs = FetchInputNodeByCNode(control_node); | |||
| if (std::any_of(input_with_indexs.begin(), input_with_indexs.end(), | |||
| [this](const auto &input_with_index) { return IsRecursionCallNode(input_with_index.first); })) { | |||
| } else if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimPartial) || | |||
| AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitch) || | |||
| AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitchLayer)) { | |||
| if (!IsInputInSameLevel(control_node)) { | |||
| need_stack_control_nodes_.emplace(control_node); | |||
| MS_LOG(DEBUG) << "Add need stack control node:" << control_node->DebugString(); | |||
| } | |||
| @@ -1850,7 +1880,7 @@ void ControlNodeParser::ParseNeedStackKernelGraph(const KernelGraphToDeviceConte | |||
| continue; | |||
| } | |||
| if (AnfAlgo::IsCallNode(front_node_with_index.first)) { | |||
| kernel_graph_group_info->is_call_input_ = true; | |||
| kernel_graph_group_info->need_stack_ = true; | |||
| } | |||
| if (AnfAlgo::CheckPrimitiveType(front_node_with_index.first, prim::kPrimTupleGetItem)) { | |||
| @@ -1877,7 +1907,7 @@ void ControlNodeParser::ParseNeedStackKernelGraph(const KernelGraphToDeviceConte | |||
| } | |||
| kernel_graphs_to_group_info_[kernel_graph] = kernel_graph_group_info; | |||
| if (kernel_graph_group_info->is_call_input_) { | |||
| if (kernel_graph_group_info->need_stack_) { | |||
| call_input_kernel_graphs_.emplace(kernel_graph.get()); | |||
| } | |||
| } | |||
| @@ -1890,6 +1920,97 @@ void ControlNodeParser::ParseNeedStackKernelGraph(const KernelGraphToDeviceConte | |||
| } | |||
| } | |||
| void ControlNodeParser::ParseNodeLevel(const std::vector<AnfNodePtr> &control_nodes) { | |||
| size_t level = 0; | |||
| // 1. Parse levels of control nodes. | |||
| for (const auto &control_node : control_nodes) { | |||
| if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) { | |||
| node_to_level_[control_node] = level; | |||
| level = 0; | |||
| const auto &func_graph = control_node->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| const auto ¶meters = func_graph->parameters(); | |||
| for (const auto ¶meter : parameters) { | |||
| node_to_level_[parameter] = level; | |||
| } | |||
| continue; | |||
| } else if (AnfAlgo::IsCallNode(control_node) && IsRecursionCallNode(control_node)) { | |||
| ++level; | |||
| node_to_level_[control_node] = level; | |||
| } else { | |||
| std::set<AnfNodePtr> checked_nodes; | |||
| node_to_level_[control_node] = ParseControlNodeLevel(control_node, &checked_nodes, node_to_level_); | |||
| } | |||
| } | |||
| // 2. Parse the levels of kernel graph outputs. | |||
| for (const auto &kernel_graph_group_info : kernel_graph_group_infos_) { | |||
| level = 0; | |||
| for (const auto &front_input_node : kernel_graph_group_info->front_input_nodes_) { | |||
| const auto &input_node = front_input_node.first.first; | |||
| auto iter = node_to_level_.find(input_node); | |||
| if (iter != node_to_level_.end() && level < iter->second) { | |||
| level = iter->second; | |||
| } | |||
| } | |||
| for (const auto &front_output_node : kernel_graph_group_info->front_output_nodes_) { | |||
| const auto &output_node = front_output_node.first.first; | |||
| node_to_level_[output_node] = level; | |||
| } | |||
| } | |||
| // Parse the levels of kernel graph groups. | |||
| for (const auto &kernel_graph_group_info : kernel_graph_group_infos_) { | |||
| size_t max_level = 0; | |||
| for (const auto &front_input_node : kernel_graph_group_info->front_input_nodes_) { | |||
| const auto &input_node = front_input_node.first.first; | |||
| auto iter = node_to_level_.find(input_node); | |||
| if (iter == node_to_level_.end()) { | |||
| MS_LOG(EXCEPTION) << "Failed to get level by input node:" << input_node->DebugString() | |||
| << " for kernel graph:" << kernel_graph_group_info->group_name_; | |||
| } | |||
| max_level = (max_level > iter->second ? max_level : iter->second); | |||
| } | |||
| if (max_level > 0) { | |||
| kernel_graph_group_info->need_stack_ = true; | |||
| kernel_graph_group_info->level_ = max_level; | |||
| for (const auto &kernel_graph : kernel_graph_group_info->graphs_) { | |||
| call_input_kernel_graphs_.emplace(kernel_graph.get()); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| bool ControlNodeParser::IsInputInSameLevel(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (!node->isa<CNode>()) { | |||
| return true; | |||
| } | |||
| auto input_with_indexes = FetchInputNodeByCNode(node); | |||
| size_t level = SIZE_MAX; | |||
| for (const auto &input_with_index : input_with_indexes) { | |||
| auto input_node = input_with_index.first; | |||
| MS_EXCEPTION_IF_NULL(input_node); | |||
| if (input_node->isa<ValueNode>()) { | |||
| continue; | |||
| } | |||
| auto iter = node_to_level_.find(input_node); | |||
| if (iter == node_to_level_.end()) { | |||
| MS_LOG(EXCEPTION) << "Failed to find level by input:" << input_node->DebugString() | |||
| << " for node:" << node->DebugString(); | |||
| } | |||
| if (level == SIZE_MAX) { | |||
| level = iter->second; | |||
| continue; | |||
| } | |||
| if (level != iter->second) { | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| void ControlNodeParser::CreateDeviceTensorForRootGraphParameter(DeviceContext *default_context) { | |||
| MS_EXCEPTION_IF_NULL(default_context); | |||
| for (const auto ¶meter : root_graph_parameters_) { | |||
| @@ -82,8 +82,11 @@ using CallNodeToFuncGraph = mindspore::HashMap<AnfNodePtr, std::set<FuncGraphPtr | |||
| using KernelGraphToDeviceContext = mindspore::HashMap<KernelGraphPtr, DeviceContext *>; | |||
| // In the control flow, heterogeneous kernel graphs need to be reconnected in the same group, and the kernel graph | |||
| // group info is used to store the inputs and outputs of the group. | |||
| // Need stack indicates whether a stack actor needs to be created for the group. | |||
| // Level indicates the level of the output of the graph in the group. | |||
| struct KernelGraphGroupInfo { | |||
| bool is_call_input_; | |||
| bool need_stack_{0}; | |||
| size_t level_; | |||
| std::string group_name_; | |||
| std::set<KernelGraphPtr> graphs_; | |||
| std::map<KernelWithIndex, const DeviceContext *> front_input_nodes_; | |||
| @@ -128,6 +131,7 @@ class ControlNodeParser { | |||
| // If there is a recursive call node in the input of the kernel graph, the graph is recursive. | |||
| bool IsRecursionKernelGraph(const KernelGraphPtr &graph); | |||
| bool IsSameKernelGraphGroup(const AnfNodePtr &node, const KernelGraphPtr &graph); | |||
| bool IsInputInSameLevel(const AnfNodePtr &node); | |||
| const std::vector<AnfNodePtr> &control_node_parameters() const { return control_node_parameters_; } | |||
| const FrontToBackendNodeWithContext &front_to_backend_parameters() const { return front_to_backend_parameters_; } | |||
| @@ -217,6 +221,8 @@ class ControlNodeParser { | |||
| // When a control node or kernel graph has input that is a call node, you need to add a stack actor for it. | |||
| void ParseNeedStackControlNode(const std::vector<AnfNodePtr> &control_nodes); | |||
| void ParseNeedStackKernelGraph(const KernelGraphToDeviceContext &kernel_graph_to_device_contexts); | |||
| // Parse the level of inputs and outputs of graphs and all control nodes. | |||
| void ParseNodeLevel(const std::vector<AnfNodePtr> &control_nodes); | |||
| // When the parameter is directly used as the condition of the switch, there will be no back-end node, and a device | |||
| // tensor needs to be created for it. | |||
| void CreateDeviceTensorForRootGraphParameter(DeviceContext *default_context); | |||
| @@ -241,6 +247,15 @@ class ControlNodeParser { | |||
| // id needs to be sent to the gather actor corresponding to the funcgraph, and the gather will send the branch id | |||
| // to its output switch actor. | |||
| mindspore::HashMap<AnfNodePtr, int> call_node_to_branch_id_; | |||
| // Level indicates that the input of the node depends on the number of the recursive call node in the funcgraph. | |||
| // During graph scheduler, the input needs to be graded according to the input's dependence on the recursive call | |||
| // node, and according to this level, the lower-level inputs are pushed in the stack actor. When arranging, first | |||
| // sort the call nodes in the funcgraph according to their topological relationships, and then confirm the | |||
| // dependencies of other nodes on these call nodes in turn. | |||
| // For example, the dependencies are a -> b, b -> d, c -> d, where b is a call node, then the level of a and c is 0, | |||
| // and the level of bd is 1, then since d has inputs with different levels of b and c, it is necessary to add a | |||
| // stack to d. | |||
| mindspore::HashMap<AnfNodePtr, size_t> node_to_level_; | |||
| CallNodeToFuncGraph call_node_to_func_graphs_; | |||
| // host parameter to weights records the weights in the subgraph corresponding to the node in the root funcgraph. | |||
| // When initializing the weights, all related weights need to be recorded as the same device tensor. | |||
| @@ -81,16 +81,6 @@ void FetchRealDependNodeByAutoMonad(const AnfNodePtr &node, std::set<AnfNodePtr> | |||
| } | |||
| } | |||
| bool IsControlFlowArrow(const ControlNodeParserPtr &parser, const KernelGraphPtr &graph, const AnfNodePtr &from_node) { | |||
| MS_EXCEPTION_IF_NULL(parser); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(from_node); | |||
| bool is_call_input_kernl_graph = parser->IsCallInputKernelGraph(graph.get()); | |||
| return ((!is_call_input_kernl_graph) && ((from_node == nullptr) || (!from_node->isa<Parameter>()))) || | |||
| (from_node != nullptr && IsPersistentDeviceTensor(from_node)) || | |||
| (from_node != nullptr && parser->IsSameKernelGraphGroup(from_node, graph)); | |||
| } | |||
| // Parameter and ref node can not copy the device tensor. | |||
| bool is_need_copy_device_tensor(const AnfNodePtr &backend_node, size_t index) { | |||
| MS_EXCEPTION_IF_NULL(backend_node); | |||
| @@ -317,10 +307,10 @@ std::vector<StackActorPtr> ControlNodeScheduler::BuildStackActor(const GraphComp | |||
| // Create a corresponding stack actor for each kernel graph that has a call node as input. | |||
| for (const auto &kernel_graph_group_info : parser->kernel_graph_group_infos_) { | |||
| if (!kernel_graph_group_info->is_call_input_) { | |||
| if (!kernel_graph_group_info->need_stack_) { | |||
| continue; | |||
| } | |||
| const auto &actor_name = kernel_graph_group_info->group_name_ + kStackActorNameSuffix; | |||
| size_t input_parameter_data_num = 0; | |||
| std::vector<const DeviceContext *> device_contexts; | |||
| std::vector<KernelWithIndex> formal_parameters; | |||
| @@ -329,8 +319,13 @@ std::vector<StackActorPtr> ControlNodeScheduler::BuildStackActor(const GraphComp | |||
| // If the input comes from inside funcgraph, put it at the front of the vector, otherwise put it at the end. | |||
| const auto &from_node = node_with_context.first.first; | |||
| MS_EXCEPTION_IF_NULL(from_node); | |||
| const auto &graph = (from_node->isa<CNode>() ? parser->FetchKernelGraphByFrontNode(from_node) : nullptr); | |||
| if (parser->IsRecursionCallNode(from_node) || (graph != nullptr && parser->IsRecursionKernelGraph(graph))) { | |||
| auto iter = parser->node_to_level_.find(from_node); | |||
| if (iter == parser->node_to_level_.end()) { | |||
| MS_LOG(EXCEPTION) << "Failed to get level by from node:" << from_node->DebugString() | |||
| << " in graph:" << kernel_graph_group_info->group_name_; | |||
| } | |||
| if (iter->second == kernel_graph_group_info->level_ && | |||
| (!(parser->IsRootGraphParameter(from_node) && IsPersistentDeviceTensor(from_node)))) { | |||
| formal_parameters.emplace_back(node_with_context.first); | |||
| device_contexts.emplace_back(node_with_context.second); | |||
| } else { | |||
| @@ -339,7 +334,6 @@ std::vector<StackActorPtr> ControlNodeScheduler::BuildStackActor(const GraphComp | |||
| input_parameter_data_num++; | |||
| } | |||
| } | |||
| const auto &actor_name = kernel_graph_group_info->group_name_ + kStackActorNameSuffix; | |||
| const auto &stack_actor = std::make_shared<StackActor>(actor_name, memory_manager_aid_, formal_parameters); | |||
| stack_actors.emplace_back(stack_actor); | |||
| stack_actor->device_contexts_.swap(device_contexts); | |||
| @@ -360,6 +354,7 @@ void ControlNodeScheduler::BuildStackActorForControlNode(const GraphCompilerInfo | |||
| for (const auto &need_stack_control_node : parser->need_stack_control_nodes_) { | |||
| MS_EXCEPTION_IF_NULL(need_stack_control_node); | |||
| const auto &stack_actor_name = GetActorName(need_stack_control_node) + kStackActorNameSuffix; | |||
| std::vector<KernelWithIndex> formal_parameters; | |||
| std::vector<const DeviceContext *> device_contexts; | |||
| size_t input_parameter_data_num = 0; | |||
| @@ -372,12 +367,20 @@ void ControlNodeScheduler::BuildStackActorForControlNode(const GraphCompilerInfo | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| control_actor_name = func_graph->ToString() + kExitActorNameSuffix; | |||
| } else if (AnfAlgo::CheckPrimitiveType(need_stack_control_node, prim::kPrimPartial) || | |||
| AnfAlgo::CheckPrimitiveType(need_stack_control_node, prim::kPrimSwitch) || | |||
| AnfAlgo::CheckPrimitiveType(need_stack_control_node, prim::kPrimSwitchLayer) || | |||
| AnfAlgo::IsCallNode(need_stack_control_node)) { | |||
| control_actor_name = GetActorName(need_stack_control_node); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Invalid control node:" << need_stack_control_node->DebugString(); | |||
| } | |||
| auto iter = parser->node_to_level_.find(need_stack_control_node); | |||
| if (iter == parser->node_to_level_.end()) { | |||
| MS_LOG(EXCEPTION) << "Failed to get level for need stack control node:" << need_stack_control_node->DebugString(); | |||
| } | |||
| size_t control_node_level = iter->second; | |||
| auto actor = FetchActor(control_actor_name); | |||
| MS_EXCEPTION_IF_NULL(actor); | |||
| auto control_actor = dynamic_cast<ControlActor *>(actor); | |||
| @@ -392,13 +395,20 @@ void ControlNodeScheduler::BuildStackActorForControlNode(const GraphCompilerInfo | |||
| for (size_t i = 0; i < control_actor->formal_parameters_.size(); ++i) { | |||
| const auto ¶meter = control_actor->formal_parameters_[i]; | |||
| auto device_context = control_actor->device_contexts_[i]; | |||
| const auto &graph = | |||
| (parameter.first->isa<CNode>() ? parser->FetchKernelGraphByFrontNode(parameter.first) : nullptr); | |||
| if (parser->IsRecursionCallNode(parameter.first) || (graph != nullptr && parser->IsRecursionKernelGraph(graph))) { | |||
| if (parameter.first->isa<ValueNode>()) { | |||
| continue; | |||
| } | |||
| iter = parser->node_to_level_.find(parameter.first); | |||
| if (iter == parser->node_to_level_.end()) { | |||
| MS_LOG(EXCEPTION) << "Failed to get level for formal parameter:" << parameter.first->DebugString() | |||
| << " for need stack control node:" << need_stack_control_node->DebugString(); | |||
| } | |||
| if (control_node_level == iter->second && | |||
| (!(parser->IsRootGraphParameter(parameter.first) && IsPersistentDeviceTensor(parameter.first)))) { | |||
| formal_parameters.emplace_back(parameter); | |||
| device_contexts.emplace_back(device_context); | |||
| } else if (parameter.first->isa<ValueNode>()) { | |||
| continue; | |||
| } else { | |||
| formal_parameters.insert(formal_parameters.begin(), parameter); | |||
| device_contexts.insert(device_contexts.begin(), device_context); | |||
| @@ -415,7 +425,6 @@ void ControlNodeScheduler::BuildStackActorForControlNode(const GraphCompilerInfo | |||
| } | |||
| } | |||
| // Create stack actor. | |||
| const auto &stack_actor_name = GetActorName(need_stack_control_node) + kStackActorNameSuffix; | |||
| const auto &stack_actor = std::make_shared<StackActor>(stack_actor_name, memory_manager_aid_, formal_parameters); | |||
| stack_actor->device_contexts_ = device_contexts; | |||
| stack_actor->input_stack_data_num_ = input_parameter_data_num; | |||
| @@ -495,9 +504,20 @@ void ControlNodeScheduler::LinkArrowForControlActor(ControlActorSet *const contr | |||
| MS_EXCEPTION_IF_NULL(graph_compiler_info.control_node_parser_); | |||
| const auto &parser = graph_compiler_info.control_node_parser_; | |||
| for (auto &switch_actor : control_actor_set->switch_actors_) { | |||
| for (size_t i = 0; i < switch_actor->formal_parameters_.size(); ++i) { | |||
| LinkArrowbyFormalParameter(switch_actor.get(), switch_actor->formal_parameters_[i], {switch_actor->node_, i}, | |||
| parser); | |||
| MS_EXCEPTION_IF_NULL(switch_actor); | |||
| if (parser->need_stack_control_nodes_.find(switch_actor->node_) == parser->need_stack_control_nodes_.end()) { | |||
| for (size_t i = 0; i < switch_actor->formal_parameters_.size(); ++i) { | |||
| LinkArrowbyFormalParameter(switch_actor.get(), switch_actor->formal_parameters_[i], {switch_actor->node_, i}, | |||
| parser); | |||
| } | |||
| } else { | |||
| // If the control actor has a corresponding stack actor, the input should be linked to the stack actor. | |||
| auto stack_actor_name = GetActorName(switch_actor->node_) + kStackActorNameSuffix; | |||
| auto actor = FetchActor(stack_actor_name); | |||
| MS_EXCEPTION_IF_NULL(actor); | |||
| auto stack_actor = dynamic_cast<StackActor *>(actor); | |||
| MS_EXCEPTION_IF_NULL(stack_actor); | |||
| LinkArrowFromStackActor(stack_actor, switch_actor.get()); | |||
| } | |||
| } | |||
| @@ -601,7 +621,14 @@ void ControlNodeScheduler::LinkArrowbyFormalParameter(ControlActor *const to_act | |||
| MS_EXCEPTION_IF_NULL(actor); | |||
| const auto &switch_actor = dynamic_cast<SwitchActor *>(actor); | |||
| MS_EXCEPTION_IF_NULL(switch_actor); | |||
| LinkPartialArrow(switch_actor, to_actor, from_node_with_index.second, to_node_with_index.second); | |||
| const auto &abstract = from_node->abstract(); | |||
| MS_EXCEPTION_IF_NULL(abstract); | |||
| if (abstract->isa<abstract::AbstractFunction>()) { | |||
| LinkPartialArrow(switch_actor, to_actor, from_node_with_index.second, to_node_with_index.second); | |||
| } else { | |||
| LinkDataArrow(switch_actor, to_actor, from_node_with_index.second, to_node_with_index.second); | |||
| } | |||
| } else if (AnfAlgo::CheckPrimitiveType(from_node, prim::kPrimPartial)) { | |||
| // Link arrow from gather actor | |||
| const auto &actor_name = GetActorName(from_node); | |||
| @@ -934,7 +961,7 @@ void ControlNodeScheduler::LinkControlArrowForKernelActor(ActorSet *const actor_ | |||
| continue; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| auto actor_name = kernel_graph->ToString() + kStackActorNameSuffix; | |||
| auto actor_name = parser->FetchGroupNameByKernelGraph(kernel_graph) + kStackActorNameSuffix; | |||
| if (!parser->IsCallInputKernelGraph(kernel_graph.get())) { | |||
| const auto &func_graph = parser->FetchFuncGraphByKernelGraph(kernel_graph.get()); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| @@ -1135,10 +1162,8 @@ void ControlNodeScheduler::LinkDataArrowByKernelGraph(const KernelGraphPtr &grap | |||
| auto input_with_index = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false); | |||
| auto input = input_with_index.first; | |||
| MS_EXCEPTION_IF_NULL(input); | |||
| if (sink_input_node_linked.count(input) > 0) { | |||
| continue; | |||
| } | |||
| if ((!input->isa<Parameter>()) || HasAbstractMonad(input) || IsPersistentDeviceTensor(input)) { | |||
| if (sink_input_node_linked.count(input) > 0 || HasAbstractMonad(input) || parser == nullptr || | |||
| (!parser->IsControlFlowDataArrow(graph, input))) { | |||
| continue; | |||
| } | |||
| auto front_node = graph->GetFrontAnfByBackendAnf(input); | |||
| @@ -1159,16 +1184,23 @@ void ControlNodeScheduler::LinkDataArrowByKernelGraph(const KernelGraphPtr &grap | |||
| // If the formal parameter is a tuple type, the parameter of the kernel graph will not directly correspond | |||
| // to the front parameter, but the node in the internal parameter. | |||
| const auto &from_node = from_node_with_index.first; | |||
| if (IsControlFlowArrow(parser, graph, from_node)) { | |||
| continue; | |||
| } | |||
| // Fetch actor and link. | |||
| auto type = FetchKernelTransformType(kernel, graph, {}); | |||
| auto to_actor = FetchActor(type, "", kernel, graph); | |||
| MS_EXCEPTION_IF_NULL(to_actor); | |||
| size_t from_index = 0; | |||
| if (AnfAlgo::CheckPrimitiveType(from_node, prim::kPrimSwitch) || | |||
| AnfAlgo::CheckPrimitiveType(from_node, prim::kPrimSwitchLayer)) { | |||
| const auto &actor_name = GetActorName(from_node); | |||
| auto actor = FetchActor(actor_name); | |||
| MS_EXCEPTION_IF_NULL(actor); | |||
| from_actor = dynamic_cast<ControlActor *>(actor); | |||
| } else { | |||
| from_index = from_actor->FetchNodePosition(from_node_with_index); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(from_actor); | |||
| auto from_index = from_actor->FetchNodePosition(from_node_with_index); | |||
| auto to_index = i; | |||
| if (type == KernelTransformType::kSuperKernelActor) { | |||
| auto super_kernel_actor = dynamic_cast<SuperKernelActor *>(to_actor); | |||
| @@ -559,7 +559,8 @@ void MindRTBackend::CompileGraph(const GraphSegmentPtr &segment, bool contain_mu | |||
| MS_EXCEPTION_IF_NULL(cut_node); | |||
| MS_LOG(INFO) << "Compile cut segment, the cut node: " << cut_node->DebugString(); | |||
| control_nodes_.push_back(cut_node); | |||
| if (AnfAlgo::IsCallNode(cut_node)) { | |||
| if (AnfAlgo::IsCallNode(cut_node) || AnfAlgo::CheckPrimitiveType(cut_node, prim::kPrimSwitch) || | |||
| AnfAlgo::CheckPrimitiveType(cut_node, prim::kPrimSwitchLayer)) { | |||
| const auto &func_graph = cut_node->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| func_graph_to_kernel_graph_ids_[func_graph].emplace_back(std::vector<GraphId>()); | |||