Merge pull request !7911 from anancds/kv-patchtags/v1.1.0
| @@ -100,6 +100,11 @@ message("onnx proto path is :" ${ONNX_PROTO}) | |||
| ms_protobuf_generate(ONNX_PROTO_SRCS ONNX_PROTO_HDRS ${ONNX_PROTO}) | |||
| list(APPEND MINDSPORE_PROTO_LIST ${ONNX_PROTO_SRCS}) | |||
| include_directories("${CMAKE_BINARY_DIR}/ps/comm") | |||
| file(GLOB_RECURSE COMM_PROTO_IN RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ps/comm/protos/*.proto") | |||
| ms_protobuf_generate(COMM_PROTO_SRCS COMM_PROTO_HDRS ${COMM_PROTO_IN}) | |||
| list(APPEND MINDSPORE_PROTO_LIST ${COMM_PROTO_SRCS}) | |||
| if (ENABLE_DEBUGGER) | |||
| # debugger: compile proto files | |||
| include_directories("${CMAKE_BINARY_DIR}/debug/debugger") | |||
| @@ -290,7 +295,7 @@ if (CMAKE_SYSTEM_NAME MATCHES "Windows") | |||
| target_link_libraries(_c_expression PRIVATE -Wl,--whole-archive mindspore -Wl,--no-whole-archive) | |||
| else () | |||
| if (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)) | |||
| target_link_libraries(mindspore mindspore::pslite mindspore::protobuf mindspore::event mindspore::event_pthreads ${zeromq_DIRPATH}/zmq_install/lib/libzmq.a) | |||
| target_link_libraries(mindspore mindspore::pslite proto_input mindspore::protobuf mindspore::event mindspore::event_pthreads ${zeromq_DIRPATH}/zmq_install/lib/libzmq.a) | |||
| if (${ENABLE_IBVERBS} STREQUAL "ON") | |||
| target_link_libraries(mindspore ibverbs rdmacm) | |||
| endif() | |||
| @@ -0,0 +1,42 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| syntax = "proto3"; | |||
| import "google/protobuf/any.proto"; | |||
| package mindspore.ps; | |||
| option optimize_for = LITE_RUNTIME; | |||
| message MessageMeta { | |||
| // hostname or ip | |||
| string hostname = 1; | |||
| // the port of this node | |||
| int32 port = 2; | |||
| // the command of this message,for example: register、heartbeat、data | |||
| int32 cmd = 3; | |||
| // the timestamp of this message | |||
| int32 timestamp = 4; | |||
| // data type of message | |||
| repeated int32 data_type = 5 [packed = true]; | |||
| // message.data_size | |||
| int32 data_size = 6; | |||
| } | |||
| message CommMessage { | |||
| MessageMeta pb_meta = 1; | |||
| bytes data = 2; | |||
| } | |||
| @@ -0,0 +1,25 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| message KVMessage { | |||
| repeated int32 keys = 1; | |||
| repeated float values = 2; | |||
| } | |||
| message HeartBeatMessage { | |||
| // *.*.*.*:port | |||
| repeated string host_and_port = 1; | |||
| } | |||
| @@ -36,18 +36,16 @@ namespace mindspore { | |||
| namespace ps { | |||
| namespace comm { | |||
| TcpClient::TcpClient(std::string address, std::uint16_t port) | |||
| TcpClient::TcpClient(const std::string &address, std::uint16_t port) | |||
| : event_base_(nullptr), | |||
| event_timeout_(nullptr), | |||
| buffer_event_(nullptr), | |||
| server_address_(std::move(address)), | |||
| server_port_(port) { | |||
| message_handler_.SetCallback([this](const void *buf, size_t num) { | |||
| if (buf == nullptr) { | |||
| if (disconnected_callback_) disconnected_callback_(*this, 200); | |||
| Stop(); | |||
| message_handler_.SetCallback([this](const CommMessage &message) { | |||
| if (message_callback_) { | |||
| message_callback_(*this, message); | |||
| } | |||
| if (message_callback_) message_callback_(*this, buf, num); | |||
| }); | |||
| } | |||
| @@ -63,7 +61,7 @@ void TcpClient::SetCallback(const OnConnected &conn, const OnDisconnected &disco | |||
| timeout_callback_ = timeout; | |||
| } | |||
| void TcpClient::InitTcpClient() { | |||
| void TcpClient::Init() { | |||
| if (buffer_event_) { | |||
| return; | |||
| } | |||
| @@ -139,7 +137,7 @@ void TcpClient::SetTcpNoDelay(const evutil_socket_t &fd) { | |||
| void TcpClient::TimeoutCallback(evutil_socket_t, std::int16_t, void *arg) { | |||
| MS_EXCEPTION_IF_NULL(arg); | |||
| auto tcp_client = reinterpret_cast<TcpClient *>(arg); | |||
| tcp_client->InitTcpClient(); | |||
| tcp_client->Init(); | |||
| } | |||
| void TcpClient::ReadCallback(struct bufferevent *bev, void *ctx) { | |||
| @@ -150,10 +148,10 @@ void TcpClient::ReadCallback(struct bufferevent *bev, void *ctx) { | |||
| MS_EXCEPTION_IF_NULL(input); | |||
| char read_buffer[4096]; | |||
| int read = 0; | |||
| while ((read = EVBUFFER_LENGTH(input)) > 0) { | |||
| if (evbuffer_remove(input, &read_buffer, sizeof(read_buffer)) == -1) { | |||
| while (EVBUFFER_LENGTH(input) > 0) { | |||
| int read = evbuffer_remove(input, &read_buffer, sizeof(read_buffer)); | |||
| if (read == -1) { | |||
| MS_LOG(EXCEPTION) << "Can not drain data from the event buffer!"; | |||
| } | |||
| tcp_client->OnReadHandler(read_buffer, read); | |||
| @@ -196,25 +194,38 @@ void TcpClient::EventCallback(struct bufferevent *bev, std::int16_t events, void | |||
| void TcpClient::Start() { | |||
| MS_EXCEPTION_IF_NULL(event_base_); | |||
| int ret = event_base_dispatch(event_base_); | |||
| if (ret == 0) { | |||
| MS_LOG(INFO) << "Event base dispatch success!"; | |||
| } else if (ret == 1) { | |||
| MS_LOG(ERROR) << "Event base dispatch failed with no events pending or active!"; | |||
| } else if (ret == -1) { | |||
| MS_LOG(ERROR) << "Event base dispatch failed with error occurred!"; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Event base dispatch with unexpect error code!"; | |||
| } | |||
| MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base dispatch success!"; | |||
| MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) | |||
| << "Event base dispatch failed with no events pending or active!"; | |||
| MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base dispatch failed with error occurred!"; | |||
| MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base dispatch with unexpect error code!"; | |||
| } | |||
| void TcpClient::StartWithNoBlock() { | |||
| MS_LOG(INFO) << "Start tcp client with no block!"; | |||
| MS_EXCEPTION_IF_NULL(event_base_); | |||
| int ret = event_base_loop(event_base_, EVLOOP_NONBLOCK); | |||
| MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base loop success!"; | |||
| MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) << "Event base loop failed with no events pending or active!"; | |||
| MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base loop failed with error occurred!"; | |||
| MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base loop with unexpect error code!"; | |||
| } | |||
| void TcpClient::ReceiveMessage(const OnMessage &cb) { message_callback_ = cb; } | |||
| void TcpClient::SetMessageCallback(const OnMessage &cb) { message_callback_ = cb; } | |||
| void TcpClient::SendMessage(const void *buf, size_t num) const { | |||
| void TcpClient::SendMessage(const CommMessage &message) const { | |||
| MS_EXCEPTION_IF_NULL(buffer_event_); | |||
| if (evbuffer_add(bufferevent_get_output(buffer_event_), buf, num) == -1) { | |||
| MS_LOG(EXCEPTION) << "Event buffer add failed!"; | |||
| uint32_t buf_size = message.ByteSizeLong(); | |||
| std::vector<unsigned char> serialized(buf_size); | |||
| message.SerializeToArray(serialized.data(), static_cast<int>(buf_size)); | |||
| if (evbuffer_add(bufferevent_get_output(buffer_event_), &buf_size, sizeof(buf_size)) == -1) { | |||
| MS_LOG(EXCEPTION) << "Event buffer add header failed!"; | |||
| } | |||
| if (evbuffer_add(bufferevent_get_output(buffer_event_), serialized.data(), buf_size) == -1) { | |||
| MS_LOG(EXCEPTION) << "Event buffer add protobuf data failed!"; | |||
| } | |||
| } | |||
| } // namespace comm | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -23,6 +23,10 @@ | |||
| #include <event2/bufferevent.h> | |||
| #include <functional> | |||
| #include <string> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "proto/comm.pb.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| @@ -30,24 +34,25 @@ namespace comm { | |||
| class TcpClient { | |||
| public: | |||
| using OnMessage = std::function<void(const TcpClient &, const void *, size_t)>; | |||
| using OnConnected = std::function<void(const TcpClient &)>; | |||
| using OnDisconnected = std::function<void(const TcpClient &, int)>; | |||
| using OnRead = std::function<void(const TcpClient &, const void *, size_t)>; | |||
| using OnTimeout = std::function<void(const TcpClient &)>; | |||
| using OnMessage = std::function<void(const TcpClient &, const CommMessage &)>; | |||
| explicit TcpClient(std::string address, std::uint16_t port); | |||
| explicit TcpClient(const std::string &address, std::uint16_t port); | |||
| virtual ~TcpClient(); | |||
| std::string GetServerAddress() const; | |||
| void SetCallback(const OnConnected &conn, const OnDisconnected &disconn, const OnRead &read, | |||
| const OnTimeout &timeout); | |||
| void InitTcpClient(); | |||
| void Init(); | |||
| void StartWithDelay(int seconds); | |||
| void Stop(); | |||
| void ReceiveMessage(const OnMessage &cb); | |||
| void SendMessage(const void *buf, size_t num) const; | |||
| void Start(); | |||
| void StartWithNoBlock(); | |||
| void SetMessageCallback(const OnMessage &cb); | |||
| void SendMessage(const CommMessage &message) const; | |||
| protected: | |||
| static void SetTcpNoDelay(const evutil_socket_t &fd); | |||
| @@ -57,8 +62,9 @@ class TcpClient { | |||
| virtual void OnReadHandler(const void *buf, size_t num); | |||
| private: | |||
| TcpMessageHandler message_handler_; | |||
| OnMessage message_callback_; | |||
| TcpMessageHandler message_handler_; | |||
| OnConnected connected_callback_; | |||
| OnDisconnected disconnected_callback_; | |||
| OnRead read_callback_; | |||
| @@ -71,6 +77,7 @@ class TcpClient { | |||
| std::string server_address_; | |||
| std::uint16_t server_port_; | |||
| }; | |||
| } // namespace comm | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -15,6 +15,8 @@ | |||
| */ | |||
| #include "ps/comm/tcp_message_handler.h" | |||
| #include <arpa/inet.h> | |||
| #include <iostream> | |||
| #include <utility> | |||
| @@ -22,15 +24,55 @@ namespace mindspore { | |||
| namespace ps { | |||
| namespace comm { | |||
| void TcpMessageHandler::SetCallback(messageReceive message_receive) { message_callback_ = std::move(message_receive); } | |||
| void TcpMessageHandler::SetCallback(const messageReceive &message_receive) { message_callback_ = message_receive; } | |||
| void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) { | |||
| MS_EXCEPTION_IF_NULL(buffer); | |||
| auto buffer_data = reinterpret_cast<const unsigned char *>(buffer); | |||
| while (num > 0) { | |||
| if (remaining_length_ == 0) { | |||
| for (int i = 0; i < 4 && num > 0; ++i) { | |||
| header_[++header_index_] = *(buffer_data + i); | |||
| --num; | |||
| if (header_index_ == 3) { | |||
| message_length_ = *reinterpret_cast<const uint32_t *>(header_); | |||
| message_length_ = ntohl(message_length_); | |||
| remaining_length_ = message_length_; | |||
| message_buffer_.reset(new unsigned char[remaining_length_]); | |||
| buffer_data += i; | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| if (remaining_length_ > 0) { | |||
| uint32_t copy_len = remaining_length_ <= num ? remaining_length_ : num; | |||
| remaining_length_ -= copy_len; | |||
| num -= copy_len; | |||
| if (message_callback_) { | |||
| message_callback_(buffer, num); | |||
| int ret = memcpy_s(message_buffer_.get() + last_copy_len_, copy_len, buffer_data, copy_len); | |||
| last_copy_len_ += copy_len; | |||
| buffer_data += copy_len; | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| if (remaining_length_ == 0) { | |||
| CommMessage pb_message; | |||
| pb_message.ParseFromArray(reinterpret_cast<const void *>(message_buffer_.get()), message_length_); | |||
| if (message_callback_) { | |||
| message_callback_(pb_message); | |||
| } | |||
| message_buffer_.reset(); | |||
| message_buffer_ = nullptr; | |||
| header_index_ = 0; | |||
| last_copy_len_ = 0; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } // namespace comm | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -19,26 +19,43 @@ | |||
| #include <functional> | |||
| #include <iostream> | |||
| #include <string> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "utils/log_adapter.h" | |||
| #include "proto/comm.pb.h" | |||
| #include "proto/ps.pb.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace comm { | |||
| using messageReceive = std::function<void(const void *buffer, size_t len)>; | |||
| using messageReceive = std::function<void(const CommMessage &message)>; | |||
| class TcpMessageHandler { | |||
| public: | |||
| TcpMessageHandler() = default; | |||
| TcpMessageHandler() | |||
| : is_parsed_(false), | |||
| message_buffer_(nullptr), | |||
| message_length_(0), | |||
| remaining_length_(0), | |||
| header_index_(-1), | |||
| last_copy_len_(0) {} | |||
| virtual ~TcpMessageHandler() = default; | |||
| void SetCallback(messageReceive cb); | |||
| void SetCallback(const messageReceive &cb); | |||
| void ReceiveMessage(const void *buffer, size_t num); | |||
| private: | |||
| messageReceive message_callback_; | |||
| bool is_parsed_; | |||
| std::unique_ptr<unsigned char> message_buffer_; | |||
| size_t message_length_; | |||
| uint32_t remaining_length_; | |||
| char header_[4]; | |||
| int header_index_; | |||
| uint32_t last_copy_len_; | |||
| }; | |||
| } // namespace comm | |||
| } // namespace ps | |||
| @@ -33,16 +33,12 @@ namespace mindspore { | |||
| namespace ps { | |||
| namespace comm { | |||
| void TcpConnection::InitConnection(const evutil_socket_t &fd, const struct bufferevent *bev, const TcpServer *server) { | |||
| MS_EXCEPTION_IF_NULL(bev); | |||
| MS_EXCEPTION_IF_NULL(server); | |||
| buffer_event_ = const_cast<struct bufferevent *>(bev); | |||
| fd_ = fd; | |||
| server_ = const_cast<TcpServer *>(server); | |||
| tcp_message_handler_.SetCallback([this, server](const void *buf, size_t num) { | |||
| OnServerReceiveMessage message_callback = server->GetServerReceiveMessage(); | |||
| if (message_callback) message_callback(*server, *this, buf, num); | |||
| void TcpConnection::InitConnection() { | |||
| tcp_message_handler_.SetCallback([&](const CommMessage &message) { | |||
| OnServerReceiveMessage on_server_receive = server_->GetServerReceive(); | |||
| if (on_server_receive) { | |||
| on_server_receive(*server_, *this, message); | |||
| } | |||
| }); | |||
| } | |||
| @@ -54,11 +50,26 @@ void TcpConnection::SendMessage(const void *buffer, size_t num) const { | |||
| } | |||
| } | |||
| TcpServer *TcpConnection::GetServer() const { return server_; } | |||
| TcpServer *TcpConnection::GetServer() const { return const_cast<TcpServer *>(server_); } | |||
| evutil_socket_t TcpConnection::GetFd() const { return fd_; } | |||
| const evutil_socket_t &TcpConnection::GetFd() const { return fd_; } | |||
| TcpServer::TcpServer(std::string address, std::uint16_t port) | |||
| void TcpConnection::SendMessage(const CommMessage &message) const { | |||
| MS_EXCEPTION_IF_NULL(buffer_event_); | |||
| uint32_t buf_size = message.ByteSizeLong(); | |||
| std::vector<unsigned char> serialized(buf_size); | |||
| message.SerializeToArray(serialized.data(), static_cast<int>(buf_size)); | |||
| if (evbuffer_add(bufferevent_get_output(const_cast<struct bufferevent *>(buffer_event_)), &buf_size, | |||
| sizeof(buf_size)) == -1) { | |||
| MS_LOG(EXCEPTION) << "Event buffer add header failed!"; | |||
| } | |||
| if (evbuffer_add(bufferevent_get_output(const_cast<struct bufferevent *>(buffer_event_)), serialized.data(), | |||
| buf_size) == -1) { | |||
| MS_LOG(EXCEPTION) << "Event buffer add protobuf data failed!"; | |||
| } | |||
| } | |||
| TcpServer::TcpServer(const std::string &address, std::uint16_t port) | |||
| : base_(nullptr), | |||
| signal_event_(nullptr), | |||
| listener_(nullptr), | |||
| @@ -74,7 +85,7 @@ void TcpServer::SetServerCallback(const OnConnected &client_conn, const OnDiscon | |||
| this->client_accept_ = client_accept; | |||
| } | |||
| void TcpServer::InitServer() { | |||
| void TcpServer::Init() { | |||
| base_ = event_base_new(); | |||
| MS_EXCEPTION_IF_NULL(base_); | |||
| CommUtil::CheckIp(server_address_); | |||
| @@ -101,19 +112,26 @@ void TcpServer::InitServer() { | |||
| } | |||
| void TcpServer::Start() { | |||
| std::unique_lock<std::recursive_mutex> l(connection_mutex_); | |||
| std::unique_lock<std::recursive_mutex> lock(connection_mutex_); | |||
| MS_LOG(INFO) << "Start tcp server!"; | |||
| MS_EXCEPTION_IF_NULL(base_); | |||
| int ret = event_base_dispatch(base_); | |||
| if (ret == 0) { | |||
| MS_LOG(INFO) << "Event base dispatch success!"; | |||
| } else if (ret == 1) { | |||
| MS_LOG(ERROR) << "Event base dispatch failed with no events pending or active!"; | |||
| } else if (ret == -1) { | |||
| MS_LOG(ERROR) << "Event base dispatch failed with error occurred!"; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Event base dispatch with unexpect error code!"; | |||
| } | |||
| MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base dispatch success!"; | |||
| MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) | |||
| << "Event base dispatch failed with no events pending or active!"; | |||
| MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base dispatch failed with error occurred!"; | |||
| MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base dispatch with unexpect error code!"; | |||
| } | |||
| void TcpServer::StartWithNoBlock() { | |||
| std::unique_lock<std::recursive_mutex> lock(connection_mutex_); | |||
| MS_LOG(INFO) << "Start tcp server with no block!"; | |||
| MS_EXCEPTION_IF_NULL(base_); | |||
| int ret = event_base_loop(base_, EVLOOP_NONBLOCK); | |||
| MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base loop success!"; | |||
| MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) << "Event base loop failed with no events pending or active!"; | |||
| MSLOG_IF(mindspore::ERROR, ret == -1, NoExceptionType) << "Event base loop failed with error occurred!"; | |||
| MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base loop with unexpect error code!"; | |||
| } | |||
| void TcpServer::Stop() { | |||
| @@ -150,6 +168,8 @@ void TcpServer::AddConnection(const evutil_socket_t &fd, const TcpConnection *co | |||
| void TcpServer::RemoveConnection(const evutil_socket_t &fd) { | |||
| std::unique_lock<std::recursive_mutex> lock(connection_mutex_); | |||
| TcpConnection *connection = const_cast<TcpConnection *>(connections_.find(fd)->second); | |||
| delete connection; | |||
| connections_.erase(fd); | |||
| } | |||
| @@ -166,10 +186,10 @@ void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, st | |||
| return; | |||
| } | |||
| TcpConnection *conn = server->onCreateConnection(); | |||
| TcpConnection *conn = server->onCreateConnection(bev, fd); | |||
| MS_EXCEPTION_IF_NULL(conn); | |||
| conn->InitConnection(fd, bev, server); | |||
| conn->InitConnection(); | |||
| server->AddConnection(fd, conn); | |||
| bufferevent_setcb(bev, TcpServer::ReadCallback, nullptr, TcpServer::EventCallback, reinterpret_cast<void *>(conn)); | |||
| if (bufferevent_enable(bev, EV_READ | EV_WRITE) == -1) { | |||
| @@ -177,17 +197,18 @@ void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, st | |||
| } | |||
| } | |||
| TcpConnection *TcpServer::onCreateConnection() { | |||
| TcpConnection *TcpServer::onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd) { | |||
| TcpConnection *conn = nullptr; | |||
| if (client_accept_) | |||
| conn = const_cast<TcpConnection *>(client_accept_(this)); | |||
| else | |||
| conn = new TcpConnection(); | |||
| if (client_accept_) { | |||
| conn = const_cast<TcpConnection *>(client_accept_(*this)); | |||
| } else { | |||
| conn = new TcpConnection(bev, fd, this); | |||
| } | |||
| return conn; | |||
| } | |||
| OnServerReceiveMessage TcpServer::GetServerReceiveMessage() const { return message_callback_; } | |||
| OnServerReceiveMessage TcpServer::GetServerReceive() const { return message_callback_; } | |||
| void TcpServer::SignalCallback(evutil_socket_t, std::int16_t, void *data) { | |||
| auto server = reinterpret_cast<class TcpServer *>(data); | |||
| @@ -207,9 +228,9 @@ void TcpServer::ReadCallback(struct bufferevent *bev, void *connection) { | |||
| auto conn = static_cast<class TcpConnection *>(connection); | |||
| struct evbuffer *buf = bufferevent_get_input(bev); | |||
| char read_buffer[4096]; | |||
| auto read = 0; | |||
| while ((read = EVBUFFER_LENGTH(buf)) > 0) { | |||
| if (evbuffer_remove(buf, &read_buffer, sizeof(read_buffer)) == -1) { | |||
| while (EVBUFFER_LENGTH(buf) > 0) { | |||
| int read = evbuffer_remove(buf, &read_buffer, sizeof(read_buffer)); | |||
| if (read == -1) { | |||
| MS_LOG(EXCEPTION) << "Can not drain data from the event buffer!"; | |||
| } | |||
| conn->OnReadHandler(read_buffer, static_cast<size_t>(read)); | |||
| @@ -219,43 +240,47 @@ void TcpServer::ReadCallback(struct bufferevent *bev, void *connection) { | |||
| void TcpServer::EventCallback(struct bufferevent *bev, std::int16_t events, void *data) { | |||
| MS_EXCEPTION_IF_NULL(bev); | |||
| MS_EXCEPTION_IF_NULL(data); | |||
| struct evbuffer *output = bufferevent_get_output(bev); | |||
| size_t remain = evbuffer_get_length(output); | |||
| auto conn = reinterpret_cast<TcpConnection *>(data); | |||
| TcpServer *srv = conn->GetServer(); | |||
| if (events & BEV_EVENT_EOF) { | |||
| MS_LOG(INFO) << "Event buffer end of file!"; | |||
| // Notify about disconnection | |||
| if (srv->client_disconnection_) srv->client_disconnection_(srv, conn); | |||
| if (srv->client_disconnection_) { | |||
| srv->client_disconnection_(*srv, *conn); | |||
| } | |||
| // Free connection structures | |||
| srv->RemoveConnection(conn->GetFd()); | |||
| bufferevent_free(bev); | |||
| } else if (events & BEV_EVENT_ERROR) { | |||
| MS_LOG(ERROR) << "Event buffer remain data: " << remain; | |||
| // Free connection structures | |||
| srv->RemoveConnection(conn->GetFd()); | |||
| bufferevent_free(bev); | |||
| // Notify about disconnection | |||
| if (srv->client_disconnection_) srv->client_disconnection_(srv, conn); | |||
| if (srv->client_disconnection_) { | |||
| srv->client_disconnection_(*srv, *conn); | |||
| } | |||
| } else { | |||
| MS_LOG(ERROR) << "Unhandled event!"; | |||
| } | |||
| } | |||
| void TcpServer::ReceiveMessage(const OnServerReceiveMessage &cb) { message_callback_ = cb; } | |||
| void TcpServer::SendMessage(const TcpConnection &conn, const CommMessage &message) { conn.SendMessage(message); } | |||
| void TcpServer::SendMessage(const TcpConnection &conn, const void *data, size_t num) { | |||
| MS_EXCEPTION_IF_NULL(data); | |||
| auto mc = const_cast<TcpConnection &>(conn); | |||
| mc.SendMessage(data, num); | |||
| } | |||
| void TcpServer::SendMessage(const void *data, size_t num) { | |||
| MS_EXCEPTION_IF_NULL(data); | |||
| void TcpServer::SendMessage(const CommMessage &message) { | |||
| std::unique_lock<std::recursive_mutex> lock(connection_mutex_); | |||
| for (auto it = connections_.begin(); it != connections_.end(); ++it) { | |||
| SendMessage(*it->second, data, num); | |||
| SendMessage(*it->second, message); | |||
| } | |||
| } | |||
| void TcpServer::SetMessageCallback(const OnServerReceiveMessage &cb) { message_callback_ = cb; } | |||
| } // namespace comm | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -27,6 +27,8 @@ | |||
| #include <map> | |||
| #include <mutex> | |||
| #include <string> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "utils/log_adapter.h" | |||
| #include "ps/comm/tcp_message_handler.h" | |||
| @@ -38,46 +40,49 @@ namespace comm { | |||
| class TcpServer; | |||
| class TcpConnection { | |||
| public: | |||
| TcpConnection() : buffer_event_(nullptr), fd_(0), server_(nullptr) {} | |||
| explicit TcpConnection(struct bufferevent *bev, const evutil_socket_t &fd, const TcpServer *server) | |||
| : buffer_event_(bev), fd_(0), server_(server) {} | |||
| virtual ~TcpConnection() = default; | |||
| virtual void InitConnection(const evutil_socket_t &fd, const struct bufferevent *bev, const TcpServer *server); | |||
| void SendMessage(const void *buffer, size_t num) const; | |||
| virtual void InitConnection(); | |||
| virtual void SendMessage(const void *buffer, size_t num) const; | |||
| void SendMessage(const CommMessage &message) const; | |||
| virtual void OnReadHandler(const void *buffer, size_t numBytes); | |||
| TcpServer *GetServer() const; | |||
| evutil_socket_t GetFd() const; | |||
| const evutil_socket_t &GetFd() const; | |||
| protected: | |||
| TcpMessageHandler tcp_message_handler_; | |||
| struct bufferevent *buffer_event_; | |||
| evutil_socket_t fd_; | |||
| TcpServer *server_; | |||
| const TcpServer *server_; | |||
| TcpMessageHandler tcp_message_handler_; | |||
| }; | |||
| using OnServerReceiveMessage = | |||
| std::function<void(const TcpServer &tcp_server, const TcpConnection &conn, const void *buffer, size_t num)>; | |||
| std::function<void(const TcpServer &tcp_server, const TcpConnection &conn, const CommMessage &)>; | |||
| class TcpServer { | |||
| public: | |||
| using OnConnected = std::function<void(const TcpServer *, const TcpConnection *)>; | |||
| using OnDisconnected = std::function<void(const TcpServer *, const TcpConnection *)>; | |||
| using OnAccepted = std::function<const TcpConnection *(const TcpServer *)>; | |||
| using OnConnected = std::function<void(const TcpServer &, const TcpConnection &)>; | |||
| using OnDisconnected = std::function<void(const TcpServer &, const TcpConnection &)>; | |||
| using OnAccepted = std::function<const TcpConnection *(const TcpServer &)>; | |||
| explicit TcpServer(std::string address, std::uint16_t port); | |||
| explicit TcpServer(const std::string &address, std::uint16_t port); | |||
| virtual ~TcpServer(); | |||
| void SetServerCallback(const OnConnected &client_conn, const OnDisconnected &client_disconn, | |||
| const OnAccepted &client_accept); | |||
| void InitServer(); | |||
| void Init(); | |||
| void Start(); | |||
| void StartWithNoBlock(); | |||
| void Stop(); | |||
| void SendToAllClients(const char *data, size_t len); | |||
| void AddConnection(const evutil_socket_t &fd, const TcpConnection *connection); | |||
| void RemoveConnection(const evutil_socket_t &fd); | |||
| void ReceiveMessage(const OnServerReceiveMessage &cb); | |||
| static void SendMessage(const TcpConnection &conn, const void *data, size_t num); | |||
| void SendMessage(const void *data, size_t num); | |||
| OnServerReceiveMessage GetServerReceiveMessage() const; | |||
| OnServerReceiveMessage GetServerReceive() const; | |||
| void SetMessageCallback(const OnServerReceiveMessage &cb); | |||
| static void SendMessage(const TcpConnection &conn, const CommMessage &message); | |||
| void SendMessage(const CommMessage &message); | |||
| protected: | |||
| static void ListenerCallback(struct evconnlistener *listener, evutil_socket_t socket, struct sockaddr *saddr, | |||
| @@ -85,9 +90,8 @@ class TcpServer { | |||
| static void SignalCallback(evutil_socket_t sig, std::int16_t events, void *server); | |||
| static void ReadCallback(struct bufferevent *, void *connection); | |||
| static void EventCallback(struct bufferevent *, std::int16_t events, void *server); | |||
| virtual TcpConnection *onCreateConnection(); | |||
| virtual TcpConnection *onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd); | |||
| private: | |||
| struct event_base *base_; | |||
| struct event *signal_event_; | |||
| struct evconnlistener *listener_; | |||
| @@ -101,6 +105,7 @@ class TcpServer { | |||
| std::recursive_mutex connection_mutex_; | |||
| OnServerReceiveMessage message_callback_; | |||
| }; | |||
| } // namespace comm | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -24,6 +24,7 @@ | |||
| #include <iostream> | |||
| #include <string> | |||
| #include <thread> | |||
| #include <memory> | |||
| namespace mindspore { | |||
| namespace ps { | |||
| @@ -31,7 +32,9 @@ namespace comm { | |||
| class TestHttpServer : public UT::Common { | |||
| public: | |||
| TestHttpServer() = default; | |||
| TestHttpServer() : server_(nullptr) {} | |||
| virtual ~TestHttpServer() = default; | |||
| static void testGetHandler(std::shared_ptr<HttpMessageHandler> resp) { | |||
| std::string host = resp->GetRequestHost(); | |||
| @@ -57,7 +60,7 @@ class TestHttpServer : public UT::Common { | |||
| } | |||
| void SetUp() override { | |||
| server_ = new HttpServer("0.0.0.0", 9999); | |||
| server_ = std::make_unique<HttpServer>("0.0.0.0", 9999); | |||
| OnRequestReceive http_get_func = std::bind( | |||
| [](std::shared_ptr<HttpMessageHandler> resp) { | |||
| EXPECT_STREQ(resp->GetPathParam("key1").c_str(), "value1"); | |||
| @@ -106,7 +109,7 @@ class TestHttpServer : public UT::Common { | |||
| } | |||
| private: | |||
| HttpServer *server_; | |||
| std::unique_ptr<HttpServer> server_; | |||
| }; | |||
| TEST_F(TestHttpServer, httpGetQequest) { | |||
| @@ -143,13 +146,13 @@ TEST_F(TestHttpServer, messageHandler) { | |||
| } | |||
| TEST_F(TestHttpServer, portErrorNoException) { | |||
| HttpServer *server_exception = new HttpServer("0.0.0.0", -1); | |||
| auto server_exception = std::make_unique<HttpServer>("0.0.0.0", -1); | |||
| OnRequestReceive http_handler_func = std::bind(TestHttpServer::testGetHandler, std::placeholders::_1); | |||
| EXPECT_NO_THROW(server_exception->RegisterRoute("/handler", &http_handler_func)); | |||
| } | |||
| TEST_F(TestHttpServer, addressException) { | |||
| HttpServer *server_exception = new HttpServer("12344.0.0.0", 9998); | |||
| auto server_exception = std::make_unique<HttpServer>("12344.0.0.0", 9998); | |||
| OnRequestReceive http_handler_func = std::bind(TestHttpServer::testGetHandler, std::placeholders::_1); | |||
| ASSERT_THROW(server_exception->RegisterRoute("/handler", &http_handler_func), std::exception); | |||
| } | |||
| @@ -14,6 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include <memory> | |||
| #include "common/common_test.h" | |||
| #include "ps/comm/tcp_client.h" | |||
| @@ -26,19 +28,19 @@ class TestTcpClient : public UT::Common { | |||
| }; | |||
| TEST_F(TestTcpClient, InitClientIPError) { | |||
| auto client = new TcpClient("127.0.0.13543", 9000); | |||
| client->ReceiveMessage( | |||
| [](const TcpClient &client, const void *buffer, size_t num) { client.SendMessage(buffer, num); }); | |||
| auto client = std::make_unique<TcpClient>("127.0.0.13543", 9000); | |||
| client->SetMessageCallback([](const TcpClient &client, const CommMessage &message) { client.SendMessage(message); }); | |||
| ASSERT_THROW(client->InitTcpClient(), std::exception); | |||
| ASSERT_THROW(client->Init(), std::exception); | |||
| } | |||
| TEST_F(TestTcpClient, InitClientPortErrorNoException) { | |||
| auto client = new TcpClient("127.0.0.1", -1); | |||
| client->ReceiveMessage( | |||
| [](const TcpClient &client, const void *buffer, size_t num) { client.SendMessage(buffer, num); }); | |||
| auto client = std::make_unique<TcpClient>("127.0.0.1", -1); | |||
| client->SetMessageCallback([](const TcpClient &client, const CommMessage &message) { client.SendMessage(message); }); | |||
| EXPECT_NO_THROW(client->InitTcpClient()); | |||
| EXPECT_NO_THROW(client->Init()); | |||
| } | |||
| } // namespace comm | |||
| @@ -18,6 +18,7 @@ | |||
| #include "ps/comm/tcp_server.h" | |||
| #include "common/common_test.h" | |||
| #include <memory> | |||
| #include <thread> | |||
| namespace mindspore { | |||
| @@ -25,16 +26,20 @@ namespace ps { | |||
| namespace comm { | |||
| class TestTcpServer : public UT::Common { | |||
| public: | |||
| TestTcpServer() = default; | |||
| TestTcpServer() : client_(nullptr), server_(nullptr) {} | |||
| virtual ~TestTcpServer() = default; | |||
| void SetUp() override { | |||
| server_ = new TcpServer("127.0.0.1", 9000); | |||
| server_ = std::make_unique<TcpServer>("127.0.0.1", 9998); | |||
| std::unique_ptr<std::thread> http_server_thread_(nullptr); | |||
| http_server_thread_ = std::make_unique<std::thread>([&]() { | |||
| server_->ReceiveMessage([](const TcpServer &server, const TcpConnection &conn, const void *buffer, size_t num) { | |||
| EXPECT_STREQ(std::string(reinterpret_cast<const char *>(buffer), num).c_str(), "TCP_MESSAGE"); | |||
| server.SendMessage(conn, buffer, num); | |||
| server_->SetMessageCallback([](const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { | |||
| KVMessage kv_message; | |||
| kv_message.ParseFromString(message.data()); | |||
| EXPECT_EQ(2, kv_message.keys_size()); | |||
| server.SendMessage(conn, message); | |||
| }); | |||
| server_->InitServer(); | |||
| server_->Init(); | |||
| server_->Start(); | |||
| }); | |||
| http_server_thread_->detach(); | |||
| @@ -47,21 +52,32 @@ class TestTcpServer : public UT::Common { | |||
| server_->Stop(); | |||
| } | |||
| TcpClient *client_; | |||
| TcpServer *server_; | |||
| const std::string test_message_ = "TCP_MESSAGE"; | |||
| std::unique_ptr<TcpClient> client_; | |||
| std::unique_ptr<TcpServer> server_; | |||
| }; | |||
| TEST_F(TestTcpServer, ServerSendMessage) { | |||
| client_ = new TcpClient("127.0.0.1", 9000); | |||
| client_ = std::make_unique<TcpClient>("127.0.0.1", 9998); | |||
| std::unique_ptr<std::thread> http_client_thread(nullptr); | |||
| http_client_thread = std::make_unique<std::thread>([&]() { | |||
| client_->ReceiveMessage([](const TcpClient &client, const void *buffer, size_t num) { | |||
| EXPECT_STREQ(std::string(reinterpret_cast<const char *>(buffer), num).c_str(), "TCP_MESSAGE"); | |||
| client_->SetMessageCallback([](const TcpClient &client, const CommMessage &message) { | |||
| KVMessage kv_message; | |||
| kv_message.ParseFromString(message.data()); | |||
| EXPECT_EQ(2, kv_message.keys_size()); | |||
| }); | |||
| client_->InitTcpClient(); | |||
| client_->SendMessage(test_message_.c_str(), test_message_.size()); | |||
| client_->Init(); | |||
| CommMessage comm_message; | |||
| KVMessage kv_message; | |||
| std::vector<int> keys{1, 2}; | |||
| std::vector<int> values{3, 4}; | |||
| *kv_message.mutable_keys() = {keys.begin(), keys.end()}; | |||
| *kv_message.mutable_values() = {values.begin(), values.end()}; | |||
| comm_message.set_data(kv_message.SerializeAsString()); | |||
| client_->SendMessage(comm_message); | |||
| client_->Start(); | |||
| }); | |||
| http_client_thread->detach(); | |||