Merge pull request !27390 from tanghuikang/swap_strategy_adjusttags/v1.6.0
| @@ -1406,6 +1406,18 @@ bool AnfRuntimeAlgorithm::IsLabelIndexInNode(const AnfNodePtr &node, size_t labe | |||
| return false; | |||
| } | |||
| bool AnfRuntimeAlgorithm::IsUpdateParameterKernel(const CNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto node_name = GetCNodeName(node); | |||
| if (HasNodeAttr(kAttrAsync, node) && GetNodeAttr<bool>(node, kAttrAsync)) { | |||
| return false; | |||
| } | |||
| if (kOptOperatorSet.find(node_name) == kOptOperatorSet.end() && node_name.find("Assign") == string::npos) { | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| void AnfRuntimeAlgorithm::SetStreamId(uint32_t stream_id, AnfNode *node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); | |||
| @@ -246,6 +246,8 @@ class AnfRuntimeAlgorithm { | |||
| static bool IsParameterWeight(const ParameterPtr &node); | |||
| // checkout whether the anf node is include the label_index. | |||
| static bool IsLabelIndexInNode(const AnfNodePtr &node, size_t label_index); | |||
| // Check whether the cnode update parameter | |||
| static bool IsUpdateParameterKernel(const CNodePtr &node); | |||
| // set stream id of kernel,which will be set in stream assign and be used in stream generate | |||
| static void SetStreamId(uint32_t stream_id, AnfNode *node); | |||
| // get stream id | |||
| @@ -1346,13 +1346,7 @@ void KernelGraph::SetOptimizerFlag() { | |||
| has_optimizer_ = false; | |||
| for (const auto &cnode : execution_order_) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto node_name = AnfAlgo::GetCNodeName(cnode); | |||
| if (AnfAlgo::HasNodeAttr(kAttrAsync, cnode) && AnfAlgo::GetNodeAttr<bool>(cnode, kAttrAsync)) { | |||
| continue; | |||
| } | |||
| if (kOptOperatorSet.find(node_name) != kOptOperatorSet.end()) { | |||
| has_optimizer_ = true; | |||
| } else if (node_name.find("Assign") == string::npos) { | |||
| if (!AnfAlgo::IsUpdateParameterKernel(cnode)) { | |||
| continue; | |||
| } | |||
| for (auto &input : cnode->inputs()) { | |||
| @@ -307,10 +307,7 @@ class KernelGraph : public FuncGraph { | |||
| bool has_optimizer() const { return has_optimizer_; } | |||
| bool IsUpdatedParameter(const ParameterPtr ¶m) const { | |||
| if (updated_parameters_.find(param) != updated_parameters_.end()) { | |||
| return true; | |||
| } | |||
| return false; | |||
| return updated_parameters_.find(param) != updated_parameters_.end(); | |||
| } | |||
| // handle graph dependency | |||
| void AddPreGraph(const std::shared_ptr<session::KernelGraph> &graph) { | |||
| @@ -1324,6 +1324,7 @@ void KernelRuntime::AssignKernelAddress(const std::shared_ptr<MemScheduler> &mem | |||
| auto kernel_mod = AnfAlgo::GetKernelMod(kernel); | |||
| MS_EXCEPTION_IF_NULL(kernel_mod); | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel); | |||
| const auto update_parameter = AnfAlgo::IsUpdateParameterKernel(cnode); | |||
| for (size_t j = 0; j < input_num; ++j) { | |||
| auto real_input = AnfAlgo::GetRealInputIndex(kernel, j); | |||
| auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(kernel, real_input, true); | |||
| @@ -1335,6 +1336,14 @@ void KernelRuntime::AssignKernelAddress(const std::shared_ptr<MemScheduler> &mem | |||
| GetOrMallocAddress(mem_scheduler, device_address, input); | |||
| input->size = device_address->size_; | |||
| kernel_launch_info->inputs_.emplace_back(input); | |||
| if (update_parameter && input_node->isa<Parameter>()) { | |||
| auto param = input_node->cast<ParameterPtr>(); | |||
| auto abstract = param->abstract(); | |||
| MS_EXCEPTION_IF_NULL(abstract); | |||
| if (abstract->isa<abstract::AbstractRef>()) { | |||
| mem_scheduler->UpdateHighPriorityMem(device_address); | |||
| } | |||
| } | |||
| } | |||
| for (size_t j = 0; j < kernel_mod->GetOutputSizeList().size(); ++j) { | |||
| @@ -1410,6 +1419,7 @@ void KernelRuntime::InitGraphInputTensors(const std::shared_ptr<MemScheduler> &m | |||
| if (input_tensors.size() != input_nodes.size()) { | |||
| MS_LOG_EXCEPTION << "Invalid input tensor size:" << input_tensors.size() << " vs node size:" << input_nodes.size(); | |||
| } | |||
| mem_scheduler->ClearMemNeedInit(); | |||
| for (size_t i = 0; i < input_tensors.size(); ++i) { | |||
| auto input_node = input_nodes[i]; | |||
| if (!input_node->isa<Parameter>() || !AnfAlgo::OutputAddrExist(input_node, 0)) { | |||
| @@ -1418,16 +1428,30 @@ void KernelRuntime::InitGraphInputTensors(const std::shared_ptr<MemScheduler> &m | |||
| auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0); | |||
| auto tensor = input_tensors[i]; | |||
| MS_EXCEPTION_IF_NULL(tensor); | |||
| auto tensor_address = tensor->device_address(); | |||
| if (!tensor->NeedSyncHostToDevice() && tensor_address != nullptr && tensor_address != device_address) { | |||
| auto tensor_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()); | |||
| const auto tensor_size = LongToSize(tensor->data().nbytes()); | |||
| if (tensor_address == device_address) { | |||
| if (tensor->NeedSyncHostToDevice()) { | |||
| tensor_address->SyncHostToDevice(trans::GetRuntimePaddingShape(input_node, 0), tensor->data().nbytes(), | |||
| tensor->data_type(), tensor->data_c(), tensor->device_info().host_format_); | |||
| tensor->set_sync_status(kNoNeedSync); | |||
| } | |||
| if (mem_scheduler->HasDeviceMem(tensor_address.get())) { | |||
| tensor_address->set_ptr(nullptr); | |||
| } | |||
| continue; | |||
| } | |||
| if (tensor->NeedSyncHostToDevice()) { | |||
| mem_scheduler->AddMemNeedInit(device_address.get()); | |||
| } else if (tensor_address != nullptr) { | |||
| tensor->data_sync(false); | |||
| mem_scheduler->AddMemNeedInit(device_address.get()); | |||
| } | |||
| MemPriority priority = kMemPriorityLow; | |||
| if (AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>()) && | |||
| graph.IsUpdatedParameter(input_node->cast<ParameterPtr>())) { | |||
| const auto ¶meter = input_node->cast<ParameterPtr>(); | |||
| if (AnfAlgo::IsParameterWeight(parameter) || graph.IsUpdatedParameter(parameter)) { | |||
| priority = kMemPriorityHigh; | |||
| } | |||
| auto tensor_size = LongToSize(tensor->data().nbytes()); | |||
| mem_scheduler->Init(device_address.get(), tensor->data_c(), tensor_size, priority); | |||
| tensor->set_sync_status(kNoNeedSync); | |||
| } | |||
| @@ -22,6 +22,9 @@ | |||
| namespace mindspore { | |||
| namespace device { | |||
| constexpr size_t kFirstGetMemEventIndex = 1; | |||
| constexpr size_t kInitOrMallocMemEventIndex = 0; | |||
| std::vector<std::shared_ptr<MemEvent>> &MemOffloadStrategy::GetPreComputeEvents(size_t step) { | |||
| if (pre_compute_events_.size() <= step) { | |||
| MS_LOG_EXCEPTION << "Index out of pre event range, index:" << step << ", event size:" << pre_compute_events_.size(); | |||
| @@ -62,7 +65,7 @@ void MemOffloadStrategy::CountMemUsage() { | |||
| if (mem_events.empty()) { | |||
| continue; | |||
| } | |||
| auto first_event = mem_events[0]; | |||
| auto first_event = mem_events[kInitOrMallocMemEventIndex]; | |||
| const bool is_high_priority = IsHighPriorityMem(first_event->key); | |||
| if (is_high_priority) { | |||
| high_priority_mem_size += first_event->mem_size; | |||
| @@ -83,6 +86,10 @@ void MemOffloadStrategy::CountMemUsage() { | |||
| } | |||
| min_mem_needed_ = *(std::max_element(min_mem_used_.begin(), min_mem_used_.end())); | |||
| mem_used_without_swap_ = *(std::max_element(total_mem_used.begin(), total_mem_used.end())) + high_priority_mem_size; | |||
| if (mem_size_ < min_mem_needed_) { | |||
| MS_LOG(EXCEPTION) << "Out of memory, as available mem size is " << mem_size_ << " while graph needs at least " | |||
| << min_mem_needed_; | |||
| } | |||
| } | |||
| bool MemOffloadStrategy::IsHighPriorityMem(const void *key) { | |||
| @@ -94,11 +101,6 @@ bool MemOffloadStrategy::IsHighPriorityMem(const void *key) { | |||
| } | |||
| void MemOffloadStrategy::CheckMemSize() { | |||
| if (mem_size_ < min_mem_needed_) { | |||
| MS_LOG(EXCEPTION) << "Out of memory, as available mem size is " << mem_size_ << " while graph needs at least " | |||
| << min_mem_needed_; | |||
| } | |||
| if (mem_size_ < mem_used_without_swap_ || !manual_offload_keys_.empty()) { | |||
| need_swap_ = true; | |||
| } | |||
| @@ -116,19 +118,20 @@ void MemOffloadStrategy::GenEventSpan() { | |||
| if (tensor_events.size() <= 1) { | |||
| continue; | |||
| } | |||
| const bool is_high_priority = IsHighPriorityMem(tensor_events[0]->key); | |||
| for (size_t event_index = 1; event_index < tensor_events.size(); ++event_index) { | |||
| auto &event = tensor_events[event_index]; | |||
| const bool is_high_priority = IsHighPriorityMem(tensor_events[kInitOrMallocMemEventIndex]->key); | |||
| for (size_t i = kFirstGetMemEventIndex; i < tensor_events.size(); ++i) { | |||
| auto &event = tensor_events[i]; | |||
| MS_EXCEPTION_IF_NULL(event); | |||
| if (event->type != kGet) { | |||
| MS_LOG(EXCEPTION) << "Event should be Get except fist event."; | |||
| } | |||
| size_t span = 0; | |||
| if (event_index == 1 && is_high_priority) { | |||
| const auto &last_event = tensor_events[tensor_events.size() - 1]; | |||
| span = event->index + total_step_ - last_event->index; | |||
| } else { | |||
| span = event->index - tensor_events[event_index - 1]->index; | |||
| auto latest_event = tensor_events[i - 1]; | |||
| if (i == kFirstGetMemEventIndex && is_high_priority) { | |||
| latest_event = tensor_events[tensor_events.size() - 1]; | |||
| } | |||
| auto span = GetSpanBetweenMemEvents(latest_event->index, event->index); | |||
| if (is_high_priority && span == 0 && latest_event == event) { | |||
| span = total_step_; | |||
| } | |||
| if (span > 1) { | |||
| const size_t span_mul_size = (span - 1) * event->mem_size; | |||
| @@ -156,7 +159,7 @@ void MemOffloadStrategy::GenSwapEventSet() { | |||
| for (const auto &iter : event_span_) { | |||
| auto span = iter.second.second; | |||
| auto &event = iter.second.first; | |||
| auto start_index = ((total_step_ + event->index - span) % total_step_) + 1; | |||
| auto start_index = ((event->index + total_step_ - span + 1) % total_step_); | |||
| bool revert = false; | |||
| size_t cur_index = start_index; | |||
| while (cur_index != event->index) { | |||
| @@ -196,12 +199,12 @@ void MemOffloadStrategy::GenComputeMemEvents() { | |||
| } | |||
| const bool is_high_priority = IsHighPriorityMem(item.first); | |||
| auto first_event = mem_events[0]; | |||
| auto first_event = mem_events[kInitOrMallocMemEventIndex]; | |||
| MS_EXCEPTION_IF_NULL(first_event); | |||
| const auto &second_event = mem_events[1]; | |||
| MS_EXCEPTION_IF_NULL(second_event); | |||
| if (is_high_priority && swap_events_.find(second_event) != swap_events_.end()) { | |||
| first_event->index = second_event->index; | |||
| const auto &first_get_event = mem_events[kFirstGetMemEventIndex]; | |||
| MS_EXCEPTION_IF_NULL(first_get_event); | |||
| if (is_high_priority && swap_events_.find(first_get_event) != swap_events_.end()) { | |||
| first_event->index = first_get_event->index; | |||
| } | |||
| if ((first_event->type == kInit || first_event->type == kMalloc) && first_event->index < total_step_) { | |||
| pre_compute_events_[first_event->index].emplace_back(first_event); | |||
| @@ -211,16 +214,21 @@ void MemOffloadStrategy::GenComputeMemEvents() { | |||
| const auto &last_event = mem_events[mem_events.size() - 1]; | |||
| size_t pre_index = is_high_priority ? last_event->index : first_event->index; | |||
| for (size_t i = 1; i < mem_events.size(); ++i) { | |||
| const auto &swap_out_event_index = GetSwapOutEventIndex(item.first, mem_events); | |||
| for (size_t i = kFirstGetMemEventIndex; i < mem_events.size(); ++i) { | |||
| auto &event = mem_events[i]; | |||
| MS_EXCEPTION_IF_NULL(event); | |||
| if (need_swap_ && swap_events_.find(event) != swap_events_.end()) { | |||
| auto swap_out_event = std::make_shared<MemEvent>(kSwapOut, pre_index); | |||
| swap_out_event->key = item.first; | |||
| swap_out_event->mem_size = first_event->mem_size; | |||
| post_compute_events_[pre_index].emplace_back(swap_out_event); | |||
| MemEventType event_type = kSwapOut; | |||
| if (is_high_priority && swap_out_event_index.count(i) == 0) { | |||
| event_type = kFree; | |||
| } | |||
| auto free_or_swap_out_event = std::make_shared<MemEvent>(event_type, pre_index); | |||
| free_or_swap_out_event->key = item.first; | |||
| free_or_swap_out_event->mem_size = first_event->mem_size; | |||
| post_compute_events_[pre_index].emplace_back(free_or_swap_out_event); | |||
| // avoid swap-in-event follow init-event | |||
| if (first_event->type != kInit || i != 1) { | |||
| if (i != kFirstGetMemEventIndex || first_event->type != kInit) { | |||
| auto swap_in_event = std::make_shared<MemEvent>(kSwapIn, event->index); | |||
| swap_in_event->key = item.first; | |||
| swap_in_event->mem_size = first_event->mem_size; | |||
| @@ -246,5 +254,39 @@ void MemOffloadStrategy::GenFreeEvent(const std::shared_ptr<MemEvent> &last_even | |||
| (void)post_compute_events_[last_event->index].emplace_back(free_event); | |||
| } | |||
| } | |||
| std::set<size_t> MemOffloadStrategy::GetSwapOutEventIndex(const void *key, | |||
| const std::vector<std::shared_ptr<MemEvent>> &mem_events) { | |||
| const auto &update_step_iter = high_priority_updated_step_.find(key); | |||
| if (update_step_iter == high_priority_updated_step_.end() || update_step_iter->second.empty()) { | |||
| return std::set<size_t>(); | |||
| } | |||
| const auto &update_steps = update_step_iter->second; | |||
| size_t update_steps_index = 0; | |||
| std::set<size_t> swap_out_event_index; | |||
| size_t min_swap_index_before_update = SIZE_MAX; | |||
| size_t max_swap_out_step = 0; | |||
| for (size_t i = 0; i < mem_events.size(); ++i) { | |||
| const auto &mem_event = mem_events[i]; | |||
| if (swap_events_.count(mem_event) == 0) { | |||
| continue; | |||
| } | |||
| if (mem_event->index <= update_steps[update_steps_index]) { | |||
| if (i <= min_swap_index_before_update) { | |||
| min_swap_index_before_update = i; | |||
| } | |||
| } else { | |||
| swap_out_event_index.insert(i); | |||
| max_swap_out_step = mem_event->index; | |||
| while (update_steps_index < update_steps.size() && update_steps[update_steps_index] < mem_event->index) { | |||
| ++update_steps_index; | |||
| } | |||
| } | |||
| } | |||
| if (max_swap_out_step <= update_steps[update_steps.size() - 1]) { | |||
| swap_out_event_index.insert(min_swap_index_before_update); | |||
| } | |||
| return swap_out_event_index; | |||
| } | |||
| } // namespace device | |||
| } // namespace mindspore | |||
| @@ -41,10 +41,12 @@ class MemOffloadStrategy { | |||
| public: | |||
| MemOffloadStrategy(const std::map<const void *, MemPriority> &mem_priority, | |||
| const std::map<const void *, std::vector<std::shared_ptr<MemEvent>>> &mem_events, | |||
| const std::set<const void *> &manual_offload_keys, size_t total_step) | |||
| const std::set<const void *> &manual_offload_keys, | |||
| const std::map<const void *, std::vector<size_t>> &high_priority_updated_step, size_t total_step) | |||
| : mem_priority_(mem_priority), | |||
| mem_events_(mem_events), | |||
| manual_offload_keys_(manual_offload_keys), | |||
| high_priority_updated_step_(high_priority_updated_step), | |||
| total_step_(total_step) {} | |||
| virtual ~MemOffloadStrategy() = default; | |||
| @@ -75,10 +77,16 @@ class MemOffloadStrategy { | |||
| void GenComputeMemEvents(); | |||
| void GenFreeEvent(const std::shared_ptr<MemEvent> &last_event); | |||
| std::set<size_t> GetSwapOutEventIndex(const void *key, const std::vector<std::shared_ptr<MemEvent>> &mem_events); | |||
| size_t GetSpanBetweenMemEvents(size_t pre_step, size_t post_step) const { | |||
| return (post_step + total_step_ - pre_step) % total_step_; | |||
| } | |||
| const std::map<const void *, MemPriority> &mem_priority_; | |||
| const std::map<const void *, std::vector<std::shared_ptr<MemEvent>>> &mem_events_; | |||
| const std::set<const void *> &manual_offload_keys_; | |||
| std::map<const void *, std::vector<size_t>> high_priority_updated_step_; | |||
| const size_t total_step_; | |||
| std::vector<std::vector<std::shared_ptr<MemEvent>>> pre_compute_events_; | |||
| std::vector<std::vector<std::shared_ptr<MemEvent>>> post_compute_events_; | |||
| @@ -45,10 +45,10 @@ void MemScheduler::Clear() { | |||
| if (mem_handler_ == nullptr) { | |||
| return; | |||
| } | |||
| for (auto &item : high_priority_device_ptr_) { | |||
| for (auto &item : mem_result_) { | |||
| mem_handler_->FreeDevice(item.second); | |||
| } | |||
| high_priority_device_ptr_.clear(); | |||
| mem_result_.clear(); | |||
| } | |||
| void MemScheduler::ClearAllocatedMem() { | |||
| @@ -57,12 +57,11 @@ void MemScheduler::ClearAllocatedMem() { | |||
| } | |||
| for (auto &item : mem_result_) { | |||
| const auto device_ptr = item.second; | |||
| if (device_ptr == nullptr) { | |||
| if (device_ptr != nullptr) { | |||
| mem_handler_->FreeDevice(device_ptr); | |||
| } | |||
| } | |||
| mem_result_.clear(); | |||
| high_priority_device_ptr_.clear(); | |||
| for (const auto &item : swap_host_ptr_) { | |||
| const auto host_ptr = item.second; | |||
| if (host_ptr != nullptr) { | |||
| @@ -125,22 +124,19 @@ bool MemScheduler::PreCompute(void *stream) { | |||
| MS_EXCEPTION_IF_NULL(event); | |||
| MS_LOG(DEBUG) << "Pre compute " << current_step_ << ": " << event->key << " v " << event->type; | |||
| if (event->type == kInit || event->type == kMalloc) { | |||
| auto priority = mem_priority_[event->key]; | |||
| auto iter = high_priority_device_ptr_.find(event->key); | |||
| if (priority != kMemPriorityLow && iter != high_priority_device_ptr_.end()) { | |||
| MS_EXCEPTION_IF_NULL(iter->second); | |||
| mem_result_[event->key] = iter->second; | |||
| continue; | |||
| } | |||
| auto device_ptr = mem_handler_->MallocDevice(event->mem_size); | |||
| if (device_ptr == nullptr) { | |||
| return false; | |||
| } | |||
| if (priority != kMemPriorityLow) { | |||
| high_priority_device_ptr_[event->key] = device_ptr; | |||
| const auto &iter = mem_result_.find(event->key); | |||
| const bool new_malloc = iter == mem_result_.end(); | |||
| void *device_ptr; | |||
| if (new_malloc) { | |||
| device_ptr = mem_handler_->MallocDevice(event->mem_size); | |||
| if (device_ptr == nullptr) { | |||
| return false; | |||
| } | |||
| } else { | |||
| device_ptr = iter->second; | |||
| } | |||
| if (event->type == kInit) { | |||
| if (event->type == kInit && (new_malloc || high_priority_mem_need_init_.count(event->key) != 0)) { | |||
| auto host_ptr = init_host_ptr_[event->key]; | |||
| MS_EXCEPTION_IF_NULL(host_ptr); | |||
| mem_handler_->SwapIn(host_ptr, device_ptr, event->mem_size, stream); | |||
| @@ -160,9 +156,6 @@ bool MemScheduler::PreCompute(void *stream) { | |||
| MS_EXCEPTION_IF_NULL(host_ptr); | |||
| mem_handler_->SwapIn(host_ptr, device_ptr, event->mem_size, stream); | |||
| mem_result_[event->key] = device_ptr; | |||
| if (mem_priority_[event->key] == kMemPriorityHigh) { | |||
| high_priority_device_ptr_[event->key] = device_ptr; | |||
| } | |||
| if (!from_init) { | |||
| mem_handler_->FreeHost(host_ptr); | |||
| (void)swap_host_ptr_.erase(event->key); | |||
| @@ -211,9 +204,6 @@ bool MemScheduler::PostCompute(void *stream) { | |||
| mem_handler_->SwapOut(device_ptr, host_ptr, event->mem_size, stream); | |||
| mem_handler_->FreeDevice(device_ptr); | |||
| (void)mem_result_.erase(event->key); | |||
| if (mem_priority_[event->key] == kMemPriorityHigh) { | |||
| high_priority_device_ptr_.erase(event->key); | |||
| } | |||
| } | |||
| } | |||
| ++current_step_; | |||
| @@ -225,7 +215,8 @@ void MemScheduler::OptMemUsage(float mem_used_factor) { | |||
| MS_EXCEPTION_IF_NULL(mem_handler_); | |||
| if (strategy_ == nullptr) { | |||
| strategy_ = std::make_shared<MemOffloadStrategy>(mem_priority_, mem_events_, manual_offload_keys_, total_step_); | |||
| strategy_ = std::make_shared<MemOffloadStrategy>(mem_priority_, mem_events_, manual_offload_keys_, | |||
| high_priority_updated_step_, total_step_); | |||
| if (manual_offload_keys_.empty()) { | |||
| compute_time_.resize(total_step_); | |||
| } else { | |||
| @@ -53,6 +53,14 @@ class MemScheduler { | |||
| void *GetOrMalloc(const void *key, size_t mem_size, MemPriority priority = kMemPriorityLow); | |||
| bool HasDeviceMem(const void *key) const { return mem_result_.find(key) != mem_result_.end(); } | |||
| void UpdateHighPriorityMem(const void *key) { | |||
| if (need_record_event_) { | |||
| high_priority_updated_step_[key].emplace_back(current_step_); | |||
| } | |||
| } | |||
| void SetTotalStep(size_t step) { | |||
| total_step_ = step; | |||
| step_events_.resize(total_step_); | |||
| @@ -72,6 +80,10 @@ class MemScheduler { | |||
| void SetOffload(const void *key) { (void)manual_offload_keys_.insert(key); } | |||
| void AddMemNeedInit(const void *key) { high_priority_mem_need_init_.insert(key); } | |||
| void ClearMemNeedInit() { high_priority_mem_need_init_.clear(); } | |||
| private: | |||
| void Record(const void *key, const MemEventType &event_type, size_t mem_size = 0); | |||
| @@ -86,7 +98,8 @@ class MemScheduler { | |||
| std::map<const void *, void *> mem_result_; | |||
| std::map<const void *, void *> init_host_ptr_; | |||
| std::map<const void *, void *> swap_host_ptr_; | |||
| std::map<const void *, void *> high_priority_device_ptr_; | |||
| std::map<const void *, std::vector<size_t>> high_priority_updated_step_; | |||
| std::set<const void *> high_priority_mem_need_init_; | |||
| size_t total_step_{0}; | |||
| size_t current_step_{0}; | |||
| bool need_record_event_{true}; | |||