| @@ -172,19 +172,19 @@ ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *upda | |||
| } | |||
| PBMetadata device_metas = DistributedMetadataStore::GetInstance().GetMetadata(kCtxDeviceMetas); | |||
| FLIdToDeviceMeta fl_id_to_meta = device_metas.device_metas(); | |||
| const auto &fl_id_to_meta = device_metas.device_metas().fl_id_to_meta(); | |||
| std::string update_model_fl_id = update_model_req->fl_id()->str(); | |||
| MS_LOG(INFO) << "UpdateModel for fl id " << update_model_fl_id; | |||
| if (ps::PSContext::instance()->encrypt_type() != ps::kPWEncryptType) { | |||
| if (fl_id_to_meta.fl_id_to_meta().count(update_model_fl_id) == 0) { | |||
| std::string reason = "devices_meta for " + update_model_fl_id + " is not set. Please retry later."; | |||
| BuildUpdateModelRsp( | |||
| fbb, schema::ResponseCode_OutOfTime, reason, | |||
| std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp))); | |||
| MS_LOG(ERROR) << reason; | |||
| return ResultCode::kSuccessAndReturn; | |||
| } | |||
| } else { | |||
| if (fl_id_to_meta.count(update_model_fl_id) == 0) { | |||
| std::string reason = "devices_meta for " + update_model_fl_id + " is not set. Please retry later."; | |||
| BuildUpdateModelRsp( | |||
| fbb, schema::ResponseCode_OutOfTime, reason, | |||
| std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp))); | |||
| MS_LOG(ERROR) << reason; | |||
| return ResultCode::kSuccessAndReturn; | |||
| } | |||
| if (ps::PSContext::instance()->encrypt_type() == ps::kPWEncryptType) { | |||
| std::vector<std::string> get_secrets_clients; | |||
| #ifdef ENABLE_ARMOUR | |||
| mindspore::armour::CipherMetaStorage cipher_meta_storage; | |||
| @@ -201,8 +201,8 @@ ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *upda | |||
| } | |||
| } | |||
| size_t data_size = fl_id_to_meta.fl_id_to_meta().at(update_model_fl_id).data_size(); | |||
| auto feature_map = ParseFeatureMap(update_model_req); | |||
| size_t data_size = fl_id_to_meta.at(update_model_fl_id).data_size(); | |||
| const auto &feature_map = ParseFeatureMap(update_model_req); | |||
| if (feature_map.empty()) { | |||
| std::string reason = "Feature map is empty."; | |||
| BuildUpdateModelRsp(fbb, schema::ResponseCode_RequestError, reason, ""); | |||
| @@ -44,7 +44,7 @@ class UpdateModelKernel : public RoundKernel { | |||
| ~UpdateModelKernel() override = default; | |||
| void InitKernel(size_t threshold_count) override; | |||
| bool Launch(const uint8_t *req_data, size_t len, const std::shared_ptr<ps::core::MessageHandler> &message); | |||
| bool Launch(const uint8_t *req_data, size_t len, const std::shared_ptr<ps::core::MessageHandler> &message) override; | |||
| bool Reset() override; | |||
| // In some cases, the last updateModel message means this server iteration is finished. | |||
| @@ -97,8 +97,6 @@ void Server::Run() { | |||
| MS_EXCEPTION_IF_NULL(communicator_with_server_); | |||
| communicator_with_server_->Join(); | |||
| MsException::Instance().CheckException(); | |||
| func_graph_ = nullptr; | |||
| return; | |||
| } | |||
| void Server::InitPkiCertificate() { | |||
| @@ -375,16 +373,16 @@ void Server::RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunic | |||
| } | |||
| void Server::InitExecutor() { | |||
| MS_EXCEPTION_IF_NULL(func_graph_); | |||
| if (executor_threshold_ == 0) { | |||
| MS_LOG(EXCEPTION) << "The executor's threshold should greater than 0."; | |||
| return; | |||
| } | |||
| auto func_graph = func_graph_.lock(); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| // The train engine instance is used in both push-type and pull-type kernels, | |||
| // so the required_cnt of these kernels must be the same as executor_threshold_. | |||
| MS_LOG(INFO) << "Required count for push-type and pull-type kernels is " << executor_threshold_; | |||
| Executor::GetInstance().Initialize(func_graph_, executor_threshold_); | |||
| func_graph_ = nullptr; | |||
| Executor::GetInstance().Initialize(func_graph, executor_threshold_); | |||
| ModelStore::GetInstance().Initialize(); | |||
| return; | |||
| } | |||
| @@ -83,7 +83,6 @@ class Server { | |||
| use_tcp_(false), | |||
| use_http_(false), | |||
| http_port_(0), | |||
| func_graph_(nullptr), | |||
| executor_threshold_(0), | |||
| communicator_with_server_(nullptr), | |||
| communicators_with_worker_({}), | |||
| @@ -194,7 +193,7 @@ class Server { | |||
| CipherConfig cipher_config_; | |||
| // The graph passed by the frontend without backend optimizing. | |||
| FuncGraphPtr func_graph_; | |||
| FuncGraphWeakPtr func_graph_; | |||
| // The threshold count for executor to do aggregation or optimizing. | |||
| size_t executor_threshold_; | |||
| @@ -861,7 +861,8 @@ bool AbstractNode::WaitForDisconnect(const uint32_t &timeout) { | |||
| void AbstractNode::InitClientToServer() { | |||
| // create tcp client to myself in case of event dispatch failed when Send msg to server 0 failed | |||
| client_to_server_ = std::make_shared<TcpClient>(node_info_.ip_, node_info_.port_, config_.get()); | |||
| client_to_server_ = | |||
| std::make_shared<TcpClient>(node_info_.ip_, node_info_.port_, config_.get(), node_info_.node_role_); | |||
| MS_EXCEPTION_IF_NULL(client_to_server_); | |||
| client_to_server_->Init(); | |||
| MS_LOG(INFO) << "The node start a tcp client to this node!"; | |||
| @@ -872,7 +873,8 @@ bool AbstractNode::InitClientToScheduler() { | |||
| MS_LOG(WARNING) << "The config is empty."; | |||
| return false; | |||
| } | |||
| client_to_scheduler_ = std::make_shared<TcpClient>(scheduler_ip_, scheduler_port_, config_.get()); | |||
| client_to_scheduler_ = | |||
| std::make_shared<TcpClient>(scheduler_ip_, scheduler_port_, config_.get(), NodeRole::SCHEDULER); | |||
| MS_EXCEPTION_IF_NULL(client_to_scheduler_); | |||
| client_to_scheduler_->SetMessageCallback( | |||
| [&](const std::shared_ptr<MessageMeta> &meta, const Protos &, const void *data, size_t size) { | |||
| @@ -930,7 +932,7 @@ const std::shared_ptr<TcpClient> &AbstractNode::GetOrCreateTcpClient(const uint3 | |||
| MS_LOG(INFO) << "Create tcp client for role: " << role << ", rank: " << rank_id; | |||
| std::string ip = nodes_address_[key].first; | |||
| uint16_t port = nodes_address_[key].second; | |||
| auto client = std::make_shared<TcpClient>(ip, port, config_.get()); | |||
| auto client = std::make_shared<TcpClient>(ip, port, config_.get(), role); | |||
| MS_EXCEPTION_IF_NULL(client); | |||
| client->SetMessageCallback([&](const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data, | |||
| size_t size) { | |||
| @@ -37,11 +37,12 @@ event_base *TcpClient::event_base_ = nullptr; | |||
| std::mutex TcpClient::event_base_mutex_; | |||
| bool TcpClient::is_started_ = false; | |||
| TcpClient::TcpClient(const std::string &address, std::uint16_t port, Configuration *const config) | |||
| TcpClient::TcpClient(const std::string &address, std::uint16_t port, Configuration *const config, NodeRole peer_role) | |||
| : event_timeout_(nullptr), | |||
| buffer_event_(nullptr), | |||
| server_address_(std::move(address)), | |||
| server_port_(port), | |||
| peer_role_(peer_role), | |||
| is_stop_(true), | |||
| is_connected_(false), | |||
| config_(config) { | |||
| @@ -70,6 +71,19 @@ void TcpClient::set_disconnected_callback(const OnDisconnected &disconnected) { | |||
| void TcpClient::set_connected_callback(const OnConnected &connected) { connected_callback_ = connected; } | |||
| std::string TcpClient::PeerRoleName() const { | |||
| switch (peer_role_) { | |||
| case SERVER: | |||
| return "Server"; | |||
| case WORKER: | |||
| return "Worker"; | |||
| case SCHEDULER: | |||
| return "Scheduler"; | |||
| default: | |||
| return "RoleUndefined"; | |||
| } | |||
| } | |||
| bool TcpClient::WaitConnected(const uint32_t &connected_timeout) { | |||
| std::unique_lock<std::mutex> lock(connection_mutex_); | |||
| bool res = connection_cond_.wait_for(lock, std::chrono::seconds(connected_timeout), | |||
| @@ -167,36 +181,32 @@ void TcpClient::SetTcpNoDelay(const evutil_socket_t &fd) { | |||
| void TcpClient::TimeoutCallback(evutil_socket_t, std::int16_t, void *const arg) { | |||
| try { | |||
| TimeoutCallbackInner(arg); | |||
| MS_EXCEPTION_IF_NULL(arg); | |||
| auto tcp_client = reinterpret_cast<TcpClient *>(arg); | |||
| tcp_client->Init(); | |||
| } catch (const std::exception &e) { | |||
| MS_LOG(ERROR) << "Catch exception: " << e.what(); | |||
| } | |||
| } | |||
| void TcpClient::TimeoutCallbackInner(void *const arg) { | |||
| MS_EXCEPTION_IF_NULL(arg); | |||
| auto tcp_client = reinterpret_cast<TcpClient *>(arg); | |||
| tcp_client->Init(); | |||
| } | |||
| void TcpClient::ReadCallback(struct bufferevent *bev, void *const ctx) { | |||
| try { | |||
| ReadCallbackInner(bev, ctx); | |||
| MS_EXCEPTION_IF_NULL(ctx); | |||
| auto tcp_client = reinterpret_cast<TcpClient *>(ctx); | |||
| tcp_client->ReadCallbackInner(bev); | |||
| } catch (const std::exception &e) { | |||
| MS_LOG(ERROR) << "Catch exception: " << e.what(); | |||
| } | |||
| } | |||
| void TcpClient::ReadCallbackInner(struct bufferevent *bev, void *const ctx) { | |||
| void TcpClient::ReadCallbackInner(struct bufferevent *bev) { | |||
| MS_EXCEPTION_IF_NULL(bev); | |||
| MS_EXCEPTION_IF_NULL(ctx); | |||
| auto tcp_client = reinterpret_cast<TcpClient *>(ctx); | |||
| char read_buffer[kMessageChunkLength]; | |||
| size_t read = 0; | |||
| while ((read = bufferevent_read(bev, &read_buffer, sizeof(read_buffer))) > 0) { | |||
| tcp_client->OnReadHandler(read_buffer, read); | |||
| OnReadHandler(read_buffer, read); | |||
| } | |||
| } | |||
| @@ -217,7 +227,8 @@ void TcpClient::TimerCallback(evutil_socket_t, int16_t, void *arg) { | |||
| } | |||
| void TcpClient::NotifyConnected() { | |||
| MS_LOG(INFO) << "Client connected to the server!"; | |||
| MS_LOG(INFO) << "Client connected to the server! Peer " << PeerRoleName() << " ip: " << server_address_ | |||
| << ", port: " << server_port_; | |||
| is_connected_ = true; | |||
| connection_cond_.notify_all(); | |||
| } | |||
| @@ -236,34 +247,37 @@ bool TcpClient::EstablishSSL() { | |||
| void TcpClient::EventCallback(struct bufferevent *bev, std::int16_t events, void *const ptr) { | |||
| try { | |||
| EventCallbackInner(bev, events, ptr); | |||
| MS_EXCEPTION_IF_NULL(ptr); | |||
| auto tcp_client = reinterpret_cast<TcpClient *>(ptr); | |||
| tcp_client->EventCallbackInner(bev, events); | |||
| } catch (const std::exception &e) { | |||
| MS_LOG(ERROR) << "Catch exception: " << e.what(); | |||
| } | |||
| } | |||
| void TcpClient::EventCallbackInner(struct bufferevent *bev, std::int16_t events, void *const ptr) { | |||
| void TcpClient::EventCallbackInner(struct bufferevent *bev, std::int16_t events) { | |||
| MS_EXCEPTION_IF_NULL(bev); | |||
| MS_EXCEPTION_IF_NULL(ptr); | |||
| auto tcp_client = reinterpret_cast<TcpClient *>(ptr); | |||
| if (events & BEV_EVENT_CONNECTED) { | |||
| // Connected | |||
| if (tcp_client->connected_callback_) { | |||
| tcp_client->connected_callback_(); | |||
| if (connected_callback_) { | |||
| connected_callback_(); | |||
| } | |||
| tcp_client->NotifyConnected(); | |||
| NotifyConnected(); | |||
| evutil_socket_t fd = bufferevent_getfd(bev); | |||
| SetTcpNoDelay(fd); | |||
| MS_LOG(INFO) << "Client connected!"; | |||
| MS_LOG(INFO) << "Client connected! Peer " << PeerRoleName() << " ip: " << server_address_ | |||
| << ", port: " << server_port_; | |||
| } else if (events & BEV_EVENT_ERROR) { | |||
| MS_LOG(WARNING) << "The client will retry to connect to the server!"; | |||
| if (tcp_client->disconnected_callback_) { | |||
| tcp_client->disconnected_callback_(); | |||
| MS_LOG(WARNING) << "The client will retry to connect to the server! Peer " << PeerRoleName() | |||
| << " ip: " << server_address_ << ", port: " << server_port_; | |||
| if (disconnected_callback_) { | |||
| disconnected_callback_(); | |||
| } | |||
| } else if (events & BEV_EVENT_EOF) { | |||
| MS_LOG(WARNING) << "Client connected end of file"; | |||
| if (tcp_client->disconnected_callback_) { | |||
| tcp_client->disconnected_callback_(); | |||
| MS_LOG(WARNING) << "Client connected end of file! Peer " << PeerRoleName() << " ip: " << server_address_ | |||
| << ", port: " << server_port_; | |||
| if (disconnected_callback_) { | |||
| disconnected_callback_(); | |||
| } | |||
| } | |||
| } | |||
| @@ -54,7 +54,7 @@ class TcpClient { | |||
| std::function<void(const std::shared_ptr<MessageMeta> &, const Protos &, const void *, size_t size)>; | |||
| using OnTimer = std::function<void()>; | |||
| explicit TcpClient(const std::string &address, std::uint16_t port, Configuration *const config); | |||
| explicit TcpClient(const std::string &address, std::uint16_t port, Configuration *const config, NodeRole peer_role); | |||
| virtual ~TcpClient(); | |||
| std::string GetServerAddress() const; | |||
| @@ -76,16 +76,17 @@ class TcpClient { | |||
| protected: | |||
| static void SetTcpNoDelay(const evutil_socket_t &fd); | |||
| static void TimeoutCallback(evutil_socket_t fd, std::int16_t what, void *arg); | |||
| static void TimeoutCallbackInner(void *arg); | |||
| static void ReadCallback(struct bufferevent *bev, void *ctx); | |||
| static void ReadCallbackInner(struct bufferevent *bev, void *ctx); | |||
| void ReadCallbackInner(struct bufferevent *bev); | |||
| static void EventCallback(struct bufferevent *bev, std::int16_t events, void *ptr); | |||
| static void EventCallbackInner(struct bufferevent *bev, std::int16_t events, void *ptr); | |||
| void EventCallbackInner(struct bufferevent *bev, std::int16_t events); | |||
| virtual void OnReadHandler(const void *buf, size_t num); | |||
| static void TimerCallback(evutil_socket_t fd, int16_t event, void *arg); | |||
| void NotifyConnected(); | |||
| bool EstablishSSL(); | |||
| std::string PeerRoleName() const; | |||
| private: | |||
| OnMessage message_callback_; | |||
| TcpMessageHandler message_handler_; | |||
| @@ -107,6 +108,7 @@ class TcpClient { | |||
| std::string server_address_; | |||
| std::uint16_t server_port_; | |||
| NodeRole peer_role_; | |||
| std::atomic<bool> is_stop_; | |||
| std::atomic<bool> is_connected_; | |||
| // The Configuration file | |||
| @@ -28,18 +28,20 @@ void TcpMessageHandler::SetCallback(const messageReceive &message_receive) { mes | |||
| void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) { | |||
| MS_EXCEPTION_IF_NULL(buffer); | |||
| auto buffer_data = reinterpret_cast<const unsigned char *>(buffer); | |||
| auto buffer_data = reinterpret_cast<const uint8_t *>(buffer); | |||
| while (num > 0) { | |||
| if (remaining_length_ == 0) { | |||
| for (int i = 0; i < kHeaderLen && num > 0; ++i) { | |||
| header_[++header_index_] = *(buffer_data + i); | |||
| for (size_t i = 0; cur_header_len_ < kHeaderLen && num > 0; ++i) { | |||
| header_[cur_header_len_] = buffer_data[i]; | |||
| cur_header_len_ += 1; | |||
| --num; | |||
| if (header_index_ == kHeaderLen - 1) { | |||
| if (cur_header_len_ == kHeaderLen) { | |||
| message_header_.message_proto_ = *reinterpret_cast<const Protos *>(header_); | |||
| if (message_header_.message_proto_ != Protos::RAW && message_header_.message_proto_ != Protos::FLATBUFFERS && | |||
| message_header_.message_proto_ != Protos::PROTOBUF) { | |||
| MS_LOG(WARNING) << "The proto:" << message_header_.message_proto_ << " is illegal!"; | |||
| Reset(); | |||
| return; | |||
| } | |||
| message_header_.message_meta_length_ = | |||
| @@ -48,15 +50,17 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) { | |||
| header_ + sizeof(message_header_.message_proto_) + sizeof(message_header_.message_meta_length_)); | |||
| if (message_header_.message_length_ >= UINT32_MAX) { | |||
| MS_LOG(WARNING) << "The message len:" << message_header_.message_length_ << " is too long."; | |||
| Reset(); | |||
| return; | |||
| } | |||
| if (message_header_.message_meta_length_ > message_header_.message_length_) { | |||
| MS_LOG(WARNING) << "The message meta len " << message_header_.message_meta_length_ << " > the message len " | |||
| << message_header_.message_length_; | |||
| Reset(); | |||
| return; | |||
| } | |||
| remaining_length_ = message_header_.message_length_; | |||
| message_buffer_ = std::make_unique<unsigned char[]>(remaining_length_); | |||
| MS_EXCEPTION_IF_NULL(message_buffer_); | |||
| message_buffer_.resize(remaining_length_); | |||
| buffer_data += (i + 1); | |||
| break; | |||
| } | |||
| @@ -68,9 +72,9 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) { | |||
| remaining_length_ -= copy_len; | |||
| num -= copy_len; | |||
| size_t dest_size = copy_len; | |||
| size_t dest_size = message_buffer_.size() - last_copy_len_; | |||
| size_t src_size = copy_len; | |||
| auto ret = memcpy_s(message_buffer_.get() + last_copy_len_, dest_size, buffer_data, src_size); | |||
| auto ret = memcpy_s(message_buffer_.data() + last_copy_len_, dest_size, buffer_data, src_size); | |||
| last_copy_len_ += copy_len; | |||
| buffer_data += copy_len; | |||
| if (ret != EOK) { | |||
| @@ -81,20 +85,27 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) { | |||
| if (message_callback_) { | |||
| std::shared_ptr<MessageMeta> pb_message = std::make_shared<MessageMeta>(); | |||
| MS_EXCEPTION_IF_NULL(pb_message); | |||
| CHECK_RETURN_TYPE( | |||
| pb_message->ParseFromArray(message_buffer_.get(), UintToInt(message_header_.message_meta_length_))); | |||
| if (!pb_message->ParseFromArray(message_buffer_.data(), UintToInt(message_header_.message_meta_length_))) { | |||
| MS_LOG(ERROR) << "Parse protobuf MessageMeta failed"; | |||
| Reset(); | |||
| return; | |||
| } | |||
| message_callback_(pb_message, message_header_.message_proto_, | |||
| message_buffer_.get() + message_header_.message_meta_length_, | |||
| message_buffer_.data() + message_header_.message_meta_length_, | |||
| message_header_.message_length_ - message_header_.message_meta_length_); | |||
| } | |||
| message_buffer_.reset(); | |||
| message_buffer_ = nullptr; | |||
| header_index_ = -1; | |||
| last_copy_len_ = 0; | |||
| Reset(); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void TcpMessageHandler::Reset() { | |||
| message_buffer_.clear(); | |||
| cur_header_len_ = 0; | |||
| last_copy_len_ = 0; | |||
| remaining_length_ = 0; | |||
| } | |||
| } // namespace core | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -35,27 +35,27 @@ namespace ps { | |||
| namespace core { | |||
| using messageReceive = | |||
| std::function<void(const std::shared_ptr<MessageMeta> &, const Protos &, const void *, size_t size)>; | |||
| constexpr int kHeaderLen = 16; | |||
| constexpr size_t kHeaderLen = sizeof(MessageHeader); | |||
| class TcpMessageHandler { | |||
| public: | |||
| TcpMessageHandler() | |||
| : is_parsed_(false), message_buffer_(nullptr), remaining_length_(0), header_index_(-1), last_copy_len_(0) {} | |||
| TcpMessageHandler() : remaining_length_(0), cur_header_len_(0), last_copy_len_(0) {} | |||
| virtual ~TcpMessageHandler() = default; | |||
| void SetCallback(const messageReceive &cb); | |||
| void ReceiveMessage(const void *buffer, size_t num); | |||
| void Reset(); | |||
| private: | |||
| messageReceive message_callback_; | |||
| bool is_parsed_; | |||
| std::unique_ptr<unsigned char[]> message_buffer_; | |||
| std::vector<uint8_t> message_buffer_; | |||
| uint8_t header_[kHeaderLen]{0}; | |||
| size_t remaining_length_; | |||
| unsigned char header_[16]{0}; | |||
| int header_index_; | |||
| size_t cur_header_len_ = 0; | |||
| size_t last_copy_len_; | |||
| MessageHeader message_header_; | |||
| std::string mBuffer; | |||
| }; | |||
| } // namespace core | |||
| } // namespace ps | |||
| @@ -67,8 +67,8 @@ bool SchedulerNode::Start(const uint32_t &timeout) { | |||
| void SchedulerNode::RunRecovery() { | |||
| core::ClusterConfig &clusterConfig = PSContext::instance()->cluster_config(); | |||
| // create tcp client to myself in case of event dispatch failed when Send reconnect msg to server failed | |||
| client_to_scheduler_ = | |||
| std::make_shared<TcpClient>(clusterConfig.scheduler_host, clusterConfig.scheduler_port, config_.get()); | |||
| client_to_scheduler_ = std::make_shared<TcpClient>(clusterConfig.scheduler_host, clusterConfig.scheduler_port, | |||
| config_.get(), NodeRole::SCHEDULER); | |||
| MS_EXCEPTION_IF_NULL(client_to_scheduler_); | |||
| client_to_scheduler_->Init(); | |||
| client_thread_ = std::make_unique<std::thread>([&]() { | |||
| @@ -95,7 +95,7 @@ void SchedulerNode::RunRecovery() { | |||
| for (const auto &kvs : initial_node_infos) { | |||
| auto &node_id = kvs.first; | |||
| auto &node_info = kvs.second; | |||
| auto client = std::make_shared<TcpClient>(node_info.ip_, node_info.port_, config_.get()); | |||
| auto client = std::make_shared<TcpClient>(node_info.ip_, node_info.port_, config_.get(), node_info.node_role_); | |||
| client->SetMessageCallback([this](const std::shared_ptr<MessageMeta> &meta, const Protos &, const void *, size_t) { | |||
| MS_LOG(INFO) << "received the response. "; | |||
| NotifyMessageArrival(meta); | |||
| @@ -648,7 +648,7 @@ const std::shared_ptr<TcpClient> &SchedulerNode::GetOrCreateClient(const NodeInf | |||
| std::string ip = node_info.ip_; | |||
| uint16_t port = node_info.port_; | |||
| MS_LOG(INFO) << "ip:" << ip << ", port:" << port << ", node id:" << node_info.node_id_; | |||
| auto client = std::make_shared<TcpClient>(ip, port, config_.get()); | |||
| auto client = std::make_shared<TcpClient>(ip, port, config_.get(), node_info.node_role_); | |||
| MS_EXCEPTION_IF_NULL(client); | |||
| client->SetMessageCallback( | |||
| [&](const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data, size_t size) { | |||
| @@ -29,7 +29,7 @@ class TestTcpClient : public UT::Common { | |||
| TEST_F(TestTcpClient, InitClientIPError) { | |||
| std::unique_ptr<Configuration> config = std::make_unique<FileConfiguration>(""); | |||
| auto client = std::make_unique<TcpClient>("127.0.0.13543", 9000, config.get()); | |||
| auto client = std::make_unique<TcpClient>("127.0.0.13543", 9000, config.get(), NodeRole::SERVER); | |||
| client->SetMessageCallback([&](std::shared_ptr<MessageMeta>, const Protos &, const void *data, size_t size) { | |||
| CommMessage message; | |||
| @@ -43,7 +43,7 @@ TEST_F(TestTcpClient, InitClientIPError) { | |||
| TEST_F(TestTcpClient, InitClientPortErrorNoException) { | |||
| std::unique_ptr<Configuration> config = std::make_unique<FileConfiguration>(""); | |||
| auto client = std::make_unique<TcpClient>("127.0.0.1", -1, config.get()); | |||
| auto client = std::make_unique<TcpClient>("127.0.0.1", -1, config.get(), NodeRole::SERVER); | |||
| client->SetMessageCallback([&](std::shared_ptr<MessageMeta>, const Protos &, const void *data, size_t size) { | |||
| CommMessage message; | |||
| @@ -60,7 +60,7 @@ class TestTcpServer : public UT::Common { | |||
| TEST_F(TestTcpServer, ServerSendMessage) { | |||
| std::unique_ptr<Configuration> config = std::make_unique<FileConfiguration>(""); | |||
| client_ = std::make_unique<TcpClient>("127.0.0.1", server_->BoundPort(), config.get()); | |||
| client_ = std::make_unique<TcpClient>("127.0.0.1", server_->BoundPort(), config.get(), NodeRole::SERVER); | |||
| std::cout << server_->BoundPort() << std::endl; | |||
| std::unique_ptr<std::thread> http_client_thread(nullptr); | |||
| http_client_thread = std::make_unique<std::thread>([&]() { | |||