Browse Source

add the impl of graph scheduler

pull/14216/head
limingqi107 4 years ago
parent
commit
33e010db69
6 changed files with 526 additions and 58 deletions
  1. +2
    -1
      mindspore/ccsrc/runtime/framework/actor/loop_count_actor.h
  2. +2
    -7
      mindspore/ccsrc/runtime/framework/actor/memory_manager_actor.h
  3. +5
    -3
      mindspore/ccsrc/runtime/framework/device_tensor_store.h
  4. +459
    -0
      mindspore/ccsrc/runtime/framework/graph_scheduler.cc
  5. +30
    -20
      mindspore/ccsrc/runtime/framework/graph_scheduler.h
  6. +28
    -27
      mindspore/core/utils/log_adapter.h

+ 2
- 1
mindspore/ccsrc/runtime/framework/actor/loop_count_actor.h View File

@@ -30,7 +30,8 @@ namespace runtime {
// and decide whether to loop execution by loop count.
class LoopCountActor : public OpActor<DeviceTensor> {
public:
LoopCountActor(std::string name, size_t loop_count) : OpActor(name), loop_count_(loop_count), current_count_(0) {}
LoopCountActor(std::string name, size_t loop_count)
: OpActor(name), loop_count_(loop_count), current_count_(0), input_controls_num_(0) {}
virtual ~LoopCountActor() = default;

// The loop count actor run when receive the input control.


+ 2
- 7
mindspore/ccsrc/runtime/framework/actor/memory_manager_actor.h View File

@@ -21,7 +21,7 @@
#include <memory>
#include <string>
#include <unordered_map>
#include "mindrt/include/actor/actor.h"
#include "mindrt/include/actor/op_actor.h"
#include "runtime/framework/device_tensor_store.h"
#include "runtime/hardware/device_context.h"

@@ -33,12 +33,7 @@ using mindspore::device::DeviceContext;
class MemoryManagerActor : public ActorBase {
public:
MemoryManagerActor() : ActorBase("MemoryManagerActor") {}
virtual ~MemoryManagerActor() = default;

static std::shared_ptr<MemoryManagerActor> &GetInstance() {
static std::shared_ptr<MemoryManagerActor> instance;
return instance;
}
~MemoryManagerActor() override = default;

// The process entry of memory alloc.
bool AllocateMemory(std::vector<DeviceTensorPtr> alloc_list, const DeviceContext *device_context,


+ 5
- 3
mindspore/ccsrc/runtime/framework/device_tensor_store.h View File

@@ -19,6 +19,7 @@

#include <memory>
#include <unordered_map>
#include "utils/ms_utils.h"
#include "runtime/device/device_address.h"

namespace mindspore {
@@ -32,9 +33,6 @@ using DeviceTensorPtr = std::shared_ptr<DeviceTensor>;
// so they are more suitable for store and can be obtained when they are used by actor.
class DeviceTensorStore {
public:
DeviceTensorStore() = default;
virtual ~DeviceTensorStore() = default;

static DeviceTensorStore &GetInstance() {
static DeviceTensorStore instance;
return instance;
@@ -60,6 +58,10 @@ class DeviceTensorStore {
}

private:
DeviceTensorStore() = default;
~DeviceTensorStore() = default;
DISABLE_COPY_AND_ASSIGN(DeviceTensorStore);

// The data storage of device tensor, key is anfNode ptr.
std::unordered_map<void *, DeviceTensorPtr> device_tensors_;
};


+ 459
- 0
mindspore/ccsrc/runtime/framework/graph_scheduler.cc View File

@@ -0,0 +1,459 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "runtime/framework/graph_scheduler.h"
#include "mindrt/src/actor/actormgr.h"
#include "mindrt/include/async/async.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/optimizer/common/helper.h"
#include "utils/config_manager.h"
#include "utils/log_adapter.h"

namespace mindspore {
namespace runtime {
namespace {
bool IsDeviceQueueDSActor(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>() && (AnfAlgo::GetCNodeName(node) == kGetNextOpName)) {
return true;
}
return false;
}

bool IsHostQueueDSActor(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<Parameter>() && (!AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>()))) {
return true;
}
return false;
}

bool IsKernelActor(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>() && (AnfAlgo::GetCNodeName(node) != kGetNextOpName)) {
return true;
}
return false;
}

// Judge whether the device tensor of the node is persistent or not.
bool IsPersistentDeviceTensor(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<ValueNode>()) {
return true;
}
if (node->isa<Parameter>() && AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>())) {
return true;
}
return false;
}

KernelActor *FindKernelActor(const std::unordered_map<std::string, KernelActorPtr> &kernel_actors_map,
const std::string &name) {
auto iter = kernel_actors_map.find(name);
if (iter != kernel_actors_map.end()) {
return iter->second.get();
}
return nullptr;
}

DeviceQueueDataSourceActor *FindDeviceQueueDSActor(const std::vector<DataSourceActorPtr> &data_source_actors) {
for (auto &actor : data_source_actors) {
MS_EXCEPTION_IF_NULL(actor);
if (actor->GetAID().Name().find("_DeviceQueueDataSourceActor") != string::npos) {
auto device_queue_ds_actor = dynamic_cast<DeviceQueueDataSourceActor *>(actor.get());
return device_queue_ds_actor;
}
}
return nullptr;
}

HostQueueDataSourceActor *FindHostQueueDSActor(const std::vector<DataSourceActorPtr> &data_source_actors) {
for (auto &actor : data_source_actors) {
MS_EXCEPTION_IF_NULL(actor);
if (actor->GetAID().Name().find("_HostQueueDataSourceActor") != string::npos) {
auto device_queue_ds_actor = dynamic_cast<HostQueueDataSourceActor *>(actor.get());
return device_queue_ds_actor;
}
}
return nullptr;
}

// Update the reference count of device tensor by the output index of node.
void UpdateRefCount(const AnfNodePtr &node, size_t output_idx) {
MS_EXCEPTION_IF_NULL(node);
auto device_tensor = AnfAlgo::GetMutableOutputAddr(node, output_idx);
MS_EXCEPTION_IF_NULL(device_tensor);
device_tensor->IncreaseRefCount();
device_tensor->ResetRefCountUsed();
}
} // namespace

ActorSet *GraphScheduler::Transform(const KernelGraphPtr &graph, const DeviceContext *device_context,
const std::vector<tensor::TensorPtr> *input_tensors,
GraphExecutionStrategy strategy) {
PersistDeviceTensor(graph);
auto actor_set = Build(graph, device_context);
graph_to_actors_.emplace(graph, actor_set);
Link(actor_set.get(), graph, strategy);
return actor_set.get();
}

void GraphScheduler::Schedule(const ActorSet *actor_set) {
MS_EXCEPTION_IF_NULL(actor_set);
auto actorMgr = ActorMgr::GetActorMgrRef();
MS_EXCEPTION_IF_NULL(actorMgr);

// Schedule dats source actors.
for (auto &data_source_actor : actor_set->data_source_actors_) {
MS_EXCEPTION_IF_NULL(data_source_actor);
auto base_actor = static_cast<ActorReference>(data_source_actor);
(void)actorMgr->Spawn(base_actor);
}

// Schedule kernel actors.
for (auto &kernel_actor : actor_set->kernel_actors_) {
MS_EXCEPTION_IF_NULL(kernel_actor);
auto base_actor = static_cast<ActorReference>(kernel_actor);
(void)actorMgr->Spawn(base_actor);
}

// Schedule loop count actor.
if (actor_set->loop_count_actor_ != nullptr) {
auto base_actor = static_cast<ActorReference>(actor_set->loop_count_actor_);
(void)actorMgr->Spawn(base_actor);
}
}

bool GraphScheduler::Run(const ActorSet *actor_set, GraphExecutionStrategy strategy) {
MS_EXCEPTION_IF_NULL(actor_set);
// Construct OpContext.
OpContext<DeviceTensor> op_context;
auto sequential_num = uuids::RandomBasedGenerator::GenerateRandomUuid();
op_context.sequential_num_ = &sequential_num;
Promise<int> result;
op_context.results_->push_back(result);

// Trigger no input kernel actor running.
for (auto &no_input_kernel_actor : actor_set->no_input_kernel_actors_) {
MS_EXCEPTION_IF_NULL(no_input_kernel_actor);
Async(no_input_kernel_actor->GetAID(), &KernelActor::RunOpControl, nullptr, &op_context);
}

// Trigger data source actor running.
for (auto &data_source_actor : actor_set->data_source_actors_) {
MS_EXCEPTION_IF_NULL(data_source_actor);
Async(data_source_actor->GetAID(), &DataSourceActor::FetchData, &op_context);
}

// Trigger kernel actor running in the step execution strategy.
if (strategy == GraphExecutionStrategy::kStep) {
for (auto &kernel_actor : actor_set->kernel_actors_) {
MS_EXCEPTION_IF_NULL(kernel_actor);
Async(kernel_actor->GetAID(), &KernelActor::RunOpControl, nullptr, &op_context);
}
}

// Get the run result.
auto result_future = result.GetFuture();
result_future.Wait();
if (!result_future.IsOK()) {
return false;
}
return true;
}

ActorSet *GraphScheduler::Fetch(const KernelGraphPtr &graph) const {
MS_EXCEPTION_IF_NULL(graph);
auto iter = graph_to_actors_.find(graph);
if (iter != graph_to_actors_.end()) {
return iter->second.get();
} else {
MS_LOG(ERROR) << "Can't find the actors map of graph: " << graph->ToString();
return nullptr;
}
}

ActorSetPtr GraphScheduler::Build(const KernelGraphPtr &graph, const DeviceContext *device_context) {
auto actor_set = std::make_shared<ActorSet>();
MS_EXCEPTION_IF_NULL(actor_set);

auto data_source_actors = BuildDataSourceActor(graph);
actor_set->data_source_actors_.swap(data_source_actors);

auto kernel_actors = BuildKernelActor(graph, device_context);
actor_set->kernel_actors_.swap(kernel_actors);

auto loop_count_actor = BuildLoopCountActor(graph);
actor_set->loop_count_actor_ = loop_count_actor;

return actor_set;
}

void GraphScheduler::Link(ActorSet *actor_set, const KernelGraphPtr &graph, GraphExecutionStrategy strategy) {
MS_EXCEPTION_IF_NULL(actor_set);
MS_EXCEPTION_IF_NULL(graph);
std::unordered_map<std::string, KernelActorPtr> kernel_actors_temp_map;
for (auto &actor : actor_set->kernel_actors_) {
MS_EXCEPTION_IF_NULL(actor);
kernel_actors_temp_map.emplace(actor->GetAID().Name(), actor);
}

// Foreach the execution order to link the actors.
auto execution_order = graph->execution_order();
for (auto &kernel : execution_order) {
if (!IsKernelActor(kernel)) {
continue;
}
auto kernel_actor = FindKernelActor(kernel_actors_temp_map, kernel->fullname_with_scope());
// Link the control arrows of kernel actor.
LinkControlArrowForKernelActor(kernel_actor, actor_set->loop_count_actor_.get(), graph, strategy);

for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) {
KernelWithIndex from_kernel_with_output_idx = AnfAlgo::GetPrevNodeOutput(kernel, i);
KernelWithIndex to_kernel_with_input_idx = std::make_pair(kernel, i);
auto from_kernel = from_kernel_with_output_idx.first;

if (IsDeviceQueueDSActor(from_kernel)) {
// Link the data arrows of device queue data source actor.
auto from_actor = FindDeviceQueueDSActor(actor_set->data_source_actors_);
LinkDataArrowForDeviceDSActor(from_actor, kernel_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
} else if (IsHostQueueDSActor(from_kernel)) {
// Link the data arrows of host queue data source actor.
auto from_actor = FindHostQueueDSActor(actor_set->data_source_actors_);
LinkDataArrowForHostDSActor(from_actor, kernel_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
} else {
// Link the data arrows of kernel actor.
auto from_actor = FindKernelActor(kernel_actors_temp_map, from_kernel->fullname_with_scope());
LinkDataArrowForKernelActor(from_actor, kernel_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
}
}
}

// BuildNoInputKernelActor depends on whether kernel actors have input, so must be behind the link of kernel actors.
auto no_input_kernel_actors = BuildNoInputKernelActor(graph);
actor_set->no_input_kernel_actors_.swap(no_input_kernel_actors);

// Link the control arrows of loop count actor, which depends on the no input kernel actors.
LinkControlArrowForLoopCountActor(actor_set->loop_count_actor_.get(), graph);
}

std::vector<DataSourceActorPtr> GraphScheduler::BuildDataSourceActor(const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
std::vector<DataSourceActorPtr> data_source_actors;

// Build host queue data source actor.
HostQueueDSActorPtr host_queue_ds_actor = nullptr;
for (auto &input_node : graph->input_nodes()) {
MS_EXCEPTION_IF_NULL(input_node);
if (IsHostQueueDSActor(input_node)) {
if (host_queue_ds_actor == nullptr) {
auto actor_name = graph->ToString() + "_" + "HostQueueDataSourceActor";
MS_LOG(INFO) << "Create host queue data source actor: " << actor_name;
auto host_queue = std::make_shared<HostTensorQueue>();
graph_to_host_queue_.emplace(graph, host_queue);
host_queue_ds_actor = std::make_shared<HostQueueDataSourceActor>(actor_name, 1, host_queue);
data_source_actors.emplace_back(host_queue_ds_actor);
}
host_queue_ds_actor->data_nodes_.emplace_back(input_node);
}
}

// Build device queue data source actor.
auto execution_order = graph->execution_order();
auto iter = std::find_if(execution_order.begin(), execution_order.end(),
[](const CNodePtr &node) { return IsDeviceQueueDSActor(node); });
if (iter != execution_order.end()) {
auto actor_name = graph->ToString() + "_" + "DeviceQueueDataSourceActor";
MS_LOG(INFO) << "Create queue data source actor: " << actor_name;
auto device_queue_ds_actor = std::make_shared<DeviceQueueDataSourceActor>(actor_name, 1);
MS_EXCEPTION_IF_NULL(device_queue_ds_actor);
data_source_actors.emplace_back(device_queue_ds_actor);
device_queue_ds_actor->data_kernel_ = *iter;
}
return data_source_actors;
}

std::vector<KernelActorPtr> GraphScheduler::BuildKernelActor(const KernelGraphPtr &graph,
const DeviceContext *device_context) {
MS_EXCEPTION_IF_NULL(graph);
std::vector<KernelActorPtr> kernel_actors;

auto execution_order = graph->execution_order();
for (auto &kernel : execution_order) {
if (IsKernelActor(kernel)) {
auto kernel_actor = std::make_shared<KernelActor>(kernel->fullname_with_scope(), kernel, device_context);
MS_EXCEPTION_IF_NULL(kernel_actor);
kernel_actors.emplace_back(kernel_actor);
}
}
return kernel_actors;
}

std::vector<KernelActorPtr> GraphScheduler::BuildNoInputKernelActor(const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
std::vector<KernelActorPtr> no_input_kernel_actors;

auto actor_set = Fetch(graph);
MS_EXCEPTION_IF_NULL(actor_set);
for (auto &kernel_actor : actor_set->kernel_actors_) {
MS_EXCEPTION_IF_NULL(kernel_actor);
if ((kernel_actor->input_datas_num_ == 0) && (kernel_actor->input_controls_num_ == 0)) {
no_input_kernel_actors.emplace_back(kernel_actor);
}
}
return no_input_kernel_actors;
}

LoopCountActorPtr GraphScheduler::BuildLoopCountActor(const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
auto loop_count = ConfigManager::GetInstance().iter_num();
auto actor_name = graph->ToString() + "_" + "LoopCountActor";
auto loop_count_actor = std::make_shared<LoopCountActor>(actor_name, loop_count);
MS_EXCEPTION_IF_NULL(loop_count_actor);
return loop_count_actor;
}

void GraphScheduler::LinkDataArrowForDeviceDSActor(DeviceQueueDataSourceActor *from_actor, KernelActor *to_actor,
KernelWithIndex from_kernel_with_output_idx,
KernelWithIndex to_kernel_with_input_idx) {
MS_EXCEPTION_IF_NULL(from_actor);
MS_EXCEPTION_IF_NULL(to_actor);

auto from_kernel = from_kernel_with_output_idx.first;
MS_EXCEPTION_IF_NULL(from_kernel);
auto from_output_index = from_kernel_with_output_idx.second;
auto to_input_index = to_kernel_with_input_idx.second;

auto to_aid = to_actor->GetAID();
auto op_arrow = std::make_shared<OpArrow>(from_output_index, to_aid, to_input_index);
from_actor->output_op_arrows_.emplace_back(op_arrow);
to_actor->input_datas_num_++;

// Update the reference count of device tensor.
UpdateRefCount(from_kernel, from_output_index);
}

void GraphScheduler::LinkDataArrowForHostDSActor(HostQueueDataSourceActor *from_actor, KernelActor *to_actor,
KernelWithIndex from_kernel_with_output_idx,
KernelWithIndex to_kernel_with_input_idx) {
MS_EXCEPTION_IF_NULL(from_actor);
MS_EXCEPTION_IF_NULL(to_actor);

auto from_kernel = from_kernel_with_output_idx.first;
MS_EXCEPTION_IF_NULL(from_kernel);
auto from_output_index = from_kernel_with_output_idx.second;
auto to_input_index = to_kernel_with_input_idx.second;

auto data_nodes = from_actor->data_nodes_;
auto iter = find(data_nodes.begin(), data_nodes.end(), from_kernel);
if (iter == data_nodes.end()) {
MS_LOG(EXCEPTION) << "Parameter node: " << from_kernel->fullname_with_scope() << " is not exist.";
}
auto position = IntToSize(std::distance(data_nodes.begin(), iter));
auto to_aid = to_actor->GetAID();
auto op_arrow = std::make_shared<OpArrow>(position, to_aid, to_input_index);
from_actor->output_op_arrows_.emplace_back(op_arrow);
to_actor->input_datas_num_++;

// Update the reference count of device tensor.
UpdateRefCount(from_kernel, from_output_index);
}

void GraphScheduler::LinkDataArrowForKernelActor(KernelActor *from_actor, KernelActor *to_actor,
KernelWithIndex from_kernel_with_output_idx,
KernelWithIndex to_kernel_with_input_idx) {
MS_EXCEPTION_IF_NULL(to_actor);
auto from_kernel = from_kernel_with_output_idx.first;
MS_EXCEPTION_IF_NULL(from_kernel);
auto from_output_index = from_kernel_with_output_idx.second;
auto to_input_index = to_kernel_with_input_idx.second;

if (IsPersistentDeviceTensor(from_kernel)) {
to_actor->device_tensor_store_keys_.emplace_back(to_input_index, static_cast<void *>(from_kernel.get()));
} else if (IsKernelActor(from_kernel)) {
MS_EXCEPTION_IF_NULL(from_actor);
auto to_aid = to_actor->GetAID();
auto op_arrow = std::make_shared<OpArrow>(from_output_index, to_aid, to_input_index);
from_actor->output_op_arrows_.emplace_back(op_arrow);
to_actor->input_datas_num_++;

// Update the reference count of device tensor.
UpdateRefCount(from_kernel, from_output_index);
}
}

void GraphScheduler::LinkControlArrowForKernelActor(KernelActor *from_actor, LoopCountActor *to_actor,
const KernelGraphPtr &graph, GraphExecutionStrategy strategy) {
MS_EXCEPTION_IF_NULL(from_actor);
MS_EXCEPTION_IF_NULL(to_actor);
MS_EXCEPTION_IF_NULL(graph);

if (strategy == GraphExecutionStrategy::kStep) {
from_actor->input_controls_num_++;
}

if (opt::IsNotRealUsedByOthers(graph, from_actor->kernel_)) {
auto to_aid = to_actor->GetAID();
from_actor->output_op_controls_.emplace_back(to_aid);
to_actor->input_controls_num_++;
}
}

void GraphScheduler::LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(loop_count_actor);

auto actor_set = Fetch(graph);
MS_EXCEPTION_IF_NULL(actor_set);

// Set the source data actor.
for (auto &data_source_actor : actor_set->data_source_actors_) {
MS_EXCEPTION_IF_NULL(data_source_actor);
loop_count_actor->data_source_aids_.emplace_back(data_source_actor->GetAID());
}

// Set the no input kernel actor.
for (auto &no_input_kernel_actor : actor_set->no_input_kernel_actors_) {
MS_EXCEPTION_IF_NULL(no_input_kernel_actor);
loop_count_actor->no_input_kernel_aids_.emplace_back(no_input_kernel_actor->GetAID());
}
}

void GraphScheduler::PersistDeviceTensor(const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);

for (auto &value_node : graph->graph_value_nodes()) {
MS_EXCEPTION_IF_NULL(value_node);
auto device_tensor = AnfAlgo::GetMutableOutputAddr(value_node, 0);
DeviceTensorStore::GetInstance().Insert(value_node.get(), device_tensor);
device_tensor->set_ref_count(SIZE_MAX);
device_tensor->ResetRefCountUsed();
}

for (auto &input_node : graph->input_nodes()) {
MS_EXCEPTION_IF_NULL(input_node);
if (IsPersistentDeviceTensor(input_node)) {
auto device_tensor = AnfAlgo::GetMutableOutputAddr(input_node, 0);
DeviceTensorStore::GetInstance().Insert(input_node.get(), device_tensor);
device_tensor->set_ref_count(SIZE_MAX);
device_tensor->ResetRefCountUsed();
}
}
}

} // namespace runtime
} // namespace mindspore

