浏览代码

!27370 unified runtime support auto monad in the sink mode

Merge pull request !27370 from limingqi107/bug_fix4
tags/v1.6.0
i-robot Gitee 4 年前
父节点
当前提交
6cd4a0714c
共有 3 个文件被更改,包括 74 次插入28 次删除
  1. +4
    -1
      mindspore/ccsrc/runtime/framework/actor/abstract_actor.h
  2. +38
    -22
      mindspore/ccsrc/runtime/framework/actor/super_kernel_actor.cc
  3. +32
    -5
      mindspore/ccsrc/runtime/framework/graph_scheduler.cc

+ 4
- 1
mindspore/ccsrc/runtime/framework/actor/abstract_actor.h 查看文件

@@ -21,6 +21,7 @@
#include <string>
#include <memory>
#include <utility>
#include <set>
#include "mindrt/include/actor/op_actor.h"
#include "runtime/framework/actor/actor_common.h"
#include "runtime/framework/device_tensor_store.h"
@@ -95,9 +96,11 @@ class AbstractActor : public OpActor<DeviceTensor> {
std::vector<AnfNodePtr> output_data_nodes_;
std::vector<OpDataUniquePtr<DeviceTensor>> output_data_;

// The dependent device tensor stores, the dependent expression is pair<index, AnfNode>.
// The dependent device tensor stores, the dependent expression is pair<index, AnfNode>.
// Index is the input position, AnfNode is the key of the device tensor store.
std::vector<std::pair<size_t, AnfNodePtr>> device_tensor_store_keys_;
// The device tensor stores which have the auto monad attribute.
std::set<AnfNodePtr> auto_monad_device_tensor_stores_;

// The dependent input actors.
std::vector<AID> input_data_arrow_aids_;


+ 38
- 22
mindspore/ccsrc/runtime/framework/actor/super_kernel_actor.cc 查看文件

@@ -82,22 +82,21 @@ void SuperKernelActor::Run(OpContext<DeviceTensor> *const context) {
std::string error_info = "Launch graph exception, graph id: " + std::to_string(graph_->graph_id());
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}

for (auto item : ref_node_addr_map_) {
const auto &input_node = item.first;
auto formal_param_addr = AnfAlgo::GetMutableOutputAddr(input_node, 0, false);
MS_EXCEPTION_IF_NULL(item.first);
MS_EXCEPTION_IF_NULL(item.second);
auto formal_param_addr = AnfAlgo::GetMutableOutputAddr(item.first, 0, false);
MS_EXCEPTION_IF_NULL(formal_param_addr);
auto device_address = item.second;
MS_EXCEPTION_IF_NULL(device_address);
MS_LOG(INFO) << "The input ref_node: " << input_node->DebugString()
MS_LOG(INFO) << "The input ref_node: " << item.first->DebugString()
<< " need copy back, from address: " << formal_param_addr->GetPtr()
<< " to address: " << device_address->GetPtr() << ".";
if (!device_address->SyncDeviceToDevice(trans::GetRuntimePaddingShape(input_node, 0), formal_param_addr->GetSize(),
formal_param_addr->type_id(), formal_param_addr->GetPtr(),
formal_param_addr->format())) {
MS_LOG(EXCEPTION) << "Sync device to device failed.";
<< " to address: " << item.second->GetPtr() << ".";
if (!Copy(item.second, formal_param_addr.get())) {
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "Copy data failed.");
}
}
ref_node_addr_map_.clear();

PostRun(context);
}

@@ -110,6 +109,7 @@ bool SuperKernelActor::CopyInputData(const OpContext<DeviceTensor> *context) {
}

auto &input_nodes = graph_->input_nodes();
// Copy input data.
for (auto &input_data : data_iter->second) {
MS_EXCEPTION_IF_NULL(input_data);
if (IntToSize(input_data->index_) >= input_nodes.size()) {
@@ -120,25 +120,18 @@ bool SuperKernelActor::CopyInputData(const OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(input_node);
auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0, false);
MS_EXCEPTION_IF_NULL(device_address);

auto &input_device_tensor = input_data->data_;
MS_EXCEPTION_IF_NULL(input_device_tensor);
if (input_device_tensor->DeviceType() != device_address->DeviceType()) {
MS_LOG(ERROR) << "The input data device type:" << input_device_tensor->DeviceType()
<< " is not equal to the graph node device type:" << device_address->DeviceType() << ".";
return false;
}

if (input_device_tensor->GetPtr() == device_address->GetPtr()) {
continue;
}

MS_LOG(INFO) << "The input data of node:" << input_node->DebugString()
<< " need copy from address:" << input_device_tensor->GetPtr()
<< " to address:" << device_address->GetPtr() << ".";
if (!device_address->SyncDeviceToDevice(trans::GetRuntimePaddingShape(input_node, 0),
input_device_tensor->GetSize(), input_device_tensor->type_id(),
input_device_tensor->GetPtr(), input_device_tensor->format())) {
MS_LOG(ERROR) << "Sync device to device failed.";
<< ", type:" << input_device_tensor->DeviceType() << " to address:" << device_address->GetPtr()
<< ", type:" << device_address->DeviceType() << ".";
if (!Copy(device_address.get(), input_device_tensor)) {
MS_LOG(ERROR) << "Copy data failed.";
return false;
}
if (HasAbstractRef(input_node) && ref_node_addr_map_.count(input_node) == 0) {
@@ -146,6 +139,29 @@ bool SuperKernelActor::CopyInputData(const OpContext<DeviceTensor> *context) {
}
}

// Check device tensor store.
for (auto &device_tensor_store_key : device_tensor_store_keys_) {
auto input_device_tensor = DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second.get(),
device_contexts_[0]->GetDeviceAddressType());
MS_EXCEPTION_IF_NULL(input_device_tensor);
if (device_tensor_store_key.first >= input_nodes.size()) {
MS_LOG(ERROR) << "The input index:" << device_tensor_store_key.first << "is out of range:" << input_nodes.size();
return false;
}
auto input_node = input_nodes[device_tensor_store_key.first];
MS_EXCEPTION_IF_NULL(input_node);
auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0, false);
MS_EXCEPTION_IF_NULL(device_address);
if (input_device_tensor->GetPtr() != device_address->GetPtr()) {
MS_LOG(ERROR) << "The input data of node:" << input_node->DebugString()
<< " device address:" << input_device_tensor->GetPtr()
<< ", type:" << input_device_tensor->DeviceType()
<< " is not equal to the graph node device address:" << device_address->GetPtr()
<< ", type:" << device_address->DeviceType() << ".";
return false;
}
}

return true;
}
} // namespace runtime


