Browse Source

!13857 check whether communication unit has been inited

From: @yao_yf
Reviewed-by: @kisnwang,@stsuteng
Signed-off-by: @stsuteng
pull/13857/MERGE
mindspore-ci-bot Gitee 5 years ago
parent
commit
b1c86b6a22
2 changed files with 12 additions and 8 deletions
  1. +9
    -0
      mindspore/communication/_comm_helper.py
  2. +3
    -8
      mindspore/communication/management.py

+ 9
- 0
mindspore/communication/_comm_helper.py View File

@@ -77,6 +77,13 @@ class Backend:
raise ValueError("Invalid backend: '{}'".format(name))
return value

DEFAULT_BACKEND = Backend("hccl")

class GlobalComm:
"""World communication information."""
BACKEND = DEFAULT_BACKEND
WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
INITED = False

def is_hccl_available():
"""
@@ -114,6 +121,8 @@ def check_parameter_available(func):
def wrapper(*args, **kargs):
if _is_role_pserver() or _is_role_sched():
return func(*args, **kargs)
if not GlobalComm.INITED:
raise RuntimeError("Distributed Communication has not been inited")
group = None
if "group" in kargs.keys():
group = kargs.get("group")


+ 3
- 8
mindspore/communication/management.py View File

@@ -18,7 +18,7 @@ from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched
from ._comm_helper import Backend, _get_rank_helper, _get_size_helper, \
_get_world_rank_from_group_rank_helper, _get_group_rank_from_world_rank_helper, \
_create_group_helper, _destroy_group_helper, HCCL_WORLD_COMM_GROUP, NCCL_WORLD_COMM_GROUP, \
_get_local_rank_helper, _get_local_size_helper
_get_local_rank_helper, _get_local_size_helper, GlobalComm
from .._c_expression import init_hccl, finalize_hccl, init_gpu_collective


@@ -28,8 +28,6 @@ __all__ = ["init", "release", "get_rank", "get_local_rank", "get_group_size",
"HCCL_WORLD_COMM_GROUP", "NCCL_WORLD_COMM_GROUP"]

DEFAULT_WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
DEFAULT_BACKEND = Backend("hccl")


def _get_group(group):
"""Return the world communication group if the `group` is `DEFAULT_WORLD_COMM_GROUP`."""
@@ -38,11 +36,6 @@ def _get_group(group):
return group


class GlobalComm:
"""World communication information."""
BACKEND = DEFAULT_BACKEND
WORLD_COMM_GROUP = DEFAULT_WORLD_COMM_GROUP


def init(backend_name=None):
"""
@@ -78,10 +71,12 @@ def init(backend_name=None):
init_hccl()
GlobalComm.BACKEND = Backend("hccl")
GlobalComm.WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
GlobalComm.INITED = True
elif backend_name == "nccl":
init_gpu_collective()
GlobalComm.BACKEND = Backend("nccl")
GlobalComm.WORLD_COMM_GROUP = NCCL_WORLD_COMM_GROUP
GlobalComm.INITED = True
else:
raise RuntimeError("Backend name {} is not supported.".format(backend_name))



Loading…
Cancel
Save