| @@ -55,6 +55,17 @@ void ExitSignalHandle::WorkerWait() { | |||||
| exit_future.wait(); | exit_future.wait(); | ||||
| } | } | ||||
| // waiting ctrl+c or stop message to exit, | |||||
| // if no server is running or server has exited, there is no need to wait | |||||
| void ExitSignalHandle::AgentWait() { | |||||
| if (!is_running_) { | |||||
| MSI_LOG_INFO << "Exit Handle has not started or has exited"; | |||||
| return; | |||||
| } | |||||
| auto exit_future = agent_exit_requested_.get_future(); | |||||
| exit_future.wait(); | |||||
| } | |||||
| void ExitSignalHandle::Start() { | void ExitSignalHandle::Start() { | ||||
| if (is_running_) { | if (is_running_) { | ||||
| return; | return; | ||||
| @@ -62,6 +73,7 @@ void ExitSignalHandle::Start() { | |||||
| is_running_ = true; | is_running_ = true; | ||||
| master_exit_requested_ = std::promise<void>(); | master_exit_requested_ = std::promise<void>(); | ||||
| worker_exit_requested_ = std::promise<void>(); | worker_exit_requested_ = std::promise<void>(); | ||||
| agent_exit_requested_ = std::promise<void>(); | |||||
| has_exited_.clear(); | has_exited_.clear(); | ||||
| InitSignalHandle(); | InitSignalHandle(); | ||||
| } | } | ||||
| @@ -79,6 +91,7 @@ void ExitSignalHandle::HandleSignalInner() { | |||||
| if (!has_exited_.test_and_set()) { | if (!has_exited_.test_and_set()) { | ||||
| master_exit_requested_.set_value(); | master_exit_requested_.set_value(); | ||||
| worker_exit_requested_.set_value(); | worker_exit_requested_.set_value(); | ||||
| agent_exit_requested_.set_value(); | |||||
| is_running_ = false; | is_running_ = false; | ||||
| } | } | ||||
| } | } | ||||
| @@ -32,6 +32,7 @@ class MS_API ExitSignalHandle { | |||||
| void InitSignalHandle(); | void InitSignalHandle(); | ||||
| void MasterWait(); | void MasterWait(); | ||||
| void WorkerWait(); | void WorkerWait(); | ||||
| void AgentWait(); | |||||
| void Start(); | void Start(); | ||||
| void Stop(); | void Stop(); | ||||
| bool HasStopped(); | bool HasStopped(); | ||||
| @@ -39,6 +40,7 @@ class MS_API ExitSignalHandle { | |||||
| private: | private: | ||||
| std::promise<void> master_exit_requested_; | std::promise<void> master_exit_requested_; | ||||
| std::promise<void> worker_exit_requested_; | std::promise<void> worker_exit_requested_; | ||||
| std::promise<void> agent_exit_requested_; | |||||
| std::atomic_flag has_exited_ = true; | std::atomic_flag has_exited_ = true; | ||||
| std::atomic_flag has_inited_ = ATOMIC_FLAG_INIT; | std::atomic_flag has_inited_ = ATOMIC_FLAG_INIT; | ||||
| std::atomic_bool is_running_ = false; | std::atomic_bool is_running_ = false; | ||||
| @@ -23,13 +23,15 @@ | |||||
| #include "worker/notfiy_master/local_notify.h" | #include "worker/notfiy_master/local_notify.h" | ||||
| #include "worker/local_servable/local_sevable.h" | #include "worker/local_servable/local_sevable.h" | ||||
| #include "worker/distributed_worker/distributed_servable.h" | #include "worker/distributed_worker/distributed_servable.h" | ||||
| #include "worker/grpc/worker_server.h" | |||||
| #include "worker/distributed_worker/grpc/distributed_server.h" | |||||
| namespace mindspore::serving { | namespace mindspore::serving { | ||||
| void PyWorker::StartServable(const std::string &model_directory, const std::string &model_name, uint32_t version_number, | void PyWorker::StartServable(const std::string &model_directory, const std::string &model_name, uint32_t version_number, | ||||
| const std::string &master_ip, uint32_t master_port, const std::string &host_ip, | |||||
| uint32_t host_port) { | |||||
| auto notify_master = std::make_shared<GrpcNotfiyMaster>(master_ip, master_port, host_ip, host_port); | |||||
| const std::string &master_ip, uint32_t master_port, const std::string &worker_ip, | |||||
| uint32_t worker_port) { | |||||
| auto notify_master = std::make_shared<GrpcNotfiyMaster>(master_ip, master_port, worker_ip, worker_port); | |||||
| auto servable = std::make_shared<LocalModelServable>(); | auto servable = std::make_shared<LocalModelServable>(); | ||||
| auto status = servable->StartServable(model_directory, model_name, version_number); | auto status = servable->StartServable(model_directory, model_name, version_number); | ||||
| if (status != SUCCESS) { | if (status != SUCCESS) { | ||||
| @@ -39,10 +41,14 @@ void PyWorker::StartServable(const std::string &model_directory, const std::stri | |||||
| if (status != SUCCESS) { | if (status != SUCCESS) { | ||||
| MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | ||||
| } | } | ||||
| status = Worker::GetInstance().StartGrpcServer(host_ip, host_port); | |||||
| // start grpc server | |||||
| auto grpc_sever = std::make_shared<MSWorkerServer>(); | |||||
| status = grpc_sever->StartWorkerGrpcServer(worker_ip, worker_port); | |||||
| if (status != SUCCESS) { | if (status != SUCCESS) { | ||||
| MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | ||||
| } | } | ||||
| Worker::GetInstance().AfterStartGrpcServer(grpc_sever); | |||||
| status = Worker::GetInstance().StartVersionController(); | status = Worker::GetInstance().StartVersionController(); | ||||
| if (status != SUCCESS) { | if (status != SUCCESS) { | ||||
| MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | ||||
| @@ -72,12 +78,15 @@ void PyWorker::StartDistributedServable(const std::string &servable_directory, c | |||||
| const std::string &worker_ip, uint32_t worker_port, | const std::string &worker_ip, uint32_t worker_port, | ||||
| const std::string &master_ip, uint32_t master_port) { | const std::string &master_ip, uint32_t master_port) { | ||||
| Status status; | Status status; | ||||
| status = Worker::GetInstance().StartGrpcServer(worker_ip, worker_port); | |||||
| auto servable = std::make_shared<DistributedServable>(); | |||||
| auto grpc_sever = std::make_shared<MSDistributedWorkerServer>(); | |||||
| status = grpc_sever->StartDistributedWorkerGrpcServer(servable, worker_ip, worker_port); | |||||
| if (status != SUCCESS) { | if (status != SUCCESS) { | ||||
| MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | ||||
| } | } | ||||
| Worker::GetInstance().AfterStartGrpcServer(grpc_sever); | |||||
| auto notify_master = std::make_shared<GrpcNotfiyMaster>(master_ip, master_port, worker_ip, worker_port); | auto notify_master = std::make_shared<GrpcNotfiyMaster>(master_ip, master_port, worker_ip, worker_port); | ||||
| auto servable = std::make_shared<DistributedServable>(); | |||||
| status = servable->StartServable(servable_directory, servable_name, rank_table_json_file, version_number); | status = servable->StartServable(servable_directory, servable_name, rank_table_json_file, version_number); | ||||
| if (status != SUCCESS) { | if (status != SUCCESS) { | ||||
| MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | ||||
| @@ -96,13 +105,15 @@ void PyWorker::StartDistributedServableInMaster(const std::string &servable_dire | |||||
| const std::string &rank_table_json_file, uint32_t version_number, | const std::string &rank_table_json_file, uint32_t version_number, | ||||
| const std::string &worker_ip, uint32_t worker_port) { | const std::string &worker_ip, uint32_t worker_port) { | ||||
| Status status; | Status status; | ||||
| status = Worker::GetInstance().StartGrpcServer(worker_ip, worker_port); | |||||
| auto servable = std::make_shared<DistributedServable>(); | |||||
| auto grpc_sever = std::make_shared<MSDistributedWorkerServer>(); | |||||
| status = grpc_sever->StartDistributedWorkerGrpcServer(servable, worker_ip, worker_port); | |||||
| if (status != SUCCESS) { | if (status != SUCCESS) { | ||||
| MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | ||||
| } | } | ||||
| Worker::GetInstance().AfterStartGrpcServer(grpc_sever); | |||||
| auto notify_master = std::make_shared<LocalNotifyMaster>(); | auto notify_master = std::make_shared<LocalNotifyMaster>(); | ||||
| auto servable = std::make_shared<DistributedServable>(); | |||||
| status = servable->StartServable(servable_directory, servable_name, rank_table_json_file, version_number); | status = servable->StartServable(servable_directory, servable_name, rank_table_json_file, version_number); | ||||
| if (status != SUCCESS) { | if (status != SUCCESS) { | ||||
| MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | ||||
| @@ -14,7 +14,7 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "worker/distributed_worker/distributed_process/distributed_process.h" | |||||
| #include "worker/distributed_worker/grpc/distributed_process.h" | |||||
| #include "common/proto_tensor.h" | #include "common/proto_tensor.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -27,12 +27,13 @@ | |||||
| #include "proto/ms_distributed.pb.h" | #include "proto/ms_distributed.pb.h" | ||||
| #include "proto/ms_distributed.grpc.pb.h" | #include "proto/ms_distributed.grpc.pb.h" | ||||
| #include "worker/distributed_worker/distributed_servable.h" | #include "worker/distributed_worker/distributed_servable.h" | ||||
| #include "worker/grpc/worker_process.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace serving { | namespace serving { | ||||
| // Service Implement | // Service Implement | ||||
| class MSDistributedImpl final : public proto::MSDistributedWorker::Service { | |||||
| class MSDistributedImpl final : public MSWorkerImpl { | |||||
| public: | public: | ||||
| explicit MSDistributedImpl(std::shared_ptr<DistributedServable> servable) : servable_(servable) {} | explicit MSDistributedImpl(std::shared_ptr<DistributedServable> servable) : servable_(servable) {} | ||||
| ~MSDistributedImpl() = default; | ~MSDistributedImpl() = default; | ||||
| @@ -0,0 +1,38 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "worker/distributed_worker/grpc/distributed_server.h" | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <utility> | |||||
| #include "common/grpc_server.h" | |||||
| namespace mindspore { | |||||
| namespace serving { | |||||
| Status MSDistributedWorkerServer::StartDistributedWorkerGrpcServer(std::shared_ptr<DistributedServable> servable, | |||||
| const std::string &hostname, int32_t port) { | |||||
| if (in_running_) { | |||||
| return INFER_STATUS_LOG_ERROR(FAILED) << "Worker grpc server is already running"; | |||||
| } | |||||
| auto impl = std::make_unique<MSDistributedImpl>(servable); | |||||
| async_server_ = std::make_unique<DistributedWorkerGrpcServer>(hostname, port, impl.get()); | |||||
| service_impl_ = std::move(impl); | |||||
| return Init(); | |||||
| } | |||||
| } // namespace serving | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,147 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_SERVING_WORKER_DISTRIBUTED_WORKER_SERVER_H | |||||
| #define MINDSPORE_SERVING_WORKER_DISTRIBUTED_WORKER_SERVER_H | |||||
| #include <grpcpp/grpcpp.h> | |||||
| #include <grpcpp/health_check_service_interface.h> | |||||
| #include <grpcpp/ext/proto_server_reflection_plugin.h> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include "common/serving_common.h" | |||||
| #include "proto/ms_worker.pb.h" | |||||
| #include "proto/ms_worker.grpc.pb.h" | |||||
| #include "common/grpc_async_server.h" | |||||
| #include "worker/grpc/worker_process.h" | |||||
| #include "worker/grpc/worker_server.h" | |||||
| #include "worker/distributed_worker/grpc/distributed_process.h" | |||||
| namespace mindspore { | |||||
| namespace serving { | |||||
| // Service Implement | |||||
| class MS_API MSDistributedWorkerServer : public MSWorkerServer { | |||||
| public: | |||||
| Status StartDistributedWorkerGrpcServer(std::shared_ptr<DistributedServable> servable, const std::string &hostname, | |||||
| int32_t port); | |||||
| }; | |||||
| // Service Implement | |||||
| class WorkerAgentRegisterContext : public WorkerServiceContext { | |||||
| public: | |||||
| WorkerAgentRegisterContext(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service, | |||||
| grpc::ServerCompletionQueue *cq) | |||||
| : service_impl_(service_impl), async_service_(async_service), cq_(cq), responder_(&ctx_) { | |||||
| state_ = STATE::CREATE; | |||||
| } | |||||
| ~WorkerAgentRegisterContext() = default; | |||||
| static Status EnqueueRequest(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service, | |||||
| grpc::ServerCompletionQueue *cq) { | |||||
| auto call = new WorkerAgentRegisterContext(service_impl, async_service, cq); | |||||
| call->StartEnqueueRequest(); | |||||
| return SUCCESS; | |||||
| } | |||||
| void StartEnqueueRequest() override { | |||||
| state_ = STATE::PROCESS; | |||||
| async_service_->RequestPredict(&ctx_, &request_, &responder_, cq_, cq_, this); | |||||
| } | |||||
| void HandleRequest() override { | |||||
| EnqueueRequest(service_impl_, async_service_, cq_); | |||||
| state_ = STATE::FINISH; | |||||
| grpc::Status status = service_impl_->Predict(&ctx_, &request_, &response_); | |||||
| responder_.Finish(response_, status, this); | |||||
| } | |||||
| bool JudgeFinish() override { return state_ == STATE::FINISH; } | |||||
| private: | |||||
| MSDistributedImpl *service_impl_; | |||||
| proto::MSWorker::AsyncService *async_service_; | |||||
| grpc::ServerCompletionQueue *cq_; | |||||
| grpc::ServerContext ctx_; | |||||
| grpc::ServerAsyncResponseWriter<proto::PredictReply> responder_; | |||||
| proto::PredictRequest request_; | |||||
| proto::PredictReply response_; | |||||
| }; | |||||
| class WorkerAgentExitContext : public WorkerServiceContext { | |||||
| public: | |||||
| WorkerAgentExitContext(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service, | |||||
| grpc::ServerCompletionQueue *cq) | |||||
| : service_impl_(service_impl), async_service_(async_service), cq_(cq), responder_(&ctx_) { | |||||
| state_ = STATE::CREATE; | |||||
| } | |||||
| ~WorkerAgentExitContext() = default; | |||||
| static Status EnqueueRequest(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service, | |||||
| grpc::ServerCompletionQueue *cq) { | |||||
| auto call = new WorkerAgentExitContext(service_impl, async_service, cq); | |||||
| call->StartEnqueueRequest(); | |||||
| return SUCCESS; | |||||
| } | |||||
| void StartEnqueueRequest() override { | |||||
| state_ = STATE::PROCESS; | |||||
| async_service_->RequestExit(&ctx_, &request_, &responder_, cq_, cq_, this); | |||||
| } | |||||
| void HandleRequest() override { | |||||
| EnqueueRequest(service_impl_, async_service_, cq_); | |||||
| state_ = STATE::FINISH; | |||||
| grpc::Status status = service_impl_->Exit(&ctx_, &request_, &response_); | |||||
| responder_.Finish(response_, status, this); | |||||
| } | |||||
| bool JudgeFinish() override { return state_ == STATE::FINISH; } | |||||
| private: | |||||
| MSDistributedImpl *service_impl_; | |||||
| proto::MSWorker::AsyncService *async_service_; | |||||
| grpc::ServerCompletionQueue *cq_; | |||||
| grpc::ServerContext ctx_; | |||||
| grpc::ServerAsyncResponseWriter<proto::ExitReply> responder_; | |||||
| proto::ExitRequest request_; | |||||
| proto::ExitReply response_; | |||||
| }; | |||||
| class DistributedWorkerGrpcServer : public WorkerGrpcServer { | |||||
| public: | |||||
| DistributedWorkerGrpcServer(const std::string &host, int32_t port, MSDistributedImpl *service_impl) | |||||
| : WorkerGrpcServer(host, port, service_impl), distributed_service_impl_(service_impl) {} | |||||
| ~DistributedWorkerGrpcServer() = default; | |||||
| Status EnqueueRequest() { | |||||
| WorkerGrpcServer::EnqueueRequest(); | |||||
| WorkerAgentRegisterContext::EnqueueRequest(distributed_service_impl_, &svc_, cq_.get()); | |||||
| WorkerAgentExitContext::EnqueueRequest(distributed_service_impl_, &svc_, cq_.get()); | |||||
| return SUCCESS; | |||||
| } | |||||
| private: | |||||
| MSDistributedImpl *distributed_service_impl_; | |||||
| }; | |||||
| } // namespace serving | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_SERVING_WORKER_DISTRIBUTED_WORKER_SERVER_H | |||||
| @@ -25,7 +25,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace serving { | namespace serving { | ||||
| GrpcNotfiyDistributeWorker::GrpcNotfiyDistributeWorker(const std::string &distributed_worker_ip, | |||||
| GrpcNotifyDistributeWorker::GrpcNotifyDistributeWorker(const std::string &distributed_worker_ip, | |||||
| uint32_t distributed_worker_port, const std::string &host_ip, | uint32_t distributed_worker_port, const std::string &host_ip, | ||||
| uint32_t host_port) | uint32_t host_port) | ||||
| : distributed_worker_ip_(distributed_worker_ip), | : distributed_worker_ip_(distributed_worker_ip), | ||||
| @@ -35,12 +35,12 @@ GrpcNotfiyDistributeWorker::GrpcNotfiyDistributeWorker(const std::string &distri | |||||
| distributed_worker_address_ = distributed_worker_ip + ":" + std::to_string(distributed_worker_port); | distributed_worker_address_ = distributed_worker_ip + ":" + std::to_string(distributed_worker_port); | ||||
| agent_address_ = host_ip_ + ":" + std::to_string(host_port_); | agent_address_ = host_ip_ + ":" + std::to_string(host_port_); | ||||
| auto channel = GrpcServer::CreateChannel(distributed_worker_address_); | auto channel = GrpcServer::CreateChannel(distributed_worker_address_); | ||||
| stub_ = proto::MSDistributedWorker::NewStub(channel); | |||||
| stub_ = proto::MSWorker::NewStub(channel); | |||||
| } | } | ||||
| GrpcNotfiyDistributeWorker::~GrpcNotfiyDistributeWorker() = default; | |||||
| GrpcNotifyDistributeWorker::~GrpcNotifyDistributeWorker() = default; | |||||
| Status GrpcNotfiyDistributeWorker::Register(const std::vector<WorkerAgentSpec> &worker_specs) { | |||||
| Status GrpcNotifyDistributeWorker::Register(const std::vector<WorkerAgentSpec> &worker_specs) { | |||||
| const int32_t REGISTER_TIME_OUT = 60; | const int32_t REGISTER_TIME_OUT = 60; | ||||
| const int32_t REGISTER_INTERVAL = 1; | const int32_t REGISTER_INTERVAL = 1; | ||||
| auto loop = REGISTER_TIME_OUT; | auto loop = REGISTER_TIME_OUT; | ||||
| @@ -67,7 +67,7 @@ Status GrpcNotfiyDistributeWorker::Register(const std::vector<WorkerAgentSpec> & | |||||
| return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Register TimeOut"; | return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Register TimeOut"; | ||||
| } | } | ||||
| Status GrpcNotfiyDistributeWorker::Unregister() { | |||||
| Status GrpcNotifyDistributeWorker::Unregister() { | |||||
| if (is_stoped_.load()) { | if (is_stoped_.load()) { | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -20,18 +20,18 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <memory> | #include <memory> | ||||
| #include "worker/distributed_worker/notify_distributed/base_notify_worker.h" | #include "worker/distributed_worker/notify_distributed/base_notify_worker.h" | ||||
| #include "proto/ms_master.pb.h" | |||||
| #include "proto/ms_master.grpc.pb.h" | |||||
| #include "proto/ms_distributed.pb.h" | #include "proto/ms_distributed.pb.h" | ||||
| #include "proto/ms_distributed.grpc.pb.h" | #include "proto/ms_distributed.grpc.pb.h" | ||||
| #include "proto/ms_worker.pb.h" | |||||
| #include "proto/ms_worker.grpc.pb.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace serving { | namespace serving { | ||||
| class MS_API GrpcNotfiyDistributeWorker : public BaseNotifyDistributeWorker { | |||||
| class MS_API GrpcNotifyDistributeWorker : public BaseNotifyDistributeWorker { | |||||
| public: | public: | ||||
| GrpcNotfiyDistributeWorker(const std::string &master_ip, uint32_t master_port, const std::string &host_ip, | |||||
| GrpcNotifyDistributeWorker(const std::string &master_ip, uint32_t master_port, const std::string &host_ip, | |||||
| uint32_t host_port); | uint32_t host_port); | ||||
| ~GrpcNotfiyDistributeWorker() override; | |||||
| ~GrpcNotifyDistributeWorker() override; | |||||
| Status Register(const std::vector<WorkerAgentSpec> &worker_specs) override; | Status Register(const std::vector<WorkerAgentSpec> &worker_specs) override; | ||||
| Status Unregister() override; | Status Unregister() override; | ||||
| @@ -42,7 +42,7 @@ class MS_API GrpcNotfiyDistributeWorker : public BaseNotifyDistributeWorker { | |||||
| uint32_t host_port_; | uint32_t host_port_; | ||||
| std::string agent_address_; | std::string agent_address_; | ||||
| std::string distributed_worker_address_; | std::string distributed_worker_address_; | ||||
| std::unique_ptr<proto::MSDistributedWorker::Stub> stub_; | |||||
| std::unique_ptr<proto::MSWorker::Stub> stub_; | |||||
| std::atomic<bool> is_stoped_{false}; | std::atomic<bool> is_stoped_{false}; | ||||
| }; | }; | ||||
| @@ -15,7 +15,6 @@ | |||||
| */ | */ | ||||
| #include "worker/grpc/worker_process.h" | #include "worker/grpc/worker_process.h" | ||||
| #include "master/dispacther.h" | |||||
| #include "worker/worker.h" | #include "worker/worker.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -28,7 +28,7 @@ namespace mindspore { | |||||
| namespace serving { | namespace serving { | ||||
| // Service Implement | // Service Implement | ||||
| class MSWorkerImpl final : public proto::MSWorker::Service { | |||||
| class MSWorkerImpl : public proto::MSWorker::Service { | |||||
| public: | public: | ||||
| grpc::Status Predict(grpc::ServerContext *context, const proto::PredictRequest *request, | grpc::Status Predict(grpc::ServerContext *context, const proto::PredictRequest *request, | ||||
| proto::PredictReply *reply) override; | proto::PredictReply *reply) override; | ||||
| @@ -21,12 +21,20 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace serving { | namespace serving { | ||||
| MSWorkerServer::~MSWorkerServer() { Stop(); } | MSWorkerServer::~MSWorkerServer() { Stop(); } | ||||
| MSWorkerServer::MSWorkerServer(const std::string &hostname, int32_t port) { | |||||
| Status MSWorkerServer::StartWorkerGrpcServer(const std::string &hostname, int32_t port) { | |||||
| if (in_running_) { | |||||
| return INFER_STATUS_LOG_ERROR(FAILED) << "Worker grpc server is already running"; | |||||
| } | |||||
| service_impl_ = std::make_unique<MSWorkerImpl>(); | service_impl_ = std::make_unique<MSWorkerImpl>(); | ||||
| async_server_ = std::make_unique<WorkerGrpcServer>(hostname, port, service_impl_.get()); | async_server_ = std::make_unique<WorkerGrpcServer>(hostname, port, service_impl_.get()); | ||||
| return Init(); | |||||
| } | } | ||||
| MSWorkerServer::MSWorkerServer() = default; | |||||
| Status MSWorkerServer::Init() { | Status MSWorkerServer::Init() { | ||||
| Status status = async_server_->Run("Worker gRPC", gRpcMaxMBMsgSize); | Status status = async_server_->Run("Worker gRPC", gRpcMaxMBMsgSize); | ||||
| if (status != SUCCESS) return status; | if (status != SUCCESS) return status; | ||||
| @@ -40,10 +48,14 @@ Status MSWorkerServer::StartAsyncRpcService() { | |||||
| return status; | return status; | ||||
| } | } | ||||
| Status MSWorkerServer::Stop() { | Status MSWorkerServer::Stop() { | ||||
| if (in_running_) { | |||||
| if (in_running_ && async_server_) { | |||||
| async_server_->Stop(); | async_server_->Stop(); | ||||
| grpc_thread_.join(); | |||||
| if (grpc_thread_.joinable()) { | |||||
| grpc_thread_.join(); | |||||
| } | |||||
| } | } | ||||
| async_server_ = nullptr; | |||||
| service_impl_ = nullptr; | |||||
| in_running_ = false; | in_running_ = false; | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -27,27 +27,29 @@ | |||||
| #include "proto/ms_worker.grpc.pb.h" | #include "proto/ms_worker.grpc.pb.h" | ||||
| #include "common/grpc_async_server.h" | #include "common/grpc_async_server.h" | ||||
| #include "worker/grpc/worker_process.h" | #include "worker/grpc/worker_process.h" | ||||
| #include "worker/distributed_worker/distributed_servable.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace serving { | namespace serving { | ||||
| // Service Implement | // Service Implement | ||||
| class MSWorkerServer { | |||||
| class MS_API MSWorkerServer { | |||||
| public: | public: | ||||
| enum ServerState { kGdsUninit = 0, kGdsInitializing, kGdsRunning, kGdsStopped }; | enum ServerState { kGdsUninit = 0, kGdsInitializing, kGdsRunning, kGdsStopped }; | ||||
| MSWorkerServer(const std::string &hostname, int32_t port); | |||||
| ~MSWorkerServer(); | |||||
| Status Init(); | |||||
| MSWorkerServer(); | |||||
| virtual ~MSWorkerServer(); | |||||
| Status StartWorkerGrpcServer(const std::string &hostname, int32_t port); | |||||
| Status Stop(); | Status Stop(); | ||||
| Status StartAsyncRpcService(); | |||||
| protected: | |||||
| bool in_running_ = false; | bool in_running_ = false; | ||||
| std::thread grpc_thread_; | std::thread grpc_thread_; | ||||
| std::unique_ptr<MSWorkerImpl> service_impl_; | |||||
| std::unique_ptr<GrpcAsyncServer> async_server_; | |||||
| std::unique_ptr<MSWorkerImpl> service_impl_ = nullptr; | |||||
| std::unique_ptr<GrpcAsyncServer> async_server_ = nullptr; | |||||
| Status StartAsyncRpcService(); | |||||
| Status Init(); | |||||
| }; | }; | ||||
| class WorkerServiceContext { | class WorkerServiceContext { | ||||
| @@ -174,7 +176,7 @@ class WorkerGrpcServer : public GrpcAsyncServer { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| private: | |||||
| protected: | |||||
| MSWorkerImpl *service_impl_; | MSWorkerImpl *service_impl_; | ||||
| proto::MSWorker::AsyncService svc_; | proto::MSWorker::AsyncService svc_; | ||||
| }; | }; | ||||
| @@ -39,10 +39,8 @@ using WorkCallBack = std::function<void(const Instance &output, const Status &er | |||||
| class WorkExecutor { | class WorkExecutor { | ||||
| public: | public: | ||||
| WorkExecutor(std::shared_ptr<TaskQueue> py_preprocess_task_queue, | |||||
| std::shared_ptr<TaskQueue> py_postprocess_task_queue, | |||||
| std::shared_ptr<TaskQueue> cpp_preprocess_task_queue, | |||||
| std::shared_ptr<TaskQueue> cpp_postprocess_task_queue); | |||||
| WorkExecutor(std::shared_ptr<TaskQueue> py_preprocess, std::shared_ptr<TaskQueue> py_postprocess, | |||||
| std::shared_ptr<TaskQueue> cpp_preprocess, std::shared_ptr<TaskQueue> cpp_postprocess); | |||||
| ~WorkExecutor(); | ~WorkExecutor(); | ||||
| Status Init(const ServableSignature &servable_declare, const std::shared_ptr<ServableBase> &servable); | Status Init(const ServableSignature &servable_declare, const std::shared_ptr<ServableBase> &servable); | ||||
| @@ -34,21 +34,11 @@ namespace py = pybind11; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace serving { | namespace serving { | ||||
| static std::unique_ptr<MSWorkerServer> grpc_async_worker_server_; | |||||
| Worker &Worker::GetInstance() { | Worker &Worker::GetInstance() { | ||||
| static Worker instance; | static Worker instance; | ||||
| return instance; | return instance; | ||||
| } | } | ||||
| Status Worker::StartGrpcServer(const std::string &ip, uint32_t grpc_port) { | |||||
| if (grpc_async_worker_server_ != nullptr) { | |||||
| return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Serving Error: Worker gRPC server is already running"; | |||||
| } | |||||
| grpc_async_worker_server_ = std::make_unique<MSWorkerServer>(ip, grpc_port); | |||||
| return grpc_async_worker_server_->Init(); | |||||
| } | |||||
| Status Worker::RegisterWorker() { | Status Worker::RegisterWorker() { | ||||
| std::vector<WorkerSpec> worker_specs; | std::vector<WorkerSpec> worker_specs; | ||||
| for (auto &work : work_list_) { | for (auto &work : work_list_) { | ||||
| @@ -184,6 +174,11 @@ void Worker::Update() { | |||||
| */ | */ | ||||
| } | } | ||||
| Status Worker::AfterStartGrpcServer(const std::shared_ptr<MSWorkerServer> &grpc_server) { | |||||
| worker_grpc_server_ = grpc_server; | |||||
| return SUCCESS; | |||||
| } | |||||
| Status Worker::StartServable(std::shared_ptr<ServableBase> servable, std::shared_ptr<BaseNotifyMaster> notify_master) { | Status Worker::StartServable(std::shared_ptr<ServableBase> servable, std::shared_ptr<BaseNotifyMaster> notify_master) { | ||||
| ExitSignalHandle::Instance().Start(); // handle ctrl+c to exit | ExitSignalHandle::Instance().Start(); // handle ctrl+c to exit | ||||
| if (servable_started_) { | if (servable_started_) { | ||||
| @@ -244,7 +239,7 @@ void Worker::StopServable(bool notify_master) { | |||||
| void Worker::Clear() { | void Worker::Clear() { | ||||
| std::unique_lock<std::shared_mutex> lock(worker_shared_lock_); | std::unique_lock<std::shared_mutex> lock(worker_shared_lock_); | ||||
| ServableStorage::Instance().Clear(); | ServableStorage::Instance().Clear(); | ||||
| grpc_async_worker_server_ = nullptr; | |||||
| worker_grpc_server_ = nullptr; | |||||
| if (clear_flag_.test_and_set()) { | if (clear_flag_.test_and_set()) { | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -33,6 +33,7 @@ | |||||
| #include "worker/version_control/version_controller.h" | #include "worker/version_control/version_controller.h" | ||||
| #include "common/grpc_async_server.h" | #include "common/grpc_async_server.h" | ||||
| #include "worker/sevable_base.h" | #include "worker/sevable_base.h" | ||||
| #include "worker/grpc/worker_server.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace serving { | namespace serving { | ||||
| @@ -74,10 +75,11 @@ class MS_API Worker { | |||||
| const std::vector<InstanceData> &inputs); | const std::vector<InstanceData> &inputs); | ||||
| Status StartServable(std::shared_ptr<ServableBase> servable, std::shared_ptr<BaseNotifyMaster> notify_master); | Status StartServable(std::shared_ptr<ServableBase> servable, std::shared_ptr<BaseNotifyMaster> notify_master); | ||||
| Status AfterStartGrpcServer(const std::shared_ptr<MSWorkerServer> &grpc_server); | |||||
| void StopServable(bool notify_master = true); | void StopServable(bool notify_master = true); | ||||
| bool HasCleared(); | bool HasCleared(); | ||||
| Status RegisterWorker(); | Status RegisterWorker(); | ||||
| Status StartGrpcServer(const std::string &ip, uint32_t grpc_port); | |||||
| void Update(); | void Update(); | ||||
| Status StartVersionController(); | Status StartVersionController(); | ||||
| Status AddWorker(const ServableWorkerContext &work); | Status AddWorker(const ServableWorkerContext &work); | ||||
| @@ -101,6 +103,7 @@ class MS_API Worker { | |||||
| std::atomic_bool servable_started_ = false; | std::atomic_bool servable_started_ = false; | ||||
| std::atomic_flag clear_flag_ = ATOMIC_FLAG_INIT; | std::atomic_flag clear_flag_ = ATOMIC_FLAG_INIT; | ||||
| std::shared_ptr<BaseNotifyMaster> notify_master_ = nullptr; | std::shared_ptr<BaseNotifyMaster> notify_master_ = nullptr; | ||||
| std::shared_ptr<MSWorkerServer> worker_grpc_server_ = nullptr; | |||||
| std::shared_mutex worker_shared_lock_; | std::shared_mutex worker_shared_lock_; | ||||
| @@ -45,8 +45,3 @@ message AgentExitRequest { | |||||
| message AgentExitReply { | message AgentExitReply { | ||||
| ErrorMsg error_msg = 1; | ErrorMsg error_msg = 1; | ||||
| } | } | ||||
| service MSDistributedWorker { | |||||
| rpc AgentExit(AgentExitRequest) returns (AgentExitReply) {} | |||||
| rpc AgentRegister(AgentRegisterRequest) returns (AgentRegisterReply) {} | |||||
| } | |||||
| @@ -20,8 +20,13 @@ syntax = "proto3"; | |||||
| package mindspore.serving.proto; | package mindspore.serving.proto; | ||||
| import "mindspore_serving/proto/ms_service.proto"; | import "mindspore_serving/proto/ms_service.proto"; | ||||
| import "mindspore_serving/proto/ms_master.proto"; | import "mindspore_serving/proto/ms_master.proto"; | ||||
| import "mindspore_serving/proto/ms_distributed.proto"; | |||||
| service MSWorker { | service MSWorker { | ||||
| // for master | |||||
| rpc Predict(PredictRequest) returns (PredictReply) {} | rpc Predict(PredictRequest) returns (PredictReply) {} | ||||
| rpc Exit(ExitRequest) returns (ExitReply) {} | rpc Exit(ExitRequest) returns (ExitReply) {} | ||||
| // for worker agent | |||||
| rpc AgentExit(AgentExitRequest) returns (AgentExitReply) {} | |||||
| rpc AgentRegister(AgentRegisterRequest) returns (AgentRegisterReply) {} | |||||
| } | } | ||||