GitOrigin-RevId: cc120cfb55
tags/v0.6.0
| @@ -11,10 +11,13 @@ from .functional import ( | |||||
| all_reduce_max, | all_reduce_max, | ||||
| all_reduce_min, | all_reduce_min, | ||||
| all_reduce_sum, | all_reduce_sum, | ||||
| all_to_all, | |||||
| bcast_param, | bcast_param, | ||||
| broadcast, | broadcast, | ||||
| gather, | |||||
| reduce_scatter_sum, | reduce_scatter_sum, | ||||
| reduce_sum, | reduce_sum, | ||||
| scatter, | |||||
| ) | ) | ||||
| from .util import ( | from .util import ( | ||||
| get_backend, | get_backend, | ||||
| @@ -9,7 +9,7 @@ | |||||
| from typing import Optional, Union | from typing import Optional, Union | ||||
| import megengine._internal as mgb | import megengine._internal as mgb | ||||
| from megengine._internal.opr_param_defs import CollectiveComm as CollParam | |||||
| from megengine._internal.opr_param_defs import CollectiveComm as Param | |||||
| from ..core import Buffer, Parameter, Tensor, wrap_io_tensor | from ..core import Buffer, Parameter, Tensor, wrap_io_tensor | ||||
| from ..functional import add_update | from ..functional import add_update | ||||
| @@ -22,9 +22,16 @@ def _collective_comm(*args, **kargs): | |||||
| return collective_comm_symvar(*args, **kargs) | return collective_comm_symvar(*args, **kargs) | ||||
| def _group_check(*args): | |||||
| """Return True when arguments are all None or all not None | |||||
| """ | |||||
| l = [val is None for val in args] | |||||
| return len(set(l)) <= 1 | |||||
| def reduce_sum( | def reduce_sum( | ||||
| tensor: Tensor, | tensor: Tensor, | ||||
| key: str, | |||||
| key: Optional[str] = None, | |||||
| nr_ranks: Optional[int] = None, | nr_ranks: Optional[int] = None, | ||||
| is_root: Optional[bool] = None, | is_root: Optional[bool] = None, | ||||
| ) -> Tensor: | ) -> Tensor: | ||||
| @@ -35,14 +42,17 @@ def reduce_sum( | |||||
| :param nr_ranks: number of ranks, use util.get_world_size() as default | :param nr_ranks: number of ranks, use util.get_world_size() as default | ||||
| :param is_root: whether this is a root node | :param is_root: whether this is a root node | ||||
| """ | """ | ||||
| assert _group_check( | |||||
| key, nr_ranks, is_root | |||||
| ), "key, nr_ranks, is_root should be set at the same time" | |||||
| return _collective_comm( | return _collective_comm( | ||||
| tensor, key, CollParam.Mode.REDUCE_SUM, nr_ranks, is_root, device=tensor.device, | |||||
| tensor, key, Param.Mode.REDUCE_SUM, nr_ranks, is_root, device=tensor.device, | |||||
| ) | ) | ||||
| def gather( | def gather( | ||||
| tensor: Tensor, | tensor: Tensor, | ||||
| key: str, | |||||
| key: Optional[str] = None, | |||||
| nr_ranks: Optional[int] = None, | nr_ranks: Optional[int] = None, | ||||
| is_root: Optional[bool] = None, | is_root: Optional[bool] = None, | ||||
| rank: Optional[int] = None, | rank: Optional[int] = None, | ||||
| @@ -55,20 +65,17 @@ def gather( | |||||
| :param is_root: whether this is a root node | :param is_root: whether this is a root node | ||||
| :param rank: rank of this node | :param rank: rank of this node | ||||
| """ | """ | ||||
| assert _group_check( | |||||
| key, nr_ranks, is_root, rank | |||||
| ), "key, nr_ranks, is_root, rank should be set at the same time" | |||||
| return _collective_comm( | return _collective_comm( | ||||
| tensor, | |||||
| key, | |||||
| CollParam.Mode.GATHER, | |||||
| nr_ranks, | |||||
| is_root, | |||||
| rank, | |||||
| device=tensor.device, | |||||
| tensor, key, Param.Mode.GATHER, nr_ranks, is_root, rank, device=tensor.device, | |||||
| ) | ) | ||||
| def broadcast( | def broadcast( | ||||
| tensor: Tensor, | tensor: Tensor, | ||||
| key: str, | |||||
| key: Optional[str] = None, | |||||
| nr_ranks: Optional[int] = None, | nr_ranks: Optional[int] = None, | ||||
| is_root: Optional[bool] = None, | is_root: Optional[bool] = None, | ||||
| ) -> Tensor: | ) -> Tensor: | ||||
| @@ -79,11 +86,12 @@ def broadcast( | |||||
| :param nr_ranks: number of ranks, use util.get_world_size() as default | :param nr_ranks: number of ranks, use util.get_world_size() as default | ||||
| :param is_root: whether this is a root node | :param is_root: whether this is a root node | ||||
| """ | """ | ||||
| if key is None: | |||||
| key = tensor._symvar.name | |||||
| assert _group_check( | |||||
| key, nr_ranks, is_root | |||||
| ), "key, nr_ranks, is_root should be set at the same time" | |||||
| if is_root is None: | if is_root is None: | ||||
| is_root = get_rank() == 0 | is_root = get_rank() == 0 | ||||
| if is_root: | if is_root: | ||||
| inp = tensor | inp = tensor | ||||
| else: | else: | ||||
| @@ -92,7 +100,7 @@ def broadcast( | |||||
| return _collective_comm( | return _collective_comm( | ||||
| inp, | inp, | ||||
| key, | key, | ||||
| CollParam.Mode.BROADCAST, | |||||
| Param.Mode.BROADCAST, | |||||
| nr_ranks, | nr_ranks, | ||||
| is_root, | is_root, | ||||
| dtype=tensor.dtype, | dtype=tensor.dtype, | ||||
| @@ -102,7 +110,7 @@ def broadcast( | |||||
| def scatter( | def scatter( | ||||
| tensor: Tensor, | tensor: Tensor, | ||||
| key: str, | |||||
| key: Optional[str] = None, | |||||
| nr_ranks: Optional[int] = None, | nr_ranks: Optional[int] = None, | ||||
| is_root: Optional[bool] = None, | is_root: Optional[bool] = None, | ||||
| rank: Optional[int] = None, | rank: Optional[int] = None, | ||||
| @@ -115,6 +123,9 @@ def scatter( | |||||
| :param is_root: whether this is a root node | :param is_root: whether this is a root node | ||||
| :param rank: rank of this node | :param rank: rank of this node | ||||
| """ | """ | ||||
| assert _group_check( | |||||
| key, nr_ranks, is_root, rank | |||||
| ), "key, nr_ranks, is_root, rank should be set at the same time" | |||||
| if key is None: | if key is None: | ||||
| key = tensor._symvar.name | key = tensor._symvar.name | ||||
| if is_root is None: | if is_root is None: | ||||
| @@ -128,7 +139,7 @@ def scatter( | |||||
| return _collective_comm( | return _collective_comm( | ||||
| inp, | inp, | ||||
| key, | key, | ||||
| CollParam.Mode.SCATTER, | |||||
| Param.Mode.SCATTER, | |||||
| nr_ranks, | nr_ranks, | ||||
| is_root, | is_root, | ||||
| rank, | rank, | ||||
| @@ -138,7 +149,11 @@ def scatter( | |||||
| def all_to_all( | def all_to_all( | ||||
| tensor: Tensor, key: str, nr_ranks: Optional[int] = None, rank: Optional[int] = None | |||||
| tensor: Tensor, | |||||
| key: Optional[str] = None, | |||||
| nr_ranks: Optional[int] = None, | |||||
| rank: Optional[int] = None, | |||||
| local_grad: Optional[bool] = False, | |||||
| ) -> Tensor: | ) -> Tensor: | ||||
| """Create all_to_all operator for collective communication | """Create all_to_all operator for collective communication | ||||
| @@ -146,12 +161,22 @@ def all_to_all( | |||||
| :param key: unique identifier for collective communication | :param key: unique identifier for collective communication | ||||
| :param nr_ranks: number of ranks, use util.get_world_size() as default | :param nr_ranks: number of ranks, use util.get_world_size() as default | ||||
| :param rank: rank of this node | :param rank: rank of this node | ||||
| :param local_grad: whether use local grad | |||||
| """ | """ | ||||
| return _collective_comm(tensor, key, CollParam.Mode.ALL_TO_ALL, nr_ranks, rank=rank) | |||||
| assert _group_check( | |||||
| key, nr_ranks, rank | |||||
| ), "key, nr_ranks, rank should be set at the same time" | |||||
| return _collective_comm( | |||||
| tensor, key, Param.Mode.ALL_TO_ALL, nr_ranks, rank=rank, local_grad=local_grad, | |||||
| ) | |||||
| def all_gather( | def all_gather( | ||||
| tensor: Tensor, key: str, nr_ranks: Optional[int] = None, rank: Optional[int] = None | |||||
| tensor: Tensor, | |||||
| key: Optional[str] = None, | |||||
| nr_ranks: Optional[int] = None, | |||||
| rank: Optional[int] = None, | |||||
| local_grad: Optional[bool] = False, | |||||
| ) -> Tensor: | ) -> Tensor: | ||||
| """Create all_gather operator for collective communication | """Create all_gather operator for collective communication | ||||
| @@ -159,12 +184,22 @@ def all_gather( | |||||
| :param key: unique identifier for collective communication | :param key: unique identifier for collective communication | ||||
| :param nr_ranks: number of ranks, use util.get_world_size() as default | :param nr_ranks: number of ranks, use util.get_world_size() as default | ||||
| :param rank: rank of this node | :param rank: rank of this node | ||||
| :param local_grad: whether use local grad | |||||
| """ | """ | ||||
| return _collective_comm(tensor, key, CollParam.Mode.ALL_GATHER, nr_ranks, rank=rank) | |||||
| assert _group_check( | |||||
| key, nr_ranks, rank | |||||
| ), "key, nr_ranks, rank should be set at the same time" | |||||
| return _collective_comm( | |||||
| tensor, key, Param.Mode.ALL_GATHER, nr_ranks, rank=rank, local_grad=local_grad | |||||
| ) | |||||
| def reduce_scatter_sum( | def reduce_scatter_sum( | ||||
| tensor: Tensor, key: str, nr_ranks: Optional[int] = None, rank: Optional[int] = None | |||||
| tensor: Tensor, | |||||
| key: Optional[str] = None, | |||||
| nr_ranks: Optional[int] = None, | |||||
| rank: Optional[int] = None, | |||||
| local_grad: Optional[bool] = False, | |||||
| ) -> Tensor: | ) -> Tensor: | ||||
| """Create reduce_scatter_sum operator for collective communication | """Create reduce_scatter_sum operator for collective communication | ||||
| @@ -172,45 +207,81 @@ def reduce_scatter_sum( | |||||
| :param key: unique identifier for collective communication | :param key: unique identifier for collective communication | ||||
| :param nr_ranks: number of ranks, use util.get_world_size() as default | :param nr_ranks: number of ranks, use util.get_world_size() as default | ||||
| :param rank: rank of this node | :param rank: rank of this node | ||||
| :param local_grad: whether use local grad | |||||
| """ | """ | ||||
| assert _group_check( | |||||
| key, nr_ranks, rank | |||||
| ), "key, nr_ranks, rank should be set at the same time" | |||||
| return _collective_comm( | return _collective_comm( | ||||
| tensor, key, CollParam.Mode.REDUCE_SCATTER_SUM, nr_ranks, rank=rank, | |||||
| tensor, | |||||
| key, | |||||
| Param.Mode.REDUCE_SCATTER_SUM, | |||||
| nr_ranks, | |||||
| rank=rank, | |||||
| local_grad=local_grad, | |||||
| ) | ) | ||||
| def all_reduce_sum(tensor: Tensor, key: str, nr_ranks: Optional[int] = None) -> Tensor: | |||||
| def all_reduce_sum( | |||||
| tensor: Tensor, | |||||
| key: Optional[str] = None, | |||||
| nr_ranks: Optional[int] = None, | |||||
| local_grad: Optional[bool] = False, | |||||
| ) -> Tensor: | |||||
| """Create all_reduce_sum operator for collective communication | """Create all_reduce_sum operator for collective communication | ||||
| :param tensor: input tensor | :param tensor: input tensor | ||||
| :param key: unique identifier for collective communication | :param key: unique identifier for collective communication | ||||
| :param nr_ranks: number of ranks, use util.get_world_size() as default | :param nr_ranks: number of ranks, use util.get_world_size() as default | ||||
| :param local_grad: whether use local grad | |||||
| """ | """ | ||||
| return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_SUM, nr_ranks) | |||||
| assert _group_check(key, nr_ranks), "key, nr_ranks should be set at the same time" | |||||
| return _collective_comm( | |||||
| tensor, key, Param.Mode.ALL_REDUCE_SUM, nr_ranks, local_grad=local_grad | |||||
| ) | |||||
| def all_reduce_max(tensor: Tensor, key: str, nr_ranks: Optional[int] = None) -> Tensor: | |||||
| def all_reduce_max( | |||||
| tensor: Tensor, | |||||
| key: Optional[str] = None, | |||||
| nr_ranks: Optional[int] = None, | |||||
| local_grad: Optional[bool] = False, | |||||
| ) -> Tensor: | |||||
| """Create all_reduce_max operator for collective communication | """Create all_reduce_max operator for collective communication | ||||
| :param tensor: input tensor | :param tensor: input tensor | ||||
| :param key: unique identifier for collective communication | :param key: unique identifier for collective communication | ||||
| :param nr_ranks: number of ranks, use util.get_world_size() as default | :param nr_ranks: number of ranks, use util.get_world_size() as default | ||||
| :param local_grad: whether use local grad | |||||
| """ | """ | ||||
| return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_MAX, nr_ranks) | |||||
| assert _group_check(key, nr_ranks), "key, nr_ranks should be set at the same time" | |||||
| return _collective_comm( | |||||
| tensor, key, Param.Mode.ALL_REDUCE_MAX, nr_ranks, local_grad=local_grad | |||||
| ) | |||||
| def all_reduce_min(tensor: Tensor, key: str, nr_ranks: Optional[int] = None) -> Tensor: | |||||
| def all_reduce_min( | |||||
| tensor: Tensor, | |||||
| key: Optional[str] = None, | |||||
| nr_ranks: Optional[int] = None, | |||||
| local_grad: Optional[bool] = False, | |||||
| ) -> Tensor: | |||||
| """Create all_reduce_min operator for collective communication | """Create all_reduce_min operator for collective communication | ||||
| :param tensor: input tensor | :param tensor: input tensor | ||||
| :param key: unique identifier for collective communication | :param key: unique identifier for collective communication | ||||
| :param nr_ranks: number of ranks, use util.get_world_size() as default | :param nr_ranks: number of ranks, use util.get_world_size() as default | ||||
| :param local_grad: whether use local grad | |||||
| """ | """ | ||||
| return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_MIN, nr_ranks) | |||||
| assert _group_check(key, nr_ranks), "key, nr_ranks should be set at the same time" | |||||
| return _collective_comm( | |||||
| tensor, key, Param.Mode.ALL_REDUCE_MIN, nr_ranks, local_grad=local_grad | |||||
| ) | |||||
| def bcast_param( | def bcast_param( | ||||
| inp: Union[Buffer, Parameter], | inp: Union[Buffer, Parameter], | ||||
| key: str, | |||||
| key: Optional[str] = None, | |||||
| nr_ranks: Optional[int] = None, | nr_ranks: Optional[int] = None, | ||||
| is_root: Optional[bool] = None, | is_root: Optional[bool] = None, | ||||
| ) -> None: | ) -> None: | ||||
| @@ -223,6 +294,9 @@ def bcast_param( | |||||
| """ | """ | ||||
| if not is_distributed(): | if not is_distributed(): | ||||
| return | return | ||||
| assert _group_check( | |||||
| key, nr_ranks, is_root | |||||
| ), "key, nr_ranks, is_root should be set at the same time" | |||||
| assert isinstance(inp, (Buffer, Parameter)) | assert isinstance(inp, (Buffer, Parameter)) | ||||
| bcast_res = broadcast(inp, key, nr_ranks, is_root) | bcast_res = broadcast(inp, key, nr_ranks, is_root) | ||||
| add_update(inp, bcast_res, alpha=0) | add_update(inp, bcast_res, alpha=0) | ||||
| @@ -11,16 +11,24 @@ from typing import Optional, Union | |||||
| import megengine._internal as mgb | import megengine._internal as mgb | ||||
| from megengine._internal.opr_param_defs import CollectiveComm as CollParam | from megengine._internal.opr_param_defs import CollectiveComm as CollParam | ||||
| from .util import get_backend, get_master_ip, get_master_port, get_rank, get_world_size | |||||
| from .util import ( | |||||
| get_backend, | |||||
| get_group_id, | |||||
| get_master_ip, | |||||
| get_master_port, | |||||
| get_rank, | |||||
| get_world_size, | |||||
| ) | |||||
| def collective_comm_symvar( | def collective_comm_symvar( | ||||
| inp: Union[mgb.SymbolVar, mgb.CompGraph], | inp: Union[mgb.SymbolVar, mgb.CompGraph], | ||||
| key: str, | |||||
| op: CollParam.Mode, | |||||
| key: Optional[str] = None, | |||||
| op: CollParam.Mode = None, | |||||
| nr_ranks: Optional[int] = None, | nr_ranks: Optional[int] = None, | ||||
| is_root: Optional[bool] = None, | is_root: Optional[bool] = None, | ||||
| rank: Optional[int] = None, | rank: Optional[int] = None, | ||||
| local_grad: Optional[bool] = False, | |||||
| dtype: Optional[type] = None, | dtype: Optional[type] = None, | ||||
| device: Optional[mgb.CompNode] = None, | device: Optional[mgb.CompNode] = None, | ||||
| comp_graph: Optional[mgb.CompGraph] = None, | comp_graph: Optional[mgb.CompGraph] = None, | ||||
| @@ -32,16 +40,19 @@ def collective_comm_symvar( | |||||
| :param op: mode of collective communication | :param op: mode of collective communication | ||||
| :param nr_ranks: number of ranks, use util.get_world_size() as default | :param nr_ranks: number of ranks, use util.get_world_size() as default | ||||
| :param is_root: whether this node is root node | :param is_root: whether this node is root node | ||||
| :param rank: rank of this node | |||||
| :param local_grad: whether use local grad | |||||
| :param dtype: output data type, use dtype of inp as default | :param dtype: output data type, use dtype of inp as default | ||||
| :param device: output comp node, use comp node of inp as default | :param device: output comp node, use comp node of inp as default | ||||
| :param comp_graph: output comp graph, use comp graph of inp as default | :param comp_graph: output comp graph, use comp graph of inp as default | ||||
| """ | """ | ||||
| return mgb.opr.collective_comm( | return mgb.opr.collective_comm( | ||||
| inp, | inp, | ||||
| key=str(key), | |||||
| key=key if key is not None else ("collective_comm_" + str(get_group_id())), | |||||
| nr_devices=nr_ranks if nr_ranks is not None else get_world_size(), | nr_devices=nr_ranks if nr_ranks is not None else get_world_size(), | ||||
| is_root=is_root if is_root is not None else (get_rank() == 0), | is_root=is_root if is_root is not None else (get_rank() == 0), | ||||
| rank=rank if rank is not None else -1, | |||||
| rank=rank if rank is not None else get_rank(), | |||||
| local_grad=local_grad, | |||||
| server_addr=get_master_ip(), | server_addr=get_master_ip(), | ||||
| port=get_master_port(), | port=get_master_port(), | ||||
| param=CollParam(mode=op), | param=CollParam(mode=op), | ||||
| @@ -182,7 +182,9 @@ class Optimizer(metaclass=ABCMeta): | |||||
| with opr_priority_scope(cg, -(2 ** 30)): | with opr_priority_scope(cg, -(2 ** 30)): | ||||
| # always run all_reduce_mean first except add_update | # always run all_reduce_mean first except add_update | ||||
| grad = ( | grad = ( | ||||
| all_reduce_sum(grad, "grad_" + str(get_group_id())) | |||||
| all_reduce_sum( | |||||
| grad, "grad_" + str(get_group_id()), get_world_size() | |||||
| ) | |||||
| / get_world_size() | / get_world_size() | ||||
| ) | ) | ||||
| with opr_priority_scope(cg, -(2 ** 31)): | with opr_priority_scope(cg, -(2 ** 31)): | ||||
| @@ -229,7 +231,7 @@ class Optimizer(metaclass=ABCMeta): | |||||
| for group in self.param_groups: | for group in self.param_groups: | ||||
| for param in group["params"]: | for param in group["params"]: | ||||
| bcast_param( | bcast_param( | ||||
| param, "bcast_param_" + str(key), is_root=(get_rank() == 0), | |||||
| param, "bcast_param_" + str(key), get_world_size(), get_rank() == 0, | |||||
| ) | ) | ||||
| key += 1 | key += 1 | ||||
| @@ -94,9 +94,9 @@ SymbolVar _Opr::remote_recv(const std::string& server_addr, const int port, | |||||
| SymbolVar _Opr::collective_comm_with_input( | SymbolVar _Opr::collective_comm_with_input( | ||||
| SymbolVar inpvar, const std::string& key, const size_t nr_devices, | SymbolVar inpvar, const std::string& key, const size_t nr_devices, | ||||
| const bool is_root, const int rank, const std::string& server_addr, | |||||
| const int port, PyObject* params, PyObject* dtype, | |||||
| const std::string& backend, SharedND* output_buf, | |||||
| const bool is_root, const int rank, const bool local_grad, | |||||
| const std::string& server_addr, const int port, PyObject* params, | |||||
| PyObject* dtype, const std::string& backend, SharedND* output_buf, | |||||
| const OperatorNodeConfig& config, const SharedScalar& disable) { | const OperatorNodeConfig& config, const SharedScalar& disable) { | ||||
| SymbolVarArray inputs(1, inpvar); | SymbolVarArray inputs(1, inpvar); | ||||
| ComputingGraph* graph = inpvar.node()->owner_graph(); | ComputingGraph* graph = inpvar.node()->owner_graph(); | ||||
| @@ -111,15 +111,15 @@ SymbolVar _Opr::collective_comm_with_input( | |||||
| _dtype = npy::dtype_np2mgb(dtype); | _dtype = npy::dtype_np2mgb(dtype); | ||||
| } | } | ||||
| return CollectiveComm::make(inputs, graph, key, nr_devices, is_root, rank, | return CollectiveComm::make(inputs, graph, key, nr_devices, is_root, rank, | ||||
| group_mgr, dev_buffer_arr, param, _dtype, | |||||
| backend, config, disable.get_val())[0]; | |||||
| local_grad, group_mgr, dev_buffer_arr, param, | |||||
| _dtype, backend, config, disable.get_val())[0]; | |||||
| } | } | ||||
| SymbolVar _Opr::collective_comm_without_input( | SymbolVar _Opr::collective_comm_without_input( | ||||
| CompGraph& cg, const std::string& key, const size_t nr_devices, | CompGraph& cg, const std::string& key, const size_t nr_devices, | ||||
| const bool is_root, const int rank, const std::string& server_addr, | |||||
| const int port, PyObject* params, PyObject* dtype, | |||||
| const std::string& backend, SharedND* output_buf, | |||||
| const bool is_root, const int rank, const bool local_grad, | |||||
| const std::string& server_addr, const int port, PyObject* params, | |||||
| PyObject* dtype, const std::string& backend, SharedND* output_buf, | |||||
| const OperatorNodeConfig& config, const SharedScalar& disable) { | const OperatorNodeConfig& config, const SharedScalar& disable) { | ||||
| SymbolVarArray inputs; | SymbolVarArray inputs; | ||||
| auto& graph = cg.get(); | auto& graph = cg.get(); | ||||
| @@ -134,8 +134,8 @@ SymbolVar _Opr::collective_comm_without_input( | |||||
| _dtype = npy::dtype_np2mgb(dtype); | _dtype = npy::dtype_np2mgb(dtype); | ||||
| } | } | ||||
| return CollectiveComm::make(inputs, &graph, key, nr_devices, is_root, rank, | return CollectiveComm::make(inputs, &graph, key, nr_devices, is_root, rank, | ||||
| group_mgr, dev_buffer_arr, param, _dtype, | |||||
| backend, config, disable.get_val())[0]; | |||||
| local_grad, group_mgr, dev_buffer_arr, param, | |||||
| _dtype, backend, config, disable.get_val())[0]; | |||||
| } | } | ||||
| #else | #else | ||||
| @@ -171,8 +171,8 @@ SymbolVar _Opr::remote_recv(const std::string& server_addr, const int port, | |||||
| } | } | ||||
| SymbolVar _Opr::collective_comm_with_input( | SymbolVar _Opr::collective_comm_with_input( | ||||
| SymbolVar inpvar, const std::string& key, | |||||
| const size_t nr_devices, const bool is_root, const int rank, | |||||
| SymbolVar inpvar, const std::string& key, const size_t nr_devices, | |||||
| const bool is_root, const int rank, const bool local_grad, | |||||
| const std::string& server_addr, const int port, PyObject* params, | const std::string& server_addr, const int port, PyObject* params, | ||||
| PyObject* dtype, const std::string& backend, SharedND* output_buf, | PyObject* dtype, const std::string& backend, SharedND* output_buf, | ||||
| const OperatorNodeConfig& config, const SharedScalar& disable) { | const OperatorNodeConfig& config, const SharedScalar& disable) { | ||||
| @@ -180,8 +180,8 @@ SymbolVar _Opr::collective_comm_with_input( | |||||
| } | } | ||||
| SymbolVar _Opr::collective_comm_without_input( | SymbolVar _Opr::collective_comm_without_input( | ||||
| CompGraph& cg, const std::string& key, | |||||
| const size_t nr_devices, const bool is_root, const int rank, | |||||
| CompGraph& cg, const std::string& key, const size_t nr_devices, | |||||
| const bool is_root, const int rank, const bool local_grad, | |||||
| const std::string& server_addr, const int port, PyObject* params, | const std::string& server_addr, const int port, PyObject* params, | ||||
| PyObject* dtype, const std::string& backend, SharedND* output_buf, | PyObject* dtype, const std::string& backend, SharedND* output_buf, | ||||
| const OperatorNodeConfig& config, const SharedScalar& disable) { | const OperatorNodeConfig& config, const SharedScalar& disable) { | ||||
| @@ -94,17 +94,17 @@ static SymbolVar remote_recv(const std::string& server_addr, const int port, | |||||
| static SymbolVar collective_comm_with_input( | static SymbolVar collective_comm_with_input( | ||||
| SymbolVar inpvar, const std::string& key, const size_t nr_devices, | SymbolVar inpvar, const std::string& key, const size_t nr_devices, | ||||
| const bool is_root, const int rank, const std::string& server_addr, const int port, | |||||
| PyObject* params, PyObject* dtype, const std::string& backend, | |||||
| SharedND* output_buf, const OperatorNodeConfig& config, | |||||
| const SharedScalar& disable); | |||||
| const bool is_root, const int rank, const bool local_grad, | |||||
| const std::string& server_addr, const int port, PyObject* params, | |||||
| PyObject* dtype, const std::string& backend, SharedND* output_buf, | |||||
| const OperatorNodeConfig& config, const SharedScalar& disable); | |||||
| static SymbolVar collective_comm_without_input( | static SymbolVar collective_comm_without_input( | ||||
| CompGraph& graph, const std::string& key, const size_t nr_devices, | CompGraph& graph, const std::string& key, const size_t nr_devices, | ||||
| const bool is_root, const int rank, const std::string& server_addr, const int port, | |||||
| PyObject* params, PyObject* dtype, const std::string& backend, | |||||
| SharedND* output_buf, const OperatorNodeConfig& config, | |||||
| const SharedScalar& disable); | |||||
| const bool is_root, const int rank, const bool local_grad, | |||||
| const std::string& server_addr, const int port, PyObject* params, | |||||
| PyObject* dtype, const std::string& backend, SharedND* output_buf, | |||||
| const OperatorNodeConfig& config, const SharedScalar& disable); | |||||
| // misc | // misc | ||||
| static SymbolVarArray extern_c_opr_placeholder( | static SymbolVarArray extern_c_opr_placeholder( | ||||
| @@ -34,7 +34,7 @@ def test_reduce_sum(): | |||||
| return | return | ||||
| _init_process_group_wrapper(world_size, rank, rank, backend, port_queue) | _init_process_group_wrapper(world_size, rank, rank, backend, port_queue) | ||||
| inp = tensor(data) | inp = tensor(data) | ||||
| output = dist.functional.reduce_sum(inp, "x") | |||||
| output = dist.functional.reduce_sum(inp) | |||||
| if rank == 0: | if rank == 0: | ||||
| assert np.allclose(output.numpy(), expect) | assert np.allclose(output.numpy(), expect) | ||||
| else: | else: | ||||
| @@ -70,7 +70,7 @@ def test_gather(): | |||||
| return | return | ||||
| _init_process_group_wrapper(world_size, rank, rank, backend, port_queue) | _init_process_group_wrapper(world_size, rank, rank, backend, port_queue) | ||||
| inp = tensor(data) | inp = tensor(data) | ||||
| output = dist.functional.gather(inp, "x", is_root=(rank == 0), rank=rank) | |||||
| output = dist.functional.gather(inp) | |||||
| if rank == 0: | if rank == 0: | ||||
| assert np.allclose(output.numpy(), expect) | assert np.allclose(output.numpy(), expect) | ||||
| else: | else: | ||||
| @@ -106,7 +106,7 @@ def test_broadcast(): | |||||
| return | return | ||||
| _init_process_group_wrapper(world_size, rank, rank, backend, port_queue) | _init_process_group_wrapper(world_size, rank, rank, backend, port_queue) | ||||
| inp = tensor(data) | inp = tensor(data) | ||||
| output = dist.functional.broadcast(inp, "x") | |||||
| output = dist.functional.broadcast(inp) | |||||
| assert np.allclose(output.numpy(), expect) | assert np.allclose(output.numpy(), expect) | ||||
| def check(shape, backend): | def check(shape, backend): | ||||
| @@ -138,7 +138,7 @@ def test_scatter(): | |||||
| return | return | ||||
| _init_process_group_wrapper(world_size, rank, rank, backend, port_queue) | _init_process_group_wrapper(world_size, rank, rank, backend, port_queue) | ||||
| inp = tensor(data) | inp = tensor(data) | ||||
| output = dist.functional.scatter(inp, "x", is_root=(rank == 0), rank=rank) | |||||
| output = dist.functional.scatter(inp) | |||||
| assert np.allclose(output.numpy(), expect) | assert np.allclose(output.numpy(), expect) | ||||
| def check(shape, backend): | def check(shape, backend): | ||||
| @@ -174,7 +174,7 @@ def test_all_to_all(): | |||||
| return | return | ||||
| _init_process_group_wrapper(world_size, rank, rank, backend, port_queue) | _init_process_group_wrapper(world_size, rank, rank, backend, port_queue) | ||||
| inp = tensor(data) | inp = tensor(data) | ||||
| output = dist.functional.all_to_all(inp, "x", rank=rank) | |||||
| output = dist.functional.all_to_all(inp) | |||||
| assert np.allclose(output.numpy(), expect) | assert np.allclose(output.numpy(), expect) | ||||
| def check(shape, backend): | def check(shape, backend): | ||||
| @@ -208,7 +208,7 @@ def test_all_gather(): | |||||
| return | return | ||||
| _init_process_group_wrapper(world_size, rank, rank, backend, port_queue) | _init_process_group_wrapper(world_size, rank, rank, backend, port_queue) | ||||
| inp = tensor(data) | inp = tensor(data) | ||||
| output = dist.functional.all_gather(inp, "x", rank=rank) | |||||
| output = dist.functional.all_gather(inp) | |||||
| assert np.allclose(output.numpy(), expect) | assert np.allclose(output.numpy(), expect) | ||||
| def check(shape, backend): | def check(shape, backend): | ||||
| @@ -241,7 +241,7 @@ def test_reduce_scatter_sum(): | |||||
| return | return | ||||
| _init_process_group_wrapper(world_size, rank, rank, backend, port_queue) | _init_process_group_wrapper(world_size, rank, rank, backend, port_queue) | ||||
| inp = tensor(data) | inp = tensor(data) | ||||
| output = dist.functional.reduce_scatter_sum(inp, "x", rank=rank) | |||||
| output = dist.functional.reduce_scatter_sum(inp) | |||||
| assert np.allclose(output.numpy(), expect) | assert np.allclose(output.numpy(), expect) | ||||
| def check(shape, backend): | def check(shape, backend): | ||||
| @@ -278,7 +278,7 @@ def test_all_reduce_sum(): | |||||
| return | return | ||||
| _init_process_group_wrapper(world_size, rank, rank, backend, port_queue) | _init_process_group_wrapper(world_size, rank, rank, backend, port_queue) | ||||
| inp = tensor(data) | inp = tensor(data) | ||||
| output = dist.functional.all_reduce_sum(inp, "x") | |||||
| output = dist.functional.all_reduce_sum(inp) | |||||
| assert np.allclose(output.numpy(), expect) | assert np.allclose(output.numpy(), expect) | ||||
| def check(shape, backend): | def check(shape, backend): | ||||
| @@ -311,7 +311,7 @@ def test_all_reduce_max(): | |||||
| return | return | ||||
| _init_process_group_wrapper(world_size, rank, rank, backend, port_queue) | _init_process_group_wrapper(world_size, rank, rank, backend, port_queue) | ||||
| inp = tensor(data) | inp = tensor(data) | ||||
| output = dist.functional.all_reduce_max(inp, "x") | |||||
| output = dist.functional.all_reduce_max(inp) | |||||
| assert np.allclose(output.numpy(), expect) | assert np.allclose(output.numpy(), expect) | ||||
| def check(shape, backend): | def check(shape, backend): | ||||
| @@ -344,7 +344,7 @@ def test_all_reduce_min(): | |||||
| return | return | ||||
| _init_process_group_wrapper(world_size, rank, rank, backend, port_queue) | _init_process_group_wrapper(world_size, rank, rank, backend, port_queue) | ||||
| inp = tensor(data) | inp = tensor(data) | ||||
| output = dist.functional.all_reduce_min(inp, "x") | |||||
| output = dist.functional.all_reduce_min(inp) | |||||
| assert np.allclose(output.numpy(), expect) | assert np.allclose(output.numpy(), expect) | ||||
| def check(shape, backend): | def check(shape, backend): | ||||
| @@ -377,7 +377,7 @@ def test_bcast_param(): | |||||
| return | return | ||||
| _init_process_group_wrapper(world_size, rank, rank, backend, port_queue) | _init_process_group_wrapper(world_size, rank, rank, backend, port_queue) | ||||
| inp = Parameter(data) | inp = Parameter(data) | ||||
| dist.functional.bcast_param(inp, "x") | |||||
| dist.functional.bcast_param(inp) | |||||
| assert np.allclose(inp.numpy(), expect) | assert np.allclose(inp.numpy(), expect) | ||||
| def check(shape, backend): | def check(shape, backend): | ||||
| @@ -688,6 +688,7 @@ bool PackAllReduceScanPass::check_pattern(OperatorNodeBase* opr) { | |||||
| if (!opr->same_type<opr::CollectiveComm>()) return false; | if (!opr->same_type<opr::CollectiveComm>()) return false; | ||||
| auto& comm = opr->cast_final_safe<opr::CollectiveComm>(); | auto& comm = opr->cast_final_safe<opr::CollectiveComm>(); | ||||
| if (comm.param().mode != opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM) return false; | if (comm.param().mode != opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM) return false; | ||||
| if (comm.local_grad()) return false; | |||||
| if (comm.input().size() != 1) return false; | if (comm.input().size() != 1) return false; | ||||
| auto grad = comm.input(0)->owner_opr(); | auto grad = comm.input(0)->owner_opr(); | ||||
| @@ -839,7 +840,7 @@ void PackAllReduceReplacePass::insert_packed_oprs( | |||||
| std::string key = ssprintf("grad_pack_%zu", pack_id); | std::string key = ssprintf("grad_pack_%zu", pack_id); | ||||
| auto param = opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM; | auto param = opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM; | ||||
| SymbolVar allreduce = opr::CollectiveComm::make({concat}, graph, | SymbolVar allreduce = opr::CollectiveComm::make({concat}, graph, | ||||
| key, info->nr_devices, info->is_root, info->rank, | |||||
| key, info->nr_devices, info->is_root, info->rank, false, | |||||
| info->group_client, param, info->dtype, info->backend)[0]; | info->group_client, param, info->dtype, info->backend)[0]; | ||||
| // split according to recorded partition | // split according to recorded partition | ||||
| @@ -438,14 +438,14 @@ TEST_PASS(PackAllReduceScanPass, Basic) { | |||||
| auto grad3 = opr::VirtualGrad::make(y1, x1); | auto grad3 = opr::VirtualGrad::make(y1, x1); | ||||
| auto mode = opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM; | auto mode = opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM; | ||||
| auto comm0 = opr::CollectiveComm::make({grad0}, graph.get(), | |||||
| "grad0", 2, 0, 0, client, mode)[0]; | |||||
| auto comm1 = opr::CollectiveComm::make({grad1}, graph.get(), | |||||
| "grad1", 2, 0, 0, client, mode)[0]; | |||||
| auto comm2 = opr::CollectiveComm::make({grad2}, graph.get(), | |||||
| "grad2", 2, 0, 0, client, mode)[0]; | |||||
| auto comm3 = opr::CollectiveComm::make({grad3}, graph.get(), | |||||
| "grad3", 2, 0, 0, client, mode)[0]; | |||||
| auto comm0 = opr::CollectiveComm::make({grad0}, graph.get(), "grad0", 2, | |||||
| false, 0, false, client, mode)[0]; | |||||
| auto comm1 = opr::CollectiveComm::make({grad1}, graph.get(), "grad1", 2, | |||||
| false, 0, false, client, mode)[0]; | |||||
| auto comm2 = opr::CollectiveComm::make({grad2}, graph.get(), "grad2", 2, | |||||
| false, 0, false, client, mode)[0]; | |||||
| auto comm3 = opr::CollectiveComm::make({grad3}, graph.get(), "grad3", 2, | |||||
| false, 0, false, client, mode)[0]; | |||||
| gopt::GraphOptimizer() | gopt::GraphOptimizer() | ||||
| .add_pass<gopt::PackAllReduceScanPass>() | .add_pass<gopt::PackAllReduceScanPass>() | ||||
| @@ -488,10 +488,12 @@ TEST_PASS(PackAllReduceReplacePass, CollectGroups) { | |||||
| auto grad = opr::VirtualGrad::make(target, wrt); | auto grad = opr::VirtualGrad::make(target, wrt); | ||||
| auto comm = opr::CollectiveComm::make( | |||||
| {grad}, graph.get(), "key", 2, 0, 0, client, | |||||
| opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM)[0] | |||||
| .node()->owner_opr(); | |||||
| auto comm = | |||||
| opr::CollectiveComm::make( | |||||
| {grad}, graph.get(), "key", 2, false, 0, false, client, | |||||
| opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM)[0] | |||||
| .node() | |||||
| ->owner_opr(); | |||||
| comm->cast_final_safe<opr::CollectiveComm>().set_pack_hash(extra_hash); | comm->cast_final_safe<opr::CollectiveComm>().set_pack_hash(extra_hash); | ||||
| @@ -543,8 +545,8 @@ TEST_PASS(PackAllReduceReplacePass, DividePacks) { | |||||
| auto insert_opr = [&] (size_t size) { | auto insert_opr = [&] (size_t size) { | ||||
| auto dev = std::make_shared<DeviceTensorND>(cn, TensorShape{size / sizeof(float)}); | auto dev = std::make_shared<DeviceTensorND>(cn, TensorShape{size / sizeof(float)}); | ||||
| auto sd = opr::SharedDeviceTensor::make(*graph, dev); | auto sd = opr::SharedDeviceTensor::make(*graph, dev); | ||||
| auto symvar = opr::CollectiveComm::make({sd}, graph.get(), | |||||
| "key", 2, 0, 0, client, mode)[0]; | |||||
| auto symvar = opr::CollectiveComm::make( | |||||
| {sd}, graph.get(), "key", 2, false, 0, false, client, mode)[0]; | |||||
| auto opr = symvar.node()->owner_opr(); | auto opr = symvar.node()->owner_opr(); | ||||
| auto& comm = opr->cast_final_safe<opr::CollectiveComm>(); | auto& comm = opr->cast_final_safe<opr::CollectiveComm>(); | ||||
| comm.set_pack_hash(1); | comm.set_pack_hash(1); | ||||
| @@ -596,7 +598,6 @@ TEST_PASS(PackAllReduceReplacePass, InsertPackedOprs) { | |||||
| size_t nr_devices = 2; | size_t nr_devices = 2; | ||||
| uint32_t rank = 0; | uint32_t rank = 0; | ||||
| uint32_t root = 0; | |||||
| using GroupInfo = gopt::PackAllReduceReplacePass::GroupInfo; | using GroupInfo = gopt::PackAllReduceReplacePass::GroupInfo; | ||||
| ThinHashMap<uint64_t, std::shared_ptr<GroupInfo>> group_info; | ThinHashMap<uint64_t, std::shared_ptr<GroupInfo>> group_info; | ||||
| @@ -605,8 +606,9 @@ TEST_PASS(PackAllReduceReplacePass, InsertPackedOprs) { | |||||
| auto insert_opr = [&] (const TensorShape& shape) { | auto insert_opr = [&] (const TensorShape& shape) { | ||||
| auto dev = std::make_shared<DeviceTensorND>(cn, shape); | auto dev = std::make_shared<DeviceTensorND>(cn, shape); | ||||
| auto sd = opr::SharedDeviceTensor::make(*graph, dev); | auto sd = opr::SharedDeviceTensor::make(*graph, dev); | ||||
| auto symvar = opr::CollectiveComm::make({sd}, graph.get(), | |||||
| "key", nr_devices, rank, root, client, mode)[0]; | |||||
| auto symvar = | |||||
| opr::CollectiveComm::make({sd}, graph.get(), "key", nr_devices, | |||||
| false, rank, false, client, mode)[0]; | |||||
| auto opr = symvar.node()->owner_opr(); | auto opr = symvar.node()->owner_opr(); | ||||
| auto& comm = opr->cast_final_safe<opr::CollectiveComm>(); | auto& comm = opr->cast_final_safe<opr::CollectiveComm>(); | ||||
| comm.set_pack_hash(1); | comm.set_pack_hash(1); | ||||
| @@ -634,8 +636,9 @@ TEST_PASS(PackAllReduceReplacePass, InsertPackedOprs) { | |||||
| auto concat = opr::Concat::make({grad_x.flatten(), grad_y.flatten()}, 0); | auto concat = opr::Concat::make({grad_x.flatten(), grad_y.flatten()}, 0); | ||||
| std::string key = ssprintf("grad_pack_%zu", pack_id); | std::string key = ssprintf("grad_pack_%zu", pack_id); | ||||
| auto allreduce = opr::CollectiveComm::make({concat}, graph.get(), | |||||
| key, nr_devices, rank, root, client, mode)[0]; | |||||
| auto allreduce = | |||||
| opr::CollectiveComm::make({concat}, graph.get(), key, nr_devices, | |||||
| false, rank, false, client, mode)[0]; | |||||
| std::vector<size_t> partition; | std::vector<size_t> partition; | ||||
| partition.push_back(shape_x.total_nr_elems()); | partition.push_back(shape_x.total_nr_elems()); | ||||
| @@ -683,10 +686,14 @@ TEST_PASS(PackAllReduceReplacePass, Equivalence) { | |||||
| using Mode = opr::CollectiveComm::Param::Mode; | using Mode = opr::CollectiveComm::Param::Mode; | ||||
| bool is_root = (rank == 0); | bool is_root = (rank == 0); | ||||
| auto reduced_x = opr::CollectiveComm::make({grad_x}, graph.get(), | |||||
| "x", 2, is_root, rank, client, Mode::ALL_REDUCE_SUM)[0] / 2; | |||||
| auto reduced_y = opr::CollectiveComm::make({grad_y}, graph.get(), | |||||
| "y", 2, is_root, rank, client, Mode::ALL_REDUCE_SUM)[0] / 2; | |||||
| auto reduced_x = opr::CollectiveComm::make( | |||||
| {grad_x}, graph.get(), "x", 2, is_root, rank, | |||||
| false, client, Mode::ALL_REDUCE_SUM)[0] / | |||||
| 2; | |||||
| auto reduced_y = opr::CollectiveComm::make( | |||||
| {grad_y}, graph.get(), "y", 2, is_root, rank, | |||||
| false, client, Mode::ALL_REDUCE_SUM)[0] / | |||||
| 2; | |||||
| graph->options().allreduce_pack_max_size = 5000; | graph->options().allreduce_pack_max_size = 5000; | ||||
| graph->options().allreduce_pack_ignore_first = 0; | graph->options().allreduce_pack_ignore_first = 0; | ||||
| @@ -14,6 +14,8 @@ | |||||
| #include "megbrain/comp_node_env.h" | #include "megbrain/comp_node_env.h" | ||||
| #include "megbrain/graph/event.h" | #include "megbrain/graph/event.h" | ||||
| #include "megbrain/graph/grad_impl.h" | #include "megbrain/graph/grad_impl.h" | ||||
| #include "megbrain/opr/io.h" | |||||
| #include "megbrain/opr/tensor_manip.h" | |||||
| #include "megbrain/opr/basic_arith.h" | #include "megbrain/opr/basic_arith.h" | ||||
| #include "megbrain/opr/megray_helper.h" | #include "megbrain/opr/megray_helper.h" | ||||
| #include "megbrain/opr/group_manager.h" | #include "megbrain/opr/group_manager.h" | ||||
| @@ -77,6 +79,8 @@ cudaStream_t get_stream(VarNode* var) { | |||||
| } | } | ||||
| } // anonymous namespace | } // anonymous namespace | ||||
| /* ================= ModeTrait ================= */ | |||||
| class CollectiveComm::ModeTrait { | class CollectiveComm::ModeTrait { | ||||
| class BROADCAST; | class BROADCAST; | ||||
| class REDUCE_SUM; | class REDUCE_SUM; | ||||
| @@ -132,6 +136,42 @@ public: | |||||
| return None; | return None; | ||||
| } | } | ||||
| VarNode* full_grad(VarNode* out_grad, const CollectiveComm* opr) const { | |||||
| auto mode = ModeTrait::from_mode(opr->param().mode).grad_mode(); | |||||
| SymbolVarArray og_syms; | |||||
| og_syms.push_back(out_grad); | |||||
| auto&& cn = opr->output(0)->comp_node(); | |||||
| auto gvar = CollectiveComm::make( | |||||
| og_syms, opr->owner_graph(), opr->key() + ":grad", | |||||
| opr->nr_devices(), opr->is_root(), opr->rank(), false, | |||||
| opr->group_client(), mode, opr->dtype(), opr->backend(), {cn}); | |||||
| return gvar[0].node(); | |||||
| } | |||||
| virtual VarNode* local_grad(VarNode* out_grad, const CollectiveComm* opr) const { | |||||
| mgb_throw(MegBrainError, | |||||
| "only all_reduce all_to_all all_gather reduce_scatter " | |||||
| "support local_grad"); | |||||
| } | |||||
| virtual VarNode* grad(VarNode* out_grad, const CollectiveComm* opr) const { | |||||
| if (opr->local_grad()){ | |||||
| return local_grad(out_grad, opr); | |||||
| } else { | |||||
| return full_grad(out_grad, opr); | |||||
| } | |||||
| } | |||||
| VarNode* zeros(mgb::cg::ComputingGraph &graph, CompNode node, const SymbolVar& shape, | |||||
| DType dtype) const { | |||||
| auto zero = SymbolVar::make_scalar(0, graph, node); | |||||
| auto zero_tensor = opr::TypeCvt::make(zero, dtype).broadcast(shape); | |||||
| return zero_tensor.node(); | |||||
| } | |||||
| virtual void get_output_var_shape(const CollectiveComm* opr, | virtual void get_output_var_shape(const CollectiveComm* opr, | ||||
| const TensorShapeArray& ishp, | const TensorShapeArray& ishp, | ||||
| TensorShapeArray& oshp) = 0; | TensorShapeArray& oshp) = 0; | ||||
| @@ -174,6 +214,17 @@ class CollectiveComm::ModeTrait::ALL_GATHER : public ModeTrait { | |||||
| } | } | ||||
| Mode grad_mode() override { return Mode::REDUCE_SCATTER_SUM; } | Mode grad_mode() override { return Mode::REDUCE_SCATTER_SUM; } | ||||
| VarNode* local_grad(VarNode* out_grad, const CollectiveComm* opr) const override { | |||||
| auto nr_devices = opr->nr_devices(); | |||||
| auto rank = opr->rank(); | |||||
| opr::Subtensor::IndexDesc axis; | |||||
| auto shape0 = opr::GetVarShape::make(out_grad, 0); | |||||
| axis.push_back({0, shape0 * rank / (int)nr_devices, | |||||
| shape0 * (rank + 1) / (int)nr_devices}); | |||||
| auto grad = opr::Subtensor::make(out_grad, axis); | |||||
| return grad.node(); | |||||
| } | |||||
| }; | }; | ||||
| class CollectiveComm::ModeTrait::REDUCE_SCATTER_SUM : public ModeTrait { | class CollectiveComm::ModeTrait::REDUCE_SCATTER_SUM : public ModeTrait { | ||||
| @@ -211,9 +262,23 @@ class CollectiveComm::ModeTrait::REDUCE_SCATTER_SUM : public ModeTrait { | |||||
| } | } | ||||
| Mode grad_mode() override { return Mode::ALL_GATHER; } | Mode grad_mode() override { return Mode::ALL_GATHER; } | ||||
| }; | |||||
| /* ================= ModeTrait impls ================= */ | |||||
| VarNode* local_grad(VarNode* out_grad, const CollectiveComm* opr) const override { | |||||
| VarNodeArray grads; | |||||
| auto zeros_tensor = | |||||
| zeros(*out_grad->owner_graph(), out_grad->comp_node(), | |||||
| opr::GetVarShape::make(out_grad), out_grad->dtype()); | |||||
| for (size_t i = 0;i < opr->nr_devices();i++) { | |||||
| if (i == opr->rank()) { | |||||
| grads.push_back(out_grad); | |||||
| } else { | |||||
| grads.push_back(zeros_tensor); | |||||
| } | |||||
| } | |||||
| auto grad = opr::Concat::make(grads, 0); | |||||
| return grad.node(); | |||||
| } | |||||
| }; | |||||
| class CollectiveComm::ModeTrait::ReducedBasedTrait { | class CollectiveComm::ModeTrait::ReducedBasedTrait { | ||||
| protected: | protected: | ||||
| @@ -250,6 +315,12 @@ class CollectiveComm::ModeTrait::AllReduceBase : public ReducedBasedTrait, | |||||
| } | } | ||||
| Mode grad_mode() override { return Mode::ALL_REDUCE_SUM; } | Mode grad_mode() override { return Mode::ALL_REDUCE_SUM; } | ||||
| public: | |||||
| VarNode* local_grad(VarNode* out_grad, | |||||
| const CollectiveComm* opr) const override { | |||||
| return out_grad; | |||||
| } | |||||
| }; | }; | ||||
| class CollectiveComm::ModeTrait::ALL_REDUCE_SUM final : public AllReduceBase { | class CollectiveComm::ModeTrait::ALL_REDUCE_SUM final : public AllReduceBase { | ||||
| @@ -258,10 +329,38 @@ class CollectiveComm::ModeTrait::ALL_REDUCE_SUM final : public AllReduceBase { | |||||
| class CollectiveComm::ModeTrait::ALL_REDUCE_MAX final : public AllReduceBase { | class CollectiveComm::ModeTrait::ALL_REDUCE_MAX final : public AllReduceBase { | ||||
| MegRay::ReduceOp op() const override { return MegRay::ReduceOp::MEGRAY_MAX; } | MegRay::ReduceOp op() const override { return MegRay::ReduceOp::MEGRAY_MAX; } | ||||
| VarNode* grad(VarNode* out_grad, const CollectiveComm* opr) const override { | |||||
| VarNode* grad; | |||||
| if (opr->local_grad()) { | |||||
| grad = local_grad(out_grad, opr); | |||||
| } else { | |||||
| grad = full_grad(out_grad, opr); | |||||
| } | |||||
| grad = opr::Elemwise::make({opr->output(0), opr->input(0), grad}, | |||||
| Elemwise::Mode::COND_LEQ_MOV) | |||||
| .node(); | |||||
| return grad; | |||||
| } | |||||
| }; | }; | ||||
| class CollectiveComm::ModeTrait::ALL_REDUCE_MIN final : public AllReduceBase { | class CollectiveComm::ModeTrait::ALL_REDUCE_MIN final : public AllReduceBase { | ||||
| MegRay::ReduceOp op() const override { return MegRay::ReduceOp::MEGRAY_MIN; } | MegRay::ReduceOp op() const override { return MegRay::ReduceOp::MEGRAY_MIN; } | ||||
| VarNode* grad(VarNode* out_grad, const CollectiveComm* opr) const override { | |||||
| VarNode* grad; | |||||
| if (opr->local_grad()) { | |||||
| grad = local_grad(out_grad, opr); | |||||
| } else { | |||||
| grad = full_grad(out_grad, opr); | |||||
| } | |||||
| grad = opr::Elemwise::make({opr->input(0), opr->output(0), grad}, | |||||
| Elemwise::Mode::COND_LEQ_MOV) | |||||
| .node(); | |||||
| return grad; | |||||
| } | |||||
| }; | }; | ||||
| class CollectiveComm::ModeTrait::ReduceBase : public ReducedBasedTrait, | class CollectiveComm::ModeTrait::ReduceBase : public ReducedBasedTrait, | ||||
| @@ -448,6 +547,24 @@ class CollectiveComm::ModeTrait::ALL_TO_ALL : public ModeTrait { | |||||
| } | } | ||||
| Mode grad_mode() override { return Mode::ALL_TO_ALL; } | Mode grad_mode() override { return Mode::ALL_TO_ALL; } | ||||
| VarNode* local_grad(VarNode* out_grad, const CollectiveComm* opr) const override { | |||||
| VarNodeArray grads; | |||||
| auto grad_shape = opr::GetVarShape::make(out_grad); | |||||
| auto zeros_tensor = | |||||
| zeros(*out_grad->owner_graph(), out_grad->comp_node(), | |||||
| grad_shape, out_grad->dtype()); | |||||
| auto nr_devices = opr->nr_devices(); | |||||
| auto rank = opr->rank(); | |||||
| opr::Subtensor::IndexDesc axis; | |||||
| auto shape0 = opr::GetVarShape::make(out_grad, 0); | |||||
| axis.push_back({0, shape0 * rank / (int)nr_devices, | |||||
| shape0 * (rank + 1) / (int)nr_devices}); | |||||
| auto sub_grad = opr::Subtensor::make(out_grad, axis); | |||||
| return opr::SetSubtensor::make(zeros_tensor, sub_grad, axis).node(); | |||||
| } | |||||
| }; | }; | ||||
| CollectiveComm::ModeTrait& CollectiveComm::ModeTrait::from_mode(Mode mode) { | CollectiveComm::ModeTrait& CollectiveComm::ModeTrait::from_mode(Mode mode) { | ||||
| @@ -469,8 +586,9 @@ CollectiveComm::ModeTrait& CollectiveComm::ModeTrait::from_mode(Mode mode) { | |||||
| CollectiveComm::CollectiveComm( | CollectiveComm::CollectiveComm( | ||||
| VarNodeArray inputs, ComputingGraph* const graph, | VarNodeArray inputs, ComputingGraph* const graph, | ||||
| const std::string& key, const size_t nr_devices, const bool is_root, | const std::string& key, const size_t nr_devices, const bool is_root, | ||||
| const int rank, std::shared_ptr<GroupClient> group_client, | |||||
| const Param& param, const DType& dtype, const std::string& backend, | |||||
| const int rank, const bool local_grad, | |||||
| std::shared_ptr<GroupClient> group_client, const Param& param, | |||||
| const DType& dtype, const std::string& backend, | |||||
| const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr, | const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr, | ||||
| const OperatorNodeConfig& config, | const OperatorNodeConfig& config, | ||||
| const std::shared_ptr<DTypeScalar>& disable) | const std::shared_ptr<DTypeScalar>& disable) | ||||
| @@ -482,6 +600,7 @@ CollectiveComm::CollectiveComm( | |||||
| m_nr_devices(nr_devices), | m_nr_devices(nr_devices), | ||||
| m_is_root(is_root), | m_is_root(is_root), | ||||
| m_rank(rank), | m_rank(rank), | ||||
| m_local_grad(local_grad), | |||||
| m_key(key), | m_key(key), | ||||
| m_dev_buffers(dev_buffer_arr), | m_dev_buffers(dev_buffer_arr), | ||||
| m_disable{disable} { | m_disable{disable} { | ||||
| @@ -523,28 +642,31 @@ CollectiveComm::CollectiveComm( | |||||
| SymbolVarArray CollectiveComm::make( | SymbolVarArray CollectiveComm::make( | ||||
| const SymbolVarArray& inputs, ComputingGraph* const graph, | const SymbolVarArray& inputs, ComputingGraph* const graph, | ||||
| const std::string& key, const size_t nr_devices, const bool is_root, | const std::string& key, const size_t nr_devices, const bool is_root, | ||||
| const int rank, std::shared_ptr<GroupClient> group_client, | |||||
| const Param& param, const DType& dtype, const std::string& backend, | |||||
| const int rank, const bool local_grad, | |||||
| std::shared_ptr<GroupClient> group_client, const Param& param, | |||||
| const DType& dtype, const std::string& backend, | |||||
| const OperatorNodeConfig& config, | const OperatorNodeConfig& config, | ||||
| const std::shared_ptr<DTypeScalar>& disable) { | const std::shared_ptr<DTypeScalar>& disable) { | ||||
| SmallVector<std::shared_ptr<DeviceTensorND>> dev_buffer_arr(nr_devices, | SmallVector<std::shared_ptr<DeviceTensorND>> dev_buffer_arr(nr_devices, | ||||
| nullptr); | nullptr); | ||||
| return make(inputs, graph, key, nr_devices, is_root, rank, group_client, | |||||
| dev_buffer_arr, param, dtype, backend, config); | |||||
| return make(inputs, graph, key, nr_devices, is_root, rank, local_grad, | |||||
| group_client, dev_buffer_arr, param, dtype, backend, config); | |||||
| } | } | ||||
| SymbolVarArray CollectiveComm::make( | SymbolVarArray CollectiveComm::make( | ||||
| const SymbolVarArray& inputs, ComputingGraph* const graph, | const SymbolVarArray& inputs, ComputingGraph* const graph, | ||||
| const std::string& key, const size_t nr_devices, const bool is_root, | const std::string& key, const size_t nr_devices, const bool is_root, | ||||
| const int rank, std::shared_ptr<GroupClient> group_client, | |||||
| const int rank, const bool local_grad, | |||||
| std::shared_ptr<GroupClient> group_client, | |||||
| const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr, | const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr, | ||||
| const Param& param, const DType& dtype, const std::string& backend, | const Param& param, const DType& dtype, const std::string& backend, | ||||
| const OperatorNodeConfig& config, | const OperatorNodeConfig& config, | ||||
| const std::shared_ptr<DTypeScalar>& disable) { | const std::shared_ptr<DTypeScalar>& disable) { | ||||
| auto inpvars = cg::to_var_node_array(inputs); | auto inpvars = cg::to_var_node_array(inputs); | ||||
| auto opr = graph->insert_opr(std::make_unique<CollectiveComm>( | auto opr = graph->insert_opr(std::make_unique<CollectiveComm>( | ||||
| inpvars, graph, key, nr_devices, is_root, rank, std::move(group_client), | |||||
| param, dtype, backend, dev_buffer_arr, config, disable)); | |||||
| inpvars, graph, key, nr_devices, is_root, rank, local_grad, | |||||
| std::move(group_client), param, dtype, backend, dev_buffer_arr, | |||||
| config, disable)); | |||||
| mgb_assert(!opr->output().empty()); | mgb_assert(!opr->output().empty()); | ||||
| return cg::to_symbol_var_array(opr->output()); | return cg::to_symbol_var_array(opr->output()); | ||||
| } | } | ||||
| @@ -647,93 +769,12 @@ void CollectiveComm::do_execute(ExecEnv& env) { | |||||
| owner_graph()->event().signal_inplace<cg::event::BeforeKernel>(this, cn); | owner_graph()->event().signal_inplace<cg::event::BeforeKernel>(this, cn); | ||||
| trait.exec(this); | trait.exec(this); | ||||
| owner_graph()->event().signal_inplace<cg::event::AfterKernel>(this, cn); | owner_graph()->event().signal_inplace<cg::event::AfterKernel>(this, cn); | ||||
| #if CUDART_VERSION < 9000 | |||||
| #pragma message "legacy CUDA; use sync to avoid blocking" | |||||
| // nccl hangs occasionally without this sync() | |||||
| cn.sync(); | |||||
| #endif | |||||
| }; | }; | ||||
| env.dispatch_on_comp_node(cn, runner); | env.dispatch_on_comp_node(cn, runner); | ||||
| } | } | ||||
| void CollectiveComm::on_output_comp_node_stream_changed() {} | void CollectiveComm::on_output_comp_node_stream_changed() {} | ||||
| VarNodeArray CollectiveComm::grad(const VarNodeArray& out_grads) const { | |||||
| auto mode = ModeTrait::from_mode(m_param.mode).grad_mode(); | |||||
| SymbolVarArray og_syms; | |||||
| if (m_param.mode == Param::Mode::REDUCE_SUM) { | |||||
| for (size_t i = 0; i < output().size(); i++) { | |||||
| if (out_grads[i]) | |||||
| og_syms.push_back(out_grads[i]); | |||||
| } | |||||
| mgb_assert(og_syms.size() == 1); | |||||
| } else { | |||||
| for (size_t i = 0; i < output().size(); i++) { | |||||
| if (!out_grads[i]) { | |||||
| mgb_assert(m_param.mode != Param::Mode::REDUCE_SCATTER_SUM, | |||||
| "null out grad in CollctiveCommMM currently " | |||||
| "unsupported when the forward mode is " | |||||
| "Reduce_Scatter_Sum."); | |||||
| DTypeScalar dval{output(i)->dtype()}; | |||||
| dval.set_retain_dtype(0); | |||||
| auto zeros = | |||||
| SymbolVar::make_scalar(dval, *output(i)->owner_graph(), | |||||
| output(i)->comp_node()) | |||||
| .broadcast(SymbolVar(output(i)).symshape()); | |||||
| og_syms.push_back(zeros); | |||||
| } else { | |||||
| og_syms.push_back(out_grads[i]); | |||||
| } | |||||
| } | |||||
| } | |||||
| OperatorNodeConfig::CompNodeArray cn_arr; | |||||
| if (m_param.mode == Param::Mode::REDUCE_SUM) { | |||||
| for (auto i : input()) { | |||||
| cn_arr.push_back(i->comp_node()); | |||||
| } | |||||
| } else if (m_param.mode == Param::Mode::BROADCAST) { | |||||
| if (!input().empty()) { | |||||
| cn_arr.push_back(input(0)->comp_node()); | |||||
| } | |||||
| } | |||||
| auto gvar = CollectiveComm::make( | |||||
| og_syms, owner_graph(), m_key + ":grad", m_nr_devices, m_is_root, | |||||
| m_rank, m_group_client, mode, m_dtype, m_backend, | |||||
| OperatorNodeConfig{}.comp_node_arr(cn_arr)); | |||||
| if (m_param.mode == Param::Mode::ALL_REDUCE_MAX) { | |||||
| for (size_t i = 0; i < input().size(); ++i) { | |||||
| gvar[i] = Elemwise::make({output(i), input(i), gvar[i]}, | |||||
| Elemwise::Mode::COND_LEQ_MOV); | |||||
| } | |||||
| } else if (m_param.mode == Param::Mode::ALL_REDUCE_MIN) { | |||||
| for (size_t i = 0; i < input().size(); ++i) { | |||||
| gvar[i] = Elemwise::make({input(i), output(i), gvar[i]}, | |||||
| Elemwise::Mode::COND_LEQ_MOV); | |||||
| } | |||||
| } else if (m_param.mode == Param::Mode::BROADCAST) { | |||||
| if (!input().empty()) { | |||||
| CompNode&& master_out_cn = input(0)->comp_node(); | |||||
| SymbolVarArray rst; | |||||
| for (auto i : gvar) { | |||||
| if (i.node()->comp_node() == master_out_cn) { | |||||
| mgb_assert(rst.empty()); | |||||
| rst.push_back(i); | |||||
| } | |||||
| } | |||||
| gvar = rst; | |||||
| } | |||||
| } | |||||
| return cg::to_var_node_array(gvar); | |||||
| } | |||||
| MGB_IMPL_OPR_GRAD(CollectiveComm) { | |||||
| return opr.grad(out_grad); | |||||
| } | |||||
| void CollectiveComm::init_output_dtype() { | void CollectiveComm::init_output_dtype() { | ||||
| if (m_dtype.valid()) { | if (m_dtype.valid()) { | ||||
| for (size_t i = 0; i < input().size(); ++i) { | for (size_t i = 0; i < input().size(); ++i) { | ||||
| @@ -797,6 +838,15 @@ void CollectiveComm::init_output_static_infer_desc() { | |||||
| } | } | ||||
| } | } | ||||
| VarNode* CollectiveComm::grad(VarNode* out_grad) const { | |||||
| return ModeTrait::from_mode(m_param.mode).grad(out_grad, this); | |||||
| } | |||||
| MGB_IMPL_OPR_GRAD(CollectiveComm) { | |||||
| mgb_assert(out_grad.size() == 1, "CollectiveComm should only have one grad"); | |||||
| return opr.grad(out_grad[0]); | |||||
| } | |||||
| /* ===================== shallow copy ===================== */ | /* ===================== shallow copy ===================== */ | ||||
| namespace mgb { | namespace mgb { | ||||
| @@ -807,13 +857,14 @@ cg::OperatorNodeBase* opr_shallow_copy_collective_mm( | |||||
| const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs, | const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs, | ||||
| const OperatorNodeConfig& config) { | const OperatorNodeConfig& config) { | ||||
| auto&& opr = opr_.cast_final_safe<opr::CollectiveComm>(); | auto&& opr = opr_.cast_final_safe<opr::CollectiveComm>(); | ||||
| auto new_opr = CollectiveComm::make( | |||||
| to_symbol_var_array(inputs), ctx.owner_graph(opr_, inputs), | |||||
| opr.key(), opr.nr_devices(), opr.is_root(), opr.rank(), | |||||
| opr.group_client(), opr.dev_buffers(), opr.param(), | |||||
| opr.dtype(), opr.backend(), config)[0] | |||||
| .node() | |||||
| ->owner_opr(); | |||||
| auto new_opr = | |||||
| CollectiveComm::make( | |||||
| to_symbol_var_array(inputs), ctx.owner_graph(opr_, inputs), | |||||
| opr.key(), opr.nr_devices(), opr.is_root(), opr.rank(), | |||||
| opr.local_grad(), opr.group_client(), opr.dev_buffers(), | |||||
| opr.param(), opr.dtype(), opr.backend(), config)[0] | |||||
| .node() | |||||
| ->owner_opr(); | |||||
| new_opr->cast_final_safe<opr::CollectiveComm>().set_pack_hash(opr.pack_hash()); | new_opr->cast_final_safe<opr::CollectiveComm>().set_pack_hash(opr.pack_hash()); | ||||
| return new_opr; | return new_opr; | ||||
| } | } | ||||
| @@ -8,6 +8,7 @@ decl_raw_opr( | |||||
| 'operation to which this operator belongs.', 'int'), | 'operation to which this operator belongs.', 'int'), | ||||
| Doc('is_root', 'whether this node is root node', 'bool'), | Doc('is_root', 'whether this node is root node', 'bool'), | ||||
| Doc('rank', 'rank of this node, if is -1, generate one', 'int'), | Doc('rank', 'rank of this node, if is -1, generate one', 'int'), | ||||
| Doc('local_grad', 'whether use local grad', 'bool'), | |||||
| Doc('server_addr', 'rpc server ip address'), | Doc('server_addr', 'rpc server ip address'), | ||||
| Doc('port', 'server rpc listening port'), | Doc('port', 'server rpc listening port'), | ||||
| Doc('param', 'The only component of *param* is *mode*, which refers to ' | Doc('param', 'The only component of *param* is *mode*, which refers to ' | ||||
| @@ -28,12 +29,12 @@ decl_raw_opr( | |||||
| body = [ | body = [ | ||||
| 'if isinstance(input, _mgb.SymbolVar):', | 'if isinstance(input, _mgb.SymbolVar):', | ||||
| (' output = _mgb._Opr.collective_comm_with_input(input, key, ' | (' output = _mgb._Opr.collective_comm_with_input(input, key, ' | ||||
| 'nr_devices, is_root, rank, server_addr, port, ' | |||||
| 'nr_devices, is_root, rank, local_grad, server_addr, port, ' | |||||
| '[param.serialize()], dtype, backend, output_buffer, config, disable)'), | '[param.serialize()], dtype, backend, output_buffer, config, disable)'), | ||||
| 'else:', | 'else:', | ||||
| ' assert isinstance(input, _mgb.CompGraph)', | ' assert isinstance(input, _mgb.CompGraph)', | ||||
| (' output = _mgb._Opr.collective_comm_without_input(input, key, ' | (' output = _mgb._Opr.collective_comm_without_input(input, key, ' | ||||
| 'nr_devices, is_root, rank, server_addr, port, ' | |||||
| 'nr_devices, is_root, rank, local_grad, server_addr, port, ' | |||||
| '[param.serialize()], dtype, backend, output_buffer, config, disable)') | '[param.serialize()], dtype, backend, output_buffer, config, disable)') | ||||
| ], | ], | ||||
| desc = ('collective communication between multiple CompNodes on multiple ' | desc = ('collective communication between multiple CompNodes on multiple ' | ||||
| @@ -29,8 +29,9 @@ public: | |||||
| CollectiveComm( | CollectiveComm( | ||||
| VarNodeArray inputs, ComputingGraph* const graph, | VarNodeArray inputs, ComputingGraph* const graph, | ||||
| const std::string& key, const size_t nr_devices, const bool is_root, | const std::string& key, const size_t nr_devices, const bool is_root, | ||||
| const int rank, std::shared_ptr<GroupClient> group_client, | |||||
| const Param& param, const DType& dtype, const std::string& backend, | |||||
| const int rank, const bool local_grad, | |||||
| std::shared_ptr<GroupClient> group_client, const Param& param, | |||||
| const DType& dtype, const std::string& backend, | |||||
| const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr, | const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr, | ||||
| const OperatorNodeConfig& config, | const OperatorNodeConfig& config, | ||||
| const std::shared_ptr<DTypeScalar>& disable); | const std::shared_ptr<DTypeScalar>& disable); | ||||
| @@ -38,7 +39,8 @@ public: | |||||
| static SymbolVarArray make( | static SymbolVarArray make( | ||||
| const SymbolVarArray& inputs, ComputingGraph* const graph, | const SymbolVarArray& inputs, ComputingGraph* const graph, | ||||
| const std::string& key, const size_t nr_devices, const bool is_root, | const std::string& key, const size_t nr_devices, const bool is_root, | ||||
| const int rank, std::shared_ptr<GroupClient> group_client, | |||||
| const int rank, const bool local_grad, | |||||
| std::shared_ptr<GroupClient> group_client, | |||||
| const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr, | const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr, | ||||
| const Param& param, const DType& dtype = {}, | const Param& param, const DType& dtype = {}, | ||||
| const std::string& backend = "nccl", | const std::string& backend = "nccl", | ||||
| @@ -50,6 +52,7 @@ public: | |||||
| ComputingGraph* const graph, | ComputingGraph* const graph, | ||||
| const std::string& key, const size_t nr_devices, | const std::string& key, const size_t nr_devices, | ||||
| const bool is_root, const int rank, | const bool is_root, const int rank, | ||||
| const bool local_grad, | |||||
| std::shared_ptr<GroupClient> group_client, | std::shared_ptr<GroupClient> group_client, | ||||
| const Param& param, const DType& dtype = {}, | const Param& param, const DType& dtype = {}, | ||||
| const std::string& backend = "nccl", | const std::string& backend = "nccl", | ||||
| @@ -72,6 +75,7 @@ public: | |||||
| int rank() const { return m_rank; } | int rank() const { return m_rank; } | ||||
| int root() const { return m_root; } | int root() const { return m_root; } | ||||
| bool is_root() const { return m_is_root; } | bool is_root() const { return m_is_root; } | ||||
| bool local_grad() const { return m_local_grad; } | |||||
| //! The key that identifies an NCCL clique. | //! The key that identifies an NCCL clique. | ||||
| //! Operators with same keys belong to the same clique. | //! Operators with same keys belong to the same clique. | ||||
| @@ -89,7 +93,7 @@ public: | |||||
| return m_megray_ctx; | return m_megray_ctx; | ||||
| } | } | ||||
| VarNodeArray grad(const VarNodeArray& out_grad) const; | |||||
| VarNode* grad(VarNode* out_grad) const; | |||||
| private: | private: | ||||
| Barrier m_exec_barrier; | Barrier m_exec_barrier; | ||||
| @@ -116,6 +120,7 @@ private: | |||||
| size_t m_nr_devices = 0; | size_t m_nr_devices = 0; | ||||
| bool m_is_root; | bool m_is_root; | ||||
| int m_rank; | int m_rank; | ||||
| bool m_local_grad; | |||||
| std::string m_key; | std::string m_key; | ||||
| //! XXHash generated from m_key | //! XXHash generated from m_key | ||||
| size_t m_hash; | size_t m_hash; | ||||
| @@ -46,7 +46,7 @@ pdef('PersistentOutputStorage').add_fields( | |||||
| (pdef('CollectiveComm', 'collective communication between multiple computing ' | (pdef('CollectiveComm', 'collective communication between multiple computing ' | ||||
| 'nodes on localhost') | 'nodes on localhost') | ||||
| .add_enum('Mode', | |||||
| .add_enum(Doc('Mode', 'mode of collective communication'), | |||||
| Doc('REDUCE_SUM', 'reduce by sum to output computing node'), | Doc('REDUCE_SUM', 'reduce by sum to output computing node'), | ||||
| Doc('BROADCAST', 'copy input value to each output computing node'), | Doc('BROADCAST', 'copy input value to each output computing node'), | ||||
| Doc('ALL_GATHER', 'each output comp node gets the concatenated ' | Doc('ALL_GATHER', 'each output comp node gets the concatenated ' | ||||
| @@ -59,7 +59,8 @@ pdef('PersistentOutputStorage').add_fields( | |||||
| Doc('ALL_REDUCE_PROD', 'every output gets the prod of all inputs'), | Doc('ALL_REDUCE_PROD', 'every output gets the prod of all inputs'), | ||||
| Doc('GATHER', 'concat inputs to one node'), | Doc('GATHER', 'concat inputs to one node'), | ||||
| Doc('SCATTER', 'scatter input to each output computing node'), | Doc('SCATTER', 'scatter input to each output computing node'), | ||||
| Doc('ALL_TO_ALL', 'scatter inputs and gather them on each computing node'))) | |||||
| Doc('ALL_TO_ALL', 'scatter inputs and gather them on each computing node'), | |||||
| name_field='mode')) | |||||
| (pdef('FakeSerializedDType', | (pdef('FakeSerializedDType', | ||||
| 'HACK: The tag of this param def is actually used for another ' | 'HACK: The tag of this param def is actually used for another ' | ||||