| @@ -100,8 +100,8 @@ message("onnx proto path is :" ${ONNX_PROTO}) | |||||
| ms_protobuf_generate(ONNX_PROTO_SRCS ONNX_PROTO_HDRS ${ONNX_PROTO}) | ms_protobuf_generate(ONNX_PROTO_SRCS ONNX_PROTO_HDRS ${ONNX_PROTO}) | ||||
| list(APPEND MINDSPORE_PROTO_LIST ${ONNX_PROTO_SRCS}) | list(APPEND MINDSPORE_PROTO_LIST ${ONNX_PROTO_SRCS}) | ||||
| include_directories("${CMAKE_BINARY_DIR}/ps/comm") | |||||
| file(GLOB_RECURSE COMM_PROTO_IN RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ps/comm/protos/*.proto") | |||||
| include_directories("${CMAKE_BINARY_DIR}/ps/core") | |||||
| file(GLOB_RECURSE COMM_PROTO_IN RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ps/core/protos/*.proto") | |||||
| ms_protobuf_generate(COMM_PROTO_SRCS COMM_PROTO_HDRS ${COMM_PROTO_IN}) | ms_protobuf_generate(COMM_PROTO_SRCS COMM_PROTO_HDRS ${COMM_PROTO_IN}) | ||||
| list(APPEND MINDSPORE_PROTO_LIST ${COMM_PROTO_SRCS}) | list(APPEND MINDSPORE_PROTO_LIST ${COMM_PROTO_SRCS}) | ||||
| @@ -5,12 +5,13 @@ if (NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))) | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "optimizer_info.cc") | list(REMOVE_ITEM _PS_SRC_FILES "optimizer_info.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "scheduler.cc") | list(REMOVE_ITEM _PS_SRC_FILES "scheduler.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "util.cc") | list(REMOVE_ITEM _PS_SRC_FILES "util.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "comm/http_message_handler.cc") | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "comm/http_server.cc") | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "comm/comm_util.cc") | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "comm/tcp_client.cc") | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "comm/tcp_message_handler.cc") | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "comm/tcp_server.cc") | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "core/http_message_handler.cc") | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "core/http_server.cc") | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "core/comm_util.cc") | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "core/tcp_client.cc") | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "core/tcp_message_handler.cc") | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "core/tcp_server.cc") | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "core/cluster_config.cc") | |||||
| endif() | endif() | ||||
| set_property(SOURCE ${_PS_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PS) | set_property(SOURCE ${_PS_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PS) | ||||
| @@ -0,0 +1,58 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ps/core/cluster_config.h" | |||||
| #include <string> | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| namespace core { | |||||
| uint32_t ClusterConfig::worker_num_ = 0; | |||||
| uint32_t ClusterConfig::server_num_ = 0; | |||||
| uint32_t ClusterConfig::heartbeat_interval_ = kHeartbeatInterval; | |||||
| std::unique_ptr<std::string> ClusterConfig::scheduler_host_ = nullptr; | |||||
| uint16_t ClusterConfig::scheduler_port_ = 0; | |||||
| void ClusterConfig::Init(const uint32_t &worker_num, const uint32_t &server_num, | |||||
| std::unique_ptr<std::string> scheduler_host, const uint16_t &scheduler_port) { | |||||
| worker_num_ = worker_num; | |||||
| server_num_ = server_num; | |||||
| if (!CommUtil::CheckIp(*scheduler_host.get())) { | |||||
| MS_LOG(EXCEPTION) << "The scheduler_host:" << *scheduler_host.get() << " is illegal!"; | |||||
| } | |||||
| scheduler_host_ = std::move(scheduler_host); | |||||
| scheduler_port_ = scheduler_port; | |||||
| } | |||||
| uint32_t ClusterConfig::worker_num() { return worker_num_; } | |||||
| uint32_t ClusterConfig::server_num() { return server_num_; } | |||||
| uint32_t ClusterConfig::heartbeat_interval() { return heartbeat_interval_; } | |||||
| void ClusterConfig::set_heartbeat_interval(const uint32_t &heartbeat_interval) { | |||||
| heartbeat_interval_ = heartbeat_interval; | |||||
| } | |||||
| std::string ClusterConfig::scheduler_host() { return *scheduler_host_.get(); } | |||||
| uint16_t ClusterConfig::scheduler_port() { return scheduler_port_; } | |||||
| } // namespace core | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,56 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_PS_CORE_CLUSTER_CONFIG_H_ | |||||
| #define MINDSPORE_CCSRC_PS_CORE_CLUSTER_CONFIG_H_ | |||||
| #include <string> | |||||
| #include <iostream> | |||||
| #include <memory> | |||||
| #include <utility> | |||||
| #include "utils/log_adapter.h" | |||||
| #include "ps/core/comm_util.h" | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| namespace core { | |||||
| constexpr uint32_t kHeartbeatInterval = 3; | |||||
| class ClusterConfig { | |||||
| public: | |||||
| static void Init(const uint32_t &worker_num, const uint32_t &server_num, std::unique_ptr<std::string> scheduler_host, | |||||
| const uint16_t &scheduler_port); | |||||
| static uint32_t worker_num(); | |||||
| static uint32_t server_num(); | |||||
| static uint32_t heartbeat_interval(); | |||||
| static void set_heartbeat_interval(const uint32_t &heartbeat_interval); | |||||
| static std::string scheduler_host(); | |||||
| static uint16_t scheduler_port(); | |||||
| private: | |||||
| static uint32_t worker_num_; | |||||
| static uint32_t server_num_; | |||||
| static uint32_t heartbeat_interval_; | |||||
| static std::unique_ptr<std::string> scheduler_host_; | |||||
| static uint16_t scheduler_port_; | |||||
| }; | |||||
| } // namespace core | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_PS_CORE_CLUSTER_CONFIG_H_ | |||||
| @@ -14,7 +14,7 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "ps/comm/comm_util.h" | |||||
| #include "ps/core/comm_util.h" | |||||
| #include <arpa/inet.h> | #include <arpa/inet.h> | ||||
| #include <cstdio> | #include <cstdio> | ||||
| @@ -25,7 +25,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| namespace comm { | |||||
| namespace core { | |||||
| bool CommUtil::CheckIpWithRegex(const std::string &ip) { | bool CommUtil::CheckIpWithRegex(const std::string &ip) { | ||||
| std::regex pattern("((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?).){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)"); | std::regex pattern("((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?).){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)"); | ||||
| @@ -36,15 +36,47 @@ bool CommUtil::CheckIpWithRegex(const std::string &ip) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| void CommUtil::CheckIp(const std::string &ip) { | |||||
| bool CommUtil::CheckIp(const std::string &ip) { | |||||
| if (!CheckIpWithRegex(ip)) { | if (!CheckIpWithRegex(ip)) { | ||||
| MS_LOG(EXCEPTION) << "Server address" << ip << " illegal!"; | |||||
| return false; | |||||
| } | } | ||||
| int64_t uAddr = inet_addr(ip.c_str()); | int64_t uAddr = inet_addr(ip.c_str()); | ||||
| if (INADDR_NONE == uAddr) { | if (INADDR_NONE == uAddr) { | ||||
| MS_LOG(EXCEPTION) << "Server address illegal, inet_addr converting failed!"; | |||||
| return false; | |||||
| } | } | ||||
| return true; | |||||
| } | } | ||||
| } // namespace comm | |||||
| void CommUtil::GetAvailableInterfaceAndIP(std::string *interface, std::string *ip) { | |||||
| MS_EXCEPTION_IF_NULL(interface); | |||||
| MS_EXCEPTION_IF_NULL(ip); | |||||
| struct ifaddrs *if_address = nullptr; | |||||
| struct ifaddrs *ifa = nullptr; | |||||
| interface->clear(); | |||||
| ip->clear(); | |||||
| getifaddrs(&if_address); | |||||
| for (ifa = if_address; ifa != nullptr; ifa = ifa->ifa_next) { | |||||
| if (ifa->ifa_addr == nullptr) { | |||||
| continue; | |||||
| } | |||||
| if (ifa->ifa_addr->sa_family == AF_INET && (ifa->ifa_flags & IFF_LOOPBACK) == 0) { | |||||
| char address_buffer[INET_ADDRSTRLEN] = {0}; | |||||
| void *sin_addr_ptr = &(reinterpret_cast<struct sockaddr_in *>(ifa->ifa_addr))->sin_addr; | |||||
| MS_EXCEPTION_IF_NULL(sin_addr_ptr); | |||||
| const char *net_ptr = inet_ntop(AF_INET, sin_addr_ptr, address_buffer, INET_ADDRSTRLEN); | |||||
| MS_EXCEPTION_IF_NULL(net_ptr); | |||||
| *ip = address_buffer; | |||||
| *interface = ifa->ifa_name; | |||||
| break; | |||||
| } | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(if_address); | |||||
| freeifaddrs(if_address); | |||||
| } | |||||
| } // namespace core | |||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -14,8 +14,21 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CCSRC_PS_COMM_COMM_UTIL_H_ | |||||
| #define MINDSPORE_CCSRC_PS_COMM_COMM_UTIL_H_ | |||||
| #ifndef MINDSPORE_CCSRC_PS_CORE_COMM_UTIL_H_ | |||||
| #define MINDSPORE_CCSRC_PS_CORE_COMM_UTIL_H_ | |||||
| #include <unistd.h> | |||||
| #ifdef _MSC_VER | |||||
| #include <tchar.h> | |||||
| #include <winsock2.h> | |||||
| #include <windows.h> | |||||
| #include <iphlpapi.h> | |||||
| #else | |||||
| #include <net/if.h> | |||||
| #include <arpa/inet.h> | |||||
| #include <ifaddrs.h> | |||||
| #include <netinet/in.h> | |||||
| #endif | |||||
| #include <event2/buffer.h> | #include <event2/buffer.h> | ||||
| #include <event2/event.h> | #include <event2/event.h> | ||||
| @@ -35,15 +48,16 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| namespace comm { | |||||
| namespace core { | |||||
| class CommUtil { | class CommUtil { | ||||
| public: | public: | ||||
| static bool CheckIpWithRegex(const std::string &ip); | static bool CheckIpWithRegex(const std::string &ip); | ||||
| static void CheckIp(const std::string &ip); | |||||
| static bool CheckIp(const std::string &ip); | |||||
| static void GetAvailableInterfaceAndIP(std::string *interface, std::string *ip); | |||||
| }; | }; | ||||
| } // namespace comm | |||||
| } // namespace core | |||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_PS_COMM_COMM_UTIL_H_ | |||||
| #endif // MINDSPORE_CCSRC_PS_CORE_COMM_UTIL_H_ | |||||
| @@ -14,7 +14,7 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "ps/comm/http_message_handler.h" | |||||
| #include "ps/core/http_message_handler.h" | |||||
| #include <event2/event.h> | #include <event2/event.h> | ||||
| #include <event2/buffer.h> | #include <event2/buffer.h> | ||||
| @@ -36,7 +36,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| namespace comm { | |||||
| namespace core { | |||||
| void HttpMessageHandler::InitHttpMessage() { | void HttpMessageHandler::InitHttpMessage() { | ||||
| MS_EXCEPTION_IF_NULL(event_request_); | MS_EXCEPTION_IF_NULL(event_request_); | ||||
| @@ -202,6 +202,6 @@ void HttpMessageHandler::RespError(int nCode, const std::string &message) { | |||||
| evhttp_send_error(event_request_, nCode, message.c_str()); | evhttp_send_error(event_request_, nCode, message.c_str()); | ||||
| } | } | ||||
| } | } | ||||
| } // namespace comm | |||||
| } // namespace core | |||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CCSRC_PS_COMM_HTTP_MESSAGE_HANDLER_H_ | |||||
| #define MINDSPORE_CCSRC_PS_COMM_HTTP_MESSAGE_HANDLER_H_ | |||||
| #ifndef MINDSPORE_CCSRC_PS_CORE_HTTP_MESSAGE_HANDLER_H_ | |||||
| #define MINDSPORE_CCSRC_PS_CORE_HTTP_MESSAGE_HANDLER_H_ | |||||
| #include <event2/buffer.h> | #include <event2/buffer.h> | ||||
| #include <event2/event.h> | #include <event2/event.h> | ||||
| @@ -36,7 +36,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| namespace comm { | |||||
| namespace core { | |||||
| using HttpHeaders = std::map<std::string, std::list<std::string>>; | using HttpHeaders = std::map<std::string, std::list<std::string>>; | ||||
| @@ -101,7 +101,7 @@ class HttpMessageHandler { | |||||
| void ParsePostParam(); | void ParsePostParam(); | ||||
| }; | }; | ||||
| } // namespace comm | |||||
| } // namespace core | |||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_PS_COMM_HTTP_MESSAGE_HANDLER_H_ | |||||
| #endif // MINDSPORE_CCSRC_PS_CORE_HTTP_MESSAGE_HANDLER_H_ | |||||
| @@ -14,9 +14,9 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "ps/comm/http_server.h" | |||||
| #include "ps/comm/http_message_handler.h" | |||||
| #include "ps/comm/comm_util.h" | |||||
| #include "ps/core/http_server.h" | |||||
| #include "ps/core/http_message_handler.h" | |||||
| #include "ps/core/comm_util.h" | |||||
| #ifdef WIN32 | #ifdef WIN32 | ||||
| #include <WinSock2.h> | #include <WinSock2.h> | ||||
| @@ -40,12 +40,14 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| namespace comm { | |||||
| namespace core { | |||||
| HttpServer::~HttpServer() { Stop(); } | HttpServer::~HttpServer() { Stop(); } | ||||
| bool HttpServer::InitServer() { | bool HttpServer::InitServer() { | ||||
| CommUtil::CheckIp(server_address_); | |||||
| if (!CommUtil::CheckIp(server_address_)) { | |||||
| MS_LOG(EXCEPTION) << "The http server ip:" << server_address_ << " is illegal!"; | |||||
| } | |||||
| event_base_ = event_base_new(); | event_base_ = event_base_new(); | ||||
| MS_EXCEPTION_IF_NULL(event_base_); | MS_EXCEPTION_IF_NULL(event_base_); | ||||
| @@ -154,6 +156,6 @@ void HttpServer::Stop() { | |||||
| } | } | ||||
| } | } | ||||
| } // namespace comm | |||||
| } // namespace core | |||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -14,10 +14,10 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CCSRC_PS_COMM_HTTP_SERVER_H_ | |||||
| #define MINDSPORE_CCSRC_PS_COMM_HTTP_SERVER_H_ | |||||
| #ifndef MINDSPORE_CCSRC_PS_CORE_HTTP_SERVER_H_ | |||||
| #define MINDSPORE_CCSRC_PS_CORE_HTTP_SERVER_H_ | |||||
| #include "ps/comm/http_message_handler.h" | |||||
| #include "ps/core/http_message_handler.h" | |||||
| #include <event2/buffer.h> | #include <event2/buffer.h> | ||||
| #include <event2/event.h> | #include <event2/event.h> | ||||
| @@ -35,7 +35,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| namespace comm { | |||||
| namespace core { | |||||
| typedef enum eHttpMethod { | typedef enum eHttpMethod { | ||||
| HM_GET = 1 << 0, | HM_GET = 1 << 0, | ||||
| @@ -86,8 +86,8 @@ class HttpServer { | |||||
| bool is_init_; | bool is_init_; | ||||
| }; | }; | ||||
| } // namespace comm | |||||
| } // namespace core | |||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_PS_COMM_HTTP_SERVER_H_ | |||||
| #endif // MINDSPORE_CCSRC_PS_CORE_HTTP_SERVER_H_ | |||||
| @@ -16,9 +16,24 @@ | |||||
| syntax = "proto3"; | syntax = "proto3"; | ||||
| import "google/protobuf/any.proto"; | import "google/protobuf/any.proto"; | ||||
| package mindspore.ps; | |||||
| package mindspore.ps.core; | |||||
| option optimize_for = LITE_RUNTIME; | option optimize_for = LITE_RUNTIME; | ||||
| enum ClusterCommand { | |||||
| TERMINATE = 0; | |||||
| REGISTER = 1; | |||||
| ACK = 2; | |||||
| HEARTBEAT = 3; | |||||
| FETCH_WORKERS = 4; | |||||
| FETCH_SERVERS = 5; | |||||
| } | |||||
| enum Role { | |||||
| SERVER = 0; | |||||
| WORKER = 1; | |||||
| SCHEDULER = 2; | |||||
| } | |||||
| message MessageMeta { | message MessageMeta { | ||||
| // hostname or ip | // hostname or ip | ||||
| string hostname = 1; | string hostname = 1; | ||||
| @@ -14,6 +14,10 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| syntax = "proto3"; | |||||
| package mindspore.ps.core; | |||||
| option optimize_for = LITE_RUNTIME; | |||||
| message KVMessage { | message KVMessage { | ||||
| repeated int32 keys = 1; | repeated int32 keys = 1; | ||||
| repeated float values = 2; | repeated float values = 2; | ||||
| @@ -14,7 +14,7 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "ps/comm/tcp_client.h" | |||||
| #include "ps/core/tcp_client.h" | |||||
| #include <arpa/inet.h> | #include <arpa/inet.h> | ||||
| #include <event2/buffer.h> | #include <event2/buffer.h> | ||||
| @@ -30,11 +30,11 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include <string> | #include <string> | ||||
| #include "ps/comm/comm_util.h" | |||||
| #include "ps/core/comm_util.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| namespace comm { | |||||
| namespace core { | |||||
| TcpClient::TcpClient(const std::string &address, std::uint16_t port) | TcpClient::TcpClient(const std::string &address, std::uint16_t port) | ||||
| : event_base_(nullptr), | : event_base_(nullptr), | ||||
| @@ -65,7 +65,9 @@ void TcpClient::Init() { | |||||
| if (buffer_event_) { | if (buffer_event_) { | ||||
| return; | return; | ||||
| } | } | ||||
| CommUtil::CheckIp(server_address_); | |||||
| if (!CommUtil::CheckIp(server_address_)) { | |||||
| MS_LOG(EXCEPTION) << "The tcp client ip:" << server_address_ << " is illegal!"; | |||||
| } | |||||
| event_base_ = event_base_new(); | event_base_ = event_base_new(); | ||||
| MS_EXCEPTION_IF_NULL(event_base_); | MS_EXCEPTION_IF_NULL(event_base_); | ||||
| @@ -166,6 +168,23 @@ void TcpClient::OnReadHandler(const void *buf, size_t num) { | |||||
| message_handler_.ReceiveMessage(buf, num); | message_handler_.ReceiveMessage(buf, num); | ||||
| } | } | ||||
| void TcpClient::SendHeartBeatCallback(evutil_socket_t, int16_t, void *arg) { | |||||
| MS_EXCEPTION_IF_NULL(arg); | |||||
| auto tcp_client = reinterpret_cast<TcpClient *>(arg); | |||||
| MessageMeta meta; | |||||
| meta.set_cmd(ClusterCommand::HEARTBEAT); | |||||
| CommMessage message; | |||||
| message.set_allocated_pb_meta(&meta); | |||||
| tcp_client->SendMessage(message); | |||||
| struct event *ev; | |||||
| struct timeval timeout {}; | |||||
| timeout.tv_sec = ClusterConfig::heartbeat_interval(); | |||||
| timeout.tv_usec = 0; | |||||
| ev = evtimer_new(tcp_client->event_base_, SendHeartBeatCallback, arg); | |||||
| evtimer_add(ev, &timeout); | |||||
| } | |||||
| void TcpClient::EventCallback(struct bufferevent *bev, std::int16_t events, void *ptr) { | void TcpClient::EventCallback(struct bufferevent *bev, std::int16_t events, void *ptr) { | ||||
| MS_EXCEPTION_IF_NULL(bev); | MS_EXCEPTION_IF_NULL(bev); | ||||
| MS_EXCEPTION_IF_NULL(ptr); | MS_EXCEPTION_IF_NULL(ptr); | ||||
| @@ -226,6 +245,16 @@ void TcpClient::SendMessage(const CommMessage &message) const { | |||||
| } | } | ||||
| } | } | ||||
| } // namespace comm | |||||
| void TcpClient::SendMessageWithTimer() { | |||||
| MS_EXCEPTION_IF_NULL(buffer_event_); | |||||
| struct event *ev = nullptr; | |||||
| struct timeval timeout {}; | |||||
| timeout.tv_sec = 0; | |||||
| timeout.tv_usec = 0; | |||||
| ev = evtimer_new(event_base_, SendHeartBeatCallback, this); | |||||
| evtimer_add(ev, &timeout); | |||||
| } | |||||
| } // namespace core | |||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -14,10 +14,10 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CCSRC_PS_COMM_TCP_CLIENT_H_ | |||||
| #define MINDSPORE_CCSRC_PS_COMM_TCP_CLIENT_H_ | |||||
| #ifndef MINDSPORE_CCSRC_PS_CORE_TCP_CLIENT_H_ | |||||
| #define MINDSPORE_CCSRC_PS_CORE_TCP_CLIENT_H_ | |||||
| #include "ps/comm/tcp_message_handler.h" | |||||
| #include "ps/core/tcp_message_handler.h" | |||||
| #include <event2/event.h> | #include <event2/event.h> | ||||
| #include <event2/bufferevent.h> | #include <event2/bufferevent.h> | ||||
| @@ -27,10 +27,11 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "proto/comm.pb.h" | #include "proto/comm.pb.h" | ||||
| #include "ps/core/cluster_config.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| namespace comm { | |||||
| namespace core { | |||||
| class TcpClient { | class TcpClient { | ||||
| public: | public: | ||||
| @@ -53,6 +54,7 @@ class TcpClient { | |||||
| void StartWithNoBlock(); | void StartWithNoBlock(); | ||||
| void SetMessageCallback(const OnMessage &cb); | void SetMessageCallback(const OnMessage &cb); | ||||
| void SendMessage(const CommMessage &message) const; | void SendMessage(const CommMessage &message) const; | ||||
| void SendMessageWithTimer(); | |||||
| protected: | protected: | ||||
| static void SetTcpNoDelay(const evutil_socket_t &fd); | static void SetTcpNoDelay(const evutil_socket_t &fd); | ||||
| @@ -60,6 +62,7 @@ class TcpClient { | |||||
| static void ReadCallback(struct bufferevent *bev, void *ctx); | static void ReadCallback(struct bufferevent *bev, void *ctx); | ||||
| static void EventCallback(struct bufferevent *bev, std::int16_t events, void *ptr); | static void EventCallback(struct bufferevent *bev, std::int16_t events, void *ptr); | ||||
| virtual void OnReadHandler(const void *buf, size_t num); | virtual void OnReadHandler(const void *buf, size_t num); | ||||
| static void SendHeartBeatCallback(evutil_socket_t fd, int16_t event, void *arg); | |||||
| private: | private: | ||||
| OnMessage message_callback_; | OnMessage message_callback_; | ||||
| @@ -78,7 +81,7 @@ class TcpClient { | |||||
| std::uint16_t server_port_; | std::uint16_t server_port_; | ||||
| }; | }; | ||||
| } // namespace comm | |||||
| } // namespace core | |||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_PS_COMM_TCP_CLIENT_H_ | |||||
| #endif // MINDSPORE_CCSRC_PS_CORE_TCP_CLIENT_H_ | |||||
| @@ -14,7 +14,7 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "ps/comm/tcp_message_handler.h" | |||||
| #include "ps/core/tcp_message_handler.h" | |||||
| #include <arpa/inet.h> | #include <arpa/inet.h> | ||||
| #include <iostream> | #include <iostream> | ||||
| @@ -22,7 +22,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| namespace comm { | |||||
| namespace core { | |||||
| void TcpMessageHandler::SetCallback(const messageReceive &message_receive) { message_callback_ = message_receive; } | void TcpMessageHandler::SetCallback(const messageReceive &message_receive) { message_callback_ = message_receive; } | ||||
| @@ -37,16 +37,15 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) { | |||||
| --num; | --num; | ||||
| if (header_index_ == 3) { | if (header_index_ == 3) { | ||||
| message_length_ = *reinterpret_cast<const uint32_t *>(header_); | message_length_ = *reinterpret_cast<const uint32_t *>(header_); | ||||
| message_length_ = ntohl(message_length_); | |||||
| remaining_length_ = message_length_; | remaining_length_ = message_length_; | ||||
| message_buffer_.reset(new unsigned char[remaining_length_]); | message_buffer_.reset(new unsigned char[remaining_length_]); | ||||
| buffer_data += i; | |||||
| buffer_data += (i + 1); | |||||
| break; | break; | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| if (remaining_length_ > 0) { | |||||
| if (remaining_length_ > 0 && num > 0) { | |||||
| uint32_t copy_len = remaining_length_ <= num ? remaining_length_ : num; | uint32_t copy_len = remaining_length_ <= num ? remaining_length_ : num; | ||||
| remaining_length_ -= copy_len; | remaining_length_ -= copy_len; | ||||
| num -= copy_len; | num -= copy_len; | ||||
| @@ -60,19 +59,19 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) { | |||||
| if (remaining_length_ == 0) { | if (remaining_length_ == 0) { | ||||
| CommMessage pb_message; | CommMessage pb_message; | ||||
| pb_message.ParseFromArray(reinterpret_cast<const void *>(message_buffer_.get()), message_length_); | |||||
| pb_message.ParseFromArray(message_buffer_.get(), message_length_); | |||||
| if (message_callback_) { | if (message_callback_) { | ||||
| message_callback_(pb_message); | message_callback_(pb_message); | ||||
| } | } | ||||
| message_buffer_.reset(); | message_buffer_.reset(); | ||||
| message_buffer_ = nullptr; | message_buffer_ = nullptr; | ||||
| header_index_ = 0; | |||||
| header_index_ = -1; | |||||
| last_copy_len_ = 0; | last_copy_len_ = 0; | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| } // namespace comm | |||||
| } // namespace core | |||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CCSRC_PS_COMM_TCP_MESSAGE_HANDLER_H_ | |||||
| #define MINDSPORE_CCSRC_PS_COMM_TCP_MESSAGE_HANDLER_H_ | |||||
| #ifndef MINDSPORE_CCSRC_PS_CORE_TCP_MESSAGE_HANDLER_H_ | |||||
| #define MINDSPORE_CCSRC_PS_CORE_TCP_MESSAGE_HANDLER_H_ | |||||
| #include <functional> | #include <functional> | ||||
| #include <iostream> | #include <iostream> | ||||
| @@ -29,7 +29,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| namespace comm { | |||||
| namespace core { | |||||
| using messageReceive = std::function<void(const CommMessage &message)>; | using messageReceive = std::function<void(const CommMessage &message)>; | ||||
| @@ -57,8 +57,8 @@ class TcpMessageHandler { | |||||
| int header_index_; | int header_index_; | ||||
| uint32_t last_copy_len_; | uint32_t last_copy_len_; | ||||
| }; | }; | ||||
| } // namespace comm | |||||
| } // namespace core | |||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_PS_COMM_TCP_MESSAGE_HANDLER_H_ | |||||
| #endif // MINDSPORE_CCSRC_PS_CORE_TCP_MESSAGE_HANDLER_H_ | |||||
| @@ -14,7 +14,7 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "ps/comm/tcp_server.h" | |||||
| #include "ps/core/tcp_server.h" | |||||
| #include <arpa/inet.h> | #include <arpa/inet.h> | ||||
| #include <event2/buffer.h> | #include <event2/buffer.h> | ||||
| @@ -27,11 +27,11 @@ | |||||
| #include <csignal> | #include <csignal> | ||||
| #include <utility> | #include <utility> | ||||
| #include "ps/comm/comm_util.h" | |||||
| #include "ps/core/comm_util.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| namespace comm { | |||||
| namespace core { | |||||
| void TcpConnection::InitConnection() { | void TcpConnection::InitConnection() { | ||||
| tcp_message_handler_.SetCallback([&](const CommMessage &message) { | tcp_message_handler_.SetCallback([&](const CommMessage &message) { | ||||
| @@ -88,7 +88,9 @@ void TcpServer::SetServerCallback(const OnConnected &client_conn, const OnDiscon | |||||
| void TcpServer::Init() { | void TcpServer::Init() { | ||||
| base_ = event_base_new(); | base_ = event_base_new(); | ||||
| MS_EXCEPTION_IF_NULL(base_); | MS_EXCEPTION_IF_NULL(base_); | ||||
| CommUtil::CheckIp(server_address_); | |||||
| if (!CommUtil::CheckIp(server_address_)) { | |||||
| MS_LOG(EXCEPTION) << "The tcp server ip:" << server_address_ << " is illegal!"; | |||||
| } | |||||
| struct sockaddr_in sin {}; | struct sockaddr_in sin {}; | ||||
| if (memset_s(&sin, sizeof(sin), 0, sizeof(sin)) != EOK) { | if (memset_s(&sin, sizeof(sin), 0, sizeof(sin)) != EOK) { | ||||
| @@ -104,6 +106,18 @@ void TcpServer::Init() { | |||||
| MS_EXCEPTION_IF_NULL(listener_); | MS_EXCEPTION_IF_NULL(listener_); | ||||
| if (server_port_ == 0) { | |||||
| struct sockaddr_in sin_bound {}; | |||||
| if (memset_s(&sin, sizeof(sin_bound), 0, sizeof(sin_bound)) != EOK) { | |||||
| MS_LOG(EXCEPTION) << "Initialize sockaddr_in failed!"; | |||||
| } | |||||
| socklen_t addr_len = sizeof(struct sockaddr_in); | |||||
| if (getsockname(evconnlistener_get_fd(listener_), (struct sockaddr *)&sin_bound, &addr_len) != 0) { | |||||
| MS_LOG(EXCEPTION) << "Get sock name failed!"; | |||||
| } | |||||
| server_port_ = htons(sin_bound.sin_port); | |||||
| } | |||||
| signal_event_ = evsignal_new(base_, SIGINT, SignalCallback, reinterpret_cast<void *>(this)); | signal_event_ = evsignal_new(base_, SIGINT, SignalCallback, reinterpret_cast<void *>(this)); | ||||
| MS_EXCEPTION_IF_NULL(signal_event_); | MS_EXCEPTION_IF_NULL(signal_event_); | ||||
| if (event_add(signal_event_, nullptr) < 0) { | if (event_add(signal_event_, nullptr) < 0) { | ||||
| @@ -173,11 +187,13 @@ void TcpServer::RemoveConnection(const evutil_socket_t &fd) { | |||||
| connections_.erase(fd); | connections_.erase(fd); | ||||
| } | } | ||||
| void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, struct sockaddr *, int, void *data) { | |||||
| void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, struct sockaddr *sockaddr, int, | |||||
| void *data) { | |||||
| auto server = reinterpret_cast<class TcpServer *>(data); | auto server = reinterpret_cast<class TcpServer *>(data); | ||||
| auto base = reinterpret_cast<struct event_base *>(server->base_); | auto base = reinterpret_cast<struct event_base *>(server->base_); | ||||
| MS_EXCEPTION_IF_NULL(server); | MS_EXCEPTION_IF_NULL(server); | ||||
| MS_EXCEPTION_IF_NULL(base); | MS_EXCEPTION_IF_NULL(base); | ||||
| MS_EXCEPTION_IF_NULL(sockaddr); | |||||
| struct bufferevent *bev = bufferevent_socket_new(base, fd, BEV_OPT_CLOSE_ON_FREE); | struct bufferevent *bev = bufferevent_socket_new(base, fd, BEV_OPT_CLOSE_ON_FREE); | ||||
| if (!bev) { | if (!bev) { | ||||
| @@ -279,8 +295,10 @@ void TcpServer::SendMessage(const CommMessage &message) { | |||||
| } | } | ||||
| } | } | ||||
| uint16_t TcpServer::BoundPort() const { return server_port_; } | |||||
| void TcpServer::SetMessageCallback(const OnServerReceiveMessage &cb) { message_callback_ = cb; } | void TcpServer::SetMessageCallback(const OnServerReceiveMessage &cb) { message_callback_ = cb; } | ||||
| } // namespace comm | |||||
| } // namespace core | |||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CCSRC_PS_COMM_TCP_SERVER_H_ | |||||
| #define MINDSPORE_CCSRC_PS_COMM_TCP_SERVER_H_ | |||||
| #ifndef MINDSPORE_CCSRC_PS_CORE_TCP_SERVER_H_ | |||||
| #define MINDSPORE_CCSRC_PS_CORE_TCP_SERVER_H_ | |||||
| #include <event2/buffer.h> | #include <event2/buffer.h> | ||||
| #include <event2/bufferevent.h> | #include <event2/bufferevent.h> | ||||
| @@ -31,11 +31,11 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| #include "ps/comm/tcp_message_handler.h" | |||||
| #include "ps/core/tcp_message_handler.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| namespace comm { | |||||
| namespace core { | |||||
| class TcpServer; | class TcpServer; | ||||
| class TcpConnection { | class TcpConnection { | ||||
| @@ -83,6 +83,7 @@ class TcpServer { | |||||
| void SetMessageCallback(const OnServerReceiveMessage &cb); | void SetMessageCallback(const OnServerReceiveMessage &cb); | ||||
| static void SendMessage(const TcpConnection &conn, const CommMessage &message); | static void SendMessage(const TcpConnection &conn, const CommMessage &message); | ||||
| void SendMessage(const CommMessage &message); | void SendMessage(const CommMessage &message); | ||||
| uint16_t BoundPort() const; | |||||
| protected: | protected: | ||||
| static void ListenerCallback(struct evconnlistener *listener, evutil_socket_t socket, struct sockaddr *saddr, | static void ListenerCallback(struct evconnlistener *listener, evutil_socket_t socket, struct sockaddr *saddr, | ||||
| @@ -106,7 +107,7 @@ class TcpServer { | |||||
| OnServerReceiveMessage message_callback_; | OnServerReceiveMessage message_callback_; | ||||
| }; | }; | ||||
| } // namespace comm | |||||
| } // namespace core | |||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_PS_COMM_TCP_SERVER_H_ | |||||
| #endif // MINDSPORE_CCSRC_PS_CORE_TCP_SERVER_H_ | |||||
| @@ -0,0 +1,46 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include "common/common_test.h" | |||||
| #include "ps/core/cluster_config.h" | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| namespace core { | |||||
| class TestClusterConfig : public UT::Common { | |||||
| public: | |||||
| TestClusterConfig() = default; | |||||
| virtual ~TestClusterConfig() = default; | |||||
| void SetUp() override {} | |||||
| void TearDown() override {} | |||||
| }; | |||||
| TEST_F(TestClusterConfig, HeartbeatInterval) { | |||||
| ClusterConfig::Init(2, 2, std::make_unique<std::string>("127.0.0.1"), 8080); | |||||
| EXPECT_TRUE(ClusterConfig::heartbeat_interval() == 3); | |||||
| ClusterConfig::set_heartbeat_interval(100); | |||||
| EXPECT_TRUE(ClusterConfig::heartbeat_interval() == 100); | |||||
| EXPECT_STREQ(ClusterConfig::scheduler_host().c_str(), "127.0.0.1"); | |||||
| EXPECT_TRUE(ClusterConfig::scheduler_port() == 8080); | |||||
| } | |||||
| } // namespace core | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,44 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/common_test.h" | |||||
| #include "ps/core/comm_util.h" | |||||
| #include <memory> | |||||
| #include <thread> | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| namespace core { | |||||
| class TestCommUtil : public UT::Common { | |||||
| public: | |||||
| TestCommUtil() = default; | |||||
| virtual ~TestCommUtil() = default; | |||||
| void SetUp() override {} | |||||
| void TearDown() override {} | |||||
| }; | |||||
| TEST_F(TestCommUtil, GetAvailableInterfaceAndIP) { | |||||
| std::string interface; | |||||
| std::string ip; | |||||
| CommUtil::GetAvailableInterfaceAndIP(&interface, &ip); | |||||
| EXPECT_TRUE(!interface.empty()); | |||||
| EXPECT_TRUE(!ip.empty()); | |||||
| } | |||||
| } // namespace comm | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||
| @@ -14,7 +14,7 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "ps/comm/http_server.h" | |||||
| #include "ps/core/http_server.h" | |||||
| #include "common/common_test.h" | #include "common/common_test.h" | ||||
| #include <gtest/gtest.h> | #include <gtest/gtest.h> | ||||
| #include <algorithm> | #include <algorithm> | ||||
| @@ -28,7 +28,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| namespace comm { | |||||
| namespace core { | |||||
| class TestHttpServer : public UT::Common { | class TestHttpServer : public UT::Common { | ||||
| public: | public: | ||||
| @@ -17,11 +17,11 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include "common/common_test.h" | #include "common/common_test.h" | ||||
| #include "ps/comm/tcp_client.h" | |||||
| #include "ps/core/tcp_client.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| namespace comm { | |||||
| namespace core { | |||||
| class TestTcpClient : public UT::Common { | class TestTcpClient : public UT::Common { | ||||
| public: | public: | ||||
| TestTcpClient() = default; | TestTcpClient() = default; | ||||
| @@ -0,0 +1,163 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ps/core/tcp_message_handler.h" | |||||
| #include "common/common_test.h" | |||||
| #include <memory> | |||||
| #include <thread> | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| namespace core { | |||||
| class TestTcpMessageHandler : public UT::Common { | |||||
| public: | |||||
| using messageReceive = std::function<void(const CommMessage &message)>; | |||||
| TestTcpMessageHandler() = default; | |||||
| virtual ~TestTcpMessageHandler() = default; | |||||
| void SetUp() override {} | |||||
| void TearDown() override {} | |||||
| }; | |||||
| TEST_F(TestTcpMessageHandler, 4_Header_1003_Data) { | |||||
| TcpMessageHandler handler; | |||||
| handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 1000); }); | |||||
| std::string data(1000, 'a'); | |||||
| CommMessage message; | |||||
| message.set_data(data); | |||||
| uint32_t buf_size = message.ByteSizeLong(); | |||||
| char result[1007]; | |||||
| int ret = memcpy_s(result, 4, &buf_size, 4); | |||||
| if (ret != 0) { | |||||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||||
| } | |||||
| std::vector<char> serialized(buf_size); | |||||
| message.SerializeToArray(serialized.data(), static_cast<int>(buf_size)); | |||||
| memcpy_s(result + 4, buf_size, serialized.data(), buf_size); | |||||
| handler.ReceiveMessage(result, buf_size + 4); | |||||
| } | |||||
| TEST_F(TestTcpMessageHandler, 4_Header_1003_Data_4_Header_1003_Data) { | |||||
| TcpMessageHandler handler; | |||||
| handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 1000); }); | |||||
| std::string data(1000, 'a'); | |||||
| CommMessage message; | |||||
| message.set_data(data); | |||||
| uint32_t buf_size = message.ByteSizeLong(); | |||||
| char result[2014]; | |||||
| int ret = memcpy_s(result, 4, &buf_size, 4); | |||||
| if (ret != 0) { | |||||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||||
| } | |||||
| std::vector<char> serialized(buf_size); | |||||
| message.SerializeToArray(serialized.data(), static_cast<int>(buf_size)); | |||||
| ret = memcpy_s(result + 4, buf_size, serialized.data(), buf_size); | |||||
| if (ret != 0) { | |||||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||||
| } | |||||
| ret = memcpy_s(result + 4 + buf_size, 4, &buf_size, 4); | |||||
| if (ret != 0) { | |||||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||||
| } | |||||
| ret = memcpy_s(result + 4 + buf_size + 4, buf_size, serialized.data(), buf_size); | |||||
| if (ret != 0) { | |||||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||||
| } | |||||
| handler.ReceiveMessage(result, 2 * buf_size + 4 * 2); | |||||
| } | |||||
| TEST_F(TestTcpMessageHandler, 4_Header_4090_Data_2_Header_2_header_4090_data) { | |||||
| TcpMessageHandler handler; | |||||
| handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 4087); }); | |||||
| std::string data(4087, 'a'); | |||||
| CommMessage message; | |||||
| message.set_data(data); | |||||
| uint32_t buf_size = message.ByteSizeLong(); | |||||
| char result[4096]; | |||||
| int ret = memcpy_s(result, 4, &buf_size, 4); | |||||
| if (ret != 0) { | |||||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||||
| } | |||||
| std::vector<char> serialized(buf_size); | |||||
| message.SerializeToArray(serialized.data(), static_cast<int>(buf_size)); | |||||
| ret = memcpy_s(result + 4, buf_size, serialized.data(), buf_size); | |||||
| if (ret != 0) { | |||||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||||
| } | |||||
| ret = memcpy_s(result + 4 + buf_size, 2, &buf_size, 2); | |||||
| if (ret != 0) { | |||||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||||
| } | |||||
| handler.ReceiveMessage(result, 4096); | |||||
| ret = memcpy_s(result, 2, &buf_size + 2, 2); | |||||
| if (ret != 0) { | |||||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||||
| } | |||||
| ret = memcpy_s(result + 2, buf_size, serialized.data(), buf_size); | |||||
| if (ret != 0) { | |||||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||||
| } | |||||
| handler.ReceiveMessage(result, 4092); | |||||
| } | |||||
| TEST_F(TestTcpMessageHandler, 4_Header_4088_Data_4_Header_4088_data) { | |||||
| TcpMessageHandler handler; | |||||
| handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 4085); }); | |||||
| std::string data(4085, 'a'); | |||||
| CommMessage message; | |||||
| message.set_data(data); | |||||
| uint32_t buf_size = message.ByteSizeLong(); | |||||
| char result[4096]; | |||||
| int ret = memcpy_s(result, 4, &buf_size, 4); | |||||
| if (ret != 0) { | |||||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||||
| } | |||||
| std::vector<char> serialized(buf_size); | |||||
| message.SerializeToArray(serialized.data(), static_cast<int>(buf_size)); | |||||
| ret = memcpy_s(result + 4, buf_size, serialized.data(), buf_size); | |||||
| if (ret != 0) { | |||||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||||
| } | |||||
| ret = memcpy_s(result + 4 + buf_size, 4, &buf_size, 4); | |||||
| if (ret != 0) { | |||||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||||
| } | |||||
| handler.ReceiveMessage(result, 4096); | |||||
| ret = memcpy_s(result, buf_size, serialized.data(), buf_size); | |||||
| if (ret != 0) { | |||||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||||
| } | |||||
| handler.ReceiveMessage(result, 4088); | |||||
| } | |||||
| } // namespace comm | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "ps/comm/tcp_client.h" | |||||
| #include "ps/comm/tcp_server.h" | |||||
| #include "ps/core/tcp_client.h" | |||||
| #include "ps/core/tcp_server.h" | |||||
| #include "common/common_test.h" | #include "common/common_test.h" | ||||
| #include <memory> | #include <memory> | ||||
| @@ -23,14 +23,14 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| namespace comm { | |||||
| namespace core { | |||||
| class TestTcpServer : public UT::Common { | class TestTcpServer : public UT::Common { | ||||
| public: | public: | ||||
| TestTcpServer() : client_(nullptr), server_(nullptr) {} | TestTcpServer() : client_(nullptr), server_(nullptr) {} | ||||
| virtual ~TestTcpServer() = default; | virtual ~TestTcpServer() = default; | ||||
| void SetUp() override { | void SetUp() override { | ||||
| server_ = std::make_unique<TcpServer>("127.0.0.1", 9998); | |||||
| server_ = std::make_unique<TcpServer>("127.0.0.1", 0); | |||||
| std::unique_ptr<std::thread> http_server_thread_(nullptr); | std::unique_ptr<std::thread> http_server_thread_(nullptr); | ||||
| http_server_thread_ = std::make_unique<std::thread>([&]() { | http_server_thread_ = std::make_unique<std::thread>([&]() { | ||||
| server_->SetMessageCallback([](const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { | server_->SetMessageCallback([](const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { | ||||
| @@ -57,7 +57,7 @@ class TestTcpServer : public UT::Common { | |||||
| }; | }; | ||||
| TEST_F(TestTcpServer, ServerSendMessage) { | TEST_F(TestTcpServer, ServerSendMessage) { | ||||
| client_ = std::make_unique<TcpClient>("127.0.0.1", 9998); | |||||
| client_ = std::make_unique<TcpClient>("127.0.0.1", server_->BoundPort()); | |||||
| std::unique_ptr<std::thread> http_client_thread(nullptr); | std::unique_ptr<std::thread> http_client_thread(nullptr); | ||||
| http_client_thread = std::make_unique<std::thread>([&]() { | http_client_thread = std::make_unique<std::thread>([&]() { | ||||
| client_->SetMessageCallback([](const TcpClient &client, const CommMessage &message) { | client_->SetMessageCallback([](const TcpClient &client, const CommMessage &message) { | ||||
| @@ -82,6 +82,6 @@ TEST_F(TestTcpServer, ServerSendMessage) { | |||||
| }); | }); | ||||
| http_client_thread->detach(); | http_client_thread->detach(); | ||||
| } | } | ||||
| } // namespace comm | |||||
| } // namespace core | |||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||