Browse Source

Sync parameter output after execute kernel when using MemScheduler

tags/v1.6.0
tanghuikang 4 years ago
parent
commit
87c7f72cf1
5 changed files with 52 additions and 30 deletions
  1. +3
    -0
      mindspore/ccsrc/backend/session/anf_runtime_algorithm.h
  2. +4
    -2
      mindspore/ccsrc/backend/session/session_basic.cc
  3. +2
    -2
      mindspore/ccsrc/runtime/device/ascend/ascend_memory_pool.cc
  4. +41
    -26
      mindspore/ccsrc/runtime/device/kernel_runtime.cc
  5. +2
    -0
      mindspore/ccsrc/runtime/device/kernel_runtime.h

+ 3
- 0
mindspore/ccsrc/backend/session/anf_runtime_algorithm.h View File

@@ -188,6 +188,9 @@ class AnfRuntimeAlgorithm {
static const DeviceAddress *GetOutputAddr(const AnfNodePtr &node, size_t output_idx, bool skip_nop_node = true);
// get mutable output device addr of anf_node
static DeviceAddressPtr GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx, bool skip_nop_node = true);
static DeviceAddressPtr GetMutableOutputAddr(const KernelWithIndex &node_output_index, bool skip_nop_node) {
return GetMutableOutputAddr(node_output_index.first, node_output_index.second, skip_nop_node);
}
// check whether output addr is exist or not
static bool OutputAddrExist(const AnfNodePtr &node, size_t output_idx, bool skip_nop_node = false);
// check whether workspace addr is exist or not


+ 4
- 2
mindspore/ccsrc/backend/session/session_basic.cc View File

