Browse Source

!10393 added collective send and receive

From: @anancds
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
b96d4315dc
13 changed files with 316 additions and 152 deletions
  1. +137
    -15
      mindspore/ccsrc/ps/core/abstract_node.cc
  2. +39
    -17
      mindspore/ccsrc/ps/core/abstract_node.h
  3. +1
    -0
      mindspore/ccsrc/ps/core/protos/comm.proto
  4. +4
    -19
      mindspore/ccsrc/ps/core/scheduler_node.cc
  5. +1
    -0
      mindspore/ccsrc/ps/core/scheduler_node.h
  6. +23
    -27
      mindspore/ccsrc/ps/core/server_node.cc
  7. +3
    -2
      mindspore/ccsrc/ps/core/server_node.h
  8. +29
    -22
      mindspore/ccsrc/ps/core/tcp_client.cc
  9. +1
    -1
      mindspore/ccsrc/ps/core/tcp_client.h
  10. +31
    -23
      mindspore/ccsrc/ps/core/tcp_server.cc
  11. +1
    -1
      mindspore/ccsrc/ps/core/tcp_server.h
  12. +4
    -25
      mindspore/ccsrc/ps/core/worker_node.cc
  13. +42
    -0
      tests/ut/cpp/ps/core/abstract_node_test.cc

+ 137
- 15
mindspore/ccsrc/ps/core/abstract_node.cc View File

@@ -53,13 +53,20 @@ void AbstractNode::ProcessRegisterResp(const CommMessage &message) {
MS_LOG(INFO) << "The node id is:" << node_info_.node_id_ << ", and the rank id is:" << node_info_.rank_id_;
}

bool AbstractNode::BroadcastToServers(const std::string &message, const uint32_t &timeout) {
bool AbstractNode::Broadcast(const enum NodeRole &node_role, const std::string &message, const uint32_t &timeout) {
if (node_role != NodeRole::SERVER) {
MS_LOG(EXCEPTION) << "Currently only supports broadcast to server nodes";
}

uint64_t request_id = ++next_request_id_;
message_tracker_[request_id] = std::make_pair(nodes_address_.size(), 0);

for (auto it = nodes_address_.begin(); it != nodes_address_.end(); ++it) {
MessageMeta message_meta;
message_meta.set_cmd(NodeCommand::SEND_DATA);
message_meta.set_request_id(request_id);
message_meta.set_rank_id(node_info_.rank_id_);
message_meta.set_role(node_info_.node_role_);

CommMessage comm_message;
*comm_message.mutable_pb_meta() = {message_meta};
@@ -82,12 +89,14 @@ bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id,

MessageMeta message_meta;
message_meta.set_cmd(NodeCommand::SEND_DATA);
message_meta.set_rank_id(node_info_.rank_id_);
message_meta.set_role(node_info_.node_role_);

CommMessage comm_message;
*comm_message.mutable_pb_meta() = {message_meta};
comm_message.set_data(message);
auto client = GetOrCreateTcpClient(rank_id);
return SendMessageSync(client, comm_message);
return SendMessageSync(client, comm_message, timeout);
}

bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
@@ -106,6 +115,8 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
MessageMeta message_meta;
message_meta.set_cmd(NodeCommand::SEND_DATA);
message_meta.set_request_id(request_id);
message_meta.set_rank_id(node_info_.rank_id_);
message_meta.set_role(node_info_.node_role_);

CommMessage comm_message;
*comm_message.mutable_pb_meta() = {message_meta};
@@ -118,8 +129,8 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
}

bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
CommMessage *comm_message_resp, const uint32_t &timeout) {
MS_EXCEPTION_IF_NULL(comm_message_resp);
CommMessage *output, const uint32_t &timeout) {
MS_EXCEPTION_IF_NULL(output);
if (!CommUtil::ValidateRankId(node_role, rank_id)) {
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
}
@@ -129,7 +140,7 @@ bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id,
set_message_callback(request_id, [&]() {
receive_messages_mutex_.lock();
auto res = receive_messages_[request_id];
*comm_message_resp = res[rank_id];
*output = res[rank_id];
receive_messages_.erase(request_id);
receive_messages_mutex_.unlock();
});
@@ -149,9 +160,9 @@ bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id,
}

bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
const std::vector<std::string> &data, std::vector<CommMessage> *comm_message_resp,
const std::vector<std::string> &data, std::vector<CommMessage> *output,
const uint32_t &timeout) {
MS_EXCEPTION_IF_NULL(comm_message_resp);
MS_EXCEPTION_IF_NULL(output);
uint64_t request_id = ++next_request_id_;
message_tracker_[request_id] = std::make_pair(data.size(), 0);

@@ -165,7 +176,7 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
receive_messages_mutex_.lock();
auto res = receive_messages_[request_id];
for (size_t it = 0; it < len; ++it) {
(*comm_message_resp).push_back(res[rank_ids.at(it)]);
(*output).push_back(res[rank_ids.at(it)]);
}
receive_messages_.erase(request_id);
receive_messages_mutex_.unlock();
@@ -179,6 +190,8 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
MessageMeta message_meta;
message_meta.set_cmd(NodeCommand::SEND_DATA);
message_meta.set_request_id(request_id);
message_meta.set_rank_id(node_info_.rank_id_);
message_meta.set_role(node_info_.node_role_);

CommMessage comm_message;
*comm_message.mutable_pb_meta() = {message_meta};
@@ -200,6 +213,58 @@ bool AbstractNode::Wait(uint64_t request_id, const uint32_t &timeout) {
return res;
}

uint64_t AbstractNode::CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id,
const std::string &message, const uint32_t &timeout) {
if (!CommUtil::ValidateRankId(node_role, rank_id)) {
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
}

MessageMeta message_meta;
message_meta.set_cmd(NodeCommand::COLLECTIVE_SEND_DATA);
message_meta.set_rank_id(node_info_.rank_id_);
message_meta.set_role(node_info_.node_role_);

CommMessage comm_message;
*comm_message.mutable_pb_meta() = {message_meta};
comm_message.set_data(message);
auto client = GetOrCreateTcpClient(rank_id);
return SendMessageAsync(client, comm_message);
}

std::pair<uint32_t, uint64_t> AbstractNode::CollectiveReceiveAsync(const enum NodeRole &node_role,
const uint32_t &rank_id, CommMessage *output) {
if (!CommUtil::ValidateRankId(node_role, rank_id)) {
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
}

uint64_t rank_request_id = NextExpectedRankRequestId(rank_id);
if (received_data_.count(std::make_pair(rank_id, rank_request_id)) > 0) {
*output = received_data_[std::make_pair(rank_id, rank_request_id)];
received_data_.erase(std::make_pair(rank_id, rank_request_id));
} else {
set_receive_callback(rank_id, rank_request_id, [=]() {
receive_callbacks_mutex_.lock();
*output = received_data_[std::make_pair(rank_id, 1)];
received_data_.erase(std::make_pair(rank_id, rank_request_id));
receive_callbacks_mutex_.unlock();
});
}
return std::make_pair(rank_id, rank_request_id);
}

bool AbstractNode::CollectiveWait(std::pair<uint32_t, uint64_t> request_id, const uint32_t &timeout) {
std::unique_lock<std::mutex> lock(receive_callbacks_mutex_);
bool res = receive_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] {
if (actual_rank_request_ids_.count(request_id.first) &&
(actual_rank_request_ids_[request_id.first] >= request_id.second)) {
return true;
} else {
return false;
}
});
return res;
}

void AbstractNode::StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client) {
MS_LOG(INFO) << "The node role: " << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id:" << node_info_.node_id_ << ", the node rank id:" << node_info_.rank_id_
@@ -210,7 +275,6 @@ void AbstractNode::StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client)
std::this_thread::sleep_for(std::chrono::seconds(ClusterConfig::heartbeat_interval()));
}
});
heart_beat_thread_->detach();
}

void AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_node_finish) {
@@ -334,11 +398,9 @@ bool AbstractNode::InitClientToScheduler() {
MS_LOG(INFO) << "The worker node start a tcp client!";
client_to_scheduler_->Start();
});
client_to_scheduler_thread_->detach();

client_to_scheduler_->set_disconnected_callback([&]() {
std::this_thread::sleep_for(std::chrono::milliseconds(ClusterConfig::connect_interval()));
client_to_scheduler_->Stop();
client_to_scheduler_->Init();
});
return client_to_scheduler_->WaitConnected();
@@ -361,6 +423,9 @@ const std::shared_ptr<TcpClient> &AbstractNode::GetOrCreateTcpClient(const int &
ProcessSendDataResp(message);
RunMessageCallback(message.pb_meta().request_id());
break;
case NodeCommand::COLLECTIVE_SEND_DATA:
MS_LOG(INFO) << "The Node id:" << node_info_.node_id_ << " receive a collective_send_data message response!";
break;
default:
MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!";
}
@@ -381,10 +446,12 @@ bool AbstractNode::SendMessageSync(const std::shared_ptr<TcpClient> &client, con
return Wait(request_id, timeout);
}

