Merge pull request !26695 from tanghuikang/swap_strategy_adjusttags/v1.6.0
| @@ -1294,9 +1294,6 @@ void KernelRuntime::GetOrMallocAddress(const std::shared_ptr<MemScheduler> &mem_ | |||
| kernel_addr->addr = device_address->ptr_; | |||
| } else { | |||
| kernel_addr->addr = mem_scheduler->GetOrMalloc(device_address, device_address->size_); | |||
| if (mem_scheduler->IsHighPriorityMem(device_address)) { | |||
| device_address->ptr_ = kernel_addr->addr; | |||
| } | |||
| } | |||
| } | |||
| @@ -1343,37 +1340,29 @@ void KernelRuntime::AssignKernelAddress(const std::shared_ptr<MemScheduler> &mem | |||
| } | |||
| void KernelRuntime::SyncNodeOutputTensors(const std::shared_ptr<MemScheduler> &mem_scheduler, | |||
| const session::KernelGraph &graph, const AnfNodePtr &kernel, bool mock) { | |||
| const session::KernelGraph &graph, const AnfNodePtr &kernel) { | |||
| MS_EXCEPTION_IF_NULL(mem_scheduler); | |||
| MS_EXCEPTION_IF_NULL(kernel); | |||
| auto kernel_mod = AnfAlgo::GetKernelMod(kernel); | |||
| MS_EXCEPTION_IF_NULL(kernel_mod); | |||
| for (size_t input_idx = 0; input_idx < kernel_mod->GetInputSizeList().size(); ++input_idx) { | |||
| const auto input_node_index = AnfAlgo::GetPrevNodeOutput(kernel, input_idx, true); | |||
| if (input_node_index.first == nullptr || !input_node_index.first->isa<Parameter>()) { | |||
| continue; | |||
| if (input_node_index.first != nullptr && input_node_index.first->isa<Parameter>()) { | |||
| SyncNodeOutputTensor(mem_scheduler, input_node_index, graph); | |||
| } | |||
| SyncNodeOutputTensor(mem_scheduler, input_node_index, graph, mock); | |||
| } | |||
| for (size_t output_idx = 0; output_idx < kernel_mod->GetOutputSizeList().size(); ++output_idx) { | |||
| SyncNodeOutputTensor(mem_scheduler, std::make_pair(kernel, output_idx), graph, mock); | |||
| SyncNodeOutputTensor(mem_scheduler, std::make_pair(kernel, output_idx), graph); | |||
| } | |||
| } | |||
| void KernelRuntime::SyncNodeOutputTensor(const std::shared_ptr<MemScheduler> &mem_scheduler, | |||
| const KernelWithIndex &node_output_index, const session::KernelGraph &graph, | |||
| bool mock) { | |||
| const KernelWithIndex &node_output_index, const session::KernelGraph &graph) { | |||
| MS_EXCEPTION_IF_NULL(mem_scheduler); | |||
| if (node_output_index.first == nullptr) { | |||
| return; | |||
| } | |||
| auto device_address = AnfAlgo::GetMutableOutputAddr(node_output_index, true); | |||
| if (mock) { | |||
| if (graph.IsInternalOutput(node_output_index.first, node_output_index.second) && device_address != nullptr) { | |||
| mem_scheduler->SetMemPriority(device_address.get(), kMemPriorityHigh); | |||
| } | |||
| return; | |||
| } | |||
| auto tensor = graph.GetNodeOutputTensor(node_output_index); | |||
| if (tensor == nullptr) { | |||
| return; | |||
| @@ -1407,22 +1396,20 @@ void KernelRuntime::InitGraphInputTensors(const std::shared_ptr<MemScheduler> &m | |||
| MS_LOG_EXCEPTION << "Invalid input tensor size:" << input_tensors.size() << " vs node size:" << input_nodes.size(); | |||
| } | |||
| for (size_t i = 0; i < input_tensors.size(); ++i) { | |||
| auto tensor = input_tensors[i]; | |||
| MS_EXCEPTION_IF_NULL(tensor); | |||
| auto input_node = input_nodes[i]; | |||
| if (!input_node->isa<Parameter>() || !AnfAlgo::OutputAddrExist(input_node, 0)) { | |||
| continue; | |||
| } | |||
| auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0); | |||
| auto tensor = input_tensors[i]; | |||
| MS_EXCEPTION_IF_NULL(tensor); | |||
| MemPriority priority = kMemPriorityLow; | |||
| auto tensor_address = tensor->device_address(); | |||
| if (!tensor->NeedSyncHostToDevice() && tensor_address != nullptr && tensor_address != device_address) { | |||
| tensor->data_sync(false); | |||
| } | |||
| if (AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>()) || | |||
| MemPriority priority = kMemPriorityLow; | |||
| if (AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>()) && | |||
| graph.IsUpdatedParameter(input_node->cast<ParameterPtr>())) { | |||
| tensor->set_device_address(device_address); | |||
| priority = kMemPriorityHigh; | |||
| } | |||
| auto tensor_size = LongToSize(tensor->data().nbytes()); | |||
| @@ -1477,7 +1464,9 @@ bool KernelRuntime::LaunchKernel(const session::KernelGraph &graph, const AnfNod | |||
| } | |||
| } | |||
| if (mem_scheduler != nullptr) { | |||
| SyncNodeOutputTensors(mem_scheduler, graph, kernel, mock); | |||
| if (!mock) { | |||
| SyncNodeOutputTensors(mem_scheduler, graph, kernel); | |||
| } | |||
| ret = mem_scheduler->PostCompute(stream); | |||
| if (!ret) { | |||
| return ret; | |||
| @@ -1553,9 +1542,43 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph, bool mock | |||
| } | |||
| LaunchKernelEvent(kernel_post_run_events, kernels[i]); | |||
| } | |||
| if (UseMemScheduler() && !mock) { | |||
| SyncUpdatedParameter(graph, mem_scheduler); | |||
| } | |||
| return true; | |||
| } | |||
| void KernelRuntime::SyncUpdatedParameter(const session::KernelGraph &graph, | |||
| const std::shared_ptr<MemScheduler> &mem_scheduler) { | |||
| MS_EXCEPTION_IF_NULL(mem_scheduler); | |||
| auto &input_nodes = graph.input_nodes(); | |||
| auto &input_tensors = graph.input_tensors(); | |||
| if (input_tensors.size() != input_nodes.size()) { | |||
| MS_LOG_EXCEPTION << "Invalid input tensor size:" << input_tensors.size() << " vs node size:" << input_nodes.size(); | |||
| } | |||
| for (size_t i = 0; i < input_tensors.size(); ++i) { | |||
| auto input_node = input_nodes[i]; | |||
| if (!input_node->isa<Parameter>() || !AnfAlgo::OutputAddrExist(input_node, 0)) { | |||
| continue; | |||
| } | |||
| auto parameter = input_node->cast<ParameterPtr>(); | |||
| MS_EXCEPTION_IF_NULL(parameter); | |||
| if (!graph.IsUpdatedParameter(parameter)) { | |||
| continue; | |||
| } | |||
| auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0); | |||
| auto tensor = input_tensors[i]; | |||
| MS_EXCEPTION_IF_NULL(tensor); | |||
| auto device_ptr = mem_scheduler->GetOrMalloc(device_address.get(), device_address->size(), kMemPriorityHigh); | |||
| if (device_ptr != nullptr) { | |||
| device_address->set_ptr(device_ptr); | |||
| tensor->set_device_address(device_address); | |||
| tensor->set_sync_status(kNeedSyncDeviceToHost); | |||
| } | |||
| } | |||
| } | |||
| void KernelRuntime::UseMemSchedulerIfNeeded(const session::KernelGraph &graph) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| @@ -95,6 +95,7 @@ class KernelRuntime { | |||
| void set_device_id(uint32_t device_id) { device_id_ = device_id; } | |||
| uint32_t device_id() { return device_id_; } | |||
| static bool UseMemScheduler(); | |||
| void SyncUpdatedParameter(const session::KernelGraph &graph, const std::shared_ptr<MemScheduler> &mem_scheduler); | |||
| #ifdef ENABLE_DEBUGGER | |||
| // set debugger | |||
| @@ -156,9 +157,9 @@ class KernelRuntime { | |||
| const DeviceAddress *device_address, const kernel::AddressPtr &kernel_addr); | |||
| void InitGraphInputTensors(const std::shared_ptr<MemScheduler> &mem_scheduler, const session::KernelGraph &graph); | |||
| void SyncNodeOutputTensors(const std::shared_ptr<MemScheduler> &mem_scheduler, const session::KernelGraph &graph, | |||
| const AnfNodePtr &kernel, bool mock); | |||
| const AnfNodePtr &kernel); | |||
| void SyncNodeOutputTensor(const std::shared_ptr<MemScheduler> &mem_scheduler, const KernelWithIndex &output, | |||
| const session::KernelGraph &graph, bool mock); | |||
| const session::KernelGraph &graph); | |||
| void AssignCommunicationMem(const session::KernelGraph &graph); | |||
| bool LaunchKernelMod(const session::KernelGraph &graph, bool mock = false); | |||
| @@ -43,7 +43,7 @@ void MemOffloadStrategy::Execute() { | |||
| CheckMemSize(); | |||
| if (need_swap_) { | |||
| GenEventSpan(); | |||
| GenNoSwapEventSet(); | |||
| GenSwapEventSet(); | |||
| } | |||
| GenComputeMemEvents(); | |||
| } | |||
| @@ -57,37 +57,41 @@ void MemOffloadStrategy::CountMemUsage() { | |||
| } | |||
| min_mem_used_.resize(total_step_, 0); | |||
| std::vector<size_t> total_mem_used(total_step_, 0); | |||
| size_t high_priority_mem_size = 0; | |||
| for (auto &item : mem_events_) { | |||
| auto &mem_events = item.second; | |||
| if (mem_events.empty()) { | |||
| continue; | |||
| } | |||
| auto first_event = mem_events[0]; | |||
| size_t cur_index = 0; | |||
| if (first_event != nullptr && first_event->type == kInit && mem_events.size() > 1) { | |||
| first_event = mem_events[1]; | |||
| cur_index = 1; | |||
| } | |||
| auto last_event = mem_events[mem_events.size() - 1]; | |||
| for (size_t start_index = first_event->index; start_index <= last_event->index; ++start_index) { | |||
| if (start_index < total_step_) { | |||
| const bool is_high_priority = IsHighPriorityMem(first_event->key); | |||
| if (is_high_priority) { | |||
| high_priority_mem_size += first_event->mem_size; | |||
| } else { | |||
| auto last_event = mem_events[mem_events.size() - 1]; | |||
| for (size_t start_index = first_event->index; start_index <= last_event->index; ++start_index) { | |||
| total_mem_used[start_index] += first_event->mem_size; | |||
| } else { | |||
| MS_LOG(ERROR) << "Error mem event index " << start_index; | |||
| } | |||
| } | |||
| for (; cur_index < mem_events.size(); ++cur_index) { | |||
| auto &event = mem_events[cur_index]; | |||
| // Calculate the minimum memory size for kernel execution. | |||
| for (const auto &event : mem_events) { | |||
| MS_EXCEPTION_IF_NULL(event); | |||
| if (event->index < total_step_) { | |||
| min_mem_used_[event->index] += first_event->mem_size; | |||
| } else { | |||
| MS_LOG(ERROR) << "Error mem event index " << event->index; | |||
| if (event->type != kGet) { | |||
| continue; | |||
| } | |||
| min_mem_used_[event->index] += first_event->mem_size; | |||
| } | |||
| } | |||
| min_mem_needed_ = *(std::max_element(min_mem_used_.begin(), min_mem_used_.end())); | |||
| mem_used_without_swap_ = *(std::max_element(total_mem_used.begin(), total_mem_used.end())); | |||
| mem_used_without_swap_ = *(std::max_element(total_mem_used.begin(), total_mem_used.end())) + high_priority_mem_size; | |||
| } | |||
| bool MemOffloadStrategy::IsHighPriorityMem(const void *key) { | |||
| auto iter = mem_priority_.find(key); | |||
| if (iter != mem_priority_.end()) { | |||
| return iter->second == kMemPriorityHigh; | |||
| } | |||
| return false; | |||
| } | |||
| void MemOffloadStrategy::CheckMemSize() { | |||
| @@ -110,48 +114,60 @@ void MemOffloadStrategy::GenEventSpan() { | |||
| } | |||
| for (auto &item : mem_events_) { | |||
| auto &tensor_events = item.second; | |||
| if (tensor_events.empty()) { | |||
| if (tensor_events.size() <= 1) { | |||
| continue; | |||
| } | |||
| auto first_event = tensor_events[0]; | |||
| size_t cur_index = 0; | |||
| if (first_event != nullptr && first_event->type == kInit && tensor_events.size() > 1) { | |||
| first_event = tensor_events[1]; | |||
| cur_index = 1; | |||
| } | |||
| size_t last_index = first_event->index; | |||
| for (; cur_index < tensor_events.size(); ++cur_index) { | |||
| auto &event = tensor_events[cur_index]; | |||
| const bool is_high_priority = IsHighPriorityMem(tensor_events[0]->key); | |||
| for (size_t event_index = 1; event_index < tensor_events.size(); ++event_index) { | |||
| auto &event = tensor_events[event_index]; | |||
| MS_EXCEPTION_IF_NULL(event); | |||
| auto span = event->index - last_index; | |||
| if (event->type != kGet) { | |||
| MS_LOG(EXCEPTION) << "Event should be Get except fist event."; | |||
| } | |||
| size_t span = 0; | |||
| if (event_index == 1 && is_high_priority) { | |||
| const auto &last_event = tensor_events[tensor_events.size() - 1]; | |||
| span = event->index + total_step_ - last_event->index; | |||
| } else { | |||
| span = event->index - tensor_events[event_index - 1]->index; | |||
| } | |||
| if (span > 1) { | |||
| (void)event_span_.emplace(span, event); | |||
| const size_t span_mul_size = (span - 1) * event->mem_size; | |||
| (void)event_span_.emplace(std::make_pair(span_mul_size, std::make_pair(event, span))); | |||
| } | |||
| last_index = event->index; | |||
| } | |||
| } | |||
| } | |||
| void MemOffloadStrategy::GenNoSwapEventSet() { | |||
| no_swap_events_.clear(); | |||
| void MemOffloadStrategy::GenSwapEventSet() { | |||
| swap_events_.clear(); | |||
| std::vector<size_t> cur_mem_used(min_mem_used_.begin(), min_mem_used_.end()); | |||
| for (auto iter = event_span_.begin(); iter != event_span_.end(); ++iter) { | |||
| auto span = iter->first; | |||
| auto &event = iter->second; | |||
| auto start_index = event->index - span + 1; | |||
| for (const auto &iter : event_span_) { | |||
| auto span = iter.second.second; | |||
| auto &event = iter.second.first; | |||
| auto start_index = ((total_step_ + event->index - span) % total_step_) + 1; | |||
| bool revert = false; | |||
| for (size_t i = start_index; i < event->index; ++i) { | |||
| cur_mem_used[i] += event->mem_size; | |||
| if (cur_mem_used[i] > mem_size_) { | |||
| size_t cur_index = start_index; | |||
| while (cur_index != event->index) { | |||
| cur_mem_used[cur_index] += event->mem_size; | |||
| if (cur_mem_used[cur_index] > mem_size_) { | |||
| revert = true; | |||
| } | |||
| cur_index += 1; | |||
| if (cur_index >= total_step_) { | |||
| cur_index = 0; | |||
| } | |||
| } | |||
| if (revert) { | |||
| for (size_t i = start_index; i < event->index; ++i) { | |||
| cur_mem_used[i] -= event->mem_size; | |||
| cur_index = start_index; | |||
| while (cur_index != event->index) { | |||
| cur_mem_used[cur_index] -= event->mem_size; | |||
| cur_index += 1; | |||
| if (cur_index >= total_step_) { | |||
| cur_index = 0; | |||
| } | |||
| } | |||
| } else { | |||
| (void)no_swap_events_.emplace(event); | |||
| (void)swap_events_.emplace(event); | |||
| } | |||
| } | |||
| } | |||
| @@ -166,34 +182,31 @@ void MemOffloadStrategy::GenComputeMemEvents() { | |||
| if (mem_events.empty()) { | |||
| continue; | |||
| } | |||
| // No need to generate events for memory that has only one event, which means it is never used by any kernel. | |||
| if (mem_events.size() <= 1) { | |||
| continue; | |||
| } | |||
| const bool is_high_priority = IsHighPriorityMem(item.first); | |||
| auto first_event = mem_events[0]; | |||
| MS_EXCEPTION_IF_NULL(first_event); | |||
| if (first_event->type == kInit) { | |||
| if (mem_events.size() > 1) { | |||
| auto &second_event = mem_events[1]; | |||
| MS_EXCEPTION_IF_NULL(second_event); | |||
| first_event->index = second_event->index; | |||
| } else { | |||
| continue; | |||
| } | |||
| const auto &second_event = mem_events[1]; | |||
| MS_EXCEPTION_IF_NULL(second_event); | |||
| if (is_high_priority && swap_events_.find(second_event) != swap_events_.end()) { | |||
| first_event->index = second_event->index; | |||
| } | |||
| if ((first_event->type == kInit || first_event->type == kMalloc) && | |||
| first_event->index < pre_compute_events_.size()) { | |||
| if ((first_event->type == kInit || first_event->type == kMalloc) && first_event->index < total_step_) { | |||
| pre_compute_events_[first_event->index].emplace_back(first_event); | |||
| } else { | |||
| MS_LOG_EXCEPTION << "First event should be init or malloc!"; | |||
| } | |||
| MemPriority priority = kMemPriorityLow; | |||
| auto iter = mem_priority_.find(first_event->key); | |||
| if (iter != mem_priority_.end()) { | |||
| priority = iter->second; | |||
| } | |||
| size_t pre_index = first_event->index; | |||
| const auto &last_event = mem_events[mem_events.size() - 1]; | |||
| size_t pre_index = is_high_priority ? last_event->index : first_event->index; | |||
| for (size_t i = 1; i < mem_events.size(); ++i) { | |||
| auto &event = mem_events[i]; | |||
| MS_EXCEPTION_IF_NULL(event); | |||
| if (need_swap_ && event->index - pre_index > 1 && priority == kMemPriorityLow && | |||
| no_swap_events_.find(event) == no_swap_events_.end()) { | |||
| if (need_swap_ && swap_events_.find(event) != swap_events_.end()) { | |||
| auto swap_out_event = std::make_shared<MemEvent>(kSwapOut, pre_index); | |||
| swap_out_event->key = item.first; | |||
| swap_out_event->mem_size = first_event->mem_size; | |||
| @@ -208,17 +221,19 @@ void MemOffloadStrategy::GenComputeMemEvents() { | |||
| } | |||
| pre_index = event->index; | |||
| } | |||
| if (priority != kMemPriorityLow) { | |||
| continue; | |||
| } | |||
| auto &last_event = mem_events[mem_events.size() - 1]; | |||
| MS_EXCEPTION_IF_NULL(last_event); | |||
| auto free_event = std::make_shared<MemEvent>(kFree, last_event->index); | |||
| free_event->key = item.first; | |||
| if (last_event->index < post_compute_events_.size()) { | |||
| (void)post_compute_events_[last_event->index].emplace_back(free_event); | |||
| if (!is_high_priority) { | |||
| GenFreeEvent(last_event); | |||
| } | |||
| } | |||
| } | |||
| void MemOffloadStrategy::GenFreeEvent(const std::shared_ptr<MemEvent> &last_event) { | |||
| MS_EXCEPTION_IF_NULL(last_event); | |||
| auto free_event = std::make_shared<MemEvent>(kFree, last_event->index); | |||
| free_event->key = last_event->key; | |||
| if (last_event->index < post_compute_events_.size()) { | |||
| (void)post_compute_events_[last_event->index].emplace_back(free_event); | |||
| } | |||
| } | |||
| } // namespace device | |||
| } // namespace mindspore | |||
| @@ -58,12 +58,15 @@ class MemOffloadStrategy { | |||
| bool need_swap() const { return need_swap_; } | |||
| bool IsHighPriorityMem(const void *key); | |||
| private: | |||
| void CountMemUsage(); | |||
| void CheckMemSize(); | |||
| void GenEventSpan(); | |||
| void GenNoSwapEventSet(); | |||
| void GenSwapEventSet(); | |||
| void GenComputeMemEvents(); | |||
| void GenFreeEvent(const std::shared_ptr<MemEvent> &last_event); | |||
| const std::map<const void *, MemPriority> &mem_priority_; | |||
| const std::map<const void *, std::vector<std::shared_ptr<MemEvent>>> &mem_events_; | |||
| @@ -74,8 +77,8 @@ class MemOffloadStrategy { | |||
| size_t mem_size_{0}; | |||
| std::vector<double> compute_time_; | |||
| bool need_swap_{false}; | |||
| std::multimap<size_t, std::shared_ptr<MemEvent>> event_span_; | |||
| std::set<std::shared_ptr<MemEvent>> no_swap_events_; | |||
| std::multimap<size_t, std::pair<std::shared_ptr<MemEvent>, size_t>> event_span_; | |||
| std::set<std::shared_ptr<MemEvent>> swap_events_; | |||
| std::vector<size_t> min_mem_used_; | |||
| size_t mem_used_without_swap_{0}; | |||
| size_t min_mem_needed_{0}; | |||
| @@ -26,7 +26,7 @@ | |||
| namespace mindspore { | |||
| namespace device { | |||
| namespace { | |||
| constexpr float kMaxMemReuseFactor = 0.9; | |||
| constexpr float kMaxMemReuseFactor = 1.0; | |||
| constexpr float kMinMemReuseFactor = 0.5; | |||
| constexpr float kRetryFactor = 0.1; | |||
| @@ -51,12 +51,25 @@ void MemScheduler::Clear() { | |||
| high_priority_device_ptr_.clear(); | |||
| } | |||
| bool MemScheduler::IsHighPriorityMem(const void *key) { | |||
| auto iter = mem_priority_.find(key); | |||
| if (iter != mem_priority_.end()) { | |||
| return iter->second == kMemPriorityHigh; | |||
| void MemScheduler::ClearTempMem() { | |||
| if (mem_handler_ == nullptr) { | |||
| return; | |||
| } | |||
| for (auto &item : mem_result_) { | |||
| const auto device_ptr = item.second; | |||
| if (device_ptr == nullptr) { | |||
| mem_handler_->FreeDevice(device_ptr); | |||
| } | |||
| } | |||
| return false; | |||
| mem_result_.clear(); | |||
| high_priority_device_ptr_.clear(); | |||
| for (const auto &item : swap_host_ptr_) { | |||
| const auto host_ptr = item.second; | |||
| if (host_ptr != nullptr) { | |||
| mem_handler_->FreeHost(host_ptr); | |||
| } | |||
| } | |||
| swap_host_ptr_.clear(); | |||
| } | |||
| void MemScheduler::SetMemPriority(const void *key, MemPriority priority) { mem_priority_[key] = priority; } | |||
| @@ -88,9 +101,8 @@ void *MemScheduler::GetOrMalloc(const void *key, size_t mem_size, MemPriority pr | |||
| if (mem_priority_.find(key) == mem_priority_.end()) { | |||
| mem_priority_[key] = priority; | |||
| Record(key, kMalloc, mem_size); | |||
| } else { | |||
| Record(key, kGet, mem_size); | |||
| } | |||
| Record(key, kGet, mem_size); | |||
| return nullptr; | |||
| } | |||
| if (strategy_ == nullptr) { | |||
| @@ -101,9 +113,8 @@ void *MemScheduler::GetOrMalloc(const void *key, size_t mem_size, MemPriority pr | |||
| auto ptr = iter->second; | |||
| MS_EXCEPTION_IF_NULL(ptr); | |||
| return ptr; | |||
| } else { | |||
| MS_LOG_EXCEPTION << "Mem extender get nullptr result!"; | |||
| } | |||
| return nullptr; | |||
| } | |||
| bool MemScheduler::PreCompute(void *stream) { | |||
| @@ -151,6 +162,9 @@ bool MemScheduler::PreCompute(void *stream) { | |||
| MS_EXCEPTION_IF_NULL(host_ptr); | |||
| mem_handler_->SwapIn(host_ptr, device_ptr, event->mem_size, stream); | |||
| mem_result_[event->key] = device_ptr; | |||
| if (mem_priority_[event->key] == kMemPriorityHigh) { | |||
| high_priority_device_ptr_[event->key] = device_ptr; | |||
| } | |||
| if (!from_init) { | |||
| mem_handler_->FreeHost(host_ptr); | |||
| (void)swap_host_ptr_.erase(event->key); | |||
| @@ -199,6 +213,9 @@ bool MemScheduler::PostCompute(void *stream) { | |||
| mem_handler_->SwapOut(device_ptr, host_ptr, event->mem_size, stream); | |||
| mem_handler_->FreeDevice(device_ptr); | |||
| (void)mem_result_.erase(event->key); | |||
| if (mem_priority_[event->key] == kMemPriorityHigh) { | |||
| high_priority_device_ptr_.erase(event->key); | |||
| } | |||
| } | |||
| } | |||
| ++current_step_; | |||
| @@ -221,6 +238,7 @@ void MemScheduler::OptMemUsage(float mem_used_factor) { | |||
| } | |||
| void MemScheduler::Optimize() { | |||
| AdjustFirstEventIndex(); | |||
| float mem_used_factor = kMaxMemReuseFactor; | |||
| while (!optimized_ && mem_used_factor >= kMinMemReuseFactor) { | |||
| OptMemUsage(mem_used_factor); | |||
| @@ -247,11 +265,30 @@ void MemScheduler::Optimize() { | |||
| if (ret) { | |||
| optimized_ = true; | |||
| } else { | |||
| ClearTempMem(); | |||
| mem_used_factor -= kRetryFactor; | |||
| } | |||
| } | |||
| } | |||
| void MemScheduler::AdjustFirstEventIndex() { | |||
| for (const auto &item : mem_events_) { | |||
| const auto &mem_events = item.second; | |||
| if (mem_events.empty()) { | |||
| continue; | |||
| } | |||
| auto &first_event = mem_events[0]; | |||
| MS_EXCEPTION_IF_NULL(first_event); | |||
| const auto &priority_iter = mem_priority_.find(item.first); | |||
| const bool is_high_priority = (priority_iter != mem_priority_.end() && priority_iter->second == kMemPriorityHigh); | |||
| if (first_event->type == kInit && !is_high_priority && mem_events.size() > 1) { | |||
| const auto &second_event = mem_events[1]; | |||
| MS_EXCEPTION_IF_NULL(second_event); | |||
| first_event->index = second_event->index; | |||
| } | |||
| } | |||
| } | |||
| void MemScheduler::Update() { | |||
| if (!optimized_) { | |||
| return; | |||
| @@ -70,7 +70,7 @@ class MemScheduler { | |||
| void Clear(); | |||
| bool IsHighPriorityMem(const void *key); | |||
| void ClearTempMem(); | |||
| void SetMemPriority(const void *key, MemPriority priority); | |||
| @@ -79,6 +79,8 @@ class MemScheduler { | |||
| void OptMemUsage(float mem_used_factor = 1.0f); | |||
| void AdjustFirstEventIndex(); | |||
| std::map<const void *, MemPriority> mem_priority_; | |||
| std::map<const void *, std::vector<std::shared_ptr<MemEvent>>> mem_events_; | |||
| std::vector<std::vector<std::shared_ptr<MemEvent>>> step_events_; | |||