| @@ -0,0 +1,18 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/heart_beat.h" | |||||
| namespace mindspore::serving {} // namespace mindspore::serving | |||||
| @@ -0,0 +1,194 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_SERVING_HEART_BEAT_H | |||||
| #define MINDSPORE_SERVING_HEART_BEAT_H | |||||
| #include <grpcpp/grpcpp.h> | |||||
| #include <grpcpp/health_check_service_interface.h> | |||||
| #include <grpcpp/ext/proto_server_reflection_plugin.h> | |||||
| #include <vector> | |||||
| #include <unordered_map> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <condition_variable> | |||||
| #include <thread> | |||||
| #include <functional> | |||||
| #include <chrono> | |||||
| #include <utility> | |||||
| #include "common/serving_common.h" | |||||
| #include "common/grpc_server.h" | |||||
| #include "proto/ms_service.pb.h" | |||||
| #include "proto/ms_service.grpc.pb.h" | |||||
| namespace mindspore::serving { | |||||
| using TimerCallback = std::function<void()>; | |||||
| class MS_API Timer { | |||||
| public: | |||||
| Timer() {} | |||||
| ~Timer() { StopTimer(); } | |||||
| void StartTimer(int64_t millisecond, TimerCallback callback) { | |||||
| auto timer_run = [this, millisecond, callback]() { | |||||
| std::unique_lock<std::mutex> lk(cv_m_); | |||||
| if (cv_.wait_for(lk, std::chrono::milliseconds(millisecond)) == std::cv_status::timeout) { | |||||
| callback(); | |||||
| } | |||||
| }; | |||||
| thread_ = std::thread(timer_run); | |||||
| } | |||||
| void StopTimer() { | |||||
| cv_.notify_one(); | |||||
| if (thread_.joinable()) { | |||||
| try { | |||||
| thread_.join(); | |||||
| } catch (const std::system_error &) { | |||||
| } catch (...) { | |||||
| } | |||||
| } | |||||
| } | |||||
| private: | |||||
| std::mutex cv_m_; | |||||
| std::thread thread_; | |||||
| std::condition_variable cv_; | |||||
| }; | |||||
| template <class SendStub, class RecvStub> | |||||
| class MS_API Watcher { | |||||
| public: | |||||
| explicit Watcher(const std::string host_address) { host_address_ = host_address; } | |||||
| void StartWatch(const std::string &address) { | |||||
| auto it = watchee_map_.find(address); | |||||
| if (it != watchee_map_.end()) { | |||||
| MSI_LOG(INFO) << "watchee exist: " << address; | |||||
| return; | |||||
| } | |||||
| WatcheeContext context; | |||||
| auto channel = GrpcServer::CreateChannel(address); | |||||
| context.stub_ = SendStub::NewStub(channel); | |||||
| context.timer_ = std::make_shared<Timer>(); | |||||
| watchee_map_.insert(make_pair(address, context)); | |||||
| MSI_LOG(INFO) << "Begin to send ping to " << address; | |||||
| // add timer | |||||
| watchee_map_[address].timer_->StartTimer(max_time_out_ / max_ping_times_, | |||||
| std::bind(&Watcher::RecvPongTimeOut, this, address)); | |||||
| SendPing(address); | |||||
| } | |||||
| void StopWatch(const std::string &address) { | |||||
| // clear map and timer | |||||
| auto it = watchee_map_.find(address); | |||||
| if (it == watchee_map_.end()) { | |||||
| MSI_LOG(INFO) << "watchee not exist: " << address; | |||||
| return; | |||||
| } | |||||
| watchee_map_[address].timer_->StopTimer(); | |||||
| watchee_map_.erase(address); | |||||
| } | |||||
| void SendPing(const std::string &address) { | |||||
| watchee_map_[address].timeouts_ += 1; | |||||
| // send async message | |||||
| PingAsync(address); | |||||
| } | |||||
| void RecvPing(const std::string &address) { | |||||
| // recv message | |||||
| if (watcher_map_.count(address)) { | |||||
| watcher_map_[address].timer_->StopTimer(); | |||||
| } else { | |||||
| WatcherContext context; | |||||
| auto channel = GrpcServer::CreateChannel(address); | |||||
| context.stub_ = RecvStub::NewStub(channel); | |||||
| context.timer_ = std::make_shared<Timer>(); | |||||
| watcher_map_.insert(make_pair(address, context)); | |||||
| MSI_LOG(INFO) << "Begin to send pong to " << address; | |||||
| } | |||||
| // add timer | |||||
| watcher_map_[address].timer_->StartTimer(max_time_out_, std::bind(&Watcher::RecvPingTimeOut, this, address)); | |||||
| // send async message | |||||
| PongAsync(address); | |||||
| } | |||||
| void RecvPong(const std::string &address) { | |||||
| // recv message | |||||
| if (watchee_map_.count(address)) { | |||||
| watchee_map_[address].timeouts_ = 0; | |||||
| } else { | |||||
| MSI_LOG(INFO) << "Recv Pong after timeout or stop"; | |||||
| } | |||||
| } | |||||
| void RecvPongTimeOut(const std::string &address) { | |||||
| if (watchee_map_[address].timeouts_ >= max_ping_times_) { | |||||
| // add exit handle | |||||
| MSI_LOG(INFO) << "Recv Pong Time Out from " << address; | |||||
| watchee_map_.erase(address); | |||||
| return; | |||||
| } | |||||
| SendPing(address); | |||||
| } | |||||
| void RecvPingTimeOut(const std::string &address) { | |||||
| MSI_LOG(INFO) << "Recv Ping Time Out from " << address; | |||||
| // add exit handle | |||||
| watcher_map_.erase(address); | |||||
| } | |||||
| void PingAsync(const std::string &address) { | |||||
| proto::PingRequest request; | |||||
| proto::PingReply reply; | |||||
| request.set_address(address); | |||||
| grpc::ClientContext context; | |||||
| const int32_t TIME_OUT = 100; | |||||
| std::chrono::system_clock::time_point deadline = | |||||
| std::chrono::system_clock::now() + std::chrono::microseconds(TIME_OUT); | |||||
| context.set_deadline(deadline); | |||||
| (void)watchee_map_[address].stub_->Ping(&context, request, &reply); | |||||
| MSI_LOG(INFO) << "Finish send ping"; | |||||
| } | |||||
| void PongAsync(const std::string &address) { | |||||
| proto::PongRequest request; | |||||
| proto::PongReply reply; | |||||
| request.set_address(address); | |||||
| grpc::ClientContext context; | |||||
| const int32_t TIME_OUT = 100; | |||||
| std::chrono::system_clock::time_point deadline = | |||||
| std::chrono::system_clock::now() + std::chrono::microseconds(TIME_OUT); | |||||
| context.set_deadline(deadline); | |||||
| (void)watcher_map_[address].stub_->Pong(&context, request, &reply); | |||||
| MSI_LOG(INFO) << "Finish send pong"; | |||||
| } | |||||
| private: | |||||
| struct WatcheeContext { | |||||
| uint64_t timeouts_ = 0; | |||||
| std::shared_ptr<Timer> timer_ = nullptr; | |||||
| std::shared_ptr<typename SendStub::Stub> stub_ = nullptr; | |||||
| }; | |||||
| struct WatcherContext { | |||||
| uint64_t timeouts_ = 0; | |||||
| std::shared_ptr<Timer> timer_ = nullptr; | |||||
| std::shared_ptr<typename RecvStub::Stub> stub_ = nullptr; | |||||
| }; | |||||
| std::string host_address_; | |||||
| uint64_t max_ping_times_ = 10; | |||||
| uint64_t max_time_out_ = 10000; // 10s | |||||
| std::unordered_map<std::string, WatcheeContext> watchee_map_; | |||||
| std::unordered_map<std::string, WatcherContext> watcher_map_; | |||||
| }; | |||||
| } // namespace mindspore::serving | |||||
| #endif // MINDSPORE_SERVING_HEART_BEAT_H | |||||
| @@ -106,6 +106,7 @@ grpc::Status MSMasterImpl::Register(grpc::ServerContext *context, const proto::R | |||||
| MSI_LOG_ERROR << "Register servable failed, " << worker_sig(); | MSI_LOG_ERROR << "Register servable failed, " << worker_sig(); | ||||
| return grpc::Status::OK; | return grpc::Status::OK; | ||||
| } | } | ||||
| watcher_->StartWatch(request->address()); | |||||
| MSI_LOG(INFO) << "Register success: " << worker_sig(); | MSI_LOG(INFO) << "Register success: " << worker_sig(); | ||||
| return grpc::Status::OK; | return grpc::Status::OK; | ||||
| } | } | ||||
| @@ -126,6 +127,7 @@ grpc::Status MSMasterImpl::AddWorker(grpc::ServerContext *context, const proto:: | |||||
| MSI_LOG_ERROR << "Add servable failed, " << worker_sig(); | MSI_LOG_ERROR << "Add servable failed, " << worker_sig(); | ||||
| return grpc::Status::OK; | return grpc::Status::OK; | ||||
| } | } | ||||
| watcher_->StartWatch(request->address()); | |||||
| MSI_LOG(INFO) << "Add success, " << worker_sig(); | MSI_LOG(INFO) << "Add success, " << worker_sig(); | ||||
| return grpc::Status::OK; | return grpc::Status::OK; | ||||
| } | } | ||||
| @@ -141,6 +143,7 @@ grpc::Status MSMasterImpl::RemoveWorker(grpc::ServerContext *context, const prot | |||||
| return str.str(); | return str.str(); | ||||
| }; | }; | ||||
| Status status(FAILED); | Status status(FAILED); | ||||
| watcher_->StopWatch(request->address()); | |||||
| status = dispatcher_->RemoveServable(*request, reply); | status = dispatcher_->RemoveServable(*request, reply); | ||||
| if (status != SUCCESS) { | if (status != SUCCESS) { | ||||
| MSI_LOG_ERROR << "Add servable failed, " << worker_sig(); | MSI_LOG_ERROR << "Add servable failed, " << worker_sig(); | ||||
| @@ -162,6 +165,7 @@ grpc::Status MSMasterImpl::Exit(grpc::ServerContext *context, const proto::ExitR | |||||
| MSI_LOG(INFO) << "Worker Exit, " << worker_sig(); | MSI_LOG(INFO) << "Worker Exit, " << worker_sig(); | ||||
| Status status(FAILED); | Status status(FAILED); | ||||
| watcher_->StopWatch(request->address()); | |||||
| status = dispatcher_->UnregisterServable(*request, reply); | status = dispatcher_->UnregisterServable(*request, reply); | ||||
| if (status != SUCCESS) { | if (status != SUCCESS) { | ||||
| MSI_LOG_ERROR << "UnRegister servable failed, " << worker_sig(); | MSI_LOG_ERROR << "UnRegister servable failed, " << worker_sig(); | ||||
| @@ -169,6 +173,20 @@ grpc::Status MSMasterImpl::Exit(grpc::ServerContext *context, const proto::ExitR | |||||
| } | } | ||||
| return grpc::Status::OK; | return grpc::Status::OK; | ||||
| } | } | ||||
| grpc::Status MSMasterImpl::Ping(grpc::ServerContext *context, const proto::PingRequest *request, | |||||
| proto::PingReply *reply) { | |||||
| MSI_EXCEPTION_IF_NULL(request); | |||||
| MSI_EXCEPTION_IF_NULL(reply); | |||||
| watcher_->RecvPing(request->address()); | |||||
| return grpc::Status::OK; | |||||
| } | |||||
| grpc::Status MSMasterImpl::Pong(grpc::ServerContext *context, const proto::PongRequest *request, | |||||
| proto::PongReply *reply) { | |||||
| MSI_EXCEPTION_IF_NULL(request); | |||||
| MSI_EXCEPTION_IF_NULL(reply); | |||||
| watcher_->RecvPong(request->address()); | |||||
| return grpc::Status::OK; | |||||
| } | |||||
| } // namespace serving | } // namespace serving | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -21,11 +21,15 @@ | |||||
| #include <grpcpp/health_check_service_interface.h> | #include <grpcpp/health_check_service_interface.h> | ||||
| #include <grpcpp/ext/proto_server_reflection_plugin.h> | #include <grpcpp/ext/proto_server_reflection_plugin.h> | ||||
| #include <memory> | #include <memory> | ||||
| #include <string> | |||||
| #include "common/serving_common.h" | #include "common/serving_common.h" | ||||
| #include "common/heart_beat.h" | |||||
| #include "proto/ms_service.pb.h" | #include "proto/ms_service.pb.h" | ||||
| #include "proto/ms_service.grpc.pb.h" | #include "proto/ms_service.grpc.pb.h" | ||||
| #include "proto/ms_master.pb.h" | #include "proto/ms_master.pb.h" | ||||
| #include "proto/ms_master.grpc.pb.h" | #include "proto/ms_master.grpc.pb.h" | ||||
| #include "proto/ms_worker.pb.h" | |||||
| #include "proto/ms_worker.grpc.pb.h" | |||||
| #include "master/dispacther.h" | #include "master/dispacther.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -46,7 +50,12 @@ class MSServiceImpl { | |||||
| // Service Implement | // Service Implement | ||||
| class MSMasterImpl final : public proto::MSMaster::Service { | class MSMasterImpl final : public proto::MSMaster::Service { | ||||
| public: | public: | ||||
| explicit MSMasterImpl(std::shared_ptr<Dispatcher> dispatcher) : dispatcher_(dispatcher) {} | |||||
| explicit MSMasterImpl(std::shared_ptr<Dispatcher> dispatcher, const std::string server_address) | |||||
| : dispatcher_(dispatcher) { | |||||
| if (!watcher_) { | |||||
| watcher_ = std::make_shared<Watcher<proto::MSWorker, proto::MSWorker>>(server_address); | |||||
| } | |||||
| } | |||||
| ~MSMasterImpl() = default; | ~MSMasterImpl() = default; | ||||
| grpc::Status Register(grpc::ServerContext *context, const proto::RegisterRequest *request, | grpc::Status Register(grpc::ServerContext *context, const proto::RegisterRequest *request, | ||||
| @@ -56,9 +65,12 @@ class MSMasterImpl final : public proto::MSMaster::Service { | |||||
| proto::AddWorkerReply *reply) override; | proto::AddWorkerReply *reply) override; | ||||
| grpc::Status RemoveWorker(grpc::ServerContext *context, const proto::RemoveWorkerRequest *request, | grpc::Status RemoveWorker(grpc::ServerContext *context, const proto::RemoveWorkerRequest *request, | ||||
| proto::RemoveWorkerReply *reply) override; | proto::RemoveWorkerReply *reply) override; | ||||
| grpc::Status Ping(grpc::ServerContext *context, const proto::PingRequest *request, proto::PingReply *reply) override; | |||||
| grpc::Status Pong(grpc::ServerContext *context, const proto::PongRequest *request, proto::PongReply *reply) override; | |||||
| private: | private: | ||||
| std::shared_ptr<Dispatcher> dispatcher_; | std::shared_ptr<Dispatcher> dispatcher_; | ||||
| std::shared_ptr<Watcher<proto::MSWorker, proto::MSWorker>> watcher_; | |||||
| }; | }; | ||||
| } // namespace serving | } // namespace serving | ||||
| @@ -49,8 +49,9 @@ Status Server::StartGrpcServer(const std::string &ip, uint32_t grpc_port, int ma | |||||
| } | } | ||||
| Status Server::StartGrpcMasterServer(const std::string &ip, uint32_t grpc_port) { | Status Server::StartGrpcMasterServer(const std::string &ip, uint32_t grpc_port) { | ||||
| return grpc_manager_server_.Start(std::make_shared<MSMasterImpl>(dispatcher_), ip, grpc_port, gRpcMaxMBMsgSize, | |||||
| "Master"); | |||||
| std::string server_address = ip + ":" + std::to_string(grpc_port); | |||||
| return grpc_manager_server_.Start(std::make_shared<MSMasterImpl>(dispatcher_, server_address), ip, grpc_port, | |||||
| gRpcMaxMBMsgSize, "Master"); | |||||
| } | } | ||||
| Status Server::StartRestfulServer(const std::string &ip, uint32_t restful_port, int max_msg_mb_size, | Status Server::StartRestfulServer(const std::string &ip, uint32_t restful_port, int max_msg_mb_size, | ||||
| @@ -33,6 +33,20 @@ grpc::Status MSAgentImpl::Predict(grpc::ServerContext *context, const proto::Dis | |||||
| MSI_LOG(INFO) << "End call service Eval"; | MSI_LOG(INFO) << "End call service Eval"; | ||||
| return grpc::Status::OK; | return grpc::Status::OK; | ||||
| } | } | ||||
| grpc::Status MSAgentImpl::Ping(grpc::ServerContext *context, const proto::PingRequest *request, | |||||
| proto::PingReply *reply) { | |||||
| MSI_EXCEPTION_IF_NULL(request); | |||||
| MSI_EXCEPTION_IF_NULL(reply); | |||||
| watcher_->RecvPing(request->address()); | |||||
| return grpc::Status::OK; | |||||
| } | |||||
| grpc::Status MSAgentImpl::Pong(grpc::ServerContext *context, const proto::PongRequest *request, | |||||
| proto::PongReply *reply) { | |||||
| MSI_EXCEPTION_IF_NULL(request); | |||||
| MSI_EXCEPTION_IF_NULL(reply); | |||||
| watcher_->RecvPong(request->address()); | |||||
| return grpc::Status::OK; | |||||
| } | |||||
| } // namespace serving | } // namespace serving | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -20,9 +20,14 @@ | |||||
| #include <grpcpp/grpcpp.h> | #include <grpcpp/grpcpp.h> | ||||
| #include <grpcpp/health_check_service_interface.h> | #include <grpcpp/health_check_service_interface.h> | ||||
| #include <grpcpp/ext/proto_server_reflection_plugin.h> | #include <grpcpp/ext/proto_server_reflection_plugin.h> | ||||
| #include <memory> | |||||
| #include <string> | |||||
| #include "common/serving_common.h" | #include "common/serving_common.h" | ||||
| #include "common/heart_beat.h" | |||||
| #include "proto/ms_agent.pb.h" | #include "proto/ms_agent.pb.h" | ||||
| #include "proto/ms_agent.grpc.pb.h" | #include "proto/ms_agent.grpc.pb.h" | ||||
| #include "proto/ms_worker.pb.h" | |||||
| #include "proto/ms_worker.grpc.pb.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace serving { | namespace serving { | ||||
| @@ -30,10 +35,20 @@ namespace serving { | |||||
| // Service Implement | // Service Implement | ||||
| class MSAgentImpl final : public proto::MSAgent::Service { | class MSAgentImpl final : public proto::MSAgent::Service { | ||||
| public: | public: | ||||
| explicit MSAgentImpl(const std::string server_address) { | |||||
| if (!watcher_) { | |||||
| watcher_ = std::make_shared<Watcher<proto::MSWorker, proto::MSWorker>>(server_address); | |||||
| } | |||||
| } | |||||
| grpc::Status Predict(grpc::ServerContext *context, const proto::DistributedPredictRequest *request, | grpc::Status Predict(grpc::ServerContext *context, const proto::DistributedPredictRequest *request, | ||||
| proto::DistributedPredictReply *reply) override; | proto::DistributedPredictReply *reply) override; | ||||
| grpc::Status Exit(grpc::ServerContext *context, const proto::DistributedExitRequest *request, | grpc::Status Exit(grpc::ServerContext *context, const proto::DistributedExitRequest *request, | ||||
| proto::DistributedExitReply *reply) override; | proto::DistributedExitReply *reply) override; | ||||
| grpc::Status Ping(grpc::ServerContext *context, const proto::PingRequest *request, proto::PingReply *reply) override; | |||||
| grpc::Status Pong(grpc::ServerContext *context, const proto::PongRequest *request, proto::PongReply *reply) override; | |||||
| private: | |||||
| std::shared_ptr<Watcher<proto::MSWorker, proto::MSWorker>> watcher_; | |||||
| }; | }; | ||||
| } // namespace serving | } // namespace serving | ||||
| @@ -34,6 +34,7 @@ grpc::Status MSDistributedImpl::AgentRegister(grpc::ServerContext *context, cons | |||||
| if (status != SUCCESS) { | if (status != SUCCESS) { | ||||
| MSI_LOG(ERROR) << "Agent Register FAILED"; | MSI_LOG(ERROR) << "Agent Register FAILED"; | ||||
| } | } | ||||
| watcher_->StartWatch(request->address()); | |||||
| } | } | ||||
| return grpc::Status::OK; | return grpc::Status::OK; | ||||
| } | } | ||||
| @@ -42,6 +43,7 @@ grpc::Status MSDistributedImpl::AgentExit(grpc::ServerContext *context, const pr | |||||
| proto::AgentExitReply *reply) { | proto::AgentExitReply *reply) { | ||||
| MSI_EXCEPTION_IF_NULL(request); | MSI_EXCEPTION_IF_NULL(request); | ||||
| MSI_EXCEPTION_IF_NULL(reply); | MSI_EXCEPTION_IF_NULL(reply); | ||||
| watcher_->StopWatch(request->address()); | |||||
| servable_->OnAgentExit(); | servable_->OnAgentExit(); | ||||
| if (Worker::GetInstance().IsRunning()) { | if (Worker::GetInstance().IsRunning()) { | ||||
| Worker::GetInstance().StopServable(); | Worker::GetInstance().StopServable(); | ||||
| @@ -21,7 +21,9 @@ | |||||
| #include <grpcpp/health_check_service_interface.h> | #include <grpcpp/health_check_service_interface.h> | ||||
| #include <grpcpp/ext/proto_server_reflection_plugin.h> | #include <grpcpp/ext/proto_server_reflection_plugin.h> | ||||
| #include <memory> | #include <memory> | ||||
| #include <string> | |||||
| #include "common/serving_common.h" | #include "common/serving_common.h" | ||||
| #include "common/heart_beat.h" | |||||
| #include "proto/ms_service.pb.h" | #include "proto/ms_service.pb.h" | ||||
| #include "proto/ms_service.grpc.pb.h" | #include "proto/ms_service.grpc.pb.h" | ||||
| #include "proto/ms_distributed.pb.h" | #include "proto/ms_distributed.pb.h" | ||||
| @@ -35,7 +37,8 @@ namespace serving { | |||||
| // Service Implement | // Service Implement | ||||
| class MSDistributedImpl final : public MSWorkerImpl { | class MSDistributedImpl final : public MSWorkerImpl { | ||||
| public: | public: | ||||
| explicit MSDistributedImpl(std::shared_ptr<DistributedServable> servable) : servable_(servable) {} | |||||
| explicit MSDistributedImpl(std::shared_ptr<DistributedServable> servable, const std::string server_address) | |||||
| : MSWorkerImpl(server_address), servable_(servable) {} | |||||
| ~MSDistributedImpl() = default; | ~MSDistributedImpl() = default; | ||||
| grpc::Status AgentRegister(grpc::ServerContext *context, const proto::AgentRegisterRequest *request, | grpc::Status AgentRegister(grpc::ServerContext *context, const proto::AgentRegisterRequest *request, | ||||
| proto::AgentRegisterReply *reply) override; | proto::AgentRegisterReply *reply) override; | ||||
| @@ -27,7 +27,8 @@ Status MSDistributedWorkerServer::StartWorkerGrpcServer(const std::string &hostn | |||||
| if (in_running_) { | if (in_running_) { | ||||
| return INFER_STATUS_LOG_ERROR(FAILED) << "Worker grpc server is already running"; | return INFER_STATUS_LOG_ERROR(FAILED) << "Worker grpc server is already running"; | ||||
| } | } | ||||
| auto impl = std::make_unique<MSDistributedImpl>(servable_); | |||||
| std::string server_address = hostname + ":" + std::to_string(port); | |||||
| auto impl = std::make_unique<MSDistributedImpl>(servable_, server_address); | |||||
| async_server_ = std::make_unique<DistributedWorkerGrpcServer>(hostname, port, impl.get()); | async_server_ = std::make_unique<DistributedWorkerGrpcServer>(hostname, port, impl.get()); | ||||
| service_impl_ = std::move(impl); | service_impl_ = std::move(impl); | ||||
| return Init(); | return Init(); | ||||
| @@ -101,7 +101,9 @@ Status WorkerAgent::StartAgent(const AgentStartUpConfig &config) { | |||||
| } | } | ||||
| Status WorkerAgent::StartGrpcServer() { | Status WorkerAgent::StartGrpcServer() { | ||||
| grpc_server_.Start(std::make_shared<MSAgentImpl>(), config_.agent_ip, config_.agent_port, gRpcMaxMBMsgSize, "Agent"); | |||||
| std::string server_address = config_.agent_ip + ":" + std::to_string(config_.agent_port); | |||||
| grpc_server_.Start(std::make_shared<MSAgentImpl>(server_address), config_.agent_ip, config_.agent_port, | |||||
| gRpcMaxMBMsgSize, "Agent"); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -62,6 +62,20 @@ grpc::Status MSWorkerImpl::Predict(grpc::ServerContext *context, const proto::Pr | |||||
| } | } | ||||
| return grpc::Status::OK; | return grpc::Status::OK; | ||||
| } | } | ||||
| grpc::Status MSWorkerImpl::Ping(grpc::ServerContext *context, const proto::PingRequest *request, | |||||
| proto::PingReply *reply) { | |||||
| MSI_EXCEPTION_IF_NULL(request); | |||||
| MSI_EXCEPTION_IF_NULL(reply); | |||||
| watcher_->RecvPing(request->address()); | |||||
| return grpc::Status::OK; | |||||
| } | |||||
| grpc::Status MSWorkerImpl::Pong(grpc::ServerContext *context, const proto::PongRequest *request, | |||||
| proto::PongReply *reply) { | |||||
| MSI_EXCEPTION_IF_NULL(request); | |||||
| MSI_EXCEPTION_IF_NULL(reply); | |||||
| watcher_->RecvPong(request->address()); | |||||
| return grpc::Status::OK; | |||||
| } | |||||
| } // namespace serving | } // namespace serving | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -20,19 +20,34 @@ | |||||
| #include <grpcpp/grpcpp.h> | #include <grpcpp/grpcpp.h> | ||||
| #include <grpcpp/health_check_service_interface.h> | #include <grpcpp/health_check_service_interface.h> | ||||
| #include <grpcpp/ext/proto_server_reflection_plugin.h> | #include <grpcpp/ext/proto_server_reflection_plugin.h> | ||||
| #include <memory> | |||||
| #include <string> | |||||
| #include "common/serving_common.h" | #include "common/serving_common.h" | ||||
| #include "common/heart_beat.h" | |||||
| #include "proto/ms_worker.pb.h" | #include "proto/ms_worker.pb.h" | ||||
| #include "proto/ms_worker.grpc.pb.h" | #include "proto/ms_worker.grpc.pb.h" | ||||
| #include "proto/ms_master.pb.h" | |||||
| #include "proto/ms_master.grpc.pb.h" | |||||
| #include "proto/ms_agent.pb.h" | |||||
| #include "proto/ms_agent.grpc.pb.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace serving { | namespace serving { | ||||
| // Service Implement | // Service Implement | ||||
| class MSWorkerImpl : public proto::MSWorker::Service { | class MSWorkerImpl : public proto::MSWorker::Service { | ||||
| public: | public: | ||||
| explicit MSWorkerImpl(const std::string server_address) { | |||||
| if (!watcher_) { | |||||
| watcher_ = std::make_shared<Watcher<proto::MSMaster, proto::MSAgent>>(server_address); | |||||
| } | |||||
| } | |||||
| grpc::Status Predict(grpc::ServerContext *context, const proto::PredictRequest *request, | grpc::Status Predict(grpc::ServerContext *context, const proto::PredictRequest *request, | ||||
| proto::PredictReply *reply) override; | proto::PredictReply *reply) override; | ||||
| grpc::Status Exit(grpc::ServerContext *context, const proto::ExitRequest *request, proto::ExitReply *reply) override; | grpc::Status Exit(grpc::ServerContext *context, const proto::ExitRequest *request, proto::ExitReply *reply) override; | ||||
| grpc::Status Ping(grpc::ServerContext *context, const proto::PingRequest *request, proto::PingReply *reply) override; | |||||
| grpc::Status Pong(grpc::ServerContext *context, const proto::PongRequest *request, proto::PongReply *reply) override; | |||||
| std::shared_ptr<Watcher<proto::MSMaster, proto::MSAgent>> watcher_; | |||||
| }; | }; | ||||
| } // namespace serving | } // namespace serving | ||||
| @@ -28,7 +28,8 @@ Status MSWorkerServer::StartWorkerGrpcServer(const std::string &hostname, int32_ | |||||
| if (in_running_) { | if (in_running_) { | ||||
| return INFER_STATUS_LOG_ERROR(FAILED) << "Worker grpc server is already running"; | return INFER_STATUS_LOG_ERROR(FAILED) << "Worker grpc server is already running"; | ||||
| } | } | ||||
| service_impl_ = std::make_unique<MSWorkerImpl>(); | |||||
| std::string server_address = hostname + ":" + std::to_string(port); | |||||
| service_impl_ = std::make_unique<MSWorkerImpl>(server_address); | |||||
| async_server_ = std::make_unique<WorkerGrpcServer>(hostname, port, service_impl_.get()); | async_server_ = std::make_unique<WorkerGrpcServer>(hostname, port, service_impl_.get()); | ||||
| return Init(); | return Init(); | ||||
| } | } | ||||
| @@ -40,4 +40,6 @@ message DistributedExitReply { | |||||
| service MSAgent { | service MSAgent { | ||||
| rpc Predict(DistributedPredictRequest) returns (DistributedPredictReply) {} | rpc Predict(DistributedPredictRequest) returns (DistributedPredictReply) {} | ||||
| rpc Exit(DistributedExitRequest) returns (DistributedExitReply) {} | rpc Exit(DistributedExitRequest) returns (DistributedExitReply) {} | ||||
| rpc Ping(PingRequest) returns (PingReply) {} | |||||
| rpc Pong(PongRequest) returns (PongReply) {} | |||||
| } | } | ||||
| @@ -25,6 +25,8 @@ service MSMaster { | |||||
| rpc Exit(ExitRequest) returns (ExitReply) {} | rpc Exit(ExitRequest) returns (ExitReply) {} | ||||
| rpc AddWorker(AddWorkerRequest) returns (AddWorkerReply) {} | rpc AddWorker(AddWorkerRequest) returns (AddWorkerReply) {} | ||||
| rpc RemoveWorker(RemoveWorkerRequest) returns (RemoveWorkerReply) {} | rpc RemoveWorker(RemoveWorkerRequest) returns (RemoveWorkerReply) {} | ||||
| rpc Ping(PingRequest) returns (PingReply) {} | |||||
| rpc Pong(PongRequest) returns (PongReply) {} | |||||
| } | } | ||||
| message WorkerSpec { | message WorkerSpec { | ||||
| @@ -94,3 +94,18 @@ message ServableSpec { | |||||
| // Specifies the method name in the servable. | // Specifies the method name in the servable. | ||||
| string method_name = 2; | string method_name = 2; | ||||
| } | } | ||||
| message PingRequest { | |||||
| string address = 1; | |||||
| } | |||||
| message PingReply { | |||||
| string address = 1; | |||||
| } | |||||
| message PongRequest { | |||||
| string address = 1; | |||||
| } | |||||
| message PongReply { | |||||
| string address = 1; | |||||
| } | |||||
| @@ -31,4 +31,6 @@ service MSWorker { | |||||
| rpc AgentRegister(AgentRegisterRequest) returns (AgentRegisterReply) {} | rpc AgentRegister(AgentRegisterRequest) returns (AgentRegisterReply) {} | ||||
| rpc AgentFailed(AgentFailedRequest) returns (AgentFailedReply) {} | rpc AgentFailed(AgentFailedRequest) returns (AgentFailedReply) {} | ||||
| rpc AgentConfigAcquire(AgentConfigAcquireRequest) returns (AgentConfigAcquireReply) {} | rpc AgentConfigAcquire(AgentConfigAcquireRequest) returns (AgentConfigAcquireReply) {} | ||||
| rpc Ping(PingRequest) returns (PingReply) {} | |||||
| rpc Pong(PongRequest) returns (PongReply) {} | |||||
| } | } | ||||