GitOrigin-RevId: 1dd5a02a51
tags/v1.5.0
| @@ -1018,6 +1018,7 @@ endif() | |||||
| if(MGE_WITH_DISTRIBUTED) | if(MGE_WITH_DISTRIBUTED) | ||||
| set(MEGRAY_WITH_NCCL ${MGE_WITH_CUDA} CACHE BOOL "Override MegRay option" FORCE) | set(MEGRAY_WITH_NCCL ${MGE_WITH_CUDA} CACHE BOOL "Override MegRay option" FORCE) | ||||
| set(MEGRAY_WITH_SHM ${MGE_WITH_CUDA} CACHE BOOL "Override MegRay option" FORCE) | |||||
| set(MEGRAY_WITH_RCCL ${MGE_WITH_ROCM} CACHE BOOL "Override MegRay option" FORCE) | set(MEGRAY_WITH_RCCL ${MGE_WITH_ROCM} CACHE BOOL "Override MegRay option" FORCE) | ||||
| add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/MegRay) | add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/MegRay) | ||||
| endif() | endif() | ||||
| @@ -6,6 +6,9 @@ | |||||
| # Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
| # software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| from mprop import mproperty | |||||
| from . import group | |||||
| from .group import ( | from .group import ( | ||||
| WORLD, | WORLD, | ||||
| Group, | Group, | ||||
| @@ -19,7 +22,20 @@ from .group import ( | |||||
| init_process_group, | init_process_group, | ||||
| is_distributed, | is_distributed, | ||||
| new_group, | new_group, | ||||
| override_backend, | |||||
| ) | ) | ||||
| from .helper import bcast_list_, make_allreduce_cb, synchronized | from .helper import bcast_list_, make_allreduce_cb, synchronized | ||||
| from .launcher import launcher | from .launcher import launcher | ||||
| from .server import Client, Server | from .server import Client, Server | ||||
| @mproperty | |||||
| def backend(mod): | |||||
| assert group._sd, "please call init_process_group first" | |||||
| return group._sd.backend | |||||
| @backend.setter | |||||
| def backend(mod, val): | |||||
| assert group._sd, "please call init_process_group first" | |||||
| group._sd.backend = val | |||||
| @@ -14,9 +14,10 @@ from ..core._imperative_rt.core2 import apply | |||||
| from ..core.autodiff.grad import Function, _grad_manager_dict | from ..core.autodiff.grad import Function, _grad_manager_dict | ||||
| from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend | from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend | ||||
| from ..core.tensor.utils import isscalar, setscalar | from ..core.tensor.utils import isscalar, setscalar | ||||
| from ..device import get_default_device | |||||
| from ..device import get_default_device, what_is_xpu | |||||
| from ..tensor import Tensor | from ..tensor import Tensor | ||||
| from .group import WORLD, Group, get_backend, get_client, get_mm_server_addr, get_rank | |||||
| from . import group | |||||
| from .group import WORLD, Group, get_client, get_mm_server_addr, get_rank | |||||
| __all__ = [ | __all__ = [ | ||||
| "reduce_sum", | "reduce_sum", | ||||
| @@ -34,14 +35,30 @@ __all__ = [ | |||||
| ] | ] | ||||
| _device2backend = { | |||||
| "gpu": "nccl", | |||||
| "cuda": "nccl", | |||||
| "rocm": "rccl", | |||||
| } | |||||
| def _backend(): | |||||
| if group._sd.backend == "auto": | |||||
| return _device2backend[what_is_xpu()] | |||||
| else: | |||||
| return group._sd.backend | |||||
| def collective_comm(inp, mode, group, device): | def collective_comm(inp, mode, group, device): | ||||
| """Helper function for applying collective communication functions.""" | """Helper function for applying collective communication functions.""" | ||||
| assert isinstance(group, Group) | assert isinstance(group, Group) | ||||
| if group is None: | if group is None: | ||||
| return inp | return inp | ||||
| if device is None: | |||||
| device = "" | |||||
| addr, port = get_mm_server_addr() | addr, port = get_mm_server_addr() | ||||
| op = CollectiveComm( | op = CollectiveComm( | ||||
| key=group.key, | |||||
| key=group.key + _backend(), | |||||
| nr_devices=group.size, | nr_devices=group.size, | ||||
| rank=group.rank, | rank=group.rank, | ||||
| is_root=(group.rank == 0), | is_root=(group.rank == 0), | ||||
| @@ -50,7 +67,7 @@ def collective_comm(inp, mode, group, device): | |||||
| port=port, | port=port, | ||||
| mode=mode, | mode=mode, | ||||
| dtype=inp.dtype, | dtype=inp.dtype, | ||||
| backend=get_backend(), | |||||
| backend=_backend(), | |||||
| comp_node=device, | comp_node=device, | ||||
| ) | ) | ||||
| (result,) = apply(op, inp) | (result,) = apply(op, inp) | ||||
| @@ -112,8 +129,8 @@ def _bcast_tracer_state(group, inp): | |||||
| g._refkeeper.append(inp) | g._refkeeper.append(inp) | ||||
| def _dummy_input(shape, dtype, device=""): | |||||
| if device == "": | |||||
| def _dummy_input(shape, dtype, device=None): | |||||
| if device is None: | |||||
| device = get_default_device() | device = get_default_device() | ||||
| inp = Tensor(0, dtype=dtype, device=device) | inp = Tensor(0, dtype=dtype, device=device) | ||||
| if len(shape) > 0: | if len(shape) > 0: | ||||
| @@ -122,14 +139,14 @@ def _dummy_input(shape, dtype, device=""): | |||||
| class _ReduceSum(Function): | class _ReduceSum(Function): | ||||
| def __init__(self, group=WORLD, device=""): | |||||
| def __init__(self, group=WORLD, device=None): | |||||
| self.group = group | self.group = group | ||||
| self.out_device = device | self.out_device = device | ||||
| def forward(self, data): | def forward(self, data): | ||||
| self.in_device = str(data.device) | self.in_device = str(data.device) | ||||
| return collective_comm( | return collective_comm( | ||||
| data, CollectiveComm.Mode.REDUCE_SUM, self.group, self.out_device | |||||
| data, CollectiveComm.Mode.REDUCE_SUM, self.group, self.out_device, | |||||
| ) | ) | ||||
| def backward(self, grad): | def backward(self, grad): | ||||
| @@ -139,7 +156,7 @@ class _ReduceSum(Function): | |||||
| def reduce_sum( | def reduce_sum( | ||||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, | |||||
| ) -> Tensor: | ) -> Tensor: | ||||
| """ | """ | ||||
| Create reduce_sum operator for collective communication. | Create reduce_sum operator for collective communication. | ||||
| @@ -158,14 +175,14 @@ def reduce_sum( | |||||
| class _Broadcast(Function): | class _Broadcast(Function): | ||||
| def __init__(self, group=WORLD, device=""): | |||||
| def __init__(self, group=WORLD, device=None): | |||||
| self.group = group | self.group = group | ||||
| self.out_device = device | self.out_device = device | ||||
| def forward(self, data): | def forward(self, data): | ||||
| self.in_device = str(data.device) | self.in_device = str(data.device) | ||||
| return collective_comm( | return collective_comm( | ||||
| data, CollectiveComm.Mode.BROADCAST, self.group, self.out_device | |||||
| data, CollectiveComm.Mode.BROADCAST, self.group, self.out_device, | |||||
| ) | ) | ||||
| def backward(self, grad): | def backward(self, grad): | ||||
| @@ -175,7 +192,7 @@ class _Broadcast(Function): | |||||
| def broadcast( | def broadcast( | ||||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, | |||||
| ) -> Tensor: | ) -> Tensor: | ||||
| """ | """ | ||||
| Create broadcast operator for collective communication. | Create broadcast operator for collective communication. | ||||
| @@ -197,14 +214,14 @@ def broadcast( | |||||
| def _bcast_param( | def _bcast_param( | ||||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None | |||||
| ) -> Tensor: | ) -> Tensor: | ||||
| mode = CollectiveComm.Mode.BROADCAST | mode = CollectiveComm.Mode.BROADCAST | ||||
| return collective_comm(inp, mode, group, device) | return collective_comm(inp, mode, group, device) | ||||
| def all_gather( | def all_gather( | ||||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, | |||||
| ) -> Tensor: | ) -> Tensor: | ||||
| """ | """ | ||||
| Create all_gather operator for collective communication. | Create all_gather operator for collective communication. | ||||
| @@ -218,7 +235,7 @@ def all_gather( | |||||
| def reduce_scatter_sum( | def reduce_scatter_sum( | ||||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, | |||||
| ) -> Tensor: | ) -> Tensor: | ||||
| """ | """ | ||||
| Create reduce_scatter_sum operator for collective communication. | Create reduce_scatter_sum operator for collective communication. | ||||
| @@ -232,7 +249,7 @@ def reduce_scatter_sum( | |||||
| def all_reduce_sum( | def all_reduce_sum( | ||||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, | |||||
| ) -> Tensor: | ) -> Tensor: | ||||
| """ | """ | ||||
| Create all_reduce_sum operator for collective communication. | Create all_reduce_sum operator for collective communication. | ||||
| @@ -246,7 +263,7 @@ def all_reduce_sum( | |||||
| def all_reduce_max( | def all_reduce_max( | ||||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, | |||||
| ) -> Tensor: | ) -> Tensor: | ||||
| """ | """ | ||||
| Create all_reduce_max operator for collective communication. | Create all_reduce_max operator for collective communication. | ||||
| @@ -260,7 +277,7 @@ def all_reduce_max( | |||||
| def all_reduce_min( | def all_reduce_min( | ||||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, | |||||
| ) -> Tensor: | ) -> Tensor: | ||||
| """ | """ | ||||
| Create all_reduce_min operator for collective communication. | Create all_reduce_min operator for collective communication. | ||||
| @@ -274,7 +291,7 @@ def all_reduce_min( | |||||
| class _Gather(Function): | class _Gather(Function): | ||||
| def __init__(self, group=WORLD, device=""): | |||||
| def __init__(self, group=WORLD, device=None): | |||||
| self.group = group | self.group = group | ||||
| self.out_device = device | self.out_device = device | ||||
| @@ -291,7 +308,7 @@ class _Gather(Function): | |||||
| def gather( | def gather( | ||||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, | |||||
| ) -> Tensor: | ) -> Tensor: | ||||
| """ | """ | ||||
| Create gather operator for collective communication. | Create gather operator for collective communication. | ||||
| @@ -311,7 +328,7 @@ def gather( | |||||
| class _Scatter(Function): | class _Scatter(Function): | ||||
| def __init__(self, group=WORLD, device=""): | |||||
| def __init__(self, group=WORLD, device=None): | |||||
| self.group = group | self.group = group | ||||
| self.out_device = device | self.out_device = device | ||||
| @@ -328,7 +345,7 @@ class _Scatter(Function): | |||||
| def scatter( | def scatter( | ||||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, | |||||
| ) -> Tensor: | ) -> Tensor: | ||||
| """ | """ | ||||
| Create scatter operator for collective communication. | Create scatter operator for collective communication. | ||||
| @@ -350,7 +367,7 @@ def scatter( | |||||
| def all_to_all( | def all_to_all( | ||||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, | |||||
| ) -> Tensor: | ) -> Tensor: | ||||
| """ | """ | ||||
| Create all_to_all operator for collective communication. | Create all_to_all operator for collective communication. | ||||
| @@ -407,7 +424,7 @@ class _RemoteRecv(Function): | |||||
| remote_send(grad, self.op.rank_from) | remote_send(grad, self.op.rank_from) | ||||
| def remote_send(inp: Tensor, dest_rank: int) -> Tensor: | |||||
| def remote_send(inp: Tensor, dest_rank: int): | |||||
| """ | """ | ||||
| Send a Tensor to a remote process. | Send a Tensor to a remote process. | ||||
| @@ -423,13 +440,13 @@ def remote_send(inp: Tensor, dest_rank: int) -> Tensor: | |||||
| op.key = group.key | op.key = group.key | ||||
| op.addr, op.port = get_mm_server_addr() | op.addr, op.port = get_mm_server_addr() | ||||
| op.rank_to = dest_rank | op.rank_to = dest_rank | ||||
| op.backend = get_backend() | |||||
| op.backend = _backend() | |||||
| (out,) = apply(_RemoteSend(op), inp) | (out,) = apply(_RemoteSend(op), inp) | ||||
| _save_output_for_autodiff(inp, out) | _save_output_for_autodiff(inp, out) | ||||
| def remote_recv(src_rank: int, device: Optional[str] = None, inp=None,) -> Tensor: | |||||
| def remote_recv(src_rank: int, device: Optional[str] = None, inp=None) -> Tensor: | |||||
| """ | """ | ||||
| Receive a Tensor from a remote process. | Receive a Tensor from a remote process. | ||||
| @@ -459,7 +476,7 @@ def remote_recv(src_rank: int, device: Optional[str] = None, inp=None,) -> Tenso | |||||
| op.dtype = dtype | op.dtype = dtype | ||||
| op.addr, op.port = get_mm_server_addr() | op.addr, op.port = get_mm_server_addr() | ||||
| op.rank_from = src_rank | op.rank_from = src_rank | ||||
| op.backend = get_backend() | |||||
| op.backend = _backend() | |||||
| (ret,) = apply(_RemoteRecv(op), inp) | (ret,) = apply(_RemoteRecv(op), inp) | ||||
| if _isscalar: | if _isscalar: | ||||
| @@ -7,8 +7,11 @@ | |||||
| # software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| import time | import time | ||||
| from contextlib import contextmanager | |||||
| from typing import List, Optional, Tuple | from typing import List, Optional, Tuple | ||||
| from mprop import mproperty | |||||
| from ..device import set_default_device, what_is_xpu | from ..device import set_default_device, what_is_xpu | ||||
| from ..random import seed | from ..random import seed | ||||
| from .server import Client, Server | from .server import Client, Server | ||||
| @@ -26,6 +29,7 @@ class StaticData: | |||||
| backend = None | backend = None | ||||
| next_stream = None | next_stream = None | ||||
| device_type = None | device_type = None | ||||
| machine_ranks = None | |||||
| _sd = None | _sd = None | ||||
| @@ -55,6 +59,7 @@ class Group: | |||||
| self.proc_ranks = proc_ranks | self.proc_ranks = proc_ranks | ||||
| self.stream = _sd.next_stream | self.stream = _sd.next_stream | ||||
| _sd.next_stream += 1 | _sd.next_stream += 1 | ||||
| self.is_single_machine_cache = None | |||||
| def check(self, proc_ranks): | def check(self, proc_ranks): | ||||
| assert _sd is not None, "please call init_process_group first" | assert _sd is not None, "please call init_process_group first" | ||||
| @@ -83,17 +88,23 @@ class Group: | |||||
| assert len(self.proc_ranks) > 0, "invalid group" | assert len(self.proc_ranks) > 0, "invalid group" | ||||
| return "{}{}:{}".format(_sd.device_type, _sd.device, self.stream) | return "{}{}:{}".format(_sd.device_type, _sd.device, self.stream) | ||||
| WORLD = Group([]) | |||||
| @property | |||||
| def is_single_machine(self): | |||||
| if self.is_single_machine_cache is not None: | |||||
| return self.is_single_machine_cache | |||||
| assert _sd is not None, "please call init_process_group first" | |||||
| for rank in self.proc_ranks: | |||||
| if rank not in _sd.machine_ranks: | |||||
| self.is_single_machine_cache = False | |||||
| return False | |||||
| self.is_single_machine_cache = True | |||||
| return True | |||||
| _device2backend = { | |||||
| "gpu": "nccl", | |||||
| "cuda": "nccl", | |||||
| "rocm": "rccl", | |||||
| } | |||||
| WORLD = Group([]) | |||||
| _backends = {"nccl", "rccl", "ucx"} | |||||
| _devices = {"gpu", "cuda", "rocm"} | |||||
| _backends = {"nccl", "rccl", "ucx", "auto"} | |||||
| def init_process_group( | def init_process_group( | ||||
| @@ -102,7 +113,7 @@ def init_process_group( | |||||
| world_size: int, | world_size: int, | ||||
| rank: int, | rank: int, | ||||
| device: int, | device: int, | ||||
| backend: Optional[str] = None, | |||||
| backend: Optional[str] = "auto", | |||||
| device_type: str = "xpu", | device_type: str = "xpu", | ||||
| ) -> None: | ) -> None: | ||||
| """ | """ | ||||
| @@ -113,10 +124,9 @@ def init_process_group( | |||||
| :param world_size: total number of processes participating in the job. | :param world_size: total number of processes participating in the job. | ||||
| :param rank: rank of the current process. | :param rank: rank of the current process. | ||||
| :param device: the GPU device id to bind this process to. | :param device: the GPU device id to bind this process to. | ||||
| :param backend: communicator backend, currently support 'nccl' and 'ucx'. | |||||
| :param backend: communicator backend, currently support 'nccl' and 'shm'. | |||||
| """ | """ | ||||
| physical_device_type = what_is_xpu() if device_type == "xpu" else device_type | physical_device_type = what_is_xpu() if device_type == "xpu" else device_type | ||||
| backend = _device2backend[physical_device_type] if backend is None else backend | |||||
| if not isinstance(master_ip, str): | if not isinstance(master_ip, str): | ||||
| raise TypeError("Expect type str but got {}".format(type(master_ip))) | raise TypeError("Expect type str but got {}".format(type(master_ip))) | ||||
| if not isinstance(port, int): | if not isinstance(port, int): | ||||
| @@ -131,7 +141,7 @@ def init_process_group( | |||||
| raise ValueError( | raise ValueError( | ||||
| "backend should be one of {} but got {}".format(_backends, backend) | "backend should be one of {} but got {}".format(_backends, backend) | ||||
| ) | ) | ||||
| if physical_device_type not in _device2backend: | |||||
| if physical_device_type not in _devices: | |||||
| raise ValueError( | raise ValueError( | ||||
| "{} is not a valid distributed device type".format(device_type) | "{} is not a valid distributed device type".format(device_type) | ||||
| ) | ) | ||||
| @@ -161,6 +171,30 @@ def init_process_group( | |||||
| seed(int(time.time()) + rank) | seed(int(time.time()) + rank) | ||||
| def _set_machine_ranks(ranks) -> None: | |||||
| global _sd | |||||
| assert _sd is not None | |||||
| _sd.machine_ranks = ranks | |||||
| @contextmanager | |||||
| def override_backend(new_backend: str): | |||||
| """ | |||||
| Override distributed backend | |||||
| :param new_backend: communicator backend set in this context. | |||||
| """ | |||||
| global _sd | |||||
| assert _sd, "please call init_process_group first" | |||||
| old_backend = _sd.backend | |||||
| _sd.backend = new_backend | |||||
| try: | |||||
| yield | |||||
| finally: | |||||
| _sd.backend = old_backend | |||||
| def is_distributed() -> bool: | def is_distributed() -> bool: | ||||
| """Return True if the distributed process group has been initialized.""" | """Return True if the distributed process group has been initialized.""" | ||||
| return _sd is not None | return _sd is not None | ||||
| @@ -22,8 +22,9 @@ from ..core.ops.builtin import ParamPackConcat, ParamPackSplit | |||||
| from ..functional.tensor import copy | from ..functional.tensor import copy | ||||
| from ..tensor import Tensor | from ..tensor import Tensor | ||||
| from ..utils.future import Future | from ..utils.future import Future | ||||
| from . import group as _group | |||||
| from .functional import _bcast_param, all_reduce_sum, broadcast | from .functional import _bcast_param, all_reduce_sum, broadcast | ||||
| from .group import WORLD, Group, group_barrier, is_distributed | |||||
| from .group import WORLD, Group, group_barrier, is_distributed, override_backend | |||||
| def param_pack_split(inp: Tensor, offsets: list, shapes: list): | def param_pack_split(inp: Tensor, offsets: list, shapes: list): | ||||
| @@ -118,10 +119,30 @@ def get_offsets(shapes): | |||||
| return offsets | return offsets | ||||
| _enable_p2p_cache = None | |||||
| def _check_enable_p2p(): | |||||
| global _enable_p2p_cache | |||||
| if _enable_p2p_cache is not None: | |||||
| return _enable_p2p_cache | |||||
| cmd = ["nvidia-smi", "topo", "-p2p", "w"] | |||||
| import subprocess | |||||
| output = subprocess.run(cmd, stdout=subprocess.PIPE).stdout | |||||
| if output.count(b"OK") > 1: | |||||
| _enable_p2p_cache = True | |||||
| return True | |||||
| else: | |||||
| _enable_p2p_cache = False | |||||
| return False | |||||
| def pack_allreduce_split(pack_list, shapes, group, reduce_method): | def pack_allreduce_split(pack_list, shapes, group, reduce_method): | ||||
| offsets_val = get_offsets(shapes) | offsets_val = get_offsets(shapes) | ||||
| offsets = Tensor(offsets_val) | offsets = Tensor(offsets_val) | ||||
| packed_grads = param_pack_concat(pack_list, offsets, offsets_val) | packed_grads = param_pack_concat(pack_list, offsets, offsets_val) | ||||
| packed_grads = all_reduce_sum(packed_grads, group, group.comp_node) | packed_grads = all_reduce_sum(packed_grads, group, group.comp_node) | ||||
| if reduce_method == "mean": | if reduce_method == "mean": | ||||
| packed_grads /= group.size | packed_grads /= group.size | ||||
| @@ -207,9 +228,10 @@ class AllreduceCallback: | |||||
| :param reduce_method: the method to reduce gradiants. | :param reduce_method: the method to reduce gradiants. | ||||
| :param group: communication group. | :param group: communication group. | ||||
| :param backend: override distributed backend in allreduce | |||||
| """ | """ | ||||
| def __init__(self, reduce_method: str, group: Group = WORLD): | |||||
| def __init__(self, reduce_method: str, group: Group = WORLD, backend: str = None): | |||||
| reduce_method = reduce_method.lower() | reduce_method = reduce_method.lower() | ||||
| assert reduce_method in ["sum", "mean"], "reduce_method should be sum or mean" | assert reduce_method in ["sum", "mean"], "reduce_method should be sum or mean" | ||||
| self._reduce_method = reduce_method | self._reduce_method = reduce_method | ||||
| @@ -217,6 +239,15 @@ class AllreduceCallback: | |||||
| self._marked_gm = WeakSet() | self._marked_gm = WeakSet() | ||||
| self._param_pack_thd = 10 * 1024 * 1024 | self._param_pack_thd = 10 * 1024 * 1024 | ||||
| self._reset() | self._reset() | ||||
| if backend is None: | |||||
| assert _group._sd, "please call init_process_group first" | |||||
| backend = _group._sd.backend | |||||
| if backend == "auto": | |||||
| if group.is_single_machine and not _check_enable_p2p(): | |||||
| backend = "shm" | |||||
| else: | |||||
| backend = "nccl" | |||||
| self._backend = backend | |||||
| def _reset(self): | def _reset(self): | ||||
| self._params = [] | self._params = [] | ||||
| @@ -231,9 +262,10 @@ class AllreduceCallback: | |||||
| return | return | ||||
| grad_list = [self._gradients_dict[p] for p in self._packing_list[dtype]] | grad_list = [self._gradients_dict[p] for p in self._packing_list[dtype]] | ||||
| shapes = [p._tuple_shape for p in self._packing_list[dtype]] | shapes = [p._tuple_shape for p in self._packing_list[dtype]] | ||||
| reduced_grads = pack_allreduce_split( | |||||
| grad_list, shapes, self._group, self._reduce_method | |||||
| ) | |||||
| with override_backend(self._backend): | |||||
| reduced_grads = pack_allreduce_split( | |||||
| grad_list, shapes, self._group, self._reduce_method | |||||
| ) | |||||
| for param, grad in zip(self._packing_list[dtype], reduced_grads): | for param, grad in zip(self._packing_list[dtype], reduced_grads): | ||||
| self._gradients_dict[param] = grad | self._gradients_dict[param] = grad | ||||
| self._packing_list[dtype] = [] | self._packing_list[dtype] = [] | ||||
| @@ -14,7 +14,7 @@ import queue | |||||
| from .. import _exit | from .. import _exit | ||||
| from ..core._imperative_rt.core2 import full_sync | from ..core._imperative_rt.core2 import full_sync | ||||
| from ..logger import get_logger | from ..logger import get_logger | ||||
| from .group import group_barrier, init_process_group | |||||
| from .group import _set_machine_ranks, group_barrier, init_process_group | |||||
| from .helper import _check_device_initialized, get_device_count_by_fork | from .helper import _check_device_initialized, get_device_count_by_fork | ||||
| from .server import Client, Server | from .server import Client, Server | ||||
| @@ -34,7 +34,9 @@ def _run_wrapped( | |||||
| device_type, | device_type, | ||||
| args, | args, | ||||
| kwargs, | kwargs, | ||||
| backend, | |||||
| queue: mp.Queue, | queue: mp.Queue, | ||||
| machine_ranks: list, | |||||
| ): | ): | ||||
| """Init distributed process group and run wrapped function.""" | """Init distributed process group and run wrapped function.""" | ||||
| _check_device_initialized(device_type) | _check_device_initialized(device_type) | ||||
| @@ -44,10 +46,12 @@ def _run_wrapped( | |||||
| world_size=world_size, | world_size=world_size, | ||||
| rank=rank, | rank=rank, | ||||
| device=dev, | device=dev, | ||||
| backend=backend, | |||||
| device_type=device_type, | device_type=device_type, | ||||
| ) | ) | ||||
| # set NCCL_LAUNCH_MODE to avoid deadlock | # set NCCL_LAUNCH_MODE to avoid deadlock | ||||
| os.environ["NCCL_LAUNCH_MODE"] = "PARALLEL" | os.environ["NCCL_LAUNCH_MODE"] = "PARALLEL" | ||||
| _set_machine_ranks(machine_ranks) | |||||
| if is_multimachine: | if is_multimachine: | ||||
| group_barrier() | group_barrier() | ||||
| ret = func(*args, **kwargs) | ret = func(*args, **kwargs) | ||||
| @@ -67,6 +71,7 @@ class launcher: | |||||
| :param rank_start: start number for rank. | :param rank_start: start number for rank. | ||||
| :param master_ip: ip address for master node (where the rank 0 is). | :param master_ip: ip address for master node (where the rank 0 is). | ||||
| :param port: server port for distributed server. | :param port: server port for distributed server. | ||||
| :param backend: set default collective communication backend. | |||||
| """ | """ | ||||
| def __new__(cls, *args, **kwargs): | def __new__(cls, *args, **kwargs): | ||||
| @@ -83,6 +88,7 @@ class launcher: | |||||
| master_ip="localhost", | master_ip="localhost", | ||||
| port=0, | port=0, | ||||
| device_type="xpu", | device_type="xpu", | ||||
| backend="auto", | |||||
| ): | ): | ||||
| self.func = func | self.func = func | ||||
| self.n_gpus = ( | self.n_gpus = ( | ||||
| @@ -93,6 +99,7 @@ class launcher: | |||||
| self.master_ip = master_ip | self.master_ip = master_ip | ||||
| self.port = port | self.port = port | ||||
| self.device_type = device_type | self.device_type = device_type | ||||
| self.backend = backend | |||||
| # master node create server | # master node create server | ||||
| if self.rank_start == 0: | if self.rank_start == 0: | ||||
| self.server = Server(self.port) | self.server = Server(self.port) | ||||
| @@ -104,6 +111,7 @@ class launcher: | |||||
| procs = [] | procs = [] | ||||
| queue = mp.Queue(self.n_gpus) | queue = mp.Queue(self.n_gpus) | ||||
| results = [None] * self.n_gpus | results = [None] * self.n_gpus | ||||
| machine_ranks = [i + self.rank_start for i in range(self.n_gpus)] | |||||
| for dev in range(self.n_gpus): | for dev in range(self.n_gpus): | ||||
| p = mp.Process( | p = mp.Process( | ||||
| target=_run_wrapped, | target=_run_wrapped, | ||||
| @@ -118,7 +126,9 @@ class launcher: | |||||
| self.device_type, | self.device_type, | ||||
| args, | args, | ||||
| kwargs, | kwargs, | ||||
| self.backend, | |||||
| queue, | queue, | ||||
| machine_ranks, | |||||
| ), | ), | ||||
| ) | ) | ||||
| p.start() | p.start() | ||||
| @@ -11,6 +11,7 @@ | |||||
| #include "megbrain/opr/megray_helper.h" | #include "megbrain/opr/megray_helper.h" | ||||
| #include "megbrain/comp_node_env.h" | #include "megbrain/comp_node_env.h" | ||||
| #include "megray/common.h" | |||||
| using namespace mgb; | using namespace mgb; | ||||
| using namespace opr; | using namespace opr; | ||||
| @@ -39,6 +40,8 @@ MegRay::Backend mgb::opr::get_megray_backend(const std::string& backend) { | |||||
| return MegRay::MEGRAY_RCCL; | return MegRay::MEGRAY_RCCL; | ||||
| } else if (backend == "ucx") { | } else if (backend == "ucx") { | ||||
| return MegRay::MEGRAY_UCX; | return MegRay::MEGRAY_UCX; | ||||
| } else if (backend == "shm") { | |||||
| return MegRay::MEGRAY_SHM; | |||||
| } else { | } else { | ||||
| mgb_throw(MegBrainError, "back CollectiveComm backend"); | mgb_throw(MegBrainError, "back CollectiveComm backend"); | ||||
| } | } | ||||
| @@ -90,7 +93,7 @@ std::shared_ptr<MegRay::Communicator> MegRayCommBuilder::get_megray_comm( | |||||
| if (rank == root) { | if (rank == root) { | ||||
| char* c = MegRay::get_host_ip(); | char* c = MegRay::get_host_ip(); | ||||
| master_ip = std::string(c); | master_ip = std::string(c); | ||||
| delete c; | |||||
| delete [] c; | |||||
| port = MegRay::get_free_port(); | port = MegRay::get_free_port(); | ||||
| auto ret = MegRay::create_server(size, port); | auto ret = MegRay::create_server(size, port); | ||||
| mgb_assert(ret == MegRay::Status::MEGRAY_OK); | mgb_assert(ret == MegRay::Status::MEGRAY_OK); | ||||