From ddb54ae0aca04c78df4908a4e65bbcff62f221b9 Mon Sep 17 00:00:00 2001 From: chendongsheng Date: Wed, 28 Apr 2021 09:21:56 +0800 Subject: [PATCH] add tcp communicator --- mindspore/ccsrc/ps/CMakeLists.txt | 1 + .../ps/core/communicator/communicator_base.h | 4 +- .../ps/core/communicator/task_executor.h | 2 +- .../ps/core/communicator/tcp_communicator.cc | 91 ++++++++++++++ .../ps/core/communicator/tcp_communicator.h | 111 ++++++++++++++++++ mindspore/ccsrc/ps/core/server_node.cc | 32 +++++ mindspore/ccsrc/ps/core/server_node.h | 15 +++ 7 files changed, 253 insertions(+), 3 deletions(-) create mode 100644 mindspore/ccsrc/ps/core/communicator/tcp_communicator.cc create mode 100644 mindspore/ccsrc/ps/core/communicator/tcp_communicator.h diff --git a/mindspore/ccsrc/ps/CMakeLists.txt b/mindspore/ccsrc/ps/CMakeLists.txt index b15d7d082f..bdf008ca88 100644 --- a/mindspore/ccsrc/ps/CMakeLists.txt +++ b/mindspore/ccsrc/ps/CMakeLists.txt @@ -14,6 +14,7 @@ if(NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))) list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_server.cc") list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/communicator_base.cc") list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/http_communicator.cc") + list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_communicator.cc") list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/http_msg_handler.cc") list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_msg_handler.cc") list(REMOVE_ITEM _PS_SRC_FILES "core/node.cc") diff --git a/mindspore/ccsrc/ps/core/communicator/communicator_base.h b/mindspore/ccsrc/ps/core/communicator/communicator_base.h index ce35420e75..5c10a54d83 100644 --- a/mindspore/ccsrc/ps/core/communicator/communicator_base.h +++ b/mindspore/ccsrc/ps/core/communicator/communicator_base.h @@ -40,9 +40,9 @@ class CommunicatorBase { using MessageCallback = std::function)>; using HttpMsgCallback = std::function)>; using OnNodeEventCallback = std::function; - using TcpMsgCallBack = std::function conn, + using TcpMsgCallback = std::function conn, std::shared_ptr meta, DataPtr data, size_t size)>; - using CertainEventCallBack = std::function; + using CertainEventCallback = std::function; CommunicatorBase() = default; diff --git a/mindspore/ccsrc/ps/core/communicator/task_executor.h b/mindspore/ccsrc/ps/core/communicator/task_executor.h index 16b996f9ed..272484ffe2 100644 --- a/mindspore/ccsrc/ps/core/communicator/task_executor.h +++ b/mindspore/ccsrc/ps/core/communicator/task_executor.h @@ -50,7 +50,7 @@ class TaskExecutor { bool Submit(Fun &&function, Args &&... args) { auto callee = std::bind(function, args...); std::function task = [callee]() -> void { callee(); }; - auto index = 0; + size_t index = 0; for (size_t i = 0; i < submit_timeout_; i++) { std::unique_lock lock(mtx_); if (task_num_ >= max_task_num_) { diff --git a/mindspore/ccsrc/ps/core/communicator/tcp_communicator.cc b/mindspore/ccsrc/ps/core/communicator/tcp_communicator.cc new file mode 100644 index 0000000000..756077f540 --- /dev/null +++ b/mindspore/ccsrc/ps/core/communicator/tcp_communicator.cc @@ -0,0 +1,91 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/core/communicator/tcp_communicator.h" +#include + +namespace mindspore { +namespace ps { +namespace core { +bool TcpCommunicator::Start() { + if (running_) { + MS_LOG(INFO) << "The TCP communicator has already started."; + return true; + } + MS_EXCEPTION_IF_NULL(server_node_); + + // Set message callback. For example, message of push/pull, etc. + tcp_msg_callback_ = std::bind( + [&](std::shared_ptr conn, std::shared_ptr meta, DataPtr data, + size_t size) -> void { + TcpUserCommand user_command = static_cast(meta->user_cmd()); + const std::string &msg_type = kUserCommandToMsgType.at(user_command); + if (msg_type == "" || !msg_callbacks_[msg_type]) { + MS_LOG(ERROR) << "Tcp server doesn't support command " << user_command << " " << msg_type; + return; + } + + MS_LOG(DEBUG) << "TcpCommunicator receives message for " << msg_type; + std::shared_ptr tcp_msg_handler = + std::make_shared(server_node_, conn, meta, data, size); + MS_EXCEPTION_IF_NULL(tcp_msg_handler); + task_executor_->Submit(msg_callbacks_[msg_type], tcp_msg_handler); + return; + }, + std::placeholders::_1, std::placeholders::_2, std::placeholders::_3, std::placeholders::_4); + server_node_->set_handler(tcp_msg_callback_); + + // Set event callback. For example, event of scaling out/in, etc. + event_callback_ = std::bind( + [&](const core::NodeEvent &event) -> void { + MS_LOG(INFO) << "Server receives event of " << event; + certain_event_to_callback_[event](); + }, + std::placeholders::_1); + server_node_->set_event_callback(event_callback_); + + server_node_->Start(); + running_ = true; + running_thread_ = std::thread([&]() { + while (running_) { + std::this_thread::yield(); + } + }); + return true; +} + +bool TcpCommunicator::Stop() { + MS_EXCEPTION_IF_NULL(server_node_); + server_node_->Finish(); + server_node_->Stop(); + running_ = false; + return true; +} + +void TcpCommunicator::RegisterMsgCallBack(const std::string &msg_type, const MessageCallback &cb) { + msg_callbacks_.try_emplace(msg_type, cb); + return; +} + +void TcpCommunicator::RegisterEventCallback(const core::NodeEvent &event, const CertainEventCallback &event_cb) { + certain_event_to_callback_.try_emplace(event, event_cb); + return; +} + +ServerNode *TcpCommunicator::server_node() { return server_node_; } +} // namespace core +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/core/communicator/tcp_communicator.h b/mindspore/ccsrc/ps/core/communicator/tcp_communicator.h new file mode 100644 index 0000000000..17be6827f9 --- /dev/null +++ b/mindspore/ccsrc/ps/core/communicator/tcp_communicator.h @@ -0,0 +1,111 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_TCP_COMMUNICATOR_H_ +#define MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_TCP_COMMUNICATOR_H_ + +#include +#include +#include +#include +#include +#include "proto/ps.pb.h" +#include "ps/core/server_node.h" +#include "ps/core/cluster_metadata.h" +#include "ps/ps_context.h" +#include "ps/core/communicator/task_executor.h" +#include "ps/core/communicator/communicator_base.h" +#include "ps/core/communicator/tcp_msg_handler.h" + +namespace mindspore { +namespace ps { +namespace core { +enum class TcpUserCommand { kPush, kPull, kCount, kReachThreshold, kResetCount, kGetValue, kPutValue, kCounterEvent }; + +const std::unordered_map kUserCommandToMsgType = { + {TcpUserCommand::kPush, "push"}, {TcpUserCommand::kPull, "pull"}, + {TcpUserCommand::kCount, "count"}, {TcpUserCommand::kReachThreshold, "reachThreshold"}, + {TcpUserCommand::kResetCount, "resetCnt"}, {TcpUserCommand::kGetValue, "getValue"}, + {TcpUserCommand::kPutValue, "putValue"}, {TcpUserCommand::kCounterEvent, "counterEvent"}, +}; + +class TcpCommunicator : public CommunicatorBase { + public: + explicit TcpCommunicator(const std::shared_ptr &task_executor, ServerNode *node) + : task_executor_(task_executor), + running_(false), + server_num_(0), + worker_num_(0), + scheduler_ip_(""), + scheduler_port_(0), + server_node_(node) {} + ~TcpCommunicator() = default; + + bool Start() override; + bool Stop() override; + + void RegisterMsgCallBack(const std::string &msg_type, const MessageCallback &cb) override; + void RegisterEventCallback(const core::NodeEvent &event, const CertainEventCallback &event_cb); + + ServerNode *server_node(); + + template + bool SendPbRequest(const T &pb_msg, const uint32_t &rank_id, TcpUserCommand command, + std::shared_ptr> *output = nullptr) { + const std::string &msg_str = pb_msg.SerializeAsString(); + std::shared_ptr msg(new unsigned char[msg_str.size()]); + size_t dest_size = msg_str.size(); + size_t src_size = msg_str.size(); + auto ret = memcpy_s(msg.get(), dest_size, msg_str.c_str(), src_size); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "memcpy_s error, error no " << ret; + } + + if (output != nullptr) { + if (!server_node_->Send(NodeRole::SERVER, rank_id, msg, msg_str.size(), static_cast(command), output)) { + MS_LOG(ERROR) << "Query leader server whether count is enough failed."; + return false; + } + } else { + if (!server_node_->Send(NodeRole::SERVER, rank_id, msg, msg_str.size(), static_cast(command))) { + MS_LOG(ERROR) << "Query leader server whether count is enough failed."; + return false; + } + } + return true; + } + + private: + std::shared_ptr task_executor_; + bool running_; + + TcpMsgCallback tcp_msg_callback_; + OnNodeEventCallback event_callback_; + // Each NodeEvent corresponds to a CertainEventCallback to process the event. + std::map certain_event_to_callback_; + + uint32_t server_num_; + uint32_t worker_num_; + + std::string scheduler_ip_; + uint16_t scheduler_port_; + + ServerNode *server_node_; +}; +} // namespace core +} // namespace ps +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_TCP_COMMUNICATOR_H_ diff --git a/mindspore/ccsrc/ps/core/server_node.cc b/mindspore/ccsrc/ps/core/server_node.cc index 2cdaa10d7d..e83bc371e9 100644 --- a/mindspore/ccsrc/ps/core/server_node.cc +++ b/mindspore/ccsrc/ps/core/server_node.cc @@ -14,6 +14,8 @@ * limitations under the License. */ #include "ps/core/server_node.h" +#include "ps/core/communicator/tcp_communicator.h" +#include "ps/core/communicator/http_communicator.h" namespace mindspore { namespace ps { @@ -124,6 +126,36 @@ void ServerNode::ProcessCollectiveSendData(std::shared_ptr conn, server_->SendMessage(conn, meta, Protos::RAW, data, size); } +std::shared_ptr ServerNode::GetOrCreateHttpComm(const std::string &ip, std::int16_t port, + const std::shared_ptr &task_executor) { + std::lock_guard lock(communicator_mutex_); + if (!communicators_.count(kHttpCommunicator)) { + MS_LOG(INFO) << "Create Http communicator."; + auto http_comm = std::make_shared(ip, port, task_executor); + MS_EXCEPTION_IF_NULL(http_comm); + communicators_[kHttpCommunicator] = http_comm; + } + return communicators_[kHttpCommunicator]; +} + +std::shared_ptr ServerNode::GetOrCreateTcpComm(const std::string &scheduler_ip, + std::int16_t scheduler_port, uint32_t worker_num, + uint32_t server_num, + const std::shared_ptr &task_executor) { + std::lock_guard lock(communicator_mutex_); + if (!communicators_.count(kTcpCommunicator)) { + MS_LOG(INFO) << "Create Tcp communicator."; + auto tcp_comm = std::make_shared(task_executor, this); + MS_EXCEPTION_IF_NULL(tcp_comm); + ClusterMetadata::instance()->Init(worker_num, server_num, scheduler_ip, scheduler_port); + MS_LOG(INFO) << "Initialize cluster metadata for server. Worker number:" << worker_num + << ", Server number:" << server_num << ", Scheduler ip:" << scheduler_ip + << ", Scheduler port:" << scheduler_port; + communicators_[kTcpCommunicator] = tcp_comm; + } + return communicators_[kTcpCommunicator]; +} + bool ServerNode::Stop() { MS_LOG(INFO) << "Stop server node!"; if (!is_already_stopped_.load()) { diff --git a/mindspore/ccsrc/ps/core/server_node.h b/mindspore/ccsrc/ps/core/server_node.h index 3622543a63..a1cb8f0f17 100644 --- a/mindspore/ccsrc/ps/core/server_node.h +++ b/mindspore/ccsrc/ps/core/server_node.h @@ -24,18 +24,25 @@ #include #include #include +#include #include "ps/core/cluster_metadata.h" #include "ps/core/communicator/tcp_client.h" #include "ps/core/communicator/tcp_server.h" #include "ps/core/abstract_node.h" +#include "ps/core/communicator/task_executor.h" +#include "ps/core/communicator/communicator_base.h" namespace mindspore { namespace ps { namespace core { +constexpr char kTcpCommunicator[] = "TCP"; +constexpr char kHttpCommunicator[] = "HTTP"; + class ServerNode : public AbstractNode { public: ServerNode() : server_(nullptr), server_thread_(nullptr) {} + ~ServerNode() override = default; bool Start(const uint32_t &timeout = ClusterMetadata::instance()->cluster_available_timeout()) override; @@ -48,6 +55,12 @@ class ServerNode : public AbstractNode { void set_handler(const RequestHandler &handler); void Response(std::shared_ptr conn, std::shared_ptr meta, const void *data, size_t size); + std::shared_ptr GetOrCreateHttpComm(const std::string &ip, std::int16_t port, + const std::shared_ptr &task_executor); + std::shared_ptr GetOrCreateTcpComm(const std::string &scheduler_ip, std::int16_t scheduler_port, + uint32_t worker_num, uint32_t server_num, + const std::shared_ptr &task_executor); + private: void CreateTcpServer(); void Initialize(); @@ -59,6 +72,8 @@ class ServerNode : public AbstractNode { std::shared_ptr server_; std::unique_ptr server_thread_; RequestHandler request_handler_; + std::unordered_map> communicators_; + std::mutex communicator_mutex_; }; } // namespace core } // namespace ps