| @@ -62,7 +62,7 @@ bool ComputeGraphNode::Register() { | |||
| auto message = CreateMessage(server_url, content); | |||
| MS_EXCEPTION_IF_NULL(message); | |||
| tcp_client_->Send(std::move(message)); | |||
| tcp_client_->SendSync(std::move(message)); | |||
| return true; | |||
| } | |||
| @@ -77,7 +77,7 @@ bool ComputeGraphNode::Heartbeat() { | |||
| auto message = CreateMessage(server_url, content); | |||
| MS_EXCEPTION_IF_NULL(message); | |||
| tcp_client_->Send(std::move(message)); | |||
| tcp_client_->SendSync(std::move(message)); | |||
| return true; | |||
| } | |||
| } // namespace topology | |||
| @@ -326,12 +326,9 @@ void Connection::FillSendMessage(MessageBase *msg, const std::string &advertiseU | |||
| if (msg->type == MessageBase::Type::KMSG) { | |||
| if (!isHttpKmsg) { | |||
| send_to = msg->to; | |||
| send_from = msg->from.Name() + "@" + advertiseUrl; | |||
| send_from = msg->from; | |||
| send_msg_header.name_len = htonl(static_cast<uint32_t>(msg->name.size())); | |||
| send_msg_header.to_len = htonl(static_cast<uint32_t>(send_to.size())); | |||
| send_msg_header.from_len = htonl(static_cast<uint32_t>(send_from.size())); | |||
| send_msg_header.body_len = htonl(static_cast<uint32_t>(msg->body.size())); | |||
| FillMessageHeader(*msg, &send_msg_header); | |||
| send_io_vec[index].iov_base = &send_msg_header; | |||
| send_io_vec[index].iov_len = sizeof(send_msg_header); | |||
| @@ -431,7 +428,6 @@ int Connection::AddConnnectEventHandler() { | |||
| } | |||
| bool Connection::ParseMessage() { | |||
| std::string magic_id = ""; | |||
| int retval = 0; | |||
| uint32_t recvLen = 0; | |||
| char *recvBuf = nullptr; | |||
| @@ -454,7 +450,7 @@ bool Connection::ParseMessage() { | |||
| if (strncmp(recv_msg_header.magic, RPC_MAGICID, sizeof(RPC_MAGICID) - 1) != 0) { | |||
| MS_LOG(ERROR) << "Failed to check magicid, RPC_MAGICID: " << RPC_MAGICID | |||
| << ", recv magic_id: " << magic_id.c_str(); | |||
| << ", recv magic_id: " << recv_msg_header.magic; | |||
| state = ConnectionState::kDisconnecting; | |||
| return false; | |||
| } | |||
| @@ -30,27 +30,6 @@ | |||
| namespace mindspore { | |||
| namespace distributed { | |||
| namespace rpc { | |||
| /* | |||
| * The MessageHeader contains the stats info about the message body. | |||
| */ | |||
| struct MessageHeader { | |||
| MessageHeader() { | |||
| for (unsigned int i = 0; i < BUSMAGIC_LEN; ++i) { | |||
| if (i < sizeof(RPC_MAGICID) - 1) { | |||
| magic[i] = RPC_MAGICID[i]; | |||
| } else { | |||
| magic[i] = '\0'; | |||
| } | |||
| } | |||
| } | |||
| char magic[BUSMAGIC_LEN]; | |||
| uint32_t name_len{0}; | |||
| uint32_t to_len{0}; | |||
| uint32_t from_len{0}; | |||
| uint32_t body_len{0}; | |||
| }; | |||
| /* | |||
| * The SendMetrics is responsible for collecting metrics when sending data through a connection. | |||
| */ | |||
| @@ -17,6 +17,7 @@ | |||
| #ifndef MINDSPORE_CCSRC_DISTRIBUTED_RPC_TCP_CONSTANTS_H_ | |||
| #define MINDSPORE_CCSRC_DISTRIBUTED_RPC_TCP_CONSTANTS_H_ | |||
| #include <arpa/inet.h> | |||
| #include <string> | |||
| #include <csignal> | |||
| #include <queue> | |||
| @@ -85,6 +86,45 @@ constexpr int IP_LEN_MAX = 128; | |||
| // Kill the process for safe exiting. | |||
| inline void KillProcess(const std::string &ret) { raise(SIGKILL); } | |||
| /* | |||
| * The MessageHeader contains the stats info about the message body. | |||
| */ | |||
| struct MessageHeader { | |||
| MessageHeader() { | |||
| for (unsigned int i = 0; i < BUSMAGIC_LEN; ++i) { | |||
| if (i < sizeof(RPC_MAGICID) - 1) { | |||
| magic[i] = RPC_MAGICID[i]; | |||
| } else { | |||
| magic[i] = '\0'; | |||
| } | |||
| } | |||
| } | |||
| char magic[BUSMAGIC_LEN]; | |||
| uint32_t name_len{0}; | |||
| uint32_t to_len{0}; | |||
| uint32_t from_len{0}; | |||
| uint32_t body_len{0}; | |||
| }; | |||
| // Fill the message header using the given message. | |||
| __attribute__((unused)) static void FillMessageHeader(const MessageBase &message, MessageHeader *header) { | |||
| std::string send_to = message.to; | |||
| std::string send_from = message.from; | |||
| header->name_len = htonl(static_cast<uint32_t>(message.name.size())); | |||
| header->to_len = htonl(static_cast<uint32_t>(send_to.size())); | |||
| header->from_len = htonl(static_cast<uint32_t>(send_from.size())); | |||
| header->body_len = htonl(static_cast<uint32_t>(message.body.size())); | |||
| } | |||
| // Compute and return the byte size of the whole message. | |||
| __attribute__((unused)) static size_t GetMessageSize(const MessageBase &message) { | |||
| std::string send_to = message.to; | |||
| std::string send_from = message.from; | |||
| size_t size = message.name.size() + send_to.size() + send_from.size() + message.body.size() + sizeof(MessageHeader); | |||
| return size; | |||
| } | |||
| #define RPC_ASSERT(expression) \ | |||
| do { \ | |||
| if (!(expression)) { \ | |||
| @@ -129,7 +129,7 @@ void EventLoop::ReleaseResource() { | |||
| } | |||
| } | |||
| int EventLoop::AddTask(std::function<void()> &&task) { | |||
| int EventLoop::AddTask(std::function<int()> &&task) { | |||
| // put func to the queue | |||
| task_queue_mutex_.lock(); | |||
| (void)task_queue_.emplace(std::move(task)); | |||
| @@ -68,7 +68,7 @@ class EventLoop { | |||
| // Add task (eg. send message, reconnect etc.) to task queue of the event loop. | |||
| // These tasks are executed asynchronously. | |||
| int AddTask(std::function<void()> &&task); | |||
| int AddTask(std::function<int()> &&task); | |||
| // Set event handler for events(read/write/..) occurred on the socket fd. | |||
| int SetEventHandler(int sock_fd, uint32_t events, EventHandler handler, void *data); | |||
| @@ -69,11 +69,13 @@ bool TCPClient::Disconnect(const std::string &dst_url, size_t timeout_in_sec) { | |||
| return rt; | |||
| } | |||
| int TCPClient::Send(std::unique_ptr<MessageBase> &&msg) { | |||
| int TCPClient::SendSync(std::unique_ptr<MessageBase> &&msg) { | |||
| int rt = -1; | |||
| rt = tcp_comm_->Send(msg.release()); | |||
| rt = tcp_comm_->Send(msg.release(), true); | |||
| return rt; | |||
| } | |||
| void TCPClient::SendAsync(std::unique_ptr<MessageBase> &&msg) { (void)tcp_comm_->Send(msg.release(), false); } | |||
| } // namespace rpc | |||
| } // namespace distributed | |||
| } // namespace mindspore | |||
| @@ -41,8 +41,11 @@ class TCPClient { | |||
| // Disconnect from the specified server. | |||
| bool Disconnect(const std::string &dst_url, size_t timeout_in_sec = 5); | |||
| // Send the message from the source to the destination. | |||
| int Send(std::unique_ptr<MessageBase> &&msg); | |||
| // Send the message from the source to the destination synchronously and return the byte size by this method call. | |||
| int SendSync(std::unique_ptr<MessageBase> &&msg); | |||
| // Send the message from the source to the destination asynchronously. | |||
| void SendAsync(std::unique_ptr<MessageBase> &&msg); | |||
| private: | |||
| // The basic TCP communication component used by the client. | |||
| @@ -130,7 +130,8 @@ void OnAccept(int server, uint32_t events, void *arg) { | |||
| } | |||
| } | |||
| void DoSend(Connection *conn) { | |||
| int DoSend(Connection *conn) { | |||
| int total_send_bytes = 0; | |||
| while (!conn->send_message_queue.empty() || conn->total_send_len != 0) { | |||
| if (conn->total_send_len == 0) { | |||
| conn->FillSendMessage(conn->send_message_queue.front(), conn->source, false); | |||
| @@ -139,6 +140,7 @@ void DoSend(Connection *conn) { | |||
| int sendLen = conn->socket_operation->SendMessage(conn, &conn->send_kernel_msg, &conn->total_send_len); | |||
| if (sendLen > 0) { | |||
| total_send_bytes += sendLen; | |||
| if (conn->total_send_len == 0) { | |||
| // update metrics | |||
| conn->send_metrics->UpdateError(false); | |||
| @@ -158,6 +160,7 @@ void DoSend(Connection *conn) { | |||
| break; | |||
| } | |||
| } | |||
| return total_send_bytes; | |||
| } | |||
| TCPComm::~TCPComm() { | |||
| @@ -354,8 +357,8 @@ void TCPComm::DropMessage(MessageBase *msg) { | |||
| ptr = nullptr; | |||
| } | |||
| int TCPComm::Send(MessageBase *msg) { | |||
| return send_event_loop_->AddTask([msg, this] { | |||
| int TCPComm::Send(MessageBase *msg, bool sync) { | |||
| auto task = [msg, this] { | |||
| std::lock_guard<std::mutex> lock(*conn_mutex_); | |||
| // Search connection by the target address | |||
| Connection *conn = conn_pool_->FindConnection(msg->to.Url()); | |||
| @@ -363,7 +366,8 @@ int TCPComm::Send(MessageBase *msg) { | |||
| MS_LOG(ERROR) << "Can not found remote link and send fail name: " << msg->name.c_str() | |||
| << ", from: " << msg->from.Url().c_str() << ", to: " << msg->to.Url().c_str(); | |||
| DropMessage(msg); | |||
| return; | |||
| int error_no = -1; | |||
| return error_no; | |||
| } | |||
| if (conn->send_message_queue.size() >= SENDMSG_QUEUELEN) { | |||
| @@ -371,7 +375,8 @@ int TCPComm::Send(MessageBase *msg) { | |||
| << ") and the name of dropped message is: " << msg->name.c_str() << ", fd: " << conn->socket_fd | |||
| << ", to: " << conn->destination.c_str(); | |||
| DropMessage(msg); | |||
| return; | |||
| int error_no = -1; | |||
| return error_no; | |||
| } | |||
| if (conn->state != ConnectionState::kConnected) { | |||
| @@ -379,7 +384,8 @@ int TCPComm::Send(MessageBase *msg) { | |||
| << " and the name of dropped message is: " << msg->name.c_str() << ", fd: " << conn->socket_fd | |||
| << ", to: " << conn->destination.c_str(); | |||
| DropMessage(msg); | |||
| return; | |||
| int error_no = -1; | |||
| return error_no; | |||
| } | |||
| if (conn->total_send_len == 0) { | |||
| @@ -387,8 +393,13 @@ int TCPComm::Send(MessageBase *msg) { | |||
| } else { | |||
| (void)conn->send_message_queue.emplace(msg); | |||
| } | |||
| DoSend(conn); | |||
| }); | |||
| return DoSend(conn); | |||
| }; | |||
| if (sync) { | |||
| return task(); | |||
| } else { | |||
| return send_event_loop_->AddTask(task); | |||
| } | |||
| } | |||
| void TCPComm::Connect(const std::string &dst_url) { | |||
| @@ -403,7 +414,7 @@ void TCPComm::Connect(const std::string &dst_url) { | |||
| conn = new (std::nothrow) Connection(); | |||
| if (conn == nullptr) { | |||
| MS_LOG(ERROR) << "Failed to create new connection and link fail destination: " << dst_url; | |||
| return; | |||
| return false; | |||
| } | |||
| conn->source = url_; | |||
| conn->destination = dst_url; | |||
| @@ -418,12 +429,12 @@ void TCPComm::Connect(const std::string &dst_url) { | |||
| SocketAddress addr; | |||
| if (!SocketOperation::GetSockAddr(dst_url, &addr)) { | |||
| MS_LOG(ERROR) << "Failed to get socket address to dest url " << dst_url; | |||
| return; | |||
| return false; | |||
| } | |||
| int sock_fd = SocketOperation::CreateSocket(addr.sa.sa_family); | |||
| if (sock_fd < 0) { | |||
| MS_LOG(ERROR) << "Failed to create client tcp socket to dest url " << dst_url; | |||
| return; | |||
| return false; | |||
| } | |||
| conn->socket_fd = sock_fd; | |||
| @@ -439,12 +450,13 @@ void TCPComm::Connect(const std::string &dst_url) { | |||
| conn->socket_operation = nullptr; | |||
| } | |||
| delete conn; | |||
| return; | |||
| return false; | |||
| } | |||
| conn_pool_->AddConnection(conn); | |||
| } | |||
| conn_pool_->AddConnInfo(conn->socket_fd, dst_url, nullptr); | |||
| MS_LOG(INFO) << "Connected to destination: " << dst_url; | |||
| return true; | |||
| }); | |||
| } | |||
| @@ -460,6 +472,7 @@ void TCPComm::Disconnect(const std::string &dst_url) { | |||
| (void)recv_event_loop_->AddTask([dst_url, this] { | |||
| std::lock_guard<std::mutex> lock(*conn_mutex_); | |||
| conn_pool_->DeleteConnection(dst_url); | |||
| return true; | |||
| }); | |||
| } | |||
| @@ -34,7 +34,7 @@ namespace rpc { | |||
| void OnAccept(int server, uint32_t events, void *arg); | |||
| // Send messages buffered in the connection. | |||
| void DoSend(Connection *conn); | |||
| int DoSend(Connection *conn); | |||
| void DoDisconnect(int fd, Connection *conn, uint32_t error, int soError); | |||
| @@ -65,7 +65,8 @@ class TCPComm { | |||
| void Disconnect(const std::string &dst_url); | |||
| // Send the message from the source to the destination. | |||
| int Send(MessageBase *msg); | |||
| // The flag sync means if the message is sent directly or added to the task queue. | |||
| int Send(MessageBase *msg, bool sync = false); | |||
| // Set the message processing handler. | |||
| void SetMessageHandler(MessageHandler handler); | |||
| @@ -115,7 +116,7 @@ class TCPComm { | |||
| std::shared_ptr<std::mutex> conn_mutex_; | |||
| friend void OnAccept(int server, uint32_t events, void *arg); | |||
| friend void DoSend(Connection *conn); | |||
| friend int DoSend(Connection *conn); | |||
| friend int DoConnect(const std::string &to, Connection *conn, ConnectionCallBack event_callback, | |||
| ConnectionCallBack write_callback, ConnectionCallBack read_callback); | |||
| }; | |||
| @@ -69,7 +69,7 @@ void SendActor::SendOutput(OpContext<DeviceTensor> *const context) { | |||
| std::string peer_server_url = peer.second; | |||
| auto message = BuildRpcMessage(send_output, peer_server_url); | |||
| MS_ERROR_IF_NULL_WO_RET_VAL(message); | |||
| client_->Send(std::move(message)); | |||
| client_->SendAsync(std::move(message)); | |||
| } | |||
| } | |||
| @@ -26,6 +26,7 @@ | |||
| #define private public | |||
| #include "distributed/rpc/tcp/tcp_server.h" | |||
| #include "distributed/rpc/tcp/tcp_client.h" | |||
| #include "distributed/rpc/tcp/constants.h" | |||
| #include "common/common_test.h" | |||
| namespace mindspore { | |||
| @@ -165,7 +166,7 @@ TEST_F(TCPTest, SendOneMessage) { | |||
| // Send the message. | |||
| client->Connect(server_url); | |||
| client->Send(std::move(message)); | |||
| client->SendAsync(std::move(message)); | |||
| // Wait timeout: 5s | |||
| WaitForDataMsg(1, 5); | |||
| @@ -182,7 +183,7 @@ TEST_F(TCPTest, SendOneMessage) { | |||
| /// Feature: test sending two message continuously. | |||
| /// Description: start a socket server and send two normal message to it. | |||
| /// Expectation: the server received the two messages sented from client. | |||
| TEST_F(TCPTest, sendTwoMessages) { | |||
| TEST_F(TCPTest, SendTwoMessages) { | |||
| Init(); | |||
| // Start the tcp server. | |||
| @@ -205,8 +206,8 @@ TEST_F(TCPTest, sendTwoMessages) { | |||
| // Send messages. | |||
| client->Connect(server_url); | |||
| client->Send(std::move(message1)); | |||
| client->Send(std::move(message2)); | |||
| client->SendAsync(std::move(message1)); | |||
| client->SendAsync(std::move(message2)); | |||
| // Wait timeout: 5s | |||
| WaitForDataMsg(2, 5); | |||
| @@ -230,6 +231,45 @@ TEST_F(TCPTest, StartServerWithRandomPort) { | |||
| EXPECT_LT(0, port); | |||
| server->Finalize(); | |||
| } | |||
| /// Feature: test send the message synchronously. | |||
| /// Description: start a socket server and send the message synchronously. | |||
| /// Expectation: the number of bytes sent could be got synchronously. | |||
| TEST_F(TCPTest, SendSyncMessage) { | |||
| Init(); | |||
| // Start the tcp server. | |||
| auto server_url = "127.0.0.1:8081"; | |||
| std::unique_ptr<TCPServer> server = std::make_unique<TCPServer>(); | |||
| bool ret = server->Initialize(server_url); | |||
| ASSERT_TRUE(ret); | |||
| server->SetMessageHandler([](const std::shared_ptr<MessageBase> &message) -> void { IncrDataMsgNum(1); }); | |||
| // Start the tcp client. | |||
| auto client_url = "127.0.0.1:1234"; | |||
| std::unique_ptr<TCPClient> client = std::make_unique<TCPClient>(); | |||
| ret = client->Initialize(); | |||
| ASSERT_TRUE(ret); | |||
| // Create the message. | |||
| auto message = CreateMessage(server_url, client_url); | |||
| auto msg_size = GetMessageSize(*message); | |||
| // Send the message. | |||
| client->Connect(server_url); | |||
| auto bytes_num = client->SendSync(std::move(message)); | |||
| EXPECT_EQ(msg_size, bytes_num); | |||
| WaitForDataMsg(1, 5); | |||
| EXPECT_EQ(1, GetDataMsgNum()); | |||
| // Destroy | |||
| client->Disconnect(server_url); | |||
| client->Finalize(); | |||
| server->Finalize(); | |||
| } | |||
| } // namespace rpc | |||
| } // namespace distributed | |||
| } // namespace mindspore | |||