Browse Source

Serving, commbile distributed worker and local worker in grpc server process

tags/v1.2.0
xuyongfei 5 years ago
parent
commit
620bc494b4
18 changed files with 278 additions and 57 deletions
  1. +13
    -0
      mindspore_serving/ccsrc/common/exit_handle.cc
  2. +2
    -0
      mindspore_serving/ccsrc/common/exit_handle.h
  3. +19
    -8
      mindspore_serving/ccsrc/python/worker/worker_py.cc
  4. +1
    -1
      mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_process.cc
  5. +2
    -1
      mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_process.h
  6. +38
    -0
      mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_server.cc
  7. +147
    -0
      mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_server.h
  8. +5
    -5
      mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.cc
  9. +6
    -6
      mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.h
  10. +0
    -1
      mindspore_serving/ccsrc/worker/grpc/worker_process.cc
  11. +1
    -1
      mindspore_serving/ccsrc/worker/grpc/worker_process.h
  12. +15
    -3
      mindspore_serving/ccsrc/worker/grpc/worker_server.cc
  13. +12
    -10
      mindspore_serving/ccsrc/worker/grpc/worker_server.h
  14. +2
    -4
      mindspore_serving/ccsrc/worker/work_executor.h
  15. +6
    -11
      mindspore_serving/ccsrc/worker/worker.cc
  16. +4
    -1
      mindspore_serving/ccsrc/worker/worker.h
  17. +0
    -5
      mindspore_serving/proto/ms_distributed.proto
  18. +5
    -0
      mindspore_serving/proto/ms_worker.proto

+ 13
- 0
mindspore_serving/ccsrc/common/exit_handle.cc View File

@@ -55,6 +55,17 @@ void ExitSignalHandle::WorkerWait() {
exit_future.wait();
}

// waiting ctrl+c or stop message to exit,
// if no server is running or server has exited, there is no need to wait
void ExitSignalHandle::AgentWait() {
if (!is_running_) {
MSI_LOG_INFO << "Exit Handle has not started or has exited";
return;
}
auto exit_future = agent_exit_requested_.get_future();
exit_future.wait();
}

