diff --git a/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc b/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc index dc4487ccee..b771419f7f 100644 --- a/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc +++ b/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc @@ -28,10 +28,13 @@ #include "common/utils.h" #include "device/gpu/gpu_memory_manager.h" #include "kernel/common_utils.h" +#include "device/gpu/gpu_memory_copy_manager.h" namespace mindspore { namespace device { namespace gpu { +using mindspore::device::memswap::MemSwapManager; +using mindspore::device::memswap::SwapKind; bool GPUKernelRuntime::SyncStream() { return GPUDeviceManager::GetInstance().SyncStream(stream_); } bool GPUKernelRuntime::Init() { @@ -101,6 +104,12 @@ void GPUKernelRuntime::ReleaseDeviceRes() { } CHECK_OP_RET_WITH_EXCEPT(GpuBufferMgr::GetInstance().Destroy(), "Could not destroy gpu data queue."); } + // destroy remaining memory swap events and free host memory + if (mem_swap_manager_->trigger_swap()) { + mem_swap_manager_->ClearSwapQueue(); + mem_swap_manager_->ReleaseHostPinnedMem(); + } + GPUDeviceManager::GetInstance().ReleaseDevice(); if (mem_manager_ != nullptr) { mem_manager_->FreeDeviceMemory(); @@ -126,15 +135,29 @@ void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) { } bool GPUKernelRuntime::Run(session::KernelGraph *graph) { - bool ret; + bool ret = true; auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); bool is_enable_dynamic_mem = context_ptr->enable_dynamic_mem_pool(); bool is_enable_pynative_infer = context_ptr->enable_pynative_infer(); + auto iter = mem_swap_map_.find(graph); + if (iter == mem_swap_map_.end()) { + GPUMemCopyManagerPtr gpu_mem_copy_manager = std::make_shared(); + iter = mem_swap_map_.emplace(graph, std::make_shared(gpu_mem_copy_manager)).first; + } + mem_swap_manager_ = iter->second; struct timeval start_time, end_time; (void)gettimeofday(&start_time, nullptr); if (is_enable_dynamic_mem && !is_enable_pynative_infer) { - ret = LaunchKernelDynamic(graph); + while (!LaunchKernelDynamic(graph)) { + ClearKernelOutputAddress(graph); + if (!mem_swap_manager_->mem_swap_init()) { + mem_swap_manager_->Init(graph); + } + if (!mem_swap_manager_->RetreatSwapInfo()) { + return false; + } + } } else { ret = LaunchKernel(graph); } @@ -181,6 +204,27 @@ void GPUKernelRuntime::InitKernelOutputAddress(const session::KernelGraph *graph } } +void GPUKernelRuntime::ClearKernelOutputAddress(const session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + auto &kernels = graph->execution_order(); + for (const auto &kernel : kernels) { + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + auto output_sizes = kernel_mod->GetOutputSizeList(); + for (size_t i = 0; i < output_sizes.size(); ++i) { + if (!AnfAlgo::OutputAddrExist(kernel, i)) { + continue; + } + + auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i); + if (device_address->ptr_) { + mem_manager_->FreeMemFromMemPool(device_address); + } + device_address->set_status(DeviceAddressStatus::kInDevice); + } + } +} + bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph) { MS_EXCEPTION_IF_NULL(graph); auto graph_id = graph->graph_id(); @@ -198,32 +242,157 @@ bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph) { AddressPtrList kernel_inputs; AddressPtrList kernel_workspaces; AddressPtrList kernel_outputs; - AllocKernelDynamicRes(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs); - if (!kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_)) { - MS_LOG(ERROR) << "Launch kernel failed."; + auto ret = AllocKernelDynamicRes(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs); + if (!ret) { return false; } + if (!kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_)) { + MS_LOG(EXCEPTION) << "Launch kernel failed."; + } FreeKernelDynamicRes(kernel, kernel_workspaces, graph_id); + + if (mem_swap_manager_->trigger_swap() && mem_swap_manager_->QueryKernelTriggerSwap(kernel)) { + CHECK_OP_RET_WITH_EXCEPT(SyncStream(), "SyncStream failed."); + if (!AddMemSwapTask(kernel)) { + return false; + } + } + + if (mem_swap_manager_->trigger_swap()) { + mem_swap_manager_->SyncMemCopyStream(SwapKind::kDeviceToHost); + } } - if (!SyncStream()) { - MS_LOG(ERROR) << "SyncStream failed."; - return false; + CHECK_OP_RET_WITH_EXCEPT(SyncStream(), "SyncStream failed."); + if (mem_swap_manager_->trigger_swap()) { + mem_swap_manager_->ClearSwapQueue(); + } + return true; +} + +bool GPUKernelRuntime::AddMemSwapTask(const AnfNodePtr &kernel) { + auto &mem_swap_info_list = mem_swap_manager_->QueryKernelMemSwapInfo(kernel); + for (auto &mem_swap_info : mem_swap_info_list) { + auto &kernel_exec_info = mem_swap_manager_->SearchKernelExecutionInfo(mem_swap_info.kernel_); + const HostAddress &host_address = kernel_exec_info.host_addrs_[mem_swap_info.output_idx_]; + auto device_address = AnfAlgo::GetMutableOutputAddr(mem_swap_info.kernel_, mem_swap_info.output_idx_); + + if (mem_swap_info.swap_kind_ == SwapKind::kDeviceToHost) { + mem_swap_manager_->AddMemSwapTask(SwapKind::kDeviceToHost, device_address, host_address); + } else if (mem_swap_info.swap_kind_ == SwapKind::kHostToDevice) { + auto status = device_address->status(); + if (status == DeviceAddressStatus::kInDeviceToHost) { + mem_swap_manager_->InsertSwapInBlackList(device_address->ptr_); + device_address->set_status(DeviceAddressStatus::kInDevice); + } else if (status == DeviceAddressStatus::kInHost) { + if (!device_address->ptr_ && !AttemptMallocMem(device_address, device_address->size_)) { + return false; + } + if (!mem_swap_manager_->FindInSwapInBlackList(device_address->ptr_)) { + mem_swap_manager_->AddMemSwapTask(SwapKind::kHostToDevice, device_address, host_address); + } + } + } + } + return true; +} + +bool GPUKernelRuntime::AttemptMallocMem(const DeviceAddressPtr &device_address, size_t size) { + auto ret = mem_manager_->MallocMemFromMemPool(device_address, size); + if (!ret) { + if (!mem_swap_manager_->trigger_swap()) { + return false; + } + + mem_swap_manager_->SyncMemCopyStream(SwapKind::kDeviceToHost); + while (auto device_address_swap_out = mem_swap_manager_->UpdateSwapQueue(SwapKind::kDeviceToHost)) { + if (!mem_swap_manager_->FindInSwapInBlackList(device_address_swap_out->ptr_) && device_address_swap_out->ptr_) { + device_address_swap_out->set_status(DeviceAddressStatus::kInHost); + mem_manager_->FreeMemFromMemPool(device_address_swap_out); + } + } + + ret = mem_manager_->MallocMemFromMemPool(device_address, size); + if (!ret) { + return false; + } } return true; } -void GPUKernelRuntime::AllocKernelDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, +void *GPUKernelRuntime::AttemptMallocMem(size_t size) { + auto device_ptr = mem_manager_->MallocMemFromMemPool(size); + if (!device_ptr) { + if (!mem_swap_manager_->trigger_swap()) { + return nullptr; + } + + mem_swap_manager_->SyncMemCopyStream(SwapKind::kDeviceToHost); + while (auto device_address_swap_out = mem_swap_manager_->UpdateSwapQueue(SwapKind::kDeviceToHost)) { + if (!mem_swap_manager_->FindInSwapInBlackList(device_address_swap_out->ptr_) && device_address_swap_out->ptr_) { + device_address_swap_out->set_status(DeviceAddressStatus::kInHost); + mem_manager_->FreeMemFromMemPool(device_address_swap_out); + } + } + + device_ptr = mem_manager_->MallocMemFromMemPool(size); + if (!device_ptr) { + return nullptr; + } + } + return device_ptr; +} + +bool GPUKernelRuntime::AllocKernelDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs) { + if (!AllocKernelInputDynamicRes(kernel, kernel_inputs)) { + return false; + } + if (!AllocKernelOutputDynamicRes(kernel_mod, kernel, kernel_outputs)) { + return false; + } + if (!AllocKernelWorkspaceDynamicRes(kernel_mod, kernel, kernel_workspaces)) { + return false; + } + return true; +} + +bool GPUKernelRuntime::AllocKernelInputDynamicRes(const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_inputs) { MS_EXCEPTION_IF_NULL(kernel); MS_EXCEPTION_IF_NULL(kernel_inputs); - MS_EXCEPTION_IF_NULL(kernel_workspaces); - MS_EXCEPTION_IF_NULL(kernel_outputs); - MS_EXCEPTION_IF_NULL(mem_manager_); for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { - auto device_address = AnfAlgo::GetPrevNodeOutputAddr(kernel, i); + auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i); MS_EXCEPTION_IF_NULL(device_address); + if (mem_swap_manager_->trigger_swap()) { + while (auto device_address_swap_in = mem_swap_manager_->UpdateSwapQueue(SwapKind::kHostToDevice)) { + device_address_swap_in->set_status(DeviceAddressStatus::kInDevice); + } + + auto status = device_address->status(); + switch (status) { + case DeviceAddressStatus::kInDevice: + break; + case DeviceAddressStatus::kInHost: + break; + case DeviceAddressStatus::kInDeviceToHost: { + mem_swap_manager_->InsertSwapInBlackList(device_address->ptr_); + device_address->set_status(DeviceAddressStatus::kInDevice); + break; + } + case DeviceAddressStatus::kInHostToDevice: { + while (device_address->status() != DeviceAddressStatus::kInDevice) { + while (auto device_address_swap_in = mem_swap_manager_->UpdateSwapQueue(SwapKind::kHostToDevice)) { + device_address_swap_in->set_status(DeviceAddressStatus::kInDevice); + } + } + break; + } + default: + MS_LOG(ERROR) << "Invaild device address status"; + return false; + } + } MS_EXCEPTION_IF_NULL(device_address->ptr_); kernel::AddressPtr input = std::make_shared(); MS_EXCEPTION_IF_NULL(input); @@ -231,15 +400,29 @@ void GPUKernelRuntime::AllocKernelDynamicRes(const mindspore::kernel::KernelMod input->size = device_address->size_; kernel_inputs->emplace_back(input); } + return true; +} + +bool GPUKernelRuntime::AllocKernelOutputDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, + const mindspore::AnfNodePtr &kernel, + AddressPtrList *kernel_outputs) { + MS_EXCEPTION_IF_NULL(kernel); + MS_EXCEPTION_IF_NULL(kernel_outputs); + MS_EXCEPTION_IF_NULL(mem_manager_); + if (mem_swap_manager_->trigger_swap()) { + while (auto device_address_swap_out = mem_swap_manager_->UpdateSwapQueue(SwapKind::kDeviceToHost)) { + if (!mem_swap_manager_->FindInSwapInBlackList(device_address_swap_out->ptr_) && device_address_swap_out->ptr_) { + device_address_swap_out->set_status(DeviceAddressStatus::kInHost); + mem_manager_->FreeMemFromMemPool(device_address_swap_out); + } + } + } auto output_sizes = kernel_mod.GetOutputSizeList(); for (size_t i = 0; i < output_sizes.size(); ++i) { auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i); MS_EXCEPTION_IF_NULL(device_address); - if (device_address->ptr_ == nullptr) { - auto ret = mem_manager_->MallocMemFromMemPool(device_address, output_sizes[i]); - if (!ret) { - MS_LOG(EXCEPTION) << "Malloc device memory failed."; - } + if (device_address->ptr_ == nullptr && !AttemptMallocMem(device_address, output_sizes[i])) { + return false; } kernel::AddressPtr output = std::make_shared(); MS_EXCEPTION_IF_NULL(output); @@ -247,15 +430,24 @@ void GPUKernelRuntime::AllocKernelDynamicRes(const mindspore::kernel::KernelMod output->size = output_sizes[i]; kernel_outputs->emplace_back(output); } + return true; +} + +bool GPUKernelRuntime::AllocKernelWorkspaceDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, + const mindspore::AnfNodePtr &kernel, + AddressPtrList *kernel_workspaces) { + MS_EXCEPTION_IF_NULL(kernel); + MS_EXCEPTION_IF_NULL(kernel_workspaces); + MS_EXCEPTION_IF_NULL(mem_manager_); auto workspace_sizes = kernel_mod.GetWorkspaceSizeList(); for (size_t i = 0; i < workspace_sizes.size(); ++i) { if (workspace_sizes[i] == 0) { kernel_workspaces->emplace_back(nullptr); continue; } - auto device_ptr = mem_manager_->MallocMemFromMemPool(workspace_sizes[i]); + auto device_ptr = AttemptMallocMem(workspace_sizes[i]); if (!device_ptr) { - MS_LOG(EXCEPTION) << "Malloc device memory failed."; + return false; } kernel::AddressPtr workspace = std::make_shared(); MS_EXCEPTION_IF_NULL(workspace); @@ -263,6 +455,7 @@ void GPUKernelRuntime::AllocKernelDynamicRes(const mindspore::kernel::KernelMod workspace->size = workspace_sizes[i]; kernel_workspaces->emplace_back(workspace); } + return true; } void GPUKernelRuntime::AllocCommunicationOpDynamicRes(const session::KernelGraph *graph) { @@ -371,6 +564,7 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, if (kernel_ref_count_ptr->ref_count_dynamic_use_ == 0) { auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i); mem_manager_->FreeMemFromMemPool(device_address); + device_address->set_status(DeviceAddressStatus::kInDevice); } } // Free the output of kernel, if output has no reference. @@ -382,6 +576,7 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, if (kernel_ref_count_ptr->ref_count_dynamic_use_ == 0) { auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i); mem_manager_->FreeMemFromMemPool(device_address); + device_address->set_status(DeviceAddressStatus::kInDevice); } } // Free the workspace of kernel. diff --git a/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.h b/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.h index 6f0eefc27a..ea3ab17160 100644 --- a/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.h +++ b/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.h @@ -24,10 +24,12 @@ #include #include "device/kernel_runtime.h" #include "device/kernel_runtime_manager.h" +#include "pre_activate/mem_reuse/mem_swap_manager.h" namespace mindspore { namespace device { namespace gpu { +using mindspore::device::memswap::MemSwapManagerPtr; class GPUKernelRuntime : public KernelRuntime { public: GPUKernelRuntime() = default; @@ -51,10 +53,19 @@ class GPUKernelRuntime : public KernelRuntime { // The related functions and members for using dynamic memory pool. void InitKernelRefCount(const session::KernelGraph *graph); void InitKernelOutputAddress(const session::KernelGraph *graph); + void ClearKernelOutputAddress(const session::KernelGraph *graph); bool LaunchKernelDynamic(const session::KernelGraph *graph); - void AllocKernelDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel, + bool AddMemSwapTask(const AnfNodePtr &kernel); + bool AttemptMallocMem(const DeviceAddressPtr &device_address, size_t size); + void *AttemptMallocMem(size_t size); + bool AllocKernelDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs); + bool AllocKernelInputDynamicRes(const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_inputs); + bool AllocKernelOutputDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel, + AddressPtrList *kernel_outputs); + bool AllocKernelWorkspaceDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, + const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_workspaces); void AllocCommunicationOpDynamicRes(const session::KernelGraph *graph); void AllocCommunicationOpInputDynamicRes(const mindspore::AnfNodePtr &kernel); void AllocCommunicationOpOutputDynamicRes(const mindspore::AnfNodePtr &kernel); @@ -64,6 +75,8 @@ class GPUKernelRuntime : public KernelRuntime { void FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, const AddressPtrList &kernel_workspaces, uint32_t graph_id); std::unordered_map mem_reuse_util_map_; + std::unordered_map mem_swap_map_; + MemSwapManagerPtr mem_swap_manager_{nullptr}; }; MS_REG_KERNEL_RUNTIME(kGPUDevice, GPUKernelRuntime); } // namespace gpu diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_swap_manager.cc b/mindspore/ccsrc/pre_activate/mem_reuse/mem_swap_manager.cc index a8e36c5c58..75c440b7b3 100644 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_swap_manager.cc +++ b/mindspore/ccsrc/pre_activate/mem_reuse/mem_swap_manager.cc @@ -25,10 +25,7 @@ namespace memswap { void MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph) { MS_EXCEPTION_IF_NULL(kernel_graph); execution_order_ = kernel_graph->execution_order(); - FuncGraphManagerPtr manager = kernel_graph->manager(); - NodeUsersMap user_map = manager->node_users(); size_t kernel_index = 0; - for (const auto &kernel : execution_order_) { // parse topo order of kernel kernel_execution_info_.emplace(kernel.get(), kernel_index++); @@ -44,6 +41,31 @@ void MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph) { } // parse topo order of user kernel + SaveUserKernelTopoOrder(kernel_graph); + + sort(ordered_tensors_.begin(), ordered_tensors_.end(), + [](const TensorInfo &a, const TensorInfo &b) { return a.tensor_size_ > b.tensor_size_; }); + + auto cur_tensor_size = ordered_tensors_.front().tensor_size_; + for (auto &tensor_info : ordered_tensors_) { + if (cur_tensor_size != tensor_info.tensor_size_) { + cur_tensor_size = tensor_info.tensor_size_; + tensor_size_num_++; + } + } + tensor_size_threshold_ = ordered_tensors_.front().tensor_size_; + tensor_size_threshold_idx_ = 0; + + distance_threshold_ = kernel_index / kDistanceInitFactor; + mem_swap_initialized_ = true; + MS_EXCEPTION_IF_NULL(mem_copy_manager_); + mem_copy_manager_->Init(); +} + +void MemSwapManager::SaveUserKernelTopoOrder(const mindspore::session::KernelGraph *kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + FuncGraphManagerPtr manager = kernel_graph->manager(); + NodeUsersMap user_map = manager->node_users(); for (const auto &kernel : execution_order_) { auto iter = user_map.find(kernel); if (iter == user_map.end()) { @@ -66,24 +88,6 @@ void MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph) { sort(node_user_pair.second.begin(), node_user_pair.second.end()); } } - - sort(ordered_tensors_.begin(), ordered_tensors_.end(), - [](const TensorInfo &a, const TensorInfo &b) { return a.tensor_size_ > b.tensor_size_; }); - - auto cur_tensor_size = ordered_tensors_.front().tensor_size_; - for (auto &tensor_info : ordered_tensors_) { - if (cur_tensor_size != tensor_info.tensor_size_) { - cur_tensor_size = tensor_info.tensor_size_; - tensor_size_num_++; - } - } - tensor_size_threshold_ = ordered_tensors_.front().tensor_size_; - tensor_size_threshold_idx_ = 0; - - distance_threshold_ = kernel_index / kDistanceInitFactor; - mem_swap_initialized_ = true; - MS_EXCEPTION_IF_NULL(mem_copy_manager_); - mem_copy_manager_->Init(); } void MemSwapManager::AddSwapInfo() { @@ -228,12 +232,12 @@ float MemSwapManager::QueryKernelExecutionPerform(const AnfNodePtr &kernel) cons return kernel_exec_info.execution_perform_; } -bool MemSwapManager::QueryKerneTriggerSwap(const AnfNodePtr &kernel) const { +bool MemSwapManager::QueryKernelTriggerSwap(const AnfNodePtr &kernel) const { const auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); return kernel_exec_info.trigger_swap_; } -bool MemSwapManager::QueryKerneNeedSwap(const AnfNodePtr &kernel) const { +bool MemSwapManager::QueryKernelNeedSwap(const AnfNodePtr &kernel) const { const auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); return kernel_exec_info.need_swap_; } @@ -254,7 +258,7 @@ const PerformPair &MemSwapManager::QueryKernelSwapPerform(const AnfNodePtr &kern return iter_output->second; } -const std::vector &MemSwapManager::QueryKerneMemSwapInfo(const AnfNodePtr &kernel) const { +const std::vector &MemSwapManager::QueryKernelMemSwapInfo(const AnfNodePtr &kernel) const { MS_EXCEPTION_IF_NULL(kernel); auto iter = mem_swap_info_.find(kernel.get()); if (iter == mem_swap_info_.end()) { diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_swap_manager.h b/mindspore/ccsrc/pre_activate/mem_reuse/mem_swap_manager.h index 8fc85a5656..585ddd8f51 100644 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_swap_manager.h +++ b/mindspore/ccsrc/pre_activate/mem_reuse/mem_swap_manager.h @@ -63,11 +63,11 @@ class MemSwapManager { const PerformPair &QueryKernelSwapPerform(const AnfNodePtr &kernel, size_t output_idx) const; - bool QueryKerneTriggerSwap(const AnfNodePtr &kernel) const; + bool QueryKernelTriggerSwap(const AnfNodePtr &kernel) const; - bool QueryKerneNeedSwap(const AnfNodePtr &kernel) const; + bool QueryKernelNeedSwap(const AnfNodePtr &kernel) const; - const std::vector &QueryKerneMemSwapInfo(const AnfNodePtr &kernel) const; + const std::vector &QueryKernelMemSwapInfo(const AnfNodePtr &kernel) const; void InsertSwapInBlackList(const void *device_ptr); @@ -90,6 +90,8 @@ class MemSwapManager { void ResetSwapInfo(); + void SaveUserKernelTopoOrder(const mindspore::session::KernelGraph *kernel_graph); + void AddKernelTriggerSwap(const AnfNodePtr &kernel, bool trigger_swap); void AddKernelNeedSwap(const AnfNodePtr &kernel, bool need_swap);