|
|
@@ -310,6 +310,13 @@ void GPUKernelRuntime::AllocInplaceNodeMemory(const session::KernelGraph *graph) |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
bool GPUKernelRuntime::IsDistributedTraining(const session::KernelGraph *graph) { |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
|
|
const auto &kernels = graph->execution_order(); |
|
|
|
|
|
return std::any_of(kernels.begin(), kernels.end(), |
|
|
|
|
|
[](const AnfNodePtr &kernel) { return AnfAlgo::IsCommunicationOp(kernel); }); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) { |
|
|
void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) { |
|
|
auto context_ptr = MsContext::GetInstance(); |
|
|
auto context_ptr = MsContext::GetInstance(); |
|
|
MS_EXCEPTION_IF_NULL(context_ptr); |
|
|
MS_EXCEPTION_IF_NULL(context_ptr); |
|
|
@@ -367,28 +374,28 @@ bool GPUKernelRuntime::Run(session::KernelGraph *graph, bool is_task_sink) { |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
bool GPUKernelRuntime::RunOneStep(const session::KernelGraph *graph) { |
|
|
bool GPUKernelRuntime::RunOneStep(const session::KernelGraph *graph) { |
|
|
bool ret = true; |
|
|
|
|
|
auto graph_id = graph->graph_id(); |
|
|
auto graph_id = graph->graph_id(); |
|
|
if (!is_first_step_map_[graph_id] || graph->is_dynamic_shape()) { |
|
|
if (!is_first_step_map_[graph_id] || graph->is_dynamic_shape()) { |
|
|
// Normally run graph |
|
|
// Normally run graph |
|
|
ret = LaunchKernelDynamic(graph); |
|
|
|
|
|
} else { |
|
|
|
|
|
// Mock run first step |
|
|
|
|
|
ret = LaunchKernelDynamic(graph, true, false); |
|
|
|
|
|
if (ret) { |
|
|
|
|
|
// Normally run graph |
|
|
|
|
|
ret = LaunchKernelDynamic(graph); |
|
|
|
|
|
} else { |
|
|
|
|
|
// Trigger memory swap |
|
|
|
|
|
ret = SearchMemSwapScheme(graph); |
|
|
|
|
|
} |
|
|
|
|
|
is_first_step_map_[graph_id] = false; |
|
|
|
|
|
|
|
|
return LaunchKernelDynamic(graph); |
|
|
} |
|
|
} |
|
|
return ret; |
|
|
|
|
|
|
|
|
is_first_step_map_[graph_id] = false; |
|
|
|
|
|
// Mock run first step |
|
|
|
|
|
bool ret = LaunchKernelDynamic(graph, true, false); |
|
|
|
|
|
if (ret) { |
|
|
|
|
|
// Normally run graph |
|
|
|
|
|
return LaunchKernelDynamic(graph); |
|
|
|
|
|
} |
|
|
|
|
|
if (IsDistributedTraining(graph)) { |
|
|
|
|
|
MS_LOG(ERROR) << "Device memory is not enough, run graph failed!"; |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
|
|
|
// Trigger memory swap |
|
|
|
|
|
return SearchMemSwapScheme(graph); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
bool GPUKernelRuntime::SearchMemSwapScheme(const session::KernelGraph *graph) { |
|
|
bool GPUKernelRuntime::SearchMemSwapScheme(const session::KernelGraph *graph) { |
|
|
MS_LOG(WARNING) << "Run out of memory and try memory swapping, it may take some time, please wait a moment."; |
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "Run out of memory and try memory swapping, it may take some time, please wait a moment."; |
|
|
bool ret = false; |
|
|
bool ret = false; |
|
|
ClearKernelOldOutputAndWorkspace(graph); |
|
|
ClearKernelOldOutputAndWorkspace(graph); |
|
|
if (!mem_swap_manager_->mem_swap_init()) { |
|
|
if (!mem_swap_manager_->mem_swap_init()) { |
|
|
@@ -399,6 +406,7 @@ bool GPUKernelRuntime::SearchMemSwapScheme(const session::KernelGraph *graph) { |
|
|
|
|
|
|
|
|
while (!ret) { |
|
|
while (!ret) { |
|
|
if (!mem_swap_manager_->RetreatSwapInfo()) { |
|
|
if (!mem_swap_manager_->RetreatSwapInfo()) { |
|
|
|
|
|
MS_LOG(ERROR) << "Device memory is not enough, run graph failed!"; |
|
|
return false; |
|
|
return false; |
|
|
} |
|
|
} |
|
|
ret = LaunchKernelDynamic(graph, true, false); |
|
|
ret = LaunchKernelDynamic(graph, true, false); |
|
|
@@ -417,7 +425,7 @@ bool GPUKernelRuntime::SearchMemSwapScheme(const session::KernelGraph *graph) { |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
bool GPUKernelRuntime::RefineMemSwapScheme(const session::KernelGraph *graph) { |
|
|
bool GPUKernelRuntime::RefineMemSwapScheme(const session::KernelGraph *graph) { |
|
|
MS_LOG(WARNING) << "Refine memory swap scheme, it may take some time, please wait a moment."; |
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "Refine memory swap scheme, it may take some time, please wait a moment."; |
|
|
auto &kernels = graph->execution_order(); |
|
|
auto &kernels = graph->execution_order(); |
|
|
for (const auto &kernel : kernels) { |
|
|
for (const auto &kernel : kernels) { |
|
|
if (!mem_swap_manager_->QueryKernelTriggerSwapIn(kernel)) { |
|
|
if (!mem_swap_manager_->QueryKernelTriggerSwapIn(kernel)) { |
|
|
|