GitOrigin-RevId: cc120cfb55
tags/v0.6.0
| @@ -11,10 +11,13 @@ from .functional import ( | |||
| all_reduce_max, | |||
| all_reduce_min, | |||
| all_reduce_sum, | |||
| all_to_all, | |||
| bcast_param, | |||
| broadcast, | |||
| gather, | |||
| reduce_scatter_sum, | |||
| reduce_sum, | |||
| scatter, | |||
| ) | |||
| from .util import ( | |||
| get_backend, | |||
| @@ -9,7 +9,7 @@ | |||
| from typing import Optional, Union | |||
| import megengine._internal as mgb | |||
| from megengine._internal.opr_param_defs import CollectiveComm as CollParam | |||
| from megengine._internal.opr_param_defs import CollectiveComm as Param | |||
| from ..core import Buffer, Parameter, Tensor, wrap_io_tensor | |||
| from ..functional import add_update | |||
| @@ -22,9 +22,16 @@ def _collective_comm(*args, **kargs): | |||
| return collective_comm_symvar(*args, **kargs) | |||
| def _group_check(*args): | |||
| """Return True when arguments are all None or all not None | |||
| """ | |||
| l = [val is None for val in args] | |||
| return len(set(l)) <= 1 | |||
| def reduce_sum( | |||
| tensor: Tensor, | |||
| key: str, | |||
| key: Optional[str] = None, | |||
| nr_ranks: Optional[int] = None, | |||
| is_root: Optional[bool] = None, | |||
| ) -> Tensor: | |||
| @@ -35,14 +42,17 @@ def reduce_sum( | |||
| :param nr_ranks: number of ranks, use util.get_world_size() as default | |||
| :param is_root: whether this is a root node | |||
| """ | |||
| assert _group_check( | |||
| key, nr_ranks, is_root | |||
| ), "key, nr_ranks, is_root should be set at the same time" | |||
| return _collective_comm( | |||
| tensor, key, CollParam.Mode.REDUCE_SUM, nr_ranks, is_root, device=tensor.device, | |||
| tensor, key, Param.Mode.REDUCE_SUM, nr_ranks, is_root, device=tensor.device, | |||
| ) | |||
| def gather( | |||
| tensor: Tensor, | |||
| key: str, | |||
| key: Optional[str] = None, | |||
| nr_ranks: Optional[int] = None, | |||
| is_root: Optional[bool] = None, | |||
| rank: Optional[int] = None, | |||
| @@ -55,20 +65,17 @@ def gather( | |||
| :param is_root: whether this is a root node | |||
| :param rank: rank of this node | |||
| """ | |||
| assert _group_check( | |||
| key, nr_ranks, is_root, rank | |||
| ), "key, nr_ranks, is_root, rank should be set at the same time" | |||
| return _collective_comm( | |||
| tensor, | |||
| key, | |||
| CollParam.Mode.GATHER, | |||
| nr_ranks, | |||
| is_root, | |||
| rank, | |||
| device=tensor.device, | |||
| tensor, key, Param.Mode.GATHER, nr_ranks, is_root, rank, device=tensor.device, | |||
| ) | |||
| def broadcast( | |||
| tensor: Tensor, | |||
| key: str, | |||
| key: Optional[str] = None, | |||
| nr_ranks: Optional[int] = None, | |||
| is_root: Optional[bool] = None, | |||
| ) -> Tensor: | |||
| @@ -79,11 +86,12 @@ def broadcast( | |||
| :param nr_ranks: number of ranks, use util.get_world_size() as default | |||
| :param is_root: whether this is a root node | |||
| """ | |||
| if key is None: | |||
| key = tensor._symvar.name | |||
| assert _group_check( | |||
| key, nr_ranks, is_root | |||
| ), "key, nr_ranks, is_root should be set at the same time" | |||
| if is_root is None: | |||
| is_root = get_rank() == 0 | |||
| if is_root: | |||
| inp = tensor | |||
| else: | |||
| @@ -92,7 +100,7 @@ def broadcast( | |||
| return _collective_comm( | |||
| inp, | |||
| key, | |||
| CollParam.Mode.BROADCAST, | |||
| Param.Mode.BROADCAST, | |||
| nr_ranks, | |||
| is_root, | |||
| dtype=tensor.dtype, | |||
| @@ -102,7 +110,7 @@ def broadcast( | |||
| def scatter( | |||
| tensor: Tensor, | |||
| key: str, | |||
| key: Optional[str] = None, | |||
| nr_ranks: Optional[int] = None, | |||
| is_root: Optional[bool] = None, | |||
| rank: Optional[int] = None, | |||
| @@ -115,6 +123,9 @@ def scatter( | |||
| :param is_root: whether this is a root node | |||
| :param rank: rank of this node | |||
| """ | |||
| assert _group_check( | |||
| key, nr_ranks, is_root, rank | |||
| ), "key, nr_ranks, is_root, rank should be set at the same time" | |||
| if key is None: | |||
| key = tensor._symvar.name | |||
| if is_root is None: | |||
| @@ -128,7 +139,7 @@ def scatter( | |||
| return _collective_comm( | |||
| inp, | |||
| key, | |||
| CollParam.Mode.SCATTER, | |||
| Param.Mode.SCATTER, | |||
| nr_ranks, | |||
| is_root, | |||
| rank, | |||
| @@ -138,7 +149,11 @@ def scatter( | |||
| def all_to_all( | |||
| tensor: Tensor, key: str, nr_ranks: Optional[int] = None, rank: Optional[int] = None | |||
| tensor: Tensor, | |||
| key: Optional[str] = None, | |||
| nr_ranks: Optional[int] = None, | |||
| rank: Optional[int] = None, | |||
| local_grad: Optional[bool] = False, | |||
| ) -> Tensor: | |||
| """Create all_to_all operator for collective communication | |||
| @@ -146,12 +161,22 @@ def all_to_all( | |||
| :param key: unique identifier for collective communication | |||
| :param nr_ranks: number of ranks, use util.get_world_size() as default | |||
| :param rank: rank of this node | |||
| :param local_grad: whether use local grad | |||
| """ | |||
| return _collective_comm(tensor, key, CollParam.Mode.ALL_TO_ALL, nr_ranks, rank=rank) | |||
| assert _group_check( | |||
| key, nr_ranks, rank | |||
| ), "key, nr_ranks, rank should be set at the same time" | |||
| return _collective_comm( | |||
| tensor, key, Param.Mode.ALL_TO_ALL, nr_ranks, rank=rank, local_grad=local_grad, | |||
| ) | |||
| def all_gather( | |||
| tensor: Tensor, key: str, nr_ranks: Optional[int] = None, rank: Optional[int] = None | |||
| tensor: Tensor, | |||
| key: Optional[str] = None, | |||
| nr_ranks: Optional[int] = None, | |||
| rank: Optional[int] = None, | |||
| local_grad: Optional[bool] = False, | |||
| ) -> Tensor: | |||
| """Create all_gather operator for collective communication | |||
| @@ -159,12 +184,22 @@ def all_gather( | |||
| :param key: unique identifier for collective communication | |||
| :param nr_ranks: number of ranks, use util.get_world_size() as default | |||
| :param rank: rank of this node | |||
| :param local_grad: whether use local grad | |||
| """ | |||
| return _collective_comm(tensor, key, CollParam.Mode.ALL_GATHER, nr_ranks, rank=rank) | |||
| assert _group_check( | |||
| key, nr_ranks, rank | |||
| ), "key, nr_ranks, rank should be set at the same time" | |||
| return _collective_comm( | |||
| tensor, key, Param.Mode.ALL_GATHER, nr_ranks, rank=rank, local_grad=local_grad | |||
| ) | |||
| def reduce_scatter_sum( | |||
| tensor: Tensor, key: str, nr_ranks: Optional[int] = None, rank: Optional[int] = None | |||
| tensor: Tensor, | |||
| key: Optional[str] = None, | |||
| nr_ranks: Optional[int] = None, | |||
| rank: Optional[int] = None, | |||
| local_grad: Optional[bool] = False, | |||
| ) -> Tensor: | |||
| """Create reduce_scatter_sum operator for collective communication | |||
| @@ -172,45 +207,81 @@ def reduce_scatter_sum( | |||
| :param key: unique identifier for collective communication | |||
| :param nr_ranks: number of ranks, use util.get_world_size() as default | |||
| :param rank: rank of this node | |||
| :param local_grad: whether use local grad | |||
| """ | |||
| assert _group_check( | |||
| key, nr_ranks, rank | |||
| ), "key, nr_ranks, rank should be set at the same time" | |||
| return _collective_comm( | |||
| tensor, key, CollParam.Mode.REDUCE_SCATTER_SUM, nr_ranks, rank=rank, | |||
| tensor, | |||
| key, | |||
| Param.Mode.REDUCE_SCATTER_SUM, | |||
| nr_ranks, | |||
| rank=rank, | |||
| local_grad=local_grad, | |||
| ) | |||
| def all_reduce_sum(tensor: Tensor, key: str, nr_ranks: Optional[int] = None) -> Tensor: | |||
| def all_reduce_sum( | |||
| tensor: Tensor, | |||
| key: Optional[str] = None, | |||
| nr_ranks: Optional[int] = None, | |||
| local_grad: Optional[bool] = False, | |||
| ) -> Tensor: | |||
| """Create all_reduce_sum operator for collective communication | |||
| :param tensor: input tensor | |||
| :param key: unique identifier for collective communication | |||
| :param nr_ranks: number of ranks, use util.get_world_size() as default | |||
| :param local_grad: whether use local grad | |||
| """ | |||
| return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_SUM, nr_ranks) | |||
| assert _group_check(key, nr_ranks), "key, nr_ranks should be set at the same time" | |||
| return _collective_comm( | |||
| tensor, key, Param.Mode.ALL_REDUCE_SUM, nr_ranks, local_grad=local_grad | |||
| ) | |||
| def all_reduce_max(tensor: Tensor, key: str, nr_ranks: Optional[int] = None) -> Tensor: | |||
| def all_reduce_max( | |||
| tensor: Tensor, | |||
| key: Optional[str] = None, | |||
| nr_ranks: Optional[int] = None, | |||
| local_grad: Optional[bool] = False, | |||
| ) -> Tensor: | |||
| """Create all_reduce_max operator for collective communication | |||
| :param tensor: input tensor | |||
| :param key: unique identifier for collective communication | |||
| :param nr_ranks: number of ranks, use util.get_world_size() as default | |||
| :param local_grad: whether use local grad | |||
| """ | |||
| return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_MAX, nr_ranks) | |||
| assert _group_check(key, nr_ranks), "key, nr_ranks should be set at the same time" | |||
| return _collective_comm( | |||
| tensor, key, Param.Mode.ALL_REDUCE_MAX, nr_ranks, local_grad=local_grad | |||
| ) | |||
| def all_reduce_min(tensor: Tensor, key: str, nr_ranks: Optional[int] = None) -> Tensor: | |||
| def all_reduce_min( | |||
| tensor: Tensor, | |||
| key: Optional[str] = None, | |||
| nr_ranks: Optional[int] = None, | |||
| local_grad: Optional[bool] = False, | |||
| ) -> Tensor: | |||
| """Create all_reduce_min operator for collective communication | |||
| :param tensor: input tensor | |||
| :param key: unique identifier for collective communication | |||
| :param nr_ranks: number of ranks, use util.get_world_size() as default | |||
| :param local_grad: whether use local grad | |||
| """ | |||
| return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_MIN, nr_ranks) | |||
| assert _group_check(key, nr_ranks), "key, nr_ranks should be set at the same time" | |||
| return _collective_comm( | |||
| tensor, key, Param.Mode.ALL_REDUCE_MIN, nr_ranks, local_grad=local_grad | |||
| ) | |||
| def bcast_param( | |||
| inp: Union[Buffer, Parameter], | |||
| key: str, | |||
| key: Optional[str] = None, | |||
| nr_ranks: Optional[int] = None, | |||
| is_root: Optional[bool] = None, | |||
| ) -> None: | |||
| @@ -223,6 +294,9 @@ def bcast_param( | |||
| """ | |||
| if not is_distributed(): | |||
| return | |||
| assert _group_check( | |||
| key, nr_ranks, is_root | |||
| ), "key, nr_ranks, is_root should be set at the same time" | |||
| assert isinstance(inp, (Buffer, Parameter)) | |||
| bcast_res = broadcast(inp, key, nr_ranks, is_root) | |||
| add_update(inp, bcast_res, alpha=0) | |||
| @@ -11,16 +11,24 @@ from typing import Optional, Union | |||
| import megengine._internal as mgb | |||
| from megengine._internal.opr_param_defs import CollectiveComm as CollParam | |||
| from .util import get_backend, get_master_ip, get_master_port, get_rank, get_world_size | |||
| from .util import ( | |||
| get_backend, | |||
| get_group_id, | |||
| get_master_ip, | |||
| get_master_port, | |||
| get_rank, | |||
| get_world_size, | |||
| ) | |||
| def collective_comm_symvar( | |||
| inp: Union[mgb.SymbolVar, mgb.CompGraph], | |||
| key: str, | |||
| op: CollParam.Mode, | |||
| key: Optional[str] = None, | |||
| op: CollParam.Mode = None, | |||
| nr_ranks: Optional[int] = None, | |||
| is_root: Optional[bool] = None, | |||
| rank: Optional[int] = None, | |||
| local_grad: Optional[bool] = False, | |||
| dtype: Optional[type] = None, | |||
| device: Optional[mgb.CompNode] = None, | |||
| comp_graph: Optional[mgb.CompGraph] = None, | |||
| @@ -32,16 +40,19 @@ def collective_comm_symvar( | |||
| :param op: mode of collective communication | |||
| :param nr_ranks: number of ranks, use util.get_world_size() as default | |||
| :param is_root: whether this node is root node | |||
| :param rank: rank of this node | |||
| :param local_grad: whether use local grad | |||
| :param dtype: output data type, use dtype of inp as default | |||
| :param device: output comp node, use comp node of inp as default | |||
| :param comp_graph: output comp graph, use comp graph of inp as default | |||
| """ | |||
| return mgb.opr.collective_comm( | |||
| inp, | |||
| key=str(key), | |||
| key=key if key is not None else ("collective_comm_" + str(get_group_id())), | |||
| nr_devices=nr_ranks if nr_ranks is not None else get_world_size(), | |||
| is_root=is_root if is_root is not None else (get_rank() == 0), | |||
| rank=rank if rank is not None else -1, | |||
| rank=rank if rank is not None else get_rank(), | |||
| local_grad=local_grad, | |||
| server_addr=get_master_ip(), | |||
| port=get_master_port(), | |||
| param=CollParam(mode=op), | |||
| @@ -182,7 +182,9 @@ class Optimizer(metaclass=ABCMeta): | |||
| with opr_priority_scope(cg, -(2 ** 30)): | |||
| # always run all_reduce_mean first except add_update | |||
| grad = ( | |||
| all_reduce_sum(grad, "grad_" + str(get_group_id())) | |||
| all_reduce_sum( | |||
| grad, "grad_" + str(get_group_id()), get_world_size() | |||
| ) | |||
| / get_world_size() | |||
| ) | |||
| with opr_priority_scope(cg, -(2 ** 31)): | |||
| @@ -229,7 +231,7 @@ class Optimizer(metaclass=ABCMeta): | |||
| for group in self.param_groups: | |||
| for param in group["params"]: | |||
| bcast_param( | |||
| param, "bcast_param_" + str(key), is_root=(get_rank() == 0), | |||
| param, "bcast_param_" + str(key), get_world_size(), get_rank() == 0, | |||
| ) | |||
| key += 1 | |||
| @@ -94,9 +94,9 @@ SymbolVar _Opr::remote_recv(const std::string& server_addr, const int port, | |||
| SymbolVar _Opr::collective_comm_with_input( | |||
| SymbolVar inpvar, const std::string& key, const size_t nr_devices, | |||
| const bool is_root, const int rank, const std::string& server_addr, | |||
| const int port, PyObject* params, PyObject* dtype, | |||
| const std::string& backend, SharedND* output_buf, | |||
| const bool is_root, const int rank, const bool local_grad, | |||
| const std::string& server_addr, const int port, PyObject* params, | |||
| PyObject* dtype, const std::string& backend, SharedND* output_buf, | |||
| const OperatorNodeConfig& config, const SharedScalar& disable) { | |||
| SymbolVarArray inputs(1, inpvar); | |||
| ComputingGraph* graph = inpvar.node()->owner_graph(); | |||
| @@ -111,15 +111,15 @@ SymbolVar _Opr::collective_comm_with_input( | |||
| _dtype = npy::dtype_np2mgb(dtype); | |||
| } | |||
| return CollectiveComm::make(inputs, graph, key, nr_devices, is_root, rank, | |||
| group_mgr, dev_buffer_arr, param, _dtype, | |||
| backend, config, disable.get_val())[0]; | |||
| local_grad, group_mgr, dev_buffer_arr, param, | |||
| _dtype, backend, config, disable.get_val())[0]; | |||
| } | |||
| SymbolVar _Opr::collective_comm_without_input( | |||
| CompGraph& cg, const std::string& key, const size_t nr_devices, | |||
| const bool is_root, const int rank, const std::string& server_addr, | |||
| const int port, PyObject* params, PyObject* dtype, | |||
| const std::string& backend, SharedND* output_buf, | |||
| const bool is_root, const int rank, const bool local_grad, | |||
| const std::string& server_addr, const int port, PyObject* params, | |||
| PyObject* dtype, const std::string& backend, SharedND* output_buf, | |||
| const OperatorNodeConfig& config, const SharedScalar& disable) { | |||
| SymbolVarArray inputs; | |||
| auto& graph = cg.get(); | |||
| @@ -134,8 +134,8 @@ SymbolVar _Opr::collective_comm_without_input( | |||
| _dtype = npy::dtype_np2mgb(dtype); | |||
| } | |||
| return CollectiveComm::make(inputs, &graph, key, nr_devices, is_root, rank, | |||
| group_mgr, dev_buffer_arr, param, _dtype, | |||
| backend, config, disable.get_val())[0]; | |||
| local_grad, group_mgr, dev_buffer_arr, param, | |||
| _dtype, backend, config, disable.get_val())[0]; | |||
| } | |||
| #else | |||
| @@ -171,8 +171,8 @@ SymbolVar _Opr::remote_recv(const std::string& server_addr, const int port, | |||
| } | |||
| SymbolVar _Opr::collective_comm_with_input( | |||
| SymbolVar inpvar, const std::string& key, | |||
| const size_t nr_devices, const bool is_root, const int rank, | |||
| SymbolVar inpvar, const std::string& key, const size_t nr_devices, | |||
| const bool is_root, const int rank, const bool local_grad, | |||
| const std::string& server_addr, const int port, PyObject* params, | |||
| PyObject* dtype, const std::string& backend, SharedND* output_buf, | |||
| const OperatorNodeConfig& config, const SharedScalar& disable) { | |||
| @@ -180,8 +180,8 @@ SymbolVar _Opr::collective_comm_with_input( | |||
| } | |||
| SymbolVar _Opr::collective_comm_without_input( | |||
| CompGraph& cg, const std::string& key, | |||
| const size_t nr_devices, const bool is_root, const int rank, | |||
| CompGraph& cg, const std::string& key, const size_t nr_devices, | |||
| const bool is_root, const int rank, const bool local_grad, | |||
| const std::string& server_addr, const int port, PyObject* params, | |||
| PyObject* dtype, const std::string& backend, SharedND* output_buf, | |||
| const OperatorNodeConfig& config, const SharedScalar& disable) { | |||
| @@ -94,17 +94,17 @@ static SymbolVar remote_recv(const std::string& server_addr, const int port, | |||
| static SymbolVar collective_comm_with_input( | |||
| SymbolVar inpvar, const std::string& key, const size_t nr_devices, | |||
| const bool is_root, const int rank, const std::string& server_addr, const int port, | |||
| PyObject* params, PyObject* dtype, const std::string& backend, | |||
| SharedND* output_buf, const OperatorNodeConfig& config, | |||
| const SharedScalar& disable); | |||
| const bool is_root, const int rank, const bool local_grad, | |||
| const std::string& server_addr, const int port, PyObject* params, | |||
| PyObject* dtype, const std::string& backend, SharedND* output_buf, | |||
| const OperatorNodeConfig& config, const SharedScalar& disable); | |||
| static SymbolVar collective_comm_without_input( | |||
| CompGraph& graph, const std::string& key, const size_t nr_devices, | |||
| const bool is_root, const int rank, const std::string& server_addr, const int port, | |||
| PyObject* params, PyObject* dtype, const std::string& backend, | |||
| SharedND* output_buf, const OperatorNodeConfig& config, | |||
| const SharedScalar& disable); | |||
| const bool is_root, const int rank, const bool local_grad, | |||
| const std::string& server_addr, const int port, PyObject* params, | |||
| PyObject* dtype, const std::string& backend, SharedND* output_buf, | |||
| const OperatorNodeConfig& config, const SharedScalar& disable); | |||
| // misc | |||
| static SymbolVarArray extern_c_opr_placeholder( | |||
| @@ -34,7 +34,7 @@ def test_reduce_sum(): | |||
| return | |||
| _init_process_group_wrapper(world_size, rank, rank, backend, port_queue) | |||
| inp = tensor(data) | |||
| output = dist.functional.reduce_sum(inp, "x") | |||
| output = dist.functional.reduce_sum(inp) | |||
| if rank == 0: | |||
| assert np.allclose(output.numpy(), expect) | |||
| else: | |||
| @@ -70,7 +70,7 @@ def test_gather(): | |||
| return | |||
| _init_process_group_wrapper(world_size, rank, rank, backend, port_queue) | |||
| inp = tensor(data) | |||
| output = dist.functional.gather(inp, "x", is_root=(rank == 0), rank=rank) | |||
| output = dist.functional.gather(inp) | |||
| if rank == 0: | |||
| assert np.allclose(output.numpy(), expect) | |||
| else: | |||
| @@ -106,7 +106,7 @@ def test_broadcast(): | |||
| return | |||
| _init_process_group_wrapper(world_size, rank, rank, backend, port_queue) | |||
| inp = tensor(data) | |||
| output = dist.functional.broadcast(inp, "x") | |||
| output = dist.functional.broadcast(inp) | |||
| assert np.allclose(output.numpy(), expect) | |||
| def check(shape, backend): | |||
| @@ -138,7 +138,7 @@ def test_scatter(): | |||
| return | |||
| _init_process_group_wrapper(world_size, rank, rank, backend, port_queue) | |||
| inp = tensor(data) | |||
| output = dist.functional.scatter(inp, "x", is_root=(rank == 0), rank=rank) | |||
| output = dist.functional.scatter(inp) | |||
| assert np.allclose(output.numpy(), expect) | |||
| def check(shape, backend): | |||
| @@ -174,7 +174,7 @@ def test_all_to_all(): | |||
| return | |||
| _init_process_group_wrapper(world_size, rank, rank, backend, port_queue) | |||
| inp = tensor(data) | |||
| output = dist.functional.all_to_all(inp, "x", rank=rank) | |||
| output = dist.functional.all_to_all(inp) | |||
| assert np.allclose(output.numpy(), expect) | |||
| def check(shape, backend): | |||
| @@ -208,7 +208,7 @@ def test_all_gather(): | |||
| return | |||
| _init_process_group_wrapper(world_size, rank, rank, backend, port_queue) | |||
| inp = tensor(data) | |||
| output = dist.functional.all_gather(inp, "x", rank=rank) | |||
| output = dist.functional.all_gather(inp) | |||
| assert np.allclose(output.numpy(), expect) | |||
| def check(shape, backend): | |||
| @@ -241,7 +241,7 @@ def test_reduce_scatter_sum(): | |||
| return | |||
| _init_process_group_wrapper(world_size, rank, rank, backend, port_queue) | |||
| inp = tensor(data) | |||
| output = dist.functional.reduce_scatter_sum(inp, "x", rank=rank) | |||
| output = dist.functional.reduce_scatter_sum(inp) | |||
| assert np.allclose(output.numpy(), expect) | |||
| def check(shape, backend): | |||
| @@ -278,7 +278,7 @@ def test_all_reduce_sum(): | |||
| return | |||
| _init_process_group_wrapper(world_size, rank, rank, backend, port_queue) | |||
| inp = tensor(data) | |||
| output = dist.functional.all_reduce_sum(inp, "x") | |||
| output = dist.functional.all_reduce_sum(inp) | |||
| assert np.allclose(output.numpy(), expect) | |||
| def check(shape, backend): | |||
| @@ -311,7 +311,7 @@ def test_all_reduce_max(): | |||
| return | |||
| _init_process_group_wrapper(world_size, rank, rank, backend, port_queue) | |||
| inp = tensor(data) | |||
| output = dist.functional.all_reduce_max(inp, "x") | |||
| output = dist.functional.all_reduce_max(inp) | |||
| assert np.allclose(output.numpy(), expect) | |||
| def check(shape, backend): | |||
| @@ -344,7 +344,7 @@ def test_all_reduce_min(): | |||
| return | |||
| _init_process_group_wrapper(world_size, rank, rank, backend, port_queue) | |||
| inp = tensor(data) | |||
| output = dist.functional.all_reduce_min(inp, "x") | |||
| output = dist.functional.all_reduce_min(inp) | |||
| assert np.allclose(output.numpy(), expect) | |||
| def check(shape, backend): | |||
| @@ -377,7 +377,7 @@ def test_bcast_param(): | |||
| return | |||
| _init_process_group_wrapper(world_size, rank, rank, backend, port_queue) | |||
| inp = Parameter(data) | |||
| dist.functional.bcast_param(inp, "x") | |||
| dist.functional.bcast_param(inp) | |||
| assert np.allclose(inp.numpy(), expect) | |||
| def check(shape, backend): | |||
| @@ -688,6 +688,7 @@ bool PackAllReduceScanPass::check_pattern(OperatorNodeBase* opr) { | |||
| if (!opr->same_type<opr::CollectiveComm>()) return false; | |||
| auto& comm = opr->cast_final_safe<opr::CollectiveComm>(); | |||
| if (comm.param().mode != opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM) return false; | |||
| if (comm.local_grad()) return false; | |||
| if (comm.input().size() != 1) return false; | |||
| auto grad = comm.input(0)->owner_opr(); | |||
| @@ -839,7 +840,7 @@ void PackAllReduceReplacePass::insert_packed_oprs( | |||
| std::string key = ssprintf("grad_pack_%zu", pack_id); | |||
| auto param = opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM; | |||
| SymbolVar allreduce = opr::CollectiveComm::make({concat}, graph, | |||
| key, info->nr_devices, info->is_root, info->rank, | |||
| key, info->nr_devices, info->is_root, info->rank, false, | |||
| info->group_client, param, info->dtype, info->backend)[0]; | |||
| // split according to recorded partition | |||
| @@ -438,14 +438,14 @@ TEST_PASS(PackAllReduceScanPass, Basic) { | |||
| auto grad3 = opr::VirtualGrad::make(y1, x1); | |||
| auto mode = opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM; | |||
| auto comm0 = opr::CollectiveComm::make({grad0}, graph.get(), | |||
| "grad0", 2, 0, 0, client, mode)[0]; | |||
| auto comm1 = opr::CollectiveComm::make({grad1}, graph.get(), | |||
| "grad1", 2, 0, 0, client, mode)[0]; | |||
| auto comm2 = opr::CollectiveComm::make({grad2}, graph.get(), | |||
| "grad2", 2, 0, 0, client, mode)[0]; | |||
| auto comm3 = opr::CollectiveComm::make({grad3}, graph.get(), | |||
| "grad3", 2, 0, 0, client, mode)[0]; | |||
| auto comm0 = opr::CollectiveComm::make({grad0}, graph.get(), "grad0", 2, | |||
| false, 0, false, client, mode)[0]; | |||
| auto comm1 = opr::CollectiveComm::make({grad1}, graph.get(), "grad1", 2, | |||
| false, 0, false, client, mode)[0]; | |||
| auto comm2 = opr::CollectiveComm::make({grad2}, graph.get(), "grad2", 2, | |||
| false, 0, false, client, mode)[0]; | |||
| auto comm3 = opr::CollectiveComm::make({grad3}, graph.get(), "grad3", 2, | |||
| false, 0, false, client, mode)[0]; | |||
| gopt::GraphOptimizer() | |||
| .add_pass<gopt::PackAllReduceScanPass>() | |||
| @@ -488,10 +488,12 @@ TEST_PASS(PackAllReduceReplacePass, CollectGroups) { | |||
| auto grad = opr::VirtualGrad::make(target, wrt); | |||
| auto comm = opr::CollectiveComm::make( | |||
| {grad}, graph.get(), "key", 2, 0, 0, client, | |||
| opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM)[0] | |||
| .node()->owner_opr(); | |||
| auto comm = | |||
| opr::CollectiveComm::make( | |||
| {grad}, graph.get(), "key", 2, false, 0, false, client, | |||
| opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM)[0] | |||
| .node() | |||
| ->owner_opr(); | |||
| comm->cast_final_safe<opr::CollectiveComm>().set_pack_hash(extra_hash); | |||
| @@ -543,8 +545,8 @@ TEST_PASS(PackAllReduceReplacePass, DividePacks) { | |||
| auto insert_opr = [&] (size_t size) { | |||
| auto dev = std::make_shared<DeviceTensorND>(cn, TensorShape{size / sizeof(float)}); | |||
| auto sd = opr::SharedDeviceTensor::make(*graph, dev); | |||
| auto symvar = opr::CollectiveComm::make({sd}, graph.get(), | |||
| "key", 2, 0, 0, client, mode)[0]; | |||
| auto symvar = opr::CollectiveComm::make( | |||
| {sd}, graph.get(), "key", 2, false, 0, false, client, mode)[0]; | |||
| auto opr = symvar.node()->owner_opr(); | |||
| auto& comm = opr->cast_final_safe<opr::CollectiveComm>(); | |||
| comm.set_pack_hash(1); | |||
| @@ -596,7 +598,6 @@ TEST_PASS(PackAllReduceReplacePass, InsertPackedOprs) { | |||
| size_t nr_devices = 2; | |||
| uint32_t rank = 0; | |||
| uint32_t root = 0; | |||
| using GroupInfo = gopt::PackAllReduceReplacePass::GroupInfo; | |||
| ThinHashMap<uint64_t, std::shared_ptr<GroupInfo>> group_info; | |||
| @@ -605,8 +606,9 @@ TEST_PASS(PackAllReduceReplacePass, InsertPackedOprs) { | |||
| auto insert_opr = [&] (const TensorShape& shape) { | |||
| auto dev = std::make_shared<DeviceTensorND>(cn, shape); | |||
| auto sd = opr::SharedDeviceTensor::make(*graph, dev); | |||
| auto symvar = opr::CollectiveComm::make({sd}, graph.get(), | |||
| "key", nr_devices, rank, root, client, mode)[0]; | |||
| auto symvar = | |||
| opr::CollectiveComm::make({sd}, graph.get(), "key", nr_devices, | |||
| false, rank, false, client, mode)[0]; | |||
| auto opr = symvar.node()->owner_opr(); | |||
| auto& comm = opr->cast_final_safe<opr::CollectiveComm>(); | |||
| comm.set_pack_hash(1); | |||
| @@ -634,8 +636,9 @@ TEST_PASS(PackAllReduceReplacePass, InsertPackedOprs) { | |||
| auto concat = opr::Concat::make({grad_x.flatten(), grad_y.flatten()}, 0); | |||
| std::string key = ssprintf("grad_pack_%zu", pack_id); | |||
| auto allreduce = opr::CollectiveComm::make({concat}, graph.get(), | |||
| key, nr_devices, rank, root, client, mode)[0]; | |||
| auto allreduce = | |||
| opr::CollectiveComm::make({concat}, graph.get(), key, nr_devices, | |||
| false, rank, false, client, mode)[0]; | |||
| std::vector<size_t> partition; | |||
| partition.push_back(shape_x.total_nr_elems()); | |||
| @@ -683,10 +686,14 @@ TEST_PASS(PackAllReduceReplacePass, Equivalence) { | |||
| using Mode = opr::CollectiveComm::Param::Mode; | |||
| bool is_root = (rank == 0); | |||
| auto reduced_x = opr::CollectiveComm::make({grad_x}, graph.get(), | |||
| "x", 2, is_root, rank, client, Mode::ALL_REDUCE_SUM)[0] / 2; | |||
| auto reduced_y = opr::CollectiveComm::make({grad_y}, graph.get(), | |||
| "y", 2, is_root, rank, client, Mode::ALL_REDUCE_SUM)[0] / 2; | |||
| auto reduced_x = opr::CollectiveComm::make( | |||
| {grad_x}, graph.get(), "x", 2, is_root, rank, | |||
| false, client, Mode::ALL_REDUCE_SUM)[0] / | |||
| 2; | |||
| auto reduced_y = opr::CollectiveComm::make( | |||
| {grad_y}, graph.get(), "y", 2, is_root, rank, | |||
| false, client, Mode::ALL_REDUCE_SUM)[0] / | |||
| 2; | |||
| graph->options().allreduce_pack_max_size = 5000; | |||
| graph->options().allreduce_pack_ignore_first = 0; | |||
| @@ -14,6 +14,8 @@ | |||
| #include "megbrain/comp_node_env.h" | |||
| #include "megbrain/graph/event.h" | |||
| #include "megbrain/graph/grad_impl.h" | |||
| #include "megbrain/opr/io.h" | |||
| #include "megbrain/opr/tensor_manip.h" | |||
| #include "megbrain/opr/basic_arith.h" | |||
| #include "megbrain/opr/megray_helper.h" | |||
| #include "megbrain/opr/group_manager.h" | |||
| @@ -77,6 +79,8 @@ cudaStream_t get_stream(VarNode* var) { | |||
| } | |||
| } // anonymous namespace | |||
| /* ================= ModeTrait ================= */ | |||
| class CollectiveComm::ModeTrait { | |||
| class BROADCAST; | |||
| class REDUCE_SUM; | |||
| @@ -132,6 +136,42 @@ public: | |||
| return None; | |||
| } | |||
| VarNode* full_grad(VarNode* out_grad, const CollectiveComm* opr) const { | |||
| auto mode = ModeTrait::from_mode(opr->param().mode).grad_mode(); | |||
| SymbolVarArray og_syms; | |||
| og_syms.push_back(out_grad); | |||
| auto&& cn = opr->output(0)->comp_node(); | |||
| auto gvar = CollectiveComm::make( | |||
| og_syms, opr->owner_graph(), opr->key() + ":grad", | |||
| opr->nr_devices(), opr->is_root(), opr->rank(), false, | |||
| opr->group_client(), mode, opr->dtype(), opr->backend(), {cn}); | |||
| return gvar[0].node(); | |||
| } | |||
| virtual VarNode* local_grad(VarNode* out_grad, const CollectiveComm* opr) const { | |||
| mgb_throw(MegBrainError, | |||
| "only all_reduce all_to_all all_gather reduce_scatter " | |||
| "support local_grad"); | |||
| } | |||
| virtual VarNode* grad(VarNode* out_grad, const CollectiveComm* opr) const { | |||
| if (opr->local_grad()){ | |||
| return local_grad(out_grad, opr); | |||
| } else { | |||
| return full_grad(out_grad, opr); | |||
| } | |||
| } | |||
| VarNode* zeros(mgb::cg::ComputingGraph &graph, CompNode node, const SymbolVar& shape, | |||
| DType dtype) const { | |||
| auto zero = SymbolVar::make_scalar(0, graph, node); | |||
| auto zero_tensor = opr::TypeCvt::make(zero, dtype).broadcast(shape); | |||
| return zero_tensor.node(); | |||
| } | |||
| virtual void get_output_var_shape(const CollectiveComm* opr, | |||
| const TensorShapeArray& ishp, | |||
| TensorShapeArray& oshp) = 0; | |||
| @@ -174,6 +214,17 @@ class CollectiveComm::ModeTrait::ALL_GATHER : public ModeTrait { | |||
| } | |||
| Mode grad_mode() override { return Mode::REDUCE_SCATTER_SUM; } | |||
| VarNode* local_grad(VarNode* out_grad, const CollectiveComm* opr) const override { | |||
| auto nr_devices = opr->nr_devices(); | |||
| auto rank = opr->rank(); | |||
| opr::Subtensor::IndexDesc axis; | |||
| auto shape0 = opr::GetVarShape::make(out_grad, 0); | |||
| axis.push_back({0, shape0 * rank / (int)nr_devices, | |||
| shape0 * (rank + 1) / (int)nr_devices}); | |||
| auto grad = opr::Subtensor::make(out_grad, axis); | |||
| return grad.node(); | |||
| } | |||
| }; | |||
| class CollectiveComm::ModeTrait::REDUCE_SCATTER_SUM : public ModeTrait { | |||
| @@ -211,9 +262,23 @@ class CollectiveComm::ModeTrait::REDUCE_SCATTER_SUM : public ModeTrait { | |||
| } | |||
| Mode grad_mode() override { return Mode::ALL_GATHER; } | |||
| }; | |||
| /* ================= ModeTrait impls ================= */ | |||
| VarNode* local_grad(VarNode* out_grad, const CollectiveComm* opr) const override { | |||
| VarNodeArray grads; | |||
| auto zeros_tensor = | |||
| zeros(*out_grad->owner_graph(), out_grad->comp_node(), | |||
| opr::GetVarShape::make(out_grad), out_grad->dtype()); | |||
| for (size_t i = 0;i < opr->nr_devices();i++) { | |||
| if (i == opr->rank()) { | |||
| grads.push_back(out_grad); | |||
| } else { | |||
| grads.push_back(zeros_tensor); | |||
| } | |||
| } | |||
| auto grad = opr::Concat::make(grads, 0); | |||
| return grad.node(); | |||
| } | |||
| }; | |||
| class CollectiveComm::ModeTrait::ReducedBasedTrait { | |||
| protected: | |||
| @@ -250,6 +315,12 @@ class CollectiveComm::ModeTrait::AllReduceBase : public ReducedBasedTrait, | |||
| } | |||
| Mode grad_mode() override { return Mode::ALL_REDUCE_SUM; } | |||
| public: | |||
| VarNode* local_grad(VarNode* out_grad, | |||
| const CollectiveComm* opr) const override { | |||
| return out_grad; | |||
| } | |||
| }; | |||
| class CollectiveComm::ModeTrait::ALL_REDUCE_SUM final : public AllReduceBase { | |||
| @@ -258,10 +329,38 @@ class CollectiveComm::ModeTrait::ALL_REDUCE_SUM final : public AllReduceBase { | |||
| class CollectiveComm::ModeTrait::ALL_REDUCE_MAX final : public AllReduceBase { | |||
| MegRay::ReduceOp op() const override { return MegRay::ReduceOp::MEGRAY_MAX; } | |||
| VarNode* grad(VarNode* out_grad, const CollectiveComm* opr) const override { | |||
| VarNode* grad; | |||
| if (opr->local_grad()) { | |||
| grad = local_grad(out_grad, opr); | |||
| } else { | |||
| grad = full_grad(out_grad, opr); | |||
| } | |||
| grad = opr::Elemwise::make({opr->output(0), opr->input(0), grad}, | |||
| Elemwise::Mode::COND_LEQ_MOV) | |||
| .node(); | |||
| return grad; | |||
| } | |||
| }; | |||
| class CollectiveComm::ModeTrait::ALL_REDUCE_MIN final : public AllReduceBase { | |||
| MegRay::ReduceOp op() const override { return MegRay::ReduceOp::MEGRAY_MIN; } | |||
| VarNode* grad(VarNode* out_grad, const CollectiveComm* opr) const override { | |||
| VarNode* grad; | |||
| if (opr->local_grad()) { | |||
| grad = local_grad(out_grad, opr); | |||
| } else { | |||
| grad = full_grad(out_grad, opr); | |||
| } | |||
| grad = opr::Elemwise::make({opr->input(0), opr->output(0), grad}, | |||
| Elemwise::Mode::COND_LEQ_MOV) | |||
| .node(); | |||
| return grad; | |||
| } | |||
| }; | |||
| class CollectiveComm::ModeTrait::ReduceBase : public ReducedBasedTrait, | |||
| @@ -448,6 +547,24 @@ class CollectiveComm::ModeTrait::ALL_TO_ALL : public ModeTrait { | |||
| } | |||
| Mode grad_mode() override { return Mode::ALL_TO_ALL; } | |||
| VarNode* local_grad(VarNode* out_grad, const CollectiveComm* opr) const override { | |||
| VarNodeArray grads; | |||
| auto grad_shape = opr::GetVarShape::make(out_grad); | |||
| auto zeros_tensor = | |||
| zeros(*out_grad->owner_graph(), out_grad->comp_node(), | |||
| grad_shape, out_grad->dtype()); | |||
| auto nr_devices = opr->nr_devices(); | |||
| auto rank = opr->rank(); | |||
| opr::Subtensor::IndexDesc axis; | |||
| auto shape0 = opr::GetVarShape::make(out_grad, 0); | |||
| axis.push_back({0, shape0 * rank / (int)nr_devices, | |||
| shape0 * (rank + 1) / (int)nr_devices}); | |||
| auto sub_grad = opr::Subtensor::make(out_grad, axis); | |||
| return opr::SetSubtensor::make(zeros_tensor, sub_grad, axis).node(); | |||
| } | |||
| }; | |||
| CollectiveComm::ModeTrait& CollectiveComm::ModeTrait::from_mode(Mode mode) { | |||
| @@ -469,8 +586,9 @@ CollectiveComm::ModeTrait& CollectiveComm::ModeTrait::from_mode(Mode mode) { | |||
| CollectiveComm::CollectiveComm( | |||
| VarNodeArray inputs, ComputingGraph* const graph, | |||
| const std::string& key, const size_t nr_devices, const bool is_root, | |||
| const int rank, std::shared_ptr<GroupClient> group_client, | |||
| const Param& param, const DType& dtype, const std::string& backend, | |||
| const int rank, const bool local_grad, | |||
| std::shared_ptr<GroupClient> group_client, const Param& param, | |||
| const DType& dtype, const std::string& backend, | |||
| const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr, | |||
| const OperatorNodeConfig& config, | |||
| const std::shared_ptr<DTypeScalar>& disable) | |||
| @@ -482,6 +600,7 @@ CollectiveComm::CollectiveComm( | |||
| m_nr_devices(nr_devices), | |||
| m_is_root(is_root), | |||
| m_rank(rank), | |||
| m_local_grad(local_grad), | |||
| m_key(key), | |||
| m_dev_buffers(dev_buffer_arr), | |||
| m_disable{disable} { | |||
| @@ -523,28 +642,31 @@ CollectiveComm::CollectiveComm( | |||
| SymbolVarArray CollectiveComm::make( | |||
| const SymbolVarArray& inputs, ComputingGraph* const graph, | |||
| const std::string& key, const size_t nr_devices, const bool is_root, | |||
| const int rank, std::shared_ptr<GroupClient> group_client, | |||
| const Param& param, const DType& dtype, const std::string& backend, | |||
| const int rank, const bool local_grad, | |||
| std::shared_ptr<GroupClient> group_client, const Param& param, | |||
| const DType& dtype, const std::string& backend, | |||
| const OperatorNodeConfig& config, | |||
| const std::shared_ptr<DTypeScalar>& disable) { | |||
| SmallVector<std::shared_ptr<DeviceTensorND>> dev_buffer_arr(nr_devices, | |||
| nullptr); | |||
| return make(inputs, graph, key, nr_devices, is_root, rank, group_client, | |||
| dev_buffer_arr, param, dtype, backend, config); | |||
| return make(inputs, graph, key, nr_devices, is_root, rank, local_grad, | |||
| group_client, dev_buffer_arr, param, dtype, backend, config); | |||
| } | |||
| SymbolVarArray CollectiveComm::make( | |||
| const SymbolVarArray& inputs, ComputingGraph* const graph, | |||
| const std::string& key, const size_t nr_devices, const bool is_root, | |||
| const int rank, std::shared_ptr<GroupClient> group_client, | |||
| const int rank, const bool local_grad, | |||
| std::shared_ptr<GroupClient> group_client, | |||
| const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr, | |||
| const Param& param, const DType& dtype, const std::string& backend, | |||
| const OperatorNodeConfig& config, | |||
| const std::shared_ptr<DTypeScalar>& disable) { | |||
| auto inpvars = cg::to_var_node_array(inputs); | |||
| auto opr = graph->insert_opr(std::make_unique<CollectiveComm>( | |||
| inpvars, graph, key, nr_devices, is_root, rank, std::move(group_client), | |||
| param, dtype, backend, dev_buffer_arr, config, disable)); | |||
| inpvars, graph, key, nr_devices, is_root, rank, local_grad, | |||
| std::move(group_client), param, dtype, backend, dev_buffer_arr, | |||
| config, disable)); | |||
| mgb_assert(!opr->output().empty()); | |||
| return cg::to_symbol_var_array(opr->output()); | |||
| } | |||
| @@ -647,93 +769,12 @@ void CollectiveComm::do_execute(ExecEnv& env) { | |||
| owner_graph()->event().signal_inplace<cg::event::BeforeKernel>(this, cn); | |||
| trait.exec(this); | |||
| owner_graph()->event().signal_inplace<cg::event::AfterKernel>(this, cn); | |||
| #if CUDART_VERSION < 9000 | |||
| #pragma message "legacy CUDA; use sync to avoid blocking" | |||
| // nccl hangs occasionally without this sync() | |||
| cn.sync(); | |||
| #endif | |||
| }; | |||
| env.dispatch_on_comp_node(cn, runner); | |||
| } | |||
| void CollectiveComm::on_output_comp_node_stream_changed() {} | |||
| VarNodeArray CollectiveComm::grad(const VarNodeArray& out_grads) const { | |||
| auto mode = ModeTrait::from_mode(m_param.mode).grad_mode(); | |||
| SymbolVarArray og_syms; | |||
| if (m_param.mode == Param::Mode::REDUCE_SUM) { | |||
| for (size_t i = 0; i < output().size(); i++) { | |||
| if (out_grads[i]) | |||
| og_syms.push_back(out_grads[i]); | |||
| } | |||
| mgb_assert(og_syms.size() == 1); | |||
| } else { | |||
| for (size_t i = 0; i < output().size(); i++) { | |||
| if (!out_grads[i]) { | |||
| mgb_assert(m_param.mode != Param::Mode::REDUCE_SCATTER_SUM, | |||
| "null out grad in CollctiveCommMM currently " | |||
| "unsupported when the forward mode is " | |||
| "Reduce_Scatter_Sum."); | |||
| DTypeScalar dval{output(i)->dtype()}; | |||
| dval.set_retain_dtype(0); | |||
| auto zeros = | |||
| SymbolVar::make_scalar(dval, *output(i)->owner_graph(), | |||
| output(i)->comp_node()) | |||
| .broadcast(SymbolVar(output(i)).symshape()); | |||
| og_syms.push_back(zeros); | |||
| } else { | |||
| og_syms.push_back(out_grads[i]); | |||
| } | |||
| } | |||
| } | |||
| OperatorNodeConfig::CompNodeArray cn_arr; | |||
| if (m_param.mode == Param::Mode::REDUCE_SUM) { | |||
| for (auto i : input()) { | |||
| cn_arr.push_back(i->comp_node()); | |||
| } | |||
| } else if (m_param.mode == Param::Mode::BROADCAST) { | |||
| if (!input().empty()) { | |||
| cn_arr.push_back(input(0)->comp_node()); | |||
| } | |||
| } | |||
| auto gvar = CollectiveComm::make( | |||
| og_syms, owner_graph(), m_key + ":grad", m_nr_devices, m_is_root, | |||
| m_rank, m_group_client, mode, m_dtype, m_backend, | |||
| OperatorNodeConfig{}.comp_node_arr(cn_arr)); | |||
| if (m_param.mode == Param::Mode::ALL_REDUCE_MAX) { | |||
| for (size_t i = 0; i < input().size(); ++i) { | |||
| gvar[i] = Elemwise::make({output(i), input(i), gvar[i]}, | |||
| Elemwise::Mode::COND_LEQ_MOV); | |||
| } | |||
| } else if (m_param.mode == Param::Mode::ALL_REDUCE_MIN) { | |||
| for (size_t i = 0; i < input().size(); ++i) { | |||
| gvar[i] = Elemwise::make({input(i), output(i), gvar[i]}, | |||
| Elemwise::Mode::COND_LEQ_MOV); | |||
| } | |||
| } else if (m_param.mode == Param::Mode::BROADCAST) { | |||
| if (!input().empty()) { | |||
| CompNode&& master_out_cn = input(0)->comp_node(); | |||
| SymbolVarArray rst; | |||
| for (auto i : gvar) { | |||
| if (i.node()->comp_node() == master_out_cn) { | |||
| mgb_assert(rst.empty()); | |||
| rst.push_back(i); | |||
| } | |||
| } | |||
| gvar = rst; | |||
| } | |||
| } | |||
| return cg::to_var_node_array(gvar); | |||
| } | |||
| MGB_IMPL_OPR_GRAD(CollectiveComm) { | |||
| return opr.grad(out_grad); | |||
| } | |||
| void CollectiveComm::init_output_dtype() { | |||
| if (m_dtype.valid()) { | |||
| for (size_t i = 0; i < input().size(); ++i) { | |||
| @@ -797,6 +838,15 @@ void CollectiveComm::init_output_static_infer_desc() { | |||
| } | |||
| } | |||
| VarNode* CollectiveComm::grad(VarNode* out_grad) const { | |||
| return ModeTrait::from_mode(m_param.mode).grad(out_grad, this); | |||
| } | |||
| MGB_IMPL_OPR_GRAD(CollectiveComm) { | |||
| mgb_assert(out_grad.size() == 1, "CollectiveComm should only have one grad"); | |||
| return opr.grad(out_grad[0]); | |||
| } | |||
| /* ===================== shallow copy ===================== */ | |||
| namespace mgb { | |||
| @@ -807,13 +857,14 @@ cg::OperatorNodeBase* opr_shallow_copy_collective_mm( | |||
| const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs, | |||
| const OperatorNodeConfig& config) { | |||
| auto&& opr = opr_.cast_final_safe<opr::CollectiveComm>(); | |||
| auto new_opr = CollectiveComm::make( | |||
| to_symbol_var_array(inputs), ctx.owner_graph(opr_, inputs), | |||
| opr.key(), opr.nr_devices(), opr.is_root(), opr.rank(), | |||
| opr.group_client(), opr.dev_buffers(), opr.param(), | |||
| opr.dtype(), opr.backend(), config)[0] | |||
| .node() | |||
| ->owner_opr(); | |||
| auto new_opr = | |||
| CollectiveComm::make( | |||
| to_symbol_var_array(inputs), ctx.owner_graph(opr_, inputs), | |||
| opr.key(), opr.nr_devices(), opr.is_root(), opr.rank(), | |||
| opr.local_grad(), opr.group_client(), opr.dev_buffers(), | |||
| opr.param(), opr.dtype(), opr.backend(), config)[0] | |||
| .node() | |||
| ->owner_opr(); | |||
| new_opr->cast_final_safe<opr::CollectiveComm>().set_pack_hash(opr.pack_hash()); | |||
| return new_opr; | |||
| } | |||
| @@ -8,6 +8,7 @@ decl_raw_opr( | |||
| 'operation to which this operator belongs.', 'int'), | |||
| Doc('is_root', 'whether this node is root node', 'bool'), | |||
| Doc('rank', 'rank of this node, if is -1, generate one', 'int'), | |||
| Doc('local_grad', 'whether use local grad', 'bool'), | |||
| Doc('server_addr', 'rpc server ip address'), | |||
| Doc('port', 'server rpc listening port'), | |||
| Doc('param', 'The only component of *param* is *mode*, which refers to ' | |||
| @@ -28,12 +29,12 @@ decl_raw_opr( | |||
| body = [ | |||
| 'if isinstance(input, _mgb.SymbolVar):', | |||
| (' output = _mgb._Opr.collective_comm_with_input(input, key, ' | |||
| 'nr_devices, is_root, rank, server_addr, port, ' | |||
| 'nr_devices, is_root, rank, local_grad, server_addr, port, ' | |||
| '[param.serialize()], dtype, backend, output_buffer, config, disable)'), | |||
| 'else:', | |||
| ' assert isinstance(input, _mgb.CompGraph)', | |||
| (' output = _mgb._Opr.collective_comm_without_input(input, key, ' | |||
| 'nr_devices, is_root, rank, server_addr, port, ' | |||
| 'nr_devices, is_root, rank, local_grad, server_addr, port, ' | |||
| '[param.serialize()], dtype, backend, output_buffer, config, disable)') | |||
| ], | |||
| desc = ('collective communication between multiple CompNodes on multiple ' | |||
| @@ -29,8 +29,9 @@ public: | |||
| CollectiveComm( | |||
| VarNodeArray inputs, ComputingGraph* const graph, | |||
| const std::string& key, const size_t nr_devices, const bool is_root, | |||
| const int rank, std::shared_ptr<GroupClient> group_client, | |||
| const Param& param, const DType& dtype, const std::string& backend, | |||
| const int rank, const bool local_grad, | |||
| std::shared_ptr<GroupClient> group_client, const Param& param, | |||
| const DType& dtype, const std::string& backend, | |||
| const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr, | |||
| const OperatorNodeConfig& config, | |||
| const std::shared_ptr<DTypeScalar>& disable); | |||
| @@ -38,7 +39,8 @@ public: | |||
| static SymbolVarArray make( | |||
| const SymbolVarArray& inputs, ComputingGraph* const graph, | |||
| const std::string& key, const size_t nr_devices, const bool is_root, | |||
| const int rank, std::shared_ptr<GroupClient> group_client, | |||
| const int rank, const bool local_grad, | |||
| std::shared_ptr<GroupClient> group_client, | |||
| const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr, | |||
| const Param& param, const DType& dtype = {}, | |||
| const std::string& backend = "nccl", | |||
| @@ -50,6 +52,7 @@ public: | |||
| ComputingGraph* const graph, | |||
| const std::string& key, const size_t nr_devices, | |||
| const bool is_root, const int rank, | |||
| const bool local_grad, | |||
| std::shared_ptr<GroupClient> group_client, | |||
| const Param& param, const DType& dtype = {}, | |||
| const std::string& backend = "nccl", | |||
| @@ -72,6 +75,7 @@ public: | |||
| int rank() const { return m_rank; } | |||
| int root() const { return m_root; } | |||
| bool is_root() const { return m_is_root; } | |||
| bool local_grad() const { return m_local_grad; } | |||
| //! The key that identifies an NCCL clique. | |||
| //! Operators with same keys belong to the same clique. | |||
| @@ -89,7 +93,7 @@ public: | |||
| return m_megray_ctx; | |||
| } | |||
| VarNodeArray grad(const VarNodeArray& out_grad) const; | |||
| VarNode* grad(VarNode* out_grad) const; | |||
| private: | |||
| Barrier m_exec_barrier; | |||
| @@ -116,6 +120,7 @@ private: | |||
| size_t m_nr_devices = 0; | |||
| bool m_is_root; | |||
| int m_rank; | |||
| bool m_local_grad; | |||
| std::string m_key; | |||
| //! XXHash generated from m_key | |||
| size_t m_hash; | |||
| @@ -46,7 +46,7 @@ pdef('PersistentOutputStorage').add_fields( | |||
| (pdef('CollectiveComm', 'collective communication between multiple computing ' | |||
| 'nodes on localhost') | |||
| .add_enum('Mode', | |||
| .add_enum(Doc('Mode', 'mode of collective communication'), | |||
| Doc('REDUCE_SUM', 'reduce by sum to output computing node'), | |||
| Doc('BROADCAST', 'copy input value to each output computing node'), | |||
| Doc('ALL_GATHER', 'each output comp node gets the concatenated ' | |||
| @@ -59,7 +59,8 @@ pdef('PersistentOutputStorage').add_fields( | |||
| Doc('ALL_REDUCE_PROD', 'every output gets the prod of all inputs'), | |||
| Doc('GATHER', 'concat inputs to one node'), | |||
| Doc('SCATTER', 'scatter input to each output computing node'), | |||
| Doc('ALL_TO_ALL', 'scatter inputs and gather them on each computing node'))) | |||
| Doc('ALL_TO_ALL', 'scatter inputs and gather them on each computing node'), | |||
| name_field='mode')) | |||
| (pdef('FakeSerializedDType', | |||
| 'HACK: The tag of this param def is actually used for another ' | |||