| @@ -316,7 +316,17 @@ bool AscendKernelRuntime::Load(session::KernelGraph *graph, bool is_task_sink) { | |||||
| if (!is_task_sink) { | if (!is_task_sink) { | ||||
| return true; | return true; | ||||
| } | } | ||||
| rtCtxSetCurrent(rt_context_hccl_); | |||||
| // Bind hccl context to current thread | |||||
| if (graph->is_dynamic_shape()) { | |||||
| if (rt_context_hccl_ != nullptr) { | |||||
| auto ret = rtCtxSetCurrent(rt_context_hccl_); | |||||
| if (ret != RT_ERROR_NONE) { | |||||
| MS_LOG(EXCEPTION) << "Call rtCtxSetCurrent failed, ret[" << ret << "]"; | |||||
| } | |||||
| } | |||||
| } | |||||
| // Do HcomExecutorInitialize | // Do HcomExecutorInitialize | ||||
| if (graph->is_dynamic_shape() && !HcclExecutorManager::GetInstance().Initialize()) { | if (graph->is_dynamic_shape() && !HcclExecutorManager::GetInstance().Initialize()) { | ||||
| MS_LOG(ERROR) << "Init Hccl Executor Failed"; | MS_LOG(ERROR) << "Init Hccl Executor Failed"; | ||||
| @@ -655,6 +665,7 @@ bool AscendKernelRuntime::InitDevice() { | |||||
| ret = rtCtxGetCurrent(&rt_context_hccl_); | ret = rtCtxGetCurrent(&rt_context_hccl_); | ||||
| if (ret != RT_ERROR_NONE || rt_context_hccl_ == nullptr) { | if (ret != RT_ERROR_NONE || rt_context_hccl_ == nullptr) { | ||||
| MS_LOG(ERROR) << "Call rtCtxGetCurrent failed, ret[" << ret << "]"; | MS_LOG(ERROR) << "Call rtCtxGetCurrent failed, ret[" << ret << "]"; | ||||
| return false; | |||||
| } | } | ||||
| ret = rtCtxCreate(&rt_context_, 0, device_id_); | ret = rtCtxCreate(&rt_context_, 0, device_id_); | ||||
| @@ -687,6 +698,10 @@ bool AscendKernelRuntime::ResetDevice() { | |||||
| } | } | ||||
| rt_context_ = nullptr; | rt_context_ = nullptr; | ||||
| } | } | ||||
| // set to nullptr as its not created, only bounded to existing context | |||||
| rt_context_hccl_ = nullptr; | |||||
| return true; | return true; | ||||
| } | } | ||||