| @@ -54,6 +54,7 @@ using mindspore::device::memswap::MemSwapInfoSet; | |||||
| using mindspore::device::memswap::MemSwapManager; | using mindspore::device::memswap::MemSwapManager; | ||||
| using mindspore::device::memswap::SwapKind; | using mindspore::device::memswap::SwapKind; | ||||
| static const size_t PARAMETER_OUTPUT_INDEX = 0; | static const size_t PARAMETER_OUTPUT_INDEX = 0; | ||||
| static thread_local bool cur_thread_device_inited{false}; | |||||
| bool GPUKernelRuntime::SyncStream() { | bool GPUKernelRuntime::SyncStream() { | ||||
| if (!GPUDeviceManager::GetInstance().SyncStream(stream_)) { | if (!GPUDeviceManager::GetInstance().SyncStream(stream_)) { | ||||
| @@ -70,8 +71,11 @@ bool GPUKernelRuntime::SyncStream() { | |||||
| bool GPUKernelRuntime::Init() { | bool GPUKernelRuntime::Init() { | ||||
| enable_relation_cache_ = context::GraphKernelFlags::GetInstance().IsEnableGraphKernel(); | enable_relation_cache_ = context::GraphKernelFlags::GetInstance().IsEnableGraphKernel(); | ||||
| if (device_init_ == true) { | |||||
| CHECK_OP_RET_WITH_EXCEPT(CudaDriver::SetDevice(UintToInt(device_id_)), "Failed to set device id"); | |||||
| if (device_init_) { | |||||
| if (!cur_thread_device_inited) { | |||||
| CHECK_OP_RET_WITH_EXCEPT(CudaDriver::SetDevice(UintToInt(device_id_)), "Failed to set device id"); | |||||
| cur_thread_device_inited = true; | |||||
| } | |||||
| GPUMemoryAllocator::GetInstance().CheckMaxDeviceMemory(); | GPUMemoryAllocator::GetInstance().CheckMaxDeviceMemory(); | ||||
| return true; | return true; | ||||
| } | } | ||||