From 90db9fba60a669fc57dfdac84018c5368b949190 Mon Sep 17 00:00:00 2001 From: xuyongfei Date: Tue, 2 Feb 2021 19:58:11 +0800 Subject: [PATCH] Serving, python agent --- mindspore_serving/ccsrc/master/server.cc | 3 - .../ccsrc/python/agent/agent_py.cc | 63 +++++ .../agent/agent_py.h} | 31 ++- mindspore_serving/ccsrc/python/serving_py.cc | 52 +++- .../ccsrc/python/worker/worker_py.cc | 27 +- .../ccsrc/python/worker/worker_py.h | 5 +- .../agent_process/agent_process.cc | 2 +- .../distributed_worker/agent_startup.cc | 27 +- .../worker/distributed_worker/agent_startup.h | 11 +- .../distributed_process.cc | 17 +- .../distributed_process.h | 2 + .../distributed_server.cc | 7 +- .../distributed_server.h | 99 ++++--- .../distributed_servable.cc | 109 +++++--- .../distributed_worker/distributed_servable.h | 15 +- .../notify_distributed/notify_worker.cc | 18 +- .../notify_distributed/notify_worker.h | 17 +- .../worker/distributed_worker/worker_agent.cc | 77 +++++- .../worker/distributed_worker/worker_agent.h | 16 +- .../ccsrc/worker/grpc/worker_server.h | 39 ++- .../worker/local_servable/local_sevable.cc | 9 +- .../worker/local_servable/local_sevable.h | 1 + mindspore_serving/ccsrc/worker/sevable_base.h | 1 + mindspore_serving/ccsrc/worker/worker.cc | 15 +- mindspore_serving/ccsrc/worker/worker.h | 5 +- mindspore_serving/master/_master.py | 2 + mindspore_serving/proto/ms_distributed.proto | 7 + mindspore_serving/proto/ms_worker.proto | 1 + mindspore_serving/worker/_worker.py | 4 +- .../worker/distributed/agent_startup.py | 249 ++++++++++++++++-- .../worker/distributed/distributed_worker.py | 22 +- .../worker/distributed/worker_agent.py | 62 +++-- 32 files changed, 797 insertions(+), 218 deletions(-) create mode 100644 mindspore_serving/ccsrc/python/agent/agent_py.cc rename mindspore_serving/ccsrc/{worker/distributed_worker/notify_distributed/base_notify_worker.h => python/agent/agent_py.h} (52%) rename mindspore_serving/ccsrc/worker/distributed_worker/{grpc => distributed_process}/distributed_process.cc (76%) rename mindspore_serving/ccsrc/worker/distributed_worker/{grpc => distributed_process}/distributed_process.h (89%) rename mindspore_serving/ccsrc/worker/distributed_worker/{grpc => distributed_process}/distributed_server.cc (73%) rename mindspore_serving/ccsrc/worker/distributed_worker/{grpc => distributed_process}/distributed_server.h (50%) diff --git a/mindspore_serving/ccsrc/master/server.cc b/mindspore_serving/ccsrc/master/server.cc index 5117bad..980daac 100644 --- a/mindspore_serving/ccsrc/master/server.cc +++ b/mindspore_serving/ccsrc/master/server.cc @@ -39,7 +39,6 @@ Status Server::StartGrpcServer(const std::string &ip, uint32_t grpc_port, int ma if (grpc_async_server_ != nullptr) { return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Serving Error: Serving gRPC server is already running"; } - ExitSignalHandle::Instance().Start(); // handle ctrl+c to exit if (max_msg_mb_size > gRpcMaxMBMsgSize) { MSI_LOG_WARNING << "The maximum Serving gRPC message size is 512MB and will be updated from " << max_msg_mb_size << "MB to 512MB"; @@ -50,14 +49,12 @@ Status Server::StartGrpcServer(const std::string &ip, uint32_t grpc_port, int ma } Status Server::StartGrpcMasterServer(const std::string &ip, uint32_t grpc_port) { - ExitSignalHandle::Instance().Start(); // handle ctrl+c to exit return grpc_manager_server_.Start(std::make_shared(dispatcher_), ip, grpc_port, gRpcMaxMBMsgSize, "Master"); } Status Server::StartRestfulServer(const std::string &ip, uint32_t restful_port, int max_msg_mb_size, int time_out_second) { - ExitSignalHandle::Instance().Start(); // handle ctrl+c to exit return restful_server_.Start(ip, restful_port, max_msg_mb_size, time_out_second); } diff --git a/mindspore_serving/ccsrc/python/agent/agent_py.cc b/mindspore_serving/ccsrc/python/agent/agent_py.cc new file mode 100644 index 0000000..c2c1465 --- /dev/null +++ b/mindspore_serving/ccsrc/python/agent/agent_py.cc @@ -0,0 +1,63 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "python/agent/agent_py.h" +#include "common/exit_handle.h" +#include "worker/distributed_worker/agent_startup.h" +#include "worker/distributed_worker/worker_agent.h" + +namespace mindspore::serving { + +DistributedServableConfig PyAgent::GetAgentsConfigsFromWorker(const std::string &worker_ip, uint32_t worker_port) { + auto status = WorkerAgentStartUp::Instance().GetAgentsConfigsFromWorker(worker_ip, worker_port); + if (status != SUCCESS) { + MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); + } + + DistributedServableConfig config; + status = WorkerAgentStartUp::Instance().GetDistributedServableConfig(&config); + if (status != SUCCESS) { + MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); + } + return config; +} + +void PyAgent::NotifyFailed(const std::string &worker_ip, uint32_t worker_port) { + WorkerAgentStartUp::Instance().NotifyFailed(worker_ip, worker_port); +} + +void PyAgent::StartAgent(const AgentStartUpConfig &start_config) { + auto status = WorkerAgent::Instance().StartAgent(start_config); + if (status != SUCCESS) { + MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); + } +} + +void PyAgent::WaitAndClear() { + { + py::gil_scoped_release release; + ExitSignalHandle::Instance().AgentWait(); + } + WorkerAgent::Instance().Clear(); + MSI_LOG_INFO << "Python agent end wait and clear"; +} + +void PyAgent::StopAndClear() { + ExitSignalHandle::Instance().Stop(); + WorkerAgent::Instance().Clear(); +} + +} // namespace mindspore::serving diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/base_notify_worker.h b/mindspore_serving/ccsrc/python/agent/agent_py.h similarity index 52% rename from mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/base_notify_worker.h rename to mindspore_serving/ccsrc/python/agent/agent_py.h index 8e5e690..708b673 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/base_notify_worker.h +++ b/mindspore_serving/ccsrc/python/agent/agent_py.h @@ -1,5 +1,5 @@ /** - * Copyright 2021 Huawei Technologies Co., Ltd + * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,25 +14,34 @@ * limitations under the License. */ -#ifndef MINDSPORE_SERVING_WORKER_BASE_NOTIFY_WORKER_H -#define MINDSPORE_SERVING_WORKER_BASE_NOTIFY_WORKER_H -#include +#ifndef MINDSPORE_SERVER_AGENT_PY_H +#define MINDSPORE_SERVER_AGENT_PY_H + +#include +#include +#include +#include +#include #include "common/serving_common.h" -#include "common/servable.h" #include "worker/distributed_worker/common.h" +namespace py = pybind11; + namespace mindspore { namespace serving { -class MS_API BaseNotifyDistributeWorker { +class MS_API PyAgent { public: - BaseNotifyDistributeWorker() = default; - virtual ~BaseNotifyDistributeWorker() = default; - virtual Status Register(const std::vector &worker_specs) = 0; - virtual Status Unregister() = 0; + static void StartAgent(const AgentStartUpConfig &start_config); + + static DistributedServableConfig GetAgentsConfigsFromWorker(const std::string &worker_ip, uint32_t worker_port); + static void WaitAndClear(); + static void StopAndClear(); + // from start up, not agent + static void NotifyFailed(const std::string &worker_ip, uint32_t worker_port); }; } // namespace serving } // namespace mindspore -#endif // MINDSPORE_SERVING_WORKER_BASE_NOTIFY_WORKER_H +#endif // MINDSPORE_SERVER_AGENT_PY_H diff --git a/mindspore_serving/ccsrc/python/serving_py.cc b/mindspore_serving/ccsrc/python/serving_py.cc index 1dac040..adf29d3 100644 --- a/mindspore_serving/ccsrc/python/serving_py.cc +++ b/mindspore_serving/ccsrc/python/serving_py.cc @@ -23,6 +23,9 @@ #include "common/servable.h" #include "worker/context.h" #include "python/master/master_py.h" +#include "python/agent/agent_py.h" +#include "common/exit_handle.h" +#include "worker/distributed_worker/worker_agent.h" namespace mindspore::serving { @@ -104,11 +107,23 @@ void PyRegServable(pybind11::module *m_ptr) { .def_static("register_method", &PyServableStorage::RegisterMethod) .def_static("declare_servable", &PyServableStorage::DeclareServable) .def_static("declare_distributed_servable", &PyServableStorage::DeclareDistributedServable); + + py::class_(m, "OneRankConfig_") + .def(py::init<>()) + .def_readwrite("device_id", &OneRankConfig::device_id) + .def_readwrite("ip", &OneRankConfig::ip); + + py::class_(m, "DistributedServableConfig_") + .def(py::init<>()) + .def_readwrite("common_meta", &DistributedServableConfig::common_meta) + .def_readwrite("distributed_meta", &DistributedServableConfig::distributed_meta) + .def_readwrite("rank_table_content", &DistributedServableConfig::rank_table_content) + .def_readwrite("rank_list", &DistributedServableConfig::rank_list); } void PyRegMaster(pybind11::module *m_ptr) { auto &m = *m_ptr; - py::class_>(m, "Master_") + py::class_(m, "Master_") .def_static("start_grpc_server", &PyMaster::StartGrpcServer) .def_static("start_grpc_master_server", &PyMaster::StartGrpcMasterServer) .def_static("start_restful_server", &PyMaster::StartRestfulServer) @@ -163,15 +178,50 @@ void PyRegWorker(pybind11::module *m_ptr) { .def("set_device_id", &ServableContext::SetDeviceId); } +void PyRegWorkerAgent(pybind11::module *m_ptr) { + auto &m = *m_ptr; + py::class_(m, "WorkerAgent_") + .def_static("get_agents_config_from_worker", &PyAgent::GetAgentsConfigsFromWorker) + .def_static("wait_and_clear", &PyAgent::WaitAndClear) + .def_static("stop_and_clear", &PyAgent::StopAndClear) + .def_static("notify_failed", &PyAgent::NotifyFailed) + .def_static("start_agent", &PyAgent::StartAgent); + + py::class_(m, "AgentStartUpConfig_") + .def(py::init<>()) + .def_readwrite("rank_id", &AgentStartUpConfig::rank_id) + .def_readwrite("device_id", &AgentStartUpConfig::device_id) + .def_readwrite("model_file_name", &AgentStartUpConfig::model_file_name) + .def_readwrite("group_file_name", &AgentStartUpConfig::group_file_name) + .def_readwrite("rank_table_json_file_name", &AgentStartUpConfig::rank_table_json_file_name) + .def_readwrite("agent_ip", &AgentStartUpConfig::agent_ip) + .def_readwrite("agent_port", &AgentStartUpConfig::agent_port) + .def_readwrite("worker_ip", &AgentStartUpConfig::worker_ip) + .def_readwrite("worker_port", &AgentStartUpConfig::worker_port) + .def_readwrite("common_meta", &AgentStartUpConfig::common_meta); +} + +class PyExitSignalHandle { + public: + static void Start() { ExitSignalHandle::Instance().Start(); } + static bool HasStopped() { return ExitSignalHandle::Instance().HasStopped(); } +}; + // cppcheck-suppress syntaxError PYBIND11_MODULE(_mindspore_serving, m) { PyRegServable(&m); PyRegMaster(&m); PyRegWorker(&m); + PyRegWorkerAgent(&m); + + py::class_(m, "ExitSignalHandle_") + .def_static("start", &PyExitSignalHandle::Start) + .def_static("has_stopped", &PyExitSignalHandle::HasStopped); (void)py::module::import("atexit").attr("register")(py::cpp_function{[&]() -> void { Server::Instance().Clear(); Worker::GetInstance().Clear(); + WorkerAgent::Instance().Clear(); }}); } diff --git a/mindspore_serving/ccsrc/python/worker/worker_py.cc b/mindspore_serving/ccsrc/python/worker/worker_py.cc index faa0f31..c1b03a5 100644 --- a/mindspore_serving/ccsrc/python/worker/worker_py.cc +++ b/mindspore_serving/ccsrc/python/worker/worker_py.cc @@ -24,7 +24,7 @@ #include "worker/local_servable/local_sevable.h" #include "worker/distributed_worker/distributed_servable.h" #include "worker/grpc/worker_server.h" -#include "worker/distributed_worker/grpc/distributed_server.h" +#include "worker/distributed_worker/distributed_process/distributed_server.h" namespace mindspore::serving { @@ -43,11 +43,10 @@ void PyWorker::StartServable(const std::string &model_directory, const std::stri } // start grpc server auto grpc_sever = std::make_shared(); - status = grpc_sever->StartWorkerGrpcServer(worker_ip, worker_port); + status = Worker::GetInstance().StartGrpcServer(grpc_sever, worker_ip, worker_port); if (status != SUCCESS) { MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); } - Worker::GetInstance().AfterStartGrpcServer(grpc_sever); status = Worker::GetInstance().StartVersionController(); if (status != SUCCESS) { @@ -76,18 +75,19 @@ void PyWorker::StartServableInMaster(const std::string &model_directory, const s void PyWorker::StartDistributedServable(const std::string &servable_directory, const std::string &servable_name, const std::string &rank_table_json_file, uint32_t version_number, const std::string &worker_ip, uint32_t worker_port, - const std::string &master_ip, uint32_t master_port) { + const std::string &master_ip, uint32_t master_port, + uint32_t wait_agents_time_in_seconds) { Status status; auto servable = std::make_shared(); - auto grpc_sever = std::make_shared(); - status = grpc_sever->StartDistributedWorkerGrpcServer(servable, worker_ip, worker_port); + auto grpc_sever = std::make_shared(servable); + status = Worker::GetInstance().StartGrpcServer(grpc_sever, worker_ip, worker_port); if (status != SUCCESS) { MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); } - Worker::GetInstance().AfterStartGrpcServer(grpc_sever); auto notify_master = std::make_shared(master_ip, master_port, worker_ip, worker_port); - status = servable->StartServable(servable_directory, servable_name, rank_table_json_file, version_number); + status = servable->StartServable(servable_directory, servable_name, rank_table_json_file, version_number, + wait_agents_time_in_seconds); if (status != SUCCESS) { MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); } @@ -103,18 +103,19 @@ void PyWorker::StartDistributedServable(const std::string &servable_directory, c void PyWorker::StartDistributedServableInMaster(const std::string &servable_directory, const std::string &servable_name, const std::string &rank_table_json_file, uint32_t version_number, - const std::string &worker_ip, uint32_t worker_port) { + const std::string &worker_ip, uint32_t worker_port, + uint32_t wait_agents_time_in_seconds) { Status status; auto servable = std::make_shared(); - auto grpc_sever = std::make_shared(); - status = grpc_sever->StartDistributedWorkerGrpcServer(servable, worker_ip, worker_port); + auto grpc_sever = std::make_shared(servable); + status = Worker::GetInstance().StartGrpcServer(grpc_sever, worker_ip, worker_port); if (status != SUCCESS) { MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); } - Worker::GetInstance().AfterStartGrpcServer(grpc_sever); auto notify_master = std::make_shared(); - status = servable->StartServable(servable_directory, servable_name, rank_table_json_file, version_number); + status = servable->StartServable(servable_directory, servable_name, rank_table_json_file, version_number, + wait_agents_time_in_seconds); if (status != SUCCESS) { MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); } diff --git a/mindspore_serving/ccsrc/python/worker/worker_py.h b/mindspore_serving/ccsrc/python/worker/worker_py.h index 01a53a8..e6b2c6d 100644 --- a/mindspore_serving/ccsrc/python/worker/worker_py.h +++ b/mindspore_serving/ccsrc/python/worker/worker_py.h @@ -37,11 +37,12 @@ class MS_API PyWorker { static void StartDistributedServable(const std::string &servable_directory, const std::string &servable_name, const std::string &rank_table_json_file, uint32_t version_number, const std::string &worker_ip, uint32_t worker_port, const std::string &master_ip, - uint32_t master_port); + uint32_t master_port, uint32_t wait_agents_time_in_seconds); static void StartDistributedServableInMaster(const std::string &servable_directory, const std::string &servable_name, const std::string &rank_table_json_file, uint32_t version_number, - const std::string &worker_ip, uint32_t worker_port); + const std::string &worker_ip, uint32_t worker_port, + uint32_t wait_agents_time_in_seconds); static int GetBatchSize(); static void WaitAndClear(); diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/agent_process/agent_process.cc b/mindspore_serving/ccsrc/worker/distributed_worker/agent_process/agent_process.cc index ff030e6..6e1750a 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/agent_process/agent_process.cc +++ b/mindspore_serving/ccsrc/worker/distributed_worker/agent_process/agent_process.cc @@ -22,7 +22,7 @@ namespace serving { grpc::Status MSAgentImpl::Exit(grpc::ServerContext *context, const proto::DistributedExitRequest *request, proto::DistributedExitReply *reply) { MSI_LOG(INFO) << "Distributed Worker Exit"; - WorkerAgent::Instance().Clear(); + WorkerAgent::Instance().StopAgent(false); return grpc::Status::OK; } diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.cc b/mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.cc index b4f5ee9..8ec9a39 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.cc +++ b/mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.cc @@ -14,17 +14,32 @@ * limitations under the License. */ #include "worker/distributed_worker/agent_startup.h" +#include "worker/distributed_worker/notify_distributed/notify_worker.h" + namespace mindspore { namespace serving { -Status WorkerAgentStartUp::InitAgentsConfig(const std::string &model_dir, const std::string &model_file_prefix, - const std::string &group_file_dir, const std::string &group_file_prefix) { - return Status(); +WorkerAgentStartUp &WorkerAgentStartUp::Instance() { + static WorkerAgentStartUp instance; + return instance; } -Status WorkerAgentStartUp::GetAgentsConfigsFromWorker(const std::string &agent_ip, uint32_t agent_start_port, - const std::string &worker_ip, uint32_t worker_port) { + +Status WorkerAgentStartUp::GetAgentsConfigsFromWorker(const std::string &worker_ip, uint32_t worker_port) { return Status(); } -Status WorkerAgentStartUp::GetCurrentMachineConfigs(std::vector *configs) { return Status(); } + +Status WorkerAgentStartUp::GetDistributedServableConfig(DistributedServableConfig *config) { + MSI_EXCEPTION_IF_NULL(config); + if (config_.rank_list.empty()) { + return INFER_STATUS_LOG_ERROR(FAILED) << "Rank table config is not ready"; + } + *config = config_; + return SUCCESS; +} + +Status WorkerAgentStartUp::NotifyFailed(const std::string &worker_ip, uint32_t worker_port) { + return GrpcNotifyDistributeWorker::NotifyFailed(worker_ip, worker_port); +} + } // namespace serving } // namespace mindspore diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.h b/mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.h index 5a7c25e..ad28e5c 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.h +++ b/mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.h @@ -27,16 +27,15 @@ namespace serving { class MS_API WorkerAgentStartUp { public: + static WorkerAgentStartUp &Instance(); // from python, worker_agent.py // start_worker_agent // step1, get agents config from worker - Status InitAgentsConfig(const std::string &model_dir, const std::string &model_file_prefix, - const std::string &group_file_dir, const std::string &group_file_prefix); + Status GetAgentsConfigsFromWorker(const std::string &worker_ip, uint32_t worker_port); + // step2, invoke from python + Status GetDistributedServableConfig(DistributedServableConfig *config); - Status GetAgentsConfigsFromWorker(const std::string &rank_start, uint32_t agent_start_port, - const std::string &worker_ip, uint32_t worker_port); - // step2, invoke from python, get current machine agents config - Status GetCurrentMachineConfigs(std::vector *configs); + Status NotifyFailed(const std::string &worker_ip, uint32_t worker_port); private: DistributedServableConfig config_; diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_process.cc b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.cc similarity index 76% rename from mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_process.cc rename to mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.cc index 0333434..48d1042 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_process.cc +++ b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.cc @@ -14,7 +14,8 @@ * limitations under the License. */ -#include "worker/distributed_worker/grpc/distributed_process.h" +#include "worker/distributed_worker/distributed_process/distributed_process.h" +#include "worker/worker.h" #include "common/proto_tensor.h" namespace mindspore { @@ -51,6 +52,20 @@ grpc::Status MSDistributedImpl::AgentExit(grpc::ServerContext *context, const pr MSI_LOG(ERROR) << "Agent Exit FAILED"; } } + if (Worker::GetInstance().IsRunning()) { + Worker::GetInstance().StopServable(); + } + return grpc::Status::OK; +} + +grpc::Status MSDistributedImpl::AgentFailed(grpc::ServerContext *context, const proto::AgentFailedRequest *request, + proto::AgentFailedReply *reply) { + if (Worker::GetInstance().IsRunning()) { + MSI_LOG_ERROR << "Expect worker should not be running"; + Worker::GetInstance().StopServable(); + } else { + servable_->OnAgentFailed(); + } return grpc::Status::OK; } } // namespace serving diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_process.h b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.h similarity index 89% rename from mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_process.h rename to mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.h index b127ac7..147e7c5 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_process.h +++ b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.h @@ -41,6 +41,8 @@ class MSDistributedImpl final : public MSWorkerImpl { proto::AgentRegisterReply *reply) override; grpc::Status AgentExit(grpc::ServerContext *context, const proto::AgentExitRequest *request, proto::AgentExitReply *reply) override; + grpc::Status AgentFailed(grpc::ServerContext *context, const proto::AgentFailedRequest *request, + proto::AgentFailedReply *reply) override; private: std::shared_ptr servable_; diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_server.cc b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_server.cc similarity index 73% rename from mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_server.cc rename to mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_server.cc index 79d4064..d9de7cd 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_server.cc +++ b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_server.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "worker/distributed_worker/grpc/distributed_server.h" +#include "worker/distributed_worker/distributed_process/distributed_server.h" #include #include #include @@ -23,12 +23,11 @@ namespace mindspore { namespace serving { -Status MSDistributedWorkerServer::StartDistributedWorkerGrpcServer(std::shared_ptr servable, - const std::string &hostname, int32_t port) { +Status MSDistributedWorkerServer::StartWorkerGrpcServer(const std::string &hostname, int32_t port) { if (in_running_) { return INFER_STATUS_LOG_ERROR(FAILED) << "Worker grpc server is already running"; } - auto impl = std::make_unique(servable); + auto impl = std::make_unique(servable_); async_server_ = std::make_unique(hostname, port, impl.get()); service_impl_ = std::move(impl); return Init(); diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_server.h b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_server.h similarity index 50% rename from mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_server.h rename to mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_server.h index 2151a41..ca6b967 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_server.h +++ b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_server.h @@ -28,7 +28,7 @@ #include "common/grpc_async_server.h" #include "worker/grpc/worker_process.h" #include "worker/grpc/worker_server.h" -#include "worker/distributed_worker/grpc/distributed_process.h" +#include "worker/distributed_worker/distributed_process/distributed_process.h" namespace mindspore { namespace serving { @@ -36,18 +36,30 @@ namespace serving { // Service Implement class MS_API MSDistributedWorkerServer : public MSWorkerServer { public: - Status StartDistributedWorkerGrpcServer(std::shared_ptr servable, const std::string &hostname, - int32_t port); + explicit MSDistributedWorkerServer(std::shared_ptr servable) : servable_(servable) {} + ~MSDistributedWorkerServer() = default; + Status StartWorkerGrpcServer(const std::string &hostname, int32_t port) override; + + private: + std::shared_ptr servable_; +}; + +class DistributedServiceContext : public WorkerServiceContext { + public: + DistributedServiceContext(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service, + grpc::ServerCompletionQueue *cq) + : WorkerServiceContext(service_impl, async_service, cq), dist_service_impl_(service_impl) {} + + protected: + MSDistributedImpl *dist_service_impl_ = nullptr; }; // Service Implement -class WorkerAgentRegisterContext : public WorkerServiceContext { +class WorkerAgentRegisterContext : public DistributedServiceContext { public: WorkerAgentRegisterContext(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service, grpc::ServerCompletionQueue *cq) - : service_impl_(service_impl), async_service_(async_service), cq_(cq), responder_(&ctx_) { - state_ = STATE::CREATE; - } + : DistributedServiceContext(service_impl, async_service, cq), responder_(&ctx_) {} ~WorkerAgentRegisterContext() = default; @@ -60,35 +72,27 @@ class WorkerAgentRegisterContext : public WorkerServiceContext { void StartEnqueueRequest() override { state_ = STATE::PROCESS; - async_service_->RequestPredict(&ctx_, &request_, &responder_, cq_, cq_, this); + async_service_->RequestAgentRegister(&ctx_, &request_, &responder_, cq_, cq_, this); } void HandleRequest() override { - EnqueueRequest(service_impl_, async_service_, cq_); + EnqueueRequest(dist_service_impl_, async_service_, cq_); state_ = STATE::FINISH; - grpc::Status status = service_impl_->Predict(&ctx_, &request_, &response_); + grpc::Status status = dist_service_impl_->AgentRegister(&ctx_, &request_, &response_); responder_.Finish(response_, status, this); } - bool JudgeFinish() override { return state_ == STATE::FINISH; } - private: - MSDistributedImpl *service_impl_; - proto::MSWorker::AsyncService *async_service_; - grpc::ServerCompletionQueue *cq_; - grpc::ServerContext ctx_; - grpc::ServerAsyncResponseWriter responder_; - proto::PredictRequest request_; - proto::PredictReply response_; + grpc::ServerAsyncResponseWriter responder_; + proto::AgentRegisterRequest request_; + proto::AgentRegisterReply response_; }; -class WorkerAgentExitContext : public WorkerServiceContext { +class WorkerAgentExitContext : public DistributedServiceContext { public: WorkerAgentExitContext(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service, grpc::ServerCompletionQueue *cq) - : service_impl_(service_impl), async_service_(async_service), cq_(cq), responder_(&ctx_) { - state_ = STATE::CREATE; - } + : DistributedServiceContext(service_impl, async_service, cq), responder_(&ctx_) {} ~WorkerAgentExitContext() = default; @@ -101,26 +105,52 @@ class WorkerAgentExitContext : public WorkerServiceContext { void StartEnqueueRequest() override { state_ = STATE::PROCESS; - async_service_->RequestExit(&ctx_, &request_, &responder_, cq_, cq_, this); + async_service_->RequestAgentExit(&ctx_, &request_, &responder_, cq_, cq_, this); } void HandleRequest() override { - EnqueueRequest(service_impl_, async_service_, cq_); + EnqueueRequest(dist_service_impl_, async_service_, cq_); state_ = STATE::FINISH; - grpc::Status status = service_impl_->Exit(&ctx_, &request_, &response_); + grpc::Status status = dist_service_impl_->AgentExit(&ctx_, &request_, &response_); responder_.Finish(response_, status, this); } - bool JudgeFinish() override { return state_ == STATE::FINISH; } + private: + grpc::ServerAsyncResponseWriter responder_; + proto::AgentExitRequest request_; + proto::AgentExitReply response_; +}; + +class WorkerAgentFailedContext : public DistributedServiceContext { + public: + WorkerAgentFailedContext(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service, + grpc::ServerCompletionQueue *cq) + : DistributedServiceContext(service_impl, async_service, cq), responder_(&ctx_) {} + + ~WorkerAgentFailedContext() = default; + static Status EnqueueRequest(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service, + grpc::ServerCompletionQueue *cq) { + auto call = new WorkerAgentFailedContext(service_impl, async_service, cq); + call->StartEnqueueRequest(); + return SUCCESS; + } + + void StartEnqueueRequest() override { + state_ = STATE::PROCESS; + async_service_->RequestAgentFailed(&ctx_, &request_, &responder_, cq_, cq_, this); + } + + void HandleRequest() override { + EnqueueRequest(dist_service_impl_, async_service_, cq_); + state_ = STATE::FINISH; + grpc::Status status = dist_service_impl_->AgentFailed(&ctx_, &request_, &response_); + responder_.Finish(response_, status, this); + } private: - MSDistributedImpl *service_impl_; - proto::MSWorker::AsyncService *async_service_; - grpc::ServerCompletionQueue *cq_; - grpc::ServerContext ctx_; - grpc::ServerAsyncResponseWriter responder_; - proto::ExitRequest request_; - proto::ExitReply response_; + grpc::ServerAsyncResponseWriter responder_; + proto::AgentFailedRequest request_; + proto::AgentFailedReply response_; }; class DistributedWorkerGrpcServer : public WorkerGrpcServer { @@ -134,6 +164,7 @@ class DistributedWorkerGrpcServer : public WorkerGrpcServer { WorkerGrpcServer::EnqueueRequest(); WorkerAgentRegisterContext::EnqueueRequest(distributed_service_impl_, &svc_, cq_.get()); WorkerAgentExitContext::EnqueueRequest(distributed_service_impl_, &svc_, cq_.get()); + WorkerAgentFailedContext::EnqueueRequest(distributed_service_impl_, &svc_, cq_.get()); return SUCCESS; } diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.cc b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.cc index b504355..ea83d1c 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.cc +++ b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.cc @@ -17,13 +17,15 @@ #include "worker/distributed_worker/distributed_servable.h" #include #include -#include "worker/worker.h" +#include #include "worker/distributed_worker/notify_agent/notify_agent.h" #include "common/exit_handle.h" namespace mindspore { namespace serving { +DistributedServable::~DistributedServable() { Clear(); } + std::string DistributedServable::GetServableName() const { return servable_name_; } uint64_t DistributedServable::GetServableVersion() const { return version_number_; } @@ -60,7 +62,15 @@ Status DistributedServable::GetDistributedServableConfig(DistributedServableConf return SUCCESS; } +void DistributedServable::SetWaitAgentsPromise(bool flag) { + if (!promise_set_flag_.test_and_set()) { + agents_promise_.set_value(flag); + } +} + Status DistributedServable::RegisterAgent(const WorkerAgentSpec &agent_spec) { + std::unique_lock lock{mutex_}; + if (agent_spec.rank_id < config_.distributed_meta.rank_size) { return INFER_STATUS_LOG_ERROR(FAILED) << "Invalid rank id " << agent_spec.rank_id << ", rank size " << config_.distributed_meta.rank_size; @@ -75,27 +85,24 @@ Status DistributedServable::RegisterAgent(const WorkerAgentSpec &agent_spec) { std::shared_ptr notify_agent = std::make_shared(agent_spec.agent_address); context.notify_agent_ = notify_agent; agent_spec_map_[agent_spec.rank_id] = context; - if (config_.distributed_meta.rank_size == agent_spec_map_.size()) { - Status status = Worker::GetInstance().RegisterWorker(); - if (status != SUCCESS) { - Clear(); - return FAILED; - } - } + if (agent_spec_map_.size() >= config_.distributed_meta.rank_size) { - agents_promise_.set_value(); + SetWaitAgentsPromise(true); } return SUCCESS; } void DistributedServable::Clear() { - for (auto agent : agent_spec_map_) { + std::unique_lock lock{mutex_}; + for (auto &agent : agent_spec_map_) { agent.second.notify_agent_->Exit(); } - Worker::GetInstance().StopServable(false); + agent_spec_map_.clear(); + MSI_LOG_INFO << "End Clear servable"; } Status DistributedServable::UnregisterAgent(const WorkerAgentSpec &agent_spec) { + std::unique_lock lock{mutex_}; for (auto iter = agent_spec_map_.begin(); iter != agent_spec_map_.end();) { if (agent_spec.rank_id == iter->second.agent_spec_.rank_id) { iter = agent_spec_map_.erase(iter); @@ -103,13 +110,13 @@ Status DistributedServable::UnregisterAgent(const WorkerAgentSpec &agent_spec) { ++iter; } } - // todo: send exit message to agent, and then exit if split with master - Clear(); + SetWaitAgentsPromise(false); return SUCCESS; } Status DistributedServable::StartServable(const std::string &servable_directory, const std::string &servable_name, - const std::string &rank_table_json_file, uint64_t version_number) { + const std::string &rank_table_json_file, uint64_t version_number, + uint64_t wait_agents_time_in_seconds) { if (model_loaded_) { MSI_LOG_EXCEPTION << "Model has loaded"; } @@ -138,7 +145,7 @@ Status DistributedServable::StartServable(const std::string &servable_directory, MSI_LOG_ERROR << "Check rank config failed"; return status; } - status = WaitAgentsReady(); + status = WaitAgentsReady(wait_agents_time_in_seconds); if (status != SUCCESS) { MSI_LOG_ERROR << "Waiting for ready of agents failed"; return status; @@ -154,16 +161,23 @@ Status DistributedServable::StartServable(const std::string &servable_directory, Status DistributedServable::InitConfigOnStartup(const std::string &rank_table_json_file) { return FAILED; } -Status DistributedServable::WaitAgentsReady() { +Status DistributedServable::WaitAgentsReady(uint64_t wait_agents_time_in_seconds) { auto future = agents_promise_.get_future(); - const int kWaitMaxHundredMs = 100 * 10; // 100s - int i; + if (wait_agents_time_in_seconds == 0) { + wait_agents_time_in_seconds = UINT32_MAX; + } + const uint64_t kWaitMaxHundredMs = wait_agents_time_in_seconds * 10; + uint64_t i; for (i = 0; i < kWaitMaxHundredMs; i++) { // if (ExitSignalHandle::Instance().HasStopped()) { return INFER_STATUS_LOG_ERROR(FAILED) << "Agents has stopped"; } // waiting for 100ms if (future.wait_for(std::chrono::milliseconds(100)) == std::future_status::ready) { + auto flag = future.get(); + if (!flag) { + return INFER_STATUS_LOG_ERROR(FAILED) << "Failed to starting all agents, maybe some error reported"; + } break; } } @@ -264,32 +278,49 @@ Status DistributedServable::CheckRankConfig() { << "Rank size must be an integral multiple of stage size, rank size: " << rank_size << ", stage size: " << stage_size; } - auto parallel_count = rank_size / stage_size; - constexpr size_t card_count_per_machine = 8; - if (rank_size > card_count_per_machine && parallel_count % card_count_per_machine != 0) { - return INFER_STATUS_LOG_ERROR(FAILED) - << "Parallel count " << parallel_count << " in one stage must be an integral multiple of card count " - << card_count_per_machine << " in one machine, when rank size is greater than card count in one machine, " - << "rank size: " << rank_size << ", stage size: " << stage_size; - } if (config_.rank_list.size() != rank_size) { return INFER_STATUS_LOG_ERROR(FAILED) << "Rank size " << config_.rank_list.size() << " declared in rank table file not equal to rank size " << rank_size << " declared in servable_config, rank json config file: " << rank_table_json_file_; } - for (size_t i = 0; i < rank_size; i++) { - const auto &first_item = config_.rank_list[i]; - for (size_t k = 0; i + k < rank_size && k < card_count_per_machine; k++) { - auto rank_id = i + k; - const auto &item = config_.rank_list[rank_id]; - if (k != item.device_id) { - return INFER_STATUS_LOG_ERROR(FAILED) - << "Check rank table config failed, expected device id of rank " << rank_id << " to be " << k; + auto parallel_count = rank_size / stage_size; + constexpr size_t card_count_per_machine = 8; + if (stage_size == 1) { + std::map> device_map; + for (size_t i = 0; i < rank_size; i++) { + const auto &item = config_.rank_list[i]; + auto &device_id_list = device_map[item.ip]; + if (device_id_list.count(item.device_id) > 0) { + return INFER_STATUS_LOG_ERROR(FAILED) << "Check rank table config failed, device id repeatedly used by rank " + << i << " in device ip " << item.ip; } - if (first_item.ip != item.ip) { - return INFER_STATUS_LOG_ERROR(FAILED) - << "Check rank table config failed, expected device ip " << item.ip << " of rank " << rank_id - << " to be equal with device ip " << first_item.ip << " of rank " << i; + device_id_list.emplace(item.device_id); + } + } else { + if (rank_size < card_count_per_machine) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Rank size " << rank_size << "must >= card count " << card_count_per_machine + << " of one machine when stage size " << stage_size << " > 1"; + } + if (parallel_count % card_count_per_machine != 0) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Parallel count " << parallel_count << " in one stage must be N * " << card_count_per_machine + << "(card count of one machine), rank size: " << rank_size << ", stage size: " << stage_size; + } + for (size_t i = 0; i < rank_size; i += card_count_per_machine) { + const auto &first_item = config_.rank_list[i]; + for (size_t k = 0; i + k < rank_size && k < card_count_per_machine; k++) { + auto rank_id = i + k; + const auto &item = config_.rank_list[rank_id]; + if (k != item.device_id) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Check rank table config failed, expected device id of rank " << rank_id << " to be " << k; + } + if (first_item.ip != item.ip) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Check rank table config failed, expected device ip " << item.ip << " of rank " << rank_id + << " to be equal with device ip " << first_item.ip << " of rank " << i; + } } } } @@ -298,5 +329,7 @@ Status DistributedServable::CheckRankConfig() { return SUCCESS; } +void DistributedServable::OnAgentFailed() { SetWaitAgentsPromise(false); } + } // namespace serving } // namespace mindspore diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.h b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.h index 642a868..d810209 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.h +++ b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.h @@ -35,9 +35,12 @@ struct DistributedAgentContext { class MS_API DistributedServable : public ServableBase { public: + DistributedServable() = default; + ~DistributedServable(); // from python, worker.py Status StartServable(const std::string &servable_directory, const std::string &servable_name, - const std::string &rank_table_json_file, uint64_t version_number); + const std::string &rank_table_json_file, uint64_t version_number, + uint64_t wait_agents_time_in_seconds); // invoke from agent Status GetDistributedServableConfig(DistributedServableConfig *config) const; @@ -55,7 +58,8 @@ class MS_API DistributedServable : public ServableBase { uint64_t GetBatchSize() const override; std::string GetServableName() const override; uint64_t GetServableVersion() const override; - void Clear(); + void Clear() override; + void OnAgentFailed(); private: DistributedServableConfig config_; @@ -63,19 +67,22 @@ class MS_API DistributedServable : public ServableBase { uint64_t version_number_ = 0; bool model_loaded_ = false; + std::mutex mutex_; std::map agent_spec_map_; std::string rank_table_json_file_; std::vector input_infos_; std::vector output_infos_; uint64_t batch_size_ = 0; - std::promise agents_promise_; + std::atomic_flag promise_set_flag_ = ATOMIC_FLAG_INIT; + std::promise agents_promise_; Status InitConfigOnStartup(const std::string &rank_table_json_file); - Status WaitAgentsReady(); + Status WaitAgentsReady(uint64_t wait_agents_time_in_seconds); Status CheckAgentsInfosAndInitTensorInfos(); Status CompareTensorInfos(const std::vector &lefts, const std::vector &rights); Status CheckRankConfig(); + void SetWaitAgentsPromise(bool flag); // agent stubs }; diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.cc b/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.cc index d9e6b73..379eeff 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.cc +++ b/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.cc @@ -62,7 +62,7 @@ Status GrpcNotifyDistributeWorker::Register(const std::vector & std::this_thread::sleep_for(std::chrono::milliseconds(REGISTER_INTERVAL * 1000)); } if (ExitSignalHandle::Instance().HasStopped()) { - return INFER_STATUS_LOG_WARNING(FAILED) << "Worker exit, stop registration"; + return INFER_STATUS_LOG_WARNING(FAILED) << "Agent exit, stop registration"; } return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Register TimeOut"; } @@ -87,5 +87,21 @@ Status GrpcNotifyDistributeWorker::Unregister() { return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Exit Failed"; } +Status GrpcNotifyDistributeWorker::NotifyFailed(const std::string &worker_ip, uint32_t worker_port) { + auto address = worker_ip + ":" + std::to_string(worker_port); + auto channel = GrpcServer::CreateChannel(address); + auto stub = proto::MSWorker::NewStub(channel); + + grpc::ClientContext context; + proto::AgentFailedRequest request; + proto::AgentFailedReply reply; + grpc::Status status = stub->AgentFailed(&context, request, &reply); + if (status.ok()) { + MSI_LOG(INFO) << "Success to notify failure of agent"; + return SUCCESS; + } + return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Failed to notify failure of agent"; +} + } // namespace serving } // namespace mindspore diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.h b/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.h index 2c2724c..da509ff 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.h +++ b/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.h @@ -19,7 +19,8 @@ #include #include #include -#include "worker/distributed_worker/notify_distributed/base_notify_worker.h" +#include "common/serving_common.h" +#include "worker/distributed_worker/common.h" #include "proto/ms_distributed.pb.h" #include "proto/ms_distributed.grpc.pb.h" #include "proto/ms_worker.pb.h" @@ -27,13 +28,15 @@ namespace mindspore { namespace serving { -class MS_API GrpcNotifyDistributeWorker : public BaseNotifyDistributeWorker { +class MS_API GrpcNotifyDistributeWorker { public: - GrpcNotifyDistributeWorker(const std::string &master_ip, uint32_t master_port, const std::string &host_ip, - uint32_t host_port); - ~GrpcNotifyDistributeWorker() override; - Status Register(const std::vector &worker_specs) override; - Status Unregister() override; + GrpcNotifyDistributeWorker(const std::string &worker_ip, uint32_t worker_port, const std::string &agent_ip, + uint32_t agent_port); + ~GrpcNotifyDistributeWorker(); + Status Register(const std::vector &agent_specs); + Status Unregister(); + // from start up, not agent + static Status NotifyFailed(const std::string &worker_ip, uint32_t worker_port); private: std::string distributed_worker_ip_; diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.cc b/mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.cc index c5e59df..a819b95 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.cc +++ b/mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.cc @@ -14,6 +14,10 @@ * limitations under the License. */ #include "worker/distributed_worker/worker_agent.h" +#include +#include "worker/distributed_worker/agent_process/agent_process.h" +#include "worker/distributed_worker/notify_distributed/notify_worker.h" +#include "common/exit_handle.h" namespace mindspore { namespace serving { @@ -23,15 +27,16 @@ WorkerAgent &WorkerAgent::Instance() { return instance; } -Status WorkerAgent::LoadModelFromFile(const AgentStartUpConfig &config) { - config_ = config; - return executor_.LoadModelFromFile(config); -} - -Status WorkerAgent::Clear() { return executor_.UnloadModel(); } - -Status WorkerAgent::ExecuteModel(const std::vector &request, std::vector *reply) { - return executor_.ExecuteModel(request, reply); +Status WorkerAgent::Clear() { + if (notify_worker_) { + if (exit_notify_worker_) { + notify_worker_->Unregister(); + } + notify_worker_ = nullptr; + } + grpc_server_.Stop(); + executor_.UnloadModel(); + return SUCCESS; } Status WorkerAgent::Run(const proto::DistributedPredictRequest &request, proto::DistributedPredictReply *reply) { @@ -40,5 +45,59 @@ Status WorkerAgent::Run(const proto::DistributedPredictRequest &request, proto:: return SUCCESS; } +Status WorkerAgent::StartAgent(const AgentStartUpConfig &config) { + Status status; + config_ = config; + status = executor_.LoadModelFromFile(config); + if (status != SUCCESS) { + MSI_LOG_ERROR << "LoadModelFromFile failed, servable name: " << config.common_meta.servable_name + << ", rank_id: " << config.rank_id << ", device id: " << config.device_id + << ", model file: " << config.model_file_name + << ", rank table file: " << config.rank_table_json_file_name + << ", group config file: " << config.group_file_name; + return status; + } + status = StartGrpcServer(); + if (status != SUCCESS) { + MSI_LOG_ERROR << "Start agent grpc server failed, agent ip: " << config.agent_ip + << ", agent port: " << config.agent_port; + return status; + } + status = RegisterAgent(); + if (status != SUCCESS) { + MSI_LOG_ERROR << "Register agent failed, agent ip: " << config.agent_ip << ", agent port: " << config.agent_port + << ", worker ip: " << config.worker_ip << ", worker port: " << config.worker_port; + return status; + } + MSI_LOG_INFO << "Start agent success, servable name: " << config.common_meta.servable_name + << ", rank_id: " << config.rank_id << ", device id: " << config.device_id + << ", model file: " << config.model_file_name + << ", rank table file: " << config.rank_table_json_file_name + << ", group config file: " << config.group_file_name; + return SUCCESS; +} + +Status WorkerAgent::StartGrpcServer() { + grpc_server_.Start(std::make_shared(), config_.agent_ip, config_.agent_port, gRpcMaxMBMsgSize, "Agent"); + return SUCCESS; +} + +Status WorkerAgent::RegisterAgent() { + notify_worker_ = std::make_shared(config_.worker_ip, config_.agent_port, config_.agent_ip, + config_.agent_port); + WorkerAgentSpec spec; + spec.agent_address = config_.agent_ip + ":" + std::to_string(config_.agent_port); + spec.rank_id = config_.rank_id; + spec.batch_size = executor_.GetBatchSize(); + spec.input_infos = executor_.GetInputInfos(); + spec.output_infos = executor_.GetOutputInfos(); + return notify_worker_->Register({spec}); +} + +void WorkerAgent::StopAgent(bool notify_worker) { + exit_notify_worker_ = notify_worker; + ExitSignalHandle::Instance().Stop(); +} + } // namespace serving } // namespace mindspore diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.h b/mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.h index 520c4db..702e791 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.h +++ b/mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.h @@ -17,24 +17,36 @@ #ifndef MINDSPORE_SERVING_WORKER_AGENT_H #define MINDSPORE_SERVING_WORKER_AGENT_H #include +#include #include "worker/distributed_worker/agent_executor.h" #include "proto/ms_agent.pb.h" #include "proto/ms_agent.grpc.pb.h" +#include "common/grpc_server.h" +#include "worker/distributed_worker/common.h" +#include "worker/distributed_worker/notify_distributed/notify_worker.h" namespace mindspore { namespace serving { class MS_API WorkerAgent { public: static WorkerAgent &Instance(); - Status LoadModelFromFile(const AgentStartUpConfig &config); Status Clear(); - Status ExecuteModel(const std::vector &request, std::vector *reply); Status Run(const proto::DistributedPredictRequest &request, proto::DistributedPredictReply *reply); + Status StartAgent(const AgentStartUpConfig &config); + + void StopAgent(bool notify_worker = true); + private: AgentStartUpConfig config_; WorkerAgentExecutor executor_; + GrpcServer grpc_server_; + bool exit_notify_worker_ = true; + std::shared_ptr notify_worker_; + + Status StartGrpcServer(); + Status RegisterAgent(); }; } // namespace serving diff --git a/mindspore_serving/ccsrc/worker/grpc/worker_server.h b/mindspore_serving/ccsrc/worker/grpc/worker_server.h index 1452727..d02d014 100644 --- a/mindspore_serving/ccsrc/worker/grpc/worker_server.h +++ b/mindspore_serving/ccsrc/worker/grpc/worker_server.h @@ -39,7 +39,7 @@ class MS_API MSWorkerServer { MSWorkerServer(); virtual ~MSWorkerServer(); - Status StartWorkerGrpcServer(const std::string &hostname, int32_t port); + virtual Status StartWorkerGrpcServer(const std::string &hostname, int32_t port); Status Stop(); protected: @@ -48,21 +48,32 @@ class MS_API MSWorkerServer { std::unique_ptr service_impl_ = nullptr; std::unique_ptr async_server_ = nullptr; - Status StartAsyncRpcService(); Status Init(); + Status StartAsyncRpcService(); }; class WorkerServiceContext { public: enum class STATE : int8_t { CREATE = 1, PROCESS = 2, FINISH = 3 }; + + WorkerServiceContext(MSWorkerImpl *service_impl, proto::MSWorker::AsyncService *async_service, + grpc::ServerCompletionQueue *cq) + : service_impl_(service_impl), async_service_(async_service), cq_(cq) { + state_ = STATE::CREATE; + } virtual ~WorkerServiceContext() {} + bool JudgeFinish() { return state_ == STATE::FINISH; } + virtual void StartEnqueueRequest() = 0; virtual void HandleRequest() = 0; - virtual bool JudgeFinish() = 0; + protected: + MSWorkerImpl *service_impl_; + proto::MSWorker::AsyncService *async_service_; + grpc::ServerCompletionQueue *cq_; + grpc::ServerContext ctx_; - public: STATE state_; }; @@ -70,9 +81,7 @@ class WorkerPredictContext : public WorkerServiceContext { public: WorkerPredictContext(MSWorkerImpl *service_impl, proto::MSWorker::AsyncService *async_service, grpc::ServerCompletionQueue *cq) - : service_impl_(service_impl), async_service_(async_service), cq_(cq), responder_(&ctx_) { - state_ = STATE::CREATE; - } + : WorkerServiceContext(service_impl, async_service, cq), responder_(&ctx_) {} ~WorkerPredictContext() = default; @@ -95,13 +104,7 @@ class WorkerPredictContext : public WorkerServiceContext { responder_.Finish(response_, status, this); } - bool JudgeFinish() override { return state_ == STATE::FINISH; } - private: - MSWorkerImpl *service_impl_; - proto::MSWorker::AsyncService *async_service_; - grpc::ServerCompletionQueue *cq_; - grpc::ServerContext ctx_; grpc::ServerAsyncResponseWriter responder_; proto::PredictRequest request_; proto::PredictReply response_; @@ -111,9 +114,7 @@ class WorkerExitContext : public WorkerServiceContext { public: WorkerExitContext(MSWorkerImpl *service_impl, proto::MSWorker::AsyncService *async_service, grpc::ServerCompletionQueue *cq) - : service_impl_(service_impl), async_service_(async_service), cq_(cq), responder_(&ctx_) { - state_ = STATE::CREATE; - } + : WorkerServiceContext(service_impl, async_service, cq), responder_(&ctx_) {} ~WorkerExitContext() = default; @@ -136,13 +137,7 @@ class WorkerExitContext : public WorkerServiceContext { responder_.Finish(response_, status, this); } - bool JudgeFinish() override { return state_ == STATE::FINISH; } - private: - MSWorkerImpl *service_impl_; - proto::MSWorker::AsyncService *async_service_; - grpc::ServerCompletionQueue *cq_; - grpc::ServerContext ctx_; grpc::ServerAsyncResponseWriter responder_; proto::ExitRequest request_; proto::ExitReply response_; diff --git a/mindspore_serving/ccsrc/worker/local_servable/local_sevable.cc b/mindspore_serving/ccsrc/worker/local_servable/local_sevable.cc index 9680929..81452c1 100644 --- a/mindspore_serving/ccsrc/worker/local_servable/local_sevable.cc +++ b/mindspore_serving/ccsrc/worker/local_servable/local_sevable.cc @@ -31,7 +31,7 @@ static const char *kVersionStrategySpecific = "specific"; namespace mindspore::serving { -LocalModelServable::~LocalModelServable() { session_.UnloadModel(); } +LocalModelServable::~LocalModelServable() { Clear(); } std::string LocalModelServable::GetServableName() const { return servable_name_; } @@ -248,4 +248,11 @@ Status LocalModelServable::LoadModel(uint64_t version_number) { return SUCCESS; } +void LocalModelServable::Clear() { + if (model_loaded_) { + session_.UnloadModel(); + } + model_loaded_ = false; +} + } // namespace mindspore::serving diff --git a/mindspore_serving/ccsrc/worker/local_servable/local_sevable.h b/mindspore_serving/ccsrc/worker/local_servable/local_sevable.h index d5b9a8c..227c9e9 100644 --- a/mindspore_serving/ccsrc/worker/local_servable/local_sevable.h +++ b/mindspore_serving/ccsrc/worker/local_servable/local_sevable.h @@ -48,6 +48,7 @@ class MS_API LocalModelServable : public ServableBase { Status InitDevice(ModelType model_type, const std::map &other_options); std::string GetServableName() const override; uint64_t GetServableVersion() const override; + void Clear() override; private: LoadServableSpec base_spec_; diff --git a/mindspore_serving/ccsrc/worker/sevable_base.h b/mindspore_serving/ccsrc/worker/sevable_base.h index 8e9e800..c3acd00 100644 --- a/mindspore_serving/ccsrc/worker/sevable_base.h +++ b/mindspore_serving/ccsrc/worker/sevable_base.h @@ -41,6 +41,7 @@ class ServableBase { virtual uint64_t GetBatchSize() const = 0; virtual std::string GetServableName() const = 0; virtual uint64_t GetServableVersion() const = 0; + virtual void Clear() = 0; }; } // namespace mindspore::serving diff --git a/mindspore_serving/ccsrc/worker/worker.cc b/mindspore_serving/ccsrc/worker/worker.cc index 87167ef..10e25e1 100644 --- a/mindspore_serving/ccsrc/worker/worker.cc +++ b/mindspore_serving/ccsrc/worker/worker.cc @@ -174,9 +174,13 @@ void Worker::Update() { */ } -Status Worker::AfterStartGrpcServer(const std::shared_ptr &grpc_server) { +Status Worker::StartGrpcServer(const std::shared_ptr &grpc_server, const std::string &worker_ip, + int32_t port) { + if (worker_grpc_server_ != nullptr) { + return INFER_STATUS_LOG_ERROR(FAILED) << "Worker gRPC server is already running"; + } worker_grpc_server_ = grpc_server; - return SUCCESS; + return worker_grpc_server_->StartWorkerGrpcServer(worker_ip, port); } Status Worker::StartServable(std::shared_ptr servable, std::shared_ptr notify_master) { @@ -248,6 +252,9 @@ void Worker::Clear() { if (exit_notify_master_ && servable_started_) { notify_master_->Unregister(); } + for (auto &worker_item : work_list_) { + worker_item.servable->Clear(); + } work_list_.clear(); py_task_queue_group_.Stop(); @@ -257,7 +264,7 @@ void Worker::Clear() { MSI_LOG_INFO << "End clear worker session"; } -bool Worker::HasCleared() { return !servable_started_; } +bool Worker::IsRunning() { return servable_started_; } Worker::~Worker() { Clear(); } @@ -318,7 +325,7 @@ Status AsyncResult::GetNext(Instance *instance_result) { const int kWaitMaxHundredMs = 100; int i; for (i = 0; i < kWaitMaxHundredMs; i++) { // - if (ExitSignalHandle::Instance().HasStopped() || Worker::GetInstance().HasCleared()) { + if (ExitSignalHandle::Instance().HasStopped() || !Worker::GetInstance().IsRunning()) { instance_result->error_msg = Status(SYSTEM_ERROR, "Servable stopped"); return SYSTEM_ERROR; } diff --git a/mindspore_serving/ccsrc/worker/worker.h b/mindspore_serving/ccsrc/worker/worker.h index 50d95eb..ef66043 100644 --- a/mindspore_serving/ccsrc/worker/worker.h +++ b/mindspore_serving/ccsrc/worker/worker.h @@ -75,10 +75,11 @@ class MS_API Worker { const std::vector &inputs); Status StartServable(std::shared_ptr servable, std::shared_ptr notify_master); - Status AfterStartGrpcServer(const std::shared_ptr &grpc_server); + Status StartGrpcServer(const std::shared_ptr &grpc_server, const std::string &worker_ip, + int32_t port); void StopServable(bool notify_master = true); - bool HasCleared(); + bool IsRunning(); Status RegisterWorker(); void Update(); Status StartVersionController(); diff --git a/mindspore_serving/master/_master.py b/mindspore_serving/master/_master.py index 78abb52..0d61459 100644 --- a/mindspore_serving/master/_master.py +++ b/mindspore_serving/master/_master.py @@ -18,6 +18,7 @@ import threading from functools import wraps from mindspore_serving.worker import check_type from mindspore_serving import log as logger +from mindspore_serving._mindspore_serving import ExitSignalHandle_ from mindspore_serving._mindspore_serving import Master_ _wait_and_clear_thread = None @@ -59,6 +60,7 @@ def stop_on_except(func): @wraps(func) def handle_except(*args, **kwargs): try: + ExitSignalHandle_.start() # Set flag to running and receive Ctrl+C message func(*args, **kwargs) except: stop() diff --git a/mindspore_serving/proto/ms_distributed.proto b/mindspore_serving/proto/ms_distributed.proto index fb6c72a..27fa6c4 100644 --- a/mindspore_serving/proto/ms_distributed.proto +++ b/mindspore_serving/proto/ms_distributed.proto @@ -44,3 +44,10 @@ message AgentExitRequest { message AgentExitReply { ErrorMsg error_msg = 1; } + +message AgentFailedRequest { +} + +message AgentFailedReply { + ErrorMsg error_msg = 1; +} diff --git a/mindspore_serving/proto/ms_worker.proto b/mindspore_serving/proto/ms_worker.proto index 436b52f..7b2dbe0 100644 --- a/mindspore_serving/proto/ms_worker.proto +++ b/mindspore_serving/proto/ms_worker.proto @@ -29,4 +29,5 @@ service MSWorker { // for worker agent rpc AgentExit(AgentExitRequest) returns (AgentExitReply) {} rpc AgentRegister(AgentRegisterRequest) returns (AgentRegisterReply) {} + rpc AgentFailed(AgentFailedRequest) returns (AgentFailedReply) {} } diff --git a/mindspore_serving/worker/_worker.py b/mindspore_serving/worker/_worker.py index 33ba9fb..b11de95 100644 --- a/mindspore_serving/worker/_worker.py +++ b/mindspore_serving/worker/_worker.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""Inferface for start up servable""" +"""Interface for start up servable""" import threading from functools import wraps from mindspore_serving import log as logger +from mindspore_serving._mindspore_serving import ExitSignalHandle_ from mindspore_serving._mindspore_serving import Worker_ from .register.preprocess import preprocess_storage from .register.postprocess import postprocess_storage @@ -77,6 +78,7 @@ def stop_on_except(func): @wraps(func) def handle_except(*args, **kwargs): try: + ExitSignalHandle_.start() # Set flag to running and receive Ctrl+C message func(*args, **kwargs) except: stop() diff --git a/mindspore_serving/worker/distributed/agent_startup.py b/mindspore_serving/worker/distributed/agent_startup.py index 41d8218..8bf27d1 100644 --- a/mindspore_serving/worker/distributed/agent_startup.py +++ b/mindspore_serving/worker/distributed/agent_startup.py @@ -13,31 +13,238 @@ # limitations under the License. # ============================================================================ """Serving, distributed worker agent startup""" -import inspect +import os +import time +from multiprocessing import Process, Pipe + +from mindspore_serving._mindspore_serving import ExitSignalHandle_ +from mindspore_serving._mindspore_serving import WorkerAgent_, AgentStartUpConfig_ + +from mindspore_serving import log as logger from mindspore_serving.worker import check_type +from mindspore_serving.worker.distributed import worker_agent + + +def _get_local_ip(rank_list, port): + """Get the local ip from the rank table config""" + import socket + ip_list = [] + for item in rank_list: + ip_list.append(item.ip) + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + for ip in ip_list: + try: + s.bind((ip, port)) + logger.info(f"Get local machine ip success, ip {ip}") + return ip + # pylint: disable=bare-except + except: + pass + raise RuntimeError(f"Get local machine ip failed, rank table ips: {ip_list}, bind port {port}") + + +def _update_model_files_path(model_files, group_config_files): + """Check and return model files or group config files""" + script_dir = os.path.dirname(os.path.realpath(__file__)) + logger.info(f"input model files: {model_files}") + logger.info(f"input group config files: {group_config_files}") + model_files_temp = [] + for item in model_files: + file_name = os.path.join(script_dir, item) + if not os.access(file_name, os.R_OK): + raise RuntimeError(f"Cannot access model file '{file_name}'") + model_files_temp.append(file_name) + + group_files_temp = [] + for item in group_config_files: + file_name = os.path.join(script_dir, item) + if not os.access(file_name, os.R_OK): + raise RuntimeError(f"Cannot access group config file '{file_name}'") + group_files_temp.append(file_name) + + logger.info(f"absolute model files: {model_files_temp}") + logger.info(f"absolute group config files: {group_files_temp}") + return model_files_temp, group_files_temp + + +def _make_json_table_file(distributed_config): + """Make rank table json file""" + rank_size = len(distributed_config.rank_list) + runtime_dir = os.path.abspath(".") + time_stamp = str(time.strftime('%Y_%m_%d_%H_%M_%S', time.localtime(time.time()))) + rank_table_file_name = os.path.join(runtime_dir, f"hccl_rank_table_{time_stamp}_{rank_size}.json") + with open(rank_table_file_name, "w") as fp: + fp.write(distributed_config.rank_table_content) + return rank_table_file_name + + +signal_success = "Success" +signal_exit = "Exit" +signal_heartbeat = "HeartBeat" + + +def _recv_parent(index, recv_pipe): + """Receive message from Start up process. + Return False on Ctrl+C(and worker Stop message) Exit Signal, heartbeat failed, and signal_exit. + Return True on receiving signal_success.""" + try: + while True: + heartbeat_count = 0 + while not recv_pipe.poll(0.1): + if ExitSignalHandle_.has_stopped(): + logger.warning(f"Child {index}: Exit on Ctrl+C or stop message from worker") + return False + heartbeat_count += 1 + if heartbeat_count >= 30: # 3s + logger.warning(f"Child {index}: Exit on failure of receiving parent message") + return False + parent_signal = recv_pipe.recv() + if parent_signal != signal_heartbeat: + break + if parent_signal == signal_success: + logger.info(f"Child {index}: Receive success") + return True + if parent_signal == signal_exit: + logger.warning(f"Child {index}: Exit on receiving exit message") + else: + logger.warning(f"Child {index}: Exit on receiving unknown message {parent_signal}") + # pylint: disable=broad-except + except Exception as e: + logger.warning(f"Child {index}: Exit on exception: {e}") + return False + + +def _agent_process(send_pipe, recv_pipe, index, start_config): + """Agent process""" + try: + # listening success or failed message from parent process + ExitSignalHandle_.start() # Set flag to running and receive Ctrl+C message + worker_agent.start_worker_agent(start_config=start_config) + send_pipe.send((index, signal_success)) + success_msg = _recv_parent(index, recv_pipe) + if not success_msg: + worker_agent.stop() + send_pipe.close() + recv_pipe.close() + # pylint: disable=broad-except + except Exception as e: + logger.error(f"Child {index}: Catch exception and notify exit of others") + send_pipe.send((index, e)) + worker_agent.stop() + raise -def startup_worker_agents(worker_ip, worker_port, - get_model_files_fun, get_group_configs_fun, - rank_start, agent_start_port=7000): - """Start up all needed worker agents on one machine - """ +def _start_listening_child_processes(p_recv_pipe, send_pipe_list, subprocess_list): + """Listening child process""" + def send_pipe_msg(send_pipe, msg): + try: + send_pipe.send(msg) + # pylint: disable=broad-except + except Exception as e: + logger.warning(f"Send pipe message exception happen: {e}") + + count = len(send_pipe_list) + for _ in range(count): + while True: + if p_recv_pipe.poll(0.1): + break + for send_pipe, process in zip(send_pipe_list, subprocess_list): + if process.is_alive(): + continue + logger.warning("Fail to start agents because of death of one agent") + for send_pipe_x, process_x in zip(send_pipe_list, subprocess_list): + if process_x.is_alive(): + send_pipe_msg(send_pipe_x, signal_exit) + return False + for send_pipe in send_pipe_list: + send_pipe_msg(send_pipe, signal_heartbeat) + + _, msg = p_recv_pipe.recv() + if isinstance(msg, Exception): + logger.warning("Fail to start agents because of exception raise by one agent") + for send_pipe in send_pipe_list: + send_pipe_msg(send_pipe, signal_exit) + return False + + for send_pipe in send_pipe_list: + send_pipe_msg(send_pipe, signal_success) + logger.info("Success to start agents") + return True + + +def _startup_all_agents(common_meta, worker_ip, worker_port, + agent_ip, agent_start_port, device_id_list, rank_id_list, + model_files, group_config_files, rank_table_file): + """Start up all agents in one machine""" + servable_name = common_meta.servable_name + index = 0 + send_pipe_list = [] + subprocess_list = [] + c_send_pipe, p_recv_pipe = Pipe() + for device_id, rank_id, model_file, group_file in zip(device_id_list, rank_id_list, model_files, + group_config_files): + p_send_pipe, c_recv_pipe = Pipe() + send_pipe_list.append(p_send_pipe) + + agent_port = agent_start_port + index + + start_config = AgentStartUpConfig_() + start_config.rank_id = rank_id + start_config.device_id = device_id + start_config.model_file_name = model_file + start_config.group_file_name = group_file + start_config.rank_table_json_file_name = rank_table_file + start_config.agent_ip = agent_ip + start_config.agent_port = agent_port + start_config.worker_ip = worker_ip + start_config.worker_port = worker_port + start_config.common_meta = common_meta + + process = Process(target=_agent_process, + args=(c_send_pipe, c_recv_pipe, index, start_config), + name=f"{servable_name}_worker_agent_rank{rank_id}_device{device_id}") + process.start() + subprocess_list.append(process) + index += 1 + ret = _start_listening_child_processes(p_recv_pipe, send_pipe_list, subprocess_list) + if not ret: + WorkerAgent_.notify_failed(worker_ip, worker_port) + + +def startup_worker_agents(worker_ip, worker_port, model_files, group_config_files, agent_start_port=7000): + """Start up all needed worker agents on one machine""" check_type.check_str("worker_ip", worker_ip) check_type.check_ip_port("worker_port", worker_port) check_type.check_int("agent_start_port", agent_start_port, 1, 65535 - 7) - if inspect.isfunction(get_model_files_fun): - pass - else: - if not isinstance(get_model_files_fun, [list, tuple]): - raise RuntimeError(f"Check failed, get_model_files_fun first must be function or tuple/list of str, " - f"now is {type(get_model_files_fun)}") - if inspect.isfunction(get_group_configs_fun): - pass - else: - if not isinstance(get_group_configs_fun, [list, tuple]): - raise RuntimeError(f"Check failed, get_group_configs_fun first must be function or tuple/list of str, " - f"now is {type(get_group_configs_fun)}") - check_type.check_int("rank_start", rank_start, 0) - if rank_start % 8 != 0: - raise RuntimeError(f"Parameter 'rank_start' must be mulfiply of 8, now is {rank_start}") + model_files = check_type.check_and_as_int_tuple_list("model_files", model_files) + group_config_files = check_type.check_and_as_int_tuple_list("group_config_files", group_config_files) + distributed_config = WorkerAgent_.get_agents_config_from_worker(worker_ip, worker_port) + + # get machine ip + rank_list = distributed_config.rank_list + local_ip = _get_local_ip(rank_list, agent_start_port) + # get all device_id and rank_id + local_device_id_list = [] + local_rank_id_list = [] + for rank_id, item in enumerate(rank_list): + if item.ip == local_ip: + local_device_id_list.append(item.device_id) + local_rank_id_list.append(rank_id) + + # handle model files and group config files + if len(local_device_id_list) != len(model_files): + raise RuntimeError(f"Card count {local_device_id_list} described rank table does not equal to model files size " + f"{len(model_files)}, model files: {model_files}") + + if len(local_device_id_list) != len(group_config_files): + raise RuntimeError(f"Card count {local_device_id_list} described rank table does not equal to group config " + f"files size {len(group_config_files)}, group config files: {group_config_files}") + + model_files, group_config_files = _update_model_files_path(model_files, group_config_files) + + # make json table file and export env + rank_table_file = _make_json_table_file(distributed_config) + _startup_all_agents(distributed_config.common_meta, worker_ip, worker_port, local_ip, agent_start_port, + local_device_id_list, local_rank_id_list, + model_files, group_config_files, rank_table_file) diff --git a/mindspore_serving/worker/distributed/distributed_worker.py b/mindspore_serving/worker/distributed/distributed_worker.py index 5bee6b8..4235ee6 100644 --- a/mindspore_serving/worker/distributed/distributed_worker.py +++ b/mindspore_serving/worker/distributed/distributed_worker.py @@ -13,15 +13,17 @@ # limitations under the License. # ============================================================================ """Serving, distributed worker startup""" -from mindspore_serving.worker._worker import stop_on_except, _load_servable_config -from mindspore_serving.worker._worker import _start_py_task, _start_wait_and_clear -from mindspore_serving.worker import check_type from mindspore_serving._mindspore_serving import Worker_ +from mindspore_serving.worker import check_type +from mindspore_serving.worker._worker import _start_py_task, _start_wait_and_clear +from mindspore_serving.worker._worker import stop_on_except, _load_servable_config + @stop_on_except def start_distributed_servable(servable_directory, servable_name, rank_table_json_file, version_number=1, - worker_ip="0.0.0.0", worker_port=6200, master_ip="0.0.0.0", master_port=6100): + worker_ip="0.0.0.0", worker_port=6200, master_ip="0.0.0.0", master_port=6100, + wait_agents_time_in_seconds=300): r""" Start up the servable named 'servable_name' defined in 'servable_directory', and link the worker to the master through gRPC (master_ip, master_port). @@ -46,6 +48,7 @@ def start_distributed_servable(servable_directory, servable_name, rank_table_jso master_port (int): The master port the worker linked to. worker_ip (str): The worker ip the master and agents linked to. worker_port (int): The worker port the master and agents linked to. + wait_agents_time_in_seconds(int): The maximum time in seconds the worker waiting ready of all agents. Examples: >>> import os @@ -70,15 +73,15 @@ def start_distributed_servable(servable_directory, servable_name, rank_table_jso check_type.check_ip_port('worker_port', worker_port) _load_servable_config(servable_directory, servable_name) - _start_wait_and_clear() Worker_.start_distributed_servable(servable_directory, servable_name, rank_table_json_file, version_number, - master_ip, master_port, worker_ip, worker_port) + master_ip, master_port, worker_ip, worker_port, wait_agents_time_in_seconds) _start_py_task(Worker_.get_batch_size()) + _start_wait_and_clear() @stop_on_except def start_distributed_servable_in_master(servable_directory, servable_name, rank_table_json_file, version_number=1, - worker_ip="0.0.0.0", worker_port=6200): + worker_ip="0.0.0.0", worker_port=6200, wait_agents_time_in_seconds=300): r""" Start up the servable named 'servable_name' defined in 'svable_directory', and the worker will run in the process of the master. @@ -97,6 +100,7 @@ def start_distributed_servable_in_master(servable_directory, servable_name, rank rank_table_json_file (str): The ranke table json file name. worker_ip (str): The worker ip the agents linked to. worker_port (int): The worker port the agents linked to. + wait_agents_time_in_seconds(int): The maximum time in seconds the worker waiting ready of all agents. Examples: >>> import os @@ -121,7 +125,7 @@ def start_distributed_servable_in_master(servable_directory, servable_name, rank check_type.check_ip_port('worker_port', worker_port) _load_servable_config(servable_directory, servable_name) - _start_wait_and_clear() Worker_.start_distributed_servable_in_master(servable_directory, servable_name, rank_table_json_file, - version_number, worker_ip, worker_port) + version_number, worker_ip, worker_port, wait_agents_time_in_seconds) _start_py_task(Worker_.get_batch_size()) + _start_wait_and_clear() diff --git a/mindspore_serving/worker/distributed/worker_agent.py b/mindspore_serving/worker/distributed/worker_agent.py index d1ebb99..ad32a53 100644 --- a/mindspore_serving/worker/distributed/worker_agent.py +++ b/mindspore_serving/worker/distributed/worker_agent.py @@ -13,22 +13,54 @@ # limitations under the License. # ============================================================================ """Serving, distributed worker agent""" -from mindspore_serving.worker import check_type +import os +import threading +from mindspore_serving._mindspore_serving import WorkerAgent_, AgentStartUpConfig_ +from mindspore_serving import log as logger -def _start_worker_agent(agent_ip, agent_port, worker_ip, worker_port, - rank_id, device_id, model_file, group_config_file, rank_table_file, - with_bach_dim, without_batch_dim_inputs): + +def start_worker_agent(start_config): """Start up one worker agent on one device id, invoke by agent_startup.startup_worker_agents """ - check_type.check_str("agent_ip", agent_ip) - check_type.check_ip_port("agent_port", agent_port) - check_type.check_str("worker_ip", worker_ip) - check_type.check_ip_port("worker_port", worker_port) - check_type.check_int("rank_id", rank_id, 0) - check_type.check_int("device_id", device_id, 0) - check_type.check_str("model_file", model_file) - check_type.check_str("group_config_file", group_config_file) - check_type.check_str("rank_table_file", rank_table_file) - check_type.check_bool("with_bach_dim", with_bach_dim) - check_type.check_and_as_int_tuple_list("without_batch_dim_inputs", without_batch_dim_inputs, 0) + if not isinstance(start_config, AgentStartUpConfig_): + raise RuntimeError("Parameter 'start_config' should be instance of AgentStartUpConfig_") + + os.environ["RANK_ID"] = str(start_config.rank_id) + os.environ["DEVICE_ID"] = str(start_config.device_id) + os.environ["MS_ENABLE_HCCL"] = "1" + os.environ["PARA_GROUP_FILE"] = start_config.group_file_name + os.environ["RANK_TABLE_FILE"] = start_config.rank_table_json_file_name + + for item in ("RANK_ID", "DEVICE_ID", "MS_ENABLE_HCCL", "PARA_GROUP_FILE", "RANK_TABLE_FILE", + "LD_LIBRARY_PATH", "PYTHONPATH"): + logger.info(f"Env {item}: {os.getenv(item, '')}") + WorkerAgent_.start_agent(start_config) + + start_wait_and_clear() + + +_wait_and_clear_thread = None + + +def start_wait_and_clear(): + """Waiting for Ctrl+C, and clear up environment""" + + def thread_func(): + logger.info("Serving worker: wait for Ctrl+C to exit ------------------------------------") + print("Serving worker: wait for Ctrl+C to exit ------------------------------------") + WorkerAgent_.wait_and_clear() + logger.info("Serving worker: exited ------------------------------------") + print("Serving worker: exited ------------------------------------") + + global _wait_and_clear_thread + if not _wait_and_clear_thread: + _wait_and_clear_thread = threading.Thread(target=thread_func) + _wait_and_clear_thread.start() + + +def stop(): + r""" + Stop the running of agent. + """ + WorkerAgent_.stop_and_clear()