diff --git a/mindspore/ccsrc/runtime/framework/actor/loop_count_actor.h b/mindspore/ccsrc/runtime/framework/actor/loop_count_actor.h index 7a7b7d7f6d..fe922396d8 100644 --- a/mindspore/ccsrc/runtime/framework/actor/loop_count_actor.h +++ b/mindspore/ccsrc/runtime/framework/actor/loop_count_actor.h @@ -30,7 +30,8 @@ namespace runtime { // and decide whether to loop execution by loop count. class LoopCountActor : public OpActor { public: - LoopCountActor(std::string name, size_t loop_count) : OpActor(name), loop_count_(loop_count), current_count_(0) {} + LoopCountActor(std::string name, size_t loop_count) + : OpActor(name), loop_count_(loop_count), current_count_(0), input_controls_num_(0) {} virtual ~LoopCountActor() = default; // The loop count actor run when receive the input control. diff --git a/mindspore/ccsrc/runtime/framework/actor/memory_manager_actor.h b/mindspore/ccsrc/runtime/framework/actor/memory_manager_actor.h index 4e15dabff7..9c9ab16318 100644 --- a/mindspore/ccsrc/runtime/framework/actor/memory_manager_actor.h +++ b/mindspore/ccsrc/runtime/framework/actor/memory_manager_actor.h @@ -21,7 +21,7 @@ #include #include #include -#include "mindrt/include/actor/actor.h" +#include "mindrt/include/actor/op_actor.h" #include "runtime/framework/device_tensor_store.h" #include "runtime/hardware/device_context.h" @@ -33,12 +33,7 @@ using mindspore::device::DeviceContext; class MemoryManagerActor : public ActorBase { public: MemoryManagerActor() : ActorBase("MemoryManagerActor") {} - virtual ~MemoryManagerActor() = default; - - static std::shared_ptr &GetInstance() { - static std::shared_ptr instance; - return instance; - } + ~MemoryManagerActor() override = default; // The process entry of memory alloc. bool AllocateMemory(std::vector alloc_list, const DeviceContext *device_context, diff --git a/mindspore/ccsrc/runtime/framework/device_tensor_store.h b/mindspore/ccsrc/runtime/framework/device_tensor_store.h index bc2ff03007..328f3fcb48 100644 --- a/mindspore/ccsrc/runtime/framework/device_tensor_store.h +++ b/mindspore/ccsrc/runtime/framework/device_tensor_store.h @@ -19,6 +19,7 @@ #include #include +#include "utils/ms_utils.h" #include "runtime/device/device_address.h" namespace mindspore { @@ -32,9 +33,6 @@ using DeviceTensorPtr = std::shared_ptr; // so they are more suitable for store and can be obtained when they are used by actor. class DeviceTensorStore { public: - DeviceTensorStore() = default; - virtual ~DeviceTensorStore() = default; - static DeviceTensorStore &GetInstance() { static DeviceTensorStore instance; return instance; @@ -60,6 +58,10 @@ class DeviceTensorStore { } private: + DeviceTensorStore() = default; + ~DeviceTensorStore() = default; + DISABLE_COPY_AND_ASSIGN(DeviceTensorStore); + // The data storage of device tensor, key is anfNode ptr. std::unordered_map device_tensors_; }; diff --git a/mindspore/ccsrc/runtime/framework/graph_scheduler.cc b/mindspore/ccsrc/runtime/framework/graph_scheduler.cc new file mode 100644 index 0000000000..75b009798a --- /dev/null +++ b/mindspore/ccsrc/runtime/framework/graph_scheduler.cc @@ -0,0 +1,459 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/framework/graph_scheduler.h" +#include "mindrt/src/actor/actormgr.h" +#include "mindrt/include/async/async.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/common/helper.h" +#include "utils/config_manager.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace runtime { +namespace { +bool IsDeviceQueueDSActor(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (node->isa() && (AnfAlgo::GetCNodeName(node) == kGetNextOpName)) { + return true; + } + return false; +} + +bool IsHostQueueDSActor(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (node->isa() && (!AnfAlgo::IsParameterWeight(node->cast()))) { + return true; + } + return false; +} + +bool IsKernelActor(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (node->isa() && (AnfAlgo::GetCNodeName(node) != kGetNextOpName)) { + return true; + } + return false; +} + +// Judge whether the device tensor of the node is persistent or not. +bool IsPersistentDeviceTensor(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (node->isa()) { + return true; + } + if (node->isa() && AnfAlgo::IsParameterWeight(node->cast())) { + return true; + } + return false; +} + +KernelActor *FindKernelActor(const std::unordered_map &kernel_actors_map, + const std::string &name) { + auto iter = kernel_actors_map.find(name); + if (iter != kernel_actors_map.end()) { + return iter->second.get(); + } + return nullptr; +} + +DeviceQueueDataSourceActor *FindDeviceQueueDSActor(const std::vector &data_source_actors) { + for (auto &actor : data_source_actors) { + MS_EXCEPTION_IF_NULL(actor); + if (actor->GetAID().Name().find("_DeviceQueueDataSourceActor") != string::npos) { + auto device_queue_ds_actor = dynamic_cast(actor.get()); + return device_queue_ds_actor; + } + } + return nullptr; +} + +HostQueueDataSourceActor *FindHostQueueDSActor(const std::vector &data_source_actors) { + for (auto &actor : data_source_actors) { + MS_EXCEPTION_IF_NULL(actor); + if (actor->GetAID().Name().find("_HostQueueDataSourceActor") != string::npos) { + auto device_queue_ds_actor = dynamic_cast(actor.get()); + return device_queue_ds_actor; + } + } + return nullptr; +} + +// Update the reference count of device tensor by the output index of node. +void UpdateRefCount(const AnfNodePtr &node, size_t output_idx) { + MS_EXCEPTION_IF_NULL(node); + auto device_tensor = AnfAlgo::GetMutableOutputAddr(node, output_idx); + MS_EXCEPTION_IF_NULL(device_tensor); + device_tensor->IncreaseRefCount(); + device_tensor->ResetRefCountUsed(); +} +} // namespace + +ActorSet *GraphScheduler::Transform(const KernelGraphPtr &graph, const DeviceContext *device_context, + const std::vector *input_tensors, + GraphExecutionStrategy strategy) { + PersistDeviceTensor(graph); + auto actor_set = Build(graph, device_context); + graph_to_actors_.emplace(graph, actor_set); + Link(actor_set.get(), graph, strategy); + return actor_set.get(); +} + +void GraphScheduler::Schedule(const ActorSet *actor_set) { + MS_EXCEPTION_IF_NULL(actor_set); + auto actorMgr = ActorMgr::GetActorMgrRef(); + MS_EXCEPTION_IF_NULL(actorMgr); + + // Schedule dats source actors. + for (auto &data_source_actor : actor_set->data_source_actors_) { + MS_EXCEPTION_IF_NULL(data_source_actor); + auto base_actor = static_cast(data_source_actor); + (void)actorMgr->Spawn(base_actor); + } + + // Schedule kernel actors. + for (auto &kernel_actor : actor_set->kernel_actors_) { + MS_EXCEPTION_IF_NULL(kernel_actor); + auto base_actor = static_cast(kernel_actor); + (void)actorMgr->Spawn(base_actor); + } + + // Schedule loop count actor. + if (actor_set->loop_count_actor_ != nullptr) { + auto base_actor = static_cast(actor_set->loop_count_actor_); + (void)actorMgr->Spawn(base_actor); + } +} + +bool GraphScheduler::Run(const ActorSet *actor_set, GraphExecutionStrategy strategy) { + MS_EXCEPTION_IF_NULL(actor_set); + // Construct OpContext. + OpContext op_context; + auto sequential_num = uuids::RandomBasedGenerator::GenerateRandomUuid(); + op_context.sequential_num_ = &sequential_num; + Promise result; + op_context.results_->push_back(result); + + // Trigger no input kernel actor running. + for (auto &no_input_kernel_actor : actor_set->no_input_kernel_actors_) { + MS_EXCEPTION_IF_NULL(no_input_kernel_actor); + Async(no_input_kernel_actor->GetAID(), &KernelActor::RunOpControl, nullptr, &op_context); + } + + // Trigger data source actor running. + for (auto &data_source_actor : actor_set->data_source_actors_) { + MS_EXCEPTION_IF_NULL(data_source_actor); + Async(data_source_actor->GetAID(), &DataSourceActor::FetchData, &op_context); + } + + // Trigger kernel actor running in the step execution strategy. + if (strategy == GraphExecutionStrategy::kStep) { + for (auto &kernel_actor : actor_set->kernel_actors_) { + MS_EXCEPTION_IF_NULL(kernel_actor); + Async(kernel_actor->GetAID(), &KernelActor::RunOpControl, nullptr, &op_context); + } + } + + // Get the run result. + auto result_future = result.GetFuture(); + result_future.Wait(); + if (!result_future.IsOK()) { + return false; + } + return true; +} + +ActorSet *GraphScheduler::Fetch(const KernelGraphPtr &graph) const { + MS_EXCEPTION_IF_NULL(graph); + auto iter = graph_to_actors_.find(graph); + if (iter != graph_to_actors_.end()) { + return iter->second.get(); + } else { + MS_LOG(ERROR) << "Can't find the actors map of graph: " << graph->ToString(); + return nullptr; + } +} + +ActorSetPtr GraphScheduler::Build(const KernelGraphPtr &graph, const DeviceContext *device_context) { + auto actor_set = std::make_shared(); + MS_EXCEPTION_IF_NULL(actor_set); + + auto data_source_actors = BuildDataSourceActor(graph); + actor_set->data_source_actors_.swap(data_source_actors); + + auto kernel_actors = BuildKernelActor(graph, device_context); + actor_set->kernel_actors_.swap(kernel_actors); + + auto loop_count_actor = BuildLoopCountActor(graph); + actor_set->loop_count_actor_ = loop_count_actor; + + return actor_set; +} + +void GraphScheduler::Link(ActorSet *actor_set, const KernelGraphPtr &graph, GraphExecutionStrategy strategy) { + MS_EXCEPTION_IF_NULL(actor_set); + MS_EXCEPTION_IF_NULL(graph); + std::unordered_map kernel_actors_temp_map; + for (auto &actor : actor_set->kernel_actors_) { + MS_EXCEPTION_IF_NULL(actor); + kernel_actors_temp_map.emplace(actor->GetAID().Name(), actor); + } + + // Foreach the execution order to link the actors. + auto execution_order = graph->execution_order(); + for (auto &kernel : execution_order) { + if (!IsKernelActor(kernel)) { + continue; + } + auto kernel_actor = FindKernelActor(kernel_actors_temp_map, kernel->fullname_with_scope()); + // Link the control arrows of kernel actor. + LinkControlArrowForKernelActor(kernel_actor, actor_set->loop_count_actor_.get(), graph, strategy); + + for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { + KernelWithIndex from_kernel_with_output_idx = AnfAlgo::GetPrevNodeOutput(kernel, i); + KernelWithIndex to_kernel_with_input_idx = std::make_pair(kernel, i); + auto from_kernel = from_kernel_with_output_idx.first; + + if (IsDeviceQueueDSActor(from_kernel)) { + // Link the data arrows of device queue data source actor. + auto from_actor = FindDeviceQueueDSActor(actor_set->data_source_actors_); + LinkDataArrowForDeviceDSActor(from_actor, kernel_actor, from_kernel_with_output_idx, to_kernel_with_input_idx); + } else if (IsHostQueueDSActor(from_kernel)) { + // Link the data arrows of host queue data source actor. + auto from_actor = FindHostQueueDSActor(actor_set->data_source_actors_); + LinkDataArrowForHostDSActor(from_actor, kernel_actor, from_kernel_with_output_idx, to_kernel_with_input_idx); + } else { + // Link the data arrows of kernel actor. + auto from_actor = FindKernelActor(kernel_actors_temp_map, from_kernel->fullname_with_scope()); + LinkDataArrowForKernelActor(from_actor, kernel_actor, from_kernel_with_output_idx, to_kernel_with_input_idx); + } + } + } + + // BuildNoInputKernelActor depends on whether kernel actors have input, so must be behind the link of kernel actors. + auto no_input_kernel_actors = BuildNoInputKernelActor(graph); + actor_set->no_input_kernel_actors_.swap(no_input_kernel_actors); + + // Link the control arrows of loop count actor, which depends on the no input kernel actors. + LinkControlArrowForLoopCountActor(actor_set->loop_count_actor_.get(), graph); +} + +std::vector GraphScheduler::BuildDataSourceActor(const KernelGraphPtr &graph) { + MS_EXCEPTION_IF_NULL(graph); + std::vector data_source_actors; + + // Build host queue data source actor. + HostQueueDSActorPtr host_queue_ds_actor = nullptr; + for (auto &input_node : graph->input_nodes()) { + MS_EXCEPTION_IF_NULL(input_node); + if (IsHostQueueDSActor(input_node)) { + if (host_queue_ds_actor == nullptr) { + auto actor_name = graph->ToString() + "_" + "HostQueueDataSourceActor"; + MS_LOG(INFO) << "Create host queue data source actor: " << actor_name; + auto host_queue = std::make_shared(); + graph_to_host_queue_.emplace(graph, host_queue); + host_queue_ds_actor = std::make_shared(actor_name, 1, host_queue); + data_source_actors.emplace_back(host_queue_ds_actor); + } + host_queue_ds_actor->data_nodes_.emplace_back(input_node); + } + } + + // Build device queue data source actor. + auto execution_order = graph->execution_order(); + auto iter = std::find_if(execution_order.begin(), execution_order.end(), + [](const CNodePtr &node) { return IsDeviceQueueDSActor(node); }); + if (iter != execution_order.end()) { + auto actor_name = graph->ToString() + "_" + "DeviceQueueDataSourceActor"; + MS_LOG(INFO) << "Create queue data source actor: " << actor_name; + auto device_queue_ds_actor = std::make_shared(actor_name, 1); + MS_EXCEPTION_IF_NULL(device_queue_ds_actor); + data_source_actors.emplace_back(device_queue_ds_actor); + device_queue_ds_actor->data_kernel_ = *iter; + } + return data_source_actors; +} + +std::vector GraphScheduler::BuildKernelActor(const KernelGraphPtr &graph, + const DeviceContext *device_context) { + MS_EXCEPTION_IF_NULL(graph); + std::vector kernel_actors; + + auto execution_order = graph->execution_order(); + for (auto &kernel : execution_order) { + if (IsKernelActor(kernel)) { + auto kernel_actor = std::make_shared(kernel->fullname_with_scope(), kernel, device_context); + MS_EXCEPTION_IF_NULL(kernel_actor); + kernel_actors.emplace_back(kernel_actor); + } + } + return kernel_actors; +} + +std::vector GraphScheduler::BuildNoInputKernelActor(const KernelGraphPtr &graph) { + MS_EXCEPTION_IF_NULL(graph); + std::vector no_input_kernel_actors; + + auto actor_set = Fetch(graph); + MS_EXCEPTION_IF_NULL(actor_set); + for (auto &kernel_actor : actor_set->kernel_actors_) { + MS_EXCEPTION_IF_NULL(kernel_actor); + if ((kernel_actor->input_datas_num_ == 0) && (kernel_actor->input_controls_num_ == 0)) { + no_input_kernel_actors.emplace_back(kernel_actor); + } + } + return no_input_kernel_actors; +} + +LoopCountActorPtr GraphScheduler::BuildLoopCountActor(const KernelGraphPtr &graph) { + MS_EXCEPTION_IF_NULL(graph); + auto loop_count = ConfigManager::GetInstance().iter_num(); + auto actor_name = graph->ToString() + "_" + "LoopCountActor"; + auto loop_count_actor = std::make_shared(actor_name, loop_count); + MS_EXCEPTION_IF_NULL(loop_count_actor); + return loop_count_actor; +} + +void GraphScheduler::LinkDataArrowForDeviceDSActor(DeviceQueueDataSourceActor *from_actor, KernelActor *to_actor, + KernelWithIndex from_kernel_with_output_idx, + KernelWithIndex to_kernel_with_input_idx) { + MS_EXCEPTION_IF_NULL(from_actor); + MS_EXCEPTION_IF_NULL(to_actor); + + auto from_kernel = from_kernel_with_output_idx.first; + MS_EXCEPTION_IF_NULL(from_kernel); + auto from_output_index = from_kernel_with_output_idx.second; + auto to_input_index = to_kernel_with_input_idx.second; + + auto to_aid = to_actor->GetAID(); + auto op_arrow = std::make_shared(from_output_index, to_aid, to_input_index); + from_actor->output_op_arrows_.emplace_back(op_arrow); + to_actor->input_datas_num_++; + + // Update the reference count of device tensor. + UpdateRefCount(from_kernel, from_output_index); +} + +void GraphScheduler::LinkDataArrowForHostDSActor(HostQueueDataSourceActor *from_actor, KernelActor *to_actor, + KernelWithIndex from_kernel_with_output_idx, + KernelWithIndex to_kernel_with_input_idx) { + MS_EXCEPTION_IF_NULL(from_actor); + MS_EXCEPTION_IF_NULL(to_actor); + + auto from_kernel = from_kernel_with_output_idx.first; + MS_EXCEPTION_IF_NULL(from_kernel); + auto from_output_index = from_kernel_with_output_idx.second; + auto to_input_index = to_kernel_with_input_idx.second; + + auto data_nodes = from_actor->data_nodes_; + auto iter = find(data_nodes.begin(), data_nodes.end(), from_kernel); + if (iter == data_nodes.end()) { + MS_LOG(EXCEPTION) << "Parameter node: " << from_kernel->fullname_with_scope() << " is not exist."; + } + auto position = IntToSize(std::distance(data_nodes.begin(), iter)); + auto to_aid = to_actor->GetAID(); + auto op_arrow = std::make_shared(position, to_aid, to_input_index); + from_actor->output_op_arrows_.emplace_back(op_arrow); + to_actor->input_datas_num_++; + + // Update the reference count of device tensor. + UpdateRefCount(from_kernel, from_output_index); +} + +void GraphScheduler::LinkDataArrowForKernelActor(KernelActor *from_actor, KernelActor *to_actor, + KernelWithIndex from_kernel_with_output_idx, + KernelWithIndex to_kernel_with_input_idx) { + MS_EXCEPTION_IF_NULL(to_actor); + auto from_kernel = from_kernel_with_output_idx.first; + MS_EXCEPTION_IF_NULL(from_kernel); + auto from_output_index = from_kernel_with_output_idx.second; + auto to_input_index = to_kernel_with_input_idx.second; + + if (IsPersistentDeviceTensor(from_kernel)) { + to_actor->device_tensor_store_keys_.emplace_back(to_input_index, static_cast(from_kernel.get())); + } else if (IsKernelActor(from_kernel)) { + MS_EXCEPTION_IF_NULL(from_actor); + auto to_aid = to_actor->GetAID(); + auto op_arrow = std::make_shared(from_output_index, to_aid, to_input_index); + from_actor->output_op_arrows_.emplace_back(op_arrow); + to_actor->input_datas_num_++; + + // Update the reference count of device tensor. + UpdateRefCount(from_kernel, from_output_index); + } +} + +void GraphScheduler::LinkControlArrowForKernelActor(KernelActor *from_actor, LoopCountActor *to_actor, + const KernelGraphPtr &graph, GraphExecutionStrategy strategy) { + MS_EXCEPTION_IF_NULL(from_actor); + MS_EXCEPTION_IF_NULL(to_actor); + MS_EXCEPTION_IF_NULL(graph); + + if (strategy == GraphExecutionStrategy::kStep) { + from_actor->input_controls_num_++; + } + + if (opt::IsNotRealUsedByOthers(graph, from_actor->kernel_)) { + auto to_aid = to_actor->GetAID(); + from_actor->output_op_controls_.emplace_back(to_aid); + to_actor->input_controls_num_++; + } +} + +void GraphScheduler::LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const KernelGraphPtr &graph) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(loop_count_actor); + + auto actor_set = Fetch(graph); + MS_EXCEPTION_IF_NULL(actor_set); + + // Set the source data actor. + for (auto &data_source_actor : actor_set->data_source_actors_) { + MS_EXCEPTION_IF_NULL(data_source_actor); + loop_count_actor->data_source_aids_.emplace_back(data_source_actor->GetAID()); + } + + // Set the no input kernel actor. + for (auto &no_input_kernel_actor : actor_set->no_input_kernel_actors_) { + MS_EXCEPTION_IF_NULL(no_input_kernel_actor); + loop_count_actor->no_input_kernel_aids_.emplace_back(no_input_kernel_actor->GetAID()); + } +} + +void GraphScheduler::PersistDeviceTensor(const KernelGraphPtr &graph) { + MS_EXCEPTION_IF_NULL(graph); + + for (auto &value_node : graph->graph_value_nodes()) { + MS_EXCEPTION_IF_NULL(value_node); + auto device_tensor = AnfAlgo::GetMutableOutputAddr(value_node, 0); + DeviceTensorStore::GetInstance().Insert(value_node.get(), device_tensor); + device_tensor->set_ref_count(SIZE_MAX); + device_tensor->ResetRefCountUsed(); + } + + for (auto &input_node : graph->input_nodes()) { + MS_EXCEPTION_IF_NULL(input_node); + if (IsPersistentDeviceTensor(input_node)) { + auto device_tensor = AnfAlgo::GetMutableOutputAddr(input_node, 0); + DeviceTensorStore::GetInstance().Insert(input_node.get(), device_tensor); + device_tensor->set_ref_count(SIZE_MAX); + device_tensor->ResetRefCountUsed(); + } + } +} + +} // namespace runtime +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/framework/graph_scheduler.h b/mindspore/ccsrc/runtime/framework/graph_scheduler.h index b087929ffe..0aa7bad6c6 100644 --- a/mindspore/ccsrc/runtime/framework/graph_scheduler.h +++ b/mindspore/ccsrc/runtime/framework/graph_scheduler.h @@ -31,12 +31,11 @@ namespace mindspore { namespace runtime { using mindspore::device::DeviceContext; +using mindspore::session::KernelWithIndex; enum class GraphExecutionStrategy { - // The actor running is triggered only by data. - kPipeline, - // The actor running need be triggered by control in addition. - kStep + kPipeline, // The actor running is triggered only by data. + kStep // The actor running need be triggered by control in addition. }; // The actor set generated by graph transformer is the execution unit of actor runtime. @@ -57,49 +56,60 @@ using ActorSetPtr = std::shared_ptr; class GraphScheduler { public: - GraphScheduler() = default; - virtual ~GraphScheduler() = default; - static GraphScheduler &GetInstance() { static GraphScheduler instance; return instance; } // Transform graph to actor DAG, contains build and link. - ActorSetPtr Transform(const KernelGraphPtr &graph, const DeviceContext *device_context, - const std::vector *input_tensors = nullptr, - GraphExecutionStrategy strategy = GraphExecutionStrategy::kPipeline); + ActorSet *Transform(const KernelGraphPtr &graph, const DeviceContext *device_context, + const std::vector *input_tensors = nullptr, + GraphExecutionStrategy strategy = GraphExecutionStrategy::kPipeline); // Schedule actors in the actor runtime. Single machine scheduling is supported currently, and distributed scheduling // will be supported in the future. - void Schedule(const ActorSetPtr &actor_set); + void Schedule(const ActorSet *actor_set); // The processing entry of actors running. - bool Run(const ActorSetPtr &actor_set); + bool Run(const ActorSet *actor_set, GraphExecutionStrategy strategy = GraphExecutionStrategy::kPipeline); + + // Fetch the actor set by kernel graph. + ActorSet *Fetch(const KernelGraphPtr &graph) const; private: + GraphScheduler() = default; + ~GraphScheduler() = default; + DISABLE_COPY_AND_ASSIGN(GraphScheduler); + // Transform the nodes of graph to actors. ActorSetPtr Build(const KernelGraphPtr &graph, const DeviceContext *device_context); // Link actors to DAG through the edge connection of graph and graph execution strategy. - void Link(ActorSetPtr actor_set, const KernelGraphPtr &graph, GraphExecutionStrategy strategy); + void Link(ActorSet *actor_set, const KernelGraphPtr &graph, GraphExecutionStrategy strategy); // The processing of actors build. std::vector BuildDataSourceActor(const KernelGraphPtr &graph); std::vector BuildKernelActor(const KernelGraphPtr &graph, const DeviceContext *device_context); + std::vector BuildNoInputKernelActor(const KernelGraphPtr &graph); LoopCountActorPtr BuildLoopCountActor(const KernelGraphPtr &graph); // The processing of actors link. - void LinkDataSourceActor(std::vector actors, const KernelGraphPtr &graph); - void LinkKernelActor(std::vector actors, const KernelGraphPtr &graph, - GraphExecutionStrategy strategy); - void LinkLoopCountActor(LoopCountActorPtr actor, const KernelGraphPtr &graph); + void LinkDataArrowForDeviceDSActor(DeviceQueueDataSourceActor *from_actor, KernelActor *to_actor, + KernelWithIndex from_kernel_with_output_idx, + KernelWithIndex to_to_kernel_with_input_idx); + void LinkDataArrowForHostDSActor(HostQueueDataSourceActor *from_actor, KernelActor *to_actor, + KernelWithIndex from_kernel_with_output_idx, + KernelWithIndex to_kernel_with_input_idx); + void LinkDataArrowForKernelActor(KernelActor *from_actor, KernelActor *to_actor, + KernelWithIndex from_kernel_with_output_idx, + KernelWithIndex to_kernel_with_input_idx); + void LinkControlArrowForKernelActor(KernelActor *from_actor, LoopCountActor *to_actor, const KernelGraphPtr &graph, + GraphExecutionStrategy strategy); + void LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const KernelGraphPtr &graph); // Persist device tensors of graph's some nodes(such as weights and value nodes). void PersistDeviceTensor(const KernelGraphPtr &graph); - // Judge whether the device tensor of the node is persistent or not. - bool IsPersistentDeviceTensor(const AnfNodePtr &node); - std::unordered_map graph_to_actor_; + std::unordered_map graph_to_actors_; std::unordered_map graph_to_host_queue_; // The second element of pair represents the output index of kernel actor corresponding to the device tensor. diff --git a/mindspore/core/utils/log_adapter.h b/mindspore/core/utils/log_adapter.h index b00b9bc608..ba03af7a44 100644 --- a/mindspore/core/utils/log_adapter.h +++ b/mindspore/core/utils/log_adapter.h @@ -102,33 +102,34 @@ constexpr std::ostream &operator<<(std::ostream &stream, const T &value) { enum MsLogLevel : int { DEBUG = 0, INFO, WARNING, ERROR, EXCEPTION }; enum SubModuleId : int { - SM_UNKNOWN = 0, // unknown submodule - SM_CORE, // core - SM_ANALYZER, // static analyzer - SM_COMMON, // common - SM_DEBUG, // debug - SM_DEVICE, // device - SM_GE_ADPT, // ge adapter - SM_IR, // IR - SM_KERNEL, // kernel - SM_MD, // MindData - SM_ME, // MindExpression - SM_EXPRESS, // EXPRESS_IR - SM_OPTIMIZER, // optimzer - SM_PARALLEL, // parallel - SM_PARSER, // parser - SM_PIPELINE, // ME pipeline - SM_PRE_ACT, // pre-activate - SM_PYNATIVE, // PyNative - SM_SESSION, // session - SM_UTILS, // utils - SM_VM, // VM - SM_PROFILER, // profiler - SM_PS, // Parameter Server - SM_LITE, // LITE - SM_HCCL_ADPT, // Hccl Adapter - SM_MINDQUANTUM, // MindQuantum - NUM_SUBMODUES // number of submodules + SM_UNKNOWN = 0, // unknown submodule + SM_CORE, // core + SM_ANALYZER, // static analyzer + SM_COMMON, // common + SM_DEBUG, // debug + SM_DEVICE, // device + SM_GE_ADPT, // ge adapter + SM_IR, // IR + SM_KERNEL, // kernel + SM_MD, // MindData + SM_ME, // MindExpression + SM_EXPRESS, // EXPRESS_IR + SM_OPTIMIZER, // optimzer + SM_PARALLEL, // parallel + SM_PARSER, // parser + SM_PIPELINE, // ME pipeline + SM_PRE_ACT, // pre-activate + SM_PYNATIVE, // PyNative + SM_SESSION, // session + SM_UTILS, // utils + SM_VM, // VM + SM_PROFILER, // profiler + SM_PS, // Parameter Server + SM_LITE, // LITE + SM_HCCL_ADPT, // Hccl Adapter + SM_MINDQUANTUM, // MindQuantum + SM_RUNTIME_FRAMEWORK, // Runtime framework + NUM_SUBMODUES // number of submodules }; #ifndef SUBMODULE_ID