| @@ -14,6 +14,7 @@ if(NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))) | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_server.cc") | list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_server.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/communicator_base.cc") | list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/communicator_base.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/http_communicator.cc") | list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/http_communicator.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_communicator.cc") | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/http_msg_handler.cc") | list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/http_msg_handler.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_msg_handler.cc") | list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_msg_handler.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "core/node.cc") | list(REMOVE_ITEM _PS_SRC_FILES "core/node.cc") | ||||
| @@ -40,9 +40,9 @@ class CommunicatorBase { | |||||
| using MessageCallback = std::function<void(std::shared_ptr<MessageHandler>)>; | using MessageCallback = std::function<void(std::shared_ptr<MessageHandler>)>; | ||||
| using HttpMsgCallback = std::function<void(std::shared_ptr<HttpMessageHandler>)>; | using HttpMsgCallback = std::function<void(std::shared_ptr<HttpMessageHandler>)>; | ||||
| using OnNodeEventCallback = std::function<void(const NodeEvent &)>; | using OnNodeEventCallback = std::function<void(const NodeEvent &)>; | ||||
| using TcpMsgCallBack = std::function<void(std::shared_ptr<core::TcpConnection> conn, | |||||
| using TcpMsgCallback = std::function<void(std::shared_ptr<core::TcpConnection> conn, | |||||
| std::shared_ptr<core::MessageMeta> meta, DataPtr data, size_t size)>; | std::shared_ptr<core::MessageMeta> meta, DataPtr data, size_t size)>; | ||||
| using CertainEventCallBack = std::function<void(void)>; | |||||
| using CertainEventCallback = std::function<void(void)>; | |||||
| CommunicatorBase() = default; | CommunicatorBase() = default; | ||||
| @@ -50,7 +50,7 @@ class TaskExecutor { | |||||
| bool Submit(Fun &&function, Args &&... args) { | bool Submit(Fun &&function, Args &&... args) { | ||||
| auto callee = std::bind(function, args...); | auto callee = std::bind(function, args...); | ||||
| std::function<void()> task = [callee]() -> void { callee(); }; | std::function<void()> task = [callee]() -> void { callee(); }; | ||||
| auto index = 0; | |||||
| size_t index = 0; | |||||
| for (size_t i = 0; i < submit_timeout_; i++) { | for (size_t i = 0; i < submit_timeout_; i++) { | ||||
| std::unique_lock<std::mutex> lock(mtx_); | std::unique_lock<std::mutex> lock(mtx_); | ||||
| if (task_num_ >= max_task_num_) { | if (task_num_ >= max_task_num_) { | ||||
| @@ -0,0 +1,91 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ps/core/communicator/tcp_communicator.h" | |||||
| #include <memory> | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| namespace core { | |||||
| bool TcpCommunicator::Start() { | |||||
| if (running_) { | |||||
| MS_LOG(INFO) << "The TCP communicator has already started."; | |||||
| return true; | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(server_node_); | |||||
| // Set message callback. For example, message of push/pull, etc. | |||||
| tcp_msg_callback_ = std::bind( | |||||
| [&](std::shared_ptr<core::TcpConnection> conn, std::shared_ptr<core::MessageMeta> meta, DataPtr data, | |||||
| size_t size) -> void { | |||||
| TcpUserCommand user_command = static_cast<TcpUserCommand>(meta->user_cmd()); | |||||
| const std::string &msg_type = kUserCommandToMsgType.at(user_command); | |||||
| if (msg_type == "" || !msg_callbacks_[msg_type]) { | |||||
| MS_LOG(ERROR) << "Tcp server doesn't support command " << user_command << " " << msg_type; | |||||
| return; | |||||
| } | |||||
| MS_LOG(DEBUG) << "TcpCommunicator receives message for " << msg_type; | |||||
| std::shared_ptr<MessageHandler> tcp_msg_handler = | |||||
| std::make_shared<TcpMsgHandler>(server_node_, conn, meta, data, size); | |||||
| MS_EXCEPTION_IF_NULL(tcp_msg_handler); | |||||
| task_executor_->Submit(msg_callbacks_[msg_type], tcp_msg_handler); | |||||
| return; | |||||
| }, | |||||
| std::placeholders::_1, std::placeholders::_2, std::placeholders::_3, std::placeholders::_4); | |||||
| server_node_->set_handler(tcp_msg_callback_); | |||||
| // Set event callback. For example, event of scaling out/in, etc. | |||||
| event_callback_ = std::bind( | |||||
| [&](const core::NodeEvent &event) -> void { | |||||
| MS_LOG(INFO) << "Server receives event of " << event; | |||||
| certain_event_to_callback_[event](); | |||||
| }, | |||||
| std::placeholders::_1); | |||||
| server_node_->set_event_callback(event_callback_); | |||||
| server_node_->Start(); | |||||
| running_ = true; | |||||
| running_thread_ = std::thread([&]() { | |||||
| while (running_) { | |||||
| std::this_thread::yield(); | |||||
| } | |||||
| }); | |||||
| return true; | |||||
| } | |||||
| bool TcpCommunicator::Stop() { | |||||
| MS_EXCEPTION_IF_NULL(server_node_); | |||||
| server_node_->Finish(); | |||||
| server_node_->Stop(); | |||||
| running_ = false; | |||||
| return true; | |||||
| } | |||||
| void TcpCommunicator::RegisterMsgCallBack(const std::string &msg_type, const MessageCallback &cb) { | |||||
| msg_callbacks_.try_emplace(msg_type, cb); | |||||
| return; | |||||
| } | |||||
| void TcpCommunicator::RegisterEventCallback(const core::NodeEvent &event, const CertainEventCallback &event_cb) { | |||||
| certain_event_to_callback_.try_emplace(event, event_cb); | |||||
| return; | |||||
| } | |||||
| ServerNode *TcpCommunicator::server_node() { return server_node_; } | |||||
| } // namespace core | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,111 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_TCP_COMMUNICATOR_H_ | |||||
| #define MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_TCP_COMMUNICATOR_H_ | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <unordered_map> | |||||
| #include "proto/ps.pb.h" | |||||
| #include "ps/core/server_node.h" | |||||
| #include "ps/core/cluster_metadata.h" | |||||
| #include "ps/ps_context.h" | |||||
| #include "ps/core/communicator/task_executor.h" | |||||
| #include "ps/core/communicator/communicator_base.h" | |||||
| #include "ps/core/communicator/tcp_msg_handler.h" | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| namespace core { | |||||
| enum class TcpUserCommand { kPush, kPull, kCount, kReachThreshold, kResetCount, kGetValue, kPutValue, kCounterEvent }; | |||||
| const std::unordered_map<TcpUserCommand, std::string> kUserCommandToMsgType = { | |||||
| {TcpUserCommand::kPush, "push"}, {TcpUserCommand::kPull, "pull"}, | |||||
| {TcpUserCommand::kCount, "count"}, {TcpUserCommand::kReachThreshold, "reachThreshold"}, | |||||
| {TcpUserCommand::kResetCount, "resetCnt"}, {TcpUserCommand::kGetValue, "getValue"}, | |||||
| {TcpUserCommand::kPutValue, "putValue"}, {TcpUserCommand::kCounterEvent, "counterEvent"}, | |||||
| }; | |||||
| class TcpCommunicator : public CommunicatorBase { | |||||
| public: | |||||
| explicit TcpCommunicator(const std::shared_ptr<TaskExecutor> &task_executor, ServerNode *node) | |||||
| : task_executor_(task_executor), | |||||
| running_(false), | |||||
| server_num_(0), | |||||
| worker_num_(0), | |||||
| scheduler_ip_(""), | |||||
| scheduler_port_(0), | |||||
| server_node_(node) {} | |||||
| ~TcpCommunicator() = default; | |||||
| bool Start() override; | |||||
| bool Stop() override; | |||||
| void RegisterMsgCallBack(const std::string &msg_type, const MessageCallback &cb) override; | |||||
| void RegisterEventCallback(const core::NodeEvent &event, const CertainEventCallback &event_cb); | |||||
| ServerNode *server_node(); | |||||
| template <class T> | |||||
| bool SendPbRequest(const T &pb_msg, const uint32_t &rank_id, TcpUserCommand command, | |||||
| std::shared_ptr<std::vector<unsigned char>> *output = nullptr) { | |||||
| const std::string &msg_str = pb_msg.SerializeAsString(); | |||||
| std::shared_ptr<unsigned char[]> msg(new unsigned char[msg_str.size()]); | |||||
| size_t dest_size = msg_str.size(); | |||||
| size_t src_size = msg_str.size(); | |||||
| auto ret = memcpy_s(msg.get(), dest_size, msg_str.c_str(), src_size); | |||||
| if (ret != EOK) { | |||||
| MS_LOG(EXCEPTION) << "memcpy_s error, error no " << ret; | |||||
| } | |||||
| if (output != nullptr) { | |||||
| if (!server_node_->Send(NodeRole::SERVER, rank_id, msg, msg_str.size(), static_cast<int>(command), output)) { | |||||
| MS_LOG(ERROR) << "Query leader server whether count is enough failed."; | |||||
| return false; | |||||
| } | |||||
| } else { | |||||
| if (!server_node_->Send(NodeRole::SERVER, rank_id, msg, msg_str.size(), static_cast<int>(command))) { | |||||
| MS_LOG(ERROR) << "Query leader server whether count is enough failed."; | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| private: | |||||
| std::shared_ptr<TaskExecutor> task_executor_; | |||||
| bool running_; | |||||
| TcpMsgCallback tcp_msg_callback_; | |||||
| OnNodeEventCallback event_callback_; | |||||
| // Each NodeEvent corresponds to a CertainEventCallback to process the event. | |||||
| std::map<core::NodeEvent, CertainEventCallback> certain_event_to_callback_; | |||||
| uint32_t server_num_; | |||||
| uint32_t worker_num_; | |||||
| std::string scheduler_ip_; | |||||
| uint16_t scheduler_port_; | |||||
| ServerNode *server_node_; | |||||
| }; | |||||
| } // namespace core | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_TCP_COMMUNICATOR_H_ | |||||
| @@ -14,6 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "ps/core/server_node.h" | #include "ps/core/server_node.h" | ||||
| #include "ps/core/communicator/tcp_communicator.h" | |||||
| #include "ps/core/communicator/http_communicator.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| @@ -124,6 +126,36 @@ void ServerNode::ProcessCollectiveSendData(std::shared_ptr<TcpConnection> conn, | |||||
| server_->SendMessage(conn, meta, Protos::RAW, data, size); | server_->SendMessage(conn, meta, Protos::RAW, data, size); | ||||
| } | } | ||||
| std::shared_ptr<CommunicatorBase> ServerNode::GetOrCreateHttpComm(const std::string &ip, std::int16_t port, | |||||
| const std::shared_ptr<TaskExecutor> &task_executor) { | |||||
| std::lock_guard<std::mutex> lock(communicator_mutex_); | |||||
| if (!communicators_.count(kHttpCommunicator)) { | |||||
| MS_LOG(INFO) << "Create Http communicator."; | |||||
| auto http_comm = std::make_shared<HttpCommunicator>(ip, port, task_executor); | |||||
| MS_EXCEPTION_IF_NULL(http_comm); | |||||
| communicators_[kHttpCommunicator] = http_comm; | |||||
| } | |||||
| return communicators_[kHttpCommunicator]; | |||||
| } | |||||
| std::shared_ptr<CommunicatorBase> ServerNode::GetOrCreateTcpComm(const std::string &scheduler_ip, | |||||
| std::int16_t scheduler_port, uint32_t worker_num, | |||||
| uint32_t server_num, | |||||
| const std::shared_ptr<TaskExecutor> &task_executor) { | |||||
| std::lock_guard<std::mutex> lock(communicator_mutex_); | |||||
| if (!communicators_.count(kTcpCommunicator)) { | |||||
| MS_LOG(INFO) << "Create Tcp communicator."; | |||||
| auto tcp_comm = std::make_shared<TcpCommunicator>(task_executor, this); | |||||
| MS_EXCEPTION_IF_NULL(tcp_comm); | |||||
| ClusterMetadata::instance()->Init(worker_num, server_num, scheduler_ip, scheduler_port); | |||||
| MS_LOG(INFO) << "Initialize cluster metadata for server. Worker number:" << worker_num | |||||
| << ", Server number:" << server_num << ", Scheduler ip:" << scheduler_ip | |||||
| << ", Scheduler port:" << scheduler_port; | |||||
| communicators_[kTcpCommunicator] = tcp_comm; | |||||
| } | |||||
| return communicators_[kTcpCommunicator]; | |||||
| } | |||||
| bool ServerNode::Stop() { | bool ServerNode::Stop() { | ||||
| MS_LOG(INFO) << "Stop server node!"; | MS_LOG(INFO) << "Stop server node!"; | ||||
| if (!is_already_stopped_.load()) { | if (!is_already_stopped_.load()) { | ||||
| @@ -24,18 +24,25 @@ | |||||
| #include <thread> | #include <thread> | ||||
| #include <utility> | #include <utility> | ||||
| #include <vector> | #include <vector> | ||||
| #include <unordered_map> | |||||
| #include "ps/core/cluster_metadata.h" | #include "ps/core/cluster_metadata.h" | ||||
| #include "ps/core/communicator/tcp_client.h" | #include "ps/core/communicator/tcp_client.h" | ||||
| #include "ps/core/communicator/tcp_server.h" | #include "ps/core/communicator/tcp_server.h" | ||||
| #include "ps/core/abstract_node.h" | #include "ps/core/abstract_node.h" | ||||
| #include "ps/core/communicator/task_executor.h" | |||||
| #include "ps/core/communicator/communicator_base.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| namespace core { | namespace core { | ||||
| constexpr char kTcpCommunicator[] = "TCP"; | |||||
| constexpr char kHttpCommunicator[] = "HTTP"; | |||||
| class ServerNode : public AbstractNode { | class ServerNode : public AbstractNode { | ||||
| public: | public: | ||||
| ServerNode() : server_(nullptr), server_thread_(nullptr) {} | ServerNode() : server_(nullptr), server_thread_(nullptr) {} | ||||
| ~ServerNode() override = default; | ~ServerNode() override = default; | ||||
| bool Start(const uint32_t &timeout = ClusterMetadata::instance()->cluster_available_timeout()) override; | bool Start(const uint32_t &timeout = ClusterMetadata::instance()->cluster_available_timeout()) override; | ||||
| @@ -48,6 +55,12 @@ class ServerNode : public AbstractNode { | |||||
| void set_handler(const RequestHandler &handler); | void set_handler(const RequestHandler &handler); | ||||
| void Response(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const void *data, size_t size); | void Response(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const void *data, size_t size); | ||||
| std::shared_ptr<CommunicatorBase> GetOrCreateHttpComm(const std::string &ip, std::int16_t port, | |||||
| const std::shared_ptr<TaskExecutor> &task_executor); | |||||
| std::shared_ptr<CommunicatorBase> GetOrCreateTcpComm(const std::string &scheduler_ip, std::int16_t scheduler_port, | |||||
| uint32_t worker_num, uint32_t server_num, | |||||
| const std::shared_ptr<TaskExecutor> &task_executor); | |||||
| private: | private: | ||||
| void CreateTcpServer(); | void CreateTcpServer(); | ||||
| void Initialize(); | void Initialize(); | ||||
| @@ -59,6 +72,8 @@ class ServerNode : public AbstractNode { | |||||
| std::shared_ptr<TcpServer> server_; | std::shared_ptr<TcpServer> server_; | ||||
| std::unique_ptr<std::thread> server_thread_; | std::unique_ptr<std::thread> server_thread_; | ||||
| RequestHandler request_handler_; | RequestHandler request_handler_; | ||||
| std::unordered_map<std::string, std::shared_ptr<CommunicatorBase>> communicators_; | |||||
| std::mutex communicator_mutex_; | |||||
| }; | }; | ||||
| } // namespace core | } // namespace core | ||||
| } // namespace ps | } // namespace ps | ||||