Merge pull request !3246 from zyli2020/refine_gpu_mem_swaptags/v0.6.0-beta
| @@ -19,6 +19,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <map> | #include <map> | ||||
| #include <set> | |||||
| #include <queue> | #include <queue> | ||||
| #include <memory> | #include <memory> | ||||
| #include <utility> | #include <utility> | ||||
| @@ -40,29 +41,58 @@ struct TensorInfo { | |||||
| struct KernelExecutionInfo { | struct KernelExecutionInfo { | ||||
| size_t topo_order_{0}; | size_t topo_order_{0}; | ||||
| float execution_perform_{0.0}; | float execution_perform_{0.0}; | ||||
| bool trigger_swap_{false}; | |||||
| bool need_swap_{false}; | |||||
| // output index to topo orders of node users | |||||
| bool trigger_swap_out_{false}; | |||||
| bool trigger_swap_in_{false}; | |||||
| size_t swap_in_task_num_{0}; | |||||
| // Key: output index, value: topo orders of node users | |||||
| std::map<size_t, std::vector<size_t>> node_users_map_; | std::map<size_t, std::vector<size_t>> node_users_map_; | ||||
| // kernel output idx to host addr | |||||
| std::map<size_t, HostAddress> host_addrs_; | |||||
| // Key: output idx, value: (host addr, dirty or not) | |||||
| std::map<size_t, std::pair<HostAddress, bool>> host_addrs_; | |||||
| KernelExecutionInfo() : KernelExecutionInfo(0, 0.0, false, false) {} | |||||
| explicit KernelExecutionInfo(size_t topo_order) | |||||
| : topo_order_(topo_order), execution_perform_(0.0), trigger_swap_(false), need_swap_(false) {} | |||||
| KernelExecutionInfo(size_t topo_order, float execution_perform, bool trigger_swap, bool need_swap) | |||||
| KernelExecutionInfo() {} | |||||
| explicit KernelExecutionInfo(size_t topo_order) : KernelExecutionInfo(topo_order, 0.0, false, false, 0) {} | |||||
| KernelExecutionInfo(size_t topo_order, float execution_perform, bool trigger_swap_out, bool trigger_swap_in, | |||||
| size_t swap_in_task_num) | |||||
| : topo_order_(topo_order), | : topo_order_(topo_order), | ||||
| execution_perform_(execution_perform), | execution_perform_(execution_perform), | ||||
| trigger_swap_(trigger_swap), | |||||
| need_swap_(need_swap) {} | |||||
| trigger_swap_out_(trigger_swap_out), | |||||
| trigger_swap_in_(trigger_swap_in), | |||||
| swap_in_task_num_(swap_in_task_num) {} | |||||
| }; | }; | ||||
| // trigger swap | |||||
| struct MemSwapInfo { | struct MemSwapInfo { | ||||
| SwapKind swap_kind_; | SwapKind swap_kind_; | ||||
| // kernel need to be swapped | |||||
| AnfNodePtr kernel_{nullptr}; | |||||
| // Topo order of kernel need be swapped | |||||
| size_t topo_order_; | |||||
| size_t output_idx_{0}; | size_t output_idx_{0}; | ||||
| // Record the swapping out position of swapping in tensor | |||||
| size_t swap_out_pos_; | |||||
| }; | |||||
| struct SwapInfoComp { | |||||
| bool operator()(const MemSwapInfo &a, const MemSwapInfo &b) { | |||||
| int swap_kind_a = static_cast<int>(a.swap_kind_); | |||||
| int swap_kind_b = static_cast<int>(b.swap_kind_); | |||||
| if (swap_kind_a < swap_kind_b) { | |||||
| return true; | |||||
| } else if (swap_kind_a > swap_kind_b) { | |||||
| return false; | |||||
| } | |||||
| if (a.swap_out_pos_ < b.swap_out_pos_) { | |||||
| return true; | |||||
| } else if (a.swap_out_pos_ > b.swap_out_pos_) { | |||||
| return false; | |||||
| } | |||||
| if (a.topo_order_ < b.topo_order_) { | |||||
| return true; | |||||
| } else if (a.topo_order_ > b.topo_order_) { | |||||
| return false; | |||||
| } | |||||
| return a.output_idx_ < b.output_idx_; | |||||
| } | |||||
| }; | }; | ||||
| class MemCopyManager { | class MemCopyManager { | ||||
| @@ -90,6 +120,7 @@ class MemCopyManager { | |||||
| virtual void ClearSwapQueue() {} | virtual void ClearSwapQueue() {} | ||||
| }; | }; | ||||
| using MemCopyManagerPtr = std::shared_ptr<MemCopyManager>; | using MemCopyManagerPtr = std::shared_ptr<MemCopyManager>; | ||||
| using MemSwapInfoSet = std::set<MemSwapInfo, SwapInfoComp>; | |||||
| } // namespace memswap | } // namespace memswap | ||||
| } // namespace device | } // namespace device | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -22,22 +22,17 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace device { | namespace device { | ||||
| namespace memswap { | namespace memswap { | ||||
| void MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph) { | |||||
| bool MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph, size_t swap_mem_size) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | MS_EXCEPTION_IF_NULL(kernel_graph); | ||||
| graph_manager_ = kernel_graph->manager(); | graph_manager_ = kernel_graph->manager(); | ||||
| MS_EXCEPTION_IF_NULL(graph_manager_); | MS_EXCEPTION_IF_NULL(graph_manager_); | ||||
| auto &kernels = kernel_graph->execution_order(); | |||||
| for (const auto &kernel : kernels) { | |||||
| if (AnfAlgo::IsRealCNodeKernel(kernel) && (!opt::IsNopNode(kernel))) { | |||||
| execution_order_.push_back(kernel); | |||||
| } | |||||
| } | |||||
| execution_order_ = kernel_graph->execution_order(); | |||||
| size_t kernel_index = 0; | size_t kernel_index = 0; | ||||
| for (const auto &kernel : execution_order_) { | for (const auto &kernel : execution_order_) { | ||||
| // parse topo order of kernel | |||||
| // Parse topo order of kernel | |||||
| (void)kernel_execution_info_.emplace(kernel.get(), kernel_index++); | (void)kernel_execution_info_.emplace(kernel.get(), kernel_index++); | ||||
| // parse tensor info | |||||
| // Parse tensor info | |||||
| auto kernel_mod = AnfAlgo::GetKernelMod(kernel); | auto kernel_mod = AnfAlgo::GetKernelMod(kernel); | ||||
| MS_EXCEPTION_IF_NULL(kernel_mod); | MS_EXCEPTION_IF_NULL(kernel_mod); | ||||
| auto output_sizes = kernel_mod->GetOutputSizeList(); | auto output_sizes = kernel_mod->GetOutputSizeList(); | ||||
| @@ -48,7 +43,7 @@ void MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph) { | |||||
| } | } | ||||
| } | } | ||||
| // parse topo order of user kernel | |||||
| // Parse topo order of user kernel | |||||
| SaveUserKernelTopoOrder(); | SaveUserKernelTopoOrder(); | ||||
| sort(ordered_tensors_.begin(), ordered_tensors_.end(), | sort(ordered_tensors_.begin(), ordered_tensors_.end(), | ||||
| @@ -61,17 +56,103 @@ void MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph) { | |||||
| tensor_size_num_++; | tensor_size_num_++; | ||||
| } | } | ||||
| } | } | ||||
| tensor_size_threshold_ = ordered_tensors_.front().tensor_size_; | |||||
| tensor_size_threshold_idx_ = 0; | |||||
| distance_threshold_ = kernel_index / kDistanceInitFactor; | |||||
| if (!InitSwapThreshold(0)) { | |||||
| return false; | |||||
| } | |||||
| mem_swap_initialized_ = true; | mem_swap_initialized_ = true; | ||||
| MS_EXCEPTION_IF_NULL(mem_copy_manager_); | MS_EXCEPTION_IF_NULL(mem_copy_manager_); | ||||
| mem_copy_manager_->Init(); | mem_copy_manager_->Init(); | ||||
| return true; | |||||
| } | |||||
| bool MemSwapManager::InitSwapThreshold(size_t swap_mem_size) { | |||||
| distance_threshold_ = execution_order_.size() / kDistanceInitFactor; | |||||
| distance_decay_step_ = execution_order_.size() / kDistanceInitFactor / tensor_size_num_; | |||||
| if (distance_decay_step_ <= 1) { | |||||
| distance_decay_step_ = 1; | |||||
| } | |||||
| tensor_size_threshold_ = ordered_tensors_.front().tensor_size_; | |||||
| tensor_size_threshold_idx_ = 0; | |||||
| size_t accumulation = 0; | |||||
| while (accumulation < swap_mem_size) { | |||||
| accumulation = 0; | |||||
| for (const auto &tensor_info : ordered_tensors_) { | |||||
| size_t tensor_size = tensor_info.tensor_size_; | |||||
| if (tensor_size < tensor_size_threshold_) { | |||||
| break; | |||||
| } | |||||
| if (!CheckDistanceBetweenKernels(tensor_info)) { | |||||
| continue; | |||||
| } | |||||
| accumulation += tensor_info.tensor_size_; | |||||
| if (accumulation >= swap_mem_size) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| RetreatSwapThreshold(); | |||||
| if (tensor_size_threshold_idx_ == ordered_tensors_.size() - 1 && distance_threshold_ < kDistanceLowerBound) { | |||||
| MS_LOG(ERROR) << "Init swap threshold info failed"; | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| void MemSwapManager::RetreatSwapThreshold() { | |||||
| if (distance_threshold_ >= kDistanceLowerBound) { | |||||
| bool update_one_decay_step = (distance_threshold_ > distance_decay_step_) && | |||||
| (distance_threshold_ - distance_decay_step_ >= kDistanceLowerBound); | |||||
| if (update_one_decay_step) { | |||||
| distance_threshold_ -= distance_decay_step_; | |||||
| } else if (distance_threshold_ >= kDistanceLowerBound) { | |||||
| size_t new_distance_decay_step = (distance_threshold_ - kDistanceLowerBound) / 4; | |||||
| if (new_distance_decay_step < 1) { | |||||
| new_distance_decay_step = 1; | |||||
| } | |||||
| distance_threshold_ -= new_distance_decay_step; | |||||
| } | |||||
| } | |||||
| while (tensor_size_threshold_idx_ < ordered_tensors_.size() - 1) { | |||||
| ++tensor_size_threshold_idx_; | |||||
| if (tensor_size_threshold_ > ordered_tensors_[tensor_size_threshold_idx_].tensor_size_) { | |||||
| tensor_size_threshold_ = ordered_tensors_[tensor_size_threshold_idx_].tensor_size_; | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| bool MemSwapManager::CheckDistanceBetweenKernels(const TensorInfo &tensor_info) const { | |||||
| const AnfNodePtr &kernel = tensor_info.kernel_; | |||||
| auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); | |||||
| auto &node_users_map = kernel_exec_info.node_users_map_; | |||||
| auto iter = node_users_map.find(tensor_info.output_idx_); | |||||
| if (iter == node_users_map.end()) { | |||||
| return false; | |||||
| } | |||||
| auto &node_users = iter->second; | |||||
| if (node_users.front() - kernel_exec_info.topo_order_ > distance_threshold_) { | |||||
| return true; | |||||
| } | |||||
| for (size_t i = 1; i < node_users.size(); ++i) { | |||||
| if (node_users[i] - node_users[i - 1] > distance_threshold_) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | } | ||||
| bool MemSwapManager::IsCommunicationRelevantOp(const AnfNodePtr &kernel) const { | bool MemSwapManager::IsCommunicationRelevantOp(const AnfNodePtr &kernel) const { | ||||
| MS_EXCEPTION_IF_NULL(kernel); | MS_EXCEPTION_IF_NULL(kernel); | ||||
| if (AnfAlgo::IsCommunicationOp(kernel)) { | |||||
| return true; | |||||
| } | |||||
| NodeUsersMap &user_map = graph_manager_->node_users(); | NodeUsersMap &user_map = graph_manager_->node_users(); | ||||
| auto iter = user_map.find(kernel); | auto iter = user_map.find(kernel); | ||||
| bool adjacent_with_communication_op = false; | bool adjacent_with_communication_op = false; | ||||
| @@ -81,7 +162,7 @@ bool MemSwapManager::IsCommunicationRelevantOp(const AnfNodePtr &kernel) const { | |||||
| node_set.begin(), node_set.end(), | node_set.begin(), node_set.end(), | ||||
| [](const std::pair<AnfNodePtr, int> &node_pair) { return AnfAlgo::IsCommunicationOp(node_pair.first); }); | [](const std::pair<AnfNodePtr, int> &node_pair) { return AnfAlgo::IsCommunicationOp(node_pair.first); }); | ||||
| } | } | ||||
| return (AnfAlgo::IsCommunicationOp(kernel)) || adjacent_with_communication_op; | |||||
| return adjacent_with_communication_op; | |||||
| } | } | ||||
| void MemSwapManager::SaveUserKernelTopoOrder() { | void MemSwapManager::SaveUserKernelTopoOrder() { | ||||
| @@ -95,7 +176,7 @@ void MemSwapManager::SaveUserKernelTopoOrder() { | |||||
| auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); | auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); | ||||
| for (auto &node_pair : node_set) { | for (auto &node_pair : node_set) { | ||||
| auto user_kernel = node_pair.first; | auto user_kernel = node_pair.first; | ||||
| if (!AnfAlgo::IsRealCNodeKernel(user_kernel) || opt::IsNopNode(user_kernel)) { | |||||
| if (!AnfAlgo::IsRealCNodeKernel(user_kernel)) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -138,21 +219,18 @@ void MemSwapManager::AddSwapInfo() { | |||||
| if (!need_swap) { | if (!need_swap) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| AddKernelNeedSwap(kernel, true); | |||||
| HostAddress host_addr; | HostAddress host_addr; | ||||
| host_addr.size = tensor_size; | host_addr.size = tensor_size; | ||||
| auto ret = AllocHostPinnedMem(tensor_size, reinterpret_cast<void **>(&host_addr.addr)); | auto ret = AllocHostPinnedMem(tensor_size, reinterpret_cast<void **>(&host_addr.addr)); | ||||
| if (!ret) { | if (!ret) { | ||||
| MS_LOG(EXCEPTION) << "Alloc host pinned memory[" << tensor_size << "] failed."; | MS_LOG(EXCEPTION) << "Alloc host pinned memory[" << tensor_size << "] failed."; | ||||
| } | } | ||||
| kernel_exec_info.host_addrs_[output_idx] = host_addr; | |||||
| MemSwapInfo mem_swap_out_info = {SwapKind::kDeviceToHost, kernel, output_idx}; | |||||
| kernel_exec_info.host_addrs_[output_idx] = std::make_pair(host_addr, true); | |||||
| MemSwapInfo mem_swap_out_info = {SwapKind::kDeviceToHost, kernel_exec_info.topo_order_, output_idx, 0}; | |||||
| if (node_users.size() > 1) { | if (node_users.size() > 1) { | ||||
| AddKernelMemSwapInfo(execution_order_[node_users[0]], mem_swap_out_info); | AddKernelMemSwapInfo(execution_order_[node_users[0]], mem_swap_out_info); | ||||
| AddKernelTriggerSwap(execution_order_[node_users[0]], true); | |||||
| } else { | } else { | ||||
| AddKernelMemSwapInfo(kernel, mem_swap_out_info); | AddKernelMemSwapInfo(kernel, mem_swap_out_info); | ||||
| AddKernelTriggerSwap(kernel, true); | |||||
| } | } | ||||
| size_t swap_in_order = node_users.size() == 1 ? node_users[0] - 1 : node_users[1] - 1; | size_t swap_in_order = node_users.size() == 1 ? node_users[0] - 1 : node_users[1] - 1; | ||||
| @@ -160,9 +238,8 @@ void MemSwapManager::AddSwapInfo() { | |||||
| MS_LOG(EXCEPTION) << "Select swap in point failed for op[" << AnfAlgo::GetCNodeName(kernel) << "]"; | MS_LOG(EXCEPTION) << "Select swap in point failed for op[" << AnfAlgo::GetCNodeName(kernel) << "]"; | ||||
| } | } | ||||
| auto swap_in_kernel = execution_order_[swap_in_order]; | auto swap_in_kernel = execution_order_[swap_in_order]; | ||||
| MemSwapInfo mem_swap_in_info = {SwapKind::kHostToDevice, kernel, output_idx}; | |||||
| MemSwapInfo mem_swap_in_info = {SwapKind::kHostToDevice, kernel_exec_info.topo_order_, output_idx, 0}; | |||||
| AddKernelMemSwapInfo(swap_in_kernel, mem_swap_in_info); | AddKernelMemSwapInfo(swap_in_kernel, mem_swap_in_info); | ||||
| AddKernelTriggerSwap(swap_in_kernel, true); | |||||
| host_addrs_list_.push_back(host_addr); | host_addrs_list_.push_back(host_addr); | ||||
| } | } | ||||
| @@ -189,7 +266,7 @@ DeviceAddressPtr MemSwapManager::UpdateSwapQueue(SwapKind swap_kind) const { | |||||
| } | } | ||||
| } | } | ||||
| // retreat to find a workable swap scheme | |||||
| // Retreat to find a workable swap scheme | |||||
| bool MemSwapManager::RetreatSwapInfo() { | bool MemSwapManager::RetreatSwapInfo() { | ||||
| if (!trigger_swap_) { | if (!trigger_swap_) { | ||||
| trigger_swap_ = true; | trigger_swap_ = true; | ||||
| @@ -220,6 +297,114 @@ bool MemSwapManager::RetreatSwapInfo() { | |||||
| return true; | return true; | ||||
| } | } | ||||
| void MemSwapManager::AdjustSwapInPos(const AnfNodePtr &kernel, size_t index) { | |||||
| if (kernel_first_move_cache_map_.find(kernel.get()) == kernel_first_move_cache_map_.end()) { | |||||
| CacheCurSwapInfoSet(kernel); | |||||
| } | |||||
| auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); | |||||
| size_t kernel_pos = kernel_exec_info.topo_order_; | |||||
| auto &mem_swap_info = mem_swap_info_cache_list_[index]; | |||||
| if (QueryFirstTimeMovePos(kernel, index)) { | |||||
| best_and_cur_pos_cache_.first = BestSwapInPerformPos(kernel, mem_swap_info); | |||||
| best_and_cur_pos_cache_.second = best_and_cur_pos_cache_.first; | |||||
| size_t best_pos = best_and_cur_pos_cache_.first; | |||||
| if (best_pos != kernel_pos) { | |||||
| MoveSwapInfoPos(best_pos, kernel_pos, mem_swap_info); | |||||
| } | |||||
| AddFirstTimeMovePos(kernel, index, false); | |||||
| return; | |||||
| } | |||||
| auto &cur_pos = best_and_cur_pos_cache_.second; | |||||
| if (cur_pos < kernel_pos) { | |||||
| MoveSwapInfoPos(cur_pos + 1, cur_pos, mem_swap_info); | |||||
| cur_pos++; | |||||
| } | |||||
| } | |||||
| void MemSwapManager::CacheCurSwapInfoSet(const AnfNodePtr &kernel) { | |||||
| if (!kernel_first_move_cache_map_.empty()) { | |||||
| kernel_first_move_cache_map_.clear(); | |||||
| } | |||||
| if (!mem_swap_info_cache_list_.empty()) { | |||||
| mem_swap_info_cache_list_.clear(); | |||||
| } | |||||
| auto mem_swap_info_set = QueryKernelMemSwapInfo(kernel); | |||||
| size_t swap_in_task_cnt = 0; | |||||
| for (auto &mem_swap_info : mem_swap_info_set) { | |||||
| if (mem_swap_info.swap_kind_ == SwapKind::kHostToDevice) { | |||||
| (void)mem_swap_info_cache_list_.push_back(mem_swap_info); | |||||
| kernel_first_move_cache_map_[kernel.get()].push_back(true); | |||||
| swap_in_task_cnt++; | |||||
| } | |||||
| } | |||||
| size_t swap_in_task_num = QueryKernelTriggerSwapInTaskNum(kernel); | |||||
| if (swap_in_task_cnt != swap_in_task_num) { | |||||
| MS_LOG(EXCEPTION) << "Swap_in_task_cnt :" << swap_in_task_cnt | |||||
| << "must equal Swap_in_task_num: " << swap_in_task_num; | |||||
| } | |||||
| } | |||||
| void MemSwapManager::AddFirstTimeMovePos(const AnfNodePtr &kernel, size_t index, bool first_time) { | |||||
| auto iter = kernel_first_move_cache_map_.find(kernel.get()); | |||||
| if (iter == kernel_first_move_cache_map_.end()) { | |||||
| MS_LOG(EXCEPTION) << "Can not find first time move pos info of op[" << AnfAlgo::GetCNodeName(kernel) << "]"; | |||||
| } | |||||
| auto &first_move_list = iter->second; | |||||
| if (index >= first_move_list.size()) { | |||||
| MS_LOG(EXCEPTION) << "Index [" << index << "] out of range"; | |||||
| } | |||||
| first_move_list[index] = first_time; | |||||
| } | |||||
| bool MemSwapManager::QueryFirstTimeMovePos(const AnfNodePtr &kernel, size_t index) const { | |||||
| auto iter = kernel_first_move_cache_map_.find(kernel.get()); | |||||
| if (iter == kernel_first_move_cache_map_.end()) { | |||||
| MS_LOG(EXCEPTION) << "Can not find first time move pos info of op[" << AnfAlgo::GetCNodeName(kernel) << "]"; | |||||
| } | |||||
| const auto &first_move_list = iter->second; | |||||
| if (index >= first_move_list.size()) { | |||||
| MS_LOG(EXCEPTION) << "Index [" << index << "] out of range"; | |||||
| } | |||||
| return first_move_list[index]; | |||||
| } | |||||
| size_t MemSwapManager::BestSwapInPerformPos(const AnfNodePtr &trigger_kernel, const MemSwapInfo &mem_swap_info) const { | |||||
| auto need_swap_kernel = QueryKerneByTopoOrder(mem_swap_info.topo_order_); | |||||
| const PerformPair &perform_pair = QueryKernelSwapPerform(need_swap_kernel, mem_swap_info.output_idx_); | |||||
| float swap_in_cost_time = perform_pair.second; | |||||
| size_t swap_out_pos = mem_swap_info.swap_out_pos_; | |||||
| auto &kernel_exec_info = SearchKernelExecutionInfo(trigger_kernel); | |||||
| size_t trigger_kernel_pos = kernel_exec_info.topo_order_; | |||||
| float kernel_execution_time = 0; | |||||
| size_t pos = trigger_kernel_pos; | |||||
| for (; pos > swap_out_pos + 1; pos--) { | |||||
| auto kernel = QueryKerneByTopoOrder(pos - 1); | |||||
| if (QueryKernelTriggerSwapIn(kernel)) { | |||||
| return pos; | |||||
| } | |||||
| kernel_execution_time += QueryKernelExecutionPerform(QueryKerneByTopoOrder(pos)); | |||||
| if (kernel_execution_time >= swap_in_cost_time) { | |||||
| return pos - 1; | |||||
| } | |||||
| } | |||||
| return pos; | |||||
| } | |||||
| void MemSwapManager::MoveSwapInfoPos(size_t des_pos, size_t src_pos, const MemSwapInfo &mem_swap_info) { | |||||
| if (des_pos == src_pos) { | |||||
| MS_LOG(EXCEPTION) << "destination pos can not equal source pos"; | |||||
| } | |||||
| auto des_kernel = QueryKerneByTopoOrder(des_pos); | |||||
| auto src_kernel = QueryKerneByTopoOrder(src_pos); | |||||
| AddKernelMemSwapInfo(des_kernel, mem_swap_info); | |||||
| RemoveKernelMemSwapInfo(src_kernel, mem_swap_info); | |||||
| } | |||||
| KernelExecutionInfo &MemSwapManager::SearchKernelExecutionInfo(const AnfNodePtr &kernel) const { | KernelExecutionInfo &MemSwapManager::SearchKernelExecutionInfo(const AnfNodePtr &kernel) const { | ||||
| MS_EXCEPTION_IF_NULL(kernel); | MS_EXCEPTION_IF_NULL(kernel); | ||||
| auto iter = kernel_execution_info_.find(kernel.get()); | auto iter = kernel_execution_info_.find(kernel.get()); | ||||
| @@ -234,16 +419,6 @@ void MemSwapManager::AddKernelExecutionPerform(const AnfNodePtr &kernel, float p | |||||
| kernel_exec_info.execution_perform_ = perform; | kernel_exec_info.execution_perform_ = perform; | ||||
| } | } | ||||
| void MemSwapManager::AddKernelTriggerSwap(const AnfNodePtr &kernel, bool trigger_swap) { | |||||
| auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); | |||||
| kernel_exec_info.trigger_swap_ = trigger_swap; | |||||
| } | |||||
| void MemSwapManager::AddKernelNeedSwap(const AnfNodePtr &kernel, bool need_swap) { | |||||
| auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); | |||||
| kernel_exec_info.need_swap_ = need_swap; | |||||
| } | |||||
| void MemSwapManager::AddKernelSwapPerform(const AnfNodePtr &kernel, size_t output_idx, | void MemSwapManager::AddKernelSwapPerform(const AnfNodePtr &kernel, size_t output_idx, | ||||
| const std::pair<float, float> &perform) { | const std::pair<float, float> &perform) { | ||||
| MS_EXCEPTION_IF_NULL(kernel); | MS_EXCEPTION_IF_NULL(kernel); | ||||
| @@ -252,7 +427,42 @@ void MemSwapManager::AddKernelSwapPerform(const AnfNodePtr &kernel, size_t outpu | |||||
| void MemSwapManager::AddKernelMemSwapInfo(const AnfNodePtr &kernel, const MemSwapInfo &mem_swap_info) { | void MemSwapManager::AddKernelMemSwapInfo(const AnfNodePtr &kernel, const MemSwapInfo &mem_swap_info) { | ||||
| MS_EXCEPTION_IF_NULL(kernel); | MS_EXCEPTION_IF_NULL(kernel); | ||||
| mem_swap_info_[kernel.get()].push_back(mem_swap_info); | |||||
| (void)mem_swap_info_map_[kernel.get()].insert(mem_swap_info); | |||||
| auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); | |||||
| if (mem_swap_info.swap_kind_ == SwapKind::kDeviceToHost) { | |||||
| kernel_exec_info.trigger_swap_out_ = true; | |||||
| } else if (mem_swap_info.swap_kind_ == SwapKind::kHostToDevice) { | |||||
| kernel_exec_info.swap_in_task_num_++; | |||||
| kernel_exec_info.trigger_swap_in_ = true; | |||||
| } | |||||
| } | |||||
| void MemSwapManager::RemoveKernelMemSwapInfo(const AnfNodePtr &kernel, const MemSwapInfo &mem_swap_info) { | |||||
| MS_EXCEPTION_IF_NULL(kernel); | |||||
| if (mem_swap_info.swap_kind_ == SwapKind::kHostToDevice) { | |||||
| auto map_iter = mem_swap_info_map_.find(kernel.get()); | |||||
| if (map_iter == mem_swap_info_map_.end()) { | |||||
| MS_LOG(EXCEPTION) << "Can not find memory swap information of op[" << AnfAlgo::GetCNodeName(kernel) << "]"; | |||||
| } | |||||
| MemSwapInfoSet &mem_swap_info_set = map_iter->second; | |||||
| auto set_iter = mem_swap_info_set.find(mem_swap_info); | |||||
| if (set_iter == mem_swap_info_set.end()) { | |||||
| MS_LOG(EXCEPTION) << "Can not find memory swap information in mem swap info set"; | |||||
| } | |||||
| mem_swap_info_set.erase(set_iter); | |||||
| auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); | |||||
| if (kernel_exec_info.swap_in_task_num_ > 0) { | |||||
| kernel_exec_info.swap_in_task_num_--; | |||||
| } | |||||
| if (kernel_exec_info.swap_in_task_num_ == 0) { | |||||
| kernel_exec_info.trigger_swap_in_ = false; | |||||
| } | |||||
| if (mem_swap_info_set.empty()) { | |||||
| (void)mem_swap_info_map_.erase(kernel.get()); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| float MemSwapManager::QueryKernelExecutionPerform(const AnfNodePtr &kernel) const { | float MemSwapManager::QueryKernelExecutionPerform(const AnfNodePtr &kernel) const { | ||||
| @@ -262,12 +472,24 @@ float MemSwapManager::QueryKernelExecutionPerform(const AnfNodePtr &kernel) cons | |||||
| bool MemSwapManager::QueryKernelTriggerSwap(const AnfNodePtr &kernel) const { | bool MemSwapManager::QueryKernelTriggerSwap(const AnfNodePtr &kernel) const { | ||||
| const auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); | const auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); | ||||
| return kernel_exec_info.trigger_swap_; | |||||
| return kernel_exec_info.trigger_swap_out_ || kernel_exec_info.trigger_swap_in_; | |||||
| } | } | ||||
| bool MemSwapManager::QueryKernelNeedSwap(const AnfNodePtr &kernel) const { | |||||
| bool MemSwapManager::QueryKernelTriggerSwapIn(const AnfNodePtr &kernel) const { | |||||
| const auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); | const auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); | ||||
| return kernel_exec_info.need_swap_; | |||||
| return kernel_exec_info.trigger_swap_in_; | |||||
| } | |||||
| size_t MemSwapManager::QueryKernelTriggerSwapInTaskNum(const AnfNodePtr &kernel) const { | |||||
| const auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); | |||||
| return kernel_exec_info.swap_in_task_num_; | |||||
| } | |||||
| const AnfNodePtr MemSwapManager::QueryKerneByTopoOrder(size_t index) const { | |||||
| if (index >= execution_order_.size()) { | |||||
| MS_LOG(EXCEPTION) << "Index [" << index << "] out of range"; | |||||
| } | |||||
| return execution_order_[index]; | |||||
| } | } | ||||
| const PerformPair &MemSwapManager::QueryKernelSwapPerform(const AnfNodePtr &kernel, size_t output_idx) const { | const PerformPair &MemSwapManager::QueryKernelSwapPerform(const AnfNodePtr &kernel, size_t output_idx) const { | ||||
| @@ -286,30 +508,75 @@ const PerformPair &MemSwapManager::QueryKernelSwapPerform(const AnfNodePtr &kern | |||||
| return iter_output->second; | return iter_output->second; | ||||
| } | } | ||||
| const std::vector<MemSwapInfo> &MemSwapManager::QueryKernelMemSwapInfo(const AnfNodePtr &kernel) const { | |||||
| const MemSwapInfoSet &MemSwapManager::QueryKernelMemSwapInfo(const AnfNodePtr &kernel) const { | |||||
| MS_EXCEPTION_IF_NULL(kernel); | MS_EXCEPTION_IF_NULL(kernel); | ||||
| auto iter = mem_swap_info_.find(kernel.get()); | |||||
| if (iter == mem_swap_info_.end()) { | |||||
| MS_LOG(EXCEPTION) << "Can not find memory swap information data of op[" << AnfAlgo::GetCNodeName(kernel) << "]"; | |||||
| auto iter = mem_swap_info_map_.find(kernel.get()); | |||||
| if (iter == mem_swap_info_map_.end()) { | |||||
| MS_LOG(EXCEPTION) << "Can not find memory swap information of op[" << AnfAlgo::GetCNodeName(kernel) << "]"; | |||||
| } | } | ||||
| return iter->second; | return iter->second; | ||||
| } | } | ||||
| void MemSwapManager::InsertSwapInBlackList(const void *device_ptr) { swap_in_blacklist_.insert(device_ptr); } | |||||
| bool MemSwapManager::FindInSwapInBlackList(const void *device_ptr) const { | |||||
| auto iter = swap_in_blacklist_.find(device_ptr); | |||||
| return iter != swap_in_blacklist_.end(); | |||||
| void MemSwapManager::AssignHostMemory() { | |||||
| for (auto &kernel_exec_info_pair : kernel_execution_info_) { | |||||
| auto &kernel_exec_info = kernel_exec_info_pair.second; | |||||
| auto &host_addrs_map = kernel_exec_info.host_addrs_; | |||||
| for (auto &host_addr_pair : host_addrs_map) { | |||||
| auto &host_addr = host_addr_pair.second.first; | |||||
| auto ret = AllocHostPinnedMem(host_addr.size, reinterpret_cast<void **>(&host_addr.addr)); | |||||
| if (!ret) { | |||||
| MS_LOG(EXCEPTION) << "Alloc host pinned memory[" << host_addr.size << "] failed."; | |||||
| } | |||||
| host_addrs_list_.push_back(host_addr); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| const HostAddress &MemSwapManager::kernel_host_addr(const AnfNodePtr &kernel, size_t output_idx) const { | |||||
| const HostAddress &MemSwapManager::QueryKernelHostAddr(const AnfNodePtr &kernel, size_t output_idx) const { | |||||
| auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); | auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); | ||||
| auto &host_addrs = kernel_exec_info.host_addrs_; | auto &host_addrs = kernel_exec_info.host_addrs_; | ||||
| auto iter = host_addrs.find(output_idx); | auto iter = host_addrs.find(output_idx); | ||||
| if (iter == host_addrs.end()) { | if (iter == host_addrs.end()) { | ||||
| MS_LOG(EXCEPTION) << "Can not find host address of op[" << AnfAlgo::GetCNodeName(kernel) << "]"; | MS_LOG(EXCEPTION) << "Can not find host address of op[" << AnfAlgo::GetCNodeName(kernel) << "]"; | ||||
| } | } | ||||
| return iter->second; | |||||
| return (iter->second).first; | |||||
| } | |||||
| void MemSwapManager::AddKernelHostAddrIsDirty(const AnfNodePtr &kernel, size_t output_idx, bool dirty) { | |||||
| auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); | |||||
| auto &host_addrs = kernel_exec_info.host_addrs_; | |||||
| auto iter = host_addrs.find(output_idx); | |||||
| if (iter == host_addrs.end()) { | |||||
| MS_LOG(EXCEPTION) << "Can not find host memory dirty info of op[" << AnfAlgo::GetCNodeName(kernel) << "]"; | |||||
| } | |||||
| (iter->second).second = dirty; | |||||
| } | |||||
| bool MemSwapManager::QueryKernelHostAddrIsDirty(const AnfNodePtr &kernel, size_t output_idx) const { | |||||
| auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); | |||||
| auto &host_addrs = kernel_exec_info.host_addrs_; | |||||
| auto iter = host_addrs.find(output_idx); | |||||
| if (iter == host_addrs.end()) { | |||||
| MS_LOG(EXCEPTION) << "Can not find host memory dirty info of op[" << AnfAlgo::GetCNodeName(kernel) << "]"; | |||||
| } | |||||
| return (iter->second).second; | |||||
| } | |||||
| void MemSwapManager::ResetHostAddrIsDirty() { | |||||
| for (auto &kernel_exec_info_pair : kernel_execution_info_) { | |||||
| auto &kernel_exec_info = kernel_exec_info_pair.second; | |||||
| auto &host_addrs = kernel_exec_info.host_addrs_; | |||||
| for (auto &host_addr : host_addrs) { | |||||
| host_addr.second.second = true; | |||||
| } | |||||
| } | |||||
| } | |||||
| void MemSwapManager::InsertSwapInBlackList(const void *device_ptr) { swap_in_blacklist_.insert(device_ptr); } | |||||
| bool MemSwapManager::FindInSwapInBlackList(const void *device_ptr) const { | |||||
| auto iter = swap_in_blacklist_.find(device_ptr); | |||||
| return iter != swap_in_blacklist_.end(); | |||||
| } | } | ||||
| bool MemSwapManager::AllocHostPinnedMem(size_t size, void **addr) const { | bool MemSwapManager::AllocHostPinnedMem(size_t size, void **addr) const { | ||||
| @@ -331,13 +598,14 @@ void MemSwapManager::ResetSwapInfo() { | |||||
| ClearSwapQueue(); | ClearSwapQueue(); | ||||
| for (auto &kernel_exec_info_pair : kernel_execution_info_) { | for (auto &kernel_exec_info_pair : kernel_execution_info_) { | ||||
| auto &kernel_exec_info = kernel_exec_info_pair.second; | auto &kernel_exec_info = kernel_exec_info_pair.second; | ||||
| kernel_exec_info.trigger_swap_ = false; | |||||
| kernel_exec_info.need_swap_ = false; | |||||
| kernel_exec_info.trigger_swap_out_ = false; | |||||
| kernel_exec_info.trigger_swap_in_ = false; | |||||
| kernel_exec_info.swap_in_task_num_ = 0; | |||||
| kernel_exec_info.host_addrs_.clear(); | kernel_exec_info.host_addrs_.clear(); | ||||
| } | } | ||||
| ReleaseHostPinnedMem(); | ReleaseHostPinnedMem(); | ||||
| swap_in_blacklist_.clear(); | swap_in_blacklist_.clear(); | ||||
| mem_swap_info_.clear(); | |||||
| mem_swap_info_map_.clear(); | |||||
| } | } | ||||
| } // namespace memswap | } // namespace memswap | ||||
| } // namespace device | } // namespace device | ||||
| @@ -32,7 +32,11 @@ namespace memswap { | |||||
| class MemSwapManager { | class MemSwapManager { | ||||
| public: | public: | ||||
| explicit MemSwapManager(const MemCopyManagerPtr &mem_copy_manager) | explicit MemSwapManager(const MemCopyManagerPtr &mem_copy_manager) | ||||
| : tensor_size_threshold_(0), tensor_size_threshold_idx_(0), tensor_size_num_(1), distance_threshold_(1) { | |||||
| : tensor_size_threshold_(0), | |||||
| tensor_size_threshold_idx_(0), | |||||
| tensor_size_num_(1), | |||||
| distance_threshold_(1), | |||||
| distance_decay_step_(1) { | |||||
| mem_copy_manager_ = mem_copy_manager; | mem_copy_manager_ = mem_copy_manager; | ||||
| } | } | ||||
| @@ -42,7 +46,7 @@ class MemSwapManager { | |||||
| ~MemSwapManager() = default; | ~MemSwapManager() = default; | ||||
| void Init(const mindspore::session::KernelGraph *kernel_graph); | |||||
| bool Init(const mindspore::session::KernelGraph *kernel_graph, size_t swap_mem_size = 0); | |||||
| void AddMemSwapTask(SwapKind swap_kind, const DeviceAddressPtr &device_address, | void AddMemSwapTask(SwapKind swap_kind, const DeviceAddressPtr &device_address, | ||||
| const HostAddress &host_address) const; | const HostAddress &host_address) const; | ||||
| @@ -51,9 +55,10 @@ class MemSwapManager { | |||||
| DeviceAddressPtr UpdateSwapQueue(SwapKind swap_kind) const; | DeviceAddressPtr UpdateSwapQueue(SwapKind swap_kind) const; | ||||
| // retreat to find a workable swap scheme | |||||
| bool RetreatSwapInfo(); | bool RetreatSwapInfo(); | ||||
| void AdjustSwapInPos(const AnfNodePtr &kernel, size_t index); | |||||
| bool trigger_swap() const { return trigger_swap_; } | bool trigger_swap() const { return trigger_swap_; } | ||||
| bool mem_swap_init() const { return mem_swap_initialized_; } | bool mem_swap_init() const { return mem_swap_initialized_; } | ||||
| @@ -70,16 +75,28 @@ class MemSwapManager { | |||||
| bool QueryKernelTriggerSwap(const AnfNodePtr &kernel) const; | bool QueryKernelTriggerSwap(const AnfNodePtr &kernel) const; | ||||
| bool QueryKernelNeedSwap(const AnfNodePtr &kernel) const; | |||||
| bool QueryKernelTriggerSwapIn(const AnfNodePtr &kernel) const; | |||||
| size_t QueryKernelTriggerSwapInTaskNum(const AnfNodePtr &kernel) const; | |||||
| const AnfNodePtr QueryKerneByTopoOrder(size_t index) const; | |||||
| const MemSwapInfoSet &QueryKernelMemSwapInfo(const AnfNodePtr &kernel) const; | |||||
| void AssignHostMemory(); | |||||
| const std::vector<MemSwapInfo> &QueryKernelMemSwapInfo(const AnfNodePtr &kernel) const; | |||||
| const HostAddress &QueryKernelHostAddr(const AnfNodePtr &kernel, size_t output_idx) const; | |||||
| void AddKernelHostAddrIsDirty(const AnfNodePtr &kernel, size_t output_idx, bool dirty); | |||||
| bool QueryKernelHostAddrIsDirty(const AnfNodePtr &kernel, size_t output_idx) const; | |||||
| void ResetHostAddrIsDirty(); | |||||
| void InsertSwapInBlackList(const void *device_ptr); | void InsertSwapInBlackList(const void *device_ptr); | ||||
| bool FindInSwapInBlackList(const void *device_ptr) const; | bool FindInSwapInBlackList(const void *device_ptr) const; | ||||
| const HostAddress &kernel_host_addr(const AnfNodePtr &kernel, size_t output_idx) const; | |||||
| bool AllocHostPinnedMem(size_t size, void **addr) const; | bool AllocHostPinnedMem(size_t size, void **addr) const; | ||||
| void ReleaseHostPinnedMem(); | void ReleaseHostPinnedMem(); | ||||
| @@ -93,27 +110,47 @@ class MemSwapManager { | |||||
| void SaveUserKernelTopoOrder(); | void SaveUserKernelTopoOrder(); | ||||
| void AddKernelTriggerSwap(const AnfNodePtr &kernel, bool trigger_swap); | |||||
| bool InitSwapThreshold(size_t swap_mem_size); | |||||
| void AddKernelNeedSwap(const AnfNodePtr &kernel, bool need_swap); | |||||
| void RetreatSwapThreshold(); | |||||
| void CacheCurSwapInfoSet(const AnfNodePtr &kernel); | |||||
| void AddFirstTimeMovePos(const AnfNodePtr &kernel, size_t index, bool first_time); | |||||
| bool QueryFirstTimeMovePos(const AnfNodePtr &kernel, size_t index) const; | |||||
| size_t BestSwapInPerformPos(const AnfNodePtr &trigger_kernel, const MemSwapInfo &mem_swap_info) const; | |||||
| void MoveSwapInfoPos(size_t des_pos, size_t src_pos, const MemSwapInfo &mem_swap_info); | |||||
| void AddKernelMemSwapInfo(const AnfNodePtr &kernel, const MemSwapInfo &mem_swap_info); | void AddKernelMemSwapInfo(const AnfNodePtr &kernel, const MemSwapInfo &mem_swap_info); | ||||
| void RemoveKernelMemSwapInfo(const AnfNodePtr &kernel, const MemSwapInfo &mem_swap_info); | |||||
| bool CheckDistanceBetweenKernels(const TensorInfo &tensor_info) const; | |||||
| bool IsCommunicationRelevantOp(const AnfNodePtr &kernel) const; | bool IsCommunicationRelevantOp(const AnfNodePtr &kernel) const; | ||||
| std::vector<CNodePtr> execution_order_; | std::vector<CNodePtr> execution_order_; | ||||
| std::vector<TensorInfo> ordered_tensors_; | std::vector<TensorInfo> ordered_tensors_; | ||||
| std::unordered_map<void *, KernelExecutionInfo> kernel_execution_info_; | std::unordered_map<void *, KernelExecutionInfo> kernel_execution_info_; | ||||
| std::unordered_map<void *, std::map<size_t, PerformPair>> kernel_swap_perform_; | std::unordered_map<void *, std::map<size_t, PerformPair>> kernel_swap_perform_; | ||||
| // trigger swap kernel key : MemSwapInfo of kernel need to be swapped | |||||
| std::unordered_map<void *, std::vector<MemSwapInfo>> mem_swap_info_; | |||||
| // Key: trigger swap kernel, value: MemSwapInfoSet of kernel need to be swapped | |||||
| std::unordered_map<void *, MemSwapInfoSet> mem_swap_info_map_; | |||||
| std::vector<HostAddress> host_addrs_list_; | std::vector<HostAddress> host_addrs_list_; | ||||
| std::unordered_set<const void *> swap_in_blacklist_; | std::unordered_set<const void *> swap_in_blacklist_; | ||||
| // Key: cache kernel address, value: lists of first time move pos or not | |||||
| std::map<void *, std::vector<bool>> kernel_first_move_cache_map_; | |||||
| std::vector<MemSwapInfo> mem_swap_info_cache_list_; | |||||
| std::pair<size_t, size_t> best_and_cur_pos_cache_; | |||||
| size_t tensor_size_threshold_; | size_t tensor_size_threshold_; | ||||
| size_t tensor_size_threshold_idx_; | size_t tensor_size_threshold_idx_; | ||||
| size_t tensor_size_num_; | size_t tensor_size_num_; | ||||
| size_t distance_threshold_; | size_t distance_threshold_; | ||||
| size_t distance_decay_step_; | |||||
| MemCopyManagerPtr mem_copy_manager_{nullptr}; | MemCopyManagerPtr mem_copy_manager_{nullptr}; | ||||
| FuncGraphManagerPtr graph_manager_{nullptr}; | FuncGraphManagerPtr graph_manager_{nullptr}; | ||||
| @@ -707,6 +707,18 @@ DeviceAddress *AnfRuntimeAlgorithm::GetWorkspaceAddr(const AnfNodePtr &node, siz | |||||
| return addr; | return addr; | ||||
| } | } | ||||
| // get workspace device mutable addr of anf_node | |||||
| DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableWorkspaceAddr(const AnfNodePtr &node, size_t index) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); | |||||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||||
| auto addr = kernel_info->GetMutableWorkspaceAddr(index); | |||||
| if (addr == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "Index " << index << " of node " << node->DebugString() << "] workspace addr is not exist"; | |||||
| } | |||||
| return addr; | |||||
| } | |||||
| // set infer shapes and types of anf node | // set infer shapes and types of anf node | ||||
| void AnfRuntimeAlgorithm::SetOutputInferTypeAndShape(const std::vector<TypeId> &types, | void AnfRuntimeAlgorithm::SetOutputInferTypeAndShape(const std::vector<TypeId> &types, | ||||
| const std::vector<std::vector<size_t>> &shapes, AnfNode *node) { | const std::vector<std::vector<size_t>> &shapes, AnfNode *node) { | ||||
| @@ -149,6 +149,8 @@ class AnfRuntimeAlgorithm { | |||||
| static void SetWorkspaceAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node); | static void SetWorkspaceAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node); | ||||
| // get workspace device addr of anf_node | // get workspace device addr of anf_node | ||||
| static DeviceAddress *GetWorkspaceAddr(const AnfNodePtr &node, size_t output_idx); | static DeviceAddress *GetWorkspaceAddr(const AnfNodePtr &node, size_t output_idx); | ||||
| // get workspace device mutable addr of anf_node | |||||
| static DeviceAddressPtr GetMutableWorkspaceAddr(const AnfNodePtr &node, size_t index); | |||||
| // set infer shapes and types of anf node | // set infer shapes and types of anf node | ||||
| static void SetOutputInferTypeAndShape(const std::vector<TypeId> &types, | static void SetOutputInferTypeAndShape(const std::vector<TypeId> &types, | ||||
| const std::vector<std::vector<size_t>> &shapes, AnfNode *node); | const std::vector<std::vector<size_t>> &shapes, AnfNode *node); | ||||
| @@ -209,6 +209,16 @@ bool CudaDriver::QueryEvent(const DeviceEvent &event) { | |||||
| } | } | ||||
| } | } | ||||
| bool CudaDriver::ElapsedTime(float *cost_time, const DeviceEvent &start, const DeviceEvent &end) { | |||||
| auto ret = cudaEventElapsedTime(cost_time, (cudaEvent_t)start, (cudaEvent_t)end); | |||||
| if (ret == cudaSuccess) { | |||||
| return true; | |||||
| } else { | |||||
| MS_LOG(ERROR) << "cudaEventElapsedTime failed, ret[" << static_cast<int>(ret) << "], " << cudaGetErrorString(ret); | |||||
| return false; | |||||
| } | |||||
| } | |||||
| int CudaDriver::device_count() { | int CudaDriver::device_count() { | ||||
| int dev_count; | int dev_count; | ||||
| auto ret = cudaGetDeviceCount(&dev_count); | auto ret = cudaGetDeviceCount(&dev_count); | ||||
| @@ -57,6 +57,7 @@ class CudaDriver { | |||||
| static bool RecordEvent(DeviceEvent event, DeviceStream stream = 0); | static bool RecordEvent(DeviceEvent event, DeviceStream stream = 0); | ||||
| static bool SyncEvent(const DeviceEvent &event); | static bool SyncEvent(const DeviceEvent &event); | ||||
| static bool QueryEvent(const DeviceEvent &event); | static bool QueryEvent(const DeviceEvent &event); | ||||
| static bool ElapsedTime(float *cost_time, const DeviceEvent &start, const DeviceEvent &end); | |||||
| // Encapsulate the cuda APIs associated with device management. | // Encapsulate the cuda APIs associated with device management. | ||||
| static int device_count(); | static int device_count(); | ||||
| @@ -33,6 +33,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace device { | namespace device { | ||||
| namespace gpu { | namespace gpu { | ||||
| using mindspore::device::memswap::MemSwapInfoSet; | |||||
| using mindspore::device::memswap::MemSwapManager; | using mindspore::device::memswap::MemSwapManager; | ||||
| using mindspore::device::memswap::SwapKind; | using mindspore::device::memswap::SwapKind; | ||||
| bool GPUKernelRuntime::SyncStream() { return GPUDeviceManager::GetInstance().SyncStream(stream_); } | bool GPUKernelRuntime::SyncStream() { return GPUDeviceManager::GetInstance().SyncStream(stream_); } | ||||
| @@ -139,6 +140,7 @@ void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) { | |||||
| InitKernelRefCount(graph); | InitKernelRefCount(graph); | ||||
| InitMemorySwapInfo(graph); | InitMemorySwapInfo(graph); | ||||
| InitKernelOutputAddress(graph); | InitKernelOutputAddress(graph); | ||||
| InitKernelWorkspaceAddress(graph); | |||||
| } else { | } else { | ||||
| AssignDynamicMemory(graph); | AssignDynamicMemory(graph); | ||||
| } | } | ||||
| @@ -183,6 +185,56 @@ bool GPUKernelRuntime::Run(session::KernelGraph *graph) { | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| bool GPUKernelRuntime::SearchMemSwapScheme(const session::KernelGraph *graph) { | |||||
| bool ret = false; | |||||
| ClearKernelOldOutputAndWorkspace(graph); | |||||
| if (!mem_swap_manager_->mem_swap_init()) { | |||||
| if (!mem_swap_manager_->Init(graph)) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| while (!ret) { | |||||
| if (!mem_swap_manager_->RetreatSwapInfo()) { | |||||
| return false; | |||||
| } | |||||
| ret = LaunchKernelDynamic(graph, true, false); | |||||
| if (!ret) { | |||||
| ClearKernelOldOutputAndWorkspace(graph); | |||||
| } | |||||
| } | |||||
| mem_swap_manager_->AssignHostMemory(); | |||||
| // Time profiling | |||||
| ret = LaunchKernelDynamic(graph, false, true); | |||||
| if (!ret) { | |||||
| return ret; | |||||
| } | |||||
| return RefineMemSwapScheme(graph); | |||||
| } | |||||
| bool GPUKernelRuntime::RefineMemSwapScheme(const session::KernelGraph *graph) { | |||||
| auto &kernels = graph->execution_order(); | |||||
| for (const auto &kernel : kernels) { | |||||
| if (!mem_swap_manager_->QueryKernelTriggerSwapIn(kernel)) { | |||||
| continue; | |||||
| } | |||||
| size_t swap_in_task_num = mem_swap_manager_->QueryKernelTriggerSwapInTaskNum(kernel); | |||||
| for (size_t swap_in_task_idx = 0; swap_in_task_idx < swap_in_task_num; swap_in_task_idx++) { | |||||
| bool ret = false; | |||||
| while (!ret) { | |||||
| mem_swap_manager_->AdjustSwapInPos(kernel, swap_in_task_idx); | |||||
| ret = LaunchKernelDynamic(graph, true, false); | |||||
| if (!ret) { | |||||
| ClearKernelOldOutputAndWorkspace(graph); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| void GPUKernelRuntime::InitKernelRefCount(const session::KernelGraph *graph) { | void GPUKernelRuntime::InitKernelRefCount(const session::KernelGraph *graph) { | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| MemReuseUtilPtr mem_reuse_util_ptr = std::make_shared<memreuse::MemReuseUtil>(); | MemReuseUtilPtr mem_reuse_util_ptr = std::make_shared<memreuse::MemReuseUtil>(); | ||||
| @@ -209,6 +261,7 @@ void GPUKernelRuntime::InitMemorySwapInfo(const session::KernelGraph *graph) { | |||||
| MS_EXCEPTION_IF_NULL(mem_swap_manager); | MS_EXCEPTION_IF_NULL(mem_swap_manager); | ||||
| auto graph_id = graph->graph_id(); | auto graph_id = graph->graph_id(); | ||||
| mem_swap_map_[graph_id] = mem_swap_manager; | mem_swap_map_[graph_id] = mem_swap_manager; | ||||
| is_first_step_map_[graph_id] = true; | |||||
| } | } | ||||
| void GPUKernelRuntime::InitKernelOutputAddress(const session::KernelGraph *graph) { | void GPUKernelRuntime::InitKernelOutputAddress(const session::KernelGraph *graph) { | ||||
| @@ -230,6 +283,25 @@ void GPUKernelRuntime::InitKernelOutputAddress(const session::KernelGraph *graph | |||||
| } | } | ||||
| } | } | ||||
| void GPUKernelRuntime::InitKernelWorkspaceAddress(const session::KernelGraph *graph) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| auto &kernels = graph->execution_order(); | |||||
| for (const auto &kernel : kernels) { | |||||
| auto kernel_mod = AnfAlgo::GetKernelMod(kernel); | |||||
| MS_EXCEPTION_IF_NULL(kernel_mod); | |||||
| auto workspace_sizes = kernel_mod->GetWorkspaceSizeList(); | |||||
| for (size_t i = 0; i < workspace_sizes.size(); ++i) { | |||||
| auto device_address = CreateDeviceAddress(nullptr, workspace_sizes[i], "", kTypeUnknown); | |||||
| AnfAlgo::SetWorkspaceAddr(device_address, i, kernel.get()); | |||||
| } | |||||
| } | |||||
| } | |||||
| void GPUKernelRuntime::ClearKernelOldOutputAndWorkspace(const session::KernelGraph *graph) { | |||||
| ClearKernelOutputAddress(graph); | |||||
| ClearKernelWorkspaceAddress(graph); | |||||
| } | |||||
| void GPUKernelRuntime::ClearKernelOutputAddress(const session::KernelGraph *graph) { | void GPUKernelRuntime::ClearKernelOutputAddress(const session::KernelGraph *graph) { | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| auto &kernels = graph->execution_order(); | auto &kernels = graph->execution_order(); | ||||
| @@ -242,6 +314,7 @@ void GPUKernelRuntime::ClearKernelOutputAddress(const session::KernelGraph *grap | |||||
| continue; | continue; | ||||
| } | } | ||||
| auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false); | auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false); | ||||
| MS_EXCEPTION_IF_NULL(device_address); | |||||
| if (device_address->ptr_) { | if (device_address->ptr_) { | ||||
| mem_manager_->FreeMemFromMemPool(device_address); | mem_manager_->FreeMemFromMemPool(device_address); | ||||
| } | } | ||||
| @@ -250,7 +323,24 @@ void GPUKernelRuntime::ClearKernelOutputAddress(const session::KernelGraph *grap | |||||
| } | } | ||||
| } | } | ||||
| bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph) { | |||||
| void GPUKernelRuntime::ClearKernelWorkspaceAddress(const session::KernelGraph *graph) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| auto &kernels = graph->execution_order(); | |||||
| for (const auto &kernel : kernels) { | |||||
| auto kernel_mod = AnfAlgo::GetKernelMod(kernel); | |||||
| MS_EXCEPTION_IF_NULL(kernel_mod); | |||||
| auto workspace_sizes = kernel_mod->GetWorkspaceSizeList(); | |||||
| for (size_t i = 0; i < workspace_sizes.size(); ++i) { | |||||
| auto device_address = AnfAlgo::GetMutableWorkspaceAddr(kernel, i); | |||||
| MS_EXCEPTION_IF_NULL(device_address); | |||||
| if (device_address->ptr_) { | |||||
| mem_manager_->FreeMemFromMemPool(device_address); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph, bool mock, bool profiling) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| MS_EXCEPTION_IF_NULL(mem_reuse_util_); | MS_EXCEPTION_IF_NULL(mem_reuse_util_); | ||||
| // Reset the reference count. | // Reset the reference count. | ||||
| @@ -271,7 +361,7 @@ bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph) { | |||||
| if (!kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_)) { | if (!kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_)) { | ||||
| MS_LOG(EXCEPTION) << "Launch kernel failed."; | MS_LOG(EXCEPTION) << "Launch kernel failed."; | ||||
| } | } | ||||
| FreeKernelDynamicRes(kernel, kernel_workspaces); | |||||
| FreeKernelDynamicRes(kernel); | |||||
| UpdateMemorySwapTask(kernel); | UpdateMemorySwapTask(kernel); | ||||
| } | } | ||||
| CHECK_OP_RET_WITH_EXCEPT(SyncStream(), "SyncStream failed."); | CHECK_OP_RET_WITH_EXCEPT(SyncStream(), "SyncStream failed."); | ||||
| @@ -279,13 +369,39 @@ bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| void GPUKernelRuntime::LaunchKernelWithTimeProfiling(const AnfNodePtr &kernel, const AddressPtrList &inputs, | |||||
| const AddressPtrList &workspace, const AddressPtrList &outputs) { | |||||
| auto kernel_mod = AnfAlgo::GetKernelMod(kernel); | |||||
| MS_EXCEPTION_IF_NULL(kernel_mod); | |||||
| float cost_time = 0; | |||||
| DeviceEvent start = nullptr; | |||||
| DeviceEvent end = nullptr; | |||||
| CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateEvent(&start), "Failed to create event."); | |||||
| CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateEvent(&end), "Failed to create event."); | |||||
| CHECK_OP_RET_WITH_EXCEPT(CudaDriver::RecordEvent(start, stream_), "Failed to record event to stream."); | |||||
| CHECK_OP_RET_WITH_EXCEPT(kernel_mod->Launch(inputs, workspace, outputs, stream_), "Launch kernel failed."); | |||||
| CHECK_OP_RET_WITH_EXCEPT(CudaDriver::RecordEvent(end, stream_), "Failed to record event to stream."); | |||||
| CHECK_OP_RET_WITH_EXCEPT(CudaDriver::SyncEvent(start), "Failed to sync event."); | |||||
| CHECK_OP_RET_WITH_EXCEPT(CudaDriver::SyncEvent(end), "Failed to sync event."); | |||||
| CHECK_OP_RET_WITH_EXCEPT(CudaDriver::ElapsedTime(&cost_time, start, end), "Failed to record elapsed time."); | |||||
| mem_swap_manager_->AddKernelExecutionPerform(kernel, cost_time); | |||||
| CHECK_OP_RET_WITH_EXCEPT(CudaDriver::DestroyEvent(start), "Failed to destroy event."); | |||||
| CHECK_OP_RET_WITH_EXCEPT(CudaDriver::DestroyEvent(end), "Failed to destroy event."); | |||||
| } | |||||
| bool GPUKernelRuntime::AddMemorySwapTask(const AnfNodePtr &kernel) { | bool GPUKernelRuntime::AddMemorySwapTask(const AnfNodePtr &kernel) { | ||||
| MS_EXCEPTION_IF_NULL(mem_swap_manager_); | MS_EXCEPTION_IF_NULL(mem_swap_manager_); | ||||
| auto &mem_swap_info_list = mem_swap_manager_->QueryKernelMemSwapInfo(kernel); | |||||
| for (auto &mem_swap_info : mem_swap_info_list) { | |||||
| auto &kernel_exec_info = mem_swap_manager_->SearchKernelExecutionInfo(mem_swap_info.kernel_); | |||||
| const HostAddress &host_address = kernel_exec_info.host_addrs_[mem_swap_info.output_idx_]; | |||||
| auto device_address = AnfAlgo::GetMutableOutputAddr(mem_swap_info.kernel_, mem_swap_info.output_idx_, false); | |||||
| const MemSwapInfoSet &mem_swap_info_set = mem_swap_manager_->QueryKernelMemSwapInfo(kernel); | |||||
| for (auto &mem_swap_info : mem_swap_info_set) { | |||||
| auto need_swap_kernel = mem_swap_manager_->QueryKerneByTopoOrder(mem_swap_info.topo_order_); | |||||
| MS_EXCEPTION_IF_NULL(need_swap_kernel); | |||||
| const HostAddress &host_address = | |||||
| mem_swap_manager_->QueryKernelHostAddr(need_swap_kernel, mem_swap_info.output_idx_); | |||||
| auto device_address = AnfAlgo::GetMutableOutputAddr(need_swap_kernel, mem_swap_info.output_idx_, false); | |||||
| if (mem_swap_info.swap_kind_ == SwapKind::kDeviceToHost) { | if (mem_swap_info.swap_kind_ == SwapKind::kDeviceToHost) { | ||||
| mem_swap_manager_->AddMemSwapTask(SwapKind::kDeviceToHost, device_address, host_address); | mem_swap_manager_->AddMemSwapTask(SwapKind::kDeviceToHost, device_address, host_address); | ||||
| @@ -309,9 +425,11 @@ bool GPUKernelRuntime::AddMemorySwapTask(const AnfNodePtr &kernel) { | |||||
| bool GPUKernelRuntime::UpdateMemorySwapInfo(const session::KernelGraph *graph) { | bool GPUKernelRuntime::UpdateMemorySwapInfo(const session::KernelGraph *graph) { | ||||
| MS_EXCEPTION_IF_NULL(mem_swap_manager_); | MS_EXCEPTION_IF_NULL(mem_swap_manager_); | ||||
| ClearKernelOutputAddress(graph); | |||||
| ClearKernelOldOutputAndWorkspace(graph); | |||||
| if (!mem_swap_manager_->mem_swap_init()) { | if (!mem_swap_manager_->mem_swap_init()) { | ||||
| mem_swap_manager_->Init(graph); | |||||
| if (!mem_swap_manager_->Init(graph)) { | |||||
| return false; | |||||
| } | |||||
| } | } | ||||
| return mem_swap_manager_->RetreatSwapInfo(); | return mem_swap_manager_->RetreatSwapInfo(); | ||||
| } | } | ||||
| @@ -408,29 +526,6 @@ bool GPUKernelRuntime::AttemptMallocMem(const DeviceAddressPtr &device_address, | |||||
| return true; | return true; | ||||
| } | } | ||||
| void *GPUKernelRuntime::AttemptMallocMem(size_t size) { | |||||
| MS_EXCEPTION_IF_NULL(mem_manager_); | |||||
| MS_EXCEPTION_IF_NULL(mem_swap_manager_); | |||||
| auto device_ptr = mem_manager_->MallocMemFromMemPool(size); | |||||
| if (!device_ptr) { | |||||
| if (!mem_swap_manager_->trigger_swap()) { | |||||
| return nullptr; | |||||
| } | |||||
| mem_swap_manager_->SyncMemCopyStream(SwapKind::kDeviceToHost); | |||||
| while (auto device_address_swap_out = mem_swap_manager_->UpdateSwapQueue(SwapKind::kDeviceToHost)) { | |||||
| if (!mem_swap_manager_->FindInSwapInBlackList(device_address_swap_out->ptr_) && device_address_swap_out->ptr_) { | |||||
| device_address_swap_out->set_status(DeviceAddressStatus::kInHost); | |||||
| mem_manager_->FreeMemFromMemPool(device_address_swap_out); | |||||
| } | |||||
| } | |||||
| device_ptr = mem_manager_->MallocMemFromMemPool(size); | |||||
| if (!device_ptr) { | |||||
| return nullptr; | |||||
| } | |||||
| } | |||||
| return device_ptr; | |||||
| } | |||||
| bool GPUKernelRuntime::AllocKernelDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, | bool GPUKernelRuntime::AllocKernelDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, | ||||
| const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_inputs, | const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_inputs, | ||||
| AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs) { | AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs) { | ||||
| @@ -504,13 +599,13 @@ bool GPUKernelRuntime::AllocKernelWorkspaceDynamicRes(const mindspore::kernel::K | |||||
| kernel_workspaces->emplace_back(nullptr); | kernel_workspaces->emplace_back(nullptr); | ||||
| continue; | continue; | ||||
| } | } | ||||
| auto device_ptr = AttemptMallocMem(workspace_sizes[i]); | |||||
| if (!device_ptr) { | |||||
| auto device_address = AnfAlgo::GetMutableWorkspaceAddr(kernel, i); | |||||
| if (device_address->ptr_ == nullptr && !AttemptMallocMem(device_address, workspace_sizes[i])) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| kernel::AddressPtr workspace = std::make_shared<kernel::Address>(); | kernel::AddressPtr workspace = std::make_shared<kernel::Address>(); | ||||
| MS_EXCEPTION_IF_NULL(workspace); | MS_EXCEPTION_IF_NULL(workspace); | ||||
| workspace->addr = device_ptr; | |||||
| workspace->addr = device_address->ptr_; | |||||
| workspace->size = workspace_sizes[i]; | workspace->size = workspace_sizes[i]; | ||||
| kernel_workspaces->emplace_back(workspace); | kernel_workspaces->emplace_back(workspace); | ||||
| } | } | ||||
| @@ -606,8 +701,7 @@ void GPUKernelRuntime::AllocCommunicationOpMemory(bool is_need_alloc_memory, boo | |||||
| } | } | ||||
| } | } | ||||
| void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, | |||||
| const AddressPtrList &kernel_workspaces) { | |||||
| void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel) { | |||||
| MS_EXCEPTION_IF_NULL(kernel); | MS_EXCEPTION_IF_NULL(kernel); | ||||
| MS_EXCEPTION_IF_NULL(mem_manager_); | MS_EXCEPTION_IF_NULL(mem_manager_); | ||||
| MS_EXCEPTION_IF_NULL(mem_reuse_util_); | MS_EXCEPTION_IF_NULL(mem_reuse_util_); | ||||
| @@ -652,12 +746,13 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, | |||||
| } | } | ||||
| } | } | ||||
| // Free the workspace of kernel. | // Free the workspace of kernel. | ||||
| for (size_t i = 0; i < kernel_workspaces.size(); ++i) { | |||||
| auto workspace = kernel_workspaces[i]; | |||||
| if (workspace != nullptr) { | |||||
| MS_EXCEPTION_IF_NULL(workspace->addr); | |||||
| mem_manager_->FreeMemFromMemPool(workspace->addr); | |||||
| workspace->addr = nullptr; | |||||
| auto kernel_mod = AnfAlgo::GetKernelMod(kernel); | |||||
| MS_EXCEPTION_IF_NULL(kernel_mod); | |||||
| for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) { | |||||
| auto device_address = AnfAlgo::GetMutableWorkspaceAddr(kernel, i); | |||||
| MS_EXCEPTION_IF_NULL(device_address); | |||||
| if (device_address->ptr_) { | |||||
| mem_manager_->FreeMemFromMemPool(device_address); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -53,11 +53,17 @@ class GPUKernelRuntime : public KernelRuntime { | |||||
| // The related functions and members for using dynamic memory pool. | // The related functions and members for using dynamic memory pool. | ||||
| void InitKernelRefCount(const session::KernelGraph *graph); | void InitKernelRefCount(const session::KernelGraph *graph); | ||||
| void InitKernelOutputAddress(const session::KernelGraph *graph); | void InitKernelOutputAddress(const session::KernelGraph *graph); | ||||
| void InitKernelWorkspaceAddress(const session::KernelGraph *graph); | |||||
| void InitMemorySwapInfo(const session::KernelGraph *graph); | void InitMemorySwapInfo(const session::KernelGraph *graph); | ||||
| void ClearKernelOutputAddress(const session::KernelGraph *graph); | void ClearKernelOutputAddress(const session::KernelGraph *graph); | ||||
| bool LaunchKernelDynamic(const session::KernelGraph *graph); | |||||
| void ClearKernelWorkspaceAddress(const session::KernelGraph *graph); | |||||
| void ClearKernelOldOutputAndWorkspace(const session::KernelGraph *graph); | |||||
| bool SearchMemSwapScheme(const session::KernelGraph *graph); | |||||
| bool RefineMemSwapScheme(const session::KernelGraph *graph); | |||||
| bool LaunchKernelDynamic(const session::KernelGraph *graph, bool mock = false, bool profiling = false); | |||||
| void LaunchKernelWithTimeProfiling(const AnfNodePtr &kernel, const AddressPtrList &inputs, | |||||
| const AddressPtrList &workspace, const AddressPtrList &outputs); | |||||
| bool AttemptMallocMem(const DeviceAddressPtr &device_address, size_t size); | bool AttemptMallocMem(const DeviceAddressPtr &device_address, size_t size); | ||||
| void *AttemptMallocMem(size_t size); | |||||
| bool AllocKernelDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel, | bool AllocKernelDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel, | ||||
| AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces, | AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces, | ||||
| AddressPtrList *kernel_outputs); | AddressPtrList *kernel_outputs); | ||||
| @@ -72,7 +78,7 @@ class GPUKernelRuntime : public KernelRuntime { | |||||
| void AllocCommunicationOpMemory(bool is_need_alloc_memory, bool is_need_free_memory, | void AllocCommunicationOpMemory(bool is_need_alloc_memory, bool is_need_free_memory, | ||||
| const DeviceAddressPtrList addr_list, size_t total_size, | const DeviceAddressPtrList addr_list, size_t total_size, | ||||
| std::vector<size_t> size_list); | std::vector<size_t> size_list); | ||||
| void FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, const AddressPtrList &kernel_workspaces); | |||||
| void FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel); | |||||
| bool AddMemorySwapTask(const AnfNodePtr &kernel); | bool AddMemorySwapTask(const AnfNodePtr &kernel); | ||||
| bool UpdateMemorySwapInfo(const session::KernelGraph *graph); | bool UpdateMemorySwapInfo(const session::KernelGraph *graph); | ||||
| bool UpdateMemorySwapTask(const AnfNodePtr &kernel); | bool UpdateMemorySwapTask(const AnfNodePtr &kernel); | ||||
| @@ -81,6 +87,7 @@ class GPUKernelRuntime : public KernelRuntime { | |||||
| void ClearSwapQueue(); | void ClearSwapQueue(); | ||||
| std::unordered_map<uint32_t, MemReuseUtilPtr> mem_reuse_util_map_; | std::unordered_map<uint32_t, MemReuseUtilPtr> mem_reuse_util_map_; | ||||
| std::unordered_map<uint32_t, MemSwapManagerPtr> mem_swap_map_; | std::unordered_map<uint32_t, MemSwapManagerPtr> mem_swap_map_; | ||||
| std::unordered_map<uint32_t, bool> is_first_step_map_; | |||||
| MemReuseUtilPtr mem_reuse_util_{nullptr}; | MemReuseUtilPtr mem_reuse_util_{nullptr}; | ||||
| MemSwapManagerPtr mem_swap_manager_{nullptr}; | MemSwapManagerPtr mem_swap_manager_{nullptr}; | ||||
| }; | }; | ||||
| @@ -73,6 +73,14 @@ DeviceAddress *KernelInfo::GetWorkspaceAddr(size_t index) const { | |||||
| return workspace_address_list_[index].get(); | return workspace_address_list_[index].get(); | ||||
| } | } | ||||
| DeviceAddressPtr KernelInfo::GetMutableWorkspaceAddr(size_t index) const { | |||||
| if (index >= workspace_address_list_.size()) { | |||||
| MS_LOG(ERROR) << "Index [" << index << "] out of range"; | |||||
| return nullptr; | |||||
| } | |||||
| return workspace_address_list_[index]; | |||||
| } | |||||
| bool KernelInfo::SetWorkspaceAddr(const DeviceAddressPtr &output_address, size_t index) { | bool KernelInfo::SetWorkspaceAddr(const DeviceAddressPtr &output_address, size_t index) { | ||||
| if (workspace_address_list_.empty()) { | if (workspace_address_list_.empty()) { | ||||
| // parameter and valuenode | // parameter and valuenode | ||||
| @@ -54,6 +54,7 @@ class KernelInfo : public KernelInfoDevice { | |||||
| bool OutputAddrExist(size_t index) const; | bool OutputAddrExist(size_t index) const; | ||||
| bool SetOutputAddr(const DeviceAddressPtr &output_address, size_t index); | bool SetOutputAddr(const DeviceAddressPtr &output_address, size_t index); | ||||
| DeviceAddress *GetWorkspaceAddr(size_t index) const; | DeviceAddress *GetWorkspaceAddr(size_t index) const; | ||||
| DeviceAddressPtr GetMutableWorkspaceAddr(size_t index) const; | |||||
| bool SetWorkspaceAddr(const DeviceAddressPtr &output_address, size_t index); | bool SetWorkspaceAddr(const DeviceAddressPtr &output_address, size_t index); | ||||
| void set_kernel_mod(const kernel::KernelModPtr &kernel_mod); | void set_kernel_mod(const kernel::KernelModPtr &kernel_mod); | ||||
| kernel::KernelMod *MutableKernelMod() const; | kernel::KernelMod *MutableKernelMod() const; | ||||