|
|
|
@@ -32,6 +32,7 @@ void AbstractNode::Register(const std::shared_ptr<TcpClient> &client) { |
|
|
|
CommMessage comm_message; |
|
|
|
*comm_message.mutable_pb_meta() = {message_meta}; |
|
|
|
comm_message.set_data(register_message.SerializeAsString()); |
|
|
|
comm_message.set_user_cmd(""); |
|
|
|
if (!SendMessageSync(client, comm_message)) { |
|
|
|
MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_) |
|
|
|
<< " the node id:" << node_info_.node_id_ << " register timeout!"; |
|
|
|
@@ -54,11 +55,12 @@ void AbstractNode::ProcessRegisterResp(const CommMessage &message) { |
|
|
|
MS_LOG(INFO) << "The node id is:" << node_info_.node_id_ << ", and the rank id is:" << node_info_.rank_id_; |
|
|
|
} |
|
|
|
|
|
|
|
bool AbstractNode::Broadcast(const enum NodeRole &node_role, const std::string &message, const uint32_t &timeout) { |
|
|
|
bool AbstractNode::Broadcast(const enum NodeRole &node_role, const CommMessage &message, const uint32_t &timeout) { |
|
|
|
if (node_role != NodeRole::SERVER) { |
|
|
|
MS_LOG(EXCEPTION) << "Currently only supports broadcast to server nodes"; |
|
|
|
} |
|
|
|
|
|
|
|
CommMessage &comm_message = const_cast<CommMessage &>(message); |
|
|
|
uint64_t request_id = ++next_request_id_; |
|
|
|
message_tracker_[request_id] = std::make_pair(nodes_address_.size(), 0); |
|
|
|
|
|
|
|
@@ -69,9 +71,7 @@ bool AbstractNode::Broadcast(const enum NodeRole &node_role, const std::string & |
|
|
|
message_meta.set_rank_id(node_info_.rank_id_); |
|
|
|
message_meta.set_role(node_info_.node_role_); |
|
|
|
|
|
|
|
CommMessage comm_message; |
|
|
|
*comm_message.mutable_pb_meta() = {message_meta}; |
|
|
|
comm_message.set_data(message); |
|
|
|
auto client = GetOrCreateTcpClient((*it).first.second); |
|
|
|
client->SendMessage(comm_message); |
|
|
|
} |
|
|
|
@@ -84,26 +84,26 @@ void AbstractNode::set_event_callback(const OnNodeEventMessage &on_node_event_me |
|
|
|
on_node_event_message_ = on_node_event_message; |
|
|
|
} |
|
|
|
|
|
|
|
bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, |
|
|
|
bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message, |
|
|
|
const uint32_t &timeout) { |
|
|
|
if (!CommUtil::ValidateRankId(node_role, rank_id)) { |
|
|
|
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; |
|
|
|
} |
|
|
|
|
|
|
|
CommMessage &comm_message = const_cast<CommMessage &>(message); |
|
|
|
|
|
|
|
MessageMeta message_meta; |
|
|
|
message_meta.set_cmd(NodeCommand::SEND_DATA); |
|
|
|
message_meta.set_rank_id(node_info_.rank_id_); |
|
|
|
message_meta.set_role(node_info_.node_role_); |
|
|
|
|
|
|
|
CommMessage comm_message; |
|
|
|
*comm_message.mutable_pb_meta() = {message_meta}; |
|
|
|
comm_message.set_data(message); |
|
|
|
auto client = GetOrCreateTcpClient(rank_id); |
|
|
|
return SendMessageSync(client, comm_message, timeout); |
|
|
|
} |
|
|
|
|
|
|
|
bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, |
|
|
|
const std::vector<std::string> &data, const uint32_t &timeout) { |
|
|
|
const std::vector<CommMessage> &data, const uint32_t &timeout) { |
|
|
|
uint64_t request_id = ++next_request_id_; |
|
|
|
message_tracker_[request_id] = std::make_pair(data.size(), 0); |
|
|
|
|
|
|
|
@@ -121,9 +121,8 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> & |
|
|
|
message_meta.set_rank_id(node_info_.rank_id_); |
|
|
|
message_meta.set_role(node_info_.node_role_); |
|
|
|
|
|
|
|
CommMessage comm_message; |
|
|
|
CommMessage &comm_message = const_cast<CommMessage &>(data.at(it)); |
|
|
|
*comm_message.mutable_pb_meta() = {message_meta}; |
|
|
|
comm_message.set_data(data.at(it)); |
|
|
|
|
|
|
|
auto client = GetOrCreateTcpClient(rank_ids.at(it)); |
|
|
|
client->SendMessage(comm_message); |
|
|
|
@@ -133,19 +132,21 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> & |
|
|
|
return Wait(request_id, timeout); |
|
|
|
} |
|
|
|
|
|
|
|
bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, |
|
|
|
std::string *output, const uint32_t &timeout) { |
|
|
|
bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, const CommMessage &message, |
|
|
|
CommMessage *output, const uint32_t &timeout) { |
|
|
|
MS_EXCEPTION_IF_NULL(output); |
|
|
|
if (!CommUtil::ValidateRankId(node_role, rank_id)) { |
|
|
|
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; |
|
|
|
} |
|
|
|
|
|
|
|
CommMessage &comm_message = const_cast<CommMessage &>(message); |
|
|
|
|
|
|
|
uint64_t request_id = ++next_request_id_; |
|
|
|
message_tracker_[request_id] = std::make_pair(1, 0); |
|
|
|
set_message_callback(request_id, [&]() { |
|
|
|
receive_messages_mutex_.lock(); |
|
|
|
auto res = receive_messages_[request_id]; |
|
|
|
*output = res[rank_id].data(); |
|
|
|
*output = res[rank_id]; |
|
|
|
receive_messages_.erase(request_id); |
|
|
|
receive_messages_mutex_.unlock(); |
|
|
|
}); |
|
|
|
@@ -156,9 +157,7 @@ bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, |
|
|
|
message_meta.set_rank_id(node_info_.rank_id_); |
|
|
|
message_meta.set_role(node_info_.node_role_); |
|
|
|
|
|
|
|
CommMessage comm_message; |
|
|
|
*comm_message.mutable_pb_meta() = {message_meta}; |
|
|
|
comm_message.set_data(message); |
|
|
|
auto client = GetOrCreateTcpClient(rank_id); |
|
|
|
client->SendMessage(comm_message); |
|
|
|
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) |
|
|
|
@@ -167,7 +166,7 @@ bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, |
|
|
|
} |
|
|
|
|
|
|
|
bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, |
|
|
|
const std::vector<std::string> &data, std::vector<std::string> *output, |
|
|
|
const std::vector<CommMessage> &data, std::vector<CommMessage> *output, |
|
|
|
const uint32_t &timeout) { |
|
|
|
MS_EXCEPTION_IF_NULL(output); |
|
|
|
uint64_t request_id = ++next_request_id_; |
|
|
|
@@ -183,7 +182,7 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> & |
|
|
|
receive_messages_mutex_.lock(); |
|
|
|
auto res = receive_messages_[request_id]; |
|
|
|
for (size_t it = 0; it < len; ++it) { |
|
|
|
(*output).push_back(res[rank_ids.at(it)].data()); |
|
|
|
(*output).push_back(res[rank_ids.at(it)]); |
|
|
|
} |
|
|
|
receive_messages_.erase(request_id); |
|
|
|
receive_messages_mutex_.unlock(); |
|
|
|
@@ -200,9 +199,8 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> & |
|
|
|
message_meta.set_rank_id(node_info_.rank_id_); |
|
|
|
message_meta.set_role(node_info_.node_role_); |
|
|
|
|
|
|
|
CommMessage comm_message; |
|
|
|
CommMessage &comm_message = const_cast<CommMessage &>(data.at(it)); |
|
|
|
*comm_message.mutable_pb_meta() = {message_meta}; |
|
|
|
comm_message.set_data(data.at(it)); |
|
|
|
|
|
|
|
auto client = GetOrCreateTcpClient(rank_ids.at(it)); |
|
|
|
client->SendMessage(comm_message); |
|
|
|
@@ -223,37 +221,37 @@ bool AbstractNode::Wait(uint64_t request_id, const uint32_t &timeout) { |
|
|
|
} |
|
|
|
|
|
|
|
uint64_t AbstractNode::CollectiveSendAsync(const enum NodeRole &node_role, const uint32_t &rank_id, |
|
|
|
const std::string &message) { |
|
|
|
const CommMessage &message) { |
|
|
|
if (!CommUtil::ValidateRankId(node_role, rank_id)) { |
|
|
|
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; |
|
|
|
} |
|
|
|
|
|
|
|
CommMessage &comm_message = const_cast<CommMessage &>(message); |
|
|
|
|
|
|
|
MessageMeta message_meta; |
|
|
|
message_meta.set_cmd(NodeCommand::COLLECTIVE_SEND_DATA); |
|
|
|
message_meta.set_rank_id(node_info_.rank_id_); |
|
|
|
message_meta.set_role(node_info_.node_role_); |
|
|
|
|
|
|
|
CommMessage comm_message; |
|
|
|
*comm_message.mutable_pb_meta() = {message_meta}; |
|
|
|
comm_message.set_data(message); |
|
|
|
auto client = GetOrCreateTcpClient(rank_id); |
|
|
|
return SendMessageAsync(client, comm_message); |
|
|
|
} |
|
|
|
|
|
|
|
std::pair<uint32_t, uint64_t> AbstractNode::CollectiveReceiveAsync(const enum NodeRole &node_role, |
|
|
|
const uint32_t &rank_id, std::string *output) { |
|
|
|
const uint32_t &rank_id, CommMessage *output) { |
|
|
|
if (!CommUtil::ValidateRankId(node_role, rank_id)) { |
|
|
|
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal!"; |
|
|
|
} |
|
|
|
|
|
|
|
uint64_t rank_request_id = NextExpectedRankRequestId(rank_id); |
|
|
|
if (received_data_.count(std::make_pair(rank_id, rank_request_id)) > 0) { |
|
|
|
*output = received_data_[std::make_pair(rank_id, rank_request_id)].data(); |
|
|
|
*output = received_data_[std::make_pair(rank_id, rank_request_id)]; |
|
|
|
received_data_.erase(std::make_pair(rank_id, rank_request_id)); |
|
|
|
} else { |
|
|
|
set_receive_callback(rank_id, rank_request_id, [=]() { |
|
|
|
receive_callbacks_mutex_.lock(); |
|
|
|
*output = received_data_[std::make_pair(rank_id, 1)].data(); |
|
|
|
*output = received_data_[std::make_pair(rank_id, rank_request_id)]; |
|
|
|
received_data_.erase(std::make_pair(rank_id, rank_request_id)); |
|
|
|
receive_callbacks_mutex_.unlock(); |
|
|
|
}); |
|
|
|
@@ -415,21 +413,12 @@ bool AbstractNode::InitClientToScheduler() { |
|
|
|
uint16_t scheduler_port = ClusterConfig::scheduler_port(); |
|
|
|
client_to_scheduler_ = std::make_shared<TcpClient>(scheduler_host, scheduler_port); |
|
|
|
client_to_scheduler_->SetMessageCallback([&](const TcpClient &client, const CommMessage &message) { |
|
|
|
switch (message.pb_meta().cmd()) { |
|
|
|
case NodeCommand::HEARTBEAT: |
|
|
|
ProcessHeartbeatResp(message); |
|
|
|
break; |
|
|
|
case NodeCommand::REGISTER: |
|
|
|
ProcessRegisterResp(message); |
|
|
|
break; |
|
|
|
case NodeCommand::FETCH_SERVER: |
|
|
|
ProcessFetchServersResp(message); |
|
|
|
break; |
|
|
|
case NodeCommand::FINISH: |
|
|
|
MS_LOG(INFO) << "The Node id:" << node_info_.node_id_ << " receive a finish message response!"; |
|
|
|
break; |
|
|
|
default: |
|
|
|
MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!"; |
|
|
|
if (handlers_.count(message.pb_meta().cmd()) == 0) { |
|
|
|
MS_LOG(EXCEPTION) << "The cmd:" << message.pb_meta().cmd() << " is not supported!"; |
|
|
|
} |
|
|
|
if (handlers_[message.pb_meta().cmd()] != nullptr) { |
|
|
|
const auto &handler_ptr = handlers_[message.pb_meta().cmd()]; |
|
|
|
(this->*handler_ptr)(message); |
|
|
|
} |
|
|
|
NotifyMessageArrival(message); |
|
|
|
}); |
|
|
|
@@ -607,6 +596,13 @@ uint64_t AbstractNode::NextActualRankRequestId(const uint32_t &rank_id) { |
|
|
|
} |
|
|
|
return rank_request_id; |
|
|
|
} |
|
|
|
|
|
|
|
void AbstractNode::InitCommandHandler() { |
|
|
|
handlers_[NodeCommand::HEARTBEAT] = &AbstractNode::ProcessHeartbeatResp; |
|
|
|
handlers_[NodeCommand::REGISTER] = &AbstractNode::ProcessRegisterResp; |
|
|
|
handlers_[NodeCommand::FETCH_SERVER] = &AbstractNode::ProcessFetchServersResp; |
|
|
|
handlers_[NodeCommand::FINISH] = nullptr; |
|
|
|
} |
|
|
|
} // namespace core |
|
|
|
} // namespace ps |
|
|
|
} // namespace mindspore |