Browse Source

Serving, python agent

tags/v1.2.0
xuyongfei 5 years ago
parent
commit
90db9fba60
32 changed files with 797 additions and 218 deletions
  1. +0
    -3
      mindspore_serving/ccsrc/master/server.cc
  2. +63
    -0
      mindspore_serving/ccsrc/python/agent/agent_py.cc
  3. +20
    -11
      mindspore_serving/ccsrc/python/agent/agent_py.h
  4. +51
    -1
      mindspore_serving/ccsrc/python/serving_py.cc
  5. +14
    -13
      mindspore_serving/ccsrc/python/worker/worker_py.cc
  6. +3
    -2
      mindspore_serving/ccsrc/python/worker/worker_py.h
  7. +1
    -1
      mindspore_serving/ccsrc/worker/distributed_worker/agent_process/agent_process.cc
  8. +21
    -6
      mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.cc
  9. +5
    -6
      mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.h
  10. +16
    -1
      mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.cc
  11. +2
    -0
      mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.h
  12. +3
    -4
      mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_server.cc
  13. +65
    -34
      mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_server.h
  14. +71
    -38
      mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.cc
  15. +11
    -4
      mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.h
  16. +17
    -1
      mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.cc
  17. +10
    -7
      mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.h
  18. +68
    -9
      mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.cc
  19. +14
    -2
      mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.h
  20. +17
    -22
      mindspore_serving/ccsrc/worker/grpc/worker_server.h
  21. +8
    -1
      mindspore_serving/ccsrc/worker/local_servable/local_sevable.cc
  22. +1
    -0
      mindspore_serving/ccsrc/worker/local_servable/local_sevable.h
  23. +1
    -0
      mindspore_serving/ccsrc/worker/sevable_base.h
  24. +11
    -4
      mindspore_serving/ccsrc/worker/worker.cc
  25. +3
    -2
      mindspore_serving/ccsrc/worker/worker.h
  26. +2
    -0
      mindspore_serving/master/_master.py
  27. +7
    -0
      mindspore_serving/proto/ms_distributed.proto
  28. +1
    -0
      mindspore_serving/proto/ms_worker.proto
  29. +3
    -1
      mindspore_serving/worker/_worker.py
  30. +228
    -21
      mindspore_serving/worker/distributed/agent_startup.py
  31. +13
    -9
      mindspore_serving/worker/distributed/distributed_worker.py
  32. +47
    -15
      mindspore_serving/worker/distributed/worker_agent.py

+ 0
- 3
mindspore_serving/ccsrc/master/server.cc View File

