|
|
|
@@ -24,7 +24,15 @@ namespace device { |
|
|
|
namespace memswap { |
|
|
|
void MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph) { |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph); |
|
|
|
execution_order_ = kernel_graph->execution_order(); |
|
|
|
graph_manager_ = kernel_graph->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(graph_manager_); |
|
|
|
auto &kernels = kernel_graph->execution_order(); |
|
|
|
for (const auto &kernel : kernels) { |
|
|
|
if (AnfAlgo::IsRealCNodeKernel(kernel) && (!opt::IsNopNode(kernel))) { |
|
|
|
execution_order_.push_back(kernel); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
size_t kernel_index = 0; |
|
|
|
for (const auto &kernel : execution_order_) { |
|
|
|
// parse topo order of kernel |
|
|
|
@@ -41,7 +49,7 @@ void MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph) { |
|
|
|
} |
|
|
|
|
|
|
|
// parse topo order of user kernel |
|
|
|
SaveUserKernelTopoOrder(kernel_graph); |
|
|
|
SaveUserKernelTopoOrder(); |
|
|
|
|
|
|
|
sort(ordered_tensors_.begin(), ordered_tensors_.end(), |
|
|
|
[](const TensorInfo &a, const TensorInfo &b) { return a.tensor_size_ > b.tensor_size_; }); |
|
|
|
@@ -62,11 +70,22 @@ void MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph) { |
|
|
|
mem_copy_manager_->Init(); |
|
|
|
} |
|
|
|
|
|
|
|
void MemSwapManager::SaveUserKernelTopoOrder(const mindspore::session::KernelGraph *kernel_graph) { |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph); |
|
|
|
FuncGraphManagerPtr manager = kernel_graph->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
NodeUsersMap user_map = manager->node_users(); |
|
|
|
bool MemSwapManager::IsCommunicationRelevantOp(const AnfNodePtr &kernel) const { |
|
|
|
MS_EXCEPTION_IF_NULL(kernel); |
|
|
|
NodeUsersMap &user_map = graph_manager_->node_users(); |
|
|
|
auto iter = user_map.find(kernel); |
|
|
|
bool adjacent_with_communication_op = false; |
|
|
|
if (iter != user_map.end()) { |
|
|
|
AnfNodeIndexSet node_set = iter->second; |
|
|
|
adjacent_with_communication_op = std::any_of( |
|
|
|
node_set.begin(), node_set.end(), |
|
|
|
[](const std::pair<AnfNodePtr, int> &node_pair) { return AnfAlgo::IsCommunicationOp(node_pair.first); }); |
|
|
|
} |
|
|
|
return (AnfAlgo::IsCommunicationOp(kernel)) || adjacent_with_communication_op; |
|
|
|
} |
|
|
|
|
|
|
|
void MemSwapManager::SaveUserKernelTopoOrder() { |
|
|
|
NodeUsersMap &user_map = graph_manager_->node_users(); |
|
|
|
for (const auto &kernel : execution_order_) { |
|
|
|
auto iter = user_map.find(kernel); |
|
|
|
if (iter == user_map.end()) { |
|
|
|
@@ -76,13 +95,16 @@ void MemSwapManager::SaveUserKernelTopoOrder(const mindspore::session::KernelGra |
|
|
|
auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); |
|
|
|
for (auto &node_pair : node_set) { |
|
|
|
auto user_kernel = node_pair.first; |
|
|
|
if (!AnfAlgo::IsRealCNodeKernel(user_kernel)) { |
|
|
|
if (!AnfAlgo::IsRealCNodeKernel(user_kernel) || opt::IsNopNode(user_kernel)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
size_t user_kernel_topo_sort = SearchKernelExecutionInfo(user_kernel).topo_order_; |
|
|
|
auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(user_kernel, node_pair.second - 1); |
|
|
|
auto &output_idx = kernel_with_index.second; |
|
|
|
if (kernel_with_index.first.get() != kernel.get()) { |
|
|
|
MS_LOG(EXCEPTION) << "Save user kernel topo order failed for op[" << AnfAlgo::GetCNodeName(kernel) << "]"; |
|
|
|
} |
|
|
|
kernel_exec_info.node_users_map_[output_idx].push_back(user_kernel_topo_sort); |
|
|
|
} |
|
|
|
for (auto &node_user_pair : kernel_exec_info.node_users_map_) { |
|
|
|
@@ -100,6 +122,9 @@ void MemSwapManager::AddSwapInfo() { |
|
|
|
|
|
|
|
size_t output_idx = tensor.output_idx_; |
|
|
|
const AnfNodePtr &kernel = tensor.kernel_; |
|
|
|
if (IsCommunicationRelevantOp(kernel)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); |
|
|
|
auto &node_users_map = kernel_exec_info.node_users_map_; |
|
|
|
|
|
|
|
@@ -178,7 +203,7 @@ bool MemSwapManager::RetreatSwapInfo() { |
|
|
|
|
|
|
|
while (tensor_size_threshold_idx_ < ordered_tensors_.size() - 1) { |
|
|
|
++tensor_size_threshold_idx_; |
|
|
|
if (tensor_size_threshold_idx_ > ordered_tensors_[tensor_size_threshold_idx_].tensor_size_) { |
|
|
|
if (tensor_size_threshold_ > ordered_tensors_[tensor_size_threshold_idx_].tensor_size_) { |
|
|
|
tensor_size_threshold_ = ordered_tensors_[tensor_size_threshold_idx_].tensor_size_; |
|
|
|
break; |
|
|
|
} |
|
|
|
|