| @@ -149,13 +149,13 @@ bool AbstractNode::Send(const enum NodeRole &node_role, const uint32_t &rank_id, | |||||
| } | } | ||||
| bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, | bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, | ||||
| const std::vector<std::string> &data, std::vector<CommMessage *> *comm_message_resp, | |||||
| const std::vector<std::string> &data, std::vector<CommMessage> *comm_message_resp, | |||||
| const uint32_t &timeout) { | const uint32_t &timeout) { | ||||
| MS_EXCEPTION_IF_NULL(comm_message_resp); | MS_EXCEPTION_IF_NULL(comm_message_resp); | ||||
| uint64_t request_id = ++next_request_id_; | uint64_t request_id = ++next_request_id_; | ||||
| message_tracker_[request_id] = std::make_pair(data.size(), 0); | message_tracker_[request_id] = std::make_pair(data.size(), 0); | ||||
| if (rank_ids.size() != data.size() || rank_ids.size() != (*comm_message_resp).size()) { | |||||
| if (rank_ids.size() != data.size()) { | |||||
| MS_LOG(EXCEPTION) << "The number of rank ids, data, comm_message_resp should be equal!"; | MS_LOG(EXCEPTION) << "The number of rank ids, data, comm_message_resp should be equal!"; | ||||
| } | } | ||||
| @@ -165,7 +165,7 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> & | |||||
| receive_messages_mutex_.lock(); | receive_messages_mutex_.lock(); | ||||
| auto res = receive_messages_[request_id]; | auto res = receive_messages_[request_id]; | ||||
| for (size_t it = 0; it < len; ++it) { | for (size_t it = 0; it < len; ++it) { | ||||
| comm_message_resp->at(it) = &res[rank_ids.at(it)]; | |||||
| (*comm_message_resp).push_back(res[rank_ids.at(it)]); | |||||
| } | } | ||||
| receive_messages_.erase(request_id); | receive_messages_.erase(request_id); | ||||
| receive_messages_mutex_.unlock(); | receive_messages_mutex_.unlock(); | ||||
| @@ -394,7 +394,7 @@ void AbstractNode::ProcessSendDataResp(const CommMessage &message) { | |||||
| const uint64_t request_id = message_meta.request_id(); | const uint64_t request_id = message_meta.request_id(); | ||||
| auto it = receive_messages_.find(request_id); | auto it = receive_messages_.find(request_id); | ||||
| if (it != receive_messages_.end()) { | if (it != receive_messages_.end()) { | ||||
| it->second.insert(std::make_pair(rank_id, message)); | |||||
| it->second[rank_id] = message; | |||||
| } else { | } else { | ||||
| std::unordered_map<uint32_t, CommMessage> res; | std::unordered_map<uint32_t, CommMessage> res; | ||||
| res.insert(std::make_pair(rank_id, message)); | res.insert(std::make_pair(rank_id, message)); | ||||
| @@ -44,7 +44,7 @@ class AbstractNode : public Node { | |||||
| virtual bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, | virtual bool Send(const enum NodeRole &node_role, const uint32_t &rank_id, const std::string &message, | ||||
| CommMessage *comm_message_resp, const uint32_t &timeout = kCommTimeoutInSeconds); | CommMessage *comm_message_resp, const uint32_t &timeout = kCommTimeoutInSeconds); | ||||
| virtual bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, | virtual bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, | ||||
| const std::vector<std::string> &data, std::vector<CommMessage *> *comm_message_resp, | |||||
| const std::vector<std::string> &data, std::vector<CommMessage> *comm_message_resp, | |||||
| const uint32_t &timeout = kCommTimeoutInSeconds); | const uint32_t &timeout = kCommTimeoutInSeconds); | ||||
| bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds); | bool Wait(uint64_t request_id, const uint32_t &timeout = kCommTimeoutInSeconds); | ||||
| @@ -25,7 +25,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| namespace core { | namespace core { | ||||
| enum NodeEvent { CLUSTER_TIMEOUT = 0, NODE_TIMEOUT = 1 }; | enum NodeEvent { CLUSTER_TIMEOUT = 0, NODE_TIMEOUT = 1 }; | ||||
| struct NodeInfo { | struct NodeInfo { | ||||
| @@ -46,7 +46,6 @@ int NodeManager::NextRankId(const RegisterMessage ®ister_message) { | |||||
| nodes_info_[node_id] = node_info; | nodes_info_[node_id] = node_info; | ||||
| MS_LOG(INFO) << "The server node id:" << node_id << ",node ip: " << node_info.ip_ << ",node port:" << port | MS_LOG(INFO) << "The server node id:" << node_id << ",node ip: " << node_info.ip_ << ",node port:" << port | ||||
| << " assign rank id:" << rank_id; | << " assign rank id:" << rank_id; | ||||
| } else if (register_message.role() == NodeRole::WORKER) { | } else if (register_message.role() == NodeRole::WORKER) { | ||||
| rank_id = ++next_worker_rank_id_; | rank_id = ++next_worker_rank_id_; | ||||
| NodeInfo node_info; | NodeInfo node_info; | ||||
| @@ -19,7 +19,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| namespace core { | namespace core { | ||||
| SchedulerNode::~SchedulerNode() { | SchedulerNode::~SchedulerNode() { | ||||
| MS_LOG(INFO) << "Stop scheduler node!"; | MS_LOG(INFO) << "Stop scheduler node!"; | ||||
| if (!is_already_stopped_) { | if (!is_already_stopped_) { | ||||
| @@ -38,7 +38,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| namespace core { | namespace core { | ||||
| class SchedulerNode : public Node { | class SchedulerNode : public Node { | ||||
| public: | public: | ||||
| SchedulerNode() : server_(nullptr), scheduler_thread_(nullptr), update_state_thread_(nullptr) {} | SchedulerNode() : server_(nullptr), scheduler_thread_(nullptr), update_state_thread_(nullptr) {} | ||||
| @@ -280,7 +280,6 @@ void TcpClient::StartTimer(const uint32_t &time) { | |||||
| void TcpClient::set_timer_callback(const OnTimer &timer) { on_timer_callback_ = timer; } | void TcpClient::set_timer_callback(const OnTimer &timer) { on_timer_callback_ = timer; } | ||||
| const event_base &TcpClient::eventbase() { return *event_base_; } | const event_base &TcpClient::eventbase() { return *event_base_; } | ||||
| } // namespace core | } // namespace core | ||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -97,7 +97,6 @@ class TcpClient { | |||||
| std::atomic<bool> is_stop_; | std::atomic<bool> is_stop_; | ||||
| std::atomic<bool> is_connected_; | std::atomic<bool> is_connected_; | ||||
| }; | }; | ||||
| } // namespace core | } // namespace core | ||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -368,7 +368,6 @@ int TcpServer::ConnectionNum() const { return connections_.size(); } | |||||
| const std::map<evutil_socket_t, const TcpConnection *> &TcpServer::Connections() const { return connections_; } | const std::map<evutil_socket_t, const TcpConnection *> &TcpServer::Connections() const { return connections_; } | ||||
| void TcpServer::SetMessageCallback(const OnServerReceiveMessage &cb) { message_callback_ = cb; } | void TcpServer::SetMessageCallback(const OnServerReceiveMessage &cb) { message_callback_ = cb; } | ||||
| } // namespace core | } // namespace core | ||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -40,7 +40,6 @@ TEST_F(TestClusterConfig, HeartbeatInterval) { | |||||
| EXPECT_STREQ(ClusterConfig::scheduler_host().c_str(), "127.0.0.1"); | EXPECT_STREQ(ClusterConfig::scheduler_host().c_str(), "127.0.0.1"); | ||||
| EXPECT_TRUE(ClusterConfig::scheduler_port() == 8080); | EXPECT_TRUE(ClusterConfig::scheduler_port() == 8080); | ||||
| } | } | ||||
| } // namespace core | } // namespace core | ||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -41,12 +41,12 @@ TEST_F(TestCommUtil, GetAvailableInterfaceAndIP) { | |||||
| } | } | ||||
| TEST_F(TestCommUtil, ValidateRankId) { | TEST_F(TestCommUtil, ValidateRankId) { | ||||
| ClusterConfig::Init(3, 2, std::make_unique<std::string>("127.0.0.1"), 9999); | |||||
| EXPECT_TRUE(CommUtil::ValidateRankId(NodeRole::WORKER, 2)); | |||||
| EXPECT_FALSE(CommUtil::ValidateRankId(NodeRole::WORKER, 3)); | |||||
| EXPECT_TRUE(CommUtil::ValidateRankId(NodeRole::SERVER, 1)); | |||||
| EXPECT_FALSE(CommUtil::ValidateRankId(NodeRole::SERVER, 2)); | |||||
| ClusterConfig::Init(3, 2, std::make_unique<std::string>("127.0.0.1"), 9999); | |||||
| EXPECT_TRUE(CommUtil::ValidateRankId(NodeRole::WORKER, 2)); | |||||
| EXPECT_FALSE(CommUtil::ValidateRankId(NodeRole::WORKER, 3)); | |||||
| EXPECT_TRUE(CommUtil::ValidateRankId(NodeRole::SERVER, 1)); | |||||
| EXPECT_FALSE(CommUtil::ValidateRankId(NodeRole::SERVER, 2)); | |||||
| } | } | ||||
| } // namespace comm | |||||
| } // namespace core | |||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -42,7 +42,6 @@ TEST_F(TestTcpClient, InitClientPortErrorNoException) { | |||||
| EXPECT_NO_THROW(client->Init()); | EXPECT_NO_THROW(client->Init()); | ||||
| } | } | ||||
| } // namespace comm | |||||
| } // namespace core | |||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||