| @@ -208,9 +208,14 @@ void ControlActor::SendOutput(OpContext<DeviceTensor> *const context) { | |||
| // Send Partial. | |||
| for (const auto &partial_arrow : output_partial_arrows_) { | |||
| MS_EXCEPTION_IF_NULL(partial_arrow); | |||
| MS_EXCEPTION_IF_NULL(output_partial_.first); | |||
| ActorDispatcher::Send(partial_arrow->to_op_id_, &ControlActor::RunOpPartial, output_partial_.first, | |||
| output_partial_.second, IntToSize(partial_arrow->to_input_index_), context); | |||
| if (IntToSize(partial_arrow->from_output_index_) >= input_partials_.size()) { | |||
| MS_LOG(ERROR) << "Invalid partial input:" << partial_arrow->from_output_index_ | |||
| << " current:" << input_partials_.size() << " for actor:" << GetAID(); | |||
| } | |||
| auto output_partial = input_partials_[partial_arrow->from_output_index_]; | |||
| MS_EXCEPTION_IF_NULL(output_partial.first); | |||
| ActorDispatcher::Send(partial_arrow->to_op_id_, &ControlActor::RunOpPartial, output_partial.first, | |||
| output_partial.second, IntToSize(partial_arrow->to_input_index_), context); | |||
| } | |||
| } | |||
| } // namespace runtime | |||
| @@ -72,6 +72,22 @@ void ExitActor::SendOutput(OpContext<DeviceTensor> *const context) { | |||
| ActorDispatcher::Send(control_arrow, &OpActor::RunOpControl, source_aid, context); | |||
| } | |||
| } | |||
| // 3.Send output partial in output branch. | |||
| const auto &partial_iter = output_branch_partial_arrows_.find(output_branch_id_); | |||
| if (partial_iter != output_branch_partial_arrows_.end()) { | |||
| for (const auto &partial_arrow : partial_iter->second) { | |||
| MS_EXCEPTION_IF_NULL(partial_arrow); | |||
| if (IntToSize(partial_arrow->from_output_index_) >= input_partials_.size()) { | |||
| MS_LOG(ERROR) << "Invalid partial input:" << partial_arrow->from_output_index_ | |||
| << " current:" << input_partials_.size() << " for actor:" << GetAID(); | |||
| } | |||
| auto output_partial = input_partials_[partial_arrow->from_output_index_]; | |||
| MS_EXCEPTION_IF_NULL(output_partial.first); | |||
| Async(partial_arrow->to_op_id_, &ControlActor::RunOpPartial, output_partial.first, output_partial.second, | |||
| IntToSize(partial_arrow->to_input_index_), context); | |||
| } | |||
| } | |||
| } | |||
| void ExitActor::CopyDeviceAddress() { | |||
| @@ -29,23 +29,22 @@ void GatherActor::FetchInput(OpContext<DeviceTensor> *const context) { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| ControlActor::FetchInput(context); | |||
| output_partial_ = input_partials_[0]; | |||
| MS_EXCEPTION_IF_NULL(output_partial_.first); | |||
| MS_EXCEPTION_IF_NULL(input_partials_[0].first); | |||
| // Put other real parameter in partial. | |||
| for (const auto &device_tensor : input_device_tensors_) { | |||
| if (device_tensor != nullptr) { | |||
| output_partial_.second.emplace_back(device_tensor); | |||
| input_partials_[0].second.emplace_back(device_tensor); | |||
| } | |||
| } | |||
| } | |||
| void GatherActor::SendOutput(OpContext<DeviceTensor> *const context) { | |||
| // Send data with branch id. | |||
| const auto &iter = output_data_with_branch_id_arrows_.find(output_partial_.first); | |||
| const auto &iter = output_data_with_branch_id_arrows_.find(input_partials_[0].first); | |||
| if (iter != output_data_with_branch_id_arrows_.end()) { | |||
| for (const auto &data_with_branch_id_arrow : iter->second) { | |||
| ActorDispatcher::Send(data_with_branch_id_arrow, &EntranceActor::RunOpDataWithBranchID, output_partial_.second, | |||
| ActorDispatcher::Send(data_with_branch_id_arrow, &EntranceActor::RunOpDataWithBranchID, input_partials_[0].second, | |||
| output_branch_id_, context); | |||
| } | |||
| } | |||
| @@ -27,25 +27,56 @@ StackActor::StackActor(const std::string &name, const std::vector<KernelWithInde | |||
| void StackActor::Init() { | |||
| ControlActor::Init(); | |||
| // The stack actor has 6 parts of input : | |||
| // 1. Directly input data. | |||
| // 2. Direct input partial. | |||
| // 3. Weight. | |||
| // 4. Local tensor. | |||
| // 5. Call input data. | |||
| // 6. Call input partial. | |||
| input_datas_num_ = formal_parameters_.size() - input_parameter_data_num_ - input_parameter_partial_num_; | |||
| if (input_parameter_data_num_ < device_tensor_store_keys_.size() + local_device_tensors_.size()) { | |||
| MS_LOG(EXCEPTION) << "Invalid input parameter data num:" << input_parameter_data_num_ | |||
| << " device store num:" << device_tensor_store_keys_.size() << " local device tensor num" | |||
| << local_device_tensors_.size() << " for actor:" << GetAID(); | |||
| } | |||
| // Fetch the total number of input partial. | |||
| int total_partials_num = 0; | |||
| for (const auto &formal_parameter : formal_parameters_) { | |||
| if (AnfAlgo::IsCallNode(formal_parameter.first)) { | |||
| break; | |||
| const auto &abstract = formal_parameter.first->abstract(); | |||
| MS_EXCEPTION_IF_NULL(abstract); | |||
| const auto &real_abstract = FetchAbstractByIndex(abstract, formal_parameter.second); | |||
| MS_EXCEPTION_IF_NULL(real_abstract); | |||
| if (real_abstract->isa<abstract::AbstractFunction>()) { | |||
| total_partials_num++; | |||
| } | |||
| ++input_parameter_data_num_; | |||
| } | |||
| input_datas_num_ = formal_parameters_.size() - input_parameter_data_num_; | |||
| if (input_parameter_data_num_ < device_tensor_store_keys_.size()) { | |||
| MS_LOG(EXCEPTION) << "Invalid input parameter data num:" << input_parameter_data_num_ | |||
| << " device store num:" << device_tensor_store_keys_.size() << " for actor:" << GetAID(); | |||
| // Fetch call input data num. | |||
| input_datas_num_ = formal_parameters_.size() - total_partials_num - input_parameter_data_num_; | |||
| input_partials_num_ = total_partials_num - input_parameter_partial_num_; | |||
| // Fetch call input partial num. | |||
| input_parameter_data_num_ -= (device_tensor_store_keys_.size() + local_device_tensors_.size()); | |||
| // Check if the input num is valid. | |||
| if (input_parameter_data_num_ + input_parameter_partial_num_ + input_datas_num_ + input_partials_num_ + | |||
| device_tensor_store_keys_.size() + local_device_tensors_.size() != | |||
| formal_parameters_.size()) { | |||
| MS_LOG(EXCEPTION) << "Invalid input num, input parameter data num:" << input_parameter_data_num_ | |||
| << " input parameter partial num:" << input_parameter_partial_num_ | |||
| << " input data num:" << input_datas_num_ << " input partial num:" << input_partials_num_ | |||
| << " device tensor store size:" << device_tensor_store_keys_.size() | |||
| << " need total size:" << formal_parameters_.size() << " for actor:" << GetAID(); | |||
| } | |||
| input_parameter_data_num_ -= device_tensor_store_keys_.size(); | |||
| } | |||
| void StackActor::RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| MS_EXCEPTION_IF_NULL(input_data); | |||
| MS_EXCEPTION_IF_NULL(input_data->data_); | |||
| // The parameters from the inside of the subgraph need to be put into the stack. | |||
| if (IntToSize(input_data->index_) < input_parameter_data_num_ + device_tensor_store_keys_.size()) { | |||
| if (IntToSize(input_data->index_) < input_parameter_data_num_ + device_tensor_store_keys_.size() + | |||
| input_parameter_partial_num_ + local_device_tensors_.size()) { | |||
| FillStack(input_data, context); | |||
| } else { | |||
| // The outputs of call nodes are placed directly in the input data. | |||
| @@ -56,6 +87,22 @@ void StackActor::RunOpData(OpData<DeviceTensor> *const input_data, OpContext<Dev | |||
| } | |||
| } | |||
| void StackActor::RunOpPartial(FuncGraph *func_graph, std::vector<DeviceTensor *> input_data, size_t position, | |||
| OpContext<DeviceTensor> *const context) { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| // The parameters from the inside of the subgraph need to be put into the stack. | |||
| if (IntToSize(position) < input_parameter_data_num_ + device_tensor_store_keys_.size() + | |||
| input_parameter_partial_num_ + local_device_tensors_.size()) { | |||
| input_parameter_partial_[context->sequential_num_][position].push(OpPartial(func_graph, input_data)); | |||
| } else { | |||
| input_op_partials_[context->sequential_num_].emplace_back(position, OpPartial(func_graph, input_data)); | |||
| } | |||
| if (CheckRunningCondition(context)) { | |||
| Run(context); | |||
| } | |||
| } | |||
| void StackActor::FillStack(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| MS_EXCEPTION_IF_NULL(input_data); | |||
| @@ -122,6 +169,26 @@ bool StackActor::CheckRunningCondition(const OpContext<DeviceTensor> *context) c | |||
| return false; | |||
| } | |||
| } | |||
| if (input_parameter_partial_num_ != 0) { | |||
| const auto &partial_iter = input_parameter_partial_.find(context->sequential_num_); | |||
| if (partial_iter == input_parameter_partial_.end()) { | |||
| return false; | |||
| } | |||
| if (partial_iter->second.size() != input_parameter_partial_num_) { | |||
| return false; | |||
| } | |||
| auto iter = input_branch_ids_.find(context->sequential_num_); | |||
| if (iter == input_branch_ids_.end() || iter->second.empty()) { | |||
| MS_LOG(ERROR) << "There is no branch id for actor:" << GetAID(); | |||
| } | |||
| size_t branch_id_size = iter->second.size(); | |||
| if (std::any_of(partial_iter->second.begin(), partial_iter->second.end(), | |||
| [branch_id_size](const auto &one_stack) { return one_stack.second.size() != branch_id_size; })) { | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| @@ -133,13 +200,29 @@ void StackActor::FetchInput(OpContext<DeviceTensor> *const context) { | |||
| MS_LOG(ERROR) << "Invalid input for actor:" << GetAID(); | |||
| } | |||
| for (const auto &one_stack : data_iter->second) { | |||
| if (one_stack.first >= input_parameter_data_num_ + device_tensor_store_keys_.size()) { | |||
| if (one_stack.first >= input_parameter_data_num_ + device_tensor_store_keys_.size() + | |||
| local_device_tensors_.size() + input_parameter_partial_num_) { | |||
| MS_LOG(ERROR) << "Invalid input index:" << one_stack.first << " need:" << input_parameter_data_num_ | |||
| << " for actor:" << GetAID(); | |||
| } | |||
| input_device_tensors_[one_stack.first] = one_stack.second.top(); | |||
| } | |||
| } | |||
| if (input_parameter_partial_num_ != 0) { | |||
| const auto &partial_iter = input_parameter_partial_.find(context->sequential_num_); | |||
| if (partial_iter == input_parameter_partial_.end()) { | |||
| MS_LOG(ERROR) << "Invalid input for actor:" << GetAID(); | |||
| } | |||
| for (const auto &one_stack : partial_iter->second) { | |||
| if (one_stack.first >= input_parameter_data_num_ + device_tensor_store_keys_.size() + | |||
| local_device_tensors_.size() + input_parameter_partial_num_) { | |||
| MS_LOG(ERROR) << "Invalid input index:" << one_stack.first << " need:" << input_parameter_partial_ | |||
| << " for actor:" << GetAID(); | |||
| } | |||
| input_partials_[one_stack.first] = one_stack.second.top(); | |||
| } | |||
| } | |||
| ControlActor::FetchInput(context); | |||
| } | |||
| @@ -160,6 +243,20 @@ void StackActor::EraseInput(const OpContext<DeviceTensor> *const context) { | |||
| one_stack.second.pop(); | |||
| } | |||
| } | |||
| if (input_parameter_partial_num_ != 0) { | |||
| const auto &partial_iter = input_parameter_partial_.find(context->sequential_num_); | |||
| if (partial_iter == input_parameter_partial_.end()) { | |||
| MS_LOG(ERROR) << "Invalid input for actor:" << GetAID(); | |||
| } | |||
| for (auto &one_stack : partial_iter->second) { | |||
| if (one_stack.second.empty()) { | |||
| MS_LOG(ERROR) << "Input index:" << one_stack.first << " is null in actor:" << GetAID(); | |||
| } | |||
| one_stack.second.pop(); | |||
| } | |||
| } | |||
| } | |||
| } // namespace runtime | |||
| } // namespace mindspore | |||
| @@ -39,9 +39,11 @@ class StackActor : public ControlActor { | |||
| ~StackActor() override = default; | |||
| void Init() override; | |||
| // The input data of the stack actor needs to be pushed into the stack according to the input index, so it is | |||
| // implemented separately. | |||
| // The input data and partial of the stack actor needs to be pushed into the stack according to the input index, | |||
| // so it is implemented separately. | |||
| void RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) override; | |||
| void RunOpPartial(FuncGraph *func_graph, std::vector<DeviceTensor *> input_data, size_t position, | |||
| OpContext<DeviceTensor> *const context) override; | |||
| protected: | |||
| void FetchInput(OpContext<DeviceTensor> *const context) override; | |||
| @@ -56,12 +58,15 @@ class StackActor : public ControlActor { | |||
| // The device tensors created and stored by the stack. | |||
| std::vector<DeviceTensorPtr> created_device_tensors_; | |||
| // The input data records that the stack actor is copied from the input nodes and needs to be stored in the | |||
| // device tensor in the stack. | |||
| // The input data and partials records that the stack actor is copied from the input nodes and needs to be | |||
| // stored in the device tensor in the stack. | |||
| mindspore::HashMap<int, mindspore::HashMap<size_t, std::stack<DeviceTensor *>>> input_parameter_data_; | |||
| // Input parameter data num represents the number of actor's input come from funcgraph itself, these inputs | |||
| // will be ranked at the front of input. | |||
| mindspore::HashMap<int, mindspore::HashMap<size_t, std::stack<OpPartial>>> input_parameter_partial_; | |||
| // Input parameter num represents the number of actor's input come from funcgraph itself, these inputs will | |||
| // be ranked at the front of input. | |||
| size_t input_parameter_data_num_{0}; | |||
| size_t input_parameter_partial_num_{0}; | |||
| // The backend parameter is used to save the backend node corresponding to the device tensor in the stack. | |||
| // When these device tensors are used as output, they need to be placed in the node of the result arrow, | |||
| // so these nodes need to be saved. | |||
| @@ -54,7 +54,7 @@ void SwitchActor::FetchInput(OpContext<DeviceTensor> *const context) { | |||
| if (!output_partial_arrows_.empty()) { | |||
| auto func_graph = input_partials_[index + kSwitchCondPos].first; | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| output_partial_ = input_partials_[index + kSwitchCondPos]; | |||
| input_partials_[0] = input_partials_[index + kSwitchCondPos]; | |||
| } | |||
| for (auto &output_data : output_data_by_output_index_[kSwitchDefaultOutputNum - 1]) { | |||
| @@ -261,7 +261,7 @@ size_t HostQueueDataSourceActor::FetchNodePosition(const AnfNodePtr &data_node) | |||
| MS_EXCEPTION_IF_NULL(data_node); | |||
| const auto &iter = data_node_position_map_.find(data_node); | |||
| if (iter == data_node_position_map_.end()) { | |||
| MS_LOG(EXCEPTION) << "Data node: " << data_node->fullname_with_scope() << " is not exist."; | |||
| MS_LOG(EXCEPTION) << "Data node: " << data_node->DebugString() << " is not exist."; | |||
| } | |||
| return iter->second; | |||
| } | |||
| @@ -261,6 +261,65 @@ void CreateDeviceTensorForFrontNode(const KernelWithIndex &front_node_with_index | |||
| MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(node) << " addr:" << address; | |||
| AnfAlgo::SetOutputAddr(address, front_node_with_index.second, node.get()); | |||
| } | |||
| // Check if there is a recursive call to funcgraph, if a calls b, b calls c, and c calls a, it is a recursive call. | |||
| bool IsRecursionFunction(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *checked_funcgraphs, | |||
| const std::unordered_map<FuncGraphPtr, std::set<FuncGraphPtr>> &func_graph_call_relation) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| if (checked_funcgraphs->find(func_graph) != checked_funcgraphs->end()) { | |||
| return true; | |||
| } | |||
| checked_funcgraphs->emplace(func_graph); | |||
| auto iter = func_graph_call_relation.find(func_graph); | |||
| if (iter == func_graph_call_relation.end()) { | |||
| return false; | |||
| } | |||
| for (const auto &called_func_graph : iter->second) { | |||
| MS_EXCEPTION_IF_NULL(called_func_graph); | |||
| if (IsRecursionFunction(called_func_graph, checked_funcgraphs, func_graph_call_relation)) { | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| // Fetch all inputs of node. | |||
| std::vector<KernelWithIndex> FetchInputNodeByNode(const AnfNodePtr &node) { | |||
| // The node is divided into the following types: | |||
| // 1. depend and load. | |||
| const auto &node_with_index = | |||
| AnfAlgo::VisitKernelWithReturnType(node, 0, false, {prim::kPrimTupleGetItem, prim::kPrimMakeTuple}); | |||
| auto real_node = node_with_index.first; | |||
| MS_EXCEPTION_IF_NULL(real_node); | |||
| std::vector<KernelWithIndex> results; | |||
| // 2. MakeTuple. | |||
| if (AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimMakeTuple)) { | |||
| const auto &make_tuple_cnode = real_node->cast<CNodePtr>(); | |||
| const auto &make_tuple_inputs = make_tuple_cnode->inputs(); | |||
| for (size_t i = kMakeTupleInputStartPos; i < make_tuple_inputs.size(); ++i) { | |||
| const auto &sub_results = FetchInputNodeByNode(make_tuple_inputs[i]); | |||
| results.insert(results.end(), sub_results.begin(), sub_results.end()); | |||
| } | |||
| return results; | |||
| } | |||
| // 3. One output node. | |||
| const auto &abstract = real_node->abstract(); | |||
| if (abstract == nullptr || (!abstract->isa<abstract::AbstractTuple>())) { | |||
| if (abstract == nullptr) { | |||
| MS_LOG(WARNING) << "Empty abstract for node:" << real_node->DebugString(); | |||
| } | |||
| return {AnfAlgo::VisitKernelWithReturnType(real_node, 0)}; | |||
| } | |||
| // 4. Abstract is Tuple. | |||
| size_t output_num = AnfAlgo::GetOutputNumByAbstract(abstract); | |||
| for (size_t i = 0; i < output_num; ++i) { | |||
| results.emplace_back(real_node, i); | |||
| } | |||
| return results; | |||
| } | |||
| } // namespace | |||
| bool HasAbstractRef(const AnfNodePtr &node) { | |||
| @@ -285,6 +344,72 @@ KernelWithIndex GetFrontNodeByKernelGraph(const AnfNodePtr &backend_node, Kernel | |||
| return front_node_with_index; | |||
| } | |||
| std::vector<KernelWithIndex> FetchInputNodeByCNode(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (!node->isa<CNode>()) { | |||
| return {}; | |||
| } | |||
| std::vector<KernelWithIndex> results; | |||
| // The first input of normal cnode is the primitive of node, and the real input starts from the second input, | |||
| // but in control flow, the call node has no primitive, and the 0th input is funcgraph or partial. | |||
| size_t input_start_pos = kCNodeInputStartPos; | |||
| if (AnfAlgo::IsCallNode(node)) { | |||
| input_start_pos = 0; | |||
| } | |||
| const auto &cnode = node->cast<CNodePtr>(); | |||
| const auto inputs = cnode->inputs(); | |||
| // The first branch of the input of the switch node is the true branch, and the second is the false branch. | |||
| // But in switch actor, since the false value is 0, it corresponds to the first branch. Therefore, the input | |||
| // of the switch node needs to exchange the positions of the two branches. So deal separately. | |||
| if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch)) { | |||
| if (inputs.size() != kSwitchInputNum) { | |||
| MS_LOG(EXCEPTION) << "Invalid switch node:" << node->DebugString(); | |||
| } | |||
| results.emplace_back(AnfAlgo::VisitKernelWithReturnType(inputs[kSwitchCondPos], 0)); | |||
| results.emplace_back(AnfAlgo::VisitKernelWithReturnType(inputs[kSwitchFalseBranchPos], 0)); | |||
| results.emplace_back(AnfAlgo::VisitKernelWithReturnType(inputs[kSwitchTrueBranchPos], 0)); | |||
| return results; | |||
| } | |||
| for (size_t i = input_start_pos; i < inputs.size(); ++i) { | |||
| MS_EXCEPTION_IF_NULL(inputs[i]); | |||
| // skip monad node. | |||
| if (HasAbstractMonad(inputs[i])) { | |||
| continue; | |||
| } | |||
| const auto &sub_results = FetchInputNodeByNode(inputs[i]); | |||
| results.insert(results.end(), sub_results.begin(), sub_results.end()); | |||
| } | |||
| return results; | |||
| } | |||
| abstract::AbstractBasePtr FetchAbstractByIndex(const AbstractBasePtr &abstract, size_t index) { | |||
| MS_EXCEPTION_IF_NULL(abstract); | |||
| if (!abstract->isa<abstract::AbstractTuple>()) { | |||
| if (index != 0) { | |||
| MS_LOG(EXCEPTION) << "Invalid abstract index:" << index << " for abstract:" << abstract->ToString(); | |||
| } | |||
| return abstract; | |||
| } | |||
| auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>(); | |||
| MS_EXCEPTION_IF_NULL(tuple_abstract); | |||
| const auto &sub_abstracts = tuple_abstract->elements(); | |||
| size_t real_index = index; | |||
| for (const auto &sub_abstract : sub_abstracts) { | |||
| size_t tmp_index = AnfAlgo::GetOutputNumByAbstract(sub_abstract); | |||
| if (real_index >= tmp_index) { | |||
| real_index -= tmp_index; | |||
| continue; | |||
| } | |||
| return FetchAbstractByIndex(sub_abstract, real_index); | |||
| } | |||
| MS_LOG(EXCEPTION) << "Invalid abstract index:" << index << " for abstract:" << abstract->ToString(); | |||
| return nullptr; | |||
| } | |||
| void ControlNodeParser::Parse(const std::vector<AnfNodePtr> &control_nodes, const std::vector<KernelGraphPtr> &graphs, | |||
| const std::vector<DeviceContext *> &device_contexts, const FuncGraphPtr &root_graph, | |||
| const FuncGraphToKernelGraph &func_graph_to_kernel_graphs) { | |||
| @@ -309,12 +434,18 @@ void ControlNodeParser::Parse(const std::vector<AnfNodePtr> &control_nodes, cons | |||
| ParseCallNodeToFuncGraph(control_nodes); | |||
| FetchNeedStackControlNode(control_nodes); | |||
| ParseUnRecursionCallNode(); | |||
| FetchFrontNodeToKernelGraph(graphs); | |||
| ParseFormalToRealParameter(control_nodes); | |||
| ParseFrontToBackendParameter(graphs, device_contexts); | |||
| CreateDeviceTensorForRootGraphParameter(device_contexts[0]); | |||
| FetchFrontToBackendKernel(graphs, device_contexts); | |||
| FetchHostParameterToWeight(); | |||
| @@ -323,7 +454,7 @@ void ControlNodeParser::Parse(const std::vector<AnfNodePtr> &control_nodes, cons | |||
| ParseDeviceContext(control_nodes, graphs, device_contexts, func_graph_to_kernel_graphs); | |||
| FetchFrontValueNode(device_contexts[0]); | |||
| FetchFrontValueNode(control_nodes, device_contexts[0]); | |||
| FetchControlNodeParameter(control_nodes); | |||
| @@ -363,6 +494,11 @@ bool ControlNodeParser::IsRootGraphParameter(const AnfNodePtr &node) { | |||
| return find(root_graph_parameters_.begin(), root_graph_parameters_.end(), node) != root_graph_parameters_.end(); | |||
| } | |||
| bool ControlNodeParser::IsRecursionCallNode(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| return find(unrecursion_call_nodes_.begin(), unrecursion_call_nodes_.end(), node) == unrecursion_call_nodes_.end(); | |||
| } | |||
| void ControlNodeParser::ParseDeviceContext(const std::vector<AnfNodePtr> &control_nodes, | |||
| const std::vector<KernelGraphPtr> &kernel_graphs, | |||
| const std::vector<DeviceContext *> &device_contexts, | |||
| @@ -604,6 +740,16 @@ void ControlNodeParser::ParseDeviceContextForReturnNode(const DeviceContext *def | |||
| void ControlNodeParser::FetchFrontNodeToKernelGraph(const std::vector<KernelGraphPtr> &graphs) { | |||
| for (const auto &graph : graphs) { | |||
| if (graph->execution_order().empty()) { | |||
| continue; | |||
| } | |||
| for (auto &kernel : graph->execution_order()) { | |||
| auto front_node = graph->GetFrontAnfByBackendAnf(kernel); | |||
| if (front_node != nullptr) { | |||
| front_node_to_kernel_graph_[front_node] = graph; | |||
| } | |||
| } | |||
| const auto &graph_outputs = graph->graph_output_map(); | |||
| for (const auto &backend_to_front : graph_outputs) { | |||
| front_node_to_kernel_graph_[backend_to_front.second.first] = graph; | |||
| @@ -654,7 +800,8 @@ FuncGraphPtr ControlNodeParser::FetchFuncGraphByKernelGraph(const KernelGraph *c | |||
| return nullptr; | |||
| } | |||
| void ControlNodeParser::FetchFrontValueNode(DeviceContext *default_context) { | |||
| void ControlNodeParser::FetchFrontValueNode(const std::vector<AnfNodePtr> &control_nodes, | |||
| DeviceContext *default_context) { | |||
| MS_EXCEPTION_IF_NULL(default_context); | |||
| for (const auto &formal_to_real_parameter : formal_to_real_parameters_) { | |||
| @@ -688,6 +835,27 @@ void ControlNodeParser::FetchFrontValueNode(DeviceContext *default_context) { | |||
| front_value_nodes_.emplace(front_to_backend_parameters.first, device_context); | |||
| } | |||
| } | |||
| // Create device tensors for those value nodes which direct return by a return node. | |||
| for (const auto &control_node : control_nodes) { | |||
| if (!AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) { | |||
| continue; | |||
| } | |||
| auto input_with_indexs = FetchInputNodeByCNode(control_node); | |||
| auto iter = control_node_to_device_contexts_.find(control_node); | |||
| if (iter == control_node_to_device_contexts_.end() || iter->second.size() != input_with_indexs.size()) { | |||
| MS_LOG(EXCEPTION) << "Invalid device context for control node:" << control_node->DebugString(); | |||
| } | |||
| for (size_t i = 0; i < input_with_indexs.size(); ++i) { | |||
| const auto &input_with_index = input_with_indexs[i]; | |||
| if (input_with_index.first->isa<ValueNode>() && (!IsValueNode<FuncGraph>(input_with_index.first)) && | |||
| front_value_nodes_.find({input_with_index, iter->second[i]}) == front_value_nodes_.end()) { | |||
| CreateDeviceTensorForFrontNode(input_with_index, iter->second[i]); | |||
| front_value_nodes_.emplace(input_with_index, iter->second[i]); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void ControlNodeParser::ParseFormalToRealParameter(const std::vector<AnfNodePtr> &control_nodes) { | |||
| @@ -977,7 +1145,9 @@ void ControlNodeParser::FetchFrontToBackendKernel(const std::vector<KernelGraphP | |||
| const auto graph_output_map = graph->graph_output_map(); | |||
| for (const auto &output_pair : graph_output_map) { | |||
| front_to_backend_kernels_[output_pair.second] = {output_pair.first, device_context}; | |||
| if (output_pair.first.first->isa<CNode>()) { | |||
| front_to_backend_kernels_[output_pair.second] = {output_pair.first, device_context}; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -1039,5 +1209,57 @@ void ControlNodeParser::ParseFirstControlNodeForFuncGraph(const std::vector<AnfN | |||
| } | |||
| } | |||
| } | |||
| void ControlNodeParser::ParseUnRecursionCallNode() { | |||
| std::unordered_map<FuncGraphPtr, std::set<FuncGraphPtr>> func_graph_call_relation; | |||
| // Collect the call relationship between funcgraphs. | |||
| for (const auto &call_node_to_func_graphs : call_node_to_func_graphs_) { | |||
| const auto &call_node = call_node_to_func_graphs.first; | |||
| MS_EXCEPTION_IF_NULL(call_node); | |||
| const auto &func_graph = call_node->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| func_graph_call_relation[func_graph].insert(call_node_to_func_graphs.second.begin(), | |||
| call_node_to_func_graphs.second.end()); | |||
| } | |||
| for (const auto &call_node_to_func_graphs : call_node_to_func_graphs_) { | |||
| const auto &call_node = call_node_to_func_graphs.first; | |||
| std::set<FuncGraphPtr> checked_func_graphs{call_node->func_graph()}; | |||
| bool is_recursion_call_node = false; | |||
| if (std::any_of(call_node_to_func_graphs.second.begin(), call_node_to_func_graphs.second.end(), | |||
| [&is_recursion_call_node, &checked_func_graphs, &func_graph_call_relation](const auto &func_graph) { | |||
| return IsRecursionFunction(func_graph, &checked_func_graphs, func_graph_call_relation); | |||
| })) { | |||
| is_recursion_call_node = true; | |||
| } | |||
| if (!is_recursion_call_node && need_stack_control_nodes_.find(call_node) == need_stack_control_nodes_.end()) { | |||
| unrecursion_call_nodes_.emplace(call_node); | |||
| } | |||
| } | |||
| } | |||
| void ControlNodeParser::FetchNeedStackControlNode(const std::vector<AnfNodePtr> &control_nodes) { | |||
| for (const auto &control_node : control_nodes) { | |||
| MS_EXCEPTION_IF_NULL(control_node); | |||
| if (AnfAlgo::IsCallNode(control_node)) { | |||
| auto input_with_indexs = FetchInputNodeByCNode(control_node); | |||
| if (std::any_of(input_with_indexs.begin(), input_with_indexs.end(), | |||
| [](const auto &input_with_index) { return AnfAlgo::IsCallNode(input_with_index.first); })) { | |||
| need_stack_control_nodes_.emplace(control_node); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void ControlNodeParser::CreateDeviceTensorForRootGraphParameter(DeviceContext *default_context) { | |||
| MS_EXCEPTION_IF_NULL(default_context); | |||
| for (const auto ¶meter : root_graph_parameters_) { | |||
| KernelWithIndex parameter_with_index(parameter, 0); | |||
| if (front_to_backend_parameters_.find(parameter_with_index) == front_to_backend_parameters_.end()) { | |||
| CreateDeviceTensorForFrontNode(parameter_with_index, default_context); | |||
| front_to_backend_parameters_[parameter_with_index].emplace(parameter, default_context); | |||
| } | |||
| } | |||
| } | |||
| } // namespace runtime | |||
| } // namespace mindspore | |||
| @@ -24,6 +24,7 @@ | |||
| #include <queue> | |||
| #include <map> | |||
| #include <utility> | |||
| #include <unordered_map> | |||
| #include <algorithm> | |||
| #include "utils/hash_map.h" | |||
| #include "runtime/hardware/device_context.h" | |||
| @@ -76,7 +77,10 @@ bool HasAbstractRef(const AnfNodePtr &node); | |||
| // Get the front node corresponding to the backend node, if the front node is not a parameter node, return the | |||
| // corresponding cnode. | |||
| KernelWithIndex GetFrontNodeByKernelGraph(const AnfNodePtr &backend_node, KernelGraph *const graph); | |||
| // Get all the real input of the frontend node, skip the virtual node like maketuple, tuplegetitem. | |||
| std::vector<KernelWithIndex> FetchInputNodeByCNode(const AnfNodePtr &node); | |||
| // Fetch the sub abstract from the top abstract by the index. | |||
| abstract::AbstractBasePtr FetchAbstractByIndex(const AbstractBasePtr &abstract, size_t index); | |||
| // ControlNodeParser is used to parse control nodes, and get the edges between nodes. | |||
| class ControlNodeParser { | |||
| public: | |||
| @@ -94,6 +98,7 @@ class ControlNodeParser { | |||
| // 2. In the kernel graph with call node input, the data arrow needs to be connected to the stack actor. | |||
| bool IsControlFlowDataArrow(const KernelGraphPtr &graph, const AnfNodePtr &node); | |||
| bool IsRootGraphParameter(const AnfNodePtr &node); | |||
| bool IsRecursionCallNode(const AnfNodePtr &node); | |||
| const std::vector<AnfNodePtr> &control_node_parameters() const { return control_node_parameters_; } | |||
| const FrontToBackendNodeWithContext &front_to_backend_parameters() const { return front_to_backend_parameters_; } | |||
| @@ -117,7 +122,7 @@ class ControlNodeParser { | |||
| // value nodes will not enter the kernel graph, so these nodes need to be saved separately, and space is allocated for | |||
| // them separately during initialization. | |||
| // The interface is initialized by finding the backend node in the kernel graph that the front node finally sends to. | |||
| void FetchFrontValueNode(DeviceContext *default_context); | |||
| void FetchFrontValueNode(const std::vector<AnfNodePtr> &control_nodes, DeviceContext *default_context); | |||
| // Create branch id for all call node in the control flow. | |||
| void CreateBranchIDForCallNode(const std::vector<AnfNodePtr> &control_nodes); | |||
| @@ -138,6 +143,8 @@ class ControlNodeParser { | |||
| const FormalToRealParameter &formal_to_real_parameters, | |||
| std::set<KernelWithIndex> *total_real_parameters, | |||
| std::set<AnfNodePtr> *invalid_real_parameter); | |||
| // Get all the call nodes without a recursion call relation. | |||
| void ParseUnRecursionCallNode(); | |||
| // Parse the device context of the control node. In a heterogeneous scenario, different device contexts need to be | |||
| // copied between different device memories. The analysis steps: | |||
| @@ -179,7 +186,12 @@ class ControlNodeParser { | |||
| void FetchAutoMonadNode(const std::vector<AnfNodePtr> &control_nodes); | |||
| // Fetch the formal parameter in root graph by parameters in subgraph. | |||
| AnfNodePtr FetchRootGraphFrontNodeBySubFrontNode(const AnfNodePtr &sub_front_node); | |||
| // Get the control nodes which need to add a stack actor for them. | |||
| // When a control node has input that is a call node, you need to add a stack actor for it. | |||
| void FetchNeedStackControlNode(const std::vector<AnfNodePtr> &control_nodes); | |||
| // When the parameter is directly used as the condition of the switch, there will be no back-end node, and a device | |||
| // tensor needs to be created for it. | |||
| void CreateDeviceTensorForRootGraphParameter(DeviceContext *default_context); | |||
| // In control flow, funcgraph will be cut into multiple kernel graphs for execution, and this relationship is recorded | |||
| // in this map. | |||
| FuncGraphToKernelGraph func_graph_to_kernel_graphs_; | |||
| @@ -220,6 +232,12 @@ class ControlNodeParser { | |||
| mindspore::HashMap<AnfNodePtr, AnfNodePtr> kernel_to_call_nodes_; | |||
| // Control nodes without a control node input in the topological sorting of funcgraph. | |||
| mindspore::HashMap<FuncGraphPtr, std::set<AnfNodePtr>> func_graph_to_first_control_nodes_; | |||
| // Call nodes without recursive call. The funcgraphs of the call will not call the funcgraph where the call node | |||
| // belong. | |||
| std::set<AnfNodePtr> unrecursion_call_nodes_; | |||
| // Those control nodes that need to create the corresponding stack actor, when there is a call node in the inputs | |||
| // of the control node, the stack actor is needed to collect these inputs. | |||
| std::set<AnfNodePtr> need_stack_control_nodes_; | |||
| // In heterogeneous scenario, each parameter has its own device context type, so the device context corresponding | |||
| // to the type needs to be parsed in advance so that it can add some copy operation in the scheduler. | |||
| @@ -35,66 +35,41 @@ std::string GetActorName(const AnfNodePtr &node) { | |||
| } | |||
| } | |||
| // Get all the real input of the frontend node, skip the virtual node like maketuple, tuplegetitem. | |||
| std::vector<KernelWithIndex> FetchInputNodeByCNode(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (!node->isa<CNode>()) { | |||
| return {}; | |||
| // Fetch the depend nodes according to the monad node. | |||
| void FetchRealDependNodeByAutoMonad(const AnfNodePtr &node, std::set<AnfNodePtr> *depend_nodes) { | |||
| // Find the real input node, include the monad node and make tuple node. | |||
| const std::vector<PrimitivePtr> return_types = {prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad, | |||
| prim::kPrimMakeTuple}; | |||
| const auto &node_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, false, return_types); | |||
| auto real_node = node_with_index.first; | |||
| MS_EXCEPTION_IF_NULL(real_node); | |||
| if (!real_node->isa<CNode>()) { | |||
| return; | |||
| } | |||
| std::vector<KernelWithIndex> results; | |||
| // The first input of normal cnode is the primitive of node, and the real input starts from the second input, | |||
| // but in control flow, the call node has no primitive, and the 0th input is funcgraph or partial. | |||
| size_t input_start_pos = kCNodeInputStartPos; | |||
| if (AnfAlgo::IsCallNode(node)) { | |||
| input_start_pos = 0; | |||
| } | |||
| const auto &cnode = node->cast<CNodePtr>(); | |||
| const auto inputs = cnode->inputs(); | |||
| const auto &real_cnode = real_node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(real_cnode); | |||
| const auto &real_inputs = real_cnode->inputs(); | |||
| // The first branch of the input of the switch node is the true branch, and the second is the false branch. | |||
| // But in switch actor, since the false value is 0, it corresponds to the first branch. Therefore, the input | |||
| // of the switch node needs to exchange the positions of the two branches. So deal separately. | |||
| if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch)) { | |||
| if (inputs.size() != kSwitchInputNum) { | |||
| MS_LOG(EXCEPTION) << "Invalid switch node:" << node->DebugString(); | |||
| // Make tuple node needs to be expanded. | |||
| if (AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimMakeTuple)) { | |||
| for (size_t i = 1; i < real_inputs.size(); ++i) { | |||
| MS_EXCEPTION_IF_NULL(real_inputs[i]); | |||
| FetchRealDependNodeByAutoMonad(real_inputs[i], depend_nodes); | |||
| } | |||
| results.emplace_back(AnfAlgo::VisitKernelWithReturnType(inputs[kSwitchCondPos], 0)); | |||
| results.emplace_back(AnfAlgo::VisitKernelWithReturnType(inputs[kSwitchFalseBranchPos], 0)); | |||
| results.emplace_back(AnfAlgo::VisitKernelWithReturnType(inputs[kSwitchTrueBranchPos], 0)); | |||
| return results; | |||
| return; | |||
| } | |||
| for (size_t i = input_start_pos; i < inputs.size(); ++i) { | |||
| MS_EXCEPTION_IF_NULL(inputs[i]); | |||
| // skip monad node. | |||
| if (HasAbstractMonad(inputs[i])) { | |||
| continue; | |||
| } | |||
| const auto &node_with_index = | |||
| AnfAlgo::VisitKernelWithReturnType(inputs[i], 0, false, {prim::kPrimTupleGetItem, prim::kPrimMakeTuple}); | |||
| MS_EXCEPTION_IF_NULL(node_with_index.first); | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(node_with_index.first); | |||
| for (size_t j = 0; j < output_num; ++j) { | |||
| if (AnfAlgo::CheckPrimitiveType(node_with_index.first, prim::kPrimMakeTuple)) { | |||
| const auto &make_tuple_cnode = node_with_index.first->cast<CNodePtr>(); | |||
| const auto &make_tuple_inputs = make_tuple_cnode->inputs(); | |||
| if (make_tuple_inputs.size() <= j + kMakeTupleInputStartPos) { | |||
| MS_LOG(EXCEPTION) << "Invalid input:" << j + kMakeTupleInputStartPos | |||
| << " for make tuple node:" << make_tuple_cnode->DebugString(); | |||
| } | |||
| results.emplace_back(AnfAlgo::VisitKernelWithReturnType(make_tuple_inputs[j + kMakeTupleInputStartPos], 0)); | |||
| } else if (node_with_index.first->isa<ValueNode>()) { | |||
| // When the value node is a value tuple, the value node will have multiple outputs, which need to be directly | |||
| // put into the vector, and the output cannot be obtained through the VisitKernelWithReturnType interface. | |||
| results.emplace_back(node_with_index.first, j); | |||
| } else { | |||
| results.emplace_back(AnfAlgo::VisitKernelWithReturnType(node_with_index.first, j)); | |||
| } | |||
| if (AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimDepend) || | |||
| AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimLoad)) { | |||
| FetchRealDependNodeByAutoMonad(real_inputs[kDependAttachNodeIndex], depend_nodes); | |||
| } else if (AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimUpdateState)) { | |||
| for (size_t i = kUpdateStateRealInput; i < real_inputs.size(); ++i) { | |||
| FetchRealDependNodeByAutoMonad(real_inputs[i], depend_nodes); | |||
| } | |||
| } else { | |||
| depend_nodes->emplace(real_node); | |||
| } | |||
| return results; | |||
| } | |||
| } // namespace | |||
| @@ -283,6 +258,8 @@ std::vector<StackActorPtr> ControlNodeScheduler::BuildStackActor(const GraphComp | |||
| // Create a corresponding stack actor for each kernel graph that has a call node as input. | |||
| for (const auto &graph_with_context : parser->call_input_kernel_graphs_) { | |||
| std::vector<KernelWithIndex> formal_parameters; | |||
| size_t input_parameter_data_num = 0; | |||
| const auto &graph = graph_with_context.first; | |||
| const auto &device_context = graph_with_context.second; | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| @@ -303,10 +280,12 @@ std::vector<StackActorPtr> ControlNodeScheduler::BuildStackActor(const GraphComp | |||
| } | |||
| // If the input comes from inside funcgraph, put it at the front of the vector, otherwise put it at the end. | |||
| if (AnfAlgo::IsCallNode(front_node_with_index.first)) { | |||
| if (AnfAlgo::IsCallNode(front_node_with_index.first) && | |||
| (parser->IsRecursionCallNode(front_node_with_index.first))) { | |||
| formal_parameters.emplace_back(front_node_with_index); | |||
| } else { | |||
| formal_parameters.insert(formal_parameters.begin(), front_node_with_index); | |||
| input_parameter_data_num++; | |||
| } | |||
| } | |||
| @@ -315,11 +294,79 @@ std::vector<StackActorPtr> ControlNodeScheduler::BuildStackActor(const GraphComp | |||
| stack_actors.emplace_back(stack_actor); | |||
| stack_actor->device_contexts_.insert(stack_actor->device_contexts_.begin(), formal_parameters.size(), | |||
| device_context); | |||
| stack_actor->input_parameter_data_num_ = input_parameter_data_num; | |||
| InsertActor(stack_actor.get()); | |||
| } | |||
| // Create stack actors for control nodes. | |||
| BuildStackActorForControlNode(graph_compiler_info, &stack_actors); | |||
| return stack_actors; | |||
| } | |||
| void ControlNodeScheduler::BuildStackActorForControlNode(const GraphCompilerInfo &graph_compiler_info, | |||
| std::vector<StackActorPtr> *stack_actors) { | |||
| const auto &parser = graph_compiler_info.control_node_parser_; | |||
| MS_EXCEPTION_IF_NULL(parser); | |||
| for (const auto &need_stack_control_node : parser->need_stack_control_nodes_) { | |||
| MS_EXCEPTION_IF_NULL(need_stack_control_node); | |||
| if (!AnfAlgo::IsCallNode(need_stack_control_node)) { | |||
| continue; | |||
| } | |||
| std::vector<KernelWithIndex> formal_parameters; | |||
| std::vector<const DeviceContext *> device_contexts; | |||
| size_t input_parameter_data_num = 0; | |||
| size_t input_parameter_partials_num = 0; | |||
| // Fetch the control actor of control node. | |||
| auto gather_actor_name = GetActorName(need_stack_control_node); | |||
| auto actor = FetchActor(gather_actor_name); | |||
| MS_EXCEPTION_IF_NULL(actor); | |||
| auto gather_actor = dynamic_cast<GatherActor *>(actor); | |||
| MS_EXCEPTION_IF_NULL(gather_actor); | |||
| if (gather_actor->formal_parameters_.size() > gather_actor->device_contexts_.size()) { | |||
| MS_LOG(EXCEPTION) << "Invalid device context size:" << gather_actor->device_contexts_.size() | |||
| << " and formal parameter size:" << gather_actor->formal_parameters_.size() | |||
| << " for actor:" << gather_actor->GetAID(); | |||
| } | |||
| // Collect formal parameters and device contexts, skip the value nodes. | |||
| for (size_t i = 0; i < gather_actor->formal_parameters_.size(); ++i) { | |||
| const auto ¶meter = gather_actor->formal_parameters_[i]; | |||
| auto device_context = gather_actor->device_contexts_[i]; | |||
| if (AnfAlgo::IsCallNode(parameter.first) && (parser->IsRecursionCallNode(parameter.first))) { | |||
| formal_parameters.emplace_back(parameter); | |||
| device_contexts.emplace_back(device_context); | |||
| } else if (parameter.first->isa<ValueNode>()) { | |||
| continue; | |||
| } else { | |||
| formal_parameters.insert(formal_parameters.begin(), parameter); | |||
| device_contexts.insert(device_contexts.begin(), device_context); | |||
| const auto &abstract = parameter.first->abstract(); | |||
| MS_EXCEPTION_IF_NULL(abstract); | |||
| const auto &real_abstract = FetchAbstractByIndex(abstract, parameter.second); | |||
| MS_EXCEPTION_IF_NULL(real_abstract); | |||
| if (real_abstract->isa<abstract::AbstractFunction>()) { | |||
| input_parameter_partials_num++; | |||
| } else { | |||
| input_parameter_data_num++; | |||
| } | |||
| } | |||
| } | |||
| // Create stack actor. | |||
| const auto &stack_actor_name = GetActorName(need_stack_control_node) + kStackActorNameSuffix; | |||
| const auto &stack_actor = std::make_shared<StackActor>(stack_actor_name, formal_parameters); | |||
| stack_actor->device_contexts_ = device_contexts; | |||
| stack_actor->input_parameter_data_num_ = input_parameter_data_num; | |||
| stack_actor->input_parameter_partial_num_ = input_parameter_partials_num; | |||
| InsertActor(stack_actor.get()); | |||
| stack_actors->emplace_back(stack_actor); | |||
| } | |||
| } | |||
| void ControlNodeScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info) { | |||
| MS_EXCEPTION_IF_NULL(actor_set); | |||
| MS_EXCEPTION_IF_NULL(actor_set->control_actors_); | |||
| @@ -362,11 +409,18 @@ void ControlNodeScheduler::LinkArrowForControlActor(ControlActorSet *const contr | |||
| } | |||
| for (auto &gather_actor : control_actor_set->gather_actors_) { | |||
| for (size_t i = 0; i < gather_actor->formal_parameters_.size(); ++i) { | |||
| LinkArrowbyFormalParameter(gather_actor.get(), gather_actor->formal_parameters_[i], {gather_actor->node_, i}, | |||
| parser); | |||
| MS_EXCEPTION_IF_NULL(gather_actor->node_); | |||
| if (parser->need_stack_control_nodes_.find(gather_actor->node_) == parser->need_stack_control_nodes_.end()) { | |||
| for (size_t i = 0; i < gather_actor->formal_parameters_.size(); ++i) { | |||
| LinkArrowbyFormalParameter(gather_actor.get(), gather_actor->formal_parameters_[i], {gather_actor->node_, i}, | |||
| parser); | |||
| } | |||
| } else { | |||
| // If the control actor has a corresponding stack actor, the input should be linked to the stack actor. | |||
| LinkArrowFromStackActor(gather_actor.get()); | |||
| } | |||
| } | |||
| for (auto &entrance_actor : control_actor_set->entrance_actors_) { | |||
| for (const auto &call_node : entrance_actor->call_nodes_) { | |||
| LinkArrowbyFormalParameter(entrance_actor.get(), call_node, {entrance_actor->node_, 0}, parser); | |||
| @@ -387,6 +441,38 @@ void ControlNodeScheduler::LinkArrowForControlActor(ControlActorSet *const contr | |||
| } | |||
| } | |||
| void ControlNodeScheduler::LinkArrowFromStackActor(ControlActor *to_actor) { | |||
| MS_EXCEPTION_IF_NULL(to_actor->node_); | |||
| auto stack_actor_name = GetActorName(to_actor->node_) + kStackActorNameSuffix; | |||
| auto actor = FetchActor(stack_actor_name); | |||
| MS_EXCEPTION_IF_NULL(actor); | |||
| auto stack_actor = dynamic_cast<StackActor *>(actor); | |||
| MS_EXCEPTION_IF_NULL(stack_actor); | |||
| for (size_t to_index = 0; to_index < to_actor->formal_parameters_.size(); ++to_index) { | |||
| const auto &formal_parameter = to_actor->formal_parameters_[to_index]; | |||
| const auto &from_node = formal_parameter.first; | |||
| if (from_node->isa<ValueNode>()) { | |||
| LinkArrowByValueNode(from_node, to_actor, formal_parameter.second, to_index); | |||
| continue; | |||
| } | |||
| // Fetch the arrow type of input. | |||
| size_t from_index = stack_actor->FetchNodePosition(formal_parameter); | |||
| const auto &abstract = formal_parameter.first->abstract(); | |||
| MS_EXCEPTION_IF_NULL(abstract); | |||
| const auto &real_abstract = FetchAbstractByIndex(abstract, formal_parameter.second); | |||
| MS_EXCEPTION_IF_NULL(real_abstract); | |||
| // Link arrow according to abstract. | |||
| if (real_abstract->isa<abstract::AbstractFunction>()) { | |||
| LinkPartialArrow(stack_actor, to_actor, from_index, to_index); | |||
| } else { | |||
| LinkDataArrow(stack_actor, to_actor, from_index, to_index); | |||
| } | |||
| } | |||
| } | |||
| void ControlNodeScheduler::LinkArrowbyFormalParameter(ControlActor *const to_actor, | |||
| const KernelWithIndex &from_node_with_index, | |||
| const KernelWithIndex &to_node_with_index, | |||
| @@ -394,20 +480,7 @@ void ControlNodeScheduler::LinkArrowbyFormalParameter(ControlActor *const to_act | |||
| const auto &from_node = from_node_with_index.first; | |||
| MS_EXCEPTION_IF_NULL(from_node); | |||
| if (from_node->isa<ValueNode>()) { | |||
| if (IsValueNode<FuncGraph>(from_node)) { | |||
| // Link local partial. | |||
| const auto &func_graph = GetValueNode<FuncGraphPtr>(from_node); | |||
| to_actor->local_partials_[to_node_with_index.second] = OpPartial(func_graph.get(), {}); | |||
| } else { | |||
| // Link device store value node. | |||
| if (!AnfAlgo::OutputAddrExist(from_node, from_node_with_index.second)) { | |||
| MS_LOG(EXCEPTION) << "Invalid output address index:" << from_node_with_index.second | |||
| << " for value node:" << from_node->DebugString(); | |||
| } | |||
| to_actor->local_device_tensors_[to_node_with_index.second] = | |||
| AnfAlgo::GetMutableOutputAddr(from_node, from_node_with_index.second, false).get(); | |||
| to_actor->local_device_tensors_[to_node_with_index.second]->SetNodeIndex(from_node, from_node_with_index.second); | |||
| } | |||
| LinkArrowByValueNode(from_node, to_actor, from_node_with_index.second, to_node_with_index.second); | |||
| } else if (from_node->isa<Parameter>()) { | |||
| LinkArrowByParameter(from_node, to_actor, from_node_with_index, to_node_with_index, parser); | |||
| } else if (AnfAlgo::IsCallNode(from_node)) { | |||
| @@ -436,6 +509,26 @@ void ControlNodeScheduler::LinkArrowbyFormalParameter(ControlActor *const to_act | |||
| } | |||
| } | |||
| void ControlNodeScheduler::LinkArrowByValueNode(const AnfNodePtr &value_node, ControlActor *const to_actor, | |||
| size_t from_index, size_t to_index) { | |||
| MS_EXCEPTION_IF_NULL(value_node); | |||
| MS_EXCEPTION_IF_NULL(to_actor); | |||
| if (IsValueNode<FuncGraph>(value_node)) { | |||
| // Link local partial. | |||
| const auto &func_graph = GetValueNode<FuncGraphPtr>(value_node); | |||
| to_actor->local_partials_[to_index] = OpPartial(func_graph.get(), {}); | |||
| } else { | |||
| // Link device store value node. | |||
| if (!AnfAlgo::OutputAddrExist(value_node, from_index)) { | |||
| MS_LOG(EXCEPTION) << "Invalid output address index:" << from_index | |||
| << " for value node:" << value_node->DebugString(); | |||
| } | |||
| to_actor->local_device_tensors_[to_index] = AnfAlgo::GetMutableOutputAddr(value_node, from_index, false).get(); | |||
| to_actor->local_device_tensors_[to_index]->SetNodeIndex(value_node, from_index); | |||
| } | |||
| } | |||
| void ControlNodeScheduler::LinkArrowByParameter(const AnfNodePtr ¶meter, ControlActor *const to_actor, | |||
| const KernelWithIndex &from_node_with_index, | |||
| const KernelWithIndex &to_node_with_index, | |||
| @@ -467,6 +560,11 @@ void ControlNodeScheduler::LinkArrowByCallNode(const AnfNodePtr &call_node, Cont | |||
| if (to_actor->type_ != KernelTransformType::kEntranceActor) { | |||
| // Link arrow from exit actor to control actor. | |||
| const auto &abstract = call_node->abstract(); | |||
| MS_EXCEPTION_IF_NULL(abstract); | |||
| const auto &real_abstract = FetchAbstractByIndex(abstract, from_node_with_index.second); | |||
| MS_EXCEPTION_IF_NULL(real_abstract); | |||
| const auto &func_graphs = AnfAlgo::GetFuncGraphbyCallNode(from_node); | |||
| for (const auto &func_graph : func_graphs) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| @@ -475,10 +573,19 @@ void ControlNodeScheduler::LinkArrowByCallNode(const AnfNodePtr &call_node, Cont | |||
| MS_EXCEPTION_IF_NULL(actor); | |||
| auto exit_actor = dynamic_cast<ExitActor *>(actor); | |||
| size_t branch_id = parser->FetchBranchIDByCallNode(from_node); | |||
| LinkDataArrowForExitActor(exit_actor, to_actor, from_node_with_index.second, to_node_with_index.second, | |||
| branch_id); | |||
| if (real_abstract->isa<abstract::AbstractFunction>()) { | |||
| LinkPartialArrowForExitActor(exit_actor, to_actor, from_node_with_index.second, to_node_with_index.second, | |||
| branch_id); | |||
| } else { | |||
| LinkDataArrowForExitActor(exit_actor, to_actor, from_node_with_index.second, to_node_with_index.second, | |||
| branch_id); | |||
| } | |||
| } | |||
| if (abstract->isa<abstract::AbstractFunction>()) { | |||
| to_actor->input_partials_num_++; | |||
| } else { | |||
| to_actor->input_datas_num_++; | |||
| } | |||
| to_actor->input_datas_num_++; | |||
| } else { | |||
| // Link arrow from gather actor to entrance actor. | |||
| const auto &actor_name = GetActorName(from_node); | |||
| @@ -604,6 +711,34 @@ void ControlNodeScheduler::LinkControlArrowForControlActor(ActorSet *const actor | |||
| LinkControlArrow(entrance_actor, control_actor); | |||
| } | |||
| } | |||
| // Link auto monad control arrow for control actor. | |||
| std::vector<ControlActor *> control_actors; | |||
| (void)std::transform(control_actor_set->switch_actors_.begin(), control_actor_set->switch_actors_.end(), | |||
| std::back_inserter(control_actors), [](auto &switch_actor) { return switch_actor.get(); }); | |||
| (void)std::transform(control_actor_set->gather_actors_.begin(), control_actor_set->gather_actors_.end(), | |||
| std::back_inserter(control_actors), [](auto &gather_actor) { return gather_actor.get(); }); | |||
| (void)std::transform(control_actor_set->exit_actors_.begin(), control_actor_set->exit_actors_.end(), | |||
| std::back_inserter(control_actors), [](auto &exit_actor) { return exit_actor.get(); }); | |||
| for (auto control_actor : control_actors) { | |||
| MS_EXCEPTION_IF_NULL(control_actor); | |||
| const auto &node = control_actor->node_; | |||
| if (node == nullptr) { | |||
| return; | |||
| } | |||
| const auto &cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| const auto &inputs = cnode->inputs(); | |||
| for (const auto &input : inputs) { | |||
| MS_EXCEPTION_IF_NULL(input); | |||
| if (AnfAlgo::CheckPrimitiveType(input, prim::kPrimUpdateState) || | |||
| AnfAlgo::CheckPrimitiveType(input, prim::kPrimDepend) || | |||
| AnfAlgo::CheckPrimitiveType(input, prim::kPrimLoad)) { | |||
| LinkControlArrowByAutoMonad(control_actor, input, parser); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void ControlNodeScheduler::LinkControlArrowForKernelActor(ActorSet *const actor_set, | |||
| @@ -613,6 +748,13 @@ void ControlNodeScheduler::LinkControlArrowForKernelActor(ActorSet *const actor_ | |||
| // Link control arrow from entrance actors or stack actors to no input kernel actors. | |||
| for (const auto &no_input_kernel_actor : actor_set->no_input_kernel_actors_) { | |||
| // In control flow, when the input of the kernel actor is a parameter, this input needs to be linked to the | |||
| // control actor, so the no-input kernel actor collected in the graph scheduler will also collect this actor, | |||
| // and it needs to be skipped here. | |||
| if ((no_input_kernel_actor->input_datas_num_ != 0) || (no_input_kernel_actor->input_controls_num_ != 0)) { | |||
| continue; | |||
| } | |||
| KernelGraphPtr kernel_graph = nullptr; | |||
| if (no_input_kernel_actor->type_ == KernelTransformType::kSuperKernelActor) { | |||
| const auto &super_kernel_actor = dynamic_cast<SuperKernelActor *>(no_input_kernel_actor.get()); | |||
| @@ -653,6 +795,46 @@ void ControlNodeScheduler::LinkControlArrowForKernelActor(ActorSet *const actor_ | |||
| } | |||
| } | |||
| void ControlNodeScheduler::LinkControlArrowByAutoMonad(ControlActor *to_actor, const AnfNodePtr &from_node, | |||
| const ControlNodeParserPtr &parser) { | |||
| MS_EXCEPTION_IF_NULL(to_actor); | |||
| MS_EXCEPTION_IF_NULL(from_node); | |||
| std::set<AnfNodePtr> depend_nodes; | |||
| FetchRealDependNodeByAutoMonad(from_node, &depend_nodes); | |||
| for (const auto &depend_node : depend_nodes) { | |||
| MS_EXCEPTION_IF_NULL(depend_node); | |||
| auto from_actor = FetchActor(depend_node->DebugString()); | |||
| if (AnfAlgo::IsCallNode(depend_node)) { | |||
| int branch_id = parser->FetchBranchIDByCallNode(depend_node); | |||
| const auto &func_graphs = parser->FetchFuncGraphbyCallNode(depend_node); | |||
| if (func_graphs.empty()) { | |||
| MS_LOG(EXCEPTION) << "Failed to get funcgraph by call node:" << depend_node->DebugString(); | |||
| } | |||
| for (const auto func_graph : func_graphs) { | |||
| auto exit_actor_name = func_graph->ToString() + kExitActorNameSuffix; | |||
| auto actor = FetchActor(exit_actor_name); | |||
| MS_EXCEPTION_IF_NULL(actor); | |||
| auto exit_actor = dynamic_cast<ExitActor *>(actor); | |||
| MS_EXCEPTION_IF_NULL(exit_actor); | |||
| LinkControlArrowForExitActor(exit_actor, to_actor, branch_id); | |||
| } | |||
| to_actor->input_controls_num_ -= (func_graphs.size() - 1); | |||
| } else if (from_actor != nullptr) { | |||
| LinkControlArrow(from_actor, to_actor); | |||
| } else { | |||
| auto graph = parser->FetchKernelGraphByFrontNode(depend_node); | |||
| if (graph == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Failed to find actor for node:" << depend_node->DebugString(); | |||
| } | |||
| from_actor = FetchActor(graph->ToString() + kExitActorNameSuffix); | |||
| MS_EXCEPTION_IF_NULL(from_actor); | |||
| LinkControlArrow(from_actor, to_actor); | |||
| } | |||
| } | |||
| } | |||
| void ControlNodeScheduler::LinkBranchIDArrowForControlActor(ControlActorSet *const control_actor_set) { | |||
| MS_EXCEPTION_IF_NULL(control_actor_set); | |||
| @@ -871,6 +1053,15 @@ void ControlNodeScheduler::LinkDataArrowForExitActor(ExitActor *const exit_actor | |||
| (void)to_actor->input_data_arrow_aids_.emplace_back(exit_actor->GetAID()); | |||
| } | |||
| void ControlNodeScheduler::LinkPartialArrowForExitActor(ExitActor *const exit_actor, ControlActor *const to_actor, | |||
| size_t from_index, size_t to_index, int branch_id) { | |||
| MS_EXCEPTION_IF_NULL(exit_actor); | |||
| MS_EXCEPTION_IF_NULL(to_actor); | |||
| auto partial_arrow = std::make_shared<DataArrow>(from_index, to_actor->GetAID(), to_index); | |||
| (void)exit_actor->output_branch_partial_arrows_[branch_id].emplace_back(partial_arrow); | |||
| (void)to_actor->input_partial_arrow_aids_.emplace_back(exit_actor->GetAID()); | |||
| } | |||
| void ControlNodeScheduler::LinkControlArrowForExitActor(ExitActor *from_actor, AbstractActor *to_actor, int branch_id) { | |||
| MS_EXCEPTION_IF_NULL(from_actor); | |||
| MS_EXCEPTION_IF_NULL(to_actor); | |||
| @@ -50,7 +50,8 @@ class ControlNodeScheduler { | |||
| std::vector<EntranceActorPtr> BuildEntranceActor(const GraphCompilerInfo &graph_compiler_info); | |||
| std::vector<ExitActorPtr> BuildExitActor(const GraphCompilerInfo &graph_compiler_info); | |||
| std::vector<StackActorPtr> BuildStackActor(const GraphCompilerInfo &graph_compiler_info); | |||
| void BuildStackActorForControlNode(const GraphCompilerInfo &graph_compiler_info, | |||
| std::vector<StackActorPtr> *stack_actors); | |||
| // Interface to link control actors. | |||
| void LinkControlArrowForControlActor(ActorSet *const actor_set, const GraphCompilerInfo &graph_compiler_info); | |||
| void LinkBranchIDArrowForControlActor(ControlActorSet *const control_actor_set); | |||
| @@ -67,6 +68,10 @@ class ControlNodeScheduler { | |||
| void LinkArrowByParameter(const AnfNodePtr ¶meter, ControlActor *const to_actor, | |||
| const KernelWithIndex &from_node_with_index, const KernelWithIndex &to_node_with_index, | |||
| const ControlNodeParserPtr &parser); | |||
| void LinkArrowByValueNode(const AnfNodePtr &value_node, ControlActor *const to_actor, size_t from_index, | |||
| size_t to_index); | |||
| // Link arrow from stack actor to control actor. | |||
| void LinkArrowFromStackActor(ControlActor *to_actor); | |||
| // Link data arrow between control actor and actor in frame, including kernel actor, output actor, data source actor. | |||
| void LinkDataArrowForKernelActor(const GraphCompilerInfo &graph_compiler_info); | |||
| @@ -75,6 +80,9 @@ class ControlNodeScheduler { | |||
| void LinkDataArrowForOutputActor(ActorSet *const actor_set, const GraphCompilerInfo &graph_compiler_info); | |||
| void LinkDataArrowForHostDSActor(const GraphCompilerInfo &graph_compiler_info); | |||
| void LinkControlArrowForKernelActor(ActorSet *const actor_set, const GraphCompilerInfo &graph_compiler_info); | |||
| void LinkControlArrowByAutoMonad(ControlActor *to_actor, const AnfNodePtr &from_node, | |||
| const ControlNodeParserPtr &parser); | |||
| // Interface tool to link arrows between actors. | |||
| void LinkControlArrow(AbstractActor *from_actor, AbstractActor *to_actor); | |||
| // Data arrow with branch id is only exists from gather actor to entrance actor. | |||
| @@ -90,6 +98,8 @@ class ControlNodeScheduler { | |||
| void LinkControlArrowForExitActor(ExitActor *from_actor, AbstractActor *to_actor, int branch_id); | |||
| void LinkDataArrowForExitActor(ExitActor *const exit_actor, AbstractActor *const to_actor, size_t from_index, | |||
| size_t to_index, int branch_id); | |||
| void LinkPartialArrowForExitActor(ExitActor *const exit_actor, ControlActor *const to_actor, size_t from_index, | |||
| size_t to_index, int branch_id); | |||
| bool IsNoInputActor(const ControlActor *control_actor); | |||
| }; | |||
| } // namespace runtime | |||