| @@ -33,11 +33,12 @@ void AbstractNode::Register(const std::shared_ptr<TcpClient> &client) { | |||||
| *comm_message.mutable_pb_meta() = {message_meta}; | *comm_message.mutable_pb_meta() = {message_meta}; | ||||
| comm_message.set_data(register_message.SerializeAsString()); | comm_message.set_data(register_message.SerializeAsString()); | ||||
| if (!SendMessageSync(client, comm_message)) { | if (!SendMessageSync(client, comm_message)) { | ||||
| MS_LOG(EXCEPTION) << "Node register timeout!"; | |||||
| MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||||
| << " the node id:" << node_info_.node_id_ << " register timeout!"; | |||||
| } | } | ||||
| MS_LOG(INFO) << "The node id:" << node_info_.node_id_ | |||||
| << "is registering to scheduler, the request id is:" << message_meta.request_id(); | |||||
| MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||||
| << " the node id:" << node_info_.node_id_ << "is registering to scheduler!"; | |||||
| } | } | ||||
| void AbstractNode::ProcessRegisterResp(const CommMessage &message) { | void AbstractNode::ProcessRegisterResp(const CommMessage &message) { | ||||
| @@ -395,7 +396,7 @@ bool AbstractNode::InitClientToScheduler() { | |||||
| client_to_scheduler_->Init(); | client_to_scheduler_->Init(); | ||||
| client_to_scheduler_thread_ = std::make_unique<std::thread>([&]() { | client_to_scheduler_thread_ = std::make_unique<std::thread>([&]() { | ||||
| MS_LOG(INFO) << "The worker node start a tcp client!"; | |||||
| MS_LOG(INFO) << "The node start a tcp client!"; | |||||
| client_to_scheduler_->Start(); | client_to_scheduler_->Start(); | ||||
| }); | }); | ||||
| @@ -129,6 +129,16 @@ bool CommUtil::ValidateRankId(const enum NodeRole &node_role, const uint32_t &ra | |||||
| } | } | ||||
| return true; | return true; | ||||
| } | } | ||||
| bool CommUtil::Retry(const std::function<bool()> &func, size_t max_attempts, size_t interval_milliseconds) { | |||||
| for (size_t attempt = 0; attempt < max_attempts; ++attempt) { | |||||
| if (func()) { | |||||
| return true; | |||||
| } | |||||
| std::this_thread::sleep_for(std::chrono::milliseconds(interval_milliseconds)); | |||||
| } | |||||
| return false; | |||||
| } | |||||
| } // namespace core | } // namespace core | ||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -45,6 +45,7 @@ | |||||
| #include <sstream> | #include <sstream> | ||||
| #include <string> | #include <string> | ||||
| #include <utility> | #include <utility> | ||||
| #include <thread> | |||||
| #include "proto/comm.pb.h" | #include "proto/comm.pb.h" | ||||
| #include "proto/ps.pb.h" | #include "proto/ps.pb.h" | ||||
| @@ -68,6 +69,7 @@ class CommUtil { | |||||
| static std::string GenerateUUID(); | static std::string GenerateUUID(); | ||||
| static std::string NodeRoleToString(const NodeRole &role); | static std::string NodeRoleToString(const NodeRole &role); | ||||
| static bool ValidateRankId(const enum NodeRole &node_role, const uint32_t &rank_id); | static bool ValidateRankId(const enum NodeRole &node_role, const uint32_t &rank_id); | ||||
| static bool Retry(const std::function<bool()> &func, size_t max_attempts, size_t interval_milliseconds); | |||||
| private: | private: | ||||
| static std::random_device rd; | static std::random_device rd; | ||||
| @@ -57,7 +57,7 @@ class Node { | |||||
| using OnNodeEventMessage = std::function<void(const NodeEvent &event)>; | using OnNodeEventMessage = std::function<void(const NodeEvent &event)>; | ||||
| using MessageCallback = std::function<void()>; | using MessageCallback = std::function<void()>; | ||||
| virtual bool Start(const uint32_t &timeout = kTimeoutInSeconds) = 0; | |||||
| virtual bool Start(const uint32_t &timeout = ClusterConfig::cluster_available_timeout()) = 0; | |||||
| virtual bool Stop() = 0; | virtual bool Stop() = 0; | ||||
| virtual bool Finish(const uint32_t &timeout = kTimeoutInSeconds) = 0; | virtual bool Finish(const uint32_t &timeout = kTimeoutInSeconds) = 0; | ||||
| @@ -105,7 +105,7 @@ void NodeManager::UpdateClusterState() { | |||||
| } | } | ||||
| // 2. update cluster finish state | // 2. update cluster finish state | ||||
| if (finish_nodes_id_.size() == total_node_num_) { | |||||
| if (finish_nodes_id_.size() == total_node_num_ || SizeToInt(finish_nodes_id_.size()) == current_node_num_) { | |||||
| is_cluster_finish_ = true; | is_cluster_finish_ = true; | ||||
| is_cluster_ready_ = true; | is_cluster_ready_ = true; | ||||
| } | } | ||||
| @@ -119,7 +119,9 @@ void NodeManager::UpdateClusterState() { | |||||
| void NodeManager::CheckClusterTimeout() { | void NodeManager::CheckClusterTimeout() { | ||||
| if (total_node_num_ != nodes_info_.size()) { | if (total_node_num_ != nodes_info_.size()) { | ||||
| MS_LOG(WARNING) << "The cluster is not ready after " << ClusterConfig::cluster_available_timeout() | MS_LOG(WARNING) << "The cluster is not ready after " << ClusterConfig::cluster_available_timeout() | ||||
| << " seconds,so finish the cluster"; | |||||
| << " seconds,so finish the cluster, and change total node number from " << total_node_num_ << " to " | |||||
| << nodes_info_.size(); | |||||
| current_node_num_ = nodes_info_.size(); | |||||
| is_cluster_timeout_ = true; | is_cluster_timeout_ = true; | ||||
| } | } | ||||
| } | } | ||||
| @@ -35,6 +35,7 @@ | |||||
| #include "proto/ps.pb.h" | #include "proto/ps.pb.h" | ||||
| #include "ps/core/node.h" | #include "ps/core/node.h" | ||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| #include "utils/convert_utils_base.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| @@ -47,6 +48,7 @@ class NodeManager { | |||||
| is_cluster_timeout_(false), | is_cluster_timeout_(false), | ||||
| is_node_timeout_(false), | is_node_timeout_(false), | ||||
| total_node_num_(0), | total_node_num_(0), | ||||
| current_node_num_(-1), | |||||
| next_worker_rank_id_(-1), | next_worker_rank_id_(-1), | ||||
| next_server_rank_id_(-1) {} | next_server_rank_id_(-1) {} | ||||
| virtual ~NodeManager() = default; | virtual ~NodeManager() = default; | ||||
| @@ -75,6 +77,7 @@ class NodeManager { | |||||
| std::atomic<bool> is_cluster_timeout_; | std::atomic<bool> is_cluster_timeout_; | ||||
| std::atomic<bool> is_node_timeout_; | std::atomic<bool> is_node_timeout_; | ||||
| uint32_t total_node_num_; | uint32_t total_node_num_; | ||||
| int32_t current_node_num_; | |||||
| std::atomic<int> next_worker_rank_id_; | std::atomic<int> next_worker_rank_id_; | ||||
| std::atomic<int> next_server_rank_id_; | std::atomic<int> next_server_rank_id_; | ||||
| // worker nodes and server nodes | // worker nodes and server nodes | ||||
| @@ -44,7 +44,7 @@ class SchedulerNode : public Node { | |||||
| SchedulerNode() : server_(nullptr), scheduler_thread_(nullptr), update_state_thread_(nullptr) {} | SchedulerNode() : server_(nullptr), scheduler_thread_(nullptr), update_state_thread_(nullptr) {} | ||||
| ~SchedulerNode() override; | ~SchedulerNode() override; | ||||
| bool Start(const uint32_t &timeout = kTimeoutInSeconds) override; | |||||
| bool Start(const uint32_t &timeout = ClusterConfig::cluster_available_timeout()) override; | |||||
| bool Stop() override; | bool Stop() override; | ||||
| bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override; | bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override; | ||||
| @@ -30,7 +30,7 @@ bool ServerNode::Start(const uint32_t &timeout) { | |||||
| StartHeartbeatTimer(client_to_scheduler_); | StartHeartbeatTimer(client_to_scheduler_); | ||||
| if (!WaitForStart(timeout)) { | if (!WaitForStart(timeout)) { | ||||
| MS_LOG(EXCEPTION) << "Start Worker node timeout!"; | |||||
| MS_LOG(ERROR) << "Start Server node timeout!"; | |||||
| } | } | ||||
| MS_LOG(INFO) << "The cluster is ready to use!"; | MS_LOG(INFO) << "The cluster is ready to use!"; | ||||
| @@ -40,7 +40,7 @@ class ServerNode : public AbstractNode { | |||||
| ServerNode() : server_(nullptr), server_thread_(nullptr) {} | ServerNode() : server_(nullptr), server_thread_(nullptr) {} | ||||
| ~ServerNode() override; | ~ServerNode() override; | ||||
| bool Start(const uint32_t &timeout = kTimeoutInSeconds) override; | |||||
| bool Start(const uint32_t &timeout = ClusterConfig::cluster_available_timeout()) override; | |||||
| bool Stop() override; | bool Stop() override; | ||||
| bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override; | bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override; | ||||
| @@ -36,6 +36,8 @@ namespace mindspore { | |||||
| namespace ps { | namespace ps { | ||||
| namespace core { | namespace core { | ||||
| event_base *TcpClient::event_base_ = nullptr; | event_base *TcpClient::event_base_ = nullptr; | ||||
| std::mutex TcpClient::event_base_mutex_; | |||||
| bool TcpClient::is_started_ = false; | |||||
| TcpClient::TcpClient(const std::string &address, std::uint16_t port) | TcpClient::TcpClient(const std::string &address, std::uint16_t port) | ||||
| : event_timeout_(nullptr), | : event_timeout_(nullptr), | ||||
| @@ -60,10 +62,6 @@ TcpClient::~TcpClient() { | |||||
| event_free(event_timeout_); | event_free(event_timeout_); | ||||
| event_timeout_ = nullptr; | event_timeout_ = nullptr; | ||||
| } | } | ||||
| if (event_base_) { | |||||
| event_base_free(event_base_); | |||||
| event_base_ = nullptr; | |||||
| } | |||||
| } | } | ||||
| std::string TcpClient::GetServerAddress() const { return server_address_; } | std::string TcpClient::GetServerAddress() const { return server_address_; } | ||||
| @@ -234,6 +232,13 @@ void TcpClient::EventCallback(struct bufferevent *bev, std::int16_t events, void | |||||
| } | } | ||||
| void TcpClient::Start() { | void TcpClient::Start() { | ||||
| event_base_mutex_.lock(); | |||||
| if (is_started_) { | |||||
| event_base_mutex_.unlock(); | |||||
| return; | |||||
| } | |||||
| is_started_ = true; | |||||
| event_base_mutex_.unlock(); | |||||
| MS_EXCEPTION_IF_NULL(event_base_); | MS_EXCEPTION_IF_NULL(event_base_); | ||||
| int ret = event_base_dispatch(event_base_); | int ret = event_base_dispatch(event_base_); | ||||
| MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base dispatch success!"; | MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base dispatch success!"; | ||||
| @@ -260,7 +265,7 @@ void TcpClient::SendMessage(const CommMessage &message) const { | |||||
| MS_EXCEPTION_IF_NULL(buffer_event_); | MS_EXCEPTION_IF_NULL(buffer_event_); | ||||
| size_t buf_size = message.ByteSizeLong(); | size_t buf_size = message.ByteSizeLong(); | ||||
| std::vector<unsigned char> serialized(buf_size); | std::vector<unsigned char> serialized(buf_size); | ||||
| message.SerializeToArray(serialized.data(), static_cast<int>(buf_size)); | |||||
| message.SerializeToArray(serialized.data(), SizeToInt(buf_size)); | |||||
| if (evbuffer_add(bufferevent_get_output(buffer_event_), &buf_size, sizeof(buf_size)) == -1) { | if (evbuffer_add(bufferevent_get_output(buffer_event_), &buf_size, sizeof(buf_size)) == -1) { | ||||
| MS_LOG(EXCEPTION) << "Event buffer add header failed!"; | MS_LOG(EXCEPTION) << "Event buffer add header failed!"; | ||||
| } | } | ||||
| @@ -35,6 +35,7 @@ | |||||
| #include "ps/core/cluster_config.h" | #include "ps/core/cluster_config.h" | ||||
| #include "proto/comm.pb.h" | #include "proto/comm.pb.h" | ||||
| #include "proto/ps.pb.h" | #include "proto/ps.pb.h" | ||||
| #include "utils/convert_utils_base.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| @@ -86,6 +87,9 @@ class TcpClient { | |||||
| OnTimer on_timer_callback_; | OnTimer on_timer_callback_; | ||||
| static event_base *event_base_; | static event_base *event_base_; | ||||
| static std::mutex event_base_mutex_; | |||||
| static bool is_started_; | |||||
| std::mutex connection_mutex_; | std::mutex connection_mutex_; | ||||
| std::condition_variable connection_cond_; | std::condition_variable connection_cond_; | ||||
| event *event_timeout_; | event *event_timeout_; | ||||
| @@ -32,7 +32,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| namespace core { | namespace core { | ||||
| void TcpConnection::InitConnection() { | void TcpConnection::InitConnection() { | ||||
| tcp_message_handler_.SetCallback([&](const CommMessage &message) { | tcp_message_handler_.SetCallback([&](const CommMessage &message) { | ||||
| OnServerReceiveMessage on_server_receive = server_->GetServerReceive(); | OnServerReceiveMessage on_server_receive = server_->GetServerReceive(); | ||||
| @@ -58,7 +57,7 @@ void TcpConnection::SendMessage(const CommMessage &message) const { | |||||
| MS_EXCEPTION_IF_NULL(buffer_event_); | MS_EXCEPTION_IF_NULL(buffer_event_); | ||||
| size_t buf_size = message.ByteSizeLong(); | size_t buf_size = message.ByteSizeLong(); | ||||
| std::vector<unsigned char> serialized(buf_size); | std::vector<unsigned char> serialized(buf_size); | ||||
| message.SerializeToArray(serialized.data(), static_cast<int>(buf_size)); | |||||
| message.SerializeToArray(serialized.data(), SizeToInt(buf_size)); | |||||
| if (evbuffer_add(bufferevent_get_output(const_cast<struct bufferevent *>(buffer_event_)), &buf_size, | if (evbuffer_add(bufferevent_get_output(const_cast<struct bufferevent *>(buffer_event_)), &buf_size, | ||||
| sizeof(buf_size)) == -1) { | sizeof(buf_size)) == -1) { | ||||
| MS_LOG(EXCEPTION) << "Event buffer add header failed!"; | MS_LOG(EXCEPTION) << "Event buffer add header failed!"; | ||||
| @@ -304,7 +303,7 @@ void TcpServer::ReadCallback(struct bufferevent *bev, void *connection) { | |||||
| if (read == -1) { | if (read == -1) { | ||||
| MS_LOG(EXCEPTION) << "Can not drain data from the event buffer!"; | MS_LOG(EXCEPTION) << "Can not drain data from the event buffer!"; | ||||
| } | } | ||||
| conn->OnReadHandler(read_buffer, static_cast<size_t>(read)); | |||||
| conn->OnReadHandler(read_buffer, IntToSize(read)); | |||||
| } | } | ||||
| } | } | ||||
| @@ -39,6 +39,7 @@ | |||||
| #include "ps/core/tcp_message_handler.h" | #include "ps/core/tcp_message_handler.h" | ||||
| #include "ps/core/cluster_config.h" | #include "ps/core/cluster_config.h" | ||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| #include "utils/convert_utils_base.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| @@ -40,7 +40,7 @@ class WorkerNode : public AbstractNode { | |||||
| WorkerNode() = default; | WorkerNode() = default; | ||||
| ~WorkerNode() override; | ~WorkerNode() override; | ||||
| bool Start(const uint32_t &timeout = kTimeoutInSeconds) override; | |||||
| bool Start(const uint32_t &timeout = ClusterConfig::cluster_available_timeout()) override; | |||||
| bool Stop() override; | bool Stop() override; | ||||
| bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override; | bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override; | ||||
| @@ -0,0 +1,43 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/common_test.h" | |||||
| #include "ps/core/node.h" | |||||
| #include "ps/core/scheduler_node.h" | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| namespace core { | |||||
| class TestClusterAvailableTimeout : public UT::Common { | |||||
| public: | |||||
| TestClusterAvailableTimeout() = default; | |||||
| ~TestClusterAvailableTimeout() override = default; | |||||
| void SetUp() override {} | |||||
| void TearDown() override {} | |||||
| }; | |||||
| TEST_F(TestClusterAvailableTimeout, TestClusterAvailableTimeout) { | |||||
| ClusterConfig::Init(1, 1, std::make_unique<std::string>("127.0.0.1"), 9999); | |||||
| ClusterConfig::set_cluster_available_timeout(3); | |||||
| SchedulerNode node; | |||||
| node.Start(); | |||||
| node.Finish(); | |||||
| node.Stop(); | |||||
| } | |||||
| } // namespace core | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||
| @@ -27,6 +27,18 @@ class TestCommUtil : public UT::Common { | |||||
| public: | public: | ||||
| TestCommUtil() = default; | TestCommUtil() = default; | ||||
| virtual ~TestCommUtil() = default; | virtual ~TestCommUtil() = default; | ||||
| struct MockRetry { | |||||
| bool operator()(std::string mock) { | |||||
| ++count_; | |||||
| if (count_ > 3) { | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| int count_{0}; | |||||
| }; | |||||
| void SetUp() override {} | void SetUp() override {} | ||||
| void TearDown() override {} | void TearDown() override {} | ||||
| @@ -47,6 +59,14 @@ TEST_F(TestCommUtil, ValidateRankId) { | |||||
| EXPECT_TRUE(CommUtil::ValidateRankId(NodeRole::SERVER, 1)); | EXPECT_TRUE(CommUtil::ValidateRankId(NodeRole::SERVER, 1)); | ||||
| EXPECT_FALSE(CommUtil::ValidateRankId(NodeRole::SERVER, 2)); | EXPECT_FALSE(CommUtil::ValidateRankId(NodeRole::SERVER, 2)); | ||||
| } | } | ||||
| TEST_F(TestCommUtil, Retry) { | |||||
| bool const ret = CommUtil::Retry([]() -> bool { return false; }, 5, 100); | |||||
| EXPECT_FALSE(ret); | |||||
| MockRetry mock_retry; | |||||
| bool const mock_ret = CommUtil::Retry([&] { return mock_retry("mock"); }, 5, 100); | |||||
| EXPECT_TRUE(mock_ret); | |||||
| } | |||||
| } // namespace core | } // namespace core | ||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||