void AbstractNode::SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message) {
uint64_t AbstractNode::SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message) {
uint64_t request_id = ++next_request_id_;
message_tracker_[request_id] = std::make_pair(1, 0);
const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id);
client->SendMessage(message);
return request_id;
}

void AbstractNode::ProcessSendDataResp(const CommMessage &message) {
@@ -422,12 +489,12 @@ void AbstractNode::RunMessageCallback(const uint64_t &request_id) {
message_callbacks_mutex_.unlock();
}

void AbstractNode::set_message_callback(const uint64_t &request_id, const MessageCallback &message_callback) {
if (!message_callback) {
void AbstractNode::set_message_callback(const uint64_t &request_id, const MessageCallback &callback) {
if (!callback) {
return;
}
std::lock_guard<std::mutex> lock(message_callbacks_mutex_);
message_callbacks_[request_id] = message_callback;
message_callbacks_[request_id] = callback;
}

void AbstractNode::NotifyMessageArrival(const CommMessage &message) {
@@ -438,6 +505,61 @@ void AbstractNode::NotifyMessageArrival(const CommMessage &message) {
message_tracker_[request_id].second++;
message_tracker_cond_.notify_all();
}

void AbstractNode::set_receive_callback(const uint32_t &rank_id, const uint64_t &request_id,
const MessageCallback &callback) {
if (!callback) {
return;
}
std::lock_guard<std::mutex> lock(receive_callbacks_mutex_);
receive_callbacks_[std::make_pair(rank_id, request_id)] = callback;
}

void AbstractNode::RunReceiveCallback(const CommMessage &message) {
receive_callbacks_mutex_.lock();
uint32_t rank_id = message.pb_meta().rank_id();
// When receiving a collective message, Then generate rank request id,compare with the desired rank request id,
// If they are equal, then call the callback function
uint64_t rank_request_id = NextActualRankRequestId(rank_id);
received_data_[std::make_pair(rank_id, rank_request_id)] = message;
auto it = receive_callbacks_.find(std::make_pair(rank_id, rank_request_id));
if (it != receive_callbacks_.end()) {
receive_callbacks_mutex_.unlock();

if (it->second) {
it->second();
}

receive_callbacks_mutex_.lock();
receive_cond_.notify_all();
receive_callbacks_.erase(it);
}
receive_callbacks_mutex_.unlock();
}

uint64_t AbstractNode::NextExpectedRankRequestId(const uint32_t &rank_id) {
std::lock_guard<std::mutex> lock(rank_request_ids_mutex);
uint64_t rank_request_id = 1;
if (expected_rank_request_ids_.count(rank_id)) {
rank_request_id = ++expected_rank_request_ids_[rank_id];
expected_rank_request_ids_[rank_id] = rank_request_id;
} else {
expected_rank_request_ids_[rank_id] = rank_request_id;
}
return rank_request_id;
}

uint64_t AbstractNode::NextActualRankRequestId(const uint32_t &rank_id) {
std::lock_guard<std::mutex> lock(rank_request_ids_mutex);
uint64_t rank_request_id = 1;
if (actual_rank_request_ids_.count(rank_id)) {
rank_request_id = ++actual_rank_request_ids_[rank_id];
actual_rank_request_ids_[rank_id] = rank_request_id;
} else {
actual_rank_request_ids_[rank_id] = rank_request_id;
}
return rank_request_id;
}
} // namespace core
} // namespace ps
} // namespace mindspore

+ 39
- 17
mindspore/ccsrc/ps/core/abstract_node.h View File

@@ -34,21 +34,26 @@ class AbstractNode : public Node {
AbstractNode() : heart_beat_thread_(nullptr), client_to_scheduler_thread_(nullptr), client_to_scheduler_(nullptr) {}
~AbstractNode() override = default;

bool BroadcastToServers(const std::string &message, const uint32_t &timeout = kCommTimeoutInSeconds);
bool Broadcast(const enum NodeRole &node_role, const std::string &message,
const uint32_t &timeout = kCommTimeoutInSeconds);
void set_event_callback(const OnNodeEventMessage &on_node_event_message);

virtual bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
const uint32_t &timeout = kCommTimeoutInSeconds);
virtual bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
const std::vector<std::string> &data, const uint32_t &timeout = kCommTimeoutInSeconds);
virtual bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
CommMessage *comm_message_resp, const uint32_t &timeout = kCommTimeoutInSeconds);
virtual bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
const std::vector<std::string> &data, std::vector<CommMessage> *comm_message_resp,
const uint32_t &timeout = kCommTimeoutInSeconds);

bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
const uint32_t &timeout = kCommTimeoutInSeconds);
bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<std::string> &data,
const uint32_t &timeout = kCommTimeoutInSeconds);
bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, CommMessage *output,
const uint32_t &timeout = kCommTimeoutInSeconds);
bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<std::string> &data,
std::vector<CommMessage> *output, const uint32_t &timeout = kCommTimeoutInSeconds);
bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds);

uint64_t CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message,
const uint32_t &timeout = kCommTimeoutInSeconds);
std::pair<uint32_t, uint64_t> CollectiveReceiveAsync(const enum NodeRole &node_role, const uint32_t &rank_id,
CommMessage *output);
bool CollectiveWait(std::pair<uint32_t, uint64_t> request_id, const uint32_t &timeout = kCommTimeoutInSeconds);

protected:
void Register(const std::shared_ptr<TcpClient> &client);
void ProcessRegisterResp(const CommMessage &message);
@@ -63,34 +68,51 @@ class AbstractNode : public Node {
const std::shared_ptr<TcpClient> &GetOrCreateTcpClient(const int &rank_id);
bool SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message,
const uint32_t &timeout = kCommTimeoutInSeconds);
void SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message);
uint64_t SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message);
void ProcessSendDataResp(const CommMessage &message);
void RunMessageCallback(const uint64_t &request_id);
void set_message_callback(const uint64_t &request_id, const MessageCallback &message_callback);
void set_message_callback(const uint64_t &request_id, const MessageCallback &callback);
void NotifyMessageArrival(const CommMessage &message);
void set_receive_callback(const uint32_t &rank_id, const uint64_t &request_id, const MessageCallback &callback);
void RunReceiveCallback(const CommMessage &message);
uint64_t NextExpectedRankRequestId(const uint32_t &rank_id);
uint64_t NextActualRankRequestId(const uint32_t &rank_id);

std::unique_ptr<std::thread> heart_beat_thread_;
std::unique_ptr<std::thread> client_to_scheduler_thread_;
std::shared_ptr<TcpClient> client_to_scheduler_;

OnNodeEventMessage on_node_event_message_;
// the map's key is: <node_role,rank_id>, the map's value is: <ip, port>
// the key is: <node_role,rank_id>, the value is: <ip, port>
std::map<std::pair<NodeRole, uint32_t>, std::pair<std::string, uint16_t>> nodes_address_;
std::mutex client_mutex_;
// the map's key is: rank_id
std::unordered_map<int, std::shared_ptr<TcpClient>> connected_nodes_;

// the map's key is: request_id, the map's value is: <expected responses, actual responses>
// the key is: request_id, the value is: <expected responses, actual responses>
std::unordered_map<uint64_t, std::pair<uint32_t, uint32_t>> message_tracker_;
std::mutex message_tracker_mutex_;
std::condition_variable message_tracker_cond_;

// the map's key is: request_id, the map's value is:<rank_id, CommMessage>
// the key is: request_id, the value is:<rank_id, CommMessage>
std::unordered_map<uint64_t, std::unordered_map<uint32_t, CommMessage>> receive_messages_;
std::mutex receive_messages_mutex_;
// the map's key is: request_id
// the key is: request_id
std::unordered_map<uint64_t, MessageCallback> message_callbacks_;
std::mutex message_callbacks_mutex_;

// the key is <rank_id, rank_request_id>
std::map<std::pair<uint32_t, uint64_t>, CommMessage> received_data_;
std::mutex receive_callbacks_mutex_;
// the key is <rank_id, rank_request_id>
std::map<std::pair<uint32_t, uint64_t>, MessageCallback> receive_callbacks_;
std::condition_variable receive_cond_;

// the key is rank_id, the value is rank_id's expected request_id
std::unordered_map<uint32_t, uint64_t> expected_rank_request_ids_;
// the key is rank_id, the value is rank_id's actual request_id
std::unordered_map<uint32_t, uint64_t> actual_rank_request_ids_;
std::mutex rank_request_ids_mutex;
};
} // namespace core
} // namespace ps


