You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tcp_server.cc 8.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "ps/comm/tcp_server.h"
  17. #include <arpa/inet.h>
  18. #include <event2/buffer.h>
  19. #include <event2/bufferevent.h>
  20. #include <event2/event.h>
  21. #include <event2/listener.h>
  22. #include <event2/buffer_compat.h>
  23. #include <event2/util.h>
  24. #include <sys/socket.h>
  25. #include <csignal>
  26. #include <utility>
  27. #include "ps/comm/comm_util.h"
  28. namespace mindspore {
  29. namespace ps {
  30. namespace comm {
  31. void TcpConnection::InitConnection(const evutil_socket_t &fd, const struct bufferevent *bev, const TcpServer *server) {
  32. MS_EXCEPTION_IF_NULL(bev);
  33. MS_EXCEPTION_IF_NULL(server);
  34. buffer_event_ = const_cast<struct bufferevent *>(bev);
  35. fd_ = fd;
  36. server_ = const_cast<TcpServer *>(server);
  37. tcp_message_handler_.SetCallback([this, server](const void *buf, size_t num) {
  38. OnServerReceiveMessage message_callback = server->GetServerReceiveMessage();
  39. if (message_callback) message_callback(*server, *this, buf, num);
  40. });
  41. }
  42. void TcpConnection::OnReadHandler(const void *buffer, size_t num) { tcp_message_handler_.ReceiveMessage(buffer, num); }
  43. void TcpConnection::SendMessage(const void *buffer, size_t num) const {
  44. if (bufferevent_write(buffer_event_, buffer, num) == -1) {
  45. MS_LOG(ERROR) << "Write message to buffer event failed!";
  46. }
  47. }
  48. TcpServer *TcpConnection::GetServer() const { return server_; }
  49. evutil_socket_t TcpConnection::GetFd() const { return fd_; }
  50. TcpServer::TcpServer(std::string address, std::uint16_t port)
  51. : base_(nullptr),
  52. signal_event_(nullptr),
  53. listener_(nullptr),
  54. server_address_(std::move(address)),
  55. server_port_(port) {}
  56. TcpServer::~TcpServer() { Stop(); }
  57. void TcpServer::SetServerCallback(const OnConnected &client_conn, const OnDisconnected &client_disconn,
  58. const OnAccepted &client_accept) {
  59. this->client_connection_ = client_conn;
  60. this->client_disconnection_ = client_disconn;
  61. this->client_accept_ = client_accept;
  62. }
  63. void TcpServer::InitServer() {
  64. base_ = event_base_new();
  65. MS_EXCEPTION_IF_NULL(base_);
  66. CommUtil::CheckIp(server_address_);
  67. struct sockaddr_in sin {};
  68. if (memset_s(&sin, sizeof(sin), 0, sizeof(sin)) != EOK) {
  69. MS_LOG(EXCEPTION) << "Initialize sockaddr_in failed!";
  70. }
  71. sin.sin_family = AF_INET;
  72. sin.sin_port = htons(server_port_);
  73. sin.sin_addr.s_addr = inet_addr(server_address_.c_str());
  74. listener_ = evconnlistener_new_bind(base_, ListenerCallback, reinterpret_cast<void *>(this),
  75. LEV_OPT_REUSEABLE | LEV_OPT_CLOSE_ON_FREE, -1,
  76. reinterpret_cast<struct sockaddr *>(&sin), sizeof(sin));
  77. MS_EXCEPTION_IF_NULL(listener_);
  78. signal_event_ = evsignal_new(base_, SIGINT, SignalCallback, reinterpret_cast<void *>(this));
  79. MS_EXCEPTION_IF_NULL(signal_event_);
  80. if (event_add(signal_event_, nullptr) < 0) {
  81. MS_LOG(EXCEPTION) << "Cannot create signal event.";
  82. }
  83. }
  84. void TcpServer::Start() {
  85. std::unique_lock<std::recursive_mutex> l(connection_mutex_);
  86. MS_EXCEPTION_IF_NULL(base_);
  87. int ret = event_base_dispatch(base_);
  88. if (ret == 0) {
  89. MS_LOG(INFO) << "Event base dispatch success!";
  90. } else if (ret == 1) {
  91. MS_LOG(ERROR) << "Event base dispatch failed with no events pending or active!";
  92. } else if (ret == -1) {
  93. MS_LOG(ERROR) << "Event base dispatch failed with error occurred!";
  94. } else {
  95. MS_LOG(EXCEPTION) << "Event base dispatch with unexpect error code!";
  96. }
  97. }
  98. void TcpServer::Stop() {
  99. if (signal_event_ != nullptr) {
  100. event_free(signal_event_);
  101. signal_event_ = nullptr;
  102. }
  103. if (listener_ != nullptr) {
  104. evconnlistener_free(listener_);
  105. listener_ = nullptr;
  106. }
  107. if (base_ != nullptr) {
  108. event_base_free(base_);
  109. base_ = nullptr;
  110. }
  111. }
  112. void TcpServer::SendToAllClients(const char *data, size_t len) {
  113. MS_EXCEPTION_IF_NULL(data);
  114. std::unique_lock<std::recursive_mutex> lock(connection_mutex_);
  115. for (auto it = connections_.begin(); it != connections_.end(); ++it) {
  116. it->second->SendMessage(data, len);
  117. }
  118. }
  119. void TcpServer::AddConnection(const evutil_socket_t &fd, const TcpConnection *connection) {
  120. MS_EXCEPTION_IF_NULL(connection);
  121. std::unique_lock<std::recursive_mutex> lock(connection_mutex_);
  122. connections_.insert(std::make_pair(fd, connection));
  123. }
  124. void TcpServer::RemoveConnection(const evutil_socket_t &fd) {
  125. std::unique_lock<std::recursive_mutex> lock(connection_mutex_);
  126. connections_.erase(fd);
  127. }
  128. void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, struct sockaddr *, int, void *data) {
  129. auto server = reinterpret_cast<class TcpServer *>(data);
  130. auto base = reinterpret_cast<struct event_base *>(server->base_);
  131. MS_EXCEPTION_IF_NULL(server);
  132. MS_EXCEPTION_IF_NULL(base);
  133. struct bufferevent *bev = bufferevent_socket_new(base, fd, BEV_OPT_CLOSE_ON_FREE);
  134. if (!bev) {
  135. MS_LOG(ERROR) << "Error constructing buffer event!";
  136. event_base_loopbreak(base);
  137. return;
  138. }
  139. TcpConnection *conn = server->onCreateConnection();
  140. MS_EXCEPTION_IF_NULL(conn);
  141. conn->InitConnection(fd, bev, server);
  142. server->AddConnection(fd, conn);
  143. bufferevent_setcb(bev, TcpServer::ReadCallback, nullptr, TcpServer::EventCallback, reinterpret_cast<void *>(conn));
  144. if (bufferevent_enable(bev, EV_READ | EV_WRITE) == -1) {
  145. MS_LOG(EXCEPTION) << "buffer event enable read and write failed!";
  146. }
  147. }
  148. TcpConnection *TcpServer::onCreateConnection() {
  149. TcpConnection *conn = nullptr;
  150. if (client_accept_)
  151. conn = const_cast<TcpConnection *>(client_accept_(this));
  152. else
  153. conn = new TcpConnection();
  154. return conn;
  155. }
  156. OnServerReceiveMessage TcpServer::GetServerReceiveMessage() const { return message_callback_; }
  157. void TcpServer::SignalCallback(evutil_socket_t, std::int16_t, void *data) {
  158. auto server = reinterpret_cast<class TcpServer *>(data);
  159. MS_EXCEPTION_IF_NULL(server);
  160. struct event_base *base = server->base_;
  161. struct timeval delay = {0, 0};
  162. MS_LOG(ERROR) << "Caught an interrupt signal; exiting cleanly in 0 seconds.";
  163. if (event_base_loopexit(base, &delay) == -1) {
  164. MS_LOG(EXCEPTION) << "event base loop exit failed.";
  165. }
  166. }
  167. void TcpServer::ReadCallback(struct bufferevent *bev, void *connection) {
  168. MS_EXCEPTION_IF_NULL(bev);
  169. MS_EXCEPTION_IF_NULL(connection);
  170. auto conn = static_cast<class TcpConnection *>(connection);
  171. struct evbuffer *buf = bufferevent_get_input(bev);
  172. char read_buffer[4096];
  173. auto read = 0;
  174. while ((read = EVBUFFER_LENGTH(buf)) > 0) {
  175. if (evbuffer_remove(buf, &read_buffer, sizeof(read_buffer)) == -1) {
  176. MS_LOG(EXCEPTION) << "Can not drain data from the event buffer!";
  177. }
  178. conn->OnReadHandler(read_buffer, static_cast<size_t>(read));
  179. }
  180. }
  181. void TcpServer::EventCallback(struct bufferevent *bev, std::int16_t events, void *data) {
  182. MS_EXCEPTION_IF_NULL(bev);
  183. MS_EXCEPTION_IF_NULL(data);
  184. auto conn = reinterpret_cast<TcpConnection *>(data);
  185. TcpServer *srv = conn->GetServer();
  186. if (events & BEV_EVENT_EOF) {
  187. // Notify about disconnection
  188. if (srv->client_disconnection_) srv->client_disconnection_(srv, conn);
  189. // Free connection structures
  190. srv->RemoveConnection(conn->GetFd());
  191. bufferevent_free(bev);
  192. } else if (events & BEV_EVENT_ERROR) {
  193. // Free connection structures
  194. srv->RemoveConnection(conn->GetFd());
  195. bufferevent_free(bev);
  196. // Notify about disconnection
  197. if (srv->client_disconnection_) srv->client_disconnection_(srv, conn);
  198. } else {
  199. MS_LOG(ERROR) << "unhandled event!";
  200. }
  201. }
  202. void TcpServer::ReceiveMessage(const OnServerReceiveMessage &cb) { message_callback_ = cb; }
  203. void TcpServer::SendMessage(const TcpConnection &conn, const void *data, size_t num) {
  204. MS_EXCEPTION_IF_NULL(data);
  205. auto mc = const_cast<TcpConnection &>(conn);
  206. mc.SendMessage(data, num);
  207. }
  208. void TcpServer::SendMessage(const void *data, size_t num) {
  209. MS_EXCEPTION_IF_NULL(data);
  210. std::unique_lock<std::recursive_mutex> lock(connection_mutex_);
  211. for (auto it = connections_.begin(); it != connections_.end(); ++it) {
  212. SendMessage(*it->second, data, num);
  213. }
  214. }
  215. } // namespace comm
  216. } // namespace ps
  217. } // namespace mindspore