| @@ -149,6 +149,13 @@ size_t EventLoop::AddTask(std::function<int()> &&task) { | |||
| return result; | |||
| } | |||
| size_t EventLoop::RemainingTaskNum() { | |||
| task_queue_mutex_.lock(); | |||
| auto task_num = task_queue_.size(); | |||
| task_queue_mutex_.unlock(); | |||
| return task_num; | |||
| } | |||
| bool EventLoop::Initialize(const std::string &threadName) { | |||
| int retval = InitResource(); | |||
| if (retval != RPC_OK) { | |||
| @@ -70,6 +70,9 @@ class EventLoop { | |||
| // These tasks are executed asynchronously. | |||
| size_t AddTask(std::function<int()> &&task); | |||
| // The number of tasks in the pending task queue. | |||
| size_t RemainingTaskNum(); | |||
| // Set event handler for events(read/write/..) occurred on the socket fd. | |||
| int SetEventHandler(int sock_fd, uint32_t events, EventHandler handler, void *data); | |||
| @@ -138,29 +138,40 @@ int DoSend(Connection *conn) { | |||
| conn->FillSendMessage(conn->send_message_queue.front(), conn->source, false); | |||
| conn->send_message_queue.pop(); | |||
| } | |||
| size_t retryCount = 10; | |||
| size_t sendLen = 0; | |||
| int retval = conn->socket_operation->SendMessage(conn, &conn->send_kernel_msg, conn->total_send_len, &sendLen); | |||
| if (retval == IO_RW_OK && sendLen > 0) { | |||
| total_send_bytes += sendLen; | |||
| conn->total_send_len -= sendLen; | |||
| if (conn->total_send_len == 0) { | |||
| // update metrics | |||
| conn->send_metrics->UpdateError(false); | |||
| conn->output_buffer_size -= conn->send_message->body.size(); | |||
| delete conn->send_message; | |||
| conn->send_message = nullptr; | |||
| while (retryCount > 0 && sendLen != conn->total_send_len) { | |||
| int retval = conn->socket_operation->SendMessage(conn, &conn->send_kernel_msg, conn->total_send_len, &sendLen); | |||
| if (retval == IO_RW_OK && sendLen > 0) { | |||
| conn->total_send_len -= sendLen; | |||
| if (conn->total_send_len == 0) { | |||
| // update metrics | |||
| conn->send_metrics->UpdateError(false); | |||
| conn->output_buffer_size -= conn->send_message->body.size(); | |||
| total_send_bytes += conn->send_message->body.size(); | |||
| delete conn->send_message; | |||
| conn->send_message = nullptr; | |||
| break; | |||
| } | |||
| } else if (retval == IO_RW_OK && sendLen == 0) { | |||
| // EAGAIN | |||
| MS_LOG(ERROR) << "Failed to send message and update the epoll event"; | |||
| (void)conn->recv_event_loop->UpdateEpollEvent(conn->socket_fd, EPOLLOUT | EPOLLIN | EPOLLHUP | EPOLLERR); | |||
| continue; | |||
| } else { | |||
| if (--retryCount > 0) { | |||
| MS_LOG(ERROR) << "Failed to send message and retry(" + std::to_string(retryCount) + ")..."; | |||
| unsigned int time = 1; | |||
| sleep(time); | |||
| continue; | |||
| } else { | |||
| // update metrics | |||
| conn->send_metrics->UpdateError(true, conn->error_code); | |||
| conn->state = ConnectionState::kDisconnecting; | |||
| break; | |||
| } | |||
| } | |||
| } else if (retval == IO_RW_OK && sendLen == 0) { | |||
| // EAGAIN | |||
| (void)conn->recv_event_loop->UpdateEpollEvent(conn->socket_fd, EPOLLOUT | EPOLLIN | EPOLLHUP | EPOLLERR); | |||
| break; | |||
| } else { | |||
| // update metrics | |||
| conn->send_metrics->UpdateError(true, conn->error_code); | |||
| conn->state = ConnectionState::kDisconnecting; | |||
| break; | |||
| } | |||
| } | |||
| return total_send_bytes; | |||
| @@ -445,12 +456,24 @@ bool TCPComm::IsConnected(const std::string &dst_url) { | |||
| return false; | |||
| } | |||
| void TCPComm::Disconnect(const std::string &dst_url) { | |||
| bool TCPComm::Disconnect(const std::string &dst_url) { | |||
| int interval = 100000; | |||
| size_t retry = 30; | |||
| while (recv_event_loop_->RemainingTaskNum() != 0 && send_event_loop_->RemainingTaskNum() != 0 && retry > 0) { | |||
| usleep(interval); | |||
| retry--; | |||
| } | |||
| if (recv_event_loop_->RemainingTaskNum() > 0 || send_event_loop_->RemainingTaskNum() > 0) { | |||
| MS_LOG(ERROR) << "Failed to disconnect from url " << dst_url | |||
| << ", because there are still pending tasks to be executed, please try later."; | |||
| return false; | |||
| } | |||
| (void)recv_event_loop_->AddTask([dst_url, this] { | |||
| std::lock_guard<std::mutex> lock(*conn_mutex_); | |||
| conn_pool_->DeleteConnection(dst_url); | |||
| return true; | |||
| }); | |||
| return true; | |||
| } | |||
| Connection *TCPComm::CreateDefaultConn(const std::string &to) { | |||
| @@ -62,7 +62,7 @@ class TCPComm { | |||
| // Connection operation for a specified destination. | |||
| void Connect(const std::string &dst_url); | |||
| bool IsConnected(const std::string &dst_url); | |||
| void Disconnect(const std::string &dst_url); | |||
| bool Disconnect(const std::string &dst_url); | |||
| // Send the message from the source to the destination. | |||
| // The flag sync means if the message is sent directly or added to the task queue. | |||
| @@ -19,7 +19,7 @@ | |||
| namespace mindspore { | |||
| namespace distributed { | |||
| namespace rpc { | |||
| constexpr int EAGAIN_RETRY = 2; | |||
| constexpr int EAGAIN_RETRY = 100; | |||
| ssize_t TCPSocketOperation::ReceivePeek(Connection *connection, char *recvBuf, uint32_t recvLen) { | |||
| return recv(connection->socket_fd, recvBuf, recvLen, MSG_PEEK); | |||
| @@ -107,18 +107,23 @@ int TCPSocketOperation::ReceiveMessage(Connection *connection, struct msghdr *re | |||
| int TCPSocketOperation::SendMessage(Connection *connection, struct msghdr *sendMsg, size_t totalSendLen, | |||
| size_t *sendLen) { | |||
| int eagainCount = EAGAIN_RETRY; | |||
| *sendLen = 0; | |||
| while (*sendLen != totalSendLen) { | |||
| auto retval = sendmsg(connection->socket_fd, sendMsg, MSG_NOSIGNAL); | |||
| if (retval < 0) { | |||
| --eagainCount; | |||
| if (errno != EAGAIN) { | |||
| MS_LOG(ERROR) << "Failed to call sendmsg and errno is: " << errno; | |||
| connection->error_code = errno; | |||
| return IO_RW_ERROR; | |||
| } else if (eagainCount == 0) { | |||
| MS_LOG(ERROR) << "Failed to call sendmsg after retry " + std::to_string(EAGAIN_RETRY) + " times and errno is: " | |||
| << errno; | |||
| *sendLen = 0; | |||
| break; | |||
| return IO_RW_OK; | |||
| } | |||
| MS_LOG(ERROR) << "retry(" + std::to_string(eagainCount) + "/" + std::to_string(EAGAIN_RETRY) + ") sending ..."; | |||
| } else { | |||
| *sendLen += retval; | |||
| @@ -137,7 +142,7 @@ int TCPSocketOperation::SendMessage(Connection *connection, struct msghdr *sendM | |||
| reinterpret_cast<char *>(sendMsg->msg_iov[i].iov_base) + static_cast<unsigned int>(retval) - tmpBytes; | |||
| sendMsg->msg_iov = &sendMsg->msg_iov[i]; | |||
| sendMsg->msg_iovlen -= (i + 1); | |||
| sendMsg->msg_iovlen -= i; | |||
| break; | |||
| } | |||
| } | |||
| @@ -255,7 +255,7 @@ TEST_F(TCPTest, SendSyncMessage) { | |||
| // Create the message. | |||
| auto message = CreateMessage(server_url, client_url); | |||
| auto msg_size = GetMessageSize(*message); | |||
| auto msg_size = message->body.size(); | |||
| // Send the message. | |||
| client->Connect(server_url); | |||
| @@ -271,6 +271,109 @@ TEST_F(TCPTest, SendSyncMessage) { | |||
| client->Finalize(); | |||
| server->Finalize(); | |||
| } | |||
| /// Feature: test sending large messages. | |||
| /// Description: start a socket server and send several large messages to it. | |||
| /// Expectation: the server received these large messages sented from client. | |||
| TEST_F(TCPTest, SendLargeMessages) { | |||
| Init(); | |||
| // Start the tcp server. | |||
| std::unique_ptr<TCPServer> server = std::make_unique<TCPServer>(); | |||
| bool ret = server->Initialize(); | |||
| ASSERT_TRUE(ret); | |||
| server->SetMessageHandler([](const std::shared_ptr<MessageBase> &message) -> void { IncrDataMsgNum(1); }); | |||
| // Start the tcp client. | |||
| auto client_url = "127.0.0.1:1234"; | |||
| std::unique_ptr<TCPClient> client = std::make_unique<TCPClient>(); | |||
| ret = client->Initialize(); | |||
| ASSERT_TRUE(ret); | |||
| // Send the message. | |||
| auto ip = server->GetIP(); | |||
| auto port = server->GetPort(); | |||
| auto server_url = ip + ":" + std::to_string(port); | |||
| client->Connect(server_url); | |||
| size_t msg_cnt = 5; | |||
| size_t large_msg_size = 1024000; | |||
| for (int i = 0; i < msg_cnt; ++i) { | |||
| auto message = CreateMessage(server_url, client_url, large_msg_size); | |||
| client->SendAsync(std::move(message)); | |||
| } | |||
| // Wait timeout: 15s | |||
| WaitForDataMsg(msg_cnt, 15); | |||
| // Check result | |||
| EXPECT_EQ(msg_cnt, GetDataMsgNum()); | |||
| // Destroy | |||
| client->Disconnect(server_url); | |||
| client->Finalize(); | |||
| server->Finalize(); | |||
| } | |||
| /// Feature: test creating many TCP connections. | |||
| /// Description: create many servers and clients, then connect each client to a server. | |||
| /// Expectation: all the servers and clients are created successfully. | |||
| TEST_F(TCPTest, CreateManyConnectionPairs) { | |||
| Init(); | |||
| std::vector<std::shared_ptr<TCPServer>> servers; | |||
| std::vector<std::shared_ptr<TCPClient>> clients; | |||
| std::vector<std::string> server_urls; | |||
| size_t total_connection_num = 10; | |||
| for (size_t i = 0; i < total_connection_num; ++i) { | |||
| // Start the tcp server. | |||
| std::shared_ptr<TCPServer> server = std::make_shared<TCPServer>(); | |||
| bool ret = server->Initialize(); | |||
| auto ip = server->GetIP(); | |||
| auto port = server->GetPort(); | |||
| ASSERT_TRUE(ret); | |||
| server->SetMessageHandler([](const std::shared_ptr<MessageBase> &message) -> void { IncrDataMsgNum(1); }); | |||
| // Start the tcp client. | |||
| auto client_url = "127.0.0.1:1234"; | |||
| std::shared_ptr<TCPClient> client = std::make_shared<TCPClient>(); | |||
| ret = client->Initialize(); | |||
| ASSERT_TRUE(ret); | |||
| // Send the message. | |||
| auto server_url = ip + ":" + std::to_string(port); | |||
| server_urls.push_back(server_url); | |||
| auto success = client->Connect(server_url); | |||
| EXPECT_EQ(true, success); | |||
| size_t msg_cnt = 100; | |||
| size_t large_msg_size = 10240; | |||
| for (int i = 0; i < msg_cnt; ++i) { | |||
| auto message = CreateMessage(server_url, client_url, large_msg_size); | |||
| client->SendAsync(std::move(message)); | |||
| } | |||
| // Check result | |||
| servers.push_back(server); | |||
| clients.push_back(client); | |||
| } | |||
| // Check result | |||
| EXPECT_EQ(total_connection_num, servers.size()); | |||
| EXPECT_EQ(total_connection_num, clients.size()); | |||
| // Destroy | |||
| for (size_t i = 0; i < total_connection_num; ++i) { | |||
| while (!clients[i]->Disconnect(server_urls[i])) | |||
| ; | |||
| clients[i]->Finalize(); | |||
| servers[i]->Finalize(); | |||
| } | |||
| } | |||
| } // namespace rpc | |||
| } // namespace distributed | |||
| } // namespace mindspore | |||