| @@ -104,7 +104,7 @@ class ControlActor : public AbstractActor { | |||
| // The branch id is the unique identifier of the control actor. In the control flow, there are multiple control | |||
| // actors calling the same subgraph at the same time. At this time, the output of the subgraph needs to be returned | |||
| // to the calling place according to the branch id. | |||
| int output_branch_id_; | |||
| int output_branch_id_{0}; | |||
| // Partial data in local. When partial is only funcgraph without real parameter, it is stored inside the actor. | |||
| std::unordered_map<size_t, OpPartial> local_partials_; | |||
| @@ -79,14 +79,18 @@ void ExitActor::CopyDeviceAddress() { | |||
| if (node_ != nullptr) { | |||
| return; | |||
| } | |||
| if (input_device_tensors_.size() != is_need_copy_device_tensors_.size()) { | |||
| MS_LOG(ERROR) << "Invalid input device tensor size:" << input_device_tensors_.size() | |||
| << " need:" << is_need_copy_device_tensors_.size() << " for actor:" << GetAID(); | |||
| } | |||
| std::vector<DeviceTensor *> new_device_tensors; | |||
| for (size_t i = 0; i < input_device_tensors_.size(); ++i) { | |||
| auto input_device_tensor = input_device_tensors_[i]; | |||
| MS_EXCEPTION_IF_NULL(input_device_tensor); | |||
| const KernelWithIndex &node_with_index = input_device_tensor->GetNodeIndex(); | |||
| MS_EXCEPTION_IF_NULL(node_with_index.first); | |||
| if (device_contexts_[i] == nullptr) { | |||
| // If device context is empty, it means that the input is from a parameter, need not to copy a new device tensor. | |||
| if (!is_need_copy_device_tensors_[i]) { | |||
| new_device_tensors.emplace_back(input_device_tensor); | |||
| continue; | |||
| } | |||
| @@ -68,7 +68,9 @@ class ExitActor : public ControlActor { | |||
| // The exit actor needs to create a new device address and take out the ptr from the device tensor come from | |||
| // the kernel actor. These new created device tensors are stored in the created device tensors. | |||
| std::vector<DeviceTensorPtr> created_device_tensors_; | |||
| // In exit actor, we need to copy a new device tensor for the output of the kernel actor, but parameter is not | |||
| // needed. This mark is used to record whether it need to be copied. | |||
| std::vector<bool> is_need_copy_device_tensors_; | |||
| // Output data. | |||
| // The output branch data corresponds to the output_data_arrows_ one by one. | |||
| std::unordered_map<int, std::vector<std::pair<size_t, OpDataUniquePtr<DeviceTensor>>>> output_branch_data_; | |||
| @@ -47,7 +47,7 @@ void StackActor::RunOpData(OpData<DeviceTensor> *const input_data, OpContext<Dev | |||
| MS_EXCEPTION_IF_NULL(input_data->data_); | |||
| auto &sequential_num = context->sequential_num_; | |||
| // The parameters from the inside of the subgraph need to be put into the stack. | |||
| if (IntToSize(input_data->index_) < input_parameter_data_num_) { | |||
| if (IntToSize(input_data->index_) < input_parameter_data_num_ + device_tensor_store_keys_.size()) { | |||
| input_parameter_data_[sequential_num][input_data->index_].push(input_data->data_); | |||
| } else { | |||
| // The outputs of call nodes are placed directly in the input data. | |||
| @@ -94,7 +94,7 @@ void StackActor::FetchInput(OpContext<DeviceTensor> *const context) { | |||
| MS_LOG(ERROR) << "Invalid input for actor:" << GetAID(); | |||
| } | |||
| for (const auto &one_stack : data_iter->second) { | |||
| if (one_stack.first >= input_parameter_data_num_) { | |||
| if (one_stack.first >= input_parameter_data_num_ + device_tensor_store_keys_.size()) { | |||
| MS_LOG(ERROR) << "Invalid input index:" << one_stack.first << " need:" << input_parameter_data_num_ | |||
| << " for actor:" << GetAID(); | |||
| } | |||
| @@ -398,6 +398,23 @@ void DataPrepareActor::PrepareDataForControlValueNode(const KernelWithIndex &nod | |||
| SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info); | |||
| } | |||
| } | |||
| } else if (node_value->isa<BoolImm>()) { | |||
| const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(node, 0, false); | |||
| MS_EXCEPTION_IF_NULL(device_tensor); | |||
| if (device_tensor->GetPtr() != nullptr) { | |||
| return; | |||
| } | |||
| if (!device_context->AllocateMemory(device_tensor.get(), device_tensor->GetSize())) { | |||
| SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(strategy_, *context, *device_context, node->fullname_with_scope(), | |||
| device_tensor->GetSize()); | |||
| } | |||
| auto value = GetValue<bool>(node_value); | |||
| if (!device_tensor->SyncHostToDevice({}, sizeof(bool), kNumberTypeBool, &value)) { | |||
| std::string error_info = "SyncHostToDevice failed, node name: " + node->DebugString(); | |||
| SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info); | |||
| } | |||
| } | |||
| } | |||
| @@ -206,7 +206,7 @@ void CreateDeviceTensorForValueNode(const KernelWithIndex &front_node_with_index | |||
| MS_EXCEPTION_IF_NULL(front_node); | |||
| const auto &node_value = front_node->cast<ValueNodePtr>()->value(); | |||
| if ((!node_value->isa<tensor::Tensor>()) && (!node_value->isa<ValueTuple>())) { | |||
| if ((!node_value->isa<tensor::Tensor>()) && (!node_value->isa<ValueTuple>()) && (!node_value->isa<BoolImm>())) { | |||
| return; | |||
| } | |||
| @@ -315,16 +315,16 @@ void ControlNodeParser::Parse(const std::vector<AnfNodePtr> &control_nodes, cons | |||
| ParseFrontToBackendParameter(graphs, device_contexts); | |||
| FetchFrontToBackendKernel(graphs, device_contexts); | |||
| FetchHostParameterToWeight(); | |||
| FetchCallInputKernelGraph(graphs, device_contexts); | |||
| FetchFrontValueNode(device_contexts[0]); | |||
| FetchFrontToBackendKernel(graphs, device_contexts); | |||
| ParseDeviceContext(control_nodes, graphs, device_contexts, func_graph_to_kernel_graphs); | |||
| FetchFrontValueNode(device_contexts[0]); | |||
| FetchControlNodeParameter(control_nodes); | |||
| FetchAutoMonadNode(control_nodes); | |||
| @@ -447,6 +447,9 @@ void ControlNodeParser::ParseDeviceContextForPartialNode(const std::vector<AnfNo | |||
| // Get the device contexts for the real parameters. | |||
| std::vector<const DeviceContext *> device_contexts; | |||
| // In partial node, the first input is always a partial, maybe a funcgraph or a partial node, so we need | |||
| // to insert an empty device context for it. | |||
| device_contexts.emplace_back(nullptr); | |||
| for (size_t i = 0; i < inputs.size() - kPartialInputStartPos; ++i) { | |||
| if (i >= iter->second.size()) { | |||
| MS_LOG(EXCEPTION) << "Invalid device context index:" << i << " for funcgraph:" << func_graph->ToString() | |||
| @@ -480,6 +483,9 @@ void ControlNodeParser::ParseDeviceContextForCallNode(const std::vector<AnfNodeP | |||
| } | |||
| std::vector<const DeviceContext *> device_contexts; | |||
| // In call node, the first input is always a partial, maybe a funcgraph or a partial node, so we need | |||
| // to insert an empty device context for it. | |||
| device_contexts.emplace_back(nullptr); | |||
| const auto &cnode = control_node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| const auto &inputs = cnode->inputs(); | |||
| @@ -503,6 +509,7 @@ void ControlNodeParser::ParseDeviceContextForCallNode(const std::vector<AnfNodeP | |||
| } | |||
| void ControlNodeParser::ParseDeviceContextForReturnNode(const DeviceContext *default_context) { | |||
| MS_EXCEPTION_IF_NULL(default_context); | |||
| // Collect the call realationship between funcgraphs. | |||
| FuncGraphCallRelation func_graph_call_relation; | |||
| for (const auto &call_node_to_func_graphs : call_node_to_func_graphs_) { | |||
| @@ -62,7 +62,7 @@ using FrontToBackendNodeWithContext = std::map<KernelWithIndex, std::set<std::pa | |||
| using FrontToBackendKernelWithContext = std::map<KernelWithIndex, std::pair<KernelWithIndex, DeviceContext *>>; | |||
| using FuncGraphToKernelGraph = std::unordered_map<FuncGraphPtr, std::vector<KernelGraphPtr>>; | |||
| using HostParameterToWeight = std::unordered_map<AnfNodePtr, std::set<AnfNodePtr>>; | |||
| using NodeWithDeviceContext = std::set<std::pair<KernelWithIndex, DeviceContext *>>; | |||
| using NodeWithDeviceContext = std::set<std::pair<KernelWithIndex, const DeviceContext *>>; | |||
| using RealToFormalNode = std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>>; | |||
| using FormalToRealParameter = std::unordered_map<AnfNodePtr, std::set<KernelWithIndex>>; | |||
| using RealToFormalParameter = std::unordered_map<AnfNodePtr, std::set<AnfNodePtr>>; | |||
| @@ -226,57 +226,52 @@ std::vector<ExitActorPtr> ControlNodeScheduler::BuildExitActor(const GraphCompil | |||
| } | |||
| } | |||
| if (graph_compiler_info.graphs_.size() != graph_compiler_info.device_contexts_.size()) { | |||
| MS_LOG(EXCEPTION) << "Invalid graphs num:" << graph_compiler_info.graphs_.size() | |||
| << " and contexts num:" << graph_compiler_info.device_contexts_.size(); | |||
| } | |||
| // 2. Replace the device address in the kernel actor when calling funcgraph, that is to say in the data exchange | |||
| // between kernel graph and the control node, in fact, it is the output of the kernel graph. | |||
| for (const auto func_graph_to_kernel_graphs : parser->func_graph_to_kernel_graphs_) { | |||
| for (const auto &kernel_graph : func_graph_to_kernel_graphs.second) { | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| // If the graph does not have kernel, it means there is no internal calculation in it, the output is parameter, | |||
| // and no exit actor is needed. | |||
| if (kernel_graph->execution_order().empty()) { | |||
| continue; | |||
| } | |||
| std::vector<KernelWithIndex> formal_parameters; | |||
| const auto &graph_outputs = kernel_graph->graph_output_map(); | |||
| std::vector<const DeviceContext *> device_contexts; | |||
| for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) { | |||
| const auto &kernel_graph = graph_compiler_info.graphs_[i]; | |||
| const auto &device_context = graph_compiler_info.device_contexts_[i]; | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| MS_EXCEPTION_IF_NULL(device_context); | |||
| if (kernel_graph->execution_order().empty()) { | |||
| continue; | |||
| } | |||
| for (const auto &backend_to_front : graph_outputs) { | |||
| if (HasAbstractMonad(backend_to_front.second.first)) { | |||
| continue; | |||
| } | |||
| // Collect inputs of exit actor. | |||
| formal_parameters.emplace_back(backend_to_front.second); | |||
| // Get the device contexts of the exit actor's cnode inputs. | |||
| const AnfNodePtr &backend_node = backend_to_front.first.first; | |||
| MS_EXCEPTION_IF_NULL(backend_node); | |||
| if ((!backend_node->isa<CNode>())) { | |||
| device_contexts.emplace_back(nullptr); | |||
| continue; | |||
| } | |||
| std::vector<bool> is_need_copy_device_tensors; | |||
| std::vector<KernelWithIndex> formal_parameters; | |||
| const auto &graph_outputs = kernel_graph->graph_output_map(); | |||
| const auto &actor_name = backend_node->fullname_with_scope(); | |||
| const auto &actor = FetchActor(actor_name); | |||
| MS_EXCEPTION_IF_NULL(actor); | |||
| const auto &kernel_actor = dynamic_cast<KernelActor *>(actor); | |||
| MS_EXCEPTION_IF_NULL(kernel_actor); | |||
| if (kernel_actor->device_contexts_.empty() || kernel_actor->device_contexts_[0] == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Failed to get device context for kernel:" << backend_node->DebugString(); | |||
| } | |||
| device_contexts.emplace_back(kernel_actor->device_contexts_[0]); | |||
| for (const auto &backend_to_front : graph_outputs) { | |||
| if (HasAbstractMonad(backend_to_front.second.first)) { | |||
| continue; | |||
| } | |||
| const auto &actor_name = kernel_graph->ToString() + kExitActorNameSuffix; | |||
| const auto &exit_actor = std::make_shared<ExitActor>(actor_name, formal_parameters, nullptr); | |||
| exit_actors.emplace_back(exit_actor); | |||
| if (exit_actor->device_contexts_.size() != device_contexts.size()) { | |||
| MS_LOG(EXCEPTION) << "Invalid device context size:" << device_contexts.size() | |||
| << " need:" << exit_actor->device_contexts_.size() << " for actor:" << exit_actor->GetAID(); | |||
| // Collect inputs of exit actor. | |||
| formal_parameters.emplace_back(backend_to_front.second); | |||
| // Get the device contexts of the exit actor's cnode inputs. | |||
| const AnfNodePtr &backend_node = backend_to_front.first.first; | |||
| MS_EXCEPTION_IF_NULL(backend_node); | |||
| if ((!backend_node->isa<CNode>())) { | |||
| is_need_copy_device_tensors.emplace_back(false); | |||
| continue; | |||
| } | |||
| exit_actor->device_contexts_.swap(device_contexts); | |||
| InsertActor(exit_actor.get()); | |||
| is_need_copy_device_tensors.emplace_back(true); | |||
| } | |||
| const auto &actor_name = kernel_graph->ToString() + kExitActorNameSuffix; | |||
| const auto &exit_actor = std::make_shared<ExitActor>(actor_name, formal_parameters, nullptr); | |||
| exit_actor->is_need_copy_device_tensors_ = is_need_copy_device_tensors; | |||
| std::vector<const DeviceContext *> device_contexts(formal_parameters.size(), device_context); | |||
| exit_actor->device_contexts_.swap(device_contexts); | |||
| exit_actors.emplace_back(exit_actor); | |||
| InsertActor(exit_actor.get()); | |||
| } | |||
| return exit_actors; | |||
| } | |||