+ 32
- 5
mindspore/ccsrc/runtime/framework/graph_scheduler.cc 查看文件

@@ -874,7 +874,6 @@ void GraphScheduler::LinkDataArrowInSinkMode(const KernelGraphPtr &graph, const
MS_LOG(INFO) << "The graph:" << graph->graph_id()
<< " has abstract monad input node:" << input_node->DebugString() << ", input index:" << node_index;
LinkControlArrowByAutoMonad(to_actor, input_node, graph);
(void)auto_monad_actors->emplace_back(to_actor);
continue; // No data arrow for monad input.
}

@@ -884,6 +883,32 @@ void GraphScheduler::LinkDataArrowInSinkMode(const KernelGraphPtr &graph, const
// The gather of linking data arrows of kernel by the different from kernel type.
LinkDataArrow(to_actor, graph_compiler_info, graph, from_kernel_with_output_idx, to_kernel_with_input_idx);
}

std::vector<CNodePtr> auto_monad_kernels;
// Foreach the execution order to get the auto monad kernels.
auto &execution_order = graph->execution_order();
(void)std::for_each(execution_order.begin(), execution_order.end(), [&](const CNodePtr &kernel) {
for (size_t i = 0; i < AnfAlgo::GetInputNum(kernel); ++i) {
auto input_node = AnfAlgo::GetInputNode(kernel, i);
if (HasAbstractMonad(input_node)) {
(void)auto_monad_kernels.emplace_back(kernel);
continue;
}
}
});
// Foreach auto monad kernels to get the auto monad device tensor stores.
(void)std::for_each(auto_monad_kernels.begin(), auto_monad_kernels.end(), [&](const CNodePtr &kernel) {
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) {
KernelWithIndex from_kernel_with_output_idx = AnfAlgo::GetPrevNodeOutput(kernel, i, false);
auto front_node = FetchFrontNodeByBackendNode(from_kernel_with_output_idx.first, graph);
if (IsPersistentDeviceTensor(front_node)) {
(void)to_actor->auto_monad_device_tensor_stores_.insert(front_node);
}
}
});
if (to_actor->auto_monad_device_tensor_stores_.size() > 0) {
(void)auto_monad_actors->emplace_back(to_actor);
}
}

void GraphScheduler::LinkDataArrowInNonSinkMode(const KernelGraphPtr &graph,
@@ -964,13 +989,9 @@ void GraphScheduler::LinkDataArrowForDeviceTensorStore(AbstractActor *const, Abs
const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(to_actor);
MS_EXCEPTION_IF_NULL(graph);
if (to_actor->type_ == KernelTransformType::kSuperKernelActor) {
return;
}

auto from_kernel = from_kernel_with_output_idx.first;
MS_EXCEPTION_IF_NULL(from_kernel);

auto device_tensor_store_key = FetchFrontNodeByBackendNode(from_kernel, graph);
(void)to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second, device_tensor_store_key);
}
@@ -1550,6 +1571,12 @@ void GraphScheduler::LinkDeviceTensorStoreForAutoMonadActor(const std::vector<Ab
if (device_tensors.size() < kNeedUpdateDeviceTensorStoreNum) {
continue;
}
// Find the device tensor store that needs to be processed accurately.
if ((auto_monad_actor->type_ == KernelTransformType::kSuperKernelActor) &&
(auto_monad_actor->auto_monad_device_tensor_stores_.find(device_tensor_store_key.second) ==
auto_monad_actor->auto_monad_device_tensor_stores_.end())) {
continue;
}

// Create the copy actor.
std::string name = "copy_from:" + auto_monad_actor->GetAID().Name() +


正在加载...
取消
保存