Browse Source

Fix ps cache broadcast error.

tags/v1.6.0
ZPaC 4 years ago
parent
commit
4c1ef4cef6
4 changed files with 19 additions and 3 deletions
  1. +2
    -0
      mindspore/ccsrc/backend/session/gpu_session.cc
  2. +12
    -3
      mindspore/ccsrc/ps/core/abstract_node.cc
  3. +2
    -0
      mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc
  4. +3
    -0
      mindspore/ccsrc/runtime/hardware/gpu/gpu_device_context.cc

+ 2
- 0
mindspore/ccsrc/backend/session/gpu_session.cc View File

@@ -118,10 +118,12 @@ void GPUSession::Init(uint32_t device_id) {
ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id);
if (collective_inited) {
if (collective_handle_ != nullptr) {
MS_LOG(INFO) << "Start initializing NCCL communicator for device " << device_id;
auto init_nccl_comm_funcptr =
reinterpret_cast<InitNCCLComm>(dlsym(const_cast<void *>(collective_handle_), "InitNCCLComm"));
MS_EXCEPTION_IF_NULL(init_nccl_comm_funcptr);
(*init_nccl_comm_funcptr)();
MS_LOG(INFO) << "End initializing NCCL communicator.";
rank_id_ = GetRankId();
}
}


+ 12
- 3
mindspore/ccsrc/ps/core/abstract_node.cc View File

@@ -69,9 +69,18 @@ bool AbstractNode::Broadcast(const NodeRole &node_role, const DataPtr &message,
MS_LOG(EXCEPTION) << "Currently only supports broadcast to server nodes";
}

uint64_t request_id = AddMessageTrack(nodes_address_.size());
uint32_t broadcast_size = 0;
std::for_each(nodes_address_.begin(), nodes_address_.end(), [&broadcast_size, &node_role](const auto &addr) {
if (addr.first.first == node_role) {
++broadcast_size;
}
});
uint64_t request_id = AddMessageTrack(broadcast_size);

for (auto it = nodes_address_.begin(); it != nodes_address_.end(); ++it) {
if (it->first.first != node_role) {
continue;
}
auto message_meta = std::make_shared<MessageMeta>();
MS_EXCEPTION_IF_NULL(message_meta);
message_meta->set_cmd(NodeCommand::SEND_DATA);
@@ -626,7 +635,7 @@ void AbstractNode::ProcessFetchServersResp(const std::shared_ptr<MessageMeta> &m

nodes_address_.clear();
for (const auto &it : fetch_servers_resp_message.servers_meta()) {
nodes_address_[std::make_pair(NodeRole::SERVER, it.rank_id())] = std::make_pair(it.ip(), it.port());
nodes_address_[std::make_pair(it.role(), it.rank_id())] = std::make_pair(it.ip(), it.port());
MS_LOG(INFO) << "The server ip is:" << it.ip() << ", the port is:" << it.port();
}
}
@@ -862,7 +871,7 @@ const std::shared_ptr<TcpClient> &AbstractNode::GetOrCreateTcpClient(const uint3
return connected_nodes_[key];
} else {
if (nodes_address_.find(key) == nodes_address_.end()) {
MS_LOG(EXCEPTION) << "Worker receive nodes info from scheduler failed!";
MS_LOG(EXCEPTION) << "Worker receive nodes info from scheduler failed. Role: " << role << ", rank: " << rank_id;
}
if (config_ == nullptr) {
MS_LOG(EXCEPTION) << "The config is empty.";


+ 2
- 0
mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc View File

@@ -98,10 +98,12 @@ bool GPUKernelRuntime::Init() {
const void *collective_handle_ = CollectiveInitializer::instance().collective_handle();
bool collective_inited = CollectiveInitializer::instance().collective_inited();
if (collective_inited && collective_handle_ != nullptr) {
MS_LOG(INFO) << "Start initializing NCCL communicator for device " << device_id_;
auto init_nccl_comm_funcptr =
reinterpret_cast<InitNCCLComm>(dlsym(const_cast<void *>(collective_handle_), "InitNCCLComm"));
MS_EXCEPTION_IF_NULL(init_nccl_comm_funcptr);
(*init_nccl_comm_funcptr)();
MS_LOG(INFO) << "End initializing NCCL communicator.";
}
device_init_ = true;



+ 3
- 0
mindspore/ccsrc/runtime/hardware/gpu/gpu_device_context.cc View File

@@ -79,6 +79,7 @@ void GPUDeviceContext::Initialize() {
ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_context_key_.device_id_);
}

MS_LOG(INFO) << "Set GPU device id index " << device_context_key_.device_id_;
// Set device id and initialize device resource.
if (!InitDevice()) {
MS_LOG(EXCEPTION) << "GPU InitDevice failed.";
@@ -91,10 +92,12 @@ void GPUDeviceContext::Initialize() {

// Initialize NCCL.
if (collective_inited && collective_handle_ != nullptr) {
MS_LOG(INFO) << "Start initializing NCCL communicator for device " << device_context_key_.device_id_;
auto init_nccl_comm_funcptr =
reinterpret_cast<InitNCCLComm>(dlsym(const_cast<void *>(collective_handle_), "InitNCCLComm"));
MS_EXCEPTION_IF_NULL(init_nccl_comm_funcptr);
(*init_nccl_comm_funcptr)();
MS_LOG(INFO) << "End initializing NCCL communicator.";
}

#ifndef ENABLE_SECURITY


Loading…
Cancel
Save