| @@ -100,8 +100,8 @@ message("onnx proto path is :" ${ONNX_PROTO}) | |||
| ms_protobuf_generate(ONNX_PROTO_SRCS ONNX_PROTO_HDRS ${ONNX_PROTO}) | |||
| list(APPEND MINDSPORE_PROTO_LIST ${ONNX_PROTO_SRCS}) | |||
| include_directories("${CMAKE_BINARY_DIR}/ps/comm") | |||
| file(GLOB_RECURSE COMM_PROTO_IN RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ps/comm/protos/*.proto") | |||
| include_directories("${CMAKE_BINARY_DIR}/ps/core") | |||
| file(GLOB_RECURSE COMM_PROTO_IN RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ps/core/protos/*.proto") | |||
| ms_protobuf_generate(COMM_PROTO_SRCS COMM_PROTO_HDRS ${COMM_PROTO_IN}) | |||
| list(APPEND MINDSPORE_PROTO_LIST ${COMM_PROTO_SRCS}) | |||
| @@ -5,12 +5,13 @@ if (NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))) | |||
| list(REMOVE_ITEM _PS_SRC_FILES "optimizer_info.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "scheduler.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "util.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "comm/http_message_handler.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "comm/http_server.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "comm/comm_util.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "comm/tcp_client.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "comm/tcp_message_handler.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "comm/tcp_server.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/http_message_handler.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/http_server.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/comm_util.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/tcp_client.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/tcp_message_handler.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/tcp_server.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/cluster_config.cc") | |||
| endif() | |||
| set_property(SOURCE ${_PS_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PS) | |||
| @@ -0,0 +1,58 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/core/cluster_config.h" | |||
| #include <string> | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| uint32_t ClusterConfig::worker_num_ = 0; | |||
| uint32_t ClusterConfig::server_num_ = 0; | |||
| uint32_t ClusterConfig::heartbeat_interval_ = kHeartbeatInterval; | |||
| std::unique_ptr<std::string> ClusterConfig::scheduler_host_ = nullptr; | |||
| uint16_t ClusterConfig::scheduler_port_ = 0; | |||
| void ClusterConfig::Init(const uint32_t &worker_num, const uint32_t &server_num, | |||
| std::unique_ptr<std::string> scheduler_host, const uint16_t &scheduler_port) { | |||
| worker_num_ = worker_num; | |||
| server_num_ = server_num; | |||
| if (!CommUtil::CheckIp(*scheduler_host.get())) { | |||
| MS_LOG(EXCEPTION) << "The scheduler_host:" << *scheduler_host.get() << " is illegal!"; | |||
| } | |||
| scheduler_host_ = std::move(scheduler_host); | |||
| scheduler_port_ = scheduler_port; | |||
| } | |||
| uint32_t ClusterConfig::worker_num() { return worker_num_; } | |||
| uint32_t ClusterConfig::server_num() { return server_num_; } | |||
| uint32_t ClusterConfig::heartbeat_interval() { return heartbeat_interval_; } | |||
| void ClusterConfig::set_heartbeat_interval(const uint32_t &heartbeat_interval) { | |||
| heartbeat_interval_ = heartbeat_interval; | |||
| } | |||
| std::string ClusterConfig::scheduler_host() { return *scheduler_host_.get(); } | |||
| uint16_t ClusterConfig::scheduler_port() { return scheduler_port_; } | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,56 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_CORE_CLUSTER_CONFIG_H_ | |||
| #define MINDSPORE_CCSRC_PS_CORE_CLUSTER_CONFIG_H_ | |||
| #include <string> | |||
| #include <iostream> | |||
| #include <memory> | |||
| #include <utility> | |||
| #include "utils/log_adapter.h" | |||
| #include "ps/core/comm_util.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| constexpr uint32_t kHeartbeatInterval = 3; | |||
| class ClusterConfig { | |||
| public: | |||
| static void Init(const uint32_t &worker_num, const uint32_t &server_num, std::unique_ptr<std::string> scheduler_host, | |||
| const uint16_t &scheduler_port); | |||
| static uint32_t worker_num(); | |||
| static uint32_t server_num(); | |||
| static uint32_t heartbeat_interval(); | |||
| static void set_heartbeat_interval(const uint32_t &heartbeat_interval); | |||
| static std::string scheduler_host(); | |||
| static uint16_t scheduler_port(); | |||
| private: | |||
| static uint32_t worker_num_; | |||
| static uint32_t server_num_; | |||
| static uint32_t heartbeat_interval_; | |||
| static std::unique_ptr<std::string> scheduler_host_; | |||
| static uint16_t scheduler_port_; | |||
| }; | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_CORE_CLUSTER_CONFIG_H_ | |||
| @@ -14,7 +14,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/comm/comm_util.h" | |||
| #include "ps/core/comm_util.h" | |||
| #include <arpa/inet.h> | |||
| #include <cstdio> | |||
| @@ -25,7 +25,7 @@ | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace comm { | |||
| namespace core { | |||
| bool CommUtil::CheckIpWithRegex(const std::string &ip) { | |||
| std::regex pattern("((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?).){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)"); | |||
| @@ -36,15 +36,47 @@ bool CommUtil::CheckIpWithRegex(const std::string &ip) { | |||
| return false; | |||
| } | |||
| void CommUtil::CheckIp(const std::string &ip) { | |||
| bool CommUtil::CheckIp(const std::string &ip) { | |||
| if (!CheckIpWithRegex(ip)) { | |||
| MS_LOG(EXCEPTION) << "Server address" << ip << " illegal!"; | |||
| return false; | |||
| } | |||
| int64_t uAddr = inet_addr(ip.c_str()); | |||
| if (INADDR_NONE == uAddr) { | |||
| MS_LOG(EXCEPTION) << "Server address illegal, inet_addr converting failed!"; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace comm | |||
| void CommUtil::GetAvailableInterfaceAndIP(std::string *interface, std::string *ip) { | |||
| MS_EXCEPTION_IF_NULL(interface); | |||
| MS_EXCEPTION_IF_NULL(ip); | |||
| struct ifaddrs *if_address = nullptr; | |||
| struct ifaddrs *ifa = nullptr; | |||
| interface->clear(); | |||
| ip->clear(); | |||
| getifaddrs(&if_address); | |||
| for (ifa = if_address; ifa != nullptr; ifa = ifa->ifa_next) { | |||
| if (ifa->ifa_addr == nullptr) { | |||
| continue; | |||
| } | |||
| if (ifa->ifa_addr->sa_family == AF_INET && (ifa->ifa_flags & IFF_LOOPBACK) == 0) { | |||
| char address_buffer[INET_ADDRSTRLEN] = {0}; | |||
| void *sin_addr_ptr = &(reinterpret_cast<struct sockaddr_in *>(ifa->ifa_addr))->sin_addr; | |||
| MS_EXCEPTION_IF_NULL(sin_addr_ptr); | |||
| const char *net_ptr = inet_ntop(AF_INET, sin_addr_ptr, address_buffer, INET_ADDRSTRLEN); | |||
| MS_EXCEPTION_IF_NULL(net_ptr); | |||
| *ip = address_buffer; | |||
| *interface = ifa->ifa_name; | |||
| break; | |||
| } | |||
| } | |||
| MS_EXCEPTION_IF_NULL(if_address); | |||
| freeifaddrs(if_address); | |||
| } | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -14,8 +14,21 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_COMM_COMM_UTIL_H_ | |||
| #define MINDSPORE_CCSRC_PS_COMM_COMM_UTIL_H_ | |||
| #ifndef MINDSPORE_CCSRC_PS_CORE_COMM_UTIL_H_ | |||
| #define MINDSPORE_CCSRC_PS_CORE_COMM_UTIL_H_ | |||
| #include <unistd.h> | |||
| #ifdef _MSC_VER | |||
| #include <tchar.h> | |||
| #include <winsock2.h> | |||
| #include <windows.h> | |||
| #include <iphlpapi.h> | |||
| #else | |||
| #include <net/if.h> | |||
| #include <arpa/inet.h> | |||
| #include <ifaddrs.h> | |||
| #include <netinet/in.h> | |||
| #endif | |||
| #include <event2/buffer.h> | |||
| #include <event2/event.h> | |||
| @@ -35,15 +48,16 @@ | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace comm { | |||
| namespace core { | |||
| class CommUtil { | |||
| public: | |||
| static bool CheckIpWithRegex(const std::string &ip); | |||
| static void CheckIp(const std::string &ip); | |||
| static bool CheckIp(const std::string &ip); | |||
| static void GetAvailableInterfaceAndIP(std::string *interface, std::string *ip); | |||
| }; | |||
| } // namespace comm | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_COMM_COMM_UTIL_H_ | |||
| #endif // MINDSPORE_CCSRC_PS_CORE_COMM_UTIL_H_ | |||
| @@ -14,7 +14,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/comm/http_message_handler.h" | |||
| #include "ps/core/http_message_handler.h" | |||
| #include <event2/event.h> | |||
| #include <event2/buffer.h> | |||
| @@ -36,7 +36,7 @@ | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace comm { | |||
| namespace core { | |||
| void HttpMessageHandler::InitHttpMessage() { | |||
| MS_EXCEPTION_IF_NULL(event_request_); | |||
| @@ -202,6 +202,6 @@ void HttpMessageHandler::RespError(int nCode, const std::string &message) { | |||
| evhttp_send_error(event_request_, nCode, message.c_str()); | |||
| } | |||
| } | |||
| } // namespace comm | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_COMM_HTTP_MESSAGE_HANDLER_H_ | |||
| #define MINDSPORE_CCSRC_PS_COMM_HTTP_MESSAGE_HANDLER_H_ | |||
| #ifndef MINDSPORE_CCSRC_PS_CORE_HTTP_MESSAGE_HANDLER_H_ | |||
| #define MINDSPORE_CCSRC_PS_CORE_HTTP_MESSAGE_HANDLER_H_ | |||
| #include <event2/buffer.h> | |||
| #include <event2/event.h> | |||
| @@ -36,7 +36,7 @@ | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace comm { | |||
| namespace core { | |||
| using HttpHeaders = std::map<std::string, std::list<std::string>>; | |||
| @@ -101,7 +101,7 @@ class HttpMessageHandler { | |||
| void ParsePostParam(); | |||
| }; | |||
| } // namespace comm | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_COMM_HTTP_MESSAGE_HANDLER_H_ | |||
| #endif // MINDSPORE_CCSRC_PS_CORE_HTTP_MESSAGE_HANDLER_H_ | |||
| @@ -14,9 +14,9 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/comm/http_server.h" | |||
| #include "ps/comm/http_message_handler.h" | |||
| #include "ps/comm/comm_util.h" | |||
| #include "ps/core/http_server.h" | |||
| #include "ps/core/http_message_handler.h" | |||
| #include "ps/core/comm_util.h" | |||
| #ifdef WIN32 | |||
| #include <WinSock2.h> | |||
| @@ -40,12 +40,14 @@ | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace comm { | |||
| namespace core { | |||
| HttpServer::~HttpServer() { Stop(); } | |||
| bool HttpServer::InitServer() { | |||
| CommUtil::CheckIp(server_address_); | |||
| if (!CommUtil::CheckIp(server_address_)) { | |||
| MS_LOG(EXCEPTION) << "The http server ip:" << server_address_ << " is illegal!"; | |||
| } | |||
| event_base_ = event_base_new(); | |||
| MS_EXCEPTION_IF_NULL(event_base_); | |||
| @@ -154,6 +156,6 @@ void HttpServer::Stop() { | |||
| } | |||
| } | |||
| } // namespace comm | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -14,10 +14,10 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_COMM_HTTP_SERVER_H_ | |||
| #define MINDSPORE_CCSRC_PS_COMM_HTTP_SERVER_H_ | |||
| #ifndef MINDSPORE_CCSRC_PS_CORE_HTTP_SERVER_H_ | |||
| #define MINDSPORE_CCSRC_PS_CORE_HTTP_SERVER_H_ | |||
| #include "ps/comm/http_message_handler.h" | |||
| #include "ps/core/http_message_handler.h" | |||
| #include <event2/buffer.h> | |||
| #include <event2/event.h> | |||
| @@ -35,7 +35,7 @@ | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace comm { | |||
| namespace core { | |||
| typedef enum eHttpMethod { | |||
| HM_GET = 1 << 0, | |||
| @@ -86,8 +86,8 @@ class HttpServer { | |||
| bool is_init_; | |||
| }; | |||
| } // namespace comm | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_COMM_HTTP_SERVER_H_ | |||
| #endif // MINDSPORE_CCSRC_PS_CORE_HTTP_SERVER_H_ | |||
| @@ -16,9 +16,24 @@ | |||
| syntax = "proto3"; | |||
| import "google/protobuf/any.proto"; | |||
| package mindspore.ps; | |||
| package mindspore.ps.core; | |||
| option optimize_for = LITE_RUNTIME; | |||
| enum ClusterCommand { | |||
| TERMINATE = 0; | |||
| REGISTER = 1; | |||
| ACK = 2; | |||
| HEARTBEAT = 3; | |||
| FETCH_WORKERS = 4; | |||
| FETCH_SERVERS = 5; | |||
| } | |||
| enum Role { | |||
| SERVER = 0; | |||
| WORKER = 1; | |||
| SCHEDULER = 2; | |||
| } | |||
| message MessageMeta { | |||
| // hostname or ip | |||
| string hostname = 1; | |||
| @@ -14,6 +14,10 @@ | |||
| * limitations under the License. | |||
| */ | |||
| syntax = "proto3"; | |||
| package mindspore.ps.core; | |||
| option optimize_for = LITE_RUNTIME; | |||
| message KVMessage { | |||
| repeated int32 keys = 1; | |||
| repeated float values = 2; | |||
| @@ -14,7 +14,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/comm/tcp_client.h" | |||
| #include "ps/core/tcp_client.h" | |||
| #include <arpa/inet.h> | |||
| #include <event2/buffer.h> | |||
| @@ -30,11 +30,11 @@ | |||
| #include <utility> | |||
| #include <string> | |||
| #include "ps/comm/comm_util.h" | |||
| #include "ps/core/comm_util.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace comm { | |||
| namespace core { | |||
| TcpClient::TcpClient(const std::string &address, std::uint16_t port) | |||
| : event_base_(nullptr), | |||
| @@ -65,7 +65,9 @@ void TcpClient::Init() { | |||
| if (buffer_event_) { | |||
| return; | |||
| } | |||
| CommUtil::CheckIp(server_address_); | |||
| if (!CommUtil::CheckIp(server_address_)) { | |||
| MS_LOG(EXCEPTION) << "The tcp client ip:" << server_address_ << " is illegal!"; | |||
| } | |||
| event_base_ = event_base_new(); | |||
| MS_EXCEPTION_IF_NULL(event_base_); | |||
| @@ -166,6 +168,23 @@ void TcpClient::OnReadHandler(const void *buf, size_t num) { | |||
| message_handler_.ReceiveMessage(buf, num); | |||
| } | |||
| void TcpClient::SendHeartBeatCallback(evutil_socket_t, int16_t, void *arg) { | |||
| MS_EXCEPTION_IF_NULL(arg); | |||
| auto tcp_client = reinterpret_cast<TcpClient *>(arg); | |||
| MessageMeta meta; | |||
| meta.set_cmd(ClusterCommand::HEARTBEAT); | |||
| CommMessage message; | |||
| message.set_allocated_pb_meta(&meta); | |||
| tcp_client->SendMessage(message); | |||
| struct event *ev; | |||
| struct timeval timeout {}; | |||
| timeout.tv_sec = ClusterConfig::heartbeat_interval(); | |||
| timeout.tv_usec = 0; | |||
| ev = evtimer_new(tcp_client->event_base_, SendHeartBeatCallback, arg); | |||
| evtimer_add(ev, &timeout); | |||
| } | |||
| void TcpClient::EventCallback(struct bufferevent *bev, std::int16_t events, void *ptr) { | |||
| MS_EXCEPTION_IF_NULL(bev); | |||
| MS_EXCEPTION_IF_NULL(ptr); | |||
| @@ -226,6 +245,16 @@ void TcpClient::SendMessage(const CommMessage &message) const { | |||
| } | |||
| } | |||
| } // namespace comm | |||
| void TcpClient::SendMessageWithTimer() { | |||
| MS_EXCEPTION_IF_NULL(buffer_event_); | |||
| struct event *ev = nullptr; | |||
| struct timeval timeout {}; | |||
| timeout.tv_sec = 0; | |||
| timeout.tv_usec = 0; | |||
| ev = evtimer_new(event_base_, SendHeartBeatCallback, this); | |||
| evtimer_add(ev, &timeout); | |||
| } | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -14,10 +14,10 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_COMM_TCP_CLIENT_H_ | |||
| #define MINDSPORE_CCSRC_PS_COMM_TCP_CLIENT_H_ | |||
| #ifndef MINDSPORE_CCSRC_PS_CORE_TCP_CLIENT_H_ | |||
| #define MINDSPORE_CCSRC_PS_CORE_TCP_CLIENT_H_ | |||
| #include "ps/comm/tcp_message_handler.h" | |||
| #include "ps/core/tcp_message_handler.h" | |||
| #include <event2/event.h> | |||
| #include <event2/bufferevent.h> | |||
| @@ -27,10 +27,11 @@ | |||
| #include <vector> | |||
| #include "proto/comm.pb.h" | |||
| #include "ps/core/cluster_config.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace comm { | |||
| namespace core { | |||
| class TcpClient { | |||
| public: | |||
| @@ -53,6 +54,7 @@ class TcpClient { | |||
| void StartWithNoBlock(); | |||
| void SetMessageCallback(const OnMessage &cb); | |||
| void SendMessage(const CommMessage &message) const; | |||
| void SendMessageWithTimer(); | |||
| protected: | |||
| static void SetTcpNoDelay(const evutil_socket_t &fd); | |||
| @@ -60,6 +62,7 @@ class TcpClient { | |||
| static void ReadCallback(struct bufferevent *bev, void *ctx); | |||
| static void EventCallback(struct bufferevent *bev, std::int16_t events, void *ptr); | |||
| virtual void OnReadHandler(const void *buf, size_t num); | |||
| static void SendHeartBeatCallback(evutil_socket_t fd, int16_t event, void *arg); | |||
| private: | |||
| OnMessage message_callback_; | |||
| @@ -78,7 +81,7 @@ class TcpClient { | |||
| std::uint16_t server_port_; | |||
| }; | |||
| } // namespace comm | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_COMM_TCP_CLIENT_H_ | |||
| #endif // MINDSPORE_CCSRC_PS_CORE_TCP_CLIENT_H_ | |||
| @@ -14,7 +14,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/comm/tcp_message_handler.h" | |||
| #include "ps/core/tcp_message_handler.h" | |||
| #include <arpa/inet.h> | |||
| #include <iostream> | |||
| @@ -22,7 +22,7 @@ | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace comm { | |||
| namespace core { | |||
| void TcpMessageHandler::SetCallback(const messageReceive &message_receive) { message_callback_ = message_receive; } | |||
| @@ -37,16 +37,15 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) { | |||
| --num; | |||
| if (header_index_ == 3) { | |||
| message_length_ = *reinterpret_cast<const uint32_t *>(header_); | |||
| message_length_ = ntohl(message_length_); | |||
| remaining_length_ = message_length_; | |||
| message_buffer_.reset(new unsigned char[remaining_length_]); | |||
| buffer_data += i; | |||
| buffer_data += (i + 1); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| if (remaining_length_ > 0) { | |||
| if (remaining_length_ > 0 && num > 0) { | |||
| uint32_t copy_len = remaining_length_ <= num ? remaining_length_ : num; | |||
| remaining_length_ -= copy_len; | |||
| num -= copy_len; | |||
| @@ -60,19 +59,19 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) { | |||
| if (remaining_length_ == 0) { | |||
| CommMessage pb_message; | |||
| pb_message.ParseFromArray(reinterpret_cast<const void *>(message_buffer_.get()), message_length_); | |||
| pb_message.ParseFromArray(message_buffer_.get(), message_length_); | |||
| if (message_callback_) { | |||
| message_callback_(pb_message); | |||
| } | |||
| message_buffer_.reset(); | |||
| message_buffer_ = nullptr; | |||
| header_index_ = 0; | |||
| header_index_ = -1; | |||
| last_copy_len_ = 0; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } // namespace comm | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_COMM_TCP_MESSAGE_HANDLER_H_ | |||
| #define MINDSPORE_CCSRC_PS_COMM_TCP_MESSAGE_HANDLER_H_ | |||
| #ifndef MINDSPORE_CCSRC_PS_CORE_TCP_MESSAGE_HANDLER_H_ | |||
| #define MINDSPORE_CCSRC_PS_CORE_TCP_MESSAGE_HANDLER_H_ | |||
| #include <functional> | |||
| #include <iostream> | |||
| @@ -29,7 +29,7 @@ | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace comm { | |||
| namespace core { | |||
| using messageReceive = std::function<void(const CommMessage &message)>; | |||
| @@ -57,8 +57,8 @@ class TcpMessageHandler { | |||
| int header_index_; | |||
| uint32_t last_copy_len_; | |||
| }; | |||
| } // namespace comm | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_COMM_TCP_MESSAGE_HANDLER_H_ | |||
| #endif // MINDSPORE_CCSRC_PS_CORE_TCP_MESSAGE_HANDLER_H_ | |||
| @@ -14,7 +14,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/comm/tcp_server.h" | |||
| #include "ps/core/tcp_server.h" | |||
| #include <arpa/inet.h> | |||
| #include <event2/buffer.h> | |||
| @@ -27,11 +27,11 @@ | |||
| #include <csignal> | |||
| #include <utility> | |||
| #include "ps/comm/comm_util.h" | |||
| #include "ps/core/comm_util.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace comm { | |||
| namespace core { | |||
| void TcpConnection::InitConnection() { | |||
| tcp_message_handler_.SetCallback([&](const CommMessage &message) { | |||
| @@ -88,7 +88,9 @@ void TcpServer::SetServerCallback(const OnConnected &client_conn, const OnDiscon | |||
| void TcpServer::Init() { | |||
| base_ = event_base_new(); | |||
| MS_EXCEPTION_IF_NULL(base_); | |||
| CommUtil::CheckIp(server_address_); | |||
| if (!CommUtil::CheckIp(server_address_)) { | |||
| MS_LOG(EXCEPTION) << "The tcp server ip:" << server_address_ << " is illegal!"; | |||
| } | |||
| struct sockaddr_in sin {}; | |||
| if (memset_s(&sin, sizeof(sin), 0, sizeof(sin)) != EOK) { | |||
| @@ -104,6 +106,18 @@ void TcpServer::Init() { | |||
| MS_EXCEPTION_IF_NULL(listener_); | |||
| if (server_port_ == 0) { | |||
| struct sockaddr_in sin_bound {}; | |||
| if (memset_s(&sin, sizeof(sin_bound), 0, sizeof(sin_bound)) != EOK) { | |||
| MS_LOG(EXCEPTION) << "Initialize sockaddr_in failed!"; | |||
| } | |||
| socklen_t addr_len = sizeof(struct sockaddr_in); | |||
| if (getsockname(evconnlistener_get_fd(listener_), (struct sockaddr *)&sin_bound, &addr_len) != 0) { | |||
| MS_LOG(EXCEPTION) << "Get sock name failed!"; | |||
| } | |||
| server_port_ = htons(sin_bound.sin_port); | |||
| } | |||
| signal_event_ = evsignal_new(base_, SIGINT, SignalCallback, reinterpret_cast<void *>(this)); | |||
| MS_EXCEPTION_IF_NULL(signal_event_); | |||
| if (event_add(signal_event_, nullptr) < 0) { | |||
| @@ -173,11 +187,13 @@ void TcpServer::RemoveConnection(const evutil_socket_t &fd) { | |||
| connections_.erase(fd); | |||
| } | |||
| void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, struct sockaddr *, int, void *data) { | |||
| void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, struct sockaddr *sockaddr, int, | |||
| void *data) { | |||
| auto server = reinterpret_cast<class TcpServer *>(data); | |||
| auto base = reinterpret_cast<struct event_base *>(server->base_); | |||
| MS_EXCEPTION_IF_NULL(server); | |||
| MS_EXCEPTION_IF_NULL(base); | |||
| MS_EXCEPTION_IF_NULL(sockaddr); | |||
| struct bufferevent *bev = bufferevent_socket_new(base, fd, BEV_OPT_CLOSE_ON_FREE); | |||
| if (!bev) { | |||
| @@ -279,8 +295,10 @@ void TcpServer::SendMessage(const CommMessage &message) { | |||
| } | |||
| } | |||
| uint16_t TcpServer::BoundPort() const { return server_port_; } | |||
| void TcpServer::SetMessageCallback(const OnServerReceiveMessage &cb) { message_callback_ = cb; } | |||
| } // namespace comm | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_COMM_TCP_SERVER_H_ | |||
| #define MINDSPORE_CCSRC_PS_COMM_TCP_SERVER_H_ | |||
| #ifndef MINDSPORE_CCSRC_PS_CORE_TCP_SERVER_H_ | |||
| #define MINDSPORE_CCSRC_PS_CORE_TCP_SERVER_H_ | |||
| #include <event2/buffer.h> | |||
| #include <event2/bufferevent.h> | |||
| @@ -31,11 +31,11 @@ | |||
| #include <vector> | |||
| #include "utils/log_adapter.h" | |||
| #include "ps/comm/tcp_message_handler.h" | |||
| #include "ps/core/tcp_message_handler.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace comm { | |||
| namespace core { | |||
| class TcpServer; | |||
| class TcpConnection { | |||
| @@ -83,6 +83,7 @@ class TcpServer { | |||
| void SetMessageCallback(const OnServerReceiveMessage &cb); | |||
| static void SendMessage(const TcpConnection &conn, const CommMessage &message); | |||
| void SendMessage(const CommMessage &message); | |||
| uint16_t BoundPort() const; | |||
| protected: | |||
| static void ListenerCallback(struct evconnlistener *listener, evutil_socket_t socket, struct sockaddr *saddr, | |||
| @@ -106,7 +107,7 @@ class TcpServer { | |||
| OnServerReceiveMessage message_callback_; | |||
| }; | |||
| } // namespace comm | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_COMM_TCP_SERVER_H_ | |||
| #endif // MINDSPORE_CCSRC_PS_CORE_TCP_SERVER_H_ | |||
| @@ -0,0 +1,46 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <memory> | |||
| #include <string> | |||
| #include "common/common_test.h" | |||
| #include "ps/core/cluster_config.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| class TestClusterConfig : public UT::Common { | |||
| public: | |||
| TestClusterConfig() = default; | |||
| virtual ~TestClusterConfig() = default; | |||
| void SetUp() override {} | |||
| void TearDown() override {} | |||
| }; | |||
| TEST_F(TestClusterConfig, HeartbeatInterval) { | |||
| ClusterConfig::Init(2, 2, std::make_unique<std::string>("127.0.0.1"), 8080); | |||
| EXPECT_TRUE(ClusterConfig::heartbeat_interval() == 3); | |||
| ClusterConfig::set_heartbeat_interval(100); | |||
| EXPECT_TRUE(ClusterConfig::heartbeat_interval() == 100); | |||
| EXPECT_STREQ(ClusterConfig::scheduler_host().c_str(), "127.0.0.1"); | |||
| EXPECT_TRUE(ClusterConfig::scheduler_port() == 8080); | |||
| } | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,44 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "common/common_test.h" | |||
| #include "ps/core/comm_util.h" | |||
| #include <memory> | |||
| #include <thread> | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| class TestCommUtil : public UT::Common { | |||
| public: | |||
| TestCommUtil() = default; | |||
| virtual ~TestCommUtil() = default; | |||
| void SetUp() override {} | |||
| void TearDown() override {} | |||
| }; | |||
| TEST_F(TestCommUtil, GetAvailableInterfaceAndIP) { | |||
| std::string interface; | |||
| std::string ip; | |||
| CommUtil::GetAvailableInterfaceAndIP(&interface, &ip); | |||
| EXPECT_TRUE(!interface.empty()); | |||
| EXPECT_TRUE(!ip.empty()); | |||
| } | |||
| } // namespace comm | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -14,7 +14,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/comm/http_server.h" | |||
| #include "ps/core/http_server.h" | |||
| #include "common/common_test.h" | |||
| #include <gtest/gtest.h> | |||
| #include <algorithm> | |||
| @@ -28,7 +28,7 @@ | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace comm { | |||
| namespace core { | |||
| class TestHttpServer : public UT::Common { | |||
| public: | |||
| @@ -17,11 +17,11 @@ | |||
| #include <memory> | |||
| #include "common/common_test.h" | |||
| #include "ps/comm/tcp_client.h" | |||
| #include "ps/core/tcp_client.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace comm { | |||
| namespace core { | |||
| class TestTcpClient : public UT::Common { | |||
| public: | |||
| TestTcpClient() = default; | |||
| @@ -0,0 +1,163 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/core/tcp_message_handler.h" | |||
| #include "common/common_test.h" | |||
| #include <memory> | |||
| #include <thread> | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| class TestTcpMessageHandler : public UT::Common { | |||
| public: | |||
| using messageReceive = std::function<void(const CommMessage &message)>; | |||
| TestTcpMessageHandler() = default; | |||
| virtual ~TestTcpMessageHandler() = default; | |||
| void SetUp() override {} | |||
| void TearDown() override {} | |||
| }; | |||
| TEST_F(TestTcpMessageHandler, 4_Header_1003_Data) { | |||
| TcpMessageHandler handler; | |||
| handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 1000); }); | |||
| std::string data(1000, 'a'); | |||
| CommMessage message; | |||
| message.set_data(data); | |||
| uint32_t buf_size = message.ByteSizeLong(); | |||
| char result[1007]; | |||
| int ret = memcpy_s(result, 4, &buf_size, 4); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| std::vector<char> serialized(buf_size); | |||
| message.SerializeToArray(serialized.data(), static_cast<int>(buf_size)); | |||
| memcpy_s(result + 4, buf_size, serialized.data(), buf_size); | |||
| handler.ReceiveMessage(result, buf_size + 4); | |||
| } | |||
| TEST_F(TestTcpMessageHandler, 4_Header_1003_Data_4_Header_1003_Data) { | |||
| TcpMessageHandler handler; | |||
| handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 1000); }); | |||
| std::string data(1000, 'a'); | |||
| CommMessage message; | |||
| message.set_data(data); | |||
| uint32_t buf_size = message.ByteSizeLong(); | |||
| char result[2014]; | |||
| int ret = memcpy_s(result, 4, &buf_size, 4); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| std::vector<char> serialized(buf_size); | |||
| message.SerializeToArray(serialized.data(), static_cast<int>(buf_size)); | |||
| ret = memcpy_s(result + 4, buf_size, serialized.data(), buf_size); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| ret = memcpy_s(result + 4 + buf_size, 4, &buf_size, 4); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| ret = memcpy_s(result + 4 + buf_size + 4, buf_size, serialized.data(), buf_size); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| handler.ReceiveMessage(result, 2 * buf_size + 4 * 2); | |||
| } | |||
| TEST_F(TestTcpMessageHandler, 4_Header_4090_Data_2_Header_2_header_4090_data) { | |||
| TcpMessageHandler handler; | |||
| handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 4087); }); | |||
| std::string data(4087, 'a'); | |||
| CommMessage message; | |||
| message.set_data(data); | |||
| uint32_t buf_size = message.ByteSizeLong(); | |||
| char result[4096]; | |||
| int ret = memcpy_s(result, 4, &buf_size, 4); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| std::vector<char> serialized(buf_size); | |||
| message.SerializeToArray(serialized.data(), static_cast<int>(buf_size)); | |||
| ret = memcpy_s(result + 4, buf_size, serialized.data(), buf_size); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| ret = memcpy_s(result + 4 + buf_size, 2, &buf_size, 2); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| handler.ReceiveMessage(result, 4096); | |||
| ret = memcpy_s(result, 2, &buf_size + 2, 2); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| ret = memcpy_s(result + 2, buf_size, serialized.data(), buf_size); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| handler.ReceiveMessage(result, 4092); | |||
| } | |||
| TEST_F(TestTcpMessageHandler, 4_Header_4088_Data_4_Header_4088_data) { | |||
| TcpMessageHandler handler; | |||
| handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 4085); }); | |||
| std::string data(4085, 'a'); | |||
| CommMessage message; | |||
| message.set_data(data); | |||
| uint32_t buf_size = message.ByteSizeLong(); | |||
| char result[4096]; | |||
| int ret = memcpy_s(result, 4, &buf_size, 4); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| std::vector<char> serialized(buf_size); | |||
| message.SerializeToArray(serialized.data(), static_cast<int>(buf_size)); | |||
| ret = memcpy_s(result + 4, buf_size, serialized.data(), buf_size); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| ret = memcpy_s(result + 4 + buf_size, 4, &buf_size, 4); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| handler.ReceiveMessage(result, 4096); | |||
| ret = memcpy_s(result, buf_size, serialized.data(), buf_size); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| handler.ReceiveMessage(result, 4088); | |||
| } | |||
| } // namespace comm | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/comm/tcp_client.h" | |||
| #include "ps/comm/tcp_server.h" | |||
| #include "ps/core/tcp_client.h" | |||
| #include "ps/core/tcp_server.h" | |||
| #include "common/common_test.h" | |||
| #include <memory> | |||
| @@ -23,14 +23,14 @@ | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace comm { | |||
| namespace core { | |||
| class TestTcpServer : public UT::Common { | |||
| public: | |||
| TestTcpServer() : client_(nullptr), server_(nullptr) {} | |||
| virtual ~TestTcpServer() = default; | |||
| void SetUp() override { | |||
| server_ = std::make_unique<TcpServer>("127.0.0.1", 9998); | |||
| server_ = std::make_unique<TcpServer>("127.0.0.1", 0); | |||
| std::unique_ptr<std::thread> http_server_thread_(nullptr); | |||
| http_server_thread_ = std::make_unique<std::thread>([&]() { | |||
| server_->SetMessageCallback([](const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { | |||
| @@ -57,7 +57,7 @@ class TestTcpServer : public UT::Common { | |||
| }; | |||
| TEST_F(TestTcpServer, ServerSendMessage) { | |||
| client_ = std::make_unique<TcpClient>("127.0.0.1", 9998); | |||
| client_ = std::make_unique<TcpClient>("127.0.0.1", server_->BoundPort()); | |||
| std::unique_ptr<std::thread> http_client_thread(nullptr); | |||
| http_client_thread = std::make_unique<std::thread>([&]() { | |||
| client_->SetMessageCallback([](const TcpClient &client, const CommMessage &message) { | |||
| @@ -82,6 +82,6 @@ TEST_F(TestTcpServer, ServerSendMessage) { | |||
| }); | |||
| http_client_thread->detach(); | |||
| } | |||
| } // namespace comm | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||