Browse Source

Custom data transmission format

tags/v1.2.0-rc1
chendongsheng 4 years ago
parent
commit
c7fe82b43d
17 changed files with 622 additions and 381 deletions
  1. +206
    -162
      mindspore/ccsrc/ps/core/abstract_node.cc
  2. +35
    -23
      mindspore/ccsrc/ps/core/abstract_node.h
  3. +59
    -0
      mindspore/ccsrc/ps/core/message.h
  4. +6
    -3
      mindspore/ccsrc/ps/core/protos/comm.proto
  5. +32
    -25
      mindspore/ccsrc/ps/core/scheduler_node.cc
  6. +6
    -5
      mindspore/ccsrc/ps/core/scheduler_node.h
  7. +29
    -22
      mindspore/ccsrc/ps/core/server_node.cc
  8. +8
    -4
      mindspore/ccsrc/ps/core/server_node.h
  9. +50
    -14
      mindspore/ccsrc/ps/core/tcp_client.cc
  10. +5
    -4
      mindspore/ccsrc/ps/core/tcp_client.h
  11. +11
    -5
      mindspore/ccsrc/ps/core/tcp_message_handler.cc
  12. +7
    -10
      mindspore/ccsrc/ps/core/tcp_message_handler.h
  13. +43
    -8
      mindspore/ccsrc/ps/core/tcp_server.cc
  14. +5
    -2
      mindspore/ccsrc/ps/core/tcp_server.h
  15. +12
    -2
      tests/ut/cpp/ps/core/tcp_client_tests.cc
  16. +96
    -82
      tests/ut/cpp/ps/core/tcp_message_handler_test.cc
  17. +12
    -10
      tests/ut/cpp/ps/core/tcp_pb_server_test.cc

+ 206
- 162
mindspore/ccsrc/ps/core/abstract_node.cc View File

@@ -20,8 +20,9 @@ namespace mindspore {
namespace ps { namespace ps {
namespace core { namespace core {
void AbstractNode::Register(const std::shared_ptr<TcpClient> &client) { void AbstractNode::Register(const std::shared_ptr<TcpClient> &client) {
MessageMeta message_meta;
message_meta.set_cmd(NodeCommand::REGISTER);
MS_EXCEPTION_IF_NULL(client);
auto message_meta = std::make_shared<MessageMeta>();
message_meta->set_cmd(NodeCommand::REGISTER);


RegisterMessage register_message; RegisterMessage register_message;
register_message.set_node_id(node_info_.node_id_); register_message.set_node_id(node_info_.node_id_);
@@ -29,11 +30,8 @@ void AbstractNode::Register(const std::shared_ptr<TcpClient> &client) {
register_message.set_ip(node_info_.ip_); register_message.set_ip(node_info_.ip_);
register_message.set_port(node_info_.port_); register_message.set_port(node_info_.port_);


CommMessage comm_message;
*comm_message.mutable_pb_meta() = {message_meta};
comm_message.set_data(register_message.SerializeAsString());
comm_message.set_user_cmd("");
if (!SendMessageSync(client, comm_message)) {
if (!SendMessageSync(client, message_meta, Protos::PROTOBUF, register_message.SerializeAsString().data(),
register_message.ByteSizeLong())) {
MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_) MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << " register timeout!"; << " the node id:" << node_info_.node_id_ << " register timeout!";
} }
@@ -42,9 +40,11 @@ void AbstractNode::Register(const std::shared_ptr<TcpClient> &client) {
<< " the node id:" << node_info_.node_id_ << "is registering to scheduler!"; << " the node id:" << node_info_.node_id_ << "is registering to scheduler!";
} }


void AbstractNode::ProcessRegisterResp(const CommMessage &message) {
void AbstractNode::ProcessRegisterResp(std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
RegisterRespMessage register_resp_message; RegisterRespMessage register_resp_message;
register_resp_message.ParseFromString(message.data());
register_resp_message.ParseFromArray(data, size);
if (register_resp_message.node_id() != node_info_.node_id_) { if (register_resp_message.node_id() != node_info_.node_id_) {
MS_LOG(EXCEPTION) << "The node id received:" << register_resp_message.node_id() MS_LOG(EXCEPTION) << "The node id received:" << register_resp_message.node_id()
<< " is not match the current node id:" << node_info_.node_id_; << " is not match the current node id:" << node_info_.node_id_;
@@ -52,28 +52,29 @@ void AbstractNode::ProcessRegisterResp(const CommMessage &message) {


node_info_.rank_id_ = register_resp_message.rank_id(); node_info_.rank_id_ = register_resp_message.rank_id();


MS_LOG(INFO) << "The node id is:" << node_info_.node_id_ << ", and the rank id is:" << node_info_.rank_id_;
MS_LOG(INFO) << "The node id is:" << node_info_.node_id_ << ", and the rank id is:" << node_info_.rank_id_
<< " registered scheduler success!";
} }


bool AbstractNode::Broadcast(const enum NodeRole &node_role, const CommMessage &message, const uint32_t &timeout) {
bool AbstractNode::Broadcast(const enum NodeRole &node_role, const DataPtr &message, size_t size, int command,
const uint32_t &timeout) {
MS_EXCEPTION_IF_NULL(message);
if (node_role != NodeRole::SERVER) { if (node_role != NodeRole::SERVER) {
MS_LOG(EXCEPTION) << "Currently only supports broadcast to server nodes"; MS_LOG(EXCEPTION) << "Currently only supports broadcast to server nodes";
} }


CommMessage &comm_message = const_cast<CommMessage &>(message);
uint64_t request_id = ++next_request_id_;
message_tracker_[request_id] = std::make_pair(nodes_address_.size(), 0);
uint64_t request_id = AddMessageTrack(nodes_address_.size());


for (auto it = nodes_address_.begin(); it != nodes_address_.end(); ++it) { for (auto it = nodes_address_.begin(); it != nodes_address_.end(); ++it) {
MessageMeta message_meta;
message_meta.set_cmd(NodeCommand::SEND_DATA);
message_meta.set_request_id(request_id);
message_meta.set_rank_id(node_info_.rank_id_);
message_meta.set_role(node_info_.node_role_);
auto message_meta = std::make_shared<MessageMeta>();
message_meta->set_cmd(NodeCommand::SEND_DATA);
message_meta->set_request_id(request_id);
message_meta->set_rank_id(node_info_.rank_id_);
message_meta->set_role(node_info_.node_role_);
message_meta->set_user_cmd(command);


*comm_message.mutable_pb_meta() = {message_meta};
auto client = GetOrCreateTcpClient((*it).first.second); auto client = GetOrCreateTcpClient((*it).first.second);
client->SendMessage(comm_message);
client->SendMessage(message_meta, Protos::RAW, message.get(), size);
} }
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id; << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
@@ -84,28 +85,27 @@ void AbstractNode::set_event_callback(const OnNodeEventMessage &on_node_event_me
on_node_event_message_ = on_node_event_message; on_node_event_message_ = on_node_event_message;
} }


bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message,
const uint32_t &timeout) {
bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const DataPtr &data, size_t len,
int command, const uint32_t &timeout) {
MS_EXCEPTION_IF_NULL(data);
if (!CommUtil::ValidateRankId(node_role, rank_id)) { if (!CommUtil::ValidateRankId(node_role, rank_id)) {
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
} }


CommMessage &comm_message = const_cast<CommMessage &>(message);

MessageMeta message_meta;
message_meta.set_cmd(NodeCommand::SEND_DATA);
message_meta.set_rank_id(node_info_.rank_id_);
message_meta.set_role(node_info_.node_role_);
auto message_meta = std::make_shared<MessageMeta>();
message_meta->set_cmd(NodeCommand::SEND_DATA);
message_meta->set_rank_id(node_info_.rank_id_);
message_meta->set_role(node_info_.node_role_);
message_meta->set_user_cmd(command);


*comm_message.mutable_pb_meta() = {message_meta};
auto client = GetOrCreateTcpClient(rank_id); auto client = GetOrCreateTcpClient(rank_id);
return SendMessageSync(client, comm_message, timeout);
return SendMessageSync(client, message_meta, Protos::RAW, data.get(), len, timeout);
} }


bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
const std::vector<CommMessage> &data, const uint32_t &timeout) {
uint64_t request_id = ++next_request_id_;
message_tracker_[request_id] = std::make_pair(data.size(), 0);
const std::vector<DataPtr> &data, const std::vector<size_t> &lens, int command,
const uint32_t &timeout) {
uint64_t request_id = AddMessageTrack(data.size());


if (rank_ids.size() != data.size()) { if (rank_ids.size() != data.size()) {
MS_LOG(EXCEPTION) << "The number of rank ids is not equal to the number of data!"; MS_LOG(EXCEPTION) << "The number of rank ids is not equal to the number of data!";
@@ -115,34 +115,32 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
} }


MessageMeta message_meta;
message_meta.set_cmd(NodeCommand::SEND_DATA);
message_meta.set_request_id(request_id);
message_meta.set_rank_id(node_info_.rank_id_);
message_meta.set_role(node_info_.node_role_);

CommMessage &comm_message = const_cast<CommMessage &>(data.at(it));
*comm_message.mutable_pb_meta() = {message_meta};
auto message_meta = std::make_shared<MessageMeta>();
message_meta->set_cmd(NodeCommand::SEND_DATA);
message_meta->set_request_id(request_id);
message_meta->set_rank_id(node_info_.rank_id_);
message_meta->set_role(node_info_.node_role_);
message_meta->set_user_cmd(command);


auto send = data.at(it);
auto len = lens.at(it);
auto client = GetOrCreateTcpClient(rank_ids.at(it)); auto client = GetOrCreateTcpClient(rank_ids.at(it));
client->SendMessage(comm_message);
client->SendMessage(message_meta, Protos::RAW, send.get(), len);
} }
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id; << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
return Wait(request_id, timeout); return Wait(request_id, timeout);
} }


bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message,
CommMessage *output, const uint32_t &timeout) {
bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const DataPtr &message, size_t len,
int command, VectorPtr *output, const uint32_t &timeout) {
MS_EXCEPTION_IF_NULL(message);
MS_EXCEPTION_IF_NULL(output); MS_EXCEPTION_IF_NULL(output);
if (!CommUtil::ValidateRankId(node_role, rank_id)) { if (!CommUtil::ValidateRankId(node_role, rank_id)) {
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
} }


CommMessage &comm_message = const_cast<CommMessage &>(message);

uint64_t request_id = ++next_request_id_;
message_tracker_[request_id] = std::make_pair(1, 0);
uint64_t request_id = AddMessageTrack(1);
set_message_callback(request_id, [&]() { set_message_callback(request_id, [&]() {
receive_messages_mutex_.lock(); receive_messages_mutex_.lock();
auto res = receive_messages_[request_id]; auto res = receive_messages_[request_id];
@@ -151,59 +149,59 @@ bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id,
receive_messages_mutex_.unlock(); receive_messages_mutex_.unlock();
}); });


MessageMeta message_meta;
message_meta.set_cmd(NodeCommand::SEND_DATA);
message_meta.set_request_id(request_id);
message_meta.set_rank_id(node_info_.rank_id_);
message_meta.set_role(node_info_.node_role_);
auto message_meta = std::make_shared<MessageMeta>();
message_meta->set_cmd(NodeCommand::SEND_DATA);
message_meta->set_request_id(request_id);
message_meta->set_rank_id(node_info_.rank_id_);
message_meta->set_role(node_info_.node_role_);
message_meta->set_user_cmd(command);


*comm_message.mutable_pb_meta() = {message_meta};
auto client = GetOrCreateTcpClient(rank_id); auto client = GetOrCreateTcpClient(rank_id);
client->SendMessage(comm_message);
client->SendMessage(message_meta, Protos::RAW, message.get(), len);
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id; << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
return Wait(request_id, timeout); return Wait(request_id, timeout);
} }


bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
const std::vector<CommMessage> &data, std::vector<CommMessage> *output,
const uint32_t &timeout) {
const std::vector<DataPtr> &data, const std::vector<size_t> &data_lens, int command,
std::vector<VectorPtr> *output, const uint32_t &timeout) {
MS_EXCEPTION_IF_NULL(output); MS_EXCEPTION_IF_NULL(output);
uint64_t request_id = ++next_request_id_;
message_tracker_[request_id] = std::make_pair(data.size(), 0);
uint64_t request_id = AddMessageTrack(data.size());


if (rank_ids.size() != data.size()) { if (rank_ids.size() != data.size()) {
MS_LOG(EXCEPTION) << "The number of rank ids, data, comm_message_resp should be equal!"; MS_LOG(EXCEPTION) << "The number of rank ids, data, comm_message_resp should be equal!";
} }


size_t len = rank_ids.size();
size_t size = rank_ids.size();


set_message_callback(request_id, [&]() { set_message_callback(request_id, [&]() {
receive_messages_mutex_.lock(); receive_messages_mutex_.lock();
auto res = receive_messages_[request_id]; auto res = receive_messages_[request_id];
for (size_t it = 0; it < len; ++it) {
for (size_t it = 0; it < size; ++it) {
(*output).push_back(res[rank_ids.at(it)]); (*output).push_back(res[rank_ids.at(it)]);
} }
receive_messages_.erase(request_id); receive_messages_.erase(request_id);
receive_messages_mutex_.unlock(); receive_messages_mutex_.unlock();
}); });


for (size_t it = 0; it < len; ++it) {
for (size_t it = 0; it < size; ++it) {
if (!CommUtil::ValidateRankId(node_role, rank_ids.at(it))) { if (!CommUtil::ValidateRankId(node_role, rank_ids.at(it))) {
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
} }


MessageMeta message_meta;
message_meta.set_cmd(NodeCommand::SEND_DATA);
message_meta.set_request_id(request_id);
message_meta.set_rank_id(node_info_.rank_id_);
message_meta.set_role(node_info_.node_role_);
auto message_meta = std::make_shared<MessageMeta>();
message_meta->set_cmd(NodeCommand::SEND_DATA);
message_meta->set_request_id(request_id);
message_meta->set_rank_id(node_info_.rank_id_);
message_meta->set_role(node_info_.node_role_);
message_meta->set_user_cmd(command);


CommMessage &comm_message = const_cast<CommMessage &>(data.at(it));
*comm_message.mutable_pb_meta() = {message_meta};
auto send = data.at(it);
auto len = data_lens.at(it);


auto client = GetOrCreateTcpClient(rank_ids.at(it)); auto client = GetOrCreateTcpClient(rank_ids.at(it));
client->SendMessage(comm_message);
client->SendMessage(message_meta, Protos::RAW, send.get(), len);
} }
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id; << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
@@ -220,55 +218,61 @@ bool AbstractNode::Wait(uint64_t request_id, const uint32_t &timeout) {
return res; return res;
} }


uint64_t AbstractNode::CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id,
const CommMessage &message) {
uint64_t AbstractNode::CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const void *data,
size_t size) {
MS_EXCEPTION_IF_NULL(data);
if (!CommUtil::ValidateRankId(node_role, rank_id)) { if (!CommUtil::ValidateRankId(node_role, rank_id)) {
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
} }


CommMessage &comm_message = const_cast<CommMessage &>(message);

MessageMeta message_meta;
message_meta.set_cmd(NodeCommand::COLLECTIVE_SEND_DATA);
message_meta.set_rank_id(node_info_.rank_id_);
message_meta.set_role(node_info_.node_role_);
std::shared_ptr<MessageMeta> message_meta = std::make_shared<MessageMeta>();
message_meta->set_cmd(NodeCommand::COLLECTIVE_SEND_DATA);
message_meta->set_rank_id(node_info_.rank_id_);
message_meta->set_role(node_info_.node_role_);


*comm_message.mutable_pb_meta() = {message_meta};
auto client = GetOrCreateTcpClient(rank_id); auto client = GetOrCreateTcpClient(rank_id);
return SendMessageAsync(client, comm_message);
return SendMessageAsync(client, message_meta, Protos::RAW, data, size);
} }