+ 1
- 0
mindspore/ccsrc/ps/core/protos/comm.proto View File

@@ -26,6 +26,7 @@ enum NodeCommand {
SEND_DATA = 3;
FETCH_SERVER = 4;
FINISH = 5;
COLLECTIVE_SEND_DATA = 6;
}

enum NodeRole {


+ 4
- 19
mindspore/ccsrc/ps/core/scheduler_node.cc View File

@@ -19,19 +19,10 @@
namespace mindspore {
namespace ps {
namespace core {

SchedulerNode::~SchedulerNode() {
MS_LOG(INFO) << "Stop scheduler node!";
if (!is_already_stopped_) {
is_already_stopped_ = true;
server_->Stop();
if (scheduler_thread_->joinable()) {
scheduler_thread_->join();
}
if (update_state_thread_->joinable()) {
update_state_thread_->join();
}
is_ready_ = true;
}
Stop();
}

bool SchedulerNode::Start(const uint32_t &timeout) {
@@ -114,7 +105,6 @@ void SchedulerNode::CreateTcpServer() {
MS_LOG(INFO) << "The scheduler node start a tcp server!";
server_->Start();
});
scheduler_thread_->detach();
}

void SchedulerNode::ProcessRegister(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) {
@@ -186,20 +176,15 @@ void SchedulerNode::StartUpdateClusterStateTimer() {
}
}
});
update_state_thread_->detach();
}