+ 30
- 20
mindspore/ccsrc/runtime/framework/graph_scheduler.h View File

@@ -31,12 +31,11 @@
namespace mindspore {
namespace runtime {
using mindspore::device::DeviceContext;
using mindspore::session::KernelWithIndex;

enum class GraphExecutionStrategy {
// The actor running is triggered only by data.
kPipeline,
// The actor running need be triggered by control in addition.
kStep
kPipeline, // The actor running is triggered only by data.
kStep // The actor running need be triggered by control in addition.
};

// The actor set generated by graph transformer is the execution unit of actor runtime.
@@ -57,49 +56,60 @@ using ActorSetPtr = std::shared_ptr<ActorSet>;

class GraphScheduler {
public:
GraphScheduler() = default;
virtual ~GraphScheduler() = default;

static GraphScheduler &GetInstance() {
static GraphScheduler instance;
return instance;
}

// Transform graph to actor DAG, contains build and link.
ActorSetPtr Transform(const KernelGraphPtr &graph, const DeviceContext *device_context,
const std::vector<tensor::TensorPtr> *input_tensors = nullptr,
GraphExecutionStrategy strategy = GraphExecutionStrategy::kPipeline);
ActorSet *Transform(const KernelGraphPtr &graph, const DeviceContext *device_context,
const std::vector<tensor::TensorPtr> *input_tensors = nullptr,
GraphExecutionStrategy strategy = GraphExecutionStrategy::kPipeline);

// Schedule actors in the actor runtime. Single machine scheduling is supported currently, and distributed scheduling
// will be supported in the future.
void Schedule(const ActorSetPtr &actor_set);
void Schedule(const ActorSet *actor_set);

// The processing entry of actors running.
bool Run(const ActorSetPtr &actor_set);
bool Run(const ActorSet *actor_set, GraphExecutionStrategy strategy = GraphExecutionStrategy::kPipeline);

// Fetch the actor set by kernel graph.
ActorSet *Fetch(const KernelGraphPtr &graph) const;

private:
GraphScheduler() = default;
~GraphScheduler() = default;
DISABLE_COPY_AND_ASSIGN(GraphScheduler);

// Transform the nodes of graph to actors.
ActorSetPtr Build(const KernelGraphPtr &graph, const DeviceContext *device_context);
// Link actors to DAG through the edge connection of graph and graph execution strategy.
void Link(ActorSetPtr actor_set, const KernelGraphPtr &graph, GraphExecutionStrategy strategy);
void Link(ActorSet *actor_set, const KernelGraphPtr &graph, GraphExecutionStrategy strategy);

// The processing of actors build.
std::vector<DataSourceActorPtr> BuildDataSourceActor(const KernelGraphPtr &graph);
std::vector<KernelActorPtr> BuildKernelActor(const KernelGraphPtr &graph, const DeviceContext *device_context);
std::vector<KernelActorPtr> BuildNoInputKernelActor(const KernelGraphPtr &graph);
LoopCountActorPtr BuildLoopCountActor(const KernelGraphPtr &graph);

// The processing of actors link.
void LinkDataSourceActor(std::vector<DataSourceActorPtr> actors, const KernelGraphPtr &graph);
void LinkKernelActor(std::vector<KernelActorPtr> actors, const KernelGraphPtr &graph,
GraphExecutionStrategy strategy);
void LinkLoopCountActor(LoopCountActorPtr actor, const KernelGraphPtr &graph);
void LinkDataArrowForDeviceDSActor(DeviceQueueDataSourceActor *from_actor, KernelActor *to_actor,
KernelWithIndex from_kernel_with_output_idx,
KernelWithIndex to_to_kernel_with_input_idx);
void LinkDataArrowForHostDSActor(HostQueueDataSourceActor *from_actor, KernelActor *to_actor,
KernelWithIndex from_kernel_with_output_idx,
KernelWithIndex to_kernel_with_input_idx);
void LinkDataArrowForKernelActor(KernelActor *from_actor, KernelActor *to_actor,
KernelWithIndex from_kernel_with_output_idx,
KernelWithIndex to_kernel_with_input_idx);
void LinkControlArrowForKernelActor(KernelActor *from_actor, LoopCountActor *to_actor, const KernelGraphPtr &graph,
GraphExecutionStrategy strategy);
void LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const KernelGraphPtr &graph);

// Persist device tensors of graph's some nodes(such as weights and value nodes).
void PersistDeviceTensor(const KernelGraphPtr &graph);
// Judge whether the device tensor of the node is persistent or not.
bool IsPersistentDeviceTensor(const AnfNodePtr &node);

std::unordered_map<KernelGraphPtr, ActorSetPtr> graph_to_actor_;
std::unordered_map<KernelGraphPtr, ActorSetPtr> graph_to_actors_;
std::unordered_map<KernelGraphPtr, HostTensorQueuePtr> graph_to_host_queue_;

// The second element of pair represents the output index of kernel actor corresponding to the device tensor.


+ 28
- 27
mindspore/core/utils/log_adapter.h View File

@@ -102,33 +102,34 @@ constexpr std::ostream &operator<<(std::ostream &stream, const T &value) {
enum MsLogLevel : int { DEBUG = 0, INFO, WARNING, ERROR, EXCEPTION };

enum SubModuleId : int {
SM_UNKNOWN = 0, // unknown submodule
SM_CORE, // core
SM_ANALYZER, // static analyzer
SM_COMMON, // common
SM_DEBUG, // debug
SM_DEVICE, // device
SM_GE_ADPT, // ge adapter
SM_IR, // IR
SM_KERNEL, // kernel
SM_MD, // MindData
SM_ME, // MindExpression
SM_EXPRESS, // EXPRESS_IR
SM_OPTIMIZER, // optimzer
SM_PARALLEL, // parallel
SM_PARSER, // parser
SM_PIPELINE, // ME pipeline
SM_PRE_ACT, // pre-activate
SM_PYNATIVE, // PyNative
SM_SESSION, // session
SM_UTILS, // utils
SM_VM, // VM
SM_PROFILER, // profiler
SM_PS, // Parameter Server
SM_LITE, // LITE
SM_HCCL_ADPT, // Hccl Adapter
SM_MINDQUANTUM, // MindQuantum
NUM_SUBMODUES // number of submodules
SM_UNKNOWN = 0, // unknown submodule
SM_CORE, // core
SM_ANALYZER, // static analyzer
SM_COMMON, // common
SM_DEBUG, // debug
SM_DEVICE, // device
SM_GE_ADPT, // ge adapter
SM_IR, // IR
SM_KERNEL, // kernel
SM_MD, // MindData
SM_ME, // MindExpression
SM_EXPRESS, // EXPRESS_IR
SM_OPTIMIZER, // optimzer
SM_PARALLEL, // parallel
SM_PARSER, // parser
SM_PIPELINE, // ME pipeline
SM_PRE_ACT, // pre-activate
SM_PYNATIVE, // PyNative
SM_SESSION, // session
SM_UTILS, // utils
SM_VM, // VM
SM_PROFILER, // profiler
SM_PS, // Parameter Server
SM_LITE, // LITE
SM_HCCL_ADPT, // Hccl Adapter
SM_MINDQUANTUM, // MindQuantum
SM_RUNTIME_FRAMEWORK, // Runtime framework
NUM_SUBMODUES // number of submodules
};

#ifndef SUBMODULE_ID


Loading…
Cancel
Save