From 8e221f938ab0df5045d443ac7d5818abc6b4fe59 Mon Sep 17 00:00:00 2001 From: xuyongfei Date: Mon, 18 Jan 2021 11:19:06 +0800 Subject: [PATCH 01/10] Serving, gpt3 framework --- .../distributed_worker/agent_executor.cc | 34 ++++++ .../distributed_worker/agent_executor.h | 48 ++++++++ .../distributed_worker/agent_startup.cc | 30 +++++ .../worker/distributed_worker/agent_startup.h | 49 ++++++++ .../ccsrc/worker/distributed_worker/common.h | 77 ++++++++++++ .../distributed_servable.cc | 39 ++++++ .../distributed_worker/distributed_servable.h | 60 ++++++++++ .../worker/distributed_worker/worker_agent.cc | 38 ++++++ .../worker/distributed_worker/worker_agent.h | 40 +++++++ .../worker/distributed/agent_startup.py | 22 ++++ .../worker/distributed/distributed_worker.py | 113 ++++++++++++++++++ .../worker/distributed/register.py | 20 ++++ .../worker/distributed/worker_agent.py | 23 ++++ 13 files changed, 593 insertions(+) create mode 100644 mindspore_serving/ccsrc/worker/distributed_worker/agent_executor.cc create mode 100644 mindspore_serving/ccsrc/worker/distributed_worker/agent_executor.h create mode 100644 mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.cc create mode 100644 mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.h create mode 100644 mindspore_serving/ccsrc/worker/distributed_worker/common.h create mode 100644 mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.cc create mode 100644 mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.h create mode 100644 mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.cc create mode 100644 mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.h create mode 100644 mindspore_serving/worker/distributed/agent_startup.py create mode 100644 mindspore_serving/worker/distributed/distributed_worker.py create mode 100644 mindspore_serving/worker/distributed/register.py create mode 100644 mindspore_serving/worker/distributed/worker_agent.py diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/agent_executor.cc b/mindspore_serving/ccsrc/worker/distributed_worker/agent_executor.cc new file mode 100644 index 0000000..d2c4a1b --- /dev/null +++ b/mindspore_serving/ccsrc/worker/distributed_worker/agent_executor.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "worker/distributed_worker/agent_executor.h" + +namespace mindspore { +namespace serving { + +Status WorkerAgentExecutor::LoadModelFromFile(const AgentStartUpConfig &config) { return Status(); } +Status WorkerAgentExecutor::UnloadModel() { return Status(); } +Status WorkerAgentExecutor::ExecuteModel(const std::vector &request, std::vector *reply) { + return Status(); +} +std::vector WorkerAgentExecutor::GetInputInfos() const { + return std::vector(); +} +std::vector WorkerAgentExecutor::GetOutputInfos() const { + return std::vector(); +} +ssize_t WorkerAgentExecutor::GetBatchSize() const { return 0; } +} // namespace serving +} // namespace mindspore diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/agent_executor.h b/mindspore_serving/ccsrc/worker/distributed_worker/agent_executor.h new file mode 100644 index 0000000..dd5d16a --- /dev/null +++ b/mindspore_serving/ccsrc/worker/distributed_worker/agent_executor.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_SERVING_WORKER_AGENT_EXECUTOR_H +#define MINDSPORE_SERVING_WORKER_AGENT_EXECUTOR_H + +#include +#include "common/serving_common.h" +#include "worker/inference/inference.h" +#include "worker/distributed_worker/common.h" + +namespace mindspore { +namespace serving { +class MS_API WorkerAgentExecutor { + public: + // from python + Status LoadModelFromFile(const AgentStartUpConfig &config); + // ctrl+c, worker exit + Status UnloadModel(); + + // from worker + Status ExecuteModel(const std::vector &request, std::vector *reply); + + // for register + std::vector GetInputInfos() const; + + std::vector GetOutputInfos() const; + + ssize_t GetBatchSize() const; +}; + +} // namespace serving +} // namespace mindspore + +#endif // MINDSPORE_SERVING_WORKER_AGENT_EXECUTOR_H diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.cc b/mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.cc new file mode 100644 index 0000000..9f766a9 --- /dev/null +++ b/mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.cc @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "worker/distributed_worker/agent_startup.h" +namespace mindspore { +namespace serving { + +Status WorkerAgentStartUp::InitAgentsConfig(const std::string &model_dir, const std::string &model_file_prefix, + const std::string &group_file_dir, const std::string &group_file_prefix) { + return Status(); +} +Status WorkerAgentStartUp::GetAgentsConfigsFromWorker(const std::string &agent_ip, uint32_t agent_start_port, + const std::string &worker_ip, uint32_t worker_port) { + return Status(); +} +Status WorkerAgentStartUp::GetCurrentMachineConfigs(std::vector *configs) { return Status(); } +} // namespace serving +} // namespace mindspore diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.h b/mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.h new file mode 100644 index 0000000..37df7ce --- /dev/null +++ b/mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_SERVING_WORKER_AGENT_STARTUP_H +#define MINDSPORE_SERVING_WORKER_AGENT_STARTUP_H +#include +#include +#include "common/serving_common.h" +#include "worker/distributed_worker/common.h" +#include "worker/inference/inference.h" + +namespace mindspore { +namespace serving { + +class MS_API WorkerAgentStartUp { + public: + // from python, worker_agent.py + // start_worker_agent + // step1, get agents config from worker + Status InitAgentsConfig(const std::string &model_dir, const std::string &model_file_prefix, + const std::string &group_file_dir, const std::string &group_file_prefix); + + Status GetAgentsConfigsFromWorker(const std::string &agent_ip, uint32_t agent_start_port, + const std::string &worker_ip, uint32_t worker_port); + // step2, invoke from python, get current machine agents config + Status GetCurrentMachineConfigs(std::vector *configs); + + private: + DistributedServableConfig config_; + std::string worker_address_; +}; + +} // namespace serving +} // namespace mindspore + +#endif // MINDSPORE_SERVING_WORKER_AGENT_STARTUP_H diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/common.h b/mindspore_serving/ccsrc/worker/distributed_worker/common.h new file mode 100644 index 0000000..4a8dbb2 --- /dev/null +++ b/mindspore_serving/ccsrc/worker/distributed_worker/common.h @@ -0,0 +1,77 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_SERVING_DISTRIBUTED_WORKER_COMMON_H +#define MINDSPORE_SERVING_DISTRIBUTED_WORKER_COMMON_H + +#include +#include +#include +#include "common/serving_common.h" +#include "worker/inference/inference.h" + +namespace mindspore { +namespace serving { + +struct OneRankConfig { + std::string ip; + uint32_t device_id = 0; +}; + +struct DistributedServableCommonConfig { + bool with_batch_dim; + std::vector without_batch_dim_inputs; +}; + +struct DistributedServableConfig { + uint32_t rank_size = 0; + uint32_t stage_size = 0; + const std::string models_dir; + const std::string groups_dir; + std::string rank_table_content; + std::vector rank_list; + DistributedServableCommonConfig common_config; +}; + +struct WorkerAgentSpec { + std::string ip; + uint32_t port = 0; + uint32_t rank_id = 0; + std::vector input_infos; + std::vector output_infos; + uint32_t batch_size = 0; +}; + +struct AgentStartUpConfig { + uint32_t rank_id; + uint32_t device_id; + std::string model_file_name; + std::string group_file_name; + std::string rank_table_json_file_name; + + std::string agent_ip; + uint32_t agent_port; + std::string worker_ip; + uint32_t worker_port; + + DistributedServableCommonConfig common_config; + std::map other_options; +}; + +} // namespace serving +} // namespace mindspore + +#endif // MINDSPORE_SERVING_DISTRIBUTED_WORKER_COMMON_H diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.cc b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.cc new file mode 100644 index 0000000..1cfd4ba --- /dev/null +++ b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.cc @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "worker/distributed_worker/distributed_servable.h" +#include +#include + +namespace mindspore { +namespace serving { + +Status DistributedServable::Predict(const std::vector &input, std::vector *output) { + return Status(); +} +std::vector DistributedServable::GetInputInfos() const { return std::vector(); } +std::vector DistributedServable::GetOutputInfos() const { return std::vector(); } +uint64_t DistributedServable::GetBatchSize() const { return 0; } +Status DistributedServable::GetDistributedServableConfig(DistributedServableConfig *config) { return Status(); } +Status DistributedServable::RegisterAgent(const WorkerAgentSpec &agent_spec) { return Status(); } +Status DistributedServable::UnregisterAgent(const WorkerAgentSpec &agent_spec) { return Status(); } +Status DistributedServable::SetProperty(uint32_t rank_size, uint32_t stage_size, bool with_bach_dim, + const std::vector &without_batch_dim_inputs) { + return Status(); +} +Status DistributedServable::InitConfigOnStartup(const std::string &rank_table_json_file) { return Status(); } +} // namespace serving +} // namespace mindspore diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.h b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.h new file mode 100644 index 0000000..fad293c --- /dev/null +++ b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.h @@ -0,0 +1,60 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_SERVING_WORKER_DISTRIBUTED_SERVABLE_H +#define MINDSPORE_SERVING_WORKER_DISTRIBUTED_SERVABLE_H + +#include +#include +#include +#include "worker/model.h" +#include "worker/distributed_worker/common.h" + +namespace mindspore { +namespace serving { + +class MS_API DistributedServable : public ServableBase { + public: + // from python, servable_config.py + Status SetProperty(uint32_t rank_size, uint32_t stage_size, bool with_bach_dim, + const std::vector &without_batch_dim_inputs); + // from python, worker.py + Status InitConfigOnStartup(const std::string &rank_table_json_file); + // invoke from agent + Status GetDistributedServableConfig(DistributedServableConfig *config); + // send model and group + + // register and unregister agent, agent_spec_list_ + Status RegisterAgent(const WorkerAgentSpec &agent_spec); + Status UnregisterAgent(const WorkerAgentSpec &agent_spec); + + // predict, use config_ and agent_spec_list_ + Status Predict(const std::vector &input, std::vector *output) override; + + std::vector GetInputInfos() const override; + std::vector GetOutputInfos() const override; + uint64_t GetBatchSize() const override; + + private: + DistributedServableConfig config_; + std::map agent_spec_list_; + // agent stubs +}; + +} // namespace serving +} // namespace mindspore + +#endif // MINDSPORE_SERVING_WORKER_DISTRIBUTED_SERVABLE_H diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.cc b/mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.cc new file mode 100644 index 0000000..8d4d1b5 --- /dev/null +++ b/mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.cc @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "worker/distributed_worker/worker_agent.h" + +namespace mindspore { +namespace serving { + +WorkerAgent &WorkerAgent::Instance() { + static WorkerAgent instance; + return instance; +} + +Status WorkerAgent::LoadModelFromFile(const AgentStartUpConfig &config) { + config_ = config; + return executor_.LoadModelFromFile(config); +} + +Status WorkerAgent::Clear() { return executor_.UnloadModel(); } + +Status WorkerAgent::ExecuteModel(const std::vector &request, std::vector *reply) { + return executor_.ExecuteModel(request, reply); +} + +} // namespace serving +} // namespace mindspore diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.h b/mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.h new file mode 100644 index 0000000..a160d55 --- /dev/null +++ b/mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_SERVING_WORKER_AGENT_H +#define MINDSPORE_SERVING_WORKER_AGENT_H +#include +#include "worker/distributed_worker/agent_executor.h" + +namespace mindspore { +namespace serving { +class MS_API WorkerAgent { + public: + static WorkerAgent &Instance(); + Status LoadModelFromFile(const AgentStartUpConfig &config); + Status Clear(); + + Status ExecuteModel(const std::vector &request, std::vector *reply); + + private: + AgentStartUpConfig config_; + WorkerAgentExecutor executor_; +}; + +} // namespace serving +} // namespace mindspore + +#endif // MINDSPORE_SERVING_WORKER_AGENT_H diff --git a/mindspore_serving/worker/distributed/agent_startup.py b/mindspore_serving/worker/distributed/agent_startup.py new file mode 100644 index 0000000..83603eb --- /dev/null +++ b/mindspore_serving/worker/distributed/agent_startup.py @@ -0,0 +1,22 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Serving, distributed worker agent startup""" + + +def startup_worker_agents(agent_ip, agent_start_port, worker_ip, worker_port, + model_dir, model_file_prefix, group_config_dir, group_file_prefix): + """Start up all needed worker agents on one machine + """ + pass diff --git a/mindspore_serving/worker/distributed/distributed_worker.py b/mindspore_serving/worker/distributed/distributed_worker.py new file mode 100644 index 0000000..402a377 --- /dev/null +++ b/mindspore_serving/worker/distributed/distributed_worker.py @@ -0,0 +1,113 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Serving, distributed worker startup""" +from .._worker import stop_on_except, _load_servable_config +from .. import check_type + + +@stop_on_except +def start_distributed_servable(servable_directory, servable_name, rank_table_json_file, version_number=0, + master_ip="0.0.0.0", master_port=6100, worker_ip="0.0.0.0", worker_port=6200): + r""" + Start up the servable named 'servable_name' defined in 'servable_directory', and link the worker to the master + through gRPC (master_ip, master_port). + + Serving has two running modes. One is running in a single process, providing the Serving service of a single model. + The other includes a master and multiple workers. This interface is for the second scenario. + + The master is responsible for providing the Serving access interface for clients, + while the worker is responsible for providing the inference service of the specific model. The communications + between the master and workers through gPRC are defined as (master_ip, master_port) and (worker_ip, worker_port). + + Args: + servable_directory (str): The directory where the servable is located in. There expects to has a directory + named `servable_name`. For more detail: + `How to config Servable `_ . + + servable_name (str): The servable name. + version_number (int): Servable version number to be loaded. The version number should be a positive integer, + starting from 1, and 0 means to load the latest version. Default: 0. + device_type (str): Currently only supports "Ascend", "Davinci" and None, Default: None. + "Ascend" means the device type can be Ascend910 or Ascend310, etc. + "Davinci" has the same meaning as "Ascend". + None means the device type is determined by the MindSpore environment. + device_id (int): The id of the device the model loads into and runs in. + master_ip (str): The master ip the worker linked to. + master_port (int): The master port the worker linked to. + worker_ip (str): The worker ip the master linked to. + worker_port (int): The worker port the master linked to. + + Examples: + >>> import os + >>> from mindspore_serving import worker + >>> + >>> servable_dir = os.path.abspath(".") + >>> worker.start_servable(servable_dir, "lenet", device_id=0, \ + ... master_ip="127.0.0.1", master_port=6500, \ + ... host_ip="127.0.0.1", host_port=6600) + """ + check_type.check_str('servable_directory', servable_directory) + check_type.check_str('servable_name', servable_name) + check_type.check_int('version_number', version_number, 0) + check_type.check_str('rank_table_json_file', rank_table_json_file) + + check_type.check_str('master_ip', master_ip) + check_type.check_ip_port('master_port', master_port) + + check_type.check_str('worker_ip', worker_ip) + check_type.check_ip_port('worker_port', worker_port) + + _load_servable_config(servable_directory, servable_name) + + +@stop_on_except +def start_distributed_servable_in_master(servable_directory, servable_name, rank_table_json_file, version_number=0): + r""" + Start up the servable named 'servable_name' defined in 'svable_directory', and the worker will run in + the process of the master. + + Serving has two running modes. One is running in a single process, providing the Serving service of a single model. + The other includes a master and multiple workers. This interface is for the first scenario. + + Args: + servable_directory (str): The directory where the servable is located in. There expects to has a directory named + `servable_name`. For more detail: + `How to config Servable `_ . + + servable_name (str): The servable name. + version_number (int): Servable version number to be loaded. The version number should be a positive integer, + starting from 1, and 0 means to load the latest version. Default: 0. + device_type (str): Currently only supports "Ascend", "Davinci" and None, Default: None. + "Ascend" means the device type can be Ascend910 or Ascend310, etc. + "Davinci" has the same meaning as "Ascend". + None means the device type is determined by the MindSpore environment. + + Examples: + >>> import os + >>> from mindspore_serving import worker + >>> from mindspore_serving import master + >>> + >>> servable_dir = os.path.abspath(".") + >>> worker.start_servable_in_master(servable_dir, "lenet", device_id=0) + >>> + >>> master.start_grpc_server("0.0.0.0", 5500) + >>> master.start_restful_server("0.0.0.0", 1500) + """ + check_type.check_str('servable_directory', servable_directory) + check_type.check_str('servable_name', servable_name) + check_type.check_int('version_number', version_number, 0) + check_type.check_str('rank_table_json_file', rank_table_json_file) + + _load_servable_config(servable_directory, servable_name) diff --git a/mindspore_serving/worker/distributed/register.py b/mindspore_serving/worker/distributed/register.py new file mode 100644 index 0000000..59ce2d8 --- /dev/null +++ b/mindspore_serving/worker/distributed/register.py @@ -0,0 +1,20 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Serving, distributed worker register""" + + +def declare_distributed_servable(rank_size, stage_size, with_bach_dim, without_batch_dim_inputs): + """declare distributed servable in servable_config.py""" + pass diff --git a/mindspore_serving/worker/distributed/worker_agent.py b/mindspore_serving/worker/distributed/worker_agent.py new file mode 100644 index 0000000..4a27c75 --- /dev/null +++ b/mindspore_serving/worker/distributed/worker_agent.py @@ -0,0 +1,23 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Serving, distributed worker agent""" + + +def _start_worker_agent(agent_ip, agent_start_port, worker_ip, worker_port, + rank_id, device_id, model_file, group_config_file, rank_table_file, + with_bach_dim, without_batch_dim_inputs): + """Start up one worker agent on one device id, invoke by agent_startup.startup_worker_agents + """ + pass From a802c7686e54d8cbe15b82e803691b51a5d42c9a Mon Sep 17 00:00:00 2001 From: xuyongfei Date: Mon, 25 Jan 2021 21:37:03 +0800 Subject: [PATCH 02/10] Serving, gpt3, distributed servable --- .../ccsrc/python/worker/worker_py.cc | 15 +- .../worker/ascend_servable/ascend_sevable.cc | 253 +++++++++++++++ .../worker/ascend_servable/ascend_sevable.h | 66 ++++ .../distributed_worker/agent_executor.cc | 2 +- .../distributed_worker/agent_executor.h | 2 +- .../distributed_worker/agent_startup.cc | 2 +- .../worker/distributed_worker/agent_startup.h | 2 +- .../ccsrc/worker/distributed_worker/common.h | 2 +- .../distributed_servable.cc | 2 +- .../distributed_worker/distributed_servable.h | 4 +- .../worker/distributed_worker/worker_agent.cc | 2 +- .../worker/distributed_worker/worker_agent.h | 2 +- .../ccsrc/worker/inference/inference.h | 126 -------- .../worker/inference/mindspore_model_wrap.cc | 146 ++++----- .../worker/inference/mindspore_model_wrap.h | 38 +-- mindspore_serving/ccsrc/worker/model.cc | 33 -- .../ccsrc/worker/{model.h => sevable_base.h} | 24 +- .../ccsrc/worker/work_executor.h | 2 +- mindspore_serving/ccsrc/worker/worker.cc | 292 +++--------------- mindspore_serving/ccsrc/worker/worker.h | 23 +- .../worker/distributed/agent_startup.py | 29 +- .../worker/distributed/distributed_worker.py | 38 ++- .../worker/distributed/register.py | 8 +- .../worker/distributed/worker_agent.py | 17 +- tests/ut/cpp/common/test_servable_common.h | 8 +- tests/ut/cpp/tests/test_start_worker.cc | 80 +---- tests/ut/runtest.sh | 29 +- 27 files changed, 571 insertions(+), 676 deletions(-) create mode 100644 mindspore_serving/ccsrc/worker/ascend_servable/ascend_sevable.cc create mode 100644 mindspore_serving/ccsrc/worker/ascend_servable/ascend_sevable.h delete mode 100644 mindspore_serving/ccsrc/worker/model.cc rename mindspore_serving/ccsrc/worker/{model.h => sevable_base.h} (63%) diff --git a/mindspore_serving/ccsrc/python/worker/worker_py.cc b/mindspore_serving/ccsrc/python/worker/worker_py.cc index 2980aae..b72cb5f 100644 --- a/mindspore_serving/ccsrc/python/worker/worker_py.cc +++ b/mindspore_serving/ccsrc/python/worker/worker_py.cc @@ -21,6 +21,7 @@ #include "common/exit_handle.h" #include "worker/notfiy_master/grpc_notify.h" #include "worker/notfiy_master/local_notify.h" +#include "worker/ascend_servable/ascend_sevable.h" namespace mindspore::serving { @@ -28,7 +29,12 @@ void PyWorker::StartServable(const std::string &model_directory, const std::stri const std::string &master_ip, uint32_t master_port, const std::string &host_ip, uint32_t host_port) { auto notify_master = std::make_shared(master_ip, master_port, host_ip, host_port); - auto status = Worker::GetInstance().StartServable(model_directory, model_name, version_number, notify_master); + auto servable = std::make_shared(); + auto status = servable->StartServable(model_directory, model_name, version_number); + if (status != SUCCESS) { + MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); + } + status = Worker::GetInstance().StartServable(servable, notify_master); if (status != SUCCESS) { MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); } @@ -45,7 +51,12 @@ void PyWorker::StartServable(const std::string &model_directory, const std::stri void PyWorker::StartServableInMaster(const std::string &model_directory, const std::string &model_name, uint32_t version_number) { auto notify_master = std::make_shared(); - auto status = Worker::GetInstance().StartServable(model_directory, model_name, version_number, notify_master); + auto servable = std::make_shared(); + auto status = servable->StartServable(model_directory, model_name, version_number); + if (status != SUCCESS) { + MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); + } + status = Worker::GetInstance().StartServable(servable, notify_master); if (status != SUCCESS) { MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); } diff --git a/mindspore_serving/ccsrc/worker/ascend_servable/ascend_sevable.cc b/mindspore_serving/ccsrc/worker/ascend_servable/ascend_sevable.cc new file mode 100644 index 0000000..ed84469 --- /dev/null +++ b/mindspore_serving/ccsrc/worker/ascend_servable/ascend_sevable.cc @@ -0,0 +1,253 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "worker/ascend_servable/ascend_sevable.h" +#include +#include +#include +#include +#include +#include "common/tensor.h" +#include "common/file_system_operation.h" +#include "worker/context.h" + +namespace { +static const char *kVersionStrategyLatest = "latest"; +static const char *kVersionStrategySpecific = "specific"; +} // namespace + +namespace mindspore::serving { + +AscendModelServable::~AscendModelServable() { session_.UnloadModel(); } + +Status AscendModelServable::Predict(const std::vector &input, std::vector *output) { + if (!model_loaded_) { + MSI_LOG_EXCEPTION << "Model has not been loaded"; + } + return session_.ExecuteModel(input, output); +} + +std::vector AscendModelServable::GetInputInfos() const { + if (!model_loaded_) { + MSI_LOG_EXCEPTION << "Model has not been loaded"; + } + return session_.GetInputInfos(); +} + +std::vector AscendModelServable::GetOutputInfos() const { + if (!model_loaded_) { + MSI_LOG_EXCEPTION << "Model has not been loaded"; + } + return session_.GetOutputInfos(); +} + +uint64_t AscendModelServable::GetBatchSize() const { + if (!model_loaded_) { + MSI_LOG_EXCEPTION << "Model has not been loaded"; + } + return session_.GetBatchSize(); +} + +TensorBasePtr AscendModelServable::MakeInferenceTensor(DataType data_type, const std::vector &shape) const { + return std::make_shared(data_type, shape); +} + +Status AscendModelServable::StartServable(const std::string &servable_directory, const std::string &servable_name, + uint32_t version_number) { + if (model_loaded_) { + MSI_LOG_EXCEPTION << "Model has loaded"; + } + base_spec_.servable_directory = servable_directory; + base_spec_.servable_name = servable_name; + base_spec_.version_number = version_number; + + std::string version_strategy; + if (version_number == 0) { + version_strategy = kVersionStrategyLatest; + } else { + version_strategy = kVersionStrategySpecific; + } + Status status; + ServableSignature signature; + if (!ServableStorage::Instance().GetServableDef(servable_name, &signature)) { + return INFER_STATUS_LOG_ERROR(FAILED) << "Servable '" << servable_name << "' has not been registered"; + } + status = InitDevice(signature.servable_meta.model_format, {}); + if (status != SUCCESS) { + MSI_LOG_ERROR << "Init env failed"; + return status; + } + + std::vector real_versions; + status = LoadServableConfig(base_spec_, version_strategy, &real_versions); + if (status != SUCCESS) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Start servable failed, there is no servable of the specified version number, specified version number: " + << version_number << ", servable directory: '" << base_spec_.servable_directory << "', servable name: '" + << base_spec_.servable_name + << "'. version number is a positive integer(started from 1) and 0 represents the maximum version number."; + } + auto real_version_number = real_versions[0]; + status = LoadModel(real_version_number); + if (status != SUCCESS) { + return status; + } + worker_spec_.servable_name = base_spec_.servable_name; + worker_spec_.version_number = real_version_number; + for (auto &method : signature.methods) { + WorkerMethodInfo worker_method_info; + worker_method_info.name = method.method_name; + for (auto &name : method.inputs) { + worker_method_info.input_names.push_back(name); + } + worker_spec_.methods.push_back(worker_method_info); + } + model_loaded_ = true; + MSI_LOG_INFO << status.StatusMessage(); + std::cout << status.StatusMessage() << std::endl; + return SUCCESS; +} + +void AscendModelServable::GetVersions(const LoadServableSpec &servable_spec, std::vector *real_versions) { + MSI_EXCEPTION_IF_NULL(real_versions); + // define version_strategy:"specific","latest","multi" + if (version_strategy_ == kVersionStrategySpecific) { + real_versions->push_back(servable_spec.version_number); + return; + } + auto trans_to_integer = [](const std::string &str) -> uint32_t { + uint32_t parsed_value = 0; + for (auto c : str) { + if (c < '0' || c > '9') { + return 0; + } + parsed_value = parsed_value * 10 + c - '0'; + } + if (std::to_string(parsed_value) != str) { + return 0; + } + return parsed_value; + }; + uint64_t newest_version = 0; + std::string model_path = servable_spec.servable_directory + "/" + servable_spec.servable_name; + auto sub_dir = GetAllSubDirsNotFullPath(model_path); + static std::set ignore_dir; + for (const auto &dir : sub_dir) { + if (dir == "__pycache__") continue; + auto version_parse = trans_to_integer(dir); + if (version_parse == 0) { + if (ignore_dir.emplace(servable_spec.servable_directory + dir).second) { + MSI_LOG_INFO << "Ignore directory " << dir << ", model_directory " << servable_spec.servable_directory + << ", model_name " << servable_spec.servable_name; + } + continue; + } + real_versions->push_back(version_parse); + if (version_parse > newest_version) { + newest_version = version_parse; + } + } + if (version_strategy_ == kVersionStrategyLatest) { + real_versions->clear(); + if (newest_version != 0) { + real_versions->push_back(newest_version); + } + } +} + +Status AscendModelServable::LoadServableConfig(const LoadServableSpec &servable_spec, + const std::string &version_strategy, + std::vector *real_versions) { + MSI_EXCEPTION_IF_NULL(real_versions); + auto model_directory = servable_spec.servable_directory; + auto model_name = servable_spec.servable_name; + + if (!DirOrFileExist(model_directory + "/" + model_name)) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Model not found, model_directory " << model_directory << ", model_name " << model_name; + } + std::string model_path = model_directory + "/" + model_name; + auto version_directory = [model_path](int64_t version_number) { + return model_path + "/" + std::to_string(version_number); + }; + version_strategy_ = version_strategy; + // version_strategy:"specific","latest","multi" + GetVersions(servable_spec, real_versions); + if (real_versions->size() == 0) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Not found invalid model version , model_directory " << model_directory << ", model_name " << model_name; + } + for (auto real_version_number : *real_versions) { + if (!DirOrFileExist(version_directory(real_version_number))) { + return INFER_STATUS_LOG_ERROR(FAILED) << "Open failed for version " << real_version_number << ", model_directory " + << model_directory << ", model_name " << model_name; + } + } + return SUCCESS; +} + +Status AscendModelServable::InitDevice(ModelType model_type, const std::map &other_options) { + Status status; + auto context = ServableContext::Instance(); + DeviceType device_type = ServableContext::Instance()->GetDeviceType(); + auto get_support_device_type = [this, device_type, model_type]() { + std::vector support_device_list; + if (device_type == kDeviceTypeNotSpecified || device_type == kDeviceTypeAscend) { + auto ascend_list = {kDeviceTypeAscendCL, kDeviceTypeAscendMS}; + for (auto item : ascend_list) { + if (session_.CheckModelSupport(item, model_type)) { + return item; + } + } + } else if (device_type == kDeviceTypeAscendCL || device_type == kDeviceTypeAscendMS) { + if (session_.CheckModelSupport(device_type, model_type)) { + return device_type; + } + } + return kDeviceTypeNotSpecified; + }; + auto support_device_type = get_support_device_type(); + if (support_device_type == kDeviceTypeNotSpecified) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Not support device type " << device_type << " and model type " << model_type + << ". Ascend 910 supports MindIR model and Ascend 310 supports OM, MindIR model"; + } + context->SetDeviceType(support_device_type); + return SUCCESS; +} + +Status AscendModelServable::LoadModel(uint64_t version_number) { + ServableSignature signature; + if (!ServableStorage::Instance().GetServableDef(base_spec_.servable_name, &signature)) { + return INFER_STATUS_LOG_ERROR(FAILED) << "Servable " << base_spec_.servable_name << " has not been registered"; + } + const auto &servable_meta = signature.servable_meta; + std::string model_file_name = base_spec_.servable_directory + "/" + base_spec_.servable_name + "/" + + std::to_string(version_number) + "/" + servable_meta.servable_file; + auto context = ServableContext::Instance(); + Status status = session_.LoadModelFromFile(context->GetDeviceType(), context->GetDeviceId(), model_file_name, + servable_meta.model_format, servable_meta.with_batch_dim, + servable_meta.without_batch_dim_inputs, servable_meta.load_options); + if (status != SUCCESS) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Load model failed, servable directory: '" << base_spec_.servable_directory << "', servable name: '" + << base_spec_.servable_name << "', servable file: '" << servable_meta.servable_file << "', version number " + << version_number << ", options " << servable_meta.load_options; + } + return SUCCESS; +} + +} // namespace mindspore::serving diff --git a/mindspore_serving/ccsrc/worker/ascend_servable/ascend_sevable.h b/mindspore_serving/ccsrc/worker/ascend_servable/ascend_sevable.h new file mode 100644 index 0000000..ad29706 --- /dev/null +++ b/mindspore_serving/ccsrc/worker/ascend_servable/ascend_sevable.h @@ -0,0 +1,66 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_SERVING_WORKER_ASCEND_SERVABLE_H +#define MINDSPORE_SERVING_WORKER_ASCEND_SERVABLE_H + +#include +#include +#include +#include + +#include "common/serving_common.h" +#include "common/instance.h" +#include "common/servable.h" +#include "worker/sevable_base.h" +#include "worker/inference/inference.h" +#include "worker/inference/mindspore_model_wrap.h" + +namespace mindspore::serving { + +class MS_API AscendModelServable : public ServableBase { + public: + AscendModelServable() = default; + ~AscendModelServable() override; + + Status Predict(const std::vector &input, std::vector *output) override; + + std::vector GetInputInfos() const override; + std::vector GetOutputInfos() const override; + uint64_t GetBatchSize() const override; + TensorBasePtr MakeInferenceTensor(DataType data_type, const std::vector &shape) const override; + + Status StartServable(const std::string &servable_directory, const std::string &servable_name, + uint32_t version_number); + Status InitDevice(ModelType model_type, const std::map &other_options); + WorkerSpec GetWorkerSpec() const override { return worker_spec_; } + + private: + LoadServableSpec base_spec_; + WorkerSpec worker_spec_; + MindSporeModelWrap session_; + std::string version_strategy_; + bool model_loaded_ = false; + + void GetVersions(const LoadServableSpec &servable_spec, std::vector *real_versions); + Status LoadServableConfig(const LoadServableSpec &servable_spec, const std::string &version_strategy, + std::vector *real_version_number); + Status LoadModel(uint64_t version); +}; + +} // namespace mindspore::serving + +#endif // MINDSPORE_SERVING_WORKER_ASCEND_SERVABLE_H diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/agent_executor.cc b/mindspore_serving/ccsrc/worker/distributed_worker/agent_executor.cc index d2c4a1b..e55509b 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/agent_executor.cc +++ b/mindspore_serving/ccsrc/worker/distributed_worker/agent_executor.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/agent_executor.h b/mindspore_serving/ccsrc/worker/distributed_worker/agent_executor.h index dd5d16a..7a00fb0 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/agent_executor.h +++ b/mindspore_serving/ccsrc/worker/distributed_worker/agent_executor.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.cc b/mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.cc index 9f766a9..b4f5ee9 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.cc +++ b/mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.h b/mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.h index 37df7ce..916fd39 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.h +++ b/mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/common.h b/mindspore_serving/ccsrc/worker/distributed_worker/common.h index 4a8dbb2..801894a 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/common.h +++ b/mindspore_serving/ccsrc/worker/distributed_worker/common.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.cc b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.cc index 1cfd4ba..9327d24 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.cc +++ b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.h b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.h index fad293c..2808943 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.h +++ b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,7 +20,7 @@ #include #include #include -#include "worker/model.h" +#include "worker/sevable_base.h" #include "worker/distributed_worker/common.h" namespace mindspore { diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.cc b/mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.cc index 8d4d1b5..4e21583 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.cc +++ b/mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.h b/mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.h index a160d55..9119e63 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.h +++ b/mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/mindspore_serving/ccsrc/worker/inference/inference.h b/mindspore_serving/ccsrc/worker/inference/inference.h index 72df989..337155c 100644 --- a/mindspore_serving/ccsrc/worker/inference/inference.h +++ b/mindspore_serving/ccsrc/worker/inference/inference.h @@ -52,132 +52,6 @@ enum DeviceType { kDeviceTypeCpu, }; -class MS_API InferSession { - public: - InferSession() = default; - virtual ~InferSession() = default; - virtual Status InitEnv(DeviceType device_type, uint32_t device_id, - const std::map &other_options) = 0; - virtual Status FinalizeEnv() = 0; - - virtual Status LoadModelFromFile(serving::DeviceType device_type, uint32_t device_id, const std::string &file_name, - ModelType model_type, const std::vector &without_batch_dim_inputs, - const std::map &other_options, uint32_t *model_id) = 0; - - virtual Status UnloadModel(uint32_t model_id) = 0; - // override this method to avoid request/reply data copy - virtual Status ExecuteModel(uint32_t model_id, const RequestBase &request, ReplyBase *reply) = 0; - virtual Status ExecuteModel(uint32_t model_id, const std::vector &request, - std::vector *reply) { - VectorTensorPtrWrapRequest wrap_request(request); - VectorTensorPtrWrapReply wrap_reply(reply, []() { return std::make_shared(); }); - return ExecuteModel(model_id, wrap_request, &wrap_reply); - } - - virtual std::vector GetInputInfos(uint32_t model_id) const = 0; - virtual std::vector GetOutputInfos(uint32_t model_id) const = 0; - virtual ssize_t GetBatchSize(uint32_t model_id) const = 0; - virtual bool CheckModelSupport(DeviceType device_type, ModelType model_type) const { return true; } -}; - -struct InferSessionRegInfo { - std::shared_ptr session; - ModelType model_type; - int priority; -}; - -class MS_API InferSessionStorage { - public: - void Register(DeviceType device_type, ModelType model_type, const std::shared_ptr &session, - int priority) { - auto &list = session_map_[device_type]; - InferSessionRegInfo info{session, model_type, priority}; - list.push_back(info); - } - - std::shared_ptr Get(DeviceType device_type, ModelType model_type, DeviceType *specified_device_type) { - MSI_EXCEPTION_IF_NULL(specified_device_type); - if (device_type == kDeviceTypeNotSpecified) { - for (auto &item_device : session_map_) { - std::shared_ptr ret_session = GetSession(item_device.second, item_device.first, model_type); - if (ret_session) { - *specified_device_type = item_device.first; - return ret_session; - } - } - return nullptr; - } else if (device_type == kDeviceTypeAscend) { - auto ascend_list = {kDeviceTypeAscendCL, kDeviceTypeAscendMS}; - for (auto ascend_type : ascend_list) { - auto it = session_map_.find(ascend_type); - if (it == session_map_.end()) { - continue; - } - auto session_ret = GetSession(it->second, ascend_type, model_type); - if (session_ret != nullptr) { - *specified_device_type = ascend_type; - return session_ret; - } - } - return nullptr; - } - auto it = session_map_.find(device_type); - if (it == session_map_.end()) { - return nullptr; - } - std::shared_ptr session_ret; - session_ret = GetSession(it->second, device_type, model_type); - *specified_device_type = device_type; - return session_ret; - } - - static InferSessionStorage &Instance() { - static InferSessionStorage instance; - return instance; - } - - private: - std::unordered_map> session_map_; - - std::shared_ptr GetSession(const std::vector &session_list, DeviceType device_type, - ModelType model_type) { - std::shared_ptr session_ret = nullptr; - int cur_priority = INT32_MIN; - for (auto &item : session_list) { - if (item.model_type != model_type) { - continue; - } - if (session_ret == nullptr || cur_priority < item.priority) { - if (!item.session->CheckModelSupport(device_type, model_type)) { - MSI_LOG_INFO << "CheckModelSupport for " << device_type << " " << model_type << " failed, skipped"; - continue; - } - cur_priority = item.priority; - session_ret = item.session; - } - } - return session_ret; - } -}; - -class MS_API InferSessionRegister { - public: - InferSessionRegister(DeviceType device_type, ModelType model_type, const std::shared_ptr &session, - int priority) { - InferSessionStorage::Instance().Register(device_type, model_type, session, priority); - } -}; - -#define REGISTER_INFER_SEESION_UNIQUE(device_type, model_type, cls_name, priority, index) \ - static mindspore::serving::InferSessionRegister g_register_session_##cls_name##_##index( \ - device_type, model_type, std::make_shared(), priority); - -#define REGISTER_INFER_SEESION_HELPER(device_type, model_type, cls_name, priority, index) \ - REGISTER_INFER_SEESION_UNIQUE(device_type, model_type, cls_name, priority, index) - -#define REGISTER_INFER_SEESION(device_type, model_type, cls_name, priority) \ - REGISTER_INFER_SEESION_HELPER(device_type, model_type, cls_name, priority, __COUNTER__); - static inline LogStream &operator<<(LogStream &stream, DeviceType device_type) { switch (device_type) { case kDeviceTypeAscend: diff --git a/mindspore_serving/ccsrc/worker/inference/mindspore_model_wrap.cc b/mindspore_serving/ccsrc/worker/inference/mindspore_model_wrap.cc index 40f8c81..1affc0e 100644 --- a/mindspore_serving/ccsrc/worker/inference/mindspore_model_wrap.cc +++ b/mindspore_serving/ccsrc/worker/inference/mindspore_model_wrap.cc @@ -26,16 +26,6 @@ namespace mindspore { namespace serving { -Status MindSporeModelWrap::InitEnv(serving::DeviceType device_type, uint32_t device_id, - const std::map &other_options) { - return SUCCESS; -} - -Status MindSporeModelWrap::FinalizeEnv() { - model_map_.clear(); - return SUCCESS; -} - mindspore::DataType TransInferDataType2ApiTypeId(DataType data_type) { const std::map type2id_map{ {serving::kMSI_Unknown, mindspore::DataType::kTypeUnknown}, @@ -81,11 +71,9 @@ DataType TransTypeId2InferDataType(mindspore::DataType type_id) { } Status MindSporeModelWrap::LoadModelFromFile(serving::DeviceType device_type, uint32_t device_id, - const std::string &file_name, ModelType model_type, + const std::string &file_name, ModelType model_type, bool with_batch_dim, const std::vector &without_batch_dim_inputs, - const std::map &other_options, - uint32_t *model_id) { - MSI_EXCEPTION_IF_NULL(model_id); + const std::map &other_options) { std::string device_type_str; if (device_type == kDeviceTypeAscendMS) { device_type_str = mindspore::kDeviceTypeAscend910; @@ -113,18 +101,18 @@ Status MindSporeModelWrap::LoadModelFromFile(serving::DeviceType device_type, ui << "', device_id: " << device_id << ", model type: " << model_type << ", options: " << other_options; return Status(FAILED, status.ToString()); } - model_index_++; - *model_id = model_index_; ApiModelInfo api_model_info; api_model_info.model = model; api_model_info.device_type = device_type_str; api_model_info.device_id = device_id; + api_model_info.with_batch_dim = with_batch_dim; api_model_info.without_batch_dim_inputs = without_batch_dim_inputs; auto st = GetModelInfos(&api_model_info); if (st != SUCCESS) { return st; } - model_map_[*model_id] = api_model_info; + GetModelBatchSize(&api_model_info); + model_ = api_model_info; MSI_LOG_INFO << "Load model from file success, model file: " << file_name << ", device_type: '" << device_type_str << "', device_id: " << device_id << ", model type: " << model_type << ", options: " << other_options; return SUCCESS; @@ -169,20 +157,6 @@ Status MindSporeModelWrap::GetModelInfos(ApiModelInfo *api_model_info) { MSI_EXCEPTION_IF_NULL(api_model_info); auto model = api_model_info->model; - bool first_dim_same = true; - auto find_batch_size = [&first_dim_same, api_model_info](const std::vector &shape) { - if (first_dim_same) { - if (shape.empty()) { - first_dim_same = false; - } else if (api_model_info->batch_size != 0) { - if (api_model_info->batch_size != shape[0]) { - first_dim_same = false; - } - } else { - api_model_info->batch_size = shape[0]; - } - } - }; auto get_tensor_info_from_tensor = [](const mindspore::MSTensor &ms_tensor) { serving::TensorInfo tensor_info; tensor_info.shape = ms_tensor.Shape(); @@ -204,10 +178,6 @@ Status MindSporeModelWrap::GetModelInfos(ApiModelInfo *api_model_info) { return INFER_STATUS_LOG_ERROR(FAILED) << "Unknown input mindspore data type " << static_cast(info.DataType()); } - const auto &list = api_model_info->without_batch_dim_inputs; - if (std::find(list.begin(), list.end(), i) == list.end()) { - find_batch_size(tensor_info.shape); - } api_model_info->input_tensor_infos.push_back(tensor_info); api_model_info->input_names.push_back(info.Name()); } @@ -220,27 +190,59 @@ Status MindSporeModelWrap::GetModelInfos(ApiModelInfo *api_model_info) { return INFER_STATUS_LOG_ERROR(FAILED) << "Unknown output mindspore data type " << static_cast(info.DataType()); } - find_batch_size(tensor_info.shape); api_model_info->output_tensor_infos.push_back(tensor_info); api_model_info->output_names.push_back(info.Name()); } } + return SUCCESS; +} + +void MindSporeModelWrap::GetModelBatchSize(ApiModelInfo *api_model_info) { + MSI_EXCEPTION_IF_NULL(api_model_info); + bool first_dim_same = true; + auto find_batch_size = [&first_dim_same, api_model_info](const std::vector &shape) { + if (!api_model_info->with_batch_dim) { + first_dim_same = false; + return; + } + if (!first_dim_same) { + return; + } + if (shape.empty()) { + first_dim_same = false; + return; + } + if (api_model_info->batch_size != 0) { + if (api_model_info->batch_size != shape[0]) { + first_dim_same = false; + } + } else { + api_model_info->batch_size = shape[0]; + } + }; + + auto list = api_model_info->without_batch_dim_inputs; + auto size = api_model_info->input_tensor_infos.size(); + for (size_t i = 0; i < size; i++) { + if (std::find(list.begin(), list.end(), i) == list.end()) { + auto &info = api_model_info->input_tensor_infos[i]; + find_batch_size(info.shape); + } + } + for (auto &info : api_model_info->output_tensor_infos) { + find_batch_size(info.shape); + } if (!first_dim_same) { api_model_info->batch_size = 0; } - return SUCCESS; } -Status MindSporeModelWrap::UnloadModel(uint32_t model_id) { - auto it = model_map_.find(model_id); - if (it == model_map_.end()) { - return INFER_STATUS_LOG_ERROR(FAILED) << "Invalid model id " << model_id; - } - model_map_.erase(it); +Status MindSporeModelWrap::UnloadModel() { + model_.model = nullptr; return SUCCESS; } -Status MindSporeModelWrap::ExecuteModel(uint32_t model_id, const RequestBase &request, serving::ReplyBase *reply) { +Status MindSporeModelWrap::ExecuteModel(const RequestBase &request, serving::ReplyBase *reply) { MSI_EXCEPTION_IF_NULL(reply); FuncMakeInBuffer func_in = [&request](size_t index, const std::string &name) { auto input_tensor = request[index]; @@ -260,11 +262,10 @@ Status MindSporeModelWrap::ExecuteModel(uint32_t model_id, const RequestBase &re tensor->set_data_type(data_type); tensor->set_shape(shape); }; - return ExecuteModelCommon(model_id, request.size(), func_in, func_out); + return ExecuteModelCommon(request.size(), func_in, func_out); } -Status MindSporeModelWrap::ExecuteModel(uint32_t model_id, const std::vector &request, - std::vector *reply) { +Status MindSporeModelWrap::ExecuteModel(const std::vector &request, std::vector *reply) { MSI_EXCEPTION_IF_NULL(reply); FuncMakeInBuffer func_in = [&request](size_t index, const std::string &name) { auto &input_tensor = request[index]; @@ -282,16 +283,15 @@ Status MindSporeModelWrap::ExecuteModel(uint32_t model_id, const std::vectorset_shape(shape); reply->push_back(tensor); }; - return ExecuteModelCommon(model_id, request.size(), func_in, func_out); + return ExecuteModelCommon(request.size(), func_in, func_out); } -Status MindSporeModelWrap::ExecuteModelCommon(uint32_t model_id, size_t request_size, const FuncMakeInBuffer &in_func, +Status MindSporeModelWrap::ExecuteModelCommon(size_t request_size, const FuncMakeInBuffer &in_func, const FuncMakeOutTensor &out_func) { - auto it = model_map_.find(model_id); - if (it == model_map_.end()) { - return INFER_STATUS_LOG_ERROR(FAILED) << "Invalid model id " << model_id; + if (model_.model == nullptr) { + return INFER_STATUS_LOG_ERROR(FAILED) << "Model is not loaded"; } - auto &model_info = it->second; + auto &model_info = model_; auto model = model_info.model; auto &input_names = model_info.input_names; auto &output_names = model_info.output_names; @@ -327,43 +327,25 @@ Status MindSporeModelWrap::ExecuteModelCommon(uint32_t model_id, size_t request_ return SUCCESS; } -std::vector MindSporeModelWrap::GetInputInfos(uint32_t model_id) const { - auto it = model_map_.find(model_id); - if (it == model_map_.end()) { - MSI_LOG_ERROR << "Invalid model id " << model_id; - return {}; - } - auto &model_info = it->second; - return model_info.input_tensor_infos; -} +std::vector MindSporeModelWrap::GetInputInfos() const { return model_.input_tensor_infos; } -std::vector MindSporeModelWrap::GetOutputInfos(uint32_t model_id) const { - auto it = model_map_.find(model_id); - if (it == model_map_.end()) { - MSI_LOG_ERROR << "Invalid model id " << model_id; - return {}; - } - auto &model_info = it->second; - return model_info.output_tensor_infos; -} +std::vector MindSporeModelWrap::GetOutputInfos() const { return model_.output_tensor_infos; } -ssize_t MindSporeModelWrap::GetBatchSize(uint32_t model_id) const { - auto it = model_map_.find(model_id); - if (it == model_map_.end()) { - MSI_LOG_ERROR << "Invalid model id " << model_id; - return {}; - } - auto &model_info = it->second; - return model_info.batch_size; -} +ssize_t MindSporeModelWrap::GetBatchSize() const { return model_.batch_size; } bool MindSporeModelWrap::CheckModelSupport(DeviceType device_type, ModelType model_type) const { std::string device_type_str; switch (device_type) { case kDeviceTypeAscendMS: + if (model_type != kMindIR) { + return false; + } device_type_str = mindspore::kDeviceTypeAscend910; break; case kDeviceTypeAscendCL: + if (model_type != kMindIR && model_type != kOM) { + return false; + } device_type_str = mindspore::kDeviceTypeAscend310; break; default: @@ -378,9 +360,5 @@ ApiBufferTensorWrap::ApiBufferTensorWrap(const mindspore::MSTensor &tensor) : te ApiBufferTensorWrap::~ApiBufferTensorWrap() = default; -REGISTER_INFER_SEESION(serving::kDeviceTypeAscendCL, kOM, MindSporeModelWrap, 1); -REGISTER_INFER_SEESION(serving::kDeviceTypeAscendCL, kMindIR, MindSporeModelWrap, 1); -REGISTER_INFER_SEESION(serving::kDeviceTypeAscendMS, kMindIR, MindSporeModelWrap, 1); - } // namespace serving } // namespace mindspore diff --git a/mindspore_serving/ccsrc/worker/inference/mindspore_model_wrap.h b/mindspore_serving/ccsrc/worker/inference/mindspore_model_wrap.h index 14432ec..02f6f18 100644 --- a/mindspore_serving/ccsrc/worker/inference/mindspore_model_wrap.h +++ b/mindspore_serving/ccsrc/worker/inference/mindspore_model_wrap.h @@ -34,54 +34,46 @@ struct ApiModelInfo { std::vector input_tensor_infos; std::vector output_names; std::vector output_tensor_infos; - std::shared_ptr model; + std::shared_ptr model = nullptr; uint32_t batch_size = 0; std::string device_type; uint32_t device_id = 0; + bool with_batch_dim = false; std::vector without_batch_dim_inputs; }; -class MindSporeModelWrap : public InferSession { +class MindSporeModelWrap { public: MindSporeModelWrap() = default; ~MindSporeModelWrap() = default; - Status InitEnv(serving::DeviceType device_type, uint32_t device_id, - const std::map &other_options) override; - - Status FinalizeEnv() override; - Status LoadModelFromFile(serving::DeviceType device_type, uint32_t device_id, const std::string &file_name, - ModelType model_type, const std::vector &without_batch_dim_inputs, - const std::map &other_options, uint32_t *model_id) override; - - Status UnloadModel(uint32_t model_id) override; + ModelType model_type, bool with_batch_dim, const std::vector &without_batch_dim_inputs, + const std::map &other_options); - // override this method to avoid request/reply data copy - Status ExecuteModel(uint32_t model_id, const RequestBase &request, ReplyBase *reply) override; - Status ExecuteModel(uint32_t model_id, const std::vector &request, - std::vector *reply) override; + Status UnloadModel(); + Status ExecuteModel(const RequestBase &request, ReplyBase *reply); + Status ExecuteModel(const std::vector &request, std::vector *reply); - std::vector GetInputInfos(uint32_t model_id) const override; + std::vector GetInputInfos() const; - std::vector GetOutputInfos(uint32_t model_id) const override; + std::vector GetOutputInfos() const; - ssize_t GetBatchSize(uint32_t model_id) const override; + ssize_t GetBatchSize() const; - bool CheckModelSupport(DeviceType device_type, ModelType model_type) const override; + bool CheckModelSupport(DeviceType device_type, ModelType model_type) const; private: - std::unordered_map model_map_; - uint32_t model_index_ = 0; + ApiModelInfo model_; using FuncMakeInBuffer = std::function; using FuncMakeOutTensor = std::function &shape)>; - Status ExecuteModelCommon(uint32_t model_id, size_t request_size, const FuncMakeInBuffer &in_func, - const FuncMakeOutTensor &out_func); + Status ExecuteModelCommon(size_t request_size, const FuncMakeInBuffer &in_func, const FuncMakeOutTensor &out_func); Status GetModelInfos(ApiModelInfo *model_info); std::shared_ptr TransformModelContext(const std::map &other_options); + void GetModelBatchSize(ApiModelInfo *model_info); }; class ApiBufferTensorWrap : public TensorBase { diff --git a/mindspore_serving/ccsrc/worker/model.cc b/mindspore_serving/ccsrc/worker/model.cc deleted file mode 100644 index 814bd02..0000000 --- a/mindspore_serving/ccsrc/worker/model.cc +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "worker/model.h" -#include -#include "mindspore_serving/ccsrc/common/tensor.h" - -namespace mindspore::serving { - -Status AscendModelServable::Predict(const std::vector &input, std::vector *output) { - return session_->ExecuteModel(model_id_, input, output); -} - -std::vector AscendModelServable::GetInputInfos() const { return session_->GetInputInfos(model_id_); } - -std::vector AscendModelServable::GetOutputInfos() const { return session_->GetOutputInfos(model_id_); } - -uint64_t AscendModelServable::GetBatchSize() const { return session_->GetBatchSize(model_id_); } - -} // namespace mindspore::serving diff --git a/mindspore_serving/ccsrc/worker/model.h b/mindspore_serving/ccsrc/worker/sevable_base.h similarity index 63% rename from mindspore_serving/ccsrc/worker/model.h rename to mindspore_serving/ccsrc/worker/sevable_base.h index dcb23e0..4185b13 100644 --- a/mindspore_serving/ccsrc/worker/model.h +++ b/mindspore_serving/ccsrc/worker/sevable_base.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_SERVING_WORKER_MODEL_H -#define MINDSPORE_SERVING_WORKER_MODEL_H +#ifndef MINDSPORE_SERVING_WORKER_SERVABLE_BASE_H +#define MINDSPORE_SERVING_WORKER_SERVABLE_BASE_H #include #include @@ -39,25 +39,9 @@ class ServableBase { virtual std::vector GetInputInfos() const = 0; virtual std::vector GetOutputInfos() const = 0; virtual uint64_t GetBatchSize() const = 0; -}; - -class AscendModelServable : public ServableBase { - public: - AscendModelServable(const std::shared_ptr &session, uint32_t model_id) - : session_(session), model_id_(model_id) {} - ~AscendModelServable() = default; - - Status Predict(const std::vector &input, std::vector *output) override; - - std::vector GetInputInfos() const override; - std::vector GetOutputInfos() const override; - uint64_t GetBatchSize() const override; - - private: - std::shared_ptr session_{nullptr}; - uint32_t model_id_ = 0; + virtual WorkerSpec GetWorkerSpec() const = 0; }; } // namespace mindspore::serving -#endif // MINDSPORE_SERVING_WORKER_MODEL_H +#endif // MINDSPORE_SERVING_WORKER_SERVABLE_BASE_H diff --git a/mindspore_serving/ccsrc/worker/work_executor.h b/mindspore_serving/ccsrc/worker/work_executor.h index 7ca9c3b..2c44e6a 100644 --- a/mindspore_serving/ccsrc/worker/work_executor.h +++ b/mindspore_serving/ccsrc/worker/work_executor.h @@ -28,7 +28,7 @@ #include "common/serving_common.h" #include "common/instance.h" #include "common/servable.h" -#include "worker/model.h" +#include "worker/sevable_base.h" #include "worker/predict_thread.h" #include "worker/task_queue.h" diff --git a/mindspore_serving/ccsrc/worker/worker.cc b/mindspore_serving/ccsrc/worker/worker.cc index aad5339..6dc4531 100644 --- a/mindspore_serving/ccsrc/worker/worker.cc +++ b/mindspore_serving/ccsrc/worker/worker.cc @@ -34,8 +34,6 @@ namespace py = pybind11; namespace mindspore { namespace serving { -static const char *kVersionStrategyLastest = "lastest"; -static const char *kVersionStrategySpecific = "specific"; static std::unique_ptr grpc_async_worker_server_; Worker &Worker::GetInstance() { @@ -52,28 +50,10 @@ Status Worker::StartGrpcServer(const std::string &ip, uint32_t grpc_port) { } Status Worker::RegisterWorker() { - std::vector specs; - std::vector signatures; - for (auto &work : work_list_) { - specs.push_back(work.servable_spec); - signatures.push_back(work.servable_signature); - } std::vector worker_specs; - for (size_t i = 0; i < specs.size(); i++) { - auto &spec = specs[i]; - auto &servable_signature = signatures[i]; - WorkerSpec worker_spec; - worker_spec.servable_name = spec.servable_name; - worker_spec.version_number = spec.version_number; - for (auto &method : servable_signature.methods) { - WorkerMethodInfo worker_method_info; - worker_method_info.name = method.method_name; - for (auto &name : method.inputs) { - worker_method_info.input_names.push_back(name); - } - worker_spec.methods.push_back(worker_method_info); - } - worker_specs.push_back(worker_spec); + for (auto &work : work_list_) { + // cppcheck-suppress useStlAlgorithm + worker_specs.push_back(work.worker_spec); } auto status = notify_master_->Register(worker_specs); return status; @@ -84,34 +64,10 @@ Status Worker::StartVersionController() { return SUCCESS; } -Status Worker::AddWorker(const ServableWorkerContext &work) { - WorkerSpec worker_spec; - worker_spec.servable_name = work.servable_spec.servable_name; - worker_spec.version_number = work.servable_spec.version_number; - for (auto &method : work.servable_signature.methods) { - WorkerMethodInfo worker_method_info; - worker_method_info.name = method.method_name; - for (auto &name : method.inputs) { - worker_method_info.input_names.push_back(name); - } - worker_spec.methods.push_back(worker_method_info); - } - return notify_master_->AddWorker(worker_spec); -} +Status Worker::AddWorker(const ServableWorkerContext &work) { return notify_master_->AddWorker(work.worker_spec); } Status Worker::RemoveWorker(const ServableWorkerContext &work) { - WorkerSpec worker_spec; - worker_spec.servable_name = work.servable_spec.servable_name; - worker_spec.version_number = work.servable_spec.version_number; - for (auto &method : work.servable_signature.methods) { - WorkerMethodInfo worker_method_info; - worker_method_info.name = method.method_name; - for (auto &name : method.inputs) { - worker_method_info.input_names.push_back(name); - } - worker_spec.methods.push_back(worker_method_info); - } - return notify_master_->RemoveWorker(worker_spec); + return notify_master_->RemoveWorker(work.worker_spec); } Status Worker::Run(const proto::PredictRequest &request, proto::PredictReply *reply) { @@ -189,74 +145,8 @@ std::pair> Worker::RunAsync(const RequestSp return {SUCCESS, result}; } -Status Worker::InitEnv(ModelType model_type, const std::map &other_options) { - Status status; - if (session_) { - return INFER_STATUS_LOG_ERROR(FAILED) << "Session has been inited"; - } - auto context = ServableContext::Instance(); - DeviceType device_type = kDeviceTypeNotSpecified; - session_ = InferSessionStorage::Instance().Get(context->GetDeviceType(), model_type, &device_type); - if (session_ == nullptr) { - return INFER_STATUS_LOG_ERROR(FAILED) - << "Cannot find session registered for device type " << context->GetDeviceType() << " and model type " - << model_type << ". Ascend 910 supports MindIR model and Ascend 310 supports OM, MindIR model"; - } - if (device_type != kDeviceTypeNotSpecified) { - context->SetDeviceType(device_type); - } - status = session_->InitEnv(context->GetDeviceType(), context->GetDeviceId(), other_options); - if (status != SUCCESS) { - session_ = nullptr; - return INFER_STATUS_LOG_ERROR(FAILED) - << "Init env failed, device type " << context->GetDeviceType() << ", device id " << context->GetDeviceId(); - } - return SUCCESS; -} - -Status Worker::FinalizeEnv() { - if (session_ != nullptr) { - return session_->FinalizeEnv(); - } - return SUCCESS; -} -Status Worker::LoadModel(LoadServableSpec *servable_spec, uint64_t version_number, ServableWorkerContext *work) { - MSI_EXCEPTION_IF_NULL(servable_spec); - MSI_EXCEPTION_IF_NULL(work); - servable_spec->version_number = version_number; - ServableSignature signature; - if (!ServableStorage::Instance().GetServableDef(servable_spec->servable_name, &signature)) { - return INFER_STATUS_LOG_ERROR(FAILED) << "Servable " << servable_spec->servable_name << " has not been registerd"; - } - const auto &servable_meta = signature.servable_meta; - std::string model_file_name = servable_spec->servable_directory + "/" + servable_spec->servable_name + "/" + - std::to_string(version_number) + "/" + servable_meta.servable_file; - uint32_t model_id; - auto context = ServableContext::Instance(); - Status status = session_->LoadModelFromFile(context->GetDeviceType(), context->GetDeviceId(), model_file_name, - servable_meta.model_format, servable_meta.without_batch_dim_inputs, - servable_meta.load_options, &model_id); - if (status != SUCCESS) { - return INFER_STATUS_LOG_ERROR(FAILED) - << "Load model failed, servable directory: '" << servable_spec->servable_directory << "', servable name: '" - << servable_spec->servable_name << "', servable file: '" << servable_meta.servable_file - << "', version number " << version_number << ", options " << servable_meta.load_options; - } - auto service = std::make_shared(GetPyTaskQueuePreprocess(), GetPyTaskQueuePostprocess(), - GetCppTaskQueuePreprocess(), GetCppTaskQueuePostprocess()); - status = service->Init(signature, std::make_shared(session_, model_id)); - if (status != SUCCESS) { - return status; - } - work->servable_spec = *servable_spec; - work->servable_signature = signature; - work->worker_service = service; - work->model_id = model_id; - work->model_file_name = model_file_name; - return SUCCESS; -} - void Worker::Update() { + /* if (version_strategy_ == kVersionStrategySpecific) { return; } @@ -291,10 +181,10 @@ void Worker::Update() { MSI_LOG_INFO << "UnLoad Model version " << iter->servable_spec.version_number << " success"; work_list_.erase(iter); } + */ } -Status Worker::StartServable(const std::string &servable_directory, const std::string &servable_name, - uint32_t version_number, std::shared_ptr notify_master) { +Status Worker::StartServable(std::shared_ptr servable, std::shared_ptr notify_master) { ExitSignalHandle::Instance().Start(); // handle ctrl+c to exit if (servable_started_) { MSI_LOG_EXCEPTION << "A servable has been started, only one servable can run in a process currently."; @@ -307,58 +197,31 @@ Status Worker::StartServable(const std::string &servable_directory, const std::s cpp_postprocess_.Start(2); notify_master_ = std::move(notify_master); - base_spec_.servable_directory = servable_directory; - base_spec_.servable_name = servable_name; - base_spec_.version_number = version_number; - - std::string version_strategy; - if (version_number == 0) { - version_strategy = kVersionStrategyLastest; - } else { - version_strategy = kVersionStrategySpecific; - } - Status status; + auto worker_spec = servable->GetWorkerSpec(); ServableSignature signature; - if (!ServableStorage::Instance().GetServableDef(servable_name, &signature)) { - return INFER_STATUS_LOG_ERROR(FAILED) << "Servable '" << servable_name << "' has not been registered"; + if (!ServableStorage::Instance().GetServableDef(worker_spec.servable_name, &signature)) { + return INFER_STATUS_LOG_ERROR(FAILED) << "Servable " << worker_spec.servable_name << " has not been registered"; } - if (session_ == nullptr) { - status = InitEnv(signature.servable_meta.model_format, {}); - if (status != SUCCESS) { - MSI_LOG_ERROR << "Init env failed"; - return status; - } - } - - std::vector real_versions; - status = LoadServableConfig(base_spec_, version_strategy, &real_versions); + auto service = std::make_shared(GetPyTaskQueuePreprocess(), GetPyTaskQueuePostprocess(), + GetCppTaskQueuePreprocess(), GetCppTaskQueuePostprocess()); + auto status = service->Init(signature, servable); if (status != SUCCESS) { - return INFER_STATUS_LOG_ERROR(FAILED) - << "Start servable failed, there is no servable of the specified version number, specified version number: " - << version_number << ", servable directory: '" << base_spec_.servable_directory << "', servable name: '" - << base_spec_.servable_name - << "'. version number is a positive integer(started from 1) and 0 represents the maximum version number."; - } - for (auto real_version_number : real_versions) { - ServableWorkerContext work; - status = LoadModel(&base_spec_, real_version_number, &work); - if (status != SUCCESS) { - return status; - } - work_list_.push_back(work); + return status; } + ServableWorkerContext work; + work.worker_spec = worker_spec; + work.servable_signature = signature; + work.worker_service = service; + work.servable = servable; + + work_list_.push_back(work); + status = RegisterWorker(); if (status != SUCCESS) { MSI_LOG_ERROR << "Register worker failed"; return status; } servable_started_ = true; - status = INFER_STATUS(SUCCESS) << "Serving: Start servable success, servable directory: '" << servable_directory - << "', servable name: '" << servable_name - << "', specified version number: " << version_number - << ", started version numbers: " << real_versions; - MSI_LOG_INFO << status.StatusMessage(); - std::cout << status.StatusMessage() << std::endl; return SUCCESS; } @@ -368,29 +231,22 @@ void Worker::StopServable(bool notify_master) { } void Worker::Clear() { + std::unique_lock lock(worker_shared_lock_); + ServableStorage::Instance().Clear(); + grpc_async_worker_server_ = nullptr; if (clear_flag_.test_and_set()) { return; } - std::unique_lock lock(worker_shared_lock_); MSI_LOG_INFO << "Start clear worker session"; version_controller_.StopPollModelPeriodic(); if (exit_notify_master_ && servable_started_) { notify_master_->Unregister(); } - if (session_ != nullptr) { - for (auto &it : work_list_) { - session_->UnloadModel(it.model_id); - } - } work_list_.clear(); - FinalizeEnv(); - session_ = nullptr; py_task_queue_group_.Stop(); cpp_preprocess_.Stop(); cpp_postprocess_.Stop(); - ServableStorage::Instance().Clear(); - grpc_async_worker_server_ = nullptr; servable_started_ = false; MSI_LOG_INFO << "End clear worker session"; } @@ -399,88 +255,12 @@ bool Worker::HasCleared() { return !servable_started_; } Worker::~Worker() { Clear(); } -void Worker::GetVersions(const LoadServableSpec &servable_spec, std::vector *real_versions) { - MSI_EXCEPTION_IF_NULL(real_versions); - // define version_strategy:"specific","lastest","multi" - if (version_strategy_ == kVersionStrategySpecific) { - real_versions->push_back(servable_spec.version_number); - return; - } - auto trans_to_integer = [](const std::string &str) -> uint32_t { - uint32_t parsed_value = 0; - for (auto c : str) { - if (c < '0' || c > '9') { - return 0; - } - parsed_value = parsed_value * 10 + c - '0'; - } - if (std::to_string(parsed_value) != str) { - return 0; - } - return parsed_value; - }; - uint64_t newest_version = 0; - std::string model_path = servable_spec.servable_directory + "/" + servable_spec.servable_name; - auto sub_dir = GetAllSubDirsNotFullPath(model_path); - static std::set ignore_dir; - for (const auto &dir : sub_dir) { - if (dir == "__pycache__") continue; - auto version_parse = trans_to_integer(dir); - if (version_parse == 0) { - if (ignore_dir.emplace(servable_spec.servable_directory + dir).second) { - MSI_LOG_INFO << "Ignore directory " << dir << ", model_directory " << servable_spec.servable_directory - << ", model_name " << servable_spec.servable_name; - } - continue; - } - real_versions->push_back(version_parse); - if (version_parse > newest_version) { - newest_version = version_parse; - } - } - if (version_strategy_ == kVersionStrategyLastest) { - real_versions->clear(); - if (newest_version != 0) { - real_versions->push_back(newest_version); - } - } -} -Status Worker::LoadServableConfig(const LoadServableSpec &servable_spec, const std::string &version_strategy, - std::vector *real_versions) { - MSI_EXCEPTION_IF_NULL(real_versions); - auto model_directory = servable_spec.servable_directory; - auto model_name = servable_spec.servable_name; - - if (!DirOrFileExist(model_directory + "/" + model_name)) { - return INFER_STATUS_LOG_ERROR(FAILED) - << "Model not found, model_directory " << model_directory << ", model_name " << model_name; - } - std::string model_path = model_directory + "/" + model_name; - auto version_directory = [model_path](int64_t version_number) { - return model_path + "/" + std::to_string(version_number); - }; - version_strategy_ = version_strategy; - // version_strategy:"specific","lastest","multi" - GetVersions(servable_spec, real_versions); - if (real_versions->size() == 0) { - return INFER_STATUS_LOG_ERROR(FAILED) - << "Not found invalid model version , model_directory " << model_directory << ", model_name " << model_name; - } - for (auto real_version_number : *real_versions) { - if (!DirOrFileExist(version_directory(real_version_number))) { - return INFER_STATUS_LOG_ERROR(FAILED) << "Open failed for version " << real_version_number << ", model_directory " - << model_directory << ", model_name " << model_name; - } - } - return SUCCESS; -} - ServableWorkerContext Worker::GetServableWorker(const RequestSpec &request_spec) { ServableWorkerContext context; if (request_spec.version_number != 0) { auto item = find_if(work_list_.begin(), work_list_.end(), [&](const ServableWorkerContext &v) { - return v.servable_spec.servable_name == request_spec.servable_name && - v.servable_spec.version_number == request_spec.version_number; + return v.worker_spec.servable_name == request_spec.servable_name && + v.worker_spec.version_number == request_spec.version_number; }); if (item != work_list_.end()) { context = *item; @@ -488,10 +268,10 @@ ServableWorkerContext Worker::GetServableWorker(const RequestSpec &request_spec) } else { uint64_t max_version = 0; for (auto &item : work_list_) { - if (item.servable_spec.servable_name == request_spec.servable_name && - item.servable_spec.version_number > max_version) { + if (item.worker_spec.servable_name == request_spec.servable_name && + item.worker_spec.version_number > max_version) { context = item; - max_version = item.servable_spec.version_number; + max_version = item.worker_spec.version_number; } } } @@ -500,11 +280,11 @@ ServableWorkerContext Worker::GetServableWorker(const RequestSpec &request_spec) Worker::Worker() {} -ssize_t Worker::GetBatchSize() const { - ssize_t batch_size_ret = -1; - for (auto service : work_list_) { - auto batch_size = session_->GetBatchSize(service.model_id); - if (batch_size != -1) { +size_t Worker::GetBatchSize() const { + size_t batch_size_ret = 1; + for (const auto &service : work_list_) { + auto batch_size = service.servable->GetBatchSize(); + if (batch_size != 0) { batch_size_ret = batch_size; break; } diff --git a/mindspore_serving/ccsrc/worker/worker.h b/mindspore_serving/ccsrc/worker/worker.h index 007a5cd..122b7d4 100644 --- a/mindspore_serving/ccsrc/worker/worker.h +++ b/mindspore_serving/ccsrc/worker/worker.h @@ -32,6 +32,7 @@ #include "worker/task_queue.h" #include "worker/version_control/version_controller.h" #include "common/grpc_async_server.h" +#include "worker/sevable_base.h" namespace mindspore { namespace serving { @@ -53,11 +54,10 @@ class AsyncResult { }; struct ServableWorkerContext { - LoadServableSpec servable_spec; + WorkerSpec worker_spec; ServableSignature servable_signature; std::shared_ptr worker_service = nullptr; - uint32_t model_id = 0; - std::string model_file_name; + std::shared_ptr servable = nullptr; }; class MS_API Worker { @@ -72,17 +72,12 @@ class MS_API Worker { Status Run(const RequestSpec &request_spec, const std::vector &inputs, std::vector *outputs); std::pair> RunAsync(const RequestSpec &request_spec, const std::vector &inputs); + Status StartServable(std::shared_ptr servable, std::shared_ptr notify_master); - Status InitEnv(ModelType model_type, const std::map &other_options); - Status FinalizeEnv(); - - Status StartServable(const std::string &servable_directory, const std::string &servable_name, uint32_t version_number, - std::shared_ptr notify_master); void StopServable(bool notify_master = true); bool HasCleared(); Status RegisterWorker(); Status StartGrpcServer(const std::string &ip, uint32_t grpc_port); - Status LoadModel(LoadServableSpec *servable_spec, uint64_t version, ServableWorkerContext *work); void Update(); Status StartVersionController(); Status AddWorker(const ServableWorkerContext &work); @@ -93,20 +88,15 @@ class MS_API Worker { std::shared_ptr GetPyTaskQueuePostprocess() { return py_task_queue_group_.GetPostprocessTaskQueue(); } std::shared_ptr GetCppTaskQueuePreprocess() { return cpp_preprocess_.GetTaskQueue(); } std::shared_ptr GetCppTaskQueuePostprocess() { return cpp_postprocess_.GetTaskQueue(); } - ssize_t GetBatchSize() const; + size_t GetBatchSize() const; private: - static std::shared_ptr global_worker_; - std::vector work_list_; - std::shared_ptr session_ = nullptr; - std::string version_strategy_; PyTaskQueueGroup py_task_queue_group_; PreprocessThreadPool cpp_preprocess_; PostprocessThreadPool cpp_postprocess_; VersionController version_controller_; - LoadServableSpec base_spec_; std::atomic_bool exit_notify_master_ = true; std::atomic_bool servable_started_ = false; std::atomic_flag clear_flag_ = ATOMIC_FLAG_INIT; @@ -115,9 +105,6 @@ class MS_API Worker { std::shared_mutex worker_shared_lock_; ServableWorkerContext GetServableWorker(const RequestSpec &request_spec); - Status LoadServableConfig(const LoadServableSpec &servable_spec, const std::string &version_strategy, - std::vector *real_version_number); - void GetVersions(const LoadServableSpec &servable_spec, std::vector *real_versions); }; } // namespace serving diff --git a/mindspore_serving/worker/distributed/agent_startup.py b/mindspore_serving/worker/distributed/agent_startup.py index 83603eb..41d8218 100644 --- a/mindspore_serving/worker/distributed/agent_startup.py +++ b/mindspore_serving/worker/distributed/agent_startup.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,10 +13,31 @@ # limitations under the License. # ============================================================================ """Serving, distributed worker agent startup""" +import inspect +from mindspore_serving.worker import check_type -def startup_worker_agents(agent_ip, agent_start_port, worker_ip, worker_port, - model_dir, model_file_prefix, group_config_dir, group_file_prefix): + +def startup_worker_agents(worker_ip, worker_port, + get_model_files_fun, get_group_configs_fun, + rank_start, agent_start_port=7000): """Start up all needed worker agents on one machine """ - pass + check_type.check_str("worker_ip", worker_ip) + check_type.check_ip_port("worker_port", worker_port) + check_type.check_int("agent_start_port", agent_start_port, 1, 65535 - 7) + if inspect.isfunction(get_model_files_fun): + pass + else: + if not isinstance(get_model_files_fun, [list, tuple]): + raise RuntimeError(f"Check failed, get_model_files_fun first must be function or tuple/list of str, " + f"now is {type(get_model_files_fun)}") + if inspect.isfunction(get_group_configs_fun): + pass + else: + if not isinstance(get_group_configs_fun, [list, tuple]): + raise RuntimeError(f"Check failed, get_group_configs_fun first must be function or tuple/list of str, " + f"now is {type(get_group_configs_fun)}") + check_type.check_int("rank_start", rank_start, 0) + if rank_start % 8 != 0: + raise RuntimeError(f"Parameter 'rank_start' must be mulfiply of 8, now is {rank_start}") diff --git a/mindspore_serving/worker/distributed/distributed_worker.py b/mindspore_serving/worker/distributed/distributed_worker.py index 402a377..d8ff512 100644 --- a/mindspore_serving/worker/distributed/distributed_worker.py +++ b/mindspore_serving/worker/distributed/distributed_worker.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,13 +13,13 @@ # limitations under the License. # ============================================================================ """Serving, distributed worker startup""" -from .._worker import stop_on_except, _load_servable_config -from .. import check_type +from mindspore_serving.worker._worker import stop_on_except, _load_servable_config +from mindspore_serving.worker import check_type @stop_on_except -def start_distributed_servable(servable_directory, servable_name, rank_table_json_file, version_number=0, - master_ip="0.0.0.0", master_port=6100, worker_ip="0.0.0.0", worker_port=6200): +def start_distributed_servable(servable_directory, servable_name, rank_table_json_file, version_number=1, + worker_ip="0.0.0.0", worker_port=6200, master_ip="0.0.0.0", master_port=6100): r""" Start up the servable named 'servable_name' defined in 'servable_directory', and link the worker to the master through gRPC (master_ip, master_port). @@ -39,15 +39,11 @@ def start_distributed_servable(servable_directory, servable_name, rank_table_jso servable_name (str): The servable name. version_number (int): Servable version number to be loaded. The version number should be a positive integer, starting from 1, and 0 means to load the latest version. Default: 0. - device_type (str): Currently only supports "Ascend", "Davinci" and None, Default: None. - "Ascend" means the device type can be Ascend910 or Ascend310, etc. - "Davinci" has the same meaning as "Ascend". - None means the device type is determined by the MindSpore environment. - device_id (int): The id of the device the model loads into and runs in. + rank_table_json_file (str): The ranke table json file name. master_ip (str): The master ip the worker linked to. master_port (int): The master port the worker linked to. - worker_ip (str): The worker ip the master linked to. - worker_port (int): The worker port the master linked to. + worker_ip (str): The worker ip the master and agents linked to. + worker_port (int): The worker port the master and agents linked to. Examples: >>> import os @@ -61,6 +57,8 @@ def start_distributed_servable(servable_directory, servable_name, rank_table_jso check_type.check_str('servable_directory', servable_directory) check_type.check_str('servable_name', servable_name) check_type.check_int('version_number', version_number, 0) + if version_number == 0: + version_number = 1 check_type.check_str('rank_table_json_file', rank_table_json_file) check_type.check_str('master_ip', master_ip) @@ -73,7 +71,8 @@ def start_distributed_servable(servable_directory, servable_name, rank_table_jso @stop_on_except -def start_distributed_servable_in_master(servable_directory, servable_name, rank_table_json_file, version_number=0): +def start_distributed_servable_in_master(servable_directory, servable_name, rank_table_json_file, version_number=1, + worker_ip="0.0.0.0", worker_port=6200): r""" Start up the servable named 'servable_name' defined in 'svable_directory', and the worker will run in the process of the master. @@ -89,10 +88,9 @@ def start_distributed_servable_in_master(servable_directory, servable_name, rank servable_name (str): The servable name. version_number (int): Servable version number to be loaded. The version number should be a positive integer, starting from 1, and 0 means to load the latest version. Default: 0. - device_type (str): Currently only supports "Ascend", "Davinci" and None, Default: None. - "Ascend" means the device type can be Ascend910 or Ascend310, etc. - "Davinci" has the same meaning as "Ascend". - None means the device type is determined by the MindSpore environment. + rank_table_json_file (str): The ranke table json file name. + worker_ip (str): The worker ip the agents linked to. + worker_port (int): The worker port the agents linked to. Examples: >>> import os @@ -108,6 +106,12 @@ def start_distributed_servable_in_master(servable_directory, servable_name, rank check_type.check_str('servable_directory', servable_directory) check_type.check_str('servable_name', servable_name) check_type.check_int('version_number', version_number, 0) + if version_number == 0: + version_number = 1 + check_type.check_str('rank_table_json_file', rank_table_json_file) + check_type.check_str('worker_ip', worker_ip) + check_type.check_ip_port('worker_port', worker_port) + _load_servable_config(servable_directory, servable_name) diff --git a/mindspore_serving/worker/distributed/register.py b/mindspore_serving/worker/distributed/register.py index 59ce2d8..c060624 100644 --- a/mindspore_serving/worker/distributed/register.py +++ b/mindspore_serving/worker/distributed/register.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,8 +13,12 @@ # limitations under the License. # ============================================================================ """Serving, distributed worker register""" +from mindspore_serving.worker import check_type def declare_distributed_servable(rank_size, stage_size, with_bach_dim, without_batch_dim_inputs): """declare distributed servable in servable_config.py""" - pass + check_type.check_int("rank_size", rank_size, 0) + check_type.check_int("stage_size", stage_size, 0) + check_type.check_bool("with_bach_dim", with_bach_dim) + check_type.check_and_as_int_tuple_list("without_batch_dim_inputs", without_batch_dim_inputs, 0) diff --git a/mindspore_serving/worker/distributed/worker_agent.py b/mindspore_serving/worker/distributed/worker_agent.py index 4a27c75..d1ebb99 100644 --- a/mindspore_serving/worker/distributed/worker_agent.py +++ b/mindspore_serving/worker/distributed/worker_agent.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,11 +13,22 @@ # limitations under the License. # ============================================================================ """Serving, distributed worker agent""" +from mindspore_serving.worker import check_type -def _start_worker_agent(agent_ip, agent_start_port, worker_ip, worker_port, +def _start_worker_agent(agent_ip, agent_port, worker_ip, worker_port, rank_id, device_id, model_file, group_config_file, rank_table_file, with_bach_dim, without_batch_dim_inputs): """Start up one worker agent on one device id, invoke by agent_startup.startup_worker_agents """ - pass + check_type.check_str("agent_ip", agent_ip) + check_type.check_ip_port("agent_port", agent_port) + check_type.check_str("worker_ip", worker_ip) + check_type.check_ip_port("worker_port", worker_port) + check_type.check_int("rank_id", rank_id, 0) + check_type.check_int("device_id", device_id, 0) + check_type.check_str("model_file", model_file) + check_type.check_str("group_config_file", group_config_file) + check_type.check_str("rank_table_file", rank_table_file) + check_type.check_bool("with_bach_dim", with_bach_dim) + check_type.check_and_as_int_tuple_list("without_batch_dim_inputs", without_batch_dim_inputs, 0) diff --git a/tests/ut/cpp/common/test_servable_common.h b/tests/ut/cpp/common/test_servable_common.h index 91df994..bb7eeed 100644 --- a/tests/ut/cpp/common/test_servable_common.h +++ b/tests/ut/cpp/common/test_servable_common.h @@ -27,6 +27,7 @@ #include "worker/worker.h" #include "worker/notfiy_master/local_notify.h" #include "worker/context.h" +#include "worker/ascend_servable/ascend_sevable.h" #include "master/grpc/grpc_process.h" #include "mindspore_serving/proto/ms_service.pb.h" @@ -102,7 +103,12 @@ class TestMasterWorker : public UT::Common { auto notify_master = std::make_shared(); ServableContext::Instance()->SetDeviceId(0); ServableContext::Instance()->SetDeviceTypeStr("Ascend"); - Status status = Worker::GetInstance().StartServable(servable_dir, servable_name, version_number, notify_master); + auto servable = std::make_shared(); + auto status = servable->StartServable(servable_dir, servable_name, version_number); + if (status != SUCCESS) { + return status; + } + status = Worker::GetInstance().StartServable(servable, notify_master); return status; } static void DeclareServable(const std::string &servable_name, const std::string &servable_file, diff --git a/tests/ut/cpp/tests/test_start_worker.cc b/tests/ut/cpp/tests/test_start_worker.cc index 63e0838..8ab955b 100644 --- a/tests/ut/cpp/tests/test_start_worker.cc +++ b/tests/ut/cpp/tests/test_start_worker.cc @@ -30,10 +30,7 @@ TEST_F(TestStartWorker, test_worker_start_success) { DeclareServable("test_servable", "test_add.mindir", "mindir", true); RegisterMethod("test_servable", "add_common", {"x1", "x2"}, {"y"}, 2, 1); // start_servable - auto notify_master = std::make_shared(); - ServableContext::Instance()->SetDeviceId(0); - ServableContext::Instance()->SetDeviceTypeStr("Ascend"); - Status status = Worker::GetInstance().StartServable("test_servable_dir", "test_servable", 0, notify_master); + Status status = StartServable("test_servable_dir", "test_servable", 0); EXPECT_TRUE(status.IsSuccess()); } @@ -43,10 +40,7 @@ TEST_F(TestStartWorker, test_worker_start_error_model_file_name) { RegisterMethod("test_servable", "add_common", {"x1", "x2"}, {"y"}, 2, 1); // start_servable - auto notify_master = std::make_shared(); - ServableContext::Instance()->SetDeviceId(0); - ServableContext::Instance()->SetDeviceTypeStr("Ascend"); - Status status = Worker::GetInstance().StartServable("test_servable_dir", "test_servable", 0, notify_master); + auto status = StartServable("test_servable_dir", "test_servable", 0); EXPECT_FALSE(status.IsSuccess()); ExpectContainMsg(status.StatusMessage(), "Load model failed, servable directory: "); } @@ -57,12 +51,8 @@ TEST_F(TestStartWorker, test_worker_start_error_version_number) { RegisterMethod("test_servable", "add_common", {"x1", "x2"}, {"y"}, 2, 1); // start_servable - auto notify_master = std::make_shared(); - ServableContext::Instance()->SetDeviceId(0); - ServableContext::Instance()->SetDeviceTypeStr("Ascend"); int error_version_number = 2; - Status status = - Worker::GetInstance().StartServable("test_servable_dir", "test_servable", error_version_number, notify_master); + auto status = StartServable("test_servable_dir", "test_servable", error_version_number); EXPECT_FALSE(status.IsSuccess()); ExpectContainMsg(status.StatusMessage(), "Start servable failed, there is no servable of" @@ -78,11 +68,8 @@ TEST_F(TestStartWorker, test_worker_start_multi_version_number) { RegisterMethod("test_servable", "add_common", {"x1", "x2"}, {"y"}, 2, 1); // start_servable - auto notify_master = std::make_shared(); - ServableContext::Instance()->SetDeviceId(0); - ServableContext::Instance()->SetDeviceTypeStr("Ascend"); int version_number = 0; - Status status = Worker::GetInstance().StartServable(servable_dir, "test_servable", version_number, notify_master); + Status status = StartServable(servable_dir, "test_servable", version_number); EXPECT_TRUE(status.IsSuccess()); } @@ -96,10 +83,7 @@ TEST_F(TestStartWorker, test_worker_start_version_number_no_valid) { RegisterMethod("test_servable", "add_common", {"x1", "x2"}, {"y"}, 2, 1); // start_servable - auto notify_master = std::make_shared(); - ServableContext::Instance()->SetDeviceId(0); - ServableContext::Instance()->SetDeviceTypeStr("Ascend"); - Status status = Worker::GetInstance().StartServable(servable_dir, "test_servable", 0, notify_master); + Status status = StartServable(servable_dir, "test_servable", 0); EXPECT_FALSE(status.IsSuccess()); ExpectContainMsg(status.StatusMessage(), "Start servable failed, there is no servable of" @@ -112,11 +96,8 @@ TEST_F(TestStartWorker, test_worker_start_error_servable_dir) { RegisterMethod("test_servable", "add_common", {"x1", "x2"}, {"y"}, 2, 1); // start_servable - auto notify_master = std::make_shared(); - ServableContext::Instance()->SetDeviceId(0); - ServableContext::Instance()->SetDeviceTypeStr("Ascend"); std::string error_servable_dir = "test_servable_dir_error"; - Status status = Worker::GetInstance().StartServable(error_servable_dir, "test_servable", 0, notify_master); + Status status = StartServable(error_servable_dir, "test_servable", 0); EXPECT_FALSE(status.IsSuccess()); ExpectContainMsg(status.StatusMessage(), "Start servable failed, there is no servable of" @@ -129,11 +110,8 @@ TEST_F(TestStartWorker, test_worker_start_error_servable_name) { RegisterMethod("test_servable", "add_common", {"x1", "x2"}, {"y"}, 2, 1); // start_servable - auto notify_master = std::make_shared(); - ServableContext::Instance()->SetDeviceId(0); - ServableContext::Instance()->SetDeviceTypeStr("Ascend"); std::string error_servable_name = "test_servable_error"; - Status status = Worker::GetInstance().StartServable("test_servable_dir", error_servable_name, 0, notify_master); + Status status = StartServable("test_servable_dir", error_servable_name, 0); EXPECT_FALSE(status.IsSuccess()); ExpectContainMsg(status.StatusMessage(), "'test_servable_error' has not been registered"); } @@ -144,24 +122,18 @@ TEST_F(TestStartWorker, test_worker_start_error_servable_format) { RegisterMethod("test_servable", "add_common", {"x1", "x2"}, {"y"}, 2, 1); // start_servable - auto notify_master = std::make_shared(); - ServableContext::Instance()->SetDeviceId(0); - ServableContext::Instance()->SetDeviceTypeStr("Ascend"); - Status status = Worker::GetInstance().StartServable("test_servable_dir", "test_servable", 0, notify_master); + Status status = StartServable("test_servable_dir", "test_servable", 0); EXPECT_FALSE(status.IsSuccess()); - ExpectContainMsg(status.StatusMessage(), "Cannot find session registered for device type Ascend and model type OM"); + ExpectContainMsg(status.StatusMessage(), "Not support device type Ascend and model type OM. "); } TEST_F(TestStartWorker, test_worker_start_no_registered_method) { - Init("test_servable_dir", "test_servable", 1, "test_add.mindir"); + Init("test_servable_dir", "test_servable", 2, "test_add.mindir"); DeclareServable("test_servable", "test_add.mindir", "mindir", true); // no registered method // RegisterMethod("test_servable", "add_common", {"x1", "x2"}, {"y"}, 2, 1); // start_servable - auto notify_master = std::make_shared(); - ServableContext::Instance()->SetDeviceId(0); - ServableContext::Instance()->SetDeviceTypeStr("Ascend"); - Status status = Worker::GetInstance().StartServable("test_servable_dir", "test_servable", 0, notify_master); + Status status = StartServable("test_servable_dir", "test_servable", 2); EXPECT_FALSE(status.IsSuccess()); ExpectContainMsg(status.StatusMessage(), "There is no method registered for servable"); } @@ -181,10 +153,7 @@ TEST_F(TestStartWorker, test_worker_start_multi_method) { RegisterMethod("test_servable", "add_common", {"x1", "x2"}, {"y"}, 2, 1); RegisterMethod("test_servable", "add_common2", {"x1", "x2"}, {"y"}, 2, 1); // start_servable - auto notify_master = std::make_shared(); - ServableContext::Instance()->SetDeviceId(0); - ServableContext::Instance()->SetDeviceTypeStr("Ascend"); - Status status = Worker::GetInstance().StartServable("test_servable_dir", "test_servable", 0, notify_master); + Status status = StartServable("test_servable_dir", "test_servable", 0); EXPECT_TRUE(status.IsSuccess()); } @@ -194,10 +163,7 @@ TEST_F(TestStartWorker, test_worker_start_method_servable_input_count_not_match) size_t servable_input_count = 1; RegisterMethod("test_servable", "add_common", {"x1", "x2"}, {"y"}, servable_input_count, 1); // start_servable - auto notify_master = std::make_shared(); - ServableContext::Instance()->SetDeviceId(0); - ServableContext::Instance()->SetDeviceTypeStr("Ascend"); - Status status = Worker::GetInstance().StartServable("test_servable_dir", "test_servable", 0, notify_master); + Status status = StartServable("test_servable_dir", "test_servable", 0); EXPECT_FALSE(status.IsSuccess()); ExpectContainMsg(status.StatusMessage(), "The inputs count 1 registered in method not equal to " @@ -210,10 +176,7 @@ TEST_F(TestStartWorker, test_worker_start_method_servable_output_count_not_match size_t servable_output_count = 2; RegisterMethod("test_servable", "add_common", {"x1", "x2"}, {"y"}, 2, servable_output_count); // start_servable - auto notify_master = std::make_shared(); - ServableContext::Instance()->SetDeviceId(0); - ServableContext::Instance()->SetDeviceTypeStr("Ascend"); - Status status = Worker::GetInstance().StartServable("test_servable_dir", "test_servable", 0, notify_master); + Status status = StartServable("test_servable_dir", "test_servable", 0); EXPECT_FALSE(status.IsSuccess()); ExpectContainMsg(status.StatusMessage(), "The outputs count 2 registered in method not equal to " @@ -241,10 +204,7 @@ TEST_F(TestStartWorker, test_worker_start_preprocess_not_found) { ServableStorage::Instance().RegisterMethod(method_signature); // start_servable - auto notify_master = std::make_shared(); - ServableContext::Instance()->SetDeviceId(0); - ServableContext::Instance()->SetDeviceTypeStr("Ascend"); - Status status = Worker::GetInstance().StartServable("test_servable_dir", "test_servable", 0, notify_master); + Status status = StartServable("test_servable_dir", "test_servable", 0); EXPECT_FALSE(status.IsSuccess()); ExpectContainMsg(status.StatusMessage(), " preprocess preprocess_fake_fun not defined") } @@ -269,10 +229,7 @@ TEST_F(TestStartWorker, test_worker_start_postprocess_not_found) { ServableStorage::Instance().RegisterMethod(method_signature); // start_servable - auto notify_master = std::make_shared(); - ServableContext::Instance()->SetDeviceId(0); - ServableContext::Instance()->SetDeviceTypeStr("Ascend"); - Status status = Worker::GetInstance().StartServable("test_servable_dir", "test_servable", 0, notify_master); + Status status = StartServable("test_servable_dir", "test_servable", 0); EXPECT_FALSE(status.IsSuccess()); ExpectContainMsg(status.StatusMessage(), " postprocess postprocess_fake_fun not defined") } @@ -300,10 +257,7 @@ TEST_F(TestStartWorker, test_worker_start_with_preproces_and_postprocess_success ServableStorage::Instance().RegisterMethod(method_signature); // start_servable - auto notify_master = std::make_shared(); - ServableContext::Instance()->SetDeviceId(0); - ServableContext::Instance()->SetDeviceTypeStr("Ascend"); - Status status = Worker::GetInstance().StartServable("test_servable_dir", "test_servable", 0, notify_master); + Status status = StartServable("test_servable_dir", "test_servable", 0); EXPECT_TRUE(status.IsSuccess()); } diff --git a/tests/ut/runtest.sh b/tests/ut/runtest.sh index 9a3dfa7..da92b72 100755 --- a/tests/ut/runtest.sh +++ b/tests/ut/runtest.sh @@ -16,21 +16,24 @@ set -e -CURRPATH=$(cd "$(dirname $0)" || exit; pwd) +CURRPATH=$( + cd "$(dirname $0)" || exit + pwd +) if [ $# -gt 0 ]; then - if [ $1 == "python" ]; then - echo "run python ut" - bash ${CURRPATH}/python/runtest.sh $2 - elif [ $1 == "cpp" ]; then - echo "run cpp ut" - bash ${CURRPATH}/cpp/runtest.sh - fi -else - echo "run all ut" - # 1.run python testcases + if [ $1 == "python" ]; then + echo "run python ut" bash ${CURRPATH}/python/runtest.sh $2 - - # 2.run c++ ut testcases + elif [ $1 == "cpp" ]; then + echo "run cpp ut" bash ${CURRPATH}/cpp/runtest.sh + fi +else + echo "run all ut" + # 1.run python testcases + bash ${CURRPATH}/python/runtest.sh $2 + + # 2.run c++ ut testcases + bash ${CURRPATH}/cpp/runtest.sh fi From bcb30821d576be3cd9bf2ed08b9eb20b6daa5e6f Mon Sep 17 00:00:00 2001 From: zhangyinxia Date: Tue, 26 Jan 2021 18:17:28 +0800 Subject: [PATCH 03/10] add message --- .../{master/grpc => common}/grpc_client.cc | 5 +- .../{master/grpc => common}/grpc_client.h | 5 +- mindspore_serving/ccsrc/master/dispacther.h | 2 +- .../ccsrc/master/notify_worker/base_notify.h | 3 +- .../ccsrc/master/notify_worker/grpc_notify.cc | 1 - .../agent_process/agent_process.cc | 36 ++++++++ .../agent_process/agent_process.h | 41 +++++++++ .../distributed_process.cc | 37 ++++++++ .../distributed_process/distributed_process.h | 49 ++++++++++ .../notify_agent/base_notify_agent.h | 43 +++++++++ .../notify_agent/notify_agent.cc | 39 ++++++++ .../notify_agent/notify_agent.h | 48 ++++++++++ .../notify_distributed/base_notify_worker.h | 38 ++++++++ .../notify_distributed/notify_worker.cc | 91 +++++++++++++++++++ .../notify_distributed/notify_worker.h | 53 +++++++++++ mindspore_serving/proto/ms_agent.proto | 27 ++++++ mindspore_serving/proto/ms_distributed.proto | 28 ++++++ 17 files changed, 537 insertions(+), 9 deletions(-) rename mindspore_serving/ccsrc/{master/grpc => common}/grpc_client.cc (91%) rename mindspore_serving/ccsrc/{master/grpc => common}/grpc_client.h (91%) create mode 100644 mindspore_serving/ccsrc/worker/distributed_worker/agent_process/agent_process.cc create mode 100644 mindspore_serving/ccsrc/worker/distributed_worker/agent_process/agent_process.h create mode 100644 mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.cc create mode 100644 mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.h create mode 100644 mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/base_notify_agent.h create mode 100644 mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/notify_agent.cc create mode 100644 mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/notify_agent.h create mode 100644 mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/base_notify_worker.h create mode 100644 mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.cc create mode 100644 mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.h create mode 100644 mindspore_serving/proto/ms_agent.proto create mode 100644 mindspore_serving/proto/ms_distributed.proto diff --git a/mindspore_serving/ccsrc/master/grpc/grpc_client.cc b/mindspore_serving/ccsrc/common/grpc_client.cc similarity index 91% rename from mindspore_serving/ccsrc/master/grpc/grpc_client.cc rename to mindspore_serving/ccsrc/common/grpc_client.cc index 85ad129..508da4e 100644 --- a/mindspore_serving/ccsrc/master/grpc/grpc_client.cc +++ b/mindspore_serving/ccsrc/common/grpc_client.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,10 +14,9 @@ * limitations under the License. */ -#include "master/grpc/grpc_client.h" +#include "common/grpc_client.h" #include #include -#include "master/grpc/grpc_server.h" namespace mindspore { namespace serving { diff --git a/mindspore_serving/ccsrc/master/grpc/grpc_client.h b/mindspore_serving/ccsrc/common/grpc_client.h similarity index 91% rename from mindspore_serving/ccsrc/master/grpc/grpc_client.h rename to mindspore_serving/ccsrc/common/grpc_client.h index ca39a48..afd9a1e 100644 --- a/mindspore_serving/ccsrc/master/grpc/grpc_client.h +++ b/mindspore_serving/ccsrc/common/grpc_client.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,7 +24,6 @@ #include #include #include "common/serving_common.h" -#include "master/notify_worker/base_notify.h" #include "proto/ms_service.pb.h" #include "proto/ms_service.grpc.pb.h" #include "proto/ms_master.pb.h" @@ -38,6 +37,8 @@ extern std::unique_ptr client_; using PredictOnFinish = std::function; +using DispatchCallback = std::function; + class MSServiceClient { public: MSServiceClient() = default; diff --git a/mindspore_serving/ccsrc/master/dispacther.h b/mindspore_serving/ccsrc/master/dispacther.h index 95f632b..4e5e406 100644 --- a/mindspore_serving/ccsrc/master/dispacther.h +++ b/mindspore_serving/ccsrc/master/dispacther.h @@ -27,7 +27,7 @@ #include "common/instance.h" #include "common/servable.h" #include "master/notify_worker/base_notify.h" -#include "master/grpc/grpc_client.h" +#include "common/grpc_client.h" namespace mindspore::serving { diff --git a/mindspore_serving/ccsrc/master/notify_worker/base_notify.h b/mindspore_serving/ccsrc/master/notify_worker/base_notify.h index 5d8cb11..5ccb0c3 100644 --- a/mindspore_serving/ccsrc/master/notify_worker/base_notify.h +++ b/mindspore_serving/ccsrc/master/notify_worker/base_notify.h @@ -22,12 +22,11 @@ #include "common/serving_common.h" #include "common/servable.h" #include "proto/ms_service.pb.h" +#include "common/grpc_client.h" namespace mindspore { namespace serving { -using DispatchCallback = std::function; - class MS_API BaseNotifyWorker { public: BaseNotifyWorker() = default; diff --git a/mindspore_serving/ccsrc/master/notify_worker/grpc_notify.cc b/mindspore_serving/ccsrc/master/notify_worker/grpc_notify.cc index 4420d44..86ca8e7 100644 --- a/mindspore_serving/ccsrc/master/notify_worker/grpc_notify.cc +++ b/mindspore_serving/ccsrc/master/notify_worker/grpc_notify.cc @@ -20,7 +20,6 @@ #include #include "common/exit_handle.h" #include "common/grpc_server.h" -#include "master/grpc/grpc_client.h" namespace mindspore { namespace serving { diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/agent_process/agent_process.cc b/mindspore_serving/ccsrc/worker/distributed_worker/agent_process/agent_process.cc new file mode 100644 index 0000000..474fe3a --- /dev/null +++ b/mindspore_serving/ccsrc/worker/distributed_worker/agent_process/agent_process.cc @@ -0,0 +1,36 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "worker/distributed_worker/agent_process/agent_process.h" + +namespace mindspore { +namespace serving { +grpc::Status MSAgentImpl::Exit(grpc::ServerContext *context, const proto::ExitRequest *request, + proto::ExitReply *reply) { + MSI_LOG(INFO) << "Distributed Worker Exit"; + // to do : need WorkerAgent support stop funcition + return grpc::Status::OK; +} + +grpc::Status MSAgentImpl::Predict(grpc::ServerContext *context, const proto::PredictRequest *request, + proto::PredictReply *reply) { + MSI_LOG(INFO) << "Begin call service Eval"; + // to do : need WorkerAgent support run funcition + return grpc::Status::OK; +} + +} // namespace serving +} // namespace mindspore diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/agent_process/agent_process.h b/mindspore_serving/ccsrc/worker/distributed_worker/agent_process/agent_process.h new file mode 100644 index 0000000..b04affe --- /dev/null +++ b/mindspore_serving/ccsrc/worker/distributed_worker/agent_process/agent_process.h @@ -0,0 +1,41 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_SERVING_WORKER_AGENT_PROCESS_H +#define MINDSPORE_SERVING_WORKER_AGENT_PROCESS_H + +#include +#include +#include +#include "common/serving_common.h" +#include "proto/ms_worker.pb.h" +#include "proto/ms_worker.grpc.pb.h" + +namespace mindspore { +namespace serving { + +// Service Implement +class MSAgentImpl final : public proto::MSWorker::Service { + public: + grpc::Status Predict(grpc::ServerContext *context, const proto::PredictRequest *request, + proto::PredictReply *reply) override; + grpc::Status Exit(grpc::ServerContext *context, const proto::ExitRequest *request, proto::ExitReply *reply) override; +}; + +} // namespace serving +} // namespace mindspore + +#endif // MINDSPORE_SERVING_WORKER_AGENT_PROCESS_H diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.cc b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.cc new file mode 100644 index 0000000..11b86d3 --- /dev/null +++ b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.cc @@ -0,0 +1,37 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "worker/distributed_worker/distributed_process/distributed_process.h" + +namespace mindspore { +namespace serving { + +grpc::Status MSDistributedImpl::Register(grpc::ServerContext *context, const proto::RegisterRequest *request, + proto::RegisterReply *reply) { + return grpc::Status::OK; +} + +grpc::Status MSDistributedImpl::Predict(grpc::ServerContext *context, const proto::PredictRequest *request, + proto::PredictReply *reply) { + return grpc::Status::OK; +} + +grpc::Status MSDistributedImpl::Exit(grpc::ServerContext *context, const proto::ExitRequest *request, + proto::ExitReply *reply) { + return grpc::Status::OK; +} +} // namespace serving +} // namespace mindspore diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.h b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.h new file mode 100644 index 0000000..0ba4f08 --- /dev/null +++ b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.h @@ -0,0 +1,49 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_SERVING_DISTRIBUTED_WORKER_WORKER_PROCESS_H +#define MINDSPORE_SERVING_DISTRIBUTED_WORKER_WORKER_PROCESS_H + +#include +#include +#include +#include "common/serving_common.h" +#include "proto/ms_worker.pb.h" +#include "proto/ms_worker.grpc.pb.h" +#include "proto/ms_service.pb.h" +#include "proto/ms_service.grpc.pb.h" +#include "proto/ms_master.pb.h" +#include "proto/ms_master.grpc.pb.h" + +namespace mindspore { +namespace serving { + +// Service Implement +class MSDistributedImpl final : public proto::MSMaster::Service, public proto::MSWorker::Service { + public: + MSDistributedImpl() {} + ~MSDistributedImpl() = default; + grpc::Status Register(grpc::ServerContext *context, const proto::RegisterRequest *request, + proto::RegisterReply *reply) override; + grpc::Status Predict(grpc::ServerContext *context, const proto::PredictRequest *request, + proto::PredictReply *reply) override; + grpc::Status Exit(grpc::ServerContext *context, const proto::ExitRequest *request, proto::ExitReply *reply) override; +}; + +} // namespace serving +} // namespace mindspore + +#endif // MINDSPORE_SERVING_DISTRIBUTED_WORKER_WORKER_PROCESS_H diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/base_notify_agent.h b/mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/base_notify_agent.h new file mode 100644 index 0000000..861ea0d --- /dev/null +++ b/mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/base_notify_agent.h @@ -0,0 +1,43 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_SERVING_WORKER_BASE_NOTIFY_AGENT_H +#define MINDSPORE_SERVING_WORKER_BASE_NOTIFY_AGENT_H +#include +#include +#include +#include "common/serving_common.h" +#include "common/servable.h" +#include "proto/ms_service.pb.h" + +namespace mindspore { +namespace serving { + +using DistributeCallback = std::function; + +class MS_API BaseNotifyAgent { + public: + BaseNotifyAgent() = default; + virtual ~BaseNotifyAgent() = default; + virtual Status Exit() = 0; + virtual Status DispatchAsync(const proto::PredictRequest &request, proto::PredictReply *reply, + DistributeCallback callback) = 0; +}; + +} // namespace serving +} // namespace mindspore + +#endif // MINDSPORE_SERVING_WORKER_BASE_NOTIFY_AGENT_H diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/notify_agent.cc b/mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/notify_agent.cc new file mode 100644 index 0000000..cbc3e50 --- /dev/null +++ b/mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/notify_agent.cc @@ -0,0 +1,39 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "worker/distributed_worker/notify_agent/notify_agent.h" +#include +#include +#include +#include +#include "common/exit_handle.h" +#include "common/grpc_server.h" + +namespace mindspore { +namespace serving { + +GrpcNotfiyAgent::GrpcNotfiyAgent(const std::string &worker_address) {} + +GrpcNotfiyAgent::~GrpcNotfiyAgent() = default; + +Status GrpcNotfiyAgent::Exit() { return SUCCESS; } + +Status GrpcNotfiyAgent::DispatchAsync(const proto::PredictRequest &request, proto::PredictReply *reply, + DistributeCallback callback) { + return SUCCESS; +} + +} // namespace serving +} // namespace mindspore diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/notify_agent.h b/mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/notify_agent.h new file mode 100644 index 0000000..974c54a --- /dev/null +++ b/mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/notify_agent.h @@ -0,0 +1,48 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_SERVING_WORKER_NOTIFY_AGENT_H +#define MINDSPORE_SERVING_WORKER_NOTIFY_AGENT_H +#include +#include +#include +#include +#include "worker/distributed_worker/notify_agent/base_notify_agent.h" +#include "proto/ms_agent.pb.h" +#include "proto/ms_agent.grpc.pb.h" + +namespace mindspore { +namespace serving { + +class MS_API GrpcNotfiyAgent : public BaseNotifyAgent { + public: + explicit GrpcNotfiyAgent(const std::string &worker_address); + ~GrpcNotfiyAgent() override; + + Status Exit() override; + + Status DispatchAsync(const proto::PredictRequest &request, proto::PredictReply *reply, + DistributeCallback callback) override; + + private: + std::string worker_address_; + std::shared_ptr stub_ = nullptr; +}; + +} // namespace serving +} // namespace mindspore + +#endif // MINDSPORE_SERVING_WORKER_NOTIFY_AGENT_H diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/base_notify_worker.h b/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/base_notify_worker.h new file mode 100644 index 0000000..8e5e690 --- /dev/null +++ b/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/base_notify_worker.h @@ -0,0 +1,38 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_SERVING_WORKER_BASE_NOTIFY_WORKER_H +#define MINDSPORE_SERVING_WORKER_BASE_NOTIFY_WORKER_H +#include +#include "common/serving_common.h" +#include "common/servable.h" +#include "worker/distributed_worker/common.h" + +namespace mindspore { +namespace serving { + +class MS_API BaseNotifyDistributeWorker { + public: + BaseNotifyDistributeWorker() = default; + virtual ~BaseNotifyDistributeWorker() = default; + virtual Status Register(const std::vector &worker_specs) = 0; + virtual Status Unregister() = 0; +}; + +} // namespace serving +} // namespace mindspore + +#endif // MINDSPORE_SERVING_WORKER_BASE_NOTIFY_WORKER_H diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.cc b/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.cc new file mode 100644 index 0000000..d0014d5 --- /dev/null +++ b/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.cc @@ -0,0 +1,91 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "worker/distributed_worker/notify_distributed/notify_worker.h" +#include +#include +#include +#include +#include "common/exit_handle.h" +#include "common/grpc_server.h" + +namespace mindspore { +namespace serving { + +GrpcNotfiyDistributeWorker::GrpcNotfiyDistributeWorker(const std::string &distributed_worker_ip, + uint32_t distributed_worker_port, const std::string &host_ip, + uint32_t host_port) + : distributed_worker_ip_(distributed_worker_ip), + distributed_worker_port_(distributed_worker_port), + host_ip_(host_ip), + host_port_(host_port) { + distributed_worker_address_ = distributed_worker_ip + ":" + std::to_string(distributed_worker_port); + agent_address_ = host_ip_ + ":" + std::to_string(host_port_); + auto channel = GrpcServer::CreateChannel(distributed_worker_address_); + stub_ = proto::MSDistributedWorker::NewStub(channel); +} + +GrpcNotfiyDistributeWorker::~GrpcNotfiyDistributeWorker() = default; + +Status GrpcNotfiyDistributeWorker::Register(const std::vector &worker_specs) { + const int32_t REGISTER_TIME_OUT = 60; + const int32_t REGISTER_INTERVAL = 1; + auto loop = REGISTER_TIME_OUT; + while (loop-- && !ExitSignalHandle::Instance().HasStopped()) { + MSI_LOG(INFO) << "Register to " << distributed_worker_address_; + proto::RegisterRequest request; + request.set_address(agent_address_); + // to do set RegisterRequest message + proto::RegisterReply reply; + grpc::ClientContext context; + std::chrono::system_clock::time_point deadline = + std::chrono::system_clock::now() + std::chrono::seconds(REGISTER_INTERVAL); + context.set_deadline(deadline); + grpc::Status status = stub_->Register(&context, request, &reply); + if (status.ok()) { + MSI_LOG(INFO) << "Register SUCCESS "; + return SUCCESS; + } + MSI_LOG_INFO << "Grpc message: " << status.error_code() << ", " << status.error_message(); + std::this_thread::sleep_for(std::chrono::milliseconds(REGISTER_INTERVAL * 1000)); + } + if (ExitSignalHandle::Instance().HasStopped()) { + return INFER_STATUS_LOG_WARNING(FAILED) << "Worker exit, stop registration"; + } + return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Register TimeOut"; +} + +Status GrpcNotfiyDistributeWorker::Unregister() { + if (is_stoped_.load()) { + return SUCCESS; + } + is_stoped_ = true; + proto::ExitRequest request; + request.set_address(agent_address_); + proto::ExitReply reply; + grpc::ClientContext context; + const int32_t TIME_OUT = 1; + std::chrono::system_clock::time_point deadline = std::chrono::system_clock::now() + std::chrono::seconds(TIME_OUT); + context.set_deadline(deadline); + grpc::Status status = stub_->Exit(&context, request, &reply); + if (status.ok()) { + MSI_LOG(INFO) << "Exit SUCCESS "; + return SUCCESS; + } + return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Exit Failed"; +} + +} // namespace serving +} // namespace mindspore diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.h b/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.h new file mode 100644 index 0000000..e698c56 --- /dev/null +++ b/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.h @@ -0,0 +1,53 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_SERVING_WORKER_NOTIFY_WORKER_H +#define MINDSPORE_SERVING_WORKER_NOTIFY_WORKER_H +#include +#include +#include +#include "worker/distributed_worker/notify_distributed/base_notify_worker.h" +#include "proto/ms_master.pb.h" +#include "proto/ms_master.grpc.pb.h" +#include "proto/ms_distributed.pb.h" +#include "proto/ms_distributed.grpc.pb.h" +namespace mindspore { +namespace serving { + +class MS_API GrpcNotfiyDistributeWorker : public BaseNotifyDistributeWorker { + public: + GrpcNotfiyDistributeWorker(const std::string &master_ip, uint32_t master_port, const std::string &host_ip, + uint32_t host_port); + ~GrpcNotfiyDistributeWorker() override; + Status Register(const std::vector &worker_specs) override; + Status Unregister() override; + + private: + std::string distributed_worker_ip_; + uint32_t distributed_worker_port_; + std::string host_ip_; + uint32_t host_port_; + std::string agent_address_; + std::string distributed_worker_address_; + + std::unique_ptr stub_; + std::atomic is_stoped_{false}; +}; + +} // namespace serving +} // namespace mindspore + +#endif // MINDSPORE_SERVING_WORKER_NOTIFY_WORKER_H diff --git a/mindspore_serving/proto/ms_agent.proto b/mindspore_serving/proto/ms_agent.proto new file mode 100644 index 0000000..eb44638 --- /dev/null +++ b/mindspore_serving/proto/ms_agent.proto @@ -0,0 +1,27 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// ms_manager.proto +syntax = "proto3"; + +package mindspore.serving.proto; +import "mindspore_serving/proto/ms_service.proto"; +import "mindspore_serving/proto/ms_master.proto"; + +service MSAgent { + rpc Predict(PredictRequest) returns (PredictReply) {} + rpc Exit(ExitRequest) returns (ExitReply) {} +} diff --git a/mindspore_serving/proto/ms_distributed.proto b/mindspore_serving/proto/ms_distributed.proto new file mode 100644 index 0000000..a53f126 --- /dev/null +++ b/mindspore_serving/proto/ms_distributed.proto @@ -0,0 +1,28 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// ms_manager.proto +syntax = "proto3"; + +package mindspore.serving.proto; +import "mindspore_serving/proto/ms_service.proto"; +import "mindspore_serving/proto/ms_master.proto"; + +service MSDistributedWorker { + rpc Predict(PredictRequest) returns (PredictReply) {} + rpc Exit(ExitRequest) returns (ExitReply) {} + rpc Register(RegisterRequest) returns (RegisterReply) {} +} \ No newline at end of file From 88e909071c519ed82293e907a4edb8127468e9e1 Mon Sep 17 00:00:00 2001 From: zhangyinxia Date: Wed, 27 Jan 2021 17:18:11 +0800 Subject: [PATCH 04/10] add register and exit message --- .../agent_process/agent_process.cc | 13 ++++---- .../agent_process/agent_process.h | 13 ++++---- .../distributed_process.cc | 31 ++++++++++++----- .../distributed_process/distributed_process.h | 24 +++++++------- .../distributed_servable.cc | 16 +++++++-- .../notify_agent/notify_agent.cc | 26 +++++++++++++-- .../notify_agent/notify_agent.h | 2 +- .../notify_distributed/notify_worker.cc | 12 +++---- mindspore_serving/proto/ms_agent.proto | 33 +++++++++++++++++-- mindspore_serving/proto/ms_distributed.proto | 27 ++++++++++++--- 10 files changed, 147 insertions(+), 50 deletions(-) diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/agent_process/agent_process.cc b/mindspore_serving/ccsrc/worker/distributed_worker/agent_process/agent_process.cc index 474fe3a..37f6bff 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/agent_process/agent_process.cc +++ b/mindspore_serving/ccsrc/worker/distributed_worker/agent_process/agent_process.cc @@ -18,17 +18,18 @@ namespace mindspore { namespace serving { -grpc::Status MSAgentImpl::Exit(grpc::ServerContext *context, const proto::ExitRequest *request, - proto::ExitReply *reply) { +grpc::Status MSAgentImpl::DistributedExit(grpc::ServerContext *context, const proto::DistributedExitRequest *request, + proto::DistributedExitReply *reply) { MSI_LOG(INFO) << "Distributed Worker Exit"; - // to do : need WorkerAgent support stop funcition + // WorkerAgent::GetInstance().StopServable(false); return grpc::Status::OK; } -grpc::Status MSAgentImpl::Predict(grpc::ServerContext *context, const proto::PredictRequest *request, - proto::PredictReply *reply) { +grpc::Status MSAgentImpl::DistributedPredict(grpc::ServerContext *context, + const proto::DistributedPredictRequest *request, + proto::DistributedPredictReply *reply) { MSI_LOG(INFO) << "Begin call service Eval"; - // to do : need WorkerAgent support run funcition + // WorkerAgent::GetInstance().Run(*request, reply); return grpc::Status::OK; } diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/agent_process/agent_process.h b/mindspore_serving/ccsrc/worker/distributed_worker/agent_process/agent_process.h index b04affe..7ea69ab 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/agent_process/agent_process.h +++ b/mindspore_serving/ccsrc/worker/distributed_worker/agent_process/agent_process.h @@ -21,18 +21,19 @@ #include #include #include "common/serving_common.h" -#include "proto/ms_worker.pb.h" -#include "proto/ms_worker.grpc.pb.h" +#include "proto/ms_agent.pb.h" +#include "proto/ms_agent.grpc.pb.h" namespace mindspore { namespace serving { // Service Implement -class MSAgentImpl final : public proto::MSWorker::Service { +class MSAgentImpl final : public proto::MSAgent::Service { public: - grpc::Status Predict(grpc::ServerContext *context, const proto::PredictRequest *request, - proto::PredictReply *reply) override; - grpc::Status Exit(grpc::ServerContext *context, const proto::ExitRequest *request, proto::ExitReply *reply) override; + grpc::Status DistributedPredict(grpc::ServerContext *context, const proto::DistributedPredictRequest *request, + proto::DistributedPredictReply *reply) override; + grpc::Status DistributedExit(grpc::ServerContext *context, const proto::DistributedExitRequest *request, + proto::DistributedExitReply *reply) override; }; } // namespace serving diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.cc b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.cc index 11b86d3..8fa3fe6 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.cc +++ b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.cc @@ -19,18 +19,31 @@ namespace mindspore { namespace serving { -grpc::Status MSDistributedImpl::Register(grpc::ServerContext *context, const proto::RegisterRequest *request, - proto::RegisterReply *reply) { +grpc::Status MSDistributedImpl::AgentRegister(grpc::ServerContext *context, const proto::AgentRegisterRequest *request, + proto::AgentRegisterReply *reply) { + MSI_EXCEPTION_IF_NULL(request); + MSI_EXCEPTION_IF_NULL(reply); + WorkerAgentSpec agent_spec; + // todo request->agent_spec + Status status(FAILED); + status = servable_->RegisterAgent(agent_spec); + if (status != SUCCESS) { + MSI_LOG(ERROR) << "Agent Register FAILED"; + } return grpc::Status::OK; } -grpc::Status MSDistributedImpl::Predict(grpc::ServerContext *context, const proto::PredictRequest *request, - proto::PredictReply *reply) { - return grpc::Status::OK; -} - -grpc::Status MSDistributedImpl::Exit(grpc::ServerContext *context, const proto::ExitRequest *request, - proto::ExitReply *reply) { +grpc::Status MSDistributedImpl::AgentExit(grpc::ServerContext *context, const proto::AgentExitRequest *request, + proto::AgentExitReply *reply) { + MSI_EXCEPTION_IF_NULL(request); + MSI_EXCEPTION_IF_NULL(reply); + WorkerAgentSpec agent_spec; + // todo request->agent_spec + Status status(FAILED); + status = servable_->UnregisterAgent(agent_spec); + if (status != SUCCESS) { + MSI_LOG(ERROR) << "Agent Exit FAILED"; + } return grpc::Status::OK; } } // namespace serving diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.h b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.h index 0ba4f08..3ef02b2 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.h +++ b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.h @@ -20,27 +20,29 @@ #include #include #include +#include #include "common/serving_common.h" -#include "proto/ms_worker.pb.h" -#include "proto/ms_worker.grpc.pb.h" #include "proto/ms_service.pb.h" #include "proto/ms_service.grpc.pb.h" -#include "proto/ms_master.pb.h" -#include "proto/ms_master.grpc.pb.h" +#include "proto/ms_distributed.pb.h" +#include "proto/ms_distributed.grpc.pb.h" +#include "worker/distributed_worker/distributed_servable.h" namespace mindspore { namespace serving { // Service Implement -class MSDistributedImpl final : public proto::MSMaster::Service, public proto::MSWorker::Service { +class MSDistributedImpl final : public proto::MSDistributedWorker::Service { public: - MSDistributedImpl() {} + explicit MSDistributedImpl(std::shared_ptr servable) : servable_(servable) {} ~MSDistributedImpl() = default; - grpc::Status Register(grpc::ServerContext *context, const proto::RegisterRequest *request, - proto::RegisterReply *reply) override; - grpc::Status Predict(grpc::ServerContext *context, const proto::PredictRequest *request, - proto::PredictReply *reply) override; - grpc::Status Exit(grpc::ServerContext *context, const proto::ExitRequest *request, proto::ExitReply *reply) override; + grpc::Status AgentRegister(grpc::ServerContext *context, const proto::AgentRegisterRequest *request, + proto::AgentRegisterReply *reply) override; + grpc::Status AgentExit(grpc::ServerContext *context, const proto::AgentExitRequest *request, + proto::AgentExitReply *reply) override; + + private: + std::shared_ptr servable_; }; } // namespace serving diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.cc b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.cc index 9327d24..ee76d5f 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.cc +++ b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.cc @@ -28,8 +28,20 @@ std::vector DistributedServable::GetInputInfos() const { return std: std::vector DistributedServable::GetOutputInfos() const { return std::vector(); } uint64_t DistributedServable::GetBatchSize() const { return 0; } Status DistributedServable::GetDistributedServableConfig(DistributedServableConfig *config) { return Status(); } -Status DistributedServable::RegisterAgent(const WorkerAgentSpec &agent_spec) { return Status(); } -Status DistributedServable::UnregisterAgent(const WorkerAgentSpec &agent_spec) { return Status(); } +Status DistributedServable::RegisterAgent(const WorkerAgentSpec &agent_spec) { + agent_spec_list_[agent_spec.rank_id] = agent_spec; + return Status(); +} +Status DistributedServable::UnregisterAgent(const WorkerAgentSpec &agent_spec) { + for (auto iter = agent_spec_list_.begin(); iter != agent_spec_list_.end();) { + if (agent_spec.rank_id == iter->second.rank_id) { + iter = agent_spec_list_.erase(iter); + } else { + ++iter; + } + } + return Status(); +} Status DistributedServable::SetProperty(uint32_t rank_size, uint32_t stage_size, bool with_bach_dim, const std::vector &without_batch_dim_inputs) { return Status(); diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/notify_agent.cc b/mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/notify_agent.cc index cbc3e50..2c810d2 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/notify_agent.cc +++ b/mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/notify_agent.cc @@ -24,14 +24,36 @@ namespace mindspore { namespace serving { -GrpcNotfiyAgent::GrpcNotfiyAgent(const std::string &worker_address) {} +GrpcNotfiyAgent::GrpcNotfiyAgent(const std::string &agent_address) { + agent_address_ = agent_address; + std::shared_ptr channel = GrpcServer::CreateChannel(agent_address_); + stub_ = proto::MSAgent::NewStub(channel); +} GrpcNotfiyAgent::~GrpcNotfiyAgent() = default; -Status GrpcNotfiyAgent::Exit() { return SUCCESS; } +Status GrpcNotfiyAgent::Exit() { + if (stub_) { + proto::DistributedExitRequest request; + request.set_address(agent_address_); + proto::DistributedExitReply reply; + grpc::ClientContext context; + const int32_t TIME_OUT = 1; + std::chrono::system_clock::time_point deadline = std::chrono::system_clock::now() + std::chrono::seconds(TIME_OUT); + context.set_deadline(deadline); + + (void)stub_->DistributedExit(&context, request, &reply); + } + return SUCCESS; +} Status GrpcNotfiyAgent::DispatchAsync(const proto::PredictRequest &request, proto::PredictReply *reply, DistributeCallback callback) { + if (!stub_) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Predict failed, agent gRPC has not been inited or has already exited, agent address " << agent_address_; + } + // todo send async message return SUCCESS; } diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/notify_agent.h b/mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/notify_agent.h index 974c54a..cf984f1 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/notify_agent.h +++ b/mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/notify_agent.h @@ -38,7 +38,7 @@ class MS_API GrpcNotfiyAgent : public BaseNotifyAgent { DistributeCallback callback) override; private: - std::string worker_address_; + std::string agent_address_; std::shared_ptr stub_ = nullptr; }; diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.cc b/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.cc index d0014d5..50d3f38 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.cc +++ b/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.cc @@ -45,15 +45,15 @@ Status GrpcNotfiyDistributeWorker::Register(const std::vector & auto loop = REGISTER_TIME_OUT; while (loop-- && !ExitSignalHandle::Instance().HasStopped()) { MSI_LOG(INFO) << "Register to " << distributed_worker_address_; - proto::RegisterRequest request; + proto::AgentRegisterRequest request; request.set_address(agent_address_); // to do set RegisterRequest message - proto::RegisterReply reply; + proto::AgentRegisterReply reply; grpc::ClientContext context; std::chrono::system_clock::time_point deadline = std::chrono::system_clock::now() + std::chrono::seconds(REGISTER_INTERVAL); context.set_deadline(deadline); - grpc::Status status = stub_->Register(&context, request, &reply); + grpc::Status status = stub_->AgentRegister(&context, request, &reply); if (status.ok()) { MSI_LOG(INFO) << "Register SUCCESS "; return SUCCESS; @@ -72,14 +72,14 @@ Status GrpcNotfiyDistributeWorker::Unregister() { return SUCCESS; } is_stoped_ = true; - proto::ExitRequest request; + proto::AgentExitRequest request; request.set_address(agent_address_); - proto::ExitReply reply; + proto::AgentExitReply reply; grpc::ClientContext context; const int32_t TIME_OUT = 1; std::chrono::system_clock::time_point deadline = std::chrono::system_clock::now() + std::chrono::seconds(TIME_OUT); context.set_deadline(deadline); - grpc::Status status = stub_->Exit(&context, request, &reply); + grpc::Status status = stub_->AgentExit(&context, request, &reply); if (status.ok()) { MSI_LOG(INFO) << "Exit SUCCESS "; return SUCCESS; diff --git a/mindspore_serving/proto/ms_agent.proto b/mindspore_serving/proto/ms_agent.proto index eb44638..4428810 100644 --- a/mindspore_serving/proto/ms_agent.proto +++ b/mindspore_serving/proto/ms_agent.proto @@ -19,9 +19,36 @@ syntax = "proto3"; package mindspore.serving.proto; import "mindspore_serving/proto/ms_service.proto"; -import "mindspore_serving/proto/ms_master.proto"; + +message DistributedServableSpec { + // servable name + string name = 1; + + // optional. If unspecified, the latest version servable will be used. + int64 version_number = 3; + + // Specifies the method name in the servable. + string method_name = 2; +} + +message DistributedPredictRequest { + DistributedServableSpec servable_spec = 1; +} + +message DistributedPredictReply { + DistributedServableSpec servable_spec = 1; + repeated ErrorMsg error_msg = 2; +} + +message DistributedExitRequest { + string address = 1; +} + +message DistributedExitReply { + ErrorMsg error_msg = 1; +} service MSAgent { - rpc Predict(PredictRequest) returns (PredictReply) {} - rpc Exit(ExitRequest) returns (ExitReply) {} + rpc DistributedPredict(DistributedPredictRequest) returns (DistributedPredictReply) {} + rpc DistributedExit(DistributedExitRequest) returns (DistributedExitReply) {} } diff --git a/mindspore_serving/proto/ms_distributed.proto b/mindspore_serving/proto/ms_distributed.proto index a53f126..936b245 100644 --- a/mindspore_serving/proto/ms_distributed.proto +++ b/mindspore_serving/proto/ms_distributed.proto @@ -19,10 +19,29 @@ syntax = "proto3"; package mindspore.serving.proto; import "mindspore_serving/proto/ms_service.proto"; -import "mindspore_serving/proto/ms_master.proto"; +message AgentSpec { + string name = 1; + int64 version_number = 2; +} + +message AgentRegisterRequest { + repeated AgentSpec worker_spec = 1; + string address = 2; +} + +message AgentRegisterReply { + ErrorMsg error_msg = 1; +} + +message AgentExitRequest { + string address = 1; +} + +message AgentExitReply { + ErrorMsg error_msg = 1; +} service MSDistributedWorker { - rpc Predict(PredictRequest) returns (PredictReply) {} - rpc Exit(ExitRequest) returns (ExitReply) {} - rpc Register(RegisterRequest) returns (RegisterReply) {} + rpc AgentExit(AgentExitRequest) returns (AgentExitReply) {} + rpc AgentRegister(AgentRegisterRequest) returns (AgentRegisterReply) {} } \ No newline at end of file From 3be4b90098c9f62128f703f20f8757e77572fe48 Mon Sep 17 00:00:00 2001 From: zhangyinxia Date: Thu, 28 Jan 2021 14:02:08 +0800 Subject: [PATCH 05/10] add message --- .../ccsrc/common/proto_tensor.cc | 49 +++++++++++++++++++ mindspore_serving/ccsrc/common/proto_tensor.h | 4 ++ mindspore_serving/ccsrc/common/servable.h | 9 ++++ .../agent_process/agent_process.cc | 5 +- .../ccsrc/worker/distributed_worker/common.h | 9 ---- .../distributed_process.cc | 31 +++++++----- .../distributed_servable.cc | 36 ++++++++++++-- .../distributed_worker/distributed_servable.h | 10 +++- .../notify_distributed/notify_worker.cc | 4 +- .../notify_distributed/notify_worker.h | 1 - .../worker/distributed_worker/worker_agent.cc | 5 ++ .../worker/distributed_worker/worker_agent.h | 3 ++ mindspore_serving/proto/ms_agent.proto | 6 +-- mindspore_serving/proto/ms_distributed.proto | 13 +++-- 14 files changed, 146 insertions(+), 39 deletions(-) diff --git a/mindspore_serving/ccsrc/common/proto_tensor.cc b/mindspore_serving/ccsrc/common/proto_tensor.cc index 2b8fd9e..f97fe9c 100644 --- a/mindspore_serving/ccsrc/common/proto_tensor.cc +++ b/mindspore_serving/ccsrc/common/proto_tensor.cc @@ -341,6 +341,55 @@ Status GrpcTensorHelper::CreateInstanceFromRequestInstances(const proto::Predict return SUCCESS; } +void GrpcTensorHelper::CopyFromAgentSpec(const proto::AgentSpec &specs, WorkerAgentSpec *worker_specs) { + worker_specs->rank_id = specs.rank_id(); + worker_specs->batch_size = specs.batch_size(); + worker_specs->input_size = specs.input_size(); + for (auto &in : specs.inputs()) { + TensorInfo info; + info.data_type = ProtoTensor::TransDataType2Inference(in.dtype()); + for (auto &dim : in.shape().dims()) { + info.shape.push_back(dim); + } + worker_specs->input_infos.push_back(info); + } + for (auto &out : specs.outputs()) { + TensorInfo info; + info.data_type = ProtoTensor::TransDataType2Inference(out.dtype()); + for (auto &dim : out.shape().dims()) { + info.shape.push_back(dim); + } + worker_specs->output_infos.push_back(info); + } +} + +void GrpcTensorHelper::CopyFromWorkerAgentSpec(const std::vector &worker_specs, + proto::AgentRegisterRequest *request) { + for (size_t i = 0; i < worker_specs.size(); i++) { + auto &spec = worker_specs[i]; + auto worker_spec = request->add_agent_spec(); + worker_spec->set_rank_id(spec.rank_id); + worker_spec->set_batch_size(spec.batch_size); + worker_spec->set_input_size(spec.input_size); + for (auto &method : spec.input_infos) { + auto proto_method = worker_spec->add_inputs(); + proto_method->set_dtype(ProtoTensor::TransDataType2Proto(method.data_type)); + auto proto_shape = proto_method->mutable_shape(); + for (auto &dim : method.shape) { + proto_shape->add_dims(dim); + } + } + for (auto &method : spec.output_infos) { + auto proto_method = worker_spec->add_outputs(); + proto_method->set_dtype(ProtoTensor::TransDataType2Proto(method.data_type)); + auto proto_shape = proto_method->mutable_shape(); + for (auto &dim : method.shape) { + proto_shape->add_dims(dim); + } + } + } +} + Status GrpcTensorHelper::CheckRequestTensor(const proto::Tensor &tensor) { Status status; ProtoTensor tensor_input(const_cast(&tensor)); diff --git a/mindspore_serving/ccsrc/common/proto_tensor.h b/mindspore_serving/ccsrc/common/proto_tensor.h index 554da02..d80b982 100644 --- a/mindspore_serving/ccsrc/common/proto_tensor.h +++ b/mindspore_serving/ccsrc/common/proto_tensor.h @@ -24,6 +24,7 @@ #include "common/serving_common.h" #include "proto/ms_service.pb.h" #include "proto/ms_master.pb.h" +#include "proto/ms_distributed.pb.h" #include "common/instance.h" #include "common/servable.h" @@ -68,6 +69,9 @@ class MS_API GrpcTensorHelper { std::vector *results); static Status CreateReplyFromInstances(const proto::PredictRequest &request, const std::vector &inputs, proto::PredictReply *reply); + static void CopyFromAgentSpec(const proto::AgentSpec &request, WorkerAgentSpec *worker_specs); + static void CopyFromWorkerAgentSpec(const std::vector &worker_specs, + proto::AgentRegisterRequest *request); private: static Status CreateInstanceFromRequestInstances(const proto::PredictRequest &request, diff --git a/mindspore_serving/ccsrc/common/servable.h b/mindspore_serving/ccsrc/common/servable.h index 5402bc0..01dd977 100644 --- a/mindspore_serving/ccsrc/common/servable.h +++ b/mindspore_serving/ccsrc/common/servable.h @@ -144,6 +144,15 @@ static inline LogStream &operator<<(LogStream &stream, PredictPhaseTag data_type return stream; } +struct WorkerAgentSpec { + std::string agent_address; + uint32_t rank_id = 0; + std::vector input_infos; + std::vector output_infos; + uint32_t batch_size = 0; + uint32_t input_size = 0; +}; + } // namespace mindspore::serving #endif // MINDSPORE_SERVING_SERVABLE_H diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/agent_process/agent_process.cc b/mindspore_serving/ccsrc/worker/distributed_worker/agent_process/agent_process.cc index 37f6bff..97a7a77 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/agent_process/agent_process.cc +++ b/mindspore_serving/ccsrc/worker/distributed_worker/agent_process/agent_process.cc @@ -15,13 +15,14 @@ */ #include "worker/distributed_worker/agent_process/agent_process.h" +#include "worker/distributed_worker/worker_agent.h" namespace mindspore { namespace serving { grpc::Status MSAgentImpl::DistributedExit(grpc::ServerContext *context, const proto::DistributedExitRequest *request, proto::DistributedExitReply *reply) { MSI_LOG(INFO) << "Distributed Worker Exit"; - // WorkerAgent::GetInstance().StopServable(false); + WorkerAgent::Instance().Clear(); return grpc::Status::OK; } @@ -29,7 +30,7 @@ grpc::Status MSAgentImpl::DistributedPredict(grpc::ServerContext *context, const proto::DistributedPredictRequest *request, proto::DistributedPredictReply *reply) { MSI_LOG(INFO) << "Begin call service Eval"; - // WorkerAgent::GetInstance().Run(*request, reply); + WorkerAgent::Instance().Run(*request, reply); return grpc::Status::OK; } diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/common.h b/mindspore_serving/ccsrc/worker/distributed_worker/common.h index 801894a..2bda8ec 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/common.h +++ b/mindspore_serving/ccsrc/worker/distributed_worker/common.h @@ -46,15 +46,6 @@ struct DistributedServableConfig { DistributedServableCommonConfig common_config; }; -struct WorkerAgentSpec { - std::string ip; - uint32_t port = 0; - uint32_t rank_id = 0; - std::vector input_infos; - std::vector output_infos; - uint32_t batch_size = 0; -}; - struct AgentStartUpConfig { uint32_t rank_id; uint32_t device_id; diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.cc b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.cc index 8fa3fe6..72442cc 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.cc +++ b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.cc @@ -15,6 +15,7 @@ */ #include "worker/distributed_worker/distributed_process/distributed_process.h" +#include "common/proto_tensor.h" namespace mindspore { namespace serving { @@ -23,12 +24,15 @@ grpc::Status MSDistributedImpl::AgentRegister(grpc::ServerContext *context, cons proto::AgentRegisterReply *reply) { MSI_EXCEPTION_IF_NULL(request); MSI_EXCEPTION_IF_NULL(reply); - WorkerAgentSpec agent_spec; - // todo request->agent_spec - Status status(FAILED); - status = servable_->RegisterAgent(agent_spec); - if (status != SUCCESS) { - MSI_LOG(ERROR) << "Agent Register FAILED"; + for (auto &spec : request->agent_spec()) { + WorkerAgentSpec agent_spec; + agent_spec.agent_address = request->address(); + GrpcTensorHelper::CopyFromAgentSpec(spec, &agent_spec); + Status status(FAILED); + status = servable_->RegisterAgent(agent_spec); + if (status != SUCCESS) { + MSI_LOG(ERROR) << "Agent Register FAILED"; + } } return grpc::Status::OK; } @@ -37,12 +41,15 @@ grpc::Status MSDistributedImpl::AgentExit(grpc::ServerContext *context, const pr proto::AgentExitReply *reply) { MSI_EXCEPTION_IF_NULL(request); MSI_EXCEPTION_IF_NULL(reply); - WorkerAgentSpec agent_spec; - // todo request->agent_spec - Status status(FAILED); - status = servable_->UnregisterAgent(agent_spec); - if (status != SUCCESS) { - MSI_LOG(ERROR) << "Agent Exit FAILED"; + for (auto &spec : request->agent_spec()) { + WorkerAgentSpec agent_spec; + agent_spec.agent_address = request->address(); + GrpcTensorHelper::CopyFromAgentSpec(spec, &agent_spec); + Status status(FAILED); + status = servable_->UnregisterAgent(agent_spec); + if (status != SUCCESS) { + MSI_LOG(ERROR) << "Agent Exit FAILED"; + } } return grpc::Status::OK; } diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.cc b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.cc index ee76d5f..9bc44f1 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.cc +++ b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.cc @@ -17,7 +17,8 @@ #include "worker/distributed_worker/distributed_servable.h" #include #include - +#include "worker/worker.h" +#include "worker/distributed_worker/notify_agent/notify_agent.h" namespace mindspore { namespace serving { @@ -29,19 +30,46 @@ std::vector DistributedServable::GetOutputInfos() const { return std uint64_t DistributedServable::GetBatchSize() const { return 0; } Status DistributedServable::GetDistributedServableConfig(DistributedServableConfig *config) { return Status(); } Status DistributedServable::RegisterAgent(const WorkerAgentSpec &agent_spec) { - agent_spec_list_[agent_spec.rank_id] = agent_spec; - return Status(); + DistributedAgentContext context; + auto it = agent_spec_list_.find(agent_spec.rank_id); + if (it != agent_spec_list_.end()) { + MSI_LOG_WARNING << "rank_id " << agent_spec.rank_id << " has been registered"; + return SUCCESS; + } + context.agent_spec_ = agent_spec; + std::shared_ptr notify_agent = std::make_shared(agent_spec.agent_address); + context.notify_agent_ = notify_agent; + agent_spec_list_[agent_spec.rank_id] = context; + if (config_.rank_size == agent_spec_list_.size()) { + Status status = Worker::GetInstance().RegisterWorker(); + if (status != SUCCESS) { + Clear(); + return FAILED; + } + } + return SUCCESS; } + +void DistributedServable::Clear() { + for (auto agent : agent_spec_list_) { + agent.second.notify_agent_->Exit(); + } + Worker::GetInstance().StopServable(false); +} + Status DistributedServable::UnregisterAgent(const WorkerAgentSpec &agent_spec) { for (auto iter = agent_spec_list_.begin(); iter != agent_spec_list_.end();) { - if (agent_spec.rank_id == iter->second.rank_id) { + if (agent_spec.rank_id == iter->second.agent_spec_.rank_id) { iter = agent_spec_list_.erase(iter); } else { ++iter; } } + // todo: send exit message to agent, and then exit if split with master + Clear(); return Status(); } + Status DistributedServable::SetProperty(uint32_t rank_size, uint32_t stage_size, bool with_bach_dim, const std::vector &without_batch_dim_inputs) { return Status(); diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.h b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.h index 2808943..507a283 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.h +++ b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.h @@ -20,12 +20,19 @@ #include #include #include +#include #include "worker/sevable_base.h" #include "worker/distributed_worker/common.h" +#include "worker/distributed_worker/notify_agent/base_notify_agent.h" namespace mindspore { namespace serving { +struct DistributedAgentContext { + WorkerAgentSpec agent_spec_; + std::shared_ptr notify_agent_ = nullptr; +}; + class MS_API DistributedServable : public ServableBase { public: // from python, servable_config.py @@ -47,10 +54,11 @@ class MS_API DistributedServable : public ServableBase { std::vector GetInputInfos() const override; std::vector GetOutputInfos() const override; uint64_t GetBatchSize() const override; + void Clear(); private: DistributedServableConfig config_; - std::map agent_spec_list_; + std::map agent_spec_list_; // agent stubs }; diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.cc b/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.cc index 50d3f38..230e225 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.cc +++ b/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.cc @@ -20,6 +20,7 @@ #include #include "common/exit_handle.h" #include "common/grpc_server.h" +#include "common/proto_tensor.h" namespace mindspore { namespace serving { @@ -46,8 +47,7 @@ Status GrpcNotfiyDistributeWorker::Register(const std::vector & while (loop-- && !ExitSignalHandle::Instance().HasStopped()) { MSI_LOG(INFO) << "Register to " << distributed_worker_address_; proto::AgentRegisterRequest request; - request.set_address(agent_address_); - // to do set RegisterRequest message + GrpcTensorHelper::CopyFromWorkerAgentSpec(worker_specs, &request); proto::AgentRegisterReply reply; grpc::ClientContext context; std::chrono::system_clock::time_point deadline = diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.h b/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.h index e698c56..d618878 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.h +++ b/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.h @@ -42,7 +42,6 @@ class MS_API GrpcNotfiyDistributeWorker : public BaseNotifyDistributeWorker { uint32_t host_port_; std::string agent_address_; std::string distributed_worker_address_; - std::unique_ptr stub_; std::atomic is_stoped_{false}; }; diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.cc b/mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.cc index 4e21583..2e497e0 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.cc +++ b/mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.cc @@ -34,5 +34,10 @@ Status WorkerAgent::ExecuteModel(const std::vector &request, std: return executor_.ExecuteModel(request, reply); } +Status WorkerAgent::Run(const proto::DistributedPredictRequest &request, proto::DistributedPredictReply *reply) { + // todo :call ExecuteModel + return SUCCESS; +} + } // namespace serving } // namespace mindspore diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.h b/mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.h index 9119e63..520c4db 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.h +++ b/mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.h @@ -18,6 +18,8 @@ #define MINDSPORE_SERVING_WORKER_AGENT_H #include #include "worker/distributed_worker/agent_executor.h" +#include "proto/ms_agent.pb.h" +#include "proto/ms_agent.grpc.pb.h" namespace mindspore { namespace serving { @@ -28,6 +30,7 @@ class MS_API WorkerAgent { Status Clear(); Status ExecuteModel(const std::vector &request, std::vector *reply); + Status Run(const proto::DistributedPredictRequest &request, proto::DistributedPredictReply *reply); private: AgentStartUpConfig config_; diff --git a/mindspore_serving/proto/ms_agent.proto b/mindspore_serving/proto/ms_agent.proto index 4428810..143f0de 100644 --- a/mindspore_serving/proto/ms_agent.proto +++ b/mindspore_serving/proto/ms_agent.proto @@ -23,12 +23,10 @@ import "mindspore_serving/proto/ms_service.proto"; message DistributedServableSpec { // servable name string name = 1; - // optional. If unspecified, the latest version servable will be used. - int64 version_number = 3; - + int64 version_number = 2; // Specifies the method name in the servable. - string method_name = 2; + string method_name = 3; } message DistributedPredictRequest { diff --git a/mindspore_serving/proto/ms_distributed.proto b/mindspore_serving/proto/ms_distributed.proto index 936b245..fc48ead 100644 --- a/mindspore_serving/proto/ms_distributed.proto +++ b/mindspore_serving/proto/ms_distributed.proto @@ -21,12 +21,15 @@ package mindspore.serving.proto; import "mindspore_serving/proto/ms_service.proto"; message AgentSpec { - string name = 1; - int64 version_number = 2; + int64 rank_id = 1; + int64 batch_size = 2; + int64 input_size = 3; + repeated Tensor inputs =4; + repeated Tensor outputs = 5; } message AgentRegisterRequest { - repeated AgentSpec worker_spec = 1; + repeated AgentSpec agent_spec = 1; string address = 2; } @@ -35,12 +38,14 @@ message AgentRegisterReply { } message AgentExitRequest { - string address = 1; + repeated AgentSpec agent_spec = 1; + string address = 2; } message AgentExitReply { ErrorMsg error_msg = 1; } + service MSDistributedWorker { rpc AgentExit(AgentExitRequest) returns (AgentExitReply) {} rpc AgentRegister(AgentRegisterRequest) returns (AgentRegisterReply) {} From 47851e0081b0feeb082bec8c7d5198dd468e1733 Mon Sep 17 00:00:00 2001 From: xuyongfei Date: Thu, 28 Jan 2021 19:31:02 +0800 Subject: [PATCH 06/10] Serving, python gpt3 --- mindspore_serving/ccsrc/common/servable.cc | 340 +++++++++++------- mindspore_serving/ccsrc/common/servable.h | 39 +- mindspore_serving/ccsrc/python/serving_py.cc | 63 +++- .../ccsrc/python/worker/servable_py.cc | 11 +- .../ccsrc/python/worker/servable_py.h | 1 + .../ccsrc/python/worker/worker_py.cc | 57 ++- .../ccsrc/python/worker/worker_py.h | 9 + .../worker/distributed_worker/agent_startup.h | 2 +- .../ccsrc/worker/distributed_worker/common.h | 17 +- .../distributed_servable.cc | 255 ++++++++++++- .../distributed_worker/distributed_servable.h | 29 +- .../local_sevable.cc} | 60 ++-- .../local_sevable.h} | 15 +- mindspore_serving/ccsrc/worker/sevable_base.h | 3 +- .../ccsrc/worker/work_executor.cc | 22 +- mindspore_serving/ccsrc/worker/worker.cc | 17 +- mindspore_serving/proto/ms_distributed.proto | 4 +- .../worker/distributed/distributed_worker.py | 10 + .../worker/distributed/register.py | 29 +- mindspore_serving/worker/register/method.py | 26 +- mindspore_serving/worker/register/servable.py | 33 +- tests/ut/cpp/common/test_servable_common.h | 12 +- 22 files changed, 766 insertions(+), 288 deletions(-) rename mindspore_serving/ccsrc/worker/{ascend_servable/ascend_sevable.cc => local_servable/local_sevable.cc} (79%) rename mindspore_serving/ccsrc/worker/{ascend_servable/ascend_sevable.h => local_servable/local_sevable.h} (86%) diff --git a/mindspore_serving/ccsrc/common/servable.cc b/mindspore_serving/ccsrc/common/servable.cc index d36bacf..9f90fbf 100644 --- a/mindspore_serving/ccsrc/common/servable.cc +++ b/mindspore_serving/ccsrc/common/servable.cc @@ -25,11 +25,23 @@ namespace mindspore::serving { std::string ServableMeta::Repr() const { std::ostringstream stream; - stream << "path(" << servable_name << ") file(" << servable_file + ")"; + switch (servable_type) { + case kServableTypeUnknown: + stream << "undeclared servable, servable name: '" << common_meta.servable_name << "'"; + break; + case kServableTypeLocal: + stream << "local servable, servable name: '" << common_meta.servable_name << "', file: '" + << local_meta.servable_file + "'"; + break; + case kServableTypeDistributed: + stream << "distributed servable, servable name: '" << common_meta.servable_name + << "', rank size: " << distributed_meta.rank_size << ", stage size " << distributed_meta.stage_size; + break; + } return stream.str(); } -void ServableMeta::SetModelFormat(const std::string &format) { +void LocalServableMeta::SetModelFormat(const std::string &format) { if (format == "om") { model_format = kOM; } else if (format == "mindir") { @@ -63,142 +75,181 @@ std::string RequestSpec::Repr() const { return "servable(" + servable_name + ") " + "method(" + method_name + ") " + version; } -Status ServableSignature::Check() const { - std::set method_set; +Status ServableSignature::CheckPreprocessInput(const MethodSignature &method, size_t *preprocess_outputs_count) const { std::string model_str = servable_meta.Repr(); + const auto &preprocess_name = method.preprocess_name; + if (!preprocess_name.empty()) { + auto preprocess = PreprocessStorage::Instance().GetPreprocess(preprocess_name); + if (preprocess == nullptr) { + return INFER_STATUS_LOG_ERROR(FAILED) << "Model " << model_str << " method " << method.method_name + << " preprocess " << preprocess_name << " not defined"; + } + *preprocess_outputs_count = preprocess->GetOutputsCount(preprocess_name); - for (auto &method : methods) { - if (method_set.count(method.method_name) > 0) { - return INFER_STATUS_LOG_ERROR(FAILED) - << "Model " << model_str << " " << method.method_name << " has been defined repeatly"; + for (size_t i = 0; i < method.preprocess_inputs.size(); i++) { + auto &input = method.preprocess_inputs[i]; + if (input.first != kPredictPhaseTag_Input) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Model " << model_str << " method " << method.method_name << ", the data of preprocess " << i + << "th input cannot not come from '" << input.first << "'"; + } + if (input.second >= method.inputs.size()) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Model " << model_str << " method " << method.method_name << ", the preprocess " << i + << "th input uses method " << input.second << "th input, that is greater than the method inputs size " + << method.inputs.size(); + } } - method_set.emplace(method.method_name); + } + return SUCCESS; +} - size_t preprocess_outputs_count = 0; - size_t postprocess_outputs_count = 0; +Status ServableSignature::CheckPredictInput(const MethodSignature &method, size_t preprocess_outputs_count) const { + std::string model_str = servable_meta.Repr(); - const auto &preprocess_name = method.preprocess_name; - if (!preprocess_name.empty()) { - auto preprocess = PreprocessStorage::Instance().GetPreprocess(preprocess_name); - if (preprocess == nullptr) { - return INFER_STATUS_LOG_ERROR(FAILED) << "Model " << model_str << " method " << method.method_name - << " preprocess " << preprocess_name << " not defined"; + for (size_t i = 0; i < method.servable_inputs.size(); i++) { + auto &input = method.servable_inputs[i]; + if (input.first == kPredictPhaseTag_Input) { + if (input.second >= method.inputs.size()) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Model " << model_str << " method " << method.method_name << ", the servable " << i + << "th input uses method " << input.second << "th input, that is greater than the method inputs size " + << method.inputs.size(); } - preprocess_outputs_count = preprocess->GetOutputsCount(preprocess_name); - - for (size_t i = 0; i < method.preprocess_inputs.size(); i++) { - auto &input = method.preprocess_inputs[i]; - if (input.first != kPredictPhaseTag_Input) { - return INFER_STATUS_LOG_ERROR(FAILED) - << "Model " << model_str << " method " << method.method_name << ", the data of preprocess " << i - << "th input cannot not come from '" << input.first << "'"; - } - if (input.second >= method.inputs.size()) { - return INFER_STATUS_LOG_ERROR(FAILED) - << "Model " << model_str << " method " << method.method_name << ", the preprocess " << i - << "th input uses method " << input.second << "th input, that is greater than the method inputs size " - << method.inputs.size(); - } + } else if (input.first == kPredictPhaseTag_Preproces) { + if (input.second >= preprocess_outputs_count) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Model " << model_str << " method " << method.method_name << ", the servable " << i + << "th input uses preprocess " << input.second + << "th output, that is greater than the preprocess outputs size " << preprocess_outputs_count; } + } else { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Model " << model_str << " method " << method.method_name << ", the data of servable " << i + << "th input cannot not come from '" << input.first << "'"; + } + } + return SUCCESS; +} + +Status ServableSignature::CheckPostprocessInput(const MethodSignature &method, size_t preprocess_outputs_count, + size_t *postprocess_outputs_count) const { + std::string model_str = servable_meta.Repr(); + const auto &common_meta = servable_meta.common_meta; + + const auto &postprocess_name = method.postprocess_name; + if (!method.postprocess_name.empty()) { + auto postprocess = PostprocessStorage::Instance().GetPostprocess(postprocess_name); + if (postprocess == nullptr) { + return INFER_STATUS_LOG_ERROR(FAILED) << "Model " << model_str << " method " << method.method_name + << " postprocess " << postprocess_name << " not defined"; } + *postprocess_outputs_count = postprocess->GetOutputsCount(postprocess_name); - for (size_t i = 0; i < method.servable_inputs.size(); i++) { - auto &input = method.servable_inputs[i]; + for (size_t i = 0; i < method.postprocess_inputs.size(); i++) { + auto &input = method.postprocess_inputs[i]; if (input.first == kPredictPhaseTag_Input) { if (input.second >= method.inputs.size()) { return INFER_STATUS_LOG_ERROR(FAILED) - << "Model " << model_str << " method " << method.method_name << ", the servable " << i + << "Model " << model_str << " method " << method.method_name << ", the postprocess " << i << "th input uses method " << input.second << "th input, that is greater than the method inputs size " << method.inputs.size(); } } else if (input.first == kPredictPhaseTag_Preproces) { if (input.second >= preprocess_outputs_count) { return INFER_STATUS_LOG_ERROR(FAILED) - << "Model " << model_str << " method " << method.method_name << ", the servable " << i + << "Model " << model_str << " method " << method.method_name << ", the postprocess " << i << "th input uses preprocess " << input.second << "th output, that is greater than the preprocess outputs size " << preprocess_outputs_count; } + } else if (input.first == kPredictPhaseTag_Predict) { + if (input.second >= common_meta.outputs_count) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Model " << model_str << " method " << method.method_name << ", the postprocess " << i + << "th input uses servable " << input.second + << "th output, that is greater than the servable outputs size " << common_meta.outputs_count; + } } else { return INFER_STATUS_LOG_ERROR(FAILED) - << "Model " << model_str << " method " << method.method_name << ", the data of servable " << i + << "Model " << model_str << " method " << method.method_name << ", the data of postprocess " << i << "th input cannot not come from '" << input.first << "'"; } } + } + return SUCCESS; +} - const auto &postprocess_name = method.postprocess_name; - if (!method.postprocess_name.empty()) { - auto postprocess = PostprocessStorage::Instance().GetPostprocess(postprocess_name); - if (postprocess == nullptr) { - return INFER_STATUS_LOG_ERROR(FAILED) << "Model " << model_str << " method " << method.method_name - << " postprocess " << postprocess_name << " not defined"; - } - postprocess_outputs_count = postprocess->GetOutputsCount(postprocess_name); +Status ServableSignature::CheckReturn(const MethodSignature &method, size_t preprocess_outputs_count, + size_t postprocess_outputs_count) const { + std::string model_str = servable_meta.Repr(); + const auto &common_meta = servable_meta.common_meta; - for (size_t i = 0; i < method.postprocess_inputs.size(); i++) { - auto &input = method.postprocess_inputs[i]; - if (input.first == kPredictPhaseTag_Input) { - if (input.second >= method.inputs.size()) { - return INFER_STATUS_LOG_ERROR(FAILED) - << "Model " << model_str << " method " << method.method_name << ", the postprocess " << i - << "th input uses method " << input.second - << "th input, that is greater than the method inputs size " << method.inputs.size(); - } - } else if (input.first == kPredictPhaseTag_Preproces) { - if (input.second >= preprocess_outputs_count) { - return INFER_STATUS_LOG_ERROR(FAILED) - << "Model " << model_str << " method " << method.method_name << ", the postprocess " << i - << "th input uses preprocess " << input.second - << "th output, that is greater than the preprocess outputs size " << preprocess_outputs_count; - } - } else if (input.first == kPredictPhaseTag_Predict) { - if (input.second >= servable_meta.outputs_count) { - return INFER_STATUS_LOG_ERROR(FAILED) - << "Model " << model_str << " method " << method.method_name << ", the postprocess " << i - << "th input uses servable " << input.second - << "th output, that is greater than the servable outputs size " << servable_meta.outputs_count; - } - } else { - return INFER_STATUS_LOG_ERROR(FAILED) - << "Model " << model_str << " method " << method.method_name << ", the data of postprocess " << i - << "th input cannot not come from '" << input.first << "'"; - } + for (size_t i = 0; i < method.returns.size(); i++) { + auto &input = method.returns[i]; + if (input.first == kPredictPhaseTag_Input) { + if (input.second >= method.inputs.size()) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Model " << model_str << " method " << method.method_name << ", the method " << i + << "th output uses method " << input.second << "th input, that is greater than the method inputs size " + << method.inputs.size(); } - } - for (size_t i = 0; i < method.returns.size(); i++) { - auto &input = method.returns[i]; - if (input.first == kPredictPhaseTag_Input) { - if (input.second >= method.inputs.size()) { - return INFER_STATUS_LOG_ERROR(FAILED) - << "Model " << model_str << " method " << method.method_name << ", the method " << i - << "th output uses method " << input.second << "th input, that is greater than the method inputs size " - << method.inputs.size(); - } - } else if (input.first == kPredictPhaseTag_Preproces) { - if (input.second >= preprocess_outputs_count) { - return INFER_STATUS_LOG_ERROR(FAILED) - << "Model " << model_str << " method " << method.method_name << ", the method " << i - << "th output uses preprocess " << input.second - << "th output, that is greater than the preprocess outputs size " << preprocess_outputs_count; - } - } else if (input.first == kPredictPhaseTag_Predict) { - if (input.second >= servable_meta.outputs_count) { - return INFER_STATUS_LOG_ERROR(FAILED) - << "Model " << model_str << " method " << method.method_name << ", the method " << i - << "th output uses servable " << input.second - << "th output, that is greater than the servable outputs size " << servable_meta.outputs_count; - } - } else if (input.first == kPredictPhaseTag_Postprocess) { - if (input.second >= postprocess_outputs_count) { - return INFER_STATUS_LOG_ERROR(FAILED) - << "Model " << model_str << " method " << method.method_name << ", the method " << i - << "th output uses postprocess " << input.second - << "th output, that is greater than the postprocess outputs size " << postprocess_outputs_count; - } - } else { + } else if (input.first == kPredictPhaseTag_Preproces) { + if (input.second >= preprocess_outputs_count) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Model " << model_str << " method " << method.method_name << ", the method " << i + << "th output uses preprocess " << input.second + << "th output, that is greater than the preprocess outputs size " << preprocess_outputs_count; + } + } else if (input.first == kPredictPhaseTag_Predict) { + if (input.second >= common_meta.outputs_count) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Model " << model_str << " method " << method.method_name << ", the method " << i + << "th output uses servable " << input.second + << "th output, that is greater than the servable outputs size " << common_meta.outputs_count; + } + } else if (input.first == kPredictPhaseTag_Postprocess) { + if (input.second >= postprocess_outputs_count) { return INFER_STATUS_LOG_ERROR(FAILED) - << "Model " << model_str << " method " << method.method_name << ", the data of method " << i - << "th output cannot not come from '" << input.first << "'"; + << "Model " << model_str << " method " << method.method_name << ", the method " << i + << "th output uses postprocess " << input.second + << "th output, that is greater than the postprocess outputs size " << postprocess_outputs_count; } + } else { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Model " << model_str << " method " << method.method_name << ", the data of method " << i + << "th output cannot not come from '" << input.first << "'"; + } + } + return SUCCESS; +} + +Status ServableSignature::Check() const { + std::set method_set; + Status status; + for (auto &method : methods) { + if (method_set.count(method.method_name) > 0) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Model " << servable_meta.Repr() << " " << method.method_name << " has been defined repeatedly"; + } + method_set.emplace(method.method_name); + + size_t preprocess_outputs_count = 0; + size_t postprocess_outputs_count = 0; + status = CheckPreprocessInput(method, &preprocess_outputs_count); + if (status != SUCCESS) { + return status; + } + status = CheckPredictInput(method, preprocess_outputs_count); + if (status != SUCCESS) { + return status; + } + status = CheckPostprocessInput(method, preprocess_outputs_count, &postprocess_outputs_count); + if (status != SUCCESS) { + return status; + } + status = CheckReturn(method, preprocess_outputs_count, postprocess_outputs_count); + if (status != SUCCESS) { + return status; } } return SUCCESS; @@ -216,7 +267,7 @@ bool ServableSignature::GetMethodDeclare(const std::string &method_name, MethodS } void ServableStorage::Register(const ServableSignature &def) { - auto model_name = def.servable_meta.servable_name; + auto model_name = def.servable_meta.common_meta.servable_name; if (servable_signatures_map_.find(model_name) == servable_signatures_map_.end()) { MSI_LOG_WARNING << "Servable " << model_name << " has already been defined"; } @@ -258,16 +309,60 @@ Status ServableStorage::RegisterMethod(const MethodSignature &method) { return SUCCESS; } -void ServableStorage::DeclareServable(const mindspore::serving::ServableMeta &servable) { - MSI_LOG_INFO << "Declare servable " << servable.servable_name; - auto it = servable_signatures_map_.find(servable.servable_name); +Status ServableStorage::DeclareServable(ServableMeta servable) { + auto &common_meta = servable.common_meta; + MSI_LOG_INFO << "Declare servable " << common_meta.servable_name; + servable.servable_type = kServableTypeLocal; + if (servable.local_meta.servable_file.empty()) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Declare servable " << common_meta.servable_name << " failed, servable_file cannot be empty"; + } + if (servable.local_meta.model_format == api::kUnknownType) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Declare servable " << common_meta.servable_name << " failed, model_format is not inited"; + } + auto it = servable_signatures_map_.find(common_meta.servable_name); if (it == servable_signatures_map_.end()) { ServableSignature signature; signature.servable_meta = servable; - servable_signatures_map_[servable.servable_name] = signature; - return; + servable_signatures_map_[common_meta.servable_name] = signature; + return SUCCESS; } - it->second.servable_meta = servable; + auto &org_servable_meta = it->second.servable_meta; + if (org_servable_meta.servable_type != kServableTypeUnknown) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Servable " << common_meta.servable_name << " has already been declared as: " << servable.Repr(); + } + org_servable_meta = servable; + return SUCCESS; +} + +Status ServableStorage::DeclareDistributedServable(ServableMeta servable) { + auto &common_meta = servable.common_meta; + MSI_LOG_INFO << "Declare servable " << common_meta.servable_name; + servable.servable_type = kServableTypeDistributed; + if (servable.distributed_meta.rank_size == 0) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Declare distributed servable " << common_meta.servable_name << " failed, rank_size cannot be 0"; + } + if (servable.distributed_meta.stage_size == 0) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Declare distributed servable " << common_meta.servable_name << " failed, stage_size cannot be 0"; + } + auto it = servable_signatures_map_.find(common_meta.servable_name); + if (it == servable_signatures_map_.end()) { + ServableSignature signature; + signature.servable_meta = servable; + servable_signatures_map_[common_meta.servable_name] = signature; + return SUCCESS; + } + auto &org_servable_meta = it->second.servable_meta; + if (org_servable_meta.servable_type != kServableTypeUnknown) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Servable " << common_meta.servable_name << " has already been declared as: " << servable.Repr(); + } + org_servable_meta = servable; + return SUCCESS; } Status ServableStorage::RegisterInputOutputInfo(const std::string &servable_name, size_t inputs_count, @@ -277,18 +372,19 @@ Status ServableStorage::RegisterInputOutputInfo(const std::string &servable_name return INFER_STATUS_LOG_ERROR(FAILED) << "RegisterInputOutputInfo failed, cannot find servable " << servable_name; } auto &servable_meta = it->second.servable_meta; - if (servable_meta.inputs_count != 0 && servable_meta.inputs_count != inputs_count) { + auto &common_meta = servable_meta.common_meta; + if (common_meta.inputs_count != 0 && common_meta.inputs_count != inputs_count) { return INFER_STATUS_LOG_ERROR(FAILED) << "RegisterInputOutputInfo failed, inputs count " << inputs_count << " not match old count " - << servable_meta.inputs_count << ",servable name " << servable_name; + << common_meta.inputs_count << ",servable name " << servable_name; } - if (servable_meta.outputs_count != 0 && servable_meta.outputs_count != outputs_count) { + if (common_meta.outputs_count != 0 && common_meta.outputs_count != outputs_count) { return INFER_STATUS_LOG_ERROR(FAILED) << "RegisterInputOutputInfo failed, outputs count " << outputs_count << " not match old count " - << servable_meta.outputs_count << ",servable name " << servable_name; + << common_meta.outputs_count << ",servable name " << servable_name; } - servable_meta.inputs_count = inputs_count; - servable_meta.outputs_count = outputs_count; + common_meta.inputs_count = inputs_count; + common_meta.outputs_count = outputs_count; return SUCCESS; } @@ -298,8 +394,8 @@ std::vector ServableStorage::GetInputOutputInfo(const std::string &serva if (it == servable_signatures_map_.end()) { return result; } - result.push_back(it->second.servable_meta.inputs_count); - result.push_back(it->second.servable_meta.outputs_count); + result.push_back(it->second.servable_meta.common_meta.inputs_count); + result.push_back(it->second.servable_meta.common_meta.outputs_count); return result; } diff --git a/mindspore_serving/ccsrc/common/servable.h b/mindspore_serving/ccsrc/common/servable.h index 01dd977..0de2afe 100644 --- a/mindspore_serving/ccsrc/common/servable.h +++ b/mindspore_serving/ccsrc/common/servable.h @@ -81,19 +81,39 @@ struct RequestSpec { std::string Repr() const; }; -struct MS_API ServableMeta { +enum ServableType { + kServableTypeUnknown = 0, + kServableTypeLocal = 1, + kServableTypeDistributed = 2, +}; + +struct CommonServableMeta { std::string servable_name; - std::string servable_file; // file name - ModelType model_format; // OM, MindIR bool with_batch_dim = true; // whether there is batch dim in model's inputs/outputs + std::vector without_batch_dim_inputs; size_t inputs_count = 0; size_t outputs_count = 0; +}; +struct MS_API LocalServableMeta { + std::string servable_file; // file name + ModelType model_format = api::kUnknownType; // OM, MindIR std::map load_options; // Acl options - std::vector without_batch_dim_inputs; + void SetModelFormat(const std::string &format); +}; + +struct DistributedServableMeta { + size_t rank_size = 0; + size_t stage_size = 0; +}; + +struct MS_API ServableMeta { + ServableType servable_type = kServableTypeUnknown; + CommonServableMeta common_meta; + LocalServableMeta local_meta; + DistributedServableMeta distributed_meta; std::string Repr() const; - void SetModelFormat(const std::string &format); }; struct ServableSignature { @@ -102,6 +122,12 @@ struct ServableSignature { Status Check() const; bool GetMethodDeclare(const std::string &method_name, MethodSignature *method); + + private: + Status CheckPreprocessInput(const MethodSignature &method, size_t *pre) const; + Status CheckPredictInput(const MethodSignature &method, size_t pre) const; + Status CheckPostprocessInput(const MethodSignature &method, size_t pre, size_t *post) const; + Status CheckReturn(const MethodSignature &method, size_t pre, size_t post) const; }; class MS_API ServableStorage { @@ -111,7 +137,8 @@ class MS_API ServableStorage { bool GetServableDef(const std::string &model_name, ServableSignature *def) const; - void DeclareServable(const ServableMeta &servable); + Status DeclareServable(ServableMeta servable); + Status DeclareDistributedServable(ServableMeta servable); Status RegisterInputOutputInfo(const std::string &servable_name, size_t inputs_count, size_t outputs_count); std::vector GetInputOutputInfo(const std::string &servable_name) const; diff --git a/mindspore_serving/ccsrc/python/serving_py.cc b/mindspore_serving/ccsrc/python/serving_py.cc index 7de691c..1dac040 100644 --- a/mindspore_serving/ccsrc/python/serving_py.cc +++ b/mindspore_serving/ccsrc/python/serving_py.cc @@ -26,7 +26,8 @@ namespace mindspore::serving { -PYBIND11_MODULE(_mindspore_serving, m) { +void PyRegServable(pybind11::module *m_ptr) { + auto &m = *m_ptr; // avoid as numpy object memory copy in PyTensor::AsPythonData py::class_(m, "Tensor_"); @@ -68,16 +69,30 @@ PYBIND11_MODULE(_mindspore_serving, m) { .def_readwrite("version_number", &RequestSpec::version_number) .def_readwrite("method_name", &RequestSpec::method_name); + py::class_(m, "CommonServableMeta_") + .def(py::init<>()) + .def_readwrite("servable_name", &CommonServableMeta::servable_name) + .def_readwrite("inputs_count", &CommonServableMeta::inputs_count) + .def_readwrite("outputs_count", &CommonServableMeta::outputs_count) + .def_readwrite("with_batch_dim", &CommonServableMeta::with_batch_dim) + .def_readwrite("without_batch_dim_inputs", &CommonServableMeta::without_batch_dim_inputs); + + py::class_(m, "LocalServableMeta_") + .def(py::init<>()) + .def_readwrite("servable_file", &LocalServableMeta::servable_file) + .def_readwrite("options", &LocalServableMeta::load_options) + .def("set_model_format", &LocalServableMeta::SetModelFormat); + + py::class_(m, "DistributedServableMeta_") + .def(py::init<>()) + .def_readwrite("rank_size", &DistributedServableMeta::rank_size) + .def_readwrite("stage_size", &DistributedServableMeta::stage_size); + py::class_(m, "ServableMeta_") .def(py::init<>()) - .def_readwrite("servable_name", &ServableMeta::servable_name) - .def_readwrite("inputs_count", &ServableMeta::inputs_count) - .def_readwrite("outputs_count", &ServableMeta::outputs_count) - .def_readwrite("servable_file", &ServableMeta::servable_file) - .def_readwrite("with_batch_dim", &ServableMeta::with_batch_dim) - .def_readwrite("options", &ServableMeta::load_options) - .def_readwrite("without_batch_dim_inputs", &ServableMeta::without_batch_dim_inputs) - .def("set_model_format", &ServableMeta::SetModelFormat); + .def_readwrite("common_meta", &ServableMeta::common_meta) + .def_readwrite("local_meta", &ServableMeta::local_meta) + .def_readwrite("distributed_meta", &ServableMeta::distributed_meta); py::class_(m, "ServableSignature_") .def(py::init<>()) @@ -87,8 +102,22 @@ PYBIND11_MODULE(_mindspore_serving, m) { py::class_(m, "ServableStorage_") .def_static("register_servable_input_output_info", &PyServableStorage::RegisterInputOutputInfo) .def_static("register_method", &PyServableStorage::RegisterMethod) - .def_static("declare_servable", &PyServableStorage::DeclareServable); + .def_static("declare_servable", &PyServableStorage::DeclareServable) + .def_static("declare_distributed_servable", &PyServableStorage::DeclareDistributedServable); +} + +void PyRegMaster(pybind11::module *m_ptr) { + auto &m = *m_ptr; + py::class_>(m, "Master_") + .def_static("start_grpc_server", &PyMaster::StartGrpcServer) + .def_static("start_grpc_master_server", &PyMaster::StartGrpcMasterServer) + .def_static("start_restful_server", &PyMaster::StartRestfulServer) + .def_static("wait_and_clear", &PyMaster::WaitAndClear) + .def_static("stop_and_clear", &PyMaster::StopAndClear); +} +void PyRegWorker(pybind11::module *m_ptr) { + auto &m = *m_ptr; py::class_(m, "TaskContext_").def(py::init<>()); py::class_(m, "TaskItem_") @@ -108,6 +137,8 @@ PYBIND11_MODULE(_mindspore_serving, m) { py::class_(m, "Worker_") .def_static("start_servable", &PyWorker::StartServable) .def_static("start_servable_in_master", &PyWorker::StartServableInMaster) + .def_static("start_distributed_servable", &PyWorker::StartDistributedServable) + .def_static("start_distributed_servable_in_master", &PyWorker::StartDistributedServableInMaster) .def_static("get_batch_size", &PyWorker::GetBatchSize) .def_static("wait_and_clear", &PyWorker::WaitAndClear) .def_static("stop_and_clear", PyWorker::StopAndClear) @@ -130,13 +161,13 @@ PYBIND11_MODULE(_mindspore_serving, m) { } }) .def("set_device_id", &ServableContext::SetDeviceId); +} - py::class_>(m, "Master_") - .def_static("start_grpc_server", &PyMaster::StartGrpcServer) - .def_static("start_grpc_master_server", &PyMaster::StartGrpcMasterServer) - .def_static("start_restful_server", &PyMaster::StartRestfulServer) - .def_static("wait_and_clear", &PyMaster::WaitAndClear) - .def_static("stop_and_clear", &PyMaster::StopAndClear); +// cppcheck-suppress syntaxError +PYBIND11_MODULE(_mindspore_serving, m) { + PyRegServable(&m); + PyRegMaster(&m); + PyRegWorker(&m); (void)py::module::import("atexit").attr("register")(py::cpp_function{[&]() -> void { Server::Instance().Clear(); diff --git a/mindspore_serving/ccsrc/python/worker/servable_py.cc b/mindspore_serving/ccsrc/python/worker/servable_py.cc index 8fa565e..b320722 100644 --- a/mindspore_serving/ccsrc/python/worker/servable_py.cc +++ b/mindspore_serving/ccsrc/python/worker/servable_py.cc @@ -25,7 +25,16 @@ void PyServableStorage::RegisterMethod(const MethodSignature &method) { } } void PyServableStorage::DeclareServable(const ServableMeta &servable) { - ServableStorage::Instance().DeclareServable(servable); + auto status = ServableStorage::Instance().DeclareServable(servable); + if (status != SUCCESS) { + MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); + } +} +void PyServableStorage::DeclareDistributedServable(const ServableMeta &servable) { + auto status = ServableStorage::Instance().DeclareDistributedServable(servable); + if (status != SUCCESS) { + MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); + } } void PyServableStorage::RegisterInputOutputInfo(const std::string &servable_name, size_t inputs_count, size_t outputs_count) { diff --git a/mindspore_serving/ccsrc/python/worker/servable_py.h b/mindspore_serving/ccsrc/python/worker/servable_py.h index af9b26f..759289e 100644 --- a/mindspore_serving/ccsrc/python/worker/servable_py.h +++ b/mindspore_serving/ccsrc/python/worker/servable_py.h @@ -27,6 +27,7 @@ class MS_API PyServableStorage { static void RegisterMethod(const MethodSignature &method); static void DeclareServable(const ServableMeta &servable); + static void DeclareDistributedServable(const ServableMeta &servable); static void RegisterInputOutputInfo(const std::string &servable_name, size_t inputs_count, size_t outputs_count); static void Clear(); diff --git a/mindspore_serving/ccsrc/python/worker/worker_py.cc b/mindspore_serving/ccsrc/python/worker/worker_py.cc index b72cb5f..dce0b70 100644 --- a/mindspore_serving/ccsrc/python/worker/worker_py.cc +++ b/mindspore_serving/ccsrc/python/worker/worker_py.cc @@ -21,7 +21,8 @@ #include "common/exit_handle.h" #include "worker/notfiy_master/grpc_notify.h" #include "worker/notfiy_master/local_notify.h" -#include "worker/ascend_servable/ascend_sevable.h" +#include "worker/local_servable/local_sevable.h" +#include "worker/distributed_worker/distributed_servable.h" namespace mindspore::serving { @@ -29,7 +30,7 @@ void PyWorker::StartServable(const std::string &model_directory, const std::stri const std::string &master_ip, uint32_t master_port, const std::string &host_ip, uint32_t host_port) { auto notify_master = std::make_shared(master_ip, master_port, host_ip, host_port); - auto servable = std::make_shared(); + auto servable = std::make_shared(); auto status = servable->StartServable(model_directory, model_name, version_number); if (status != SUCCESS) { MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); @@ -51,7 +52,7 @@ void PyWorker::StartServable(const std::string &model_directory, const std::stri void PyWorker::StartServableInMaster(const std::string &model_directory, const std::string &model_name, uint32_t version_number) { auto notify_master = std::make_shared(); - auto servable = std::make_shared(); + auto servable = std::make_shared(); auto status = servable->StartServable(model_directory, model_name, version_number); if (status != SUCCESS) { MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); @@ -66,6 +67,56 @@ void PyWorker::StartServableInMaster(const std::string &model_directory, const s } } +void PyWorker::StartDistributedServable(const std::string &servable_directory, const std::string &servable_name, + const std::string &rank_table_json_file, uint32_t version_number, + const std::string &worker_ip, uint32_t worker_port, + const std::string &master_ip, uint32_t master_port) { + Status status; + status = Worker::GetInstance().StartGrpcServer(worker_ip, worker_port); + if (status != SUCCESS) { + MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); + } + auto notify_master = std::make_shared(master_ip, master_port, worker_ip, worker_port); + auto servable = std::make_shared(); + status = servable->StartServable(servable_directory, servable_name, rank_table_json_file, version_number); + if (status != SUCCESS) { + MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); + } + status = Worker::GetInstance().StartServable(servable, notify_master); + if (status != SUCCESS) { + MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); + } + status = Worker::GetInstance().StartVersionController(); + if (status != SUCCESS) { + MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); + } +} + +void PyWorker::StartDistributedServableInMaster(const std::string &servable_directory, const std::string &servable_name, + const std::string &rank_table_json_file, uint32_t version_number, + const std::string &worker_ip, uint32_t worker_port) { + Status status; + status = Worker::GetInstance().StartGrpcServer(worker_ip, worker_port); + if (status != SUCCESS) { + MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); + } + + auto notify_master = std::make_shared(); + auto servable = std::make_shared(); + status = servable->StartServable(servable_directory, servable_name, rank_table_json_file, version_number); + if (status != SUCCESS) { + MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); + } + status = Worker::GetInstance().StartServable(servable, notify_master); + if (status != SUCCESS) { + MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); + } + status = Worker::GetInstance().StartVersionController(); + if (status != SUCCESS) { + MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); + } +} + TaskItem PyWorker::GetPyTask() { TaskItem item; Worker::GetInstance().GetPyTaskQueueGroup().PopPyTask(&item); diff --git a/mindspore_serving/ccsrc/python/worker/worker_py.h b/mindspore_serving/ccsrc/python/worker/worker_py.h index cf595e3..01a53a8 100644 --- a/mindspore_serving/ccsrc/python/worker/worker_py.h +++ b/mindspore_serving/ccsrc/python/worker/worker_py.h @@ -34,6 +34,15 @@ class MS_API PyWorker { static void StartServableInMaster(const std::string &model_directory, const std::string &model_name, uint32_t version_number); + static void StartDistributedServable(const std::string &servable_directory, const std::string &servable_name, + const std::string &rank_table_json_file, uint32_t version_number, + const std::string &worker_ip, uint32_t worker_port, const std::string &master_ip, + uint32_t master_port); + + static void StartDistributedServableInMaster(const std::string &servable_directory, const std::string &servable_name, + const std::string &rank_table_json_file, uint32_t version_number, + const std::string &worker_ip, uint32_t worker_port); + static int GetBatchSize(); static void WaitAndClear(); static void StopAndClear(); diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.h b/mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.h index 916fd39..5a7c25e 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.h +++ b/mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.h @@ -33,7 +33,7 @@ class MS_API WorkerAgentStartUp { Status InitAgentsConfig(const std::string &model_dir, const std::string &model_file_prefix, const std::string &group_file_dir, const std::string &group_file_prefix); - Status GetAgentsConfigsFromWorker(const std::string &agent_ip, uint32_t agent_start_port, + Status GetAgentsConfigsFromWorker(const std::string &rank_start, uint32_t agent_start_port, const std::string &worker_ip, uint32_t worker_port); // step2, invoke from python, get current machine agents config Status GetCurrentMachineConfigs(std::vector *configs); diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/common.h b/mindspore_serving/ccsrc/worker/distributed_worker/common.h index 2bda8ec..c145bcd 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/common.h +++ b/mindspore_serving/ccsrc/worker/distributed_worker/common.h @@ -22,6 +22,7 @@ #include #include "common/serving_common.h" #include "worker/inference/inference.h" +#include "common/servable.h" namespace mindspore { namespace serving { @@ -31,19 +32,12 @@ struct OneRankConfig { uint32_t device_id = 0; }; -struct DistributedServableCommonConfig { - bool with_batch_dim; - std::vector without_batch_dim_inputs; -}; - struct DistributedServableConfig { - uint32_t rank_size = 0; - uint32_t stage_size = 0; - const std::string models_dir; - const std::string groups_dir; std::string rank_table_content; std::vector rank_list; - DistributedServableCommonConfig common_config; + + CommonServableMeta common_meta; + DistributedServableMeta distributed_meta; }; struct AgentStartUpConfig { @@ -58,8 +52,7 @@ struct AgentStartUpConfig { std::string worker_ip; uint32_t worker_port; - DistributedServableCommonConfig common_config; - std::map other_options; + CommonServableMeta common_meta; }; } // namespace serving diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.cc b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.cc index 9bc44f1..b504355 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.cc +++ b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.cc @@ -19,61 +19,284 @@ #include #include "worker/worker.h" #include "worker/distributed_worker/notify_agent/notify_agent.h" +#include "common/exit_handle.h" + namespace mindspore { namespace serving { +std::string DistributedServable::GetServableName() const { return servable_name_; } + +uint64_t DistributedServable::GetServableVersion() const { return version_number_; } + Status DistributedServable::Predict(const std::vector &input, std::vector *output) { + if (!model_loaded_) { + MSI_LOG_EXCEPTION << "Model has not been loaded"; + } return Status(); } -std::vector DistributedServable::GetInputInfos() const { return std::vector(); } -std::vector DistributedServable::GetOutputInfos() const { return std::vector(); } -uint64_t DistributedServable::GetBatchSize() const { return 0; } -Status DistributedServable::GetDistributedServableConfig(DistributedServableConfig *config) { return Status(); } +std::vector DistributedServable::GetInputInfos() const { + if (!model_loaded_) { + MSI_LOG_EXCEPTION << "Model has not been loaded"; + } + return input_infos_; +} + +std::vector DistributedServable::GetOutputInfos() const { + if (!model_loaded_) { + MSI_LOG_EXCEPTION << "Model has not been loaded"; + } + return output_infos_; +} + +uint64_t DistributedServable::GetBatchSize() const { + if (!model_loaded_) { + MSI_LOG_EXCEPTION << "Model has not been loaded"; + } + return batch_size_; +} + +Status DistributedServable::GetDistributedServableConfig(DistributedServableConfig *config) const { + *config = config_; + return SUCCESS; +} + Status DistributedServable::RegisterAgent(const WorkerAgentSpec &agent_spec) { + if (agent_spec.rank_id < config_.distributed_meta.rank_size) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Invalid rank id " << agent_spec.rank_id << ", rank size " << config_.distributed_meta.rank_size; + } DistributedAgentContext context; - auto it = agent_spec_list_.find(agent_spec.rank_id); - if (it != agent_spec_list_.end()) { + auto it = agent_spec_map_.find(agent_spec.rank_id); + if (it != agent_spec_map_.end()) { MSI_LOG_WARNING << "rank_id " << agent_spec.rank_id << " has been registered"; return SUCCESS; } context.agent_spec_ = agent_spec; std::shared_ptr notify_agent = std::make_shared(agent_spec.agent_address); context.notify_agent_ = notify_agent; - agent_spec_list_[agent_spec.rank_id] = context; - if (config_.rank_size == agent_spec_list_.size()) { + agent_spec_map_[agent_spec.rank_id] = context; + if (config_.distributed_meta.rank_size == agent_spec_map_.size()) { Status status = Worker::GetInstance().RegisterWorker(); if (status != SUCCESS) { Clear(); return FAILED; } } + if (agent_spec_map_.size() >= config_.distributed_meta.rank_size) { + agents_promise_.set_value(); + } return SUCCESS; } void DistributedServable::Clear() { - for (auto agent : agent_spec_list_) { + for (auto agent : agent_spec_map_) { agent.second.notify_agent_->Exit(); } Worker::GetInstance().StopServable(false); } Status DistributedServable::UnregisterAgent(const WorkerAgentSpec &agent_spec) { - for (auto iter = agent_spec_list_.begin(); iter != agent_spec_list_.end();) { + for (auto iter = agent_spec_map_.begin(); iter != agent_spec_map_.end();) { if (agent_spec.rank_id == iter->second.agent_spec_.rank_id) { - iter = agent_spec_list_.erase(iter); + iter = agent_spec_map_.erase(iter); } else { ++iter; } } // todo: send exit message to agent, and then exit if split with master Clear(); - return Status(); + return SUCCESS; } -Status DistributedServable::SetProperty(uint32_t rank_size, uint32_t stage_size, bool with_bach_dim, - const std::vector &without_batch_dim_inputs) { - return Status(); +Status DistributedServable::StartServable(const std::string &servable_directory, const std::string &servable_name, + const std::string &rank_table_json_file, uint64_t version_number) { + if (model_loaded_) { + MSI_LOG_EXCEPTION << "Model has loaded"; + } + version_number_ = version_number; + servable_name_ = servable_name; + rank_table_json_file_ = rank_table_json_file; + ServableSignature signature; + if (!ServableStorage::Instance().GetServableDef(servable_name, &signature)) { + return INFER_STATUS_LOG_ERROR(FAILED) << "Servable '" << servable_name << "' has not been registered"; + } + auto &meta = signature.servable_meta; + if (meta.servable_type != kServableTypeDistributed) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Servable '" << servable_name << "' is not registered as distributed servable, " << meta.Repr(); + } + config_.common_meta = meta.common_meta; + config_.distributed_meta = meta.distributed_meta; + + auto status = InitConfigOnStartup(rank_table_json_file_); + if (status != SUCCESS) { + MSI_LOG_ERROR << "Init with rank table on start up failed"; + return status; + } + status = CheckRankConfig(); + if (status != SUCCESS) { + MSI_LOG_ERROR << "Check rank config failed"; + return status; + } + status = WaitAgentsReady(); + if (status != SUCCESS) { + MSI_LOG_ERROR << "Waiting for ready of agents failed"; + return status; + } + status = CheckAgentsInfosAndInitTensorInfos(); + if (status != SUCCESS) { + MSI_LOG_ERROR << "Check agents infos failed"; + return status; + } + model_loaded_ = true; + return SUCCESS; +} + +Status DistributedServable::InitConfigOnStartup(const std::string &rank_table_json_file) { return FAILED; } + +Status DistributedServable::WaitAgentsReady() { + auto future = agents_promise_.get_future(); + const int kWaitMaxHundredMs = 100 * 10; // 100s + int i; + for (i = 0; i < kWaitMaxHundredMs; i++) { // + if (ExitSignalHandle::Instance().HasStopped()) { + return INFER_STATUS_LOG_ERROR(FAILED) << "Agents has stopped"; + } + // waiting for 100ms + if (future.wait_for(std::chrono::milliseconds(100)) == std::future_status::ready) { + break; + } + } + if (i >= kWaitMaxHundredMs) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Failed to wait for ready of all agents, current agents count: " << agent_spec_map_.size() + << ", rank size: " << config_.distributed_meta.rank_size; + } + return SUCCESS; +} + +Status DistributedServable::CompareTensorInfos(const std::vector &lefts, + const std::vector &rights) { + if (lefts.size() != rights.size()) { + return INFER_STATUS(FAILED) << "Size not match, left: " << lefts.size() << ", right: " << rights.size(); + } + auto tensor_info_as_str = [](const TensorInfo &tensor_info) { + Status status = INFER_STATUS(SUCCESS) << "size: " << tensor_info.size << ", data type: " << tensor_info.data_type + << ", shape: " << tensor_info.shape; + return status.StatusMessage(); + }; + for (size_t k = 0; k < lefts.size(); k++) { + auto &left = lefts[k]; + auto &right = rights[k]; + if (left.size != right.size || left.shape != right.shape || left.data_type != right.data_type) { + return INFER_STATUS(FAILED) << "Index " << k << " tensor not match, left- " << tensor_info_as_str(left) + << "; right- " << tensor_info_as_str(right); + } + } + return SUCCESS; +} + +Status DistributedServable::CheckAgentsInfosAndInitTensorInfos() { + auto rank_size = config_.distributed_meta.rank_size; + auto stage_size = config_.distributed_meta.stage_size; + auto parallel_count = rank_size / stage_size; + MSI_LOG_INFO << "Check agents infos, rank size :" << rank_size << ", stage size: " << stage_size + << ", parallel count: " << parallel_count; + if (agent_spec_map_.size() != rank_size) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Registered agents size " << agent_spec_map_.size() << " not match rank size " << rank_size; + } + + input_infos_ = agent_spec_map_[0].agent_spec_.input_infos; + output_infos_ = agent_spec_map_[rank_size - 1].agent_spec_.output_infos; + batch_size_ = agent_spec_map_[0].agent_spec_.batch_size; + if (input_infos_.empty()) { + return INFER_STATUS_LOG_ERROR(FAILED) << "Rank " << 0 << " input count cannot be 0"; + } + if (output_infos_.empty()) { + return INFER_STATUS_LOG_ERROR(FAILED) << "Rank " << rank_size - 1 << " output count cannot be 0"; + } + Status status; + for (size_t i = 0; i < parallel_count; i++) { + auto &agent_spec = agent_spec_map_[i]; + status = CompareTensorInfos(agent_spec.agent_spec_.input_infos, input_infos_); + if (status != SUCCESS) { + status = INFER_STATUS_LOG_ERROR(FAILED) + << "Rank " << i << " input infos not match rank 0, details: " << status.StatusMessage(); + return status; + } + } + for (size_t i = parallel_count; i < rank_size; i++) { + auto &agent_spec = agent_spec_map_[i]; + if (!agent_spec.agent_spec_.input_infos.empty()) { + return INFER_STATUS_LOG_ERROR(FAILED) << "Expect rank " << i << " input count equal to 0"; + } + } + for (size_t i = 0; i < rank_size; i++) { + auto &first_item = agent_spec_map_[i]; + for (size_t k = 0; k < parallel_count && i + k < rank_size; k++) { + auto rank_id = i + k; + auto &agent_spec = agent_spec_map_[i + k]; + status = CompareTensorInfos(agent_spec.agent_spec_.output_infos, first_item.agent_spec_.output_infos); + if (status != SUCCESS) { + status = INFER_STATUS_LOG_ERROR(FAILED) << "Rank " << rank_size << " output infos not match rank " << i + << ", details: " << status.StatusMessage(); + return status; + } + if (agent_spec.agent_spec_.batch_size != 0 && agent_spec.agent_spec_.batch_size != batch_size_) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Expect rank " << rank_id << " batch size equal to 0 or rank 0 batch size " << batch_size_; + } + } + } + return SUCCESS; } -Status DistributedServable::InitConfigOnStartup(const std::string &rank_table_json_file) { return Status(); } + +Status DistributedServable::CheckRankConfig() { + auto rank_size = config_.distributed_meta.rank_size; + auto stage_size = config_.distributed_meta.stage_size; + if (stage_size == 0 || rank_size == 0) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Rank size or stage size cannot be 0, rank size: " << rank_size << ", stage size: " << stage_size; + } + if (rank_size % stage_size != 0) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Rank size must be an integral multiple of stage size, rank size: " << rank_size + << ", stage size: " << stage_size; + } + auto parallel_count = rank_size / stage_size; + constexpr size_t card_count_per_machine = 8; + if (rank_size > card_count_per_machine && parallel_count % card_count_per_machine != 0) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Parallel count " << parallel_count << " in one stage must be an integral multiple of card count " + << card_count_per_machine << " in one machine, when rank size is greater than card count in one machine, " + << "rank size: " << rank_size << ", stage size: " << stage_size; + } + if (config_.rank_list.size() != rank_size) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Rank size " << config_.rank_list.size() << " declared in rank table file not equal to rank size " + << rank_size << " declared in servable_config, rank json config file: " << rank_table_json_file_; + } + for (size_t i = 0; i < rank_size; i++) { + const auto &first_item = config_.rank_list[i]; + for (size_t k = 0; i + k < rank_size && k < card_count_per_machine; k++) { + auto rank_id = i + k; + const auto &item = config_.rank_list[rank_id]; + if (k != item.device_id) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Check rank table config failed, expected device id of rank " << rank_id << " to be " << k; + } + if (first_item.ip != item.ip) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Check rank table config failed, expected device ip " << item.ip << " of rank " << rank_id + << " to be equal with device ip " << first_item.ip << " of rank " << i; + } + } + } + MSI_LOG_INFO << "Check rank table success, rank size: " << rank_size << ", stage size: " << stage_size + << ", parallel count in one stage: " << parallel_count; + return SUCCESS; +} + } // namespace serving } // namespace mindspore diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.h b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.h index 507a283..642a868 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.h +++ b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.h @@ -35,13 +35,12 @@ struct DistributedAgentContext { class MS_API DistributedServable : public ServableBase { public: - // from python, servable_config.py - Status SetProperty(uint32_t rank_size, uint32_t stage_size, bool with_bach_dim, - const std::vector &without_batch_dim_inputs); // from python, worker.py - Status InitConfigOnStartup(const std::string &rank_table_json_file); + Status StartServable(const std::string &servable_directory, const std::string &servable_name, + const std::string &rank_table_json_file, uint64_t version_number); + // invoke from agent - Status GetDistributedServableConfig(DistributedServableConfig *config); + Status GetDistributedServableConfig(DistributedServableConfig *config) const; // send model and group // register and unregister agent, agent_spec_list_ @@ -54,11 +53,29 @@ class MS_API DistributedServable : public ServableBase { std::vector GetInputInfos() const override; std::vector GetOutputInfos() const override; uint64_t GetBatchSize() const override; + std::string GetServableName() const override; + uint64_t GetServableVersion() const override; void Clear(); private: DistributedServableConfig config_; - std::map agent_spec_list_; + std::string servable_name_; + uint64_t version_number_ = 0; + bool model_loaded_ = false; + + std::map agent_spec_map_; + std::string rank_table_json_file_; + + std::vector input_infos_; + std::vector output_infos_; + uint64_t batch_size_ = 0; + std::promise agents_promise_; + + Status InitConfigOnStartup(const std::string &rank_table_json_file); + Status WaitAgentsReady(); + Status CheckAgentsInfosAndInitTensorInfos(); + Status CompareTensorInfos(const std::vector &lefts, const std::vector &rights); + Status CheckRankConfig(); // agent stubs }; diff --git a/mindspore_serving/ccsrc/worker/ascend_servable/ascend_sevable.cc b/mindspore_serving/ccsrc/worker/local_servable/local_sevable.cc similarity index 79% rename from mindspore_serving/ccsrc/worker/ascend_servable/ascend_sevable.cc rename to mindspore_serving/ccsrc/worker/local_servable/local_sevable.cc index ed84469..9680929 100644 --- a/mindspore_serving/ccsrc/worker/ascend_servable/ascend_sevable.cc +++ b/mindspore_serving/ccsrc/worker/local_servable/local_sevable.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "worker/ascend_servable/ascend_sevable.h" +#include "worker/local_servable/local_sevable.h" #include #include #include @@ -31,42 +31,46 @@ static const char *kVersionStrategySpecific = "specific"; namespace mindspore::serving { -AscendModelServable::~AscendModelServable() { session_.UnloadModel(); } +LocalModelServable::~LocalModelServable() { session_.UnloadModel(); } -Status AscendModelServable::Predict(const std::vector &input, std::vector *output) { +std::string LocalModelServable::GetServableName() const { return servable_name_; } + +uint64_t LocalModelServable::GetServableVersion() const { return version_number_; } + +Status LocalModelServable::Predict(const std::vector &input, std::vector *output) { if (!model_loaded_) { MSI_LOG_EXCEPTION << "Model has not been loaded"; } return session_.ExecuteModel(input, output); } -std::vector AscendModelServable::GetInputInfos() const { +std::vector LocalModelServable::GetInputInfos() const { if (!model_loaded_) { MSI_LOG_EXCEPTION << "Model has not been loaded"; } return session_.GetInputInfos(); } -std::vector AscendModelServable::GetOutputInfos() const { +std::vector LocalModelServable::GetOutputInfos() const { if (!model_loaded_) { MSI_LOG_EXCEPTION << "Model has not been loaded"; } return session_.GetOutputInfos(); } -uint64_t AscendModelServable::GetBatchSize() const { +uint64_t LocalModelServable::GetBatchSize() const { if (!model_loaded_) { MSI_LOG_EXCEPTION << "Model has not been loaded"; } return session_.GetBatchSize(); } -TensorBasePtr AscendModelServable::MakeInferenceTensor(DataType data_type, const std::vector &shape) const { +TensorBasePtr LocalModelServable::MakeInferenceTensor(DataType data_type, const std::vector &shape) const { return std::make_shared(data_type, shape); } -Status AscendModelServable::StartServable(const std::string &servable_directory, const std::string &servable_name, - uint32_t version_number) { +Status LocalModelServable::StartServable(const std::string &servable_directory, const std::string &servable_name, + uint64_t version_number) { if (model_loaded_) { MSI_LOG_EXCEPTION << "Model has loaded"; } @@ -85,7 +89,7 @@ Status AscendModelServable::StartServable(const std::string &servable_directory, if (!ServableStorage::Instance().GetServableDef(servable_name, &signature)) { return INFER_STATUS_LOG_ERROR(FAILED) << "Servable '" << servable_name << "' has not been registered"; } - status = InitDevice(signature.servable_meta.model_format, {}); + status = InitDevice(signature.servable_meta.local_meta.model_format, {}); if (status != SUCCESS) { MSI_LOG_ERROR << "Init env failed"; return status; @@ -105,23 +109,15 @@ Status AscendModelServable::StartServable(const std::string &servable_directory, if (status != SUCCESS) { return status; } - worker_spec_.servable_name = base_spec_.servable_name; - worker_spec_.version_number = real_version_number; - for (auto &method : signature.methods) { - WorkerMethodInfo worker_method_info; - worker_method_info.name = method.method_name; - for (auto &name : method.inputs) { - worker_method_info.input_names.push_back(name); - } - worker_spec_.methods.push_back(worker_method_info); - } + servable_name_ = base_spec_.servable_name; + version_number_ = real_version_number; model_loaded_ = true; MSI_LOG_INFO << status.StatusMessage(); std::cout << status.StatusMessage() << std::endl; return SUCCESS; } -void AscendModelServable::GetVersions(const LoadServableSpec &servable_spec, std::vector *real_versions) { +void LocalModelServable::GetVersions(const LoadServableSpec &servable_spec, std::vector *real_versions) { MSI_EXCEPTION_IF_NULL(real_versions); // define version_strategy:"specific","latest","multi" if (version_strategy_ == kVersionStrategySpecific) { @@ -168,9 +164,9 @@ void AscendModelServable::GetVersions(const LoadServableSpec &servable_spec, std } } -Status AscendModelServable::LoadServableConfig(const LoadServableSpec &servable_spec, - const std::string &version_strategy, - std::vector *real_versions) { +Status LocalModelServable::LoadServableConfig(const LoadServableSpec &servable_spec, + const std::string &version_strategy, + std::vector *real_versions) { MSI_EXCEPTION_IF_NULL(real_versions); auto model_directory = servable_spec.servable_directory; auto model_name = servable_spec.servable_name; @@ -199,7 +195,7 @@ Status AscendModelServable::LoadServableConfig(const LoadServableSpec &servable_ return SUCCESS; } -Status AscendModelServable::InitDevice(ModelType model_type, const std::map &other_options) { +Status LocalModelServable::InitDevice(ModelType model_type, const std::map &other_options) { Status status; auto context = ServableContext::Instance(); DeviceType device_type = ServableContext::Instance()->GetDeviceType(); @@ -229,23 +225,25 @@ Status AscendModelServable::InitDevice(ModelType model_type, const std::mapGetDeviceType(), context->GetDeviceId(), model_file_name, - servable_meta.model_format, servable_meta.with_batch_dim, - servable_meta.without_batch_dim_inputs, servable_meta.load_options); + local_meta.model_format, common_meta.with_batch_dim, + common_meta.without_batch_dim_inputs, local_meta.load_options); if (status != SUCCESS) { return INFER_STATUS_LOG_ERROR(FAILED) << "Load model failed, servable directory: '" << base_spec_.servable_directory << "', servable name: '" - << base_spec_.servable_name << "', servable file: '" << servable_meta.servable_file << "', version number " - << version_number << ", options " << servable_meta.load_options; + << base_spec_.servable_name << "', servable file: '" << local_meta.servable_file << "', version number " + << version_number << ", options " << local_meta.load_options; } return SUCCESS; } diff --git a/mindspore_serving/ccsrc/worker/ascend_servable/ascend_sevable.h b/mindspore_serving/ccsrc/worker/local_servable/local_sevable.h similarity index 86% rename from mindspore_serving/ccsrc/worker/ascend_servable/ascend_sevable.h rename to mindspore_serving/ccsrc/worker/local_servable/local_sevable.h index ad29706..d5b9a8c 100644 --- a/mindspore_serving/ccsrc/worker/ascend_servable/ascend_sevable.h +++ b/mindspore_serving/ccsrc/worker/local_servable/local_sevable.h @@ -31,10 +31,10 @@ namespace mindspore::serving { -class MS_API AscendModelServable : public ServableBase { +class MS_API LocalModelServable : public ServableBase { public: - AscendModelServable() = default; - ~AscendModelServable() override; + LocalModelServable() = default; + ~LocalModelServable() override; Status Predict(const std::vector &input, std::vector *output) override; @@ -44,13 +44,16 @@ class MS_API AscendModelServable : public ServableBase { TensorBasePtr MakeInferenceTensor(DataType data_type, const std::vector &shape) const override; Status StartServable(const std::string &servable_directory, const std::string &servable_name, - uint32_t version_number); + uint64_t version_number); Status InitDevice(ModelType model_type, const std::map &other_options); - WorkerSpec GetWorkerSpec() const override { return worker_spec_; } + std::string GetServableName() const override; + uint64_t GetServableVersion() const override; private: LoadServableSpec base_spec_; - WorkerSpec worker_spec_; + std::string servable_name_; + uint64_t version_number_ = 0; + MindSporeModelWrap session_; std::string version_strategy_; bool model_loaded_ = false; diff --git a/mindspore_serving/ccsrc/worker/sevable_base.h b/mindspore_serving/ccsrc/worker/sevable_base.h index 4185b13..8e9e800 100644 --- a/mindspore_serving/ccsrc/worker/sevable_base.h +++ b/mindspore_serving/ccsrc/worker/sevable_base.h @@ -39,7 +39,8 @@ class ServableBase { virtual std::vector GetInputInfos() const = 0; virtual std::vector GetOutputInfos() const = 0; virtual uint64_t GetBatchSize() const = 0; - virtual WorkerSpec GetWorkerSpec() const = 0; + virtual std::string GetServableName() const = 0; + virtual uint64_t GetServableVersion() const = 0; }; } // namespace mindspore::serving diff --git a/mindspore_serving/ccsrc/worker/work_executor.cc b/mindspore_serving/ccsrc/worker/work_executor.cc index 0453fb6..2ec760e 100644 --- a/mindspore_serving/ccsrc/worker/work_executor.cc +++ b/mindspore_serving/ccsrc/worker/work_executor.cc @@ -49,15 +49,15 @@ Status WorkExecutor::CheckSevableSignature() { if (servable_declare_.methods.empty()) { return INFER_STATUS_LOG_ERROR(FAILED) << "There is no method registered for servable"; } - if (input_infos.size() != servable_declare_.servable_meta.inputs_count) { - return INFER_STATUS_LOG_ERROR(FAILED) - << "The inputs count " << servable_declare_.servable_meta.inputs_count << " registered in method " - << "not equal to the count " << input_infos.size() << " defined in servable"; + const auto &common_meta = servable_declare_.servable_meta.common_meta; + if (input_infos.size() != common_meta.inputs_count) { + return INFER_STATUS_LOG_ERROR(FAILED) << "The inputs count " << common_meta.inputs_count << " registered in method " + << "not equal to the count " << input_infos.size() << " defined in servable"; } const auto &output_infos = output_infos_; - if (output_infos.size() != servable_declare_.servable_meta.outputs_count) { + if (output_infos.size() != common_meta.outputs_count) { return INFER_STATUS_LOG_ERROR(FAILED) - << "The outputs count " << servable_declare_.servable_meta.outputs_count << " registered in method " + << "The outputs count " << common_meta.outputs_count << " registered in method " << "not equal to the count " << output_infos.size() << " defined in servable"; } MSI_LOG_INFO << "Model input infos: count " << input_infos.size(); @@ -68,7 +68,7 @@ Status WorkExecutor::CheckSevableSignature() { for (auto &item : output_infos) { MSI_LOG_INFO << item.shape << ", " << item.data_type << ", " << item.size; } - if (servable_declare_.servable_meta.with_batch_dim) { + if (common_meta.with_batch_dim) { if (model_batch_size_ == 0) { return INFER_STATUS_LOG_ERROR(FAILED) << "Servable batch size cannot be " << model_batch_size_; } @@ -104,7 +104,7 @@ Status WorkExecutor::Init(const ServableSignature &servable_declare, const std:: servable_ = servable; input_infos_ = servable_->GetInputInfos(); output_infos_ = servable_->GetOutputInfos(); - if (servable_declare_.servable_meta.with_batch_dim) { + if (servable_declare_.servable_meta.common_meta.with_batch_dim) { model_batch_size_ = servable_->GetBatchSize(); } else { model_batch_size_ = 1; @@ -389,7 +389,7 @@ Status WorkExecutor::PostPredict(const std::vector &inputs, const std: MSI_LOG_EXCEPTION << "Output result data size cannot be 0"; } auto shape = item->shape(); - if (servable_declare_.servable_meta.with_batch_dim) { + if (servable_declare_.servable_meta.common_meta.with_batch_dim) { if (shape.empty() || shape[0] != model_batch_size) { MSI_LOG_EXCEPTION << "Output shape " << shape << " not match batch size " << model_batch_size; } @@ -429,9 +429,9 @@ Status WorkExecutor::Predict(const std::vector &inputs, std::vector servable, std::shared cpp_postprocess_.Start(2); notify_master_ = std::move(notify_master); - auto worker_spec = servable->GetWorkerSpec(); + auto servable_name = servable->GetServableName(); ServableSignature signature; - if (!ServableStorage::Instance().GetServableDef(worker_spec.servable_name, &signature)) { - return INFER_STATUS_LOG_ERROR(FAILED) << "Servable " << worker_spec.servable_name << " has not been registered"; + if (!ServableStorage::Instance().GetServableDef(servable_name, &signature)) { + return INFER_STATUS_LOG_ERROR(FAILED) << "Servable " << servable_name << " has not been registered"; } auto service = std::make_shared(GetPyTaskQueuePreprocess(), GetPyTaskQueuePostprocess(), GetCppTaskQueuePreprocess(), GetCppTaskQueuePostprocess()); @@ -209,6 +209,17 @@ Status Worker::StartServable(std::shared_ptr servable, std::shared return status; } ServableWorkerContext work; + WorkerSpec worker_spec; + worker_spec.servable_name = servable_name; + worker_spec.version_number = servable->GetServableVersion(); + for (auto &method : signature.methods) { + WorkerMethodInfo worker_method_info; + worker_method_info.name = method.method_name; + for (auto &name : method.inputs) { + worker_method_info.input_names.push_back(name); + } + worker_spec.methods.push_back(worker_method_info); + } work.worker_spec = worker_spec; work.servable_signature = signature; work.worker_service = service; diff --git a/mindspore_serving/proto/ms_distributed.proto b/mindspore_serving/proto/ms_distributed.proto index fc48ead..c7be82f 100644 --- a/mindspore_serving/proto/ms_distributed.proto +++ b/mindspore_serving/proto/ms_distributed.proto @@ -24,8 +24,8 @@ message AgentSpec { int64 rank_id = 1; int64 batch_size = 2; int64 input_size = 3; - repeated Tensor inputs =4; - repeated Tensor outputs = 5; + repeated Tensor inputs = 4; + repeated Tensor outputs = 5; } message AgentRegisterRequest { diff --git a/mindspore_serving/worker/distributed/distributed_worker.py b/mindspore_serving/worker/distributed/distributed_worker.py index d8ff512..5bee6b8 100644 --- a/mindspore_serving/worker/distributed/distributed_worker.py +++ b/mindspore_serving/worker/distributed/distributed_worker.py @@ -14,7 +14,9 @@ # ============================================================================ """Serving, distributed worker startup""" from mindspore_serving.worker._worker import stop_on_except, _load_servable_config +from mindspore_serving.worker._worker import _start_py_task, _start_wait_and_clear from mindspore_serving.worker import check_type +from mindspore_serving._mindspore_serving import Worker_ @stop_on_except @@ -68,6 +70,10 @@ def start_distributed_servable(servable_directory, servable_name, rank_table_jso check_type.check_ip_port('worker_port', worker_port) _load_servable_config(servable_directory, servable_name) + _start_wait_and_clear() + Worker_.start_distributed_servable(servable_directory, servable_name, rank_table_json_file, version_number, + master_ip, master_port, worker_ip, worker_port) + _start_py_task(Worker_.get_batch_size()) @stop_on_except @@ -115,3 +121,7 @@ def start_distributed_servable_in_master(servable_directory, servable_name, rank check_type.check_ip_port('worker_port', worker_port) _load_servable_config(servable_directory, servable_name) + _start_wait_and_clear() + Worker_.start_distributed_servable_in_master(servable_directory, servable_name, rank_table_json_file, + version_number, worker_ip, worker_port) + _start_py_task(Worker_.get_batch_size()) diff --git a/mindspore_serving/worker/distributed/register.py b/mindspore_serving/worker/distributed/register.py index c060624..bac0f35 100644 --- a/mindspore_serving/worker/distributed/register.py +++ b/mindspore_serving/worker/distributed/register.py @@ -13,12 +13,31 @@ # limitations under the License. # ============================================================================ """Serving, distributed worker register""" + +from mindspore_serving._mindspore_serving import ServableMeta_, ServableStorage_ from mindspore_serving.worker import check_type +from mindspore_serving.worker.common import get_servable_dir +from mindspore_serving import log as logger -def declare_distributed_servable(rank_size, stage_size, with_bach_dim, without_batch_dim_inputs): +def declare_distributed_servable(rank_size, stage_size, with_batch_dim, without_batch_dim_inputs): """declare distributed servable in servable_config.py""" - check_type.check_int("rank_size", rank_size, 0) - check_type.check_int("stage_size", stage_size, 0) - check_type.check_bool("with_bach_dim", with_bach_dim) - check_type.check_and_as_int_tuple_list("without_batch_dim_inputs", without_batch_dim_inputs, 0) + check_type.check_bool('with_batch_dim', with_batch_dim) + + meta = ServableMeta_() + meta.common_meta.servable_name = get_servable_dir() + meta.common_meta.with_batch_dim = with_batch_dim + if without_batch_dim_inputs: + without_batch_dim_inputs = check_type.check_and_as_int_tuple_list('without_batch_dim_inputs', + without_batch_dim_inputs, 0) + meta.common_meta.without_batch_dim_inputs = without_batch_dim_inputs + + # init distributed servable meta info + check_type.check_int("rank_size", rank_size, 1) + check_type.check_int("stage_size", stage_size, 1) + meta.distributed_meta.rank_size = rank_size + meta.distributed_meta.stage_size = stage_size + ServableStorage_.declare_distributed_servable(meta) + logger.info(f"Declare distributed servable, servable_name: {meta.common_meta.servable_name} " + f", rank_size: {rank_size} , stage_size: {stage_size}, with_batch_dim: {with_batch_dim} " + f", without_batch_dim_inputs: {without_batch_dim_inputs}") diff --git a/mindspore_serving/worker/register/method.py b/mindspore_serving/worker/register/method.py index 224796c..ccde72c 100644 --- a/mindspore_serving/worker/register/method.py +++ b/mindspore_serving/worker/register/method.py @@ -35,28 +35,6 @@ method_tag_predict = PredictPhaseTag_.kPredictPhaseTag_Predict method_tag_postprocess = PredictPhaseTag_.kPredictPhaseTag_Postprocess -class _ServableStorage: - """Declare servable info""" - - def __init__(self): - pass - - @staticmethod - def declare_servable(servable_meta): - """Declare servable info excluding method, input and output count""" - ServableStorage_.declare_servable(servable_meta) - - @staticmethod - def declare_servable_input_output(servable_name, inputs_count, outputs_count): - """Declare input and output count of servable""" - ServableStorage_.register_servable_input_output_info(servable_name, inputs_count, outputs_count) - - @staticmethod - def register_method(method_signature): - """Declare method of servable""" - ServableStorage_.register_method(method_signature) - - class _TensorDef: """Data flow item, for definitions of data flow in a method""" @@ -251,7 +229,7 @@ def call_servable(*args): servable_name = get_servable_dir() inputs_count, outputs_count = method_def_ast_meta_[_call_servable_name] - _ServableStorage.declare_servable_input_output(servable_name, inputs_count, outputs_count) + ServableStorage_.register_servable_input_output_info(servable_name, inputs_count, outputs_count) if inputs_count != len(args): raise RuntimeError(f"Check failed in method '{method_def_context_.method_name}', given servable input " f"size {len(args)} not match '{servable_name}' ast parse size {inputs_count}") @@ -467,7 +445,7 @@ def register_method(output_names): f", servable_name {method_def_context_.servable_name}, inputs: {input_names}, outputs: " f"{output_names}") - _ServableStorage.register_method(method_def_context_) + ServableStorage_.register_method(method_def_context_) return func return register diff --git a/mindspore_serving/worker/register/servable.py b/mindspore_serving/worker/register/servable.py index 97b4829..97eb23d 100644 --- a/mindspore_serving/worker/register/servable.py +++ b/mindspore_serving/worker/register/servable.py @@ -14,11 +14,10 @@ # ============================================================================ """Servable declaration interface""" -from mindspore_serving._mindspore_serving import ServableMeta_ +from mindspore_serving._mindspore_serving import ServableMeta_, ServableStorage_ from mindspore_serving.worker import check_type from mindspore_serving.worker.common import get_servable_dir from mindspore_serving import log as logger -from .method import _ServableStorage def declare_servable(servable_file, model_format, with_batch_dim=True, options=None, without_batch_dim_inputs=None): @@ -37,19 +36,25 @@ def declare_servable(servable_file, model_format, with_batch_dim=True, options=N RuntimeError: The type or value of the parameters is invalid. """ - check_type.check_str('servable_file', servable_file) - check_type.check_str('model_format', model_format) check_type.check_bool('with_batch_dim', with_batch_dim) + meta = ServableMeta_() + meta.common_meta.servable_name = get_servable_dir() + meta.common_meta.with_batch_dim = with_batch_dim + if without_batch_dim_inputs: + without_batch_dim_inputs = check_type.check_and_as_int_tuple_list('without_batch_dim_inputs', + without_batch_dim_inputs, 0) + meta.common_meta.without_batch_dim_inputs = without_batch_dim_inputs + + # init local servable meta info + check_type.check_str('servable_file', servable_file) + check_type.check_str('model_format', model_format) model_format = model_format.lower() if model_format not in ("om", "mindir"): raise RuntimeError("model format can only be OM or MindIR") - meta = ServableMeta_() - meta.servable_name = get_servable_dir() - meta.servable_file = servable_file - meta.set_model_format(model_format) - meta.with_batch_dim = with_batch_dim + meta.local_meta.servable_file = servable_file + meta.local_meta.set_model_format(model_format) if isinstance(options, dict): for k, w in options.items(): check_type.check_str("options key", k) @@ -61,14 +66,10 @@ def declare_servable(servable_file, model_format, with_batch_dim=True, options=N raise RuntimeError(f"Parameter 'options' should be None, dict of or AclOptions, but " f"gotten {type(options)}") if options: - meta.options = options - if without_batch_dim_inputs: - without_batch_dim_inputs = check_type.check_and_as_int_tuple_list('without_batch_dim_inputs', - without_batch_dim_inputs, 0) - meta.without_batch_dim_inputs = without_batch_dim_inputs + meta.local_meta.options = options - _ServableStorage.declare_servable(meta) - logger.info(f"Declare servable, servable_name: {meta.servable_name} " + ServableStorage_.declare_servable(meta) + logger.info(f"Declare servable, servable_name: {meta.common_meta.servable_name} " f", servable_file: {servable_file} , model_format: {model_format}, with_batch_dim: {with_batch_dim} " f", options: {options}, without_batch_dim_inputs: {without_batch_dim_inputs}") diff --git a/tests/ut/cpp/common/test_servable_common.h b/tests/ut/cpp/common/test_servable_common.h index bb7eeed..5f565ac 100644 --- a/tests/ut/cpp/common/test_servable_common.h +++ b/tests/ut/cpp/common/test_servable_common.h @@ -27,7 +27,7 @@ #include "worker/worker.h" #include "worker/notfiy_master/local_notify.h" #include "worker/context.h" -#include "worker/ascend_servable/ascend_sevable.h" +#include "worker/local_servable/local_sevable.h" #include "master/grpc/grpc_process.h" #include "mindspore_serving/proto/ms_service.pb.h" @@ -103,7 +103,7 @@ class TestMasterWorker : public UT::Common { auto notify_master = std::make_shared(); ServableContext::Instance()->SetDeviceId(0); ServableContext::Instance()->SetDeviceTypeStr("Ascend"); - auto servable = std::make_shared(); + auto servable = std::make_shared(); auto status = servable->StartServable(servable_dir, servable_name, version_number); if (status != SUCCESS) { return status; @@ -114,10 +114,10 @@ class TestMasterWorker : public UT::Common { static void DeclareServable(const std::string &servable_name, const std::string &servable_file, const std::string &model_type, bool with_batch_dim = false) { ServableMeta servable_meta; - servable_meta.servable_name = servable_name; - servable_meta.servable_file = servable_file; - servable_meta.SetModelFormat(model_type); - servable_meta.with_batch_dim = with_batch_dim; + servable_meta.common_meta.servable_name = servable_name; + servable_meta.common_meta.with_batch_dim = with_batch_dim; + servable_meta.local_meta.servable_file = servable_file; + servable_meta.local_meta.SetModelFormat(model_type); // declare_servable ServableStorage::Instance().DeclareServable(servable_meta); } From 620bc494b47c023890d26fa8bfd6e66690bbfee1 Mon Sep 17 00:00:00 2001 From: xuyongfei Date: Thu, 28 Jan 2021 21:23:02 +0800 Subject: [PATCH 07/10] Serving, commbile distributed worker and local worker in grpc server process --- mindspore_serving/ccsrc/common/exit_handle.cc | 13 ++ mindspore_serving/ccsrc/common/exit_handle.h | 2 + .../ccsrc/python/worker/worker_py.cc | 27 +++- .../distributed_process.cc | 2 +- .../distributed_process.h | 3 +- .../grpc/distributed_server.cc | 38 +++++ .../grpc/distributed_server.h | 147 ++++++++++++++++++ .../notify_distributed/notify_worker.cc | 10 +- .../notify_distributed/notify_worker.h | 12 +- .../ccsrc/worker/grpc/worker_process.cc | 1 - .../ccsrc/worker/grpc/worker_process.h | 2 +- .../ccsrc/worker/grpc/worker_server.cc | 18 ++- .../ccsrc/worker/grpc/worker_server.h | 22 +-- .../ccsrc/worker/work_executor.h | 6 +- mindspore_serving/ccsrc/worker/worker.cc | 17 +- mindspore_serving/ccsrc/worker/worker.h | 5 +- mindspore_serving/proto/ms_distributed.proto | 5 - mindspore_serving/proto/ms_worker.proto | 5 + 18 files changed, 278 insertions(+), 57 deletions(-) rename mindspore_serving/ccsrc/worker/distributed_worker/{distributed_process => grpc}/distributed_process.cc (96%) rename mindspore_serving/ccsrc/worker/distributed_worker/{distributed_process => grpc}/distributed_process.h (92%) create mode 100644 mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_server.cc create mode 100644 mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_server.h diff --git a/mindspore_serving/ccsrc/common/exit_handle.cc b/mindspore_serving/ccsrc/common/exit_handle.cc index 3b97c21..88d9644 100644 --- a/mindspore_serving/ccsrc/common/exit_handle.cc +++ b/mindspore_serving/ccsrc/common/exit_handle.cc @@ -55,6 +55,17 @@ void ExitSignalHandle::WorkerWait() { exit_future.wait(); } +// waiting ctrl+c or stop message to exit, +// if no server is running or server has exited, there is no need to wait +void ExitSignalHandle::AgentWait() { + if (!is_running_) { + MSI_LOG_INFO << "Exit Handle has not started or has exited"; + return; + } + auto exit_future = agent_exit_requested_.get_future(); + exit_future.wait(); +} + void ExitSignalHandle::Start() { if (is_running_) { return; @@ -62,6 +73,7 @@ void ExitSignalHandle::Start() { is_running_ = true; master_exit_requested_ = std::promise(); worker_exit_requested_ = std::promise(); + agent_exit_requested_ = std::promise(); has_exited_.clear(); InitSignalHandle(); } @@ -79,6 +91,7 @@ void ExitSignalHandle::HandleSignalInner() { if (!has_exited_.test_and_set()) { master_exit_requested_.set_value(); worker_exit_requested_.set_value(); + agent_exit_requested_.set_value(); is_running_ = false; } } diff --git a/mindspore_serving/ccsrc/common/exit_handle.h b/mindspore_serving/ccsrc/common/exit_handle.h index 66a2fd2..42654c6 100644 --- a/mindspore_serving/ccsrc/common/exit_handle.h +++ b/mindspore_serving/ccsrc/common/exit_handle.h @@ -32,6 +32,7 @@ class MS_API ExitSignalHandle { void InitSignalHandle(); void MasterWait(); void WorkerWait(); + void AgentWait(); void Start(); void Stop(); bool HasStopped(); @@ -39,6 +40,7 @@ class MS_API ExitSignalHandle { private: std::promise master_exit_requested_; std::promise worker_exit_requested_; + std::promise agent_exit_requested_; std::atomic_flag has_exited_ = true; std::atomic_flag has_inited_ = ATOMIC_FLAG_INIT; std::atomic_bool is_running_ = false; diff --git a/mindspore_serving/ccsrc/python/worker/worker_py.cc b/mindspore_serving/ccsrc/python/worker/worker_py.cc index dce0b70..faa0f31 100644 --- a/mindspore_serving/ccsrc/python/worker/worker_py.cc +++ b/mindspore_serving/ccsrc/python/worker/worker_py.cc @@ -23,13 +23,15 @@ #include "worker/notfiy_master/local_notify.h" #include "worker/local_servable/local_sevable.h" #include "worker/distributed_worker/distributed_servable.h" +#include "worker/grpc/worker_server.h" +#include "worker/distributed_worker/grpc/distributed_server.h" namespace mindspore::serving { void PyWorker::StartServable(const std::string &model_directory, const std::string &model_name, uint32_t version_number, - const std::string &master_ip, uint32_t master_port, const std::string &host_ip, - uint32_t host_port) { - auto notify_master = std::make_shared(master_ip, master_port, host_ip, host_port); + const std::string &master_ip, uint32_t master_port, const std::string &worker_ip, + uint32_t worker_port) { + auto notify_master = std::make_shared(master_ip, master_port, worker_ip, worker_port); auto servable = std::make_shared(); auto status = servable->StartServable(model_directory, model_name, version_number); if (status != SUCCESS) { @@ -39,10 +41,14 @@ void PyWorker::StartServable(const std::string &model_directory, const std::stri if (status != SUCCESS) { MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); } - status = Worker::GetInstance().StartGrpcServer(host_ip, host_port); + // start grpc server + auto grpc_sever = std::make_shared(); + status = grpc_sever->StartWorkerGrpcServer(worker_ip, worker_port); if (status != SUCCESS) { MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); } + Worker::GetInstance().AfterStartGrpcServer(grpc_sever); + status = Worker::GetInstance().StartVersionController(); if (status != SUCCESS) { MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); @@ -72,12 +78,15 @@ void PyWorker::StartDistributedServable(const std::string &servable_directory, c const std::string &worker_ip, uint32_t worker_port, const std::string &master_ip, uint32_t master_port) { Status status; - status = Worker::GetInstance().StartGrpcServer(worker_ip, worker_port); + auto servable = std::make_shared(); + auto grpc_sever = std::make_shared(); + status = grpc_sever->StartDistributedWorkerGrpcServer(servable, worker_ip, worker_port); if (status != SUCCESS) { MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); } + Worker::GetInstance().AfterStartGrpcServer(grpc_sever); + auto notify_master = std::make_shared(master_ip, master_port, worker_ip, worker_port); - auto servable = std::make_shared(); status = servable->StartServable(servable_directory, servable_name, rank_table_json_file, version_number); if (status != SUCCESS) { MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); @@ -96,13 +105,15 @@ void PyWorker::StartDistributedServableInMaster(const std::string &servable_dire const std::string &rank_table_json_file, uint32_t version_number, const std::string &worker_ip, uint32_t worker_port) { Status status; - status = Worker::GetInstance().StartGrpcServer(worker_ip, worker_port); + auto servable = std::make_shared(); + auto grpc_sever = std::make_shared(); + status = grpc_sever->StartDistributedWorkerGrpcServer(servable, worker_ip, worker_port); if (status != SUCCESS) { MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); } + Worker::GetInstance().AfterStartGrpcServer(grpc_sever); auto notify_master = std::make_shared(); - auto servable = std::make_shared(); status = servable->StartServable(servable_directory, servable_name, rank_table_json_file, version_number); if (status != SUCCESS) { MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.cc b/mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_process.cc similarity index 96% rename from mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.cc rename to mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_process.cc index 72442cc..0333434 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.cc +++ b/mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_process.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "worker/distributed_worker/distributed_process/distributed_process.h" +#include "worker/distributed_worker/grpc/distributed_process.h" #include "common/proto_tensor.h" namespace mindspore { diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.h b/mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_process.h similarity index 92% rename from mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.h rename to mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_process.h index 3ef02b2..b127ac7 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.h +++ b/mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_process.h @@ -27,12 +27,13 @@ #include "proto/ms_distributed.pb.h" #include "proto/ms_distributed.grpc.pb.h" #include "worker/distributed_worker/distributed_servable.h" +#include "worker/grpc/worker_process.h" namespace mindspore { namespace serving { // Service Implement -class MSDistributedImpl final : public proto::MSDistributedWorker::Service { +class MSDistributedImpl final : public MSWorkerImpl { public: explicit MSDistributedImpl(std::shared_ptr servable) : servable_(servable) {} ~MSDistributedImpl() = default; diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_server.cc b/mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_server.cc new file mode 100644 index 0000000..79d4064 --- /dev/null +++ b/mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_server.cc @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "worker/distributed_worker/grpc/distributed_server.h" +#include +#include +#include +#include "common/grpc_server.h" + +namespace mindspore { +namespace serving { + +Status MSDistributedWorkerServer::StartDistributedWorkerGrpcServer(std::shared_ptr servable, + const std::string &hostname, int32_t port) { + if (in_running_) { + return INFER_STATUS_LOG_ERROR(FAILED) << "Worker grpc server is already running"; + } + auto impl = std::make_unique(servable); + async_server_ = std::make_unique(hostname, port, impl.get()); + service_impl_ = std::move(impl); + return Init(); +} + +} // namespace serving +} // namespace mindspore diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_server.h b/mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_server.h new file mode 100644 index 0000000..2151a41 --- /dev/null +++ b/mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_server.h @@ -0,0 +1,147 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_SERVING_WORKER_DISTRIBUTED_WORKER_SERVER_H +#define MINDSPORE_SERVING_WORKER_DISTRIBUTED_WORKER_SERVER_H + +#include +#include +#include +#include +#include +#include "common/serving_common.h" +#include "proto/ms_worker.pb.h" +#include "proto/ms_worker.grpc.pb.h" +#include "common/grpc_async_server.h" +#include "worker/grpc/worker_process.h" +#include "worker/grpc/worker_server.h" +#include "worker/distributed_worker/grpc/distributed_process.h" + +namespace mindspore { +namespace serving { + +// Service Implement +class MS_API MSDistributedWorkerServer : public MSWorkerServer { + public: + Status StartDistributedWorkerGrpcServer(std::shared_ptr servable, const std::string &hostname, + int32_t port); +}; + +// Service Implement +class WorkerAgentRegisterContext : public WorkerServiceContext { + public: + WorkerAgentRegisterContext(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service, + grpc::ServerCompletionQueue *cq) + : service_impl_(service_impl), async_service_(async_service), cq_(cq), responder_(&ctx_) { + state_ = STATE::CREATE; + } + + ~WorkerAgentRegisterContext() = default; + + static Status EnqueueRequest(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service, + grpc::ServerCompletionQueue *cq) { + auto call = new WorkerAgentRegisterContext(service_impl, async_service, cq); + call->StartEnqueueRequest(); + return SUCCESS; + } + + void StartEnqueueRequest() override { + state_ = STATE::PROCESS; + async_service_->RequestPredict(&ctx_, &request_, &responder_, cq_, cq_, this); + } + + void HandleRequest() override { + EnqueueRequest(service_impl_, async_service_, cq_); + state_ = STATE::FINISH; + grpc::Status status = service_impl_->Predict(&ctx_, &request_, &response_); + responder_.Finish(response_, status, this); + } + + bool JudgeFinish() override { return state_ == STATE::FINISH; } + + private: + MSDistributedImpl *service_impl_; + proto::MSWorker::AsyncService *async_service_; + grpc::ServerCompletionQueue *cq_; + grpc::ServerContext ctx_; + grpc::ServerAsyncResponseWriter responder_; + proto::PredictRequest request_; + proto::PredictReply response_; +}; + +class WorkerAgentExitContext : public WorkerServiceContext { + public: + WorkerAgentExitContext(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service, + grpc::ServerCompletionQueue *cq) + : service_impl_(service_impl), async_service_(async_service), cq_(cq), responder_(&ctx_) { + state_ = STATE::CREATE; + } + + ~WorkerAgentExitContext() = default; + + static Status EnqueueRequest(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service, + grpc::ServerCompletionQueue *cq) { + auto call = new WorkerAgentExitContext(service_impl, async_service, cq); + call->StartEnqueueRequest(); + return SUCCESS; + } + + void StartEnqueueRequest() override { + state_ = STATE::PROCESS; + async_service_->RequestExit(&ctx_, &request_, &responder_, cq_, cq_, this); + } + + void HandleRequest() override { + EnqueueRequest(service_impl_, async_service_, cq_); + state_ = STATE::FINISH; + grpc::Status status = service_impl_->Exit(&ctx_, &request_, &response_); + responder_.Finish(response_, status, this); + } + + bool JudgeFinish() override { return state_ == STATE::FINISH; } + + private: + MSDistributedImpl *service_impl_; + proto::MSWorker::AsyncService *async_service_; + grpc::ServerCompletionQueue *cq_; + grpc::ServerContext ctx_; + grpc::ServerAsyncResponseWriter responder_; + proto::ExitRequest request_; + proto::ExitReply response_; +}; + +class DistributedWorkerGrpcServer : public WorkerGrpcServer { + public: + DistributedWorkerGrpcServer(const std::string &host, int32_t port, MSDistributedImpl *service_impl) + : WorkerGrpcServer(host, port, service_impl), distributed_service_impl_(service_impl) {} + + ~DistributedWorkerGrpcServer() = default; + + Status EnqueueRequest() { + WorkerGrpcServer::EnqueueRequest(); + WorkerAgentRegisterContext::EnqueueRequest(distributed_service_impl_, &svc_, cq_.get()); + WorkerAgentExitContext::EnqueueRequest(distributed_service_impl_, &svc_, cq_.get()); + return SUCCESS; + } + + private: + MSDistributedImpl *distributed_service_impl_; +}; + +} // namespace serving +} // namespace mindspore + +#endif // MINDSPORE_SERVING_WORKER_DISTRIBUTED_WORKER_SERVER_H diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.cc b/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.cc index 230e225..d9e6b73 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.cc +++ b/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.cc @@ -25,7 +25,7 @@ namespace mindspore { namespace serving { -GrpcNotfiyDistributeWorker::GrpcNotfiyDistributeWorker(const std::string &distributed_worker_ip, +GrpcNotifyDistributeWorker::GrpcNotifyDistributeWorker(const std::string &distributed_worker_ip, uint32_t distributed_worker_port, const std::string &host_ip, uint32_t host_port) : distributed_worker_ip_(distributed_worker_ip), @@ -35,12 +35,12 @@ GrpcNotfiyDistributeWorker::GrpcNotfiyDistributeWorker(const std::string &distri distributed_worker_address_ = distributed_worker_ip + ":" + std::to_string(distributed_worker_port); agent_address_ = host_ip_ + ":" + std::to_string(host_port_); auto channel = GrpcServer::CreateChannel(distributed_worker_address_); - stub_ = proto::MSDistributedWorker::NewStub(channel); + stub_ = proto::MSWorker::NewStub(channel); } -GrpcNotfiyDistributeWorker::~GrpcNotfiyDistributeWorker() = default; +GrpcNotifyDistributeWorker::~GrpcNotifyDistributeWorker() = default; -Status GrpcNotfiyDistributeWorker::Register(const std::vector &worker_specs) { +Status GrpcNotifyDistributeWorker::Register(const std::vector &worker_specs) { const int32_t REGISTER_TIME_OUT = 60; const int32_t REGISTER_INTERVAL = 1; auto loop = REGISTER_TIME_OUT; @@ -67,7 +67,7 @@ Status GrpcNotfiyDistributeWorker::Register(const std::vector & return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Register TimeOut"; } -Status GrpcNotfiyDistributeWorker::Unregister() { +Status GrpcNotifyDistributeWorker::Unregister() { if (is_stoped_.load()) { return SUCCESS; } diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.h b/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.h index d618878..2c2724c 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.h +++ b/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.h @@ -20,18 +20,18 @@ #include #include #include "worker/distributed_worker/notify_distributed/base_notify_worker.h" -#include "proto/ms_master.pb.h" -#include "proto/ms_master.grpc.pb.h" #include "proto/ms_distributed.pb.h" #include "proto/ms_distributed.grpc.pb.h" +#include "proto/ms_worker.pb.h" +#include "proto/ms_worker.grpc.pb.h" namespace mindspore { namespace serving { -class MS_API GrpcNotfiyDistributeWorker : public BaseNotifyDistributeWorker { +class MS_API GrpcNotifyDistributeWorker : public BaseNotifyDistributeWorker { public: - GrpcNotfiyDistributeWorker(const std::string &master_ip, uint32_t master_port, const std::string &host_ip, + GrpcNotifyDistributeWorker(const std::string &master_ip, uint32_t master_port, const std::string &host_ip, uint32_t host_port); - ~GrpcNotfiyDistributeWorker() override; + ~GrpcNotifyDistributeWorker() override; Status Register(const std::vector &worker_specs) override; Status Unregister() override; @@ -42,7 +42,7 @@ class MS_API GrpcNotfiyDistributeWorker : public BaseNotifyDistributeWorker { uint32_t host_port_; std::string agent_address_; std::string distributed_worker_address_; - std::unique_ptr stub_; + std::unique_ptr stub_; std::atomic is_stoped_{false}; }; diff --git a/mindspore_serving/ccsrc/worker/grpc/worker_process.cc b/mindspore_serving/ccsrc/worker/grpc/worker_process.cc index 73c38d1..2d41b03 100644 --- a/mindspore_serving/ccsrc/worker/grpc/worker_process.cc +++ b/mindspore_serving/ccsrc/worker/grpc/worker_process.cc @@ -15,7 +15,6 @@ */ #include "worker/grpc/worker_process.h" -#include "master/dispacther.h" #include "worker/worker.h" namespace mindspore { diff --git a/mindspore_serving/ccsrc/worker/grpc/worker_process.h b/mindspore_serving/ccsrc/worker/grpc/worker_process.h index 450158e..ebdb3c5 100644 --- a/mindspore_serving/ccsrc/worker/grpc/worker_process.h +++ b/mindspore_serving/ccsrc/worker/grpc/worker_process.h @@ -28,7 +28,7 @@ namespace mindspore { namespace serving { // Service Implement -class MSWorkerImpl final : public proto::MSWorker::Service { +class MSWorkerImpl : public proto::MSWorker::Service { public: grpc::Status Predict(grpc::ServerContext *context, const proto::PredictRequest *request, proto::PredictReply *reply) override; diff --git a/mindspore_serving/ccsrc/worker/grpc/worker_server.cc b/mindspore_serving/ccsrc/worker/grpc/worker_server.cc index cc603ad..58880df 100644 --- a/mindspore_serving/ccsrc/worker/grpc/worker_server.cc +++ b/mindspore_serving/ccsrc/worker/grpc/worker_server.cc @@ -21,12 +21,20 @@ namespace mindspore { namespace serving { + MSWorkerServer::~MSWorkerServer() { Stop(); } -MSWorkerServer::MSWorkerServer(const std::string &hostname, int32_t port) { +Status MSWorkerServer::StartWorkerGrpcServer(const std::string &hostname, int32_t port) { + if (in_running_) { + return INFER_STATUS_LOG_ERROR(FAILED) << "Worker grpc server is already running"; + } service_impl_ = std::make_unique(); async_server_ = std::make_unique(hostname, port, service_impl_.get()); + return Init(); } + +MSWorkerServer::MSWorkerServer() = default; + Status MSWorkerServer::Init() { Status status = async_server_->Run("Worker gRPC", gRpcMaxMBMsgSize); if (status != SUCCESS) return status; @@ -40,10 +48,14 @@ Status MSWorkerServer::StartAsyncRpcService() { return status; } Status MSWorkerServer::Stop() { - if (in_running_) { + if (in_running_ && async_server_) { async_server_->Stop(); - grpc_thread_.join(); + if (grpc_thread_.joinable()) { + grpc_thread_.join(); + } } + async_server_ = nullptr; + service_impl_ = nullptr; in_running_ = false; return SUCCESS; } diff --git a/mindspore_serving/ccsrc/worker/grpc/worker_server.h b/mindspore_serving/ccsrc/worker/grpc/worker_server.h index 8bcc057..1452727 100644 --- a/mindspore_serving/ccsrc/worker/grpc/worker_server.h +++ b/mindspore_serving/ccsrc/worker/grpc/worker_server.h @@ -27,27 +27,29 @@ #include "proto/ms_worker.grpc.pb.h" #include "common/grpc_async_server.h" #include "worker/grpc/worker_process.h" +#include "worker/distributed_worker/distributed_servable.h" namespace mindspore { namespace serving { // Service Implement -class MSWorkerServer { +class MS_API MSWorkerServer { public: enum ServerState { kGdsUninit = 0, kGdsInitializing, kGdsRunning, kGdsStopped }; - MSWorkerServer(const std::string &hostname, int32_t port); - ~MSWorkerServer(); - - Status Init(); + MSWorkerServer(); + virtual ~MSWorkerServer(); + Status StartWorkerGrpcServer(const std::string &hostname, int32_t port); Status Stop(); - Status StartAsyncRpcService(); - + protected: bool in_running_ = false; std::thread grpc_thread_; - std::unique_ptr service_impl_; - std::unique_ptr async_server_; + std::unique_ptr service_impl_ = nullptr; + std::unique_ptr async_server_ = nullptr; + + Status StartAsyncRpcService(); + Status Init(); }; class WorkerServiceContext { @@ -174,7 +176,7 @@ class WorkerGrpcServer : public GrpcAsyncServer { return SUCCESS; } - private: + protected: MSWorkerImpl *service_impl_; proto::MSWorker::AsyncService svc_; }; diff --git a/mindspore_serving/ccsrc/worker/work_executor.h b/mindspore_serving/ccsrc/worker/work_executor.h index 2c44e6a..d491843 100644 --- a/mindspore_serving/ccsrc/worker/work_executor.h +++ b/mindspore_serving/ccsrc/worker/work_executor.h @@ -39,10 +39,8 @@ using WorkCallBack = std::function py_preprocess_task_queue, - std::shared_ptr py_postprocess_task_queue, - std::shared_ptr cpp_preprocess_task_queue, - std::shared_ptr cpp_postprocess_task_queue); + WorkExecutor(std::shared_ptr py_preprocess, std::shared_ptr py_postprocess, + std::shared_ptr cpp_preprocess, std::shared_ptr cpp_postprocess); ~WorkExecutor(); Status Init(const ServableSignature &servable_declare, const std::shared_ptr &servable); diff --git a/mindspore_serving/ccsrc/worker/worker.cc b/mindspore_serving/ccsrc/worker/worker.cc index fed5d42..87167ef 100644 --- a/mindspore_serving/ccsrc/worker/worker.cc +++ b/mindspore_serving/ccsrc/worker/worker.cc @@ -34,21 +34,11 @@ namespace py = pybind11; namespace mindspore { namespace serving { -static std::unique_ptr grpc_async_worker_server_; - Worker &Worker::GetInstance() { static Worker instance; return instance; } -Status Worker::StartGrpcServer(const std::string &ip, uint32_t grpc_port) { - if (grpc_async_worker_server_ != nullptr) { - return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Serving Error: Worker gRPC server is already running"; - } - grpc_async_worker_server_ = std::make_unique(ip, grpc_port); - return grpc_async_worker_server_->Init(); -} - Status Worker::RegisterWorker() { std::vector worker_specs; for (auto &work : work_list_) { @@ -184,6 +174,11 @@ void Worker::Update() { */ } +Status Worker::AfterStartGrpcServer(const std::shared_ptr &grpc_server) { + worker_grpc_server_ = grpc_server; + return SUCCESS; +} + Status Worker::StartServable(std::shared_ptr servable, std::shared_ptr notify_master) { ExitSignalHandle::Instance().Start(); // handle ctrl+c to exit if (servable_started_) { @@ -244,7 +239,7 @@ void Worker::StopServable(bool notify_master) { void Worker::Clear() { std::unique_lock lock(worker_shared_lock_); ServableStorage::Instance().Clear(); - grpc_async_worker_server_ = nullptr; + worker_grpc_server_ = nullptr; if (clear_flag_.test_and_set()) { return; } diff --git a/mindspore_serving/ccsrc/worker/worker.h b/mindspore_serving/ccsrc/worker/worker.h index 122b7d4..50d95eb 100644 --- a/mindspore_serving/ccsrc/worker/worker.h +++ b/mindspore_serving/ccsrc/worker/worker.h @@ -33,6 +33,7 @@ #include "worker/version_control/version_controller.h" #include "common/grpc_async_server.h" #include "worker/sevable_base.h" +#include "worker/grpc/worker_server.h" namespace mindspore { namespace serving { @@ -74,10 +75,11 @@ class MS_API Worker { const std::vector &inputs); Status StartServable(std::shared_ptr servable, std::shared_ptr notify_master); + Status AfterStartGrpcServer(const std::shared_ptr &grpc_server); + void StopServable(bool notify_master = true); bool HasCleared(); Status RegisterWorker(); - Status StartGrpcServer(const std::string &ip, uint32_t grpc_port); void Update(); Status StartVersionController(); Status AddWorker(const ServableWorkerContext &work); @@ -101,6 +103,7 @@ class MS_API Worker { std::atomic_bool servable_started_ = false; std::atomic_flag clear_flag_ = ATOMIC_FLAG_INIT; std::shared_ptr notify_master_ = nullptr; + std::shared_ptr worker_grpc_server_ = nullptr; std::shared_mutex worker_shared_lock_; diff --git a/mindspore_serving/proto/ms_distributed.proto b/mindspore_serving/proto/ms_distributed.proto index c7be82f..13b13d5 100644 --- a/mindspore_serving/proto/ms_distributed.proto +++ b/mindspore_serving/proto/ms_distributed.proto @@ -45,8 +45,3 @@ message AgentExitRequest { message AgentExitReply { ErrorMsg error_msg = 1; } - -service MSDistributedWorker { - rpc AgentExit(AgentExitRequest) returns (AgentExitReply) {} - rpc AgentRegister(AgentRegisterRequest) returns (AgentRegisterReply) {} -} \ No newline at end of file diff --git a/mindspore_serving/proto/ms_worker.proto b/mindspore_serving/proto/ms_worker.proto index c9ed051..436b52f 100644 --- a/mindspore_serving/proto/ms_worker.proto +++ b/mindspore_serving/proto/ms_worker.proto @@ -20,8 +20,13 @@ syntax = "proto3"; package mindspore.serving.proto; import "mindspore_serving/proto/ms_service.proto"; import "mindspore_serving/proto/ms_master.proto"; +import "mindspore_serving/proto/ms_distributed.proto"; service MSWorker { + // for master rpc Predict(PredictRequest) returns (PredictReply) {} rpc Exit(ExitRequest) returns (ExitReply) {} + // for worker agent + rpc AgentExit(AgentExitRequest) returns (AgentExitReply) {} + rpc AgentRegister(AgentRegisterRequest) returns (AgentRegisterReply) {} } From 9fb423c3e39b7350534c531f8e9feccf1ef334ff Mon Sep 17 00:00:00 2001 From: zhangyinxia Date: Fri, 29 Jan 2021 10:52:34 +0800 Subject: [PATCH 08/10] add predict code --- mindspore_serving/ccsrc/common/grpc_client.cc | 51 +-------------- mindspore_serving/ccsrc/common/grpc_client.h | 64 ++++++++++++++++--- .../ccsrc/common/proto_tensor.cc | 5 +- mindspore_serving/ccsrc/common/servable.h | 1 - .../ccsrc/master/notify_worker/grpc_notify.cc | 4 +- .../agent_process/agent_process.cc | 9 ++- .../agent_process/agent_process.h | 8 +-- .../notify_agent/base_notify_agent.h | 9 ++- .../notify_agent/notify_agent.cc | 15 +++-- .../notify_agent/notify_agent.h | 4 +- .../worker/distributed_worker/worker_agent.cc | 3 +- mindspore_serving/proto/ms_agent.proto | 19 ++---- mindspore_serving/proto/ms_distributed.proto | 5 +- mindspore_serving/proto/ms_service.proto | 2 + 14 files changed, 97 insertions(+), 102 deletions(-) diff --git a/mindspore_serving/ccsrc/common/grpc_client.cc b/mindspore_serving/ccsrc/common/grpc_client.cc index 508da4e..d4ccb8c 100644 --- a/mindspore_serving/ccsrc/common/grpc_client.cc +++ b/mindspore_serving/ccsrc/common/grpc_client.cc @@ -15,58 +15,11 @@ */ #include "common/grpc_client.h" -#include -#include namespace mindspore { namespace serving { -std::unique_ptr client_; - -MSServiceClient::~MSServiceClient() { - if (in_running_) { - cq_.Shutdown(); - if (client_thread_.joinable()) { - try { - client_thread_.join(); - } catch (const std::system_error &) { - } catch (...) { - } - } - } - in_running_ = false; -} - -void MSServiceClient::PredictAsync(const proto::PredictRequest &request, proto::PredictReply *reply, - std::shared_ptr stub, DispatchCallback callback) { - AsyncClientCall *call = new AsyncClientCall; - call->reply = reply; - call->callback = std::move(callback); - call->response_reader = stub->PrepareAsyncPredict(&call->context, request, &cq_); - call->response_reader->StartCall(); - call->response_reader->Finish(call->reply, &call->status, call); - MSI_LOG(INFO) << "Finish send Predict"; -} - -void MSServiceClient::AsyncCompleteRpc() { - void *got_tag; - bool ok = false; - - while (cq_.Next(&got_tag, &ok)) { - AsyncClientCall *call = static_cast(got_tag); - if (call->status.ok()) { - call->callback(SUCCESS); - } else { - MSI_LOG_ERROR << "RPC failed: " << call->status.error_code() << ", " << call->status.error_message(); - call->callback(Status(FAILED, call->status.error_message())); - } - delete call; - } -} - -void MSServiceClient::Start() { - client_thread_ = std::thread(&MSServiceClient::AsyncCompleteRpc, this); - in_running_ = true; -} +std::unique_ptr client_; +std::unique_ptr distributed_client_; } // namespace serving } // namespace mindspore diff --git a/mindspore_serving/ccsrc/common/grpc_client.h b/mindspore_serving/ccsrc/common/grpc_client.h index afd9a1e..c784713 100644 --- a/mindspore_serving/ccsrc/common/grpc_client.h +++ b/mindspore_serving/ccsrc/common/grpc_client.h @@ -23,39 +23,80 @@ #include #include #include +#include +#include #include "common/serving_common.h" #include "proto/ms_service.pb.h" #include "proto/ms_service.grpc.pb.h" #include "proto/ms_master.pb.h" #include "proto/ms_master.grpc.pb.h" #include "proto/ms_worker.grpc.pb.h" +#include "proto/ms_agent.pb.h" +#include "proto/ms_agent.grpc.pb.h" namespace mindspore { namespace serving { -class MSServiceClient; -extern std::unique_ptr client_; using PredictOnFinish = std::function; using DispatchCallback = std::function; +template class MSServiceClient { public: MSServiceClient() = default; - ~MSServiceClient(); - void AsyncCompleteRpc(); - void Start(); + ~MSServiceClient() { + if (in_running_) { + cq_.Shutdown(); + if (client_thread_.joinable()) { + try { + client_thread_.join(); + } catch (const std::system_error &) { + } catch (...) { + } + } + } + in_running_ = false; + } - void PredictAsync(const proto::PredictRequest &request, proto::PredictReply *reply, - std::shared_ptr stub, DispatchCallback callback); + void Start() { + client_thread_ = std::thread(&MSServiceClient::AsyncCompleteRpc, this); + in_running_ = true; + } + + void AsyncCompleteRpc() { + void *got_tag; + bool ok = false; + + while (cq_.Next(&got_tag, &ok)) { + AsyncClientCall *call = static_cast(got_tag); + if (call->status.ok()) { + call->callback(SUCCESS); + } else { + MSI_LOG_ERROR << "RPC failed: " << call->status.error_code() << ", " << call->status.error_message(); + call->callback(Status(FAILED, call->status.error_message())); + } + delete call; + } + } + + void PredictAsync(const Request &request, Reply *reply, MSStub *stub, DispatchCallback callback) { + AsyncClientCall *call = new AsyncClientCall; + call->reply = reply; + call->callback = std::move(callback); + call->response_reader = stub->PrepareAsyncPredict(&call->context, request, &cq_); + call->response_reader->StartCall(); + call->response_reader->Finish(call->reply, &call->status, call); + MSI_LOG(INFO) << "Finish send Predict"; + } private: struct AsyncClientCall { grpc::ClientContext context; grpc::Status status; - proto::PredictReply *reply; + Reply *reply; DispatchCallback callback; - std::shared_ptr> response_reader; + std::shared_ptr> response_reader; }; grpc::CompletionQueue cq_; @@ -63,6 +104,11 @@ class MSServiceClient { bool in_running_ = false; }; +using MSPredictClient = MSServiceClient; +using MSDistributedClient = + MSServiceClient; +extern std::unique_ptr client_; +extern std::unique_ptr distributed_client_; } // namespace serving } // namespace mindspore diff --git a/mindspore_serving/ccsrc/common/proto_tensor.cc b/mindspore_serving/ccsrc/common/proto_tensor.cc index f97fe9c..c4f8d16 100644 --- a/mindspore_serving/ccsrc/common/proto_tensor.cc +++ b/mindspore_serving/ccsrc/common/proto_tensor.cc @@ -344,10 +344,10 @@ Status GrpcTensorHelper::CreateInstanceFromRequestInstances(const proto::Predict void GrpcTensorHelper::CopyFromAgentSpec(const proto::AgentSpec &specs, WorkerAgentSpec *worker_specs) { worker_specs->rank_id = specs.rank_id(); worker_specs->batch_size = specs.batch_size(); - worker_specs->input_size = specs.input_size(); for (auto &in : specs.inputs()) { TensorInfo info; info.data_type = ProtoTensor::TransDataType2Inference(in.dtype()); + info.size = in.size(); for (auto &dim : in.shape().dims()) { info.shape.push_back(dim); } @@ -370,10 +370,10 @@ void GrpcTensorHelper::CopyFromWorkerAgentSpec(const std::vectoradd_agent_spec(); worker_spec->set_rank_id(spec.rank_id); worker_spec->set_batch_size(spec.batch_size); - worker_spec->set_input_size(spec.input_size); for (auto &method : spec.input_infos) { auto proto_method = worker_spec->add_inputs(); proto_method->set_dtype(ProtoTensor::TransDataType2Proto(method.data_type)); + proto_method->set_size(method.size); auto proto_shape = proto_method->mutable_shape(); for (auto &dim : method.shape) { proto_shape->add_dims(dim); @@ -382,6 +382,7 @@ void GrpcTensorHelper::CopyFromWorkerAgentSpec(const std::vectoradd_outputs(); proto_method->set_dtype(ProtoTensor::TransDataType2Proto(method.data_type)); + proto_method->set_size(method.size); auto proto_shape = proto_method->mutable_shape(); for (auto &dim : method.shape) { proto_shape->add_dims(dim); diff --git a/mindspore_serving/ccsrc/common/servable.h b/mindspore_serving/ccsrc/common/servable.h index 0de2afe..3458748 100644 --- a/mindspore_serving/ccsrc/common/servable.h +++ b/mindspore_serving/ccsrc/common/servable.h @@ -177,7 +177,6 @@ struct WorkerAgentSpec { std::vector input_infos; std::vector output_infos; uint32_t batch_size = 0; - uint32_t input_size = 0; }; } // namespace mindspore::serving diff --git a/mindspore_serving/ccsrc/master/notify_worker/grpc_notify.cc b/mindspore_serving/ccsrc/master/notify_worker/grpc_notify.cc index 86ca8e7..b60a86d 100644 --- a/mindspore_serving/ccsrc/master/notify_worker/grpc_notify.cc +++ b/mindspore_serving/ccsrc/master/notify_worker/grpc_notify.cc @@ -55,10 +55,10 @@ Status GrpcNotfiyWorker::DispatchAsync(const proto::PredictRequest &request, pro << worker_address_; } if (!client_) { - client_ = std::make_unique(); + client_ = std::make_unique(); client_->Start(); } - client_->PredictAsync(request, reply, stub_, callback); + client_->PredictAsync(request, reply, stub_.get(), callback); return SUCCESS; } diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/agent_process/agent_process.cc b/mindspore_serving/ccsrc/worker/distributed_worker/agent_process/agent_process.cc index 97a7a77..ff030e6 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/agent_process/agent_process.cc +++ b/mindspore_serving/ccsrc/worker/distributed_worker/agent_process/agent_process.cc @@ -19,16 +19,15 @@ namespace mindspore { namespace serving { -grpc::Status MSAgentImpl::DistributedExit(grpc::ServerContext *context, const proto::DistributedExitRequest *request, - proto::DistributedExitReply *reply) { +grpc::Status MSAgentImpl::Exit(grpc::ServerContext *context, const proto::DistributedExitRequest *request, + proto::DistributedExitReply *reply) { MSI_LOG(INFO) << "Distributed Worker Exit"; WorkerAgent::Instance().Clear(); return grpc::Status::OK; } -grpc::Status MSAgentImpl::DistributedPredict(grpc::ServerContext *context, - const proto::DistributedPredictRequest *request, - proto::DistributedPredictReply *reply) { +grpc::Status MSAgentImpl::Predict(grpc::ServerContext *context, const proto::DistributedPredictRequest *request, + proto::DistributedPredictReply *reply) { MSI_LOG(INFO) << "Begin call service Eval"; WorkerAgent::Instance().Run(*request, reply); return grpc::Status::OK; diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/agent_process/agent_process.h b/mindspore_serving/ccsrc/worker/distributed_worker/agent_process/agent_process.h index 7ea69ab..d0ea12c 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/agent_process/agent_process.h +++ b/mindspore_serving/ccsrc/worker/distributed_worker/agent_process/agent_process.h @@ -30,10 +30,10 @@ namespace serving { // Service Implement class MSAgentImpl final : public proto::MSAgent::Service { public: - grpc::Status DistributedPredict(grpc::ServerContext *context, const proto::DistributedPredictRequest *request, - proto::DistributedPredictReply *reply) override; - grpc::Status DistributedExit(grpc::ServerContext *context, const proto::DistributedExitRequest *request, - proto::DistributedExitReply *reply) override; + grpc::Status Predict(grpc::ServerContext *context, const proto::DistributedPredictRequest *request, + proto::DistributedPredictReply *reply) override; + grpc::Status Exit(grpc::ServerContext *context, const proto::DistributedExitRequest *request, + proto::DistributedExitReply *reply) override; }; } // namespace serving diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/base_notify_agent.h b/mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/base_notify_agent.h index 861ea0d..ac4d5c7 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/base_notify_agent.h +++ b/mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/base_notify_agent.h @@ -21,20 +21,19 @@ #include #include "common/serving_common.h" #include "common/servable.h" -#include "proto/ms_service.pb.h" +#include "proto/ms_agent.pb.h" +#include "common/grpc_client.h" namespace mindspore { namespace serving { -using DistributeCallback = std::function; - class MS_API BaseNotifyAgent { public: BaseNotifyAgent() = default; virtual ~BaseNotifyAgent() = default; virtual Status Exit() = 0; - virtual Status DispatchAsync(const proto::PredictRequest &request, proto::PredictReply *reply, - DistributeCallback callback) = 0; + virtual Status DispatchAsync(const proto::DistributedPredictRequest &request, proto::DistributedPredictReply *reply, + DispatchCallback callback) = 0; }; } // namespace serving diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/notify_agent.cc b/mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/notify_agent.cc index 2c810d2..3220a6c 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/notify_agent.cc +++ b/mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/notify_agent.cc @@ -20,6 +20,7 @@ #include #include "common/exit_handle.h" #include "common/grpc_server.h" +#include "common/grpc_client.h" namespace mindspore { namespace serving { @@ -42,20 +43,24 @@ Status GrpcNotfiyAgent::Exit() { std::chrono::system_clock::time_point deadline = std::chrono::system_clock::now() + std::chrono::seconds(TIME_OUT); context.set_deadline(deadline); - (void)stub_->DistributedExit(&context, request, &reply); + (void)stub_->Exit(&context, request, &reply); } return SUCCESS; } -Status GrpcNotfiyAgent::DispatchAsync(const proto::PredictRequest &request, proto::PredictReply *reply, - DistributeCallback callback) { +Status GrpcNotfiyAgent::DispatchAsync(const proto::DistributedPredictRequest &request, + proto::DistributedPredictReply *reply, DispatchCallback callback) { if (!stub_) { return INFER_STATUS_LOG_ERROR(FAILED) << "Predict failed, agent gRPC has not been inited or has already exited, agent address " << agent_address_; } - // todo send async message + if (!distributed_client_) { + distributed_client_ = std::make_unique(); + distributed_client_->Start(); + } + distributed_client_->PredictAsync(request, reply, stub_.get(), callback); return SUCCESS; -} +} // namespace serving } // namespace serving } // namespace mindspore diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/notify_agent.h b/mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/notify_agent.h index cf984f1..53fd39f 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/notify_agent.h +++ b/mindspore_serving/ccsrc/worker/distributed_worker/notify_agent/notify_agent.h @@ -34,8 +34,8 @@ class MS_API GrpcNotfiyAgent : public BaseNotifyAgent { Status Exit() override; - Status DispatchAsync(const proto::PredictRequest &request, proto::PredictReply *reply, - DistributeCallback callback) override; + Status DispatchAsync(const proto::DistributedPredictRequest &request, proto::DistributedPredictReply *reply, + DispatchCallback callback) override; private: std::string agent_address_; diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.cc b/mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.cc index 2e497e0..c5e59df 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.cc +++ b/mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.cc @@ -35,7 +35,8 @@ Status WorkerAgent::ExecuteModel(const std::vector &request, std: } Status WorkerAgent::Run(const proto::DistributedPredictRequest &request, proto::DistributedPredictReply *reply) { - // todo :call ExecuteModel + // todo : DistributedPredictRequest->RequestBase + // todo : DistributedPredictReply->ReplyBase return SUCCESS; } diff --git a/mindspore_serving/proto/ms_agent.proto b/mindspore_serving/proto/ms_agent.proto index 143f0de..1642993 100644 --- a/mindspore_serving/proto/ms_agent.proto +++ b/mindspore_serving/proto/ms_agent.proto @@ -20,22 +20,13 @@ syntax = "proto3"; package mindspore.serving.proto; import "mindspore_serving/proto/ms_service.proto"; -message DistributedServableSpec { - // servable name - string name = 1; - // optional. If unspecified, the latest version servable will be used. - int64 version_number = 2; - // Specifies the method name in the servable. - string method_name = 3; -} - message DistributedPredictRequest { - DistributedServableSpec servable_spec = 1; + repeated Tensor inputs = 1; } message DistributedPredictReply { - DistributedServableSpec servable_spec = 1; - repeated ErrorMsg error_msg = 2; + repeated Tensor outputs = 1; + ErrorMsg error_msg = 2; } message DistributedExitRequest { @@ -47,6 +38,6 @@ message DistributedExitReply { } service MSAgent { - rpc DistributedPredict(DistributedPredictRequest) returns (DistributedPredictReply) {} - rpc DistributedExit(DistributedExitRequest) returns (DistributedExitReply) {} + rpc Predict(DistributedPredictRequest) returns (DistributedPredictReply) {} + rpc Exit(DistributedExitRequest) returns (DistributedExitReply) {} } diff --git a/mindspore_serving/proto/ms_distributed.proto b/mindspore_serving/proto/ms_distributed.proto index 13b13d5..fb6c72a 100644 --- a/mindspore_serving/proto/ms_distributed.proto +++ b/mindspore_serving/proto/ms_distributed.proto @@ -23,9 +23,8 @@ import "mindspore_serving/proto/ms_service.proto"; message AgentSpec { int64 rank_id = 1; int64 batch_size = 2; - int64 input_size = 3; - repeated Tensor inputs = 4; - repeated Tensor outputs = 5; + repeated Tensor inputs = 3; + repeated Tensor outputs = 4; } message AgentRegisterRequest { diff --git a/mindspore_serving/proto/ms_service.proto b/mindspore_serving/proto/ms_service.proto index 908c1dd..ddcccd4 100644 --- a/mindspore_serving/proto/ms_service.proto +++ b/mindspore_serving/proto/ms_service.proto @@ -80,6 +80,8 @@ message Tensor { // for string type and images, the dtype is MS_BYTES. repeated bytes bytes_val = 4; + + int64 size = 5; } message ServableSpec { From 90db9fba60a669fc57dfdac84018c5368b949190 Mon Sep 17 00:00:00 2001 From: xuyongfei Date: Tue, 2 Feb 2021 19:58:11 +0800 Subject: [PATCH 09/10] Serving, python agent --- mindspore_serving/ccsrc/master/server.cc | 3 - .../ccsrc/python/agent/agent_py.cc | 63 +++++ .../agent/agent_py.h} | 31 ++- mindspore_serving/ccsrc/python/serving_py.cc | 52 +++- .../ccsrc/python/worker/worker_py.cc | 27 +- .../ccsrc/python/worker/worker_py.h | 5 +- .../agent_process/agent_process.cc | 2 +- .../distributed_worker/agent_startup.cc | 27 +- .../worker/distributed_worker/agent_startup.h | 11 +- .../distributed_process.cc | 17 +- .../distributed_process.h | 2 + .../distributed_server.cc | 7 +- .../distributed_server.h | 99 ++++--- .../distributed_servable.cc | 109 +++++--- .../distributed_worker/distributed_servable.h | 15 +- .../notify_distributed/notify_worker.cc | 18 +- .../notify_distributed/notify_worker.h | 17 +- .../worker/distributed_worker/worker_agent.cc | 77 +++++- .../worker/distributed_worker/worker_agent.h | 16 +- .../ccsrc/worker/grpc/worker_server.h | 39 ++- .../worker/local_servable/local_sevable.cc | 9 +- .../worker/local_servable/local_sevable.h | 1 + mindspore_serving/ccsrc/worker/sevable_base.h | 1 + mindspore_serving/ccsrc/worker/worker.cc | 15 +- mindspore_serving/ccsrc/worker/worker.h | 5 +- mindspore_serving/master/_master.py | 2 + mindspore_serving/proto/ms_distributed.proto | 7 + mindspore_serving/proto/ms_worker.proto | 1 + mindspore_serving/worker/_worker.py | 4 +- .../worker/distributed/agent_startup.py | 249 ++++++++++++++++-- .../worker/distributed/distributed_worker.py | 22 +- .../worker/distributed/worker_agent.py | 62 +++-- 32 files changed, 797 insertions(+), 218 deletions(-) create mode 100644 mindspore_serving/ccsrc/python/agent/agent_py.cc rename mindspore_serving/ccsrc/{worker/distributed_worker/notify_distributed/base_notify_worker.h => python/agent/agent_py.h} (52%) rename mindspore_serving/ccsrc/worker/distributed_worker/{grpc => distributed_process}/distributed_process.cc (76%) rename mindspore_serving/ccsrc/worker/distributed_worker/{grpc => distributed_process}/distributed_process.h (89%) rename mindspore_serving/ccsrc/worker/distributed_worker/{grpc => distributed_process}/distributed_server.cc (73%) rename mindspore_serving/ccsrc/worker/distributed_worker/{grpc => distributed_process}/distributed_server.h (50%) diff --git a/mindspore_serving/ccsrc/master/server.cc b/mindspore_serving/ccsrc/master/server.cc index 5117bad..980daac 100644 --- a/mindspore_serving/ccsrc/master/server.cc +++ b/mindspore_serving/ccsrc/master/server.cc @@ -39,7 +39,6 @@ Status Server::StartGrpcServer(const std::string &ip, uint32_t grpc_port, int ma if (grpc_async_server_ != nullptr) { return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Serving Error: Serving gRPC server is already running"; } - ExitSignalHandle::Instance().Start(); // handle ctrl+c to exit if (max_msg_mb_size > gRpcMaxMBMsgSize) { MSI_LOG_WARNING << "The maximum Serving gRPC message size is 512MB and will be updated from " << max_msg_mb_size << "MB to 512MB"; @@ -50,14 +49,12 @@ Status Server::StartGrpcServer(const std::string &ip, uint32_t grpc_port, int ma } Status Server::StartGrpcMasterServer(const std::string &ip, uint32_t grpc_port) { - ExitSignalHandle::Instance().Start(); // handle ctrl+c to exit return grpc_manager_server_.Start(std::make_shared(dispatcher_), ip, grpc_port, gRpcMaxMBMsgSize, "Master"); } Status Server::StartRestfulServer(const std::string &ip, uint32_t restful_port, int max_msg_mb_size, int time_out_second) { - ExitSignalHandle::Instance().Start(); // handle ctrl+c to exit return restful_server_.Start(ip, restful_port, max_msg_mb_size, time_out_second); } diff --git a/mindspore_serving/ccsrc/python/agent/agent_py.cc b/mindspore_serving/ccsrc/python/agent/agent_py.cc new file mode 100644 index 0000000..c2c1465 --- /dev/null +++ b/mindspore_serving/ccsrc/python/agent/agent_py.cc @@ -0,0 +1,63 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "python/agent/agent_py.h" +#include "common/exit_handle.h" +#include "worker/distributed_worker/agent_startup.h" +#include "worker/distributed_worker/worker_agent.h" + +namespace mindspore::serving { + +DistributedServableConfig PyAgent::GetAgentsConfigsFromWorker(const std::string &worker_ip, uint32_t worker_port) { + auto status = WorkerAgentStartUp::Instance().GetAgentsConfigsFromWorker(worker_ip, worker_port); + if (status != SUCCESS) { + MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); + } + + DistributedServableConfig config; + status = WorkerAgentStartUp::Instance().GetDistributedServableConfig(&config); + if (status != SUCCESS) { + MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); + } + return config; +} + +void PyAgent::NotifyFailed(const std::string &worker_ip, uint32_t worker_port) { + WorkerAgentStartUp::Instance().NotifyFailed(worker_ip, worker_port); +} + +void PyAgent::StartAgent(const AgentStartUpConfig &start_config) { + auto status = WorkerAgent::Instance().StartAgent(start_config); + if (status != SUCCESS) { + MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); + } +} + +void PyAgent::WaitAndClear() { + { + py::gil_scoped_release release; + ExitSignalHandle::Instance().AgentWait(); + } + WorkerAgent::Instance().Clear(); + MSI_LOG_INFO << "Python agent end wait and clear"; +} + +void PyAgent::StopAndClear() { + ExitSignalHandle::Instance().Stop(); + WorkerAgent::Instance().Clear(); +} + +} // namespace mindspore::serving diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/base_notify_worker.h b/mindspore_serving/ccsrc/python/agent/agent_py.h similarity index 52% rename from mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/base_notify_worker.h rename to mindspore_serving/ccsrc/python/agent/agent_py.h index 8e5e690..708b673 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/base_notify_worker.h +++ b/mindspore_serving/ccsrc/python/agent/agent_py.h @@ -1,5 +1,5 @@ /** - * Copyright 2021 Huawei Technologies Co., Ltd + * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,25 +14,34 @@ * limitations under the License. */ -#ifndef MINDSPORE_SERVING_WORKER_BASE_NOTIFY_WORKER_H -#define MINDSPORE_SERVING_WORKER_BASE_NOTIFY_WORKER_H -#include +#ifndef MINDSPORE_SERVER_AGENT_PY_H +#define MINDSPORE_SERVER_AGENT_PY_H + +#include +#include +#include +#include +#include #include "common/serving_common.h" -#include "common/servable.h" #include "worker/distributed_worker/common.h" +namespace py = pybind11; + namespace mindspore { namespace serving { -class MS_API BaseNotifyDistributeWorker { +class MS_API PyAgent { public: - BaseNotifyDistributeWorker() = default; - virtual ~BaseNotifyDistributeWorker() = default; - virtual Status Register(const std::vector &worker_specs) = 0; - virtual Status Unregister() = 0; + static void StartAgent(const AgentStartUpConfig &start_config); + + static DistributedServableConfig GetAgentsConfigsFromWorker(const std::string &worker_ip, uint32_t worker_port); + static void WaitAndClear(); + static void StopAndClear(); + // from start up, not agent + static void NotifyFailed(const std::string &worker_ip, uint32_t worker_port); }; } // namespace serving } // namespace mindspore -#endif // MINDSPORE_SERVING_WORKER_BASE_NOTIFY_WORKER_H +#endif // MINDSPORE_SERVER_AGENT_PY_H diff --git a/mindspore_serving/ccsrc/python/serving_py.cc b/mindspore_serving/ccsrc/python/serving_py.cc index 1dac040..adf29d3 100644 --- a/mindspore_serving/ccsrc/python/serving_py.cc +++ b/mindspore_serving/ccsrc/python/serving_py.cc @@ -23,6 +23,9 @@ #include "common/servable.h" #include "worker/context.h" #include "python/master/master_py.h" +#include "python/agent/agent_py.h" +#include "common/exit_handle.h" +#include "worker/distributed_worker/worker_agent.h" namespace mindspore::serving { @@ -104,11 +107,23 @@ void PyRegServable(pybind11::module *m_ptr) { .def_static("register_method", &PyServableStorage::RegisterMethod) .def_static("declare_servable", &PyServableStorage::DeclareServable) .def_static("declare_distributed_servable", &PyServableStorage::DeclareDistributedServable); + + py::class_(m, "OneRankConfig_") + .def(py::init<>()) + .def_readwrite("device_id", &OneRankConfig::device_id) + .def_readwrite("ip", &OneRankConfig::ip); + + py::class_(m, "DistributedServableConfig_") + .def(py::init<>()) + .def_readwrite("common_meta", &DistributedServableConfig::common_meta) + .def_readwrite("distributed_meta", &DistributedServableConfig::distributed_meta) + .def_readwrite("rank_table_content", &DistributedServableConfig::rank_table_content) + .def_readwrite("rank_list", &DistributedServableConfig::rank_list); } void PyRegMaster(pybind11::module *m_ptr) { auto &m = *m_ptr; - py::class_>(m, "Master_") + py::class_(m, "Master_") .def_static("start_grpc_server", &PyMaster::StartGrpcServer) .def_static("start_grpc_master_server", &PyMaster::StartGrpcMasterServer) .def_static("start_restful_server", &PyMaster::StartRestfulServer) @@ -163,15 +178,50 @@ void PyRegWorker(pybind11::module *m_ptr) { .def("set_device_id", &ServableContext::SetDeviceId); } +void PyRegWorkerAgent(pybind11::module *m_ptr) { + auto &m = *m_ptr; + py::class_(m, "WorkerAgent_") + .def_static("get_agents_config_from_worker", &PyAgent::GetAgentsConfigsFromWorker) + .def_static("wait_and_clear", &PyAgent::WaitAndClear) + .def_static("stop_and_clear", &PyAgent::StopAndClear) + .def_static("notify_failed", &PyAgent::NotifyFailed) + .def_static("start_agent", &PyAgent::StartAgent); + + py::class_(m, "AgentStartUpConfig_") + .def(py::init<>()) + .def_readwrite("rank_id", &AgentStartUpConfig::rank_id) + .def_readwrite("device_id", &AgentStartUpConfig::device_id) + .def_readwrite("model_file_name", &AgentStartUpConfig::model_file_name) + .def_readwrite("group_file_name", &AgentStartUpConfig::group_file_name) + .def_readwrite("rank_table_json_file_name", &AgentStartUpConfig::rank_table_json_file_name) + .def_readwrite("agent_ip", &AgentStartUpConfig::agent_ip) + .def_readwrite("agent_port", &AgentStartUpConfig::agent_port) + .def_readwrite("worker_ip", &AgentStartUpConfig::worker_ip) + .def_readwrite("worker_port", &AgentStartUpConfig::worker_port) + .def_readwrite("common_meta", &AgentStartUpConfig::common_meta); +} + +class PyExitSignalHandle { + public: + static void Start() { ExitSignalHandle::Instance().Start(); } + static bool HasStopped() { return ExitSignalHandle::Instance().HasStopped(); } +}; + // cppcheck-suppress syntaxError PYBIND11_MODULE(_mindspore_serving, m) { PyRegServable(&m); PyRegMaster(&m); PyRegWorker(&m); + PyRegWorkerAgent(&m); + + py::class_(m, "ExitSignalHandle_") + .def_static("start", &PyExitSignalHandle::Start) + .def_static("has_stopped", &PyExitSignalHandle::HasStopped); (void)py::module::import("atexit").attr("register")(py::cpp_function{[&]() -> void { Server::Instance().Clear(); Worker::GetInstance().Clear(); + WorkerAgent::Instance().Clear(); }}); } diff --git a/mindspore_serving/ccsrc/python/worker/worker_py.cc b/mindspore_serving/ccsrc/python/worker/worker_py.cc index faa0f31..c1b03a5 100644 --- a/mindspore_serving/ccsrc/python/worker/worker_py.cc +++ b/mindspore_serving/ccsrc/python/worker/worker_py.cc @@ -24,7 +24,7 @@ #include "worker/local_servable/local_sevable.h" #include "worker/distributed_worker/distributed_servable.h" #include "worker/grpc/worker_server.h" -#include "worker/distributed_worker/grpc/distributed_server.h" +#include "worker/distributed_worker/distributed_process/distributed_server.h" namespace mindspore::serving { @@ -43,11 +43,10 @@ void PyWorker::StartServable(const std::string &model_directory, const std::stri } // start grpc server auto grpc_sever = std::make_shared(); - status = grpc_sever->StartWorkerGrpcServer(worker_ip, worker_port); + status = Worker::GetInstance().StartGrpcServer(grpc_sever, worker_ip, worker_port); if (status != SUCCESS) { MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); } - Worker::GetInstance().AfterStartGrpcServer(grpc_sever); status = Worker::GetInstance().StartVersionController(); if (status != SUCCESS) { @@ -76,18 +75,19 @@ void PyWorker::StartServableInMaster(const std::string &model_directory, const s void PyWorker::StartDistributedServable(const std::string &servable_directory, const std::string &servable_name, const std::string &rank_table_json_file, uint32_t version_number, const std::string &worker_ip, uint32_t worker_port, - const std::string &master_ip, uint32_t master_port) { + const std::string &master_ip, uint32_t master_port, + uint32_t wait_agents_time_in_seconds) { Status status; auto servable = std::make_shared(); - auto grpc_sever = std::make_shared(); - status = grpc_sever->StartDistributedWorkerGrpcServer(servable, worker_ip, worker_port); + auto grpc_sever = std::make_shared(servable); + status = Worker::GetInstance().StartGrpcServer(grpc_sever, worker_ip, worker_port); if (status != SUCCESS) { MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); } - Worker::GetInstance().AfterStartGrpcServer(grpc_sever); auto notify_master = std::make_shared(master_ip, master_port, worker_ip, worker_port); - status = servable->StartServable(servable_directory, servable_name, rank_table_json_file, version_number); + status = servable->StartServable(servable_directory, servable_name, rank_table_json_file, version_number, + wait_agents_time_in_seconds); if (status != SUCCESS) { MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); } @@ -103,18 +103,19 @@ void PyWorker::StartDistributedServable(const std::string &servable_directory, c void PyWorker::StartDistributedServableInMaster(const std::string &servable_directory, const std::string &servable_name, const std::string &rank_table_json_file, uint32_t version_number, - const std::string &worker_ip, uint32_t worker_port) { + const std::string &worker_ip, uint32_t worker_port, + uint32_t wait_agents_time_in_seconds) { Status status; auto servable = std::make_shared(); - auto grpc_sever = std::make_shared(); - status = grpc_sever->StartDistributedWorkerGrpcServer(servable, worker_ip, worker_port); + auto grpc_sever = std::make_shared(servable); + status = Worker::GetInstance().StartGrpcServer(grpc_sever, worker_ip, worker_port); if (status != SUCCESS) { MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); } - Worker::GetInstance().AfterStartGrpcServer(grpc_sever); auto notify_master = std::make_shared(); - status = servable->StartServable(servable_directory, servable_name, rank_table_json_file, version_number); + status = servable->StartServable(servable_directory, servable_name, rank_table_json_file, version_number, + wait_agents_time_in_seconds); if (status != SUCCESS) { MSI_LOG_EXCEPTION << "Raise failed: " << status.StatusMessage(); } diff --git a/mindspore_serving/ccsrc/python/worker/worker_py.h b/mindspore_serving/ccsrc/python/worker/worker_py.h index 01a53a8..e6b2c6d 100644 --- a/mindspore_serving/ccsrc/python/worker/worker_py.h +++ b/mindspore_serving/ccsrc/python/worker/worker_py.h @@ -37,11 +37,12 @@ class MS_API PyWorker { static void StartDistributedServable(const std::string &servable_directory, const std::string &servable_name, const std::string &rank_table_json_file, uint32_t version_number, const std::string &worker_ip, uint32_t worker_port, const std::string &master_ip, - uint32_t master_port); + uint32_t master_port, uint32_t wait_agents_time_in_seconds); static void StartDistributedServableInMaster(const std::string &servable_directory, const std::string &servable_name, const std::string &rank_table_json_file, uint32_t version_number, - const std::string &worker_ip, uint32_t worker_port); + const std::string &worker_ip, uint32_t worker_port, + uint32_t wait_agents_time_in_seconds); static int GetBatchSize(); static void WaitAndClear(); diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/agent_process/agent_process.cc b/mindspore_serving/ccsrc/worker/distributed_worker/agent_process/agent_process.cc index ff030e6..6e1750a 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/agent_process/agent_process.cc +++ b/mindspore_serving/ccsrc/worker/distributed_worker/agent_process/agent_process.cc @@ -22,7 +22,7 @@ namespace serving { grpc::Status MSAgentImpl::Exit(grpc::ServerContext *context, const proto::DistributedExitRequest *request, proto::DistributedExitReply *reply) { MSI_LOG(INFO) << "Distributed Worker Exit"; - WorkerAgent::Instance().Clear(); + WorkerAgent::Instance().StopAgent(false); return grpc::Status::OK; } diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.cc b/mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.cc index b4f5ee9..8ec9a39 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.cc +++ b/mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.cc @@ -14,17 +14,32 @@ * limitations under the License. */ #include "worker/distributed_worker/agent_startup.h" +#include "worker/distributed_worker/notify_distributed/notify_worker.h" + namespace mindspore { namespace serving { -Status WorkerAgentStartUp::InitAgentsConfig(const std::string &model_dir, const std::string &model_file_prefix, - const std::string &group_file_dir, const std::string &group_file_prefix) { - return Status(); +WorkerAgentStartUp &WorkerAgentStartUp::Instance() { + static WorkerAgentStartUp instance; + return instance; } -Status WorkerAgentStartUp::GetAgentsConfigsFromWorker(const std::string &agent_ip, uint32_t agent_start_port, - const std::string &worker_ip, uint32_t worker_port) { + +Status WorkerAgentStartUp::GetAgentsConfigsFromWorker(const std::string &worker_ip, uint32_t worker_port) { return Status(); } -Status WorkerAgentStartUp::GetCurrentMachineConfigs(std::vector *configs) { return Status(); } + +Status WorkerAgentStartUp::GetDistributedServableConfig(DistributedServableConfig *config) { + MSI_EXCEPTION_IF_NULL(config); + if (config_.rank_list.empty()) { + return INFER_STATUS_LOG_ERROR(FAILED) << "Rank table config is not ready"; + } + *config = config_; + return SUCCESS; +} + +Status WorkerAgentStartUp::NotifyFailed(const std::string &worker_ip, uint32_t worker_port) { + return GrpcNotifyDistributeWorker::NotifyFailed(worker_ip, worker_port); +} + } // namespace serving } // namespace mindspore diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.h b/mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.h index 5a7c25e..ad28e5c 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.h +++ b/mindspore_serving/ccsrc/worker/distributed_worker/agent_startup.h @@ -27,16 +27,15 @@ namespace serving { class MS_API WorkerAgentStartUp { public: + static WorkerAgentStartUp &Instance(); // from python, worker_agent.py // start_worker_agent // step1, get agents config from worker - Status InitAgentsConfig(const std::string &model_dir, const std::string &model_file_prefix, - const std::string &group_file_dir, const std::string &group_file_prefix); + Status GetAgentsConfigsFromWorker(const std::string &worker_ip, uint32_t worker_port); + // step2, invoke from python + Status GetDistributedServableConfig(DistributedServableConfig *config); - Status GetAgentsConfigsFromWorker(const std::string &rank_start, uint32_t agent_start_port, - const std::string &worker_ip, uint32_t worker_port); - // step2, invoke from python, get current machine agents config - Status GetCurrentMachineConfigs(std::vector *configs); + Status NotifyFailed(const std::string &worker_ip, uint32_t worker_port); private: DistributedServableConfig config_; diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_process.cc b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.cc similarity index 76% rename from mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_process.cc rename to mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.cc index 0333434..48d1042 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_process.cc +++ b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.cc @@ -14,7 +14,8 @@ * limitations under the License. */ -#include "worker/distributed_worker/grpc/distributed_process.h" +#include "worker/distributed_worker/distributed_process/distributed_process.h" +#include "worker/worker.h" #include "common/proto_tensor.h" namespace mindspore { @@ -51,6 +52,20 @@ grpc::Status MSDistributedImpl::AgentExit(grpc::ServerContext *context, const pr MSI_LOG(ERROR) << "Agent Exit FAILED"; } } + if (Worker::GetInstance().IsRunning()) { + Worker::GetInstance().StopServable(); + } + return grpc::Status::OK; +} + +grpc::Status MSDistributedImpl::AgentFailed(grpc::ServerContext *context, const proto::AgentFailedRequest *request, + proto::AgentFailedReply *reply) { + if (Worker::GetInstance().IsRunning()) { + MSI_LOG_ERROR << "Expect worker should not be running"; + Worker::GetInstance().StopServable(); + } else { + servable_->OnAgentFailed(); + } return grpc::Status::OK; } } // namespace serving diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_process.h b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.h similarity index 89% rename from mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_process.h rename to mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.h index b127ac7..147e7c5 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_process.h +++ b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_process.h @@ -41,6 +41,8 @@ class MSDistributedImpl final : public MSWorkerImpl { proto::AgentRegisterReply *reply) override; grpc::Status AgentExit(grpc::ServerContext *context, const proto::AgentExitRequest *request, proto::AgentExitReply *reply) override; + grpc::Status AgentFailed(grpc::ServerContext *context, const proto::AgentFailedRequest *request, + proto::AgentFailedReply *reply) override; private: std::shared_ptr servable_; diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_server.cc b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_server.cc similarity index 73% rename from mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_server.cc rename to mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_server.cc index 79d4064..d9de7cd 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_server.cc +++ b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_server.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "worker/distributed_worker/grpc/distributed_server.h" +#include "worker/distributed_worker/distributed_process/distributed_server.h" #include #include #include @@ -23,12 +23,11 @@ namespace mindspore { namespace serving { -Status MSDistributedWorkerServer::StartDistributedWorkerGrpcServer(std::shared_ptr servable, - const std::string &hostname, int32_t port) { +Status MSDistributedWorkerServer::StartWorkerGrpcServer(const std::string &hostname, int32_t port) { if (in_running_) { return INFER_STATUS_LOG_ERROR(FAILED) << "Worker grpc server is already running"; } - auto impl = std::make_unique(servable); + auto impl = std::make_unique(servable_); async_server_ = std::make_unique(hostname, port, impl.get()); service_impl_ = std::move(impl); return Init(); diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_server.h b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_server.h similarity index 50% rename from mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_server.h rename to mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_server.h index 2151a41..ca6b967 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/grpc/distributed_server.h +++ b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_process/distributed_server.h @@ -28,7 +28,7 @@ #include "common/grpc_async_server.h" #include "worker/grpc/worker_process.h" #include "worker/grpc/worker_server.h" -#include "worker/distributed_worker/grpc/distributed_process.h" +#include "worker/distributed_worker/distributed_process/distributed_process.h" namespace mindspore { namespace serving { @@ -36,18 +36,30 @@ namespace serving { // Service Implement class MS_API MSDistributedWorkerServer : public MSWorkerServer { public: - Status StartDistributedWorkerGrpcServer(std::shared_ptr servable, const std::string &hostname, - int32_t port); + explicit MSDistributedWorkerServer(std::shared_ptr servable) : servable_(servable) {} + ~MSDistributedWorkerServer() = default; + Status StartWorkerGrpcServer(const std::string &hostname, int32_t port) override; + + private: + std::shared_ptr servable_; +}; + +class DistributedServiceContext : public WorkerServiceContext { + public: + DistributedServiceContext(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service, + grpc::ServerCompletionQueue *cq) + : WorkerServiceContext(service_impl, async_service, cq), dist_service_impl_(service_impl) {} + + protected: + MSDistributedImpl *dist_service_impl_ = nullptr; }; // Service Implement -class WorkerAgentRegisterContext : public WorkerServiceContext { +class WorkerAgentRegisterContext : public DistributedServiceContext { public: WorkerAgentRegisterContext(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service, grpc::ServerCompletionQueue *cq) - : service_impl_(service_impl), async_service_(async_service), cq_(cq), responder_(&ctx_) { - state_ = STATE::CREATE; - } + : DistributedServiceContext(service_impl, async_service, cq), responder_(&ctx_) {} ~WorkerAgentRegisterContext() = default; @@ -60,35 +72,27 @@ class WorkerAgentRegisterContext : public WorkerServiceContext { void StartEnqueueRequest() override { state_ = STATE::PROCESS; - async_service_->RequestPredict(&ctx_, &request_, &responder_, cq_, cq_, this); + async_service_->RequestAgentRegister(&ctx_, &request_, &responder_, cq_, cq_, this); } void HandleRequest() override { - EnqueueRequest(service_impl_, async_service_, cq_); + EnqueueRequest(dist_service_impl_, async_service_, cq_); state_ = STATE::FINISH; - grpc::Status status = service_impl_->Predict(&ctx_, &request_, &response_); + grpc::Status status = dist_service_impl_->AgentRegister(&ctx_, &request_, &response_); responder_.Finish(response_, status, this); } - bool JudgeFinish() override { return state_ == STATE::FINISH; } - private: - MSDistributedImpl *service_impl_; - proto::MSWorker::AsyncService *async_service_; - grpc::ServerCompletionQueue *cq_; - grpc::ServerContext ctx_; - grpc::ServerAsyncResponseWriter responder_; - proto::PredictRequest request_; - proto::PredictReply response_; + grpc::ServerAsyncResponseWriter responder_; + proto::AgentRegisterRequest request_; + proto::AgentRegisterReply response_; }; -class WorkerAgentExitContext : public WorkerServiceContext { +class WorkerAgentExitContext : public DistributedServiceContext { public: WorkerAgentExitContext(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service, grpc::ServerCompletionQueue *cq) - : service_impl_(service_impl), async_service_(async_service), cq_(cq), responder_(&ctx_) { - state_ = STATE::CREATE; - } + : DistributedServiceContext(service_impl, async_service, cq), responder_(&ctx_) {} ~WorkerAgentExitContext() = default; @@ -101,26 +105,52 @@ class WorkerAgentExitContext : public WorkerServiceContext { void StartEnqueueRequest() override { state_ = STATE::PROCESS; - async_service_->RequestExit(&ctx_, &request_, &responder_, cq_, cq_, this); + async_service_->RequestAgentExit(&ctx_, &request_, &responder_, cq_, cq_, this); } void HandleRequest() override { - EnqueueRequest(service_impl_, async_service_, cq_); + EnqueueRequest(dist_service_impl_, async_service_, cq_); state_ = STATE::FINISH; - grpc::Status status = service_impl_->Exit(&ctx_, &request_, &response_); + grpc::Status status = dist_service_impl_->AgentExit(&ctx_, &request_, &response_); responder_.Finish(response_, status, this); } - bool JudgeFinish() override { return state_ == STATE::FINISH; } + private: + grpc::ServerAsyncResponseWriter responder_; + proto::AgentExitRequest request_; + proto::AgentExitReply response_; +}; + +class WorkerAgentFailedContext : public DistributedServiceContext { + public: + WorkerAgentFailedContext(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service, + grpc::ServerCompletionQueue *cq) + : DistributedServiceContext(service_impl, async_service, cq), responder_(&ctx_) {} + + ~WorkerAgentFailedContext() = default; + static Status EnqueueRequest(MSDistributedImpl *service_impl, proto::MSWorker::AsyncService *async_service, + grpc::ServerCompletionQueue *cq) { + auto call = new WorkerAgentFailedContext(service_impl, async_service, cq); + call->StartEnqueueRequest(); + return SUCCESS; + } + + void StartEnqueueRequest() override { + state_ = STATE::PROCESS; + async_service_->RequestAgentFailed(&ctx_, &request_, &responder_, cq_, cq_, this); + } + + void HandleRequest() override { + EnqueueRequest(dist_service_impl_, async_service_, cq_); + state_ = STATE::FINISH; + grpc::Status status = dist_service_impl_->AgentFailed(&ctx_, &request_, &response_); + responder_.Finish(response_, status, this); + } private: - MSDistributedImpl *service_impl_; - proto::MSWorker::AsyncService *async_service_; - grpc::ServerCompletionQueue *cq_; - grpc::ServerContext ctx_; - grpc::ServerAsyncResponseWriter responder_; - proto::ExitRequest request_; - proto::ExitReply response_; + grpc::ServerAsyncResponseWriter responder_; + proto::AgentFailedRequest request_; + proto::AgentFailedReply response_; }; class DistributedWorkerGrpcServer : public WorkerGrpcServer { @@ -134,6 +164,7 @@ class DistributedWorkerGrpcServer : public WorkerGrpcServer { WorkerGrpcServer::EnqueueRequest(); WorkerAgentRegisterContext::EnqueueRequest(distributed_service_impl_, &svc_, cq_.get()); WorkerAgentExitContext::EnqueueRequest(distributed_service_impl_, &svc_, cq_.get()); + WorkerAgentFailedContext::EnqueueRequest(distributed_service_impl_, &svc_, cq_.get()); return SUCCESS; } diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.cc b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.cc index b504355..ea83d1c 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.cc +++ b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.cc @@ -17,13 +17,15 @@ #include "worker/distributed_worker/distributed_servable.h" #include #include -#include "worker/worker.h" +#include #include "worker/distributed_worker/notify_agent/notify_agent.h" #include "common/exit_handle.h" namespace mindspore { namespace serving { +DistributedServable::~DistributedServable() { Clear(); } + std::string DistributedServable::GetServableName() const { return servable_name_; } uint64_t DistributedServable::GetServableVersion() const { return version_number_; } @@ -60,7 +62,15 @@ Status DistributedServable::GetDistributedServableConfig(DistributedServableConf return SUCCESS; } +void DistributedServable::SetWaitAgentsPromise(bool flag) { + if (!promise_set_flag_.test_and_set()) { + agents_promise_.set_value(flag); + } +} + Status DistributedServable::RegisterAgent(const WorkerAgentSpec &agent_spec) { + std::unique_lock lock{mutex_}; + if (agent_spec.rank_id < config_.distributed_meta.rank_size) { return INFER_STATUS_LOG_ERROR(FAILED) << "Invalid rank id " << agent_spec.rank_id << ", rank size " << config_.distributed_meta.rank_size; @@ -75,27 +85,24 @@ Status DistributedServable::RegisterAgent(const WorkerAgentSpec &agent_spec) { std::shared_ptr notify_agent = std::make_shared(agent_spec.agent_address); context.notify_agent_ = notify_agent; agent_spec_map_[agent_spec.rank_id] = context; - if (config_.distributed_meta.rank_size == agent_spec_map_.size()) { - Status status = Worker::GetInstance().RegisterWorker(); - if (status != SUCCESS) { - Clear(); - return FAILED; - } - } + if (agent_spec_map_.size() >= config_.distributed_meta.rank_size) { - agents_promise_.set_value(); + SetWaitAgentsPromise(true); } return SUCCESS; } void DistributedServable::Clear() { - for (auto agent : agent_spec_map_) { + std::unique_lock lock{mutex_}; + for (auto &agent : agent_spec_map_) { agent.second.notify_agent_->Exit(); } - Worker::GetInstance().StopServable(false); + agent_spec_map_.clear(); + MSI_LOG_INFO << "End Clear servable"; } Status DistributedServable::UnregisterAgent(const WorkerAgentSpec &agent_spec) { + std::unique_lock lock{mutex_}; for (auto iter = agent_spec_map_.begin(); iter != agent_spec_map_.end();) { if (agent_spec.rank_id == iter->second.agent_spec_.rank_id) { iter = agent_spec_map_.erase(iter); @@ -103,13 +110,13 @@ Status DistributedServable::UnregisterAgent(const WorkerAgentSpec &agent_spec) { ++iter; } } - // todo: send exit message to agent, and then exit if split with master - Clear(); + SetWaitAgentsPromise(false); return SUCCESS; } Status DistributedServable::StartServable(const std::string &servable_directory, const std::string &servable_name, - const std::string &rank_table_json_file, uint64_t version_number) { + const std::string &rank_table_json_file, uint64_t version_number, + uint64_t wait_agents_time_in_seconds) { if (model_loaded_) { MSI_LOG_EXCEPTION << "Model has loaded"; } @@ -138,7 +145,7 @@ Status DistributedServable::StartServable(const std::string &servable_directory, MSI_LOG_ERROR << "Check rank config failed"; return status; } - status = WaitAgentsReady(); + status = WaitAgentsReady(wait_agents_time_in_seconds); if (status != SUCCESS) { MSI_LOG_ERROR << "Waiting for ready of agents failed"; return status; @@ -154,16 +161,23 @@ Status DistributedServable::StartServable(const std::string &servable_directory, Status DistributedServable::InitConfigOnStartup(const std::string &rank_table_json_file) { return FAILED; } -Status DistributedServable::WaitAgentsReady() { +Status DistributedServable::WaitAgentsReady(uint64_t wait_agents_time_in_seconds) { auto future = agents_promise_.get_future(); - const int kWaitMaxHundredMs = 100 * 10; // 100s - int i; + if (wait_agents_time_in_seconds == 0) { + wait_agents_time_in_seconds = UINT32_MAX; + } + const uint64_t kWaitMaxHundredMs = wait_agents_time_in_seconds * 10; + uint64_t i; for (i = 0; i < kWaitMaxHundredMs; i++) { // if (ExitSignalHandle::Instance().HasStopped()) { return INFER_STATUS_LOG_ERROR(FAILED) << "Agents has stopped"; } // waiting for 100ms if (future.wait_for(std::chrono::milliseconds(100)) == std::future_status::ready) { + auto flag = future.get(); + if (!flag) { + return INFER_STATUS_LOG_ERROR(FAILED) << "Failed to starting all agents, maybe some error reported"; + } break; } } @@ -264,32 +278,49 @@ Status DistributedServable::CheckRankConfig() { << "Rank size must be an integral multiple of stage size, rank size: " << rank_size << ", stage size: " << stage_size; } - auto parallel_count = rank_size / stage_size; - constexpr size_t card_count_per_machine = 8; - if (rank_size > card_count_per_machine && parallel_count % card_count_per_machine != 0) { - return INFER_STATUS_LOG_ERROR(FAILED) - << "Parallel count " << parallel_count << " in one stage must be an integral multiple of card count " - << card_count_per_machine << " in one machine, when rank size is greater than card count in one machine, " - << "rank size: " << rank_size << ", stage size: " << stage_size; - } if (config_.rank_list.size() != rank_size) { return INFER_STATUS_LOG_ERROR(FAILED) << "Rank size " << config_.rank_list.size() << " declared in rank table file not equal to rank size " << rank_size << " declared in servable_config, rank json config file: " << rank_table_json_file_; } - for (size_t i = 0; i < rank_size; i++) { - const auto &first_item = config_.rank_list[i]; - for (size_t k = 0; i + k < rank_size && k < card_count_per_machine; k++) { - auto rank_id = i + k; - const auto &item = config_.rank_list[rank_id]; - if (k != item.device_id) { - return INFER_STATUS_LOG_ERROR(FAILED) - << "Check rank table config failed, expected device id of rank " << rank_id << " to be " << k; + auto parallel_count = rank_size / stage_size; + constexpr size_t card_count_per_machine = 8; + if (stage_size == 1) { + std::map> device_map; + for (size_t i = 0; i < rank_size; i++) { + const auto &item = config_.rank_list[i]; + auto &device_id_list = device_map[item.ip]; + if (device_id_list.count(item.device_id) > 0) { + return INFER_STATUS_LOG_ERROR(FAILED) << "Check rank table config failed, device id repeatedly used by rank " + << i << " in device ip " << item.ip; } - if (first_item.ip != item.ip) { - return INFER_STATUS_LOG_ERROR(FAILED) - << "Check rank table config failed, expected device ip " << item.ip << " of rank " << rank_id - << " to be equal with device ip " << first_item.ip << " of rank " << i; + device_id_list.emplace(item.device_id); + } + } else { + if (rank_size < card_count_per_machine) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Rank size " << rank_size << "must >= card count " << card_count_per_machine + << " of one machine when stage size " << stage_size << " > 1"; + } + if (parallel_count % card_count_per_machine != 0) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Parallel count " << parallel_count << " in one stage must be N * " << card_count_per_machine + << "(card count of one machine), rank size: " << rank_size << ", stage size: " << stage_size; + } + for (size_t i = 0; i < rank_size; i += card_count_per_machine) { + const auto &first_item = config_.rank_list[i]; + for (size_t k = 0; i + k < rank_size && k < card_count_per_machine; k++) { + auto rank_id = i + k; + const auto &item = config_.rank_list[rank_id]; + if (k != item.device_id) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Check rank table config failed, expected device id of rank " << rank_id << " to be " << k; + } + if (first_item.ip != item.ip) { + return INFER_STATUS_LOG_ERROR(FAILED) + << "Check rank table config failed, expected device ip " << item.ip << " of rank " << rank_id + << " to be equal with device ip " << first_item.ip << " of rank " << i; + } } } } @@ -298,5 +329,7 @@ Status DistributedServable::CheckRankConfig() { return SUCCESS; } +void DistributedServable::OnAgentFailed() { SetWaitAgentsPromise(false); } + } // namespace serving } // namespace mindspore diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.h b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.h index 642a868..d810209 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.h +++ b/mindspore_serving/ccsrc/worker/distributed_worker/distributed_servable.h @@ -35,9 +35,12 @@ struct DistributedAgentContext { class MS_API DistributedServable : public ServableBase { public: + DistributedServable() = default; + ~DistributedServable(); // from python, worker.py Status StartServable(const std::string &servable_directory, const std::string &servable_name, - const std::string &rank_table_json_file, uint64_t version_number); + const std::string &rank_table_json_file, uint64_t version_number, + uint64_t wait_agents_time_in_seconds); // invoke from agent Status GetDistributedServableConfig(DistributedServableConfig *config) const; @@ -55,7 +58,8 @@ class MS_API DistributedServable : public ServableBase { uint64_t GetBatchSize() const override; std::string GetServableName() const override; uint64_t GetServableVersion() const override; - void Clear(); + void Clear() override; + void OnAgentFailed(); private: DistributedServableConfig config_; @@ -63,19 +67,22 @@ class MS_API DistributedServable : public ServableBase { uint64_t version_number_ = 0; bool model_loaded_ = false; + std::mutex mutex_; std::map agent_spec_map_; std::string rank_table_json_file_; std::vector input_infos_; std::vector output_infos_; uint64_t batch_size_ = 0; - std::promise agents_promise_; + std::atomic_flag promise_set_flag_ = ATOMIC_FLAG_INIT; + std::promise agents_promise_; Status InitConfigOnStartup(const std::string &rank_table_json_file); - Status WaitAgentsReady(); + Status WaitAgentsReady(uint64_t wait_agents_time_in_seconds); Status CheckAgentsInfosAndInitTensorInfos(); Status CompareTensorInfos(const std::vector &lefts, const std::vector &rights); Status CheckRankConfig(); + void SetWaitAgentsPromise(bool flag); // agent stubs }; diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.cc b/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.cc index d9e6b73..379eeff 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.cc +++ b/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.cc @@ -62,7 +62,7 @@ Status GrpcNotifyDistributeWorker::Register(const std::vector & std::this_thread::sleep_for(std::chrono::milliseconds(REGISTER_INTERVAL * 1000)); } if (ExitSignalHandle::Instance().HasStopped()) { - return INFER_STATUS_LOG_WARNING(FAILED) << "Worker exit, stop registration"; + return INFER_STATUS_LOG_WARNING(FAILED) << "Agent exit, stop registration"; } return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Register TimeOut"; } @@ -87,5 +87,21 @@ Status GrpcNotifyDistributeWorker::Unregister() { return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Exit Failed"; } +Status GrpcNotifyDistributeWorker::NotifyFailed(const std::string &worker_ip, uint32_t worker_port) { + auto address = worker_ip + ":" + std::to_string(worker_port); + auto channel = GrpcServer::CreateChannel(address); + auto stub = proto::MSWorker::NewStub(channel); + + grpc::ClientContext context; + proto::AgentFailedRequest request; + proto::AgentFailedReply reply; + grpc::Status status = stub->AgentFailed(&context, request, &reply); + if (status.ok()) { + MSI_LOG(INFO) << "Success to notify failure of agent"; + return SUCCESS; + } + return INFER_STATUS_LOG_ERROR(SYSTEM_ERROR) << "Failed to notify failure of agent"; +} + } // namespace serving } // namespace mindspore diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.h b/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.h index 2c2724c..da509ff 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.h +++ b/mindspore_serving/ccsrc/worker/distributed_worker/notify_distributed/notify_worker.h @@ -19,7 +19,8 @@ #include #include #include -#include "worker/distributed_worker/notify_distributed/base_notify_worker.h" +#include "common/serving_common.h" +#include "worker/distributed_worker/common.h" #include "proto/ms_distributed.pb.h" #include "proto/ms_distributed.grpc.pb.h" #include "proto/ms_worker.pb.h" @@ -27,13 +28,15 @@ namespace mindspore { namespace serving { -class MS_API GrpcNotifyDistributeWorker : public BaseNotifyDistributeWorker { +class MS_API GrpcNotifyDistributeWorker { public: - GrpcNotifyDistributeWorker(const std::string &master_ip, uint32_t master_port, const std::string &host_ip, - uint32_t host_port); - ~GrpcNotifyDistributeWorker() override; - Status Register(const std::vector &worker_specs) override; - Status Unregister() override; + GrpcNotifyDistributeWorker(const std::string &worker_ip, uint32_t worker_port, const std::string &agent_ip, + uint32_t agent_port); + ~GrpcNotifyDistributeWorker(); + Status Register(const std::vector &agent_specs); + Status Unregister(); + // from start up, not agent + static Status NotifyFailed(const std::string &worker_ip, uint32_t worker_port); private: std::string distributed_worker_ip_; diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.cc b/mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.cc index c5e59df..a819b95 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.cc +++ b/mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.cc @@ -14,6 +14,10 @@ * limitations under the License. */ #include "worker/distributed_worker/worker_agent.h" +#include +#include "worker/distributed_worker/agent_process/agent_process.h" +#include "worker/distributed_worker/notify_distributed/notify_worker.h" +#include "common/exit_handle.h" namespace mindspore { namespace serving { @@ -23,15 +27,16 @@ WorkerAgent &WorkerAgent::Instance() { return instance; } -Status WorkerAgent::LoadModelFromFile(const AgentStartUpConfig &config) { - config_ = config; - return executor_.LoadModelFromFile(config); -} - -Status WorkerAgent::Clear() { return executor_.UnloadModel(); } - -Status WorkerAgent::ExecuteModel(const std::vector &request, std::vector *reply) { - return executor_.ExecuteModel(request, reply); +Status WorkerAgent::Clear() { + if (notify_worker_) { + if (exit_notify_worker_) { + notify_worker_->Unregister(); + } + notify_worker_ = nullptr; + } + grpc_server_.Stop(); + executor_.UnloadModel(); + return SUCCESS; } Status WorkerAgent::Run(const proto::DistributedPredictRequest &request, proto::DistributedPredictReply *reply) { @@ -40,5 +45,59 @@ Status WorkerAgent::Run(const proto::DistributedPredictRequest &request, proto:: return SUCCESS; } +Status WorkerAgent::StartAgent(const AgentStartUpConfig &config) { + Status status; + config_ = config; + status = executor_.LoadModelFromFile(config); + if (status != SUCCESS) { + MSI_LOG_ERROR << "LoadModelFromFile failed, servable name: " << config.common_meta.servable_name + << ", rank_id: " << config.rank_id << ", device id: " << config.device_id + << ", model file: " << config.model_file_name + << ", rank table file: " << config.rank_table_json_file_name + << ", group config file: " << config.group_file_name; + return status; + } + status = StartGrpcServer(); + if (status != SUCCESS) { + MSI_LOG_ERROR << "Start agent grpc server failed, agent ip: " << config.agent_ip + << ", agent port: " << config.agent_port; + return status; + } + status = RegisterAgent(); + if (status != SUCCESS) { + MSI_LOG_ERROR << "Register agent failed, agent ip: " << config.agent_ip << ", agent port: " << config.agent_port + << ", worker ip: " << config.worker_ip << ", worker port: " << config.worker_port; + return status; + } + MSI_LOG_INFO << "Start agent success, servable name: " << config.common_meta.servable_name + << ", rank_id: " << config.rank_id << ", device id: " << config.device_id + << ", model file: " << config.model_file_name + << ", rank table file: " << config.rank_table_json_file_name + << ", group config file: " << config.group_file_name; + return SUCCESS; +} + +Status WorkerAgent::StartGrpcServer() { + grpc_server_.Start(std::make_shared(), config_.agent_ip, config_.agent_port, gRpcMaxMBMsgSize, "Agent"); + return SUCCESS; +} + +Status WorkerAgent::RegisterAgent() { + notify_worker_ = std::make_shared(config_.worker_ip, config_.agent_port, config_.agent_ip, + config_.agent_port); + WorkerAgentSpec spec; + spec.agent_address = config_.agent_ip + ":" + std::to_string(config_.agent_port); + spec.rank_id = config_.rank_id; + spec.batch_size = executor_.GetBatchSize(); + spec.input_infos = executor_.GetInputInfos(); + spec.output_infos = executor_.GetOutputInfos(); + return notify_worker_->Register({spec}); +} + +void WorkerAgent::StopAgent(bool notify_worker) { + exit_notify_worker_ = notify_worker; + ExitSignalHandle::Instance().Stop(); +} + } // namespace serving } // namespace mindspore diff --git a/mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.h b/mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.h index 520c4db..702e791 100644 --- a/mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.h +++ b/mindspore_serving/ccsrc/worker/distributed_worker/worker_agent.h @@ -17,24 +17,36 @@ #ifndef MINDSPORE_SERVING_WORKER_AGENT_H #define MINDSPORE_SERVING_WORKER_AGENT_H #include +#include #include "worker/distributed_worker/agent_executor.h" #include "proto/ms_agent.pb.h" #include "proto/ms_agent.grpc.pb.h" +#include "common/grpc_server.h" +#include "worker/distributed_worker/common.h" +#include "worker/distributed_worker/notify_distributed/notify_worker.h" namespace mindspore { namespace serving { class MS_API WorkerAgent { public: static WorkerAgent &Instance(); - Status LoadModelFromFile(const AgentStartUpConfig &config); Status Clear(); - Status ExecuteModel(const std::vector &request, std::vector *reply); Status Run(const proto::DistributedPredictRequest &request, proto::DistributedPredictReply *reply); + Status StartAgent(const AgentStartUpConfig &config); + + void StopAgent(bool notify_worker = true); + private: AgentStartUpConfig config_; WorkerAgentExecutor executor_; + GrpcServer grpc_server_; + bool exit_notify_worker_ = true; + std::shared_ptr notify_worker_; + + Status StartGrpcServer(); + Status RegisterAgent(); }; } // namespace serving diff --git a/mindspore_serving/ccsrc/worker/grpc/worker_server.h b/mindspore_serving/ccsrc/worker/grpc/worker_server.h index 1452727..d02d014 100644 --- a/mindspore_serving/ccsrc/worker/grpc/worker_server.h +++ b/mindspore_serving/ccsrc/worker/grpc/worker_server.h @@ -39,7 +39,7 @@ class MS_API MSWorkerServer { MSWorkerServer(); virtual ~MSWorkerServer(); - Status StartWorkerGrpcServer(const std::string &hostname, int32_t port); + virtual Status StartWorkerGrpcServer(const std::string &hostname, int32_t port); Status Stop(); protected: @@ -48,21 +48,32 @@ class MS_API MSWorkerServer { std::unique_ptr service_impl_ = nullptr; std::unique_ptr async_server_ = nullptr; - Status StartAsyncRpcService(); Status Init(); + Status StartAsyncRpcService(); }; class WorkerServiceContext { public: enum class STATE : int8_t { CREATE = 1, PROCESS = 2, FINISH = 3 }; + + WorkerServiceContext(MSWorkerImpl *service_impl, proto::MSWorker::AsyncService *async_service, + grpc::ServerCompletionQueue *cq) + : service_impl_(service_impl), async_service_(async_service), cq_(cq) { + state_ = STATE::CREATE; + } virtual ~WorkerServiceContext() {} + bool JudgeFinish() { return state_ == STATE::FINISH; } + virtual void StartEnqueueRequest() = 0; virtual void HandleRequest() = 0; - virtual bool JudgeFinish() = 0; + protected: + MSWorkerImpl *service_impl_; + proto::MSWorker::AsyncService *async_service_; + grpc::ServerCompletionQueue *cq_; + grpc::ServerContext ctx_; - public: STATE state_; }; @@ -70,9 +81,7 @@ class WorkerPredictContext : public WorkerServiceContext { public: WorkerPredictContext(MSWorkerImpl *service_impl, proto::MSWorker::AsyncService *async_service, grpc::ServerCompletionQueue *cq) - : service_impl_(service_impl), async_service_(async_service), cq_(cq), responder_(&ctx_) { - state_ = STATE::CREATE; - } + : WorkerServiceContext(service_impl, async_service, cq), responder_(&ctx_) {} ~WorkerPredictContext() = default; @@ -95,13 +104,7 @@ class WorkerPredictContext : public WorkerServiceContext { responder_.Finish(response_, status, this); } - bool JudgeFinish() override { return state_ == STATE::FINISH; } - private: - MSWorkerImpl *service_impl_; - proto::MSWorker::AsyncService *async_service_; - grpc::ServerCompletionQueue *cq_; - grpc::ServerContext ctx_; grpc::ServerAsyncResponseWriter responder_; proto::PredictRequest request_; proto::PredictReply response_; @@ -111,9 +114,7 @@ class WorkerExitContext : public WorkerServiceContext { public: WorkerExitContext(MSWorkerImpl *service_impl, proto::MSWorker::AsyncService *async_service, grpc::ServerCompletionQueue *cq) - : service_impl_(service_impl), async_service_(async_service), cq_(cq), responder_(&ctx_) { - state_ = STATE::CREATE; - } + : WorkerServiceContext(service_impl, async_service, cq), responder_(&ctx_) {} ~WorkerExitContext() = default; @@ -136,13 +137,7 @@ class WorkerExitContext : public WorkerServiceContext { responder_.Finish(response_, status, this); } - bool JudgeFinish() override { return state_ == STATE::FINISH; } - private: - MSWorkerImpl *service_impl_; - proto::MSWorker::AsyncService *async_service_; - grpc::ServerCompletionQueue *cq_; - grpc::ServerContext ctx_; grpc::ServerAsyncResponseWriter responder_; proto::ExitRequest request_; proto::ExitReply response_; diff --git a/mindspore_serving/ccsrc/worker/local_servable/local_sevable.cc b/mindspore_serving/ccsrc/worker/local_servable/local_sevable.cc index 9680929..81452c1 100644 --- a/mindspore_serving/ccsrc/worker/local_servable/local_sevable.cc +++ b/mindspore_serving/ccsrc/worker/local_servable/local_sevable.cc @@ -31,7 +31,7 @@ static const char *kVersionStrategySpecific = "specific"; namespace mindspore::serving { -LocalModelServable::~LocalModelServable() { session_.UnloadModel(); } +LocalModelServable::~LocalModelServable() { Clear(); } std::string LocalModelServable::GetServableName() const { return servable_name_; } @@ -248,4 +248,11 @@ Status LocalModelServable::LoadModel(uint64_t version_number) { return SUCCESS; } +void LocalModelServable::Clear() { + if (model_loaded_) { + session_.UnloadModel(); + } + model_loaded_ = false; +} + } // namespace mindspore::serving diff --git a/mindspore_serving/ccsrc/worker/local_servable/local_sevable.h b/mindspore_serving/ccsrc/worker/local_servable/local_sevable.h index d5b9a8c..227c9e9 100644 --- a/mindspore_serving/ccsrc/worker/local_servable/local_sevable.h +++ b/mindspore_serving/ccsrc/worker/local_servable/local_sevable.h @@ -48,6 +48,7 @@ class MS_API LocalModelServable : public ServableBase { Status InitDevice(ModelType model_type, const std::map &other_options); std::string GetServableName() const override; uint64_t GetServableVersion() const override; + void Clear() override; private: LoadServableSpec base_spec_; diff --git a/mindspore_serving/ccsrc/worker/sevable_base.h b/mindspore_serving/ccsrc/worker/sevable_base.h index 8e9e800..c3acd00 100644 --- a/mindspore_serving/ccsrc/worker/sevable_base.h +++ b/mindspore_serving/ccsrc/worker/sevable_base.h @@ -41,6 +41,7 @@ class ServableBase { virtual uint64_t GetBatchSize() const = 0; virtual std::string GetServableName() const = 0; virtual uint64_t GetServableVersion() const = 0; + virtual void Clear() = 0; }; } // namespace mindspore::serving diff --git a/mindspore_serving/ccsrc/worker/worker.cc b/mindspore_serving/ccsrc/worker/worker.cc index 87167ef..10e25e1 100644 --- a/mindspore_serving/ccsrc/worker/worker.cc +++ b/mindspore_serving/ccsrc/worker/worker.cc @@ -174,9 +174,13 @@ void Worker::Update() { */ } -Status Worker::AfterStartGrpcServer(const std::shared_ptr &grpc_server) { +Status Worker::StartGrpcServer(const std::shared_ptr &grpc_server, const std::string &worker_ip, + int32_t port) { + if (worker_grpc_server_ != nullptr) { + return INFER_STATUS_LOG_ERROR(FAILED) << "Worker gRPC server is already running"; + } worker_grpc_server_ = grpc_server; - return SUCCESS; + return worker_grpc_server_->StartWorkerGrpcServer(worker_ip, port); } Status Worker::StartServable(std::shared_ptr servable, std::shared_ptr notify_master) { @@ -248,6 +252,9 @@ void Worker::Clear() { if (exit_notify_master_ && servable_started_) { notify_master_->Unregister(); } + for (auto &worker_item : work_list_) { + worker_item.servable->Clear(); + } work_list_.clear(); py_task_queue_group_.Stop(); @@ -257,7 +264,7 @@ void Worker::Clear() { MSI_LOG_INFO << "End clear worker session"; } -bool Worker::HasCleared() { return !servable_started_; } +bool Worker::IsRunning() { return servable_started_; } Worker::~Worker() { Clear(); } @@ -318,7 +325,7 @@ Status AsyncResult::GetNext(Instance *instance_result) { const int kWaitMaxHundredMs = 100; int i; for (i = 0; i < kWaitMaxHundredMs; i++) { // - if (ExitSignalHandle::Instance().HasStopped() || Worker::GetInstance().HasCleared()) { + if (ExitSignalHandle::Instance().HasStopped() || !Worker::GetInstance().IsRunning()) { instance_result->error_msg = Status(SYSTEM_ERROR, "Servable stopped"); return SYSTEM_ERROR; } diff --git a/mindspore_serving/ccsrc/worker/worker.h b/mindspore_serving/ccsrc/worker/worker.h index 50d95eb..ef66043 100644 --- a/mindspore_serving/ccsrc/worker/worker.h +++ b/mindspore_serving/ccsrc/worker/worker.h @@ -75,10 +75,11 @@ class MS_API Worker { const std::vector &inputs); Status StartServable(std::shared_ptr servable, std::shared_ptr notify_master); - Status AfterStartGrpcServer(const std::shared_ptr &grpc_server); + Status StartGrpcServer(const std::shared_ptr &grpc_server, const std::string &worker_ip, + int32_t port); void StopServable(bool notify_master = true); - bool HasCleared(); + bool IsRunning(); Status RegisterWorker(); void Update(); Status StartVersionController(); diff --git a/mindspore_serving/master/_master.py b/mindspore_serving/master/_master.py index 78abb52..0d61459 100644 --- a/mindspore_serving/master/_master.py +++ b/mindspore_serving/master/_master.py @@ -18,6 +18,7 @@ import threading from functools import wraps from mindspore_serving.worker import check_type from mindspore_serving import log as logger +from mindspore_serving._mindspore_serving import ExitSignalHandle_ from mindspore_serving._mindspore_serving import Master_ _wait_and_clear_thread = None @@ -59,6 +60,7 @@ def stop_on_except(func): @wraps(func) def handle_except(*args, **kwargs): try: + ExitSignalHandle_.start() # Set flag to running and receive Ctrl+C message func(*args, **kwargs) except: stop() diff --git a/mindspore_serving/proto/ms_distributed.proto b/mindspore_serving/proto/ms_distributed.proto index fb6c72a..27fa6c4 100644 --- a/mindspore_serving/proto/ms_distributed.proto +++ b/mindspore_serving/proto/ms_distributed.proto @@ -44,3 +44,10 @@ message AgentExitRequest { message AgentExitReply { ErrorMsg error_msg = 1; } + +message AgentFailedRequest { +} + +message AgentFailedReply { + ErrorMsg error_msg = 1; +} diff --git a/mindspore_serving/proto/ms_worker.proto b/mindspore_serving/proto/ms_worker.proto index 436b52f..7b2dbe0 100644 --- a/mindspore_serving/proto/ms_worker.proto +++ b/mindspore_serving/proto/ms_worker.proto @@ -29,4 +29,5 @@ service MSWorker { // for worker agent rpc AgentExit(AgentExitRequest) returns (AgentExitReply) {} rpc AgentRegister(AgentRegisterRequest) returns (AgentRegisterReply) {} + rpc AgentFailed(AgentFailedRequest) returns (AgentFailedReply) {} } diff --git a/mindspore_serving/worker/_worker.py b/mindspore_serving/worker/_worker.py index 33ba9fb..b11de95 100644 --- a/mindspore_serving/worker/_worker.py +++ b/mindspore_serving/worker/_worker.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""Inferface for start up servable""" +"""Interface for start up servable""" import threading from functools import wraps from mindspore_serving import log as logger +from mindspore_serving._mindspore_serving import ExitSignalHandle_ from mindspore_serving._mindspore_serving import Worker_ from .register.preprocess import preprocess_storage from .register.postprocess import postprocess_storage @@ -77,6 +78,7 @@ def stop_on_except(func): @wraps(func) def handle_except(*args, **kwargs): try: + ExitSignalHandle_.start() # Set flag to running and receive Ctrl+C message func(*args, **kwargs) except: stop() diff --git a/mindspore_serving/worker/distributed/agent_startup.py b/mindspore_serving/worker/distributed/agent_startup.py index 41d8218..8bf27d1 100644 --- a/mindspore_serving/worker/distributed/agent_startup.py +++ b/mindspore_serving/worker/distributed/agent_startup.py @@ -13,31 +13,238 @@ # limitations under the License. # ============================================================================ """Serving, distributed worker agent startup""" -import inspect +import os +import time +from multiprocessing import Process, Pipe + +from mindspore_serving._mindspore_serving import ExitSignalHandle_ +from mindspore_serving._mindspore_serving import WorkerAgent_, AgentStartUpConfig_ + +from mindspore_serving import log as logger from mindspore_serving.worker import check_type +from mindspore_serving.worker.distributed import worker_agent + + +def _get_local_ip(rank_list, port): + """Get the local ip from the rank table config""" + import socket + ip_list = [] + for item in rank_list: + ip_list.append(item.ip) + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + for ip in ip_list: + try: + s.bind((ip, port)) + logger.info(f"Get local machine ip success, ip {ip}") + return ip + # pylint: disable=bare-except + except: + pass + raise RuntimeError(f"Get local machine ip failed, rank table ips: {ip_list}, bind port {port}") + + +def _update_model_files_path(model_files, group_config_files): + """Check and return model files or group config files""" + script_dir = os.path.dirname(os.path.realpath(__file__)) + logger.info(f"input model files: {model_files}") + logger.info(f"input group config files: {group_config_files}") + model_files_temp = [] + for item in model_files: + file_name = os.path.join(script_dir, item) + if not os.access(file_name, os.R_OK): + raise RuntimeError(f"Cannot access model file '{file_name}'") + model_files_temp.append(file_name) + + group_files_temp = [] + for item in group_config_files: + file_name = os.path.join(script_dir, item) + if not os.access(file_name, os.R_OK): + raise RuntimeError(f"Cannot access group config file '{file_name}'") + group_files_temp.append(file_name) + + logger.info(f"absolute model files: {model_files_temp}") + logger.info(f"absolute group config files: {group_files_temp}") + return model_files_temp, group_files_temp + + +def _make_json_table_file(distributed_config): + """Make rank table json file""" + rank_size = len(distributed_config.rank_list) + runtime_dir = os.path.abspath(".") + time_stamp = str(time.strftime('%Y_%m_%d_%H_%M_%S', time.localtime(time.time()))) + rank_table_file_name = os.path.join(runtime_dir, f"hccl_rank_table_{time_stamp}_{rank_size}.json") + with open(rank_table_file_name, "w") as fp: + fp.write(distributed_config.rank_table_content) + return rank_table_file_name + + +signal_success = "Success" +signal_exit = "Exit" +signal_heartbeat = "HeartBeat" + + +def _recv_parent(index, recv_pipe): + """Receive message from Start up process. + Return False on Ctrl+C(and worker Stop message) Exit Signal, heartbeat failed, and signal_exit. + Return True on receiving signal_success.""" + try: + while True: + heartbeat_count = 0 + while not recv_pipe.poll(0.1): + if ExitSignalHandle_.has_stopped(): + logger.warning(f"Child {index}: Exit on Ctrl+C or stop message from worker") + return False + heartbeat_count += 1 + if heartbeat_count >= 30: # 3s + logger.warning(f"Child {index}: Exit on failure of receiving parent message") + return False + parent_signal = recv_pipe.recv() + if parent_signal != signal_heartbeat: + break + if parent_signal == signal_success: + logger.info(f"Child {index}: Receive success") + return True + if parent_signal == signal_exit: + logger.warning(f"Child {index}: Exit on receiving exit message") + else: + logger.warning(f"Child {index}: Exit on receiving unknown message {parent_signal}") + # pylint: disable=broad-except + except Exception as e: + logger.warning(f"Child {index}: Exit on exception: {e}") + return False + + +def _agent_process(send_pipe, recv_pipe, index, start_config): + """Agent process""" + try: + # listening success or failed message from parent process + ExitSignalHandle_.start() # Set flag to running and receive Ctrl+C message + worker_agent.start_worker_agent(start_config=start_config) + send_pipe.send((index, signal_success)) + success_msg = _recv_parent(index, recv_pipe) + if not success_msg: + worker_agent.stop() + send_pipe.close() + recv_pipe.close() + # pylint: disable=broad-except + except Exception as e: + logger.error(f"Child {index}: Catch exception and notify exit of others") + send_pipe.send((index, e)) + worker_agent.stop() + raise -def startup_worker_agents(worker_ip, worker_port, - get_model_files_fun, get_group_configs_fun, - rank_start, agent_start_port=7000): - """Start up all needed worker agents on one machine - """ +def _start_listening_child_processes(p_recv_pipe, send_pipe_list, subprocess_list): + """Listening child process""" + def send_pipe_msg(send_pipe, msg): + try: + send_pipe.send(msg) + # pylint: disable=broad-except + except Exception as e: + logger.warning(f"Send pipe message exception happen: {e}") + + count = len(send_pipe_list) + for _ in range(count): + while True: + if p_recv_pipe.poll(0.1): + break + for send_pipe, process in zip(send_pipe_list, subprocess_list): + if process.is_alive(): + continue + logger.warning("Fail to start agents because of death of one agent") + for send_pipe_x, process_x in zip(send_pipe_list, subprocess_list): + if process_x.is_alive(): + send_pipe_msg(send_pipe_x, signal_exit) + return False + for send_pipe in send_pipe_list: + send_pipe_msg(send_pipe, signal_heartbeat) + + _, msg = p_recv_pipe.recv() + if isinstance(msg, Exception): + logger.warning("Fail to start agents because of exception raise by one agent") + for send_pipe in send_pipe_list: + send_pipe_msg(send_pipe, signal_exit) + return False + + for send_pipe in send_pipe_list: + send_pipe_msg(send_pipe, signal_success) + logger.info("Success to start agents") + return True + + +def _startup_all_agents(common_meta, worker_ip, worker_port, + agent_ip, agent_start_port, device_id_list, rank_id_list, + model_files, group_config_files, rank_table_file): + """Start up all agents in one machine""" + servable_name = common_meta.servable_name + index = 0 + send_pipe_list = [] + subprocess_list = [] + c_send_pipe, p_recv_pipe = Pipe() + for device_id, rank_id, model_file, group_file in zip(device_id_list, rank_id_list, model_files, + group_config_files): + p_send_pipe, c_recv_pipe = Pipe() + send_pipe_list.append(p_send_pipe) + + agent_port = agent_start_port + index + + start_config = AgentStartUpConfig_() + start_config.rank_id = rank_id + start_config.device_id = device_id + start_config.model_file_name = model_file + start_config.group_file_name = group_file + start_config.rank_table_json_file_name = rank_table_file + start_config.agent_ip = agent_ip + start_config.agent_port = agent_port + start_config.worker_ip = worker_ip + start_config.worker_port = worker_port + start_config.common_meta = common_meta + + process = Process(target=_agent_process, + args=(c_send_pipe, c_recv_pipe, index, start_config), + name=f"{servable_name}_worker_agent_rank{rank_id}_device{device_id}") + process.start() + subprocess_list.append(process) + index += 1 + ret = _start_listening_child_processes(p_recv_pipe, send_pipe_list, subprocess_list) + if not ret: + WorkerAgent_.notify_failed(worker_ip, worker_port) + + +def startup_worker_agents(worker_ip, worker_port, model_files, group_config_files, agent_start_port=7000): + """Start up all needed worker agents on one machine""" check_type.check_str("worker_ip", worker_ip) check_type.check_ip_port("worker_port", worker_port) check_type.check_int("agent_start_port", agent_start_port, 1, 65535 - 7) - if inspect.isfunction(get_model_files_fun): - pass - else: - if not isinstance(get_model_files_fun, [list, tuple]): - raise RuntimeError(f"Check failed, get_model_files_fun first must be function or tuple/list of str, " - f"now is {type(get_model_files_fun)}") - if inspect.isfunction(get_group_configs_fun): - pass - else: - if not isinstance(get_group_configs_fun, [list, tuple]): - raise RuntimeError(f"Check failed, get_group_configs_fun first must be function or tuple/list of str, " - f"now is {type(get_group_configs_fun)}") - check_type.check_int("rank_start", rank_start, 0) - if rank_start % 8 != 0: - raise RuntimeError(f"Parameter 'rank_start' must be mulfiply of 8, now is {rank_start}") + model_files = check_type.check_and_as_int_tuple_list("model_files", model_files) + group_config_files = check_type.check_and_as_int_tuple_list("group_config_files", group_config_files) + distributed_config = WorkerAgent_.get_agents_config_from_worker(worker_ip, worker_port) + + # get machine ip + rank_list = distributed_config.rank_list + local_ip = _get_local_ip(rank_list, agent_start_port) + # get all device_id and rank_id + local_device_id_list = [] + local_rank_id_list = [] + for rank_id, item in enumerate(rank_list): + if item.ip == local_ip: + local_device_id_list.append(item.device_id) + local_rank_id_list.append(rank_id) + + # handle model files and group config files + if len(local_device_id_list) != len(model_files): + raise RuntimeError(f"Card count {local_device_id_list} described rank table does not equal to model files size " + f"{len(model_files)}, model files: {model_files}") + + if len(local_device_id_list) != len(group_config_files): + raise RuntimeError(f"Card count {local_device_id_list} described rank table does not equal to group config " + f"files size {len(group_config_files)}, group config files: {group_config_files}") + + model_files, group_config_files = _update_model_files_path(model_files, group_config_files) + + # make json table file and export env + rank_table_file = _make_json_table_file(distributed_config) + _startup_all_agents(distributed_config.common_meta, worker_ip, worker_port, local_ip, agent_start_port, + local_device_id_list, local_rank_id_list, + model_files, group_config_files, rank_table_file) diff --git a/mindspore_serving/worker/distributed/distributed_worker.py b/mindspore_serving/worker/distributed/distributed_worker.py index 5bee6b8..4235ee6 100644 --- a/mindspore_serving/worker/distributed/distributed_worker.py +++ b/mindspore_serving/worker/distributed/distributed_worker.py @@ -13,15 +13,17 @@ # limitations under the License. # ============================================================================ """Serving, distributed worker startup""" -from mindspore_serving.worker._worker import stop_on_except, _load_servable_config -from mindspore_serving.worker._worker import _start_py_task, _start_wait_and_clear -from mindspore_serving.worker import check_type from mindspore_serving._mindspore_serving import Worker_ +from mindspore_serving.worker import check_type +from mindspore_serving.worker._worker import _start_py_task, _start_wait_and_clear +from mindspore_serving.worker._worker import stop_on_except, _load_servable_config + @stop_on_except def start_distributed_servable(servable_directory, servable_name, rank_table_json_file, version_number=1, - worker_ip="0.0.0.0", worker_port=6200, master_ip="0.0.0.0", master_port=6100): + worker_ip="0.0.0.0", worker_port=6200, master_ip="0.0.0.0", master_port=6100, + wait_agents_time_in_seconds=300): r""" Start up the servable named 'servable_name' defined in 'servable_directory', and link the worker to the master through gRPC (master_ip, master_port). @@ -46,6 +48,7 @@ def start_distributed_servable(servable_directory, servable_name, rank_table_jso master_port (int): The master port the worker linked to. worker_ip (str): The worker ip the master and agents linked to. worker_port (int): The worker port the master and agents linked to. + wait_agents_time_in_seconds(int): The maximum time in seconds the worker waiting ready of all agents. Examples: >>> import os @@ -70,15 +73,15 @@ def start_distributed_servable(servable_directory, servable_name, rank_table_jso check_type.check_ip_port('worker_port', worker_port) _load_servable_config(servable_directory, servable_name) - _start_wait_and_clear() Worker_.start_distributed_servable(servable_directory, servable_name, rank_table_json_file, version_number, - master_ip, master_port, worker_ip, worker_port) + master_ip, master_port, worker_ip, worker_port, wait_agents_time_in_seconds) _start_py_task(Worker_.get_batch_size()) + _start_wait_and_clear() @stop_on_except def start_distributed_servable_in_master(servable_directory, servable_name, rank_table_json_file, version_number=1, - worker_ip="0.0.0.0", worker_port=6200): + worker_ip="0.0.0.0", worker_port=6200, wait_agents_time_in_seconds=300): r""" Start up the servable named 'servable_name' defined in 'svable_directory', and the worker will run in the process of the master. @@ -97,6 +100,7 @@ def start_distributed_servable_in_master(servable_directory, servable_name, rank rank_table_json_file (str): The ranke table json file name. worker_ip (str): The worker ip the agents linked to. worker_port (int): The worker port the agents linked to. + wait_agents_time_in_seconds(int): The maximum time in seconds the worker waiting ready of all agents. Examples: >>> import os @@ -121,7 +125,7 @@ def start_distributed_servable_in_master(servable_directory, servable_name, rank check_type.check_ip_port('worker_port', worker_port) _load_servable_config(servable_directory, servable_name) - _start_wait_and_clear() Worker_.start_distributed_servable_in_master(servable_directory, servable_name, rank_table_json_file, - version_number, worker_ip, worker_port) + version_number, worker_ip, worker_port, wait_agents_time_in_seconds) _start_py_task(Worker_.get_batch_size()) + _start_wait_and_clear() diff --git a/mindspore_serving/worker/distributed/worker_agent.py b/mindspore_serving/worker/distributed/worker_agent.py index d1ebb99..ad32a53 100644 --- a/mindspore_serving/worker/distributed/worker_agent.py +++ b/mindspore_serving/worker/distributed/worker_agent.py @@ -13,22 +13,54 @@ # limitations under the License. # ============================================================================ """Serving, distributed worker agent""" -from mindspore_serving.worker import check_type +import os +import threading +from mindspore_serving._mindspore_serving import WorkerAgent_, AgentStartUpConfig_ +from mindspore_serving import log as logger -def _start_worker_agent(agent_ip, agent_port, worker_ip, worker_port, - rank_id, device_id, model_file, group_config_file, rank_table_file, - with_bach_dim, without_batch_dim_inputs): + +def start_worker_agent(start_config): """Start up one worker agent on one device id, invoke by agent_startup.startup_worker_agents """ - check_type.check_str("agent_ip", agent_ip) - check_type.check_ip_port("agent_port", agent_port) - check_type.check_str("worker_ip", worker_ip) - check_type.check_ip_port("worker_port", worker_port) - check_type.check_int("rank_id", rank_id, 0) - check_type.check_int("device_id", device_id, 0) - check_type.check_str("model_file", model_file) - check_type.check_str("group_config_file", group_config_file) - check_type.check_str("rank_table_file", rank_table_file) - check_type.check_bool("with_bach_dim", with_bach_dim) - check_type.check_and_as_int_tuple_list("without_batch_dim_inputs", without_batch_dim_inputs, 0) + if not isinstance(start_config, AgentStartUpConfig_): + raise RuntimeError("Parameter 'start_config' should be instance of AgentStartUpConfig_") + + os.environ["RANK_ID"] = str(start_config.rank_id) + os.environ["DEVICE_ID"] = str(start_config.device_id) + os.environ["MS_ENABLE_HCCL"] = "1" + os.environ["PARA_GROUP_FILE"] = start_config.group_file_name + os.environ["RANK_TABLE_FILE"] = start_config.rank_table_json_file_name + + for item in ("RANK_ID", "DEVICE_ID", "MS_ENABLE_HCCL", "PARA_GROUP_FILE", "RANK_TABLE_FILE", + "LD_LIBRARY_PATH", "PYTHONPATH"): + logger.info(f"Env {item}: {os.getenv(item, '')}") + WorkerAgent_.start_agent(start_config) + + start_wait_and_clear() + + +_wait_and_clear_thread = None + + +def start_wait_and_clear(): + """Waiting for Ctrl+C, and clear up environment""" + + def thread_func(): + logger.info("Serving worker: wait for Ctrl+C to exit ------------------------------------") + print("Serving worker: wait for Ctrl+C to exit ------------------------------------") + WorkerAgent_.wait_and_clear() + logger.info("Serving worker: exited ------------------------------------") + print("Serving worker: exited ------------------------------------") + + global _wait_and_clear_thread + if not _wait_and_clear_thread: + _wait_and_clear_thread = threading.Thread(target=thread_func) + _wait_and_clear_thread.start() + + +def stop(): + r""" + Stop the running of agent. + """ + WorkerAgent_.stop_and_clear() From 746f531de95253122e6833893d71d620d1b1b520 Mon Sep 17 00:00:00 2001 From: xuyongfei Date: Tue, 2 Feb 2021 21:39:23 +0800 Subject: [PATCH 10/10] Serving, gpt3 merge with master --- mindspore_serving/ccsrc/common/servable.cc | 2 +- mindspore_serving/ccsrc/common/servable.h | 6 +++--- .../ccsrc/worker/local_servable/local_sevable.cc | 4 ---- .../ccsrc/worker/local_servable/local_sevable.h | 1 - 4 files changed, 4 insertions(+), 9 deletions(-) diff --git a/mindspore_serving/ccsrc/common/servable.cc b/mindspore_serving/ccsrc/common/servable.cc index 9f90fbf..26f5957 100644 --- a/mindspore_serving/ccsrc/common/servable.cc +++ b/mindspore_serving/ccsrc/common/servable.cc @@ -317,7 +317,7 @@ Status ServableStorage::DeclareServable(ServableMeta servable) { return INFER_STATUS_LOG_ERROR(FAILED) << "Declare servable " << common_meta.servable_name << " failed, servable_file cannot be empty"; } - if (servable.local_meta.model_format == api::kUnknownType) { + if (servable.local_meta.model_format == ModelType::kUnknownType) { return INFER_STATUS_LOG_ERROR(FAILED) << "Declare servable " << common_meta.servable_name << " failed, model_format is not inited"; } diff --git a/mindspore_serving/ccsrc/common/servable.h b/mindspore_serving/ccsrc/common/servable.h index 3458748..b64d99b 100644 --- a/mindspore_serving/ccsrc/common/servable.h +++ b/mindspore_serving/ccsrc/common/servable.h @@ -96,9 +96,9 @@ struct CommonServableMeta { }; struct MS_API LocalServableMeta { - std::string servable_file; // file name - ModelType model_format = api::kUnknownType; // OM, MindIR - std::map load_options; // Acl options + std::string servable_file; // file name + ModelType model_format = ModelType::kUnknownType; // OM, MindIR + std::map load_options; // Acl options void SetModelFormat(const std::string &format); }; diff --git a/mindspore_serving/ccsrc/worker/local_servable/local_sevable.cc b/mindspore_serving/ccsrc/worker/local_servable/local_sevable.cc index 81452c1..6f73444 100644 --- a/mindspore_serving/ccsrc/worker/local_servable/local_sevable.cc +++ b/mindspore_serving/ccsrc/worker/local_servable/local_sevable.cc @@ -65,10 +65,6 @@ uint64_t LocalModelServable::GetBatchSize() const { return session_.GetBatchSize(); } -TensorBasePtr LocalModelServable::MakeInferenceTensor(DataType data_type, const std::vector &shape) const { - return std::make_shared(data_type, shape); -} - Status LocalModelServable::StartServable(const std::string &servable_directory, const std::string &servable_name, uint64_t version_number) { if (model_loaded_) { diff --git a/mindspore_serving/ccsrc/worker/local_servable/local_sevable.h b/mindspore_serving/ccsrc/worker/local_servable/local_sevable.h index 227c9e9..eb43356 100644 --- a/mindspore_serving/ccsrc/worker/local_servable/local_sevable.h +++ b/mindspore_serving/ccsrc/worker/local_servable/local_sevable.h @@ -41,7 +41,6 @@ class MS_API LocalModelServable : public ServableBase { std::vector GetInputInfos() const override; std::vector GetOutputInfos() const override; uint64_t GetBatchSize() const override; - TensorBasePtr MakeInferenceTensor(DataType data_type, const std::vector &shape) const override; Status StartServable(const std::string &servable_directory, const std::string &servable_name, uint64_t version_number);