Browse Source

!1395 add memory swap manager module

Merge pull request !1395 from zyli2020/add_mem_swap_module_header
tags/v0.3.0-alpha
mindspore-ci-bot Gitee 5 years ago
parent
commit
20d71dfbc3
3 changed files with 321 additions and 7 deletions
  1. +0
    -1
      mindspore/ccsrc/pre_activate/mem_reuse/mem_copy_manager.h
  2. +312
    -0
      mindspore/ccsrc/pre_activate/mem_reuse/mem_swap_manager.cc
  3. +9
    -6
      mindspore/ccsrc/pre_activate/mem_reuse/mem_swap_manager.h

+ 0
- 1
mindspore/ccsrc/pre_activate/mem_reuse/mem_copy_manager.h View File

@@ -23,7 +23,6 @@
#include <memory>
#include <utility>
#include "session/kernel_graph.h"
#include "device/gpu/cuda_driver.h"
#include "kernel/kernel.h"

using HostAddress = mindspore::kernel::Address;


+ 312
- 0
mindspore/ccsrc/pre_activate/mem_reuse/mem_swap_manager.cc View File

@@ -0,0 +1,312 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "pre_activate/mem_reuse/mem_swap_manager.h"
#include <algorithm>
#include "session/anf_runtime_algorithm.h"
#include "pre_activate/common/helper.h"

