| @@ -14,6 +14,7 @@ if(NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))) | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_server.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/communicator_base.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/http_communicator.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_communicator.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/http_msg_handler.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_msg_handler.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/node.cc") | |||
| @@ -40,9 +40,9 @@ class CommunicatorBase { | |||
| using MessageCallback = std::function<void(std::shared_ptr<MessageHandler>)>; | |||
| using HttpMsgCallback = std::function<void(std::shared_ptr<HttpMessageHandler>)>; | |||
| using OnNodeEventCallback = std::function<void(const NodeEvent &)>; | |||
| using TcpMsgCallBack = std::function<void(std::shared_ptr<core::TcpConnection> conn, | |||
| using TcpMsgCallback = std::function<void(std::shared_ptr<core::TcpConnection> conn, | |||
| std::shared_ptr<core::MessageMeta> meta, DataPtr data, size_t size)>; | |||
| using CertainEventCallBack = std::function<void(void)>; | |||
| using CertainEventCallback = std::function<void(void)>; | |||
| CommunicatorBase() = default; | |||
| @@ -50,7 +50,7 @@ class TaskExecutor { | |||
| bool Submit(Fun &&function, Args &&... args) { | |||
| auto callee = std::bind(function, args...); | |||
| std::function<void()> task = [callee]() -> void { callee(); }; | |||
| auto index = 0; | |||
| size_t index = 0; | |||
| for (size_t i = 0; i < submit_timeout_; i++) { | |||
| std::unique_lock<std::mutex> lock(mtx_); | |||
| if (task_num_ >= max_task_num_) { | |||
| @@ -0,0 +1,91 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/core/communicator/tcp_communicator.h" | |||
| #include <memory> | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| bool TcpCommunicator::Start() { | |||
| if (running_) { | |||
| MS_LOG(INFO) << "The TCP communicator has already started."; | |||
| return true; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(server_node_); | |||
| // Set message callback. For example, message of push/pull, etc. | |||
| tcp_msg_callback_ = std::bind( | |||
| [&](std::shared_ptr<core::TcpConnection> conn, std::shared_ptr<core::MessageMeta> meta, DataPtr data, | |||
| size_t size) -> void { | |||
| TcpUserCommand user_command = static_cast<TcpUserCommand>(meta->user_cmd()); | |||
| const std::string &msg_type = kUserCommandToMsgType.at(user_command); | |||
| if (msg_type == "" || !msg_callbacks_[msg_type]) { | |||
| MS_LOG(ERROR) << "Tcp server doesn't support command " << user_command << " " << msg_type; | |||
| return; | |||
| } | |||
| MS_LOG(DEBUG) << "TcpCommunicator receives message for " << msg_type; | |||
| std::shared_ptr<MessageHandler> tcp_msg_handler = | |||
| std::make_shared<TcpMsgHandler>(server_node_, conn, meta, data, size); | |||
| MS_EXCEPTION_IF_NULL(tcp_msg_handler); | |||
| task_executor_->Submit(msg_callbacks_[msg_type], tcp_msg_handler); | |||
| return; | |||
| }, | |||
| std::placeholders::_1, std::placeholders::_2, std::placeholders::_3, std::placeholders::_4); | |||
| server_node_->set_handler(tcp_msg_callback_); | |||
| // Set event callback. For example, event of scaling out/in, etc. | |||
| event_callback_ = std::bind( | |||
| [&](const core::NodeEvent &event) -> void { | |||
| MS_LOG(INFO) << "Server receives event of " << event; | |||
| certain_event_to_callback_[event](); | |||
| }, | |||
| std::placeholders::_1); | |||
| server_node_->set_event_callback(event_callback_); | |||
| server_node_->Start(); | |||
| running_ = true; | |||
| running_thread_ = std::thread([&]() { | |||
| while (running_) { | |||
| std::this_thread::yield(); | |||
| } | |||
| }); | |||
| return true; | |||
| } | |||
| bool TcpCommunicator::Stop() { | |||
| MS_EXCEPTION_IF_NULL(server_node_); | |||
| server_node_->Finish(); | |||
| server_node_->Stop(); | |||
| running_ = false; | |||
| return true; | |||
| } | |||
| void TcpCommunicator::RegisterMsgCallBack(const std::string &msg_type, const MessageCallback &cb) { | |||
| msg_callbacks_.try_emplace(msg_type, cb); | |||
| return; | |||
| } | |||
| void TcpCommunicator::RegisterEventCallback(const core::NodeEvent &event, const CertainEventCallback &event_cb) { | |||
| certain_event_to_callback_.try_emplace(event, event_cb); | |||
| return; | |||
| } | |||
| ServerNode *TcpCommunicator::server_node() { return server_node_; } | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,111 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_TCP_COMMUNICATOR_H_ | |||
| #define MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_TCP_COMMUNICATOR_H_ | |||
| #include <map> | |||
| #include <vector> | |||
| #include <string> | |||
| #include <memory> | |||
| #include <unordered_map> | |||
| #include "proto/ps.pb.h" | |||
| #include "ps/core/server_node.h" | |||
| #include "ps/core/cluster_metadata.h" | |||
| #include "ps/ps_context.h" | |||
| #include "ps/core/communicator/task_executor.h" | |||
| #include "ps/core/communicator/communicator_base.h" | |||
| #include "ps/core/communicator/tcp_msg_handler.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| enum class TcpUserCommand { kPush, kPull, kCount, kReachThreshold, kResetCount, kGetValue, kPutValue, kCounterEvent }; | |||
| const std::unordered_map<TcpUserCommand, std::string> kUserCommandToMsgType = { | |||
| {TcpUserCommand::kPush, "push"}, {TcpUserCommand::kPull, "pull"}, | |||
| {TcpUserCommand::kCount, "count"}, {TcpUserCommand::kReachThreshold, "reachThreshold"}, | |||
| {TcpUserCommand::kResetCount, "resetCnt"}, {TcpUserCommand::kGetValue, "getValue"}, | |||
| {TcpUserCommand::kPutValue, "putValue"}, {TcpUserCommand::kCounterEvent, "counterEvent"}, | |||
| }; | |||
| class TcpCommunicator : public CommunicatorBase { | |||
| public: | |||
| explicit TcpCommunicator(const std::shared_ptr<TaskExecutor> &task_executor, ServerNode *node) | |||
| : task_executor_(task_executor), | |||
| running_(false), | |||
| server_num_(0), | |||
| worker_num_(0), | |||
| scheduler_ip_(""), | |||
| scheduler_port_(0), | |||
| server_node_(node) {} | |||
| ~TcpCommunicator() = default; | |||
| bool Start() override; | |||
| bool Stop() override; | |||
| void RegisterMsgCallBack(const std::string &msg_type, const MessageCallback &cb) override; | |||
| void RegisterEventCallback(const core::NodeEvent &event, const CertainEventCallback &event_cb); | |||
| ServerNode *server_node(); | |||
| template <class T> | |||
| bool SendPbRequest(const T &pb_msg, const uint32_t &rank_id, TcpUserCommand command, | |||
| std::shared_ptr<std::vector<unsigned char>> *output = nullptr) { | |||
| const std::string &msg_str = pb_msg.SerializeAsString(); | |||
| std::shared_ptr<unsigned char[]> msg(new unsigned char[msg_str.size()]); | |||
| size_t dest_size = msg_str.size(); | |||
| size_t src_size = msg_str.size(); | |||
| auto ret = memcpy_s(msg.get(), dest_size, msg_str.c_str(), src_size); | |||
| if (ret != EOK) { | |||
| MS_LOG(EXCEPTION) << "memcpy_s error, error no " << ret; | |||
| } | |||
| if (output != nullptr) { | |||
| if (!server_node_->Send(NodeRole::SERVER, rank_id, msg, msg_str.size(), static_cast<int>(command), output)) { | |||
| MS_LOG(ERROR) << "Query leader server whether count is enough failed."; | |||
| return false; | |||
| } | |||
| } else { | |||
| if (!server_node_->Send(NodeRole::SERVER, rank_id, msg, msg_str.size(), static_cast<int>(command))) { | |||
| MS_LOG(ERROR) << "Query leader server whether count is enough failed."; | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| private: | |||
| std::shared_ptr<TaskExecutor> task_executor_; | |||
| bool running_; | |||
| TcpMsgCallback tcp_msg_callback_; | |||
| OnNodeEventCallback event_callback_; | |||
| // Each NodeEvent corresponds to a CertainEventCallback to process the event. | |||
| std::map<core::NodeEvent, CertainEventCallback> certain_event_to_callback_; | |||
| uint32_t server_num_; | |||
| uint32_t worker_num_; | |||
| std::string scheduler_ip_; | |||
| uint16_t scheduler_port_; | |||
| ServerNode *server_node_; | |||
| }; | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_TCP_COMMUNICATOR_H_ | |||
| @@ -14,6 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/core/server_node.h" | |||
| #include "ps/core/communicator/tcp_communicator.h" | |||
| #include "ps/core/communicator/http_communicator.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| @@ -124,6 +126,36 @@ void ServerNode::ProcessCollectiveSendData(std::shared_ptr<TcpConnection> conn, | |||
| server_->SendMessage(conn, meta, Protos::RAW, data, size); | |||
| } | |||
| std::shared_ptr<CommunicatorBase> ServerNode::GetOrCreateHttpComm(const std::string &ip, std::int16_t port, | |||
| const std::shared_ptr<TaskExecutor> &task_executor) { | |||
| std::lock_guard<std::mutex> lock(communicator_mutex_); | |||
| if (!communicators_.count(kHttpCommunicator)) { | |||
| MS_LOG(INFO) << "Create Http communicator."; | |||
| auto http_comm = std::make_shared<HttpCommunicator>(ip, port, task_executor); | |||
| MS_EXCEPTION_IF_NULL(http_comm); | |||
| communicators_[kHttpCommunicator] = http_comm; | |||
| } | |||
| return communicators_[kHttpCommunicator]; | |||
| } | |||
| std::shared_ptr<CommunicatorBase> ServerNode::GetOrCreateTcpComm(const std::string &scheduler_ip, | |||
| std::int16_t scheduler_port, uint32_t worker_num, | |||
| uint32_t server_num, | |||
| const std::shared_ptr<TaskExecutor> &task_executor) { | |||
| std::lock_guard<std::mutex> lock(communicator_mutex_); | |||
| if (!communicators_.count(kTcpCommunicator)) { | |||
| MS_LOG(INFO) << "Create Tcp communicator."; | |||
| auto tcp_comm = std::make_shared<TcpCommunicator>(task_executor, this); | |||
| MS_EXCEPTION_IF_NULL(tcp_comm); | |||
| ClusterMetadata::instance()->Init(worker_num, server_num, scheduler_ip, scheduler_port); | |||
| MS_LOG(INFO) << "Initialize cluster metadata for server. Worker number:" << worker_num | |||
| << ", Server number:" << server_num << ", Scheduler ip:" << scheduler_ip | |||
| << ", Scheduler port:" << scheduler_port; | |||
| communicators_[kTcpCommunicator] = tcp_comm; | |||
| } | |||
| return communicators_[kTcpCommunicator]; | |||
| } | |||
| bool ServerNode::Stop() { | |||
| MS_LOG(INFO) << "Stop server node!"; | |||
| if (!is_already_stopped_.load()) { | |||
| @@ -24,18 +24,25 @@ | |||
| #include <thread> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include <unordered_map> | |||
| #include "ps/core/cluster_metadata.h" | |||
| #include "ps/core/communicator/tcp_client.h" | |||
| #include "ps/core/communicator/tcp_server.h" | |||
| #include "ps/core/abstract_node.h" | |||
| #include "ps/core/communicator/task_executor.h" | |||
| #include "ps/core/communicator/communicator_base.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| constexpr char kTcpCommunicator[] = "TCP"; | |||
| constexpr char kHttpCommunicator[] = "HTTP"; | |||
| class ServerNode : public AbstractNode { | |||
| public: | |||
| ServerNode() : server_(nullptr), server_thread_(nullptr) {} | |||
| ~ServerNode() override = default; | |||
| bool Start(const uint32_t &timeout = ClusterMetadata::instance()->cluster_available_timeout()) override; | |||
| @@ -48,6 +55,12 @@ class ServerNode : public AbstractNode { | |||
| void set_handler(const RequestHandler &handler); | |||
| void Response(std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta, const void *data, size_t size); | |||
| std::shared_ptr<CommunicatorBase> GetOrCreateHttpComm(const std::string &ip, std::int16_t port, | |||
| const std::shared_ptr<TaskExecutor> &task_executor); | |||
| std::shared_ptr<CommunicatorBase> GetOrCreateTcpComm(const std::string &scheduler_ip, std::int16_t scheduler_port, | |||
| uint32_t worker_num, uint32_t server_num, | |||
| const std::shared_ptr<TaskExecutor> &task_executor); | |||
| private: | |||
| void CreateTcpServer(); | |||
| void Initialize(); | |||
| @@ -59,6 +72,8 @@ class ServerNode : public AbstractNode { | |||
| std::shared_ptr<TcpServer> server_; | |||
| std::unique_ptr<std::thread> server_thread_; | |||
| RequestHandler request_handler_; | |||
| std::unordered_map<std::string, std::shared_ptr<CommunicatorBase>> communicators_; | |||
| std::mutex communicator_mutex_; | |||
| }; | |||
| } // namespace core | |||
| } // namespace ps | |||