void ExitSignalHandle::Start() {
if (is_running_) {
return;
@@ -62,6 +73,7 @@ void ExitSignalHandle::Start() {
is_running_ = true;
master_exit_requested_ = std::promise<void>();
worker_exit_requested_ = std::promise<void>();
agent_exit_requested_ = std::promise<void>();
has_exited_.clear();
InitSignalHandle();
}
@@ -79,6 +91,7 @@ void ExitSignalHandle::HandleSignalInner() {
if (!has_exited_.test_and_set()) {
master_exit_requested_.set_value();
worker_exit_requested_.set_value();
agent_exit_requested_.set_value();
is_running_ = false;
}
}


+ 2
- 0
mindspore_serving/ccsrc/common/exit_handle.h View File

@@ -32,6 +32,7 @@ class MS_API ExitSignalHandle {
void InitSignalHandle();
void MasterWait();
void WorkerWait();
void AgentWait();
void Start();
void Stop();
bool HasStopped();
@@ -39,6 +40,7 @@ class MS_API ExitSignalHandle {
private:
std::promise<void> master_exit_requested_;
std::promise<void> worker_exit_requested_;
std::promise<void> agent_exit_requested_;
std::atomic_flag has_exited_ = true;
std::atomic_flag has_inited_ = ATOMIC_FLAG_INIT;
std::atomic_bool is_running_ = false;


+ 19
- 8
mindspore_serving/ccsrc/python/worker/worker_py.cc View File

@@ -23,13 +23,15 @@
#include "worker/notfiy_master/local_notify.h"
#include "worker/local_servable/local_sevable.h"
#include "worker/distributed_worker/distributed_servable.h"
#include "worker/grpc/worker_server.h"
#include "worker/distributed_worker/grpc/distributed_server.h"

namespace mindspore::serving {

void PyWorker::StartServable(const std::string &model_directory, const std::string &model_name, uint32_t version_number,
const std::string &master_ip, uint32_t master_port, const std::string &host_ip,
uint32_t host_port) {
auto notify_master = std::make_shared<GrpcNotfiyMaster>(master_ip, master_port, host_ip, host_port);
const std::string &master_ip, uint32_t master_port, const std::string &worker_ip,
uint32_t worker_port) {
auto notify_master = std::make_shared<GrpcNotfiyMaster>(master_ip, master_port, worker_ip, worker_port);
auto servable = std::make_shared<LocalModelServable>();
auto status = servable->StartServable(model_directory, model_name, version_number);
if (status != SUCCESS) {
@@ -39,10 +41,14 @@ void PyWorker::StartServable(const std::string &model_directory, const std::stri
if (status != SUCCESS) {
MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage();
}
status = Worker::GetInstance().StartGrpcServer(host_ip, host_port);
// start grpc server
auto grpc_sever = std::make_shared<MSWorkerServer>();
status = grpc_sever->StartWorkerGrpcServer(worker_ip, worker_port);
if (status != SUCCESS) {
MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage();
}
Worker::GetInstance().AfterStartGrpcServer(grpc_sever);

status = Worker::GetInstance().StartVersionController();
if (status != SUCCESS) {
MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage();
@@ -72,12 +78,15 @@ void PyWorker::StartDistributedServable(const std::string &servable_directory, c
const std::string &worker_ip, uint32_t worker_port,
const std::string &master_ip, uint32_t master_port) {
Status status;
status = Worker::GetInstance().StartGrpcServer(worker_ip, worker_port);
auto servable = std::make_shared<DistributedServable>();
auto grpc_sever = std::make_shared<MSDistributedWorkerServer>();
status = grpc_sever->StartDistributedWorkerGrpcServer(servable, worker_ip, worker_port);
if (status != SUCCESS) {
MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage();
}
Worker::GetInstance().AfterStartGrpcServer(grpc_sever);

auto notify_master = std::make_shared<GrpcNotfiyMaster>(master_ip, master_port, worker_ip, worker_port);
auto servable = std::make_shared<DistributedServable>();
status = servable->StartServable(servable_directory, servable_name, rank_table_json_file, version_number);
if (status != SUCCESS) {
MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage();
@@ -96,13 +105,15 @@ void PyWorker::StartDistributedServableInMaster(const std::string &servable_dire
const std::string &rank_table_json_file, uint32_t version_number,
const std::string &worker_ip, uint32_t worker_port) {
Status status;
status = Worker::GetInstance().StartGrpcServer(worker_ip, worker_port);
auto servable = std::make_shared<DistributedServable>();
auto grpc_sever = std::make_shared<MSDistributedWorkerServer>();
status = grpc_sever->StartDistributedWorkerGrpcServer(servable, worker_ip, worker_port);
if (status != SUCCESS) {
MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage();
}
Worker::GetInstance().AfterStartGrpcServer(grpc_sever);

auto notify_master = std::make_shared<LocalNotifyMaster>();
auto servable = std::make_shared<DistributedServable>();
status = servable->StartServable(servable_directory, servable_name, rank_table_json_file, version_number);
if (status != SUCCESS) {
MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage();


mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.cc → mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_process.cc View File

@@ -14,7 +14,7 @@
* limitations under the License.
*/

#include "worker/distributed_worker/distributed_process/distributed_process.h"
#include "worker/distributed_worker/grpc/distributed_process.h"
#include "common/proto_tensor.h"

namespace mindspore {

mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.h → mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_process.h View File

@@ -27,12 +27,13 @@
#include "proto/ms_distributed.pb.h"
#include "proto/ms_distributed.grpc.pb.h"
#include "worker/distributed_worker/distributed_servable.h"
#include "worker/grpc/worker_process.h"
namespace mindspore {
namespace serving {
// Service Implement
class MSDistributedImpl final : public proto::MSDistributedWorker::Service {
class MSDistributedImpl final : public MSWorkerImpl {
public:
explicit MSDistributedImpl(std::shared_ptr<DistributedServable> servable) : servable_(servable) {}
~MSDistributedImpl() = default;

+ 38
- 0
mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_server.cc View File

@@ -0,0 +1,38 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "worker/distributed_worker/grpc/distributed_server.h"
#include <string>
#include <memory>
#include <utility>
#include "common/grpc_server.h"
namespace mindspore {
namespace serving {
Status MSDistributedWorkerServer::StartDistributedWorkerGrpcServer(std::shared_ptr<DistributedServable> servable,
const std::string &hostname, int32_t port) {
if (in_running_) {
return INFER_STATUS_LOG_ERROR(FAILED) << "Worker grpc server is already running";
}
auto impl = std::make_unique<MSDistributedImpl>(servable);
async_server_ = std::make_unique<DistributedWorkerGrpcServer>(hostname, port, impl.get());
service_impl_ = std::move(impl);
return Init();
}
} // namespace serving
} // namespace mindspore

+ 147
- 0
mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_server.h View File

@@ -0,0 +1,147 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_SERVING_WORKER_DISTRIBUTED_WORKER_SERVER_H
#define MINDSPORE_SERVING_WORKER_DISTRIBUTED_WORKER_SERVER_H
#include <grpcpp/grpcpp.h>
#include <grpcpp/health_check_service_interface.h>
#include <grpcpp/ext/proto_server_reflection_plugin.h>
#include <memory>
#include <string>
#include "common/serving_common.h"
#include "proto/ms_worker.pb.h"
#include "proto/ms_worker.grpc.pb.h"
#include "common/grpc_async_server.h"
#include "worker/grpc/worker_process.h"
#include "worker/grpc/worker_server.h"
#include "worker/distributed_worker/grpc/distributed_process.h"
namespace mindspore {
namespace serving {
// Service Implement
class MS_API MSDistributedWorkerServer : public MSWorkerServer {
public:
Status StartDistributedWorkerGrpcServer(std::shared_ptr<DistributedServable> servable, const std::string &hostname,
int32_t port);
};
// Service Implement
class WorkerAgentRegisterContext : public WorkerServiceContext {
public:
WorkerAgentRegisterContext(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service,
grpc::ServerCompletionQueue *cq)
: service_impl_(service_impl), async_service_(async_service), cq_(cq), responder_(&ctx_) {
state_ = STATE::CREATE;
}
~WorkerAgentRegisterContext() = default;
static Status EnqueueRequest(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service,
grpc::ServerCompletionQueue *cq) {
auto call = new WorkerAgentRegisterContext(service_impl, async_service, cq);
call->StartEnqueueRequest();
return SUCCESS;
}
void StartEnqueueRequest() override {
state_ = STATE::PROCESS;
async_service_->RequestPredict(&ctx_, &request_, &responder_, cq_, cq_, this);
}
void HandleRequest() override {
EnqueueRequest(service_impl_, async_service_, cq_);
state_ = STATE::FINISH;
grpc::Status status = service_impl_->Predict(&ctx_, &request_, &response_);
responder_.Finish(response_, status, this);
}
bool JudgeFinish() override { return state_ == STATE::FINISH; }
private:
MSDistributedImpl *service_impl_;
proto::MSWorker::AsyncService *async_service_;
grpc::ServerCompletionQueue *cq_;
grpc::ServerContext ctx_;
grpc::ServerAsyncResponseWriter<proto::PredictReply> responder_;
proto::PredictRequest request_;
proto::PredictReply response_;
};
class WorkerAgentExitContext : public WorkerServiceContext {
public:
WorkerAgentExitContext(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service,
grpc::ServerCompletionQueue *cq)
: service_impl_(service_impl), async_service_(async_service), cq_(cq), responder_(&ctx_) {
state_ = STATE::CREATE;
}
~WorkerAgentExitContext() = default;
static Status EnqueueRequest(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service,
grpc::ServerCompletionQueue *cq) {
auto call = new WorkerAgentExitContext(service_impl, async_service, cq);
call->StartEnqueueRequest();
return SUCCESS;
}
void StartEnqueueRequest() override {
state_ = STATE::PROCESS;
async_service_->RequestExit(&ctx_, &request_, &responder_, cq_, cq_, this);
}
void HandleRequest() override {
EnqueueRequest(service_impl_, async_service_, cq_);
state_ = STATE::FINISH;
grpc::Status status = service_impl_->Exit(&ctx_, &request_, &response_);
responder_.Finish(response_, status, this);
}
bool JudgeFinish() override { return state_ == STATE::FINISH; }
private:
MSDistributedImpl *service_impl_;
proto::MSWorker::AsyncService *async_service_;
grpc::ServerCompletionQueue *cq_;
grpc::ServerContext ctx_;
grpc::ServerAsyncResponseWriter<proto::ExitReply> responder_;
proto::ExitRequest request_;
proto::ExitReply response_;
};
class DistributedWorkerGrpcServer : public WorkerGrpcServer {
public:
DistributedWorkerGrpcServer(const std::string &host, int32_t port, MSDistributedImpl *service_impl)
: WorkerGrpcServer(host, port, service_impl), distributed_service_impl_(service_impl) {}
~DistributedWorkerGrpcServer() = default;
Status EnqueueRequest() {
WorkerGrpcServer::EnqueueRequest();
WorkerAgentRegisterContext::EnqueueRequest(distributed_service_impl_, &svc_, cq_.get());
WorkerAgentExitContext::EnqueueRequest(distributed_service_impl_, &svc_, cq_.get());
return SUCCESS;
}
private:
MSDistributedImpl *distributed_service_impl_;
};
} // namespace serving
} // namespace mindspore
#endif // MINDSPORE_SERVING_WORKER_DISTRIBUTED_WORKER_SERVER_H

+ 5
- 5
mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.cc View File

@@ -25,7 +25,7 @@
namespace mindspore {
namespace serving {

GrpcNotfiyDistributeWorker::GrpcNotfiyDistributeWorker(const std::string &distributed_worker_ip,
GrpcNotifyDistributeWorker::GrpcNotifyDistributeWorker(const std::string &distributed_worker_ip,
uint32_t distributed_worker_port, const std::string &host_ip,
uint32_t host_port)
: distributed_worker_ip_(distributed_worker_ip),
@@ -35,12 +35,12 @@ GrpcNotfiyDistributeWorker::GrpcNotfiyDistributeWorker(const std::string &distri
distributed_worker_address_ = distributed_worker_ip + ":" + std::to_string(distributed_worker_port);
agent_address_ = host_ip_ + ":" + std::to_string(host_port_);
auto channel = GrpcServer::CreateChannel(distributed_worker_address_);
stub_ = proto::MSDistributedWorker::NewStub(channel);
stub_ = proto::MSWorker::NewStub(channel);
}

GrpcNotfiyDistributeWorker::~GrpcNotfiyDistributeWorker() = default;
GrpcNotifyDistributeWorker::~GrpcNotifyDistributeWorker() = default;

Status GrpcNotfiyDistributeWorker::Register(const std::vector<WorkerAgentSpec> &worker_specs) {
Status GrpcNotifyDistributeWorker::Register(const std::vector<WorkerAgentSpec> &worker_specs) {
const int32_t REGISTER_TIME_OUT = 60;
const int32_t REGISTER_INTERVAL = 1;
auto loop = REGISTER_TIME_OUT;
@@ -67,7 +67,7 @@ Status GrpcNotfiyDistributeWorker::Register(const std::vector<WorkerAgentSpec> &
return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Register TimeOut";
}

Status GrpcNotfiyDistributeWorker::Unregister() {
Status GrpcNotifyDistributeWorker::Unregister() {
if (is_stoped_.load()) {
return SUCCESS;
}


+ 6
- 6
mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.h View File

@@ -20,18 +20,18 @@
#include <string>
#include <memory>
#include "worker/distributed_worker/notify_distributed/base_notify_worker.h"
#include "proto/ms_master.pb.h"
#include "proto/ms_master.grpc.pb.h"
#include "proto/ms_distributed.pb.h"
#include "proto/ms_distributed.grpc.pb.h"
#include "proto/ms_worker.pb.h"
#include "proto/ms_worker.grpc.pb.h"
namespace mindspore {
namespace serving {

class MS_API GrpcNotfiyDistributeWorker : public BaseNotifyDistributeWorker {
class MS_API GrpcNotifyDistributeWorker : public BaseNotifyDistributeWorker {
public:
GrpcNotfiyDistributeWorker(const std::string &master_ip, uint32_t master_port, const std::string &host_ip,
GrpcNotifyDistributeWorker(const std::string &master_ip, uint32_t master_port, const std::string &host_ip,
uint32_t host_port);
~GrpcNotfiyDistributeWorker() override;
~GrpcNotifyDistributeWorker() override;
Status Register(const std::vector<WorkerAgentSpec> &worker_specs) override;
Status Unregister() override;

@@ -42,7 +42,7 @@ class MS_API GrpcNotfiyDistributeWorker : public BaseNotifyDistributeWorker {
uint32_t host_port_;
std::string agent_address_;
std::string distributed_worker_address_;
std::unique_ptr<proto::MSDistributedWorker::Stub> stub_;
std::unique_ptr<proto::MSWorker::Stub> stub_;
std::atomic<bool> is_stoped_{false};
};



+ 0
- 1
mindspore_serving/ccsrc/worker/grpc/worker_process.cc View File

@@ -15,7 +15,6 @@
*/
#include "worker/grpc/worker_process.h"
#include "master/dispacther.h"
#include "worker/worker.h"
namespace mindspore {


+ 1
- 1
mindspore_serving/ccsrc/worker/grpc/worker_process.h View File

@@ -28,7 +28,7 @@ namespace mindspore {
namespace serving {
// Service Implement
class MSWorkerImpl final : public proto::MSWorker::Service {
class MSWorkerImpl : public proto::MSWorker::Service {
public:
grpc::Status Predict(grpc::ServerContext *context, const proto::PredictRequest *request,
proto::PredictReply *reply) override;


+ 15
- 3
mindspore_serving/ccsrc/worker/grpc/worker_server.cc View File

@@ -21,12 +21,20 @@
namespace mindspore {
namespace serving {
MSWorkerServer::~MSWorkerServer() { Stop(); }
MSWorkerServer::MSWorkerServer(const std::string &hostname, int32_t port) {
Status MSWorkerServer::StartWorkerGrpcServer(const std::string &hostname, int32_t port) {
if (in_running_) {
return INFER_STATUS_LOG_ERROR(FAILED) << "Worker grpc server is already running";
}
service_impl_ = std::make_unique<MSWorkerImpl>();
async_server_ = std::make_unique<WorkerGrpcServer>(hostname, port, service_impl_.get());
return Init();
}
MSWorkerServer::MSWorkerServer() = default;
Status MSWorkerServer::Init() {
Status status = async_server_->Run("Worker gRPC", gRpcMaxMBMsgSize);
if (status != SUCCESS) return status;
@@ -40,10 +48,14 @@ Status MSWorkerServer::StartAsyncRpcService() {
return status;
}
Status MSWorkerServer::Stop() {
if (in_running_) {
if (in_running_ && async_server_) {
async_server_->Stop();
grpc_thread_.join();
if (grpc_thread_.joinable()) {
grpc_thread_.join();
}
}
async_server_ = nullptr;
service_impl_ = nullptr;
in_running_ = false;
return SUCCESS;
}


+ 12
- 10
mindspore_serving/ccsrc/worker/grpc/worker_server.h View File

@@ -27,27 +27,29 @@
#include "proto/ms_worker.grpc.pb.h"
#include "common/grpc_async_server.h"
#include "worker/grpc/worker_process.h"
#include "worker/distributed_worker/distributed_servable.h"
namespace mindspore {
namespace serving {
// Service Implement
class MSWorkerServer {
class MS_API MSWorkerServer {
public:
enum ServerState { kGdsUninit = 0, kGdsInitializing, kGdsRunning, kGdsStopped };
MSWorkerServer(const std::string &hostname, int32_t port);
~MSWorkerServer();
Status Init();
MSWorkerServer();
virtual ~MSWorkerServer();
Status StartWorkerGrpcServer(const std::string &hostname, int32_t port);
Status Stop();
Status StartAsyncRpcService();
protected:
bool in_running_ = false;
std::thread grpc_thread_;
std::unique_ptr<MSWorkerImpl> service_impl_;
std::unique_ptr<GrpcAsyncServer> async_server_;
std::unique_ptr<MSWorkerImpl> service_impl_ = nullptr;
std::unique_ptr<GrpcAsyncServer> async_server_ = nullptr;
Status StartAsyncRpcService();
Status Init();
};
class WorkerServiceContext {
@@ -174,7 +176,7 @@ class WorkerGrpcServer : public GrpcAsyncServer {
return SUCCESS;
}
private:
protected:
MSWorkerImpl *service_impl_;
proto::MSWorker::AsyncService svc_;
};


+ 2
- 4
mindspore_serving/ccsrc/worker/work_executor.h View File

@@ -39,10 +39,8 @@ using WorkCallBack = std::function<void(const Instance &output, const Status &er

class WorkExecutor {
public:
WorkExecutor(std::shared_ptr<TaskQueue> py_preprocess_task_queue,
std::shared_ptr<TaskQueue> py_postprocess_task_queue,
std::shared_ptr<TaskQueue> cpp_preprocess_task_queue,
std::shared_ptr<TaskQueue> cpp_postprocess_task_queue);
WorkExecutor(std::shared_ptr<TaskQueue> py_preprocess, std::shared_ptr<TaskQueue> py_postprocess,
std::shared_ptr<TaskQueue> cpp_preprocess, std::shared_ptr<TaskQueue> cpp_postprocess);
~WorkExecutor();

Status Init(const ServableSignature &servable_declare, const std::shared_ptr<ServableBase> &servable);


+ 6
- 11
mindspore_serving/ccsrc/worker/worker.cc View File

@@ -34,21 +34,11 @@ namespace py = pybind11;
namespace mindspore {
namespace serving {

static std::unique_ptr<MSWorkerServer> grpc_async_worker_server_;

Worker &Worker::GetInstance() {
static Worker instance;
return instance;
}

Status Worker::StartGrpcServer(const std::string &ip, uint32_t grpc_port) {
if (grpc_async_worker_server_ != nullptr) {
return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Serving Error: Worker gRPC server is already running";
}
grpc_async_worker_server_ = std::make_unique<MSWorkerServer>(ip, grpc_port);
return grpc_async_worker_server_->Init();
}

Status Worker::RegisterWorker() {
std::vector<WorkerSpec> worker_specs;
for (auto &work : work_list_) {
@@ -184,6 +174,11 @@ void Worker::Update() {
*/
}

Status Worker::AfterStartGrpcServer(const std::shared_ptr<MSWorkerServer> &grpc_server) {
worker_grpc_server_ = grpc_server;
return SUCCESS;
}

Status Worker::StartServable(std::shared_ptr<ServableBase> servable, std::shared_ptr<BaseNotifyMaster> notify_master) {
ExitSignalHandle::Instance().Start(); // handle ctrl+c to exit
if (servable_started_) {
@@ -244,7 +239,7 @@ void Worker::StopServable(bool notify_master) {
void Worker::Clear() {
std::unique_lock<std::shared_mutex> lock(worker_shared_lock_);
ServableStorage::Instance().Clear();
grpc_async_worker_server_ = nullptr;
worker_grpc_server_ = nullptr;
if (clear_flag_.test_and_set()) {
return;
}


+ 4
- 1
mindspore_serving/ccsrc/worker/worker.h View File

@@ -33,6 +33,7 @@
#include "worker/version_control/version_controller.h"
#include "common/grpc_async_server.h"
#include "worker/sevable_base.h"
#include "worker/grpc/worker_server.h"

namespace mindspore {
namespace serving {
@@ -74,10 +75,11 @@ class MS_API Worker {
const std::vector<InstanceData> &inputs);
Status StartServable(std::shared_ptr<ServableBase> servable, std::shared_ptr<BaseNotifyMaster> notify_master);

Status AfterStartGrpcServer(const std::shared_ptr<MSWorkerServer> &grpc_server);

void StopServable(bool notify_master = true);
bool HasCleared();
Status RegisterWorker();
Status StartGrpcServer(const std::string &ip, uint32_t grpc_port);
void Update();
Status StartVersionController();
Status AddWorker(const ServableWorkerContext &work);
@@ -101,6 +103,7 @@ class MS_API Worker {
std::atomic_bool servable_started_ = false;
std::atomic_flag clear_flag_ = ATOMIC_FLAG_INIT;
std::shared_ptr<BaseNotifyMaster> notify_master_ = nullptr;
std::shared_ptr<MSWorkerServer> worker_grpc_server_ = nullptr;

std::shared_mutex worker_shared_lock_;



+ 0
- 5
mindspore_serving/proto/ms_distributed.proto View File

@@ -45,8 +45,3 @@ message AgentExitRequest {
message AgentExitReply {
ErrorMsg error_msg = 1;
}

service MSDistributedWorker {
rpc AgentExit(AgentExitRequest) returns (AgentExitReply) {}
rpc AgentRegister(AgentRegisterRequest) returns (AgentRegisterReply) {}
}

+ 5
- 0
mindspore_serving/proto/ms_worker.proto View File

@@ -20,8 +20,13 @@ syntax = "proto3";
package mindspore.serving.proto;
import "mindspore_serving/proto/ms_service.proto";
import "mindspore_serving/proto/ms_master.proto";
import "mindspore_serving/proto/ms_distributed.proto";

service MSWorker {
// for master
rpc Predict(PredictRequest) returns (PredictReply) {}
rpc Exit(ExitRequest) returns (ExitReply) {}
// for worker agent
rpc AgentExit(AgentExitRequest) returns (AgentExitReply) {}
rpc AgentRegister(AgentRegisterRequest) returns (AgentRegisterReply) {}
}

Loading…
Cancel
Save