Browse Source

unified runtime delete result arrow

tags/v1.6.0
limingqi107 4 years ago
parent
commit
1ba78fcbd6
10 changed files with 171 additions and 179 deletions
  1. +10
    -5
      mindspore/ccsrc/runtime/device/device_address.h
  2. +5
    -17
      mindspore/ccsrc/runtime/framework/actor/abstract_actor.cc
  3. +0
    -6
      mindspore/ccsrc/runtime/framework/actor/abstract_actor.h
  4. +1
    -26
      mindspore/ccsrc/runtime/framework/actor/actor_dump.cc
  5. +5
    -2
      mindspore/ccsrc/runtime/framework/actor/loop_count_actor.cc
  6. +0
    -3
      mindspore/ccsrc/runtime/framework/actor/loop_count_actor.h
  7. +116
    -89
      mindspore/ccsrc/runtime/framework/actor/output_actor.cc
  8. +5
    -7
      mindspore/ccsrc/runtime/framework/actor/output_actor.h
  9. +22
    -17
      mindspore/ccsrc/runtime/framework/graph_scheduler.cc
  10. +7
    -7
      mindspore/ccsrc/vm/backend.cc

+ 10
- 5
mindspore/ccsrc/runtime/device/device_address.h View File

@@ -90,17 +90,26 @@ class DeviceAddress : public mindspore::DeviceSync {
virtual ~DeviceAddress() { ptr_ = nullptr; }

const void *GetPtr() const { return ptr_; }
void set_ptr(void *ptr) { ptr_ = ptr; }
size_t GetSize() const { return size_; }
void SetSize(size_t size) { size_ = size; }

std::string format() const { return format_; }
TypeId type_id() const { return type_id_; }
bool from_mem_pool() const { return from_mem_pool_; }
void set_from_mem_pool(bool from_mem_pool) { from_mem_pool_ = from_mem_pool; }
void set_host_shape(const ShapeVector &shape) { host_shape_ = shape; }
virtual void set_status(DeviceAddressStatus status) {}
virtual DeviceAddressStatus status() const { return DeviceAddressStatus::kInDevice; }
virtual DeviceAddressType DeviceType() const { return DeviceAddressType::kUnknown; }
void *GetMutablePtr() const override { return ptr_; }

virtual void SetNodeIndex(const AnfNodePtr &node, size_t out_index) { node_index_ = {node, out_index}; }
KernelWithIndex GetNodeIndex() const {
return node_index_.first.expired() ? KernelWithIndex{nullptr, node_index_.second}
: KernelWithIndex{node_index_.first.lock(), node_index_.second};
}

virtual bool DumpMemToFile(const std::string &filepath, const std::string &host_fmt, const ShapeVector &host_shape,
TypeId host_type, bool trans_flag) const {
return true;
@@ -115,11 +124,7 @@ class DeviceAddress : public mindspore::DeviceSync {
protected:
const void *ptr() const { return ptr_; }
size_t size() const { return size_; }
void set_ptr(void *ptr) { ptr_ = ptr; }
KernelWithIndex GetNodeIndex() const {
return node_index_.first.expired() ? KernelWithIndex{nullptr, node_index_.second}
: KernelWithIndex{node_index_.first.lock(), node_index_.second};
}

mutable void *ptr_{nullptr};
size_t size_{0};
string format_{"DefaultFormat"};


+ 5
- 17
mindspore/ccsrc/runtime/framework/actor/abstract_actor.cc View File

@@ -94,19 +94,8 @@ void AbstractActor::EraseInput(const OpContext<DeviceTensor> *context) {

void AbstractActor::SendOutput(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
// Must be the execution order: send result --> send data --> send control, avoid the illegal timing problem.
// 1.Send graph output result.
if (output_result_arrows_.size() != output_result_nodes_.size()) {
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The size of output result arrows is not equal to the output nodes.");
}
size_t output_node_index = 0;
for (const auto &result_arrow : output_result_arrows_) {
MS_EXCEPTION_IF_NULL(result_arrow);
Async(result_arrow->to_op_id_, &OutputActor::CollectOutput, output_result_nodes_[output_node_index++],
result_arrow->from_output_index_, result_arrow->to_input_index_, context);
}

// 2.Send output data.
// Must be the execution order: send data --> send control, avoid the illegal timing problem.
// 1.Send output data.
if ((output_data_arrows_.size() != output_data_.size()) ||
(output_data_arrows_.size() != output_data_nodes_.size())) {
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The size of output data arrows is not equal to the output data.");
@@ -120,7 +109,7 @@ void AbstractActor::SendOutput(OpContext<DeviceTensor> *const context) {
++output_data_arrow_index;
}

// 3.Send output control.
// 2.Send output control.
if (output_control_arrows_.size() > 0) {
auto from_aid = const_cast<AID *>(&GetAID());
for (auto &output_control : output_control_arrows_) {
@@ -128,12 +117,11 @@ void AbstractActor::SendOutput(OpContext<DeviceTensor> *const context) {
}
}

// 4.Send recorder info.
// 3.Send recorder info.
SendRecorderInfo(context);

// No output.
if ((output_data_arrows_.size() == 0) && (output_control_arrows_.size() == 0) &&
(output_result_arrows_.size() == 0)) {
if ((output_data_arrows_.size() == 0) && (output_control_arrows_.size() == 0)) {
SET_OPCONTEXT_SUCCESS_RET((*context));
}
}


+ 0
- 6
mindspore/ccsrc/runtime/framework/actor/abstract_actor.h View File

@@ -57,8 +57,6 @@ class AbstractActor : public OpActor<DeviceTensor> {
KernelTransformType type() const { return type_; }
const std::vector<const DeviceContext *> &device_contexts() const { return device_contexts_; }
const std::vector<AnfNodePtr> &output_data_nodes() const { return output_data_nodes_; }
const std::vector<AnfNodePtr> &output_result_nodes() const { return output_result_nodes_; }
const std::vector<DataArrowPtr> &output_result_arrows() const { return output_result_arrows_; }
const std::vector<std::pair<size_t, AnfNodePtr>> &device_tensor_store_keys() const {
return device_tensor_store_keys_;
}
@@ -96,10 +94,6 @@ class AbstractActor : public OpActor<DeviceTensor> {
std::vector<AnfNodePtr> output_data_nodes_;
std::vector<OpDataUniquePtr<DeviceTensor>> output_data_;

// The output result nodes and output result arrows of graph output.
std::vector<AnfNodePtr> output_result_nodes_;
std::vector<DataArrowPtr> output_result_arrows_;

// The dependent device tensor stores, the dependent expression is pair<index, AnfNode>.
// Index is the input position, AnfNode is the key of the device tensor store.
std::vector<std::pair<size_t, AnfNodePtr>> device_tensor_store_keys_;


+ 1
- 26
mindspore/ccsrc/runtime/framework/actor/actor_dump.cc View File

@@ -86,23 +86,6 @@ void DumpAbstractActor(const AbstractActor *actor, std::ofstream &ofs) {
ofs << "\t\t\tto_actor_name:" << aid.Name() << "\n";
}
}

if (actor->output_result_arrows().size() != actor->output_result_nodes().size()) {
MS_LOG(EXCEPTION) << "The size of output result arrows is not equal to the output nodes.";
}
if (actor->output_result_arrows().size() > 0) {
ofs << "\t\toutput_result_arrows:" << actor->output_result_arrows().size() << "\n ";
for (size_t i = 0; i < actor->output_result_arrows().size(); ++i) {
auto result_arrow = actor->output_result_arrows()[i];
auto output_node = actor->output_result_nodes()[i];
MS_EXCEPTION_IF_NULL(result_arrow);
MS_EXCEPTION_IF_NULL(output_node);
ofs << "\t\t\tfrom_output_node:" << GetSplitName(output_node->fullname_with_scope())
<< "\tfrom_output_index:" << result_arrow->from_output_index_
<< "\tto_actor_name:" << result_arrow->to_op_id_.Name()
<< "\toutput_node_position:" << result_arrow->to_input_index_ << "\n";
}
}
}

void DumpDSActor(const DataSourceActor *actor, std::ofstream &ofs) {
@@ -233,10 +216,7 @@ void DumpLoopCountActor(const LoopCountActorPtr &actor, std::ofstream &ofs) {
ofs << "\tactor_name:" << actor->GetAID().Name() << "\tloop_count:" << actor->loop_count() << "\n";
DumpAbstractActor(actor.get(), ofs);

const size_t kOutputControlArrowsNum = 2;
ofs << "\t\toutput_control_arrows:" << kOutputControlArrowsNum << "\n ";
ofs << "\t\t\tto_actor_name:" << actor->output_aid().Name() << "\n";
ofs << "\t\t\tto_actor_name:" << actor->data_prepare_aid().Name() << "\n";
ofs << "\t\t\tto_data_prepare_actor:" << actor->data_prepare_aid().Name() << "\n";
}

void DumpOutputActor(const OutputActorPtr &actor, std::ofstream &ofs) {
@@ -250,11 +230,6 @@ void DumpOutputActor(const OutputActorPtr &actor, std::ofstream &ofs) {

DumpAbstractActor(actor.get(), ofs);

ofs << "\t\tinput_result_arrows:" << actor->input_result_arrow_aids().size() << "\n ";
for (const auto &input_result_arrow_aid : actor->input_result_arrow_aids()) {
ofs << "\t\t\tfrom_actor_name:" << input_result_arrow_aid.Name() << "\n";
}

ofs << "\t\toutput_address_persisted_nodes:" << actor->output_address_persisted_nodes().size() << "\n ";
for (const auto &output_address_persisted_node : actor->output_address_persisted_nodes()) {
MS_EXCEPTION_IF_NULL(output_address_persisted_node);


+ 5
- 2
mindspore/ccsrc/runtime/framework/actor/loop_count_actor.cc View File

@@ -64,8 +64,11 @@ void LoopCountActor::SendOutput(OpContext<DeviceTensor> *const context) {
Async(*recorder_aid_, &RecorderActor::RecordOnStepEnd, context);
}

// Send loop count to output actor.
Async(output_aid_, &OutputActor::CollectLoopCount, current_count_, context);
// Send output control.
auto from_aid = const_cast<AID *>(&GetAID());
for (auto &output_control : output_control_arrows_) {
Async(output_control, &OpActor::RunOpControl, from_aid, context);
}

// The LoopCountActor exits.
if (current_count_ == loop_count_) {


+ 0
- 3
mindspore/ccsrc/runtime/framework/actor/loop_count_actor.h View File

@@ -52,7 +52,6 @@ class LoopCountActor : public DebugAwareActor {
// Get the member.
size_t loop_count() const { return loop_count_; }
const AID &data_prepare_aid() const { return data_prepare_aid_; }
const AID &output_aid() const { return output_aid_; }

protected:
void Run(OpContext<DeviceTensor> *const context) override;
@@ -69,9 +68,7 @@ class LoopCountActor : public DebugAwareActor {
// The total running count represents the toal step running count.
size_t total_running_count_;

// The output controls contain the data prepare actor and output actor.
AID data_prepare_aid_;
AID output_aid_;
};

using LoopCountActorPtr = std::shared_ptr<LoopCountActor>;


+ 116
- 89
mindspore/ccsrc/runtime/framework/actor/output_actor.cc View File

@@ -25,61 +25,19 @@ void OutputActor::Init() {
if (device_contexts_.size() != output_nodes_.size()) {
MS_LOG(EXCEPTION) << "The device contexts number is wrong.";
}
// Check outputs number.
if (output_nodes_.size() != outputs_.size()) {
MS_LOG(EXCEPTION) << "The outputs number is wrong.";
}

// Set the number of actor running dependent messages.
running_dependent_msg_num_ = SizeToInt(outputs_num_ - device_tensor_store_keys_.size());
}

TensorPtr OutputActor::CreateOutputTensor(const AnfNodePtr &output_node, size_t output_index, size_t output_position) {
MS_EXCEPTION_IF_NULL(output_node);
MS_LOG(INFO) << "Create output tensor, output node: " << output_node->fullname_with_scope()
<< ", output index: " << output_index << ", output position: " << output_position;

// Create host tensor, the output tensor should use the infer type, it will be handed correctly by tensor data sync
// when infer type is not equal to device type.
auto type_id = AnfAlgo::GetOutputInferDataType(output_node, output_index);
std::vector<int64_t> temp_shape;
auto shape = AnfAlgo::GetOutputInferShape(output_node, output_index);
(void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
auto tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(output_node, output_index));

const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(output_node, output_index, false);
MS_EXCEPTION_IF_NULL(device_tensor);
auto new_device_tensor = device_tensor;
// If the output node whose output address can't be changed, then create the new device tensor and copy the data.
if (output_address_persisted_nodes_.count(output_node) > 0) {
const auto &device_context = device_contexts_[output_index];
MS_EXCEPTION_IF_NULL(device_context);
new_device_tensor = device_context->CreateDeviceAddress(nullptr, device_tensor->GetSize(), device_tensor->format(),
device_tensor->type_id());
MS_EXCEPTION_IF_NULL(new_device_tensor);
new_device_tensor->set_original_ref_count(device_tensor->original_ref_count());
new_device_tensor->ResetRefCount();
if (!device_context->AllocateMemory(new_device_tensor.get(), new_device_tensor->GetSize())) {
MS_LOG(ERROR) << "Device(id:" << device_context->device_context_key().device_id_
<< ") memory isn't enough and alloc failed, kernel name: " << output_node->fullname_with_scope()
<< ", alloc size: " << new_device_tensor->GetSize() << "B.";
return nullptr;
}

if (!new_device_tensor->SyncDeviceToDevice(trans::GetRuntimePaddingShape(output_node, output_index),
device_tensor->GetSize(), device_tensor->type_id(),
device_tensor->GetPtr(), device_tensor->format())) {
MS_LOG(ERROR) << "Sync device to device failed, device type: " << new_device_tensor->DeviceType();
return nullptr;
}
}

// Put device tensor into host tensor.
tensor->set_device_address(new_device_tensor);
return tensor;
}

void OutputActor::CollectLoopCount(size_t loop_count, OpContext<DeviceTensor> *const context) {
void OutputActor::RunOpControl(AID *const, OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);

current_count_ = loop_count;
++current_count_;
if (loop_count_ == current_count_) {
if (current_outputs_num_ + device_tensor_store_keys_.size() != outputs_num_) {
std::string error_info = "The outputs num is wrong, the total outputs num: " + std::to_string(outputs_num_) +
@@ -106,56 +64,22 @@ void OutputActor::CollectLoopCount(size_t loop_count, OpContext<DeviceTensor> *c
}
}

void OutputActor::UpdateOutputDeviceAddress() {
// In the running end, when the device tensor of graph output node is set into host tensor, the graph output node
// need be set new device tensor, to avoid that the device tensor context of host tensor be rewritten in the next
// step or next loop. But the graph output nodes corresponding to device tensor store need to be skipped, because
// they are fixed addresses and persistent.
for (size_t i = 0; i < output_nodes_.size(); ++i) {
auto &output_node = output_nodes_[i].first;
auto output_index = output_nodes_[i].second;
// The output node whose output address can't be changed needs to be skipped.
if (output_address_persisted_nodes_.count(output_node) > 0) {
continue;
}

if ((output_node != nullptr) && (!IsPersistentDeviceTensor(output_node))) {
const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(output_node, output_index, false);
// The outputs may have the same output node, so need skip when the node has been set to new device tensor.
if ((device_tensor == nullptr) || (device_tensor->GetPtr() == nullptr)) {
continue;
}
const auto &device_context = device_contexts_[i];
MS_EXCEPTION_IF_NULL(device_context);
auto new_device_tensor = device_context->CreateDeviceAddress(nullptr, device_tensor->GetSize(),
device_tensor->format(), device_tensor->type_id());
MS_EXCEPTION_IF_NULL(new_device_tensor);
new_device_tensor->set_original_ref_count(device_tensor->original_ref_count());
new_device_tensor->ResetRefCount();
// Support skip nop node.
const auto &real_output_node = AnfAlgo::VisitKernelWithReturnType(output_node, output_index);
AnfAlgo::SetOutputAddr(new_device_tensor, real_output_node.second, real_output_node.first.get());
}
}

output_nodes_.clear();
output_nodes_.resize(outputs_num_);
}

void OutputActor::CollectOutput(const AnfNodePtr &output_node, size_t output_index, size_t output_position,
OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(output_node);
void OutputActor::RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(input_data);
MS_EXCEPTION_IF_NULL(input_data->data_);
MS_EXCEPTION_IF_NULL(context);
// Collect the output result in the last loop which is represented by "loop_count_ - current_count_ == 1".
if (loop_count_ - current_count_ != 1) {
return;
}

auto output_position = IntToSize(input_data->index_);
if (output_position >= outputs_.size()) {
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The input index is of range.");
}

auto tensor = CreateOutputTensor(output_node, output_index, output_position);
auto node_with_index = input_data->data_->GetNodeIndex();
auto tensor = CreateOutputTensor(node_with_index.first, node_with_index.second, output_position);
if (tensor == nullptr) {
SET_OPCONTEXT_FAIL_RET_WITH_ERROR(*context, "Create output tensor failed.");
}
@@ -164,7 +88,110 @@ void OutputActor::CollectOutput(const AnfNodePtr &output_node, size_t output_ind
current_outputs_num_++;

// Save the output nodes to clear the device tensor in the running end.
output_nodes_[output_position] = KernelWithIndex(output_node, output_index);
output_nodes_[output_position] = node_with_index;
}

TensorPtr OutputActor::CreateOutputTensor(const AnfNodePtr &output_node, size_t output_index, size_t output_position) {
MS_EXCEPTION_IF_NULL(output_node);
MS_LOG(INFO) << "Create output tensor, output node: " << output_node->fullname_with_scope()
<< ", output index: " << output_index << ", output position: " << output_position;

// Create host tensor, the output tensor should use the infer type, it will be handed correctly by tensor data sync
// when infer type is not equal to device type.
auto type_id = AnfAlgo::GetOutputInferDataType(output_node, output_index);
std::vector<int64_t> temp_shape;
auto shape = AnfAlgo::GetOutputInferShape(output_node, output_index);
(void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
auto tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(output_node, output_index));

const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(output_node, output_index, false);
MS_EXCEPTION_IF_NULL(device_tensor);
if (IsPersistentDeviceTensor(output_node)) {
tensor->set_device_address(device_tensor);
return tensor;
}

if (output_position >= device_contexts_.size()) {
MS_LOG(ERROR) << "The output position is of range: " << output_position;
return nullptr;
}
auto device_context = device_contexts_[output_position];
MS_EXCEPTION_IF_NULL(device_context);
if (device_context->GetDeviceAddressType() != device_tensor->DeviceType()) {
MS_LOG(ERROR) << "The device type is wrong.";
return nullptr;
}

// Create the device address and put it into host tensor.
if (output_node_to_tensor_device_address_.count({output_node, output_index}) > 0) {
tensor->set_device_address(output_node_to_tensor_device_address_[{output_node, output_index}]);
} else {
auto tensor_device_address = device_context->CreateDeviceAddress(nullptr, device_tensor->GetSize(),
device_tensor->format(), device_tensor->type_id());
MS_EXCEPTION_IF_NULL(tensor_device_address);
tensor->set_device_address(tensor_device_address);
output_node_to_tensor_device_address_[{output_node, output_index}] = tensor_device_address;
}
return tensor;
}

void OutputActor::UpdateOutputDeviceAddress() {
// In the running end, when the device ptr of graph output node is set into host tensor, the graph output node
// need be set new device ptr, to avoid that the device ptr context of host tensor be rewritten in the next
// step or next loop. But the graph output nodes corresponding to device tensor store need to be skipped, because
// they are fixed addresses and persistent.
for (size_t i = 0; i < output_nodes_.size(); ++i) {
auto &output_node = output_nodes_[i].first;
auto output_index = output_nodes_[i].second;
auto &tensor = outputs_[i];
if ((output_node == nullptr) || (IsPersistentDeviceTensor(output_node))) {
continue;
}

MS_EXCEPTION_IF_NULL(tensor);
auto tensor_device_address = std::dynamic_pointer_cast<DeviceTensor>(tensor->device_address());
MS_EXCEPTION_IF_NULL(tensor_device_address);
auto device_tensor = AnfAlgo::GetMutableOutputAddr(output_node, output_index, false);
MS_EXCEPTION_IF_NULL(device_tensor);

// Update tensor device address by device tensor of output node.
tensor_device_address->set_original_ref_count(device_tensor->original_ref_count());
tensor_device_address->ResetRefCount();
auto node_with_index = device_tensor->GetNodeIndex();
tensor_device_address->SetNodeIndex(node_with_index.first, node_with_index.second);
// The outputs may have the same output node, so need skip when the node has been done.
if (device_tensor->GetPtr() == nullptr) {
continue;
}

// If the output node whose output address ptr can't be changed, then alloc the new device memory and copy the data.
if (output_address_persisted_nodes_.count(output_node) > 0) {
auto device_context = device_contexts_[i];
MS_EXCEPTION_IF_NULL(device_context);
if (!device_context->AllocateMemory(tensor_device_address.get(), tensor_device_address->GetSize())) {
MS_LOG(EXCEPTION) << "Device(id:" << device_context->device_context_key().device_id_
<< ") memory isn't enough and alloc failed, kernel name: "
<< output_node->fullname_with_scope() << ", alloc size: " << tensor_device_address->GetSize()
<< "B.";
}
if (!tensor_device_address->SyncDeviceToDevice(trans::GetRuntimePaddingShape(output_node, output_index),
device_tensor->GetSize(), device_tensor->type_id(),
device_tensor->GetPtr(), device_tensor->format())) {
MS_LOG(EXCEPTION) << "Sync device to device failed, device type: " << tensor_device_address->DeviceType();
}
} else {
// Move the device ptr from device_tensor to tensor_device_address.
tensor_device_address->set_ptr(device_tensor->GetMutablePtr());
tensor_device_address->set_from_mem_pool(device_tensor->from_mem_pool());
device_tensor->set_ptr(nullptr);
device_tensor->set_from_mem_pool(false);
}
}

output_node_to_tensor_device_address_.clear();
output_nodes_.clear();
output_nodes_.resize(outputs_num_);
}
} // namespace runtime
} // namespace mindspore

+ 5
- 7
mindspore/ccsrc/runtime/framework/actor/output_actor.h View File

@@ -22,6 +22,7 @@
#include <memory>
#include <utility>
#include <algorithm>
#include <map>
#include <unordered_map>
#include <set>
#include "runtime/framework/control_node_parser.h"
@@ -56,11 +57,10 @@ class OutputActor : public AbstractActor {
void Init() override;

// The output actor collects loop count when receive the input control of loop count actor.
void CollectLoopCount(size_t loop_count, OpContext<DeviceTensor> *const context);
void RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) override;

// The output actor collects output result when receive the data of actor.
void CollectOutput(const AnfNodePtr &output_node, size_t output_index, size_t output_position,
OpContext<DeviceTensor> *const context);
void RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) override;

// The graph output need be set new device address every step or loop, to avoid that the device address
// context of tensor be rewritten in the next step or next loop.
@@ -69,7 +69,6 @@ class OutputActor : public AbstractActor {
// Get the member.
size_t loop_count() const { return loop_count_; }
size_t outputs_num() const { return outputs_num_; }
const std::vector<AID> &input_result_arrow_aids() const { return input_result_arrow_aids_; }
const std::set<AnfNodePtr> &output_address_persisted_nodes() const { return output_address_persisted_nodes_; }
std::vector<TensorPtr> &outputs() { return outputs_; }

@@ -83,9 +82,6 @@ class OutputActor : public AbstractActor {
size_t loop_count_;
size_t current_count_;

// The dependent input result arrow actors.
std::vector<AID> input_result_arrow_aids_;

// The outputs.
std::vector<TensorPtr> outputs_;
std::vector<KernelWithIndex> output_nodes_;
@@ -94,6 +90,8 @@ class OutputActor : public AbstractActor {
std::set<AnfNodePtr> output_address_persisted_nodes_;
size_t outputs_num_;
size_t current_outputs_num_;

std::map<KernelWithIndex, DeviceTensorPtr> output_node_to_tensor_device_address_;
};

using OutputActorPtr = std::shared_ptr<OutputActor>;


+ 22
- 17
mindspore/ccsrc/runtime/framework/graph_scheduler.cc View File

@@ -1364,13 +1364,12 @@ void GraphScheduler::LinkControlArrowForLoopCountActor(LoopCountActor *loop_coun
AddControlArrow(no_output_actor, loop_count_actor);
}

// Loop count actor --> output actor.
AddControlArrow(loop_count_actor, actor_set->output_actor_.get());

// Loop count actor --> data prepare actor.
MS_EXCEPTION_IF_NULL(actor_set->data_prepare_actor_);
loop_count_actor->data_prepare_aid_ = actor_set->data_prepare_actor_->GetAID();

// Loop count actor --> output actor.
MS_EXCEPTION_IF_NULL(actor_set->output_actor_);
loop_count_actor->output_aid_ = actor_set->output_actor_->GetAID();
}

void GraphScheduler::LinkOutputResultArrowForOutputActor(OutputActor *to_actor,
@@ -1415,6 +1414,10 @@ void GraphScheduler::LinkOutputResultArrowForOutputActor(OutputActor *to_actor,
// The graph output is from device tensor store.
if (IsPersistentDeviceTensor(output_with_index.first)) {
(void)to_actor->device_tensor_store_keys_.emplace_back(output_position, output_with_index.first);
auto device_tensor = AnfAlgo::GetMutableOutputAddr(output_with_index.first, output_with_index.second, false);
MS_EXCEPTION_IF_NULL(device_tensor);
// The output actor need use the relevant information of node to create output tensor.
device_tensor->SetNodeIndex(output_with_index.first, output_with_index.second);
continue;
}

@@ -1540,16 +1543,22 @@ void GraphScheduler::AddResultArrow(AbstractActor *const from_actor, OutputActor
MS_EXCEPTION_IF_NULL(from_kernel);

auto result_arrow = std::make_shared<DataArrow>(from_output_index, to_actor->GetAID(), output_position);
(void)from_actor->output_result_arrows_.emplace_back(result_arrow);
(void)from_actor->output_result_nodes_.emplace_back(from_kernel);
(void)to_actor->input_result_arrow_aids_.emplace_back(from_actor->GetAID());
(void)from_actor->output_data_arrows_.insert(from_actor->output_data_arrows_.begin(), result_arrow);
(void)from_actor->output_data_nodes_.insert(from_actor->output_data_nodes_.begin(), from_kernel);
to_actor->input_datas_num_++;
(void)to_actor->input_data_arrow_aids_.emplace_back(from_actor->GetAID());

auto device_tensor = AnfAlgo::GetMutableOutputAddr(from_kernel, from_output_index, false);
MS_EXCEPTION_IF_NULL(device_tensor);
// The output actor need use the relevant information of node to create output tensor.
device_tensor->SetNodeIndex(from_kernel, from_output_index);

if (from_actor->type_ == KernelTransformType::kSuperKernelActor) {
(void)to_actor->output_address_persisted_nodes_.insert(from_kernel);
}

// The device tensor of graph out need be taken over by host tensor, so set the max reference count.
UpdateRefCount(from_kernel, from_output_index, true);
UpdateRefCount(device_tensor.get(), true);
}

void GraphScheduler::AddControlArrow(AbstractActor *const from_actor, AbstractActor *const to_actor) {
@@ -1574,13 +1583,11 @@ void GraphScheduler::CheckActorValid(const ActorSet *actor_set) const {
<< ", actual control num: " << actor->input_control_arrow_aids_.size();
}

if ((actor->type_ != KernelTransformType::kOutputActor) && (actor->type_ != KernelTransformType::kLoopCountActor) &&
(actor->output_data_arrows_.size() == 0) && (actor->output_control_arrows_.size() == 0) &&
(actor->output_result_arrows_.size() == 0)) {
if ((actor->type_ != KernelTransformType::kOutputActor) && (actor->output_data_arrows_.size() == 0) &&
(actor->output_control_arrows_.size() == 0)) {
MS_LOG(EXCEPTION) << actor->GetAID().Name() << " has no user.";
}
if ((actor->type_ != KernelTransformType::kOutputActor) &&
(actor->type_ != KernelTransformType::kDataPrepareActor) && (actor->input_datas_num_ == 0) &&
if ((actor->type_ != KernelTransformType::kDataPrepareActor) && (actor->input_datas_num_ == 0) &&
(actor->input_controls_num_ == 0)) {
MS_LOG(EXCEPTION) << actor->GetAID().Name() << " has no source.";
}
@@ -1607,11 +1614,9 @@ void GraphScheduler::CheckActorValid(const ActorSet *actor_set) const {
// Check the output actor.
auto output_actor = actor_set->output_actor_;
MS_EXCEPTION_IF_NULL(output_actor);
if (output_actor->input_result_arrow_aids_.size() + output_actor->device_tensor_store_keys_.size() !=
output_actor->outputs_num_) {
if (output_actor->input_datas_num_ + output_actor->device_tensor_store_keys_.size() != output_actor->outputs_num_) {
MS_LOG(EXCEPTION) << "The outputs num of output actor is wrong, the total outputs num: "
<< output_actor->outputs_num_
<< ", the input result arrows num: " << output_actor->input_result_arrow_aids_.size()
<< output_actor->outputs_num_ << ", the input data arrows num: " << output_actor->input_datas_num_
<< ", the device tensor store num: " << output_actor->device_tensor_store_keys_.size();
}
}


+ 7
- 7
mindspore/ccsrc/vm/backend.cc View File

@@ -866,6 +866,13 @@ void MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &args,
}
}

MS_EXCEPTION_IF_NULL(graph_compiler_);
graph_compiler_->Summary(graph_compiler_info.graphs_);

// Update device address for output node of graph.
// Summary processing will use the output device address, so must be after the summary processing.
actor_set->output_actor_->UpdateOutputDeviceAddress();

// Fetch outputs.
MS_EXCEPTION_IF_NULL(actor_set->output_actor_);
auto &output_tensors = actor_set->output_actor_->outputs();
@@ -873,13 +880,6 @@ void MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &args,
size_t output_position = 0;
ConstructOutputs(root_graph_->output(), output_tensors, &output_position, outputs);
}

MS_EXCEPTION_IF_NULL(graph_compiler_);
graph_compiler_->Summary(graph_compiler_info.graphs_);

// Update device address for output node of graph.
// Summary processing will use the output device address, so must be after the summary processing.
actor_set->output_actor_->UpdateOutputDeviceAddress();
MS_LOG(INFO) << "Status record: end run actor: " << actor_info;
}



Loading…
Cancel
Save