diff --git a/mindspore/ccsrc/ps/core/abstract_node.cc b/mindspore/ccsrc/ps/core/abstract_node.cc index 52dd4f48e0..95cc15d929 100644 --- a/mindspore/ccsrc/ps/core/abstract_node.cc +++ b/mindspore/ccsrc/ps/core/abstract_node.cc @@ -33,11 +33,12 @@ void AbstractNode::Register(const std::shared_ptr &client) { *comm_message.mutable_pb_meta() = {message_meta}; comm_message.set_data(register_message.SerializeAsString()); if (!SendMessageSync(client, comm_message)) { - MS_LOG(EXCEPTION) << "Node register timeout!"; + MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_) + << " the node id:" << node_info_.node_id_ << " register timeout!"; } - MS_LOG(INFO) << "The node id:" << node_info_.node_id_ - << "is registering to scheduler, the request id is:" << message_meta.request_id(); + MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_) + << " the node id:" << node_info_.node_id_ << "is registering to scheduler!"; } void AbstractNode::ProcessRegisterResp(const CommMessage &message) { @@ -395,7 +396,7 @@ bool AbstractNode::InitClientToScheduler() { client_to_scheduler_->Init(); client_to_scheduler_thread_ = std::make_unique([&]() { - MS_LOG(INFO) << "The worker node start a tcp client!"; + MS_LOG(INFO) << "The node start a tcp client!"; client_to_scheduler_->Start(); }); diff --git a/mindspore/ccsrc/ps/core/comm_util.cc b/mindspore/ccsrc/ps/core/comm_util.cc index 2e1d73cecf..dff9c28d3a 100644 --- a/mindspore/ccsrc/ps/core/comm_util.cc +++ b/mindspore/ccsrc/ps/core/comm_util.cc @@ -129,6 +129,16 @@ bool CommUtil::ValidateRankId(const enum NodeRole &node_role, const uint32_t &ra } return true; } + +bool CommUtil::Retry(const std::function &func, size_t max_attempts, size_t interval_milliseconds) { + for (size_t attempt = 0; attempt < max_attempts; ++attempt) { + if (func()) { + return true; + } + std::this_thread::sleep_for(std::chrono::milliseconds(interval_milliseconds)); + } + return false; +} } // namespace core } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/core/comm_util.h b/mindspore/ccsrc/ps/core/comm_util.h index 13ed85db82..41fb373741 100644 --- a/mindspore/ccsrc/ps/core/comm_util.h +++ b/mindspore/ccsrc/ps/core/comm_util.h @@ -45,6 +45,7 @@ #include #include #include +#include #include "proto/comm.pb.h" #include "proto/ps.pb.h" @@ -68,6 +69,7 @@ class CommUtil { static std::string GenerateUUID(); static std::string NodeRoleToString(const NodeRole &role); static bool ValidateRankId(const enum NodeRole &node_role, const uint32_t &rank_id); + static bool Retry(const std::function &func, size_t max_attempts, size_t interval_milliseconds); private: static std::random_device rd; diff --git a/mindspore/ccsrc/ps/core/node.h b/mindspore/ccsrc/ps/core/node.h index 89f006a42a..bf7e3cbee1 100644 --- a/mindspore/ccsrc/ps/core/node.h +++ b/mindspore/ccsrc/ps/core/node.h @@ -57,7 +57,7 @@ class Node { using OnNodeEventMessage = std::function; using MessageCallback = std::function; - virtual bool Start(const uint32_t &timeout = kTimeoutInSeconds) = 0; + virtual bool Start(const uint32_t &timeout = ClusterConfig::cluster_available_timeout()) = 0; virtual bool Stop() = 0; virtual bool Finish(const uint32_t &timeout = kTimeoutInSeconds) = 0; diff --git a/mindspore/ccsrc/ps/core/node_manager.cc b/mindspore/ccsrc/ps/core/node_manager.cc index 76ced579af..eb1d0d609b 100644 --- a/mindspore/ccsrc/ps/core/node_manager.cc +++ b/mindspore/ccsrc/ps/core/node_manager.cc @@ -105,7 +105,7 @@ void NodeManager::UpdateClusterState() { } // 2. update cluster finish state - if (finish_nodes_id_.size() == total_node_num_) { + if (finish_nodes_id_.size() == total_node_num_ || SizeToInt(finish_nodes_id_.size()) == current_node_num_) { is_cluster_finish_ = true; is_cluster_ready_ = true; } @@ -119,7 +119,9 @@ void NodeManager::UpdateClusterState() { void NodeManager::CheckClusterTimeout() { if (total_node_num_ != nodes_info_.size()) { MS_LOG(WARNING) << "The cluster is not ready after " << ClusterConfig::cluster_available_timeout() - << " seconds,so finish the cluster"; + << " seconds,so finish the cluster, and change total node number from " << total_node_num_ << " to " + << nodes_info_.size(); + current_node_num_ = nodes_info_.size(); is_cluster_timeout_ = true; } } diff --git a/mindspore/ccsrc/ps/core/node_manager.h b/mindspore/ccsrc/ps/core/node_manager.h index 3cc8eeb60f..27ee8c41f5 100644 --- a/mindspore/ccsrc/ps/core/node_manager.h +++ b/mindspore/ccsrc/ps/core/node_manager.h @@ -35,6 +35,7 @@ #include "proto/ps.pb.h" #include "ps/core/node.h" #include "utils/log_adapter.h" +#include "utils/convert_utils_base.h" namespace mindspore { namespace ps { @@ -47,6 +48,7 @@ class NodeManager { is_cluster_timeout_(false), is_node_timeout_(false), total_node_num_(0), + current_node_num_(-1), next_worker_rank_id_(-1), next_server_rank_id_(-1) {} virtual ~NodeManager() = default; @@ -75,6 +77,7 @@ class NodeManager { std::atomic is_cluster_timeout_; std::atomic is_node_timeout_; uint32_t total_node_num_; + int32_t current_node_num_; std::atomic next_worker_rank_id_; std::atomic next_server_rank_id_; // worker nodes and server nodes diff --git a/mindspore/ccsrc/ps/core/scheduler_node.h b/mindspore/ccsrc/ps/core/scheduler_node.h index 86488ea9ac..6f13236797 100644 --- a/mindspore/ccsrc/ps/core/scheduler_node.h +++ b/mindspore/ccsrc/ps/core/scheduler_node.h @@ -44,7 +44,7 @@ class SchedulerNode : public Node { SchedulerNode() : server_(nullptr), scheduler_thread_(nullptr), update_state_thread_(nullptr) {} ~SchedulerNode() override; - bool Start(const uint32_t &timeout = kTimeoutInSeconds) override; + bool Start(const uint32_t &timeout = ClusterConfig::cluster_available_timeout()) override; bool Stop() override; bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override; diff --git a/mindspore/ccsrc/ps/core/server_node.cc b/mindspore/ccsrc/ps/core/server_node.cc index 8a76a5b7d4..e6658749ab 100644 --- a/mindspore/ccsrc/ps/core/server_node.cc +++ b/mindspore/ccsrc/ps/core/server_node.cc @@ -30,7 +30,7 @@ bool ServerNode::Start(const uint32_t &timeout) { StartHeartbeatTimer(client_to_scheduler_); if (!WaitForStart(timeout)) { - MS_LOG(EXCEPTION) << "Start Worker node timeout!"; + MS_LOG(ERROR) << "Start Server node timeout!"; } MS_LOG(INFO) << "The cluster is ready to use!"; diff --git a/mindspore/ccsrc/ps/core/server_node.h b/mindspore/ccsrc/ps/core/server_node.h index 77c196902f..73d103840e 100644 --- a/mindspore/ccsrc/ps/core/server_node.h +++ b/mindspore/ccsrc/ps/core/server_node.h @@ -40,7 +40,7 @@ class ServerNode : public AbstractNode { ServerNode() : server_(nullptr), server_thread_(nullptr) {} ~ServerNode() override; - bool Start(const uint32_t &timeout = kTimeoutInSeconds) override; + bool Start(const uint32_t &timeout = ClusterConfig::cluster_available_timeout()) override; bool Stop() override; bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override; diff --git a/mindspore/ccsrc/ps/core/tcp_client.cc b/mindspore/ccsrc/ps/core/tcp_client.cc index d607b819a0..b6528567ba 100644 --- a/mindspore/ccsrc/ps/core/tcp_client.cc +++ b/mindspore/ccsrc/ps/core/tcp_client.cc @@ -36,6 +36,8 @@ namespace mindspore { namespace ps { namespace core { event_base *TcpClient::event_base_ = nullptr; +std::mutex TcpClient::event_base_mutex_; +bool TcpClient::is_started_ = false; TcpClient::TcpClient(const std::string &address, std::uint16_t port) : event_timeout_(nullptr), @@ -60,10 +62,6 @@ TcpClient::~TcpClient() { event_free(event_timeout_); event_timeout_ = nullptr; } - if (event_base_) { - event_base_free(event_base_); - event_base_ = nullptr; - } } std::string TcpClient::GetServerAddress() const { return server_address_; } @@ -234,6 +232,13 @@ void TcpClient::EventCallback(struct bufferevent *bev, std::int16_t events, void } void TcpClient::Start() { + event_base_mutex_.lock(); + if (is_started_) { + event_base_mutex_.unlock(); + return; + } + is_started_ = true; + event_base_mutex_.unlock(); MS_EXCEPTION_IF_NULL(event_base_); int ret = event_base_dispatch(event_base_); MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base dispatch success!"; @@ -260,7 +265,7 @@ void TcpClient::SendMessage(const CommMessage &message) const { MS_EXCEPTION_IF_NULL(buffer_event_); size_t buf_size = message.ByteSizeLong(); std::vector serialized(buf_size); - message.SerializeToArray(serialized.data(), static_cast(buf_size)); + message.SerializeToArray(serialized.data(), SizeToInt(buf_size)); if (evbuffer_add(bufferevent_get_output(buffer_event_), &buf_size, sizeof(buf_size)) == -1) { MS_LOG(EXCEPTION) << "Event buffer add header failed!"; } diff --git a/mindspore/ccsrc/ps/core/tcp_client.h b/mindspore/ccsrc/ps/core/tcp_client.h index ce682e0d57..f34982a2bb 100644 --- a/mindspore/ccsrc/ps/core/tcp_client.h +++ b/mindspore/ccsrc/ps/core/tcp_client.h @@ -35,6 +35,7 @@ #include "ps/core/cluster_config.h" #include "proto/comm.pb.h" #include "proto/ps.pb.h" +#include "utils/convert_utils_base.h" namespace mindspore { namespace ps { @@ -86,6 +87,9 @@ class TcpClient { OnTimer on_timer_callback_; static event_base *event_base_; + static std::mutex event_base_mutex_; + static bool is_started_; + std::mutex connection_mutex_; std::condition_variable connection_cond_; event *event_timeout_; diff --git a/mindspore/ccsrc/ps/core/tcp_server.cc b/mindspore/ccsrc/ps/core/tcp_server.cc index afafad7354..4d4466fd2c 100644 --- a/mindspore/ccsrc/ps/core/tcp_server.cc +++ b/mindspore/ccsrc/ps/core/tcp_server.cc @@ -32,7 +32,6 @@ namespace mindspore { namespace ps { namespace core { - void TcpConnection::InitConnection() { tcp_message_handler_.SetCallback([&](const CommMessage &message) { OnServerReceiveMessage on_server_receive = server_->GetServerReceive(); @@ -58,7 +57,7 @@ void TcpConnection::SendMessage(const CommMessage &message) const { MS_EXCEPTION_IF_NULL(buffer_event_); size_t buf_size = message.ByteSizeLong(); std::vector serialized(buf_size); - message.SerializeToArray(serialized.data(), static_cast(buf_size)); + message.SerializeToArray(serialized.data(), SizeToInt(buf_size)); if (evbuffer_add(bufferevent_get_output(const_cast(buffer_event_)), &buf_size, sizeof(buf_size)) == -1) { MS_LOG(EXCEPTION) << "Event buffer add header failed!"; @@ -304,7 +303,7 @@ void TcpServer::ReadCallback(struct bufferevent *bev, void *connection) { if (read == -1) { MS_LOG(EXCEPTION) << "Can not drain data from the event buffer!"; } - conn->OnReadHandler(read_buffer, static_cast(read)); + conn->OnReadHandler(read_buffer, IntToSize(read)); } } diff --git a/mindspore/ccsrc/ps/core/tcp_server.h b/mindspore/ccsrc/ps/core/tcp_server.h index c268775bfd..fcb51ff8de 100644 --- a/mindspore/ccsrc/ps/core/tcp_server.h +++ b/mindspore/ccsrc/ps/core/tcp_server.h @@ -39,6 +39,7 @@ #include "ps/core/tcp_message_handler.h" #include "ps/core/cluster_config.h" #include "utils/log_adapter.h" +#include "utils/convert_utils_base.h" namespace mindspore { namespace ps { diff --git a/mindspore/ccsrc/ps/core/worker_node.h b/mindspore/ccsrc/ps/core/worker_node.h index fe6ac67539..9d2713d81e 100644 --- a/mindspore/ccsrc/ps/core/worker_node.h +++ b/mindspore/ccsrc/ps/core/worker_node.h @@ -40,7 +40,7 @@ class WorkerNode : public AbstractNode { WorkerNode() = default; ~WorkerNode() override; - bool Start(const uint32_t &timeout = kTimeoutInSeconds) override; + bool Start(const uint32_t &timeout = ClusterConfig::cluster_available_timeout()) override; bool Stop() override; bool Finish(const uint32_t &timeout = kTimeoutInSeconds) override; diff --git a/tests/ut/cpp/ps/core/cluster_available_timeout_test.cc b/tests/ut/cpp/ps/core/cluster_available_timeout_test.cc new file mode 100644 index 0000000000..9f72446006 --- /dev/null +++ b/tests/ut/cpp/ps/core/cluster_available_timeout_test.cc @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/common_test.h" +#include "ps/core/node.h" +#include "ps/core/scheduler_node.h" + +namespace mindspore { +namespace ps { +namespace core { +class TestClusterAvailableTimeout : public UT::Common { + public: + TestClusterAvailableTimeout() = default; + ~TestClusterAvailableTimeout() override = default; + + void SetUp() override {} + void TearDown() override {} +}; + +TEST_F(TestClusterAvailableTimeout, TestClusterAvailableTimeout) { + ClusterConfig::Init(1, 1, std::make_unique("127.0.0.1"), 9999); + ClusterConfig::set_cluster_available_timeout(3); + SchedulerNode node; + node.Start(); + node.Finish(); + node.Stop(); +} +} // namespace core +} // namespace ps +} // namespace mindspore \ No newline at end of file diff --git a/tests/ut/cpp/ps/core/common_util_test.cc b/tests/ut/cpp/ps/core/common_util_test.cc index ea67fc5856..f2b3bf2e60 100644 --- a/tests/ut/cpp/ps/core/common_util_test.cc +++ b/tests/ut/cpp/ps/core/common_util_test.cc @@ -27,6 +27,18 @@ class TestCommUtil : public UT::Common { public: TestCommUtil() = default; virtual ~TestCommUtil() = default; + struct MockRetry { + bool operator()(std::string mock) { + ++count_; + + if (count_ > 3) { + return true; + } + return false; + } + + int count_{0}; + }; void SetUp() override {} void TearDown() override {} @@ -47,6 +59,14 @@ TEST_F(TestCommUtil, ValidateRankId) { EXPECT_TRUE(CommUtil::ValidateRankId(NodeRole::SERVER, 1)); EXPECT_FALSE(CommUtil::ValidateRankId(NodeRole::SERVER, 2)); } + +TEST_F(TestCommUtil, Retry) { + bool const ret = CommUtil::Retry([]() -> bool { return false; }, 5, 100); + EXPECT_FALSE(ret); + MockRetry mock_retry; + bool const mock_ret = CommUtil::Retry([&] { return mock_retry("mock"); }, 5, 100); + EXPECT_TRUE(mock_ret); +} } // namespace core } // namespace ps } // namespace mindspore \ No newline at end of file