namespace mindspore {
namespace device {
namespace memswap {
void MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
execution_order_ = kernel_graph->execution_order();
FuncGraphManagerPtr manager = kernel_graph->manager();
NodeUsersMap user_map = manager->node_users();
size_t kernel_index = 0;

for (const auto &kernel : execution_order_) {
// parse topo order of kernel
kernel_execution_info_.emplace(kernel.get(), kernel_index++);
// parse tensor info
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
MS_EXCEPTION_IF_NULL(kernel_mod);
auto output_sizes = kernel_mod->GetOutputSizeList();

for (size_t output_idx = 0; output_idx < AnfAlgo::GetOutputTensorNum(kernel); ++output_idx) {
TensorInfo tensor_info = {output_sizes[output_idx], kernel, output_idx};
ordered_tensors_.push_back(tensor_info);
}
}

// parse topo order of user kernel
for (const auto &kernel : execution_order_) {
auto iter = user_map.find(kernel);
if (iter == user_map.end()) {
continue;
}
AnfNodeIndexSet node_set = iter->second;
auto &kernel_exec_info = SearchKernelExecutionInfo(kernel);
for (auto &node_pair : node_set) {
auto user_kernel = node_pair.first;
if (!AnfAlgo::IsRealCNodeKernel(user_kernel)) {
continue;
}

size_t user_kernel_topo_sort = SearchKernelExecutionInfo(user_kernel).topo_order_;
auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(user_kernel, node_pair.second - 1);
auto &output_idx = kernel_with_index.second;
kernel_exec_info.node_users_map_[output_idx].push_back(user_kernel_topo_sort);
}
for (auto &node_user_pair : kernel_exec_info.node_users_map_) {
sort(node_user_pair.second.begin(), node_user_pair.second.end());
}
}

sort(ordered_tensors_.begin(), ordered_tensors_.end(),
[](const TensorInfo &a, const TensorInfo &b) { return a.tensor_size_ > b.tensor_size_; });

auto cur_tensor_size = ordered_tensors_.front().tensor_size_;
for (auto &tensor_info : ordered_tensors_) {
if (cur_tensor_size != tensor_info.tensor_size_) {
cur_tensor_size = tensor_info.tensor_size_;
tensor_size_num_++;
}
}
tensor_size_threshold_ = ordered_tensors_.front().tensor_size_;
tensor_size_threshold_idx_ = 0;

distance_threshold_ = kernel_index / kDistanceInitFactor;
mem_swap_initialized_ = true;
MS_EXCEPTION_IF_NULL(mem_copy_manager_);
mem_copy_manager_->Init();
}

void MemSwapManager::AddSwapInfo() {
for (const auto &tensor : ordered_tensors_) {
size_t tensor_size = tensor.tensor_size_;
if (tensor_size < tensor_size_threshold_) {
break;
}

size_t output_idx = tensor.output_idx_;
const AnfNodePtr &kernel = tensor.kernel_;
auto &kernel_exec_info = SearchKernelExecutionInfo(kernel);
auto &node_users_map = kernel_exec_info.node_users_map_;

auto iter = node_users_map.find(output_idx);
if (iter == node_users_map.end()) {
continue;
}
auto &node_users = iter->second;
bool need_swap = (node_users.size() == 1 && node_users[0] - kernel_exec_info.topo_order_ >= distance_threshold_) ||
(node_users.size() > 1 && node_users[1] - node_users[0] >= distance_threshold_);
if (!need_swap) {
continue;
}
AddKernelNeedSwap(kernel, true);
HostAddress host_addr;
host_addr.size = tensor_size;
auto ret = AllocHostPinnedMem(tensor_size, reinterpret_cast<void **>(&host_addr.addr));
if (!ret) {
MS_LOG(EXCEPTION) << "Alloc host pinned memory[" << tensor_size << "] failed.";
}
kernel_exec_info.host_addrs_[output_idx] = host_addr;
MemSwapInfo mem_swap_out_info = {SwapKind::kDeviceToHost, kernel, output_idx};
if (node_users.size() > 1) {
AddKernelMemSwapInfo(execution_order_[node_users[0]], mem_swap_out_info);
AddKernelTriggerSwap(execution_order_[node_users[0]], true);
} else {
AddKernelMemSwapInfo(kernel, mem_swap_out_info);
AddKernelTriggerSwap(kernel, true);
}

size_t swap_in_order = node_users.size() == 1 ? node_users[0] - 1 : node_users[1] - 1;
if (swap_in_order <= kernel_exec_info.topo_order_) {
MS_LOG(EXCEPTION) << "Select swap in point failed for op[" << AnfAlgo::GetCNodeName(kernel) << "]";
}
auto swap_in_kernel = execution_order_[swap_in_order];
MemSwapInfo mem_swap_in_info = {SwapKind::kHostToDevice, kernel, output_idx};
AddKernelMemSwapInfo(swap_in_kernel, mem_swap_in_info);
AddKernelTriggerSwap(swap_in_kernel, true);

host_addrs_list_.push_back(host_addr);
}
}

void MemSwapManager::AddMemSwapTask(SwapKind swap_kind, const DeviceAddressPtr &device_address,
const HostAddress &host_address) {
if (swap_kind == SwapKind::kDeviceToHost) {
mem_copy_manager_->AddMemSwapOutTask(device_address, host_address);
} else if (swap_kind == SwapKind::kHostToDevice) {
mem_copy_manager_->AddMemSwapInTask(device_address, host_address);
}
}

bool MemSwapManager::SyncMemCopyStream(SwapKind swap_kind) { return mem_copy_manager_->SyncMemCopyStream(swap_kind); }

DeviceAddressPtr MemSwapManager::UpdateSwapQueue(SwapKind swap_kind) {
if (swap_kind == SwapKind::kDeviceToHost) {
return mem_copy_manager_->UpdateSwapOutQueue();
} else {
return mem_copy_manager_->UpdateSwapInQueue();
}
}

// retreat to find a workable swap scheme
bool MemSwapManager::RetreatSwapInfo() {
if (!trigger_swap_) {
trigger_swap_ = true;
}
if (swap_info_already_set_) {
ResetSwapInfo();
if (distance_threshold_ >= kDistanceLowerBound) {
auto distance_decay_step = execution_order_.size() / kDistanceInitFactor / tensor_size_num_;
distance_threshold_ -= (distance_decay_step > 1 ? distance_decay_step : 1);
}

while (tensor_size_threshold_idx_ < ordered_tensors_.size() - 1) {
++tensor_size_threshold_idx_;
if (tensor_size_threshold_idx_ > ordered_tensors_[tensor_size_threshold_idx_].tensor_size_) {
tensor_size_threshold_ = ordered_tensors_[tensor_size_threshold_idx_].tensor_size_;
break;
}
}

if (tensor_size_threshold_idx_ == ordered_tensors_.size() - 1 && distance_threshold_ < kDistanceLowerBound) {
MS_LOG(ERROR) << "Retreat swap info failed";
return false;
}
} else {
swap_info_already_set_ = true;
}
AddSwapInfo();
return true;
}

KernelExecutionInfo &MemSwapManager::SearchKernelExecutionInfo(const AnfNodePtr &kernel) const {
MS_EXCEPTION_IF_NULL(kernel);
auto iter = kernel_execution_info_.find(kernel.get());
if (iter == kernel_execution_info_.end()) {
MS_LOG(EXCEPTION) << "Can not find execution info of op[" << AnfAlgo::GetCNodeName(kernel) << "]";
}
return const_cast<KernelExecutionInfo &>(iter->second);
}

void MemSwapManager::AddKernelExecutionPerform(const AnfNodePtr &kernel, float perform) {
auto &kernel_exec_info = SearchKernelExecutionInfo(kernel);
kernel_exec_info.execution_perform_ = perform;
}

void MemSwapManager::AddKernelTriggerSwap(const AnfNodePtr &kernel, bool trigger_swap) {
auto &kernel_exec_info = SearchKernelExecutionInfo(kernel);
kernel_exec_info.trigger_swap_ = trigger_swap;
}

void MemSwapManager::AddKernelNeedSwap(const AnfNodePtr &kernel, bool need_swap) {
auto &kernel_exec_info = SearchKernelExecutionInfo(kernel);
kernel_exec_info.need_swap_ = need_swap;
}

void MemSwapManager::AddKernelSwapPerform(const AnfNodePtr &kernel, size_t output_idx,
const std::pair<float, float> &perform) {
MS_EXCEPTION_IF_NULL(kernel);
kernel_swap_perform_[kernel.get()][output_idx] = perform;
}

void MemSwapManager::AddKernelMemSwapInfo(const AnfNodePtr &kernel, const MemSwapInfo &mem_swap_info) {
MS_EXCEPTION_IF_NULL(kernel);
mem_swap_info_[kernel.get()].push_back(mem_swap_info);
}

float MemSwapManager::QueryKernelExecutionPerform(const AnfNodePtr &kernel) const {
const auto &kernel_exec_info = SearchKernelExecutionInfo(kernel);
return kernel_exec_info.execution_perform_;
}

bool MemSwapManager::QueryKerneTriggerSwap(const AnfNodePtr &kernel) const {
const auto &kernel_exec_info = SearchKernelExecutionInfo(kernel);
return kernel_exec_info.trigger_swap_;
}

bool MemSwapManager::QueryKerneNeedSwap(const AnfNodePtr &kernel) const {
const auto &kernel_exec_info = SearchKernelExecutionInfo(kernel);
return kernel_exec_info.need_swap_;
}

const PerformPair &MemSwapManager::QueryKernelSwapPerform(const AnfNodePtr &kernel, size_t output_idx) const {
MS_EXCEPTION_IF_NULL(kernel);
auto iter_kernel = kernel_swap_perform_.find(kernel.get());
if (iter_kernel == kernel_swap_perform_.end()) {
MS_LOG(EXCEPTION) << "Can not find swap performance data of op[" << AnfAlgo::GetCNodeName(kernel) << "]";
}

auto &perform_map = iter_kernel->second;
auto iter_output = perform_map.find(output_idx);
if (iter_output == perform_map.end()) {
MS_LOG(EXCEPTION) << "Can not find swap performance data of output[" << output_idx << "] of op["
<< AnfAlgo::GetCNodeName(kernel) << "]";
}
return iter_output->second;
}

const std::vector<MemSwapInfo> &MemSwapManager::QueryKerneMemSwapInfo(const AnfNodePtr &kernel) const {
MS_EXCEPTION_IF_NULL(kernel);
auto iter = mem_swap_info_.find(kernel.get());
if (iter == mem_swap_info_.end()) {
MS_LOG(EXCEPTION) << "Can not find memory swap information data of op[" << AnfAlgo::GetCNodeName(kernel) << "]";
}
return iter->second;
}

void MemSwapManager::InsertSwapInBlackList(const void *device_ptr) { swap_in_blacklist_.insert(device_ptr); }

bool MemSwapManager::FindInSwapInBlackList(const void *device_ptr) const {
auto iter = swap_in_blacklist_.find(device_ptr);
return iter != swap_in_blacklist_.end();
}

const HostAddress &MemSwapManager::kernel_host_addr(const AnfNodePtr &kernel, size_t output_idx) const {
auto &kernel_exec_info = SearchKernelExecutionInfo(kernel);
auto &host_addrs = kernel_exec_info.host_addrs_;
auto iter = host_addrs.find(output_idx);
if (iter == host_addrs.end()) {
MS_LOG(EXCEPTION) << "Can not find host address of op[" << AnfAlgo::GetCNodeName(kernel) << "]";
}
return iter->second;
}

bool MemSwapManager::AllocHostPinnedMem(size_t size, void **addr) const {
return mem_copy_manager_->AllocHostPinnedMem(size, addr);
}

void MemSwapManager::ReleaseHostPinnedMem() {
for (const auto &host_addr : host_addrs_list_) {
if (host_addr.addr) {
mem_copy_manager_->FreeHostPinnedMem(host_addr.addr);
}
}
host_addrs_list_.clear();
}

void MemSwapManager::ClearSwapQueue() { mem_copy_manager_->ClearSwapQueue(); }

void MemSwapManager::ResetSwapInfo() {
ClearSwapQueue();
for (auto &kernel_exec_info_pair : kernel_execution_info_) {
auto &kernel_exec_info = kernel_exec_info_pair.second;
kernel_exec_info.trigger_swap_ = false;
kernel_exec_info.need_swap_ = false;
kernel_exec_info.host_addrs_.clear();
}
ReleaseHostPinnedMem();
swap_in_blacklist_.clear();
mem_swap_info_.clear();
}
} // namespace memswap
} // namespace device
} // namespace mindspore