std::pair<uint32_t, uint64_t> AbstractNode::CollectiveReceiveAsync(const enum NodeRole &node_role, std::pair<uint32_t, uint64_t> AbstractNode::CollectiveReceiveAsync(const enum NodeRole &node_role,
const uint32_t &rank_id, CommMessage *output) {
const uint32_t &rank_id, void **output,
size_t *size) {
MS_EXCEPTION_IF_NULL(output);
MS_EXCEPTION_IF_NULL(size);
if (!CommUtil::ValidateRankId(node_role, rank_id)) { if (!CommUtil::ValidateRankId(node_role, rank_id)) {
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!";
} }


receive_callbacks_mutex_.lock();
uint64_t rank_request_id = NextExpectedRankRequestId(rank_id); uint64_t rank_request_id = NextExpectedRankRequestId(rank_id);
receive_messages_done_[std::make_pair(rank_id, rank_request_id)] = false;
if (received_data_.count(std::make_pair(rank_id, rank_request_id)) > 0) { if (received_data_.count(std::make_pair(rank_id, rank_request_id)) > 0) {
*output = received_data_[std::make_pair(rank_id, rank_request_id)];
auto res = received_data_[std::make_pair(rank_id, rank_request_id)];
*output = res->data();
*size = res->size();
received_data_.erase(std::make_pair(rank_id, rank_request_id)); received_data_.erase(std::make_pair(rank_id, rank_request_id));
receive_messages_done_[std::make_pair(rank_id, rank_request_id)] = true;
MS_LOG(DEBUG) << "Receive data from rank id:" << rank_id << ", the rank request id is:" << rank_request_id;
} else { } else {
set_receive_callback(rank_id, rank_request_id, [=]() {
receive_callbacks_[std::make_pair(rank_id, rank_request_id)] = [=]() mutable {
receive_callbacks_mutex_.lock(); receive_callbacks_mutex_.lock();
*output = received_data_[std::make_pair(rank_id, rank_request_id)];
auto res = received_data_[std::make_pair(rank_id, rank_request_id)];
*output = res->data();
*size = res->size();
received_data_.erase(std::make_pair(rank_id, rank_request_id)); received_data_.erase(std::make_pair(rank_id, rank_request_id));
receive_messages_done_[std::make_pair(rank_id, rank_request_id)] = true;
MS_LOG(DEBUG) << "Receive data from rank id:" << rank_id << ", the rank request id is:" << rank_request_id;
receive_callbacks_mutex_.unlock(); receive_callbacks_mutex_.unlock();
});
};
} }
receive_callbacks_mutex_.unlock();
return std::make_pair(rank_id, rank_request_id); return std::make_pair(rank_id, rank_request_id);
} }


bool AbstractNode::CollectiveWait(std::pair<uint32_t, uint64_t> request_id, const uint32_t &timeout) { bool AbstractNode::CollectiveWait(std::pair<uint32_t, uint64_t> request_id, const uint32_t &timeout) {
std::unique_lock<std::mutex> lock(receive_callbacks_mutex_); std::unique_lock<std::mutex> lock(receive_callbacks_mutex_);
bool res = receive_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] {
if (actual_rank_request_ids_.count(request_id.first) &&
(actual_rank_request_ids_[request_id.first] >= request_id.second)) {
return true;
} else {
return false;
}
});
bool res =
receive_cond_.wait_for(lock, std::chrono::seconds(timeout), [&] { return receive_messages_done_[request_id]; });
return res; return res;
} }


@@ -297,17 +301,15 @@ void AbstractNode::StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client)
} }


bool AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_node_finish) { bool AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_node_finish) {
MessageMeta meta;
meta.set_cmd(NodeCommand::HEARTBEAT);
auto meta = std::make_shared<MessageMeta>();
meta->set_cmd(NodeCommand::HEARTBEAT);


HeartbeatMessage heartbeat_message; HeartbeatMessage heartbeat_message;
heartbeat_message.set_node_id(node_info_.node_id_); heartbeat_message.set_node_id(node_info_.node_id_);
heartbeat_message.set_is_node_finish(is_node_finish); heartbeat_message.set_is_node_finish(is_node_finish);


CommMessage message;
*message.mutable_pb_meta() = {meta};
message.set_data(heartbeat_message.SerializeAsString());
if (!SendMessageSync(client, message)) {
if (!SendMessageSync(client, meta, Protos::PROTOBUF, heartbeat_message.SerializeAsString().data(),
heartbeat_message.ByteSizeLong())) {
MS_LOG(ERROR) << "The node id:" << node_info_.node_id_ << " Send heartbeat timeout!"; MS_LOG(ERROR) << "The node id:" << node_info_.node_id_ << " Send heartbeat timeout!";
} }
return true; return true;
@@ -331,9 +333,11 @@ bool AbstractNode::CheckSchedulerTimeout() const {
return false; return false;
} }


