From ba3a1f4ffef9746eaa5d5bae06424fc2ff93c739 Mon Sep 17 00:00:00 2001 From: lichenever Date: Fri, 24 Apr 2020 15:39:38 +0800 Subject: [PATCH] change get_group to internal interface --- mindspore/communication/__init__.py | 4 ++-- mindspore/communication/management.py | 12 ++++++------ mindspore/ops/operations/comm_ops.py | 28 +++++++++++++-------------- 3 files changed, 22 insertions(+), 22 deletions(-) diff --git a/mindspore/communication/__init__.py b/mindspore/communication/__init__.py index 65078f6820..26acc53d91 100644 --- a/mindspore/communication/__init__.py +++ b/mindspore/communication/__init__.py @@ -17,12 +17,12 @@ Collective communication interface. """ from .management import GlobalComm, init, release, get_rank, get_group_size, get_world_rank_from_group_rank, \ - get_group_rank_from_world_rank, create_group, HCCL_WORLD_COMM_GROUP, NCCL_WORLD_COMM_GROUP, get_group, \ + get_group_rank_from_world_rank, create_group, HCCL_WORLD_COMM_GROUP, NCCL_WORLD_COMM_GROUP, \ get_local_rank, get_local_rank_size, destroy_group __all__ = [ "GlobalComm", "init", "release", "get_rank", "get_group_size", "get_world_rank_from_group_rank", - "get_group_rank_from_world_rank", "create_group", "HCCL_WORLD_COMM_GROUP", "NCCL_WORLD_COMM_GROUP", "get_group", + "get_group_rank_from_world_rank", "create_group", "HCCL_WORLD_COMM_GROUP", "NCCL_WORLD_COMM_GROUP", "get_local_rank", "get_local_rank_size", "destroy_group" ] diff --git a/mindspore/communication/management.py b/mindspore/communication/management.py index 7208538a07..1cd60fe2e5 100755 --- a/mindspore/communication/management.py +++ b/mindspore/communication/management.py @@ -21,7 +21,7 @@ from ._comm_helper import Backend, _get_rank_helper, _get_size_helper, \ from .._c_expression import init_hccl, finalize_hccl, init_gpu_collective -__all__ = ["init", "release", "get_rank", "get_local_rank", "get_group_size", "get_group", +__all__ = ["init", "release", "get_rank", "get_local_rank", "get_group_size", "get_local_rank_size", "get_world_rank_from_group_rank", "get_group_rank_from_world_rank", "create_group", "destroy_group", "HCCL_WORLD_COMM_GROUP", "NCCL_WORLD_COMM_GROUP"] @@ -30,7 +30,7 @@ DEFAULT_WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP DEFAULT_BACKEND = Backend("hccl") -def get_group(group): +def _get_group(group): """Get the global world group if the group is default world comm group.""" if group == DEFAULT_WORLD_COMM_GROUP: return GlobalComm.WORLD_COMM_GROUP @@ -100,7 +100,7 @@ def get_rank(group=GlobalComm.WORLD_COMM_GROUP): ValueError: If backend is invalid. RuntimeError: If hccl/nccl is not available or nccl not supports. """ - return _get_rank_helper(group=get_group(group), backend=GlobalComm.BACKEND) + return _get_rank_helper(group=_get_group(group), backend=GlobalComm.BACKEND) def get_local_rank(group=GlobalComm.WORLD_COMM_GROUP): @@ -121,7 +121,7 @@ def get_local_rank(group=GlobalComm.WORLD_COMM_GROUP): ValueError: If backend is invalid. RuntimeError: If hccl/nccl is not available or nccl not supports. """ - return _get_local_rank_helper(group=get_group(group), backend=GlobalComm.BACKEND) + return _get_local_rank_helper(group=_get_group(group), backend=GlobalComm.BACKEND) def get_group_size(group=GlobalComm.WORLD_COMM_GROUP): @@ -139,7 +139,7 @@ def get_group_size(group=GlobalComm.WORLD_COMM_GROUP): ValueError: If backend is invalid. RuntimeError: If hccl/nccl is not available or nccl not supports. """ - return _get_size_helper(group=get_group(group), backend=GlobalComm.BACKEND) + return _get_size_helper(group=_get_group(group), backend=GlobalComm.BACKEND) def get_local_rank_size(group=GlobalComm.WORLD_COMM_GROUP): @@ -160,7 +160,7 @@ def get_local_rank_size(group=GlobalComm.WORLD_COMM_GROUP): ValueError: If backend is invalid. RuntimeError: If hccl/nccl is not available or nccl not supports. """ - return _get_local_size_helper(group=get_group(group), backend=GlobalComm.BACKEND) + return _get_local_size_helper(group=_get_group(group), backend=GlobalComm.BACKEND) def get_world_rank_from_group_rank(group, group_rank_id): diff --git a/mindspore/ops/operations/comm_ops.py b/mindspore/ops/operations/comm_ops.py index fbad5b49d3..969091de97 100644 --- a/mindspore/ops/operations/comm_ops.py +++ b/mindspore/ops/operations/comm_ops.py @@ -17,7 +17,7 @@ from ..._checkparam import Validator as validator from ..._checkparam import Rel -from ...communication.management import get_rank, get_group_size, GlobalComm, get_group +from ...communication.management import get_rank, get_group_size, GlobalComm, _get_group from ...common import dtype as mstype from ..primitive import PrimitiveWithInfer, prim_attr_register @@ -88,10 +88,10 @@ class AllReduce(PrimitiveWithInfer): raise TypeError("The operation of AllReduce should be str.") if op == ReduceOp.PROD: raise RuntimeError("The operation of AllReduce 'prod' is not supported yet.") - if not isinstance(get_group(group), str): + if not isinstance(_get_group(group), str): raise TypeError("The group of AllReduce should be str.") self.op = op - self.add_prim_attr('group', get_group(group)) + self.add_prim_attr('group', _get_group(group)) self.add_prim_attr('fusion', 0) def vm_impl(self, x): @@ -149,12 +149,12 @@ class AllGather(PrimitiveWithInfer): @prim_attr_register def __init__(self, group=GlobalComm.WORLD_COMM_GROUP): - validator.check_value_type('group', get_group(group), (str,), self.name) - self.rank = get_rank(get_group(group)) - self.rank_size = get_group_size(get_group(group)) + validator.check_value_type('group', _get_group(group), (str,), self.name) + self.rank = get_rank(_get_group(group)) + self.rank_size = get_group_size(_get_group(group)) validator.check('rank', self.rank, 'rank_size', self.rank_size, Rel.LT, self.name) self.add_prim_attr('rank_size', self.rank_size) - self.add_prim_attr('group', get_group(group)) + self.add_prim_attr('group', _get_group(group)) def infer_shape(self, x_shape): x_shape[0] = x_shape[0] * self.rank_size @@ -205,11 +205,11 @@ class ReduceScatter(PrimitiveWithInfer): @prim_attr_register def __init__(self, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP): validator.check_value_type('op', op, (type(ReduceOp.SUM),), self.name) - validator.check_value_type('group', get_group(group), (str,), self.name) + validator.check_value_type('group', _get_group(group), (str,), self.name) self.op = op - self.rank_size = get_group_size(get_group(group)) + self.rank_size = get_group_size(_get_group(group)) self.add_prim_attr('rank_size', self.rank_size) - self.add_prim_attr('group', get_group(group)) + self.add_prim_attr('group', _get_group(group)) def infer_shape(self, x_shape): if x_shape[0] % self.rank_size != 0: @@ -268,8 +268,8 @@ class Broadcast(PrimitiveWithInfer): @prim_attr_register def __init__(self, root_rank, group=GlobalComm.WORLD_COMM_GROUP): validator.check_value_type('root_rank', root_rank, (int,), self.name) - validator.check_value_type('group', get_group(group), (str,), self.name) - self.add_prim_attr('group', get_group(group)) + validator.check_value_type('group', _get_group(group), (str,), self.name) + self.add_prim_attr('group', _get_group(group)) def infer_shape(self, x_shape): return x_shape @@ -306,11 +306,11 @@ class _AlltoAll(PrimitiveWithInfer): @prim_attr_register def __init__(self, split_count, split_dim, concat_dim, group=GlobalComm.WORLD_COMM_GROUP): """init AlltoAll""" - validator.check_value_type('group', get_group(group), (str,), self.name) + validator.check_value_type('group', _get_group(group), (str,), self.name) self.split_count = split_count self.split_dim = split_dim self.concat_dim = concat_dim - self.add_prim_attr('group', get_group(group)) + self.add_prim_attr('group', _get_group(group)) def infer_shape(self, x_shape): x_shape[self.concat_dim] = x_shape[self.concat_dim] * self.split_count