Browse Source

!143 fix heartbeat

From: @zhangyinxia
Reviewed-by: @zhoufeng54,@xu-yfei
Signed-off-by: @xu-yfei
tags/v1.2.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
d166b35862
3 changed files with 110 additions and 28 deletions
  1. +40
    -26
      mindspore_serving/ccsrc/common/heart_beat.h
  2. +3
    -2
      mindspore_serving/ccsrc/worker/grpc/worker_process.h
  3. +67
    -0
      mindspore_serving/ccsrc/worker/grpc/worker_server.h

+ 40
- 26
mindspore_serving/ccsrc/common/heart_beat.h View File

@@ -39,18 +39,9 @@ using TimerCallback = std::function<void()>;
class MS_API Timer {
public:
Timer() {}
~Timer() { StopTimer(); }
void StartTimer(int64_t millisecond, TimerCallback callback) {
auto timer_run = [this, millisecond, callback]() {
std::unique_lock<std::mutex> lk(cv_m_);
if (cv_.wait_for(lk, std::chrono::milliseconds(millisecond)) == std::cv_status::timeout) {
callback();
}
};
thread_ = std::thread(timer_run);
}
void StopTimer() {
cv_.notify_one();
~Timer() {
is_stoped_.store(true);
cv_.notify_all();
if (thread_.joinable()) {
try {
thread_.join();
@@ -60,10 +51,27 @@ class MS_API Timer {
}
}

void StartTimer(int64_t millisecond, TimerCallback callback) {
auto timer_run = [this, millisecond, callback]() {
while (!is_stoped_.load()) {
std::unique_lock<std::mutex> lk(cv_m_);
if (cv_.wait_for(lk, std::chrono::milliseconds(millisecond)) == std::cv_status::timeout) {
callback();
}
}
};
thread_ = std::thread(timer_run);
}
void StopTimer() {
is_stoped_.store(true);
cv_.notify_all();
}

private:
std::mutex cv_m_;
std::thread thread_;
std::condition_variable cv_;
std::atomic<bool> is_stoped_ = false;
};

template <class SendStub, class RecvStub>
@@ -74,13 +82,14 @@ class MS_API Watcher {
auto it = watchee_map_.find(address);
if (it != watchee_map_.end()) {
MSI_LOG(INFO) << "watchee exist: " << address;
return;
watchee_map_[address].timer_ = std::make_shared<Timer>();
} else {
WatcheeContext context;
auto channel = GrpcServer::CreateChannel(address);
context.stub_ = SendStub::NewStub(channel);
context.timer_ = std::make_shared<Timer>();
watchee_map_.insert(make_pair(address, context));
}
WatcheeContext context;
auto channel = GrpcServer::CreateChannel(address);
context.stub_ = SendStub::NewStub(channel);
context.timer_ = std::make_shared<Timer>();
watchee_map_.insert(make_pair(address, context));
MSI_LOG(INFO) << "Begin to send ping to " << address;
// add timer
watchee_map_[address].timer_->StartTimer(max_time_out_ / max_ping_times_,
@@ -105,9 +114,11 @@ class MS_API Watcher {
}

void RecvPing(const std::string &address) {
std::unique_lock<std::mutex> lock{m_lock_};
// recv message
if (watcher_map_.count(address)) {
watcher_map_[address].timer_->StopTimer();
watcher_map_[address].timer_ = std::make_shared<Timer>();
} else {
WatcherContext context;
auto channel = GrpcServer::CreateChannel(address);
@@ -116,13 +127,14 @@ class MS_API Watcher {
watcher_map_.insert(make_pair(address, context));
MSI_LOG(INFO) << "Begin to send pong to " << address;
}
// add timer
watcher_map_[address].timer_->StartTimer(max_time_out_, std::bind(&Watcher::RecvPingTimeOut, this, address));
// send async message
PongAsync(address);
// add timer
watcher_map_[address].timer_->StartTimer(max_time_out_, std::bind(&Watcher::RecvPingTimeOut, this, address));
}

void RecvPong(const std::string &address) {
std::unique_lock<std::mutex> lock{m_lock_};
// recv message
if (watchee_map_.count(address)) {
watchee_map_[address].timeouts_ = 0;
@@ -132,10 +144,12 @@ class MS_API Watcher {
}

void RecvPongTimeOut(const std::string &address) {
std::unique_lock<std::mutex> lock{m_lock_};
if (watchee_map_[address].timeouts_ >= max_ping_times_) {
// add exit handle
MSI_LOG(INFO) << "Recv Pong Time Out from " << address;
watchee_map_.erase(address);
watchee_map_[address].timer_->StopTimer();
// need erase map
return;
}
SendPing(address);
@@ -144,32 +158,31 @@ class MS_API Watcher {
void RecvPingTimeOut(const std::string &address) {
MSI_LOG(INFO) << "Recv Ping Time Out from " << address;
// add exit handle
watcher_map_.erase(address);
watcher_map_[address].timer_->StopTimer();
// need erase map
}
void PingAsync(const std::string &address) {
proto::PingRequest request;
proto::PingReply reply;
request.set_address(address);
request.set_address(host_address_);
grpc::ClientContext context;
const int32_t TIME_OUT = 100;
std::chrono::system_clock::time_point deadline =
std::chrono::system_clock::now() + std::chrono::microseconds(TIME_OUT);
context.set_deadline(deadline);
(void)watchee_map_[address].stub_->Ping(&context, request, &reply);
MSI_LOG(INFO) << "Finish send ping";
}

void PongAsync(const std::string &address) {
proto::PongRequest request;
proto::PongReply reply;
request.set_address(address);
request.set_address(host_address_);
grpc::ClientContext context;
const int32_t TIME_OUT = 100;
std::chrono::system_clock::time_point deadline =
std::chrono::system_clock::now() + std::chrono::microseconds(TIME_OUT);
context.set_deadline(deadline);
(void)watcher_map_[address].stub_->Pong(&context, request, &reply);
MSI_LOG(INFO) << "Finish send pong";
}

private:
@@ -188,6 +201,7 @@ class MS_API Watcher {
uint64_t max_time_out_ = 10000; // 10s
std::unordered_map<std::string, WatcheeContext> watchee_map_;
std::unordered_map<std::string, WatcherContext> watcher_map_;
std::mutex m_lock_;
};
} // namespace mindspore::serving



+ 3
- 2
mindspore_serving/ccsrc/worker/grpc/worker_process.h View File

@@ -38,16 +38,17 @@ class MSWorkerImpl : public proto::MSWorker::Service {
public:
explicit MSWorkerImpl(const std::string server_address) {
if (!watcher_) {
watcher_ = std::make_shared<Watcher<proto::MSMaster, proto::MSAgent>>(server_address);
watcher_ = std::make_shared<Watcher<proto::MSAgent, proto::MSMaster>>(server_address);
}
}
grpc::Status Predict(grpc::ServerContext *context, const proto::PredictRequest *request,
proto::PredictReply *reply) override;
grpc::Status Exit(grpc::ServerContext *context, const proto::ExitRequest *request, proto::ExitReply *reply) override;
grpc::Status Ping(grpc::ServerContext *context, const proto::PingRequest *request, proto::PingReply *reply) override;
grpc::Status Pong(grpc::ServerContext *context, const proto::PongRequest *request, proto::PongReply *reply) override;
std::shared_ptr<Watcher<proto::MSMaster, proto::MSAgent>> watcher_;
std::shared_ptr<Watcher<proto::MSAgent, proto::MSMaster>> watcher_;
};
} // namespace serving


+ 67
- 0
mindspore_serving/ccsrc/worker/grpc/worker_server.h View File

@@ -143,6 +143,71 @@ class WorkerExitContext : public WorkerServiceContext {
proto::ExitReply response_;
};
class WorkerPingContext : public WorkerServiceContext {
public:
WorkerPingContext(MSWorkerImpl *service_impl, proto::MSWorker::AsyncService *async_service,
grpc::ServerCompletionQueue *cq)
: WorkerServiceContext(service_impl, async_service, cq), responder_(&ctx_) {}
~WorkerPingContext() = default;
static Status EnqueueRequest(MSWorkerImpl *service_impl, proto::MSWorker::AsyncService *async_service,
grpc::ServerCompletionQueue *cq) {
auto call = new WorkerPingContext(service_impl, async_service, cq);
call->StartEnqueueRequest();
return SUCCESS;
}
void StartEnqueueRequest() override {
state_ = STATE::PROCESS;
async_service_->RequestPing(&ctx_, &request_, &responder_, cq_, cq_, this);
}
void HandleRequest() override {
EnqueueRequest(service_impl_, async_service_, cq_);
state_ = STATE::FINISH;
grpc::Status status = service_impl_->Ping(&ctx_, &request_, &response_);
responder_.Finish(response_, status, this);
}
private:
grpc::ServerAsyncResponseWriter<proto::PingReply> responder_;
proto::PingRequest request_;
proto::PingReply response_;
};
class WorkerPongContext : public WorkerServiceContext {
public:
WorkerPongContext(MSWorkerImpl *service_impl, proto::MSWorker::AsyncService *async_service,
grpc::ServerCompletionQueue *cq)
: WorkerServiceContext(service_impl, async_service, cq), responder_(&ctx_) {}
~WorkerPongContext() = default;
static Status EnqueueRequest(MSWorkerImpl *service_impl, proto::MSWorker::AsyncService *async_service,
grpc::ServerCompletionQueue *cq) {
auto call = new WorkerPongContext(service_impl, async_service, cq);
call->StartEnqueueRequest();
return SUCCESS;
}
void StartEnqueueRequest() override {
state_ = STATE::PROCESS;
async_service_->RequestPong(&ctx_, &request_, &responder_, cq_, cq_, this);
}
void HandleRequest() override {
EnqueueRequest(service_impl_, async_service_, cq_);
state_ = STATE::FINISH;
grpc::Status status = service_impl_->Pong(&ctx_, &request_, &response_);
responder_.Finish(response_, status, this);
}
private:
grpc::ServerAsyncResponseWriter<proto::PongReply> responder_;
proto::PongRequest request_;
proto::PongReply response_;
};
class WorkerGrpcServer : public GrpcAsyncServer {
public:
WorkerGrpcServer(const std::string &host, int32_t port, MSWorkerImpl *service_impl)
@@ -158,6 +223,8 @@ class WorkerGrpcServer : public GrpcAsyncServer {
Status EnqueueRequest() {
WorkerPredictContext::EnqueueRequest(service_impl_, &svc_, cq_.get());
WorkerExitContext::EnqueueRequest(service_impl_, &svc_, cq_.get());
WorkerPingContext::EnqueueRequest(service_impl_, &svc_, cq_.get());
WorkerPongContext::EnqueueRequest(service_impl_, &svc_, cq_.get());
return SUCCESS;
}


Loading…
Cancel
Save