void AbstractNode::ProcessHeartbeatResp(const CommMessage &message) {
void AbstractNode::ProcessHeartbeatResp(std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
HeartbeatRespMessage heartbeat_resp_message; HeartbeatRespMessage heartbeat_resp_message;
heartbeat_resp_message.ParseFromString(message.data());
heartbeat_resp_message.ParseFromArray(data, size);


is_ready_ = heartbeat_resp_message.is_cluster_ready(); is_ready_ = heartbeat_resp_message.is_cluster_ready();
if (is_ready_.load()) { if (is_ready_.load()) {
@@ -359,19 +363,22 @@ void AbstractNode::ProcessHeartbeatResp(const CommMessage &message) {
} }


void AbstractNode::FetchServers(const std::shared_ptr<TcpClient> &client) { void AbstractNode::FetchServers(const std::shared_ptr<TcpClient> &client) {
MessageMeta meta;
meta.set_cmd(NodeCommand::FETCH_SERVER);
auto meta = std::make_shared<MessageMeta>();
meta->set_cmd(NodeCommand::FETCH_SERVER);


CommMessage message;
*message.mutable_pb_meta() = {meta};
if (!SendMessageSync(client, message)) {
FetchServersMessage fetch_servers;
fetch_servers.set_node_id(node_info_.node_id_);
if (!SendMessageSync(client, meta, Protos::PROTOBUF, fetch_servers.SerializeAsString().data(),
fetch_servers.ByteSizeLong())) {
MS_LOG(EXCEPTION) << "Fetch servers address timeout!"; MS_LOG(EXCEPTION) << "Fetch servers address timeout!";
} }
} }


void AbstractNode::ProcessFetchServersResp(const CommMessage &message) {
void AbstractNode::ProcessFetchServersResp(std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
FetchServersRespMessage fetch_servers_resp_message; FetchServersRespMessage fetch_servers_resp_message;
fetch_servers_resp_message.ParseFromString(message.data());
fetch_servers_resp_message.ParseFromArray(data, size);


for (const auto &it : fetch_servers_resp_message.servers_meta()) { for (const auto &it : fetch_servers_resp_message.servers_meta()) {
nodes_address_[std::make_pair(NodeRole::SERVER, it.rank_id())] = std::make_pair(it.ip(), it.port()); nodes_address_[std::make_pair(NodeRole::SERVER, it.rank_id())] = std::make_pair(it.ip(), it.port());
@@ -381,16 +388,14 @@ void AbstractNode::ProcessFetchServersResp(const CommMessage &message) {
} }


bool AbstractNode::Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout) { bool AbstractNode::Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout) {
MessageMeta meta;
meta.set_cmd(NodeCommand::FINISH);
auto meta = std::make_shared<MessageMeta>();
meta->set_cmd(NodeCommand::FINISH);


FinishMessage finish_message; FinishMessage finish_message;
finish_message.set_node_id(node_info_.node_id_); finish_message.set_node_id(node_info_.node_id_);


CommMessage message;
*message.mutable_pb_meta() = {meta};
message.set_data(finish_message.SerializeAsString());
if (!SendMessageSync(client, message)) {
if (!SendMessageSync(client, meta, Protos::PROTOBUF, finish_message.SerializeAsString().data(),
finish_message.ByteSizeLong())) {
MS_LOG(ERROR) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_) MS_LOG(ERROR) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << " send Finish Message timeout!"; << " the node id:" << node_info_.node_id_ << " send Finish Message timeout!";
} }
@@ -412,16 +417,17 @@ bool AbstractNode::InitClientToScheduler() {
std::string scheduler_host = ClusterConfig::scheduler_host(); std::string scheduler_host = ClusterConfig::scheduler_host();
uint16_t scheduler_port = ClusterConfig::scheduler_port(); uint16_t scheduler_port = ClusterConfig::scheduler_port();
client_to_scheduler_ = std::make_shared<TcpClient>(scheduler_host, scheduler_port); client_to_scheduler_ = std::make_shared<TcpClient>(scheduler_host, scheduler_port);
client_to_scheduler_->SetMessageCallback([&](const TcpClient &client, const CommMessage &message) {
if (handlers_.count(message.pb_meta().cmd()) == 0) {
MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!";
}
if (handlers_[message.pb_meta().cmd()] != nullptr) {
const auto &handler_ptr = handlers_[message.pb_meta().cmd()];
(this->*handler_ptr)(message);
}
NotifyMessageArrival(message);
});
client_to_scheduler_->SetMessageCallback(
[&](std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size) {
if (handlers_.count(meta->cmd()) == 0) {
MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!";
}
if (handlers_[meta->cmd()] != nullptr) {
const auto &handler_ptr = handlers_[meta->cmd()];
(this->*handler_ptr)(meta, data, size);
}
NotifyMessageArrival(meta);
});


client_to_scheduler_->Init(); client_to_scheduler_->Init();
client_to_scheduler_thread_ = std::make_unique<std::thread>([&]() { client_to_scheduler_thread_ = std::make_unique<std::thread>([&]() {
@@ -447,19 +453,20 @@ const std::shared_ptr<TcpClient> &AbstractNode::GetOrCreateTcpClient(const int &
std::string ip = nodes_address_[std::make_pair(NodeRole::SERVER, rank_id)].first; std::string ip = nodes_address_[std::make_pair(NodeRole::SERVER, rank_id)].first;
uint16_t port = nodes_address_[std::make_pair(NodeRole::SERVER, rank_id)].second; uint16_t port = nodes_address_[std::make_pair(NodeRole::SERVER, rank_id)].second;
auto client = std::make_shared<TcpClient>(ip, port); auto client = std::make_shared<TcpClient>(ip, port);
client->SetMessageCallback([&](const TcpClient &client, const CommMessage &message) {
switch (message.pb_meta().cmd()) {
client->SetMessageCallback([&](std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data,
size_t size) {
switch (meta->cmd()) {
case NodeCommand::SEND_DATA: case NodeCommand::SEND_DATA:
ProcessSendDataResp(message);
RunMessageCallback(message.pb_meta().request_id());
ProcessSendDataResp(meta, protos, data, size);
RunMessageCallback(meta->request_id());
break; break;
case NodeCommand::COLLECTIVE_SEND_DATA: case NodeCommand::COLLECTIVE_SEND_DATA:
MS_LOG(INFO) << "The Node id:" << node_info_.node_id_ << " receive a collective_send_data message response!";
MS_LOG(DEBUG) << "The Node id:" << node_info_.node_id_ << " receive a collective_send_data message response!";
break; break;
default: default:
MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!";
MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!";
} }
NotifyMessageArrival(message);
NotifyMessageArrival(meta);
}); });
client->Init(); client->Init();
connected_nodes_[rank_id] = client; connected_nodes_[rank_id] = client;
@@ -469,8 +476,7 @@ const std::shared_ptr<TcpClient> &AbstractNode::GetOrCreateTcpClient(const int &


bool AbstractNode::SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message, bool AbstractNode::SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message,
const uint32_t &timeout) { const uint32_t &timeout) {
uint64_t request_id = ++next_request_id_;
message_tracker_[request_id] = std::make_pair(1, 0);
uint64_t request_id = AddMessageTrack(1);
const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id); const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id);
client->SendMessage(message); client->SendMessage(message);
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
@@ -478,29 +484,55 @@ bool AbstractNode::SendMessageSync(const std::shared_ptr<TcpClient> &client, con
return Wait(request_id, timeout); return Wait(request_id, timeout);
} }


uint64_t AbstractNode::SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message) {
uint64_t request_id = ++next_request_id_;
message_tracker_[request_id] = std::make_pair(1, 0);
const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id);
client->SendMessage(message);
uint64_t AbstractNode::SendMessageAsync(const std::shared_ptr<TcpClient> &client, std::shared_ptr<MessageMeta> meta,
const Protos &protos, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(client);
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
uint64_t request_id = AddMessageTrack(1);
meta->set_request_id(request_id);
client->SendMessage(meta, protos, data, size);
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id; << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
return request_id; return request_id;
} }


void AbstractNode::ProcessSendDataResp(const CommMessage &message) {
bool AbstractNode::SendMessageSync(const std::shared_ptr<TcpClient> &client, std::shared_ptr<MessageMeta> meta,
const Protos &protos, const void *data, size_t size, const uint32_t &timeout) {
MS_EXCEPTION_IF_NULL(client);
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
uint64_t request_id = AddMessageTrack(1);
meta->set_request_id(request_id);
client->SendMessage(meta, protos, data, size);
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
bool res = Wait(request_id, timeout);
return res;
}

void AbstractNode::ProcessSendDataResp(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data,
size_t size) {
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
std::lock_guard<std::mutex> lock(receive_messages_mutex_); std::lock_guard<std::mutex> lock(receive_messages_mutex_);
const MessageMeta &message_meta = message.pb_meta();
const uint32_t &rank_id = message_meta.rank_id();
const uint64_t request_id = message_meta.request_id();
const uint32_t &rank_id = meta->rank_id();
const uint64_t request_id = meta->request_id();
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id; << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
auto it = receive_messages_.find(request_id); auto it = receive_messages_.find(request_id);
VectorPtr received_data = std::make_shared<std::vector<unsigned char>>(size, 0);
if (size > 0) {
int ret = memcpy_s(received_data.get()->data(), size, data, size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
}
if (it != receive_messages_.end()) { if (it != receive_messages_.end()) {
it->second[rank_id] = message;
it->second[rank_id] = received_data;
} else { } else {
std::unordered_map<uint32_t, CommMessage> res;
res.insert(std::make_pair(rank_id, message));
std::unordered_map<uint32_t, VectorPtr> res;
res.insert(std::make_pair(rank_id, received_data));
receive_messages_[request_id] = res; receive_messages_[request_id] = res;
} }
} }
@@ -509,7 +541,7 @@ void AbstractNode::RunMessageCallback(const uint64_t &request_id) {
message_callbacks_mutex_.lock(); message_callbacks_mutex_.lock();
// When receiving a message's response, Then compare with the desired number of responses, // When receiving a message's response, Then compare with the desired number of responses,
// If they are equal, then call the callback function // If they are equal, then call the callback function
if (message_tracker_[request_id].first == message_tracker_[request_id].second + 1) {
if (CheckMessageTrack(request_id)) {
auto it = message_callbacks_.find(request_id); auto it = message_callbacks_.find(request_id);
if (it != message_callbacks_.end()) { if (it != message_callbacks_.end()) {
message_callbacks_mutex_.unlock(); message_callbacks_mutex_.unlock();
@@ -533,31 +565,31 @@ void AbstractNode::set_message_callback(const uint64_t &request_id, const Messag
message_callbacks_[request_id] = callback; message_callbacks_[request_id] = callback;
} }


void AbstractNode::NotifyMessageArrival(const CommMessage &message) {
void AbstractNode::NotifyMessageArrival(std::shared_ptr<MessageMeta> meta) {
std::lock_guard<std::mutex> lock(message_tracker_mutex_); std::lock_guard<std::mutex> lock(message_tracker_mutex_);
const MessageMeta &message_meta = message.pb_meta();
uint64_t request_id = message_meta.request_id();
uint64_t request_id = meta->request_id();


message_tracker_[request_id].second++; message_tracker_[request_id].second++;
message_tracker_cond_.notify_all(); message_tracker_cond_.notify_all();
} }


void AbstractNode::set_receive_callback(const uint32_t &rank_id, const uint64_t &request_id,
const MessageCallback &callback) {
if (!callback) {
return;
}
std::lock_guard<std::mutex> lock(receive_callbacks_mutex_);
receive_callbacks_[std::make_pair(rank_id, request_id)] = callback;
}

void AbstractNode::RunReceiveCallback(const CommMessage &message) {
void AbstractNode::RunReceiveCallback(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data,
size_t size) {
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
receive_callbacks_mutex_.lock(); receive_callbacks_mutex_.lock();
uint32_t rank_id = message.pb_meta().rank_id();
uint32_t rank_id = meta->rank_id();
// When receiving a collective message, Then generate rank request id,compare with the desired rank request id, // When receiving a collective message, Then generate rank request id,compare with the desired rank request id,
// If they are equal, then call the callback function // If they are equal, then call the callback function
uint64_t rank_request_id = NextActualRankRequestId(rank_id); uint64_t rank_request_id = NextActualRankRequestId(rank_id);
received_data_[std::make_pair(rank_id, rank_request_id)] = message;
std::shared_ptr<std::vector<unsigned char>> received_data = std::make_shared<std::vector<unsigned char>>(size, 0);
int ret = memcpy_s(received_data->data(), size, data, size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
received_data_[std::make_pair(rank_id, rank_request_id)] = received_data;
MS_LOG(DEBUG) << "Run Receive data callback,the rank id:" << rank_id << ", the rank request id is:" << rank_request_id
<< ", the send request id is:" << meta->request_id();
auto it = receive_callbacks_.find(std::make_pair(rank_id, rank_request_id)); auto it = receive_callbacks_.find(std::make_pair(rank_id, rank_request_id));
if (it != receive_callbacks_.end()) { if (it != receive_callbacks_.end()) {
receive_callbacks_mutex_.unlock(); receive_callbacks_mutex_.unlock();
@@ -603,6 +635,18 @@ void AbstractNode::InitCommandHandler() {
handlers_[NodeCommand::FETCH_SERVER] = &AbstractNode::ProcessFetchServersResp; handlers_[NodeCommand::FETCH_SERVER] = &AbstractNode::ProcessFetchServersResp;
handlers_[NodeCommand::FINISH] = nullptr; handlers_[NodeCommand::FINISH] = nullptr;
} }

uint64_t AbstractNode::AddMessageTrack(const uint32_t &expected_response) {
std::lock_guard<std::mutex> lock(message_tracker_mutex_);
uint64_t request_id = ++next_request_id_;
message_tracker_[request_id] = std::make_pair(expected_response, 0);
return request_id;
}

bool AbstractNode::CheckMessageTrack(const uint64_t &request_id) {
std::lock_guard<std::mutex> lock(message_tracker_mutex_);
return message_tracker_[request_id].first == message_tracker_[request_id].second + 1;
}
} // namespace core } // namespace core
} // namespace ps } // namespace ps
} // namespace mindspore } // namespace mindspore

+ 35
- 23
mindspore/ccsrc/ps/core/abstract_node.h View File

@@ -25,6 +25,7 @@
#include <unordered_map> #include <unordered_map>


#include "ps/core/node.h" #include "ps/core/node.h"
#include "ps/core/message.h"


namespace mindspore { namespace mindspore {
namespace ps { namespace ps {
@@ -34,53 +35,63 @@ class AbstractNode : public Node {
AbstractNode() : heart_beat_thread_(nullptr), client_to_scheduler_thread_(nullptr), client_to_scheduler_(nullptr) {} AbstractNode() : heart_beat_thread_(nullptr), client_to_scheduler_thread_(nullptr), client_to_scheduler_(nullptr) {}
~AbstractNode() override = default; ~AbstractNode() override = default;


typedef void (AbstractNode::*ResponseHandler)(const CommMessage &message);
typedef void (AbstractNode::*ResponseHandler)(std::shared_ptr<MessageMeta> meta, const void *data, size_t size);


bool Broadcast(const enum NodeRole &node_role, const CommMessage &message,
using DataPtr = std::shared_ptr<unsigned char>;
using VectorPtr = std::shared_ptr<std::vector<unsigned char>>;

bool Broadcast(const enum NodeRole &node_role, const DataPtr &message, size_t size, int command,
const uint32_t &timeout = kCommTimeoutInSeconds); const uint32_t &timeout = kCommTimeoutInSeconds);
void set_event_callback(const OnNodeEventMessage &on_node_event_message); void set_event_callback(const OnNodeEventMessage &on_node_event_message);


bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message,
const uint32_t &timeout = kCommTimeoutInSeconds);
bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<CommMessage> &data,
bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const DataPtr &data, size_t len, int command,
const uint32_t &timeout = kCommTimeoutInSeconds); const uint32_t &timeout = kCommTimeoutInSeconds);
bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message, CommMessage *output,
bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<DataPtr> &data,
const std::vector<size_t> &lens, int command, const uint32_t &timeout = kCommTimeoutInSeconds);
bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const DataPtr &message, size_t len, int command,
VectorPtr *output, const uint32_t &timeout = kCommTimeoutInSeconds);
bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<DataPtr> &data,
const std::vector<size_t> &data_lens, int command, std::vector<VectorPtr> *output,
const uint32_t &timeout = kCommTimeoutInSeconds); const uint32_t &timeout = kCommTimeoutInSeconds);
bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<CommMessage> &data,
std::vector<CommMessage> *output, const uint32_t &timeout = kCommTimeoutInSeconds);
bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds); bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds);


uint64_t CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message);
uint64_t CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, const void *data, size_t size);
std::pair<uint32_t, uint64_t> CollectiveReceiveAsync(const enum NodeRole &node_role, const uint32_t &rank_id, std::pair<uint32_t, uint64_t> CollectiveReceiveAsync(const enum NodeRole &node_role, const uint32_t &rank_id,
CommMessage *output);
void **output, size_t *size);
bool CollectiveWait(std::pair<uint32_t, uint64_t> request_id, const uint32_t &timeout = kCommTimeoutInSeconds); bool CollectiveWait(std::pair<uint32_t, uint64_t> request_id, const uint32_t &timeout = kCommTimeoutInSeconds);


protected: protected:
void Register(const std::shared_ptr<TcpClient> &client); void Register(const std::shared_ptr<TcpClient> &client);
void ProcessRegisterResp(const CommMessage &message);
void StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client);
bool Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_node_finish = false); bool Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_node_finish = false);
void FetchServers(const std::shared_ptr<TcpClient> &client);

void ProcessRegisterResp(std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
void ProcessHeartbeatResp(std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
void ProcessFetchServersResp(std::shared_ptr<MessageMeta> meta, const void *data, size_t size);

void StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client);
void UpdateSchedulerTime(); void UpdateSchedulerTime();
bool CheckSchedulerTimeout() const; bool CheckSchedulerTimeout() const;
void ProcessHeartbeatResp(const CommMessage &message);
void FetchServers(const std::shared_ptr<TcpClient> &client);
void ProcessFetchServersResp(const CommMessage &message);
bool Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout); bool Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout);
bool WaitForDisconnect(const uint32_t &timeout); bool WaitForDisconnect(const uint32_t &timeout);
bool InitClientToScheduler(); bool InitClientToScheduler();
const std::shared_ptr<TcpClient> &GetOrCreateTcpClient(const int &rank_id); const std::shared_ptr<TcpClient> &GetOrCreateTcpClient(const int &rank_id);
bool SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message, bool SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message,
const uint32_t &timeout = kCommTimeoutInSeconds); const uint32_t &timeout = kCommTimeoutInSeconds);
uint64_t SendMessageAsync(const std::shared_ptr<TcpClient> &client, const CommMessage &message);
void ProcessSendDataResp(const CommMessage &message);
bool SendMessageSync(const std::shared_ptr<TcpClient> &client, std::shared_ptr<MessageMeta>, const Protos &,
const void *, size_t size, const uint32_t &timeout = kCommTimeoutInSeconds);
uint64_t SendMessageAsync(const std::shared_ptr<TcpClient> &client, std::shared_ptr<MessageMeta> meta,
const Protos &protos, const void *data, size_t size);
void ProcessSendDataResp(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size);
void RunMessageCallback(const uint64_t &request_id); void RunMessageCallback(const uint64_t &request_id);
void set_message_callback(const uint64_t &request_id, const MessageCallback &callback); void set_message_callback(const uint64_t &request_id, const MessageCallback &callback);
void NotifyMessageArrival(const CommMessage &message);
void set_receive_callback(const uint32_t &rank_id, const uint64_t &request_id, const MessageCallback &callback);
void RunReceiveCallback(const CommMessage &message);
void NotifyMessageArrival(std::shared_ptr<MessageMeta> meta);
void RunReceiveCallback(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size);
uint64_t NextExpectedRankRequestId(const uint32_t &rank_id); uint64_t NextExpectedRankRequestId(const uint32_t &rank_id);
uint64_t NextActualRankRequestId(const uint32_t &rank_id); uint64_t NextActualRankRequestId(const uint32_t &rank_id);
void InitCommandHandler(); void InitCommandHandler();
uint64_t AddMessageTrack(const uint32_t &expected_response);
bool CheckMessageTrack(const uint64_t &request_id);


