GitOrigin-RevId: 085fd1dcfd
tags/v1.2.0
| @@ -45,9 +45,15 @@ def launcher(func): | |||||
| while len(ranks) > 0: | while len(ranks) > 0: | ||||
| left = [] | left = [] | ||||
| # check all processes in one second | |||||
| time_to_wait = 1.0 / len(ranks) | |||||
| for rank in ranks: | for rank in ranks: | ||||
| procs[rank].join(1) | |||||
| procs[rank].join(time_to_wait) | |||||
| code = procs[rank].exitcode | code = procs[rank].exitcode | ||||
| # terminate processes if one of them has failed | |||||
| if code != 0 and code != None: | |||||
| for i in ranks: | |||||
| procs[i].terminate() | |||||
| assert ( | assert ( | ||||
| code == 0 or code == None | code == 0 or code == None | ||||
| ), "subprocess {} exit with code {}".format(rank, code) | ), "subprocess {} exit with code {}".format(rank, code) | ||||
| @@ -133,18 +133,22 @@ class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer): | |||||
| pass | pass | ||||
| def start_server(py_server_port, mm_server_port, queue): | |||||
| def _start_server(py_server_port, mm_server_port, queue): | |||||
| """ | """ | ||||
| Start python distributed server and multiple machine server. | Start python distributed server and multiple machine server. | ||||
| :param py_server_port: python server port. | :param py_server_port: python server port. | ||||
| :param mm_server_port: multiple machine server port. | :param mm_server_port: multiple machine server port. | ||||
| :param queue: server port will put in this queue, puts exception when process fails. | |||||
| """ | """ | ||||
| server = ThreadXMLRPCServer(("0.0.0.0", py_server_port), logRequests=False) | |||||
| server.register_instance(Methods(mm_server_port)) | |||||
| _, port = server.server_address | |||||
| queue.put(port) | |||||
| server.serve_forever() | |||||
| try: | |||||
| server = ThreadXMLRPCServer(("0.0.0.0", py_server_port), logRequests=False) | |||||
| server.register_instance(Methods(mm_server_port)) | |||||
| _, port = server.server_address | |||||
| queue.put(port) | |||||
| server.serve_forever() | |||||
| except Exception as e: | |||||
| queue.put(e) | |||||
| class Server: | class Server: | ||||
| @@ -159,10 +163,14 @@ class Server: | |||||
| self.mm_server_port = create_mm_server("0.0.0.0", 0) | self.mm_server_port = create_mm_server("0.0.0.0", 0) | ||||
| q = Queue() | q = Queue() | ||||
| self.proc = threading.Thread( | self.proc = threading.Thread( | ||||
| target=start_server, args=(port, self.mm_server_port, q), daemon=True, | |||||
| target=_start_server, args=(port, self.mm_server_port, q), daemon=True, | |||||
| ) | ) | ||||
| self.proc.start() | self.proc.start() | ||||
| self.py_server_port = q.get() | |||||
| ret = q.get() | |||||
| if isinstance(ret, Exception): | |||||
| raise ret | |||||
| else: | |||||
| self.py_server_port = ret | |||||
| class Client: | class Client: | ||||
| @@ -159,11 +159,9 @@ def run_test( | |||||
| checkpoint = mge.load(model_path) | checkpoint = mge.load(model_path) | ||||
| data = checkpoint["data"] | data = checkpoint["data"] | ||||
| label = checkpoint["label"] | label = checkpoint["label"] | ||||
| port = dist.get_free_ports(1)[0] | |||||
| server = dist.Server(port) | |||||
| def worker(rank, max_err): | |||||
| dist.init_process_group("localhost", port, p_num, rank, rank) | |||||
| @dist.launcher | |||||
| def worker(max_err): | |||||
| net = MnistNet(has_bn=True) | net = MnistNet(has_bn=True) | ||||
| net.load_state_dict(checkpoint["net_init"]) | net.load_state_dict(checkpoint["net_init"]) | ||||
| lr = checkpoint["sgd_lr"] | lr = checkpoint["sgd_lr"] | ||||
| @@ -194,15 +192,7 @@ def run_test( | |||||
| else: | else: | ||||
| np.testing.assert_allclose(param[1], param_ref[1], atol=max_err) | np.testing.assert_allclose(param[1], param_ref[1], atol=max_err) | ||||
| procs = [] | |||||
| for rank in range(p_num): | |||||
| p = mp.Process(target=worker, args=(rank, max_err,)) | |||||
| p.start() | |||||
| procs.append(p) | |||||
| for p in procs: | |||||
| p.join(20) | |||||
| assert p.exitcode == 0 | |||||
| worker(max_err) | |||||
| @pytest.mark.skipif(get_device_count_by_fork("gpu") < 4, reason="need more gpu device") | @pytest.mark.skipif(get_device_count_by_fork("gpu") < 4, reason="need more gpu device") | ||||
| @@ -23,6 +23,7 @@ from megengine.core.ops.builtin import Elemwise | |||||
| from megengine.core.tensor.raw_tensor import as_raw_tensor | from megengine.core.tensor.raw_tensor import as_raw_tensor | ||||
| from megengine.core.tensor.tensor import Tensor, apply | from megengine.core.tensor.tensor import Tensor, apply | ||||
| from megengine.core.tensor.tensor_wrapper import TensorWrapper | from megengine.core.tensor.tensor_wrapper import TensorWrapper | ||||
| from megengine.distributed.helper import get_device_count_by_fork | |||||
| from megengine.functional.distributed import remote_recv, remote_send | from megengine.functional.distributed import remote_recv, remote_send | ||||
| @@ -53,15 +54,19 @@ def save_to(self, name="grad"): | |||||
| return callback | return callback | ||||
| @pytest.mark.isolated_distributed | |||||
| @pytest.mark.skipif( | |||||
| platform.system() == "Darwin", reason="do not imp GPU mode at macos now" | |||||
| ) | |||||
| @pytest.mark.skipif( | @pytest.mark.skipif( | ||||
| platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM" | platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM" | ||||
| ) | ) | ||||
| @pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") | |||||
| @pytest.mark.isolated_distributed | |||||
| def test_dist_grad(): | def test_dist_grad(): | ||||
| world_size = 2 | world_size = 2 | ||||
| x_np = np.random.rand(10).astype("float32") | x_np = np.random.rand(10).astype("float32") | ||||
| port = dist.get_free_ports(1)[0] | |||||
| server = dist.Server(port) | |||||
| server = dist.Server() | |||||
| port = server.py_server_port | |||||
| def worker0(): | def worker0(): | ||||
| dist.init_process_group("localhost", port, world_size, 0, 0) | dist.init_process_group("localhost", port, world_size, 0, 0) | ||||
| @@ -47,8 +47,8 @@ def _assert_q_val(q, val): | |||||
| @pytest.mark.isolated_distributed | @pytest.mark.isolated_distributed | ||||
| def test_init_process_group(): | def test_init_process_group(): | ||||
| world_size = 2 | world_size = 2 | ||||
| port = dist.get_free_ports(1)[0] | |||||
| server = dist.Server(port) | |||||
| server = dist.Server() | |||||
| port = server.py_server_port | |||||
| def worker(rank, backend): | def worker(rank, backend): | ||||
| dist.init_process_group("localhost", port, world_size, rank, rank, backend) | dist.init_process_group("localhost", port, world_size, rank, rank, backend) | ||||
| @@ -92,11 +92,10 @@ def test_init_process_group(): | |||||
| def test_new_group(): | def test_new_group(): | ||||
| world_size = 3 | world_size = 3 | ||||
| ranks = [2, 0] | ranks = [2, 0] | ||||
| port = dist.get_free_ports(1)[0] | |||||
| server = dist.Server(port) | |||||
| def worker(rank): | |||||
| dist.init_process_group("localhost", port, world_size, rank, rank) | |||||
| @dist.launcher | |||||
| def worker(): | |||||
| rank = dist.get_rank() | |||||
| if rank in ranks: | if rank in ranks: | ||||
| group = dist.new_group(ranks) | group = dist.new_group(ranks) | ||||
| assert group.size == 2 | assert group.size == 2 | ||||
| @@ -104,15 +103,7 @@ def test_new_group(): | |||||
| assert group.rank == ranks.index(rank) | assert group.rank == ranks.index(rank) | ||||
| assert group.comp_node == "gpu{}:2".format(rank) | assert group.comp_node == "gpu{}:2".format(rank) | ||||
| procs = [] | |||||
| for rank in range(world_size): | |||||
| p = mp.Process(target=worker, args=(rank,)) | |||||
| p.start() | |||||
| procs.append(p) | |||||
| for p in procs: | |||||
| p.join(20) | |||||
| assert p.exitcode == 0 | |||||
| worker() | |||||
| @pytest.mark.skipif( | @pytest.mark.skipif( | ||||
| @@ -125,8 +116,8 @@ def test_new_group(): | |||||
| @pytest.mark.isolated_distributed | @pytest.mark.isolated_distributed | ||||
| def test_group_barrier(): | def test_group_barrier(): | ||||
| world_size = 2 | world_size = 2 | ||||
| port = dist.get_free_ports(1)[0] | |||||
| server = dist.Server(port) | |||||
| server = dist.Server() | |||||
| port = server.py_server_port | |||||
| def worker(rank, q): | def worker(rank, q): | ||||
| dist.init_process_group("localhost", port, world_size, rank, rank) | dist.init_process_group("localhost", port, world_size, rank, rank) | ||||
| @@ -161,8 +152,8 @@ def test_group_barrier(): | |||||
| @pytest.mark.isolated_distributed | @pytest.mark.isolated_distributed | ||||
| def test_synchronized(): | def test_synchronized(): | ||||
| world_size = 2 | world_size = 2 | ||||
| port = dist.get_free_ports(1)[0] | |||||
| server = dist.Server(port) | |||||
| server = dist.Server() | |||||
| port = server.py_server_port | |||||
| @dist.synchronized | @dist.synchronized | ||||
| def func(rank, q): | def func(rank, q): | ||||
| @@ -205,26 +196,16 @@ def test_synchronized(): | |||||
| @pytest.mark.isolated_distributed | @pytest.mark.isolated_distributed | ||||
| def test_user_set_get(): | def test_user_set_get(): | ||||
| world_size = 2 | world_size = 2 | ||||
| port = dist.get_free_ports(1)[0] | |||||
| server = dist.Server(port) | |||||
| def worker(rank): | |||||
| dist.init_process_group("localhost", port, world_size, rank, rank) | |||||
| @dist.launcher | |||||
| def worker(): | |||||
| # set in race condition | # set in race condition | ||||
| dist.get_client().user_set("foo", 1) | dist.get_client().user_set("foo", 1) | ||||
| # get in race condition | # get in race condition | ||||
| ret = dist.get_client().user_get("foo") | ret = dist.get_client().user_get("foo") | ||||
| assert ret == 1 | assert ret == 1 | ||||
| procs = [] | |||||
| for rank in range(world_size): | |||||
| p = mp.Process(target=worker, args=(rank,)) | |||||
| p.start() | |||||
| procs.append(p) | |||||
| for p in procs: | |||||
| p.join(20) | |||||
| assert p.exitcode == 0 | |||||
| worker() | |||||
| def test_oprmm_hashable(): | def test_oprmm_hashable(): | ||||
| @@ -41,8 +41,8 @@ from megengine.functional.distributed import ( | |||||
| @pytest.mark.isolated_distributed | @pytest.mark.isolated_distributed | ||||
| def test_reduce_sum(): | def test_reduce_sum(): | ||||
| world_size = 2 | world_size = 2 | ||||
| port = dist.get_free_ports(1)[0] | |||||
| server = dist.Server(port) | |||||
| server = dist.Server() | |||||
| port = server.py_server_port | |||||
| def worker(rank, data, expect, port): | def worker(rank, data, expect, port): | ||||
| if mge.get_device_count("gpu") < world_size: | if mge.get_device_count("gpu") < world_size: | ||||
| @@ -83,8 +83,8 @@ def test_reduce_sum(): | |||||
| @pytest.mark.isolated_distributed | @pytest.mark.isolated_distributed | ||||
| def test_broadcast(): | def test_broadcast(): | ||||
| world_size = 2 | world_size = 2 | ||||
| port = dist.get_free_ports(1)[0] | |||||
| server = dist.Server(port) | |||||
| server = dist.Server() | |||||
| port = server.py_server_port | |||||
| def worker(rank, data, expect, port): | def worker(rank, data, expect, port): | ||||
| if mge.get_device_count("gpu") < world_size: | if mge.get_device_count("gpu") < world_size: | ||||
| @@ -121,8 +121,8 @@ def test_broadcast(): | |||||
| @pytest.mark.isolated_distributed | @pytest.mark.isolated_distributed | ||||
| def test_all_gather(): | def test_all_gather(): | ||||
| world_size = 2 | world_size = 2 | ||||
| port = dist.get_free_ports(1)[0] | |||||
| server = dist.Server(port) | |||||
| server = dist.Server() | |||||
| port = server.py_server_port | |||||
| def worker(rank, data, expect, port): | def worker(rank, data, expect, port): | ||||
| if mge.get_device_count("gpu") < world_size: | if mge.get_device_count("gpu") < world_size: | ||||
| @@ -160,8 +160,8 @@ def test_all_gather(): | |||||
| @pytest.mark.isolated_distributed | @pytest.mark.isolated_distributed | ||||
| def test_reduce_scatter_sum(): | def test_reduce_scatter_sum(): | ||||
| world_size = 2 | world_size = 2 | ||||
| port = dist.get_free_ports(1)[0] | |||||
| server = dist.Server(port) | |||||
| server = dist.Server() | |||||
| port = server.py_server_port | |||||
| def worker(rank, data, expect, port): | def worker(rank, data, expect, port): | ||||
| if mge.get_device_count("gpu") < world_size: | if mge.get_device_count("gpu") < world_size: | ||||
| @@ -199,8 +199,8 @@ def test_reduce_scatter_sum(): | |||||
| @pytest.mark.isolated_distributed | @pytest.mark.isolated_distributed | ||||
| def test_all_reduce_sum(): | def test_all_reduce_sum(): | ||||
| world_size = 2 | world_size = 2 | ||||
| port = dist.get_free_ports(1)[0] | |||||
| server = dist.Server(port) | |||||
| server = dist.Server() | |||||
| port = server.py_server_port | |||||
| def worker(rank, data, expect, port): | def worker(rank, data, expect, port): | ||||
| if mge.get_device_count("gpu") < world_size: | if mge.get_device_count("gpu") < world_size: | ||||
| @@ -238,8 +238,8 @@ def test_all_reduce_sum(): | |||||
| @pytest.mark.isolated_distributed | @pytest.mark.isolated_distributed | ||||
| def test_all_reduce_max(): | def test_all_reduce_max(): | ||||
| world_size = 2 | world_size = 2 | ||||
| port = dist.get_free_ports(1)[0] | |||||
| server = dist.Server(port) | |||||
| server = dist.Server() | |||||
| port = server.py_server_port | |||||
| def worker(rank, data, expect, port): | def worker(rank, data, expect, port): | ||||
| if mge.get_device_count("gpu") < world_size: | if mge.get_device_count("gpu") < world_size: | ||||
| @@ -277,8 +277,8 @@ def test_all_reduce_max(): | |||||
| @pytest.mark.isolated_distributed | @pytest.mark.isolated_distributed | ||||
| def test_all_reduce_min(): | def test_all_reduce_min(): | ||||
| world_size = 2 | world_size = 2 | ||||
| port = dist.get_free_ports(1)[0] | |||||
| server = dist.Server(port) | |||||
| server = dist.Server() | |||||
| port = server.py_server_port | |||||
| def worker(rank, data, expect, port): | def worker(rank, data, expect, port): | ||||
| if mge.get_device_count("gpu") < world_size: | if mge.get_device_count("gpu") < world_size: | ||||
| @@ -316,8 +316,8 @@ def test_all_reduce_min(): | |||||
| @pytest.mark.isolated_distributed | @pytest.mark.isolated_distributed | ||||
| def test_gather(): | def test_gather(): | ||||
| world_size = 2 | world_size = 2 | ||||
| port = dist.get_free_ports(1)[0] | |||||
| server = dist.Server(port) | |||||
| server = dist.Server() | |||||
| port = server.py_server_port | |||||
| def worker(rank, data, expect, port): | def worker(rank, data, expect, port): | ||||
| if mge.get_device_count("gpu") < world_size: | if mge.get_device_count("gpu") < world_size: | ||||
| @@ -358,8 +358,8 @@ def test_gather(): | |||||
| @pytest.mark.isolated_distributed | @pytest.mark.isolated_distributed | ||||
| def test_scatter(): | def test_scatter(): | ||||
| world_size = 2 | world_size = 2 | ||||
| port = dist.get_free_ports(1)[0] | |||||
| server = dist.Server(port) | |||||
| server = dist.Server() | |||||
| port = server.py_server_port | |||||
| def worker(rank, data, expect, port): | def worker(rank, data, expect, port): | ||||
| if mge.get_device_count("gpu") < world_size: | if mge.get_device_count("gpu") < world_size: | ||||
| @@ -396,8 +396,8 @@ def test_scatter(): | |||||
| @pytest.mark.isolated_distributed | @pytest.mark.isolated_distributed | ||||
| def test_all_to_all(): | def test_all_to_all(): | ||||
| world_size = 2 | world_size = 2 | ||||
| port = dist.get_free_ports(1)[0] | |||||
| server = dist.Server(port) | |||||
| server = dist.Server() | |||||
| port = server.py_server_port | |||||
| def worker(rank, data, expect, port): | def worker(rank, data, expect, port): | ||||
| if mge.get_device_count("gpu") < world_size: | if mge.get_device_count("gpu") < world_size: | ||||
| @@ -436,8 +436,8 @@ def test_all_to_all(): | |||||
| @pytest.mark.isolated_distributed | @pytest.mark.isolated_distributed | ||||
| def test_io_remote(): | def test_io_remote(): | ||||
| world_size = 2 | world_size = 2 | ||||
| port = dist.get_free_ports(1)[0] | |||||
| server = dist.Server(port) | |||||
| server = dist.Server() | |||||
| port = server.py_server_port | |||||
| val = np.random.rand(4, 5).astype(np.float32) | val = np.random.rand(4, 5).astype(np.float32) | ||||
| def worker(rank): | def worker(rank): | ||||
| @@ -38,7 +38,7 @@ def test_syncbn(): | |||||
| running_var = np.ones((1, nr_chan, 1, 1), dtype=np.float32) | running_var = np.ones((1, nr_chan, 1, 1), dtype=np.float32) | ||||
| steps = 4 | steps = 4 | ||||
| nr_ranks = 2 | nr_ranks = 2 | ||||
| server = dist.Server(0) | |||||
| server = dist.Server() | |||||
| port = server.py_server_port | port = server.py_server_port | ||||
| def worker(rank, data, yv_expect, running_mean, running_var): | def worker(rank, data, yv_expect, running_mean, running_var): | ||||
| @@ -28,25 +28,16 @@ def test_min_max_observer(): | |||||
| @pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") | @pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") | ||||
| @pytest.mark.isolated_distributed | @pytest.mark.isolated_distributed | ||||
| def test_sync_min_max_observer(): | def test_sync_min_max_observer(): | ||||
| x = np.random.rand(6, 3, 3, 3).astype("float32") | |||||
| word_size = get_device_count_by_fork("gpu") | |||||
| x = np.random.rand(3 * word_size, 3, 3, 3).astype("float32") | |||||
| np_min, np_max = x.min(), x.max() | np_min, np_max = x.min(), x.max() | ||||
| world_size = 2 | |||||
| port = dist.get_free_ports(1)[0] | |||||
| server = dist.Server(port) | |||||
| def worker(rank, slc): | |||||
| dist.init_process_group("localhost", port, world_size, rank, rank) | |||||
| @dist.launcher | |||||
| def worker(): | |||||
| rank = dist.get_rank() | |||||
| m = ob.SyncMinMaxObserver() | m = ob.SyncMinMaxObserver() | ||||
| y = mge.tensor(x[slc]) | |||||
| y = mge.tensor(x[rank * 3 : (rank + 1) * 3]) | |||||
| m(y) | m(y) | ||||
| assert m.min_val == np_min and m.max_val == np_max | assert m.min_val == np_min and m.max_val == np_max | ||||
| procs = [] | |||||
| for rank in range(world_size): | |||||
| slc = slice(rank * 3, (rank + 1) * 3) | |||||
| p = mp.Process(target=worker, args=(rank, slc,), daemon=True) | |||||
| p.start() | |||||
| procs.append(p) | |||||
| for p in procs: | |||||
| p.join(20) | |||||
| assert p.exitcode == 0 | |||||
| worker() | |||||