| @@ -322,8 +322,11 @@ std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputWithIndex(const An | |||
| } | |||
| const std::vector<PrimitivePtr> return_types = {prim::kPrimDepend, prim::kPrimMakeTuple}; | |||
| // The output may be the tuple, so need visit all the outputs of node. | |||
| auto outputs_num = AnfAlgo::GetOutputTensorNum(node); | |||
| size_t outputs_num = 1; | |||
| if (IsRealCNodeKernel(node)) { | |||
| outputs_num = AnfAlgo::GetOutputTensorNum(node); | |||
| } | |||
| // The output may be the tuple of node, so need visit all the outputs of node. | |||
| for (size_t i = 0; i < outputs_num; ++i) { | |||
| const auto &output_with_index = AnfAlgo::VisitKernelWithReturnType(node, i, false, return_types); | |||
| MS_EXCEPTION_IF_NULL(output_with_index.first); | |||
| @@ -351,6 +354,8 @@ std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputWithIndex(const An | |||
| return ret_empty; | |||
| } | |||
| MS_LOG(INFO) << "Output node: " << output_with_index.first->fullname_with_scope() | |||
| << " with output index: " << output_with_index.second; | |||
| ret.push_back(output_with_index); | |||
| } | |||
| @@ -1925,7 +1925,6 @@ CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std: | |||
| auto FindEqu = [graph, outputs, this](const AnfNodePtr &out) -> AnfNodePtr { | |||
| auto backend_anf = graph->GetBackendAnfByFrontAnf(out); | |||
| if (backend_anf != nullptr) { | |||
| graph->CacheGraphOutputToFrontNodeWithIndex(backend_anf, out); | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) { | |||
| @@ -262,6 +262,15 @@ GraphId GraphCompiler::CompileGraph(const AnfNodePtrList &nodes, const AnfNodePt | |||
| // Generate kernel graph. | |||
| KernelGraphPtr graph = session_->ConstructKernelGraph(nodes, outputs); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| // Cache the backend graph output nodes to front nodes with output index. | |||
| for (auto &output : outputs) { | |||
| auto backend_node = graph->GetBackendAnfByFrontAnf(output); | |||
| if (backend_node != nullptr) { | |||
| graph->CacheGraphOutputToFrontNodeWithIndex(backend_node, output); | |||
| } | |||
| } | |||
| return CompileGraphImpl(graph, device_context); | |||
| } | |||
| @@ -1421,7 +1421,12 @@ void GraphScheduler::LinkOutputResultArrowForOutputActor(OutputActor *to_actor, | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| ++number; | |||
| auto outputs = AnfAlgo::GetAllOutputWithIndex(graph->output()); | |||
| for (const auto &output_with_index : outputs) { | |||
| std::set<std::pair<int, std::vector<size_t>>> unique_output_positions; | |||
| std::set<KernelWithIndex> unique_outputs; | |||
| for (const auto &output : outputs) { | |||
| unique_outputs.insert(output); | |||
| } | |||
| for (const auto &output_with_index : unique_outputs) { | |||
| MS_EXCEPTION_IF_NULL(output_with_index.first); | |||
| auto origin_output_with_index = FetchFrontNodeWithIndexByGraphOutput(output_with_index, graph); | |||
| const auto &iter = graph_compiler_info.origin_outputs_order_.find(origin_output_with_index); | |||
| @@ -1429,54 +1434,62 @@ void GraphScheduler::LinkOutputResultArrowForOutputActor(OutputActor *to_actor, | |||
| continue; | |||
| } | |||
| to_actor->device_contexts_[iter->second.second] = graph_compiler_info.device_contexts_[number - 1]; | |||
| // The device tensor of graph out need be taken over by host tensor, so set the max reference count. | |||
| UpdateRefCount(output_with_index.first, output_with_index.second, true); | |||
| // The graph output is from device tensor store. | |||
| if (IsPersistentDeviceTensor(output_with_index.first)) { | |||
| to_actor->device_tensor_store_keys_[iter->second.first].emplace_back(iter->second.second, | |||
| output_with_index.first); | |||
| // Skip duplicate position. | |||
| if (unique_output_positions.count(iter->second) > 0) { | |||
| continue; | |||
| } | |||
| unique_output_positions.insert(iter->second); | |||
| for (auto &output_position : iter->second.second) { | |||
| to_actor->device_contexts_[output_position] = graph_compiler_info.device_contexts_[number - 1]; | |||
| // The device tensor of graph out need be taken over by host tensor, so set the max reference count. | |||
| UpdateRefCount(output_with_index.first, output_with_index.second, true); | |||
| // The graph output is from device tensor store. | |||
| if (IsPersistentDeviceTensor(output_with_index.first)) { | |||
| to_actor->device_tensor_store_keys_[iter->second.first].emplace_back(output_position, | |||
| output_with_index.first); | |||
| continue; | |||
| } | |||
| // The graph output is from kernel actor. | |||
| if (IsKernelActor(output_with_index.first)) { | |||
| const auto &from_actor = | |||
| dynamic_cast<KernelActor *>(FetchActor(output_with_index.first->fullname_with_scope())); | |||
| MS_EXCEPTION_IF_NULL(from_actor); | |||
| auto op_arrow = std::make_shared<DataArrow>(output_with_index.second, to_actor->GetAID(), iter->second.second); | |||
| from_actor->output_result_arrows_.emplace_back(op_arrow); | |||
| continue; | |||
| } | |||
| // The graph output is from kernel actor. | |||
| if (IsKernelActor(output_with_index.first)) { | |||
| const auto &from_actor = | |||
| dynamic_cast<KernelActor *>(FetchActor(output_with_index.first->fullname_with_scope())); | |||
| MS_EXCEPTION_IF_NULL(from_actor); | |||
| auto op_arrow = std::make_shared<DataArrow>(output_with_index.second, to_actor->GetAID(), output_position); | |||
| from_actor->output_result_arrows_.emplace_back(op_arrow); | |||
| continue; | |||
| } | |||
| // The graph output is from data source actor. | |||
| std::string actor_name; | |||
| DataSourceActor *from_actor = nullptr; | |||
| size_t from_actor_output_index = 0; | |||
| if (IsHostQueueDSActor(output_with_index.first, graph, nullptr, graph_compiler_info.origin_parameters_order_)) { | |||
| actor_name = graph_compiler_info.name_ + "_HostDSActor"; | |||
| const auto &host_queue_ds_actor = dynamic_cast<HostQueueDataSourceActor *>(FetchActor(actor_name)); | |||
| from_actor_output_index = host_queue_ds_actor->FetchDataNodePosition(output_with_index.first); | |||
| UpdateRefCount(host_queue_ds_actor->data_nodes_[from_actor_output_index], output_with_index.second, true); | |||
| from_actor = static_cast<DataSourceActor *>(host_queue_ds_actor); | |||
| } else if (IsDeviceQueueDSActor(output_with_index.first)) { | |||
| actor_name = graph_compiler_info.name_ + "_DeviceDSActor" + "_" + std::to_string(graph->graph_id()); | |||
| from_actor = dynamic_cast<DataSourceActor *>(FetchActor(actor_name)); | |||
| from_actor_output_index = output_with_index.second; | |||
| } | |||
| // The graph output is from data source actor. | |||
| std::string actor_name; | |||
| DataSourceActor *from_actor = nullptr; | |||
| size_t from_actor_output_index = 0; | |||
| if (IsHostQueueDSActor(output_with_index.first, graph, nullptr, graph_compiler_info.origin_parameters_order_)) { | |||
| actor_name = graph_compiler_info.name_ + "_HostDSActor"; | |||
| const auto &host_queue_ds_actor = dynamic_cast<HostQueueDataSourceActor *>(FetchActor(actor_name)); | |||
| from_actor_output_index = host_queue_ds_actor->FetchDataNodePosition(output_with_index.first); | |||
| UpdateRefCount(host_queue_ds_actor->data_nodes_[from_actor_output_index], output_with_index.second, true); | |||
| from_actor = static_cast<DataSourceActor *>(host_queue_ds_actor); | |||
| } else if (IsDeviceQueueDSActor(output_with_index.first)) { | |||
| actor_name = graph_compiler_info.name_ + "_DeviceDSActor" + "_" + std::to_string(graph->graph_id()); | |||
| from_actor = dynamic_cast<DataSourceActor *>(FetchActor(actor_name)); | |||
| from_actor_output_index = output_with_index.second; | |||
| } | |||
| // When the input is a parameter node, it should be connected by gather actor. | |||
| if (from_actor == nullptr) { | |||
| if (output_with_index.first->isa<CNode>()) { | |||
| MS_LOG(EXCEPTION) << "Cannot find kernel actor for kernel:" << output_with_index.first->fullname_with_scope(); | |||
| } else { | |||
| continue; | |||
| // When the input is a parameter node, it should be connected by gather actor. | |||
| if (from_actor == nullptr) { | |||
| if (output_with_index.first->isa<CNode>()) { | |||
| MS_LOG(EXCEPTION) << "Cannot find kernel actor for kernel:" | |||
| << output_with_index.first->fullname_with_scope(); | |||
| } else { | |||
| continue; | |||
| } | |||
| } | |||
| MS_EXCEPTION_IF_NULL(from_actor); | |||
| auto op_arrow = std::make_shared<DataArrow>(from_actor_output_index, to_actor->GetAID(), output_position); | |||
| from_actor->output_result_arrows_.emplace_back(op_arrow); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(from_actor); | |||
| auto op_arrow = std::make_shared<DataArrow>(from_actor_output_index, to_actor->GetAID(), iter->second.second); | |||
| from_actor->output_result_arrows_.emplace_back(op_arrow); | |||
| } | |||
| } | |||
| } | |||
| @@ -1852,39 +1865,41 @@ void GraphScheduler::LinkOutputResultArrowForGatherActor(const GraphCompilerInfo | |||
| if (iter == graph_compiler_info.origin_outputs_order_.end()) { | |||
| continue; | |||
| } | |||
| MS_LOG(INFO) << "Link output node:" << AnfAlgo::GetNodeDebugString(origin_output_with_index.first) | |||
| << " branch id:" << iter->second.first << " index:" << iter->second.second | |||
| << " for gather actor:" << gather_actor->GetAID(); | |||
| auto op_arrow = std::make_shared<DataArrow>(i, to_actor->GetAID(), iter->second.second); | |||
| gather_actor->output_result_arrows_.emplace_back(op_arrow); | |||
| const auto &backend_nodes = gather_actor->front_to_backend_parameter_[front_node]; | |||
| if (backend_nodes.empty()) { | |||
| MS_LOG(EXCEPTION) << "No backend node for data node:" << AnfAlgo::GetNodeDebugString(front_node); | |||
| } | |||
| const auto &backend_node = backend_nodes[0].first; | |||
| if (backend_node->isa<Parameter>()) { | |||
| std::string actor_name = graph_compiler_info.name_ + "_HostDSActor"; | |||
| auto actor = FetchActor(actor_name); | |||
| MS_EXCEPTION_IF_NULL(actor); | |||
| auto host_ds_actor = dynamic_cast<HostQueueDataSourceActor *>(actor); | |||
| MS_EXCEPTION_IF_NULL(host_ds_actor); | |||
| for (auto &output_position : iter->second.second) { | |||
| MS_LOG(INFO) << "Link output node:" << AnfAlgo::GetNodeDebugString(origin_output_with_index.first) | |||
| << " branch id:" << iter->second.first << " index:" << output_position | |||
| << " for gather actor:" << gather_actor->GetAID(); | |||
| const auto &data_nodes = host_ds_actor->data_nodes_; | |||
| const auto &node_iter = find(data_nodes.begin(), data_nodes.end(), backend_node); | |||
| if (node_iter == data_nodes.end()) { | |||
| MS_LOG(EXCEPTION) << "Cannot find backend node in host data source actor, node:" | |||
| << AnfAlgo::GetNodeDebugString(backend_node); | |||
| auto op_arrow = std::make_shared<DataArrow>(i, to_actor->GetAID(), output_position); | |||
| gather_actor->output_result_arrows_.emplace_back(op_arrow); | |||
| const auto &backend_nodes = gather_actor->front_to_backend_parameter_[front_node]; | |||
| if (backend_nodes.empty()) { | |||
| MS_LOG(EXCEPTION) << "No backend node for data node:" << AnfAlgo::GetNodeDebugString(front_node); | |||
| } | |||
| const auto &backend_node = backend_nodes[0].first; | |||
| if (backend_node->isa<Parameter>()) { | |||
| std::string actor_name = graph_compiler_info.name_ + "_HostDSActor"; | |||
| auto actor = FetchActor(actor_name); | |||
| MS_EXCEPTION_IF_NULL(actor); | |||
| auto host_ds_actor = dynamic_cast<HostQueueDataSourceActor *>(actor); | |||
| MS_EXCEPTION_IF_NULL(host_ds_actor); | |||
| const auto &data_nodes = host_ds_actor->data_nodes_; | |||
| const auto &node_iter = find(data_nodes.begin(), data_nodes.end(), backend_node); | |||
| if (node_iter == data_nodes.end()) { | |||
| MS_LOG(EXCEPTION) << "Cannot find backend node in host data source actor, node:" | |||
| << AnfAlgo::GetNodeDebugString(backend_node); | |||
| } | |||
| to_actor->device_contexts_[output_position] = host_ds_actor->device_contexts_[node_iter - data_nodes.begin()]; | |||
| } else { | |||
| auto actor_base = FetchActor(backend_node->fullname_with_scope()); | |||
| MS_EXCEPTION_IF_NULL(actor_base); | |||
| auto kernel_actor = dynamic_cast<KernelActor *>(actor_base); | |||
| MS_EXCEPTION_IF_NULL(kernel_actor); | |||
| to_actor->device_contexts_[output_position] = kernel_actor->device_context_; | |||
| } | |||
| to_actor->device_contexts_[iter->second.second] = | |||
| host_ds_actor->device_contexts_[node_iter - data_nodes.begin()]; | |||
| } else { | |||
| auto actor_base = FetchActor(backend_node->fullname_with_scope()); | |||
| MS_EXCEPTION_IF_NULL(actor_base); | |||
| auto kernel_actor = dynamic_cast<KernelActor *>(actor_base); | |||
| MS_EXCEPTION_IF_NULL(kernel_actor); | |||
| to_actor->device_contexts_[iter->second.second] = kernel_actor->device_context_; | |||
| } | |||
| } | |||
| } | |||
| @@ -24,6 +24,7 @@ | |||
| #include <unordered_map> | |||
| #include <unordered_set> | |||
| #include <map> | |||
| #include <set> | |||
| #include <algorithm> | |||
| #include <fstream> | |||
| #include "runtime/framework/actor/data_source_actor.h" | |||
| @@ -42,10 +43,10 @@ namespace runtime { | |||
| using mindspore::device::DeviceContext; | |||
| using mindspore::session::KernelGraph; | |||
| using mindspore::session::KernelWithIndex; | |||
| // Position of kernel with index, the value pair<branch_id, pos> means the branch id of the kernel and the pos of | |||
| // the kernel. Generally, there is only one branch, and the branch id is 0 at this time. In control flow, there | |||
| // are multiple branch scenarios, and pos represents the position of the kernel in the branch. | |||
| using KernelMapPosition = std::map<KernelWithIndex, std::pair<int, size_t>, session::KernelWithIndexCmp>; | |||
| // Position of kernel with index, the value pair<branch_id, vector<pos>> means the branch id of the kernel and the pos | |||
| // of the kernel. Generally, there is only one branch, and the branch id is 0 at this time. In control flow, there are | |||
| // multiple branch scenarios, and pos represents the position of the kernel in the branch. | |||
| using KernelMapPosition = std::map<KernelWithIndex, std::pair<int, std::vector<size_t>>, session::KernelWithIndexCmp>; | |||
| using ActorInfo = std::string; | |||
| // The second element of pair represents the output index of op actor corresponding to the graph output node. | |||
| @@ -250,6 +250,7 @@ void MsBackend::SetDebugger() { target_sess_->SetDebugger(); } | |||
| MindRTBackend::MindRTBackend(const std::string &backend_name, const std::string &device_name, uint32_t device_id) | |||
| : Backend(backend_name), device_name_(device_name), device_id_(device_id) { | |||
| root_graph_ = nullptr; | |||
| auto cut_list = compile::GetMsNonlinearOps(); | |||
| graph_partition_ = std::make_shared<GraphPartition>(cut_list, backend_name); | |||
| graph_compiler_ = std::make_shared<GraphCompiler>(); | |||
| @@ -258,18 +259,18 @@ MindRTBackend::MindRTBackend(const std::string &backend_name, const std::string | |||
| const ActorInfo &MindRTBackend::CompileGraphs(const FuncGraphPtr &func_graph) { | |||
| MS_EXCEPTION_IF_NULL(graph_compiler_); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| FuncGraphPtr root_graph = WrapPrimitives(func_graph); | |||
| MS_EXCEPTION_IF_NULL(root_graph); | |||
| root_graph_ = WrapPrimitives(func_graph); | |||
| MS_EXCEPTION_IF_NULL(root_graph_); | |||
| // Register a summary callback function, which is called in the final stages of summary. | |||
| graph_compiler_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback); | |||
| // Compile root graph. | |||
| graph_id_to_device_context_.clear(); | |||
| control_nodes_.clear(); | |||
| CompileGraph(root_graph); | |||
| CompileGraph(root_graph_); | |||
| // Compile sub graphs. | |||
| FuncGraphSet sub_graphs = root_graph->manager()->func_graphs(); | |||
| FuncGraphSet sub_graphs = root_graph_->manager()->func_graphs(); | |||
| for (auto sub_graph : sub_graphs) { | |||
| if (sub_graph != func_graph && sub_graph != nullptr) { | |||
| CompileGraph(sub_graph); | |||
| @@ -277,7 +278,7 @@ const ActorInfo &MindRTBackend::CompileGraphs(const FuncGraphPtr &func_graph) { | |||
| } | |||
| // Construct the graph compiler info. | |||
| auto graph_compiler_info = ConstructGraphCompilerInfo(root_graph); | |||
| auto graph_compiler_info = ConstructGraphCompilerInfo(root_graph_); | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| @@ -494,19 +495,53 @@ void MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &args, | |||
| // Fetch outputs. | |||
| MS_EXCEPTION_IF_NULL(actor_set->output_actor_); | |||
| auto &output_tensors = actor_set->output_actor_->outputs(); | |||
| if (output_tensors.size() > 1) { | |||
| VectorRef tmp; | |||
| (void)std::transform(output_tensors.begin(), output_tensors.end(), std::back_inserter(tmp.elements_), | |||
| [](tensor::TensorPtr &tensor) { return std::move(tensor); }); | |||
| outputs->emplace_back(std::move(tmp)); | |||
| } else if (output_tensors.size() == 1) { | |||
| outputs->emplace_back(std::move(output_tensors.front())); | |||
| if (output_tensors.size() > 0) { | |||
| size_t output_position = 0; | |||
| ConstructOutputs(root_graph_->output(), output_tensors, &output_position, outputs); | |||
| } | |||
| MS_LOG(INFO) << "Run actor end, actor name: " << actor_info; | |||
| graph_compiler_->Summary(graph_compiler_info.graphs_); | |||
| } | |||
| void MindRTBackend::ConstructOutputs(const AnfNodePtr &output_node, | |||
| const std::vector<tensor::TensorPtr> &output_tensors, size_t *output_position, | |||
| VectorRef *outputs) { | |||
| // The makeTuple node need expand and recurse. | |||
| if (AnfAlgo::CheckPrimitiveType(output_node, prim::kPrimMakeTuple)) { | |||
| auto make_tuple = output_node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(make_tuple); | |||
| VectorRef make_tuple_output; | |||
| for (size_t i = 1; i < make_tuple->inputs().size(); i++) { | |||
| ConstructOutputs(make_tuple->input(i), output_tensors, output_position, &make_tuple_output); | |||
| } | |||
| outputs->emplace_back(std::move(make_tuple_output)); | |||
| return; | |||
| } | |||
| // The depend node need get the real node. | |||
| if (AnfAlgo::CheckPrimitiveType(output_node, prim::kPrimDepend)) { | |||
| auto depend_node = output_node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(depend_node); | |||
| ConstructOutputs(depend_node->input(kRealInputIndexInDepend), output_tensors, output_position, outputs); | |||
| return; | |||
| } | |||
| // Judge the output whether tuple or not by the outputs number. | |||
| auto outputs_num = AnfAlgo::GetOutputTensorNum(output_node); | |||
| if (outputs_num > 1) { | |||
| VectorRef output_tuple; | |||
| for (size_t i = 0; i < outputs_num; ++i) { | |||
| output_tuple.emplace_back(std::move(output_tensors[*output_position])); | |||
| ++(*output_position); | |||
| } | |||
| outputs->emplace_back(std::move(output_tuple)); | |||
| } else { | |||
| outputs->emplace_back(std::move(output_tensors[*output_position])); | |||
| ++(*output_position); | |||
| } | |||
| } | |||
| std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(const FuncGraphPtr &root_graph) { | |||
| MS_EXCEPTION_IF_NULL(root_graph); | |||
| MS_EXCEPTION_IF_NULL(graph_compiler_); | |||
| @@ -535,7 +570,11 @@ std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(con | |||
| auto outputs = AnfAlgo::GetAllOutputWithIndex(branch_output); | |||
| outputs_num = outputs.size(); | |||
| for (const auto &output : outputs) { | |||
| outputs_order[output] = {branch_id, position++}; | |||
| if (outputs_order.count(output) == 0) { | |||
| outputs_order[output] = {branch_id, {position++}}; | |||
| } else { | |||
| outputs_order[output].second.emplace_back(position++); | |||
| } | |||
| } | |||
| } | |||
| @@ -579,7 +618,11 @@ std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo( | |||
| auto outputs = AnfAlgo::GetAllOutputWithIndex(graph->output()); | |||
| for (const auto &output : outputs) { | |||
| outputs_order[output] = {runtime::kMainBranchID, position++}; | |||
| if (outputs_order.count(output) == 0) { | |||
| outputs_order[output] = {runtime::kMainBranchID, {position++}}; | |||
| } else { | |||
| outputs_order[output].second.emplace_back(position++); | |||
| } | |||
| } | |||
| } | |||
| @@ -126,6 +126,10 @@ class MindRTBackend : public Backend { | |||
| // The result of graph compiler is stored in graph_id_to_device_context_ and control_nodes_. | |||
| void CompileGraph(const FuncGraphPtr &func_graph); | |||
| // Restore the outputs tuple by the origin funcGraph output node and output tensors. | |||
| void ConstructOutputs(const AnfNodePtr &output_node, const std::vector<tensor::TensorPtr> &output_tensors, | |||
| size_t *output_position, VectorRef *outputs); | |||
| // Construct the GraphCompilerInfo by the compilation results of graph, used in Graph mode. | |||
| std::unique_ptr<GraphCompilerInfo> ConstructGraphCompilerInfo(const FuncGraphPtr &root_graph); | |||
| @@ -148,6 +152,7 @@ class MindRTBackend : public Backend { | |||
| std::unordered_map<ActorInfo, std::unique_ptr<GraphCompilerInfo>> actor_to_graph_compiler_info_; | |||
| FuncGraphPtr root_graph_; | |||
| GraphPartitionPtr graph_partition_; | |||
| std::shared_ptr<GraphCompiler> graph_compiler_; | |||
| std::string device_name_; | |||