+ 9
- 6
mindspore/ccsrc/pre_activate/mem_reuse/mem_swap_manager.h View File

@@ -23,7 +23,7 @@
#include <memory>
#include <vector>
#include <utility>
#include "pre_activate/mem_reuse/mem_swap_util.h"
#include "pre_activate/mem_reuse/mem_copy_manager.h"

using PerformPair = std::pair<float, float>;
namespace mindspore {
@@ -31,9 +31,9 @@ namespace device {
namespace memswap {
class MemSwapManager {
public:
MemSwapManager()
explicit MemSwapManager(const MemCopyManagerPtr &mem_copy_manager)
: tensor_size_threshold_(0), tensor_size_threshold_idx_(0), tensor_size_num_(1), distance_threshold_(1) {
mem_copy_manager_ = std::make_shared<MemCopyManager>();
mem_copy_manager_ = mem_copy_manager;
}

~MemSwapManager() = default;
@@ -75,9 +75,9 @@ class MemSwapManager {

const HostAddress &kernel_host_addr(const AnfNodePtr &kernel, size_t output_idx) const;

size_t AllocHostPinnedMemory(size_t size, HostMemPtr *addr) const;
bool AllocHostPinnedMem(size_t size, void **addr) const;

void FreeHostPinnedMemory();
void ReleaseHostPinnedMem();

void ClearSwapQueue();

@@ -110,10 +110,13 @@ class MemSwapManager {
size_t tensor_size_num_;
size_t distance_threshold_;

MemCopyManagerPtr mem_copy_manager_{nullptr};
MemCopyManagerPtr mem_copy_manager_;
bool mem_swap_initialized_{false};
bool swap_info_already_set_{false};
bool trigger_swap_{false};

static constexpr size_t kDistanceInitFactor = 3;
static constexpr size_t kDistanceLowerBound = 3;
};
using MemSwapManagerPtr = std::shared_ptr<MemSwapManager>;
} // namespace memswap


Loading…
Cancel
Save