From 96d8c411e715cfe8e04eeac3988e06e66c04f389 Mon Sep 17 00:00:00 2001 From: anancds Date: Sun, 8 Nov 2020 19:47:28 +0800 Subject: [PATCH] added message handler unit test --- mindspore/ccsrc/CMakeLists.txt | 4 +- mindspore/ccsrc/ps/CMakeLists.txt | 13 +- mindspore/ccsrc/ps/core/cluster_config.cc | 58 +++++++ mindspore/ccsrc/ps/core/cluster_config.h | 56 ++++++ .../ccsrc/ps/{comm => core}/comm_util.cc | 44 ++++- mindspore/ccsrc/ps/{comm => core}/comm_util.h | 26 ++- .../ps/{comm => core}/http_message_handler.cc | 6 +- .../ps/{comm => core}/http_message_handler.h | 10 +- .../ccsrc/ps/{comm => core}/http_server.cc | 14 +- .../ccsrc/ps/{comm => core}/http_server.h | 12 +- .../ccsrc/ps/{comm => core}/protos/comm.proto | 17 +- .../ccsrc/ps/{comm => core}/protos/ps.proto | 4 + .../ccsrc/ps/{comm => core}/tcp_client.cc | 39 ++++- .../ccsrc/ps/{comm => core}/tcp_client.h | 15 +- .../ps/{comm => core}/tcp_message_handler.cc | 15 +- .../ps/{comm => core}/tcp_message_handler.h | 10 +- .../ccsrc/ps/{comm => core}/tcp_server.cc | 30 +++- .../ccsrc/ps/{comm => core}/tcp_server.h | 13 +- tests/ut/cpp/ps/core/cluster_config_test.cc | 46 +++++ tests/ut/cpp/ps/core/common_util_test.cc | 44 +++++ .../cpp/ps/{comm => core}/http_server_test.cc | 4 +- .../cpp/ps/{comm => core}/tcp_client_tests.cc | 4 +- .../cpp/ps/core/tcp_message_handler_test.cc | 163 ++++++++++++++++++ .../ps/{comm => core}/tcp_pb_server_test.cc | 12 +- 24 files changed, 572 insertions(+), 87 deletions(-) create mode 100644 mindspore/ccsrc/ps/core/cluster_config.cc create mode 100644 mindspore/ccsrc/ps/core/cluster_config.h rename mindspore/ccsrc/ps/{comm => core}/comm_util.cc (51%) rename mindspore/ccsrc/ps/{comm => core}/comm_util.h (67%) rename mindspore/ccsrc/ps/{comm => core}/http_message_handler.cc (98%) rename mindspore/ccsrc/ps/{comm => core}/http_message_handler.h (93%) rename mindspore/ccsrc/ps/{comm => core}/http_server.cc (94%) rename mindspore/ccsrc/ps/{comm => core}/http_server.h (91%) rename mindspore/ccsrc/ps/{comm => core}/protos/comm.proto (84%) rename mindspore/ccsrc/ps/{comm => core}/protos/ps.proto (90%) rename mindspore/ccsrc/ps/{comm => core}/tcp_client.cc (87%) rename mindspore/ccsrc/ps/{comm => core}/tcp_client.h (86%) rename mindspore/ccsrc/ps/{comm => core}/tcp_message_handler.cc (86%) rename mindspore/ccsrc/ps/{comm => core}/tcp_message_handler.h (88%) rename mindspore/ccsrc/ps/{comm => core}/tcp_server.cc (91%) rename mindspore/ccsrc/ps/{comm => core}/tcp_server.h (93%) create mode 100644 tests/ut/cpp/ps/core/cluster_config_test.cc create mode 100644 tests/ut/cpp/ps/core/common_util_test.cc rename tests/ut/cpp/ps/{comm => core}/http_server_test.cc (99%) rename tests/ut/cpp/ps/{comm => core}/tcp_client_tests.cc (94%) create mode 100644 tests/ut/cpp/ps/core/tcp_message_handler_test.cc rename tests/ut/cpp/ps/{comm => core}/tcp_pb_server_test.cc (90%) diff --git a/mindspore/ccsrc/CMakeLists.txt b/mindspore/ccsrc/CMakeLists.txt index 38c325a127..e46e47e633 100644 --- a/mindspore/ccsrc/CMakeLists.txt +++ b/mindspore/ccsrc/CMakeLists.txt @@ -100,8 +100,8 @@ message("onnx proto path is :" ${ONNX_PROTO}) ms_protobuf_generate(ONNX_PROTO_SRCS ONNX_PROTO_HDRS ${ONNX_PROTO}) list(APPEND MINDSPORE_PROTO_LIST ${ONNX_PROTO_SRCS}) -include_directories("${CMAKE_BINARY_DIR}/ps/comm") -file(GLOB_RECURSE COMM_PROTO_IN RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ps/comm/protos/*.proto") +include_directories("${CMAKE_BINARY_DIR}/ps/core") +file(GLOB_RECURSE COMM_PROTO_IN RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ps/core/protos/*.proto") ms_protobuf_generate(COMM_PROTO_SRCS COMM_PROTO_HDRS ${COMM_PROTO_IN}) list(APPEND MINDSPORE_PROTO_LIST ${COMM_PROTO_SRCS}) diff --git a/mindspore/ccsrc/ps/CMakeLists.txt b/mindspore/ccsrc/ps/CMakeLists.txt index e8e412734f..658546465a 100644 --- a/mindspore/ccsrc/ps/CMakeLists.txt +++ b/mindspore/ccsrc/ps/CMakeLists.txt @@ -5,12 +5,13 @@ if (NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))) list(REMOVE_ITEM _PS_SRC_FILES "optimizer_info.cc") list(REMOVE_ITEM _PS_SRC_FILES "scheduler.cc") list(REMOVE_ITEM _PS_SRC_FILES "util.cc") - list(REMOVE_ITEM _PS_SRC_FILES "comm/http_message_handler.cc") - list(REMOVE_ITEM _PS_SRC_FILES "comm/http_server.cc") - list(REMOVE_ITEM _PS_SRC_FILES "comm/comm_util.cc") - list(REMOVE_ITEM _PS_SRC_FILES "comm/tcp_client.cc") - list(REMOVE_ITEM _PS_SRC_FILES "comm/tcp_message_handler.cc") - list(REMOVE_ITEM _PS_SRC_FILES "comm/tcp_server.cc") + list(REMOVE_ITEM _PS_SRC_FILES "core/http_message_handler.cc") + list(REMOVE_ITEM _PS_SRC_FILES "core/http_server.cc") + list(REMOVE_ITEM _PS_SRC_FILES "core/comm_util.cc") + list(REMOVE_ITEM _PS_SRC_FILES "core/tcp_client.cc") + list(REMOVE_ITEM _PS_SRC_FILES "core/tcp_message_handler.cc") + list(REMOVE_ITEM _PS_SRC_FILES "core/tcp_server.cc") + list(REMOVE_ITEM _PS_SRC_FILES "core/cluster_config.cc") endif() set_property(SOURCE ${_PS_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PS) diff --git a/mindspore/ccsrc/ps/core/cluster_config.cc b/mindspore/ccsrc/ps/core/cluster_config.cc new file mode 100644 index 0000000000..0b8a00c89a --- /dev/null +++ b/mindspore/ccsrc/ps/core/cluster_config.cc @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/core/cluster_config.h" + +#include + +namespace mindspore { +namespace ps { +namespace core { + +uint32_t ClusterConfig::worker_num_ = 0; +uint32_t ClusterConfig::server_num_ = 0; +uint32_t ClusterConfig::heartbeat_interval_ = kHeartbeatInterval; +std::unique_ptr ClusterConfig::scheduler_host_ = nullptr; +uint16_t ClusterConfig::scheduler_port_ = 0; + +void ClusterConfig::Init(const uint32_t &worker_num, const uint32_t &server_num, + std::unique_ptr scheduler_host, const uint16_t &scheduler_port) { + worker_num_ = worker_num; + server_num_ = server_num; + if (!CommUtil::CheckIp(*scheduler_host.get())) { + MS_LOG(EXCEPTION) << "The scheduler_host:" << *scheduler_host.get() << " is illegal!"; + } + scheduler_host_ = std::move(scheduler_host); + scheduler_port_ = scheduler_port; +} + +uint32_t ClusterConfig::worker_num() { return worker_num_; } + +uint32_t ClusterConfig::server_num() { return server_num_; } + +uint32_t ClusterConfig::heartbeat_interval() { return heartbeat_interval_; } + +void ClusterConfig::set_heartbeat_interval(const uint32_t &heartbeat_interval) { + heartbeat_interval_ = heartbeat_interval; +} + +std::string ClusterConfig::scheduler_host() { return *scheduler_host_.get(); } + +uint16_t ClusterConfig::scheduler_port() { return scheduler_port_; } + +} // namespace core +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/core/cluster_config.h b/mindspore/ccsrc/ps/core/cluster_config.h new file mode 100644 index 0000000000..8a8ace7fb4 --- /dev/null +++ b/mindspore/ccsrc/ps/core/cluster_config.h @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_CORE_CLUSTER_CONFIG_H_ +#define MINDSPORE_CCSRC_PS_CORE_CLUSTER_CONFIG_H_ + +#include +#include +#include +#include + +#include "utils/log_adapter.h" +#include "ps/core/comm_util.h" + +namespace mindspore { +namespace ps { +namespace core { + +constexpr uint32_t kHeartbeatInterval = 3; + +class ClusterConfig { + public: + static void Init(const uint32_t &worker_num, const uint32_t &server_num, std::unique_ptr scheduler_host, + const uint16_t &scheduler_port); + static uint32_t worker_num(); + static uint32_t server_num(); + static uint32_t heartbeat_interval(); + static void set_heartbeat_interval(const uint32_t &heartbeat_interval); + static std::string scheduler_host(); + static uint16_t scheduler_port(); + + private: + static uint32_t worker_num_; + static uint32_t server_num_; + static uint32_t heartbeat_interval_; + static std::unique_ptr scheduler_host_; + static uint16_t scheduler_port_; +}; +} // namespace core +} // namespace ps +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PS_CORE_CLUSTER_CONFIG_H_ diff --git a/mindspore/ccsrc/ps/comm/comm_util.cc b/mindspore/ccsrc/ps/core/comm_util.cc similarity index 51% rename from mindspore/ccsrc/ps/comm/comm_util.cc rename to mindspore/ccsrc/ps/core/comm_util.cc index 1b3be35edc..9d5be87ee0 100644 --- a/mindspore/ccsrc/ps/comm/comm_util.cc +++ b/mindspore/ccsrc/ps/core/comm_util.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "ps/comm/comm_util.h" +#include "ps/core/comm_util.h" #include #include @@ -25,7 +25,7 @@ namespace mindspore { namespace ps { -namespace comm { +namespace core { bool CommUtil::CheckIpWithRegex(const std::string &ip) { std::regex pattern("((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?).){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)"); @@ -36,15 +36,47 @@ bool CommUtil::CheckIpWithRegex(const std::string &ip) { return false; } -void CommUtil::CheckIp(const std::string &ip) { +bool CommUtil::CheckIp(const std::string &ip) { if (!CheckIpWithRegex(ip)) { - MS_LOG(EXCEPTION) << "Server address" << ip << " illegal!"; + return false; } int64_t uAddr = inet_addr(ip.c_str()); if (INADDR_NONE == uAddr) { - MS_LOG(EXCEPTION) << "Server address illegal, inet_addr converting failed!"; + return false; } + return true; } -} // namespace comm + +void CommUtil::GetAvailableInterfaceAndIP(std::string *interface, std::string *ip) { + MS_EXCEPTION_IF_NULL(interface); + MS_EXCEPTION_IF_NULL(ip); + struct ifaddrs *if_address = nullptr; + struct ifaddrs *ifa = nullptr; + + interface->clear(); + ip->clear(); + getifaddrs(&if_address); + for (ifa = if_address; ifa != nullptr; ifa = ifa->ifa_next) { + if (ifa->ifa_addr == nullptr) { + continue; + } + + if (ifa->ifa_addr->sa_family == AF_INET && (ifa->ifa_flags & IFF_LOOPBACK) == 0) { + char address_buffer[INET_ADDRSTRLEN] = {0}; + void *sin_addr_ptr = &(reinterpret_cast(ifa->ifa_addr))->sin_addr; + MS_EXCEPTION_IF_NULL(sin_addr_ptr); + const char *net_ptr = inet_ntop(AF_INET, sin_addr_ptr, address_buffer, INET_ADDRSTRLEN); + MS_EXCEPTION_IF_NULL(net_ptr); + + *ip = address_buffer; + *interface = ifa->ifa_name; + break; + } + } + MS_EXCEPTION_IF_NULL(if_address); + freeifaddrs(if_address); +} + +} // namespace core } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/comm/comm_util.h b/mindspore/ccsrc/ps/core/comm_util.h similarity index 67% rename from mindspore/ccsrc/ps/comm/comm_util.h rename to mindspore/ccsrc/ps/core/comm_util.h index 46455671e4..3200a62415 100644 --- a/mindspore/ccsrc/ps/comm/comm_util.h +++ b/mindspore/ccsrc/ps/core/comm_util.h @@ -14,8 +14,21 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PS_COMM_COMM_UTIL_H_ -#define MINDSPORE_CCSRC_PS_COMM_COMM_UTIL_H_ +#ifndef MINDSPORE_CCSRC_PS_CORE_COMM_UTIL_H_ +#define MINDSPORE_CCSRC_PS_CORE_COMM_UTIL_H_ + +#include +#ifdef _MSC_VER +#include +#include +#include +#include +#else +#include +#include +#include +#include +#endif #include #include @@ -35,15 +48,16 @@ namespace mindspore { namespace ps { -namespace comm { +namespace core { class CommUtil { public: static bool CheckIpWithRegex(const std::string &ip); - static void CheckIp(const std::string &ip); + static bool CheckIp(const std::string &ip); + static void GetAvailableInterfaceAndIP(std::string *interface, std::string *ip); }; -} // namespace comm +} // namespace core } // namespace ps } // namespace mindspore -#endif // MINDSPORE_CCSRC_PS_COMM_COMM_UTIL_H_ +#endif // MINDSPORE_CCSRC_PS_CORE_COMM_UTIL_H_ diff --git a/mindspore/ccsrc/ps/comm/http_message_handler.cc b/mindspore/ccsrc/ps/core/http_message_handler.cc similarity index 98% rename from mindspore/ccsrc/ps/comm/http_message_handler.cc rename to mindspore/ccsrc/ps/core/http_message_handler.cc index 9b226de319..93e43ccf0c 100644 --- a/mindspore/ccsrc/ps/comm/http_message_handler.cc +++ b/mindspore/ccsrc/ps/core/http_message_handler.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "ps/comm/http_message_handler.h" +#include "ps/core/http_message_handler.h" #include #include @@ -36,7 +36,7 @@ namespace mindspore { namespace ps { -namespace comm { +namespace core { void HttpMessageHandler::InitHttpMessage() { MS_EXCEPTION_IF_NULL(event_request_); @@ -202,6 +202,6 @@ void HttpMessageHandler::RespError(int nCode, const std::string &message) { evhttp_send_error(event_request_, nCode, message.c_str()); } } -} // namespace comm +} // namespace core } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/comm/http_message_handler.h b/mindspore/ccsrc/ps/core/http_message_handler.h similarity index 93% rename from mindspore/ccsrc/ps/comm/http_message_handler.h rename to mindspore/ccsrc/ps/core/http_message_handler.h index 2de7083c47..72f0322c97 100644 --- a/mindspore/ccsrc/ps/comm/http_message_handler.h +++ b/mindspore/ccsrc/ps/core/http_message_handler.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PS_COMM_HTTP_MESSAGE_HANDLER_H_ -#define MINDSPORE_CCSRC_PS_COMM_HTTP_MESSAGE_HANDLER_H_ +#ifndef MINDSPORE_CCSRC_PS_CORE_HTTP_MESSAGE_HANDLER_H_ +#define MINDSPORE_CCSRC_PS_CORE_HTTP_MESSAGE_HANDLER_H_ #include #include @@ -36,7 +36,7 @@ namespace mindspore { namespace ps { -namespace comm { +namespace core { using HttpHeaders = std::map>; @@ -101,7 +101,7 @@ class HttpMessageHandler { void ParsePostParam(); }; -} // namespace comm +} // namespace core } // namespace ps } // namespace mindspore -#endif // MINDSPORE_CCSRC_PS_COMM_HTTP_MESSAGE_HANDLER_H_ +#endif // MINDSPORE_CCSRC_PS_CORE_HTTP_MESSAGE_HANDLER_H_ diff --git a/mindspore/ccsrc/ps/comm/http_server.cc b/mindspore/ccsrc/ps/core/http_server.cc similarity index 94% rename from mindspore/ccsrc/ps/comm/http_server.cc rename to mindspore/ccsrc/ps/core/http_server.cc index 061b84ae9d..548e6ec1c6 100644 --- a/mindspore/ccsrc/ps/comm/http_server.cc +++ b/mindspore/ccsrc/ps/core/http_server.cc @@ -14,9 +14,9 @@ * limitations under the License. */ -#include "ps/comm/http_server.h" -#include "ps/comm/http_message_handler.h" -#include "ps/comm/comm_util.h" +#include "ps/core/http_server.h" +#include "ps/core/http_message_handler.h" +#include "ps/core/comm_util.h" #ifdef WIN32 #include @@ -40,12 +40,14 @@ namespace mindspore { namespace ps { -namespace comm { +namespace core { HttpServer::~HttpServer() { Stop(); } bool HttpServer::InitServer() { - CommUtil::CheckIp(server_address_); + if (!CommUtil::CheckIp(server_address_)) { + MS_LOG(EXCEPTION) << "The http server ip:" << server_address_ << " is illegal!"; + } event_base_ = event_base_new(); MS_EXCEPTION_IF_NULL(event_base_); @@ -154,6 +156,6 @@ void HttpServer::Stop() { } } -} // namespace comm +} // namespace core } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/comm/http_server.h b/mindspore/ccsrc/ps/core/http_server.h similarity index 91% rename from mindspore/ccsrc/ps/comm/http_server.h rename to mindspore/ccsrc/ps/core/http_server.h index e8ba957866..acea23db65 100644 --- a/mindspore/ccsrc/ps/comm/http_server.h +++ b/mindspore/ccsrc/ps/core/http_server.h @@ -14,10 +14,10 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PS_COMM_HTTP_SERVER_H_ -#define MINDSPORE_CCSRC_PS_COMM_HTTP_SERVER_H_ +#ifndef MINDSPORE_CCSRC_PS_CORE_HTTP_SERVER_H_ +#define MINDSPORE_CCSRC_PS_CORE_HTTP_SERVER_H_ -#include "ps/comm/http_message_handler.h" +#include "ps/core/http_message_handler.h" #include #include @@ -35,7 +35,7 @@ namespace mindspore { namespace ps { -namespace comm { +namespace core { typedef enum eHttpMethod { HM_GET = 1 << 0, @@ -86,8 +86,8 @@ class HttpServer { bool is_init_; }; -} // namespace comm +} // namespace core } // namespace ps } // namespace mindspore -#endif // MINDSPORE_CCSRC_PS_COMM_HTTP_SERVER_H_ +#endif // MINDSPORE_CCSRC_PS_CORE_HTTP_SERVER_H_ diff --git a/mindspore/ccsrc/ps/comm/protos/comm.proto b/mindspore/ccsrc/ps/core/protos/comm.proto similarity index 84% rename from mindspore/ccsrc/ps/comm/protos/comm.proto rename to mindspore/ccsrc/ps/core/protos/comm.proto index 653af8edfe..9862ab998b 100644 --- a/mindspore/ccsrc/ps/comm/protos/comm.proto +++ b/mindspore/ccsrc/ps/core/protos/comm.proto @@ -16,9 +16,24 @@ syntax = "proto3"; import "google/protobuf/any.proto"; -package mindspore.ps; +package mindspore.ps.core; option optimize_for = LITE_RUNTIME; +enum ClusterCommand { + TERMINATE = 0; + REGISTER = 1; + ACK = 2; + HEARTBEAT = 3; + FETCH_WORKERS = 4; + FETCH_SERVERS = 5; +} + +enum Role { + SERVER = 0; + WORKER = 1; + SCHEDULER = 2; +} + message MessageMeta { // hostname or ip string hostname = 1; diff --git a/mindspore/ccsrc/ps/comm/protos/ps.proto b/mindspore/ccsrc/ps/core/protos/ps.proto similarity index 90% rename from mindspore/ccsrc/ps/comm/protos/ps.proto rename to mindspore/ccsrc/ps/core/protos/ps.proto index 9cee1712bf..cd5835ed14 100644 --- a/mindspore/ccsrc/ps/comm/protos/ps.proto +++ b/mindspore/ccsrc/ps/core/protos/ps.proto @@ -14,6 +14,10 @@ * limitations under the License. */ +syntax = "proto3"; +package mindspore.ps.core; +option optimize_for = LITE_RUNTIME; + message KVMessage { repeated int32 keys = 1; repeated float values = 2; diff --git a/mindspore/ccsrc/ps/comm/tcp_client.cc b/mindspore/ccsrc/ps/core/tcp_client.cc similarity index 87% rename from mindspore/ccsrc/ps/comm/tcp_client.cc rename to mindspore/ccsrc/ps/core/tcp_client.cc index b55aa18af2..ad8db6aed4 100644 --- a/mindspore/ccsrc/ps/comm/tcp_client.cc +++ b/mindspore/ccsrc/ps/core/tcp_client.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "ps/comm/tcp_client.h" +#include "ps/core/tcp_client.h" #include #include @@ -30,11 +30,11 @@ #include #include -#include "ps/comm/comm_util.h" +#include "ps/core/comm_util.h" namespace mindspore { namespace ps { -namespace comm { +namespace core { TcpClient::TcpClient(const std::string &address, std::uint16_t port) : event_base_(nullptr), @@ -65,7 +65,9 @@ void TcpClient::Init() { if (buffer_event_) { return; } - CommUtil::CheckIp(server_address_); + if (!CommUtil::CheckIp(server_address_)) { + MS_LOG(EXCEPTION) << "The tcp client ip:" << server_address_ << " is illegal!"; + } event_base_ = event_base_new(); MS_EXCEPTION_IF_NULL(event_base_); @@ -166,6 +168,23 @@ void TcpClient::OnReadHandler(const void *buf, size_t num) { message_handler_.ReceiveMessage(buf, num); } +void TcpClient::SendHeartBeatCallback(evutil_socket_t, int16_t, void *arg) { + MS_EXCEPTION_IF_NULL(arg); + auto tcp_client = reinterpret_cast(arg); + MessageMeta meta; + meta.set_cmd(ClusterCommand::HEARTBEAT); + CommMessage message; + message.set_allocated_pb_meta(&meta); + tcp_client->SendMessage(message); + + struct event *ev; + struct timeval timeout {}; + timeout.tv_sec = ClusterConfig::heartbeat_interval(); + timeout.tv_usec = 0; + ev = evtimer_new(tcp_client->event_base_, SendHeartBeatCallback, arg); + evtimer_add(ev, &timeout); +} + void TcpClient::EventCallback(struct bufferevent *bev, std::int16_t events, void *ptr) { MS_EXCEPTION_IF_NULL(bev); MS_EXCEPTION_IF_NULL(ptr); @@ -226,6 +245,16 @@ void TcpClient::SendMessage(const CommMessage &message) const { } } -} // namespace comm +void TcpClient::SendMessageWithTimer() { + MS_EXCEPTION_IF_NULL(buffer_event_); + struct event *ev = nullptr; + struct timeval timeout {}; + timeout.tv_sec = 0; + timeout.tv_usec = 0; + ev = evtimer_new(event_base_, SendHeartBeatCallback, this); + evtimer_add(ev, &timeout); +} + +} // namespace core } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/comm/tcp_client.h b/mindspore/ccsrc/ps/core/tcp_client.h similarity index 86% rename from mindspore/ccsrc/ps/comm/tcp_client.h rename to mindspore/ccsrc/ps/core/tcp_client.h index 2108e1db85..734e9cdbdb 100644 --- a/mindspore/ccsrc/ps/comm/tcp_client.h +++ b/mindspore/ccsrc/ps/core/tcp_client.h @@ -14,10 +14,10 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PS_COMM_TCP_CLIENT_H_ -#define MINDSPORE_CCSRC_PS_COMM_TCP_CLIENT_H_ +#ifndef MINDSPORE_CCSRC_PS_CORE_TCP_CLIENT_H_ +#define MINDSPORE_CCSRC_PS_CORE_TCP_CLIENT_H_ -#include "ps/comm/tcp_message_handler.h" +#include "ps/core/tcp_message_handler.h" #include #include @@ -27,10 +27,11 @@ #include #include "proto/comm.pb.h" +#include "ps/core/cluster_config.h" namespace mindspore { namespace ps { -namespace comm { +namespace core { class TcpClient { public: @@ -53,6 +54,7 @@ class TcpClient { void StartWithNoBlock(); void SetMessageCallback(const OnMessage &cb); void SendMessage(const CommMessage &message) const; + void SendMessageWithTimer(); protected: static void SetTcpNoDelay(const evutil_socket_t &fd); @@ -60,6 +62,7 @@ class TcpClient { static void ReadCallback(struct bufferevent *bev, void *ctx); static void EventCallback(struct bufferevent *bev, std::int16_t events, void *ptr); virtual void OnReadHandler(const void *buf, size_t num); + static void SendHeartBeatCallback(evutil_socket_t fd, int16_t event, void *arg); private: OnMessage message_callback_; @@ -78,7 +81,7 @@ class TcpClient { std::uint16_t server_port_; }; -} // namespace comm +} // namespace core } // namespace ps } // namespace mindspore -#endif // MINDSPORE_CCSRC_PS_COMM_TCP_CLIENT_H_ +#endif // MINDSPORE_CCSRC_PS_CORE_TCP_CLIENT_H_ diff --git a/mindspore/ccsrc/ps/comm/tcp_message_handler.cc b/mindspore/ccsrc/ps/core/tcp_message_handler.cc similarity index 86% rename from mindspore/ccsrc/ps/comm/tcp_message_handler.cc rename to mindspore/ccsrc/ps/core/tcp_message_handler.cc index 97285bdcc9..a98b1352a2 100644 --- a/mindspore/ccsrc/ps/comm/tcp_message_handler.cc +++ b/mindspore/ccsrc/ps/core/tcp_message_handler.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "ps/comm/tcp_message_handler.h" +#include "ps/core/tcp_message_handler.h" #include #include @@ -22,7 +22,7 @@ namespace mindspore { namespace ps { -namespace comm { +namespace core { void TcpMessageHandler::SetCallback(const messageReceive &message_receive) { message_callback_ = message_receive; } @@ -37,16 +37,15 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) { --num; if (header_index_ == 3) { message_length_ = *reinterpret_cast(header_); - message_length_ = ntohl(message_length_); remaining_length_ = message_length_; message_buffer_.reset(new unsigned char[remaining_length_]); - buffer_data += i; + buffer_data += (i + 1); break; } } } - if (remaining_length_ > 0) { + if (remaining_length_ > 0 && num > 0) { uint32_t copy_len = remaining_length_ <= num ? remaining_length_ : num; remaining_length_ -= copy_len; num -= copy_len; @@ -60,19 +59,19 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) { if (remaining_length_ == 0) { CommMessage pb_message; - pb_message.ParseFromArray(reinterpret_cast(message_buffer_.get()), message_length_); + pb_message.ParseFromArray(message_buffer_.get(), message_length_); if (message_callback_) { message_callback_(pb_message); } message_buffer_.reset(); message_buffer_ = nullptr; - header_index_ = 0; + header_index_ = -1; last_copy_len_ = 0; } } } } -} // namespace comm +} // namespace core } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/comm/tcp_message_handler.h b/mindspore/ccsrc/ps/core/tcp_message_handler.h similarity index 88% rename from mindspore/ccsrc/ps/comm/tcp_message_handler.h rename to mindspore/ccsrc/ps/core/tcp_message_handler.h index 58686c781e..c13db9d703 100644 --- a/mindspore/ccsrc/ps/comm/tcp_message_handler.h +++ b/mindspore/ccsrc/ps/core/tcp_message_handler.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PS_COMM_TCP_MESSAGE_HANDLER_H_ -#define MINDSPORE_CCSRC_PS_COMM_TCP_MESSAGE_HANDLER_H_ +#ifndef MINDSPORE_CCSRC_PS_CORE_TCP_MESSAGE_HANDLER_H_ +#define MINDSPORE_CCSRC_PS_CORE_TCP_MESSAGE_HANDLER_H_ #include #include @@ -29,7 +29,7 @@ namespace mindspore { namespace ps { -namespace comm { +namespace core { using messageReceive = std::function; @@ -57,8 +57,8 @@ class TcpMessageHandler { int header_index_; uint32_t last_copy_len_; }; -} // namespace comm +} // namespace core } // namespace ps } // namespace mindspore -#endif // MINDSPORE_CCSRC_PS_COMM_TCP_MESSAGE_HANDLER_H_ +#endif // MINDSPORE_CCSRC_PS_CORE_TCP_MESSAGE_HANDLER_H_ diff --git a/mindspore/ccsrc/ps/comm/tcp_server.cc b/mindspore/ccsrc/ps/core/tcp_server.cc similarity index 91% rename from mindspore/ccsrc/ps/comm/tcp_server.cc rename to mindspore/ccsrc/ps/core/tcp_server.cc index 4cf70f3b2d..9a5c1b5987 100644 --- a/mindspore/ccsrc/ps/comm/tcp_server.cc +++ b/mindspore/ccsrc/ps/core/tcp_server.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "ps/comm/tcp_server.h" +#include "ps/core/tcp_server.h" #include #include @@ -27,11 +27,11 @@ #include #include -#include "ps/comm/comm_util.h" +#include "ps/core/comm_util.h" namespace mindspore { namespace ps { -namespace comm { +namespace core { void TcpConnection::InitConnection() { tcp_message_handler_.SetCallback([&](const CommMessage &message) { @@ -88,7 +88,9 @@ void TcpServer::SetServerCallback(const OnConnected &client_conn, const OnDiscon void TcpServer::Init() { base_ = event_base_new(); MS_EXCEPTION_IF_NULL(base_); - CommUtil::CheckIp(server_address_); + if (!CommUtil::CheckIp(server_address_)) { + MS_LOG(EXCEPTION) << "The tcp server ip:" << server_address_ << " is illegal!"; + } struct sockaddr_in sin {}; if (memset_s(&sin, sizeof(sin), 0, sizeof(sin)) != EOK) { @@ -104,6 +106,18 @@ void TcpServer::Init() { MS_EXCEPTION_IF_NULL(listener_); + if (server_port_ == 0) { + struct sockaddr_in sin_bound {}; + if (memset_s(&sin, sizeof(sin_bound), 0, sizeof(sin_bound)) != EOK) { + MS_LOG(EXCEPTION) << "Initialize sockaddr_in failed!"; + } + socklen_t addr_len = sizeof(struct sockaddr_in); + if (getsockname(evconnlistener_get_fd(listener_), (struct sockaddr *)&sin_bound, &addr_len) != 0) { + MS_LOG(EXCEPTION) << "Get sock name failed!"; + } + server_port_ = htons(sin_bound.sin_port); + } + signal_event_ = evsignal_new(base_, SIGINT, SignalCallback, reinterpret_cast(this)); MS_EXCEPTION_IF_NULL(signal_event_); if (event_add(signal_event_, nullptr) < 0) { @@ -173,11 +187,13 @@ void TcpServer::RemoveConnection(const evutil_socket_t &fd) { connections_.erase(fd); } -void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, struct sockaddr *, int, void *data) { +void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, struct sockaddr *sockaddr, int, + void *data) { auto server = reinterpret_cast(data); auto base = reinterpret_cast(server->base_); MS_EXCEPTION_IF_NULL(server); MS_EXCEPTION_IF_NULL(base); + MS_EXCEPTION_IF_NULL(sockaddr); struct bufferevent *bev = bufferevent_socket_new(base, fd, BEV_OPT_CLOSE_ON_FREE); if (!bev) { @@ -279,8 +295,10 @@ void TcpServer::SendMessage(const CommMessage &message) { } } +uint16_t TcpServer::BoundPort() const { return server_port_; } + void TcpServer::SetMessageCallback(const OnServerReceiveMessage &cb) { message_callback_ = cb; } -} // namespace comm +} // namespace core } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/comm/tcp_server.h b/mindspore/ccsrc/ps/core/tcp_server.h similarity index 93% rename from mindspore/ccsrc/ps/comm/tcp_server.h rename to mindspore/ccsrc/ps/core/tcp_server.h index f88cc954ab..ba554e29a2 100644 --- a/mindspore/ccsrc/ps/comm/tcp_server.h +++ b/mindspore/ccsrc/ps/core/tcp_server.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PS_COMM_TCP_SERVER_H_ -#define MINDSPORE_CCSRC_PS_COMM_TCP_SERVER_H_ +#ifndef MINDSPORE_CCSRC_PS_CORE_TCP_SERVER_H_ +#define MINDSPORE_CCSRC_PS_CORE_TCP_SERVER_H_ #include #include @@ -31,11 +31,11 @@ #include #include "utils/log_adapter.h" -#include "ps/comm/tcp_message_handler.h" +#include "ps/core/tcp_message_handler.h" namespace mindspore { namespace ps { -namespace comm { +namespace core { class TcpServer; class TcpConnection { @@ -83,6 +83,7 @@ class TcpServer { void SetMessageCallback(const OnServerReceiveMessage &cb); static void SendMessage(const TcpConnection &conn, const CommMessage &message); void SendMessage(const CommMessage &message); + uint16_t BoundPort() const; protected: static void ListenerCallback(struct evconnlistener *listener, evutil_socket_t socket, struct sockaddr *saddr, @@ -106,7 +107,7 @@ class TcpServer { OnServerReceiveMessage message_callback_; }; -} // namespace comm +} // namespace core } // namespace ps } // namespace mindspore -#endif // MINDSPORE_CCSRC_PS_COMM_TCP_SERVER_H_ +#endif // MINDSPORE_CCSRC_PS_CORE_TCP_SERVER_H_ diff --git a/tests/ut/cpp/ps/core/cluster_config_test.cc b/tests/ut/cpp/ps/core/cluster_config_test.cc new file mode 100644 index 0000000000..2fafa70357 --- /dev/null +++ b/tests/ut/cpp/ps/core/cluster_config_test.cc @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "common/common_test.h" +#include "ps/core/cluster_config.h" + +namespace mindspore { +namespace ps { +namespace core { +class TestClusterConfig : public UT::Common { + public: + TestClusterConfig() = default; + virtual ~TestClusterConfig() = default; + + void SetUp() override {} + void TearDown() override {} +}; + +TEST_F(TestClusterConfig, HeartbeatInterval) { + ClusterConfig::Init(2, 2, std::make_unique("127.0.0.1"), 8080); + EXPECT_TRUE(ClusterConfig::heartbeat_interval() == 3); + ClusterConfig::set_heartbeat_interval(100); + EXPECT_TRUE(ClusterConfig::heartbeat_interval() == 100); + EXPECT_STREQ(ClusterConfig::scheduler_host().c_str(), "127.0.0.1"); + EXPECT_TRUE(ClusterConfig::scheduler_port() == 8080); +} + +} // namespace core +} // namespace ps +} // namespace mindspore \ No newline at end of file diff --git a/tests/ut/cpp/ps/core/common_util_test.cc b/tests/ut/cpp/ps/core/common_util_test.cc new file mode 100644 index 0000000000..4b58469248 --- /dev/null +++ b/tests/ut/cpp/ps/core/common_util_test.cc @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/common_test.h" +#include "ps/core/comm_util.h" + +#include +#include + +namespace mindspore { +namespace ps { +namespace core { +class TestCommUtil : public UT::Common { + public: + TestCommUtil() = default; + virtual ~TestCommUtil() = default; + + void SetUp() override {} + void TearDown() override {} +}; + +TEST_F(TestCommUtil, GetAvailableInterfaceAndIP) { + std::string interface; + std::string ip; + CommUtil::GetAvailableInterfaceAndIP(&interface, &ip); + EXPECT_TRUE(!interface.empty()); + EXPECT_TRUE(!ip.empty()); +} +} // namespace comm +} // namespace ps +} // namespace mindspore \ No newline at end of file diff --git a/tests/ut/cpp/ps/comm/http_server_test.cc b/tests/ut/cpp/ps/core/http_server_test.cc similarity index 99% rename from tests/ut/cpp/ps/comm/http_server_test.cc rename to tests/ut/cpp/ps/core/http_server_test.cc index 7c2f5dc6bc..9646b235b7 100644 --- a/tests/ut/cpp/ps/comm/http_server_test.cc +++ b/tests/ut/cpp/ps/core/http_server_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "ps/comm/http_server.h" +#include "ps/core/http_server.h" #include "common/common_test.h" #include #include @@ -28,7 +28,7 @@ namespace mindspore { namespace ps { -namespace comm { +namespace core { class TestHttpServer : public UT::Common { public: diff --git a/tests/ut/cpp/ps/comm/tcp_client_tests.cc b/tests/ut/cpp/ps/core/tcp_client_tests.cc similarity index 94% rename from tests/ut/cpp/ps/comm/tcp_client_tests.cc rename to tests/ut/cpp/ps/core/tcp_client_tests.cc index a8b2b2ef3a..badfd6b287 100644 --- a/tests/ut/cpp/ps/comm/tcp_client_tests.cc +++ b/tests/ut/cpp/ps/core/tcp_client_tests.cc @@ -17,11 +17,11 @@ #include #include "common/common_test.h" -#include "ps/comm/tcp_client.h" +#include "ps/core/tcp_client.h" namespace mindspore { namespace ps { -namespace comm { +namespace core { class TestTcpClient : public UT::Common { public: TestTcpClient() = default; diff --git a/tests/ut/cpp/ps/core/tcp_message_handler_test.cc b/tests/ut/cpp/ps/core/tcp_message_handler_test.cc new file mode 100644 index 0000000000..65bc90ae73 --- /dev/null +++ b/tests/ut/cpp/ps/core/tcp_message_handler_test.cc @@ -0,0 +1,163 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/core/tcp_message_handler.h" +#include "common/common_test.h" + +#include +#include + +namespace mindspore { +namespace ps { +namespace core { +class TestTcpMessageHandler : public UT::Common { + public: + using messageReceive = std::function; + TestTcpMessageHandler() = default; + virtual ~TestTcpMessageHandler() = default; + + void SetUp() override {} + void TearDown() override {} +}; + +TEST_F(TestTcpMessageHandler, 4_Header_1003_Data) { + TcpMessageHandler handler; + handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 1000); }); + + std::string data(1000, 'a'); + CommMessage message; + message.set_data(data); + uint32_t buf_size = message.ByteSizeLong(); + char result[1007]; + int ret = memcpy_s(result, 4, &buf_size, 4); + if (ret != 0) { + MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; + } + + std::vector serialized(buf_size); + message.SerializeToArray(serialized.data(), static_cast(buf_size)); + memcpy_s(result + 4, buf_size, serialized.data(), buf_size); + handler.ReceiveMessage(result, buf_size + 4); +} + +TEST_F(TestTcpMessageHandler, 4_Header_1003_Data_4_Header_1003_Data) { + TcpMessageHandler handler; + handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 1000); }); + + std::string data(1000, 'a'); + CommMessage message; + message.set_data(data); + uint32_t buf_size = message.ByteSizeLong(); + char result[2014]; + int ret = memcpy_s(result, 4, &buf_size, 4); + if (ret != 0) { + MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; + } + std::vector serialized(buf_size); + message.SerializeToArray(serialized.data(), static_cast(buf_size)); + ret = memcpy_s(result + 4, buf_size, serialized.data(), buf_size); + if (ret != 0) { + MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; + } + ret = memcpy_s(result + 4 + buf_size, 4, &buf_size, 4); + if (ret != 0) { + MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; + } + ret = memcpy_s(result + 4 + buf_size + 4, buf_size, serialized.data(), buf_size); + if (ret != 0) { + MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; + } + + handler.ReceiveMessage(result, 2 * buf_size + 4 * 2); +} + +TEST_F(TestTcpMessageHandler, 4_Header_4090_Data_2_Header_2_header_4090_data) { + TcpMessageHandler handler; + handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 4087); }); + + std::string data(4087, 'a'); + CommMessage message; + message.set_data(data); + uint32_t buf_size = message.ByteSizeLong(); + char result[4096]; + int ret = memcpy_s(result, 4, &buf_size, 4); + if (ret != 0) { + MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; + } + std::vector serialized(buf_size); + message.SerializeToArray(serialized.data(), static_cast(buf_size)); + ret = memcpy_s(result + 4, buf_size, serialized.data(), buf_size); + if (ret != 0) { + MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; + } + + ret = memcpy_s(result + 4 + buf_size, 2, &buf_size, 2); + if (ret != 0) { + MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; + } + + handler.ReceiveMessage(result, 4096); + + ret = memcpy_s(result, 2, &buf_size + 2, 2); + if (ret != 0) { + MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; + } + ret = memcpy_s(result + 2, buf_size, serialized.data(), buf_size); + if (ret != 0) { + MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; + } + + handler.ReceiveMessage(result, 4092); +} + +TEST_F(TestTcpMessageHandler, 4_Header_4088_Data_4_Header_4088_data) { + TcpMessageHandler handler; + handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 4085); }); + + std::string data(4085, 'a'); + CommMessage message; + message.set_data(data); + uint32_t buf_size = message.ByteSizeLong(); + char result[4096]; + int ret = memcpy_s(result, 4, &buf_size, 4); + if (ret != 0) { + MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; + } + std::vector serialized(buf_size); + message.SerializeToArray(serialized.data(), static_cast(buf_size)); + ret = memcpy_s(result + 4, buf_size, serialized.data(), buf_size); + if (ret != 0) { + MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; + } + + ret = memcpy_s(result + 4 + buf_size, 4, &buf_size, 4); + if (ret != 0) { + MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; + } + + handler.ReceiveMessage(result, 4096); + + ret = memcpy_s(result, buf_size, serialized.data(), buf_size); + if (ret != 0) { + MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; + } + + handler.ReceiveMessage(result, 4088); +} + +} // namespace comm +} // namespace ps +} // namespace mindspore \ No newline at end of file diff --git a/tests/ut/cpp/ps/comm/tcp_pb_server_test.cc b/tests/ut/cpp/ps/core/tcp_pb_server_test.cc similarity index 90% rename from tests/ut/cpp/ps/comm/tcp_pb_server_test.cc rename to tests/ut/cpp/ps/core/tcp_pb_server_test.cc index e041e2046f..360c722abe 100644 --- a/tests/ut/cpp/ps/comm/tcp_pb_server_test.cc +++ b/tests/ut/cpp/ps/core/tcp_pb_server_test.cc @@ -14,8 +14,8 @@ * limitations under the License. */ -#include "ps/comm/tcp_client.h" -#include "ps/comm/tcp_server.h" +#include "ps/core/tcp_client.h" +#include "ps/core/tcp_server.h" #include "common/common_test.h" #include @@ -23,14 +23,14 @@ namespace mindspore { namespace ps { -namespace comm { +namespace core { class TestTcpServer : public UT::Common { public: TestTcpServer() : client_(nullptr), server_(nullptr) {} virtual ~TestTcpServer() = default; void SetUp() override { - server_ = std::make_unique("127.0.0.1", 9998); + server_ = std::make_unique("127.0.0.1", 0); std::unique_ptr http_server_thread_(nullptr); http_server_thread_ = std::make_unique([&]() { server_->SetMessageCallback([](const TcpServer &server, const TcpConnection &conn, const CommMessage &message) { @@ -57,7 +57,7 @@ class TestTcpServer : public UT::Common { }; TEST_F(TestTcpServer, ServerSendMessage) { - client_ = std::make_unique("127.0.0.1", 9998); + client_ = std::make_unique("127.0.0.1", server_->BoundPort()); std::unique_ptr http_client_thread(nullptr); http_client_thread = std::make_unique([&]() { client_->SetMessageCallback([](const TcpClient &client, const CommMessage &message) { @@ -82,6 +82,6 @@ TEST_F(TestTcpServer, ServerSendMessage) { }); http_client_thread->detach(); } -} // namespace comm +} // namespace core } // namespace ps } // namespace mindspore \ No newline at end of file