Browse Source

unified runtime code optimize of Run and SendOutput

tags/v1.6.0
limingqi107 4 years ago
parent
commit
ef29b33fa4
21 changed files with 249 additions and 398 deletions
  1. +49
    -7
      mindspore/ccsrc/runtime/framework/actor/abstract_actor.cc
  2. +20
    -5
      mindspore/ccsrc/runtime/framework/actor/abstract_actor.h
  3. +3
    -4
      mindspore/ccsrc/runtime/framework/actor/control_flow/entrance_actor.h
  4. +0
    -2
      mindspore/ccsrc/runtime/framework/actor/control_flow/exit_actor.h
  5. +0
    -1
      mindspore/ccsrc/runtime/framework/actor/control_flow/gather_actor.h
  6. +0
    -1
      mindspore/ccsrc/runtime/framework/actor/control_flow/stack_actor.h
  7. +7
    -35
      mindspore/ccsrc/runtime/framework/actor/copy_actor.cc
  8. +5
    -11
      mindspore/ccsrc/runtime/framework/actor/copy_actor.h
  9. +1
    -17
      mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.cc
  10. +1
    -11
      mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.h
  11. +12
    -30
      mindspore/ccsrc/runtime/framework/actor/data_source_actor.cc
  12. +8
    -12
      mindspore/ccsrc/runtime/framework/actor/data_source_actor.h
  13. +17
    -72
      mindspore/ccsrc/runtime/framework/actor/kernel_actor.cc
  14. +6
    -11
      mindspore/ccsrc/runtime/framework/actor/kernel_actor.h
  15. +4
    -8
      mindspore/ccsrc/runtime/framework/actor/loop_count_actor.cc
  16. +4
    -4
      mindspore/ccsrc/runtime/framework/actor/loop_count_actor.h
  17. +3
    -0
      mindspore/ccsrc/runtime/framework/actor/output_actor.h
  18. +13
    -53
      mindspore/ccsrc/runtime/framework/actor/super_kernel_actor.cc
  19. +2
    -8
      mindspore/ccsrc/runtime/framework/actor/super_kernel_actor.h
  20. +93
    -104
      mindspore/ccsrc/runtime/framework/graph_scheduler.cc
  21. +1
    -2
      mindspore/ccsrc/runtime/framework/graph_scheduler.h

+ 49
- 7
mindspore/ccsrc/runtime/framework/actor/abstract_actor.cc View File

@@ -20,6 +20,31 @@

namespace mindspore {
namespace runtime {
void AbstractActor::RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
auto &sequential_num = context->sequential_num_;
(void)input_op_datas_[sequential_num].emplace_back(input_data);

auto is_run = CheckRunningCondition(context);
MS_LOG(DEBUG) << "Actor(" << GetAID().Name() << ") receive the input op data and check running condition:" << is_run;
if (is_run) {
Run(context);
}
}

void AbstractActor::RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
auto &sequential_num = context->sequential_num_;
(void)input_op_controls_[sequential_num].emplace_back(input_control);

auto is_run = CheckRunningCondition(context);
MS_LOG(DEBUG) << "Actor(" << GetAID().Name()
<< ") receive the input op control and check running condition:" << is_run;
if (is_run) {
Run(context);
}
}

bool AbstractActor::CheckRunningCondition(const OpContext<DeviceTensor> *context) const {
MS_EXCEPTION_IF_NULL(context);
if (input_datas_num_ != 0) {
@@ -67,30 +92,47 @@ void AbstractActor::EraseInput(const OpContext<DeviceTensor> *context) {
}
}

void AbstractActor::SendOutputResult(OpContext<DeviceTensor> *const context) const {
void AbstractActor::SendOutput(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
// Must be the execution order: send result --> send data --> send control, avoid the illegal timing problem.
// 1.Send graph output result.
if (output_result_arrows_.size() != output_nodes_.size()) {
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The size of output result arrows is not equal to the output nodes.");
}

size_t output_node_index = 0;
for (const auto &result_arrow : output_result_arrows_) {
MS_EXCEPTION_IF_NULL(result_arrow);
Async(result_arrow->to_op_id_, &OutputActor::CollectOutput, output_nodes_[output_node_index],
Async(result_arrow->to_op_id_, &OutputActor::CollectOutput, output_nodes_[output_node_index++],
result_arrow->from_output_index_, result_arrow->to_input_index_, context);
++output_node_index;
}
}

void AbstractActor::SendOutputControl(OpContext<DeviceTensor> *const context) const {
MS_EXCEPTION_IF_NULL(context);
// 2.Send output data.
if (output_data_arrows_.size() != output_data_.size()) {
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The size of output data arrows is not equal to the output data.");
}
size_t output_data_arrow_index = 0;
for (auto &output_data : output_data_) {
MS_EXCEPTION_IF_NULL(output_data);
UpdateOutputData(output_data.get(), output_data_arrows_[output_data_arrow_index++].get(), context);
Async(output_data->op_id_, &OpActor::RunOpData, output_data.get(), context);
}

// 3.Send output control.
if (output_control_arrows_.size() > 0) {
auto from_aid = const_cast<AID *>(&GetAID());
for (auto &output_control : output_control_arrows_) {
Async(output_control, &OpActor::RunOpControl, from_aid, context);
}
}

// 4.Send recorder info.
SendRecorderInfo(context);

// No output.
if ((output_data_arrows_.size() == 0) && (output_control_arrows_.size() == 0) &&
(output_result_arrows_.size() == 0)) {
SET_OPCONTEXT_SUCCESS_RET((*context));
}
}
} // namespace runtime
} // namespace mindspore

+ 20
- 5
mindspore/ccsrc/runtime/framework/actor/abstract_actor.h View File

@@ -45,6 +45,11 @@ class AbstractActor : public OpActor<DeviceTensor> {

bool IsActive(int msg_num) override { return msg_num >= running_dependent_msg_num_ ? true : false; }

// The actor run when receive the input data.
void RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) override;
// The actor run when receive the input control.
void RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) override;

// Get the position of node in the actor.
virtual size_t FetchNodePosition(const AnfNodePtr &node) const { return 0; }

