| @@ -21,12 +21,16 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| namespace core { | namespace core { | ||||
| uint32_t ClusterConfig::worker_num_ = 0; | uint32_t ClusterConfig::worker_num_ = 0; | ||||
| uint32_t ClusterConfig::server_num_ = 0; | uint32_t ClusterConfig::server_num_ = 0; | ||||
| uint32_t ClusterConfig::heartbeat_interval_ = kHeartbeatInterval; | |||||
| std::unique_ptr<std::string> ClusterConfig::scheduler_host_ = nullptr; | std::unique_ptr<std::string> ClusterConfig::scheduler_host_ = nullptr; | ||||
| uint16_t ClusterConfig::scheduler_port_ = 0; | uint16_t ClusterConfig::scheduler_port_ = 0; | ||||
| // The interval for sending heartbeat packets between worker node,server node and scheduler node is 3 seconds. | |||||
| uint32_t ClusterConfig::heartbeat_interval_ = 3; | |||||
| // The timeout for worker node and server node sending heartbeat packets to scheduler node is 30 seconds. | |||||
| uint32_t ClusterConfig::heartbeat_timeout_ = 30; | |||||
| // Timeout period for cluster preparation is 300 seconds. | |||||
| uint32_t ClusterConfig::cluster_available_timeout_ = 300; | |||||
| void ClusterConfig::Init(const uint32_t &worker_num, const uint32_t &server_num, | void ClusterConfig::Init(const uint32_t &worker_num, const uint32_t &server_num, | ||||
| std::unique_ptr<std::string> scheduler_host, const uint16_t &scheduler_port) { | std::unique_ptr<std::string> scheduler_host, const uint16_t &scheduler_port) { | ||||
| @@ -53,6 +57,18 @@ std::string ClusterConfig::scheduler_host() { return *scheduler_host_.get(); } | |||||
| uint16_t ClusterConfig::scheduler_port() { return scheduler_port_; } | uint16_t ClusterConfig::scheduler_port() { return scheduler_port_; } | ||||
| uint32_t ClusterConfig::heartbeat_timeout() { return heartbeat_timeout_; } | |||||
| void ClusterConfig::set_heartbeat_timeout(const uint32_t &heartbeat_timeout) { | |||||
| heartbeat_interval_ = heartbeat_timeout; | |||||
| } | |||||
| uint32_t ClusterConfig::cluster_available_timeout() { return cluster_available_timeout_; } | |||||
| void ClusterConfig::set_cluster_available_timeout(const uint32_t &cluster_available_timeout) { | |||||
| cluster_available_timeout_ = cluster_available_timeout; | |||||
| } | |||||
| } // namespace core | } // namespace core | ||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -28,8 +28,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| namespace core { | namespace core { | ||||
| constexpr uint32_t kHeartbeatInterval = 3; | |||||
| class ClusterConfig { | class ClusterConfig { | ||||
| public: | public: | ||||
| static void Init(const uint32_t &worker_num, const uint32_t &server_num, std::unique_ptr<std::string> scheduler_host, | static void Init(const uint32_t &worker_num, const uint32_t &server_num, std::unique_ptr<std::string> scheduler_host, | ||||
| @@ -40,6 +38,10 @@ class ClusterConfig { | |||||
| static void set_heartbeat_interval(const uint32_t &heartbeat_interval); | static void set_heartbeat_interval(const uint32_t &heartbeat_interval); | ||||
| static std::string scheduler_host(); | static std::string scheduler_host(); | ||||
| static uint16_t scheduler_port(); | static uint16_t scheduler_port(); | ||||
| static uint32_t heartbeat_timeout(); | |||||
| static void set_heartbeat_timeout(const uint32_t &heartbeat_timeout); | |||||
| static uint32_t cluster_available_timeout(); | |||||
| static void set_cluster_available_timeout(const uint32_t &cluster_available_timeout); | |||||
| private: | private: | ||||
| static uint32_t worker_num_; | static uint32_t worker_num_; | ||||
| @@ -47,6 +49,8 @@ class ClusterConfig { | |||||
| static uint32_t heartbeat_interval_; | static uint32_t heartbeat_interval_; | ||||
| static std::unique_ptr<std::string> scheduler_host_; | static std::unique_ptr<std::string> scheduler_host_; | ||||
| static uint16_t scheduler_port_; | static uint16_t scheduler_port_; | ||||
| static uint32_t heartbeat_timeout_; | |||||
| static uint32_t cluster_available_timeout_; | |||||
| }; | }; | ||||
| } // namespace core | } // namespace core | ||||
| } // namespace ps | } // namespace ps | ||||
| @@ -21,11 +21,17 @@ | |||||
| #include <cstdlib> | #include <cstdlib> | ||||
| #include <cstring> | #include <cstring> | ||||
| #include <functional> | #include <functional> | ||||
| #include <algorithm> | |||||
| #include <regex> | #include <regex> | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| namespace core { | namespace core { | ||||
| std::random_device CommUtil::rd; | |||||
| std::mt19937_64 CommUtil::gen(rd()); | |||||
| std::uniform_int_distribution<> CommUtil::dis = std::uniform_int_distribution<>{0, 15}; | |||||
| std::uniform_int_distribution<> CommUtil::dis2 = std::uniform_int_distribution<>{8, 11}; | |||||
| bool CommUtil::CheckIpWithRegex(const std::string &ip) { | bool CommUtil::CheckIpWithRegex(const std::string &ip) { | ||||
| std::regex pattern("((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?).){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)"); | std::regex pattern("((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?).){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)"); | ||||
| std::smatch res; | std::smatch res; | ||||
| @@ -75,6 +81,34 @@ void CommUtil::GetAvailableInterfaceAndIP(std::string *interface, std::string *i | |||||
| MS_EXCEPTION_IF_NULL(if_address); | MS_EXCEPTION_IF_NULL(if_address); | ||||
| freeifaddrs(if_address); | freeifaddrs(if_address); | ||||
| } | } | ||||
| std::string CommUtil::GenerateUUID() { | |||||
| std::stringstream ss; | |||||
| int i; | |||||
| ss << std::hex; | |||||
| for (i = 0; i < kGroup1RandomLength; i++) { | |||||
| ss << dis(gen); | |||||
| } | |||||
| ss << "-"; | |||||
| for (i = 0; i < kGroup2RandomLength; i++) { | |||||
| ss << dis(gen); | |||||
| } | |||||
| ss << "-4"; | |||||
| for (i = 0; i < kGroup2RandomLength - 1; i++) { | |||||
| ss << dis(gen); | |||||
| } | |||||
| ss << "-"; | |||||
| ss << dis2(gen); | |||||
| for (i = 0; i < kGroup3RandomLength - 1; i++) { | |||||
| ss << dis(gen); | |||||
| } | |||||
| ss << "-"; | |||||
| for (i = 0; i < kGroup4RandomLength; i++) { | |||||
| ss << dis(gen); | |||||
| } | |||||
| return ss.str(); | |||||
| } | |||||
| } // namespace core | } // namespace core | ||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -43,17 +43,31 @@ | |||||
| #include <functional> | #include <functional> | ||||
| #include <string> | #include <string> | ||||
| #include <utility> | #include <utility> | ||||
| #include <random> | |||||
| #include <sstream> | |||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| namespace core { | namespace core { | ||||
| constexpr int kGroup1RandomLength = 8; | |||||
| constexpr int kGroup2RandomLength = 4; | |||||
| constexpr int kGroup3RandomLength = 4; | |||||
| constexpr int kGroup4RandomLength = 4; | |||||
| constexpr int kGroup5RandomLength = 12; | |||||
| class CommUtil { | class CommUtil { | ||||
| public: | public: | ||||
| static bool CheckIpWithRegex(const std::string &ip); | static bool CheckIpWithRegex(const std::string &ip); | ||||
| static bool CheckIp(const std::string &ip); | static bool CheckIp(const std::string &ip); | ||||
| static void GetAvailableInterfaceAndIP(std::string *interface, std::string *ip); | static void GetAvailableInterfaceAndIP(std::string *interface, std::string *ip); | ||||
| static std::string GenerateUUID(); | |||||
| static std::random_device rd; | |||||
| static std::mt19937_64 gen; | |||||
| static std::uniform_int_distribution<> dis; | |||||
| static std::uniform_int_distribution<> dis2; | |||||
| }; | }; | ||||
| } // namespace core | } // namespace core | ||||
| } // namespace ps | } // namespace ps | ||||
| @@ -19,36 +19,64 @@ import "google/protobuf/any.proto"; | |||||
| package mindspore.ps.core; | package mindspore.ps.core; | ||||
| option optimize_for = LITE_RUNTIME; | option optimize_for = LITE_RUNTIME; | ||||
| enum ClusterCommand { | |||||
| enum NodeCommand { | |||||
| TERMINATE = 0; | TERMINATE = 0; | ||||
| REGISTER = 1; | REGISTER = 1; | ||||
| ACK = 2; | |||||
| HEARTBEAT = 3; | |||||
| FETCH_WORKERS = 4; | |||||
| FETCH_SERVERS = 5; | |||||
| HEARTBEAT = 2; | |||||
| SEND_DATA = 3; | |||||
| FETCH_SERVER = 4; | |||||
| } | } | ||||
| enum Role { | |||||
| enum NodeRole { | |||||
| SERVER = 0; | SERVER = 0; | ||||
| WORKER = 1; | WORKER = 1; | ||||
| SCHEDULER = 2; | SCHEDULER = 2; | ||||
| } | } | ||||
| message MessageMeta { | message MessageMeta { | ||||
| // hostname or ip | |||||
| string hostname = 1; | |||||
| // the command of this message,for example: register,heartbeat,data | |||||
| NodeCommand cmd = 1; | |||||
| // the request id of this message | |||||
| uint64 request_id = 2; | |||||
| } | |||||
| message RegisterMessage { | |||||
| // ip | |||||
| string ip = 1; | |||||
| // the port of this node | // the port of this node | ||||
| int32 port = 2; | int32 port = 2; | ||||
| // the command of this message,for example: register、heartbeat、data | |||||
| int32 cmd = 3; | |||||
| // the timestamp of this message | |||||
| int32 timestamp = 4; | |||||
| // data type of message | |||||
| repeated int32 data_type = 5 [packed = true]; | |||||
| // message.data_size | |||||
| int32 data_size = 6; | |||||
| // the current Node unique id:0,1,2... | |||||
| string node_id = 3; | |||||
| // the role of the node: worker,server,scheduler | |||||
| NodeRole role = 4; | |||||
| } | |||||
| message RegisterRespMessage { | |||||
| string node_id = 1; | |||||
| int32 rank_id = 2; | |||||
| } | |||||
| message HeartbeatMessage { | |||||
| // the current Node unique id:0,1,2... | |||||
| string node_id = 1; | |||||
| } | } | ||||
| message HeartbeatRespMessage { | |||||
| // Is the entire system ready to use. | |||||
| bool is_cluster_ready = 1; | |||||
| bool is_cluster_finish = 2; | |||||
| } | |||||
| message FetchServersRespMessage { | |||||
| repeated ServersMeta servers_meta = 1; | |||||
| } | |||||
| message ServersMeta { | |||||
| int32 rank_id = 1; | |||||
| string ip = 2; | |||||
| int32 port = 3; | |||||
| } | |||||
| message CommMessage { | message CommMessage { | ||||
| MessageMeta pb_meta = 1; | MessageMeta pb_meta = 1; | ||||
| @@ -13,17 +13,17 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| syntax = "proto3"; | syntax = "proto3"; | ||||
| package mindspore.ps.core; | package mindspore.ps.core; | ||||
| option optimize_for = LITE_RUNTIME; | option optimize_for = LITE_RUNTIME; | ||||
| message KVMessage { | |||||
| repeated int32 keys = 1; | |||||
| repeated float values = 2; | |||||
| enum PSCommand { | |||||
| PUSH = 0; | |||||
| PULL = 1; | |||||
| } | } | ||||
| message HeartBeatMessage { | |||||
| // *.*.*.*:port | |||||
| repeated string host_and_port = 1; | |||||
| message KVMessage { | |||||
| PSCommand command = 1; | |||||
| repeated int32 keys = 2; | |||||
| repeated float values = 3; | |||||
| } | } | ||||
| @@ -18,8 +18,8 @@ | |||||
| #include <arpa/inet.h> | #include <arpa/inet.h> | ||||
| #include <event2/buffer.h> | #include <event2/buffer.h> | ||||
| #include <event2/bufferevent.h> | |||||
| #include <event2/buffer_compat.h> | #include <event2/buffer_compat.h> | ||||
| #include <event2/bufferevent.h> | |||||
| #include <event2/event.h> | #include <event2/event.h> | ||||
| #include <netinet/in.h> | #include <netinet/in.h> | ||||
| #include <netinet/tcp.h> | #include <netinet/tcp.h> | ||||
| @@ -27,20 +27,23 @@ | |||||
| #include <cstdlib> | #include <cstdlib> | ||||
| #include <cstring> | #include <cstring> | ||||
| #include <iostream> | #include <iostream> | ||||
| #include <utility> | |||||
| #include <string> | #include <string> | ||||
| #include <utility> | |||||
| #include "ps/core/comm_util.h" | #include "ps/core/comm_util.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| namespace core { | namespace core { | ||||
| event_base *TcpClient::event_base_ = nullptr; | |||||
| TcpClient::TcpClient(const std::string &address, std::uint16_t port) | TcpClient::TcpClient(const std::string &address, std::uint16_t port) | ||||
| : event_base_(nullptr), | |||||
| event_timeout_(nullptr), | |||||
| : event_timeout_(nullptr), | |||||
| buffer_event_(nullptr), | buffer_event_(nullptr), | ||||
| server_address_(std::move(address)), | server_address_(std::move(address)), | ||||
| server_port_(port) { | |||||
| server_port_(port), | |||||
| is_stop_(true) { | |||||
| message_handler_.SetCallback([this](const CommMessage &message) { | message_handler_.SetCallback([this](const CommMessage &message) { | ||||
| if (message_callback_) { | if (message_callback_) { | ||||
| message_callback_(*this, message); | message_callback_(*this, message); | ||||
| @@ -61,6 +64,7 @@ void TcpClient::SetCallback(const OnConnected &conn, const OnDisconnected &disco | |||||
| } | } | ||||
| void TcpClient::Init() { | void TcpClient::Init() { | ||||
| std::lock_guard<std::mutex> lock(connection_mutex_); | |||||
| if (buffer_event_) { | if (buffer_event_) { | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -68,7 +72,13 @@ void TcpClient::Init() { | |||||
| MS_LOG(EXCEPTION) << "The tcp client ip:" << server_address_ << " is illegal!"; | MS_LOG(EXCEPTION) << "The tcp client ip:" << server_address_ << " is illegal!"; | ||||
| } | } | ||||
| event_base_ = event_base_new(); | |||||
| int result = evthread_use_pthreads(); | |||||
| if (result != 0) { | |||||
| MS_LOG(EXCEPTION) << "Use event pthread failed!"; | |||||
| } | |||||
| if (event_base_ == nullptr) { | |||||
| event_base_ = event_base_new(); | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(event_base_); | MS_EXCEPTION_IF_NULL(event_base_); | ||||
| sockaddr_in sin{}; | sockaddr_in sin{}; | ||||
| @@ -94,6 +104,7 @@ void TcpClient::Init() { | |||||
| } | } | ||||
| void TcpClient::StartWithDelay(int seconds) { | void TcpClient::StartWithDelay(int seconds) { | ||||
| std::lock_guard<std::mutex> lock(connection_mutex_); | |||||
| if (buffer_event_) { | if (buffer_event_) { | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -111,16 +122,28 @@ void TcpClient::StartWithDelay(int seconds) { | |||||
| } | } | ||||
| void TcpClient::Stop() { | void TcpClient::Stop() { | ||||
| if (buffer_event_) { | |||||
| bufferevent_free(buffer_event_); | |||||
| buffer_event_ = nullptr; | |||||
| } | |||||
| std::lock_guard<std::mutex> lock(connection_mutex_); | |||||
| MS_LOG(INFO) << "Stop tcp client event buffer!"; | |||||
| if (!is_stop_.load()) { | |||||
| if (buffer_event_) { | |||||
| bufferevent_free(buffer_event_); | |||||
| buffer_event_ = nullptr; | |||||
| } | |||||
| if (event_timeout_) { | |||||
| event_free(event_timeout_); | |||||
| event_timeout_ = nullptr; | |||||
| if (event_timeout_) { | |||||
| event_free(event_timeout_); | |||||
| event_timeout_ = nullptr; | |||||
| } | |||||
| is_stop_ = true; | |||||
| } | } | ||||
| } | |||||
| void TcpClient::StopEventBase() { | |||||
| MS_LOG(INFO) << "Stop tcp client event base!"; | |||||
| int ret = event_base_loopbreak(event_base_); | |||||
| if (ret != 0) { | |||||
| MS_LOG(EXCEPTION) << "Event base loop break failed!"; | |||||
| } | |||||
| if (event_base_) { | if (event_base_) { | ||||
| event_base_free(event_base_); | event_base_free(event_base_); | ||||
| event_base_ = nullptr; | event_base_ = nullptr; | ||||
| @@ -167,21 +190,12 @@ void TcpClient::OnReadHandler(const void *buf, size_t num) { | |||||
| message_handler_.ReceiveMessage(buf, num); | message_handler_.ReceiveMessage(buf, num); | ||||
| } | } | ||||
| void TcpClient::SendHeartBeatCallback(evutil_socket_t, int16_t, void *arg) { | |||||
| void TcpClient::TimerCallback(evutil_socket_t, int16_t, void *arg) { | |||||
| MS_EXCEPTION_IF_NULL(arg); | MS_EXCEPTION_IF_NULL(arg); | ||||
| auto tcp_client = reinterpret_cast<TcpClient *>(arg); | auto tcp_client = reinterpret_cast<TcpClient *>(arg); | ||||
| MessageMeta meta; | |||||
| meta.set_cmd(ClusterCommand::HEARTBEAT); | |||||
| CommMessage message; | |||||
| message.set_allocated_pb_meta(&meta); | |||||
| tcp_client->SendMessage(message); | |||||
| struct event *ev; | |||||
| struct timeval timeout {}; | |||||
| timeout.tv_sec = ClusterConfig::heartbeat_interval(); | |||||
| timeout.tv_usec = 0; | |||||
| ev = evtimer_new(tcp_client->event_base_, SendHeartBeatCallback, arg); | |||||
| evtimer_add(ev, &timeout); | |||||
| if (tcp_client->on_timer_callback_) { | |||||
| tcp_client->on_timer_callback_(*tcp_client); | |||||
| } | |||||
| } | } | ||||
| void TcpClient::EventCallback(struct bufferevent *bev, std::int16_t events, void *ptr) { | void TcpClient::EventCallback(struct bufferevent *bev, std::int16_t events, void *ptr) { | ||||
| @@ -211,6 +225,7 @@ void TcpClient::EventCallback(struct bufferevent *bev, std::int16_t events, void | |||||
| void TcpClient::Start() { | void TcpClient::Start() { | ||||
| MS_EXCEPTION_IF_NULL(event_base_); | MS_EXCEPTION_IF_NULL(event_base_); | ||||
| is_stop_ = false; | |||||
| int ret = event_base_dispatch(event_base_); | int ret = event_base_dispatch(event_base_); | ||||
| MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base dispatch success!"; | MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base dispatch success!"; | ||||
| MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) | MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) | ||||
| @@ -220,6 +235,7 @@ void TcpClient::Start() { | |||||
| } | } | ||||
| void TcpClient::StartWithNoBlock() { | void TcpClient::StartWithNoBlock() { | ||||
| std::lock_guard<std::mutex> lock(connection_mutex_); | |||||
| MS_LOG(INFO) << "Start tcp client with no block!"; | MS_LOG(INFO) << "Start tcp client with no block!"; | ||||
| MS_EXCEPTION_IF_NULL(event_base_); | MS_EXCEPTION_IF_NULL(event_base_); | ||||
| int ret = event_base_loop(event_base_, EVLOOP_NONBLOCK); | int ret = event_base_loop(event_base_, EVLOOP_NONBLOCK); | ||||
| @@ -244,15 +260,24 @@ void TcpClient::SendMessage(const CommMessage &message) const { | |||||
| } | } | ||||
| } | } | ||||
| void TcpClient::SendMessageWithTimer() { | |||||
| MS_EXCEPTION_IF_NULL(buffer_event_); | |||||
| void TcpClient::StartTimer(const uint32_t &time) { | |||||
| MS_EXCEPTION_IF_NULL(event_base_); | |||||
| struct event *ev = nullptr; | struct event *ev = nullptr; | ||||
| if (time == 0) { | |||||
| MS_LOG(EXCEPTION) << "The time should not be 0!"; | |||||
| } | |||||
| struct timeval timeout {}; | struct timeval timeout {}; | ||||
| timeout.tv_sec = 0; | |||||
| timeout.tv_sec = time; | |||||
| timeout.tv_usec = 0; | timeout.tv_usec = 0; | ||||
| ev = evtimer_new(event_base_, SendHeartBeatCallback, this); | |||||
| ev = event_new(event_base_, -1, EV_PERSIST, TimerCallback, this); | |||||
| MS_EXCEPTION_IF_NULL(ev); | |||||
| evtimer_add(ev, &timeout); | evtimer_add(ev, &timeout); | ||||
| } | } | ||||
| void TcpClient::set_timer_callback(const OnTimer &timer) { on_timer_callback_ = timer; } | |||||
| const event_base &TcpClient::eventbase() { return *event_base_; } | |||||
| } // namespace core | } // namespace core | ||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -21,10 +21,15 @@ | |||||
| #include <event2/event.h> | #include <event2/event.h> | ||||
| #include <event2/bufferevent.h> | #include <event2/bufferevent.h> | ||||
| #include <event2/thread.h> | |||||
| #include <functional> | #include <functional> | ||||
| #include <string> | #include <string> | ||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | #include <vector> | ||||
| #include <thread> | |||||
| #include <mutex> | |||||
| #include <atomic> | |||||
| #include "proto/comm.pb.h" | #include "proto/comm.pb.h" | ||||
| #include "ps/core/cluster_config.h" | #include "ps/core/cluster_config.h" | ||||
| @@ -40,6 +45,7 @@ class TcpClient { | |||||
| using OnRead = std::function<void(const TcpClient &, const void *, size_t)>; | using OnRead = std::function<void(const TcpClient &, const void *, size_t)>; | ||||
| using OnTimeout = std::function<void(const TcpClient &)>; | using OnTimeout = std::function<void(const TcpClient &)>; | ||||
| using OnMessage = std::function<void(const TcpClient &, const CommMessage &)>; | using OnMessage = std::function<void(const TcpClient &, const CommMessage &)>; | ||||
| using OnTimer = std::function<void(const TcpClient &)>; | |||||
| explicit TcpClient(const std::string &address, std::uint16_t port); | explicit TcpClient(const std::string &address, std::uint16_t port); | ||||
| virtual ~TcpClient(); | virtual ~TcpClient(); | ||||
| @@ -50,11 +56,14 @@ class TcpClient { | |||||
| void Init(); | void Init(); | ||||
| void StartWithDelay(int seconds); | void StartWithDelay(int seconds); | ||||
| void Stop(); | void Stop(); | ||||
| static void StopEventBase(); | |||||
| void Start(); | void Start(); | ||||
| void StartWithNoBlock(); | void StartWithNoBlock(); | ||||
| void SetMessageCallback(const OnMessage &cb); | void SetMessageCallback(const OnMessage &cb); | ||||
| void SendMessage(const CommMessage &message) const; | void SendMessage(const CommMessage &message) const; | ||||
| void SendMessageWithTimer(); | |||||
| void StartTimer(const uint32_t &time); | |||||
| void set_timer_callback(const OnTimer &timer); | |||||
| const event_base &eventbase(); | |||||
| protected: | protected: | ||||
| static void SetTcpNoDelay(const evutil_socket_t &fd); | static void SetTcpNoDelay(const evutil_socket_t &fd); | ||||
| @@ -62,7 +71,7 @@ class TcpClient { | |||||
| static void ReadCallback(struct bufferevent *bev, void *ctx); | static void ReadCallback(struct bufferevent *bev, void *ctx); | ||||
| static void EventCallback(struct bufferevent *bev, std::int16_t events, void *ptr); | static void EventCallback(struct bufferevent *bev, std::int16_t events, void *ptr); | ||||
| virtual void OnReadHandler(const void *buf, size_t num); | virtual void OnReadHandler(const void *buf, size_t num); | ||||
| static void SendHeartBeatCallback(evutil_socket_t fd, int16_t event, void *arg); | |||||
| static void TimerCallback(evutil_socket_t fd, int16_t event, void *arg); | |||||
| private: | private: | ||||
| OnMessage message_callback_; | OnMessage message_callback_; | ||||
| @@ -72,13 +81,16 @@ class TcpClient { | |||||
| OnDisconnected disconnected_callback_; | OnDisconnected disconnected_callback_; | ||||
| OnRead read_callback_; | OnRead read_callback_; | ||||
| OnTimeout timeout_callback_; | OnTimeout timeout_callback_; | ||||
| OnTimer on_timer_callback_; | |||||
| event_base *event_base_; | |||||
| static event_base *event_base_; | |||||
| std::mutex connection_mutex_; | |||||
| event *event_timeout_; | event *event_timeout_; | ||||
| bufferevent *buffer_event_; | bufferevent *buffer_event_; | ||||
| std::string server_address_; | std::string server_address_; | ||||
| std::uint16_t server_port_; | std::uint16_t server_port_; | ||||
| std::atomic<bool> is_stop_; | |||||
| }; | }; | ||||
| } // namespace core | } // namespace core | ||||
| @@ -18,10 +18,10 @@ | |||||
| #include <arpa/inet.h> | #include <arpa/inet.h> | ||||
| #include <event2/buffer.h> | #include <event2/buffer.h> | ||||
| #include <event2/buffer_compat.h> | |||||
| #include <event2/bufferevent.h> | #include <event2/bufferevent.h> | ||||
| #include <event2/event.h> | #include <event2/event.h> | ||||
| #include <event2/listener.h> | #include <event2/listener.h> | ||||
| #include <event2/buffer_compat.h> | |||||
| #include <event2/util.h> | #include <event2/util.h> | ||||
| #include <sys/socket.h> | #include <sys/socket.h> | ||||
| #include <csignal> | #include <csignal> | ||||
| @@ -73,7 +73,8 @@ TcpServer::TcpServer(const std::string &address, std::uint16_t port) | |||||
| signal_event_(nullptr), | signal_event_(nullptr), | ||||
| listener_(nullptr), | listener_(nullptr), | ||||
| server_address_(std::move(address)), | server_address_(std::move(address)), | ||||
| server_port_(port) {} | |||||
| server_port_(port), | |||||
| is_stop_(true) {} | |||||
| TcpServer::~TcpServer() { Stop(); } | TcpServer::~TcpServer() { Stop(); } | ||||
| @@ -84,7 +85,14 @@ void TcpServer::SetServerCallback(const OnConnected &client_conn, const OnDiscon | |||||
| this->client_accept_ = client_accept; | this->client_accept_ = client_accept; | ||||
| } | } | ||||
| void TcpServer::set_timer_callback(const OnTimer &timer) { on_timer_callback_ = timer; } | |||||
| void TcpServer::Init() { | void TcpServer::Init() { | ||||
| int result = evthread_use_pthreads(); | |||||
| if (result != 0) { | |||||
| MS_LOG(EXCEPTION) << "Use event pthread failed!"; | |||||
| } | |||||
| base_ = event_base_new(); | base_ = event_base_new(); | ||||
| MS_EXCEPTION_IF_NULL(base_); | MS_EXCEPTION_IF_NULL(base_); | ||||
| if (!CommUtil::CheckIp(server_address_)) { | if (!CommUtil::CheckIp(server_address_)) { | ||||
| @@ -128,6 +136,7 @@ void TcpServer::Start() { | |||||
| std::unique_lock<std::recursive_mutex> lock(connection_mutex_); | std::unique_lock<std::recursive_mutex> lock(connection_mutex_); | ||||
| MS_LOG(INFO) << "Start tcp server!"; | MS_LOG(INFO) << "Start tcp server!"; | ||||
| MS_EXCEPTION_IF_NULL(base_); | MS_EXCEPTION_IF_NULL(base_); | ||||
| is_stop_ = false; | |||||
| int ret = event_base_dispatch(base_); | int ret = event_base_dispatch(base_); | ||||
| MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base dispatch success!"; | MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base dispatch success!"; | ||||
| MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) | MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) | ||||
| @@ -147,21 +156,42 @@ void TcpServer::StartWithNoBlock() { | |||||
| MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base loop with unexpect error code!"; | MSLOG_IF(mindspore::EXCEPTION, ret < -1, AbortedError) << "Event base loop with unexpect error code!"; | ||||
| } | } | ||||
| void TcpServer::StartTimerOnlyOnce(const uint32_t &time) { | |||||
| MS_EXCEPTION_IF_NULL(base_); | |||||
| if (time == 0) { | |||||
| MS_LOG(EXCEPTION) << "The time should not be 0!"; | |||||
| } | |||||
| struct event *ev = nullptr; | |||||
| struct timeval timeout {}; | |||||
| timeout.tv_sec = time; | |||||
| timeout.tv_usec = 0; | |||||
| ev = evtimer_new(base_, TimerCallback, this); | |||||
| MS_EXCEPTION_IF_NULL(ev); | |||||
| evtimer_add(ev, &timeout); | |||||
| } | |||||
| void TcpServer::Stop() { | void TcpServer::Stop() { | ||||
| MS_LOG(INFO) << "Stop tcp server!"; | MS_LOG(INFO) << "Stop tcp server!"; | ||||
| if (signal_event_ != nullptr) { | |||||
| event_free(signal_event_); | |||||
| signal_event_ = nullptr; | |||||
| } | |||||
| if (!is_stop_.load()) { | |||||
| int ret = event_base_loopbreak(base_); | |||||
| if (ret != 0) { | |||||
| MS_LOG(EXCEPTION) << "event base loop break failed!"; | |||||
| } | |||||
| if (signal_event_ != nullptr) { | |||||
| event_free(signal_event_); | |||||
| signal_event_ = nullptr; | |||||
| } | |||||
| if (listener_ != nullptr) { | |||||
| evconnlistener_free(listener_); | |||||
| listener_ = nullptr; | |||||
| } | |||||
| if (listener_ != nullptr) { | |||||
| evconnlistener_free(listener_); | |||||
| listener_ = nullptr; | |||||
| } | |||||
| if (base_ != nullptr) { | |||||
| event_base_free(base_); | |||||
| base_ = nullptr; | |||||
| if (base_ != nullptr) { | |||||
| event_base_free(base_); | |||||
| base_ = nullptr; | |||||
| } | |||||
| is_stop_ = true; | |||||
| } | } | ||||
| } | } | ||||
| @@ -287,6 +317,14 @@ void TcpServer::EventCallback(struct bufferevent *bev, std::int16_t events, void | |||||
| } | } | ||||
| } | } | ||||
| void TcpServer::TimerCallback(evutil_socket_t, int16_t, void *arg) { | |||||
| MS_EXCEPTION_IF_NULL(arg); | |||||
| auto tcp_server = reinterpret_cast<TcpServer *>(arg); | |||||
| if (tcp_server->on_timer_callback_) { | |||||
| tcp_server->on_timer_callback_(*tcp_server); | |||||
| } | |||||
| } | |||||
| void TcpServer::SendMessage(const TcpConnection &conn, const CommMessage &message) { conn.SendMessage(message); } | void TcpServer::SendMessage(const TcpConnection &conn, const CommMessage &message) { conn.SendMessage(message); } | ||||
| void TcpServer::SendMessage(const CommMessage &message) { | void TcpServer::SendMessage(const CommMessage &message) { | ||||
| @@ -299,6 +337,10 @@ void TcpServer::SendMessage(const CommMessage &message) { | |||||
| uint16_t TcpServer::BoundPort() const { return server_port_; } | uint16_t TcpServer::BoundPort() const { return server_port_; } | ||||
| int TcpServer::ConnectionNum() const { return connections_.size(); } | |||||
| const std::map<evutil_socket_t, const TcpConnection *> &TcpServer::Connections() const { return connections_; } | |||||
| void TcpServer::SetMessageCallback(const OnServerReceiveMessage &cb) { message_callback_ = cb; } | void TcpServer::SetMessageCallback(const OnServerReceiveMessage &cb) { message_callback_ = cb; } | ||||
| } // namespace core | } // namespace core | ||||
| } // namespace ps | } // namespace ps | ||||
| @@ -21,17 +21,23 @@ | |||||
| #include <event2/bufferevent.h> | #include <event2/bufferevent.h> | ||||
| #include <event2/event.h> | #include <event2/event.h> | ||||
| #include <event2/listener.h> | #include <event2/listener.h> | ||||
| #include <event2/thread.h> | |||||
| #include <exception> | #include <exception> | ||||
| #include <functional> | #include <functional> | ||||
| #include <iostream> | #include <iostream> | ||||
| #include <map> | #include <map> | ||||
| #include <memory> | |||||
| #include <mutex> | #include <mutex> | ||||
| #include <string> | #include <string> | ||||
| #include <memory> | |||||
| #include <vector> | #include <vector> | ||||
| #include <thread> | |||||
| #include <atomic> | |||||
| #include "utils/log_adapter.h" | |||||
| #include "proto/comm.pb.h" | |||||
| #include "ps/core/tcp_message_handler.h" | #include "ps/core/tcp_message_handler.h" | ||||
| #include "ps/core/cluster_config.h" | |||||
| #include "utils/log_adapter.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| @@ -40,7 +46,7 @@ class TcpServer; | |||||
| class TcpConnection { | class TcpConnection { | ||||
| public: | public: | ||||
| explicit TcpConnection(struct bufferevent *bev, const evutil_socket_t &fd, const TcpServer *server) | explicit TcpConnection(struct bufferevent *bev, const evutil_socket_t &fd, const TcpServer *server) | ||||
| : buffer_event_(bev), fd_(0), server_(server) {} | |||||
| : buffer_event_(bev), fd_(fd), server_(server) {} | |||||
| virtual ~TcpConnection() = default; | virtual ~TcpConnection() = default; | ||||
| virtual void InitConnection(); | virtual void InitConnection(); | ||||
| @@ -65,24 +71,29 @@ class TcpServer { | |||||
| using OnConnected = std::function<void(const TcpServer &, const TcpConnection &)>; | using OnConnected = std::function<void(const TcpServer &, const TcpConnection &)>; | ||||
| using OnDisconnected = std::function<void(const TcpServer &, const TcpConnection &)>; | using OnDisconnected = std::function<void(const TcpServer &, const TcpConnection &)>; | ||||
| using OnAccepted = std::function<const TcpConnection *(const TcpServer &)>; | using OnAccepted = std::function<const TcpConnection *(const TcpServer &)>; | ||||
| using OnTimer = std::function<void(const TcpServer &)>; | |||||
| explicit TcpServer(const std::string &address, std::uint16_t port); | explicit TcpServer(const std::string &address, std::uint16_t port); | ||||
| virtual ~TcpServer(); | virtual ~TcpServer(); | ||||
| void SetServerCallback(const OnConnected &client_conn, const OnDisconnected &client_disconn, | void SetServerCallback(const OnConnected &client_conn, const OnDisconnected &client_disconn, | ||||
| const OnAccepted &client_accept); | const OnAccepted &client_accept); | ||||
| void set_timer_callback(const OnTimer &timer); | |||||
| void Init(); | void Init(); | ||||
| void Start(); | void Start(); | ||||
| void StartWithNoBlock(); | void StartWithNoBlock(); | ||||
| void StartTimerOnlyOnce(const uint32_t &time); | |||||
| void Stop(); | void Stop(); | ||||
| void SendToAllClients(const char *data, size_t len); | void SendToAllClients(const char *data, size_t len); | ||||
| void AddConnection(const evutil_socket_t &fd, const TcpConnection *connection); | void AddConnection(const evutil_socket_t &fd, const TcpConnection *connection); | ||||
| void RemoveConnection(const evutil_socket_t &fd); | void RemoveConnection(const evutil_socket_t &fd); | ||||
| OnServerReceiveMessage GetServerReceive() const; | OnServerReceiveMessage GetServerReceive() const; | ||||
| void SetMessageCallback(const OnServerReceiveMessage &cb); | void SetMessageCallback(const OnServerReceiveMessage &cb); | ||||
| static void SendMessage(const TcpConnection &conn, const CommMessage &message); | |||||
| void SendMessage(const TcpConnection &conn, const CommMessage &message); | |||||
| void SendMessage(const CommMessage &message); | void SendMessage(const CommMessage &message); | ||||
| uint16_t BoundPort() const; | uint16_t BoundPort() const; | ||||
| int ConnectionNum() const; | |||||
| const std::map<evutil_socket_t, const TcpConnection *> &Connections() const; | |||||
| protected: | protected: | ||||
| static void ListenerCallback(struct evconnlistener *listener, evutil_socket_t socket, struct sockaddr *saddr, | static void ListenerCallback(struct evconnlistener *listener, evutil_socket_t socket, struct sockaddr *saddr, | ||||
| @@ -90,6 +101,7 @@ class TcpServer { | |||||
| static void SignalCallback(evutil_socket_t sig, std::int16_t events, void *server); | static void SignalCallback(evutil_socket_t sig, std::int16_t events, void *server); | ||||
| static void ReadCallback(struct bufferevent *, void *connection); | static void ReadCallback(struct bufferevent *, void *connection); | ||||
| static void EventCallback(struct bufferevent *, std::int16_t events, void *server); | static void EventCallback(struct bufferevent *, std::int16_t events, void *server); | ||||
| static void TimerCallback(evutil_socket_t fd, int16_t event, void *arg); | |||||
| virtual TcpConnection *onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd); | virtual TcpConnection *onCreateConnection(struct bufferevent *bev, const evutil_socket_t &fd); | ||||
| struct event_base *base_; | struct event_base *base_; | ||||
| @@ -97,6 +109,7 @@ class TcpServer { | |||||
| struct evconnlistener *listener_; | struct evconnlistener *listener_; | ||||
| std::string server_address_; | std::string server_address_; | ||||
| std::uint16_t server_port_; | std::uint16_t server_port_; | ||||
| std::atomic<bool> is_stop_; | |||||
| std::map<evutil_socket_t, const TcpConnection *> connections_; | std::map<evutil_socket_t, const TcpConnection *> connections_; | ||||
| OnConnected client_connection_; | OnConnected client_connection_; | ||||
| @@ -104,6 +117,7 @@ class TcpServer { | |||||
| OnAccepted client_accept_; | OnAccepted client_accept_; | ||||
| std::recursive_mutex connection_mutex_; | std::recursive_mutex connection_mutex_; | ||||
| OnServerReceiveMessage message_callback_; | OnServerReceiveMessage message_callback_; | ||||
| OnTimer on_timer_callback_; | |||||
| }; | }; | ||||
| } // namespace core | } // namespace core | ||||
| } // namespace ps | } // namespace ps | ||||
| @@ -37,7 +37,7 @@ class TestTcpServer : public UT::Common { | |||||
| KVMessage kv_message; | KVMessage kv_message; | ||||
| kv_message.ParseFromString(message.data()); | kv_message.ParseFromString(message.data()); | ||||
| EXPECT_EQ(2, kv_message.keys_size()); | EXPECT_EQ(2, kv_message.keys_size()); | ||||
| server.SendMessage(conn, message); | |||||
| const_cast<TcpServer&>(server).SendMessage(conn, message); | |||||
| }); | }); | ||||
| server_->Init(); | server_->Init(); | ||||
| server_->Start(); | server_->Start(); | ||||