std::unique_ptr<std::thread> heart_beat_thread_; std::unique_ptr<std::thread> heart_beat_thread_;
std::unique_ptr<std::thread> client_to_scheduler_thread_; std::unique_ptr<std::thread> client_to_scheduler_thread_;
@@ -98,15 +109,16 @@ class AbstractNode : public Node {
std::mutex message_tracker_mutex_; std::mutex message_tracker_mutex_;
std::condition_variable message_tracker_cond_; std::condition_variable message_tracker_cond_;


// the key is: request_id, the value is:<rank_id, CommMessage>
std::unordered_map<uint64_t, std::unordered_map<uint32_t, CommMessage>> receive_messages_;
// the key is: request_id, the value is: <rank_id, RecvMessage>
std::unordered_map<uint64_t, std::unordered_map<uint32_t, VectorPtr>> receive_messages_;
std::map<std::pair<uint32_t, uint64_t>, bool> receive_messages_done_;
std::mutex receive_messages_mutex_; std::mutex receive_messages_mutex_;
// the key is: request_id // the key is: request_id
std::unordered_map<uint64_t, MessageCallback> message_callbacks_; std::unordered_map<uint64_t, MessageCallback> message_callbacks_;
std::mutex message_callbacks_mutex_; std::mutex message_callbacks_mutex_;


// the key is <rank_id, rank_request_id> // the key is <rank_id, rank_request_id>
std::map<std::pair<uint32_t, uint64_t>, CommMessage> received_data_;
std::map<std::pair<uint32_t, uint64_t>, std::shared_ptr<std::vector<unsigned char>>> received_data_;
std::mutex receive_callbacks_mutex_; std::mutex receive_callbacks_mutex_;
// the key is <rank_id, rank_request_id> // the key is <rank_id, rank_request_id>
std::map<std::pair<uint32_t, uint64_t>, MessageCallback> receive_callbacks_; std::map<std::pair<uint32_t, uint64_t>, MessageCallback> receive_callbacks_;


+ 59
- 0
mindspore/ccsrc/ps/core/message.h View File

@@ -0,0 +1,59 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_PS_CORE_MESSAGE_H_
#define MINDSPORE_CCSRC_PS_CORE_MESSAGE_H_

#include <string>
#include <memory>

namespace mindspore {
namespace ps {
namespace core {
enum class Protos : uint32_t { RAW = 0, PROTOBUF = 1, FLATBUFFERS = 2 };

enum class Command {
TERMINATE = 0,
REGISTER = 1,
HEARTBEAT = 2,
SEND_DATA = 3,
FETCH_SERVER = 4,
FINISH = 5,
COLLECTIVE_SEND_DATA = 6
};

enum class Role { SERVER = 0, WORKER = 1, SCHEDULER = 2 };

struct MessageHeader {
Protos message_proto_ = Protos::RAW;
uint32_t message_meta_length_ = 0;
uint64_t message_length_ = 0;
};

struct CommandMeta {
// the command of this message,for example: register,heartbeat,data
Command cmd;
// the request id of this message
uint64_t request_id;
// the role of the current node: worker,server,scheduler
Role role;
// the current Node rank id,the worker node range is:[0,numOfWorker-1], the server node range is:[0, numOfServer-1]
int32_t rank_id = 4;
};
} // namespace core
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_CORE_MESSAGE_H_

+ 6
- 3
mindspore/ccsrc/ps/core/protos/comm.proto View File

@@ -15,7 +15,6 @@
*/ */


syntax = "proto3"; syntax = "proto3";
import "google/protobuf/any.proto";
package mindspore.ps.core; package mindspore.ps.core;
option optimize_for = LITE_RUNTIME; option optimize_for = LITE_RUNTIME;


@@ -44,6 +43,8 @@ message MessageMeta {
NodeRole role = 3; NodeRole role = 3;
// the current Node rank id,the worker node range is:[0,numOfWorker-1], the server node range is:[0, numOfServer-1] // the current Node rank id,the worker node range is:[0,numOfWorker-1], the server node range is:[0, numOfServer-1]
int32 rank_id = 4; int32 rank_id = 4;
// User-defined commands
int32 user_cmd = 5;
} }


message RegisterMessage { message RegisterMessage {
@@ -76,6 +77,10 @@ message HeartbeatRespMessage {
bool is_node_timeout = 4; bool is_node_timeout = 4;
} }


message FetchServersMessage {
string node_id = 1;
}

message FetchServersRespMessage { message FetchServersRespMessage {
repeated ServersMeta servers_meta = 1; repeated ServersMeta servers_meta = 1;
} }
@@ -95,6 +100,4 @@ message FinishMessage {
message CommMessage { message CommMessage {
MessageMeta pb_meta = 1; MessageMeta pb_meta = 1;
bytes data = 2; bytes data = 2;
// User-defined commands
bytes user_cmd = 3;
} }

+ 32
- 25
mindspore/ccsrc/ps/core/scheduler_node.cc View File

@@ -38,9 +38,13 @@ bool SchedulerNode::Start(const uint32_t &timeout) {
} }


void SchedulerNode::ProcessHeartbeat(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn, void SchedulerNode::ProcessHeartbeat(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<CommMessage> message) {
std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(server);
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
HeartbeatMessage heartbeat_message; HeartbeatMessage heartbeat_message;
heartbeat_message.ParseFromString(message->data());
heartbeat_message.ParseFromArray(data, size);


node_manager_.UpdateHeartbeat(heartbeat_message.node_id()); node_manager_.UpdateHeartbeat(heartbeat_message.node_id());


@@ -60,10 +64,8 @@ void SchedulerNode::ProcessHeartbeat(std::shared_ptr<TcpServer> server, std::sha
heartbeat_resp_message.set_is_cluster_timeout(node_manager_.is_cluster_timeout()); heartbeat_resp_message.set_is_cluster_timeout(node_manager_.is_cluster_timeout());
heartbeat_resp_message.set_is_node_timeout(node_manager_.is_node_timeout()); heartbeat_resp_message.set_is_node_timeout(node_manager_.is_node_timeout());


std::shared_ptr<CommMessage> comm_message = std::make_shared<CommMessage>();
*comm_message->mutable_pb_meta() = {message->pb_meta()};
comm_message->set_data(heartbeat_resp_message.SerializeAsString());
server->SendMessage(conn, comm_message);
server->SendMessage(conn, meta, Protos::PROTOBUF, heartbeat_resp_message.SerializeAsString().data(),
heartbeat_resp_message.ByteSizeLong());
} }


void SchedulerNode::Initialize() { void SchedulerNode::Initialize() {
@@ -89,12 +91,13 @@ void SchedulerNode::CreateTcpServer() {
std::string scheduler_host = ClusterConfig::scheduler_host(); std::string scheduler_host = ClusterConfig::scheduler_host();
uint32_t scheduler_port = ClusterConfig::scheduler_port(); uint32_t scheduler_port = ClusterConfig::scheduler_port();
server_ = std::make_shared<TcpServer>(scheduler_host, scheduler_port); server_ = std::make_shared<TcpServer>(scheduler_host, scheduler_port);
server_->SetMessageCallback([&](std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) {
if (handlers_.count(message->pb_meta().cmd()) == 0) {
MS_LOG(EXCEPTION) << "The cmd:" << message->pb_meta().cmd() << " is not supported!";
server_->SetMessageCallback([&](std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
const Protos &protos, const void *data, size_t size) {
if (handlers_.count(meta->cmd()) == 0) {
MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!";
} }
const auto &handler_ptr = handlers_[message->pb_meta().cmd()];
(this->*handler_ptr)(server_, conn, message);
const auto &handler_ptr = handlers_[meta->cmd()];
(this->*handler_ptr)(server_, conn, meta, data, size);
}); });


server_->Init(); server_->Init();
@@ -106,10 +109,14 @@ void SchedulerNode::CreateTcpServer() {
} }


void SchedulerNode::ProcessRegister(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn, void SchedulerNode::ProcessRegister(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<CommMessage> message) {
std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(server);
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
MS_LOG(INFO) << "The scheduler process a register message!"; MS_LOG(INFO) << "The scheduler process a register message!";
RegisterMessage register_message; RegisterMessage register_message;
register_message.ParseFromString(message->data());
register_message.ParseFromArray(data, size);


// assign worker node and server node rank id // assign worker node and server node rank id
int rank_id = node_manager_.NextRankId(register_message); int rank_id = node_manager_.NextRankId(register_message);
@@ -123,32 +130,32 @@ void SchedulerNode::ProcessRegister(std::shared_ptr<TcpServer> server, std::shar
register_resp_message.set_node_id(node_id); register_resp_message.set_node_id(node_id);
register_resp_message.set_rank_id(rank_id); register_resp_message.set_rank_id(rank_id);


std::shared_ptr<CommMessage> comm_message = std::make_shared<CommMessage>();
*comm_message->mutable_pb_meta() = {message->pb_meta()};
comm_message->set_data(register_resp_message.SerializeAsString());
server->SendMessage(conn, comm_message);
server->SendMessage(conn, meta, Protos::PROTOBUF, register_resp_message.SerializeAsString().data(),
register_resp_message.ByteSizeLong());
} }


void SchedulerNode::ProcessFinish(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn, void SchedulerNode::ProcessFinish(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<CommMessage> message) {
std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(server);
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
FinishMessage finish_message; FinishMessage finish_message;
finish_message.ParseFromString(message->data());
finish_message.ParseFromArray(data, size);
node_manager_.AddFinishNode(finish_message); node_manager_.AddFinishNode(finish_message);
MS_LOG(INFO) << "Process finish message from node id:" << finish_message.node_id(); MS_LOG(INFO) << "Process finish message from node id:" << finish_message.node_id();
server->SendMessage(conn, message);
server->SendMessage(conn, meta, Protos::PROTOBUF, data, size);
} }


void SchedulerNode::ProcessFetchServers(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn, void SchedulerNode::ProcessFetchServers(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<CommMessage> message) {
std::shared_ptr<MessageMeta> meta, const void *data, size_t size) {
FetchServersRespMessage fetch_servers_message; FetchServersRespMessage fetch_servers_message;
std::vector<ServersMeta> servers_meta_list = node_manager_.FetchServersMeta(); std::vector<ServersMeta> servers_meta_list = node_manager_.FetchServersMeta();


*fetch_servers_message.mutable_servers_meta() = {servers_meta_list.begin(), servers_meta_list.end()}; *fetch_servers_message.mutable_servers_meta() = {servers_meta_list.begin(), servers_meta_list.end()};


std::shared_ptr<CommMessage> comm_message = std::make_shared<CommMessage>();
*comm_message->mutable_pb_meta() = {message->pb_meta()};
comm_message->set_data(fetch_servers_message.SerializeAsString());
server->SendMessage(conn, comm_message);
server->SendMessage(conn, meta, Protos::PROTOBUF, fetch_servers_message.SerializeAsString().data(),
fetch_servers_message.ByteSizeLong());
} }


void SchedulerNode::StartUpdateClusterStateTimer() { void SchedulerNode::StartUpdateClusterStateTimer() {


+ 6
- 5
mindspore/ccsrc/ps/core/scheduler_node.h View File

@@ -36,13 +36,14 @@
namespace mindspore { namespace mindspore {
namespace ps { namespace ps {
namespace core { namespace core {

class SchedulerNode : public Node { class SchedulerNode : public Node {
public: public:
SchedulerNode() : server_(nullptr), scheduler_thread_(nullptr), update_state_thread_(nullptr) {} SchedulerNode() : server_(nullptr), scheduler_thread_(nullptr), update_state_thread_(nullptr) {}
~SchedulerNode() override; ~SchedulerNode() override;


typedef void (SchedulerNode::*ResponseHandler)(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn, typedef void (SchedulerNode::*ResponseHandler)(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<CommMessage> message);
std::shared_ptr<MessageMeta> meta, const void *data, size_t size);


bool Start(const uint32_t &timeout = ClusterConfig::cluster_available_timeout()) override; bool Start(const uint32_t &timeout = ClusterConfig::cluster_available_timeout()) override;
bool Stop() override; bool Stop() override;
@@ -53,14 +54,14 @@ class SchedulerNode : public Node {
void InitCommandHandler(); void InitCommandHandler();
void CreateTcpServer(); void CreateTcpServer();
void ProcessHeartbeat(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn, void ProcessHeartbeat(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<CommMessage> message);
std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
void ProcessRegister(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn, void ProcessRegister(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<CommMessage> message);
std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
void StartUpdateClusterStateTimer(); void StartUpdateClusterStateTimer();
void ProcessFinish(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn, void ProcessFinish(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<CommMessage> message);
std::shared_ptr<MessageMeta> meta, const void *data, size_t size);
void ProcessFetchServers(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn, void ProcessFetchServers(std::shared_ptr<TcpServer> server, std::shared_ptr<TcpConnection> conn,
std::shared_ptr<CommMessage> message);
std::shared_ptr<MessageMeta> meta, const void *data, size_t size);


std::shared_ptr<TcpServer> server_; std::shared_ptr<TcpServer> server_;
std::unique_ptr<std::thread> scheduler_thread_; std::unique_ptr<std::thread> scheduler_thread_;


+ 29
- 22
mindspore/ccsrc/ps/core/server_node.cc View File

@@ -46,16 +46,16 @@ bool ServerNode::Start(const uint32_t &timeout) {


void ServerNode::set_handler(const RequestHandler &handler) { request_handler_ = handler; } void ServerNode::set_handler(const RequestHandler &handler) { request_handler_ = handler; }


void ServerNode::Response(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) {
void ServerNode::Response(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, DataPtr data,
size_t size) {
MS_EXCEPTION_IF_NULL(conn); MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(message);
message->mutable_pb_meta()->set_role(node_info_.node_role_);
message->mutable_pb_meta()->set_rank_id(node_info_.rank_id_);
const MessageMeta &message_meta = message->pb_meta();
const uint64_t request_id = message_meta.request_id();
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
meta->set_role(node_info_.node_role_);
meta->set_rank_id(node_info_.rank_id_);
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
server_->SendMessage(conn, message);
<< ", the node id is:" << node_info_.node_id_ << " send the request id is:" << meta->request_id();
server_->SendMessage(conn, meta, Protos::RAW, data.get(), size);
} }


void ServerNode::CreateTcpServer() { void ServerNode::CreateTcpServer() {
@@ -63,17 +63,18 @@ void ServerNode::CreateTcpServer() {
std::string server_ip; std::string server_ip;
CommUtil::GetAvailableInterfaceAndIP(&interface, &server_ip); CommUtil::GetAvailableInterfaceAndIP(&interface, &server_ip);
server_ = std::make_shared<TcpServer>(server_ip, 0); server_ = std::make_shared<TcpServer>(server_ip, 0);
server_->SetMessageCallback([&](std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) {
switch (message->pb_meta().cmd()) {
server_->SetMessageCallback([&](std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
const Protos &protos, const void *data, size_t size) {
switch (meta->cmd()) {
case NodeCommand::SEND_DATA: case NodeCommand::SEND_DATA:
ProcessSendData(conn, message);
ProcessSendData(conn, meta, protos, data, size);
break; break;
case NodeCommand::COLLECTIVE_SEND_DATA: case NodeCommand::COLLECTIVE_SEND_DATA:
ProcessCollectiveSendData(conn, message);
RunReceiveCallback(*message);
ProcessCollectiveSendData(conn, meta, data, size);
RunReceiveCallback(meta, protos, data, size);
break; break;
default: default:
MS_LOG(EXCEPTION) << "The cmd:" << message->pb_meta().cmd() << " is not supported!";
MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!";
} }
}); });
server_->Init(); server_->Init();
@@ -99,18 +100,24 @@ void ServerNode::Initialize() {
MS_LOG(INFO) << "Server node init client successful!"; MS_LOG(INFO) << "Server node init client successful!";
} }


void ServerNode::ProcessSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) {
void ServerNode::ProcessSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
const Protos &protos, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(conn); MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(message);
request_handler_(conn, message);
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
std::shared_ptr<unsigned char> res(new unsigned char[size]);
int ret = memcpy_s(res.get(), size, data, size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
request_handler_(conn, meta, res, size);
} }


void ServerNode::ProcessCollectiveSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) {
void ServerNode::ProcessCollectiveSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(conn); MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(message);
std::shared_ptr<CommMessage> comm_message = std::make_shared<CommMessage>();
*comm_message->mutable_pb_meta() = {message->pb_meta()};
server_->SendMessage(conn, comm_message);
MS_EXCEPTION_IF_NULL(meta);
server_->SendMessage(conn, meta, Protos::RAW, data, size);
} }


bool ServerNode::Stop() { bool ServerNode::Stop() {


+ 8
- 4
mindspore/ccsrc/ps/core/server_node.h View File

@@ -23,6 +23,7 @@
#include <string> #include <string>
#include <thread> #include <thread>
#include <utility> #include <utility>
#include <vector>


#include "ps/core/cluster_config.h" #include "ps/core/cluster_config.h"
#include "ps/core/tcp_client.h" #include "ps/core/tcp_client.h"
@@ -41,16 +42,19 @@ class ServerNode : public AbstractNode {
bool Stop() override; bool Stop() override;
bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override; bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override;


using RequestHandler = std::function<void(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message)>;
using RequestHandler = std::function<void(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
DataPtr data, size_t size)>;


void set_handler(const RequestHandler &handler); void set_handler(const RequestHandler &handler);
void Response(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message);
void Response(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, DataPtr data, size_t size);


private: private:
void CreateTcpServer(); void CreateTcpServer();
void Initialize(); void Initialize();
void ProcessSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message);
void ProcessCollectiveSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message);
void ProcessSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const Protos &protos,
const void *data, size_t size);
void ProcessCollectiveSendData(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
const void *data, size_t size);


std::shared_ptr<TcpServer> server_; std::shared_ptr<TcpServer> server_;
std::unique_ptr<std::thread> server_thread_; std::unique_ptr<std::thread> server_thread_;


+ 50
- 14
mindspore/ccsrc/ps/core/tcp_client.cc View File

@@ -46,11 +46,12 @@ TcpClient::TcpClient(const std::string &address, std::uint16_t port)
server_port_(port), server_port_(port),
is_stop_(true), is_stop_(true),
is_connected_(false) { is_connected_(false) {
message_handler_.SetCallback([this](std::shared_ptr<CommMessage> message) {
if (message_callback_) {
message_callback_(*this, *message);
}
});
message_handler_.SetCallback(
[this](std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size) {
if (message_callback_) {
message_callback_(meta, protos, data, size);
}
});
} }


TcpClient::~TcpClient() { TcpClient::~TcpClient() {
@@ -189,7 +190,7 @@ void TcpClient::ReadCallback(struct bufferevent *bev, void *ctx) {
void TcpClient::OnReadHandler(const void *buf, size_t num) { void TcpClient::OnReadHandler(const void *buf, size_t num) {
MS_EXCEPTION_IF_NULL(buf); MS_EXCEPTION_IF_NULL(buf);
if (read_callback_) { if (read_callback_) {
read_callback_(*this, buf, num);
read_callback_(buf, num);
} }
message_handler_.ReceiveMessage(buf, num); message_handler_.ReceiveMessage(buf, num);
} }
@@ -198,7 +199,7 @@ void TcpClient::TimerCallback(evutil_socket_t, int16_t, void *arg) {
MS_EXCEPTION_IF_NULL(arg); MS_EXCEPTION_IF_NULL(arg);
auto tcp_client = reinterpret_cast<TcpClient *>(arg); auto tcp_client = reinterpret_cast<TcpClient *>(arg);
if (tcp_client->on_timer_callback_) { if (tcp_client->on_timer_callback_) {
tcp_client->on_timer_callback_(*tcp_client);
tcp_client->on_timer_callback_();
} }
} }


@@ -245,7 +246,7 @@ void TcpClient::Start() {
MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType)
<< "Event base dispatch failed with no events pending or active!"; << "Event base dispatch failed with no events pending or active!";
MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base dispatch failed with error occurred!"; MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base dispatch failed with error occurred!";
MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base dispatch with unexpect error code!";
MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base dispatch with unexpected error code!";
} }


void TcpClient::StartWithNoBlock() { void TcpClient::StartWithNoBlock() {
@@ -256,7 +257,7 @@ void TcpClient::StartWithNoBlock() {
MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base loop success!"; MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base loop success!";
MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) << "Event base loop failed with no events pending or active!"; MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) << "Event base loop failed with no events pending or active!";
MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base loop failed with error occurred!"; MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base loop failed with error occurred!";
MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base loop with unexpect error code!";
MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base loop with unexpected error code!";
} }


void TcpClient::SetMessageCallback(const OnMessage &cb) { message_callback_ = cb; } void TcpClient::SetMessageCallback(const OnMessage &cb) { message_callback_ = cb; }
@@ -265,14 +266,49 @@ bool TcpClient::SendMessage(const CommMessage &message) const {
MS_EXCEPTION_IF_NULL(buffer_event_); MS_EXCEPTION_IF_NULL(buffer_event_);
bufferevent_lock(buffer_event_); bufferevent_lock(buffer_event_);
bool res = true; bool res = true;
size_t buf_size = message.ByteSizeLong();
std::vector<unsigned char> serialized(buf_size);
message.SerializeToArray(serialized.data(), SizeToInt(buf_size));
if (bufferevent_write(buffer_event_, &buf_size, sizeof(buf_size)) == -1) {
size_t buf_size = IntToUint(message.ByteSizeLong());
uint32_t meta_size = SizeToUint(message.pb_meta().ByteSizeLong());
Messageheader header;
header.message_proto_ = Protos::PROTOBUF;
header.message_length_ = buf_size;
header.message_meta_length_ = meta_size;
if (bufferevent_write(buffer_event_, &header, sizeof(header)) == -1) {
MS_LOG(ERROR) << "Event buffer add header failed!";
res = false;
}
if (bufferevent_write(buffer_event_, message.pb_meta().SerializeAsString().data(), meta_size) == -1) {
MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
res = false;
}
if (bufferevent_write(buffer_event_, message.data().data(), message.data().length()) == -1) {
MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
res = false;
}
bufferevent_unlock(buffer_event_);
return res;
}

bool TcpClient::SendMessage(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(buffer_event_);
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
bufferevent_lock(buffer_event_);
bool res = true;

Messageheader header;
header.message_proto_ = protos;
header.message_meta_length_ = SizeToUint(meta->ByteSizeLong());
header.message_length_ = size + header.message_meta_length_;

if (bufferevent_write(buffer_event_, &header, sizeof(header)) == -1) {
MS_LOG(ERROR) << "Event buffer add header failed!"; MS_LOG(ERROR) << "Event buffer add header failed!";
res = false; res = false;
} }
if (bufferevent_write(buffer_event_, serialized.data(), buf_size) == -1) {
if (bufferevent_write(buffer_event_, meta->SerializeAsString().data(), meta->ByteSizeLong()) == -1) {
MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
res = false;
}
if (bufferevent_write(buffer_event_, data, size) == -1) {
MS_LOG(ERROR) << "Event buffer add protobuf data failed!"; MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
res = false; res = false;
} }


+ 5
- 4
mindspore/ccsrc/ps/core/tcp_client.h View File

@@ -42,10 +42,10 @@ class TcpClient {
public: public:
using OnConnected = std::function<void()>; using OnConnected = std::function<void()>;
using OnDisconnected = std::function<void()>; using OnDisconnected = std::function<void()>;
using OnRead = std::function<void(const TcpClient &, const void *, size_t)>;
using OnTimeout = std::function<void(const TcpClient &)>;
using OnMessage = std::function<void(const TcpClient &, const CommMessage &)>;
using OnTimer = std::function<void(const TcpClient &)>;
using OnRead = std::function<void(const void *, size_t)>;
using OnTimeout = std::function<void()>;
using OnMessage = std::function<void(std::shared_ptr<MessageMeta>, const Protos &, const void *, size_t size)>;
using OnTimer = std::function<void()>;


explicit TcpClient(const std::string &address, std::uint16_t port); explicit TcpClient(const std::string &address, std::uint16_t port);
virtual ~TcpClient(); virtual ~TcpClient();
@@ -61,6 +61,7 @@ class TcpClient {
void StartWithNoBlock(); void StartWithNoBlock();
void SetMessageCallback(const OnMessage &cb); void SetMessageCallback(const OnMessage &cb);
bool SendMessage(const CommMessage &message) const; bool SendMessage(const CommMessage &message) const;
bool SendMessage(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size);
void StartTimer(const uint32_t &time); void StartTimer(const uint32_t &time);
void set_timer_callback(const OnTimer &timer); void set_timer_callback(const OnTimer &timer);
const event_base &eventbase(); const event_base &eventbase();


+ 11
- 5
mindspore/ccsrc/ps/core/tcp_message_handler.cc View File

@@ -35,8 +35,12 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) {
header_[++header_index_] = *(buffer_data + i); header_[++header_index_] = *(buffer_data + i);
--num; --num;
if (header_index_ == kHeaderLen - 1) { if (header_index_ == kHeaderLen - 1) {
message_length_ = *reinterpret_cast<const size_t *>(header_);
remaining_length_ = message_length_;
message_header_.message_proto_ = *reinterpret_cast<const Protos *>(header_);
message_header_.message_meta_length_ =
*reinterpret_cast<const uint32_t *>(header_ + sizeof(message_header_.message_proto_));
message_header_.message_length_ = *reinterpret_cast<const size_t *>(
header_ + sizeof(message_header_.message_proto_) + sizeof(message_header_.message_meta_length_));
remaining_length_ = message_header_.message_length_;
message_buffer_.reset(new unsigned char[remaining_length_]); message_buffer_.reset(new unsigned char[remaining_length_]);
buffer_data += (i + 1); buffer_data += (i + 1);
break; break;
@@ -57,10 +61,12 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) {
} }


if (remaining_length_ == 0) { if (remaining_length_ == 0) {
std::shared_ptr<CommMessage> pb_message = std::make_shared<CommMessage>();
pb_message->ParseFromArray(message_buffer_.get(), message_length_);
if (message_callback_) { if (message_callback_) {
message_callback_(pb_message);
std::shared_ptr<MessageMeta> pb_message = std::make_shared<MessageMeta>();
pb_message->ParseFromArray(message_buffer_.get(), message_header_.message_meta_length_);
message_callback_(pb_message, message_header_.message_proto_,
message_buffer_.get() + message_header_.message_meta_length_,
message_header_.message_length_ - message_header_.message_meta_length_);
} }
message_buffer_.reset(); message_buffer_.reset();
message_buffer_ = nullptr; message_buffer_ = nullptr;


+ 7
- 10
mindspore/ccsrc/ps/core/tcp_message_handler.h View File

@@ -24,24 +24,20 @@
#include <vector> #include <vector>


#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "ps/core/message.h"
#include "proto/comm.pb.h" #include "proto/comm.pb.h"
#include "proto/ps.pb.h" #include "proto/ps.pb.h"


namespace mindspore { namespace mindspore {
namespace ps { namespace ps {
namespace core { namespace core {
using messageReceive = std::function<void(std::shared_ptr<CommMessage>)>;
constexpr int kHeaderLen = 8;
using messageReceive = std::function<void(std::shared_ptr<MessageMeta>, const Protos &, const void *, size_t size)>;
constexpr int kHeaderLen = 16;


class TcpMessageHandler { class TcpMessageHandler {
public: public:
TcpMessageHandler() TcpMessageHandler()
: is_parsed_(false),
message_buffer_(nullptr),
message_length_(0),
remaining_length_(0),
header_index_(-1),
last_copy_len_(0) {}
: is_parsed_(false), message_buffer_(nullptr), remaining_length_(0), header_index_(-1), last_copy_len_(0) {}
virtual ~TcpMessageHandler() = default; virtual ~TcpMessageHandler() = default;


void SetCallback(const messageReceive &cb); void SetCallback(const messageReceive &cb);
@@ -51,11 +47,12 @@ class TcpMessageHandler {
messageReceive message_callback_; messageReceive message_callback_;
bool is_parsed_; bool is_parsed_;
std::unique_ptr<unsigned char> message_buffer_; std::unique_ptr<unsigned char> message_buffer_;
size_t message_length_;
size_t remaining_length_; size_t remaining_length_;
char header_[8];
char header_[16];
int header_index_; int header_index_;
size_t last_copy_len_; size_t last_copy_len_;
MessageHeader message_header_;
std::string mBuffer;
}; };
} // namespace core } // namespace core
} // namespace ps } // namespace ps


+ 43
- 8
mindspore/ccsrc/ps/core/tcp_server.cc View File

@@ -54,13 +54,39 @@ bool TcpConnection::SendMessage(std::shared_ptr<CommMessage> message) const {
bufferevent_lock(buffer_event_); bufferevent_lock(buffer_event_);
bool res = true; bool res = true;
size_t buf_size = message->ByteSizeLong(); size_t buf_size = message->ByteSizeLong();
std::vector<unsigned char> serialized(buf_size);
message->SerializeToArray(serialized.data(), SizeToInt(buf_size));
if (bufferevent_write(buffer_event_, &buf_size, sizeof(buf_size)) == -1) { if (bufferevent_write(buffer_event_, &buf_size, sizeof(buf_size)) == -1) {
MS_LOG(ERROR) << "Event buffer add header failed!"; MS_LOG(ERROR) << "Event buffer add header failed!";
res = false; res = false;
} }
if (bufferevent_write(buffer_event_, serialized.data(), buf_size) == -1) {
if (bufferevent_write(buffer_event_, message->SerializeAsString().data(), buf_size) == -1) {
MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
res = false;
}
bufferevent_unlock(buffer_event_);
return res;
}

bool TcpConnection::SendMessage(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data,
size_t size) const {
MS_EXCEPTION_IF_NULL(buffer_event_);
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
bufferevent_lock(buffer_event_);
bool res = true;
Messageheader header;
header.message_proto_ = protos;
header.message_meta_length_ = SizeToUint(meta->ByteSizeLong());
header.message_length_ = size + header.message_meta_length_;

if (bufferevent_write(buffer_event_, &header, sizeof(header)) == -1) {
MS_LOG(ERROR) << "Event buffer add header failed!";
res = false;
}
if (bufferevent_write(buffer_event_, meta->SerializeAsString().data(), meta->ByteSizeLong()) == -1) {
MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
res = false;
}
if (bufferevent_write(buffer_event_, data, size) == -1) {
MS_LOG(ERROR) << "Event buffer add protobuf data failed!"; MS_LOG(ERROR) << "Event buffer add protobuf data failed!";
res = false; res = false;
} }
@@ -158,7 +184,7 @@ void TcpServer::Start() {
MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType)
<< "Event base dispatch failed with no events pending or active!"; << "Event base dispatch failed with no events pending or active!";
MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base dispatch failed with error occurred!"; MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base dispatch failed with error occurred!";
MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base dispatch with unexpect error code!";
MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base dispatch with unexpected error code!";
} }


void TcpServer::StartWithNoBlock() { void TcpServer::StartWithNoBlock() {
@@ -169,7 +195,7 @@ void TcpServer::StartWithNoBlock() {
MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base loop success!"; MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base loop success!";
MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) << "Event base loop failed with no events pending or active!"; MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) << "Event base loop failed with no events pending or active!";
MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base loop failed with error occurred!"; MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base loop failed with error occurred!";
MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base loop with unexpect error code!";
MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base loop with unexpected error code!";
} }


void TcpServer::StartTimerOnlyOnce(const uint32_t &time) { void TcpServer::StartTimerOnlyOnce(const uint32_t &time) {
@@ -260,10 +286,10 @@ void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, st
MS_EXCEPTION_IF_NULL(conn); MS_EXCEPTION_IF_NULL(conn);


server->AddConnection(fd, conn); server->AddConnection(fd, conn);
conn->InitConnection([=](std::shared_ptr<CommMessage> message) {
conn->InitConnection([=](std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size) {
OnServerReceiveMessage on_server_receive = server->GetServerReceive(); OnServerReceiveMessage on_server_receive = server->GetServerReceive();
if (on_server_receive) { if (on_server_receive) {
on_server_receive(conn, message);
on_server_receive(conn, meta, protos, data, size);
} }
}); });
bufferevent_setcb(bev, TcpServer::ReadCallback, nullptr, TcpServer::EventCallback, bufferevent_setcb(bev, TcpServer::ReadCallback, nullptr, TcpServer::EventCallback,
@@ -274,6 +300,7 @@ void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, st
} }


std::shared_ptr<TcpConnection> TcpServer::onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd) { std::shared_ptr<TcpConnection> TcpServer::onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd) {
MS_EXCEPTION_IF_NULL(bev);
std::shared_ptr<TcpConnection> conn = nullptr; std::shared_ptr<TcpConnection> conn = nullptr;
if (client_accept_) { if (client_accept_) {
conn = (client_accept_(*this)); conn = (client_accept_(*this));
@@ -367,9 +394,17 @@ bool TcpServer::SendMessage(std::shared_ptr<TcpConnection> conn, std::shared_ptr
return conn->SendMessage(message); return conn->SendMessage(message);
} }


bool TcpServer::SendMessage(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
const Protos &protos, const void *data, size_t size) {
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
return conn->SendMessage(meta, protos, data, size);
}

void TcpServer::SendMessage(std::shared_ptr<CommMessage> message) { void TcpServer::SendMessage(std::shared_ptr<CommMessage> message) {
std::lock_guard<std::mutex> lock(connection_mutex_);
MS_EXCEPTION_IF_NULL(message); MS_EXCEPTION_IF_NULL(message);
std::lock_guard<std::mutex> lock(connection_mutex_);


for (auto it = connections_.begin(); it != connections_.end(); ++it) { for (auto it = connections_.begin(); it != connections_.end(); ++it) {
SendMessage(it->second, message); SendMessage(it->second, message);


+ 5
- 2
mindspore/ccsrc/ps/core/tcp_server.h View File

@@ -36,7 +36,6 @@


#include "ps/core/tcp_message_handler.h" #include "ps/core/tcp_message_handler.h"
#include "ps/core/cluster_config.h" #include "ps/core/cluster_config.h"
#include "utils/log_adapter.h"
#include "utils/convert_utils_base.h" #include "utils/convert_utils_base.h"


namespace mindspore { namespace mindspore {
@@ -55,6 +54,7 @@ class TcpConnection {
virtual void InitConnection(const messageReceive &callback); virtual void InitConnection(const messageReceive &callback);
virtual void SendMessage(const void *buffer, size_t num) const; virtual void SendMessage(const void *buffer, size_t num) const;
bool SendMessage(std::shared_ptr<CommMessage> message) const; bool SendMessage(std::shared_ptr<CommMessage> message) const;
bool SendMessage(std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data, size_t size) const;
virtual void OnReadHandler(const void *buffer, size_t numBytes); virtual void OnReadHandler(const void *buffer, size_t numBytes);
TcpServer *GetServer() const; TcpServer *GetServer() const;
const evutil_socket_t &GetFd() const; const evutil_socket_t &GetFd() const;
@@ -69,7 +69,8 @@ class TcpConnection {
}; };


using OnServerReceiveMessage = using OnServerReceiveMessage =
std::function<void(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message)>;
std::function<void(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const Protos &protos,
const void *data, size_t size)>;


class TcpServer { class TcpServer {
public: public:
@@ -100,6 +101,8 @@ class TcpServer {
OnServerReceiveMessage GetServerReceive() const; OnServerReceiveMessage GetServerReceive() const;
void SetMessageCallback(const OnServerReceiveMessage &cb); void SetMessageCallback(const OnServerReceiveMessage &cb);
bool SendMessage(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message); bool SendMessage(std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message);
bool SendMessage(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const Protos &protos,
const void *data, size_t sizee);
void SendMessage(std::shared_ptr<CommMessage> message); void SendMessage(std::shared_ptr<CommMessage> message);
uint16_t BoundPort() const; uint16_t BoundPort() const;
std::string BoundIp() const; std::string BoundIp() const;


+ 12
- 2
tests/ut/cpp/ps/core/tcp_client_tests.cc View File

@@ -30,7 +30,12 @@ class TestTcpClient : public UT::Common {
TEST_F(TestTcpClient, InitClientIPError) { TEST_F(TestTcpClient, InitClientIPError) {
auto client = std::make_unique<TcpClient>("127.0.0.13543", 9000); auto client = std::make_unique<TcpClient>("127.0.0.13543", 9000);


client->SetMessageCallback([](const TcpClient &client, const CommMessage &message) { client.SendMessage(message); });
client->SetMessageCallback([&](std::shared_ptr<MessageMeta>, const Protos &, const void *data, size_t size) {
CommMessage message;
message.ParseFromArray(data, size);

client->SendMessage(message);
});


ASSERT_THROW(client->Init(), std::exception); ASSERT_THROW(client->Init(), std::exception);
} }
@@ -38,10 +43,15 @@ TEST_F(TestTcpClient, InitClientIPError) {
TEST_F(TestTcpClient, InitClientPortErrorNoException) { TEST_F(TestTcpClient, InitClientPortErrorNoException) {
auto client = std::make_unique<TcpClient>("127.0.0.1", -1); auto client = std::make_unique<TcpClient>("127.0.0.1", -1);


client->SetMessageCallback([](const TcpClient &client, const CommMessage &message) { client.SendMessage(message); });
client->SetMessageCallback([&](std::shared_ptr<MessageMeta>, const Protos &, const void *data, size_t size) {
CommMessage message;
message.ParseFromArray(data, size);
client->SendMessage(message);
});


EXPECT_NO_THROW(client->Init()); EXPECT_NO_THROW(client->Init());
} }

} // namespace core } // namespace core
} // namespace ps } // namespace ps
} // namespace mindspore } // namespace mindspore

+ 96
- 82
tests/ut/cpp/ps/core/tcp_message_handler_test.cc View File

@@ -33,130 +33,144 @@ class TestTcpMessageHandler : public UT::Common {
void TearDown() override {} void TearDown() override {}
}; };


TEST_F(TestTcpMessageHandler, 8_Header_1003_Data) {
TEST_F(TestTcpMessageHandler, 16Header_2meta_1000Data) {
TcpMessageHandler handler; TcpMessageHandler handler;
handler.SetCallback([this](std::shared_ptr<CommMessage> message) { EXPECT_EQ(message->data().size(), 1000); });
handler.SetCallback([this](std::shared_ptr<MessageMeta> meta, const Protos &, const void *data, size_t size) {
EXPECT_EQ(meta->ByteSizeLong(), 2);
EXPECT_EQ(size, 1000);
});


std::string data(1000, 'a'); std::string data(1000, 'a');
CommMessage message;
message.set_data(data);
size_t buf_size = message.ByteSizeLong();
char result[1011];
int ret = memcpy_s(result, kHeaderLen, &buf_size, kHeaderLen);

char result[1018];

MessageMeta meta;
meta.set_request_id(1);
EXPECT_EQ(meta.ByteSizeLong(), 2);

MessageHeader header;
header.message_proto_ = Protos::RAW;
header.message_meta_length_ = meta.ByteSizeLong();
header.message_length_ = data.length() + meta.ByteSizeLong();
int ret = memcpy_s(result, kHeaderLen, &header, kHeaderLen);
if (ret != 0) { if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
} }


std::vector<char> serialized(buf_size);
message.SerializeToArray(serialized.data(), static_cast<int>(buf_size));
memcpy_s(result + kHeaderLen, buf_size, serialized.data(), buf_size);
handler.ReceiveMessage(result, buf_size + kHeaderLen);
memcpy_s(result + kHeaderLen, meta.ByteSizeLong(), meta.SerializeAsString().data(), meta.ByteSizeLong());
memcpy_s(result + kHeaderLen + meta.ByteSizeLong(), data.length(), data.data(), data.length());
handler.ReceiveMessage(result, 1018);
} }


TEST_F(TestTcpMessageHandler, 8_Header_1003_Data_8_Header_1003_Data) {
TEST_F(TestTcpMessageHandler, 16Header_2meta_1000Data_16Header_2meta_1000Data) {
TcpMessageHandler handler; TcpMessageHandler handler;
handler.SetCallback([this](std::shared_ptr<CommMessage> message) { EXPECT_EQ(message->data().size(), 1000); });
handler.SetCallback([this](std::shared_ptr<MessageMeta> meta, const Protos &, const void *data, size_t size) {
EXPECT_EQ(meta->ByteSizeLong(), 2);
EXPECT_EQ(size, 1000);
});


std::string data(1000, 'a'); std::string data(1000, 'a');
CommMessage message;
message.set_data(data);
size_t buf_size = message.ByteSizeLong();
char result[2022] = {0};
int ret = memcpy_s(result, kHeaderLen, &buf_size, kHeaderLen);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
std::vector<char> serialized(buf_size);
message.SerializeToArray(serialized.data(), static_cast<int>(buf_size));
ret = memcpy_s(result + kHeaderLen, buf_size, serialized.data(), buf_size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
ret = memcpy_s(result + kHeaderLen + buf_size, kHeaderLen, &buf_size, kHeaderLen);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
ret = memcpy_s(result + kHeaderLen + buf_size + kHeaderLen, buf_size, serialized.data(), buf_size);

char result[2036];

MessageMeta meta;
meta.set_request_id(1);
EXPECT_EQ(meta.ByteSizeLong(), 2);

MessageHeader header;
header.message_proto_ = Protos::RAW;
header.message_meta_length_ = meta.ByteSizeLong();
header.message_length_ = data.length() + meta.ByteSizeLong();
int ret = memcpy_s(result, kHeaderLen, &header, kHeaderLen);
if (ret != 0) { if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
} }


handler.ReceiveMessage(result, 2 * buf_size + kHeaderLen * 2);
memcpy_s(result + kHeaderLen, meta.ByteSizeLong(), meta.SerializeAsString().data(), meta.ByteSizeLong());
memcpy_s(result + kHeaderLen + meta.ByteSizeLong(), data.length(), data.data(), data.length());

memcpy_s(result + kHeaderLen + meta.ByteSizeLong() + data.length(), kHeaderLen, &header, kHeaderLen);
memcpy_s(result + kHeaderLen * 2 + meta.ByteSizeLong() + data.length(), meta.ByteSizeLong(),
meta.SerializeAsString().data(), meta.ByteSizeLong());
memcpy_s(result + kHeaderLen * 2 + meta.ByteSizeLong() * 2 + data.length(), data.length(), data.data(),
data.length());

handler.ReceiveMessage(result, 2036);
} }


TEST_F(TestTcpMessageHandler, 8_Header_4084_Data_4_Header_4_header_4084_data) {
TEST_F(TestTcpMessageHandler, 16header_2meta_4070data_8header_8header_2meta_4070data) {
TcpMessageHandler handler; TcpMessageHandler handler;
handler.SetCallback([this](std::shared_ptr<CommMessage> message) { EXPECT_EQ(message->data().size(), 4081); });
handler.SetCallback([this](std::shared_ptr<MessageMeta> meta, const Protos &, const void *data, size_t size) {
EXPECT_EQ(meta->ByteSizeLong(), 2);
EXPECT_EQ(size, 4070);
});

std::string data(4070, 'a');


std::string data(4081, 'a');
CommMessage message;
message.set_data(data);
size_t buf_size = message.ByteSizeLong();
char result[4096] = {0}; char result[4096] = {0};
int ret = memcpy_s(result, kHeaderLen, &buf_size, kHeaderLen);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
std::vector<char> serialized(buf_size);
message.SerializeToArray(serialized.data(), static_cast<int>(buf_size));
ret = memcpy_s(result + kHeaderLen, buf_size, serialized.data(), buf_size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}


ret = memcpy_s(result + kHeaderLen + buf_size, 4, &buf_size, 4);
MessageMeta meta;
meta.set_request_id(1);
EXPECT_EQ(meta.ByteSizeLong(), 2);

MessageHeader header;
header.message_proto_ = Protos::RAW;
header.message_meta_length_ = meta.ByteSizeLong();
header.message_length_ = data.length() + meta.ByteSizeLong();
int ret = memcpy_s(result, kHeaderLen, &header, kHeaderLen);
if (ret != 0) { if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
} }


memcpy_s(result + kHeaderLen, meta.ByteSizeLong(), meta.SerializeAsString().data(), meta.ByteSizeLong());
memcpy_s(result + kHeaderLen + meta.ByteSizeLong(), data.length(), data.data(), data.length());

memcpy_s(result + kHeaderLen + meta.ByteSizeLong() + data.length(), 8, &header, 8);
handler.ReceiveMessage(result, 4096); handler.ReceiveMessage(result, 4096);


auto temp = reinterpret_cast<char *>(&buf_size);
ret = memcpy_s(result, 4, temp + 4, 4);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
ret = memcpy_s(result + 4, buf_size, serialized.data(), buf_size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
auto temp = reinterpret_cast<char *>(&header);
memcpy_s(result, 8, temp + 8, 8);
memcpy_s(result + 8, meta.ByteSizeLong(), meta.SerializeAsString().data(), meta.ByteSizeLong());
memcpy_s(result + 8 + 2, data.length(), data.data(), data.length());


handler.ReceiveMessage(result, 4088);
handler.ReceiveMessage(result, 4080);
} }


TEST_F(TestTcpMessageHandler, 8_Header_4080_Data_8_Header_4080_data) {
TEST_F(TestTcpMessageHandler, 16Header_2meta_4062Data_16Header_2meta_4062_data) {
TcpMessageHandler handler; TcpMessageHandler handler;
handler.SetCallback([this](std::shared_ptr<CommMessage> message) { EXPECT_EQ(message->data().size(), 4077); });
handler.SetCallback([this](std::shared_ptr<MessageMeta> meta, const Protos &, const void *data, size_t size) {
EXPECT_EQ(meta->ByteSizeLong(), 2);
EXPECT_EQ(size, 4062);
});

std::string data(4062, 'a');


std::string data(4077, 'a');
CommMessage message;
message.set_data(data);
size_t buf_size = message.ByteSizeLong();
char result[4096] = {0}; char result[4096] = {0};
int ret = memcpy_s(result, kHeaderLen, &buf_size, kHeaderLen);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
std::vector<char> serialized(buf_size);
message.SerializeToArray(serialized.data(), static_cast<int>(buf_size));
ret = memcpy_s(result + kHeaderLen, buf_size, serialized.data(), buf_size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}


ret = memcpy_s(result + kHeaderLen + buf_size, kHeaderLen, &buf_size, kHeaderLen);
MessageMeta meta;
meta.set_request_id(1);
EXPECT_EQ(meta.ByteSizeLong(), 2);

MessageHeader header;
header.message_proto_ = Protos::RAW;
header.message_meta_length_ = meta.ByteSizeLong();
header.message_length_ = data.length() + meta.ByteSizeLong();
int ret = memcpy_s(result, kHeaderLen, &header, kHeaderLen);
if (ret != 0) { if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
} }


memcpy_s(result + kHeaderLen, meta.ByteSizeLong(), meta.SerializeAsString().data(), meta.ByteSizeLong());
memcpy_s(result + kHeaderLen + meta.ByteSizeLong(), data.length(), data.data(), data.length());
memcpy_s(result + kHeaderLen + meta.ByteSizeLong() + data.length(), kHeaderLen, &header, kHeaderLen);

handler.ReceiveMessage(result, 4096); handler.ReceiveMessage(result, 4096);


ret = memcpy_s(result, buf_size, serialized.data(), buf_size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
}
memcpy_s(result, meta.ByteSizeLong(), meta.SerializeAsString().data(), meta.ByteSizeLong());
memcpy_s(result + meta.ByteSizeLong(), data.length(), data.data(), data.length());


handler.ReceiveMessage(result, 4080);
handler.ReceiveMessage(result, 4064);
} }
} // namespace core } // namespace core
} // namespace ps } // namespace ps

+ 12
- 10
tests/ut/cpp/ps/core/tcp_pb_server_test.cc View File

@@ -33,11 +33,12 @@ class TestTcpServer : public UT::Common {
server_ = std::make_unique<TcpServer>("127.0.0.1", 0); server_ = std::make_unique<TcpServer>("127.0.0.1", 0);
std::unique_ptr<std::thread> http_server_thread_(nullptr); std::unique_ptr<std::thread> http_server_thread_(nullptr);
http_server_thread_ = std::make_unique<std::thread>([=]() { http_server_thread_ = std::make_unique<std::thread>([=]() {
server_->SetMessageCallback([=](std::shared_ptr<TcpConnection> conn, std::shared_ptr<CommMessage> message) {
server_->SetMessageCallback([=](std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
const Protos &protos, const void *data, size_t size) {
KVMessage kv_message; KVMessage kv_message;
kv_message.ParseFromString(message->data());
kv_message.ParseFromArray(data, size);
EXPECT_EQ(2, kv_message.keys_size()); EXPECT_EQ(2, kv_message.keys_size());
server_->SendMessage(conn, message);
server_->SendMessage(conn, meta, protos, data, size);
}); });
server_->Init(); server_->Init();
server_->Start(); server_->Start();
@@ -61,23 +62,24 @@ TEST_F(TestTcpServer, ServerSendMessage) {
std::cout << server_->BoundPort() << std::endl; std::cout << server_->BoundPort() << std::endl;
std::unique_ptr<std::thread> http_client_thread(nullptr); std::unique_ptr<std::thread> http_client_thread(nullptr);
http_client_thread = std::make_unique<std::thread>([&]() { http_client_thread = std::make_unique<std::thread>([&]() {
client_->SetMessageCallback([](const TcpClient &client, const CommMessage &message) {
KVMessage kv_message;
kv_message.ParseFromString(message.data());
EXPECT_EQ(2, kv_message.keys_size());
client_->SetMessageCallback([&](std::shared_ptr<MessageMeta> meta, const Protos &, const void *data, size_t size) {
KVMessage message;
message.ParseFromArray(data, size);
EXPECT_EQ(2, message.keys_size());
}); });


client_->Init(); client_->Init();


CommMessage comm_message;
KVMessage kv_message; KVMessage kv_message;
std::vector<int> keys{1, 2}; std::vector<int> keys{1, 2};
std::vector<int> values{3, 4}; std::vector<int> values{3, 4};
*kv_message.mutable_keys() = {keys.begin(), keys.end()}; *kv_message.mutable_keys() = {keys.begin(), keys.end()};
*kv_message.mutable_values() = {values.begin(), values.end()}; *kv_message.mutable_values() = {values.begin(), values.end()};


comm_message.set_data(kv_message.SerializeAsString());
client_->SendMessage(comm_message);
auto message_meta = std::make_shared<MessageMeta>();
message_meta->set_cmd(NodeCommand::SEND_DATA);

client_->SendMessage(message_meta, Protos::RAW, kv_message.SerializeAsString().data(), kv_message.ByteSizeLong());


client_->Start(); client_->Start();
}); });


Loading…
Cancel
Save