Browse Source

Fix static code check issues.

r1.7
Parallels 4 years ago
parent
commit
8eb7ecf89e
16 changed files with 121 additions and 134 deletions
  1. +13
    -13
      mindspore/ccsrc/distributed/rpc/tcp/connection.cc
  2. +4
    -4
      mindspore/ccsrc/distributed/rpc/tcp/connection.h
  3. +4
    -8
      mindspore/ccsrc/distributed/rpc/tcp/connection_pool.cc
  4. +3
    -2
      mindspore/ccsrc/distributed/rpc/tcp/connection_pool.h
  5. +2
    -2
      mindspore/ccsrc/distributed/rpc/tcp/constants.h
  6. +6
    -13
      mindspore/ccsrc/distributed/rpc/tcp/event_loop.cc
  7. +4
    -4
      mindspore/ccsrc/distributed/rpc/tcp/event_loop.h
  8. +6
    -3
      mindspore/ccsrc/distributed/rpc/tcp/socket_operation.cc
  9. +4
    -4
      mindspore/ccsrc/distributed/rpc/tcp/socket_operation.h
  10. +28
    -16
      mindspore/ccsrc/distributed/rpc/tcp/tcp_client.cc
  11. +15
    -34
      mindspore/ccsrc/distributed/rpc/tcp/tcp_comm.cc
  12. +5
    -5
      mindspore/ccsrc/distributed/rpc/tcp/tcp_comm.h
  13. +4
    -3
      mindspore/ccsrc/distributed/rpc/tcp/tcp_server.cc
  14. +3
    -3
      mindspore/ccsrc/distributed/rpc/tcp/tcp_server.h
  15. +16
    -16
      mindspore/ccsrc/distributed/rpc/tcp/tcp_socket_operation.cc
  16. +4
    -4
      mindspore/ccsrc/distributed/rpc/tcp/tcp_socket_operation.h

+ 13
- 13
mindspore/ccsrc/distributed/rpc/tcp/connection.cc View File

