Merge pull request !28765 from zyli2020/parameter_server_cachefeature/build-system-rewrite
| @@ -16,7 +16,9 @@ if(NOT ENABLE_CPU OR WIN32) | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/node_manager.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_cache_manager.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/worker_node.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/ps_worker_node.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/server_node.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/ps_server_node.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/abstract_node.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/scheduler_node.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/http_client.cc") | |||
| @@ -150,7 +150,7 @@ class AbstractNode : public Node { | |||
| const std::shared_ptr<TaskExecutor> &task_executor); | |||
| protected: | |||
| void Register(const std::shared_ptr<TcpClient> &client); | |||
| virtual void Register(const std::shared_ptr<TcpClient> &client); | |||
| bool Heartbeat(const std::shared_ptr<TcpClient> &client); | |||
| void FetchServers(const std::shared_ptr<TcpClient> &client); | |||
| @@ -0,0 +1,43 @@ | |||
| /** | |||
| * Copyright 2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_CORE_PS_SCHEDULER_NODE_H_ | |||
| #define MINDSPORE_CCSRC_PS_CORE_PS_SCHEDULER_NODE_H_ | |||
| #include "ps/core/scheduler_node.h" | |||
| #include "ps/core/node_info.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| // This class is a derived class of SchedulerNode specialized for Parameter Server. It is used to rewrite the specific | |||
| // logic for Parameter Server mode training in SchedulerNode. For example, the Scheduler of Parameter Server will reject | |||
| // the registration request of alive nodes. | |||
| class PSSchedulerNode : public SchedulerNode { | |||
| public: | |||
| PSSchedulerNode() = default; | |||
| ~PSSchedulerNode() override = default; | |||
| private: | |||
| // Determine whether the registration request of the node should be rejected, the registration of the | |||
| // alive node should be rejected. | |||
| bool NeedRejectRegister(const NodeInfo &node_info) override { return node_info.is_alive; } | |||
| }; | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_CORE_PS_SCHEDULER_NODE_H_ | |||
| @@ -0,0 +1,55 @@ | |||
| /** | |||
| * Copyright 2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/core/ps_server_node.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| void PSServerNode::Register(const std::shared_ptr<TcpClient> &client) { | |||
| MS_EXCEPTION_IF_NULL(client); | |||
| auto message_meta = std::make_shared<MessageMeta>(); | |||
| MS_EXCEPTION_IF_NULL(message_meta); | |||
| message_meta->set_cmd(NodeCommand::REGISTER); | |||
| message_meta->set_rank_id(node_info_.rank_id_); | |||
| RegisterMessage register_message; | |||
| register_message.set_node_id(node_info_.node_id_); | |||
| register_message.set_role(node_info_.node_role_); | |||
| register_message.set_ip(node_info_.ip_); | |||
| register_message.set_port(node_info_.port_); | |||
| MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||
| << " the node id:" << node_info_.node_id_ << " begin to register to the scheduler!"; | |||
| if (!SendMessageSync(client, message_meta, Protos::PROTOBUF, register_message.SerializeAsString().data(), | |||
| register_message.ByteSizeLong())) { | |||
| MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||
| << " the node id:" << node_info_.node_id_ << " register timeout!"; | |||
| } else { | |||
| MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||
| << " the node id:" << node_info_.node_id_ << " send register success!"; | |||
| } | |||
| // Registrations of alive nodes or registrations request exceeding the set total number of nodes should be | |||
| // rejected, and the process exits after being rejected. | |||
| if (node_info_.rank_id_ == UINT32_MAX) { | |||
| MS_LOG(EXCEPTION) << "Register is rejected, and finish the node."; | |||
| } | |||
| } | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,43 @@ | |||
| /** | |||
| * Copyright 2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_CORE_PS_SERVER_NODE_H_ | |||
| #define MINDSPORE_CCSRC_PS_CORE_PS_SERVER_NODE_H_ | |||
| #include <memory> | |||
| #include "ps/core/server_node.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| // This class is a derived class of ServerNode specialized for Parameter Server. It is used to rewrite the logic | |||
| // specific to Parameter Server mode training in ServerNode. For example, the registration of Parameter Server's Server | |||
| // node is synchronous. | |||
| class PSServerNode : public ServerNode { | |||
| public: | |||
| PSServerNode() = default; | |||
| ~PSServerNode() override = default; | |||
| private: | |||
| // The Server node registers to the Scheduler node, and the registration of the Server node of the Parameter Server | |||
| // is synchronous. | |||
| void Register(const std::shared_ptr<TcpClient> &client) override; | |||
| }; | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_CORE_PS_SERVER_NODE_H_ | |||
| @@ -0,0 +1,55 @@ | |||
| /** | |||
| * Copyright 2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/core/ps_worker_node.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| void PSWorkerNode::Register(const std::shared_ptr<TcpClient> &client) { | |||
| MS_EXCEPTION_IF_NULL(client); | |||
| auto message_meta = std::make_shared<MessageMeta>(); | |||
| MS_EXCEPTION_IF_NULL(message_meta); | |||
| message_meta->set_cmd(NodeCommand::REGISTER); | |||
| message_meta->set_rank_id(node_info_.rank_id_); | |||
| RegisterMessage register_message; | |||
| register_message.set_node_id(node_info_.node_id_); | |||
| register_message.set_role(node_info_.node_role_); | |||
| register_message.set_ip(node_info_.ip_); | |||
| register_message.set_port(node_info_.port_); | |||
| MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||
| << " the node id:" << node_info_.node_id_ << " begin to register to the scheduler!"; | |||
| if (!SendMessageSync(client, message_meta, Protos::PROTOBUF, register_message.SerializeAsString().data(), | |||
| register_message.ByteSizeLong())) { | |||
| MS_LOG(EXCEPTION) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||
| << " the node id:" << node_info_.node_id_ << " register timeout!"; | |||
| } else { | |||
| MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||
| << " the node id:" << node_info_.node_id_ << " send register success!"; | |||
| } | |||
| // Registrations of alive nodes or registrations request exceeding the set total number of nodes should be | |||
| // rejected, and the process exits after being rejected. | |||
| if (node_info_.rank_id_ == UINT32_MAX) { | |||
| MS_LOG(EXCEPTION) << "Register is rejected, and finish the node."; | |||
| } | |||
| } | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,43 @@ | |||
| /** | |||
| * Copyright 2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_CORE_PS_WORKER_NODE_H_ | |||
| #define MINDSPORE_CCSRC_PS_CORE_PS_WORKER_NODE_H_ | |||
| #include <memory> | |||
| #include "ps/core/worker_node.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| // This class is a derived class of WorkerNode specialized for Parameter Server. It is used to rewrite the logic | |||
| // specific to Parameter Server mode training in WorkerNode. For example, the registration of Parameter Server's Worker | |||
| // node is synchronous. | |||
| class PSWorkerNode : public WorkerNode { | |||
| public: | |||
| PSWorkerNode() = default; | |||
| ~PSWorkerNode() override = default; | |||
| private: | |||
| // The Worker node registers to the Scheduler node, and the registration of the Worker node of the Parameter Server | |||
| // is synchronous. | |||
| void Register(const std::shared_ptr<TcpClient> &client) override; | |||
| }; | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_CORE_PS_WORKER_NODE_H_ | |||
| @@ -256,6 +256,20 @@ void SchedulerNode::ProcessRegister(const std::shared_ptr<TcpServer> &server, | |||
| MS_LOG(INFO) << "The node id:" << node_id << " is registering to scheduler."; | |||
| client_mutex_.lock(); | |||
| if (node_manager_.IsNodeRegistered(node_id)) { | |||
| NodeInfo node_info = node_manager_.QueryNodeInfo(node_id); | |||
| if (NeedRejectRegister(node_info)) { | |||
| MS_LOG(WARNING) << "The node(id: " << node_id << ") is alive, register is rejected!"; | |||
| RegisterRespMessage register_rejected_message; | |||
| register_rejected_message.set_node_id(node_id); | |||
| register_rejected_message.set_rank_id(UINT32_MAX); | |||
| if (!server->SendMessage(conn, meta, Protos::PROTOBUF, register_rejected_message.SerializeAsString().data(), | |||
| register_rejected_message.ByteSizeLong())) { | |||
| MS_LOG(WARNING) << "Server response rejected message failed."; | |||
| } | |||
| client_mutex_.unlock(); | |||
| return; | |||
| } | |||
| MS_LOG(INFO) << "The node id is registered."; | |||
| if (connected_nodes_.count(node_id)) { | |||
| (void)connected_nodes_.erase(node_id); | |||
| @@ -103,6 +103,10 @@ class SchedulerNode : public Node { | |||
| void ProcessSendEvent(const std::shared_ptr<TcpServer> &server, const std::shared_ptr<TcpConnection> &conn, | |||
| const std::shared_ptr<MessageMeta> &meta, const void *data, size_t size); | |||
| // Determine whether the registration request of the node should be rejected, the registration of the | |||
| // alive node should be rejected. | |||
| virtual bool NeedRejectRegister(const NodeInfo &node_info) { return false; } | |||
| // After scheduler collects all registered message, it actively sends finish to the node connected by the client. | |||
| void SendMetadata(const std::shared_ptr<TcpClient> &client, uint32_t rank_id); | |||
| // After scheduler collects all finish message, it actively sends finish to the node connected by the client. | |||
| @@ -29,7 +29,7 @@ static const uint32_t kCPUCoreNum = std::thread::hardware_concurrency(); | |||
| void ParameterServer::Run(const FuncGraphPtr &func_graph) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_LOG(INFO) << "PServer starts connecting to scheduler and workers..."; | |||
| server_node_ = std::make_shared<core::ServerNode>(); | |||
| server_node_ = std::make_shared<core::PSServerNode>(); | |||
| MS_LOG(INFO) << "PServer connected successfully."; | |||
| if (!PSContext::instance()->is_server()) { | |||
| @@ -61,7 +61,7 @@ | |||
| #include "utils/log_adapter.h" | |||
| #include "proto/comm.pb.h" | |||
| #include "proto/ps.pb.h" | |||
| #include "ps/core/server_node.h" | |||
| #include "ps/core/ps_server_node.h" | |||
| #include "ps/core/node.h" | |||
| namespace mindspore { | |||
| @@ -24,15 +24,15 @@ void Scheduler::Run() { | |||
| PSContext::instance()->cluster_config().scheduler_port = PSContext::instance()->scheduler_port(); | |||
| PSContext::instance()->cluster_config().initial_worker_num = PSContext::instance()->initial_worker_num(); | |||
| PSContext::instance()->cluster_config().initial_server_num = PSContext::instance()->initial_server_num(); | |||
| if (!scheduler_node_.Start()) { | |||
| if (!scheduler_node_->Start()) { | |||
| MS_LOG(WARNING) << "Scheduler start failed."; | |||
| } | |||
| if (!scheduler_node_.Finish()) { | |||
| if (!scheduler_node_->Finish()) { | |||
| MS_LOG(WARNING) << "Scheduler finis failed."; | |||
| } | |||
| if (!scheduler_node_.Stop()) { | |||
| if (!scheduler_node_->Stop()) { | |||
| MS_LOG(WARNING) << "Scheduler stop failed."; | |||
| } | |||
| exit(1); | |||
| @@ -17,7 +17,9 @@ | |||
| #ifndef MINDSPORE_CCSRC_PS_SCHEDULER_H_ | |||
| #define MINDSPORE_CCSRC_PS_SCHEDULER_H_ | |||
| #include <memory> | |||
| #include "ps/core/scheduler_node.h" | |||
| #include "ps/core/ps_scheduler_node.h" | |||
| #include "ps/util.h" | |||
| #include "ps/ps_context.h" | |||
| @@ -33,11 +35,22 @@ class Scheduler { | |||
| void Run(); | |||
| private: | |||
| Scheduler() = default; | |||
| Scheduler() { | |||
| if (scheduler_node_ == nullptr) { | |||
| bool is_fl_mode = PSContext::instance()->server_mode() == ps::kServerModeFL || | |||
| PSContext::instance()->server_mode() == ps::kServerModeHybrid; | |||
| if (is_fl_mode) { | |||
| scheduler_node_ = std::make_unique<core::SchedulerNode>(); | |||
| } else { | |||
| scheduler_node_ = std::make_unique<core::PSSchedulerNode>(); | |||
| } | |||
| } | |||
| } | |||
| ~Scheduler() = default; | |||
| Scheduler(const Scheduler &) = delete; | |||
| Scheduler &operator=(const Scheduler &) = delete; | |||
| core::SchedulerNode scheduler_node_; | |||
| std::unique_ptr<core::SchedulerNode> scheduler_node_; | |||
| }; | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -35,7 +35,7 @@ | |||
| #include "ps/constants.h" | |||
| #include "utils/shape_utils.h" | |||
| #include "ps/ps_cache/ps_data/ps_data_prefetch.h" | |||
| #include "ps/core/worker_node.h" | |||
| #include "ps/core/ps_worker_node.h" | |||
| #include "ps/embedding_table_shard_metadata.h" | |||
| #include "proto/comm.pb.h" | |||
| #include "proto/ps.pb.h" | |||
| @@ -135,7 +135,7 @@ class Worker { | |||
| std::map<size_t, int64_t> key_to_optimId_; | |||
| std::map<size_t, std::vector<ShapeVector>> key_to_optim_shapes_; | |||
| std::map<std::string, bool> param_to_init_in_server_; | |||
| core::WorkerNode worker_node_; | |||
| core::PSWorkerNode worker_node_; | |||
| EmbeddingPartitioner lookup_partitioner_; | |||
| KVPartitioner sparse_partitioner_; | |||