@@ -39,7 +39,6 @@ Status Server::StartGrpcServer(const std::string &ip, uint32_t grpc_port, int ma
if (grpc_async_server_ != nullptr) {
return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Serving Error: Serving gRPC server is already running";
}
ExitSignalHandle::Instance().Start(); // handle ctrl+c to exit
if (max_msg_mb_size > gRpcMaxMBMsgSize) {
MSI_LOG_WARNING << "The maximum Serving gRPC message size is 512MB and will be updated from " << max_msg_mb_size
<< "MB to 512MB";
@@ -50,14 +49,12 @@ Status Server::StartGrpcServer(const std::string &ip, uint32_t grpc_port, int ma
}

Status Server::StartGrpcMasterServer(const std::string &ip, uint32_t grpc_port) {
ExitSignalHandle::Instance().Start(); // handle ctrl+c to exit
return grpc_manager_server_.Start(std::make_shared<MSMasterImpl>(dispatcher_), ip, grpc_port, gRpcMaxMBMsgSize,
"Master");
}

Status Server::StartRestfulServer(const std::string &ip, uint32_t restful_port, int max_msg_mb_size,
int time_out_second) {
ExitSignalHandle::Instance().Start(); // handle ctrl+c to exit
return restful_server_.Start(ip, restful_port, max_msg_mb_size, time_out_second);
}



+ 63
- 0
mindspore_serving/ccsrc/python/agent/agent_py.cc View File

@@ -0,0 +1,63 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "python/agent/agent_py.h"
#include "common/exit_handle.h"
#include "worker/distributed_worker/agent_startup.h"
#include "worker/distributed_worker/worker_agent.h"

namespace mindspore::serving {

DistributedServableConfig PyAgent::GetAgentsConfigsFromWorker(const std::string &worker_ip, uint32_t worker_port) {
auto status = WorkerAgentStartUp::Instance().GetAgentsConfigsFromWorker(worker_ip, worker_port);
if (status != SUCCESS) {
MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage();
}

DistributedServableConfig config;
status = WorkerAgentStartUp::Instance().GetDistributedServableConfig(&config);
if (status != SUCCESS) {
MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage();
}
return config;
}

void PyAgent::NotifyFailed(const std::string &worker_ip, uint32_t worker_port) {
WorkerAgentStartUp::Instance().NotifyFailed(worker_ip, worker_port);
}

void PyAgent::StartAgent(const AgentStartUpConfig &start_config) {
auto status = WorkerAgent::Instance().StartAgent(start_config);
if (status != SUCCESS) {
MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage();
}
}

void PyAgent::WaitAndClear() {
{
py::gil_scoped_release release;
ExitSignalHandle::Instance().AgentWait();
}
WorkerAgent::Instance().Clear();
MSI_LOG_INFO << "Python agent end wait and clear";
}

void PyAgent::StopAndClear() {
ExitSignalHandle::Instance().Stop();
WorkerAgent::Instance().Clear();
}

} // namespace mindspore::serving

mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/base_notify_worker.h → mindspore_serving/ccsrc/python/agent/agent_py.h View File

@@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -14,25 +14,34 @@
* limitations under the License.
*/

#ifndef MINDSPORE_SERVING_WORKER_BASE_NOTIFY_WORKER_H
#define MINDSPORE_SERVING_WORKER_BASE_NOTIFY_WORKER_H
#include <vector>
#ifndef MINDSPORE_SERVER_AGENT_PY_H
#define MINDSPORE_SERVER_AGENT_PY_H

#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
#include <string>
#include <memory>
#include "common/serving_common.h"
#include "common/servable.h"
#include "worker/distributed_worker/common.h"

namespace py = pybind11;

namespace mindspore {
namespace serving {

class MS_API BaseNotifyDistributeWorker {
class MS_API PyAgent {
public:
BaseNotifyDistributeWorker() = default;
virtual ~BaseNotifyDistributeWorker() = default;
virtual Status Register(const std::vector<WorkerAgentSpec> &worker_specs) = 0;
virtual Status Unregister() = 0;
static void StartAgent(const AgentStartUpConfig &start_config);

static DistributedServableConfig GetAgentsConfigsFromWorker(const std::string &worker_ip, uint32_t worker_port);
static void WaitAndClear();
static void StopAndClear();
// from start up, not agent
static void NotifyFailed(const std::string &worker_ip, uint32_t worker_port);
};

} // namespace serving
} // namespace mindspore

#endif // MINDSPORE_SERVING_WORKER_BASE_NOTIFY_WORKER_H
#endif // MINDSPORE_SERVER_AGENT_PY_H

+ 51
- 1
mindspore_serving/ccsrc/python/serving_py.cc View File

@@ -23,6 +23,9 @@
#include "common/servable.h"
#include "worker/context.h"
#include "python/master/master_py.h"
#include "python/agent/agent_py.h"
#include "common/exit_handle.h"
#include "worker/distributed_worker/worker_agent.h"

namespace mindspore::serving {

@@ -104,11 +107,23 @@ void PyRegServable(pybind11::module *m_ptr) {
.def_static("register_method", &PyServableStorage::RegisterMethod)
.def_static("declare_servable", &PyServableStorage::DeclareServable)
.def_static("declare_distributed_servable", &PyServableStorage::DeclareDistributedServable);

py::class_<OneRankConfig>(m, "OneRankConfig_")
.def(py::init<>())
.def_readwrite("device_id", &OneRankConfig::device_id)
.def_readwrite("ip", &OneRankConfig::ip);

py::class_<DistributedServableConfig>(m, "DistributedServableConfig_")
.def(py::init<>())
.def_readwrite("common_meta", &DistributedServableConfig::common_meta)
.def_readwrite("distributed_meta", &DistributedServableConfig::distributed_meta)
.def_readwrite("rank_table_content", &DistributedServableConfig::rank_table_content)
.def_readwrite("rank_list", &DistributedServableConfig::rank_list);
}

void PyRegMaster(pybind11::module *m_ptr) {
auto &m = *m_ptr;
py::class_<PyMaster, std::shared_ptr<PyMaster>>(m, "Master_")
py::class_<PyMaster>(m, "Master_")
.def_static("start_grpc_server", &PyMaster::StartGrpcServer)
.def_static("start_grpc_master_server", &PyMaster::StartGrpcMasterServer)
.def_static("start_restful_server", &PyMaster::StartRestfulServer)
@@ -163,15 +178,50 @@ void PyRegWorker(pybind11::module *m_ptr) {
.def("set_device_id", &ServableContext::SetDeviceId);
}

void PyRegWorkerAgent(pybind11::module *m_ptr) {
auto &m = *m_ptr;
py::class_<PyAgent>(m, "WorkerAgent_")
.def_static("get_agents_config_from_worker", &PyAgent::GetAgentsConfigsFromWorker)
.def_static("wait_and_clear", &PyAgent::WaitAndClear)
.def_static("stop_and_clear", &PyAgent::StopAndClear)
.def_static("notify_failed", &PyAgent::NotifyFailed)
.def_static("start_agent", &PyAgent::StartAgent);

py::class_<AgentStartUpConfig>(m, "AgentStartUpConfig_")
.def(py::init<>())
.def_readwrite("rank_id", &AgentStartUpConfig::rank_id)
.def_readwrite("device_id", &AgentStartUpConfig::device_id)
.def_readwrite("model_file_name", &AgentStartUpConfig::model_file_name)
.def_readwrite("group_file_name", &AgentStartUpConfig::group_file_name)
.def_readwrite("rank_table_json_file_name", &AgentStartUpConfig::rank_table_json_file_name)
.def_readwrite("agent_ip", &AgentStartUpConfig::agent_ip)
.def_readwrite("agent_port", &AgentStartUpConfig::agent_port)
.def_readwrite("worker_ip", &AgentStartUpConfig::worker_ip)
.def_readwrite("worker_port", &AgentStartUpConfig::worker_port)
.def_readwrite("common_meta", &AgentStartUpConfig::common_meta);
}

class PyExitSignalHandle {
public:
static void Start() { ExitSignalHandle::Instance().Start(); }
static bool HasStopped() { return ExitSignalHandle::Instance().HasStopped(); }
};

// cppcheck-suppress syntaxError
PYBIND11_MODULE(_mindspore_serving, m) {
PyRegServable(&m);
PyRegMaster(&m);
PyRegWorker(&m);
PyRegWorkerAgent(&m);

py::class_<PyExitSignalHandle>(m, "ExitSignalHandle_")
.def_static("start", &PyExitSignalHandle::Start)
.def_static("has_stopped", &PyExitSignalHandle::HasStopped);

(void)py::module::import("atexit").attr("register")(py::cpp_function{[&]() -> void {
Server::Instance().Clear();
Worker::GetInstance().Clear();
WorkerAgent::Instance().Clear();
}});
}



+ 14
- 13
mindspore_serving/ccsrc/python/worker/worker_py.cc View File

@@ -24,7 +24,7 @@
#include "worker/local_servable/local_sevable.h"
#include "worker/distributed_worker/distributed_servable.h"
#include "worker/grpc/worker_server.h"
#include "worker/distributed_worker/grpc/distributed_server.h"
#include "worker/distributed_worker/distributed_process/distributed_server.h"

namespace mindspore::serving {

@@ -43,11 +43,10 @@ void PyWorker::StartServable(const std::string &model_directory, const std::stri
}
// start grpc server
auto grpc_sever = std::make_shared<MSWorkerServer>();
status = grpc_sever->StartWorkerGrpcServer(worker_ip, worker_port);
status = Worker::GetInstance().StartGrpcServer(grpc_sever, worker_ip, worker_port);
if (status != SUCCESS) {
MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage();
}
Worker::GetInstance().AfterStartGrpcServer(grpc_sever);

status = Worker::GetInstance().StartVersionController();
if (status != SUCCESS) {
@@ -76,18 +75,19 @@ void PyWorker::StartServableInMaster(const std::string &model_directory, const s
void PyWorker::StartDistributedServable(const std::string &servable_directory, const std::string &servable_name,
const std::string &rank_table_json_file, uint32_t version_number,
const std::string &worker_ip, uint32_t worker_port,
const std::string &master_ip, uint32_t master_port) {
const std::string &master_ip, uint32_t master_port,
uint32_t wait_agents_time_in_seconds) {
Status status;
auto servable = std::make_shared<DistributedServable>();
auto grpc_sever = std::make_shared<MSDistributedWorkerServer>();
status = grpc_sever->StartDistributedWorkerGrpcServer(servable, worker_ip, worker_port);
auto grpc_sever = std::make_shared<MSDistributedWorkerServer>(servable);
status = Worker::GetInstance().StartGrpcServer(grpc_sever, worker_ip, worker_port);
if (status != SUCCESS) {
MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage();
}
Worker::GetInstance().AfterStartGrpcServer(grpc_sever);

auto notify_master = std::make_shared<GrpcNotfiyMaster>(master_ip, master_port, worker_ip, worker_port);
status = servable->StartServable(servable_directory, servable_name, rank_table_json_file, version_number);
status = servable->StartServable(servable_directory, servable_name, rank_table_json_file, version_number,
wait_agents_time_in_seconds);
if (status != SUCCESS) {
MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage();
}
@@ -103,18 +103,19 @@ void PyWorker::StartDistributedServable(const std::string &servable_directory, c

void PyWorker::StartDistributedServableInMaster(const std::string &servable_directory, const std::string &servable_name,
const std::string &rank_table_json_file, uint32_t version_number,
const std::string &worker_ip, uint32_t worker_port) {
const std::string &worker_ip, uint32_t worker_port,
uint32_t wait_agents_time_in_seconds) {
Status status;
auto servable = std::make_shared<DistributedServable>();
auto grpc_sever = std::make_shared<MSDistributedWorkerServer>();
status = grpc_sever->StartDistributedWorkerGrpcServer(servable, worker_ip, worker_port);
auto grpc_sever = std::make_shared<MSDistributedWorkerServer>(servable);
status = Worker::GetInstance().StartGrpcServer(grpc_sever, worker_ip, worker_port);
if (status != SUCCESS) {
MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage();
}
Worker::GetInstance().AfterStartGrpcServer(grpc_sever);

auto notify_master = std::make_shared<LocalNotifyMaster>();
status = servable->StartServable(servable_directory, servable_name, rank_table_json_file, version_number);
status = servable->StartServable(servable_directory, servable_name, rank_table_json_file, version_number,
wait_agents_time_in_seconds);
if (status != SUCCESS) {
MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage();
}


+ 3
- 2
mindspore_serving/ccsrc/python/worker/worker_py.h View File

@@ -37,11 +37,12 @@ class MS_API PyWorker {
static void StartDistributedServable(const std::string &servable_directory, const std::string &servable_name,
const std::string &rank_table_json_file, uint32_t version_number,
const std::string &worker_ip, uint32_t worker_port, const std::string &master_ip,
uint32_t master_port);
uint32_t master_port, uint32_t wait_agents_time_in_seconds);

static void StartDistributedServableInMaster(const std::string &servable_directory, const std::string &servable_name,
const std::string &rank_table_json_file, uint32_t version_number,
const std::string &worker_ip, uint32_t worker_port);
const std::string &worker_ip, uint32_t worker_port,
uint32_t wait_agents_time_in_seconds);

static int GetBatchSize();
static void WaitAndClear();


+ 1
- 1
mindspore_serving/ccsrc/worker/distributed_worker/agent_process/agent_process.cc View File

@@ -22,7 +22,7 @@ namespace serving {
grpc::Status MSAgentImpl::Exit(grpc::ServerContext *context, const proto::DistributedExitRequest *request,
proto::DistributedExitReply *reply) {
MSI_LOG(INFO) << "Distributed Worker Exit";
WorkerAgent::Instance().Clear();
WorkerAgent::Instance().StopAgent(false);
return grpc::Status::OK;
}


+ 21
- 6
mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.cc View File

@@ -14,17 +14,32 @@
* limitations under the License.
*/
#include "worker/distributed_worker/agent_startup.h"
#include "worker/distributed_worker/notify_distributed/notify_worker.h"

namespace mindspore {
namespace serving {

Status WorkerAgentStartUp::InitAgentsConfig(const std::string &model_dir, const std::string &model_file_prefix,
const std::string &group_file_dir, const std::string &group_file_prefix) {
return Status();
WorkerAgentStartUp &WorkerAgentStartUp::Instance() {
static WorkerAgentStartUp instance;
return instance;
}
Status WorkerAgentStartUp::GetAgentsConfigsFromWorker(const std::string &agent_ip, uint32_t agent_start_port,
const std::string &worker_ip, uint32_t worker_port) {
Status WorkerAgentStartUp::GetAgentsConfigsFromWorker(const std::string &worker_ip, uint32_t worker_port) {
return Status();
}
Status WorkerAgentStartUp::GetCurrentMachineConfigs(std::vector<AgentStartUpConfig> *configs) { return Status(); }

Status WorkerAgentStartUp::GetDistributedServableConfig(DistributedServableConfig *config) {
MSI_EXCEPTION_IF_NULL(config);
if (config_.rank_list.empty()) {
return INFER_STATUS_LOG_ERROR(FAILED) << "Rank table config is not ready";
}
*config = config_;
return SUCCESS;
}

Status WorkerAgentStartUp::NotifyFailed(const std::string &worker_ip, uint32_t worker_port) {
return GrpcNotifyDistributeWorker::NotifyFailed(worker_ip, worker_port);
}

} // namespace serving
} // namespace mindspore

+ 5
- 6
mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.h View File

@@ -27,16 +27,15 @@ namespace serving {

class MS_API WorkerAgentStartUp {
public:
static WorkerAgentStartUp &Instance();
// from python, worker_agent.py
// start_worker_agent
// step1, get agents config from worker
Status InitAgentsConfig(const std::string &model_dir, const std::string &model_file_prefix,
const std::string &group_file_dir, const std::string &group_file_prefix);
Status GetAgentsConfigsFromWorker(const std::string &worker_ip, uint32_t worker_port);
// step2, invoke from python
Status GetDistributedServableConfig(DistributedServableConfig *config);

Status GetAgentsConfigsFromWorker(const std::string &rank_start, uint32_t agent_start_port,
const std::string &worker_ip, uint32_t worker_port);
// step2, invoke from python, get current machine agents config
Status GetCurrentMachineConfigs(std::vector<AgentStartUpConfig> *configs);
Status NotifyFailed(const std::string &worker_ip, uint32_t worker_port);

private:
DistributedServableConfig config_;


mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_process.cc → mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.cc View File

@@ -14,7 +14,8 @@
* limitations under the License.
*/

#include "worker/distributed_worker/grpc/distributed_process.h"
#include "worker/distributed_worker/distributed_process/distributed_process.h"
#include "worker/worker.h"
#include "common/proto_tensor.h"

namespace mindspore {
@@ -51,6 +52,20 @@ grpc::Status MSDistributedImpl::AgentExit(grpc::ServerContext *context, const pr
MSI_LOG(ERROR) << "Agent Exit FAILED";
}
}
if (Worker::GetInstance().IsRunning()) {
Worker::GetInstance().StopServable();
}
return grpc::Status::OK;
}

grpc::Status MSDistributedImpl::AgentFailed(grpc::ServerContext *context, const proto::AgentFailedRequest *request,
proto::AgentFailedReply *reply) {
if (Worker::GetInstance().IsRunning()) {
MSI_LOG_ERROR << "Expect worker should not be running";
Worker::GetInstance().StopServable();
} else {
servable_->OnAgentFailed();
}
return grpc::Status::OK;
}
} // namespace serving

mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_process.h → mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.h View File

@@ -41,6 +41,8 @@ class MSDistributedImpl final : public MSWorkerImpl {
proto::AgentRegisterReply *reply) override;
grpc::Status AgentExit(grpc::ServerContext *context, const proto::AgentExitRequest *request,
proto::AgentExitReply *reply) override;
grpc::Status AgentFailed(grpc::ServerContext *context, const proto::AgentFailedRequest *request,
proto::AgentFailedReply *reply) override;
private:
std::shared_ptr<DistributedServable> servable_;

mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_server.cc → mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_server.cc View File

@@ -14,7 +14,7 @@
* limitations under the License.
*/
#include "worker/distributed_worker/grpc/distributed_server.h"
#include "worker/distributed_worker/distributed_process/distributed_server.h"
#include <string>
#include <memory>
#include <utility>
@@ -23,12 +23,11 @@
namespace mindspore {
namespace serving {
Status MSDistributedWorkerServer::StartDistributedWorkerGrpcServer(std::shared_ptr<DistributedServable> servable,
const std::string &hostname, int32_t port) {
Status MSDistributedWorkerServer::StartWorkerGrpcServer(const std::string &hostname, int32_t port) {
if (in_running_) {
return INFER_STATUS_LOG_ERROR(FAILED) << "Worker grpc server is already running";
}
auto impl = std::make_unique<MSDistributedImpl>(servable);
auto impl = std::make_unique<MSDistributedImpl>(servable_);
async_server_ = std::make_unique<DistributedWorkerGrpcServer>(hostname, port, impl.get());
service_impl_ = std::move(impl);
return Init();

mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_server.h → mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_server.h View File

@@ -28,7 +28,7 @@
#include "common/grpc_async_server.h"
#include "worker/grpc/worker_process.h"
#include "worker/grpc/worker_server.h"
#include "worker/distributed_worker/grpc/distributed_process.h"
#include "worker/distributed_worker/distributed_process/distributed_process.h"
namespace mindspore {
namespace serving {
@@ -36,18 +36,30 @@ namespace serving {
// Service Implement
class MS_API MSDistributedWorkerServer : public MSWorkerServer {
public:
Status StartDistributedWorkerGrpcServer(std::shared_ptr<DistributedServable> servable, const std::string &hostname,
int32_t port);
explicit MSDistributedWorkerServer(std::shared_ptr<DistributedServable> servable) : servable_(servable) {}
~MSDistributedWorkerServer() = default;
Status StartWorkerGrpcServer(const std::string &hostname, int32_t port) override;
private:
std::shared_ptr<DistributedServable> servable_;
};
class DistributedServiceContext : public WorkerServiceContext {
public:
DistributedServiceContext(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service,
grpc::ServerCompletionQueue *cq)
: WorkerServiceContext(service_impl, async_service, cq), dist_service_impl_(service_impl) {}
protected:
MSDistributedImpl *dist_service_impl_ = nullptr;
};
// Service Implement
class WorkerAgentRegisterContext : public WorkerServiceContext {
class WorkerAgentRegisterContext : public DistributedServiceContext {
public:
WorkerAgentRegisterContext(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service,
grpc::ServerCompletionQueue *cq)
: service_impl_(service_impl), async_service_(async_service), cq_(cq), responder_(&ctx_) {
state_ = STATE::CREATE;
}
: DistributedServiceContext(service_impl, async_service, cq), responder_(&ctx_) {}
~WorkerAgentRegisterContext() = default;
@@ -60,35 +72,27 @@ class WorkerAgentRegisterContext : public WorkerServiceContext {
void StartEnqueueRequest() override {
state_ = STATE::PROCESS;
async_service_->RequestPredict(&ctx_, &request_, &responder_, cq_, cq_, this);
async_service_->RequestAgentRegister(&ctx_, &request_, &responder_, cq_, cq_, this);
}
void HandleRequest() override {
EnqueueRequest(service_impl_, async_service_, cq_);
EnqueueRequest(dist_service_impl_, async_service_, cq_);
state_ = STATE::FINISH;
grpc::Status status = service_impl_->Predict(&ctx_, &request_, &response_);
grpc::Status status = dist_service_impl_->AgentRegister(&ctx_, &request_, &response_);
responder_.Finish(response_, status, this);
}
bool JudgeFinish() override { return state_ == STATE::FINISH; }
private:
MSDistributedImpl *service_impl_;
proto::MSWorker::AsyncService *async_service_;
grpc::ServerCompletionQueue *cq_;
grpc::ServerContext ctx_;
grpc::ServerAsyncResponseWriter<proto::PredictReply> responder_;
proto::PredictRequest request_;
proto::PredictReply response_;
grpc::ServerAsyncResponseWriter<proto::AgentRegisterReply> responder_;
proto::AgentRegisterRequest request_;
proto::AgentRegisterReply response_;
};
class WorkerAgentExitContext : public WorkerServiceContext {
class WorkerAgentExitContext : public DistributedServiceContext {
public:
WorkerAgentExitContext(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service,
grpc::ServerCompletionQueue *cq)
: service_impl_(service_impl), async_service_(async_service), cq_(cq), responder_(&ctx_) {
state_ = STATE::CREATE;
}
: DistributedServiceContext(service_impl, async_service, cq), responder_(&ctx_) {}
~WorkerAgentExitContext() = default;
@@ -101,26 +105,52 @@ class WorkerAgentExitContext : public WorkerServiceContext {
void StartEnqueueRequest() override {
state_ = STATE::PROCESS;
async_service_->RequestExit(&ctx_, &request_, &responder_, cq_, cq_, this);
async_service_->RequestAgentExit(&ctx_, &request_, &responder_, cq_, cq_, this);
}
void HandleRequest() override {
EnqueueRequest(service_impl_, async_service_, cq_);
EnqueueRequest(dist_service_impl_, async_service_, cq_);
state_ = STATE::FINISH;
grpc::Status status = service_impl_->Exit(&ctx_, &request_, &response_);
grpc::Status status = dist_service_impl_->AgentExit(&ctx_, &request_, &response_);
responder_.Finish(response_, status, this);
}
bool JudgeFinish() override { return state_ == STATE::FINISH; }
private:
grpc::ServerAsyncResponseWriter<proto::AgentExitReply> responder_;
proto::AgentExitRequest request_;
proto::AgentExitReply response_;
};
class WorkerAgentFailedContext : public DistributedServiceContext {
public:
WorkerAgentFailedContext(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service,
grpc::ServerCompletionQueue *cq)
: DistributedServiceContext(service_impl, async_service, cq), responder_(&ctx_) {}
~WorkerAgentFailedContext() = default;
static Status EnqueueRequest(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service,
grpc::ServerCompletionQueue *cq) {
auto call = new WorkerAgentFailedContext(service_impl, async_service, cq);
call->StartEnqueueRequest();
return SUCCESS;
}
void StartEnqueueRequest() override {
state_ = STATE::PROCESS;
async_service_->RequestAgentFailed(&ctx_, &request_, &responder_, cq_, cq_, this);
}
void HandleRequest() override {
EnqueueRequest(dist_service_impl_, async_service_, cq_);
state_ = STATE::FINISH;
grpc::Status status = dist_service_impl_->AgentFailed(&ctx_, &request_, &response_);
responder_.Finish(response_, status, this);
}
private:
MSDistributedImpl *service_impl_;
proto::MSWorker::AsyncService *async_service_;
grpc::ServerCompletionQueue *cq_;
grpc::ServerContext ctx_;
grpc::ServerAsyncResponseWriter<proto::ExitReply> responder_;
proto::ExitRequest request_;
proto::ExitReply response_;
grpc::ServerAsyncResponseWriter<proto::AgentFailedReply> responder_;
proto::AgentFailedRequest request_;
proto::AgentFailedReply response_;
};
class DistributedWorkerGrpcServer : public WorkerGrpcServer {
@@ -134,6 +164,7 @@ class DistributedWorkerGrpcServer : public WorkerGrpcServer {
WorkerGrpcServer::EnqueueRequest();
WorkerAgentRegisterContext::EnqueueRequest(distributed_service_impl_, &svc_, cq_.get());
WorkerAgentExitContext::EnqueueRequest(distributed_service_impl_, &svc_, cq_.get());
WorkerAgentFailedContext::EnqueueRequest(distributed_service_impl_, &svc_, cq_.get());
return SUCCESS;
}

+ 71
- 38
mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.cc View File

@@ -17,13 +17,15 @@
#include "worker/distributed_worker/distributed_servable.h"
#include <vector>
#include <string>
#include "worker/worker.h"
#include <set>
#include "worker/distributed_worker/notify_agent/notify_agent.h"
#include "common/exit_handle.h"

namespace mindspore {
namespace serving {

DistributedServable::~DistributedServable() { Clear(); }

std::string DistributedServable::GetServableName() const { return servable_name_; }

uint64_t DistributedServable::GetServableVersion() const { return version_number_; }
@@ -60,7 +62,15 @@ Status DistributedServable::GetDistributedServableConfig(DistributedServableConf
return SUCCESS;
}

void DistributedServable::SetWaitAgentsPromise(bool flag) {
if (!promise_set_flag_.test_and_set()) {
agents_promise_.set_value(flag);
}
}

Status DistributedServable::RegisterAgent(const WorkerAgentSpec &agent_spec) {
std::unique_lock<std::mutex> lock{mutex_};

if (agent_spec.rank_id < config_.distributed_meta.rank_size) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Invalid rank id " << agent_spec.rank_id << ", rank size " << config_.distributed_meta.rank_size;
@@ -75,27 +85,24 @@ Status DistributedServable::RegisterAgent(const WorkerAgentSpec &agent_spec) {
std::shared_ptr<BaseNotifyAgent> notify_agent = std::make_shared<GrpcNotfiyAgent>(agent_spec.agent_address);
context.notify_agent_ = notify_agent;
agent_spec_map_[agent_spec.rank_id] = context;
if (config_.distributed_meta.rank_size == agent_spec_map_.size()) {
Status status = Worker::GetInstance().RegisterWorker();
if (status != SUCCESS) {
Clear();
return FAILED;
}
}

if (agent_spec_map_.size() >= config_.distributed_meta.rank_size) {
agents_promise_.set_value();
SetWaitAgentsPromise(true);
}
return SUCCESS;
}

void DistributedServable::Clear() {
for (auto agent : agent_spec_map_) {
std::unique_lock<std::mutex> lock{mutex_};
for (auto &agent : agent_spec_map_) {
agent.second.notify_agent_->Exit();
}
Worker::GetInstance().StopServable(false);
agent_spec_map_.clear();
MSI_LOG_INFO << "End Clear servable";
}

Status DistributedServable::UnregisterAgent(const WorkerAgentSpec &agent_spec) {
std::unique_lock<std::mutex> lock{mutex_};
for (auto iter = agent_spec_map_.begin(); iter != agent_spec_map_.end();) {
if (agent_spec.rank_id == iter->second.agent_spec_.rank_id) {
iter = agent_spec_map_.erase(iter);
@@ -103,13 +110,13 @@ Status DistributedServable::UnregisterAgent(const WorkerAgentSpec &agent_spec) {
++iter;
}
}
// todo: send exit message to agent, and then exit if split with master
Clear();
SetWaitAgentsPromise(false);
return SUCCESS;
}

Status DistributedServable::StartServable(const std::string &servable_directory, const std::string &servable_name,
const std::string &rank_table_json_file, uint64_t version_number) {
const std::string &rank_table_json_file, uint64_t version_number,
uint64_t wait_agents_time_in_seconds) {
if (model_loaded_) {
MSI_LOG_EXCEPTION << "Model has loaded";
}
@@ -138,7 +145,7 @@ Status DistributedServable::StartServable(const std::string &servable_directory,
MSI_LOG_ERROR << "Check rank config failed";
return status;
}
status = WaitAgentsReady();
status = WaitAgentsReady(wait_agents_time_in_seconds);
if (status != SUCCESS) {
MSI_LOG_ERROR << "Waiting for ready of agents failed";
return status;
@@ -154,16 +161,23 @@ Status DistributedServable::StartServable(const std::string &servable_directory,

Status DistributedServable::InitConfigOnStartup(const std::string &rank_table_json_file) { return FAILED; }

Status DistributedServable::WaitAgentsReady() {
Status DistributedServable::WaitAgentsReady(uint64_t wait_agents_time_in_seconds) {
auto future = agents_promise_.get_future();
const int kWaitMaxHundredMs = 100 * 10; // 100s
int i;
if (wait_agents_time_in_seconds == 0) {
wait_agents_time_in_seconds = UINT32_MAX;
}
const uint64_t kWaitMaxHundredMs = wait_agents_time_in_seconds * 10;
uint64_t i;
for (i = 0; i < kWaitMaxHundredMs; i++) { //
if (ExitSignalHandle::Instance().HasStopped()) {
return INFER_STATUS_LOG_ERROR(FAILED) << "Agents has stopped";
}
// waiting for 100ms
if (future.wait_for(std::chrono::milliseconds(100)) == std::future_status::ready) {
auto flag = future.get();
if (!flag) {
return INFER_STATUS_LOG_ERROR(FAILED) << "Failed to starting all agents, maybe some error reported";
}
break;
}
}
@@ -264,32 +278,49 @@ Status DistributedServable::CheckRankConfig() {
<< "Rank size must be an integral multiple of stage size, rank size: " << rank_size
<< ", stage size: " << stage_size;
}
auto parallel_count = rank_size / stage_size;
constexpr size_t card_count_per_machine = 8;
if (rank_size > card_count_per_machine && parallel_count % card_count_per_machine != 0) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Parallel count " << parallel_count << " in one stage must be an integral multiple of card count "
<< card_count_per_machine << " in one machine, when rank size is greater than card count in one machine, "
<< "rank size: " << rank_size << ", stage size: " << stage_size;
}
if (config_.rank_list.size() != rank_size) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Rank size " << config_.rank_list.size() << " declared in rank table file not equal to rank size "
<< rank_size << " declared in servable_config, rank json config file: " << rank_table_json_file_;
}
for (size_t i = 0; i < rank_size; i++) {
const auto &first_item = config_.rank_list[i];
for (size_t k = 0; i + k < rank_size && k < card_count_per_machine; k++) {
auto rank_id = i + k;
const auto &item = config_.rank_list[rank_id];
if (k != item.device_id) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Check rank table config failed, expected device id of rank " << rank_id << " to be " << k;
auto parallel_count = rank_size / stage_size;
constexpr size_t card_count_per_machine = 8;
if (stage_size == 1) {
std::map<std::string, std::set<uint32_t>> device_map;
for (size_t i = 0; i < rank_size; i++) {
const auto &item = config_.rank_list[i];
auto &device_id_list = device_map[item.ip];
if (device_id_list.count(item.device_id) > 0) {
return INFER_STATUS_LOG_ERROR(FAILED) << "Check rank table config failed, device id repeatedly used by rank "
<< i << " in device ip " << item.ip;
}
if (first_item.ip != item.ip) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Check rank table config failed, expected device ip " << item.ip << " of rank " << rank_id
<< " to be equal with device ip " << first_item.ip << " of rank " << i;
device_id_list.emplace(item.device_id);
}
} else {
if (rank_size < card_count_per_machine) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Rank size " << rank_size << "must >= card count " << card_count_per_machine
<< " of one machine when stage size " << stage_size << " > 1";
}
if (parallel_count % card_count_per_machine != 0) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Parallel count " << parallel_count << " in one stage must be N * " << card_count_per_machine
<< "(card count of one machine), rank size: " << rank_size << ", stage size: " << stage_size;
}
for (size_t i = 0; i < rank_size; i += card_count_per_machine) {
const auto &first_item = config_.rank_list[i];
for (size_t k = 0; i + k < rank_size && k < card_count_per_machine; k++) {
auto rank_id = i + k;
const auto &item = config_.rank_list[rank_id];
if (k != item.device_id) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Check rank table config failed, expected device id of rank " << rank_id << " to be " << k;
}
if (first_item.ip != item.ip) {
return INFER_STATUS_LOG_ERROR(FAILED)
<< "Check rank table config failed, expected device ip " << item.ip << " of rank " << rank_id
<< " to be equal with device ip " << first_item.ip << " of rank " << i;
}
}
}
}
@@ -298,5 +329,7 @@ Status DistributedServable::CheckRankConfig() {
return SUCCESS;
}

void DistributedServable::OnAgentFailed() { SetWaitAgentsPromise(false); }

} // namespace serving
} // namespace mindspore

+ 11
- 4
mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.h View File

@@ -35,9 +35,12 @@ struct DistributedAgentContext {

class MS_API DistributedServable : public ServableBase {
public:
DistributedServable() = default;
~DistributedServable();
// from python, worker.py
Status StartServable(const std::string &servable_directory, const std::string &servable_name,
const std::string &rank_table_json_file, uint64_t version_number);
const std::string &rank_table_json_file, uint64_t version_number,
uint64_t wait_agents_time_in_seconds);

// invoke from agent
Status GetDistributedServableConfig(DistributedServableConfig *config) const;
@@ -55,7 +58,8 @@ class MS_API DistributedServable : public ServableBase {
uint64_t GetBatchSize() const override;
std::string GetServableName() const override;
uint64_t GetServableVersion() const override;
void Clear();
void Clear() override;
void OnAgentFailed();

private:
DistributedServableConfig config_;
@@ -63,19 +67,22 @@ class MS_API DistributedServable : public ServableBase {
uint64_t version_number_ = 0;
bool model_loaded_ = false;

std::mutex mutex_;
std::map<uint32_t, DistributedAgentContext> agent_spec_map_;
std::string rank_table_json_file_;

std::vector<TensorInfo> input_infos_;
std::vector<TensorInfo> output_infos_;
uint64_t batch_size_ = 0;
std::promise<void> agents_promise_;
std::atomic_flag promise_set_flag_ = ATOMIC_FLAG_INIT;
std::promise<bool> agents_promise_;

Status InitConfigOnStartup(const std::string &rank_table_json_file);
Status WaitAgentsReady();
Status WaitAgentsReady(uint64_t wait_agents_time_in_seconds);
Status CheckAgentsInfosAndInitTensorInfos();
Status CompareTensorInfos(const std::vector<TensorInfo> &lefts, const std::vector<TensorInfo> &rights);
Status CheckRankConfig();
void SetWaitAgentsPromise(bool flag);
// agent stubs
};



+ 17
- 1
mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.cc View File

@@ -62,7 +62,7 @@ Status GrpcNotifyDistributeWorker::Register(const std::vector<WorkerAgentSpec> &
std::this_thread::sleep_for(std::chrono::milliseconds(REGISTER_INTERVAL * 1000));
}
if (ExitSignalHandle::Instance().HasStopped()) {
return INFER_STATUS_LOG_WARNING(FAILED) << "Worker exit, stop registration";
return INFER_STATUS_LOG_WARNING(FAILED) << "Agent exit, stop registration";
}
return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Register TimeOut";
}
@@ -87,5 +87,21 @@ Status GrpcNotifyDistributeWorker::Unregister() {
return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Exit Failed";
}

Status GrpcNotifyDistributeWorker::NotifyFailed(const std::string &worker_ip, uint32_t worker_port) {
auto address = worker_ip + ":" + std::to_string(worker_port);
auto channel = GrpcServer::CreateChannel(address);
auto stub = proto::MSWorker::NewStub(channel);

grpc::ClientContext context;
proto::AgentFailedRequest request;
proto::AgentFailedReply reply;
grpc::Status status = stub->AgentFailed(&context, request, &reply);
if (status.ok()) {
MSI_LOG(INFO) << "Success to notify failure of agent";
return SUCCESS;
}
return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Failed to notify failure of agent";
}

} // namespace serving
} // namespace mindspore

+ 10
- 7
mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.h View File

@@ -19,7 +19,8 @@
#include <vector>
#include <string>
#include <memory>
#include "worker/distributed_worker/notify_distributed/base_notify_worker.h"
#include "common/serving_common.h"
#include "worker/distributed_worker/common.h"
#include "proto/ms_distributed.pb.h"
#include "proto/ms_distributed.grpc.pb.h"
#include "proto/ms_worker.pb.h"
@@ -27,13 +28,15 @@
namespace mindspore {
namespace serving {

class MS_API GrpcNotifyDistributeWorker : public BaseNotifyDistributeWorker {
class MS_API GrpcNotifyDistributeWorker {
public:
GrpcNotifyDistributeWorker(const std::string &master_ip, uint32_t master_port, const std::string &host_ip,
uint32_t host_port);
~GrpcNotifyDistributeWorker() override;
Status Register(const std::vector<WorkerAgentSpec> &worker_specs) override;
Status Unregister() override;
GrpcNotifyDistributeWorker(const std::string &worker_ip, uint32_t worker_port, const std::string &agent_ip,
uint32_t agent_port);
~GrpcNotifyDistributeWorker();
Status Register(const std::vector<WorkerAgentSpec> &agent_specs);
Status Unregister();
// from start up, not agent
static Status NotifyFailed(const std::string &worker_ip, uint32_t worker_port);

private:
std::string distributed_worker_ip_;


+ 68
- 9
mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.cc View File

@@ -14,6 +14,10 @@
* limitations under the License.
*/
#include "worker/distributed_worker/worker_agent.h"
#include <memory>
#include "worker/distributed_worker/agent_process/agent_process.h"
#include "worker/distributed_worker/notify_distributed/notify_worker.h"
#include "common/exit_handle.h"

namespace mindspore {
namespace serving {
@@ -23,15 +27,16 @@ WorkerAgent &WorkerAgent::Instance() {
return instance;
}

Status WorkerAgent::LoadModelFromFile(const AgentStartUpConfig &config) {
config_ = config;
return executor_.LoadModelFromFile(config);
}

Status WorkerAgent::Clear() { return executor_.UnloadModel(); }

Status WorkerAgent::ExecuteModel(const std::vector<TensorBasePtr> &request, std::vector<TensorBasePtr> *reply) {
return executor_.ExecuteModel(request, reply);
Status WorkerAgent::Clear() {
if (notify_worker_) {
if (exit_notify_worker_) {
notify_worker_->Unregister();
}
notify_worker_ = nullptr;
}
grpc_server_.Stop();
executor_.UnloadModel();
return SUCCESS;
}

Status WorkerAgent::Run(const proto::DistributedPredictRequest &request, proto::DistributedPredictReply *reply) {
@@ -40,5 +45,59 @@ Status WorkerAgent::Run(const proto::DistributedPredictRequest &request, proto::
return SUCCESS;
}

Status WorkerAgent::StartAgent(const AgentStartUpConfig &config) {
Status status;
config_ = config;
status = executor_.LoadModelFromFile(config);
if (status != SUCCESS) {
MSI_LOG_ERROR << "LoadModelFromFile failed, servable name: " << config.common_meta.servable_name
<< ", rank_id: " << config.rank_id << ", device id: " << config.device_id
<< ", model file: " << config.model_file_name
<< ", rank table file: " << config.rank_table_json_file_name
<< ", group config file: " << config.group_file_name;
return status;
}
status = StartGrpcServer();
if (status != SUCCESS) {
MSI_LOG_ERROR << "Start agent grpc server failed, agent ip: " << config.agent_ip
<< ", agent port: " << config.agent_port;
return status;
}
status = RegisterAgent();
if (status != SUCCESS) {
MSI_LOG_ERROR << "Register agent failed, agent ip: " << config.agent_ip << ", agent port: " << config.agent_port
<< ", worker ip: " << config.worker_ip << ", worker port: " << config.worker_port;
return status;
}
MSI_LOG_INFO << "Start agent success, servable name: " << config.common_meta.servable_name
<< ", rank_id: " << config.rank_id << ", device id: " << config.device_id
<< ", model file: " << config.model_file_name
<< ", rank table file: " << config.rank_table_json_file_name
<< ", group config file: " << config.group_file_name;
return SUCCESS;
}

Status WorkerAgent::StartGrpcServer() {
grpc_server_.Start(std::make_shared<MSAgentImpl>(), config_.agent_ip, config_.agent_port, gRpcMaxMBMsgSize, "Agent");
return SUCCESS;
}

Status WorkerAgent::RegisterAgent() {
notify_worker_ = std::make_shared<GrpcNotifyDistributeWorker>(config_.worker_ip, config_.agent_port, config_.agent_ip,
config_.agent_port);
WorkerAgentSpec spec;
spec.agent_address = config_.agent_ip + ":" + std::to_string(config_.agent_port);
spec.rank_id = config_.rank_id;
spec.batch_size = executor_.GetBatchSize();
spec.input_infos = executor_.GetInputInfos();
spec.output_infos = executor_.GetOutputInfos();
return notify_worker_->Register({spec});
}

void WorkerAgent::StopAgent(bool notify_worker) {
exit_notify_worker_ = notify_worker;
ExitSignalHandle::Instance().Stop();
}

} // namespace serving
} // namespace mindspore

+ 14
- 2
mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.h View File

@@ -17,24 +17,36 @@
#ifndef MINDSPORE_SERVING_WORKER_AGENT_H
#define MINDSPORE_SERVING_WORKER_AGENT_H
#include <vector>
#include <memory>
#include "worker/distributed_worker/agent_executor.h"
#include "proto/ms_agent.pb.h"
#include "proto/ms_agent.grpc.pb.h"
#include "common/grpc_server.h"
#include "worker/distributed_worker/common.h"
#include "worker/distributed_worker/notify_distributed/notify_worker.h"

namespace mindspore {
namespace serving {
class MS_API WorkerAgent {
public:
static WorkerAgent &Instance();
Status LoadModelFromFile(const AgentStartUpConfig &config);
Status Clear();

Status ExecuteModel(const std::vector<TensorBasePtr> &request, std::vector<TensorBasePtr> *reply);
Status Run(const proto::DistributedPredictRequest &request, proto::DistributedPredictReply *reply);

Status StartAgent(const AgentStartUpConfig &config);

void StopAgent(bool notify_worker = true);

private:
AgentStartUpConfig config_;
WorkerAgentExecutor executor_;
GrpcServer grpc_server_;
bool exit_notify_worker_ = true;
std::shared_ptr<GrpcNotifyDistributeWorker> notify_worker_;

Status StartGrpcServer();
Status RegisterAgent();
};

} // namespace serving


+ 17
- 22
mindspore_serving/ccsrc/worker/grpc/worker_server.h View File

@@ -39,7 +39,7 @@ class MS_API MSWorkerServer {
MSWorkerServer();
virtual ~MSWorkerServer();
Status StartWorkerGrpcServer(const std::string &hostname, int32_t port);
virtual Status StartWorkerGrpcServer(const std::string &hostname, int32_t port);
Status Stop();
protected:
@@ -48,21 +48,32 @@ class MS_API MSWorkerServer {
std::unique_ptr<MSWorkerImpl> service_impl_ = nullptr;
std::unique_ptr<GrpcAsyncServer> async_server_ = nullptr;
Status StartAsyncRpcService();
Status Init();
Status StartAsyncRpcService();
};
class WorkerServiceContext {
public:
enum class STATE : int8_t { CREATE = 1, PROCESS = 2, FINISH = 3 };
WorkerServiceContext(MSWorkerImpl *service_impl, proto::MSWorker::AsyncService *async_service,
grpc::ServerCompletionQueue *cq)
: service_impl_(service_impl), async_service_(async_service), cq_(cq) {
state_ = STATE::CREATE;
}
virtual ~WorkerServiceContext() {}
bool JudgeFinish() { return state_ == STATE::FINISH; }
virtual void StartEnqueueRequest() = 0;
virtual void HandleRequest() = 0;
virtual bool JudgeFinish() = 0;
protected:
MSWorkerImpl *service_impl_;
proto::MSWorker::AsyncService *async_service_;
grpc::ServerCompletionQueue *cq_;
grpc::ServerContext ctx_;
public:
STATE state_;
};
@@ -70,9 +81,7 @@ class WorkerPredictContext : public WorkerServiceContext {
public:
WorkerPredictContext(MSWorkerImpl *service_impl, proto::MSWorker::AsyncService *async_service,
grpc::ServerCompletionQueue *cq)
: service_impl_(service_impl), async_service_(async_service), cq_(cq), responder_(&ctx_) {
state_ = STATE::CREATE;
}
: WorkerServiceContext(service_impl, async_service, cq), responder_(&ctx_) {}
~WorkerPredictContext() = default;
@@ -95,13 +104,7 @@ class WorkerPredictContext : public WorkerServiceContext {
responder_.Finish(response_, status, this);
}
bool JudgeFinish() override { return state_ == STATE::FINISH; }
private:
MSWorkerImpl *service_impl_;
proto::MSWorker::AsyncService *async_service_;
grpc::ServerCompletionQueue *cq_;
grpc::ServerContext ctx_;
grpc::ServerAsyncResponseWriter<proto::PredictReply> responder_;
proto::PredictRequest request_;
proto::PredictReply response_;
@@ -111,9 +114,7 @@ class WorkerExitContext : public WorkerServiceContext {
public:
WorkerExitContext(MSWorkerImpl *service_impl, proto::MSWorker::AsyncService *async_service,
grpc::ServerCompletionQueue *cq)
: service_impl_(service_impl), async_service_(async_service), cq_(cq), responder_(&ctx_) {
state_ = STATE::CREATE;
}
: WorkerServiceContext(service_impl, async_service, cq), responder_(&ctx_) {}
~WorkerExitContext() = default;
@@ -136,13 +137,7 @@ class WorkerExitContext : public WorkerServiceContext {
responder_.Finish(response_, status, this);
}
bool JudgeFinish() override { return state_ == STATE::FINISH; }
private:
MSWorkerImpl *service_impl_;
proto::MSWorker::AsyncService *async_service_;
grpc::ServerCompletionQueue *cq_;
grpc::ServerContext ctx_;
grpc::ServerAsyncResponseWriter<proto::ExitReply> responder_;
proto::ExitRequest request_;
proto::ExitReply response_;


+ 8
- 1
mindspore_serving/ccsrc/worker/local_servable/local_sevable.cc View File

@@ -31,7 +31,7 @@ static const char *kVersionStrategySpecific = "specific";

namespace mindspore::serving {

LocalModelServable::~LocalModelServable() { session_.UnloadModel(); }
LocalModelServable::~LocalModelServable() { Clear(); }

std::string LocalModelServable::GetServableName() const { return servable_name_; }

@@ -248,4 +248,11 @@ Status LocalModelServable::LoadModel(uint64_t version_number) {
return SUCCESS;
}

void LocalModelServable::Clear() {
if (model_loaded_) {
session_.UnloadModel();
}
model_loaded_ = false;
}

} // namespace mindspore::serving

+ 1
- 0
mindspore_serving/ccsrc/worker/local_servable/local_sevable.h View File

@@ -48,6 +48,7 @@ class MS_API LocalModelServable : public ServableBase {
Status InitDevice(ModelType model_type, const std::map<std::string, std::string> &other_options);
std::string GetServableName() const override;
uint64_t GetServableVersion() const override;
void Clear() override;

private:
LoadServableSpec base_spec_;


+ 1
- 0
mindspore_serving/ccsrc/worker/sevable_base.h View File

@@ -41,6 +41,7 @@ class ServableBase {
virtual uint64_t GetBatchSize() const = 0;
virtual std::string GetServableName() const = 0;
virtual uint64_t GetServableVersion() const = 0;
virtual void Clear() = 0;
};

} // namespace mindspore::serving


+ 11
- 4
mindspore_serving/ccsrc/worker/worker.cc View File

@@ -174,9 +174,13 @@ void Worker::Update() {
*/
}

Status Worker::AfterStartGrpcServer(const std::shared_ptr<MSWorkerServer> &grpc_server) {
Status Worker::StartGrpcServer(const std::shared_ptr<MSWorkerServer> &grpc_server, const std::string &worker_ip,
int32_t port) {
if (worker_grpc_server_ != nullptr) {
return INFER_STATUS_LOG_ERROR(FAILED) << "Worker gRPC server is already running";
}
worker_grpc_server_ = grpc_server;
return SUCCESS;
return worker_grpc_server_->StartWorkerGrpcServer(worker_ip, port);
}

Status Worker::StartServable(std::shared_ptr<ServableBase> servable, std::shared_ptr<BaseNotifyMaster> notify_master) {
@@ -248,6 +252,9 @@ void Worker::Clear() {
if (exit_notify_master_ && servable_started_) {
notify_master_->Unregister();
}
for (auto &worker_item : work_list_) {
worker_item.servable->Clear();
}
work_list_.clear();

py_task_queue_group_.Stop();
@@ -257,7 +264,7 @@ void Worker::Clear() {
MSI_LOG_INFO << "End clear worker session";
}

bool Worker::HasCleared() { return !servable_started_; }
bool Worker::IsRunning() { return servable_started_; }

Worker::~Worker() { Clear(); }

@@ -318,7 +325,7 @@ Status AsyncResult::GetNext(Instance *instance_result) {
const int kWaitMaxHundredMs = 100;
int i;
for (i = 0; i < kWaitMaxHundredMs; i++) { //
if (ExitSignalHandle::Instance().HasStopped() || Worker::GetInstance().HasCleared()) {
if (ExitSignalHandle::Instance().HasStopped() || !Worker::GetInstance().IsRunning()) {
instance_result->error_msg = Status(SYSTEM_ERROR, "Servable stopped");
return SYSTEM_ERROR;
}


+ 3
- 2
mindspore_serving/ccsrc/worker/worker.h View File

@@ -75,10 +75,11 @@ class MS_API Worker {
const std::vector<InstanceData> &inputs);
Status StartServable(std::shared_ptr<ServableBase> servable, std::shared_ptr<BaseNotifyMaster> notify_master);

Status AfterStartGrpcServer(const std::shared_ptr<MSWorkerServer> &grpc_server);
Status StartGrpcServer(const std::shared_ptr<MSWorkerServer> &grpc_server, const std::string &worker_ip,
int32_t port);

void StopServable(bool notify_master = true);
bool HasCleared();
bool IsRunning();
Status RegisterWorker();
void Update();
Status StartVersionController();


+ 2
- 0
mindspore_serving/master/_master.py View File

@@ -18,6 +18,7 @@ import threading
from functools import wraps
from mindspore_serving.worker import check_type
from mindspore_serving import log as logger
from mindspore_serving._mindspore_serving import ExitSignalHandle_
from mindspore_serving._mindspore_serving import Master_

_wait_and_clear_thread = None
@@ -59,6 +60,7 @@ def stop_on_except(func):
@wraps(func)
def handle_except(*args, **kwargs):
try:
ExitSignalHandle_.start() # Set flag to running and receive Ctrl+C message
func(*args, **kwargs)
except:
stop()


+ 7
- 0
mindspore_serving/proto/ms_distributed.proto View File

@@ -44,3 +44,10 @@ message AgentExitRequest {
message AgentExitReply {
ErrorMsg error_msg = 1;
}

message AgentFailedRequest {
}

message AgentFailedReply {
ErrorMsg error_msg = 1;
}

+ 1
- 0
mindspore_serving/proto/ms_worker.proto View File

@@ -29,4 +29,5 @@ service MSWorker {
// for worker agent
rpc AgentExit(AgentExitRequest) returns (AgentExitReply) {}
rpc AgentRegister(AgentRegisterRequest) returns (AgentRegisterReply) {}
rpc AgentFailed(AgentFailedRequest) returns (AgentFailedReply) {}
}

+ 3
- 1
mindspore_serving/worker/_worker.py View File

@@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Inferface for start up servable"""
"""Interface for start up servable"""

import threading
from functools import wraps
from mindspore_serving import log as logger
from mindspore_serving._mindspore_serving import ExitSignalHandle_
from mindspore_serving._mindspore_serving import Worker_
from .register.preprocess import preprocess_storage
from .register.postprocess import postprocess_storage
@@ -77,6 +78,7 @@ def stop_on_except(func):
@wraps(func)
def handle_except(*args, **kwargs):
try:
ExitSignalHandle_.start() # Set flag to running and receive Ctrl+C message
func(*args, **kwargs)
except:
stop()


+ 228
- 21
mindspore_serving/worker/distributed/agent_startup.py View File

@@ -13,31 +13,238 @@
# limitations under the License.
# ============================================================================
"""Serving, distributed worker agent startup"""
import inspect

import os
import time
from multiprocessing import Process, Pipe

from mindspore_serving._mindspore_serving import ExitSignalHandle_
from mindspore_serving._mindspore_serving import WorkerAgent_, AgentStartUpConfig_

from mindspore_serving import log as logger
from mindspore_serving.worker import check_type
from mindspore_serving.worker.distributed import worker_agent


def _get_local_ip(rank_list, port):
"""Get the local ip from the rank table config"""
import socket
ip_list = []
for item in rank_list:
ip_list.append(item.ip)
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
for ip in ip_list:
try:
s.bind((ip, port))
logger.info(f"Get local machine ip success, ip {ip}")
return ip
# pylint: disable=bare-except
except:
pass
raise RuntimeError(f"Get local machine ip failed, rank table ips: {ip_list}, bind port {port}")


def _update_model_files_path(model_files, group_config_files):
"""Check and return model files or group config files"""
script_dir = os.path.dirname(os.path.realpath(__file__))
logger.info(f"input model files: {model_files}")
logger.info(f"input group config files: {group_config_files}")
model_files_temp = []
for item in model_files:
file_name = os.path.join(script_dir, item)
if not os.access(file_name, os.R_OK):
raise RuntimeError(f"Cannot access model file '{file_name}'")
model_files_temp.append(file_name)

group_files_temp = []
for item in group_config_files:
file_name = os.path.join(script_dir, item)
if not os.access(file_name, os.R_OK):
raise RuntimeError(f"Cannot access group config file '{file_name}'")
group_files_temp.append(file_name)

logger.info(f"absolute model files: {model_files_temp}")
logger.info(f"absolute group config files: {group_files_temp}")
return model_files_temp, group_files_temp


def _make_json_table_file(distributed_config):
"""Make rank table json file"""
rank_size = len(distributed_config.rank_list)
runtime_dir = os.path.abspath(".")
time_stamp = str(time.strftime('%Y_%m_%d_%H_%M_%S', time.localtime(time.time())))
rank_table_file_name = os.path.join(runtime_dir, f"hccl_rank_table_{time_stamp}_{rank_size}.json")
with open(rank_table_file_name, "w") as fp:
fp.write(distributed_config.rank_table_content)
return rank_table_file_name


signal_success = "Success"
signal_exit = "Exit"
signal_heartbeat = "HeartBeat"


def _recv_parent(index, recv_pipe):
"""Receive message from Start up process.
Return False on Ctrl+C(and worker Stop message) Exit Signal, heartbeat failed, and signal_exit.
Return True on receiving signal_success."""
try:
while True:
heartbeat_count = 0
while not recv_pipe.poll(0.1):
if ExitSignalHandle_.has_stopped():
logger.warning(f"Child {index}: Exit on Ctrl+C or stop message from worker")
return False
heartbeat_count += 1
if heartbeat_count >= 30: # 3s
logger.warning(f"Child {index}: Exit on failure of receiving parent message")
return False
parent_signal = recv_pipe.recv()
if parent_signal != signal_heartbeat:
break
if parent_signal == signal_success:
logger.info(f"Child {index}: Receive success")
return True
if parent_signal == signal_exit:
logger.warning(f"Child {index}: Exit on receiving exit message")
else:
logger.warning(f"Child {index}: Exit on receiving unknown message {parent_signal}")
# pylint: disable=broad-except
except Exception as e:
logger.warning(f"Child {index}: Exit on exception: {e}")
return False


def _agent_process(send_pipe, recv_pipe, index, start_config):
"""Agent process"""
try:
# listening success or failed message from parent process
ExitSignalHandle_.start() # Set flag to running and receive Ctrl+C message
worker_agent.start_worker_agent(start_config=start_config)
send_pipe.send((index, signal_success))
success_msg = _recv_parent(index, recv_pipe)
if not success_msg:
worker_agent.stop()
send_pipe.close()
recv_pipe.close()
# pylint: disable=broad-except
except Exception as e:
logger.error(f"Child {index}: Catch exception and notify exit of others")
send_pipe.send((index, e))
worker_agent.stop()
raise


def startup_worker_agents(worker_ip, worker_port,
get_model_files_fun, get_group_configs_fun,
rank_start, agent_start_port=7000):
"""Start up all needed worker agents on one machine
"""
def _start_listening_child_processes(p_recv_pipe, send_pipe_list, subprocess_list):
"""Listening child process"""
def send_pipe_msg(send_pipe, msg):
try:
send_pipe.send(msg)
# pylint: disable=broad-except
except Exception as e:
logger.warning(f"Send pipe message exception happen: {e}")

count = len(send_pipe_list)
for _ in range(count):
while True:
if p_recv_pipe.poll(0.1):
break
for send_pipe, process in zip(send_pipe_list, subprocess_list):
if process.is_alive():
continue
logger.warning("Fail to start agents because of death of one agent")
for send_pipe_x, process_x in zip(send_pipe_list, subprocess_list):
if process_x.is_alive():
send_pipe_msg(send_pipe_x, signal_exit)
return False
for send_pipe in send_pipe_list:
send_pipe_msg(send_pipe, signal_heartbeat)

_, msg = p_recv_pipe.recv()
if isinstance(msg, Exception):
logger.warning("Fail to start agents because of exception raise by one agent")
for send_pipe in send_pipe_list:
send_pipe_msg(send_pipe, signal_exit)
return False

for send_pipe in send_pipe_list:
send_pipe_msg(send_pipe, signal_success)
logger.info("Success to start agents")
return True


def _startup_all_agents(common_meta, worker_ip, worker_port,
agent_ip, agent_start_port, device_id_list, rank_id_list,
model_files, group_config_files, rank_table_file):
"""Start up all agents in one machine"""
servable_name = common_meta.servable_name
index = 0
send_pipe_list = []
subprocess_list = []
c_send_pipe, p_recv_pipe = Pipe()
for device_id, rank_id, model_file, group_file in zip(device_id_list, rank_id_list, model_files,
group_config_files):
p_send_pipe, c_recv_pipe = Pipe()
send_pipe_list.append(p_send_pipe)

agent_port = agent_start_port + index

start_config = AgentStartUpConfig_()
start_config.rank_id = rank_id
start_config.device_id = device_id
start_config.model_file_name = model_file
start_config.group_file_name = group_file
start_config.rank_table_json_file_name = rank_table_file
start_config.agent_ip = agent_ip
start_config.agent_port = agent_port
start_config.worker_ip = worker_ip
start_config.worker_port = worker_port
start_config.common_meta = common_meta

process = Process(target=_agent_process,
args=(c_send_pipe, c_recv_pipe, index, start_config),
name=f"{servable_name}_worker_agent_rank{rank_id}_device{device_id}")
process.start()
subprocess_list.append(process)
index += 1
ret = _start_listening_child_processes(p_recv_pipe, send_pipe_list, subprocess_list)
if not ret:
WorkerAgent_.notify_failed(worker_ip, worker_port)


def startup_worker_agents(worker_ip, worker_port, model_files, group_config_files, agent_start_port=7000):
"""Start up all needed worker agents on one machine"""
check_type.check_str("worker_ip", worker_ip)
check_type.check_ip_port("worker_port", worker_port)
check_type.check_int("agent_start_port", agent_start_port, 1, 65535 - 7)
if inspect.isfunction(get_model_files_fun):
pass
else:
if not isinstance(get_model_files_fun, [list, tuple]):
raise RuntimeError(f"Check failed, get_model_files_fun first must be function or tuple/list of str, "
f"now is {type(get_model_files_fun)}")
if inspect.isfunction(get_group_configs_fun):
pass
else:
if not isinstance(get_group_configs_fun, [list, tuple]):
raise RuntimeError(f"Check failed, get_group_configs_fun first must be function or tuple/list of str, "
f"now is {type(get_group_configs_fun)}")
check_type.check_int("rank_start", rank_start, 0)
if rank_start % 8 != 0:
raise RuntimeError(f"Parameter 'rank_start' must be mulfiply of 8, now is {rank_start}")
model_files = check_type.check_and_as_int_tuple_list("model_files", model_files)
group_config_files = check_type.check_and_as_int_tuple_list("group_config_files", group_config_files)
distributed_config = WorkerAgent_.get_agents_config_from_worker(worker_ip, worker_port)

# get machine ip
rank_list = distributed_config.rank_list
local_ip = _get_local_ip(rank_list, agent_start_port)
# get all device_id and rank_id
local_device_id_list = []
local_rank_id_list = []
for rank_id, item in enumerate(rank_list):
if item.ip == local_ip:
local_device_id_list.append(item.device_id)
local_rank_id_list.append(rank_id)

# handle model files and group config files
if len(local_device_id_list) != len(model_files):
raise RuntimeError(f"Card count {local_device_id_list} described rank table does not equal to model files size "
f"{len(model_files)}, model files: {model_files}")

if len(local_device_id_list) != len(group_config_files):
raise RuntimeError(f"Card count {local_device_id_list} described rank table does not equal to group config "
f"files size {len(group_config_files)}, group config files: {group_config_files}")

model_files, group_config_files = _update_model_files_path(model_files, group_config_files)

# make json table file and export env
rank_table_file = _make_json_table_file(distributed_config)
_startup_all_agents(distributed_config.common_meta, worker_ip, worker_port, local_ip, agent_start_port,
local_device_id_list, local_rank_id_list,
model_files, group_config_files, rank_table_file)

+ 13
- 9
mindspore_serving/worker/distributed/distributed_worker.py View File

@@ -13,15 +13,17 @@
# limitations under the License.
# ============================================================================
"""Serving, distributed worker startup"""
from mindspore_serving.worker._worker import stop_on_except, _load_servable_config
from mindspore_serving.worker._worker import _start_py_task, _start_wait_and_clear
from mindspore_serving.worker import check_type
from mindspore_serving._mindspore_serving import Worker_

from mindspore_serving.worker import check_type
from mindspore_serving.worker._worker import _start_py_task, _start_wait_and_clear
from mindspore_serving.worker._worker import stop_on_except, _load_servable_config


@stop_on_except
def start_distributed_servable(servable_directory, servable_name, rank_table_json_file, version_number=1,
worker_ip="0.0.0.0", worker_port=6200, master_ip="0.0.0.0", master_port=6100):
worker_ip="0.0.0.0", worker_port=6200, master_ip="0.0.0.0", master_port=6100,
wait_agents_time_in_seconds=300):
r"""
Start up the servable named 'servable_name' defined in 'servable_directory', and link the worker to the master
through gRPC (master_ip, master_port).
@@ -46,6 +48,7 @@ def start_distributed_servable(servable_directory, servable_name, rank_table_jso
master_port (int): The master port the worker linked to.
worker_ip (str): The worker ip the master and agents linked to.
worker_port (int): The worker port the master and agents linked to.
wait_agents_time_in_seconds(int): The maximum time in seconds the worker waiting ready of all agents.

Examples:
>>> import os
@@ -70,15 +73,15 @@ def start_distributed_servable(servable_directory, servable_name, rank_table_jso
check_type.check_ip_port('worker_port', worker_port)

_load_servable_config(servable_directory, servable_name)
_start_wait_and_clear()
Worker_.start_distributed_servable(servable_directory, servable_name, rank_table_json_file, version_number,
master_ip, master_port, worker_ip, worker_port)
master_ip, master_port, worker_ip, worker_port, wait_agents_time_in_seconds)
_start_py_task(Worker_.get_batch_size())
_start_wait_and_clear()


@stop_on_except
def start_distributed_servable_in_master(servable_directory, servable_name, rank_table_json_file, version_number=1,
worker_ip="0.0.0.0", worker_port=6200):
worker_ip="0.0.0.0", worker_port=6200, wait_agents_time_in_seconds=300):
r"""
Start up the servable named 'servable_name' defined in 'svable_directory', and the worker will run in
the process of the master.
@@ -97,6 +100,7 @@ def start_distributed_servable_in_master(servable_directory, servable_name, rank
rank_table_json_file (str): The ranke table json file name.
worker_ip (str): The worker ip the agents linked to.
worker_port (int): The worker port the agents linked to.
wait_agents_time_in_seconds(int): The maximum time in seconds the worker waiting ready of all agents.

Examples:
>>> import os
@@ -121,7 +125,7 @@ def start_distributed_servable_in_master(servable_directory, servable_name, rank
check_type.check_ip_port('worker_port', worker_port)

_load_servable_config(servable_directory, servable_name)
_start_wait_and_clear()
Worker_.start_distributed_servable_in_master(servable_directory, servable_name, rank_table_json_file,
version_number, worker_ip, worker_port)
version_number, worker_ip, worker_port, wait_agents_time_in_seconds)
_start_py_task(Worker_.get_batch_size())
_start_wait_and_clear()

+ 47
- 15
mindspore_serving/worker/distributed/worker_agent.py View File

@@ -13,22 +13,54 @@
# limitations under the License.
# ============================================================================
"""Serving, distributed worker agent"""
from mindspore_serving.worker import check_type

import os
import threading
from mindspore_serving._mindspore_serving import WorkerAgent_, AgentStartUpConfig_
from mindspore_serving import log as logger

def _start_worker_agent(agent_ip, agent_port, worker_ip, worker_port,
rank_id, device_id, model_file, group_config_file, rank_table_file,
with_bach_dim, without_batch_dim_inputs):

def start_worker_agent(start_config):
"""Start up one worker agent on one device id, invoke by agent_startup.startup_worker_agents
"""
check_type.check_str("agent_ip", agent_ip)
check_type.check_ip_port("agent_port", agent_port)
check_type.check_str("worker_ip", worker_ip)
check_type.check_ip_port("worker_port", worker_port)
check_type.check_int("rank_id", rank_id, 0)
check_type.check_int("device_id", device_id, 0)
check_type.check_str("model_file", model_file)
check_type.check_str("group_config_file", group_config_file)
check_type.check_str("rank_table_file", rank_table_file)
check_type.check_bool("with_bach_dim", with_bach_dim)
check_type.check_and_as_int_tuple_list("without_batch_dim_inputs", without_batch_dim_inputs, 0)
if not isinstance(start_config, AgentStartUpConfig_):
raise RuntimeError("Parameter 'start_config' should be instance of AgentStartUpConfig_")

os.environ["RANK_ID"] = str(start_config.rank_id)
os.environ["DEVICE_ID"] = str(start_config.device_id)
os.environ["MS_ENABLE_HCCL"] = "1"
os.environ["PARA_GROUP_FILE"] = start_config.group_file_name
os.environ["RANK_TABLE_FILE"] = start_config.rank_table_json_file_name

for item in ("RANK_ID", "DEVICE_ID", "MS_ENABLE_HCCL", "PARA_GROUP_FILE", "RANK_TABLE_FILE",
"LD_LIBRARY_PATH", "PYTHONPATH"):
logger.info(f"Env {item}: {os.getenv(item, '')}")
WorkerAgent_.start_agent(start_config)

start_wait_and_clear()


_wait_and_clear_thread = None


def start_wait_and_clear():
"""Waiting for Ctrl+C, and clear up environment"""

def thread_func():
logger.info("Serving worker: wait for Ctrl+C to exit ------------------------------------")
print("Serving worker: wait for Ctrl+C to exit ------------------------------------")
WorkerAgent_.wait_and_clear()
logger.info("Serving worker: exited ------------------------------------")
print("Serving worker: exited ------------------------------------")

global _wait_and_clear_thread
if not _wait_and_clear_thread:
_wait_and_clear_thread = threading.Thread(target=thread_func)
_wait_and_clear_thread.start()


def stop():
r"""
Stop the running of agent.
"""
WorkerAgent_.stop_and_clear()

Loading…
Cancel
Save