|
|
|
@@ -18,7 +18,7 @@ from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched |
|
|
|
from ._comm_helper import Backend, _get_rank_helper, _get_size_helper, \ |
|
|
|
_get_world_rank_from_group_rank_helper, _get_group_rank_from_world_rank_helper, \ |
|
|
|
_create_group_helper, _destroy_group_helper, HCCL_WORLD_COMM_GROUP, NCCL_WORLD_COMM_GROUP, \ |
|
|
|
_get_local_rank_helper, _get_local_size_helper |
|
|
|
_get_local_rank_helper, _get_local_size_helper, GlobalComm |
|
|
|
from .._c_expression import init_hccl, finalize_hccl, init_gpu_collective |
|
|
|
|
|
|
|
|
|
|
|
@@ -28,8 +28,6 @@ __all__ = ["init", "release", "get_rank", "get_local_rank", "get_group_size", |
|
|
|
"HCCL_WORLD_COMM_GROUP", "NCCL_WORLD_COMM_GROUP"] |
|
|
|
|
|
|
|
DEFAULT_WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP |
|
|
|
DEFAULT_BACKEND = Backend("hccl") |
|
|
|
|
|
|
|
|
|
|
|
def _get_group(group): |
|
|
|
"""Return the world communication group if the `group` is `DEFAULT_WORLD_COMM_GROUP`.""" |
|
|
|
@@ -38,11 +36,6 @@ def _get_group(group): |
|
|
|
return group |
|
|
|
|
|
|
|
|
|
|
|
class GlobalComm: |
|
|
|
"""World communication information.""" |
|
|
|
BACKEND = DEFAULT_BACKEND |
|
|
|
WORLD_COMM_GROUP = DEFAULT_WORLD_COMM_GROUP |
|
|
|
|
|
|
|
|
|
|
|
def init(backend_name=None): |
|
|
|
""" |
|
|
|
@@ -78,10 +71,12 @@ def init(backend_name=None): |
|
|
|
init_hccl() |
|
|
|
GlobalComm.BACKEND = Backend("hccl") |
|
|
|
GlobalComm.WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP |
|
|
|
GlobalComm.INITED = True |
|
|
|
elif backend_name == "nccl": |
|
|
|
init_gpu_collective() |
|
|
|
GlobalComm.BACKEND = Backend("nccl") |
|
|
|
GlobalComm.WORLD_COMM_GROUP = NCCL_WORLD_COMM_GROUP |
|
|
|
GlobalComm.INITED = True |
|
|
|
else: |
|
|
|
raise RuntimeError("Backend name {} is not supported.".format(backend_name)) |
|
|
|
|
|
|
|
|