From e76e7d4a27a792b60e8b977bed215e901279a4c0 Mon Sep 17 00:00:00 2001 From: caifubi Date: Tue, 6 Apr 2021 21:23:16 +0800 Subject: [PATCH] Fix bug of wrong cuda device id --- mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc index 504b722392..ac77ae927c 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc @@ -54,6 +54,7 @@ using mindspore::device::memswap::MemSwapInfoSet; using mindspore::device::memswap::MemSwapManager; using mindspore::device::memswap::SwapKind; static const size_t PARAMETER_OUTPUT_INDEX = 0; +static thread_local bool cur_thread_device_inited{false}; bool GPUKernelRuntime::SyncStream() { if (!GPUDeviceManager::GetInstance().SyncStream(stream_)) { @@ -70,8 +71,11 @@ bool GPUKernelRuntime::SyncStream() { bool GPUKernelRuntime::Init() { enable_relation_cache_ = context::GraphKernelFlags::GetInstance().IsEnableGraphKernel(); - if (device_init_ == true) { - CHECK_OP_RET_WITH_EXCEPT(CudaDriver::SetDevice(UintToInt(device_id_)), "Failed to set device id"); + if (device_init_) { + if (!cur_thread_device_inited) { + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::SetDevice(UintToInt(device_id_)), "Failed to set device id"); + cur_thread_device_inited = true; + } GPUMemoryAllocator::GetInstance().CheckMaxDeviceMemory(); return true; }