Browse Source

Fix bug of wrong cuda device id

pull/14802/head
caifubi 4 years ago
parent
commit
e76e7d4a27
1 changed files with 6 additions and 2 deletions
  1. +6
    -2
      mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc

+ 6
- 2
mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc View File

@@ -54,6 +54,7 @@ using mindspore::device::memswap::MemSwapInfoSet;
using mindspore::device::memswap::MemSwapManager;
using mindspore::device::memswap::SwapKind;
static const size_t PARAMETER_OUTPUT_INDEX = 0;
static thread_local bool cur_thread_device_inited{false};

bool GPUKernelRuntime::SyncStream() {
if (!GPUDeviceManager::GetInstance().SyncStream(stream_)) {
@@ -70,8 +71,11 @@ bool GPUKernelRuntime::SyncStream() {
bool GPUKernelRuntime::Init() {
enable_relation_cache_ = context::GraphKernelFlags::GetInstance().IsEnableGraphKernel();

if (device_init_ == true) {
CHECK_OP_RET_WITH_EXCEPT(CudaDriver::SetDevice(UintToInt(device_id_)), "Failed to set device id");
if (device_init_) {
if (!cur_thread_device_inited) {
CHECK_OP_RET_WITH_EXCEPT(CudaDriver::SetDevice(UintToInt(device_id_)), "Failed to set device id");
cur_thread_device_inited = true;
}
GPUMemoryAllocator::GetInstance().CheckMaxDeviceMemory();
return true;
}


Loading…
Cancel
Save