Merge pull request !7518 from anancds/tcp-servertags/v1.1.0
| @@ -7,6 +7,10 @@ if (NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))) | |||
| list(REMOVE_ITEM _PS_SRC_FILES "util.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "comm/http_message_handler.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "comm/http_server.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "comm/comm_util.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "comm/tcp_client.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "comm/tcp_message_handler.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "comm/tcp_server.cc") | |||
| endif() | |||
| set_property(SOURCE ${_PS_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PS) | |||
| @@ -0,0 +1,50 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/comm/comm_util.h" | |||
| #include <arpa/inet.h> | |||
| #include <cstdio> | |||
| #include <cstdlib> | |||
| #include <cstring> | |||
| #include <functional> | |||
| #include <regex> | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace comm { | |||
| bool CommUtil::CheckIpWithRegex(const std::string &ip) { | |||
| std::regex pattern("((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?).){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)"); | |||
| std::smatch res; | |||
| if (regex_match(ip, res, pattern)) { | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| void CommUtil::CheckIp(const std::string &ip) { | |||
| if (!CheckIpWithRegex(ip)) { | |||
| MS_LOG(EXCEPTION) << "Server address" << ip << " illegal!"; | |||
| } | |||
| int64_t uAddr = inet_addr(ip.c_str()); | |||
| if (INADDR_NONE == uAddr) { | |||
| MS_LOG(EXCEPTION) << "Server address illegal, inet_addr converting failed!"; | |||
| } | |||
| } | |||
| } // namespace comm | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,49 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_COMM_COMM_UTIL_H_ | |||
| #define MINDSPORE_CCSRC_PS_COMM_COMM_UTIL_H_ | |||
| #include <event2/buffer.h> | |||
| #include <event2/event.h> | |||
| #include <event2/http.h> | |||
| #include <event2/keyvalq_struct.h> | |||
| #include <event2/listener.h> | |||
| #include <event2/util.h> | |||
| #include <cstdio> | |||
| #include <cstdlib> | |||
| #include <cstring> | |||
| #include <functional> | |||
| #include <string> | |||
| #include <utility> | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace comm { | |||
| class CommUtil { | |||
| public: | |||
| static bool CheckIpWithRegex(const std::string &ip); | |||
| static void CheckIp(const std::string &ip); | |||
| }; | |||
| } // namespace comm | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_COMM_COMM_UTIL_H_ | |||
| @@ -38,7 +38,7 @@ namespace mindspore { | |||
| namespace ps { | |||
| namespace comm { | |||
| typedef std::map<std::string, std::list<std::string>> HttpHeaders; | |||
| using HttpHeaders = std::map<std::string, std::list<std::string>>; | |||
| class HttpMessageHandler { | |||
| public: | |||
| @@ -16,6 +16,7 @@ | |||
| #include "ps/comm/http_server.h" | |||
| #include "ps/comm/http_message_handler.h" | |||
| #include "ps/comm/comm_util.h" | |||
| #ifdef WIN32 | |||
| #include <WinSock2.h> | |||
| @@ -41,28 +42,10 @@ namespace mindspore { | |||
| namespace ps { | |||
| namespace comm { | |||
| HttpServer::~HttpServer() { | |||
| if (event_http_) { | |||
| evhttp_free(event_http_); | |||
| event_http_ = nullptr; | |||
| } | |||
| if (event_base_) { | |||
| event_base_free(event_base_); | |||
| event_base_ = nullptr; | |||
| } | |||
| } | |||
| HttpServer::~HttpServer() { Stop(); } | |||
| bool HttpServer::InitServer() { | |||
| if (!CheckIp(server_address_)) { | |||
| MS_LOG(EXCEPTION) << "Server address" << server_address_ << " illegal!"; | |||
| } | |||
| int64_t uAddr = inet_addr(server_address_.c_str()); | |||
| if (INADDR_NONE == uAddr) { | |||
| MS_LOG(EXCEPTION) << "Server address illegal, inet_addr converting failed!"; | |||
| } | |||
| if (server_port_ <= 0) { | |||
| MS_LOG(EXCEPTION) << "Server port:" << server_port_ << " illegal!"; | |||
| } | |||
| CommUtil::CheckIp(server_address_); | |||
| event_base_ = event_base_new(); | |||
| MS_EXCEPTION_IF_NULL(event_base_); | |||
| @@ -76,15 +59,6 @@ bool HttpServer::InitServer() { | |||
| return true; | |||
| } | |||
| bool HttpServer::CheckIp(const std::string &ip) { | |||
| std::regex pattern("((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?).){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)"); | |||
| std::smatch res; | |||
| if (regex_match(ip, res, pattern)) { | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| void HttpServer::SetTimeOut(int seconds) { | |||
| MS_EXCEPTION_IF_NULL(event_http_); | |||
| if (seconds < 0) { | |||
| @@ -93,7 +67,7 @@ void HttpServer::SetTimeOut(int seconds) { | |||
| evhttp_set_timeout(event_http_, seconds); | |||
| } | |||
| void HttpServer::SetAllowedMethod(HttpMethodsSet methods) { | |||
| void HttpServer::SetAllowedMethod(u_int16_t methods) { | |||
| MS_EXCEPTION_IF_NULL(event_http_); | |||
| evhttp_set_allowed_methods(event_http_, methods); | |||
| } | |||
| @@ -114,12 +88,11 @@ void HttpServer::SetMaxBodySize(size_t num) { | |||
| evhttp_set_max_body_size(event_http_, num); | |||
| } | |||
| bool HttpServer::RegisterRoute(const std::string &url, handle_t *function) { | |||
| bool HttpServer::RegisterRoute(const std::string &url, OnRequestReceive *function) { | |||
| if ((!is_init_) && (!InitServer())) { | |||
| MS_LOG(EXCEPTION) << "Init http server failed!"; | |||
| } | |||
| HandlerFunc func = function; | |||
| if (!func) { | |||
| if (!function) { | |||
| return false; | |||
| } | |||
| @@ -128,15 +101,13 @@ bool HttpServer::RegisterRoute(const std::string &url, handle_t *function) { | |||
| MS_EXCEPTION_IF_NULL(arg); | |||
| HttpMessageHandler httpReq(req); | |||
| httpReq.InitHttpMessage(); | |||
| handle_t *f = reinterpret_cast<handle_t *>(arg); | |||
| f(&httpReq); | |||
| OnRequestReceive *func = reinterpret_cast<OnRequestReceive *>(arg); | |||
| (*func)(&httpReq); | |||
| }; | |||
| handle_t **pph = func.target<handle_t *>(); | |||
| MS_EXCEPTION_IF_NULL(pph); | |||
| MS_EXCEPTION_IF_NULL(event_http_); | |||
| // O SUCCESS,-1 ALREADY_EXIST,-2 FAILURE | |||
| int ret = evhttp_set_cb(event_http_, url.c_str(), TransFunc, reinterpret_cast<void *>(*pph)); | |||
| int ret = evhttp_set_cb(event_http_, url.c_str(), TransFunc, reinterpret_cast<void *>(function)); | |||
| if (ret == 0) { | |||
| MS_LOG(INFO) << "Ev http register handle of:" << url.c_str() << " success."; | |||
| } else if (ret == -1) { | |||
| @@ -48,26 +48,21 @@ typedef enum eHttpMethod { | |||
| HM_PATCH = 1 << 8 | |||
| } HttpMethod; | |||
| typedef u_int16_t HttpMethodsSet; | |||
| typedef void(handle_t)(HttpMessageHandler *); | |||
| class HttpServer { | |||
| public: | |||
| // Server address only support IPV4 now, and should be in format of "x.x.x.x" | |||
| explicit HttpServer(const std::string &address, std::int16_t port) | |||
| explicit HttpServer(const std::string &address, std::uint16_t port) | |||
| : server_address_(address), server_port_(port), event_base_(nullptr), event_http_(nullptr), is_init_(false) {} | |||
| ~HttpServer(); | |||
| typedef std::function<handle_t> HandlerFunc; | |||
| using OnRequestReceive = std::function<void(HttpMessageHandler *)>; | |||
| bool InitServer(); | |||
| static bool CheckIp(const std::string &ip); | |||
| void SetTimeOut(int seconds = 5); | |||
| // Default allowed methods: GET, POST, HEAD, PUT, DELETE | |||
| void SetAllowedMethod(HttpMethodsSet methods); | |||
| void SetAllowedMethod(u_int16_t methods); | |||
| // Default to ((((unsigned long long)0xffffffffUL) << 32) | 0xffffffffUL) | |||
| void SetMaxHeaderSize(std::size_t num); | |||
| @@ -76,7 +71,7 @@ class HttpServer { | |||
| void SetMaxBodySize(std::size_t num); | |||
| // Return: true if success, false if failed, check log to find failure reason | |||
| bool RegisterRoute(const std::string &url, handle_t *func); | |||
| bool RegisterRoute(const std::string &url, OnRequestReceive *func); | |||
| bool UnRegisterRoute(const std::string &url); | |||
| bool Start(); | |||
| @@ -84,7 +79,7 @@ class HttpServer { | |||
| private: | |||
| std::string server_address_; | |||
| std::int16_t server_port_; | |||
| std::uint16_t server_port_; | |||
| struct event_base *event_base_; | |||
| struct evhttp *event_http_; | |||
| bool is_init_; | |||
| @@ -0,0 +1,220 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/comm/tcp_client.h" | |||
| #include <arpa/inet.h> | |||
| #include <event2/buffer.h> | |||
| #include <event2/bufferevent.h> | |||
| #include <event2/buffer_compat.h> | |||
| #include <event2/event.h> | |||
| #include <netinet/in.h> | |||
| #include <netinet/tcp.h> | |||
| #include <sys/socket.h> | |||
| #include <cstdlib> | |||
| #include <cstring> | |||
| #include <iostream> | |||
| #include <utility> | |||
| #include <string> | |||
| #include "ps/comm/comm_util.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace comm { | |||
| TcpClient::TcpClient(std::string address, std::uint16_t port) | |||
| : event_base_(nullptr), | |||
| event_timeout_(nullptr), | |||
| buffer_event_(nullptr), | |||
| server_address_(std::move(address)), | |||
| server_port_(port) { | |||
| message_handler_.SetCallback([this](const void *buf, size_t num) { | |||
| if (buf == nullptr) { | |||
| if (disconnected_callback_) disconnected_callback_(*this, 200); | |||
| Stop(); | |||
| } | |||
| if (message_callback_) message_callback_(*this, buf, num); | |||
| }); | |||
| } | |||
| TcpClient::~TcpClient() { Stop(); } | |||
| std::string TcpClient::GetServerAddress() const { return server_address_; } | |||
| void TcpClient::SetCallback(const OnConnected &conn, const OnDisconnected &disconn, const OnRead &read, | |||
| const OnTimeout &timeout) { | |||
| connected_callback_ = conn; | |||
| disconnected_callback_ = disconn; | |||
| read_callback_ = read; | |||
| timeout_callback_ = timeout; | |||
| } | |||
| void TcpClient::InitTcpClient() { | |||
| if (buffer_event_) { | |||
| return; | |||
| } | |||
| CommUtil::CheckIp(server_address_); | |||
| event_base_ = event_base_new(); | |||
| MS_EXCEPTION_IF_NULL(event_base_); | |||
| sockaddr_in sin{}; | |||
| if (memset_s(&sin, sizeof(sin), 0, sizeof(sin)) != EOK) { | |||
| MS_LOG(EXCEPTION) << "Initialize sockaddr_in failed!"; | |||
| } | |||
| sin.sin_family = AF_INET; | |||
| sin.sin_addr.s_addr = inet_addr(server_address_.c_str()); | |||
| sin.sin_port = htons(server_port_); | |||
| buffer_event_ = bufferevent_socket_new(event_base_, -1, BEV_OPT_CLOSE_ON_FREE); | |||
| MS_EXCEPTION_IF_NULL(buffer_event_); | |||
| bufferevent_setcb(buffer_event_, ReadCallback, nullptr, EventCallback, this); | |||
| if (bufferevent_enable(buffer_event_, EV_READ | EV_WRITE) == -1) { | |||
| MS_LOG(EXCEPTION) << "buffer event enable read and write failed!"; | |||
| } | |||
| int result_code = bufferevent_socket_connect(buffer_event_, reinterpret_cast<struct sockaddr *>(&sin), sizeof(sin)); | |||
| if (result_code < 0) { | |||
| MS_LOG(EXCEPTION) << "Connect server ip:" << server_address_ << " and port: " << server_port_ << " is failed!"; | |||
| } | |||
| } | |||
| void TcpClient::StartWithDelay(int seconds) { | |||
| if (buffer_event_) { | |||
| return; | |||
| } | |||
| event_base_ = event_base_new(); | |||
| timeval timeout_value{}; | |||
| timeout_value.tv_sec = seconds; | |||
| timeout_value.tv_usec = 0; | |||
| event_timeout_ = evtimer_new(event_base_, TimeoutCallback, this); | |||
| if (evtimer_add(event_timeout_, &timeout_value) == -1) { | |||
| MS_LOG(EXCEPTION) << "event timeout failed!"; | |||
| } | |||
| } | |||
| void TcpClient::Stop() { | |||
| if (buffer_event_) { | |||
| bufferevent_free(buffer_event_); | |||
| buffer_event_ = nullptr; | |||
| } | |||
| if (event_timeout_) { | |||
| event_free(event_timeout_); | |||
| event_timeout_ = nullptr; | |||
| } | |||
| if (event_base_) { | |||
| event_base_free(event_base_); | |||
| event_base_ = nullptr; | |||
| } | |||
| } | |||
| void TcpClient::SetTcpNoDelay(const evutil_socket_t &fd) { | |||
| const int one = 1; | |||
| int ret = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &one, sizeof(int)); | |||
| if (ret < 0) { | |||
| MS_LOG(EXCEPTION) << "Set socket no delay failed!"; | |||
| } | |||
| } | |||
| void TcpClient::TimeoutCallback(evutil_socket_t, std::int16_t, void *arg) { | |||
| MS_EXCEPTION_IF_NULL(arg); | |||
| auto tcp_client = reinterpret_cast<TcpClient *>(arg); | |||
| tcp_client->InitTcpClient(); | |||
| } | |||
| void TcpClient::ReadCallback(struct bufferevent *bev, void *ctx) { | |||
| MS_EXCEPTION_IF_NULL(bev); | |||
| MS_EXCEPTION_IF_NULL(ctx); | |||
| auto tcp_client = reinterpret_cast<TcpClient *>(ctx); | |||
| struct evbuffer *input = bufferevent_get_input(const_cast<struct bufferevent *>(bev)); | |||
| MS_EXCEPTION_IF_NULL(input); | |||
| char read_buffer[4096]; | |||
| int read = 0; | |||
| while ((read = EVBUFFER_LENGTH(input)) > 0) { | |||
| if (evbuffer_remove(input, &read_buffer, sizeof(read_buffer)) == -1) { | |||
| MS_LOG(EXCEPTION) << "Can not drain data from the event buffer!"; | |||
| } | |||
| tcp_client->OnReadHandler(read_buffer, read); | |||
| } | |||
| } | |||
| void TcpClient::OnReadHandler(const void *buf, size_t num) { | |||
| MS_EXCEPTION_IF_NULL(buf); | |||
| if (read_callback_) { | |||
| read_callback_(*this, buf, num); | |||
| } | |||
| message_handler_.ReceiveMessage(buf, num); | |||
| } | |||
| void TcpClient::EventCallback(struct bufferevent *bev, std::int16_t events, void *ptr) { | |||
| MS_EXCEPTION_IF_NULL(bev); | |||
| MS_EXCEPTION_IF_NULL(ptr); | |||
| auto tcp_client = reinterpret_cast<TcpClient *>(ptr); | |||
| if (events & BEV_EVENT_CONNECTED) { | |||
| // Connected | |||
| if (tcp_client->connected_callback_) { | |||
| tcp_client->connected_callback_(*tcp_client); | |||
| } | |||
| evutil_socket_t fd = bufferevent_getfd(const_cast<struct bufferevent *>(bev)); | |||
| SetTcpNoDelay(fd); | |||
| MS_LOG(INFO) << "Client connected!"; | |||
| } else if (events & BEV_EVENT_ERROR) { | |||
| MS_LOG(ERROR) << "Client connected error!"; | |||
| if (tcp_client->disconnected_callback_) { | |||
| tcp_client->disconnected_callback_(*tcp_client, errno); | |||
| } | |||
| } else if (events & BEV_EVENT_EOF) { | |||
| MS_LOG(ERROR) << "Client connected end of file"; | |||
| if (tcp_client->disconnected_callback_) { | |||
| tcp_client->disconnected_callback_(*tcp_client, 0); | |||
| } | |||
| } | |||
| } | |||
| void TcpClient::Start() { | |||
| MS_EXCEPTION_IF_NULL(event_base_); | |||
| int ret = event_base_dispatch(event_base_); | |||
| if (ret == 0) { | |||
| MS_LOG(INFO) << "Event base dispatch success!"; | |||
| } else if (ret == 1) { | |||
| MS_LOG(ERROR) << "Event base dispatch failed with no events pending or active!"; | |||
| } else if (ret == -1) { | |||
| MS_LOG(ERROR) << "Event base dispatch failed with error occurred!"; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Event base dispatch with unexpect error code!"; | |||
| } | |||
| } | |||
| void TcpClient::ReceiveMessage(const OnMessage &cb) { message_callback_ = cb; } | |||
| void TcpClient::SendMessage(const void *buf, size_t num) const { | |||
| MS_EXCEPTION_IF_NULL(buffer_event_); | |||
| if (evbuffer_add(bufferevent_get_output(buffer_event_), buf, num) == -1) { | |||
| MS_LOG(EXCEPTION) << "event buffer add failed!"; | |||
| } | |||
| } | |||
| } // namespace comm | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,77 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_COMM_TCP_CLIENT_H_ | |||
| #define MINDSPORE_CCSRC_PS_COMM_TCP_CLIENT_H_ | |||
| #include "ps/comm/tcp_message_handler.h" | |||
| #include <event2/event.h> | |||
| #include <event2/bufferevent.h> | |||
| #include <functional> | |||
| #include <string> | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace comm { | |||
| class TcpClient { | |||
| public: | |||
| using OnMessage = std::function<void(const TcpClient &, const void *, size_t)>; | |||
| using OnConnected = std::function<void(const TcpClient &)>; | |||
| using OnDisconnected = std::function<void(const TcpClient &, int)>; | |||
| using OnRead = std::function<void(const TcpClient &, const void *, size_t)>; | |||
| using OnTimeout = std::function<void(const TcpClient &)>; | |||
| explicit TcpClient(std::string address, std::uint16_t port); | |||
| virtual ~TcpClient(); | |||
| std::string GetServerAddress() const; | |||
| void SetCallback(const OnConnected &conn, const OnDisconnected &disconn, const OnRead &read, | |||
| const OnTimeout &timeout); | |||
| void InitTcpClient(); | |||
| void StartWithDelay(int seconds); | |||
| void Stop(); | |||
| void ReceiveMessage(const OnMessage &cb); | |||
| void SendMessage(const void *buf, size_t num) const; | |||
| void Start(); | |||
| protected: | |||
| static void SetTcpNoDelay(const evutil_socket_t &fd); | |||
| static void TimeoutCallback(evutil_socket_t fd, std::int16_t what, void *arg); | |||
| static void ReadCallback(struct bufferevent *bev, void *ctx); | |||
| static void EventCallback(struct bufferevent *bev, std::int16_t events, void *ptr); | |||
| virtual void OnReadHandler(const void *buf, size_t num); | |||
| private: | |||
| TcpMessageHandler message_handler_; | |||
| OnMessage message_callback_; | |||
| OnConnected connected_callback_; | |||
| OnDisconnected disconnected_callback_; | |||
| OnRead read_callback_; | |||
| OnTimeout timeout_callback_; | |||
| event_base *event_base_; | |||
| event *event_timeout_; | |||
| bufferevent *buffer_event_; | |||
| std::string server_address_; | |||
| std::uint16_t server_port_; | |||
| }; | |||
| } // namespace comm | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_COMM_TCP_CLIENT_H_ | |||
| @@ -0,0 +1,36 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/comm/tcp_message_handler.h" | |||
| #include <iostream> | |||
| #include <utility> | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace comm { | |||
| void TcpMessageHandler::SetCallback(messageReceive message_receive) { message_callback_ = std::move(message_receive); } | |||
| void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) { | |||
| MS_EXCEPTION_IF_NULL(buffer); | |||
| if (message_callback_) { | |||
| message_callback_(buffer, num); | |||
| } | |||
| } | |||
| } // namespace comm | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,47 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_COMM_TCP_MESSAGE_HANDLER_H_ | |||
| #define MINDSPORE_CCSRC_PS_COMM_TCP_MESSAGE_HANDLER_H_ | |||
| #include <functional> | |||
| #include <iostream> | |||
| #include <memory> | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace comm { | |||
| using messageReceive = std::function<void(const void *buffer, size_t len)>; | |||
| class TcpMessageHandler { | |||
| public: | |||
| TcpMessageHandler() = default; | |||
| virtual ~TcpMessageHandler() = default; | |||
| void SetCallback(messageReceive cb); | |||
| void ReceiveMessage(const void *buffer, size_t num); | |||
| private: | |||
| messageReceive message_callback_; | |||
| }; | |||
| } // namespace comm | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_COMM_TCP_MESSAGE_HANDLER_H_ | |||
| @@ -0,0 +1,259 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/comm/tcp_server.h" | |||
| #include <arpa/inet.h> | |||
| #include <event2/buffer.h> | |||
| #include <event2/bufferevent.h> | |||
| #include <event2/event.h> | |||
| #include <event2/listener.h> | |||
| #include <event2/buffer_compat.h> | |||
| #include <event2/util.h> | |||
| #include <sys/socket.h> | |||
| #include <csignal> | |||
| #include <utility> | |||
| #include "ps/comm/comm_util.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace comm { | |||
| void TcpConnection::InitConnection(const evutil_socket_t &fd, const struct bufferevent *bev, const TcpServer *server) { | |||
| MS_EXCEPTION_IF_NULL(bev); | |||
| MS_EXCEPTION_IF_NULL(server); | |||
| buffer_event_ = const_cast<struct bufferevent *>(bev); | |||
| fd_ = fd; | |||
| server_ = const_cast<TcpServer *>(server); | |||
| tcp_message_handler_.SetCallback([this, server](const void *buf, size_t num) { | |||
| OnServerReceiveMessage message_callback = server->GetServerReceiveMessage(); | |||
| if (message_callback) message_callback(*server, *this, buf, num); | |||
| }); | |||
| } | |||
| void TcpConnection::OnReadHandler(const void *buffer, size_t num) { tcp_message_handler_.ReceiveMessage(buffer, num); } | |||
| void TcpConnection::SendMessage(const void *buffer, size_t num) const { | |||
| if (bufferevent_write(buffer_event_, buffer, num) == -1) { | |||
| MS_LOG(ERROR) << "Write message to buffer event failed!"; | |||
| } | |||
| } | |||
| TcpServer *TcpConnection::GetServer() const { return server_; } | |||
| evutil_socket_t TcpConnection::GetFd() const { return fd_; } | |||
| TcpServer::TcpServer(std::string address, std::uint16_t port) | |||
| : base_(nullptr), | |||
| signal_event_(nullptr), | |||
| listener_(nullptr), | |||
| server_address_(std::move(address)), | |||
| server_port_(port) {} | |||
| TcpServer::~TcpServer() { Stop(); } | |||
| void TcpServer::SetServerCallback(const OnConnected &client_conn, const OnDisconnected &client_disconn, | |||
| const OnAccepted &client_accept) { | |||
| this->client_connection_ = client_conn; | |||
| this->client_disconnection_ = client_disconn; | |||
| this->client_accept_ = client_accept; | |||
| } | |||
| void TcpServer::InitServer() { | |||
| base_ = event_base_new(); | |||
| MS_EXCEPTION_IF_NULL(base_); | |||
| CommUtil::CheckIp(server_address_); | |||
| struct sockaddr_in sin {}; | |||
| if (memset_s(&sin, sizeof(sin), 0, sizeof(sin)) != EOK) { | |||
| MS_LOG(EXCEPTION) << "Initialize sockaddr_in failed!"; | |||
| } | |||
| sin.sin_family = AF_INET; | |||
| sin.sin_port = htons(server_port_); | |||
| sin.sin_addr.s_addr = inet_addr(server_address_.c_str()); | |||
| listener_ = evconnlistener_new_bind(base_, ListenerCallback, reinterpret_cast<void *>(this), | |||
| LEV_OPT_REUSEABLE | LEV_OPT_CLOSE_ON_FREE, -1, | |||
| reinterpret_cast<struct sockaddr *>(&sin), sizeof(sin)); | |||
| MS_EXCEPTION_IF_NULL(listener_); | |||
| signal_event_ = evsignal_new(base_, SIGINT, SignalCallback, reinterpret_cast<void *>(this)); | |||
| MS_EXCEPTION_IF_NULL(signal_event_); | |||
| if (event_add(signal_event_, nullptr) < 0) { | |||
| MS_LOG(EXCEPTION) << "Cannot create signal event."; | |||
| } | |||
| } | |||
| void TcpServer::Start() { | |||
| std::unique_lock<std::recursive_mutex> l(connection_mutex_); | |||
| MS_EXCEPTION_IF_NULL(base_); | |||
| int ret = event_base_dispatch(base_); | |||
| if (ret == 0) { | |||
| MS_LOG(INFO) << "Event base dispatch success!"; | |||
| } else if (ret == 1) { | |||
| MS_LOG(ERROR) << "Event base dispatch failed with no events pending or active!"; | |||
| } else if (ret == -1) { | |||
| MS_LOG(ERROR) << "Event base dispatch failed with error occurred!"; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Event base dispatch with unexpect error code!"; | |||
| } | |||
| } | |||
| void TcpServer::Stop() { | |||
| if (signal_event_ != nullptr) { | |||
| event_free(signal_event_); | |||
| signal_event_ = nullptr; | |||
| } | |||
| if (listener_ != nullptr) { | |||
| evconnlistener_free(listener_); | |||
| listener_ = nullptr; | |||
| } | |||
| if (base_ != nullptr) { | |||
| event_base_free(base_); | |||
| base_ = nullptr; | |||
| } | |||
| } | |||
| void TcpServer::SendToAllClients(const char *data, size_t len) { | |||
| MS_EXCEPTION_IF_NULL(data); | |||
| std::unique_lock<std::recursive_mutex> lock(connection_mutex_); | |||
| for (auto it = connections_.begin(); it != connections_.end(); ++it) { | |||
| it->second->SendMessage(data, len); | |||
| } | |||
| } | |||
| void TcpServer::AddConnection(const evutil_socket_t &fd, const TcpConnection *connection) { | |||
| MS_EXCEPTION_IF_NULL(connection); | |||
| std::unique_lock<std::recursive_mutex> lock(connection_mutex_); | |||
| connections_.insert(std::make_pair(fd, connection)); | |||
| } | |||
| void TcpServer::RemoveConnection(const evutil_socket_t &fd) { | |||
| std::unique_lock<std::recursive_mutex> lock(connection_mutex_); | |||
| connections_.erase(fd); | |||
| } | |||
| void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, struct sockaddr *, int, void *data) { | |||
| auto server = reinterpret_cast<class TcpServer *>(data); | |||
| auto base = reinterpret_cast<struct event_base *>(server->base_); | |||
| MS_EXCEPTION_IF_NULL(server); | |||
| MS_EXCEPTION_IF_NULL(base); | |||
| struct bufferevent *bev = bufferevent_socket_new(base, fd, BEV_OPT_CLOSE_ON_FREE); | |||
| if (!bev) { | |||
| MS_LOG(ERROR) << "Error constructing buffer event!"; | |||
| event_base_loopbreak(base); | |||
| return; | |||
| } | |||
| TcpConnection *conn = server->onCreateConnection(); | |||
| MS_EXCEPTION_IF_NULL(conn); | |||
| conn->InitConnection(fd, bev, server); | |||
| server->AddConnection(fd, conn); | |||
| bufferevent_setcb(bev, TcpServer::ReadCallback, nullptr, TcpServer::EventCallback, reinterpret_cast<void *>(conn)); | |||
| if (bufferevent_enable(bev, EV_READ | EV_WRITE) == -1) { | |||
| MS_LOG(EXCEPTION) << "buffer event enable read and write failed!"; | |||
| } | |||
| } | |||
| TcpConnection *TcpServer::onCreateConnection() { | |||
| TcpConnection *conn = nullptr; | |||
| if (client_accept_) | |||
| conn = const_cast<TcpConnection *>(client_accept_(this)); | |||
| else | |||
| conn = new TcpConnection(); | |||
| return conn; | |||
| } | |||
| OnServerReceiveMessage TcpServer::GetServerReceiveMessage() const { return message_callback_; } | |||
| void TcpServer::SignalCallback(evutil_socket_t, std::int16_t, void *data) { | |||
| auto server = reinterpret_cast<class TcpServer *>(data); | |||
| MS_EXCEPTION_IF_NULL(server); | |||
| struct event_base *base = server->base_; | |||
| struct timeval delay = {0, 0}; | |||
| MS_LOG(ERROR) << "Caught an interrupt signal; exiting cleanly in 0 seconds."; | |||
| if (event_base_loopexit(base, &delay) == -1) { | |||
| MS_LOG(EXCEPTION) << "event base loop exit failed."; | |||
| } | |||
| } | |||
| void TcpServer::ReadCallback(struct bufferevent *bev, void *connection) { | |||
| MS_EXCEPTION_IF_NULL(bev); | |||
| MS_EXCEPTION_IF_NULL(connection); | |||
| auto conn = static_cast<class TcpConnection *>(connection); | |||
| struct evbuffer *buf = bufferevent_get_input(bev); | |||
| char read_buffer[4096]; | |||
| auto read = 0; | |||
| while ((read = EVBUFFER_LENGTH(buf)) > 0) { | |||
| if (evbuffer_remove(buf, &read_buffer, sizeof(read_buffer)) == -1) { | |||
| MS_LOG(EXCEPTION) << "Can not drain data from the event buffer!"; | |||
| } | |||
| conn->OnReadHandler(read_buffer, static_cast<size_t>(read)); | |||
| } | |||
| } | |||
| void TcpServer::EventCallback(struct bufferevent *bev, std::int16_t events, void *data) { | |||
| MS_EXCEPTION_IF_NULL(bev); | |||
| MS_EXCEPTION_IF_NULL(data); | |||
| auto conn = reinterpret_cast<TcpConnection *>(data); | |||
| TcpServer *srv = conn->GetServer(); | |||
| if (events & BEV_EVENT_EOF) { | |||
| // Notify about disconnection | |||
| if (srv->client_disconnection_) srv->client_disconnection_(srv, conn); | |||
| // Free connection structures | |||
| srv->RemoveConnection(conn->GetFd()); | |||
| bufferevent_free(bev); | |||
| } else if (events & BEV_EVENT_ERROR) { | |||
| // Free connection structures | |||
| srv->RemoveConnection(conn->GetFd()); | |||
| bufferevent_free(bev); | |||
| // Notify about disconnection | |||
| if (srv->client_disconnection_) srv->client_disconnection_(srv, conn); | |||
| } else { | |||
| MS_LOG(ERROR) << "unhandled event!"; | |||
| } | |||
| } | |||
| void TcpServer::ReceiveMessage(const OnServerReceiveMessage &cb) { message_callback_ = cb; } | |||
| void TcpServer::SendMessage(const TcpConnection &conn, const void *data, size_t num) { | |||
| MS_EXCEPTION_IF_NULL(data); | |||
| auto mc = const_cast<TcpConnection &>(conn); | |||
| mc.SendMessage(data, num); | |||
| } | |||
| void TcpServer::SendMessage(const void *data, size_t num) { | |||
| MS_EXCEPTION_IF_NULL(data); | |||
| std::unique_lock<std::recursive_mutex> lock(connection_mutex_); | |||
| for (auto it = connections_.begin(); it != connections_.end(); ++it) { | |||
| SendMessage(*it->second, data, num); | |||
| } | |||
| } | |||
| } // namespace comm | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,107 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_COMM_TCP_SERVER_H_ | |||
| #define MINDSPORE_CCSRC_PS_COMM_TCP_SERVER_H_ | |||
| #include <event2/buffer.h> | |||
| #include <event2/bufferevent.h> | |||
| #include <event2/event.h> | |||
| #include <event2/listener.h> | |||
| #include <exception> | |||
| #include <functional> | |||
| #include <iostream> | |||
| #include <map> | |||
| #include <mutex> | |||
| #include <string> | |||
| #include "utils/log_adapter.h" | |||
| #include "ps/comm/tcp_message_handler.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace comm { | |||
| class TcpServer; | |||
| class TcpConnection { | |||
| public: | |||
| TcpConnection() : buffer_event_(nullptr), fd_(0), server_(nullptr) {} | |||
| virtual ~TcpConnection() = default; | |||
| virtual void InitConnection(const evutil_socket_t &fd, const struct bufferevent *bev, const TcpServer *server); | |||
| void SendMessage(const void *buffer, size_t num) const; | |||
| virtual void OnReadHandler(const void *buffer, size_t numBytes); | |||
| TcpServer *GetServer() const; | |||
| evutil_socket_t GetFd() const; | |||
| protected: | |||
| TcpMessageHandler tcp_message_handler_; | |||
| struct bufferevent *buffer_event_; | |||
| evutil_socket_t fd_; | |||
| TcpServer *server_; | |||
| }; | |||
| using OnServerReceiveMessage = | |||
| std::function<void(const TcpServer &tcp_server, const TcpConnection &conn, const void *buffer, size_t num)>; | |||
| class TcpServer { | |||
| public: | |||
| using OnConnected = std::function<void(const TcpServer *, const TcpConnection *)>; | |||
| using OnDisconnected = std::function<void(const TcpServer *, const TcpConnection *)>; | |||
| using OnAccepted = std::function<const TcpConnection *(const TcpServer *)>; | |||
| explicit TcpServer(std::string address, std::uint16_t port); | |||
| virtual ~TcpServer(); | |||
| void SetServerCallback(const OnConnected &client_conn, const OnDisconnected &client_disconn, | |||
| const OnAccepted &client_accept); | |||
| void InitServer(); | |||
| void Start(); | |||
| void Stop(); | |||
| void SendToAllClients(const char *data, size_t len); | |||
| void AddConnection(const evutil_socket_t &fd, const TcpConnection *connection); | |||
| void RemoveConnection(const evutil_socket_t &fd); | |||
| void ReceiveMessage(const OnServerReceiveMessage &cb); | |||
| static void SendMessage(const TcpConnection &conn, const void *data, size_t num); | |||
| void SendMessage(const void *data, size_t num); | |||
| OnServerReceiveMessage GetServerReceiveMessage() const; | |||
| protected: | |||
| static void ListenerCallback(struct evconnlistener *listener, evutil_socket_t socket, struct sockaddr *saddr, | |||
| int socklen, void *server); | |||
| static void SignalCallback(evutil_socket_t sig, std::int16_t events, void *server); | |||
| static void ReadCallback(struct bufferevent *, void *connection); | |||
| static void EventCallback(struct bufferevent *, std::int16_t events, void *server); | |||
| virtual TcpConnection *onCreateConnection(); | |||
| private: | |||
| struct event_base *base_; | |||
| struct event *signal_event_; | |||
| struct evconnlistener *listener_; | |||
| std::string server_address_; | |||
| std::uint16_t server_port_; | |||
| std::map<evutil_socket_t, const TcpConnection *> connections_; | |||
| OnConnected client_connection_; | |||
| OnDisconnected client_disconnection_; | |||
| OnAccepted client_accept_; | |||
| std::recursive_mutex connection_mutex_; | |||
| OnServerReceiveMessage message_callback_; | |||
| }; | |||
| } // namespace comm | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_COMM_TCP_SERVER_H_ | |||
| @@ -31,7 +31,7 @@ namespace comm { | |||
| class TestHttpServer : public UT::Common { | |||
| public: | |||
| TestHttpServer() {} | |||
| TestHttpServer() = default; | |||
| static void testGetHandler(HttpMessageHandler *resp) { | |||
| std::string host = resp->GetRequestHost(); | |||
| @@ -58,16 +58,44 @@ class TestHttpServer : public UT::Common { | |||
| void SetUp() override { | |||
| server_ = new HttpServer("0.0.0.0", 9999); | |||
| server_->RegisterRoute("/httpget", [](HttpMessageHandler *resp) { | |||
| EXPECT_STREQ(resp->GetPathParam("key1").c_str(), "value1"); | |||
| EXPECT_STREQ(resp->GetUriQuery().c_str(), "key1=value1"); | |||
| EXPECT_STREQ(resp->GetRequestUri().c_str(), "/httpget?key1=value1"); | |||
| EXPECT_STREQ(resp->GetUriPath().c_str(), "/httpget"); | |||
| resp->QuickResponse(200, "get request success!\n"); | |||
| }); | |||
| server_->RegisterRoute("/handler", TestHttpServer::testGetHandler); | |||
| std::function<void(HttpMessageHandler *)> http_get_func = std::bind( | |||
| [](HttpMessageHandler *resp) { | |||
| EXPECT_STREQ(resp->GetPathParam("key1").c_str(), "value1"); | |||
| EXPECT_STREQ(resp->GetUriQuery().c_str(), "key1=value1"); | |||
| EXPECT_STREQ(resp->GetRequestUri().c_str(), "/httpget?key1=value1"); | |||
| EXPECT_STREQ(resp->GetUriPath().c_str(), "/httpget"); | |||
| resp->QuickResponse(200, "get request success!\n"); | |||
| }, | |||
| std::placeholders::_1); | |||
| std::function<void(HttpMessageHandler *)> http_handler_func = std::bind( | |||
| [](HttpMessageHandler *resp) { | |||
| std::string host = resp->GetRequestHost(); | |||
| EXPECT_STREQ(host.c_str(), "127.0.0.1"); | |||
| std::string path_param = resp->GetPathParam("key1"); | |||
| std::string header_param = resp->GetHeadParam("headerKey"); | |||
| std::string post_param = resp->GetPostParam("postKey"); | |||
| std::string post_message = resp->GetPostMsg(); | |||
| EXPECT_STREQ(path_param.c_str(), "value1"); | |||
| EXPECT_STREQ(header_param.c_str(), "headerValue"); | |||
| EXPECT_STREQ(post_param.c_str(), "postValue"); | |||
| EXPECT_STREQ(post_message.c_str(), "postKey=postValue"); | |||
| const std::string rKey("headKey"); | |||
| const std::string rVal("headValue"); | |||
| const std::string rBody("post request success!\n"); | |||
| resp->AddRespHeadParam(rKey, rVal); | |||
| resp->AddRespString(rBody); | |||
| resp->SetRespCode(200); | |||
| resp->SendResponse(); | |||
| }, | |||
| std::placeholders::_1); | |||
| server_->RegisterRoute("/httpget", &http_get_func); | |||
| server_->RegisterRoute("/handler", &http_handler_func); | |||
| std::unique_ptr<std::thread> http_server_thread_(nullptr); | |||
| http_server_thread_.reset(new std::thread([&]() { server_->Start(); })); | |||
| http_server_thread_ = std::make_unique<std::thread>([&]() { server_->Start(); }); | |||
| http_server_thread_->detach(); | |||
| } | |||
| @@ -110,14 +138,18 @@ TEST_F(TestHttpServer, messageHandler) { | |||
| pclose(file); | |||
| } | |||
| TEST_F(TestHttpServer, portException) { | |||
| TEST_F(TestHttpServer, portErrorNoException) { | |||
| HttpServer *server_exception = new HttpServer("0.0.0.0", -1); | |||
| ASSERT_THROW(server_exception->RegisterRoute("/handler", TestHttpServer::testGetHandler), std::exception); | |||
| std::function<void(HttpMessageHandler *)> http_handler_func = | |||
| std::bind(TestHttpServer::testGetHandler, std::placeholders::_1); | |||
| EXPECT_NO_THROW(server_exception->RegisterRoute("/handler", &http_handler_func)); | |||
| } | |||
| TEST_F(TestHttpServer, addressException) { | |||
| HttpServer *server_exception = new HttpServer("12344.0.0.0", 9998); | |||
| ASSERT_THROW(server_exception->RegisterRoute("/handler", TestHttpServer::testGetHandler), std::exception); | |||
| std::function<void(HttpMessageHandler *)> http_handler_func = | |||
| std::bind(TestHttpServer::testGetHandler, std::placeholders::_1); | |||
| ASSERT_THROW(server_exception->RegisterRoute("/handler", &http_handler_func), std::exception); | |||
| } | |||
| } // namespace comm | |||
| @@ -0,0 +1,46 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "common/common_test.h" | |||
| #include "ps/comm/tcp_client.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace comm { | |||
| class TestTcpClient : public UT::Common { | |||
| public: | |||
| TestTcpClient() = default; | |||
| }; | |||
| TEST_F(TestTcpClient, InitClientIPError) { | |||
| auto client = new TcpClient("127.0.0.13543", 9000); | |||
| client->ReceiveMessage( | |||
| [](const TcpClient &client, const void *buffer, size_t num) { client.SendMessage(buffer, num); }); | |||
| ASSERT_THROW(client->InitTcpClient(), std::exception); | |||
| } | |||
| TEST_F(TestTcpClient, InitClientPortErrorNoException) { | |||
| auto client = new TcpClient("127.0.0.1", -1); | |||
| client->ReceiveMessage( | |||
| [](const TcpClient &client, const void *buffer, size_t num) { client.SendMessage(buffer, num); }); | |||
| EXPECT_NO_THROW(client->InitTcpClient()); | |||
| } | |||
| } // namespace comm | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,71 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/comm/tcp_client.h" | |||
| #include "ps/comm/tcp_server.h" | |||
| #include "common/common_test.h" | |||
| #include <thread> | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace comm { | |||
| class TestTcpServer : public UT::Common { | |||
| public: | |||
| TestTcpServer() = default; | |||
| void SetUp() override { | |||
| server_ = new TcpServer("127.0.0.1", 9000); | |||
| std::unique_ptr<std::thread> http_server_thread_(nullptr); | |||
| http_server_thread_ = std::make_unique<std::thread>([&]() { | |||
| server_->ReceiveMessage([](const TcpServer &server, const TcpConnection &conn, const void *buffer, size_t num) { | |||
| EXPECT_STREQ(std::string(reinterpret_cast<const char *>(buffer), num).c_str(), "TCP_MESSAGE"); | |||
| server.SendMessage(conn, buffer, num); | |||
| }); | |||
| server_->InitServer(); | |||
| server_->Start(); | |||
| }); | |||
| http_server_thread_->detach(); | |||
| std::this_thread::sleep_for(std::chrono::milliseconds(2000)); | |||
| } | |||
| void TearDown() override { | |||
| std::this_thread::sleep_for(std::chrono::milliseconds(2000)); | |||
| client_->Stop(); | |||
| std::this_thread::sleep_for(std::chrono::milliseconds(2000)); | |||
| server_->Stop(); | |||
| } | |||
| TcpClient *client_; | |||
| TcpServer *server_; | |||
| const std::string test_message_ = "TCP_MESSAGE"; | |||
| }; | |||
| TEST_F(TestTcpServer, ServerSendMessage) { | |||
| client_ = new TcpClient("127.0.0.1", 9000); | |||
| std::unique_ptr<std::thread> http_client_thread(nullptr); | |||
| http_client_thread = std::make_unique<std::thread>([&]() { | |||
| client_->ReceiveMessage([](const TcpClient &client, const void *buffer, size_t num) { | |||
| EXPECT_STREQ(std::string(reinterpret_cast<const char *>(buffer), num).c_str(), "TCP_MESSAGE"); | |||
| }); | |||
| client_->InitTcpClient(); | |||
| client_->SendMessage(test_message_.c_str(), test_message_.size()); | |||
| client_->Start(); | |||
| }); | |||
| http_client_thread->detach(); | |||
| } | |||
| } // namespace comm | |||
| } // namespace ps | |||
| } // namespace mindspore | |||