Browse Source

FL, update iterator func_graph_ to weak_ptr && fix tcp message handler memory bug

feature/build-system-rewrite
xuyongfei 4 years ago
parent
commit
6bad80681b
12 changed files with 112 additions and 86 deletions
  1. +13
    -13
      mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.cc
  2. +1
    -1
      mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.h
  3. +3
    -5
      mindspore/ccsrc/fl/server/server.cc
  4. +1
    -2
      mindspore/ccsrc/fl/server/server.h
  5. +5
    -3
      mindspore/ccsrc/ps/core/abstract_node.cc
  6. +42
    -28
      mindspore/ccsrc/ps/core/communicator/tcp_client.cc
  7. +6
    -4
      mindspore/ccsrc/ps/core/communicator/tcp_client.h
  8. +26
    -15
      mindspore/ccsrc/ps/core/communicator/tcp_message_handler.cc
  9. +8
    -8
      mindspore/ccsrc/ps/core/communicator/tcp_message_handler.h
  10. +4
    -4
      mindspore/ccsrc/ps/core/scheduler_node.cc
  11. +2
    -2
      tests/ut/cpp/ps/core/tcp_client_tests.cc
  12. +1
    -1
      tests/ut/cpp/ps/core/tcp_pb_server_test.cc

+ 13
- 13
mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.cc View File

@@ -172,19 +172,19 @@ ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *upda
}

