|
|
|
@@ -316,7 +316,17 @@ bool AscendKernelRuntime::Load(session::KernelGraph *graph, bool is_task_sink) { |
|
|
|
if (!is_task_sink) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
rtCtxSetCurrent(rt_context_hccl_); |
|
|
|
|
|
|
|
// Bind hccl context to current thread |
|
|
|
if (graph->is_dynamic_shape()) { |
|
|
|
if (rt_context_hccl_ != nullptr) { |
|
|
|
auto ret = rtCtxSetCurrent(rt_context_hccl_); |
|
|
|
if (ret != RT_ERROR_NONE) { |
|
|
|
MS_LOG(EXCEPTION) << "Call rtCtxSetCurrent failed, ret[" << ret << "]"; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// Do HcomExecutorInitialize |
|
|
|
if (graph->is_dynamic_shape() && !HcclExecutorManager::GetInstance().Initialize()) { |
|
|
|
MS_LOG(ERROR) << "Init Hccl Executor Failed"; |
|
|
|
@@ -655,6 +665,7 @@ bool AscendKernelRuntime::InitDevice() { |
|
|
|
ret = rtCtxGetCurrent(&rt_context_hccl_); |
|
|
|
if (ret != RT_ERROR_NONE || rt_context_hccl_ == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Call rtCtxGetCurrent failed, ret[" << ret << "]"; |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
ret = rtCtxCreate(&rt_context_, 0, device_id_); |
|
|
|
@@ -687,6 +698,10 @@ bool AscendKernelRuntime::ResetDevice() { |
|
|
|
} |
|
|
|
rt_context_ = nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
// set to nullptr as its not created, only bounded to existing context |
|
|
|
rt_context_hccl_ = nullptr; |
|
|
|
|
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
|