|
|
|
@@ -238,11 +238,16 @@ int CudaDriver::device_count() { |
|
|
|
return dev_count; |
|
|
|
} |
|
|
|
|
|
|
|
bool CudaDriver::set_current_device(int index) { |
|
|
|
bool CudaDriver::SetDevice(int index) { |
|
|
|
auto ret = cudaSetDevice(index); |
|
|
|
if (ret != cudaSuccess) { |
|
|
|
MS_LOG(ERROR) << "cudaSetDevice " << index << " failed, ret[" << static_cast<int>(ret) << "], " |
|
|
|
<< cudaGetErrorString(ret); |
|
|
|
MS_LOG(ERROR) |
|
|
|
<< "SetDevice for id:" << index << " failed, ret[" << static_cast<int>(ret) << "], " << cudaGetErrorString(ret) |
|
|
|
<< ". Please make sure that the 'device_id' set in context is in the range:[0, total number of GPU). " |
|
|
|
"If the environment variable 'CUDA_VISIBLE_DEVICES' is set, the total number of GPU will be the number set " |
|
|
|
"in the environment variable 'CUDA_VISIBLE_DEVICES'. For example, if export CUDA_VISIBLE_DEVICES=4,5,6, the " |
|
|
|
"'device_id' can be 0,1,2 at the moment, 'device_id' starts from 0, and 'device_id'=0 means using GPU of " |
|
|
|
"number 4."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
return true; |
|
|
|
|