Browse Source

!16558 fix bug of copy actor

From: @limingqi107
Reviewed-by: @cristoval,@wilfchen
Signed-off-by:
tags/v1.3.0
mindspore-ci-bot Gitee 4 years ago
parent
commit
e0893f33ad
17 changed files with 566 additions and 226 deletions
  1. +2
    -0
      mindspore/ccsrc/backend/session/session_basic.cc
  2. +5
    -0
      mindspore/ccsrc/runtime/device/device_address.h
  3. +4
    -1
      mindspore/ccsrc/runtime/framework/actor/actor_common.cc
  4. +1
    -1
      mindspore/ccsrc/runtime/framework/actor/actor_common.h
  5. +30
    -11
      mindspore/ccsrc/runtime/framework/actor/copy_actor.cc
  6. +10
    -7
      mindspore/ccsrc/runtime/framework/actor/copy_actor.h
  7. +54
    -10
      mindspore/ccsrc/runtime/framework/actor/data_source_actor.cc
  8. +21
    -18
      mindspore/ccsrc/runtime/framework/actor/data_source_actor.h
  9. +9
    -2
      mindspore/ccsrc/runtime/framework/actor/kernel_actor.cc
  10. +1
    -1
      mindspore/ccsrc/runtime/framework/actor/kernel_actor.h
  11. +10
    -3
      mindspore/ccsrc/runtime/framework/actor/switch_actor.cc
  12. +2
    -2
      mindspore/ccsrc/runtime/framework/actor/switch_actor.h
  13. +43
    -9
      mindspore/ccsrc/runtime/framework/device_tensor_store.h
  14. +342
    -148
      mindspore/ccsrc/runtime/framework/graph_scheduler.cc
  15. +15
    -9
      mindspore/ccsrc/runtime/framework/graph_scheduler.h
  16. +14
    -4
      mindspore/ccsrc/vm/backend.cc
  17. +3
    -0
      mindspore/core/mindrt/include/actor/op_actor.h

+ 2
- 0
mindspore/ccsrc/backend/session/session_basic.cc View File

@@ -727,6 +727,8 @@ void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph,
cnode_inputs->push_back(parameter_from_cnode);
(*other_graph_cnode)[anf] = parameter_from_cnode;
KernelWithIndex front_node_with_index(anf, 0);
MS_LOG(INFO) << "The " << input_idx << " input of node:" << cnode->fullname_with_scope()
<< " is from front node:" << anf->fullname_with_scope();
graph->CacheInternalParameterToFrontNode(parameter_from_cnode, front_node_with_index);
}
}


+ 5
- 0
mindspore/ccsrc/runtime/device/device_address.h View File

