You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tcp_client.cc 6.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "ps/comm/tcp_client.h"
  17. #include <arpa/inet.h>
  18. #include <event2/buffer.h>
  19. #include <event2/bufferevent.h>
  20. #include <event2/buffer_compat.h>
  21. #include <event2/event.h>
  22. #include <netinet/in.h>
  23. #include <netinet/tcp.h>
  24. #include <sys/socket.h>
  25. #include <cstdlib>
  26. #include <cstring>
  27. #include <iostream>
  28. #include <utility>
  29. #include <string>
  30. #include "ps/comm/comm_util.h"
  31. namespace mindspore {
  32. namespace ps {
  33. namespace comm {
  34. TcpClient::TcpClient(std::string address, std::uint16_t port)
  35. : event_base_(nullptr),
  36. event_timeout_(nullptr),
  37. buffer_event_(nullptr),
  38. server_address_(std::move(address)),
  39. server_port_(port) {
  40. message_handler_.SetCallback([this](const void *buf, size_t num) {
  41. if (buf == nullptr) {
  42. if (disconnected_callback_) disconnected_callback_(*this, 200);
  43. Stop();
  44. }
  45. if (message_callback_) message_callback_(*this, buf, num);
  46. });
  47. }
  48. TcpClient::~TcpClient() { Stop(); }
  49. std::string TcpClient::GetServerAddress() const { return server_address_; }
  50. void TcpClient::SetCallback(const OnConnected &conn, const OnDisconnected &disconn, const OnRead &read,
  51. const OnTimeout &timeout) {
  52. connected_callback_ = conn;
  53. disconnected_callback_ = disconn;
  54. read_callback_ = read;
  55. timeout_callback_ = timeout;
  56. }
  57. void TcpClient::InitTcpClient() {
  58. if (buffer_event_) {
  59. return;
  60. }
  61. CommUtil::CheckIp(server_address_);
  62. event_base_ = event_base_new();
  63. MS_EXCEPTION_IF_NULL(event_base_);
  64. sockaddr_in sin{};
  65. if (memset_s(&sin, sizeof(sin), 0, sizeof(sin)) != EOK) {
  66. MS_LOG(EXCEPTION) << "Initialize sockaddr_in failed!";
  67. }
  68. sin.sin_family = AF_INET;
  69. sin.sin_addr.s_addr = inet_addr(server_address_.c_str());
  70. sin.sin_port = htons(server_port_);
  71. buffer_event_ = bufferevent_socket_new(event_base_, -1, BEV_OPT_CLOSE_ON_FREE);
  72. MS_EXCEPTION_IF_NULL(buffer_event_);
  73. bufferevent_setcb(buffer_event_, ReadCallback, nullptr, EventCallback, this);
  74. if (bufferevent_enable(buffer_event_, EV_READ | EV_WRITE) == -1) {
  75. MS_LOG(EXCEPTION) << "buffer event enable read and write failed!";
  76. }
  77. int result_code = bufferevent_socket_connect(buffer_event_, reinterpret_cast<struct sockaddr *>(&sin), sizeof(sin));
  78. if (result_code < 0) {
  79. MS_LOG(EXCEPTION) << "Connect server ip:" << server_address_ << " and port: " << server_port_ << " is failed!";
  80. }
  81. }
  82. void TcpClient::StartWithDelay(int seconds) {
  83. if (buffer_event_) {
  84. return;
  85. }
  86. event_base_ = event_base_new();
  87. timeval timeout_value{};
  88. timeout_value.tv_sec = seconds;
  89. timeout_value.tv_usec = 0;
  90. event_timeout_ = evtimer_new(event_base_, TimeoutCallback, this);
  91. if (evtimer_add(event_timeout_, &timeout_value) == -1) {
  92. MS_LOG(EXCEPTION) << "event timeout failed!";
  93. }
  94. }
  95. void TcpClient::Stop() {
  96. if (buffer_event_) {
  97. bufferevent_free(buffer_event_);
  98. buffer_event_ = nullptr;
  99. }
  100. if (event_timeout_) {
  101. event_free(event_timeout_);
  102. event_timeout_ = nullptr;
  103. }
  104. if (event_base_) {
  105. event_base_free(event_base_);
  106. event_base_ = nullptr;
  107. }
  108. }
  109. void TcpClient::SetTcpNoDelay(const evutil_socket_t &fd) {
  110. const int one = 1;
  111. int ret = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &one, sizeof(int));
  112. if (ret < 0) {
  113. MS_LOG(EXCEPTION) << "Set socket no delay failed!";
  114. }
  115. }
  116. void TcpClient::TimeoutCallback(evutil_socket_t, std::int16_t, void *arg) {
  117. MS_EXCEPTION_IF_NULL(arg);
  118. auto tcp_client = reinterpret_cast<TcpClient *>(arg);
  119. tcp_client->InitTcpClient();
  120. }
  121. void TcpClient::ReadCallback(struct bufferevent *bev, void *ctx) {
  122. MS_EXCEPTION_IF_NULL(bev);
  123. MS_EXCEPTION_IF_NULL(ctx);
  124. auto tcp_client = reinterpret_cast<TcpClient *>(ctx);
  125. struct evbuffer *input = bufferevent_get_input(const_cast<struct bufferevent *>(bev));
  126. MS_EXCEPTION_IF_NULL(input);
  127. char read_buffer[4096];
  128. int read = 0;
  129. while ((read = EVBUFFER_LENGTH(input)) > 0) {
  130. if (evbuffer_remove(input, &read_buffer, sizeof(read_buffer)) == -1) {
  131. MS_LOG(EXCEPTION) << "Can not drain data from the event buffer!";
  132. }
  133. tcp_client->OnReadHandler(read_buffer, read);
  134. }
  135. }
  136. void TcpClient::OnReadHandler(const void *buf, size_t num) {
  137. MS_EXCEPTION_IF_NULL(buf);
  138. if (read_callback_) {
  139. read_callback_(*this, buf, num);
  140. }
  141. message_handler_.ReceiveMessage(buf, num);
  142. }
  143. void TcpClient::EventCallback(struct bufferevent *bev, std::int16_t events, void *ptr) {
  144. MS_EXCEPTION_IF_NULL(bev);
  145. MS_EXCEPTION_IF_NULL(ptr);
  146. auto tcp_client = reinterpret_cast<TcpClient *>(ptr);
  147. if (events & BEV_EVENT_CONNECTED) {
  148. // Connected
  149. if (tcp_client->connected_callback_) {
  150. tcp_client->connected_callback_(*tcp_client);
  151. }
  152. evutil_socket_t fd = bufferevent_getfd(const_cast<struct bufferevent *>(bev));
  153. SetTcpNoDelay(fd);
  154. MS_LOG(INFO) << "Client connected!";
  155. } else if (events & BEV_EVENT_ERROR) {
  156. MS_LOG(ERROR) << "Client connected error!";
  157. if (tcp_client->disconnected_callback_) {
  158. tcp_client->disconnected_callback_(*tcp_client, errno);
  159. }
  160. } else if (events & BEV_EVENT_EOF) {
  161. MS_LOG(ERROR) << "Client connected end of file";
  162. if (tcp_client->disconnected_callback_) {
  163. tcp_client->disconnected_callback_(*tcp_client, 0);
  164. }
  165. }
  166. }
  167. void TcpClient::Start() {
  168. MS_EXCEPTION_IF_NULL(event_base_);
  169. int ret = event_base_dispatch(event_base_);
  170. if (ret == 0) {
  171. MS_LOG(INFO) << "Event base dispatch success!";
  172. } else if (ret == 1) {
  173. MS_LOG(ERROR) << "Event base dispatch failed with no events pending or active!";
  174. } else if (ret == -1) {
  175. MS_LOG(ERROR) << "Event base dispatch failed with error occurred!";
  176. } else {
  177. MS_LOG(EXCEPTION) << "Event base dispatch with unexpect error code!";
  178. }
  179. }
  180. void TcpClient::ReceiveMessage(const OnMessage &cb) { message_callback_ = cb; }
  181. void TcpClient::SendMessage(const void *buf, size_t num) const {
  182. MS_EXCEPTION_IF_NULL(buffer_event_);
  183. if (evbuffer_add(bufferevent_get_output(buffer_event_), buf, num) == -1) {
  184. MS_LOG(EXCEPTION) << "event buffer add failed!";
  185. }
  186. }
  187. } // namespace comm
  188. } // namespace ps
  189. } // namespace mindspore