diff --git a/mindspore_serving/ccsrc/common/heart_beat.h b/mindspore_serving/ccsrc/common/heart_beat.h index 0eaf501..cf95b16 100644 --- a/mindspore_serving/ccsrc/common/heart_beat.h +++ b/mindspore_serving/ccsrc/common/heart_beat.h @@ -39,18 +39,9 @@ using TimerCallback = std::function; class MS_API Timer { public: Timer() {} - ~Timer() { StopTimer(); } - void StartTimer(int64_t millisecond, TimerCallback callback) { - auto timer_run = [this, millisecond, callback]() { - std::unique_lock lk(cv_m_); - if (cv_.wait_for(lk, std::chrono::milliseconds(millisecond)) == std::cv_status::timeout) { - callback(); - } - }; - thread_ = std::thread(timer_run); - } - void StopTimer() { - cv_.notify_one(); + ~Timer() { + is_stoped_.store(true); + cv_.notify_all(); if (thread_.joinable()) { try { thread_.join(); @@ -60,10 +51,27 @@ class MS_API Timer { } } + void StartTimer(int64_t millisecond, TimerCallback callback) { + auto timer_run = [this, millisecond, callback]() { + while (!is_stoped_.load()) { + std::unique_lock lk(cv_m_); + if (cv_.wait_for(lk, std::chrono::milliseconds(millisecond)) == std::cv_status::timeout) { + callback(); + } + } + }; + thread_ = std::thread(timer_run); + } + void StopTimer() { + is_stoped_.store(true); + cv_.notify_all(); + } + private: std::mutex cv_m_; std::thread thread_; std::condition_variable cv_; + std::atomic is_stoped_ = false; }; template @@ -74,13 +82,14 @@ class MS_API Watcher { auto it = watchee_map_.find(address); if (it != watchee_map_.end()) { MSI_LOG(INFO) << "watchee exist: " << address; - return; + watchee_map_[address].timer_ = std::make_shared(); + } else { + WatcheeContext context; + auto channel = GrpcServer::CreateChannel(address); + context.stub_ = SendStub::NewStub(channel); + context.timer_ = std::make_shared(); + watchee_map_.insert(make_pair(address, context)); } - WatcheeContext context; - auto channel = GrpcServer::CreateChannel(address); - context.stub_ = SendStub::NewStub(channel); - context.timer_ = std::make_shared(); - watchee_map_.insert(make_pair(address, context)); MSI_LOG(INFO) << "Begin to send ping to " << address; // add timer watchee_map_[address].timer_->StartTimer(max_time_out_ / max_ping_times_, @@ -105,9 +114,11 @@ class MS_API Watcher { } void RecvPing(const std::string &address) { + std::unique_lock lock{m_lock_}; // recv message if (watcher_map_.count(address)) { watcher_map_[address].timer_->StopTimer(); + watcher_map_[address].timer_ = std::make_shared(); } else { WatcherContext context; auto channel = GrpcServer::CreateChannel(address); @@ -116,13 +127,14 @@ class MS_API Watcher { watcher_map_.insert(make_pair(address, context)); MSI_LOG(INFO) << "Begin to send pong to " << address; } - // add timer - watcher_map_[address].timer_->StartTimer(max_time_out_, std::bind(&Watcher::RecvPingTimeOut, this, address)); // send async message PongAsync(address); + // add timer + watcher_map_[address].timer_->StartTimer(max_time_out_, std::bind(&Watcher::RecvPingTimeOut, this, address)); } void RecvPong(const std::string &address) { + std::unique_lock lock{m_lock_}; // recv message if (watchee_map_.count(address)) { watchee_map_[address].timeouts_ = 0; @@ -132,10 +144,12 @@ class MS_API Watcher { } void RecvPongTimeOut(const std::string &address) { + std::unique_lock lock{m_lock_}; if (watchee_map_[address].timeouts_ >= max_ping_times_) { // add exit handle MSI_LOG(INFO) << "Recv Pong Time Out from " << address; - watchee_map_.erase(address); + watchee_map_[address].timer_->StopTimer(); + // need erase map return; } SendPing(address); @@ -144,32 +158,31 @@ class MS_API Watcher { void RecvPingTimeOut(const std::string &address) { MSI_LOG(INFO) << "Recv Ping Time Out from " << address; // add exit handle - watcher_map_.erase(address); + watcher_map_[address].timer_->StopTimer(); + // need erase map } void PingAsync(const std::string &address) { proto::PingRequest request; proto::PingReply reply; - request.set_address(address); + request.set_address(host_address_); grpc::ClientContext context; const int32_t TIME_OUT = 100; std::chrono::system_clock::time_point deadline = std::chrono::system_clock::now() + std::chrono::microseconds(TIME_OUT); context.set_deadline(deadline); (void)watchee_map_[address].stub_->Ping(&context, request, &reply); - MSI_LOG(INFO) << "Finish send ping"; } void PongAsync(const std::string &address) { proto::PongRequest request; proto::PongReply reply; - request.set_address(address); + request.set_address(host_address_); grpc::ClientContext context; const int32_t TIME_OUT = 100; std::chrono::system_clock::time_point deadline = std::chrono::system_clock::now() + std::chrono::microseconds(TIME_OUT); context.set_deadline(deadline); (void)watcher_map_[address].stub_->Pong(&context, request, &reply); - MSI_LOG(INFO) << "Finish send pong"; } private: @@ -188,6 +201,7 @@ class MS_API Watcher { uint64_t max_time_out_ = 10000; // 10s std::unordered_map watchee_map_; std::unordered_map watcher_map_; + std::mutex m_lock_; }; } // namespace mindspore::serving diff --git a/mindspore_serving/ccsrc/worker/grpc/worker_process.h b/mindspore_serving/ccsrc/worker/grpc/worker_process.h index 0a377b9..4d6129a 100644 --- a/mindspore_serving/ccsrc/worker/grpc/worker_process.h +++ b/mindspore_serving/ccsrc/worker/grpc/worker_process.h @@ -38,16 +38,17 @@ class MSWorkerImpl : public proto::MSWorker::Service { public: explicit MSWorkerImpl(const std::string server_address) { if (!watcher_) { - watcher_ = std::make_shared>(server_address); + watcher_ = std::make_shared>(server_address); } } + grpc::Status Predict(grpc::ServerContext *context, const proto::PredictRequest *request, proto::PredictReply *reply) override; grpc::Status Exit(grpc::ServerContext *context, const proto::ExitRequest *request, proto::ExitReply *reply) override; grpc::Status Ping(grpc::ServerContext *context, const proto::PingRequest *request, proto::PingReply *reply) override; grpc::Status Pong(grpc::ServerContext *context, const proto::PongRequest *request, proto::PongReply *reply) override; - std::shared_ptr> watcher_; + std::shared_ptr> watcher_; }; } // namespace serving diff --git a/mindspore_serving/ccsrc/worker/grpc/worker_server.h b/mindspore_serving/ccsrc/worker/grpc/worker_server.h index d02d014..575a322 100644 --- a/mindspore_serving/ccsrc/worker/grpc/worker_server.h +++ b/mindspore_serving/ccsrc/worker/grpc/worker_server.h @@ -143,6 +143,71 @@ class WorkerExitContext : public WorkerServiceContext { proto::ExitReply response_; }; +class WorkerPingContext : public WorkerServiceContext { + public: + WorkerPingContext(MSWorkerImpl *service_impl, proto::MSWorker::AsyncService *async_service, + grpc::ServerCompletionQueue *cq) + : WorkerServiceContext(service_impl, async_service, cq), responder_(&ctx_) {} + + ~WorkerPingContext() = default; + + static Status EnqueueRequest(MSWorkerImpl *service_impl, proto::MSWorker::AsyncService *async_service, + grpc::ServerCompletionQueue *cq) { + auto call = new WorkerPingContext(service_impl, async_service, cq); + call->StartEnqueueRequest(); + return SUCCESS; + } + + void StartEnqueueRequest() override { + state_ = STATE::PROCESS; + async_service_->RequestPing(&ctx_, &request_, &responder_, cq_, cq_, this); + } + + void HandleRequest() override { + EnqueueRequest(service_impl_, async_service_, cq_); + state_ = STATE::FINISH; + grpc::Status status = service_impl_->Ping(&ctx_, &request_, &response_); + responder_.Finish(response_, status, this); + } + + private: + grpc::ServerAsyncResponseWriter responder_; + proto::PingRequest request_; + proto::PingReply response_; +}; + +class WorkerPongContext : public WorkerServiceContext { + public: + WorkerPongContext(MSWorkerImpl *service_impl, proto::MSWorker::AsyncService *async_service, + grpc::ServerCompletionQueue *cq) + : WorkerServiceContext(service_impl, async_service, cq), responder_(&ctx_) {} + + ~WorkerPongContext() = default; + + static Status EnqueueRequest(MSWorkerImpl *service_impl, proto::MSWorker::AsyncService *async_service, + grpc::ServerCompletionQueue *cq) { + auto call = new WorkerPongContext(service_impl, async_service, cq); + call->StartEnqueueRequest(); + return SUCCESS; + } + + void StartEnqueueRequest() override { + state_ = STATE::PROCESS; + async_service_->RequestPong(&ctx_, &request_, &responder_, cq_, cq_, this); + } + + void HandleRequest() override { + EnqueueRequest(service_impl_, async_service_, cq_); + state_ = STATE::FINISH; + grpc::Status status = service_impl_->Pong(&ctx_, &request_, &response_); + responder_.Finish(response_, status, this); + } + + private: + grpc::ServerAsyncResponseWriter responder_; + proto::PongRequest request_; + proto::PongReply response_; +}; class WorkerGrpcServer : public GrpcAsyncServer { public: WorkerGrpcServer(const std::string &host, int32_t port, MSWorkerImpl *service_impl) @@ -158,6 +223,8 @@ class WorkerGrpcServer : public GrpcAsyncServer { Status EnqueueRequest() { WorkerPredictContext::EnqueueRequest(service_impl_, &svc_, cq_.get()); WorkerExitContext::EnqueueRequest(service_impl_, &svc_, cq_.get()); + WorkerPingContext::EnqueueRequest(service_impl_, &svc_, cq_.get()); + WorkerPongContext::EnqueueRequest(service_impl_, &svc_, cq_.get()); return SUCCESS; }