@@ -53,12 +58,19 @@ class AbstractActor : public OpActor<DeviceTensor> {

// Check whether satisfy the actor running condition.
bool CheckRunningCondition(const OpContext<DeviceTensor> *context) const;
// The actor run really when satisfy the actor running condition.
virtual void Run(OpContext<DeviceTensor> *const context) {}

// Erase input data and input controls when finish actor running.
void EraseInput(const OpContext<DeviceTensor> *const context);
// Send the output result by output_result_arrows_.
void SendOutputResult(OpContext<DeviceTensor> *const context) const;
// Send the output control by output_control_arrows_.
void SendOutputControl(OpContext<DeviceTensor> *const context) const;
void EraseInput(const OpContext<DeviceTensor> *context);

// Update the output data before send output data.
virtual void UpdateOutputData(OpData<DeviceTensor> *const output_data, const DataArrow *data_arrow,
OpContext<DeviceTensor> *const context) {}
// Send output to downstream actors to trigger running.
virtual void SendOutput(OpContext<DeviceTensor> *const context);
// Send recorder info to recorder actor.
virtual void SendRecorderInfo(OpContext<DeviceTensor> *const context) const {}

KernelTransformType type_;

@@ -68,6 +80,9 @@ class AbstractActor : public OpActor<DeviceTensor> {
// The id of recorder actor. Send message to it for recording info.
const AID *recorder_aid_;

// The output_data_ corresponds to the output_data_arrows_ one by one.
std::vector<OpDataUniquePtr<DeviceTensor>> output_data_;

// The output nodes and output result arrows of graph output.
std::vector<AnfNodePtr> output_nodes_;
std::vector<DataArrowPtr> output_result_arrows_;


+ 3
- 4
mindspore/ccsrc/runtime/framework/actor/control_flow/entrance_actor.h View File

@@ -40,17 +40,16 @@ class EntranceActor : public AbstractActor {

void Init() override;

// The entrance actor run when receive the input control.
void RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) override;
// The entrance actor run when receive the real parameter nodes and branch id.
void CollectRealParametersAndBranchId(const std::vector<KernelWithIndex> &real_parameters, int branch_id,
OpContext<DeviceTensor> *const context);

protected:
void Run(OpContext<DeviceTensor> *const context) override;

private:
friend class GraphScheduler;

void SendOutput(OpContext<DeviceTensor> *const context) const;

// Formal parameters of actor, which is the front node.
std::vector<KernelWithIndex> formal_parameters_;



+ 0
- 2
mindspore/ccsrc/runtime/framework/actor/control_flow/exit_actor.h View File

@@ -44,8 +44,6 @@ class ExitActor : public AbstractActor {
private:
friend class GraphScheduler;

void SendOutput(OpContext<DeviceTensor> *const context) const;

// Formal parameters of actor, which is the front node.
std::vector<KernelWithIndex> formal_parameters_;



+ 0
- 1
mindspore/ccsrc/runtime/framework/actor/control_flow/gather_actor.h View File

@@ -47,7 +47,6 @@ class GatherActor : public AbstractActor {

private:
friend class GraphScheduler;
void SendOutput(OpContext<DeviceTensor> *const context) const;

// Formal parameters of actor, which is the front node.
std::vector<KernelWithIndex> formal_parameters_;


+ 0
- 1
mindspore/ccsrc/runtime/framework/actor/control_flow/stack_actor.h View File

@@ -44,7 +44,6 @@ class StackActor : public MemoryAwareActor {

private:
friend class GraphScheduler;
void SendOutput(OpContext<DeviceTensor> *const context) const;

// Formal parameters record the input front-end node, these nodes may be parameter, kernel, call node.
std::vector<KernelWithIndex> formal_parameters_;


+ 7
- 35
mindspore/ccsrc/runtime/framework/actor/copy_actor.cc View File

@@ -45,26 +45,10 @@ void CopyActor::Init() {
}
}

void CopyActor::RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) {
void CopyActor::Run(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
auto &sequential_num = context->sequential_num_;
(void)input_op_datas_[sequential_num].emplace_back(input_data);
// When all the inputs are collected, then allocate memory and callback copy.
if (CheckRunningCondition(context)) {
FetchDeviceTensor(context);
SendMemoryAllocReq(context);
}
}

void CopyActor::RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
auto &sequential_num = context->sequential_num_;
(void)input_op_controls_[sequential_num].emplace_back(input_control);
// When all the inputs are collected, then allocate memory and callback copy.
if (CheckRunningCondition(context)) {
FetchDeviceTensor(context);
SendMemoryAllocReq(context);
}
FetchDeviceTensor(context);
SendMemoryAllocReq(context);
}

void CopyActor::SendMemoryAllocReq(OpContext<DeviceTensor> *const context) {
@@ -146,22 +130,10 @@ void CopyActor::FetchDeviceTensor(OpContext<DeviceTensor> *const context) {
}
}

void CopyActor::SendOutput(OpContext<DeviceTensor> *const context) const {
MS_EXCEPTION_IF_NULL(context);
// No output.
if ((output_data_arrows_.size() == 0) && (output_control_arrows_.size() == 0)) {
SET_OPCONTEXT_SUCCESS_RET((*context));
}

// Send output data.
for (auto &output_data : output_data_) {
MS_EXCEPTION_IF_NULL(output_data);
output_data->data_ = output_device_tensor_[0];
Async(output_data->op_id_, &OpActor::RunOpData, output_data.get(), context);
}

// Send output control.
SendOutputControl(context);
void CopyActor::UpdateOutputData(OpData<DeviceTensor> *const output_data, const DataArrow *,
OpContext<DeviceTensor> *const) {
MS_EXCEPTION_IF_NULL(output_data);
output_data->data_ = output_device_tensor_[0];
}
} // namespace runtime
} // namespace mindspore

+ 5
- 11
mindspore/ccsrc/runtime/framework/actor/copy_actor.h View File

@@ -42,34 +42,28 @@ class CopyActor : public MemoryAwareActor {

void Init() override;

// The copy actor run when receive the input data.
void RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) override;
// The copy actor run when receive the input control.
void RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) override;

// The memory related operation interface.
void SendMemoryAllocReq(OpContext<DeviceTensor> *const context) override;
void SendMemoryFreeReq(OpContext<DeviceTensor> *const context) override;
// The copy processing after memory alloc finished.
void OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) override;

protected:
void Run(OpContext<DeviceTensor> *const context) override;
void UpdateOutputData(OpData<DeviceTensor> *const output_data, const DataArrow *data_arrow,
OpContext<DeviceTensor> *const context) override;

private:
friend class GraphScheduler;

// Fetch the device tensor for copy.
void FetchDeviceTensor(OpContext<DeviceTensor> *const context);

// Send output data and output controls when finish copy.
void SendOutput(OpContext<DeviceTensor> *const context) const;

// The input device tensor is saved from the input data or fetched by device_tensor_store_keys_.
std::vector<DeviceTensor *> input_device_tensor_;
// The output device tensor is saved from the output or fetched by device_tensor_store_keys_.
std::vector<DeviceTensor *> output_device_tensor_;

// The output_data_ corresponds to the output_data_arrows_ one by one.
std::vector<OpDataUniquePtr<DeviceTensor>> output_data_;

// The output is created in the copy actor build, so can't be the raw pointer.
DeviceTensorPtr output_;
};


+ 1
- 17
mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.cc View File

