|
|
|
@@ -53,6 +53,8 @@ |
|
|
|
#include "runtime/device/gpu/gpu_stream_assign.h" |
|
|
|
#include "runtime/device/gpu/kernel_info_setter.h" |
|
|
|
#include "runtime/device/kernel_runtime_manager.h" |
|
|
|
#include "runtime/device/gpu/cuda_driver.h" |
|
|
|
#include "runtime/device/gpu/distribution/collective_init.h" |
|
|
|
#include "utils/ms_utils.h" |
|
|
|
#include "utils/config_manager.h" |
|
|
|
#include "utils/ms_context.h" |
|
|
|
@@ -64,6 +66,25 @@ namespace mindspore { |
|
|
|
namespace session { |
|
|
|
namespace gpu { |
|
|
|
using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm; |
|
|
|
using CollectiveInitializer = device::gpu::CollectiveInitializer; |
|
|
|
using GetLocalRankId = device::gpu::GetLocalRankId; |
|
|
|
|
|
|
|
void GPUSession::Init(uint32_t device_id) { |
|
|
|
const void *collective_handle_ = CollectiveInitializer::instance().collective_handle(); |
|
|
|
bool collective_inited = CollectiveInitializer::instance().collective_inited(); |
|
|
|
if (collective_inited && collective_handle_ != nullptr) { |
|
|
|
auto get_local_rank_funcptr = |
|
|
|
reinterpret_cast<GetLocalRankId>(dlsym(const_cast<void *>(collective_handle_), "local_rank_id")); |
|
|
|
MS_EXCEPTION_IF_NULL(get_local_rank_funcptr); |
|
|
|
device_id = IntToUint((*get_local_rank_funcptr)()); |
|
|
|
} |
|
|
|
bool ret = device::gpu::CudaDriver::set_current_device(UintToInt(device_id)); |
|
|
|
if (!ret) { |
|
|
|
MS_LOG(EXCEPTION) << "GPUSession failed to set current device id."; |
|
|
|
} |
|
|
|
MS_LOG(INFO) << "Set device id " << device_id << " for gpu session."; |
|
|
|
InitDevice(kGPUDevice, device_id); |
|
|
|
} |
|
|
|
|
|
|
|
void GPUSession::SelectKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const { |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph); |
|
|
|
|