diff --git a/mindspore/ccsrc/backend/session/gpu_session.cc b/mindspore/ccsrc/backend/session/gpu_session.cc index 8429f7c87d..ae38283d08 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.cc +++ b/mindspore/ccsrc/backend/session/gpu_session.cc @@ -53,6 +53,8 @@ #include "runtime/device/gpu/gpu_stream_assign.h" #include "runtime/device/gpu/kernel_info_setter.h" #include "runtime/device/kernel_runtime_manager.h" +#include "runtime/device/gpu/cuda_driver.h" +#include "runtime/device/gpu/distribution/collective_init.h" #include "utils/ms_utils.h" #include "utils/config_manager.h" #include "utils/ms_context.h" @@ -64,6 +66,25 @@ namespace mindspore { namespace session { namespace gpu { using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm; +using CollectiveInitializer = device::gpu::CollectiveInitializer; +using GetLocalRankId = device::gpu::GetLocalRankId; + +void GPUSession::Init(uint32_t device_id) { + const void *collective_handle_ = CollectiveInitializer::instance().collective_handle(); + bool collective_inited = CollectiveInitializer::instance().collective_inited(); + if (collective_inited && collective_handle_ != nullptr) { + auto get_local_rank_funcptr = + reinterpret_cast(dlsym(const_cast(collective_handle_), "local_rank_id")); + MS_EXCEPTION_IF_NULL(get_local_rank_funcptr); + device_id = IntToUint((*get_local_rank_funcptr)()); + } + bool ret = device::gpu::CudaDriver::set_current_device(UintToInt(device_id)); + if (!ret) { + MS_LOG(EXCEPTION) << "GPUSession failed to set current device id."; + } + MS_LOG(INFO) << "Set device id " << device_id << " for gpu session."; + InitDevice(kGPUDevice, device_id); +} void GPUSession::SelectKernel(const std::shared_ptr &kernel_graph) const { MS_EXCEPTION_IF_NULL(kernel_graph); diff --git a/mindspore/ccsrc/backend/session/gpu_session.h b/mindspore/ccsrc/backend/session/gpu_session.h index ebd3598cdb..27de7c48ac 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.h +++ b/mindspore/ccsrc/backend/session/gpu_session.h @@ -31,7 +31,7 @@ class GPUSession : public SessionBasic { public: GPUSession() = default; ~GPUSession() override = default; - void Init(uint32_t device_id) override { InitDevice(kGPUDevice, device_id); } + void Init(uint32_t device_id) override; protected: GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc index c854362a35..4cfeb00a75 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc @@ -187,14 +187,6 @@ bool GPUKernelRuntime::InitDevice() { MS_LOG(ERROR) << "No GPU device found."; return false; } - const void *collective_handle_ = CollectiveInitializer::instance().collective_handle(); - bool collective_inited = CollectiveInitializer::instance().collective_inited(); - if (collective_inited && collective_handle_ != nullptr) { - auto get_local_rank_funcptr = - reinterpret_cast(dlsym(const_cast(collective_handle_), "local_rank_id")); - MS_EXCEPTION_IF_NULL(get_local_rank_funcptr); - device_id_ = IntToUint((*get_local_rank_funcptr)()); - } if (!GPUDeviceManager::GetInstance().is_device_id_init()) { if (!GPUDeviceManager::GetInstance().set_cur_device_id(device_id_)) { MS_LOG(ERROR) << "Failed to set current device to " << SizeToInt(device_id_);