Browse Source

!24376 unified runtime support the scheduler of graph sink

Merge pull request !24376 from limingqi107/new_actor_runtime
tags/v1.6.0
i-robot Gitee 4 years ago
parent
commit
9d50e0c48a
12 changed files with 452 additions and 172 deletions
  1. +27
    -0
      mindspore/ccsrc/runtime/framework/actor/abstract_actor.cc
  2. +6
    -1
      mindspore/ccsrc/runtime/framework/actor/abstract_actor.h
  3. +2
    -0
      mindspore/ccsrc/runtime/framework/actor/actor_common.h
  4. +1
    -7
      mindspore/ccsrc/runtime/framework/actor/copy_actor.cc
  5. +1
    -1
      mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.cc
  6. +2
    -26
      mindspore/ccsrc/runtime/framework/actor/data_source_actor.cc
  7. +0
    -5
      mindspore/ccsrc/runtime/framework/actor/data_source_actor.h
  8. +2
    -11
      mindspore/ccsrc/runtime/framework/actor/kernel_actor.cc
  9. +77
    -0
      mindspore/ccsrc/runtime/framework/actor/super_kernel_actor.cc
  10. +73
    -0
      mindspore/ccsrc/runtime/framework/actor/super_kernel_actor.h
  11. +235
    -106
      mindspore/ccsrc/runtime/framework/graph_scheduler.cc
  12. +26
    -15
      mindspore/ccsrc/runtime/framework/graph_scheduler.h

+ 27
- 0
mindspore/ccsrc/runtime/framework/actor/abstract_actor.cc View File

@@ -15,6 +15,7 @@
*/

#include "runtime/framework/actor/abstract_actor.h"
#include "runtime/framework/actor/output_actor.h"
#include "utils/log_adapter.h"

namespace mindspore {
@@ -65,5 +66,31 @@ void AbstractActor::EraseInput(const OpContext<DeviceTensor> *context) {
}
}
}

void AbstractActor::SendOutputResult(OpContext<DeviceTensor> *const context) const {
MS_EXCEPTION_IF_NULL(context);
if (output_result_arrows_.size() != output_nodes_.size()) {
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The size of output result arrows is not equal to the output nodes.");
}

size_t output_node_index = 0;
for (const auto &result_arrow : output_result_arrows_) {
MS_EXCEPTION_IF_NULL(result_arrow);
Async(result_arrow->to_op_id_, &OutputActor::CollectOutput, output_nodes_[output_node_index],
result_arrow->from_output_index_, result_arrow->to_input_index_, context);
++output_node_index;
}
}

void AbstractActor::SendOutputControl(OpContext<DeviceTensor> *const context) const {
MS_EXCEPTION_IF_NULL(context);

if (output_control_arrows_.size() > 0) {
auto from_aid = const_cast<AID *>(&GetAID());
for (auto &output_control : output_control_arrows_) {
Async(output_control, &OpActor::RunOpControl, from_aid, context);
}
}
}
} // namespace runtime
} // namespace mindspore

+ 6
- 1
mindspore/ccsrc/runtime/framework/actor/abstract_actor.h View File

