From 620bc494b47c023890d26fa8bfd6e66690bbfee1 Mon Sep 17 00:00:00 2001 From: xuyongfei Date: Thu, 28 Jan 2021 21:23:02 +0800 Subject: [PATCH] Serving, commbile distributed worker and local worker in grpc server process --- mindspore_serving/ccsrc/common/exit_handle.cc | 13 ++ mindspore_serving/ccsrc/common/exit_handle.h | 2 + .../ccsrc/python/worker/worker_py.cc | 27 +++- .../distributed_process.cc | 2 +- .../distributed_process.h | 3 +- .../grpc/distributed_server.cc | 38 +++++ .../grpc/distributed_server.h | 147 ++++++++++++++++++ .../notify_distributed/notify_worker.cc | 10 +- .../notify_distributed/notify_worker.h | 12 +- .../ccsrc/worker/grpc/worker_process.cc | 1 - .../ccsrc/worker/grpc/worker_process.h | 2 +- .../ccsrc/worker/grpc/worker_server.cc | 18 ++- .../ccsrc/worker/grpc/worker_server.h | 22 +-- .../ccsrc/worker/work_executor.h | 6 +- mindspore_serving/ccsrc/worker/worker.cc | 17 +- mindspore_serving/ccsrc/worker/worker.h | 5 +- mindspore_serving/proto/ms_distributed.proto | 5 - mindspore_serving/proto/ms_worker.proto | 5 + 18 files changed, 278 insertions(+), 57 deletions(-) rename mindspore_serving/ccsrc/worker/distributed_worker/{distributed_process => grpc}/distributed_process.cc (96%) rename mindspore_serving/ccsrc/worker/distributed_worker/{distributed_process => grpc}/distributed_process.h (92%) create mode 100644 mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_server.cc create mode 100644 mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_server.h diff --git a/mindspore_serving/ccsrc/common/exit_handle.cc b/mindspore_serving/ccsrc/common/exit_handle.cc index 3b97c21..88d9644 100644 --- a/mindspore_serving/ccsrc/common/exit_handle.cc +++ b/mindspore_serving/ccsrc/common/exit_handle.cc @@ -55,6 +55,17 @@ void ExitSignalHandle::WorkerWait() { exit_future.wait(); } +// waiting ctrl+c or stop message to exit, +// if no server is running or server has exited, there is no need to wait +void ExitSignalHandle::AgentWait() { + if (!is_running_) { + MSI_LOG_INFO << "Exit Handle has not started or has exited"; + return; + } + auto exit_future = agent_exit_requested_.get_future(); + exit_future.wait(); +} + void ExitSignalHandle::Start() { if (is_running_) { return; @@ -62,6 +73,7 @@ void ExitSignalHandle::Start() { is_running_ = true; master_exit_requested_ = std::promise(); worker_exit_requested_ = std::promise(); + agent_exit_requested_ = std::promise(); has_exited_.clear(); InitSignalHandle(); } @@ -79,6 +91,7 @@ void ExitSignalHandle::HandleSignalInner() { if (!has_exited_.test_and_set()) { master_exit_requested_.set_value(); worker_exit_requested_.set_value(); + agent_exit_requested_.set_value(); is_running_ = false; } } diff --git a/mindspore_serving/ccsrc/common/exit_handle.h b/mindspore_serving/ccsrc/common/exit_handle.h index 66a2fd2..42654c6 100644 --- a/mindspore_serving/ccsrc/common/exit_handle.h +++ b/mindspore_serving/ccsrc/common/exit_handle.h @@ -32,6 +32,7 @@ class MS_API ExitSignalHandle { void InitSignalHandle(); void MasterWait(); void WorkerWait(); + void AgentWait(); void Start(); void Stop(); bool HasStopped(); @@ -39,6 +40,7 @@ class MS_API ExitSignalHandle { private: std::promise master_exit_requested_; std::promise worker_exit_requested_; + std::promise agent_exit_requested_; std::atomic_flag has_exited_ = true; std::atomic_flag has_inited_ = ATOMIC_FLAG_INIT; std::atomic_bool is_running_ = false; diff --git a/mindspore_serving/ccsrc/python/worker/worker_py.cc b/mindspore_serving/ccsrc/python/worker/worker_py.cc index dce0b70..faa0f31 100644 --- a/mindspore_serving/ccsrc/python/worker/worker_py.cc +++ b/mindspore_serving/ccsrc/python/worker/worker_py.cc @@ -23,13 +23,15 @@ #include "worker/notfiy_master/local_notify.h" #include "worker/local_servable/local_sevable.h" #include "worker/distributed_worker/distributed_servable.h" +#include "worker/grpc/worker_server.h" +#include "worker/distributed_worker/grpc/distributed_server.h" namespace mindspore::serving { void PyWorker::StartServable(const std::string &model_directory, const std::string &model_name, uint32_t version_number, - const std::string &master_ip, uint32_t master_port, const std::string &host_ip, - uint32_t host_port) { - auto notify_master = std::make_shared(master_ip, master_port, host_ip, host_port); + const std::string &master_ip, uint32_t master_port, const std::string &worker_ip, + uint32_t worker_port) { + auto notify_master = std::make_shared(master_ip, master_port, worker_ip, worker_port); auto servable = std::make_shared(); auto status = servable->StartServable(model_directory, model_name, version_number); if (status != SUCCESS) { @@ -39,10 +41,14 @@ void PyWorker::StartServable(const std::string &model_directory, const std::stri if (status != SUCCESS) { MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); } - status = Worker::GetInstance().StartGrpcServer(host_ip, host_port); + // start grpc server + auto grpc_sever = std::make_shared(); + status = grpc_sever->StartWorkerGrpcServer(worker_ip, worker_port); if (status != SUCCESS) { MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); } + Worker::GetInstance().AfterStartGrpcServer(grpc_sever); + status = Worker::GetInstance().StartVersionController(); if (status != SUCCESS) { MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); @@ -72,12 +78,15 @@ void PyWorker::StartDistributedServable(const std::string &servable_directory, c const std::string &worker_ip, uint32_t worker_port, const std::string &master_ip, uint32_t master_port) { Status status; - status = Worker::GetInstance().StartGrpcServer(worker_ip, worker_port); + auto servable = std::make_shared(); + auto grpc_sever = std::make_shared(); + status = grpc_sever->StartDistributedWorkerGrpcServer(servable, worker_ip, worker_port); if (status != SUCCESS) { MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); } + Worker::GetInstance().AfterStartGrpcServer(grpc_sever); + auto notify_master = std::make_shared(master_ip, master_port, worker_ip, worker_port); - auto servable = std::make_shared(); status = servable->StartServable(servable_directory, servable_name, rank_table_json_file, version_number); if (status != SUCCESS) { MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); @@ -96,13 +105,15 @@ void PyWorker::StartDistributedServableInMaster(const std::string &servable_dire const std::string &rank_table_json_file, uint32_t version_number, const std::string &worker_ip, uint32_t worker_port) { Status status; - status = Worker::GetInstance().StartGrpcServer(worker_ip, worker_port); + auto servable = std::make_shared(); + auto grpc_sever = std::make_shared(); + status = grpc_sever->StartDistributedWorkerGrpcServer(servable, worker_ip, worker_port); if (status != SUCCESS) { MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); } + Worker::GetInstance().AfterStartGrpcServer(grpc_sever); auto notify_master = std::make_shared(); - auto servable = std::make_shared(); status = servable->StartServable(servable_directory, servable_name, rank_table_json_file, version_number); if (status != SUCCESS) { MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.cc b/mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_process.cc similarity index 96% rename from mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.cc rename to mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_process.cc index 72442cc..0333434 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.cc +++ b/mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_process.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "worker/distributed_worker/distributed_process/distributed_process.h" +#include "worker/distributed_worker/grpc/distributed_process.h" #include "common/proto_tensor.h" namespace mindspore { diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.h b/mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_process.h similarity index 92% rename from mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.h rename to mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_process.h index 3ef02b2..b127ac7 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.h +++ b/mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_process.h @@ -27,12 +27,13 @@ #include "proto/ms_distributed.pb.h" #include "proto/ms_distributed.grpc.pb.h" #include "worker/distributed_worker/distributed_servable.h" +#include "worker/grpc/worker_process.h" namespace mindspore { namespace serving { // Service Implement -class MSDistributedImpl final : public proto::MSDistributedWorker::Service { +class MSDistributedImpl final : public MSWorkerImpl { public: explicit MSDistributedImpl(std::shared_ptr servable) : servable_(servable) {} ~MSDistributedImpl() = default; diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_server.cc b/mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_server.cc new file mode 100644 index 0000000..79d4064 --- /dev/null +++ b/mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_server.cc @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "worker/distributed_worker/grpc/distributed_server.h" +#include +#include +#include +#include "common/grpc_server.h" + +namespace mindspore { +namespace serving { + +Status MSDistributedWorkerServer::StartDistributedWorkerGrpcServer(std::shared_ptr servable, + const std::string &hostname, int32_t port) { + if (in_running_) { + return INFER_STATUS_LOG_ERROR(FAILED) << "Worker grpc server is already running"; + } + auto impl = std::make_unique(servable); + async_server_ = std::make_unique(hostname, port, impl.get()); + service_impl_ = std::move(impl); + return Init(); +} + +} // namespace serving +} // namespace mindspore diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_server.h b/mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_server.h new file mode 100644 index 0000000..2151a41 --- /dev/null +++ b/mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_server.h @@ -0,0 +1,147 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_SERVING_WORKER_DISTRIBUTED_WORKER_SERVER_H +#define MINDSPORE_SERVING_WORKER_DISTRIBUTED_WORKER_SERVER_H + +#include +#include +#include +#include +#include +#include "common/serving_common.h" +#include "proto/ms_worker.pb.h" +#include "proto/ms_worker.grpc.pb.h" +#include "common/grpc_async_server.h" +#include "worker/grpc/worker_process.h" +#include "worker/grpc/worker_server.h" +#include "worker/distributed_worker/grpc/distributed_process.h" + +namespace mindspore { +namespace serving { + +// Service Implement +class MS_API MSDistributedWorkerServer : public MSWorkerServer { + public: + Status StartDistributedWorkerGrpcServer(std::shared_ptr servable, const std::string &hostname, + int32_t port); +}; + +// Service Implement +class WorkerAgentRegisterContext : public WorkerServiceContext { + public: + WorkerAgentRegisterContext(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service, + grpc::ServerCompletionQueue *cq) + : service_impl_(service_impl), async_service_(async_service), cq_(cq), responder_(&ctx_) { + state_ = STATE::CREATE; + } + + ~WorkerAgentRegisterContext() = default; + + static Status EnqueueRequest(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service, + grpc::ServerCompletionQueue *cq) { + auto call = new WorkerAgentRegisterContext(service_impl, async_service, cq); + call->StartEnqueueRequest(); + return SUCCESS; + } + + void StartEnqueueRequest() override { + state_ = STATE::PROCESS; + async_service_->RequestPredict(&ctx_, &request_, &responder_, cq_, cq_, this); + } + + void HandleRequest() override { + EnqueueRequest(service_impl_, async_service_, cq_); + state_ = STATE::FINISH; + grpc::Status status = service_impl_->Predict(&ctx_, &request_, &response_); + responder_.Finish(response_, status, this); + } + + bool JudgeFinish() override { return state_ == STATE::FINISH; } + + private: + MSDistributedImpl *service_impl_; + proto::MSWorker::AsyncService *async_service_; + grpc::ServerCompletionQueue *cq_; + grpc::ServerContext ctx_; + grpc::ServerAsyncResponseWriter responder_; + proto::PredictRequest request_; + proto::PredictReply response_; +}; + +class WorkerAgentExitContext : public WorkerServiceContext { + public: + WorkerAgentExitContext(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service, + grpc::ServerCompletionQueue *cq) + : service_impl_(service_impl), async_service_(async_service), cq_(cq), responder_(&ctx_) { + state_ = STATE::CREATE; + } + + ~WorkerAgentExitContext() = default; + + static Status EnqueueRequest(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service, + grpc::ServerCompletionQueue *cq) { + auto call = new WorkerAgentExitContext(service_impl, async_service, cq); + call->StartEnqueueRequest(); + return SUCCESS; + } + + void StartEnqueueRequest() override { + state_ = STATE::PROCESS; + async_service_->RequestExit(&ctx_, &request_, &responder_, cq_, cq_, this); + } + + void HandleRequest() override { + EnqueueRequest(service_impl_, async_service_, cq_); + state_ = STATE::FINISH; + grpc::Status status = service_impl_->Exit(&ctx_, &request_, &response_); + responder_.Finish(response_, status, this); + } + + bool JudgeFinish() override { return state_ == STATE::FINISH; } + + private: + MSDistributedImpl *service_impl_; + proto::MSWorker::AsyncService *async_service_; + grpc::ServerCompletionQueue *cq_; + grpc::ServerContext ctx_; + grpc::ServerAsyncResponseWriter responder_; + proto::ExitRequest request_; + proto::ExitReply response_; +}; + +class DistributedWorkerGrpcServer : public WorkerGrpcServer { + public: + DistributedWorkerGrpcServer(const std::string &host, int32_t port, MSDistributedImpl *service_impl) + : WorkerGrpcServer(host, port, service_impl), distributed_service_impl_(service_impl) {} + + ~DistributedWorkerGrpcServer() = default; + + Status EnqueueRequest() { + WorkerGrpcServer::EnqueueRequest(); + WorkerAgentRegisterContext::EnqueueRequest(distributed_service_impl_, &svc_, cq_.get()); + WorkerAgentExitContext::EnqueueRequest(distributed_service_impl_, &svc_, cq_.get()); + return SUCCESS; + } + + private: + MSDistributedImpl *distributed_service_impl_; +}; + +} // namespace serving +} // namespace mindspore + +#endif // MINDSPORE_SERVING_WORKER_DISTRIBUTED_WORKER_SERVER_H diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.cc b/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.cc index 230e225..d9e6b73 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.cc +++ b/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.cc @@ -25,7 +25,7 @@ namespace mindspore { namespace serving { -GrpcNotfiyDistributeWorker::GrpcNotfiyDistributeWorker(const std::string &distributed_worker_ip, +GrpcNotifyDistributeWorker::GrpcNotifyDistributeWorker(const std::string &distributed_worker_ip, uint32_t distributed_worker_port, const std::string &host_ip, uint32_t host_port) : distributed_worker_ip_(distributed_worker_ip), @@ -35,12 +35,12 @@ GrpcNotfiyDistributeWorker::GrpcNotfiyDistributeWorker(const std::string &distri distributed_worker_address_ = distributed_worker_ip + ":" + std::to_string(distributed_worker_port); agent_address_ = host_ip_ + ":" + std::to_string(host_port_); auto channel = GrpcServer::CreateChannel(distributed_worker_address_); - stub_ = proto::MSDistributedWorker::NewStub(channel); + stub_ = proto::MSWorker::NewStub(channel); } -GrpcNotfiyDistributeWorker::~GrpcNotfiyDistributeWorker() = default; +GrpcNotifyDistributeWorker::~GrpcNotifyDistributeWorker() = default; -Status GrpcNotfiyDistributeWorker::Register(const std::vector &worker_specs) { +Status GrpcNotifyDistributeWorker::Register(const std::vector &worker_specs) { const int32_t REGISTER_TIME_OUT = 60; const int32_t REGISTER_INTERVAL = 1; auto loop = REGISTER_TIME_OUT; @@ -67,7 +67,7 @@ Status GrpcNotfiyDistributeWorker::Register(const std::vector & return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Register TimeOut"; } -Status GrpcNotfiyDistributeWorker::Unregister() { +Status GrpcNotifyDistributeWorker::Unregister() { if (is_stoped_.load()) { return SUCCESS; } diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.h b/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.h index d618878..2c2724c 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.h +++ b/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.h @@ -20,18 +20,18 @@ #include #include #include "worker/distributed_worker/notify_distributed/base_notify_worker.h" -#include "proto/ms_master.pb.h" -#include "proto/ms_master.grpc.pb.h" #include "proto/ms_distributed.pb.h" #include "proto/ms_distributed.grpc.pb.h" +#include "proto/ms_worker.pb.h" +#include "proto/ms_worker.grpc.pb.h" namespace mindspore { namespace serving { -class MS_API GrpcNotfiyDistributeWorker : public BaseNotifyDistributeWorker { +class MS_API GrpcNotifyDistributeWorker : public BaseNotifyDistributeWorker { public: - GrpcNotfiyDistributeWorker(const std::string &master_ip, uint32_t master_port, const std::string &host_ip, + GrpcNotifyDistributeWorker(const std::string &master_ip, uint32_t master_port, const std::string &host_ip, uint32_t host_port); - ~GrpcNotfiyDistributeWorker() override; + ~GrpcNotifyDistributeWorker() override; Status Register(const std::vector &worker_specs) override; Status Unregister() override; @@ -42,7 +42,7 @@ class MS_API GrpcNotfiyDistributeWorker : public BaseNotifyDistributeWorker { uint32_t host_port_; std::string agent_address_; std::string distributed_worker_address_; - std::unique_ptr stub_; + std::unique_ptr stub_; std::atomic is_stoped_{false}; }; diff --git a/mindspore_serving/ccsrc/worker/grpc/worker_process.cc b/mindspore_serving/ccsrc/worker/grpc/worker_process.cc index 73c38d1..2d41b03 100644 --- a/mindspore_serving/ccsrc/worker/grpc/worker_process.cc +++ b/mindspore_serving/ccsrc/worker/grpc/worker_process.cc @@ -15,7 +15,6 @@ */ #include "worker/grpc/worker_process.h" -#include "master/dispacther.h" #include "worker/worker.h" namespace mindspore { diff --git a/mindspore_serving/ccsrc/worker/grpc/worker_process.h b/mindspore_serving/ccsrc/worker/grpc/worker_process.h index 450158e..ebdb3c5 100644 --- a/mindspore_serving/ccsrc/worker/grpc/worker_process.h +++ b/mindspore_serving/ccsrc/worker/grpc/worker_process.h @@ -28,7 +28,7 @@ namespace mindspore { namespace serving { // Service Implement -class MSWorkerImpl final : public proto::MSWorker::Service { +class MSWorkerImpl : public proto::MSWorker::Service { public: grpc::Status Predict(grpc::ServerContext *context, const proto::PredictRequest *request, proto::PredictReply *reply) override; diff --git a/mindspore_serving/ccsrc/worker/grpc/worker_server.cc b/mindspore_serving/ccsrc/worker/grpc/worker_server.cc index cc603ad..58880df 100644 --- a/mindspore_serving/ccsrc/worker/grpc/worker_server.cc +++ b/mindspore_serving/ccsrc/worker/grpc/worker_server.cc @@ -21,12 +21,20 @@ namespace mindspore { namespace serving { + MSWorkerServer::~MSWorkerServer() { Stop(); } -MSWorkerServer::MSWorkerServer(const std::string &hostname, int32_t port) { +Status MSWorkerServer::StartWorkerGrpcServer(const std::string &hostname, int32_t port) { + if (in_running_) { + return INFER_STATUS_LOG_ERROR(FAILED) << "Worker grpc server is already running"; + } service_impl_ = std::make_unique(); async_server_ = std::make_unique(hostname, port, service_impl_.get()); + return Init(); } + +MSWorkerServer::MSWorkerServer() = default; + Status MSWorkerServer::Init() { Status status = async_server_->Run("Worker gRPC", gRpcMaxMBMsgSize); if (status != SUCCESS) return status; @@ -40,10 +48,14 @@ Status MSWorkerServer::StartAsyncRpcService() { return status; } Status MSWorkerServer::Stop() { - if (in_running_) { + if (in_running_ && async_server_) { async_server_->Stop(); - grpc_thread_.join(); + if (grpc_thread_.joinable()) { + grpc_thread_.join(); + } } + async_server_ = nullptr; + service_impl_ = nullptr; in_running_ = false; return SUCCESS; } diff --git a/mindspore_serving/ccsrc/worker/grpc/worker_server.h b/mindspore_serving/ccsrc/worker/grpc/worker_server.h index 8bcc057..1452727 100644 --- a/mindspore_serving/ccsrc/worker/grpc/worker_server.h +++ b/mindspore_serving/ccsrc/worker/grpc/worker_server.h @@ -27,27 +27,29 @@ #include "proto/ms_worker.grpc.pb.h" #include "common/grpc_async_server.h" #include "worker/grpc/worker_process.h" +#include "worker/distributed_worker/distributed_servable.h" namespace mindspore { namespace serving { // Service Implement -class MSWorkerServer { +class MS_API MSWorkerServer { public: enum ServerState { kGdsUninit = 0, kGdsInitializing, kGdsRunning, kGdsStopped }; - MSWorkerServer(const std::string &hostname, int32_t port); - ~MSWorkerServer(); - - Status Init(); + MSWorkerServer(); + virtual ~MSWorkerServer(); + Status StartWorkerGrpcServer(const std::string &hostname, int32_t port); Status Stop(); - Status StartAsyncRpcService(); - + protected: bool in_running_ = false; std::thread grpc_thread_; - std::unique_ptr service_impl_; - std::unique_ptr async_server_; + std::unique_ptr service_impl_ = nullptr; + std::unique_ptr async_server_ = nullptr; + + Status StartAsyncRpcService(); + Status Init(); }; class WorkerServiceContext { @@ -174,7 +176,7 @@ class WorkerGrpcServer : public GrpcAsyncServer { return SUCCESS; } - private: + protected: MSWorkerImpl *service_impl_; proto::MSWorker::AsyncService svc_; }; diff --git a/mindspore_serving/ccsrc/worker/work_executor.h b/mindspore_serving/ccsrc/worker/work_executor.h index 2c44e6a..d491843 100644 --- a/mindspore_serving/ccsrc/worker/work_executor.h +++ b/mindspore_serving/ccsrc/worker/work_executor.h @@ -39,10 +39,8 @@ using WorkCallBack = std::function py_preprocess_task_queue, - std::shared_ptr py_postprocess_task_queue, - std::shared_ptr cpp_preprocess_task_queue, - std::shared_ptr cpp_postprocess_task_queue); + WorkExecutor(std::shared_ptr py_preprocess, std::shared_ptr py_postprocess, + std::shared_ptr cpp_preprocess, std::shared_ptr cpp_postprocess); ~WorkExecutor(); Status Init(const ServableSignature &servable_declare, const std::shared_ptr &servable); diff --git a/mindspore_serving/ccsrc/worker/worker.cc b/mindspore_serving/ccsrc/worker/worker.cc index fed5d42..87167ef 100644 --- a/mindspore_serving/ccsrc/worker/worker.cc +++ b/mindspore_serving/ccsrc/worker/worker.cc @@ -34,21 +34,11 @@ namespace py = pybind11; namespace mindspore { namespace serving { -static std::unique_ptr grpc_async_worker_server_; - Worker &Worker::GetInstance() { static Worker instance; return instance; } -Status Worker::StartGrpcServer(const std::string &ip, uint32_t grpc_port) { - if (grpc_async_worker_server_ != nullptr) { - return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Serving Error: Worker gRPC server is already running"; - } - grpc_async_worker_server_ = std::make_unique(ip, grpc_port); - return grpc_async_worker_server_->Init(); -} - Status Worker::RegisterWorker() { std::vector worker_specs; for (auto &work : work_list_) { @@ -184,6 +174,11 @@ void Worker::Update() { */ } +Status Worker::AfterStartGrpcServer(const std::shared_ptr &grpc_server) { + worker_grpc_server_ = grpc_server; + return SUCCESS; +} + Status Worker::StartServable(std::shared_ptr servable, std::shared_ptr notify_master) { ExitSignalHandle::Instance().Start(); // handle ctrl+c to exit if (servable_started_) { @@ -244,7 +239,7 @@ void Worker::StopServable(bool notify_master) { void Worker::Clear() { std::unique_lock lock(worker_shared_lock_); ServableStorage::Instance().Clear(); - grpc_async_worker_server_ = nullptr; + worker_grpc_server_ = nullptr; if (clear_flag_.test_and_set()) { return; } diff --git a/mindspore_serving/ccsrc/worker/worker.h b/mindspore_serving/ccsrc/worker/worker.h index 122b7d4..50d95eb 100644 --- a/mindspore_serving/ccsrc/worker/worker.h +++ b/mindspore_serving/ccsrc/worker/worker.h @@ -33,6 +33,7 @@ #include "worker/version_control/version_controller.h" #include "common/grpc_async_server.h" #include "worker/sevable_base.h" +#include "worker/grpc/worker_server.h" namespace mindspore { namespace serving { @@ -74,10 +75,11 @@ class MS_API Worker { const std::vector &inputs); Status StartServable(std::shared_ptr servable, std::shared_ptr notify_master); + Status AfterStartGrpcServer(const std::shared_ptr &grpc_server); + void StopServable(bool notify_master = true); bool HasCleared(); Status RegisterWorker(); - Status StartGrpcServer(const std::string &ip, uint32_t grpc_port); void Update(); Status StartVersionController(); Status AddWorker(const ServableWorkerContext &work); @@ -101,6 +103,7 @@ class MS_API Worker { std::atomic_bool servable_started_ = false; std::atomic_flag clear_flag_ = ATOMIC_FLAG_INIT; std::shared_ptr notify_master_ = nullptr; + std::shared_ptr worker_grpc_server_ = nullptr; std::shared_mutex worker_shared_lock_; diff --git a/mindspore_serving/proto/ms_distributed.proto b/mindspore_serving/proto/ms_distributed.proto index c7be82f..13b13d5 100644 --- a/mindspore_serving/proto/ms_distributed.proto +++ b/mindspore_serving/proto/ms_distributed.proto @@ -45,8 +45,3 @@ message AgentExitRequest { message AgentExitReply { ErrorMsg error_msg = 1; } - -service MSDistributedWorker { - rpc AgentExit(AgentExitRequest) returns (AgentExitReply) {} - rpc AgentRegister(AgentRegisterRequest) returns (AgentRegisterReply) {} -} \ No newline at end of file diff --git a/mindspore_serving/proto/ms_worker.proto b/mindspore_serving/proto/ms_worker.proto index c9ed051..436b52f 100644 --- a/mindspore_serving/proto/ms_worker.proto +++ b/mindspore_serving/proto/ms_worker.proto @@ -20,8 +20,13 @@ syntax = "proto3"; package mindspore.serving.proto; import "mindspore_serving/proto/ms_service.proto"; import "mindspore_serving/proto/ms_master.proto"; +import "mindspore_serving/proto/ms_distributed.proto"; service MSWorker { + // for master rpc Predict(PredictRequest) returns (PredictReply) {} rpc Exit(ExitRequest) returns (ExitReply) {} + // for worker agent + rpc AgentExit(AgentExitRequest) returns (AgentExitReply) {} + rpc AgentRegister(AgentRegisterRequest) returns (AgentRegisterReply) {} }