@@ -32,7 +32,9 @@ void SocketEventHandler(int fd, uint32_t events, void *context) {
if (fd != conn->socket_fd) {
MS_LOG(ERROR) << "Failed to reuse connection, delete and close fd: " << fd << ", connfd: " << conn->socket_fd
<< ", event: " << events;
conn->recv_event_loop->DeleteEpollEvent(fd);
if (conn->recv_event_loop->DeleteEpollEvent(fd) != RPC_OK) {
MS_LOG(ERROR) << "Failed to delete epoll event for fd: " << fd;
}
conn->state = ConnectionState::kDisconnecting;
if (conn->event_callback != nullptr) {
conn->event_callback(conn);
@@ -76,7 +78,7 @@ void SocketEventHandler(int fd, uint32_t events, void *context) {
void NewConnectEventHandler(int fd, uint32_t events, void *context) {
int retval = 0;
Connection *conn = reinterpret_cast<Connection *>(context);
conn->socket_operation->NewConnEventHandler(fd, events, context);
conn->socket_operation->NewConnEventHandler(context);

if (conn->state == ConnectionState::kDisconnecting) {
conn->Disconnect(fd);
@@ -170,10 +172,9 @@ void Connection::InitSocketOperation() {
}

bool Connection::ReconnectSourceSocket(int fd, uint32_t events, int *soError, uint32_t error) {
int retval = 0;
socklen_t len = sizeof(*soError);

retval = recv_event_loop->DeleteEpollEvent(fd);
int retval = recv_event_loop->DeleteEpollEvent(fd);
if (retval) {
MS_LOG(ERROR) << "Failed to delete event for fd: " << fd << ", event: " << events;
return false;
@@ -274,7 +275,7 @@ void Connection::CheckMessageType() {
magic_id.resize(sizeof(RPC_MAGICID) - 1);
char *buf = const_cast<char *>(magic_id.data());

int size = socket_operation->ReceivePeek(this, buf, sizeof(RPC_MAGICID) - 1);
ssize_t size = socket_operation->ReceivePeek(this, buf, sizeof(RPC_MAGICID) - 1);
if (size < static_cast<int>(sizeof(RPC_MAGICID) - 1)) {
if (size == 0) {
MS_LOG(INFO) << "Set connection disconnecting for fd: " << socket_fd << ", size: " << size
@@ -321,13 +322,12 @@ std::string Connection::GenerateHttpMessage(MessageBase *msg) {
return postLine + userAgentLine + fromLine + connectLine + hostLine + commonEndLine;
}

void Connection::FillSendMessage(MessageBase *msg, const std::string &advertiseUrl, bool isHttpKmsg, int index) {
index = 0;
void Connection::FillSendMessage(MessageBase *msg, const std::string &advertiseUrl, bool isHttpKmsg) {
if (msg->type == MessageBase::Type::KMSG) {
size_t index = 0;
if (!isHttpKmsg) {
send_to = msg->to;
send_from = msg->from;

FillMessageHeader(*msg, &send_msg_header);

send_io_vec[index].iov_base = &send_msg_header;
@@ -346,7 +346,7 @@ void Connection::FillSendMessage(MessageBase *msg, const std::string &advertiseU
send_io_vec[index].iov_len = msg->body.size();
++index;
send_kernel_msg.msg_iov = send_io_vec;
send_kernel_msg.msg_iovlen = IntToSize(index);
send_kernel_msg.msg_iovlen = index;
total_send_len =
UlongToUint(sizeof(send_msg_header)) + msg->name.size() + send_to.size() + send_from.size() + msg->body.size();
send_message = msg;
@@ -358,11 +358,11 @@ void Connection::FillSendMessage(MessageBase *msg, const std::string &advertiseU
return;
} else {
if (advertise_addr_.empty()) {
size_t index = advertiseUrl.find(URL_PROTOCOL_IP_SEPARATOR);
if (index == std::string::npos) {
size_t idx = advertiseUrl.find(URL_PROTOCOL_IP_SEPARATOR);
if (idx == std::string::npos) {
advertise_addr_ = advertiseUrl;
} else {
advertise_addr_ = advertiseUrl.substr(index + sizeof(URL_PROTOCOL_IP_SEPARATOR) - 1);
advertise_addr_ = advertiseUrl.substr(idx + sizeof(URL_PROTOCOL_IP_SEPARATOR) - 1);
}
}
msg->body = GenerateHttpMessage(msg);
@@ -418,7 +418,7 @@ void Connection::FillRecvMessage() {

recv_kernel_msg.msg_iov = recv_io_vec;
recv_kernel_msg.msg_iovlen = IntToSize(i);
total_recv_len = UlongToUint(msg->name.size()) + recv_to.size() + recv_from.size() + msg->body.size();
total_recv_len = msg->name.size() + recv_to.size() + recv_from.size() + msg->body.size();
recv_message = msg;
}



+ 4
- 4
mindspore/ccsrc/distributed/rpc/tcp/connection.h View File

@@ -101,7 +101,7 @@ struct Connection {
void CheckMessageType();

// Fill the message to be sent based on the input message.
void FillSendMessage(MessageBase *msg, const std::string &advertiseUrl, bool isHttpKmsg, int index = 0);
void FillSendMessage(MessageBase *msg, const std::string &advertiseUrl, bool isHttpKmsg);

void FillRecvMessage();

@@ -155,9 +155,9 @@ struct Connection {
State recv_state;

// Total length of received and sent messages.
uint32_t total_recv_len;
uint32_t total_send_len;
uint32_t recv_len;
size_t total_recv_len;
size_t total_send_len;
size_t recv_len;

std::string send_to;
std::string send_from;


+ 4
- 8
mindspore/ccsrc/distributed/rpc/tcp/connection_pool.cc View File

@@ -195,14 +195,10 @@ bool ConnectionPool::ReverseConnInfo(int fromFd, int toFd) {
return true;
}

ConnectionPool::~ConnectionPool() {
try {
DeleteAllConnections(&local_conns_);
DeleteAllConnections(&remote_conns_);
DeleteAllConnInfos();
} catch (...) {
MS_LOG(ERROR) << "Failed to release resource for connection pool.";
}
void ConnectionPool::Finalize() {
DeleteAllConnections(&local_conns_);
DeleteAllConnections(&remote_conns_);
DeleteAllConnInfos();
}
} // namespace rpc
} // namespace distributed


+ 3
- 2
mindspore/ccsrc/distributed/rpc/tcp/connection_pool.h View File

@@ -40,7 +40,9 @@ struct ConnectionInfo {
class ConnectionPool {
public:
ConnectionPool() : double_link_(false) {}
~ConnectionPool();
~ConnectionPool() = default;

void Finalize();

/*
* Operations for ConnectionInfo.
@@ -90,7 +92,6 @@ class ConnectionPool {
// each to_url has two fds at most, and each fd has multiple linkinfos
std::map<int, std::set<ConnectionInfo *>> conn_infos_;

friend class Connection;
friend class TCPComm;
};
} // namespace rpc


+ 2
- 2
mindspore/ccsrc/distributed/rpc/tcp/constants.h View File

@@ -47,7 +47,7 @@ constexpr size_t MAX_KMSG_NAME_LEN = 1024;
constexpr size_t MAX_KMSG_BODY_LEN = 104857600;

enum ParseType { kTcpMsg = 1, kHttpReq, kHttpRsp, kUnknown };
enum State { kMagicId = 1, kMsgHeader, kName, kDestination, kSource, kBody };
enum State { kMsgHeader, kBody };
enum ConnectionState { kInit = 1, kConnecting, kConnected, kDisconnecting, kClose };
enum ConnectionType { kTcp = 1, kSSL };
enum ConnectionPriority { kPriorityLow = 1, kPriorityHigh };
@@ -84,7 +84,7 @@ constexpr int RPC_OK = 0;
constexpr int IP_LEN_MAX = 128;

// Kill the process for safe exiting.
inline void KillProcess(const std::string &ret) { raise(SIGKILL); }
inline void KillProcess(const std::string &ret) { (void)raise(SIGKILL); }

/*
* The MessageHeader contains the stats info about the message body.


+ 6
- 13
mindspore/ccsrc/distributed/rpc/tcp/event_loop.cc View File

@@ -27,7 +27,6 @@
#include <atomic>
#include <string>
#include <thread>
#include <csignal>

#include "actor/log.h"
#include "distributed/rpc/tcp/constants.h"
@@ -98,7 +97,8 @@ void QueueReadyCallback(int fd, uint32_t events, void *arg) {
return;
}
uint64_t count;
if (read(evloop->task_queue_event_fd_, &count, sizeof(count)) == sizeof(count)) {
ssize_t retval = read(evloop->task_queue_event_fd_, &count, sizeof(count));
if (retval > 0 && retval == sizeof(count)) {
// take out functions from the queue
std::queue<std::function<void()>> q;

@@ -129,7 +129,7 @@ void EventLoop::ReleaseResource() {
}
}

int EventLoop::AddTask(std::function<int()> &&task) {
ssize_t EventLoop::AddTask(std::function<int()> &&task) {
// put func to the queue
task_queue_mutex_.lock();
(void)task_queue_.emplace(std::move(task));
@@ -141,7 +141,8 @@ int EventLoop::AddTask(std::function<int()> &&task) {
if (result == 1) {
// wakeup event loop
uint64_t one = 1;
if (write(task_queue_event_fd_, &one, sizeof(one)) != sizeof(one)) {
ssize_t retval = write(task_queue_event_fd_, &one, sizeof(one));
if (retval != sizeof(one)) {
MS_LOG(WARNING) << "Failed to write queue Event fd: " << task_queue_event_fd_ << ",errno:" << errno;
}
}
@@ -197,14 +198,6 @@ void EventLoop::Finalize() {
MS_LOG(INFO) << "Stop loop succ";
}

EventLoop::~EventLoop() {
try {
Finalize();
} catch (...) {
MS_LOG(ERROR) << "Failed to finalize the event loop";
}
}

void EventLoop::DeleteEvent(int fd) {
auto iter = events_.find(fd);
if (iter == events_.end()) {
@@ -410,7 +403,7 @@ void EventLoop::RemoveDeletedEvents() {
deleted_events_.clear();
}

int EventLoop::FindDeletedEvent(Event *tev) {
int EventLoop::FindDeletedEvent(const Event *tev) {
std::map<int, std::list<Event *>>::iterator fdIter = deleted_events_.find(tev->fd);
if (fdIter == deleted_events_.end()) {
return 0;


+ 4
- 4
mindspore/ccsrc/distributed/rpc/tcp/event_loop.h View File

@@ -46,7 +46,7 @@ int EventLoopRun(EventLoop *evloop, int timeout);
/*
* The event occurred on the fd.
*/
typedef struct Event {
typedef struct {
int fd;
void *data;
EventHandler handler;
@@ -61,14 +61,14 @@ class EventLoop {
EventLoop() : epoll_fd_(-1), is_stop_(false), loop_thread_(0), task_queue_event_fd_(-1) {}
EventLoop(const EventLoop &) = delete;
EventLoop &operator=(const EventLoop &) = delete;
~EventLoop();
~EventLoop() = default;

bool Initialize(const std::string &threadName);
void Finalize();

// Add task (eg. send message, reconnect etc.) to task queue of the event loop.
// These tasks are executed asynchronously.
int AddTask(std::function<int()> &&task);
ssize_t AddTask(std::function<int()> &&task);

// Set event handler for events(read/write/..) occurred on the socket fd.
int SetEventHandler(int sock_fd, uint32_t events, EventHandler handler, void *data);
@@ -91,7 +91,7 @@ class EventLoop {

// Operate the soft deleted events.
void AddDeletedEvent(Event *event);
int FindDeletedEvent(Event *event);
int FindDeletedEvent(const Event *event);
void RemoveDeletedEvents();

// Event operations.


+ 6
- 3
mindspore/ccsrc/distributed/rpc/tcp/socket_operation.cc View File

@@ -23,7 +23,6 @@
#include <securec.h>
#include <netinet/tcp.h>
#include <unistd.h>
#include <csignal>
#include <system_error>

#include "actor/log.h"
@@ -289,7 +288,9 @@ std::string SocketOperation::GetPeer(int sock_fd) {
}
peer = std::string(ipdotdec) + ":" + std::to_string(ntohs(isa.saIn.sin_port));
} else if (isa.sa.sa_family == AF_INET6) {
inet_ntop(AF_INET6, reinterpret_cast<void *>(&isa.saIn6.sin6_addr), ipdotdec, IP_LEN_MAX);
if (inet_ntop(AF_INET6, reinterpret_cast<void *>(&isa.saIn6.sin6_addr), ipdotdec, IP_LEN_MAX) == nullptr) {
MS_LOG(ERROR) << "Failed to call inet_ntop.";
}
peer = std::string(ipdotdec) + ":" + std::to_string(ntohs(isa.saIn6.sin6_port));
} else {
MS_LOG(INFO) << "Unknown fd: " << sock_fd << ", family: " << isa.sa.sa_family;
@@ -363,7 +364,9 @@ int SocketOperation::Accept(int sock_fd) {
MS_LOG(ERROR) << "Failed to call accept, errno: " << errno << ", server: " << sock_fd;
return acceptFd;
}
SetSocketOptions(acceptFd);
if (SetSocketOptions(acceptFd) < 0) {
MS_LOG(ERROR) << "Failed to set socket options for accepted socket: " << acceptFd;
}
return acceptFd;
}
} // namespace rpc


+ 4
- 4
mindspore/ccsrc/distributed/rpc/tcp/socket_operation.h View File

@@ -71,7 +71,7 @@ class SocketOperation {
static int Accept(int sock_fd);

// Call recv with flag MSG_PEEK which means do not delete data in buffer after reading.
virtual int ReceivePeek(Connection *connection, char *recvBuf, uint32_t recvLen) = 0;
virtual ssize_t ReceivePeek(Connection *connection, char *recvBuf, uint32_t recvLen) = 0;

// Try to receive messages up to totalRecvLen (for message header).
virtual int Receive(Connection *connection, char *recvBuf, uint32_t totalRecvLen, uint32_t *recvLen) = 0;
@@ -79,11 +79,11 @@ class SocketOperation {
// Receive message (for message body).
virtual int ReceiveMessage(Connection *connection, struct msghdr *recvMsg, uint32_t recvLen) = 0;

virtual int SendMessage(Connection *connection, struct msghdr *sendMsg, uint32_t *sendLen) = 0;
virtual ssize_t SendMessage(Connection *connection, struct msghdr *sendMsg, size_t *sendLen) = 0;

// Handle connect and connected events.
virtual void NewConnEventHandler(int fd, uint32_t events, void *context) = 0;
virtual void ConnEstablishedEventHandler(int fd, uint32_t events, void *context) = 0;
virtual void NewConnEventHandler(void *context) = 0;
virtual void ConnEstablishedEventHandler(void *context) = 0;
};
} // namespace rpc
} // namespace distributed


+ 28
- 16
mindspore/ccsrc/distributed/rpc/tcp/tcp_client.cc View File

@@ -31,22 +31,33 @@ bool TCPClient::Initialize() {
return rt;
}

void TCPClient::Finalize() { tcp_comm_->Finalize(); }
void TCPClient::Finalize() {
if (tcp_comm_ != nullptr) {
tcp_comm_->Finalize();
tcp_comm_.reset();
tcp_comm_ = nullptr;
}
}

bool TCPClient::Connect(const std::string &dst_url, size_t timeout_in_sec) {
bool rt = false;
tcp_comm_->Connect(dst_url);

int timeout = timeout_in_sec * 1000 * 1000;
size_t usleep_count = 100000;
size_t timeout_in_ms = timeout_in_sec * 1000;
size_t sleep_in_ms = 100;
useconds_t sleep_in_us = 100000;

while (timeout) {
while (true) {
if (tcp_comm_->IsConnected(dst_url)) {
rt = true;
break;
}
timeout = timeout - usleep_count;
usleep(usleep_count);
if (timeout_in_ms > sleep_in_ms) {
timeout_in_ms -= sleep_in_ms;
} else {
break;
}
(void)usleep(sleep_in_us);
}
return rt;
}
@@ -55,25 +66,26 @@ bool TCPClient::Disconnect(const std::string &dst_url, size_t timeout_in_sec) {
bool rt = false;
tcp_comm_->Disconnect(dst_url);

int timeout = timeout_in_sec * 1000 * 1000;
size_t usleep_count = 100000;
size_t timeout_in_ms = timeout_in_sec * 1000;
size_t sleep_in_ms = 100;
useconds_t sleep_in_us = 100000;

while (timeout) {
while (true) {
if (!tcp_comm_->IsConnected(dst_url)) {
rt = true;
break;
}
timeout = timeout - usleep_count;
usleep(usleep_count);
if (timeout_in_ms > sleep_in_ms) {
timeout_in_ms -= sleep_in_ms;
} else {
break;
}
usleep(sleep_in_us);
}
return rt;
}

int TCPClient::SendSync(std::unique_ptr<MessageBase> &&msg) {
int rt = -1;
rt = tcp_comm_->Send(msg.release(), true);
return rt;
}
int TCPClient::SendSync(std::unique_ptr<MessageBase> &&msg) { return tcp_comm_->Send(msg.release(), true); }

void TCPClient::SendAsync(std::unique_ptr<MessageBase> &&msg) { (void)tcp_comm_->Send(msg.release(), false); }
} // namespace rpc


+ 15
- 34
mindspore/ccsrc/distributed/rpc/tcp/tcp_comm.cc View File

@@ -43,7 +43,7 @@ void ConnectedEventHandler(int fd, uint32_t events, void *context) {
uint32_t error = events & (EPOLLERR | EPOLLHUP | EPOLLRDHUP);
int soError = 0;
Connection *conn = reinterpret_cast<Connection *>(context);
conn->socket_operation->ConnEstablishedEventHandler(fd, events, context);
conn->socket_operation->ConnEstablishedEventHandler(context);
if (conn->state == ConnectionState::kDisconnecting) {
DoDisconnect(fd, conn, error, soError);
return;
@@ -128,6 +128,7 @@ void OnAccept(int server, uint32_t events, void *arg) {
delete conn;
return;
}
tcpmgr->conn_pool_->AddConnection(conn);
}

int DoSend(Connection *conn) {
@@ -163,15 +164,7 @@ int DoSend(Connection *conn) {
return total_send_bytes;
}

TCPComm::~TCPComm() {
try {
Finalize();
} catch (...) {
MS_LOG(ERROR) << "Failed to finalize tcp communicator.";
}
}

void TCPComm::SetMessageHandler(MessageHandler handler) { message_handler_ = handler; }
void TCPComm::SetMessageHandler(const MessageHandler &handler) { message_handler_ = handler; }

bool TCPComm::Initialize() {
conn_pool_ = std::make_shared<ConnectionPool>();
@@ -244,7 +237,7 @@ bool TCPComm::StartServerSocket() {
return StartServerSocket(url);
}

int TCPComm::GetServerFd() { return server_fd_; }
int TCPComm::GetServerFd() const { return server_fd_; }

void TCPComm::ReadCallBack(void *connection) {
const int max_recv_count = 3;
@@ -264,7 +257,7 @@ void TCPComm::EventCallBack(void *connection) {

if (conn->state == ConnectionState::kConnected) {
conn->conn_mutex->lock();
DoSend(conn);
(void)DoSend(conn);
conn->conn_mutex->unlock();
} else if (conn->state == ConnectionState::kDisconnecting) {
conn->conn_mutex->lock();
@@ -276,7 +269,7 @@ void TCPComm::WriteCallBack(void *connection) {
Connection *conn = reinterpret_cast<Connection *>(connection);
if (conn->state == ConnectionState::kConnected) {
conn->conn_mutex->lock();
DoSend(conn);
(void)DoSend(conn);
conn->conn_mutex->unlock();
}
}
@@ -288,25 +281,6 @@ int TCPComm::ReceiveMessage(Connection *conn) {
switch (conn->recv_message_type) {
case ParseType::kTcpMsg:
return conn->ReceiveMessage();

#ifdef HTTP_ENABLED
case ParseType::KHTTP_REQ:
if (httpReqCb) {
return httpReqCb(conn, message_handler_);
} else {
conn->state = ConnectionState::kDisconnecting;
return -1;
}

case ParseType::KHTTP_RSP:
if (httpRspCb) {
return httpRspCb(conn, message_handler_);
} else {
conn->state = ConnectionState::kDisconnecting;
return -1;
}
#endif

default:
return 0;
}
@@ -357,7 +331,7 @@ void TCPComm::DropMessage(MessageBase *msg) {
ptr = nullptr;
}

int TCPComm::Send(MessageBase *msg, bool sync) {
ssize_t TCPComm::Send(MessageBase *msg, bool sync) {
auto task = [msg, this] {
std::lock_guard<std::mutex> lock(*conn_mutex_);
// Search connection by the target address
@@ -476,7 +450,7 @@ void TCPComm::Disconnect(const std::string &dst_url) {
});
}

Connection *TCPComm::CreateDefaultConn(std::string to) {
Connection *TCPComm::CreateDefaultConn(const std::string &to) {
Connection *conn = new (std::nothrow) Connection();
if (conn == nullptr) {
MS_LOG(ERROR) << "Failed to create new connection and reconnect fail to: " << to.c_str();
@@ -513,6 +487,13 @@ void TCPComm::Finalize() {
}
server_fd_ = -1;
}

if (conn_pool_ != nullptr) {
MS_LOG(INFO) << "Delete connection pool.";
conn_pool_->Finalize();
conn_pool_.reset();
conn_pool_ = nullptr;
}
}
} // namespace rpc
} // namespace distributed


+ 5
- 5
mindspore/ccsrc/distributed/rpc/tcp/tcp_comm.h View File

@@ -45,7 +45,7 @@ class TCPComm {
TCPComm() : server_fd_(-1), recv_event_loop_(nullptr), send_event_loop_(nullptr) {}
TCPComm(const TCPComm &) = delete;
TCPComm &operator=(const TCPComm &) = delete;
~TCPComm();
~TCPComm() = default;

// Init the event loop for reading and writing.
bool Initialize();
@@ -66,17 +66,17 @@ class TCPComm {

// Send the message from the source to the destination.
// The flag sync means if the message is sent directly or added to the task queue.
int Send(MessageBase *msg, bool sync = false);
ssize_t Send(MessageBase *msg, bool sync = false);

// Set the message processing handler.
void SetMessageHandler(MessageHandler handler);
void SetMessageHandler(const MessageHandler &handler);

// Get the file descriptor of server socket.
int GetServerFd();
int GetServerFd() const;

private:
// Build the connection.
Connection *CreateDefaultConn(std::string to);
Connection *CreateDefaultConn(const std::string &to);

// Send a message.
static void SendExitMsg(const std::string &from, const std::string &to);


+ 4
- 3
mindspore/ccsrc/distributed/rpc/tcp/tcp_server.cc View File

@@ -25,16 +25,17 @@ bool TCPServer::Initialize() { return InitializeImpl(""); }

void TCPServer::Finalize() {
if (tcp_comm_ != nullptr) {
tcp_comm_->Finalize();
tcp_comm_.reset();
tcp_comm_ = nullptr;
}
}

void TCPServer::SetMessageHandler(MessageHandler handler) { tcp_comm_->SetMessageHandler(handler); }
void TCPServer::SetMessageHandler(const MessageHandler &handler) { tcp_comm_->SetMessageHandler(handler); }

std::string TCPServer::GetIP() { return ip_; }
std::string TCPServer::GetIP() const { return ip_; }

uint32_t TCPServer::GetPort() { return port_; }
uint32_t TCPServer::GetPort() const { return port_; }

bool TCPServer::InitializeImpl(const std::string &url) {
if (tcp_comm_ == nullptr) {


+ 3
- 3
mindspore/ccsrc/distributed/rpc/tcp/tcp_server.h View File

@@ -41,11 +41,11 @@ class TCPServer {
void Finalize();

// Set the message processing handler.
void SetMessageHandler(MessageHandler handler);
void SetMessageHandler(const MessageHandler &handler);

// Return the IP and port binded by this server.
std::string GetIP();
uint32_t GetPort();
std::string GetIP() const;
uint32_t GetPort() const;

private:
bool InitializeImpl(const std::string &url);


+ 16
- 16
mindspore/ccsrc/distributed/rpc/tcp/tcp_socket_operation.cc View File

@@ -21,7 +21,7 @@ namespace distributed {
namespace rpc {
constexpr int EAGAIN_RETRY = 2;

int TCPSocketOperation::ReceivePeek(Connection *connection, char *recvBuf, uint32_t recvLen) {
ssize_t TCPSocketOperation::ReceivePeek(Connection *connection, char *recvBuf, uint32_t recvLen) {
return recv(connection->socket_fd, recvBuf, recvLen, MSG_PEEK);
}

@@ -31,9 +31,9 @@ int TCPSocketOperation::Receive(Connection *connection, char *recvBuf, uint32_t

*recvLen = 0;
while (*recvLen != totalRecvLen) {
int retval = recv(fd, curRecvBuf, totalRecvLen - *recvLen, static_cast<int>(0));
ssize_t retval = recv(fd, curRecvBuf, totalRecvLen - *recvLen, static_cast<int>(0));
if (retval > 0) {
*recvLen += IntToUint(retval);
*recvLen += static_cast<uint32_t>(retval);
if (*recvLen == totalRecvLen) {
return UintToInt(totalRecvLen);
}
@@ -57,7 +57,7 @@ int TCPSocketOperation::Receive(Connection *connection, char *recvBuf, uint32_t
}

int TCPSocketOperation::ReceiveMessage(Connection *connection, struct msghdr *recvMsg, uint32_t recvLen) {
uint32_t totalRecvLen = recvLen;
ssize_t totalRecvLen = recvLen;

if (totalRecvLen == 0) {
return 0;
@@ -66,7 +66,7 @@ int TCPSocketOperation::ReceiveMessage(Connection *connection, struct msghdr *re
while (totalRecvLen) {
auto retval = recvmsg(connection->socket_fd, recvMsg, 0);
if (retval > 0) {
totalRecvLen -= IntToSize(retval);
totalRecvLen -= retval;
if (totalRecvLen == 0) {
recvMsg->msg_iovlen = 0;
break;
@@ -90,28 +90,28 @@ int TCPSocketOperation::ReceiveMessage(Connection *connection, struct msghdr *re
}
}
} else if (retval == 0) {
return UintToInt(-1);
return -1;
} else {
if (EAGAIN == errno) {
return recvLen - totalRecvLen;
return UintToInt(recvLen - totalRecvLen);
} else if (ECONNRESET == errno || ECONNABORTED == errno || ENOTCONN == errno || EPIPE == errno) {
connection->error_code = UintToInt(errno);
connection->error_code = errno;
return -1;
} else {
return UintToInt(recvLen - totalRecvLen);
}
}
}
return recvLen;
return UintToInt(recvLen);
}

int TCPSocketOperation::SendMessage(Connection *connection, struct msghdr *sendMsg, uint32_t *sendLen) {
ssize_t TCPSocketOperation::SendMessage(Connection *connection, struct msghdr *sendMsg, size_t *sendLen) {
int eagainCount = EAGAIN_RETRY;
uint32_t totalLen = *sendLen;
int32_t unsendLen = *sendLen;
size_t totalLen = *sendLen;
ssize_t unsendLen = static_cast<ssize_t>(*sendLen);

while (*sendLen != 0) {
int retval = sendmsg(connection->socket_fd, sendMsg, MSG_NOSIGNAL);
auto retval = sendmsg(connection->socket_fd, sendMsg, MSG_NOSIGNAL);
if (retval < 0) {
--eagainCount;
if (errno != EAGAIN) {
@@ -148,7 +148,7 @@ int TCPSocketOperation::SendMessage(Connection *connection, struct msghdr *sendM
}
}
if (unsendLen > 0) {
unsendLen = UintToInt(totalLen - *sendLen);
unsendLen = totalLen - *sendLen;
}
return unsendLen;
}
@@ -159,13 +159,13 @@ void TCPSocketOperation::Close(Connection *connection) {
}

// accept new conn event handle
void TCPSocketOperation::NewConnEventHandler(int fd, uint32_t events, void *context) {
void TCPSocketOperation::NewConnEventHandler(void *context) {
Connection *conn = reinterpret_cast<Connection *>(context);
conn->state = ConnectionState::kConnected;
return;
}

void TCPSocketOperation::ConnEstablishedEventHandler(int fd, uint32_t events, void *context) {
void TCPSocketOperation::ConnEstablishedEventHandler(void *context) {
Connection *conn = reinterpret_cast<Connection *>(context);
conn->state = ConnectionState::kConnected;
return;


+ 4
- 4
mindspore/ccsrc/distributed/rpc/tcp/tcp_socket_operation.h View File

@@ -25,16 +25,16 @@ namespace distributed {
namespace rpc {
class TCPSocketOperation : public SocketOperation {
public:
int ReceivePeek(Connection *connection, char *recvBuf, uint32_t recvLen) override;
ssize_t ReceivePeek(Connection *connection, char *recvBuf, uint32_t recvLen) override;
int Receive(Connection *connection, char *recvBuf, uint32_t totRecvLen, uint32_t *recvLen) override;
int ReceiveMessage(Connection *connection, struct msghdr *recvMsg, uint32_t recvLen) override;

int SendMessage(Connection *connection, struct msghdr *sendMsg, uint32_t *sendLen) override;
ssize_t SendMessage(Connection *connection, struct msghdr *sendMsg, size_t *sendLen) override;

void Close(Connection *connection) override;

void NewConnEventHandler(int fd, uint32_t events, void *context) override;
void ConnEstablishedEventHandler(int fd, uint32_t events, void *context) override;
void NewConnEventHandler(void *context) override;
void ConnEstablishedEventHandler(void *context) override;
};
} // namespace rpc
} // namespace distributed


Loading…
Cancel
Save