bool SchedulerNode::Stop() {
MS_LOG(INFO) << "Stop scheduler node!";
if (!is_already_stopped_) {
is_already_stopped_ = true;
update_state_thread_->join();
server_->Stop();
if (scheduler_thread_->joinable()) {
scheduler_thread_->join();
}
if (update_state_thread_->joinable()) {
update_state_thread_->join();
}
scheduler_thread_->join();
is_ready_ = true;
}
return true;


+ 1
- 0
mindspore/ccsrc/ps/core/scheduler_node.h View File

@@ -38,6 +38,7 @@
namespace mindspore {
namespace ps {
namespace core {

class SchedulerNode : public Node {
public:
SchedulerNode() : server_(nullptr), scheduler_thread_(nullptr), update_state_thread_(nullptr) {}


+ 23
- 27
mindspore/ccsrc/ps/core/server_node.cc View File

@@ -20,18 +20,7 @@ namespace ps {
namespace core {
ServerNode::~ServerNode() {
MS_LOG(INFO) << "Stop server node!";
if (!is_already_stopped_.load()) {
server_->Stop();
client_to_scheduler_->Stop();
client_to_scheduler_->StopEventBase();
if (server_thread_->joinable()) {
server_thread_->join();
}
if (client_to_scheduler_thread_->joinable()) {
client_to_scheduler_thread_->join();
}
is_already_stopped_ = true;
}
Stop();
}

bool ServerNode::Start(const uint32_t &timeout) {
@@ -78,6 +67,10 @@ void ServerNode::CreateTcpServer() {
case NodeCommand::SEND_DATA:
ProcessSendData(server, conn, message);
break;
case NodeCommand::COLLECTIVE_SEND_DATA:
ProcessCollectiveSendData(server, conn, message);
RunReceiveCallback(message);
break;
default:
MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!";
}
@@ -87,7 +80,6 @@ void ServerNode::CreateTcpServer() {
MS_LOG(INFO) << "The server node start a tcp server!";
server_->Start();
});
server_thread_->detach();
}

void ServerNode::Initialize() {
@@ -106,27 +98,31 @@ void ServerNode::Initialize() {
}

void ServerNode::ProcessSendData(const TcpServer &server, const TcpConnection &conn, const CommMessage &message) {
if (request_handler_) {
request_handler_(server, conn, message.pb_meta(), message.data());
}
request_handler_(server, conn, message.pb_meta(), message.data());
}

void ServerNode::ProcessCollectiveSendData(const TcpServer &server, const TcpConnection &conn,
const CommMessage &message) {
CommMessage comm_message;
*comm_message.mutable_pb_meta() = {message.pb_meta()};
const_cast<TcpServer &>(server).SendMessage(conn, comm_message);
}

bool ServerNode::Stop() {
MS_LOG(INFO) << "Stop server node!";
if (!is_already_stopped_.load()) {
server_->Stop();
is_already_stopped_ = true;
is_finish_ = true;
heart_beat_thread_->join();
client_to_scheduler_->Stop();
client_to_scheduler_->StopEventBase();
if (server_thread_->joinable()) {
server_thread_->join();
if (!connected_nodes_.empty()) {
for (auto &connected_node : connected_nodes_) {
connected_node.second->Stop();
}
}
if (client_to_scheduler_thread_->joinable()) {
client_to_scheduler_thread_->join();
}
if (heart_beat_thread_->joinable()) {
heart_beat_thread_->join();
}
is_already_stopped_ = true;
client_to_scheduler_thread_->join();
server_->Stop();
server_thread_->join();
}
return true;
}


+ 3
- 2
mindspore/ccsrc/ps/core/server_node.h View File

@@ -44,8 +44,8 @@ class ServerNode : public AbstractNode {
bool Stop() override;
bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override;

using RequestHandler = std::function<void(const TcpServer &server, const TcpConnection &conn,
const MessageMeta message_meta, const std::string &message)>;
using RequestHandler = std::function<void(const TcpServer &server, const TcpConnection &conn, const MessageMeta meta,
const std::string &message)>;

void set_handler(const RequestHandler &handler);
void Response(const TcpServer &server, const TcpConnection &conn, const MessageMeta &message_meta,
@@ -55,6 +55,7 @@ class ServerNode : public AbstractNode {
void CreateTcpServer();
void Initialize();
void ProcessSendData(const TcpServer &server, const TcpConnection &conn, const CommMessage &message);
void ProcessCollectiveSendData(const TcpServer &server, const TcpConnection &conn, const CommMessage &message);

std::shared_ptr<TcpServer> server_;
std::unique_ptr<std::thread> server_thread_;


+ 29
- 22
mindspore/ccsrc/ps/core/tcp_client.cc View File

@@ -51,7 +51,20 @@ TcpClient::TcpClient(const std::string &address, std::uint16_t port)
});
}

TcpClient::~TcpClient() { Stop(); }
TcpClient::~TcpClient() {
if (buffer_event_) {
bufferevent_free(buffer_event_);
buffer_event_ = nullptr;
}
if (event_timeout_) {
event_free(event_timeout_);
event_timeout_ = nullptr;
}
if (event_base_) {
event_base_free(event_base_);
event_base_ = nullptr;
}
}

std::string TcpClient::GetServerAddress() const { return server_address_; }

@@ -69,9 +82,9 @@ bool TcpClient::WaitConnected(const uint32_t &connected_timeout) {
void TcpClient::Init() {
std::lock_guard<std::mutex> lock(connection_mutex_);
if (buffer_event_) {
return;
bufferevent_free(buffer_event_);
buffer_event_ = nullptr;
}
is_stop_ = false;
if (!CommUtil::CheckIp(server_address_)) {
MS_LOG(EXCEPTION) << "The tcp client ip:" << server_address_ << " is illegal!";
}
@@ -82,8 +95,9 @@ void TcpClient::Init() {
}
if (event_base_ == nullptr) {
event_base_ = event_base_new();
MS_EXCEPTION_IF_NULL(event_base_);
is_stop_ = false;
}
MS_EXCEPTION_IF_NULL(event_base_);

sockaddr_in sin{};
if (memset_s(&sin, sizeof(sin), 0, sizeof(sin)) != EOK) {
@@ -127,26 +141,18 @@ void TcpClient::StartWithDelay(int seconds) {

void TcpClient::Stop() {
std::lock_guard<std::mutex> lock(connection_mutex_);
MS_LOG(INFO) << "Stop tcp client event buffer!";
if (!is_stop_.load()) {
if (buffer_event_) {
bufferevent_free(buffer_event_);
buffer_event_ = nullptr;
}

if (event_timeout_) {
event_free(event_timeout_);
event_timeout_ = nullptr;
}
MS_LOG(INFO) << "Stop tcp client!";
if (event_base_got_break(event_base_)) {
MS_LOG(DEBUG) << "The event base has stopped!";
is_stop_ = true;
return;
}
}

void TcpClient::StopEventBase() {
MS_LOG(INFO) << "Stop tcp client event base!";
int ret = event_base_loopbreak(event_base_);
if (ret != 0) {
MS_LOG(ERROR) << "Event base loop break failed!";
if (!is_stop_.load()) {
is_stop_ = true;
int ret = event_base_loopbreak(event_base_);
if (ret != 0) {
MS_LOG(ERROR) << "Event base loop break failed!";
}
}
}

@@ -280,6 +286,7 @@ void TcpClient::StartTimer(const uint32_t &time) {
void TcpClient::set_timer_callback(const OnTimer &timer) { on_timer_callback_ = timer; }

const event_base &TcpClient::eventbase() { return *event_base_; }

} // namespace core
} // namespace ps
} // namespace mindspore

+ 1
- 1
mindspore/ccsrc/ps/core/tcp_client.h View File

@@ -58,7 +58,6 @@ class TcpClient {
void Init();
void StartWithDelay(int seconds);
void Stop();
static void StopEventBase();
void Start();
void StartWithNoBlock();
void SetMessageCallback(const OnMessage &cb);
@@ -97,6 +96,7 @@ class TcpClient {
std::atomic<bool> is_stop_;
std::atomic<bool> is_connected_;
};

} // namespace core
} // namespace ps
} // namespace mindspore


+ 31
- 23
mindspore/ccsrc/ps/core/tcp_server.cc View File

@@ -32,6 +32,7 @@
namespace mindspore {
namespace ps {
namespace core {

void TcpConnection::InitConnection() {
tcp_message_handler_.SetCallback([&](const CommMessage &message) {
OnServerReceiveMessage on_server_receive = server_->GetServerReceive();
@@ -76,7 +77,22 @@ TcpServer::TcpServer(const std::string &address, std::uint16_t port)
server_port_(port),
is_stop_(true) {}

TcpServer::~TcpServer() { Stop(); }
TcpServer::~TcpServer() {
if (signal_event_ != nullptr) {
event_free(signal_event_);
signal_event_ = nullptr;
}

if (listener_ != nullptr) {
evconnlistener_free(listener_);
listener_ = nullptr;
}

if (base_ != nullptr) {
event_base_free(base_);
base_ = nullptr;
}
}

void TcpServer::SetServerCallback(const OnConnected &client_conn, const OnDisconnected &client_disconn,
const OnAccepted &client_accept) {
@@ -136,7 +152,6 @@ void TcpServer::Init() {
}

void TcpServer::Start() {
std::unique_lock<std::recursive_mutex> lock(connection_mutex_);
MS_LOG(INFO) << "Start tcp server!";
MS_EXCEPTION_IF_NULL(base_);
int ret = event_base_dispatch(base_);
@@ -148,7 +163,7 @@ void TcpServer::Start() {
}

void TcpServer::StartWithNoBlock() {
std::unique_lock<std::recursive_mutex> lock(connection_mutex_);
std::lock_guard<std::mutex> lock(connection_mutex_);
MS_LOG(INFO) << "Start tcp server with no block!";
MS_EXCEPTION_IF_NULL(base_);
int ret = event_base_loop(base_, EVLOOP_NONBLOCK);
@@ -187,33 +202,25 @@ void TcpServer::StartTimer(const uint32_t &time) {
}

void TcpServer::Stop() {
std::lock_guard<std::mutex> lock(connection_mutex_);
MS_LOG(INFO) << "Stop tcp server!";
if (event_base_got_break(base_)) {
MS_LOG(DEBUG) << "The event base has stopped!";
is_stop_ = true;
return;
}
if (!is_stop_.load()) {
is_stop_ = true;
int ret = event_base_loopbreak(base_);
if (ret != 0) {
MS_LOG(EXCEPTION) << "event base loop break failed!";
}
if (signal_event_ != nullptr) {
event_free(signal_event_);
signal_event_ = nullptr;
}

if (listener_ != nullptr) {
evconnlistener_free(listener_);
listener_ = nullptr;
MS_LOG(ERROR) << "Event base loop break failed!";
}

if (base_ != nullptr) {
event_base_free(base_);
base_ = nullptr;
}
is_stop_ = true;
}
}

void TcpServer::SendToAllClients(const char *data, size_t len) {
MS_EXCEPTION_IF_NULL(data);
std::unique_lock<std::recursive_mutex> lock(connection_mutex_);
std::lock_guard<std::mutex> lock(connection_mutex_);
for (auto it = connections_.begin(); it != connections_.end(); ++it) {
it->second->SendMessage(data, len);
}
@@ -221,12 +228,12 @@ void TcpServer::SendToAllClients(const char *data, size_t len) {

void TcpServer::AddConnection(const evutil_socket_t &fd, const TcpConnection *connection) {
MS_EXCEPTION_IF_NULL(connection);
std::unique_lock<std::recursive_mutex> lock(connection_mutex_);
std::lock_guard<std::mutex> lock(connection_mutex_);
connections_.insert(std::make_pair(fd, connection));
}

void TcpServer::RemoveConnection(const evutil_socket_t &fd) {
std::unique_lock<std::recursive_mutex> lock(connection_mutex_);
std::lock_guard<std::mutex> lock(connection_mutex_);
TcpConnection *connection = const_cast<TcpConnection *>(connections_.find(fd)->second);
delete connection;
connections_.erase(fd);
@@ -352,7 +359,7 @@ void TcpServer::TimerOnceCallback(evutil_socket_t, int16_t, void *arg) {
void TcpServer::SendMessage(const TcpConnection &conn, const CommMessage &message) { conn.SendMessage(message); }

void TcpServer::SendMessage(const CommMessage &message) {
std::unique_lock<std::recursive_mutex> lock(connection_mutex_);
std::lock_guard<std::mutex> lock(connection_mutex_);

for (auto it = connections_.begin(); it != connections_.end(); ++it) {
SendMessage(*it->second, message);
@@ -368,6 +375,7 @@ int TcpServer::ConnectionNum() const { return connections_.size(); }
const std::map<evutil_socket_t, const TcpConnection *> &TcpServer::Connections() const { return connections_; }

void TcpServer::SetMessageCallback(const OnServerReceiveMessage &cb) { message_callback_ = cb; }

} // namespace core
} // namespace ps
} // namespace mindspore

+ 1
- 1
mindspore/ccsrc/ps/core/tcp_server.h View File

@@ -121,7 +121,7 @@ class TcpServer {
OnConnected client_connection_;
OnDisconnected client_disconnection_;
OnAccepted client_accept_;
std::recursive_mutex connection_mutex_;
std::mutex connection_mutex_;
OnServerReceiveMessage message_callback_;
OnTimerOnce on_timer_once_callback_;
OnTimer on_timer_callback_;


+ 4
- 25
mindspore/ccsrc/ps/core/worker_node.cc View File

@@ -21,24 +21,7 @@ namespace ps {
namespace core {
WorkerNode::~WorkerNode() {
MS_LOG(INFO) << "Stop worker node!";
if (!is_already_stopped_.load()) {
is_ready_ = true;
is_timeout_ = true;
client_to_scheduler_->Stop();
if (!connected_nodes_.empty()) {
for (auto &connected_node : connected_nodes_) {
connected_node.second->Stop();
}
}
client_to_scheduler_->StopEventBase();
if (client_to_scheduler_thread_->joinable()) {
client_to_scheduler_thread_->join();
}
if (heart_beat_thread_->joinable()) {
heart_beat_thread_->join();
}
is_already_stopped_ = true;
}
Stop();
}
bool WorkerNode::Start(const uint32_t &timeout) {
MS_LOG(INFO) << "Starting worker node!";
@@ -78,19 +61,15 @@ bool WorkerNode::Stop() {
if (!is_already_stopped_.load()) {
is_ready_ = true;
is_timeout_ = true;
is_finish_ = true;
heart_beat_thread_->join();
client_to_scheduler_->Stop();
if (!connected_nodes_.empty()) {
for (auto &connected_node : connected_nodes_) {
connected_node.second->Stop();
}
}
client_to_scheduler_->StopEventBase();
if (client_to_scheduler_thread_->joinable()) {
client_to_scheduler_thread_->join();
}
if (heart_beat_thread_->joinable()) {
heart_beat_thread_->join();
}
client_to_scheduler_thread_->join();
is_already_stopped_ = true;
}
return true;


+ 42
- 0
tests/ut/cpp/ps/core/abstract_node_test.cc View File

@@ -0,0 +1,42 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "common/common_test.h"
#define protected public
#include "ps/core/worker_node.h"
#undef protected

namespace mindspore {
namespace ps {
namespace core {
class TestAbstractNode : public UT::Common {
public:
TestAbstractNode() = default;
virtual ~TestAbstractNode() = default;

void SetUp() override {}
void TearDown() override {}
};

TEST_F(TestAbstractNode, NextExpectedRankRequestId) {
WorkerNode workerNode;
ASSERT_EQ(1, workerNode.NextExpectedRankRequestId(0));
ASSERT_EQ(2, workerNode.NextExpectedRankRequestId(0));
ASSERT_EQ(1, workerNode.NextExpectedRankRequestId(1));
}
} // namespace core
} // namespace ps
} // namespace mindspore

Loading…
Cancel
Save