@@ -20,6 +20,7 @@
#include <string>
#include <vector>
#include <memory>
#include <map>
#include "ir/dtype.h"
#include "ir/device_sync.h"
#include "utils/shape_utils.h"
@@ -53,6 +54,10 @@ namespace mindspore {
namespace device {
enum class DeviceAddressStatus { kInDevice, kInHost, kInDeviceToHost, kInHostToDevice };
enum class DeviceAddressType { kUnknown, kAscend, kCPU, kGPU };
static const std::map<DeviceAddressType, std::string> kDeviceTypeToName = {{DeviceAddressType::kUnknown, "Unknown"},
{DeviceAddressType::kAscend, "Ascend"},
{DeviceAddressType::kCPU, "CPU"},
{DeviceAddressType::kGPU, "GPU"}};

class DeviceAddress : public mindspore::DeviceSync {
public:


+ 4
- 1
mindspore/ccsrc/runtime/framework/actor/actor_common.cc View File

@@ -45,8 +45,11 @@ bool IsDeviceQueueDSActor(const AnfNodePtr &node) {

bool IsHostQueueDSActor(const AnfNodePtr &node, const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(graph);
if (node->isa<Parameter>() && (!AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>()))) {
if (graph == nullptr) {
return true;
}

// Judge whether node is internal parameter.
const auto &front_node = graph->GetFrontNodeByInternalParameter(node);
if (front_node.first == nullptr) {


+ 1
- 1
mindspore/ccsrc/runtime/framework/actor/actor_common.h View File

@@ -45,7 +45,7 @@ constexpr int kFailure = 1;
int64_t GetMaxThreadNum();

bool IsDeviceQueueDSActor(const AnfNodePtr &node);
bool IsHostQueueDSActor(const AnfNodePtr &node, const KernelGraphPtr &graph);
bool IsHostQueueDSActor(const AnfNodePtr &node, const KernelGraphPtr &graph = nullptr);
bool IsKernelActor(const AnfNodePtr &node);

// Internal parameter is not the origin parameter of func graph, it is the output of previous kernel graph which is


+ 30
- 11
mindspore/ccsrc/runtime/framework/actor/copy_actor.cc View File

@@ -27,7 +27,7 @@ void CopyActor::RunOpData(OpDataPtr<DeviceTensor> input_data, OpContext<DeviceTe
input_op_datas_[sequential_num].emplace_back(input_data);
// When all the inputs are collected, then allocate memory and callback copy.
if (CheckCopyCondition(context)) {
FetchInputDeviceTensor(context);
FetchDeviceTensor(context);
AllocateMemory(context);
}
}
@@ -38,20 +38,20 @@ void CopyActor::RunOpControl(AID *input_control, OpContext<DeviceTensor> *contex
input_op_controls_[sequential_num].emplace_back(input_control);
// When all the inputs are collected, then allocate memory and callback copy.
if (CheckCopyCondition(context)) {
FetchInputDeviceTensor(context);
FetchDeviceTensor(context);
AllocateMemory(context);
}
}

void CopyActor::AllocateMemory(OpContext<DeviceTensor> *context) {
std::vector<DeviceTensor *> alloc_list({output_device_tensor_.get()});
std::vector<DeviceTensor *> alloc_list({output_device_tensor_});
Async(memory_manager_aid_, &MemoryManagerActor::AllocateMemory, alloc_list, output_device_context_, context,
GetAID());
}

void CopyActor::FreeMemory(OpContext<DeviceTensor> *context) {
std::vector<DeviceTensor *> input_free_list({input_device_tensor_});
std::vector<DeviceTensor *> output_free_list({output_device_tensor_.get()});
std::vector<DeviceTensor *> output_free_list({output_device_tensor_});
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, input_free_list, input_device_context_, context);
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, output_free_list, output_device_context_, context);
}
@@ -59,7 +59,7 @@ void CopyActor::FreeMemory(OpContext<DeviceTensor> *context) {
void CopyActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(context);

if (!Copy(output_device_tensor_.get(), input_device_tensor_)) {
if (!Copy(output_device_tensor_, input_device_tensor_)) {
std::string error_info = "Copy device tensor failed: " + GetAID().Name();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
@@ -115,12 +115,28 @@ bool CopyActor::CheckCopyCondition(OpContext<DeviceTensor> *context) const {
return true;
}

void CopyActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *context) {
void CopyActor::FetchDeviceTensor(OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(context);
MS_EXCEPTION_IF_NULL(input_device_context_);

if (device_tensor_store_key_.second != nullptr) {
input_device_tensor_ = DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key_.second,
input_device_context_->GetDeviceAddressType());
if (input_device_tensor_ == nullptr) {
std::string error_info =
GetAID().Name() + " get device tensor store failed: " + device_tensor_store_key_.second->fullname_with_scope() +
", device type:" + std::to_string(static_cast<int>(input_device_context_->GetDeviceAddressType()));
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}

if (device_tensor_store_keys_.size() > 0) {
const auto &device_tensor = DeviceTensorStore::GetInstance().Fetch(device_tensor_store_keys_[0].second);
input_device_tensor_ = device_tensor.get();
output_device_tensor_ = DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key_.second,
output_device_context_->GetDeviceAddressType());
if (output_device_tensor_ == nullptr) {
std::string error_info =
GetAID().Name() + " get device tensor store failed: " + device_tensor_store_key_.second->fullname_with_scope() +
", device type:" + std::to_string(static_cast<int>(output_device_context_->GetDeviceAddressType()));
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
} else {
const auto &data_iter = input_op_datas_.find(context->sequential_num_);
if (data_iter == input_op_datas_.end()) {
@@ -129,6 +145,9 @@ void CopyActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *context) {
const auto &input_data = data_iter->second[0];
MS_EXCEPTION_IF_NULL(input_data);
input_device_tensor_ = input_data->data_;

MS_EXCEPTION_IF_NULL(output_);
output_device_tensor_ = output_.get();
}
}

@@ -146,8 +165,8 @@ void CopyActor::SendOutput(OpContext<DeviceTensor> *context) const {
std::string error_info = "The output index is out of range: " + GetAID().Name();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
auto data = std::make_shared<OpData<DeviceTensor>>(op_arrow->to_op_id_, output_device_tensor_.get(),
op_arrow->to_input_index_);
auto data =
std::make_shared<OpData<DeviceTensor>>(op_arrow->to_op_id_, output_device_tensor_, op_arrow->to_input_index_);
Async(op_arrow->to_op_id_, &CopyActor::RunOpData, data, context);
}



+ 10
- 7
mindspore/ccsrc/runtime/framework/actor/copy_actor.h View File

@@ -42,7 +42,7 @@ class CopyActor : public MemoryInterfaceActor {
input_datas_num_(0),
input_controls_num_(0),
input_device_tensor_(nullptr),
output_device_tensor_(nullptr) {}
output_(nullptr) {}
~CopyActor() override = default;

// The copy actor run when receive the input data.
@@ -61,8 +61,8 @@ class CopyActor : public MemoryInterfaceActor {

// Check whether satisfy the condition for copy.
bool CheckCopyCondition(OpContext<DeviceTensor> *context) const;
// Fetch the input device tensor for copy.
void FetchInputDeviceTensor(OpContext<DeviceTensor> *context);
// Fetch the device tensor for copy.
void FetchDeviceTensor(OpContext<DeviceTensor> *context);

// Copy data from src_device_tensor to dst_device_tensor.
bool Copy(DeviceTensor *dst_device_tensor, const DeviceTensor *src_device_tensor);
@@ -80,16 +80,19 @@ class CopyActor : public MemoryInterfaceActor {
size_t input_controls_num_;

// Pair<index, anfNode> points to the dependent device tensor store, anfNode is the key of the device tensor store.
std::vector<std::pair<size_t, void *>> device_tensor_store_keys_;
std::pair<size_t, AnfNode *> device_tensor_store_key_;

// The device interface for copy.
const DeviceContext *input_device_context_;
const DeviceContext *output_device_context_;

// The input device tensor is saved from the input data.
// The input device tensor is saved from the input data or fetched by device_tensor_store_key_.
DeviceTensor *input_device_tensor_;
// The output device tensor is created in the copy actor build, so can't be the raw pointer.
DeviceTensorPtr output_device_tensor_;
// The output device tensor is saved from the output or fetched by device_tensor_store_key_.
DeviceTensor *output_device_tensor_;

// The output is created in the copy actor build, so can't be the raw pointer.
DeviceTensorPtr output_;
};

using CopyActorPtr = std::shared_ptr<CopyActor>;


+ 54
- 10
mindspore/ccsrc/runtime/framework/actor/data_source_actor.cc View File

@@ -48,16 +48,6 @@ void DataSourceActor::FetchData(OpContext<DeviceTensor> *context) {
AllocateMemory(context);
}

void DataSourceActor::AllocateMemory(OpContext<DeviceTensor> *context) {
auto device_tensors = buffers_.back();
Async(memory_manager_aid_, &MemoryManagerActor::AllocateMemory, device_tensors, device_context_, context, GetAID());
}

void DataSourceActor::FreeMemory(OpContext<DeviceTensor> *context) {
auto device_tensors = buffers_.front();
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, device_tensors, device_context_, context);
}

void DataSourceActor::SendOutput(OpContext<DeviceTensor> *context) {
MS_LOG(INFO) << "Data source actor(" << GetAID().Name() << ") sends output data.";
MS_EXCEPTION_IF_NULL(context);
@@ -98,6 +88,16 @@ void DeviceQueueDataSourceActor::FillDataBuffer() {
buffers_.push(device_tensors);
}

void DeviceQueueDataSourceActor::AllocateMemory(OpContext<DeviceTensor> *context) {
auto device_tensors = buffers_.back();
Async(memory_manager_aid_, &MemoryManagerActor::AllocateMemory, device_tensors, device_context_, context, GetAID());
}

void DeviceQueueDataSourceActor::FreeMemory(OpContext<DeviceTensor> *context) {
auto device_tensors = buffers_.front();
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, device_tensors, device_context_, context);
}

void DeviceQueueDataSourceActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(context);
MS_EXCEPTION_IF_NULL(device_context_);
@@ -151,6 +151,32 @@ void HostQueueDataSourceActor::FillDataBuffer() {
buffers_.push(device_tensors);
}

void HostQueueDataSourceActor::AllocateMemory(OpContext<DeviceTensor> *context) {
auto device_tensors = buffers_.back();
if (IsSameDeviceType()) {
Async(memory_manager_aid_, &MemoryManagerActor::AllocateMemory, device_tensors, device_contexts_[0], context,
GetAID());
} else {
for (size_t i = 0; i < device_tensors.size(); ++i) {
std::vector<DeviceTensor *> alloc_list({device_tensors[i]});
Async(memory_manager_aid_, &MemoryManagerActor::AllocateMemory, alloc_list, device_contexts_[i], context,
GetAID());
}
}
}

void HostQueueDataSourceActor::FreeMemory(OpContext<DeviceTensor> *context) {
auto device_tensors = buffers_.front();
if (IsSameDeviceType()) {
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, device_tensors, device_contexts_[0], context);
} else {
for (size_t i = 0; i < device_tensors.size(); ++i) {
std::vector<DeviceTensor *> free_list({device_tensors[i]});
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, free_list, device_contexts_[i], context);
}
}
}

void HostQueueDataSourceActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(context);
if (buffers_.size() == 0) {
@@ -198,5 +224,23 @@ void HostQueueDataSourceActor::SendResult(OpContext<DeviceTensor> *context) {
result_arrow->to_input_index_, context);
}
}

size_t HostQueueDataSourceActor::FetchDataNodePosition(const AnfNodePtr &data_node) const {
const auto &iter = data_node_position_map_.find(data_node);
if (iter == data_node_position_map_.end()) {
MS_LOG(EXCEPTION) << "Data node: " << data_node->fullname_with_scope() << " is not exist.";
}
return iter->second;
}

bool HostQueueDataSourceActor::IsSameDeviceType() const {
for (size_t i = 1; i < device_contexts_.size(); i++) {
if (device_contexts_[i] != device_contexts_[0]) {
return false;
}
}
return true;
}

} // namespace runtime
} // namespace mindspore

+ 21
- 18
mindspore/ccsrc/runtime/framework/actor/data_source_actor.h View File

@@ -39,20 +39,16 @@ using mindspore::device::DeviceContext;
// -> OnMemoryAllocFinish -> FreeMemory -> SendOutput.
class DataSourceActor : public MemoryInterfaceActor {
public:
DataSourceActor(std::string name, size_t buffer_capacity, const DeviceContext *device_context,
const AID memory_manager_aid)
: MemoryInterfaceActor(name),
buffer_capacity_(buffer_capacity),
device_context_(device_context),
memory_manager_aid_(memory_manager_aid) {}
DataSourceActor(std::string name, size_t buffer_capacity, const AID memory_manager_aid)
: MemoryInterfaceActor(name), buffer_capacity_(buffer_capacity), memory_manager_aid_(memory_manager_aid) {}
virtual ~DataSourceActor() = default;

// The process entry of data processing.
void FetchData(OpContext<DeviceTensor> *context);

// The memory related operation interface.
void AllocateMemory(OpContext<DeviceTensor> *context) override;
void FreeMemory(OpContext<DeviceTensor> *context) override;
void AllocateMemory(OpContext<DeviceTensor> *context) override{};
void FreeMemory(OpContext<DeviceTensor> *context) override{};
// Copy data from data source to the device tensor buffer of actor after memory alloc finished.
void OnMemoryAllocFinish(OpContext<DeviceTensor> *context) override{};

@@ -71,16 +67,10 @@ class DataSourceActor : public MemoryInterfaceActor {
// The output result arrows of graph output.
std::vector<OpArrowPtr> output_result_arrows_;

// To trigger kernel actors running by op arrows.
std::vector<OpArrowPtr> output_op_arrows_;

// The buffers store the device tensors.
std::queue<std::vector<DeviceTensor *>> buffers_;
size_t buffer_capacity_;

// The device interface of data copy.
const DeviceContext *device_context_;

// The id of memory manager actor. Send message to it for alloc and free memory during the data processing.
const AID memory_manager_aid_;
};
@@ -90,9 +80,11 @@ class DeviceQueueDataSourceActor : public DataSourceActor {
public:
DeviceQueueDataSourceActor(std::string name, size_t buffer_capacity, const DeviceContext *device_context,
const AID memory_manager_aid)
: DataSourceActor(name, buffer_capacity, device_context, memory_manager_aid) {}
: DataSourceActor(name, buffer_capacity, memory_manager_aid), device_context_(device_context) {}
~DeviceQueueDataSourceActor() override = default;

void AllocateMemory(OpContext<DeviceTensor> *context) override;
void FreeMemory(OpContext<DeviceTensor> *context) override;
void OnMemoryAllocFinish(OpContext<DeviceTensor> *context) override;

protected:
@@ -104,18 +96,24 @@ class DeviceQueueDataSourceActor : public DataSourceActor {

// Input data kernel(for example GetNext) fetches data from device queue.
CNodePtr data_kernel_;

const DeviceContext *device_context_;
};

// The class represents that the data source is host queue.
class HostQueueDataSourceActor : public DataSourceActor {
public:
HostQueueDataSourceActor(std::string name, size_t buffer_capacity, const DeviceContext *device_context,
const AID memory_manager_aid, HostTensorQueuePtr host_queue)
: DataSourceActor(name, buffer_capacity, device_context, memory_manager_aid), host_queue_(host_queue) {}
HostQueueDataSourceActor(std::string name, size_t buffer_capacity, const AID memory_manager_aid,
HostTensorQueuePtr host_queue)
: DataSourceActor(name, buffer_capacity, memory_manager_aid), host_queue_(host_queue) {}
~HostQueueDataSourceActor() override = default;

void AllocateMemory(OpContext<DeviceTensor> *context) override;
void FreeMemory(OpContext<DeviceTensor> *context) override;
void OnMemoryAllocFinish(OpContext<DeviceTensor> *context) override;

size_t FetchDataNodePosition(const AnfNodePtr &data_node) const;

protected:
void FillDataBuffer() override;
void SendResult(OpContext<DeviceTensor> *context) override;
@@ -123,9 +121,14 @@ class HostQueueDataSourceActor : public DataSourceActor {
private:
friend class GraphScheduler;

bool IsSameDeviceType() const;

HostTensorQueuePtr host_queue_;
// Input data nodes fetch data from host queue.
std::vector<AnfNodePtr> data_nodes_;
// The device contexts corresponding to the data nodes.
std::vector<const DeviceContext *> device_contexts_;

// The location of the data node in the data source actor.
std::unordered_map<AnfNodePtr, size_t> data_node_position_map_;
};


+ 9
- 2
mindspore/ccsrc/runtime/framework/actor/kernel_actor.cc View File

@@ -151,6 +151,7 @@ void KernelActor::PushInputDeviceTensor(const std::vector<TensorPtr> *input_tens

void KernelActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(context);
MS_EXCEPTION_IF_NULL(device_context_);
auto input_size = AnfAlgo::GetInputTensorNum(kernel_);
if (input_device_tensors_.empty()) {
input_device_tensors_.resize(input_size);
@@ -165,8 +166,14 @@ void KernelActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *context) {
}

for (auto &device_tensor_store_key : device_tensor_store_keys_) {
auto device_tensor = DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second);
input_device_tensors_[device_tensor_store_key.first] = device_tensor.get();
input_device_tensors_[device_tensor_store_key.first] =
DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second, device_context_->GetDeviceAddressType());
if (input_device_tensors_[device_tensor_store_key.first] == nullptr) {
std::string error_info =
GetAID().Name() + " get device tensor store failed: " + device_tensor_store_key.second->fullname_with_scope() +
", device type:" + std::to_string(static_cast<int>(device_context_->GetDeviceAddressType()));
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
}
}



+ 1
- 1
mindspore/ccsrc/runtime/framework/actor/kernel_actor.h View File

@@ -99,7 +99,7 @@ class KernelActor : public MemoryInterfaceActor {
size_t input_controls_num_;

// Pair<index, anfNode> points to the dependent device tensor store, anfNode is the key of the device tensor store.
std::vector<std::pair<size_t, void *>> device_tensor_store_keys_;
std::vector<std::pair<size_t, AnfNode *>> device_tensor_store_keys_;

// The device tensors for launch.
std::vector<DeviceTensor *> input_device_tensors_;


+ 10
- 3
mindspore/ccsrc/runtime/framework/actor/switch_actor.cc View File

@@ -198,6 +198,7 @@ bool SwitchActor::CheckLaunchCondition(OpContext<DeviceTensor> *context) const {

void SwitchActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(context);
MS_EXCEPTION_IF_NULL(device_context_);
auto input_size = input_datas_num_ + branch_device_tensor_store_keys_.size();
input_device_tensors_.resize(input_size);
auto data_iter = input_op_datas_.find(context->sequential_num_);
@@ -210,8 +211,14 @@ void SwitchActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *context) {
data_iter->second.clear();

for (auto &device_tensor_store_key : branch_device_tensor_store_keys_) {
auto device_tensor = DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second);
input_device_tensors_[device_tensor_store_key.first] = device_tensor.get();
input_device_tensors_[device_tensor_store_key.first] =
DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second, device_context_->GetDeviceAddressType());
if (input_device_tensors_[device_tensor_store_key.first] == nullptr) {
std::string error_info =
GetAID().Name() + " get device tensor store failed: " + device_tensor_store_key.second->fullname_with_scope() +
", device type:" + std::to_string(static_cast<int>(device_context_->GetDeviceAddressType()));
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
}
}

@@ -229,7 +236,7 @@ void SwitchActor::SendOutput(OpContext<DeviceTensor> *context) {
}

void SwitchActor::FreeMemory(OpContext<DeviceTensor> *context) {
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, input_device_tensors_, device_contexts_, context);
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, input_device_tensors_, device_context_, context);
}

} // namespace runtime


+ 2
- 2
mindspore/ccsrc/runtime/framework/actor/switch_actor.h View File

@@ -76,14 +76,14 @@ class SwitchActor : public SwitchActorBase<DeviceTensor> {
// The position of the branch output in the input_nodes_.
std::vector<std::vector<size_t>> branch_inputs_pos_;
// Pair<index, anfNode> points to the dependent device tensor store, anfNode is the key of the device tensor store.
std::vector<std::pair<size_t, void *>> branch_device_tensor_store_keys_;
std::vector<std::pair<size_t, AnfNode *>> branch_device_tensor_store_keys_;
std::vector<std::vector<AnfNodePtr>> branch_total_inputs_;
std::vector<FuncGraphPtr> branch_func_graph_;

std::vector<DeviceTensor *> input_device_tensors_;

// Save the DeviceContext of input_nodes_, which is used to release the DeviceTensor.
DeviceContext *device_contexts_;
DeviceContext *device_context_;

// The id of memory manager actor. Send message to it for alloc and free memory.
const AID memory_manager_aid_;


+ 43
- 9
mindspore/ccsrc/runtime/framework/device_tensor_store.h View File

@@ -19,12 +19,14 @@

#include <memory>
#include <unordered_map>
#include <vector>
#include "utils/ms_utils.h"
#include "runtime/device/device_address.h"

namespace mindspore {
namespace runtime {
using DeviceTensor = mindspore::device::DeviceAddress;
using DeviceTensorType = mindspore::device::DeviceAddressType;
using DeviceTensorPtr = std::shared_ptr<DeviceTensor>;

// The device tensor mainly includes address ptr, size and reference count,
@@ -38,23 +40,54 @@ class DeviceTensorStore {
return instance;
}

// Support value modifiable, so use the way of array subscript directly.
void Insert(void *key, DeviceTensorPtr value) { device_tensors_[key] = value; }
// Support value modifiable.
void Insert(AnfNode *key, const DeviceTensorPtr &value) {
MS_EXCEPTION_IF_NULL(key);
const auto &iter = device_tensors_.find(key);
if (iter == device_tensors_.end()) {
device_tensors_[key].emplace_back(value);
return;
}

for (size_t i = 0; i < iter->second.size(); ++i) {
if (iter->second[i]->DeviceType() == value->DeviceType()) {
iter->second[i] = value;
return;
}
}
iter->second.emplace_back(value);
}

void Remove(void *key) {
auto iter = device_tensors_.find(key);
void Remove(AnfNode *key) {
MS_EXCEPTION_IF_NULL(key);
const auto &iter = device_tensors_.find(key);
if (iter != device_tensors_.end()) {
(void)device_tensors_.erase(iter);
}
}

DeviceTensorPtr Fetch(void *key) const {
auto iter = device_tensors_.find(key);
std::vector<DeviceTensorPtr> Fetch(AnfNode *key) const {
MS_EXCEPTION_IF_NULL(key);
const auto &iter = device_tensors_.find(key);
if (iter != device_tensors_.end()) {
return iter->second;
} else {
return nullptr;
std::vector<DeviceTensorPtr> empty_value;
return empty_value;
}
}

DeviceTensor *Fetch(AnfNode *key, DeviceTensorType value_type) const {
MS_EXCEPTION_IF_NULL(key);
const auto &iter = device_tensors_.find(key);
if (iter != device_tensors_.end()) {
for (const auto &devcie_tensor : iter->second) {
if (devcie_tensor->DeviceType() == value_type) {
return devcie_tensor.get();
}
}
}
return nullptr;
}

private:
@@ -62,8 +95,9 @@ class DeviceTensorStore {
~DeviceTensorStore() = default;
DISABLE_COPY_AND_ASSIGN(DeviceTensorStore);

// The data storage of device tensor, key is anfNode ptr.
std::unordered_map<void *, DeviceTensorPtr> device_tensors_;
// The data storage of device tensor. Key is the anf node, value is the vector which may contains the device
// tensors from different devices.
std::unordered_map<AnfNode *, std::vector<DeviceTensorPtr>> device_tensors_;
};
} // namespace runtime
} // namespace mindspore


+ 342
- 148
mindspore/ccsrc/runtime/framework/graph_scheduler.cc View File

@@ -16,6 +16,7 @@

#include "runtime/framework/graph_scheduler.h"
#include "runtime/framework/actor/memory_manager_actor.h"
#include "runtime/hardware/device_context_manager.h"
#include "mindrt/src/actor/actormgr.h"
#include "mindrt/include/async/async.h"
#include "backend/session/anf_runtime_algorithm.h"
@@ -57,6 +58,17 @@ void UpdateRefCount(const AnfNodePtr &node, size_t output_idx, bool is_max_ref_c
UpdateRefCount(device_tensor.get(), is_max_ref_count);
}

AnfNodePtr FetchFrontNodeByBackendNode(const AnfNodePtr &backend_node, const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(backend_node);
MS_EXCEPTION_IF_NULL(graph);
auto front_node = graph->GetFrontAnfByBackendAnf(backend_node);
// PyNative forward graph does not has front node, using backend node instead.
if (front_node == nullptr) {
front_node = backend_node;
}
return front_node;
}

// The branch processing of PrepareDataForValueNode that value type is tensor.
void PrepareDataForValueNodeTensor(const ValueNodePtr &node, const ValuePtr &node_value,
const DeviceContext *device_context) {
@@ -133,50 +145,63 @@ void PrepareDataForValueNode(const ValueNodePtr &node, const DeviceContext *devi
}

// Prepare the device data for persistent device tensor of weight node from host tensor.
void PrepareDataForWeightNode(const AnfNodePtr &node, const TensorPtr &tensor, const DeviceContext *device_context) {
void PrepareDataForWeightNode(const AnfNodePtr &node, const TensorPtr &tensor, const DeviceContext *device_context,
const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(tensor);
MS_EXCEPTION_IF_NULL(device_context);
MS_EXCEPTION_IF_NULL(graph);
auto device_tensor = AnfAlgo::GetMutableOutputAddr(node, 0, false);
MS_EXCEPTION_IF_NULL(device_tensor);
const auto &host_tensor_address = std::dynamic_pointer_cast<DeviceTensor>(tensor->device_address());
// If the host tensor has the device address, it indicates that the device address of host tensor is new.
if (host_tensor_address != nullptr) {
if (host_tensor_address != device_tensor) {
AnfAlgo::SetOutputAddr(host_tensor_address, 0, node.get());
DeviceTensorStore::GetInstance().Insert(node.get(), host_tensor_address);
auto host_tensor_address = std::dynamic_pointer_cast<DeviceTensor>(tensor->device_address());
const auto &front_node = FetchFrontNodeByBackendNode(node, graph);
// Use the device address of host tensor to set device tensor.
if (host_tensor_address != device_tensor) {
if (host_tensor_address == nullptr) {
MS_EXCEPTION_IF_NULL(device_tensor);
host_tensor_address = device_context->CreateDeviceAddress(nullptr, device_tensor->GetSize(),
device_tensor->format(), device_tensor->type_id());
tensor->set_device_address(host_tensor_address);
UpdateRefCount(host_tensor_address.get(), true);
}
return;
} else {
auto new_device_tensor = device_context->CreateDeviceAddress(nullptr, device_tensor->GetSize(),
device_tensor->format(), device_tensor->type_id());
MS_EXCEPTION_IF_NULL(new_device_tensor);
AnfAlgo::SetOutputAddr(new_device_tensor, 0, node.get());
tensor->set_device_address(new_device_tensor);
DeviceTensorStore::GetInstance().Insert(node.get(), new_device_tensor);

new_device_tensor->set_original_ref_count(SIZE_MAX);
new_device_tensor->ResetRefCount();
device_tensor = new_device_tensor;
MS_EXCEPTION_IF_NULL(host_tensor_address);
AnfAlgo::SetOutputAddr(host_tensor_address, 0, node.get());
DeviceTensorStore::GetInstance().Insert(front_node.get(), host_tensor_address);
}

// If the ptr of device tensor is not nullptr, it indicates that the device data has been prepared.
if (device_tensor->GetPtr() != nullptr) {
if (host_tensor_address->GetPtr() != nullptr) {
return;
}
MS_LOG(INFO) << "Prepare device data for weight node: " << node->fullname_with_scope();
tensor->set_device_address(device_tensor);

// Allocate device memory.
if (!device_context->AllocateMemory(device_tensor.get(), device_tensor->GetSize())) {
MS_LOG(EXCEPTION) << "Device memory isn't enough and alloc failed, node name: " << node->fullname_with_scope()
<< ", alloc size: " << device_tensor->GetSize();
// Allocate device memory and copy data from host tensor to device.
if (!device_context->AllocateMemory(host_tensor_address.get(), host_tensor_address->GetSize())) {
MS_LOG(EXCEPTION) << "Device memory isn't enough and alloc failed, node name: " << node->fullname_with_scope();
}

// Copy data from host tensor to device.
if (!device_tensor->SyncHostToDevice(trans::GetRuntimePaddingShape(node, 0), LongToSize(tensor->data().nbytes()),
tensor->data_type(), tensor->data_c())) {
if (!host_tensor_address->SyncHostToDevice(trans::GetRuntimePaddingShape(node, 0),
LongToSize(tensor->data().nbytes()), tensor->data_type(),
tensor->data_c())) {
MS_LOG(EXCEPTION) << "SyncHostToDevice failed, node name: " << node->fullname_with_scope();
}

// Allocate another device memory and copy data from host tensor to another device(if exist).
const auto &device_tensors = DeviceTensorStore::GetInstance().Fetch(front_node.get());
if (device_tensors.size() > 1) {
auto another_device_tensor = (device_tensors[0] == host_tensor_address) ? device_tensors[1] : device_tensors[0];
MS_EXCEPTION_IF_NULL(another_device_tensor);
auto another_device_type = another_device_tensor->DeviceType();
const auto &another_device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
{device::kDeviceTypeToName.at(another_device_type), device_context->device_context_key().device_id_});
MS_EXCEPTION_IF_NULL(another_device_context);
if (!another_device_context->AllocateMemory(another_device_tensor.get(), another_device_tensor->GetSize())) {
MS_LOG(EXCEPTION) << "Device memory isn't enough and alloc failed, node name: " << node->fullname_with_scope();
}
if (!another_device_tensor->SyncHostToDevice(trans::GetRuntimePaddingShape(node, 0),
LongToSize(tensor->data().nbytes()), tensor->data_type(),
tensor->data_c())) {
MS_LOG(EXCEPTION) << "SyncHostToDevice failed, node name: " << node->fullname_with_scope();
}
}
}

void AllocateContinuousMemoryForInput(const AnfNodePtr &kernel, const DeviceContext *device_context,
@@ -266,6 +291,16 @@ void EraseValueNodeTensor(const std::vector<int64_t> *tensors_mask, const std::v
}
}
}

TensorPtr FetchInputTensor(const GraphCompilerInfo &graph_compiler_info, size_t graph_index, size_t input_index) {
if (graph_index < graph_compiler_info.input_tensors_.size()) {
const std::vector<TensorPtr> *input_tensors = graph_compiler_info.input_tensors_[graph_index];
if (input_index < input_tensors->size()) {
return input_tensors->at(input_index);
}
}
return nullptr;
}
} // namespace

GraphScheduler::~GraphScheduler() {
@@ -276,12 +311,12 @@ GraphScheduler::~GraphScheduler() {

// Local maps clear.
actor_name_to_actor_.clear();
output_to_actor_.clear();
graph_output_to_actor_.clear();
}

void GraphScheduler::Initialize() {
// Local maps and vcetors clear.
output_to_actor_.clear();
graph_output_to_actor_.clear();
copy_actors_.clear();

if (init_) {
@@ -325,7 +360,7 @@ ActorSet *GraphScheduler::Transform(const GraphCompilerInfo &graph_compiler_info

actors_.emplace(actor_set->name_, actor_set);

DumpActor(actor_set.get());
DumpActor(actor_set.get(), graph_compiler_info);
if (!CheckActorValid(actor_set.get(), strategy)) {
MS_LOG(EXCEPTION) << "The actor set of " << graph_compiler_info.name_ << " is invalid.";
}
@@ -400,7 +435,7 @@ void GraphScheduler::PrepareRun(const ActorSet *actor_set, const GraphCompilerIn
MS_EXCEPTION_IF_NULL(input_node);
if (IsPersistentDeviceTensor(input_node)) {
// Prepare the device data for weights.
PrepareDataForWeightNode(input_node, input_tensor, device_context);
PrepareDataForWeightNode(input_node, input_tensor, device_context, graph);
} else if (IsHostQueueDSActor(input_node, graph)) {
if (std::dynamic_pointer_cast<DeviceTensor>(input_tensor->device_address()) != nullptr) {
continue;
@@ -470,16 +505,6 @@ bool GraphScheduler::Run(const ActorSet *actor_set, GraphExecutionStrategy strat
return false;
}

// Sync device stream.
const auto &first_kernel_actor = actor_set->kernel_actors_[0];
MS_EXCEPTION_IF_NULL(first_kernel_actor);
const auto &device_context = first_kernel_actor->device_context_;
MS_EXCEPTION_IF_NULL(device_context);
if (!device_context->SyncStream()) {
MS_LOG(ERROR) << "Sync stream failed.";
return false;
}

return true;
}

@@ -513,26 +538,40 @@ void GraphScheduler::CacheGraphOutputToActor(const GraphCompilerInfo &graph_comp
const auto &outputs = AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem});
for (const auto &output : outputs) {
const auto &output_with_index = AnfAlgo::VisitKernelWithReturnType(output, 0, false);
MS_EXCEPTION_IF_NULL(output_with_index.first);
const auto &front_node = graph->GetFrontAnfByBackendAnf(output_with_index.first);
auto output_kernel = output_with_index.first;
MS_EXCEPTION_IF_NULL(output_kernel);
const auto &front_node = graph->GetFrontAnfByBackendAnf(output_kernel);
if (front_node == nullptr) {
continue;
}

auto origin_output_with_index = KernelWithIndex(front_node, output_with_index.second);
std::string actor_name;
// Only cache the kernel actor and device queue data source actor.
if (IsKernelActor(output_with_index.first)) {
actor_name = output_with_index.first->fullname_with_scope();
} else if (IsDeviceQueueDSActor(output_with_index.first)) {
actor_name = graph_compiler_info.name_ + "_DeviceDSActor" + "_" + std::to_string(graph->graph_id());
auto actor_output_index = output_with_index.second;
auto origin_output_with_index = KernelWithIndex(front_node, actor_output_index);
OpActor<DeviceTensor> *actor = nullptr;
if (IsKernelActor(output_kernel)) {
actor = FetchActor(output_kernel->fullname_with_scope());
} else if (IsDeviceQueueDSActor(output_kernel)) {
std::string actor_name = graph_compiler_info.name_ + "_DeviceDSActor" + "_" + std::to_string(graph->graph_id());
actor = FetchActor(actor_name);
} else if (IsHostQueueDSActor(output_kernel, graph)) {
actor = FetchActor(graph_compiler_info.name_ + "_HostDSActor");
const auto &host_ds_actor = dynamic_cast<HostQueueDataSourceActor *>(actor);
MS_EXCEPTION_IF_NULL(host_ds_actor);
// Get the position of output kernel in the data source actor.
actor_output_index = host_ds_actor->FetchDataNodePosition(output_kernel);
} else if (IsPersistentDeviceTensor(output_kernel)) {
MS_LOG(INFO) << "The graph " << graph->graph_id() << " output node:" << output_kernel->fullname_with_scope()
<< " is device tensor store.";
continue;
} else {
MS_LOG(WARNING) << "Invalid graph output node:" << output_kernel->fullname_with_scope();
continue;
}
const auto &actor = FetchActor(actor_name);
MS_EXCEPTION_IF_NULL(actor);
std::pair<OpActor<DeviceTensor> *, size_t> actor_pair(actor, output_with_index.second);
output_to_actor_.emplace(origin_output_with_index, actor_pair);
MS_LOG(INFO) << "Cache the graph " << graph->graph_id() << " output node:" << output_kernel->fullname_with_scope()
<< " to actor:" << actor->GetAID().Name() << " with output index:" << actor_output_index;
graph_output_to_actor_.emplace(origin_output_with_index, GraphOutputPair(actor, actor_output_index));
}
}
}
@@ -540,6 +579,7 @@ void GraphScheduler::CacheGraphOutputToActor(const GraphCompilerInfo &graph_comp
void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info,
GraphExecutionStrategy strategy) {
MS_EXCEPTION_IF_NULL(actor_set);
std::vector<KernelActor *> auto_monad_actors;

// Foreach the execution order to link the actors.
for (size_t index = 0; index < graph_compiler_info.graphs_.size(); ++index) {
@@ -558,19 +598,13 @@ void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_co
// Link the control arrows of kernel actor by the auto monad, the inputs include monad node.
LinkControlArrowByAutoMonad(kernel_actor, input_node);
if (HasAbstractMonad(input_node)) {
auto_monad_actors.emplace_back(kernel_actor);
continue; // No data arrow for monad input.
}

KernelWithIndex from_kernel_with_output_idx = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false);
KernelWithIndex to_kernel_with_input_idx = std::make_pair(kernel, i);

TensorPtr tensor = nullptr;
if (index < graph_compiler_info.input_tensors_.size()) {
const std::vector<TensorPtr> *input_tensors = graph_compiler_info.input_tensors_[index];
if (i < input_tensors->size()) {
tensor = input_tensors->at(i);
}
}
const auto &tensor = FetchInputTensor(graph_compiler_info, index, i);
// The gather of linking data allows of kernel by the different from kernel type.
LinkDataArrow(kernel_actor, actor_set, graph, from_kernel_with_output_idx, to_kernel_with_input_idx, tensor);
}
@@ -580,6 +614,9 @@ void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_co
// Link the control arrows of kernel actors.
LinkControlArrowForKernelActor(&(actor_set->kernel_actors_), actor_set->loop_count_actor_.get(), strategy);

// Auto monad actor may modify the device tensor store.
LinkDeviceTensorStoreForAutoMonadActor(auto_monad_actors);

// BuildNoInputKernelActor depends on whether kernel actors have input, so must be behind the link of kernel actors.
actor_set->no_input_kernel_actors_ = BuildNoInputKernelActor(actor_set);

@@ -634,19 +671,20 @@ std::vector<DataSourceActorPtr> GraphScheduler::BuildDataSourceActor(const Graph
auto actor_name = graph_compiler_info.name_ + "_HostDSActor";
MS_LOG(INFO) << "Create host queue data source actor: " << actor_name;
host_queue_ds_actor =
std::make_shared<HostQueueDataSourceActor>(actor_name, 1, device_context, memory_manager_aid_, host_queue);
std::make_shared<HostQueueDataSourceActor>(actor_name, 1, memory_manager_aid_, host_queue);
InsertActor(host_queue_ds_actor.get());
data_source_actors.emplace_back(host_queue_ds_actor);
}

const auto &front_node = graph->GetFrontAnfByBackendAnf(input_node);
const auto &front_node = FetchFrontNodeByBackendNode(input_node, graph);
// In the scenario where multiple backend nodes correspond to the same front node, only the first backend node
// is saved in the host queue data source actor.
if ((front_node != nullptr) && (front_node_position_temp_map.count(front_node) > 0)) {
if (front_node_position_temp_map.count(front_node) > 0) {
host_queue_ds_actor->data_node_position_map_.emplace(input_node, front_node_position_temp_map[front_node]);
continue;
}
host_queue_ds_actor->data_nodes_.emplace_back(input_node);
host_queue_ds_actor->device_contexts_.emplace_back(device_context);
host_queue_ds_actor->data_node_position_map_.emplace(input_node, data_node_position);
front_node_position_temp_map.emplace(front_node, data_node_position);
data_node_position++;
@@ -699,7 +737,7 @@ LoopCountActorPtr GraphScheduler::BuildLoopCountActor(const GraphCompilerInfo &g
}

auto loop_count = ConfigManager::GetInstance().iter_num();
auto actor_name = graph_compiler_info.name_ + "_" + "LoopCountActor";
auto actor_name = graph_compiler_info.name_ + "_LoopCountActor";
auto loop_count_actor = std::make_shared<LoopCountActor>(actor_name, loop_count);
MS_LOG(INFO) << "Create loop count actor: " << actor_name;
MS_EXCEPTION_IF_NULL(loop_count_actor);
@@ -768,8 +806,8 @@ void GraphScheduler::LinkDataArrow(KernelActor *to_actor, const ActorSet *actor_
// Link data arrow for internal parameter, convert internal parameter to actor by internal parameter cache to link.
LinkDataArrowForInternalParameter(from_kernel, graph, to_actor, to_kernel_with_input_idx);
} else if (IsPersistentDeviceTensor(from_kernel)) {
to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second,
static_cast<void *>(from_kernel.get()));
const auto devcie_tensor_store_key = FetchFrontNodeByBackendNode(from_kernel, graph);
to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second, devcie_tensor_store_key.get());
} else {
MS_LOG(EXCEPTION) << "Invalid from kernel: " << from_kernel->fullname_with_scope();
}
@@ -785,19 +823,35 @@ void GraphScheduler::LinkDataArrowForInternalParameter(const AnfNodePtr &interna
// Parameter ---> front node ---> actor.
auto front_node_with_index = graph->GetFrontNodeByInternalParameter(internal_parameter);
MS_EXCEPTION_IF_NULL(front_node_with_index.first);
if (output_to_actor_.count(front_node_with_index) == 0) {
MS_LOG(EXCEPTION) << "Can't find actor by node:" << front_node_with_index.first->fullname_with_scope();
const auto &front_output_with_index =
AnfAlgo::VisitKernelWithReturnType(front_node_with_index.first, front_node_with_index.second, false);
auto front_output_node = front_output_with_index.first;
MS_EXCEPTION_IF_NULL(front_output_node);
MS_LOG(INFO) << "Link data arrow for internal parameter:" << internal_parameter->fullname_with_scope()
<< ", corresponding front node:" << front_output_node->fullname_with_scope()
<< " with output index:" << front_output_with_index.second;
if (IsPersistentDeviceTensor(front_output_node)) {
to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second, front_output_node.get());
return;
}
auto actor_pair = output_to_actor_[front_node_with_index];
if (graph_output_to_actor_.count(front_output_with_index) == 0) {
MS_LOG(EXCEPTION) << "Can't find actor by front node:" << front_output_node->fullname_with_scope()
<< ", internal parameter:" << internal_parameter->fullname_with_scope();
}
auto actor_pair = graph_output_to_actor_[front_output_with_index];

if (IsDeviceQueueDSActor(front_node_with_index.first)) {
if (IsDeviceQueueDSActor(front_output_node)) {
auto from_actor = dynamic_cast<DeviceQueueDataSourceActor *>(actor_pair.first);
auto from_kernel_with_output_idx = KernelWithIndex(from_actor->data_kernel_, actor_pair.second);
LinkDataArrowForDeviceDSActor(from_actor, to_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
} else if (IsKernelActor(front_node_with_index.first)) {
} else if (IsKernelActor(front_output_node)) {
auto from_actor = dynamic_cast<KernelActor *>(actor_pair.first);
auto from_kernel_with_output_idx = KernelWithIndex(from_actor->kernel_, actor_pair.second);
LinkDataArrowForKernelActor(from_actor, to_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
} else if (IsHostQueueDSActor(front_output_node, graph)) {
auto from_actor = dynamic_cast<HostQueueDataSourceActor *>(actor_pair.first);
auto from_kernel_with_output_idx = KernelWithIndex(from_actor->data_nodes_[actor_pair.second], 0);
LinkDataArrowForHostDSActor(from_actor, to_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
} else {
MS_LOG(EXCEPTION) << "Invalid internal parameter: " << internal_parameter->fullname_with_scope();
}
@@ -839,19 +893,19 @@ void GraphScheduler::LinkDataArrowForHostDSActor(HostQueueDataSourceActor *from_
auto to_input_index = to_kernel_with_input_idx.second;

// Get the position of from kernel in the data source actor.
auto iter = from_actor->data_node_position_map_.find(from_kernel);
if (iter == from_actor->data_node_position_map_.end()) {
MS_LOG(EXCEPTION) << "Parameter node: " << from_kernel->fullname_with_scope() << " is not exist.";
}
auto position = iter->second;
auto position = from_actor->FetchDataNodePosition(from_kernel);

auto to_aid = to_actor->GetAID();
auto op_arrow = std::make_shared<OpArrow>(position, to_aid, to_input_index);
from_actor->output_op_arrows_.emplace_back(op_arrow);
to_actor->input_datas_num_++;
if (IsNeedInsertCopyActor(from_actor->device_contexts_[position], to_actor->device_context_)) {
LinkDataArrowForCopyActor(from_actor, to_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
} else {
auto to_aid = to_actor->GetAID();
auto op_arrow = std::make_shared<OpArrow>(position, to_aid, to_input_index);
from_actor->output_op_arrows_.emplace_back(op_arrow);
to_actor->input_datas_num_++;

// Update the reference count of device tensor.
UpdateRefCount(from_actor->data_nodes_[position], from_output_index);
// Update the reference count of device tensor.
UpdateRefCount(from_actor->data_nodes_[position], from_output_index);
}
}

void GraphScheduler::LinkDataArrowForKernelActor(KernelActor *from_actor, KernelActor *to_actor,
@@ -859,6 +913,7 @@ void GraphScheduler::LinkDataArrowForKernelActor(KernelActor *from_actor, Kernel
KernelWithIndex to_kernel_with_input_idx) {
MS_EXCEPTION_IF_NULL(from_actor);
MS_EXCEPTION_IF_NULL(to_actor);

auto from_kernel = from_kernel_with_output_idx.first;
MS_EXCEPTION_IF_NULL(from_kernel);
auto from_output_index = from_kernel_with_output_idx.second;
@@ -889,8 +944,8 @@ void GraphScheduler::LinkDataArrowForCopyActor(OpActor<DeviceTensor> *from_actor
auto from_output_index = from_kernel_with_output_idx.second;
auto to_input_index = to_kernel_with_input_idx.second;

std::string name =
"copy_actor_" + from_kernel->fullname_with_scope() + "_output_index_" + std::to_string(from_output_index);
std::string name = "copy_from:" + from_actor->GetAID().Name() + "_node:" + from_kernel->fullname_with_scope() +
"_output_index:" + std::to_string(from_output_index);
CopyActor *copy_actor = dynamic_cast<CopyActor *>(FetchActor(name));
// Link between from actor and copy actor.
if (copy_actor == nullptr) {
@@ -903,6 +958,7 @@ void GraphScheduler::LinkDataArrowForCopyActor(OpActor<DeviceTensor> *from_actor

// LInk.
const DeviceContext *from_devcie_context = nullptr;
auto from_device_tensor = AnfAlgo::GetMutableOutputAddr(from_kernel, from_output_index, false);
auto op_arrow_to_copy = std::make_shared<OpArrow>(from_output_index, copy_actor->GetAID(), 0);
if (IsDeviceQueueDSActor(from_kernel)) {
auto real_from_actor = dynamic_cast<DeviceQueueDataSourceActor *>(from_actor);
@@ -912,13 +968,20 @@ void GraphScheduler::LinkDataArrowForCopyActor(OpActor<DeviceTensor> *from_actor
auto real_from_actor = dynamic_cast<KernelActor *>(from_actor);
from_devcie_context = real_from_actor->device_context_;
real_from_actor->output_op_arrows_.emplace_back(op_arrow_to_copy);
} else if (IsHostQueueDSActor(from_kernel)) {
auto real_from_actor = dynamic_cast<HostQueueDataSourceActor *>(from_actor);
auto position = real_from_actor->FetchDataNodePosition(from_kernel);
from_devcie_context = real_from_actor->device_contexts_[position];
op_arrow_to_copy->from_output_index_ = position;
real_from_actor->output_op_arrows_.emplace_back(op_arrow_to_copy);
from_device_tensor =
AnfAlgo::GetMutableOutputAddr(real_from_actor->data_nodes_[position], from_output_index, false);
}
copy_actor->input_datas_num_++;

// Set the member of the copy actor.
const auto &from_device_tensor = AnfAlgo::GetMutableOutputAddr(from_kernel, from_output_index, false);
MS_EXCEPTION_IF_NULL(from_device_tensor);
copy_actor->output_device_tensor_ = to_devcie_context->CreateDeviceAddress(
copy_actor->output_ = to_devcie_context->CreateDeviceAddress(
nullptr, from_device_tensor->GetSize(), from_device_tensor->format(), from_device_tensor->type_id());
MS_EXCEPTION_IF_NULL(from_devcie_context);
copy_actor->input_device_context_ = from_devcie_context;
@@ -932,7 +995,7 @@ void GraphScheduler::LinkDataArrowForCopyActor(OpActor<DeviceTensor> *from_actor
auto op_arrow_from_copy = std::make_shared<OpArrow>(0, to_actor->GetAID(), to_input_index);
copy_actor->output_op_arrows_.emplace_back(op_arrow_from_copy);
to_actor->input_datas_num_++;
UpdateRefCount(copy_actor->output_device_tensor_.get());
UpdateRefCount(copy_actor->output_.get());
}

void GraphScheduler::LinkControlArrowForKernelActor(std::vector<KernelActorPtr> *from_actors, LoopCountActor *to_actor,
@@ -1051,12 +1114,7 @@ void GraphScheduler::LinkOutputResultArrowForOutputActor(OutputActor *to_actor,
for (const auto &output : outputs) {
const auto &output_with_index = AnfAlgo::VisitKernelWithReturnType(output, 0, false);
MS_EXCEPTION_IF_NULL(output_with_index.first);
AnfNodePtr front_node = graph->GetFrontAnfByBackendAnf(output_with_index.first);
if (front_node == nullptr) {
// PyNative forward graph does not has front node, using backend node instead.
front_node = output_with_index.first;
}

const auto &front_node = FetchFrontNodeByBackendNode(output_with_index.first, graph);
auto origin_output_with_index = KernelWithIndex(front_node, output_with_index.second);
const auto &iter = graph_compiler_info.origin_outputs_order_.find(origin_output_with_index);
if (iter == graph_compiler_info.origin_outputs_order_.end()) {
@@ -1089,11 +1147,7 @@ void GraphScheduler::LinkOutputResultArrowForOutputActor(OutputActor *to_actor,
if (IsHostQueueDSActor(output_with_index.first, graph)) {
actor_name = graph_compiler_info.name_ + "_HostDSActor";
const auto &host_queue_ds_actor = dynamic_cast<HostQueueDataSourceActor *>(FetchActor(actor_name));
auto position_iter = host_queue_ds_actor->data_node_position_map_.find(output_with_index.first);
if (position_iter == host_queue_ds_actor->data_node_position_map_.end()) {
MS_LOG(EXCEPTION) << "Parameter node: " << output_with_index.first->fullname_with_scope() << " is not exist.";
}
from_actor_output_index = position_iter->second;
from_actor_output_index = host_queue_ds_actor->FetchDataNodePosition(output_with_index.first);
UpdateRefCount(host_queue_ds_actor->data_nodes_[from_actor_output_index], output_with_index.second, true);
from_actor = static_cast<DataSourceActor *>(host_queue_ds_actor);
} else if (IsDeviceQueueDSActor(output_with_index.first)) {
@@ -1108,6 +1162,61 @@ void GraphScheduler::LinkOutputResultArrowForOutputActor(OutputActor *to_actor,
}
}

void GraphScheduler::LinkDeviceTensorStoreForAutoMonadActor(const std::vector<KernelActor *> &auto_monad_actors) {
const size_t kNeedUpdateDeviceTensorStoreNum = 2;
for (auto &kernel_actor : auto_monad_actors) {
MS_EXCEPTION_IF_NULL(kernel_actor);
for (auto &device_tensor_store_key : kernel_actor->device_tensor_store_keys_) {
auto device_tensors = DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second);
if (device_tensors.size() < kNeedUpdateDeviceTensorStoreNum) {
continue;
}

// Create the copy actor.
std::string name = "copy_from:" + kernel_actor->GetAID().Name() +
"_device_tensor_store:" + device_tensor_store_key.second->fullname_with_scope();
auto copy_actor = std::make_shared<CopyActor>(name, memory_manager_aid_);
MS_EXCEPTION_IF_NULL(copy_actor);
copy_actors_.emplace_back(copy_actor);
InsertActor(copy_actor.get());

// Set the member of the copy actor.
copy_actor->device_tensor_store_key_ = std::pair<size_t, AnfNode *>(0, device_tensor_store_key.second);
auto input_device_context = kernel_actor->device_context_;
copy_actor->input_device_context_ = input_device_context;
auto another_device_tensor = (device_tensors[0]->DeviceType() == input_device_context->GetDeviceAddressType())
? device_tensors[1]
: device_tensors[0];
MS_EXCEPTION_IF_NULL(another_device_tensor);
auto another_device_type = another_device_tensor->DeviceType();
const auto &another_device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
{device::kDeviceTypeToName.at(another_device_type), input_device_context->device_context_key().device_id_});
MS_EXCEPTION_IF_NULL(another_device_context);
copy_actor->output_device_context_ = another_device_context;

// LInk from copy actor to kernel actor users.
if (kernel_actor->output_op_controls_.size() == 0) {
MS_LOG(WARNING) << "The kernel actor has no control arrow:" << kernel_actor->GetAID().Name();
}
for (auto &output_contorl : kernel_actor->output_op_controls_) {
copy_actor->output_op_controls_.emplace_back(output_contorl);
auto to_actor = FetchActor(output_contorl.Name());
MS_EXCEPTION_IF_NULL(to_actor);
if (output_contorl.Name().find("_LoopCountActor") != string::npos) {
auto real_to_actor = dynamic_cast<LoopCountActor *>(to_actor);
real_to_actor->input_controls_num_++;
} else {
auto real_to_actor = dynamic_cast<KernelActor *>(to_actor);
real_to_actor->input_controls_num_++;
}
}
// Link from kernel actor to copy actor.
kernel_actor->output_op_controls_.emplace_back(copy_actor->GetAID());
copy_actor->input_controls_num_++;
}
}
}

bool GraphScheduler::CheckActorValid(const ActorSet *actor_set, GraphExecutionStrategy strategy) const {
MS_EXCEPTION_IF_NULL(actor_set);
// Check the data source actors.
@@ -1152,7 +1261,7 @@ bool GraphScheduler::CheckActorValid(const ActorSet *actor_set, GraphExecutionSt

const size_t kCopyActorInputDataNum = 1;
auto input_data_num = copy_actor->input_datas_num_;
auto device_tensor_store_num = copy_actor->device_tensor_store_keys_.size();
auto device_tensor_store_num = (copy_actor->device_tensor_store_key_.second == nullptr) ? 0 : 1;
if (input_data_num + device_tensor_store_num != kCopyActorInputDataNum) {
MS_LOG(ERROR) << "The input building of " << copy_actor->GetAID().Name()
<< " is wrong, input data num: " << input_data_num
@@ -1175,8 +1284,11 @@ bool GraphScheduler::CheckActorValid(const ActorSet *actor_set, GraphExecutionSt
}

void GraphScheduler::PersistDeviceTensor(const GraphCompilerInfo &graph_compiler_info) {
for (const auto &graph : graph_compiler_info.graphs_) {
for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
const auto &graph = graph_compiler_info.graphs_[i];
const auto &device_context = graph_compiler_info.device_contexts_[i];
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(device_context);

for (auto &value_node : graph->graph_value_nodes()) {
MS_EXCEPTION_IF_NULL(value_node);
@@ -1185,17 +1297,30 @@ void GraphScheduler::PersistDeviceTensor(const GraphCompilerInfo &graph_compiler
continue;
}
auto device_tensor = AnfAlgo::GetMutableOutputAddr(value_node, 0, false);
DeviceTensorStore::GetInstance().Insert(value_node.get(), device_tensor);
const auto &front_node = FetchFrontNodeByBackendNode(value_node, graph);
DeviceTensorStore::GetInstance().Insert(front_node.get(), device_tensor);
UpdateRefCount(device_tensor.get(), true);
}

for (auto &input_node : graph->input_nodes()) {
MS_EXCEPTION_IF_NULL(input_node);
if (IsPersistentDeviceTensor(input_node)) {
auto device_tensor = AnfAlgo::GetMutableOutputAddr(input_node, 0, false);
MS_EXCEPTION_IF_NULL(device_tensor);
DeviceTensorStore::GetInstance().Insert(input_node.get(), device_tensor);
UpdateRefCount(device_tensor.get(), true);
if (!IsPersistentDeviceTensor(input_node)) {
continue;
}
auto device_tensor = AnfAlgo::GetMutableOutputAddr(input_node, 0, false);
MS_EXCEPTION_IF_NULL(device_tensor);
const auto &front_node = FetchFrontNodeByBackendNode(input_node, graph);
DeviceTensorStore::GetInstance().Insert(front_node.get(), device_tensor);
UpdateRefCount(device_tensor.get(), true);

// If the device tensor store of this device type is not exist, then create the new device tensor of this type.
if (DeviceTensorStore::GetInstance().Fetch(front_node.get(), device_context->GetDeviceAddressType()) == nullptr) {
MS_LOG(INFO) << "Fetch no device tensor store by:" << front_node->fullname_with_scope()
<< ", type:" << device_context->GetDeviceAddressType();
auto other_type_device_tensor = device_context->CreateDeviceAddress(
nullptr, device_tensor->GetSize(), device_tensor->format(), device_tensor->type_id());
DeviceTensorStore::GetInstance().Insert(front_node.get(), other_type_device_tensor);
UpdateRefCount(other_type_device_tensor.get(), true);
}
}
}
@@ -1226,7 +1351,7 @@ OpActor<DeviceTensor> *GraphScheduler::FetchActor(const std::string actor_name)
return iter->second;
}

void GraphScheduler::DumpActor(const ActorSet *actor_set) const {
void GraphScheduler::DumpActor(const ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info) const {
MS_EXCEPTION_IF_NULL(actor_set);
const auto &context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
@@ -1246,36 +1371,58 @@ void GraphScheduler::DumpActor(const ActorSet *actor_set) const {
return;
}

ofs << "[Data source actors]\n";
ofs << "[Device tensor stores]\n";
DumpDeviceTensorStore(graph_compiler_info, ofs);

ofs << "\n\n[Data source actors]\n";
for (const auto &data_source_actor : actor_set->data_source_actors_) {
DumpDSActor(data_source_actor.get(), ofs);
ofs << "\n";
}

ofs << "\n[Kernel actors]\n";
ofs << "\n\n[Kernel actors]\n";
for (const auto &kernel_actor : actor_set->kernel_actors_) {
DumpKernelActor(kernel_actor.get(), ofs);
ofs << "\n";
}

ofs << "\n[No input kernel actors]\n";
ofs << "\n\n[No input kernel actors]\n";
for (const auto &no_input_kernel_actor : actor_set->no_input_kernel_actors_) {
DumpKernelActor(no_input_kernel_actor.get(), ofs);
ofs << "\n";
}

ofs << "\n[Loop count actor]\n";
ofs << "\n\n[Copy actors]\n";
for (const auto &copy_actor : actor_set->copy_actors_) {
DumpCopyActor(copy_actor.get(), ofs);
}

ofs << "\n\n[Loop count actor]\n";
const auto &loop_count_actor = actor_set->loop_count_actor_;
if (loop_count_actor != nullptr) {
DumpLoopCountActor(loop_count_actor.get(), ofs);
ofs << "\n";
}

ofs << "\n[Output actor]\n";
ofs << "\n\n[Output actor]\n";
const auto &output_actor = actor_set->output_actor_;
if (output_actor != nullptr) {
DumpOutputActor(output_actor.get(), ofs);
ofs << "\n";
}
}

void GraphScheduler::DumpBaseActor(const OpActor<DeviceTensor> *actor, std::ofstream &ofs) const {
MS_EXCEPTION_IF_NULL(actor);

const auto &output_op_arrows = actor->output_op_arrows();
ofs << "\t\toutput_data_arrows:" << output_op_arrows.size() << "\n ";
for (const auto &data_arrow : output_op_arrows) {
MS_EXCEPTION_IF_NULL(data_arrow);
ofs << "\t\t\tfrom_output_index:" << data_arrow->from_output_index_
<< "\tto_actor_name:" << data_arrow->to_op_id_.Name() << "\tto_input_index:" << data_arrow->to_input_index_
<< "\n";
}

const auto &output_op_controls = actor->output_op_controls();
ofs << "\t\toutput_control_arrows:" << output_op_controls.size() << "\n ";
for (const auto &aid : output_op_controls) {
ofs << "\t\t\tto_actor_name:" << aid.Name() << "\n";
}
}

@@ -1283,13 +1430,12 @@ void GraphScheduler::DumpDSActor(const DataSourceActor *actor, std::ofstream &of
MS_EXCEPTION_IF_NULL(actor);
const auto &actor_name = actor->GetAID().Name();

MS_EXCEPTION_IF_NULL(actor->device_context_);
ofs << "\tactor_name:" << actor_name << "\tdevice_context:" << actor->device_context_->device_context_key().ToString()
<< "\n";

if (actor_name.find("_DeviceDSActor") != string::npos) {
// Dump the member info of device queue data source actor.
const auto &device_queue_ds_actor = dynamic_cast<const DeviceQueueDataSourceActor *>(actor);
MS_EXCEPTION_IF_NULL(device_queue_ds_actor->device_context_);
ofs << "\tactor_name:" << actor_name
<< "\tdevice_context:" << device_queue_ds_actor->device_context_->device_context_key().ToString() << "\n";
const auto &data_kernel = device_queue_ds_actor->data_kernel_;
MS_EXCEPTION_IF_NULL(data_kernel);
ofs << "\t\tdata_kernel_name:" << data_kernel->fullname_with_scope()
@@ -1303,6 +1449,7 @@ void GraphScheduler::DumpDSActor(const DataSourceActor *actor, std::ofstream &of
}
} else if (actor_name.find("_HostDSActor") != string::npos) {
// Dump the member info of host queue data source actor.
ofs << "\tactor_name:" << actor_name << "\n";
const auto &host_queue_ds_actor = dynamic_cast<const HostQueueDataSourceActor *>(actor);
ofs << "\t\tdata_nodes:" << host_queue_ds_actor->data_nodes_.size() << "\n";
for (size_t i = 0; i < host_queue_ds_actor->data_nodes_.size(); ++i) {
@@ -1312,17 +1459,12 @@ void GraphScheduler::DumpDSActor(const DataSourceActor *actor, std::ofstream &of
MS_EXCEPTION_IF_NULL(device_tensor);
ofs << "\t\t\tnode_order_number:" << i << "\tnode_name:" << data_node->fullname_with_scope()
<< "\tptr:" << device_tensor->GetPtr() << "\tsize:" << device_tensor->GetSize()
<< "\toriginal_ref_count:" << device_tensor->original_ref_count() << "\n ";
<< "\toriginal_ref_count:" << device_tensor->original_ref_count()
<< "\tdevice_context:" << host_queue_ds_actor->device_contexts_[i]->device_context_key().ToString() << "\n";
}
}

ofs << "\t\toutput_data_arrows:" << actor->output_op_arrows_.size() << "\n ";
for (const auto &data_arrow : actor->output_op_arrows_) {
MS_EXCEPTION_IF_NULL(data_arrow);
ofs << "\t\t\tfrom_output_index:" << data_arrow->from_output_index_
<< "\tto_actor_name:" << data_arrow->to_op_id_.Name() << "\tto_input_index:" << data_arrow->to_input_index_
<< "\n";
}
DumpBaseActor(actor, ofs);

ofs << "\t\toutput_result_arrows:" << actor->output_result_arrows_.size() << "\n ";
for (const auto &result_arrow : actor->output_result_arrows_) {
@@ -1331,6 +1473,7 @@ void GraphScheduler::DumpDSActor(const DataSourceActor *actor, std::ofstream &of
<< "\tto_actor_name:" << result_arrow->to_op_id_.Name()
<< "\toutput_node_position:" << result_arrow->to_input_index_ << "\n";
}
ofs << "\n";
}

void GraphScheduler::DumpLoopCountActor(const LoopCountActor *actor, std::ofstream &ofs) const {
@@ -1370,24 +1513,12 @@ void GraphScheduler::DumpKernelActor(const KernelActor *actor, std::ofstream &of

ofs << "\t\tdevice_tensor_stores:" << actor->device_tensor_store_keys_.size() << "\n ";
for (const auto &device_tensor_store_key : actor->device_tensor_store_keys_) {
const auto &node = reinterpret_cast<AnfNode *>(device_tensor_store_key.second);
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(device_tensor_store_key.second);
ofs << "\t\t\tto_input_index:" << device_tensor_store_key.first
<< "\tfrom_node_name:" << node->fullname_with_scope() << "\n";
<< "\tfrom_node_name:" << device_tensor_store_key.second->fullname_with_scope() << "\n";
}

ofs << "\t\toutput_data_arrows:" << actor->output_op_arrows_.size() << "\n ";
for (const auto &data_arrow : actor->output_op_arrows_) {
MS_EXCEPTION_IF_NULL(data_arrow);
ofs << "\t\t\tfrom_output_index:" << data_arrow->from_output_index_
<< "\tto_actor_name:" << data_arrow->to_op_id_.Name() << "\tto_input_index:" << data_arrow->to_input_index_
<< "\n";
}

ofs << "\t\toutput_control_arrows:" << actor->output_op_controls_.size() << "\n ";
for (const auto &aid : actor->output_op_controls_) {
ofs << "\t\t\tto_actor_name:" << aid.Name() << "\n";
}
DumpBaseActor(actor, ofs);

ofs << "\t\toutput_result_arrows:" << actor->output_result_arrows_.size() << "\n ";
for (const auto &result_arrow : actor->output_result_arrows_) {
@@ -1396,6 +1527,7 @@ void GraphScheduler::DumpKernelActor(const KernelActor *actor, std::ofstream &of
<< "\tto_actor_name:" << result_arrow->to_op_id_.Name()
<< "\toutput_node_position:" << result_arrow->to_input_index_ << "\n";
}
ofs << "\n";
}

void GraphScheduler::DumpOutputActor(const OutputActor *actor, std::ofstream &ofs) const {
@@ -1411,5 +1543,67 @@ void GraphScheduler::DumpOutputActor(const OutputActor *actor, std::ofstream &of
}
}

void GraphScheduler::DumpCopyActor(const CopyActor *actor, std::ofstream &ofs) const {
MS_EXCEPTION_IF_NULL(actor);
MS_EXCEPTION_IF_NULL(actor->input_device_context_);
MS_EXCEPTION_IF_NULL(actor->output_device_context_);
ofs << "\tactor_name:" << actor->GetAID().Name()
<< "\tinput_device_context:" << actor->input_device_context_->device_context_key().ToString()
<< "\toutput_device_context:" << actor->output_device_context_->device_context_key().ToString()
<< "\tinput_data_num:" << actor->input_datas_num_ << "\tinput_controls_num:" << actor->input_controls_num_
<< "\n";

auto device_tensor = actor->output_;
if (device_tensor != nullptr) {
ofs << "\t\toutput_index:" << 0 << "\tptr:" << device_tensor->GetPtr() << "\tsize:" << device_tensor->GetSize()
<< "\toriginal_ref_count:" << device_tensor->original_ref_count() << "\n ";
}

if (actor->device_tensor_store_key_.second != nullptr) {
ofs << "\t\tdevice_tensor_stores:" << 1 << "\n ";
ofs << "\t\t\tto_input_index:" << actor->device_tensor_store_key_.first
<< "\tfrom_node_name:" << actor->device_tensor_store_key_.second->fullname_with_scope() << "\n";
}

DumpBaseActor(actor, ofs);
ofs << "\n";
}

void GraphScheduler::DumpDeviceTensorStore(const GraphCompilerInfo &graph_compiler_info, std::ofstream &ofs) const {
for (const auto &graph : graph_compiler_info.graphs_) {
MS_EXCEPTION_IF_NULL(graph);
ofs << "\tgraph id:" << graph->graph_id() << "\n";

for (auto &value_node : graph->graph_value_nodes()) {
MS_EXCEPTION_IF_NULL(value_node);
if (!AnfAlgo::OutputAddrExist(value_node, 0)) {
continue;
}
ofs << "\t\tdevcie tensor key:" << value_node->fullname_with_scope() << "\n";
const auto device_tensors = DeviceTensorStore::GetInstance().Fetch(value_node.get());
for (const auto &device_tensor : device_tensors) {
ofs << "\t\t\tdevcie tensor value:" << device_tensor << "\tptr:" << device_tensor->GetPtr()
<< "\tsize:" << device_tensor->GetSize() << "\toriginal_ref_count:" << device_tensor->original_ref_count()
<< "\tdevice_type:" << device_tensor->DeviceType() << "\n ";
}
}

for (auto &input_node : graph->input_nodes()) {
MS_EXCEPTION_IF_NULL(input_node);
if (!IsPersistentDeviceTensor(input_node)) {
continue;
}
const auto &front_node = FetchFrontNodeByBackendNode(input_node, graph);
ofs << "\t\tdevcie tensor key:" << front_node->fullname_with_scope() << "\n";
const auto device_tensors = DeviceTensorStore::GetInstance().Fetch(front_node.get());
for (const auto &device_tensor : device_tensors) {
ofs << "\t\t\tdevcie tensor value:" << device_tensor << "\tptr:" << device_tensor->GetPtr()
<< "\tsize:" << device_tensor->GetSize() << "\toriginal_ref_count:" << device_tensor->original_ref_count()
<< "\tdevice_type:" << device_tensor->DeviceType() << "\n ";
}
}
ofs << "\n";
}
}
} // namespace runtime
} // namespace mindspore

+ 15
- 9
mindspore/ccsrc/runtime/framework/graph_scheduler.h View File

@@ -38,10 +38,17 @@ namespace mindspore {
namespace runtime {
using mindspore::device::DeviceContext;
using mindspore::session::KernelWithIndex;
using KernelMapActor = std::unordered_map<std::string, KernelActorPtr>;
using KernelMapPosition = std::map<KernelWithIndex, size_t, session::KernelWithIndexCmp>;
using ActorInfo = std::string;

// The second element of pair represents the output index of op actor corresponding to the graph output node.
using GraphOutputPair = std::pair<OpActor<DeviceTensor> *, size_t>;

// OpArrowPair represent data edge between from actor and to actor.
// The first element of pair is the AID of from actor, and
// second element is op arrow between actors.
using OpArrowPair = std::pair<AID, OpArrowPtr>;

enum class GraphExecutionStrategy {
kPipeline, // The actor running is triggered only by data.
kStep // The actor running need be triggered by control in addition.
@@ -102,11 +109,6 @@ struct ActorSet {
};
using ActorSetPtr = std::shared_ptr<ActorSet>;

// OpArrowPair represent data edge between from actor and to actor.
// The first element of pair is the AID of from actor, and
// second element is op arrow between actors.
using OpArrowPair = std::pair<AID, OpArrowPtr>;

class GraphScheduler {
public:
static GraphScheduler &GetInstance() {
@@ -187,6 +189,7 @@ class GraphScheduler {
void LinkControlArrowForLoopCountActor(const ActorSet *actor_set, GraphExecutionStrategy strategy);
void LinkControlArrowByAutoMonad(KernelActor *to_actor, const AnfNodePtr &from_node);
void LinkOutputResultArrowForOutputActor(OutputActor *to_actor, const GraphCompilerInfo &graph_compiler_info);
void LinkDeviceTensorStoreForAutoMonadActor(const std::vector<KernelActor *> &auto_monad_actors);

// The processing of actors link dynamically.
// Analyze necessary input data of current actor, generate and cache op arrow
@@ -212,22 +215,25 @@ class GraphScheduler {
OpActor<DeviceTensor> *FetchActor(const std::string actor_name) const;

// Display the actor information of corresponding kernel graph.
void DumpActor(const ActorSet *actor_set) const;
void DumpActor(const ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info) const;
void DumpBaseActor(const OpActor<DeviceTensor> *actor, std::ofstream &ofs) const;
void DumpDSActor(const DataSourceActor *actor, std::ofstream &ofs) const;
void DumpLoopCountActor(const LoopCountActor *actor, std::ofstream &ofs) const;
void DumpKernelActor(const KernelActor *actor, std::ofstream &ofs) const;
void DumpOutputActor(const OutputActor *actor, std::ofstream &ofs) const;
void DumpCopyActor(const CopyActor *actor, std::ofstream &ofs) const;
void DumpDeviceTensorStore(const GraphCompilerInfo &graph_compiler_info, std::ofstream &ofs) const;

// The global maps, only be cleared in the deconstruction.
std::unordered_map<ActorInfo, ActorSetPtr> actors_;
std::unordered_map<ActorInfo, HostTensorQueuePtr> actor_to_host_queue_;
// The second element of pair represents the output index of op actor corresponding to the device tensor.
std::unordered_map<DeviceTensorPtr, std::pair<OpActor<DeviceTensor> *, size_t>> device_tensor_to_actor_;
std::unordered_map<DeviceTensorPtr, GraphOutputPair> device_tensor_to_actor_;

// The local maps and vectors, will be cleared at the beginning of each graph transform.
std::unordered_map<std::string, OpActor<DeviceTensor> *> actor_name_to_actor_;
// The second element of pair represents the output index of op actor corresponding to the graph output front node.
std::map<KernelWithIndex, std::pair<OpActor<DeviceTensor> *, size_t>, session::KernelWithIndexCmp> output_to_actor_;
std::map<KernelWithIndex, GraphOutputPair, session::KernelWithIndexCmp> graph_output_to_actor_;
// Beaceuse the copy actors are built in the link, so need record the all copy actors in the link process to push into
// the actor set after link.
std::vector<CopyActorPtr> copy_actors_;


+ 14
- 4
mindspore/ccsrc/vm/backend.cc View File

@@ -409,10 +409,8 @@ VectorRef MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &
input_tensors.emplace_back(input_tensor);
}

VectorRef outputs;

// Run actor DAG.
VectorRef outputs;
auto ms_context = MsContext::GetInstance();
const bool pynative_mode = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode);
if (pynative_mode) {
@@ -425,11 +423,23 @@ VectorRef MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &
const auto &actor_set = runtime::GraphScheduler::GetInstance().Fetch(actor_info);
MS_EXCEPTION_IF_NULL(actor_set);
runtime::GraphScheduler::GetInstance().PrepareRun(actor_set, graph_compiler_info, input_tensors);

if (!runtime::GraphScheduler::GetInstance().Run(actor_set)) {
MS_LOG(EXCEPTION) << "The actor runs failed, actor name: " << actor_set->name_;
}

// Sync device stream.
const auto &first_device_context = graph_compiler_info.device_contexts_[0];
MS_EXCEPTION_IF_NULL(first_device_context);
if (!first_device_context->SyncStream()) {
MS_LOG(EXCEPTION) << "Sync stream failed:" << first_device_context->device_context_key().ToString();
}
for (size_t i = 0; i < graph_compiler_info.device_contexts_.size(); ++i) {
const auto &device_context = graph_compiler_info.device_contexts_[i];
if ((device_context != first_device_context) && (!device_context->SyncStream())) {
MS_LOG(EXCEPTION) << "Sync stream failed:" << device_context->device_context_key().ToString();
}
}

// Fetch outputs.
MS_EXCEPTION_IF_NULL(actor_set->output_actor_);
auto &output_tensors = actor_set->output_actor_->outputs();


+ 3
- 0
mindspore/core/mindrt/include/actor/op_actor.h View File

@@ -86,6 +86,9 @@ class OpActor : public ActorBase {
// The op actor run when receive the input control.
virtual void RunOpControl(AID *input_control, OpContext<T> *context = nullptr) {}

std::vector<OpArrowPtr> output_op_arrows() const { return output_op_arrows_; }
std::vector<AID> output_op_controls() const { return output_op_controls_; }

protected:
// The op data.
std::unordered_map<uuids::uuid *, std::vector<OpDataPtr<T>>> input_op_datas_;


Loading…
Cancel
Save