| @@ -58,10 +58,8 @@ void TaskEmitActionForMindRT(const ResourcePtr &res) { | |||
| auto mindrt_bc_ptr = std::dynamic_pointer_cast<compile::MindRTBackend>(bc_ptr); | |||
| MS_EXCEPTION_IF_NULL(mindrt_bc_ptr); | |||
| auto cut_list = compile::GetMsNonlinearOps(); | |||
| auto mindrt_compile = std::make_shared<compile::GraphCompiler>(mindrt_bc_ptr, cut_list); | |||
| // The output of graph compiler is graph id. | |||
| res->results()[kOutput] = mindrt_compile->CompileGraphs(res->func_graph()); | |||
| res->results()[kOutput] = mindrt_bc_ptr->CompileGraphs(res->func_graph()); | |||
| } | |||
| void ExecuteActionForMindRT(const ResourcePtr &res) { | |||
| @@ -995,7 +995,7 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc | |||
| ConfigManager::GetInstance().set_iter_num(size); | |||
| const auto &mindrt_backend = std::dynamic_pointer_cast<compile::MindRTBackend>(backend); | |||
| MS_EXCEPTION_IF_NULL(mindrt_backend); | |||
| auto graph_id = mindrt_backend->CompileGraph({app_init}); | |||
| auto graph_id = mindrt_backend->CompileGraphs(func_graph); | |||
| VectorRef args; | |||
| if (need_run) { | |||
| (void)mindrt_backend->RunGraph(graph_id, args); | |||
| @@ -226,10 +226,6 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph) const { | |||
| // Create device address for all anf nodes of graph. | |||
| CreateDeviceAddress(graph); | |||
| // Transform graph to actor DAG, contains build and link. | |||
| const auto &actor_set = GraphScheduler::GetInstance().Transform(graph, device_context_); | |||
| GraphScheduler::GetInstance().Schedule(actor_set); | |||
| return graph->graph_id(); | |||
| } | |||
| @@ -262,7 +258,8 @@ GraphId GraphCompiler::CompileGraph(session::OpRunInfo *op_run_info, const Graph | |||
| // Create device address for all anf nodes of graph. | |||
| CreateDeviceAddress(graph); | |||
| // Transform graph to actor DAG, contains build and link. | |||
| GraphScheduler::GetInstance().Transform(graph, device_context_, input_tensors, GraphExecutionStrategy::kStep); | |||
| GraphScheduler::GetInstance().Transform({graph}, {device_context_}, input_tensors, nullptr, | |||
| GraphExecutionStrategy::kStep); | |||
| run_op_graphs_[graph_info] = graph; | |||
| return graph->graph_id(); | |||
| } | |||
| @@ -308,24 +308,32 @@ void GraphScheduler::Initialize() { | |||
| (void)actorMgr->Spawn(base_actor, false); | |||
| } | |||
| ActorSet *GraphScheduler::Transform(const KernelGraphPtr &graph, const DeviceContext *device_context, | |||
| const std::vector<tensor::TensorPtr> *input_tensors, | |||
| GraphExecutionStrategy strategy) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_LOG(INFO) << "Graph(" << graph->ToString() << ") transforms actor begin."; | |||
| ActorSet *GraphScheduler::Transform(const std::vector<KernelGraphPtr> &graphs, | |||
| const std::vector<DeviceContext *> &device_contexts, | |||
| const std::vector<TensorPtr> *input_tensors, | |||
| const std::vector<AnfNodePtr> *control_nodes, GraphExecutionStrategy strategy) { | |||
| if (graphs.size() != device_contexts.size()) { | |||
| MS_LOG(EXCEPTION) << "The number of graphs is not equal to the number of device_contexts."; | |||
| } | |||
| Initialize(); | |||
| PersistDeviceTensor(graph); | |||
| auto actor_set = Build(graph, device_context); | |||
| graph_to_actors_.emplace(graph, actor_set); | |||
| Link(actor_set.get(), graph, strategy); | |||
| if (!CheckActorValid(actor_set.get())) { | |||
| MS_LOG(EXCEPTION) << "The actor set of " << graph->ToString() << " is invalid."; | |||
| std::vector<ActorSetPtr> actor_sets; | |||
| for (size_t i = 0; i < graphs.size(); ++i) { | |||
| auto graph = graphs[i]; | |||
| auto device_context = device_contexts[i]; | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_LOG(INFO) << "Graph(" << graph->ToString() << ") transforms actor begin."; | |||
| PersistDeviceTensor(graph); | |||
| auto actor_set = Build(graph, device_context); | |||
| actor_sets.emplace_back(actor_set); | |||
| graph_to_actors_.emplace(graph, actor_set); | |||
| Link(actor_set.get(), graph, strategy); | |||
| if (!CheckActorValid(actor_set.get())) { | |||
| MS_LOG(EXCEPTION) << "The actor set of " << graph->ToString() << " is invalid."; | |||
| } | |||
| MS_LOG(INFO) << "Graph(" << graph->ToString() << ") transforms actor end."; | |||
| } | |||
| MS_LOG(INFO) << "Graph(" << graph->ToString() << ") transforms actor end."; | |||
| return actor_set.get(); | |||
| return actor_sets[0].get(); | |||
| } | |||
| void GraphScheduler::Schedule(const ActorSet *actor_set) { | |||
| @@ -69,8 +69,9 @@ class GraphScheduler { | |||
| void Initialize(); | |||
| // Transform graph to actor DAG, contains build and link. | |||
| ActorSet *Transform(const KernelGraphPtr &graph, const DeviceContext *device_context, | |||
| ActorSet *Transform(const std::vector<KernelGraphPtr> &graphs, const std::vector<DeviceContext *> &device_contexts, | |||
| const std::vector<TensorPtr> *input_tensors = nullptr, | |||
| const std::vector<AnfNodePtr> *control_nodes = nullptr, | |||
| GraphExecutionStrategy strategy = GraphExecutionStrategy::kPipeline); | |||
| // Schedule actors in the actor runtime. Single machine scheduling is supported currently, and distributed scheduling | |||
| @@ -18,6 +18,7 @@ | |||
| #include <algorithm> | |||
| #include <vector> | |||
| #include "vm/transform.h" | |||
| #include "backend/session/session_factory.h" | |||
| #include "pipeline/pynative/pynative_execute.h" | |||
| #include "ir/anf.h" | |||
| @@ -230,27 +231,85 @@ void MsBackend::SetDebugger() { target_sess_->SetDebugger(); } | |||
| #endif | |||
| MindRTBackend::MindRTBackend(const std::string &backend_name, const std::string &device_name, uint32_t device_id) | |||
| : Backend(backend_name), device_name_(device_name), device_id_(device_id) {} | |||
| GraphId MindRTBackend::CompileGraph(const AnfNodePtrList &nodes) { | |||
| MS_LOG(INFO) << "Compile graph begin."; | |||
| // Get and set the device context. | |||
| const auto &cur_device_name = GetCNodeTarget(nodes[0]); | |||
| const auto &device_context = | |||
| device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({cur_device_name, device_id_}); | |||
| device_context->Initialize(); | |||
| runtime::GraphCompiler::GetInstance().set_device_context(device_context); | |||
| // Transform nodes to inputs and outputs. | |||
| FuncGraphPtr fg; | |||
| AnfNodePtrList inputs; | |||
| AnfNodePtrList outputs; | |||
| std::tie(fg, inputs, outputs) = TransformSegmentToAnfGraph(nodes); | |||
| : Backend(backend_name), device_name_(device_name), device_id_(device_id) { | |||
| auto cut_list = compile::GetMsNonlinearOps(); | |||
| graph_partition_ = std::make_shared<GraphPartition>(cut_list, backend_name); | |||
| } | |||
| GraphId MindRTBackend::CompileGraphs(const FuncGraphPtr &func_graph) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| FuncGraphPtr root_graph = WrapPrimitives(func_graph); | |||
| MS_EXCEPTION_IF_NULL(root_graph); | |||
| // Compile root graph. | |||
| auto root_graph_id = CompileGraph(root_graph); | |||
| // Compile sub graphs. | |||
| FuncGraphSet sub_graphs = root_graph->manager()->func_graphs(); | |||
| for (auto sub_graph : sub_graphs) { | |||
| if (sub_graph != func_graph && sub_graph != nullptr) { | |||
| (void)CompileGraph(sub_graph); | |||
| } | |||
| } | |||
| // Transform graph to actor DAG, and schedule the actor DAG. | |||
| std::vector<KernelGraphPtr> graphs; | |||
| std::vector<DeviceContext *> device_contexts; | |||
| for (const auto &graph_id_to_context : graph_to_device_context_) { | |||
| graphs.emplace_back(runtime::GraphCompiler::GetInstance().Fetch(graph_id_to_context.first)); | |||
| device_contexts.emplace_back(graph_id_to_context.second); | |||
| } | |||
| const auto &actor_set = | |||
| runtime::GraphScheduler::GetInstance().Transform(graphs, device_contexts, nullptr, &control_nodes_); | |||
| runtime::GraphScheduler::GetInstance().Schedule(actor_set); | |||
| return root_graph_id; | |||
| } | |||
| GraphId MindRTBackend::CompileGraph(const FuncGraphPtr &func_graph) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(graph_partition_); | |||
| // Split graph to segments. | |||
| const auto &segments = graph_partition_->Partition(func_graph); | |||
| MS_LOG(INFO) << "Compile graph: " << func_graph->ToString() << ", Split segments size:" << segments.size(); | |||
| // Foreach the segments to compile graph. | |||
| for (const auto &segment : segments) { | |||
| MS_EXCEPTION_IF_NULL(segment); | |||
| // Compile the normal nodes, which doesn't contain the cut node. | |||
| if (!segment->is_cut_) { | |||
| if (segment->nodes_.size() == 0) { | |||
| MS_LOG(EXCEPTION) << "The segments size is 0."; | |||
| } | |||
| MS_LOG(INFO) << "Compile normal segment, the first node: " << segment->nodes_[0]->fullname_with_scope(); | |||
| // Get and set the device context. | |||
| const auto &cur_device_name = GetCNodeTarget(segment->nodes_[0]); | |||
| const auto &device_context = | |||
| device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({cur_device_name, device_id_}); | |||
| device_context->Initialize(); | |||
| runtime::GraphCompiler::GetInstance().set_device_context(device_context); | |||
| // Transform nodes to inputs and outputs. | |||
| FuncGraphPtr fg; | |||
| AnfNodePtrList inputs; | |||
| AnfNodePtrList outputs; | |||
| std::tie(fg, inputs, outputs) = TransformSegmentToAnfGraph(segment->nodes_); | |||
| // Compile graph. | |||
| auto graph_id = runtime::GraphCompiler::GetInstance().CompileGraph(segment->nodes_, outputs); | |||
| graph_to_device_context_[graph_id] = device_context; | |||
| } else { | |||
| // Compile the cut node. | |||
| auto cut_node = segment->nodes_[0]; | |||
| MS_EXCEPTION_IF_NULL(cut_node); | |||
| MS_LOG(INFO) << "Compile cut segment, the cut node: " << cut_node->fullname_with_scope(); | |||
| control_nodes_.push_back(cut_node); | |||
| } | |||
| } | |||
| // Compile graph. | |||
| auto graph_id = runtime::GraphCompiler::GetInstance().CompileGraph(nodes, outputs); | |||
| MS_LOG(INFO) << "Compile graph end, graph id: " << graph_id; | |||
| return graph_id; | |||
| return graph_to_device_context_.begin()->first; | |||
| } | |||
| VectorRef MindRTBackend::RunGraph(GraphId graph_id, const VectorRef &args) { | |||
| @@ -29,11 +29,12 @@ | |||
| #include "vm/graph_partition.h" | |||
| #include "vm/vm.h" | |||
| #include "backend/session/session_basic.h" | |||
| #include "runtime/hardware/device_context.h" | |||
| namespace mindspore { | |||
| namespace compile { | |||
| using OpRunInfo = session::OpRunInfo; | |||
| using DeviceContext = device::DeviceContext; | |||
| enum SwitchCondStatus { | |||
| kCondOk = 0, | |||
| kCondAlreadyRun, | |||
| @@ -94,8 +95,9 @@ class MindRTBackend : public Backend { | |||
| MindRTBackend(const std::string &backend_name, const std::string &device_name, uint32_t device_id); | |||
| ~MindRTBackend() override = default; | |||
| // Compile kernel graph from anf nodes list in the graph mode. | |||
| GraphId CompileGraph(const AnfNodePtrList &nodes); | |||
| // The parameter root_graph is a root graph, and the root graph maybe contain multiple sub graphs, | |||
| // the return is the kernelGraph id of the root graph. It will traverse all subgraphs to call CompileGraph. | |||
| GraphId CompileGraphs(const FuncGraphPtr &root_graph); | |||
| // Compile single op kernel graph in the pyNative mode. | |||
| GraphId CompileGraph(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||
| const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<int64_t> &tensors_mask); | |||
| @@ -106,6 +108,17 @@ class MindRTBackend : public Backend { | |||
| VectorRef RunGraph(const GraphInfo &graph_info, const VectorRef &args); | |||
| private: | |||
| // The parameter func_graph is a graph, it can be either a root graph or a sub graph, | |||
| // the return is the corresponding kernelGraph id of the graph. | |||
| GraphId CompileGraph(const FuncGraphPtr &func_graph); | |||
| // When compiling FuncGraph, it is divided according to the control nodes, and obtain the control nodes and several | |||
| // node segments. Node segments will be compiled into kernelGraphs which are expressed as GraphId and bound to | |||
| // the corresponding device_context. | |||
| std::unordered_map<GraphId, DeviceContext *> graph_to_device_context_; | |||
| std::vector<AnfNodePtr> control_nodes_; | |||
| GraphPartitionPtr graph_partition_; | |||
| std::string device_name_; | |||
| uint32_t device_id_; | |||
| }; | |||
| @@ -521,69 +521,6 @@ FinalVMPtr CompileGraphs::CompileAndLink(const FuncGraphPtr &graph) { | |||
| return rt; | |||
| } | |||
| GraphCompiler::GraphCompiler(const std::shared_ptr<MindRTBackend> &backend, const std::vector<PrimitivePtr> &cut_list) | |||
| : backend_(backend) { | |||
| MS_EXCEPTION_IF_NULL(backend_); | |||
| if (backend_ == nullptr) { | |||
| MS_LOG(ERROR) << "The backend isn't created."; | |||
| return; | |||
| } | |||
| graph_partition_ = std::make_shared<GraphPartition>(cut_list, backend->name()); | |||
| } | |||
| uint32_t GraphCompiler::CompileGraphs(const FuncGraphPtr &func_graph) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| FuncGraphPtr root_graph = WrapPrimitives(func_graph); | |||
| MS_EXCEPTION_IF_NULL(root_graph); | |||
| // Compile root graph. | |||
| auto root_graph_id = CompileGraph(root_graph); | |||
| // Compile sub graphs. | |||
| FuncGraphSet sub_graphs = root_graph->manager()->func_graphs(); | |||
| for (auto sub_graph : sub_graphs) { | |||
| if (sub_graph != func_graph && sub_graph != nullptr) { | |||
| (void)CompileGraph(sub_graph); | |||
| } | |||
| } | |||
| return root_graph_id; | |||
| } | |||
| uint32_t GraphCompiler::CompileGraph(const FuncGraphPtr &func_graph) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(graph_partition_); | |||
| MS_EXCEPTION_IF_NULL(backend_); | |||
| // Split graph to segments. | |||
| const auto &segments = graph_partition_->Partition(func_graph); | |||
| MS_LOG(INFO) << "Compile graph: " << func_graph->ToString() << ", Split segments size:" << segments.size(); | |||
| // Foreach the segments to compile graph. | |||
| std::vector<uint32_t> graph_ids; | |||
| for (const auto &segment : segments) { | |||
| MS_EXCEPTION_IF_NULL(segment); | |||
| // Compile the normal nodes, which doesn't contain the cut node. | |||
| if (!segment->is_cut_) { | |||
| if (segment->nodes_.size() == 0) { | |||
| MS_LOG(EXCEPTION) << "The segments size is 0."; | |||
| } | |||
| MS_LOG(INFO) << "Compile normal segment, the first node: " << segment->nodes_[0]->fullname_with_scope(); | |||
| // Compile the anfNodes list to kernelGraph, return the graph id of kernelGraph. | |||
| auto graph_id = backend_->CompileGraph(segment->nodes_); | |||
| graph_ids.emplace_back(graph_id); | |||
| } else { | |||
| // Compile the cut node. | |||
| auto cut_node = segment->nodes_[0]; | |||
| MS_EXCEPTION_IF_NULL(cut_node); | |||
| MS_LOG(INFO) << "Compile cut segment, the cut node: " << cut_node->fullname_with_scope(); | |||
| } | |||
| } | |||
| return graph_ids[0]; | |||
| } | |||
| // Judge whether to use mindRT. GPU and CPU use mindRT currently, and other hardwares will use it in the future. | |||
| // Return false in the transitional stage. | |||
| bool IsMindRTUsed() { return false; } | |||
| @@ -44,7 +44,7 @@ extern const char kGeVm[]; | |||
| namespace compile { | |||
| extern std::vector<PrimitivePtr> nonlinear_ops; | |||
| const std::vector<PrimitivePtr> &GetMsNonlinearOps(); | |||
| FuncGraphPtr WrapPrimitives(const FuncGraphPtr &graph); | |||
| using VmEvalFunc = std::function<BaseRef(const VectorRef &)>; | |||
| using VmEvalFuncPtr = std::shared_ptr<std::function<BaseRef(const VectorRef &)>>; | |||
| @@ -131,27 +131,6 @@ class CompileGraphs { | |||
| BackendPtr backend_; | |||
| }; | |||
| // The graph compiling of using mindRT, which transforms the funcGraph to kernelGraph and returns the graph id of | |||
| // kernelGraph. | |||
| class GraphCompiler { | |||
| public: | |||
| GraphCompiler(const std::shared_ptr<MindRTBackend> &backend, | |||
| const std::vector<PrimitivePtr> &cut_list = nonlinear_ops); | |||
| ~GraphCompiler() = default; | |||
| // The parameter root_graph is a root graph, and the root graph maybe contain multiple sub graphs, | |||
| // the return is the kernelGraph id of the root graph. It will traverse all subgraphs to call CompileGraph. | |||
| uint32_t CompileGraphs(const FuncGraphPtr &root_graph); | |||
| private: | |||
| // The parameter func_graph is a graph, it can be either a root graph or a sub graph, | |||
| // the return is the corresponding kernelGraph id of the graph. | |||
| uint32_t CompileGraph(const FuncGraphPtr &func_graph); | |||
| std::shared_ptr<MindRTBackend> backend_; | |||
| GraphPartitionPtr graph_partition_; | |||
| }; | |||
| // Judge whether to use mindRT. GPU and CPU use mindRT currently, and other hardwares will use it in the future. | |||
| bool IsMindRTUsed(); | |||