GitOrigin-RevId: 1dd5a02a51
tags/v1.5.0
| @@ -1018,6 +1018,7 @@ endif() | |||
| if(MGE_WITH_DISTRIBUTED) | |||
| set(MEGRAY_WITH_NCCL ${MGE_WITH_CUDA} CACHE BOOL "Override MegRay option" FORCE) | |||
| set(MEGRAY_WITH_SHM ${MGE_WITH_CUDA} CACHE BOOL "Override MegRay option" FORCE) | |||
| set(MEGRAY_WITH_RCCL ${MGE_WITH_ROCM} CACHE BOOL "Override MegRay option" FORCE) | |||
| add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/MegRay) | |||
| endif() | |||
| @@ -6,6 +6,9 @@ | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from mprop import mproperty | |||
| from . import group | |||
| from .group import ( | |||
| WORLD, | |||
| Group, | |||
| @@ -19,7 +22,20 @@ from .group import ( | |||
| init_process_group, | |||
| is_distributed, | |||
| new_group, | |||
| override_backend, | |||
| ) | |||
| from .helper import bcast_list_, make_allreduce_cb, synchronized | |||
| from .launcher import launcher | |||
| from .server import Client, Server | |||
| @mproperty | |||
| def backend(mod): | |||
| assert group._sd, "please call init_process_group first" | |||
| return group._sd.backend | |||
| @backend.setter | |||
| def backend(mod, val): | |||
| assert group._sd, "please call init_process_group first" | |||
| group._sd.backend = val | |||
| @@ -14,9 +14,10 @@ from ..core._imperative_rt.core2 import apply | |||
| from ..core.autodiff.grad import Function, _grad_manager_dict | |||
| from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend | |||
| from ..core.tensor.utils import isscalar, setscalar | |||
| from ..device import get_default_device | |||
| from ..device import get_default_device, what_is_xpu | |||
| from ..tensor import Tensor | |||
| from .group import WORLD, Group, get_backend, get_client, get_mm_server_addr, get_rank | |||
| from . import group | |||
| from .group import WORLD, Group, get_client, get_mm_server_addr, get_rank | |||
| __all__ = [ | |||
| "reduce_sum", | |||
| @@ -34,14 +35,30 @@ __all__ = [ | |||
| ] | |||
| _device2backend = { | |||
| "gpu": "nccl", | |||
| "cuda": "nccl", | |||
| "rocm": "rccl", | |||
| } | |||
| def _backend(): | |||
| if group._sd.backend == "auto": | |||
| return _device2backend[what_is_xpu()] | |||
| else: | |||
| return group._sd.backend | |||
| def collective_comm(inp, mode, group, device): | |||
| """Helper function for applying collective communication functions.""" | |||
| assert isinstance(group, Group) | |||
| if group is None: | |||
| return inp | |||
| if device is None: | |||
| device = "" | |||
| addr, port = get_mm_server_addr() | |||
| op = CollectiveComm( | |||
| key=group.key, | |||
| key=group.key + _backend(), | |||
| nr_devices=group.size, | |||
| rank=group.rank, | |||
| is_root=(group.rank == 0), | |||
| @@ -50,7 +67,7 @@ def collective_comm(inp, mode, group, device): | |||
| port=port, | |||
| mode=mode, | |||
| dtype=inp.dtype, | |||
| backend=get_backend(), | |||
| backend=_backend(), | |||
| comp_node=device, | |||
| ) | |||
| (result,) = apply(op, inp) | |||
| @@ -112,8 +129,8 @@ def _bcast_tracer_state(group, inp): | |||
| g._refkeeper.append(inp) | |||
| def _dummy_input(shape, dtype, device=""): | |||
| if device == "": | |||
| def _dummy_input(shape, dtype, device=None): | |||
| if device is None: | |||
| device = get_default_device() | |||
| inp = Tensor(0, dtype=dtype, device=device) | |||
| if len(shape) > 0: | |||
| @@ -122,14 +139,14 @@ def _dummy_input(shape, dtype, device=""): | |||
| class _ReduceSum(Function): | |||
| def __init__(self, group=WORLD, device=""): | |||
| def __init__(self, group=WORLD, device=None): | |||
| self.group = group | |||
| self.out_device = device | |||
| def forward(self, data): | |||
| self.in_device = str(data.device) | |||
| return collective_comm( | |||
| data, CollectiveComm.Mode.REDUCE_SUM, self.group, self.out_device | |||
| data, CollectiveComm.Mode.REDUCE_SUM, self.group, self.out_device, | |||
| ) | |||
| def backward(self, grad): | |||
| @@ -139,7 +156,7 @@ class _ReduceSum(Function): | |||
| def reduce_sum( | |||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, | |||
| ) -> Tensor: | |||
| """ | |||
| Create reduce_sum operator for collective communication. | |||
| @@ -158,14 +175,14 @@ def reduce_sum( | |||
| class _Broadcast(Function): | |||
| def __init__(self, group=WORLD, device=""): | |||
| def __init__(self, group=WORLD, device=None): | |||
| self.group = group | |||
| self.out_device = device | |||
| def forward(self, data): | |||
| self.in_device = str(data.device) | |||
| return collective_comm( | |||
| data, CollectiveComm.Mode.BROADCAST, self.group, self.out_device | |||
| data, CollectiveComm.Mode.BROADCAST, self.group, self.out_device, | |||
| ) | |||
| def backward(self, grad): | |||
| @@ -175,7 +192,7 @@ class _Broadcast(Function): | |||
| def broadcast( | |||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, | |||
| ) -> Tensor: | |||
| """ | |||
| Create broadcast operator for collective communication. | |||
| @@ -197,14 +214,14 @@ def broadcast( | |||
| def _bcast_param( | |||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None | |||
| ) -> Tensor: | |||
| mode = CollectiveComm.Mode.BROADCAST | |||
| return collective_comm(inp, mode, group, device) | |||
| def all_gather( | |||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, | |||
| ) -> Tensor: | |||
| """ | |||
| Create all_gather operator for collective communication. | |||
| @@ -218,7 +235,7 @@ def all_gather( | |||
| def reduce_scatter_sum( | |||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, | |||
| ) -> Tensor: | |||
| """ | |||
| Create reduce_scatter_sum operator for collective communication. | |||
| @@ -232,7 +249,7 @@ def reduce_scatter_sum( | |||
| def all_reduce_sum( | |||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, | |||
| ) -> Tensor: | |||
| """ | |||
| Create all_reduce_sum operator for collective communication. | |||
| @@ -246,7 +263,7 @@ def all_reduce_sum( | |||
| def all_reduce_max( | |||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, | |||
| ) -> Tensor: | |||
| """ | |||
| Create all_reduce_max operator for collective communication. | |||
| @@ -260,7 +277,7 @@ def all_reduce_max( | |||
| def all_reduce_min( | |||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, | |||
| ) -> Tensor: | |||
| """ | |||
| Create all_reduce_min operator for collective communication. | |||
| @@ -274,7 +291,7 @@ def all_reduce_min( | |||
| class _Gather(Function): | |||
| def __init__(self, group=WORLD, device=""): | |||
| def __init__(self, group=WORLD, device=None): | |||
| self.group = group | |||
| self.out_device = device | |||
| @@ -291,7 +308,7 @@ class _Gather(Function): | |||
| def gather( | |||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, | |||
| ) -> Tensor: | |||
| """ | |||
| Create gather operator for collective communication. | |||
| @@ -311,7 +328,7 @@ def gather( | |||
| class _Scatter(Function): | |||
| def __init__(self, group=WORLD, device=""): | |||
| def __init__(self, group=WORLD, device=None): | |||
| self.group = group | |||
| self.out_device = device | |||
| @@ -328,7 +345,7 @@ class _Scatter(Function): | |||
| def scatter( | |||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, | |||
| ) -> Tensor: | |||
| """ | |||
| Create scatter operator for collective communication. | |||
| @@ -350,7 +367,7 @@ def scatter( | |||
| def all_to_all( | |||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
| inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, | |||
| ) -> Tensor: | |||
| """ | |||
| Create all_to_all operator for collective communication. | |||
| @@ -407,7 +424,7 @@ class _RemoteRecv(Function): | |||
| remote_send(grad, self.op.rank_from) | |||
| def remote_send(inp: Tensor, dest_rank: int) -> Tensor: | |||
| def remote_send(inp: Tensor, dest_rank: int): | |||
| """ | |||
| Send a Tensor to a remote process. | |||
| @@ -423,13 +440,13 @@ def remote_send(inp: Tensor, dest_rank: int) -> Tensor: | |||
| op.key = group.key | |||
| op.addr, op.port = get_mm_server_addr() | |||
| op.rank_to = dest_rank | |||
| op.backend = get_backend() | |||
| op.backend = _backend() | |||
| (out,) = apply(_RemoteSend(op), inp) | |||
| _save_output_for_autodiff(inp, out) | |||
| def remote_recv(src_rank: int, device: Optional[str] = None, inp=None,) -> Tensor: | |||
| def remote_recv(src_rank: int, device: Optional[str] = None, inp=None) -> Tensor: | |||
| """ | |||
| Receive a Tensor from a remote process. | |||
| @@ -459,7 +476,7 @@ def remote_recv(src_rank: int, device: Optional[str] = None, inp=None,) -> Tenso | |||
| op.dtype = dtype | |||
| op.addr, op.port = get_mm_server_addr() | |||
| op.rank_from = src_rank | |||
| op.backend = get_backend() | |||
| op.backend = _backend() | |||
| (ret,) = apply(_RemoteRecv(op), inp) | |||
| if _isscalar: | |||
| @@ -7,8 +7,11 @@ | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import time | |||
| from contextlib import contextmanager | |||
| from typing import List, Optional, Tuple | |||
| from mprop import mproperty | |||
| from ..device import set_default_device, what_is_xpu | |||
| from ..random import seed | |||
| from .server import Client, Server | |||
| @@ -26,6 +29,7 @@ class StaticData: | |||
| backend = None | |||
| next_stream = None | |||
| device_type = None | |||
| machine_ranks = None | |||
| _sd = None | |||
| @@ -55,6 +59,7 @@ class Group: | |||
| self.proc_ranks = proc_ranks | |||
| self.stream = _sd.next_stream | |||
| _sd.next_stream += 1 | |||
| self.is_single_machine_cache = None | |||
| def check(self, proc_ranks): | |||
| assert _sd is not None, "please call init_process_group first" | |||
| @@ -83,17 +88,23 @@ class Group: | |||
| assert len(self.proc_ranks) > 0, "invalid group" | |||
| return "{}{}:{}".format(_sd.device_type, _sd.device, self.stream) | |||
| WORLD = Group([]) | |||
| @property | |||
| def is_single_machine(self): | |||
| if self.is_single_machine_cache is not None: | |||
| return self.is_single_machine_cache | |||
| assert _sd is not None, "please call init_process_group first" | |||
| for rank in self.proc_ranks: | |||
| if rank not in _sd.machine_ranks: | |||
| self.is_single_machine_cache = False | |||
| return False | |||
| self.is_single_machine_cache = True | |||
| return True | |||
| _device2backend = { | |||
| "gpu": "nccl", | |||
| "cuda": "nccl", | |||
| "rocm": "rccl", | |||
| } | |||
| WORLD = Group([]) | |||
| _backends = {"nccl", "rccl", "ucx"} | |||
| _devices = {"gpu", "cuda", "rocm"} | |||
| _backends = {"nccl", "rccl", "ucx", "auto"} | |||
| def init_process_group( | |||
| @@ -102,7 +113,7 @@ def init_process_group( | |||
| world_size: int, | |||
| rank: int, | |||
| device: int, | |||
| backend: Optional[str] = None, | |||
| backend: Optional[str] = "auto", | |||
| device_type: str = "xpu", | |||
| ) -> None: | |||
| """ | |||
| @@ -113,10 +124,9 @@ def init_process_group( | |||
| :param world_size: total number of processes participating in the job. | |||
| :param rank: rank of the current process. | |||
| :param device: the GPU device id to bind this process to. | |||
| :param backend: communicator backend, currently support 'nccl' and 'ucx'. | |||
| :param backend: communicator backend, currently support 'nccl' and 'shm'. | |||
| """ | |||
| physical_device_type = what_is_xpu() if device_type == "xpu" else device_type | |||
| backend = _device2backend[physical_device_type] if backend is None else backend | |||
| if not isinstance(master_ip, str): | |||
| raise TypeError("Expect type str but got {}".format(type(master_ip))) | |||
| if not isinstance(port, int): | |||
| @@ -131,7 +141,7 @@ def init_process_group( | |||
| raise ValueError( | |||
| "backend should be one of {} but got {}".format(_backends, backend) | |||
| ) | |||
| if physical_device_type not in _device2backend: | |||
| if physical_device_type not in _devices: | |||
| raise ValueError( | |||
| "{} is not a valid distributed device type".format(device_type) | |||
| ) | |||
| @@ -161,6 +171,30 @@ def init_process_group( | |||
| seed(int(time.time()) + rank) | |||
| def _set_machine_ranks(ranks) -> None: | |||
| global _sd | |||
| assert _sd is not None | |||
| _sd.machine_ranks = ranks | |||
| @contextmanager | |||
| def override_backend(new_backend: str): | |||
| """ | |||
| Override distributed backend | |||
| :param new_backend: communicator backend set in this context. | |||
| """ | |||
| global _sd | |||
| assert _sd, "please call init_process_group first" | |||
| old_backend = _sd.backend | |||
| _sd.backend = new_backend | |||
| try: | |||
| yield | |||
| finally: | |||
| _sd.backend = old_backend | |||
| def is_distributed() -> bool: | |||
| """Return True if the distributed process group has been initialized.""" | |||
| return _sd is not None | |||
| @@ -22,8 +22,9 @@ from ..core.ops.builtin import ParamPackConcat, ParamPackSplit | |||
| from ..functional.tensor import copy | |||
| from ..tensor import Tensor | |||
| from ..utils.future import Future | |||
| from . import group as _group | |||
| from .functional import _bcast_param, all_reduce_sum, broadcast | |||
| from .group import WORLD, Group, group_barrier, is_distributed | |||
| from .group import WORLD, Group, group_barrier, is_distributed, override_backend | |||
| def param_pack_split(inp: Tensor, offsets: list, shapes: list): | |||
| @@ -118,10 +119,30 @@ def get_offsets(shapes): | |||
| return offsets | |||
| _enable_p2p_cache = None | |||
| def _check_enable_p2p(): | |||
| global _enable_p2p_cache | |||
| if _enable_p2p_cache is not None: | |||
| return _enable_p2p_cache | |||
| cmd = ["nvidia-smi", "topo", "-p2p", "w"] | |||
| import subprocess | |||
| output = subprocess.run(cmd, stdout=subprocess.PIPE).stdout | |||
| if output.count(b"OK") > 1: | |||
| _enable_p2p_cache = True | |||
| return True | |||
| else: | |||
| _enable_p2p_cache = False | |||
| return False | |||
| def pack_allreduce_split(pack_list, shapes, group, reduce_method): | |||
| offsets_val = get_offsets(shapes) | |||
| offsets = Tensor(offsets_val) | |||
| packed_grads = param_pack_concat(pack_list, offsets, offsets_val) | |||
| packed_grads = all_reduce_sum(packed_grads, group, group.comp_node) | |||
| if reduce_method == "mean": | |||
| packed_grads /= group.size | |||
| @@ -207,9 +228,10 @@ class AllreduceCallback: | |||
| :param reduce_method: the method to reduce gradiants. | |||
| :param group: communication group. | |||
| :param backend: override distributed backend in allreduce | |||
| """ | |||
| def __init__(self, reduce_method: str, group: Group = WORLD): | |||
| def __init__(self, reduce_method: str, group: Group = WORLD, backend: str = None): | |||
| reduce_method = reduce_method.lower() | |||
| assert reduce_method in ["sum", "mean"], "reduce_method should be sum or mean" | |||
| self._reduce_method = reduce_method | |||
| @@ -217,6 +239,15 @@ class AllreduceCallback: | |||
| self._marked_gm = WeakSet() | |||
| self._param_pack_thd = 10 * 1024 * 1024 | |||
| self._reset() | |||
| if backend is None: | |||
| assert _group._sd, "please call init_process_group first" | |||
| backend = _group._sd.backend | |||
| if backend == "auto": | |||
| if group.is_single_machine and not _check_enable_p2p(): | |||
| backend = "shm" | |||
| else: | |||
| backend = "nccl" | |||
| self._backend = backend | |||
| def _reset(self): | |||
| self._params = [] | |||
| @@ -231,9 +262,10 @@ class AllreduceCallback: | |||
| return | |||
| grad_list = [self._gradients_dict[p] for p in self._packing_list[dtype]] | |||
| shapes = [p._tuple_shape for p in self._packing_list[dtype]] | |||
| reduced_grads = pack_allreduce_split( | |||
| grad_list, shapes, self._group, self._reduce_method | |||
| ) | |||
| with override_backend(self._backend): | |||
| reduced_grads = pack_allreduce_split( | |||
| grad_list, shapes, self._group, self._reduce_method | |||
| ) | |||
| for param, grad in zip(self._packing_list[dtype], reduced_grads): | |||
| self._gradients_dict[param] = grad | |||
| self._packing_list[dtype] = [] | |||
| @@ -14,7 +14,7 @@ import queue | |||
| from .. import _exit | |||
| from ..core._imperative_rt.core2 import full_sync | |||
| from ..logger import get_logger | |||
| from .group import group_barrier, init_process_group | |||
| from .group import _set_machine_ranks, group_barrier, init_process_group | |||
| from .helper import _check_device_initialized, get_device_count_by_fork | |||
| from .server import Client, Server | |||
| @@ -34,7 +34,9 @@ def _run_wrapped( | |||
| device_type, | |||
| args, | |||
| kwargs, | |||
| backend, | |||
| queue: mp.Queue, | |||
| machine_ranks: list, | |||
| ): | |||
| """Init distributed process group and run wrapped function.""" | |||
| _check_device_initialized(device_type) | |||
| @@ -44,10 +46,12 @@ def _run_wrapped( | |||
| world_size=world_size, | |||
| rank=rank, | |||
| device=dev, | |||
| backend=backend, | |||
| device_type=device_type, | |||
| ) | |||
| # set NCCL_LAUNCH_MODE to avoid deadlock | |||
| os.environ["NCCL_LAUNCH_MODE"] = "PARALLEL" | |||
| _set_machine_ranks(machine_ranks) | |||
| if is_multimachine: | |||
| group_barrier() | |||
| ret = func(*args, **kwargs) | |||
| @@ -67,6 +71,7 @@ class launcher: | |||
| :param rank_start: start number for rank. | |||
| :param master_ip: ip address for master node (where the rank 0 is). | |||
| :param port: server port for distributed server. | |||
| :param backend: set default collective communication backend. | |||
| """ | |||
| def __new__(cls, *args, **kwargs): | |||
| @@ -83,6 +88,7 @@ class launcher: | |||
| master_ip="localhost", | |||
| port=0, | |||
| device_type="xpu", | |||
| backend="auto", | |||
| ): | |||
| self.func = func | |||
| self.n_gpus = ( | |||
| @@ -93,6 +99,7 @@ class launcher: | |||
| self.master_ip = master_ip | |||
| self.port = port | |||
| self.device_type = device_type | |||
| self.backend = backend | |||
| # master node create server | |||
| if self.rank_start == 0: | |||
| self.server = Server(self.port) | |||
| @@ -104,6 +111,7 @@ class launcher: | |||
| procs = [] | |||
| queue = mp.Queue(self.n_gpus) | |||
| results = [None] * self.n_gpus | |||
| machine_ranks = [i + self.rank_start for i in range(self.n_gpus)] | |||
| for dev in range(self.n_gpus): | |||
| p = mp.Process( | |||
| target=_run_wrapped, | |||
| @@ -118,7 +126,9 @@ class launcher: | |||
| self.device_type, | |||
| args, | |||
| kwargs, | |||
| self.backend, | |||
| queue, | |||
| machine_ranks, | |||
| ), | |||
| ) | |||
| p.start() | |||
| @@ -11,6 +11,7 @@ | |||
| #include "megbrain/opr/megray_helper.h" | |||
| #include "megbrain/comp_node_env.h" | |||
| #include "megray/common.h" | |||
| using namespace mgb; | |||
| using namespace opr; | |||
| @@ -39,6 +40,8 @@ MegRay::Backend mgb::opr::get_megray_backend(const std::string& backend) { | |||
| return MegRay::MEGRAY_RCCL; | |||
| } else if (backend == "ucx") { | |||
| return MegRay::MEGRAY_UCX; | |||
| } else if (backend == "shm") { | |||
| return MegRay::MEGRAY_SHM; | |||
| } else { | |||
| mgb_throw(MegBrainError, "back CollectiveComm backend"); | |||
| } | |||
| @@ -90,7 +93,7 @@ std::shared_ptr<MegRay::Communicator> MegRayCommBuilder::get_megray_comm( | |||
| if (rank == root) { | |||
| char* c = MegRay::get_host_ip(); | |||
| master_ip = std::string(c); | |||
| delete c; | |||
| delete [] c; | |||
| port = MegRay::get_free_port(); | |||
| auto ret = MegRay::create_server(size, port); | |||
| mgb_assert(ret == MegRay::Status::MEGRAY_OK); | |||