@@ -124,7 +124,7 @@ void DataPrepareActor::Init() {
void DataPrepareActor::PrepareData(const std::vector<std::vector<TensorPtr>> &input_tensors,
OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
MS_LOG(INFO) << "Data prepare actor(" << GetAID().Name() << ") prepares data.";
MS_LOG(DEBUG) << "Data prepare actor(" << GetAID().Name() << ") prepares data.";

// Convert actor running data from input tensors.
if (input_tensors.size() > 0) {
@@ -175,22 +175,6 @@ void DataPrepareActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const contex
SendOutput(context);
}

void DataPrepareActor::SendOutput(OpContext<DeviceTensor> *const context) {
for (auto &data_source_aid : data_source_aids_) {
Async(data_source_aid, &DataSourceActor::FetchData, context);
}

auto source_aid = const_cast<AID *>(&GetAID());
for (auto &kernel_aid : no_input_kernel_aids_) {
Async(kernel_aid, &OpActor::RunOpControl, source_aid, context);
}

// Trigger loop count actor running when there are no data source actor and kernel actor.
if ((data_source_aids_.size() + no_input_kernel_aids_.size() == 0) && (loop_count_aid_ != nullptr)) {
Async(*loop_count_aid_, &LoopCountActor::RunOpControl, source_aid, context);
}
}

void DataPrepareActor::PrepareDataForDeviceTensorStore(const std::vector<std::vector<TensorPtr>> &input_tensors,
OpContext<DeviceTensor> *const context) {
for (size_t i = 0; i < graph_compiler_info_->graphs_.size(); ++i) {


+ 1
- 11
mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.h View File

@@ -45,8 +45,7 @@ class DataPrepareActor : public DebugAwareActor {
graph_compiler_info_(graph_compiler_info),
strategy_(GraphExecutionStrategy::kPipeline),
host_data_source_actor_(host_data_source_actor),
host_tensor_queue_(host_tensor_queue),
loop_count_aid_(nullptr) {}
host_tensor_queue_(host_tensor_queue) {}
~DataPrepareActor() override = default;

void Init() override;
@@ -65,9 +64,6 @@ class DataPrepareActor : public DebugAwareActor {
private:
friend class GraphScheduler;

// Send output controls when finish data prepare.
void SendOutput(OpContext<DeviceTensor> *const context);

void PrepareDataForDeviceTensorStore(const std::vector<std::vector<TensorPtr>> &input_tensors,
OpContext<DeviceTensor> *const context);
void PrepareDataForHostTensorQueue(const std::vector<std::vector<TensorPtr>> &input_tensors,
@@ -103,12 +99,6 @@ class DataPrepareActor : public DebugAwareActor {
HostQueueDSActorPtr host_data_source_actor_;
HostTensorQueuePtr host_tensor_queue_;

// The output controls contain the data source actors and the no input kernel actors.
std::vector<AID> data_source_aids_;
std::vector<AID> no_input_kernel_aids_;
// If has no data source actor and kernel actor, then need send to loop count actor.
const AID *loop_count_aid_;

// The nodes need continuous memory, which must allocate in the begin of step running. The first bool of pair
// expresses the inputs of node need continuous memory, the second bool of pair expresses the outputs of node need
// continuous memory.


+ 12
- 30
mindspore/ccsrc/runtime/framework/actor/data_source_actor.cc View File

@@ -58,43 +58,21 @@ void DataSourceActor::FetchData(OpContext<DeviceTensor> *const context) {
SendMemoryAllocReq(context);
}

void DataSourceActor::SendOutput(OpContext<DeviceTensor> *const context) {
void DataSourceActor::UpdateOutputData(OpData<DeviceTensor> *const output_data, const DataArrow *data_arrow,
OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(output_data);
MS_EXCEPTION_IF_NULL(data_arrow);
MS_EXCEPTION_IF_NULL(context);
// No output.
if ((output_data_arrows_.size() == 0) && (output_control_arrows_.size() == 0) &&
(output_result_arrows_.size() == 0)) {
SET_OPCONTEXT_SUCCESS_RET((*context));
}

if (buffers_.size() == 0) {
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The data queue is empty.");
}

// Must be the execution order: send result --> send data --> send control, avoid the illegal timing problem.
// 1.Send graph output result.
SendOutputResult(context);

// 2.Send output data.
const auto &output_device_tensors = buffers_.front();
for (size_t i = 0; i < output_data_arrows_.size(); ++i) {
auto &data_arrow = output_data_arrows_[i];
auto &output_data = output_data_[i];
MS_EXCEPTION_IF_NULL(data_arrow);
MS_EXCEPTION_IF_NULL(output_data);
if (IntToSize(data_arrow->from_output_index_) >= output_device_tensors.size()) {
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The output index is of range.");
}
output_data->data_ = output_device_tensors[data_arrow->from_output_index_];
Async(data_arrow->to_op_id_, &OpActor::RunOpData, output_data.get(), context);
}

// 3.Send output control.
SendOutputControl(context);

// 4.Send recorder info.
if (recorder_aid_ != nullptr) {
SendRecorderInfo(context);
if (IntToSize(data_arrow->from_output_index_) >= output_device_tensors.size()) {
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The output index is of range.");
}
output_data->data_ = output_device_tensors[data_arrow->from_output_index_];
}

void DeviceQueueDataSourceActor::Init() {
@@ -180,6 +158,8 @@ void DeviceQueueDataSourceActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *co
return;
}

EraseInput(context);

// Note that SendMemoryFreeReq must be in front of SendOutput, because SendOutput will trigger SendMemoryAllocReq of
// the next actor and the actor is asynchronous execution. So it is necessary to ensure that SendMemoryFreeReq of
// the current actor is in front of SendMemoryAllocReq of the next actor. One is to reuse the memory more fully,
@@ -197,7 +177,7 @@ void DeviceQueueDataSourceActor::OnDebugFinish(OpContext<DeviceTensor> *const co
SendOutput(context);
}

void DeviceQueueDataSourceActor::SendRecorderInfo(OpContext<DeviceTensor> *const context) {
void DeviceQueueDataSourceActor::SendRecorderInfo(OpContext<DeviceTensor> *const context) const {
if (recorder_aid_ != nullptr) {
MS_EXCEPTION_IF_NULL(data_kernel_);
Async(*recorder_aid_, &RecorderActor::RecordInfo, data_kernel_->fullname_with_scope(), &launch_info_,
@@ -279,6 +259,8 @@ void HostQueueDataSourceActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *cons
}
host_queue_->Pop();

EraseInput(context);

// Note that SendMemoryFreeReq must be in front of SendOutput, because SendOutput will trigger SendMemoryAllocReq of
// the next actor and the actor is asynchronous execution. So it is necessary to ensure that SendMemoryFreeReq of
// the current actor is in front of SendMemoryAllocReq of the next actor. One is to reuse the memory more fully,


+ 8
- 12
mindspore/ccsrc/runtime/framework/actor/data_source_actor.h View File

@@ -48,27 +48,23 @@ class DataSourceActor : public DebugAwareActor {

void Init() override;

// The process entry of data processing.
void FetchData(OpContext<DeviceTensor> *const context);

protected:
friend class GraphScheduler;

void Run(OpContext<DeviceTensor> *const context) override { FetchData(context); }

// The process entry of data processing.
void FetchData(OpContext<DeviceTensor> *const context);

// Construct the device tensors and fill to device tensor buffer from the member nodes during the data fetching.
virtual void FillDataBuffer() = 0;

// Send recorder info to recorder actor, only the device queue data source actor need.
virtual void SendRecorderInfo(OpContext<DeviceTensor> *const context) {}

// Send output to downstream actors to trigger computing after fetching data finished.
void SendOutput(OpContext<DeviceTensor> *const context);
void UpdateOutputData(OpData<DeviceTensor> *const output_data, const DataArrow *data_arrow,
OpContext<DeviceTensor> *const context) override;

// The buffers store the device tensors.
std::queue<std::vector<DeviceTensor *>> buffers_;
size_t buffer_capacity_;

// The output_data_ corresponds to the output_data_arrows_ one by one.
std::vector<OpDataUniquePtr<DeviceTensor>> output_data_;
};

// The class represents that the data source is device queue.
@@ -95,7 +91,7 @@ class DeviceQueueDataSourceActor : public DataSourceActor {

protected:
void FillDataBuffer() override;
void SendRecorderInfo(OpContext<DeviceTensor> *const context) override;
void SendRecorderInfo(OpContext<DeviceTensor> *const context) const override;

private:
friend class GraphScheduler;


+ 17
- 72
mindspore/ccsrc/runtime/framework/actor/kernel_actor.cc View File

@@ -74,59 +74,26 @@ void KernelActor::Init() {
auto device_address = output_device_tensors_[data_arrow->from_output_index_];
auto data =
std::make_unique<OpData<DeviceTensor>>(data_arrow->to_op_id_, device_address, data_arrow->to_input_index_);
(void)output_data_.emplace_back(data.get());
(void)output_data_by_output_index_[data_arrow->from_output_index_].emplace_back(std::move(data));
(void)output_data_by_output_index_[data_arrow->from_output_index_].emplace_back(data.get());
(void)output_data_.emplace_back(std::move(data));
}
}

void KernelActor::RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) {
void KernelActor::Run(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
MS_EXCEPTION_IF_NULL(device_contexts_[0]);

auto &sequential_num = context->sequential_num_;
(void)input_op_datas_[sequential_num].emplace_back(input_data);
if (input_data->data_ == nullptr) {
std::string error_info =
"Input data of actor:" + GetAID().Name() + " num:" + std::to_string(input_data->index_) + " is empty";
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
// When all the inputs are collected, then allocate memory and callback launch.
if (CheckRunningCondition(context)) {
// Infer kernel shape and update abstract info for dynamic shape kernel.
if (is_dynamic_shape_) {
device_contexts_[0]->UpdateDynamicShape(kernel_);
}

FetchInputDeviceTensor(context);
FetchOutputDeviceTensor();
if (memory_alloc_list_.size() > 0) {
SendMemoryAllocReq(context);
} else {
OnMemoryAllocFinish(context);
}
// Infer kernel shape and update abstract info for dynamic shape kernel.
if (is_dynamic_shape_) {
device_contexts_[0]->UpdateDynamicShape(kernel_);
}
}

void KernelActor::RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
MS_EXCEPTION_IF_NULL(device_contexts_[0]);

auto &sequential_num = context->sequential_num_;
(void)input_op_controls_[sequential_num].emplace_back(input_control);
// When all the inputs are collected, then allocate memory and callback launch.
if (CheckRunningCondition(context)) {
// Infer kernel shape and update abstract info for dynamic shape kernel.
if (is_dynamic_shape_) {
device_contexts_[0]->UpdateDynamicShape(kernel_);
}

FetchInputDeviceTensor(context);
FetchOutputDeviceTensor();
if (memory_alloc_list_.size() > 0) {
SendMemoryAllocReq(context);
} else {
OnMemoryAllocFinish(context);
}
FetchInputDeviceTensor(context);
FetchOutputDeviceTensor();
if (memory_alloc_list_.size() > 0) {
SendMemoryAllocReq(context);
} else {
OnMemoryAllocFinish(context);
}
}

@@ -410,40 +377,18 @@ void KernelActor::PostLaunchKernel(OpContext<DeviceTensor> *const context) {
if (memory_free_list_.size() > 0) {
SendMemoryFreeReq(context);
}
SendOutput(context);
}

void KernelActor::SendOutput(OpContext<DeviceTensor> *const context) const {
MS_EXCEPTION_IF_NULL(context);
MS_EXCEPTION_IF_NULL(kernel_);
if (strategy_ == GraphExecutionStrategy::kStep) {
return;
}

// Must be the execution order: send result --> send data --> send control, avoid the illegal timing problem.
// 1.Send graph output result.
SendOutputResult(context);

// 2.Send output data.
for (auto &output_data : output_data_) {
MS_EXCEPTION_IF_NULL(output_data);
Async(output_data->op_id_, &OpActor::RunOpData, output_data, context);
if (strategy_ == GraphExecutionStrategy::kPipeline) {
SendOutput(context);
}
}

// 3.Send output control.
SendOutputControl(context);

// 4.Send recorder info.
void KernelActor::SendRecorderInfo(OpContext<DeviceTensor> *const context) const {
if (recorder_aid_ != nullptr) {
MS_EXCEPTION_IF_NULL(kernel_);
Async(*recorder_aid_, &RecorderActor::RecordInfo, kernel_->fullname_with_scope(), &launch_info_,
device_contexts_[0], context);
}

// No output.
if ((output_data_arrows_.size() == 0) && (output_control_arrows_.size() == 0) &&
(output_result_arrows_.size() == 0)) {
SET_OPCONTEXT_SUCCESS_RET((*context));
}
}
} // namespace runtime
} // namespace mindspore

+ 6
- 11
mindspore/ccsrc/runtime/framework/actor/kernel_actor.h View File

@@ -58,10 +58,6 @@ class KernelActor : public DebugAwareActor {

void Init() override;

// The kernel actor run when receive the input data.
void RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) override;
// The kernel actor run when receive the input control.
void RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) override;
// The kernel actor run when receive the input control and input tensors, used in step mode.
void RunOpControlWithInputTensor(AID *const input_control, OpContext<DeviceTensor> *const context,
const std::vector<TensorPtr> *input_tensors);
@@ -77,6 +73,10 @@ class KernelActor : public DebugAwareActor {
// The callback after debug finished.
void OnDebugFinish(OpContext<DeviceTensor> *const context) override;

protected:
void Run(OpContext<DeviceTensor> *const context) override;
void SendRecorderInfo(OpContext<DeviceTensor> *const context) const override;

private:
friend class GraphScheduler;

@@ -92,9 +92,6 @@ class KernelActor : public DebugAwareActor {
// The processing after kernel launch: 1.erase input, 2.free memory, 3.send output.
void PostLaunchKernel(OpContext<DeviceTensor> *const context);

// Send output data and output controls when finish kernel launch.
void SendOutput(OpContext<DeviceTensor> *const context) const;

// The info of kernel.
CNodePtr kernel_;
KernelInfo *kernel_info_;
@@ -127,10 +124,8 @@ class KernelActor : public DebugAwareActor {
// The kernel launch info is fetched by the device tensors.
KernelLaunchInfo launch_info_;

// Cache unique output data by output index to modify the output data effectively.
std::vector<std::vector<OpDataUniquePtr<DeviceTensor>>> output_data_by_output_index_;
// The output_data_ corresponds to the output_data_arrows_ one by one.
std::vector<OpData<DeviceTensor> *> output_data_;
// Cache output data by output index to modify the output data effectively.
std::vector<std::vector<OpData<DeviceTensor> *>> output_data_by_output_index_;
};

using KernelActorPtr = std::shared_ptr<KernelActor>;


+ 4
- 8
mindspore/ccsrc/runtime/framework/actor/loop_count_actor.cc View File

@@ -25,15 +25,11 @@

namespace mindspore {
namespace runtime {
void LoopCountActor::RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) {
void LoopCountActor::Run(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
auto sequential_num = context->sequential_num_;
(void)input_op_controls_[sequential_num].emplace_back(input_control);
if (CheckRunningCondition(context)) {
// Need wait MemoryManagerActor running finished to avoid the illegal memory timing problem before
// LoopCountActor exits, because other processors which are not in actor also will process device tensor.
Async(memory_manager_aid_, &MemoryManagerActor::Wait, context, GetAID());
}
// Need wait MemoryManagerActor running finished to avoid the illegal memory timing problem before
// LoopCountActor exits, because other processors which are not in actor also will process device tensor.
Async(memory_manager_aid_, &MemoryManagerActor::Wait, context, GetAID());
}

void LoopCountActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {


+ 4
- 4
mindspore/ccsrc/runtime/framework/actor/loop_count_actor.h View File

@@ -43,9 +43,6 @@ class LoopCountActor : public DebugAwareActor {

~LoopCountActor() override = default;

// The loop count actor run when receive the input control.
void RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) override;

// The callback waits for the memory manager actor to finish all the message processing.
void OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) override;

@@ -54,11 +51,14 @@ class LoopCountActor : public DebugAwareActor {
// The callback after debug finished.
void OnDebugFinish(OpContext<DeviceTensor> *const context) override;

protected:
void Run(OpContext<DeviceTensor> *const context) override;
void SendOutput(OpContext<DeviceTensor> *const context) override;

private:
friend class GraphScheduler;

void IncreaseLoopCount(OpContext<DeviceTensor> *const context);
void SendOutput(OpContext<DeviceTensor> *const context);

// The loop count is constant, the current count is increased after each step running finished.
size_t loop_count_;


+ 3
- 0
mindspore/ccsrc/runtime/framework/actor/output_actor.h View File

@@ -76,6 +76,9 @@ class OutputActor : public AbstractActor {
size_t loop_count_;
size_t current_count_;

// The dependent input result arrow actors.
std::vector<AID> input_result_arrow_aids_;

// The outputs.
std::vector<TensorPtr> outputs_;
std::vector<KernelWithIndex> output_nodes_;


+ 13
- 53
mindspore/ccsrc/runtime/framework/actor/super_kernel_actor.cc View File

@@ -32,68 +32,28 @@ void SuperKernelActor::Init() {
running_dependent_msg_num_ = SizeToInt(input_datas_num_ + input_controls_num_);
}

void SuperKernelActor::RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) {
void SuperKernelActor::Run(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
MS_EXCEPTION_IF_NULL(graph_);
MS_EXCEPTION_IF_NULL(device_contexts_[0]);
MS_LOG(INFO) << "Super kernel actor(" << GetAID().Name() << ") launches graph: " << graph_->graph_id();

auto &sequential_num = context->sequential_num_;
(void)input_op_datas_[sequential_num].emplace_back(input_data);
if (CheckRunningCondition(context)) {
MS_LOG(INFO) << "Super kernel actor(" << GetAID().Name() << ") launches graph: " << graph_->graph_id();
try {
auto ret = device_contexts_[0]->LaunchGraph(graph_);
if (!ret) {
std::string error_info = "Launch graph failed, graph id: " + graph_->graph_id();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
} catch (const std::exception &e) {
MsException::Instance().SetException();
try {
auto ret = device_contexts_[0]->LaunchGraph(graph_);
if (!ret) {
std::string error_info = "Launch graph failed, graph id: " + graph_->graph_id();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
// The input is invalid and needs to be erased when finish kernel launch.
EraseInput(context);
SendOutput(context);
} catch (const std::exception &e) {
MsException::Instance().SetException();
std::string error_info = "Launch graph exception, graph id: " + graph_->graph_id();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
}

void SuperKernelActor::RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
MS_EXCEPTION_IF_NULL(device_contexts_[0]);

auto &sequential_num = context->sequential_num_;
(void)input_op_controls_[sequential_num].emplace_back(input_control);
if (CheckRunningCondition(context)) {
MS_LOG(INFO) << "Super kernel actor(" << GetAID().Name() << ") launches graph: " << graph_->graph_id();
try {
auto ret = device_contexts_[0]->LaunchGraph(graph_);
if (!ret) {
std::string error_info = "Launch graph failed, graph id: " + graph_->graph_id();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
} catch (const std::exception &e) {
MsException::Instance().SetException();
std::string error_info = "Launch graph failed, graph id: " + graph_->graph_id();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}

// The input is invalid and needs to be erased when finish kernel launch.
EraseInput(context);
SendOutput(context);
}
// The input is invalid and needs to be erased when finish kernel launch.
EraseInput(context);
SendOutput(context);
}

void SuperKernelActor::SendOutput(OpContext<DeviceTensor> *const context) const {
MS_EXCEPTION_IF_NULL(context);
SendOutputResult(context);
SendOutputControl(context);

// No output.
if ((output_data_arrows_.size() == 0) && (output_control_arrows_.size() == 0) &&
(output_result_arrows_.size() == 0)) {
SET_OPCONTEXT_SUCCESS_RET((*context));
}
}
} // namespace runtime
} // namespace mindspore

+ 2
- 8
mindspore/ccsrc/runtime/framework/actor/super_kernel_actor.h View File

@@ -42,18 +42,12 @@ class SuperKernelActor : public DebugAwareActor {

void Init() override;

// The super kernel actor run when receive the input data.
void RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) override;

// The super kernel actor run when receive the input control.
void RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) override;
protected:
void Run(OpContext<DeviceTensor> *const context) override;

private:
friend class GraphScheduler;

// Send output data and output controls when finish kernel launch.
void SendOutput(OpContext<DeviceTensor> *const context) const;

KernelGraphPtr graph_;
};



+ 93
- 104
mindspore/ccsrc/runtime/framework/graph_scheduler.cc View File

@@ -63,42 +63,42 @@ inline bool IsSingleOpActorSet(const ActorSet *actor_set) {
}

// Convert the actors vector by the actor set.
std::vector<ActorReference> CollectActors(const ActorSet *actor_set) {
std::vector<AbstractActorPtr> CollectActors(const ActorSet *actor_set) {
MS_EXCEPTION_IF_NULL(actor_set);
std::vector<ActorReference> actors;
std::vector<AbstractActorPtr> actors;

if (actor_set->data_prepare_actor_ != nullptr) {
(void)actors.emplace_back(static_cast<ActorReference>(actor_set->data_prepare_actor_));
(void)actors.emplace_back(static_cast<AbstractActorPtr>(actor_set->data_prepare_actor_));
}
for (auto &data_source_actor : actor_set->data_source_actors_) {
MS_EXCEPTION_IF_NULL(data_source_actor);
(void)actors.emplace_back(static_cast<ActorReference>(data_source_actor));
(void)actors.emplace_back(static_cast<AbstractActorPtr>(data_source_actor));
}
for (auto &kernel_actor : actor_set->kernel_actors_) {
MS_EXCEPTION_IF_NULL(kernel_actor);
(void)actors.emplace_back(static_cast<ActorReference>(kernel_actor));
(void)actors.emplace_back(static_cast<AbstractActorPtr>(kernel_actor));
}
for (auto &super_kernel_actor : actor_set->super_kernel_actors_) {
MS_EXCEPTION_IF_NULL(super_kernel_actor);
(void)actors.emplace_back(static_cast<ActorReference>(super_kernel_actor));
(void)actors.emplace_back(static_cast<AbstractActorPtr>(super_kernel_actor));
}
for (auto &switch_actor : actor_set->switch_actors_) {
MS_EXCEPTION_IF_NULL(switch_actor);
(void)actors.emplace_back(static_cast<ActorReference>(switch_actor));
(void)actors.emplace_back(static_cast<AbstractActorPtr>(switch_actor));
}
for (auto &gather_actor : actor_set->gather_actors_) {
MS_EXCEPTION_IF_NULL(gather_actor);
(void)actors.emplace_back(static_cast<ActorReference>(gather_actor));
(void)actors.emplace_back(static_cast<AbstractActorPtr>(gather_actor));
}
for (auto &copy_actor : actor_set->copy_actors_) {
MS_EXCEPTION_IF_NULL(copy_actor);
(void)actors.emplace_back(static_cast<ActorReference>(copy_actor));
(void)actors.emplace_back(static_cast<AbstractActorPtr>(copy_actor));
}
if (actor_set->loop_count_actor_ != nullptr) {
(void)actors.emplace_back(static_cast<ActorReference>(actor_set->loop_count_actor_));
(void)actors.emplace_back(static_cast<AbstractActorPtr>(actor_set->loop_count_actor_));
}
if (actor_set->output_actor_ != nullptr) {
(void)actors.emplace_back(static_cast<ActorReference>(actor_set->output_actor_));
(void)actors.emplace_back(static_cast<AbstractActorPtr>(actor_set->output_actor_));
}

return actors;
@@ -294,8 +294,8 @@ ActorSet *GraphScheduler::Transform(const GraphCompilerInfo &graph_compiler_info
(void)actors_.emplace(actor_set->name_, actor_set);

DumpActor(actor_set.get(), graph_compiler_info);
if (!CheckActorValid(actor_set.get(), graph_compiler_info.strategy_)) {
MS_LOG(EXCEPTION) << "The actor set of " << graph_compiler_info.name_ << " is invalid.";
if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kPipeline) {
CheckActorValid(actor_set.get());
}
MS_LOG(INFO) << "Graph(" << graph_compiler_info.name_ << ") transforms actor end.";

@@ -1072,6 +1072,7 @@ void GraphScheduler::LinkDataArrowForCopyActor(AbstractActor *const from_actor,
// Link.
(void)from_actor->output_data_arrows_.emplace_back(op_arrow_to_copy);
copy_actor->input_datas_num_++;
(void)copy_actor->input_data_arrow_aids_.emplace_back(from_actor->GetAID());

// Set the member of the copy actor.
auto to_kernel_mod = AnfAlgo::GetKernelMod(to_kernel_with_input_idx.first);
@@ -1093,6 +1094,7 @@ void GraphScheduler::LinkDataArrowForCopyActor(AbstractActor *const from_actor,
auto op_arrow_from_copy = std::make_shared<DataArrow>(0, to_actor->GetAID(), to_input_index);
(void)copy_actor->output_data_arrows_.emplace_back(op_arrow_from_copy);
to_actor->input_datas_num_++;
(void)to_actor->input_data_arrow_aids_.emplace_back(copy_actor->GetAID());
UpdateRefCount(copy_actor->output_.get());
}

@@ -1171,6 +1173,7 @@ void GraphScheduler::LinkControlArrowByAutoMonad(AbstractActor *to_actor, const
<< ", to actor: " << to_actor->GetAID().Name();
(void)from_actor->output_control_arrows_.emplace_back(to_actor->GetAID());
to_actor->input_controls_num_++;
(void)to_actor->input_control_arrow_aids_.emplace_back(from_actor->GetAID());
}
}

@@ -1190,6 +1193,7 @@ void GraphScheduler::LinkControlArrowBySkippedNode(AbstractActor *to_actor, cons
<< ", from actor: " << from_actor->GetAID().Name() << ", to actor: " << to_actor->GetAID().Name();
(void)from_actor->output_control_arrows_.emplace_back(to_aid);
to_actor->input_controls_num_++;
(void)to_actor->input_control_arrow_aids_.emplace_back(from_actor->GetAID());
}
}

@@ -1216,16 +1220,19 @@ void GraphScheduler::LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph
if (input_actor != nullptr) {
(void)input_actor->output_control_arrows_.emplace_back(from_send_actor->GetAID());
from_send_actor->input_controls_num_++;
(void)from_send_actor->input_control_arrow_aids_.emplace_back(input_actor->GetAID());
}
}

// from_send_actor --> from_recv_actor
(void)from_send_actor->output_control_arrows_.emplace_back(from_recv_actor->GetAID());
from_recv_actor->input_controls_num_++;
(void)from_recv_actor->input_control_arrow_aids_.emplace_back(from_send_actor->GetAID());

// from_recv_actor --> to_allreduce_actor
(void)from_recv_actor->output_control_arrows_.emplace_back(to_allreduce_actor->GetAID());
to_allreduce_actor->input_controls_num_++;
(void)to_allreduce_actor->input_control_arrow_aids_.emplace_back(from_recv_actor->GetAID());
}

for (auto &to_iter : graph->allreduce_to_send_recv_pairs()) {
@@ -1246,10 +1253,12 @@ void GraphScheduler::LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph
// from_allreduce_actor --> to_send_actor
(void)from_allreduce_actor->output_control_arrows_.emplace_back(to_send_actor->GetAID());
to_send_actor->input_controls_num_++;
(void)to_send_actor->input_control_arrow_aids_.emplace_back(from_allreduce_actor->GetAID());

// to_send_actor --> to_recv_actor
(void)to_send_actor->output_control_arrows_.emplace_back(to_recv_actor->GetAID());
to_recv_actor->input_controls_num_++;
(void)to_recv_actor->input_control_arrow_aids_.emplace_back(to_send_actor->GetAID());

// to_recv_actor --> outputs of from_allreduce_actor
for (auto &output_data_arrow : from_allreduce_actor->output_data_arrows_) {
@@ -1257,6 +1266,7 @@ void GraphScheduler::LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph
if (output_actor != nullptr) {
(void)to_recv_actor->output_control_arrows_.emplace_back(output_actor->GetAID());
output_actor->input_controls_num_++;
(void)output_actor->input_control_arrow_aids_.emplace_back(to_recv_actor->GetAID());
}
}

@@ -1309,6 +1319,7 @@ void GraphScheduler::LinkControlArrowByCommunicationNode(const std::vector<CNode
MS_EXCEPTION_IF_NULL(to_actor);
(void)from_actor->output_control_arrows_.emplace_back(to_actor->GetAID());
to_actor->input_controls_num_++;
(void)to_actor->input_control_arrow_aids_.emplace_back(from_actor->GetAID());
}

// Ensure all actors execute orderly to optimize the execution performance in the multi device scenario currently.
@@ -1322,6 +1333,7 @@ void GraphScheduler::LinkControlArrowByCommunicationNode(const std::vector<CNode
if ((from_actor != nullptr) && (to_actor != nullptr)) {
(void)from_actor->output_control_arrows_.emplace_back(to_actor->GetAID());
to_actor->input_controls_num_++;
(void)to_actor->input_control_arrow_aids_.emplace_back(from_actor->GetAID());
}
}
}
@@ -1335,20 +1347,25 @@ void GraphScheduler::LinkControlArrowForDataPrepareActor(DataPrepareActor *data_
// Data prepare actor --> data source actor.
for (auto &data_source_actor : actor_set->data_source_actors_) {
MS_EXCEPTION_IF_NULL(data_source_actor);
(void)data_prepare_actor->data_source_aids_.emplace_back(data_source_actor->GetAID());
(void)data_prepare_actor->output_control_arrows_.emplace_back(data_source_actor->GetAID());
data_source_actor->input_controls_num_++;
(void)data_source_actor->input_control_arrow_aids_.emplace_back(data_prepare_actor->GetAID());
}

// Data prepare actor --> no input kernel actor.
for (auto &no_input_kernel_actor : actor_set->no_input_kernel_actors_) {
MS_EXCEPTION_IF_NULL(no_input_kernel_actor);
(void)data_prepare_actor->no_input_kernel_aids_.emplace_back(no_input_kernel_actor->GetAID());
(void)data_prepare_actor->output_control_arrows_.emplace_back(no_input_kernel_actor->GetAID());
no_input_kernel_actor->input_controls_num_++;
(void)no_input_kernel_actor->input_control_arrow_aids_.emplace_back(data_prepare_actor->GetAID());
}

// Data prepare actor --> loop count actor.
if ((actor_set->data_source_actors_.size() + actor_set->no_input_kernel_actors_.size() == 0) &&
(actor_set->loop_count_actor_ != nullptr)) {
data_prepare_actor->loop_count_aid_ = &(actor_set->loop_count_actor_->GetAID());
(void)data_prepare_actor->output_control_arrows_.emplace_back(actor_set->loop_count_actor_->GetAID());
actor_set->loop_count_actor_->input_controls_num_++;
(void)actor_set->loop_count_actor_->input_control_arrow_aids_.emplace_back(data_prepare_actor->GetAID());
}
}

@@ -1392,6 +1409,7 @@ void GraphScheduler::LinkControlArrowForLoopCountActor(LoopCountActor *loop_coun
for (auto &no_output_actor : no_output_actors) {
(void)no_output_actor->output_control_arrows_.emplace_back(loop_count_actor->GetAID());
loop_count_actor->input_controls_num_++;
(void)loop_count_actor->input_control_arrow_aids_.emplace_back(no_output_actor->GetAID());
}

// Loop count actor --> data prepare actor.
@@ -1463,6 +1481,7 @@ void GraphScheduler::LinkOutputResultArrowForOutputActor(OutputActor *to_actor,
auto op_arrow = std::make_shared<DataArrow>(output_with_index.second, to_actor->GetAID(), output_position);
(void)from_actor->output_result_arrows_.emplace_back(op_arrow);
(void)from_actor->output_nodes_.emplace_back(output_with_index.first);
(void)to_actor->input_result_arrow_aids_.emplace_back(from_actor->GetAID());

// Update the real compute node in the host data source actor.
if (kernel_type == KernelTransformType::kHostDataSourceActor) {
@@ -1525,6 +1544,7 @@ void GraphScheduler::LinkDeviceTensorStoreForAutoMonadActor(const std::vector<Ke
// Link from kernel actor to copy actor.
(void)kernel_actor->output_control_arrows_.emplace_back(copy_actor->GetAID());
copy_actor->input_controls_num_++;
(void)copy_actor->input_control_arrow_aids_.emplace_back(kernel_actor->GetAID());
}
}
}
@@ -1539,82 +1559,60 @@ void GraphScheduler::LinkDataArrowForSwitchActor(SwitchActor *from_actor, const
OpActor<DeviceTensor> *to_actor, const size_t to_index,
const size_t branch_index) {}

bool GraphScheduler::CheckActorValid(const ActorSet *actor_set, GraphExecutionStrategy strategy) const {
void GraphScheduler::CheckActorValid(const ActorSet *actor_set) const {
MS_EXCEPTION_IF_NULL(actor_set);
// Check the data source actors.
for (const auto &data_source_actor : actor_set->data_source_actors_) {
MS_EXCEPTION_IF_NULL(data_source_actor);
if (data_source_actor->output_data_arrows_.size() + data_source_actor->output_result_arrows_.size() +
data_source_actor->output_control_arrows_.size() ==
0) {
MS_LOG(ERROR) << data_source_actor->GetAID().Name() << " has no user.";
return false;
}
}

if (strategy == GraphExecutionStrategy::kStep) {
return true;
}

// Check the super kernel actors.
for (const auto &super_kernel_actor : actor_set->super_kernel_actors_) {
MS_EXCEPTION_IF_NULL(super_kernel_actor);
if (super_kernel_actor->output_data_arrows_.size() + super_kernel_actor->output_control_arrows_.size() == 0) {
MS_LOG(ERROR) << super_kernel_actor->GetAID().Name() << " has no user.";
return false;
}
}

// Check the kernel actors.
for (const auto &kernel_actor : actor_set->kernel_actors_) {
MS_EXCEPTION_IF_NULL(kernel_actor);
if (kernel_actor->output_data_arrows_.size() + kernel_actor->output_control_arrows_.size() == 0) {
MS_LOG(ERROR) << kernel_actor->GetAID().Name() << " has no user.";
return false;
}

auto input_num = AnfAlgo::GetInputTensorNum(kernel_actor->kernel_);
auto input_data_num = kernel_actor->input_datas_num_;
auto device_tensor_store_num = kernel_actor->device_tensor_store_keys_.size();
if (input_data_num + device_tensor_store_num != input_num) {
MS_LOG(ERROR) << "The input building of " << AnfAlgo::GetNodeDebugString(kernel_actor->kernel_)
<< " is wrong, input data num: " << input_data_num
<< ", device tensor store num: " << device_tensor_store_num << ", total input num: " << input_num;
return false;
}
}

// Check the copy actors.
for (const auto &copy_actor : actor_set->copy_actors_) {
MS_EXCEPTION_IF_NULL(copy_actor);
if (copy_actor->output_data_arrows_.size() + copy_actor->output_control_arrows_.size() == 0) {
MS_LOG(ERROR) << copy_actor->GetAID().Name() << " has no user.";
return false;
}

const size_t kCopyActorInputDataNum = 1;
auto input_data_num = copy_actor->input_datas_num_;
size_t device_tensor_store_num = copy_actor->device_tensor_store_keys_.size();
if (input_data_num + device_tensor_store_num != kCopyActorInputDataNum) {
MS_LOG(ERROR) << "The input building of " << copy_actor->GetAID().Name()
<< " is wrong, input data num: " << input_data_num
<< ", device tensor store num: " << device_tensor_store_num
<< ", total input num: " << kCopyActorInputDataNum;
return false;
auto actors = CollectActors(actor_set);
for (auto &actor : actors) {
MS_EXCEPTION_IF_NULL(actor);
if ((actor->input_datas_num_ != actor->input_data_arrow_aids_.size()) ||
(actor->input_controls_num_ != actor->input_control_arrow_aids_.size())) {
MS_LOG(EXCEPTION) << "The input num of " << actor->GetAID().Name()
<< " is wrong, expect data num: " << actor->input_datas_num_
<< ", actual data num: " << actor->input_data_arrow_aids_.size()
<< ", expect control num: " << actor->input_controls_num_
<< ", actual control num: " << actor->input_control_arrow_aids_.size();
}

if ((actor->type_ != KernelTransformType::kOutputActor) && (actor->type_ != KernelTransformType::kLoopCountActor) &&
(actor->output_data_arrows_.size() == 0) && (actor->output_control_arrows_.size() == 0) &&
(actor->output_result_arrows_.size() == 0)) {
MS_LOG(EXCEPTION) << actor->GetAID().Name() << " has no user.";
}
if ((actor->type_ != KernelTransformType::kOutputActor) &&
(actor->type_ != KernelTransformType::kDataPrepareActor) && (actor->input_datas_num_ == 0) &&
(actor->input_controls_num_ == 0)) {
MS_LOG(EXCEPTION) << actor->GetAID().Name() << " has no source.";
}

// Check the input of kernel actors and copy actors.
if ((actor->type_ == KernelTransformType::kKernelActor) || (actor->type_ == KernelTransformType::kCopyActor)) {
size_t expect_toal_input_num = 1;
if (actor->type_ == KernelTransformType::kKernelActor) {
auto kernel_actor = dynamic_cast<KernelActor *>(actor.get());
MS_EXCEPTION_IF_NULL(kernel_actor);
expect_toal_input_num = AnfAlgo::GetInputTensorNum(kernel_actor->kernel_);
}
auto input_data_num = actor->input_datas_num_;
auto device_tensor_store_num = actor->device_tensor_store_keys_.size();
if (input_data_num + device_tensor_store_num != expect_toal_input_num) {
MS_LOG(EXCEPTION) << "The input building of " << actor->GetAID().Name()
<< " is wrong, input data num: " << input_data_num
<< ", device tensor store num: " << device_tensor_store_num
<< ", total input num: " << expect_toal_input_num;
}
}
}

// Check the loop count actor.
const auto &loop_count_actor = actor_set->loop_count_actor_;
if ((loop_count_actor != nullptr) &&
(actor_set->data_source_actors_.size() + actor_set->kernel_actors_.size() + actor_set->copy_actors_.size() > 0)) {
if (loop_count_actor->input_controls_num_ == 0) {
MS_LOG(ERROR) << loop_count_actor->GetAID().Name() << " has no source.";
return false;
}
// Check the output actor.
auto output_actor = actor_set->output_actor_;
MS_EXCEPTION_IF_NULL(output_actor);
if (output_actor->input_result_arrow_aids_.size() + output_actor->device_tensor_store_keys_.size() !=
output_actor->outputs_num_) {
MS_LOG(EXCEPTION) << "The outputs num of output actor is wrong, the total outputs num: "
<< output_actor->outputs_num_
<< ", the input result arrows num: " << output_actor->input_result_arrow_aids_.size()
<< ", the device tensor store num: " << output_actor->device_tensor_store_keys_.size();
}

return true;
}

void GraphScheduler::PersistDeviceTensor(const GraphCompilerInfo &graph_compiler_info) {
@@ -1819,14 +1817,6 @@ void GraphScheduler::DumpActor(const ActorSet *actor_set, const GraphCompilerInf

void GraphScheduler::DumpAbstractActor(const AbstractActor *actor, std::ofstream &ofs) const {
MS_EXCEPTION_IF_NULL(actor);
ofs << "\t\tdevice_contexts_num:" << actor->device_contexts_.size()
<< "\tdevice_tensor_store_keys_num:" << actor->device_tensor_store_keys_.size()
<< "\tinput_data_arrow_actors_num:" << actor->input_datas_num_
<< "\tinput_control_arrow_actors_num:" << actor->input_controls_num_ << "\n";
ofs << "\t\toutput_data_arrows_num:" << actor->output_data_arrows_.size()
<< "\toutput_control_arrows_num:" << actor->output_control_arrows_.size()
<< "\toutput_result_arrows_num:" << actor->output_result_arrows_.size() << "\n";

if (actor->device_contexts_.size() > 0) {
ofs << "\t\tdevice_contexts:" << actor->device_contexts_.size() << "\n ";
for (const auto &device_context : actor->device_contexts_) {
@@ -1903,14 +1893,6 @@ void GraphScheduler::DumpDataPrepareActor(const DataPrepareActor *actor, std::of
ofs << "\tactor_name:" << actor->GetAID().Name() << "\n";
DumpAbstractActor(actor, ofs);

ofs << "\t\toutput_control_arrows:" << actor->data_source_aids_.size() + actor->no_input_kernel_aids_.size() << "\n ";
for (const auto &aid : actor->data_source_aids_) {
ofs << "\t\t\tto_actor_name:" << aid.Name() << "\n";
}
for (const auto &aid : actor->no_input_kernel_aids_) {
ofs << "\t\t\tto_actor_name:" << aid.Name() << "\n";
}

ofs << "\t\tcontinuous_memory_nodes:" << actor->continuous_memory_nodes_.size() << "\n ";
for (const auto &iter : actor->continuous_memory_nodes_) {
MS_EXCEPTION_IF_NULL(iter.first.first);
@@ -2023,7 +2005,13 @@ void GraphScheduler::DumpOutputActor(const OutputActor *actor, std::ofstream &of
MS_EXCEPTION_IF_NULL(actor);
ofs << "\tactor_name:" << actor->GetAID().Name() << "\tloop_count:" << actor->loop_count_
<< "\toutputs_num:" << actor->outputs_num_ << "\n";

DumpAbstractActor(actor, ofs);

ofs << "\t\tinput_result_arrows:" << actor->input_result_arrow_aids_.size() << "\n ";
for (const auto &input_result_arrow_aid : actor->input_result_arrow_aids_) {
ofs << "\t\t\tfrom_actor_name:" << input_result_arrow_aid.Name() << "\n";
}
}

void GraphScheduler::DumpCopyActor(const CopyActor *actor, std::ofstream &ofs) const {
@@ -2043,7 +2031,8 @@ void GraphScheduler::DumpCopyActor(const CopyActor *actor, std::ofstream &ofs) c
void GraphScheduler::DumpDeviceTensorStore(const GraphCompilerInfo &graph_compiler_info, std::ofstream &ofs) const {
for (const auto &graph : graph_compiler_info.graphs_) {
MS_EXCEPTION_IF_NULL(graph);
ofs << "\tgraph_id:" << graph->graph_id() << "\tis_sink:" << graph->is_sink() << "\n";
ofs << "\tgraph_id:" << graph->graph_id() << "\tis_sink:" << graph->is_sink()
<< "\texecution_strategy:" << graph_compiler_info.strategy_ << "\n";

for (auto &value_node : graph->graph_value_nodes()) {
MS_EXCEPTION_IF_NULL(value_node);


+ 1
- 2
mindspore/ccsrc/runtime/framework/graph_scheduler.h View File

@@ -207,8 +207,7 @@ class GraphScheduler {
const size_t branch_index = SIZE_MAX);

// Check whether the actor set is valid.
bool CheckActorValid(const ActorSet *actor_set,
GraphExecutionStrategy strategy = GraphExecutionStrategy::kPipeline) const;
void CheckActorValid(const ActorSet *actor_set) const;

// Persist device tensors of graph's some nodes(such as weights and value nodes).
void PersistDeviceTensor(const GraphCompilerInfo &graph_compiler_info);


Loading…
Cancel
Save