@@ -1443,8 +1443,10 @@ void SessionBasic::GetOpInputTensors(const CNodePtr &cnode,
tensor::TensorPtr tensor = nullptr;
if (real_input->isa<ValueNode>()) {
tensor = GetValueNodeOutputTensor(real_input, kernel_with_index.second);
input_tensor_info->input_tensors_mask.emplace_back(
GetValueNode(real_input)->isa<StringImm>() ? kValueNodeTensorMask : kParameterDataTensorMask);
const auto &value_ptr = GetValueNode(real_input);
MS_EXCEPTION_IF_NULL(value_ptr);
input_tensor_info->input_tensors_mask.emplace_back(value_ptr->isa<StringImm>() ? kValueNodeTensorMask
: kParameterDataTensorMask);
} else if (real_input->isa<Parameter>()) {
tensor = GetParameterOutputTensor(real_input, parameter_index, graph_inputs);
input_tensor_info->input_tensors_mask.emplace_back(tensor->is_parameter() ? kParameterWeightTensorMask


+ 2
- 2
mindspore/ccsrc/runtime/device/ascend/ascend_memory_pool.cc View File

@@ -31,8 +31,8 @@ static const size_t ASCEND_DYNAMIC_MEM_ALLOC_UNIT_SIZE_FOR_GRAPH = 8 << 20;
size_t AscendMemoryPool::CalMemBlockAllocSize(size_t size) {
auto device_free_mem_size = free_mem_size();
if (device_free_mem_size < size) {
MS_LOG(WARNING) << "Memory not enough: current free memory size[" << device_free_mem_size
<< "] is smaller than required size[" << size << "]";
MS_LOG(WARNING) << "Out of Memory. Request memory size: " << size
<< ", Memory Statistic:" << AscendMemAdapter::GetInstance().DevMemStatistics();
return 0;
}
auto alloc_mem_size = ASCEND_DYNAMIC_MEM_ALLOC_UNIT_SIZE;


+ 41
- 26
mindspore/ccsrc/runtime/device/kernel_runtime.cc View File

@@ -1352,39 +1352,54 @@ void KernelRuntime::SyncNodeOutputTensors(const std::shared_ptr<MemScheduler> &m
MS_EXCEPTION_IF_NULL(kernel);
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
MS_EXCEPTION_IF_NULL(kernel_mod);
for (size_t j = 0; j < kernel_mod->GetOutputSizeList().size(); ++j) {
auto tensor = graph.GetNodeOutputTensor(std::make_pair(kernel, j));
auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, j, true);
if (mock) {
if (graph.IsInternalOutput(kernel, j) && device_address != nullptr) {
mem_scheduler->SetMemPriority(device_address.get(), kMemPriorityHigh);
}
for (size_t input_idx = 0; input_idx < kernel_mod->GetInputSizeList().size(); ++input_idx) {
const auto input_node_index = AnfAlgo::GetPrevNodeOutput(kernel, input_idx, true);
if (input_node_index.first == nullptr || !input_node_index.first->isa<Parameter>()) {
continue;
}
if (tensor == nullptr) {
continue;
}
if (device_address == nullptr) {
tensor->data_sync(false);
tensor->set_device_address(nullptr);
tensor->set_sync_status(kNeedSyncHostToDevice);
continue;
}
if (!SyncStream()) {
MS_LOG(EXCEPTION) << "SyncStream failed";
}
auto origin_ptr = device_address->ptr_;
if (origin_ptr == nullptr) {
device_address->ptr_ = mem_scheduler->GetOrMalloc(device_address.get(), device_address->size_);
SyncNodeOutputTensor(mem_scheduler, input_node_index, graph, mock);
}
for (size_t output_idx = 0; output_idx < kernel_mod->GetOutputSizeList().size(); ++output_idx) {
SyncNodeOutputTensor(mem_scheduler, std::make_pair(kernel, output_idx), graph, mock);
}
}

void KernelRuntime::SyncNodeOutputTensor(const std::shared_ptr<MemScheduler> &mem_scheduler,
const KernelWithIndex &node_output_index, const session::KernelGraph &graph,
bool mock) {
MS_EXCEPTION_IF_NULL(mem_scheduler);
if (node_output_index.first == nullptr) {
return;
}
auto device_address = AnfAlgo::GetMutableOutputAddr(node_output_index, true);
if (mock) {
if (graph.IsInternalOutput(node_output_index.first, node_output_index.second) && device_address != nullptr) {
mem_scheduler->SetMemPriority(device_address.get(), kMemPriorityHigh);
}
tensor->set_device_address(device_address);
return;
}
auto tensor = graph.GetNodeOutputTensor(node_output_index);
if (tensor == nullptr) {
return;
}
if (device_address == nullptr) {
tensor->data_sync(false);
tensor->set_device_address(nullptr);
if (origin_ptr == nullptr) {
device_address->ptr_ = nullptr;
}
tensor->set_sync_status(kNeedSyncHostToDevice);
return;
}
if (!SyncStream()) {
MS_LOG(EXCEPTION) << "SyncStream failed";
}
auto origin_ptr = device_address->ptr_;
if (device_address->ptr_ == nullptr) {
device_address->ptr_ = mem_scheduler->GetOrMalloc(device_address.get(), device_address->size_);
}
tensor->set_device_address(device_address);
tensor->data_sync(false);
tensor->set_device_address(nullptr);
device_address->ptr_ = origin_ptr;
tensor->set_sync_status(kNeedSyncHostToDevice);
}

void KernelRuntime::InitGraphInputTensors(const std::shared_ptr<MemScheduler> &mem_scheduler,


+ 2
- 0
mindspore/ccsrc/runtime/device/kernel_runtime.h View File

@@ -156,6 +156,8 @@ class KernelRuntime {
void InitGraphInputTensors(const std::shared_ptr<MemScheduler> &mem_scheduler, const session::KernelGraph &graph);
void SyncNodeOutputTensors(const std::shared_ptr<MemScheduler> &mem_scheduler, const session::KernelGraph &graph,
const AnfNodePtr &kernel, bool mock);
void SyncNodeOutputTensor(const std::shared_ptr<MemScheduler> &mem_scheduler, const KernelWithIndex &output,
const session::KernelGraph &graph, bool mock);

void AssignCommunicationMem(const session::KernelGraph &graph);
bool LaunchKernelMod(const session::KernelGraph &graph, bool mock = false);


Loading…
Cancel
Save