diff --git a/example/add/master_with_worker.py b/example/add/master_with_worker.py index c39f899..9428a47 100644 --- a/example/add/master_with_worker.py +++ b/example/add/master_with_worker.py @@ -15,12 +15,13 @@ """Start Servable add""" import os +import sys from mindspore_serving import master from mindspore_serving import worker def start(): - servable_dir = os.path.abspath(".") + servable_dir = os.path.dirname(os.path.realpath(sys.argv[0])) worker.start_servable_in_master(servable_dir, "add", device_id=0) master.start_grpc_server("127.0.0.1", 5500) diff --git a/example/resnet/master_with_worker.py b/example/resnet/master_with_worker.py index 4296528..3f91f9b 100644 --- a/example/resnet/master_with_worker.py +++ b/example/resnet/master_with_worker.py @@ -15,12 +15,13 @@ """Start Servable resnet50""" import os +import sys from mindspore_serving import master from mindspore_serving import worker def start(): - servable_dir = os.path.abspath(".") + servable_dir = os.path.dirname(os.path.realpath(sys.argv[0])) worker.start_servable_in_master(servable_dir, "resnet50", device_id=0) master.start_grpc_server("127.0.0.1", 5500) diff --git a/mindspore_serving/worker/distributed/agent_startup.py b/mindspore_serving/worker/distributed/agent_startup.py index c6a895b..80faa96 100644 --- a/mindspore_serving/worker/distributed/agent_startup.py +++ b/mindspore_serving/worker/distributed/agent_startup.py @@ -20,6 +20,7 @@ import sys import traceback import signal from multiprocessing import Process, Pipe +import threading import psutil from mindspore_serving._mindspore_serving import ExitSignalHandle_ @@ -49,6 +50,22 @@ def _get_local_ip(rank_list, port): raise RuntimeError(f"Get local machine ip failed, rank table ips: {ip_list}, bind port {port}") +def _check_local_ip(agent_ip, port): + """Check the local ip""" + import socket + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + for i in range(8): + try: + s.bind((agent_ip, port + i)) + logger.info(f"Check local machine ip success, ip {agent_ip}") + return True + # pylint: disable=bare-except + except: + pass + return False + + def _update_model_files_path(model_files, group_config_files): """Check and return model files or group config files""" script_dir = os.path.dirname(os.path.realpath(sys.argv[0])) @@ -126,7 +143,6 @@ def _agent_process(send_pipe, recv_pipe, index, start_config): parent_process = psutil.Process(os.getppid()) try: # listening success or failed message from parent process - ExitSignalHandle_.start() # Set flag to running and receive Ctrl+C message worker_agent.start_worker_agent(start_config=start_config) send_pipe.send((index, signal_success)) success_msg = _recv_parent(parent_process, index, recv_pipe) @@ -232,6 +248,10 @@ def _listening_agents_when_startup(p_recv_pipe, send_pipe_list, subprocess_list) while True: if p_recv_pipe.poll(0.1): break + if ExitSignalHandle_.has_stopped(): + logger.warning("Fail to start agents because of Ctrl+C") + _send_exit_msg_to_children(send_pipe_list, subprocess_list) + return False for send_pipe, process in zip(send_pipe_list, subprocess_list): if process.is_alive(): continue @@ -251,14 +271,24 @@ def _listening_agents_when_startup(p_recv_pipe, send_pipe_list, subprocess_list) return True -def _listening_agents_after_startup(subprocess_list): +def _listening_agents_after_startup(subprocess_list, worker_ip, worker_port, agent_ip): """Listening agent status after success start up of agents""" - while not ExitSignalHandle_.has_stopped(): - for index, process in enumerate(subprocess_list): - if not process.is_alive(): - logger.warning(f"Child {index}, pid={process.pid} has exited") - return - time.sleep(0.1) + + def wait_child_exit(): + while not ExitSignalHandle_.has_stopped(): + for index, process in enumerate(subprocess_list): + if not process.is_alive(): + logger.warning(f"Child {index}, pid={process.pid} has exited") + return + time.sleep(0.1) + + def listening_thread_fun(): + wait_child_exit() + WorkerAgent_.startup_notify_exit(worker_ip, worker_port, agent_ip) + _send_exit_signal_to_children(subprocess_list) + + thread = threading.Thread(target=listening_thread_fun) + thread.start() def _startup_agents(common_meta, worker_ip, worker_port, @@ -311,12 +341,11 @@ def _startup_agents(common_meta, worker_ip, worker_port, logger.info(f"Success to start agents, {msg}") print(f"Success to start agents, {msg}") - _listening_agents_after_startup(subprocess_list) - WorkerAgent_.startup_notify_exit(worker_ip, worker_port, agent_ip) - _send_exit_signal_to_children(subprocess_list) + _listening_agents_after_startup(subprocess_list, worker_ip, worker_port, agent_ip) -def startup_worker_agents(worker_ip, worker_port, model_files, group_config_files=None, agent_start_port=7000): +def startup_worker_agents(worker_ip, worker_port, model_files, group_config_files=None, + agent_start_port=7000, agent_ip=None, rank_start=None): r""" Start up all needed worker agenton current machine. @@ -334,6 +363,13 @@ def startup_worker_agents(worker_ip, worker_port, model_files, group_config_file this startup python script. group_config_files (None, list or tuple of str): All group config files need in current machine, absolute path or path relative to this startup python script, default None, which means there are no configuration files. + agent_start_port (int): The starting agent port of the agents link to worker. + agent_ip (str or None): The local agent ip, if it's None, the agent ip will be obtained from rank table file. + Default None. Parameter agent_ip and parameter rank_start must have values at the same time, + or both None at the same time. + rank_start (int or None): The starting rank id of this machine, if it's None, the rank ip will be obtained from + rank table file. Default None. Parameter agent_ip and parameter rank_start must have values at the same + time, or both None at the same time. Examples: >>> import os @@ -355,14 +391,34 @@ def startup_worker_agents(worker_ip, worker_port, model_files, group_config_file # get machine ip rank_list = distributed_config.rank_list - local_ip = _get_local_ip(rank_list, agent_start_port) - # get all device_id and rank_id local_device_id_list = [] local_rank_id_list = [] - for rank_id, item in enumerate(rank_list): - if item.ip == local_ip: - local_device_id_list.append(item.device_id) - local_rank_id_list.append(rank_id) + if agent_ip is None: + if rank_start is not None: + raise RuntimeError("Parameter 'agent_ip' and parameter 'rank_start' must have values at the same time, " + "or both None at the same time.") + local_ip = _get_local_ip(rank_list, agent_start_port) + # get all device_id and rank_id + for rank_id, item in enumerate(rank_list): + if item.ip == local_ip: + local_device_id_list.append(item.device_id) + local_rank_id_list.append(rank_id) + else: + if rank_start is None: + raise RuntimeError("Parameter 'agent_ip' and parameter 'rank_start' must have values at the same time, " + "or both None at the same time.") + check_type.check_str("agent_ip", agent_ip) + check_type.check_int("rank_start", rank_start, 0) + if rank_start >= len(rank_list): + raise RuntimeError(f"Parameter 'rank_start' cannot equal or larger than rank size {len(rank_list)}.") + if not _check_local_ip(agent_ip, agent_start_port): + raise RuntimeError(f"Check ip 'agent_ip' valid failed, agent_ip: {agent_ip}") + local_ip = agent_ip + rank_table_ip = rank_list[rank_start].ip + for rank_id, item in enumerate(rank_list): + if item.ip == rank_table_ip: + local_device_id_list.append(item.device_id) + local_rank_id_list.append(rank_id) # handle model files and group config files if len(local_device_id_list) != len(model_files): diff --git a/mindspore_serving/worker/distributed/worker_agent.py b/mindspore_serving/worker/distributed/worker_agent.py index 3043e9f..37a3f37 100644 --- a/mindspore_serving/worker/distributed/worker_agent.py +++ b/mindspore_serving/worker/distributed/worker_agent.py @@ -16,9 +16,11 @@ import os import threading -from mindspore_serving.worker import init_mindspore -from mindspore_serving._mindspore_serving import WorkerAgent_, AgentStartUpConfig_ + +from mindspore_serving._mindspore_serving import WorkerAgent_, AgentStartUpConfig_, ExitSignalHandle_ + from mindspore_serving import log as logger +from mindspore_serving.worker import init_mindspore def start_worker_agent(start_config): @@ -26,6 +28,15 @@ def start_worker_agent(start_config): """ if not isinstance(start_config, AgentStartUpConfig_): raise RuntimeError("Parameter 'start_config' should be instance of AgentStartUpConfig_") + logger.info(f"rank_id={start_config.rank_id}, device_id={start_config.device_id}, " + f"model_file='{start_config.model_file_name}', group_file='{start_config.group_file_name}', " + f"rank_table_file='{start_config.rank_table_json_file_name}'," + f"agent_ip='{start_config.agent_ip}', agent_port={start_config.agent_port}, " + f"worker_ip='{start_config.worker_ip}', worker_port={start_config.worker_port}," + f"with_batch_dim={start_config.common_meta.with_batch_dim}, " + f"without_batch_dim_inputs={start_config.common_meta.without_batch_dim_inputs}") + + ExitSignalHandle_.start() # Set flag to running and receive Ctrl+C message init_mindspore.init_mindspore_cxx_env() os.environ["RANK_ID"] = str(start_config.rank_id) diff --git a/tests/ut/python/tests/common.py b/tests/ut/python/tests/common.py index e78c70c..a9f1b09 100644 --- a/tests/ut/python/tests/common.py +++ b/tests/ut/python/tests/common.py @@ -72,7 +72,55 @@ class ServingTestBase: with open(config_file, "w") as fp: fp.write(servable_config_content) + def init_distributed_servable(self, servable_config_content, rank_size, rank_table_content): + global servable_index + self.servable_name = "add_" + str(servable_index) + servable_index += 1 + self.version_number = 1 + self.servable_name_path = os.path.join(self.servable_dir, self.servable_name) + self.model_dir = os.path.join(self.servable_dir, "model_"+self.servable_name) + self.rank_table_content_path = os.path.join(self.servable_dir, self.servable_name + "_hccl.json") + try: + os.mkdir(self.servable_dir) + except FileExistsError: + pass + try: + os.mkdir(self.servable_name_path) + except FileExistsError: + pass + try: + os.mkdir(self.model_dir) + except FileExistsError: + pass + self.model_file_list = [] + for i in range(rank_size): + model_file_path = os.path.join(self.model_dir, f"model{i}.mindir") + self.model_file_list.append(model_file_path) + with open(model_file_path, "w") as fp: + print("model content", file=fp) + self.group_config_list = [] + for i in range(rank_size): + group_config = os.path.join(self.model_dir, f"group{i}.pb") + self.group_config_list.append(group_config) + with open(group_config, "w") as fp: + print("group config content", file=fp) + + if servable_config_content is not None: + config_file = os.path.join(self.servable_name_path, "servable_config.py") + with open(config_file, "w") as fp: + fp.write(servable_config_content) + + if rank_table_content is not None: + with open(self.rank_table_content_path, "w") as fp: + fp.write(rank_table_content) + + @staticmethod + def add_on_exit(fun): + global exit_fun_list + exit_fun_list.append(fun) + +exit_fun_list = [] client_create_list = [] @@ -91,6 +139,10 @@ def serving_test(func): del client.stub client.stub = None client_create_list = [] + global exit_fun_list + for fun in exit_fun_list: + fun() + exit_fun_list = [] return wrap_test diff --git a/tests/ut/python/tests/test_distributed_worker.py b/tests/ut/python/tests/test_distributed_worker.py new file mode 100644 index 0000000..5f965c7 --- /dev/null +++ b/tests/ut/python/tests/test_distributed_worker.py @@ -0,0 +1,291 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test distributed worker""" + +import os +import signal +import time +from multiprocessing import Process, Pipe +import numpy as np +import psutil + +from common import serving_test, create_client, ServingTestBase +from mindspore_serving.worker import distributed +from mindspore_serving import master + +distributed_import = r""" +import numpy as np +from mindspore_serving.worker import distributed +from mindspore_serving.worker import register +""" + +distributed_declare_servable = r""" +distributed.declare_distributed_servable(rank_size=8, stage_size=1, with_batch_dim=False) +""" + +rank_table_content = r""" +{ + "version": "1.0", "server_count": "1", + "server_list": [ + { + "server_id": "127.0.0.1", + "device": [ + { "device_id": "0", "device_ip": "192.1.27.6", "rank_id": "0" }, + { "device_id": "1", "device_ip": "192.2.27.6", "rank_id": "1" }, + { "device_id": "2", "device_ip": "192.3.27.6", "rank_id": "2" }, + { "device_id": "3", "device_ip": "192.4.27.6", "rank_id": "3" }, + { "device_id": "4", "device_ip": "192.1.27.7", "rank_id": "4" }, + { "device_id": "5", "device_ip": "192.2.27.7", "rank_id": "5" }, + { "device_id": "6", "device_ip": "192.3.27.7", "rank_id": "6" }, + { "device_id": "7", "device_ip": "192.4.27.7", "rank_id": "7" } + ], + "host_nic_ip": "reserve" + } + ], + "status": "completed" +} +""" + + +def init_distributed_servable(): + base = ServingTestBase() + servable_content = distributed_import + servable_content += distributed_declare_servable + servable_content += r""" +@register.register_method(output_names=["y"]) +def predict(x1, x2): + y = register.call_servable(x1, x2) + return y +""" + base.init_distributed_servable(servable_content, 8, rank_table_content) + return base + + +def start_distributed_grpc_server(): + base = init_distributed_servable() + return base + + +def start_distributed_worker(base): + def worker_process(): + distributed.start_distributed_servable_in_master(base.servable_dir, base.servable_name, + rank_table_json_file=base.rank_table_content_path, + worker_ip="127.0.0.1", worker_port=6200) + master.start_grpc_server("0.0.0.0", 5500) + + worker = Process(target=worker_process) + worker.start() + time.sleep(0.5) # wait parse rank table ready + assert worker.is_alive() + return worker + + +def start_agents(model_file_list, group_config_list): + send_pipe, recv_pipe = Pipe() + + def agent_process(send_pipe): + distributed.startup_worker_agents(worker_ip="127.0.0.1", worker_port=6200, model_files=model_file_list, + group_config_files=group_config_list) + send_pipe.send("Success") + send_pipe.close() + + agent = Process(target=agent_process, args=(send_pipe,)) + agent.start() + index = 0 + while index < 30 and agent.is_alive(): # wait max 3 s + index += 1 + if recv_pipe.poll(0.1): + break + assert index < 30 + assert agent.is_alive() + return agent + + +def send_exit(process): + if not process.is_alive(): + return + parent_process = psutil.Process(process.pid) + child_processes = parent_process.children(recursive=True) + + def children_alive(): + return any([item.is_running() for item in child_processes]) + os.kill(process.pid, signal.SIGINT) + for _ in range(50): # 50*0.1s + if not process.is_alive() and not children_alive(): + break + time.sleep(0.1) + for item in child_processes: + if item.is_running(): + os.kill(item.pid, signal.SIGKILL) + if process.is_alive(): + os.kill(process.pid, signal.SIGKILL) + + +@serving_test +def test_distributed_worker_worker_exit_success(): + base = start_distributed_grpc_server() + worker_process = start_distributed_worker(base) + base.add_on_exit(lambda: send_exit(worker_process)) + agent_process = start_agents(base.model_file_list, base.group_config_list) + base.add_on_exit(lambda: send_exit(agent_process)) + + client = create_client("localhost", 5500, base.servable_name, "predict") + instances = [{}, {}, {}] + y_data_list = [] + for index, instance in enumerate(instances): + instance["x1"] = np.array([[1.1, 1.2], [2.2, 2.3]], np.float32) * (index + 1) + instance["x2"] = np.array([[3.3, 3.4], [4.4, 4.5]], np.float32) * (index + 1) + y_data_list.append((instance["x1"] + instance["x2"]).tolist()) + + result = client.infer(instances) + print(result) + assert len(result) == 3 + assert result[0]["y"].dtype == np.float32 + assert result[1]["y"].dtype == np.float32 + assert result[2]["y"].dtype == np.float32 + assert result[0]["y"].tolist() == y_data_list[0] + assert result[1]["y"].tolist() == y_data_list[1] + assert result[2]["y"].tolist() == y_data_list[2] + + # send SIGINT to worker, expect worker and all agents exit + agents = psutil.Process(agent_process.pid).children() + + def agents_alive(): + return any([item.is_running() for item in agents]) + os.kill(worker_process.pid, signal.SIGINT) + for _ in range(50): # 50*0.1s + if not worker_process.is_alive() and not agent_process.is_alive() and not agents_alive(): + break + time.sleep(0.1) + assert not worker_process.is_alive() + assert not agent_process.is_alive() + assert not agents_alive() + + +@serving_test +def test_distributed_worker_agent_exit_success(): + base = start_distributed_grpc_server() + worker_process = start_distributed_worker(base) + base.add_on_exit(lambda: send_exit(worker_process)) + agent_process = start_agents(base.model_file_list, base.group_config_list) + base.add_on_exit(lambda: send_exit(agent_process)) + + client = create_client("localhost", 5500, base.servable_name, "predict") + instances = [{}, {}, {}] + y_data_list = [] + for index, instance in enumerate(instances): + instance["x1"] = np.array([[1.1, 1.2], [2.2, 2.3]], np.float32) * (index + 1) + instance["x2"] = np.array([[3.3, 3.4], [4.4, 4.5]], np.float32) * (index + 1) + y_data_list.append((instance["x1"] + instance["x2"]).tolist()) + + result = client.infer(instances) + print(result) + assert len(result) == 3 + assert result[0]["y"].tolist() == y_data_list[0] + assert result[1]["y"].tolist() == y_data_list[1] + assert result[2]["y"].tolist() == y_data_list[2] + + # send SIGINT to worker, expect worker and all agents exit + agents = psutil.Process(agent_process.pid).children() + + def agents_alive(): + return any([item.is_running() for item in agents]) + os.kill(agent_process.pid, signal.SIGINT) + for _ in range(50): # 50*0.1s + if not worker_process.is_alive() and not agent_process.is_alive() and not agents_alive(): + break + time.sleep(0.1) + assert not worker_process.is_alive() + assert not agent_process.is_alive() + assert not agents_alive() + + +@serving_test +def test_distributed_worker_agent_startup_killed_exit_success(): + base = start_distributed_grpc_server() + worker_process = start_distributed_worker(base) + base.add_on_exit(lambda: send_exit(worker_process)) + agent_process = start_agents(base.model_file_list, base.group_config_list) + base.add_on_exit(lambda: send_exit(agent_process)) + + client = create_client("localhost", 5500, base.servable_name, "predict") + instances = [{}, {}, {}] + y_data_list = [] + for index, instance in enumerate(instances): + instance["x1"] = np.array([[1.1, 1.2], [2.2, 2.3]], np.float32) * (index + 1) + instance["x2"] = np.array([[3.3, 3.4], [4.4, 4.5]], np.float32) * (index + 1) + y_data_list.append((instance["x1"] + instance["x2"]).tolist()) + + result = client.infer(instances) + print(result) + assert len(result) == 3 + assert result[0]["y"].tolist() == y_data_list[0] + assert result[1]["y"].tolist() == y_data_list[1] + assert result[2]["y"].tolist() == y_data_list[2] + + # send SIGINT to worker, expect worker and all agents exit + agents = psutil.Process(agent_process.pid).children() + + def agents_alive(): + return any([item.is_running() for item in agents]) + os.kill(agent_process.pid, signal.SIGKILL) # kill msg + for _ in range(50): # 50*0.1s + # test agent_process.is_alive() first, it will make agents(children) notify exit of their parent + if not agent_process.is_alive() and not worker_process.is_alive() and not agents_alive(): + break + time.sleep(0.1) + assert not worker_process.is_alive() + assert not agent_process.is_alive() + assert not agents_alive() + + +@serving_test +def test_distributed_worker_agent_killed_exit_success(): + base = start_distributed_grpc_server() + worker_process = start_distributed_worker(base) + base.add_on_exit(lambda: send_exit(worker_process)) + agent_process = start_agents(base.model_file_list, base.group_config_list) + base.add_on_exit(lambda: send_exit(agent_process)) + + client = create_client("localhost", 5500, base.servable_name, "predict") + instances = [{}, {}, {}] + y_data_list = [] + for index, instance in enumerate(instances): + instance["x1"] = np.array([[1.1, 1.2], [2.2, 2.3]], np.float32) * (index + 1) + instance["x2"] = np.array([[3.3, 3.4], [4.4, 4.5]], np.float32) * (index + 1) + y_data_list.append((instance["x1"] + instance["x2"]).tolist()) + + result = client.infer(instances) + print(result) + assert len(result) == 3 + assert result[0]["y"].tolist() == y_data_list[0] + assert result[1]["y"].tolist() == y_data_list[1] + assert result[2]["y"].tolist() == y_data_list[2] + + # send SIGINT to worker, expect worker and all agents exit + agents = psutil.Process(agent_process.pid).children() + assert agents + + def agents_alive(): + return any([item.is_running() for item in agents]) + os.kill(agents[0].pid, signal.SIGKILL) # kill msg + for _ in range(50): # 50*0.1s + if not worker_process.is_alive() and not agent_process.is_alive() and not agents_alive(): + break + time.sleep(0.1) + + assert not worker_process.is_alive() + assert not agent_process.is_alive() + assert not agents_alive()