PBMetadata device_metas = DistributedMetadataStore::GetInstance().GetMetadata(kCtxDeviceMetas);
FLIdToDeviceMeta fl_id_to_meta = device_metas.device_metas();
const auto &fl_id_to_meta = device_metas.device_metas().fl_id_to_meta();
std::string update_model_fl_id = update_model_req->fl_id()->str();
MS_LOG(INFO) << "UpdateModel for fl id " << update_model_fl_id;
if (ps::PSContext::instance()->encrypt_type() != ps::kPWEncryptType) {
if (fl_id_to_meta.fl_id_to_meta().count(update_model_fl_id) == 0) {
std::string reason = "devices_meta for " + update_model_fl_id + " is not set. Please retry later.";
BuildUpdateModelRsp(
fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(ERROR) << reason;
return ResultCode::kSuccessAndReturn;
}
} else {
if (fl_id_to_meta.count(update_model_fl_id) == 0) {
std::string reason = "devices_meta for " + update_model_fl_id + " is not set. Please retry later.";
BuildUpdateModelRsp(
fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(ERROR) << reason;
return ResultCode::kSuccessAndReturn;
}
if (ps::PSContext::instance()->encrypt_type() == ps::kPWEncryptType) {
std::vector<std::string> get_secrets_clients;
#ifdef ENABLE_ARMOUR
mindspore::armour::CipherMetaStorage cipher_meta_storage;
@@ -201,8 +201,8 @@ ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *upda
}
}

size_t data_size = fl_id_to_meta.fl_id_to_meta().at(update_model_fl_id).data_size();
auto feature_map = ParseFeatureMap(update_model_req);
size_t data_size = fl_id_to_meta.at(update_model_fl_id).data_size();
const auto &feature_map = ParseFeatureMap(update_model_req);
if (feature_map.empty()) {
std::string reason = "Feature map is empty.";
BuildUpdateModelRsp(fbb, schema::ResponseCode_RequestError, reason, "");


+ 1
- 1
mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.h View File

@@ -44,7 +44,7 @@ class UpdateModelKernel : public RoundKernel {
~UpdateModelKernel() override = default;

void InitKernel(size_t threshold_count) override;
bool Launch(const uint8_t *req_data, size_t len, const std::shared_ptr<ps::core::MessageHandler> &message);
bool Launch(const uint8_t *req_data, size_t len, const std::shared_ptr<ps::core::MessageHandler> &message) override;
bool Reset() override;

// In some cases, the last updateModel message means this server iteration is finished.


+ 3
- 5
mindspore/ccsrc/fl/server/server.cc View File

@@ -97,8 +97,6 @@ void Server::Run() {
MS_EXCEPTION_IF_NULL(communicator_with_server_);
communicator_with_server_->Join();
MsException::Instance().CheckException();
func_graph_ = nullptr;
return;
}

void Server::InitPkiCertificate() {
@@ -375,16 +373,16 @@ void Server::RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunic
}

void Server::InitExecutor() {
MS_EXCEPTION_IF_NULL(func_graph_);
if (executor_threshold_ == 0) {
MS_LOG(EXCEPTION) << "The executor's threshold should greater than 0.";
return;
}
auto func_graph = func_graph_.lock();
MS_EXCEPTION_IF_NULL(func_graph);
// The train engine instance is used in both push-type and pull-type kernels,
// so the required_cnt of these kernels must be the same as executor_threshold_.
MS_LOG(INFO) << "Required count for push-type and pull-type kernels is " << executor_threshold_;
Executor::GetInstance().Initialize(func_graph_, executor_threshold_);
func_graph_ = nullptr;
Executor::GetInstance().Initialize(func_graph, executor_threshold_);
ModelStore::GetInstance().Initialize();
return;
}


+ 1
- 2
mindspore/ccsrc/fl/server/server.h View File

@@ -83,7 +83,6 @@ class Server {
use_tcp_(false),
use_http_(false),
http_port_(0),
func_graph_(nullptr),
executor_threshold_(0),
communicator_with_server_(nullptr),
communicators_with_worker_({}),
@@ -194,7 +193,7 @@ class Server {
CipherConfig cipher_config_;

// The graph passed by the frontend without backend optimizing.
FuncGraphPtr func_graph_;
FuncGraphWeakPtr func_graph_;

// The threshold count for executor to do aggregation or optimizing.
size_t executor_threshold_;


+ 5
- 3
mindspore/ccsrc/ps/core/abstract_node.cc View File

@@ -861,7 +861,8 @@ bool AbstractNode::WaitForDisconnect(const uint32_t &timeout) {

void AbstractNode::InitClientToServer() {
// create tcp client to myself in case of event dispatch failed when Send msg to server 0 failed
client_to_server_ = std::make_shared<TcpClient>(node_info_.ip_, node_info_.port_, config_.get());
client_to_server_ =
std::make_shared<TcpClient>(node_info_.ip_, node_info_.port_, config_.get(), node_info_.node_role_);
MS_EXCEPTION_IF_NULL(client_to_server_);
client_to_server_->Init();
MS_LOG(INFO) << "The node start a tcp client to this node!";
@@ -872,7 +873,8 @@ bool AbstractNode::InitClientToScheduler() {
MS_LOG(WARNING) << "The config is empty.";
return false;
}
client_to_scheduler_ = std::make_shared<TcpClient>(scheduler_ip_, scheduler_port_, config_.get());
client_to_scheduler_ =
std::make_shared<TcpClient>(scheduler_ip_, scheduler_port_, config_.get(), NodeRole::SCHEDULER);
MS_EXCEPTION_IF_NULL(client_to_scheduler_);
client_to_scheduler_->SetMessageCallback(
[&](const std::shared_ptr<MessageMeta> &meta, const Protos &, const void *data, size_t size) {
@@ -930,7 +932,7 @@ const std::shared_ptr<TcpClient> &AbstractNode::GetOrCreateTcpClient(const uint3
MS_LOG(INFO) << "Create tcp client for role: " << role << ", rank: " << rank_id;
std::string ip = nodes_address_[key].first;
uint16_t port = nodes_address_[key].second;
auto client = std::make_shared<TcpClient>(ip, port, config_.get());
auto client = std::make_shared<TcpClient>(ip, port, config_.get(), role);
MS_EXCEPTION_IF_NULL(client);
client->SetMessageCallback([&](const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data,
size_t size) {


+ 42
- 28
mindspore/ccsrc/ps/core/communicator/tcp_client.cc View File

@@ -37,11 +37,12 @@ event_base *TcpClient::event_base_ = nullptr;
std::mutex TcpClient::event_base_mutex_;
bool TcpClient::is_started_ = false;

TcpClient::TcpClient(const std::string &address, std::uint16_t port, Configuration *const config)
TcpClient::TcpClient(const std::string &address, std::uint16_t port, Configuration *const config, NodeRole peer_role)
: event_timeout_(nullptr),
buffer_event_(nullptr),
server_address_(std::move(address)),
server_port_(port),
peer_role_(peer_role),
is_stop_(true),
is_connected_(false),
config_(config) {
@@ -70,6 +71,19 @@ void TcpClient::set_disconnected_callback(const OnDisconnected &disconnected) {

void TcpClient::set_connected_callback(const OnConnected &connected) { connected_callback_ = connected; }

std::string TcpClient::PeerRoleName() const {
switch (peer_role_) {
case SERVER:
return "Server";
case WORKER:
return "Worker";
case SCHEDULER:
return "Scheduler";
default:
return "RoleUndefined";
}
}

bool TcpClient::WaitConnected(const uint32_t &connected_timeout) {
std::unique_lock<std::mutex> lock(connection_mutex_);
bool res = connection_cond_.wait_for(lock, std::chrono::seconds(connected_timeout),
@@ -167,36 +181,32 @@ void TcpClient::SetTcpNoDelay(const evutil_socket_t &fd) {

void TcpClient::TimeoutCallback(evutil_socket_t, std::int16_t, void *const arg) {
try {
TimeoutCallbackInner(arg);
MS_EXCEPTION_IF_NULL(arg);
auto tcp_client = reinterpret_cast<TcpClient *>(arg);
tcp_client->Init();
} catch (const std::exception &e) {
MS_LOG(ERROR) << "Catch exception: " << e.what();
}
}

void TcpClient::TimeoutCallbackInner(void *const arg) {
MS_EXCEPTION_IF_NULL(arg);
auto tcp_client = reinterpret_cast<TcpClient *>(arg);
tcp_client->Init();
}

void TcpClient::ReadCallback(struct bufferevent *bev, void *const ctx) {
try {
ReadCallbackInner(bev, ctx);
MS_EXCEPTION_IF_NULL(ctx);
auto tcp_client = reinterpret_cast<TcpClient *>(ctx);
tcp_client->ReadCallbackInner(bev);
} catch (const std::exception &e) {
MS_LOG(ERROR) << "Catch exception: " << e.what();
}
}

void TcpClient::ReadCallbackInner(struct bufferevent *bev, void *const ctx) {
void TcpClient::ReadCallbackInner(struct bufferevent *bev) {
MS_EXCEPTION_IF_NULL(bev);
MS_EXCEPTION_IF_NULL(ctx);
auto tcp_client = reinterpret_cast<TcpClient *>(ctx);

char read_buffer[kMessageChunkLength];
size_t read = 0;

while ((read = bufferevent_read(bev, &read_buffer, sizeof(read_buffer))) > 0) {
tcp_client->OnReadHandler(read_buffer, read);
OnReadHandler(read_buffer, read);
}
}

@@ -217,7 +227,8 @@ void TcpClient::TimerCallback(evutil_socket_t, int16_t, void *arg) {
}

void TcpClient::NotifyConnected() {
MS_LOG(INFO) << "Client connected to the server!";
MS_LOG(INFO) << "Client connected to the server! Peer " << PeerRoleName() << " ip: " << server_address_
<< ", port: " << server_port_;
is_connected_ = true;
connection_cond_.notify_all();
}
@@ -236,34 +247,37 @@ bool TcpClient::EstablishSSL() {

void TcpClient::EventCallback(struct bufferevent *bev, std::int16_t events, void *const ptr) {
try {
EventCallbackInner(bev, events, ptr);
MS_EXCEPTION_IF_NULL(ptr);
auto tcp_client = reinterpret_cast<TcpClient *>(ptr);
tcp_client->EventCallbackInner(bev, events);
} catch (const std::exception &e) {
MS_LOG(ERROR) << "Catch exception: " << e.what();
}
}

void TcpClient::EventCallbackInner(struct bufferevent *bev, std::int16_t events, void *const ptr) {
void TcpClient::EventCallbackInner(struct bufferevent *bev, std::int16_t events) {
MS_EXCEPTION_IF_NULL(bev);
MS_EXCEPTION_IF_NULL(ptr);
auto tcp_client = reinterpret_cast<TcpClient *>(ptr);
if (events & BEV_EVENT_CONNECTED) {
// Connected
if (tcp_client->connected_callback_) {
tcp_client->connected_callback_();
if (connected_callback_) {
connected_callback_();
}
tcp_client->NotifyConnected();
NotifyConnected();
evutil_socket_t fd = bufferevent_getfd(bev);
SetTcpNoDelay(fd);
MS_LOG(INFO) << "Client connected!";
MS_LOG(INFO) << "Client connected! Peer " << PeerRoleName() << " ip: " << server_address_
<< ", port: " << server_port_;
} else if (events & BEV_EVENT_ERROR) {
MS_LOG(WARNING) << "The client will retry to connect to the server!";
if (tcp_client->disconnected_callback_) {
tcp_client->disconnected_callback_();
MS_LOG(WARNING) << "The client will retry to connect to the server! Peer " << PeerRoleName()
<< " ip: " << server_address_ << ", port: " << server_port_;
if (disconnected_callback_) {
disconnected_callback_();
}
} else if (events & BEV_EVENT_EOF) {
MS_LOG(WARNING) << "Client connected end of file";
if (tcp_client->disconnected_callback_) {
tcp_client->disconnected_callback_();
MS_LOG(WARNING) << "Client connected end of file! Peer " << PeerRoleName() << " ip: " << server_address_
<< ", port: " << server_port_;
if (disconnected_callback_) {
disconnected_callback_();
}
}
}


+ 6
- 4
mindspore/ccsrc/ps/core/communicator/tcp_client.h View File

@@ -54,7 +54,7 @@ class TcpClient {
std::function<void(const std::shared_ptr<MessageMeta> &, const Protos &, const void *, size_t size)>;
using OnTimer = std::function<void()>;

explicit TcpClient(const std::string &address, std::uint16_t port, Configuration *const config);
explicit TcpClient(const std::string &address, std::uint16_t port, Configuration *const config, NodeRole peer_role);
virtual ~TcpClient();

std::string GetServerAddress() const;
@@ -76,16 +76,17 @@ class TcpClient {
protected:
static void SetTcpNoDelay(const evutil_socket_t &fd);
static void TimeoutCallback(evutil_socket_t fd, std::int16_t what, void *arg);
static void TimeoutCallbackInner(void *arg);
static void ReadCallback(struct bufferevent *bev, void *ctx);
static void ReadCallbackInner(struct bufferevent *bev, void *ctx);
void ReadCallbackInner(struct bufferevent *bev);
static void EventCallback(struct bufferevent *bev, std::int16_t events, void *ptr);
static void EventCallbackInner(struct bufferevent *bev, std::int16_t events, void *ptr);
void EventCallbackInner(struct bufferevent *bev, std::int16_t events);
virtual void OnReadHandler(const void *buf, size_t num);
static void TimerCallback(evutil_socket_t fd, int16_t event, void *arg);
void NotifyConnected();
bool EstablishSSL();

std::string PeerRoleName() const;

private:
OnMessage message_callback_;
TcpMessageHandler message_handler_;
@@ -107,6 +108,7 @@ class TcpClient {

std::string server_address_;
std::uint16_t server_port_;
NodeRole peer_role_;
std::atomic<bool> is_stop_;
std::atomic<bool> is_connected_;
// The Configuration file


+ 26
- 15
mindspore/ccsrc/ps/core/communicator/tcp_message_handler.cc View File

@@ -28,18 +28,20 @@ void TcpMessageHandler::SetCallback(const messageReceive &message_receive) { mes

void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) {
MS_EXCEPTION_IF_NULL(buffer);
auto buffer_data = reinterpret_cast<const unsigned char *>(buffer);
auto buffer_data = reinterpret_cast<const uint8_t *>(buffer);

while (num > 0) {
if (remaining_length_ == 0) {
for (int i = 0; i < kHeaderLen && num > 0; ++i) {
header_[++header_index_] = *(buffer_data + i);
for (size_t i = 0; cur_header_len_ < kHeaderLen && num > 0; ++i) {
header_[cur_header_len_] = buffer_data[i];
cur_header_len_ += 1;
--num;
if (header_index_ == kHeaderLen - 1) {
if (cur_header_len_ == kHeaderLen) {
message_header_.message_proto_ = *reinterpret_cast<const Protos *>(header_);
if (message_header_.message_proto_ != Protos::RAW && message_header_.message_proto_ != Protos::FLATBUFFERS &&
message_header_.message_proto_ != Protos::PROTOBUF) {
MS_LOG(WARNING) << "The proto:" << message_header_.message_proto_ << " is illegal!";
Reset();
return;
}
message_header_.message_meta_length_ =
@@ -48,15 +50,17 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) {
header_ + sizeof(message_header_.message_proto_) + sizeof(message_header_.message_meta_length_));
if (message_header_.message_length_ >= UINT32_MAX) {
MS_LOG(WARNING) << "The message len:" << message_header_.message_length_ << " is too long.";
Reset();
return;
}
if (message_header_.message_meta_length_ > message_header_.message_length_) {
MS_LOG(WARNING) << "The message meta len " << message_header_.message_meta_length_ << " > the message len "
<< message_header_.message_length_;
Reset();
return;
}
remaining_length_ = message_header_.message_length_;
message_buffer_ = std::make_unique<unsigned char[]>(remaining_length_);
MS_EXCEPTION_IF_NULL(message_buffer_);
message_buffer_.resize(remaining_length_);
buffer_data += (i + 1);
break;
}
@@ -68,9 +72,9 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) {
remaining_length_ -= copy_len;
num -= copy_len;

size_t dest_size = copy_len;
size_t dest_size = message_buffer_.size() - last_copy_len_;
size_t src_size = copy_len;
auto ret = memcpy_s(message_buffer_.get() + last_copy_len_, dest_size, buffer_data, src_size);
auto ret = memcpy_s(message_buffer_.data() + last_copy_len_, dest_size, buffer_data, src_size);
last_copy_len_ += copy_len;
buffer_data += copy_len;
if (ret != EOK) {
@@ -81,20 +85,27 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) {
if (message_callback_) {
std::shared_ptr<MessageMeta> pb_message = std::make_shared<MessageMeta>();
MS_EXCEPTION_IF_NULL(pb_message);
CHECK_RETURN_TYPE(
pb_message->ParseFromArray(message_buffer_.get(), UintToInt(message_header_.message_meta_length_)));
if (!pb_message->ParseFromArray(message_buffer_.data(), UintToInt(message_header_.message_meta_length_))) {
MS_LOG(ERROR) << "Parse protobuf MessageMeta failed";
Reset();
return;
}
message_callback_(pb_message, message_header_.message_proto_,
message_buffer_.get() + message_header_.message_meta_length_,
message_buffer_.data() + message_header_.message_meta_length_,
message_header_.message_length_ - message_header_.message_meta_length_);
}
message_buffer_.reset();
message_buffer_ = nullptr;
header_index_ = -1;
last_copy_len_ = 0;
Reset();
}
}
}
}

void TcpMessageHandler::Reset() {
message_buffer_.clear();
cur_header_len_ = 0;
last_copy_len_ = 0;
remaining_length_ = 0;
}
} // namespace core
} // namespace ps
} // namespace mindspore

+ 8
- 8
mindspore/ccsrc/ps/core/communicator/tcp_message_handler.h View File

@@ -35,27 +35,27 @@ namespace ps {
namespace core {
using messageReceive =
std::function<void(const std::shared_ptr<MessageMeta> &, const Protos &, const void *, size_t size)>;
constexpr int kHeaderLen = 16;

constexpr size_t kHeaderLen = sizeof(MessageHeader);

class TcpMessageHandler {
public:
TcpMessageHandler()
: is_parsed_(false), message_buffer_(nullptr), remaining_length_(0), header_index_(-1), last_copy_len_(0) {}
TcpMessageHandler() : remaining_length_(0), cur_header_len_(0), last_copy_len_(0) {}
virtual ~TcpMessageHandler() = default;

void SetCallback(const messageReceive &cb);
void ReceiveMessage(const void *buffer, size_t num);

void Reset();

private:
messageReceive message_callback_;
bool is_parsed_;
std::unique_ptr<unsigned char[]> message_buffer_;
std::vector<uint8_t> message_buffer_;
uint8_t header_[kHeaderLen]{0};
size_t remaining_length_;
unsigned char header_[16]{0};
int header_index_;
size_t cur_header_len_ = 0;
size_t last_copy_len_;
MessageHeader message_header_;
std::string mBuffer;
};
} // namespace core
} // namespace ps


+ 4
- 4
mindspore/ccsrc/ps/core/scheduler_node.cc View File

@@ -67,8 +67,8 @@ bool SchedulerNode::Start(const uint32_t &timeout) {
void SchedulerNode::RunRecovery() {
core::ClusterConfig &clusterConfig = PSContext::instance()->cluster_config();
// create tcp client to myself in case of event dispatch failed when Send reconnect msg to server failed
client_to_scheduler_ =
std::make_shared<TcpClient>(clusterConfig.scheduler_host, clusterConfig.scheduler_port, config_.get());
client_to_scheduler_ = std::make_shared<TcpClient>(clusterConfig.scheduler_host, clusterConfig.scheduler_port,
config_.get(), NodeRole::SCHEDULER);
MS_EXCEPTION_IF_NULL(client_to_scheduler_);
client_to_scheduler_->Init();
client_thread_ = std::make_unique<std::thread>([&]() {
@@ -95,7 +95,7 @@ void SchedulerNode::RunRecovery() {
for (const auto &kvs : initial_node_infos) {
auto &node_id = kvs.first;
auto &node_info = kvs.second;
auto client = std::make_shared<TcpClient>(node_info.ip_, node_info.port_, config_.get());
auto client = std::make_shared<TcpClient>(node_info.ip_, node_info.port_, config_.get(), node_info.node_role_);
client->SetMessageCallback([this](const std::shared_ptr<MessageMeta> &meta, const Protos &, const void *, size_t) {
MS_LOG(INFO) << "received the response. ";
NotifyMessageArrival(meta);
@@ -648,7 +648,7 @@ const std::shared_ptr<TcpClient> &SchedulerNode::GetOrCreateClient(const NodeInf
std::string ip = node_info.ip_;
uint16_t port = node_info.port_;
MS_LOG(INFO) << "ip:" << ip << ", port:" << port << ", node id:" << node_info.node_id_;
auto client = std::make_shared<TcpClient>(ip, port, config_.get());
auto client = std::make_shared<TcpClient>(ip, port, config_.get(), node_info.node_role_);
MS_EXCEPTION_IF_NULL(client);
client->SetMessageCallback(
[&](const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data, size_t size) {


+ 2
- 2
tests/ut/cpp/ps/core/tcp_client_tests.cc View File

@@ -29,7 +29,7 @@ class TestTcpClient : public UT::Common {

TEST_F(TestTcpClient, InitClientIPError) {
std::unique_ptr<Configuration> config = std::make_unique<FileConfiguration>("");
auto client = std::make_unique<TcpClient>("127.0.0.13543", 9000, config.get());
auto client = std::make_unique<TcpClient>("127.0.0.13543", 9000, config.get(), NodeRole::SERVER);

client->SetMessageCallback([&](std::shared_ptr<MessageMeta>, const Protos &, const void *data, size_t size) {
CommMessage message;
@@ -43,7 +43,7 @@ TEST_F(TestTcpClient, InitClientIPError) {

TEST_F(TestTcpClient, InitClientPortErrorNoException) {
std::unique_ptr<Configuration> config = std::make_unique<FileConfiguration>("");
auto client = std::make_unique<TcpClient>("127.0.0.1", -1, config.get());
auto client = std::make_unique<TcpClient>("127.0.0.1", -1, config.get(), NodeRole::SERVER);

client->SetMessageCallback([&](std::shared_ptr<MessageMeta>, const Protos &, const void *data, size_t size) {
CommMessage message;


+ 1
- 1
tests/ut/cpp/ps/core/tcp_pb_server_test.cc View File

@@ -60,7 +60,7 @@ class TestTcpServer : public UT::Common {

TEST_F(TestTcpServer, ServerSendMessage) {
std::unique_ptr<Configuration> config = std::make_unique<FileConfiguration>("");
client_ = std::make_unique<TcpClient>("127.0.0.1", server_->BoundPort(), config.get());
client_ = std::make_unique<TcpClient>("127.0.0.1", server_->BoundPort(), config.get(), NodeRole::SERVER);
std::cout << server_->BoundPort() << std::endl;
std::unique_ptr<std::thread> http_client_thread(nullptr);
http_client_thread = std::make_unique<std::thread>([&]() {


Loading…
Cancel
Save