| @@ -18,6 +18,7 @@ | |||
| #define MINDSPORE_CCSRC_DISTRIBUTED_CLUSTER_TOPOLOGY_COMMON_H_ | |||
| #include <string> | |||
| #include <chrono> | |||
| namespace mindspore { | |||
| namespace distributed { | |||
| @@ -41,6 +42,9 @@ constexpr char kEnvNodeId[] = "MS_NODE_ID"; | |||
| // For port number conversion. | |||
| static const int kDecimal = 10; | |||
| // The timeout for initializing the cluster topology. | |||
| static const std::chrono::milliseconds kTopoInitTimeout = std::chrono::milliseconds(1000 * 60 * 10); | |||
| // All kinds of messages sent between compute graph nodes and meta server node. | |||
| enum class MessageName { kRegistration, kHeartbeat }; | |||
| } // namespace topology | |||
| @@ -30,6 +30,11 @@ bool MetaServerNode::Initialize() { | |||
| // Init the TCP server. | |||
| RETURN_IF_FALSE_WITH_LOG(InitTCPServer(), "Failed to create the TCP server."); | |||
| start_time_ = Now(); | |||
| // Init the thread for monitoring the state of the cluster topo. | |||
| topo_monitor_ = std::thread(&MetaServerNode::UpdateTopoState, this); | |||
| return true; | |||
| } | |||
| @@ -39,6 +44,10 @@ bool MetaServerNode::Finalize() { | |||
| tcp_server_->Finalize(); | |||
| tcp_server_.reset(); | |||
| } | |||
| // Stop the topo monitor thread. | |||
| enable_monitor_ = false; | |||
| topo_monitor_.join(); | |||
| return true; | |||
| } | |||
| @@ -75,6 +84,7 @@ void MetaServerNode::ProcessRegister(const std::shared_ptr<MessageBase> &message | |||
| // Add the compute graph node into registered nodes. | |||
| const auto &node_id = registration.node_id(); | |||
| std::unique_lock<std::shared_mutex> lock(nodes_mutex_); | |||
| if (nodes_.find(node_id) == nodes_.end()) { | |||
| std::shared_ptr<ComputeGraphNodeState> node_state = std::make_shared<ComputeGraphNodeState>(node_id); | |||
| nodes_[node_id] = node_state; | |||
| @@ -93,13 +103,45 @@ void MetaServerNode::ProcessHeartbeat(const std::shared_ptr<MessageBase> &messag | |||
| // Update the state(timestamp) of this node. | |||
| const auto &node_id = heartbeat.node_id(); | |||
| if (nodes_.find(node_id) == nodes_.end()) { | |||
| std::shared_lock<std::shared_mutex> lock(nodes_mutex_); | |||
| if (nodes_.find(node_id) != nodes_.end()) { | |||
| auto &node = nodes_[node_id]; | |||
| time(&(node->last_update)); | |||
| } else { | |||
| MS_LOG(ERROR) << "Invalid node: " << node_id << "."; | |||
| } | |||
| } | |||
| void MetaServerNode::UpdateTopoState() { | |||
| while (enable_monitor_) { | |||
| if (topo_state_ == TopoState::kInitializing) { | |||
| // Set the state of topo to `kFailed` if the topology is still in process of initializtion but timed out. | |||
| if (ElapsedTime(start_time_) > kTopoInitTimeout) { | |||
| MS_LOG(ERROR) << "Failed to initialize the cluster topology after waiting for " << kTopoInitTimeout.count() | |||
| << " milliseconds."; | |||
| topo_state_ = TopoState::kFailed; | |||
| } | |||
| std::shared_lock<std::shared_mutex> lock(nodes_mutex_); | |||
| if (nodes_.size() == total_node_num_) { | |||
| MS_LOG(INFO) << "The cluster topology has been constructed successfully"; | |||
| topo_state_ = TopoState::kInitialized; | |||
| continue; | |||
| } | |||
| MS_LOG(INFO) << "The cluster topology is in the process of constructing, current alive node num: (" | |||
| << nodes_.size() << "/" << total_node_num_ << ")"; | |||
| } | |||
| static const size_t interval = 3; | |||
| sleep(interval); | |||
| } | |||
| } | |||
| TopoState MetaServerNode::TopologyState() { return topo_state_; } | |||
| size_t MetaServerNode::GetAliveNodeNum() { | |||
| std::shared_lock<std::shared_mutex> lock(nodes_mutex_); | |||
| return nodes_.size(); | |||
| } | |||
| } // namespace topology | |||
| } // namespace cluster | |||
| } // namespace distributed | |||
| @@ -21,6 +21,9 @@ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <thread> | |||
| #include <chrono> | |||
| #include <shared_mutex> | |||
| #include "distributed/cluster/topology/common.h" | |||
| #include "distributed/rpc/tcp/tcp_server.h" | |||
| #include "distributed/cluster/topology/node_base.h" | |||
| @@ -39,13 +42,32 @@ struct ComputeGraphNodeState { | |||
| time_t last_update; | |||
| }; | |||
| // Indicates the state of the cluster physical topology. | |||
| enum class TopoState { | |||
| // All the nodes of this cluster are in the process of starting up. | |||
| kInitializing = 0, | |||
| // All the nodes of this cluster has been started and registered to the meta server node successfully. | |||
| kInitialized = 1, | |||
| // The topo of this cluster failed to construct at specified time. | |||
| kFailed = 2 | |||
| }; | |||
| // The MetaServerNode is a separate process representing the meta server node which stores all the metadata and status | |||
| // of computation graph nodes. | |||
| class MetaServerNode : public NodeBase { | |||
| public: | |||
| explicit MetaServerNode(const std::string &node_id) : NodeBase(node_id) {} | |||
| explicit MetaServerNode(const std::string &node_id, const size_t &node_num) | |||
| : NodeBase(node_id), total_node_num_(node_num), topo_state_(TopoState::kInitializing), enable_monitor_(true) {} | |||
| ~MetaServerNode() override = default; | |||
| // Get the current topology state. | |||
| TopoState TopologyState(); | |||
| // Get the number of alive compute graph node. | |||
| size_t GetAliveNodeNum(); | |||
| bool Initialize() override; | |||
| bool Finalize() override; | |||
| @@ -62,6 +84,9 @@ class MetaServerNode : public NodeBase { | |||
| // Process the received heartbeat message sent from compute graph nodes. | |||
| void ProcessHeartbeat(const std::shared_ptr<MessageBase> &message); | |||
| // Maintain the state which is type of `TopoState` of this cluster topology. | |||
| void UpdateTopoState(); | |||
| // The meta server address used to manage the tcp server. | |||
| MetaServerAddress meta_server_addr_; | |||
| @@ -73,6 +98,23 @@ class MetaServerNode : public NodeBase { | |||
| // Stores the registered compute graph nodes. | |||
| std::map<std::string, std::shared_ptr<ComputeGraphNodeState>> nodes_; | |||
| mutable std::shared_mutex nodes_mutex_; | |||
| // The total legal number of compute graph nodes. | |||
| size_t total_node_num_; | |||
| // The state of the topology consisting of compute graph nodes. | |||
| TopoState topo_state_; | |||
| // The monitor thread for update the topo state. | |||
| std::thread topo_monitor_; | |||
| // The switch for the topo monitor thread. | |||
| std::atomic<bool> enable_monitor_; | |||
| // The start time of this meta server node. | |||
| std::chrono::high_resolution_clock::time_point start_time_; | |||
| }; | |||
| } // namespace topology | |||
| } // namespace cluster | |||
| @@ -19,6 +19,7 @@ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <chrono> | |||
| #include "utils/log_adapter.h" | |||
| #include "utils/ms_utils.h" | |||
| #include "actor/msg.h" | |||
| @@ -66,6 +67,15 @@ __attribute__((unused)) static std::unique_ptr<MessageBase> CreateMessage(const | |||
| message->body = content; | |||
| return message; | |||
| } | |||
| __attribute__((unused)) static std::chrono::high_resolution_clock::time_point Now() { | |||
| return std::chrono::high_resolution_clock::now(); | |||
| } | |||
| __attribute__((unused)) static std::chrono::milliseconds ElapsedTime( | |||
| const std::chrono::high_resolution_clock::time_point &start_time) { | |||
| return std::chrono::duration_cast<std::chrono::milliseconds>(Now() - start_time); | |||
| } | |||
| } // namespace topology | |||
| } // namespace cluster | |||
| } // namespace distributed | |||
| @@ -40,12 +40,23 @@ TEST_F(TestDynamicNetworking, NodeRegister) { | |||
| common::SetEnv(kEnvMetaServerHost, server_host.c_str()); | |||
| common::SetEnv(kEnvMetaServerPort, server_port.c_str()); | |||
| MetaServerNode msn("meta_server_node"); | |||
| size_t total_node_num = 1; | |||
| MetaServerNode msn("meta_server_node", total_node_num); | |||
| ASSERT_TRUE(msn.Initialize()); | |||
| ComputeGraphNode cgn("compute_graph_node"); | |||
| ASSERT_TRUE(cgn.Initialize()); | |||
| size_t interval = 1; | |||
| size_t retry = 30; | |||
| while (((msn.GetAliveNodeNum() != total_node_num) || (msn.TopologyState() != TopoState::kInitialized)) && | |||
| (retry-- > 0)) { | |||
| sleep(interval); | |||
| } | |||
| ASSERT_EQ(total_node_num, msn.GetAliveNodeNum()); | |||
| ASSERT_EQ(TopoState::kInitialized, msn.TopologyState()); | |||
| cgn.Finalize(); | |||
| msn.Finalize(); | |||
| } | |||