GitOrigin-RevId: 90be2d5b4d
tags/v1.5.0
| @@ -12,6 +12,7 @@ from typing import Optional | |||||
| from .core._imperative_rt.common import CompNode, DeviceType | from .core._imperative_rt.common import CompNode, DeviceType | ||||
| from .core._imperative_rt.common import set_prealloc_config as _set_prealloc_config | from .core._imperative_rt.common import set_prealloc_config as _set_prealloc_config | ||||
| from .core._imperative_rt.common import what_is_xpu as _what_is_xpu | |||||
| __all__ = [ | __all__ = [ | ||||
| "is_cuda_available", | "is_cuda_available", | ||||
| @@ -25,7 +26,7 @@ __all__ = [ | |||||
| def _valid_device(inp): | def _valid_device(inp): | ||||
| if isinstance(inp, str) and re.match("^[cxg]pu(\d+|\d+:\d+|x)$", inp): | |||||
| if isinstance(inp, str) and re.match("^([cxg]pu|rocm)(\d+|\d+:\d+|x)$", inp): | |||||
| return True | return True | ||||
| return False | return False | ||||
| @@ -40,21 +41,24 @@ def _str2device_type(type_str: str, allow_unspec: bool = True): | |||||
| return DeviceType.CAMBRICON | return DeviceType.CAMBRICON | ||||
| elif type_str == "ATLAS": | elif type_str == "ATLAS": | ||||
| return DeviceType.ATLAS | return DeviceType.ATLAS | ||||
| elif type_str == "ROCM" or type_str == "AMDGPU": | |||||
| return DeviceType.ROCM | |||||
| else: | else: | ||||
| assert allow_unspec and str == "XPU", "device type can only be cpu, gpu or xpu" | assert allow_unspec and str == "XPU", "device type can only be cpu, gpu or xpu" | ||||
| return DeviceType.UNSPEC | return DeviceType.UNSPEC | ||||
| _device_type_set = {"cpu", "gpu", "xpu", "rocm"} | |||||
| def get_device_count(device_type: str) -> int: | def get_device_count(device_type: str) -> int: | ||||
| """ | """ | ||||
| Gets number of devices installed on this system. | Gets number of devices installed on this system. | ||||
| :param device_type: device type, one of 'gpu' or 'cpu' | :param device_type: device type, one of 'gpu' or 'cpu' | ||||
| """ | """ | ||||
| device_type_set = ("cpu", "gpu") | |||||
| assert device_type in device_type_set, "device must be one of {}".format( | |||||
| device_type_set | |||||
| assert device_type in _device_type_set, "device must be one of {}".format( | |||||
| _device_type_set | |||||
| ) | ) | ||||
| device_type = _str2device_type(device_type) | device_type = _str2device_type(device_type) | ||||
| return CompNode._get_device_count(device_type, False) | return CompNode._get_device_count(device_type, False) | ||||
| @@ -87,6 +91,14 @@ def is_atlas_available() -> bool: | |||||
| return CompNode._get_device_count(t, False) > 0 | return CompNode._get_device_count(t, False) > 0 | ||||
| def is_rocm_available() -> bool: | |||||
| """Returns whether rocm device is available on this system. | |||||
| """ | |||||
| t = _str2device_type("rocm") | |||||
| return CompNode._get_device_count(t, False) > 0 | |||||
| def set_default_device(device: str = "xpux"): | def set_default_device(device: str = "xpux"): | ||||
| r""" | r""" | ||||
| Sets default computing node. | Sets default computing node. | ||||
| @@ -151,3 +163,7 @@ def set_prealloc_config( | |||||
| assert max_overhead >= 0 | assert max_overhead >= 0 | ||||
| assert growth_factor >= 1 | assert growth_factor >= 1 | ||||
| _set_prealloc_config(alignment, min_req, max_overhead, growth_factor, device_type) | _set_prealloc_config(alignment, min_req, max_overhead, growth_factor, device_type) | ||||
| def what_is_xpu(): | |||||
| return _what_is_xpu().name.lower() | |||||
| @@ -8,7 +8,7 @@ | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| from typing import List, Optional, Tuple | from typing import List, Optional, Tuple | ||||
| from ..device import set_default_device | |||||
| from ..device import set_default_device, what_is_xpu | |||||
| from .server import Client, Server | from .server import Client, Server | ||||
| @@ -23,6 +23,7 @@ class StaticData: | |||||
| device = None | device = None | ||||
| backend = None | backend = None | ||||
| next_stream = None | next_stream = None | ||||
| device_type = None | |||||
| _sd = None | _sd = None | ||||
| @@ -78,19 +79,29 @@ class Group: | |||||
| @property | @property | ||||
| def comp_node(self): | def comp_node(self): | ||||
| assert len(self.proc_ranks) > 0, "invalid group" | assert len(self.proc_ranks) > 0, "invalid group" | ||||
| return "gpu{}:{}".format(_sd.device, self.stream) | |||||
| return "{}{}:{}".format(_sd.device_type, _sd.device, self.stream) | |||||
| WORLD = Group([]) | WORLD = Group([]) | ||||
| _device2backend = { | |||||
| "gpu": "nccl", | |||||
| "cuda": "nccl", | |||||
| "rocm": "rccl", | |||||
| } | |||||
| _backends = {"nccl", "rccl", "ucx"} | |||||
| def init_process_group( | def init_process_group( | ||||
| master_ip: str, | master_ip: str, | ||||
| port: int, | port: int, | ||||
| world_size: int, | world_size: int, | ||||
| rank: int, | rank: int, | ||||
| device: int, | device: int, | ||||
| backend: Optional[str] = "nccl", | |||||
| backend: Optional[str] = None, | |||||
| device_type: str = "xpu", | |||||
| ) -> None: | ) -> None: | ||||
| """ | """ | ||||
| Initialize the distributed process group and specify the device used in the current process | Initialize the distributed process group and specify the device used in the current process | ||||
| @@ -102,6 +113,8 @@ def init_process_group( | |||||
| :param device: the GPU device id to bind this process to. | :param device: the GPU device id to bind this process to. | ||||
| :param backend: communicator backend, currently support 'nccl' and 'ucx'. | :param backend: communicator backend, currently support 'nccl' and 'ucx'. | ||||
| """ | """ | ||||
| physical_device_type = what_is_xpu() if device_type == "xpu" else device_type | |||||
| backend = _device2backend[physical_device_type] if backend is None else backend | |||||
| if not isinstance(master_ip, str): | if not isinstance(master_ip, str): | ||||
| raise TypeError("Expect type str but got {}".format(type(master_ip))) | raise TypeError("Expect type str but got {}".format(type(master_ip))) | ||||
| if not isinstance(port, int): | if not isinstance(port, int): | ||||
| @@ -112,8 +125,14 @@ def init_process_group( | |||||
| raise TypeError("Expect type int but got {}".format(type(rank))) | raise TypeError("Expect type int but got {}".format(type(rank))) | ||||
| if not isinstance(device, int): | if not isinstance(device, int): | ||||
| raise TypeError("Expect type int but got {}".format(type(backend))) | raise TypeError("Expect type int but got {}".format(type(backend))) | ||||
| if not isinstance(backend, str): | |||||
| raise TypeError("Expect type str but got {}".format(type(backend))) | |||||
| if backend not in _backends: | |||||
| raise ValueError( | |||||
| "backend should be one of {} but got {}".format(_backends, backend) | |||||
| ) | |||||
| if physical_device_type not in _device2backend: | |||||
| raise ValueError( | |||||
| "{} is not a valid distributed device type".format(device_type) | |||||
| ) | |||||
| global _sd | global _sd | ||||
| assert _sd is None, "init_process_group should be called only once" | assert _sd is None, "init_process_group should be called only once" | ||||
| @@ -132,10 +151,11 @@ def init_process_group( | |||||
| _sd.device = device | _sd.device = device | ||||
| _sd.backend = backend | _sd.backend = backend | ||||
| _sd.next_stream = 1 | _sd.next_stream = 1 | ||||
| _sd.device_type = device_type | |||||
| WORLD.reset(list(range(world_size))) | WORLD.reset(list(range(world_size))) | ||||
| set_default_device("gpu{}".format(device)) | |||||
| set_default_device("{}{}".format(device_type, device)) | |||||
| def is_distributed() -> bool: | def is_distributed() -> bool: | ||||
| @@ -182,7 +202,7 @@ def new_group(proc_ranks: List[int]) -> Group: | |||||
| return Group(proc_ranks) | return Group(proc_ranks) | ||||
| def group_barrier(group: Optional[Group] = WORLD) -> None: | |||||
| def group_barrier(group: Group = WORLD) -> None: | |||||
| """Block until all ranks in the group reach this barrier.""" | """Block until all ranks in the group reach this barrier.""" | ||||
| # if running with single node, skip it | # if running with single node, skip it | ||||
| if _sd is None: | if _sd is None: | ||||
| @@ -29,13 +29,19 @@ def _run_wrapped( | |||||
| world_size, | world_size, | ||||
| rank, | rank, | ||||
| dev, | dev, | ||||
| device_type, | |||||
| args, | args, | ||||
| kwargs, | kwargs, | ||||
| queue: mp.Queue, | queue: mp.Queue, | ||||
| ): | ): | ||||
| """Init distributed process group and run wrapped function.""" | """Init distributed process group and run wrapped function.""" | ||||
| init_process_group( | init_process_group( | ||||
| master_ip=master_ip, port=port, world_size=world_size, rank=rank, device=dev | |||||
| master_ip=master_ip, | |||||
| port=port, | |||||
| world_size=world_size, | |||||
| rank=rank, | |||||
| device=dev, | |||||
| device_type=device_type, | |||||
| ) | ) | ||||
| if is_multimachine: | if is_multimachine: | ||||
| group_barrier() | group_barrier() | ||||
| @@ -70,13 +76,17 @@ class launcher: | |||||
| rank_start=0, | rank_start=0, | ||||
| master_ip="localhost", | master_ip="localhost", | ||||
| port=0, | port=0, | ||||
| device_type="xpu", | |||||
| ): | ): | ||||
| self.func = func | self.func = func | ||||
| self.n_gpus = n_gpus if n_gpus is not None else get_device_count_by_fork("gpu") | |||||
| self.n_gpus = ( | |||||
| n_gpus if n_gpus is not None else get_device_count_by_fork(device_type) | |||||
| ) | |||||
| self.world_size = world_size if world_size is not None else self.n_gpus | self.world_size = world_size if world_size is not None else self.n_gpus | ||||
| self.rank_start = rank_start | self.rank_start = rank_start | ||||
| self.master_ip = master_ip | self.master_ip = master_ip | ||||
| self.port = port | self.port = port | ||||
| self.device_type = device_type | |||||
| # master node create server | # master node create server | ||||
| if self.rank_start == 0: | if self.rank_start == 0: | ||||
| self.server = Server(self.port) | self.server = Server(self.port) | ||||
| @@ -99,6 +109,7 @@ class launcher: | |||||
| self.world_size, | self.world_size, | ||||
| dev + self.rank_start, | dev + self.rank_start, | ||||
| dev, | dev, | ||||
| self.device_type, | |||||
| args, | args, | ||||
| kwargs, | kwargs, | ||||
| queue, | queue, | ||||
| @@ -62,8 +62,8 @@ void init_common(py::module m) { | |||||
| return cn.get_mem_status_bytes(); | return cn.get_mem_status_bytes(); | ||||
| }) | }) | ||||
| .def("create_event", &CompNode::create_event, py::arg("flags") = 0ul) | .def("create_event", &CompNode::create_event, py::arg("flags") = 0ul) | ||||
| .def("_set_default_device", &set_default_device) | |||||
| .def("_get_default_device", &get_default_device) | |||||
| .def_static("_set_default_device", &set_default_device) | |||||
| .def_static("_get_default_device", &get_default_device) | |||||
| .def("__str__", &CompNode::to_string_logical) | .def("__str__", &CompNode::to_string_logical) | ||||
| .def("__repr__", [](const CompNode& cn) { | .def("__repr__", [](const CompNode& cn) { | ||||
| return py::str("\"" + cn.to_string() + "\" from \"" + cn.to_string_logical() + "\""); | return py::str("\"" + cn.to_string() + "\" from \"" + cn.to_string_logical() + "\""); | ||||
| @@ -179,6 +179,10 @@ void init_common(py::module m) { | |||||
| m.def("set_prealloc_config", &CompNode::set_prealloc_config, | m.def("set_prealloc_config", &CompNode::set_prealloc_config, | ||||
| "specifies how to pre-allocate from raw dev allocator"); | "specifies how to pre-allocate from raw dev allocator"); | ||||
| m.def("what_is_xpu", []{ | |||||
| return CompNode::Locator::parse("xpux").to_physical().type; | |||||
| }); | |||||
| init_npy_num_bfloat16(m); | init_npy_num_bfloat16(m); | ||||
| init_npy_num_intbx(m); | init_npy_num_intbx(m); | ||||
| init_dtypes(m); | init_dtypes(m); | ||||
| @@ -16,6 +16,7 @@ import pytest | |||||
| import megengine as mge | import megengine as mge | ||||
| import megengine.distributed as dist | import megengine.distributed as dist | ||||
| from megengine.core.ops.builtin import CollectiveComm, ParamPackConcat, ParamPackSplit | from megengine.core.ops.builtin import CollectiveComm, ParamPackConcat, ParamPackSplit | ||||
| from megengine.device import get_default_device | |||||
| from megengine.distributed.helper import ( | from megengine.distributed.helper import ( | ||||
| get_device_count_by_fork, | get_device_count_by_fork, | ||||
| param_pack_concat, | param_pack_concat, | ||||
| @@ -87,7 +88,8 @@ def test_new_group(): | |||||
| assert group.size == 2 | assert group.size == 2 | ||||
| assert group.key == "2,0" | assert group.key == "2,0" | ||||
| assert group.rank == ranks.index(rank) | assert group.rank == ranks.index(rank) | ||||
| assert group.comp_node == "gpu{}:2".format(rank) | |||||
| dt = get_default_device()[:-1] | |||||
| assert group.comp_node == "{}{}:2".format(dt, rank) | |||||
| worker() | worker() | ||||
| @@ -236,12 +236,12 @@ def test_io_remote(shape): | |||||
| def worker(val, shape): | def worker(val, shape): | ||||
| rank = dist.get_rank() | rank = dist.get_rank() | ||||
| if rank == 0: # remote send | if rank == 0: # remote send | ||||
| x = tensor(val, device="gpu0") | |||||
| x = tensor(val, device="xpu0") | |||||
| remote_send(x, 1) | remote_send(x, 1) | ||||
| sync() | sync() | ||||
| else: # remote recv | else: # remote recv | ||||
| y = remote_recv(0, shape, np.float32) | y = remote_recv(0, shape, np.float32) | ||||
| assert y.device == "gpu1" | |||||
| assert y.device == get_default_device() | |||||
| np.testing.assert_almost_equal(val, y.numpy()) | np.testing.assert_almost_equal(val, y.numpy()) | ||||
| val = np.random.random_sample(shape).astype("float32") | val = np.random.random_sample(shape).astype("float32") | ||||