Browse Source

updated cluster config and proto

tags/v1.1.0
anancds 5 years ago
parent
commit
fc380b8071
11 changed files with 267 additions and 78 deletions
  1. +18
    -2
      mindspore/ccsrc/ps/core/cluster_config.cc
  2. +6
    -2
      mindspore/ccsrc/ps/core/cluster_config.h
  3. +34
    -0
      mindspore/ccsrc/ps/core/comm_util.cc
  4. +14
    -0
      mindspore/ccsrc/ps/core/comm_util.h
  5. +44
    -16
      mindspore/ccsrc/ps/core/protos/comm.proto
  6. +7
    -7
      mindspore/ccsrc/ps/core/protos/ps.proto
  7. +55
    -30
      mindspore/ccsrc/ps/core/tcp_client.cc
  8. +15
    -3
      mindspore/ccsrc/ps/core/tcp_client.h
  9. +55
    -13
      mindspore/ccsrc/ps/core/tcp_server.cc
  10. +18
    -4
      mindspore/ccsrc/ps/core/tcp_server.h
  11. +1
    -1
      tests/ut/cpp/ps/core/tcp_pb_server_test.cc

+ 18
- 2
mindspore/ccsrc/ps/core/cluster_config.cc View File

@@ -21,12 +21,16 @@
namespace mindspore { namespace mindspore {
namespace ps { namespace ps {
namespace core { namespace core {

uint32_t ClusterConfig::worker_num_ = 0; uint32_t ClusterConfig::worker_num_ = 0;
uint32_t ClusterConfig::server_num_ = 0; uint32_t ClusterConfig::server_num_ = 0;
uint32_t ClusterConfig::heartbeat_interval_ = kHeartbeatInterval;
std::unique_ptr<std::string> ClusterConfig::scheduler_host_ = nullptr; std::unique_ptr<std::string> ClusterConfig::scheduler_host_ = nullptr;
uint16_t ClusterConfig::scheduler_port_ = 0; uint16_t ClusterConfig::scheduler_port_ = 0;
// The interval for sending heartbeat packets between worker node,server node and scheduler node is 3 seconds.
uint32_t ClusterConfig::heartbeat_interval_ = 3;
// The timeout for worker node and server node sending heartbeat packets to scheduler node is 30 seconds.
uint32_t ClusterConfig::heartbeat_timeout_ = 30;
// Timeout period for cluster preparation is 300 seconds.
uint32_t ClusterConfig::cluster_available_timeout_ = 300;


void ClusterConfig::Init(const uint32_t &worker_num, const uint32_t &server_num, void ClusterConfig::Init(const uint32_t &worker_num, const uint32_t &server_num,
std::unique_ptr<std::string> scheduler_host, const uint16_t &scheduler_port) { std::unique_ptr<std::string> scheduler_host, const uint16_t &scheduler_port) {
@@ -53,6 +57,18 @@ std::string ClusterConfig::scheduler_host() { return *scheduler_host_.get(); }


uint16_t ClusterConfig::scheduler_port() { return scheduler_port_; } uint16_t ClusterConfig::scheduler_port() { return scheduler_port_; }


uint32_t ClusterConfig::heartbeat_timeout() { return heartbeat_timeout_; }

void ClusterConfig::set_heartbeat_timeout(const uint32_t &heartbeat_timeout) {
heartbeat_interval_ = heartbeat_timeout;
}

uint32_t ClusterConfig::cluster_available_timeout() { return cluster_available_timeout_; }

void ClusterConfig::set_cluster_available_timeout(const uint32_t &cluster_available_timeout) {
cluster_available_timeout_ = cluster_available_timeout;
}

} // namespace core } // namespace core
} // namespace ps } // namespace ps
} // namespace mindspore } // namespace mindspore

+ 6
- 2
mindspore/ccsrc/ps/core/cluster_config.h View File

@@ -28,8 +28,6 @@
namespace mindspore { namespace mindspore {
namespace ps { namespace ps {
namespace core { namespace core {
constexpr uint32_t kHeartbeatInterval = 3;

class ClusterConfig { class ClusterConfig {
public: public:
static void Init(const uint32_t &worker_num, const uint32_t &server_num, std::unique_ptr<std::string> scheduler_host, static void Init(const uint32_t &worker_num, const uint32_t &server_num, std::unique_ptr<std::string> scheduler_host,
@@ -40,6 +38,10 @@ class ClusterConfig {
static void set_heartbeat_interval(const uint32_t &heartbeat_interval); static void set_heartbeat_interval(const uint32_t &heartbeat_interval);
static std::string scheduler_host(); static std::string scheduler_host();
static uint16_t scheduler_port(); static uint16_t scheduler_port();
static uint32_t heartbeat_timeout();
static void set_heartbeat_timeout(const uint32_t &heartbeat_timeout);
static uint32_t cluster_available_timeout();
static void set_cluster_available_timeout(const uint32_t &cluster_available_timeout);


private: private:
static uint32_t worker_num_; static uint32_t worker_num_;
@@ -47,6 +49,8 @@ class ClusterConfig {
static uint32_t heartbeat_interval_; static uint32_t heartbeat_interval_;
static std::unique_ptr<std::string> scheduler_host_; static std::unique_ptr<std::string> scheduler_host_;
static uint16_t scheduler_port_; static uint16_t scheduler_port_;
static uint32_t heartbeat_timeout_;
static uint32_t cluster_available_timeout_;
}; };
} // namespace core } // namespace core
} // namespace ps } // namespace ps


+ 34
- 0
mindspore/ccsrc/ps/core/comm_util.cc View File

@@ -21,11 +21,17 @@
#include <cstdlib> #include <cstdlib>
#include <cstring> #include <cstring>
#include <functional> #include <functional>
#include <algorithm>
#include <regex> #include <regex>


namespace mindspore { namespace mindspore {
namespace ps { namespace ps {
namespace core { namespace core {
std::random_device CommUtil::rd;
std::mt19937_64 CommUtil::gen(rd());
std::uniform_int_distribution<> CommUtil::dis = std::uniform_int_distribution<>{0, 15};
std::uniform_int_distribution<> CommUtil::dis2 = std::uniform_int_distribution<>{8, 11};

bool CommUtil::CheckIpWithRegex(const std::string &ip) { bool CommUtil::CheckIpWithRegex(const std::string &ip) {
std::regex pattern("((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?).){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)"); std::regex pattern("((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?).){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)");
std::smatch res; std::smatch res;
@@ -75,6 +81,34 @@ void CommUtil::GetAvailableInterfaceAndIP(std::string *interface, std::string *i
MS_EXCEPTION_IF_NULL(if_address); MS_EXCEPTION_IF_NULL(if_address);
freeifaddrs(if_address); freeifaddrs(if_address);
} }

std::string CommUtil::GenerateUUID() {
std::stringstream ss;
int i;
ss << std::hex;
for (i = 0; i < kGroup1RandomLength; i++) {
ss << dis(gen);
}
ss << "-";
for (i = 0; i < kGroup2RandomLength; i++) {
ss << dis(gen);
}
ss << "-4";
for (i = 0; i < kGroup2RandomLength - 1; i++) {
ss << dis(gen);
}
ss << "-";
ss << dis2(gen);
for (i = 0; i < kGroup3RandomLength - 1; i++) {
ss << dis(gen);
}
ss << "-";
for (i = 0; i < kGroup4RandomLength; i++) {
ss << dis(gen);
}
return ss.str();
}

} // namespace core } // namespace core
} // namespace ps } // namespace ps
} // namespace mindspore } // namespace mindspore

+ 14
- 0
mindspore/ccsrc/ps/core/comm_util.h View File

@@ -43,17 +43,31 @@
#include <functional> #include <functional>
#include <string> #include <string>
#include <utility> #include <utility>
#include <random>
#include <sstream>


#include "utils/log_adapter.h" #include "utils/log_adapter.h"


namespace mindspore { namespace mindspore {
namespace ps { namespace ps {
namespace core { namespace core {
constexpr int kGroup1RandomLength = 8;
constexpr int kGroup2RandomLength = 4;
constexpr int kGroup3RandomLength = 4;
constexpr int kGroup4RandomLength = 4;
constexpr int kGroup5RandomLength = 12;

class CommUtil { class CommUtil {
public: public:
static bool CheckIpWithRegex(const std::string &ip); static bool CheckIpWithRegex(const std::string &ip);
static bool CheckIp(const std::string &ip); static bool CheckIp(const std::string &ip);
static void GetAvailableInterfaceAndIP(std::string *interface, std::string *ip); static void GetAvailableInterfaceAndIP(std::string *interface, std::string *ip);
static std::string GenerateUUID();

static std::random_device rd;
static std::mt19937_64 gen;
static std::uniform_int_distribution<> dis;
static std::uniform_int_distribution<> dis2;
}; };
} // namespace core } // namespace core
} // namespace ps } // namespace ps


+ 44
- 16
mindspore/ccsrc/ps/core/protos/comm.proto View File

@@ -19,36 +19,64 @@ import "google/protobuf/any.proto";
package mindspore.ps.core; package mindspore.ps.core;
option optimize_for = LITE_RUNTIME; option optimize_for = LITE_RUNTIME;


enum ClusterCommand {
enum NodeCommand {
TERMINATE = 0; TERMINATE = 0;
REGISTER = 1; REGISTER = 1;
ACK = 2;
HEARTBEAT = 3;
FETCH_WORKERS = 4;
FETCH_SERVERS = 5;
HEARTBEAT = 2;
SEND_DATA = 3;
FETCH_SERVER = 4;
} }


enum Role {
enum NodeRole {
SERVER = 0; SERVER = 0;
WORKER = 1; WORKER = 1;
SCHEDULER = 2; SCHEDULER = 2;
} }


message MessageMeta { message MessageMeta {
// hostname or ip
string hostname = 1;
// the command of this message,for example: register,heartbeat,data
NodeCommand cmd = 1;
// the request id of this message
uint64 request_id = 2;
}

message RegisterMessage {
// ip
string ip = 1;
// the port of this node // the port of this node
int32 port = 2; int32 port = 2;
// the command of this message,for example: register、heartbeat、data
int32 cmd = 3;
// the timestamp of this message
int32 timestamp = 4;
// data type of message
repeated int32 data_type = 5 [packed = true];
// message.data_size
int32 data_size = 6;
// the current Node unique id:0,1,2...
string node_id = 3;
// the role of the node: worker,server,scheduler
NodeRole role = 4;
}

message RegisterRespMessage {
string node_id = 1;
int32 rank_id = 2;
}

message HeartbeatMessage {
// the current Node unique id:0,1,2...
string node_id = 1;
} }


message HeartbeatRespMessage {
// Is the entire system ready to use.
bool is_cluster_ready = 1;
bool is_cluster_finish = 2;
}

message FetchServersRespMessage {
repeated ServersMeta servers_meta = 1;
}

message ServersMeta {
int32 rank_id = 1;
string ip = 2;
int32 port = 3;

}


message CommMessage { message CommMessage {
MessageMeta pb_meta = 1; MessageMeta pb_meta = 1;


+ 7
- 7
mindspore/ccsrc/ps/core/protos/ps.proto View File

@@ -13,17 +13,17 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */

syntax = "proto3"; syntax = "proto3";
package mindspore.ps.core; package mindspore.ps.core;
option optimize_for = LITE_RUNTIME; option optimize_for = LITE_RUNTIME;


message KVMessage {
repeated int32 keys = 1;
repeated float values = 2;
enum PSCommand {
PUSH = 0;
PULL = 1;
} }


message HeartBeatMessage {
// *.*.*.*:port
repeated string host_and_port = 1;
message KVMessage {
PSCommand command = 1;
repeated int32 keys = 2;
repeated float values = 3;
} }

+ 55
- 30
mindspore/ccsrc/ps/core/tcp_client.cc View File

@@ -18,8 +18,8 @@


#include <arpa/inet.h> #include <arpa/inet.h>
#include <event2/buffer.h> #include <event2/buffer.h>
#include <event2/bufferevent.h>
#include <event2/buffer_compat.h> #include <event2/buffer_compat.h>
#include <event2/bufferevent.h>
#include <event2/event.h> #include <event2/event.h>
#include <netinet/in.h> #include <netinet/in.h>
#include <netinet/tcp.h> #include <netinet/tcp.h>
@@ -27,20 +27,23 @@
#include <cstdlib> #include <cstdlib>
#include <cstring> #include <cstring>
#include <iostream> #include <iostream>
#include <utility>
#include <string> #include <string>
#include <utility>


#include "ps/core/comm_util.h" #include "ps/core/comm_util.h"


namespace mindspore { namespace mindspore {
namespace ps { namespace ps {
namespace core { namespace core {

event_base *TcpClient::event_base_ = nullptr;

TcpClient::TcpClient(const std::string &address, std::uint16_t port) TcpClient::TcpClient(const std::string &address, std::uint16_t port)
: event_base_(nullptr),
event_timeout_(nullptr),
: event_timeout_(nullptr),
buffer_event_(nullptr), buffer_event_(nullptr),
server_address_(std::move(address)), server_address_(std::move(address)),
server_port_(port) {
server_port_(port),
is_stop_(true) {
message_handler_.SetCallback([this](const CommMessage &message) { message_handler_.SetCallback([this](const CommMessage &message) {
if (message_callback_) { if (message_callback_) {
message_callback_(*this, message); message_callback_(*this, message);
@@ -61,6 +64,7 @@ void TcpClient::SetCallback(const OnConnected &conn, const OnDisconnected &disco
} }


void TcpClient::Init() { void TcpClient::Init() {
std::lock_guard<std::mutex> lock(connection_mutex_);
if (buffer_event_) { if (buffer_event_) {
return; return;
} }
@@ -68,7 +72,13 @@ void TcpClient::Init() {
MS_LOG(EXCEPTION) << "The tcp client ip:" << server_address_ << " is illegal!"; MS_LOG(EXCEPTION) << "The tcp client ip:" << server_address_ << " is illegal!";
} }


event_base_ = event_base_new();
int result = evthread_use_pthreads();
if (result != 0) {
MS_LOG(EXCEPTION) << "Use event pthread failed!";
}
if (event_base_ == nullptr) {
event_base_ = event_base_new();
}
MS_EXCEPTION_IF_NULL(event_base_); MS_EXCEPTION_IF_NULL(event_base_);


sockaddr_in sin{}; sockaddr_in sin{};
@@ -94,6 +104,7 @@ void TcpClient::Init() {
} }


void TcpClient::StartWithDelay(int seconds) { void TcpClient::StartWithDelay(int seconds) {
std::lock_guard<std::mutex> lock(connection_mutex_);
if (buffer_event_) { if (buffer_event_) {
return; return;
} }
@@ -111,16 +122,28 @@ void TcpClient::StartWithDelay(int seconds) {
} }


void TcpClient::Stop() { void TcpClient::Stop() {
if (buffer_event_) {
bufferevent_free(buffer_event_);
buffer_event_ = nullptr;
}
std::lock_guard<std::mutex> lock(connection_mutex_);
MS_LOG(INFO) << "Stop tcp client event buffer!";
if (!is_stop_.load()) {
if (buffer_event_) {
bufferevent_free(buffer_event_);
buffer_event_ = nullptr;
}


if (event_timeout_) {
event_free(event_timeout_);
event_timeout_ = nullptr;
if (event_timeout_) {
event_free(event_timeout_);
event_timeout_ = nullptr;
}
is_stop_ = true;
} }
}


void TcpClient::StopEventBase() {
MS_LOG(INFO) << "Stop tcp client event base!";
int ret = event_base_loopbreak(event_base_);
if (ret != 0) {
MS_LOG(EXCEPTION) << "Event base loop break failed!";
}
if (event_base_) { if (event_base_) {
event_base_free(event_base_); event_base_free(event_base_);
event_base_ = nullptr; event_base_ = nullptr;
@@ -167,21 +190,12 @@ void TcpClient::OnReadHandler(const void *buf, size_t num) {
message_handler_.ReceiveMessage(buf, num); message_handler_.ReceiveMessage(buf, num);
} }


void TcpClient::SendHeartBeatCallback(evutil_socket_t, int16_t, void *arg) {
void TcpClient::TimerCallback(evutil_socket_t, int16_t, void *arg) {
MS_EXCEPTION_IF_NULL(arg); MS_EXCEPTION_IF_NULL(arg);
auto tcp_client = reinterpret_cast<TcpClient *>(arg); auto tcp_client = reinterpret_cast<TcpClient *>(arg);
MessageMeta meta;
meta.set_cmd(ClusterCommand::HEARTBEAT);
CommMessage message;
message.set_allocated_pb_meta(&meta);
tcp_client->SendMessage(message);

struct event *ev;
struct timeval timeout {};
timeout.tv_sec = ClusterConfig::heartbeat_interval();
timeout.tv_usec = 0;
ev = evtimer_new(tcp_client->event_base_, SendHeartBeatCallback, arg);
evtimer_add(ev, &timeout);
if (tcp_client->on_timer_callback_) {
tcp_client->on_timer_callback_(*tcp_client);
}
} }


void TcpClient::EventCallback(struct bufferevent *bev, std::int16_t events, void *ptr) { void TcpClient::EventCallback(struct bufferevent *bev, std::int16_t events, void *ptr) {
@@ -211,6 +225,7 @@ void TcpClient::EventCallback(struct bufferevent *bev, std::int16_t events, void


void TcpClient::Start() { void TcpClient::Start() {
MS_EXCEPTION_IF_NULL(event_base_); MS_EXCEPTION_IF_NULL(event_base_);
is_stop_ = false;
int ret = event_base_dispatch(event_base_); int ret = event_base_dispatch(event_base_);
MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base dispatch success!"; MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base dispatch success!";
MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType)
@@ -220,6 +235,7 @@ void TcpClient::Start() {
} }


void TcpClient::StartWithNoBlock() { void TcpClient::StartWithNoBlock() {
std::lock_guard<std::mutex> lock(connection_mutex_);
MS_LOG(INFO) << "Start tcp client with no block!"; MS_LOG(INFO) << "Start tcp client with no block!";
MS_EXCEPTION_IF_NULL(event_base_); MS_EXCEPTION_IF_NULL(event_base_);
int ret = event_base_loop(event_base_, EVLOOP_NONBLOCK); int ret = event_base_loop(event_base_, EVLOOP_NONBLOCK);
@@ -244,15 +260,24 @@ void TcpClient::SendMessage(const CommMessage &message) const {
} }
} }


void TcpClient::SendMessageWithTimer() {
MS_EXCEPTION_IF_NULL(buffer_event_);
void TcpClient::StartTimer(const uint32_t &time) {
MS_EXCEPTION_IF_NULL(event_base_);
struct event *ev = nullptr; struct event *ev = nullptr;
if (time == 0) {
MS_LOG(EXCEPTION) << "The time should not be 0!";
}
struct timeval timeout {}; struct timeval timeout {};
timeout.tv_sec = 0;
timeout.tv_sec = time;
timeout.tv_usec = 0; timeout.tv_usec = 0;
ev = evtimer_new(event_base_, SendHeartBeatCallback, this);
ev = event_new(event_base_, -1, EV_PERSIST, TimerCallback, this);
MS_EXCEPTION_IF_NULL(ev);
evtimer_add(ev, &timeout); evtimer_add(ev, &timeout);
} }

void TcpClient::set_timer_callback(const OnTimer &timer) { on_timer_callback_ = timer; }

const event_base &TcpClient::eventbase() { return *event_base_; }

} // namespace core } // namespace core
} // namespace ps } // namespace ps
} // namespace mindspore } // namespace mindspore

+ 15
- 3
mindspore/ccsrc/ps/core/tcp_client.h View File

@@ -21,10 +21,15 @@


#include <event2/event.h> #include <event2/event.h>
#include <event2/bufferevent.h> #include <event2/bufferevent.h>
#include <event2/thread.h>

#include <functional> #include <functional>
#include <string> #include <string>
#include <memory> #include <memory>
#include <vector> #include <vector>
#include <thread>
#include <mutex>
#include <atomic>


#include "proto/comm.pb.h" #include "proto/comm.pb.h"
#include "ps/core/cluster_config.h" #include "ps/core/cluster_config.h"
@@ -40,6 +45,7 @@ class TcpClient {
using OnRead = std::function<void(const TcpClient &, const void *, size_t)>; using OnRead = std::function<void(const TcpClient &, const void *, size_t)>;
using OnTimeout = std::function<void(const TcpClient &)>; using OnTimeout = std::function<void(const TcpClient &)>;
using OnMessage = std::function<void(const TcpClient &, const CommMessage &)>; using OnMessage = std::function<void(const TcpClient &, const CommMessage &)>;
using OnTimer = std::function<void(const TcpClient &)>;


explicit TcpClient(const std::string &address, std::uint16_t port); explicit TcpClient(const std::string &address, std::uint16_t port);
virtual ~TcpClient(); virtual ~TcpClient();
@@ -50,11 +56,14 @@ class TcpClient {
void Init(); void Init();
void StartWithDelay(int seconds); void StartWithDelay(int seconds);
void Stop(); void Stop();
static void StopEventBase();
void Start(); void Start();
void StartWithNoBlock(); void StartWithNoBlock();
void SetMessageCallback(const OnMessage &cb); void SetMessageCallback(const OnMessage &cb);
void SendMessage(const CommMessage &message) const; void SendMessage(const CommMessage &message) const;
void SendMessageWithTimer();
void StartTimer(const uint32_t &time);
void set_timer_callback(const OnTimer &timer);
const event_base &eventbase();


protected: protected:
static void SetTcpNoDelay(const evutil_socket_t &fd); static void SetTcpNoDelay(const evutil_socket_t &fd);
@@ -62,7 +71,7 @@ class TcpClient {
static void ReadCallback(struct bufferevent *bev, void *ctx); static void ReadCallback(struct bufferevent *bev, void *ctx);
static void EventCallback(struct bufferevent *bev, std::int16_t events, void *ptr); static void EventCallback(struct bufferevent *bev, std::int16_t events, void *ptr);
virtual void OnReadHandler(const void *buf, size_t num); virtual void OnReadHandler(const void *buf, size_t num);
static void SendHeartBeatCallback(evutil_socket_t fd, int16_t event, void *arg);
static void TimerCallback(evutil_socket_t fd, int16_t event, void *arg);


private: private:
OnMessage message_callback_; OnMessage message_callback_;
@@ -72,13 +81,16 @@ class TcpClient {
OnDisconnected disconnected_callback_; OnDisconnected disconnected_callback_;
OnRead read_callback_; OnRead read_callback_;
OnTimeout timeout_callback_; OnTimeout timeout_callback_;
OnTimer on_timer_callback_;


event_base *event_base_;
static event_base *event_base_;
std::mutex connection_mutex_;
event *event_timeout_; event *event_timeout_;
bufferevent *buffer_event_; bufferevent *buffer_event_;


std::string server_address_; std::string server_address_;
std::uint16_t server_port_; std::uint16_t server_port_;
std::atomic<bool> is_stop_;
}; };


} // namespace core } // namespace core


+ 55
- 13
mindspore/ccsrc/ps/core/tcp_server.cc View File

@@ -18,10 +18,10 @@


#include <arpa/inet.h> #include <arpa/inet.h>
#include <event2/buffer.h> #include <event2/buffer.h>
#include <event2/buffer_compat.h>
#include <event2/bufferevent.h> #include <event2/bufferevent.h>
#include <event2/event.h> #include <event2/event.h>
#include <event2/listener.h> #include <event2/listener.h>
#include <event2/buffer_compat.h>
#include <event2/util.h> #include <event2/util.h>
#include <sys/socket.h> #include <sys/socket.h>
#include <csignal> #include <csignal>
@@ -73,7 +73,8 @@ TcpServer::TcpServer(const std::string &address, std::uint16_t port)
signal_event_(nullptr), signal_event_(nullptr),
listener_(nullptr), listener_(nullptr),
server_address_(std::move(address)), server_address_(std::move(address)),
server_port_(port) {}
server_port_(port),
is_stop_(true) {}


TcpServer::~TcpServer() { Stop(); } TcpServer::~TcpServer() { Stop(); }


@@ -84,7 +85,14 @@ void TcpServer::SetServerCallback(const OnConnected &client_conn, const OnDiscon
this->client_accept_ = client_accept; this->client_accept_ = client_accept;
} }


void TcpServer::set_timer_callback(const OnTimer &timer) { on_timer_callback_ = timer; }

void TcpServer::Init() { void TcpServer::Init() {
int result = evthread_use_pthreads();
if (result != 0) {
MS_LOG(EXCEPTION) << "Use event pthread failed!";
}

base_ = event_base_new(); base_ = event_base_new();
MS_EXCEPTION_IF_NULL(base_); MS_EXCEPTION_IF_NULL(base_);
if (!CommUtil::CheckIp(server_address_)) { if (!CommUtil::CheckIp(server_address_)) {
@@ -128,6 +136,7 @@ void TcpServer::Start() {
std::unique_lock<std::recursive_mutex> lock(connection_mutex_); std::unique_lock<std::recursive_mutex> lock(connection_mutex_);
MS_LOG(INFO) << "Start tcp server!"; MS_LOG(INFO) << "Start tcp server!";
MS_EXCEPTION_IF_NULL(base_); MS_EXCEPTION_IF_NULL(base_);
is_stop_ = false;
int ret = event_base_dispatch(base_); int ret = event_base_dispatch(base_);
MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base dispatch success!"; MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base dispatch success!";
MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType)
@@ -147,21 +156,42 @@ void TcpServer::StartWithNoBlock() {
MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base loop with unexpect error code!"; MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base loop with unexpect error code!";
} }


void TcpServer::StartTimerOnlyOnce(const uint32_t &time) {
MS_EXCEPTION_IF_NULL(base_);
if (time == 0) {
MS_LOG(EXCEPTION) << "The time should not be 0!";
}
struct event *ev = nullptr;
struct timeval timeout {};
timeout.tv_sec = time;
timeout.tv_usec = 0;
ev = evtimer_new(base_, TimerCallback, this);
MS_EXCEPTION_IF_NULL(ev);
evtimer_add(ev, &timeout);
}

void TcpServer::Stop() { void TcpServer::Stop() {
MS_LOG(INFO) << "Stop tcp server!"; MS_LOG(INFO) << "Stop tcp server!";
if (signal_event_ != nullptr) {
event_free(signal_event_);
signal_event_ = nullptr;
}
if (!is_stop_.load()) {
int ret = event_base_loopbreak(base_);
if (ret != 0) {
MS_LOG(EXCEPTION) << "event base loop break failed!";
}
if (signal_event_ != nullptr) {
event_free(signal_event_);
signal_event_ = nullptr;
}


if (listener_ != nullptr) {
evconnlistener_free(listener_);
listener_ = nullptr;
}
if (listener_ != nullptr) {
evconnlistener_free(listener_);
listener_ = nullptr;
}


if (base_ != nullptr) {
event_base_free(base_);
base_ = nullptr;
if (base_ != nullptr) {
event_base_free(base_);
base_ = nullptr;
}
is_stop_ = true;
} }
} }


@@ -287,6 +317,14 @@ void TcpServer::EventCallback(struct bufferevent *bev, std::int16_t events, void
} }
} }


void TcpServer::TimerCallback(evutil_socket_t, int16_t, void *arg) {
MS_EXCEPTION_IF_NULL(arg);
auto tcp_server = reinterpret_cast<TcpServer *>(arg);
if (tcp_server->on_timer_callback_) {
tcp_server->on_timer_callback_(*tcp_server);
}
}

void TcpServer::SendMessage(const TcpConnection &conn, const CommMessage &message) { conn.SendMessage(message); } void TcpServer::SendMessage(const TcpConnection &conn, const CommMessage &message) { conn.SendMessage(message); }


void TcpServer::SendMessage(const CommMessage &message) { void TcpServer::SendMessage(const CommMessage &message) {
@@ -299,6 +337,10 @@ void TcpServer::SendMessage(const CommMessage &message) {


uint16_t TcpServer::BoundPort() const { return server_port_; } uint16_t TcpServer::BoundPort() const { return server_port_; }


int TcpServer::ConnectionNum() const { return connections_.size(); }

const std::map<evutil_socket_t, const TcpConnection *> &TcpServer::Connections() const { return connections_; }

void TcpServer::SetMessageCallback(const OnServerReceiveMessage &cb) { message_callback_ = cb; } void TcpServer::SetMessageCallback(const OnServerReceiveMessage &cb) { message_callback_ = cb; }
} // namespace core } // namespace core
} // namespace ps } // namespace ps


+ 18
- 4
mindspore/ccsrc/ps/core/tcp_server.h View File

@@ -21,17 +21,23 @@
#include <event2/bufferevent.h> #include <event2/bufferevent.h>
#include <event2/event.h> #include <event2/event.h>
#include <event2/listener.h> #include <event2/listener.h>
#include <event2/thread.h>

#include <exception> #include <exception>
#include <functional> #include <functional>
#include <iostream> #include <iostream>
#include <map> #include <map>
#include <memory>
#include <mutex> #include <mutex>
#include <string> #include <string>
#include <memory>
#include <vector> #include <vector>
#include <thread>
#include <atomic>


#include "utils/log_adapter.h"
#include "proto/comm.pb.h"
#include "ps/core/tcp_message_handler.h" #include "ps/core/tcp_message_handler.h"
#include "ps/core/cluster_config.h"
#include "utils/log_adapter.h"


namespace mindspore { namespace mindspore {
namespace ps { namespace ps {
@@ -40,7 +46,7 @@ class TcpServer;
class TcpConnection { class TcpConnection {
public: public:
explicit TcpConnection(struct bufferevent *bev, const evutil_socket_t &fd, const TcpServer *server) explicit TcpConnection(struct bufferevent *bev, const evutil_socket_t &fd, const TcpServer *server)
: buffer_event_(bev), fd_(0), server_(server) {}
: buffer_event_(bev), fd_(fd), server_(server) {}
virtual ~TcpConnection() = default; virtual ~TcpConnection() = default;


virtual void InitConnection(); virtual void InitConnection();
@@ -65,24 +71,29 @@ class TcpServer {
using OnConnected = std::function<void(const TcpServer &, const TcpConnection &)>; using OnConnected = std::function<void(const TcpServer &, const TcpConnection &)>;
using OnDisconnected = std::function<void(const TcpServer &, const TcpConnection &)>; using OnDisconnected = std::function<void(const TcpServer &, const TcpConnection &)>;
using OnAccepted = std::function<const TcpConnection *(const TcpServer &)>; using OnAccepted = std::function<const TcpConnection *(const TcpServer &)>;
using OnTimer = std::function<void(const TcpServer &)>;


explicit TcpServer(const std::string &address, std::uint16_t port); explicit TcpServer(const std::string &address, std::uint16_t port);
virtual ~TcpServer(); virtual ~TcpServer();


void SetServerCallback(const OnConnected &client_conn, const OnDisconnected &client_disconn, void SetServerCallback(const OnConnected &client_conn, const OnDisconnected &client_disconn,
const OnAccepted &client_accept); const OnAccepted &client_accept);
void set_timer_callback(const OnTimer &timer);
void Init(); void Init();
void Start(); void Start();
void StartWithNoBlock(); void StartWithNoBlock();
void StartTimerOnlyOnce(const uint32_t &time);
void Stop(); void Stop();
void SendToAllClients(const char *data, size_t len); void SendToAllClients(const char *data, size_t len);
void AddConnection(const evutil_socket_t &fd, const TcpConnection *connection); void AddConnection(const evutil_socket_t &fd, const TcpConnection *connection);
void RemoveConnection(const evutil_socket_t &fd); void RemoveConnection(const evutil_socket_t &fd);
OnServerReceiveMessage GetServerReceive() const; OnServerReceiveMessage GetServerReceive() const;
void SetMessageCallback(const OnServerReceiveMessage &cb); void SetMessageCallback(const OnServerReceiveMessage &cb);
static void SendMessage(const TcpConnection &conn, const CommMessage &message);
void SendMessage(const TcpConnection &conn, const CommMessage &message);
void SendMessage(const CommMessage &message); void SendMessage(const CommMessage &message);
uint16_t BoundPort() const; uint16_t BoundPort() const;
int ConnectionNum() const;
const std::map<evutil_socket_t, const TcpConnection *> &Connections() const;


protected: protected:
static void ListenerCallback(struct evconnlistener *listener, evutil_socket_t socket, struct sockaddr *saddr, static void ListenerCallback(struct evconnlistener *listener, evutil_socket_t socket, struct sockaddr *saddr,
@@ -90,6 +101,7 @@ class TcpServer {
static void SignalCallback(evutil_socket_t sig, std::int16_t events, void *server); static void SignalCallback(evutil_socket_t sig, std::int16_t events, void *server);
static void ReadCallback(struct bufferevent *, void *connection); static void ReadCallback(struct bufferevent *, void *connection);
static void EventCallback(struct bufferevent *, std::int16_t events, void *server); static void EventCallback(struct bufferevent *, std::int16_t events, void *server);
static void TimerCallback(evutil_socket_t fd, int16_t event, void *arg);
virtual TcpConnection *onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd); virtual TcpConnection *onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd);


struct event_base *base_; struct event_base *base_;
@@ -97,6 +109,7 @@ class TcpServer {
struct evconnlistener *listener_; struct evconnlistener *listener_;
std::string server_address_; std::string server_address_;
std::uint16_t server_port_; std::uint16_t server_port_;
std::atomic<bool> is_stop_;


std::map<evutil_socket_t, const TcpConnection *> connections_; std::map<evutil_socket_t, const TcpConnection *> connections_;
OnConnected client_connection_; OnConnected client_connection_;
@@ -104,6 +117,7 @@ class TcpServer {
OnAccepted client_accept_; OnAccepted client_accept_;
std::recursive_mutex connection_mutex_; std::recursive_mutex connection_mutex_;
OnServerReceiveMessage message_callback_; OnServerReceiveMessage message_callback_;
OnTimer on_timer_callback_;
}; };
} // namespace core } // namespace core
} // namespace ps } // namespace ps


+ 1
- 1
tests/ut/cpp/ps/core/tcp_pb_server_test.cc View File

@@ -37,7 +37,7 @@ class TestTcpServer : public UT::Common {
KVMessage kv_message; KVMessage kv_message;
kv_message.ParseFromString(message.data()); kv_message.ParseFromString(message.data());
EXPECT_EQ(2, kv_message.keys_size()); EXPECT_EQ(2, kv_message.keys_size());
server.SendMessage(conn, message);
const_cast<TcpServer&>(server).SendMessage(conn, message);
}); });
server_->Init(); server_->Init();
server_->Start(); server_->Start();


Loading…
Cancel
Save