| @@ -55,6 +55,17 @@ void ExitSignalHandle::WorkerWait() { | |||
| exit_future.wait(); | |||
| } | |||
| // waiting ctrl+c or stop message to exit, | |||
| // if no server is running or server has exited, there is no need to wait | |||
| void ExitSignalHandle::AgentWait() { | |||
| if (!is_running_) { | |||
| MSI_LOG_INFO << "Exit Handle has not started or has exited"; | |||
| return; | |||
| } | |||
| auto exit_future = agent_exit_requested_.get_future(); | |||
| exit_future.wait(); | |||
| } | |||
| void ExitSignalHandle::Start() { | |||
| if (is_running_) { | |||
| return; | |||
| @@ -62,6 +73,7 @@ void ExitSignalHandle::Start() { | |||
| is_running_ = true; | |||
| master_exit_requested_ = std::promise<void>(); | |||
| worker_exit_requested_ = std::promise<void>(); | |||
| agent_exit_requested_ = std::promise<void>(); | |||
| has_exited_.clear(); | |||
| InitSignalHandle(); | |||
| } | |||
| @@ -79,6 +91,7 @@ void ExitSignalHandle::HandleSignalInner() { | |||
| if (!has_exited_.test_and_set()) { | |||
| master_exit_requested_.set_value(); | |||
| worker_exit_requested_.set_value(); | |||
| agent_exit_requested_.set_value(); | |||
| is_running_ = false; | |||
| } | |||
| } | |||
| @@ -32,6 +32,7 @@ class MS_API ExitSignalHandle { | |||
| void InitSignalHandle(); | |||
| void MasterWait(); | |||
| void WorkerWait(); | |||
| void AgentWait(); | |||
| void Start(); | |||
| void Stop(); | |||
| bool HasStopped(); | |||
| @@ -39,6 +40,7 @@ class MS_API ExitSignalHandle { | |||
| private: | |||
| std::promise<void> master_exit_requested_; | |||
| std::promise<void> worker_exit_requested_; | |||
| std::promise<void> agent_exit_requested_; | |||
| std::atomic_flag has_exited_ = true; | |||
| std::atomic_flag has_inited_ = ATOMIC_FLAG_INIT; | |||
| std::atomic_bool is_running_ = false; | |||
| @@ -23,13 +23,15 @@ | |||
| #include "worker/notfiy_master/local_notify.h" | |||
| #include "worker/local_servable/local_sevable.h" | |||
| #include "worker/distributed_worker/distributed_servable.h" | |||
| #include "worker/grpc/worker_server.h" | |||
| #include "worker/distributed_worker/grpc/distributed_server.h" | |||
| namespace mindspore::serving { | |||
| void PyWorker::StartServable(const std::string &model_directory, const std::string &model_name, uint32_t version_number, | |||
| const std::string &master_ip, uint32_t master_port, const std::string &host_ip, | |||
| uint32_t host_port) { | |||
| auto notify_master = std::make_shared<GrpcNotfiyMaster>(master_ip, master_port, host_ip, host_port); | |||
| const std::string &master_ip, uint32_t master_port, const std::string &worker_ip, | |||
| uint32_t worker_port) { | |||
| auto notify_master = std::make_shared<GrpcNotfiyMaster>(master_ip, master_port, worker_ip, worker_port); | |||
| auto servable = std::make_shared<LocalModelServable>(); | |||
| auto status = servable->StartServable(model_directory, model_name, version_number); | |||
| if (status != SUCCESS) { | |||
| @@ -39,10 +41,14 @@ void PyWorker::StartServable(const std::string &model_directory, const std::stri | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | |||
| } | |||
| status = Worker::GetInstance().StartGrpcServer(host_ip, host_port); | |||
| // start grpc server | |||
| auto grpc_sever = std::make_shared<MSWorkerServer>(); | |||
| status = grpc_sever->StartWorkerGrpcServer(worker_ip, worker_port); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | |||
| } | |||
| Worker::GetInstance().AfterStartGrpcServer(grpc_sever); | |||
| status = Worker::GetInstance().StartVersionController(); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | |||
| @@ -72,12 +78,15 @@ void PyWorker::StartDistributedServable(const std::string &servable_directory, c | |||
| const std::string &worker_ip, uint32_t worker_port, | |||
| const std::string &master_ip, uint32_t master_port) { | |||
| Status status; | |||
| status = Worker::GetInstance().StartGrpcServer(worker_ip, worker_port); | |||
| auto servable = std::make_shared<DistributedServable>(); | |||
| auto grpc_sever = std::make_shared<MSDistributedWorkerServer>(); | |||
| status = grpc_sever->StartDistributedWorkerGrpcServer(servable, worker_ip, worker_port); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | |||
| } | |||
| Worker::GetInstance().AfterStartGrpcServer(grpc_sever); | |||
| auto notify_master = std::make_shared<GrpcNotfiyMaster>(master_ip, master_port, worker_ip, worker_port); | |||
| auto servable = std::make_shared<DistributedServable>(); | |||
| status = servable->StartServable(servable_directory, servable_name, rank_table_json_file, version_number); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | |||
| @@ -96,13 +105,15 @@ void PyWorker::StartDistributedServableInMaster(const std::string &servable_dire | |||
| const std::string &rank_table_json_file, uint32_t version_number, | |||
| const std::string &worker_ip, uint32_t worker_port) { | |||
| Status status; | |||
| status = Worker::GetInstance().StartGrpcServer(worker_ip, worker_port); | |||
| auto servable = std::make_shared<DistributedServable>(); | |||
| auto grpc_sever = std::make_shared<MSDistributedWorkerServer>(); | |||
| status = grpc_sever->StartDistributedWorkerGrpcServer(servable, worker_ip, worker_port); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | |||
| } | |||
| Worker::GetInstance().AfterStartGrpcServer(grpc_sever); | |||
| auto notify_master = std::make_shared<LocalNotifyMaster>(); | |||
| auto servable = std::make_shared<DistributedServable>(); | |||
| status = servable->StartServable(servable_directory, servable_name, rank_table_json_file, version_number); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | |||
| @@ -14,7 +14,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "worker/distributed_worker/distributed_process/distributed_process.h" | |||
| #include "worker/distributed_worker/grpc/distributed_process.h" | |||
| #include "common/proto_tensor.h" | |||
| namespace mindspore { | |||
| @@ -27,12 +27,13 @@ | |||
| #include "proto/ms_distributed.pb.h" | |||
| #include "proto/ms_distributed.grpc.pb.h" | |||
| #include "worker/distributed_worker/distributed_servable.h" | |||
| #include "worker/grpc/worker_process.h" | |||
| namespace mindspore { | |||
| namespace serving { | |||
| // Service Implement | |||
| class MSDistributedImpl final : public proto::MSDistributedWorker::Service { | |||
| class MSDistributedImpl final : public MSWorkerImpl { | |||
| public: | |||
| explicit MSDistributedImpl(std::shared_ptr<DistributedServable> servable) : servable_(servable) {} | |||
| ~MSDistributedImpl() = default; | |||
| @@ -0,0 +1,38 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "worker/distributed_worker/grpc/distributed_server.h" | |||
| #include <string> | |||
| #include <memory> | |||
| #include <utility> | |||
| #include "common/grpc_server.h" | |||
| namespace mindspore { | |||
| namespace serving { | |||
| Status MSDistributedWorkerServer::StartDistributedWorkerGrpcServer(std::shared_ptr<DistributedServable> servable, | |||
| const std::string &hostname, int32_t port) { | |||
| if (in_running_) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) << "Worker grpc server is already running"; | |||
| } | |||
| auto impl = std::make_unique<MSDistributedImpl>(servable); | |||
| async_server_ = std::make_unique<DistributedWorkerGrpcServer>(hostname, port, impl.get()); | |||
| service_impl_ = std::move(impl); | |||
| return Init(); | |||
| } | |||
| } // namespace serving | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,147 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_SERVING_WORKER_DISTRIBUTED_WORKER_SERVER_H | |||
| #define MINDSPORE_SERVING_WORKER_DISTRIBUTED_WORKER_SERVER_H | |||
| #include <grpcpp/grpcpp.h> | |||
| #include <grpcpp/health_check_service_interface.h> | |||
| #include <grpcpp/ext/proto_server_reflection_plugin.h> | |||
| #include <memory> | |||
| #include <string> | |||
| #include "common/serving_common.h" | |||
| #include "proto/ms_worker.pb.h" | |||
| #include "proto/ms_worker.grpc.pb.h" | |||
| #include "common/grpc_async_server.h" | |||
| #include "worker/grpc/worker_process.h" | |||
| #include "worker/grpc/worker_server.h" | |||
| #include "worker/distributed_worker/grpc/distributed_process.h" | |||
| namespace mindspore { | |||
| namespace serving { | |||
| // Service Implement | |||
| class MS_API MSDistributedWorkerServer : public MSWorkerServer { | |||
| public: | |||
| Status StartDistributedWorkerGrpcServer(std::shared_ptr<DistributedServable> servable, const std::string &hostname, | |||
| int32_t port); | |||
| }; | |||
| // Service Implement | |||
| class WorkerAgentRegisterContext : public WorkerServiceContext { | |||
| public: | |||
| WorkerAgentRegisterContext(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service, | |||
| grpc::ServerCompletionQueue *cq) | |||
| : service_impl_(service_impl), async_service_(async_service), cq_(cq), responder_(&ctx_) { | |||
| state_ = STATE::CREATE; | |||
| } | |||
| ~WorkerAgentRegisterContext() = default; | |||
| static Status EnqueueRequest(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service, | |||
| grpc::ServerCompletionQueue *cq) { | |||
| auto call = new WorkerAgentRegisterContext(service_impl, async_service, cq); | |||
| call->StartEnqueueRequest(); | |||
| return SUCCESS; | |||
| } | |||
| void StartEnqueueRequest() override { | |||
| state_ = STATE::PROCESS; | |||
| async_service_->RequestPredict(&ctx_, &request_, &responder_, cq_, cq_, this); | |||
| } | |||
| void HandleRequest() override { | |||
| EnqueueRequest(service_impl_, async_service_, cq_); | |||
| state_ = STATE::FINISH; | |||
| grpc::Status status = service_impl_->Predict(&ctx_, &request_, &response_); | |||
| responder_.Finish(response_, status, this); | |||
| } | |||
| bool JudgeFinish() override { return state_ == STATE::FINISH; } | |||
| private: | |||
| MSDistributedImpl *service_impl_; | |||
| proto::MSWorker::AsyncService *async_service_; | |||
| grpc::ServerCompletionQueue *cq_; | |||
| grpc::ServerContext ctx_; | |||
| grpc::ServerAsyncResponseWriter<proto::PredictReply> responder_; | |||
| proto::PredictRequest request_; | |||
| proto::PredictReply response_; | |||
| }; | |||
| class WorkerAgentExitContext : public WorkerServiceContext { | |||
| public: | |||
| WorkerAgentExitContext(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service, | |||
| grpc::ServerCompletionQueue *cq) | |||
| : service_impl_(service_impl), async_service_(async_service), cq_(cq), responder_(&ctx_) { | |||
| state_ = STATE::CREATE; | |||
| } | |||
| ~WorkerAgentExitContext() = default; | |||
| static Status EnqueueRequest(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service, | |||
| grpc::ServerCompletionQueue *cq) { | |||
| auto call = new WorkerAgentExitContext(service_impl, async_service, cq); | |||
| call->StartEnqueueRequest(); | |||
| return SUCCESS; | |||
| } | |||
| void StartEnqueueRequest() override { | |||
| state_ = STATE::PROCESS; | |||
| async_service_->RequestExit(&ctx_, &request_, &responder_, cq_, cq_, this); | |||
| } | |||
| void HandleRequest() override { | |||
| EnqueueRequest(service_impl_, async_service_, cq_); | |||
| state_ = STATE::FINISH; | |||
| grpc::Status status = service_impl_->Exit(&ctx_, &request_, &response_); | |||
| responder_.Finish(response_, status, this); | |||
| } | |||
| bool JudgeFinish() override { return state_ == STATE::FINISH; } | |||
| private: | |||
| MSDistributedImpl *service_impl_; | |||
| proto::MSWorker::AsyncService *async_service_; | |||
| grpc::ServerCompletionQueue *cq_; | |||
| grpc::ServerContext ctx_; | |||
| grpc::ServerAsyncResponseWriter<proto::ExitReply> responder_; | |||
| proto::ExitRequest request_; | |||
| proto::ExitReply response_; | |||
| }; | |||
| class DistributedWorkerGrpcServer : public WorkerGrpcServer { | |||
| public: | |||
| DistributedWorkerGrpcServer(const std::string &host, int32_t port, MSDistributedImpl *service_impl) | |||
| : WorkerGrpcServer(host, port, service_impl), distributed_service_impl_(service_impl) {} | |||
| ~DistributedWorkerGrpcServer() = default; | |||
| Status EnqueueRequest() { | |||
| WorkerGrpcServer::EnqueueRequest(); | |||
| WorkerAgentRegisterContext::EnqueueRequest(distributed_service_impl_, &svc_, cq_.get()); | |||
| WorkerAgentExitContext::EnqueueRequest(distributed_service_impl_, &svc_, cq_.get()); | |||
| return SUCCESS; | |||
| } | |||
| private: | |||
| MSDistributedImpl *distributed_service_impl_; | |||
| }; | |||
| } // namespace serving | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_SERVING_WORKER_DISTRIBUTED_WORKER_SERVER_H | |||
| @@ -25,7 +25,7 @@ | |||
| namespace mindspore { | |||
| namespace serving { | |||
| GrpcNotfiyDistributeWorker::GrpcNotfiyDistributeWorker(const std::string &distributed_worker_ip, | |||
| GrpcNotifyDistributeWorker::GrpcNotifyDistributeWorker(const std::string &distributed_worker_ip, | |||
| uint32_t distributed_worker_port, const std::string &host_ip, | |||
| uint32_t host_port) | |||
| : distributed_worker_ip_(distributed_worker_ip), | |||
| @@ -35,12 +35,12 @@ GrpcNotfiyDistributeWorker::GrpcNotfiyDistributeWorker(const std::string &distri | |||
| distributed_worker_address_ = distributed_worker_ip + ":" + std::to_string(distributed_worker_port); | |||
| agent_address_ = host_ip_ + ":" + std::to_string(host_port_); | |||
| auto channel = GrpcServer::CreateChannel(distributed_worker_address_); | |||
| stub_ = proto::MSDistributedWorker::NewStub(channel); | |||
| stub_ = proto::MSWorker::NewStub(channel); | |||
| } | |||
| GrpcNotfiyDistributeWorker::~GrpcNotfiyDistributeWorker() = default; | |||
| GrpcNotifyDistributeWorker::~GrpcNotifyDistributeWorker() = default; | |||
| Status GrpcNotfiyDistributeWorker::Register(const std::vector<WorkerAgentSpec> &worker_specs) { | |||
| Status GrpcNotifyDistributeWorker::Register(const std::vector<WorkerAgentSpec> &worker_specs) { | |||
| const int32_t REGISTER_TIME_OUT = 60; | |||
| const int32_t REGISTER_INTERVAL = 1; | |||
| auto loop = REGISTER_TIME_OUT; | |||
| @@ -67,7 +67,7 @@ Status GrpcNotfiyDistributeWorker::Register(const std::vector<WorkerAgentSpec> & | |||
| return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Register TimeOut"; | |||
| } | |||
| Status GrpcNotfiyDistributeWorker::Unregister() { | |||
| Status GrpcNotifyDistributeWorker::Unregister() { | |||
| if (is_stoped_.load()) { | |||
| return SUCCESS; | |||
| } | |||
| @@ -20,18 +20,18 @@ | |||
| #include <string> | |||
| #include <memory> | |||
| #include "worker/distributed_worker/notify_distributed/base_notify_worker.h" | |||
| #include "proto/ms_master.pb.h" | |||
| #include "proto/ms_master.grpc.pb.h" | |||
| #include "proto/ms_distributed.pb.h" | |||
| #include "proto/ms_distributed.grpc.pb.h" | |||
| #include "proto/ms_worker.pb.h" | |||
| #include "proto/ms_worker.grpc.pb.h" | |||
| namespace mindspore { | |||
| namespace serving { | |||
| class MS_API GrpcNotfiyDistributeWorker : public BaseNotifyDistributeWorker { | |||
| class MS_API GrpcNotifyDistributeWorker : public BaseNotifyDistributeWorker { | |||
| public: | |||
| GrpcNotfiyDistributeWorker(const std::string &master_ip, uint32_t master_port, const std::string &host_ip, | |||
| GrpcNotifyDistributeWorker(const std::string &master_ip, uint32_t master_port, const std::string &host_ip, | |||
| uint32_t host_port); | |||
| ~GrpcNotfiyDistributeWorker() override; | |||
| ~GrpcNotifyDistributeWorker() override; | |||
| Status Register(const std::vector<WorkerAgentSpec> &worker_specs) override; | |||
| Status Unregister() override; | |||
| @@ -42,7 +42,7 @@ class MS_API GrpcNotfiyDistributeWorker : public BaseNotifyDistributeWorker { | |||
| uint32_t host_port_; | |||
| std::string agent_address_; | |||
| std::string distributed_worker_address_; | |||
| std::unique_ptr<proto::MSDistributedWorker::Stub> stub_; | |||
| std::unique_ptr<proto::MSWorker::Stub> stub_; | |||
| std::atomic<bool> is_stoped_{false}; | |||
| }; | |||
| @@ -15,7 +15,6 @@ | |||
| */ | |||
| #include "worker/grpc/worker_process.h" | |||
| #include "master/dispacther.h" | |||
| #include "worker/worker.h" | |||
| namespace mindspore { | |||
| @@ -28,7 +28,7 @@ namespace mindspore { | |||
| namespace serving { | |||
| // Service Implement | |||
| class MSWorkerImpl final : public proto::MSWorker::Service { | |||
| class MSWorkerImpl : public proto::MSWorker::Service { | |||
| public: | |||
| grpc::Status Predict(grpc::ServerContext *context, const proto::PredictRequest *request, | |||
| proto::PredictReply *reply) override; | |||
| @@ -21,12 +21,20 @@ | |||
| namespace mindspore { | |||
| namespace serving { | |||
| MSWorkerServer::~MSWorkerServer() { Stop(); } | |||
| MSWorkerServer::MSWorkerServer(const std::string &hostname, int32_t port) { | |||
| Status MSWorkerServer::StartWorkerGrpcServer(const std::string &hostname, int32_t port) { | |||
| if (in_running_) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) << "Worker grpc server is already running"; | |||
| } | |||
| service_impl_ = std::make_unique<MSWorkerImpl>(); | |||
| async_server_ = std::make_unique<WorkerGrpcServer>(hostname, port, service_impl_.get()); | |||
| return Init(); | |||
| } | |||
| MSWorkerServer::MSWorkerServer() = default; | |||
| Status MSWorkerServer::Init() { | |||
| Status status = async_server_->Run("Worker gRPC", gRpcMaxMBMsgSize); | |||
| if (status != SUCCESS) return status; | |||
| @@ -40,10 +48,14 @@ Status MSWorkerServer::StartAsyncRpcService() { | |||
| return status; | |||
| } | |||
| Status MSWorkerServer::Stop() { | |||
| if (in_running_) { | |||
| if (in_running_ && async_server_) { | |||
| async_server_->Stop(); | |||
| grpc_thread_.join(); | |||
| if (grpc_thread_.joinable()) { | |||
| grpc_thread_.join(); | |||
| } | |||
| } | |||
| async_server_ = nullptr; | |||
| service_impl_ = nullptr; | |||
| in_running_ = false; | |||
| return SUCCESS; | |||
| } | |||
| @@ -27,27 +27,29 @@ | |||
| #include "proto/ms_worker.grpc.pb.h" | |||
| #include "common/grpc_async_server.h" | |||
| #include "worker/grpc/worker_process.h" | |||
| #include "worker/distributed_worker/distributed_servable.h" | |||
| namespace mindspore { | |||
| namespace serving { | |||
| // Service Implement | |||
| class MSWorkerServer { | |||
| class MS_API MSWorkerServer { | |||
| public: | |||
| enum ServerState { kGdsUninit = 0, kGdsInitializing, kGdsRunning, kGdsStopped }; | |||
| MSWorkerServer(const std::string &hostname, int32_t port); | |||
| ~MSWorkerServer(); | |||
| Status Init(); | |||
| MSWorkerServer(); | |||
| virtual ~MSWorkerServer(); | |||
| Status StartWorkerGrpcServer(const std::string &hostname, int32_t port); | |||
| Status Stop(); | |||
| Status StartAsyncRpcService(); | |||
| protected: | |||
| bool in_running_ = false; | |||
| std::thread grpc_thread_; | |||
| std::unique_ptr<MSWorkerImpl> service_impl_; | |||
| std::unique_ptr<GrpcAsyncServer> async_server_; | |||
| std::unique_ptr<MSWorkerImpl> service_impl_ = nullptr; | |||
| std::unique_ptr<GrpcAsyncServer> async_server_ = nullptr; | |||
| Status StartAsyncRpcService(); | |||
| Status Init(); | |||
| }; | |||
| class WorkerServiceContext { | |||
| @@ -174,7 +176,7 @@ class WorkerGrpcServer : public GrpcAsyncServer { | |||
| return SUCCESS; | |||
| } | |||
| private: | |||
| protected: | |||
| MSWorkerImpl *service_impl_; | |||
| proto::MSWorker::AsyncService svc_; | |||
| }; | |||
| @@ -39,10 +39,8 @@ using WorkCallBack = std::function<void(const Instance &output, const Status &er | |||
| class WorkExecutor { | |||
| public: | |||
| WorkExecutor(std::shared_ptr<TaskQueue> py_preprocess_task_queue, | |||
| std::shared_ptr<TaskQueue> py_postprocess_task_queue, | |||
| std::shared_ptr<TaskQueue> cpp_preprocess_task_queue, | |||
| std::shared_ptr<TaskQueue> cpp_postprocess_task_queue); | |||
| WorkExecutor(std::shared_ptr<TaskQueue> py_preprocess, std::shared_ptr<TaskQueue> py_postprocess, | |||
| std::shared_ptr<TaskQueue> cpp_preprocess, std::shared_ptr<TaskQueue> cpp_postprocess); | |||
| ~WorkExecutor(); | |||
| Status Init(const ServableSignature &servable_declare, const std::shared_ptr<ServableBase> &servable); | |||
| @@ -34,21 +34,11 @@ namespace py = pybind11; | |||
| namespace mindspore { | |||
| namespace serving { | |||
| static std::unique_ptr<MSWorkerServer> grpc_async_worker_server_; | |||
| Worker &Worker::GetInstance() { | |||
| static Worker instance; | |||
| return instance; | |||
| } | |||
| Status Worker::StartGrpcServer(const std::string &ip, uint32_t grpc_port) { | |||
| if (grpc_async_worker_server_ != nullptr) { | |||
| return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Serving Error: Worker gRPC server is already running"; | |||
| } | |||
| grpc_async_worker_server_ = std::make_unique<MSWorkerServer>(ip, grpc_port); | |||
| return grpc_async_worker_server_->Init(); | |||
| } | |||
| Status Worker::RegisterWorker() { | |||
| std::vector<WorkerSpec> worker_specs; | |||
| for (auto &work : work_list_) { | |||
| @@ -184,6 +174,11 @@ void Worker::Update() { | |||
| */ | |||
| } | |||
| Status Worker::AfterStartGrpcServer(const std::shared_ptr<MSWorkerServer> &grpc_server) { | |||
| worker_grpc_server_ = grpc_server; | |||
| return SUCCESS; | |||
| } | |||
| Status Worker::StartServable(std::shared_ptr<ServableBase> servable, std::shared_ptr<BaseNotifyMaster> notify_master) { | |||
| ExitSignalHandle::Instance().Start(); // handle ctrl+c to exit | |||
| if (servable_started_) { | |||
| @@ -244,7 +239,7 @@ void Worker::StopServable(bool notify_master) { | |||
| void Worker::Clear() { | |||
| std::unique_lock<std::shared_mutex> lock(worker_shared_lock_); | |||
| ServableStorage::Instance().Clear(); | |||
| grpc_async_worker_server_ = nullptr; | |||
| worker_grpc_server_ = nullptr; | |||
| if (clear_flag_.test_and_set()) { | |||
| return; | |||
| } | |||
| @@ -33,6 +33,7 @@ | |||
| #include "worker/version_control/version_controller.h" | |||
| #include "common/grpc_async_server.h" | |||
| #include "worker/sevable_base.h" | |||
| #include "worker/grpc/worker_server.h" | |||
| namespace mindspore { | |||
| namespace serving { | |||
| @@ -74,10 +75,11 @@ class MS_API Worker { | |||
| const std::vector<InstanceData> &inputs); | |||
| Status StartServable(std::shared_ptr<ServableBase> servable, std::shared_ptr<BaseNotifyMaster> notify_master); | |||
| Status AfterStartGrpcServer(const std::shared_ptr<MSWorkerServer> &grpc_server); | |||
| void StopServable(bool notify_master = true); | |||
| bool HasCleared(); | |||
| Status RegisterWorker(); | |||
| Status StartGrpcServer(const std::string &ip, uint32_t grpc_port); | |||
| void Update(); | |||
| Status StartVersionController(); | |||
| Status AddWorker(const ServableWorkerContext &work); | |||
| @@ -101,6 +103,7 @@ class MS_API Worker { | |||
| std::atomic_bool servable_started_ = false; | |||
| std::atomic_flag clear_flag_ = ATOMIC_FLAG_INIT; | |||
| std::shared_ptr<BaseNotifyMaster> notify_master_ = nullptr; | |||
| std::shared_ptr<MSWorkerServer> worker_grpc_server_ = nullptr; | |||
| std::shared_mutex worker_shared_lock_; | |||
| @@ -45,8 +45,3 @@ message AgentExitRequest { | |||
| message AgentExitReply { | |||
| ErrorMsg error_msg = 1; | |||
| } | |||
| service MSDistributedWorker { | |||
| rpc AgentExit(AgentExitRequest) returns (AgentExitReply) {} | |||
| rpc AgentRegister(AgentRegisterRequest) returns (AgentRegisterReply) {} | |||
| } | |||
| @@ -20,8 +20,13 @@ syntax = "proto3"; | |||
| package mindspore.serving.proto; | |||
| import "mindspore_serving/proto/ms_service.proto"; | |||
| import "mindspore_serving/proto/ms_master.proto"; | |||
| import "mindspore_serving/proto/ms_distributed.proto"; | |||
| service MSWorker { | |||
| // for master | |||
| rpc Predict(PredictRequest) returns (PredictReply) {} | |||
| rpc Exit(ExitRequest) returns (ExitReply) {} | |||
| // for worker agent | |||
| rpc AgentExit(AgentExitRequest) returns (AgentExitReply) {} | |||
| rpc AgentRegister(AgentRegisterRequest) returns (AgentRegisterReply) {} | |||
| } | |||