| @@ -39,7 +39,6 @@ Status Server::StartGrpcServer(const std::string &ip, uint32_t grpc_port, int ma | |||
| if (grpc_async_server_ != nullptr) { | |||
| return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Serving Error: Serving gRPC server is already running"; | |||
| } | |||
| ExitSignalHandle::Instance().Start(); // handle ctrl+c to exit | |||
| if (max_msg_mb_size > gRpcMaxMBMsgSize) { | |||
| MSI_LOG_WARNING << "The maximum Serving gRPC message size is 512MB and will be updated from " << max_msg_mb_size | |||
| << "MB to 512MB"; | |||
| @@ -50,14 +49,12 @@ Status Server::StartGrpcServer(const std::string &ip, uint32_t grpc_port, int ma | |||
| } | |||
| Status Server::StartGrpcMasterServer(const std::string &ip, uint32_t grpc_port) { | |||
| ExitSignalHandle::Instance().Start(); // handle ctrl+c to exit | |||
| return grpc_manager_server_.Start(std::make_shared<MSMasterImpl>(dispatcher_), ip, grpc_port, gRpcMaxMBMsgSize, | |||
| "Master"); | |||
| } | |||
| Status Server::StartRestfulServer(const std::string &ip, uint32_t restful_port, int max_msg_mb_size, | |||
| int time_out_second) { | |||
| ExitSignalHandle::Instance().Start(); // handle ctrl+c to exit | |||
| return restful_server_.Start(ip, restful_port, max_msg_mb_size, time_out_second); | |||
| } | |||
| @@ -0,0 +1,63 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "python/agent/agent_py.h" | |||
| #include "common/exit_handle.h" | |||
| #include "worker/distributed_worker/agent_startup.h" | |||
| #include "worker/distributed_worker/worker_agent.h" | |||
| namespace mindspore::serving { | |||
| DistributedServableConfig PyAgent::GetAgentsConfigsFromWorker(const std::string &worker_ip, uint32_t worker_port) { | |||
| auto status = WorkerAgentStartUp::Instance().GetAgentsConfigsFromWorker(worker_ip, worker_port); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | |||
| } | |||
| DistributedServableConfig config; | |||
| status = WorkerAgentStartUp::Instance().GetDistributedServableConfig(&config); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | |||
| } | |||
| return config; | |||
| } | |||
| void PyAgent::NotifyFailed(const std::string &worker_ip, uint32_t worker_port) { | |||
| WorkerAgentStartUp::Instance().NotifyFailed(worker_ip, worker_port); | |||
| } | |||
| void PyAgent::StartAgent(const AgentStartUpConfig &start_config) { | |||
| auto status = WorkerAgent::Instance().StartAgent(start_config); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | |||
| } | |||
| } | |||
| void PyAgent::WaitAndClear() { | |||
| { | |||
| py::gil_scoped_release release; | |||
| ExitSignalHandle::Instance().AgentWait(); | |||
| } | |||
| WorkerAgent::Instance().Clear(); | |||
| MSI_LOG_INFO << "Python agent end wait and clear"; | |||
| } | |||
| void PyAgent::StopAndClear() { | |||
| ExitSignalHandle::Instance().Stop(); | |||
| WorkerAgent::Instance().Clear(); | |||
| } | |||
| } // namespace mindspore::serving | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -14,25 +14,34 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_SERVING_WORKER_BASE_NOTIFY_WORKER_H | |||
| #define MINDSPORE_SERVING_WORKER_BASE_NOTIFY_WORKER_H | |||
| #include <vector> | |||
| #ifndef MINDSPORE_SERVER_AGENT_PY_H | |||
| #define MINDSPORE_SERVER_AGENT_PY_H | |||
| #include <pybind11/pybind11.h> | |||
| #include <pybind11/numpy.h> | |||
| #include <pybind11/stl.h> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "common/serving_common.h" | |||
| #include "common/servable.h" | |||
| #include "worker/distributed_worker/common.h" | |||
| namespace py = pybind11; | |||
| namespace mindspore { | |||
| namespace serving { | |||
| class MS_API BaseNotifyDistributeWorker { | |||
| class MS_API PyAgent { | |||
| public: | |||
| BaseNotifyDistributeWorker() = default; | |||
| virtual ~BaseNotifyDistributeWorker() = default; | |||
| virtual Status Register(const std::vector<WorkerAgentSpec> &worker_specs) = 0; | |||
| virtual Status Unregister() = 0; | |||
| static void StartAgent(const AgentStartUpConfig &start_config); | |||
| static DistributedServableConfig GetAgentsConfigsFromWorker(const std::string &worker_ip, uint32_t worker_port); | |||
| static void WaitAndClear(); | |||
| static void StopAndClear(); | |||
| // from start up, not agent | |||
| static void NotifyFailed(const std::string &worker_ip, uint32_t worker_port); | |||
| }; | |||
| } // namespace serving | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_SERVING_WORKER_BASE_NOTIFY_WORKER_H | |||
| #endif // MINDSPORE_SERVER_AGENT_PY_H | |||
| @@ -23,6 +23,9 @@ | |||
| #include "common/servable.h" | |||
| #include "worker/context.h" | |||
| #include "python/master/master_py.h" | |||
| #include "python/agent/agent_py.h" | |||
| #include "common/exit_handle.h" | |||
| #include "worker/distributed_worker/worker_agent.h" | |||
| namespace mindspore::serving { | |||
| @@ -104,11 +107,23 @@ void PyRegServable(pybind11::module *m_ptr) { | |||
| .def_static("register_method", &PyServableStorage::RegisterMethod) | |||
| .def_static("declare_servable", &PyServableStorage::DeclareServable) | |||
| .def_static("declare_distributed_servable", &PyServableStorage::DeclareDistributedServable); | |||
| py::class_<OneRankConfig>(m, "OneRankConfig_") | |||
| .def(py::init<>()) | |||
| .def_readwrite("device_id", &OneRankConfig::device_id) | |||
| .def_readwrite("ip", &OneRankConfig::ip); | |||
| py::class_<DistributedServableConfig>(m, "DistributedServableConfig_") | |||
| .def(py::init<>()) | |||
| .def_readwrite("common_meta", &DistributedServableConfig::common_meta) | |||
| .def_readwrite("distributed_meta", &DistributedServableConfig::distributed_meta) | |||
| .def_readwrite("rank_table_content", &DistributedServableConfig::rank_table_content) | |||
| .def_readwrite("rank_list", &DistributedServableConfig::rank_list); | |||
| } | |||
| void PyRegMaster(pybind11::module *m_ptr) { | |||
| auto &m = *m_ptr; | |||
| py::class_<PyMaster, std::shared_ptr<PyMaster>>(m, "Master_") | |||
| py::class_<PyMaster>(m, "Master_") | |||
| .def_static("start_grpc_server", &PyMaster::StartGrpcServer) | |||
| .def_static("start_grpc_master_server", &PyMaster::StartGrpcMasterServer) | |||
| .def_static("start_restful_server", &PyMaster::StartRestfulServer) | |||
| @@ -163,15 +178,50 @@ void PyRegWorker(pybind11::module *m_ptr) { | |||
| .def("set_device_id", &ServableContext::SetDeviceId); | |||
| } | |||
| void PyRegWorkerAgent(pybind11::module *m_ptr) { | |||
| auto &m = *m_ptr; | |||
| py::class_<PyAgent>(m, "WorkerAgent_") | |||
| .def_static("get_agents_config_from_worker", &PyAgent::GetAgentsConfigsFromWorker) | |||
| .def_static("wait_and_clear", &PyAgent::WaitAndClear) | |||
| .def_static("stop_and_clear", &PyAgent::StopAndClear) | |||
| .def_static("notify_failed", &PyAgent::NotifyFailed) | |||
| .def_static("start_agent", &PyAgent::StartAgent); | |||
| py::class_<AgentStartUpConfig>(m, "AgentStartUpConfig_") | |||
| .def(py::init<>()) | |||
| .def_readwrite("rank_id", &AgentStartUpConfig::rank_id) | |||
| .def_readwrite("device_id", &AgentStartUpConfig::device_id) | |||
| .def_readwrite("model_file_name", &AgentStartUpConfig::model_file_name) | |||
| .def_readwrite("group_file_name", &AgentStartUpConfig::group_file_name) | |||
| .def_readwrite("rank_table_json_file_name", &AgentStartUpConfig::rank_table_json_file_name) | |||
| .def_readwrite("agent_ip", &AgentStartUpConfig::agent_ip) | |||
| .def_readwrite("agent_port", &AgentStartUpConfig::agent_port) | |||
| .def_readwrite("worker_ip", &AgentStartUpConfig::worker_ip) | |||
| .def_readwrite("worker_port", &AgentStartUpConfig::worker_port) | |||
| .def_readwrite("common_meta", &AgentStartUpConfig::common_meta); | |||
| } | |||
| class PyExitSignalHandle { | |||
| public: | |||
| static void Start() { ExitSignalHandle::Instance().Start(); } | |||
| static bool HasStopped() { return ExitSignalHandle::Instance().HasStopped(); } | |||
| }; | |||
| // cppcheck-suppress syntaxError | |||
| PYBIND11_MODULE(_mindspore_serving, m) { | |||
| PyRegServable(&m); | |||
| PyRegMaster(&m); | |||
| PyRegWorker(&m); | |||
| PyRegWorkerAgent(&m); | |||
| py::class_<PyExitSignalHandle>(m, "ExitSignalHandle_") | |||
| .def_static("start", &PyExitSignalHandle::Start) | |||
| .def_static("has_stopped", &PyExitSignalHandle::HasStopped); | |||
| (void)py::module::import("atexit").attr("register")(py::cpp_function{[&]() -> void { | |||
| Server::Instance().Clear(); | |||
| Worker::GetInstance().Clear(); | |||
| WorkerAgent::Instance().Clear(); | |||
| }}); | |||
| } | |||
| @@ -24,7 +24,7 @@ | |||
| #include "worker/local_servable/local_sevable.h" | |||
| #include "worker/distributed_worker/distributed_servable.h" | |||
| #include "worker/grpc/worker_server.h" | |||
| #include "worker/distributed_worker/grpc/distributed_server.h" | |||
| #include "worker/distributed_worker/distributed_process/distributed_server.h" | |||
| namespace mindspore::serving { | |||
| @@ -43,11 +43,10 @@ void PyWorker::StartServable(const std::string &model_directory, const std::stri | |||
| } | |||
| // start grpc server | |||
| auto grpc_sever = std::make_shared<MSWorkerServer>(); | |||
| status = grpc_sever->StartWorkerGrpcServer(worker_ip, worker_port); | |||
| status = Worker::GetInstance().StartGrpcServer(grpc_sever, worker_ip, worker_port); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | |||
| } | |||
| Worker::GetInstance().AfterStartGrpcServer(grpc_sever); | |||
| status = Worker::GetInstance().StartVersionController(); | |||
| if (status != SUCCESS) { | |||
| @@ -76,18 +75,19 @@ void PyWorker::StartServableInMaster(const std::string &model_directory, const s | |||
| void PyWorker::StartDistributedServable(const std::string &servable_directory, const std::string &servable_name, | |||
| const std::string &rank_table_json_file, uint32_t version_number, | |||
| const std::string &worker_ip, uint32_t worker_port, | |||
| const std::string &master_ip, uint32_t master_port) { | |||
| const std::string &master_ip, uint32_t master_port, | |||
| uint32_t wait_agents_time_in_seconds) { | |||
| Status status; | |||
| auto servable = std::make_shared<DistributedServable>(); | |||
| auto grpc_sever = std::make_shared<MSDistributedWorkerServer>(); | |||
| status = grpc_sever->StartDistributedWorkerGrpcServer(servable, worker_ip, worker_port); | |||
| auto grpc_sever = std::make_shared<MSDistributedWorkerServer>(servable); | |||
| status = Worker::GetInstance().StartGrpcServer(grpc_sever, worker_ip, worker_port); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | |||
| } | |||
| Worker::GetInstance().AfterStartGrpcServer(grpc_sever); | |||
| auto notify_master = std::make_shared<GrpcNotfiyMaster>(master_ip, master_port, worker_ip, worker_port); | |||
| status = servable->StartServable(servable_directory, servable_name, rank_table_json_file, version_number); | |||
| status = servable->StartServable(servable_directory, servable_name, rank_table_json_file, version_number, | |||
| wait_agents_time_in_seconds); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | |||
| } | |||
| @@ -103,18 +103,19 @@ void PyWorker::StartDistributedServable(const std::string &servable_directory, c | |||
| void PyWorker::StartDistributedServableInMaster(const std::string &servable_directory, const std::string &servable_name, | |||
| const std::string &rank_table_json_file, uint32_t version_number, | |||
| const std::string &worker_ip, uint32_t worker_port) { | |||
| const std::string &worker_ip, uint32_t worker_port, | |||
| uint32_t wait_agents_time_in_seconds) { | |||
| Status status; | |||
| auto servable = std::make_shared<DistributedServable>(); | |||
| auto grpc_sever = std::make_shared<MSDistributedWorkerServer>(); | |||
| status = grpc_sever->StartDistributedWorkerGrpcServer(servable, worker_ip, worker_port); | |||
| auto grpc_sever = std::make_shared<MSDistributedWorkerServer>(servable); | |||
| status = Worker::GetInstance().StartGrpcServer(grpc_sever, worker_ip, worker_port); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | |||
| } | |||
| Worker::GetInstance().AfterStartGrpcServer(grpc_sever); | |||
| auto notify_master = std::make_shared<LocalNotifyMaster>(); | |||
| status = servable->StartServable(servable_directory, servable_name, rank_table_json_file, version_number); | |||
| status = servable->StartServable(servable_directory, servable_name, rank_table_json_file, version_number, | |||
| wait_agents_time_in_seconds); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); | |||
| } | |||
| @@ -37,11 +37,12 @@ class MS_API PyWorker { | |||
| static void StartDistributedServable(const std::string &servable_directory, const std::string &servable_name, | |||
| const std::string &rank_table_json_file, uint32_t version_number, | |||
| const std::string &worker_ip, uint32_t worker_port, const std::string &master_ip, | |||
| uint32_t master_port); | |||
| uint32_t master_port, uint32_t wait_agents_time_in_seconds); | |||
| static void StartDistributedServableInMaster(const std::string &servable_directory, const std::string &servable_name, | |||
| const std::string &rank_table_json_file, uint32_t version_number, | |||
| const std::string &worker_ip, uint32_t worker_port); | |||
| const std::string &worker_ip, uint32_t worker_port, | |||
| uint32_t wait_agents_time_in_seconds); | |||
| static int GetBatchSize(); | |||
| static void WaitAndClear(); | |||
| @@ -22,7 +22,7 @@ namespace serving { | |||
| grpc::Status MSAgentImpl::Exit(grpc::ServerContext *context, const proto::DistributedExitRequest *request, | |||
| proto::DistributedExitReply *reply) { | |||
| MSI_LOG(INFO) << "Distributed Worker Exit"; | |||
| WorkerAgent::Instance().Clear(); | |||
| WorkerAgent::Instance().StopAgent(false); | |||
| return grpc::Status::OK; | |||
| } | |||
| @@ -14,17 +14,32 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "worker/distributed_worker/agent_startup.h" | |||
| #include "worker/distributed_worker/notify_distributed/notify_worker.h" | |||
| namespace mindspore { | |||
| namespace serving { | |||
| Status WorkerAgentStartUp::InitAgentsConfig(const std::string &model_dir, const std::string &model_file_prefix, | |||
| const std::string &group_file_dir, const std::string &group_file_prefix) { | |||
| return Status(); | |||
| WorkerAgentStartUp &WorkerAgentStartUp::Instance() { | |||
| static WorkerAgentStartUp instance; | |||
| return instance; | |||
| } | |||
| Status WorkerAgentStartUp::GetAgentsConfigsFromWorker(const std::string &agent_ip, uint32_t agent_start_port, | |||
| const std::string &worker_ip, uint32_t worker_port) { | |||
| Status WorkerAgentStartUp::GetAgentsConfigsFromWorker(const std::string &worker_ip, uint32_t worker_port) { | |||
| return Status(); | |||
| } | |||
| Status WorkerAgentStartUp::GetCurrentMachineConfigs(std::vector<AgentStartUpConfig> *configs) { return Status(); } | |||
| Status WorkerAgentStartUp::GetDistributedServableConfig(DistributedServableConfig *config) { | |||
| MSI_EXCEPTION_IF_NULL(config); | |||
| if (config_.rank_list.empty()) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) << "Rank table config is not ready"; | |||
| } | |||
| *config = config_; | |||
| return SUCCESS; | |||
| } | |||
| Status WorkerAgentStartUp::NotifyFailed(const std::string &worker_ip, uint32_t worker_port) { | |||
| return GrpcNotifyDistributeWorker::NotifyFailed(worker_ip, worker_port); | |||
| } | |||
| } // namespace serving | |||
| } // namespace mindspore | |||
| @@ -27,16 +27,15 @@ namespace serving { | |||
| class MS_API WorkerAgentStartUp { | |||
| public: | |||
| static WorkerAgentStartUp &Instance(); | |||
| // from python, worker_agent.py | |||
| // start_worker_agent | |||
| // step1, get agents config from worker | |||
| Status InitAgentsConfig(const std::string &model_dir, const std::string &model_file_prefix, | |||
| const std::string &group_file_dir, const std::string &group_file_prefix); | |||
| Status GetAgentsConfigsFromWorker(const std::string &worker_ip, uint32_t worker_port); | |||
| // step2, invoke from python | |||
| Status GetDistributedServableConfig(DistributedServableConfig *config); | |||
| Status GetAgentsConfigsFromWorker(const std::string &rank_start, uint32_t agent_start_port, | |||
| const std::string &worker_ip, uint32_t worker_port); | |||
| // step2, invoke from python, get current machine agents config | |||
| Status GetCurrentMachineConfigs(std::vector<AgentStartUpConfig> *configs); | |||
| Status NotifyFailed(const std::string &worker_ip, uint32_t worker_port); | |||
| private: | |||
| DistributedServableConfig config_; | |||
| @@ -14,7 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "worker/distributed_worker/grpc/distributed_process.h" | |||
| #include "worker/distributed_worker/distributed_process/distributed_process.h" | |||
| #include "worker/worker.h" | |||
| #include "common/proto_tensor.h" | |||
| namespace mindspore { | |||
| @@ -51,6 +52,20 @@ grpc::Status MSDistributedImpl::AgentExit(grpc::ServerContext *context, const pr | |||
| MSI_LOG(ERROR) << "Agent Exit FAILED"; | |||
| } | |||
| } | |||
| if (Worker::GetInstance().IsRunning()) { | |||
| Worker::GetInstance().StopServable(); | |||
| } | |||
| return grpc::Status::OK; | |||
| } | |||
| grpc::Status MSDistributedImpl::AgentFailed(grpc::ServerContext *context, const proto::AgentFailedRequest *request, | |||
| proto::AgentFailedReply *reply) { | |||
| if (Worker::GetInstance().IsRunning()) { | |||
| MSI_LOG_ERROR << "Expect worker should not be running"; | |||
| Worker::GetInstance().StopServable(); | |||
| } else { | |||
| servable_->OnAgentFailed(); | |||
| } | |||
| return grpc::Status::OK; | |||
| } | |||
| } // namespace serving | |||
| @@ -41,6 +41,8 @@ class MSDistributedImpl final : public MSWorkerImpl { | |||
| proto::AgentRegisterReply *reply) override; | |||
| grpc::Status AgentExit(grpc::ServerContext *context, const proto::AgentExitRequest *request, | |||
| proto::AgentExitReply *reply) override; | |||
| grpc::Status AgentFailed(grpc::ServerContext *context, const proto::AgentFailedRequest *request, | |||
| proto::AgentFailedReply *reply) override; | |||
| private: | |||
| std::shared_ptr<DistributedServable> servable_; | |||
| @@ -14,7 +14,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "worker/distributed_worker/grpc/distributed_server.h" | |||
| #include "worker/distributed_worker/distributed_process/distributed_server.h" | |||
| #include <string> | |||
| #include <memory> | |||
| #include <utility> | |||
| @@ -23,12 +23,11 @@ | |||
| namespace mindspore { | |||
| namespace serving { | |||
| Status MSDistributedWorkerServer::StartDistributedWorkerGrpcServer(std::shared_ptr<DistributedServable> servable, | |||
| const std::string &hostname, int32_t port) { | |||
| Status MSDistributedWorkerServer::StartWorkerGrpcServer(const std::string &hostname, int32_t port) { | |||
| if (in_running_) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) << "Worker grpc server is already running"; | |||
| } | |||
| auto impl = std::make_unique<MSDistributedImpl>(servable); | |||
| auto impl = std::make_unique<MSDistributedImpl>(servable_); | |||
| async_server_ = std::make_unique<DistributedWorkerGrpcServer>(hostname, port, impl.get()); | |||
| service_impl_ = std::move(impl); | |||
| return Init(); | |||
| @@ -28,7 +28,7 @@ | |||
| #include "common/grpc_async_server.h" | |||
| #include "worker/grpc/worker_process.h" | |||
| #include "worker/grpc/worker_server.h" | |||
| #include "worker/distributed_worker/grpc/distributed_process.h" | |||
| #include "worker/distributed_worker/distributed_process/distributed_process.h" | |||
| namespace mindspore { | |||
| namespace serving { | |||
| @@ -36,18 +36,30 @@ namespace serving { | |||
| // Service Implement | |||
| class MS_API MSDistributedWorkerServer : public MSWorkerServer { | |||
| public: | |||
| Status StartDistributedWorkerGrpcServer(std::shared_ptr<DistributedServable> servable, const std::string &hostname, | |||
| int32_t port); | |||
| explicit MSDistributedWorkerServer(std::shared_ptr<DistributedServable> servable) : servable_(servable) {} | |||
| ~MSDistributedWorkerServer() = default; | |||
| Status StartWorkerGrpcServer(const std::string &hostname, int32_t port) override; | |||
| private: | |||
| std::shared_ptr<DistributedServable> servable_; | |||
| }; | |||
| class DistributedServiceContext : public WorkerServiceContext { | |||
| public: | |||
| DistributedServiceContext(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service, | |||
| grpc::ServerCompletionQueue *cq) | |||
| : WorkerServiceContext(service_impl, async_service, cq), dist_service_impl_(service_impl) {} | |||
| protected: | |||
| MSDistributedImpl *dist_service_impl_ = nullptr; | |||
| }; | |||
| // Service Implement | |||
| class WorkerAgentRegisterContext : public WorkerServiceContext { | |||
| class WorkerAgentRegisterContext : public DistributedServiceContext { | |||
| public: | |||
| WorkerAgentRegisterContext(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service, | |||
| grpc::ServerCompletionQueue *cq) | |||
| : service_impl_(service_impl), async_service_(async_service), cq_(cq), responder_(&ctx_) { | |||
| state_ = STATE::CREATE; | |||
| } | |||
| : DistributedServiceContext(service_impl, async_service, cq), responder_(&ctx_) {} | |||
| ~WorkerAgentRegisterContext() = default; | |||
| @@ -60,35 +72,27 @@ class WorkerAgentRegisterContext : public WorkerServiceContext { | |||
| void StartEnqueueRequest() override { | |||
| state_ = STATE::PROCESS; | |||
| async_service_->RequestPredict(&ctx_, &request_, &responder_, cq_, cq_, this); | |||
| async_service_->RequestAgentRegister(&ctx_, &request_, &responder_, cq_, cq_, this); | |||
| } | |||
| void HandleRequest() override { | |||
| EnqueueRequest(service_impl_, async_service_, cq_); | |||
| EnqueueRequest(dist_service_impl_, async_service_, cq_); | |||
| state_ = STATE::FINISH; | |||
| grpc::Status status = service_impl_->Predict(&ctx_, &request_, &response_); | |||
| grpc::Status status = dist_service_impl_->AgentRegister(&ctx_, &request_, &response_); | |||
| responder_.Finish(response_, status, this); | |||
| } | |||
| bool JudgeFinish() override { return state_ == STATE::FINISH; } | |||
| private: | |||
| MSDistributedImpl *service_impl_; | |||
| proto::MSWorker::AsyncService *async_service_; | |||
| grpc::ServerCompletionQueue *cq_; | |||
| grpc::ServerContext ctx_; | |||
| grpc::ServerAsyncResponseWriter<proto::PredictReply> responder_; | |||
| proto::PredictRequest request_; | |||
| proto::PredictReply response_; | |||
| grpc::ServerAsyncResponseWriter<proto::AgentRegisterReply> responder_; | |||
| proto::AgentRegisterRequest request_; | |||
| proto::AgentRegisterReply response_; | |||
| }; | |||
| class WorkerAgentExitContext : public WorkerServiceContext { | |||
| class WorkerAgentExitContext : public DistributedServiceContext { | |||
| public: | |||
| WorkerAgentExitContext(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service, | |||
| grpc::ServerCompletionQueue *cq) | |||
| : service_impl_(service_impl), async_service_(async_service), cq_(cq), responder_(&ctx_) { | |||
| state_ = STATE::CREATE; | |||
| } | |||
| : DistributedServiceContext(service_impl, async_service, cq), responder_(&ctx_) {} | |||
| ~WorkerAgentExitContext() = default; | |||
| @@ -101,26 +105,52 @@ class WorkerAgentExitContext : public WorkerServiceContext { | |||
| void StartEnqueueRequest() override { | |||
| state_ = STATE::PROCESS; | |||
| async_service_->RequestExit(&ctx_, &request_, &responder_, cq_, cq_, this); | |||
| async_service_->RequestAgentExit(&ctx_, &request_, &responder_, cq_, cq_, this); | |||
| } | |||
| void HandleRequest() override { | |||
| EnqueueRequest(service_impl_, async_service_, cq_); | |||
| EnqueueRequest(dist_service_impl_, async_service_, cq_); | |||
| state_ = STATE::FINISH; | |||
| grpc::Status status = service_impl_->Exit(&ctx_, &request_, &response_); | |||
| grpc::Status status = dist_service_impl_->AgentExit(&ctx_, &request_, &response_); | |||
| responder_.Finish(response_, status, this); | |||
| } | |||
| bool JudgeFinish() override { return state_ == STATE::FINISH; } | |||
| private: | |||
| grpc::ServerAsyncResponseWriter<proto::AgentExitReply> responder_; | |||
| proto::AgentExitRequest request_; | |||
| proto::AgentExitReply response_; | |||
| }; | |||
| class WorkerAgentFailedContext : public DistributedServiceContext { | |||
| public: | |||
| WorkerAgentFailedContext(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service, | |||
| grpc::ServerCompletionQueue *cq) | |||
| : DistributedServiceContext(service_impl, async_service, cq), responder_(&ctx_) {} | |||
| ~WorkerAgentFailedContext() = default; | |||
| static Status EnqueueRequest(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service, | |||
| grpc::ServerCompletionQueue *cq) { | |||
| auto call = new WorkerAgentFailedContext(service_impl, async_service, cq); | |||
| call->StartEnqueueRequest(); | |||
| return SUCCESS; | |||
| } | |||
| void StartEnqueueRequest() override { | |||
| state_ = STATE::PROCESS; | |||
| async_service_->RequestAgentFailed(&ctx_, &request_, &responder_, cq_, cq_, this); | |||
| } | |||
| void HandleRequest() override { | |||
| EnqueueRequest(dist_service_impl_, async_service_, cq_); | |||
| state_ = STATE::FINISH; | |||
| grpc::Status status = dist_service_impl_->AgentFailed(&ctx_, &request_, &response_); | |||
| responder_.Finish(response_, status, this); | |||
| } | |||
| private: | |||
| MSDistributedImpl *service_impl_; | |||
| proto::MSWorker::AsyncService *async_service_; | |||
| grpc::ServerCompletionQueue *cq_; | |||
| grpc::ServerContext ctx_; | |||
| grpc::ServerAsyncResponseWriter<proto::ExitReply> responder_; | |||
| proto::ExitRequest request_; | |||
| proto::ExitReply response_; | |||
| grpc::ServerAsyncResponseWriter<proto::AgentFailedReply> responder_; | |||
| proto::AgentFailedRequest request_; | |||
| proto::AgentFailedReply response_; | |||
| }; | |||
| class DistributedWorkerGrpcServer : public WorkerGrpcServer { | |||
| @@ -134,6 +164,7 @@ class DistributedWorkerGrpcServer : public WorkerGrpcServer { | |||
| WorkerGrpcServer::EnqueueRequest(); | |||
| WorkerAgentRegisterContext::EnqueueRequest(distributed_service_impl_, &svc_, cq_.get()); | |||
| WorkerAgentExitContext::EnqueueRequest(distributed_service_impl_, &svc_, cq_.get()); | |||
| WorkerAgentFailedContext::EnqueueRequest(distributed_service_impl_, &svc_, cq_.get()); | |||
| return SUCCESS; | |||
| } | |||
| @@ -17,13 +17,15 @@ | |||
| #include "worker/distributed_worker/distributed_servable.h" | |||
| #include <vector> | |||
| #include <string> | |||
| #include "worker/worker.h" | |||
| #include <set> | |||
| #include "worker/distributed_worker/notify_agent/notify_agent.h" | |||
| #include "common/exit_handle.h" | |||
| namespace mindspore { | |||
| namespace serving { | |||
| DistributedServable::~DistributedServable() { Clear(); } | |||
| std::string DistributedServable::GetServableName() const { return servable_name_; } | |||
| uint64_t DistributedServable::GetServableVersion() const { return version_number_; } | |||
| @@ -60,7 +62,15 @@ Status DistributedServable::GetDistributedServableConfig(DistributedServableConf | |||
| return SUCCESS; | |||
| } | |||
| void DistributedServable::SetWaitAgentsPromise(bool flag) { | |||
| if (!promise_set_flag_.test_and_set()) { | |||
| agents_promise_.set_value(flag); | |||
| } | |||
| } | |||
| Status DistributedServable::RegisterAgent(const WorkerAgentSpec &agent_spec) { | |||
| std::unique_lock<std::mutex> lock{mutex_}; | |||
| if (agent_spec.rank_id < config_.distributed_meta.rank_size) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Invalid rank id " << agent_spec.rank_id << ", rank size " << config_.distributed_meta.rank_size; | |||
| @@ -75,27 +85,24 @@ Status DistributedServable::RegisterAgent(const WorkerAgentSpec &agent_spec) { | |||
| std::shared_ptr<BaseNotifyAgent> notify_agent = std::make_shared<GrpcNotfiyAgent>(agent_spec.agent_address); | |||
| context.notify_agent_ = notify_agent; | |||
| agent_spec_map_[agent_spec.rank_id] = context; | |||
| if (config_.distributed_meta.rank_size == agent_spec_map_.size()) { | |||
| Status status = Worker::GetInstance().RegisterWorker(); | |||
| if (status != SUCCESS) { | |||
| Clear(); | |||
| return FAILED; | |||
| } | |||
| } | |||
| if (agent_spec_map_.size() >= config_.distributed_meta.rank_size) { | |||
| agents_promise_.set_value(); | |||
| SetWaitAgentsPromise(true); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| void DistributedServable::Clear() { | |||
| for (auto agent : agent_spec_map_) { | |||
| std::unique_lock<std::mutex> lock{mutex_}; | |||
| for (auto &agent : agent_spec_map_) { | |||
| agent.second.notify_agent_->Exit(); | |||
| } | |||
| Worker::GetInstance().StopServable(false); | |||
| agent_spec_map_.clear(); | |||
| MSI_LOG_INFO << "End Clear servable"; | |||
| } | |||
| Status DistributedServable::UnregisterAgent(const WorkerAgentSpec &agent_spec) { | |||
| std::unique_lock<std::mutex> lock{mutex_}; | |||
| for (auto iter = agent_spec_map_.begin(); iter != agent_spec_map_.end();) { | |||
| if (agent_spec.rank_id == iter->second.agent_spec_.rank_id) { | |||
| iter = agent_spec_map_.erase(iter); | |||
| @@ -103,13 +110,13 @@ Status DistributedServable::UnregisterAgent(const WorkerAgentSpec &agent_spec) { | |||
| ++iter; | |||
| } | |||
| } | |||
| // todo: send exit message to agent, and then exit if split with master | |||
| Clear(); | |||
| SetWaitAgentsPromise(false); | |||
| return SUCCESS; | |||
| } | |||
| Status DistributedServable::StartServable(const std::string &servable_directory, const std::string &servable_name, | |||
| const std::string &rank_table_json_file, uint64_t version_number) { | |||
| const std::string &rank_table_json_file, uint64_t version_number, | |||
| uint64_t wait_agents_time_in_seconds) { | |||
| if (model_loaded_) { | |||
| MSI_LOG_EXCEPTION << "Model has loaded"; | |||
| } | |||
| @@ -138,7 +145,7 @@ Status DistributedServable::StartServable(const std::string &servable_directory, | |||
| MSI_LOG_ERROR << "Check rank config failed"; | |||
| return status; | |||
| } | |||
| status = WaitAgentsReady(); | |||
| status = WaitAgentsReady(wait_agents_time_in_seconds); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_ERROR << "Waiting for ready of agents failed"; | |||
| return status; | |||
| @@ -154,16 +161,23 @@ Status DistributedServable::StartServable(const std::string &servable_directory, | |||
| Status DistributedServable::InitConfigOnStartup(const std::string &rank_table_json_file) { return FAILED; } | |||
| Status DistributedServable::WaitAgentsReady() { | |||
| Status DistributedServable::WaitAgentsReady(uint64_t wait_agents_time_in_seconds) { | |||
| auto future = agents_promise_.get_future(); | |||
| const int kWaitMaxHundredMs = 100 * 10; // 100s | |||
| int i; | |||
| if (wait_agents_time_in_seconds == 0) { | |||
| wait_agents_time_in_seconds = UINT32_MAX; | |||
| } | |||
| const uint64_t kWaitMaxHundredMs = wait_agents_time_in_seconds * 10; | |||
| uint64_t i; | |||
| for (i = 0; i < kWaitMaxHundredMs; i++) { // | |||
| if (ExitSignalHandle::Instance().HasStopped()) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) << "Agents has stopped"; | |||
| } | |||
| // waiting for 100ms | |||
| if (future.wait_for(std::chrono::milliseconds(100)) == std::future_status::ready) { | |||
| auto flag = future.get(); | |||
| if (!flag) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) << "Failed to starting all agents, maybe some error reported"; | |||
| } | |||
| break; | |||
| } | |||
| } | |||
| @@ -264,32 +278,49 @@ Status DistributedServable::CheckRankConfig() { | |||
| << "Rank size must be an integral multiple of stage size, rank size: " << rank_size | |||
| << ", stage size: " << stage_size; | |||
| } | |||
| auto parallel_count = rank_size / stage_size; | |||
| constexpr size_t card_count_per_machine = 8; | |||
| if (rank_size > card_count_per_machine && parallel_count % card_count_per_machine != 0) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Parallel count " << parallel_count << " in one stage must be an integral multiple of card count " | |||
| << card_count_per_machine << " in one machine, when rank size is greater than card count in one machine, " | |||
| << "rank size: " << rank_size << ", stage size: " << stage_size; | |||
| } | |||
| if (config_.rank_list.size() != rank_size) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Rank size " << config_.rank_list.size() << " declared in rank table file not equal to rank size " | |||
| << rank_size << " declared in servable_config, rank json config file: " << rank_table_json_file_; | |||
| } | |||
| for (size_t i = 0; i < rank_size; i++) { | |||
| const auto &first_item = config_.rank_list[i]; | |||
| for (size_t k = 0; i + k < rank_size && k < card_count_per_machine; k++) { | |||
| auto rank_id = i + k; | |||
| const auto &item = config_.rank_list[rank_id]; | |||
| if (k != item.device_id) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Check rank table config failed, expected device id of rank " << rank_id << " to be " << k; | |||
| auto parallel_count = rank_size / stage_size; | |||
| constexpr size_t card_count_per_machine = 8; | |||
| if (stage_size == 1) { | |||
| std::map<std::string, std::set<uint32_t>> device_map; | |||
| for (size_t i = 0; i < rank_size; i++) { | |||
| const auto &item = config_.rank_list[i]; | |||
| auto &device_id_list = device_map[item.ip]; | |||
| if (device_id_list.count(item.device_id) > 0) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) << "Check rank table config failed, device id repeatedly used by rank " | |||
| << i << " in device ip " << item.ip; | |||
| } | |||
| if (first_item.ip != item.ip) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Check rank table config failed, expected device ip " << item.ip << " of rank " << rank_id | |||
| << " to be equal with device ip " << first_item.ip << " of rank " << i; | |||
| device_id_list.emplace(item.device_id); | |||
| } | |||
| } else { | |||
| if (rank_size < card_count_per_machine) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Rank size " << rank_size << "must >= card count " << card_count_per_machine | |||
| << " of one machine when stage size " << stage_size << " > 1"; | |||
| } | |||
| if (parallel_count % card_count_per_machine != 0) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Parallel count " << parallel_count << " in one stage must be N * " << card_count_per_machine | |||
| << "(card count of one machine), rank size: " << rank_size << ", stage size: " << stage_size; | |||
| } | |||
| for (size_t i = 0; i < rank_size; i += card_count_per_machine) { | |||
| const auto &first_item = config_.rank_list[i]; | |||
| for (size_t k = 0; i + k < rank_size && k < card_count_per_machine; k++) { | |||
| auto rank_id = i + k; | |||
| const auto &item = config_.rank_list[rank_id]; | |||
| if (k != item.device_id) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Check rank table config failed, expected device id of rank " << rank_id << " to be " << k; | |||
| } | |||
| if (first_item.ip != item.ip) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) | |||
| << "Check rank table config failed, expected device ip " << item.ip << " of rank " << rank_id | |||
| << " to be equal with device ip " << first_item.ip << " of rank " << i; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -298,5 +329,7 @@ Status DistributedServable::CheckRankConfig() { | |||
| return SUCCESS; | |||
| } | |||
| void DistributedServable::OnAgentFailed() { SetWaitAgentsPromise(false); } | |||
| } // namespace serving | |||
| } // namespace mindspore | |||
| @@ -35,9 +35,12 @@ struct DistributedAgentContext { | |||
| class MS_API DistributedServable : public ServableBase { | |||
| public: | |||
| DistributedServable() = default; | |||
| ~DistributedServable(); | |||
| // from python, worker.py | |||
| Status StartServable(const std::string &servable_directory, const std::string &servable_name, | |||
| const std::string &rank_table_json_file, uint64_t version_number); | |||
| const std::string &rank_table_json_file, uint64_t version_number, | |||
| uint64_t wait_agents_time_in_seconds); | |||
| // invoke from agent | |||
| Status GetDistributedServableConfig(DistributedServableConfig *config) const; | |||
| @@ -55,7 +58,8 @@ class MS_API DistributedServable : public ServableBase { | |||
| uint64_t GetBatchSize() const override; | |||
| std::string GetServableName() const override; | |||
| uint64_t GetServableVersion() const override; | |||
| void Clear(); | |||
| void Clear() override; | |||
| void OnAgentFailed(); | |||
| private: | |||
| DistributedServableConfig config_; | |||
| @@ -63,19 +67,22 @@ class MS_API DistributedServable : public ServableBase { | |||
| uint64_t version_number_ = 0; | |||
| bool model_loaded_ = false; | |||
| std::mutex mutex_; | |||
| std::map<uint32_t, DistributedAgentContext> agent_spec_map_; | |||
| std::string rank_table_json_file_; | |||
| std::vector<TensorInfo> input_infos_; | |||
| std::vector<TensorInfo> output_infos_; | |||
| uint64_t batch_size_ = 0; | |||
| std::promise<void> agents_promise_; | |||
| std::atomic_flag promise_set_flag_ = ATOMIC_FLAG_INIT; | |||
| std::promise<bool> agents_promise_; | |||
| Status InitConfigOnStartup(const std::string &rank_table_json_file); | |||
| Status WaitAgentsReady(); | |||
| Status WaitAgentsReady(uint64_t wait_agents_time_in_seconds); | |||
| Status CheckAgentsInfosAndInitTensorInfos(); | |||
| Status CompareTensorInfos(const std::vector<TensorInfo> &lefts, const std::vector<TensorInfo> &rights); | |||
| Status CheckRankConfig(); | |||
| void SetWaitAgentsPromise(bool flag); | |||
| // agent stubs | |||
| }; | |||
| @@ -62,7 +62,7 @@ Status GrpcNotifyDistributeWorker::Register(const std::vector<WorkerAgentSpec> & | |||
| std::this_thread::sleep_for(std::chrono::milliseconds(REGISTER_INTERVAL * 1000)); | |||
| } | |||
| if (ExitSignalHandle::Instance().HasStopped()) { | |||
| return INFER_STATUS_LOG_WARNING(FAILED) << "Worker exit, stop registration"; | |||
| return INFER_STATUS_LOG_WARNING(FAILED) << "Agent exit, stop registration"; | |||
| } | |||
| return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Register TimeOut"; | |||
| } | |||
| @@ -87,5 +87,21 @@ Status GrpcNotifyDistributeWorker::Unregister() { | |||
| return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Exit Failed"; | |||
| } | |||
| Status GrpcNotifyDistributeWorker::NotifyFailed(const std::string &worker_ip, uint32_t worker_port) { | |||
| auto address = worker_ip + ":" + std::to_string(worker_port); | |||
| auto channel = GrpcServer::CreateChannel(address); | |||
| auto stub = proto::MSWorker::NewStub(channel); | |||
| grpc::ClientContext context; | |||
| proto::AgentFailedRequest request; | |||
| proto::AgentFailedReply reply; | |||
| grpc::Status status = stub->AgentFailed(&context, request, &reply); | |||
| if (status.ok()) { | |||
| MSI_LOG(INFO) << "Success to notify failure of agent"; | |||
| return SUCCESS; | |||
| } | |||
| return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Failed to notify failure of agent"; | |||
| } | |||
| } // namespace serving | |||
| } // namespace mindspore | |||
| @@ -19,7 +19,8 @@ | |||
| #include <vector> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "worker/distributed_worker/notify_distributed/base_notify_worker.h" | |||
| #include "common/serving_common.h" | |||
| #include "worker/distributed_worker/common.h" | |||
| #include "proto/ms_distributed.pb.h" | |||
| #include "proto/ms_distributed.grpc.pb.h" | |||
| #include "proto/ms_worker.pb.h" | |||
| @@ -27,13 +28,15 @@ | |||
| namespace mindspore { | |||
| namespace serving { | |||
| class MS_API GrpcNotifyDistributeWorker : public BaseNotifyDistributeWorker { | |||
| class MS_API GrpcNotifyDistributeWorker { | |||
| public: | |||
| GrpcNotifyDistributeWorker(const std::string &master_ip, uint32_t master_port, const std::string &host_ip, | |||
| uint32_t host_port); | |||
| ~GrpcNotifyDistributeWorker() override; | |||
| Status Register(const std::vector<WorkerAgentSpec> &worker_specs) override; | |||
| Status Unregister() override; | |||
| GrpcNotifyDistributeWorker(const std::string &worker_ip, uint32_t worker_port, const std::string &agent_ip, | |||
| uint32_t agent_port); | |||
| ~GrpcNotifyDistributeWorker(); | |||
| Status Register(const std::vector<WorkerAgentSpec> &agent_specs); | |||
| Status Unregister(); | |||
| // from start up, not agent | |||
| static Status NotifyFailed(const std::string &worker_ip, uint32_t worker_port); | |||
| private: | |||
| std::string distributed_worker_ip_; | |||
| @@ -14,6 +14,10 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "worker/distributed_worker/worker_agent.h" | |||
| #include <memory> | |||
| #include "worker/distributed_worker/agent_process/agent_process.h" | |||
| #include "worker/distributed_worker/notify_distributed/notify_worker.h" | |||
| #include "common/exit_handle.h" | |||
| namespace mindspore { | |||
| namespace serving { | |||
| @@ -23,15 +27,16 @@ WorkerAgent &WorkerAgent::Instance() { | |||
| return instance; | |||
| } | |||
| Status WorkerAgent::LoadModelFromFile(const AgentStartUpConfig &config) { | |||
| config_ = config; | |||
| return executor_.LoadModelFromFile(config); | |||
| } | |||
| Status WorkerAgent::Clear() { return executor_.UnloadModel(); } | |||
| Status WorkerAgent::ExecuteModel(const std::vector<TensorBasePtr> &request, std::vector<TensorBasePtr> *reply) { | |||
| return executor_.ExecuteModel(request, reply); | |||
| Status WorkerAgent::Clear() { | |||
| if (notify_worker_) { | |||
| if (exit_notify_worker_) { | |||
| notify_worker_->Unregister(); | |||
| } | |||
| notify_worker_ = nullptr; | |||
| } | |||
| grpc_server_.Stop(); | |||
| executor_.UnloadModel(); | |||
| return SUCCESS; | |||
| } | |||
| Status WorkerAgent::Run(const proto::DistributedPredictRequest &request, proto::DistributedPredictReply *reply) { | |||
| @@ -40,5 +45,59 @@ Status WorkerAgent::Run(const proto::DistributedPredictRequest &request, proto:: | |||
| return SUCCESS; | |||
| } | |||
| Status WorkerAgent::StartAgent(const AgentStartUpConfig &config) { | |||
| Status status; | |||
| config_ = config; | |||
| status = executor_.LoadModelFromFile(config); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_ERROR << "LoadModelFromFile failed, servable name: " << config.common_meta.servable_name | |||
| << ", rank_id: " << config.rank_id << ", device id: " << config.device_id | |||
| << ", model file: " << config.model_file_name | |||
| << ", rank table file: " << config.rank_table_json_file_name | |||
| << ", group config file: " << config.group_file_name; | |||
| return status; | |||
| } | |||
| status = StartGrpcServer(); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_ERROR << "Start agent grpc server failed, agent ip: " << config.agent_ip | |||
| << ", agent port: " << config.agent_port; | |||
| return status; | |||
| } | |||
| status = RegisterAgent(); | |||
| if (status != SUCCESS) { | |||
| MSI_LOG_ERROR << "Register agent failed, agent ip: " << config.agent_ip << ", agent port: " << config.agent_port | |||
| << ", worker ip: " << config.worker_ip << ", worker port: " << config.worker_port; | |||
| return status; | |||
| } | |||
| MSI_LOG_INFO << "Start agent success, servable name: " << config.common_meta.servable_name | |||
| << ", rank_id: " << config.rank_id << ", device id: " << config.device_id | |||
| << ", model file: " << config.model_file_name | |||
| << ", rank table file: " << config.rank_table_json_file_name | |||
| << ", group config file: " << config.group_file_name; | |||
| return SUCCESS; | |||
| } | |||
| Status WorkerAgent::StartGrpcServer() { | |||
| grpc_server_.Start(std::make_shared<MSAgentImpl>(), config_.agent_ip, config_.agent_port, gRpcMaxMBMsgSize, "Agent"); | |||
| return SUCCESS; | |||
| } | |||
| Status WorkerAgent::RegisterAgent() { | |||
| notify_worker_ = std::make_shared<GrpcNotifyDistributeWorker>(config_.worker_ip, config_.agent_port, config_.agent_ip, | |||
| config_.agent_port); | |||
| WorkerAgentSpec spec; | |||
| spec.agent_address = config_.agent_ip + ":" + std::to_string(config_.agent_port); | |||
| spec.rank_id = config_.rank_id; | |||
| spec.batch_size = executor_.GetBatchSize(); | |||
| spec.input_infos = executor_.GetInputInfos(); | |||
| spec.output_infos = executor_.GetOutputInfos(); | |||
| return notify_worker_->Register({spec}); | |||
| } | |||
| void WorkerAgent::StopAgent(bool notify_worker) { | |||
| exit_notify_worker_ = notify_worker; | |||
| ExitSignalHandle::Instance().Stop(); | |||
| } | |||
| } // namespace serving | |||
| } // namespace mindspore | |||
| @@ -17,24 +17,36 @@ | |||
| #ifndef MINDSPORE_SERVING_WORKER_AGENT_H | |||
| #define MINDSPORE_SERVING_WORKER_AGENT_H | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "worker/distributed_worker/agent_executor.h" | |||
| #include "proto/ms_agent.pb.h" | |||
| #include "proto/ms_agent.grpc.pb.h" | |||
| #include "common/grpc_server.h" | |||
| #include "worker/distributed_worker/common.h" | |||
| #include "worker/distributed_worker/notify_distributed/notify_worker.h" | |||
| namespace mindspore { | |||
| namespace serving { | |||
| class MS_API WorkerAgent { | |||
| public: | |||
| static WorkerAgent &Instance(); | |||
| Status LoadModelFromFile(const AgentStartUpConfig &config); | |||
| Status Clear(); | |||
| Status ExecuteModel(const std::vector<TensorBasePtr> &request, std::vector<TensorBasePtr> *reply); | |||
| Status Run(const proto::DistributedPredictRequest &request, proto::DistributedPredictReply *reply); | |||
| Status StartAgent(const AgentStartUpConfig &config); | |||
| void StopAgent(bool notify_worker = true); | |||
| private: | |||
| AgentStartUpConfig config_; | |||
| WorkerAgentExecutor executor_; | |||
| GrpcServer grpc_server_; | |||
| bool exit_notify_worker_ = true; | |||
| std::shared_ptr<GrpcNotifyDistributeWorker> notify_worker_; | |||
| Status StartGrpcServer(); | |||
| Status RegisterAgent(); | |||
| }; | |||
| } // namespace serving | |||
| @@ -39,7 +39,7 @@ class MS_API MSWorkerServer { | |||
| MSWorkerServer(); | |||
| virtual ~MSWorkerServer(); | |||
| Status StartWorkerGrpcServer(const std::string &hostname, int32_t port); | |||
| virtual Status StartWorkerGrpcServer(const std::string &hostname, int32_t port); | |||
| Status Stop(); | |||
| protected: | |||
| @@ -48,21 +48,32 @@ class MS_API MSWorkerServer { | |||
| std::unique_ptr<MSWorkerImpl> service_impl_ = nullptr; | |||
| std::unique_ptr<GrpcAsyncServer> async_server_ = nullptr; | |||
| Status StartAsyncRpcService(); | |||
| Status Init(); | |||
| Status StartAsyncRpcService(); | |||
| }; | |||
| class WorkerServiceContext { | |||
| public: | |||
| enum class STATE : int8_t { CREATE = 1, PROCESS = 2, FINISH = 3 }; | |||
| WorkerServiceContext(MSWorkerImpl *service_impl, proto::MSWorker::AsyncService *async_service, | |||
| grpc::ServerCompletionQueue *cq) | |||
| : service_impl_(service_impl), async_service_(async_service), cq_(cq) { | |||
| state_ = STATE::CREATE; | |||
| } | |||
| virtual ~WorkerServiceContext() {} | |||
| bool JudgeFinish() { return state_ == STATE::FINISH; } | |||
| virtual void StartEnqueueRequest() = 0; | |||
| virtual void HandleRequest() = 0; | |||
| virtual bool JudgeFinish() = 0; | |||
| protected: | |||
| MSWorkerImpl *service_impl_; | |||
| proto::MSWorker::AsyncService *async_service_; | |||
| grpc::ServerCompletionQueue *cq_; | |||
| grpc::ServerContext ctx_; | |||
| public: | |||
| STATE state_; | |||
| }; | |||
| @@ -70,9 +81,7 @@ class WorkerPredictContext : public WorkerServiceContext { | |||
| public: | |||
| WorkerPredictContext(MSWorkerImpl *service_impl, proto::MSWorker::AsyncService *async_service, | |||
| grpc::ServerCompletionQueue *cq) | |||
| : service_impl_(service_impl), async_service_(async_service), cq_(cq), responder_(&ctx_) { | |||
| state_ = STATE::CREATE; | |||
| } | |||
| : WorkerServiceContext(service_impl, async_service, cq), responder_(&ctx_) {} | |||
| ~WorkerPredictContext() = default; | |||
| @@ -95,13 +104,7 @@ class WorkerPredictContext : public WorkerServiceContext { | |||
| responder_.Finish(response_, status, this); | |||
| } | |||
| bool JudgeFinish() override { return state_ == STATE::FINISH; } | |||
| private: | |||
| MSWorkerImpl *service_impl_; | |||
| proto::MSWorker::AsyncService *async_service_; | |||
| grpc::ServerCompletionQueue *cq_; | |||
| grpc::ServerContext ctx_; | |||
| grpc::ServerAsyncResponseWriter<proto::PredictReply> responder_; | |||
| proto::PredictRequest request_; | |||
| proto::PredictReply response_; | |||
| @@ -111,9 +114,7 @@ class WorkerExitContext : public WorkerServiceContext { | |||
| public: | |||
| WorkerExitContext(MSWorkerImpl *service_impl, proto::MSWorker::AsyncService *async_service, | |||
| grpc::ServerCompletionQueue *cq) | |||
| : service_impl_(service_impl), async_service_(async_service), cq_(cq), responder_(&ctx_) { | |||
| state_ = STATE::CREATE; | |||
| } | |||
| : WorkerServiceContext(service_impl, async_service, cq), responder_(&ctx_) {} | |||
| ~WorkerExitContext() = default; | |||
| @@ -136,13 +137,7 @@ class WorkerExitContext : public WorkerServiceContext { | |||
| responder_.Finish(response_, status, this); | |||
| } | |||
| bool JudgeFinish() override { return state_ == STATE::FINISH; } | |||
| private: | |||
| MSWorkerImpl *service_impl_; | |||
| proto::MSWorker::AsyncService *async_service_; | |||
| grpc::ServerCompletionQueue *cq_; | |||
| grpc::ServerContext ctx_; | |||
| grpc::ServerAsyncResponseWriter<proto::ExitReply> responder_; | |||
| proto::ExitRequest request_; | |||
| proto::ExitReply response_; | |||
| @@ -31,7 +31,7 @@ static const char *kVersionStrategySpecific = "specific"; | |||
| namespace mindspore::serving { | |||
| LocalModelServable::~LocalModelServable() { session_.UnloadModel(); } | |||
| LocalModelServable::~LocalModelServable() { Clear(); } | |||
| std::string LocalModelServable::GetServableName() const { return servable_name_; } | |||
| @@ -248,4 +248,11 @@ Status LocalModelServable::LoadModel(uint64_t version_number) { | |||
| return SUCCESS; | |||
| } | |||
| void LocalModelServable::Clear() { | |||
| if (model_loaded_) { | |||
| session_.UnloadModel(); | |||
| } | |||
| model_loaded_ = false; | |||
| } | |||
| } // namespace mindspore::serving | |||
| @@ -48,6 +48,7 @@ class MS_API LocalModelServable : public ServableBase { | |||
| Status InitDevice(ModelType model_type, const std::map<std::string, std::string> &other_options); | |||
| std::string GetServableName() const override; | |||
| uint64_t GetServableVersion() const override; | |||
| void Clear() override; | |||
| private: | |||
| LoadServableSpec base_spec_; | |||
| @@ -41,6 +41,7 @@ class ServableBase { | |||
| virtual uint64_t GetBatchSize() const = 0; | |||
| virtual std::string GetServableName() const = 0; | |||
| virtual uint64_t GetServableVersion() const = 0; | |||
| virtual void Clear() = 0; | |||
| }; | |||
| } // namespace mindspore::serving | |||
| @@ -174,9 +174,13 @@ void Worker::Update() { | |||
| */ | |||
| } | |||
| Status Worker::AfterStartGrpcServer(const std::shared_ptr<MSWorkerServer> &grpc_server) { | |||
| Status Worker::StartGrpcServer(const std::shared_ptr<MSWorkerServer> &grpc_server, const std::string &worker_ip, | |||
| int32_t port) { | |||
| if (worker_grpc_server_ != nullptr) { | |||
| return INFER_STATUS_LOG_ERROR(FAILED) << "Worker gRPC server is already running"; | |||
| } | |||
| worker_grpc_server_ = grpc_server; | |||
| return SUCCESS; | |||
| return worker_grpc_server_->StartWorkerGrpcServer(worker_ip, port); | |||
| } | |||
| Status Worker::StartServable(std::shared_ptr<ServableBase> servable, std::shared_ptr<BaseNotifyMaster> notify_master) { | |||
| @@ -248,6 +252,9 @@ void Worker::Clear() { | |||
| if (exit_notify_master_ && servable_started_) { | |||
| notify_master_->Unregister(); | |||
| } | |||
| for (auto &worker_item : work_list_) { | |||
| worker_item.servable->Clear(); | |||
| } | |||
| work_list_.clear(); | |||
| py_task_queue_group_.Stop(); | |||
| @@ -257,7 +264,7 @@ void Worker::Clear() { | |||
| MSI_LOG_INFO << "End clear worker session"; | |||
| } | |||
| bool Worker::HasCleared() { return !servable_started_; } | |||
| bool Worker::IsRunning() { return servable_started_; } | |||
| Worker::~Worker() { Clear(); } | |||
| @@ -318,7 +325,7 @@ Status AsyncResult::GetNext(Instance *instance_result) { | |||
| const int kWaitMaxHundredMs = 100; | |||
| int i; | |||
| for (i = 0; i < kWaitMaxHundredMs; i++) { // | |||
| if (ExitSignalHandle::Instance().HasStopped() || Worker::GetInstance().HasCleared()) { | |||
| if (ExitSignalHandle::Instance().HasStopped() || !Worker::GetInstance().IsRunning()) { | |||
| instance_result->error_msg = Status(SYSTEM_ERROR, "Servable stopped"); | |||
| return SYSTEM_ERROR; | |||
| } | |||
| @@ -75,10 +75,11 @@ class MS_API Worker { | |||
| const std::vector<InstanceData> &inputs); | |||
| Status StartServable(std::shared_ptr<ServableBase> servable, std::shared_ptr<BaseNotifyMaster> notify_master); | |||
| Status AfterStartGrpcServer(const std::shared_ptr<MSWorkerServer> &grpc_server); | |||
| Status StartGrpcServer(const std::shared_ptr<MSWorkerServer> &grpc_server, const std::string &worker_ip, | |||
| int32_t port); | |||
| void StopServable(bool notify_master = true); | |||
| bool HasCleared(); | |||
| bool IsRunning(); | |||
| Status RegisterWorker(); | |||
| void Update(); | |||
| Status StartVersionController(); | |||
| @@ -18,6 +18,7 @@ import threading | |||
| from functools import wraps | |||
| from mindspore_serving.worker import check_type | |||
| from mindspore_serving import log as logger | |||
| from mindspore_serving._mindspore_serving import ExitSignalHandle_ | |||
| from mindspore_serving._mindspore_serving import Master_ | |||
| _wait_and_clear_thread = None | |||
| @@ -59,6 +60,7 @@ def stop_on_except(func): | |||
| @wraps(func) | |||
| def handle_except(*args, **kwargs): | |||
| try: | |||
| ExitSignalHandle_.start() # Set flag to running and receive Ctrl+C message | |||
| func(*args, **kwargs) | |||
| except: | |||
| stop() | |||
| @@ -44,3 +44,10 @@ message AgentExitRequest { | |||
| message AgentExitReply { | |||
| ErrorMsg error_msg = 1; | |||
| } | |||
| message AgentFailedRequest { | |||
| } | |||
| message AgentFailedReply { | |||
| ErrorMsg error_msg = 1; | |||
| } | |||
| @@ -29,4 +29,5 @@ service MSWorker { | |||
| // for worker agent | |||
| rpc AgentExit(AgentExitRequest) returns (AgentExitReply) {} | |||
| rpc AgentRegister(AgentRegisterRequest) returns (AgentRegisterReply) {} | |||
| rpc AgentFailed(AgentFailedRequest) returns (AgentFailedReply) {} | |||
| } | |||
| @@ -12,11 +12,12 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Inferface for start up servable""" | |||
| """Interface for start up servable""" | |||
| import threading | |||
| from functools import wraps | |||
| from mindspore_serving import log as logger | |||
| from mindspore_serving._mindspore_serving import ExitSignalHandle_ | |||
| from mindspore_serving._mindspore_serving import Worker_ | |||
| from .register.preprocess import preprocess_storage | |||
| from .register.postprocess import postprocess_storage | |||
| @@ -77,6 +78,7 @@ def stop_on_except(func): | |||
| @wraps(func) | |||
| def handle_except(*args, **kwargs): | |||
| try: | |||
| ExitSignalHandle_.start() # Set flag to running and receive Ctrl+C message | |||
| func(*args, **kwargs) | |||
| except: | |||
| stop() | |||
| @@ -13,31 +13,238 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Serving, distributed worker agent startup""" | |||
| import inspect | |||
| import os | |||
| import time | |||
| from multiprocessing import Process, Pipe | |||
| from mindspore_serving._mindspore_serving import ExitSignalHandle_ | |||
| from mindspore_serving._mindspore_serving import WorkerAgent_, AgentStartUpConfig_ | |||
| from mindspore_serving import log as logger | |||
| from mindspore_serving.worker import check_type | |||
| from mindspore_serving.worker.distributed import worker_agent | |||
| def _get_local_ip(rank_list, port): | |||
| """Get the local ip from the rank table config""" | |||
| import socket | |||
| ip_list = [] | |||
| for item in rank_list: | |||
| ip_list.append(item.ip) | |||
| with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: | |||
| for ip in ip_list: | |||
| try: | |||
| s.bind((ip, port)) | |||
| logger.info(f"Get local machine ip success, ip {ip}") | |||
| return ip | |||
| # pylint: disable=bare-except | |||
| except: | |||
| pass | |||
| raise RuntimeError(f"Get local machine ip failed, rank table ips: {ip_list}, bind port {port}") | |||
| def _update_model_files_path(model_files, group_config_files): | |||
| """Check and return model files or group config files""" | |||
| script_dir = os.path.dirname(os.path.realpath(__file__)) | |||
| logger.info(f"input model files: {model_files}") | |||
| logger.info(f"input group config files: {group_config_files}") | |||
| model_files_temp = [] | |||
| for item in model_files: | |||
| file_name = os.path.join(script_dir, item) | |||
| if not os.access(file_name, os.R_OK): | |||
| raise RuntimeError(f"Cannot access model file '{file_name}'") | |||
| model_files_temp.append(file_name) | |||
| group_files_temp = [] | |||
| for item in group_config_files: | |||
| file_name = os.path.join(script_dir, item) | |||
| if not os.access(file_name, os.R_OK): | |||
| raise RuntimeError(f"Cannot access group config file '{file_name}'") | |||
| group_files_temp.append(file_name) | |||
| logger.info(f"absolute model files: {model_files_temp}") | |||
| logger.info(f"absolute group config files: {group_files_temp}") | |||
| return model_files_temp, group_files_temp | |||
| def _make_json_table_file(distributed_config): | |||
| """Make rank table json file""" | |||
| rank_size = len(distributed_config.rank_list) | |||
| runtime_dir = os.path.abspath(".") | |||
| time_stamp = str(time.strftime('%Y_%m_%d_%H_%M_%S', time.localtime(time.time()))) | |||
| rank_table_file_name = os.path.join(runtime_dir, f"hccl_rank_table_{time_stamp}_{rank_size}.json") | |||
| with open(rank_table_file_name, "w") as fp: | |||
| fp.write(distributed_config.rank_table_content) | |||
| return rank_table_file_name | |||
| signal_success = "Success" | |||
| signal_exit = "Exit" | |||
| signal_heartbeat = "HeartBeat" | |||
| def _recv_parent(index, recv_pipe): | |||
| """Receive message from Start up process. | |||
| Return False on Ctrl+C(and worker Stop message) Exit Signal, heartbeat failed, and signal_exit. | |||
| Return True on receiving signal_success.""" | |||
| try: | |||
| while True: | |||
| heartbeat_count = 0 | |||
| while not recv_pipe.poll(0.1): | |||
| if ExitSignalHandle_.has_stopped(): | |||
| logger.warning(f"Child {index}: Exit on Ctrl+C or stop message from worker") | |||
| return False | |||
| heartbeat_count += 1 | |||
| if heartbeat_count >= 30: # 3s | |||
| logger.warning(f"Child {index}: Exit on failure of receiving parent message") | |||
| return False | |||
| parent_signal = recv_pipe.recv() | |||
| if parent_signal != signal_heartbeat: | |||
| break | |||
| if parent_signal == signal_success: | |||
| logger.info(f"Child {index}: Receive success") | |||
| return True | |||
| if parent_signal == signal_exit: | |||
| logger.warning(f"Child {index}: Exit on receiving exit message") | |||
| else: | |||
| logger.warning(f"Child {index}: Exit on receiving unknown message {parent_signal}") | |||
| # pylint: disable=broad-except | |||
| except Exception as e: | |||
| logger.warning(f"Child {index}: Exit on exception: {e}") | |||
| return False | |||
| def _agent_process(send_pipe, recv_pipe, index, start_config): | |||
| """Agent process""" | |||
| try: | |||
| # listening success or failed message from parent process | |||
| ExitSignalHandle_.start() # Set flag to running and receive Ctrl+C message | |||
| worker_agent.start_worker_agent(start_config=start_config) | |||
| send_pipe.send((index, signal_success)) | |||
| success_msg = _recv_parent(index, recv_pipe) | |||
| if not success_msg: | |||
| worker_agent.stop() | |||
| send_pipe.close() | |||
| recv_pipe.close() | |||
| # pylint: disable=broad-except | |||
| except Exception as e: | |||
| logger.error(f"Child {index}: Catch exception and notify exit of others") | |||
| send_pipe.send((index, e)) | |||
| worker_agent.stop() | |||
| raise | |||
| def startup_worker_agents(worker_ip, worker_port, | |||
| get_model_files_fun, get_group_configs_fun, | |||
| rank_start, agent_start_port=7000): | |||
| """Start up all needed worker agents on one machine | |||
| """ | |||
| def _start_listening_child_processes(p_recv_pipe, send_pipe_list, subprocess_list): | |||
| """Listening child process""" | |||
| def send_pipe_msg(send_pipe, msg): | |||
| try: | |||
| send_pipe.send(msg) | |||
| # pylint: disable=broad-except | |||
| except Exception as e: | |||
| logger.warning(f"Send pipe message exception happen: {e}") | |||
| count = len(send_pipe_list) | |||
| for _ in range(count): | |||
| while True: | |||
| if p_recv_pipe.poll(0.1): | |||
| break | |||
| for send_pipe, process in zip(send_pipe_list, subprocess_list): | |||
| if process.is_alive(): | |||
| continue | |||
| logger.warning("Fail to start agents because of death of one agent") | |||
| for send_pipe_x, process_x in zip(send_pipe_list, subprocess_list): | |||
| if process_x.is_alive(): | |||
| send_pipe_msg(send_pipe_x, signal_exit) | |||
| return False | |||
| for send_pipe in send_pipe_list: | |||
| send_pipe_msg(send_pipe, signal_heartbeat) | |||
| _, msg = p_recv_pipe.recv() | |||
| if isinstance(msg, Exception): | |||
| logger.warning("Fail to start agents because of exception raise by one agent") | |||
| for send_pipe in send_pipe_list: | |||
| send_pipe_msg(send_pipe, signal_exit) | |||
| return False | |||
| for send_pipe in send_pipe_list: | |||
| send_pipe_msg(send_pipe, signal_success) | |||
| logger.info("Success to start agents") | |||
| return True | |||
| def _startup_all_agents(common_meta, worker_ip, worker_port, | |||
| agent_ip, agent_start_port, device_id_list, rank_id_list, | |||
| model_files, group_config_files, rank_table_file): | |||
| """Start up all agents in one machine""" | |||
| servable_name = common_meta.servable_name | |||
| index = 0 | |||
| send_pipe_list = [] | |||
| subprocess_list = [] | |||
| c_send_pipe, p_recv_pipe = Pipe() | |||
| for device_id, rank_id, model_file, group_file in zip(device_id_list, rank_id_list, model_files, | |||
| group_config_files): | |||
| p_send_pipe, c_recv_pipe = Pipe() | |||
| send_pipe_list.append(p_send_pipe) | |||
| agent_port = agent_start_port + index | |||
| start_config = AgentStartUpConfig_() | |||
| start_config.rank_id = rank_id | |||
| start_config.device_id = device_id | |||
| start_config.model_file_name = model_file | |||
| start_config.group_file_name = group_file | |||
| start_config.rank_table_json_file_name = rank_table_file | |||
| start_config.agent_ip = agent_ip | |||
| start_config.agent_port = agent_port | |||
| start_config.worker_ip = worker_ip | |||
| start_config.worker_port = worker_port | |||
| start_config.common_meta = common_meta | |||
| process = Process(target=_agent_process, | |||
| args=(c_send_pipe, c_recv_pipe, index, start_config), | |||
| name=f"{servable_name}_worker_agent_rank{rank_id}_device{device_id}") | |||
| process.start() | |||
| subprocess_list.append(process) | |||
| index += 1 | |||
| ret = _start_listening_child_processes(p_recv_pipe, send_pipe_list, subprocess_list) | |||
| if not ret: | |||
| WorkerAgent_.notify_failed(worker_ip, worker_port) | |||
| def startup_worker_agents(worker_ip, worker_port, model_files, group_config_files, agent_start_port=7000): | |||
| """Start up all needed worker agents on one machine""" | |||
| check_type.check_str("worker_ip", worker_ip) | |||
| check_type.check_ip_port("worker_port", worker_port) | |||
| check_type.check_int("agent_start_port", agent_start_port, 1, 65535 - 7) | |||
| if inspect.isfunction(get_model_files_fun): | |||
| pass | |||
| else: | |||
| if not isinstance(get_model_files_fun, [list, tuple]): | |||
| raise RuntimeError(f"Check failed, get_model_files_fun first must be function or tuple/list of str, " | |||
| f"now is {type(get_model_files_fun)}") | |||
| if inspect.isfunction(get_group_configs_fun): | |||
| pass | |||
| else: | |||
| if not isinstance(get_group_configs_fun, [list, tuple]): | |||
| raise RuntimeError(f"Check failed, get_group_configs_fun first must be function or tuple/list of str, " | |||
| f"now is {type(get_group_configs_fun)}") | |||
| check_type.check_int("rank_start", rank_start, 0) | |||
| if rank_start % 8 != 0: | |||
| raise RuntimeError(f"Parameter 'rank_start' must be mulfiply of 8, now is {rank_start}") | |||
| model_files = check_type.check_and_as_int_tuple_list("model_files", model_files) | |||
| group_config_files = check_type.check_and_as_int_tuple_list("group_config_files", group_config_files) | |||
| distributed_config = WorkerAgent_.get_agents_config_from_worker(worker_ip, worker_port) | |||
| # get machine ip | |||
| rank_list = distributed_config.rank_list | |||
| local_ip = _get_local_ip(rank_list, agent_start_port) | |||
| # get all device_id and rank_id | |||
| local_device_id_list = [] | |||
| local_rank_id_list = [] | |||
| for rank_id, item in enumerate(rank_list): | |||
| if item.ip == local_ip: | |||
| local_device_id_list.append(item.device_id) | |||
| local_rank_id_list.append(rank_id) | |||
| # handle model files and group config files | |||
| if len(local_device_id_list) != len(model_files): | |||
| raise RuntimeError(f"Card count {local_device_id_list} described rank table does not equal to model files size " | |||
| f"{len(model_files)}, model files: {model_files}") | |||
| if len(local_device_id_list) != len(group_config_files): | |||
| raise RuntimeError(f"Card count {local_device_id_list} described rank table does not equal to group config " | |||
| f"files size {len(group_config_files)}, group config files: {group_config_files}") | |||
| model_files, group_config_files = _update_model_files_path(model_files, group_config_files) | |||
| # make json table file and export env | |||
| rank_table_file = _make_json_table_file(distributed_config) | |||
| _startup_all_agents(distributed_config.common_meta, worker_ip, worker_port, local_ip, agent_start_port, | |||
| local_device_id_list, local_rank_id_list, | |||
| model_files, group_config_files, rank_table_file) | |||
| @@ -13,15 +13,17 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Serving, distributed worker startup""" | |||
| from mindspore_serving.worker._worker import stop_on_except, _load_servable_config | |||
| from mindspore_serving.worker._worker import _start_py_task, _start_wait_and_clear | |||
| from mindspore_serving.worker import check_type | |||
| from mindspore_serving._mindspore_serving import Worker_ | |||
| from mindspore_serving.worker import check_type | |||
| from mindspore_serving.worker._worker import _start_py_task, _start_wait_and_clear | |||
| from mindspore_serving.worker._worker import stop_on_except, _load_servable_config | |||
| @stop_on_except | |||
| def start_distributed_servable(servable_directory, servable_name, rank_table_json_file, version_number=1, | |||
| worker_ip="0.0.0.0", worker_port=6200, master_ip="0.0.0.0", master_port=6100): | |||
| worker_ip="0.0.0.0", worker_port=6200, master_ip="0.0.0.0", master_port=6100, | |||
| wait_agents_time_in_seconds=300): | |||
| r""" | |||
| Start up the servable named 'servable_name' defined in 'servable_directory', and link the worker to the master | |||
| through gRPC (master_ip, master_port). | |||
| @@ -46,6 +48,7 @@ def start_distributed_servable(servable_directory, servable_name, rank_table_jso | |||
| master_port (int): The master port the worker linked to. | |||
| worker_ip (str): The worker ip the master and agents linked to. | |||
| worker_port (int): The worker port the master and agents linked to. | |||
| wait_agents_time_in_seconds(int): The maximum time in seconds the worker waiting ready of all agents. | |||
| Examples: | |||
| >>> import os | |||
| @@ -70,15 +73,15 @@ def start_distributed_servable(servable_directory, servable_name, rank_table_jso | |||
| check_type.check_ip_port('worker_port', worker_port) | |||
| _load_servable_config(servable_directory, servable_name) | |||
| _start_wait_and_clear() | |||
| Worker_.start_distributed_servable(servable_directory, servable_name, rank_table_json_file, version_number, | |||
| master_ip, master_port, worker_ip, worker_port) | |||
| master_ip, master_port, worker_ip, worker_port, wait_agents_time_in_seconds) | |||
| _start_py_task(Worker_.get_batch_size()) | |||
| _start_wait_and_clear() | |||
| @stop_on_except | |||
| def start_distributed_servable_in_master(servable_directory, servable_name, rank_table_json_file, version_number=1, | |||
| worker_ip="0.0.0.0", worker_port=6200): | |||
| worker_ip="0.0.0.0", worker_port=6200, wait_agents_time_in_seconds=300): | |||
| r""" | |||
| Start up the servable named 'servable_name' defined in 'svable_directory', and the worker will run in | |||
| the process of the master. | |||
| @@ -97,6 +100,7 @@ def start_distributed_servable_in_master(servable_directory, servable_name, rank | |||
| rank_table_json_file (str): The ranke table json file name. | |||
| worker_ip (str): The worker ip the agents linked to. | |||
| worker_port (int): The worker port the agents linked to. | |||
| wait_agents_time_in_seconds(int): The maximum time in seconds the worker waiting ready of all agents. | |||
| Examples: | |||
| >>> import os | |||
| @@ -121,7 +125,7 @@ def start_distributed_servable_in_master(servable_directory, servable_name, rank | |||
| check_type.check_ip_port('worker_port', worker_port) | |||
| _load_servable_config(servable_directory, servable_name) | |||
| _start_wait_and_clear() | |||
| Worker_.start_distributed_servable_in_master(servable_directory, servable_name, rank_table_json_file, | |||
| version_number, worker_ip, worker_port) | |||
| version_number, worker_ip, worker_port, wait_agents_time_in_seconds) | |||
| _start_py_task(Worker_.get_batch_size()) | |||
| _start_wait_and_clear() | |||
| @@ -13,22 +13,54 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Serving, distributed worker agent""" | |||
| from mindspore_serving.worker import check_type | |||
| import os | |||
| import threading | |||
| from mindspore_serving._mindspore_serving import WorkerAgent_, AgentStartUpConfig_ | |||
| from mindspore_serving import log as logger | |||
| def _start_worker_agent(agent_ip, agent_port, worker_ip, worker_port, | |||
| rank_id, device_id, model_file, group_config_file, rank_table_file, | |||
| with_bach_dim, without_batch_dim_inputs): | |||
| def start_worker_agent(start_config): | |||
| """Start up one worker agent on one device id, invoke by agent_startup.startup_worker_agents | |||
| """ | |||
| check_type.check_str("agent_ip", agent_ip) | |||
| check_type.check_ip_port("agent_port", agent_port) | |||
| check_type.check_str("worker_ip", worker_ip) | |||
| check_type.check_ip_port("worker_port", worker_port) | |||
| check_type.check_int("rank_id", rank_id, 0) | |||
| check_type.check_int("device_id", device_id, 0) | |||
| check_type.check_str("model_file", model_file) | |||
| check_type.check_str("group_config_file", group_config_file) | |||
| check_type.check_str("rank_table_file", rank_table_file) | |||
| check_type.check_bool("with_bach_dim", with_bach_dim) | |||
| check_type.check_and_as_int_tuple_list("without_batch_dim_inputs", without_batch_dim_inputs, 0) | |||
| if not isinstance(start_config, AgentStartUpConfig_): | |||
| raise RuntimeError("Parameter 'start_config' should be instance of AgentStartUpConfig_") | |||
| os.environ["RANK_ID"] = str(start_config.rank_id) | |||
| os.environ["DEVICE_ID"] = str(start_config.device_id) | |||
| os.environ["MS_ENABLE_HCCL"] = "1" | |||
| os.environ["PARA_GROUP_FILE"] = start_config.group_file_name | |||
| os.environ["RANK_TABLE_FILE"] = start_config.rank_table_json_file_name | |||
| for item in ("RANK_ID", "DEVICE_ID", "MS_ENABLE_HCCL", "PARA_GROUP_FILE", "RANK_TABLE_FILE", | |||
| "LD_LIBRARY_PATH", "PYTHONPATH"): | |||
| logger.info(f"Env {item}: {os.getenv(item, '')}") | |||
| WorkerAgent_.start_agent(start_config) | |||
| start_wait_and_clear() | |||
| _wait_and_clear_thread = None | |||
| def start_wait_and_clear(): | |||
| """Waiting for Ctrl+C, and clear up environment""" | |||
| def thread_func(): | |||
| logger.info("Serving worker: wait for Ctrl+C to exit ------------------------------------") | |||
| print("Serving worker: wait for Ctrl+C to exit ------------------------------------") | |||
| WorkerAgent_.wait_and_clear() | |||
| logger.info("Serving worker: exited ------------------------------------") | |||
| print("Serving worker: exited ------------------------------------") | |||
| global _wait_and_clear_thread | |||
| if not _wait_and_clear_thread: | |||
| _wait_and_clear_thread = threading.Thread(target=thread_func) | |||
| _wait_and_clear_thread.start() | |||
| def stop(): | |||
| r""" | |||
| Stop the running of agent. | |||
| """ | |||
| WorkerAgent_.stop_and_clear() | |||