@@ -55,6 +55,10 @@ class AbstractActor : public OpActor<DeviceTensor> {
bool CheckRunningCondition(const OpContext<DeviceTensor> *context) const;
// Erase input data and input controls when finish actor running.
void EraseInput(const OpContext<DeviceTensor> *const context);
// Send the output result by output_result_arrows_.
void SendOutputResult(OpContext<DeviceTensor> *const context) const;
// Send the output control by output_control_arrows_.
void SendOutputControl(OpContext<DeviceTensor> *const context) const;

KernelTransformType type_;

@@ -64,7 +68,8 @@ class AbstractActor : public OpActor<DeviceTensor> {
// The id of recorder actor. Send message to it for recording info.
const AID *recorder_aid_;

// The output result arrows of graph output.
// The output nodes and output result arrows of graph output.
std::vector<AnfNodePtr> output_nodes_;
std::vector<DataArrowPtr> output_result_arrows_;

// The dependent device tensor stores, the dependent expression is pair<index, AnfNode>.


+ 2
- 0
mindspore/ccsrc/runtime/framework/actor/actor_common.h View File

@@ -51,6 +51,8 @@ enum class KernelTransformType {
kDeviceDataSourceActor,
kHostDataSourceActor,
kKernelActor,
// Super kernel actor represents the sink executing of graph which is the combination of kernels.
kSuperKernelActor,
kCopyActor,
kLoopCountActor,
kOutputActor,


+ 1
- 7
mindspore/ccsrc/runtime/framework/actor/copy_actor.cc View File

@@ -21,7 +21,6 @@

namespace mindspore {
namespace runtime {

const size_t kInputDeviceContextIndex = 0;
const size_t kOutputDeviceContextIndex = 1;

@@ -162,12 +161,7 @@ void CopyActor::SendOutput(OpContext<DeviceTensor> *const context) const {
}

// Send output control.
if (output_control_arrows_.size() > 0) {
auto source_aid = const_cast<AID *>(&GetAID());
for (auto &output_control : output_control_arrows_) {
Async(output_control, &OpActor::RunOpControl, source_aid, context);
}
}
SendOutputControl(context);
}
} // namespace runtime
} // namespace mindspore

+ 1
- 1
mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.cc View File

@@ -181,7 +181,7 @@ void DataPrepareActor::SendOutput(OpContext<DeviceTensor> *const context) {

auto source_aid = const_cast<AID *>(&GetAID());
for (auto &kernel_aid : no_input_kernel_aids_) {
Async(kernel_aid, &KernelActor::RunOpControl, source_aid, context);
Async(kernel_aid, &OpActor::RunOpControl, source_aid, context);
}

// Trigger loop count actor running when there are no data source actor and kernel actor.


+ 2
- 26
mindspore/ccsrc/runtime/framework/actor/data_source_actor.cc View File

@@ -72,7 +72,7 @@ void DataSourceActor::SendOutput(OpContext<DeviceTensor> *const context) {

// Must be the execution order: send result --> send data --> send control, avoid the illegal timing problem.
// 1.Send graph output result.
SendResult(context);
SendOutputResult(context);

// 2.Send output data.
const auto &output_device_tensors = buffers_.front();
@@ -89,12 +89,7 @@ void DataSourceActor::SendOutput(OpContext<DeviceTensor> *const context) {
}

// 3.Send output control.
if (output_control_arrows_.size() > 0) {
auto source_aid = const_cast<AID *>(&GetAID());
for (auto &output_control : output_control_arrows_) {
Async(output_control, &OpActor::RunOpControl, source_aid, context);
}
}
SendOutputControl(context);

// 4.Send recorder info.
if (recorder_aid_ != nullptr) {
@@ -202,14 +197,6 @@ void DeviceQueueDataSourceActor::OnDebugFinish(OpContext<DeviceTensor> *const co
SendOutput(context);
}

void DeviceQueueDataSourceActor::SendResult(OpContext<DeviceTensor> *const context) {
for (const auto &result_arrow : output_result_arrows_) {
MS_EXCEPTION_IF_NULL(result_arrow);
Async(result_arrow->to_op_id_, &OutputActor::CollectOutput, data_kernel_, result_arrow->from_output_index_,
result_arrow->to_input_index_, context);
}
}

void DeviceQueueDataSourceActor::SendRecorderInfo(OpContext<DeviceTensor> *const context) {
if (recorder_aid_ != nullptr) {
MS_EXCEPTION_IF_NULL(data_kernel_);
@@ -300,17 +287,6 @@ void HostQueueDataSourceActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *cons
SendOutput(context);
}

void HostQueueDataSourceActor::SendResult(OpContext<DeviceTensor> *const context) {
for (const auto &result_arrow : output_result_arrows_) {
MS_EXCEPTION_IF_NULL(result_arrow);
if (IntToSize(result_arrow->from_output_index_) >= data_nodes_.size()) {
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The output index is of range.");
}
Async(result_arrow->to_op_id_, &OutputActor::CollectOutput, data_nodes_[result_arrow->from_output_index_], 0,
result_arrow->to_input_index_, context);
}
}

size_t HostQueueDataSourceActor::FetchNodePosition(const AnfNodePtr &data_node) const {
MS_EXCEPTION_IF_NULL(data_node);
const auto &iter = data_node_position_map_.find(data_node);


+ 0
- 5
mindspore/ccsrc/runtime/framework/actor/data_source_actor.h View File

@@ -63,9 +63,6 @@ class DataSourceActor : public DebugAwareActor {
// Construct the device tensors and fill to device tensor buffer from the member nodes during the data fetching.
virtual void FillDataBuffer() = 0;

// Send output result of graph output to output actor.
virtual void SendResult(OpContext<DeviceTensor> *const context) = 0;

// Send recorder info to recorder actor, only the device queue data source actor need.
virtual void SendRecorderInfo(OpContext<DeviceTensor> *const context) {}

@@ -102,7 +99,6 @@ class DeviceQueueDataSourceActor : public DataSourceActor {

protected:
void FillDataBuffer() override;
void SendResult(OpContext<DeviceTensor> *const context) override;
void SendRecorderInfo(OpContext<DeviceTensor> *const context) override;

private:
@@ -136,7 +132,6 @@ class HostQueueDataSourceActor : public DataSourceActor {

protected:
void FillDataBuffer() override;
void SendResult(OpContext<DeviceTensor> *const context) override;

private:
friend class GraphScheduler;


+ 2
- 11
mindspore/ccsrc/runtime/framework/actor/kernel_actor.cc View File

@@ -422,11 +422,7 @@ void KernelActor::SendOutput(OpContext<DeviceTensor> *const context) const {

// Must be the execution order: send result --> send data --> send control, avoid the illegal timing problem.
// 1.Send graph output result.
for (const auto &result_arrow : output_result_arrows_) {
MS_EXCEPTION_IF_NULL(result_arrow);
Async(result_arrow->to_op_id_, &OutputActor::CollectOutput, kernel_, result_arrow->from_output_index_,
result_arrow->to_input_index_, context);
}
SendOutputResult(context);

// 2.Send output data.
for (auto &output_data : output_data_) {
@@ -435,12 +431,7 @@ void KernelActor::SendOutput(OpContext<DeviceTensor> *const context) const {
}

// 3.Send output control.
if (output_control_arrows_.size() > 0) {
auto source_aid = const_cast<AID *>(&GetAID());
for (auto &output_control : output_control_arrows_) {
Async(output_control, &OpActor::RunOpControl, source_aid, context);
}
}
SendOutputControl(context);

// 4.Send recorder info.
if (recorder_aid_ != nullptr) {


+ 77
- 0
mindspore/ccsrc/runtime/framework/actor/super_kernel_actor.cc View File

@@ -0,0 +1,77 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "runtime/framework/actor/super_kernel_actor.h"
#include "runtime/framework/actor/output_actor.h"
#include "mindrt/include/async/async.h"
#include "utils/log_adapter.h"

namespace mindspore {
namespace runtime {
void SuperKernelActor::Init() {
MS_EXCEPTION_IF_NULL(graph_);
// Check device contexts number.
if (device_contexts_.size() != device::kDeviceContextsNumOne) {
MS_LOG(EXCEPTION) << "The device contexts number is wrong.";
}

// Set the number of actor running dependent messages.
running_dependent_msg_num_ = SizeToInt(input_datas_num_ + input_controls_num_);
}

void SuperKernelActor::RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
MS_EXCEPTION_IF_NULL(device_contexts_[0]);

auto &sequential_num = context->sequential_num_;
(void)input_op_datas_[sequential_num].emplace_back(input_data);
if (CheckRunningCondition(context)) {
device_contexts_[0]->LaunchGraph(graph_);

// The input is invalid and needs to be erased when finish kernel launch.
EraseInput(context);
SendOutput(context);
}
}

void SuperKernelActor::RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
MS_EXCEPTION_IF_NULL(device_contexts_[0]);

auto &sequential_num = context->sequential_num_;
(void)input_op_controls_[sequential_num].emplace_back(input_control);
if (CheckRunningCondition(context)) {
device_contexts_[0]->LaunchGraph(graph_);

// The input is invalid and needs to be erased when finish kernel launch.
EraseInput(context);
SendOutput(context);
}
}

void SuperKernelActor::SendOutput(OpContext<DeviceTensor> *const context) const {
MS_EXCEPTION_IF_NULL(context);
SendOutputResult(context);
SendOutputControl(context);

// No output.
if ((output_data_arrows_.size() == 0) && (output_control_arrows_.size() == 0) &&
(output_result_arrows_.size() == 0)) {
SET_OPCONTEXT_SUCCESS_RET((*context));
}
}
} // namespace runtime
} // namespace mindspore

+ 73
- 0
mindspore/ccsrc/runtime/framework/actor/super_kernel_actor.h View File

@@ -0,0 +1,73 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_SUPER_KERNEL_ACTOR_H_
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_SUPER_KERNEL_ACTOR_H_

#include <vector>
#include <string>
#include <memory>
#include <utility>
#include <unordered_map>
#include "runtime/framework/actor/actor_common.h"
#include "runtime/framework/actor/debug_aware_actor.h"
#include "runtime/hardware/device_context.h"
#include "runtime/framework/device_tensor_store.h"
#include "backend/kernel_compiler/kernel.h"
#include "ir/anf.h"
#include "ir/tensor.h"

namespace mindspore {
namespace runtime {
using mindspore::device::DeviceContext;
using mindspore::device::KernelInfo;
using mindspore::kernel::Address;
using mindspore::kernel::KernelLaunchInfo;
using mindspore::tensor::TensorPtr;

// The Super kernel actor is used to represent the sink executing of graph which is the combination of kernels.
class SuperKernelActor : public DebugAwareActor {
public:
SuperKernelActor(const std::string &name, const KernelGraphPtr &graph, const DeviceContext *device_context,
const AID &memory_manager_aid, const AID *debug_aid, const AID *recorder_aid)
: DebugAwareActor(name, KernelTransformType::kSuperKernelActor, recorder_aid, memory_manager_aid, debug_aid),
graph_(graph) {
(void)device_contexts_.emplace_back(device_context);
}
~SuperKernelActor() override = default;

void Init() override;

// The super kernel actor run when receive the input data.
void RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) override;

// The super kernel actor run when receive the input control.
void RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) override;

private:
friend class GraphScheduler;

// Send output data and output controls when finish kernel launch.
void SendOutput(OpContext<DeviceTensor> *const context) const;

KernelGraphPtr graph_;
};

using SuperKernelActorPtr = std::shared_ptr<SuperKernelActor>;
} // namespace runtime
} // namespace mindspore

#endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_KERNEL_ACTOR_H_

+ 235
- 106
mindspore/ccsrc/runtime/framework/graph_scheduler.cc View File

@@ -194,7 +194,7 @@ void GraphScheduler::Clear() {
actor_name_to_actor_.clear();
}

using DataArrowLinkFunc = void (GraphScheduler::*)(AbstractActor *const, KernelActor *const, const KernelWithIndex &,
using DataArrowLinkFunc = void (GraphScheduler::*)(AbstractActor *const, AbstractActor *const, const KernelWithIndex &,
const KernelWithIndex &, const KernelGraphPtr &);
static std::map<KernelTransformType, DataArrowLinkFunc> kKernelTypeToLinkFunc;

@@ -364,6 +364,7 @@ ActorSetPtr GraphScheduler::Build(const GraphCompilerInfo &graph_compiler_info)
auto host_queue = std::make_shared<HostTensorQueue>();
actor_set->data_source_actors_ = BuildDataSourceActor(graph_compiler_info, host_queue);
actor_set->kernel_actors_ = BuildKernelActor(graph_compiler_info);
actor_set->super_kernel_actors_ = BuildSuperKernelActor(graph_compiler_info);
actor_set->loop_count_actor_ = BuildLoopCountActor(graph_compiler_info);
actor_set->output_actor_ = BuildOutputActor(graph_compiler_info);
actor_set->data_prepare_actor_ =
@@ -386,41 +387,27 @@ void GraphScheduler::CacheGraphOutputToActor(const GraphCompilerInfo &graph_comp
auto origin_output_with_index = graph->GetFrontNodeWithIndexByGraphOutput(output_with_index);
if (origin_output_with_index.first == nullptr) {
MS_LOG(WARNING) << "The graph " << graph->graph_id() << " output node:" << output_kernel->fullname_with_scope()
<< " with index: " << output_with_index.second << " has no actor.";
<< " with index: " << output_with_index.second << " has no front node.";
continue;
}

auto actor_output_index = output_with_index.second;
OpActor<DeviceTensor> *actor = nullptr;
if (IsKernelActor(output_kernel, graph_compiler_info.strategy_)) {
actor = FetchActor(output_kernel->fullname_with_scope());
} else if (IsDeviceQueueDSActor(output_kernel, graph_compiler_info.strategy_)) {
std::string actor_name = graph_compiler_info.name_ + "_DeviceDSActor" + "_" + std::to_string(graph->graph_id());
actor = FetchActor(actor_name);
} else if (IsHostQueueDSActor(output_kernel, graph, graph_compiler_info.origin_parameters_order_,
graph_compiler_info.strategy_)) {
actor = FetchActor(graph_compiler_info.name_ + "_HostDSActor");
const auto &host_ds_actor = dynamic_cast<HostQueueDataSourceActor *>(actor);
MS_EXCEPTION_IF_NULL(host_ds_actor);
// Get the position of output kernel in the data source actor.
actor_output_index = host_ds_actor->FetchNodePosition(output_kernel);
} else if (IsPersistentDeviceTensor(output_kernel)) {
auto kernel_type = KernelTransformType::kUnknown;
std::string kernel_name = "";
FetchKernelTransformTypeAndName(output_kernel, graph, graph_compiler_info, &kernel_type, &kernel_name);
if (kernel_name == "") {
MS_LOG(INFO) << "The graph " << graph->graph_id() << " output node:" << output_kernel->fullname_with_scope()
<< " is device tensor store.";
continue;
} else {
MS_LOG(INFO) << "Ignore the internal parameter node:" << output_kernel->DebugString();
<< " with index:" << output_with_index.second
<< " is not actor, and the kernel type is:" << kernel_type;
continue;
}

MS_EXCEPTION_IF_NULL(actor);
auto output_actor = dynamic_cast<AbstractActor *>(FetchActor(kernel_name));
MS_EXCEPTION_IF_NULL(output_actor);
(void)graph_output_to_actor_.emplace(origin_output_with_index, GraphOutputPair(output_actor, output_with_index));
MS_LOG(INFO) << "Cache the graph " << graph->graph_id() << " output node:" << output_kernel->fullname_with_scope()
<< " with index: " << output_with_index.second << " to actor:" << actor->GetAID().Name()
<< " with index:" << actor_output_index
<< " with index: " << output_with_index.second << " to actor:" << output_actor->GetAID().Name()
<< ", from front node:" << origin_output_with_index.first->fullname_with_scope()
<< " with index: " << origin_output_with_index.second;
(void)graph_output_to_actor_.emplace(origin_output_with_index,
GraphOutputPair(dynamic_cast<AbstractActor *>(actor), actor_output_index));
}
}
}
@@ -429,44 +416,14 @@ void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_co
MS_EXCEPTION_IF_NULL(actor_set);
std::vector<KernelActor *> auto_monad_actors;
std::vector<CNodePtr> communication_nodes;
const std::unordered_set<PrimitivePtr, PrimitiveHasher, PrimitiveEqual> auto_monad_prims = {
prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad};

// Foreach the execution order to link the actors.
for (size_t index = 0; index < graph_compiler_info.graphs_.size(); ++index) {
const auto &graph = graph_compiler_info.graphs_[index];
for (const auto &graph : graph_compiler_info.graphs_) {
MS_EXCEPTION_IF_NULL(graph);
auto execution_order = graph->execution_order();
for (auto &kernel : execution_order) {
MS_EXCEPTION_IF_NULL(kernel);
if (AnfAlgo::IsCommunicationOp(kernel)) {
(void)communication_nodes.emplace_back(kernel);
}
if (IsSkippedKernelActor(kernel) || (!IsKernelActor(kernel, graph_compiler_info.strategy_))) {
continue;
}
const auto &kernel_actor = dynamic_cast<KernelActor *>(FetchActor(kernel->fullname_with_scope()));
MS_EXCEPTION_IF_NULL(kernel_actor);

for (size_t i = 0; i < AnfAlgo::GetInputNum(kernel); ++i) {
auto input_node = AnfAlgo::GetInputNode(kernel, i);
// Link the control arrows of kernel actor by the auto monad, the inputs include monad node.
if (AnfAlgo::IsOneOfPrimitiveCNode(input_node, auto_monad_prims)) {
LinkControlArrowByAutoMonad(kernel_actor, input_node, graph);
}
if (HasAbstractMonad(input_node)) {
(void)auto_monad_actors.emplace_back(kernel_actor);
continue; // No data arrow for monad input.
}

KernelWithIndex from_kernel_with_output_idx = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false);
KernelWithIndex to_kernel_with_input_idx = std::make_pair(kernel, i);
// The gather of linking data arrows of kernel by the different from kernel type.
LinkDataArrow(kernel_actor, graph_compiler_info, graph, from_kernel_with_output_idx, to_kernel_with_input_idx);
}
if (graph->is_sink()) {
LinkDataArrowInSinkMode(graph, graph_compiler_info);
} else {
LinkDataArrowInNonSinkMode(graph, graph_compiler_info, &auto_monad_actors, &communication_nodes);
}
// Link the control arrows for allreduce kernel by the send/recv nodes in the kernel graph.
LinkControlArrowBySendRecvNodes(graph);
}

// Link the arrow in the control flow scene.
@@ -523,22 +480,25 @@ std::vector<DataSourceActorPtr> GraphScheduler::BuildDataSourceActor(const Graph
}
}

// Build device queue data source actor.
const auto &execution_order = graph->execution_order();
const auto &iter =
std::find_if(execution_order.begin(), execution_order.end(), [&graph_compiler_info](const CNodePtr &node) {
return IsDeviceQueueDSActor(node, graph_compiler_info.strategy_);
});
if (iter != execution_order.end()) {
auto actor_name = graph_compiler_info.name_ + "_DeviceDSActor" + "_" + std::to_string(graph->graph_id());
MS_LOG(INFO) << "Create queue data source actor: " << actor_name;
auto device_queue_ds_actor = std::make_shared<DeviceQueueDataSourceActor>(
actor_name, 1, device_context, memory_manager_aid_, debug_aid_, recorder_aid_);
MS_EXCEPTION_IF_NULL(device_queue_ds_actor);
InsertActor(device_queue_ds_actor.get());
(void)data_source_actors.emplace_back(device_queue_ds_actor);
device_queue_ds_actor->data_kernel_ = *iter;
device_queue_ds_actor->kernel_info_ = dynamic_cast<device::KernelInfo *>((*iter)->kernel_info());
// The graph sink mode has no device queue data source actor.
if (!graph->is_sink()) {
// Build device queue data source actor.
const auto &execution_order = graph->execution_order();
const auto &iter =
std::find_if(execution_order.begin(), execution_order.end(), [&graph_compiler_info](const CNodePtr &node) {
return IsDeviceQueueDSActor(node, graph_compiler_info.strategy_);
});
if (iter != execution_order.end()) {
auto actor_name = graph_compiler_info.name_ + "_DeviceDSActor" + "_" + std::to_string(graph->graph_id());
MS_LOG(INFO) << "Create queue data source actor: " << actor_name;
auto device_queue_ds_actor = std::make_shared<DeviceQueueDataSourceActor>(
actor_name, 1, device_context, memory_manager_aid_, debug_aid_, recorder_aid_);
MS_EXCEPTION_IF_NULL(device_queue_ds_actor);
InsertActor(device_queue_ds_actor.get());
(void)data_source_actors.emplace_back(device_queue_ds_actor);
device_queue_ds_actor->data_kernel_ = *iter;
device_queue_ds_actor->kernel_info_ = dynamic_cast<device::KernelInfo *>((*iter)->kernel_info());
}
}
}

@@ -588,8 +548,11 @@ std::vector<KernelActorPtr> GraphScheduler::BuildKernelActor(const GraphCompiler
const auto &graph = graph_compiler_info.graphs_[i];
const auto &device_context = graph_compiler_info.device_contexts_[i];
MS_EXCEPTION_IF_NULL(graph);
auto execution_order = graph->execution_order();
if (graph->is_sink()) {
continue;
}

auto execution_order = graph->execution_order();
// Single op graph in step mode, kernel actor executes synchronously.
bool is_single_op_graph = execution_order.size() == 1;
GraphExecutionStrategy strategy = graph_compiler_info.strategy_;
@@ -615,6 +578,27 @@ std::vector<KernelActorPtr> GraphScheduler::BuildKernelActor(const GraphCompiler
return kernel_actors;
}

std::vector<SuperKernelActorPtr> GraphScheduler::BuildSuperKernelActor(const GraphCompilerInfo &graph_compiler_info) {
std::vector<SuperKernelActorPtr> super_kernel_actors;

for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
const auto &graph = graph_compiler_info.graphs_[i];
const auto &device_context = graph_compiler_info.device_contexts_[i];
MS_EXCEPTION_IF_NULL(graph);
if (!graph->is_sink()) {
continue;
}

auto actor_name = graph->ToString() + "_SuperKernelActor";
auto super_kernel_actor =
std::make_shared<SuperKernelActor>(actor_name, graph, device_context, memory_manager_aid_, nullptr, nullptr);
MS_EXCEPTION_IF_NULL(super_kernel_actor);
InsertActor(super_kernel_actor.get());
(void)super_kernel_actors.emplace_back(super_kernel_actor);
}
return super_kernel_actors;
}

LoopCountActorPtr GraphScheduler::BuildLoopCountActor(const GraphCompilerInfo &graph_compiler_info) {
if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kStep) {
return nullptr;
@@ -658,7 +642,6 @@ DataPrepareActorPtr GraphScheduler::BuildDataPrepareActor(const GraphCompilerInf
if (iter != data_source_actors.end()) {
host_queue_ds_actor = std::dynamic_pointer_cast<HostQueueDataSourceActor>(*iter);
}

auto actor_name = graph_compiler_info.name_ + "_DataPrepareActor";
auto data_prepare_actor = std::make_shared<DataPrepareActor>(actor_name, memory_manager_aid_, debug_aid_,
&graph_compiler_info, host_queue_ds_actor, host_queue);
@@ -670,12 +653,15 @@ DataPrepareActorPtr GraphScheduler::BuildDataPrepareActor(const GraphCompilerInf
for (size_t index = 0; index < graph_compiler_info.graphs_.size(); ++index) {
const auto &graph = graph_compiler_info.graphs_[index];
MS_EXCEPTION_IF_NULL(graph);
if (graph->is_sink()) {
continue;
}

auto &execution_order = graph->execution_order();
for (auto &kernel : execution_order) {
if (!AnfAlgo::IsCommunicationOp(kernel)) {
continue;
}

auto key = std::make_pair(kernel, graph_compiler_info.device_contexts_[index]);
auto value = std::make_pair(false, false);
if (AnfAlgo::GetInputTensorNum(kernel) > 1) {
@@ -695,10 +681,17 @@ DataPrepareActorPtr GraphScheduler::BuildDataPrepareActor(const GraphCompilerInf
return data_prepare_actor;
}

std::vector<KernelActorPtr> GraphScheduler::BuildNoInputKernelActor(const ActorSet *actor_set,
GraphExecutionStrategy strategy) {
std::vector<AbstractActorPtr> GraphScheduler::BuildNoInputKernelActor(const ActorSet *actor_set,
GraphExecutionStrategy strategy) {
MS_EXCEPTION_IF_NULL(actor_set);
std::vector<KernelActorPtr> no_input_kernel_actors;
std::vector<AbstractActorPtr> no_input_kernel_actors;

for (auto &super_kernel_actor : actor_set->super_kernel_actors_) {
MS_EXCEPTION_IF_NULL(super_kernel_actor);
if ((super_kernel_actor->input_datas_num_ == 0) && (super_kernel_actor->input_controls_num_ == 0)) {
(void)no_input_kernel_actors.emplace_back(super_kernel_actor);
}
}

for (auto &kernel_actor : actor_set->kernel_actors_) {
MS_EXCEPTION_IF_NULL(kernel_actor);
@@ -730,6 +723,78 @@ std::vector<KernelActorPtr> GraphScheduler::BuildNoInputKernelActor(const ActorS
return no_input_kernel_actors;
}

void GraphScheduler::LinkDataArrowInSinkMode(const KernelGraphPtr &graph,
const GraphCompilerInfo &graph_compiler_info) {
MS_EXCEPTION_IF_NULL(graph);
auto to_actor_name = graph->ToString() + "_SuperKernelActor";
auto to_actor = dynamic_cast<SuperKernelActor *>(FetchActor(to_actor_name));
MS_EXCEPTION_IF_NULL(to_actor);

for (const auto &input_node : graph->input_nodes()) {
MS_EXCEPTION_IF_NULL(input_node);
auto kernel_type = KernelTransformType::kUnknown;
std::string kernel_name = "";
FetchKernelTransformTypeAndName(input_node, graph, graph_compiler_info, &kernel_type, &kernel_name);

KernelWithIndex from_kernel_with_output_idx = std::make_pair(input_node, 0);
KernelWithIndex to_kernel_with_input_idx = std::make_pair(nullptr, 0);
AbstractActor *from_actor = nullptr;
if (kernel_type == KernelTransformType::kHostDataSourceActor) {
from_actor = dynamic_cast<AbstractActor *>(FetchActor(kernel_name));
}

if ((from_actor != nullptr) && (kKernelTypeToLinkFunc.count(kernel_type) > 0)) {
(this->*kKernelTypeToLinkFunc[kernel_type])(from_actor, to_actor, from_kernel_with_output_idx,
to_kernel_with_input_idx, graph);
}
}
}

void GraphScheduler::LinkDataArrowInNonSinkMode(const KernelGraphPtr &graph,
const GraphCompilerInfo &graph_compiler_info,
std::vector<KernelActor *> *const auto_monad_actors,
std::vector<CNodePtr> *const communication_nodes) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(auto_monad_actors);
MS_EXCEPTION_IF_NULL(communication_nodes);

const std::unordered_set<PrimitivePtr, PrimitiveHasher, PrimitiveEqual> auto_monad_prims = {
prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad};
auto &execution_order = graph->execution_order();
// Foreach the execution order to link the actors.
for (const auto &kernel : execution_order) {
MS_EXCEPTION_IF_NULL(kernel);
if (AnfAlgo::IsCommunicationOp(kernel)) {
(void)communication_nodes->emplace_back(kernel);
}
if (IsSkippedKernelActor(kernel) || (!IsKernelActor(kernel, graph_compiler_info.strategy_))) {
continue;
}
const auto &kernel_actor = dynamic_cast<KernelActor *>(FetchActor(kernel->fullname_with_scope()));
MS_EXCEPTION_IF_NULL(kernel_actor);

for (size_t i = 0; i < AnfAlgo::GetInputNum(kernel); ++i) {
auto input_node = AnfAlgo::GetInputNode(kernel, i);
// Link the control arrows of kernel actor by the auto monad, the inputs include monad node.
if (AnfAlgo::IsOneOfPrimitiveCNode(input_node, auto_monad_prims)) {
LinkControlArrowByAutoMonad(kernel_actor, input_node, graph);
}
if (HasAbstractMonad(input_node)) {
(void)auto_monad_actors->emplace_back(kernel_actor);
continue; // No data arrow for monad input.
}

KernelWithIndex from_kernel_with_output_idx = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false);
KernelWithIndex to_kernel_with_input_idx = std::make_pair(kernel, i);
// The gather of linking data arrows of kernel by the different from kernel type.
LinkDataArrow(kernel_actor, graph_compiler_info, graph, from_kernel_with_output_idx, to_kernel_with_input_idx);
}
}

// Link the control arrows for allreduce kernel by the send/recv nodes in the kernel graph.
LinkControlArrowBySendRecvNodes(graph);
}

void GraphScheduler::LinkDataArrow(KernelActor *const to_actor, const GraphCompilerInfo &graph_compiler_info,
const KernelGraphPtr &graph, const KernelWithIndex &from_kernel_with_output_idx,
const KernelWithIndex &to_kernel_with_input_idx) {
@@ -784,7 +849,7 @@ void GraphScheduler::LinkDataArrow(KernelActor *const to_actor, const GraphCompi
}
}

void GraphScheduler::LinkDataArrowForDeviceTensorStore(AbstractActor *const, KernelActor *const to_actor,
void GraphScheduler::LinkDataArrowForDeviceTensorStore(AbstractActor *const, AbstractActor *const to_actor,
const KernelWithIndex &from_kernel_with_output_idx,
const KernelWithIndex &to_kernel_with_input_idx,
const KernelGraphPtr &graph) {
@@ -797,7 +862,7 @@ void GraphScheduler::LinkDataArrowForDeviceTensorStore(AbstractActor *const, Ker
(void)to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second, device_tensor_store_key);
}

void GraphScheduler::LinkDataArrowForInternalParameter(AbstractActor *const, KernelActor *to_actor,
void GraphScheduler::LinkDataArrowForInternalParameter(AbstractActor *const, AbstractActor *to_actor,
const KernelWithIndex &from_kernel_with_output_idx,
const KernelWithIndex &to_kernel_with_input_idx,
const KernelGraphPtr &graph) {
@@ -831,13 +896,16 @@ void GraphScheduler::LinkDataArrowForInternalParameter(AbstractActor *const, Ker
}
auto actor_pair = graph_output_to_actor_[front_output_with_index];
MS_EXCEPTION_IF_NULL(actor_pair.first);
MS_EXCEPTION_IF_NULL(actor_pair.second.first);
MS_LOG(INFO) << "Graph " << graph->graph_id() << " internal parameter:" << internal_parameter->DebugString()
<< ", corresponding front node:" << front_output_node->fullname_with_scope()
<< " with index:" << front_output_with_index.second
<< ", from actor:" << actor_pair.first->GetAID().Name() << " with index:" << actor_pair.second
<< ", to actor:" << to_actor->GetAID().Name() << " with index:" << to_kernel_with_input_idx.second;
<< ", from actor:" << actor_pair.first->GetAID().Name()
<< " node:" << actor_pair.second.first->fullname_with_scope()
<< " with index:" << actor_pair.second.second << ", to actor:" << to_actor->GetAID().Name()
<< " with index:" << to_kernel_with_input_idx.second;
real_from_actor = actor_pair.first;
real_from_kernel_with_output_idx = KernelWithIndex(nullptr, actor_pair.second);
real_from_kernel_with_output_idx = actor_pair.second;
kernel_type = actor_pair.first->type_;
}

@@ -848,7 +916,7 @@ void GraphScheduler::LinkDataArrowForInternalParameter(AbstractActor *const, Ker
to_kernel_with_input_idx, graph);
}

void GraphScheduler::LinkDataArrowForBaseActor(AbstractActor *const from_actor, KernelActor *const to_actor,
void GraphScheduler::LinkDataArrowForBaseActor(AbstractActor *const from_actor, AbstractActor *const to_actor,
const KernelWithIndex &from_kernel_with_output_idx,
const KernelWithIndex &to_kernel_with_input_idx) {
MS_EXCEPTION_IF_NULL(from_actor);
@@ -884,7 +952,7 @@ void GraphScheduler::LinkDataArrowForBaseActor(AbstractActor *const from_actor,
}
}

void GraphScheduler::LinkDataArrowForDeviceDSActor(AbstractActor *const from_actor, KernelActor *const to_actor,
void GraphScheduler::LinkDataArrowForDeviceDSActor(AbstractActor *const from_actor, AbstractActor *const to_actor,
const KernelWithIndex &from_kernel_with_output_idx,
const KernelWithIndex &to_kernel_with_input_idx,
const KernelGraphPtr &) {
@@ -898,7 +966,7 @@ void GraphScheduler::LinkDataArrowForDeviceDSActor(AbstractActor *const from_act
LinkDataArrowForBaseActor(from_actor, to_actor, real_from_kernel_with_output_idx, to_kernel_with_input_idx);
}

void GraphScheduler::LinkDataArrowForHostDSActor(AbstractActor *const from_actor, KernelActor *const to_actor,
void GraphScheduler::LinkDataArrowForHostDSActor(AbstractActor *const from_actor, AbstractActor *const to_actor,
const KernelWithIndex &from_kernel_with_output_idx,
const KernelWithIndex &to_kernel_with_input_idx,
const KernelGraphPtr &) {
@@ -919,7 +987,7 @@ void GraphScheduler::LinkDataArrowForHostDSActor(AbstractActor *const from_actor
LinkDataArrowForBaseActor(from_actor, to_actor, real_from_kernel_with_output_idx, to_kernel_with_input_idx);
}

void GraphScheduler::LinkDataArrowForKernelActor(AbstractActor *const from_actor, KernelActor *const to_actor,
void GraphScheduler::LinkDataArrowForKernelActor(AbstractActor *const from_actor, AbstractActor *const to_actor,
const KernelWithIndex &from_kernel_with_output_idx,
const KernelWithIndex &to_kernel_with_input_idx,
const KernelGraphPtr &) {
@@ -954,7 +1022,7 @@ void GraphScheduler::LinkDataArrowForKernelActor(AbstractActor *const from_actor
LinkDataArrowForBaseActor(real_from_actor, to_actor, real_from_kernel_with_output_idx, to_kernel_with_input_idx);
}

void GraphScheduler::LinkDataArrowForCopyActor(AbstractActor *const from_actor, KernelActor *const to_actor,
void GraphScheduler::LinkDataArrowForCopyActor(AbstractActor *const from_actor, AbstractActor *const to_actor,
const KernelWithIndex &from_kernel_with_output_idx,
const KernelWithIndex &to_kernel_with_input_idx) {
MS_EXCEPTION_IF_NULL(from_actor);
@@ -1020,7 +1088,7 @@ void GraphScheduler::LinkDataArrowForCopyActor(AbstractActor *const from_actor,
UpdateRefCount(copy_actor->output_.get());
}

void GraphScheduler::LinkControlArrowByAutoMonad(KernelActor *to_actor, const AnfNodePtr &from_node,
void GraphScheduler::LinkControlArrowByAutoMonad(AbstractActor *to_actor, const AnfNodePtr &from_node,
const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(to_actor);
MS_EXCEPTION_IF_NULL(from_node);
@@ -1098,7 +1166,7 @@ void GraphScheduler::LinkControlArrowByAutoMonad(KernelActor *to_actor, const An
}
}

void GraphScheduler::LinkControlArrowBySkippedNode(KernelActor *to_actor, const AnfNodePtr &skipped_node) {
void GraphScheduler::LinkControlArrowBySkippedNode(AbstractActor *to_actor, const AnfNodePtr &skipped_node) {
MS_EXCEPTION_IF_NULL(to_actor);
MS_EXCEPTION_IF_NULL(skipped_node);
auto to_aid = to_actor->GetAID();
@@ -1287,6 +1355,11 @@ void GraphScheduler::LinkControlArrowForLoopCountActor(LoopCountActor *loop_coun

// Collect the actors which have no output.
std::vector<MemoryAwareActor *> no_output_actors;
for (auto &super_actor : actor_set->super_kernel_actors_) {
if ((super_actor->output_data_arrows_.size() == 0) && (super_actor->output_control_arrows_.size() == 0)) {
(void)no_output_actors.emplace_back(super_actor.get());
}
}
for (auto &kernel_actor : actor_set->kernel_actors_) {
// The no output kernel control side in subgraph needs to be connected to the corresponding output switch actor.
if ((kernel_actor->output_data_arrows_.size() == 0) && (kernel_actor->output_control_arrows_.size() == 0) &&
@@ -1380,16 +1453,17 @@ void GraphScheduler::LinkOutputResultArrowForOutputActor(OutputActor *to_actor,
continue;
}
auto op_arrow = std::make_shared<DataArrow>(output_with_index.second, to_actor->GetAID(), output_position);
auto position = from_actor->FetchNodePosition(output_with_index.first);
// If the from actor has the multi nodes, then use the real output position.
if (position != 0) {
op_arrow->from_output_index_ = SizeToInt(position);
}
(void)from_actor->output_result_arrows_.emplace_back(op_arrow);
(void)from_actor->output_nodes_.emplace_back(output_with_index.first);

// Update the real compute node in the host data source actor.
if (kernel_type == KernelTransformType::kHostDataSourceActor) {
auto host_queue_ds_actor = dynamic_cast<HostQueueDataSourceActor *>(from_actor);
MS_EXCEPTION_IF_NULL(host_queue_ds_actor);
UpdateRefCount(host_queue_ds_actor->data_nodes_[position], output_with_index.second, true);
auto position = host_queue_ds_actor->FetchNodePosition(output_with_index.first);
auto real_node = host_queue_ds_actor->FetchNode(position);
from_actor->output_nodes_[from_actor->output_nodes_.size() - 1] = real_node;
UpdateRefCount(real_node, output_with_index.second, true);
}
}
}
@@ -1474,6 +1548,15 @@ bool GraphScheduler::CheckActorValid(const ActorSet *actor_set, GraphExecutionSt
return true;
}

// Check the super kernel actors.
for (const auto &super_kernel_actor : actor_set->super_kernel_actors_) {
MS_EXCEPTION_IF_NULL(super_kernel_actor);
if (super_kernel_actor->output_data_arrows_.size() + super_kernel_actor->output_control_arrows_.size() == 0) {
MS_LOG(ERROR) << super_kernel_actor->GetAID().Name() << " has no user.";
return false;
}
}

// Check the kernel actors.
for (const auto &kernel_actor : actor_set->kernel_actors_) {
MS_EXCEPTION_IF_NULL(kernel_actor);
@@ -1602,11 +1685,17 @@ void GraphScheduler::FetchKernelTransformTypeAndName(const AnfNodePtr &node, con
const GraphCompilerInfo &graph_compiler_info,
KernelTransformType *const kernel_type,
std::string *const kernel_name) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(kernel_type);
MS_EXCEPTION_IF_NULL(kernel_name);

if (graph->is_sink() && ((node == nullptr) || node->isa<CNode>())) {
*kernel_type = KernelTransformType::kSuperKernelActor;
*kernel_name = graph->ToString() + "_SuperKernelActor";
return;
}

MS_EXCEPTION_IF_NULL(node);
if (IsDeviceQueueDSActor(node, graph_compiler_info.strategy_)) {
*kernel_type = KernelTransformType::kDeviceDataSourceActor;
*kernel_name = graph_compiler_info.name_ + "_DeviceDSActor" + "_" + std::to_string(graph->graph_id());
@@ -1682,9 +1771,14 @@ void GraphScheduler::DumpActor(const ActorSet *actor_set, const GraphCompilerInf
DumpKernelActor(kernel_actor.get(), ofs);
}

ofs << "\n\n[Super kernel actors:" << actor_set->super_kernel_actors_.size() << "]\n";
for (const auto &super_kernel_actor : actor_set->super_kernel_actors_) {
DumpSuperKernelActor(super_kernel_actor.get(), ofs);
}

ofs << "\n\n[No input kernel actors:" << actor_set->no_input_kernel_actors_.size() << "]\n";
for (const auto &no_input_kernel_actor : actor_set->no_input_kernel_actors_) {
DumpKernelActor(no_input_kernel_actor.get(), ofs);
DumpNoInputKernelActor(no_input_kernel_actor.get(), ofs);
}

ofs << "\n\n[Copy actors:" << actor_set->copy_actors_.size() << "]\n";
@@ -1778,11 +1872,18 @@ void GraphScheduler::DumpAbstractActor(const AbstractActor *actor, std::ofstream
}
}

if (actor->output_result_arrows_.size() != actor->output_nodes_.size()) {
MS_LOG(EXCEPTION) << "The size of output result arrows is not equal to the output nodes.";
}
if (actor->output_result_arrows_.size() > 0) {
ofs << "\t\toutput_result_arrows:" << actor->output_result_arrows_.size() << "\n ";
for (const auto &result_arrow : actor->output_result_arrows_) {
for (size_t i = 0; i < actor->output_result_arrows_.size(); ++i) {
auto result_arrow = actor->output_result_arrows_[i];
auto output_node = actor->output_nodes_[i];
MS_EXCEPTION_IF_NULL(result_arrow);
ofs << "\t\t\tfrom_output_index:" << result_arrow->from_output_index_
MS_EXCEPTION_IF_NULL(output_node);
ofs << "\t\t\tfrom_output_node:" << output_node->fullname_with_scope()
<< "tfrom_output_index:" << result_arrow->from_output_index_
<< "\tto_actor_name:" << result_arrow->to_op_id_.Name()
<< "\toutput_node_position:" << result_arrow->to_input_index_ << "\n";
}
@@ -1882,6 +1983,34 @@ void GraphScheduler::DumpKernelActor(const KernelActor *actor, std::ofstream &of
ofs << "\n";
}

void GraphScheduler::DumpSuperKernelActor(const SuperKernelActor *actor, std::ofstream &ofs) const {
MS_EXCEPTION_IF_NULL(actor);
ofs << "\tactor_name:" << actor->GetAID().Name() << "\n";

const auto &graph = actor->graph_;
MS_EXCEPTION_IF_NULL(graph);

ofs << "\t\tgraph id:" << graph->graph_id() << "\tgraphl_name:" << graph->ToString()
<< "\tis_sink:" << graph->is_sink() << "\tinputs_num:" << (graph->input_nodes()).size()
<< "\tkernels_num:" << (graph->execution_order()).size() << "\n";

DumpAbstractActor(actor, ofs);
ofs << "\n";
}

void GraphScheduler::DumpNoInputKernelActor(const AbstractActor *actor, std::ofstream &ofs) const {
MS_EXCEPTION_IF_NULL(actor);
if (actor->type_ == KernelTransformType::kKernelActor) {
auto kernel_actor = dynamic_cast<const KernelActor *>(actor);
MS_EXCEPTION_IF_NULL(kernel_actor);
DumpKernelActor(kernel_actor, ofs);
} else if (actor->type_ == KernelTransformType::kSuperKernelActor) {
auto super_kernel_actor = dynamic_cast<const SuperKernelActor *>(actor);
MS_EXCEPTION_IF_NULL(super_kernel_actor);
DumpSuperKernelActor(super_kernel_actor, ofs);
}
}

void GraphScheduler::DumpOutputActor(const OutputActor *actor, std::ofstream &ofs) const {
MS_EXCEPTION_IF_NULL(actor);
ofs << "\tactor_name:" << actor->GetAID().Name() << "\tloop_count:" << actor->loop_count_
@@ -1906,7 +2035,7 @@ void GraphScheduler::DumpCopyActor(const CopyActor *actor, std::ofstream &ofs) c
void GraphScheduler::DumpDeviceTensorStore(const GraphCompilerInfo &graph_compiler_info, std::ofstream &ofs) const {
for (const auto &graph : graph_compiler_info.graphs_) {
MS_EXCEPTION_IF_NULL(graph);
ofs << "\tgraph id:" << graph->graph_id() << "\n";
ofs << "\tgraph id:" << graph->graph_id() << "\tis_sink:" << graph->is_sink() << "\n";

for (auto &value_node : graph->graph_value_nodes()) {
MS_EXCEPTION_IF_NULL(value_node);


+ 26
- 15
mindspore/ccsrc/runtime/framework/graph_scheduler.h View File

@@ -32,6 +32,7 @@
#include "runtime/framework/actor/data_source_actor.h"
#include "runtime/framework/actor/loop_count_actor.h"
#include "runtime/framework/actor/kernel_actor.h"
#include "runtime/framework/actor/super_kernel_actor.h"
#include "runtime/framework/actor/output_actor.h"
#include "runtime/framework/actor/copy_actor.h"
#include "runtime/framework/actor/control_flow/switch_actor.h"
@@ -45,16 +46,18 @@ using mindspore::session::KernelGraph;
using mindspore::session::KernelWithIndex;
using ActorInfo = std::string;

// The second element of pair represents the output index of abstract actor corresponding to the graph output node.
using GraphOutputPair = std::pair<AbstractActor *, size_t>;
// The second element of pair represents the output node and output index of abstract actor corresponding to the graph
// output node.
using GraphOutputPair = std::pair<AbstractActor *, KernelWithIndex>;

// The actor set generated by graph transformer is the execution unit of actor runtime.
// It includes data source actor, kernel actor, switch actor, copy actor, loop count actor and output actor.
// The data prepare actor is used to prepare data for device tensor store and host tensor queue to represent the begin
// of one step.
// The data source actor is used to obtain data and process them into device tensors, and send them to kernel actor.
// The kernel actor is used to receive the device tensors to luanch kernel. Specifically notice the no input
// kernel actor, it means that this actor has no input device tensor, need be triggered externally.
// The kernel actor is used to receive the device tensors to luanch kernel.
// The Super kernel actor is used to represent the sink executing of graph which is the combination of kernels.
// The no input kernel actor means that this actor has no input arrow and needs to be triggered externally.
// The switch actor is used to run different branches in the control flow scenario.
// The gather actor is used to collect the inputs of graph and send branch id to loop count actor in multi-branch
// output scenario.
@@ -67,8 +70,9 @@ struct ActorSet {
DataPrepareActorPtr data_prepare_actor_{nullptr};
std::vector<DataSourceActorPtr> data_source_actors_;
std::vector<KernelActorPtr> kernel_actors_;
std::vector<SuperKernelActorPtr> super_kernel_actors_;
// No input kernel actors need be triggered specifically.
std::vector<KernelActorPtr> no_input_kernel_actors_;
std::vector<AbstractActorPtr> no_input_kernel_actors_;
std::vector<SwitchActorPtr> switch_actors_;
std::vector<GatherActorPtr> gather_actors_;
std::vector<CopyActorPtr> copy_actors_;
@@ -125,12 +129,13 @@ class GraphScheduler {
std::vector<DataSourceActorPtr> BuildDataSourceActor(const GraphCompilerInfo &graph_compiler_info,
const HostTensorQueuePtr &host_queue);
std::vector<KernelActorPtr> BuildKernelActor(const GraphCompilerInfo &graph_compiler_info);
std::vector<SuperKernelActorPtr> BuildSuperKernelActor(const GraphCompilerInfo &graph_compiler_info);
LoopCountActorPtr BuildLoopCountActor(const GraphCompilerInfo &graph_compiler_info);
OutputActorPtr BuildOutputActor(const GraphCompilerInfo &graph_compiler_info);
DataPrepareActorPtr BuildDataPrepareActor(const GraphCompilerInfo &graph_compiler_info,
const std::vector<DataSourceActorPtr> &data_source_actors,
const HostTensorQueuePtr &host_queue);
std::vector<KernelActorPtr> BuildNoInputKernelActor(const ActorSet *actor_set, GraphExecutionStrategy strategy);
std::vector<AbstractActorPtr> BuildNoInputKernelActor(const ActorSet *actor_set, GraphExecutionStrategy strategy);

// Cache the information of graph output node to actor between “build” and “link”, for linking between the tail of
// previous graph and the head of next graph.
@@ -138,38 +143,42 @@ class GraphScheduler {

// The processing of actors linking.
// 1. The processing of linking data arrows.
void LinkDataArrowInSinkMode(const KernelGraphPtr &graph, const GraphCompilerInfo &graph_compiler_info);
void LinkDataArrowInNonSinkMode(const KernelGraphPtr &graph, const GraphCompilerInfo &graph_compiler_info,
std::vector<KernelActor *> *const auto_monad_actors,
std::vector<CNodePtr> *const communication_nodes);
// The gather of linking data arrows of kernel, it will call following functions by the different from actor type.
void LinkDataArrow(KernelActor *const to_actor, const GraphCompilerInfo &graph_compiler_info,
const KernelGraphPtr &graph, const KernelWithIndex &from_kernel_with_output_idx,
const KernelWithIndex &to_kernel_with_input_idx);
void LinkDataArrowForBaseActor(AbstractActor *const from_actor, KernelActor *const to_actor,
void LinkDataArrowForBaseActor(AbstractActor *const from_actor, AbstractActor *const to_actor,
const KernelWithIndex &from_kernel_with_output_idx,
const KernelWithIndex &to_kernel_with_input_idx);
// Link data arrows for internal parameter, convert internal parameter to actor by internal parameter cache to link.
void LinkDataArrowForInternalParameter(AbstractActor *const from_actor, KernelActor *const to_actor,
void LinkDataArrowForInternalParameter(AbstractActor *const from_actor, AbstractActor *const to_actor,
const KernelWithIndex &from_kernel_with_output_idx,
const KernelWithIndex &to_kernel_with_input_idx, const KernelGraphPtr &graph);
void LinkDataArrowForDeviceTensorStore(AbstractActor *const from_actor, KernelActor *const to_actor,
void LinkDataArrowForDeviceTensorStore(AbstractActor *const from_actor, AbstractActor *const to_actor,
const KernelWithIndex &from_kernel_with_output_idx,
const KernelWithIndex &to_kernel_with_input_idx, const KernelGraphPtr &graph);
void LinkDataArrowForDeviceDSActor(AbstractActor *const from_actor, KernelActor *const to_actor,
void LinkDataArrowForDeviceDSActor(AbstractActor *const from_actor, AbstractActor *const to_actor,
const KernelWithIndex &from_kernel_with_output_idx,
const KernelWithIndex &to_kernel_with_input_idx, const KernelGraphPtr &graph);
void LinkDataArrowForHostDSActor(AbstractActor *const from_actor, KernelActor *const to_actor,
void LinkDataArrowForHostDSActor(AbstractActor *const from_actor, AbstractActor *const to_actor,
const KernelWithIndex &from_kernel_with_output_idx,
const KernelWithIndex &to_kernel_with_input_idx, const KernelGraphPtr &graph);
void LinkDataArrowForKernelActor(AbstractActor *const from_actor, KernelActor *const to_actor,
void LinkDataArrowForKernelActor(AbstractActor *const from_actor, AbstractActor *const to_actor,
const KernelWithIndex &from_kernel_with_output_idx,
const KernelWithIndex &to_kernel_with_input_idx, const KernelGraphPtr &graph);
// Link data arrows in the copy actor scene, insert the copy actor between from_actor and to_actor.
void LinkDataArrowForCopyActor(AbstractActor *const from_actor, KernelActor *const to_actor,
void LinkDataArrowForCopyActor(AbstractActor *const from_actor, AbstractActor *const to_actor,
const KernelWithIndex &from_kernel_with_output_idx,
const KernelWithIndex &to_kernel_with_input_idx);

// 2. The processing of linking control arrows.
void LinkControlArrowByAutoMonad(KernelActor *to_actor, const AnfNodePtr &from_node, const KernelGraphPtr &graph);
void LinkControlArrowByAutoMonad(AbstractActor *to_actor, const AnfNodePtr &from_node, const KernelGraphPtr &graph);
// The skipped node doesn't run, so need link the control arrow between the inputs and user of skipped node.
void LinkControlArrowBySkippedNode(KernelActor *to_actor, const AnfNodePtr &skipped_node);
void LinkControlArrowBySkippedNode(AbstractActor *to_actor, const AnfNodePtr &skipped_node);
// Link the control arrows for allreduce kernel by the send/recv nodes in the kernel graph.
void LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph);

@@ -220,6 +229,8 @@ class GraphScheduler {
void DumpDSActor(const DataSourceActor *actor, std::ofstream &ofs) const;
void DumpLoopCountActor(const LoopCountActor *actor, std::ofstream &ofs) const;
void DumpKernelActor(const KernelActor *actor, std::ofstream &ofs) const;
void DumpSuperKernelActor(const SuperKernelActor *actor, std::ofstream &ofs) const;
void DumpNoInputKernelActor(const AbstractActor *actor, std::ofstream &ofs) const;
void DumpOutputActor(const OutputActor *actor, std::ofstream &ofs) const;
void DumpCopyActor(const CopyActor *actor, std::ofstream &ofs) const;
void DumpGatherActor(const GatherActor *actor, std::ofstream &ofs) const;


Loading…
Cancel
Save